[
  {
    "path": ".clang-format",
    "content": "---\nLanguage:        Cpp\nAccessModifierOffset: -1\nAlignAfterOpenBracket: Align\nAlignConsecutiveAssignments: false\nAlignConsecutiveDeclarations: false\nAlignEscapedNewlinesLeft: true\nAlignOperands:   true\nAlignTrailingComments: true\nAllowAllParametersOfDeclarationOnNextLine: true\nAllowShortBlocksOnASingleLine: true\nAllowShortCaseLabelsOnASingleLine: true\nAllowShortFunctionsOnASingleLine: All\nAllowShortIfStatementsOnASingleLine: true\nAllowShortLoopsOnASingleLine: true\nAlwaysBreakAfterDefinitionReturnType: None\nAlwaysBreakAfterReturnType: None\nAlwaysBreakBeforeMultilineStrings: false\nAlwaysBreakTemplateDeclarations: true\nBinPackArguments: true\nBinPackParameters: true\nBraceWrapping:\n  AfterClass:      true\n  AfterControlStatement: false\n  AfterEnum:       false\n  AfterFunction:   false\n  AfterNamespace:  false\n  AfterObjCDeclaration: false\n  AfterStruct:     false\n  AfterUnion:      false\n  BeforeCatch:     false\n  BeforeElse:      false\n  IndentBraces:    false\nBreakBeforeBinaryOperators: NonAssignment\nBreakBeforeBraces: Attach\nBreakBeforeTernaryOperators: true\nBreakConstructorInitializersBeforeComma: false\nBreakAfterJavaFieldAnnotations: false\nBreakStringLiterals: true\nColumnLimit:     100\nCommentPragmas:  '^ IWYU pragma:'\nBreakBeforeInheritanceComma: false\nConstructorInitializerAllOnOneLineOrOnePerLine: true\nConstructorInitializerIndentWidth: 4\nContinuationIndentWidth: 4\nCpp11BracedListStyle: true\nDisableFormat:   false\nExperimentalAutoDetectBinPacking: false\nFixNamespaceComments: true\nForEachMacros:   [ foreach, Q_FOREACH, BOOST_FOREACH ]\nIncludeCategories:\n  - Regex:           '^<.*\\.h>'\n    Priority:        1\n  - Regex:           '^<.*'\n    Priority:        2\n  - Regex:           '.*'\n    Priority:        3\nIncludeIsMainRegex: '([-_](test|unittest))?$'\nIndentCaseLabels: true\nIndentWidth:     2\nIndentWrappedFunctionNames: false\nJavaScriptQuotes: Leave\nJavaScriptWrapImports: true\nKeepEmptyLinesAtTheStartOfBlocks: false\nMacroBlockBegin: ''\nMacroBlockEnd:   ''\nMaxEmptyLinesToKeep: 1\nNamespaceIndentation: None\nObjCBlockIndentWidth: 2\nObjCSpaceAfterProperty: false\nObjCSpaceBeforeProtocolList: false\nPenaltyBreakBeforeFirstCallParameter: 1\nPenaltyBreakComment: 300\nPenaltyBreakFirstLessLess: 120\nPenaltyBreakString: 1000\nPenaltyExcessCharacter: 1000000\nPenaltyReturnTypeOnItsOwnLine: 200\nPointerAlignment: Left\nReflowComments:  true\nSortIncludes:    false\nSpaceAfterCStyleCast: false\nSpaceAfterTemplateKeyword: false\nSpaceBeforeAssignmentOperators: true\nSpaceBeforeParens: ControlStatements\nSpaceInEmptyParentheses: false\nSpacesBeforeTrailingComments: 2\nSpacesInAngles:  false\nSpacesInContainerLiterals: true\nSpacesInCStyleCastParentheses: false\nSpacesInParentheses: false\nSpacesInSquareBrackets: false\nStandard:        Auto\nTabWidth:        8\nUseTab:          Never\n...\n"
  },
  {
    "path": ".clang-tidy",
    "content": "# `maybe-*` checks are only available on OneFlow custom clang-tidy and clangd\n# `-allow-enabling-analyzer-alpha-checkers` should be passed to clang-tidy for CSA checkers named `clang-analyzer-alpha.*` (or `-allow-enabling-alpha-checkers` for run-clang-tidy.py)\n# `aggressive-binary-operation-simplification` should be enabled (via `-Xclang -analyzer-config -Xclang aggressive-binary-operation-simplification=true` in clang)\n# there is some problem in `clang-analyzer-alpha.clone.*`, so do not enable it\n# `clang-analyzer-alpha.deadcode.*` is just too verbose to enable\nChecks: >-\n  -*,\n  clang-diagnostic-*,\n  maybe-*,\n  clang-analyzer-core.*,\n  clang-analyzer-cplusplus.*,\n  clang-analyzer-nullability.*,\n  clang-analyzer-deadcode.*,\n  clang-analyzer-security.*,\n  clang-analyzer-optin.cplusplus.*,\n  clang-analyzer-optin.performance.*,\n  clang-analyzer-alpha.core.*,\n  clang-analyzer-alpha.cplusplus.*,\n  clang-analyzer-alpha.security.*,\n  cppcoreguidelines-avoid-goto,\n  cppcoreguidelines-init-variables,\n  cppcoreguidelines-interfaces-global-init,\n  cppcoreguidelines-no-malloc,\n  cppcoreguidelines-prefer-member-initializer,\n  cppcoreguidelines-pro-type-member-init,\n  cppcoreguidelines-pro-type-static-cast-downcast,\n  cppcoreguidelines-slicing,\n  cppcoreguidelines-special-member-functions,\n  performance-unnecessary-value-param,\n  performance-unnecessary-copy-initialization,\n  performance-noexcept-move-constructor,\n  performance-no-automatic-move,\n  performance-move-const-arg,\n  performance-implicit-conversion-in-loop,\n  performance-for-range-copy,\n  google-default-arguments,\n  google-global-names-in-headers,\n  google-explicit-constructor,\n  modernize-use-emplace\n\n# TODO: treat all maybe warnings as errors when existing warnings are all fixed\n# `clang-analyzer-cplusplus.NewDelete` cannot model reference counting properly for ObjectMsg\nWarningsAsErrors: >-\n  maybe-unused,\n  clang-analyzer-nullability.*,\n  clang-analyzer-cplusplus.*,\n  performance-implicit-conversion-in-loop,\n  performance-move-const-arg,\n  performance-no-automatic-move,\n  performance-noexcept-move-constructor,\n  google-default-arguments,\n  google-global-names-in-headers,\n  -clang-analyzer-cplusplus.NewDelete,\n  modernize-use-emplace\n\nCheckOptions:\n  # `cppcoreguidelines-special-member-functions` is enabled, refer to https://en.cppreference.com/w/cpp/language/rule_of_three\n  - key:             cppcoreguidelines-special-member-functions.AllowSoleDefaultDtor\n    value:           True\n  - key:             performance-move-const-arg.CheckTriviallyCopyableMove\n    value:           False\n  - key:             cppcoreguidelines-special-member-functions.AllowMissingMoveFunctionsWhenCopyIsDeleted\n    value:           True\n"
  },
  {
    "path": ".cmake-format.py",
    "content": "# ----------------------------------\n# Options affecting listfile parsing\n# ----------------------------------\nwith section(\"parse\"):\n\n    # Specify structure for custom cmake functions\n    additional_commands = {\n        \"cc_binary\": {\n            \"flags\": [\"ADD_RUNTARGET\"],\n            \"kwargs\": {\n                \"DEPS\": \"*\",\n                \"INC\": {\n                    \"kwargs\": {\"INTERFACE\": \"*\", \"PRIVATE\": \"*\", \"PUBLIC\": \"*\"},\n                    \"pargs\": 0,\n                },\n                \"LIBDIRS\": {\n                    \"kwargs\": {\"INTERFACE\": \"*\", \"PRIVATE\": \"*\", \"PUBLIC\": \"*\"},\n                    \"pargs\": \"*\",\n                },\n                \"PKGDEPS\": \"*\",\n                \"PROPERTIES\": {\"kwargs\": {\"EXPORT_NAME\": 1, \"OUTPUT_NAME\": 1}},\n                \"SRCS\": \"*\",\n            },\n            \"pargs\": \"1+\",\n        },\n        \"cc_library\": {\n            \"flags\": [\"STATIC\", \"SHARED\"],\n            \"kwargs\": {\n                \"DEPS\": {\n                    \"kwargs\": {\"INTERFACE\": \"*\", \"PRIVATE\": \"*\", \"PUBLIC\": \"*\"},\n                    \"pargs\": \"*\",\n                },\n                \"INC\": {\n                    \"kwargs\": {\"INTERFACE\": \"*\", \"PRIVATE\": \"*\", \"PUBLIC\": \"*\"},\n                    \"pargs\": 0,\n                },\n                \"LIBDIRS\": {\n                    \"kwargs\": {\"INTERFACE\": \"*\", \"PRIVATE\": \"*\", \"PUBLIC\": \"*\"},\n                    \"pargs\": \"*\",\n                },\n                \"PKGDEPS\": \"*\",\n                \"PROPERTIES\": {\n                    \"kwargs\": {\n                        \"ARCHIVE_OUTPUT_NAME\": 1,\n                        \"EXPORT_NAME\": 1,\n                        \"INTERFACE_INCLUDE_DIRECTORIES\": 1,\n                        \"LIBRARY_OUTPUT_NAME\": 1,\n                        \"OUTPUT_NAME\": 1,\n                        \"SOVERSION\": 1,\n                        \"SUFFIX\": 1,\n                        \"VERSION\": 1,\n                    }\n                },\n                \"SRCS\": \"*\",\n            },\n            \"pargs\": \"1+\",\n        },\n        \"cc_test\": {\n            \"kwargs\": {\n                \"ARGV\": \"*\",\n                \"DEPS\": \"*\",\n                \"LABELS\": \"*\",\n                \"PKGDEPS\": \"*\",\n                \"SRCS\": \"*\",\n                \"TEST_DEPS\": \"*\",\n                \"WORKING_DIRECTORY\": \"*\",\n            },\n            \"pargs\": 1,\n        },\n        \"check_call\": {\n            \"flags\": [\n                \"OUTPUT_QUIET\",\n                \"ERROR_QUIET\",\n                \"OUTPUT_STRIP_TRAILING_WHITESPACE\",\n                \"ERROR_STRIP_TRAILING_WHITESPACE\",\n            ],\n            \"kwargs\": {\n                \"COMMAND\": \"*\",\n                \"ENCODING\": \"1\",\n                \"ERROR_FILE\": \"1\",\n                \"ERROR_VARIABLE\": \"1\",\n                \"INPUT_FILE\": \"1\",\n                \"OUTPUT_FILE\": \"1\",\n                \"OUTPUT_VARIABLE\": \"1\",\n                \"RESULTS_VARIABLE\": \"1\",\n                \"RESULT_VARIABLE\": \"1\",\n                \"TIMEOUT\": \"1\",\n                \"WORKING_DIRECTORY\": \"1\",\n            },\n        },\n        \"check_pyoneline\": {\n            \"kwargs\": {\"ERROR_VARIABLE\": 1, \"OUTPUT_VARIABLE\": 1},\n            \"pargs\": \"+\",\n        },\n        \"create_debian_binary_packages\": {\n            \"kwargs\": {\"DEPS\": \"*\", \"OUTPUTS\": \"*\"},\n            \"pargs\": [3, \"+\"],\n        },\n        \"create_debian_depsrepo\": {\"pargs\": [3, \"+\"]},\n        \"create_debian_packages\": {\n            \"kwargs\": {\"DEPS\": \"*\", \"OUTPUTS\": \"*\"},\n            \"pargs\": [{\"flags\": [\"FORCE_PBUILDER\"], \"nargs\": \"+\"}],\n        },\n        \"debhelp\": {\"pargs\": [\"1+\"], \"spelling\": \"DEBHELP\"},\n        \"exportvars\": {\n            \"kwargs\": {\"VARS\": \"+\"},\n            \"pargs\": \"1+\",\n            \"spelling\": \"EXPORTVARS\",\n        },\n        \"format_and_lint\": {\n            \"kwargs\": {\"CC\": \"*\", \"CMAKE\": \"*\", \"JS\": \"*\", \"PY\": \"*\", \"SHELL\": \"*\"}\n        },\n        \"get_debs\": {\"pargs\": [3, \"*\"]},\n        \"gresource\": {\"kwargs\": {\"DEPENDS\": \"+\", \"SRCDIR\": 1}, \"pargs\": 2},\n        \"gtk_doc_add_module\": {\n            \"kwargs\": {\n                \"FIXREFOPTS\": \"*\",\n                \"IGNOREHEADERS\": \"*\",\n                \"LIBRARIES\": \"*\",\n                \"LIBRARY_DIRS\": \"*\",\n                \"SOURCE\": \"*\",\n                \"SUFFIXES\": \"*\",\n                \"XML\": 1,\n            },\n            \"pargs\": 1,\n        },\n        \"importvars\": {\n            \"kwargs\": {\"VARS\": \"+\"},\n            \"pargs\": \"1+\",\n            \"spelling\": \"IMPORTVARS\",\n        },\n        \"join\": {\"kwargs\": {\"GLUE\": 1}, \"pargs\": [1, \"+\"]},\n        \"pkg_find\": {\"kwargs\": {\"PKG\": \"*\"}},\n        \"stage_files\": {\n            \"kwargs\": {\"FILES\": \"*\", \"LIST\": 1, \"SOURCEDIR\": 1, \"STAGE\": 1}\n        },\n        \"tangent_addtest\": {\n            \"kwargs\": {\n                \"COMMAND\": \"+\",\n                \"CONFIGURATIONS\": \"+\",\n                \"DEPENDS\": \"+\",\n                \"LABELS\": \"+\",\n                \"NAME\": 1,\n                \"WORKING_DIRECTORY\": 1,\n            }\n        },\n        \"tangent_extract_svg\": {\"kwargs\": {\"EXPORT\": 1, \"OUTPUT\": 1, \"SRC\": 1}},\n        \"tangent_fetchobj\": {\"kwargs\": {\"OUTDIR\": 1}, \"pargs\": 2},\n        \"tangent_rmark_render\": {\n            \"kwargs\": {\"DEPENDS\": 1, \"FORMAT\": 1, \"OUTPUT\": 1, \"PAGENO\": 1, \"UUID\": 1},\n            \"pargs\": 1,\n        },\n        \"tangent_unzip\": {\n            \"kwargs\": {\"OUTPUT\": \"1+\", \"WORKING_DIRECTORY\": 1},\n            \"pargs\": \"1+\",\n        },\n        \"travis_decrypt\": {\"kwargs\": {}, \"pargs\": [3]},\n    }\n\n    # Override configurations per-command where available\n    override_spec = {}\n\n    # Specify variable tags.\n    vartags = []\n\n    # Specify property tags.\n    proptags = []\n\n# -----------------------------\n# Options affecting formatting.\n# -----------------------------\nwith section(\"format\"):\n\n    # Disable formatting entirely, making cmake-format a no-op\n    disable = False\n\n    # How wide to allow formatted cmake files\n    line_width = 100\n\n    # How many spaces to tab for indent\n    tab_size = 2\n\n    # If true, lines are indented using tab characters (utf-8 0x09) instead of\n    # <tab_size> space characters (utf-8 0x20). In cases where the layout would\n    # require a fractional tab character, the behavior of the  fractional\n    # indentation is governed by <fractional_tab_policy>\n    use_tabchars = False\n\n    # If <use_tabchars> is True, then the value of this variable indicates how\n    # fractional indentions are handled during whitespace replacement. If set to\n    # 'use-space', fractional indentation is left as spaces (utf-8 0x20). If set\n    # to `round-up` fractional indentation is replaced with a single tab character\n    # (utf-8 0x09) effectively shifting the column to the next tabstop\n    fractional_tab_policy = \"use-space\"\n\n    # If an argument group contains more than this many sub-groups (parg or kwarg\n    # groups) then force it to a vertical layout.\n    max_subgroups_hwrap = 3\n\n    # If a positional argument group contains more than this many arguments, then\n    # force it to a vertical layout.\n    max_pargs_hwrap = 6\n\n    # If a cmdline positional group consumes more than this many lines without\n    # nesting, then invalidate the layout (and nest)\n    max_rows_cmdline = 3\n\n    # If true, separate flow control names from their parentheses with a space\n    separate_ctrl_name_with_space = False\n\n    # If true, separate function names from parentheses with a space\n    separate_fn_name_with_space = False\n\n    # If a statement is wrapped to more than one line, than dangle the closing\n    # parenthesis on its own line.\n    dangle_parens = False\n\n    # If the trailing parenthesis must be 'dangled' on its on line, then align it\n    # to this reference: `prefix`: the start of the statement,  `prefix-indent`:\n    # the start of the statement, plus one indentation  level, `child`: align to\n    # the column of the arguments\n    dangle_align = \"prefix\"\n\n    # If the statement spelling length (including space and parenthesis) is\n    # smaller than this amount, then force reject nested layouts.\n    min_prefix_chars = 4\n\n    # If the statement spelling length (including space and parenthesis) is larger\n    # than the tab width by more than this amount, then force reject un-nested\n    # layouts.\n    max_prefix_chars = 10\n\n    # If a candidate layout is wrapped horizontally but it exceeds this many\n    # lines, then reject the layout.\n    max_lines_hwrap = 2\n\n    # What style line endings to use in the output.\n    line_ending = \"unix\"\n\n    # Format command names consistently as 'lower' or 'upper' case\n    command_case = \"canonical\"\n\n    # Format keywords consistently as 'lower' or 'upper' case\n    keyword_case = \"unchanged\"\n\n    # A list of command names which should always be wrapped\n    always_wrap = []\n\n    # If true, the argument lists which are known to be sortable will be sorted\n    # lexicographicall\n    enable_sort = True\n\n    # If true, the parsers may infer whether or not an argument list is sortable\n    # (without annotation).\n    autosort = False\n\n    # By default, if cmake-format cannot successfully fit everything into the\n    # desired linewidth it will apply the last, most agressive attempt that it\n    # made. If this flag is True, however, cmake-format will print error, exit\n    # with non-zero status code, and write-out nothing\n    require_valid_layout = False\n\n    # A dictionary mapping layout nodes to a list of wrap decisions. See the\n    # documentation for more information.\n    layout_passes = {}\n\n# ------------------------------------------------\n# Options affecting comment reflow and formatting.\n# ------------------------------------------------\nwith section(\"markup\"):\n\n    # What character to use for bulleted lists\n    bullet_char = \"*\"\n\n    # What character to use as punctuation after numerals in an enumerated list\n    enum_char = \".\"\n\n    # If comment markup is enabled, don't reflow the first comment block in each\n    # listfile. Use this to preserve formatting of your copyright/license\n    # statements.\n    first_comment_is_literal = False\n\n    # If comment markup is enabled, don't reflow any comment block which matches\n    # this (regex) pattern. Default is `None` (disabled).\n    literal_comment_pattern = None\n\n    # Regular expression to match preformat fences in comments default=\n    # ``r'^\\s*([`~]{3}[`~]*)(.*)$'``\n    fence_pattern = \"^\\\\s*([`~]{3}[`~]*)(.*)$\"\n\n    # Regular expression to match rulers in comments default=\n    # ``r'^\\s*[^\\w\\s]{3}.*[^\\w\\s]{3}$'``\n    ruler_pattern = \"^\\\\s*[^\\\\w\\\\s]{3}.*[^\\\\w\\\\s]{3}$\"\n\n    # If a comment line matches starts with this pattern then it is explicitly a\n    # trailing comment for the preceeding argument. Default is '#<'\n    explicit_trailing_pattern = \"#<\"\n\n    # If a comment line starts with at least this many consecutive hash\n    # characters, then don't lstrip() them off. This allows for lazy hash rulers\n    # where the first hash char is not separated by space\n    hashruler_min_length = 10\n\n    # If true, then insert a space between the first hash char and remaining hash\n    # chars in a hash ruler, and normalize its length to fill the column\n    canonicalize_hashrulers = True\n\n    # enable comment markup parsing and reflow\n    enable_markup = False\n\n# ----------------------------\n# Options affecting the linter\n# ----------------------------\nwith section(\"lint\"):\n\n    # a list of lint codes to disable\n    disabled_codes = [\"C0113\"]\n\n    # regular expression pattern describing valid function names\n    function_pattern = \"[0-9a-z_]+\"\n\n    # regular expression pattern describing valid macro names\n    macro_pattern = \"[0-9A-Z_]+\"\n\n    # regular expression pattern describing valid names for variables with global\n    # (cache) scope\n    global_var_pattern = \"[A-Z][0-9A-Z_]+\"\n\n    # regular expression pattern describing valid names for variables with global\n    # scope (but internal semantic)\n    internal_var_pattern = \"_[A-Z][0-9A-Z_]+\"\n\n    # regular expression pattern describing valid names for variables with local\n    # scope\n    local_var_pattern = \"[a-z][a-z0-9_]+\"\n\n    # regular expression pattern describing valid names for privatedirectory\n    # variables\n    private_var_pattern = \"_[0-9a-z_]+\"\n\n    # regular expression pattern describing valid names for public directory\n    # variables\n    public_var_pattern = \"[A-Z][0-9A-Z_]+\"\n\n    # regular expression pattern describing valid names for function/macro\n    # arguments and loop variables.\n    argument_var_pattern = \"[a-z][a-z0-9_]+\"\n\n    # regular expression pattern describing valid names for keywords used in\n    # functions or macros\n    keyword_pattern = \"[A-Z][0-9A-Z_]+\"\n\n    # In the heuristic for C0201, how many conditionals to match within a loop in\n    # before considering the loop a parser.\n    max_conditionals_custom_parser = 2\n\n    # Require at least this many newlines between statements\n    min_statement_spacing = 1\n\n    # Require no more than this many newlines between statements\n    max_statement_spacing = 2\n    max_returns = 6\n    max_branches = 12\n    max_arguments = 5\n    max_localvars = 15\n    max_statements = 50\n\n# -------------------------------\n# Options affecting file encoding\n# -------------------------------\nwith section(\"encode\"):\n\n    # If true, emit the unicode byte-order mark (BOM) at the start of the file\n    emit_byteorder_mark = False\n\n    # Specify the encoding of the input file. Defaults to utf-8\n    input_encoding = \"utf-8\"\n\n    # Specify the encoding of the output file. Defaults to utf-8. Note that cmake\n    # only claims to support utf-8 so be careful when using anything else\n    output_encoding = \"utf-8\"\n\n# -------------------------------------\n# Miscellaneous configurations options.\n# -------------------------------------\nwith section(\"misc\"):\n\n    # A dictionary containing any per-command configuration overrides. Currently\n    # only `command_case` is supported.\n    per_command = {}\n"
  },
  {
    "path": ".devcontainer/Dockerfile",
    "content": "# See here for image contents: https://github.com/Oneflow-Inc/docker-images/blob/main/oneflow/Dockerfile\n# [Choice] llvm12 llvm13 cuda11.1\nARG VARIANT=\"llvm13\"\nARG REPO=\"oneflowinc/devcontainer\"\nFROM ${REPO}:${VARIANT}\n"
  },
  {
    "path": ".devcontainer/devcontainer.json",
    "content": "// For format details, see https://aka.ms/devcontainer.json. For config options, see the README at:\n// https://github.com/microsoft/vscode-dev-containers/tree/v0.209.6/containers/cpp\n// workaround for EACCES: permission denied, mkdir '/tmp/vsch.....\n// https://github.com/microsoft/vscode-remote-release/issues/2347\n// sudo chmod 777 /tmp/vsch/container-features\n{\n\t\"name\": \"oneflow-devel\",\n\t\"image\": \"oneflowinc/manylinux2014_x86_64_cuda11.2\",\n\t\"runArgs\": [\n\t\t\"--cap-add=SYS_PTRACE\",\n\t\t\"--privileged\",\n\t\t\"--shm-size=8g\",\n\t\t\"--security-opt\",\n\t\t\"seccomp=unconfined\",\n\t\t\"--network=host\",\n\t\t// \"--gpus\",\n\t\t// \"all\",\n\t],\n\t\"remoteEnv\": {\n\t\t\"PATH\": \"${containerEnv:PATH}:/opt/python/cp37-cp37m/bin\",\n\t\t\"ONEFLOW_CI_PYTHON_EXE\": \"/opt/python/cp37-cp37m/bin/python3\",\n\t\t\"ONEFLOW_CI_SRC_DIR\": \"${containerWorkspaceFolder}\",\n\t\t\"ONEFLOW_CI_BUILD_DIR\": \"${containerWorkspaceFolder}/build\",\n\t\t\"ONEFLOW_CI_CMAKE_INIT_CACHE\": \"${containerWorkspaceFolder}/cmake/caches/ci/cuda.cmake\",\n\t\t\"ONEFLOW_CI_BUILD_PARALLEL\": \"20\"\n\t},\n\t\"initializeCommand\": \"mkdir -p ${localWorkspaceFolder}/devcontainer-cache/dot/ccache && mkdir -p ${localWorkspaceFolder}/devcontainer-cache/dot/local && mkdir -p ${localWorkspaceFolder}/devcontainer-cache/dot/cache\",\n\t\"mounts\": [\n\t\t\"source=${localWorkspaceFolder}/devcontainer-cache/dot/ccache,target=/root/.ccache,type=bind,consistency=cached\",\n\t\t\"source=${localWorkspaceFolder}/devcontainer-cache/dot/local,target=/root/.local,type=bind,consistency=cached\",\n\t\t\"source=${localWorkspaceFolder}/devcontainer-cache/dot/cache,target=/root/.cache,type=bind,consistency=cached\",\n\t\t\"source=/dataset,target=/dataset,type=bind,consistency=cached,readonly\",\n\t\t\"source=/model_zoo,target=/model_zoo,type=bind,consistency=cached,readonly\",\n\t],\n\t// Set *default* container specific settings.json values on container create.\n\t\"settings\": {\n\t\t\"files.insertFinalNewline\": true,\n\t\t\"files.trimFinalNewlines\": true,\n\t\t\"files.trimTrailingWhitespace\": true,\n\t\t\"files.eol\": \"\\n\",\n\t\t\"clangd.arguments\": [\n\t\t\t\"-j\",\n\t\t\t\"8\",\n\t\t\t\"-header-insertion=never\"\n\t\t],\n\t},\n\t// Add the IDs of extensions you want installed when the container is created.\n\t\"extensions\": [\n\t\t\"llvm-vs-code-extensions.vscode-clangd\",\n\t\t\"ms-vscode.cmake-tools\",\n\t\t\"ms-python.python\"\n\t],\n\t// Comment out connect as root instead. More info: https://aka.ms/vscode-remote/containers/non-root.\n\t\"remoteUser\": \"root\",\n}\n"
  },
  {
    "path": ".dockerignore",
    "content": "**/.git\n/build\n/build-*\n/docs/build\n/cmake-build-*\n/third_party\n/examples/**/oneflow\n/benchmark/**/oneflow\n/.vscode\n/.idea\n/.clangd\n/dist\n/wheelhouse*\n/.DS_Store\n/tmp_wheel\n/manylinux*\n\n**/__pycache__\n**/*.pyc\n**/log\n**/.ipynb_checkpoints\n**/core.0*\n**/core.1*\n**/core.2*\n**/core.3*\n**/core.4*\n**/core.5*\n**/core.6*\n**/core.7*\n**/core.8*\n**/core.9*\n/.cache\n/oneflow-src.zip\n/distributed-tmp\n/serving-tmp\n"
  },
  {
    "path": ".github/CODEOWNERS",
    "content": "*.cu @liujuncheng\n*.py @BBuf @daquexian\n/oneflow/core/cuda @liujuncheng\n/oneflow/core/eager @daquexian\n/oneflow/core/framework @chengtbf @strint\n/oneflow/core/functional @hjchen2\n/oneflow/core/graph @chengtbf\n/oneflow/core/ndarray @daquexian\n/oneflow/core/object_msg @daquexian\n/oneflow/core/platform @jackalcooper\n/oneflow/core/ep @liujuncheng\n/oneflow/core/rpc @jackalcooper\n/oneflow/core/stream @liujuncheng\n/oneflow/core/hardware @liujuncheng\n/oneflow/core/transport @chengtbf\n/oneflow/core/vm @daquexian\n/oneflow/xrt @hjchen2\n/oneflow/ir @hjchen2 @BBuf @jackalcooper\n/ci @jackalcooper\n/python/oneflow/test_utils @daquexian @BBuf\n/cmake @daquexian @jackalcooper\nCMakeLists.txt @daquexian @jackalcooper\n/.github @jackalcooper\n/tools @jackalcooper\n/docs @doombeaker\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/blank_issue.yml",
    "content": "name: Blank Issue\ndescription: Submit an issue about OneFlow.\nlabels: [Blank Issue]\nbody:\n  - type: textarea\n    id: description\n    attributes:\n      label: Description\n      description: Please describe the issue here.\n      placeholder: Description\n    validations:\n      required: false\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/bug_report.md",
    "content": "---\nname: Bug report\nabout: Create a report to help us improve\ntitle: ''\nlabels: bug, community\nassignees: ''\n\n---\n\n## Summary\n\nA short description about the bug/issue\n\n## Code to reproduce bug\n\nPlease post a minimal example to repro the bug. GitHub Gist or repo is highly recommended.\n\n## System Information\n\n- What is your OneFlow installation (pip, source, dockerhub):\n- OS:\n- OneFlow version (run `python3 -m oneflow --doctor`):\n- Python version:\n- CUDA driver version:\n- GPU models:\n- Other info:\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/documention_issue.yml",
    "content": "name: Documentation Issue\ndescription: Report an issue about OneFlow ducumention or require a documention.\ntitle: \"[Documention Issue]: \"\nlabels: [Documention Issue]\nbody:\n  - type: markdown\n    attributes:\n      value: |\n        Welcome to suggest to OneFlow documention! This template will help us gather the information we need to improve it.\n  - type: textarea\n    id: brief-description\n    attributes:\n      label: Brief Description\n      description: Please describe the problem or the requst for new documention here.\n      placeholder: Description\n    validations:\n      required: true\n  - type: textarea\n    id: alternatives\n    attributes:\n      label: Alternatives\n      description: |\n        Please provide some alternative information here, if any.\n      placeholder: Alternatives\n    validations:\n      required: false\n  - type: markdown\n    attributes:\n      value: |\n        Thanks for your contributing!\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/feature_request.yml",
    "content": "name: Feature Request\ndescription: Request/Propose a new OneFlow feature.\ntitle: \"[Feature Request]: \"\nlabels: [feature-request]\nbody:\n  - type: markdown\n    attributes:\n      value: |\n        We welcome feature proposal/request! This template will help us gather the information we need to review the proposal/request.\n  - type: textarea\n    id: background\n    attributes:\n      label: Background and motivation\n      description: Please describe the purpose and value of the new feature here. If the feature is linked to a specific problem, please describe it or put the link here.\n      placeholder: Purpose\n    validations:\n      required: true\n  - type: textarea\n    id: api-proposal\n    attributes:\n      label: API Proposal\n      description: |\n        Please provide the specific public API signature diff that you are proposing. If a new API is not required, please provide the current API related to the feature, or note that there is no related public API.\n      placeholder: API declaration (no method bodies)\n      value: |\n        ```py\n        def new_api(value: Tensor) -> Tensor:\n          pass\n        ```\n    validations:\n      required: true\n  - type: textarea\n    id: api-usage\n    attributes:\n      label: API Usage\n      description: |\n        Please provide code examples that highlight how the proposed API additions are meant to be consumed. This will help suggest whether the API has the right shape to be functional, performant and usable.\n        If there is not a new API in step 2, please skip it.\n      placeholder: API usage\n    validations:\n      required: false\n  - type: textarea\n    id: alternatives\n    attributes:\n      label: Alternatives\n      description: |\n        Please provide some alternative information of the feature, if any. For example, if you request a feature which depends on a specific device, please provide the device information.\n      placeholder: Alternatives\n    validations:\n      required: false\n  - type: textarea\n    id: risks\n    attributes:\n      label: Risks\n      description: |\n        Please mention any risks that to your knowledge the API proposal might entail, such as breaking changes, performance regressions, etc.\n      placeholder: Risks\n    validations:\n      required: false\n  - type: markdown\n    attributes:\n      value: |\n        Thanks for your contributing!\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/performance_issue.yml",
    "content": "name: Performance Issue\ndescription: Submit an issue about performance problem or regression of OneFlow.\ntitle: \"[Performance Issue]: \"\nlabels: [Performance Issue]\nbody:\n  - type: markdown\n    attributes:\n      value: |\n        We welcome issues about OneFlow performance! This template will help us gather the information we need to locate the problem improve the performance.\n  - type: textarea\n    id: brief-description\n    attributes:\n      label: Brief Description\n      description: Please give a brief description about the performance issue here.\n      placeholder: Description\n    validations:\n      required: true\n  - type: textarea\n    id: device-and-context\n    attributes:\n      label: Device and Context\n      description: |\n        Please describe the device and context you used when you encounter the performance problem/regression.\n      placeholder: Device and Context\n    validations:\n      required: true\n  - type: textarea\n    id: benchmark\n    attributes:\n      label: Benchmark\n      description: |\n        We will appreciate it if you'd like to provide benchmark comparison of the performance issue.\n      placeholder: Benchmark\n    validations:\n      required: false\n  - type: textarea\n    id: alternatives\n    attributes:\n      label: Alternatives\n      description: |\n        Please provide some alternative information of the performance issue here, if any.\n      placeholder: Alternatives\n    validations:\n      required: false\n  - type: markdown\n    attributes:\n      value: |\n        Thanks for your contributing!\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/question.yml",
    "content": "name: Question\ndescription: Ask a question about OneFlow and discuss with community members.\ntitle: \"[Question]: \"\nlabels: [Question]\nbody:\n  - type: markdown\n    attributes:\n      value: |\n        Welcome to ask questions about OneFlow! This template will help us get your point.\n  - type: textarea\n    id: description\n    attributes:\n      label: Description\n      description: Please describe your question here.\n      placeholder: Description\n    validations:\n      required: true\n  - type: textarea\n    id: alternatives\n    attributes:\n      label: Alternatives\n      description: |\n        Please provide some alternative information here, if any.\n      placeholder: Alternatives\n    validations:\n      required: false\n  - type: markdown\n    attributes:\n      value: |\n        We are always willing to answer your questions!\n"
  },
  {
    "path": ".github/PULL_REQUEST_TEMPLATE/general_template.md",
    "content": "## 概述\n\n\n## PR Checklist\n - [ ] PR 标题语句通畅，明确表达 PR 内容，适合直接作为新版本发布时的 changelog\n - [ ] 代码格式化\n - [ ] 已经本地编译通过\n - [ ] 已本地针对改动测试\n - [ ] 已添加 type 标签:(填写 type 标签名，如 `bug, enhancement, purge, feature, documentation`)\n - [ ] 已添加 component 标签:(填写 component 标签名，如 `op, system, eager, build, xla, python, ci, test, tooling`)\n - [ ] Draft 转正式 PR 前已请人 Review\n"
  },
  {
    "path": ".github/PULL_REQUEST_TEMPLATE/op_template.md",
    "content": "## 概述\n描述 op 的功能、公式等。若参考了其它框架的接口，应列出超链接。\n\n## 功能 CheckList\n**注意** : 功能复选框均为可选项，若未选择，说明理由即可。例如：该 Op 由 Python 接口拼接而成，因此无 `SetBatchAxisInferFn` Op 注册；再比如：该 Op 无输入，因此无 `SetInputArgModifyFn`。\n\n模板中自带的复选框可留空，但是不能删除。可根据实际情况增加复选框选项。\n\n### Op\n - [ ] Op SetBatchAxisInferFn\n - [ ] Op SetGetSbpFn\n - [ ] Op SetInputArgModifyFn\n - [ ] Op 反向梯度注册\n\n### Kernel\n - [ ] CPU in:float32\n - [ ] CPU in:float64\n - [ ] CPU in:int32\n - [ ] CPU in:int64\n - [ ] CPU in:int8\n\n - [ ] GPU in:float32\n - [ ] GPU in:float64\n - [ ] GPU in:int32\n - [ ] GPU in:int64\n - [ ] GPU in:float16\n - [ ] GPU in:int8\n\n\n### Python Wrapper\n - [ ] Python API 参数检查及异常提示\n - [ ] 接口注释\n - [ ] Example \n\n### 测试\n - [ ] 单机单卡  CPU Test Case\n - [ ] 单机单卡  GPU Test Case\n - [ ] 单机多卡  CPU Test Case\n - [ ] 单机多卡  GPU Test Case\n - [ ] 分布式  CPU Test Case\n - [ ] 分布式  GPU Test Case\n\n## GPU 有效带宽\n带 GPU 的 Op，请参考 https://github.com/Oneflow-Inc/OneTeam/issues/167 测试有效带宽，并附带测试报告。\n以下是报告样例：\n\n理论带宽：\n```text\n Device to Device Bandwidth, 1 Device(s)\n PINNED Memory Transfers\n   Transfer Size (Bytes)\tBandwidth(MB/s)\n   33554432\t\t\t250798.5\n```\n\n实际带宽：\n```\nPROFILER::KERNEL::CUDA_MEMORY_BANDWIDTH op_name: sqrt_2 elapsed(ms): 0.196064 memory_size(Byte): 50331648 bandwidth(GB/s): 239.08\nPROFILER::KERNEL::CUDA_MEMORY_BANDWIDTH op_name: sqrt_2_grad elapsed(ms): 0.29072 memory_size(Byte): 75497472 bandwidth(GB/s): 241.856\n```\n\n\n## PR Checklist\n - [ ] PR 标题语句通畅，明确表达 PR 内容，适合直接作为新版本发布时的 changelog\n - [ ] 代码格式化\n - [ ] 已经本地编译通过\n - [ ] 已本地针对改动测试\n - [ ] 已添加 type 标签:(填写 type 标签名，如 `bug, enhancement, purge, feature, documentation`)\n - [ ] 已添加 component 标签:(填写 component 标签名，如 `op, system, eager, build, xla, python, ci, test, tooling`)\n - [ ] Draft 转正式 PR 前已请人 Review\n"
  },
  {
    "path": ".github/actions/mac-build/action.yml",
    "content": "name: \"Build OneFlow on macOS\"\ndescription: \"\"\nruns:\n  using: \"composite\"\n  steps:\n    - name: Install dependencies\n      run: |\n        brew install nasm\n      shell: bash\n    - name: Set environment variables\n      run: |\n        set -x\n        cmake_flags=\"\"\n        cmake_flags+=\" -DPython3_EXECUTABLE=$(which python3)\"\n        cmake_flags+=\" -DRPC_BACKEND=LOCAL\"\n        cmake_flags+=\" -DCMAKE_BUILD_TYPE=Release\"\n        cmake_flags+=\" -DBUILD_CUDA=OFF\"\n        echo \"cmake_flags=${cmake_flags}\" >> $GITHUB_ENV\n      shell: bash\n    - name: Build (third party)\n      run: |\n        mkdir -p build\n        cd build\n        cmake .. $cmake_flags -DTHIRD_PARTY=ON -DONEFLOW=OFF\n        make -j $(nproc)\n      shell: bash\n    - name: Build (oneflow)\n      run: |\n        mkdir -p build\n        cd build\n        cmake .. $cmake_flags -DTHIRD_PARTY=OFF -DONEFLOW=ON\n        make -j 2 oneflow\n      shell: bash\n    - name: Build (oneflow_internal)\n      run: |\n        mkdir -p build\n        cd build\n        cmake .. $cmake_flags -DTHIRD_PARTY=OFF -DONEFLOW=ON\n        make -j 2 oneflow_internal\n      shell: bash\n    - name: Build (generate_api)\n      run: |\n        mkdir -p build\n        cd build\n        cmake .. $cmake_flags -DTHIRD_PARTY=OFF -DONEFLOW=ON\n        make -j 2 generate_api\n      shell: bash\n"
  },
  {
    "path": ".github/actions/setup/action.yml",
    "content": "inputs:\n  name:\n    description: 'Placeholder'\n    default: 'Placeholder'\nruns:\n  using: \"composite\"\n  steps:\n    - run: |\n        echo $HOSTNAME\n        rm -rf build/third_party\n        bash ci/setup_submodule.sh\n        auth_header=\"$(git config --local --get http.https://github.com/.extraheader)\"\n        git -c \"http.extraheader=$auth_header\" -c protocol.version=2 submodule update --init --recursive\n      shell: bash\n"
  },
  {
    "path": ".github/actions/upload_oss/action.yml",
    "content": "inputs:\n  src_path:\n    required: true\n  oss_dst_path:\n    required: true\n  oss_access_key_id:\n    required: true\n  oss_access_key_secret:\n    required: true\n  upload_core:\n    required: false\nruns:\n  using: \"composite\"\n  steps:\n    - run: |\n        if [ -z \"$OSS_ACCESS_KEY_ID\" ]\n        then\n          exit 0\n        fi\n        if [ ! -f \"$HOME/ossutil64\" ]; then\n          curl http://gosspublic.alicdn.com/ossutil/1.7.15/ossutil64 -o $HOME/ossutil64\n        fi\n        chmod 755 $HOME/ossutil64\n        $HOME/ossutil64 config -e oss-cn-beijing.aliyuncs.com -i ${{ inputs.oss_access_key_id }} -k ${{ inputs.oss_access_key_secret }}  -L EN -c $HOME/.ossutilconfig\n        dir_arg=\"\"\n        if [ -d \"${{ inputs.src_path }}\" ]; then\n          dir_arg=\"--recursive\"\n        fi\n        upload_core_arg=\"\"\n        if [ \"${{ inputs.upload_core }}\" == \"true\" ]; then\n            echo \"will upload core files\"\n        else\n            upload_core_arg+='--exclude \"core*\"'\n        fi\n        set -x\n        $HOME/ossutil64 cp --disable-ignore-error --update ${dir_arg} ${upload_core_arg} ${{ inputs.src_path }} ${{ inputs.oss_dst_path }}\n      shell: bash\n      env:\n        OSS_ACCESS_KEY_ID: ${{ inputs.oss_access_key_id }}\n        OSS_ACCESS_KEY_SECRET: ${{ inputs.oss_access_key_secret }}\n"
  },
  {
    "path": ".github/actions/upload_ssh/action.yml",
    "content": "name: \"Upload via ssh\"\ndescription: \"\"\ninputs:\n  src_path:\n    required: true\n    description: \"\"\n  dst_host:\n    required: true\n    description: \"\"\n  dst_path:\n    required: true\n    description: \"\"\nruns:\n  using: \"composite\"\n  steps:\n    - run: |\n        set -x\n        dir_arg=\"\"\n        if [ -d \"${{ inputs.src_path }}\" ]; then\n          dir_arg=\"-r\"\n        fi\n        parent_dir=$(dirname ${{ inputs.dst_path }})\n        ssh -o StrictHostKeyChecking=no ${{ inputs.dst_host }} mkdir -p $parent_dir\n        ssh ${{ inputs.dst_host }} rm -rf ${{ inputs.dst_path }}\n        scp ${dir_arg} ${{ inputs.src_path }} ${{ inputs.dst_host }}:${{ inputs.dst_path }}\n      shell: bash\n"
  },
  {
    "path": ".github/actions/whl/action.yml",
    "content": "inputs:\n  tmp_dir:\n    description: \"tmp dir\"\n    required: true\n  cuda_version:\n    description: \"cuda_version\"\n    default: \"10.2\"\n  python_version:\n    description: \"python_version\"\n    default: \"3.8\"\n  extra_flags:\n    description: \"flags like --xla\"\n    default: \"\"\n  extra_docker_args:\n    description: \"\"\n    default: \"\"\nruns:\n  using: \"composite\"\n  steps:\n    - run: |\n        set -x\n        src_dir=${PWD}\n        tmp_dir=\"${{ inputs.tmp_dir }}\"\n        mkdir -p ${tmp_dir}\n        cd ${tmp_dir}\n        docker run --rm -v $PWD:/p -w $PWD:/p busybox rm -rf /p/wheelhouse\n        python3 ${src_dir}/docker/package/manylinux/build_wheel.py \\\n            --cuda_version=${{ inputs.cuda_version }} \\\n            --python_version=${{ inputs.python_version }} \\\n            --use_tuna --use_system_proxy --use_aliyun_mirror \\\n            --wheel_house_dir=${tmp_dir}/wheelhouse \\\n            --oneflow_src_dir=${src_dir} ${{ inputs.extra_flags }} \\\n            --retry=1 \\\n            --extra_docker_args \"${extra_docker_args}\"\n      shell: bash\n"
  },
  {
    "path": ".github/scripts/requirements.txt",
    "content": "PyYAML>=5.1\nparsec\n"
  },
  {
    "path": ".github/scripts/set_initial_variables.py",
    "content": "import json\n\n\ndef create_one(name=None, allow_fail=None):\n    return {\n        \"test_suite\": name,\n        \"cuda_version\": \"N/A\",\n        \"extra_flags\": \"N/A\",\n        \"os\": [\"self-hosted\", \"linux\", \"build\"],\n        \"allow_fail\": allow_fail,\n        \"python_version\": \"N/A\",\n    }\n\n\ndef create_conda(name=None):\n    return create_one(name=name, allow_fail=False)\n\n\ndef print_github_action_output(name=None, value=None):\n    print(f\"::set-output name={name}::{value}\")\n\n\ndef print_result(build_matrix=None, test_matrix=None, out=None):\n    check_include(include_key=\"test_suite\", matrix=build_matrix)\n    if test_matrix != {}:\n        check_include(include_key=\"test_suite\", matrix=test_matrix)\n    assert build_matrix\n    assert test_matrix != None\n    root = {\n        \"build_matrix\": build_matrix,\n        \"test_matrix\": test_matrix,\n    }\n    for k, v in root.items():\n        print_github_action_output(\n            name=k, value=json.dumps(v),\n        )\n    if out:\n        with open(out, \"w+\") as f:\n            f.write(json.dumps(root, indent=4))\n\n\ndef check_include(include_key=None, matrix: dict = None):\n    assert include_key in matrix\n    in_declare = set(matrix[include_key])\n    in_include = set()\n    for include_value in matrix[\"include\"]:\n        in_include.add(include_value[include_key])\n    assert in_declare == in_include, {\n        \"in_declare\": in_declare,\n        \"in_include\": in_include,\n    }\n\n\nif __name__ == \"__main__\":\n    import argparse\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\n        \"--labels\", type=lambda x: (str(x).replace(\" \", \"\").split(\",\")), required=True,\n    )\n    parser.add_argument(\"--out\", type=str, required=False)\n    args = parser.parse_args()\n    if \"need-clang-only\" in args.labels:\n        print_result(\n            build_matrix={\n                \"test_suite\": [\"cpu-clang\"],\n                \"include\": [create_conda(\"cpu-clang\")],\n            },\n            test_matrix={},\n            out=args.out,\n        )\n    else:\n        full_build_matrix = {\n            \"test_suite\": [\"cuda\", \"cpu\", \"xla\", \"xla_cpu\", \"cpu-clang\"],\n            \"include\": [\n                {\n                    \"test_suite\": \"cuda\",\n                    \"cuda_version\": 10.2,\n                    \"extra_flags\": \"--extra_oneflow_cmake_args=-DCUDA_ARCHITECTURES=61 --extra_oneflow_cmake_args=-DRPC_BACKEND=GRPC,LOCAL --extra_oneflow_cmake_args=-DPIP_INDEX_MIRROR=https://pypi.tuna.tsinghua.edu.cn/simple\",\n                    \"os\": [\"self-hosted\", \"linux\", \"build\"],\n                    \"allow_fail\": False,\n                    \"python_version\": \"3.6,3.7\",\n                },\n                {\n                    \"test_suite\": \"cpu\",\n                    \"cuda_version\": 10.2,\n                    \"extra_flags\": \"--extra_oneflow_cmake_args=-DBUILD_SHARED_LIBS=OFF --extra_oneflow_cmake_args=-DRPC_BACKEND=LOCAL --cpu\",\n                    \"os\": [\"self-hosted\", \"linux\", \"build\"],\n                    \"allow_fail\": False,\n                    \"python_version\": \"3.6,3.7\",\n                },\n                {\n                    \"test_suite\": \"xla\",\n                    \"cuda_version\": 10.1,\n                    \"extra_flags\": \"--extra_oneflow_cmake_args=-DCUDA_ARCHITECTURES=61 --extra_oneflow_cmake_args=-DRPC_BACKEND=GRPC,LOCAL --xla --extra_oneflow_cmake_args=-DPIP_INDEX_MIRROR=https://pypi.tuna.tsinghua.edu.cn/simple\",\n                    \"os\": [\"self-hosted\", \"linux\", \"build\"],\n                    \"allow_fail\": True,\n                    \"python_version\": 3.6,\n                },\n                {\n                    \"test_suite\": \"xla_cpu\",\n                    \"cuda_version\": 10.1,\n                    \"extra_flags\": \"--extra_oneflow_cmake_args=-DRPC_BACKEND=GRPC,LOCAL --xla --cpu --extra_oneflow_cmake_args=-DPIP_INDEX_MIRROR=https://pypi.tuna.tsinghua.edu.cn/simple\",\n                    \"os\": [\"self-hosted\", \"linux\", \"build\"],\n                    \"allow_fail\": True,\n                    \"python_version\": 3.6,\n                },\n                create_conda(\"cpu-clang\"),\n            ],\n        }\n        full_test_matrix = {\n            \"test_suite\": [\n                \"cuda\",\n                \"cuda_op\",\n                \"cuda_new_interface\",\n                \"cpu_new_interface\",\n                \"cpu\",\n                \"xla\",\n                \"xla_cpu\",\n            ],\n            \"include\": [\n                {\n                    \"test_suite\": \"cuda\",\n                    \"os\": [\"self-hosted\", \"linux\", \"gpu\"],\n                    \"allow_fail\": False,\n                    \"build_env\": \"build.cuda.env\",\n                },\n                {\n                    \"test_suite\": \"cuda_op\",\n                    \"os\": [\"self-hosted\", \"linux\", \"gpu\"],\n                    \"allow_fail\": False,\n                    \"build_env\": \"build.cuda.env\",\n                },\n                {\n                    \"test_suite\": \"cuda_new_interface\",\n                    \"os\": [\"self-hosted\", \"linux\", \"gpu\"],\n                    \"allow_fail\": False,\n                    \"build_env\": \"build.cuda.env\",\n                },\n                {\n                    \"test_suite\": \"cpu\",\n                    \"os\": [\"self-hosted\", \"linux\", \"cpu\"],\n                    \"allow_fail\": False,\n                    \"build_env\": \"build.cpu.env\",\n                },\n                {\n                    \"test_suite\": \"cpu_new_interface\",\n                    \"os\": [\"self-hosted\", \"linux\", \"cpu\"],\n                    \"allow_fail\": False,\n                    \"build_env\": \"build.cpu.env\",\n                },\n                {\n                    \"test_suite\": \"xla\",\n                    \"os\": [\"self-hosted\", \"linux\", \"gpu\"],\n                    \"allow_fail\": True,\n                    \"build_env\": \"build.xla.env\",\n                },\n                {\n                    \"test_suite\": \"xla_cpu\",\n                    \"os\": [\"self-hosted\", \"linux\", \"cpu\"],\n                    \"allow_fail\": True,\n                    \"build_env\": \"build.xla_cpu.env\",\n                },\n            ],\n        }\n        print_result(\n            build_matrix=full_build_matrix, test_matrix=full_test_matrix, out=args.out,\n        )\n"
  },
  {
    "path": ".github/workflows/canary.yml",
    "content": "name: Canary\n\non:\n  push:\n    branches:\n      - master\n      - \"canary/*\"\n  workflow_dispatch:\n    inputs:\n      oneflow-ref:\n        description: \"\"\n        default: \"master\"\n        required: true\nconcurrency:\n  group: canary-${{ github.ref }}\n  cancel-in-progress: false\njobs:\n  canary_release:\n    name: Canary Release\n    timeout-minutes: 120\n    runs-on: [self-hosted, linux, release]\n    if: github.repository == 'Oneflow-Inc/oneflow'\n    strategy:\n      max-parallel: 1\n      fail-fast: false\n      matrix:\n        entry: [\"canary\", \"profiler\"]\n        include:\n          - entry: \"canary\"\n            cmake-init-cache: \"cmake/caches/ci/canary/cuda.cmake\"\n          - entry: \"profiler\"\n            cmake-init-cache: \"cmake/caches/ci/profiler/cuda.cmake\"\n    env:\n      ONEFLOW_SRC: .\n      MANYLINUX_CACHE_DIR: ~/manylinux-cache-dir/canary-cu112\n      WHEELHOUSE_DIR: manylinux-wheelhouse\n      COMPUTE_PLATFORM: cu118\n      OSS_BUCKET: oneflow-staging\n      OSS_WHEEL_HOUSE_DIR: ${{ matrix.entry }}/commit/${{ github.sha }}\n      OSS_GITHUB_REF_DIR: ${{ matrix.entry }}/${{ github.ref }}\n    steps:\n      - name: Fix permissions\n        run: |\n          set -x\n          docker run --rm -v $PWD:$PWD -w $PWD busybox rm -rf *\n      - name: Remove leftover cuda-installer.log\n        run: |\n          docker run --rm -v /tmp:/host/tmp -w /p busybox rm -f /host/tmp/cuda-installer.log\n      - name: Checkout Oneflow-Inc/oneflow\n        if: ${{ github.event.inputs.oneflow-ref != '' }}\n        uses: actions/checkout@v2\n        with:\n          ref: ${{ github.event.inputs.oneflow-ref }}\n      - name: Checkout Oneflow-Inc/oneflow\n        if: ${{ github.event.inputs.oneflow-ref == '' }}\n        uses: actions/checkout@v2\n      - uses: Oneflow-Inc/get-oneflow@ci-test-with-cu118\n        name: Build manylinux\n        id: build-cuda\n        with:\n          cmake-init-cache: ${{ env.ONEFLOW_SRC }}/${{ matrix.cmake-init-cache }}\n          build-script: ${{ env.ONEFLOW_SRC }}/ci/manylinux/build-gcc9.sh\n          oneflow-src: ${{ env.ONEFLOW_SRC }}\n          oneflow-build-env: manylinux\n          wheelhouse-dir: ${{ env.WHEELHOUSE_DIR }}\n          clear-wheelhouse-dir: true\n          self-hosted: true\n          manylinux-cache-dir: ${{ env.MANYLINUX_CACHE_DIR }}\n          docker-run-use-system-http-proxy: false\n          docker-run-use-lld: true\n          retry-failed-build: true\n          clean-ccache: true\n          compute-platform: ${{ env.COMPUTE_PLATFORM }}\n          python-versions: |\n            3.8\n            3.10\n      - name: Upload wheelhouse\n        uses: ./.github/actions/upload_oss\n        with:\n          src_path: ${{ env.WHEELHOUSE_DIR }}\n          oss_dst_path: oss://${{ env.OSS_BUCKET }}/${{ env.OSS_WHEEL_HOUSE_DIR }}/${{ env.COMPUTE_PLATFORM }}\n          oss_access_key_id: ${{ secrets.OSS_ACCESS_KEY_ID }}\n          oss_access_key_secret: ${{ secrets.OSS_ACCESS_KEY_SECRET }}\n      - name: Update pip index\n        env:\n          OSS_ACCESS_KEY_ID: ${{ secrets.OSS_ACCESS_KEY_ID }}\n          OSS_ACCESS_KEY_SECRET: ${{ secrets.OSS_ACCESS_KEY_SECRET }}\n        run: |\n          python3 -m pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple\n          python3 -m pip install oss2 beautifulsoup4 --user\n          python3 tools/create_pip_index.py -b ${{ env.OSS_BUCKET }} \\\n            --dir_key ${{ env.OSS_WHEEL_HOUSE_DIR }}/${{ env.COMPUTE_PLATFORM }} \\\n            --index_key=${{ env.OSS_WHEEL_HOUSE_DIR }}/${{ env.COMPUTE_PLATFORM }}/index.html \\\n            --index_key=${{ env.OSS_GITHUB_REF_DIR}}/${{ env.COMPUTE_PLATFORM }}/index.html\n"
  },
  {
    "path": ".github/workflows/community_release.yml",
    "content": "name: Community Release\n\non:\n  push:\n    branches:\n      - \"community/*\"\n  schedule:\n    # beijing: 6 pm.\n    # utc: 10 am.\n    - cron: \"0 10 * * sat\"\n  workflow_dispatch:\n    inputs:\n      priv_branch:\n        required: false\n        default: \"main\"\n\nconcurrency:\n  group: community-release-${{ github.ref }}-${{ inputs.priv_branch }}\n  cancel-in-progress: true\n\njobs:\n  release:\n    name: Release pip\n    permissions:\n      contents: read\n      pull-requests: write\n    uses: ./.github/workflows/release.yml\n    with:\n      is_priv: true\n      branch: ${{ inputs.priv_branch || 'main' }}\n      upload_override_branch: \"community\"\n      cuda_cmake_cache: cmake/caches/ci/release/cuda_community.cmake\n    secrets:\n      ONEFLOW_PRIV_ORG: ${{ secrets.ONEFLOW_PRIV_ORG }}\n      ONEFLOW_PRIV_GH_TOKEN: ${{ secrets.ONEFLOW_PRIV_GH_TOKEN }}\n      ONEFLOW_PRIV_OSS_BUCKET: ${{ secrets.ONEFLOW_PRIV_OSS_BUCKET }}\n      OSS_ACCESS_KEY_ID: ${{ secrets.OSS_ACCESS_KEY_ID }}\n      OSS_ACCESS_KEY_SECRET: ${{ secrets.OSS_ACCESS_KEY_SECRET }}\n      ONEFLOW_CI_HTTP_PROXY: ${{ secrets.ONEFLOW_CI_HTTP_PROXY }}\n"
  },
  {
    "path": ".github/workflows/on_merge.yml",
    "content": "name: Update Benchmark History\non:\n  pull_request:\n    types:\n      - closed\n    branches:\n      - master\n\nenv:\n  OSS_ACCESS_KEY_ID: ${{ secrets.OSS_ACCESS_KEY_ID }}\n  OSS_ACCESS_KEY_SECRET: ${{ secrets.OSS_ACCESS_KEY_SECRET }}\n\njobs:\n  if_merged:\n    if: github.event.pull_request.merged == true\n    runs-on: ubuntu-latest\n    steps:\n      - uses: Oneflow-Inc/get-oneflow/update-benchmark-history@ci-test-with-cu118\n        name: Update benchmark history\n        timeout-minutes: 10\n"
  },
  {
    "path": ".github/workflows/pr.yml",
    "content": "name: Check PR\n\non:\n  pull_request:\n    types: [opened, labeled, unlabeled, synchronize]\n\njobs:\n  check_labels:\n    runs-on: ubuntu-22.04\n    name: Labels\n    if: github.event.pull_request.draft == false && github.base_ref == 'master'\n    steps:\n      - name: Check type labels 'bug, enhancement, purge, feature, documentation'\n        if: (contains(github.event.pull_request.labels.*.name, 'bug') || contains(github.event.pull_request.labels.*.name, 'enhancement') || contains(github.event.pull_request.labels.*.name, 'purge') || contains(github.event.pull_request.labels.*.name, 'feature') || contains(github.event.pull_request.labels.*.name, 'documentation')) == false\n        run: |\n          exit 1\n      - name: Check component labels 'op, system, eager, build, xla, python, ci, test, tooling, quantization, graph, ir, serving'\n        if: (contains(github.event.pull_request.labels.*.name, 'op') || contains(github.event.pull_request.labels.*.name, 'system') || contains(github.event.pull_request.labels.*.name, 'eager') || contains(github.event.pull_request.labels.*.name, 'build') || contains(github.event.pull_request.labels.*.name, 'xla') || contains(github.event.pull_request.labels.*.name, 'python') || contains(github.event.pull_request.labels.*.name, 'ci') || contains(github.event.pull_request.labels.*.name, 'test') || contains(github.event.pull_request.labels.*.name, 'tooling') || contains(github.event.pull_request.labels.*.name, 'quantization') || contains(github.event.pull_request.labels.*.name, 'graph') || contains(github.event.pull_request.labels.*.name, 'ir') || contains(github.event.pull_request.labels.*.name, 'serving')) == false\n        run: |\n          exit 2\n"
  },
  {
    "path": ".github/workflows/priv_release.yml",
    "content": "name: Priv Release\n\non:\n  push:\n    branches:\n      - \"pro/*\"\n  schedule:\n    # beijing: 12 pm.\n    # utc: 4 am.\n    - cron: \"0 4 * * sun\"\n  workflow_dispatch:\n    inputs:\n      priv_branch:\n        required: false\n        default: \"main\"\n\nconcurrency:\n  group: priv-release-${{ github.ref }}-${{ inputs.priv_branch }}\n  cancel-in-progress: true\n\njobs:\n  release:\n    name: Release pip\n    permissions:\n      contents: read\n      pull-requests: write\n    uses: ./.github/workflows/release.yml\n    with:\n      is_priv: true\n      branch: ${{ inputs.priv_branch || 'main' }}\n      cuda_cmake_cache: cmake/caches/ci/release/cuda_pro.cmake\n    secrets:\n      ONEFLOW_PRIV_ORG: ${{ secrets.ONEFLOW_PRIV_ORG }}\n      ONEFLOW_PRIV_GH_TOKEN: ${{ secrets.ONEFLOW_PRIV_GH_TOKEN }}\n      ONEFLOW_PRIV_OSS_BUCKET: ${{ secrets.ONEFLOW_PRIV_OSS_BUCKET }}\n      OSS_ACCESS_KEY_ID: ${{ secrets.OSS_ACCESS_KEY_ID }}\n      OSS_ACCESS_KEY_SECRET: ${{ secrets.OSS_ACCESS_KEY_SECRET }}\n      ONEFLOW_CI_HTTP_PROXY: ${{ secrets.ONEFLOW_CI_HTTP_PROXY }}\n"
  },
  {
    "path": ".github/workflows/release.yml",
    "content": "name: Release\n\non:\n  push:\n    branches:\n      - \"release/*\"\n\n  schedule:\n    # beijing: 2 am.\n    # utc: 6 pm.\n    - cron: \"0 18 * * *\"\n  workflow_dispatch:\n    inputs:\n      placeholder:\n        description: \"update .github/workflows/release.yml to config your build\"\n        required: false\n  workflow_call:\n    inputs:\n      is_priv:\n        required: true\n        type: boolean\n      branch:\n        required: false\n        type: string\n        default: \"main\"\n      upload_override_branch:\n        required: false\n        type: string\n      cuda_cmake_cache:\n        required: false\n        type: string\n    secrets:\n      ONEFLOW_PRIV_ORG:\n        required: true\n      ONEFLOW_PRIV_GH_TOKEN:\n        required: true\n      ONEFLOW_PRIV_OSS_BUCKET:\n        required: true\n      OSS_ACCESS_KEY_ID:\n        required: true\n      OSS_ACCESS_KEY_SECRET:\n        required: true\n      ONEFLOW_CI_HTTP_PROXY:\n        required: false\nconcurrency:\n  group: release-${{ github.ref }}-${{ inputs.branch }}\n  cancel-in-progress: ${{ github.ref != 'refs/heads/master' }}\nenv:\n  ONEFLOW_SRC: .\njobs:\n  generate-build-matrix:\n    name: \"Generate build matrix\"\n    runs-on: ubuntu-latest\n    env:\n      ONEFLOW_SRC: .\n    outputs:\n      matrix: ${{ steps.find-cache.outputs.matrix }}\n      formatted_date: ${{ steps.date.outputs.formatted_date }}\n    steps:\n      - name: Checkout Oneflow-Inc/oneflow\n        uses: actions/checkout@v2\n        if: ${{ !inputs.is_priv }}\n        with:\n          ref: ${{ github.event.pull_request.head.sha }}\n          repository: ${{github.event.pull_request.head.repo.full_name}}\n      - name: Checkout oneflow\n        uses: actions/checkout@v2\n        if: ${{ inputs.is_priv }}\n        with:\n          ref: ${{ inputs.branch }}\n          repository: ${{ secrets.ONEFLOW_PRIV_ORG }}/oneflow\n          token: ${{ secrets.ONEFLOW_PRIV_GH_TOKEN }}\n      - uses: Oneflow-Inc/get-oneflow/cache-complete/matrix/build@ci-test-with-cu118\n        name: Find build cache\n        id: find-cache\n        timeout-minutes: 5\n        with:\n          delete-cache: ${{ contains(github.event.pull_request.labels.*.name, 'need-clean-ccache') }}\n          runner-labels: |\n            self-hosted\n            linux\n            release\n          oneflow-src: ${{ env.ONEFLOW_SRC }}\n          entries: |\n            cu122\n            cu121\n            cu118\n            cpu\n      - name: Get current date\n        id: date\n        run: echo \"formatted_date=$(date +'%Y%m%d')\" >> $GITHUB_OUTPUT\n\n  staging_release:\n    env:\n      MANYLINUX_CACHE_DIR: ~/manylinux-cache-dir/release/${{ matrix.entry }}\n      WHEELHOUSE_DIR: manylinux_wheelhouse\n      OSS_DIR: branch/${{ github.ref_name }}/${{ matrix.entry }}/${{ github.sha }}\n      GITHUB_REF_NAME: ${{ github.ref_name }}\n      GITHUB_SHA: ${{ github.sha }}\n      ONEFLOW_OSS_BUCKET: oneflow-staging\n      https_proxy: ${{ secrets.ONEFLOW_CI_HTTP_PROXY }}\n    needs: [generate-build-matrix]\n    name: Staging Release\n    timeout-minutes: 240\n    runs-on: [self-hosted, linux, release]\n    if: github.repository == 'Oneflow-Inc/oneflow' || inputs.is_priv\n    strategy:\n      fail-fast: false\n      max-parallel: 6\n      matrix: ${{ fromJson(needs.generate-build-matrix.outputs.matrix) }}\n    steps:\n      - name: Fix permissions\n        run: |\n          docker run --rm -v $PWD:/p -w /p busybox rm -rf *\n      - name: Install dependencies\n        run: |\n          python3 -m pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple\n          python3 -m pip install -U setuptools wheel --user\n          python3 -m pip install oss2  --user\n      - name: Checkout Oneflow-Inc/oneflow\n        uses: actions/checkout@v2\n        if: ${{ !inputs.is_priv }}\n        with:\n          ref: ${{ github.event.pull_request.head.sha }}\n          repository: ${{github.event.pull_request.head.repo.full_name}}\n      - name: Checkout private oneflow\n        uses: actions/checkout@v2\n        if: ${{ inputs.is_priv }}\n        with:\n          ref: ${{ inputs.branch }}\n          repository: ${{ secrets.ONEFLOW_PRIV_ORG }}/oneflow\n          token: ${{ secrets.ONEFLOW_PRIV_GH_TOKEN }}\n      - name: Checkout cutlass_extension\n        uses: actions/checkout@v2\n        if: ${{ inputs.is_priv }}\n        with:\n          repository: ${{ secrets.ONEFLOW_PRIV_ORG }}/cutlass-extension\n          token: ${{ secrets.ONEFLOW_PRIV_GH_TOKEN }}\n          path: cutlass-extension\n      - name: Set Private env\n        if: ${{ inputs.is_priv }}\n        run: |\n          GITHUB_SHA=$(git rev-parse HEAD)\n          echo \"OSS_DIR=branch/${{ inputs.upload_override_branch || inputs.branch }}/${{ matrix.entry }}/${GITHUB_SHA}\" >> $GITHUB_ENV\n          echo \"GITHUB_REF_NAME=${{ inputs.upload_override_branch || inputs.branch }}\" >> $GITHUB_ENV\n          echo \"GITHUB_SHA=${GITHUB_SHA}\" >> $GITHUB_ENV\n          echo \"ONEFLOW_OSS_BUCKET=${{ secrets.ONEFLOW_PRIV_OSS_BUCKET }}\" >> $GITHUB_ENV\n      - name: Print env\n        if: ${{ inputs.is_priv }}\n        run: |\n          env\n      - uses: Oneflow-Inc/get-oneflow@ci-test-with-cu118\n        name: Build ${{ matrix.entry }}\n        if: ${{ matrix.entry =='cu118' || startsWith(matrix.entry, 'cu12') }}\n        with:\n          cmake-init-cache: ${{ env.ONEFLOW_SRC }}/${{ inputs.cuda_cmake_cache || 'cmake/caches/ci/release/cu118.cmake' }}\n          build-script: ${{ env.ONEFLOW_SRC }}/ci/manylinux/build-gcc9.sh\n          oneflow-src: ${{ env.ONEFLOW_SRC }}\n          oneflow-build-env: manylinux\n          wheelhouse-dir: ${{ env.WHEELHOUSE_DIR }}\n          clear-wheelhouse-dir: true\n          self-hosted: true\n          compute-platform: ${{ matrix.entry }}\n          manylinux-cache-dir: ${{ env.MANYLINUX_CACHE_DIR }}\n          docker-run-use-system-http-proxy: false\n          docker-run-use-lld: false\n          retry-failed-build: true\n          clean-ccache: true\n          nightly: ${{ inputs.is_priv || github.event_name == 'schedule' || github.ref == 'refs/heads/release/add_nightly_date_index'}}\n          nightly-date: ${{ needs.generate-build-matrix.outputs.formatted_date }}\n          use-nvidia-wheels: ${{ matrix.entry !='cu112' }}\n          python-versions: |\n            3.12\n            3.11\n            3.10\n            3.9\n            3.8\n      - uses: Oneflow-Inc/get-oneflow@ci-test-with-cu118\n        name: Build ${{ matrix.entry }}\n        if: ${{ startsWith(matrix.entry, 'cu') && matrix.entry !='cu118' && !startsWith(matrix.entry, 'cu12') }}\n        with:\n          cmake-init-cache: ${{ env.ONEFLOW_SRC }}/cmake/caches/ci/release/cuda.cmake\n          build-script: ${{ env.ONEFLOW_SRC }}/ci/manylinux/build-gcc9.sh\n          oneflow-src: ${{ env.ONEFLOW_SRC }}\n          oneflow-build-env: manylinux\n          wheelhouse-dir: ${{ env.WHEELHOUSE_DIR }}\n          clear-wheelhouse-dir: true\n          self-hosted: true\n          compute-platform: ${{ matrix.entry }}\n          manylinux-cache-dir: ${{ env.MANYLINUX_CACHE_DIR }}\n          docker-run-use-system-http-proxy: false\n          docker-run-use-lld: false\n          retry-failed-build: true\n          clean-ccache: true\n          nightly: ${{ inputs.is_priv || github.event_name == 'schedule' || github.ref == 'refs/heads/release/add_nightly_date_index'}}\n          nightly-date: ${{ needs.generate-build-matrix.outputs.formatted_date }}\n          use-nvidia-wheels: ${{ matrix.entry !='cu112' }}\n          python-versions: |\n            3.12\n            3.11\n            3.10\n            3.9\n            3.8\n      - uses: Oneflow-Inc/get-oneflow@ci-test-with-cu118\n        name: Build ${{ matrix.entry }}\n        if: ${{ matrix.entry =='cpu' }}\n        with:\n          cmake-init-cache: ${{ env.ONEFLOW_SRC }}/cmake/caches/ci/release/cpu.cmake\n          build-script: ${{ env.ONEFLOW_SRC }}/ci/manylinux/build.sh\n          oneflow-src: ${{ env.ONEFLOW_SRC }}\n          oneflow-build-env: manylinux\n          wheelhouse-dir: ${{ env.WHEELHOUSE_DIR }}\n          clear-wheelhouse-dir: true\n          self-hosted: true\n          compute-platform: ${{ matrix.entry }}\n          manylinux-cache-dir: ${{ env.MANYLINUX_CACHE_DIR }}\n          docker-run-use-system-http-proxy: false\n          docker-run-use-lld: false\n          retry-failed-build: true\n          clean-ccache: false\n          nightly: ${{ inputs.is_priv || github.event_name == 'schedule' || github.ref == 'refs/heads/release/add_nightly_date_index'}}\n          nightly-date: ${{ needs.generate-build-matrix.outputs.formatted_date }}\n          python-versions: |\n            3.12\n            3.11\n            3.10\n            3.9\n            3.8\n      - name: Upload wheel\n        uses: ./.github/actions/upload_oss\n        with:\n          src_path: ${{ env.WHEELHOUSE_DIR }}\n          oss_dst_path: oss://${{ env.ONEFLOW_OSS_BUCKET }}/${{ env.OSS_DIR }}\n          oss_access_key_id: ${{ secrets.OSS_ACCESS_KEY_ID }}\n          oss_access_key_secret: ${{ secrets.OSS_ACCESS_KEY_SECRET }}\n      - name: Update pip index\n        env:\n          OSS_ACCESS_KEY_ID: ${{ secrets.OSS_ACCESS_KEY_ID }}\n          OSS_ACCESS_KEY_SECRET: ${{ secrets.OSS_ACCESS_KEY_SECRET }}\n        run: |\n          python3 -m pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple\n          python3 -m pip install oss2 beautifulsoup4 --user\n          python3 tools/create_pip_index.py --dir_key ${{ env.OSS_DIR }} -b ${{ env.ONEFLOW_OSS_BUCKET }} \\\n            --index_key=branch/${{ env.GITHUB_REF_NAME }}/${{ matrix.entry }}/index.html \\\n            --index_key=branch/${{ env.GITHUB_REF_NAME }}/date/${{ needs.generate-build-matrix.outputs.formatted_date }}/${{ matrix.entry }}/index.html \\\n            --index_key=${{ env.OSS_DIR }}/index.html \\\n            --index_key=commit/${{ env.GITHUB_SHA }}/${{ matrix.entry }}/index.html\n      - name: Update API docs\n        if: github.ref == 'refs/heads/master' && matrix.entry == 'cpu' && !inputs.is_priv\n        env:\n          READTHEDOCS_TOKEN: ${{ secrets.READTHEDOCS_TOKEN }}\n        run: |\n          curl -X POST -d \"branches=master\" -d \"token=${READTHEDOCS_TOKEN}\"  https://readthedocs.org/api/v2/webhook/oneflow/135376/\n"
  },
  {
    "path": ".github/workflows/simple.yml",
    "content": "name: Simple CI\non:\n  pull_request:\n    types: [review_requested]\n    branches:\n      - \"*\"\n  push:\n    branches:\n      - master\n  workflow_dispatch:\n    inputs:\n      placeholder:\n        description: \"placeholder, no effect\"\n        required: false\nconcurrency:\n  group: simple-ci-${{ github.ref }}\n  cancel-in-progress: ${{ github.ref != 'refs/heads/master' }}\njobs:\n  static_analysis_with_clang:\n    name: Static analysis with clang\n    runs-on: ubuntu-22.04\n    if: github.ref == 'refs/heads/master' || (github.event.pull_request.draft == false && contains(github.event.pull_request.requested_reviewers.*.login, 'oneflow-ci-bot') && contains(github.event.pull_request.labels.*.name, 'need-simple-ci'))\n    steps:\n      - name: Check out OneFlow\n        uses: actions/checkout@v2\n        with:\n          ref: ${{ github.event.pull_request.head.ref }}\n          repository: ${{github.event.pull_request.head.repo.full_name}}\n      - name: Install dependencies\n        run: |\n          sudo apt-get update\n          sudo apt-get install -y libopenblas-dev nasm python3-pip ninja-build\n      - name: Download OneFlow custom clang-tidy\n        run: |\n          wget https://github.com/Oneflow-Inc/llvm-project/releases/download/maybe-14.0.4/clang-tidy-14.AppImage\n          wget https://raw.githubusercontent.com/oneflow-inc/llvm-project/maybe/clang-tools-extra/clang-tidy/tool/run-clang-tidy.py\n          chmod +x clang-tidy-14.AppImage run-clang-tidy.py\n      - name: Build third party libs and generate files\n        run: |\n          mkdir build\n          cd build\n          cmake .. -C ../cmake/caches/international/cpu.cmake \\\n            -DCMAKE_BUILD_TYPE=Release \\\n            -DBUILD_TESTING=ON\n          cmake --build . -j$(nproc) --target oneflow_deps of_protoobj of_functional_obj of_functional_tensor_obj of_op_schema\n      - name: Run clang-tidy for all translation units\n        # use clang as compiler for correct compiler flags\n        run: |\n          cd build\n          rm CMakeCache.txt\n          cmake .. -C ../cmake/caches/international/cpu.cmake \\\n            -DCMAKE_C_COMPILER=clang-12 \\\n            -DCMAKE_CXX_COMPILER=clang++-12 \\\n            -DCMAKE_BUILD_TYPE=Release \\\n            -DBUILD_TESTING=ON \\\n            -DCMAKE_EXPORT_COMPILE_COMMANDS=ON\n          cd ..\n          ./run-clang-tidy.py -clang-tidy-binary ./clang-tidy-14.AppImage -p build -quiet -allow-enabling-alpha-checkers -extra-arg=\"-Xclang\" -extra-arg=\"-analyzer-config\" -extra-arg=\"-Xclang\" -extra-arg=\"aggressive-binary-operation-simplification=true\" \"^(?!$(pwd)/build)\"\n\n  hosted:\n    name: CPU-only\n    if: github.ref == 'refs/heads/master' || (github.event.pull_request.draft == false && contains(github.event.pull_request.requested_reviewers.*.login, 'oneflow-ci-bot') && contains(github.event.pull_request.labels.*.name, 'need-simple-ci'))\n    runs-on: ${{ matrix.os }}\n    env:\n      CFLAGS: \"-w\"\n      CXXFLAGS: \"-w\"\n    strategy:\n      fail-fast: true\n      max-parallel: 1\n      matrix:\n        test_suite: [\"mac\", \"ubuntu\"]\n        cmake_generator: [\"Ninja\", \"Unix Makefiles\"]\n        cmake_build_type: [\"Debug\", \"Release\"]\n        build_shared_libs: [\"ON\", \"OFF\"]\n        include:\n          - test_suite: mac\n            os: \"macos-10.15\"\n            make_concurrency: 2\n          - test_suite: ubuntu\n            os: \"ubuntu-22.04\"\n            make_concurrency: 2\n        exclude:\n          - test_suite: mac\n            cmake_build_type: \"Debug\"\n          - test_suite: mac\n            cmake_generator: \"Ninja\"\n          - test_suite: ubuntu\n            cmake_generator: \"Ninja\"\n            cmake_build_type: \"Debug\"\n          - test_suite: ubuntu\n            cmake_generator: \"Ninja\"\n            build_shared_libs: \"OFF\"\n          - test_suite: ubuntu\n            cmake_build_type: \"Debug\"\n            build_shared_libs: \"OFF\"\n          - test_suite: ubuntu\n            cmake_generator: \"Unix Makefiles\"\n            cmake_build_type: \"Release\"\n    steps:\n      - name: Set Swap Space\n        uses: pierotofy/set-swap-space@master\n        with:\n          swap-size-gb: 5\n      - uses: actions/checkout@v2\n        with:\n          ref: ${{ github.event.pull_request.head.sha }}\n      - name: Install dependencies (homebrew)\n        if: matrix.test_suite == 'mac'\n        run: |\n          brew install nasm ninja\n      - name: Install dependencies (apt)\n        if: matrix.test_suite == 'ubuntu'\n        run: |\n          sudo apt install -y libopenblas-dev nasm g++ gcc python3-pip ninja-build\n      - name: Cache pip (Linux)\n        if: startsWith(runner.os, 'Linux')\n        uses: actions/cache@v4\n        with:\n          path: ~/.cache/pip\n          key: ${{ matrix.os }}-pip-${{ hashFiles('**/requirements.txt') }}\n      - name: Cache pip (macOS)\n        if: startsWith(runner.os, 'macOS')\n        uses: actions/cache@v4\n        with:\n          path: ~/Library/Caches/pip\n          key: ${{ matrix.os }}-pip-${{ hashFiles('**/requirements.txt') }}\n      - name: Install dependencies (pip)\n        run: |\n          python3 -m pip install -r ci/requirements.txt\n          python3 -m pip install -r dev-requirements.txt\n      - name: Set environment variables\n        run: |\n          set -x\n          cmake_flags=\"\"\n          cmake_flags+=\" -DBUILD_CUDA=OFF\"\n          cmake_flags+=\" -DBUILD_TESTING=ON\"\n          cmake_flags+=\" -G '${{ matrix.cmake_generator }}'\"\n          cmake_flags+=\" -DCMAKE_BUILD_TYPE=${{ matrix.cmake_build_type }}\"\n          cmake_flags+=\" -DBUILD_SHARED_LIBS=${{ matrix.build_shared_libs }}\"\n          cmake_flags+=\" -DCMAKE_MACOSX_RPATH=FALSE\"\n          cmake_flags+=\" -DCMAKE_BUILD_WITH_INSTALL_RPATH=FALSE\"\n          echo \"cmake_flags=${cmake_flags}\" >> $GITHUB_ENV\n      - name: Build (third party)\n        if: matrix.cmake_generator != 'Ninja'\n        run: |\n          set -x\n          mkdir -p build-third_party\n          mkdir -p third_party_install\n          cd build-third_party\n          cmake .. ${{ env.cmake_flags }} -DTHIRD_PARTY=ON -DONEFLOW=OFF -DTHIRD_PARTY_DIR=$PWD/../third_party_install\n          cmake --build . -j $(nproc)\n      - name: Build (oneflow)\n        if: matrix.cmake_generator != 'Ninja'\n        run: |\n          mkdir -p build\n          cd build\n          cmake .. ${{ env.cmake_flags }} -DTHIRD_PARTY=OFF -DONEFLOW=ON -DTHIRD_PARTY_DIR=$PWD/../third_party_install\n          cmake --build . -j ${{ matrix.make_concurrency }} --target oneflow\n      - name: Build (oneflow_internal)\n        if: always() && matrix.cmake_generator != 'Ninja'\n        run: |\n          mkdir -p build\n          cd build\n          cmake .. ${{ env.cmake_flags }} -DTHIRD_PARTY=OFF -DONEFLOW=ON\n          cmake --build . -j ${{ matrix.make_concurrency }} --target oneflow_internal\n      - name: Build (oneflow_py)\n        if: always() && matrix.cmake_generator != 'Ninja'\n        run: |\n          mkdir -p build\n          cd build\n          cmake .. ${{ env.cmake_flags }} -DTHIRD_PARTY=OFF -DONEFLOW=ON\n          cmake --build . -j ${{ matrix.make_concurrency }} --target oneflow_py\n      - name: Build (oneflow_testexe)\n        if: always() && matrix.cmake_generator != 'Ninja'\n        run: |\n          mkdir -p build\n          cd build\n          cmake .. ${{ env.cmake_flags }} -DTHIRD_PARTY=OFF -DONEFLOW=ON\n          cmake --build . -j ${{ matrix.make_concurrency }} --target oneflow_testexe\n      - name: Build (ALL)\n        if: always()\n        continue-on-error: ${{ startsWith(runner.os, 'macOS') && matrix.cmake_generator == 'Ninja' && matrix.build_shared_libs == 'ON' }}\n        run: |\n          mkdir -p build\n          cd build\n          cmake .. ${{ env.cmake_flags }}\n          cmake --build . -j ${{ matrix.make_concurrency }}\n      - name: Exe test\n        if: always()\n        continue-on-error: true\n        run: |\n          ulimit -c\n          ulimit -c unlimited\n          ulimit -c\n          mkdir -p build\n          cd build\n          ./bin/oneflow_testexe\n      - name: Op test\n        if: always()\n        continue-on-error: true\n        run: |\n          ulimit -c\n          ulimit -c unlimited\n          ulimit -c\n          source build/source.sh\n          ONEFLOW_TEST_GITHUB_HOSTED=1 ONEFLOW_TEST_CPU_ONLY=1 bash ci/test/1node_op_test.sh\n      - name: \"Tar logs\"\n        if: always() && contains(github.event.pull_request.labels.*.name, 'need-simple-ci-upload-artifact')\n        continue-on-error: true\n        run: |\n          set -ex\n          if [[ -d \"${HOME}/oneflow_temp\" ]]\n          then\n              tar -cvf home_oneflow_temp.tar ${HOME}/oneflow_temp\n          fi\n          if [[ -d \"${PWD}/test_tmp_dir\" ]]\n          then\n              tar -cvf cwd_test_tmp_dir.tar ${PWD}/test_tmp_dir\n          fi\n      - name: Upload logs\n        if: always() && contains(github.event.pull_request.labels.*.name, 'need-simple-ci-upload-artifact')\n        uses: actions/upload-artifact@v4\n        with:\n          name: logs-${{ matrix.test_suite }}-${{ matrix.cmake_generator }}-${{ matrix.cmake_build_type }}-shared-${{ matrix.build_shared_libs }}\n          path: |\n            home_oneflow_temp.tar\n            cwd_test_tmp_dir.tar\n\n  conda:\n    name: Build with conda\n    if: github.ref == 'refs/heads/master' || (github.event.pull_request.draft == false && contains(github.event.pull_request.requested_reviewers.*.login, 'oneflow-ci-bot') && contains(github.event.pull_request.labels.*.name, 'need-simple-ci'))\n    runs-on: ubuntu-latest\n    strategy:\n      fail-fast: true\n      max-parallel: 1\n      matrix:\n        build-type: [\"gcc7\", \"clang10\"]\n    steps:\n      - name: Checkout Oneflow-Inc/oneflow\n        uses: actions/checkout@v2\n      - name: Checkout Oneflow-Inc/conda-env\n        uses: actions/checkout@v2\n        with:\n          repository: Oneflow-Inc/conda-env\n          ref: 30a7f00eb48ee9009d85a848e720823e5054c66b\n          path: conda-env\n      - uses: Oneflow-Inc/get-oneflow@ci-test-with-cu118\n        name: Build with gcc7\n        if: ${{ matrix.build-type == 'gcc7'}}\n        with:\n          cmake-init-cache: cmake/caches/ci/gh-hosted/cpu-gcc.cmake\n          oneflow-src: .\n          oneflow-build-env: conda\n          conda-env-file: conda-env/dev/gcc7/environment-v2.yml\n          conda-env-name: oneflow-dev-gcc7-v2\n      - uses: Oneflow-Inc/get-oneflow@ci-test-with-cu118\n        name: Build with clang10\n        if: ${{ matrix.build-type == 'clang10'}}\n        with:\n          cmake-init-cache: cmake/caches/ci/gh-hosted/cpu-clang.cmake\n          oneflow-src: .\n          oneflow-build-env: conda\n          conda-env-file: conda-env/dev/clang10/environment-v2.yml\n          conda-env-name: oneflow-dev-clang10-v2\n"
  },
  {
    "path": ".github/workflows/test.yml",
    "content": "name: Build and Test CI\non:\n  pull_request:\n    types: [opened, review_requested, ready_for_review, synchronize, unlocked]\n  merge_group:\n    types: [checks_requested]\n\nconcurrency:\n  group: build-and-test-${{ github.ref }}\n  cancel-in-progress: true\n\nenv:\n  OSS_ACCESS_KEY_ID: ${{ secrets.OSS_ACCESS_KEY_ID }}\n  OSS_ACCESS_KEY_SECRET: ${{ secrets.OSS_ACCESS_KEY_SECRET }}\n  ONEFLOW_TIMEOUT_SECONDS: 90\n  ONEFLOW_THRAED_LOCAL_CACHED_SIZE: 16384\n  FLOW_VISION_SRC: flow_vision\n  FLOW_VISION_COMMIT: ca8ebc663b58667cf8cd1b6ef0c861522780b7bb\n  LIBAI_SRC: libai\n  LIBAI_COMMIT: 94eb85ff0131e8dfce953a3a916de7a4f897c647\n  ONEFLOW_FACE_SRC: oneflow_face\n  ONEFLOW_FACE_COMMIT: 110a97e8d5737a1f1856281a7df556a5ac8f06de\n  ONEFLOW_IREE_SRC: oneflow_iree\n  ONEFLOW_IREE_COMMIT: 42fd479de7047e6af1d42c6e62b9b056e0a762aa\n  ONE_FX_SRC: one-fx\n  ONE_FX_COMMIT: da4051c7f1ace7a20b3f54395b580cd102fc99da\n  TEST_WITH_TORCH_IMG_TAG: registry.cn-beijing.aliyuncs.com/oneflow/test-with-pytorch-1.10.0-cuda11.3-cudnn8-runtime:25817b5c0e1dd79bef8fdd43d729b98af381e7d5\n  MLIR_DOCKER_ARGS: \"-e ONEFLOW_MLIR_ENABLE_ROUND_TRIP=1 -e ONEFLOW_MLIR_PREFER_NHWC=0 -e ONEFLOW_MLIR_ENABLE_INFERENCE_OPTIMIZATION=1\"\n  SSH_TANK_HOST: 192.168.1.40\n  SSH_TANK_PATH: /data/tank\n\njobs:\n  source_info:\n    name: Collect information about PR and source\n    runs-on: ubuntu-22.04\n    if: github.event.pull_request.draft == false && github.base_ref == 'master'\n    steps:\n      - name: Check out OneFlow\n        uses: actions/checkout@v2\n        with:\n          ref: ${{ github.event.pull_request.head.sha }}\n          repository: ${{github.event.pull_request.head.repo.full_name}}\n          fetch-depth: 0\n      - name: Python diff\n        id: py-diff\n        run: |\n          ONEFLOW_TEST_FILES=\"$(git diff --diff-filter=d --name-only ${{ github.event.pull_request.base.sha }}  -- python/oneflow/test/**/test_*.py | { grep -v expensive || true; })\"\n          ONEFLOW_TEST_FILES=$(echo \"${ONEFLOW_TEST_FILES}\" | xargs)\n          if [ -z \"$ONEFLOW_TEST_FILES\" ]; then\n              echo \"no changed python tests\"\n              echo \"has_changed_python_tests=false\" >> $GITHUB_OUTPUT\n          else\n              echo \"changed python tests: ${ONEFLOW_TEST_FILES}\"\n              echo \"has_changed_python_tests=true\" >> $GITHUB_OUTPUT\n          fi\n          echo \"changed_python_tests=${ONEFLOW_TEST_FILES}\" >> $GITHUB_OUTPUT\n    outputs:\n      changed_python_tests: ${{ steps.py-diff.outputs.changed_python_tests }}\n      has_changed_python_tests: ${{ steps.py-diff.outputs.has_changed_python_tests }}\n\n  mirror_third_party:\n    name: Mirror third party dependencies\n    runs-on: ubuntu-22.04\n    if: github.event.pull_request.draft == false && github.base_ref == 'master'\n    steps:\n      - uses: actions/checkout@v2\n      - name: Mirror dependencies to aliyun\n        if: github.event.pull_request.head.repo.full_name == github.repository\n        run: |\n          set -x\n          if [ -z \"$OSS_ACCESS_KEY_ID\" ]\n          then\n            exit 0\n          fi\n          python3 -m pip install -U pip \"setuptools<=68.2.2\" wheel\n          python3 -m pip install 'cryptography<=3.4' oss2\n          python3 tools/package_mirror.py -i $PWD\n\n  check_license_and_format:\n    name: License and format\n    runs-on: ubuntu-22.04\n    if: github.event.pull_request.draft == false\n    steps:\n      - uses: actions/checkout@v2\n        with:\n          repository: ${{github.event.pull_request.head.repo.full_name}}\n          ref: ${{ github.head_ref }}\n      - name: Check license\n        id: license_check\n        run: |\n          python3 ci/check/run_license_format.py -i oneflow -c\n          python3 ci/check/run_license_format.py -i python -c\n      - name: Add license\n        id: license_fmt\n        if: ${{ failure() }}\n        run: |\n          python3 ci/check/run_license_format.py -i oneflow --fix\n          python3 ci/check/run_license_format.py -i python --fix\n      - name: Check C++/CUDA format\n        id: cpp_check\n        run: |\n          sudo apt install libtinfo5\n          python3 ci/check/run_clang_format.py --clang_format_binary clang-format --source_dir oneflow\n      - name: Run C++/CUDA format\n        id: cpp_fmt\n        if: ${{ failure() }}\n        run: |\n          sudo apt install libtinfo5\n          python3 ci/check/run_clang_format.py --clang_format_binary clang-format --source_dir oneflow --fix\n      - name: Check Python format\n        id: py_check\n        run: |\n          python3 -m pip install black==19.10b0 click==8.0.0\n          python3 ci/check/run_py_format.py --source_dir $PWD\n      - name: Run Python Format\n        id: py_fmt\n        if: ${{ failure() }}\n        run: |\n          python3 -m pip install black==19.10b0 --user\n          python3 ci/check/run_py_format.py --source_dir $PWD --fix\n      - name: Check CMake format\n        id: cmake_check\n        run: |\n          python3 -m pip install cmakelang\n          python3 ci/check/run_cmake_format.py --source_dir $PWD\n      - name: Run CMake Format\n        id: cmake_fmt\n        if: ${{ failure() }}\n        run: |\n          python3 -m pip install cmakelang\n          python3 ci/check/run_cmake_format.py --source_dir $PWD --fix\n      - name: Git push\n        id: git_push\n        if: ${{ failure() }}\n        run: |\n          git diff -p > license_and_format.patch\n          cat license_and_format.patch\n          git config --global user.email \"ci-bot@oneflow.org\"\n          git config --global user.name \"oneflow-ci-bot\"\n          git add -u\n          git commit -m \"auto format by CI\"\n          git push\n      - name: Upload patch\n        if: ${{ failure() && steps.git_push.outcome == 'failure' }}\n        uses: actions/upload-artifact@v4\n        with:\n          name: license_and_format-${{ github.sha }}.patch\n          path: license_and_format.patch\n      - name: Add comment\n        if: ${{ failure() }}\n        uses: actions/github-script@v4\n        with:\n          script: |\n            github.issues.createComment({\n              issue_number: context.issue.number,\n              owner: context.repo.owner,\n              repo: context.repo.repo,\n              body: 'Code got formatted by CI. Please request CI again if you still want to have this PR merged. If the PR is from a forked repo, please download the patch files from the GitHub Actions web page and apply them locally.'\n            })\n      - name: Please request CI again\n        if: ${{ failure() }}\n        run: |\n          exit 1\n      - name: Check source code (prevent creating files at wrong places)\n        run: |\n          python3 tools/check_src.py\n\n  find-build-cache:\n    name: \"Find build cache\"\n    if: github.event.pull_request.draft == false && github.base_ref == 'master'\n    runs-on: ubuntu-latest\n    env:\n      ONEFLOW_SRC: .\n    outputs:\n      matrix: ${{ steps.find-cache.outputs.matrix }}\n    steps:\n      - name: Checkout Oneflow-Inc/oneflow\n        uses: actions/checkout@v2\n        with:\n          ref: ${{ github.event.pull_request.head.sha }}\n          repository: ${{github.event.pull_request.head.repo.full_name}}\n      - uses: Oneflow-Inc/get-oneflow/cache-complete/matrix/build@ci-test-with-cu118\n        name: find cache\n        id: find-cache\n        timeout-minutes: 5\n        with:\n          delete-cache: ${{ contains(github.event.pull_request.labels.*.name, 'need-clean-ccache') }}\n          runner-labels: |\n            self-hosted\n            linux\n            builder\n          oneflow-src: ${{ env.ONEFLOW_SRC }}\n          entries: |\n            cu118\n            cpu\n            cpu-asan-ubsan\n            cpu-tsan\n            llvm15\n\n  build-oneflow:\n    name: \"Build OneFlow\"\n    if: github.event.pull_request.draft == false && github.base_ref == 'master'\n    runs-on: ${{ matrix.runs-on }}\n    needs: [find-build-cache]\n    timeout-minutes: 80\n    strategy:\n      fail-fast: true\n      max-parallel: 5\n      matrix: ${{ fromJson(needs.find-build-cache.outputs.matrix) }}\n    env:\n      ONEFLOW_SRC: .\n      MANYLINUX_CACHE_DIR: ~/manylinux-cache-dir/${{ matrix.entry }}\n      WHEELHOUSE_DIR: manylinux-wheelhouse\n    steps:\n      - name: Set proxy\n        if: ${{ contains(matrix.runs-on, 'self-hosted') }}\n        run: |\n          echo \"https_proxy=${{ secrets.ONEFLOW_CI_HTTP_PROXY }}\" >> $GITHUB_ENV\n      - name: Fix permissions\n        if: ${{ contains(matrix.runs-on, 'self-hosted') }}\n        run: |\n          set -x\n          docker run --rm -v $PWD:$PWD -w $PWD busybox rm -rf *\n      - name: Checkout Oneflow-Inc/oneflow\n        uses: actions/checkout@v2\n        with:\n          ref: ${{ github.event.pull_request.head.sha }}\n          repository: ${{github.event.pull_request.head.repo.full_name}}\n      - uses: Oneflow-Inc/get-oneflow/cache-complete@ci-test-with-cu118\n        name: Save cache if successful\n        id: save-cache\n        timeout-minutes: 5\n        with:\n          oneflow-src: ${{ env.ONEFLOW_SRC }}\n          entry: ${{ matrix.entry }}\n          digest-type: build\n          mark-as-completed: ${{ contains(matrix.runs-on, 'self-hosted') && github.event.pull_request.head.repo.full_name == github.repository }}\n      - name: Check digest cache result. If this step failed, usually it is caused by new commits pushed when this CI run is running.\n        if: ${{ fromJSON(steps.save-cache.outputs.cache-hit) != matrix.cache-hit }}\n        run: |\n          echo \"::error file=test.yml,line=204,col=10::steps.save-cache.outputs.cache-hit != matrix.cache-hit\"\n          exit 1\n      - uses: Oneflow-Inc/get-oneflow@ci-test-with-cu118\n        name: Build manylinux ${{ matrix.entry }}\n        id: build-cpu\n        if: ${{ matrix.entry =='cpu' && !matrix.cache-hit }}\n        with:\n          cmake-init-cache: ${{ env.ONEFLOW_SRC }}/cmake/caches/ci/cpu.cmake\n          build-script: ${{ env.ONEFLOW_SRC }}/ci/manylinux/build.sh\n          run-lit: true\n          oneflow-src: ${{ env.ONEFLOW_SRC }}\n          oneflow-build-env: manylinux\n          wheelhouse-dir: ${{ env.WHEELHOUSE_DIR }}\n          clear-wheelhouse-dir: true\n          self-hosted: ${{ contains(matrix.runs-on, 'self-hosted') }}\n          cuda-version: none\n          manylinux-cache-dir: ${{ env.MANYLINUX_CACHE_DIR }}\n          docker-run-use-system-http-proxy: false\n          docker-run-use-lld: true\n          retry-failed-build: true\n          clean-ccache: ${{ contains(github.event.pull_request.labels.*.name, 'need-clean-ccache') }}\n          python-versions: |\n            3.7\n            3.8\n      - uses: Oneflow-Inc/get-oneflow@ci-test-with-cu118\n        name: Build manylinux ${{ matrix.entry }}\n        id: build-cpu-sanitizers\n        if: ${{ (matrix.entry == 'cpu-asan-ubsan' || matrix.entry == 'cpu-tsan') && !matrix.cache-hit && false }}\n        with:\n          cmake-init-cache: ${{ env.ONEFLOW_SRC }}/cmake/caches/ci/${{ matrix.entry }}.cmake\n          build-script: ${{ env.ONEFLOW_SRC }}/ci/manylinux/build.sh\n          run-lit: false\n          oneflow-src: ${{ env.ONEFLOW_SRC }}\n          oneflow-build-env: manylinux\n          wheelhouse-dir: ${{ env.WHEELHOUSE_DIR }}\n          clear-wheelhouse-dir: true\n          self-hosted: ${{ contains(matrix.runs-on, 'self-hosted') }}\n          cuda-version: none\n          manylinux-cache-dir: ${{ env.MANYLINUX_CACHE_DIR }}\n          docker-run-use-system-http-proxy: false\n          docker-run-use-lld: true\n          retry-failed-build: true\n          clean-ccache: ${{ contains(github.event.pull_request.labels.*.name, 'need-clean-ccache') }}\n          python-versions: |\n            3.8\n      - uses: Oneflow-Inc/get-oneflow@ci-test-with-cu118\n        name: Build manylinux ${{ matrix.entry }}\n        id: build-cuda\n        if: ${{ matrix.entry =='cu118' && !matrix.cache-hit }}\n        with:\n          cmake-init-cache: ${{ env.ONEFLOW_SRC }}/cmake/caches/ci/cuda.cmake\n          build-script: ${{ env.ONEFLOW_SRC }}/ci/manylinux/build-gcc9.sh\n          oneflow-src: ${{ env.ONEFLOW_SRC }}\n          oneflow-build-env: manylinux\n          wheelhouse-dir: ${{ env.WHEELHOUSE_DIR }}\n          clear-wheelhouse-dir: true\n          self-hosted: ${{ contains(matrix.runs-on, 'self-hosted') }}\n          cuda-version: \"11.8\"\n          manylinux-cache-dir: ${{ env.MANYLINUX_CACHE_DIR }}\n          docker-run-use-system-http-proxy: false\n          docker-run-use-lld: false\n          retry-failed-build: true\n          clean-ccache: ${{ contains(github.event.pull_request.labels.*.name, 'need-clean-ccache') }}\n          python-versions: |\n            3.7\n      - uses: Oneflow-Inc/get-oneflow@ci-test-with-cu118\n        name: Build ${{ matrix.entry }}\n        if: ${{ matrix.entry == 'llvm15' && !matrix.cache-hit }}\n        with:\n          cmake-init-cache: ${{ env.ONEFLOW_SRC }}/cmake/caches/ci/llvm/cuda-75-clang.cmake\n          build-script: ${{ env.ONEFLOW_SRC }}/ci/clang/build-llvm.sh\n          oneflow-src: ${{ env.ONEFLOW_SRC }}\n          oneflow-build-env: llvm\n          wheelhouse-dir: ${{ env.WHEELHOUSE_DIR }}\n          clear-wheelhouse-dir: true\n          self-hosted: true\n          cuda-version: ${{ env.CUDA_VERSION }}\n          manylinux-cache-dir: ${{ env.MANYLINUX_CACHE_DIR }}\n          docker-run-use-system-http-proxy: false\n          docker-run-use-lld: false\n          retry-failed-build: true\n          clean-ccache: ${{ contains(github.event.pull_request.labels.*.name, 'need-clean-ccache') }}\n          wheel-audit: false\n          python-versions: |\n            3.8\n      - name: Remove automerge\n        if: ${{ failure() && contains(matrix.runs-on, 'self-hosted') && cancelled() == false && contains(github.event.pull_request.labels.*.name, 'automerge') }}\n        uses: actions/github-script@v4\n        with:\n          script: |\n            github.issues.removeLabel({\n              issue_number: context.issue.number,\n              owner: context.repo.owner,\n              repo: context.repo.repo,\n              name: 'automerge'\n            })\n            github.issues.createComment({\n              issue_number: context.issue.number,\n              owner: context.repo.owner,\n              repo: context.repo.repo,\n              body: 'CI failed when running job: Build ${{ matrix.entry }}. PR label automerge has been removed'\n            })\n      - name: Upload packed liboneflow\n        if: ${{ !fromJson(matrix.cache-hit) && matrix.entry != 'llvm15' && matrix.entry != 'cpu-asan-ubsan' && matrix.entry != 'cpu-tsan' }}\n        uses: Oneflow-Inc/get-oneflow/digest/upload@ci-test-with-cu118\n        timeout-minutes: 10\n        with:\n          digest: ${{ steps.save-cache.outputs.build-digest }}\n          entry: ${{ matrix.entry }}\n          ssh-tank-host: ${{ env.SSH_TANK_HOST }}\n          ssh-tank-path: ${{ env.SSH_TANK_PATH }}\n          src-dir: ${{ env.MANYLINUX_CACHE_DIR }}/build/cpack\n          dst-dir: cpack\n      - name: Upload whl\n        if: ${{ !fromJson(matrix.cache-hit) && matrix.entry != 'llvm15' && matrix.entry != 'cpu-asan-ubsan' && matrix.entry != 'cpu-tsan' }}\n        uses: Oneflow-Inc/get-oneflow/digest/upload@ci-test-with-cu118\n        timeout-minutes: 10\n        with:\n          digest: ${{ steps.save-cache.outputs.build-digest }}\n          entry: ${{ matrix.entry }}\n          ssh-tank-host: ${{ env.SSH_TANK_HOST }}\n          ssh-tank-path: ${{ env.SSH_TANK_PATH }}\n          src-dir: ${{ env.WHEELHOUSE_DIR }}\n          dst-dir: whl\n\n  find-test-cache-distributed:\n    name: \"Find test cache (distributed)\"\n    if: github.event.pull_request.draft == false && github.base_ref == 'master' && contains(github.event.pull_request.labels.*.name, 'need-test-distributed')\n    runs-on: ubuntu-latest\n    needs: [build-oneflow]\n    env:\n      ONEFLOW_SRC: .\n    outputs:\n      matrix: ${{ steps.find-cache.outputs.matrix }}\n    steps:\n      - name: Checkout Oneflow-Inc/oneflow\n        uses: actions/checkout@v2\n        with:\n          ref: ${{ github.event.pull_request.head.sha }}\n          repository: ${{github.event.pull_request.head.repo.full_name}}\n      - uses: Oneflow-Inc/get-oneflow/cache-complete/matrix/test@ci-test-with-cu118\n        name: find cache\n        id: find-cache\n        timeout-minutes: 5\n        with:\n          runner-labels: |\n            self-hosted\n            linux\n          oneflow-src: ${{ env.ONEFLOW_SRC }}\n          include-distributed: true\n          world-size: 2\n          devices: |\n            cuda\n          tests: |\n            module\n\n  find-test-cache:\n    name: \"Find test cache\"\n    if: github.event.pull_request.draft == false && github.base_ref == 'master'\n    runs-on: ubuntu-latest\n    needs: [build-oneflow]\n    env:\n      ONEFLOW_SRC: .\n    outputs:\n      matrix: ${{ steps.find-cache.outputs.matrix }}\n    steps:\n      - name: Checkout Oneflow-Inc/oneflow\n        uses: actions/checkout@v2\n        with:\n          ref: ${{ github.event.pull_request.head.sha }}\n          repository: ${{github.event.pull_request.head.repo.full_name}}\n      - uses: Oneflow-Inc/get-oneflow/cache-complete/matrix/test@ci-test-with-cu118\n        name: find cache\n        id: find-cache\n        timeout-minutes: 5\n        with:\n          runner-labels: |\n            self-hosted\n            linux\n          oneflow-src: ${{ env.ONEFLOW_SRC }}\n          devices: |\n            cuda\n            cpu\n          tests: |\n            module\n            misc\n            speed-test\n\n  test-distributed:\n    name: Distributed test suite\n    needs: [find-test-cache-distributed, test]\n    runs-on: ${{ matrix.runs-on }}\n    timeout-minutes: 120\n    if: github.event.pull_request.draft == false && github.base_ref == 'master' && contains(github.event.pull_request.labels.*.name, 'need-test-distributed')\n    concurrency:\n      group: distributed-test-${{ matrix.entry }}-rank-${{ matrix.rank }}\n      cancel-in-progress: false\n    strategy:\n      fail-fast: true\n      max-parallel: 2\n      matrix: ${{ fromJson(needs.find-test-cache-distributed.outputs.matrix) }}\n    env:\n      ONEFLOW_SRC: .\n      TEST_CONTAINER_NAME: \"ci-test-distributed\"\n    steps:\n      - name: Fix permissions\n        if: ${{ contains(matrix.runs-on, 'self-hosted') }}\n        run: |\n          set -x\n          docker run --rm -v $PWD:$PWD -w $PWD busybox rm -rf *\n          docker run --rm -v $PWD:$PWD -w $PWD busybox rm -rf .pytest_cache\n      - name: Checkout Oneflow-Inc/oneflow\n        uses: actions/checkout@v2\n        with:\n          ref: ${{ github.event.pull_request.head.sha }}\n          repository: ${{github.event.pull_request.head.repo.full_name}}\n      - name: Checkout Oneflow-Inc/vision\n        if: ${{ !fromJson(matrix.cache-hit) && contains(matrix.runs-on, 'self-hosted') }}\n        uses: actions/checkout@v2\n        with:\n          repository: Oneflow-Inc/vision\n          # please use a commit here\n          ref: ${{ env.FLOW_VISION_COMMIT}}\n          path: ${{ env.FLOW_VISION_SRC}}\n      - name: Checkout Oneflow-Inc/one-fx\n        if: ${{ !fromJson(matrix.cache-hit) && contains(matrix.runs-on, 'self-hosted') }}\n        uses: actions/checkout@v2\n        with:\n          repository: Oneflow-Inc/one-fx\n          # please use a commit here\n          ref: ${{ env.ONE_FX_COMMIT}}\n          path: ${{ env.ONE_FX_SRC}}\n      - name: Checkout Oneflow-Inc/libai\n        if: ${{ !fromJson(matrix.cache-hit) && contains(matrix.runs-on, 'self-hosted') }}\n        uses: actions/checkout@v2\n        with:\n          repository: Oneflow-Inc/libai\n          # please use a commit here\n          ref: ${{ env.LIBAI_COMMIT}}\n          path: ${{ env.LIBAI_SRC}}\n      - name: Checkout Oneflow-Inc/oneflow_iree\n        if: ${{ !fromJson(matrix.cache-hit) && contains(matrix.runs-on, 'self-hosted') }}\n        uses: actions/checkout@v2\n        with:\n          repository: Oneflow-Inc/oneflow_iree\n          # please use a commit here\n          ref: ${{ env.ONEFLOW_IREE_COMMIT}}\n          path: ${{ env.ONEFLOW_IREE_SRC}}\n      - name: Remove container\n        timeout-minutes: 45\n        if: ${{ contains(matrix.runs-on, 'self-hosted') }}\n        run: |\n          docker rm -f ${{ env.TEST_CONTAINER_NAME }} || true\n      - uses: Oneflow-Inc/get-oneflow/cache-complete@ci-test-with-cu118\n        name: Save cache if successful\n        id: save-cache\n        timeout-minutes: 5\n        with:\n          oneflow-src: ${{ env.ONEFLOW_SRC }}\n          entry: ${{ matrix.entry }}\n          digest-type: ${{ matrix.digest-type }}\n          mark-as-completed: ${{ contains(matrix.runs-on, 'self-hosted') && github.event.pull_request.head.repo.full_name == github.repository }}\n      - name: Check digest cache result. If this step failed, usually it is caused by new commits pushed when this CI run is running.\n        if: ${{ fromJSON(steps.save-cache.outputs.cache-hit) != matrix.cache-hit }}\n        run: |\n          echo \"::error file=test.yml,line=204,col=10::steps.save-cache.outputs.cache-hit != matrix.cache-hit\"\n          exit 1\n      - name: Download wheel and packed liboneflow\n        if: ${{ !fromJson(matrix.cache-hit) && contains(matrix.runs-on, 'self-hosted') }}\n        uses: Oneflow-Inc/get-oneflow/digest/download@ci-test-with-cu118\n        id: download-digest\n        timeout-minutes: 10\n        with:\n          digest: ${{ steps.save-cache.outputs.build-digest }}\n          entry: ${{ matrix.compute-platform }}\n          ssh-tank-host: ${{ env.SSH_TANK_HOST }}\n          ssh-tank-path: ${{ env.SSH_TANK_PATH }}\n      - name: Get primary node\n        if: ${{ !fromJson(matrix.cache-hit) && contains(matrix.runs-on, 'self-hosted') }}\n        uses: Oneflow-Inc/get-oneflow/master-address@ci-test-with-cu118\n        id: get-primary-node\n        with:\n          rank: ${{ matrix.rank }}\n      - name: Set environment variables\n        if: ${{ !fromJson(matrix.cache-hit) && contains(matrix.runs-on, 'self-hosted') }}\n        run: |\n          set -x\n          extra_docker_args=\"\"\n          if [ \"${{ matrix.device }}\" == \"cpu\" ]; then\n            extra_docker_args+=\" --env ONEFLOW_TEST_CPU_ONLY=1\"\n            extra_docker_args+=\" --env CUDA_VISIBLE_DEVICES=-1\"\n          fi\n          echo \"EXTRA_DOCKER_ARGS=${extra_docker_args}\" >> $GITHUB_ENV\n          echo \"ONEFLOW_TEST_CACHE_DIR=$HOME/ci-cache/test_cache\" >> $GITHUB_ENV\n          echo \"ONEFLOW_TEST_DATASET_DIR=$HOME/dataset\" >> $GITHUB_ENV\n\n          echo \"ONEFLOW_WHEEL_PATH=${{ steps.download-digest.outputs.entry-dir }}/whl\" >> $GITHUB_ENV\n          echo \"ONEFLOW_CPACK_PATH=${{ steps.download-digest.outputs.entry-dir }}/cpack\" >> $GITHUB_ENV\n      - name: Set environment variables (distributed)\n        if: ${{ fromJson(matrix.is-distributed) }}\n        run: |\n          set -x\n          EXTRA_DOCKER_ARGS+=\" --network host \"\n          echo \"EXTRA_DOCKER_ARGS=${EXTRA_DOCKER_ARGS}\" >> $GITHUB_ENV\n      - name: Enable ONEFLOW_TEST_VERBOSE\n        if: ${{ contains(github.event.pull_request.labels.*.name, 'need-test-verbose') }}\n        run: |\n          EXTRA_DOCKER_ARGS+=\" --env ONEFLOW_TEST_VERBOSE=1\"\n          echo \"EXTRA_DOCKER_ARGS=${EXTRA_DOCKER_ARGS}\" >> $GITHUB_ENV\n      - name: Start container\n        if: ${{ !fromJson(matrix.cache-hit) && contains(matrix.runs-on, 'self-hosted') }}\n        working-directory: ${{ env.ONEFLOW_SRC }}\n        run: |\n          docker run --gpus=all -d --rm --privileged --shm-size=8g \\\n            --pids-limit 2000 \\\n            --cap-add=SYS_PTRACE --security-opt seccomp=unconfined \\\n            -v ${ONEFLOW_TEST_DATASET_DIR}:${ONEFLOW_TEST_DATASET_DIR}:ro \\\n            -v ${ONEFLOW_WHEEL_PATH}:${ONEFLOW_WHEEL_PATH}:ro \\\n            -v $HOME/test-container-cache/dot-local:/root/.local \\\n            -v $HOME/test-container-cache/dot-cache:/root/.cache \\\n            -e NODE_RANK=${{ matrix.rank }} \\\n            -e _MASTER_ADDR=${{ steps.get-primary-node.outputs.master-address }} \\\n            -e ONEFLOW_WHEEL_PATH=${ONEFLOW_WHEEL_PATH} \\\n            -e ONEFLOW_CI=1 \\\n            -v $PWD:$PWD \\\n            -w $PWD \\\n            -v ${ONEFLOW_TEST_CACHE_DIR}:${ONEFLOW_TEST_CACHE_DIR} \\\n            -e ONEFLOW_TEST_CACHE_DIR=${ONEFLOW_TEST_CACHE_DIR} \\\n            -e ONEFLOW_TEST_DATASET_DIR=${ONEFLOW_TEST_DATASET_DIR} \\\n            -e ONEFLOW_TIMEOUT_SECONDS=${{ env.ONEFLOW_TIMEOUT_SECONDS }} \\\n            -e ONEFLOW_THRAED_LOCAL_CACHED_SIZE=${{ env.ONEFLOW_THRAED_LOCAL_CACHED_SIZE }} \\\n            ${{ env.MLIR_DOCKER_ARGS }} \\\n            --name ${TEST_CONTAINER_NAME} \\\n            ${{ env.EXTRA_DOCKER_ARGS }} \\\n            ${{ env.TEST_WITH_TORCH_IMG_TAG }} \\\n            sleep 5400\n      - name: Test container\n        if: ${{ !fromJson(matrix.cache-hit) && contains(matrix.runs-on, 'self-hosted') }}\n        run: |\n          docker exec ${{ env.TEST_CONTAINER_NAME }} ls\n          docker exec ${{ env.TEST_CONTAINER_NAME }} python3 -m pip list\n      - name: Install OneFlow\n        if: ${{ !fromJson(matrix.cache-hit) && contains(matrix.runs-on, 'self-hosted') }}\n        run: |\n          ls ${ONEFLOW_WHEEL_PATH}\n          docker exec ${TEST_CONTAINER_NAME} python3 -m pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple\n          docker exec ${TEST_CONTAINER_NAME} python3 -m pip install --find-links=${ONEFLOW_WHEEL_PATH} oneflow\n      - name: Install downstream libs\n        if: ${{ !fromJson(matrix.cache-hit) && contains(matrix.runs-on, 'self-hosted') }}\n        run: |\n          docker exec ${TEST_CONTAINER_NAME} python3 -m pip install -e ${{ env.FLOW_VISION_SRC}}\n          docker exec ${TEST_CONTAINER_NAME} python3 -m pip install pybind11 --user\n          docker exec ${TEST_CONTAINER_NAME} python3 -m pip install tensorboardX==2.6 --user\n          docker exec ${TEST_CONTAINER_NAME} python3 -m pip install -e ${{ env.LIBAI_SRC}}\n          docker exec ${TEST_CONTAINER_NAME} python3 -m pip install -e ${{ env.ONEFLOW_IREE_SRC}}\n          docker exec ${TEST_CONTAINER_NAME} python3 -m pip install -e ${{ env.ONE_FX_SRC}}\n      - name: Module API test (distributed)\n        timeout-minutes: 90\n        if: ${{ !fromJson(matrix.cache-hit) && matrix.test-type == 'module' && matrix.device == 'cuda' && fromJson(matrix.is-distributed) }}\n        continue-on-error: false\n        run: |\n          docker exec -e ONEFLOW_TEST_DIR=$PWD/python/oneflow/test/modules ${{ env.TEST_CONTAINER_NAME }} bash ci/test/2node_op_test_multi_client.sh\n      - name: Module API test (distributed, without IB)\n        timeout-minutes: 60\n        if: ${{ !fromJson(matrix.cache-hit) && matrix.test-type == 'module' && matrix.device == 'cuda' && fromJson(matrix.is-distributed) && contains(github.event.pull_request.labels.*.name, 'need-distributed-without-ib')}}\n        continue-on-error: false\n        run: |\n          docker exec -e ONEFLOW_TEST_DIR=$PWD/python/oneflow/test/modules \\\n            -e ONEFLOW_LIBIBVERBS_PATH=invalid_lib \\\n            -e ONEFLOW_CI_DEVICE_NUMS=\"4\" \\\n            ${{ env.TEST_CONTAINER_NAME }} bash ci/test/2node_op_test_multi_client.sh\n      - name: Print stacks in all core files\n        timeout-minutes: 45\n        if: ${{ failure() && contains(matrix.runs-on, 'self-hosted') }}\n        run: |\n          docker exec ${{ env.TEST_CONTAINER_NAME }} bash ci/test/print_stack_in_all_dirs.sh || true\n      - name: Remove automerge\n        if: ${{ failure() && contains(matrix.runs-on, 'self-hosted') && cancelled() == false && contains(github.event.pull_request.labels.*.name, 'automerge') }}\n        uses: actions/github-script@v4\n        with:\n          script: |\n            github.issues.removeLabel({\n              issue_number: context.issue.number,\n              owner: context.repo.owner,\n              repo: context.repo.repo,\n              name: 'automerge'\n            })\n            github.issues.createComment({\n              issue_number: context.issue.number,\n              owner: context.repo.owner,\n              repo: context.repo.repo,\n              body: 'CI failed when running job: ${{ matrix.entry }}. PR label automerge has been removed'\n            })\n      - name: Remove container\n        timeout-minutes: 45\n        if: ${{ always() && contains(matrix.runs-on, 'self-hosted') }}\n        run: |\n          docker rm -f ${{ env.TEST_CONTAINER_NAME }} || true\n          docker run --rm -v $PWD:$PWD -w $PWD busybox rm -rf *\n\n  test:\n    name: Test suite\n    needs: [find-test-cache, source_info]\n    timeout-minutes: 120\n    runs-on: ${{ matrix.runs-on }}\n    if: github.event.pull_request.draft == false && github.base_ref == 'master'\n    strategy:\n      fail-fast: ${{ !contains(github.event.pull_request.labels.*.name, 'need-all-tests-even-fail') }}\n      max-parallel: 10\n      matrix: ${{ fromJson(needs.find-test-cache.outputs.matrix) }}\n    env:\n      ONEFLOW_SRC: .\n      TEST_CONTAINER_NAME: \"pr-${{ github.event.pull_request.number }}-run-id-${{ github.run_id }}-${{ matrix.entry }}-test\"\n      TEST_MANYLINUX_CONTAINER_NAME: \"pr-${{ github.event.pull_request.number }}-run-id-${{ github.run_id }}-${{ matrix.entry }}-test-manylinux\"\n      TEST_WITH_TF_IMG_TAG: registry.cn-beijing.aliyuncs.com/oneflow/test-with-tf-2.3.0:2f831e9354298a11447578e869d983959feb046f\n      TEST_MANYLINUX_IMG_TAG: registry.cn-beijing.aliyuncs.com/oneflow/manylinux2014_x86_64_cuda11.8:6455f9b8154333333e6285fde3747aaac4a92929\n      METRICS_DIR: metrics\n    steps:\n      - name: Set proxy\n        if: ${{ contains(matrix.runs-on, 'self-hosted') }}\n        run: |\n          echo \"https_proxy=${{ secrets.ONEFLOW_CI_HTTP_PROXY }}\" >> $GITHUB_ENV\n      - name: Fix permissions\n        if: ${{ contains(matrix.runs-on, 'self-hosted') }}\n        run: |\n          set -x\n          docker run --rm -v $PWD:$PWD -w $PWD busybox rm -rf *\n          docker run --rm -v $PWD:$PWD -w $PWD busybox rm -rf .pytest_cache\n      - name: Checkout Oneflow-Inc/oneflow\n        uses: actions/checkout@v2\n        with:\n          ref: ${{ github.event.pull_request.head.sha }}\n          repository: ${{github.event.pull_request.head.repo.full_name}}\n      - name: Checkout Oneflow-Inc/vision\n        if: ${{ !fromJson(matrix.cache-hit) && contains(matrix.runs-on, 'self-hosted') }}\n        uses: actions/checkout@v2\n        with:\n          repository: Oneflow-Inc/vision\n          # please use a commit here\n          ref: ${{ env.FLOW_VISION_COMMIT}}\n          path: ${{ env.FLOW_VISION_SRC}}\n      - name: Checkout Oneflow-Inc/libai\n        if: ${{ !fromJson(matrix.cache-hit) && contains(matrix.runs-on, 'self-hosted') }}\n        uses: actions/checkout@v2\n        with:\n          repository: Oneflow-Inc/libai\n          # please use a commit here\n          ref: ${{ env.LIBAI_COMMIT}}\n          path: ${{ env.LIBAI_SRC}}\n      - name: Checkout Oneflow-Inc/oneflow_face\n        if: ${{ !fromJson(matrix.cache-hit) && contains(matrix.runs-on, 'self-hosted') }}\n        uses: actions/checkout@v2\n        with:\n          repository: Oneflow-Inc/oneflow_face\n          # please use a commit here\n          ref: ${{ env.ONEFLOW_FACE_COMMIT}}\n          path: ${{ env.ONEFLOW_FACE_SRC}}\n      - name: Checkout Oneflow-Inc/oneflow_iree\n        if: ${{ !fromJson(matrix.cache-hit) && contains(matrix.runs-on, 'self-hosted') }}\n        uses: actions/checkout@v2\n        with:\n          repository: Oneflow-Inc/oneflow_iree\n          # please use a commit here\n          ref: ${{ env.ONEFLOW_IREE_COMMIT}}\n          path: ${{ env.ONEFLOW_IREE_SRC}}\n      - name: Checkout Oneflow-Inc/one-fx\n        if: ${{ !fromJson(matrix.cache-hit) && contains(matrix.runs-on, 'self-hosted') }}\n        uses: actions/checkout@v2\n        with:\n          repository: Oneflow-Inc/one-fx\n          # please use a commit here\n          ref: ${{ env.ONE_FX_COMMIT}}\n          path: ${{ env.ONE_FX_SRC}}\n      - name: Remove container\n        timeout-minutes: 45\n        if: ${{ contains(matrix.runs-on, 'self-hosted') }}\n        run: |\n          docker rm -f ${{ env.TEST_CONTAINER_NAME }} || true\n      - name: Remove manylinux container\n        timeout-minutes: 45\n        if: ${{ contains(matrix.runs-on, 'self-hosted') }}\n        run: |\n          docker rm -f ${{ env.TEST_MANYLINUX_CONTAINER_NAME }} || true\n      - uses: Oneflow-Inc/get-oneflow/cache-complete@ci-test-with-cu118\n        name: Save cache if successful\n        id: save-cache\n        timeout-minutes: 5\n        with:\n          oneflow-src: ${{ env.ONEFLOW_SRC }}\n          entry: ${{ matrix.entry }}\n          digest-type: ${{ matrix.digest-type }}\n          mark-as-completed: ${{ contains(matrix.runs-on, 'self-hosted') && github.event.pull_request.head.repo.full_name == github.repository }}\n      - name: Check digest cache result. If this step failed, usually it is caused by new commits pushed when this CI run is running.\n        if: ${{ fromJSON(steps.save-cache.outputs.cache-hit) != matrix.cache-hit }}\n        run: |\n          echo \"::error file=test.yml,line=204,col=10::steps.save-cache.outputs.cache-hit != matrix.cache-hit\"\n          exit 1\n      - name: Download wheel and packed liboneflow\n        if: ${{ !fromJson(matrix.cache-hit) && contains(matrix.runs-on, 'self-hosted') }}\n        uses: Oneflow-Inc/get-oneflow/digest/download@ci-test-with-cu118\n        id: download-digest\n        timeout-minutes: 10\n        with:\n          digest: ${{ steps.save-cache.outputs.build-digest }}\n          entry: ${{ matrix.compute-platform }}\n          ssh-tank-host: ${{ env.SSH_TANK_HOST }}\n          ssh-tank-path: ${{ env.SSH_TANK_PATH }}\n      - name: Download ASAN and UBSAN wheel and packed liboneflow\n        if: ${{ !fromJson(matrix.cache-hit) && contains(matrix.runs-on, 'self-hosted') && matrix.device == 'cpu' && false }}\n        uses: Oneflow-Inc/get-oneflow/digest/download@ci-test-with-cu118\n        id: asan-ubsan-download-digest\n        timeout-minutes: 10\n        with:\n          digest: ${{ steps.save-cache.outputs.build-digest }}\n          entry: cpu-asan-ubsan\n          ssh-tank-host: ${{ env.SSH_TANK_HOST }}\n          ssh-tank-path: ${{ env.SSH_TANK_PATH }}\n      - name: Download TSAN wheel and packed liboneflow\n        if: ${{ !fromJson(matrix.cache-hit) && contains(matrix.runs-on, 'self-hosted') && matrix.device == 'cpu' && false }}\n        uses: Oneflow-Inc/get-oneflow/digest/download@ci-test-with-cu118\n        id: tsan-download-digest\n        timeout-minutes: 10\n        with:\n          digest: ${{ steps.save-cache.outputs.build-digest }}\n          entry: cpu-tsan\n          ssh-tank-host: ${{ env.SSH_TANK_HOST }}\n          ssh-tank-path: ${{ env.SSH_TANK_PATH }}\n      - name: Enable TF container\n        if: ${{ fromJSON(matrix.is-single-client) }}\n        run: |\n          echo \"TEST_IMG_TAG=${TEST_WITH_TF_IMG_TAG}\" >> $GITHUB_ENV\n      - name: Enable Pytorch container\n        if: ${{ !fromJSON(matrix.is-single-client) }}\n        run: |\n          echo \"TEST_IMG_TAG=${TEST_WITH_TORCH_IMG_TAG}\" >> $GITHUB_ENV\n      - name: Set environment variables\n        if: ${{ !fromJson(matrix.cache-hit) && contains(matrix.runs-on, 'self-hosted') }}\n        run: |\n          set -x\n          extra_docker_args=\"\"\n          if [ \"${{ matrix.device }}\" == \"cpu\" ]; then\n            extra_docker_args+=\" --env ONEFLOW_TEST_CPU_ONLY=1\"\n            extra_docker_args+=\" --env CUDA_VISIBLE_DEVICES=-1\"\n          fi\n          echo \"EXTRA_DOCKER_ARGS=${extra_docker_args}\" >> $GITHUB_ENV\n          echo \"ONEFLOW_TEST_CACHE_DIR=$HOME/ci-cache/test_cache\" >> $GITHUB_ENV\n          echo \"ONEFLOW_TEST_DATASET_DIR=$HOME/dataset\" >> $GITHUB_ENV\n\n          echo \"ONEFLOW_WHEEL_PATH=${{ steps.download-digest.outputs.entry-dir }}/whl\" >> $GITHUB_ENV\n          echo \"ONEFLOW_CPACK_PATH=${{ steps.download-digest.outputs.entry-dir }}/cpack\" >> $GITHUB_ENV\n          echo \"DOCS_PATH=docs/${{ github.repository }}/pr/${{ github.event.pull_request.number }}\" >> $GITHUB_ENV\n      - name: Set environment variables (experimental flags)\n        if: ${{ !fromJson(matrix.cache-hit) && contains(matrix.runs-on, 'self-hosted') && fromJson(matrix.is-experimental) }}\n        run: |\n          EXTRA_DOCKER_ARGS+=\" --env ONEFLOW_KERNEL_ENABLE_CUDA_GRAPH=1\"\n          EXTRA_DOCKER_ARGS+=\" --env ONEFLOW_THREAD_ENABLE_LOCAL_MESSAGE_QUEUE=1\"\n          EXTRA_DOCKER_ARGS+=\" --env ONEFLOW_KERNEL_DISABLE_BLOB_ACCESS_CHECKER=1\"\n          echo \"EXTRA_DOCKER_ARGS=${EXTRA_DOCKER_ARGS}\" >> $GITHUB_ENV\n      - name: Set Thread Limit (CPU)\n        if: ${{ !fromJson(matrix.cache-hit) && matrix.device == 'cpu' }}\n        run: |\n          echo \"THREAD_LIMIT=25000\" >> $GITHUB_ENV\n      - name: Set Thread Limit (CUDA)\n        if: ${{ !fromJson(matrix.cache-hit) && matrix.device == 'cuda' }}\n        run: |\n          echo \"THREAD_LIMIT=20000\" >> $GITHUB_ENV\n      - name: Enable ONEFLOW_TEST_VERBOSE\n        if: ${{ contains(github.event.pull_request.labels.*.name, 'need-test-verbose') }}\n        run: |\n          EXTRA_DOCKER_ARGS+=\" --env ONEFLOW_TEST_VERBOSE=1\"\n          echo \"EXTRA_DOCKER_ARGS=${EXTRA_DOCKER_ARGS}\" >> $GITHUB_ENV\n      - name: Pull image\n        continue-on-error: true\n        if: ${{ !fromJson(matrix.cache-hit) && contains(matrix.runs-on, 'self-hosted') }}\n        run: |\n          docker pull ${{ env.TEST_IMG_TAG }}\n      - name: Unzip packed liboneflow\n        if: ${{ !fromJson(matrix.cache-hit) && contains(matrix.runs-on, 'self-hosted') && !fromJson(matrix.is-xla) }}\n        run: |\n          unzip ${{ env.ONEFLOW_CPACK_PATH }}/liboneflow-ci-linux.zip\n      - name: Unzip packed sanitized liboneflow\n        if: ${{ !fromJson(matrix.cache-hit) && contains(matrix.runs-on, 'self-hosted') && !fromJson(matrix.is-xla) && matrix.device == 'cpu' && false }}\n        run: |\n          unzip ${{ steps.asan-ubsan-download-digest.outputs.entry-dir }}/cpack/liboneflow-ci-linux.zip -d asan-ubsan\n          unzip ${{ steps.tsan-download-digest.outputs.entry-dir }}/cpack/liboneflow-ci-linux.zip -d tsan\n      - name: Start container\n        if: ${{ !fromJson(matrix.cache-hit) && contains(matrix.runs-on, 'self-hosted') }}\n        working-directory: ${{ env.ONEFLOW_SRC }}\n        run: |\n          docker run --gpus=all -d --rm --privileged --shm-size=8g \\\n            --pids-limit ${{ env.THREAD_LIMIT }} \\\n            --cap-add=SYS_PTRACE --security-opt seccomp=unconfined \\\n            -v ${ONEFLOW_TEST_DATASET_DIR}:${ONEFLOW_TEST_DATASET_DIR}:ro \\\n            -v ${ONEFLOW_WHEEL_PATH}:${ONEFLOW_WHEEL_PATH}:ro \\\n            -v $HOME/test-container-cache/dot-local:/root/.local \\\n            -v $HOME/test-container-cache/dot-cache:/root/.cache \\\n            -e ONEFLOW_WHEEL_PATH=${ONEFLOW_WHEEL_PATH} \\\n            -e ONEFLOW_CI=1 \\\n            -e NVIDIA_TF32_OVERRIDE=0 \\\n            -e NCCL_P2P_DISABLE=1 \\\n            -v $PWD:$PWD \\\n            -w $PWD \\\n            -v ${ONEFLOW_TEST_CACHE_DIR}:${ONEFLOW_TEST_CACHE_DIR} \\\n            -e ONEFLOW_TEST_CACHE_DIR=${ONEFLOW_TEST_CACHE_DIR} \\\n            -e ONEFLOW_TEST_DATASET_DIR=${ONEFLOW_TEST_DATASET_DIR} \\\n            -e ONEFLOW_TIMEOUT_SECONDS=${{ env.ONEFLOW_TIMEOUT_SECONDS }} \\\n            -e ONEFLOW_THRAED_LOCAL_CACHED_SIZE=${{ env.ONEFLOW_THRAED_LOCAL_CACHED_SIZE }} \\\n            ${{ env.MLIR_DOCKER_ARGS }} \\\n            --name ${TEST_CONTAINER_NAME} \\\n            ${{ env.EXTRA_DOCKER_ARGS }} \\\n            ${{ env.TEST_IMG_TAG }} \\\n            sleep 7200\n      - name: Start manylinux container\n        if: ${{ !fromJson(matrix.cache-hit) && contains(matrix.runs-on, 'self-hosted') }}\n        working-directory: ${{ env.ONEFLOW_SRC }}\n        # For unknown reason we need to disable the requirement from nvidia docker\n        # by -e NVIDIA_DISABLE_REQUIRE=true\n        run: |\n          docker run --gpus=all -d --rm --privileged --shm-size=8g \\\n            --pids-limit ${{ env.THREAD_LIMIT }} \\\n            --cap-add=SYS_PTRACE --security-opt seccomp=unconfined \\\n            -v ${ONEFLOW_TEST_DATASET_DIR}:${ONEFLOW_TEST_DATASET_DIR}:ro \\\n            -v ${ONEFLOW_WHEEL_PATH}:${ONEFLOW_WHEEL_PATH}:ro \\\n            -v $HOME/test-container-cache/dot-local:/root/.local \\\n            -v $HOME/test-container-cache/dot-cache:/root/.cache \\\n            -e NVIDIA_DISABLE_REQUIRE=true \\\n            -e ONEFLOW_WHEEL_PATH=${ONEFLOW_WHEEL_PATH} \\\n            -e ONEFLOW_CI=1 \\\n            -v $PWD:$PWD \\\n            -w $PWD \\\n            -v ${ONEFLOW_TEST_CACHE_DIR}:${ONEFLOW_TEST_CACHE_DIR} \\\n            -e ONEFLOW_TEST_CACHE_DIR=${ONEFLOW_TEST_CACHE_DIR} \\\n            -e ONEFLOW_TEST_DATASET_DIR=${ONEFLOW_TEST_DATASET_DIR} \\\n            -e ONEFLOW_TIMEOUT_SECONDS=${{ env.ONEFLOW_TIMEOUT_SECONDS }} \\\n            -e ONEFLOW_THRAED_LOCAL_CACHED_SIZE=${{ env.ONEFLOW_THRAED_LOCAL_CACHED_SIZE }} \\\n            ${{ env.MLIR_DOCKER_ARGS }} \\\n            --name ${TEST_MANYLINUX_CONTAINER_NAME} \\\n            ${{ env.EXTRA_DOCKER_ARGS }} \\\n            ${{ env.TEST_MANYLINUX_IMG_TAG }} \\\n            sleep 7200\n      - name: Exe test\n        if: ${{ !fromJson(matrix.cache-hit) && matrix.test-type == 'misc' }}\n        timeout-minutes: 20\n        run: |\n          docker exec ${{ env.TEST_MANYLINUX_CONTAINER_NAME }} ./liboneflow-ci-linux/bin/oneflow_testexe\n      - name: Exe test (C++ API)\n        if: ${{ !fromJson(matrix.cache-hit) && matrix.test-type == 'misc' }}\n        timeout-minutes: 20\n        run: |\n          docker exec -e ONEFLOW_SERVING_DEBUG=1 ${{ env.TEST_MANYLINUX_CONTAINER_NAME }} ./liboneflow-ci-linux/bin/oneflow_cpp_api_testexe --gtest_filter=-Api.embedding*\n      - name: Exe test (C++ API with sanitizers)\n        if: ${{ !fromJson(matrix.cache-hit) && matrix.test-type == 'misc' && matrix.device == 'cpu' && false }}\n        timeout-minutes: 10\n        run: |\n          docker exec -e UBSAN_OPTIONS=suppressions=.ubsan-suppressions -e ASAN_OPTIONS=strict_string_checks=1:detect_stack_use_after_return=1 -e LSAN_OPTIONS=suppressions=.lsan-suppressions ${{ env.TEST_MANYLINUX_CONTAINER_NAME }} ./asan-ubsan/liboneflow-ci-linux/bin/oneflow_cpp_api_testexe --gtest_filter=Api.graph_\\*\n          # Run 5 times to avoid false positive because of occasional lack of stack info\n          docker exec -e TSAN_OPTIONS=\"history_size=7 suppressions=.tsan-suppressions\" ${{ env.TEST_MANYLINUX_CONTAINER_NAME }} bash -c \"./tsan/liboneflow-ci-linux/bin/oneflow_cpp_api_testexe || ./tsan/liboneflow-ci-linux/bin/oneflow_cpp_api_testexe || ./tsan/liboneflow-ci-linux/bin/oneflow_cpp_api_testexe || ./tsan/liboneflow-ci-linux/bin/oneflow_cpp_api_testexe || ./tsan/liboneflow-ci-linux/bin/oneflow_cpp_api_testexe\"\n      - name: Test container\n        if: ${{ !fromJson(matrix.cache-hit) && contains(matrix.runs-on, 'self-hosted') }}\n        run: |\n          docker exec ${{ env.TEST_CONTAINER_NAME }} ls\n          docker exec ${{ env.TEST_CONTAINER_NAME }} python3 -m pip list\n      - name: Install OneFlow\n        if: ${{ !fromJson(matrix.cache-hit) && contains(matrix.runs-on, 'self-hosted') }}\n        run: |\n          ls ${ONEFLOW_WHEEL_PATH}\n          docker exec ${TEST_CONTAINER_NAME} python3 -m pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple\n          docker exec ${TEST_CONTAINER_NAME} python3 -m pip install -U --find-links=${ONEFLOW_WHEEL_PATH} oneflow\n      - name: Install downstream libs\n        if: ${{ !fromJson(matrix.cache-hit) && contains(matrix.runs-on, 'self-hosted') }}\n        run: |\n          docker exec ${TEST_CONTAINER_NAME} python3 -m pip install -e ${{ env.FLOW_VISION_SRC}}\n          docker exec ${TEST_CONTAINER_NAME} python3 -m pip install pybind11 --user\n          docker exec ${TEST_CONTAINER_NAME} python3 -m pip install tensorboardX==2.6 --user\n          docker exec ${TEST_CONTAINER_NAME} python3 -m pip install -e ${{ env.LIBAI_SRC}}\n          docker exec ${TEST_CONTAINER_NAME} python3 -m pip install -e ${{ env.ONEFLOW_FACE_SRC}}\n          docker exec ${TEST_CONTAINER_NAME} python3 -m pip install -e ${{ env.ONEFLOW_IREE_SRC}}\n          docker exec ${TEST_CONTAINER_NAME} python3 -m pip install -e ${{ env.ONE_FX_SRC}}\n      - name: Run OneFlow doctor\n        if: ${{ !fromJson(matrix.cache-hit) && contains(matrix.runs-on, 'self-hosted') }}\n        run: |\n          docker exec ${{ env.TEST_CONTAINER_NAME }} python3 -m oneflow --doctor\n      - name: Build documentation\n        timeout-minutes: 10\n        if: ${{ !fromJson(matrix.cache-hit) && matrix.test-type == 'misc' && matrix.device == 'cpu' }}\n        run: |\n          docker exec ${{ env.TEST_CONTAINER_NAME }} bash ci/test/build_docs.sh\n      - name: Upload documentation\n        id: upload-docs\n        if: ${{ !fromJson(matrix.cache-hit) && matrix.test-type == 'misc' && matrix.device == 'cpu' && github.repository == 'Oneflow-Inc/oneflow' }}\n        continue-on-error: true\n        uses: ./.github/actions/upload_oss\n        with:\n          src_path: build-docs/build/html\n          oss_dst_path: oss://oneflow-staging/${{ env.DOCS_PATH }}\n          oss_access_key_id: ${{ secrets.OSS_ACCESS_KEY_ID }}\n          oss_access_key_secret: ${{ secrets.OSS_ACCESS_KEY_SECRET }}\n      - name: Post docs url\n        if: ${{ !fromJson(matrix.cache-hit) && matrix.test-type == 'misc' && matrix.device == 'cpu' && github.repository == 'Oneflow-Inc/oneflow' && steps.upload-docs.outcome == 'success'\t}}\n        continue-on-error: true\n        uses: actions/github-script@v4\n        with:\n          script: |\n            github.issues.createComment({\n              issue_number: context.issue.number,\n              owner: context.repo.owner,\n              repo: context.repo.repo,\n              body: \"View latest API docs preview at: https://oneflow-staging.oss-cn-beijing.aliyuncs.com/${{ env.DOCS_PATH }}/\"\n            })\n      - name: Doctest\n        timeout-minutes: 45\n        if: ${{ !fromJson(matrix.cache-hit) && matrix.test-type == 'misc' && matrix.device == 'cuda' }}\n        run: |\n          docker exec ${{ env.TEST_CONTAINER_NAME }} bash ci/test/doctest.sh\n      - name: Checkout Oneflow-Inc/models\n        if: ${{ !fromJson(matrix.cache-hit) && matrix.test-type == 'speed-test' && matrix.device == 'cuda' }}\n        uses: actions/checkout@v2\n        with:\n          repository: Oneflow-Inc/models\n          ref: d6b2b8260e87541726ed87361171438d258e6a4d\n          path: oneflow-models\n      - name: ResNet50 Graph DDP test\n        id: models-resnet50\n        timeout-minutes: 20\n        if: ${{ !fromJson(matrix.cache-hit) && matrix.test-type == 'speed-test' && matrix.device == 'cuda' }}\n        run: |\n          docker exec -e NCCL_DEBUG=INFO -e ONEFLOW_MODELS_DIR=$PWD/oneflow-models ${{ env.TEST_CONTAINER_NAME }} bash ci/test/test_resnet50_graph_ddp.sh\n      - name: Speed test\n        id: speed\n        timeout-minutes: 20\n        continue-on-error: ${{ !contains(github.event.pull_request.labels.*.name, 'need-pass-speed-test') }}\n        if: ${{ !fromJson(matrix.cache-hit) && matrix.test-type == 'speed-test' && matrix.device == 'cuda' }}\n        run: |\n          docker exec -e ONEFLOW_MODELS_DIR=$PWD/oneflow-models ${{ env.TEST_CONTAINER_NAME }} bash ci/test/test_speed_multi_client.sh\n      - name: Save speed stats\n        if: ${{ always() && !fromJson(matrix.cache-hit) && matrix.test-type == 'speed-test' && matrix.device == 'cuda' }}\n        run: |\n          mkdir -p ${{ env.METRICS_DIR }}\n          echo \"${{ steps.speed.outputs.stats }}\" >> ${{ env.METRICS_DIR }}/speed_stats.txt\n      - name: Upload speed stats\n        if: ${{ always() && !fromJson(matrix.cache-hit) && matrix.test-type == 'speed-test' && matrix.device == 'cuda' }}\n        # must succeed if it is a branch of Oneflow-Inc/oneflow\n        continue-on-error: ${{ !(github.repository == 'Oneflow-Inc/oneflow') }}\n        uses: ./.github/actions/upload_oss\n        with:\n          src_path: ${{ env.METRICS_DIR }}\n          oss_dst_path: oss://oneflow-log/${{ github.repository }}/metrics/pr/${{ github.event.pull_request.number }}/${{ github.event.pull_request.head.sha }}/${{github.run_id}}\n          oss_access_key_id: ${{ secrets.OSS_ACCESS_KEY_ID }}\n          oss_access_key_secret: ${{ secrets.OSS_ACCESS_KEY_SECRET }}\n      - name: Post speed stats\n        if: ${{ always() && !fromJson(matrix.cache-hit) && matrix.test-type == 'speed-test' && matrix.device == 'cuda' }}\n        continue-on-error: true\n        uses: actions/github-script@v4\n        with:\n          script: |\n            github.issues.createComment({\n              issue_number: context.issue.number,\n              owner: context.repo.owner,\n              repo: context.repo.repo,\n              body: \"<details>\\n <summary>Speed stats:</summary>\\n\\n ``` \\n${{ steps.speed.outputs.stats }}\\n ``` \\n\\n</details>\".replace(/\\\\n/g, '\\n')\n            })\n      - name: Run tests in changed files compared to default branch 100 times\n        timeout-minutes: 60\n        if: ${{ !fromJson(matrix.cache-hit) && matrix.test-type == 'module' && !fromJson(matrix.is-distributed) && steps.py-diff.outputs.has_changed_python_tests }}\n        run: |\n          docker exec -e ONEFLOW_TEST_DIR=diff \\\n            -e ONEFLOW_TEST_FILES=\"${{needs.source_info.outputs.changed_python_tests}}\" \\\n            ${{ env.TEST_CONTAINER_NAME }} bash ci/test/generic_test_multi_client.sh\n      - name: Expensive tests (models, cases require exclusive access to GPU)\n        timeout-minutes: 45\n        if: ${{ !fromJson(matrix.cache-hit) && (matrix.test-type == 'speed-test' || (matrix.test-type == 'misc' && matrix.device == 'cuda')) && !fromJson(matrix.is-distributed) }}\n        run: |\n          docker exec \\\n            -e ONEFLOW_TEST_TENSOR_SIZE_LIMIT_MB=1024 \\\n            -e ONEFLOW_TEST_DIR=$PWD/python/oneflow/test/expensive \\\n            ${{ env.TEST_CONTAINER_NAME }} bash ci/test/expensive_generic_test_multi_client.sh\n      - name: Module API test\n        timeout-minutes: 60\n        if: ${{ !fromJson(matrix.cache-hit) && matrix.test-type == 'module' && !fromJson(matrix.is-distributed) }}\n        run: |\n          docker exec -e ONEFLOW_TEST_DIR=$PWD/python/oneflow/test/modules ${{ env.TEST_CONTAINER_NAME }} bash ci/test/generic_test_multi_client.sh\n      - name: Graph API test\n        timeout-minutes: 45\n        if: ${{ !fromJson(matrix.cache-hit) && matrix.test-type == 'misc' }}\n        run: |\n          docker exec -e ONEFLOW_TEST_DIR=$PWD/python/oneflow/test/graph ${{ env.TEST_CONTAINER_NAME }} bash ci/test/generic_test_multi_client.sh\n          docker exec ${{ env.TEST_CONTAINER_NAME }} python3 -m oneflow.distributed.launch --nproc_per_node 8 $PWD/python/oneflow/test/graph/test_neq_device_process_num.py\n      - name: libai test\n        timeout-minutes: 45\n        if: ${{ !fromJson(matrix.cache-hit) && matrix.test-type == 'misc' && matrix.device == 'cuda' }}\n        run: |\n          docker exec -e ONEFLOW_TEST_DEVICE_NUM=4 -w $PWD/${{ env.LIBAI_SRC }} ${{ env.TEST_CONTAINER_NAME }} python3 -m oneflow.distributed.launch --nproc_per_node 4 -m unittest -f tests/models/test_bert.py\n          docker exec -e ONEFLOW_TEST_DEVICE_NUM=4 -w $PWD/${{ env.LIBAI_SRC }} ${{ env.TEST_CONTAINER_NAME }} python3 -m oneflow.distributed.launch --nproc_per_node 4 -m unittest -f tests/models/test_gpt.py\n          docker exec -e ONEFLOW_TEST_DEVICE_NUM=4 -w $PWD/${{ env.LIBAI_SRC }} ${{ env.TEST_CONTAINER_NAME }} python3 -m oneflow.distributed.launch --nproc_per_node 4 -m unittest -f tests/models/test_t5.py\n          docker exec -e ONEFLOW_TEST_DEVICE_NUM=4 -w $PWD/${{ env.LIBAI_SRC }} ${{ env.TEST_CONTAINER_NAME }} python3 -m oneflow.distributed.launch --nproc_per_node 4 -m unittest -f tests/models/test_vit.py\n      - name: oneflow_face test\n        timeout-minutes: 30\n        if: ${{ !fromJson(matrix.cache-hit) && matrix.test-type == 'misc' && matrix.device == 'cuda' }}\n        run: |\n          docker exec -e ONEFLOW_TEST_DEVICE_NUM=4 -w $PWD/${{ env.ONEFLOW_FACE_SRC }} ${{ env.TEST_CONTAINER_NAME }} python3 -m oneflow.distributed.launch --nproc_per_node 4 -m pytest tests/train/test_train.py\n      - name: oneflow_iree test\n        timeout-minutes: 45\n        if: ${{ !fromJson(matrix.cache-hit) && matrix.test-type == 'misc'  && false  }}\n        run: |\n          docker exec -w $PWD/${{ env.ONEFLOW_IREE_SRC }} ${{ env.TEST_CONTAINER_NAME }} python3 -m pytest examples\n      - name: IR tests\n        timeout-minutes: 45\n        if: ${{ !fromJson(matrix.cache-hit) && (matrix.test-type == 'misc' && matrix.device == 'cuda') && !fromJson(matrix.is-distributed) }}\n        run: |\n          docker exec \\\n            -e ONEFLOW_TEST_TENSOR_SIZE_LIMIT_MB=1024 \\\n            ${{ env.TEST_CONTAINER_NAME }} bash ci/test/ir_tests.sh\n      - name: Exception API test\n        timeout-minutes: 45\n        if: ${{ !fromJson(matrix.cache-hit) && matrix.test-type == 'misc' && false }}\n        run: docker exec ${{ env.TEST_CONTAINER_NAME }} bash ci/test/multi_client_exception_test.sh\n      - name: Misc test\n        timeout-minutes: 45\n        if: ${{ !fromJson(matrix.cache-hit) && matrix.test-type == 'misc' }}\n        run: |\n          docker exec -e ONEFLOW_TEST_DIR=$PWD/python/oneflow/test/misc ${{ env.TEST_CONTAINER_NAME }} bash ci/test/generic_test_multi_client.sh\n      - name: Dataloader API test\n        timeout-minutes: 45\n        # TODO(luyang): dataset check fails\n        if: ${{ !fromJson(matrix.cache-hit) && matrix.test-type == 'misc' && false}}\n        run: |\n          docker exec -e ONEFLOW_TEST_DIR=$PWD/python/oneflow/test/dataloader ${{ env.TEST_CONTAINER_NAME }} bash ci/test/generic_test_multi_client.sh\n      - name: Tensor API test\n        timeout-minutes: 45\n        if: ${{ !fromJson(matrix.cache-hit) && matrix.test-type == 'misc' }}\n        run: |\n          docker exec -e ONEFLOW_TEST_DIR=$PWD/python/oneflow/test/tensor ${{ env.TEST_CONTAINER_NAME }} bash ci/test/generic_test_multi_client.sh\n      - name: Test mocking torch by script\n        timeout-minutes: 45\n        if: ${{ !fromJson(matrix.cache-hit) && matrix.test-type == 'module' }}\n        run: |\n          docker exec ${{ env.TEST_CONTAINER_NAME }} bash -x ci/test/test_mock_script.sh\n      - name: Test mocking torch by function\n        timeout-minutes: 45\n        if: ${{ !fromJson(matrix.cache-hit) && matrix.test-type == 'module' }}\n        run: |\n          docker exec ${{ env.TEST_CONTAINER_NAME }} bash -x ci/test/test_mock_function.sh\n      - name: Benchmark Test\n        timeout-minutes: 100\n        if: ${{ !fromJson(matrix.cache-hit) && matrix.test-type == 'benchmark' && matrix.device == 'cuda' }}\n        uses: Oneflow-Inc/get-oneflow/pytest-benchmark@ci-test-with-cu118\n        with:\n          collect-path: ${{ env.FLOW_VISION_SRC }}/benchmark\n          container-name: ${{ env.TEST_CONTAINER_NAME }}\n          unknown-threshold: 30\n          error-threshold: 40\n      - name: Remove automerge\n        if: ${{ failure() && contains(matrix.runs-on, 'self-hosted') && cancelled() == false && contains(github.event.pull_request.labels.*.name, 'automerge') }}\n        uses: actions/github-script@v4\n        with:\n          script: |\n            github.issues.removeLabel({\n              issue_number: context.issue.number,\n              owner: context.repo.owner,\n              repo: context.repo.repo,\n              name: 'automerge'\n            })\n            github.issues.createComment({\n              issue_number: context.issue.number,\n              owner: context.repo.owner,\n              repo: context.repo.repo,\n              body: 'CI failed when running job: ${{ matrix.entry }}. PR label automerge has been removed'\n            })\n      - name: Print stacks in all core files\n        timeout-minutes: 45\n        if: ${{ failure() && contains(matrix.runs-on, 'self-hosted') }}\n        run: |\n          docker exec ${{ env.TEST_CONTAINER_NAME }} bash ci/test/print_stack_in_all_dirs.sh || true\n      - name: Query system status\n        timeout-minutes: 45\n        if: ${{ failure() && contains(matrix.runs-on, 'self-hosted') }}\n        run: |\n          nvidia-smi || true\n          docker ps || true\n      - name: Remove container\n        timeout-minutes: 45\n        if: ${{ always() && contains(matrix.runs-on, 'self-hosted') }}\n        run: |\n          docker rm -f ${{ env.TEST_CONTAINER_NAME }} || true\n      - name: Remove manylinux container\n        timeout-minutes: 45\n        if: ${{ always() && contains(matrix.runs-on, 'self-hosted') }}\n        run: |\n          docker rm -f ${{ env.TEST_MANYLINUX_CONTAINER_NAME }} || true\n      - name: Clean workspace\n        timeout-minutes: 45\n        if: ${{ always() && contains(matrix.runs-on, 'self-hosted') }}\n        run: |\n          docker run --rm -v $PWD:$PWD -w $PWD busybox rm -rf *\n\n  static_analysis_with_clang_on_diff:\n    name: Static analysis with clang on diff\n    runs-on: ubuntu-22.04\n    if: github.event.pull_request.draft == false && github.base_ref == 'master'\n    steps:\n      - name: Check out OneFlow\n        uses: actions/checkout@v2\n        with:\n          ref: ${{ github.event.pull_request.head.sha }}\n          repository: ${{github.event.pull_request.head.repo.full_name}}\n          fetch-depth: 0\n      - uses: Oneflow-Inc/get-oneflow/cache-complete@ci-test-with-cu118\n        name: Save cache if successful\n        id: save-cache\n        timeout-minutes: 5\n        with:\n          oneflow-src: .\n          entry: static_analysis_with_clang_on_diff\n          digest-type: build\n          mark-as-completed: ${{ github.event.pull_request.head.repo.full_name == github.repository }}\n      - name: Install dependencies\n        if: ${{ !fromJSON(steps.save-cache.outputs.cache-hit) }}\n        run: |\n          sudo apt-get update\n          sudo apt-get install -y libopenblas-dev nasm python3-pip ninja-build ccache\n      - name: Download OneFlow custom clang-tidy\n        if: ${{ !fromJSON(steps.save-cache.outputs.cache-hit) }}\n        run: |\n          wget https://github.com/Oneflow-Inc/llvm-project/releases/download/maybe-16.0.0/oneflow-clang-tidy-16\n          wget https://raw.githubusercontent.com/oneflow-inc/llvm-project/maybe/clang-tools-extra/clang-tidy/tool/clang-tidy-diff.py\n          chmod +x oneflow-clang-tidy-16 clang-tidy-diff.py\n      - name: Cache third party dir\n        uses: actions/cache@v4\n        if: ${{ !fromJSON(steps.save-cache.outputs.cache-hit) }}\n        with:\n          path: ~/.ccache\n          key: clang-tidy-diff-third-party-ccache-${{ hashFiles('**/CMakeLists.txt') }}-${{ hashFiles('**/*.cmake') }}\n          restore-keys: |\n            clang-tidy-diff-third-party-ccache-${{ hashFiles('**/CMakeLists.txt') }}-\n            clang-tidy-diff-third-party-ccache-\n      - name: Build third party libs and generate files\n        if: ${{ !fromJSON(steps.save-cache.outputs.cache-hit) }}\n        run: |\n          export CCACHE_COMPRESS=true\n          export CCACHE_MAXSIZE=500M\n          mkdir build\n          cd build\n          cmake .. -C ../cmake/caches/international/cpu.cmake \\\n            -DCMAKE_BUILD_TYPE=Release \\\n            -DBUILD_TESTING=OFF \\\n            -DCMAKE_C_COMPILER_LAUNCHER=ccache \\\n            -DCMAKE_CXX_COMPILER_LAUNCHER=ccache\n          cmake --build . -j$(nproc) --target oneflow_deps of_protoobj of_functional_obj of_functional_tensor_obj of_op_schema\n      - name: Fetch upstream\n        if: ${{ !fromJSON(steps.save-cache.outputs.cache-hit) && github.event.pull_request.head.repo.full_name != github.event.pull_request.base.repo.full_name }}\n        run: |\n          git remote add upstream https://github.com/Oneflow-Inc/oneflow\n          git fetch upstream\n      - name: Run clang-tidy for modified files\n        # use clang as compiler for correct compiler flags\n        if: ${{ !fromJSON(steps.save-cache.outputs.cache-hit) }}\n        run: |\n          sudo apt install clang-12 lldb-12 lld-12 libfuse2\n          cd build\n          rm CMakeCache.txt\n          cmake .. -C ../cmake/caches/international/cpu.cmake \\\n            -DCMAKE_C_COMPILER=clang-12 \\\n            -DCMAKE_CXX_COMPILER=clang++-12 \\\n            -DCMAKE_BUILD_TYPE=Release \\\n            -DBUILD_TESTING=OFF \\\n            -DCMAKE_EXPORT_COMPILE_COMMANDS=ON\n          cd ..\n          git diff -U0 ${{ github.event.pull_request.base.sha }} | ./clang-tidy-diff.py -clang-tidy-binary ./oneflow-clang-tidy-16 -path build -allow-enabling-alpha-checkers -j $(nproc) -p1 -extra-arg=\"-Xclang\" -extra-arg=\"-analyzer-config\" -extra-arg=\"-Xclang\" -extra-arg=\"aggressive-binary-operation-simplification=true\" -warnings-as-errors=\"$(cat ./ci/check/clang_tidy_warnings_as_errors_on_diff)\"\n      - name: Check error message absence in changed files\n        if: ${{ !fromJSON(steps.save-cache.outputs.cache-hit) && contains(github.event.pull_request.labels.*.name, 'need-check-error-message') }}\n        run: |\n          git diff -U0 ${{ github.event.pull_request.base.sha }} | ./clang-tidy-diff.py -clang-tidy-binary ./oneflow-clang-tidy-16 -path build -allow-enabling-alpha-checkers -j $(nproc) -p1 -extra-arg=\"-Xclang\" -extra-arg=\"-analyzer-config\" -extra-arg=\"-Xclang\" -extra-arg=\"aggressive-binary-operation-simplification=true\" -checks=-*,maybe-need-error-msg -warnings-as-errors=* -skip-line-filter\n      - name: Remove automerge\n        if: ${{ !fromJSON(steps.save-cache.outputs.cache-hit) && failure() && cancelled() == false && contains(github.event.pull_request.labels.*.name, 'automerge') }}\n        uses: actions/github-script@v4\n        with:\n          script: |\n            github.issues.removeLabel({\n              issue_number: context.issue.number,\n              owner: context.repo.owner,\n              repo: context.repo.repo,\n              name: 'automerge'\n            })\n            github.issues.createComment({\n              issue_number: context.issue.number,\n              owner: context.repo.owner,\n              repo: context.repo.repo,\n              body: 'Static analysis with clang failed. PR label automerge has been removed'\n            })\n"
  },
  {
    "path": ".gitignore",
    "content": "/build\n/build-*\n/docs/build/\n/docs/build-cn/\n/docs/source/generated\n/cmake-build-*\n/dist\n/third_party/\n/examples/**/oneflow\n/benchmark/**/oneflow\nlog/\n*plan\ncore.*\n*.pyc\n*.ipynb\n/.vscode\n/.idea\n/manylinux*\nwheelhouse/\nwheelhouse*\n.DS_Store\n/tmp_wheel\n/oneflow/python/__export_symbols__.py\n/oneflow/python/compatibility.py\n/oneflow/python/framework/sysconfig_gen.py\n/oneflow/python/test/ops/localhost_script_*.sh\n.clangd\ncompile_commands.json\n.cache\n/oneflow-src.zip\n/oneflow_temp\n/distributed-tmp\n/serving-tmp\ntest_tmp_dir\nunittest-log-*\n/oneflow/python\n/oneflow/compatible_single_client_python\n/benchmarks\n/oneflow/python/version.py\n/data-test\n/tmp\n/python/oneflow/test/dataloader/data-test/\n\n/target\nsaved_model\n/devcontainer-cache\n\nop_prof.csv\n*.lock\n"
  },
  {
    "path": ".lsan-suppressions",
    "content": "leak:CommandT\n"
  },
  {
    "path": ".mergify.yml",
    "content": "pull_request_rules:\n  - name: automatic update for PR with label “automerge“\n    conditions:\n      - \"#approved-reviews-by>=2\"\n      - -conflict # skip conflicts\n      - -draft # skip draft PRs\n      - label=\"automerge\"\n    actions:\n      update:\n  - name: automatic merge\n    conditions:\n      - \"#approved-reviews-by>=2\"\n      - -conflict # skip conflicts\n      - -draft # skip draft PRs\n      - label=\"automerge\"\n      - \"#commits-behind==0\"\n      - -closed\n    actions:\n      merge:\n        method: squash\n"
  },
  {
    "path": ".tsan-suppressions",
    "content": "# These four group of functions are designed to be thread unsafe,\n# it's user's responsibility to use them correctly.\nrace:ThreadUnsafe\nrace:thread_unsafe\nrace:flying_instruction_cnt\nrace:total_erased_instruction_cnt\nrace:ToShape\n# glog\nrace:google::\n# ~basic_string() in DenseElementsAttrToTensor interferes with\n# ~~AccessBlobArgCbInstructionPolicy(). Perhaps it's a false\n# positive.\nrace:~basic_string\n"
  },
  {
    "path": ".ubsan-suppressions",
    "content": "# llvm\nvptr:Class.cpp\n"
  },
  {
    "path": "CMakeLists.txt",
    "content": "# Minimum CMake required\nset(CMAKE_POLICY_DEFAULT_CMP0135 NEW)\ncmake_minimum_required(VERSION 3.18.0)\n\nset(CMAKE_INSTALL_MESSAGE LAZY CACHE STRING \"\")\nset(CMAKE_EXPORT_COMPILE_COMMANDS ON CACHE BOOL \"\")\n\noption(THIRD_PARTY \"Build third party\" ON)\noption(ONEFLOW \"Build oneflow\" ON)\n\nif(NOT THIRD_PARTY AND NOT ONEFLOW)\n  message(FATAL_ERROR \"at least one of flags THIRD_PARTY and ONEFLOW should be ON\")\nendif()\n\noption(USE_CLANG_FORMAT \"\" OFF)\noption(USE_CLANG_TIDY \"\" OFF)\noption(BUILD_PYTHON \"\" ON)\noption(BUILD_CPP_API \"Option to build OneFlow C++ API (beta)\" OFF)\noption(BUILD_RDMA \"\" OFF)\noption(BUILD_CUDA \"\" ON)\noption(BUILD_TESTING \"\" OFF)\noption(BUILD_GIT_VERSION \"\" ON)\noption(BUILD_PROFILER \"\" OFF)\noption(BUILD_FOR_CI \"\" OFF)\noption(WITH_COCOAPI \"Option to build with COCO API\" ON)\noption(WITH_ZLIB \"\" ON)\noption(WITH_ONEDNN \"\" ON)\noption(WITH_MLIR \"\" OFF)\noption(WITH_MLIR_CUDA_CODEGEN \"\" OFF)\noption(OF_SOFTMAX_USE_FAST_MATH \"\" ON)\noption(OF_LAYER_NORM_USE_FAST_MATH \"\" ON)\noption(TREAT_WARNINGS_AS_ERRORS \"\" ON)\noption(MAYBE_NEED_ERROR_MSG_CHECK \"\" OFF)\n\noption(LITE_USE_ASCEND_NPU \"\" OFF)\n\n# Reference:\n# https://medium.com/@alasher/colored-c-compiler-output-with-ninja-clang-gcc-10bfe7f2b949\noption(OF_FORCE_COLORED_DIAGNOSTICS \"Always produce ANSI-colored diagnostics (GNU/Clang only).\" ON)\n\nset(ONEFLOW_CURRENT_VERSION 0.8.1.dev CACHE STRING \"\")\n\nif(BUILD_FOR_CI)\n  set(ONEFLOW_CURRENT_VERSION ci)\nendif()\n\nset(LLVM_PROVIDER \"in-tree\" CACHE STRING \"in-tree, install\")\n\nif(NOT WITH_MLIR)\n  set(LLVM_PROVIDER \"install\"\n      CACHE STRING \"in-tree will build LLVM's ALL, not what we want when not building MLIR\" FORCE)\nendif(NOT WITH_MLIR)\n\nset(RPC_BACKEND \"GRPC,LOCAL\" CACHE STRING \"\")\nset(THIRD_PARTY_MIRROR \"\" CACHE STRING \"\")\nset(PIP_INDEX_MIRROR \"\" CACHE STRING \"\")\nset(CPU_THREADING_RUNTIMES \"TBB;OMP\" CACHE STRING \"\")\n\nif(APPLE)\n  set(RPC_BACKEND \"LOCAL\")\n  set(BUILD_CUDA OFF)\n  set(WITH_COCOAPI OFF)\n  set(WITH_ONEDNN OFF)\nendif()\n\nset(CUDNN_STATIC OFF CACHE BOOL \"\")\n\nproject(oneflow C CXX)\n\nif(NOT CMAKE_BUILD_TYPE)\n  message(STATUS \"No build type selected, default to Release\")\n  set(CMAKE_BUILD_TYPE \"Release\" CACHE STRING \"Build type (default Release)\" FORCE)\nendif()\n\nif(NOT CMAKE_BUILD_TYPE MATCHES \"^(Debug|Release|RelWithDebInfo|MinSizeRel)$\")\n  message(\n    FATAL_ERROR\n      \"Expected CMAKE_BUILD_TYPE is Debug, Release, RelWithDebInfo or MinSizeRel, got ${CMAKE_BUILD_TYPE}\"\n  )\nendif()\n\nmessage(STATUS \"CMAKE_BUILD_TYPE: ${CMAKE_BUILD_TYPE}\")\n\nset(COMPILER_VERSION_ERROR_MSG\n    \"At least gcc 9, clang 5 or Apple clang 12 is supported. Current version ${CMAKE_CXX_COMPILER_VERSION}.\"\n)\n\nif(\"${CMAKE_CXX_COMPILER_ID}\" STREQUAL \"GNU\")\n  if(\"${CMAKE_CXX_COMPILER_VERSION}\" VERSION_LESS 9)\n    message(FATAL_ERROR ${COMPILER_VERSION_ERROR_MSG})\n  endif()\nelseif(\"${CMAKE_CXX_COMPILER_ID}\" STREQUAL \"Clang\")\n  if(\"${CMAKE_CXX_COMPILER_VERSION}\" VERSION_LESS 5)\n    message(FATAL_ERROR ${COMPILER_VERSION_ERROR_MSG})\n  endif()\nelseif(\"${CMAKE_CXX_COMPILER_ID}\" STREQUAL \"AppleClang\")\n  if(\"${CMAKE_CXX_COMPILER_VERSION}\" VERSION_LESS 12)\n    message(FATAL_ERROR ${COMPILER_VERSION_ERROR_MSG})\n  endif()\nelse()\n  message(WARNING \"Unknown compiler \\\"${CMAKE_CXX_COMPILER_ID}\\\".\")\nendif()\n\nset(oneflow_cmake_dir ${PROJECT_SOURCE_DIR}/cmake)\n\nget_filename_component(real_src_dir \"${CMAKE_SOURCE_DIR}\" REALPATH)\nget_filename_component(real_bin_dir \"${CMAKE_BINARY_DIR}\" REALPATH)\n\nif(\"${real_src_dir}\" STREQUAL \"${real_bin_dir}\")\n  message(FATAL_ERROR \"In-source build not allowed\")\nendif()\n\n# Modules\nlist(APPEND CMAKE_MODULE_PATH ${oneflow_cmake_dir}/third_party)\nlist(APPEND CMAKE_MODULE_PATH ${oneflow_cmake_dir})\n\ninclude(threading)\ninclude(util)\ninclude(proto2cpp)\n\nif(NOT DEFINED USE_CXX11_ABI)\n  check_cxx11_abi(CXX11_ABI_AVAILABLE)\n  set(USE_CXX11_ABI ${CXX11_ABI_AVAILABLE})\nelseif(USE_CXX11_ABI)\n  check_cxx11_abi(CXX11_ABI_AVAILABLE)\n\n  if(NOT CXX11_ABI_AVAILABLE)\n    message(FATAL_ERROR \"cxx11 abi is not available for current compiler\")\n  endif()\nendif()\n\nmessage(STATUS \"USE_CXX11_ABI: ${USE_CXX11_ABI}\")\n\nif(WITH_MLIR)\n  add_definitions(-DWITH_MLIR)\n\n  if(WITH_MLIR_CUDA_CODEGEN)\n    add_definitions(-DWITH_MLIR_CUDA_CODEGEN)\n  endif()\nendif()\n\nif(WITH_COCOAPI)\n  add_definitions(-DWITH_COCOAPI)\nendif()\n\nif(USE_CXX11_ABI)\n  add_definitions(-D_GLIBCXX_USE_CXX11_ABI=1)\nelse()\n  add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0)\nendif()\n\nif(BUILD_PROFILER)\n  add_definitions(-DOF_ENABLE_PROFILER)\nendif()\n\nif(OF_SOFTMAX_USE_FAST_MATH)\n  add_definitions(-DOF_SOFTMAX_USE_FAST_MATH)\nendif()\n\nif(OF_LAYER_NORM_USE_FAST_MATH)\n  add_definitions(-DOF_LAYER_NORM_USE_FAST_MATH)\nendif()\n\nif(OF_FORCE_COLORED_DIAGNOSTICS)\n  add_compile_options(\n    $<$<COMPILE_LANGUAGE:CXX>:$<$<CXX_COMPILER_ID:GNU>:-fdiagnostics-color=always>>\n    $<$<COMPILE_LANGUAGE:CXX>:$<$<CXX_COMPILER_ID:Clang>:-fcolor-diagnostics>>\n    $<$<COMPILE_LANGUAGE:CUDA>:$<$<CUDA_COMPILER_ID:Clang>:-fcolor-diagnostics>>)\nendif()\n\nif(RPC_BACKEND MATCHES \"GRPC\")\n  add_definitions(-DRPC_BACKEND_GRPC)\n  message(STATUS \"RPC backend enabled: gRPC\")\n  set(SUPPORTED_RPC_BACKEND_FOUND 1)\nendif()\n\nif(WITH_ONEDNN)\n  add_definitions(-DWITH_ONEDNN)\nendif()\n\nadd_definitions(-DRPC_BACKEND_LOCAL)\nmessage(STATUS \"RPC backend enabled: local\")\nenable_testing()\nset(CMAKE_CXX_STANDARD 17)\nset(CMAKE_POSITION_INDEPENDENT_CODE ON)\n\nset(THIRD_PARTY_DIR \"${PROJECT_BINARY_DIR}/third_party_install\"\n    CACHE PATH \"Where to install third party headers and libs\")\n\nset(ONEFLOW_PYTHON_DIR \"${PROJECT_SOURCE_DIR}/python\" CACHE PATH \"oneflow python src dir\")\n\ninclude(platform)\n\nif((ENABLE_ASAN OR ENABLE_UBSAN) AND ENABLE_TSAN)\n  message(FATAL_ERROR \"Only ASAN and UBSAN can be enabled at the same time.\")\nendif()\nif(ENABLE_ASAN)\n  add_compile_options(-fsanitize=address -fno-omit-frame-pointer)\n  add_link_options(-fsanitize=address -fno-omit-frame-pointer)\nendif()\nif(ENABLE_UBSAN)\n  add_compile_options(-fsanitize=undefined)\n  add_link_options(-fsanitize=undefined)\nendif()\nif(ENABLE_TSAN)\n  add_compile_options(-fsanitize=thread)\n  add_link_options(-fsanitize=thread)\nendif()\n\nif(BUILD_PYTHON)\n  set(ONEFLOW_INCLUDE_DIR \"${ONEFLOW_PYTHON_DIR}/oneflow/include\")\nendif(BUILD_PYTHON)\n\nset(CUTLASS_URL\n    https://github.com/Oneflow-Inc/cutlass/archive/e6f548d80bfdf1167d66adbbbcfc2ee3394f4777.zip)\nuse_mirror(VARIABLE CUTLASS_URL URL ${CUTLASS_URL})\nset(CUTLASS_MD5 425f8cf064ff47c81124e55490135f5c)\n\ninclude(cuda)\nadd_subdirectory(external)\ninclude(third_party)\n\nmessage(STATUS \"CMAKE_CXX_COMPILER_VERSION: \" ${CMAKE_CXX_COMPILER_VERSION})\n\nadd_custom_target(oneflow_deps ALL DEPENDS prepare_oneflow_third_party)\n\n# skip oneflow cmake to avoid errors caused by the absences of python-dev, proto src\nif(ONEFLOW)\n  include(oneflow)\nendif()\n\nadd_subdirectory(ci)\n"
  },
  {
    "path": "LICENSE",
    "content": "Copyright 2020 The OneFlow Authors. All rights reserved.\n                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [yyyy] [name of copyright owner]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n"
  },
  {
    "path": "README.md",
    "content": "# OneFlow\n\nOneFlow is a deep learning framework designed to be **user-friendly, scalable and efficient**. With OneFlow, it is easy to:\n\n- program a model with [**PyTorch-like API**](https://oneflow.readthedocs.io/en/master/)\n- scale a model to n-dimensional-parallel execution with the [**Global Tensor**](https://docs.oneflow.org/en/master/cookies/global_tensor.html)\n- accelerate/deploy a model with the [**Graph Compiler**](https://oneflow.readthedocs.io/en/master/graph.html).\n\n[![Simple CI](https://github.com/Oneflow-Inc/oneflow/actions/workflows/simple.yml/badge.svg)](https://github.com/Oneflow-Inc/oneflow/actions/workflows/simple.yml)\n[![Nightly Docker Image](https://github.com/Oneflow-Inc/docker-images/actions/workflows/oneflow-nightly.yml/badge.svg)](https://github.com/Oneflow-Inc/docker-images/actions/workflows/oneflow-nightly.yml)\n[![Nightly Release](https://github.com/Oneflow-Inc/oneflow/actions/workflows/release.yml/badge.svg)](https://github.com/Oneflow-Inc/oneflow/actions/workflows/release.yml)\n[![Documentation](https://readthedocs.org/projects/oneflow/badge/?version=master)](https://oneflow.readthedocs.io/en/master/?badge=master)\n\n## Latest News\n\n- Version 1.0.0 is out!\n  - [Full changelog](https://github.com/Oneflow-Inc/oneflow/releases/tag/v1.0.0)\n\n## Publication\n\n- [OneFlow: Redesign the Distributed Deep Learning Framework from Scratch](https://arxiv.org/abs/2110.15032)\n\n## System Requirements\n\n### General\n- Linux\n- Python 3.7, 3.8, 3.9, 3.10, 3.11\n\n### CUDA\n- CUDA arch 60 or above\n- CUDA Toolkit version 10.0 or above\n- Nvidia driver version 440.33 or above\n\n  OneFlow will work on a minimum supported driver, and any driver beyond. For more information, please refer to [CUDA compatibility documentation](https://docs.nvidia.com/deploy/cuda-compatibility/index.html).\n\n## Install\n\n### Preinstall docker image\n\n```\ndocker pull oneflowinc/oneflow:nightly-cuda11.8\n```\n\n### Pip Install\n\n- (**Highly recommended**) Upgrade pip\n\n  ```\n  python3 -m pip install --upgrade pip #--user\n  ```\n\n- To install latest stable release of OneFlow with CUDA support:\n\n  ```bash\n  python3 -m pip install oneflow\n  ```\n\n- To install nightly release of OneFlow with CPU-only support:\n\n  ```bash\n  python3 -m pip install --pre oneflow -f https://oneflow-staging.oss-cn-beijing.aliyuncs.com/branch/master/cpu\n  ```\n\n- To install nightly release of OneFlow with CUDA support:\n\n  ```bash\n  python3 -m pip install --pre oneflow -f https://oneflow-staging.oss-cn-beijing.aliyuncs.com/branch/master/cu118\n  ```\n\n  If you are in China, you could run this to have pip download packages from domestic mirror of pypi:\n  ```\n  python3 -m pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple\n  ```\n  For more information on this, please refer to [pypi 镜像使用帮助](https://mirror.tuna.tsinghua.edu.cn/help/pypi/)\n\n### Install from Source\n\n<details>\n<summary>Clone Source Code</summary>\n\n- #### Option 1: Clone source code from GitHub\n\n  ```bash\n  git clone https://github.com/Oneflow-Inc/oneflow.git\n  ```\n\n- #### Option 2: Download from Aliyun(Only available in China)\n\n  ```bash\n  curl https://oneflow-public.oss-cn-beijing.aliyuncs.com/oneflow-src.zip -o oneflow-src.zip\n  unzip oneflow-src.zip\n  ```\n\n  </details>\n\n<details>\n<summary>Build OneFlow</summary>\n\n- Install dependencies\n  ```\n  apt install -y libopenblas-dev nasm g++ gcc python3-pip cmake autoconf libtool\n  ```\n  These dependencies are preinstalled in offical conda environment and docker image, you can use the offical conda environment [here](https://github.com/Oneflow-Inc/conda-env) or use the docker image by:\n  ```bash\n  docker pull oneflowinc/manylinux2014_x86_64_cuda11.2\n  ```\n- In the root directory of OneFlow source code, run:\n\n  ```\n  mkdir build\n  cd build\n  ```\n\n- Config the project, inside `build` directory:\n\n  - If you are in China\n\n    config for CPU-only like this:\n\n    ```\n    cmake .. -C ../cmake/caches/cn/cpu.cmake\n    ```\n\n    config for CUDA like this:\n\n    ```\n    cmake .. -C ../cmake/caches/cn/cuda.cmake -DCMAKE_CUDA_ARCHITECTURES=80 -DCUDA_TOOLKIT_ROOT_DIR=/usr/local/cuda -DCUDNN_ROOT_DIR=/usr/local/cudnn\n    ```\n\n  - If you are not in China\n\n    config for CPU-only like this:\n\n    ```\n    cmake .. -C ../cmake/caches/international/cpu.cmake\n    ```\n\n    config for CUDA like this:\n\n    ```\n    cmake .. -C ../cmake/caches/international/cuda.cmake -DCMAKE_CUDA_ARCHITECTURES=80 -DCUDA_TOOLKIT_ROOT_DIR=/usr/local/cuda -DCUDNN_ROOT_DIR=/usr/local/cudnn\n    ```\n    Here the DCMAKE\\_CUDA\\_ARCHITECTURES macro is used to specify the CUDA architecture, and the DCUDA\\_TOOLKIT\\_ROOT\\_DIR and DCUDNN\\_ROOT\\_DIR macros are used to specify the root path of the CUDA Toolkit and CUDNN.\n\n- Build the project, inside `build` directory, run:\n\n  ```\n  make -j$(nproc)\n  ```\n\n- Add oneflow to your PYTHONPATH, inside `build` directory, run:\n\n  ```\n  source source.sh\n  ```\n\n  Please note that this change is not permanent.\n\n- Simple validation\n\n  ```\n  python3 -m oneflow --doctor\n  ```\n\n  </details>\n\n### Troubleshooting\n\nPlease refer to [troubleshooting](docs/source/troubleshooting.md) for common issues you might encounter when compiling and running OneFlow.\n\n## Getting Started\n\n- Please refer to [QUICKSTART](https://docs.oneflow.org/en/master/basics/01_quickstart.html)\n- 中文版请参见 [快速上手](https://docs.oneflow.org/master/basics/01_quickstart.html)\n\n## Documentation\n\n- [API Reference](https://oneflow.readthedocs.io/en/master/)\n- [Usage & Design Docs](http://docs.oneflow.org/)\n- [System Design](https://docs.oneflow.org/en/v0.4.0/basics_topics/essentials_of_oneflow.html)\n\n## Model Zoo and Benchmark\n\n- [Libai(Toolbox for Parallel Training Large-Scale Transformer Models)](https://github.com/Oneflow-Inc/libai)\n  - [BERT-large](https://libai.readthedocs.io/en/latest/tutorials/get_started/quick_run.html)\n  - [GPT](https://libai.readthedocs.io/en/latest/modules/libai.models.html#id5)\n  - [T5](https://libai.readthedocs.io/en/latest/modules/libai.models.html#id4)\n  - [VisionTransformer](https://libai.readthedocs.io/en/latest/modules/libai.models.html#id1)\n  - [SwinTransformer](https://libai.readthedocs.io/en/latest/modules/libai.models.html#id2)\n- [FlowVision(Toolbox for Computer Vision Datasets, SOTA Models and Utils)](https://github.com/Oneflow-Inc/vision)\n- [OneFlow-Models(Outdated)](https://github.com/Oneflow-Inc/models)\n  - [ResNet-50](https://github.com/Oneflow-Inc/models/tree/main/Vision/classification/image/resnet50)\n  - [Wide&Deep](https://github.com/Oneflow-Inc/models/tree/main/RecommenderSystems/wide_and_deep)\n- [OneFlow-Benchmark(Outdated)](https://github.com/Oneflow-Inc/OneFlow-Benchmark)\n\n## Communication\n\n- [GitHub issues](https://github.com/Oneflow-Inc/oneflow/issues): any install, bug, feature issues.\n- [www.oneflow.org](http://www.oneflow.org): brand related information.\n\n- ### 中文\n\n  - QQ 群: 331883\n  - 微信号（加好友入交流群）: OneFlowXZS\n  - [知乎](https://www.zhihu.com/org/oneflow-17)\n\n- ### International\n  - [Discord](https://discord.gg/4kpjGA5bZY)\n  - [Twitter](https://twitter.com/OneFlowNews)\n  - [LinkedIn](https://www.linkedin.com/company/oneflow-inc)\n  - [Medium](https://oneflow2020.medium.com)\n\n## The Team\n\nOneFlow was originally developed by [OneFlow Inc](http://www.oneflow.org) and [Zhejiang Lab](http://www.zhejianglab.com/).\n\n## License\n\n[Apache License 2.0](LICENSE)\n"
  },
  {
    "path": "ci/CMakeLists.txt",
    "content": "add_subdirectory(test)\n"
  },
  {
    "path": "ci/build/ensure_img.py",
    "content": "import os\nimport argparse\nfrom pathlib import Path\nimport re\nimport json\nimport subprocess\n\n\ndef check_and_download(tag, url):\n    img_dir = os.path.join(os.path.expanduser(\"~\"), \"imgs\")\n    if not os.path.exists(img_dir):\n        os.makedirs(img_dir)\n    returncode = subprocess.run(\n        f\"docker image inspect {tag}\",\n        shell=True,\n        stdout=subprocess.DEVNULL,\n        stderr=subprocess.DEVNULL,\n    ).returncode\n    if returncode == 0:\n        print(\"[OK]\", tag)\n    else:\n        basename = os.path.basename(url)\n        dst = os.path.join(img_dir, basename)\n        subprocess.check_call(f\"wget -c {url} -O {dst}\", shell=True)\n        subprocess.check_call(f\"docker load -i {dst}\", shell=True)\n        base = os.path.basename(dst)\n        base = os.path.splitext(base)[0]\n        base = os.path.splitext(base)[0]\n        keep_tag = f\"ofkeep:{base}\"\n        subprocess.check_call(f\"docker tag {tag} {keep_tag}\", shell=True)\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\n        \"--create_index\", action=\"store_true\", required=False, default=False\n    )\n    args = parser.parse_args()\n    imgs = [\n        {\n            \"tag\": \"nvidia/cuda:10.0-cudnn7-devel-centos7\",\n            \"url\": \"https://oneflow-static.oss-cn-beijing.aliyuncs.com/img/nvidiacuda10.0-cudnn7-devel-centos7.tar.gz\",\n        },\n        {\n            \"tag\": \"nvidia/cuda:10.1-cudnn7-devel-centos7\",\n            \"url\": \"https://oneflow-static.oss-cn-beijing.aliyuncs.com/img/nvidiacuda10.1-cudnn7-devel-centos7.tar.gz\",\n        },\n        {\n            \"tag\": \"nvidia/cuda:10.2-cudnn7-devel-centos7\",\n            \"url\": \"https://oneflow-static.oss-cn-beijing.aliyuncs.com/img/nvidiacuda10.2-cudnn7-devel-centos7.tar.gz\",\n        },\n        {\n            \"tag\": \"nvidia/cuda:11.0-cudnn8-devel-centos7\",\n            \"url\": \"https://oneflow-static.oss-cn-beijing.aliyuncs.com/img/nvidiacuda11.0-cudnn8-devel-centos7.tar.gz\",\n        },\n        {\n            \"tag\": \"nvidia/cuda:11.1-cudnn8-devel-centos7\",\n            \"url\": \"https://oneflow-static.oss-cn-beijing.aliyuncs.com/img/nvidiacuda11.1-cudnn8-devel-centos7.tar.gz\",\n        },\n    ]\n    for img in imgs:\n        check_and_download(img[\"tag\"], img[\"url\"])\n"
  },
  {
    "path": "ci/build/make.sh",
    "content": "set -ex\n\nsrc_dir=${ONEFLOW_SRC_DIR:-\"$PWD\"}\ntmp_dir=${ONEFLOW_CI_TMP_DIR:-\"$HOME/ci-tmp\"}\nextra_oneflow_cmake_args=${ONEFLOW_CI_EXTRA_ONEFLOW_CMAKE_ARGS:-\"\"}\npackage_suffix=${ONEFLOW_CI_PACKAGE_SUFFIX:-\"\"}\ncuda_version=${ONEFLOW_CI_CUDA_VERSION:-\"10.2\"}\npython_version_args=${ONEFLOW_CI_PYTHON_VERSION_ARGS:-\"--python3.6\"}\nbuild_wheel_bash_args=${ONEFLOW_CI_BUILD_WHEEL_BASH_ARGS:-\"-l\"}\nmkdir -p $tmp_dir\ndocker_tag=${ONEFLOW_CI_DOCKER_TAG:-\"oneflow:ci-manylinux2014-cuda10.2\"}\n\ndocker_proxy_build_args=\"\"\ndocker_proxy_build_args+=\"--build-arg http_proxy=${ONEFLOW_CI_HTTP_PROXY} --build-arg https_proxy=${ONEFLOW_CI_HTTPS_PROXY}\"\ndocker_proxy_run_args=\"\"\ndocker_proxy_run_args+=\"--env http_proxy=${ONEFLOW_CI_HTTP_PROXY} --env https_proxy=${ONEFLOW_CI_HTTPS_PROXY}\"\n\ndocker_it=\"\"\nif [[ -t 1 ]]; then\n    docker_it=\"-it\"\nfi\n\n# build manylinux image\ncd $src_dir\ndocker build -f $src_dir/docker/package/manylinux/Dockerfile \\\n    --build-arg from=nvidia/cuda:${cuda_version}-cudnn7-devel-centos7 \\\n    $docker_proxy_build_args -t $docker_tag .\n\ncd -\n\n# build function\nfunction build() {\n    set -x\n    docker run --rm \\\n        -v $tmp_dir:/ci-tmp \\\n        -w $tmp_dir:/ci-tmp busybox rm -rf /ci-tmp/wheelhouse\n    docker run \\\n        $docker_proxy_run_args \\\n        --rm $docker_it \\\n        -v $src_dir:/oneflow-src \\\n        -v $tmp_dir:/ci-tmp \\\n        -w /ci-tmp \\\n        \"$docker_tag\" \\\n        bash ${build_wheel_bash_args} /oneflow-src/docker/package/manylinux/build_wheel.sh \\\n            ${python_version_args} \\\n            --house-dir /ci-tmp/wheelhouse \\\n            --package-name oneflow${package_suffix} \\\n            $extra_oneflow_cmake_args\n}\n\nset +e\n# reuse cache\nbuild\n\n# clean cache and retry\ncached_build_ret=$?\nset -e\nif [ $cached_build_ret -ne 0 ] && [[ ! -t 1 ]]; then\n    echo \"retry after cleaning build dir\"\n    docker run --rm -v $tmp_dir:/ci-tmp busybox sh -c \"rm -rf /ci-tmp/*\"\n    build\nfi\n"
  },
  {
    "path": "ci/check/clang_tidy_warnings_as_errors_on_diff",
    "content": "*,-maybe-glog-fatal,-clang-analyzer-alpha.*,-clang-analyzer-cplusplus.NewDelete,-clang-diagnostic-*"
  },
  {
    "path": "ci/check/lintutils.py",
    "content": "# Licensed to the Apache Software Foundation (ASF) under one\n# or more contributor license agreements.  See the NOTICE file\n# distributed with this work for additional information\n# regarding copyright ownership.  The ASF licenses this file\n# to you under the Apache License, Version 2.0 (the\n# \"License\"); you may not use this file except in compliance\n# with the License.  You may obtain a copy of the License at\n#\n#   http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing,\n# software distributed under the License is distributed on an\n# \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n# KIND, either express or implied.  See the License for the\n# specific language governing permissions and limitations\n# under the License.\n\nimport multiprocessing as mp\nimport os\nfrom fnmatch import fnmatch\nfrom subprocess import Popen\n\n\ndef chunk(seq, n):\n    \"\"\"\n    divide a sequence into equal sized chunks\n    (the last chunk may be smaller, but won't be empty)\n    \"\"\"\n    chunks = []\n    some = []\n    for element in seq:\n        if len(some) == n:\n            chunks.append(some)\n            some = []\n        some.append(element)\n    if len(some) > 0:\n        chunks.append(some)\n    return chunks\n\n\ndef dechunk(chunks):\n    \"flatten chunks into a single list\"\n    seq = []\n    for chunk in chunks:\n        seq.extend(chunk)\n    return seq\n\n\ndef run_parallel(cmds, **kwargs):\n    \"\"\"\n    Run each of cmds (with shared **kwargs) using subprocess.Popen\n    then wait for all of them to complete.\n    Runs batches of multiprocessing.cpu_count() * 2 from cmds\n    returns a list of tuples containing each process'\n    returncode, stdout, stderr\n    \"\"\"\n    complete = []\n    for cmds_batch in chunk(cmds, mp.cpu_count() * 2):\n        procs_batch = [Popen(cmd, **kwargs) for cmd in cmds_batch]\n        for proc in procs_batch:\n            stdout, stderr = proc.communicate()\n            complete.append((proc.returncode, stdout, stderr))\n    return complete\n\n\n_source_extensions = \"\"\"\n.h\n.cc\n.cpp\n.cu\n.cuh\n\"\"\".split()\n\n\ndef get_sources(source_dir, exclude_globs=[]):\n    sources = []\n    for directory, subdirs, basenames in os.walk(source_dir):\n        for path in [os.path.join(directory, basename) for basename in basenames]:\n            # filter out non-source files\n            if os.path.splitext(path)[1] not in _source_extensions:\n                continue\n\n            path = os.path.abspath(path)\n\n            # filter out files that match the globs in the globs file\n            if any([fnmatch(path, glob) for glob in exclude_globs]):\n                continue\n\n            sources.append(path)\n    return sources\n\n\ndef stdout_pathcolonline(completed_process, filenames):\n    \"\"\"\n    given a completed process which may have reported some files as problematic\n    by printing the path name followed by ':' then a line number, examine\n    stdout and return the set of actually reported file names\n    \"\"\"\n    returncode, stdout, stderr = completed_process\n    bfilenames = set()\n    for filename in filenames:\n        bfilenames.add(filename.encode(\"utf-8\") + b\":\")\n    problem_files = set()\n    for line in stdout.splitlines():\n        for filename in bfilenames:\n            if line.startswith(filename):\n                problem_files.add(filename.decode(\"utf-8\"))\n                bfilenames.remove(filename)\n                break\n    return problem_files, stdout\n"
  },
  {
    "path": "ci/check/run_clang_format.py",
    "content": "#!/usr/bin/env python3\n# Licensed to the Apache Software Foundation (ASF) under one\n# or more contributor license agreements.  See the NOTICE file\n# distributed with this work for additional information\n# regarding copyright ownership.  The ASF licenses this file\n# to you under the Apache License, Version 2.0 (the\n# \"License\"); you may not use this file except in compliance\n# with the License.  You may obtain a copy of the License at\n#\n#   http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing,\n# software distributed under the License is distributed on an\n# \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n# KIND, either express or implied.  See the License for the\n# specific language governing permissions and limitations\n# under the License.\n\nimport asyncio\nimport argparse\nimport pathlib\nimport multiprocessing\nimport subprocess\nimport os\nimport platform\n\n\ndef split_and_print(prefix, text):\n    lines = text.decode().splitlines(keepends=True)\n    prefixed = \"\"\n    for l in lines:\n        prefixed += f\"{prefix} {l.strip()}\"\n    if l.strip():\n        print(prefixed, flush=True)\n\n\nasync def handle_stream(stream, cb):\n    while True:\n        line = await stream.readline()\n        if line:\n            cb(line)\n        else:\n            break\n\n\nasync def run_command(cmd=None, dry=False, name=None):\n    if dry:\n        print(f\"[dry] {cmd}\")\n        return 0\n    process = await asyncio.create_subprocess_shell(\n        cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE,\n    )\n    l = lambda x: split_and_print(f\"[{name}]\" if name else \"\", x)\n    # l = lambda x: x\n    await asyncio.gather(\n        handle_stream(process.stdout, l), handle_stream(process.stderr, l),\n    )\n    await process.wait()\n    return process.returncode\n\n\ndef chunks(lst, n):\n    \"\"\"Yield successive n-sized chunks from lst.\"\"\"\n    for i in range(0, len(lst), n):\n        yield lst[i : i + n]\n\n\ndef check_version(bin):\n    try:\n        out = subprocess.check_output([\"bash\", \"-c\", f\"{bin} --version\"]).decode()\n        print(out)\n        return \"version 11.0.0\" in out\n    except:\n        return False\n\n\ndef download(dry=False):\n    if platform.system() != \"Linux\":\n        raise ValueError(\"Please install clang format 11.0.0\")\n    url = \"https://oneflow-static.oss-cn-beijing.aliyuncs.com/bin/clang-format/linux-x86/clang-format-11\"\n    if os.getenv(\"CI\"):\n        url = \"https://github.com/Oneflow-Inc/oneflow-fmt/raw/master/clang-format/linux-x86/clang-format-11\"\n    dst_dir = \".cache/bin\"\n    dst = f\"{dst_dir}/clang-format-11\"\n    if dry:\n        if os.path.isfile(dst):\n            return dst\n        else:\n            None\n    else:\n        assert subprocess.call(f\"mkdir -p {dst_dir}\", shell=True) == 0\n        assert subprocess.call(f\"curl -L {url} -o {dst}\", shell=True) == 0\n        assert subprocess.call(f\"chmod +x {dst}\", shell=True) == 0\n        return dst\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(\n        description=\"Runs clang-format on all of the source \"\n        \"files. If --fix is specified enforce format by \"\n        \"modifying in place, otherwise compare the output \"\n        \"with the existing file and output any necessary \"\n        \"changes as a patch in unified diff format\"\n    )\n    parser.add_argument(\n        \"--clang_format_binary\",\n        required=False,\n        help=\"Path to the clang-format binary.\",\n        default=\"clang-format\",\n    )\n    parser.add_argument(\n        \"--source_dir\", required=True, help=\"Root directory of the source code\"\n    )\n    parser.add_argument(\n        \"--fix\",\n        default=False,\n        action=\"store_true\",\n        help=\"If specified, will re-format the source \"\n        \"code instead of comparing the re-formatted \"\n        \"output, defaults to %(default)s\",\n    )\n    parser.add_argument(\n        \"--quiet\",\n        default=False,\n        action=\"store_true\",\n        help=\"If specified, only print errors\",\n    )\n    args = parser.parse_args()\n    exts = [\".h\", \".cc\", \".cpp\", \".cu\", \".cuh\"]\n    files = filter(\n        lambda p: p.suffix in exts, pathlib.Path(args.source_dir).rglob(\"*\"),\n    )\n    loop = asyncio.get_event_loop()\n    files = [str(f) for f in files]\n    clang_fmt_args = \"-dry-run --Werror\"\n    if args.fix:\n        clang_fmt_args = \"-i\"\n    results = []\n    if check_version(args.clang_format_binary) == False:\n        downloaded = download(dry=True)\n        if downloaded:\n            assert check_version(downloaded)\n            args.clang_format_binary = downloaded\n        else:\n            args.clang_format_binary = download()\n            assert check_version(args.clang_format_binary)\n    for chunk in chunks(files, multiprocessing.cpu_count() * 2):\n        promises = [\n            run_command(f\"{args.clang_format_binary} {clang_fmt_args} {f}\")\n            for f in chunk\n        ]\n        chunk_results = loop.run_until_complete(asyncio.gather(*promises))\n        results.extend(chunk_results)\n    print(len(results), \"files checked\")\n    assert len(results) == len(files)\n    for (r, f) in zip(results, files):\n        if r != 0:\n            print(\"[fail]\", f)\n    assert sum(results) == 0\n"
  },
  {
    "path": "ci/check/run_clang_tidy.py",
    "content": "#!/usr/bin/env python3\n# Licensed to the Apache Software Foundation (ASF) under one\n# or more contributor license agreements.  See the NOTICE file\n# distributed with this work for additional information\n# regarding copyright ownership.  The ASF licenses this file\n# to you under the Apache License, Version 2.0 (the\n# \"License\"); you may not use this file except in compliance\n# with the License.  You may obtain a copy of the License at\n#\n#   http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing,\n# software distributed under the License is distributed on an\n# \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n# KIND, either express or implied.  See the License for the\n# specific language governing permissions and limitations\n# under the License.\n\nimport asyncio\nimport argparse\nimport subprocess\nimport os\nfrom typing import List, Optional\nfrom pathlib import Path\n\n\ndef split_and_print(prefix, text):\n    lines = text.decode().splitlines(keepends=True)\n    prefixed = \"\"\n    for l in lines:\n        prefixed += f\"{prefix} {l.strip()}\"\n    if l.strip():\n        print(prefixed, flush=True)\n\n\nasync def handle_stream(stream, cb):\n    while True:\n        line = await stream.readline()\n        if line:\n            cb(line)\n        else:\n            break\n\n\nasync def run_command(cmd=None, dry=False, name=None):\n    if dry:\n        print(f\"[dry] {cmd}\")\n        return 0\n    process = await asyncio.create_subprocess_shell(\n        cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE,\n    )\n    l = lambda x: split_and_print(f\"[{name}]\" if name else \"\", x)\n    await asyncio.gather(\n        handle_stream(process.stdout, l), handle_stream(process.stderr, l),\n    )\n    await process.wait()\n    return process.returncode\n\n\ndef download(build_dir, dry=False) -> Optional[List[str]]:\n    urls = [\n        \"https://github.com/Oneflow-Inc/llvm-project/releases/download/update-err-msg-checker/clang-tidy-15.AppImage\"\n        if os.getenv(\"CI\")\n        else \"https://oneflow-static.oss-cn-beijing.aliyuncs.com/bin/clang-tidy/linux-x86_64/clang-tidy-15.AppImage\",\n        \"https://raw.githubusercontent.com/oneflow-inc/llvm-project/maybe/clang-tools-extra/clang-tidy/tool/clang-tidy-diff.py\",\n    ]\n    dst_dir = f\"{build_dir}/cache/bin\"\n    dst = [f\"{dst_dir}/clang-tidy\", f\"{dst_dir}/clang-tidy-diff.py\"]\n    if dry:\n        if os.path.isfile(dst[0]) and os.path.isfile(dst[1]):\n            return dst\n        else:\n            None\n    else:\n        assert subprocess.call(f\"mkdir -p {dst_dir}\", shell=True) == 0\n        for i, _dst in enumerate(dst):\n            assert subprocess.call(f\"curl -L {urls[i]} -o {_dst}\", shell=True) == 0\n            assert subprocess.call(f\"chmod +x {_dst}\", shell=True) == 0\n        return dst\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(\n        description=\"Runs clang-tidy on all of the source files.\"\n    )\n    parser.add_argument(\n        \"--build_dir\", required=True,\n    )\n    parser.add_argument(\n        \"--check-error-msg\", action=\"store_true\", default=False,\n    )\n    args = parser.parse_args()\n    loop = asyncio.get_event_loop()\n    downloaded = download(args.build_dir, dry=True)\n    if downloaded is None:\n        downloaded = download(args.build_dir)\n    assert downloaded is not None\n    warnings_as_errors = (\n        (Path(__file__).parent / \"clang_tidy_warnings_as_errors_on_diff\")\n        .read_text()\n        .strip()\n    )\n    cmd = f\"git diff -U0 master | {downloaded[1]} -clang-tidy-binary {downloaded[0]} -path {args.build_dir} -j $(nproc) -p1 -allow-enabling-alpha-checkers -extra-arg=-Xclang -extra-arg=-analyzer-config -extra-arg=-Xclang -extra-arg=aggressive-binary-operation-simplification=true\"\n    if args.check_error_msg:\n        command = f\" cd .. && {cmd} -warnings-as-errors='{warnings_as_errors}' && {cmd} -checks=-*,maybe-need-error-msg -warnings-as-errors=* -skip-line-filter\"\n    else:\n        command = f\"cd .. && {cmd} -warnings-as-errors='{warnings_as_errors}'\"\n\n    ret_code = loop.run_until_complete(run_command(command))\n    exit(ret_code)\n"
  },
  {
    "path": "ci/check/run_cmake_format.py",
    "content": "from subprocess import call\nfrom argparse import ArgumentParser\nfrom glob import glob\nfrom pathlib import Path\nfrom multiprocessing.pool import ThreadPool\nfrom multiprocessing import cpu_count\n\nif __name__ == \"__main__\":\n    parser = ArgumentParser(\n        description=\"Runs cmake-format on all of the cmake source files.\"\n    )\n\n    parser.add_argument(\n        \"--bin\", default=\"cmake-format\", help=\"Path of cmake-format binary\"\n    )\n    parser.add_argument(\n        \"--fix\", default=False, action=\"store_true\", help=\"Format all sources in place\"\n    )\n    parser.add_argument(\n        \"--source_dir\", default=\".\", help=\"Root directory of the source code\"\n    )\n    parser.add_argument(\n        \"-j\",\n        \"--jobs\",\n        type=int,\n        default=cpu_count(),\n        help=\"Specifies the number of jobs (commands) to run simultaneously\",\n    )\n\n    args = parser.parse_args()\n\n    patterns = [\n        \"cmake/**/*.cmake\",\n        \"oneflow/**/*.cmake\",\n        \"oneflow/**/CMakeLists.txt\",\n        \"tools/**/*.cmake\",\n        \"tools/**/CMakeLists.txt\",\n        \"CMakeLists.txt\",\n    ]\n\n    files = []\n    for pattern in patterns:\n        files.extend(glob(str(Path(args.source_dir) / pattern), recursive=True))\n\n    def gen_cmd(file):\n        cmd = [args.bin, file]\n        cmd.append(\"-i\" if args.fix else \"--check\")\n        return cmd\n\n    tp = ThreadPool(args.jobs)\n    res = tp.map_async(call, [gen_cmd(file) for file in files])\n\n    tp.close()\n    tp.join()\n\n    count = sum(map(lambda x: 0 if x == 0 else 1, res.get()))\n    total = len(files)\n    if args.fix:\n        print(f\"cmake-format -i done. {total} total\")\n    else:\n        print(f\"cmake-format --check done. {count} failed / {total} total\")\n\n    exit(0 if count == 0 else 1)\n"
  },
  {
    "path": "ci/check/run_license_format.py",
    "content": "import argparse\nimport os\nimport glob\nfrom multiprocessing import Pool\n\nLICENSE_TXT = \"\"\"Copyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nCPP_TXT = \"/*\\n{}*/\\n\".format(LICENSE_TXT)\nPY_TXT = '\"\"\"\\n{}\"\"\"\\n'.format(LICENSE_TXT)\n\n\ndef get_txt(path: str):\n    if path.endswith((\".cpp\", \".h\", \".hpp\", \".cu\", \".cuh\")):\n        return CPP_TXT\n    elif path.endswith((\".py\")):\n        return PY_TXT\n    else:\n        return None\n\n\ndef check_file(path):\n    with open(path, \"r\", encoding=\"utf-8\") as f:\n        content = f.read()\n        txt = get_txt(path)\n        if (\n            \"import doctest\" in content\n            and \"raise_on_error=True\" not in content\n            and \"doctest.DebugRunner\" not in content\n        ):\n            return (\"please add 'doctest.testmod(raise_on_error=True)'\", content)\n        elif content.count(\"The OneFlow Authors. All rights reserved.\") > 1:\n            return (\"license_duplicated\", content)\n        elif content.startswith(txt) or (not content):\n            return (\"ok\", content)\n        elif content.startswith(txt) == False:\n            return (\"license_absent\", content)\n\n\ndef format_file(path):\n    txt = get_txt(path)\n    with open(path, \"r\", encoding=\"utf-8\") as r:\n        content = r.read()\n    format_status, content = check_file(path)\n    if format_status == \"ok\":\n        return True\n    elif format_status == \"license_absent\":\n        with open(path, \"w\") as w:\n            new_content = txt + content\n            w.write(new_content)\n        return False\n    else:\n        raise ValueError(f\"{format_status} {path}\")\n\n\ndef do_check(x):\n    format_status, _ = check_file(x)\n    return (x, format_status)\n\n\ndef do_format(x):\n    return (x, format_file(x))\n\n\ndef glob_files(path: str = None, excludes=None):\n    files = []\n    for ext in (\"**/*.cpp\", \"**/*.h\", \"**/*.hpp\", \"**/*.cu\", \"**/*.cuh\", \"**/*.py\"):\n        joined = os.path.join(path, ext)\n        files.extend(glob.glob(joined, recursive=True))\n    files = [\n        f\n        for f in files\n        if \"version.py\" not in f and all([not e in f for e in excludes])\n    ]\n    print(\"[files]\", len(files))\n    return files\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"-i\", \"--root_path\", type=str, required=True)\n    parser.add_argument(\n        \"-v\", \"--verbose\", default=False, action=\"store_true\", required=False\n    )\n    parser.add_argument(\"--silent\", default=False, action=\"store_true\", required=False)\n    parser.add_argument(\n        \"-c\", \"--check\", default=False, action=\"store_true\", required=False\n    )\n    parser.add_argument(\n        \"-f\", \"--fix\", default=False, action=\"store_true\", required=False\n    )\n    parser.add_argument(\"--exclude\", action=\"append\", default=[])\n    args = parser.parse_args()\n    files = glob_files(args.root_path, excludes=args.exclude)\n    assert args.check != args.fix\n    with Pool(10) as p:\n        if args.check:\n            any_absence = False\n            for (p, format_status) in p.map(do_check, files):\n                if format_status != \"ok\":\n                    print(f\"{format_status}:\", p)\n                    any_absence = True\n            if any_absence:\n                exit(1)\n        if args.fix:\n            for (p, format_result) in p.map(do_format, files):\n                if format_result == True:\n                    if args.verbose:\n                        print(\"license already added:\", p)\n                else:\n                    if args.silent == False:\n                        print(\"license just added:\", p)\n"
  },
  {
    "path": "ci/check/run_py_format.py",
    "content": "import argparse\nimport sys\nimport platform\nfrom subprocess import Popen\nimport os\n\nif __name__ == \"__main__\":\n\n    major = platform.sys.version_info.major\n    minor = platform.sys.version_info.minor\n    if major == 3 and minor < 6:\n        print(\"WARNING: python >= 3.6 required, python source format won't run\")\n        exit(0)\n    parser = argparse.ArgumentParser(\n        description=\"Runs py-format on all of the source files.\"\n        \"If --fix is specified enforce format by modifying in place.\"\n    )\n    parser.add_argument(\n        \"--source_dir\", required=True, help=\"Root directory of the source code\"\n    )\n    parser.add_argument(\n        \"--fix\",\n        default=False,\n        action=\"store_true\",\n        help=\"If specified, will re-format the source\",\n    )\n\n    arguments = parser.parse_args()\n    os.chdir(arguments.source_dir)\n\n    version_cmd = sys.executable + \" -m {} --version | grep {} > /dev/null\"\n    BLACK_VER = \"19.10b0\"\n    if os.system(version_cmd.format(\"black\", BLACK_VER)):\n        print(\n            f\"Please install black {BLACK_VER}. For instance, run 'python3 -m pip install black=={BLACK_VER} --user'\"\n        )\n        sys.exit(1)\n\n    cmd_line = sys.executable + \" -m black \" + \".\"\n    if arguments.fix == False:\n        cmd_line += \" --check\"\n    if os.system(cmd_line):\n        sys.exit(1)\n"
  },
  {
    "path": "ci/clang/build-llvm.sh",
    "content": "set -ex\nexport PATH=/usr/lib/llvm-15/bin:/usr/lib64/ccache:/root/.local/bin:$PATH\n\n# clean python dir\ncd ${ONEFLOW_CI_SRC_DIR}\n${ONEFLOW_CI_PYTHON_EXE} -m pip install -i https://mirrors.aliyun.com/pypi/simple --user -r ci/fixed-dev-requirements.txt\ncd python\ngit config --global --add safe.directory ${ONEFLOW_CI_SRC_DIR}\ngit clean -nXd -e \\!dist -e \\!dist/**\ngit clean -fXd -e \\!dist -e \\!dist/**\n\n# cmake config\nmkdir -p ${ONEFLOW_CI_BUILD_DIR}\ncd ${ONEFLOW_CI_BUILD_DIR}\nfind ${ONEFLOW_CI_BUILD_DIR} -name CMakeCache.txt\nfind ${ONEFLOW_CI_BUILD_DIR} -name CMakeCache.txt -delete\nif [ ! -f \"$ONEFLOW_CI_CMAKE_INIT_CACHE\" ]; then\n    echo \"$ONEFLOW_CI_CMAKE_INIT_CACHE does not exist.\"\n    exit 1\nfi\ncmake -S ${ONEFLOW_CI_SRC_DIR} -C ${ONEFLOW_CI_CMAKE_INIT_CACHE} -DPython3_EXECUTABLE=${ONEFLOW_CI_PYTHON_EXE}\n# cmake build\ncd ${ONEFLOW_CI_BUILD_DIR}\ncmake --build . -j $(nproc)\n\n# build pip\ncd ${ONEFLOW_CI_SRC_DIR}\ncd python\n${ONEFLOW_CI_PYTHON_EXE} setup.py bdist_wheel\n"
  },
  {
    "path": "ci/conda/build-clang.sh",
    "content": "set -ex\nconda activate oneflow-dev-clang10-v2\nmkdir -p build\ncd build\ncmake .. -C ../cmake/caches/cn/fast/cpu-clang.cmake\ncmake --build . -j $(nproc)\ncd -\ncd python\npython setup.py bdist_wheel\necho \"wheelhouse_dir=$PWD/dist\" >> $GITHUB_ENV\n"
  },
  {
    "path": "ci/conda/tuna.condarc",
    "content": "channels:\n  - defaults\nshow_channel_urls: true\ndefault_channels:\n  - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main\n  - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/r\n  - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/msys2\ncustom_channels:\n  conda-forge: https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud\n  msys2: https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud\n  bioconda: https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud\n  menpo: https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud\n  pytorch: https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud\n  simpleitk: https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud\n"
  },
  {
    "path": "ci/fixed-dev-requirements.txt",
    "content": "numpy==1.26.4 ; python_version >= \"3.12\"\nnumpy==1.22.1 ; python_version >= \"3.10\" and python_version < \"3.12\"\nnumpy==1.21.6 ; python_version >= \"3.7\" and python_version < \"3.10\"\n"
  },
  {
    "path": "ci/manylinux/build-gcc7-xla.sh",
    "content": "source scl_source enable devtoolset-7\nset -ex\nONEFLOW_CI_BUILD_PARALLEL=${ONEFLOW_CI_BUILD_PARALLEL:-$(nproc)}\ngcc --version\nld --version\n# clean python dir\ncd ${ONEFLOW_CI_SRC_DIR}\n${ONEFLOW_CI_PYTHON_EXE} -m pip install -i https://mirrors.aliyun.com/pypi/simple --user -r ci/fixed-dev-requirements.txt\ncd python\ngit clean -nXd -e \\!dist -e \\!dist/**\ngit clean -fXd -e \\!dist -e \\!dist/**\n# cmake config\nmkdir -p ${ONEFLOW_CI_BUILD_DIR}\ncd ${ONEFLOW_CI_BUILD_DIR}\nfind ${ONEFLOW_CI_BUILD_DIR} -name CMakeCache.txt\nfind ${ONEFLOW_CI_BUILD_DIR} -name CMakeCache.txt -delete\nif [ ! -f \"$ONEFLOW_CI_CMAKE_INIT_CACHE\" ]; then\n    echo \"$ONEFLOW_CI_CMAKE_INIT_CACHE does not exist.\"\n    exit 1\nfi\nexport PATH=\"${PATH}:$(dirname ${ONEFLOW_CI_PYTHON_EXE})\"\nexport PYTHON_BIN_PATH=${ONEFLOW_CI_PYTHON_EXE}\ncmake -S ${ONEFLOW_CI_SRC_DIR} -C ${ONEFLOW_CI_CMAKE_INIT_CACHE} -DPython3_EXECUTABLE=${ONEFLOW_CI_PYTHON_EXE}\n\n# cmake build\ncd ${ONEFLOW_CI_BUILD_DIR}\ncmake --build . --parallel ${ONEFLOW_CI_BUILD_PARALLEL}\n\n# build pip\ncd ${ONEFLOW_CI_SRC_DIR}\ncd python\n${ONEFLOW_CI_PYTHON_EXE} setup.py bdist_wheel\n"
  },
  {
    "path": "ci/manylinux/build-gcc9.sh",
    "content": "source scl_source enable devtoolset-9\nset -ex\nONEFLOW_CI_BUILD_PARALLEL=${ONEFLOW_CI_BUILD_PARALLEL:-$(nproc)}\ngcc --version\nld --version\n# clean python dir\ncd ${ONEFLOW_CI_SRC_DIR}\n${ONEFLOW_CI_PYTHON_EXE} -m pip install -i https://mirrors.aliyun.com/pypi/simple --user -r ci/fixed-dev-requirements.txt\n${ONEFLOW_CI_PYTHON_EXE} -m pip install -i https://mirrors.aliyun.com/pypi/simple --user auditwheel setuptools wheel\ncd python\n\nfunction clean_artifacts {\n    git config --global --add safe.directory ${ONEFLOW_CI_SRC_DIR}\n    git clean -nXd -e \\!dist -e \\!dist/**\n    git clean -fXd -e \\!dist -e \\!dist/**\n}\n\nclean_artifacts\n\n# cmake config\nmkdir -p ${ONEFLOW_CI_BUILD_DIR}\ncd ${ONEFLOW_CI_BUILD_DIR}\nfind ${ONEFLOW_CI_BUILD_DIR} -name CMakeCache.txt\nfind ${ONEFLOW_CI_BUILD_DIR} -name CMakeCache.txt -delete\nif [ ! -f \"$ONEFLOW_CI_CMAKE_INIT_CACHE\" ]; then\n    echo \"$ONEFLOW_CI_CMAKE_INIT_CACHE does not exist.\"\n    exit 1\nfi\nexport PATH=\"${PATH}:$(dirname ${ONEFLOW_CI_PYTHON_EXE})\"\nexport PYTHON_BIN_PATH=${ONEFLOW_CI_PYTHON_EXE}\ncmake -S ${ONEFLOW_CI_SRC_DIR} -C ${ONEFLOW_CI_CMAKE_INIT_CACHE} -DPython3_EXECUTABLE=${ONEFLOW_CI_PYTHON_EXE}\n\n# cmake build\ncd ${ONEFLOW_CI_BUILD_DIR}\ncmake --build . --parallel ${ONEFLOW_CI_BUILD_PARALLEL}\nif [ ! -z \"$ONEFLOW_CI_BUILD_RUN_LIT\" ]; then\n    ${ONEFLOW_CI_PYTHON_EXE} -m pip install -i https://mirrors.aliyun.com/pypi/simple --user flowvision==0.1.0\n    export PATH=$PATH:$(dirname $ONEFLOW_CI_PYTHON_EXE)\n    cmake --build . -t c1\nfi\n\n# build pip\ncd ${ONEFLOW_CI_SRC_DIR}\ncd python\n${ONEFLOW_CI_PYTHON_EXE} setup.py bdist_wheel\n"
  },
  {
    "path": "ci/manylinux/build.sh",
    "content": "set -ex\nONEFLOW_CI_BUILD_PARALLEL=${ONEFLOW_CI_BUILD_PARALLEL:-$(nproc)}\ngcc --version\nld --version\n# clean python dir\ncd ${ONEFLOW_CI_SRC_DIR}\n${ONEFLOW_CI_PYTHON_EXE} -m pip install -i https://mirrors.aliyun.com/pypi/simple --user -r ci/fixed-dev-requirements.txt\n${ONEFLOW_CI_PYTHON_EXE} -m pip install -i https://mirrors.aliyun.com/pypi/simple --user auditwheel setuptools wheel\ncd python\n\nfunction clean_artifacts {\n    git config --global --add safe.directory ${ONEFLOW_CI_SRC_DIR}\n    git clean -nXd -e \\!dist -e \\!dist/**\n    git clean -fXd -e \\!dist -e \\!dist/**\n}\n\nclean_artifacts\n\n# cmake config\nmkdir -p ${ONEFLOW_CI_BUILD_DIR}\ncd ${ONEFLOW_CI_BUILD_DIR}\nfind ${ONEFLOW_CI_BUILD_DIR} -name CMakeCache.txt\nfind ${ONEFLOW_CI_BUILD_DIR} -name CMakeCache.txt -delete\nif [ ! -f \"$ONEFLOW_CI_CMAKE_INIT_CACHE\" ]; then\n    echo \"$ONEFLOW_CI_CMAKE_INIT_CACHE does not exist.\"\n    exit 1\nfi\ncmake -S ${ONEFLOW_CI_SRC_DIR} -C ${ONEFLOW_CI_CMAKE_INIT_CACHE} -DPython3_EXECUTABLE=${ONEFLOW_CI_PYTHON_EXE}\n# cmake build\ncd ${ONEFLOW_CI_BUILD_DIR}\ncmake --build . --parallel ${ONEFLOW_CI_BUILD_PARALLEL}\nif [ ! -z \"$ONEFLOW_CI_BUILD_RUN_LIT\" ]; then\n    ${ONEFLOW_CI_PYTHON_EXE} -m pip install -i https://mirrors.aliyun.com/pypi/simple --user flowvision==0.1.0\n    export PATH=$PATH:$(dirname $ONEFLOW_CI_PYTHON_EXE)\n    cmake --build . -t c1\nfi\n\n# build pip\ncd ${ONEFLOW_CI_SRC_DIR}\ncd python\n${ONEFLOW_CI_PYTHON_EXE} setup.py bdist_wheel\n"
  },
  {
    "path": "ci/requirements.txt",
    "content": "pycocotools\nopencv-python==4.3.0.38; sys_platform == 'darwin'\nopencv-python==4.2.0.34; sys_platform != 'darwin'\nscipy\npillow\ntensorflow-addons==0.13.0\ntensorflow==2.5.0\n"
  },
  {
    "path": "ci/reset_submodule.sh",
    "content": "set -x\nset -e\ngit reset --hard\ngit submodule deinit -f .\nrm -rf .git/modules/*\n"
  },
  {
    "path": "ci/setup_submodule.py",
    "content": "import configparser\nimport argparse\nimport os\n\nparser = argparse.ArgumentParser()\nparser.add_argument(\"-s\", \"--oneflow_src_local_path\", type=str, required=False)\nparser.add_argument(\"-r\", \"--oneflow_src_remote_url\", type=str, required=False)\nargs = parser.parse_args()\n\nassert (\n    args.oneflow_src_local_path or args.oneflow_src_remote_url\n), \"require one of oneflow_src_local_path or oneflow_src_remote_url\"\nconfig = configparser.ConfigParser()\nconfig.read(\".gitmodules\")\nfor s in config.sections():\n    path = config[s][\"path\"]\n    if args.oneflow_src_local_path:\n        src_path = os.path.join(args.oneflow_src_local_path, path)\n        assert os.path.exists(\"{}/.git\".format(src_path)), src_path\n        config[s][\"url\"] = \"file://{}\".format(src_path)\n    else:\n        src_path = os.path.join(args.oneflow_src_remote_url, path)\n        config[s][\"url\"] = src_path\n\nwith open(\".gitmodules\", \"w\") as configfile:\n    config.write(configfile)\n"
  },
  {
    "path": "ci/setup_submodule.sh",
    "content": "set -x\nset -e\nsrc_dir=${ONEFLOW_CI_SRC_DIR:-\"$HOME/oneflow\"}\npython3 ci/setup_submodule.py --oneflow_src_local_path=$src_dir\ngit submodule sync\ngit submodule update --init --recursive\n"
  },
  {
    "path": "ci/test/1node_benchmark_test.sh",
    "content": "set -xe\n\nrm -rf /benchmarks\ncp -r python/oneflow/compatible/single_client/benchmarks /benchmarks\ncd /benchmarks\n\npython3 cnn_benchmark/of_cnn_benchmarks.py \\\n    --gpu_num_per_node=1 \\\n    --model=\"vgg16\" \\\n    --batch_size_per_device=8 \\\n    --iter_num=5 \\\n    --learning_rate=0.01 \\\n    --optimizer=\"sgd\" \\\n    --loss_print_every_n_iter=1 \\\n    --data_dir=\"/dataset/imagenet_227/train/32\"\n\npython3 cnn_benchmark/of_cnn_benchmarks.py \\\n    --gpu_num_per_node=1 \\\n    --model=\"alexnet\" \\\n    --batch_size_per_device=8 \\\n    --iter_num=5 \\\n    --learning_rate=0.01 \\\n    --optimizer=\"sgd\" \\\n    --loss_print_every_n_iter=1 \\\n    --data_dir=\"/dataset/imagenet_227/train/32\"\n\npython3 cnn_benchmark/of_cnn_benchmarks.py \\\n    --gpu_num_per_node=1 \\\n    --model=\"resnet50\" \\\n    --batch_size_per_device=8 \\\n    --iter_num=5 \\\n    --gpu_image_decoder=True \\\n    --learning_rate=0.01 \\\n    --optimizer=\"sgd\" \\\n    --loss_print_every_n_iter=1 \\\n    --data_dir=\"/dataset/imagenet_227/train/32\"\n\npython3 cnn_benchmark/of_cnn_benchmarks.py \\\n    --gpu_num_per_node=1 \\\n    --model=\"resnet50\" \\\n    --batch_size_per_device=8 \\\n    --iter_num=5 \\\n    --learning_rate=0.01 \\\n    --optimizer=\"sgd\" \\\n    --loss_print_every_n_iter=1\n\npython3 bert_benchmark/run_pretraining.py \\\n    --gpu_num_per_node=1 \\\n    --node_num=1 \\\n    --learning_rate=1e-4 \\\n    --weight_decay_rate=0.01 \\\n    --batch_size_per_device=24 \\\n    --iter_num=5 \\\n    --loss_print_every_n_iter=1 \\\n    --data_dir=\"/dataset/bert/bert_seq_len_128_repeat1024\" \\\n    --data_part_num=1 \\\n    --seq_length=128 \\\n    --max_predictions_per_seq=20 \\\n    --num_hidden_layers=12 \\\n    --num_attention_heads=12 \\\n    --max_position_embeddings=512 \\\n    --type_vocab_size=2 \\\n    --vocab_size=30522 \\\n    --attention_probs_dropout_prob=0.1 \\\n    --hidden_dropout_prob=0.1 \\\n    --hidden_size_per_head=64\n"
  },
  {
    "path": "ci/test/1node_benchmark_test_fp16.sh",
    "content": "set -ex\n\nrm -rf /benchmarks\ncp -r python/oneflow/compatible/single_client/benchmarks /benchmarks\ncd /benchmarks\n\npython3 cnn_benchmark/of_cnn_benchmarks.py \\\n    --gpu_num_per_node=1 \\\n    --model=\"vgg16\" \\\n    --batch_size_per_device=8 \\\n    --iter_num=5 \\\n    --learning_rate=0.01 \\\n    --optimizer=\"sgd\" \\\n    --loss_print_every_n_iter=1 \\\n    --data_dir=\"/dataset/imagenet_227/train/32\" \\\n    --enable_auto_mixed_precision=True\n\npython3 cnn_benchmark/of_cnn_benchmarks.py \\\n    --gpu_num_per_node=1 \\\n    --model=\"alexnet\" \\\n    --batch_size_per_device=8 \\\n    --iter_num=5 \\\n    --learning_rate=0.01 \\\n    --optimizer=\"sgd\" \\\n    --loss_print_every_n_iter=1 \\\n    --data_dir=\"/dataset/imagenet_227/train/32\" \\\n    --enable_auto_mixed_precision=True\n\npython3 cnn_benchmark/of_cnn_benchmarks.py \\\n    --gpu_num_per_node=1 \\\n    --model=\"resnet50\" \\\n    --batch_size_per_device=8 \\\n    --iter_num=5 \\\n    --learning_rate=0.01 \\\n    --optimizer=\"sgd\" \\\n    --loss_print_every_n_iter=1 \\\n    --data_dir=\"/dataset/imagenet_227/train/32\" \\\n    --enable_auto_mixed_precision=True\n\npython3 bert_benchmark/run_pretraining.py \\\n    --gpu_num_per_node=1 \\\n    --node_num=1 \\\n    --learning_rate=1e-4 \\\n    --weight_decay_rate=0.01 \\\n    --batch_size_per_device=24 \\\n    --iter_num=5 \\\n    --loss_print_every_n_iter=1 \\\n    --data_dir=\"/dataset/bert/bert_seq_len_128_repeat1024\" \\\n    --data_part_num=1 \\\n    --seq_length=128 \\\n    --max_predictions_per_seq=20 \\\n    --num_hidden_layers=12 \\\n    --num_attention_heads=12 \\\n    --max_position_embeddings=512 \\\n    --type_vocab_size=2 \\\n    --vocab_size=30522 \\\n    --attention_probs_dropout_prob=0.1 \\\n    --hidden_dropout_prob=0.1 \\\n    --hidden_size_per_head=64 \\\n    --enable_auto_mixed_precision=True\n"
  },
  {
    "path": "ci/test/1node_custom_op_test.sh",
    "content": "\n#!/bin/bash\nset -xe\n\nsrc_dir=${ONEFLOW_SRC_DIR:-\"$PWD\"}\ntest_tmp_dir=${ONEFLOW_TEST_TMP_DIR:-\"./test_tmp_dir\"}\n\nrm -rf $test_tmp_dir\nmkdir -p $test_tmp_dir\ncp -r $src_dir/python/oneflow/compatible/single_client/test/custom_ops $test_tmp_dir\ncd $test_tmp_dir\n\nexport ONEFLOW_TEST_DEVICE_NUM=1\npython3 -m unittest discover ./custom_ops --failfast --verbose\n"
  },
  {
    "path": "ci/test/1node_model_eager_test.sh",
    "content": "#!/bin/bash\nset -xe\n\ncp -r python/oneflow/test /test_dir\ncd /test_dir\n\npython3 models/eager_1node_test.py\n"
  },
  {
    "path": "ci/test/1node_model_test.sh",
    "content": "#!/bin/bash\nset -xe\n\ncp -r python/oneflow/compatible/single_client/test /test_dir\ncd /test_dir\n\npython3 models/1node_test.py\n"
  },
  {
    "path": "ci/test/1node_op_test.sh",
    "content": "#!/bin/bash\nset -xe\n\nexport TF_CPP_MIN_LOG_LEVEL=3\nexport PYTHONUNBUFFERED=1\n\nsrc_dir=${ONEFLOW_SRC_DIR:-\"$PWD\"}\ntest_tmp_dir=${ONEFLOW_TEST_TMP_DIR:-\"./test_tmp_dir\"}\n\n\nrm -rf $test_tmp_dir\nmkdir -p $test_tmp_dir\ncp -r $src_dir/python/oneflow/compatible/single_client/test $test_tmp_dir\ncd $test_tmp_dir\n\npython3 -m oneflow --doctor\n\ngpu_num=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)\nfor CHUNK in 1\ndo\n\texport ONEFLOW_TEST_DEVICE_NUM=${CHUNK}\n    python3 $src_dir/ci/test/parallel_run.py \\\n        --gpu_num=\"${gpu_num}\" \\\n        --dir=test/ops \\\n        --timeout=1 \\\n        --verbose \\\n        --chunk=${CHUNK}\ndone\n\nif [ -z \"$ONEFLOW_TEST_ENABLE_EAGER\" ]\nthen\n    export ONEFLOW_TEST_DEVICE_NUM=2\n    python3 -m unittest discover test/ops --failfast --verbose\n\n    export ONEFLOW_TEST_DEVICE_NUM=4\n    python3 -m unittest discover test/ops --failfast --verbose\nelse\n    echo \"deadlock unsolved, skipping multi-card eager\"\nfi\n"
  },
  {
    "path": "ci/test/2node_op_test.sh",
    "content": "#!/bin/bash\nset -xe\n\nexport PYTHONUNBUFFERED=1\n\nsrc_dir=${ONEFLOW_SRC_DIR:-\"$PWD\"}\ntest_tmp_dir=${ONEFLOW_TEST_TMP_DIR:-\"./test_tmp_dir\"}\n\n\nrm -rf $test_tmp_dir\nmkdir -p $test_tmp_dir\nchmod -R o+w $test_tmp_dir\ncp -r $src_dir/python/oneflow/compatible/single_client/test $test_tmp_dir\ncd $test_tmp_dir\n\nONEFLOW_TEST_DEVICE_NUM=1 python3 test/ops/test_assign.py --failfast --verbose\nONEFLOW_TEST_DEVICE_NUM=1 python3 test/ops/test_two_node_boxing.py --failfast --verbose\n\nfor device_num in 1 2 4\ndo\n    ONEFLOW_TEST_ENABLE_INIT_BY_HOST_LIST=1 ONEFLOW_TEST_DEVICE_NUM=$device_num python3 -m unittest discover test/ops --failfast --verbose\n    # use a invalid ibverbs lib to test if falling back to epoll works\n    ONEFLOW_TEST_ENABLE_INIT_BY_HOST_LIST=1 ONEFLOW_TEST_DEVICE_NUM=$device_num ONEFLOW_LIBIBVERBS_PATH=invalid_lib python3 -m unittest discover test/ops --failfast --verbose\ndone\n"
  },
  {
    "path": "ci/test/2node_op_test_multi_client.sh",
    "content": "#!/bin/bash\n\nset -xeu\n\nexport PYTHONUNBUFFERED=1\n\nsrc_dir=${ONEFLOW_SRC_DIR:-\"$PWD\"}\nONEFLOW_CI_DEVICE_NUMS=${ONEFLOW_CI_DEVICE_NUMS:-\"1 2 4\"}\n\nfor device_num in ${ONEFLOW_CI_DEVICE_NUMS}\ndo\n    export ONEFLOW_TEST_NODE_NUM=2\n    export ONEFLOW_TEST_DEVICE_NUM=$device_num\n    time python3 ${src_dir}/ci/test/multi_launch.py \\\n        --files \"${ONEFLOW_TEST_DIR}/**/test_*.py\" \\\n        -n 4 \\\n        --group_size $device_num \\\n        --device_num 4 \\\n        --verbose \\\n        --auto_cuda_visible_devices \\\n        -m oneflow.distributed.launch \\\n        --nproc_per_node $device_num --nnodes=2 --node_rank=$NODE_RANK --master_addr $_MASTER_ADDR \\\n        -m pytest --max-worker-restart=0 -x --durations=50 --capture=sys -p no:cacheprovider -p no:randomly --ignore=log\ndone\n"
  },
  {
    "path": "ci/test/CMakeLists.txt",
    "content": "set(PYTHON_EXECUTABLE python3 CACHE STRING \"python3 exe to run test, usually is the python3 installation oneflow is linked to\")\nset(ONEFLOW_SRC_DIR ${CMAKE_SOURCE_DIR} CACHE STRING \"source dir of oneflow\")\nset(IS_DEV ON CACHE BOOL \"\")\nset(CTEST_RESOURCE_SPEC_FILE \"${CMAKE_CURRENT_SOURCE_DIR}/resource-spec/2x-rtx-2080.json\" CACHE STRING \"\")\n\n# CTEST_OUTPUT_ON_FAILURE=1 CTEST_PARALLEL_LEVEL=20 ninja test\n\nfile(GLOB_RECURSE PYTHON_TEST_FILES LIST_DIRECTORIES false RELATIVE ${ONEFLOW_SRC_DIR} \"${ONEFLOW_SRC_DIR}/python/oneflow/test_*.py\")\nforeach(PYTHON_TEST_FILE ${PYTHON_TEST_FILES})\n  set(TEST_NAME ${PYTHON_TEST_FILE})\n  add_test(NAME ${TEST_NAME}\n    COMMAND ${PYTHON_EXECUTABLE} ${ONEFLOW_SRC_DIR}/${PYTHON_TEST_FILE} --failfast --verbose\n  )\n  set_tests_properties(${TEST_NAME}\n    PROPERTIES\n      ENVIRONMENT \"$<$<NOT:$<BOOL:${BUILD_CUDA}>>:ONEFLOW_TEST_CPU_ONLY=1>;$<$<BOOL:${IS_DEV}>:PYTHONPATH=${ONEFLOW_SRC_DIR}/python:$ENV{PYTHONPATH}>\"\n      RESOURCE_GROUPS\n        \"vram:2000\"\n  )\nendforeach()\n"
  },
  {
    "path": "ci/test/build_docs.sh",
    "content": "set -ex\nsrc_dir=${ONEFLOW_SRC_DIR:-\"$PWD\"}\ntest_tmp_dir=${ONEFLOW_TEST_TMP_DIR:-\"$PWD/build-docs\"}\nrm -rf $test_tmp_dir\ncp -r docs ${test_tmp_dir}\ncd ${test_tmp_dir}\n\nmake html SPHINXOPTS=\"-W --keep-going\"\n"
  },
  {
    "path": "ci/test/distributed_run.py",
    "content": "from multiprocessing.connection import Listener\nimport os\nimport subprocess\nimport socket\nimport tempfile\nfrom contextlib import closing\nimport argparse\nimport uuid\nimport getpass\nimport atexit\nimport pathlib\nimport asyncio\nimport glob\nfrom datetime import date\nfrom pathlib import Path\n\nHARD_CODED_AFFILIATIONS = {\n    \"192.168.1.11\": [\"192.168.1.12\",],\n    \"192.168.1.12\": [\"192.168.1.11\",],\n    \"192.168.1.13\": [\"192.168.1.11\",],\n    \"192.168.1.15\": [\"192.168.1.16\",],\n    \"192.168.1.16\": [\"192.168.1.15\",],\n}\n\n\ndef is_img_existing(tag):\n    returncode = subprocess.run(\n        \"docker image inspect {}\".format(tag),\n        shell=True,\n        stdout=subprocess.DEVNULL,\n        stderr=subprocess.DEVNULL,\n    ).returncode\n    if returncode == 0:\n        print(\"[OK]\", tag)\n        return True\n    else:\n        return False\n\n\ndef get_affiliations(host):\n    # TODO(tsai): Implement a HTTP endpoint to retrieve affiliations\n    if host in HARD_CODED_AFFILIATIONS:\n        return HARD_CODED_AFFILIATIONS[host]\n    else:\n        return None\n\n\ndef resolve_hostname_hardcoded(host: str):\n    if host.startswith(\"oneflow\"):\n        number = host.split(\"-\")[-1]\n        return f\"192.168.1.{number}\"\n    else:\n        return host\n\n\ndef find_free_port():\n    with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:\n        s.bind((\"localhost\", 0))\n        s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)\n        return s.getsockname()[1]\n\n\nasync def spawn_shell(cmd: str = None):\n    p = await asyncio.create_subprocess_shell(cmd,)\n    await p.wait()\n    assert p.returncode == 0, cmd\n\n\nasync def spawn_shell_ignoring_failure(cmd: str = None):\n    p = await asyncio.create_subprocess_shell(cmd,)\n    await p.wait()\n\n\nasync def build_docker_img(remote_host=None, workspace_dir=None):\n    if remote_host:\n        assert workspace_dir\n        await spawn_shell(\"rm -f > oneflow-src.zip\")\n        await spawn_shell(\"git archive --format zip HEAD > oneflow-src.zip\")\n        await spawn_shell(\n            f\"scp oneflow-src.zip {remote_host}:{workspace_dir}/oneflow-src.zip\",\n        )\n        await spawn_shell(\n            f\"ssh  {remote_host} unzip {workspace_dir}/oneflow-src.zip -d {workspace_dir}/oneflow-src\",\n        )\n        await spawn_shell(\n            f\"ssh  {remote_host} bash {workspace_dir}/oneflow-src/docker/ci/test/build.sh\",\n        )\n    else:\n        await spawn_shell(f\"bash docker/ci/test/build.sh\")\n\n\nasync def create_remote_workspace_dir(\n    remote_host=None, workspace_dir=None, copy_files=None\n):\n    await spawn_shell(f\"ssh {remote_host} mkdir -p {workspace_dir}\")\n    if copy_files is not None:\n        for path in copy_files:\n            # Reference: https://stackoverflow.com/a/31278462\n            if os.path.isdir(path) and path[-1] != \"/\":\n                path += \"/\"\n            await spawn_shell(f\"ssh {remote_host} mkdir -p {workspace_dir}/{path}\")\n            await spawn_shell(\n                f\"rsync -azPq --omit-dir-times --no-perms --no-group --copy-links --exclude='__pycache__' {path} {remote_host}:{workspace_dir}/{path}\"\n            )\n    print(\"create_remote_workspace_dir done\")\n\n\ndef get_docker_cache_args():\n    return \" \".join(\n        [\n            f\"-v {Path.home() / 'test-container-cache/dot-local'}:/root/.local\",\n            f\"-v {Path.home() / 'test-container-cache/dot-cache'}:/root/.cache\",\n        ]\n    )\n\n\nasync def launch_remote_container(\n    remote_host=None,\n    survival_time=None,\n    workspace_dir=None,\n    container_name=None,\n    img_tag=None,\n    oneflow_wheel_path=None,\n    oneflow_python_path=None,\n    cmd=None,\n    node_rank=None,\n    master_addr=None,\n):\n    print(\"launching remote container at\", remote_host)\n    assert img_tag\n    multi_client_args = [node_rank, master_addr]\n    multi_client_arg_has_value = [x is not None for x in multi_client_args]\n    assert all(multi_client_arg_has_value)\n    pythonpath_args = None\n    if oneflow_wheel_path:\n        pythonpath_args = \"\"\n    elif oneflow_python_path:\n        pythonpath_args = f\"--env PYTHONPATH={workspace_dir}/python\"\n    else:\n        raise ValueError(\"must have oneflow_wheel_path or oneflow_python_path\")\n    docker_cmd = f\"\"\"docker run --privileged -d --network host --shm-size=8g --rm {get_docker_cache_args()} -v {workspace_dir}:{workspace_dir} -w {workspace_dir} -v /dataset:/dataset -v /model_zoo:/model_zoo --name {container_name} {pythonpath_args} {img_tag} sleep {survival_time}\n\"\"\"\n    await spawn_shell(f\"ssh {remote_host} {docker_cmd}\")\n    if oneflow_wheel_path:\n        whl_basename = os.path.basename(oneflow_wheel_path)\n        await spawn_shell(\n            f\"ssh {remote_host} docker exec {container_name} python3 -m pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple\"\n        )\n        await spawn_shell(\n            f\"ssh {remote_host} docker exec {container_name} python3 -m pip install {workspace_dir}/{whl_basename}\"\n        )\n    await spawn_shell(\n        f\"ssh {remote_host} docker exec {container_name} python3 -m oneflow --doctor\"\n    )\n    if cmd:\n        multi_client_docker_args = (\n            # Use _MASTER_ADDR to avoid name conflict with OneFlow's built-in MASTER_ADDR\n            f\"--env NODE_RANK={node_rank} --env _MASTER_ADDR={master_addr}\"\n        )\n        await spawn_shell(\n            f\"ssh {remote_host} docker exec {multi_client_docker_args} {container_name} {cmd}\"\n        )\n\n\ndef handle_cast(conn=None, cmd=None):\n    received_cmd: str = conn.recv().decode()\n    assert received_cmd.startswith(\"cast/\")\n    received_cmd = received_cmd.replace(\"cast/\", \"\")\n    assert received_cmd == cmd, (received_cmd, cmd)\n    return conn.recv().decode()\n\n\ndef handle_call(conn=None, cmd=None, response=None):\n    received_cmd: str = conn.recv().decode()\n    assert received_cmd.startswith(\"call/\")\n    received_cmd = received_cmd.replace(\"call/\", \"\")\n    assert received_cmd == cmd, (received_cmd, cmd)\n    msg = conn.recv().decode()\n    conn.send(response.encode())\n    return msg\n\n\nclass DockerAgent:\n    def __init__(\n        self,\n        port=None,\n        authkey=None,\n        this_host=None,\n        remote_hosts=None,\n        container_name=None,\n        timeout=None,\n        workspace_dir=None,\n        img_tag=None,\n        oneflow_wheel_path=None,\n        oneflow_python_path=None,\n        oneflow_test_tmp_dir=None,\n        extra_docker_args: str = None,\n    ) -> None:\n        # info\n        self.this_host = this_host\n        self.remote_hosts = remote_hosts\n        self.container_name = container_name\n        self.timeout = timeout\n        self.common_docker_args = \"--privileged --rm --network host --shm-size=8g -v $HOME:$HOME -v /dataset:/dataset -v /model_zoo:/model_zoo\"\n        self.workspace_dir = workspace_dir\n        self.img_tag = img_tag\n        self.oneflow_wheel_path = oneflow_wheel_path\n        self.oneflow_python_path = oneflow_python_path\n        self.oneflow_test_tmp_dir = oneflow_test_tmp_dir\n        # impl\n        self.env_proto_txt = None\n        self.bash_tmp_file = None\n        self.bash_proc = None\n        self.remote_docker_proc = {}\n        self.agent_port = port\n        self.agent_authkey = authkey\n        self.extra_docker_args = extra_docker_args\n\n    def __enter__(self):\n        return self\n\n    def run_bash_script_async(self, bash_script=None, cmd=None):\n        remote_hosts_str = \",\".join(self.remote_hosts)\n        ctrl_port = find_free_port()\n        data_port = find_free_port()\n        exports = f\"\"\"\nexport ONEFLOW_TEST_MASTER_PORT={ctrl_port}\nexport ONEFLOW_TEST_DATA_PORT={data_port}\nexport ONEFLOW_TEST_NODE_LIST=\"{self.this_host},{remote_hosts_str}\"\nexport ONEFLOW_WORKER_KEEP_LOG=1\nexport ONEFLOW_TEST_TMP_DIR=\"{self.oneflow_test_tmp_dir}\"\nexport NCCL_DEBUG=INFO\nexport ONEFLOW_TEST_WORKER_AGENT_PORT={agent_port}\nexport ONEFLOW_TEST_WORKER_AGENT_AUTHKEY={agent_authkey}\npython3 -m pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple\n\"\"\"\n        if self.oneflow_wheel_path:\n            exports += f\"python3 -m pip install {self.oneflow_wheel_path}\"\n        if self.oneflow_python_path:\n            exports += f\"export PYTHONPATH={self.oneflow_python_path}:$PYTHONPATH\\n\"\n        bash_cmd = None\n        if bash_script:\n            assert os.path.exists(bash_script)\n            bash_cmd = f\"\"\"set -ex\n{exports}\nbash {bash_script}\n\"\"\"\n        elif cmd:\n            bash_cmd = f\"\"\"set -ex\n{exports}\n{cmd}\n\"\"\"\n        else:\n            raise ValueError(\"not impl\")\n        assert bash_cmd\n\n        def get_docker_cmd(f, cmd):\n            f_name = f.name\n            f.write(cmd)\n            f.flush()\n            return f\"docker run {self.common_docker_args} {self.extra_docker_args} {get_docker_cache_args()} -v /tmp:/host/tmp:ro -v $PWD:$PWD -w $PWD --name {self.container_name} {self.img_tag} bash /host{f_name}\"\n\n        f = tempfile.NamedTemporaryFile(mode=\"w+\", encoding=\"utf-8\", delete=True)\n        run_docker_cmd = get_docker_cmd(f, bash_cmd)\n        self.bash_tmp_file = f\n        self.bash_proc = subprocess.Popen(run_docker_cmd, shell=True)\n\n    def __exit__(self, exc_type, exc_val, exc_tb):\n        pass\n\n\nasync def fix_and_sync_libs(oneflow_internal_path=None, remote_hosts=None):\n    tmp_dir = tempfile.TemporaryDirectory()\n    tmp_lib_dir = os.path.join(tmp_dir.name, \"libs\")\n    os.mkdir(tmp_lib_dir)\n    await spawn_shell(\n        \"\"\"ldd file | grep \"=> /\" | awk '{print $3}' | xargs -I '{}' cp -v '{}' destination\"\"\".replace(\n            \"file\", oneflow_internal_path\n        ).replace(\n            \"destination\", tmp_lib_dir\n        ),\n    )\n    libs = os.listdir(tmp_lib_dir)\n    assert len(libs) > 0\n    excludelist_path = os.path.join(\n        pathlib.Path(__file__).parent.absolute(), \"excludelist\"\n    )\n    excludelist = open(excludelist_path).read().split(\"\\n\")\n    await spawn_shell(f\"cp {oneflow_internal_path} {tmp_dir.name}\")\n\n    def handle_lib(lib):\n        if lib in excludelist or \"libpython\" in lib:\n            print(\"excluding\", lib)\n            return spawn_shell(f\"rm {tmp_lib_dir}/{lib}\")\n        else:\n            print(\"keeping\", lib)\n            return spawn_shell(f\"patchelf --set-rpath '$ORIGIN' {tmp_lib_dir}/{lib}\")\n\n    await asyncio.gather(*(handle_lib(lib) for lib in libs))\n\n    tmp_oneflow_internal_path = os.path.join(\n        tmp_dir.name, pathlib.Path(oneflow_internal_path).name\n    )\n    print(\"before fixing .so\")\n    await spawn_shell(f\"ldd {tmp_oneflow_internal_path}\")\n    print(\"fixing .so\")\n    await spawn_shell(\n        f\"patchelf --set-rpath '$ORIGIN/libs' {tmp_oneflow_internal_path}\"\n    )\n\n    await asyncio.gather(\n        *[\n            spawn_shell(\n                f\"ssh {remote_host} 'mkdir -p {workspace_dir}/python/oneflow/libs'\",\n            )\n            for remote_host in remote_hosts\n        ]\n    )\n\n    async def copy_file(path=None, remote_host=None):\n        relpath = os.path.relpath(path, tmp_dir.name)\n        await spawn_shell(\n            f\"scp {path} {remote_host}:{workspace_dir}/python/oneflow/{relpath}\",\n        )\n\n    files = [\n        os.path.join(root, name)\n        for root, dirs, files in os.walk(tmp_dir.name, topdown=True)\n        for name in files\n    ]\n\n    await asyncio.gather(\n        *[\n            copy_file(path=f, remote_host=remote_host)\n            for remote_host in remote_hosts\n            for f in files\n        ],\n        spawn_shell(f\"ldd {tmp_oneflow_internal_path}\"),\n    )\n\n\nasync def remove_containers_by_name(remote_hosts=None, container_name=None):\n    rm_cmd = f\"docker rm -f {container_name}\"\n    assert container_name\n    assert remote_hosts\n    await asyncio.gather(\n        *[\n            spawn_shell_ignoring_failure(f\"ssh {remote_host} {rm_cmd}\")\n            for remote_host in remote_hosts\n        ],\n        spawn_shell_ignoring_failure(rm_cmd),\n    )\n\n\ndef get_remote_hosts(args):\n    remote_hosts = None\n    if len(args.remote_host) == 1:\n        remote_hosts = args.remote_host.split(\",\")\n    elif len(args.remote_host) == 0:\n        affiliations = get_affiliations(this_host)\n        assert (\n            affiliations\n        ), f\"no affiliated node found for {this_host}, you should specify one\"\n        remote_host = affiliations[0]\n        remote_host = socket.gethostbyname(remote_host)\n        remote_hosts = [remote_host]\n    else:\n        remote_hosts = args.remote_host\n    return remote_hosts\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--debug\", action=\"store_true\", required=False, default=False)\n    parser.add_argument(\n        \"--skip_libs\", action=\"store_true\", required=False, default=False\n    )\n    parser.add_argument(\"--bash_script\", type=str, required=False)\n    default_this_host = socket.gethostname()\n    parser.add_argument(\n        \"--this_host\", type=str, required=False, default=default_this_host\n    )\n    parser.add_argument(\"--remote_host\", action=\"append\", default=[])\n    parser.add_argument(\"--oneflow_wheel_path\", type=str, required=False, default=None)\n    parser.add_argument(\n        \"--oneflow_wheel_python_version\", type=str, required=False, default=None\n    )\n    parser.add_argument(\"--oneflow_python_path\", type=str, required=False, default=None)\n    parser.add_argument(\"--custom_img_tag\", type=str, required=False, default=None)\n    parser.add_argument(\"--cmd\", type=str, required=False, default=None)\n    parser.add_argument(\n        \"--oneflow_test_tmp_dir\", type=str, required=False, default=\"distributed-tmp\"\n    )\n    parser.add_argument(\"--timeout\", type=int, required=False, default=1 * 60 * 60)\n    parser.add_argument(\"--mode\", type=str, required=False, default=\"multi_client\")\n    parser.add_argument(\"--copy_files\", action=\"append\", default=[])\n    args = parser.parse_args()\n\n    assert args.mode in [\"multi_client\"]\n    assert bool(args.oneflow_wheel_path) != bool(args.oneflow_python_path)\n    assert bool(args.bash_script) != bool(args.cmd)\n    if args.skip_libs:\n        assert args.debug, \"--skip_libs only works with --debug\"\n        assert (\n            args.oneflow_python_path\n        ), \"--skip_libs only works with --oneflow_python_path\"\n\n    oneflow_wheel_path = args.oneflow_wheel_path\n    main_node_extra_docker_args = []\n    if oneflow_wheel_path and os.path.isdir(oneflow_wheel_path):\n        assert os.path.isabs(oneflow_wheel_path)\n        main_node_extra_docker_args.append(\n            f\"-v {oneflow_wheel_path}:{oneflow_wheel_path}:ro\"\n        )\n        whl_paths = [\n            name for name in glob.glob(os.path.join(oneflow_wheel_path, f\"*.whl\",))\n        ]\n        if len(whl_paths) == 1:\n            oneflow_wheel_path = whl_paths[0]\n        else:\n            assert args.oneflow_wheel_python_version\n            assert args.oneflow_wheel_python_version in [\n                \"3.6\",\n                \"3.7\",\n                \"3.8\",\n                \"3.9\",\n                \"3.10\",\n                \"3.11\",\n            ]\n            ver_cat = args.oneflow_wheel_python_version.replace(\".\", \"\")\n            found = False\n            for whl_path in whl_paths:\n                if f\"cp{ver_cat}\" in whl_path:\n                    oneflow_wheel_path = whl_path\n                    found = True\n            assert found, whl_paths\n\n    this_host = args.this_host\n    this_host = resolve_hostname_hardcoded(this_host)\n\n    remote_hosts = get_remote_hosts(args)\n\n    print(f\"this_host: {this_host}, remote_hosts: {remote_hosts}\", flush=True)\n    sub_dir = str(uuid.uuid4())\n    if args.debug:\n        sub_dir = \"debug\"\n    workspace_dir = os.path.join(\n        os.path.expanduser(\"~\"), \"distributed_run_workspace\", sub_dir\n    )\n    print(\"workspace_dir\", workspace_dir)\n    container_name = (\n        getpass.getuser()\n        + \"-distributed-run-main-node-at-\"\n        + this_host.replace(\".\", \"-\")\n    )\n    if args.mode == \"multi_client\":\n        remote_hosts = [this_host] + remote_hosts\n    loop = asyncio.get_event_loop()\n    # add host key to all machines (needed by ssh/scp/rsync)\n    loop.run_until_complete(\n        asyncio.gather(\n            *[\n                spawn_shell(f\"ssh -o StrictHostKeyChecking=no {remote_host} true\")\n                for remote_host in remote_hosts\n            ],\n        ),\n    )\n    loop.run_until_complete(\n        asyncio.gather(\n            *[\n                create_remote_workspace_dir(\n                    remote_host=remote_host,\n                    workspace_dir=workspace_dir,\n                    copy_files=args.copy_files,\n                )\n                for remote_host in remote_hosts\n            ],\n            remove_containers_by_name(\n                remote_hosts=remote_hosts, container_name=container_name\n            ),\n        ),\n    )\n    if args.oneflow_python_path:\n        so_paths = [\n            name\n            for name in glob.glob(\n                os.path.join(\n                    args.oneflow_python_path, f\"oneflow/_oneflow_internal.*.so\",\n                )\n            )\n        ]\n        assert len(so_paths) == 1, so_paths\n        oneflow_internal_path = so_paths[0]\n        oneflow_internal_path = os.path.join(\n            args.oneflow_python_path, oneflow_internal_path\n        )\n        tmp_dir = None\n        print(\"copying oneflow python dir\")\n        loop.run_until_complete(\n            asyncio.gather(\n                *[\n                    spawn_shell(\n                        f\"rsync -azPq --omit-dir-times --no-perms --no-group --copy-links --include='*.py' --exclude='*.so' --exclude='__pycache__' --exclude='oneflow/include' --include='*/' --exclude='*' {args.oneflow_python_path} {remote_host}:{workspace_dir}\"\n                    )\n                    for remote_host in remote_hosts\n                ]\n            )\n        )\n        if args.skip_libs == False:\n            print(\"copying .so\")\n            loop.run_until_complete(\n                fix_and_sync_libs(\n                    oneflow_internal_path=oneflow_internal_path,\n                    remote_hosts=remote_hosts,\n                )\n            )\n    elif oneflow_wheel_path:\n        loop.run_until_complete(\n            asyncio.gather(\n                *[\n                    spawn_shell(\n                        f\"rsync -azPq --omit-dir-times --no-perms --no-group {oneflow_wheel_path} {remote_host}:{workspace_dir}\"\n                    )\n                    for remote_host in remote_hosts\n                ]\n            )\n        )\n    default_docker_image = \"oneflow-test:$USER\"\n    ci_user_docker_image = \"oneflow-test:0.2\"\n    img_tag = None\n    if args.custom_img_tag == None:\n        if is_img_existing(default_docker_image):\n            img_tag = default_docker_image\n        elif is_img_existing(ci_user_docker_image):\n            img_tag = ci_user_docker_image\n        else:\n            loop.run_until_complete(\n                asyncio.gather(\n                    *[\n                        build_docker_img(\n                            remote_host=remote_host, workspace_dir=workspace_dir\n                        )\n                        for remote_host in remote_hosts\n                    ],\n                    build_docker_img(workspace_dir=workspace_dir),\n                )\n            )\n            img_tag = default_docker_image\n    else:\n        img_tag = args.custom_img_tag\n    assert img_tag\n    agent_port = find_free_port()\n    agent_authkey = str(uuid.uuid4())\n\n    def exit_handler():\n        print(\n            \"---------start cleanup, you should ignore errors below and check the errors above---------\"\n        )\n        if args.oneflow_python_path:\n            print(\"fixing permission of\", args.oneflow_python_path)\n            subprocess.call(\n                f\"docker run --rm -v {args.oneflow_python_path}:/p -w /p busybox chmod -R o+w .\",\n                shell=True,\n            )\n        loop.run_until_complete(\n            asyncio.gather(\n                *[\n                    spawn_shell_ignoring_failure(\n                        f\"ssh {remote_host} docker run --rm -v {workspace_dir}:/p -w /p busybox chmod -R 777 .\",\n                    )\n                    for remote_host in remote_hosts\n                ],\n            )\n        )\n        print(\"copying artifacts\")\n        extra_exclude_args = \"\"\n        for path in args.copy_files:\n            extra_exclude_args += f\"--exclude='{path}' \"\n        loop.run_until_complete(\n            asyncio.gather(\n                *[\n                    spawn_shell_ignoring_failure(\n                        f\"rsync -azPq --omit-dir-times --no-perms --no-group --exclude='*.whl' --exclude='python' {extra_exclude_args} {remote_host}:{workspace_dir}/ {args.oneflow_test_tmp_dir}/{remote_host}\"\n                    )\n                    for remote_host in remote_hosts\n                ]\n            )\n        )\n        assert workspace_dir\n        if args.debug == False:\n            print(\"removing docker workspace_dir:\", workspace_dir)\n            loop.run_until_complete(\n                asyncio.gather(\n                    *[\n                        spawn_shell_ignoring_failure(\n                            f\"ssh {remote_host} rm -rf {workspace_dir}\",\n                        )\n                        for remote_host in remote_hosts\n                    ],\n                )\n            )\n        print(\"removing docker container:\", container_name)\n        loop.run_until_complete(\n            remove_containers_by_name(\n                remote_hosts=remote_hosts, container_name=container_name\n            )\n        )\n\n    atexit.register(exit_handler)\n    if args.mode == \"multi_client\":\n        if args.bash_script:\n            args.cmd = f\"bash {args.bash_script}\"\n        loop.run_until_complete(\n            asyncio.gather(\n                *[\n                    launch_remote_container(\n                        remote_host=remote_host,\n                        survival_time=args.timeout,\n                        workspace_dir=workspace_dir,\n                        container_name=container_name,\n                        oneflow_wheel_path=oneflow_wheel_path,\n                        oneflow_python_path=args.oneflow_python_path,\n                        img_tag=img_tag,\n                        cmd=args.cmd,\n                        node_rank=node_rank,\n                        master_addr=this_host,\n                    )\n                    for node_rank, remote_host in enumerate(remote_hosts)\n                ],\n            )\n        )\n    else:\n        loop.run_until_complete(\n            asyncio.gather(\n                *[\n                    launch_remote_container(\n                        remote_host=remote_host,\n                        survival_time=args.timeout,\n                        workspace_dir=workspace_dir,\n                        container_name=container_name,\n                        oneflow_wheel_path=oneflow_wheel_path,\n                        oneflow_python_path=args.oneflow_python_path,\n                        img_tag=img_tag,\n                    )\n                    for remote_host in remote_hosts\n                ],\n            )\n        )\n"
  },
  {
    "path": "ci/test/doctest.sh",
    "content": "#!/bin/bash\nset -xe\nexport PYTHONUNBUFFERED=1\nsrc_dir=${ONEFLOW_SRC_DIR:-\"$PWD\"}\ntest_tmp_dir=${ONEFLOW_TEST_TMP_DIR:-\"./test_tmp_dir\"}\n\nmkdir -p ${test_tmp_dir}\ncd ${test_tmp_dir}\npython3 -c 'import oneflow; f=open(\"oneflow_path.txt\", \"w\"); f.write(oneflow.__path__[0])'\n\ngpu_num=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)\npython3 $src_dir/ci/test/parallel_run.py \\\n    --gpu_num=${gpu_num} \\\n    --dir=$(cat oneflow_path.txt) \\\n    --timeout=1 \\\n    --verbose \\\n    --chunk=1 \\\n    --doctest\n"
  },
  {
    "path": "ci/test/excludelist",
    "content": "# This file lists libraries that we will assume to be present on the host system and hence\n# should NOT be bundled inside AppImages. This is a working document; expect it to change\n# over time. File format: one filename per line. Each entry should have a justification comment.\n\n# See the useful tool at https://abi-laboratory.pro/index.php?view=navigator&symbol=hb_buffer_set_cluster_level#result\n# to investigate issues with missing symbols.\n\nld-linux.so.2\nld-linux-x86-64.so.2\nlibanl.so.1\nlibBrokenLocale.so.1\nlibcidn.so.1\n# libcrypt.so.1 # Not part of glibc anymore as of Fedora 30. See https://github.com/slic3r/Slic3r/issues/4798 and https://pagure.io/fedora-docs/release-notes/c/01d74b33564faa42959c035e1eee286940e9170e?branch=f28\nlibc.so.6\nlibdl.so.2\nlibm.so.6\nlibmvec.so.1\n# libnsl.so.1 # Not part of glibc anymore as of Fedora 28. See https://github.com/RPCS3/rpcs3/issues/5224#issuecomment-434930594\nlibnss_compat.so.2\n# libnss_db.so.2 # Not part of neon-useredition-20190321-0530-amd64.iso\nlibnss_dns.so.2\nlibnss_files.so.2\nlibnss_hesiod.so.2\nlibnss_nisplus.so.2\nlibnss_nis.so.2\nlibpthread.so.0\nlibresolv.so.2\nlibrt.so.1\nlibthread_db.so.1\nlibutil.so.1\n# These files are all part of the GNU C Library which should never be bundled.\n# List was generated from a fresh build of glibc 2.25.\n\nlibstdc++.so.6\n# Workaround for:\n# usr/lib/libstdc++.so.6: version `GLIBCXX_3.4.21' not found\n\nlibGL.so.1\n# The above may be missing on Chrome OS, https://www.reddit.com/r/Crostini/comments/d1lp67/ultimaker_cura_no_longer_running_as_an_appimage/\nlibEGL.so.1\n# Part of the video driver (OpenGL); present on any regular\n# desktop system, may also be provided by proprietary drivers.\n# Known to cause issues if it's bundled.\n\nlibGLdispatch.so.0\nlibGLX.so.0\n# reported to be superfluent and conflicting system libraries (graphics driver)\n# see https://github.com/linuxdeploy/linuxdeploy/issues/89\n\nlibOpenGL.so.0\n# Qt installed via install-qt.sh apparently links to this library\n# part of OpenGL like libGL/libEGL, so excluding it should not cause any problems\n# https://github.com/linuxdeploy/linuxdeploy/issues/152\n\nlibdrm.so.2\n# Workaround for:\n# Antergos Linux release 2015.11 (ISO-Rolling)\n# /usr/lib/libdrm_amdgpu.so.1: error: symbol lookup error: undefined symbol: drmGetNodeTypeFromFd (fatal)\n# libGL error: unable to load driver: swrast_dri.so\n# libGL error: failed to load driver: swrast\n# Unrecognized OpenGL version\n\nlibglapi.so.0\n# Part of mesa\n# known to cause problems with graphics, see https://github.com/RPCS3/rpcs3/issues/4427#issuecomment-381674910\n\nlibgbm.so.1\n# Part of mesa\n# https://github.com/probonopd/linuxdeployqt/issues/390#issuecomment-529036305\n\nlibxcb.so.1\n# Workaround for:\n# Fedora 23\n# symbol lookup error: /lib64/libxcb-dri3.so.0: undefined symbol: xcb_send_fd\n# Uncertain if this is required to be bundled for some distributions - if so we need to write a version check script and use LD_PRELOAD to load the system version if it is newer\n# Fedora 25:\n# undefined symbol: xcb_send_request_with_fds\n# https://github.com/AppImage/AppImages/issues/128\n\nlibX11.so.6\n# Workaround for:\n# Fedora 23\n# symbol lookup error: ./lib/libX11.so.6: undefined symbol: xcb_wait_for_reply64\n# Uncertain if this is required to be bundled for some distributions - if so we need to write a version check script and use LD_PRELOAD to load the system version if it is newer\n\nlibgio-2.0.so.0\n# Workaround for:\n# On Ubuntu, \"symbol lookup error: /usr/lib/x86_64-linux-gnu/gtk-2.0/modules/liboverlay-scrollbar.so: undefined symbol: g_settings_new\"\n\n# libgdk-x11-2.0.so.0 # Missing on openSUSE-Tumbleweed-KDE-Live-x86_64-Snapshot20170601-Media.iso\n# libgtk-x11-2.0.so.0 # Missing on openSUSE-Tumbleweed-KDE-Live-x86_64-Snapshot20170601-Media.iso\n\nlibasound.so.2\n# Workaround for:\n# No sound, e.g., in VLC.AppImage (does not find sound cards)\n\n# https://github.com/AppImage/pkg2appimage/issues/475\n# libgdk_pixbuf-2.0.so.0\n# Was: Workaround for:\n# On Ubuntu, get (inkscape:25621): GdkPixbuf-WARNING **: Error loading XPM image loader: Image type 'xpm' is not supported\n\nlibfontconfig.so.1\n# Workaround for:\n# Application stalls when loading fonts during application launch; e.g., KiCad on ubuntu-mate\n\nlibthai.so.0\n# Workaround for:\n# audacity: /tmp/.mount_AudaciUsFbON/usr/lib/libthai.so.0: version `LIBTHAI_0.1.25' not found (required by /usr/lib64/libpango-1.0.so.0)\n# on openSUSE Tumbleweed\n\n# other \"low-level\" font rendering libraries\n# should fix https://github.com/probonopd/linuxdeployqt/issues/261#issuecomment-377522251\n# and https://github.com/probonopd/linuxdeployqt/issues/157#issuecomment-320755694\nlibfreetype.so.6\nlibharfbuzz.so.0\n\n# Note, after discussion we do not exlude this, but we can use a dummy library that just does nothing\n# libselinux.so.1\n# Workaround for:\n# sed: error while loading shared libraries: libpcre.so.3: cannot open shared object file: No such file or directory\n# Some distributions, such as Arch Linux, do not come with libselinux.so.1 by default.\n# The solution is to bundle a dummy mock library:\n# echo \"extern int is_selinux_enabled(void){return 0;}\" >> selinux-mock.c\n# gcc -s -shared -o libselinux.so.1 -Wl,-soname,libselinux.so.1 selinux-mock.c\n# strip libselinux.so.1\n# More information: https://github.com/AppImage/AppImages/issues/83\n# and https://github.com/AppImage/AppImageKit/issues/775#issuecomment-614954821\n# https://gitlab.com/sulinos/devel/libselinux-dummy\n\n# The following are assumed to be part of the base system\n# Removing these has worked e.g., for Krita. Feel free to report if\n# you think that some of these should go into AppImages and why.\nlibcom_err.so.2\nlibexpat.so.1\nlibgcc_s.so.1\nlibglib-2.0.so.0\nlibgpg-error.so.0\n# libgssapi_krb5.so.2 # Disputed, seemingly needed by Arch Linux since Kerberos is named differently there\n# libgssapi.so.3 # Seemingly needed when running Ubuntu 14.04 binaries on Fedora 23\n# libhcrypto.so.4 # Missing on openSUSE LEAP 42.0\n# libheimbase.so.1 # Seemingly needed when running Ubuntu 14.04 binaries on Fedora 23\n# libheimntlm.so.0 # Seemingly needed when running Ubuntu 14.04 binaries on Fedora 23\n# libhx509.so.5 # Missing on openSUSE LEAP 42.0\nlibICE.so.6\n# libidn.so.11 # Does not come with Solus by default\n# libk5crypto.so.3 # Runnning AppImage built on Debian 9 or Ubuntu 16.04 on an Archlinux fails otherwise; https://github.com/AppImage/AppImages/issues/301\n# libkeyutils.so.1 # Does not come with Void Linux by default; https://github.com/Subsurface-divelog/subsurface/issues/1971#issuecomment-466606834\n# libkrb5.so.26 # Disputed, seemingly needed by Arch Linux since Kerberos is named differently there. Missing on openSUSE LEAP 42.0\n# libkrb5.so.3 # Disputed, seemingly needed by Arch Linux since Kerberos is named differently there\n# libkrb5support.so.0 # Disputed, seemingly needed by Arch Linux since Kerberos is named differently there\nlibp11-kit.so.0\n# libpcre.so.3 # Missing on Fedora 24, SLED 12 SP1, and openSUSE Leap 42.2\n# libroken.so.18 # Mission on openSUSE LEAP 42.0\n# libsasl2.so.2 # Seemingly needed when running Ubuntu 14.04 binaries on Fedora 23\nlibSM.so.6\nlibusb-1.0.so.0\nlibuuid.so.1\n# libwind.so.0 # Missing on openSUSE LEAP 42.0\n\n# Potentially dangerous libraries\nlibgobject-2.0.so.0\n\n# Workaround for:\n# Rectangles instead of fonts\n# https://github.com/AppImage/AppImages/issues/240\nlibpangoft2-1.0.so.0\nlibpangocairo-1.0.so.0\nlibpango-1.0.so.0\n\n# FIXME:\n# Can get symbol lookup error: /lib64/libpango-1.0.so.0: undefined symbol: g_log_structured_standard\n# if libcairo is bundled but libpango is not\n\n# Workaround for:\n# e.g., Spotify\n# relocation error: /lib/x86_64-linux-gnu/libgcrypt.so.20:\n# symbol gpgrt_lock_lock, version GPG_ERROR_1.0 not defined\n# in file libgpg-error.so.0 with link time reference\nlibgpg-error.so.0\n\nlibjack.so.0\n# it must match the ABI of the JACK server which is installed in the base system\n# rncbc confirmed this\n# However, this library is missing on Fedora-WS-Live-31-1-9\n# which means that we should avoid using JACK altogether if possible\n\n# Unsolved issue:\n# https://github.com/probonopd/linuxdeployqt/issues/35\n# Error initializing NSS with a persistent database (sql:/home/me/.pki/nssdb): libsoftokn3.so: cannot open shared object file: No such file or directory\n# Error initializing NSS without a persistent database: NSS error code: -5925\n# nss_error=-5925, os_error=0\n# libnss3.so should not be removed from the bundles, as this causes other issues, e.g.,\n# https://github.com/probonopd/linuxdeployqt/issues/35#issuecomment-256213517\n# and https://github.com/AppImage/AppImages/pull/114\n# libnss3.so\n\n# The following cannot be excluded, see\n# https://github.com/AppImage/AppImages/commit/6c7473d8cdaaa2572248dcc53d7f617a577ade6b\n# http://stackoverflow.com/questions/32644157/forcing-a-binary-to-use-a-specific-newer-version-of-a-shared-library-so\n# libssl.so.1\n# libssl.so.1.0.0\n# libcrypto.so.1\n# libcrypto.so.1.0.0\n\n# According to https://github.com/RicardoEPRodrigues/3Engine/issues/4#issuecomment-511598362\n# libGLEW is not tied to a specific GPU. It's linked against libGL.so.1\n# and that one is different depending on the installed driver.\n# In fact libGLEW is changing its soversion very often, so you should always bundle libGLEW.so.2.0\n\n# libglut.so.3 # to be confirmed\n\nlibxcb-dri3.so.0 # https://github.com/AppImage/AppImages/issues/348\nlibxcb-dri2.so.0 # https://github.com/probonopd/linuxdeployqt/issues/331#issuecomment-442276277\n\n# If the next line turns out to cause issues, we will have to remove it again and find another solution\nlibfribidi.so.0 # https://github.com/olive-editor/olive/issues/221 and https://github.com/knapsu/plex-media-player-appimage/issues/14\n\n# Workaround for:\n# symbol lookup error: /lib/x86_64-linux-gnu/libgnutls.so.30: undefined symbol: __gmpz_limbs_write\n# https://github.com/ONLYOFFICE/appimage-desktopeditors/issues/3\n# Apparently coreutils depends on it, so it should be safe to assume that it comes with every target system\nlibgmp.so.10\n"
  },
  {
    "path": "ci/test/expensive_generic_test_multi_client.sh",
    "content": "#!/bin/bash\nset -xe\n\nexport PYTHONUNBUFFERED=1\n\nsrc_dir=${ONEFLOW_SRC_DIR:-\"$PWD\"}\nONEFLOW_TEST_DIR=${ONEFLOW_TEST_DIR:-\"$PWD/python/oneflow/test/modules\"}\n\ncd $ONEFLOW_TEST_DIR\n\nif [ -z \"$ONEFLOW_TEST_CPU_ONLY\" ]\nthen\n    gpu_num=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)\n    for ((i=0;i<gpu_num;i++)); do\n        parallel_spec=\"$parallel_spec --tx popen//env:CUDA_VISIBLE_DEVICES=${i}\"\n    done\nelse\n    parallel_spec=\"-n auto\"\nfi\n\nunset HTTP_PROXY\nunset HTTPS_PROXY\nunset http_proxy\nunset https_proxy\n\nexport ONEFLOW_TEST_DEVICE_NUM=1\n\nCOMMON_PYTEST_ARGS=\"--max-worker-restart=0 -x --durations=50 --capture=sys\"\npython3 -m pytest ${COMMON_PYTEST_ARGS} --failed-first --dist loadfile ${parallel_spec} ${PWD}\nif [[ \"$(python3 -c 'import oneflow.sysconfig;print(oneflow.sysconfig.has_rpc_backend_grpc())')\" == *\"True\"* ]]; then\n    export ONEFLOW_TEST_DEVICE_NUM=2\n    python3 -m oneflow.distributed.launch --nproc_per_node 2 -m pytest ${COMMON_PYTEST_ARGS} ${PWD}\n\n    export ONEFLOW_TEST_DEVICE_NUM=4\n    python3 -m oneflow.distributed.launch --nproc_per_node 4 -m pytest ${COMMON_PYTEST_ARGS} ${PWD}\nelse\n    python3 -c 'import oneflow.sysconfig;assert(oneflow.sysconfig.has_rpc_backend_grpc() == False)'\nfi\n"
  },
  {
    "path": "ci/test/generic_test.sh",
    "content": "#!/bin/bash\nset -xe\n\nexport TF_CPP_MIN_LOG_LEVEL=3\nexport PYTHONUNBUFFERED=1\n\nsrc_dir=${ONEFLOW_SRC_DIR:-\"$PWD\"}\ntest_dir=${ONEFLOW_TEST_DIR:-\"$PWD/python/oneflow/test/ops\"}\ntest_tmp_dir=${ONEFLOW_TEST_TMP_DIR:-\"./test_tmp_dir\"}\nexport ONEFLOW_TEST_UTILS_DIR=$src_dir/python/oneflow/test_utils\n\nrm -rf $test_tmp_dir\nmkdir -p $test_tmp_dir\ncp -r $test_dir $test_tmp_dir\ncd ${test_tmp_dir}/$(basename $test_dir)\n\ngpu_num=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)\nexport ONEFLOW_TEST_DEVICE_NUM=1\npython3 $src_dir/ci/test/parallel_run.py \\\n    --gpu_num=${gpu_num} \\\n    --dir=${PWD} \\\n    --timeout=1 \\\n    --verbose \\\n    --chunk=1\n\nexport ONEFLOW_TEST_DEVICE_NUM=2\npython3 -m unittest discover ${PWD} --failfast --verbose\n\nexport ONEFLOW_TEST_DEVICE_NUM=4\npython3 -m unittest discover ${PWD} --failfast --verbose\n"
  },
  {
    "path": "ci/test/generic_test_multi_client.sh",
    "content": "#!/bin/bash\nset -xe\n\nexport PYTHONUNBUFFERED=1\n\nsrc_dir=${ONEFLOW_SRC_DIR:-\"$PWD\"}\nONEFLOW_TEST_DIR=${ONEFLOW_TEST_DIR:-\"$PWD/python/oneflow/test/modules\"}\nONEFLOW_TEST_TASKS_PER_GPU=${ONEFLOW_TEST_TASKS_PER_GPU:-\"4\"}\n\n\nif [ -z \"$ONEFLOW_TEST_FILES\" ]; then\n  ONEFLOW_TEST_FILES=\"${ONEFLOW_TEST_DIR}\"\n  ONEFLOW_TEST_FILES_WILD=\"${ONEFLOW_TEST_DIR}/**/test_*.py\"\n  ONEFLOW_TEST_REPEAT_TIMES=\"\"\nelse\n  ONEFLOW_TEST_FILES_WILD=\"${ONEFLOW_TEST_FILES}\"\n  ONEFLOW_TEST_REPEAT_TIMES=\"--count=100\"\nfi\n\nif [ -z \"$ONEFLOW_TEST_CPU_ONLY\" ]\nthen\n    gpu_num=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)\n    for ((i=0;i<gpu_num;i++)); do\n        for ((j=0;j<ONEFLOW_TEST_TASKS_PER_GPU;j++)); do\n            parallel_spec=\"$parallel_spec --tx popen//env:CUDA_VISIBLE_DEVICES=${i}\"\n        done\n    done\n    multi_launch_device_num=${gpu_num}\nelse\n    parallel_spec=\"-n auto\"\n    multi_launch_device_num=8\nfi\n\nunset HTTP_PROXY\nunset HTTPS_PROXY\nunset http_proxy\nunset https_proxy\n\nexport ONEFLOW_TEST_DEVICE_NUM=1\n\nCOMMON_PYTEST_ARGS=\"-p no:warnings -p no:randomly -p no:cacheprovider --max-worker-restart=0 -x --durations=50 --capture=sys --ignore=log ${ONEFLOW_TEST_REPEAT_TIMES}\"\ntime python3 -m pytest ${COMMON_PYTEST_ARGS} --dist loadfile ${parallel_spec} ${ONEFLOW_TEST_FILES}\nif [[ \"$(python3 -c 'import oneflow.sysconfig;print(oneflow.sysconfig.has_rpc_backend_grpc())')\" == *\"True\"* ]]; then\n    export ONEFLOW_TEST_DEVICE_NUM=2\n    time python3 ${src_dir}/ci/test/multi_launch.py \\\n        --files \"${ONEFLOW_TEST_FILES_WILD}\" \\\n        --master_port 29500 \\\n        --master_port 29501 \\\n        --master_port 29502 \\\n        --master_port 29503 \\\n        -n master_port \\\n        --group_size 2 \\\n        --auto_cuda_visible_devices \\\n        --device_num $multi_launch_device_num \\\n        -m oneflow.distributed.launch --nproc_per_node 2 -m pytest ${COMMON_PYTEST_ARGS}\n\n    export ONEFLOW_TEST_DEVICE_NUM=4\n    time python3 ${src_dir}/ci/test/multi_launch.py \\\n        --files \"${ONEFLOW_TEST_FILES_WILD}\" \\\n        -n 3 \\\n        --group_size 4 \\\n        --device_num $multi_launch_device_num \\\n        --auto_cuda_visible_devices \\\n        -m oneflow.distributed.launch --nproc_per_node 4 -m pytest ${COMMON_PYTEST_ARGS}\nelse\n    python3 -c 'import oneflow.sysconfig;assert(oneflow.sysconfig.has_rpc_backend_grpc() == False)'\nfi\n"
  },
  {
    "path": "ci/test/ir_tests.sh",
    "content": "#!/bin/bash\nset -xe\n\nexport PYTHONUNBUFFERED=1\n\nsrc_dir=${ONEFLOW_SRC_DIR:-\"$PWD\"}\nONEFLOW_TEST_DIR=${ONEFLOW_TEST_DIR:-\"$PWD/oneflow/ir/test\"}\n\ncd $ONEFLOW_TEST_DIR\n\nif [ -z \"$ONEFLOW_TEST_CPU_ONLY\" ]\nthen\n    gpu_num=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)\n    for ((i=0;i<gpu_num;i++)); do\n        parallel_spec=\"$parallel_spec --tx popen//env:CUDA_VISIBLE_DEVICES=${i}\"\n    done\nelse\n    parallel_spec=\"-n auto\"\nfi\n\nunset HTTP_PROXY\nunset HTTPS_PROXY\nunset http_proxy\nunset https_proxy\n\nexport ONEFLOW_TEST_DEVICE_NUM=1\n\nCOMMON_PYTEST_ARGS=\"--max-worker-restart=0 --durations=50 --ignore=OneFlow/cuda_code_gen --ignore=OneFlow/psig/test_2nd_basic_parse.py\"\npython3 -m pytest ${COMMON_PYTEST_ARGS} --failed-first --dist loadfile ${parallel_spec} ${PWD}\n"
  },
  {
    "path": "ci/test/multi_client_exception_test.sh",
    "content": "#!/bin/bash\nset -xe\n\nexport PYTHONUNBUFFERED=1\n\nsrc_dir=${ONEFLOW_SRC_DIR:-\"$PWD\"}\ntest_dir=\"$PWD/python/oneflow/test/exceptions\"\ntest_tmp_dir=${ONEFLOW_TEST_TMP_DIR:-\"./test_tmp_dir\"}\nexport ONEFLOW_TEST_UTILS_DIR=$src_dir/python/oneflow/test_utils\n\n\nrm -rf $test_tmp_dir\nmkdir -p $test_tmp_dir\ncp -r $test_dir $test_tmp_dir\ncd ${test_tmp_dir}/$(basename $test_dir)\n\nexport ONEFLOW_DEBUG_MODE=1\n\nfor file in $(ls ${PWD}/test_*.py)\ndo\n    if test -f $file\n    then\n        export ONEFLOW_TEST_DEVICE_NUM=1\n        python3 $file --failfast --verbose\n        if [[ \"$(python3 -c 'import oneflow.sysconfig;print(oneflow.sysconfig.has_rpc_backend_grpc())')\" == *\"True\"* ]]; then\n            export ONEFLOW_TEST_DEVICE_NUM=2\n            python3 -m oneflow.distributed.launch --nproc_per_node 2 $file --failfast --verbose\n\n            export ONEFLOW_TEST_DEVICE_NUM=4\n            python3 -m oneflow.distributed.launch --nproc_per_node 4 $file --failfast --verbose\n        else\n            python3 -c 'import oneflow.sysconfig;assert(oneflow.sysconfig.has_rpc_backend_grpc() == False)'\n        fi\n    fi\ndone\n\nunset ONEFLOW_DEBUG_MODE\n"
  },
  {
    "path": "ci/test/multi_launch.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\"\"\"\nThis file is mostly copied from PyTorch v1.8.1 torch/distributed/launch.py\n\"\"\"\nimport asyncio\nimport os\nimport random\nimport sys\nfrom argparse import REMAINDER, ArgumentParser\nfrom typing import IO, Any, List, Optional\nimport glob\nimport hashlib\nfrom math import ceil\n\nstdout_filename = \"stdout\"\nstderr_filename = \"stderr\"\n\nglobal PARALLEL_NUM\nglobal SUCCESS_NUM\nPARALLEL_NUM = 0\nSUCCESS_NUM = 0\n\n\ndef parse_args():\n    \"\"\"\n    Helper function parsing the command line options\n    @retval ArgumentParser\n    \"\"\"\n    parser = ArgumentParser(\n        description=\"helper to start multiple distributed launches in parallel\"\n    )\n    parser.add_argument(\n        \"--files\",\n        type=str,\n        help=\"files to run, support pattern\",\n        required=True,\n        nargs=\"+\",\n    )\n    parser.add_argument(\n        \"--group_size\",\n        type=int,\n        help=\"for one command, how many duplications to run\",\n        required=True,\n    )\n    parser.add_argument(\n        \"--device_num\", type=int, help=\"how many devices to run on\", required=True,\n    )\n    parser.add_argument(\n        \"-n\",\n        \"--parallel_num\",\n        type=str,\n        help=\"how many launches, could be a number, or 'master_port'\",\n        required=True,\n    )\n    parser.add_argument(\n        \"--auto_cuda_visible_devices\",\n        action=\"store_true\",\n        required=False,\n        default=False,\n    )\n    parser.add_argument(\n        \"--shuffle\", action=\"store_true\", required=False, default=False,\n    )\n    parser.add_argument(\n        \"--verbose\", action=\"store_true\", required=False, default=False,\n    )\n    parser.add_argument(\n        \"--master_port\",\n        default=[],\n        action=\"append\",\n        help=\"Master node (rank 0)'s free port, pass this multiple `--master_port` to launch more instances\",\n    )\n    parser.add_argument(\n        \"-m\",\n        \"--module\",\n        default=False,\n        action=\"store_true\",\n        help=\"Changes each process to interpret the launch script as a python module, executing with the same behavior as'python -m'.\",\n    )\n    parser.add_argument(\n        \"training_script\",\n        type=str,\n        help=\"The full path to the single GPU training program/script to be launched in parallel, followed by all the arguments for the training script\",\n    )\n    parser.add_argument(\"training_script_args\", nargs=REMAINDER)\n    return parser.parse_args()\n\n\nasync def run_and_capture(cmd=None, prefix=None, **kwargs):\n    proc = await asyncio.create_subprocess_exec(\n        *cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.STDOUT, **kwargs\n    )\n    while True:\n        line = await proc.stdout.readline()\n        print(prefix, line.decode(), end=\"\")\n        if not line:\n            break\n    await proc.wait()\n    assert proc.returncode == 0, prefix\n    global PARALLEL_NUM\n    global SUCCESS_NUM\n    SUCCESS_NUM += 1\n    print(f\"{prefix} succeed ({SUCCESS_NUM}/{PARALLEL_NUM})\")\n\n\nasync def launch_multiple(\n    cmds=None, group_size=None, auto_cuda_env=False, device_num=None\n):\n    visible_groups = [\n        [str(x) for x in range(device_num)[i : i + group_size]]  # to get [\"0\", \"1\"]\n        for i in range(0, device_num, group_size)\n    ]\n    spawns = []\n    for i, cmd in enumerate(cmds):\n        group_idx = i % len(visible_groups)\n        cuda_visible_devices = \",\".join(visible_groups[group_idx])\n        print(cuda_visible_devices, cmd, \"\\n\")\n        env = os.environ\n        if auto_cuda_env:\n            env = dict(env, CUDA_VISIBLE_DEVICES=cuda_visible_devices)\n        process = run_and_capture(\n            cmd=cmd, prefix=f\"[wg={i}][device={cuda_visible_devices}]\", env=env,\n        )\n        spawns.append(process)\n    await asyncio.gather(*spawns)\n\n\ndef main():\n    args = parse_args()\n    # find files and chuck them\n    files = []\n    for f in args.files:\n        for ff in f.strip().split(\" \"):\n            if len(ff) > 0:\n                files += list(glob.glob(ff, recursive=True))\n    print(\"total files:\", len(files))\n    files = sorted(\n        files,\n        key=lambda x: hashlib.md5(os.path.basename(x.encode(\"ascii\"))).hexdigest(),\n    )\n    if args.shuffle:\n        random.shuffle(files)\n    files_hash = hashlib.md5(\n        \"\".join([os.path.basename(x) for x in files]).encode()\n    ).hexdigest()[:8]\n    if args.verbose:\n        print(\n            f\"::warning file=testFilesHash,line={len(files)},col=0,endColumn=0::shuffle-{args.shuffle}-group_size-{args.group_size}-md5-{files_hash}\"\n        )\n    if args.parallel_num == \"master_port\":\n        parallel_num = len(args.master_port)\n        master_ports = args.master_port\n    else:\n        parallel_num = int(args.parallel_num)\n        if parallel_num != len(args.master_port):\n            print(\n                \"warning\", \"parallel_num != len(args.master_port)\", \"will auto generate\"\n            )\n        default_master_port = 29500\n        master_ports = list(\n            range(default_master_port, default_master_port + parallel_num)\n        )\n    assert parallel_num > 0\n    assert len(master_ports) == parallel_num\n    chunk_size = ceil(len(files) / parallel_num)\n    global PARALLEL_NUM\n    PARALLEL_NUM = parallel_num\n    chunks = [files[i : i + chunk_size] for i in range(0, len(files), chunk_size)]\n\n    # check args\n    assert args.training_script == \"oneflow.distributed.launch\"\n\n    # generate commands\n    cmds = [\n        [sys.executable, \"-m\", args.training_script, \"--master_port\", str(master_port)]\n        + args.training_script_args\n        + chunck\n        for (master_port, chunck) in zip(master_ports, chunks)\n    ]\n    loop = asyncio.get_event_loop()\n    processes = launch_multiple(\n        cmds=cmds,\n        auto_cuda_env=args.auto_cuda_visible_devices,\n        group_size=args.group_size,\n        device_num=args.device_num,\n    )\n    loop.run_until_complete(processes)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "ci/test/parallel_run.py",
    "content": "import asyncio\nimport os\nimport argparse\nfrom subprocess import PIPE, STDOUT\nimport glob\nimport sys\nimport time\nimport socket\nfrom contextlib import closing\nimport uuid\n\n\ndef gen_cmds(cmd=None, dir=None, doctest=False):\n    if doctest:\n        paths = glob.glob(os.path.join(dir, \"**/*.py\"), recursive=True)\n        paths = [\n            p\n            for p in paths\n            if \"compatible\" not in p\n            and \"single_client\" not in p\n            and \"unittest.py\" not in p\n        ]\n        with_doctest = []\n        for p in paths:\n            with open(p) as f:\n                content = f.read()\n                if \"import doctest\" in content:\n                    with_doctest.append(\"{} {} -v\".format(cmd, p))\n        print(with_doctest)\n        return with_doctest\n    else:\n        paths = glob.glob(os.path.join(dir, \"test_*.py\"), recursive=False)\n        return [\"{} {} --failfast --verbose\".format(cmd, p) for p in paths]\n\n\ndef find_free_port():\n    with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:\n        s.bind((\"localhost\", 0))\n        s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)\n        return s.getsockname()[1]\n\n\ndef split_and_print(prefix, text):\n    lines = text.splitlines(keepends=True)\n    prefixed = \"\"\n    for l in lines:\n        prefixed += f\"{prefix} {l}\"\n    print(prefixed, flush=True)\n\n\ndef everyN(l: list, n: int):\n    for i in range(0, len(l), n):\n        yield l[i : i + n]\n\n\ndef contains_oom_info(txt: str):\n    return \"memory\" in txt or \"Memory\" in txt or \"CUDNN\" in txt or \"ALLOC\" in txt\n\n\ndef should_retry(txt: str):\n    return contains_oom_info(txt)\n\n\ndef print_out(prefix: str = \"\", content: str = \"\"):\n    for l in content.split(\"\\n\"):\n        print(f\"[{prefix}]\", l)\n\n\nasync def spawn_shell_and_check(cmd: str = None, gpu_id: int = -1, check: bool = False):\n    is_cpu_only = os.getenv(\"ONEFLOW_TEST_CPU_ONLY\")\n    print(f\"[gpu={gpu_id}]\", cmd)\n    p = await asyncio.create_subprocess_shell(\n        cmd,\n        stdout=PIPE,\n        stderr=STDOUT,\n        env=dict(\n            os.environ,\n            CUDA_VISIBLE_DEVICES=(\"-1\" if is_cpu_only else \",\".join([str(gpu_id)])),\n            ONEFLOW_TEST_MASTER_PORT=str(find_free_port()),\n            ONEFLOW_TEST_LOG_DIR=(\"./unittest-log-\" + str(uuid.uuid4())),\n        ),\n    )\n    (stdout_data, stderr_data) = await p.communicate()\n    decoded = stdout_data.decode()\n    if check or should_retry(decoded) == False:\n        if p.returncode != 0:\n            print_out(prefix=cmd, content=decoded)\n            raise RuntimeError(cmd)\n    return {\"returncode\": p.returncode, \"cmd\": cmd, \"stdout\": decoded}\n\n\nasync def run_cmds(\n    cmds, gpu_num=0, timeout=10, chunk=1, verbose=False, per_gpu_process_num=1\n):\n    is_cpu_only = os.getenv(\"ONEFLOW_TEST_CPU_ONLY\")\n    if is_cpu_only:\n        gpu_num = os.cpu_count()\n    fails = []\n    assert gpu_num > 0\n    for cmdN in everyN(cmds, per_gpu_process_num * gpu_num):\n        results = await asyncio.gather(\n            *[\n                spawn_shell_and_check(\n                    cmd=cmd, gpu_id=i, check=(per_gpu_process_num == 1)\n                )\n                for cmd_gpu_num in everyN(cmdN, gpu_num)\n                for (i, cmd) in enumerate(cmd_gpu_num)\n            ],\n        )\n        for r in list(results):\n            if r[\"returncode\"] != 0:\n                fails.append(r[\"cmd\"])\n            else:\n                print_out(prefix=r[\"cmd\"], content=r[\"stdout\"])\n    return fails\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--gpu_num\", type=int, required=True, default=0)\n    parser.add_argument(\"--dir\", type=str, required=True, default=\".\")\n    parser.add_argument(\"--cmd\", type=str, required=False, default=sys.executable)\n    parser.add_argument(\"--timeout\", type=int, required=False, default=2)\n    parser.add_argument(\"--chunk\", type=int, required=True)\n    parser.add_argument(\"--verbose\", action=\"store_true\", required=False, default=False)\n    parser.add_argument(\"--doctest\", action=\"store_true\", required=False, default=False)\n    args = parser.parse_args()\n    cmds = gen_cmds(cmd=args.cmd, dir=args.dir, doctest=args.doctest)\n    start = time.time()\n    loop = asyncio.get_event_loop()\n    PER_GPU_PROCESS_NUMS = [12, 8, 2, 1]\n    is_cpu_only = os.getenv(\"ONEFLOW_TEST_CPU_ONLY\")\n    if is_cpu_only:\n        PER_GPU_PROCESS_NUMS = [1]\n    for per_gpu_process_num in PER_GPU_PROCESS_NUMS:\n        print(\"[per_gpu_process_num]\", per_gpu_process_num)\n        cmds = loop.run_until_complete(\n            run_cmds(\n                cmds,\n                gpu_num=args.gpu_num,\n                timeout=args.timeout,\n                chunk=args.chunk,\n                verbose=args.verbose,\n                per_gpu_process_num=per_gpu_process_num,\n            )\n        )\n    elapsed = time.time() - start\n    elapsed_time_txt = time.strftime(\"elapsed: %H:%M:%S\", time.gmtime(elapsed))\n    print(elapsed_time_txt)\n"
  },
  {
    "path": "ci/test/print_stack_from_core.sh",
    "content": "set -ex\nif compgen -G \"$2/core.*\" > /dev/null; then\n    gdb --batch --quiet -ex \"thread apply all bt full\" -ex \"quit\" $1 $2/core.*\nfi\n"
  },
  {
    "path": "ci/test/print_stack_in_all_dirs.sh",
    "content": "set -ex\nfind . -type f -name \"core.*\" -exec gdb --batch --quiet -ex \"thread apply all bt full\" -ex \"quit\" python3 {} \\;\n"
  },
  {
    "path": "ci/test/resource-spec/1x-gtx-1080.json",
    "content": "{\n  \"version\": {\n    \"major\": 1,\n    \"minor\": 0\n  },\n  \"local\": [\n    {\n      \"vram\": [\n        {\n          \"id\": \"0\",\n          \"slots\": 8117\n        }\n      ]\n    }\n  ]\n}\n"
  },
  {
    "path": "ci/test/resource-spec/2x-rtx-2080.json",
    "content": "{\n  \"version\": {\n    \"major\": 1,\n    \"minor\": 0\n  },\n  \"local\": [\n    {\n      \"vram\": [\n        {\n          \"id\": \"0\",\n          \"slots\": 7982\n        },\n        {\n          \"id\": \"1\",\n          \"slots\": 7982\n        }\n      ]\n    }\n  ]\n}\n"
  },
  {
    "path": "ci/test/resource-spec/4x-rtx-2080ti.json",
    "content": "{\n  \"version\": {\n    \"major\": 1,\n    \"minor\": 0\n  },\n  \"local\": [\n    {\n      \"vram\": [\n        {\n          \"id\": \"0\",\n          \"slots\": 11019\n        },\n        {\n          \"id\": \"1\",\n          \"slots\": 11019\n        },\n        {\n          \"id\": \"2\",\n          \"slots\": 11019\n        },\n        {\n          \"id\": \"3\",\n          \"slots\": 11019\n        }\n      ]\n    }\n  ]\n}\n"
  },
  {
    "path": "ci/test/test_mock_function.sh",
    "content": "#!/bin/bash\nset -e\nMOCK_UNITTEST=$PWD/python/oneflow/test/misc/test_mock_scope.py\n\npython3 $MOCK_UNITTEST --failfast --verbose\n# testing import *\npython3 -c \"\nimport oneflow\nimport oneflow.nn\nimport oneflow.mock_torch as mock; mock.enable();\nfrom torch.sbp import *; assert(sbp == oneflow.sbp.sbp);\nfrom torch import *; assert(randn == oneflow.randn);\nfrom torch.nn import *; assert(Graph == oneflow.nn.Graph);\nmock.disable();\nfrom torch import *; assert(randn != oneflow.randn);\nfrom torch.nn import *; assert(Graph != oneflow.nn.Graph);\n\"\n"
  },
  {
    "path": "ci/test/test_mock_script.sh",
    "content": "#!/bin/bash\nset -e\npython_version=$(python3 --version 2>&1 | awk '{print $2}')\n\nif [[ \"$python_version\" < \"3.8\" ]]; then\n    echo \"Python version is less than 3.8.\"\n    exit 0\nfi\n\nMOCK_TORCH=$PWD/python/oneflow/test/misc/mock_example.py\n\nsame_or_exit() {\n    if [[ \"$(python3 $MOCK_TORCH)\" != *\"$1\"* ]]; then\n        exit 1\n    fi\n}\n\n# generate pytorch file\npython3 -c \"import torch; torch.save(torch.ones(1), 'test.pt')\"\n\neval $(python3 -m oneflow.mock_torch) # test call to python module, default argument is enable\nsame_or_exit \"True\"\n\n# test load pytorch file with mock torch enabled\npython3 -c \"\"\"\nimport torch\nx = torch.load('test.pt')\nassert torch.equal(x, torch.ones(1))\nimport torch.nn\nassert 'oneflow/nn/__init__.py' in torch.nn.__file__\n\"\"\"\n\n# testing import\npython3 -c 'import torch; torch.randn(2,3)'\npython3 -c 'import torch.nn; torch.nn.Graph'\npython3 -c 'import torch.version; torch.version.__version__'\npython3 -c 'from torch import *; randn(2,3)'\npython3 -c 'from torch.nn import *; Graph'\npython3 -c 'from torch.sbp import *; sbp'\npython3 -c 'from torch import nn; nn.Graph'\npython3 -c 'from torch.version import __version__'\npython3 -c 'import torch; torch.not_exist' 2>&1 >/dev/null | grep -q 'AttributeError'\npython3 -c 'import torch.not_exist' 2>&1 >/dev/null | grep -q 'ModuleNotFoundError'\n\neval $(python3 -m oneflow.mock_torch disable)\nsame_or_exit \"False\"\neval $(python3 -m oneflow.mock_torch enable)\nsame_or_exit \"True\"\neval $(python3 -m oneflow.mock_torch disable) # recover\nsame_or_exit \"False\"\neval $(oneflow-mock-torch) # test scripts\nsame_or_exit \"True\"\neval $(oneflow-mock-torch disable)\nsame_or_exit \"False\"\neval $(oneflow-mock-torch enable)\nsame_or_exit \"True\"\neval $(oneflow-mock-torch disable)\nsame_or_exit \"False\"\n\n# test load pytorch file with mock torch disabled\npython3 -c \"import oneflow as flow; x = flow.load('test.pt'); assert flow.equal(x, flow.ones(1))\"\n\nrm test.pt\n\neval $(python3 -m oneflow.mock_torch --lazy --verbose)\npython3 -c \"import torch.not_exist\" | grep -q 'dummy object'"
  },
  {
    "path": "ci/test/test_resnet50_graph_ddp.sh",
    "content": "#!/usr/bin/env bash\n\nset -ex\n\ncd $ONEFLOW_MODELS_DIR\nONEFLOW_TEST_DATASET_DIR=${ONEFLOW_TEST_DATASET_DIR:-\"/dataset\"}\nOFRECORD_PATH=${ONEFLOW_TEST_DATASET_DIR}/imagenette/ofrecord\n\nif [ ! -d \"${ONEFLOW_TEST_DATASET_DIR}/imagenette/ofrecord/train\" ];then\n    mkdir -p ./dataset/ofrecord\n    ln -s ${ONEFLOW_TEST_DATASET_DIR}/imagenette/ofrecord ./dataset/ofrecord/train\n    OFRECORD_PATH=./dataset/ofrecord\nfi\n\npython3 -m oneflow.distributed.launch --nproc_per_node 1 --nnodes 1 --node_rank 0 --master_addr 127.0.0.1 Vision/classification/image/resnet50/train.py --ofrecord-path $OFRECORD_PATH --ofrecord-part-num 1 --num-devices-per-node 1 --lr 0.004 --momentum 0.875 --num-epochs 1 --train-batch-size 4 --val-batch-size 50 --print-interval 10 --exit-num 1 --ddp\npython3 -m oneflow.distributed.launch --nproc_per_node 2 --nnodes 1 --node_rank 0 --master_addr 127.0.0.1 Vision/classification/image/resnet50/train.py --ofrecord-path $OFRECORD_PATH --ofrecord-part-num 2 --num-devices-per-node 1 --lr 0.004 --momentum 0.875 --num-epochs 1 --train-batch-size 4 --val-batch-size 50 --print-interval 10 --exit-num 1 --use-fp16 --channel-last --scale-grad --graph --fuse-bn-relu --fuse-bn-add-relu --use-gpu-decode\n"
  },
  {
    "path": "ci/test/test_speed_multi_client.sh",
    "content": "#!/usr/bin/env bash\n\nset -uxo pipefail\n\nrc=0\n# accumulate the score of every test\ntrap 'rc=$(($rc + $?))' ERR\n\ncd $ONEFLOW_MODELS_DIR\n\nfunction check_relative_speed {\n  # Default score is 1\n  SCORE=${2:-1}\n  awk -F'[:(]' -v threshold=$1 -v score=$SCORE 'BEGIN { ret=2 } /Relative speed/{ if ($2 >= threshold) { printf \"✔️ \"; ret=0 } else { printf \"❌ \"; ret=score }} {print $0} END { exit ret }'\n}\n\nfunction check_millisecond_time {\n  # Default score is 1\n  SCORE=${2:-1}\n  awk -F'[:(]' -v threshold=$1 -v score=$SCORE 'BEGIN { ret=2 } /OneFlow/{ if (substr($2, 2, length($2) - 4) <= threshold) { printf \"✔️ \"; ret=0 } else { printf \"❌ \"; ret=score }} { print $0 } END { exit ret }'\n}\n\nfunction write_to_file_and_print {\n  tee -a result\n  printf \"\\n\" >> result\n}\n\npython3 scripts/compare_speed_with_pytorch.py Vision/classification/image/resnet50/models/resnet50.py resnet50 16x3x224x224 --no-show-memory --times 100 | check_relative_speed 1.05 | check_millisecond_time 129.0 2 | write_to_file_and_print\npython3 scripts/compare_speed_with_pytorch.py Vision/classification/image/resnet50/models/resnet50.py resnet50 8x3x224x224 --no-show-memory --times 100 | check_relative_speed 1.04 | write_to_file_and_print\npython3 scripts/compare_speed_with_pytorch.py Vision/classification/image/resnet50/models/resnet50.py resnet50 4x3x224x224 --no-show-memory --times 200 | check_relative_speed 1.01 | write_to_file_and_print\npython3 scripts/compare_speed_with_pytorch.py Vision/classification/image/resnet50/models/resnet50.py resnet50 2x3x224x224 --no-show-memory --times 200 | check_relative_speed 0.99 | write_to_file_and_print\npython3 scripts/compare_speed_with_pytorch.py Vision/classification/image/resnet50/models/resnet50.py resnet50 1x3x224x224 --no-show-memory --times 200 | check_relative_speed 0.95 | write_to_file_and_print\n\npython3 scripts/swin_dataloader_compare_speed_with_pytorch.py --batch_size 32 --num_workers 1 | write_to_file_and_print\npython3 scripts/swin_dataloader_compare_speed_with_pytorch.py --batch_size 32 --num_workers 4 | write_to_file_and_print\npython3 scripts/swin_dataloader_compare_speed_with_pytorch.py --batch_size 32 --num_workers 8 | write_to_file_and_print\n\nexport OMP_NUM_THREADS=1\npython3 -m oneflow.distributed.launch --nproc_per_node 2 scripts/compare_speed_with_pytorch.py Vision/classification/image/resnet50/models/resnet50.py resnet50 16x3x224x224 --no-show-memory --times 100 --ddp | check_relative_speed 1.12 | check_millisecond_time 136.3 2 | write_to_file_and_print\npython3 -m oneflow.distributed.launch --nproc_per_node 2 scripts/compare_speed_with_pytorch.py Vision/classification/image/resnet50/models/resnet50.py resnet50 8x3x224x224 --no-show-memory --times 100 --ddp | check_relative_speed 1.1 | write_to_file_and_print\npython3 -m oneflow.distributed.launch --nproc_per_node 2 scripts/compare_speed_with_pytorch.py Vision/classification/image/resnet50/models/resnet50.py resnet50 4x3x224x224 --no-show-memory --times 200 --ddp | check_relative_speed 1.18 | write_to_file_and_print\npython3 -m oneflow.distributed.launch --nproc_per_node 2 scripts/compare_speed_with_pytorch.py Vision/classification/image/resnet50/models/resnet50.py resnet50 2x3x224x224 --no-show-memory --times 200 --ddp | check_relative_speed 1.18 | write_to_file_and_print\npython3 -m oneflow.distributed.launch --nproc_per_node 2 scripts/compare_speed_with_pytorch.py Vision/classification/image/resnet50/models/resnet50.py resnet50 1x3x224x224 --no-show-memory --times 200 --ddp | check_relative_speed 1.15 | write_to_file_and_print\n\nresult=\"GPU Name: `nvidia-smi --query-gpu=name --format=csv,noheader -i 0` \\n\\n `cat result`\"\n# escape newline for github actions: https://github.community/t/set-output-truncates-multiline-strings/16852/2\n# note that we escape \\n and \\r to \\\\n and \\\\r (i.e. raw string \"\\n\" and \"\\r\") instead of %0A and %0D, \n# so that they can be correctly handled in javascript code\nresult=\"${result//'%'/'%25'}\"\nresult=\"${result//$'\\n'/'\\\\n'}\"\nresult=\"${result//$'\\r'/'\\\\r'}\"\n\necho \"::set-output name=stats::$result\"\n\n# Only fail when the sum of score >= 2\nif (( $rc >= 2 ))\nthen\n  exit 1\nelse\n  exit 0\nfi\n"
  },
  {
    "path": "ci/test/try_install.sh",
    "content": "#!/bin/bash\nset -xe\n\nsrc_dir=${ONEFLOW_SRC_DIR:-\"$PWD\"}\nwheel_path=${ONEFLOW_WHEEL_PATH:-\"$PWD/wheelhouse\"}\nindex=${ONEFLOW_PIP_INDEX}\npkg_name=${ONEFLOW_PACKAGE_NAME:-\"oneflow\"}\n\nif [ -n \"$index\" ]; then\n    python3 -m pip install --find-links ${index} ${pkg_name}\nelif [ -d \"$wheel_path\" ]; then\n    ls -la $wheel_path\n    export PATH=/root/.local/bin:$PATH\n    python3 -m pip install https://oneflow-static.oss-cn-beijing.aliyuncs.com/pipindex/pipindex-0.1.3-py2.py3-none-any.whl --user\n    pipindex build $wheel_path\n    python3 -m pip install -U --user --extra-index-url file://${wheel_path}/simple ${pkg_name}\nelif [ -e \"$wheel_path\" ]; then\n    python3 -m pip install --user \"$wheel_path\"\nelif [ -d \"$src_dir\" ]; then\n    python3 -m pip install -e \"$src_dir\" --user\nelse\n    echo \"wheel not found: $wheel_path, src dir not found: $src_dir, continue anyway...\"\nfi\n"
  },
  {
    "path": "cmake/caches/ci/canary/cuda.cmake",
    "content": "set(BUILD_CUDA YES CACHE BOOL \"\")\nset(BUILD_GIT_VERSION YES CACHE BOOL \"\")\nset(BUILD_TESTING OFF CACHE BOOL \"\")\nset(BUILD_RDMA YES CACHE BOOL \"\")\nset(TREAT_WARNINGS_AS_ERRORS YES CACHE BOOL \"\")\nset(THIRD_PARTY_MIRROR aliyun CACHE STRING \"\")\nset(PIP_INDEX_MIRROR \"https://pypi.tuna.tsinghua.edu.cn/simple\" CACHE STRING \"\")\nset(CMAKE_BUILD_TYPE Release CACHE STRING \"\")\nset(CMAKE_GENERATOR Ninja CACHE STRING \"\")\nset(CMAKE_CUDA_ARCHITECTURES \"61-real;70-real;75-real;80-real;86-real\" CACHE STRING \"\")\nset(CUDNN_STATIC OFF CACHE BOOL \"\")\nset(WITH_MLIR ON CACHE BOOL \"\")\nset(BUILD_CPP_API OFF CACHE BOOL \"\")\nset(CUDA_NVCC_THREADS_NUMBER 8 CACHE STRING \"\")\nset(CMAKE_C_COMPILER_LAUNCHER ccache CACHE STRING \"\")\nset(CMAKE_CXX_COMPILER_LAUNCHER ccache CACHE STRING \"\")\nset(CMAKE_CUDA_COMPILER_LAUNCHER ccache CACHE STRING \"\")\n"
  },
  {
    "path": "cmake/caches/ci/cpu-asan-ubsan.cmake",
    "content": "set(BUILD_CUDA NO CACHE BOOL \"\")\nset(BUILD_GIT_VERSION YES CACHE BOOL \"\")\nset(BUILD_TESTING YES CACHE BOOL \"\")\nset(WITH_ONEDNN YES CACHE BOOL \"\")\nset(TREAT_WARNINGS_AS_ERRORS YES CACHE BOOL \"\")\nset(THIRD_PARTY_MIRROR aliyun CACHE STRING \"\")\nset(PIP_INDEX_MIRROR \"https://pypi.tuna.tsinghua.edu.cn/simple\" CACHE STRING \"\")\nset(CMAKE_BUILD_TYPE RelWithDebInfo CACHE STRING \"\")\nset(CMAKE_GENERATOR Ninja CACHE STRING \"\")\nset(BUILD_CPP_API ON CACHE BOOL \"\")\nset(WITH_MLIR ON CACHE BOOL \"\")\nset(BUILD_FOR_CI ON CACHE BOOL \"\")\nset(BUILD_SHARED_LIBS ON CACHE BOOL \"\")\nset(CMAKE_C_COMPILER_LAUNCHER ccache CACHE STRING \"\")\nset(CMAKE_CXX_COMPILER_LAUNCHER ccache CACHE STRING \"\")\nset(CMAKE_INTERPROCEDURAL_OPTIMIZATION OFF CACHE BOOL \"\")\nset(ENABLE_ASAN ON CACHE BOOL \"\")\nset(ENABLE_UBSAN OFF CACHE BOOL \"\")\n"
  },
  {
    "path": "cmake/caches/ci/cpu-tsan.cmake",
    "content": "set(BUILD_CUDA NO CACHE BOOL \"\")\nset(BUILD_GIT_VERSION YES CACHE BOOL \"\")\nset(BUILD_TESTING YES CACHE BOOL \"\")\nset(WITH_ONEDNN YES CACHE BOOL \"\")\nset(TREAT_WARNINGS_AS_ERRORS YES CACHE BOOL \"\")\nset(THIRD_PARTY_MIRROR aliyun CACHE STRING \"\")\nset(PIP_INDEX_MIRROR \"https://pypi.tuna.tsinghua.edu.cn/simple\" CACHE STRING \"\")\nset(CMAKE_BUILD_TYPE RelWithDebInfo CACHE STRING \"\")\nset(CMAKE_GENERATOR Ninja CACHE STRING \"\")\nset(BUILD_CPP_API ON CACHE BOOL \"\")\nset(WITH_MLIR ON CACHE BOOL \"\")\nset(BUILD_FOR_CI ON CACHE BOOL \"\")\nset(BUILD_SHARED_LIBS ON CACHE BOOL \"\")\nset(CMAKE_C_COMPILER_LAUNCHER ccache CACHE STRING \"\")\nset(CMAKE_CXX_COMPILER_LAUNCHER ccache CACHE STRING \"\")\nset(CMAKE_INTERPROCEDURAL_OPTIMIZATION OFF CACHE BOOL \"\")\nset(ENABLE_TSAN ON CACHE BOOL \"\")\n"
  },
  {
    "path": "cmake/caches/ci/cpu.cmake",
    "content": "set(BUILD_CUDA NO CACHE BOOL \"\")\nset(BUILD_NPU NO CACHE BOOL \"\")\nset(BUILD_MLU NO CACHE BOOL \"\")\nset(BUILD_GIT_VERSION YES CACHE BOOL \"\")\nset(BUILD_TESTING YES CACHE BOOL \"\")\nset(WITH_ONEDNN YES CACHE BOOL \"\")\nset(TREAT_WARNINGS_AS_ERRORS YES CACHE BOOL \"\")\nset(THIRD_PARTY_MIRROR aliyun CACHE STRING \"\")\nset(PIP_INDEX_MIRROR \"https://pypi.tuna.tsinghua.edu.cn/simple\" CACHE STRING \"\")\nset(CMAKE_BUILD_TYPE Release CACHE STRING \"\")\nset(CMAKE_GENERATOR Ninja CACHE STRING \"\")\nset(BUILD_CPP_API ON CACHE BOOL \"\")\nset(WITH_MLIR ON CACHE BOOL \"\")\nset(BUILD_FOR_CI ON CACHE BOOL \"\")\nset(BUILD_SHARED_LIBS ON CACHE BOOL \"\")\nset(CMAKE_C_COMPILER_LAUNCHER ccache CACHE STRING \"\")\nset(CMAKE_CXX_COMPILER_LAUNCHER ccache CACHE STRING \"\")\n"
  },
  {
    "path": "cmake/caches/ci/cuda-xla.cmake",
    "content": "set(BUILD_CUDA YES CACHE BOOL \"\")\nset(BUILD_GIT_VERSION YES CACHE BOOL \"\")\nset(BUILD_TESTING YES CACHE BOOL \"\")\nset(BUILD_RDMA YES CACHE BOOL \"\")\nset(TREAT_WARNINGS_AS_ERRORS YES CACHE BOOL \"\")\nset(THIRD_PARTY_MIRROR aliyun CACHE STRING \"\")\nset(PIP_INDEX_MIRROR \"https://pypi.tuna.tsinghua.edu.cn/simple\" CACHE STRING \"\")\nset(CMAKE_BUILD_TYPE Release CACHE STRING \"\")\nset(CMAKE_GENERATOR Ninja CACHE STRING \"\")\nset(CMAKE_C_COMPILER_LAUNCHER ccache CACHE STRING \"\")\nset(CMAKE_CXX_COMPILER_LAUNCHER ccache CACHE STRING \"\")\nset(CMAKE_CUDA_COMPILER_LAUNCHER ccache CACHE STRING \"\")\nset(CMAKE_CUDA_ARCHITECTURES \"61;75\" CACHE STRING \"\")\nset(CUDNN_STATIC OFF CACHE BOOL \"\")\nset(RPC_BACKEND \"LOCAL\" CACHE STRING \"\")\nset(CUDA_NVCC_THREADS_NUMBER 8 CACHE STRING \"\")\n"
  },
  {
    "path": "cmake/caches/ci/cuda.cmake",
    "content": "set(BUILD_CUDA YES CACHE BOOL \"\")\nset(BUILD_GIT_VERSION YES CACHE BOOL \"\")\nset(BUILD_TESTING YES CACHE BOOL \"\")\nset(BUILD_RDMA YES CACHE BOOL \"\")\nset(TREAT_WARNINGS_AS_ERRORS YES CACHE BOOL \"\")\nset(THIRD_PARTY_MIRROR aliyun CACHE STRING \"\")\nset(PIP_INDEX_MIRROR \"https://pypi.tuna.tsinghua.edu.cn/simple\" CACHE STRING \"\")\nset(CMAKE_BUILD_TYPE Release CACHE STRING \"\")\nset(CMAKE_GENERATOR Ninja CACHE STRING \"\")\nset(CMAKE_C_COMPILER_LAUNCHER ccache CACHE STRING \"\")\nset(CMAKE_CXX_COMPILER_LAUNCHER ccache CACHE STRING \"\")\nset(CMAKE_CUDA_COMPILER_LAUNCHER ccache CACHE STRING \"\")\nset(CMAKE_CUDA_ARCHITECTURES \"75;86\" CACHE STRING \"\")\nset(CUDNN_STATIC ON CACHE BOOL \"\")\nset(WITH_MLIR ON CACHE BOOL \"\")\nset(BUILD_CPP_API ON CACHE BOOL \"\")\nset(CUDA_NVCC_THREADS_NUMBER 8 CACHE STRING \"\")\nset(BUILD_FOR_CI ON CACHE BOOL \"\")\nset(CMAKE_CXX_FLAGS\n    \"-Wno-unused-but-set-parameter -Wno-unused-variable -Wno-class-memaccess -Wno-cast-function-type -Wno-comment -Wno-reorder\"\n    CACHE STRING \"\")\n"
  },
  {
    "path": "cmake/caches/ci/gh-hosted/cpu-clang.cmake",
    "content": "set(CMAKE_C_COMPILER \"clang\" CACHE STRING \"\")\nset(CMAKE_CXX_COMPILER \"clang++\" CACHE STRING \"\")\nset(CMAKE_EXE_LINKER_FLAGS_INIT \"-fuse-ld=lld\" CACHE STRING \"\")\nset(CMAKE_MODULE_LINKER_FLAGS_INIT \"-fuse-ld=lld\" CACHE STRING \"\")\nset(CMAKE_SHARED_LINKER_FLAGS_INIT \"-fuse-ld=lld\" CACHE STRING \"\")\nset(BUILD_SHARED_LIBS YES CACHE BOOL \"\")\nset(BUILD_CUDA NO CACHE BOOL \"\")\nset(BUILD_TESTING YES CACHE BOOL \"\")\nset(CMAKE_BUILD_TYPE RelWithDebInfo CACHE STRING \"\")\nset(CMAKE_GENERATOR Ninja CACHE STRING \"\")\nset(CMAKE_C_COMPILER_LAUNCHER ccache CACHE STRING \"\")\nset(CMAKE_CXX_COMPILER_LAUNCHER ccache CACHE STRING \"\")\nset(CMAKE_INTERPROCEDURAL_OPTIMIZATION OFF CACHE BOOL \"\")\n"
  },
  {
    "path": "cmake/caches/ci/gh-hosted/cpu-gcc.cmake",
    "content": "set(BUILD_CUDA NO CACHE BOOL \"\")\nset(CMAKE_BUILD_TYPE RelWithDebInfo CACHE STRING \"\")\n"
  },
  {
    "path": "cmake/caches/ci/llvm/cuda-75-clang.cmake",
    "content": "set(CMAKE_C_COMPILER \"clang\" CACHE STRING \"\")\nset(CMAKE_CXX_COMPILER \"clang++\" CACHE STRING \"\")\nset(CMAKE_CUDA_COMPILER \"clang++\" CACHE STRING \"\")\nset(CMAKE_EXE_LINKER_FLAGS_INIT \"-fuse-ld=lld\" CACHE STRING \"\")\nset(CMAKE_MODULE_LINKER_FLAGS_INIT \"-fuse-ld=lld\" CACHE STRING \"\")\nset(CMAKE_SHARED_LINKER_FLAGS_INIT \"-fuse-ld=lld\" CACHE STRING \"\")\nset(WITH_MLIR YES CACHE BOOL \"\")\nset(BUILD_SHARED_LIBS YES CACHE BOOL \"\")\nset(BUILD_CUDA YES CACHE BOOL \"\")\nset(CMAKE_CUDA_ARCHITECTURES \"75;52-real\" CACHE STRING \"\")\nset(BUILD_TESTING YES CACHE BOOL \"\")\nset(THIRD_PARTY_MIRROR aliyun CACHE STRING \"\")\nset(PIP_INDEX_MIRROR \"https://pypi.tuna.tsinghua.edu.cn/simple\" CACHE STRING \"\")\nset(CMAKE_BUILD_TYPE RelWithDebInfo CACHE STRING \"\")\nset(CMAKE_GENERATOR Ninja CACHE STRING \"\")\nset(CMAKE_C_COMPILER_LAUNCHER ccache CACHE STRING \"\")\nset(CMAKE_CXX_COMPILER_LAUNCHER ccache CACHE STRING \"\")\nset(CMAKE_CUDA_COMPILER_LAUNCHER ccache CACHE STRING \"\")\nset(CMAKE_INTERPROCEDURAL_OPTIMIZATION OFF CACHE BOOL \"\")\nset(RPC_BACKEND \"LOCAL\" CACHE STRING \"\")\nset(BUILD_HWLOC NO CACHE BOOL \"\")\n"
  },
  {
    "path": "cmake/caches/ci/profiler/cuda.cmake",
    "content": "set(BUILD_CUDA YES CACHE BOOL \"\")\nset(BUILD_GIT_VERSION YES CACHE BOOL \"\")\nset(BUILD_TESTING OFF CACHE BOOL \"\")\nset(BUILD_RDMA YES CACHE BOOL \"\")\nset(TREAT_WARNINGS_AS_ERRORS YES CACHE BOOL \"\")\nset(THIRD_PARTY_MIRROR aliyun CACHE STRING \"\")\nset(PIP_INDEX_MIRROR \"https://pypi.tuna.tsinghua.edu.cn/simple\" CACHE STRING \"\")\nset(CMAKE_BUILD_TYPE Release CACHE STRING \"\")\nset(CMAKE_GENERATOR Ninja CACHE STRING \"\")\nset(CMAKE_CUDA_ARCHITECTURES \"61-real;70-real;75-real;80-real;86-real\" CACHE STRING \"\")\nset(CUDNN_STATIC OFF CACHE BOOL \"\")\nset(WITH_MLIR ON CACHE BOOL \"\")\nset(BUILD_PROFILER ON CACHE BOOL \"\")\nset(BUILD_CPP_API OFF CACHE BOOL \"\")\nset(CUDA_NVCC_THREADS_NUMBER 8 CACHE STRING \"\")\nset(CMAKE_C_COMPILER_LAUNCHER ccache CACHE STRING \"\")\nset(CMAKE_CXX_COMPILER_LAUNCHER ccache CACHE STRING \"\")\nset(CMAKE_CUDA_COMPILER_LAUNCHER ccache CACHE STRING \"\")\n"
  },
  {
    "path": "cmake/caches/ci/release/cpu.cmake",
    "content": "set(BUILD_CUDA OFF CACHE BOOL \"\")\nset(BUILD_GIT_VERSION YES CACHE BOOL \"\")\nset(BUILD_TESTING OFF CACHE BOOL \"\")\nset(TREAT_WARNINGS_AS_ERRORS YES CACHE BOOL \"\")\nset(THIRD_PARTY_MIRROR aliyun CACHE STRING \"\")\nset(PIP_INDEX_MIRROR \"https://pypi.tuna.tsinghua.edu.cn/simple\" CACHE STRING \"\")\nset(CMAKE_BUILD_TYPE Release CACHE STRING \"\")\nset(CMAKE_GENERATOR Ninja CACHE STRING \"\")\nset(CUDNN_STATIC OFF CACHE BOOL \"\")\nset(WITH_MLIR ON CACHE BOOL \"\")\nset(BUILD_CPP_API OFF CACHE BOOL \"\")\nset(CUDA_NVCC_THREADS_NUMBER 8 CACHE STRING \"\")\nset(CMAKE_C_COMPILER_LAUNCHER ccache CACHE STRING \"\")\nset(CMAKE_CXX_COMPILER_LAUNCHER ccache CACHE STRING \"\")\nset(CMAKE_CUDA_COMPILER_LAUNCHER ccache CACHE STRING \"\")\nset(CMAKE_CXX_FLAGS\n    \"-Wno-unused-but-set-parameter -Wno-unused-variable -Wno-class-memaccess -Wno-cast-function-type -Wno-comment -Wno-reorder\"\n    CACHE STRING \"\")\n"
  },
  {
    "path": "cmake/caches/ci/release/cu118.cmake",
    "content": "set(BUILD_CUDA YES CACHE BOOL \"\")\nset(BUILD_GIT_VERSION YES CACHE BOOL \"\")\nset(BUILD_TESTING OFF CACHE BOOL \"\")\nset(BUILD_RDMA YES CACHE BOOL \"\")\nset(TREAT_WARNINGS_AS_ERRORS YES CACHE BOOL \"\")\nset(THIRD_PARTY_MIRROR aliyun CACHE STRING \"\")\nset(PIP_INDEX_MIRROR \"https://pypi.tuna.tsinghua.edu.cn/simple\" CACHE STRING \"\")\nset(CMAKE_BUILD_TYPE Release CACHE STRING \"\")\nset(CMAKE_GENERATOR Ninja CACHE STRING \"\")\nset(CMAKE_CUDA_ARCHITECTURES \"70-real;80-real;86-real;89-real;90-real\" CACHE STRING \"\")\nset(CUDNN_STATIC OFF CACHE BOOL \"\")\nset(WITH_MLIR ON CACHE BOOL \"\")\nset(BUILD_CPP_API OFF CACHE BOOL \"\")\nset(CUDA_NVCC_THREADS_NUMBER 2 CACHE STRING \"\")\nset(CMAKE_C_COMPILER_LAUNCHER ccache CACHE STRING \"\")\nset(CMAKE_CXX_COMPILER_LAUNCHER ccache CACHE STRING \"\")\nset(CMAKE_CUDA_COMPILER_LAUNCHER ccache CACHE STRING \"\")\nset(CMAKE_CXX_FLAGS\n    \"-Wno-unused-but-set-parameter -Wno-unused-variable -Wno-class-memaccess -Wno-cast-function-type -Wno-comment -Wno-reorder\"\n    CACHE STRING \"\")\n"
  },
  {
    "path": "cmake/caches/ci/release/cuda.cmake",
    "content": "set(BUILD_CUDA YES CACHE BOOL \"\")\nset(BUILD_GIT_VERSION YES CACHE BOOL \"\")\nset(BUILD_TESTING OFF CACHE BOOL \"\")\nset(BUILD_RDMA YES CACHE BOOL \"\")\nset(TREAT_WARNINGS_AS_ERRORS YES CACHE BOOL \"\")\nset(THIRD_PARTY_MIRROR aliyun CACHE STRING \"\")\nset(PIP_INDEX_MIRROR \"https://pypi.tuna.tsinghua.edu.cn/simple\" CACHE STRING \"\")\nset(CMAKE_BUILD_TYPE Release CACHE STRING \"\")\nset(CMAKE_GENERATOR Ninja CACHE STRING \"\")\nset(CUDNN_STATIC OFF CACHE BOOL \"\")\nset(WITH_MLIR ON CACHE BOOL \"\")\nset(BUILD_CPP_API OFF CACHE BOOL \"\")\nset(CUDA_NVCC_THREADS_NUMBER 2 CACHE STRING \"\")\nset(CMAKE_C_COMPILER_LAUNCHER ccache CACHE STRING \"\")\nset(CMAKE_CXX_COMPILER_LAUNCHER ccache CACHE STRING \"\")\nset(CMAKE_CUDA_COMPILER_LAUNCHER ccache CACHE STRING \"\")\nset(CMAKE_CXX_FLAGS\n    \"-Wno-unused-but-set-parameter -Wno-unused-variable -Wno-class-memaccess -Wno-cast-function-type -Wno-comment -Wno-reorder\"\n    CACHE STRING \"\")\n"
  },
  {
    "path": "cmake/caches/ci/serving/cuda-75.cmake",
    "content": "set(BUILD_CUDA YES CACHE BOOL \"\")\nset(BUILD_TESTING YES CACHE BOOL \"\")\nset(BUILD_CPP_API YES CACHE BOOL \"\")\nset(WITH_MLIR YES CACHE BOOL \"\")\nset(BUILD_SHARED_LIBS YES CACHE BOOL \"\")\n# uncomment only if you know what you are doing\n# set(CMAKE_LINK_DEPENDS_NO_SHARED YES CACHE BOOL \"\")\nset(THIRD_PARTY_MIRROR aliyun CACHE STRING \"\")\nset(PIP_INDEX_MIRROR \"https://pypi.tuna.tsinghua.edu.cn/simple\" CACHE STRING \"\")\nset(CMAKE_BUILD_TYPE Release CACHE STRING \"\")\nset(CMAKE_GENERATOR Ninja CACHE STRING \"\")\nset(CUDA_TOOLKIT_ROOT_DIR /usr/local/cuda CACHE STRING \"\")\nset(CUDNN_ROOT_DIR /usr/local/cudnn CACHE STRING \"\")\nset(CMAKE_CUDA_ARCHITECTURES \"75\" CACHE STRING \"\")\nset(CMAKE_C_COMPILER_LAUNCHER ccache CACHE STRING \"\")\nset(CMAKE_CXX_COMPILER_LAUNCHER ccache CACHE STRING \"\")\nset(CMAKE_CUDA_COMPILER_LAUNCHER ccache CACHE STRING \"\")\nset(CMAKE_INTERPROCEDURAL_OPTIMIZATION OFF CACHE BOOL \"\")\nset(CUDA_NVCC_THREADS_NUMBER 8 CACHE STRING \"\")\n"
  },
  {
    "path": "cmake/caches/ci/serving/openvino.cmake",
    "content": "set(BUILD_SHARED_LIBS YES CACHE BOOL \"\")\n# uncomment only if you know what you are doing\n# set(CMAKE_LINK_DEPENDS_NO_SHARED YES CACHE BOOL \"\")\nset(BUILD_CUDA NO CACHE BOOL \"\")\nset(BUILD_CPP_API ON CACHE BOOL \"\")\nset(BUILD_GIT_VERSION NO CACHE BOOL \"\")\nset(TREAT_WARNINGS_AS_ERRORS YES CACHE BOOL \"\")\nset(BUILD_HWLOC NO CACHE BOOL \"\")\nset(BUILD_TESTING ON CACHE BOOL \"\")\nset(WITH_MLIR YES CACHE BOOL \"\")\nset(THIRD_PARTY_MIRROR aliyun CACHE STRING \"\")\nset(PIP_INDEX_MIRROR \"https://pypi.tuna.tsinghua.edu.cn/simple\" CACHE STRING \"\")\nset(CMAKE_BUILD_TYPE Release CACHE STRING \"\")\n# set(CMAKE_GENERATOR Ninja CACHE STRING \"\")\nset(CMAKE_C_COMPILER_LAUNCHER ccache CACHE STRING \"\")\nset(CMAKE_CXX_COMPILER_LAUNCHER ccache CACHE STRING \"\")\nset(CMAKE_CUDA_COMPILER_LAUNCHER ccache CACHE STRING \"\")\nset(CMAKE_INTERPROCEDURAL_OPTIMIZATION OFF CACHE BOOL \"\")\nset(BUILD_HWLOC OFF CACHE BOOL \"\")\nset(WITH_ONEDNN OFF CACHE BOOL \"\")\nset(CMAKE_EXPORT_COMPILE_COMMANDS ON CACHE STRING \"\")\nset(CMAKE_EXE_LINKER_FLAGS_INIT \"-fuse-ld=lld\" CACHE STRING \"\")\nset(CMAKE_MODULE_LINKER_FLAGS_INIT \"-fuse-ld=lld\" CACHE STRING \"\")\nset(CMAKE_SHARED_LINKER_FLAGS_INIT \"-fuse-ld=lld\" CACHE STRING \"\")\n"
  },
  {
    "path": "cmake/caches/cn/cpu.cmake",
    "content": "set(BUILD_CUDA NO CACHE BOOL \"\")\nset(BUILD_NPU NO CACHE BOOL \"\")\nset(BUILD_MLU NO CACHE BOOL \"\")\nset(BUILD_SHARED_LIBS YES CACHE BOOL \"\")\nset(THIRD_PARTY_MIRROR aliyun CACHE STRING \"\")\nset(PIP_INDEX_MIRROR \"https://pypi.tuna.tsinghua.edu.cn/simple\" CACHE STRING \"\")\n"
  },
  {
    "path": "cmake/caches/cn/cuda.cmake",
    "content": "set(BUILD_CUDA YES CACHE BOOL \"\")\nset(BUILD_SHARED_LIBS YES CACHE BOOL \"\")\nset(THIRD_PARTY_MIRROR aliyun CACHE STRING \"\")\nset(PIP_INDEX_MIRROR \"https://pypi.tuna.tsinghua.edu.cn/simple\" CACHE STRING \"\")\n"
  },
  {
    "path": "cmake/caches/cn/fast/cpu-clang.cmake",
    "content": "set(CMAKE_C_COMPILER \"clang\" CACHE STRING \"\")\nset(CMAKE_CXX_COMPILER \"clang++\" CACHE STRING \"\")\nset(CMAKE_EXE_LINKER_FLAGS_INIT \"-fuse-ld=lld\" CACHE STRING \"\")\nset(CMAKE_MODULE_LINKER_FLAGS_INIT \"-fuse-ld=lld\" CACHE STRING \"\")\nset(CMAKE_SHARED_LINKER_FLAGS_INIT \"-fuse-ld=lld\" CACHE STRING \"\")\nset(BUILD_SHARED_LIBS YES CACHE BOOL \"\")\n# uncomment only if you know what you are doing\n# set(CMAKE_LINK_DEPENDS_NO_SHARED YES CACHE BOOL \"\")\nset(BUILD_CUDA NO CACHE BOOL \"\")\nset(BUILD_TESTING YES CACHE BOOL \"\")\nset(THIRD_PARTY_MIRROR aliyun CACHE STRING \"\")\nset(PIP_INDEX_MIRROR \"https://pypi.tuna.tsinghua.edu.cn/simple\" CACHE STRING \"\")\nset(CMAKE_BUILD_TYPE RelWithDebInfo CACHE STRING \"\")\nset(CMAKE_GENERATOR Ninja CACHE STRING \"\")\nset(CMAKE_C_COMPILER_LAUNCHER ccache CACHE STRING \"\")\nset(CMAKE_CXX_COMPILER_LAUNCHER ccache CACHE STRING \"\")\nset(CMAKE_INTERPROCEDURAL_OPTIMIZATION OFF CACHE BOOL \"\")\nset(BUILD_HWLOC OFF CACHE BOOL \"\")\n"
  },
  {
    "path": "cmake/caches/cn/fast/cpu.cmake",
    "content": "set(BUILD_SHARED_LIBS YES CACHE BOOL \"\")\n# uncomment only if you know what you are doing\n# set(CMAKE_LINK_DEPENDS_NO_SHARED YES CACHE BOOL \"\")\nset(BUILD_CUDA NO CACHE BOOL \"\")\nset(BUILD_TESTING YES CACHE BOOL \"\")\nset(THIRD_PARTY_MIRROR aliyun CACHE STRING \"\")\nset(PIP_INDEX_MIRROR \"https://pypi.tuna.tsinghua.edu.cn/simple\" CACHE STRING \"\")\nset(CMAKE_BUILD_TYPE RelWithDebInfo CACHE STRING \"\")\nset(CMAKE_GENERATOR Ninja CACHE STRING \"\")\nset(CMAKE_C_COMPILER_LAUNCHER ccache CACHE STRING \"\")\nset(CMAKE_CXX_COMPILER_LAUNCHER ccache CACHE STRING \"\")\nset(CMAKE_INTERPROCEDURAL_OPTIMIZATION OFF CACHE BOOL \"\")\nset(BUILD_HWLOC OFF CACHE BOOL \"\")\n"
  },
  {
    "path": "cmake/caches/cn/fast/cuda-61-clang.cmake",
    "content": "set(CMAKE_C_COMPILER \"clang\" CACHE STRING \"\")\nset(CMAKE_CXX_COMPILER \"clang++\" CACHE STRING \"\")\nset(CMAKE_EXE_LINKER_FLAGS_INIT \"-fuse-ld=lld\" CACHE STRING \"\")\nset(CMAKE_MODULE_LINKER_FLAGS_INIT \"-fuse-ld=lld\" CACHE STRING \"\")\nset(CMAKE_SHARED_LINKER_FLAGS_INIT \"-fuse-ld=lld\" CACHE STRING \"\")\nset(BUILD_SHARED_LIBS YES CACHE BOOL \"\")\n# uncomment only if you know what you are doing\n# set(CMAKE_LINK_DEPENDS_NO_SHARED YES CACHE BOOL \"\")\nset(BUILD_CUDA YES CACHE BOOL \"\")\nset(CMAKE_CUDA_ARCHITECTURES \"61\" CACHE STRING \"\")\nset(BUILD_TESTING YES CACHE BOOL \"\")\nset(THIRD_PARTY_MIRROR aliyun CACHE STRING \"\")\nset(PIP_INDEX_MIRROR \"https://pypi.tuna.tsinghua.edu.cn/simple\" CACHE STRING \"\")\nset(CMAKE_BUILD_TYPE RelWithDebInfo CACHE STRING \"\")\nset(CMAKE_GENERATOR Ninja CACHE STRING \"\")\nset(CMAKE_C_COMPILER_LAUNCHER ccache CACHE STRING \"\")\nset(CMAKE_CXX_COMPILER_LAUNCHER ccache CACHE STRING \"\")\nset(CMAKE_CUDA_COMPILER_LAUNCHER ccache CACHE STRING \"\")\nset(CMAKE_INTERPROCEDURAL_OPTIMIZATION OFF CACHE BOOL \"\")\nset(BUILD_HWLOC OFF CACHE BOOL \"\")\n"
  },
  {
    "path": "cmake/caches/cn/fast/cuda-61.cmake",
    "content": "set(BUILD_CUDA YES CACHE BOOL \"\")\nset(BUILD_TESTING YES CACHE BOOL \"\")\nset(BUILD_SHARED_LIBS YES CACHE BOOL \"\")\n# uncomment only if you know what you are doing\n# set(CMAKE_LINK_DEPENDS_NO_SHARED YES CACHE BOOL \"\")\nset(THIRD_PARTY_MIRROR aliyun CACHE STRING \"\")\nset(PIP_INDEX_MIRROR \"https://pypi.tuna.tsinghua.edu.cn/simple\" CACHE STRING \"\")\nset(CMAKE_BUILD_TYPE RelWithDebInfo CACHE STRING \"\")\nset(CMAKE_GENERATOR Ninja CACHE STRING \"\")\nset(CMAKE_CUDA_ARCHITECTURES \"61\" CACHE STRING \"\")\nset(CMAKE_C_COMPILER_LAUNCHER ccache CACHE STRING \"\")\nset(CMAKE_CXX_COMPILER_LAUNCHER ccache CACHE STRING \"\")\nset(CMAKE_CUDA_COMPILER_LAUNCHER ccache CACHE STRING \"\")\nset(CMAKE_INTERPROCEDURAL_OPTIMIZATION OFF CACHE BOOL \"\")\nset(BUILD_HWLOC OFF CACHE BOOL \"\")\n"
  },
  {
    "path": "cmake/caches/cn/fast/cuda-75-clang.cmake",
    "content": "set(CMAKE_C_COMPILER \"clang\" CACHE STRING \"\")\nset(WITH_MLIR YES CACHE BOOL \"\")\nset(WITH_MLIR_CUDA_CODEGEN YES CACHE BOOL \"\")\nset(CMAKE_CXX_COMPILER \"clang++\" CACHE STRING \"\")\nset(CMAKE_EXE_LINKER_FLAGS_INIT \"-fuse-ld=lld\" CACHE STRING \"\")\nset(CMAKE_MODULE_LINKER_FLAGS_INIT \"-fuse-ld=lld\" CACHE STRING \"\")\nset(CMAKE_SHARED_LINKER_FLAGS_INIT \"-fuse-ld=lld\" CACHE STRING \"\")\nset(BUILD_SHARED_LIBS YES CACHE BOOL \"\")\n# uncomment only if you know what you are doing\n# set(CMAKE_LINK_DEPENDS_NO_SHARED YES CACHE BOOL \"\")\nset(BUILD_CUDA YES CACHE BOOL \"\")\nset(CMAKE_CUDA_ARCHITECTURES \"75\" CACHE STRING \"\")\nset(BUILD_TESTING YES CACHE BOOL \"\")\nset(THIRD_PARTY_MIRROR aliyun CACHE STRING \"\")\nset(PIP_INDEX_MIRROR \"https://pypi.tuna.tsinghua.edu.cn/simple\" CACHE STRING \"\")\nset(CMAKE_BUILD_TYPE RelWithDebInfo CACHE STRING \"\")\nset(CMAKE_GENERATOR Ninja CACHE STRING \"\")\nset(CMAKE_C_COMPILER_LAUNCHER ccache CACHE STRING \"\")\nset(CMAKE_CXX_COMPILER_LAUNCHER ccache CACHE STRING \"\")\nset(CMAKE_CUDA_COMPILER_LAUNCHER ccache CACHE STRING \"\")\nset(CMAKE_INTERPROCEDURAL_OPTIMIZATION OFF CACHE BOOL \"\")\nset(BUILD_HWLOC OFF CACHE BOOL \"\")\n"
  },
  {
    "path": "cmake/caches/cn/fast/cuda-75.cmake",
    "content": "set(BUILD_CUDA YES CACHE BOOL \"\")\nset(BUILD_TESTING YES CACHE BOOL \"\")\nset(BUILD_SHARED_LIBS YES CACHE BOOL \"\")\n# uncomment only if you know what you are doing\n# set(CMAKE_LINK_DEPENDS_NO_SHARED YES CACHE BOOL \"\")\nset(THIRD_PARTY_MIRROR aliyun CACHE STRING \"\")\nset(PIP_INDEX_MIRROR \"https://pypi.tuna.tsinghua.edu.cn/simple\" CACHE STRING \"\")\nset(CMAKE_BUILD_TYPE RelWithDebInfo CACHE STRING \"\")\nset(CMAKE_GENERATOR Ninja CACHE STRING \"\")\nset(CMAKE_CUDA_ARCHITECTURES \"75\" CACHE STRING \"\")\nset(CMAKE_C_COMPILER_LAUNCHER ccache CACHE STRING \"\")\nset(CMAKE_CXX_COMPILER_LAUNCHER ccache CACHE STRING \"\")\nset(CMAKE_CUDA_COMPILER_LAUNCHER ccache CACHE STRING \"\")\nset(CMAKE_INTERPROCEDURAL_OPTIMIZATION OFF CACHE BOOL \"\")\nset(BUILD_HWLOC OFF CACHE BOOL \"\")\n# uncomment these when necessary, otherwise it is for the demonstration purpose\n\n# set(CUDA_TOOLKIT_ROOT_DIR /usr/local/cuda CACHE STRING \"\")\n# set(CUDNN_ROOT_DIR /usr/local/cudnn CACHE STRING \"\")\n\n# set(CMAKE_CUDA_HOST_COMPILER clang++ CACHE STRING \"\")\n# set(CMAKE_C_COMPILER \"clang\" CACHE STRING \"\")\n# set(CMAKE_CXX_COMPILER \"clang++\" CACHE STRING \"\")\n# set(CMAKE_EXE_LINKER_FLAGS_INIT \"-fuse-ld=lld\" CACHE STRING \"\")\n# set(CMAKE_MODULE_LINKER_FLAGS_INIT \"-fuse-ld=lld\" CACHE STRING \"\")\n# set(CMAKE_SHARED_LINKER_FLAGS_INIT \"-fuse-ld=lld\" CACHE STRING \"\")\n"
  },
  {
    "path": "cmake/caches/cn/fast/cuda-86.cmake",
    "content": "set(BUILD_CUDA YES CACHE BOOL \"\")\nset(BUILD_TESTING YES CACHE BOOL \"\")\nset(BUILD_SHARED_LIBS YES CACHE BOOL \"\")\n# uncomment only if you know what you are doing\n# set(CMAKE_LINK_DEPENDS_NO_SHARED YES CACHE BOOL \"\")\nset(THIRD_PARTY_MIRROR aliyun CACHE STRING \"\")\nset(PIP_INDEX_MIRROR \"https://pypi.tuna.tsinghua.edu.cn/simple\" CACHE STRING \"\")\nset(CMAKE_BUILD_TYPE RelWithDebInfo CACHE STRING \"\")\nset(CMAKE_GENERATOR Ninja CACHE STRING \"\")\nset(CMAKE_CUDA_ARCHITECTURES \"86\" CACHE STRING \"\")\nset(CMAKE_C_COMPILER_LAUNCHER ccache CACHE STRING \"\")\nset(CMAKE_CXX_COMPILER_LAUNCHER ccache CACHE STRING \"\")\nset(CMAKE_CUDA_COMPILER_LAUNCHER ccache CACHE STRING \"\")\nset(CMAKE_INTERPROCEDURAL_OPTIMIZATION OFF CACHE BOOL \"\")\nset(BUILD_HWLOC OFF CACHE BOOL \"\")\n"
  },
  {
    "path": "cmake/caches/cn/fast/mlir-cpu.cmake",
    "content": "set(BUILD_SHARED_LIBS YES CACHE BOOL \"\")\n# uncomment only if you know what you are doing\n# set(CMAKE_LINK_DEPENDS_NO_SHARED YES CACHE BOOL \"\")\nset(BUILD_CUDA NO CACHE BOOL \"\")\nset(BUILD_GIT_VERSION NO CACHE BOOL \"\")\nset(TREAT_WARNINGS_AS_ERRORS YES CACHE BOOL \"\")\nset(BUILD_HWLOC NO CACHE BOOL \"\")\nset(BUILD_TESTING OFF CACHE BOOL \"\")\nset(WITH_MLIR YES CACHE BOOL \"\")\nset(WITH_MLIR_CUDA_CODEGEN NO CACHE BOOL \"\")\nset(THIRD_PARTY_MIRROR aliyun CACHE STRING \"\")\nset(PIP_INDEX_MIRROR \"https://pypi.tuna.tsinghua.edu.cn/simple\" CACHE STRING \"\")\nset(CMAKE_BUILD_TYPE RelWithDebInfo CACHE STRING \"\")\nset(CMAKE_GENERATOR Ninja CACHE STRING \"\")\nset(CMAKE_C_COMPILER_LAUNCHER ccache CACHE STRING \"\")\nset(CMAKE_CXX_COMPILER_LAUNCHER ccache CACHE STRING \"\")\nset(CMAKE_INTERPROCEDURAL_OPTIMIZATION OFF CACHE BOOL \"\")\nset(CMAKE_EXE_LINKER_FLAGS_INIT \"-fuse-ld=lld\" CACHE STRING \"\")\nset(CMAKE_MODULE_LINKER_FLAGS_INIT \"-fuse-ld=lld\" CACHE STRING \"\")\nset(CMAKE_SHARED_LINKER_FLAGS_INIT \"-fuse-ld=lld\" CACHE STRING \"\")\nset(BUILD_HWLOC OFF CACHE BOOL \"\")\nset(WITH_ONEDNN OFF CACHE BOOL \"\")\n"
  },
  {
    "path": "cmake/caches/cn/fast/mlir-cuda-61.cmake",
    "content": "set(BUILD_SHARED_LIBS YES CACHE BOOL \"\")\n# uncomment only if you know what you are doing\n# set(CMAKE_LINK_DEPENDS_NO_SHARED YES CACHE BOOL \"\")\nset(BUILD_CUDA YES CACHE BOOL \"\")\nset(BUILD_GIT_VERSION NO CACHE BOOL \"\")\nset(TREAT_WARNINGS_AS_ERRORS YES CACHE BOOL \"\")\nset(BUILD_HWLOC NO CACHE BOOL \"\")\nset(BUILD_TESTING OFF CACHE BOOL \"\")\nset(WITH_MLIR YES CACHE BOOL \"\")\nset(WITH_MLIR_CUDA_CODEGEN YES CACHE BOOL \"\")\nset(THIRD_PARTY_MIRROR aliyun CACHE STRING \"\")\nset(PIP_INDEX_MIRROR \"https://pypi.tuna.tsinghua.edu.cn/simple\" CACHE STRING \"\")\nset(CMAKE_BUILD_TYPE RelWithDebInfo CACHE STRING \"\")\nset(CMAKE_GENERATOR Ninja CACHE STRING \"\")\nset(CMAKE_CUDA_ARCHITECTURES \"61-real\" CACHE STRING \"\")\nset(CUDA_TOOLKIT_ROOT_DIR /usr/local/cuda CACHE STRING \"\")\nset(CUDNN_ROOT_DIR /usr/local/cudnn CACHE STRING \"\")\nset(CMAKE_C_COMPILER_LAUNCHER ccache CACHE STRING \"\")\nset(CMAKE_CXX_COMPILER_LAUNCHER ccache CACHE STRING \"\")\nset(CMAKE_CUDA_COMPILER_LAUNCHER ccache CACHE STRING \"\")\nset(CMAKE_INTERPROCEDURAL_OPTIMIZATION OFF CACHE BOOL \"\")\nset(CMAKE_C_COMPILER \"clang\" CACHE STRING \"\")\nset(CMAKE_CXX_COMPILER \"clang++\" CACHE STRING \"\")\nset(CMAKE_EXE_LINKER_FLAGS_INIT \"-fuse-ld=lld\" CACHE STRING \"\")\nset(CMAKE_MODULE_LINKER_FLAGS_INIT \"-fuse-ld=lld\" CACHE STRING \"\")\nset(CMAKE_SHARED_LINKER_FLAGS_INIT \"-fuse-ld=lld\" CACHE STRING \"\")\nset(BUILD_HWLOC OFF CACHE BOOL \"\")\n"
  },
  {
    "path": "cmake/caches/cn/fast/mlir-cuda-75.cmake",
    "content": "set(BUILD_SHARED_LIBS YES CACHE BOOL \"\")\n# uncomment only if you know what you are doing\n# set(CMAKE_LINK_DEPENDS_NO_SHARED YES CACHE BOOL \"\")\nset(BUILD_CUDA YES CACHE BOOL \"\")\nset(BUILD_GIT_VERSION NO CACHE BOOL \"\")\nset(TREAT_WARNINGS_AS_ERRORS YES CACHE BOOL \"\")\nset(BUILD_HWLOC NO CACHE BOOL \"\")\nset(BUILD_TESTING OFF CACHE BOOL \"\")\nset(WITH_MLIR YES CACHE BOOL \"\")\nset(WITH_MLIR_CUDA_CODEGEN YES CACHE BOOL \"\")\nset(THIRD_PARTY_MIRROR aliyun CACHE STRING \"\")\nset(PIP_INDEX_MIRROR \"https://pypi.tuna.tsinghua.edu.cn/simple\" CACHE STRING \"\")\nset(CMAKE_BUILD_TYPE RelWithDebInfo CACHE STRING \"\")\nset(CMAKE_GENERATOR Ninja CACHE STRING \"\")\nset(CMAKE_CUDA_ARCHITECTURES \"75\" CACHE STRING \"\")\nset(CUDA_TOOLKIT_ROOT_DIR /usr/local/cuda CACHE STRING \"\")\nset(CUDNN_ROOT_DIR /usr/local/cudnn CACHE STRING \"\")\nset(CMAKE_C_COMPILER_LAUNCHER ccache CACHE STRING \"\")\nset(CMAKE_CXX_COMPILER_LAUNCHER ccache CACHE STRING \"\")\nset(CMAKE_CUDA_COMPILER_LAUNCHER ccache CACHE STRING \"\")\nset(CMAKE_INTERPROCEDURAL_OPTIMIZATION OFF CACHE BOOL \"\")\nset(CMAKE_EXE_LINKER_FLAGS_INIT \"-fuse-ld=lld\" CACHE STRING \"\")\nset(CMAKE_MODULE_LINKER_FLAGS_INIT \"-fuse-ld=lld\" CACHE STRING \"\")\nset(CMAKE_SHARED_LINKER_FLAGS_INIT \"-fuse-ld=lld\" CACHE STRING \"\")\nset(BUILD_HWLOC OFF CACHE BOOL \"\")\nset(WITH_ONEDNN OFF CACHE BOOL \"\")\n"
  },
  {
    "path": "cmake/caches/cn/fast/mlir-cuda-80.cmake",
    "content": "set(BUILD_SHARED_LIBS YES CACHE BOOL \"\")\n# uncomment only if you know what you are doing\n# set(CMAKE_LINK_DEPENDS_NO_SHARED YES CACHE BOOL \"\")\nset(BUILD_CUDA YES CACHE BOOL \"\")\nset(BUILD_GIT_VERSION NO CACHE BOOL \"\")\nset(TREAT_WARNINGS_AS_ERRORS YES CACHE BOOL \"\")\nset(BUILD_HWLOC NO CACHE BOOL \"\")\nset(BUILD_TESTING OFF CACHE BOOL \"\")\nset(WITH_MLIR YES CACHE BOOL \"\")\nset(WITH_MLIR_CUDA_CODEGEN YES CACHE BOOL \"\")\nset(THIRD_PARTY_MIRROR aliyun CACHE STRING \"\")\nset(PIP_INDEX_MIRROR \"https://pypi.tuna.tsinghua.edu.cn/simple\" CACHE STRING \"\")\nset(CMAKE_BUILD_TYPE RelWithDebInfo CACHE STRING \"\")\nset(CMAKE_GENERATOR Ninja CACHE STRING \"\")\nset(CMAKE_CUDA_ARCHITECTURES \"80\" CACHE STRING \"\")\nset(CUDA_TOOLKIT_ROOT_DIR /usr/local/cuda CACHE STRING \"\")\nset(CUDNN_ROOT_DIR /usr/local/cudnn CACHE STRING \"\")\nset(CMAKE_C_COMPILER_LAUNCHER ccache CACHE STRING \"\")\nset(CMAKE_CXX_COMPILER_LAUNCHER ccache CACHE STRING \"\")\nset(CMAKE_CUDA_COMPILER_LAUNCHER ccache CACHE STRING \"\")\nset(CMAKE_INTERPROCEDURAL_OPTIMIZATION OFF CACHE BOOL \"\")\nset(CMAKE_EXE_LINKER_FLAGS_INIT \"-fuse-ld=lld\" CACHE STRING \"\")\nset(CMAKE_MODULE_LINKER_FLAGS_INIT \"-fuse-ld=lld\" CACHE STRING \"\")\nset(CMAKE_SHARED_LINKER_FLAGS_INIT \"-fuse-ld=lld\" CACHE STRING \"\")\nset(CPU_THREADING_RUNTIME SEQ CACHE STRING\n                                    \"when using lld with TBB enabled, there will be linkage error\")\nset(BUILD_HWLOC OFF CACHE BOOL \"\")\nset(WITH_ONEDNN OFF CACHE BOOL \"\")\n"
  },
  {
    "path": "cmake/caches/cn/fast/mlir-cuda-86.cmake",
    "content": "set(BUILD_SHARED_LIBS YES CACHE BOOL \"\")\n# uncomment only if you know what you are doing\n# set(CMAKE_LINK_DEPENDS_NO_SHARED YES CACHE BOOL \"\")\nset(BUILD_CUDA YES CACHE BOOL \"\")\nset(BUILD_GIT_VERSION NO CACHE BOOL \"\")\nset(TREAT_WARNINGS_AS_ERRORS YES CACHE BOOL \"\")\nset(BUILD_HWLOC NO CACHE BOOL \"\")\nset(BUILD_TESTING OFF CACHE BOOL \"\")\nset(WITH_MLIR YES CACHE BOOL \"\")\nset(WITH_MLIR_CUDA_CODEGEN YES CACHE BOOL \"\")\nset(THIRD_PARTY_MIRROR aliyun CACHE STRING \"\")\nset(PIP_INDEX_MIRROR \"https://pypi.tuna.tsinghua.edu.cn/simple\" CACHE STRING \"\")\nset(CMAKE_BUILD_TYPE RelWithDebInfo CACHE STRING \"\")\nset(CMAKE_GENERATOR Ninja CACHE STRING \"\")\nset(CMAKE_CUDA_ARCHITECTURES \"86\" CACHE STRING \"\")\nset(CUDA_TOOLKIT_ROOT_DIR /usr/local/cuda CACHE STRING \"\")\nset(CUDNN_ROOT_DIR /usr/local/cudnn CACHE STRING \"\")\nset(CMAKE_C_COMPILER_LAUNCHER ccache CACHE STRING \"\")\nset(CMAKE_CXX_COMPILER_LAUNCHER ccache CACHE STRING \"\")\nset(CMAKE_CUDA_COMPILER_LAUNCHER ccache CACHE STRING \"\")\nset(CMAKE_INTERPROCEDURAL_OPTIMIZATION OFF CACHE BOOL \"\")\nset(CMAKE_EXE_LINKER_FLAGS_INIT \"-fuse-ld=lld\" CACHE STRING \"\")\nset(CMAKE_MODULE_LINKER_FLAGS_INIT \"-fuse-ld=lld\" CACHE STRING \"\")\nset(CMAKE_SHARED_LINKER_FLAGS_INIT \"-fuse-ld=lld\" CACHE STRING \"\")\nset(CPU_THREADING_RUNTIME SEQ CACHE STRING\n                                    \"when using lld with TBB enabled, there will be linkage error\")\nset(BUILD_HWLOC OFF CACHE BOOL \"\")\nset(WITH_ONEDNN OFF CACHE BOOL \"\")\n"
  },
  {
    "path": "cmake/caches/international/cpu.cmake",
    "content": "set(BUILD_CUDA NO CACHE BOOL \"\")\nset(BUILD_SHARED_LIBS YES CACHE BOOL \"\")\nset(CMAKE_BUILD_TYPE RelWithDebInfo CACHE STRING \"\")\n"
  },
  {
    "path": "cmake/caches/international/cuda.cmake",
    "content": "set(BUILD_CUDA YES CACHE BOOL \"\")\nset(BUILD_SHARED_LIBS YES CACHE BOOL \"\")\nset(CMAKE_BUILD_TYPE RelWithDebInfo CACHE STRING \"\")\n"
  },
  {
    "path": "cmake/cuda.cmake",
    "content": "if(BUILD_CUDA)\n  if(DEFINED CUDA_TOOLKIT_ROOT_DIR)\n    message(WARNING \"CUDA_TOOLKIT_ROOT_DIR is deprecated, use CUDAToolkit_ROOT instead\")\n    set(CUDAToolkit_ROOT ${CUDA_TOOLKIT_ROOT_DIR})\n  endif(DEFINED CUDA_TOOLKIT_ROOT_DIR)\n  find_package(CUDAToolkit REQUIRED)\n  message(STATUS \"CUDAToolkit_FOUND: ${CUDAToolkit_FOUND}\")\n  message(STATUS \"CUDAToolkit_VERSION: ${CUDAToolkit_VERSION}\")\n  message(STATUS \"CUDAToolkit_VERSION_MAJOR: ${CUDAToolkit_VERSION_MAJOR}\")\n  message(STATUS \"CUDAToolkit_VERSION_MINOR: ${CUDAToolkit_VERSION_MINOR}\")\n  message(STATUS \"CUDAToolkit_VERSION_PATCH: ${CUDAToolkit_VERSION_PATCH}\")\n  message(STATUS \"CUDAToolkit_BIN_DIR: ${CUDAToolkit_BIN_DIR}\")\n  message(STATUS \"CUDAToolkit_INCLUDE_DIRS: ${CUDAToolkit_INCLUDE_DIRS}\")\n  message(STATUS \"CUDAToolkit_LIBRARY_DIR: ${CUDAToolkit_LIBRARY_DIR}\")\n  message(STATUS \"CUDAToolkit_LIBRARY_ROOT: ${CUDAToolkit_LIBRARY_ROOT}\")\n  message(STATUS \"CUDAToolkit_TARGET_DIR: ${CUDAToolkit_TARGET_DIR}\")\n  message(STATUS \"CUDAToolkit_NVCC_EXECUTABLE: ${CUDAToolkit_NVCC_EXECUTABLE}\")\n  if(CUDA_NVCC_GENCODES)\n    message(FATAL_ERROR \"CUDA_NVCC_GENCODES is deprecated, use CMAKE_CUDA_ARCHITECTURES instead\")\n  endif()\n  add_definitions(-DWITH_CUDA)\n\n  # NOTE: For some unknown reason, CUDAToolkit_VERSION may become empty when running cmake again\n  set(CUDA_VERSION ${CUDAToolkit_VERSION} CACHE STRING \"\")\n  if(NOT CUDA_VERSION)\n    message(FATAL_ERROR \"CUDA_VERSION empty\")\n  endif()\n  message(STATUS \"CUDA_VERSION: ${CUDA_VERSION}\")\n  if(CUDA_VERSION VERSION_GREATER_EQUAL \"11.0\")\n    set(CUDA_STATIC OFF CACHE BOOL \"\")\n  else()\n    set(CUDA_STATIC ON CACHE BOOL \"\")\n  endif()\n\n  if((NOT CUDA_STATIC) OR BUILD_SHARED_LIBS)\n    set(OF_CUDA_LINK_DYNAMIC_LIBRARY ON)\n  else()\n    set(OF_CUDA_LINK_DYNAMIC_LIBRARY OFF)\n  endif()\n\n  if(OF_CUDA_LINK_DYNAMIC_LIBRARY)\n    list(APPEND VENDOR_CUDA_LIBRARIES CUDA::cublas)\n    list(APPEND VENDOR_CUDA_LIBRARIES CUDA::curand)\n    list(APPEND VENDOR_CUDA_LIBRARIES CUDA::cusolver)\n    list(APPEND VENDOR_CUDA_LIBRARIES CUDA::cufft)\n    if(CUDA_VERSION VERSION_GREATER_EQUAL \"10.1\")\n      list(APPEND VENDOR_CUDA_LIBRARIES CUDA::cublasLt)\n    endif()\n    if(CUDA_VERSION VERSION_GREATER_EQUAL \"10.2\")\n      list(APPEND VENDOR_CUDA_LIBRARIES CUDA::nvjpeg)\n      list(APPEND VENDOR_CUDA_LIBRARIES CUDA::nppc)\n      list(APPEND VENDOR_CUDA_LIBRARIES CUDA::nppig)\n    endif()\n  else()\n    list(APPEND VENDOR_CUDA_LIBRARIES CUDA::cublas_static)\n    list(APPEND VENDOR_CUDA_LIBRARIES CUDA::curand_static)\n    list(APPEND VENDOR_CUDA_LIBRARIES CUDA::cufft_static)\n    list(APPEND VENDOR_CUDA_LIBRARIES CUDA::cusolver_static)\n    if(CUDA_VERSION VERSION_GREATER_EQUAL \"10.1\")\n      list(APPEND VENDOR_CUDA_LIBRARIES CUDA::cublasLt_static)\n    endif()\n    if(CUDA_VERSION VERSION_GREATER_EQUAL \"10.2\")\n      list(APPEND VENDOR_CUDA_LIBRARIES CUDA::nvjpeg_static)\n      list(APPEND VENDOR_CUDA_LIBRARIES CUDA::nppig_static)\n      # Must put nppc_static after nppig_static in CUDA 10.2\n      list(APPEND VENDOR_CUDA_LIBRARIES CUDA::nppc_static)\n      list(APPEND VENDOR_CUDA_LIBRARIES CUDA::culibos)\n    endif()\n  endif()\n  message(STATUS \"VENDOR_CUDA_LIBRARIES: ${VENDOR_CUDA_LIBRARIES}\")\n  # add a cache entry if want to use a ccache/sccache wrapped nvcc\n  set(CMAKE_CUDA_COMPILER ${CUDAToolkit_NVCC_EXECUTABLE} CACHE STRING \"\")\n  message(STATUS \"CMAKE_CUDA_COMPILER: ${CMAKE_CUDA_COMPILER}\")\n  set(CMAKE_CUDA_STANDARD 17)\n  find_package(CUDNN REQUIRED)\n\n  # NOTE: if you want to use source PTX with a version different from produced PTX/binary, you should add flags\n  if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES)\n    if(CUDA_VERSION VERSION_GREATER_EQUAL \"10.0\")\n      # T4, Quadro RTX xxxx, Txxxx, Geforce RTX 20xx, TITAN RTX\n      list(APPEND CMAKE_CUDA_ARCHITECTURES 75-real)\n    endif()\n\n    if(CUDA_VERSION VERSION_GREATER_EQUAL \"11.0\")\n      # A100\n      list(APPEND CMAKE_CUDA_ARCHITECTURES 80-real)\n    endif()\n\n    if(CUDA_VERSION VERSION_GREATER_EQUAL \"11.1\")\n      # GeForce RTX 30xx\n      list(APPEND CMAKE_CUDA_ARCHITECTURES 86-real)\n    endif()\n\n    if(CUDA_VERSION VERSION_GREATER_EQUAL \"11.8\")\n      # GeForce RTX 40xx\n      list(APPEND CMAKE_CUDA_ARCHITECTURES 89-real)\n    endif()\n\n    if(CUDA_VERSION VERSION_GREATER_EQUAL \"12.0\")\n      # H100, H20\n      list(APPEND CMAKE_CUDA_ARCHITECTURES 90-real)\n    endif()\n  endif()\n\n  foreach(CUDA_ARCH ${CMAKE_CUDA_ARCHITECTURES})\n    if(CUDA_ARCH MATCHES \"^([0-9]+)\\\\-real$\")\n      list(APPEND CUDA_REAL_ARCHS_LIST ${CMAKE_MATCH_1})\n    elseif(CUDA_ARCH MATCHES \"^([0-9]+)$\")\n      list(APPEND CUDA_REAL_ARCHS_LIST ${CMAKE_MATCH_1})\n    endif()\n  endforeach()\n\n  enable_language(CUDA)\n  include_directories(${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})\n  message(STATUS \"CMAKE_CUDA_ARCHITECTURES: ${CMAKE_CUDA_ARCHITECTURES}\")\n  set(CUDA_SEPARABLE_COMPILATION OFF)\n\n  if(\"${CMAKE_CUDA_COMPILER_ID}\" STREQUAL \"NVIDIA\")\n    if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL \"11.2\")\n      set(CUDA_NVCC_THREADS_NUMBER \"4\" CACHE STRING \"\")\n      list(APPEND CUDA_NVCC_FLAGS -t ${CUDA_NVCC_THREADS_NUMBER})\n    endif()\n    list(APPEND CUDA_NVCC_FLAGS \"-Xcompiler=-fno-strict-aliasing\")\n    message(STATUS \"CUDA_NVCC_FLAGS: \" ${CUDA_NVCC_FLAGS})\n    list(JOIN CUDA_NVCC_FLAGS \" \" CMAKE_CUDA_FLAGS)\n  endif()\nendif()\n"
  },
  {
    "path": "cmake/functional.cmake",
    "content": "function(GENERATE_FUNCTIONAL_API_AND_PYBIND11_CPP SRCS HDRS PYBIND_SRCS ROOT_DIR)\n  set(YAML_FILE ${PROJECT_SOURCE_DIR}/oneflow/core/functional/functional_api.yaml)\n  set(GENERATED_API_DIR oneflow/core/functional)\n\n  list(APPEND SRCS ${PROJECT_BINARY_DIR}/${GENERATED_API_DIR}/functional_api.yaml.cpp)\n  list(APPEND HDRS ${PROJECT_BINARY_DIR}/${GENERATED_API_DIR}/functional_api.yaml.h)\n\n  if(BUILD_PYTHON)\n    set(GENERATED_PYBIND_DIR oneflow/api/python/functional)\n    list(APPEND PYBIND_SRCS\n         ${PROJECT_BINARY_DIR}/${GENERATED_PYBIND_DIR}/functional_api.yaml.pybind.cpp)\n  endif(BUILD_PYTHON)\n\n  if(BUILD_PYTHON)\n\n    add_custom_command(\n      OUTPUT \"${PROJECT_BINARY_DIR}/${GENERATED_API_DIR}/functional_api.yaml.cpp\"\n             \"${PROJECT_BINARY_DIR}/${GENERATED_API_DIR}/functional_api.yaml.h\"\n             \"${PROJECT_BINARY_DIR}/${GENERATED_PYBIND_DIR}/functional_api.yaml.pybind.cpp\"\n      COMMAND ${CMAKE_COMMAND} ARGS -E make_directory ${GENERATED_API_DIR}\n      COMMAND ${CMAKE_COMMAND} ARGS -E make_directory ${GENERATED_PYBIND_DIR}\n      COMMAND ${CODEGEN_PYTHON_EXECUTABLE} ARGS\n              ${PROJECT_SOURCE_DIR}/tools/functional/generate_functional_api.py --project_source_dir\n              ${PROJECT_SOURCE_DIR} --export_pybind\n      DEPENDS ${CODEGEN_PYTHON_EXECUTABLE}\n              ${PROJECT_SOURCE_DIR}/tools/functional/generate_functional_api.py\n              ${PROJECT_SOURCE_DIR}/tools/functional/generator.py ${YAML_FILE}\n      VERBATIM)\n\n  else() # build_python\n\n    add_custom_command(\n      OUTPUT \"${PROJECT_BINARY_DIR}/${GENERATED_API_DIR}/functional_api.yaml.cpp\"\n             \"${PROJECT_BINARY_DIR}/${GENERATED_API_DIR}/functional_api.yaml.h\"\n      COMMAND ${CMAKE_COMMAND} ARGS -E make_directory ${GENERATED_API_DIR}\n      COMMAND ${CODEGEN_PYTHON_EXECUTABLE} ARGS\n              ${PROJECT_SOURCE_DIR}/tools/functional/generate_functional_api.py --project_source_dir\n              ${PROJECT_SOURCE_DIR}\n      DEPENDS ${CODEGEN_PYTHON_EXECUTABLE}\n              ${PROJECT_SOURCE_DIR}/tools/functional/generate_functional_api.py\n              ${PROJECT_SOURCE_DIR}/tools/functional/generator.py ${YAML_FILE}\n      VERBATIM)\n\n  endif(BUILD_PYTHON)\n\n  set_source_files_properties(${${SRCS}} ${${HDRS}} PROPERTIES GENERATED TRUE)\n  set(${SRCS} ${${SRCS}} PARENT_SCOPE)\n  set(${HDRS} ${${HDRS}} PARENT_SCOPE)\n\n  if(BUILD_PYTHON)\n    set_source_files_properties(${${PYBIND_SRCS}} PROPERTIES GENERATED TRUE)\n    set(${PYBIND_SRCS} ${${PYBIND_SRCS}} PARENT_SCOPE)\n  endif(BUILD_PYTHON)\n\nendfunction()\n\nfunction(GENERATE_FUNCTIONAL_TENSOR_API_AND_PYBIND11_CPP SRCS HDRS PYBIND_SRCS ROOT_DIR)\n  set(YAML_FILE ${PROJECT_SOURCE_DIR}/oneflow/api/python/functional/tensor_api.yaml)\n  set(GENERATED_API_DIR oneflow/api/python/functional)\n  set(GENERATED_PYBIND_DIR oneflow/api/python/functional)\n\n  list(APPEND SRCS ${PROJECT_BINARY_DIR}/${GENERATED_API_DIR}/tensor_api.yaml.cpp)\n  list(APPEND HDRS ${PROJECT_BINARY_DIR}/${GENERATED_API_DIR}/tensor_api.yaml.h)\n  list(APPEND PYBIND_SRCS ${PROJECT_BINARY_DIR}/${GENERATED_PYBIND_DIR}/tensor_api.yaml.pybind.cpp)\n\n  add_custom_command(\n    OUTPUT \"${PROJECT_BINARY_DIR}/${GENERATED_API_DIR}/tensor_api.yaml.cpp\"\n           \"${PROJECT_BINARY_DIR}/${GENERATED_API_DIR}/tensor_api.yaml.h\"\n           \"${PROJECT_BINARY_DIR}/${GENERATED_PYBIND_DIR}/tensor_api.yaml.pybind.cpp\"\n    COMMAND ${CMAKE_COMMAND} ARGS -E make_directory ${GENERATED_API_DIR}\n    COMMAND ${CMAKE_COMMAND} ARGS -E make_directory ${GENERATED_PYBIND_DIR}\n    COMMAND ${CODEGEN_PYTHON_EXECUTABLE} ARGS\n            ${PROJECT_SOURCE_DIR}/tools/functional/generate_tensor_api.py --project_source_dir\n            ${PROJECT_SOURCE_DIR}\n    DEPENDS ${CODEGEN_PYTHON_EXECUTABLE}\n            ${PROJECT_SOURCE_DIR}/tools/functional/generate_tensor_api.py\n            ${PROJECT_SOURCE_DIR}/tools/functional/generator.py ${YAML_FILE}\n    VERBATIM)\n\n  set_source_files_properties(${${SRCS}} ${${HDRS}} ${${PYBIND_SRCS}} PROPERTIES GENERATED TRUE)\n  set(${SRCS} ${${SRCS}} PARENT_SCOPE)\n  set(${HDRS} ${${HDRS}} PARENT_SCOPE)\n  set(${PYBIND_SRCS} ${${PYBIND_SRCS}} PARENT_SCOPE)\n\nendfunction()\n\nfunction(GENERATE_FUNCTIONAL_DISPATCH_STATEFUL_OPS_AND_PYBIND11_CPP SRCS HDRS PYBIND_SRCS ROOT_DIR)\n  set(YAML_FILE ${PROJECT_SOURCE_DIR}/oneflow/api/python/functional/dispatch_stateful_ops.yaml)\n  set(GENERATED_API_DIR oneflow/api/python/functional)\n  set(GENERATED_PYBIND_DIR oneflow/api/python/functional)\n\n  list(APPEND SRCS ${PROJECT_BINARY_DIR}/${GENERATED_API_DIR}/dispatch_stateful_ops.yaml.cpp)\n  list(APPEND HDRS ${PROJECT_BINARY_DIR}/${GENERATED_API_DIR}/dispatch_stateful_ops.yaml.h)\n  list(APPEND PYBIND_SRCS\n       ${PROJECT_BINARY_DIR}/${GENERATED_PYBIND_DIR}/dispatch_stateful_ops.yaml.pybind.cpp)\n\n  add_custom_command(\n    OUTPUT \"${PROJECT_BINARY_DIR}/${GENERATED_API_DIR}/dispatch_stateful_ops.yaml.cpp\"\n           \"${PROJECT_BINARY_DIR}/${GENERATED_API_DIR}/dispatch_stateful_ops.yaml.h\"\n           \"${PROJECT_BINARY_DIR}/${GENERATED_PYBIND_DIR}/dispatch_stateful_ops.yaml.pybind.cpp\"\n    COMMAND ${CMAKE_COMMAND} ARGS -E make_directory ${GENERATED_API_DIR}\n    COMMAND ${CMAKE_COMMAND} ARGS -E make_directory ${GENERATED_PYBIND_DIR}\n    COMMAND ${CODEGEN_PYTHON_EXECUTABLE} ARGS\n            ${PROJECT_SOURCE_DIR}/tools/functional/generate_dispatch_stateful_ops.py\n            --project_source_dir ${PROJECT_SOURCE_DIR}\n    DEPENDS ${CODEGEN_PYTHON_EXECUTABLE}\n            ${PROJECT_SOURCE_DIR}/tools/functional/generate_dispatch_stateful_ops.py\n            ${PROJECT_SOURCE_DIR}/tools/functional/generator.py ${YAML_FILE}\n    VERBATIM)\n\n  set_source_files_properties(${${SRCS}} ${${HDRS}} ${${PYBIND_SRCS}} PROPERTIES GENERATED TRUE)\n  set(${SRCS} ${${SRCS}} PARENT_SCOPE)\n  set(${HDRS} ${${HDRS}} PARENT_SCOPE)\n  set(${PYBIND_SRCS} ${${PYBIND_SRCS}} PARENT_SCOPE)\n\nendfunction()\n"
  },
  {
    "path": "cmake/git_version.cmake",
    "content": "cmake_minimum_required(VERSION 3.5)\nexecute_process(\n  COMMAND git describe --tags --always --dirty=-snapshot\n  WORKING_DIRECTORY ${OF_GIT_VERSION_ROOT}\n  OUTPUT_VARIABLE GIT_REV\n  ERROR_QUIET)\nif((\"${GIT_REV}\" STREQUAL \"\") OR (NOT BUILD_GIT_VERSION))\n  set(GIT_REV \"N/A\")\nelse()\n  string(STRIP \"${GIT_REV}\" GIT_REV)\nendif()\n\nset(VERSION_FILE_CONTENT\n    \"namespace oneflow {\\n\\\n\\n\\\nconst char* GetOneFlowGitVersion() {\\n\\\n  return \\\"${GIT_REV}\\\";\\n\\\n}\\n\\\n\\n\\\n}\\n\")\n\nif(EXISTS ${OF_GIT_VERSION_FILE})\n  file(READ ${OF_GIT_VERSION_FILE} VERSION_FILE_CONTENT_)\nelse()\n  set(VERSION_FILE_CONTENT_ \"\")\nendif()\n\nif(NOT \"${VERSION_FILE_CONTENT}\" STREQUAL \"${VERSION_FILE_CONTENT_}\")\n  file(WRITE ${OF_GIT_VERSION_FILE} \"${VERSION_FILE_CONTENT}\")\nendif()\n"
  },
  {
    "path": "cmake/oneflow-config.cmake",
    "content": "if(DEFINED ENV{ONEFLOW_INSTALL_PREFIX})\n  set(ONEFLOW_INSTALL_PREFIX $ENV{ONEFLOW_INSTALL_PREFIX})\nelse()\n  get_filename_component(CMAKE_CURRENT_LIST_DIR \"${CMAKE_CURRENT_LIST_FILE}\" PATH)\n  get_filename_component(ONEFLOW_INSTALL_PREFIX \"${CMAKE_CURRENT_LIST_DIR}/../\" ABSOLUTE)\nendif()\n\nset(ONEFLOW_INCLUDE_DIRS ${ONEFLOW_INSTALL_PREFIX}/include)\n\nfind_library(ONEFLOW_LIBRARY NAMES oneflow_cpp PATHS ${ONEFLOW_INSTALL_PREFIX}/lib REQUIRED)\n\nif(NOT TARGET OneFlow::liboneflow)\n  add_library(OneFlow::liboneflow INTERFACE IMPORTED)\n\n  set_property(TARGET OneFlow::liboneflow PROPERTY INTERFACE_LINK_LIBRARIES ${ONEFLOW_LIBRARY})\n  set_property(TARGET OneFlow::liboneflow PROPERTY INTERFACE_INCLUDE_DIRECTORIES\n                                                   ${ONEFLOW_INCLUDE_DIRS})\nendif()\n"
  },
  {
    "path": "cmake/oneflow.cmake",
    "content": "include(python)\n\nfunction(oneflow_add_executable)\n  add_executable(${ARGV})\n  set_compile_options_to_oneflow_target(${ARGV0})\nendfunction()\n\nfunction(oneflow_add_library)\n  add_library(${ARGV})\n  set_compile_options_to_oneflow_target(${ARGV0})\nendfunction()\n\n# source_group\nif(WIN32)\n  set(oneflow_platform \"windows\")\n  list(APPEND oneflow_platform_excludes \"linux\")\nelse()\n  set(oneflow_platform \"linux\")\n  list(APPEND oneflow_platform_excludes \"windows\")\nendif()\n\nfile(GLOB_RECURSE oneflow_all_hdr_to_be_expanded \"${PROJECT_SOURCE_DIR}/oneflow/core/*.e.h\"\n     \"${PROJECT_SOURCE_DIR}/oneflow/python/*.e.h\")\nforeach(oneflow_hdr_to_be_expanded ${oneflow_all_hdr_to_be_expanded})\n  file(RELATIVE_PATH of_ehdr_rel_path ${PROJECT_SOURCE_DIR} ${oneflow_hdr_to_be_expanded})\n  set(of_e_h_expanded \"${PROJECT_BINARY_DIR}/${of_ehdr_rel_path}.expanded.h\")\n  if(WIN32)\n    error(\"Expanding macro in WIN32 is not supported yet\")\n  else()\n    add_custom_command(\n      OUTPUT ${of_e_h_expanded}\n      COMMAND ${CMAKE_C_COMPILER} ARGS -E -I\"${PROJECT_SOURCE_DIR}\" -I\"${PROJECT_BINARY_DIR}\" -o\n              \"${of_e_h_expanded}\" \"${oneflow_hdr_to_be_expanded}\"\n      DEPENDS ${oneflow_hdr_to_be_expanded}\n      COMMENT \"Expanding macros in ${oneflow_hdr_to_be_expanded}\")\n    list(APPEND oneflow_all_hdr_expanded \"${of_e_h_expanded}\")\n  endif()\n  set_source_files_properties(${oneflow_all_hdr_expanded} PROPERTIES GENERATED TRUE)\nendforeach()\n\nfile(\n  GLOB_RECURSE\n  oneflow_all_src\n  \"${PROJECT_SOURCE_DIR}/oneflow/core/*.*\"\n  \"${PROJECT_SOURCE_DIR}/oneflow/user/*.*\"\n  \"${PROJECT_SOURCE_DIR}/oneflow/api/*.*\"\n  \"${PROJECT_SOURCE_DIR}/oneflow/maybe/*.*\"\n  \"${PROJECT_SOURCE_DIR}/oneflow/extension/*.*\")\n\nforeach(oneflow_single_file ${oneflow_all_src})\n  # Verify whether this file is for other platforms\n  set(exclude_this OFF)\n  set(group_this OFF)\n  foreach(oneflow_platform_exclude ${oneflow_platform_excludes})\n    string(FIND ${oneflow_single_file} ${oneflow_platform_exclude} platform_found)\n    if(NOT ${platform_found} EQUAL -1) # the ${oneflow_single_file} is for other platforms\n      set(exclude_this ON)\n    endif()\n  endforeach()\n  # If this file is for other platforms, just exclude it from current project\n  if(exclude_this)\n    continue()\n  endif()\n\n  if(\"${oneflow_single_file}\" MATCHES\n     \"^${PROJECT_SOURCE_DIR}/oneflow/(core|user|maybe)/.*\\\\.(h|hpp)$\")\n    if((NOT RPC_BACKEND MATCHES \"GRPC\") AND \"${oneflow_single_file}\" MATCHES\n                                            \"^${PROJECT_SOURCE_DIR}/oneflow/core/control/.*\")\n      # skip if GRPC not enabled\n    elseif(APPLE AND \"${oneflow_single_file}\" MATCHES\n                     \"^${PROJECT_SOURCE_DIR}/oneflow/core/comm_network/(epoll|ibverbs)/.*\")\n      # skip if macOS\n    else()\n      list(APPEND of_all_obj_cc ${oneflow_single_file})\n      set(group_this ON)\n    endif()\n  endif()\n\n  if(\"${oneflow_single_file}\" MATCHES \"^${PROJECT_SOURCE_DIR}/oneflow/(core|user)/.*\\\\.(cuh|cu)$\")\n    if(BUILD_CUDA)\n      list(APPEND of_all_obj_cc ${oneflow_single_file})\n    endif()\n    set(group_this ON)\n  endif()\n\n  if(\"${oneflow_single_file}\" MATCHES \"^${PROJECT_SOURCE_DIR}/oneflow/(core|user)/.*\\\\.proto$\")\n    list(APPEND of_all_proto ${oneflow_single_file})\n    #list(APPEND of_all_obj_cc ${oneflow_single_file})   # include the proto file in the project\n    set(group_this ON)\n  endif()\n\n  if(BUILD_PYTHON)\n\n    if(\"${oneflow_single_file}\" MATCHES \"^${PROJECT_SOURCE_DIR}/oneflow/api/python/.*\\\\.(h|cpp)$\")\n      list(APPEND of_pybind_obj_cc ${oneflow_single_file})\n      set(group_this ON)\n    endif()\n\n    if(\"${oneflow_single_file}\" MATCHES \"^${PROJECT_SOURCE_DIR}/oneflow/extension/.*\\\\.(c|h|cpp)$\")\n      list(APPEND of_pyext_obj_cc ${oneflow_single_file})\n      set(group_this ON)\n    endif()\n  endif(BUILD_PYTHON)\n\n  if(\"${oneflow_single_file}\" MATCHES \"^${PROJECT_SOURCE_DIR}/oneflow/(core|user|maybe)/.*\\\\.cpp$\")\n    if(\"${oneflow_single_file}\" MATCHES\n       \"^${PROJECT_SOURCE_DIR}/oneflow/(core|user|maybe|thread)/.*_test\\\\.cpp$\")\n      # test file\n      list(APPEND of_all_test_cc ${oneflow_single_file})\n    elseif(APPLE AND \"${oneflow_single_file}\" MATCHES\n                     \"^${PROJECT_SOURCE_DIR}/oneflow/core/comm_network/(epoll|ibverbs)/.*\")\n      # skip if macOS\n    elseif(APPLE AND \"${oneflow_single_file}\" MATCHES\n                     \"^${PROJECT_SOURCE_DIR}/oneflow/core/transport/.*\")\n      # skip if macOS\n    elseif((NOT RPC_BACKEND MATCHES \"GRPC\") AND \"${oneflow_single_file}\" MATCHES\n                                                \"^${PROJECT_SOURCE_DIR}/oneflow/core/control.*\")\n      # skip if GRPC not enabled\n    else()\n      list(APPEND of_all_obj_cc ${oneflow_single_file})\n    endif()\n    set(group_this ON)\n  endif()\n  if(group_this)\n    file(RELATIVE_PATH oneflow_relative_file ${PROJECT_SOURCE_DIR}/oneflow/core/\n         ${oneflow_single_file})\n    get_filename_component(oneflow_relative_path ${oneflow_relative_file} PATH)\n    string(REPLACE \"/\" \"\\\\\" group_name ${oneflow_relative_path})\n    source_group(\"${group_name}\" FILES ${oneflow_single_file})\n  endif()\nendforeach()\n\n# clang format\nadd_custom_target(\n  of_format\n  COMMAND ${Python_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/ci/check/run_license_format.py -i\n          ${CMAKE_CURRENT_SOURCE_DIR}/oneflow --fix\n  COMMAND ${Python_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/ci/check/run_license_format.py -i\n          ${ONEFLOW_PYTHON_DIR} --fix --exclude=\"oneflow/include\" --exclude=\"oneflow/core\"\n  COMMAND ${Python_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/ci/check/run_clang_format.py --source_dir\n          ${CMAKE_CURRENT_SOURCE_DIR}/oneflow --fix --quiet\n  COMMAND ${Python_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/ci/check/run_py_format.py --source_dir\n          ${CMAKE_CURRENT_SOURCE_DIR}/python --fix\n  COMMAND ${Python_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/ci/check/run_clang_format.py\n          --source_dir ${CMAKE_CURRENT_SOURCE_DIR}/tools/oneflow-tblgen --fix --quiet)\n# clang tidy\nset(RUN_CLANG_TIDY_ARGS --build_dir ${CMAKE_BINARY_DIR})\nif(MAYBE_NEED_ERROR_MSG_CHECK)\n  list(APPEND RUN_CLANG_TIDY_ARGS --check-error-msg)\nendif()\nmessage(STATUS \"RUN_CLANG_TIDY_ARGS: ${RUN_CLANG_TIDY_ARGS}\")\nadd_custom_target(\n  of_tidy COMMAND ${Python_EXECUTABLE} ${CMAKE_SOURCE_DIR}/ci/check/run_clang_tidy.py\n                  ${RUN_CLANG_TIDY_ARGS} DEPENDS of_git_version oneflow_deps of_functional_obj\n                                                 of_functional_tensor_obj)\n# generate version\nset(OF_GIT_VERSION_DIR ${CMAKE_CURRENT_BINARY_DIR}/of_git_version)\nset(OF_GIT_VERSION_FILE ${OF_GIT_VERSION_DIR}/version.cpp)\nset(OF_GIT_VERSION_DUMMY_FILE ${OF_GIT_VERSION_DIR}/_version.cpp)\nadd_custom_target(of_git_version_create_dir COMMAND ${CMAKE_COMMAND} -E make_directory\n                                                    ${OF_GIT_VERSION_DIR})\nadd_custom_command(\n  OUTPUT ${OF_GIT_VERSION_DUMMY_FILE}\n  COMMAND ${CMAKE_COMMAND} -DOF_GIT_VERSION_FILE=${OF_GIT_VERSION_FILE}\n          -DOF_GIT_VERSION_ROOT=${PROJECT_SOURCE_DIR} -DBUILD_GIT_VERSION=${BUILD_GIT_VERSION} -P\n          ${CMAKE_CURRENT_SOURCE_DIR}/cmake/git_version.cmake\n  DEPENDS of_git_version_create_dir)\nadd_custom_target(of_git_version DEPENDS ${OF_GIT_VERSION_DUMMY_FILE})\nset_source_files_properties(${OF_GIT_VERSION_FILE} PROPERTIES GENERATED TRUE)\nlist(APPEND of_all_obj_cc ${OF_GIT_VERSION_FILE})\n\nset(of_proto_python_dir \"${PROJECT_BINARY_DIR}/of_proto_python\")\n\n# proto obj lib\nadd_custom_target(make_pyproto_dir ALL COMMAND ${CMAKE_COMMAND} -E make_directory\n                                               ${of_proto_python_dir})\nforeach(proto_name ${of_all_proto})\n  file(RELATIVE_PATH proto_rel_name ${PROJECT_SOURCE_DIR} ${proto_name})\n  list(APPEND of_all_rel_protos ${proto_rel_name})\nendforeach()\n\nrelative_protobuf_generate_cpp(PROTO_SRCS PROTO_HDRS ${PROJECT_SOURCE_DIR} ${of_all_rel_protos})\n\noneflow_add_library(of_protoobj SHARED ${PROTO_SRCS} ${PROTO_HDRS})\nadd_dependencies(of_protoobj make_pyproto_dir protobuf)\ntarget_link_libraries(of_protoobj protobuf_imported)\n\ninclude(functional)\ngenerate_functional_api_and_pybind11_cpp(FUNCTIONAL_GENERATED_SRCS FUNCTIONAL_GENERATED_HRCS\n                                         FUNCTIONAL_PYBIND11_SRCS ${PROJECT_SOURCE_DIR})\noneflow_add_library(of_functional_obj OBJECT ${FUNCTIONAL_GENERATED_SRCS}\n                    ${FUNCTIONAL_GENERATED_HRCS})\ntarget_link_libraries(of_functional_obj LLVMSupportWithHeader glog::glog fmt)\nadd_dependencies(of_functional_obj prepare_oneflow_third_party)\n\nif(BUILD_PYTHON)\n\n  generate_functional_tensor_api_and_pybind11_cpp(\n    FUNCTIONAL_TENSOR_GENERATED_SRCS FUNCTIONAL_TENSOR_GENERATED_HRCS\n    FUNCTIONAL_TENSOR_PYBIND11_SRCS ${PROJECT_SOURCE_DIR})\n\n  generate_functional_dispatch_stateful_ops_and_pybind11_cpp(\n    FUNCTIONAL_OPS_GENERATED_SRCS FUNCTIONAL_OPS_GENERATED_HRCS FUNCTIONAL_OPS_PYBIND11_SRCS\n    ${PROJECT_SOURCE_DIR})\n\n  oneflow_add_library(\n    of_functional_tensor_obj OBJECT ${FUNCTIONAL_TENSOR_GENERATED_SRCS}\n    ${FUNCTIONAL_TENSOR_GENERATED_HRCS} ${FUNCTIONAL_OPS_GENERATED_SRCS}\n    ${FUNCTIONAL_OPS_GENERATED_HRCS})\n  target_link_libraries(of_functional_tensor_obj LLVMSupportWithHeader glog::glog fmt)\n  add_dependencies(of_functional_tensor_obj prepare_oneflow_third_party)\n  target_include_directories(of_functional_tensor_obj PRIVATE ${Python_INCLUDE_DIRS}\n                                                              ${Python_NumPy_INCLUDE_DIRS})\n\n  set(PYBIND11_SRCS ${FUNCTIONAL_PYBIND11_SRCS} ${FUNCTIONAL_TENSOR_PYBIND11_SRCS}\n                    ${FUNCTIONAL_OPS_PYBIND11_SRCS})\n\nendif(BUILD_PYTHON)\n\ninclude_directories(${PROJECT_SOURCE_DIR}) # TO FIND: third_party/eigen3/..\ninclude_directories(${PROJECT_BINARY_DIR})\n\n# cc obj lib\noneflow_add_library(oneflow SHARED ${of_all_obj_cc})\n\nadd_dependencies(oneflow of_protoobj)\nadd_dependencies(oneflow of_functional_obj)\nadd_dependencies(oneflow of_op_schema)\nadd_dependencies(oneflow of_git_version)\n\nif(USE_CLANG_FORMAT)\n  add_dependencies(oneflow of_format)\nendif()\nif(USE_CLANG_TIDY)\n  add_dependencies(oneflow of_tidy)\nendif()\n\ntarget_compile_definitions(oneflow PRIVATE GOOGLE_LOGGING)\n\nset(ONEFLOW_TOOLS_DIR \"${PROJECT_BINARY_DIR}/tools\"\n    CACHE STRING \"dir to put binary for debugging and development\")\n\nset(CACHE_LLVM_MONO_REPO_URL_LIST\n    \"https://github.com/llvm/llvm-project/archive/c63522e6ba7782c335043893ae7cbd37eca24fe5.zip\"\n    \"https://github.com/llvm/llvm-project/archive/a0595f8c99a253c65f30a151337e7aadc19ee3a1.zip\"\n    \"https://github.com/llvm/llvm-project/archive/7eaa84eac3ba935d13f4267d3d533a6c3e1283ed.zip\"\n    \"https://github.com/llvm/llvm-project/archive/35e60f5de180aea55ed478298f4b40f04dcc57d1.zip\"\n    \"https://github.com/llvm/llvm-project/archive/6a9bbd9f20dcd700e28738788bb63a160c6c088c.zip\"\n    \"https://github.com/llvm/llvm-project/archive/32805e60c9de1f82887cd2af30d247dcabd2e1d3.zip\"\n    \"https://github.com/llvm/llvm-project/archive/6d6268dcbf0f48e43f6f9fe46b3a28c29ba63c7d.zip\"\n    \"https://github.com/llvm/llvm-project/archive/5c9a84960de2260f149ee15313998593255a78df.zip\"\n    \"https://github.com/llvm/llvm-project/archive/refs/tags/llvmorg-16.0.0-rc4.zip\"\n    \"https://github.com/llvm/llvm-project/archive/refs/tags/llvmorg-15.0.6.zip\"\n    \"https://github.com/llvm/llvm-project/archive/refs/tags/llvmorg-16.0.0.zip\"\n    \"https://github.com/llvm/llvm-project/archive/refs/tags/llvmorg-16.0.3.zip\")\n\nset(CACHE_LLVM_MONO_REPO_MD5_LIST\n    \"f2f17229cf21049663b8ef4f2b6b8062\"\n    \"6b7c6506d5922de9632c8ff012b2f945\"\n    \"e0ea669a9f0872d35bffda5ec6c5ac6f\"\n    \"241a333828bba1efa35aff4c4fc2ce87\"\n    \"075fbfdf06cb3f02373ea44971af7b03\"\n    \"e412dc61159b5e929b0c94e44b11feb2\"\n    \"1ccc00accc87a1a5d42a275d6e31cd8c\"\n    \"b64481eaca658a2ff4e3e193440d0f68\"\n    \"78172b0f67282e28956cd310612091fd\"\n    \"0c2a3196e656aaab7ca1c2ef21b6091c\"\n    \"2702b822b71c196a0cc9c8d821c069d7\"\n    \"334997b4879aba15d9323a732356cf2a\")\n\n# clean cache for last LLVM version\nif(\"${LLVM_MONO_REPO_URL}\" IN_LIST CACHE_LLVM_MONO_REPO_URL_LIST OR \"${LLVM_MONO_REPO_MD5}\" IN_LIST\n                                                                    CACHE_LLVM_MONO_REPO_MD5_LIST)\n  unset(LLVM_MONO_REPO_URL CACHE)\n  unset(LLVM_MONO_REPO_MD5 CACHE)\nendif()\nset(LLVM_MONO_REPO_URL\n    \"https://github.com/llvm/llvm-project/archive/c2ce2a509f74a85a3c0ef4b9d6d79fbacc7e8bdf.zip\"\n    CACHE STRING \"\")\nuse_mirror(VARIABLE LLVM_MONO_REPO_URL URL ${LLVM_MONO_REPO_URL})\nset(LLVM_MONO_REPO_MD5 \"25489a23c6fa971fcd0d1167a560bf0a\" CACHE STRING \"\")\nset(ONEFLOW_BUILD_ROOT_DIR \"${PROJECT_BINARY_DIR}\")\nadd_subdirectory(${PROJECT_SOURCE_DIR}/oneflow/ir)\nif(WITH_MLIR)\n  set(ONEFLOW_MLIR_LIBS -Wl,--no-as-needed MLIROneFlowExtension -Wl,--as-needed)\nendif()\n\nif(\"${LLVM_PROVIDER}\" STREQUAL \"install\")\n  get_property(LLVM_INSTALL_DIR GLOBAL PROPERTY LLVM_INSTALL_DIR)\n  check_variable_defined(LLVM_INSTALL_DIR)\n  find_library(LLVMSupportLib LLVMSupport PATHS ${LLVM_INSTALL_DIR}/lib REQUIRED)\n  add_library(LLVMSupportWithHeader UNKNOWN IMPORTED)\n  set_property(TARGET LLVMSupportWithHeader PROPERTY IMPORTED_LOCATION ${LLVMSupportLib})\nelse()\n  add_library(LLVMSupportWithHeader INTERFACE IMPORTED)\n  target_link_libraries(LLVMSupportWithHeader INTERFACE LLVMSupport)\nendif()\ncheck_variable_defined(LLVM_INCLUDE_DIRS)\nset_property(TARGET LLVMSupportWithHeader PROPERTY INTERFACE_INCLUDE_DIRECTORIES\n                                                   ${LLVM_INCLUDE_DIRS})\n\nlist(APPEND oneflow_third_party_libs LLVMSupportWithHeader)\n\n# for stack backtrace\nfind_package(BFD)\nif(BFD_FOUND)\n  add_definitions(-DBACKWARD_HAS_BFD=1)\n  list(APPEND oneflow_third_party_libs bfd::bfd)\nendif()\nfind_package(Unwind)\nif(Unwind_FOUND)\n  add_definitions(-DBACKWARD_HAS_LIBUNWIND=1)\n  list(APPEND oneflow_third_party_libs unwind::unwind)\nendif()\nadd_definitions(-DONEFLOW_SOURCE_DIR=\"${PROJECT_SOURCE_DIR}\")\nadd_definitions(-DONEFLOW_BINARY_DIR=\"${PROJECT_BINARY_DIR}\")\n\ninclude(op_schema)\n\nget_property(EXTERNAL_TARGETS GLOBAL PROPERTY EXTERNAL_TARGETS)\n\nif(APPLE)\n  set(of_libs ${ALL_ARCHIVE_BEGIN} oneflow of_op_schema ${ALL_ARCHIVE_END})\n  target_link_libraries(oneflow of_protoobj of_functional_obj ${oneflow_third_party_libs})\nelseif(UNIX)\n  set(of_libs ${ALL_ARCHIVE_BEGIN} oneflow of_op_schema ${ALL_ARCHIVE_END} -ldl -lrt)\n  target_link_libraries(\n    oneflow\n    of_protoobj\n    of_functional_obj\n    ${oneflow_third_party_libs}\n    ${EXTERNAL_TARGETS}\n    -Wl,--no-whole-archive\n    -Wl,--as-needed\n    -ldl\n    -lrt)\n  if(BUILD_CUDA)\n    target_link_libraries(oneflow CUDA::cudart_static)\n  endif()\n  if(WITH_OMP)\n    if(OpenMP_CXX_FOUND)\n      target_link_libraries(oneflow OpenMP::OpenMP_CXX)\n    endif()\n  endif()\nelseif(WIN32)\n  set(of_libs oneflow of_protoobj of_functional_obj of_op_schema)\n  set(CMAKE_EXE_LINKER_FLAGS \"${CMAKE_EXE_LINKER_FLAGS} /WHOLEARCHIVE:oneflow\")\nendif()\n\nif(BUILD_CUDA)\n  string(JOIN \",\" CUDA_REAL_ARCHS ${CUDA_REAL_ARCHS_LIST})\n  set_source_files_properties(${PROJECT_SOURCE_DIR}/oneflow/core/hardware/cuda_device_descriptor.cpp\n                              PROPERTIES COMPILE_FLAGS \"-DCUDA_REAL_ARCHS=\\\"${CUDA_REAL_ARCHS}\\\"\")\nendif()\n\nif(BUILD_NPU)\n  add_definitions(-DWITH_NPU)\nendif()\nmessage(STATUS \"BUILD_NPU: ${BUILD_NPU}\")\n\nif(BUILD_MLU)\n  add_definitions(-DWITH_MLU)\nendif()\nmessage(STATUS \"BUILD_MLU: ${BUILD_MLU}\")\n\nif(BUILD_CUDA AND WITH_CUTLASS)\n  if(CUDA_VERSION VERSION_GREATER_EQUAL \"10.1\")\n    add_definitions(-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1)\n  endif()\n\n  set_property(SOURCE ${PROJECT_SOURCE_DIR}/oneflow/user/kernels/fused_attention_kernels.cu APPEND\n               PROPERTY INCLUDE_DIRECTORIES ${CUTLASS_INSTALL_DIR}/examples/xformers_fmha)\n  set_property(SOURCE ${PROJECT_SOURCE_DIR}/oneflow/user/kernels/fused_glu_kernel.cu APPEND\n               PROPERTY INCLUDE_DIRECTORIES ${CUTLASS_INSTALL_DIR}/examples/45_dual_gemm)\n  if(\"${CMAKE_CUDA_COMPILER_ID}\" STREQUAL \"NVIDIA\")\n    set_property(\n      SOURCE\n        ${PROJECT_SOURCE_DIR}/oneflow/user/kernels/fused_multi_head_attention_inference_kernel.cu\n      APPEND\n      PROPERTY COMPILE_OPTIONS \"--use_fast_math\")\n  endif()\nendif()\n\n# oneflow api common\nif(BUILD_PYTHON OR BUILD_CPP_API)\n  file(GLOB_RECURSE of_api_common_files ${PROJECT_SOURCE_DIR}/oneflow/api/common/*.h\n       ${PROJECT_SOURCE_DIR}/oneflow/api/common/*.cpp)\n  oneflow_add_library(of_api_common OBJECT ${of_api_common_files})\n  target_link_libraries(of_api_common oneflow)\n  if(WITH_MLIR)\n    target_link_libraries(of_api_common ${ALL_ARCHIVE_BEGIN} ${ONEFLOW_MLIR_LIBS}\n                          ${ALL_ARCHIVE_END})\n  endif()\nendif()\n\nif(BUILD_PYTHON)\n\n  # py ext lib\n  # This library should be static to make sure all python symbols are included in the final ext shared lib,\n  # so that it is safe to do wheel audits of multiple pythons version in parallel.\n  oneflow_add_library(of_pyext_obj STATIC ${of_pyext_obj_cc})\n  target_include_directories(of_pyext_obj PRIVATE ${Python_INCLUDE_DIRS}\n                                                  ${Python_NumPy_INCLUDE_DIRS})\n  target_link_libraries(of_pyext_obj oneflow pybind11::headers)\n  if(BUILD_SHARED_LIBS AND APPLE)\n    target_link_libraries(of_pyext_obj ${Python3_LIBRARIES})\n  endif()\n  add_dependencies(of_pyext_obj oneflow)\n\n  pybind11_add_module(oneflow_internal ${PYBIND11_SRCS} ${of_pybind_obj_cc} ${PYBIND_REGISTRY_CC})\n  set_property(TARGET oneflow_internal APPEND PROPERTY BUILD_RPATH \"\\$ORIGIN/../nvidia/cublas/lib\")\n  set_property(TARGET oneflow_internal APPEND PROPERTY BUILD_RPATH \"\\$ORIGIN/../nvidia/cudnn/lib\")\n  set_property(TARGET oneflow_internal APPEND PROPERTY BUILD_RPATH \"\\$ORIGIN/../nvidia/nccl/lib\")\n  set_property(TARGET oneflow_internal APPEND PROPERTY BUILD_RPATH\n                                                       \"\\$ORIGIN/../nvidia/cusparse/lib\")\n  set_property(TARGET oneflow_internal APPEND PROPERTY BUILD_RPATH \"\\$ORIGIN/../nvidia/cufft/lib\")\n  set_compile_options_to_oneflow_target(oneflow_internal)\n  set_property(TARGET oneflow_internal PROPERTY CXX_VISIBILITY_PRESET \"default\")\n  add_dependencies(oneflow_internal of_functional_obj of_functional_tensor_obj of_op_schema)\n  set_target_properties(oneflow_internal PROPERTIES PREFIX \"_\")\n  set_target_properties(oneflow_internal PROPERTIES LIBRARY_OUTPUT_DIRECTORY\n                                                    \"${ONEFLOW_PYTHON_DIR}/oneflow\")\n  target_link_libraries(\n    oneflow_internal PRIVATE ${of_libs} of_functional_tensor_obj of_api_common\n                             ${oneflow_third_party_libs} of_pyext_obj glog::glog)\n  target_include_directories(oneflow_internal PRIVATE ${Python_INCLUDE_DIRS}\n                                                      ${Python_NumPy_INCLUDE_DIRS})\n\n  if(WITH_MLIR)\n    add_dependencies(check-oneflow oneflow_internal)\n  endif(WITH_MLIR)\n\n  set(gen_pip_args \"\")\n  if(BUILD_CUDA)\n    list(APPEND gen_pip_args --cuda=${CUDA_VERSION})\n  endif()\n\n  add_custom_target(\n    of_pyscript_copy ALL\n    COMMAND ${CMAKE_COMMAND} -E touch \"${of_proto_python_dir}/oneflow/core/__init__.py\"\n    COMMAND ${CMAKE_COMMAND} -E create_symlink \"${of_proto_python_dir}/oneflow/core\"\n            \"${ONEFLOW_PYTHON_DIR}/oneflow/core\"\n    COMMAND\n      ${Python_EXECUTABLE} ${PROJECT_SOURCE_DIR}/tools/generate_pip_version.py ${gen_pip_args}\n      --src=${PROJECT_SOURCE_DIR} --cmake_project_binary_dir=${PROJECT_BINARY_DIR}\n      --out=${ONEFLOW_PYTHON_DIR}/oneflow/version.py)\n\n  # source this file to add oneflow in PYTHONPATH\n  file(WRITE \"${PROJECT_BINARY_DIR}/source.sh\"\n       \"export PYTHONPATH=${ONEFLOW_PYTHON_DIR}:$PYTHONPATH\")\n\n  add_dependencies(of_pyscript_copy of_protoobj)\n\nendif(BUILD_PYTHON)\n\nif(BUILD_CPP_API)\n  file(GLOB_RECURSE of_cpp_api_files ${PROJECT_SOURCE_DIR}/oneflow/api/cpp/*.cpp\n       ${PROJECT_SOURCE_DIR}/oneflow/api/cpp/*.h)\n  list(FILTER of_cpp_api_files EXCLUDE REGEX \"oneflow/api/cpp/tests\")\n  oneflow_add_library(oneflow_cpp SHARED ${of_cpp_api_files})\n  set_target_properties(oneflow_cpp PROPERTIES ARCHIVE_OUTPUT_DIRECTORY \"${LIBONEFLOW_LIBRARY_DIR}\"\n                                               LIBRARY_OUTPUT_DIRECTORY \"${LIBONEFLOW_LIBRARY_DIR}\")\n  target_link_libraries(oneflow_cpp PRIVATE ${of_libs} of_api_common ${oneflow_third_party_libs})\nendif()\n\nfile(RELATIVE_PATH PROJECT_BINARY_DIR_RELATIVE ${PROJECT_SOURCE_DIR} ${PROJECT_BINARY_DIR})\n\nfunction(oneflow_add_test target_name)\n  cmake_parse_arguments(arg \"\" \"TEST_NAME;WORKING_DIRECTORY\" \"SRCS\" ${ARGN})\n  oneflow_add_executable(${target_name} ${arg_SRCS})\n  if(BUILD_CUDA)\n    target_link_libraries(${target_name} CUDA::cudart_static)\n  endif()\n  set_target_properties(${target_name} PROPERTIES RUNTIME_OUTPUT_DIRECTORY\n                                                  \"${PROJECT_BINARY_DIR}/bin\")\n  add_test(NAME ${arg_TEST_NAME} COMMAND ${target_name} WORKING_DIRECTORY ${arg_WORKING_DIRECTORY})\n  set_tests_properties(\n    ${arg_TEST_NAME} PROPERTIES ENVIRONMENT\n                                \"HTTP_PROXY='';HTTPS_PROXY='';http_proxy='';https_proxy='';\")\nendfunction()\n\n# build test\nif(BUILD_TESTING)\n  if(of_all_test_cc)\n    oneflow_add_test(oneflow_testexe SRCS ${of_all_test_cc} TEST_NAME oneflow_test)\n    target_link_libraries(oneflow_testexe ${of_libs} ${oneflow_third_party_libs} glog::glog\n                          ${oneflow_test_libs})\n    if(WITH_MLIR)\n      target_link_libraries(oneflow_testexe ${ALL_ARCHIVE_BEGIN} MLIROneFlowExtension\n                            ${ALL_ARCHIVE_END})\n    endif()\n  endif()\n\n  if(BUILD_CPP_API)\n    file(GLOB_RECURSE cpp_api_test_files ${PROJECT_SOURCE_DIR}/oneflow/api/cpp/tests/*.cpp)\n    oneflow_add_test(\n      oneflow_cpp_api_testexe\n      SRCS\n      ${cpp_api_test_files}\n      TEST_NAME\n      oneflow_cpp_api_test\n      WORKING_DIRECTORY\n      ${PROJECT_SOURCE_DIR})\n    find_package(Threads REQUIRED)\n    target_link_libraries(oneflow_cpp_api_testexe oneflow_cpp ${oneflow_third_party_libs}\n                          ${oneflow_test_libs} Threads::Threads)\n  endif()\nendif()\n\n# build include\nadd_custom_target(of_include_copy ALL)\n\nif(BUILD_PYTHON)\n\n  add_dependencies(of_include_copy oneflow_internal of_pyscript_copy)\n  install(\n    DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/oneflow/core\n    DESTINATION ${ONEFLOW_INCLUDE_DIR}/oneflow\n    COMPONENT oneflow_py_include\n    EXCLUDE_FROM_ALL FILES_MATCHING\n    PATTERN *.h\n    PATTERN *.hpp)\n  install(\n    DIRECTORY ${CMAKE_SOURCE_DIR}/oneflow\n    DESTINATION ${ONEFLOW_INCLUDE_DIR}\n    COMPONENT oneflow_py_include\n    EXCLUDE_FROM_ALL FILES_MATCHING\n    REGEX \"oneflow/core/common/.+(h|hpp)$\"\n    REGEX \"oneflow/core/device/.+(h|hpp)$\"\n    REGEX \"oneflow/core/framework/.+(h|hpp)$\"\n    REGEX \"oneflow/core/kernel/util/.+(h|hpp)$\"\n    REGEX \"oneflow/core/persistence/.+(h|hpp)$\"\n    REGEX \"oneflow/core/ep/include/.+(h|hpp)$\"\n    REGEX \"oneflow/core/ep/common/.+(h|hpp)$\"\n    REGEX \"oneflow/core/ep/cpu/.+(h|hpp)$\"\n    REGEX \"oneflow/core/ep/cuda/.+(h|hpp)$\"\n    REGEX \"oneflow/core/job/.+(h|hpp)$\"\n    REGEX \"oneflow/core/intrusive/.+(h|hpp)$\"\n    REGEX \"oneflow/core/graph/boxing/.+(h|hpp)$\"\n    REGEX \"oneflow/core/vm/.+(h|hpp)$\"\n    REGEX \"oneflow/core/.+(proto)$\"\n    REGEX \"oneflow/user/.+(h|hpp)$\"\n    PATTERN \"oneflow/core/kernel/chain_kernel_observer.h\"\n    PATTERN \"oneflow/core/kernel/cuda_graph_support.h\"\n    PATTERN \"oneflow/core/kernel/new_kernel_util.h\"\n    PATTERN \"oneflow/core/kernel/kernel.h\"\n    PATTERN \"oneflow/core/kernel/kernel_context.h\"\n    PATTERN \"oneflow/core/kernel/kernel_observer.h\"\n    PATTERN \"oneflow/core/kernel/kernel_util.h\"\n    PATTERN \"oneflow/core/kernel/kernel_util.cuh\"\n    PATTERN \"oneflow/core/kernel/kernel_registration.h\"\n    PATTERN \"oneflow/core/common/symbol.h\"\n    PATTERN \"oneflow/core/register/blob.h\"\n    PATTERN \"oneflow/core/register/op_blob_arg_info.h\"\n    PATTERN \"oneflow/core/register/register.h\"\n    PATTERN \"oneflow/core/register/register_desc.h\"\n    PATTERN \"oneflow/core/register/register_manager.h\"\n    PATTERN \"oneflow/core/register/runtime_register_desc.h\"\n    PATTERN \"oneflow/core/register/tensor_slice_view.h\"\n    PATTERN \"oneflow/core/ndarray/xpu_util.h\"\n    PATTERN \"oneflow/core/rpc/include/base.h\"\n    PATTERN \"oneflow/core/rpc/include/ctrl.h\"\n    PATTERN \"oneflow/core/rpc/include/global_process_ctx.h\"\n    PATTERN \"oneflow/core/control/ctrl_client.h\"\n    PATTERN \"oneflow/core/control/global_process_ctx.h\"\n    PATTERN \"oneflow/core/autograd/autograd_meta.h\"\n    PATTERN \"oneflow/core/register/blob_desc.h\"\n    PATTERN \"oneflow/core/operator/operator.h\"\n    PATTERN \"oneflow/core/operator/operator_util.h\"\n    PATTERN \"oneflow/core/operator/op_conf_util.h\"\n    PATTERN \"oneflow/core/graph/compute_task_node.h\"\n    PATTERN \"oneflow/core/graph/copy_task_node.h\"\n    PATTERN \"oneflow/core/graph/exec_graph.h\"\n    PATTERN \"oneflow/core/graph/graph.h\"\n    PATTERN \"oneflow/core/graph/node.h\"\n    PATTERN \"oneflow/core/graph/op_graph.h\"\n    PATTERN \"oneflow/core/graph/task_graph.h\"\n    PATTERN \"oneflow/core/graph/task_id.h\"\n    PATTERN \"oneflow/core/graph/task_id_generator.h\"\n    PATTERN \"oneflow/core/graph/task_node.h\"\n    PATTERN \"oneflow/core/graph/task_stream_index_manager.h\"\n    PATTERN \"oneflow/core/graph/stream_id.h\"\n    PATTERN \"oneflow/core/graph/stream_index_generator.h\"\n    PATTERN \"oneflow/core/graph/fake_consumed_regst_provider.h\"\n    PATTERN \"oneflow/core/graph/transport_task_node.h\"\n    PATTERN \"oneflow/core/thread/thread.h\"\n    PATTERN \"oneflow/core/thread/thread_manager.h\"\n    PATTERN \"oneflow/core/thread/thread_pool.h\"\n    PATTERN \"oneflow/core/thread/thread_runtime.h\"\n    PATTERN \"oneflow/core/thread/thread_runtime_factory.h\"\n    PATTERN \"oneflow/core/profiler/profiler.h\"\n    PATTERN \"oneflow/extension/stack/foreign_stack_getter.h\"\n    PATTERN \"oneflow/core/platform/include/pthread_fork.h\"\n    PATTERN \"oneflow/core/lazy/actor/actor.h\"\n    PATTERN \"oneflow/core/lazy/actor/actor_base.h\"\n    PATTERN \"oneflow/core/lazy/actor/actor_context.h\"\n    PATTERN \"oneflow/core/lazy/actor/actor_message.h\"\n    PATTERN \"oneflow/core/lazy/actor/actor_message_bus.h\"\n    PATTERN \"oneflow/core/lazy/actor/register_slot.h\"\n    PATTERN \"oneflow/core/lazy/stream_context/include/stream_context.h\"\n    PATTERN \"oneflow/core/memory/memory_allocator.h\"\n    PATTERN \"oneflow/core/memory/memory_case_util.h\"\n    PATTERN \"oneflow/core/memory/memory_zone.h\"\n    PATTERN \"oneflow/user/ops/convert_memory_format.h\"\n    PATTERN \"oneflow/api\" EXCLUDE\n    PATTERN \"oneflow/maybe\" EXCLUDE\n    PATTERN \"oneflow/core/graph_impl\" EXCLUDE\n    PATTERN \"oneflow/core/job_rewriter\" EXCLUDE\n    PATTERN \"oneflow/core/hardware\" EXCLUDE\n    PATTERN \"oneflow/core/stream\" EXCLUDE\n    PATTERN \"oneflow/core/functional\" EXCLUDE\n    PATTERN \"oneflow/core/boxing\" EXCLUDE\n    PATTERN \"oneflow/core/transport\" EXCLUDE\n    PATTERN \"oneflow/core/comm_network\" EXCLUDE\n    PATTERN \"oneflow/ir\" EXCLUDE)\n  add_custom_target(\n    install_oneflow_py_include\n    COMMAND \"${CMAKE_COMMAND}\" -DCMAKE_INSTALL_COMPONENT=oneflow_py_include -P\n            \"${CMAKE_BINARY_DIR}/cmake_install.cmake\" DEPENDS oneflow_internal)\n  add_custom_target(oneflow_py ALL)\n  add_dependencies(oneflow_py of_include_copy install_oneflow_py_include)\n\nendif(BUILD_PYTHON)\n\nif(BUILD_CPP_API)\n\n  set(LIBONEFLOW_DIR ${PROJECT_BINARY_DIR}/liboneflow_cpp)\n\n  install(\n    DIRECTORY oneflow/api/cpp/\n    COMPONENT oneflow_cpp_all\n    DESTINATION include/oneflow\n    FILES_MATCHING\n    PATTERN \"*.h\"\n    PATTERN \"tests\" EXCLUDE)\n  set(LIBONEFLOW_THIRD_PARTY_DIRS)\n  checkdirandappendslash(DIR ${PROTOBUF_LIBRARY_DIR} OUTPUT PROTOBUF_LIBRARY_DIR_APPENDED)\n  list(APPEND LIBONEFLOW_THIRD_PARTY_DIRS ${PROTOBUF_LIBRARY_DIR_APPENDED})\n  if(BUILD_CUDA)\n    checkdirandappendslash(DIR ${NCCL_LIBRARY_DIR} OUTPUT NCCL_LIBRARY_DIR_APPENDED)\n    list(APPEND LIBONEFLOW_THIRD_PARTY_DIRS ${NCCL_LIBRARY_DIR_APPENDED})\n    checkdirandappendslash(DIR ${TRT_FLASH_ATTENTION_LIBRARY_DIR} OUTPUT\n                           TRT_FLASH_ATTENTION_LIBRARY_DIR_APPENDED)\n    list(APPEND LIBONEFLOW_THIRD_PARTY_DIRS ${TRT_FLASH_ATTENTION_LIBRARY_DIR_APPENDED})\n    if(CUDA_VERSION VERSION_GREATER_EQUAL \"11.7\")\n      checkdirandappendslash(DIR ${FLASH_ATTENTION_LIBRARY_DIR} OUTPUT\n                             FLASH_ATTENTION_LIBRARY_DIR_APPENDED)\n      list(APPEND LIBONEFLOW_THIRD_PARTY_DIRS ${FLASH_ATTENTION_LIBRARY_DIR_APPENDED})\n    endif()\n    if(WITH_CUTLASS)\n      checkdirandappendslash(DIR ${CUTLASS_LIBRARY_DIR} OUTPUT CUTLASS_LIBRARY_DIR_APPENDED)\n      list(APPEND LIBONEFLOW_THIRD_PARTY_DIRS ${CUTLASS_LIBRARY_DIR_APPENDED})\n    endif()\n  endif()\n\n  install(\n    DIRECTORY ${LIBONEFLOW_THIRD_PARTY_DIRS}\n    COMPONENT oneflow_cpp_all\n    DESTINATION lib\n    FILES_MATCHING\n    PATTERN \"*.so*\"\n    PATTERN \"*.a\" EXCLUDE\n    PATTERN \"libprotobuf-lite.so*\" EXCLUDE\n    PATTERN \"libprotoc.so*\" EXCLUDE\n    PATTERN \"cmake\" EXCLUDE\n    PATTERN \"pkgconfig\" EXCLUDE)\n\n  install(FILES ${PROJECT_SOURCE_DIR}/cmake/oneflow-config.cmake COMPONENT oneflow_cpp_all\n          DESTINATION share)\n\n  get_property(MLIR_RELATED_TARGETS GLOBAL PROPERTY MLIR_EXPORTS)\n  get_property(LLVM_RELATED_TARGETS GLOBAL PROPERTY LLVM_EXPORTS)\n\n  list(\n    REMOVE_ITEM\n    LLVM_RELATED_TARGETS\n    count\n    not\n    FileCheck\n    lli-child-target\n    llvm-jitlink-executor\n    llvm-PerfectShuffle\n    llvm-tblgen\n    mlir-tblgen\n    mlir-pdll\n    obj2yaml\n    oneflow_tblgen\n    yaml-bench\n    yaml2obj)\n\n  set(LIBONEFLOW_TARGETS)\n  list(\n    APPEND\n    LIBONEFLOW_TARGETS\n    oneflow_cpp\n    oneflow\n    of_protoobj\n    glog\n    ${MLIR_RELATED_TARGETS}\n    ${LLVM_RELATED_TARGETS}\n    ${EXTERNAL_TARGETS})\n\n  if(BUILD_TESTING AND BUILD_SHARED_LIBS)\n    list(APPEND LIBONEFLOW_TARGETS gtest_main gtest)\n  endif()\n\n  if(BUILD_TESTING)\n    list(APPEND LIBONEFLOW_TARGETS oneflow_cpp_api_testexe)\n    list(APPEND LIBONEFLOW_TARGETS oneflow_testexe)\n  endif(BUILD_TESTING)\n\n  install(\n    TARGETS ${LIBONEFLOW_TARGETS}\n    COMPONENT oneflow_cpp_all\n    LIBRARY DESTINATION lib\n    ARCHIVE DESTINATION lib\n    RUNTIME DESTINATION bin)\n\n  add_custom_target(\n    install_oneflow_cpp\n    COMMAND \"${CMAKE_COMMAND}\" -DCMAKE_INSTALL_COMPONENT=oneflow_cpp_all\n            -DCMAKE_INSTALL_PREFIX=\"${LIBONEFLOW_DIR}\" -P \"${CMAKE_BINARY_DIR}/cmake_install.cmake\"\n    DEPENDS oneflow_cpp)\n  if(BUILD_TESTING)\n    add_dependencies(install_oneflow_cpp oneflow_cpp_api_testexe oneflow_testexe)\n  endif(BUILD_TESTING)\n  add_dependencies(of_include_copy install_oneflow_cpp)\n\n  string(TOLOWER ${CMAKE_SYSTEM_NAME} CPACK_SYSTEM_NAME)\n  set(CPACK_GENERATOR ZIP)\n  set(CPACK_PACKAGE_DIRECTORY ${PROJECT_BINARY_DIR}/cpack)\n  set(CPACK_PACKAGE_NAME liboneflow)\n  # TODO: by Shenghang, unify python and c++ version genenerating and getting\n  set(CPACK_PACKAGE_VERSION ${ONEFLOW_CURRENT_VERSION})\n  set(CPACK_INSTALL_CMAKE_PROJECTS ${PROJECT_BINARY_DIR};oneflow;oneflow_cpp_all;/)\n  include(CPack)\nendif(BUILD_CPP_API)\n"
  },
  {
    "path": "cmake/op_schema.cmake",
    "content": "get_property(LLVM_INSTALL_DIR GLOBAL PROPERTY LLVM_INSTALL_DIR)\nset(LLVM_INSTALL_DIR ${THIRD_PARTY_DIR}/llvm)\nset(LLVM_DIR ${LLVM_INSTALL_DIR}/lib/cmake/llvm)\nset(ONEFLOW_OP_GROUPS\n    \"ASSIGN\"\n    \"BINARY\"\n    \"BROADCAST\"\n    \"CONV\"\n    \"CROSS_ENTROPY\"\n    \"CUDA\"\n    \"DATASET\"\n    \"DETECTION\"\n    \"EAGER\"\n    \"FUSED\"\n    \"IDEMPOTENT\"\n    \"IDENTITY\"\n    \"IMAGE\"\n    \"INDICES\"\n    \"INVOLUTION\"\n    \"LOSS\"\n    \"MATH\"\n    \"MATMUL\"\n    \"MISC\"\n    \"NCCL\"\n    \"NORMALIZATION\"\n    \"OPTIMIZER\"\n    \"PADDING\"\n    \"PARALLEL_CAST\"\n    \"POOL\"\n    \"QUANTIZATION\"\n    \"REDUCE\"\n    \"RESHAPE\"\n    \"SCALAR\"\n    \"SOFTMAX\"\n    \"SUMMARY\"\n    \"TENSOR_BUFFER\"\n    \"TEST\"\n    \"TRIGONOMETRIC\"\n    \"UNARY\"\n    \"UPSAMPLE\"\n    \"ONE_EMBEDDING\"\n    \"LINEAR_ALGEBRA\"\n    \"SYSTEM\")\nif(WITH_MLIR)\n  list(APPEND ONEFLOW_OP_GROUPS \"MLIR_JIT\")\nendif(WITH_MLIR)\n\nforeach(OP_GROUP_NAME IN LISTS ONEFLOW_OP_GROUPS)\n  list(APPEND ONEFLOW_SCHEMA_TABLEGEN_FLAGS \"-DGET_ONEFLOW_${OP_GROUP_NAME}_OP_DEFINITIONS\")\nendforeach()\nlist(APPEND ONEFLOW_SCHEMA_TABLEGEN_FLAGS \"-DREMOVE_ONEFLOW_MLIR_ONLY_OP_DEFINITIONS\")\n\nset(GENERATED_OP_SCHEMA_DIR oneflow/core/framework)\nset(GENERATED_IR_INCLUDE_DIR oneflow/ir/include)\nset(SOURCE_IR_INCLUDE_DIR ${PROJECT_SOURCE_DIR}/oneflow/ir/include)\nset(ONEFLOW_ODS ${SOURCE_IR_INCLUDE_DIR}/OneFlow/OneFlowOps.td)\n\nlist(APPEND ONEFLOW_SCHEMA_TABLEGEN_FLAGS \"-I${GENERATED_IR_INCLUDE_DIR}\")\nlist(APPEND ONEFLOW_SCHEMA_TABLEGEN_FLAGS \"-I${SOURCE_IR_INCLUDE_DIR}\")\nlist(APPEND ONEFLOW_SCHEMA_TABLEGEN_FLAGS \"-I${LLVM_INSTALL_DIR}/include\")\n\nset(GENERATED_OP_SCHEMA_H \"${GENERATED_OP_SCHEMA_DIR}/op_generated.h\")\nset(GENERATED_OP_SCHEMA_CPP \"${GENERATED_OP_SCHEMA_DIR}/op_generated.cpp\")\n\nset(ONEFLOW_TABLE_GEN_EXE ${LLVM_INSTALL_DIR}/bin/oneflow_tblgen)\nif(LLVM_PROVIDER STREQUAL \"in-tree\")\n  set(ONEFLOW_TABLE_GEN_TARGET oneflow_tblgen install-oneflow-tblgen install-mlir-headers)\nelseif(LLVM_PROVIDER STREQUAL \"install\")\n  set(ONEFLOW_TABLE_GEN_TARGET ${ONEFLOW_TABLE_GEN_EXE})\nendif()\n\nfile(GLOB_RECURSE ODS_FILES LIST_DIRECTORIES false \"${SOURCE_IR_INCLUDE_DIR}/*.td\")\nif(NOT ODS_FILES)\n  message(FATAL_ERROR \"ODS_FILES not found: ${ODS_FILES}\")\nendif()\nadd_custom_command(\n  OUTPUT ${GENERATED_OP_SCHEMA_H} ${GENERATED_OP_SCHEMA_CPP}\n  COMMAND ${CMAKE_COMMAND} ARGS -E make_directory ${GENERATED_OP_SCHEMA_DIR}\n  COMMAND ${ONEFLOW_TABLE_GEN_EXE} ARGS --gen-op-schema-h ${ONEFLOW_ODS}\n          ${ONEFLOW_SCHEMA_TABLEGEN_FLAGS} -o ${GENERATED_OP_SCHEMA_H}\n  COMMAND ${ONEFLOW_TABLE_GEN_EXE} ARGS --gen-op-schema-cpp ${ONEFLOW_ODS}\n          ${ONEFLOW_SCHEMA_TABLEGEN_FLAGS} --op-include ${GENERATED_OP_SCHEMA_H} -o\n          ${GENERATED_OP_SCHEMA_CPP}\n  DEPENDS ${ONEFLOW_TABLE_GEN_TARGET} ${ODS_FILES}\n  VERBATIM)\nset_source_files_properties(${GENERATED_OP_SCHEMA_H} ${GENERATED_OP_SCHEMA_CPP} PROPERTIES GENERATED\n                                                                                           TRUE)\n\noneflow_add_library(of_op_schema OBJECT ${GENERATED_OP_SCHEMA_H} ${GENERATED_OP_SCHEMA_CPP})\ntarget_link_libraries(of_op_schema LLVMSupportWithHeader glog::glog fmt)\nadd_dependencies(of_op_schema prepare_oneflow_third_party)\n"
  },
  {
    "path": "cmake/platform.cmake",
    "content": "if(WIN32)\n  set(CMAKE_BUILD_TYPE Debug)\n  add_definitions(-DNOMINMAX -D_WIN32_WINNT=0x0A00 -DLANG_CXX11 -DCOMPILER_MSVC\n                  -D__VERSION__=\\\"MSVC\\\")\n  add_definitions(\n    -DWIN32\n    -DOS_WIN\n    -D_MBCS\n    -DWIN64\n    -DWIN32_LEAN_AND_MEAN\n    -DNOGDI\n    -DPLATFORM_WINDOWS\n    -D_ITERATOR_DEBUG_LEVEL=0)\n  add_definitions(\n    /bigobj\n    /nologo\n    /EHsc\n    /GF\n    /FC\n    /MP\n    /Gm-)\n  add_definitions(-DGOOGLE_GLOG_DLL_DECL=)\n  set(CMAKE_CXX_FLAGS \"${CMAKE_CXX_FLAGS} /MP\")\n\n  foreach(\n    flag_var\n    CMAKE_C_FLAGS\n    CMAKE_C_FLAGS_DEBUG\n    CMAKE_C_FLAGS_RELEASE\n    CMAKE_CXX_FLAGS\n    CMAKE_CXX_FLAGS_DEBUG\n    CMAKE_CXX_FLAGS_RELEASE\n    CMAKE_CXX_FLAGS_MINSIZEREL\n    CMAKE_CXX_FLAGS_RELWITHDEBINFO)\n    if(${flag_var} MATCHES \"/MD\")\n      string(REGEX REPLACE \"/MD\" \"/MT\" ${flag_var} \"${${flag_var}}\")\n    endif()\n  endforeach()\n\n  # set(CMAKE_EXE_LINKER_FLAGS_DEBUG \"${CMAKE_EXE_LINKER_FLAGS} /DEBUG:FASTLINK\")\n  set(CMAKE_CXX_FLAGS_DEBUG \"${CMAKE_CXX_FLAGS_DEBUG} /D_ITERATOR_DEBUG_LEVEL=0\")\nelse()\n  set(EXTRA_CXX_FLAGS \"-Wall -Wno-sign-compare -Wno-unused-function -fPIC\")\n\n  if(APPLE)\n    set(EXTRA_CXX_FLAGS \"${EXTRA_CXX_FLAGS} -Wno-deprecated-declarations\")\n  endif()\n\n  set(CMAKE_CXX_FLAGS \"${CMAKE_CXX_FLAGS} ${EXTRA_CXX_FLAGS}\")\n  set(CMAKE_CXX_FLAGS_DEBUG \"${CMAKE_CXX_FLAGS_DEBUG} ${EXTRA_CXX_FLAGS}\")\n  set(CMAKE_CXX_FLAGS_RELEASE \"${CMAKE_CXX_FLAGS_RELEASE} ${EXTRA_CXX_FLAGS}\")\n  set(CMAKE_CXX_FLAGS_RELWITHDEBINFO \"${CMAKE_CXX_FLAGS_RELWITHDEBINFO} ${EXTRA_CXX_FLAGS}\")\nendif(WIN32)\n"
  },
  {
    "path": "cmake/proto2cpp.cmake",
    "content": "function(RELATIVE_PROTOBUF_GENERATE_CPP SRCS HDRS ROOT_DIR)\n  if(NOT ARGN)\n    message(SEND_ERROR \"Error: RELATIVE_PROTOBUF_GENERATE_CPP() called without any proto files\")\n    return()\n  endif()\n\n  set(${SRCS})\n  set(${HDRS})\n\n  foreach(FIL ${ARGN})\n    set(ABS_FIL ${ROOT_DIR}/${FIL})\n    get_filename_component(FIL_WE ${FIL} NAME_WE)\n    get_filename_component(FIL_DIR ${ABS_FIL} PATH)\n    file(RELATIVE_PATH REL_DIR ${ROOT_DIR} ${FIL_DIR})\n\n    list(APPEND ${SRCS} \"${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.pb.cc\")\n    list(APPEND ${HDRS} \"${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.pb.h\")\n\n    add_custom_command(\n      OUTPUT \"${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.pb.cc\"\n             \"${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.pb.h\"\n             \"${of_proto_python_dir}/${REL_DIR}/${FIL_WE}_pb2.py\"\n      COMMAND ${PROTOBUF_PROTOC_EXECUTABLE} ARGS --cpp_out ${CMAKE_CURRENT_BINARY_DIR} -I\n              ${ROOT_DIR} ${ABS_FIL} -I ${PROTOBUF_INCLUDE_DIR}\n      COMMAND ${PROTOBUF_PROTOC_EXECUTABLE} ARGS --python_out ${of_proto_python_dir} -I ${ROOT_DIR}\n              ${ABS_FIL} -I ${PROTOBUF_INCLUDE_DIR}\n      COMMAND ${CMAKE_COMMAND} ARGS -E touch ${of_proto_python_dir}/${REL_DIR}/__init__.py\n      DEPENDS ${ABS_FIL} protobuf\n      COMMENT \"Running Protocol Buffer Compiler on ${FIL}\"\n      VERBATIM)\n  endforeach()\n\n  set_source_files_properties(${${SRCS}} ${${HDRS}} PROPERTIES GENERATED TRUE)\n  set(${SRCS} ${${SRCS}} PARENT_SCOPE)\n  set(${HDRS} ${${HDRS}} PARENT_SCOPE)\nendfunction()\n"
  },
  {
    "path": "cmake/pybind11.cmake",
    "content": "include(FetchContent)\n\nset_mirror_url_with_hash(PYBIND11_URL https://github.com/pybind/pybind11/archive/v2.11.1.zip\n                         c62d9e05243bd31cdb3bae1bb2f56655)\n\nFetchContent_Declare(pybind11 URL ${PYBIND11_URL} URL_HASH MD5=${PYBIND11_URL_HASH})\n\nFetchContent_MakeAvailable(pybind11)\n"
  },
  {
    "path": "cmake/python.cmake",
    "content": "if(NOT DEFINED Python3_EXECUTABLE)\n  execute_process(\n    COMMAND which python3\n    RESULT_VARIABLE STATUS\n    OUTPUT_VARIABLE OUTPUT\n    ERROR_QUIET)\n  if(STATUS EQUAL 0)\n    string(STRIP ${OUTPUT} STRIPPED)\n    message(STATUS \"Using Python3 from 'which python3': ${STRIPPED}\")\n    set(Python3_EXECUTABLE ${STRIPPED})\n  endif()\nendif()\nfind_package(Python3 COMPONENTS Interpreter REQUIRED)\nmessage(STATUS \"Python3 specified. Version found: \" ${Python3_VERSION})\nset(Python_EXECUTABLE ${Python3_EXECUTABLE})\nmessage(STATUS \"Using Python executable: \" ${Python_EXECUTABLE})\n\nmessage(STATUS \"Installing necessary Python packages...\")\nset(requirements_txt ${PROJECT_SOURCE_DIR}/dev-requirements.txt)\nset_property(DIRECTORY APPEND PROPERTY CMAKE_CONFIGURE_DEPENDS ${requirements_txt})\nmessage(STATUS \"PIP_INDEX_MIRROR: ${PIP_INDEX_MIRROR}\")\nif(PIP_INDEX_MIRROR)\n  set(extra_index_arg \"-i\")\nendif()\n\nfunction(install_py_dev_deps)\n  execute_process(COMMAND ${ARGV0} -m pip install ${extra_index_arg} ${PIP_INDEX_MIRROR} -r\n                          ${requirements_txt} --user RESULT_VARIABLE PIP_INSTALL_STATUS)\n  if(NOT PIP_INSTALL_STATUS EQUAL 0)\n    message(FATAL_ERROR \"fail to install pip packages\")\n  endif()\n  message(STATUS \"Python packages are installed.\")\nendfunction(install_py_dev_deps)\ninstall_py_dev_deps(${Python_EXECUTABLE})\n\nfind_package(Python3 COMPONENTS Development NumPy)\nif(Python3_Development_FOUND AND Python3_INCLUDE_DIRS)\n  set(Python_INCLUDE_DIRS ${Python3_INCLUDE_DIRS})\nendif()\nif(Python3_NumPy_FOUND AND Python3_NumPy_INCLUDE_DIRS)\n  set(Python_NumPy_INCLUDE_DIRS ${Python3_NumPy_INCLUDE_DIRS})\nendif()\nif(NOT Python_INCLUDE_DIRS)\n  message(STATUS \"Getting python include directory from sysconfig..\")\n  execute_process(\n    COMMAND ${Python_EXECUTABLE} -c \"import sysconfig; print(sysconfig.get_paths()['include'])\"\n    OUTPUT_VARIABLE Python_INCLUDE_DIRS RESULT_VARIABLE ret_code)\n  string(STRIP \"${Python_INCLUDE_DIRS}\" Python_INCLUDE_DIRS)\n  if((NOT (ret_code EQUAL \"0\")) OR (NOT IS_DIRECTORY ${Python_INCLUDE_DIRS})\n     OR (NOT EXISTS ${Python_INCLUDE_DIRS}/Python.h))\n    set(Python_INCLUDE_DIRS \"\")\n  endif()\nendif()\nif(NOT Python_INCLUDE_DIRS)\n  message(FATAL_ERROR \"Cannot find python include directory\")\nendif()\nmessage(STATUS \"Found python include directory ${Python_INCLUDE_DIRS}\")\n\nif(NOT Python_NumPy_INCLUDE_DIRS)\n  message(STATUS \"Getting numpy include directory by numpy.get_include()..\")\n  execute_process(COMMAND ${Python_EXECUTABLE} -c \"import numpy; print(numpy.get_include())\"\n                  OUTPUT_VARIABLE Python_NumPy_INCLUDE_DIRS RESULT_VARIABLE ret_code)\n  string(STRIP \"${Python_NumPy_INCLUDE_DIRS}\" Python_NumPy_INCLUDE_DIRS)\n  if((NOT ret_code EQUAL 0) OR (NOT IS_DIRECTORY ${Python_NumPy_INCLUDE_DIRS})\n     OR (NOT EXISTS ${Python_NumPy_INCLUDE_DIRS}/numpy/arrayobject.h))\n    set(Python_NumPy_INCLUDE_DIRS \"\")\n  endif()\nendif()\nif(NOT Python_NumPy_INCLUDE_DIRS)\n  message(FATAL_ERROR \"Cannot find numpy include directory\")\nendif()\nmessage(STATUS \"Found numpy include directory ${Python_NumPy_INCLUDE_DIRS}\")\n\n# PYTHON_EXECUTABLE will be used by pybind11\nset(PYTHON_EXECUTABLE ${Python_EXECUTABLE})\ninclude(pybind11)\n\nset(CODEGEN_PYTHON_EXECUTABLE ${Python_EXECUTABLE}\n    CACHE STRING \"Python executable to generate .cpp/.h files\")\nif(NOT \"${CODEGEN_PYTHON_EXECUTABLE}\" STREQUAL \"${Python_EXECUTABLE}\")\n  install_py_dev_deps(${CODEGEN_PYTHON_EXECUTABLE})\nendif()\n"
  },
  {
    "path": "cmake/third_party/FindBFD.cmake",
    "content": "# - BFD Library module.\n#=============================================================================\n# This module finds libbfd and associated headers.\n#\n#=== Variables ===============================================================\n# This module will set the following variables in your project:\n#\n#   BFD_FOUND            Whether libbfd was successfully found.\n#   bfd::bfd             Cmake target for bfd\n#\n#=============================================================================\n\ninclude(FindPackageHandleStandardArgs)\n\nset(CMAKE_LIBRARY_PATH /lib /usr/lib /usr/local/lib)\nset(CMAKE_INCLUDE_PATH /usr/include /usr/local/include)\n\nfind_path(BFD_INCLUDE_PATH bfd.h PATH /usr/include /usr/local/include)\nfind_library(BFD_LIBRARIES bfd PATH /lib /usr/lib /usr/local/lib)\n\nfind_package_handle_standard_args(BFD DEFAULT_MSG BFD_LIBRARIES BFD_INCLUDE_PATH)\n\nif(BFD_FOUND)\n  if(NOT TARGET bfd::bfd)\n    add_library(bfd::bfd INTERFACE IMPORTED)\n    set_property(TARGET bfd::bfd PROPERTY INTERFACE_INCLUDE_DIRECTORIES ${BFD_INCLUDE_PATH})\n    set_property(TARGET bfd::bfd PROPERTY INTERFACE_LINK_LIBRARIES ${BFD_LIBRARIES})\n    set_property(TARGET bfd::bfd PROPERTY IMPORTED_CONFIGURATIONS RELEASE)\n  endif(NOT TARGET bfd::bfd)\nendif()\n\nmark_as_advanced(BFD_INCLUDE_PATH BFD_LIBRARIES)\n"
  },
  {
    "path": "cmake/third_party/FindBLAS.cmake",
    "content": "#.rst:\n# FindBLAS\n# --------\n#\n# Find BLAS library\n#\n# This module finds an installed fortran library that implements the\n# BLAS linear-algebra interface (see http://www.netlib.org/blas/).  The\n# list of libraries searched for is taken from the autoconf macro file,\n# acx_blas.m4 (distributed at\n# http://ac-archive.sourceforge.net/ac-archive/acx_blas.html).\n#\n# This module sets the following variables:\n#\n# ::\n#\n#   BLAS_FOUND - set to true if a library implementing the BLAS interface\n#     is found\n#   BLAS_LINKER_FLAGS - uncached list of required linker flags (excluding -l\n#     and -L).\n#   BLAS_LIBRARIES - uncached list of libraries (using full path name) to\n#     link against to use BLAS\n#   BLAS95_LIBRARIES - uncached list of libraries (using full path name)\n#     to link against to use BLAS95 interface\n#   BLAS95_FOUND - set to true if a library implementing the BLAS f95 interface\n#     is found\n#   BLA_STATIC  if set on this determines what kind of linkage we do (static)\n#   BLA_VENDOR  if set checks only the specified vendor, if not set checks\n#      all the possibilities\n#   BLA_F95     if set on tries to find the f95 interfaces for BLAS/LAPACK\n#\n# ######### ## List of vendors (BLA_VENDOR) valid in this module #\n# Goto,OpenBLAS,ATLAS PhiPACK,CXML,DXML,SunPerf,SCSL,SGIMATH,IBMESSL,\n# Intel10_32 (intel mkl v10 32 bit),Intel10_64lp (intel mkl v10 64 bit,\n# lp thread model, lp64 model), # Intel10_64lp_seq (intel mkl v10 64\n# bit,sequential code, lp64 model), # Intel( older versions of mkl 32\n# and 64 bit), ACML,ACML_MP,ACML_GPU,Apple, NAS, Generic C/CXX should be\n# enabled to use Intel mkl\n\n#=============================================================================\n# Copyright 2007-2009 Kitware, Inc.\n#\n# Distributed under the OSI-approved BSD License (the \"License\");\n# see accompanying file Copyright.txt for details.\n#\n# This software is distributed WITHOUT ANY WARRANTY; without even the\n# implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.\n# See the License for more information.\n#=============================================================================\n# (To distribute this file outside of CMake, substitute the full\n#  License text for the above reference.)\n\nset(CMAKE_REQUIRED_QUIET ${BLAS_FIND_QUIETLY})\n\nset(_blas_ORIG_CMAKE_FIND_LIBRARY_SUFFIXES ${CMAKE_FIND_LIBRARY_SUFFIXES})\n\n# Check the language being used\nif(NOT (CMAKE_C_COMPILER_LOADED OR CMAKE_CXX_COMPILER_LOADED OR CMAKE_Fortran_COMPILER_LOADED))\n  if(BLAS_FIND_REQUIRED)\n    message(FATAL_ERROR \"FindBLAS requires Fortran, C, or C++ to be enabled.\")\n  else()\n    message(STATUS \"Looking for BLAS... - NOT found (Unsupported languages)\")\n    return()\n  endif()\nendif()\n\nmacro(\n  Check_Fortran_Libraries\n  LIBRARIES\n  _prefix\n  _name\n  _flags\n  _list\n  _thread)\n  # This macro checks for the existence of the combination of fortran libraries\n  # given by _list.  If the combination is found, this macro checks (using the\n  # Check_Fortran_Function_Exists macro) whether can link against that library\n  # combination using the name of a routine given by _name using the linker\n  # flags given by _flags.  If the combination of libraries is found and passes\n  # the link test, LIBRARIES is set to the list of complete library paths that\n  # have been found.  Otherwise, LIBRARIES is set to FALSE.\n\n  # N.B. _prefix is the prefix applied to the names of all cached variables that\n  # are generated internally and marked advanced by this macro.\n\n  set(_libdir ${ARGN})\n\n  set(_libraries_work TRUE)\n  set(${LIBRARIES})\n  set(_combined_name)\n  if(NOT _libdir)\n    if(WIN32)\n      set(_libdir ENV LIB)\n    elseif(APPLE)\n      set(_libdir ENV DYLD_LIBRARY_PATH)\n    else()\n      set(_libdir ENV LD_LIBRARY_PATH)\n    endif()\n  endif()\n\n  foreach(_library ${_list})\n    set(_combined_name ${_combined_name}_${_library})\n\n    if(_libraries_work)\n      if(BLA_STATIC)\n        if(WIN32)\n          set(CMAKE_FIND_LIBRARY_SUFFIXES .lib ${CMAKE_FIND_LIBRARY_SUFFIXES})\n        endif()\n        if(APPLE)\n          set(CMAKE_FIND_LIBRARY_SUFFIXES .lib ${CMAKE_FIND_LIBRARY_SUFFIXES})\n        else()\n          set(CMAKE_FIND_LIBRARY_SUFFIXES .a ${CMAKE_FIND_LIBRARY_SUFFIXES})\n        endif()\n      else()\n        if(CMAKE_SYSTEM_NAME STREQUAL \"Linux\")\n          # for ubuntu's libblas3gf and liblapack3gf packages\n          set(CMAKE_FIND_LIBRARY_SUFFIXES ${CMAKE_FIND_LIBRARY_SUFFIXES} .so.3gf)\n        endif()\n      endif()\n      find_library(${_prefix}_${_library}_LIBRARY NAMES ${_library} PATHS ${_libdir})\n      mark_as_advanced(${_prefix}_${_library}_LIBRARY)\n      set(${LIBRARIES} ${${LIBRARIES}} ${${_prefix}_${_library}_LIBRARY})\n      set(_libraries_work ${${_prefix}_${_library}_LIBRARY})\n    endif()\n  endforeach()\n  if(_libraries_work)\n    set(CMAKE_REQUIRED_LIBRARIES ${_flags} ${${LIBRARIES}} ${_thread})\n    set(CMAKE_REQUIRED_LIBRARIES)\n    mark_as_advanced(${_prefix}${_combined_name}_WORKS)\n    set(_libraries_work ${${_prefix}${_combined_name}_WORKS})\n  endif()\nendmacro()\n\nset(BLAS_LINKER_FLAGS)\nset(BLAS_LIBRARIES)\nset(BLAS95_LIBRARIES)\nif(NOT $ENV{BLA_VENDOR} STREQUAL \"\")\n  set(BLA_VENDOR $ENV{BLA_VENDOR})\nelse()\n  if(NOT BLA_VENDOR)\n    set(BLA_VENDOR \"All\")\n  endif()\nendif()\n\nif(BLA_VENDOR STREQUAL \"Goto\" OR BLA_VENDOR STREQUAL \"All\")\n  if(NOT BLAS_LIBRARIES)\n    # gotoblas (http://www.tacc.utexas.edu/tacc-projects/gotoblas2)\n    Check_Fortran_Libraries(BLAS_LIBRARIES BLAS sgemm \"\" \"goto2\" \"\")\n  endif()\nendif()\n\nif(BLA_VENDOR STREQUAL \"OpenBLAS\" OR BLA_VENDOR STREQUAL \"All\")\n  if(NOT BLAS_LIBRARIES)\n    # OpenBLAS (http://www.openblas.net)\n    Check_Fortran_Libraries(BLAS_LIBRARIES BLAS sgemm \"\" \"openblas\" \"\")\n  endif()\nendif()\n\nif(BLA_VENDOR STREQUAL \"ATLAS\" OR BLA_VENDOR STREQUAL \"All\")\n  if(NOT BLAS_LIBRARIES)\n    # BLAS in ATLAS library? (http://math-atlas.sourceforge.net/)\n    Check_Fortran_Libraries(BLAS_LIBRARIES BLAS dgemm \"\" \"f77blas;atlas\" \"\")\n  endif()\nendif()\n\n# BLAS in PhiPACK libraries? (requires generic BLAS lib, too)\nif(BLA_VENDOR STREQUAL \"PhiPACK\" OR BLA_VENDOR STREQUAL \"All\")\n  if(NOT BLAS_LIBRARIES)\n    Check_Fortran_Libraries(BLAS_LIBRARIES BLAS sgemm \"\" \"sgemm;dgemm;blas\" \"\")\n  endif()\nendif()\n\n# BLAS in Alpha CXML library?\nif(BLA_VENDOR STREQUAL \"CXML\" OR BLA_VENDOR STREQUAL \"All\")\n  if(NOT BLAS_LIBRARIES)\n    Check_Fortran_Libraries(BLAS_LIBRARIES BLAS sgemm \"\" \"cxml\" \"\")\n  endif()\nendif()\n\n# BLAS in Alpha DXML library? (now called CXML, see above)\nif(BLA_VENDOR STREQUAL \"DXML\" OR BLA_VENDOR STREQUAL \"All\")\n  if(NOT BLAS_LIBRARIES)\n    Check_Fortran_Libraries(BLAS_LIBRARIES BLAS sgemm \"\" \"dxml\" \"\")\n  endif()\nendif()\n\n# BLAS in Sun Performance library?\nif(BLA_VENDOR STREQUAL \"SunPerf\" OR BLA_VENDOR STREQUAL \"All\")\n  if(NOT BLAS_LIBRARIES)\n    Check_Fortran_Libraries(BLAS_LIBRARIES BLAS sgemm \"-xlic_lib=sunperf\" \"sunperf;sunmath\" \"\")\n    if(BLAS_LIBRARIES)\n      set(BLAS_LINKER_FLAGS \"-xlic_lib=sunperf\")\n    endif()\n  endif()\nendif()\n\n# BLAS in SCSL library?  (SGI/Cray Scientific Library)\nif(BLA_VENDOR STREQUAL \"SCSL\" OR BLA_VENDOR STREQUAL \"All\")\n  if(NOT BLAS_LIBRARIES)\n    Check_Fortran_Libraries(BLAS_LIBRARIES BLAS sgemm \"\" \"scsl\" \"\")\n  endif()\nendif()\n\n# BLAS in SGIMATH library?\nif(BLA_VENDOR STREQUAL \"SGIMATH\" OR BLA_VENDOR STREQUAL \"All\")\n  if(NOT BLAS_LIBRARIES)\n    Check_Fortran_Libraries(BLAS_LIBRARIES BLAS sgemm \"\" \"complib.sgimath\" \"\")\n  endif()\nendif()\n\n# BLAS in IBM ESSL library? (requires generic BLAS lib, too)\nif(BLA_VENDOR STREQUAL \"IBMESSL\" OR BLA_VENDOR STREQUAL \"All\")\n  if(NOT BLAS_LIBRARIES)\n    Check_Fortran_Libraries(BLAS_LIBRARIES BLAS sgemm \"\" \"essl;blas\" \"\")\n  endif()\nendif()\n\n#BLAS in acml library?\nif(BLA_VENDOR MATCHES \"ACML\" OR BLA_VENDOR STREQUAL \"All\")\n  if(((BLA_VENDOR STREQUAL \"ACML\") AND (NOT BLAS_ACML_LIB_DIRS))\n     OR ((BLA_VENDOR STREQUAL \"ACML_MP\") AND (NOT BLAS_ACML_MP_LIB_DIRS))\n     OR ((BLA_VENDOR STREQUAL \"ACML_GPU\") AND (NOT BLAS_ACML_GPU_LIB_DIRS)))\n    # try to find acml in \"standard\" paths\n    if(WIN32)\n      file(GLOB _ACML_ROOT \"C:/AMD/acml*/ACML-EULA.txt\")\n    else()\n      file(GLOB _ACML_ROOT \"/opt/acml*/ACML-EULA.txt\")\n    endif()\n    if(WIN32)\n      file(GLOB _ACML_GPU_ROOT \"C:/AMD/acml*/GPGPUexamples\")\n    else()\n      file(GLOB _ACML_GPU_ROOT \"/opt/acml*/GPGPUexamples\")\n    endif()\n    list(GET _ACML_ROOT 0 _ACML_ROOT)\n    list(GET _ACML_GPU_ROOT 0 _ACML_GPU_ROOT)\n    if(_ACML_ROOT)\n      get_filename_component(_ACML_ROOT ${_ACML_ROOT} PATH)\n      if(SIZEOF_INTEGER EQUAL 8)\n        set(_ACML_PATH_SUFFIX \"_int64\")\n      else()\n        set(_ACML_PATH_SUFFIX \"\")\n      endif()\n      if(CMAKE_Fortran_COMPILER_ID STREQUAL \"Intel\")\n        set(_ACML_COMPILER32 \"ifort32\")\n        set(_ACML_COMPILER64 \"ifort64\")\n      elseif(CMAKE_Fortran_COMPILER_ID STREQUAL \"SunPro\")\n        set(_ACML_COMPILER32 \"sun32\")\n        set(_ACML_COMPILER64 \"sun64\")\n      elseif(CMAKE_Fortran_COMPILER_ID STREQUAL \"PGI\")\n        set(_ACML_COMPILER32 \"pgi32\")\n        if(WIN32)\n          set(_ACML_COMPILER64 \"win64\")\n        else()\n          set(_ACML_COMPILER64 \"pgi64\")\n        endif()\n      elseif(CMAKE_Fortran_COMPILER_ID STREQUAL \"Open64\")\n        # 32 bit builds not supported on Open64 but for code simplicity\n        # We'll just use the same directory twice\n        set(_ACML_COMPILER32 \"open64_64\")\n        set(_ACML_COMPILER64 \"open64_64\")\n      elseif(CMAKE_Fortran_COMPILER_ID STREQUAL \"NAG\")\n        set(_ACML_COMPILER32 \"nag32\")\n        set(_ACML_COMPILER64 \"nag64\")\n      else()\n        set(_ACML_COMPILER32 \"gfortran32\")\n        set(_ACML_COMPILER64 \"gfortran64\")\n      endif()\n\n      if(BLA_VENDOR STREQUAL \"ACML_MP\")\n        set(_ACML_MP_LIB_DIRS \"${_ACML_ROOT}/${_ACML_COMPILER32}_mp${_ACML_PATH_SUFFIX}/lib\"\n                              \"${_ACML_ROOT}/${_ACML_COMPILER64}_mp${_ACML_PATH_SUFFIX}/lib\")\n      else()\n        set(_ACML_LIB_DIRS \"${_ACML_ROOT}/${_ACML_COMPILER32}${_ACML_PATH_SUFFIX}/lib\"\n                           \"${_ACML_ROOT}/${_ACML_COMPILER64}${_ACML_PATH_SUFFIX}/lib\")\n      endif()\n    endif()\n  elseif(BLAS_${BLA_VENDOR}_LIB_DIRS)\n    set(_${BLA_VENDOR}_LIB_DIRS ${BLAS_${BLA_VENDOR}_LIB_DIRS})\n  endif()\n\n  if(BLA_VENDOR STREQUAL \"ACML_MP\")\n    foreach(BLAS_ACML_MP_LIB_DIRS ${_ACML_MP_LIB_DIRS})\n      Check_Fortran_Libraries(\n        BLAS_LIBRARIES\n        BLAS\n        sgemm\n        \"\"\n        \"acml_mp;acml_mv\"\n        \"\"\n        ${BLAS_ACML_MP_LIB_DIRS})\n      if(BLAS_LIBRARIES)\n        break()\n      endif()\n    endforeach()\n  elseif(BLA_VENDOR STREQUAL \"ACML_GPU\")\n    foreach(BLAS_ACML_GPU_LIB_DIRS ${_ACML_GPU_LIB_DIRS})\n      Check_Fortran_Libraries(\n        BLAS_LIBRARIES\n        BLAS\n        sgemm\n        \"\"\n        \"acml;acml_mv;CALBLAS\"\n        \"\"\n        ${BLAS_ACML_GPU_LIB_DIRS})\n      if(BLAS_LIBRARIES)\n        break()\n      endif()\n    endforeach()\n  else()\n    foreach(BLAS_ACML_LIB_DIRS ${_ACML_LIB_DIRS})\n      Check_Fortran_Libraries(\n        BLAS_LIBRARIES\n        BLAS\n        sgemm\n        \"\"\n        \"acml;acml_mv\"\n        \"\"\n        ${BLAS_ACML_LIB_DIRS})\n      if(BLAS_LIBRARIES)\n        break()\n      endif()\n    endforeach()\n  endif()\n\n  # Either acml or acml_mp should be in LD_LIBRARY_PATH but not both\n  if(NOT BLAS_LIBRARIES)\n    Check_Fortran_Libraries(BLAS_LIBRARIES BLAS sgemm \"\" \"acml;acml_mv\" \"\")\n  endif()\n  if(NOT BLAS_LIBRARIES)\n    Check_Fortran_Libraries(BLAS_LIBRARIES BLAS sgemm \"\" \"acml_mp;acml_mv\" \"\")\n  endif()\n  if(NOT BLAS_LIBRARIES)\n    Check_Fortran_Libraries(BLAS_LIBRARIES BLAS sgemm \"\" \"acml;acml_mv;CALBLAS\" \"\")\n  endif()\nendif() # ACML\n\n# Apple BLAS library?\nif(BLA_VENDOR STREQUAL \"Apple\" OR BLA_VENDOR STREQUAL \"All\")\n  if(NOT BLAS_LIBRARIES)\n    Check_Fortran_Libraries(BLAS_LIBRARIES BLAS dgemm \"\" \"Accelerate\" \"\")\n  endif()\nendif()\n\nif(BLA_VENDOR STREQUAL \"NAS\" OR BLA_VENDOR STREQUAL \"All\")\n  if(NOT BLAS_LIBRARIES)\n    Check_Fortran_Libraries(BLAS_LIBRARIES BLAS dgemm \"\" \"vecLib\" \"\")\n  endif()\nendif()\n# Generic BLAS library?\nif(BLA_VENDOR STREQUAL \"Generic\" OR BLA_VENDOR STREQUAL \"All\")\n  if(NOT BLAS_LIBRARIES)\n    Check_Fortran_Libraries(BLAS_LIBRARIES BLAS sgemm \"\" \"blas\" \"\")\n  endif()\nendif()\n\n#BLAS in intel mkl 10 library? (em64t 64bit)\nif(BLA_VENDOR MATCHES \"Intel\" OR BLA_VENDOR STREQUAL \"All\")\n  if(NOT WIN32)\n    set(LM \"-lm\")\n  endif()\n  if(CMAKE_C_COMPILER_LOADED OR CMAKE_CXX_COMPILER_LOADED)\n    if(BLAS_FIND_QUIETLY OR NOT BLAS_FIND_REQUIRED)\n      find_package(Threads)\n    else()\n      find_package(Threads REQUIRED)\n    endif()\n\n    set(BLAS_SEARCH_LIBS \"\")\n\n    if(BLA_F95)\n      set(BLAS_mkl_SEARCH_SYMBOL SGEMM)\n      set(_LIBRARIES BLAS95_LIBRARIES)\n      if(WIN32)\n        if(BLA_STATIC)\n          set(BLAS_mkl_DLL_SUFFIX \"\")\n        else()\n          set(BLAS_mkl_DLL_SUFFIX \"_dll\")\n        endif()\n\n        # Find the main file (32-bit or 64-bit)\n        set(BLAS_SEARCH_LIBS_WIN_MAIN \"\")\n        if(BLA_VENDOR STREQUAL \"Intel10_32\" OR BLA_VENDOR STREQUAL \"All\")\n          list(APPEND BLAS_SEARCH_LIBS_WIN_MAIN\n               \"mkl_blas95${BLAS_mkl_DLL_SUFFIX} mkl_intel_c${BLAS_mkl_DLL_SUFFIX}\")\n        endif()\n        if(BLA_VENDOR MATCHES \"^Intel10_64lp\" OR BLA_VENDOR STREQUAL \"All\")\n          list(APPEND BLAS_SEARCH_LIBS_WIN_MAIN\n               \"mkl_blas95_lp64${BLAS_mkl_DLL_SUFFIX} mkl_intel_lp64${BLAS_mkl_DLL_SUFFIX}\")\n        endif()\n\n        # Add threading/sequential libs\n        set(BLAS_SEARCH_LIBS_WIN_THREAD \"\")\n        if(BLA_VENDOR MATCHES \"_seq$\" OR BLA_VENDOR STREQUAL \"All\")\n          list(APPEND BLAS_SEARCH_LIBS_WIN_THREAD \"mkl_sequential${BLAS_mkl_DLL_SUFFIX}\")\n        endif()\n        if(NOT BLA_VENDOR MATCHES \"_seq$\" OR BLA_VENDOR STREQUAL \"All\")\n          # old version\n          list(APPEND BLAS_SEARCH_LIBS_WIN_THREAD\n               \"libguide40 mkl_intel_thread${BLAS_mkl_DLL_SUFFIX}\")\n          # mkl >= 10.3\n          list(APPEND BLAS_SEARCH_LIBS_WIN_THREAD\n               \"libiomp5md mkl_intel_thread${BLAS_mkl_DLL_SUFFIX}\")\n        endif()\n\n        # Cartesian product of the above\n        foreach(MAIN ${BLAS_SEARCH_LIBS_WIN_MAIN})\n          foreach(THREAD ${BLAS_SEARCH_LIBS_WIN_THREAD})\n            list(APPEND BLAS_SEARCH_LIBS \"${MAIN} ${THREAD} mkl_core${BLAS_mkl_DLL_SUFFIX}\")\n          endforeach()\n        endforeach()\n      else()\n        if(BLA_VENDOR STREQUAL \"Intel10_32\" OR BLA_VENDOR STREQUAL \"All\")\n          list(APPEND BLAS_SEARCH_LIBS \"mkl_blas95 mkl_intel mkl_intel_thread mkl_core guide\")\n        endif()\n        if(BLA_VENDOR STREQUAL \"Intel10_64lp\" OR BLA_VENDOR STREQUAL \"All\")\n          # old version\n          list(APPEND BLAS_SEARCH_LIBS \"mkl_blas95 mkl_intel_lp64 mkl_intel_thread mkl_core guide\")\n\n          # mkl >= 10.3\n          if(CMAKE_C_COMPILER MATCHES \".+gcc\")\n            list(APPEND BLAS_SEARCH_LIBS\n                 \"mkl_blas95_lp64 mkl_intel_lp64 mkl_gnu_thread mkl_core gomp\")\n          else()\n            list(APPEND BLAS_SEARCH_LIBS\n                 \"mkl_blas95_lp64 mkl_intel_lp64 mkl_intel_thread mkl_core iomp5\")\n          endif()\n        endif()\n        if(BLA_VENDOR STREQUAL \"Intel10_64lp_seq\" OR BLA_VENDOR STREQUAL \"All\")\n          list(APPEND BLAS_SEARCH_LIBS \"mkl_intel_lp64 mkl_sequential mkl_core\")\n        endif()\n      endif()\n    else()\n      set(BLAS_mkl_SEARCH_SYMBOL sgemm)\n      set(_LIBRARIES BLAS_LIBRARIES)\n      if(WIN32)\n        if(BLA_STATIC)\n          set(BLAS_mkl_DLL_SUFFIX \"\")\n        else()\n          set(BLAS_mkl_DLL_SUFFIX \"_dll\")\n        endif()\n\n        # Find the main file (32-bit or 64-bit)\n        set(BLAS_SEARCH_LIBS_WIN_MAIN \"\")\n        if(BLA_VENDOR STREQUAL \"Intel10_32\" OR BLA_VENDOR STREQUAL \"All\")\n          list(APPEND BLAS_SEARCH_LIBS_WIN_MAIN \"mkl_intel_c${BLAS_mkl_DLL_SUFFIX}\")\n        endif()\n        if(BLA_VENDOR MATCHES \"^Intel10_64lp\" OR BLA_VENDOR STREQUAL \"All\")\n          list(APPEND BLAS_SEARCH_LIBS_WIN_MAIN \"mkl_intel_lp64${BLAS_mkl_DLL_SUFFIX}\")\n        endif()\n\n        # Add threading/sequential libs\n        set(BLAS_SEARCH_LIBS_WIN_THREAD \"\")\n        if(NOT BLA_VENDOR MATCHES \"_seq$\" OR BLA_VENDOR STREQUAL \"All\")\n          # old version\n          list(APPEND BLAS_SEARCH_LIBS_WIN_THREAD\n               \"libguide40 mkl_intel_thread${BLAS_mkl_DLL_SUFFIX}\")\n          # mkl >= 10.3\n          list(APPEND BLAS_SEARCH_LIBS_WIN_THREAD\n               \"libiomp5md mkl_intel_thread${BLAS_mkl_DLL_SUFFIX}\")\n        endif()\n        if(BLA_VENDOR MATCHES \"_seq$\" OR BLA_VENDOR STREQUAL \"All\")\n          list(APPEND BLAS_SEARCH_LIBS_WIN_THREAD \"mkl_sequential${BLAS_mkl_DLL_SUFFIX}\")\n        endif()\n\n        # Cartesian product of the above\n        foreach(MAIN ${BLAS_SEARCH_LIBS_WIN_MAIN})\n          foreach(THREAD ${BLAS_SEARCH_LIBS_WIN_THREAD})\n            list(APPEND BLAS_SEARCH_LIBS \"${MAIN} ${THREAD} mkl_core${BLAS_mkl_DLL_SUFFIX}\")\n          endforeach()\n        endforeach()\n      else()\n        if(BLA_VENDOR STREQUAL \"Intel10_32\" OR BLA_VENDOR STREQUAL \"All\")\n          list(APPEND BLAS_SEARCH_LIBS \"mkl_intel mkl_intel_thread mkl_core guide\")\n        endif()\n        if(BLA_VENDOR STREQUAL \"Intel10_64lp\" OR BLA_VENDOR STREQUAL \"All\")\n\n          # old version\n          list(APPEND BLAS_SEARCH_LIBS \"mkl_intel_lp64 mkl_intel_thread mkl_core guide\")\n\n          # mkl >= 10.3\n          if(CMAKE_C_COMPILER MATCHES \".+gcc\")\n            list(APPEND BLAS_SEARCH_LIBS \"mkl_intel_lp64 mkl_gnu_thread mkl_core gomp\")\n          else()\n            list(APPEND BLAS_SEARCH_LIBS \"mkl_intel_lp64 mkl_intel_thread mkl_core iomp5\")\n          endif()\n        endif()\n        if(BLA_VENDOR STREQUAL \"Intel10_64lp_seq\" OR BLA_VENDOR STREQUAL \"All\")\n          list(APPEND BLAS_SEARCH_LIBS \"mkl_intel_lp64 mkl_sequential mkl_core\")\n        endif()\n\n        #older vesions of intel mkl libs\n        if(BLA_VENDOR STREQUAL \"Intel\" OR BLA_VENDOR STREQUAL \"All\")\n          list(APPEND BLAS_SEARCH_LIBS \"mkl\")\n          list(APPEND BLAS_SEARCH_LIBS \"mkl_ia32\")\n          list(APPEND BLAS_SEARCH_LIBS \"mkl_em64t\")\n        endif()\n      endif()\n    endif()\n\n    foreach(IT ${BLAS_SEARCH_LIBS})\n      string(REPLACE \" \" \";\" SEARCH_LIBS ${IT})\n      if(${_LIBRARIES})\n\n      else()\n        Check_Fortran_Libraries(${_LIBRARIES} BLAS ${BLAS_mkl_SEARCH_SYMBOL} \"\" \"${SEARCH_LIBS}\"\n                                \"${CMAKE_THREAD_LIBS_INIT};${LM}\")\n      endif()\n    endforeach()\n\n  endif()\nendif()\n\nif(BLA_F95)\n  if(BLAS95_LIBRARIES)\n    set(BLAS95_FOUND TRUE)\n  else()\n    set(BLAS95_FOUND FALSE)\n  endif()\n\n  if(NOT BLAS_FIND_QUIETLY)\n    if(BLAS95_FOUND)\n      message(STATUS \"A library with BLAS95 API found.\")\n    else()\n      if(BLAS_FIND_REQUIRED)\n        message(\n          FATAL_ERROR\n            \"A required library with BLAS95 API not found. Please specify library location.\")\n      else()\n        message(STATUS \"A library with BLAS95 API not found. Please specify library location.\")\n      endif()\n    endif()\n  endif()\n  set(BLAS_FOUND TRUE)\n  set(BLAS_LIBRARIES \"${BLAS95_LIBRARIES}\")\nelse()\n  if(BLAS_LIBRARIES)\n    set(BLAS_FOUND TRUE)\n  else()\n    set(BLAS_FOUND FALSE)\n  endif()\n\n  if(NOT BLAS_FIND_QUIETLY)\n    if(BLAS_FOUND)\n      message(STATUS \"A library with BLAS API found.\")\n    else()\n      if(BLAS_FIND_REQUIRED)\n        message(\n          FATAL_ERROR \"A required library with BLAS API not found. Please specify library location.\"\n        )\n      else()\n        message(STATUS \"A library with BLAS API not found. Please specify library location.\")\n      endif()\n    endif()\n  endif()\nendif()\n\nset(CMAKE_FIND_LIBRARY_SUFFIXES ${_blas_ORIG_CMAKE_FIND_LIBRARY_SUFFIXES})\n"
  },
  {
    "path": "cmake/third_party/FindCUDNN.cmake",
    "content": "# - Try to find cuDNN\n#\n# The following variables are optionally searched for defaults\n#  CUDNN_ROOT_DIR:            Base directory where all cuDNN components are found\n#\n# The following are set after configuration is done:\n#  CUDNN_FOUND\n#  CUDNN_INCLUDE_DIRS\n#  CUDNN_LIBRARIES\n#  CUDNN_LIBRARY_DIRS\n\ninclude(FindPackageHandleStandardArgs)\ninclude(CMakeDependentOption)\n\nset(CUDNN_ROOT_DIR \"\" CACHE PATH \"Folder contains NVIDIA cuDNN\")\n\nif(CUDA_VERSION VERSION_LESS \"11.0\")\n  set(CUDA_VERSION_VERSION_LESS_11 TRUE)\nendif()\n\ncmake_dependent_option(CUDNN_STATIC \"Look for static cuDNN\" ON \"CUDA_VERSION_VERSION_LESS_11\" OFF)\n\nif(OF_CUDA_LINK_DYNAMIC_LIBRARY)\n  set(CUDNN_STATIC OFF)\nendif()\nif(CUDNN_STATIC)\n  set(__cudnn_libname \"libcudnn_static.a\")\nelse()\n  set(__cudnn_libname \"libcudnn.so\")\nendif()\n\nfind_path(CUDNN_INCLUDE_DIR cudnn.h HINTS ${CUDNN_ROOT_DIR} ${CUDAToolkit_INCLUDE_DIRS}\n          PATH_SUFFIXES cuda/include include)\n\nunset(CUDNN_LIBRARY CACHE)\nfind_library(CUDNN_LIBRARY ${__cudnn_libname} HINTS ${CUDNN_ROOT_DIR} ${CUDAToolkit_LIBRARY_DIR}\n             PATH_SUFFIXES lib lib64 cuda/lib cuda/lib64 lib/x64)\n\nfind_package_handle_standard_args(CUDNN DEFAULT_MSG CUDNN_INCLUDE_DIR CUDNN_LIBRARY)\n\nif(CUDNN_FOUND)\n  # get cuDNN version\n  if(EXISTS \"${CUDNN_INCLUDE_DIR}/cudnn_version.h\")\n    file(READ ${CUDNN_INCLUDE_DIR}/cudnn_version.h CUDNN_HEADER_CONTENTS)\n  else()\n    file(READ ${CUDNN_INCLUDE_DIR}/cudnn.h CUDNN_HEADER_CONTENTS)\n  endif()\n  string(REGEX MATCH \"define CUDNN_MAJOR * +([0-9]+)\" CUDNN_VERSION_MAJOR\n               \"${CUDNN_HEADER_CONTENTS}\")\n  string(REGEX REPLACE \"define CUDNN_MAJOR * +([0-9]+)\" \"\\\\1\" CUDNN_VERSION_MAJOR\n                       \"${CUDNN_VERSION_MAJOR}\")\n  string(REGEX MATCH \"define CUDNN_MINOR * +([0-9]+)\" CUDNN_VERSION_MINOR\n               \"${CUDNN_HEADER_CONTENTS}\")\n  string(REGEX REPLACE \"define CUDNN_MINOR * +([0-9]+)\" \"\\\\1\" CUDNN_VERSION_MINOR\n                       \"${CUDNN_VERSION_MINOR}\")\n  string(REGEX MATCH \"define CUDNN_PATCHLEVEL * +([0-9]+)\" CUDNN_VERSION_PATCH\n               \"${CUDNN_HEADER_CONTENTS}\")\n  string(REGEX REPLACE \"define CUDNN_PATCHLEVEL * +([0-9]+)\" \"\\\\1\" CUDNN_VERSION_PATCH\n                       \"${CUDNN_VERSION_PATCH}\")\n  # Assemble cuDNN version\n  if(NOT CUDNN_VERSION_MAJOR)\n    set(CUDNN_VERSION \"?\")\n  else()\n    set(CUDNN_VERSION \"${CUDNN_VERSION_MAJOR}.${CUDNN_VERSION_MINOR}.${CUDNN_VERSION_PATCH}\")\n  endif()\n\n  set(CUDNN_INCLUDE_DIRS ${CUDNN_INCLUDE_DIR})\n\n  if(NOT CUDNN_STATIC AND CUDNN_VERSION_MAJOR GREATER_EQUAL 9)\n    # skipping: libcudnn_adv_infer.so libcudnn_adv_train.so\n    set(CUDNN_DYNAMIC_NAMES libcudnn_cnn.so libcudnn_ops.so)\n    get_filename_component(CUDNN_LIBRARY_DIRECTORY ${CUDNN_LIBRARY} DIRECTORY)\n    foreach(CUDNN_DYNAMIC_NAME ${CUDNN_DYNAMIC_NAMES})\n      list(APPEND CUDNN_LIBRARIES ${CUDNN_LIBRARY_DIRECTORY}/${CUDNN_DYNAMIC_NAME})\n    endforeach()\n  elseif(NOT CUDNN_STATIC AND CUDNN_VERSION_MAJOR GREATER_EQUAL 8)\n    # skipping: libcudnn_adv_infer.so libcudnn_adv_train.so\n    set(CUDNN_DYNAMIC_NAMES libcudnn_cnn_infer.so libcudnn_cnn_train.so libcudnn_ops_infer.so\n                            libcudnn_ops_train.so)\n    get_filename_component(CUDNN_LIBRARY_DIRECTORY ${CUDNN_LIBRARY} DIRECTORY)\n    foreach(CUDNN_DYNAMIC_NAME ${CUDNN_DYNAMIC_NAMES})\n      list(APPEND CUDNN_LIBRARIES ${CUDNN_LIBRARY_DIRECTORY}/${CUDNN_DYNAMIC_NAME})\n    endforeach()\n  else()\n    set(CUDNN_LIBRARIES ${CUDNN_LIBRARY})\n  endif()\n  message(\n    STATUS\n      \"Found cuDNN: v${CUDNN_VERSION}  (include: ${CUDNN_INCLUDE_DIR}, library: ${CUDNN_LIBRARIES})\"\n  )\n  mark_as_advanced(CUDNN_ROOT_DIR CUDNN_LIBRARY CUDNN_INCLUDE_DIR)\nendif()\n"
  },
  {
    "path": "cmake/third_party/FindUnwind.cmake",
    "content": "# - Try to find libunwind\n# Once done this will define\n#\n#  Unwind_FOUND - system has libunwind\n#  unwind::unwind - cmake target for libunwind\n\ninclude(FindPackageHandleStandardArgs)\n\nfind_path(Unwind_INCLUDE_DIR NAMES unwind.h libunwind.h DOC \"unwind include directory\")\nfind_library(Unwind_LIBRARY NAMES unwind DOC \"unwind library\")\n\nmark_as_advanced(Unwind_INCLUDE_DIR Unwind_LIBRARY)\n\n# Extract version information\nif(Unwind_LIBRARY)\n  set(_Unwind_VERSION_HEADER ${Unwind_INCLUDE_DIR}/libunwind-common.h)\n\n  if(EXISTS ${_Unwind_VERSION_HEADER})\n    file(READ ${_Unwind_VERSION_HEADER} _Unwind_VERSION_CONTENTS)\n\n    string(REGEX REPLACE \".*#define UNW_VERSION_MAJOR[ \\t]+([0-9]+).*\" \"\\\\1\" Unwind_VERSION_MAJOR\n                         \"${_Unwind_VERSION_CONTENTS}\")\n    string(REGEX REPLACE \".*#define UNW_VERSION_MINOR[ \\t]+([0-9]+).*\" \"\\\\1\" Unwind_VERSION_MINOR\n                         \"${_Unwind_VERSION_CONTENTS}\")\n    string(REGEX REPLACE \".*#define UNW_VERSION_EXTRA[ \\t]+([0-9]+).*\" \"\\\\1\" Unwind_VERSION_PATCH\n                         \"${_Unwind_VERSION_CONTENTS}\")\n\n    set(Unwind_VERSION ${Unwind_VERSION_MAJOR}.${Unwind_VERSION_MINOR})\n\n    if(CMAKE_MATCH_0)\n      # Third version component may be empty\n      set(Unwind_VERSION ${Unwind_VERSION}.${Unwind_VERSION_PATCH})\n      set(Unwind_VERSION_COMPONENTS 3)\n    else(CMAKE_MATCH_0)\n      set(Unwind_VERSION_COMPONENTS 2)\n    endif(CMAKE_MATCH_0)\n  endif(EXISTS ${_Unwind_VERSION_HEADER})\nendif(Unwind_LIBRARY)\n\n# handle the QUIETLY and REQUIRED arguments and set Unwind_FOUND to TRUE\n# if all listed variables are TRUE\nfind_package_handle_standard_args(Unwind REQUIRED_VARS Unwind_INCLUDE_DIR Unwind_LIBRARY\n                                  VERSION_VAR Unwind_VERSION)\n\nif(Unwind_FOUND)\n  if(NOT TARGET unwind::unwind)\n    add_library(unwind::unwind INTERFACE IMPORTED)\n\n    set_property(TARGET unwind::unwind PROPERTY INTERFACE_INCLUDE_DIRECTORIES ${Unwind_INCLUDE_DIR})\n    set_property(TARGET unwind::unwind PROPERTY INTERFACE_LINK_LIBRARIES ${Unwind_LIBRARY})\n    set_property(TARGET unwind::unwind PROPERTY IMPORTED_CONFIGURATIONS RELEASE)\n  endif(NOT TARGET unwind::unwind)\nendif(Unwind_FOUND)\n"
  },
  {
    "path": "cmake/third_party/absl.cmake",
    "content": "include(ExternalProject)\ninclude(GNUInstallDirs)\n\nset(ABSL_PROJECT absl)\nset(ABSL_TAR_URL https://github.com/abseil/abseil-cpp/archive/refs/tags/20230125.2.tar.gz)\nuse_mirror(VARIABLE ABSL_TAR_URL URL ${ABSL_TAR_URL})\nset(ABSL_SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/absl/src/absl)\nset(ABSL_INSTALL ${THIRD_PARTY_DIR}/absl)\n\nset(ABSL_INCLUDE_DIR ${THIRD_PARTY_DIR}/absl/include CACHE PATH \"\" FORCE)\nset(ABSL_LIBRARY_DIR ${THIRD_PARTY_DIR}/absl/${CMAKE_INSTALL_LIBDIR} CACHE PATH \"\" FORCE)\n\nif(WIN32)\n  set(ABSL_BUILD_LIBRARY_DIR ${ABSL_INSTALL}/${CMAKE_INSTALL_LIBDIR})\n  set(ABSL_LIBRARY_NAMES\n      absl_spinlock_wait.lib\n      absl_malloc_internal.lib\n      absl_throw_delegate.lib\n      absl_int128.lib\n      absl_strings.lib\n      absl_str_format_internal.lib\n      absl_time.lib\n      absl_bad_optional_access.lib\n      absl_base.lib)\nelse()\n  set(ABSL_BUILD_LIBRARY_DIR ${ABSL_INSTALL}/${CMAKE_INSTALL_LIBDIR})\n  set(ABSL_LIBRARY_NAMES\n      libabsl_spinlock_wait.a\n      libabsl_malloc_internal.a\n      libabsl_throw_delegate.a\n      libabsl_int128.a\n      libabsl_strings.a\n      libabsl_str_format_internal.a\n      libabsl_time.a\n      libabsl_bad_optional_access.a\n      libabsl_base.a)\nendif()\n\nforeach(LIBRARY_NAME ${ABSL_LIBRARY_NAMES})\n  list(APPEND ABSL_STATIC_LIBRARIES ${ABSL_LIBRARY_DIR}/${LIBRARY_NAME})\n  list(APPEND ABSL_BUILD_STATIC_LIBRARIES ${ABSL_BUILD_LIBRARY_DIR}/${LIBRARY_NAME})\nendforeach()\n\nif(THIRD_PARTY)\n  ExternalProject_Add(\n    ${ABSL_PROJECT}\n    PREFIX absl\n    URL ${ABSL_TAR_URL}\n    URL_MD5 52b9786ca6fbc679869fee2b6fef25a5\n    UPDATE_COMMAND \"\"\n    BUILD_BYPRODUCTS ${ABSL_STATIC_LIBRARIES}\n    CMAKE_CACHE_ARGS\n      -DCMAKE_C_COMPILER_LAUNCHER:STRING=${CMAKE_C_COMPILER_LAUNCHER}\n      -DCMAKE_CXX_COMPILER_LAUNCHER:STRING=${CMAKE_CXX_COMPILER_LAUNCHER}\n      -DCMAKE_INSTALL_PREFIX:PATH=${ABSL_INSTALL}\n      -DCMAKE_INSTALL_LIBDIR:PATH=${ABSL_INSTALL}/${CMAKE_INSTALL_LIBDIR}\n      -DCMAKE_INSTALL_MESSAGE:STRING=${CMAKE_INSTALL_MESSAGE}\n      -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON\n      -DCMAKE_BUILD_TYPE:STRING=${CMAKE_BUILD_TYPE})\nendif(THIRD_PARTY)\n"
  },
  {
    "path": "cmake/third_party/cares.cmake",
    "content": "include(ExternalProject)\nset(CARES_TAR_URL\n    https://github.com/c-ares/c-ares/releases/download/cares-1_15_0/c-ares-1.15.0.tar.gz)\nuse_mirror(VARIABLE CARES_TAR_URL URL ${CARES_TAR_URL})\nset(CARES_URL_HASH d2391da274653f7643270623e822dff7)\nset(CARES_INSTALL ${THIRD_PARTY_DIR}/cares)\nset(CARES_SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/cares/src/cares)\n\nif(THIRD_PARTY)\n  ExternalProject_Add(\n    cares\n    PREFIX cares\n    URL ${CARES_TAR_URL}\n    URL_HASH MD5=${CARES_URL_HASH}\n    UPDATE_COMMAND \"\"\n    CONFIGURE_COMMAND \"\"\n    BUILD_COMMAND \"\"\n    INSTALL_COMMAND \"\")\n\nendif()\n"
  },
  {
    "path": "cmake/third_party/cocoapi.cmake",
    "content": "include(ExternalProject)\n\nset(COCOAPI_INCLUDE_DIR ${THIRD_PARTY_DIR}/cocoapi/include)\nset(COCOAPI_LIBRARY_DIR ${THIRD_PARTY_DIR}/cocoapi/lib)\n\nset(COCOAPI_URL https://github.com/Oneflow-Inc/cocoapi/archive/refs/tags/ed842bf.tar.gz)\nuse_mirror(VARIABLE COCOAPI_URL URL ${COCOAPI_URL})\nset(COCOAPI_URL_HASH e7e0504231e5614ffaa34f081773f7f1)\nset(COCOAPI_BASE_DIR ${CMAKE_CURRENT_BINARY_DIR}/cocoapi/src/cocoapi)\nset(COCOAPI_LIBRARY_NAME libcocoapi_static.a)\n\nlist(APPEND COCOAPI_STATIC_LIBRARIES ${COCOAPI_LIBRARY_DIR}/${COCOAPI_LIBRARY_NAME})\nlist(APPEND COCOAPI_BUILD_STATIC_LIBRARIES ${COCOAPI_BASE_DIR}/${COCOAPI_LIBRARY_NAME})\n\nset(COCOAPI_HEADERS \"${COCOAPI_BASE_DIR}/common/maskApi.h\")\n\nif(THIRD_PARTY)\n\n  ExternalProject_Add(\n    cocoapi\n    PREFIX cocoapi\n    URL ${COCOAPI_URL}\n    URL_HASH MD5=${COCOAPI_URL_HASH}\n    UPDATE_COMMAND \"\"\n    CONFIGURE_COMMAND \"\"\n    BUILD_IN_SOURCE 1\n    BUILD_BYPRODUCTS ${COCOAPI_STATIC_LIBRARIES}\n    BUILD_COMMAND ${CMAKE_C_COMPILER} -fPIC -O3 -c common/maskApi.c -o maskApi.o && ${CMAKE_AR} rcs\n                  ${COCOAPI_LIBRARY_NAME} maskApi.o\n    INSTALL_COMMAND \"\")\n\n  add_custom_target(cocoapi_create_header_dir COMMAND ${CMAKE_COMMAND} -E make_directory\n                                                      ${COCOAPI_INCLUDE_DIR} DEPENDS cocoapi)\n\n  add_custom_target(cocoapi_copy_headers_to_destination DEPENDS cocoapi_create_header_dir)\n\n  foreach(header_file ${COCOAPI_HEADERS})\n    add_custom_command(\n      TARGET cocoapi_copy_headers_to_destination PRE_BUILD\n      COMMAND ${CMAKE_COMMAND} -E copy_if_different ${header_file} ${COCOAPI_INCLUDE_DIR})\n  endforeach()\n\n  add_custom_target(cocoapi_create_library_dir COMMAND ${CMAKE_COMMAND} -E make_directory\n                                                       ${COCOAPI_LIBRARY_DIR} DEPENDS cocoapi)\n\n  add_custom_target(\n    cocoapi_copy_libs_to_destination\n    COMMAND ${CMAKE_COMMAND} -E copy_if_different ${COCOAPI_BUILD_STATIC_LIBRARIES}\n            ${COCOAPI_LIBRARY_DIR} DEPENDS cocoapi_create_library_dir)\nendif(THIRD_PARTY)\n"
  },
  {
    "path": "cmake/third_party/cub.cmake",
    "content": "include(ExternalProject)\n\nset(CUB_INCLUDE_DIR ${THIRD_PARTY_DIR}/cub/include)\nset(CUB_BUILD_INCLUDE ${CMAKE_CURRENT_BINARY_DIR}/cub/src/cub/cub)\n\nset(CUB_URL https://github.com/NVIDIA/cub/archive/refs/tags/1.11.0.tar.gz)\nuse_mirror(VARIABLE CUB_URL URL ${CUB_URL})\n\nif(THIRD_PARTY)\n\n  ExternalProject_Add(\n    cub\n    PREFIX cub\n    URL ${CUB_URL}\n    URL_MD5 97196a885598e40592100e1caaf3d5ea\n    CONFIGURE_COMMAND \"\"\n    BUILD_COMMAND \"\"\n    INSTALL_COMMAND \"\")\n\n  add_copy_headers_target(\n    NAME\n    cub\n    SRC\n    ${CUB_BUILD_INCLUDE}\n    DST\n    ${CUB_INCLUDE_DIR}/cub\n    DEPS\n    cub\n    INDEX_FILE\n    \"${oneflow_cmake_dir}/third_party/header_index/cub_headers.txt\")\n\nendif(THIRD_PARTY)\n"
  },
  {
    "path": "cmake/third_party/cutlass.cmake",
    "content": "include(ExternalProject)\n\nif(CMAKE_CXX_COMPILER_ID STREQUAL \"Clang\")\n  set(WITH_CUTLASS_INIT OFF)\nelse()\n  set(WITH_CUTLASS_INIT ON)\nendif()\n\nset(WITH_CUTLASS ${WITH_CUTLASS_INIT} CACHE BOOL \"\")\n\nif(WITH_CUTLASS)\n\n  add_definitions(-DWITH_CUTLASS)\n\n  find_package(Threads)\n\n  set(CUTLASS_PROJECT cutlass)\n\n  set(CUTLASS_INSTALL_DIR ${THIRD_PARTY_DIR}/cutlass)\n\n  set(CUTLASS_INCLUDE_DIR ${CUTLASS_INSTALL_DIR}/include CACHE PATH \"\" FORCE)\n  set(CUTLASS_LIBRARY_DIR ${CUTLASS_INSTALL_DIR}/lib CACHE PATH \"\" FORCE)\n  set(CUTLASS_LIBRARIES ${CUTLASS_LIBRARY_DIR}/libcutlass.so)\n  set(CUTLASS_SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/cutlass/src/cutlass/)\n\n  foreach(arch ${CUDA_REAL_ARCHS_LIST})\n    if(arch GREATER_EQUAL 70)\n      list(APPEND CUTLASS_REAL_ARCHS ${arch})\n    endif()\n  endforeach()\n\n  if(THIRD_PARTY)\n    ExternalProject_Add(\n      ${CUTLASS_PROJECT}\n      PREFIX cutlass\n      URL ${CUTLASS_URL}\n      URL_MD5 ${CUTLASS_MD5}\n      UPDATE_COMMAND \"\"\n      BUILD_BYPRODUCTS ${CUTLASS_LIBRARIES}\n      CMAKE_ARGS -DCMAKE_BUILD_TYPE:STRING=${CMAKE_BUILD_TYPE}\n                 -DCMAKE_CXX_FLAGS:STRING=${CMAKE_CXX_FLAGS}\n                 -DCMAKE_CXX_FLAGS_DEBUG:STRING=${CMAKE_CXX_FLAGS_DEBUG}\n                 -DCMAKE_CXX_FLAGS_RELEASE:STRING=${CMAKE_CXX_FLAGS_RELEASE}\n      CMAKE_CACHE_ARGS\n        -DCMAKE_CUDA_COMPILER:STRING=${CUDAToolkit_NVCC_EXECUTABLE}\n        -DCMAKE_C_COMPILER_LAUNCHER:STRING=${CMAKE_C_COMPILER_LAUNCHER}\n        -DCMAKE_CXX_COMPILER_LAUNCHER:STRING=${CMAKE_CXX_COMPILER_LAUNCHER}\n        -DCMAKE_INSTALL_PREFIX:PATH=${CUTLASS_INSTALL_DIR}\n        -DCMAKE_INSTALL_LIBDIR:PATH=${CUTLASS_LIBRARY_DIR}\n        -DCMAKE_INSTALL_MESSAGE:STRING=${CMAKE_INSTALL_MESSAGE}\n        -DCMAKE_BUILD_TYPE:STRING=${CMAKE_BUILD_TYPE}\n        -DCUTLASS_LIBRARY_OPERATIONS:STRING=conv2d\n        -DCUTLASS_LIBRARY_KERNELS:STRING=simt_hfprop_*,tensorop_f16_*fprop,tensorop_h*fprop\n        -DCUTLASS_ENABLE_EXAMPLES:BOOL=OFF\n        -DCUTLASS_ENABLE_PROFILER:BOOL=OFF\n        -DCUTLASS_ENABLE_LIBRARY:BOOL=ON\n        -DCUTLASS_NVCC_ARCHS:STRING=${CUTLASS_REAL_ARCHS}\n        -DCUTLASS_ENABLE_TESTS:BOOL=OFF\n        -DCUTLASS_UNITY_BUILD_ENABLED:BOOL=ON\n        -DCUTLASS_LIBRARY_DEBUG_POSTFIX:STRING=\n        -DCUTLASS_NVCC_EMBED_PTX:BOOL=OFF)\n\n    add_custom_target(cutlass_copy_examples_to_destination DEPENDS cutlass)\n    set(CUTLASS_SOURCE_EXAMPLES_DIR ${CUTLASS_SOURCE_DIR}/examples)\n\n    set(CUTLASS_INSTALL_EXAMPLES_FILES\n        \"45_dual_gemm/test_run.h\"\n        \"45_dual_gemm/kernel/dual_gemm.h\"\n        \"45_dual_gemm/device/dual_gemm.h\"\n        \"45_dual_gemm/dual_gemm_run.h\"\n        \"45_dual_gemm/thread/left_silu_and_mul.h\"\n        \"45_dual_gemm/threadblock/dual_mma_multistage.h\"\n        \"45_dual_gemm/threadblock/dual_epilogue.h\"\n        \"45_dual_gemm/threadblock/dual_mma_base.h\"\n        \"xformers_fmha/gemm_kernel_utils.h\"\n        \"xformers_fmha/gemm/find_default_mma.h\"\n        \"xformers_fmha/gemm/mma_accum_lambda_iterator.h\"\n        \"xformers_fmha/gemm/custom_mma_multistage.h\"\n        \"xformers_fmha/gemm/mma_from_smem.h\"\n        \"xformers_fmha/gemm/custom_mma.h\"\n        \"xformers_fmha/gemm/custom_mma_base.h\"\n        \"xformers_fmha/gemm/custom_mma_pipelined.h\"\n        \"xformers_fmha/epilogue/epilogue_thread_apply_logsumexp.h\"\n        \"xformers_fmha/epilogue/epilogue_rescale_output.h\"\n        \"xformers_fmha/epilogue/epilogue_pipelined.h\"\n        \"xformers_fmha/debug_utils.h\"\n        \"xformers_fmha/kernel_forward.h\"\n        \"xformers_fmha/pytorch_utils.h\"\n        \"xformers_fmha/transform/tile_smem_loader.h\"\n        \"xformers_fmha/autogen/cutlassB.h\"\n        \"xformers_fmha/autogen/cutlassF.h\"\n        \"xformers_fmha/iterators/make_residual_last.h\"\n        \"xformers_fmha/iterators/predicated_tile_iterator_residual_last.h\"\n        \"xformers_fmha/iterators/epilogue_predicated_tile_iterator.h\"\n        \"xformers_fmha/iterators/transpose_warp_iterator.h\"\n        \"xformers_fmha/iterators/warp_iterator_from_smem.h\"\n        \"xformers_fmha/iterators/predicated_tile_access_iterator_residual_last.h\"\n        \"xformers_fmha/kernel_backward.h\")\n\n    foreach(filename ${CUTLASS_INSTALL_EXAMPLES_FILES})\n      add_custom_command(\n        TARGET cutlass_copy_examples_to_destination\n        COMMAND ${CMAKE_COMMAND} -E copy_if_different ${CUTLASS_SOURCE_EXAMPLES_DIR}/${filename}\n                ${CUTLASS_INSTALL_DIR}/examples/${filename})\n    endforeach()\n\n  endif(THIRD_PARTY)\nendif(WITH_CUTLASS)\n"
  },
  {
    "path": "cmake/third_party/eigen.cmake",
    "content": "include(ExternalProject)\n\nset(EIGEN_INCLUDE_DIR ${THIRD_PARTY_DIR}/eigen/include/eigen3)\nset(EIGEN_INSTALL_DIR ${THIRD_PARTY_DIR}/eigen)\n\nset(EIGEN_URL https://github.com/Oneflow-Inc/eigen-git-mirror/archive/refs/tags/e9e95489a.tar.gz)\nset(EIGEN_MD5 a23cb70e12d1bf9b09cb28af51bc26ae)\nuse_mirror(VARIABLE EIGEN_URL URL ${EIGEN_URL})\n\nif(BUILD_CUDA)\n  add_definitions(-DEIGEN_USE_GPU)\nendif()\n\nif(THIRD_PARTY)\n\n  ExternalProject_Add(\n    eigen\n    PREFIX eigen\n    URL ${EIGEN_URL}\n    URL_MD5 ${EIGEN_MD5}\n    UPDATE_COMMAND \"\"\n    INSTALL_DIR \"${EIGEN_INSTALL_DIR}\"\n    CMAKE_CACHE_ARGS\n      -DCMAKE_C_COMPILER_LAUNCHER:STRING=${CMAKE_C_COMPILER_LAUNCHER}\n      -DCMAKE_CXX_COMPILER_LAUNCHER:STRING=${CMAKE_CXX_COMPILER_LAUNCHER}\n      -DCMAKE_BUILD_TYPE:STRING=${CMAKE_BUILD_TYPE}\n      -DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF\n      -DCMAKE_INSTALL_PREFIX:STRING=${EIGEN_INSTALL_DIR}\n      -DCMAKE_INSTALL_MESSAGE:STRING=${CMAKE_INSTALL_MESSAGE}\n      -DCMAKE_CXX_FLAGS_DEBUG:STRING=${CMAKE_CXX_FLAGS_DEBUG}\n      -DCMAKE_CXX_FLAGS_RELEASE:STRING=${CMAKE_CXX_FLAGS_RELEASE}\n      -DBUILD_TESTING:BOOL=OFF)\n\nendif(THIRD_PARTY)\n"
  },
  {
    "path": "cmake/third_party/flash_attention.cmake",
    "content": "include(ExternalProject)\n\nfind_package(Threads)\n\n# NOTE: A git version of 1.6.5 or later is required if this download method is used.\nfind_package(Git QUIET REQUIRED)\n\nset(FLASH_ATTENTION_PROJECT flash_attention)\n\nset(FLASH_ATTENTION_URL\n    https://oneflow-static.oss-cn-beijing.aliyuncs.com/third_party_mirror/flash-attention-v2-eed2e82b880e06237af3e50ceac4cf6728b15645.zip\n)\n\nset(FLASH_ATTENTION_INSTALL_DIR ${THIRD_PARTY_DIR}/flash_attention)\nset(FLASH_ATTENTION_INCLUDE_DIR ${FLASH_ATTENTION_INSTALL_DIR}/include CACHE PATH \"\" FORCE)\nset(FLASH_ATTENTION_LIBRARY_DIR ${FLASH_ATTENTION_INSTALL_DIR}/lib CACHE PATH \"\" FORCE)\nset(FLASH_ATTENTION_LIBRARIES ${FLASH_ATTENTION_LIBRARY_DIR}/libflash_attention.so)\n\nif(THIRD_PARTY)\n  ExternalProject_Add(\n    ${FLASH_ATTENTION_PROJECT}\n    PREFIX flash_attention\n    URL ${FLASH_ATTENTION_URL}\n    URL_HASH MD5=63192a05973f614aff594a8bd11813ce\n    UPDATE_COMMAND \"\"\n    BUILD_BYPRODUCTS ${FLASH_ATTENTION_LIBRARIES}\n    CMAKE_ARGS -DCMAKE_BUILD_TYPE:STRING=${CMAKE_BUILD_TYPE}\n               -DCMAKE_CXX_FLAGS:STRING=${CMAKE_CXX_FLAGS}\n               -DCMAKE_CXX_FLAGS_DEBUG:STRING=${CMAKE_CXX_FLAGS_DEBUG}\n               -DCMAKE_CXX_FLAGS_RELEASE:STRING=${CMAKE_CXX_FLAGS_RELEASE}\n               -DCMAKE_CUDA_ARCHITECTURES:STRING=${CMAKE_CUDA_ARCHITECTURES}\n    CMAKE_CACHE_ARGS\n      -DCMAKE_CUDA_COMPILER:STRING=${CUDAToolkit_NVCC_EXECUTABLE}\n      -DCMAKE_C_COMPILER_LAUNCHER:STRING=${CMAKE_C_COMPILER_LAUNCHER}\n      -DCMAKE_CXX_COMPILER_LAUNCHER:STRING=${CMAKE_CXX_COMPILER_LAUNCHER}\n      -DCMAKE_INSTALL_PREFIX:PATH=${FLASH_ATTENTION_INSTALL_DIR}\n      -DCMAKE_INSTALL_LIBDIR:PATH=${FLASH_ATTENTION_LIBRARY_DIR}\n      -DCMAKE_INSTALL_MESSAGE:STRING=${CMAKE_INSTALL_MESSAGE}\n      -DCMAKE_BUILD_TYPE:STRING=${CMAKE_BUILD_TYPE})\nendif(THIRD_PARTY)\n"
  },
  {
    "path": "cmake/third_party/flatbuffers.cmake",
    "content": "include(ExternalProject)\n\nset(FLATBUFFERS_URL https://github.com/google/flatbuffers/archive/v1.12.0.tar.gz)\n\nset(FLATBUFFERS_INSTALL_PREFIX ${THIRD_PARTY_DIR}/flatbuffers)\nset(FLATBUFFERS_INSTALL_INCLUDEDIR include)\nset(FLATBUFFERS_INSTALL_LIBDIR lib)\nset(FLATBUFFERS_INSTALL_BINDIR bin)\n\nuse_mirror(VARIABLE FLATBUFFERS_URL URL ${FLATBUFFERS_URL})\n\nset(FLATBUFFERS_INCLUDE_DIR ${FLATBUFFERS_INSTALL_PREFIX}/${FLATBUFFERS_INSTALL_INCLUDEDIR})\nset(FLATBUFFERS_LIBRARY_DIR ${FLATBUFFERS_INSTALL_PREFIX}/${FLATBUFFERS_INSTALL_LIBDIR})\nset(FLATBUFFERS_BINARY_DIR ${FLATBUFFERS_INSTALL_PREFIX}/${FLATBUFFERS_INSTALL_BINDIR})\n\nset(FLATC_EXECUTABLE_NAME flatc)\nset(FLATBUFFERS_FLATC_EXECUTABLE ${FLATBUFFERS_BINARY_DIR}/${FLATC_EXECUTABLE_NAME})\n\nset(FLATBUFFERS_LIBRARY_NAMES libflatbuffers.a)\nforeach(LIBRARY_NAME ${FLATBUFFERS_LIBRARY_NAMES})\n  list(APPEND FLATBUFFERS_STATIC_LIBRARIES ${FLATBUFFERS_LIBRARY_DIR}/${LIBRARY_NAME})\nendforeach()\n\nif(THIRD_PARTY)\n\n  ExternalProject_Add(\n    flatbuffers\n    PREFIX flatbuffers\n    URL ${FLATBUFFERS_URL}\n    URL_MD5 c62ffefb3d4548b127cca14ce047f16c\n    UPDATE_COMMAND bash -c \"rm -f BUILD || true\"\n    BUILD_IN_SOURCE 1\n    SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/flatbuffers/src/flatbuffers\n    BUILD_BYPRODUCTS ${FLATBUFFERS_STATIC_LIBRARIES}\n    CMAKE_ARGS -DCMAKE_BUILD_TYPE:STRING=${CMAKE_BUILD_TYPE}\n               -DCMAKE_INSTALL_PREFIX=${FLATBUFFERS_INSTALL_PREFIX}\n               -DCMAKE_INSTALL_INCLUDEDIR=${FLATBUFFERS_INSTALL_INCLUDEDIR}\n               -DCMAKE_INSTALL_LIBDIR=${FLATBUFFERS_INSTALL_LIBDIR}\n               -DCMAKE_INSTALL_BINDIR=${FLATBUFFERS_INSTALL_BINDIR}\n               -DCMAKE_INSTALL_MESSAGE:STRING=${CMAKE_INSTALL_MESSAGE}\n               -DFLATBUFFERS_BUILD_TESTS=OFF)\nendif(THIRD_PARTY)\n"
  },
  {
    "path": "cmake/third_party/glog.cmake",
    "content": "include(ExternalProject)\n\nset_mirror_url_with_hash(glog_URL https://github.com/google/glog/archive/refs/tags/v0.5.0.tar.gz\n                         2368e3e0a95cce8b5b35a133271b480f)\n\ninclude(FetchContent)\n\nFetchContent_Declare(glog URL ${glog_URL} URL_HASH MD5=${glog_URL_HASH})\n\nset(WITH_GFLAGS OFF CACHE BOOL \"\")\nset(BUILD_SHARED_LIBS OFF CACHE BOOL \"\")\nset(WITH_GTEST OFF CACHE BOOL \"\")\nFetchContent_MakeAvailable(glog)\n\n# just for tensorflow, DO NOT USE IN OTHER PLACE\nFetchContent_GetProperties(glog)\nset(GLOG_INCLUDE_DIR ${glog_BINARY_DIR})\n"
  },
  {
    "path": "cmake/third_party/googletest.cmake",
    "content": "include(FetchContent)\n\nset_mirror_url_with_hash(\n  googletest_URL https://github.com/google/googletest/archive/release-1.11.0.tar.gz\n  e8a8df240b6938bb6384155d4c37d937)\n\nFetchContent_Declare(googletest URL ${googletest_URL} URL_HASH MD5=${googletest_URL_HASH})\n\nFetchContent_MakeAvailable(googletest)\n"
  },
  {
    "path": "cmake/third_party/grpc.cmake",
    "content": "include(ExternalProject)\n\nset(GRPC_INSTALL_DIR ${THIRD_PARTY_DIR}/grpc)\nset(GRPC_INSTALL_INCLUDE_DIR include)\nset(GRPC_INSTALL_LIBRARY_DIR lib)\nset(GRPC_INCLUDE_DIR ${THIRD_PARTY_DIR}/grpc/${GRPC_INSTALL_INCLUDE_DIR})\nset(GRPC_LIBRARY_DIR ${THIRD_PARTY_DIR}/grpc/${GRPC_INSTALL_LIBRARY_DIR})\n\nset(GRPC_INCLUDE_DIRS ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/include)\nset(GRPC_TAR_URL https://github.com/grpc/grpc/archive/v1.27.3.tar.gz)\nuse_mirror(VARIABLE GRPC_TAR_URL URL ${GRPC_TAR_URL})\nset(GRPC_URL_HASH 0c6c3fc8682d4262dd0e5e6fabe1a7e2)\nset(GRPC_SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/grpc)\n\nif(WIN32)\n  set(GRPC_LIBRARY_NAMES grpc++_unsecure.lib grpc_unsecure.lib gpr.lib upb.lib address_sorting.lib\n                         cares.lib)\nelseif(APPLE AND (\"${CMAKE_GENERATOR}\" STREQUAL \"Xcode\"))\n  set(GRPC_LIBRARY_NAMES libgrpc++_unsecure.a libgrpc_unsecure.a libgpr.a libupb.a\n                         libaddress_sorting.a libcares.a)\nelse()\n  include(GNUInstallDirs)\n  set(GRPC_LIBRARY_NAMES libgrpc++_unsecure.a libgrpc_unsecure.a libgpr.a libupb.a\n                         libaddress_sorting.a libcares.a)\nendif()\n\nforeach(LIBRARY_NAME ${GRPC_LIBRARY_NAMES})\n  list(APPEND GRPC_STATIC_LIBRARIES ${GRPC_LIBRARY_DIR}/${LIBRARY_NAME})\nendforeach()\n\nset(PROTOBUF_CONFIG_DIR ${PROTOBUF_LIBRARY_DIR}/cmake/protobuf)\nset(ABSL_CONFIG_DIR ${ABSL_INSTALL}/${CMAKE_INSTALL_LIBDIR}/cmake/absl)\n\nif(THIRD_PARTY)\n\n  include(ProcessorCount)\n  ProcessorCount(PROC_NUM)\n  ExternalProject_Add(\n    grpc\n    PREFIX ${GRPC_SOURCE_DIR}\n    DEPENDS protobuf absl cares openssl zlib\n    URL ${GRPC_TAR_URL}\n    URL_HASH MD5=${GRPC_URL_HASH}\n    UPDATE_COMMAND \"\"\n    BUILD_IN_SOURCE 1\n    BUILD_BYPRODUCTS ${GRPC_STATIC_LIBRARIES}\n    BUILD_COMMAND\n      ${CMAKE_COMMAND} --build . -j ${PROC_NUM} --target grpc && ${CMAKE_COMMAND} --build . -j\n      ${PROC_NUM} --target grpc_unsecure && ${CMAKE_COMMAND} --build . -j ${PROC_NUM} --target\n      grpc++_unsecure\n    CMAKE_CACHE_ARGS\n      -DCMAKE_C_COMPILER_LAUNCHER:STRING=${CMAKE_C_COMPILER_LAUNCHER}\n      -DCMAKE_CXX_COMPILER_LAUNCHER:STRING=${CMAKE_CXX_COMPILER_LAUNCHER}\n      -DCMAKE_POLICY_DEFAULT_CMP0074:STRING=NEW\n      -DCMAKE_BUILD_TYPE:STRING=${CMAKE_BUILD_TYPE}\n      -DCMAKE_CXX_FLAGS_DEBUG:STRING=${CMAKE_CXX_FLAGS_DEBUG}\n      -DCMAKE_CXX_FLAGS_RELEASE:STRING=${CMAKE_CXX_FLAGS_RELEASE}\n      -DCMAKE_C_FLAGS_DEBUG:STRING=${CMAKE_C_FLAGS_DEBUG}\n      -DCMAKE_C_FLAGS_RELEASE:STRING=${CMAKE_C_FLAGS_RELEASE}\n      -DCMAKE_CXX_STANDARD:STRING=${CMAKE_CXX_STANDARD}\n      -DgRPC_INSTALL:BOOL=ON\n      -DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF\n      -DgRPC_BUILD_TESTS:BOOL=OFF\n      -DgRPC_BUILD_GRPC_CPP_PLUGIN:BOOL=ON\n      -DgRPC_BUILD_GRPC_CSHARP_PLUGIN:BOOL=OFF\n      -DgRPC_BUILD_GRPC_NODE_PLUGIN:BOOL=OFF\n      -DgRPC_BUILD_GRPC_OBJECTIVE_C_PLUGIN:BOOL=OFF\n      -DgRPC_BUILD_GRPC_PHP_PLUGIN:BOOL=OFF\n      -DgRPC_BUILD_GRPC_PYTHON_PLUGIN:BOOL=OFF\n      -DgRPC_BUILD_GRPC_RUBY_PLUGIN:BOOL=OFF\n      -DgRPC_ABSL_PROVIDER:STRING=package\n      -Dabsl_DIR:PATH=${ABSL_CONFIG_DIR}\n      -DgRPC_PROTOBUF_PROVIDER:STRING=package\n      -DgRPC_PROTOBUF_PACKAGE_TYPE:STRING=CONFIG\n      -DProtobuf_ROOT:STRING=${PROTOBUF_INSTALL_DIR}\n      -DProtobuf_DIR:PATH=${PROTOBUF_CONFIG_DIR}\n      -DgRPC_CARES_PROVIDER:STRING=module\n      -DCARES_ROOT_DIR:PATH=${CARES_SOURCE_DIR}\n      -DgRPC_ZLIB_PROVIDER:STRING=package\n      -DZLIB_ROOT:PATH=${ZLIB_INSTALL}\n      -DgRPC_SSL_PROVIDER:STRING=package\n      -DOpenSSL_ROOT:PATH=${OPENSSL_INSTALL}\n      -DCMAKE_INSTALL_PREFIX:STRING=${GRPC_INSTALL_DIR}\n      -DCMAKE_INSTALL_MESSAGE:STRING=${CMAKE_INSTALL_MESSAGE})\nendif(THIRD_PARTY)\n"
  },
  {
    "path": "cmake/third_party/half.cmake",
    "content": "include(ExternalProject)\n\nset(HALF_INCLUDE_DIR ${THIRD_PARTY_DIR}/half/include)\n\nset(HALF_URL https://github.com/Oneflow-Inc/half/archive/refs/tags/v2.1.0-fix-cuda-raise.zip)\nuse_mirror(VARIABLE HALF_URL URL ${HALF_URL})\nset(HALF_BASE_DIR ${CMAKE_CURRENT_BINARY_DIR}/half/src/half)\nset(HALF_URL_HASH 30b0dc289729f9e85ddf6995f2e6968f)\nset(HALF_HEADERS \"${HALF_BASE_DIR}/include/half.hpp\")\n\nif(THIRD_PARTY)\n\n  ExternalProject_Add(\n    half\n    PREFIX half\n    URL ${HALF_URL}\n    URL_HASH MD5=${HALF_URL_HASH}\n    UPDATE_COMMAND \"\"\n    CONFIGURE_COMMAND \"\"\n    BUILD_COMMAND \"\"\n    BUILD_IN_SOURCE 1\n    INSTALL_COMMAND \"\")\n\n  add_custom_target(half_create_header_dir COMMAND ${CMAKE_COMMAND} -E make_directory\n                                                   ${HALF_INCLUDE_DIR} DEPENDS half)\n\n  add_custom_target(half_copy_headers_to_destination DEPENDS half_create_header_dir)\n\n  foreach(header_file ${HALF_HEADERS})\n    add_custom_command(\n      TARGET half_copy_headers_to_destination PRE_BUILD\n      COMMAND ${CMAKE_COMMAND} -E copy_if_different ${header_file} ${HALF_INCLUDE_DIR})\n  endforeach()\nendif(THIRD_PARTY)\n"
  },
  {
    "path": "cmake/third_party/header_index/cub_headers.txt",
    "content": "config.cuh\ncub.cuh\nutil_allocator.cuh\nutil_arch.cuh\nutil_compiler.cuh\nutil_cpp_dialect.cuh\nutil_debug.cuh\nutil_deprecated.cuh\nutil_device.cuh\nutil_macro.cuh\nutil_namespace.cuh\nutil_ptx.cuh\nutil_type.cuh\nversion.cuh\nagent/agent_histogram.cuh\nagent/agent_radix_sort_downsweep.cuh\nagent/agent_radix_sort_histogram.cuh\nagent/agent_radix_sort_onesweep.cuh\nagent/agent_radix_sort_upsweep.cuh\nagent/agent_reduce.cuh\nagent/agent_reduce_by_key.cuh\nagent/agent_rle.cuh\nagent/agent_scan.cuh\nagent/agent_segment_fixup.cuh\nagent/agent_select_if.cuh\nagent/agent_spmv_orig.cuh\nagent/single_pass_scan_operators.cuh\nblock/block_adjacent_difference.cuh\nblock/block_discontinuity.cuh\nblock/block_exchange.cuh\nblock/block_histogram.cuh\nblock/block_load.cuh\nblock/block_radix_rank.cuh\nblock/block_radix_sort.cuh\nblock/block_raking_layout.cuh\nblock/block_reduce.cuh\nblock/block_scan.cuh\nblock/block_shuffle.cuh\nblock/block_store.cuh\nblock/radix_rank_sort_operations.cuh\nblock/specializations/block_histogram_atomic.cuh\nblock/specializations/block_histogram_sort.cuh\nblock/specializations/block_reduce_raking.cuh\nblock/specializations/block_reduce_raking_commutative_only.cuh\nblock/specializations/block_reduce_warp_reductions.cuh\nblock/specializations/block_scan_raking.cuh\nblock/specializations/block_scan_warp_scans.cuh\nblock/specializations/block_scan_warp_scans2.cuh\nblock/specializations/block_scan_warp_scans3.cuh\ndevice/device_histogram.cuh\ndevice/device_partition.cuh\ndevice/device_radix_sort.cuh\ndevice/device_reduce.cuh\ndevice/device_run_length_encode.cuh\ndevice/device_scan.cuh\ndevice/device_segmented_radix_sort.cuh\ndevice/device_segmented_reduce.cuh\ndevice/device_select.cuh\ndevice/device_spmv.cuh\ndevice/dispatch/dispatch_histogram.cuh\ndevice/dispatch/dispatch_radix_sort.cuh\ndevice/dispatch/dispatch_reduce.cuh\ndevice/dispatch/dispatch_reduce_by_key.cuh\ndevice/dispatch/dispatch_rle.cuh\ndevice/dispatch/dispatch_scan.cuh\ndevice/dispatch/dispatch_select_if.cuh\ndevice/dispatch/dispatch_spmv_orig.cuh\ngrid/grid_barrier.cuh\ngrid/grid_even_share.cuh\ngrid/grid_mapping.cuh\ngrid/grid_queue.cuh\nhost/mutex.cuh\niterator/arg_index_input_iterator.cuh\niterator/cache_modified_input_iterator.cuh\niterator/cache_modified_output_iterator.cuh\niterator/constant_input_iterator.cuh\niterator/counting_input_iterator.cuh\niterator/discard_output_iterator.cuh\niterator/tex_obj_input_iterator.cuh\niterator/tex_ref_input_iterator.cuh\niterator/transform_input_iterator.cuh\nthread/thread_load.cuh\nthread/thread_operators.cuh\nthread/thread_reduce.cuh\nthread/thread_scan.cuh\nthread/thread_search.cuh\nthread/thread_store.cuh\nwarp/warp_reduce.cuh\nwarp/warp_scan.cuh\nwarp/specializations/warp_reduce_shfl.cuh\nwarp/specializations/warp_reduce_smem.cuh\nwarp/specializations/warp_scan_shfl.cuh\nwarp/specializations/warp_scan_smem.cuh\n"
  },
  {
    "path": "cmake/third_party/header_index/grpc_headers.txt",
    "content": "grpc++/alarm.h\ngrpc++/channel.h\ngrpc++/client_context.h\ngrpc++/completion_queue.h\ngrpc++/create_channel.h\ngrpc++/create_channel_posix.h\ngrpc++/grpc++.h\ngrpc++/health_check_service_interface.h\ngrpc++/resource_quota.h\ngrpc++/server.h\ngrpc++/server_builder.h\ngrpc++/server_context.h\ngrpc++/server_posix.h\ngrpc++/ext/health_check_service_server_builder_option.h\ngrpc++/ext/proto_server_reflection_plugin.h\ngrpc++/generic/async_generic_service.h\ngrpc++/generic/generic_stub.h\ngrpc++/impl/call.h\ngrpc++/impl/channel_argument_option.h\ngrpc++/impl/client_unary_call.h\ngrpc++/impl/grpc_library.h\ngrpc++/impl/method_handler_impl.h\ngrpc++/impl/rpc_method.h\ngrpc++/impl/rpc_service_method.h\ngrpc++/impl/serialization_traits.h\ngrpc++/impl/server_builder_option.h\ngrpc++/impl/server_builder_plugin.h\ngrpc++/impl/server_initializer.h\ngrpc++/impl/service_type.h\ngrpc++/impl/sync_cxx11.h\ngrpc++/impl/sync_no_cxx11.h\ngrpc++/impl/codegen/async_stream.h\ngrpc++/impl/codegen/async_unary_call.h\ngrpc++/impl/codegen/byte_buffer.h\ngrpc++/impl/codegen/call.h\ngrpc++/impl/codegen/call_hook.h\ngrpc++/impl/codegen/channel_interface.h\ngrpc++/impl/codegen/client_context.h\ngrpc++/impl/codegen/client_unary_call.h\ngrpc++/impl/codegen/completion_queue.h\ngrpc++/impl/codegen/completion_queue_tag.h\ngrpc++/impl/codegen/config.h\ngrpc++/impl/codegen/config_protobuf.h\ngrpc++/impl/codegen/core_codegen.h\ngrpc++/impl/codegen/core_codegen_interface.h\ngrpc++/impl/codegen/create_auth_context.h\ngrpc++/impl/codegen/grpc_library.h\ngrpc++/impl/codegen/metadata_map.h\ngrpc++/impl/codegen/method_handler_impl.h\ngrpc++/impl/codegen/proto_utils.h\ngrpc++/impl/codegen/rpc_method.h\ngrpc++/impl/codegen/rpc_service_method.h\ngrpc++/impl/codegen/serialization_traits.h\ngrpc++/impl/codegen/server_context.h\ngrpc++/impl/codegen/server_interface.h\ngrpc++/impl/codegen/service_type.h\ngrpc++/impl/codegen/slice.h\ngrpc++/impl/codegen/status.h\ngrpc++/impl/codegen/status_code_enum.h\ngrpc++/impl/codegen/string_ref.h\ngrpc++/impl/codegen/stub_options.h\ngrpc++/impl/codegen/sync_stream.h\ngrpc++/impl/codegen/time.h\ngrpc++/impl/codegen/security/auth_context.h\ngrpc++/security/auth_context.h\ngrpc++/security/auth_metadata_processor.h\ngrpc++/security/credentials.h\ngrpc++/security/server_credentials.h\ngrpc++/support/async_stream.h\ngrpc++/support/async_unary_call.h\ngrpc++/support/byte_buffer.h\ngrpc++/support/channel_arguments.h\ngrpc++/support/config.h\ngrpc++/support/error_details.h\ngrpc++/support/slice.h\ngrpc++/support/status.h\ngrpc++/support/status_code_enum.h\ngrpc++/support/string_ref.h\ngrpc++/support/stub_options.h\ngrpc++/support/sync_stream.h\ngrpc++/support/time.h\ngrpc++/test/mock_stream.h\ngrpc++/test/server_context_test_spouse.h\ngrpc/byte_buffer.h\ngrpc/byte_buffer_reader.h\ngrpc/census.h\ngrpc/compression.h\ngrpc/fork.h\ngrpc/grpc.h\ngrpc/grpc_cronet.h\ngrpc/grpc_posix.h\ngrpc/grpc_security.h\ngrpc/grpc_security_constants.h\ngrpc/load_reporting.h\ngrpc/slice.h\ngrpc/slice_buffer.h\ngrpc/status.h\ngrpc/impl/codegen/atm.h\ngrpc/impl/codegen/atm_gcc_atomic.h\ngrpc/impl/codegen/atm_gcc_sync.h\ngrpc/impl/codegen/atm_windows.h\ngrpc/impl/codegen/byte_buffer.h\ngrpc/impl/codegen/byte_buffer_reader.h\ngrpc/impl/codegen/compression_types.h\ngrpc/impl/codegen/connectivity_state.h\ngrpc/impl/codegen/fork.h\ngrpc/impl/codegen/gpr_slice.h\ngrpc/impl/codegen/gpr_types.h\ngrpc/impl/codegen/grpc_types.h\ngrpc/impl/codegen/log.h\ngrpc/impl/codegen/port_platform.h\ngrpc/impl/codegen/propagation_bits.h\ngrpc/impl/codegen/slice.h\ngrpc/impl/codegen/status.h\ngrpc/impl/codegen/sync.h\ngrpc/impl/codegen/sync_custom.h\ngrpc/impl/codegen/sync_generic.h\ngrpc/impl/codegen/sync_posix.h\ngrpc/impl/codegen/sync_windows.h\ngrpc/support/alloc.h\ngrpc/support/atm.h\ngrpc/support/atm_gcc_atomic.h\ngrpc/support/atm_gcc_sync.h\ngrpc/support/atm_windows.h\ngrpc/support/cpu.h\ngrpc/support/log.h\ngrpc/support/log_windows.h\ngrpc/support/port_platform.h\ngrpc/support/string_util.h\ngrpc/support/sync.h\ngrpc/support/sync_custom.h\ngrpc/support/sync_generic.h\ngrpc/support/sync_posix.h\ngrpc/support/sync_windows.h\ngrpc/support/thd_id.h\ngrpc/support/time.h\ngrpc/support/workaround_list.h\ngrpcpp/alarm.h\ngrpcpp/alarm_impl.h\ngrpcpp/channel.h\ngrpcpp/channel_impl.h\ngrpcpp/client_context.h\ngrpcpp/completion_queue.h\ngrpcpp/completion_queue_impl.h\ngrpcpp/create_channel.h\ngrpcpp/create_channel_impl.h\ngrpcpp/create_channel_posix.h\ngrpcpp/create_channel_posix_impl.h\ngrpcpp/grpcpp.h\ngrpcpp/health_check_service_interface.h\ngrpcpp/health_check_service_interface_impl.h\ngrpcpp/opencensus.h\ngrpcpp/opencensus_impl.h\ngrpcpp/resource_quota.h\ngrpcpp/resource_quota_impl.h\ngrpcpp/server.h\ngrpcpp/server_builder.h\ngrpcpp/server_builder_impl.h\ngrpcpp/server_context.h\ngrpcpp/server_impl.h\ngrpcpp/server_posix.h\ngrpcpp/server_posix_impl.h\ngrpcpp/ext/channelz_service_plugin.h\ngrpcpp/ext/channelz_service_plugin_impl.h\ngrpcpp/ext/health_check_service_server_builder_option.h\ngrpcpp/ext/proto_server_reflection_plugin.h\ngrpcpp/ext/proto_server_reflection_plugin_impl.h\ngrpcpp/ext/server_load_reporting.h\ngrpcpp/ext/server_load_reporting_impl.h\ngrpcpp/generic/async_generic_service.h\ngrpcpp/generic/generic_stub.h\ngrpcpp/generic/generic_stub_impl.h\ngrpcpp/impl/call.h\ngrpcpp/impl/channel_argument_option.h\ngrpcpp/impl/client_unary_call.h\ngrpcpp/impl/grpc_library.h\ngrpcpp/impl/method_handler_impl.h\ngrpcpp/impl/rpc_method.h\ngrpcpp/impl/rpc_service_method.h\ngrpcpp/impl/serialization_traits.h\ngrpcpp/impl/server_builder_option.h\ngrpcpp/impl/server_builder_option_impl.h\ngrpcpp/impl/server_builder_plugin.h\ngrpcpp/impl/server_initializer.h\ngrpcpp/impl/server_initializer_impl.h\ngrpcpp/impl/service_type.h\ngrpcpp/impl/sync_cxx11.h\ngrpcpp/impl/sync_no_cxx11.h\ngrpcpp/impl/codegen/async_generic_service.h\ngrpcpp/impl/codegen/async_stream.h\ngrpcpp/impl/codegen/async_stream_impl.h\ngrpcpp/impl/codegen/async_unary_call.h\ngrpcpp/impl/codegen/async_unary_call_impl.h\ngrpcpp/impl/codegen/byte_buffer.h\ngrpcpp/impl/codegen/call.h\ngrpcpp/impl/codegen/call_hook.h\ngrpcpp/impl/codegen/call_op_set.h\ngrpcpp/impl/codegen/call_op_set_interface.h\ngrpcpp/impl/codegen/callback_common.h\ngrpcpp/impl/codegen/channel_interface.h\ngrpcpp/impl/codegen/client_callback.h\ngrpcpp/impl/codegen/client_callback_impl.h\ngrpcpp/impl/codegen/client_context.h\ngrpcpp/impl/codegen/client_context_impl.h\ngrpcpp/impl/codegen/client_interceptor.h\ngrpcpp/impl/codegen/client_unary_call.h\ngrpcpp/impl/codegen/completion_queue.h\ngrpcpp/impl/codegen/completion_queue_impl.h\ngrpcpp/impl/codegen/completion_queue_tag.h\ngrpcpp/impl/codegen/config.h\ngrpcpp/impl/codegen/config_protobuf.h\ngrpcpp/impl/codegen/core_codegen.h\ngrpcpp/impl/codegen/core_codegen_interface.h\ngrpcpp/impl/codegen/create_auth_context.h\ngrpcpp/impl/codegen/delegating_channel.h\ngrpcpp/impl/codegen/grpc_library.h\ngrpcpp/impl/codegen/intercepted_channel.h\ngrpcpp/impl/codegen/interceptor.h\ngrpcpp/impl/codegen/interceptor_common.h\ngrpcpp/impl/codegen/message_allocator.h\ngrpcpp/impl/codegen/metadata_map.h\ngrpcpp/impl/codegen/method_handler.h\ngrpcpp/impl/codegen/method_handler_impl.h\ngrpcpp/impl/codegen/proto_buffer_reader.h\ngrpcpp/impl/codegen/proto_buffer_writer.h\ngrpcpp/impl/codegen/proto_utils.h\ngrpcpp/impl/codegen/rpc_method.h\ngrpcpp/impl/codegen/rpc_service_method.h\ngrpcpp/impl/codegen/serialization_traits.h\ngrpcpp/impl/codegen/server_callback.h\ngrpcpp/impl/codegen/server_callback_handlers.h\ngrpcpp/impl/codegen/server_callback_impl.h\ngrpcpp/impl/codegen/server_context.h\ngrpcpp/impl/codegen/server_context_impl.h\ngrpcpp/impl/codegen/server_interceptor.h\ngrpcpp/impl/codegen/server_interface.h\ngrpcpp/impl/codegen/service_type.h\ngrpcpp/impl/codegen/slice.h\ngrpcpp/impl/codegen/status.h\ngrpcpp/impl/codegen/status_code_enum.h\ngrpcpp/impl/codegen/string_ref.h\ngrpcpp/impl/codegen/stub_options.h\ngrpcpp/impl/codegen/sync.h\ngrpcpp/impl/codegen/sync_stream.h\ngrpcpp/impl/codegen/sync_stream_impl.h\ngrpcpp/impl/codegen/time.h\ngrpcpp/impl/codegen/security/auth_context.h\ngrpcpp/security/alts_context.h\ngrpcpp/security/alts_util.h\ngrpcpp/security/auth_context.h\ngrpcpp/security/auth_metadata_processor.h\ngrpcpp/security/auth_metadata_processor_impl.h\ngrpcpp/security/credentials.h\ngrpcpp/security/credentials_impl.h\ngrpcpp/security/cronet_credentials.h\ngrpcpp/security/cronet_credentials_impl.h\ngrpcpp/security/server_credentials.h\ngrpcpp/security/server_credentials_impl.h\ngrpcpp/security/tls_credentials_options.h\ngrpcpp/support/async_stream.h\ngrpcpp/support/async_stream_impl.h\ngrpcpp/support/async_unary_call.h\ngrpcpp/support/async_unary_call_impl.h\ngrpcpp/support/byte_buffer.h\ngrpcpp/support/channel_arguments.h\ngrpcpp/support/channel_arguments_impl.h\ngrpcpp/support/client_callback.h\ngrpcpp/support/client_callback_impl.h\ngrpcpp/support/client_interceptor.h\ngrpcpp/support/config.h\ngrpcpp/support/error_details.h\ngrpcpp/support/error_details_impl.h\ngrpcpp/support/interceptor.h\ngrpcpp/support/message_allocator.h\ngrpcpp/support/method_handler.h\ngrpcpp/support/proto_buffer_reader.h\ngrpcpp/support/proto_buffer_writer.h\ngrpcpp/support/server_callback.h\ngrpcpp/support/server_callback_impl.h\ngrpcpp/support/server_interceptor.h\ngrpcpp/support/slice.h\ngrpcpp/support/status.h\ngrpcpp/support/status_code_enum.h\ngrpcpp/support/string_ref.h\ngrpcpp/support/stub_options.h\ngrpcpp/support/sync_stream.h\ngrpcpp/support/sync_stream_impl.h\ngrpcpp/support/time.h\ngrpcpp/support/validate_service_config.h\ngrpcpp/test/default_reactor_test_peer.h\ngrpcpp/test/mock_stream.h\ngrpcpp/test/server_context_test_spouse.h\n"
  },
  {
    "path": "cmake/third_party/header_index/libpng_headers.txt",
    "content": "png.h\npngconf.h\npngdebug.h\npnginfo.h\npnglibconf.h\npngpriv.h\npngstruct.h\n"
  },
  {
    "path": "cmake/third_party/header_index/opencv_headers.txt",
    "content": "opencv2/cvconfig.h\nopencv2/core/cv_cpu_dispatch.h\nopencv2/core/types_c.h\nopencv2/core/cvdef.h\nopencv2/core/core_c.h\nopencv2/core/cv_cpu_helper.h\nopencv2/core/hal/interface.h\nopencv2/imgproc/imgproc_c.h\nopencv2/imgproc/types_c.h\nopencv2/imgproc/hal/interface.h\nopencv2/imgcodecs/ios.h\nopencv2/imgcodecs/imgcodecs_c.h\nopencv/cvwimage.h\nopencv/cxcore.h\nopencv/highgui.h\nopencv/cvaux.h\nopencv/ml.h\nopencv/cv.h\nopencv/cxmisc.h\nopencv2/opencv.hpp\nopencv2/imgproc.hpp\nopencv2/opencv_modules.hpp\nopencv2/imgcodecs.hpp\nopencv2/core.hpp\nopencv2/core/directx.hpp\nopencv2/core/fast_math.hpp\nopencv2/core/persistence.hpp\nopencv2/core/traits.hpp\nopencv2/core/mat.hpp\nopencv2/core/affine.hpp\nopencv2/core/cuda_stream_accessor.hpp\nopencv2/core/wimage.hpp\nopencv2/core/cvstd.hpp\nopencv2/core/base.hpp\nopencv2/core/optim.hpp\nopencv2/core/vsx_utils.hpp\nopencv2/core/va_intel.hpp\nopencv2/core/ocl.hpp\nopencv2/core/ptr.inl.hpp\nopencv2/core/saturate.hpp\nopencv2/core/neon_utils.hpp\nopencv2/core/cuda.inl.hpp\nopencv2/core/utility.hpp\nopencv2/core/opengl.hpp\nopencv2/core/eigen.hpp\nopencv2/core/cuda_types.hpp\nopencv2/core/cuda.hpp\nopencv2/core/mat.inl.hpp\nopencv2/core/operations.hpp\nopencv2/core/cvstd.inl.hpp\nopencv2/core/ovx.hpp\nopencv2/core/ippasync.hpp\nopencv2/core/bufferpool.hpp\nopencv2/core/matx.hpp\nopencv2/core/sse_utils.hpp\nopencv2/core/types.hpp\nopencv2/core/version.hpp\nopencv2/core/ocl_genbase.hpp\nopencv2/core/core.hpp\nopencv2/core/softfloat.hpp\nopencv2/core/hal/hal.hpp\nopencv2/core/hal/intrin_sse.hpp\nopencv2/core/hal/intrin_neon.hpp\nopencv2/core/hal/intrin_cpp.hpp\nopencv2/core/hal/intrin.hpp\nopencv2/core/hal/intrin_vsx.hpp\nopencv2/core/cuda/reduce.hpp\nopencv2/core/cuda/warp_shuffle.hpp\nopencv2/core/cuda/emulation.hpp\nopencv2/core/cuda/limits.hpp\nopencv2/core/cuda/warp_reduce.hpp\nopencv2/core/cuda/filters.hpp\nopencv2/core/cuda/vec_distance.hpp\nopencv2/core/cuda/scan.hpp\nopencv2/core/cuda/utility.hpp\nopencv2/core/cuda/type_traits.hpp\nopencv2/core/cuda/block.hpp\nopencv2/core/cuda/vec_traits.hpp\nopencv2/core/cuda/funcattrib.hpp\nopencv2/core/cuda/datamov_utils.hpp\nopencv2/core/cuda/vec_math.hpp\nopencv2/core/cuda/common.hpp\nopencv2/core/cuda/warp.hpp\nopencv2/core/cuda/color.hpp\nopencv2/core/cuda/border_interpolate.hpp\nopencv2/core/cuda/simd_functions.hpp\nopencv2/core/cuda/dynamic_smem.hpp\nopencv2/core/cuda/functional.hpp\nopencv2/core/cuda/saturate_cast.hpp\nopencv2/core/cuda/transform.hpp\nopencv2/core/cuda/detail/reduce.hpp\nopencv2/core/cuda/detail/reduce_key_val.hpp\nopencv2/core/cuda/detail/vec_distance_detail.hpp\nopencv2/core/cuda/detail/color_detail.hpp\nopencv2/core/cuda/detail/transform_detail.hpp\nopencv2/core/cuda/detail/type_traits_detail.hpp\nopencv2/core/utils/logger.hpp\nopencv2/core/utils/trace.hpp\nopencv2/imgproc/imgproc.hpp\nopencv2/imgproc/hal/hal.hpp\nopencv2/imgproc/detail/distortion_model.hpp\nopencv2/imgcodecs/imgcodecs.hpp\nopencv/cxcore.hpp\nopencv/cv.hpp\nopencv/cxeigen.hpp\nopencv/cvaux.hpp\n"
  },
  {
    "path": "cmake/third_party/hwloc.cmake",
    "content": "include(ExternalProject)\n\nif(UNIX AND NOT APPLE)\n  set(BUILD_HWLOC_DEFAULT ON)\nelse()\n  set(BUILD_HWLOC_DEFAULT OFF)\nendif()\noption(BUILD_HWLOC \"\" ${BUILD_HWLOC_DEFAULT})\n\nif(BUILD_HWLOC)\n\n  set(PCIACCESS_INSTALL ${THIRD_PARTY_DIR}/pciaccess)\n  set(PCIACCESS_INCLUDE_DIR ${PCIACCESS_INSTALL}/include)\n  set(PCIACCESS_LIBRARY_DIR ${PCIACCESS_INSTALL}/lib)\n  set(PCIACCESS_LIBRARY_NAMES libpciaccess.a)\n  foreach(LIBRARY_NAME ${PCIACCESS_LIBRARY_NAMES})\n    list(APPEND PCIACCESS_STATIC_LIBRARIES ${PCIACCESS_LIBRARY_DIR}/${LIBRARY_NAME})\n  endforeach()\n\n  set(HWLOC_INSTALL ${THIRD_PARTY_DIR}/hwloc)\n  set(HWLOC_INCLUDE_DIR ${HWLOC_INSTALL}/include)\n  set(HWLOC_LIBRARY_DIR ${HWLOC_INSTALL}/lib)\n  set(HWLOC_LIBRARY_NAMES libhwloc.a)\n  foreach(LIBRARY_NAME ${HWLOC_LIBRARY_NAMES})\n    list(APPEND ONEFLOW_HWLOC_STATIC_LIBRARIES ${HWLOC_LIBRARY_DIR}/${LIBRARY_NAME})\n  endforeach()\n\n  if(THIRD_PARTY)\n\n    include(ProcessorCount)\n    ProcessorCount(PROC_NUM)\n\n    set(XORG_MACROS_INSTALL ${THIRD_PARTY_DIR}/xorg-macros)\n    set(XORG_MACROS_TAR_URL\n        https://gitlab.freedesktop.org/xorg/util/macros/-/archive/util-macros-1.19.1/macros-util-macros-1.19.1.tar.gz\n    )\n    use_mirror(VARIABLE XORG_MACROS_TAR_URL URL ${XORG_MACROS_TAR_URL})\n    set(XORG_MACROS_URL_HASH 764fb1647d7ebd1c8c5d707db525832f)\n    set(XORG_MACROS_SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/xorg-macros)\n    set(XORG_MACROS_PKG_CONFIG_DIR ${XORG_MACROS_INSTALL}/share/pkgconfig)\n\n    ExternalProject_Add(\n      xorg-macros\n      PREFIX xorg-macros\n      URL ${XORG_MACROS_TAR_URL}\n      URL_HASH MD5=${XORG_MACROS_URL_HASH}\n      UPDATE_COMMAND \"\"\n      CONFIGURE_COMMAND ${XORG_MACROS_SOURCE_DIR}/src/xorg-macros/autogen.sh\n      COMMAND ${XORG_MACROS_SOURCE_DIR}/src/xorg-macros/configure --prefix=${XORG_MACROS_INSTALL}\n      BUILD_COMMAND make -j${PROC_NUM}\n      INSTALL_COMMAND make install)\n\n    set(PCIACCESS_TAR_URL\n        https://gitlab.freedesktop.org/xorg/lib/libpciaccess/-/archive/libpciaccess-0.16/libpciaccess-libpciaccess-0.16.tar.gz\n    )\n    use_mirror(VARIABLE PCIACCESS_TAR_URL URL ${PCIACCESS_TAR_URL})\n    set(PCIACCESS_URL_HASH 93554c189796c27dfc72af17a367a0b4)\n    set(PCIACCESS_SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/pciaccess)\n\n    set(PCIACCESS_CFLAGS \"-O3 -fPIC\")\n\n    ExternalProject_Add(\n      pciaccess\n      PREFIX pciaccess\n      URL ${PCIACCESS_TAR_URL}\n      URL_HASH MD5=${PCIACCESS_URL_HASH}\n      UPDATE_COMMAND \"\"\n      PATCH_COMMAND cp ${XORG_MACROS_INSTALL}/share/aclocal/xorg-macros.m4\n                    ${PCIACCESS_SOURCE_DIR}/src/pciaccess/m4\n      CONFIGURE_COMMAND ${PCIACCESS_SOURCE_DIR}/src/pciaccess/autogen.sh\n      COMMAND ${PCIACCESS_SOURCE_DIR}/src/pciaccess/configure --prefix=${PCIACCESS_INSTALL}\n              --enable-shared=no\n      BUILD_COMMAND make -j${PROC_NUM} CFLAGS=${PCIACCESS_CFLAGS}\n      BUILD_BYPRODUCTS ${PCIACCESS_STATIC_LIBRARIES}\n      INSTALL_COMMAND make install\n      DEPENDS xorg-macros)\n    set(HWLOC_TAR_URL https://github.com/open-mpi/hwloc/archive/refs/tags/hwloc-2.4.1.tar.gz)\n    use_mirror(VARIABLE HWLOC_TAR_URL URL ${HWLOC_TAR_URL})\n    set(HWLOC_URL_HASH ac25fc7c2a665b7914c6c21b782f1c4f)\n    set(HWLOC_SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/hwloc)\n\n    set(HWLOC_CFLAGS \"-O3 -fPIC\")\n\n    ExternalProject_Add(\n      hwloc\n      PREFIX hwloc\n      URL ${HWLOC_TAR_URL}\n      URL_HASH MD5=${HWLOC_URL_HASH}\n      UPDATE_COMMAND \"\"\n      CONFIGURE_COMMAND ${HWLOC_SOURCE_DIR}/src/hwloc/autogen.sh\n      COMMAND ${HWLOC_SOURCE_DIR}/src/hwloc/configure --prefix=${HWLOC_INSTALL}\n              PKG_CONFIG_PATH=${PCIACCESS_INSTALL}/lib/pkgconfig --disable-libxml2 --enable-static\n              --enable-shared=no\n      BUILD_COMMAND make -j${PROC_NUM} CFLAGS=${HWLOC_CFLAGS}\n      BUILD_BYPRODUCTS ${ONEFLOW_HWLOC_STATIC_LIBRARIES}\n      INSTALL_COMMAND make install\n      DEPENDS pciaccess)\n  endif(THIRD_PARTY)\n\nendif(BUILD_HWLOC)\n"
  },
  {
    "path": "cmake/third_party/json.cmake",
    "content": "include(FetchContent)\n\nset_mirror_url_with_hash(JSON_URL https://github.com/nlohmann/json/archive/refs/tags/v3.11.2.zip\n                         49097a7ec390ffaf1cd2e14b734b6c75)\nset(JSON_Install ON CACHE STRING \"\" FORCE)\n\nFetchContent_Declare(json URL ${JSON_URL} URL_HASH MD5=${JSON_URL_HASH})\n\nFetchContent_MakeAvailable(json)\n"
  },
  {
    "path": "cmake/third_party/libjpeg-turbo.cmake",
    "content": "include(ExternalProject)\n\nset(LIBJPEG_INCLUDE_DIR ${THIRD_PARTY_DIR}/libjpeg-turbo/include)\nset(LIBJPEG_LIBRARY_DIR ${THIRD_PARTY_DIR}/libjpeg-turbo/lib)\n\nset(LIBJPEG_URL https://github.com/libjpeg-turbo/libjpeg-turbo/archive/refs/tags/2.1.3.tar.gz)\nuse_mirror(VARIABLE LIBJPEG_URL URL ${LIBJPEG_URL})\n\nif(WIN32)\n\nelseif(APPLE AND (\"${CMAKE_GENERATOR}\" STREQUAL \"Xcode\"))\n  set(LIBJPEG_BUILD_SRC_DIR ${CMAKE_CURRENT_BINARY_DIR}/libjpeg-turbo/src/libjpeg-turbo)\n  set(LIBJPEG_BUILD_LIBRARY_DIR\n      ${CMAKE_CURRENT_BINARY_DIR}/libjpeg-turbo/src/libjpeg-turbo/${CMAKE_BUILD_TYPE})\n  set(LIBJPEG_LIBRARY_NAMES libturbojpeg.a)\nelse()\n  set(LIBJPEG_BUILD_SRC_DIR ${CMAKE_CURRENT_BINARY_DIR}/libjpeg-turbo/src/libjpeg-turbo)\n  set(LIBJPEG_BUILD_LIBRARY_DIR ${CMAKE_CURRENT_BINARY_DIR}/libjpeg-turbo/src/libjpeg-turbo)\n  set(LIBJPEG_LIBRARY_NAMES libturbojpeg.a)\nendif()\n\nforeach(LIBRARY_NAME ${LIBJPEG_LIBRARY_NAMES})\n  list(APPEND LIBJPEG_STATIC_LIBRARIES ${LIBJPEG_LIBRARY_DIR}/${LIBRARY_NAME})\n  list(APPEND LIBJPEG_BUILD_STATIC_LIBRARIES ${LIBJPEG_BUILD_LIBRARY_DIR}/${LIBRARY_NAME})\nendforeach()\n\nset(LIBJPEG_HEADERS\n    \"${LIBJPEG_BUILD_SRC_DIR}/cderror.h\"\n    \"${LIBJPEG_BUILD_SRC_DIR}/cdjpeg.h\"\n    \"${LIBJPEG_BUILD_SRC_DIR}/cmyk.h\"\n    \"${LIBJPEG_BUILD_SRC_DIR}/jchuff.h\"\n    \"${LIBJPEG_BUILD_SRC_DIR}/jconfig.h\"\n    \"${LIBJPEG_BUILD_SRC_DIR}/jdcoefct.h\"\n    \"${LIBJPEG_BUILD_SRC_DIR}/jdct.h\"\n    \"${LIBJPEG_BUILD_SRC_DIR}/jdhuff.h\"\n    \"${LIBJPEG_BUILD_SRC_DIR}/jdmainct.h\"\n    \"${LIBJPEG_BUILD_SRC_DIR}/jdmaster.h\"\n    \"${LIBJPEG_BUILD_SRC_DIR}/jdsample.h\"\n    \"${LIBJPEG_BUILD_SRC_DIR}/jerror.h\"\n    \"${LIBJPEG_BUILD_SRC_DIR}/jinclude.h\"\n    \"${LIBJPEG_BUILD_SRC_DIR}/jmemsys.h\"\n    \"${LIBJPEG_BUILD_SRC_DIR}/jmorecfg.h\"\n    \"${LIBJPEG_BUILD_SRC_DIR}/jpegcomp.h\"\n    \"${LIBJPEG_BUILD_SRC_DIR}/jpegint.h\"\n    \"${LIBJPEG_BUILD_SRC_DIR}/jpeglib.h\"\n    \"${LIBJPEG_BUILD_SRC_DIR}/jpeg_nbits_table.h\"\n    \"${LIBJPEG_BUILD_SRC_DIR}/jsimddct.h\"\n    \"${LIBJPEG_BUILD_SRC_DIR}/jsimd.h\"\n    \"${LIBJPEG_BUILD_SRC_DIR}/jversion.h\"\n    \"${LIBJPEG_BUILD_SRC_DIR}/tjutil.h\"\n    \"${LIBJPEG_BUILD_SRC_DIR}/transupp.h\"\n    \"${LIBJPEG_BUILD_SRC_DIR}/turbojpeg.h\")\n\nif(THIRD_PARTY)\n\n  ExternalProject_Add(\n    libjpeg-turbo\n    PREFIX libjpeg-turbo\n    URL ${LIBJPEG_URL}\n    URL_MD5 627b980fad0573e08e4c3b80b290fc91\n    UPDATE_COMMAND \"\"\n    INSTALL_COMMAND \"\"\n    BUILD_IN_SOURCE 1\n    BUILD_BYPRODUCTS ${LIBJPEG_STATIC_LIBRARIES}\n    CMAKE_CACHE_ARGS\n      -DCMAKE_C_COMPILER_LAUNCHER:STRING=${CMAKE_C_COMPILER_LAUNCHER}\n      -DCMAKE_CXX_COMPILER_LAUNCHER:STRING=${CMAKE_CXX_COMPILER_LAUNCHER}\n      -DCMAKE_BUILD_TYPE:STRING=${CMAKE_BUILD_TYPE}\n      -DCMAKE_CXX_FLAGS_DEBUG:STRING=${CMAKE_CXX_FLAGS_DEBUG}\n      -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON)\n\n  # put libjpeg-turbo includes in the directory where they are expected\n  add_custom_target(libjpeg_create_header_dir COMMAND ${CMAKE_COMMAND} -E make_directory\n                                                      ${LIBJPEG_INCLUDE_DIR} DEPENDS libjpeg-turbo)\n\n  add_custom_target(libjpeg_copy_headers_to_destination DEPENDS libjpeg_create_header_dir)\n\n  foreach(header_file ${LIBJPEG_HEADERS})\n    add_custom_command(\n      TARGET libjpeg_copy_headers_to_destination PRE_BUILD\n      COMMAND ${CMAKE_COMMAND} -E copy_if_different ${header_file} ${LIBJPEG_INCLUDE_DIR})\n  endforeach()\n\n  # pub libjpeg libs in the directory where they are expected\n  add_custom_target(libjpeg_create_library_dir COMMAND ${CMAKE_COMMAND} -E make_directory\n                                                       ${LIBJPEG_LIBRARY_DIR} DEPENDS libjpeg-turbo)\n\n  add_custom_target(\n    libjpeg_copy_libs_to_destination\n    COMMAND ${CMAKE_COMMAND} -E copy_if_different ${LIBJPEG_BUILD_STATIC_LIBRARIES}\n            ${LIBJPEG_LIBRARY_DIR} DEPENDS libjpeg_create_library_dir)\n\nendif(THIRD_PARTY)\n"
  },
  {
    "path": "cmake/third_party/nccl.cmake",
    "content": "option(NCCL_STATIC \"\" ON)\nif(OF_CUDA_LINK_DYNAMIC_LIBRARY)\n  set(NCCL_STATIC OFF)\nendif()\noption(USE_SYSTEM_NCCL \"\" OFF)\nset(NCCL_ROOT_DIR \"\" CACHE PATH \"Folder contains NVIDIA NCCL\")\n\nif(WIN32)\n  set(NCCL_LIBRARY_NAME libnccl_static.lib)\nelse()\n  if(NCCL_STATIC)\n    set(NCCL_LIBRARY_NAME libnccl_static.a)\n  else()\n    set(NCCL_LIBRARY_NAME libnccl.so)\n  endif()\nendif()\n\nif(USE_SYSTEM_NCCL)\n  include(FindPackageHandleStandardArgs)\n  find_path(NCCL_INCLUDE_DIR nccl.h HINTS ${NCCL_ROOT_DIR} ${CUDAToolkit_INCLUDE_DIRS}\n            PATH_SUFFIXES cuda/include include)\n  unset(NCCL_LIBRARY CACHE)\n  find_library(\n    NCCL_LIBRARY ${NCCL_LIBRARY_NAME} HINTS ${NCCL_ROOT_DIR} ${CUDAToolkit_LIBRARY_DIR}\n                                            ${CUDAToolkit_LIBRARY_ROOT}\n    PATH_SUFFIXES lib lib64 cuda/lib cuda/lib64 lib/x64)\n  find_package_handle_standard_args(NCCL DEFAULT_MSG NCCL_INCLUDE_DIR NCCL_LIBRARY)\n  set(NCCL_LIBRARIES ${NCCL_LIBRARY})\n  add_custom_target(nccl)\nelse()\n  get_filename_component(CUDATOOLKIT_BIN_ROOT ${CUDAToolkit_BIN_DIR} DIRECTORY)\n  include(ExternalProject)\n  set(NCCL_INSTALL_DIR ${THIRD_PARTY_DIR}/nccl)\n  set(NCCL_INCLUDE_DIR ${NCCL_INSTALL_DIR}/include)\n  set(NCCL_LIBRARY_DIR ${NCCL_INSTALL_DIR}/lib)\n\n  # Versions 2.13 and above may cause deadlocks\n  if(CUDA_VERSION VERSION_GREATER_EQUAL \"11.8\")\n    set(NCCL_URL https://github.com/NVIDIA/nccl/archive/refs/tags/v2.15.1-1.tar.gz)\n    set(NCCL_MD5 37b787ff8934cd9374b4612f663c17fa)\n  else()\n    set(NCCL_URL https://github.com/NVIDIA/nccl/archive/refs/tags/v2.12.10-1.tar.gz)\n    set(NCCL_MD5 bdb91f80b78c99831f09ca8bb28a1032)\n  endif()\n\n  use_mirror(VARIABLE NCCL_URL URL ${NCCL_URL})\n\n  list(APPEND NCCL_LIBRARIES ${NCCL_LIBRARY_DIR}/${NCCL_LIBRARY_NAME})\n\n  set(NCCL_ARCHS_LIST ${CUDA_REAL_ARCHS_LIST})\n\n  # remove redundant archs, https://github.com/NVIDIA/nccl/blob/cb111f764a6d46370f24f75101d6b219bb2dda54/makefiles/common.mk#L28\n  if(\"70\" IN_LIST NCCL_ARCHS_LIST AND \"75\" IN_LIST NCCL_ARCHS_LIST)\n    list(REMOVE_ITEM NCCL_ARCHS_LIST \"75\")\n  endif()\n  if(\"80\" IN_LIST NCCL_ARCHS_LIST AND \"86\" IN_LIST NCCL_ARCHS_LIST)\n    list(REMOVE_ITEM NCCL_ARCHS_LIST \"86\")\n  endif()\n  if(\"80\" IN_LIST NCCL_ARCHS_LIST AND \"89\" IN_LIST NCCL_ARCHS_LIST)\n    list(REMOVE_ITEM NCCL_ARCHS_LIST \"89\")\n  endif()\n\n  foreach(arch ${NCCL_ARCHS_LIST})\n    string(APPEND NCCL_GENCODE \"-gencode=arch=compute_${arch},code=sm_${arch} \")\n  endforeach()\n\n  if(THIRD_PARTY)\n\n    include(ProcessorCount)\n    ProcessorCount(PROC_NUM)\n    ExternalProject_Add(\n      nccl\n      PREFIX nccl\n      URL ${NCCL_URL}\n      URL_MD5 ${NCCL_MD5}\n      UPDATE_COMMAND \"\"\n      CONFIGURE_COMMAND \"\"\n      BUILD_IN_SOURCE 1\n      BUILD_COMMAND make -j${PROC_NUM} src.build CUDA_HOME=${CUDATOOLKIT_BIN_ROOT}\n                    NVCC_GENCODE=${NCCL_GENCODE}\n      INSTALL_COMMAND make src.install PREFIX=${NCCL_INSTALL_DIR}\n      BUILD_BYPRODUCTS ${NCCL_LIBRARIES})\n\n  endif(THIRD_PARTY)\n\nendif()\n"
  },
  {
    "path": "cmake/third_party/oneDNN.cmake",
    "content": "include(ExternalProject)\ninclude(GNUInstallDirs)\n\nset(ONEDNN_INSTALL_DIR ${THIRD_PARTY_DIR}/onednn)\nset(ONEDNN_INCLUDE_DIR ${ONEDNN_INSTALL_DIR}/include)\nset(ONEDNN_LIBRARY_DIR ${ONEDNN_INSTALL_DIR}/${CMAKE_INSTALL_LIBDIR})\n\nset(ONEDNN_URL https://github.com/oneapi-src/oneDNN/archive/refs/tags/v2.4.3.tar.gz)\nuse_mirror(VARIABLE ONEDNN_URL URL ${ONEDNN_URL})\n\nif(WIN32)\n  message(FATAL_ERROR \"Windows system does not support onednn\")\nelse()\n  if(BUILD_CPP_API)\n    set(ONEDNN_BUILD_SHARED_LIBS OFF)\n  else()\n    set(ONEDNN_BUILD_SHARED_LIBS ${BUILD_SHARED_LIBS})\n  endif()\n\n  if(ONEDNN_BUILD_SHARED_LIBS)\n    if(\"${CMAKE_SHARED_LIBRARY_SUFFIX}\" STREQUAL \".dylib\")\n      set(ONEDNN_LIBRARY_NAMES libdnnl.dylib)\n    elseif(\"${CMAKE_SHARED_LIBRARY_SUFFIX}\" STREQUAL \".so\")\n      set(ONEDNN_LIBRARY_NAMES libdnnl.so)\n      set(DNNL_LIBRARY_TYPE SHARED)\n      set(DNNL_LIBRARY_RPATH ON)\n    else()\n      message(FATAL_ERROR \"${CMAKE_SHARED_LIBRARY_SUFFIX} not support for onednn\")\n    endif()\n  else()\n    set(ONEDNN_LIBRARY_NAMES libdnnl.a)\n    set(DNNL_LIBRARY_TYPE STATIC)\n    set(DNNL_LIBRARY_RPATH OFF)\n  endif()\nendif()\n\nforeach(LIBRARY_NAME ${ONEDNN_LIBRARY_NAMES})\n  list(APPEND ONEDNN_STATIC_LIBRARIES ${ONEDNN_LIBRARY_DIR}/${LIBRARY_NAME})\nendforeach()\n\n# the order of the following codes can't be changed\nset(ONEDNN_CPU_RUNTIME SEQ)\nif(WITH_OMP)\n  set(ONEDNN_CPU_RUNTIME OMP)\nendif()\nif(WITH_TBB)\n  set(ONEDNN_CPU_RUNTIME TBB)\n  set(ONEDNN_DEPENDS install-tbb)\nendif()\n\nif(THIRD_PARTY)\n  ExternalProject_Add(\n    onednn\n    PREFIX onednn\n    DEPENDS ${ONEDNN_DEPENDS}\n    URL ${ONEDNN_URL}\n    URL_MD5 c60ea96acbaccec053be7e3fa81c6184\n    UPDATE_COMMAND \"\"\n    BUILD_IN_SOURCE 1\n    BUILD_BYPRODUCTS ${ONEDNN_STATIC_LIBRARIES}\n    CMAKE_CACHE_ARGS\n      -DCMAKE_INSTALL_PREFIX:STRING=${ONEDNN_INSTALL_DIR}\n      -DCMAKE_INSTALL_MESSAGE:STRING=${CMAKE_INSTALL_MESSAGE}\n      -DCMAKE_C_COMPILER_LAUNCHER:STRING=${CMAKE_C_COMPILER_LAUNCHER}\n      -DCMAKE_CXX_COMPILER_LAUNCHER:STRING=${CMAKE_CXX_COMPILER_LAUNCHER}\n      -DCMAKE_POLICY_DEFAULT_CMP0074:STRING=NEW\n      -DCMAKE_BUILD_TYPE:STRING=${CMAKE_BUILD_TYPE}\n      -DCMAKE_CXX_FLAGS_DEBUG:STRING=${CMAKE_CXX_FLAGS_DEBUG}\n      -DCMAKE_CXX_FLAGS_RELEASE:STRING=${CMAKE_CXX_FLAGS_RELEASE}\n      -DCMAKE_C_FLAGS_DEBUG:STRING=${CMAKE_C_FLAGS_DEBUG}\n      -DCMAKE_C_FLAGS_RELEASE:STRING=${CMAKE_C_FLAGS_RELEASE}\n      -DDNNL_IS_MAIN_PROJECT:BOOL=OFF\n      -DDNNL_BUILD_EXAMPLES:BOOL=OFF\n      -DDNNL_BUILD_TESTS:BOOL=OFF\n      -DDNNL_LIBRARY_TYPE:STRING=${DNNL_LIBRARY_TYPE}\n      -DCMAKE_INSTALL_RPATH_USE_LINK_PATH:BOOL=${DNNL_LIBRARY_RPATH}\n      -DCMAKE_INSTALL_RPATH:STRING=${ONETBB_INSTALL_DIR}\n      -DDNNL_CPU_RUNTIME:STRING=${ONEDNN_CPU_RUNTIME}\n      -DTBBROOT:STRING=${ONETBB_INSTALL_DIR}\n      -DTBB_ROOT:STRING=${ONETBB_INSTALL_DIR}/lib/cmake/TBB)\n\nendif(THIRD_PARTY)\nadd_library(onednn_imported UNKNOWN IMPORTED)\nset_property(TARGET onednn_imported PROPERTY IMPORTED_LOCATION \"${ONEDNN_STATIC_LIBRARIES}\")\n"
  },
  {
    "path": "cmake/third_party/opencv.cmake",
    "content": "include(ExternalProject)\ninclude(GNUInstallDirs)\n\nset(OPENCV_INSTALL_DIR ${THIRD_PARTY_DIR}/opencv)\nset(OPENCV_INCLUDE_DIR ${OPENCV_INSTALL_DIR}/include)\nset(LIBPNG_INSTALL_DIR ${THIRD_PARTY_DIR}/libpng)\nset(LIBPNG_INCLUDE_DIR ${LIBPNG_INSTALL_DIR}/include)\nset(OPENCV_LIBRARY_DIR ${OPENCV_INSTALL_DIR}/${CMAKE_INSTALL_LIBDIR})\nset(OPENCV_3RDPARTY_LIBRARY_DIR ${OPENCV_INSTALL_DIR}/share/OpenCV/3rdparty/${CMAKE_INSTALL_LIBDIR})\n\nset(OPENCV_SRC_DIR ${CMAKE_CURRENT_BINARY_DIR}/opencv/src/opencv/src)\nset(OPENCV_URL\n    https://github.com/opencv/opencv/archive/83391ac59d270f2148fc99a62ae279b04d37f5d0.tar.gz)\nuse_mirror(VARIABLE OPENCV_URL URL ${OPENCV_URL})\n\nset(OPENCV_LIBRARY_NAMES libopencv_imgcodecs.a libopencv_imgproc.a libopencv_core.a)\nset(OPENCV_3RDPARTY_LIBRARY_NAMES libIlmImf.a liblibjasper.a liblibpng.a liblibtiff.a liblibwebp.a)\n\nforeach(LIBRARY_NAME ${OPENCV_LIBRARY_NAMES})\n  list(APPEND OPENCV_STATIC_LIBRARIES ${OPENCV_LIBRARY_DIR}/${LIBRARY_NAME})\nendforeach()\n\nforeach(LIBRARY_NAME ${OPENCV_3RDPARTY_LIBRARY_NAMES})\n  list(APPEND OPENCV_STATIC_LIBRARIES ${OPENCV_3RDPARTY_LIBRARY_DIR}/${LIBRARY_NAME})\nendforeach()\n\nif(THIRD_PARTY)\n\n  if(CMAKE_C_COMPILER_LAUNCHER STREQUAL \"ccache\")\n    set(OPENCV_C_COMPILER_LAUNCHER_DEF \"-DENABLE_CCACHE:BOOL=ON\")\n  else()\n    set(OPENCV_C_COMPILER_LAUNCHER_DEF\n        \"-DCMAKE_C_COMPILER_LAUNCHER:STRING=${CMAKE_C_COMPILER_LAUNCHER}\")\n  endif()\n\n  if(CMAKE_CXX_COMPILER_LAUNCHER STREQUAL \"ccache\")\n    set(OPENCV_CXX_COMPILER_LAUNCHER_DEF \"-DENABLE_CCACHE:BOOL=ON\")\n  else()\n    set(OPENCV_CXX_COMPILER_LAUNCHER_DEF\n        \"-DCMAKE_CXX_COMPILER_LAUNCHER:STRING=${CMAKE_CXX_COMPILER_LAUNCHER}\")\n  endif()\n\n  ExternalProject_Add(\n    opencv\n    DEPENDS libjpeg_copy_headers_to_destination libjpeg_copy_libs_to_destination\n    PREFIX opencv\n    URL ${OPENCV_URL}\n    URL_MD5 b09dc79dec7766a3550907bcafc8bbf5\n    UPDATE_COMMAND \"\"\n    PATCH_COMMAND cmake -E make_directory ${CMAKE_CURRENT_BINARY_DIR}/opencv/src/opencv/build\n    BUILD_IN_SOURCE 0\n    SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/opencv/src/opencv\n    BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR}/opencv/src/opencv/build\n    BUILD_BYPRODUCTS ${OPENCV_STATIC_LIBRARIES}\n    CMAKE_CACHE_ARGS\n      ${OPENCV_C_COMPILER_LAUNCHER_DEF}\n      ${OPENCV_CXX_COMPILER_LAUNCHER_DEF}\n      -DCMAKE_POLICY_DEFAULT_CMP0074:STRING=NEW\n      -DCMAKE_BUILD_TYPE:STRING=${CMAKE_BUILD_TYPE}\n      -DCMAKE_INSTALL_PREFIX:STRING=${OPENCV_INSTALL_DIR}\n      -DCMAKE_INSTALL_MESSAGE:STRING=${CMAKE_INSTALL_MESSAGE}\n      -DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF\n      -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON\n      -DCMAKE_CXX_FLAGS_DEBUG:STRING=${CMAKE_CXX_FLAGS_DEBUG}\n      -DWITH_IPP:BOOL=OFF\n      -DWITH_1394:BOOL=OFF\n      -DWITH_AVFOUNDATION:BOOL=OFF\n      -DWITH_CAROTENE:BOOL=OFF\n      -DWITH_CPUFEATURES:BOOL=OFF\n      -DWITH_VTK:BOOL=OFF\n      -DWITH_CUDA:BOOL=OFF\n      -DWITH_CUFFT:BOOL=OFF\n      -DWITH_CUBLAS:BOOL=OFF\n      -DWITH_NVCUVID:BOOL=OFF\n      -DWITH_EIGEN:BOOL=OFF\n      -DWITH_VFW:BOOL=OFF\n      -DWITH_FFMPEG:BOOL=OFF\n      -DWITH_WEBP:BOOL=ON\n      -DBUILD_WEBP:BOOL=ON\n      -DWITH_GSTREAMER:BOOL=OFF\n      -DWITH_GSTREAMER_0_10:BOOL=OFF\n      -DWITH_GTK:BOOL=OFF\n      -DWITH_GTK_2_X:BOOL=OFF\n      -DWITH_WIN32UI:BOOL=OFF\n      -DWITH_PTHREADS_PF:BOOL=OFF\n      -DWITH_DSHOW:BOOL=OFF\n      -DWITH_OPENCL:BOOL=OFF\n      -DWITH_OPENCL_SVM:BOOL=OFF\n      -DWITH_OPENCLAMDFFT:BOOL=OFF\n      -DWITH_OPENCLAMDBLAS:BOOL=OFF\n      -DWITH_DIRECTX:BOOL=OFF\n      -DWITH_MATLAB:BOOL=OFF\n      -DWITH_GPHOTO2:BOOL=OFF\n      -DWITH_LAPACK:BOOL=OFF\n      -DBUILD_SHARED_LIBS:BOOL=OFF\n      -DBUILD_ANDROID_EXAMPLES:BOOL=OFF\n      -DBUILD_DOCS:BOOL=OFF\n      -DBUILD_PACKAGE:BOOL=OFF\n      -DBUILD_PERF_TESTS:BOOL=OFF\n      -DBUILD_TESTS:BOOL=OFF\n      -DBUILD_FAT_JAVA_LIBS:BOOL=OFF\n      -DBUILD_ANDROID_SERVICE:BOOL=OFF\n      -DBUILD_CUDA_STUBS:BOOL=OFF\n      -DENABLE_PYLINT:BOOL=OFF\n      -DBUILD_opencv_python3:BOOL=OFF\n      -DBUILD_opencv_python2:BOOL=OFF\n      -DBUILD_opencv_world:BOOL=OFF\n      -DBUILD_opencv_apps:BOOL=OFF\n      -DBUILD_opencv_js:BOOL=OFF\n      -DBUILD_ZLIB:BOOL=OFF\n      -DZLIB_ROOT:PATH=${ZLIB_INSTALL}\n      -DBUILD_TIFF:BOOL=ON\n      -DBUILD_JASPER:BOOL=ON\n      -DWITH_JPEG:BOOL=ON\n      -DBUILD_JPEG:BOOL=OFF\n      -DJPEG_INCLUDE_DIR:STRING=${LIBJPEG_INCLUDE_DIR}\n      -DJPEG_LIBRARY:STRING=${LIBJPEG_STATIC_LIBRARIES}\n      -DBUILD_PNG:BOOL=ON\n      -DBUILD_OPENEXR:BOOL=ON\n      -DBUILD_TBB:BOOL=ON\n      -DBUILD_IPP_IW:BOOL=OFF\n      -DWITH_ITT:BOOL=OFF\n      -DBUILD_opencv_flann:BOOL=OFF\n      -DBUILD_opencv_ml:BOOL=OFF\n      -DBUILD_opencv_objdetect:BOOL=OFF\n      -DBUILD_opencv_photo:BOOL=OFF\n      -DBUILD_opencv_video:BOOL=OFF\n      -DBUILD_opencv_dnn:BOOL=OFF\n      -DBUILD_opencv_shape:BOOL=OFF\n      -DBUILD_opencv_videoio:BOOL=OFF\n      -DBUILD_opencv_highgui:BOOL=OFF\n      -DBUILD_opencv_superres:BOOL=OFF\n      -DBUILD_opencv_features2d:BOOL=OFF\n      -DBUILD_opencv_calib3d:BOOL=OFF\n      -DBUILD_opencv_stitching:BOOL=OFF\n      -DBUILD_opencv_videostab:BOOL=OFF\n      -DBUILD_opencv_imgproc:BOOL=ON\n      -DBUILD_opencv_imgcodecs:BOOL=ON\n      -DENABLE_CXX11:BOOL=ON\n      # -DLIB_SUFFIX:STRING=64\n  )\n\n  if(WITH_ZLIB)\n    add_dependencies(opencv zlib)\n  endif()\n\n  install(\n    FILES ${CMAKE_CURRENT_BINARY_DIR}/opencv/src/opencv/3rdparty/libpng/pngconf.h\n          ${CMAKE_CURRENT_BINARY_DIR}/opencv/src/opencv/3rdparty/libpng/pngdebug.h\n          ${CMAKE_CURRENT_BINARY_DIR}/opencv/src/opencv/3rdparty/libpng/png.h\n          ${CMAKE_CURRENT_BINARY_DIR}/opencv/src/opencv/3rdparty/libpng/pnginfo.h\n          ${CMAKE_CURRENT_BINARY_DIR}/opencv/src/opencv/3rdparty/libpng/pnglibconf.h\n          ${CMAKE_CURRENT_BINARY_DIR}/opencv/src/opencv/3rdparty/libpng/pngpriv.h\n          ${CMAKE_CURRENT_BINARY_DIR}/opencv/src/opencv/3rdparty/libpng/pngstruct.h\n    TYPE INCLUDE\n    COMPONENT libpng_headers)\n  add_custom_target(\n    install_libpng_headers\n    COMMAND\n      \"${CMAKE_COMMAND}\" -DCMAKE_INSTALL_COMPONENT=libpng_headers\n      -DCMAKE_INSTALL_PREFIX=\"${LIBPNG_INSTALL_DIR}\"\n      -DCMAKE_INSTALL_MESSAGE=${CMAKE_INSTALL_MESSAGE} -P \"${CMAKE_BINARY_DIR}/cmake_install.cmake\"\n    DEPENDS opencv)\nendif(THIRD_PARTY)\n"
  },
  {
    "path": "cmake/third_party/openssl.cmake",
    "content": "include(ExternalProject)\n\nset(OPENSSL_INSTALL ${THIRD_PARTY_DIR}/openssl)\nset(OPENSSL_INCLUDE_DIR ${THIRD_PARTY_DIR}/openssl/include)\nset(OPENSSL_LIBRARY_DIR ${THIRD_PARTY_DIR}/openssl/lib)\n\nset(OPENSSL_TAR_URL https://github.com/openssl/openssl/archive/OpenSSL_1_1_1g.tar.gz)\nuse_mirror(VARIABLE OPENSSL_TAR_URL URL ${OPENSSL_TAR_URL})\nset(OPENSSL_URL_HASH dd32f35dd5d543c571bc9ebb90ebe54e)\nset(OPENSSL_SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/openssl)\n\nif(WIN32)\n  set(OPENSSL_BUILD_LIBRARY_DIR ${OPENSSL_INSTALL}/lib)\n  set(OPENSSL_LIBRARY_NAMES ssl.lib crypto.lib)\nelseif(APPLE AND (\"${CMAKE_GENERATOR}\" STREQUAL \"Xcode\"))\n  set(OPENSSL_BUILD_LIBRARY_DIR ${OPENSSL_INSTALL}/lib)\n  set(OPENSSL_LIBRARY_NAMES libssl.a libcrypto.a)\nelse()\n  set(OPENSSL_BUILD_LIBRARY_DIR ${OPENSSL_INSTALL}/lib)\n  set(OPENSSL_LIBRARY_NAMES libssl.a libcrypto.a)\nendif()\n\nforeach(LIBRARY_NAME ${OPENSSL_LIBRARY_NAMES})\n  list(APPEND OPENSSL_STATIC_LIBRARIES ${OPENSSL_LIBRARY_DIR}/${LIBRARY_NAME})\n  list(APPEND OPENSSL_BUILD_STATIC_LIBRARIES ${OPENSSL_BUILD_LIBRARY_DIR}/${LIBRARY_NAME})\nendforeach()\n\nif(THIRD_PARTY)\n\n  include(ProcessorCount)\n  ProcessorCount(PROC_NUM)\n  ExternalProject_Add(\n    openssl\n    PREFIX openssl\n    URL ${OPENSSL_TAR_URL}\n    URL_HASH MD5=${OPENSSL_URL_HASH}\n    UPDATE_COMMAND \"\"\n    CONFIGURE_COMMAND ${OPENSSL_SOURCE_DIR}/src/openssl/config --prefix=${OPENSSL_INSTALL}\n    BUILD_BYPRODUCTS ${OPENSSL_STATIC_LIBRARIES}\n    BUILD_COMMAND make -j${PROC_NUM}\n    INSTALL_COMMAND make install_sw)\n\nendif(THIRD_PARTY)\n"
  },
  {
    "path": "cmake/third_party/patches/tensorflow-logging.patch",
    "content": "--- ./build/third_party_install/tensorflow/include/tensorflow_inc/tensorflow/stream_executor/platform/logging.h\t2021-06-22 16:41:20.000000000 +0800\n+++ logging.h\t2021-08-16 19:41:43.082449275 +0800\n@@ -19,7 +19,7 @@\n #include \"tensorflow/core/platform/logging.h\"\n #include \"tensorflow/stream_executor/platform/port.h\"\n \n-#if !defined(PLATFORM_GOOGLE) && !defined(PLATFORM_GOOGLE_ANDROID)\n+#if !defined(PLATFORM_GOOGLE) && !defined(PLATFORM_GOOGLE_ANDROID) && !defined(GOOGLE_LOGGING)\n \n #define PCHECK(invocation) CHECK(invocation)\n \n"
  },
  {
    "path": "cmake/third_party/protobuf.cmake",
    "content": "include(ExternalProject)\n\nset(PROTOBUF_INSTALL_DIR ${THIRD_PARTY_DIR}/protobuf)\nset(PROTOBUF_INSTALL_INCLUDEDIR include)\nset(PROTOBUF_INSTALL_LIBDIR lib)\nset(PROTOBUF_INSTALL_BINDIR bin)\nset(PROTOBUF_INCLUDE_DIR ${PROTOBUF_INSTALL_DIR}/${PROTOBUF_INSTALL_INCLUDEDIR})\nset(PROTOBUF_LIBRARY_DIR ${PROTOBUF_INSTALL_DIR}/${PROTOBUF_INSTALL_LIBDIR})\nset(PROTOBUF_BINARY_DIR ${PROTOBUF_INSTALL_DIR}/${PROTOBUF_INSTALL_BINDIR})\n\nset(PROTOBUF_SRC_DIR ${CMAKE_CURRENT_BINARY_DIR}/protobuf/src/protobuf/src)\nset(PROTOBUF_URL \"https://github.com/protocolbuffers/protobuf/archive/v3.9.2.zip\")\nset(PROTOBUF_MD5 cf02c32870a1f78c860039e0f63a6343)\n\nuse_mirror(VARIABLE PROTOBUF_URL URL ${PROTOBUF_URL})\n\nif(WIN32)\n  set(PROTOBUF_LIBRARY_NAMES libprotobufd.lib)\n  set(PROTOC_EXECUTABLE_NAME protoc.exe)\n  set(PROTOBUF_ADDITIONAL_CMAKE_OPTIONS -Dprotobuf_MSVC_STATIC_RUNTIME:BOOL=ON -A x64)\nelse()\n  # NOTE: (houjiang, shenghang), to support xrt, must make libproto built as shared\n  if(\"${CMAKE_SHARED_LIBRARY_SUFFIX}\" STREQUAL \".dylib\")\n    set(PROTOBUF_LIBRARY_NAMES libprotobuf.dylib)\n  elseif(\"${CMAKE_SHARED_LIBRARY_SUFFIX}\" STREQUAL \".so\")\n    set(PROTOBUF_LIBRARY_NAMES libprotobuf.so)\n  else()\n    message(FATAL_ERROR \"${CMAKE_SHARED_LIBRARY_SUFFIX} not support for protobuf\")\n  endif()\n  set(PROTOBUF_BUILD_SHARED_LIBS ON)\n  set(PROTOC_EXECUTABLE_NAME protoc)\nendif()\n\nforeach(LIBRARY_NAME ${PROTOBUF_LIBRARY_NAMES})\n  list(APPEND PROTOBUF_STATIC_LIBRARIES ${PROTOBUF_LIBRARY_DIR}/${LIBRARY_NAME})\nendforeach()\n\nset(PROTOBUF_PROTOC_EXECUTABLE ${PROTOBUF_BINARY_DIR}/${PROTOC_EXECUTABLE_NAME})\n\nif(THIRD_PARTY)\n\n  ExternalProject_Add(\n    protobuf\n    PREFIX protobuf\n    URL ${PROTOBUF_URL}\n    URL_MD5 ${PROTOBUF_MD5}\n    UPDATE_COMMAND \"\"\n    BUILD_IN_SOURCE 1\n    SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/protobuf/src/protobuf\n    SOURCE_SUBDIR cmake\n    BUILD_BYPRODUCTS ${PROTOBUF_STATIC_LIBRARIES}\n    CMAKE_CACHE_ARGS\n      -DCMAKE_C_COMPILER_LAUNCHER:STRING=${CMAKE_C_COMPILER_LAUNCHER}\n      -DCMAKE_CXX_COMPILER_LAUNCHER:STRING=${CMAKE_CXX_COMPILER_LAUNCHER}\n      -DCMAKE_POLICY_DEFAULT_CMP0074:STRING=NEW\n      -DCMAKE_BUILD_TYPE:STRING=${CMAKE_BUILD_TYPE}\n      -DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF\n      -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON\n      -DZLIB_ROOT:PATH=${ZLIB_INSTALL}\n      -Dprotobuf_WITH_ZLIB:BOOL=${WITH_ZLIB}\n      -DCMAKE_CXX_FLAGS_DEBUG:STRING=${CMAKE_CXX_FLAGS_DEBUG}\n      -DBUILD_SHARED_LIBS:BOOL=${PROTOBUF_BUILD_SHARED_LIBS}\n      -Dprotobuf_BUILD_SHARED_LIBS:BOOL=${PROTOBUF_BUILD_SHARED_LIBS}\n      -Dprotobuf_BUILD_TESTS:BOOL=OFF\n      -DCMAKE_INSTALL_PREFIX:STRING=${PROTOBUF_INSTALL_DIR}\n      -DCMAKE_INSTALL_INCLUDEDIR:STRING=${PROTOBUF_INSTALL_INCLUDEDIR}\n      -DCMAKE_INSTALL_LIBDIR:STRING=${PROTOBUF_INSTALL_LIBDIR}\n      -DCMAKE_INSTALL_BINDIR:STRING=${PROTOBUF_INSTALL_BINDIR}\n      -DCMAKE_INSTALL_MESSAGE:STRING=${CMAKE_INSTALL_MESSAGE}\n      -Dprotobuf_DEBUG_POSTFIX:STRING=\n      ${PROTOBUF_ADDITIONAL_CMAKE_OPTIONS})\n  if(WITH_ZLIB)\n    add_dependencies(protobuf zlib)\n  endif()\nelse()\n  add_custom_target(protobuf)\nendif(THIRD_PARTY)\nadd_library(protobuf_imported UNKNOWN IMPORTED)\nset_property(TARGET protobuf_imported PROPERTY IMPORTED_LOCATION \"${PROTOBUF_STATIC_LIBRARIES}\")\n"
  },
  {
    "path": "cmake/third_party/re2.cmake",
    "content": "include(ExternalProject)\n\nset(RE2_PROJECT re2)\n\nset(RE2_INSTALL_DIR ${THIRD_PARTY_DIR}/re2)\n\nset(RE2_INCLUDE_DIR ${RE2_INSTALL_DIR}/include CACHE PATH \"\" FORCE)\nset(RE2_LIBRARY_DIR ${RE2_INSTALL_DIR}/lib CACHE PATH \"\" FORCE)\nset(RE2_LIBRARIES ${RE2_LIBRARY_DIR}/libre2.a)\nset(RE2_URL https://github.com/Oneflow-Inc/re2/archive/refs/tags/e17af7789.tar.gz)\nuse_mirror(VARIABLE RE2_URL URL ${RE2_URL})\n\nif(THIRD_PARTY)\n  ExternalProject_Add(\n    ${RE2_PROJECT}\n    PREFIX re2\n    URL ${RE2_URL}\n    URL_MD5 3b2e20c1edd1cfe887aeef3b0747eac0\n    UPDATE_COMMAND \"\"\n    BUILD_BYPRODUCTS ${RE2_LIBRARIES}\n    CMAKE_ARGS -DCMAKE_BUILD_TYPE:STRING=${CMAKE_BUILD_TYPE}\n               -DBUILD_SHARED_LIBS:BOOL=OFF\n               -DCMAKE_CXX_FLAGS:STRING=${CMAKE_CXX_FLAGS}\n               -DCMAKE_CXX_FLAGS_DEBUG:STRING=${CMAKE_CXX_FLAGS_DEBUG}\n               -DCMAKE_CXX_FLAGS_RELEASE:STRING=${CMAKE_CXX_FLAGS_RELEASE}\n    CMAKE_CACHE_ARGS\n      -DCMAKE_C_COMPILER_LAUNCHER:STRING=${CMAKE_C_COMPILER_LAUNCHER}\n      -DCMAKE_CXX_COMPILER_LAUNCHER:STRING=${CMAKE_CXX_COMPILER_LAUNCHER}\n      -DCMAKE_INSTALL_PREFIX:PATH=${RE2_INSTALL_DIR}\n      -DCMAKE_INSTALL_LIBDIR:PATH=${RE2_LIBRARY_DIR}\n      -DCMAKE_INSTALL_MESSAGE:STRING=${CMAKE_INSTALL_MESSAGE}\n      -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON\n      -DRE2_BUILD_TESTING:BOOL=OFF\n      -DCMAKE_BUILD_TYPE:STRING=${CMAKE_BUILD_TYPE})\nendif(THIRD_PARTY)\n"
  },
  {
    "path": "cmake/third_party/trt_flash_attention.cmake",
    "content": "include(ExternalProject)\n\nfind_package(Threads)\n\nset(TRT_FLASH_ATTENTION_PROJECT trt_flash_attention)\n\nset(TRT_FLASH_ATTENTION_URL\n    https://github.com/Oneflow-Inc/trt_flash_attention/archive/d8b74631eb811c95a0d20f247238db6e91acafe3.zip\n)\nuse_mirror(VARIABLE TRT_FLASH_ATTENTION_URL URL ${TRT_FLASH_ATTENTION_URL})\nset(TRT_FLASH_ATTENTION_MD5 9e0e822ce1450e11515533fbe32e58a9)\n\nset(TRT_FLASH_ATTENTION_INSTALL_DIR ${THIRD_PARTY_DIR}/trt_flash_attention)\nset(TRT_FLASH_ATTENTION_INCLUDE_DIR ${TRT_FLASH_ATTENTION_INSTALL_DIR}/include CACHE PATH \"\" FORCE)\nset(TRT_FLASH_ATTENTION_LIBRARY_DIR ${TRT_FLASH_ATTENTION_INSTALL_DIR}/lib CACHE PATH \"\" FORCE)\nset(TRT_FLASH_ATTENTION_LIBRARIES ${TRT_FLASH_ATTENTION_LIBRARY_DIR}/libtrt_flash_attention.so)\n\nif(THIRD_PARTY)\n  ExternalProject_Add(\n    ${TRT_FLASH_ATTENTION_PROJECT}\n    PREFIX trt_flash_attention\n    URL ${TRT_FLASH_ATTENTION_URL}\n    URL_MD5 ${TRT_FLASH_ATTENTION_MD5}\n    UPDATE_COMMAND \"\"\n    BUILD_BYPRODUCTS ${TRT_FLASH_ATTENTION_LIBRARIES}\n    CMAKE_ARGS -DCMAKE_BUILD_TYPE:STRING=${CMAKE_BUILD_TYPE}\n               -DCMAKE_CXX_FLAGS:STRING=${CMAKE_CXX_FLAGS}\n               -DCMAKE_CXX_FLAGS_DEBUG:STRING=${CMAKE_CXX_FLAGS_DEBUG}\n               -DCMAKE_CXX_FLAGS_RELEASE:STRING=${CMAKE_CXX_FLAGS_RELEASE}\n    CMAKE_CACHE_ARGS\n      -DCMAKE_CUDA_COMPILER:STRING=${CUDAToolkit_NVCC_EXECUTABLE}\n      -DCMAKE_C_COMPILER_LAUNCHER:STRING=${CMAKE_C_COMPILER_LAUNCHER}\n      -DCMAKE_CXX_COMPILER_LAUNCHER:STRING=${CMAKE_CXX_COMPILER_LAUNCHER}\n      -DCMAKE_INSTALL_PREFIX:PATH=${TRT_FLASH_ATTENTION_INSTALL_DIR}\n      -DCMAKE_INSTALL_LIBDIR:PATH=${TRT_FLASH_ATTENTION_LIBRARY_DIR}\n      -DCMAKE_INSTALL_MESSAGE:STRING=${CMAKE_INSTALL_MESSAGE}\n      -DCMAKE_BUILD_TYPE:STRING=${CMAKE_BUILD_TYPE})\nendif(THIRD_PARTY)\n"
  },
  {
    "path": "cmake/third_party/zlib.cmake",
    "content": "include(ExternalProject)\n\nset(ZLIB_INSTALL ${THIRD_PARTY_DIR}/zlib)\nset(ZLIB_INCLUDE_DIR ${ZLIB_INSTALL}/include)\nset(ZLIB_LIBRARY_DIR ${ZLIB_INSTALL}/lib)\nset(ZLIB_URL https://github.com/madler/zlib/archive/v1.2.8.tar.gz)\nuse_mirror(VARIABLE ZLIB_URL URL ${ZLIB_URL})\n\n# only use zlib shared lib to prevent using zlib in the system\nif(WIN32)\n  set(ZLIB_LIBRARY_NAMES zlibstaticd.lib)\nelse()\n  if(\"${CMAKE_SHARED_LIBRARY_SUFFIX}\" STREQUAL \".dylib\")\n    set(ZLIB_LIBRARY_NAMES libz.dylib)\n  elseif(\"${CMAKE_SHARED_LIBRARY_SUFFIX}\" STREQUAL \".so\")\n    set(ZLIB_LIBRARY_NAMES libz.so)\n  else()\n    message(FATAL_ERROR \"${CMAKE_SHARED_LIBRARY_SUFFIX} not support for zlib\")\n  endif()\nendif()\n\nforeach(LIBRARY_NAME ${ZLIB_LIBRARY_NAMES})\n  list(APPEND ZLIB_STATIC_LIBRARIES ${ZLIB_LIBRARY_DIR}/${LIBRARY_NAME})\nendforeach()\n\nset(ZLIB_HEADERS \"${ZLIB_INSTALL}/include/zconf.h\" \"${ZLIB_INSTALL}/include/zlib.h\")\n\nif(THIRD_PARTY)\n\n  ExternalProject_Add(\n    zlib\n    PREFIX zlib\n    URL ${ZLIB_URL}\n    URL_MD5 1eabf2698dc49f925ce0ffb81397098f\n    UPDATE_COMMAND \"\"\n    BUILD_IN_SOURCE 1\n    BUILD_BYPRODUCTS ${ZLIB_STATIC_LIBRARIES}\n    CMAKE_CACHE_ARGS\n      -DCMAKE_C_COMPILER_LAUNCHER:STRING=${CMAKE_C_COMPILER_LAUNCHER}\n      -DCMAKE_CXX_COMPILER_LAUNCHER:STRING=${CMAKE_CXX_COMPILER_LAUNCHER}\n      -DCMAKE_BUILD_TYPE:STRING=${CMAKE_BUILD_TYPE}\n      -DBUILD_SHARED_LIBS:BOOL=${BUILD_SHARED_LIBS}\n      -DCMAKE_CXX_FLAGS:STRING=${CMAKE_CXX_FLAGS}\n      -DCMAKE_CXX_FLAGS_DEBUG:STRING=${CMAKE_CXX_FLAGS_DEBUG}\n      -DCMAKE_CXX_FLAGS_RELEASE:STRING=${CMAKE_CXX_FLAGS_RELEASE}\n      -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON\n      -DCMAKE_INSTALL_PREFIX:STRING=${ZLIB_INSTALL}\n      -DCMAKE_INSTALL_MESSAGE:STRING=${CMAKE_INSTALL_MESSAGE})\n\nendif(THIRD_PARTY)\nadd_library(zlib_imported UNKNOWN IMPORTED)\nset_property(TARGET zlib_imported PROPERTY IMPORTED_LOCATION \"${ZLIB_STATIC_LIBRARIES}\")\n"
  },
  {
    "path": "cmake/third_party.cmake",
    "content": "cmake_policy(SET CMP0074 NEW)\nif(NOT WIN32)\n  find_package(Threads)\nendif()\n\nif(WITH_ZLIB)\n  include(zlib)\nendif()\ninclude(protobuf)\ninclude(googletest)\ninclude(glog)\ninclude(libjpeg-turbo)\ninclude(opencv)\ninclude(eigen)\nif(WITH_COCOAPI)\n  include(cocoapi)\nendif()\ninclude(half)\ninclude(re2)\ninclude(json)\nif(RPC_BACKEND MATCHES \"GRPC\")\n  include(absl)\n  include(cares)\n  include(openssl)\n  include(grpc)\nendif()\ninclude(flatbuffers)\n\ninclude(hwloc)\nif(WITH_ONEDNN)\n  include(oneDNN)\nendif()\n\nset_mirror_url_with_hash(INJA_URL https://github.com/pantor/inja/archive/refs/tags/v3.3.0.zip\n                         611e6b7206d0fb89728a3879f78b4775)\n\nif(NOT WIN32)\n  set(BLA_STATIC ON)\n  set(BLA_VENDOR \"Intel10_64lp_seq\")\n  find_package(BLAS)\n  if(NOT BLAS_FOUND)\n    set(BLA_VENDOR \"All\")\n    find_package(BLAS)\n  endif()\nelse()\n  set(MKL_LIB_PATH\n      \"C:/Program Files (x86)/IntelSWTools/compilers_and_libraries_2017/windows/mkl/lib/intel64_win\"\n  )\n  set(BLAS_LIBRARIES ${MKL_LIB_PATH}/mkl_core_dll.lib ${MKL_LIB_PATH}/mkl_sequential_dll.lib\n                     ${MKL_LIB_PATH}/mkl_intel_lp64_dll.lib)\nendif()\nmessage(STATUS \"Found Blas Lib: \" ${BLAS_LIBRARIES})\n\nset(oneflow_test_libs gtest_main)\n\nset(oneflow_third_party_libs\n    protobuf_imported\n    ${GRPC_STATIC_LIBRARIES}\n    ${farmhash_STATIC_LIBRARIES}\n    ${BLAS_LIBRARIES}\n    ${OPENCV_STATIC_LIBRARIES}\n    ${COCOAPI_STATIC_LIBRARIES}\n    ${LIBJPEG_STATIC_LIBRARIES}\n    ${ABSL_STATIC_LIBRARIES}\n    ${OPENSSL_STATIC_LIBRARIES}\n    ${CMAKE_THREAD_LIBS_INIT}\n    ${FLATBUFFERS_STATIC_LIBRARIES}\n    nlohmann_json::nlohmann_json)\nif(WITH_ONEDNN)\n  set(oneflow_third_party_libs ${oneflow_third_party_libs} ${ONEDNN_STATIC_LIBRARIES})\nendif()\n\nlist(APPEND oneflow_third_party_libs ${RE2_LIBRARIES})\n\nif(WITH_ZLIB)\n  list(APPEND oneflow_third_party_libs zlib_imported)\nendif()\n\nif(WIN32)\n  # static gflags lib requires \"PathMatchSpecA\" defined in \"ShLwApi.Lib\"\n  list(APPEND oneflow_third_party_libs \"ShLwApi.Lib\")\n  list(APPEND oneflow_third_party_libs \"Ws2_32.lib\")\nendif()\n\nset(oneflow_third_party_dependencies\n    protobuf\n    eigen\n    half_copy_headers_to_destination\n    re2\n    opencv\n    install_libpng_headers\n    flatbuffers)\nif(WITH_ONEDNN)\n  list(APPEND oneflow_third_party_dependencies onednn)\nendif()\nif(WITH_ZLIB)\n  list(APPEND oneflow_third_party_dependencies zlib)\nendif()\n\nif(WITH_COCOAPI)\n  list(APPEND oneflow_third_party_dependencies cocoapi_copy_headers_to_destination)\n  list(APPEND oneflow_third_party_dependencies cocoapi_copy_libs_to_destination)\nendif()\n\nif(RPC_BACKEND MATCHES \"GRPC\")\n  list(APPEND oneflow_third_party_dependencies grpc)\nendif()\n\nlist(\n  APPEND\n  ONEFLOW_THIRD_PARTY_INCLUDE_DIRS\n  ${ZLIB_INCLUDE_DIR}\n  ${PROTOBUF_INCLUDE_DIR}\n  ${GRPC_INCLUDE_DIR}\n  ${GLOG_INCLUDE_DIR}\n  ${LIBJPEG_INCLUDE_DIR}\n  ${OPENCV_INCLUDE_DIR}\n  ${LIBPNG_INCLUDE_DIR}\n  ${EIGEN_INCLUDE_DIR}\n  ${COCOAPI_INCLUDE_DIR}\n  ${HALF_INCLUDE_DIR}\n  ${ABSL_INCLUDE_DIR}\n  ${OPENSSL_INCLUDE_DIR}\n  ${FLATBUFFERS_INCLUDE_DIR})\nif(WITH_ONEDNN)\n  list(APPEND ONEFLOW_THIRD_PARTY_INCLUDE_DIRS ${ONEDNN_INCLUDE_DIR})\nendif()\n\nlist(APPEND ONEFLOW_THIRD_PARTY_INCLUDE_DIRS ${RE2_INCLUDE_DIR})\n\nif(BUILD_CUDA)\n  # Always use third_party/cub for Clang CUDA in case of compatibility issues\n  if(\"${CMAKE_CUDA_COMPILER_ID}\" STREQUAL \"NVIDIA\" AND CUDA_VERSION VERSION_GREATER_EQUAL \"11.0\")\n    if(CMAKE_CXX_STANDARD LESS 14)\n      add_definitions(-DTHRUST_IGNORE_DEPRECATED_CPP_DIALECT)\n      add_definitions(-DCUB_IGNORE_DEPRECATED_CPP11)\n    endif()\n    if(CMAKE_COMPILER_IS_GNUCC AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS \"5.0\")\n      add_definitions(-DCUB_IGNORE_DEPRECATED_COMPILER)\n    endif()\n  else()\n    include(cub)\n    list(APPEND oneflow_third_party_dependencies cub_copy_headers_to_destination)\n  endif()\n  include(nccl)\n  include(cutlass)\n  include(trt_flash_attention)\n  if(CUDA_VERSION VERSION_GREATER_EQUAL \"11.7\")\n    include(flash_attention)\n  endif()\n\n  list(APPEND oneflow_third_party_libs ${NCCL_LIBRARIES})\n  list(APPEND oneflow_third_party_libs ${CUDNN_LIBRARIES})\n  list(APPEND oneflow_third_party_libs ${VENDOR_CUDA_LIBRARIES})\n\n  list(APPEND oneflow_third_party_dependencies nccl)\n\n  list(APPEND ONEFLOW_THIRD_PARTY_INCLUDE_DIRS ${CUDNN_INCLUDE_DIRS} ${CUB_INCLUDE_DIR}\n       ${NCCL_INCLUDE_DIR})\n\n  if(WITH_CUTLASS)\n    list(APPEND oneflow_third_party_dependencies cutlass)\n    list(APPEND oneflow_third_party_dependencies cutlass_copy_examples_to_destination)\n    list(APPEND oneflow_third_party_libs ${CUTLASS_LIBRARIES})\n    list(APPEND ONEFLOW_THIRD_PARTY_INCLUDE_DIRS ${CUTLASS_INCLUDE_DIR})\n  endif()\n  list(APPEND oneflow_third_party_dependencies trt_flash_attention)\n  list(APPEND oneflow_third_party_libs ${TRT_FLASH_ATTENTION_LIBRARIES})\n  list(APPEND ONEFLOW_THIRD_PARTY_INCLUDE_DIRS ${TRT_FLASH_ATTENTION_INCLUDE_DIR})\n  if(CUDA_VERSION VERSION_GREATER_EQUAL \"11.7\")\n    list(APPEND oneflow_third_party_dependencies flash_attention)\n    list(APPEND oneflow_third_party_libs ${FLASH_ATTENTION_LIBRARIES})\n    list(APPEND ONEFLOW_THIRD_PARTY_INCLUDE_DIRS ${FLASH_ATTENTION_INCLUDE_DIR})\n  endif()\nendif()\n\nif(BUILD_RDMA)\n  if(UNIX)\n    include(CheckIncludeFiles)\n    include(CheckLibraryExists)\n    check_include_files(infiniband/verbs.h HAVE_VERBS_H)\n    if(HAVE_VERBS_H)\n      add_definitions(-DWITH_RDMA)\n    else()\n      message(FATAL_ERROR \"RDMA head file not found\")\n    endif()\n  else()\n    message(FATAL_ERROR \"UNIMPLEMENTED\")\n  endif()\nendif()\n\nif(BUILD_HWLOC)\n  list(APPEND oneflow_third_party_dependencies hwloc)\n  list(APPEND oneflow_third_party_libs ${ONEFLOW_HWLOC_STATIC_LIBRARIES})\n  list(APPEND oneflow_third_party_libs ${PCIACCESS_STATIC_LIBRARIES})\n  list(APPEND ONEFLOW_THIRD_PARTY_INCLUDE_DIRS ${HWLOC_INCLUDE_DIR})\n  add_definitions(-DWITH_HWLOC)\nendif()\n\ninclude_directories(SYSTEM ${ONEFLOW_THIRD_PARTY_INCLUDE_DIRS})\n\nforeach(oneflow_third_party_lib IN LISTS oneflow_third_party_libs)\n  if(NOT \"${oneflow_third_party_lib}\" MATCHES \"^-l.+\"\n     AND NOT TARGET ${oneflow_third_party_lib}\n     AND \"${oneflow_third_party_lib}\" MATCHES \"^\\/.+\"\n     AND NOT \"${oneflow_third_party_lib}\" MATCHES \"^.+\\.framework\")\n    get_filename_component(IMPORTED_LIB_NAME ${oneflow_third_party_lib} NAME_WE)\n    set(IMPORTED_LIB_NAME \"imported::${IMPORTED_LIB_NAME}\")\n    message(STATUS \"Creating imported lib: ${oneflow_third_party_lib} => ${IMPORTED_LIB_NAME}\")\n    add_library(${IMPORTED_LIB_NAME} UNKNOWN IMPORTED)\n    set_property(TARGET ${IMPORTED_LIB_NAME} PROPERTY IMPORTED_LOCATION\n                                                      \"${oneflow_third_party_lib}\")\n    list(APPEND ONEFLOW_THIRD_PARTY_LIBS_TO_LINK \"${IMPORTED_LIB_NAME}\")\n  else()\n    list(APPEND ONEFLOW_THIRD_PARTY_LIBS_TO_LINK \"${oneflow_third_party_lib}\")\n  endif()\nendforeach()\n\nset(oneflow_third_party_libs ${ONEFLOW_THIRD_PARTY_LIBS_TO_LINK})\nmessage(STATUS \"oneflow_third_party_libs: ${oneflow_third_party_libs}\")\n\nadd_definitions(-DHALF_ENABLE_CPP11_USER_LITERALS=0)\n\nif(THIRD_PARTY)\n  add_custom_target(prepare_oneflow_third_party ALL DEPENDS ${oneflow_third_party_dependencies})\n  if(BUILD_PYTHON)\n    if(NOT ONEFLOW_INCLUDE_DIR MATCHES \"/include$\")\n      message(\n        FATAL_ERROR\n          \"ONEFLOW_INCLUDE_DIR must end with '/include', current value: ${ONEFLOW_INCLUDE_DIR}\")\n    endif()\n    get_filename_component(ONEFLOW_INCLUDE_DIR_PARENT \"${ONEFLOW_INCLUDE_DIR}\" DIRECTORY)\n    foreach(of_include_src_dir ${ONEFLOW_THIRD_PARTY_INCLUDE_DIRS})\n      if(of_include_src_dir MATCHES \"/include$\")\n        # it requires two slashes, but in CMake doc it states only one slash is needed\n        set(of_include_src_dir \"${of_include_src_dir}//\")\n      endif()\n      install(\n        DIRECTORY ${of_include_src_dir}\n        DESTINATION ${ONEFLOW_INCLUDE_DIR}\n        COMPONENT oneflow_py_include\n        EXCLUDE_FROM_ALL)\n    endforeach()\n  endif(BUILD_PYTHON)\nelse()\n  add_custom_target(prepare_oneflow_third_party ALL)\nendif()\n"
  },
  {
    "path": "cmake/threading.cmake",
    "content": "foreach(threading_runtime_item ${CPU_THREADING_RUNTIMES})\n  if(NOT ${threading_runtime_item} MATCHES \"^(TBB|OMP)$\")\n    message(FATAL_ERROR \"Unsupported cpu threading runtime: ${threading_runtime_item}\")\n  endif()\n\n  if(${threading_runtime_item} STREQUAL \"OMP\")\n    # Reference:\n    # https://releases.llvm.org/11.0.0/tools/clang/docs/OpenMPSupport.html\n    if(\"${CMAKE_CXX_COMPILER_ID}\" STREQUAL \"Clang\")\n      if(\"${CMAKE_CXX_COMPILER_VERSION}\" VERSION_LESS 11)\n        message(\n          FATAL_ERROR\n            \"libopenmp is not supported under clang10, please use TBB with '-DCPU_THREADING_RUNTIMES=TBB'.\"\n        )\n      endif()\n    endif()\n    find_package(OpenMP)\n    if(OPENMP_FOUND)\n      set(WITH_${threading_runtime_item} ON)\n      add_definitions(-DWITH_${threading_runtime_item})\n    endif()\n  else()\n    set(WITH_${threading_runtime_item} ON)\n    add_definitions(-DWITH_${threading_runtime_item})\n  endif()\nendforeach()\n"
  },
  {
    "path": "cmake/util.cmake",
    "content": "function(SHOW_VARIABLES)\n  get_cmake_property(_variableNames VARIABLES)\n  foreach(_variableName ${_variableNames})\n    message(STATUS \"${_variableName}=${${_variableName}}\")\n  endforeach()\nendfunction()\n\nmacro(write_file_if_different file_path content)\n  if(EXISTS ${file_path})\n    file(READ ${file_path} current_content)\n    # NOTE: it seems a cmake bug that \"content\" in this macro is not\n    # treated as a variable\n    if(NOT (current_content STREQUAL ${content}))\n      file(WRITE ${file_path} ${content})\n    endif()\n  else()\n    file(WRITE ${file_path} ${content})\n  endif()\nendmacro()\n\nmacro(copy_all_files_in_dir source_dir dest_dir target)\n  find_program(rsync rsync)\n  if(rsync)\n    add_custom_command(\n      TARGET ${target}\n      POST_BUILD\n      COMMAND\n        ${rsync}\n        # NOTE: the trailing slash of source_dir is needed.\n        # Reference: https://stackoverflow.com/a/56627246\n        ARGS -a --omit-dir-times --no-perms --no-owner --no-group --inplace ${source_dir}/\n        ${dest_dir})\n  else()\n    add_custom_command(TARGET ${target} POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_directory\n                                                           ${source_dir} ${dest_dir})\n  endif()\nendmacro()\n\nset(_COUNTER 0)\nmacro(copy_files file_paths source_dir dest_dir target)\n  find_program(rsync rsync)\n  if(rsync)\n    set(CACHE_FILELIST ${PROJECT_BINARY_DIR}/cached_filename_lists/cache_${_COUNTER})\n    math(EXPR _COUNTER \"${_COUNTER} + 1\")\n    file(WRITE ${CACHE_FILELIST} \"\")\n    foreach(file ${file_paths})\n      file(RELATIVE_PATH rel_path \"${source_dir}\" ${file})\n      file(APPEND ${CACHE_FILELIST} ${rel_path}\\n)\n    endforeach()\n    add_custom_command(\n      TARGET ${target} POST_BUILD\n      COMMAND ${rsync} ARGS -a --omit-dir-times --no-perms --no-owner --no-group --inplace\n              --files-from=${CACHE_FILELIST} ${source_dir} ${dest_dir})\n  else()\n    foreach(file ${file_paths})\n      file(RELATIVE_PATH rel_path \"${source_dir}\" ${file})\n      add_custom_command(TARGET ${target} POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different\n                                                             \"${file}\" \"${dest_dir}/${rel_path}\")\n    endforeach()\n  endif()\nendmacro()\n\nfunction(add_copy_headers_target)\n  cmake_parse_arguments(PARSED_ARGS \"\" \"NAME;SRC;DST;INDEX_FILE\" \"DEPS\" ${ARGN})\n  if(NOT PARSED_ARGS_NAME)\n    message(FATAL_ERROR \"name required\")\n  endif(NOT PARSED_ARGS_NAME)\n  if(NOT PARSED_ARGS_SRC)\n    message(FATAL_ERROR \"src required\")\n  endif(NOT PARSED_ARGS_SRC)\n  if(NOT PARSED_ARGS_DST)\n    message(FATAL_ERROR \"dst required\")\n  endif(NOT PARSED_ARGS_DST)\n  add_custom_target(\n    \"${PARSED_ARGS_NAME}_create_header_dir\" COMMAND ${CMAKE_COMMAND} -E make_directory\n                                                    \"${PARSED_ARGS_DST}\"\n    DEPENDS ${PARSED_ARGS_DEPS})\n\n  add_custom_target(\"${PARSED_ARGS_NAME}_copy_headers_to_destination\" ALL\n                    DEPENDS \"${PARSED_ARGS_NAME}_create_header_dir\")\n  file(GLOB_RECURSE headers \"${PARSED_ARGS_SRC}/*.h\")\n  file(GLOB_RECURSE cuda_headers \"${PARSED_ARGS_SRC}/*.cuh\")\n  file(GLOB_RECURSE hpp_headers \"${PARSED_ARGS_SRC}/*.hpp\")\n  list(APPEND headers ${cuda_headers})\n  list(APPEND headers ${hpp_headers})\n\n  foreach(header_file ${headers})\n    file(RELATIVE_PATH relative_file_path ${PARSED_ARGS_SRC} ${header_file})\n    add_custom_command(\n      TARGET \"${PARSED_ARGS_NAME}_copy_headers_to_destination\" PRE_BUILD\n      COMMAND ${CMAKE_COMMAND} -E copy_if_different ${header_file}\n              \"${PARSED_ARGS_DST}/${relative_file_path}\")\n  endforeach()\n\n  if(PARSED_ARGS_INDEX_FILE)\n    file(STRINGS ${PARSED_ARGS_INDEX_FILE} inventory_headers)\n  endif(PARSED_ARGS_INDEX_FILE)\n  foreach(header_file ${inventory_headers})\n    add_custom_command(\n      TARGET \"${PARSED_ARGS_NAME}_copy_headers_to_destination\" PRE_BUILD\n      COMMAND ${CMAKE_COMMAND} -E copy_if_different \"${PARSED_ARGS_SRC}/${header_file}\"\n              \"${PARSED_ARGS_DST}/${header_file}\")\n  endforeach()\nendfunction()\n\nfunction(use_mirror)\n  set(ALIYUN_URL_PREFIX\n      \"https://oneflow-static.oss-cn-beijing.aliyuncs.com/third_party_mirror/https/\"\n      CACHE STRING \"URL prefix of Aliyun OSS mirror\")\n  cmake_parse_arguments(PARSED_ARGS \"\" \"VARIABLE;URL\" \"\" ${ARGN})\n\n  if((NOT PARSED_ARGS_VARIABLE) OR (NOT PARSED_ARGS_URL))\n    message(FATAL_ERROR \"VARIABLE or URL required\")\n  endif()\n\n  if(PARSED_ARGS_URL MATCHES \"file://\")\n    set(${PARSED_ARGS_VARIABLE} ${PARSED_ARGS_URL} PARENT_SCOPE)\n    return()\n  endif()\n  if(DEFINED THIRD_PARTY_MIRROR)\n    if(THIRD_PARTY_MIRROR STREQUAL \"aliyun\")\n      if(NOT PARSED_ARGS_URL MATCHES \"^https://\")\n        message(FATAL_ERROR \"URL should start with 'https://'\")\n      endif()\n      string(REPLACE \"https://\" ${ALIYUN_URL_PREFIX} MIRRORED_URL ${PARSED_ARGS_URL})\n      set(${PARSED_ARGS_VARIABLE} ${MIRRORED_URL} PARENT_SCOPE)\n      message(NOTICE \"-- fetch ${PARSED_ARGS_VARIABLE} using aliyun mirror ${MIRRORED_URL}\")\n    elseif(NOT THIRD_PARTY_MIRROR STREQUAL \"\")\n      message(FATAL_ERROR \"invalid key for third party mirror\")\n    endif()\n  endif()\nendfunction()\n\nmacro(set_mirror_url variable url)\n  set(${variable} ${url} ${ARGN})\n  use_mirror(VARIABLE ${variable} URL ${url})\nendmacro()\n\nmacro(set_mirror_url_with_hash variable url hash)\n  set_mirror_url(${variable} ${url} ${ARGN})\n  set(${variable}_HASH ${hash} ${ARGN})\nendmacro()\n\nfunction(check_cxx11_abi OUTPUT_VAR)\n  execute_process(\n    COMMAND ${CMAKE_COMMAND} -E echo \"#include <string>\\n void test(std::string){}\\n int main(){}\"\n    OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/temp.cpp)\n  try_compile(\n    COMPILE_SUCCESS ${CMAKE_CURRENT_BINARY_DIR}\n    ${CMAKE_CURRENT_BINARY_DIR}/temp.cpp\n    COMPILE_DEFINITIONS -D_GLIBCXX_USE_CXX11_ABI=1\n    COPY_FILE ${CMAKE_CURRENT_BINARY_DIR}/temp)\n  if(NOT COMPILE_SUCCESS)\n    message(FATAL_ERROR \"Detecting cxx11 availability failed. Please report to OneFlow developers.\")\n  endif()\n  execute_process(COMMAND nm ${CMAKE_CURRENT_BINARY_DIR}/temp COMMAND grep -q cxx11\n                  RESULT_VARIABLE RET_CODE)\n  if(RET_CODE EQUAL 0)\n    set(CXX11_ABI_AVAILABLE ON)\n  else()\n    set(CXX11_ABI_AVAILABLE OFF)\n  endif()\n  execute_process(COMMAND rm ${CMAKE_CURRENT_BINARY_DIR}/temp ${CMAKE_CURRENT_BINARY_DIR}/temp.cpp)\n  set(${OUTPUT_VAR} ${CXX11_ABI_AVAILABLE} PARENT_SCOPE)\nendfunction()\n\ninclude(CheckCXXCompilerFlag)\n\nfunction(target_try_compile_option target flag)\n  # We cannot check for -Wno-foo as this won't throw a warning so we must check for the -Wfoo option directly\n  # http://stackoverflow.com/questions/38785168/cc1plus-unrecognized-command-line-option-warning-on-any-other-warning\n  string(REGEX REPLACE \"^-Wno-\" \"-W\" checkedFlag ${flag})\n  string(REGEX REPLACE \"[-=]\" \"_\" varName CXX_FLAG${checkedFlag})\n  # Avoid double checks. A compiler will not magically support a flag it did not before\n  if(NOT DEFINED ${varName}_SUPPORTED)\n    check_cxx_compiler_flag(${checkedFlag} ${varName}_SUPPORTED)\n  endif()\n  if(${varName}_SUPPORTED)\n    target_compile_options(${target} PRIVATE $<$<COMPILE_LANGUAGE:CXX>:${flag}>)\n    if(BUILD_CUDA)\n      if(\"${CMAKE_CXX_COMPILER_ID}\" STREQUAL \"Clang\" AND \"${CMAKE_CUDA_COMPILER_ID}\" STREQUAL\n                                                         \"Clang\")\n        target_compile_options(${target} PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:${flag}>)\n      endif()\n    endif()\n  endif()\nendfunction()\n\nfunction(target_try_compile_options target)\n  foreach(flag ${ARGN})\n    target_try_compile_option(${target} ${flag})\n  endforeach()\nendfunction()\n\nfunction(target_treat_warnings_as_errors target)\n  if(TREAT_WARNINGS_AS_ERRORS)\n    target_compile_options(${target} PRIVATE $<$<COMPILE_LANGUAGE:CXX>:-Werror>)\n    if(BUILD_CUDA)\n      # Only pass flags when cuda compiler is Clang because cmake handles -Xcompiler incorrectly\n      if(\"${CMAKE_CUDA_COMPILER_ID}\" STREQUAL \"Clang\")\n        target_compile_options(${target} PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Werror>)\n      endif()\n    endif()\n\n    # TODO: remove it while fixing all deprecated call\n    target_try_compile_options(${target} -Wno-error=deprecated-declarations)\n\n    # disable unused-* for different compile mode (maybe unused in cpu.cmake, but used in cuda.cmake)\n    target_try_compile_options(\n      ${target} -Wno-error=unused-const-variable -Wno-error=unused-variable\n      -Wno-error=unused-local-typedefs -Wno-error=unused-private-field\n      -Wno-error=unused-lambda-capture)\n\n    # there is some strict-overflow warnings in oneflow/user/kernels/ctc_loss_kernel_util.cpp for unknown reason, disable them for now\n    target_try_compile_options(${target} -Wno-error=strict-overflow)\n\n    target_try_compile_options(${target} -Wno-error=instantiation-after-specialization)\n\n    # disable for pointer operations of intrusive linked lists\n    target_try_compile_options(${target} -Wno-error=array-bounds)\n\n    target_try_compile_options(${target} -Wno-error=comment)\n\n    # disable visibility warnings related to https://github.com/Oneflow-Inc/oneflow/pull/3676.\n    target_try_compile_options(${target} -Wno-error=attributes)\n\n    # disable error about XXX has no out-of-line virtual method definitions; its vtable will be emitted in every translation unit\n    target_try_compile_options(${target} -Wno-error=weak-vtables)\n\n  endif()\nendfunction()\n\nfunction(set_compile_options_to_oneflow_target target)\n  target_treat_warnings_as_errors(${target})\n  target_compile_options(${target} PRIVATE $<$<COMPILE_LANGUAGE:CXX>:-Werror=return-type>)\n  target_compile_definitions(${target} PRIVATE ONEFLOW_CMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE})\n  # the mangled name between `struct X` and `class X` is different in MSVC ABI, remove it while windows is supported (in MSVC/cl or clang-cl)\n  target_try_compile_options(${target} -Wno-covered-switch-default)\n\n  set_target_properties(${target} PROPERTIES INSTALL_RPATH \"$ORIGIN/../lib\")\n\n  if(BUILD_CUDA)\n    if(\"${CMAKE_CUDA_COMPILER_ID}\" STREQUAL \"NVIDIA\")\n      target_compile_options(\n        ${target}\n        PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:\n                -Xcompiler\n                -Werror=return-type;\n                -Wno-deprecated-gpu-targets;\n                -Werror\n                cross-execution-space-call;\n                -Xcudafe\n                --diag_suppress=declared_but_not_referenced;\n                >)\n    elseif(\"${CMAKE_CUDA_COMPILER_ID}\" STREQUAL \"Clang\")\n      target_compile_options(\n        ${target}\n        PRIVATE\n          $<$<COMPILE_LANGUAGE:CUDA>:\n          -Werror=return-type;\n          # Suppress warning from cub library -- marking as system header seems not working for .cuh files\n          -Wno-pass-failed;\n          >)\n    else()\n      message(FATAL_ERROR \"Unknown CUDA compiler ${CMAKE_CUDA_COMPILER_ID}\")\n    endif()\n    # remove THRUST_IGNORE_CUB_VERSION_CHECK if starting using bundled cub\n    target_compile_definitions(${target} PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:\n                                                 THRUST_IGNORE_CUB_VERSION_CHECK; >)\n  endif()\nendfunction()\n\nfunction(check_variable_defined variable)\n  if(NOT DEFINED ${variable})\n    message(FATAL_ERROR \"Variable ${variable} is not defined\")\n  endif()\nendfunction()\n\nfunction(checkDirAndAppendSlash)\n  set(singleValues DIR;OUTPUT)\n  set(prefix ARG)\n  cmake_parse_arguments(PARSE_ARGV 0 ${prefix} \"${noValues}\" \"${singleValues}\" \"${multiValues}\")\n\n  if(\"${${prefix}_DIR}\" STREQUAL \"\" OR \"${${prefix}_DIR}\" STREQUAL \"/\")\n    message(FATAL_ERROR \"emtpy path found: ${${prefix}_DIR}\")\n  else()\n    set(${${prefix}_OUTPUT} \"${${prefix}_DIR}/\" PARENT_SCOPE)\n  endif()\n\nendfunction()\n\nfunction(mark_targets_as_system)\n  # TODO(daquexian): update this function once https://gitlab.kitware.com/cmake/cmake/-/merge_requests/7308\n  # and its following PRs are merged in cmake v3.25.\n  foreach(target ${ARGV})\n    get_target_property(include_dir ${target} INTERFACE_INCLUDE_DIRECTORIES)\n    set_target_properties(${target} PROPERTIES INTERFACE_SYSTEM_INCLUDE_DIRECTORIES\n                                               \"${include_dir}\")\n  endforeach()\nendfunction()\n\nif(NOT BUILD_SHARED_LIBS)\n  if(APPLE)\n    set(ALL_ARCHIVE_BEGIN -Wl,-force_load)\n    set(ALL_ARCHIVE_END)\n  elseif(UNIX)\n    set(ALL_ARCHIVE_BEGIN -Wl,--whole-archive)\n    set(ALL_ARCHIVE_END -Wl,--no-whole-archive)\n  endif()\nendif()\n"
  },
  {
    "path": "dev-requirements.txt",
    "content": "black==19.10b0; python_version >= \"3.6\"\nclick==8.0.0; python_version >= \"3.6\" # https://github.com/psf/black/issues/2964\nnumpy>=1.21.6, <2.0\nprotobuf>=3.9.2, <4.0\nwheel\ntqdm\nrequests\njinja2\nopencv-python; python_version >= \"3.9\" and sys_platform != 'darwin' and platform_machine != 'aarch64'\nopencv-python==4.2.0.34; python_version < '3.9' and sys_platform != 'darwin' and platform_machine != 'aarch64'\nPyYAML>=5.1\npillow\ndataclasses; python_version<\"3.7\"\ncmakelang==0.6.13\npytest-xdist\npytest-repeat\nrich\nportalocker\ntyping-extensions>=4.0.0, <5.0\n"
  },
  {
    "path": "docker/build/Dockerfile",
    "content": "# warning: never share the container image this dockerfile produces\nARG CUDA=10.0\n\nFROM nvidia/cuda:${CUDA}-cudnn7-devel-centos7\nRUN yum-config-manager --add-repo https://yum.repos.intel.com/setup/intelproducts.repo && \\\n    rpm --import https://yum.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS-2019.PUB\nRUN yum update -y && yum install -y epel-release\nRUN yum update -y && yum install -y rdma-core-devel \\\n    nasm \\\n    cmake3 \\\n    make \\\n    git \\\n    centos-release-scl \\\n    intel-mkl-2020.0-088 \\\n    zlib-devel \\\n    curl-devel \\\n    which\n\nRUN ln -sf /usr/bin/cmake3 /usr/bin/cmake\n\nRUN mkdir -p /tmp/download/cmake-extracted && \\\n    cd /tmp/download && \\\n    curl --location https://github.com/Kitware/CMake/releases/download/v3.14.0/cmake-3.14.0.tar.gz --output cmake.tar.gz && \\\n    tar -xvzf cmake.tar.gz --directory cmake-extracted && \\\n    cd cmake-extracted/* && \\\n    mkdir /cmake-install\nRUN cd /tmp/download/cmake-extracted/* && \\\n    cmake . -DCMAKE_USE_SYSTEM_CURL=ON -DCMAKE_INSTALL_PREFIX=/cmake-install && \\\n    make -j $(nproc) && \\\n    make install\nENV PATH=\"/cmake-install/bin:${PATH}\"\n\nARG USE_PYTHON_3_OR_2=3\n\nRUN if [ \"${USE_PYTHON_3_OR_2}\" -eq 2 ] ; then yum update -y \\\n    && yum install -y python-devel.x86_64 \\\n    && curl https://bootstrap.pypa.io/get-pip.py --output ./get-pip.py \\\n    && python ./get-pip.py \\\n    && rm get-pip.py \\\n    && pip install numpy==1.12.0 protobuf ; fi\n\nCOPY dev-requirements.txt /workspace/dev-requirements.txt\n\nRUN if [ \"${USE_PYTHON_3_OR_2}\" -eq 3 ] ; then yum update -y \\\n    && yum install -y rh-python36 python36-devel.x86_64 python36-devel \\\n    && python3 -m ensurepip \\\n    && pip3 install /workspace/dev-requirements.txt; fi\n\nWORKDIR /workspace/build\n\nCOPY cmake /workspace/cmake\nCOPY CMakeLists.txt /workspace/CMakeLists.txt\n\n# BUILD DEPENDENCY\nCOPY build/third_party /workspace/build/third_party\nRUN cmake -DTHIRD_PARTY=ON -DCMAKE_BUILD_TYPE=Release -DRELEASE_VERSION=ON .. && make -j\n\n# BUILD ONEFLOW\nCOPY oneflow /workspace/oneflow\nCOPY tools /workspace/tools\n\nRUN export LD_LIBRARY_PATH=/opt/intel/lib/intel64_lin:/opt/intel/mkl/lib/intel64:$LD_LIBRARY_PATH; \\\n    cmake -DTHIRD_PARTY=OFF .. && make -j $(nproc) ;\n\n## BUILD WHEEL\nWORKDIR /workspace\nRUN pip${USE_PYTHON_3_OR_2} install wheel\nCOPY setup.py /workspace/setup.py\nRUN python${USE_PYTHON_3_OR_2} setup.py bdist_wheel\nRUN pip${USE_PYTHON_3_OR_2} install /workspace/dist/*.whl\n\nRUN rm -rf oneflow third_party cmake CMakeLists.txt\n"
  },
  {
    "path": "docker/build/build-ubuntu.sh",
    "content": "docker build \\\n  --rm \\\n  -t oneflow-build:ubuntu -f docker/build/build.ubuntu.dockerfile .\n"
  },
  {
    "path": "docker/build/build.sh",
    "content": "docker build \\\n  --rm \\\n  -t oneflow-build -f docker/build/Dockerfile .\n"
  },
  {
    "path": "docker/build/build.ubuntu.dockerfile",
    "content": "ARG CUDA=10.0\nARG UBUNTU_VERSION=16.04\nFROM nvidia/cuda:${CUDA}-cudnn7-devel-ubuntu${UBUNTU_VERSION}\n\nUSER 0\n\nRUN apt-get update && \\\n    apt-get install -y apt-transport-https && \\\n    apt-get install -y --no-install-recommends \\\n    curl \\\n    nasm \\\n    make \\\n    git \\\n    gcc \\\n    g++ \\\n    libopenblas-dev \\\n    python3-dev\n\n# speed up pip install in China\nENV TUNA_PIP_INSTALL=\" -i https://pypi.tuna.tsinghua.edu.cn/simple\"\n\nCOPY dev-requirements.txt /workspace/dev-requirements.txt\n\nRUN curl https://bootstrap.pypa.io/get-pip.py --output ./get-pip.py \\\n    && python3 ./get-pip.py \\\n    && pip3 install $TUNA_INDEX cmake \\\n    && pip3 install $TUNA_INDEX -r /workspace/dev-requirements.txt\n\nWORKDIR /workspace/build\n\nCOPY cmake /workspace/cmake\nCOPY CMakeLists.txt /workspace/CMakeLists.txt\n\n# BUILD DEPENDENCY\nCOPY build/third_party /workspace/build/third_party\nRUN cmake -DTHIRD_PARTY=ON -DONEFLOW=OFF -DCMAKE_BUILD_TYPE=Release .. && make -j$(nproc)\n\n# BUILD ONEFLOW\nCOPY oneflow /workspace/oneflow\nCOPY tools /workspace/tools\n\nRUN cmake -DTHIRD_PARTY=OFF -DONEFLOW=ON .. && make -j$(nproc) of_pyscript_copy\nRUN cmake -DTHIRD_PARTY=OFF -DONEFLOW=ON .. && make -j$(nproc)\n\n# BUILD WHEEL\nWORKDIR /workspace\nCOPY setup.py /workspace/setup.py\nRUN python3 setup.py bdist_wheel\nRUN pip3 install /workspace/dist/*.whl\n\nRUN rm -rf oneflow third_party cmake CMakeLists.txt\n"
  },
  {
    "path": "docker/build/launch.sh",
    "content": "docker run -it --rm \\\n\t-v /dataset:/dataset/ \\\n\toneflow-build \n"
  },
  {
    "path": "docker/build/test.sh",
    "content": "docker run -it --rm \\\n\t-v /dataset:/dataset/ \\\n\toneflow-build \\\n    python3 -c \"import oneflow\"\n"
  },
  {
    "path": "docker/ci/base/Dockerfile",
    "content": "# warning: never share the container image this dockerfile produces\nARG CUDA=10.0\n\nFROM nvidia/cuda:${CUDA}-cudnn7-devel-centos7\n\nCOPY dev-requirements.txt /workspace/dev-requirements.txt\nRUN yum-config-manager --add-repo https://yum.repos.intel.com/setup/intelproducts.repo && \\\n    rpm --import https://yum.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS-2019.PUB && \\\n    yum update -y && yum install -y epel-release && \\\n    yum update -y && yum install -y rdma-core-devel \\\n    nasm \\\n    make \\\n    git \\\n    centos-release-scl \\\n    intel-mkl-2020.0-088 \\\n    zlib-devel \\\n    curl-devel \\\n    which \\\n    rh-python36 python36-devel.x86_64 python36-devel && \\\n    python3 -m ensurepip && \\\n    pip3 install -r /workspace/dev-requirements.txt && \\\n    yum clean all\n\nRUN mkdir -p /tmp/download && \\\n    mkdir /cmake-extracted && \\\n    cd /tmp/download && \\\n    curl --location https://github.com/Kitware/CMake/releases/download/v3.14.0/cmake-3.14.0-Linux-x86_64.tar.gz --output cmake.tar.gz && \\\n    tar -xvzf cmake.tar.gz --directory /cmake-extracted && \\\n    mv /cmake-extracted/* /cmake-extracted/cmake-install && \\\n    rm -rf /tmp/download\n\nENV PATH=\"/cmake-extracted/cmake-install/bin:${PATH}\"\n"
  },
  {
    "path": "docker/ci/fmt/Dockerfile",
    "content": "FROM python:3.7\nRUN curl https://oneflow-static.oss-cn-beijing.aliyuncs.com/bin/clang-format -o /usr/local/bin/clang-format && chmod +x /usr/local/bin/clang-format\nRUN apt update && apt install -y libncurses5\n"
  },
  {
    "path": "docker/ci/fmt/build.sh",
    "content": "set -ex\ncd docker/ci/fmt\ndocker build -t oneflow-fmt .\n"
  },
  {
    "path": "docker/ci/make/Dockerfile",
    "content": "ARG from\nFROM ${from}\nWORKDIR /workspace/build\n\n# BUILD ONEFLOW\nCOPY oneflow /workspace/oneflow\nCOPY tools /workspace/tools\n\nRUN export LD_LIBRARY_PATH=/opt/intel/lib/intel64_lin:/opt/intel/mkl/lib/intel64:$LD_LIBRARY_PATH; \\\n    cmake -DTHIRD_PARTY=OFF -DONEFLOW=ON .. && make -j $(nproc) ;\n\n## BUILD WHEEL\nWORKDIR /workspace\nCOPY setup.py /workspace/setup.py\nRUN python3 setup.py bdist_wheel\n\nFROM centos:7\nWORKDIR /workspace\nCOPY --from=0 /workspace/dist/*.whl .\nCOPY --from=0 /workspace/build/bin/oneflow_testexe .\n"
  },
  {
    "path": "docker/ci/test/Dockerfile",
    "content": "FROM ufoym/deepo\n\nRUN apt remove openmpi-common libfabric1 openmpi-bin librdmacm1:amd64 libopenmpi2 libopenmpi2:amd64 -y\nENV MOFED_DIR MLNX_OFED_LINUX-4.3-1.0.1.0-ubuntu18.04-x86_64\nRUN wget https://oneflow-static.oss-cn-beijing.aliyuncs.com/deps/${MOFED_DIR}.tgz && \\\n    tar -xzvf ${MOFED_DIR}.tgz && \\\n    ${MOFED_DIR}/mlnxofedinstall --user-space-only --without-fw-update --all -q --force && \\\n    cd .. && \\\n    rm -rf ${MOFED_DIR} && \\\n    rm -rf *.tgz\n\nRUN apt update && apt install -y --no-install-recommends gdb openssh-server openssh-client\n\nRUN echo 'ALL ALL=(ALL) NOPASSWD: ALL' >> /etc/sudoers\nRUN sed -i 's/PermitRootLogin prohibit-password/PermitRootLogin yes/' /etc/ssh/sshd_config\n\nCOPY requirements.txt .\nRUN pip3 install -i https://mirrors.aliyun.com/pypi/simple -r requirements.txt\n"
  },
  {
    "path": "docker/ci/test/build.sh",
    "content": "set -ex\ntest_img_dir=\"$(dirname \"${BASH_SOURCE[0]}\")\"\ntest_img_dir=\"$(realpath \"${test_img_dir}\")\"\ncd $test_img_dir\n\nproxy_args=\"\"\nproxy_args+=\" --network=host\"\nproxy_args+=\" --build-arg HTTP_PROXY=${HTTP_PROXY}\"\nproxy_args+=\" --build-arg HTTPS_PROXY=${HTTPS_PROXY}\"\nproxy_args+=\" --build-arg http_proxy=${http_proxy}\"\nproxy_args+=\" --build-arg https_proxy=${https_proxy}\"\n\nimg_tag=\"oneflow-test:0.2\" # update me if any of related files are changed\nif [[ \"$(docker images -q ${img_tag} 2> /dev/null)\" == \"\" ]]; then\n  docker build --rm $proxy_args \\\n    -t $img_tag .\nfi\n"
  },
  {
    "path": "docker/ci/test/launch.sh",
    "content": "docker run --shm-size=8g --privileged --network=host --rm -it -w $PWD -v $PWD:$PWD -v /dataset:/dataset -v /model_zoo:/model_zoo \\\n    -v $HOME:$HOME \\\n    oneflow-test:0.2 \\\n    bash\n"
  },
  {
    "path": "docker/ci/test/requirements.txt",
    "content": "sphinx==3.5.4\njinja2<3.1\nrecommonmark==0.6.0\nfuro==2021.4.11b34\nsphinx-copybutton==0.5.0\n# dependencies above must be identical to docs/requirements.txt\npycocotools\nopencv-python==4.2.0.34\nscipy\npillow\ntensorflow-addons==0.9.1\nhttps://oneflow-static.oss-cn-beijing.aliyuncs.com/pipindex/pipindex-0.1.3-py2.py3-none-any.whl\n"
  },
  {
    "path": "docker/ci/test-v2/Dockerfile",
    "content": "FROM pytorch/pytorch:1.9.0-cuda10.2-cudnn7-runtime\nCOPY sources.list /etc/apt/sources.list\nRUN apt update && apt install ffmpeg libsm6 libxext6 gdb gcc g++ -y --no-install-recommends\nCOPY requirements.txt .\nRUN python3 -m pip install -i https://mirrors.aliyun.com/pypi/simple -r requirements.txt\n"
  },
  {
    "path": "docker/ci/test-v2/build.sh",
    "content": "set -ex\ntest_img_dir=\"$(dirname \"${BASH_SOURCE[0]}\")\"\ntest_img_dir=\"$(realpath \"${test_img_dir}\")\"\ncd $test_img_dir\n\nproxy_args=\"\"\nproxy_args+=\" --network=host\"\nproxy_args+=\" --build-arg HTTP_PROXY=${HTTP_PROXY}\"\nproxy_args+=\" --build-arg HTTPS_PROXY=${HTTPS_PROXY}\"\nproxy_args+=\" --build-arg http_proxy=${http_proxy}\"\nproxy_args+=\" --build-arg https_proxy=${https_proxy}\"\n\nimg_tag=\"oneflow-test-v2:0.1\" # update me if any of related files are changed\nif [[ \"$(docker images -q ${img_tag} 2> /dev/null)\" == \"\" ]]; then\n  docker build --rm $proxy_args \\\n    -t $img_tag .\nfi\n"
  },
  {
    "path": "docker/ci/test-v2/requirements.txt",
    "content": "sphinx==3.5.4\njinja2<3.1\nrecommonmark==0.6.0\nfuro==2021.4.11b34\nsphinx-copybutton==0.5.0\n# dependencies above must be identical to docs/requirements.txt\npycocotools\nopencv-python==4.2.0.34\nscipy\npillow\nhttps://oneflow-static.oss-cn-beijing.aliyuncs.com/pipindex/pipindex-0.1.3-py2.py3-none-any.whl\n"
  },
  {
    "path": "docker/ci/test-v2/sources.list",
    "content": "# 默认注释了源码镜像以提高 apt update 速度，如有需要可自行取消注释\ndeb https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ focal main restricted universe multiverse\n# deb-src https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ focal main restricted universe multiverse\ndeb https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ focal-updates main restricted universe multiverse\n# deb-src https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ focal-updates main restricted universe multiverse\ndeb https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ focal-backports main restricted universe multiverse\n# deb-src https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ focal-backports main restricted universe multiverse\ndeb https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ focal-security main restricted universe multiverse\n# deb-src https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ focal-security main restricted universe multiverse\n\n# 预发布软件源，不建议启用\n# deb https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ focal-proposed main restricted universe multiverse\n# deb-src https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ focal-proposed main restricted universe multiverse\n"
  },
  {
    "path": "docker/ci/third_party/Dockerfile",
    "content": "ARG from\nFROM ${from}\nWORKDIR /workspace/build\n\nCOPY cmake /workspace/cmake\nCOPY CMakeLists.txt /workspace/CMakeLists.txt\n\n# BUILD DEPENDENCY\nCOPY build/third_party /workspace/build/third_party\nRUN export LD_LIBRARY_PATH=/opt/intel/lib/intel64_lin:/opt/intel/mkl/lib/intel64:$LD_LIBRARY_PATH; \\\n    cmake -DTHIRD_PARTY=ON -DONEFLOW=OFF -DCMAKE_BUILD_TYPE=Release -DRELEASE_VERSION=ON .. && make -j prepare_oneflow_third_party\n"
  },
  {
    "path": "docker/package/manylinux/CentOS-Base.repo",
    "content": "# CentOS-Base.repo\n#\n# From https://mirror.tuna.tsinghua.edu.cn/help/centos/\n#\n# The mirror system uses the connecting IP address of the client and the\n# update status of each mirror to pick mirrors that are updated to and\n# geographically close to the client.  You should use this for CentOS updates\n# unless you are manually picking other mirrors.\n#\n# If the mirrorlist= does not work for you, as a fall back you can try the\n# remarked out baseurl= line instead.\n#\n#\n\n\n[base]\nname=CentOS-$releasever - Base\nbaseurl=https://mirrors.tuna.tsinghua.edu.cn/centos/$releasever/os/$basearch/\n        http://mirrors.aliyun.com/centos/$releasever/os/$basearch/\n        http://mirrors.aliyuncs.com/centos/$releasever/os/$basearch/\n#mirrorlist=http://mirrorlist.centos.org/?release=$releasever&arch=$basearch&repo=os\nenabled=1\ngpgcheck=1\ngpgkey=file:///etc/pki/rpm-gpg/RPM-GPG-KEY-7\n\n#released updates\n[updates]\nname=CentOS-$releasever - Updates\nbaseurl=https://mirrors.tuna.tsinghua.edu.cn/centos/$releasever/updates/$basearch/\n        http://mirrors.aliyun.com/centos/$releasever/updates/$basearch/\n        http://mirrors.aliyuncs.com/centos/$releasever/updates/$basearch/\n#mirrorlist=http://mirrorlist.centos.org/?release=$releasever&arch=$basearch&repo=updates\nenabled=1\ngpgcheck=1\ngpgkey=file:///etc/pki/rpm-gpg/RPM-GPG-KEY-7\n\n\n\n#additional packages that may be useful\n[extras]\nname=CentOS-$releasever - Extras\nbaseurl=https://mirrors.tuna.tsinghua.edu.cn/centos/$releasever/extras/$basearch/\n        http://mirrors.aliyun.com/centos/$releasever/extras/$basearch/\n        http://mirrors.aliyuncs.com/centos/$releasever/extras/$basearch/\n#mirrorlist=http://mirrorlist.centos.org/?release=$releasever&arch=$basearch&repo=extras\nenabled=1\ngpgcheck=1\ngpgkey=file:///etc/pki/rpm-gpg/RPM-GPG-KEY-7\n\n\n\n#additional packages that extend functionality of existing packages\n[centosplus]\nname=CentOS-$releasever - Plus\nbaseurl=https://mirrors.tuna.tsinghua.edu.cn/centos/$releasever/centosplus/$basearch/\n        http://mirrors.aliyun.com/centos/$releasever/centosplus/$basearch/\n        http://mirrors.aliyuncs.com/centos/$releasever/centosplus/$basearch/\n#mirrorlist=http://mirrorlist.centos.org/?release=$releasever&arch=$basearch&repo=centosplus\ngpgcheck=1\nenabled=0\ngpgkey=file:///etc/pki/rpm-gpg/RPM-GPG-KEY-7\n"
  },
  {
    "path": "docker/package/manylinux/CentOS7-Base-163.repo",
    "content": "# CentOS-Base.repo\n#\n# The mirror system uses the connecting IP address of the client and the\n# update status of each mirror to pick mirrors that are updated to and\n# geographically close to the client.  You should use this for CentOS updates\n# unless you are manually picking other mirrors.\n#\n# If the mirrorlist= does not work for you, as a fall back you can try the\n# remarked out baseurl= line instead.\n#\n#\n[base]\nname=CentOS-$releasever - Base - 163.com\n#mirrorlist=http://mirrorlist.centos.org/?release=$releasever&arch=$basearch&repo=os\nbaseurl=http://mirrors.163.com/centos/$releasever/os/$basearch/\ngpgcheck=1\ngpgkey=http://mirrors.163.com/centos/RPM-GPG-KEY-CentOS-7\n\n#released updates\n[updates]\nname=CentOS-$releasever - Updates - 163.com\n#mirrorlist=http://mirrorlist.centos.org/?release=$releasever&arch=$basearch&repo=updates\nbaseurl=http://mirrors.163.com/centos/$releasever/updates/$basearch/\ngpgcheck=1\ngpgkey=http://mirrors.163.com/centos/RPM-GPG-KEY-CentOS-7\n\n#additional packages that may be useful\n[extras]\nname=CentOS-$releasever - Extras - 163.com\n#mirrorlist=http://mirrorlist.centos.org/?release=$releasever&arch=$basearch&repo=extras\nbaseurl=http://mirrors.163.com/centos/$releasever/extras/$basearch/\ngpgcheck=1\ngpgkey=http://mirrors.163.com/centos/RPM-GPG-KEY-CentOS-7\n\n#additional packages that extend functionality of existing packages\n[centosplus]\nname=CentOS-$releasever - Plus - 163.com\nbaseurl=http://mirrors.163.com/centos/$releasever/centosplus/$basearch/\ngpgcheck=1\nenabled=0\ngpgkey=http://mirrors.163.com/centos/RPM-GPG-KEY-CentOS-7\n"
  },
  {
    "path": "docker/package/manylinux/Dockerfile",
    "content": "ARG from\nFROM ${from}\nARG use_tuna_yum=0\nARG pip_args=\"\"\nARG bazel_url=\"https://github.com/bazelbuild/bazel/releases/download/3.4.1/bazel-3.4.1-linux-x86_64\"\nLABEL maintainer=\"OneFlow Maintainers\"\n\n# manylinux2014\nENV AUDITWHEEL_ARCH x86_64\nENV AUDITWHEEL_PLAT manylinux2014_$AUDITWHEEL_ARCH\nENV LC_ALL en_US.UTF-8\nENV LANG en_US.UTF-8\nENV LANGUAGE en_US.UTF-8\nENV PATH $PATH:/usr/local/bin\nENV LD_LIBRARY_PATH /usr/local/lib64:/usr/local/lib\nENV PKG_CONFIG_PATH /usr/local/lib/pkgconfig\n\n# use tuna mirror\nCOPY docker/package/manylinux/CentOS7-Base-163.repo /tmp/CentOS-Base.repo\nRUN if [ \"${use_tuna_yum}\" = \"1\" ]; then mv /tmp/CentOS-Base.repo /etc/yum.repos.d/ && yum makecache ; fi\n\n# to speed up docker img building disable cuda repo\n# in 10.1, cuda yum repo will update cublas to 10.2 and breaks build\nRUN yum-config-manager --disable cuda nvidia-ml\n\nARG MANYLINUX_SHA=b634044\nRUN yum -y install unzip && curl -L -o manylinux.zip https://github.com/Oneflow-Inc/manylinux/archive/${MANYLINUX_SHA}.zip && unzip manylinux.zip -d tmp && cp -r tmp/*/docker/build_scripts /build_scripts && bash build_scripts/build.sh && rm -r build_scripts tmp manylinux.zip\n\nENV SSL_CERT_FILE=/opt/_internal/certs.pem\n# manylinux2014 end\n\nRUN yum-config-manager --add-repo https://yum.repos.intel.com/oneapi && \\\n    rpm --import https://yum.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB && \\\n    yum update -y && yum install -y epel-release && \\\n    yum -y install centos-release-scl && \\\n    yum install -y intel-oneapi-mkl-devel-2021.2.0 nasm rdma-core-devel devtoolset-7-gcc* rsync gdb\n\nRUN /opt/python/cp35-cp35m/bin/pip install $pip_args -U cmake==3.18.4.post1 && ln -s /opt/_internal/cpython-3.5.9/bin/cmake /usr/bin/cmake\n\nRUN mkdir -p /tmp && cd /tmp && \\\n    curl -L -o patchelf-src.zip \\\n    https://github.com/Oneflow-Inc/patchelf/archive/64bf5388ef7d45d3697c4aadbd3f5d7d68a22aa3.zip && \\\n    unzip patchelf-src.zip && cd patchelf-* && ./bootstrap.sh && ./configure && make -j`nproc` && \\\n    make install && cd .. && rm -rf patchelf-*\n\nRUN curl -L $bazel_url -o /usr/local/bin/bazel \\\n    && chmod +x /usr/local/bin/bazel \\\n    && bazel\n\nCOPY dev-requirements.txt /tmp/dev-requirements.txt\nRUN /opt/python/cp36-cp36m/bin/pip install $pip_args -r /tmp/dev-requirements.txt --user \\\n    && /opt/python/cp37-cp37m/bin/pip install $pip_args -r /tmp/dev-requirements.txt --user \\\n    && /opt/python/cp38-cp38/bin/pip install $pip_args -r /tmp/dev-requirements.txt --user \\\n    && rm /tmp/dev-requirements.txt\n"
  },
  {
    "path": "docker/package/manylinux/README.md",
    "content": "# 使用 docker 生成 OneFlow wheel 包\n\n### 创建 docker 容器\n\n在 OneFlow 源码根目录下运行:\n```\ndocker build -f docker/package/manylinux/Dockerfile --build-arg from=nvidia/cuda:10.2-cudnn7-devel-centos7 -t oneflow:manylinux2014-cuda10.2 .\n```\n\n### 打包 manylinux python wheel\n\n这里有 manylinux2014(centos7) + cuda10.2 的 Dockerfile，里面安装了编译 oneflow 所需的库，假设你已经用 Dockerfile build 了一个 docker 镜像，叫做 oneflow:manylinux2014-cuda10.2，那么只要在 oneflow 源码目录运行\n\n```bash\ndocker run --rm -it -v `pwd`:/oneflow-src -w /oneflow-src oneflow:manylinux2014-cuda10.2\n```\n\n\nIf you prefer operate inside docker:\n\n```bash\ndocker run --rm -it -v `pwd`:/oneflow-src -w /oneflow-src oneflow:manylinux2014-cuda10.2 bash\n```\n\n```bash\n/oneflow-src/docker/package/manylinux/build_wheel.sh --python3.6 --wheel-dir /oneflow-src/wheel-test\n```\n\n就会在 docker 镜像里执行 build_wheel.sh 来编译生成 python 3.5 到 python 3.8 的 oneflow manylinux2014 wheel。生成的包在 oneflow 源码目录下的 wheelhouse/ 文件夹内\n\n#### 注意事项\n\n1. 运行 `docker run` 时可能需要添加 `-e http_proxy=$http_proxy -e https_proxy=$https_proxy` 参数，以在容器内使用宿主机的代理，避免编译第三方库时因为网络问题而出错\n\n2. 只要运行了 `cmake -DTHIRD_PARTY=ON ..`，oneflow 本体都会从头编译，所以如果第三方库已经由 docker 容器编译过，这次只想增量编译 oneflow 本体，可以用命令\n\n    ```bash\n    docker run --rm -it -v `pwd`:/oneflow-src oneflow:manylinux2014-cuda10.2 /oneflow-src/docker/package/manylinux/build_wheel.sh --skip-third-party\n    ```\n\n   这会给 build_wheel.sh 传一个 `--skip-third-party` 参数，跳过第三方库的编译\n\n3. 只想在生成某些 python 版本的包，例如 python3.5，可以用命令\n\n    ```bash\n    docker run --rm -it -v `pwd`:/oneflow-src oneflow:manylinux2014-cuda10.2 /oneflow-src/docker/package/manylinux/build_wheel.sh --python3.5\n    ```\n\n    支持的参数是 `--python3.5`、`--python3.6`、`--python3.7`、`--python3.8`，需要生成多个版本可以同时传入多个参数。不传入版本参数则会生成所有的 python 版本的包\n\n3. 如果想自定义 oneflow 编译时的 cmake 参数，可以直接把 cmake 参数写出来，如：\n\n    ```bash\n    docker run --rm -it -v `pwd`:/oneflow-src oneflow:manylinux2014-cuda10.2 /oneflow-src/docker/package/manylinux/build_wheel.sh -DWITH_XLA=ON\n    ```\n"
  },
  {
    "path": "docker/package/manylinux/build_wheel.py",
    "content": "import os\nimport subprocess\nimport tempfile\nfrom pathlib import Path\nimport getpass\nimport uuid\n\n\ndef get_arg_env(env_var_name: str, mode=\"run\"):\n    val = os.getenv(env_var_name)\n    assert val, f\"system environment variable {env_var_name} found empty\"\n    if mode == \"run\":\n        return f\"--env {env_var_name}={val}\"\n    elif mode == \"build\":\n        return f\"--build-arg {env_var_name}={val}\"\n    else:\n        raise f\"{mode} not supported\"\n\n\ndef get_proxy_build_args():\n    proxy_build_args = []\n    if os.getenv(\"HTTP_PROXY\"):\n        for v in [\"HTTP_PROXY\", \"HTTPS_PROXY\"]:\n            proxy_build_args.append(get_arg_env(v, mode=\"build\"))\n    if os.getenv(\"http_proxy\"):\n        for v in [\"http_proxy\", \"https_proxy\"]:\n            proxy_build_args.append(get_arg_env(v, mode=\"build\"))\n    return \" \".join(proxy_build_args)\n\n\ndef get_proxy_env_args():\n    proxy_build_args = []\n    if os.getenv(\"HTTP_PROXY\"):\n        for v in [\"HTTP_PROXY\", \"HTTPS_PROXY\"]:\n            proxy_build_args.append(get_arg_env(v))\n    if os.getenv(\"http_proxy\"):\n        for v in [\"http_proxy\", \"https_proxy\"]:\n            proxy_build_args.append(get_arg_env(v))\n    return \" \".join(proxy_build_args)\n\n\ndef build_img(\n    cuda_version,\n    oneflow_src_dir,\n    use_aliyun_mirror,\n    use_tuna,\n    use_system_proxy,\n    img_tag,\n    dry,\n):\n    cudnn_version = 7\n    if str(cuda_version).startswith(\"11\"):\n        cudnn_version = 8\n    cuda_version_img = cuda_version\n    if cuda_version == \"11.2\":\n        cuda_version_img = \"11.2.2\"\n    if cuda_version == \"11.1\":\n        cuda_version_img = \"11.1.1\"\n    if cuda_version == \"11.0\":\n        cuda_version_img = \"11.0.3\"\n    from_img = f\"nvidia/cuda:{cuda_version_img}-cudnn{cudnn_version}-devel-centos7\"\n    tuna_build_arg = \"\"\n    if use_tuna:\n        tuna_build_arg = '--build-arg use_tuna_yum=1 --build-arg pip_args=\"-i https://mirrors.aliyun.com/pypi/simple\"'\n    if use_aliyun_mirror:\n        tuna_build_arg += ' --build-arg bazel_url=\"https://oneflow-static.oss-cn-beijing.aliyuncs.com/deps/bazel-3.4.1-linux-x86_64\"'\n\n    proxy_build_arg = get_proxy_build_args() if use_system_proxy else \"\"\n    cmd = f\"docker build -f docker/package/manylinux/Dockerfile {proxy_build_arg} {tuna_build_arg} --build-arg from={from_img} -t {img_tag} .\"\n    print(cmd)\n    if dry == False:\n        subprocess.check_call(cmd, cwd=oneflow_src_dir, shell=True)\n\n\ndef common_cmake_args(cache_dir=None, extra_oneflow_cmake_args=None):\n    assert cache_dir\n    ret = \"\"\n    if (\n        not extra_oneflow_cmake_args\n        or \"-DCMAKE_BUILD_TYPE\" not in extra_oneflow_cmake_args\n    ):\n        ret += \" -DCMAKE_BUILD_TYPE=Release\"\n    if not extra_oneflow_cmake_args or \"-DBUILD_RDMA\" not in extra_oneflow_cmake_args:\n        ret += \" -DBUILD_RDMA=ON\"\n    third_party_install_dir = os.path.join(cache_dir, \"build-third-party-install\")\n    ret += f\" -DTHIRD_PARTY_DIR={third_party_install_dir}\"\n    return ret\n\n\ndef get_build_dir_arg(cache_dir, oneflow_src_dir):\n    return \"\"\n    build_dir_real = os.path.join(cache_dir, \"build\")\n    build_dir_mount = os.path.join(oneflow_src_dir, \"build\")\n    return f\"-v {build_dir_real}:{build_dir_mount}\"\n\n\ndef force_rm_dir(dir_to_clean):\n    print(\"cleaning:\", dir_to_clean)\n    assert dir_to_clean\n    clean_cmd = f\"docker run --network=host --rm -v {dir_to_clean}:{dir_to_clean} -w {dir_to_clean} busybox rm -rf {dir_to_clean}/*\"\n    subprocess.check_call(clean_cmd, shell=True)\n\n\ndef create_tmp_bash_and_run(docker_cmd, img, bash_cmd, bash_args, bash_wrap, dry):\n    with tempfile.NamedTemporaryFile(mode=\"w+\", encoding=\"utf-8\") as wrapper_f:\n        with tempfile.NamedTemporaryFile(mode=\"w+\", encoding=\"utf-8\") as f:\n            w_name = \"/host\" + wrapper_f.name\n            f_name = \"/host\" + f.name\n            bash_cmd = \"PATH=/opt/python/cp37-cp37m/bin:$PATH\\n\" + bash_cmd\n            f.write(bash_cmd)\n            f.flush()\n            wrapped = f\"\"\"\n{bash_wrap}\nbash {bash_args} {f_name}\n\"\"\"\n            wrapper_f.write(wrapped)\n            wrapper_f.flush()\n\n            print(\"=\" * 5 + f\"bash_cmd: {f_name}\" + \"=\" * 5)\n            print(bash_cmd)\n            print(\"=\" * 5 + f\"bash_cmd: {f_name}\" + \"=\" * 5)\n\n            print(\"=\" * 5 + f\"wrapped: {w_name}\" + \"=\" * 5)\n            print(wrapped)\n            print(\"=\" * 5 + f\"wrapped: {w_name}\" + \"=\" * 5)\n\n            docker_cmd = f\"{docker_cmd} -v /tmp:/host/tmp {img}\"\n            cmd = f\"{docker_cmd} bash {bash_args} {w_name}\"\n            print(cmd)\n            if dry:\n                print(\"dry run, skipping\")\n            else:\n                subprocess.check_call(cmd, shell=True)\n\n\ndef get_common_docker_args(\n    oneflow_src_dir=None,\n    cache_dir=None,\n    current_dir=None,\n    house_dir=None,\n    use_system_proxy=True,\n    inplace=False,\n):\n    root = Path(cache_dir)\n    child = Path(current_dir)\n    assert root in child.parents\n    cwd = os.getcwd()\n    pwd_arg = f\"-v {cwd}:{cwd}\"\n    cache_dir_arg = f\"-v {cache_dir}:{cache_dir}\"\n    house_dir_arg = \"\"\n    if house_dir:\n        house_dir_arg = f\"-v {house_dir}:{house_dir}\"\n    build_dir_arg = get_build_dir_arg(cache_dir, oneflow_src_dir)\n    proxy_env_arg = get_proxy_env_args() if use_system_proxy else \"\"\n    inplace_attr = \"\"\n    if inplace == False:\n        inplace_attr = \":ro\"\n    cache_dir_args = \" \".join(\n        [\n            f\"-v {os.path.join(cache_dir, 'ccache')}:/root/.ccache\",\n            f\"-v {os.path.join(cache_dir, 'local')}:/root/.local\",\n            f\"-v {os.path.join(cache_dir, 'cache')}:/root/.cache\",\n        ]\n    )\n    return f\"{cache_dir_args} -v {oneflow_src_dir}:{oneflow_src_dir}{inplace_attr} {proxy_env_arg} {pwd_arg} {house_dir_arg} {cache_dir_arg} {build_dir_arg} -w {current_dir} --shm-size=8g\"\n\n\ndef get_python_dir(inplace=True, oneflow_src_dir=None, cache_dir=None):\n    if inplace:\n        assert oneflow_src_dir\n        return os.path.join(oneflow_src_dir, \"python\")\n    else:\n        assert cache_dir\n        return os.path.join(cache_dir, \"python\")\n\n\ndef build_third_party(\n    img_tag,\n    oneflow_src_dir,\n    cache_dir,\n    extra_oneflow_cmake_args,\n    extra_docker_args,\n    bash_args,\n    bash_wrap,\n    dry,\n    use_system_proxy,\n    inplace,\n):\n    third_party_build_dir = os.path.join(cache_dir, \"build-third-party\")\n    oneflow_python_dir = get_python_dir(\n        inplace=inplace, oneflow_src_dir=oneflow_src_dir, cache_dir=cache_dir\n    )\n    if inplace:\n        inplace_arg = \"\"\n        oneflow_python_dir_cmd = \"\"\n    else:\n        inplace_arg = f\"-DONEFLOW_PYTHON_DIR={oneflow_python_dir}\"\n        oneflow_python_dir_cmd = f\"\"\"\n        rm -rf {oneflow_python_dir}\n        cp -r {oneflow_src_dir}/python {oneflow_python_dir}\n        cd {oneflow_python_dir}\n        git init\n        git clean -nXd\n        git clean -fXd\n        cd -\n        \"\"\"\n    cmake_cmd = \" \".join(\n        [\n            \"cmake\",\n            common_cmake_args(\n                cache_dir=cache_dir, extra_oneflow_cmake_args=extra_oneflow_cmake_args\n            ),\n            \"-DTHIRD_PARTY=ON -DONEFLOW=OFF\",\n            extra_oneflow_cmake_args,\n            oneflow_src_dir,\n            inplace_arg,\n        ]\n    )\n\n    bash_cmd = f\"\"\"set -ex\nexport ONEFLOW_PYTHON_DIR={oneflow_python_dir}\n{oneflow_python_dir_cmd}\nexport PATH=\"$PATH:$(dirname {get_python_bin('3.6')})\"\nexport PYTHON_BIN_PATH={get_python_bin('3.6')}\n$PYTHON_BIN_PATH -m pip install -i https://mirrors.aliyun.com/pypi/simple --user -r {os.path.join(oneflow_src_dir, \"ci/fixed-dev-requirements.txt\")}\n$PYTHON_BIN_PATH -c \"from __future__ import print_function;import numpy; print(numpy.get_include());\"\n{cmake_cmd}\ncmake --build . -j `nproc` --target oneflow_deps\n\"\"\"\n    common_docker_args = get_common_docker_args(\n        oneflow_src_dir=oneflow_src_dir,\n        cache_dir=cache_dir,\n        current_dir=third_party_build_dir,\n        use_system_proxy=use_system_proxy,\n        inplace=inplace,\n    )\n    docker_cmd = (\n        f\"docker run --network=host {extra_docker_args} --rm {common_docker_args}\"\n    )\n    create_tmp_bash_and_run(docker_cmd, img_tag, bash_cmd, bash_args, bash_wrap, dry)\n\n\ndef get_python_bin(version):\n    assert version in [\"3.5\", \"3.6\", \"3.7\", \"3.8\", \"3.9\"]\n    py_ver = \"\".join(version.split(\".\"))\n    py_abi = f\"cp{py_ver}-cp{py_ver}\"\n    if version in [\"3.5\", \"3.6\", \"3.7\"]:\n        py_abi = f\"{py_abi}m\"\n    py_root = f\"/opt/python/{py_abi}\"\n    py_bin = f\"{py_root}/bin/python\"\n    return py_bin\n\n\ndef build_oneflow(\n    img_tag,\n    oneflow_src_dir,\n    cache_dir,\n    extra_oneflow_cmake_args,\n    extra_docker_args,\n    python_version,\n    skip_wheel,\n    package_name,\n    house_dir,\n    bash_args,\n    bash_wrap,\n    dry,\n    use_system_proxy,\n    enter_bash,\n    skip_audit,\n    inplace,\n):\n    oneflow_build_dir = os.path.join(cache_dir, \"build-oneflow\")\n    python_bin = get_python_bin(python_version)\n    oneflow_python_dir = get_python_dir(\n        inplace=inplace, oneflow_src_dir=oneflow_src_dir, cache_dir=cache_dir\n    )\n    if inplace:\n        inplace_arg = \"\"\n    else:\n        inplace_arg = f\"-DONEFLOW_PYTHON_DIR={oneflow_python_dir}\"\n    cmake_cmd = \" \".join(\n        [\n            \"cmake\",\n            common_cmake_args(\n                cache_dir=cache_dir, extra_oneflow_cmake_args=extra_oneflow_cmake_args\n            ),\n            \"-DTHIRD_PARTY=OFF -DONEFLOW=ON\",\n            extra_oneflow_cmake_args,\n            \"-DCMAKE_EXPORT_COMPILE_COMMANDS=1\",\n            f\"-DPython3_EXECUTABLE={python_bin}\",\n            f\"-DCODEGEN_PYTHON_EXECUTABLE={get_python_bin('3.6')}\",\n            oneflow_src_dir,\n            inplace_arg,\n        ]\n    )\n    common_docker_args = get_common_docker_args(\n        oneflow_src_dir=oneflow_src_dir,\n        cache_dir=cache_dir,\n        current_dir=oneflow_build_dir,\n        house_dir=house_dir,\n        use_system_proxy=use_system_proxy,\n        inplace=inplace,\n    )\n    docker_cmd = (\n        f\"docker run --network=host --rm {common_docker_args} {extra_docker_args}\"\n    )\n    if enter_bash:\n        docker_cmd += \" -it\"\n    bash_cmd = f\"\"\"set -ex\nexport LD_LIBRARY_PATH=/opt/intel/lib/intel64_lin:/opt/intel/mkl/lib/intel64:$LD_LIBRARY_PATH\nexport LD_LIBRARY_PATH=/opt/intel/lib:$LD_LIBRARY_PATH\nexport LD_LIBRARY_PATH=/opt/intel/oneapi/mkl/latest/lib/intel64:$LD_LIBRARY_PATH\nexport ONEFLOW_SRC_DIR={oneflow_src_dir}\nexport ONEFLOW_CMAKE_CMD=\"{cmake_cmd}\"\n{python_bin} -m pip install -i https://mirrors.aliyun.com/pypi/simple --user -r {os.path.join(oneflow_src_dir, \"ci/fixed-dev-requirements.txt\")}\n\"\"\"\n    if enter_bash:\n        bash_cmd += \"\\nbash\"\n    else:\n        bash_cmd += f\"\"\"\ncd {oneflow_python_dir}\ngit clean -nXd -e \\!oneflow/include -e \\!oneflow/include/**\ngit clean -fXd -e \\!oneflow/include -e \\!oneflow/include/**\ncd -\n{cmake_cmd}\ncmake --build . -j `nproc`\n\"\"\"\n    if skip_wheel or enter_bash:\n        pass\n    else:\n        bash_cmd += f\"\"\"\ncd {oneflow_python_dir}\n{python_bin} setup.py bdist_wheel -d /tmp/tmp_wheel --package_name {package_name}\ncd -\n\"\"\"\n    if skip_wheel == False:\n        if skip_audit:\n            bash_cmd += f\"\"\"\n    cp /tmp/tmp_wheel/*.whl {house_dir}\n    \"\"\"\n        else:\n            bash_cmd += f\"\"\"\n    auditwheel repair /tmp/tmp_wheel/*.whl --wheel-dir {house_dir}\n    \"\"\"\n    return create_tmp_bash_and_run(\n        docker_cmd, img_tag, bash_cmd, bash_args, bash_wrap, dry\n    )\n\n\ndef is_img_existing(tag):\n    returncode = subprocess.run(\n        f\"docker image inspect {tag}\",\n        shell=True,\n        stdout=subprocess.DEVNULL,\n        stderr=subprocess.DEVNULL,\n    ).returncode\n    if returncode == 0:\n        return True\n    else:\n        return False\n\n\nif __name__ == \"__main__\":\n    import argparse\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\n        \"--custom_img_tag\", type=str, required=False, default=None,\n    )\n    parser.add_argument(\n        \"--container_name\", type=str, required=False, default=None,\n    )\n    parser.add_argument(\n        \"--cache_dir\", type=str, required=False, default=None,\n    )\n    default_wheel_house_dir = os.path.join(os.getcwd(), \"wheelhouse\")\n    parser.add_argument(\n        \"--wheel_house_dir\", type=str, required=False, default=default_wheel_house_dir,\n    )\n    parser.add_argument(\"--python_version\", type=str, required=True)\n    parser.add_argument(\n        \"--cuda_version\", type=str, required=False, default=\"10.2\",\n    )\n    parser.add_argument(\n        \"--package_name\", type=str, required=False, default=\"oneflow\",\n    )\n    parser.add_argument(\n        \"--extra_oneflow_cmake_args\", action=\"append\", nargs=\"+\", default=[]\n    )\n    parser.add_argument(\n        \"--extra_docker_args\", type=str, required=False, default=\"\",\n    )\n    parser.add_argument(\n        \"--oneflow_src_dir\", type=str, required=False, default=os.getcwd(),\n    )\n    parser.add_argument(\n        \"--skip_third_party\", default=False, action=\"store_true\", required=False\n    )\n    parser.add_argument(\n        \"--skip_wheel\", default=False, action=\"store_true\", required=False\n    )\n    parser.add_argument(\n        \"--skip_img\", default=False, action=\"store_true\", required=False\n    )\n    parser.add_argument(\n        \"--skip_audit\", default=False, action=\"store_true\", required=False\n    )\n    parser.add_argument(\n        \"--build_img\", default=False, action=\"store_true\", required=False\n    )\n    parser.add_argument(\n        \"--use_tuna\", default=False, action=\"store_true\", required=False\n    )\n    parser.add_argument(\"--dry\", default=False, action=\"store_true\", required=False)\n    parser.add_argument(\n        \"--use_system_proxy\", default=False, action=\"store_true\", required=False\n    )\n    parser.add_argument(\"--mlir\", default=False, action=\"store_true\", required=False)\n    parser.add_argument(\"--gcc4\", default=False, action=\"store_true\", required=False)\n    parser.add_argument(\"--gcc7\", default=False, action=\"store_true\", required=False)\n    parser.add_argument(\"--gcc9\", default=False, action=\"store_true\", required=False)\n    parser.add_argument(\n        \"--use_aliyun_mirror\", default=False, action=\"store_true\", required=False\n    )\n    parser.add_argument(\"--cpu\", default=False, action=\"store_true\", required=False)\n    parser.add_argument(\"--bash\", default=False, action=\"store_true\", required=False)\n    parser.add_argument(\"--inplace\", default=False, action=\"store_true\", required=False)\n    parser.add_argument(\n        \"--shared_lib\", default=False, action=\"store_true\", required=False\n    )\n    parser.add_argument(\"--retry\", default=0, type=int)\n    args = parser.parse_args()\n    if args.skip_img:\n        \"Arg skip_img is deprecated. Setting it has no effect. If you want to build image, use --build_img\"\n    if args.skip_wheel:\n        args.skip_audit = True\n    print(\"args.extra_oneflow_cmake_args\", args.extra_oneflow_cmake_args)\n    assert args.package_name\n    extra_oneflow_cmake_args = \" \".join(\n        [\" \".join(l) for l in args.extra_oneflow_cmake_args]\n    )\n    if (not args.gcc4) and (not args.gcc7) and (not args.gcc9):\n        args.gcc7 = True\n    cuda_versions = []\n    if args.use_aliyun_mirror:\n        extra_oneflow_cmake_args += \" -DTHIRD_PARTY_MIRROR=aliyun\"\n    if args.shared_lib:\n        extra_oneflow_cmake_args += \" -DBUILD_SHARED_LIBS=ON\"\n    if args.cpu:\n        extra_oneflow_cmake_args += \" -DBUILD_CUDA=OFF\"\n        cuda_versions = [\"10.2\"]\n    else:\n        extra_oneflow_cmake_args += \" -DBUILD_CUDA=ON\"\n    cuda_versions = args.cuda_version.split(\",\")\n    cuda_versions = [v.strip() for v in cuda_versions]\n    if args.mlir:\n        extra_oneflow_cmake_args += \" -DWITH_MLIR=ON\"\n    else:\n        extra_oneflow_cmake_args += \" -DWITH_MLIR=Off\"\n    for cuda_version in cuda_versions:\n\n        cache_dir = None\n\n        def build():\n            img_tag = None\n            img_prefix = f\"oneflow-manylinux2014-cuda{cuda_version}\"\n            user = getpass.getuser()\n            versioned_img_tag = f\"{img_prefix}:0.1\"\n            if cuda_version in [\"11.0\", \"11.1\"]:\n                versioned_img_tag = f\"{img_prefix}:0.2\"\n            enforced_oneflow_cmake_args = \"\"\n            enforced_oneflow_cmake_args += \" -DBUILD_TESTING=ON\"\n            if float(cuda_version) >= 11:\n                assert (\n                    \"CUDNN_STATIC\" not in extra_oneflow_cmake_args\n                ), \"CUDNN_STATIC will be set to OFF if cuda_version > 11\"\n                enforced_oneflow_cmake_args += \" -DCUDNN_STATIC=OFF\"\n            extra_docker_args = args.extra_docker_args\n            if not args.container_name:\n                args.container_name = f\"manylinux-build-run-by-{getpass.getuser()}\"\n            assert args.container_name\n            subprocess.call(\n                f\"docker rm -f {args.container_name}\", shell=True,\n            )\n            extra_docker_args += f\" --name {args.container_name}\"\n            user_img_tag = f\"{img_prefix}:{user}\"\n            inc_img_tag = f\"oneflowinc/{versioned_img_tag}\"\n            img_tag = inc_img_tag\n            if args.build_img:\n                img_tag = user_img_tag\n            elif args.custom_img_tag:\n                img_tag = args.custom_img_tag\n            else:\n                if is_img_existing(versioned_img_tag):\n                    img_tag = versioned_img_tag\n                elif is_img_existing(inc_img_tag):\n                    img_tag = inc_img_tag\n                else:\n                    raise ValueError(\n                        f\"img not found, please run 'docker pull {inc_img_tag}'\"\n                    )\n            assert img_tag is not None\n            print(\"using\", img_tag)\n            if args.build_img:\n                build_img(\n                    cuda_version,\n                    args.oneflow_src_dir,\n                    args.use_aliyun_mirror,\n                    args.use_tuna,\n                    args.use_system_proxy,\n                    img_tag,\n                    args.dry,\n                )\n            bash_args = \"\"\n            bash_wrap = \"\"\n            if args.gcc4:\n                bash_wrap = \"gcc --version\"\n            elif args.gcc7:\n                bash_wrap = \"\"\"\nsource scl_source enable devtoolset-7\ngcc --version\n\"\"\"\n            elif args.gcc9:\n                bash_wrap = \"\"\"\nsource scl_source enable devtoolset-9\ngcc --version\n\"\"\"\n            else:\n                raise ValueError(\"either one in gcc4, gcc7, gcc9 must be enabled\")\n\n            global cache_dir\n            if args.cache_dir:\n                cache_dir = args.cache_dir\n            else:\n                cache_dir = os.path.join(os.getcwd(), \"manylinux2014-build-cache\")\n                sub_dir = cuda_version\n                if args.mlir:\n                    sub_dir += \"-mlir\"\n                if args.gcc4:\n                    sub_dir += \"-gcc4\"\n                if args.gcc7:\n                    sub_dir += \"-gcc7\"\n                if args.gcc9:\n                    sub_dir += \"-gcc9\"\n                if args.cpu:\n                    assert len(cuda_versions) == 1\n                    sub_dir += \"-cpu\"\n                if args.shared_lib:\n                    sub_dir += \"-shared\"\n                cache_dir = os.path.join(cache_dir, sub_dir)\n            if args.build_img:\n                return\n            if args.skip_third_party == False:\n                build_third_party(\n                    img_tag,\n                    args.oneflow_src_dir,\n                    cache_dir,\n                    extra_oneflow_cmake_args + enforced_oneflow_cmake_args,\n                    extra_docker_args,\n                    bash_args,\n                    bash_wrap,\n                    args.dry,\n                    args.use_system_proxy,\n                    args.inplace,\n                )\n            print(cuda_version.split(\".\"))\n            cuda_version_literal = \"\".join(cuda_version.split(\".\")[:2])\n            assert len(cuda_version_literal) == 3\n            python_versions = args.python_version.split(\",\")\n            python_versions = [pv.strip() for pv in python_versions]\n            for python_version in python_versions:\n                print(\"building for python version:\", python_version)\n                build_oneflow(\n                    img_tag,\n                    args.oneflow_src_dir,\n                    cache_dir,\n                    extra_oneflow_cmake_args + enforced_oneflow_cmake_args,\n                    extra_docker_args,\n                    python_version,\n                    args.skip_wheel,\n                    args.package_name,\n                    args.wheel_house_dir,\n                    bash_args,\n                    bash_wrap,\n                    args.dry,\n                    args.use_system_proxy,\n                    args.bash,\n                    args.skip_audit,\n                    args.inplace,\n                )\n\n        try:\n            build()\n        except subprocess.CalledProcessError as e:\n            print(\"failed: \", e.cmd, e.args)\n            if cache_dir and args.retry > 0:\n                print(\"clean: \", cache_dir, flush=True)\n                print(\"start retrying...\", flush=True)\n                if args.dry:\n                    pass\n                else:\n                    force_rm_dir(cache_dir)\n                build()\n            else:\n                exit(1)\n"
  },
  {
    "path": "docker/package/manylinux/launch.sh",
    "content": "set -ex\ndocker run --rm -it \\\n    -v `pwd`:`pwd` \\\n    -w `pwd` oneflow:rel-manylinux2014-cuda-11.0 bash\n"
  },
  {
    "path": "docs/Makefile",
    "content": "# Minimal makefile for Sphinx documentation\n#\n\n# You can set these variables from the command line.\nSPHINXOPTS    =\nSPHINXBUILD   = sphinx-build\nSOURCEDIR     = source\nBUILDDIR      = build\n\n# Put it first so that \"make\" without argument is like \"make help\".\nhelp:\n\t@$(SPHINXBUILD) -M help \"$(SOURCEDIR)\" \"$(BUILDDIR)\" $(SPHINXOPTS) $(O)\n\n.PHONY: help Makefile\n\nhtml_cn: Makefile\n\t@CN_DOCS=1 $(SPHINXBUILD) -M html \"$(SOURCEDIR)\" \"$(BUILDDIR)-cn\" $(SPHINXOPTS) $(O)\n\nhtml: Makefile\n\t@$(SPHINXBUILD) -M html \"$(SOURCEDIR)\" \"$(BUILDDIR)\" $(SPHINXOPTS) $(O)\n\nclean: Makefile\n\t@rm -rf build build-cn\n"
  },
  {
    "path": "docs/requirements.txt",
    "content": "sphinx==3.5.4\njinja2<3.1\nrecommonmark==0.6.0\nfuro==2021.4.11b34\nsphinx-copybutton==0.5.0\n# above are dev dependencies\n--pre\n--find-links https://oneflow-staging.oss-cn-beijing.aliyuncs.com/branch/master/cpu\noneflow\n"
  },
  {
    "path": "docs/source/_static/.gitkeep",
    "content": ""
  },
  {
    "path": "docs/source/auto_parallel.rst",
    "content": "Auto Parallelism\n====================================================\n\nAs the scale of deep-learning models grows larger and larger, distributed training,\nor parallelism, is needed. Data parallelism and model parallelism has been designed\nto speed up the training and solve memory issues.\n\nIn oneflow, SBP signature enables users to configure parallelism policy easily.\nHowever, users still need to specify the SBP property for each operator, or most of them.\nUsers might spend a couple of days digging into the detail of parallelism and get a\nlow throughput just because of a slight mistake in the configuration of SBP signature.\n\n.. note::\n\n   It only works on :doc:`graph` mode.\n\n\nOur strength\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\nTo get rid of all those configurations for SBP signatures, we developed auto parallelism.\nStill, configurations of placement are necessary and we have not supported auto placement\nyet. If you read this paragraph before you rush into any SBP stuff, then congratulation,\nyou do not need to learn SBPs. You can start writing your code as you did under CPU mode.\nOur auto parallelism would generate a fast strategy customized for your specific models,\nthe size of parameters, and the number of available GPUs.\n\n\nHow to use auto parallelism?\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\nYou just need to simply enable the configuration settings in the model\nof :doc:`graph` .\n\nExample::\n\n    import oneflow as flow\n    class SubclassGraph(flow.nn.Graph):\n        def __init__(self):\n            super().__init__() # MUST be called\n            # auto parallelism configuration\n            self.config.enable_auto_parallel(True)\n            # other configurations about auto parallelism\n            # ......\n\n        def build(self):\n            pass\n\n.. warning::\n\n   If you enable auto parallelism, OneFlow will take care of the SBP configurations\n   of operators except for explicit ``to_global`` functions.\n\n\nConfiguration API for auto parallelism\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\n.. currentmodule:: oneflow.nn.graph.graph_config.GraphConfig\n\n.. autosummary::\n    :toctree: generated\n    :nosignatures:\n\n    enable_auto_parallel\n    enable_auto_parallel_ignore_user_sbp_config\n    set_auto_parallel_computation_cost_ratio\n    set_auto_parallel_wait_time\n    enable_auto_parallel_trunk_algo\n    enable_auto_parallel_sbp_collector\n    enable_auto_memory\n\n"
  },
  {
    "path": "docs/source/autograd.rst",
    "content": "oneflow.autograd\n====================================================\n\n.. The documentation is referenced from:\n   https://pytorch.org/docs/1.10/autograd.html\n\n``oneflow.autograd`` provides classes and functions implementing automatic differentiation of arbitrary scalar \nvalued functions. It requires minimal changes to the existing code - you only need to declare ``Tensor`` s \nfor which gradients should be computed with the ``requires_grad=True`` keyword. As of now, we only support \nautograd for floating point ``Tensor`` types ( half, float, double and bfloat16).\n\n\n.. currentmodule:: oneflow.autograd\n\n.. autosummary::\n    :toctree: generated\n    :nosignatures:\n\n    backward\n    grad\n\nLocally disabling gradient computation\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n.. autosummary::\n    :toctree: generated\n    :nosignatures:\n\n    no_grad\n    enable_grad\n    set_grad_enabled\n    inference_mode\n\n.. TODO(wyg): uncomment this after aligning accumulate grad\n.. Default gradient layouts\n.. ^^^^^^^^^^^^^^^^^^^^^^^^\n\n.. A ``param.grad`` is accumulated by replacing ``.grad`` with a \n.. new tensor ``.grad + new grad`` during :func:`oneflow.autograd.backward()` or \n.. :func:`oneflow.Tensor.backward()`.\n\nIn-place operations on Tensors\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\nSupporting in-place operations in autograd is a hard matter, and we discourage\ntheir use in most cases. Autograd's aggressive buffer freeing and reuse makes\nit very efficient and there are very few occasions when in-place operations\nactually lower memory usage by any significant amount. Unless you're operating\nunder heavy memory pressure, you might never need to use them.\n\nTensor autograd functions\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n.. autosummary::\n    :nosignatures:\n\n   oneflow.Tensor.grad\n   oneflow.Tensor.requires_grad\n   oneflow.Tensor.is_leaf\n   oneflow.Tensor.backward\n   oneflow.Tensor.detach\n   oneflow.Tensor.register_hook\n   oneflow.Tensor.retain_grad\n\nFunction\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\n.. autoclass:: Function\n.. currentmodule:: oneflow.autograd\n.. autosummary::\n    :toctree: generated\n    :nosignatures:\n\n    Function.forward\n    Function.backward\n    Function.apply\n\nContext method mixins\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\nWhen creating a new :class:`Function`, the following methods are available to `ctx`.\n\n.. currentmodule:: oneflow._oneflow_internal.autograd.Function\n.. autosummary::\n    :toctree: generated\n    :nosignatures:\n    \n    FunctionCtx.mark_non_differentiable\n    FunctionCtx.save_for_backward\n    FunctionCtx.saved_tensors\n\nfunctional\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n.. currentmodule:: oneflow.autograd.functional\n.. autosummary::\n    :toctree: generated\n    :nosignatures:\n\n    vjp\n    jvp\n    jacobian\n    hessian\n    vhp\n    hvp    \n"
  },
  {
    "path": "docs/source/cn/__init__.py",
    "content": "from .math_ops import *\nfrom .activation import *\n"
  },
  {
    "path": "docs/source/cn/activation.py",
    "content": "import oneflow\nfrom oneflow.framework.docstr.utils import reset_docstr\n\nreset_docstr(\n    oneflow.nn.ReLU,\n    r\"\"\"ReLU(inplace=False)\n    \n    ReLU 激活函数，对张量中的每一个元素做 element-wise 运算，公式如下:\n\n    :math:`\\text{ReLU}(x) = (x)^+ = \\max(0, x)`\n\n    参数:\n        inplace: 是否做 in-place 操作。 默认为 ``False``\n\n    形状:\n        - Input: :math:`(N, *)` 其中 `*` 的意思是，可以指定任意维度\n        - Output: :math:`(N, *)` 输入形状与输出形状一致\n\n    示例：\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n        >>> relu = flow.nn.ReLU()\n        >>> ndarr = np.asarray([1, -2, 3])\n        >>> x = flow.Tensor(ndarr)\n        >>> relu(x)\n        tensor([1., 0., 3.], dtype=oneflow.float32)\n\n    \"\"\",\n)\n"
  },
  {
    "path": "docs/source/cn/math_ops.py",
    "content": "import oneflow\nfrom oneflow.framework.docstr.utils import reset_docstr\n\nreset_docstr(\n    oneflow.add,\n    r\"\"\"add(input, other)\n    \n    计算 `input` 和 `other` 的和。支持 element-wise、标量和广播形式的加法。\n    公式为：\n\n    .. math::\n        out = input + other\n\n    示例：\n\n    .. code-block:: python\n\n        >>> import numpy as np\n        >>> import oneflow as flow\n        \n        # element-wise 加法\n        >>> x = flow.tensor(np.random.randn(2,3), dtype=flow.float32)\n        >>> y = flow.tensor(np.random.randn(2,3), dtype=flow.float32)\n        >>> out = flow.add(x, y).numpy()\n        >>> out.shape\n        (2, 3)\n\n        # 标量加法\n        >>> x = 5\n        >>> y = flow.tensor(np.random.randn(2,3), dtype=flow.float32)\n        >>> out = flow.add(x, y).numpy()\n        >>> out.shape\n        (2, 3)\n\n        # 广播加法\n        >>> x = flow.tensor(np.random.randn(1,1), dtype=flow.float32)\n        >>> y = flow.tensor(np.random.randn(2,3), dtype=flow.float32)\n        >>> out = flow.add(x, y).numpy()\n        >>> out.shape\n        (2, 3)\n\n    \"\"\",\n)\n"
  },
  {
    "path": "docs/source/conf.py",
    "content": "# -*- coding: utf-8 -*-\n#\n# Configuration file for the Sphinx documentation builder.\n#\n# This file does only contain a selection of the most common options. For a\n# full list see the documentation:\n# http://www.sphinx-doc.org/en/master/config\n\n# -- Path setup --------------------------------------------------------------\n\n# If extensions (or modules to document with autodoc) are in another directory,\n# add these directories to sys.path here. If the directory is relative to the\n# documentation root, use os.path.abspath to make it absolute, like shown here.\n#\nimport os\nimport sys\nimport oneflow\n\nsys.path.insert(0, os.path.abspath(\".\"))\nCN_DOCS = os.getenv(\"CN_DOCS\")\nif CN_DOCS:\n    import cn\n\n# -- Project information -----------------------------------------------------\n\nproject = u\"OneFlow\"\ncopyright = u\"2020, OneFlow\"\nauthor = u\"OneFlow\"\n\n# The short X.Y version\nversion = u\"\"\n# The full version, including alpha/beta/rc tags\nrelease = u\"\"\n\n# -- General configuration ---------------------------------------------------\n\n# If your documentation needs a minimal Sphinx version, state it here.\n#\n# needs_sphinx = '1.0'\n\n# Add any Sphinx extension module names here, as strings. They can be\n# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom\n# ones.\nextensions = [\n    \"sphinx.ext.autodoc\",\n    \"sphinx.ext.napoleon\",\n    \"recommonmark\",\n    \"sphinx.ext.autosummary\",\n    \"sphinx_copybutton\",\n]\n\n# build the templated autosummary files\nautosummary_generate = True\nnumpydoc_show_class_members = False\n\n# Add any paths that contain templates here, relative to this directory.\ntemplates_path = [\"_templates\"]\n\n# The suffix(es) of source filenames.\n# You can specify multiple suffix as a list of string:\n#\nsource_suffix = {\n    \".rst\": \"restructuredtext\",\n    \".txt\": \"markdown\",\n    \".md\": \"markdown\",\n}\n\n# The master toctree document.\nmaster_doc = \"index\"\n\n# The language for content autogenerated by Sphinx. Refer to documentation\n# for a list of supported languages.\n#\n# This is also used if you do content translation via gettext catalogs.\n# Usually you set \"language\" from the command line for these cases.\nlanguage = u\"en\"\n\n# List of patterns, relative to source directory, that match files and\n# directories to ignore when looking for source files.\n# This pattern also affects html_static_path and html_extra_path.\nexclude_patterns = []\n\n# The name of the Pygments (syntax highlighting) style to use.\npygments_style = None\n\n\n# -- Options for HTML output -------------------------------------------------\n\n# The theme to use for HTML and HTML Help pages.  See the documentation for\n# a list of builtin themes.\n#\nhtml_theme = \"furo\"\n\n# Theme options are theme-specific and customize the look and feel of a theme\n# further.  For a list of options available for each theme, see the\n# documentation.\n#\n# html_theme_options = {}\n\n# Add any paths that contain custom static files (such as style sheets) here,\n# relative to this directory. They are copied after the builtin static files,\n# so a file named \"default.css\" will overwrite the builtin \"default.css\".\nhtml_static_path = [\"_static\"]\n\n# Custom sidebar templates, must be a dictionary that maps document names\n# to template names.\n#\n# The default sidebars (for documents that don't match any pattern) are\n# defined by theme itself.  Builtin themes are using these templates by\n# default: ``['localtoc.html', 'relations.html', 'sourcelink.html',\n# 'searchbox.html']``.\n#\n# html_sidebars = {}\n\n# -- Options for HTMLHelp output ---------------------------------------------\n\n# Output file base name for HTML help builder.\nhtmlhelp_basename = \"OneFlowdoc\"\n\n\n# -- Options for LaTeX output ------------------------------------------------\n\nlatex_elements = {\n    # The paper size ('letterpaper' or 'a4paper').\n    #\n    # 'papersize': 'letterpaper',\n    # The font size ('10pt', '11pt' or '12pt').\n    #\n    # 'pointsize': '10pt',\n    # Additional stuff for the LaTeX preamble.\n    #\n    # 'preamble': '',\n    # Latex figure (float) alignment\n    #\n    # 'figure_align': 'htbp',\n}\n\n# Grouping the document tree into LaTeX files. List of tuples\n# (source start file, target name, title,\n#  author, documentclass [howto, manual, or own class]).\nlatex_documents = [\n    (\n        master_doc,\n        \"OneFlow.tex\",\n        u\"OneFlow API Reference\",\n        u\"Oneflow Contributors\",\n        \"manual\",\n    ),\n]\n\n\n# -- Options for manual page output ------------------------------------------\n\n# One entry per manual page. List of tuples\n# (source start file, name, description, authors, manual section).\nman_pages = [(master_doc, \"oneflow\", u\"OneFlow API Reference\", [author], 1)]\n\n\n# -- Options for Texinfo output ----------------------------------------------\n\n# Grouping the document tree into Texinfo files. List of tuples\n# (source start file, target name, title, author,\n#  dir menu entry, description, category)\ntexinfo_documents = [\n    (\n        master_doc,\n        \"OneFlow\",\n        u\"OneFlow API Reference\",\n        author,\n        \"OneFlow\",\n        \"OneFlow API Reference\",\n        \"Miscellaneous\",\n    ),\n]\n\n\n# -- Options for Epub output -------------------------------------------------\n\n# Bibliographic Dublin Core info.\nepub_title = project\n\n# The unique identifier of the text. This can be a ISBN number\n# or the project homepage.\n#\n# epub_identifier = ''\n\n# A unique identification for the text.\n#\n# epub_uid = ''\n\n# A list of files that should not be packed into the epub file.\nepub_exclude_files = [\"search.html\"]\n\n\n# -- Extension configuration -------------------------------------------------\n\nautodoc_default_options = {\n    \"undoc-members\": True,\n    \"exclude-members\": \"forward, extra_repr, reset_parameters\",\n}\n\n\ndef should_skip_member(app, what, name, obj, skip, options):\n    import collections\n\n    is_deprecated = oneflow.is_deprecated(obj)\n    if is_deprecated:\n        print(\"skipping deprecated\", what, name, obj)\n    magical = name in [\"__weakref__\", \"__doc__\", \"__module__\", \"__dict__\"]\n    return skip or is_deprecated or magical\n\n\ndef setup(app):\n    app.connect(\"autodoc-skip-member\", should_skip_member)\n"
  },
  {
    "path": "docs/source/cuda.rst",
    "content": "oneflow.cuda\n===================================\n\n.. The documentation is referenced from: https://pytorch.org/docs/1.10/cuda.html.\n\n.. currentmodule:: oneflow.cuda\n\n.. autosummary::\n    :toctree: generated\n    :nosignatures:\n\n    is_available\n    device_count\n    current_device\n    set_device\n    synchronize\n    get_device_properties\n    get_device_capability\n    get_device_name\n\n.. note::\n   The :attr:`current_device` returns local rank as device index. It is different from the 'torch.current_device()' in PyTorch.\n\n\nRandom Number Generator\n-------------------------\n.. autosummary::\n    :toctree: generated\n    :nosignatures:\n\n    manual_seed_all\n    manual_seed\n    get_rng_state\n    get_rng_state_all\n    set_rng_state\n    set_rng_state_all\n\n\nGPU tensor\n-----------------------------\n.. autosummary::\n    :toctree: generated\n    :nosignatures:\n\n    HalfTensor\n    FloatTensor\n    DoubleTensor\n    BoolTensor\n    ByteTensor\n    CharTensor\n    IntTensor\n    LongTensor\n\nMemory management\n-----------------------------\n.. autosummary::\n    :toctree: generated\n    :nosignatures:\n\n    empty_cache\n    "
  },
  {
    "path": "docs/source/distributed.rst",
    "content": "oneflow.distributed\n=========================================================\n\n.. note ::\n    Please refer to `OneFlow Distributed Overview <https://docs.oneflow.org/master/parallelism/01_introduction.html>`__\n    for a brief introduction to all features related to distributed training.\n\nOneFlow provides two ways to accomplish `Distributed Training`:\n\n- The first way is that users are recommended to use OneFlow's global Tensor for distributed training. Global Tensor regards the computing cluster as a supercomputing device, allowing users to write distributed training code just like in a single-machine environment.\n\n- OneFlow also provides a DDP（DistributedDataParallel） module aligned with PyTorch. DDP has been well-known and widely used in data parallelism by the majority of PyTorch users. Also see `PyTorch DDP introduction <https://pytorch.org/docs/1.10/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel>`_.\n\n\n\nBasic\n-------------------------------\nWhen you start distributed training in OneFlow, the following functions can be used.\n\n.. currentmodule:: oneflow.env\n\n.. autosummary::\n    :toctree: generated\n    :nosignatures:\n\n    get_world_size\n    get_rank\n    get_local_rank\n    get_node_size\n    init_rdma\n    rdma_is_initialized\n\n\n`Global Tensor`\n--------------------------------------------------------------\n\nConstruct `Global Tensor`\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\nA `Global Tensor` can be created with a ``placement`` and a ``sbp``. The ``placement`` describes the physical devices of the global tensor will be allocated, and the ``sbp`` describes its distribution among these devices.\n\n::\n\n    >>>import oneflow as flow\n    >>> # Place a global tensor on cuda device of rank(process) 0 and 1\n    >>> placement = flow.placement(type=\"cuda\", ranks=[0, 1])\n    >>> # Each rank's local data is a part data as a result of spliting global data on dim 0\n    >>> sbp = flow.sbp.split(dim=0)\n    >>> # Create a global tensor by randn\n    >>> x = flow.randn(4, 5, placement=placement, sbp=sbp)\n    >>> x.shape\n    oneflow.Size([4, 5])\n\n\nConvert `Local Tensor` to `Global Tensor`\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\nWith ``Tensor.to_global`` interface, `Local Tensor` can create a `Global Tensor` and use that `Local Tensor` as its local component at the current node.\n\nTwo `local tensors` with the shape of ``(2,5)`` are created separately on two devices. While after the ``to_global`` method, the `global tensor` with a shape of ``(4,5)`` is obtained.\n\nCode running on Node 0\n\n::\n\n    import oneflow as flow\n\n    x = flow.randn(2,5)\n    placement = flow.placement(\"cuda\", [0,1])\n    sbp = flow.sbp.split(0)\n    x_global = x.to_global(placement=placement, sbp=sbp)\n    x_global.shape\n\nCode running on Node 1\n\n::\n\n    import oneflow as flow\n\n    x = flow.randn(2,5)\n    placement = flow.placement(\"cuda\", [0,1])\n    sbp = flow.sbp.split(0)\n    x_global = x.to_global(placement=placement, sbp=sbp)\n    x_global.shape\n\nRedistribute `Global Tensor`\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\nRedistributing a `Global Tensor` means moving its data to another device group (or placement), or changing its data distribution (or SBP) across the group, or both at the same time. The redistributed tensor is still a `Global Tensor`.\n\n::\n\n    >>> import oneflow as flow\n    >>> x = flow.tensor([1.0, 2.0], placement=flow.placement(\"cuda\", ranks=[0, 1]), sbp=flow.sbp.split(0))\n    >>> y = x.to_global(placement=flow.placement(\"cuda\", ranks=[2, 3]), sbp=flow.sbp.broadcast)\n\nAccording to the operator's semantics, OneFlow defines a sequence of valid input and output SBP combinations for each built-in operator. So OneFlow could automatically redistribute the `Global Tensor` to satisfy the operator's SBP requirements for its input Tensor. For example, the following code:\n\n::\n\n    >>> import oneflow as flow\n    >>> x = flow.randn(4, 4, \n            placement=flow.placement(\"cuda\", ranks=[0, 1]), \n            sbp=flow.sbp.split(0))\n    >>> y = flow.randn(4, 4, \n            placement=flow.placement(\"cuda\", ranks=[0, 1]), \n            sbp=flow.sbp.split(1))\n    >>> z = x + y\n\nWhen ``x + y`` is executed, since x is split along dimension ``0`` and y is split along dimension ``1``, their local components at each node can not be added directly, then OneFlow will automatically redistribute one of x and y to make them have the same SBP, and complete the add operation successfully.\n\n.. note ::\n    - Global Tensor can not be used in combination with DDP currently.\n    - Global Tensor requires all devices to execute at the same pace, otherwise, it may cause multi-process deadlock.\n\nGet Local Tensor from Global Tensor\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\nWith ``Tensor.to_local`` interface, the `Global Tensor` can return its local component at the current node.\n\n::\n\n    y = x.to_local()\n    y.is_local\n    True\n    y\n    tensor([[ 2.9186e-01, -3.9442e-01,  4.7072e-04, -3.2216e-01,  1.7788e-01],\n                [-4.5284e-01,  1.2361e-01, -3.5962e-01,  2.6651e-01,  1.2951e+00]],\n            device='cuda:0', dtype=oneflow.float32)\n\n\nDistributedDataParallel\n--------------------------------------------------------------\n\nFor more information about DistributedDataParallel, see ``nn.parallel.DistributedDataParallel``\n\nThe following script shows the process of using ``oneflow.nn.parallel.DistributedDataParallel`` for training data parallel: \n\n.. code-block:: \n\n    import oneflow as flow\n    from oneflow.nn.parallel import DistributedDataParallel as ddp\n\n    train_x = [\n        flow.tensor([[1, 2], [2, 3]], dtype=flow.float32),\n        flow.tensor([[4, 6], [3, 1]], dtype=flow.float32),\n    ]\n    train_y = [\n        flow.tensor([[8], [13]], dtype=flow.float32),\n        flow.tensor([[26], [9]], dtype=flow.float32),\n    ]\n\n\n    class Model(flow.nn.Module):\n        def __init__(self):\n            super().__init__()\n            self.lr = 0.01\n            self.iter_count = 500\n            self.w = flow.nn.Parameter(flow.tensor([[0], [0]], dtype=flow.float32))\n\n        def forward(self, x):\n            x = flow.matmul(x, self.w)\n            return x\n\n\n    m = Model().to(\"cuda\")\n    m = ddp(m)\n    loss = flow.nn.MSELoss(reduction=\"sum\")\n    optimizer = flow.optim.SGD(m.parameters(), m.lr)\n\n    for i in range(0, m.iter_count):\n        rank = flow.env.get_rank()\n        x = train_x[rank].to(\"cuda\")\n        y = train_y[rank].to(\"cuda\")\n\n        y_pred = m(x)\n        l = loss(y_pred, y)\n        if (i + 1) % 50 == 0:\n            print(f\"{i+1}/{m.iter_count} loss:{l}\")\n\n        optimizer.zero_grad()\n        l.backward()\n        optimizer.step()\n\n    print(f\"\\nw:{m.w}\")\n\nThere are only two differences between the data parallelism training code and the stand-alone single-card script:\n\n- Use `DistributedDataParallel` to wrap the module object (`m = ddp(m)`)\n- Use `get_rank` to get the current device number and distribute the data to the device.\n\nThen use `launcher` to run the script, leave everything else to OneFlow, which makes distributed training as simple as stand-alone single-card training:\n\n::\n\n    python3 -m oneflow.distributed.launch --nproc_per_node 2 ./ddp_train.py\n\n\nCommunication collectives\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n.. currentmodule:: oneflow.comm\n    \n.. autosummary::\n    :toctree: generated\n    :nosignatures:\n\n        all_reduce\n        all_gather\n        all_gather_into_tensor\n        all_to_all\n        broadcast\n        barrier\n        gather\n        reduce\n        reduce_scatter\n        reduce_scatter_tensor\n        recv\n        scatter\n        send\n\nWe also provide PyTorch-compatible APIs for communication collectives, for example, `oneflow.distributed.all_reduce(tensor, op=ReduceOp.SUM, group=None, async_op=False)`. For more information, see `PyTorch Distributed Communication <https://pytorch.org/docs/stable/distributed.html>`_. Note that we currently only support op=ReduceOp.SUM, group=None and async_op=False in these operations.\n\nLaunching distributed training\n--------------------------------------------------------------\n\n.. currentmodule:: oneflow.distributed\n\nrun commands below to see more about usage.\n\n::\n\n    python3 -m oneflow.distributed.launch -h\n\n.. code-block::\n\n    usage: launch.py [-h] [--nnodes NNODES] [--node_rank NODE_RANK]\n                 [--nproc_per_node NPROC_PER_NODE] [--master_addr MASTER_ADDR]\n                 [--master_port MASTER_PORT] [-m] [--no_python]\n                 [--redirect_stdout_and_stderr] [--logdir LOGDIR]\n                 training_script ...\n\n    OneFlow distributed training launch helper utility that will spawn up multiple\n    distributed processes\n\n    positional arguments:\n    training_script       The full path to the single GPU training program/script to be\n                            launched in parallel, followed by all the arguments for the\n                            training script\n    training_script_args\n\n    optional arguments:\n    -h, --help            show this help message and exit\n    --nnodes NNODES       The number of nodes to use for distributed training\n    --node_rank NODE_RANK\n                            The rank of the node for multi-node distributed training\n    --nproc_per_node NPROC_PER_NODE\n                            The number of processes to launch on each node, for GPU\n                            training, this is recommended to be set to the number of GPUs in\n                            your system so that each process can be bound to a single GPU.\n    --master_addr MASTER_ADDR\n                            Master node (rank 0)'s address, should be either the IP address\n                            or the hostname of node 0, for single node multi-proc training,\n                            the --master_addr can simply be 127.0.0.1\n    --master_port MASTER_PORT\n                            Master node (rank 0)'s free port that needs to be used for\n                            communication during distributed training\n    -m, --module          Changes each process to interpret the launch script as a python\n                            module, executing with the same behavior as'python -m'.\n    --no_python           Do not prepend the training script with \"python\" - just exec it\n                            directly. Useful when the script is not a Python script.\n    --redirect_stdout_and_stderr\n                            write the stdout and stderr to files 'stdout' and 'stderr'. Only\n                            available when logdir is set\n    --logdir LOGDIR       Relative path to write subprocess logs to. Passing in a relative\n                            path will create a directory if needed. Note that successive\n                            runs with the same path to write logs to will overwrite existing\n                            logs, so be sure to save logs as needed.\n"
  },
  {
    "path": "docs/source/distributions.rst",
    "content": "oneflow.distributions\n==================================================\n\n.. contents:: oneflow.distributions\n    :depth: 2\n    :local:\n    :class: this-will-duplicate-information-and-it-is-still-useful-here\n    :backlinks: top\n\n.. currentmodule:: oneflow.distributions\n.. autosummary::\n    :toctree: generated\n    :nosignatures:\n    :template: classtemplate.rst\n\n    Distribution\n    Categorical \n"
  },
  {
    "path": "docs/source/environment_variables.rst",
    "content": "Environment Variables\n================================================\n\nOneFlow has an extensive set of environment variables to tune for specific usage.\n\n`ONEFLOW_COMM_NET_IB_HCA <https://github.com/Oneflow-Inc/oneflow/blob/v1.0.0/oneflow/core/comm_network/ibverbs/ibverbs_comm_network.cpp#L47>`_\n---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------\n\nWhen there are multiple IB NIC(which can be checked by ``ibstatus`` on the server), the system uses the first IB NIC for comm_net communication by default.\n\nWhen this environment variable is set, the system will check all IB NIC and find the NIC with the corresponding name. `#5626 <https://github.com/Oneflow-Inc/oneflow/pull/5626>`_\n\nValues accepted\n^^^^^^^^^^^^^^^\nThe default value is empty, such as ``mlx5_0:1``、 ``mlx5_1:1``. When the port is 0, the default value is 1, representing the first port.\n\n`ONEFLOW_COMM_NET_IB_GID_INDEX <https://github.com/Oneflow-Inc/oneflow/blob/v1.0.0/oneflow/core/comm_network/ibverbs/ibverbs_comm_network.cpp#L142>`_\n---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------\n\nFor the query of `ibv_query_gid <https://www.ibm.com/docs/en/aix/7.2?topic=management-ibv-query-gid>`_, and 0 represents success. It often used with ``ONEFLOW_COMM_NET_IB_HCA``. GID means the Global ID, QP under RoCE network must be built by this value, instead of just using the LID as in the IB network. `#5626 <https://github.com/Oneflow-Inc/oneflow/pull/5626>`_\n\nValues accepted\n^^^^^^^^^^^^^^^\nThe default value is 0, representing the port index value\n\n`ONEFLOW_COMM_NET_IB_QUEUE_DEPTH <https://github.com/Oneflow-Inc/oneflow/blob/v1.0.0/oneflow/core/comm_network/ibverbs/ibverbs_qp.cpp#L44>`_\n---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------\n\nQueue length of jobs in IB network.\n\nThis value effectively controls the size of the module without instead of using IB's default size, such as ``ONEFLOW_COMM_NET_IB_MEM_BLOCK_SIZE``.\n\nValues accepted\n^^^^^^^^^^^^^^^\nThe default value is ``1024``, receiving ``int64_t``. The system would compare with ``max_qp_wr`` (Maximum number of outstanding WR on any work queue), and take the smaller one.\n\n`ONEFLOW_COMM_NET_IB_MEM_BLOCK_SIZE <https://github.com/Oneflow-Inc/oneflow/blob/v1.0.0/oneflow/core/comm_network/ibverbs/ibverbs_qp.cpp#L68>`_\n---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------\n\nThe size of the module read when communicating.\n\nThe value can calculate the amount of module, and transmit it after encapsulation.\n\nValues accepted\n^^^^^^^^^^^^^^^\nThe default value is ``8388608`` (8M)\n\n`ONEFLOW_STREAM_CUDA_EVENT_FLAG_BLOCKING_SYNC <https://github.com/Oneflow-Inc/oneflow/blob/v1.0.0/oneflow/core/ep/cuda/cuda_device.cpp#L59>`_\n---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------\n\nRepresents stream, and marks Blocking synchronization in cuda. `Detailed information <https://www.cnblogs.com/1024incn/p/5891051.html>`_, `#5612 <https://github.com/Oneflow-Inc/oneflow/pull/5612>`_, `#5837 <https://github.com/Oneflow-Inc/oneflow/pull/5837>`_\n\nValues accepted\n^^^^^^^^^^^^^^^\nDefine and set to ``false``, and would be ``true` only when the value is ``1``, ``true``, ``yes``, ``on`` and ``y``.\n\n`ONEFLOW_LIBIBVERBS_PATH <https://github.com/Oneflow-Inc/oneflow/blob/v1.0.0/oneflow/core/platform/lib/ibv_wrapper.cpp#L24>`_\n---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------\n\nTo load the DynamicLibrary by dlopen at runtime, to find symbols of ibverbs functions by dlopen without linking during compile for better compatibility. `#4852 <https://github.com/Oneflow-Inc/oneflow/pull/4852>`_.\n\nIf it failed, it will output ``libibverbs not available, ibv_fork_init skipped``, if it worked, the ``import oneflow`` will output such as ``loaded library: /usr/lib/x86_64-linux-gnu/libibverbs.so.1``\n\nValues accepted\n^^^^^^^^^^^^^^^\nThe default value is empty, but will load ``libibverbs.so.1``, ``libibverbs.so``.\n\n`ONEFLOW_DEBUG_MODE <https://github.com/Oneflow-Inc/oneflow/blob/v1.0.0/oneflow/core/common/env_var/debug_mode.h#L23>`_\n---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------\n\nEnable ``debug`` mode, ``ONEFLOW_DEBUG`` can do.\n\nIf ``debug`` mode is on, it will output more INFO level logs, different ``prototxt`` and ``dot`` to files. The automatically inserted boxing information will be printed to the log file under eager global mode.\n\nValues accepted\n^^^^^^^^^^^^^^^\nThe default value is empty, but will receive any string.\n\n`ONEFLOW_DRY_RUN <https://github.com/Oneflow-Inc/oneflow/blob/v1.0.0/oneflow/core/job/resource_desc.cpp#L65>`_\n---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------\n\nOnly for test running, it can generate log files like ``dot``.\n\nExit once the test is succeed, do not try real training.\n\nValues accepted\n^^^^^^^^^^^^^^^\nThe default value is empty, but will receive any string.\n\n`ONEFLOW_DEBUG_KERNEL_SYNC_CHECK_NUMERICS <https://github.com/Oneflow-Inc/oneflow/blob/v1.0.0/oneflow/core/lazy/stream_context/cuda/cuda_stream_context.cpp#L66>`_\n---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------\n\nOnly used when debugging because the performance would be affected, it could detect which op in the network appears nan or inf.\n\nIt will create ``CpuCheckNumericsKernelObserver`` under ``cpu`` , and ``CudaCheckNumericsKernelObserver`` under ``cuda`` `#6052 <https://github.com/Oneflow-Inc/oneflow/pull/6052>`_ .\n\nValues accepted\n^^^^^^^^^^^^^^^\nDefine and set to ``false``, and would be ``true`` only when the value is ``1``, ``true``, ``yes``, ``on`` and ``y``.\n\n`ONEFLOW_DEBUG_KERNEL_SYNC_CHECK <https://github.com/Oneflow-Inc/oneflow/blob/v1.0.0/oneflow/core/job/env_global_objects_scope.cpp#L193>`_\n---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------\n\nOnly used when debugging because the performance would be affected.\n\nIt will create ``SyncCheckKernelObserver`` and will be synced after each kernel.\n\nIt could be used to debug cuda errors. `#6052 <https://github.com/Oneflow-Inc/oneflow/pull/6052>`_\n\nValues accepted\n^^^^^^^^^^^^^^^\nDefine and set to ``false``, and would be ``true`` only when the value is ``1``, ``true``, ``yes``, ``on`` and ``y``.\n\n`ONEFLOW_PROFILER_KERNEL_PROFILE_CUDA_MEMORY_BANDWIDTH <https://github.com/Oneflow-Inc/oneflow/blob/v1.0.0/oneflow/core/profiler/kernel.cpp#L34>`_\n---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------\n\nUsed when generate profiler files by nsys.\n\nProfiler is only valid for lazy temporarily.\n\nIt can estimate the memory bandwidth reached by kernel by counting the execution time of the GPU kernel and the size of the input and output memory, and help find potential kernels that can be optimized. `Details <https://github.com/Oneflow-Inc/oneflow/blob/02e29f9648f63a4d936cd818061e90064d027005/oneflow/core/profiler/kernel.cpp#L53>`_\n\nValues accepted\n^^^^^^^^^^^^^^^\nDefine and set to ``false``. When using, the compiled package needs to enable ``BUILD_PROFILER``.\n\n`ONEFLOW_PROFILER_KERNEL_PROFILE_KERNEL_FORWARD_RANGE <https://github.com/Oneflow-Inc/oneflow/blob/v1.0.0/oneflow/core/profiler/kernel.cpp#L36>`_\n---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------\n\nThe same as above. collect `op name <https://github.com/Oneflow-Inc/oneflow/blob/v1.0.0/oneflow/core/profiler/kernel.cpp#L62>`_\n\nValues accepted\n^^^^^^^^^^^^^^^\nDefine and set to ``false``. When using, the compiled package needs to enable ``BUILD_PROFILER``.\n\n`ONEFLOW_KERNEL_DISABLE_BLOB_ACCESS_CHECKER <https://github.com/Oneflow-Inc/oneflow/blob/v1.0.0/oneflow/core/job/env_global_objects_scope.cpp#L199>`_\n---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------\n\nOnly use blob_access_checker after enabling, because blob_access_checker is for correctness assurance, and closing it in some cases can increase the kernel overhead. `#5728 <https://github.com/Oneflow-Inc/oneflow/pull/5728>`_\n\nValues accepted\n^^^^^^^^^^^^^^^\nDefine and set to ``false``, and would be ``true`` only when the value is ``1``, ``true``, ``yes``, ``on`` and ``y``.\n\n`ONEFLOW_KERNEL_ENABLE_CUDA_GRAPH <https://github.com/Oneflow-Inc/oneflow/blob/v1.0.0/oneflow/core/kernel/user_kernel.cpp#L692>`_\n---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------\n\nTakes effect under ``WITH_CUDA_GRAPHS`` and the default value is ``false``. It uses more memory, so when there's just enough memory, it won't run.\n\nTurning on CUDA_GRAPH will use up more memory CUDA Graphs support. `#5868 <https://github.com/Oneflow-Inc/oneflow/pull/5868>`_\n\nValues accepted\n^^^^^^^^^^^^^^^\nDefine and set to ``false``, and would be ``true`` only when the value is ``1``, ``true``, ``yes``, ``on`` and ``y``.\n\n`ONEFLOW_ACTOR_ENABLE_LIGHT_ACTOR <https://github.com/Oneflow-Inc/oneflow/blob/v1.0.0/oneflow/core/thread/thread.cpp#L30>`_\n---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------\n\nLightActor is a new type of Actor that only handles NormalForward and similar tasks where all regst_num is 1 or tasks with only one kernel. `#5868 <https://github.com/Oneflow-Inc/oneflow/pull/5868>`_. ``export ONEFLOW_KERNEL_ENABLE_CUDA_GRAPH=1`` (Would use more memories), ``export ONEFLOW_THREAD_ENABLE_LOCAL_MESSAGE_QUEUE=1``, ``export ONEFLOW_KERNEL_DISABLE_BLOB_ACCESS_CHECKER=1``, ``export ONEFLOW_ACTOR_ENABLE_LIGHT_ACTOR=1``, ``export ONEFLOW_STREAM_REUSE_CUDA_EVENT=1`` can be used together.\n\nValues accepted\n^^^^^^^^^^^^^^^\nDefine and set to ``false``, and would be ``true`` only when the value is ``1``, ``true``, ``yes``, ``on`` and ``y``.\n\n`ONEFLOW_THREAD_ENABLE_LOCAL_MESSAGE_QUEUE <https://github.com/Oneflow-Inc/oneflow/blob/v1.0.0/oneflow/core/thread/thread.cpp#L29>`_\n---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------\n\n`#5720 <https://github.com/Oneflow-Inc/oneflow/pull/5720>`_. It is used to enable local message queue, ``oneflow.config.thread_enable_local_message_queue(True)`` is no longer used.\n\nValues accepted\n^^^^^^^^^^^^^^^\nDefine and set to ``false``, and would be ``true`` only when the value is ``1``, ``true``, ``yes``, ``on`` and ``y``.\n\n`ONEFLOW_PERSISTENT_IN_STREAM_BUFFER_SIZE_BYTES <https://github.com/Oneflow-Inc/oneflow/blob/v1.0.0/oneflow/core/persistence/persistent_in_stream.cpp#L30>`_\n---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------\n\nRepresents the size of each read from disk. `#5162 <https://github.com/Oneflow-Inc/oneflow/pull/5162>`_\n\nValues accepted\n^^^^^^^^^^^^^^^\nThe default value is empty. If an invalid string or negative number is entered, the default value would be ``32 * 1024``; 32KB.\n\n`ONEFLOW_DECODER_ENABLE_NVJPEG_HARDWARE_ACCELERATION <https://github.com/Oneflow-Inc/oneflow/blob/v1.0.0/oneflow/core/kernel/image_decoder_random_crop_resize_kernel.cpp#L290>`_\n---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------\n\n``NVJPEG_VER_MAJOR`` need to be bigger than ``11``. It can accelerate nvjpeg hardware, warm up jpeg decoder and hw_jpeg decoder, `#5851 <https://github.com/Oneflow-Inc/oneflow/pull/5851>`_.\n\nHardware JPEG decoder and NVIDIA nvJPEG library on NVIDIA A100 GPUs\n\nValues accepted\n^^^^^^^^^^^^^^^\nDefine and set to ``true``, and would be ``true`` only when the value is ``1``, ``true``, ``yes``, ``on`` and ``y``.\n\n`ONEFLOW_SERVING_DEBUG <https://github.com/Oneflow-Inc/oneflow/blob/v1.0.0/oneflow/api/cpp/framework/graph.cpp#L213>`_\n---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------\n\nFor printing information of OneFlow Serving Debug\n\nValues accepted\n^^^^^^^^^^^^^^^\nThe default value is ``false``\n\n`ONEFLOW_DISABLE_VIEW <https://github.com/Oneflow-Inc/oneflow/blob/v1.0.0/oneflow/core/framework/tensor_methods.cpp#L35>`_\n---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------\n\nTo disable view mechanism, which means op related to view would stop running.\n\nValues accepted\n^^^^^^^^^^^^^^^\nThe default value is ``false``\n\n`ONEFLOW_BOXING_DISABLE_MIDDLE_NODE_AND_CHECK <https://github.com/Oneflow-Inc/oneflow/blob/v1.0.0/oneflow/core/auto_parallel/boxing_collector.cpp#L82>`_\n---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------\n\nWhether to disable Middle Node. When it is false, all inter-SBP communication is supported\n\nValues accepted\n^^^^^^^^^^^^^^^\nThe default value is ``false``\n\n`ONEFLOW_ONE_EMBEDDING_DISABLE_NUMA_AWARE_ALLOCATION <https://github.com/Oneflow-Inc/oneflow/blob/v1.0.0/oneflow/core/embedding/full_cache.cu#L414>`_\n---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------\n\nWhether to disable NUMA_AWARE memory allocation when the OneEmbedding module allocates video memory.\n\nNUMA_AWARE memory allocation means that when allocating pinned host memory, the cpu close to the gpu will be considered (for example, if it is gpu 0 1, memory will be allocated on cpu0)\n\nValues accepted\n^^^^^^^^^^^^^^^\nThe default value is ``false``\n\n`ONEFLOW_EP_CUDA_ENABLE_TF32_EXECUTION <https://github.com/Oneflow-Inc/oneflow/blob/v1.0.0/oneflow/core/ep/cuda/cuda_stream.cpp#L96>`_\n---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------\n\nWhether to allow CUDA to use TF32 numeric types for computation\n\nValues accepted\n^^^^^^^^^^^^^^^\nThe default value is ``true``\n\n`ONEFLOW_FUNCTOR_DISABLE_FUSED_MLP <https://github.com/Oneflow-Inc/oneflow/blob/v1.0.0/oneflow/core/functional/impl/nn_functor.cpp#L554>`_\n---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------\n\nWhether to disable the fused_mlp operator implemented by cublasLt in FusedMLPFunctor, if disabled, it will degenerate into a multiple matrix multiplication operation.\n\nValues accepted\n^^^^^^^^^^^^^^^\nThe default value is ``false``\n\n`ONEFLOW_ONE_EMBEDDING_EMBEDDING_SHUFFLE_INDEPENTENT_STREAM <https://github.com/Oneflow-Inc/oneflow/blob/v1.0.0/oneflow/core/job_rewriter/replace_embedding_ops_pass.cpp#L192>`_\n---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------\n\nWhether to put the EmbeddingShuffle of the OneEmbedding module on a separate stream for overlapping execution.\n\nValues accepted\n^^^^^^^^^^^^^^^\nThe default value is ``false``\n\n`ONEFLOW_ONE_EMBEDDING_GRADIENT_SHUFFLE_USE_FP16 <https://github.com/Oneflow-Inc/oneflow/blob/v1.0.0/oneflow/core/job_rewriter/replace_embedding_ops_pass.cpp#L209>`_\n---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------\n\nWhether to allow the EmbeddingGradientShuffle operator of the OneEmbedding module to use the FP16 data type in the AMP case.\n\nValues accepted\n^^^^^^^^^^^^^^^\nThe default value is ``true``\n\n`ONEFLOW_ONE_EMBEDDING_NOT_FUSE_CAST_TO_UPDATE <https://github.com/Oneflow-Inc/oneflow/blob/v1.0.0/oneflow/core/job_rewriter/replace_embedding_ops_pass.cpp#L260>`_\n---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------\n\nWhether to disable the fusion of cast type conversion and parameter update of OneEmbedding parameters into one operator in the case of AMP\n\nValues accepted\n^^^^^^^^^^^^^^^\nThe default value is ``false``\n\n`ONEFLOW_DEBUG_KERNEL_SYNC_CHECK_NUMERICS_DUMP <https://github.com/Oneflow-Inc/oneflow/blob/v1.0.0/oneflow/core/kernel/cpu_numerics_kernel_observer.cpp#L65>`_\n---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------\n\nWhen the value appears NaN or Inf, save the data Dump.\n\nValues accepted\n^^^^^^^^^^^^^^^\nThe default value is ``false``\n\n`ONEFLOW_MLIR_ENABLE_IR_PRINTING <https://github.com/Oneflow-Inc/oneflow/blob/v1.0.0/oneflow/ir/lib/OneFlow/Passes.cpp#L768>`_\n---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------\n\nControl whether to print ir when running each pass when debugging\n\nValues accepted\n^^^^^^^^^^^^^^^\nThe default value is ``false``\n\n`ONEFLOW_MLIR_STDOUT <https://github.com/Oneflow-Inc/oneflow/blob/v1.0.0/oneflow/ir/oneflow-extension/extension.cpp#L151>`_\n---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------\n\nControl whether MLIR outputs log information in the console\n\nValues accepted\n^^^^^^^^^^^^^^^\nThe default value is ``false``\n\n`ONEFLOW_MLIR_DUMP_IR <https://github.com/Oneflow-Inc/oneflow/blob/v1.0.0/oneflow/ir/oneflow-extension/extension.cpp#L152>`_\n---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------\n\nControl whether to dump ir files\n\nValues accepted\n^^^^^^^^^^^^^^^\nThe default value is ``false``\n\n`ONEFLOW_MLIR_ENABLE_ROUND_TRIP <https://github.com/Oneflow-Inc/oneflow/blob/v1.0.0/oneflow/ir/oneflow-extension/ir_pass.cpp#L157>`_\n---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------\n\nControl whether Oneflow Job goes into MLIR\n\nValues accepted\n^^^^^^^^^^^^^^^\nThe default value is ``false``\n\n`ONEFLOW_KERNEL_REDUCE_SUM_USE_MATMUL <https://github.com/Oneflow-Inc/oneflow/blob/v1.0.0/oneflow/user/kernels/reduce_kernel.cpp#L333>`_\n---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------\n\nwhether to use matrix multiplication for reduce_sum\n\nValues accepted\n^^^^^^^^^^^^^^^\nThe default value is ``false``\n\n`ONEFLOW_ONE_EMBEDDING_ENABLE_QUANTIZED_COMM <https://github.com/Oneflow-Inc/oneflow/blob/dd580f21ffb6e4d23a899c7e0ac6d2bc502f3f1a/oneflow/core/job_rewriter/fuse_embedding_interaction_pass.cpp#L35>`_\n---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------\n\nWhether to quantify the shuffle application communication in the case of OneEmbedding multi-card\n\nValues accepted\n^^^^^^^^^^^^^^^\nThe default value is ``false``\n\n`ONEFLOW_TENSOR_BUFFER_ALIGNED_SIZE <https://github.com/Oneflow-Inc/oneflow/blob/v1.0.0/oneflow/core/common/tensor_buffer.cpp#L29>`_\n---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------\n\nAlign size when allocating TensorBuffer memory\n\nValues accepted\n^^^^^^^^^^^^^^^\nThe default value is ``1024``\n\n`ONEFLOW_TENSOR_BUFFER_POOL_THREAD_LOCAL_CACHE_SIZE <https://github.com/Oneflow-Inc/oneflow/blob/v1.0.0/oneflow/core/common/tensor_buffer.cpp#L206>`_\n---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------\n\nControl the size of ``thread_local_cache`` in TensorBufferPool\n\nValues accepted\n^^^^^^^^^^^^^^^\nThe default value is ``64``\n\n`ONEFLOW_GRPC_MAX_MESSAGE_BYTE_SIZE <https://github.com/Oneflow-Inc/oneflow/blob/v1.0.0/oneflow/core/control/ctrl_service.cpp#L45>`_\n---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------\n\nSet the maximum size of the gRPC transport message\n\nValues accepted\n^^^^^^^^^^^^^^^\nThe default value is ``-1``\n\n`ONEFLOW_ONE_EMBEDDING_PERSISTENT_TABLE_CAPACITY_HINT <https://github.com/Oneflow-Inc/oneflow/blob/v1.0.0/oneflow/core/embedding/persistent_table.cpp#L410>`_\n---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------\n\nControl the initial capacity of the PersistentTable of OneEmbedding to avoid frequent expansion\n\nValues accepted\n^^^^^^^^^^^^^^^\nOneEmbedding will calculate according to the actual situation, and users can also choose to configure a larger capacity.\n\n`ONEFLOW_ONE_EMBEDDING_PERSISTENT_TABLE_NUM_WORKERS <https://github.com/Oneflow-Inc/oneflow/blob/v1.0.0/oneflow/core/embedding/persistent_table.cpp#L435>`_\n---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------\n\nThe number of threads used for reading and writing the PersistentTable of OneEmbedding\n\nValues accepted\n^^^^^^^^^^^^^^^\nThe default value is ``4``\n\n`ONEFLOW_EP_CUDA_CONST_BUFFER_ELEMENT_COUNT <https://github.com/Oneflow-Inc/oneflow/blob/v1.0.0/oneflow/core/ep/cuda/cuda_device.cpp#L62>`_\n---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------\n\nSpecify the size of the all zero and all one buffers on the CUDA device.\n\nThis buffer can be used with matrix multiplication to implement operations such as reduce_sum\n\nValues accepted\n^^^^^^^^^^^^^^^\nThe default value is ``1024x1024``\n\n`OMP_NUM_THREADS <https://github.com/Oneflow-Inc/oneflow/blob/v1.0.0/oneflow/core/job/env_global_objects_scope.cpp#L96>`_\n---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------\n\nSet the number of threads used by OMP\n\nValues accepted\n^^^^^^^^^^^^^^^\nThe default value will be generated by specific `computational logic <https://github.com/Oneflow-Inc/oneflow/blob/v1.0.0/oneflow/core/job/env_global_objects_scope.cpp#L106-L108>`_.\n\n`SBP_INFER_RULE_TAG <https://github.com/Oneflow-Inc/oneflow/blob/v1.0.0/oneflow/core/operator/operator.cpp#L718>`_\n---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------\n\nSpecify SBP derivation rules\n\nValues accepted\n^^^^^^^^^^^^^^^\nWhen the default value is ``1`` , select the SBP that satisfies the producer or the SBP with the smallest cost as much as possible.\n\nWhen the default value is ``2``, select the SBP that matches the most.\n\nWhen the default value is ``3``, select the SBP with the smallest cost.\n\n`ONEFLOW_TENSOR_BUFFER_GROWTH_FACTOR <https://github.com/Oneflow-Inc/oneflow/blob/v1.0.0/oneflow/core/common/tensor_buffer.cpp#L35>`_\n---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------\n\nControl the growth factor of TensorBuffer\n\nValues accepted\n^^^^^^^^^^^^^^^\nThe default value is ``1.0``\n\n`ONEFLOW_TENSOR_BUFFER_SHRINK_FACTOR <https://github.com/Oneflow-Inc/oneflow/blob/v1.0.0/oneflow/core/common/tensor_buffer.cpp#L41>`_\n---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------\n\nControls the shrink factor of TensorBuffer\n\nValues accepted\n^^^^^^^^^^^^^^^\nThe default value is ``0.7``\n\n`ONEFLOW_TENSOR_BUFFER_POOL_SIZE_FACTOR <https://github.com/Oneflow-Inc/oneflow/blob/v1.0.0/oneflow/core/common/tensor_buffer.cpp#L200>`_\n---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------\n\nControls the size factor of TensorBuffer\n\nValues accepted\n^^^^^^^^^^^^^^^\nThe default value is ``2.0``\n\n`AUTO_PARALLEL_TRANSFER_COST <https://github.com/Oneflow-Inc/oneflow/blob/v1.0.0/oneflow/core/framework/sbp_infer_util.cpp#L544>`_\n---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------\n\nControl the size of the automatic parallel transfer cost\n\nValues accepted\n^^^^^^^^^^^^^^^\nThe default value is ``1.65e8``\n\n\n`ONEFLOW_DEBUG_PASS <https://github.com/Oneflow-Inc/oneflow/blob/v1.0.0/oneflow/core/job/job_build_and_infer_ctx.cpp#L991>`_\n---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------\n\nPass names and print job before and after a specific pass, such as ``export ONEFLOW_DEBUG_PASS=\"FuseAddToOutputPass``.\n\nOr ALL, print job before and after a specific pass, such as ``export ONEFLOW_DEBUG_PASS=\"ALL\"``.\n\nValues accepted\n^^^^^^^^^^^^^^^\nThe default value is ``empty``\n\n`ONEFLOW_PROFILER_HOST_THREAD_NAME_PREFIX <https://github.com/Oneflow-Inc/oneflow/blob/v1.0.0/oneflow/core/profiler/profiler.cpp#L39>`_\n---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------\n\nAdd a prefix to the name of the named host thread in the profiling context to facilitate sorting in the visualization tool (nsight)\n\nValues accepted\n^^^^^^^^^^^^^^^\nThe default value is ``empty``\n"
  },
  {
    "path": "docs/source/graph.rst",
    "content": "oneflow.nn.Graph\n============================================================\nBase class for running neural networks in Static Graph Mode.\n\nCurrently, there are two main ways to run models in deep learning frameworks, namely dynamic graphs and static graphs , which are also conventionally referred to as :ref:`dynamic graph` and :ref:`static graph` in OneFlow.\n\nBoth approaches have their advantages and disadvantages, and OneFlow provides support for both approaches, with Eager mode being the default.\n\nGenerally speaking, dynamic graphs are easier to use and static graphs have more performance advantages. :class:`oneflow.nn.Graph` module is provided by OneFlow to allow users to build static graphs and train models with Eager-like programming conventions.\n\n.. contents:: oneflow.nn.Graph\n    :depth: 2\n    :local:\n    :class: this-will-duplicate-information-and-it-is-still-useful-here\n    :backlinks: top\n\n.. _dynamic graph:\n\nEager Mode to Static Graph Mode\n------------------------------------------------------------\n\nOneFlow runs in Eager mode by default.\n\nOneFlow's nn.Graph is programmed in a style very similar to Eager Mode, so it is possible to make small changes and get large performance gains.\n\nThe following script shows the process of building a neural network in eager mode using the interface under ``oneflow.nn`` :\n\n\n.. code-block:: \n\n    import oneflow as flow\n    import oneflow.nn as nn\n\n    class ModuleMyLinear(nn.Module):\n        def __init__(self, in_features, out_features):\n            super().__init__()\n            self.weight = nn.Parameter(flow.randn(in_features, out_features))\n            self.bias = nn.Parameter(flow.randn(out_features))\n\n        def forward(self, input):\n            return flow.matmul(input, self.weight) + self.bias\n\n    linear_model = ModuleMyLinear(4, 3)\n\n\nEager ``nn.Module`` can be reused by ``nn.Graph``. The above script for eager mode can be changed to static Graph mode by adding just a few lines of code, which consists of the following steps:\n\n- Define your customized graph as a subclass of ``nn.Graph``\n- At the beginning of __init__. Call super().__init__() to let OneFlow do the necessary initialization of the Graph\n- Reuse the ``nn.Module`` object in Eager mode in __init__ (self.model = model)\n- Describe the computation in the ``build`` method\n- Instantiate your graph then call it.\n\n.. code-block:: \n\n    class GraphMyLinear(nn.Graph):\n        def __init__(self):\n            super().__init__()\n            self.model = linear_model\n\n        def build(self, input):\n            return self.model(input)\n\n    graph_mylinear = GraphMyLinear()\n    input = flow.randn(1, 4)\n    out = graph_mylinear(input)\n    print(out)\n\n    tensor([[-0.3298, -3.7907,  0.1661]], dtype=oneflow.float32)\n\n.. _static graph:\n\nStatic Graph Mode\n------------------------------------------------------------\n\n\nConstructing a Graph\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\nBase class for training or evaluating a neural network in static graph mode.\n\n.. currentmodule:: oneflow.nn.Graph\n\n.. autosummary::\n    :toctree: generated\n    :nosignatures:\n\n    __init__\n    build\n    add_optimizer\n    set_grad_scaler\n\nExecuting a Graph\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\nCall a nn.Graph instance to run a customized graph.\n\n.. currentmodule:: oneflow.nn.Graph\n\n.. autosummary::\n    :toctree: generated\n    :nosignatures:\n\n    __call__\n\n\n\nConfig options on a Graph\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\nOptimization options of a nn.Graph.\n\n.. currentmodule:: oneflow.nn.graph.graph_config.GraphConfig\n\n.. autosummary::\n    :toctree: generated\n    :nosignatures:\n\n    enable_amp\n    enable_zero\n    allow_fuse_model_update_ops\n    allow_fuse_add_to_output\n    allow_fuse_cast_scale\n    set_gradient_accumulation_steps\n    enable_cudnn_conv_heuristic_search_algo\n    enable_straighten_algorithm\n    enable_compress_memory\n    \n\nConfig options on a GraphModule\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\nGraphModule is the graph representation of a nn.Module in a nn.Graph.\n\nWhen an nn.Module is added into an nn.Graph, it is wrapped into a ProxyModule. The ProxyModule has a GraphModule inside it.\nYou can get and set the GraphModule to enable graph optimization on the nn.Module.\n\n.. currentmodule:: oneflow.nn.graph.graph_block.GraphModule\n\n.. autosummary::\n    :toctree: generated\n    :nosignatures:\n\n    set_stage\n    activation_checkpointing\n\nSave & Load a Model\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\n.. currentmodule:: oneflow.nn.Graph\n\n.. autosummary::\n    :toctree: generated\n    :nosignatures:\n\n    state_dict\n    load_state_dict\n\n\nDebug a Graph\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\n.. autosummary::\n    :toctree: generated\n    :nosignatures:\n\n    __repr__\n    debug\n    name\n\n\n\n"
  },
  {
    "path": "docs/source/hub.rst",
    "content": "oneflow.hub\n===================================\n\n.. The documentation is referenced from: \n   https://pytorch.org/docs/1.10/hub.html\n\nOneflow Hub is a pre-trained model repository designed to facilitate research reproducibility.\n\nPublishing models\n-----------------\n\nOneflow Hub supports publishing pre-trained models(model definitions and pre-trained weights)\nto a github repository by adding a simple ``hubconf.py`` file;\n\n``hubconf.py`` can have multiple entrypoints. Each entrypoint is defined as a python function\n(example: a pre-trained model you want to publish).\n\n::\n\n    def entrypoint_name(*args, **kwargs):\n        # args & kwargs are optional, for models which take positional/keyword arguments.\n        ...\n\nHow to implement an entrypoint?\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\nHere is a code snippet specifies an entrypoint for ``resnet18`` model if we expand\nthe implementation in ``Oneflow-Inc/vision/hubconf.py``.\nIn most case importing the right function in ``hubconf.py`` is sufficient. Here we\njust want to use the expanded version as an example to show how it works.\nYou can see the full script in\n`Oneflow-Inc/vision repo <https://github.com/Oneflow-Inc/vision/blob/master/hubconf.py>`_\n\n::\n\n    dependencies = ['oneflow']\n    from flowvision.models.resnet import resnet18 as _resnet18\n\n    # resnet18 is the name of entrypoint\n    def resnet18(pretrained=False, **kwargs):\n        \"\"\" # This docstring shows up in hub.help()\n        Resnet18 model\n        pretrained (bool): kwargs, load pretrained weights into the model\n        \"\"\"\n        # Call the model, load pretrained weights\n        model = _resnet18(pretrained=pretrained, **kwargs)\n        return model\n\n\n- ``dependencies`` variable is a **list** of package names required to **load** the model. Note this might\n  be slightly different from dependencies required for training a model.\n- ``args`` and ``kwargs`` are passed along to the real callable function.\n- Docstring of the function works as a help message. It explains what does the model do and what\n  are the allowed positional/keyword arguments. It's highly recommended to add a few examples here.\n- Entrypoint function can either return a model(nn.module), or auxiliary tools to make the user workflow smoother, e.g. tokenizers.\n- Callables prefixed with underscore are considered as helper functions which won't show up in :func:`oneflow.hub.list()`.\n- Pretrained weights can either be stored locally in the github repo, or loadable by\n  :func:`oneflow.hub.load_state_dict_from_url()`. If less than 2GB, it's recommended to attach it to a `project release <https://help.github.com/en/articles/distributing-large-binaries>`_\n  and use the url from the release.\n  In the example above ``flowvision.models.resnet.resnet18`` handles ``pretrained``, alternatively you can put the following logic in the entrypoint definition.\n\n::\n\n    if pretrained:\n        # For checkpoint saved in local github repo, e.g. <RELATIVE_PATH_TO_CHECKPOINT>=weights/save.pth\n        dirname = os.path.dirname(__file__)\n        checkpoint = os.path.join(dirname, <RELATIVE_PATH_TO_CHECKPOINT>)\n        state_dict = oneflow.load(checkpoint)\n        model.load_state_dict(state_dict)\n\n        # For checkpoint saved elsewhere\n        checkpoint = 'https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/ResNet/resnet18.zip'\n        model.load_state_dict(oneflow.hub.load_state_dict_from_url(checkpoint, progress=False))\n\n\nImportant Notice\n^^^^^^^^^^^^^^^^\n\n- The published models should be at least in a branch/tag. It can't be a random commit.\n\n\nLoading models from Hub\n-----------------------\n\nOneFlow Hub provides convenient APIs to explore all available models in hub\nthrough :func:`oneflow.hub.list()`, show docstring and examples through\n:func:`oneflow.hub.help()` and load the pre-trained models using\n:func:`oneflow.hub.load()`.\n\n\n.. automodule:: oneflow.hub\n\n.. autofunction:: list\n\n.. autofunction:: help\n\n.. autofunction:: load\n\n.. autofunction:: download_url_to_file\n\n.. autofunction:: load_state_dict_from_url\n\nRunning a loaded model:\n^^^^^^^^^^^^^^^^^^^^^^^\n\nNote that ``*args`` and ``**kwargs`` in :func:`oneflow.hub.load()` are used to\n**instantiate** a model. After you have loaded a model, how can you find out\nwhat you can do with the model?\nA suggested workflow is\n\n- ``dir(model)`` to see all available methods of the model.\n- ``help(model.foo)`` to check what arguments ``model.foo`` takes to run\n\nTo help users explore without referring to documentation back and forth, we strongly\nrecommend repo owners make function help messages clear and succinct. It's also helpful\nto include a minimal working example.\n\nWhere are my downloaded models saved?\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\nThe locations are used in the order of\n\n- Calling ``hub.set_dir(<PATH_TO_HUB_DIR>)``\n- ``$ONEFLOW_HOME/hub``, if environment variable ``ONEFLOW_HOME`` is set.\n- ``$XDG_CACHE_HOME/oneflow/hub``, if environment variable ``XDG_CACHE_HOME`` is set.\n- ``~/.cache/oneflow/hub``\n\n.. autofunction:: get_dir\n\n.. autofunction:: set_dir\n\nCaching logic\n^^^^^^^^^^^^^\n\nBy default, we don't clean up files after loading it. Hub uses the cache by default if it already exists in the\ndirectory returned by :func:`~oneflow.hub.get_dir()`.\n\nUsers can force a reload by calling ``hub.load(..., force_reload=True)``. This will delete\nthe existing github folder and downloaded weights, reinitialize a fresh download. This is useful\nwhen updates are published to the same branch, users can keep up with the latest release.\n\n\nKnown limitations:\n^^^^^^^^^^^^^^^^^^\nOneflow hub works by importing the package as if it was installed. There are some side effects\nintroduced by importing in Python. For example, you can see new items in Python caches\n``sys.modules`` and ``sys.path_importer_cache`` which is normal Python behavior.\nThis also means that you may have import errors when importing different models\nfrom different repos, if the repos have the same sub-package names (typically, a\n``model`` subpackage). A workaround for these kinds of import errors is to\nremove the offending sub-package from the ``sys.modules`` dict; more details can\nbe found in `this github issue\n<https://github.com/pytorch/hub/issues/243#issuecomment-942403391>`_.\n\nA known limitation that is worth mentioning here: users **CANNOT** load two different branches of\nthe same repo in the **same python process**. It's just like installing two packages with the\nsame name in Python, which is not good. Cache might join the party and give you surprises if you\nactually try that. Of course it's totally fine to load them in separate processes.\n"
  },
  {
    "path": "docs/source/image.rst",
    "content": "oneflow.nn.image\n======================================\nImage operations for neural networks\n--------------------------------------\n.. currentmodule:: oneflow.nn.image\n.. autosummary::\n    :toctree: generated\n    :nosignatures:\n\n    Resize\n    batch_align\n    decode\n    flip\n    normalize\n\n\n"
  },
  {
    "path": "docs/source/index.rst",
    "content": "OneFlow API Reference\n===================================\n\n\nDistributed performance (high efficiency) is the core technical difficulty of deep learning frameworks. \n\nOneFlow upholds the core concept and architecture of static compilation and streaming parallelism around performance improvement and heterogeneous distributed scaling, solving the challenge of memory wall at cluster level with world-leading technology.\n\n\n.. toctree::\n    :maxdepth: 1\n\n    troubleshooting\n\n\n\n.. toctree::\n    :maxdepth: 1\n    :caption: OneFlow Python API\n\n    oneflow\n    nn\n    nn.functional\n    tensor\n    tensor_attributes\n    type_info\n    autograd\n    cuda\n    distributed\n    distributions\n    hub\n    linalg\n    nn.init\n    optim\n    graph\n    auto_parallel\n    image\n    utils.data\n    utils.global_view\n    utils.tensor\n    one_embedding\n    environment_variables\n    special\n\n\n\nIndices and tables\n==================\n\n* :ref:`genindex`\n* :ref:`modindex`\n* :ref:`search`\n"
  },
  {
    "path": "docs/source/linalg.rst",
    "content": "oneflow.linalg\n===================================\n\n.. The documentation is referenced from: \n   https://pytorch.org/docs/1.10/linalg.html\n\nCommon linear algebra operations.\n\nMatrix Properties\n-----------------\n\n.. currentmodule:: oneflow.linalg\n.. autosummary::\n    :toctree: generated\n    :nosignatures:\n\n    norm \n    vector_norm\n    matrix_norm\n    diagonal\n    inv\n    cross\n    det\n"
  },
  {
    "path": "docs/source/nn.functional.rst",
    "content": "oneflow.nn.functional\n===========================================\n\n.. The documentation is referenced from: https://pytorch.org/docs/1.10/nn.functional.html.\n\n.. contents:: oneflow.nn.functional\n    :depth: 2\n    :local:\n    :class: this-will-duplicate-information-and-it-is-still-useful-here\n    :backlinks: top\n\n.. currentmodule:: oneflow.nn.functional\n\nConvolution functions\n-------------------------------------------\n\n.. autosummary::\n    :toctree: generated\n    :nosignatures:\n\n    conv1d\n    conv2d\n    conv3d\n    conv_transpose1d\n    conv_transpose2d\n    conv_transpose3d\n    fold\n    unfold\n\nNormalization functions\n-----------------------\n\n.. autosummary::\n    :toctree: generated\n    :nosignatures:\n\n    batch_norm\n    layer_norm\n    normalize\n    group_norm\n\nPooling functions\n----------------------------------\n\n.. autosummary::\n    :toctree: generated\n    :nosignatures:\n\n    avg_pool1d\n    avg_pool2d\n    avg_pool3d\n    max_pool1d\n    max_pool2d\n    max_pool3d\n    max_unpool1d\n    max_unpool2d\n    max_unpool3d\n    adaptive_avg_pool1d\n    adaptive_avg_pool2d\n    adaptive_avg_pool3d\n    adaptive_max_pool1d\n    adaptive_max_pool2d\n    adaptive_max_pool3d\n\nNon-linear activation functions\n-------------------------------\n\n.. autosummary::\n    :toctree: generated\n    :nosignatures:\n\n    threshold\n    relu\n    hardtanh\n    hardswish\n    relu6\n    elu\n    selu\n    celu\n    leaky_relu\n    square_relu\n    prelu\n    glu\n    gelu\n    quick_gelu\n    logsigmoid\n    hardshrink\n    softsign\n    softplus\n    softmax\n    softshrink\n    log_softmax\n    gumbel_softmax\n    tanh\n    sigmoid\n    hardsigmoid\n    silu\n    mish\n\nLinear functions\n----------------\n\n.. autosummary::\n    :toctree: generated\n    :nosignatures:\n\n    linear\n\nDropout functions\n-----------------\n\n.. autosummary::\n    :toctree: generated\n    :nosignatures:\n\n    dropout\n    dropout1d\n    dropout2d\n    dropout3d\n\nSparse functions\n----------------------------------\n\n.. autosummary::\n    :toctree: generated\n    :nosignatures:\n\n    embedding\n    one_hot\n\nDistance functions\n----------------------------------\n\n.. autosummary::\n    :toctree: generated\n    :nosignatures:\n\n    cosine_similarity\n    pairwise_distance\n\n\nLoss functions\n--------------\n\n.. autosummary::\n    :toctree: generated\n    :nosignatures:\n\n    sparse_softmax_cross_entropy\n    cross_entropy\n    ctc_loss\n    l1_loss\n    mse_loss\n    smooth_l1_loss\n    triplet_margin_loss\n    binary_cross_entropy\n    binary_cross_entropy_with_logits\n\nVision functions\n----------------\n\n.. autosummary::\n    :toctree: generated\n    :nosignatures:\n\n    deform_conv2d\n    pad\n    interpolate\n    upsample\n    grid_sample\n    affine_grid\n\nGreedy decoder\n----------------\n\n.. autosummary::\n    :toctree: generated\n    :nosignatures:\n\n    ctc_greedy_decoder\n\n"
  },
  {
    "path": "docs/source/nn.init.rst",
    "content": "oneflow.nn.init\n===============\n\n.. The documentation is referenced from: \n   https://pytorch.org/docs/1.10/nn.init.html\n\n.. currentmodule:: oneflow.nn.init\n.. autofunction:: calculate_gain\n.. autofunction:: uniform_\n.. autofunction:: normal_\n.. autofunction:: constant_\n.. autofunction:: ones_\n.. autofunction:: zeros_\n.. autofunction:: xavier_uniform_\n.. autofunction:: xavier_normal_\n.. autofunction:: kaiming_uniform_\n.. autofunction:: kaiming_normal_\n.. autofunction:: trunc_normal_\n.. autofunction:: orthogonal_\n"
  },
  {
    "path": "docs/source/nn.rst",
    "content": "oneflow.nn\n===================================\n\n.. The documentation is referenced from: \n   https://pytorch.org/docs/1.10/nn.html\n\nThese are the basic building blocks for graphs:\n\n.. contents:: oneflow.nn\n    :depth: 2\n    :local:\n    :class: this-will-duplicate-information-and-it-is-still-useful-here\n    :backlinks: top\n\n.. currentmodule:: oneflow.nn\n.. autosummary::\n    :toctree: generated\n    :nosignatures:\n    :template: \n\n    Parameter\n\n\nContainers\n----------------------------------\n.. currentmodule:: oneflow.nn\n\n.. autosummary::\n    :toctree: generated\n    :nosignatures:\n    :template: classtemplate.rst\n\n    Module\n    Sequential\n    ModuleList\n    ModuleDict\n    ParameterList\n    ParameterDict\n\nnn.Module\n----------------------------------\n.. currentmodule:: oneflow.nn.Module\n\n.. autosummary::\n    :toctree: generated\n    :nosignatures:\n\n    add_module\n    apply\n    buffers\n    children\n    cpu\n    cuda\n    double\n    train\n    eval\n    extra_repr\n    float\n    forward\n    load_state_dict\n    modules\n    named_buffers\n    named_children\n    named_modules\n    named_parameters\n    parameters\n    register_buffer\n    register_forward_hook\n    register_forward_pre_hook\n    register_backward_hook\n    register_full_backward_hook\n    register_state_dict_pre_hook\n    register_parameter\n    requires_grad_\n    state_dict\n    to\n    zero_grad\n\n\n\nContainers\n\nConvolution Layers\n----------------------------------\n.. currentmodule:: oneflow\n.. autosummary::\n    :toctree: generated\n    :nosignatures:\n    :template: classtemplate.rst\n\n    nn.Conv1d \n    nn.Conv2d \n    nn.Conv3d\n    nn.ConvTranspose1d \n    nn.ConvTranspose2d \n    nn.ConvTranspose3d\n    nn.Unfold\n    nn.Fold\n\nPooling Layers\n----------------------------------\n\n.. autosummary::\n    :toctree: generated\n    :nosignatures:\n    :template: classtemplate.rst\n\n    nn.MaxPool1d \n    nn.MaxPool2d \n    nn.MaxPool3d \n    nn.MaxUnpool1d\n    nn.MaxUnpool2d\n    nn.MaxUnpool3d\n    nn.AdaptiveAvgPool1d \n    nn.AdaptiveAvgPool2d \n    nn.AdaptiveAvgPool3d\n    nn.AdaptiveMaxPool1d \n    nn.AdaptiveMaxPool2d \n    nn.AdaptiveMaxPool3d\n    nn.AvgPool1d \n    nn.AvgPool2d \n    nn.AvgPool3d\n\nPadding Layers\n----------------------------------\n\n.. autosummary::\n    :toctree: generated\n    :nosignatures:\n    :template: classtemplate.rst\n\n    nn.ConstantPad1d \n    nn.ConstantPad2d \n    nn.ConstantPad3d\n    nn.ReflectionPad1d\n    nn.ReflectionPad2d\n    nn.ReplicationPad1d\n    nn.ReplicationPad2d\n    nn.ZeroPad2d\n\nNon-linear Activations (weighted sum, nonlinearity)\n----------------------------------------------------\n\n.. autosummary::\n    :toctree: generated\n    :nosignatures:\n    :template: classtemplate.rst\n\n    nn.ELU \n    nn.Hardshrink\n    nn.Hardsigmoid \n    nn.Hardswish \n    nn.Hardtanh \n    nn.LeakyReLU \n    nn.LogSigmoid \n    nn.PReLU \n    nn.ReLU\n    nn.ReLU6 \n    nn.SELU \n    nn.CELU \n    nn.GELU \n    nn.QuickGELU \n    nn.SquareReLU\n    nn.SiLU \n    nn.Sigmoid \n    nn.Mish \n    nn.Softplus \n    nn.Softshrink \n    nn.Softsign \n    nn.Tanh \n    nn.Threshold \n    nn.GLU\n\nNon-linear Activations (other)\n----------------------------------\n\n.. autosummary::\n    :toctree: generated\n    :nosignatures:\n    :template: classtemplate.rst\n\n    nn.Softmax\n    nn.LogSoftmax\n\nNormalization Layers\n----------------------------------\n\n.. autosummary::\n    :toctree: generated\n    :nosignatures:\n    :template: classtemplate.rst\n\n    nn.BatchNorm1d \n    nn.BatchNorm2d \n    nn.BatchNorm3d\n    nn.SyncBatchNorm\n    nn.FusedBatchNorm1d \n    nn.FusedBatchNorm2d\n    nn.FusedBatchNorm3d \n    nn.GroupNorm \n    nn.InstanceNorm1d \n    nn.InstanceNorm2d \n    nn.InstanceNorm3d \n    nn.LayerNorm\n    nn.RMSLayerNorm\n    nn.RMSNorm\n\nRecurrent Layers\n----------------\n\n.. autosummary::\n    :toctree: generated\n    :nosignatures:\n    :template: classtemplate.rst\n\n    nn.RNN\n    nn.LSTM\n    nn.GRU\n    nn.RNNCell\n    nn.LSTMCell\n    nn.GRUCell\n\nLinear Layers\n----------------------------------\n\n.. autosummary::\n    :toctree: generated\n    :nosignatures:\n    :template: classtemplate.rst\n\n    nn.Identity\n    nn.Linear\n\nDropout Layers\n----------------------------------\n\n.. autosummary::\n    :toctree: generated\n    :nosignatures:\n    :template: classtemplate.rst\n\n    nn.Dropout\n    nn.Dropout1d\n    nn.Dropout2d\n    nn.Dropout3d\n\nSparse Layers\n----------------------------------\n\n.. autosummary::\n    :toctree: generated\n    :nosignatures:\n    :template: classtemplate.rst\n\n    nn.Embedding\n\nDistance Functions\n------------------\n\n.. autosummary::\n    :toctree: generated\n    :nosignatures:\n    :template: classtemplate.rst\n\n    nn.CosineSimilarity\n    nn.PairwiseDistance\n\nLoss Functions\n----------------------------------\n\n.. autosummary::\n    :toctree: generated\n    :nosignatures:\n    :template: classtemplate.rst\n\n    nn.BCELoss \n    nn.BCEWithLogitsLoss \n    nn.CTCLoss \n    nn.CombinedMarginLoss \n    nn.CrossEntropyLoss \n    nn.KLDivLoss \n    nn.L1Loss \n    nn.MSELoss \n    nn.MarginRankingLoss \n    nn.NLLLoss \n    nn.SmoothL1Loss \n    nn.TripletMarginLoss\n\nVision Layers\n----------------------------------\n\n.. autosummary::\n    :toctree: generated\n    :nosignatures:\n    :template: classtemplate.rst\n\n    nn.PixelShuffle \n    nn.Upsample \n    nn.UpsamplingBilinear2d \n    nn.UpsamplingNearest2d\n\n\nDataParallel Layers (multi-GPU, distributed)\n--------------------------------------------\n\n.. autosummary::\n    :toctree: generated\n    :nosignatures:\n    :template: classtemplate.rst\n    \n    nn.parallel.DistributedDataParallel\n\n\nData loading and preprocessing Layers\n----------------------------------------\n\n.. autosummary::\n    :toctree: generated\n    :nosignatures:\n\n    nn.COCOReader\n    nn.CoinFlip\n    nn.CropMirrorNormalize\n    nn.OFRecordBytesDecoder\n    nn.OFRecordImageDecoder\n    nn.OFRecordImageDecoderRandomCrop\n    nn.OFRecordRawDecoder\n    nn.OFRecordReader\n\nQuantization Aware Training\n--------------------------------------------\n\n.. autosummary::\n    :toctree: generated\n    :nosignatures:\n\n    nn.MinMaxObserver\n    nn.MovingAverageMinMaxObserver\n    nn.FakeQuantization\n    nn.QatConv1d\n    nn.QatConv2d\n    nn.QatConv3d\n\nUtilities\n---------\n\nFrom the ``oneflow.nn.utils`` module\n\n.. currentmodule:: oneflow.nn.utils\n.. autosummary::\n    :toctree: generated\n    :nosignatures:\n    :template: classtemplate.rst\n\n    clip_grad_norm_\n    clip_grad_value_\n    weight_norm\n    remove_weight_norm\n\nUtility functions in other modules\n\n.. currentmodule:: oneflow\n.. autosummary::\n    :toctree: generated\n    :nosignatures:\n    :template: classtemplate.rst\n\n    nn.utils.rnn.PackedSequence\n    nn.utils.rnn.pack_padded_sequence\n    nn.utils.rnn.pad_packed_sequence\n    nn.utils.rnn.pad_sequence\n    nn.utils.rnn.pack_sequence\n\n.. autosummary::\n    :toctree: generated\n    :nosignatures:\n    :template: classtemplate.rst\n\n    nn.Flatten\n\nQuantized Functions\n--------------------\n\nQuantization refers to techniques for performing computations and \nstoring tensors at lower bitwidths than floating point precision.\n\n.. autosummary::\n    :toctree: generated\n    :nosignatures:\n    :template:\n\n    nn.FakeQuantization\n    nn.MinMaxObserver\n    nn.MovingAverageMinMaxObserver\n    nn.Quantization\n"
  },
  {
    "path": "docs/source/one_embedding.rst",
    "content": "oneflow.one_embedding\n===================================\n\nEmbedding is an important component of recommender system, and it has also spread to many fields outside recommender systems. Each framework provides basic operators for Embedding, for example, ``flow.nn.Embedding`` in OneFlow:\n\n::\n\n    import numpy as np\n    import oneflow as flow\n    indices = flow.tensor([[1, 2, 4, 5], [4, 3, 2, 9]], dtype=flow.int)\n    embedding = flow.nn.Embedding(10, 3)\n    y = embedding(indices)\n\n\nOneEmbedding is the large-scale Embedding solution that OneFlow provides to solve the problem of large-scale deep recommender systems. OneEmbedding has the following advantages compared to ordinary opeartors:\n\n    - With Flexible hierarchical storage, OneEmbedding can place the Embedding table on GPU memory, CPU memory or SSD, and allow high-speed devices to be used as caches for low-speed devices to achieve both speed and capacity.\n\n    - OneEmbedding supports dynamic expansion.\n\n.. note ::\n    Please refer to `Large-Scale Embedding Solution: OneEmbedding <https://docs.oneflow.org/en/master/cookies/one_embedding.html>`__\n    for a brief introduction to all features related to OneEmbedding.\n\nConfigure Embedding Table \n----------------------------------\n\nOneEmbedding supports simultaneous creation of multiple Embedding table. The following codes configured three Embedding tables.\n\n.. code-block:: \n\n    import oneflow as flow\n    import oneflow.nn as nn\n    import numpy as np\n\n    tables = [\n        flow.one_embedding.make_table_options(\n            flow.one_embedding.make_uniform_initializer(low=-0.1, high=0.1)\n        ),\n        flow.one_embedding.make_table_options(\n            flow.one_embedding.make_uniform_initializer(low=-0.05, high=0.05)\n        ),\n        flow.one_embedding.make_table_options(\n            flow.one_embedding.make_uniform_initializer(low=-0.15, high=0.15)\n        ),\n    ]\n\nWhen configuring the Embedding table, you need to specify the initialization method. The above Embedding tables are initialized in the ``uniform`` method. The result of configuring the Embedding table is stored in the ``tables`` variable\n\n.. autofunction:: oneflow.one_embedding.make_table_options\n.. autofunction:: oneflow.one_embedding.make_table\n\ninitialization method\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\n.. currentmodule:: oneflow.one_embedding\n\n.. autosummary::\n    :toctree: generated\n    :nosignatures:\n\n    make_uniform_initializer\n    make_normal_initializer\n\n\nConfigure the Storage Attribute of the Embedding Table\n--------------------------------------------------------------------\nThen run the following codes to configure the storage attribute of the Embedding table:\n\n.. code-block:: \n\n    store_options = flow.one_embedding.make_cached_ssd_store_options(\n    cache_budget_mb=8142,\n    persistent_path=\"/your_path_to_ssd\", \n    capacity=40000000,\n    size_factor=1,              \n    physical_block_size=4096\n    )\n\nStorage Method\n^^^^^^^^^^^^^^^^^^^^\n\n.. currentmodule:: oneflow.one_embedding\n\n.. autosummary::\n    :toctree: generated\n    :nosignatures:\n\n    make_device_mem_store_options\n    make_cached_ssd_store_options \n    make_cached_host_mem_store_options\n\n.. note ::\n    \n    Please refer to `Large-Scale Embedding Solution: OneEmbedding <https://docs.oneflow.org/en/master/cookies/one_embedding.html#feature-id-and-dynamic-insertion>`__\n    for a brief introduction to learn about How to Choose the Proper Storage Configuration\n\n\nInstantiate Embedding\n--------------------------------------------------------------------\nAfter the above configuration is completed, you can use MultiTableEmbedding to get the instantiated Embedding layer.\n\n.. code-block:: \n\n    embedding_size = 128\n    embedding = flow.one_embedding.MultiTableEmbedding(\n        name=\"my_embedding\",\n        embedding_dim=embedding_size,\n        dtype=flow.float,\n        key_type=flow.int64,\n        tables=tables,\n        store_options=store_options,\n    )\n\n    embedding.to(\"cuda\")\n\n.. note ::\n    \n    Please refer to `Large-Scale Embedding Solution: OneEmbedding <https://docs.oneflow.org/en/master/cookies/one_embedding.html#feature-id-and-multi-table-query>`__\n    for a brief introduction to learn about Feature ID and Multi-Table Query.\n\n\nMultiTableEmbedding\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\n.. autofunction:: oneflow.one_embedding.MultiTableEmbedding\n\n.. currentmodule:: oneflow.one_embedding.MultiTableEmbedding\n\n.. autosummary::\n    :toctree: generated\n    :nosignatures:\n    \n    forward\n    save_snapshot\n    load_snapshot\n\nMultiTableMultiColumnEmbedding\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\n.. autofunction:: oneflow.one_embedding.MultiTableMultiColumnEmbedding\n\n.. currentmodule:: oneflow.one_embedding.MultiTableMultiColumnEmbedding\n\n.. autosummary::\n    :toctree: generated\n    :nosignatures:\n    \n    forward\n    save_snapshot\n    load_snapshot\n\nConstruct Graph for Training\n--------------------------------------------------------------------\nOneEmbedding is only supported in Graph mode.\n\n.. code-block:: \n\n    num_tables = 3\n    mlp = flow.nn.FusedMLP(\n        in_features=embedding_size * num_tables,\n        hidden_features=[512, 256, 128],\n        out_features=1,\n        skip_final_activation=True,\n    )\n    mlp.to(\"cuda\")\n\n    class TrainGraph(flow.nn.Graph):\n        def __init__(self,):\n            super().__init__()\n            self.embedding_lookup = embedding\n            self.mlp = mlp\n            self.add_optimizer(\n                flow.optim.SGD(self.embedding_lookup.parameters(), lr=0.1, momentum=0.0)\n            )\n            self.add_optimizer(\n                flow.optim.SGD(self.mlp.parameters(), lr=0.1, momentum=0.0)\n            )\n        def build(self, ids):\n            embedding = self.embedding_lookup(ids)\n            loss = self.mlp(flow.reshape(embedding, (-1, num_tables * embedding_size)))\n            loss = loss.sum()\n            loss.backward()\n            return loss\n\n.. note ::\n    \n    Please refer to `Distributed Training: OneEmbedding <https://docs.oneflow.org/en/master/parallelism/01_introduction.html>`__\n    for a brief introduction to learn about Graph For Training\n\n\nPersistent Read & Write\n-----------------------------------------------\n.. currentmodule:: oneflow.one_embedding\n\n.. autosummary::\n    :toctree: generated\n    :nosignatures:\n    \n    make_persistent_table_reader\n    make_persistent_table_writer\n\n.. automodule:: oneflow.one_embedding\n    :members: Ftrl\n\n"
  },
  {
    "path": "docs/source/oneflow.rst",
    "content": "oneflow\n===================================\n\n.. The documentation is referenced from: \n   https://pytorch.org/docs/1.10/torch.html\n\nThe oneflow package contains data structures for multi-dimensional tensors and defines mathematical operations over these tensors. Additionally, it provides many utilities for efficient serializing of Tensors and arbitrary types, and other useful utilities.\n\nIt has a CUDA counterpart, that enables you to run your tensor computations on an NVIDIA GPU with compute capability >= 3.0\n\n.. currentmodule:: oneflow\n\n\nTensor\n-------------------------------------------\n\n.. autosummary::\n    :toctree: generated\n    :nosignatures:\n\n    BoolTensor\n    ByteTensor\n    CharTensor\n    DoubleTensor\n    FloatTensor\n    HalfTensor\n    IntTensor\n    LongTensor\n\n\n.. autosummary::\n    :toctree: generated\n    :nosignatures:\n\n    is_tensor\n    is_floating_point\n    is_nonzero\n    numel\n    set_printoptions\n    get_default_dtype\n    set_default_dtype\n    set_default_tensor_type\n\n.. _tensor-creation-ops:\n\nCreation Ops\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. note::\n    Random sampling creation ops are listed under :ref:`random-sampling` and\n    include:\n    :func:`oneflow.rand`\n    :func:`oneflow.randn`\n    :func:`oneflow.randint`\n    :func:`oneflow.randperm`\n    \n.. autosummary::\n    :toctree: generated\n    :nosignatures:\n\n    tensor\n    as_tensor\n    as_strided\n    from_numpy\n    zeros\n    zeros_like\n    ones\n    ones_like\n    randn_like\n    randint_like\n    masked_fill\n    new_ones\n    arange\n    linspace\n    eye\n    empty\n    empty_like\n    full\n    full_like\n    tensor_scatter_nd_update\n    logspace\n\n.. _indexing-slicing-joining:\n\nIndexing, Slicing, Joining, Mutating Ops\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autosummary::\n    :toctree: generated\n    :nosignatures:\n\n    argwhere\n    atleast_1d\n    atleast_2d\n    atleast_3d\n    cat\n    column_stack\n    concat\n    chunk\n    dstack\n    expand\n    gather\n    gather_nd\n    batch_gather\n    hsplit\n    hstack\n    vsplit\n    vstack\n    index_select\n    index_add\n    masked_select\n    movedim\n    narrow\n    nonzero\n    permute\n    repeat\n    reshape\n    row_stack\n    select\n    scatter\n    scatter_add\n    scatter_nd\n    slice\n    slice_update\n    split\n    squeeze\n    stack\n    swapaxes\n    swapdims\n    t\n    tile\n    transpose\n    unbind\n    unsqueeze\n    where\n    tensor_split\n\n.. _random-sampling:\n\nRandom sampling\n-------------------------------------------\n\n.. autosummary::\n    :toctree: generated\n    :nosignatures:\n\n    seed\n    manual_seed\n    initial_seed\n    get_rng_state\n    set_rng_state\n    bernoulli\n    normal\n    rand\n    randint\n    randn\n    randperm\n    multinomial\n    \nIn-place random sampling\n~~~~~~~~~~~~~~~~~~~~~~~~\n\nThere are a few more in-place random sampling functions defined on Tensors as well. Click through to refer to their documentation:\n- :func:`oneflow.Tensor.normal_` - in-place version of :func:`oneflow.normal`\n- :func:`oneflow.Tensor.uniform_` - numbers sampled from the continuous uniform distribution\n\n\n\nSerialization\n-------------------------------------------\n\n.. autosummary::\n    :toctree: generated\n    :nosignatures:\n\n    save\n    load\n\nParallelism\n-------------------------------------------\n\n.. autosummary::\n    :toctree: generated\n    :nosignatures:\n\n    set_num_threads\n\n\nLocally disabling gradient computation\n-------------------------------------------\nThe context managers :func:`oneflow.no_grad`, :func:`oneflow.enable_grad`, and\n:func:`oneflow.set_grad_enabled` are helpful for locally disabling and enabling\ngradient computation. These context managers are thread local, so they won't\nwork if you send work to another thread using the ``threading`` module, etc.\n\nExamples::\n\n  >>> import oneflow\n  >>> x = oneflow.zeros(1, requires_grad=True)\n  >>> with oneflow.no_grad():\n  ...     y = x * 2\n  >>> y.requires_grad\n  False\n\n  >>> with oneflow.set_grad_enabled(False):\n  ...     y = x * 2\n  >>> y.requires_grad\n  False\n  \n  >>> with oneflow.set_grad_enabled(True):\n  ...     y = x * 2\n  >>> y.requires_grad\n  True\n\n.. autosummary::\n    :toctree: generated\n    :nosignatures:\n\n    no_grad\n    set_grad_enabled\n    enable_grad\n    is_grad_enabled\n    inference_mode\n\nMath operations\n-------------------------------------------\n\nPointwise Ops\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autosummary::\n    :toctree: generated\n    :nosignatures:\n\n    abs \n    acos \n    acosh \n    arccos \n    arccosh\n    add \n    addcdiv\n    addcmul\n    asin \n    asinh \n    arcsin \n    arcsinh \n    atan\n    atanh \n    arctan \n    arctanh \n    atan2 \n    ceil \n    ceil_\n    clamp \n    clamp_min\n    clamp_max\n    clip \n    cos \n    cosh \n    div \n    erf \n    erfc \n    erfinv\n    exp \n    expm1 \n    floor \n    floor_ \n    frac\n    frac_\n    fmod \n    gelu\n    quick_gelu\n    square_relu\n    log \n    log1p \n    log2 \n    log10\n    logical_and \n    logical_not \n    logical_or \n    logical_xor \n    bitwise_and\n    bitwise_or\n    bitwise_xor\n    bitwise_not\n    mish\n    mul \n    neg \n    negative \n    pow \n    reciprocal \n    round \n    round_\n    rsqrt \n    selu\n    softmax\n    softplus\n    softsign\n    silu\n    sigmoid \n    sign \n    sin \n    sinh \n    sin_ \n    sqrt \n    square \n    sub \n    tan \n    tanh\n    trunc\n    floor_divide\n    lerp\n    lerp_\n    quantile\n\nReduction Ops\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autosummary::\n    :toctree: generated\n    :nosignatures:\n    \n    argmax  \n    argmin  \n    amax\n    amin\n    any\n    max\n    min  \n    mean  \n    median\n    mode\n    prod\n    nansum\n    std  \n    sum  \n    logsumexp\n    var\n    norm\n    all\n\n\nComparison Ops\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autosummary::\n    :toctree: generated\n    :nosignatures:\n\n    argsort \n    eq \n    equal \n    gt \n    isinf \n    isnan \n    le \n    lt \n    ne \n    sort \n    topk\n    ge\n    greater\n    greater_equal\n    maximum\n    minimum\n    not_equal\n    isclose\n    allclose\n\nSpectral Ops\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autosummary::\n    :toctree: generated\n    :nosignatures:\n\n    hann_window\n    \nOther Ops\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autosummary::\n    :toctree: generated\n    :nosignatures:\n\n    adaptive_avg_pool1d\n    adaptive_avg_pool2d\n    adaptive_avg_pool3d\n    broadcast_like \n    cast\n    cumprod \n    cumsum \n    diag \n    diagonal \n    einsum \n    flatten \n    flip \n    in_top_k\n    meshgrid \n    nms\n    roc_auc_score\n    roll \n    searchsorted\n    tensordot\n    tril\n    repeat_interleave\n    triu\n    cross\n    bincount\n    broadcast_shapes\n    broadcast_tensors\n    broadcast_to\n    unique\n\nBLAS and LAPACK Operations\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autosummary::\n    :toctree: generated\n    :nosignatures:\n\n    addmm \n    bmm\n    baddbmm \n    dot \n    matmul\n    mm\n    mv\n\n"
  },
  {
    "path": "docs/source/optim.rst",
    "content": "oneflow.optim\n===================================\n\n.. The documentation is referenced from: \n   https://pytorch.org/docs/1.10/optim.html\n\noneflow.optim is a package implementing various optimization algorithms. Most commonly used methods are already supported, and the interface is general enough, so that more sophisticated ones can be also easily integrated in the future.\n\nHow to use an optimizer\n-----------------------\n\nTo use :mod:`oneflow.optim` you have to construct an optimizer object, that will hold\nthe current state and will update the parameters based on the computed gradients.\n\nConstructing it\n^^^^^^^^^^^^^^^\n\nTo construct an :class:`Optimizer` you have to give it an iterable containing the\nparameters (all should be :class:`~oneflow.autograd.Variable` s) to optimize. Then,\nyou can specify optimizer-specific options such as the learning rate, weight decay, etc.\n\n.. note::\n    If you need to move a model to GPU via ``.cuda()``, please do so before \n    constructing optimizers for it. Parameters of a model after ``.cuda()`` \n    will be different objects with those before the call.\n\n    In general, you should make sure that optimized parameters live in \n    consistent locations when optimizers are constructed and used. \n    \nExample::\n\n    import oneflow\n    import oneflow.nn as nn\n    import oneflow.optim as optim\n\n    model = nn.Linear(16, 3)\n    optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)\n\nPer-parameter options\n^^^^^^^^^^^^^^^^^^^^^\n\n:class:`Optimizer` also support specifying per-parameter options. To do this, instead\nof passing an iterable of :class:`~oneflow.autograd.Variable`, pass in an iterable of\n:class:`dict`. Each of them will define a separate parameter group, and should contain\na ``params`` key, containing a list of parameters belonging to it. Other keys\nshould match the keyword arguments accepted by the optimizers, and will be used\nas optimization options for this group.\n\n.. note::\n\n    You can still pass options as keyword arguments. They will be used as\n    defaults, in the groups that didn't override them. This is useful when you\n    only want to vary a single option, while keeping all others consistent\n    between parameter groups.\n\n\nFor example, this is very useful when one wants to specify per-layer learning rates::\n\n    import oneflow.nn as nn\n    import oneflow.optim as optim\n\n\n    class Model(nn.Module):\n        def __init__(self):\n            super(Model, self).__init__()\n            self.base = nn.Linear(64, 32)\n            self.classifier = nn.Linear(32, 10)\n\n        def forward(self, x):\n            out = self.base(x)\n            out = self.classifier(out)\n            return out\n\n\n    model = Model()\n    optim.SGD(\n        [\n            {\"params\": model.base.parameters()},\n            {\"params\": model.classifier.parameters(), \"lr\": 1e-3},\n        ],\n        lr=1e-2,\n        momentum=0.9,\n    )\n\n\nThis means that ``model.base``'s parameters will use the default learning rate of ``1e-2``,\n``model.classifier``'s parameters will use a learning rate of ``1e-3``, and a momentum of\n``0.9`` will be used for all parameters.\n\nTaking an optimization step\n^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\nAll optimizers implement a :func:`~Optimizer.step` method, that updates the\nparameters. It can be used in two ways:\n\n``optimizer.step()``\n~~~~~~~~~~~~~~~~~~~~\n\nThis is a simplified version supported by most optimizers. The function can be\ncalled once the gradients are computed using e.g.\n:func:`~oneflow.autograd.Variable.backward`.\n\nExample::\n\n    import oneflow\n    import oneflow.nn as nn\n    import oneflow.nn.functional as F\n    import oneflow.optim as optim\n    from oneflow.utils.data import Dataset, DataLoader\n\n\n    class CustomDataset(Dataset):\n        def __init__(self, num):\n            self.inputs = oneflow.randn(num, 1)\n            self.targets = oneflow.sin(self.inputs)\n\n        def __len__(self):\n            return self.inputs.shape[0]\n\n        def __getitem__(self, index):\n            return self.inputs[index], self.targets[index]\n\n\n    class Model(nn.Module):\n        def __init__(self, input_size):\n            super(Model, self).__init__()\n            self.linear1 = nn.Linear(input_size, 64)\n            self.linear2 = nn.Linear(64, input_size)\n\n        def forward(self, x):\n            out = self.linear1(x)\n            return self.linear2(F.relu(out))\n\n\n    dataset = CustomDataset(10000)\n    dataloader = DataLoader(dataset, batch_size=10)\n    model = Model(1)\n    loss_fn = nn.MSELoss()\n    optimizer = optim.SGD(model.parameters(), lr=1e-3)\n\n    for epoch in range(100):\n        for input, target in dataloader:\n            optimizer.zero_grad()\n            output = model(input)\n            loss = loss_fn(output, target)\n            loss.backward()\n            optimizer.step()\n\n.. _optimizer-algorithms:\n\n.. currentmodule:: oneflow.optim\n\n\nBase class\n----------\n\n.. autoclass:: Optimizer\n\n.. autosummary::\n    :toctree: generated\n    :nosignatures:\n\n    Optimizer.add_param_group\n    Optimizer.load_state_dict\n    Optimizer.state_dict\n    Optimizer.step\n    Optimizer.zero_grad\n\nAlgorithms\n----------\n\n.. autosummary::\n    :toctree: generated\n    :nosignatures:\n\n    Adagrad\n    Adam\n    AdamW\n    LAMB\n    RMSprop\n    SGD\n    LBFGS\n\nAdjust Learning Rate\n--------------------\n\n:mod:`oneflow.optim.lr_scheduler` provides several methods to adjust the learning\nrate based on the number of epochs. :class:`oneflow.optim.lr_scheduler.ReduceLROnPlateau`\nallows dynamic learning rate reducing based on some validation measurements.\n\nLearning rate scheduling should be applied after optimizer's update; e.g., you\nshould write your code this way:\n\nExample::\n\n    import oneflow\n    import oneflow.nn as nn\n    import oneflow.nn.functional as F\n    import oneflow.optim as optim\n    from oneflow.utils.data import Dataset, DataLoader\n\n\n    class CustomDataset(Dataset):\n        def __init__(self, num):\n            self.inputs = oneflow.randn(num, 1)\n            self.targets = oneflow.sin(self.inputs)\n\n        def __len__(self):\n            return self.inputs.shape[0]\n\n        def __getitem__(self, index):\n            return self.inputs[index], self.targets[index]\n\n\n    class Model(nn.Module):\n        def __init__(self, input_size):\n            super(Model, self).__init__()\n            self.linear1 = nn.Linear(input_size, 64)\n            self.linear2 = nn.Linear(64, input_size)\n\n        def forward(self, x):\n            out = self.linear1(x)\n            return self.linear2(F.relu(out))\n\n\n    dataset = CustomDataset(10000)\n    dataloader = DataLoader(dataset, batch_size=10)\n    model = Model(1)\n    loss_fn = nn.MSELoss()\n    optimizer = optim.SGD(model.parameters(), lr=1e-3)\n    scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)\n\n    for epoch in range(20):\n        for input, target in dataloader:\n            optimizer.zero_grad()\n            output = model(input)\n            loss = loss_fn(output, target)\n            loss.backward()\n            optimizer.step()\n        scheduler.step()\n\nMost learning rate schedulers can be chained (also referred to as\nchaining schedulers).\n\nExample::\n\n    import oneflow\n    import oneflow.nn as nn\n    import oneflow.nn.functional as F\n    import oneflow.optim as optim\n    from oneflow.utils.data import Dataset, DataLoader\n\n\n    class CustomDataset(Dataset):\n        def __init__(self, num):\n            self.inputs = oneflow.randn(num, 1)\n            self.targets = oneflow.sin(self.inputs)\n\n        def __len__(self):\n            return self.inputs.shape[0]\n\n        def __getitem__(self, index):\n            return self.inputs[index], self.targets[index]\n\n\n    class Model(nn.Module):\n        def __init__(self, input_size):\n            super(Model, self).__init__()\n            self.linear1 = nn.Linear(input_size, 64)\n            self.linear2 = nn.Linear(64, input_size)\n\n        def forward(self, x):\n            out = self.linear1(x)\n            return self.linear2(F.relu(out))\n\n\n    dataset = CustomDataset(10000)\n    dataloader = DataLoader(dataset, batch_size=10)\n    model = Model(1)\n    loss_fn = nn.MSELoss()\n    optimizer = optim.SGD(model.parameters(), lr=1e-3)\n    scheduler1 = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)\n    scheduler2 = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[5, 10], gamma=0.1)\n\n    for epoch in range(20):\n        for input, target in dataloader:\n            optimizer.zero_grad()\n            output = model(input)\n            loss = loss_fn(output, target)\n            loss.backward()\n            optimizer.step()\n        scheduler1.step()\n        scheduler2.step()\n\nIn many places in the documentation, we will use the following template to refer to schedulers\nalgorithms.\n\n    >>> scheduler = ...\n    >>> for epoch in range(100):\n    >>>     train(...)\n    >>>     validate(...)\n    >>>     scheduler.step()\n\n.. warning::\n  If you use the learning rate scheduler (calling ``scheduler.step()``) before the optimizer's update\n  (calling ``optimizer.step()``), this will skip the first value of the learning rate schedule. Please \n  check if you are calling ``scheduler.step()`` at the wrong time.\n\n.. autosummary::\n    :toctree: generated\n    :nosignatures:\n\n    lr_scheduler.CosineAnnealingLR\n    lr_scheduler.CosineDecayLR \n    lr_scheduler.ExponentialLR \n    lr_scheduler.LambdaLR \n    lr_scheduler.MultiStepLR\n    lr_scheduler.PolynomialLR \n    lr_scheduler.ReduceLROnPlateau \n    lr_scheduler.StepLR\n    lr_scheduler.ConstantLR\n    lr_scheduler.LinearLR\n    lr_scheduler.ChainedScheduler\n    lr_scheduler.SequentialLR\n    lr_scheduler.CosineAnnealingWarmRestarts\n"
  },
  {
    "path": "docs/source/special.rst",
    "content": "oneflow.special\n======================================\nThe oneflow.special module, modeled after SciPy's special module.\n-------------------------------------------------------------------\n\n.. currentmodule:: oneflow.special\n.. autosummary::\n    :toctree: generated\n    :nosignatures:\n\n    digamma\n    erf\n    erfc\n    erfinv\n    exp2\n    expm1\n    log1p\n    log_softmax\n    logsumexp\n    round\n    softmax\n    zeta\n"
  },
  {
    "path": "docs/source/tensor.rst",
    "content": "oneflow.Tensor\n===================================\n\n.. The documentation is referenced from: \n   https://pytorch.org/docs/1.10/tensors.html\n\nA :class:`oneflow.Tensor` is a multi-dimensional matrix containing elements of\na single data type.\n\n.. currentmodule:: oneflow\n\nData types\n----------\n\nOneFlow defines 8 Tensor types with CPU and GPU variants which are as follows:\n\n======================================= =============================================== =============================== ==================================\nData type                               dtype                                           CPU tensor                      GPU tensor\n======================================= =============================================== =============================== ==================================\nBoolean                                 ``oneflow.bool``                                :class:`oneflow.BoolTensor`     :class:`oneflow.cuda.BoolTensor`\n8-bit integer (unsigned)                ``oneflow.uint8``                               :class:`oneflow.ByteTensor`     :class:`oneflow.cuda.ByteTensor`\n8-bit integer (signed)                  ``oneflow.int8``                                :class:`oneflow.CharTensor`     :class:`oneflow.cuda.CharTensor`\n64-bit floating point                   ``oneflow.float64`` or ``oneflow.double``       :class:`oneflow.DoubleTensor`   :class:`oneflow.cuda.DoubleTensor`\n32-bit floating point                   ``oneflow.float32`` or ``oneflow.float``        :class:`oneflow.FloatTensor`    :class:`oneflow.cuda.FloatTensor`\n16-bit floating point                   ``oneflow.float16`` or ``oneflow.half``         :class:`oneflow.HalfTensor`     :class:`oneflow.cuda.HalfTensor`\n32-bit integer (signed)                 ``oneflow.int32`` or ``oneflow.int``            :class:`oneflow.IntTensor`      :class:`oneflow.cuda.IntTensor`\n64-bit integer (signed)                 ``oneflow.int64`` or ``oneflow.long``           :class:`oneflow.LongTensor`     :class:`oneflow.cuda.LongTensor`\n======================================= =============================================== =============================== ==================================\n\nInitializing and basic operations\n---------------------------------\n\nA tensor can be constructed from a Python :class:`list` or sequence using the\n:func:`oneflow.tensor` constructor:\n\n::\n\n    >>> import oneflow\n    >>> import numpy as np\n    >>> oneflow.tensor([[1., -1.], [1., -1.]])\n    tensor([[ 1., -1.],\n            [ 1., -1.]], dtype=oneflow.float32)\n    >>> oneflow.tensor(np.array([[1, 2, 3], [4, 5, 6]]))\n    tensor([[ 1, 2, 3],\n            [ 4, 5, 6]], dtype=oneflow.int64)\n\n.. warning::\n\n    :func:`oneflow.tensor` always copies :attr:`data`. If you have a Tensor\n    :attr:`data` and just want to change its ``requires_grad`` flag, use\n    :meth:`~oneflow.Tensor.requires_grad_` or\n    :meth:`~oneflow.Tensor.detach` to avoid a copy.\n    If you have a numpy array and want to avoid a copy, use\n    :func:`oneflow.as_tensor`.\n\n.. A tensor of specific data type can be constructed by passing a :class:`oneflow.dtype` and/or a :class:`oneflow.device` to a constructor or tensor creation op:\n\n::\n\n    >>> import oneflow\n    >>> oneflow.zeros([2, 4], dtype=oneflow.int32)\n    tensor([[ 0, 0, 0, 0],\n            [ 0, 0, 0, 0]], dtype=oneflow.int32)\n    >>> cuda0 = oneflow.device('cuda:0')\n    >>> oneflow.ones([2, 4], dtype=oneflow.float64, device=cuda0)\n    tensor([[ 1., 1., 1., 1.],\n            [ 1., 1., 1., 1.]], device='cuda:0', dtype=oneflow.float64)\n\nFor more information about building tensors, see :ref:`tensor-creation-ops`\n\nThe contents of a tensor can be accessed and modified using Python's indexing\nand slicing notation:\n\n::\n\n    >>> import oneflow\n    >>> x = oneflow.tensor([[1, 2, 3], [4, 5, 6]])\n    >>> print(x[1][2])\n    tensor(6, dtype=oneflow.int64)\n    >>> x[0][1] = 8\n    >>> print(x)\n    tensor([[1, 8, 3],\n            [4, 5, 6]], dtype=oneflow.int64)\n\nUse :meth:`oneflow.Tensor.item` to get a Python number from a tensor containing a\nsingle value:\n\n::\n\n    >>> import oneflow\n    >>> x = oneflow.tensor([[1]])\n    >>> x\n    tensor([[1]], dtype=oneflow.int64)\n    >>> x.item()\n    1\n    >>> x = oneflow.tensor(2.5)\n    >>> x\n    tensor(2.5000, dtype=oneflow.float32)\n    >>> x.item()\n    2.5\n\nFor more information about indexing, see :ref:`indexing-slicing-joining`\n\nA tensor can be created with :attr:`requires_grad=True` so that\n:mod:`oneflow.autograd` records operations on them for automatic differentiation.\n\n::\n\n    >>> import oneflow\n    >>> x = oneflow.tensor([[1., -1.], [1., 1.]], requires_grad=True)\n    >>> out = x.pow(2).sum()\n    >>> out.backward()\n    >>> x.grad\n    tensor([[ 2., -2.],\n            [ 2.,  2.]], dtype=oneflow.float32)\n\n.. note::\n   For more information on the :class:`oneflow.dtype`, :class:`oneflow.device`, and\n   :class:`oneflow.layout` attributes of a :class:`oneflow.Tensor`, see\n   :ref:`tensor-attributes-doc`.\n\n.. note::\n   Methods which mutate a tensor are marked with an underscore suffix.\n   For example, :func:`oneflow.FloatTensor.add_` computes the absolute value\n   in-place and returns the modified tensor, while :func:`oneflow.FloatTensor.add`\n   computes the result in a new tensor.\n\n.. note::\n    To change an existing tensor's :class:`oneflow.device` and/or :class:`oneflow.dtype`, consider using\n    :meth:`~oneflow.Tensor.to` method of Tensor object.\n\n.. warning::\n   Current implementation of :class:`oneflow.Tensor` introduces memory overhead,\n   thus it might lead to unexpectedly high memory usage in the applications with many tiny tensors.\n   If this is your case, consider using one large structure.\n\nTensor class reference\n----------------------\n\n.. class:: Tensor()\n\n   There are a few main ways to create a tensor, depending on your use case.\n\n   - To create a tensor with pre-existing data, use :func:`oneflow.tensor`.\n   - To create a tensor with specific size, use ``oneflow.*`` tensor creation\n     ops (see :ref:`tensor-creation-ops`).\n   - To create a tensor with the same size (and similar types) as another tensor,\n     use ``oneflow.*_like`` tensor creation ops\n     (see :ref:`tensor-creation-ops`).\n\n.. currentmodule:: oneflow\n.. autosummary::\n    :toctree: generated\n    :nosignatures:\n    \n    Tensor.new_empty\n    Tensor.new_ones \n    Tensor.new_zeros\n    Tensor.new_full\n    Tensor.new_tensor\n    Tensor.is_cuda\n    Tensor.is_global\n    Tensor.device\n    Tensor.grad\n    Tensor.ndim\n    Tensor.abs\n    Tensor.acos\n    Tensor.acosh\n    Tensor.add\n    Tensor.add_\n    Tensor.addcdiv\n    Tensor.addcdiv_\n    Tensor.addcmul\n    Tensor.addcmul_\n    Tensor.addmm\n    Tensor.all\n    Tensor.amin\n    Tensor.amax\n    Tensor.any\n    Tensor.arccos\n    Tensor.arccosh\n    Tensor.arcsin\n    Tensor.arcsinh\n    Tensor.arctan\n    Tensor.arctanh\n    Tensor.argmax\n    Tensor.argmin\n    Tensor.argsort\n    Tensor.argwhere\n    Tensor.asin\n    Tensor.asinh\n    Tensor.atan\n    Tensor.atan2\n    Tensor.atanh\n    Tensor.backward\n    Tensor.bmm\n    Tensor.bool\n    Tensor.byte\n    Tensor.cast\n    Tensor.ceil\n    Tensor.ceil_\n    Tensor.chunk\n    Tensor.clamp\n    Tensor.clamp_\n    Tensor.clip\n    Tensor.clip_\n    Tensor.clone\n    Tensor.contiguous\n    Tensor.copy_\n    Tensor.cos\n    Tensor.cosh\n    Tensor.cpu\n    Tensor.cuda\n    Tensor.cumprod\n    Tensor.cumsum\n    Tensor.data\n    Tensor.dot\n    Tensor.detach\n    Tensor.placement\n    Tensor.sbp\n    Tensor.diag\n    Tensor.diagonal\n    Tensor.dim\n    Tensor.div\n    Tensor.div_\n    Tensor.double\n    Tensor.dtype \n    Tensor.digamma\n    Tensor.element_size\n    Tensor.eq\n    Tensor.equal\n    Tensor.erf\n    Tensor.erfc\n    Tensor.erfinv\n    Tensor.erfinv_\n    Tensor.exp\n    Tensor.exp2\n    Tensor.expand\n    Tensor.expand_as\n    Tensor.expm1\n    Tensor.fill_\n    Tensor.flatten\n    Tensor.flip\n    Tensor.float\n    Tensor.floor\n    Tensor.floor_\n    Tensor.floor_divide\n    Tensor.fmod\n    Tensor.gather\n    Tensor.ge\n    Tensor.get_device\n    Tensor.grad_fn\n    Tensor.gt\n    Tensor.gt_\n    Tensor.half\n    Tensor.in_top_k\n    Tensor.index_select\n    Tensor.index_add\n    Tensor.index_add_\n    Tensor.int\n    Tensor.is_contiguous\n    Tensor.is_floating_point\n    Tensor.is_lazy\n    Tensor.is_leaf\n    Tensor.isinf\n    Tensor.isnan\n    Tensor.item\n    Tensor.le\n    Tensor.lerp\n    Tensor.lerp_\n    Tensor.log\n    Tensor.log1p\n    Tensor.log2\n    Tensor.log10\n    Tensor.logical_and\n    Tensor.logical_or\n    Tensor.logical_not\n    Tensor.logical_xor\n    Tensor.long\n    Tensor.lt\n    Tensor.masked_fill\n    Tensor.masked_fill_\n    Tensor.masked_select\n    Tensor.matmul\n    Tensor.mm\n    Tensor.mv\n    Tensor.max\n    Tensor.maximum\n    Tensor.median\n    Tensor.mean\n    Tensor.min\n    Tensor.minimum\n    Tensor.mish\n    Tensor.mode\n    Tensor.mul\n    Tensor.mul_\n    Tensor.frac\n    Tensor.frac_\n    Tensor.nansum\n    Tensor.narrow\n    Tensor.ndimension\n    Tensor.ne\n    Tensor.neg\n    Tensor.negative\n    Tensor.nelement\n    Tensor.nonzero\n    Tensor.norm\n    Tensor.normal_\n    Tensor.numel\n    Tensor.numpy\n    Tensor.offload\n    Tensor.load\n    Tensor.is_offloaded\n    Tensor.permute\n    Tensor.pow\n    Tensor.prod\n    Tensor.quantile\n    Tensor.reciprocal\n    Tensor.register_hook\n    Tensor.relu\n    Tensor.repeat\n    Tensor.repeat_interleave\n    Tensor.requires_grad\n    Tensor.requires_grad_\n    Tensor.reshape\n    Tensor.reshape_as\n    Tensor.retain_grad\n    Tensor.roll\n    Tensor.round\n    Tensor.round_\n    Tensor.rsqrt\n    Tensor.selu\n    Tensor.shape\n    Tensor.sigmoid\n    Tensor.sign\n    Tensor.silu\n    Tensor.sin\n    Tensor.sin_\n    Tensor.sinh\n    Tensor.size\n    Tensor.softmax\n    Tensor.softplus\n    Tensor.softsign\n    Tensor.sort\n    Tensor.split\n    Tensor.sqrt\n    Tensor.square\n    Tensor.squeeze\n    Tensor.squeeze_\n    Tensor.std\n    Tensor.storage_offset\n    Tensor.stride\n    Tensor.logsumexp\n    Tensor.sum\n    Tensor.swapaxes\n    Tensor.swapdims\n    Tensor.sub\n    Tensor.sub_\n    Tensor.tan\n    Tensor.tanh\n    Tensor.tile\n    Tensor.to\n    Tensor.local_to_global\n    Tensor.global_to_global\n    Tensor.to_global\n    Tensor.to_local\n    Tensor.to_consistent\n    Tensor.tolist\n    Tensor.topk\n    Tensor.transpose\n    Tensor.tril\n    Tensor.triu\n    Tensor.trunc\n    Tensor.type_as\n    Tensor.type\n    Tensor.t\n    Tensor.T\n    Tensor.unbind\n    Tensor.unfold\n    Tensor.uniform_\n    Tensor.unsqueeze\n    Tensor.unsqueeze_\n    Tensor.as_strided\n    Tensor.as_strided_\n    Tensor.var\n    Tensor.view\n    Tensor.view_as\n    Tensor.where\n    Tensor.zero_\n    Tensor.nms\n    Tensor.pin_memory\n    Tensor.is_pinned\n    Tensor.inverse\n    Tensor.cross\n    Tensor.scatter\n    Tensor.scatter_\n    Tensor.scatter_add\n    Tensor.scatter_add_\n    Tensor.bernoulli\n    Tensor.bernoulli_\n    Tensor.bincount\n    Tensor.isclose\n    Tensor.allclose\n    Tensor.broadcast_to\n    Tensor.unique\n    Tensor.bitwise_and\n    Tensor.bitwise_or\n    Tensor.bitwise_xor\n    Tensor.baddbmm\n"
  },
  {
    "path": "docs/source/tensor_attributes.rst",
    "content": ".. currentmodule:: oneflow\n\n.. _tensor-attributes-doc:\n\nTensor Attributes\n=============================================================\n\n.. The documentation is referenced from: https://pytorch.org/docs/1.10/tensor_attributes.html.\n\n\nEach local ``oneflow.Tensor`` has a :class:`oneflow.dtype`, :class:`oneflow.device`, and global ``oneflow.Tensor`` has a :class:`oneflow.dtype`, :class:`oneflow.placement`, :class:`oneflow.sbp`.\n\n.. contents:: oneflow\n    :depth: 2\n    :local:\n    :class: this-will-duplicate-information-and-it-is-still-useful-here\n    :backlinks: top\n\n\n.. _dtype-doc:\n\noneflow.dtype\n-----------------------\n\n.. class:: dtype\n\nA :class:`oneflow.dtype` is an object that represents the data type of a\n:class:`oneflow.Tensor`. Oneflow has eight different data types:\n\n======================================= =============================================== =============================== ==================================\nData type                               dtype                                           CPU tensor                      GPU tensor\n======================================= =============================================== =============================== ==================================\nBoolean                                 ``oneflow.bool``                                :class:`oneflow.BoolTensor`     :class:`oneflow.cuda.BoolTensor`\n8-bit integer (unsigned)                ``oneflow.uint8``                               :class:`oneflow.ByteTensor`     :class:`oneflow.cuda.ByteTensor`\n8-bit integer (signed)                  ``oneflow.int8``                                :class:`oneflow.CharTensor`     :class:`oneflow.cuda.CharTensor`\n64-bit floating point                   ``oneflow.float64`` or ``oneflow.double``       :class:`oneflow.DoubleTensor`   :class:`oneflow.cuda.DoubleTensor`\n32-bit floating point                   ``oneflow.float32`` or ``oneflow.float``        :class:`oneflow.FloatTensor`    :class:`oneflow.cuda.FloatTensor`\n16-bit floating point                   ``oneflow.float16`` or ``oneflow.half``         :class:`oneflow.HalfTensor`     :class:`oneflow.cuda.HalfTensor`\n32-bit integer (signed)                 ``oneflow.int32`` or ``oneflow.int``            :class:`oneflow.IntTensor`      :class:`oneflow.cuda.IntTensor`\n64-bit integer (signed)                 ``oneflow.int64`` or ``oneflow.long``           :class:`oneflow.LongTensor`     :class:`oneflow.cuda.LongTensor`\n======================================= =============================================== =============================== ==================================\n\n\nTo find out if a :class:`oneflow.dtype` is a floating point data type, the property :attr:`is_floating_point`\ncan be used, which returns ``True`` if the data type is a floating point data type.\n\n.. _type-promotion-doc:\n\nWhen the dtypes of inputs to an arithmetic operation (`add`, `sub`, `div`, `mul`) differ, we promote\nby finding the minimum dtype that satisfies the following rules:\n\n* If the type of a scalar operand is of a higher category than tensor operands\n  (where complex > floating > integral > boolean), we promote to a type with sufficient size to hold\n  all scalar operands of that category.\n* If a zero-dimension tensor operand has a higher category than dimensioned operands,\n  we promote to a type with sufficient size and category to hold all zero-dim tensor operands of\n  that category.\n* If there are no higher-category zero-dim operands, we promote to a type with sufficient size\n  and category to hold all dimensioned operands.\n\nA floating point scalar operand has dtype `oneflow.get_default_dtype()` and an integral\nnon-boolean scalar operand has dtype `oneflow.int64`. Unlike numpy, we do not inspect\nvalues when determining the minimum `dtypes` of an operand.  Quantized and complex types\nare not yet supported.\n\nPromotion Examples::\n\n    >>> float_tensor = oneflow.ones(1, dtype=oneflow.float)\n    >>> double_tensor = oneflow.ones(1, dtype=oneflow.double)\n    >>> int_tensor = oneflow.ones(1, dtype=oneflow.int)\n    >>> long_tensor = oneflow.ones(1, dtype=oneflow.long)\n    >>> uint_tensor = oneflow.ones(1, dtype=oneflow.uint8)\n    >>> double_tensor = oneflow.ones(1, dtype=oneflow.double)\n    >>> bool_tensor = oneflow.ones(1, dtype=oneflow.bool)\n    # zero-dim tensors\n    >>> long_zerodim = oneflow.tensor(1, dtype=oneflow.long)\n    >>> int_zerodim = oneflow.tensor(1, dtype=oneflow.int)\n\n    >>> a,b=oneflow.tensor(5),oneflow.tensor(5)\n    >>> oneflow.add(a, b).dtype\n    oneflow.int64\n    # 5 is an int64, but does not have higher category than int_tensor so is not considered.\n    >>> (int_tensor + 5).dtype\n    oneflow.int32\n    >>> (int_tensor + long_zerodim).dtype\n    oneflow.int64\n    >>> (long_tensor + int_tensor).dtype\n    oneflow.int64\n    >>> (bool_tensor + long_tensor).dtype\n    oneflow.int64\n    >>> (bool_tensor + uint_tensor).dtype\n    oneflow.uint8\n    >>> (float_tensor + double_tensor).dtype\n    oneflow.float64\n    >>> (bool_tensor + int_tensor).dtype\n    oneflow.int32\n    # Since long is a different kind than float, result dtype only needs to be large enough\n    # to hold the float.\n    >>> oneflow.add(long_tensor, float_tensor).dtype\n    oneflow.float32\n\nWhen the output tensor of an arithmetic operation is specified, we allow casting to its `dtype` except that:\n  * An integral output tensor cannot accept a floating point tensor.\n  * A boolean output tensor cannot accept a non-boolean tensor.\n  * A non-complex output tensor cannot accept a complex tensor\n\nCasting Examples::\n\n    # allowed:\n    >>> float_tensor *= float_tensor\n    >>> float_tensor *= int_tensor\n    >>> float_tensor *= uint_tensor\n    >>> float_tensor *= bool_tensor\n    >>> int_tensor *= uint_tensor\n\n    # disallowed (RuntimeError: result type can't be cast to the desired output type):\n    >>> float_tensor *= double_tensor\n    >>> int_tensor *= float_tensor\n    >>> int_tensor *= long_tensor\n    >>> uint_tensor *= int_tensor\n    >>> bool_tensor *= int_tensor\n    >>> bool_tensor *= uint_tensor\n\n.. _device-doc:\n\noneflow.device\n------------------------\n\n.. class:: device\n\nA :class:`oneflow.device` is an object representing the device on which a :class:`oneflow.Tensor` is\nor will be allocated.\n\nThe :class:`oneflow.device` contains a device type (``'cpu'`` or ``'cuda'``) and optional device\nordinal for the device type. If the device ordinal is not present, this object will always represent\nthe current device for the device type, even after :func:`oneflow.cuda.set_device()` is called; e.g.,\na :class:`oneflow.Tensor` constructed with device ``'cuda'`` is equivalent to ``'cuda:X'`` where X is\nthe result of :func:`oneflow.cuda.current_device()`.\n\nA :class:`oneflow.Tensor`'s device can be accessed via the :attr:`Tensor.device` property.\n\nA :class:`oneflow.device` can be constructed via a string or via a string and device ordinal\n\nVia a string:\n::\n\n    >>> oneflow.device('cuda:0')\n    device(type='cuda', index=0)\n\n    >>> oneflow.device('cpu')\n    device(type='cpu', index=0)\n\n    >>> oneflow.device('cuda')  # current cuda device\n    device(type='cuda', index=0)\n\nVia a string and device ordinal:\n\n::\n\n    >>> oneflow.device('cuda', 0)\n    device(type='cuda', index=0)\n\n    >>> oneflow.device('cpu', 0)\n    device(type='cpu', index=0)\n\n.. note::\n   The :class:`oneflow.device` argument in functions can generally be substituted with a string.\n   This allows for fast prototyping of code.\n\n   >>> # Example of a function that takes in a oneflow.device\n   >>> cuda1 = oneflow.device('cuda:1')\n   >>> oneflow.randn((2,3), device=cuda1)\n\n   >>> # You can substitute the oneflow.device with a string\n   >>> oneflow.randn((2,3), device='cuda:1')\n\n.. note::\n   For legacy reasons, a device can be constructed via a single device ordinal, which is treated\n   as a cuda device.  This matches :meth:`Tensor.get_device`, which returns an ordinal for cuda\n   tensors and is not supported for cpu tensors.\n\n   >>> oneflow.device(1)\n   device(type='cuda', index=1)\n\n.. note::\n   Methods which take a device will generally accept a (properly formatted) string\n   or (legacy) integer device ordinal, i.e. the following are all equivalent:\n\n   >>> oneflow.randn((2,3), device=oneflow.device('cuda:1'))\n   >>> oneflow.randn((2,3), device='cuda:1')\n   >>> oneflow.randn((2,3), device=1)  # legacy\n\noneflow.placement\n--------------------------------------------------------------\n.. autoclass:: oneflow.placement\n\noneflow.placement.all\n--------------------------------------------------------------\n.. autofunction:: oneflow.placement.all\n\noneflow.env.all_device_placement\n--------------------------------------------------------------\n.. autofunction:: oneflow.env.all_device_placement\n\noneflow.sbp.sbp\n--------------------------------------------------------------\n.. autoclass:: oneflow.sbp.sbp\n"
  },
  {
    "path": "docs/source/troubleshooting.md",
    "content": "# Troubleshooting\n\n- 'libunwind.h' not found\n\n  - You might add CMake argument `-DWITH_UNWIND=OFF`, or install libunwind in your system.\n\n- `CUDNN_STATUS_NOT_INITIALIZED`\n\n  - You might see error message like these:\n    ```\n    I0729 22:37:45.483937439   56788 ev_epoll_linux.c:82]        Use of signals is disabled. Epoll enginll not be used\n    E0729 22:37:45.515343 56788 version.cpp:82] Failed to get cuda runtime version: CUDA driver version nsufficient for CUDA runtime version\n    F0729 22:38:31.209002 56788 improver.cpp:535] Check failed: mem_size > 0 (-524288000 vs. 0)\n    ```\n    ```\n    F0723 19:05:56.194067 40970 cuda_util.cpp:82] Check failed: error == CUDNN_STATUS_SUCCESS (1 vs. 0) CUDNN_STATUS_NOT_INITIALIZED\n    ```\n  - Please upgrade to Nvidia Linux x86_64 driver. Version >= 440.33 is recommended.\n  - For more information, please refer to [CUDA compatibility documentation](https://docs.nvidia.com/deploy/cuda-compatibility/index.html).\n\n- Failed to compile `.cu` files\n\n  - Please refer to [CUDA System Requirements](https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html#system-requirements) . Make sure your linux distribution and libraries shipped with it meet the requirements.\n  - If you are using tools like conda, please make sure libraries you install doesn't shade the proper installation comes with linux distribution or package management like apt-get.\n  - Please build OneFlow with a newer version of CMake. You could download version 3.14 from here: [https://github.com/Kitware/CMake/releases/download/v3.14.0/cmake-3.14.0-Linux-x86_64.tar.gz](https://github.com/Kitware/CMake/releases/download/v3.14.0/cmake-3.14.0-Linux-x86_64.tar.gz)\n\n- How do I know what compilers and flags are used to compile OneFlow?\n\n  - run `make clean && make VERBOSE=1` to get exact compile commands with compiler path and flags\n\n- How to compile OneFlow with RDMA support?\n\n  - add cmake flag `-DBUILD_RDMA` to compile OneFlow\n\n- Which version of g++ CMake is using to build OneFlow?\n\n  - You should find a line like this in CMake output:\n\n    ```bash\n    -- CMAKE_CXX_COMPILER_VERSION: [YOUR G++ VERSION NUMBER]\n    ```\n\n- Failed to compile NCCL\n\n  - Try use less threads when compiling OneFlow third party. For instance, use\n\n    ```bash\n    cmake -DTHIRD_PARTY=ON .. && make\n    ```\n\n    instead of\n\n    ```bash\n    cmake -DTHIRD_PARTY=ON .. && make -j$(nproc) `\n    ```\n\n- `\"CUDA_VERSION\" \"VERSION_GREATER_EQUAL\" \"10.0\"`\n\n  - Please use a newer version of CMake\n  - Make sure cmake is correctly included in `PATH`\n\n- CUBLAS not found\n\n  - Usually it happens when using CUDA 10.1 or newer\n  - You should see error massage by CMake like this:\n\n    ```\n    cuda lib not found: /usr/local/miniconda3/envs/dl/lib/libcublas_static.a or\n    /usr/local/cuda/lib64/libcublas_static.a\n    ```\n\n  - Make sure `libcublas_static.a` is in one of the two directories.\n\n- When running OneFlow in gdb, there is no debug information for code location.\n\n  - add cmake flag `-DCMAKE_BUILD_TYPE=RELWITHDEBINFO` or `-DCMAKE_BUILD_TYPE=DEBUG` and recompile\n\n- `libof_ccobj.a: File truncated`\n\n  - You might see error message like this:\n\n    ```\n    /usr/bin/ar: libof_ccobj.a: File truncated\n    make[2]: *** [libof_ccobj.a] Error 1\n    make[2]: *** Deleting file `libof_ccobj.a'\n    make[1]: *** [CMakeFiles/of_ccobj.dir/all] Error 2\n    make: *** [all] Error 2\n    ```\n\n  - You should upgrade your GNU Binutils. Version 2.33.1 is recommended. If you are using conda, you could install it by running `conda install -c conda-forge binutils`\n\n- Failed to compile because C++ 17 is enabled\n\n  - In some cases, environment variable `CXXFLAGS` is not empty and contains `--std c++17`.\n  - Check if it is empty by running `echo $CXXFLAGS` and clear it with `unset CXXFLAGS`.\n  - If you are using conda, to make the changes on environment variables permanent, you can run:\n    ```bash\n    conda env config vars set CXXFLAGS=\"-fPIC\"\n    ```\n\n- cmake outputs error `No CMAKE_ASM_NASM_COMPILER could be found.`\n\n  - Install `nasm`. For instance, run `sudo yum install nasm` if you are on centos.\n\n- `No module named 'google.protobuf'`\n\n  - You might see error message like this:\n    ```\n    Scanning dependencies of target generate_api\n    ...\n        from google.protobuf import descriptor as _descriptor\n    ModuleNotFoundError: No module named 'google.protobuf'\n    CMakeFiles/generate_api.dir/build.make:57: recipe for target 'CMakeFiles/generate_api' failed\n    make[2]: *** [CMakeFiles/generate_api] Error 1\n    ```\n  - Install development dependencies by running:\n    ```\n    pip3 install -r dev-requirements.txt\n    ```\n\n- Get gdb warning `ptrace: Operation not permitted.` and gdb command `bt` prints no backtrace\n\n  - You might get this warning when debugging OneFlow with gdb inside a docker container. Try add these flags when launching your container:\n    ```\n    docker run --cap-add=SYS_PTRACE --security-opt seccomp=unconfined\n    ```\n  - Please refer to https://stackoverflow.com/questions/19215177/how-to-solve-ptrace-operation-not-permitted-when-trying-to-attach-gdb-to-a-pro\n\n- It takes too long to download python packages when running `make`\n  - If you are in China, you could run this to have pip download packages from domestic mirror of pypi:\n    ```\n    python3 -m pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple\n    ```\n  - For more information on this, please refer to [pypi 镜像使用帮助](https://mirror.tuna.tsinghua.edu.cn/help/pypi/)\n"
  },
  {
    "path": "docs/source/type_info.rst",
    "content": ".. currentmodule:: oneflow\n\n.. _type-info-doc:\n\nType Info\n=========\n\n.. The documentation is referenced from: https://pytorch.org/docs/1.10/type_info.html.\n\nThe numerical properties of a :class:`oneflow.dtype` can be accessed through either the :class:`oneflow.finfo` or the :class:`oneflow.iinfo`.\n\n\n.. contents:: oneflow\n    :depth: 2\n    :local:\n    :class: this-will-duplicate-information-and-it-is-still-useful-here\n    :backlinks: top\n\noneflow.finfo\n-------------\n\n.. class:: oneflow.finfo\n\nA :class:`oneflow.finfo` is an object that represents the numerical properties of a floating point :class:`oneflow.dtype`, (i.e. ``oneflow.float32``, ``oneflow.float64`` and ``oneflow.float16``). This is similar to `numpy.finfo <https://numpy.org/doc/stable/reference/generated/numpy.finfo.html>`_.\n\nA :class:`oneflow.finfo` provides the following attributes:\n\n================== ======= ========================================================================== \nName               Type    Description                                                               \n================== ======= ========================================================================== \nbits               int     The number of bits occupied by the type.                                  \neps                float   The smallest representable number such that ``1.0 + eps != 1.0``.             \nmin                float   The largest representable number.                                         \nmax                float   The smallest representable number (typically ``-max``).                       \ntiny               float   The smallest positive normal number. See notes.\nresolution         float   The approximate decimal resolution of this type, i.e., ``10**-precision``.    \n================== ======= ========================================================================== \n\nFor example:\n\n.. code-block::\n\n    >>> import oneflow as flow\n    >>> flow.finfo()\n    finfo(resolution=1e-06, min=-3.40282e+38, max=3.40282e+38, eps=1.19209e-07, tiny=1.17549e-38, dtype=oneflow.float32, bits=32)\n    >>> flow.finfo(flow.float)\n    finfo(resolution=1e-06, min=-3.40282e+38, max=3.40282e+38, eps=1.19209e-07, tiny=1.17549e-38, dtype=oneflow.float32, bits=32)\n    >>> flow.finfo(flow.float16).bits\n    16\n    >>> flow.finfo(flow.float16).max\n    65504.0\n\noneflow.iinfo\n-------------\n\n.. class:: oneflow.iinfo\n\nA :class:`oneflow.iinfo` is an object that represents the numerical properties of a integer :class:`oneflow.dtype` (i.e. ``oneflow.uint8``, ``oneflow.int8``, ``oneflow.int16``, ``oneflow.int32``, and ``oneflow.int64``). This is similar to `numpy.iinfo <https://numpy.org/doc/stable/reference/generated/numpy.iinfo.html>`_.\n\nA :class:`oneflow.iinfo` provides the following attributes:\n\n================== ======= ========================================================================== \nName               Type    Description                                                               \n================== ======= ========================================================================== \nbits               int     The number of bits occupied by the type.                                  \nmin                float   The largest representable number.                                         \nmax                float   The smallest representable number.                       \n================== ======= ========================================================================== \n\nFor example:\n\n.. code-block ::\n\n    >>> import oneflow as flow\n    >>> flow.iinfo(flow.int8)\n    iinfo(min=-128, max=127, dtype=oneflow.int8, bits=8)\n    >>> flow.iinfo(flow.int).max\n    2147483647\n    >>> flow.iinfo(flow.int).bits\n    32\n"
  },
  {
    "path": "docs/source/utils.data.rst",
    "content": "oneflow.utils.data\n===================================\n\n.. The documentation is referenced from: \n   https://pytorch.org/docs/1.10/data.html\n\n.. automodule:: oneflow.utils.data\n\nAt the heart of Oneflow data loading utility is the :class:`oneflow.utils.data.DataLoader`\nclass.  It represents a Python iterable over a dataset, with support for\n\n* `map-style and iterable-style datasets <Dataset Types_>`_,\n\n* `customizing data loading order <Data Loading Order and Sampler_>`_,\n\n* `automatic batching <Loading Batched and Non-Batched Data_>`_,\n\n* `single- and multi-process data loading <Single- and Multi-process Data Loading_>`_,\n\n* `automatic memory pinning <Memory Pinning_>`_.\n\nThese options are configured by the constructor arguments of a\n:class:`~oneflow.utils.data.DataLoader`, which has signature::\n\n    DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,\n               batch_sampler=None, num_workers=0, collate_fn=None,\n               pin_memory=False, drop_last=False, timeout=0,\n               worker_init_fn=None, *, prefetch_factor=2,\n               persistent_workers=False)\n\nThe sections below describe in details the effects and usages of these options.\n\nDataset Types\n-------------\n\nThe most important argument of :class:`~oneflow.utils.data.DataLoader`\nconstructor is :attr:`dataset`, which indicates a dataset object to load data\nfrom. Oneflow supports two different types of datasets:\n\n* `map-style datasets <Map-style datasets_>`_,\n\n* `iterable-style datasets <Iterable-style datasets_>`_.\n\nMap-style datasets\n^^^^^^^^^^^^^^^^^^\n\nA map-style dataset is one that implements the :meth:`__getitem__` and\n:meth:`__len__` protocols, and represents a map from (possibly non-integral)\nindices/keys to data samples.\n\nFor example, such a dataset, when accessed with ``dataset[idx]``, could read\nthe ``idx``-th image and its corresponding label from a folder on the disk.\n\nSee :class:`~oneflow.utils.data.Dataset` for more details.\n\nIterable-style datasets\n^^^^^^^^^^^^^^^^^^^^^^^\n\nAn iterable-style dataset is an instance of a subclass of :class:`~oneflow.utils.data.IterableDataset`\nthat implements the :meth:`__iter__` protocol, and represents an iterable over\ndata samples. This type of datasets is particularly suitable for cases where\nrandom reads are expensive or even improbable, and where the batch size depends\non the fetched data.\n\nFor example, such a dataset, when called ``iter(dataset)``, could return a\nstream of data reading from a database, a remote server, or even logs generated\nin real time.\n\nSee :class:`~oneflow.utils.data.IterableDataset` for more details.\n\n.. note:: When using an :class:`~oneflow.utils.data.IterableDataset` with\n          `multi-process data loading <Multi-process data loading_>`_. The same\n          dataset object is replicated on each worker process, and thus the\n          replicas must be configured differently to avoid duplicated data. See\n          :class:`~oneflow.utils.data.IterableDataset` documentations for how to\n          achieve this.\n\nData Loading Order and :class:`~oneflow.utils.data.Sampler`\n-----------------------------------------------------------\n\nFor `iterable-style datasets <Iterable-style datasets_>`_, data loading order\nis entirely controlled by the user-defined iterable. This allows easier\nimplementations of chunk-reading and dynamic batch size (e.g., by yielding a\nbatched sample at each time).\n\nThe rest of this section concerns the case with\n`map-style datasets <Map-style datasets_>`_. :class:`oneflow.utils.data.Sampler`\nclasses are used to specify the sequence of indices/keys used in data loading.\nThey represent iterable objects over the indices to datasets.  E.g., in the\ncommon case with stochastic gradient decent (SGD), a\n:class:`~oneflow.utils.data.Sampler` could randomly permute a list of indices\nand yield each one at a time, or yield a small number of them for mini-batch\nSGD.\n\nA sequential or shuffled sampler will be automatically constructed based on the :attr:`shuffle` argument to a :class:`~oneflow.utils.data.DataLoader`.\nAlternatively, users may use the :attr:`sampler` argument to specify a\ncustom :class:`~oneflow.utils.data.Sampler` object that at each time yields\nthe next index/key to fetch.\n\nA custom :class:`~oneflow.utils.data.Sampler` that yields a list of batch\nindices at a time can be passed as the :attr:`batch_sampler` argument.\nAutomatic batching can also be enabled via :attr:`batch_size` and\n:attr:`drop_last` arguments. See\n`the next section <Loading Batched and Non-Batched Data_>`_ for more details\non this.\n\n.. note::\n  Neither :attr:`sampler` nor :attr:`batch_sampler` is compatible with\n  iterable-style datasets, since such datasets have no notion of a key or an\n  index.\n\nLoading Batched and Non-Batched Data\n------------------------------------\n\n:class:`~oneflow.utils.data.DataLoader` supports automatically collating\nindividual fetched data samples into batches via arguments\n:attr:`batch_size`, :attr:`drop_last`, :attr:`batch_sampler`, and\n:attr:`collate_fn` (which has a default function).\n\n\nAutomatic batching (default)\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\nThis is the most common case, and corresponds to fetching a minibatch of\ndata and collating them into batched samples, i.e., containing Tensors with\none dimension being the batch dimension (usually the first).\n\nWhen :attr:`batch_size` (default ``1``) is not ``None``, the data loader yields\nbatched samples instead of individual samples. :attr:`batch_size` and\n:attr:`drop_last` arguments are used to specify how the data loader obtains\nbatches of dataset keys. For map-style datasets, users can alternatively\nspecify :attr:`batch_sampler`, which yields a list of keys at a time.\n\n.. note::\n  The :attr:`batch_size` and :attr:`drop_last` arguments essentially are used\n  to construct a :attr:`batch_sampler` from :attr:`sampler`. For map-style\n  datasets, the :attr:`sampler` is either provided by user or constructed\n  based on the :attr:`shuffle` argument. For iterable-style datasets, the\n  :attr:`sampler` is a dummy infinite one. See\n  `this section <Data Loading Order and Sampler_>`_ on more details on\n  samplers.\n\n.. note::\n  When fetching from\n  `iterable-style datasets <Iterable-style datasets_>`_ with\n  `multi-processing <Multi-process data loading_>`_, the :attr:`drop_last`\n  argument drops the last non-full batch of each worker's dataset replica.\n\nAfter fetching a list of samples using the indices from sampler, the function\npassed as the :attr:`collate_fn` argument is used to collate lists of samples\ninto batches.\n\nIn this case, loading from a map-style dataset is roughly equivalent with::\n\n    for indices in batch_sampler:\n        yield collate_fn([dataset[i] for i in indices])\n\nand loading from an iterable-style dataset is roughly equivalent with::\n\n    dataset_iter = iter(dataset)\n    for indices in batch_sampler:\n        yield collate_fn([next(dataset_iter) for _ in indices])\n\nA custom :attr:`collate_fn` can be used to customize collation, e.g., padding\nsequential data to max length of a batch. See\n`this section <dataloader-collate_fn_>`_ on more about :attr:`collate_fn`.\n\nDisable automatic batching\n^^^^^^^^^^^^^^^^^^^^^^^^^^\n\nIn certain cases, users may want to handle batching manually in dataset code,\nor simply load individual samples. For example, it could be cheaper to directly\nload batched data (e.g., bulk reads from a database or reading continuous\nchunks of memory), or the batch size is data dependent, or the program is\ndesigned to work on individual samples.  Under these scenarios, it's likely\nbetter to not use automatic batching (where :attr:`collate_fn` is used to\ncollate the samples), but let the data loader directly return each member of\nthe :attr:`dataset` object.\n\nWhen both :attr:`batch_size` and :attr:`batch_sampler` are ``None`` (default\nvalue for :attr:`batch_sampler` is already ``None``), automatic batching is\ndisabled. Each sample obtained from the :attr:`dataset` is processed with the\nfunction passed as the :attr:`collate_fn` argument.\n\n**When automatic batching is disabled**, the default :attr:`collate_fn` simply\nconverts NumPy arrays into Oneflow Tensors, and keeps everything else untouched.\n\nIn this case, loading from a map-style dataset is roughly equivalent with::\n\n    for index in sampler:\n        yield collate_fn(dataset[index])\n\nand loading from an iterable-style dataset is roughly equivalent with::\n\n    for data in iter(dataset):\n        yield collate_fn(data)\n\nSee `this section <dataloader-collate_fn_>`_ on more about :attr:`collate_fn`.\n\n.. _dataloader-collate_fn:\n\nWorking with :attr:`collate_fn`\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\nThe use of :attr:`collate_fn` is slightly different when automatic batching is\nenabled or disabled.\n\n**When automatic batching is disabled**, :attr:`collate_fn` is called with\neach individual data sample, and the output is yielded from the data loader\niterator. In this case, the default :attr:`collate_fn` simply converts NumPy\narrays in Oneflow tensors.\n\n**When automatic batching is enabled**, :attr:`collate_fn` is called with a list\nof data samples at each time. It is expected to collate the input samples into\na batch for yielding from the data loader iterator. The rest of this section\ndescribes the behavior of the default :attr:`collate_fn`\n(:func:`~oneflow.utils.data.default_collate`).\n\nFor instance, if each data sample consists of a 3-channel image and an integral\nclass label, i.e., each element of the dataset returns a tuple\n``(image, class_index)``, the default :attr:`collate_fn` collates a list of\nsuch tuples into a single tuple of a batched image tensor and a batched class\nlabel Tensor. In particular, the default :attr:`collate_fn` has the following\nproperties:\n\n* It always prepends a new dimension as the batch dimension.\n\n* It automatically converts NumPy arrays and Python numerical values into\n  Oneflow Tensors.\n\n* It preserves the data structure, e.g., if each sample is a dictionary, it\n  outputs a dictionary with the same set of keys but batched Tensors as values\n  (or lists if the values can not be converted into Tensors). Same\n  for ``list`` s, ``tuple`` s, ``namedtuple`` s, etc.\n\nUsers may use customized :attr:`collate_fn` to achieve custom batching, e.g.,\ncollating along a dimension other than the first, padding sequences of\nvarious lengths, or adding support for custom data types.\n\nIf you run into a situation where the outputs of :class:`~oneflow.utils.data.DataLoader`\nhave dimensions or type that is different from your expectation, you may\nwant to check your :attr:`collate_fn`.\n\nSingle- and Multi-process Data Loading\n--------------------------------------\n\nA :class:`~oneflow.utils.data.DataLoader` uses single-process data loading by\ndefault.\n\nWithin a Python process, the\n`Global Interpreter Lock (GIL) <https://wiki.python.org/moin/GlobalInterpreterLock>`_\nprevents true fully parallelizing Python code across threads. To avoid blocking\ncomputation code with data loading, Oneflow provides an easy switch to perform\nmulti-process data loading by simply setting the argument :attr:`num_workers`\nto a positive integer.\n\nSingle-process data loading (default)\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\nIn this mode, data fetching is done in the same process a\n:class:`~oneflow.utils.data.DataLoader` is initialized.  Therefore, data loading\nmay block computing.  However, this mode may be preferred when resource(s) used\nfor sharing data among processes (e.g., shared memory, file descriptors) is\nlimited, or when the entire dataset is small and can be loaded entirely in\nmemory.  Additionally, single-process loading often shows more readable error\ntraces and thus is useful for debugging.\n\n\nMulti-process data loading\n^^^^^^^^^^^^^^^^^^^^^^^^^^\n\nSetting the argument :attr:`num_workers` as a positive integer will\nturn on multi-process data loading with the specified number of loader worker\nprocesses.\n\n.. warning::\n   After several iterations, the loader worker processes will consume\n   the same amount of CPU memory as the parent process for all Python\n   objects in the parent process which are accessed from the worker\n   processes.  This can be problematic if the Dataset contains a lot of\n   data (e.g., you are loading a very large list of filenames at Dataset\n   construction time) and/or you are using a lot of workers (overall\n   memory usage is ``number of workers * size of parent process``).  The\n   simplest workaround is to replace Python objects with non-refcounted\n   representations such as Pandas, Numpy or PyArrow objects. \n\nIn this mode, each time an iterator of a :class:`~oneflow.utils.data.DataLoader`\nis created (e.g., when you call ``enumerate(dataloader)``), :attr:`num_workers`\nworker processes are created. At this point, the :attr:`dataset`,\n:attr:`collate_fn`, and :attr:`worker_init_fn` are passed to each\nworker, where they are used to initialize, and fetch data. This means that\ndataset access together with its  internal IO, transforms\n(including :attr:`collate_fn`) runs in the worker process.\n\nFor map-style datasets, the main process generates the indices using\n:attr:`sampler` and sends them to the workers. So any shuffle randomization is\ndone in the main process which guides loading by assigning indices to load.\n\nFor iterable-style datasets, since each worker process gets a replica of the\n:attr:`dataset` object, naive multi-process loading will often result in\nduplicated data. Using :attr:`worker_init_fn`, users may configure each replica independently. (See\n:class:`~oneflow.utils.data.IterableDataset` documentations for how to achieve\nthis. ) For similar reasons, in multi-process loading, the :attr:`drop_last`\nargument drops the last non-full batch of each worker's iterable-style dataset\nreplica.\n\nWorkers are shut down once the end of the iteration is reached, or when the\niterator becomes garbage collected.\n\n.. warning::\n  It is generally not recommended to return CUDA tensors in multi-process\n  loading because of many subtleties in using CUDA and sharing CUDA tensors in\n  multiprocessing. Instead, we recommend\n  using `automatic memory pinning <Memory Pinning_>`_ (i.e., setting\n  :attr:`pin_memory=True`), which enables fast data transfer to CUDA-enabled\n  GPUs.\n\nPlatform-specific behaviors\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n\nSince workers rely on Python :py:mod:`multiprocessing`, worker launch behavior is\ndifferent on Windows compared to Unix.\n\n* On Unix, :func:`fork()` is the default :py:mod:`multiprocessing` start method.\n  Using :func:`fork`, child workers typically can access the :attr:`dataset` and\n  Python argument functions directly through the cloned address space.\n\n* On Windows or MacOS, :func:`spawn()` is the default :py:mod:`multiprocessing` start method.\n  Using :func:`spawn()`, another interpreter is launched which runs your main script,\n  followed by the internal worker function that receives the :attr:`dataset`,\n  :attr:`collate_fn` and other arguments through :py:mod:`pickle` serialization.\n\nThis separate serialization means that you should take two steps to ensure you\nare compatible with Windows while using multi-process data loading:\n\n- Wrap most of you main script's code within ``if __name__ == '__main__':`` block,\n  to make sure it doesn't run again (most likely generating error) when each worker\n  process is launched. You can place your dataset and :class:`~oneflow.utils.data.DataLoader`\n  instance creation logic here, as it doesn't need to be re-executed in workers.\n\n- Make sure that any custom :attr:`collate_fn`, :attr:`worker_init_fn`\n  or :attr:`dataset` code is declared as top level definitions, outside of the\n  ``__main__`` check. This ensures that they are available in worker processes.\n  (this is needed since functions are pickled as references only, not ``bytecode``.)\n\n.. _data-loading-randomness:\n\nRandomness in multi-process data loading\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n\nBy default, each worker will have its Oneflow seed set to ``base_seed + worker_id``,\nwhere ``base_seed`` is a long generated by main process using its RNG (thereby,\nconsuming a RNG state mandatorily) or a specified :attr:`generator`. However, seeds for other\nlibraries may be duplicated upon initializing workers, causing each worker to return\nidentical random numbers.\n\nIn :attr:`worker_init_fn`, you may access the Oneflow seed set for each worker\nwith :func:`oneflow.initial_seed()`, and use it to seed other libraries before data\nloading.\n\nMemory Pinning\n--------------\n\nHost to GPU copies are much faster when they originate from pinned (page-locked)\nmemory. See `cuda-memory-pinning` for more details on when and how to use\npinned memory generally.\n\nFor data loading, passing :attr:`pin_memory=True` to a\n:class:`~oneflow.utils.data.DataLoader` will automatically put the fetched data\nTensors in pinned memory, and thus enables faster data transfer to CUDA-enabled\nGPUs.\n\nThe default memory pinning logic only recognizes Tensors and maps and iterables\ncontaining Tensors.  By default, if the pinning logic sees a batch that is a\ncustom type (which will occur if you have a :attr:`collate_fn` that returns a\ncustom batch type), or if each element of your batch is a custom type, the\npinning logic will not recognize them, and it will return that batch (or those\nelements) without pinning the memory.  To enable memory pinning for custom\nbatch or data type(s), define a :meth:`pin_memory` method on your custom\ntype(s).\n\nSee the example below.\n\nExample::\n\n    class SimpleCustomBatch:\n        def __init__(self, data):\n            transposed_data = list(zip(*data))\n            self.inp = oneflow.stack(transposed_data[0], 0)\n            self.tgt = oneflow.stack(transposed_data[1], 0)\n\n        # custom memory pinning method on custom type\n        def pin_memory(self):\n            self.inp = self.inp.pin_memory()\n            self.tgt = self.tgt.pin_memory()\n            return self\n\n    def collate_wrapper(batch):\n        return SimpleCustomBatch(batch)\n\n    inps = oneflow.arange(10 * 5, dtype=oneflow.float32).view(10, 5)\n    tgts = oneflow.arange(10 * 5, dtype=oneflow.float32).view(10, 5)\n    dataset = TensorDataset(inps, tgts)\n\n    loader = DataLoader(dataset, batch_size=2, collate_fn=collate_wrapper,\n                        pin_memory=True)\n\n    for batch_ndx, sample in enumerate(loader):\n        print(sample.inp.is_pinned())\n        print(sample.tgt.is_pinned())\n\n\n.. autoclass:: DataLoader\n.. autoclass:: Dataset\n.. autoclass:: IterableDataset\n.. autoclass:: TensorDataset\n.. autoclass:: ConcatDataset\n.. autoclass:: Subset\n.. autofunction:: oneflow.utils.data.random_split\n.. autoclass:: oneflow.utils.data.Sampler\n.. autoclass:: oneflow.utils.data.SequentialSampler\n.. autoclass:: oneflow.utils.data.RandomSampler\n.. autoclass:: oneflow.utils.data.SubsetRandomSampler\n.. autoclass:: oneflow.utils.data.BatchSampler\n.. autoclass:: oneflow.utils.data.distributed.DistributedSampler\n\n"
  },
  {
    "path": "docs/source/utils.global_view.rst",
    "content": "oneflow.utils.global_view\n======================================\nSome global view Ops\n--------------------------------------\n.. currentmodule:: oneflow.utils.global_view\n\n.. autosummary::\n    :toctree: generated\n    :nosignatures:\n    :template: classtemplate.rst\n\n    to_global\n    to_local\n    global_mode\n    current_global_mode\n\n"
  },
  {
    "path": "docs/source/utils.tensor.rst",
    "content": "oneflow.utils.tensor\n==========================================================\nSome torch-related Ops are suitable for tensor conversion.\n----------------------------------------------------------\n.. currentmodule:: oneflow.utils.tensor\n\n.. autosummary::\n    :toctree: generated\n    :nosignatures:\n\n    from_torch\n    to_torch\n\n"
  },
  {
    "path": "external/CMakeLists.txt",
    "content": "set(ONETBB_URL\n    https://github.com/oneapi-src/oneTBB/archive/3db67b5ba2a81bd1288325c5847e09e13c46f4d7.zip)\nuse_mirror(VARIABLE ONETBB_URL URL ${ONETBB_URL})\nset(ONETBB_MD5 7545d4084baff17af73da2dae5ab8005)\n\nset(ROBIN_HOOD_HASHING_URL\n    https://github.com/martinus/robin-hood-hashing/archive/refs/tags/3.11.5.tar.gz)\nuse_mirror(VARIABLE ROBIN_HOOD_HASHING_URL URL ${ROBIN_HOOD_HASHING_URL})\nset(ROBIN_HOOD_HASHING_MD5 a78bd30a7582f25984f8592652836467)\n\nset(FMT_URL https://github.com/fmtlib/fmt/archive/fc07217d85e6dcec52878807d6bbd89a9d9156a5.zip)\nuse_mirror(VARIABLE FMT_URL URL ${FMT_URL})\nset(FMT_MD5 7d9bb2ececc9ede29cd35bdc42a7e22c)\n\nset(KINETO_URL\n    https://github.com/pytorch/kineto/archive/ff8dba20499a660650632952be76450bd70a52a6.zip)\nuse_mirror(VARIABLE KINETO_URL URL ${KINETO_URL})\nset(KINETO_MD5 f9b550591b3899fb267270c19484933f)\n\nset(EXTERNAL_TARGETS)\n\nif(WITH_TBB) # set(WITH_${threading_runtime_item} ON) in threading.cmake\n  add_subdirectory(onetbb)\n  list(APPEND EXTERNAL_TARGETS tbb)\nendif()\n\nadd_subdirectory(robin-hood-hashing)\nlist(APPEND EXTERNAL_TARGETS robin_hood)\n\nadd_subdirectory(fmt)\nlist(APPEND EXTERNAL_TARGETS fmt)\n\nadd_subdirectory(kineto)\nlist(APPEND EXTERNAL_TARGETS kineto)\n\nmark_targets_as_system(${EXTERNAL_TARGETS})\n\nset_property(GLOBAL PROPERTY EXTERNAL_TARGETS ${EXTERNAL_TARGETS})\n"
  },
  {
    "path": "external/fmt/CMakeLists.txt",
    "content": "include(FetchContent)\n\nset(FMT_INSTALL_DIR ${THIRD_PARTY_DIR}/fmt)\n\nFetchContent_Declare(fmt URL ${FMT_URL} URL_HASH MD5=${FMT_MD5})\n\nFetchContent_MakeAvailable(fmt)\n\n# Clang doesn't support __float128 when compiling CUDA\ntarget_compile_definitions(fmt PUBLIC FMT_USE_FLOAT128=0)\n\ninstall(\n  TARGETS fmt\n  EXPORT oneflow\n  LIBRARY DESTINATION ${FMT_INSTALL_DIR}/lib\n  ARCHIVE DESTINATION ${FMT_INSTALL_DIR}/lib)\ninstall(DIRECTORY ${fmt_SOURCE_DIR}/include DESTINATION ${FMT_INSTALL_DIR})\ninstall(DIRECTORY ${fmt_SOURCE_DIR}/include/ DESTINATION ${ONEFLOW_INCLUDE_DIR}\n        COMPONENT oneflow_py_include EXCLUDE_FROM_ALL)\n"
  },
  {
    "path": "external/kineto/CMakeLists.txt",
    "content": "include(FetchContent)\n\n# reference: https://github.com/PaddlePaddle/Paddle/blob/develop/cmake/cupti.cmake\n\nset(CUPTI_ROOT \"/usr\" CACHE PATH \"CUPTI ROOT\")\n\nset(CUDA_SOURCE_DIR ${CUDAToolkit_TARGET_DIR})\n\nfind_path(\n  CUPTI_INCLUDE_DIR cupti.h\n  PATHS ${CUPTI_ROOT}\n        ${CUPTI_ROOT}/include\n        $ENV{CUPTI_ROOT}\n        $ENV{CUPTI_ROOT}/include\n        ${CUDA_SOURCE_DIR}/extras/CUPTI/include\n        ${CUDA_SOURCE_DIR}/targets/x86_64-linux/include\n        ${CUDA_SOURCE_DIR}/targets/aarch64-linux/include\n  NO_DEFAULT_PATH)\n\nset(TARGET_ARCH \"x86_64\")\nif(NOT ${CMAKE_SYSTEM_PROCESSOR})\n  set(TARGET_ARCH ${CMAKE_SYSTEM_PROCESSOR})\nendif()\n\nlist(\n  APPEND\n  CUPTI_CHECK_LIBRARY_DIRS\n  ${CUPTI_ROOT}\n  ${CUPTI_ROOT}/lib64\n  ${CUPTI_ROOT}/lib\n  ${CUPTI_ROOT}/lib/${TARGET_ARCH}-linux-gnu\n  $ENV{CUPTI_ROOT}\n  $ENV{CUPTI_ROOT}/lib64\n  $ENV{CUPTI_ROOT}/lib\n  /usr/lib\n  ${CUDA_SOURCE_DIR}/targets/x86_64-linux/lib64\n  ${CUDA_SOURCE_DIR}/targets/x86_64-linux/lib\n  ${CUDA_SOURCE_DIR}/extras/CUPTI/lib64\n  ${CUDA_SOURCE_DIR}/extras/CUPTI/lib)\n\nfind_library(\n  CUDA_cupti_LIBRARY\n  NAMES libcupti.so libcupti.dylib # libcupti_static.a\n  PATHS ${CUPTI_CHECK_LIBRARY_DIRS} ${CUPTI_INCLUDE_DIR}\n  NO_DEFAULT_PATH\n  DOC \"Path to cuPTI library.\")\n\nlist(APPEND CUDA_cupti_LIBRARY CUDA::cudart_static) # for undefined symbol: cudaGetDeviceCount∂\n\nFetchContent_Declare(\n  kineto\n  URL ${KINETO_URL}\n  URL_HASH MD5=${KINETO_MD5}\n  SOURCE_SUBDIR libkineto)\n\nFetchContent_MakeAvailable(kineto)\n\ntarget_include_directories(kineto PUBLIC $<BUILD_INTERFACE:${kineto_SOURCE_DIR}/libkineto/include>)\n"
  },
  {
    "path": "external/onetbb/CMakeLists.txt",
    "content": "find_package(Threads REQUIRED)\nset(ONETBB_INSTALL_DIR ${THIRD_PARTY_DIR}/tbb CACHE PATH \" \")\n\ninclude(FetchContent)\nFetchContent_Declare(tbb URL ${ONETBB_URL} URL_HASH MD5=${ONETBB_MD5})\nFetchContent_GetProperties(tbb)\n\nset(TBB_EXAMPLES OFF CACHE BOOL \"\")\nset(TBB_TEST OFF CACHE BOOL \"\")\nset(TBB_ENABLE_IPO OFF CACHE BOOL \"\")\nset(BUILD_SHARED_LIBS ON)\nset(CMAKE_POLICY_DEFAULT_CMP0079 NEW)\n\nFetchContent_MakeAvailable(tbb)\n\n# workaround compile error in GCC 12 or later\n# refer to https://github.com/Oneflow-Inc/oneflow/pull/10236\nif(CMAKE_CXX_COMPILER_ID STREQUAL \"GNU\" AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 12)\n  target_compile_options(tbb PRIVATE \"-Wno-error=stringop-overflow\")\nendif()\n\nset(TBBBIND_LIBRARY_NAME)\nif(HWLOC_VERSION)\n  if(HWLOC_VERSION VERSION_LESS 2)\n    set(TBBBIND_LIBRARY_NAME tbbbind)\n  elseif(HWLOC_VERSION VERSION_LESS 2.5)\n    set(TBBBIND_LIBRARY_NAME tbbbind_2_0)\n  else()\n    set(TBBBIND_LIBRARY_NAME tbbbind_2_5)\n  endif()\nendif()\n\nadd_custom_target(\n  install-tbb DEPENDS tbb tbbmalloc tbbmalloc_proxy ${TBBBIND_LIBRARY_NAME}\n  COMMAND \"${CMAKE_COMMAND}\" -DCMAKE_INSTALL_PREFIX=${ONETBB_INSTALL_DIR} -P\n          \"${tbb_BINARY_DIR}/cmake_install.cmake\")\n"
  },
  {
    "path": "external/robin-hood-hashing/CMakeLists.txt",
    "content": "include(FetchContent)\nFetchContent_Declare(\n        robin_hood_hashing\n        URL ${ROBIN_HOOD_HASHING_URL}\n        URL_HASH MD5=${ROBIN_HOOD_HASHING_MD5}\n)\nFetchContent_MakeAvailable(robin_hood_hashing)\n"
  },
  {
    "path": "oneflow/api/common/ir_pass.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#ifdef WITH_MLIR\n\n#include \"oneflow/ir/include/OneFlow/Extension.h\"\n#include \"oneflow/ir/oneflow-extension/include/OneFlow/OneFlowRoundTrip.h\"\n#include <glog/logging.h>\n\nnamespace oneflow {\n\nREGISTER_JOB_PASS(\"IRRoundTripBeforeAD\", IRRoundTrip<kBeforeAD>);\nREGISTER_JOB_PASS(\"IRRoundTrip\", IRRoundTrip<kAfterAD>);\n\n}  // namespace oneflow\n\n#endif  // WITH_MLIR\n"
  },
  {
    "path": "oneflow/api/common/job_build_and_infer_ctx.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#ifndef ONEFLOW_API_COMMON_JOB_BUILD_AND_INFER_CTX_H_\n#define ONEFLOW_API_COMMON_JOB_BUILD_AND_INFER_CTX_H_\n\n#include \"oneflow/core/job/job.pb.h\"\n#include \"oneflow/core/job/job_build_and_infer_ctx_mgr.h\"\n\nnamespace oneflow {\n\ninline Maybe<Job> GetCurrentJob() {\n  auto* job_ctx_mgr = Singleton<LazyJobBuildAndInferCtxMgr>::Get();\n  CHECK_NOTNULL_OR_RETURN(job_ctx_mgr);\n  auto* job_ctx =\n      JUST(job_ctx_mgr->FindJobBuildAndInferCtx(*JUST(job_ctx_mgr->GetCurrentJobName())));\n  CHECK_NOTNULL_OR_RETURN(job_ctx);\n  return job_ctx->job();\n}\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_API_COMMON_JOB_BUILD_AND_INFER_CTX_H_\n"
  },
  {
    "path": "oneflow/api/common/sbp.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#ifndef ONEFLOW_API_COMMON_SBP_H_\n#define ONEFLOW_API_COMMON_SBP_H_\n\n#include \"oneflow/core/job/sbp_parallel.pb.h\"\n#include \"oneflow/core/job/sbp_parallel.h\"\n#include \"oneflow/core/common/symbol.h\"\n#include \"oneflow/core/common/maybe.h\"\n\nnamespace oneflow {\n\nnamespace api {\n\n// NOTE: The api inferface will print the whole name of sbp.\n\ninline Maybe<std::string> ApiSbpToString(Symbol<SbpParallel> sbp_sym) {\n  std::string sbp_str = \"oneflow.sbp.\";\n  if (sbp_sym->has_broadcast_parallel()) {\n    sbp_str += \"broadcast\";\n  } else if (sbp_sym->has_partial_sum_parallel()) {\n    sbp_str += \"partial_sum\";\n  } else if (sbp_sym->has_split_parallel()) {\n    sbp_str += \"split(dim=\" + std::to_string(sbp_sym->split_parallel().axis()) + \")\";\n  } else {\n    UNIMPLEMENTED_THEN_RETURN();\n  }\n  return sbp_str;\n}\n\ninline Maybe<std::string> ApiNdSbpToString(Symbol<NdSbp> nd_sbp) {\n  std::string str = \"(\";\n  for (int i = 0; i < nd_sbp->sbp_parallel_size(); ++i) {\n    if (i > 0) { str += \", \"; }\n    str += *JUST(ApiSbpToString(SymbolOf(nd_sbp->sbp_parallel(i))));\n  }\n  if (nd_sbp->sbp_parallel_size() == 1) { str += \",\"; }\n  str += \")\";\n  return str;\n}\n\n}  // namespace api\n\n}  // namespace oneflow\n\n#endif  // !ONEFLOW_API_COMMON_SBP_H_\n"
  },
  {
    "path": "oneflow/api/common/variable_tensor_mgr.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_API_COMMON_VARIABLE_TENSOR_MGR_H_\n#define ONEFLOW_API_COMMON_VARIABLE_TENSOR_MGR_H_\n\n#include \"oneflow/core/common/singleton.h\"\n#include \"oneflow/core/framework/tensor.h\"\n#include \"oneflow/core/framework/variable_tensor_mgr.h\"\n\nnamespace oneflow {\n\ninline Maybe<void> FillVariableTensorMgr(\n    const std::vector<std::string>& variable_op_names,\n    const std::vector<std::shared_ptr<one::Tensor>>& variable_tensors) {\n  auto mgr = Singleton<VariableTensorMgr>::Get();\n  return mgr->Fill(variable_op_names, variable_tensors);\n}\n\ninline void ResetVariableTensorMgr() {\n  auto mgr = Singleton<VariableTensorMgr>::Get();\n  mgr->Reset();\n}\n\ninline std::tuple<std::vector<std::string>, std::vector<std::shared_ptr<one::Tensor>>>\nDumpVariableTensorMgr() {\n  auto mgr = Singleton<VariableTensorMgr>::Get();\n  return mgr->Dump();\n}\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_API_COMMON_VARIABLE_TENSOR_MGR_H_\n"
  },
  {
    "path": "oneflow/api/cpp/api.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#ifndef ONEFLOW_API_CPP_API_H_\n#define ONEFLOW_API_CPP_API_H_\n\n#include \"env.h\"\n#include \"framework.h\"\n#include \"nn.h\"\n\n#endif  // !ONEFLOW_API_CPP_API_H_\n"
  },
  {
    "path": "oneflow/api/cpp/embedding/embedding.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/api/cpp/embedding/embedding.h\"\n#include \"oneflow/core/embedding/embedding_manager.h\"\n\nnamespace oneflow_api {\nnamespace embedding {\n\nstd::string CreateKeyValueStore(const std::string& key_value_store_options, int64_t local_rank_id,\n                                int64_t rank_id, int64_t world_size) {\n  oneflow::embedding::KeyValueStoreOptions options(key_value_store_options);\n#ifdef WITH_CUDA\n  oneflow::Singleton<oneflow::embedding::EmbeddingManager>::Get()->CreateKeyValueStore(\n      options, local_rank_id, rank_id, world_size);\n  return options.Name();\n#else\n  UNIMPLEMENTED() << \"OneEmbedding Only Support with CUDA\";\n#endif\n  return \"\";\n}\n\nvoid LoadSnapshot(const std::string& snapshot_name, const std::string& embedding_name,\n                  int64_t local_rank_id, int64_t rank_id) {\n#ifdef WITH_CUDA\n  oneflow::Singleton<oneflow::embedding::EmbeddingManager>::Get()->LoadSnapshot(\n      embedding_name, local_rank_id, rank_id, snapshot_name);\n#else\n  UNIMPLEMENTED() << \"OneEmbedding Only Support with CUDA\";\n#endif\n}\n\n}  // namespace embedding\n}  // namespace oneflow_api\n"
  },
  {
    "path": "oneflow/api/cpp/embedding/embedding.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_API_CPP_ONE_EMBEDDING_ONE_EMBEDDING_H_\n#define ONEFLOW_API_CPP_ONE_EMBEDDING_ONE_EMBEDDING_H_\n\n#include <string>\n\nnamespace oneflow_api {\nnamespace embedding {\n\n// CreateKeyValueStore returns embedding name in the options.\nstd::string CreateKeyValueStore(const std::string& key_value_store_options, int64_t local_rank_id,\n                                int64_t rank_id,\n                                int64_t world_size);  // key_value_store_options is\n                                                      // a serialized json string.\nvoid LoadSnapshot(const std::string& snapshot_name, const std::string& embedding_name,\n                  int64_t local_rank_id, int64_t rank_id);\n\n}  // namespace embedding\n}  // namespace oneflow_api\n\n#endif  // ONEFLOW_API_CPP_ONE_EMBEDDING_ONE_EMBEDDING_H_\n"
  },
  {
    "path": "oneflow/api/cpp/env.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include <glog/logging.h>\n#include \"oneflow/api/cpp/env.h\"\n#include \"oneflow/api/cpp/env_impl.h\"\n#include \"oneflow/core/framework/shut_down_util.h\"\n#include \"oneflow/core/thread/thread_global_id.h\"\n\nnamespace oneflow_api {\nvoid initialize() {\n  if (of::Singleton<OneFlowEnv>::Get() == nullptr) { of::Singleton<OneFlowEnv>::New(); }\n  of::SetShuttingDown(false);\n}\n\nvoid release() {\n  if (of::Singleton<OneFlowEnv>::Get() != nullptr) { of::Singleton<OneFlowEnv>::Delete(); }\n  of::SetShuttingDown();\n}\n\n}  // namespace oneflow_api\n"
  },
  {
    "path": "oneflow/api/cpp/env.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_API_CPP_ENV_H_\n#define ONEFLOW_API_CPP_ENV_H_\n\nnamespace oneflow_api {\n\nvoid initialize();\nvoid release();\n\n}  // namespace oneflow_api\n\n#endif  // !ONEFLOW_API_CPP_ENV_H_\n"
  },
  {
    "path": "oneflow/api/cpp/env_impl.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <glog/logging.h>\n#include <sys/socket.h>\n#include <netinet/in.h>\n#include <arpa/inet.h>\n#include <cstddef>\n#include <cstdint>\n#include <cstdlib>\n#include <memory>\n#include <random>\n#include <type_traits>\n#include \"oneflow/api/cpp/env_impl.h\"\n#include \"oneflow/core/common/singleton.h\"\n#include \"oneflow/core/common/just.h\"\n#include \"oneflow/core/common/optional.h\"\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/framework/session_util.h\"\n#include \"oneflow/core/job/env.pb.h\"\n#include \"oneflow/core/job/cluster_instruction.h\"\n#include \"oneflow/core/control/ctrl_bootstrap.h\"\n#include \"oneflow/core/job/session.h\"\n#include \"oneflow/core/rpc/include/base.h\"\n#include \"oneflow/core/vm/vm_util.h\"\n\nnamespace oneflow_api {\n\nnamespace of = oneflow;\n\nnamespace {  // for inltialize\n\ninline bool IsEnvInited() { return of::Singleton<of::EnvGlobalObjectsScope>::Get() != nullptr; }\n\nbool HasEnvVar(const std::string& key) {\n  const char* value = getenv(key.c_str());\n  return value != nullptr;\n}\n\nstd::string GetEnvVar(const std::string& key, const std::string& default_value) {\n  const char* value = getenv(key.c_str());\n  if (value == nullptr) { return default_value; }\n  return std::string(value);\n}\n\nint64_t GetEnvVar(const std::string& key, int64_t default_value) {\n  const char* value = getenv(key.c_str());\n  if (value == nullptr) { return default_value; }\n  return std::atoll(value);\n}\n\nint32_t FindFreePort(const std::string& addr) {\n#ifdef __linux__\n  int sock = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);\n  CHECK_GE(sock, 0) << \"fail to find a free port.\";\n  int optval = 1;\n  setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &optval, sizeof(optval));\n\n  std::mt19937 rng;\n  rng.seed(std::random_device()());\n  std::uniform_int_distribution<std::mt19937::result_type> dist(1, 1000);\n\n  int count = 0;\n  int num_attempts = 200;\n  do {\n    int port = 5000 + dist(rng);\n    struct sockaddr_in sockaddr {};\n    memset(&sockaddr, 0, sizeof(sockaddr));\n    sockaddr.sin_family = AF_INET;\n    sockaddr.sin_port = htons(port);\n    sockaddr.sin_addr.s_addr = inet_addr(addr.c_str());\n    int error = bind(sock, (struct sockaddr*)&sockaddr, sizeof(sockaddr));\n    if (error == 0) { return port; }\n    ++count;\n  } while (count < num_attempts);\n  CHECK_NE(count, num_attempts) << \"fail to find a free port.\";\n#endif  // __linux__\n  return -1;\n}\n\nvoid CompleteEnvProto(of::EnvProto& env_proto) {\n  auto bootstrap_conf = env_proto.mutable_ctrl_bootstrap_conf();\n  auto master_addr = bootstrap_conf->mutable_master_addr();\n  const std::string addr = GetEnvVar(\"MASTER_ADDR\", \"127.0.0.1\");\n  master_addr->set_host(addr);\n  master_addr->set_port(GetEnvVar(\"MASTER_PORT\", FindFreePort(addr)));\n  bootstrap_conf->set_world_size(GetEnvVar(\"WORLD_SIZE\", 1));\n  bootstrap_conf->set_rank(GetEnvVar(\"RANK\", 0));\n\n  auto cpp_logging_conf = env_proto.mutable_cpp_logging_conf();\n  if (HasEnvVar(\"GLOG_log_dir\")) { cpp_logging_conf->set_log_dir(GetEnvVar(\"GLOG_log_dir\", \"\")); }\n  if (HasEnvVar(\"GLOG_logtostderr\")) {\n    cpp_logging_conf->set_logtostderr(GetEnvVar(\"GLOG_logtostderr\", -1));\n  }\n  if (HasEnvVar(\"GLOG_logbuflevel\")) {\n    cpp_logging_conf->set_logbuflevel(GetEnvVar(\"GLOG_logbuflevel\", -1));\n  }\n  if (HasEnvVar(\"GLOG_minloglevel\")) {\n    cpp_logging_conf->set_minloglevel(GetEnvVar(\"GLOG_minloglevel\", -1));\n  }\n}\n}  // namespace\n\nOneFlowEnv::OneFlowEnv() {\n  of::EnvProto env_proto;\n  CompleteEnvProto(env_proto);\n\n  env_ctx_ = std::make_shared<of::EnvGlobalObjectsScope>(env_proto);\n\n  of::ConfigProto config_proto;\n  config_proto.mutable_resource()->set_cpu_device_num(1);  // useless, will be set in TryInit\n  const int64_t session_id = of::NewSessionId();\n  config_proto.set_session_id(session_id);\n  CHECK(of::RegsterSessionId(session_id));\n  session_ctx_ = std::make_shared<of::MultiClientSessionContext>(env_ctx_);\n  CHECK_JUST(session_ctx_->TryInit(config_proto));\n}\n\nOneFlowEnv::~OneFlowEnv() {\n  session_ctx_.reset();\n  CHECK(of::ClearSessionId(CHECK_JUST(of::GetDefaultSessionId())));\n  env_ctx_.reset();\n}\n\n}  // namespace oneflow_api\n"
  },
  {
    "path": "oneflow/api/cpp/env_impl.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <memory>\n#include \"oneflow/core/framework/multi_client_session_context.h\"\n#include \"oneflow/core/job/env_global_objects_scope.h\"\n\n#ifndef ONEFLOW_API_CPP_ENV_IMPL_H_\n#define ONEFLOW_API_CPP_ENV_IMPL_H_\n\nnamespace oneflow_api {\nnamespace of = oneflow;\nclass OneFlowEnv {\n public:\n  OF_DISALLOW_COPY(OneFlowEnv);\n  OneFlowEnv();\n  ~OneFlowEnv();\n  std::shared_ptr<of::MultiClientSessionContext> GetSessionCtx() { return session_ctx_; }\n\n private:\n  std::shared_ptr<of::EnvGlobalObjectsScope> env_ctx_;\n  std::shared_ptr<of::MultiClientSessionContext> session_ctx_;\n};\n}  // namespace oneflow_api\n\n#endif  // ONEFLOW_API_CPP_ENV_IMPL_H_\n"
  },
  {
    "path": "oneflow/api/cpp/framework/device.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/api/cpp/framework/device.h\"\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/common/symbol.h\"\n#include \"oneflow/core/framework/device.h\"\n\nnamespace oneflow_api {\n\nnamespace of = oneflow;\n\nDevice::Device(const std::string& type_or_type_with_device_id)\n    : device_(std::make_shared<of::Symbol<of::Device>>(\n        of::Device::ParseAndNew(type_or_type_with_device_id).GetOrThrow())) {}\n\nDevice::Device(const std::string& type, int64_t device_id)\n    : device_(\n        std::make_shared<of::Symbol<of::Device>>(of::Device::New(type, device_id).GetOrThrow())) {}\n\nconst std::string& Device::type() const { return (*device_)->type(); }\n\nint64_t Device::device_id() const { return (*device_)->device_id(); }\n\nbool Device::operator==(const Device& rhs) const { return *device_ == *rhs.device_; }\nbool Device::operator!=(const Device& rhs) const { return *device_ != *rhs.device_; }\n\n}  // namespace oneflow_api\n"
  },
  {
    "path": "oneflow/api/cpp/framework/device.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_API_CPP_FRAMEWORK_DEVICE_H_\n#define ONEFLOW_API_CPP_FRAMEWORK_DEVICE_H_\n\n#include <string>\n#include <memory>\n\nnamespace oneflow {\n\nclass Device;\n\ntemplate<typename T>\nclass Symbol;\n\n}  // namespace oneflow\n\nnamespace oneflow_api {\n\nclass Device final {\n  friend class Tensor;\n  friend class Graph;\n\n public:\n  explicit Device(const std::string& type_or_type_with_device_id);\n  explicit Device(const std::string& type, int64_t device_id);\n  [[nodiscard]] const std::string& type() const;\n  [[nodiscard]] int64_t device_id() const;\n\n  [[nodiscard]] bool operator==(const Device& rhs) const;\n  [[nodiscard]] bool operator!=(const Device& rhs) const;\n\n private:\n  std::shared_ptr<oneflow::Symbol<oneflow::Device>> device_ = nullptr;\n};\n\n}  // namespace oneflow_api\n\n#endif  // !ONEFLOW_API_CPP_FRAMEWORK_DEVICE_H_\n"
  },
  {
    "path": "oneflow/api/cpp/framework/dtype.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/api/cpp/framework/dtype.h\"\n#include <map>\n\nnamespace oneflow_api {\n\nnamespace {\n\nstd::map<DType, int32_t> DTypeSize = {\n    {DType::kFloat, sizeof(float)},   {DType::kDouble, sizeof(double)},\n    {DType::kInt8, sizeof(int8_t)},   {DType::kInt32, sizeof(int32_t)},\n    {DType::kInt64, sizeof(int64_t)}, {DType::kBool, sizeof(bool)},\n};\n\n}\n\nint32_t GetDTypeSize(DType dtype) { return DTypeSize[dtype]; }\n\n}  // namespace oneflow_api\n"
  },
  {
    "path": "oneflow/api/cpp/framework/dtype.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_API_CPP_FRAMEWORK_DTYPE_H_\n#define ONEFLOW_API_CPP_FRAMEWORK_DTYPE_H_\n\n#include <cstdint>\n\nnamespace oneflow_api {\n\nenum class DType {\n  kInvalidDataType = 0,\n  kChar = 1,\n  kFloat = 2,\n  kDouble = 3,\n  kInt8 = 4,\n  kInt32 = 5,\n  kInt64 = 6,\n  kUInt8 = 7,\n  kOFRecord = 8,\n  kFloat16 = 9,\n  kTensorBuffer = 10,\n  kBFloat16 = 11,\n  kBool = 12,\n  kMaxDataType = 13\n};\n\n[[nodiscard]] int32_t GetDTypeSize(DType dtype);\n\n}  // namespace oneflow_api\n\n#endif  // ONEFLOW_API_CPP_FRAMEWORK_DTYPE_H_\n"
  },
  {
    "path": "oneflow/api/cpp/framework/graph.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"nlohmann/json.hpp\"\n#include \"oneflow/api/common/variable_tensor_mgr.h\"\n#include \"oneflow/api/cpp/env_impl.h\"\n#include \"oneflow/api/cpp/framework/device.h\"\n#include \"oneflow/api/cpp/framework/dtype.h\"\n#include \"oneflow/api/cpp/framework/graph.h\"\n#include \"oneflow/api/cpp/framework/ivalue.h\"\n#include \"oneflow/api/cpp/framework/shape.h\"\n#include \"oneflow/api/cpp/framework/tensor.h\"\n#include \"oneflow/api/cpp/embedding/embedding.h\"\n#include \"oneflow/api/common/job_build_and_infer_ctx.h\"\n#include \"oneflow/api/python/job_build/job_build_and_infer.h\"\n#include \"oneflow/core/common/data_type.pb.h\"\n#include \"oneflow/core/common/singleton.h\"\n#include \"oneflow/core/common/hash_container.h\"\n#include \"oneflow/core/common/just.h\"\n#include \"oneflow/core/common/shape.h\"\n#include \"oneflow/core/common/symbol.h\"\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/embedding/posix_file.h\"\n#include \"oneflow/core/eager/eager_blob_object.h\"\n#include \"oneflow/core/framework/device.h\"\n#include \"oneflow/core/framework/dtype.h\"\n#include \"oneflow/core/framework/multi_client_session_context.h\"\n#include \"oneflow/core/framework/nn_graph.h\"\n#include \"oneflow/core/framework/scope_util.h\"\n#include \"oneflow/core/framework/tensor.h\"\n#include \"oneflow/core/framework/tensor_tuple.h\"\n#include \"oneflow/core/framework/tensor_util.h\"\n#include \"oneflow/core/functional/functional_api.yaml.h\"\n#include \"oneflow/core/graph/op_graph.h\"\n#include \"oneflow/core/job/job.pb.h\"\n#include \"oneflow/core/job/job_build_and_infer_ctx.h\"\n#include \"oneflow/core/job/job_build_and_infer_ctx_mgr.h\"\n#include \"oneflow/core/job/job_conf.pb.h\"\n#include \"oneflow/core/job/job_ir.h\"\n#include \"oneflow/core/job/job_set.pb.h\"\n#include \"oneflow/core/job/lazy_mode.h\"\n#include \"oneflow/core/job/parallel_desc.h\"\n#include \"oneflow/core/job/scope.h\"\n#include \"oneflow/core/job/session.h\"\n#include \"oneflow/core/kernel/kernel_util.h\"\n#include \"oneflow/core/memory/memory_case_util.h\"\n#include \"oneflow/core/operator/interface_blob_conf.pb.h\"\n#include \"oneflow/core/operator/op_conf.pb.h\"\n#include \"oneflow/core/register/logical_blob_id.pb.h\"\n#include \"oneflow/core/vm/vm_util.h\"\n\nnamespace oneflow_api {\n\nnamespace of = oneflow;\n\nnamespace {\n\nclass CompileScope {\n public:\n  CompileScope(const of::JobConfigProto& job_config, const of::Device& device) {\n    of::JobConfigProto mut_job_config = job_config;\n    const std::shared_ptr<of::Scope> scope = CHECK_JUST(MakeScope(mut_job_config, device));\n    CHECK_JUST(of::ThreadLocalScopeStackPush(scope));\n\n    CHECK_JUST(of::JobBuildAndInferCtx_Open(mut_job_config.job_name()));\n    CHECK_JUST(CHECK_JUST(of::GetCurInferCtx())->SetJobConf(mut_job_config));\n  }\n\n  ~CompileScope() {\n    CHECK_JUST(of::JobBuildAndInferCtx_Close());\n    CHECK_JUST(of::ThreadLocalScopeStackPop());\n  }\n\n private:\n  of::LazyMode::Guard lazy_mode_enabled_guard{true};\n};\n\nstd::shared_ptr<of::one::TensorTuple> ConvertToTensorTuple(\n    const std::vector<std::shared_ptr<of::one::Tensor>>& tensors) {\n  auto tensor_tuple = std::make_shared<of::one::TensorTuple>();\n  for (const auto& tensor : tensors) { tensor_tuple->emplace_back(tensor); }\n  return tensor_tuple;\n}\n\nstd::string GetDeviceTag(const Device& device) { return device.type(); }\n\ntemplate<class T1, class T2>\nconst std::pair<std::vector<T1>, std::vector<T2>> Unzip(const of::HashMap<T1, T2>& hash_map) {\n  std::vector<T1> vec1;\n  std::vector<T2> vec2;\n  for (const auto& entry : hash_map) {\n    vec1.emplace_back(entry.first);\n    vec2.emplace_back(entry.second);\n  }\n  return std::make_pair(vec1, vec2);\n}\n\nShape OfShapeToOfApiShape(const of::Shape& of_shape) {\n  std::vector<int64_t> dims(of_shape.dim_vec().begin(), of_shape.dim_vec().end());\n  return Shape(dims);\n}\n\n#ifdef __linux__\n\nvoid LoadOneEmbedding(const std::string& model_path, const Device& device) {\n  const std::string one_embedding_info_name(\"one_embedding_options.json\");\n  const std::string one_embedding_info_save_path(\n      oneflow::JoinPath(model_path, one_embedding_info_name));\n  if (oneflow::embedding::PosixFile::FileExists(one_embedding_info_save_path)) {\n    std::ifstream one_embedding_info_file(one_embedding_info_save_path);\n    auto one_embedding_json = nlohmann::json::parse(one_embedding_info_file);\n    for (auto& it : one_embedding_json[\"embedding\"]) {\n      const std::string snapshot_path = it[\"snapshot\"];\n      auto kv_options_json = it[\"kv_options\"];\n      std::string embedding_name = embedding::CreateKeyValueStore(kv_options_json.dump(),\n                                                                  /*local_rank_id=*/0,\n                                                                  /*rank_id=*/0,\n                                                                  /*world_size=*/1);\n      embedding::LoadSnapshot(snapshot_path, embedding_name, /*local_rank_id=*/0,\n                              /*rank_id=*/0);\n    }\n  }\n}\n\n#endif  // __linux__\n\n}  // namespace\n\nclass Graph::GraphImpl final {\n public:\n  explicit GraphImpl(const std::string& model_path, const Device& device = Device(\"cpu\"));\n\n  GraphImpl(const GraphImpl& graph) = delete;\n  GraphImpl(GraphImpl&& graph) = default;\n\n  ~GraphImpl();\n\n  GraphImpl& operator=(const GraphImpl& graph) = delete;\n  GraphImpl& operator=(GraphImpl&& graph) = default;\n\n  InputOutputInfos GetInputInfos();\n  InputOutputInfos GetOutputInfos();\n  std::vector<Tensor> Forward(const std::vector<Tensor>& inputs);\n  void set_batch_size(int batch_size) { batch_size_ = batch_size; }\n\n  of::Maybe<void> RegisterJobPass(\n      const std::function<std::string(const std::string& job)>& pass_fn);\n\n private:\n  of::Maybe<void> CollectInputOutputInfos();\n  of::Maybe<void> Compile(const std::vector<Tensor>& inputs);\n  of::Maybe<std::vector<Tensor>> Run(const std::vector<Tensor>& inputs) const;\n  of::Maybe<void> AddOp(of::OperatorConf op_conf);\n  of::Maybe<void> BuildGraph();\n  of::Maybe<void> LoadCheckpoint();\n  of::Maybe<void> RegisterTensors(const std::vector<Tensor>& inputs);\n  of::Maybe<of::Job> ApplyJobPasses(const of::Job& job);\n\n  std::shared_ptr<of::NNGraph> graph_ = nullptr;\n  std::string model_path_;\n  bool is_compiled_ = false;\n  int batch_size_ = 0;\n  Device device_;\n  of::Job job_;\n\n  InputOutputInfos input_infos_;\n  InputOutputInfos output_infos_;\n  of::HashMap<std::string, std::shared_ptr<of::one::Tensor>> output_name_to_tensor_;\n  of::HashMap<std::string, std::shared_ptr<of::one::Tensor>> variable_op_name_to_tensor_;\n  std::shared_ptr<of::one::TensorTuple> output_tensor_tuple_;\n  std::shared_ptr<of::one::TensorTuple> parameter_tensor_tuple_;\n  std::vector<std::function<std::string(const std::string&)>> registered_job_passes_;\n};\n\nGraph::Graph(const std::string& model_path, const Device& device)\n    : graph_(std::make_unique<GraphImpl>(model_path, device)) {}\n\nGraph::~Graph() = default;\n\nGraph::Graph(Graph&& graph) noexcept : graph_(std::move(graph.graph_)) {}\n\nGraph& Graph::operator=(Graph&& graph) noexcept {\n  if (&graph == this) { return *this; }\n  graph_ = std::move(graph.graph_);\n  return *this;\n}\n\nInputOutputInfos Graph::GetInputInfos() { return graph_->GetInputInfos(); }\n\nInputOutputInfos Graph::GetOutputInfos() { return graph_->GetOutputInfos(); }\n\nvoid Graph::RegisterJobPass(const std::function<std::string(const std::string& job)>& pass_fn) {\n  CHECK_JUST(graph_->RegisterJobPass(pass_fn));\n}\n\nIValue Graph::Forward(const IValue& inputs) {\n  std::vector<Tensor> input_tensors;\n  if (inputs.IsNone()) {\n    // do nothing\n  } else if (inputs.IsTensor()) {\n    input_tensors.emplace_back(inputs.ToTensor());\n  } else if (inputs.IsTensorVector()) {\n    input_tensors = inputs.ToTensorVector();\n  } else {\n    LOG(WARNING) << \"Graph currently only support types: Tensor/vector(Tensor)/None\";\n  }\n\n  std::vector<Tensor> output_tensors = graph_->Forward(input_tensors);\n  if (output_tensors.empty()) {\n    return IValue{};\n  } else if (output_tensors.size() == 1) {\n    return IValue(output_tensors.at(0));\n  } else {\n    return IValue(output_tensors);\n  }\n}\n\nvoid Graph::set_batch_size(int batch_size) { graph_->set_batch_size(batch_size); }\n\nGraph Graph::Load(const std::string& model_path, const Device& device) {\n#ifdef __linux__\n  LoadOneEmbedding(model_path, device);\n#endif  // __linux__\n  Graph graph(model_path, device);\n  return graph;\n}\n\nGraph::GraphImpl::GraphImpl(const std::string& model_path, const Device& device)\n    : model_path_(model_path), device_(device) {\n  CHECK_JUST(of::LoadJobFromIR(&job_, model_path + \"/model.mlir\"));\n  CollectInputOutputInfos();\n  if (of::ParseBooleanFromEnv(\"ONEFLOW_SERVING_DEBUG\", false)) { LOG(ERROR) << job_.DebugString(); }\n  job_.mutable_job_conf()->mutable_predict_conf();\n  job_.mutable_job_conf()->set_job_name(job_.mutable_job_conf()->job_name() + of::NewUniqueId());\n}\n\nInputOutputInfos Graph::GraphImpl::GetInputInfos() { return input_infos_; }\n\nInputOutputInfos Graph::GraphImpl::GetOutputInfos() { return output_infos_; }\n\nof::Maybe<void> Graph::GraphImpl::CollectInputOutputInfos() {\n  const of::OpGraph op_graph(job_);\n  size_t input_order = 0;\n  size_t output_order = 0;\n  op_graph.TopoForEachNode([&](const of::OpNode* node) -> of::Maybe<void> {\n    const of::OperatorConf& op_conf = node->op().op_conf();\n    if (op_conf.has_input_conf()) {\n      of::InterfaceBlobConf blob_conf = op_conf.input_conf().blob_conf();\n      input_infos_[op_conf.name()] =\n          InputOutputAttribute(static_cast<DType>(blob_conf.data_type()),\n                               OfShapeToOfApiShape(of::Shape(blob_conf.shape())), input_order);\n      input_order += 1;\n    } else if (op_conf.has_output_conf()) {\n      of::InterfaceBlobConf blob_conf = op_conf.output_conf().blob_conf();\n      output_infos_[op_conf.name()] =\n          InputOutputAttribute(static_cast<DType>(blob_conf.data_type()),\n                               OfShapeToOfApiShape(of::Shape(blob_conf.shape())), output_order);\n      output_order += 1;\n    }\n    return of::Maybe<void>::Ok();\n  });\n  return of::Maybe<void>::Ok();\n}\n\nof::Maybe<void> Graph::GraphImpl::RegisterJobPass(\n    const std::function<std::string(const std::string& job)>& pass_fn) {\n  if (is_compiled_) {\n    return of::Error::RuntimeError() << \"job pass should be registered before compile and forward\";\n  }\n  registered_job_passes_.emplace_back(pass_fn);\n  return of::Maybe<void>::Ok();\n}\n\nof::Maybe<of::Job> Graph::GraphImpl::ApplyJobPasses(const of::Job& job) {\n  auto current_job = std::make_shared<of::Job>(job);\n  for (const auto& pass_fn : registered_job_passes_) {\n    std::string new_serialized_original_job = pass_fn(current_job->SerializeAsString());\n    of::Job new_job;\n    if (!new_job.ParseFromString(new_serialized_original_job)) {\n      return of::Error::RuntimeError() << \"invalid serialized job after pass applied\";\n    }\n    current_job->Swap(&new_job);\n  }\n  return current_job;\n}\n\nstd::vector<Tensor> Graph::GraphImpl::Forward(const std::vector<Tensor>& inputs) {\n  if (!is_compiled_) {\n    static std::mutex mtx;\n    std::lock_guard<std::mutex> lock(mtx);\n    Compile(inputs).GetOrThrow();\n    is_compiled_ = true;\n  }\n  return Run(inputs).GetOrThrow();\n}\n\nof::Maybe<void> Graph::GraphImpl::Compile(const std::vector<Tensor>& inputs) {\n  JUST(BuildGraph());\n  JUST(RegisterTensors(inputs));\n  JUST(graph_->CompileAndInitRuntime());\n  return of::Maybe<void>::Ok();\n}\n\nof::Maybe<std::vector<Tensor>> Graph::GraphImpl::Run(const std::vector<Tensor>& inputs) const {\n  const auto input_tensor_tuple = std::make_shared<of::one::TensorTuple>();\n  for (const auto& tensor : inputs) { input_tensor_tuple->emplace_back(tensor.tensor_); }\n\n  JUST(of::RunLazyNNGraph(*input_tensor_tuple, *output_tensor_tuple_, graph_));\n  JUST(of::SoftSyncNNGraphBuffers(*output_tensor_tuple_, graph_));\n\n  std::vector<Tensor> outputs;\n  for (const auto& tensor : *output_tensor_tuple_) { outputs.emplace_back(Tensor(tensor)); }\n  return outputs;\n}\n\nof::Maybe<void> Graph::GraphImpl::AddOp(of::OperatorConf op_conf) {\n  {\n    const std::shared_ptr<of::Scope> scope = JUST(of::GetCurrentScope());\n    op_conf.set_scope_symbol_id(scope->symbol_id().value_or(0));\n  }\n  op_conf.set_device_tag(GetDeviceTag(device_));\n  if (batch_size_ > 0 && op_conf.has_input_conf()) {\n    op_conf.mutable_input_conf()->mutable_blob_conf()->mutable_shape()->mutable_dim()->Set(\n        0, batch_size_);\n  }\n  auto* ctx = JUST(of::GetCurInferCtx());\n  JUST(ctx->AddAndInferGlobalOp(op_conf));\n  return of::Maybe<void>::Ok();\n}\n\nof::Maybe<void> Graph::GraphImpl::BuildGraph() {\n  CompileScope build_graph_scope(job_.job_conf(), *device_.device_->shared_from_symbol());\n  {\n    const of::OpGraph op_graph(job_);\n    op_graph.TopoForEachNode([&](const of::OpNode* node) -> of::Maybe<void> {\n      const of::OperatorConf& op_conf = node->op().op_conf();\n      JUST(AddOp(op_conf));\n      if (op_conf.has_variable_conf()) {\n        const of::LazyMode::Guard lazy_mode_disabled_guard{false};\n        const of::VariableOpConf& variable_conf = op_conf.variable_conf();\n        variable_op_name_to_tensor_[op_conf.name()] = JUST(of::one::functional::Empty(\n            of::Shape(variable_conf.shape()),\n            JUST(of::DType::Get(static_cast<of::DataType>(variable_conf.data_type()))),\n            *device_.device_, /*requires_grad=*/false, /*pin_memory=*/false));\n      }\n      return of::Maybe<void>::Ok();\n    });\n  }\n  JUST(LoadCheckpoint());\n  JUST(of::CurJobBuildAndInferCtx_Complete());\n  std::shared_ptr<of::Job> complete_job = JUST(of::GetCurrentJob());\n  int64_t job_id = JUST(of::JobBuildAndInferCtx_GetCurrentJobId());\n  CHECK(of::Singleton<OneFlowEnv>::Get() != nullptr);\n\n  // apply custom job passes\n  complete_job = JUST(ApplyJobPasses(*complete_job));\n  graph_ = std::make_shared<of::NNGraph>(job_.job_conf().job_name(), *complete_job, job_id,\n                                         of::Singleton<OneFlowEnv>::Get()->GetSessionCtx());\n  {\n    const of::OpGraph complete_graph(*complete_job);\n    complete_graph.TopoForEachNode([&](const of::OpNode* node) -> of::Maybe<void> {\n      const of::LazyMode::Guard lazy_mode_disabled_guard{false};\n      const of::OperatorConf& op_conf = node->op().op_conf();\n      if (op_conf.has_output_conf()) {\n        of::InterfaceBlobConf blob_conf = op_conf.output_conf().blob_conf();\n        if (batch_size_ > 0) {\n          const std::string input_lbi_str = op_conf.output_conf().in();\n          const of::LogicalBlobId input_lbi = of::GenLogicalBlobId(input_lbi_str);\n          int64_t batch_size = node->LogicalBlobDesc4Lbi(input_lbi).shape().At(0);\n          blob_conf.mutable_shape()->set_dim(0, batch_size);\n        }\n        output_name_to_tensor_[op_conf.name()] = JUST(of::one::functional::Empty(\n            of::Shape(blob_conf.shape()),\n            JUST(of::DType::Get(static_cast<of::DataType>(blob_conf.data_type()))),\n            *device_.device_, /*requires_grad=*/false, /*pin_memory=*/false));\n      }\n      return of::Maybe<void>::Ok();\n    });\n  }\n  return of::Maybe<void>::Ok();\n}\n\nof::Maybe<void> Graph::GraphImpl::LoadCheckpoint() {\n  for (const auto& variable_op_name_and_tensor : variable_op_name_to_tensor_) {\n    const auto& variable_op_name = variable_op_name_and_tensor.first;\n    const auto& variable_tensor = variable_op_name_and_tensor.second;\n    const std::string variable_filename = model_path_ + \"/\" + variable_op_name + \"/out\";\n    const std::string buffer = [&]() {\n      std::ifstream variable_file(variable_filename, std::ios::binary);\n      CHECK(variable_file.is_open());\n      std::stringstream ss;\n      ss << variable_file.rdbuf();\n      return ss.str();\n    }();\n    const auto& callback = [&](of::ep::Stream* stream,\n                               const std::shared_ptr<of::vm::EagerBlobObject>& eager_blob_object) {\n      of::AutoMemcpy(stream, eager_blob_object->mut_dptr(), buffer.data(),\n                     variable_tensor->shape()->elem_cnt()\n                         * of::GetSizeOfDataType(variable_tensor->dtype()->data_type()),\n                     eager_blob_object->mem_case(), of::memory::MakeHostMemCase());\n    };\n    JUST(of::one::SyncAccessTensorWithTimeOut(variable_tensor, callback, \"mut\"));\n  }\n  const auto& pair = Unzip(variable_op_name_to_tensor_);\n  JUST(of::FillVariableTensorMgr(pair.first, pair.second));\n  return of::Maybe<void>::Ok();\n}\n\nof::Maybe<void> Graph::GraphImpl::RegisterTensors(const std::vector<Tensor>& inputs) {\n  {\n    std::vector<std::string> input_op_names(inputs.size());\n    std::vector<std::shared_ptr<of::one::Tensor>> input_tensors(inputs.size());\n    for (const auto& input_info : input_infos_) {\n      size_t index = input_info.second.input_output_index_;\n      input_op_names[index] = input_info.first;\n      input_tensors[index] = inputs.at(index).tensor_;\n    }\n    JUST(graph_->RegisterInputOpNamesAndTensors(input_op_names, input_tensors));\n  }\n  {\n    const auto& pair = Unzip(output_name_to_tensor_);\n    const std::vector<std::string>& output_op_names = pair.first;\n    const std::vector<std::shared_ptr<of::one::Tensor>>& output_tensors = pair.second;\n    JUST(graph_->RegisterOutputOpNamesAndTensors(output_op_names, output_tensors));\n    output_tensor_tuple_ = ConvertToTensorTuple(output_tensors);\n  }\n  {\n    const auto& t = of::DumpVariableTensorMgr();\n    const std::vector<std::string>& variable_op_names = std::get<0>(t);\n    const std::vector<std::shared_ptr<of::one::Tensor>>& variable_tensors = std::get<1>(t);\n    JUST(graph_->RegisterVariableOpNamesAndTensors(variable_op_names, variable_tensors));\n    parameter_tensor_tuple_ = ConvertToTensorTuple(variable_tensors);\n  }\n  return of::Maybe<void>::Ok();\n}\n\nGraph::GraphImpl::~GraphImpl() { of::vm::ClusterSync().GetOrThrow(); }\n\n}  // namespace oneflow_api\n"
  },
  {
    "path": "oneflow/api/cpp/framework/graph.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#ifndef ONEFLOW_API_CPP_GRAPH_H_\n#define ONEFLOW_API_CPP_GRAPH_H_\n\n#include \"dtype.h\"\n#include \"shape.h\"\n#include \"device.h\"\n#include \"ivalue.h\"\n#include \"tensor.h\"\n#include <cstddef>\n#include <string>\n#include <functional>\n#include <unordered_map>\n\nnamespace oneflow {\n\nclass NNGraph;\n\n}  // namespace oneflow\n\nnamespace oneflow_api {\n\nstruct InputOutputAttribute {\n  InputOutputAttribute(DType datatype, const Shape& input_output_shape, size_t input_output_index)\n      : datatype_(datatype),\n        input_output_shape_(input_output_shape),\n        input_output_index_(input_output_index) {}\n  InputOutputAttribute() : InputOutputAttribute(DType::kInvalidDataType, Shape(), 0) {}\n\n  DType datatype_;\n  Shape input_output_shape_;\n  size_t input_output_index_;\n};\n\nusing InputOutputInfos = std::unordered_map<std::string, InputOutputAttribute>;\n\nclass Graph {\n public:\n  explicit Graph(const std::string& model_path, const Device& device = Device(\"cpu\"));\n  ~Graph();\n\n  Graph(const Graph& graph) = delete;\n  Graph(Graph&& graph) noexcept;\n\n  Graph& operator=(const Graph& graph) = delete;\n  Graph& operator=(Graph&& graph) noexcept;\n\n  InputOutputInfos GetInputInfos();\n  InputOutputInfos GetOutputInfos();\n  IValue Forward(const IValue& inputs);\n  void set_batch_size(int batch_size);\n\n  void RegisterJobPass(const std::function<std::string(const std::string& job)>& pass_fn);\n\n  static Graph Load(const std::string& model_path, const Device& device = Device(\"cpu\"));\n\n private:\n  class GraphImpl;\n  std::unique_ptr<GraphImpl> graph_;\n};\n\n}  // namespace oneflow_api\n\n#endif  // ONEFLOW_API_CPP_GRAPH_H_\n"
  },
  {
    "path": "oneflow/api/cpp/framework/ivalue.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/api/cpp/framework/ivalue.h\"\n#include <glog/logging.h>\n\nnamespace oneflow_api {\n\nnamespace of = oneflow;\n\nstd::ostream& operator<<(std::ostream& os, const IValue::Tag& tag) {\n  os << static_cast<int>(tag);\n  return os;\n}\n\nint64_t IValue::ToInt() const {\n  CHECK_EQ(tag_, Tag::kInt) << \"Current value is not an int.\";\n  return payload_.i.v_int;\n}\n\ndouble IValue::ToDouble() const {\n  CHECK_EQ(tag_, Tag::kDouble) << \"Current value is not a double.\";\n  return payload_.i.v_double;\n}\n\nbool IValue::ToBool() const {\n  CHECK_EQ(tag_, Tag::kBool) << \"Current value is not a bool.\";\n  return payload_.i.v_bool;\n}\n\nconst Tensor& IValue::ToTensor() const {\n  CHECK_EQ(tag_, Tag::kTensor) << \"Current value is not a tensor.\";\n  return payload_.v_tensor;\n}\n\nconst std::vector<Tensor>& IValue::ToTensorVector() const {\n  CHECK_EQ(tag_, Tag::kTensorVector) << \"Current value is not a vector of tensor.\";\n  return payload_.v_tensor_vector;\n}\n\n}  // namespace oneflow_api\n"
  },
  {
    "path": "oneflow/api/cpp/framework/ivalue.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_API_CPP_FRAMEWORK_IVALUE_H_\n#define ONEFLOW_API_CPP_FRAMEWORK_IVALUE_H_\n\n#include <cstdint>\n#include <memory>\n#include <vector>\n#include \"tensor.h\"\n\nnamespace oneflow_api {\n\nclass IValue {\n public:\n  IValue() : tag_(IValue::Tag::kNone) {}\n  explicit IValue(int value) : tag_(IValue::Tag::kInt) { payload_.i.v_int = value; }\n\n  explicit IValue(int64_t value) : tag_(IValue::Tag::kInt) { payload_.i.v_int = value; }\n\n  explicit IValue(double value) : tag_(IValue::Tag::kDouble) { payload_.i.v_double = value; }\n\n  explicit IValue(bool value) : tag_(IValue::Tag::kBool) { payload_.i.v_bool = value; }\n\n  IValue(const Tensor& value) : tag_(IValue::Tag::kTensor) {  // NOLINT\n    new (&payload_.v_tensor) Tensor(value);\n  }\n\n  IValue(Tensor&& value) : tag_(IValue::Tag::kTensor) {  // NOLINT\n    new (&payload_.v_tensor) Tensor(std::move(value));\n  }\n\n  IValue(const std::vector<Tensor>& value) : tag_(IValue::Tag::kTensorVector) {  // NOLINT\n    new (&payload_.v_tensor_vector) std::vector<Tensor>(value);\n  }\n\n  IValue(std::vector<Tensor>&& value) : tag_(IValue::Tag::kTensorVector) {  // NOLINT\n    new (&payload_.v_tensor_vector) std::vector<Tensor>(std::move(value));\n  }\n\n  IValue(const IValue& value) : tag_(value.tag_) {\n    if (IsTensor()) {\n      new (&payload_.v_tensor) Tensor(value.payload_.v_tensor);\n    } else if (IsTensorVector()) {\n      new (&payload_.v_tensor_vector) std::vector<Tensor>(value.payload_.v_tensor_vector);\n    } else {\n      payload_.i = value.payload_.i;\n    }\n  }\n\n  IValue(IValue&& value) noexcept : tag_(value.tag_) { MoveFrom(std::move(value)); }\n\n  IValue& operator=(const IValue& value) {\n    if (&value == this) { return *this; }\n    this->tag_ = value.tag_;\n    *this = IValue(value);\n    return *this;\n  }\n\n  IValue& operator=(IValue&& value) noexcept {\n    if (&value == this) { return *this; }\n    Destory();\n    this->tag_ = value.tag_;\n    MoveFrom(std::move(value));\n    return *this;\n  }\n\n  ~IValue() { Destory(); }\n\n  bool IsNone() const { return tag_ == Tag::kNone; }\n\n  bool IsInt() const { return tag_ == Tag::kInt; }\n\n  bool IsDouble() const { return tag_ == Tag::kDouble; }\n\n  bool IsBool() const { return tag_ == Tag::kBool; }\n\n  bool IsTensor() const { return tag_ == Tag::kTensor; }\n\n  bool IsTensorVector() const { return tag_ == Tag::kTensorVector; }\n\n  int64_t ToInt() const;\n  double ToDouble() const;\n  bool ToBool() const;\n  const Tensor& ToTensor() const;\n  const std::vector<Tensor>& ToTensorVector() const;\n\n private:\n  enum class Tag { kNone = 0, kInt = 1, kDouble = 2, kBool = 3, kTensor = 4, kTensorVector = 5 };\n  friend std::ostream& operator<<(std::ostream&, const Tag&);\n\n  union Payload {  // NOLINT\n    union InternalPayload {\n      InternalPayload() : v_int(0) {}\n\n      int64_t v_int;\n      double v_double;\n      bool v_bool;\n    } i;\n\n    Tensor v_tensor;\n    std::vector<Tensor> v_tensor_vector;\n\n    Payload() : i() {}\n    ~Payload() {}\n  };\n\n  Payload payload_;\n  Tag tag_;\n\n  inline void Destory() {\n    if (IsTensor()) { payload_.v_tensor.~Tensor(); }\n    if (IsTensorVector()) { payload_.v_tensor_vector.~vector(); }\n  }\n\n  inline void MoveFrom(IValue&& value) {\n    if (IsTensor()) {\n      new (&payload_.v_tensor) Tensor(std::move(value.payload_.v_tensor));\n    } else if (IsTensorVector()) {\n      new (&payload_.v_tensor_vector)\n          std::vector<Tensor>(std::move(value.payload_.v_tensor_vector));\n    } else {\n      payload_.i = value.payload_.i;\n    }\n    value.ClearToNone();\n  }\n\n  inline void ClearToNone() {\n    Destory();\n    payload_.i.v_int = 0;\n    tag_ = Tag::kNone;\n  }\n};\n\n}  // namespace oneflow_api\n\n#endif  // ONEFLOW_API_CPP_FRAMEWORK_IVALUE_H_\n"
  },
  {
    "path": "oneflow/api/cpp/framework/shape.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/api/cpp/framework/shape.h\"\n#include \"oneflow/core/common/shape.h\"\n#include \"oneflow/core/common/shape_vec.h\"\n\nnamespace oneflow_api {\n\nnamespace of = oneflow;\nnamespace {\n\nof::DimVector ToOneflowDimVcetor(const std::vector<int64_t>& dim_vec) {\n  return of::DimVector(dim_vec.begin(), dim_vec.end());\n}\n\n}  // namespace\n\nShape::Shape() : shape_(std::make_shared<of::Shape>(of::Shape({0}))) {}\n\nShape::Shape(const std::vector<int64_t>& dim_vec)\n    : shape_(std::make_shared<of::Shape>(ToOneflowDimVcetor(dim_vec))) {}\n\nShape::Shape(const std::initializer_list<int64_t>& dim_vec)\n    : shape_(std::make_shared<of::Shape>(dim_vec)) {}\n\nShape& Shape::operator=(const Shape& shape) {\n  this->shape_.reset();\n  this->shape_ = shape.shape_;\n  return *this;\n}\n\nbool Shape::operator==(const Shape& rhs) const { return *shape_ == *rhs.shape_; }\n\nbool Shape::operator!=(const Shape& rhs) const { return !(*this == rhs); }\n\nint64_t Shape::elem_cnt() const { return shape_->elem_cnt(); }\n\nint64_t Shape::At(int64_t index) const { return shape_->At(index); }\n\nvoid Shape::Set(int64_t index, int64_t val) { shape_->Set(index, val); }\n\nint64_t Shape::NumAxes() const { return shape_->NumAxes(); }\n\nint64_t Shape::Count(int64_t begin_axis, int64_t end_axis) const {\n  return shape_->Count(begin_axis, end_axis);\n}\n\nint64_t Shape::Count(int64_t begin_axis) const { return shape_->Count(begin_axis); }\n\nstd::ostream& operator<<(std::ostream& os, const Shape& shape) {\n  os << shape.shape_->DebugStr();\n  return os;\n}\n\n}  // namespace oneflow_api\n"
  },
  {
    "path": "oneflow/api/cpp/framework/shape.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_API_CPP_FRAMEWORK_SHAPE_H_\n#define ONEFLOW_API_CPP_FRAMEWORK_SHAPE_H_\n\n#include <memory>\n#include <vector>\n\nnamespace oneflow {\n\nclass Shape;\n\n}\n\nnamespace oneflow_api {\n\nclass Shape final {\n  friend class Tensor;\n\n public:\n  Shape();\n  explicit Shape(const std::vector<int64_t>& dim_vec);\n  Shape(const std::initializer_list<int64_t>& dim_vec);\n  ~Shape() = default;\n  Shape& operator=(const Shape& shape);\n\n  [[nodiscard]] bool operator==(const Shape& rhs) const;\n  [[nodiscard]] bool operator!=(const Shape& rhs) const;\n\n  void Set(int64_t index, int64_t val);\n\n  [[nodiscard]] int64_t elem_cnt() const;\n  [[nodiscard]] int64_t At(int64_t index) const;\n  [[nodiscard]] int64_t NumAxes() const;\n  [[nodiscard]] int64_t Count(int64_t begin_axis, int64_t end_axis) const;\n  [[nodiscard]] int64_t Count(int64_t begin_axis) const;\n\n private:\n  std::shared_ptr<oneflow::Shape> shape_ = nullptr;\n\n  friend std::ostream& operator<<(std::ostream&, const Shape&);\n};\n}  // namespace oneflow_api\n\n#endif  // ONEFLOW_API_CPP_FRAMEWORK_SHAPE_H_\n"
  },
  {
    "path": "oneflow/api/cpp/framework/tensor.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/api/cpp/framework/tensor.h\"\n#include \"oneflow/api/cpp/framework/device.h\"\n#include \"oneflow/api/cpp/framework/dtype.h\"\n#include \"oneflow/api/cpp/framework/shape.h\"\n#include \"oneflow/core/common/data_type.pb.h\"\n#include \"oneflow/core/functional/functional.h\"\n#include \"oneflow/core/framework/dtype.h\"\n#include \"oneflow/core/job/lazy_mode.h\"\n#include \"oneflow/core/kernel/kernel_util.h\"\n#include \"oneflow/core/framework/instructions_builder.h\"\n#include \"oneflow/core/framework/dtype.h\"\n#include \"oneflow/core/vm/virtual_machine.h\"\n\nnamespace oneflow_api {\n\nnamespace of = oneflow;\nnamespace functional = of::one::functional;\n\nTensor::Tensor(const Shape& shape, const Device& device, const DType& dtype) {\n  of::LazyMode::Guard lazy_mode_disabled_guard(/*is_enabled*/ false);\n  tensor_ = functional::Empty(*shape.shape_,\n                              of::DType::Get(static_cast<of::DataType>(dtype)).GetOrThrow(),\n                              *device.device_, /*requires_grad=*/false, /*pin_memory=*/false)\n                .GetPtrOrThrow();\n}\nTensor::Tensor(const std::shared_ptr<oneflow::one::Tensor>& tensor) : tensor_(tensor) {}\n\nTensor::Tensor(const Tensor& tensor) : tensor_(tensor.tensor_) {}\nTensor::Tensor(Tensor&& tensor) noexcept : tensor_(std::move(tensor.tensor_)) {}\n\nTensor& Tensor::operator=(const Tensor& tensor) {\n  if (&tensor == this) { return *this; }\n  tensor_ = tensor.tensor_;\n  return *this;\n}\nTensor& Tensor::operator=(Tensor&& tensor) noexcept {\n  if (&tensor == this) { return *this; }\n  tensor_ = std::move(tensor.tensor_);\n  return *this;\n}\n\nShape Tensor::shape() const {\n  const auto shape_ = tensor_->shape();\n  return Shape(std::vector<int64_t>(shape_->dim_vec().begin(), shape_->dim_vec().end()));\n}\n\nDevice Tensor::device() const {\n  const auto device_ = tensor_->device().GetOrThrow();\n  return Device(device_->type(), device_->device_id());\n}\n\nDType Tensor::dtype() const { return static_cast<DType>(tensor_->dtype()->data_type()); }\n\nvoid Tensor::zeros_() {\n  std::shared_ptr<of::one::LocalTensor> local_tensor = tensor_->AsLocalTensor().GetPtrOrThrow();\n  of::PhysicalRun([&](of::InstructionsBuilder* builder) -> of::Maybe<void> {\n    JUST(builder->AccessBlobByCallback(\n        local_tensor,\n        [](of::ep::Stream* stream,\n           const std::shared_ptr<of::vm::EagerBlobObject>& eager_blob_object) {\n          of::AutoMemset(stream, eager_blob_object->mut_dptr(), 0,\n                         eager_blob_object->ByteSizeOfBlobBody(), eager_blob_object->mem_case());\n        },\n        \"mut\"));\n    return of::Maybe<void>::Ok();\n  }).GetOrThrow();\n}\n\nTensor Tensor::from_buffer(const void* buffer, const Shape& shape, const Device& device,\n                           const DType& dtype) {\n  Tensor tensor(shape, device, dtype);\n  std::shared_ptr<of::one::LocalTensor> local_tensor =\n      tensor.tensor_->AsLocalTensor().GetPtrOrThrow();\n  of::PhysicalRun([&](of::InstructionsBuilder* builder) -> of::Maybe<void> {\n    return builder->AccessBlobByCallback(\n        local_tensor,\n        [buffer, shape, dtype](of::ep::Stream* stream,\n                               const std::shared_ptr<of::vm::EagerBlobObject>& eager_blob_object) {\n          of::AutoMemcpy(stream, eager_blob_object->mut_dptr(), buffer,\n                         shape.Count(0) * GetDTypeSize(dtype), eager_blob_object->mem_case(),\n                         of::memory::MakeHostMemCase());\n        },\n        \"mut\");\n  }).GetOrThrow();\n  return tensor;\n}\n\ntemplate<typename T>\nvoid Tensor::copy_to(T* buffer) const {\n  std::shared_ptr<of::one::LocalTensor> local_tensor = tensor_->AsLocalTensor().GetPtrOrThrow();\n  const auto shape = this->shape();\n\n  const auto& Callback = [buffer, shape](\n                             of::ep::Stream* stream,\n                             const std::shared_ptr<of::vm::EagerBlobObject>& eager_blob_object) {\n    of::AutoMemcpy(stream, buffer, eager_blob_object->mut_dptr(), shape.Count(0) * sizeof(T),\n                   of::memory::MakeHostMemCase(), eager_blob_object->mem_case());\n  };\n  auto btb = std::make_shared<of::BlockingThenBusy>();\n  CHECK_JUST(of::PhysicalRun([&](of::InstructionsBuilder* builder) -> of::Maybe<void> {\n    return builder->SyncAccessBlobByCallback(local_tensor, btb, Callback, \"const\");\n  }));\n  TRY(btb->WaitUntilCntEqualZero(of::VirtualMachine::GetPredicatorNoMoreInstructionsFinished()))\n      .GetOrThrow();\n}\n\nconst std::shared_ptr<oneflow::one::Tensor>& Tensor::__internal_tensor() const { return tensor_; }\n\n#define REGISTER_TENSOR_COPY_TO(cpp_dtype) \\\n  template void Tensor::copy_to<cpp_dtype>(cpp_dtype * buffer) const;\n\nREGISTER_TENSOR_COPY_TO(float)\nREGISTER_TENSOR_COPY_TO(double)\nREGISTER_TENSOR_COPY_TO(bool)\nREGISTER_TENSOR_COPY_TO(int8_t)\nREGISTER_TENSOR_COPY_TO(int32_t)\nREGISTER_TENSOR_COPY_TO(int64_t)\n\n}  // namespace oneflow_api\n"
  },
  {
    "path": "oneflow/api/cpp/framework/tensor.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_API_CPP_FRAMEWORK_TENSOR_H_\n#define ONEFLOW_API_CPP_FRAMEWORK_TENSOR_H_\n\n#include <memory>\n#include \"device.h\"\n#include \"shape.h\"\n#include \"dtype.h\"\n\nnamespace oneflow {\nnamespace one {\n\nclass Tensor;\n\n}\n\n}  // namespace oneflow\n\nnamespace oneflow_api {\n\nclass Tensor final {\n  friend class Graph;\n\n public:\n  explicit Tensor(const Shape& shape = Shape(), const Device& device = Device(\"cpu\"),\n                  const DType& dtype = DType::kFloat);\n  explicit Tensor(const std::shared_ptr<oneflow::one::Tensor>& tensor);\n\n  Tensor(const Tensor& tensor);\n  Tensor(Tensor&& tensor) noexcept;\n\n  ~Tensor() = default;\n\n  Tensor& operator=(const Tensor& tensor);\n  Tensor& operator=(Tensor&& tensor) noexcept;\n\n  [[nodiscard]] Shape shape() const;\n  [[nodiscard]] Device device() const;\n  [[nodiscard]] DType dtype() const;\n\n  void zeros_();\n\n  // You should never call __internal_tensor() directly.\n  [[nodiscard]] const std::shared_ptr<oneflow::one::Tensor>& __internal_tensor() const;\n\n  template<typename T>\n  void copy_to(T* buffer) const;\n\n  [[nodiscard]] static Tensor from_buffer(const void* buffer, const Shape& shape,\n                                          const Device& device, const DType& dtype);\n\n private:\n  std::shared_ptr<oneflow::one::Tensor> tensor_ = nullptr;\n};\n\n}  // namespace oneflow_api\n\n#endif  // ONEFLOW_API_CPP_FRAMEWORK_TENSOR_H_\n"
  },
  {
    "path": "oneflow/api/cpp/framework.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#ifndef ONEFLOW_API_CPP_FRAMEWORK_H_\n#define ONEFLOW_API_CPP_FRAMEWORK_H_\n\n#include \"framework/device.h\"\n#include \"framework/shape.h\"\n#include \"framework/dtype.h\"\n#include \"framework/tensor.h\"\n#include \"framework/ivalue.h\"\n#include \"framework/graph.h\"\n\n#endif  // ONEFLOW_API_CPP_FRAMEWORK_H_\n"
  },
  {
    "path": "oneflow/api/cpp/nn/functional/activation.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/api/cpp/nn/functional/activation.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow_api {\nnamespace nn {\n\nnamespace of = oneflow;\nnamespace functional = of::one::functional;\n\nTensor relu(const Tensor& tensor) {\n  return Tensor(functional::Relu(tensor.__internal_tensor(), false).GetPtrOrThrow());\n}\n\n}  // namespace nn\n}  // namespace oneflow_api\n"
  },
  {
    "path": "oneflow/api/cpp/nn/functional/activation.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_API_CPP_NN_FUNCTIONAL_ACTIVATION_H_\n#define ONEFLOW_API_CPP_NN_FUNCTIONAL_ACTIVATION_H_\n\n#include \"../../framework.h\"\n\nnamespace oneflow_api {\nnamespace nn {\n\nTensor relu(const Tensor& tensor);\n\n}\n\n}  // namespace oneflow_api\n\n#endif  // ONEFLOW_API_CPP_NN_FUNCTIONAL_ACTIVATION_H_\n"
  },
  {
    "path": "oneflow/api/cpp/nn.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#ifndef ONEFLOW_API_CPP_NN_H_\n#define ONEFLOW_API_CPP_NN_H_\n\n#include \"nn/functional/activation.h\"\n\n#endif  // ONEFLOW_API_CPP_NN_H_\n"
  },
  {
    "path": "oneflow/api/cpp/tests/api_test.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/api/cpp/tests/api_test.h\"\n#include <cstddef>\n#include <random>\n#include <string>\n#ifdef __linux__\n\n#include <unistd.h>  // readlink\n\n#elif defined(__APPLE__)\n\n#include <mach-o/dyld.h>  //  _NSGetExecutablePath\n\n#endif\n\nnamespace oneflow_api {\n\nShape RandomShape() {\n  thread_local static std::mt19937 rng(std::random_device{}());\n  std::uniform_int_distribution<> dist_ndim(1, 4), dist_dims(16, 64);\n  std::vector<std::int64_t> dims(dist_ndim(rng), 0);\n  for (auto& x : dims) { x = dist_dims(rng); }\n  return Shape(dims);\n}\n\ntemplate<typename T>\nstd::vector<T> RandomData(size_t size) {\n  thread_local static std::mt19937 rng(std::random_device{}());\n  std::uniform_int_distribution<> dist(-100, 100);\n  std::vector<T> data(size);\n  for (auto& x : data) { x = static_cast<T>(dist(rng)); }\n  return data;\n}\n#define REGISTER_RANDOM_DATA(cpp_dtype) template std::vector<cpp_dtype> RandomData(size_t size);\n\nREGISTER_RANDOM_DATA(float)\nREGISTER_RANDOM_DATA(double)\nREGISTER_RANDOM_DATA(int8_t)\nREGISTER_RANDOM_DATA(int32_t)\nREGISTER_RANDOM_DATA(int64_t)\n\nstd::string GetExeDir() {\n  const size_t path_max_size = 4096;  // PATH_MAX = 4096 on linux\n  char result[path_max_size];\n\n  const auto get_dir_from_path = [](char result[], size_t count) -> std::string {\n    std::string exe_path(result, (count > 0) ? count : 0);\n\n    // string(path).rfind('/') will never be string::npos on linux or macos.\n    return exe_path.substr(0, exe_path.rfind('/'));\n  };\n\n#ifdef __linux__\n  ssize_t count = readlink(\"/proc/self/exe\", result, path_max_size);\n  return get_dir_from_path(result, count);\n#elif defined(__APPLE__)\n  uint32_t count = path_max_size;\n  CHECK_EQ(_NSGetExecutablePath(result, &count), 0) << \"Fail to get executable file path.\";\n  return get_dir_from_path(result, count);\n#else\n#error oneflow_api::GetExeDir() has not been supported on windows.\n#endif\n}\n\n}  // namespace oneflow_api\n"
  },
  {
    "path": "oneflow/api/cpp/tests/api_test.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#ifndef ONEFLOW_API_CPP_TESTS_API_TEST_H_\n#define ONEFLOW_API_CPP_TESTS_API_TEST_H_\n\n#include \"oneflow/api/cpp/api.h\"\n\nnamespace oneflow_api {\n\nclass EnvScope {  // NOLINT\n public:\n  EnvScope() { initialize(); }\n  ~EnvScope() { release(); }\n};\n\nShape RandomShape();\n\ntemplate<typename T>\nstd::vector<T> RandomData(size_t size);\n\nstd::string GetExeDir();\n\n}  // namespace oneflow_api\n\n#endif  // !ONEFLOW_API_CPP_TESTS_API_TEST_H_\n"
  },
  {
    "path": "oneflow/api/cpp/tests/graph_test.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include <gtest/gtest.h>\n#include <algorithm>\n#include <array>\n#include <chrono>\n#include <cstdint>\n#include <functional>\n#include <iostream>\n#include <thread>\n#include <vector>\n#include \"oneflow/api/cpp/framework.h\"\n#include \"oneflow/api/cpp/framework/dtype.h\"\n#include \"oneflow/api/cpp/framework/shape.h\"\n#include \"oneflow/api/cpp/tests/api_test.h\"\n\nnamespace oneflow_api {\n\nnamespace {\n\ninline Graph LoadGraph(const Device& device) {\n  Graph graph =\n      Graph::Load(\"./oneflow/api/cpp/tests/graph_test_model/affine_with_parameter\", device);\n  return graph;\n}\n\ninline void Forward(Graph& graph, const Device& device, int expected_batch_dim = 1) {\n  std::vector<float> data(expected_batch_dim * 3);\n  std::fill(data.begin(), data.end(), 1);\n  std::vector<Tensor> inputs;\n  inputs.emplace_back(\n      Tensor::from_buffer(data.data(), Shape({expected_batch_dim, 3}), device, DType::kFloat));\n  const auto& value = graph.Forward(inputs);\n  ASSERT_TRUE(value.IsTensor());\n  Tensor output = value.ToTensor();\n  Shape shape = output.shape();\n  ASSERT_EQ(shape.At(0), expected_batch_dim);\n  ASSERT_EQ(shape.At(1), 4);\n  std::vector<float> buf(expected_batch_dim * 4);\n  output.copy_to(buf.data());\n  for (const float& element : buf) { ASSERT_EQ(element, 4); }\n}\n\n}  // namespace\n\nTEST(Api, graph_cpu_test) {\n  EnvScope scope;\n  Device device(\"cpu\");\n  Graph graph = LoadGraph(device);\n  Forward(graph, device, 1);\n}\n\n#ifdef WITH_CUDA\nTEST(Api, graph_gpu_test) {\n  EnvScope scope;\n  Device device(\"cuda\", 0);\n  Graph graph = LoadGraph(device);\n  Forward(graph, device);\n}\n\nTEST(Api, graph_multi_gpu_test) {\n  EnvScope scope;\n  Device device(\"cuda\", 0);\n  Graph graph = LoadGraph(device);\n  Forward(graph, device);\n\n  Device device1(\"cuda\", 1);\n  Graph graph1 = LoadGraph(device1);\n  Forward(graph1, device1);\n}\n#endif\n\nTEST(Api, graph_cpu_batching_test) {\n  EnvScope scope;\n  Device device(\"cpu\");\n  Graph graph = LoadGraph(device);\n  graph.set_batch_size(10);\n  Forward(graph, device, 10);\n}\n\n#ifdef WITH_CUDA\nTEST(Api, graph_gpu_batching_test) {\n  EnvScope scope;\n  Device device(\"cuda\", 0);\n  Graph graph = LoadGraph(device);\n  graph.set_batch_size(10);\n  Forward(graph, device, 10);\n}\n\nTEST(Api, graph_multi_device_test) {\n  EnvScope scope;\n  Device device(\"cuda\", 0);\n  Graph graph = LoadGraph(device);\n  Forward(graph, device, 1);\n\n  Device device1(\"cuda\", 1);\n  Graph graph1 = LoadGraph(device1);\n  Forward(graph1, device1, 1);\n\n  Device device2(\"cpu\");\n  Graph graph2 = LoadGraph(device2);\n  Forward(graph2, device2, 1);\n}\n\nTEST(Api, graph_unload_test) {\n  {\n    EnvScope scope;\n\n    Device device(\"cuda\", 0);\n    Graph graph = LoadGraph(device);\n    Forward(graph, device, 1);\n\n    {\n      Device device1(\"cuda\", 1);\n      Graph graph1 = LoadGraph(device1);\n      Forward(graph1, device1, 1);\n    }\n\n    Device device2(\"cpu\");\n    Graph graph2 = LoadGraph(device2);\n    Forward(graph2, device2, 1);\n  }\n\n  {\n    EnvScope scope;\n\n    Device device(\"cpu\");\n    Graph graph = LoadGraph(device);\n    Forward(graph, device, 1);\n  }\n}\n#endif\n\nTEST(Api, graph_thread_test) {\n  EnvScope scope;\n\n  Device device(\"cpu\");\n  std::vector<Graph> graphs;\n  for (int i = 0; i < 10; i++) { graphs.emplace_back(LoadGraph(device)); }\n\n  std::vector<std::thread> threads;\n  for (Graph& graph : graphs) {\n    threads.emplace_back(std::thread(std::bind(Forward, std::move(graph), device, 1)));\n  }\n  for (auto& thread : threads) { thread.join(); }\n}\n\nTEST(Api, graph_input_order_test) {\n  EnvScope scope;\n\n  Device device(\"cpu\");\n  Graph graph = Graph::Load(\"./oneflow/api/cpp/tests/graph_test_model/affine_no_parameter\", device);\n\n  std::vector<Tensor> inputs;\n  std::vector<float> x(3);\n  std::fill(x.begin(), x.end(), 1);\n  inputs.emplace_back(Tensor::from_buffer(x.data(), Shape({1, 3}), device, DType::kFloat));\n  std::vector<float> a(3 * 2);\n  std::fill(a.begin(), a.end(), 1);\n  inputs.emplace_back(Tensor::from_buffer(a.data(), Shape({3, 2}), device, DType::kFloat));\n  std::vector<float> b(2);\n  std::fill(b.begin(), b.end(), 1);\n  inputs.emplace_back(Tensor::from_buffer(b.data(), Shape({2}), device, DType::kFloat));\n\n  const auto& value = graph.Forward(inputs);\n  ASSERT_TRUE(value.IsTensor());\n  Tensor output = value.ToTensor();\n  Shape shape = output.shape();\n  ASSERT_EQ(shape.At(0), 1);\n  ASSERT_EQ(shape.At(1), 2);\n  std::array<float, 2> buf{};\n  output.copy_to(buf.data());\n  ASSERT_EQ(buf[0], 4);\n  ASSERT_EQ(buf[1], 4);\n}\n\nTEST(Api, graph_input_output_infos_test) {\n  EnvScope scope;\n  Device device(\"cpu\");\n  Graph graph = LoadGraph(device);\n\n  auto input_infos = graph.GetInputInfos();\n  auto output_infos = graph.GetOutputInfos();\n\n  ASSERT_EQ(input_infos.size(), 1);\n  ASSERT_EQ(output_infos.size(), 1);\n\n  auto it = input_infos.begin();\n  DType dtype = it->second.datatype_;\n  Shape shape = it->second.input_output_shape_;\n  size_t order = it->second.input_output_index_;\n  ASSERT_EQ(dtype, DType::kFloat);\n  ASSERT_EQ(shape.NumAxes(), 2);\n  ASSERT_EQ(shape.At(0), 1);\n  ASSERT_EQ(shape.At(1), 3);\n  ASSERT_EQ(order, 0);\n\n  it = output_infos.begin();\n  dtype = it->second.datatype_;\n  shape = it->second.input_output_shape_;\n  order = it->second.input_output_index_;\n  ASSERT_EQ(dtype, DType::kFloat);\n  ASSERT_EQ(shape.NumAxes(), 2);\n  ASSERT_EQ(shape.At(0), 1);\n  ASSERT_EQ(shape.At(1), 4);\n  ASSERT_EQ(order, 0);\n}\n\n}  // namespace oneflow_api\n"
  },
  {
    "path": "oneflow/api/cpp/tests/graph_test_model/affine_no_parameter/model.mlir",
    "content": "module  {\n  oneflow.job @MyGraph_1(%arg0: tensor<1x3xf32>, %arg1: tensor<3x2xf32>, %arg2: tensor<2xf32>) -> tensor<1x2xf32> {\n    %output = \"oneflow.input\"(%arg0) {data_type = 2 : i32, device_name = [\"@0:0\"], device_tag = \"cpu\", hierarchy = [1], is_dynamic = false, nd_sbp = [\"B\"], op_name = \"_MyGraph_1-input_0\", output_lbns = [\"_MyGraph_1-input_0/out\"], scope_symbol_id = 4611686018427527167 : i64, shape = [1 : si64, 3 : si64]} : (tensor<1x3xf32>) -> tensor<1x3xf32>\n    %output_0 = \"oneflow.input\"(%arg1) {data_type = 2 : i32, device_name = [\"@0:0\"], device_tag = \"cpu\", hierarchy = [1], is_dynamic = false, nd_sbp = [\"B\"], op_name = \"_MyGraph_1-input_1\", output_lbns = [\"_MyGraph_1-input_1/out\"], scope_symbol_id = 4611686018427527167 : i64, shape = [3 : si64, 2 : si64]} : (tensor<3x2xf32>) -> tensor<3x2xf32>\n    %output_1 = \"oneflow.input\"(%arg2) {data_type = 2 : i32, device_name = [\"@0:0\"], device_tag = \"cpu\", hierarchy = [1], is_dynamic = false, nd_sbp = [\"B\"], op_name = \"_MyGraph_1-input_2\", output_lbns = [\"_MyGraph_1-input_2/out\"], scope_symbol_id = 4611686018427527167 : i64, shape = [2 : si64]} : (tensor<2xf32>) -> tensor<2xf32>\n    %0 = \"oneflow.matmul\"(%output, %output_0) {alpha = 1.000000e+00 : f64, device_name = [\"@0:0\"], device_tag = \"cpu\", hierarchy = [1], op_name = \"model-matmul_0\", output_lbns = [\"model-matmul_0/out_0\"], scope_symbol_id = 4611686018427535359 : i64, transpose_a = false, transpose_b = false} : (tensor<1x3xf32>, tensor<3x2xf32>) -> tensor<1x2xf32>\n    %1 = \"oneflow.broadcast_add\"(%0, %output_1) {device_name = [\"@0:0\"], device_tag = \"cpu\", hierarchy = [1], op_name = \"model-broadcast_add_1\", output_lbns = [\"model-broadcast_add_1/z_0\"], scope_symbol_id = 4611686018427535359 : i64} : (tensor<1x2xf32>, tensor<2xf32>) -> tensor<1x2xf32>\n    %output_2 = \"oneflow.output\"(%1) {data_type = 2 : i32, device_name = [\"@0:0\"], device_tag = \"cpu\", hierarchy = [1], is_dynamic = false, nd_sbp = [\"B\"], op_name = \"_MyGraph_1-output_0\", output_lbns = [\"_MyGraph_1-output_0/out\"], scope_symbol_id = 4611686018427527167 : i64, shape = [1 : si64, 2 : si64]} : (tensor<1x2xf32>) -> tensor<1x2xf32>\n    oneflow.return %output_2 : tensor<1x2xf32>\n  }\n}\n"
  },
  {
    "path": "oneflow/api/cpp/tests/graph_test_model/affine_with_parameter/model.a/meta",
    "content": "shape {\n  dim: 3\n  dim: 4\n}\ndata_type: kFloat\n"
  },
  {
    "path": "oneflow/api/cpp/tests/graph_test_model/affine_with_parameter/model.b/meta",
    "content": "shape {\n  dim: 4\n}\ndata_type: kFloat\n"
  },
  {
    "path": "oneflow/api/cpp/tests/graph_test_model/affine_with_parameter/model.mlir",
    "content": "module  {\n  oneflow.job @MyGraph_0(%arg0: tensor<1x3xf32>) -> tensor<1x4xf32> {\n    %output = \"oneflow.input\"(%arg0) {data_type = 2 : i32, device_name = [\"@0:0\"], device_tag = \"cpu\", hierarchy = [1], is_dynamic = false, nd_sbp = [\"B\"], op_name = \"_MyGraph_0-input_0\", output_lbns = [\"_MyGraph_0-input_0/out\"], scope_symbol_id = 4611686018427469823 : i64, shape = [1 : si64, 3 : si64]} : (tensor<1x3xf32>) -> tensor<1x3xf32>\n    %output_0 = \"oneflow.variable\"() {data_type = 2 : i32, device_name = [\"@0:0\"], device_tag = \"cpu\", hierarchy = [1], parallel = #sbp.parallel<[] -> [#sbp.B]>, op_name = \"model.a\", output_lbns = [\"model.a/out\"], scope_symbol_id = 4611686018427482111 : i64, shape = [3 : si64, 4 : si64]} : () -> tensor<3x4xf32>\n    %output_1 = \"oneflow.variable\"() {data_type = 2 : i32, device_name = [\"@0:0\"], device_tag = \"cpu\", hierarchy = [1], parallel = #sbp.parallel<[] -> [#sbp.B]>, op_name = \"model.b\", output_lbns = [\"model.b/out\"], scope_symbol_id = 4611686018427494399 : i64, shape = [4 : si64]} : () -> tensor<4xf32>\n    %0 = \"oneflow.matmul\"(%output, %output_0) {alpha = 1.000000e+00 : f64, device_name = [\"@0:0\"], device_tag = \"cpu\", hierarchy = [1], op_name = \"model-matmul_0\", output_lbns = [\"model-matmul_0/out_0\"], scope_symbol_id = 4611686018427486207 : i64, transpose_a = false, transpose_b = false} : (tensor<1x3xf32>, tensor<3x4xf32>) -> tensor<1x4xf32>\n    %1 = \"oneflow.broadcast_add\"(%0, %output_1) {device_name = [\"@0:0\"], device_tag = \"cpu\", hierarchy = [1], op_name = \"model-broadcast_add_1\", output_lbns = [\"model-broadcast_add_1/z_0\"], scope_symbol_id = 4611686018427486207 : i64} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32>\n    %output_2 = \"oneflow.output\"(%1) {data_type = 2 : i32, device_name = [\"@0:0\"], device_tag = \"cpu\", hierarchy = [1], is_dynamic = false, nd_sbp = [\"B\"], op_name = \"_MyGraph_0-output_0\", output_lbns = [\"_MyGraph_0-output_0/out\"], scope_symbol_id = 4611686018427469823 : i64, shape = [1 : si64, 4 : si64]} : (tensor<1x4xf32>) -> tensor<1x4xf32>\n    oneflow.return %output_2 : tensor<1x4xf32>\n  }\n}\n"
  },
  {
    "path": "oneflow/api/cpp/tests/ivalue_test.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include <random>\n#include <gtest/gtest.h>\n#include \"oneflow/api/cpp/framework/dtype.h\"\n#include \"oneflow/api/cpp/framework/ivalue.h\"\n#include \"oneflow/api/cpp/tests/api_test.h\"\n\nnamespace oneflow_api {\n\nnamespace {\n\nstd::mt19937 rng(std::random_device{}());\n\n}\n\nTEST(Api, ivalue) {\n  std::uniform_real_distribution<> dist(-100, 100);\n  std::uniform_int_distribution<> dist_bool(0, 1);\n\n  const auto v_int = static_cast<int>(dist(rng));\n  ASSERT_EQ(IValue(v_int).ToInt(), v_int);\n\n  const auto v_int64 = static_cast<int64_t>(dist(rng));\n  ASSERT_EQ(IValue(v_int64).ToInt(), v_int64);\n\n  const auto v_float = static_cast<float>(dist(rng));\n  ASSERT_EQ(IValue(v_float).ToDouble(), v_float);\n\n  const auto v_double = static_cast<double>(dist(rng));\n  ASSERT_EQ(IValue(v_double).ToDouble(), v_double);\n\n  const auto v_bool = static_cast<bool>(dist_bool(rng));\n  ASSERT_EQ(IValue(v_bool).ToBool(), v_bool);\n}\n\nTEST(Api, ivalue_tensor) {\n  EnvScope scope;\n\n  const auto device = Device(\"cpu\");\n  const auto shape = RandomShape();\n  const auto dtype = DType::kDouble;\n\n  const IValue i_tensor(Tensor(shape, device, dtype));\n  const auto& tensor = i_tensor.ToTensor();\n\n  ASSERT_EQ(tensor.shape(), shape);\n  ASSERT_EQ(tensor.device(), device);\n  ASSERT_EQ(tensor.dtype(), dtype);\n}\n\nTEST(Api, ivalue_tensor_vector) {\n  EnvScope scope;\n\n  const auto device = Device(\"cpu\");\n\n  const std::vector<Tensor> v_tensor_vector{Tensor(RandomShape(), device, DType::kDouble),\n                                            Tensor(RandomShape(), device, DType::kFloat)};\n  const auto i_tensor = IValue(v_tensor_vector);\n  const auto& tensor_vector = i_tensor.ToTensorVector();\n\n  ASSERT_EQ(v_tensor_vector.size(), tensor_vector.size());\n\n  for (size_t i = 0; i < tensor_vector.size(); ++i) {\n    ASSERT_EQ(v_tensor_vector[i].device(), tensor_vector[i].device());\n    ASSERT_EQ(v_tensor_vector[i].shape(), tensor_vector[i].shape());\n    ASSERT_EQ(v_tensor_vector[i].dtype(), tensor_vector[i].dtype());\n  }\n}\n\nTEST(Api, ivalue_copy) {\n  EnvScope scope;\n\n  const auto device = Device(\"cpu\");\n  const auto shape = RandomShape();\n  const auto dtype = DType::kDouble;\n\n  const IValue i_tensor(Tensor(shape, device, dtype));\n  const auto i_tensor_a = i_tensor;  // NOLINT\n\n  ASSERT_EQ(i_tensor_a.ToTensor().shape(), shape);\n  ASSERT_EQ(i_tensor_a.ToTensor().device(), device);\n  ASSERT_EQ(i_tensor_a.ToTensor().dtype(), dtype);\n\n  IValue i_tensor_b;\n  i_tensor_b = i_tensor;\n\n  ASSERT_EQ(i_tensor_b.ToTensor().shape(), shape);\n  ASSERT_EQ(i_tensor_b.ToTensor().device(), device);\n  ASSERT_EQ(i_tensor_b.ToTensor().dtype(), dtype);\n}\n\nTEST(Api, ivalue_move) {\n  EnvScope scope;\n\n  const auto device = Device(\"cpu\");\n  const auto shape = RandomShape();\n  const auto dtype = DType::kDouble;\n\n  IValue i_tensor_a = IValue(Tensor(shape, device, dtype));\n  IValue i_tensor_b = IValue(Tensor(shape, device, dtype));\n\n  IValue i_tensor_c = std::move(i_tensor_a);\n  ASSERT_EQ(i_tensor_c.ToTensor().shape(), shape);\n  ASSERT_EQ(i_tensor_c.ToTensor().device(), device);\n  ASSERT_EQ(i_tensor_c.ToTensor().dtype(), dtype);\n\n  IValue i_tensor_d;\n  i_tensor_d = std::move(i_tensor_b);\n  ASSERT_EQ(i_tensor_d.ToTensor().shape(), shape);\n  ASSERT_EQ(i_tensor_d.ToTensor().device(), device);\n  ASSERT_EQ(i_tensor_d.ToTensor().dtype(), dtype);\n\n  ASSERT_EQ(i_tensor_a.IsNone(), true);\n  ASSERT_EQ(i_tensor_b.IsNone(), true);\n}\n\n}  // namespace oneflow_api\n"
  },
  {
    "path": "oneflow/api/cpp/tests/nn_test.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include <random>\n#include <thread>\n#include <gtest/gtest.h>\n#include \"oneflow/api/cpp/tests/api_test.h\"\n\nnamespace oneflow_api {\n\nnamespace {\n\nstd::mt19937 rng(std::random_device{}());\n\ntemplate<typename T>\nstd::vector<T> Relu(const std::vector<T>& data) {\n  std::vector<T> result(data.begin(), data.end());\n  T zero = static_cast<T>(0);\n  for (auto& x : result) {\n    if (x < zero) { x = zero; }\n  }\n  return result;\n}\n\n}  // namespace\n\nvoid TestRelu() {\n  const auto shape = RandomShape();\n  const auto data = RandomData<float>(shape.Count(0));\n  const auto target_data = Relu(data);\n  std::vector<float> result(shape.Count(0));\n\n  auto tensor = Tensor::from_buffer(data.data(), shape, Device(\"cpu\"), DType::kFloat);\n  auto result_tensor = nn::relu(tensor);\n\n  result_tensor.copy_to(result.data());\n\n  ASSERT_EQ(result, target_data);\n}\n\nTEST(Api, nn_relu) {\n  EnvScope scope;\n\n  TestRelu();\n}\n\nTEST(Api, nn_relu_multithreading) {\n  EnvScope scope;\n\n  std::vector<std::thread> threads;\n  std::uniform_int_distribution<> dist(8, 32);\n  int n_threads = dist(rng);\n\n  for (int i = 0; i < n_threads; ++i) { threads.emplace_back(std::thread(TestRelu)); }\n\n  for (auto& x : threads) { x.join(); }\n}\n\n}  // namespace oneflow_api\n"
  },
  {
    "path": "oneflow/api/cpp/tests/one_embedding_test.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <gtest/gtest.h>\n#include \"oneflow/api/cpp/tests/api_test.h\"\n\nnamespace oneflow_api {\n\n#ifdef WITH_CUDA\nTEST(Api, embedding_test) {\n  EnvScope scope;\n  Device device(\"cuda\");\n  Graph graph = Graph::Load(\"/path/to/embedding\", device);\n  int64_t batch_size = 10000;\n  int64_t num_features = 39;\n\n  std::vector<int64_t> data(batch_size * num_features);\n  std::fill(data.begin(), data.end(), 1);\n  std::vector<Tensor> inputs;\n  inputs.emplace_back(\n      Tensor::from_buffer(data.data(), Shape({batch_size, num_features}), device, DType::kInt64));\n\n  const auto& value = graph.Forward(inputs);\n\n  ASSERT_TRUE(value.IsTensor());\n  Tensor output = value.ToTensor();\n  Shape shape = output.shape();\n  ASSERT_EQ(shape.At(0), batch_size);\n  ASSERT_EQ(shape.At(1), 1);\n\n  std::vector<float> buf(batch_size);\n  output.copy_to(buf.data());\n}\n#endif\n\n}  // namespace oneflow_api\n"
  },
  {
    "path": "oneflow/api/cpp/tests/tensor_test.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include <gtest/gtest.h>\n#include \"oneflow/api/cpp/tests/api_test.h\"\n\nnamespace oneflow_api {\n\nTEST(Api, device) {\n  EnvScope scope;\n\n  auto device = Device(\"cpu\");\n  ASSERT_EQ(device.type(), \"cpu\");\n\n#ifdef WITH_CUDA\n  device = Device(\"cuda:0\");\n  ASSERT_EQ(device.type(), \"cuda\");\n  ASSERT_EQ(device.device_id(), 0);\n\n  device = Device(\"cuda\", 1);\n  ASSERT_EQ(device.type(), \"cuda\");\n  ASSERT_EQ(device.device_id(), 1);\n#endif\n}\n\nTEST(Api, tensor) {\n  EnvScope scope;\n\n  const auto device = Device(\"cpu\");\n  const auto shape = RandomShape();\n  const auto dtype = DType::kDouble;\n\n  Tensor tensor;\n  ASSERT_EQ(tensor.shape(), Shape());\n  ASSERT_EQ(tensor.device(), Device(\"cpu\"));\n  ASSERT_EQ(tensor.dtype(), DType::kFloat);\n\n  Tensor tensor_with_all(shape, device, dtype);\n\n  ASSERT_EQ(tensor_with_all.shape(), shape);\n  ASSERT_EQ(tensor_with_all.device(), device);\n  ASSERT_EQ(tensor_with_all.dtype(), dtype);\n}\n\nTEST(Api, tensor_from_buffer_and_copy_to) {\n  EnvScope scope;\n\n  const auto shape = RandomShape();\n\n#define TEST_TENSOR_FROM_AND_TO_BLOB(dtype, cpp_dtype)                                           \\\n  std::vector<cpp_dtype> data_##cpp_dtype(shape.Count(0)), new_data_##cpp_dtype(shape.Count(0)); \\\n  for (int i = 0; i < shape.Count(0); ++i) { data_##cpp_dtype[i] = i; }                          \\\n  auto tensor_##cpp_dtype =                                                                      \\\n      Tensor::from_buffer(data_##cpp_dtype.data(), shape, Device(\"cpu\"), dtype);                 \\\n  tensor_##cpp_dtype.copy_to(new_data_##cpp_dtype.data());                                       \\\n  ASSERT_EQ(new_data_##cpp_dtype, data_##cpp_dtype);\n\n  TEST_TENSOR_FROM_AND_TO_BLOB(DType::kFloat, float)\n  TEST_TENSOR_FROM_AND_TO_BLOB(DType::kDouble, double)\n  TEST_TENSOR_FROM_AND_TO_BLOB(DType::kInt8, int8_t)\n  TEST_TENSOR_FROM_AND_TO_BLOB(DType::kInt32, int32_t)\n  TEST_TENSOR_FROM_AND_TO_BLOB(DType::kInt64, int64_t)\n}\n\nTEST(Api, tensor_zeros) {\n  EnvScope scope;\n\n  const auto shape = RandomShape();\n\n  std::vector<float> data(shape.Count(0)), target_data(shape.Count(0));\n\n  Tensor tensor(shape, Device(\"cpu\"), DType::kFloat);\n  tensor.zeros_();\n  tensor.copy_to(data.data());\n\n  std::fill(target_data.begin(), target_data.end(), 0);\n\n  ASSERT_EQ(data, target_data);\n}\n\n}  // namespace oneflow_api\n"
  },
  {
    "path": "oneflow/api/python/autograd/autograd.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include <pybind11/pybind11.h>\n#include <memory>\n#include <utility>\n#include <vector>\n#include \"oneflow/api/python/of_api_registry.h\"\n#include \"oneflow/api/python/job_build/job_build_and_infer.h\"\n#include \"oneflow/core/common/throw.h\"\n#include \"oneflow/core/framework/dtype.h\"\n#include \"oneflow/core/framework/scope_util.h\"\n#include \"oneflow/core/framework/tensor.h\"\n#include \"oneflow/core/framework/tensor_tuple.h\"\n#include \"oneflow/core/autograd/autograd_engine.h\"\n#include \"oneflow/core/framework/op_interpreter/op_interpreter_util.h\"\n#include \"oneflow/core/functional/functional.h\"\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/common/container_util.h\"\n#include \"oneflow/core/framework/saved_tensor_hooks.h\"\n#include \"oneflow/extension/stack/python/stack_getter.h\"\n\nnamespace oneflow {\nnamespace autograd {\n\nnamespace {\n\nbool IsScalarTensor(const one::Tensor& tensor) {\n  const auto& shape = tensor.shape();\n  return shape->elem_cnt() == 1;\n}\n\n// Checks and sets default value for initial gradients based on out_grads\n// If output is the tensor whose size is greater than 1, out_grad's shape must be same as output's.\n// If output is a scalar tensor, out_grad will also be a scaler or empty(will be initted to\n// `oneflow.ones([1])`).\nMaybe<one::TensorTuple> CheckAndInitOutGrads(const one::TensorTuple& outputs,\n                                             const one::TensorTuple& out_grads,\n                                             bool is_grads_batched) {\n  size_t grad_size = out_grads.empty() ? outputs.size() : out_grads.size();\n  auto gradients = std::make_shared<one::TensorTuple>(grad_size);\n  CHECK_EQ_OR_RETURN(outputs.size(), gradients->size())\n      << \"RuntimeError: got \" << outputs.size() << \" tensors and \" << gradients->size()\n      << \" gradients\";\n  for (int i = 0; i < outputs.size(); ++i) {\n    CHECK_OR_RETURN(outputs.at(i)->requires_grad())\n        << \"\\nRuntimeError: element \" << i\n        << \" of tensors does not require grad and does not have a grad_fn\";\n    if (!outputs.at(i)->grad_fn_node()) {\n      CHECK_OR_RETURN(outputs.at(i)->is_leaf())\n          << \"output[\" << i << \"] doesn't have grad_fn and it is not leaf tensor!\\n\"\n          << \"It is a bug with oneflow, please submit an issue on GitHub: \"\n             \"https://github.com/Oneflow-Inc/oneflow/issues\";\n      JUST(one::AddAccumulateFunctionNode(outputs.at(i)));\n    }\n    if (out_grads.empty() || !out_grads.at(i)) {\n      CHECK_OR_RETURN(IsScalarTensor(*outputs.at(i)))\n          << \"Grad can be implicitly created only for scalar outputs\";\n      gradients->at(i) = JUST(one::functional::OnesLike(outputs.at(i)));\n    } else {\n      if (is_grads_batched) {\n        if (*(outputs.at(i)->shape()) != *JUST(out_grads.at(i)->shape()->Slice(1))) {\n          THROW(RuntimeError) << \"If `is_grads_batched=True`, we interpret the first \"\n                              << \"dimension of each grad_output as the batch dimension. \"\n                              << \"The sizes of the remaining dimensions are expected to match \"\n                              << \"the shape of corresponding output, but a mismatch \"\n                              << \"was detected: grad_output[\" << i << \"] has a shape of \"\n                              << out_grads.at(i)->shape()->ToString() << \" and output[\" << i\n                              << \"] has a shape of \" << outputs.at(i)->shape()->ToString() << \".\";\n        }\n\n      } else {\n        CHECK_EQ_OR_RETURN(*(outputs.at(i)->shape()), *(out_grads.at(i)->shape()))\n            << \"out_grad's shape must be same as output's (\" << outputs.at(i)->shape()->ToString()\n            << \" vs \" << out_grads.at(i)->shape()->ToString() << \")\";\n      }\n      if (JUST(oneflow::VectorAt(outputs, i))->dtype()\n          != JUST(oneflow::VectorAt(out_grads, i))->dtype()) {\n        JUST(oneflow::VectorAt(*gradients, i)) =\n            JUST(one::functional::Cast(out_grads[i], outputs[i]->dtype(), /*pin_memory=*/false));\n      } else {\n        JUST(oneflow::VectorAt(*gradients, i)) = out_grads[i];\n      }\n    }\n  }\n  if (LazyMode::is_enabled()) { JUST(MarkOutputGradients(outputs, *gradients)); }\n  return gradients;\n}\n\n}  // namespace\n\nMaybe<one::TensorTuple> Backward(const one::TensorTuple& outputs, const one::TensorTuple& out_grads,\n                                 bool retain_graph, bool create_graph) {\n  PythonFrameGuard pf;\n  BackwardPassScopeGuard backward_guard;\n  if (create_graph) { retain_graph = true; }\n  std::shared_ptr<one::TensorTuple> gradients =\n      JUST(CheckAndInitOutGrads(outputs, out_grads, /*is_grads_batched=*/false));\n  JUST(one::GetThreadLocalAutogradEngine()->RunBackwardAndSaveGrads4LeafTensorIf(\n      outputs, *gradients, retain_graph, create_graph));\n  return std::make_shared<one::TensorTuple>(0);\n}\n\nMaybe<one::TensorTuple> Grad(const one::TensorTuple& outputs, const one::TensorTuple& inputs,\n                             const one::TensorTuple& out_grads, bool retain_graph,\n                             bool create_graph, bool allow_unused, bool is_grads_batched) {\n  PythonFrameGuard pf;\n  BackwardPassScopeGuard backward_guard;\n  if (create_graph) { retain_graph = true; }\n  if (inputs.empty()) { return Backward(outputs, out_grads, retain_graph, create_graph); }\n  CHECK_OR_RETURN(std::all_of(\n      inputs.begin(), inputs.end(),\n      [](const std::shared_ptr<one::Tensor>& tensor) { return tensor->requires_grad(); }))\n      << \"All input tensors `.requires_grad` should be true\";\n  std::shared_ptr<one::TensorTuple> gradients =\n      JUST(CheckAndInitOutGrads(outputs, out_grads, is_grads_batched));\n  return one::GetThreadLocalAutogradEngine()->RunBackwardAndReturnInputsTensorGradIf(\n      outputs, inputs, *gradients, retain_graph, create_graph, allow_unused);\n}\n\nnamespace py = pybind11;\n\nclass PySavedTensorHook final : public one::SavedTensorHook {\n public:\n  PySavedTensorHook(const py::function& pack_hook, const py::function& unpack_hook)\n      : pack_hook_(pack_hook), unpack_hook_(unpack_hook) {}\n\n  void pack(const std::shared_ptr<one::Tensor>& tensor) {\n    py::gil_scoped_acquire acquire;\n    py::object packed = pack_hook_(tensor);\n    data_ = packed.release().ptr();\n  }\n  std::shared_ptr<one::Tensor> unpack() {\n    py::gil_scoped_acquire acquire;\n    py::object obj = py::cast<py::object>(data_);\n    py::object x = unpack_hook_(obj);\n    std::shared_ptr<one::Tensor> tensor;\n    try {\n      tensor = py::cast<std::shared_ptr<one::Tensor>>(x);\n    } catch (const py::cast_error& e) {\n      THROW(RuntimeError) << \"unpack_hook should return a Tensor, but got `\"\n                          << py::str(x.get_type()).cast<std::string>() << \"` instead\";\n    }\n    return tensor;\n  }\n\n private:\n  PyObject* data_ = nullptr;\n  py::function pack_hook_;\n  py::function unpack_hook_;\n};\n\nclass PySavedTensorHookCreator final : public one::SavedTensorHookCreator {\n public:\n  std::unique_ptr<one::SavedTensorHook> new_saved_tensor_hook() const override {\n    if (hooks_.empty()) { return nullptr; }\n    return std::make_unique<PySavedTensorHook>(hooks_.back().first, hooks_.back().second);\n  }\n  void append_new_hooks(const py::function& pack_hook, const py::function& unpack_hook) {\n    hooks_.emplace_back(pack_hook, unpack_hook);\n  }\n  void pop_hooks() {\n    CHECK_OR_THROW(!hooks_.empty()) << \"pop_hooks should not be called when there are no hooks\";\n    hooks_.pop_back();\n  }\n\n private:\n  small_vector<std::pair<py::function, py::function>, 1> hooks_;\n};\n\nONEFLOW_API_PYBIND11_MODULE(\"autograd\", m) {\n  m.def(\"backward\", &Backward);\n  m.def(\"grad\", &Grad);\n  m.def_submodule(\"graph\")\n      .def(\"register_saved_tensors_hook_manager\",\n           []() {\n             Singleton<one::SavedTensorHookCreator>::SetAllocated(new PySavedTensorHookCreator());\n           })\n      .def(\"append_new_hooks\",\n           [](const py::function& pack_hook, const py::function& unpack_hook) {\n             PySavedTensorHookCreator* creator = dynamic_cast<PySavedTensorHookCreator*>(\n                 Singleton<one::SavedTensorHookCreator>::Get());\n             CHECK_NOTNULL_OR_THROW(creator)\n                 << \"`register_saved_tensors_hook_manager` should be called \"\n                    \"before calling `append_new_hooks`\";\n             creator->append_new_hooks(pack_hook, unpack_hook);\n           })\n      .def(\"pop_hooks\", []() {\n        PySavedTensorHookCreator* creator =\n            dynamic_cast<PySavedTensorHookCreator*>(Singleton<one::SavedTensorHookCreator>::Get());\n        CHECK_NOTNULL_OR_THROW(creator) << \"`register_saved_tensors_hook_manager` should be called \"\n                                           \"before calling `pop_hooks`\";\n        creator->pop_hooks();\n      });\n}\n\n}  // namespace autograd\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/api/python/autograd/autograd_engine.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include <pybind11/pybind11.h>\n#include \"oneflow/api/python/of_api_registry.h\"\n#include \"oneflow/core/framework/dtype.h\"\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/framework/global_param_grad_sync_mode.h\"\n\nnamespace py = pybind11;\n\nnamespace oneflow {\n\nONEFLOW_API_PYBIND11_MODULE(\"\", m) {\n  py::class_<GlobalParamGradSyncMode, std::shared_ptr<GlobalParamGradSyncMode>>(\n      m, \"GlobalParamGradSyncMode\")\n      .def(py::init([](bool flag) { return std::make_shared<GlobalParamGradSyncMode>(flag); }));\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/api/python/autograd/autograd_function.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include <memory>\n#include <pybind11/pybind11.h>\n#include <pybind11/stl.h>\n#include <pybind11/functional.h>\n\n#include \"oneflow/api/python/of_api_registry.h\"\n#include \"oneflow/api/python/functional/common.h\"\n#include \"oneflow/core/autograd/autograd_function.h\"\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/framework/tensor_tuple.h\"\n\nnamespace py = pybind11;\n\nnamespace oneflow {\n\nnamespace {\n\n// Transform input to TensorTuple\nMaybe<one::TensorTuple> UnpackTensorTuple(const py::object& input) {\n  one::TensorTuple tp;\n  if (one::PyTensor_Check(input.ptr())) {\n    tp.emplace_back(input.cast<std::shared_ptr<one::Tensor>>());\n  } else if (py::isinstance<py::tuple>(input)) {\n    auto tuple = input.cast<py::tuple>();\n    tp.resize(tuple.size());\n    for (int i = 0; i < tuple.size(); ++i) {\n      PyObject* obj = tuple[i].ptr();\n      if (obj == Py_None) {\n        // do nothing\n      } else if (one::PyTensor_Check(obj)) {\n        tp[i] = one::PyTensor_Unpack(obj);\n      } else {\n        return Error::RuntimeError()\n               << \"expected Tensor or None as element \" << i << \", but got \"\n               << one::functional::PyStringAsString(PyObject_Str((PyObject*)Py_TYPE(obj)));\n      }\n    }\n  } else {\n    return Error::RuntimeError()\n           << \"autograd.Function's output only support tensor or list of tensors\";\n  }\n  return tp;\n}\n\n// Return single Tensor when TensorTuple's size is one, otherwise py::tuple\npy::object PackTensorTuple(const one::TensorTuple& tp) {\n  if (tp.size() == 1) {\n    return py::cast(tp.at(0));\n  } else {\n    py::tuple out = py::tuple(tp.size());\n    for (int i = 0; i < tp.size(); ++i) { out[i] = tp.at(i); }\n    return py::cast<py::object>(out);\n  }\n}\n\n// wrap PyFunction, unpack the inputs from TensorTuple and pack outputs to TensorTuple\none::AutogradFunctionBase::FType PackPyFunctionToFType(const py::function& func) {\n  return [func](const std::shared_ptr<one::FunctionAutoGradCaptureState>& ctx,\n                const one::TensorTuple& inputs) {\n    const py::tuple& a = py::cast(inputs);\n    py::object res = func(ctx, *a);\n    return UnpackTensorTuple(res).GetPtrOrThrow();\n  };\n}\n\n}  // namespace\n\nnamespace one {\n\nONEFLOW_API_PYBIND11_MODULE(\"autograd\", m) {\n  py::class_<AutogradFunctionBase, std::shared_ptr<AutogradFunctionBase>>(m, \"AutogradFunctionBase\")\n      .def(py::init([]() { return std::make_shared<AutogradFunctionBase>(); }))\n      .def_static(\"apply\",\n                  [](const std::string& name, const py::function& forward_fn,\n                     const py::function& backward_fn, const py::args& input) -> Maybe<py::object> {\n                    const auto& input_tensor_tuple = JUST(UnpackTensorTuple(input));\n                    const std::shared_ptr<TensorTuple>& res = JUST(AutogradFunctionBase::Apply(\n                        name, PackPyFunctionToFType(forward_fn), PackPyFunctionToFType(backward_fn),\n                        *input_tensor_tuple));\n                    return PackTensorTuple(*res);\n                  });\n}\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/api/python/autograd/autograd_function_state.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/api/python/autograd/autograd_function_state.h\"\n\n#include <pybind11/pybind11.h>\n#include \"oneflow/api/python/exception/exception.h\"\n#include \"oneflow/api/python/functional/common.h\"\n#include \"oneflow/api/python/of_api_registry.h\"\n\nnamespace py = pybind11;\nnamespace oneflow {\nnamespace one {\nnamespace {\ninline FunctionAutoGradCaptureState* CheckAndGetStateData(PyAutogradFunctionState* state) {\n  if (!state->data.lock()) {\n    PyErr_Format(PyExc_RuntimeError, \"Data is deallocated. Please don't hold context outside \"\n                                     \"autograd.Function.forward or autograd.Function.backward\");\n    return nullptr;\n  }\n  return state->data.lock().get();\n}\n}  // namespace\n\n#if PY_VERSION_HEX < 0x03070000\n#define PYGETSET_NAME(name) const_cast<char*>(name)\n#else\n#define PYGETSET_NAME(name) (name)\n#endif\n\n#define PY_XINCREF(p) (({ Py_XINCREF(p); }), (p))\n\nstatic PyObject* PyAutogradFunctionState_new(PyTypeObject* type, PyObject* args, PyObject* kwds) {\n  PyAutogradFunctionState* self = (PyAutogradFunctionState*)type->tp_alloc(type, 0);\n  if (self != NULL) {\n    self->dynamic_attr_dict = PyDict_New();\n    if (self->dynamic_attr_dict == NULL) {\n      Py_DECREF(self);\n      return NULL;\n    }\n  }\n  return (PyObject*)self;\n}\n\nstatic void PyAutogradFunctionState_dealloc(PyAutogradFunctionState* self) {\n  Py_XDECREF(self->dynamic_attr_dict);\n  Py_TYPE(self)->tp_free((PyObject*)self);\n}\n\n// PyMethodDef start\nstatic PyObject* PyAutogradFunctionState_save_for_backward(PyObject* self, PyObject* args) {\n  HANDLE_ERRORS\n  auto* _self = (PyAutogradFunctionState*)self;\n  if (!functional::PyTensorSequenceCheck(args)) {\n    return PyErr_Format(PyExc_TypeError, \"save_for_backward() only support Tensor or Tensors\");\n  }\n  const std::vector<std::shared_ptr<Tensor>>& tensor_list =\n      functional::PyUnpackTensorSequence(args);\n  for (const auto& tensor : tensor_list) {\n    CheckAndGetStateData(_self)->SaveTensorForBackward(tensor);\n  }\n  Py_RETURN_NONE;\n  END_HANDLE_ERRORS\n}\n\nstatic PyObject* PyAutogradFunctionState_mark_non_differentiable(PyObject* self, PyObject* args) {\n  HANDLE_ERRORS\n  auto* _self = (PyAutogradFunctionState*)self;\n  if (!functional::PyTensorSequenceCheck(args)) {\n    return PyErr_Format(PyExc_TypeError, \"save_for_backward() only support Tensor or Tensors\");\n  }\n  const std::vector<std::shared_ptr<Tensor>>& tensor_list =\n      functional::PyUnpackTensorSequence(args);\n  for (const auto& tensor : tensor_list) {\n    CheckAndGetStateData(_self)->MarkNonDifferentiable(tensor);\n  }\n  Py_RETURN_NONE;\n  END_HANDLE_ERRORS\n}\n\nstatic PyObject* PyAutogradFunctionState_is_data_valid(PyObject* self) {\n  auto* _self = (PyAutogradFunctionState*)self;\n  return functional::CastToPyObject(_self->data.lock() != nullptr);\n}\n\nstatic PyMethodDef PyAutogradFunctionState_methods[] = {\n    {\"save_for_backward\", (PyCFunction)PyAutogradFunctionState_save_for_backward, METH_VARARGS,\n     NULL},\n    {\"mark_non_differentiable\", (PyCFunction)PyAutogradFunctionState_mark_non_differentiable,\n     METH_VARARGS, NULL},\n    {\"_is_data_valid\", (PyCFunction)PyAutogradFunctionState_is_data_valid, METH_NOARGS, NULL},\n    {NULL} /* Sentinel */\n};\n// PyMethodDef end\n\n// PyAutogradFunctionState_getset start\nstatic PyObject* PyAutogradFunctionState_saved_tensors(PyObject* self, void*) {\n  auto* _self = (PyAutogradFunctionState*)self;\n  return functional::CastToPyObject<Maybe<TensorTuple>>(\n      CheckAndGetStateData(_self)->SavedTensors());\n}\n\nstatic PyObject* PyAutogradFunctionState_get_dict(PyObject* self, PyObject* args) {\n  HANDLE_ERRORS\n  auto* _self = (PyAutogradFunctionState*)self;\n  return _self->dynamic_attr_dict;\n  Py_RETURN_NONE;\n  END_HANDLE_ERRORS\n}\n\nstatic PyGetSetDef PyAutogradFunctionState_properties[] = {\n    {PYGETSET_NAME(\"saved_tensors\"), (getter)PyAutogradFunctionState_saved_tensors, NULL, NULL,\n     NULL},\n    {PYGETSET_NAME(\"__dict__\"), (getter)PyAutogradFunctionState_get_dict, NULL, NULL, NULL},\n    {NULL} /* Sentinel */\n};\n// PyAutogradFunctionState_getset end\n\nPyObject* PyAutogradFunctionState_getattro(PyObject* self, PyObject* attr) {\n  PyObject* res = NULL;\n  res = PyDict_GetItem(((PyAutogradFunctionState*)self)->dynamic_attr_dict, attr);\n  if (!res) {\n    // Not found attr in dynamic_attr_dict, try to find it in tp_dict\n    res = PyObject_GenericGetAttr(self, attr);\n    if (!res) {\n      return PyErr_Format(PyExc_AttributeError, \"attribute %s not found\", PyUnicode_AsUTF8(attr));\n    }\n  }\n  return res;\n}\n\nint PyAutogradFunctionState_setattro(PyObject* self, PyObject* attr, PyObject* value) {\n  auto* _self = (PyAutogradFunctionState*)self;\n  return PyDict_SetItem(_self->dynamic_attr_dict, attr, value);\n}\n\nPyTypeObject PyAutogradFunctionState_Type = {\n    PyVarObject_HEAD_INIT(NULL, 0) \"oneflow.autograd.Function.FunctionCtx\", /* tp_name */\n    sizeof(PyAutogradFunctionState),                                        /* tp_basicsize */\n    0,                                                                      /* tp_itemsize */\n    (destructor)PyAutogradFunctionState_dealloc,                            /* tp_dealloc */\n    0,                                                    /* tp_vectorcall_offset */\n    NULL,                                                 /* tp_getattr */\n    NULL,                                                 /* tp_setattr */\n    NULL,                                                 /* tp_reserved */\n    NULL,                                                 /* tp_repr */\n    NULL,                                                 /* tp_as_number */\n    NULL,                                                 /* tp_as_sequence */\n    NULL,                                                 /* tp_as_mapping */\n    NULL,                                                 /* tp_hash  */\n    NULL,                                                 /* tp_call */\n    NULL,                                                 /* tp_str */\n    PyAutogradFunctionState_getattro,                     /* tp_getattro */\n    PyAutogradFunctionState_setattro,                     /* tp_setattro */\n    NULL,                                                 /* tp_as_buffer */\n    Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE,             /* tp_flags */\n    NULL,                                                 /* tp_doc */\n    NULL,                                                 /* tp_traverse */\n    NULL,                                                 /* tp_clear */\n    NULL,                                                 /* tp_richcompare */\n    0,                                                    /* tp_weaklistoffset */\n    NULL,                                                 /* tp_iter */\n    NULL,                                                 /* tp_iternext */\n    PyAutogradFunctionState_methods,                      /* tp_methods */\n    NULL,                                                 /* tp_members */\n    PyAutogradFunctionState_properties,                   /* tp_getset */\n    0,                                                    /* tp_base */\n    NULL,                                                 /* tp_dict */\n    NULL,                                                 /* tp_descr_get */\n    NULL,                                                 /* tp_descr_set */\n    offsetof(PyAutogradFunctionState, dynamic_attr_dict), /* tp_dictoffset */\n    NULL,                                                 /* tp_init */\n    NULL,                                                 /* tp_alloc */\n    PyAutogradFunctionState_new,                          /* tp_new */\n    NULL,                                                 /* tp_free */\n};\n\nPyObject* PyAutogradFunctionState_NewFromPtr(\n    const std::shared_ptr<FunctionAutoGradCaptureState>& data) {\n  if (!data) { Py_RETURN_NONE; }\n  if (data->pyobject()) { return PY_XINCREF((PyObject*)data->pyobject()); }\n  auto* self = (PyAutogradFunctionState*)(PyObject_CallObject(\n      (PyObject*)&PyAutogradFunctionState_Type, NULL));\n  if (self) {\n    PY_XINCREF(self);\n    self->data = data;\n    CheckAndGetStateData(self)->set_pyobject_ptr(\n        std::unique_ptr<void, void (*)(void*)>(self, [](void* ptr) { Py_DECREF((PyObject*)ptr); }));\n  }\n  return (PyObject*)self;\n}\n\nONEFLOW_API_PYBIND11_MODULE(\"autograd.Function\", m) {\n  if (PyType_Ready(&PyAutogradFunctionState_Type) < 0) { return; }\n  Py_INCREF(&PyAutogradFunctionState_Type);\n  if (PyModule_AddObject(m.ptr(), \"FunctionCtx\", (PyObject*)&PyAutogradFunctionState_Type) < 0) {\n    return;\n  }\n}\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/api/python/autograd/autograd_function_state.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_API_PYTHON_AUTOGRAD_AUTOGRAD_FUNCTION_STATE_H_\n#define ONEFLOW_API_PYTHON_AUTOGRAD_AUTOGRAD_FUNCTION_STATE_H_\n\n#include <Python.h>\n#undef _PyGC_FINALIZED\n#include <pybind11/pybind11.h>\n\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n\nnamespace oneflow {\nnamespace one {\n\ntypedef struct {\n  PyObject_HEAD;\n  PyObject* dynamic_attr_dict;\n  std::weak_ptr<FunctionAutoGradCaptureState> data;\n} PyAutogradFunctionState;\n\nextern PyTypeObject PyAutogradFunctionState_Type;\n\ninline bool PyAutogradFunctionState_Check(PyObject* state) {\n  return PyObject_TypeCheck(state, &PyAutogradFunctionState_Type);\n}\n\nPyObject* PyAutogradFunctionState_NewFromPtr(\n    const std::shared_ptr<FunctionAutoGradCaptureState>& data);\n\n}  // namespace one\n}  // namespace oneflow\n\n#endif  // ONEFLOW_API_PYTHON_AUTOGRAD_AUTOGRAD_FUNCTION_STATE_H_\n"
  },
  {
    "path": "oneflow/api/python/autograd/autograd_mode.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include <memory>\n#include <pybind11/pybind11.h>\n#include \"oneflow/api/python/of_api_registry.h\"\n#include \"oneflow/core/autograd/autograd_mode.h\"\n\nnamespace py = pybind11;\n\nnamespace oneflow {\n\nnamespace autograd {\n\nONEFLOW_API_PYBIND11_MODULE(\"autograd\", m) {\n  py::class_<AutoGradMode, std::shared_ptr<AutoGradMode>>(m, \"AutoGradMode\")\n      .def(py::init([](bool mode) { return std::make_shared<AutoGradMode>(mode); }))\n      .def(\"__enter__\", [](const AutoGradMode& no_grad_obj) {})\n      .def(\"__exit__\", [](const AutoGradMode& no_grad_obj, const py::object& type,\n                          const py::object& value, const py::object& traceback) {});\n  m.def(\"is_grad_enabled\", &GradMode::is_enabled);\n  m.def(\"set_grad_enabled\", &GradMode::set_enabled);\n}\n\n}  // namespace autograd\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/api/python/autograd/function_node.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include <vector>\n#include <pybind11/pybind11.h>\n#include <pybind11/stl.h>\n#include <pybind11/functional.h>\n#include \"oneflow/api/python/of_api_registry.h\"\n#include \"oneflow/core/autograd/autograd_engine.h\"\n\nnamespace py = pybind11;\n\nnamespace oneflow {\n\nnamespace {\n\nstruct FunctionNodeUtil final {\n  static std::string ToString(const one::FunctionNode& func_node) {\n    std::stringstream ss;\n    ss << \"<\";\n    ss << func_node.name();\n    ss << \" at \" << &func_node;\n    ss << \">\";\n    return ss.str();\n  }\n};\n\n}  // namespace\n\nONEFLOW_API_PYBIND11_MODULE(\"\", m) {\n  py::class_<one::FunctionNode, std::shared_ptr<one::FunctionNode>>(m, \"FunctionNode\")\n      .def(\"__str__\", &FunctionNodeUtil::ToString)\n      .def(\"__repr__\", &FunctionNodeUtil::ToString)\n      .def(\"_register_hook_dict\", []() { TODO(); })\n      .def_property_readonly(\n          \"next_functions\",\n          [](const one::FunctionNode& func_node) { return func_node.next_functions(); })\n      .def_property_readonly(\"metadata\", []() { TODO(); })\n      .def_property_readonly(\"requires_grad\", []() { TODO(); })\n      .def(\"register_hook\", &one::FunctionNode::add_post_hook)\n      .def(\"name\", [](const one::FunctionNode& func_node) { return func_node.name(); })\n      .def_property_readonly(\n          \"variable\", [](const one::FunctionNode& func_node) { return func_node.Variable(); });\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/api/python/caster/autograd_function_state.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_API_PYTHON_CASTER_AUTOGRAD_FUNCTION_STATE_H_\n#define ONEFLOW_API_PYTHON_CASTER_AUTOGRAD_FUNCTION_STATE_H_\n\n#include <pybind11/pybind11.h>\n\n#include \"oneflow/api/python/caster/common.h\"\n#include \"oneflow/api/python/autograd/autograd_function_state.h\"\n\nnamespace py = pybind11;\n\nnamespace pybind11 {\nnamespace detail {\n\ntemplate<typename T>\nstruct autograd_function_state_type_caster {\n public:\n  bool load(handle src, bool convert) {\n    using namespace oneflow::one;\n    value_ = nullptr;\n    if (!src) { return false; }\n    if (src.is_none()) { return true; }\n    if (!PyAutogradFunctionState_Check(src.ptr())) { return false; }\n    value_ = ((PyAutogradFunctionState*)src.ptr())->data;\n    return true;\n  }\n\n  template<typename U>\n  static handle cast(U&& src, return_value_policy policy, handle parent) {\n    using namespace oneflow::one;\n    return reinterpret_steal<object>(\n               PyAutogradFunctionState_NewFromPtr(\n                   std::const_pointer_cast<FunctionAutoGradCaptureState>(src)))\n        .release();\n  }\n\n  operator std::shared_ptr<T>*() { return &value_; }\n  operator std::shared_ptr<T>&() { return value_; }\n  operator std::shared_ptr<T>&&() && { return std::move(value_); }\n\n  static constexpr auto name = _(\"autograd_function_state\");\n\n protected:\n  std::shared_ptr<T> value_;\n};\n\ntemplate<>\nstruct type_caster<std::shared_ptr<oneflow::one::FunctionAutoGradCaptureState>>\n    : public autograd_function_state_type_caster<oneflow::one::FunctionAutoGradCaptureState> {};\ntemplate<>\nstruct type_caster<std::shared_ptr<const oneflow::one::FunctionAutoGradCaptureState>>\n    : public autograd_function_state_type_caster<const oneflow::one::FunctionAutoGradCaptureState> {\n};\n\n}  // namespace detail\n}  // namespace pybind11\n\n#endif  // ONEFLOW_API_PYTHON_CASTER_AUTOGRAD_FUNCTION_STATE_H_\n"
  },
  {
    "path": "oneflow/api/python/caster/common.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_API_PYTHON_CASTER_COMMON_H_\n#define ONEFLOW_API_PYTHON_CASTER_COMMON_H_\n\n#include <type_traits>\n#include <pybind11/pybind11.h>\n\nnamespace pybind11 {\nnamespace detail {\n\n// The condition follows the pybind11 source code\ntemplate<typename T>\nusing IsSupportedByPybind11WhenInsideSharedPtr =\n    std::is_base_of<type_caster_base<T>, type_caster<T>>;\n\n#define PYBIND11_TYPE_CASTER_WITH_SHARED_PTR(type, py_name)                               \\\n protected:                                                                               \\\n  std::shared_ptr<type> value;                                                            \\\n                                                                                          \\\n public:                                                                                  \\\n  static constexpr auto name = py_name;                                                   \\\n  template<typename T_, enable_if_t<std::is_same<type, remove_cv_t<T_>>::value, int> = 0> \\\n  static handle cast(T_* src, return_value_policy policy, handle parent) {                \\\n    if (!src) return none().release();                                                    \\\n    if (policy == return_value_policy::take_ownership) {                                  \\\n      auto h = cast(std::move(*src), policy, parent);                                     \\\n      delete src;                                                                         \\\n      return h;                                                                           \\\n    }                                                                                     \\\n    return cast(*src, policy, parent);                                                    \\\n  }                                                                                       \\\n  operator type*() { return value.get(); }                                                \\\n  operator type&() { return *value; }                                                     \\\n  operator type&&()&& { return std::move(*value); }                                       \\\n  template<typename T_>                                                                   \\\n  using cast_op_type = pybind11::detail::movable_cast_op_type<T_>\n\n}  // namespace detail\n}  // namespace pybind11\n\n#endif  // ONEFLOW_API_PYTHON_CASTER_COMMON_H_\n"
  },
  {
    "path": "oneflow/api/python/caster/maybe.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_API_PYTHON_CASTER_MAYBE_H_\n#define ONEFLOW_API_PYTHON_CASTER_MAYBE_H_\n#include <pybind11/pybind11.h>\n\n#include \"oneflow/api/python/caster/common.h\"\n#include \"oneflow/core/common/maybe.h\"\n\nnamespace pybind11 {\nnamespace detail {\n\nusing oneflow::Maybe;\n\nnamespace impl {\n\ntemplate<typename T>\nusing IsHoldedInsideSharedPtrByMaybe =\n    std::is_same<decltype(\n                     std::declval<Maybe<T>>().Data_YouAreNotAllowedToCallThisFuncOutsideThisFile()),\n                 std::shared_ptr<T>>;\n\ntemplate<typename T, typename std::enable_if_t<IsSupportedByPybind11WhenInsideSharedPtr<T>::value\n                                                   && IsHoldedInsideSharedPtrByMaybe<T>::value,\n                                               int> = 0>\nstd::shared_ptr<T> GetOrThrowHelper(Maybe<T> x) {\n  return x.GetPtrOrThrow();\n}\n\ntemplate<typename T, typename std::enable_if_t<!IsSupportedByPybind11WhenInsideSharedPtr<T>::value\n                                                   || !IsHoldedInsideSharedPtrByMaybe<T>::value,\n                                               int> = 0>\nT GetOrThrowHelper(Maybe<T> x) {\n  return x.GetOrThrow();\n}\n\n}  // namespace impl\n\n// Information about pybind11 custom type caster can be found\n// at oneflow/api/python/caster/optional.h, and also at\n// https://pybind11.readthedocs.io/en/stable/advanced/cast/custom.html\ntemplate<typename Type>\nstruct maybe_caster {\n  using Value = decltype(impl::GetOrThrowHelper(std::declval<Type>()));\n  using value_conv = make_caster<Value>;\n\n  bool load(handle src, bool convert) {\n    if (!src) { return false; }\n    if (src.is_none()) {\n      // Maybe<T> (except Maybe<void>) does not accept `None` from Python. Users can use Optional in\n      // those cases.\n      return false;\n    }\n    value_conv inner_caster;\n    if (!inner_caster.load(src, convert)) { return false; }\n\n    value = std::make_shared<Type>(cast_op<Value&&>(std::move(inner_caster)));\n    return true;\n  }\n\n  template<typename T>\n  static handle cast(T&& src, return_value_policy policy, handle parent) {\n    if (!std::is_lvalue_reference<T>::value) {\n      policy = return_value_policy_override<Value>::policy(policy);\n    }\n    return value_conv::cast(impl::GetOrThrowHelper(std::forward<T>(src)), policy, parent);\n  }\n\n  PYBIND11_TYPE_CASTER_WITH_SHARED_PTR(Maybe<void>, _(\"Maybe[void]\"));\n};\n\ntemplate<>\nstruct maybe_caster<Maybe<void>> {\n  template<typename T>\n  static handle cast(T&& src, return_value_policy policy, handle parent) {\n    if (!src.IsOk()) { oneflow::ThrowError(src.stacked_error()); }\n    return none().inc_ref();\n  }\n\n  bool load(handle src, bool convert) {\n    if (src && src.is_none()) {\n      return true;  // None is accepted because NoneType (i.e. void) is the value type of\n                    // Maybe<void>\n    }\n    return false;\n  }\n\n  PYBIND11_TYPE_CASTER_WITH_SHARED_PTR(Maybe<void>, _(\"Maybe[void]\"));\n};\n\ntemplate<typename T>\nstruct type_caster<Maybe<T>> : public maybe_caster<Maybe<T>> {};\n\n}  // namespace detail\n}  // namespace pybind11\n\n#endif  // ONEFLOW_API_PYTHON_CASTER_MAYBE_H_\n"
  },
  {
    "path": "oneflow/api/python/caster/optional.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_API_PYTHON_CASTER_OPTIONAL_H_\n#define ONEFLOW_API_PYTHON_CASTER_OPTIONAL_H_\n\n#include <pybind11/pybind11.h>\n\n#include \"oneflow/api/python/caster/common.h\"\n#include \"oneflow/core/common/optional.h\"\n\nnamespace pybind11 {\nnamespace detail {\n\nusing oneflow::Optional;\n\nnamespace impl {\n\ntemplate<typename T>\nT& DeferenceIfSharedPtr(std::shared_ptr<T> ptr) {\n  return *ptr;\n}\n\ntemplate<typename T>\nT&& DeferenceIfSharedPtr(T&& obj) {\n  return std::forward<T>(obj);\n}\n\ntemplate<typename T>\nusing IsHoldedInsideSharedPtrByOptional =\n    std::is_same<typename Optional<T>::storage_type, std::shared_ptr<T>>;\n\ntemplate<typename T, typename std::enable_if_t<IsSupportedByPybind11WhenInsideSharedPtr<T>::value\n                                                   && IsHoldedInsideSharedPtrByOptional<T>::value,\n                                               int> = 0>\nstd::shared_ptr<T> GetDataHelper(Optional<T> x) {\n  return CHECK_JUST(x);\n}\n\ntemplate<typename T, typename std::enable_if_t<!IsSupportedByPybind11WhenInsideSharedPtr<T>::value\n                                                   || !IsHoldedInsideSharedPtrByOptional<T>::value,\n                                               int> = 0>\nT GetDataHelper(Optional<T> x) {\n  return DeferenceIfSharedPtr<T>(CHECK_JUST(x));\n}\n\n}  // namespace impl\n\n// Code is copied from pybind11 include/pybind11/stl.h\n// Comments wrapped by /* */ are copied from\n// https://pybind11.readthedocs.io/en/stable/advanced/cast/custom.html\ntemplate<typename Type>\nstruct oneflow_optional_caster {\n  using Value = decltype(impl::GetDataHelper(std::declval<Type>()));\n  using value_conv = make_caster<Value>;\n\n  /**\n   * Conversion part 1 (Python->C++): convert a PyObject into a Optional<T>\n   * instance or return false upon failure. The second argument\n   * indicates whether implicit conversions should be applied.\n   */\n  bool load(handle src, bool convert) {\n    if (!src) { return false; }\n    if (src.is_none()) {\n      return true;  // default-constructed value is already empty\n    }\n    value_conv inner_caster;\n    if (!inner_caster.load(src, convert)) { return false; }\n\n    value = cast_op<Value&&>(std::move(inner_caster));\n    return true;\n  }\n\n  /**\n   * Conversion part 2 (C++ -> Python): convert an Optional<T> instance into\n   * a Python object. The second and third arguments are used to\n   * indicate the return value policy and parent object (for\n   * ``return_value_policy::reference_internal``) and are generally\n   * ignored by implicit casters.\n   */\n  template<typename T>\n  static handle cast(T&& src, return_value_policy policy, handle parent) {\n    if (!src) { return none().inc_ref(); }\n    if (!std::is_lvalue_reference<T>::value) {\n      policy = return_value_policy_override<Value>::policy(policy);\n    }\n    return value_conv::cast(impl::GetDataHelper(std::forward<T>(src)), policy, parent);\n  }\n\n  /**\n   * This macro establishes the name 'Optional[T]' in\n   * function signatures and declares a local variable\n   * 'value' of type inty\n   */\n  PYBIND11_TYPE_CASTER(Type, _(\"Optional[\") + value_conv::name + _(\"]\"));\n};\n\ntemplate<typename T>\nstruct type_caster<Optional<T>> : public oneflow_optional_caster<Optional<T>> {};\n\n}  // namespace detail\n}  // namespace pybind11\n\n#endif  // ONEFLOW_API_PYTHON_CASTER_OPTIONAL_H_\n"
  },
  {
    "path": "oneflow/api/python/caster/size.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_API_PYTHON_CASTER_SIZE_H_\n#define ONEFLOW_API_PYTHON_CASTER_SIZE_H_\n#include <type_traits>\n#include <Python.h>\n#undef _PyGC_FINALIZED\n#include <pybind11/pybind11.h>\n\n#include \"oneflow/api/python/framework/size.h\"\n#include \"oneflow/core/common/shape.h\"\n\nPYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE)\n\nclass shape : public object {\n public:\n  PYBIND11_OBJECT_CVT(shape, object, oneflow::TensorSize_Check, raw_shape)\n  explicit shape(size_t size = 0) : object(oneflow::TensorSize_New((ssize_t)size), stolen_t{}) {\n    if (!m_ptr) pybind11_fail(\"Could not allocate tensor size object!\");\n  }\n  size_t size() const { return (size_t)PyTuple_Size(m_ptr); }\n  bool empty() const { return size() == 0; }\n  detail::tuple_accessor operator[](size_t index) const { return {*this, index}; }\n  detail::item_accessor operator[](handle h) const { return object::operator[](h); }\n  detail::tuple_iterator begin() const { return {*this, 0}; }\n  detail::tuple_iterator end() const { return {*this, PyTuple_GET_SIZE(m_ptr)}; }\n\n private:\n  static PyObject* raw_shape(PyObject* op) {\n    if (oneflow::TensorSize_Check(op)) return handle(op).inc_ref().ptr();\n    return PyObject_CallFunctionObjArgs((PyObject*)&oneflow::TensorSize_Type, op, NULL);\n  }\n};\n\nPYBIND11_NAMESPACE_BEGIN(detail)\n\ntemplate<typename T>\nstruct shape_type_caster {\n public:\n  bool load(handle src, bool convert) {\n    value_ = nullptr;\n    if (src && src.is_none()) { return true; }\n    if (!oneflow::TensorSize_Check(src.ptr())) { return false; }\n    value_ = std::make_shared<T>(oneflow::TensorSize_AsShape(src.ptr()));\n    return true;\n  }\n\n  template<typename U>\n  static handle cast(U&& src, return_value_policy /*policy*/, handle /*parent*/) {\n    return cast_impl(std::forward<U>(src));\n  }\n\n  template<typename U>\n  static handle cast(U* src, return_value_policy policy, handle parent) {\n    if (!src) { return none().release(); }\n    return cast(*src, policy, parent);\n  }\n\n  operator T*() { return value_.get(); }\n  operator T&() { return *value_; }\n  operator T&&() && { return std::move(*value_); }\n\n  operator std::shared_ptr<T>*() { return &value_; }\n  operator std::shared_ptr<T>&() { return value_; }\n  operator std::shared_ptr<T>&&() && { return std::move(value_); }\n\n  static constexpr auto name = _(\"shape\");\n  template<typename U>\n  using cast_op_type = pybind11::detail::cast_op_type<std::shared_ptr<T>>;\n\n private:\n  static handle cast_impl(const oneflow::Shape& src) {\n    return reinterpret_steal<shape>(oneflow::TensorSize_NewFromShape(src)).release();\n  }\n  static handle cast_impl(const std::shared_ptr<const oneflow::Shape>& src) {\n    return reinterpret_steal<shape>(oneflow::TensorSize_NewFromShape(*src)).release();\n  }\n\n protected:\n  std::shared_ptr<T> value_;\n};\n\ntemplate<>\nstruct type_caster<oneflow::Shape> : public shape_type_caster<oneflow::Shape> {};\ntemplate<>\nstruct type_caster<std::shared_ptr<oneflow::Shape>> : public shape_type_caster<oneflow::Shape> {};\ntemplate<>\nstruct type_caster<std::shared_ptr<const oneflow::Shape>>\n    : public shape_type_caster<const oneflow::Shape> {};\n\nPYBIND11_NAMESPACE_END(detail)\nPYBIND11_NAMESPACE_END(PYBIND11_NAMESPACE)\n\n#endif  // ONEFLOW_API_PYTHON_CASTER_SIZE_H_\n"
  },
  {
    "path": "oneflow/api/python/caster/tensor.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_API_PYTHON_CASTER_TENSOR_H_\n#define ONEFLOW_API_PYTHON_CASTER_TENSOR_H_\n\n#include <pybind11/pybind11.h>\n\n#include \"oneflow/api/python/caster/common.h\"\n#include \"oneflow/api/python/framework/tensor.h\"\n\nnamespace pybind11 {\nnamespace detail {\n\ntemplate<typename T>\nstruct tensor_type_caster {\n public:\n  bool load(handle src, bool convert) {\n    using namespace oneflow::one;\n    value_ = nullptr;\n    if (!src) { return false; }\n    if (src.is_none()) { return true; }\n    if (!PyTensor_Check(src.ptr())) { return false; }\n    value_ = PyTensor_Unpack(src.ptr());\n    return true;\n  }\n\n  template<typename U>\n  static handle cast(U&& src, return_value_policy policy, handle parent) {\n    using namespace oneflow::one;\n    return reinterpret_steal<object>(PyTensor_New(std::const_pointer_cast<Tensor>(src))).release();\n  }\n\n  operator std::shared_ptr<T>*() { return &value_; }\n  operator std::shared_ptr<T>&() { return value_; }\n  operator std::shared_ptr<T>&&() && { return std::move(value_); }\n\n  static constexpr auto name = _(\"tensor\");\n  template<typename U>\n  using cast_op_type = pybind11::detail::cast_op_type<std::shared_ptr<T>>;\n\n protected:\n  std::shared_ptr<T> value_;\n};\n\ntemplate<typename T>\nstruct parameter_type_caster {\n public:\n  bool load(handle src, bool convert) {\n    using namespace oneflow::one;\n    value_ = nullptr;\n    if (!src) { return false; }\n    if (src.is_none()) { return true; }\n    if (!PyTensor_Check(src.ptr())) { return false; }\n    value_ = PyTensor_Unpack(src.ptr());\n    return true;\n  }\n\n  template<typename U>\n  static handle cast(U&& src, return_value_policy policy, handle parent) {\n    using namespace oneflow::one;\n    return reinterpret_steal<object>(PyParameter_New(std::const_pointer_cast<Parameter>(src)))\n        .release();\n  }\n\n  operator std::shared_ptr<T>*() { return &value_; }\n  operator std::shared_ptr<T>&() { return value_; }\n  operator std::shared_ptr<T>&&() && { return std::move(value_); }\n\n  static constexpr auto name = _(\"parameter\");\n  template<typename U>\n  using cast_op_type = pybind11::detail::cast_op_type<std::shared_ptr<T>>;\n\n protected:\n  std::shared_ptr<T> value_;\n};\n\ntemplate<>\nstruct type_caster<std::shared_ptr<oneflow::one::Tensor>>\n    : public tensor_type_caster<oneflow::one::Tensor> {};\ntemplate<>\nstruct type_caster<std::shared_ptr<const oneflow::one::Tensor>>\n    : public tensor_type_caster<const oneflow::one::Tensor> {};\n\ntemplate<>\nstruct type_caster<std::shared_ptr<oneflow::one::Parameter>>\n    : public parameter_type_caster<oneflow::one::Parameter> {};\ntemplate<>\nstruct type_caster<std::shared_ptr<const oneflow::one::Parameter>>\n    : public parameter_type_caster<const oneflow::one::Parameter> {};\n\n}  // namespace detail\n}  // namespace pybind11\n\n#endif  // ONEFLOW_API_PYTHON_CASTER_TENSOR_H_\n"
  },
  {
    "path": "oneflow/api/python/caster/test.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <pybind11/pybind11.h>\n#include \"oneflow/api/python/of_api_registry.h\"\n\nnamespace py = pybind11;\n\nnamespace oneflow {\n\nclass A {\n public:\n  void inc_x() { x++; }\n  int get_x() { return x; }\n\n private:\n  int x = 0;\n};\n\nstd::shared_ptr<A> get_singleton_a() {\n  static std::shared_ptr<A> a = std::make_shared<A>();\n  return a;\n}\n\nONEFLOW_API_PYBIND11_MODULE(\"test_api\", m) {\n  py::class_<A, std::shared_ptr<A>>(m, \"A\").def(\"inc_x\", &A::inc_x).def(\"get_x\", &A::get_x);\n\n  m.def(\"get_singleton_a\", []() -> Maybe<A> { return get_singleton_a(); });\n\n  m.def(\"increase_x_of_a_if_not_none\", [](const Optional<A>& a) -> Optional<A> {\n    a.map([](const std::shared_ptr<A>& a) -> std::shared_ptr<A> {\n      a->inc_x();\n      return a;\n    });\n    return a;\n  });\n\n  m.def(\"increase_if_not_none\",\n        [](const Optional<int>& x) -> Optional<int> { return x.map([](int i) { return i + 1; }); });\n\n  m.def(\"divide\", [](float x, float y) -> Maybe<float> {\n    CHECK_NE_OR_RETURN(y, 0);\n    return x / y;\n  });\n\n  m.def(\"throw_if_zero\", [](int x) -> Maybe<void> {\n    CHECK_NE_OR_RETURN(x, 0);\n    return Maybe<void>::Ok();\n  });\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/api/python/deprecated.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <pybind11/pybind11.h>\n#include \"oneflow/api/python/of_api_registry.h\"\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/common/protobuf.h\"\n#include \"oneflow/core/framework/dtype.h\"\n\nnamespace py = pybind11;\n\nnamespace oneflow {\n\nONEFLOW_API_PYBIND11_MODULE(\"deprecated\", m) {\n  m.def(\"GetProtoDtype4OfDtype\",\n        [](const Symbol<DType>& x) { return static_cast<int>(x->data_type()); });\n\n  m.def(\"GetDTypeByDataType\",\n        [](int data_type) { return DType::Get(static_cast<DataType>(data_type)); });\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/api/python/dlpack/converter.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/api/python/dlpack/dlpack.h\"\n#include \"oneflow/api/python/exception/exception.h\"\n#include \"oneflow/api/python/of_api_registry.h\"\n#include \"oneflow/core/common/data_type.h\"\n#include \"oneflow/core/eager/eager_blob_object.h\"\n#include \"oneflow/core/eager/tensor_storage.h\"\n#include \"oneflow/core/framework/tensor.h\"\n#include \"oneflow/core/framework/device.h\"\n#include \"oneflow/core/framework/tensor_util.h\"\n\nnamespace oneflow {\n\nMaybe<Symbol<Device>> ToOneFlowDevice(const DLDevice& ctx) {\n  switch (ctx.device_type) {\n    case DLDeviceType::kDLCPU: return JUST(Device::New(\"cpu\"));\n#ifdef WITH_CUDA\n    case DLDeviceType::kDLCUDA: return JUST(Device::New(\"cuda\", ctx.device_id));\n#endif\n    default: UNIMPLEMENTED_THEN_RETURN() << \"Unsupported device type: \" << ctx.device_type;\n  }\n}\n\nMaybe<DataType> ToOneFlowDataType(const DLDataType& dtype) {\n  DataType ofdtype = DataType::kInvalidDataType;\n  CHECK_EQ_OR_RETURN(dtype.lanes, 1) << \"OneFlow does not support lanes != 1\";\n  switch (dtype.code) {\n    case DLDataTypeCode::kDLUInt:\n      switch (dtype.bits) {\n        case 8: ofdtype = DataType::kUInt8; break;\n        default:\n          UNIMPLEMENTED_THEN_RETURN() << \"Unsupported data type: \" << dtype.code << dtype.bits;\n      }\n      break;\n    case DLDataTypeCode::kDLInt:\n      switch (dtype.bits) {\n        case 8: ofdtype = DataType::kInt8; break;\n        case 16: ofdtype = DataType::kInt16; break;\n        case 32: ofdtype = DataType::kInt32; break;\n        case 64: ofdtype = DataType::kInt64; break;\n        default:\n          UNIMPLEMENTED_THEN_RETURN() << \"Unsupported data type: \" << dtype.code << dtype.bits;\n      }\n      break;\n    case DLDataTypeCode::kDLFloat:\n      switch (dtype.bits) {\n        case 16: ofdtype = DataType::kFloat16; break;\n        case 32: ofdtype = DataType::kFloat; break;\n        case 64: ofdtype = DataType::kDouble; break;\n        default:\n          UNIMPLEMENTED_THEN_RETURN() << \"Unsupported data type: \" << dtype.code << dtype.bits;\n      }\n      break;\n    case DLDataTypeCode::kDLBfloat:\n      switch (dtype.bits) {\n        case 16: ofdtype = DataType::kBFloat16; break;\n        default: UNIMPLEMENTED_THEN_RETURN() << \"Unsupported data type: bfloat\" << dtype.bits;\n      }\n      break;\n    case DLDataTypeCode::kDLComplex:\n      UNIMPLEMENTED_THEN_RETURN() << \"Unsupported data type: complex\" << dtype.bits;\n      break;\n    default: UNIMPLEMENTED_THEN_RETURN() << \"Unsupported code \" << dtype.code;\n  }\n  CHECK_NE_OR_RETURN(ofdtype, DataType::kInvalidDataType);\n  return ofdtype;\n}\n\nMaybe<one::Tensor> fromDLPack(const DLManagedTensor* src) {\n  using namespace one;\n  const auto& dl_tensor = src->dl_tensor;\n\n  Symbol<Device> device = JUST(ToOneFlowDevice(dl_tensor.device));\n  DataType dtype = JUST(ToOneFlowDataType(dl_tensor.dtype));\n\n  // Build TensorMeta\n  const Shape shape(dl_tensor.shape, dl_tensor.shape + dl_tensor.ndim);\n  Symbol<LocalTensorMeta> tensor_meta;\n  if (dl_tensor.strides) {\n    const auto stride = Stride(dl_tensor.strides, dl_tensor.strides + dl_tensor.ndim);\n    tensor_meta =\n        SymbolOf(LocalTensorMeta(shape, stride, dtype, MemoryFormat::kContiguous, device));\n  } else {\n    tensor_meta = SymbolOf(LocalTensorMeta(shape, dtype, MemoryFormat::kContiguous, device));\n  }\n\n  // Build TensorBuffer\n  const auto& Free = [src](char* dptr) {\n    if (src->deleter) {\n      // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)\n      src->deleter(const_cast<DLManagedTensor*>(src));\n    }\n  };\n\n  size_t array_size_in_bytes = shape.elem_cnt() * GetSizeOfDataType(dtype);\n  auto tensor_data = std::make_shared<vm::TensorStorage>(false, device);\n  tensor_data->set_blob_dptr(\n      std::unique_ptr<char, std::function<void(char*)>>(static_cast<char*>(dl_tensor.data), Free),\n      array_size_in_bytes);\n\n  // Build TensorStorage: decrease ndarray reference count before releasing\n  auto tensor_storage = std::make_shared<TensorStorage>(tensor_data);\n\n  // Build Tensor\n  auto tensor_impl = std::make_shared<EagerLocalTensorImpl>(tensor_storage,\n                                                            /*requires_grad=*/false,\n                                                            /*ls_leaf=*/true);\n\n  // Init blob\n  JUST(tensor_impl->InitEagerBlobObject(tensor_meta, NewLocalDepObject()));\n  const auto& stream = JUST(GetDefaultStreamByDevice(device));\n  const auto& eager_blob_object = JUST(tensor_impl->eager_blob_object());\n  JUST(eager_blob_object->init_producer_stream(stream));\n  eager_blob_object->set_last_used_stream(stream);\n  return std::static_pointer_cast<Tensor>(std::make_shared<LocalTensor>(tensor_impl));\n}\n\nMaybe<DLDevice> ToDLDevice(Symbol<Device> ofdevice) {\n  DLDevice ctx;\n  ctx.device_id = ofdevice->device_id();\n  switch (ofdevice->enum_type()) {\n    case DeviceType::kCPU: ctx.device_type = DLDeviceType::kDLCPU; break;\n#ifdef WITH_CUDA\n    case DeviceType::kCUDA: ctx.device_type = DLDeviceType::kDLCUDA; break;\n#endif\n    default: UNIMPLEMENTED_THEN_RETURN() << \"Unsupported device type: \" << ofdevice->type();\n  }\n  return ctx;\n}\n\nMaybe<DLDataType> ToDLDataType(DataType ofdtype) {\n  DLDataType dtype;\n  dtype.lanes = 1;\n  dtype.bits = GetSizeOfDataType(ofdtype) * 8;\n  switch (ofdtype) {\n    case DataType::kUInt8: dtype.code = DLDataTypeCode::kDLUInt; break;\n    case DataType::kInt8: dtype.code = DLDataTypeCode::kDLInt; break;\n    case DataType::kInt16: dtype.code = DLDataTypeCode::kDLInt; break;\n    case DataType::kInt32: dtype.code = DLDataTypeCode::kDLInt; break;\n    case DataType::kInt64: dtype.code = DLDataTypeCode::kDLInt; break;\n    case DataType::kFloat16: dtype.code = DLDataTypeCode::kDLFloat; break;\n    case DataType::kFloat: dtype.code = DLDataTypeCode::kDLFloat; break;\n    case DataType::kDouble: dtype.code = DLDataTypeCode::kDLFloat; break;\n    case DataType::kBFloat16: dtype.code = DLDataTypeCode::kDLBfloat; break;\n    default: UNIMPLEMENTED_THEN_RETURN() << \"Unsupported data type: \" << DataType_Name(ofdtype);\n  }\n  return dtype;\n}\n\n// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)\nstruct ATenDLMTensor {\n  std::shared_ptr<one::Tensor> handle;\n  DLManagedTensor tensor;\n};\n\nvoid deleter(DLManagedTensor* arg) { delete static_cast<ATenDLMTensor*>(arg->manager_ctx); }\n\nMaybe<DLManagedTensor*> toDLPack(const std::shared_ptr<one::Tensor>& src) {\n  auto shape = *src->shape();\n  auto strides = *JUST(src->stride());\n  // create a new tensor with possibly normalized strides\n  // Reference:\n  // https://github.com/pytorch/pytorch/issues/83069\n  // https://github.com/pytorch/pytorch/issues/82610\n  for (int i = 0; i < src->ndim(); i++) {\n    if (shape[i] <= 1) { strides[i] = 1; }\n  }\n\n  ATenDLMTensor* atDLMTensor(new ATenDLMTensor);\n  atDLMTensor->handle = src;\n  atDLMTensor->tensor.manager_ctx = atDLMTensor;\n  atDLMTensor->tensor.deleter = &deleter;\n  JUST(one::SyncAccessTensorWithTimeOut(\n      src,\n      [&](ep::Stream*, const std::shared_ptr<vm::EagerBlobObject>& tensor) {\n        atDLMTensor->tensor.dl_tensor.data = tensor->mut_raw_dptr();\n      },\n      \"const\"));\n  auto dldevice = JUST(ToDLDevice(JUST(src->device())));\n  auto dldtype = JUST(ToDLDataType(src->dtype()->data_type()));\n  atDLMTensor->tensor.dl_tensor.device = *dldevice;\n  atDLMTensor->tensor.dl_tensor.ndim = src->ndim();\n  atDLMTensor->tensor.dl_tensor.dtype = *dldtype;\n  atDLMTensor->tensor.dl_tensor.shape =\n      // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)\n      const_cast<int64_t*>(src->shape()->data());\n  atDLMTensor->tensor.dl_tensor.strides =\n      // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)\n      const_cast<int64_t*>(JUST(src->stride())->data());\n  atDLMTensor->tensor.dl_tensor.byte_offset = 0;\n  return &(atDLMTensor->tensor);\n}\n\n// This function is mostly copied from PyTorch\nvoid DLPack_Capsule_Destructor(PyObject* data) {\n  if (likely(!PyCapsule_IsValid(data, \"dltensor\"))) {\n    // early out, see DLPack spec: if a consuming library sets the capsule\n    // name to something else, they own it and we don't need to do anything\n    return;\n  }\n  HANDLE_ERRORS\n  // Causes overheads for validity checks again, but this case is rare\n  // since consuming libraries should rename the capsule according to spec.\n  // Note that this cannot set a python error (we checked validity above),\n  // so we don't need to handle python error state here.\n  DLManagedTensor* dlMTensor = (DLManagedTensor*)PyCapsule_GetPointer(data, \"dltensor\");\n  // the dlMTensor has not been consumed, call deleter ourselves.\n  // DLPack spec mentions that deleter may be NULL, but deleter from\n  // `flow.to_dlpack` is never NULL, so no need for an additional check here.\n  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)\n  dlMTensor->deleter(const_cast<DLManagedTensor*>(dlMTensor));\n  END_HANDLE_ERRORS_RET()\n}\n\nnamespace py = pybind11;\n\nONEFLOW_API_PYBIND11_MODULE(\"\", m) {\n  m.def(\"to_dlpack\", [](const std::shared_ptr<one::Tensor>& tensor) -> Maybe<py::capsule> {\n    DLManagedTensor* dlMTensor = JUST(toDLPack(tensor));\n    return py::capsule(dlMTensor, \"dltensor\", DLPack_Capsule_Destructor);\n  });\n  // from_dlpack is exported in tensor_api.yaml\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/api/python/dlpack/converter.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/api/python/dlpack/dlpack.h\"\n#include \"oneflow/core/common/maybe.h\"\n\nnamespace oneflow {\n\nnamespace one {\nclass Tensor;\n}\n\nMaybe<one::Tensor> fromDLPack(const DLManagedTensor* src);\nMaybe<DLManagedTensor*> toDLPack(const std::shared_ptr<one::Tensor>& src);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/api/python/dlpack/dlpack.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n/*!\n *  Copyright (c) 2017 by Contributors\n * \\file dlpack.h\n * \\brief The common header of DLPack.\n */\n#ifndef DLPACK_DLPACK_H_\n#define DLPACK_DLPACK_H_\n\n/**\n * \\brief Compatibility with C++\n */\n#ifdef __cplusplus\n#define DLPACK_EXTERN_C extern \"C\"\n#else\n#define DLPACK_EXTERN_C\n#endif\n\n/*! \\brief The current version of dlpack */\n#define DLPACK_VERSION 70\n\n/*! \\brief The current ABI version of dlpack */\n#define DLPACK_ABI_VERSION 1\n\n/*! \\brief DLPACK_DLL prefix for windows */\n#ifdef _WIN32\n#ifdef DLPACK_EXPORTS\n#define DLPACK_DLL __declspec(dllexport)\n#else\n#define DLPACK_DLL __declspec(dllimport)\n#endif\n#else\n#define DLPACK_DLL\n#endif\n\n#include <stdint.h>\n#include <stddef.h>\n\n#ifdef __cplusplus\nextern \"C\" {\n#endif\n/*!\n * \\brief The device type in DLDevice.\n */\n#ifdef __cplusplus\ntypedef enum : int32_t {\n#else\ntypedef enum {\n#endif\n  /*! \\brief CPU device */\n  kDLCPU = 1,\n  /*! \\brief CUDA GPU device */\n  kDLCUDA = 2,\n  /*!\n   * \\brief Pinned CUDA CPU memory by cudaMallocHost\n   */\n  kDLCUDAHost = 3,\n  /*! \\brief OpenCL devices. */\n  kDLOpenCL = 4,\n  /*! \\brief Vulkan buffer for next generation graphics. */\n  kDLVulkan = 7,\n  /*! \\brief Metal for Apple GPU. */\n  kDLMetal = 8,\n  /*! \\brief Verilog simulator buffer */\n  kDLVPI = 9,\n  /*! \\brief ROCm GPUs for AMD GPUs */\n  kDLROCM = 10,\n  /*!\n   * \\brief Pinned ROCm CPU memory allocated by hipMallocHost\n   */\n  kDLROCMHost = 11,\n  /*!\n   * \\brief Reserved extension device type,\n   * used for quickly test extension device\n   * The semantics can differ depending on the implementation.\n   */\n  kDLExtDev = 12,\n  /*!\n   * \\brief CUDA managed/unified memory allocated by cudaMallocManaged\n   */\n  kDLCUDAManaged = 13,\n  /*!\n   * \\brief Unified shared memory allocated on a oneAPI non-partititioned\n   * device. Call to oneAPI runtime is required to determine the device\n   * type, the USM allocation type and the sycl context it is bound to.\n   *\n   */\n  kDLOneAPI = 14,\n  /*! \\brief GPU support for next generation WebGPU standard. */\n  kDLWebGPU = 15,\n  /*! \\brief Qualcomm Hexagon DSP */\n  kDLHexagon = 16,\n} DLDeviceType;\n\n/*!\n * \\brief A Device for Tensor and operator.\n */\ntypedef struct {\n  /*! \\brief The device type used in the device. */\n  DLDeviceType device_type;\n  /*!\n   * \\brief The device index.\n   * For vanilla CPU memory, pinned memory, or managed memory, this is set to 0.\n   */\n  int32_t device_id;\n} DLDevice;\n\n/*!\n * \\brief The type code options DLDataType.\n */\ntypedef enum {\n  /*! \\brief signed integer */\n  kDLInt = 0U,\n  /*! \\brief unsigned integer */\n  kDLUInt = 1U,\n  /*! \\brief IEEE floating point */\n  kDLFloat = 2U,\n  /*!\n   * \\brief Opaque handle type, reserved for testing purposes.\n   * Frameworks need to agree on the handle data type for the exchange to be well-defined.\n   */\n  kDLOpaqueHandle = 3U,\n  /*! \\brief bfloat16 */\n  kDLBfloat = 4U,\n  /*!\n   * \\brief complex number\n   * (C/C++/Python layout: compact struct per complex number)\n   */\n  kDLComplex = 5U,\n} DLDataTypeCode;\n\n/*!\n * \\brief The data type the tensor can hold. The data type is assumed to follow the\n * native endian-ness. An explicit error message should be raised when attempting to\n * export an array with non-native endianness\n *\n *  Examples\n *   - float: type_code = 2, bits = 32, lanes=1\n *   - float4(vectorized 4 float): type_code = 2, bits = 32, lanes=4\n *   - int8: type_code = 0, bits = 8, lanes=1\n *   - std::complex<float>: type_code = 5, bits = 64, lanes = 1\n */\ntypedef struct {\n  /*!\n   * \\brief Type code of base types.\n   * We keep it uint8_t instead of DLDataTypeCode for minimal memory\n   * footprint, but the value should be one of DLDataTypeCode enum values.\n   * */\n  uint8_t code;\n  /*!\n   * \\brief Number of bits, common choices are 8, 16, 32.\n   */\n  uint8_t bits;\n  /*! \\brief Number of lanes in the type, used for vector types. */\n  uint16_t lanes;\n} DLDataType;\n\n/*!\n * \\brief Plain C Tensor object, does not manage memory.\n */\ntypedef struct {\n  /*!\n   * \\brief The data pointer points to the allocated data. This will be CUDA\n   * device pointer or cl_mem handle in OpenCL. It may be opaque on some device\n   * types. This pointer is always aligned to 256 bytes as in CUDA. The\n   * `byte_offset` field should be used to point to the beginning of the data.\n   *\n   * Note that as of Nov 2021, multiply libraries (CuPy, PyTorch, TensorFlow,\n   * TVM, perhaps others) do not adhere to this 256 byte aligment requirement\n   * on CPU/CUDA/ROCm, and always use `byte_offset=0`.  This must be fixed\n   * (after which this note will be updated); at the moment it is recommended\n   * to not rely on the data pointer being correctly aligned.\n   *\n   * For given DLTensor, the size of memory required to store the contents of\n   * data is calculated as follows:\n   *\n   * \\code{.c}\n   * static inline size_t GetDataSize(const DLTensor* t) {\n   *   size_t size = 1;\n   *   for (tvm_index_t i = 0; i < t->ndim; ++i) {\n   *     size *= t->shape[i];\n   *   }\n   *   size *= (t->dtype.bits * t->dtype.lanes + 7) / 8;\n   *   return size;\n   * }\n   * \\endcode\n   */\n  void* data;\n  /*! \\brief The device of the tensor */\n  DLDevice device;\n  /*! \\brief Number of dimensions */\n  int32_t ndim;\n  /*! \\brief The data type of the pointer*/\n  DLDataType dtype;\n  /*! \\brief The shape of the tensor */\n  int64_t* shape;\n  /*!\n   * \\brief strides of the tensor (in number of elements, not bytes)\n   *  can be NULL, indicating tensor is compact and row-majored.\n   */\n  int64_t* strides;\n  /*! \\brief The offset in bytes to the beginning pointer to data */\n  uint64_t byte_offset;\n} DLTensor;\n\n/*!\n * \\brief C Tensor object, manage memory of DLTensor. This data structure is\n *  intended to facilitate the borrowing of DLTensor by another framework. It is\n *  not meant to transfer the tensor. When the borrowing framework doesn't need\n *  the tensor, it should call the deleter to notify the host that the resource\n *  is no longer needed.\n */\ntypedef struct DLManagedTensor {\n  /*! \\brief DLTensor which is being memory managed */\n  DLTensor dl_tensor;\n  /*! \\brief the context of the original host framework of DLManagedTensor in\n   *   which DLManagedTensor is used in the framework. It can also be NULL.\n   */\n  void* manager_ctx;\n  /*! \\brief Destructor signature void (*)(void*) - this should be called\n   *   to destruct manager_ctx which holds the DLManagedTensor. It can be NULL\n   *   if there is no way for the caller to provide a reasonable destructor.\n   *   The destructors deletes the argument self as well.\n   */\n  void (*deleter)(struct DLManagedTensor* self);\n} DLManagedTensor;\n#ifdef __cplusplus\n}  // DLPACK_EXTERN_C\n#endif\n#endif  // DLPACK_DLPACK_H_\n"
  },
  {
    "path": "oneflow/api/python/eager/eager.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <pybind11/pybind11.h>\n#include \"oneflow/api/python/of_api_registry.h\"\n#include \"oneflow/core/vm/vm_util.h\"\n#include \"oneflow/core/eager/dev_vm_dep_object_consume_mode.h\"\n\nONEFLOW_API_PYBIND11_MODULE(\"eager\", m) {\n  using namespace oneflow;\n  namespace py = pybind11;\n  m.def(\n      \"Sync\", []() { return vm::CurrentRankSync(); }, py::call_guard<py::gil_scoped_release>());\n  m.def(\n      \"ClusterSync\", []() { return vm::ClusterSync(); }, py::call_guard<py::gil_scoped_release>());\n\n  py::class_<one::DevVmDepObjectConsumeModeGuard,\n             std::shared_ptr<one::DevVmDepObjectConsumeModeGuard>>(\n      m, \"DevVmDepObjectConsumeModeGuard\");\n\n  m.def(\"SourceOpOnlyResourceDependenceModeGuard\", []() {\n    return std::make_shared<one::DevVmDepObjectConsumeModeGuard>(\n        one::DevVmDepObjectConsumeMode::NONE);\n  });\n}\n"
  },
  {
    "path": "oneflow/api/python/env/env.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <pybind11/pybind11.h>\n#include \"oneflow/api/python/env/env.h\"\n#include \"oneflow/api/python/of_api_registry.h\"\n#include \"oneflow/core/job/env_global_objects_scope.h\"\n#include \"oneflow/core/common/singleton.h\"\n#include \"oneflow/core/job/graph_scope_vars.h\"\n#include \"oneflow/core/vm/vm_util.h\"\n#include \"oneflow/core/vm/virtual_machine.h\"\n#include \"oneflow/core/framework/shut_down_util.h\"\n#include \"oneflow/core/device/cuda_util.h\"\n#include \"oneflow/core/common/mem_util.h\"\n\n#ifdef WITH_CUDA\n#include <cuda.h>\n#endif  // WITH_CUDA\n\nnamespace py = pybind11;\n\nnamespace oneflow {\n\n#ifdef WITH_CUDA\n\nvoid RegisterCudaDeviceProperties(py::module& m) {\n  py::class_<cudaDeviceProp>(m, \"_CudaDeviceProperties\", py::module_local())\n      .def(py::init<>())\n      .def_readonly(\"name\", &cudaDeviceProp::name)\n      .def_readonly(\"major\", &cudaDeviceProp::major)\n      .def_readonly(\"minor\", &cudaDeviceProp::minor)\n      .def_readonly(\"is_multi_gpu_board\", &cudaDeviceProp::isMultiGpuBoard)\n      .def_readonly(\"is_integrated\", &cudaDeviceProp::integrated)\n      .def_readonly(\"multi_processor_count\", &cudaDeviceProp::multiProcessorCount)\n      .def_readonly(\"total_memory\", &cudaDeviceProp::totalGlobalMem)\n      .def(\"__repr__\", [](const cudaDeviceProp& prop) {\n        std::ostringstream stream;\n        stream << \"_CudaDeviceProperties(name='\" << prop.name << \"', major=\" << prop.major\n               << \", minor=\" << prop.minor\n               << \", total_memory=\" << prop.totalGlobalMem / (1024 * 1024)\n               << \"MB, multi_processor_count=\" << prop.multiProcessorCount << \")\";\n        return stream.str();\n      });\n}\n\n#endif  // WITH_CUDA\n\nMaybe<void> SwitchToShuttingDownPhase(EnvGlobalObjectsScope* env, bool is_normal_exit) {\n  JUST(env->init_is_normal_exit(is_normal_exit));\n  SetShuttingDown(true);\n  if (is_normal_exit) {\n    JUST(vm::ClusterSync());\n    auto* vm = JUST(SingletonMaybe<VirtualMachine>());\n    JUST(vm->CloseVMThreads());\n  }\n  return Maybe<void>::Ok();\n}\n\nONEFLOW_API_PYBIND11_MODULE(\"\", m) {\n  m.def(\"CurrentResource\", &CurrentResource);\n  m.def(\"EnvResource\", &EnvResource);\n\n  py::class_<oneflow::EnvGlobalObjectsScope, std::shared_ptr<oneflow::EnvGlobalObjectsScope>>(\n      m, \"EnvContext\")\n      .def(py::init<const std::string&>())\n      .def(\"SwitchToShuttingDownPhase\", &SwitchToShuttingDownPhase,\n           py::call_guard<py::gil_scoped_release>());\n\n  m.def(\"CurrentMachineId\", &CurrentMachineId);\n\n  m.def(\"GetRank\", &GetRank);\n  m.def(\"GetWorldSize\", &GetWorldSize);\n  m.def(\"GetNodeSize\", &GetNodeSize);\n  m.def(\"GetLocalRank\", &GetLocalRank);\n  m.def(\"InitRDMA\", &InitRDMA);\n  m.def(\"RDMAIsInitialized\", &RDMAIsInitialized);\n  m.def(\"DestoryRDMA\", &DestoryRDMA);\n  m.def(\"CudaGetDeviceCount\", &CudaGetDeviceCount);\n  m.def(\"EmptyCache\", &EmptyCache);\n#ifdef WITH_CUDA\n  RegisterCudaDeviceProperties(m);\n  m.def(\"GetCudaDeviceIndex\", &GetCudaDeviceIndex);\n  m.def(\"SetCudaDeviceIndex\", &SetCudaDeviceIndex);\n  m.def(\"CudaSynchronize\", &CudaSynchronize);\n  m.def(\"GetCUDAMemoryUsed\", &GetCUDAMemoryUsed);\n  m.def(\"GetCPUMemoryUsed\", &GetCPUMemoryUsed);\n  m.def(\"CudaMemGetInfo\", [](int device) -> std::pair<size_t, size_t> {\n    CudaCurrentDeviceGuard guard(device);\n    size_t device_free = 0;\n    size_t device_total = 0;\n    OF_CUDA_CHECK(cudaMemGetInfo(&device_free, &device_total));\n    return {device_free, device_total};\n  });\n  m.def(\n      \"_get_device_properties\",\n      [](int device) -> cudaDeviceProp* { return GetDeviceProperties(device); },\n      py::return_value_policy::reference);\n#endif  // WITH_CUDA\n  m.def(\"SetFLAGS_alsologtostderr\", &SetFLAGS_alsologtostderr);\n  m.def(\"GetFLAGS_alsologtostderr\", &GetFLAGS_alsologtostderr);\n  m.def(\"SetFLAGS_v\", &SetFLAGS_v);\n  m.def(\"GetFLAGS_v\", &GetFLAGS_v);\n  m.def(\"SetGraphLRVerbose\", &SetGraphLRVerbose);\n  m.def(\"GetGraphLRVerbose\", &GetGraphLRVerbose);\n  m.def(\"SetGraphDebugMaxPyStackDepth\", &SetGraphDebugMaxPyStackDepth);\n  m.def(\"GetGraphDebugMaxPyStackDepth\", &GetGraphDebugMaxPyStackDepth);\n  m.def(\"SetGraphDebugMode\", &SetGraphDebugMode);\n  m.def(\"GetGraphDebugMode\", &GetGraphDebugMode);\n  m.def(\"SetGraphDebugOnlyUserPyStack\", &SetGraphDebugOnlyUserPyStack);\n  m.def(\"GetGraphDebugOnlyUserPyStack\", &GetGraphDebugOnlyUserPyStack);\n  m.def(\"InitPythonPathsToBeKeptAndFilteredForDebugging\",\n        &InitPythonPathsToBeKeptAndFilteredForDebugging);\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/api/python/env/env.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_API_PYTHON_ENV_ENV_H_\n#define ONEFLOW_API_PYTHON_ENV_ENV_H_\n\n#include <string>\n#include <google/protobuf/text_format.h>\n#include \"oneflow/core/common/protobuf.h\"\n#include \"oneflow/core/common/singleton.h\"\n#include \"oneflow/core/job/cluster_instruction.h\"\n#include \"oneflow/core/job/env_global_objects_scope.h\"\n#include \"oneflow/core/job/global_for.h\"\n#include \"oneflow/core/job/resource_desc.h\"\n#include \"oneflow/core/job/graph_scope_vars.h\"\n#include \"oneflow/core/control/global_process_ctx.h\"\n#include \"oneflow/core/rpc/include/base.h\"\n#include \"oneflow/core/ep/include/device_manager_registry.h\"\n#include \"oneflow/core/vm/vm_util.h\"\n#include \"oneflow/core/vm/virtual_machine.h\"\n\nnamespace oneflow {\n\ninline Maybe<std::string> CurrentResource() {\n  CHECK_NOTNULL_OR_RETURN((Singleton<ResourceDesc, ForSession>::Get()));\n  return PbMessage2TxtString(Singleton<ResourceDesc, ForSession>::Get()->resource());\n}\n\ninline Maybe<std::string> EnvResource() {\n  CHECK_NOTNULL_OR_RETURN((Singleton<ResourceDesc, ForEnv>::Get()));\n  return PbMessage2TxtString(Singleton<ResourceDesc, ForEnv>::Get()->resource());\n}\n\ninline Maybe<long long> CurrentMachineId() { return GlobalProcessCtx::Rank(); }\n\ninline Maybe<int64_t> GetRank() { return GlobalProcessCtx::Rank(); }\ninline Maybe<size_t> GetWorldSize() { return GlobalProcessCtx::WorldSize(); }\ninline Maybe<size_t> GetNodeSize() { return GlobalProcessCtx::NodeSize(); }\ninline Maybe<size_t> GetLocalRank() { return GlobalProcessCtx::LocalRank(); }\ninline Maybe<size_t> CudaGetDeviceCount() {\n  return Singleton<ep::DeviceManagerRegistry>::Get()->GetDeviceCount(DeviceType::kCUDA);\n}\n\ninline Maybe<void> SetFLAGS_alsologtostderr(bool flag) {\n  FLAGS_alsologtostderr = flag;\n  return Maybe<void>::Ok();\n}\ninline Maybe<bool> GetFLAGS_alsologtostderr() {\n  return FLAGS_alsologtostderr;\n}  // namespace oneflow\ninline Maybe<void> SetFLAGS_v(int32_t v_level) {\n  FLAGS_v = v_level;\n  return Maybe<void>::Ok();\n}\ninline Maybe<int32_t> GetFLAGS_v() { return FLAGS_v; }\n\ninline Maybe<void> EmptyCache() {\n  JUST(vm::CurrentRankSync());\n  auto* vm = JUST(SingletonMaybe<VirtualMachine>());\n  JUST(vm->ShrinkAllMem());\n  return Maybe<void>::Ok();\n}\n\ninline Maybe<void> SetGraphLRVerbose(bool verbose) {\n  SetGraphVerboseStepLr(verbose);\n  return Maybe<void>::Ok();\n}\ninline bool GetGraphLRVerbose() { return IsOpenGraphVerboseStepLr(); }\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_API_PYTHON_ENV_ENV_H_\n"
  },
  {
    "path": "oneflow/api/python/ep/cuda_matmul_mode.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include <memory>\n#include <pybind11/pybind11.h>\n#include \"oneflow/api/python/of_api_registry.h\"\n#include \"oneflow/core/ep/cuda/cuda_matmul_mode.h\"\n\nnamespace py = pybind11;\n\nnamespace oneflow {\n\nnamespace ep {\n\nONEFLOW_API_PYBIND11_MODULE(\"ep\", m) {\n  m.def(\"is_matmul_allow_tf32\", &CudaMatmulMode::is_matmul_allow_tf32);\n  m.def(\"set_matmul_allow_tf32\", &CudaMatmulMode::set_matmul_allow_tf32);\n  m.def(\"is_matmul_allow_fp16_reduced_precision_reduction\",\n        &CudaMatmulMode::is_matmul_allow_fp16_reduced_precision_reduction);\n  m.def(\"set_matmul_allow_fp16_reduced_precision_reduction\",\n        &CudaMatmulMode::set_matmul_allow_fp16_reduced_precision_reduction);\n}\n\n}  // namespace ep\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/api/python/exception/exception.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <pybind11/pybind11.h>\n#include \"oneflow/core/common/exception.h\"\n#include \"oneflow/core/common/error.h\"\n#include \"oneflow/api/python/of_api_registry.h\"\n\nnamespace py = pybind11;\n\nnamespace oneflow {\n\nONEFLOW_API_PYBIND11_MODULE(\"exception\", m) {\n  m.def(\"GetThreadLocalLastError\", &ThreadLocalError);\n  py::register_exception<oneflow::Exception>(m, \"Exception\");\n  py::register_exception<oneflow::RuntimeException>(m, \"RuntimeError\", PyExc_RuntimeError);\n  py::register_exception<oneflow::TypeException>(m, \"TypeError\", PyExc_TypeError);\n  py::register_exception<oneflow::IndexException>(m, \"IndexError\", PyExc_IndexError);\n  py::register_exception<oneflow::NotImplementedException>(m, \"NotImplementedError\",\n                                                           PyExc_NotImplementedError);\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/api/python/exception/exception.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_API_PYTHON_COMMON_EXCEPTION_H_\n#define ONEFLOW_API_PYTHON_COMMON_EXCEPTION_H_\n\n#include <Python.h>\n#undef _PyGC_FINALIZED\n#include <pybind11/pybind11.h>\n\n#include \"oneflow/core/common/exception.h\"\n\nnamespace py = pybind11;\n\n#define HANDLE_ERRORS try {\n#define END_HANDLE_ERRORS_RETSTMT(retstmt)                \\\n  }                                                       \\\n  catch (py::error_already_set & e) {                     \\\n    e.restore();                                          \\\n    retstmt;                                              \\\n  }                                                       \\\n  catch (const oneflow::RuntimeException& e) {            \\\n    PyErr_SetString(PyExc_RuntimeError, e.what());        \\\n    retstmt;                                              \\\n  }                                                       \\\n  catch (const oneflow::IndexException& e) {              \\\n    PyErr_SetString(PyExc_IndexError, e.what());          \\\n    retstmt;                                              \\\n  }                                                       \\\n  catch (const oneflow::TypeException& e) {               \\\n    PyErr_SetString(PyExc_TypeError, e.what());           \\\n    retstmt;                                              \\\n  }                                                       \\\n  catch (const oneflow::NotImplementedException& e) {     \\\n    PyErr_SetString(PyExc_NotImplementedError, e.what()); \\\n    retstmt;                                              \\\n  }                                                       \\\n  catch (const std::exception& e) {                       \\\n    PyErr_SetString(PyExc_RuntimeError, e.what());        \\\n    retstmt;                                              \\\n  }\n\n#define END_HANDLE_ERRORS END_HANDLE_ERRORS_RETSTMT(return NULL)\n#define END_HANDLE_ERRORS_RET(retval) END_HANDLE_ERRORS_RETSTMT(return retval)\n#define END_HANDLE_ERRORS_NORET END_HANDLE_ERRORS_RETSTMT(void)\n\n#endif  // ONEFLOW_API_PYTHON_COMMON_EXCEPTION_H_\n"
  },
  {
    "path": "oneflow/api/python/flags.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/api/python/of_api_registry.h\"\n#ifdef WITH_CUDA\n#include <cuda.h>\n#endif\n\nnamespace oneflow {\n\nONEFLOW_API_PYBIND11_MODULE(\"flags\", m) {\n  m.def(\"with_cuda\", []() {\n#ifdef WITH_CUDA\n    return true;\n#else\n    return false;\n#endif  // WITH_CUDA\n  });\n\n  m.def(\"with_npu\", []() {\n#ifdef WITH_NPU\n    return true;\n#else\n    return false;\n#endif  // WITH_NPU\n  });\n\n  m.def(\"with_mlu\", []() {\n#ifdef WITH_MLU\n    return true;\n#else\n    return false;\n#endif  // WITH_MLU\n  });\n\n  m.def(\"cuda_version\", []() {\n#ifdef WITH_CUDA\n    return CUDA_VERSION;\n#else\n    return 0;\n#endif  // WITH_CUDA\n  });\n\n  m.def(\"use_cxx11_abi\", []() {\n#if _GLIBCXX_USE_CXX11_ABI == 1\n    return true;\n#else\n    return false;\n#endif  // _GLIBCXX_USE_CXX11_ABI\n  });\n\n  m.def(\"with_mlir\", []() {\n#ifdef WITH_MLIR\n    return true;\n#else\n    return false;\n#endif  // WITH_MLIR\n  });\n\n  m.def(\"with_mlir_cuda_codegen\", []() {\n#ifdef WITH_MLIR_CUDA_CODEGEN\n    return true;\n#else\n    return false;\n#endif  // WITH_MLIR_CUDA_CODEGEN\n  });\n\n  m.def(\"with_rdma\", []() {\n#ifdef WITH_RDMA\n    return true;\n#else\n    return false;\n#endif  // WITH_RDMA\n  });\n\n  m.def(\"has_rpc_backend_grpc\", []() {\n#ifdef RPC_BACKEND_GRPC\n    return true;\n#else\n    return false;\n#endif  // RPC_BACKEND_GRPC\n  });\n\n  m.def(\"has_rpc_backend_local\", []() {\n#ifdef RPC_BACKEND_LOCAL\n    return true;\n#else\n    return false;\n#endif  // RPC_BACKEND_LOCAL\n  });\n\n#define STRINGIFY(x) STRINGIFY_(x)\n#define STRINGIFY_(x) #x\n  m.def(\"cmake_build_type\", []() {\n#ifdef ONEFLOW_CMAKE_BUILD_TYPE\n    return std::string(STRINGIFY(ONEFLOW_CMAKE_BUILD_TYPE));\n#else\n    return std::string(\"Undefined\");\n#endif  // ONEFLOW_CMAKE_BUILD_TYPE\n  });\n#undef STRINGIFY\n#undef STRINGIFY_\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/api/python/framework/autocast.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <pybind11/pybind11.h>\n#include \"oneflow/api/python/of_api_registry.h\"\n\n#include \"oneflow/core/common/throw.h\"\n#include \"oneflow/core/framework/autocast.h\"\n\nnamespace py = pybind11;\n\nnamespace oneflow {\n\nsize_t* nested_count() {\n  static thread_local size_t _nested_count = 0;\n  return &_nested_count;\n}\n\nbool is_nested_count_zero() { return (*nested_count()) == 0; }\nvoid increase_nested_count() { (*nested_count())++; }\nvoid decrease_nested_count() { (*nested_count())--; }\n\nclass AutoCastMode {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(AutoCastMode);\n\n  AutoCastMode(const std::string& device_type, Symbol<DType> dtype, bool enabled,\n               bool cache_enabled)\n      : prev_enabled_(autocast::is_enabled()),\n        prev_cache_enabled_(autocast::is_autocast_cache_enabled()),\n        prev_device_type_(autocast::get_autocast_device_type()),\n        prev_dtype_(autocast::get_autocast_dtype()),\n        prev_gpu_dtype_(autocast::get_autocast_gpu_dtype()),\n        prev_cpu_dtype_(autocast::get_autocast_cpu_dtype()) {\n    // update autocast state\n    increase_nested_count();\n    autocast::set_enabled(enabled);\n    autocast::set_autocast_cache_enabled(cache_enabled);\n    if (device_type == \"cpu\") {\n      autocast::set_autocast_device_type(kCPU);\n      autocast::set_autocast_dtype(dtype);\n      autocast::set_autocast_cpu_dtype(dtype);\n    } else if (device_type == \"cuda\") {\n      autocast::set_autocast_device_type(kCUDA);\n      autocast::set_autocast_dtype(dtype);\n      autocast::set_autocast_gpu_dtype(dtype);\n    } else {\n      THROW(RuntimeError) << \"User specified autocast device_type must be 'cuda' or 'cpu'\";\n    }\n  }\n\n  ~AutoCastMode() {\n    decrease_nested_count();\n    autocast::set_enabled(prev_enabled_);\n    autocast::set_autocast_cache_enabled(prev_cache_enabled_);\n    autocast::set_autocast_device_type(prev_device_type_);\n    autocast::set_autocast_dtype(prev_dtype_);\n    autocast::set_autocast_gpu_dtype(prev_gpu_dtype_);\n    autocast::set_autocast_cpu_dtype(prev_cpu_dtype_);\n    if ((!prev_enabled_ || !prev_cache_enabled_) && is_nested_count_zero()) {\n      autocast::clear_cache();\n    }\n  }\n\n private:\n  bool prev_enabled_;\n  bool prev_cache_enabled_;\n  DeviceType prev_device_type_;\n  Symbol<DType> prev_dtype_;\n  Symbol<DType> prev_gpu_dtype_;\n  Symbol<DType> prev_cpu_dtype_;\n};\n\nONEFLOW_API_PYBIND11_MODULE(\"\", m) {\n  py::class_<AutoCastMode, std::shared_ptr<AutoCastMode>>(m, \"AutoCastMode\")\n      .def(py::init([](const std::string& device_type, Symbol<DType> dtype, bool enabled,\n                       bool cache_enabled) {\n        return std::make_shared<AutoCastMode>(device_type, dtype, enabled, cache_enabled);\n      }));\n\n  m.def(\"is_autocast_enabled\", autocast::is_enabled);\n  m.def(\"set_autocast_enabled\", autocast::set_enabled);\n  m.def(\"get_autocast_gpu_dtype\", autocast::get_autocast_gpu_dtype);\n  m.def(\"get_autocast_cpu_dtype\", autocast::get_autocast_cpu_dtype);\n  m.def(\"set_autocast_gpu_dtype\", autocast::set_autocast_gpu_dtype);\n  m.def(\"set_autocast_cpu_dtype\", autocast::set_autocast_cpu_dtype);\n  m.def(\"is_autocast_cache_enabled\", autocast::is_autocast_cache_enabled);\n  m.def(\"set_autocast_cache_enabled\", autocast::set_autocast_cache_enabled);\n  m.def(\"clear_autocast_cache\", autocast::clear_cache);\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/api/python/framework/device.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <pybind11/pybind11.h>\n#include <pybind11/operators.h>\n#include \"oneflow/core/control/global_process_ctx.h\"\n#include \"oneflow/api/python/of_api_registry.h\"\n#include \"oneflow/core/framework/device.h\"\n#include \"oneflow/core/common/str_util.h\"\n#include \"oneflow/core/control/global_process_ctx.h\"\n#include \"oneflow/core/ep/include/device.h\"\n\nnamespace py = pybind11;\n\nnamespace oneflow {\n\nONEFLOW_API_PYBIND11_MODULE(\"\", m) {\n  py::class_<Symbol<Device>, std::shared_ptr<Symbol<Device>>>(m, \"device\")\n      .def(py::init([](const std::string& type_or_type_with_device_id) {\n        return Device::ParseAndNew(type_or_type_with_device_id).GetOrThrow();\n      }))\n      .def(py::init([](const std::string& type, int64_t index) {\n             return Device::New(type, index).GetOrThrow();\n           }),\n           py::arg(\"type\"), py::arg(\"index\"))\n      .def(py::init([](const Symbol<Device>& other_device) { return other_device; }))\n      .def_property_readonly(\"type\", [](const Symbol<Device>& d) { return d->type(); })\n      .def_property_readonly(\"index\", [](const Symbol<Device>& d) { return d->device_id(); })\n      .def_property_readonly(\"rematable\", [](const Symbol<Device>& d) { return d->rematable(); })\n      .def(\"__str__\", [](const Symbol<Device>& d) { return d->ToString(); })\n      .def(\"__repr__\", [](const Symbol<Device>& d) { return d->ToRepr(); })\n      .def(py::self == py::self)\n      .def(py::hash(py::self));\n\n  m.def(\n      \"max_alignment_size\", []() { return ep::kMaxAlignmentRequirement; },\n      py::return_value_policy::copy);\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/api/python/framework/doc.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <pybind11/pybind11.h>\n\n#include \"oneflow/api/python/of_api_registry.h\"\n#include \"oneflow/core/common/throw.h\"\n\nnamespace py = pybind11;\n\nnamespace oneflow {\n\npy::object AddFunctionDoc(py::object f, const std::string& doc_string) {\n  static std::vector<std::string> all_doc_strings;\n  all_doc_strings.emplace_back(doc_string);\n  const char* doc_str = all_doc_strings.back().c_str();\n  PyObject* obj = f.ptr();\n  if (PyCFunction_Check(obj)) {\n    auto* f = (PyCFunctionObject*)obj;\n    if (f->m_ml->ml_doc) {\n      THROW(RuntimeError) << \"function \" << f->m_ml->ml_name << \" already has a docstring \"\n                          << \"shows: \" << f->m_ml->ml_doc;\n    }\n    f->m_ml->ml_doc = doc_str;\n  } else if (PyFunction_Check(obj)) {\n    auto* f = (PyFunctionObject*)obj;\n    if (f->func_doc != Py_None) {\n      THROW(RuntimeError) << \"function \"\n                          << PyBytes_AsString(\n                                 PyUnicode_AsEncodedString(f->func_name, \"utf-8\", \"~E~\"))\n                          << \" already has a docstring\";\n    }\n    f->func_doc = PyUnicode_FromString(doc_str);\n  } else if (strcmp(Py_TYPE(obj)->tp_name, \"method_descriptor\") == 0) {\n    PyMethodDescrObject* f = (PyMethodDescrObject*)obj;\n    if (f->d_method->ml_doc) {\n      THROW(RuntimeError) << \"function \" << f->d_method->ml_name << \"already has a docstring\";\n    }\n    f->d_method->ml_doc = doc_str;\n  } else if (strcmp(Py_TYPE(obj)->tp_name, \"getset_descriptor\") == 0) {\n    PyMethodDescrObject* f = (PyMethodDescrObject*)obj;\n    if (f->d_method->ml_doc) {\n      THROW(RuntimeError) << \"function \" << f->d_method->ml_name << \"already has a docstring\";\n    }\n    f->d_method->ml_doc = doc_str;\n  } else if (py::isinstance<py::detail::generic_type>(f)) {\n    if (py::hasattr(f, \"__doc__\")) {\n      auto doc = py::getattr(f, \"__doc__\");\n      if (!doc.is(py::none())) {\n        THROW(RuntimeError) << Py_TYPE(obj)->tp_name << \" already has a docstring\";\n      }\n    }\n    py::setattr(f, \"__doc__\", py::reinterpret_steal<py::object>(PyUnicode_FromString(doc_str)));\n  } else if (Py_TYPE(obj)->tp_name == PyProperty_Type.tp_name) {\n    py::setattr(f, \"__doc__\", py::reinterpret_steal<py::object>(PyUnicode_FromString(doc_str)));\n  } else if (PyInstanceMethod_Check(obj)) {\n    auto* f = (PyCFunctionObject*)(PyInstanceMethod_Function(obj));\n    f->m_ml->ml_doc = doc_str;\n  } else {\n    THROW(RuntimeError) << \"function is \" << Py_TYPE(obj)->tp_name << \", not a valid function\";\n  }\n  f.inc_ref();\n  return f;\n}\n\npy::object ReplaceDoc(py::object f, const std::string& doc_string) {\n  static std::vector<std::string> all_doc_strings;\n  all_doc_strings.emplace_back(doc_string);\n  const char* doc_str = all_doc_strings.back().c_str();\n  PyObject* obj = f.ptr();\n  if (PyCFunction_Check(obj)) {\n    auto* f = (PyCFunctionObject*)obj;\n    if (!f->m_ml->ml_doc) {\n      THROW(RuntimeError) << \"function \" << f->m_ml->ml_name << \" has not a docstring yet.\";\n    }\n    f->m_ml->ml_doc = doc_str;\n  } else if (PyFunction_Check(obj)) {\n    auto* f = (PyFunctionObject*)obj;\n    if (f->func_doc == Py_None) {\n      THROW(RuntimeError) << \"function \"\n                          << PyBytes_AsString(\n                                 PyUnicode_AsEncodedString(f->func_name, \"utf-8\", \"~E~\"))\n                          << \" has not a docstring yet.\";\n    }\n    Py_DECREF(f->func_doc);\n    f->func_doc = PyUnicode_FromString(doc_str);\n  } else {\n    THROW(RuntimeError) << \"function is \" << Py_TYPE(obj)->tp_name << \", not a valid function.\";\n  }\n  f.inc_ref();\n  return f;\n}\n\n}  // namespace oneflow\n\nONEFLOW_API_PYBIND11_MODULE(\"\", m) {\n  m.def(\"add_doc\", &oneflow::AddFunctionDoc);\n  m.def(\"reset_doc\", &oneflow::ReplaceDoc);\n}\n"
  },
  {
    "path": "oneflow/api/python/framework/dtype.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <pybind11/pybind11.h>\n#include <pybind11/operators.h>\n#include \"oneflow/api/python/of_api_registry.h\"\n#include \"oneflow/api/python/framework/tensortype.h\"\n#include \"oneflow/api/python/functional/common.h\"\n#include \"oneflow/core/framework/dtype.h\"\n\nnamespace py = pybind11;\n\nnamespace oneflow {\n\nONEFLOW_API_PYBIND11_MODULE(\"\", m) {\n  py::class_<Symbol<DType>, std::shared_ptr<Symbol<DType>>>(m, \"dtype\")\n      .def_property_readonly(\"is_signed\", [](const Symbol<DType>& d) { return d->is_signed(); })\n      .def_property_readonly(\"is_complex\", [](const Symbol<DType>& d) { return d->is_complex(); })\n      .def_property_readonly(\"is_floating_point\",\n                             [](const Symbol<DType>& d) { return d->is_floating_point(); })\n      .def(\"__str__\", [](const Symbol<DType>& d) { return d->name(); })\n      .def(\"__repr__\", [](const Symbol<DType>& d) { return d->name(); })\n      .def(py::self == py::self)\n      .def(py::hash(py::self))\n      .def(py::pickle(\n          [](const Symbol<DType>& dtype) {  // __getstate__\n            return static_cast<int>(dtype->data_type());\n          },\n          [](int t) {  // __setstate__\n            return CHECK_JUST(DType::Get(DataType(t)));\n          }))\n      .def_property_readonly(\"bytes\", [](const Symbol<DType>& dtype) { return dtype->bytes(); })\n      .def(\"get\", [](const int data_type_enum) {\n        return CHECK_JUST(DType::Get(static_cast<DataType>(data_type_enum)));\n      });\n\n  m.attr(\"bool\") = &CHECK_JUST(DType::Get(DataType::kBool));\n  m.attr(\"char\") = &CHECK_JUST(DType::Get(DataType::kChar));\n  m.attr(\"float16\") = &CHECK_JUST(DType::Get(DataType::kFloat16));\n  m.attr(\"float\") = &CHECK_JUST(DType::Get(DataType::kFloat));\n  m.attr(\"float32\") = &CHECK_JUST(DType::Get(DataType::kFloat));\n  m.attr(\"double\") = &CHECK_JUST(DType::Get(DataType::kDouble));\n  m.attr(\"float64\") = &CHECK_JUST(DType::Get(DataType::kDouble));\n  m.attr(\"int8\") = &CHECK_JUST(DType::Get(DataType::kInt8));\n  m.attr(\"int32\") = &CHECK_JUST(DType::Get(DataType::kInt32));\n  m.attr(\"int64\") = &CHECK_JUST(DType::Get(DataType::kInt64));\n  m.attr(\"uint8\") = &CHECK_JUST(DType::Get(DataType::kUInt8));\n  m.attr(\"record\") = &CHECK_JUST(DType::Get(DataType::kOFRecord));\n  m.attr(\"tensor_buffer\") = &CHECK_JUST(DType::Get(DataType::kTensorBuffer));\n  m.attr(\"bfloat16\") = &CHECK_JUST(DType::Get(DataType::kBFloat16));\n  m.attr(\"uint16\") = &CHECK_JUST(DType::Get(DataType::kUInt16));\n  m.attr(\"uint32\") = &CHECK_JUST(DType::Get(DataType::kUInt32));\n  m.attr(\"uint64\") = &CHECK_JUST(DType::Get(DataType::kUInt64));\n  m.attr(\"uint128\") = &CHECK_JUST(DType::Get(DataType::kUInt128));\n  m.attr(\"int16\") = &CHECK_JUST(DType::Get(DataType::kInt16));\n  m.attr(\"int128\") = &CHECK_JUST(DType::Get(DataType::kInt128));\n  m.attr(\"complex32\") = &CHECK_JUST(DType::Get(DataType::kComplex32));\n  m.attr(\"chalf\") = &CHECK_JUST(DType::Get(DataType::kComplex32));\n  m.attr(\"complex64\") = &CHECK_JUST(DType::Get(DataType::kComplex64));\n  m.attr(\"cfloat\") = &CHECK_JUST(DType::Get(DataType::kComplex64));\n  m.attr(\"complex128\") = &CHECK_JUST(DType::Get(DataType::kComplex128));\n  m.attr(\"cdouble\") = &CHECK_JUST(DType::Get(DataType::kComplex128));\n  m.attr(\"char\") = &CHECK_JUST(DType::Get(DataType::kChar));\n  m.attr(\"short\") = &CHECK_JUST(DType::Get(DataType::kInt16));\n\n  py::options options;\n  options.disable_function_signatures();\n  m.def(\"get_default_dtype\", []() { return GetDefaultDType(); });\n  m.def(\"set_default_dtype\",\n        [](const Symbol<DType>& dtype) { SetDefaultDType(dtype).GetOrThrow(); });\n  m.def(\"set_default_tensor_type\", [](const py::object& tensor_type) {\n    if (one::PyTensorType_Check(tensor_type.ptr())) {\n      CHECK_JUST(SetDefaultDType(one::PyTensorType_UnpackDType(tensor_type.ptr())));\n    } else {\n      throw py::type_error(\"invalid type object\");\n    }\n  });\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/api/python/framework/framework.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <pybind11/pybind11.h>\n#include <string>\n#include \"oneflow/api/python/of_api_registry.h\"\n#include \"oneflow/core/job/job_build_and_infer_ctx_mgr.h\"\n#include \"oneflow/api/python/framework/framework.h\"\n#include \"oneflow/core/framework/load_library.h\"\n\nnamespace py = pybind11;\n\nnamespace oneflow {\n\nONEFLOW_API_PYBIND11_MODULE(\"\", m) {\n  m.def(\"GetSerializedCurrentJob\",\n        []() -> Maybe<py::bytes> { return py::bytes(*JUST(GetSerializedCurrentJob())); });\n  m.def(\"GetFunctionConfigDef\", &GetFunctionConfigDef);\n  m.def(\"GetScopeConfigDef\", &GetScopeConfigDef);\n\n  m.def(\"LoadLibrary\", &LoadLibrary);\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/api/python/framework/framework.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_API_PYTHON_FRAMEWORK_FRAMEWORK_H_\n#define ONEFLOW_API_PYTHON_FRAMEWORK_FRAMEWORK_H_\n\n#include <string>\n#include <google/protobuf/text_format.h>\n#include \"oneflow/core/common/buffer_manager.h\"\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/common/protobuf.h\"\n#include \"oneflow/core/control/global_process_ctx.h\"\n#include \"oneflow/core/job/job_build_and_infer_ctx_mgr.h\"\n#include \"oneflow/core/job/job_desc.h\"\n#include \"oneflow/core/job/inter_user_job_info.pb.h\"\n#include \"oneflow/core/job/job_instance.h\"\n#include \"oneflow/core/job/oneflow.h\"\n#include \"oneflow/core/job/placement.pb.h\"\n#include \"oneflow/core/framework/config_def.h\"\n#include \"oneflow/core/framework/load_library.h\"\n\nnamespace oneflow {\n\ninline Maybe<std::string> GetSerializedCurrentJob() {\n  auto* job_ctx_mgr = Singleton<LazyJobBuildAndInferCtxMgr>::Get();\n  CHECK_NOTNULL_OR_RETURN(job_ctx_mgr);\n  auto* job_ctx =\n      JUST(job_ctx_mgr->FindJobBuildAndInferCtx(*JUST(job_ctx_mgr->GetCurrentJobName())));\n  CHECK_NOTNULL_OR_RETURN(job_ctx);\n  return job_ctx->job().SerializeAsString();\n}\n\ninline Maybe<std::string> GetFunctionConfigDef() {\n  std::string ret;\n  google::protobuf::TextFormat::PrintToString(GlobalFunctionConfigDef(), &ret);\n  return ret;\n}\n\ninline Maybe<std::string> GetScopeConfigDef() {\n  std::string ret;\n  google::protobuf::TextFormat::PrintToString(GlobalScopeConfigDef(), &ret);\n  return ret;\n}\n\ninline Maybe<std::string> GetSerializedMachineId2DeviceIdListOFRecord(\n    const std::string& parallel_conf_str) {\n  ParallelConf parallel_conf;\n  CHECK_OR_RETURN(TxtString2PbMessage(parallel_conf_str, &parallel_conf))\n      << \"parallel conf parse failed\";\n  return PbMessage2TxtString(*JUST(ParseMachineAndDeviceIdList(parallel_conf)));\n}\n\ninline Maybe<void> LoadLibraryNow(const std::string& lib_path) { return LoadLibrary(lib_path); }\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_API_PYTHON_FRAMEWORK_FRAMEWORK_H_\n"
  },
  {
    "path": "oneflow/api/python/framework/global_mode.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include <pybind11/pybind11.h>\n#include <pybind11/stl.h>\n#include \"oneflow/api/python/of_api_registry.h\"\n#include \"oneflow/core/common/symbol.h\"\n#include \"oneflow/core/common/throw.h\"\n#include \"oneflow/core/framework/nd_sbp.h\"\n#include \"oneflow/core/job/global_mode.h\"\n\nnamespace py = pybind11;\n\nnamespace oneflow {\n\nONEFLOW_API_PYBIND11_MODULE(\"global_view\", m) {\n  py::class_<GlobalMode::Guard, std::shared_ptr<GlobalMode::Guard>>(m, \"global_mode\")\n      .def(py::init([](const bool enabled) {\n        if (enabled) {\n          THROW(RuntimeError) << \"To enable global mode, placement and sbp must be provided.\";\n        }\n        return std::make_shared<GlobalMode::Guard>(enabled);\n      }))\n      .def(py::init([](const bool enabled, const Symbol<ParallelDesc>& placement,\n                       const std::vector<Symbol<SbpParallel>>& sbp) {\n             if (!enabled) {\n               THROW(RuntimeError)\n                   << \"To disable global mode, placement and sbp must not be provided.\";\n             }\n             return std::make_shared<GlobalMode::Guard>(enabled, CHECK_JUST(GetNdSbp(sbp)),\n                                                        placement);\n           }),\n           py::arg(\"enabled\").none(false), py::arg(\"placement\").none(false),\n           py::arg(\"sbp\").none(false))\n      .def(py::init([](const bool enabled, const Symbol<ParallelDesc>& placement,\n                       const Symbol<SbpParallel>& sbp) {\n             return std::make_shared<GlobalMode::Guard>(enabled, CHECK_JUST(SbpToNdSbp(sbp)),\n                                                        placement);\n           }),\n           py::arg(\"enabled\").none(false), py::arg(\"placement\").none(false),\n           py::arg(\"sbp\").none(false))\n      .def(\"__enter__\", [](const GlobalMode::Guard& guard_obj) {})\n      .def(\"__exit__\", [](const GlobalMode::Guard& guard_obj, const py::object& type,\n                          const py::object& value, const py::object& traceback) {});\n\n  py::class_<GlobalMode, std::shared_ptr<GlobalMode>>(m, \"current_global_mode\")\n      .def(py::init([]() { return std::make_shared<GlobalMode>(); }))\n      .def_property_readonly(\"is_enabled\", [](const GlobalMode& gm) { return gm.is_enabled(); })\n      .def_property_readonly(\"sbp\",\n                             [](const GlobalMode& gm) {\n                               if (!gm.is_enabled()) {\n                                 THROW(RuntimeError)\n                                     << \"Current global mode is disabled, there is no sbp.\";\n                               }\n                               const auto& nd_sbp = gm.nd_sbp();\n                               auto tuple = py::tuple(nd_sbp->sbp_parallel_size());\n                               for (int i = 0; i < nd_sbp->sbp_parallel_size(); ++i) {\n                                 tuple[i] = SymbolOf(nd_sbp->sbp_parallel(i));\n                               }\n                               return tuple;\n                             })\n      .def_property_readonly(\"placement\", [](const GlobalMode& gm) {\n        if (!gm.is_enabled()) {\n          THROW(RuntimeError) << \"Current global mode is disabled, there is no placement.\";\n        }\n        return gm.parallel_desc();\n      });\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/api/python/framework/id_state.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <pybind11/detail/common.h>\n#include <pybind11/pybind11.h>\n#include <pybind11/pytypes.h>\n#include <pybind11/stl.h>\n#include <pybind11/stl_bind.h>\n#include \"oneflow/api/python/of_api_registry.h\"\n#include \"oneflow/core/common/singleton.h\"\n#include \"oneflow/core/framework/multi_client_session_context.h\"\n#include \"oneflow/core/job/id_state.h\"\n\nnamespace py = pybind11;\n\nONEFLOW_API_PYBIND11_MODULE(\"\", m) {\n  using namespace oneflow;\n\n  py::class_<IdState>(m, \"IdState\")\n      .def(py::init<>())\n      .def_readwrite(\"regst_desc_id_state\", &IdState::regst_desc_id_state_)\n      .def_readwrite(\"mem_block_id_state\", &IdState::mem_block_id_state_)\n      .def_readwrite(\"chunk_id_state\", &IdState::chunk_id_state_)\n      .def_readwrite(\"job_id_state\", &IdState::job_id_state_)\n      .def_readwrite(\"task_index_state\", &IdState::task_index_state_)\n      .def_readwrite(\"stream_index_state\", &IdState::stream_index_state_)\n      // support pickle\n      .def(py::pickle(\n          [](const IdState& id_state) {\n            return py::make_tuple(id_state.regst_desc_id_state_, id_state.mem_block_id_state_,\n                                  id_state.chunk_id_state_, id_state.job_id_state_,\n                                  id_state.task_index_state_, id_state.stream_index_state_);\n          },\n          [](const py::tuple& t) {\n            CHECK(t.size() == 6);\n            IdState id_state;\n            id_state.regst_desc_id_state_ = t[0].cast<int64_t>();\n            id_state.mem_block_id_state_ = t[1].cast<int64_t>();\n            id_state.chunk_id_state_ = t[2].cast<int64_t>();\n            id_state.job_id_state_ = t[3].cast<int64_t>();\n            id_state.task_index_state_ = t[4].cast<HashMap<int64_t, uint32_t>>();\n            id_state.stream_index_state_ = t[5].cast<HashMap<int64_t, uint32_t>>();\n            return id_state;\n          }));\n\n  m.def(\"set_id_state\", [](const IdState& id_state) {\n    Singleton<MultiClientSessionContext>::Get()->SetIdState(id_state);\n  });\n  m.def(\"get_id_state\", []() { return Singleton<MultiClientSessionContext>::Get()->GetIdState(); });\n}\n"
  },
  {
    "path": "oneflow/api/python/framework/id_util.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <pybind11/pybind11.h>\n#include \"oneflow/api/python/of_api_registry.h\"\n#include \"oneflow/core/framework/id_util.h\"\n\nnamespace py = pybind11;\n\nnamespace oneflow {\n\nONEFLOW_API_PYBIND11_MODULE(\"\", m) { m.def(\"UniqueStr\", &UniqueStr); }\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/api/python/framework/instructions_builder.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <pybind11/pybind11.h>\n#include <pybind11/functional.h>\n#include <pybind11/stl.h>\n#include <functional>\n#include \"oneflow/api/python/framework/size.h\"\n#include \"oneflow/api/python/of_api_registry.h\"\n#include \"oneflow/core/framework/instructions_builder.h\"\n#include \"oneflow/core/framework/tensor.h\"\n\nnamespace py = pybind11;\n\nnamespace oneflow {\n\nnamespace {\n\nMaybe<void> DeprecatedPhysicalRun(const std::function<void(InstructionsBuilder*)>& Build) {\n  return PhysicalRun([&](InstructionsBuilder* instruction_builder) -> Maybe<void> {\n    Build(instruction_builder);\n    return Maybe<void>::Ok();\n  });\n}\n\n}  // namespace\n\nONEFLOW_API_PYBIND11_MODULE(\"deprecated\", m) {\n  py::class_<InstructionsBuilder, std::shared_ptr<InstructionsBuilder>>(m, \"InstructionsBuilder\")\n      .def(\n          \"BuildInitialScope\",\n          [](const std::shared_ptr<InstructionsBuilder>& builder, int64_t session_id,\n             const std::string& job_conf_str, const std::string& device_tag,\n             const std::vector<std::string>& machine_device_ids,\n             const std::shared_ptr<Shape>& hierarchy, bool is_local) -> Maybe<Scope> {\n            JobConfigProto job_conf;\n            CHECK_OR_RETURN(TxtString2PbMessage(job_conf_str, &job_conf))\n                << Error::RuntimeError() << \"job conf parse failed\";\n            return builder->BuildInitialScope(session_id, job_conf, device_tag, machine_device_ids,\n                                              hierarchy, is_local);\n          },\n          py::arg(\"session_id\").none(false), py::arg(\"job_conf_str\").none(false),\n          py::arg(\"device_tag\").none(false), py::arg(\"machine_device_ids\").none(false),\n          py::arg(\"hierarchy\").none(true), py::arg(\"is_local\").none(false))\n      .def(\n          \"BuildInitialScopeWithPlacement\",\n          [](const std::shared_ptr<InstructionsBuilder>& builder, int64_t session_id,\n             const std::string& job_conf_str, Symbol<ParallelDesc> placement,\n             bool is_local) -> Maybe<Scope> {\n            JobConfigProto job_conf;\n            CHECK_OR_RETURN(TxtString2PbMessage(job_conf_str, &job_conf))\n                << Error::RuntimeError() << \"job conf parse failed\";\n            return builder->BuildInitialScopeWithPlacement(session_id, job_conf, placement,\n                                                           is_local);\n          },\n          py::arg(\"session_id\").none(false), py::arg(\"job_conf_str\").none(false),\n          py::arg(\"placement\").none(false), py::arg(\"is_local\").none(false))\n      .def(\"BuildScopeWithNewParallelDesc\", &InstructionsBuilder::BuildScopeWithNewParallelDesc,\n           py::arg(\"scope\").none(false), py::arg(\"device_tag\").none(false),\n           py::arg(\"machine_device_ids\").none(false), py::arg(\"hierarchy\").none(true))\n      .def(\"BuildScopeWithNewParallelConf\",\n           [](const std::shared_ptr<InstructionsBuilder>& builder,\n              const std::shared_ptr<Scope>& scope,\n              const std::string& parallel_conf_str) -> Maybe<Scope> {\n             ParallelConf parallel_conf;\n             CHECK_OR_RETURN(TxtString2PbMessage(parallel_conf_str, &parallel_conf))\n                 << Error::RuntimeError() << \"parallel conf parse failed\";\n             return builder->BuildScopeWithNewParallelConf(scope, parallel_conf);\n           })\n      .def(\"BuildScopeWithNewIsLocal\", &InstructionsBuilder::BuildScopeWithNewIsLocal)\n      .def(\"BuildScopeWithNewScopeName\", &InstructionsBuilder::BuildScopeWithNewScopeName)\n      .def(\"BuildScopeByProtoStrSetter\", &InstructionsBuilder::BuildScopeByProtoStrSetter);\n\n  m.def(\"PhysicalRun\", &DeprecatedPhysicalRun, py::call_guard<py::gil_scoped_release>());\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/api/python/framework/layout.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <pybind11/pybind11.h>\n#include <pybind11/operators.h>\n#include \"oneflow/api/python/of_api_registry.h\"\n#include \"oneflow/api/python/framework/tensortype.h\"\n#include \"oneflow/api/python/functional/common.h\"\n#include \"oneflow/core/framework/dtype.h\"\n#include \"oneflow/core/framework/layout.h\"\n\nnamespace py = pybind11;\n\nnamespace oneflow {\n\nONEFLOW_API_PYBIND11_MODULE(\"\", m) {\n  py::class_<Symbol<Layout>>(m, \"layout\")\n      .def(\"__str__\", [](Symbol<Layout> d) { return d->name(); })\n      .def(\"__repr__\", [](Symbol<Layout> d) { return d->name(); })\n      .def(py::self == py::self)\n      .def(py::hash(py::self))\n      .def(py::pickle(\n          [](Symbol<Layout> layout) {  // __getstate__\n            return static_cast<int>(layout->layout_type());\n          },\n          [](int t) {  // __setstate__\n            return Layout::Get(LayoutType(t));\n          }))\n      .def(\"get\", [](const int layout_type_enum) {\n        return Layout::Get(static_cast<LayoutType>(layout_type_enum));\n      });\n\n  m.attr(\"strided\") = Layout::Get(LayoutType::kStrided);\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/api/python/framework/memory_format.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <pybind11/pybind11.h>\n#include \"oneflow/api/python/of_api_registry.h\"\n#include \"oneflow/api/python/functional/common.h\"\n\n#include \"oneflow/api/python/framework/memory_format.h\"\n\nnamespace py = pybind11;\n\nnamespace oneflow {\n\nstatic PyObject* PyMemoryFormat_repr(PyMemoryFormatObject* self) {\n  auto memory_format = PyMemoryFormat_Unpack((PyObject*)self);\n  if (memory_format == MemoryFormat::kContiguous) {\n    return PyUnicode_FromString(\"oneflow.contiguous_format\");\n  } else if (memory_format == MemoryFormat::kChannelsLast) {\n    return PyUnicode_FromString(\"oneflow.channels_last\");\n  } else if (memory_format == MemoryFormat::kPreserve) {\n    return PyUnicode_FromString(\"oneflow.preserve_format\");\n  } else {\n    THROW(TypeError) << \"invalid memory format\";\n    return nullptr;\n  }\n}\n\nPyTypeObject PyMemoryFormat_Type = {\n    PyVarObject_HEAD_INIT(NULL, 0) \"oneflow.memory_format\", /* tp_name */\n    sizeof(PyMemoryFormatObject),                           /* tp_basicsize */\n    0,                                                      /* tp_itemsize */\n    NULL,                                                   /* tp_dealloc */\n    0,                                                      /* tp_vectorcall_offset */\n    NULL,                                                   /* tp_getattr */\n    NULL,                                                   /* tp_setattr */\n    NULL,                                                   /* tp_reserved */\n    (reprfunc)PyMemoryFormat_repr,                          /* tp_repr */\n    NULL,                                                   /* tp_as_number */\n    NULL,                                                   /* tp_as_sequence */\n    NULL,                                                   /* tp_as_mapping */\n    NULL,                                                   /* tp_hash  */\n    NULL,                                                   /* tp_call */\n    NULL,                                                   /* tp_str */\n    NULL,                                                   /* tp_getattro */\n    NULL,                                                   /* tp_setattro */\n    NULL,                                                   /* tp_as_buffer */\n    Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE,               /* tp_flags */\n};\n\nbool PyMemoryFormat_Check(PyObject* self) { return self && self->ob_type == &PyMemoryFormat_Type; }\n\nPyObject* PyMemoryFormat_New(MemoryFormat memory_format) {\n  auto* self = (PyMemoryFormatObject*)PyMemoryFormat_Type.tp_alloc(&PyMemoryFormat_Type, 0);\n  self->memory_format = memory_format;\n  return (PyObject*)self;\n}\n\nstatic PyObject* PyMemoryFormat_contiguous = nullptr;\nstatic PyObject* PyMemoryFormat_channels_last = nullptr;\nstatic PyObject* PyMemoryFormat_preserve = nullptr;\n\nONEFLOW_API_PYBIND11_MODULE(\"\", m) {\n  if (PyType_Ready(&PyMemoryFormat_Type) < 0) { return; }\n  Py_INCREF(&PyMemoryFormat_Type);\n  if (PyModule_AddObject(m.ptr(), \"memory_format\", (PyObject*)&PyMemoryFormat_Type) < 0) { return; }\n\n  PyMemoryFormat_contiguous = PyMemoryFormat_New(MemoryFormat::kContiguous);\n  PyMemoryFormat_channels_last = PyMemoryFormat_New(MemoryFormat::kChannelsLast);\n  PyMemoryFormat_preserve = PyMemoryFormat_New(MemoryFormat::kPreserve);\n  if (PyModule_AddObject(m.ptr(), \"contiguous_format\", PyMemoryFormat_contiguous) < 0) { return; }\n  if (PyModule_AddObject(m.ptr(), \"channels_last\", PyMemoryFormat_channels_last) < 0) { return; }\n  if (PyModule_AddObject(m.ptr(), \"preserve_format\", PyMemoryFormat_preserve) < 0) { return; }\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/api/python/framework/memory_format.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_API_PYTHON_FRAMEWORK_MEMORY_FORMAT_H_\n#define ONEFLOW_API_PYTHON_FRAMEWORK_MEMORY_FORMAT_H_\n\n#include <Python.h>\n#undef _PyGC_FINALIZED\n#include <pybind11/pybind11.h>\n\n#include \"oneflow/core/common/memory_format.pb.h\"\n\nnamespace oneflow {\n\ntypedef struct PyMemoryFormatObject {\n  PyTypeObject ob_type;\n  MemoryFormat memory_format;\n} PyMemoryFormatObject;\n\nbool PyMemoryFormat_Check(PyObject*);\n\ninline MemoryFormat PyMemoryFormat_Unpack(PyObject* self) {\n  return ((PyMemoryFormatObject*)self)->memory_format;\n}\n\nPyObject* PyMemoryFormat_New(MemoryFormat memory_format);\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_API_PYTHON_FRAMEWORK_MEMORY_FORMAT_H_\n"
  },
  {
    "path": "oneflow/api/python/framework/nn_graph.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <pybind11/pybind11.h>\n#include <pybind11/stl.h>\n#include <memory>\n#include <string>\n#include \"oneflow/api/python/job_build/job_build_and_infer.h\"\n#include \"oneflow/api/python/of_api_registry.h\"\n#include \"oneflow/core/framework/multi_client_session_context.h\"\n#include \"oneflow/core/framework/tensor.h\"\n#include \"oneflow/core/framework/nn_graph.h\"\n#include \"oneflow/core/job/runtime.h\"\n#include \"oneflow/core/register/blob.h\"\n#include \"oneflow/core/job/job.pb.h\"\n#include \"oneflow/core/job/job_ir.h\"\n#include \"oneflow/core/job/job_interpreter.h\"\n\nnamespace py = pybind11;\n\nnamespace oneflow {\nnamespace {\nMaybe<py::object> APINNGraphAdditionalVarNames(const std::shared_ptr<NNGraph>& graph) {\n  const auto names = *JUST(graph->GetAdditionalVarOpNames());\n  py::list name_list = py::cast(names);\n  return py::cast<py::object>(name_list);\n}\nMaybe<py::object> APINNGraphAdditionalVarTensors(const std::shared_ptr<NNGraph>& graph) {\n  const auto tensors = *JUST(graph->GetAdditionalVarOpTensors());\n  py::list tensor_list = py::cast(tensors);\n  return py::cast<py::object>(tensor_list);\n}\n\nMaybe<py::bytes> APINNGraphGetCurrentSerializedJob(const std::shared_ptr<NNGraph>& graph) {\n  const auto job = graph->job();\n  return py::bytes(job.SerializeAsString());\n}\n}  // namespace\n\nONEFLOW_API_PYBIND11_MODULE(\"nn.graph.\", m) {\n  using namespace oneflow;\n  py::class_<NNGraph, std::shared_ptr<NNGraph>>(m, \"CNNGraph\")\n      .def(py::init([](const std::string& name, const std::string& serialized_job, int64_t job_id,\n                       const std::shared_ptr<MultiClientSessionContext>& session_ctx) {\n        Job job;\n        if (!job.ParseFromString(serialized_job)) {\n          PyErr_SetString(PyExc_TypeError, \"The second argument is not a valid job\");\n        }\n        return std::make_shared<NNGraph>(name, job, job_id, session_ctx);\n      }))\n      .def(py::init([](const std::string& name, const std::string& serialized_plan, int64_t job_id,\n                       const std::shared_ptr<MultiClientSessionContext>& session_ctx,\n                       bool init_from_plan) {\n        if (!init_from_plan) {\n          PyErr_SetString(\n              PyExc_TypeError,\n              \"init_from_plan must be True when init CNNGraph with this bool parameter.\");\n        }\n        Plan plan;\n        if (!plan.ParseFromString(serialized_plan)) {\n          PyErr_SetString(PyExc_TypeError, \"The second argument is not a valid plan\");\n        }\n        return std::make_shared<NNGraph>(name, plan, job_id, session_ctx);\n      }))\n      .def_property_readonly(\"name\", &NNGraph::job_name)\n      .def_property(\n          \"job\", /*getter*/\n          [](const NNGraph& nn_graph) { return py::bytes(nn_graph.job().SerializeAsString()); },\n          /*setter*/\n          [](NNGraph& nn_graph, const std::string& serialized_job) {\n            Job job;\n            if (!job.ParseFromString(serialized_job)) {\n              PyErr_SetString(PyExc_TypeError, \"the value is not a valid job\");\n            }\n            nn_graph.restore_job(job);\n          })\n      .def_property(\"job_id\", &NNGraph::job_id,\n                    [](NNGraph& nn_graph, int64_t job_id) { nn_graph.restore_job_id(job_id); })\n      .def_property(\n          \"plan\", /*getter*/\n          [](const NNGraph& nn_graph) { return py::bytes(nn_graph.plan().SerializeAsString()); },\n          /*setter*/\n          [](NNGraph& nn_graph, const std::string& serialized_plan) {\n            Plan plan;\n            if (!plan.ParseFromString(serialized_plan)) {\n              PyErr_SetString(PyExc_TypeError, \"the value is not a valid plan\");\n            }\n            nn_graph.restore_plan(plan);\n          })\n      .def(\"register_input_op_names_and_tensors\", &NNGraph::RegisterInputOpNamesAndTensors)\n      .def(\"register_output_op_names_and_tensors\", &NNGraph::RegisterOutputOpNamesAndTensors)\n      .def(\"register_variable_op_names_and_tensors\", &NNGraph::RegisterVariableOpNamesAndTensors)\n      .def(\"register_additional_variable_names_and_tensors\",\n           &NNGraph::RegisterAdditionalVarOpNamesAndTensorsToBeLoaded)\n      .def_property_readonly(\"additional_var_names\", &APINNGraphAdditionalVarNames)\n      .def_property_readonly(\"additional_var_tensors\", &APINNGraphAdditionalVarTensors)\n      .def(\"align_states_after_logical_graph_compile\",\n           &NNGraph::AlignStatesAfterLogicalGraphCompile)\n      .def(\"complete_graph_for_runtime\", &NNGraph::CompleteLogicalGraphForRuntime)\n      .def(\"build_with_new_input_from_shared_graph\", &NNGraph::BuildWithNewInputFromSharedGraph)\n      .def(\"compile_plan_for_runtime\", &NNGraph::CompilePlanForRuntime)\n      .def(\"init_runtime\", &NNGraph::InitRuntime)\n      .def(\"get_current_job_str\", &APINNGraphGetCurrentSerializedJob);\n\n  m.def(\"RunLazyNNGraph\", &RunLazyNNGraph);\n  m.def(\"RunLazyNNGraphByVM\", &one::InterpretJob);\n  m.def(\"SoftSyncNNGraphBuffers\", &SoftSyncNNGraphBuffers);\n  m.def(\"AddTensorAsGraphLoss\", &AddTensorAsGraphLoss);\n  m.def(\"MarkVariableGradients\", [](const std::vector<std::shared_ptr<one::Tensor>>& variables,\n                                    const std::vector<std::shared_ptr<one::Tensor>>& gradients) {\n    one::TensorTuple variable_tuple(variables.size());\n    one::TensorTuple gradient_tuple(gradients.size());\n    for (int i = 0; i < variables.size(); ++i) { variable_tuple[i] = variables[i]; }\n    for (int i = 0; i < gradients.size(); ++i) { gradient_tuple[i] = gradients[i]; }\n    return MarkVariableGradients(variable_tuple, gradient_tuple);\n  });\n  m.def(\"ConvertJobToTosaIR\", [](const std::string& serialized_job) -> Maybe<std::string> {\n    Job job;\n    CHECK_OR_RETURN(job.ParseFromString(serialized_job)) << \"serialized job conversion failed.\";\n    return ConvertJobToTosaIR(&job);\n  });\n  m.def(\n      \"SaveJobToIR\", [](const std::string& serialized_job, const std::string& path) -> Maybe<void> {\n        Job job;\n        CHECK_OR_RETURN(job.ParseFromString(serialized_job)) << \"serialized job conversion failed.\";\n        return SaveJobToIR(&job, path);\n      });\n  m.def(\"ConvertJobToIR\", [](const std::string& serialized_job) -> Maybe<std::string> {\n    Job job;\n    CHECK_OR_RETURN(job.ParseFromString(serialized_job)) << \"serialized job conversion failed.\";\n    return ConvertJobToIR(&job);\n  });\n  m.def(\"LoadSerializedJobFromIR\", [](const std::string& path) -> Maybe<py::bytes> {\n    Job job;\n    JUST(LoadJobFromIR(&job, path));\n    return py::bytes(job.SerializeAsString());\n  });\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/api/python/framework/one_embedding.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <pybind11/pybind11.h>\n#include <pybind11/stl.h>\n#include <pybind11/numpy.h>\n#include <pybind11/operators.h>\n#include \"oneflow/api/python/of_api_registry.h\"\n#include \"oneflow/core/embedding/embedding_manager.h\"\n#include \"oneflow/core/embedding/persistent_table.h\"\n#include \"oneflow/core/embedding/hash_functions.cuh\"\n#include \"oneflow/core/framework/dtype.h\"\n\nnamespace py = pybind11;\n\nnamespace oneflow {\n\nclass OneEmbeddingHandler final {\n public:\n  OneEmbeddingHandler(const std::string& key_value_store_option_string, int64_t local_rank_id,\n                      int64_t rank_id, int64_t world_size)\n      : local_rank_id_(local_rank_id), rank_id_(rank_id), world_size_(world_size) {\n    embedding::KeyValueStoreOptions key_value_store_options(key_value_store_option_string);\n    embedding_name_ = key_value_store_options.Name();\n    CreateKeyValueStore(key_value_store_options);\n  }\n\n  void LoadSnapshot(const std::string& snapshot_name) {\n#ifdef WITH_CUDA\n    Singleton<embedding::EmbeddingManager>::Get()->LoadSnapshot(embedding_name_, local_rank_id_,\n                                                                rank_id_, snapshot_name);\n#else\n    UNIMPLEMENTED() << \"Only Support with CUDA\";\n#endif\n  }\n\n  void SaveSnapshot(const std::string& snapshot_name) {\n#ifdef WITH_CUDA\n    Singleton<embedding::EmbeddingManager>::Get()->SaveSnapshot(embedding_name_, local_rank_id_,\n                                                                rank_id_, snapshot_name);\n#else\n    UNIMPLEMENTED() << \"Only Support with CUDA\";\n#endif\n  }\n\n private:\n  void CreateKeyValueStore(const embedding::KeyValueStoreOptions& key_value_store_options) {\n#ifdef WITH_CUDA\n    Singleton<embedding::EmbeddingManager>::Get()->CreateKeyValueStore(\n        key_value_store_options, local_rank_id_, rank_id_, world_size_);\n#else\n    UNIMPLEMENTED() << \"Only Support with CUDA\";\n#endif\n  }\n\n  std::string embedding_name_;\n  int64_t local_rank_id_;\n  int64_t rank_id_;\n  int64_t world_size_;\n};\n\nnamespace embedding {\n\nclass PersistentTableWriter {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(PersistentTableWriter);\n  PersistentTableWriter() = default;\n  virtual ~PersistentTableWriter() = default;\n\n  virtual void Write(const py::array& keys, const py::array& values) = 0;\n  virtual void Close() = 0;\n};\n\ntemplate<typename Key, typename Value>\nclass PersistentTableWriterImpl : public PersistentTableWriter {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(PersistentTableWriterImpl);\n  PersistentTableWriterImpl(const std::vector<std::string>& paths, const std::string& snapshot_name,\n                            uint32_t storage_dim, uint64_t target_chunk_size_mb,\n                            uint16_t physical_block_size)\n      : closed_(false), snapshot_name_(snapshot_name), storage_dim_(storage_dim) {\n    tables_.resize(paths.size());\n    for (size_t i = 0; i < paths.size(); ++i) {\n      PersistentTableOptions options;\n      options.path = paths[i];\n      options.key_size = sizeof(Key);\n      options.value_size = storage_dim * sizeof(Value);\n      options.target_chunk_size_mb = target_chunk_size_mb;\n      options.physical_block_size = physical_block_size;\n      tables_[i] = NewPersistentTable(options);\n    }\n  }\n  ~PersistentTableWriterImpl() override { CloseImpl(); }\n\n  void Write(const py::array& keys, const py::array& values) override {\n    pybind11::dtype::of<int32_t>().equal(pybind11::dtype::of<int64_t>());\n    CHECK(!closed_) << \"Write on closed table\";\n    CHECK_EQ(keys.ndim(), 1);\n    CHECK_EQ(values.ndim(), 2);\n    CHECK_EQ(keys.shape(0), values.shape(0));\n    CHECK_EQ(values.shape(1), storage_dim_);\n    CHECK(keys.dtype().equal(py::dtype::of<Key>()));\n    CHECK(values.dtype().equal(py::dtype::of<Value>()));\n    const size_t n = keys.size();\n    std::vector<std::vector<Key>> keys_buffers(tables_.size());\n    std::vector<std::vector<char>> values_buffers(tables_.size());\n    for (size_t i = 0; i < n; ++i) {\n      const Key key = *(reinterpret_cast<const Key*>(keys.template data(i)));\n      const uint32_t shard = ShardingHash()(key) % tables_.size();\n      keys_buffers[shard].push_back(key);\n      const size_t values_offset = values_buffers[shard].size();\n      values_buffers[shard].resize(values_offset + storage_dim_ * sizeof(Value));\n      for (size_t j = 0; j < values.shape(1); ++j) {\n        std::memcpy(values_buffers[shard].data() + values_offset + j * values.itemsize(),\n                    values.template data(i, j), values.itemsize());\n      }\n    }\n    for (size_t shard = 0; shard < tables_.size(); ++shard) {\n      tables_[shard]->Put(keys_buffers[shard].size(), keys_buffers[shard].data(),\n                          values_buffers[shard].data());\n    }\n  }\n\n  void Close() override { CloseImpl(); }\n\n private:\n  void CloseImpl() {\n    if (!closed_) {\n      for (auto& table : tables_) {\n        table->SaveSnapshot(snapshot_name_);\n        table.reset();\n      }\n    }\n    closed_ = true;\n  }\n\n  bool closed_;\n  std::string snapshot_name_;\n  std::vector<std::unique_ptr<PersistentTable>> tables_;\n  uint32_t storage_dim_;\n};\n\ntemplate<typename Key>\nstd::shared_ptr<PersistentTableWriter> NewPersistentTableWriter(\n    const std::vector<std::string>& paths, const std::string& snapshot_name,\n    const Symbol<DType>& key_type, const Symbol<DType>& value_type, uint32_t storage_dim,\n    uint64_t target_chunk_size_mb, uint16_t physical_block_size) {\n  if (value_type->data_type() == DataType::kFloat) {\n    return std::shared_ptr<PersistentTableWriter>(new PersistentTableWriterImpl<Key, float>(\n        paths, snapshot_name, storage_dim, target_chunk_size_mb, physical_block_size));\n  } else {\n    UNIMPLEMENTED();\n  }\n}\n\nstd::shared_ptr<PersistentTableWriter> NewPersistentTableWriter(\n    const std::vector<std::string>& paths, const std::string& snapshot_name,\n    const Symbol<DType>& key_type, const Symbol<DType>& value_type, uint32_t storage_dim,\n    uint64_t target_chunk_size_mb, uint16_t physical_block_size) {\n  if (key_type->data_type() == DataType::kInt32) {\n    return NewPersistentTableWriter<int32_t>(paths, snapshot_name, key_type, value_type,\n                                             storage_dim, target_chunk_size_mb,\n                                             physical_block_size);\n  } else if (key_type->data_type() == DataType::kUInt32) {\n    return NewPersistentTableWriter<uint32_t>(paths, snapshot_name, key_type, value_type,\n                                              storage_dim, target_chunk_size_mb,\n                                              physical_block_size);\n  } else if (key_type->data_type() == DataType::kInt64) {\n    return NewPersistentTableWriter<int64_t>(paths, snapshot_name, key_type, value_type,\n                                             storage_dim, target_chunk_size_mb,\n                                             physical_block_size);\n  } else if (key_type->data_type() == DataType::kUInt64) {\n    return NewPersistentTableWriter<uint64_t>(paths, snapshot_name, key_type, value_type,\n                                              storage_dim, target_chunk_size_mb,\n                                              physical_block_size);\n  } else {\n    UNIMPLEMENTED();\n    return std::shared_ptr<embedding::PersistentTableWriter>(nullptr);\n  }\n}\n\nclass PersistentTableReader {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(PersistentTableReader);\n  PersistentTableReader() = default;\n  virtual ~PersistentTableReader() = default;\n\n  virtual std::tuple<py::object, py::object> Next() = 0;\n  virtual void Close() = 0;\n};\n\ntemplate<typename Key, typename Value>\nclass PersistentTableReaderImpl : public PersistentTableReader {\n public:\n  constexpr static uint32_t kBatchSize = 65536;\n  OF_DISALLOW_COPY_AND_MOVE(PersistentTableReaderImpl);\n  PersistentTableReaderImpl(const std::vector<std::string>& paths, const std::string& snapshot_name,\n                            uint32_t storage_dim, uint64_t target_chunk_size_mb,\n                            uint16_t physical_block_size)\n      : closed_(false),\n        snapshot_name_(snapshot_name),\n        storage_dim_(storage_dim),\n        current_table_(0) {\n    tables_.resize(paths.size());\n    iterators_.resize(paths.size());\n    for (size_t i = 0; i < paths.size(); ++i) {\n      PersistentTableOptions options;\n      options.path = paths[i];\n      options.key_size = sizeof(Key);\n      options.value_size = storage_dim * sizeof(Value);\n      options.target_chunk_size_mb = target_chunk_size_mb;\n      options.physical_block_size = physical_block_size;\n      options.read_only = true;\n      tables_[i] = NewPersistentTable(options);\n      iterators_[i] =\n          std::unique_ptr<PersistentTable::Iterator>(tables_[i]->ReadSnapshot(snapshot_name));\n    }\n    keys_buffer_.resize(kBatchSize);\n    values_buffer_.resize(kBatchSize * storage_dim_);\n  }\n  ~PersistentTableReaderImpl() override { CloseImpl(); }\n\n  std::tuple<py::object, py::object> Next() override {\n    while (current_table_ < tables_.size()) {\n      uint32_t n_result = 0;\n      iterators_[current_table_]->Next(kBatchSize, &n_result, keys_buffer_.data(),\n                                       values_buffer_.data());\n      if (n_result != 0) {\n        py::array_t<Key> keys_arr(py::array::ShapeContainer({n_result}));\n        py::array_t<Value> values_arr(py::array::ShapeContainer({n_result, storage_dim_}));\n        std::memcpy(keys_arr.mutable_data(), keys_buffer_.data(), n_result * sizeof(Key));\n        std::memcpy(values_arr.mutable_data(), values_buffer_.data(),\n                    n_result * storage_dim_ * sizeof(Value));\n        return std::make_tuple(keys_arr, values_arr);\n      } else {\n        current_table_ += 1;\n        continue;\n      }\n    }\n    throw py::stop_iteration();\n  }\n\n  void Close() override { CloseImpl(); }\n\n private:\n  void CloseImpl() {\n    if (!closed_) {\n      for (auto& table : tables_) { table.reset(); }\n    }\n    closed_ = true;\n  }\n\n  bool closed_;\n  std::string snapshot_name_;\n  std::vector<std::unique_ptr<PersistentTable>> tables_;\n  std::vector<std::unique_ptr<PersistentTable::Iterator>> iterators_;\n  uint32_t storage_dim_;\n  size_t current_table_;\n  std::vector<Key> keys_buffer_;\n  std::vector<Value> values_buffer_;\n};\n\ntemplate<typename Key>\nstd::shared_ptr<PersistentTableReader> NewPersistentTableReader(\n    const std::vector<std::string>& paths, const std::string& snapshot_name,\n    const Symbol<DType>& key_type, const Symbol<DType>& value_type, uint32_t storage_dim,\n    uint64_t target_chunk_size_mb, uint16_t physical_block_size) {\n  if (value_type->data_type() == DataType::kFloat) {\n    return std::shared_ptr<PersistentTableReader>(new PersistentTableReaderImpl<Key, float>(\n        paths, snapshot_name, storage_dim, target_chunk_size_mb, physical_block_size));\n  } else {\n    UNIMPLEMENTED();\n  }\n}\n\nstd::shared_ptr<PersistentTableReader> NewPersistentTableReader(\n    const std::vector<std::string>& paths, const std::string& snapshot_name,\n    const Symbol<DType>& key_type, const Symbol<DType>& value_type, uint32_t storage_dim,\n    uint64_t target_chunk_size_mb, uint16_t physical_block_size) {\n  if (key_type->data_type() == DataType::kInt32) {\n    return NewPersistentTableReader<int32_t>(paths, snapshot_name, key_type, value_type,\n                                             storage_dim, target_chunk_size_mb,\n                                             physical_block_size);\n  } else if (key_type->data_type() == DataType::kUInt32) {\n    return NewPersistentTableReader<uint32_t>(paths, snapshot_name, key_type, value_type,\n                                              storage_dim, target_chunk_size_mb,\n                                              physical_block_size);\n  } else if (key_type->data_type() == DataType::kInt64) {\n    return NewPersistentTableReader<int64_t>(paths, snapshot_name, key_type, value_type,\n                                             storage_dim, target_chunk_size_mb,\n                                             physical_block_size);\n  } else if (key_type->data_type() == DataType::kUInt64) {\n    return NewPersistentTableReader<uint64_t>(paths, snapshot_name, key_type, value_type,\n                                              storage_dim, target_chunk_size_mb,\n                                              physical_block_size);\n  } else {\n    UNIMPLEMENTED();\n    return std::shared_ptr<embedding::PersistentTableReader>(nullptr);\n  }\n}\n\n}  // namespace embedding\n\nONEFLOW_API_PYBIND11_MODULE(\"\", m) {\n  py::class_<OneEmbeddingHandler, std::shared_ptr<OneEmbeddingHandler>>(m, \"OneEmbeddingHandler\")\n      .def(py::init([](const std::string& key_value_store_option_str, const int64_t local_rank_id,\n                       const int64_t rank_id, const int64_t world_size) {\n        return std::make_shared<OneEmbeddingHandler>(key_value_store_option_str, local_rank_id,\n                                                     rank_id, world_size);\n      }))\n      .def(\"SaveSnapshot\", &OneEmbeddingHandler::SaveSnapshot)\n      .def(\"LoadSnapshot\", &OneEmbeddingHandler::LoadSnapshot);\n\n  py::class_<embedding::PersistentTableWriter, std::shared_ptr<embedding::PersistentTableWriter>>(\n      m, \"PersistentTableWriter\")\n      .def(py::init([](const std::vector<std::string>& paths, const std::string& snapshot_name,\n                       const Symbol<DType>& key_type, const Symbol<DType>& value_type,\n                       uint32_t storage_dim, uint64_t target_chunk_size_mb,\n                       uint16_t physical_block_size) {\n        return embedding::NewPersistentTableWriter(paths, snapshot_name, key_type, value_type,\n                                                   storage_dim, target_chunk_size_mb,\n                                                   physical_block_size);\n      }))\n      .def(\"__enter__\", [](embedding::PersistentTableWriter* writer) { return writer; })\n      .def(\"__exit__\", [](embedding::PersistentTableWriter* writer, const py::object& exc_type,\n                          const py::object& exc_val, const py::object& exc_tb) { writer->Close(); })\n      .def(\"write\", &embedding::PersistentTableWriter::Write)\n      .def(\"close\", &embedding::PersistentTableWriter::Close);\n\n  py::class_<embedding::PersistentTableReader, std::shared_ptr<embedding::PersistentTableReader>>(\n      m, \"PersistentTableReader\")\n      .def(py::init([](const std::vector<std::string>& paths, const std::string& snapshot_name,\n                       const Symbol<DType>& key_type, const Symbol<DType>& value_type,\n                       uint32_t storage_dim, uint64_t target_chunk_size_mb,\n                       uint16_t physical_block_size) {\n        return embedding::NewPersistentTableReader(paths, snapshot_name, key_type, value_type,\n                                                   storage_dim, target_chunk_size_mb,\n                                                   physical_block_size);\n      }))\n      .def(\"__next__\", &embedding::PersistentTableReader::Next)\n      .def(\"__iter__\", [](embedding::PersistentTableReader* reader) { return reader; })\n      .def(\"__enter__\", [](embedding::PersistentTableReader* reader) { return reader; })\n      .def(\"__exit__\", [](embedding::PersistentTableReader* reader, const py::object& exc_type,\n                          const py::object& exc_val, const py::object& exc_tb) { reader->Close(); })\n      .def(\"close\", &embedding::PersistentTableReader::Close);\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/api/python/framework/op_builder.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <pybind11/pybind11.h>\n#include <pybind11/stl.h>\n#include <pybind11/functional.h>\n#include \"oneflow/api/python/of_api_registry.h\"\n#include \"oneflow/core/common/protobuf.h\"\n#include \"oneflow/core/common/throw.h\"\n#include \"oneflow/core/framework/op_expr.h\"\n#include \"oneflow/core/framework/op_builder.h\"\n\nnamespace py = pybind11;\n\nnamespace oneflow {\n\nnamespace one {\n\nONEFLOW_API_PYBIND11_MODULE(\"one\", m) {\n  py::class_<one::OpBuilder, std::shared_ptr<one::OpBuilder>>(m, \"OpBuilder\")\n      .def(py::init<const std::string&>())\n      .def(py::init<const std::string&, const std::string&>())\n      .def(\"input\", &OpBuilder::MaybeInput)\n      .def(\"output\", &OpBuilder::MaybeOutput)\n      .def(\"attr\",\n           [](const std::shared_ptr<one::OpBuilder>& x, const std::string& attr_name,\n              const std::string& attr_val_str) -> Maybe<OpBuilder&> {\n             AttrValue attr_val;\n             if (!TxtString2PbMessage(attr_val_str, &attr_val)) {\n               THROW(RuntimeError) << \"attr val parse failed.\\n\" << attr_val_str;\n             }\n             return x->MaybeAttr(attr_name, attr_val);\n           })\n      .def(\"build\", &OpBuilder::Build);\n}\n\n}  // namespace one\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/api/python/framework/op_expr.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <pybind11/pybind11.h>\n#include <pybind11/stl.h>\n#include \"oneflow/api/python/of_api_registry.h\"\n#include \"oneflow/core/common/protobuf.h\"\n#include \"oneflow/core/common/throw.h\"\n#include \"oneflow/core/framework/nd_sbp.h\"\n#include \"oneflow/core/framework/op_expr.h\"\n#include \"oneflow/core/framework/op_interpreter.h\"\n#include \"oneflow/core/framework/op_interpreter/op_interpreter_util.h\"\n#include \"oneflow/core/framework/tensor.h\"\n#include \"oneflow/core/framework/tensor_tuple.h\"\n\nnamespace py = pybind11;\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename OpT, typename ConfT,\n         typename std::enable_if<std::is_base_of<one::BuiltinOpExpr, OpT>::value>::type* = nullptr>\npy::class_<OpT, one::BuiltinOpExpr, std::shared_ptr<OpT>> PybindExportOpExpr(\n    py::module& m, const char* op_type_name) {\n  return py::class_<OpT, one::BuiltinOpExpr, std::shared_ptr<OpT>>(m, op_type_name)\n      .def(py::init([](const std::string& op_name, const std::string& op_conf_str,\n                       const std::vector<std::string>& indexed_ibns,\n                       const std::vector<std::string>& indexed_obns) {\n        ConfT proto_op_conf;\n        if (!TxtString2PbMessage(op_conf_str, &proto_op_conf)) {\n          THROW(RuntimeError) << \"op conf parse failed.\\n\" << op_conf_str;\n        }\n        return OpT::New(op_name, std::move(proto_op_conf), indexed_ibns, indexed_obns)\n            .GetPtrOrThrow();\n      }));\n}\n\n}  // namespace\n\nONEFLOW_API_PYBIND11_MODULE(\"one\", m) {\n  py::class_<one::OpExpr, std::shared_ptr<one::OpExpr>>(m, \"OpExpr\")\n      .def_property_readonly(\"op_type_name\", &one::OpExpr::op_type_name)\n      .def_property_readonly(\"input_size\", &one::OpExpr::input_size)\n      .def_property_readonly(\"output_size\", &one::OpExpr::output_size);\n\n  py::class_<one::BuiltinOpExpr, one::OpExpr, std::shared_ptr<one::BuiltinOpExpr>>(m,\n                                                                                   \"BuiltinOpExpr\")\n      .def_property_readonly(\"name\", &one::BuiltinOpExpr::op_name)\n      .def_property_readonly(\"indexed_ibns\", &one::BuiltinOpExpr::indexed_ibns)\n      .def_property_readonly(\"indexed_obns\", &one::BuiltinOpExpr::indexed_obns);\n\n  auto py_user_op_class = PybindExportOpExpr<one::UserOpExpr, UserOpConf>(m, \"UserOpExpr\");\n  py_user_op_class.def_property_readonly(\n      \"op_type_name\", [](const one::UserOpExpr& op) { return op.proto().op_type_name(); });\n  PybindExportOpExpr<one::VariableOpExpr, VariableOpConf>(m, \"VariableOpExpr\");\n  // NOTE(chengcheng): export for Lazy nn.Graph Feed/Fetch EagerTensor to/from LazyTensor.\n  PybindExportOpExpr<one::FeedInputOpExpr, FeedInputOpConf>(m, \"FeedInputOpExpr\");\n  PybindExportOpExpr<one::FeedVariableOpExpr, FeedVariableOpConf>(m, \"FeedVariableOpExpr\");\n  PybindExportOpExpr<one::FetchOutputOpExpr, FetchOutputOpConf>(m, \"FetchOutputOpExpr\");\n  PybindExportOpExpr<one::ImageDecoderRandomCropResizeOpExpr, ImageDecoderRandomCropResizeOpConf>(\n      m, \"ImageDecoderRandomCropResizeOpExpr\");\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/api/python/framework/parallel_conf_util.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include <pybind11/pybind11.h>\n#include <pybind11/stl.h>\n#include \"oneflow/api/python/framework/size.h\"\n#include \"oneflow/api/python/of_api_registry.h\"\n#include \"oneflow/core/framework/parallel_conf_util.h\"\n\nnamespace oneflow {\n\nONEFLOW_API_PYBIND11_MODULE(\"\", m) {\n  m.def(\"GetDeviceTagAndMachineDeviceIdsAndHierarchy\",\n        &GetDeviceTagAndMachineDeviceIdsAndHierarchy);\n  m.def(\"MakeParallelConf\", &MakeParallelConf);\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/api/python/framework/py_kernel_registry.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <pybind11/pybind11.h>\n#include <string>\n#include \"oneflow/api/python/of_api_registry.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/extension/python/py_kernel_registry.h\"\n\nnamespace py = pybind11;\n\nONEFLOW_API_PYBIND11_MODULE(\"\", m) {\n  m.def(\"RegisterPyKernelCaller\", &::oneflow::pyext::RegisterPyKernelCaller);\n  m.def(\"RegisterPyKernels\",\n        [](py::object py_kernels) { ::oneflow::pyext::RegisterPyKernels(py_kernels.ptr()); });\n}\n"
  },
  {
    "path": "oneflow/api/python/framework/random_generator.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <pybind11/pybind11.h>\n#include \"oneflow/api/python/functional/common.h\"\n#include \"oneflow/api/python/of_api_registry.h\"\n#include \"oneflow/core/framework/random_generator.h\"\n#include \"oneflow/core/framework/tensor.h\"\n#ifdef WITH_CUDA\n#include \"oneflow/core/device/cuda_util.h\"\n#endif  // WITH_CUDA\n\nnamespace py = pybind11;\n\nnamespace oneflow {\n\nMaybe<one::Generator> CreateGenerator(const std::string& device_str) {\n  auto [device_name, device_index, rematable] = *JUST(ParseDeviceString(device_str));\n  return one::MakeGenerator(device_name, device_index);\n}\n\npy::tuple GetCudaDefaultGenerators() {\n#ifdef WITH_CUDA\n  static int device_count = GetCudaDeviceCount();\n#else\n  static int device_count = 0;\n#endif\n  py::tuple default_cuda_generators(device_count);\n  FOR_RANGE(int, device_id, 0, device_count) {\n    const auto& cuda_gen = one::DefaultCUDAGenerator(device_id);\n    default_cuda_generators[device_id] = py::cast(cuda_gen);\n  }\n  return default_cuda_generators;\n}\n\nONEFLOW_API_PYBIND11_MODULE(\"\", m) {\n  py::class_<one::Generator, std::shared_ptr<one::Generator>>(m, \"Generator\")\n      .def(py::init([](const std::string& device_tag) {\n        return CreateGenerator(device_tag).GetPtrOrThrow();\n      }))\n      .def(\"manual_seed\",\n           [](const std::shared_ptr<one::Generator>& generator,\n              const py::object& seed) -> std::shared_ptr<one::Generator> {\n             int64_t seed_val = (one::functional::PyUnpackLong(seed.ptr())).GetOrThrow();\n             generator->set_current_seed(seed_val);\n             return generator;\n           })\n      .def(\"initial_seed\", &one::Generator::current_seed)\n      .def(\"seed\", &one::Generator::seed)\n      .def_property_readonly(\"device\", &one::Generator::device)\n      .def(\"get_state\", &one::Generator::GetState)\n      .def(\"set_state\", &one::Generator::SetState);\n\n  m.def(\"manual_seed\", [](const py::object& seed) -> Maybe<one::Generator> {\n    int64_t seed_val = JUST(one::functional::PyUnpackLong(seed.ptr()));\n    return one::ManualSeed(seed_val);\n  });\n  m.def(\"manual_seed\",\n        [](const py::object& seed, const std::string& device, int device_index) -> Maybe<void> {\n          int64_t seed_val = JUST(one::functional::PyUnpackLong(seed.ptr()));\n          return one::ManualSeed(seed_val, device, device_index);\n        });\n  m.def(\"create_generator\", &CreateGenerator);\n  m.def(\"default_generator\", [](const std::string& device_str) -> Maybe<one::Generator> {\n    auto [device_name, device_index, rematable] = *JUST(ParseDeviceString(device_str));\n    return one::DefaultGenerator(device_name, device_index);\n  });\n  m.def(\"ManualSeedAllCudaGenerator\", [](const py::object& seed) -> Maybe<void> {\n    int64_t seed_val = JUST(one::functional::PyUnpackLong(seed.ptr()));\n    return one::ManualSeedAllCudaGenerator(seed_val);\n  });\n  m.def(\"default_generators\", &GetCudaDefaultGenerators);\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/api/python/framework/scope_util.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <pybind11/pybind11.h>\n#include \"oneflow/api/python/of_api_registry.h\"\n#include \"oneflow/core/framework/scope_util.h\"\n\nnamespace py = pybind11;\n\nnamespace oneflow {\n\nONEFLOW_API_PYBIND11_MODULE(\"\", m) {\n  m.def(\"GetCurrentScope\", &GetCurrentScope);\n  m.def(\"MakeInitialScope\",\n        [](const std::string& job_conf_str, Symbol<ParallelDesc> placement,\n           bool is_local) -> Maybe<Scope> {\n          JobConfigProto job_conf;\n          CHECK_OR_RETURN(TxtString2PbMessage(job_conf_str, &job_conf)) << \"job conf parse failed\";\n          return MakeInitialScope(job_conf, placement, is_local);\n        });\n  m.def(\"InitGlobalScopeStack\", &InitThreadLocalScopeStack);\n\n  m.def(\"GlobalScopeStackPush\", &ThreadLocalScopeStackPush);\n  m.def(\"GlobalScopeStackPop\", &ThreadLocalScopeStackPop);\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/api/python/framework/session_util.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/api/python/of_api_registry.h\"\n#include \"oneflow/core/framework/session_util.h\"\n\nnamespace py = pybind11;\n\nnamespace oneflow {\n\nONEFLOW_API_PYBIND11_MODULE(\"\", m) {\n  m.def(\"GetDefaultSessionId\", []() -> int64_t { return GetDefaultSessionId().GetOrThrow(); });\n  m.def(\"RegsterSessionId\", &RegsterSessionId);\n  m.def(\"ClearSessionId\", &ClearSessionId);\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/api/python/framework/shut_down_util.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <pybind11/pybind11.h>\n#include \"oneflow/api/python/of_api_registry.h\"\n#include \"oneflow/core/framework/shut_down_util.h\"\n\nnamespace py = pybind11;\n\nnamespace oneflow {\n\nONEFLOW_API_PYBIND11_MODULE(\"\", m) {\n  m.def(\"SetShuttingDown\", []() { return SetShuttingDown(); });\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/api/python/framework/size.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <pybind11/pybind11.h>\n#include \"oneflow/api/python/functional/common.h\"\n#include \"oneflow/api/python/framework/size.h\"\n#include \"oneflow/api/python/of_api_registry.h\"\n#include \"oneflow/core/common/shape.h\"\n\nnamespace py = pybind11;\n\nnamespace oneflow {\n\nusing one::functional::PyObjectPtr;\n\nstatic PyObject* TensorSize_repr(TensorSize* self) {\n  std::stringstream ss;\n  int32_t idx = 0;\n  int32_t size = PyTuple_Size((PyObject*)self);\n  ss << \"oneflow.Size([\";\n  for (int i = 0; i < size; ++i) {\n    int64_t dim = PyLong_AsLongLong(PyTuple_GET_ITEM(self, i));\n    ss << dim;\n    if (++idx != size) { ss << \", \"; }\n  }\n  ss << \"])\";\n  return PyUnicode_FromString(ss.str().c_str());\n}\n\nstatic PyObject* TensorSize_new(PyTypeObject* type, PyObject* args, PyObject* kwargs) {\n  PyObjectPtr self(PyTuple_Type.tp_new(type, args, kwargs));\n  if (self.get()) {\n    for (int i = 0; i < PyTuple_Size(self.get()); ++i) {\n      PyObject* item = PyTuple_GET_ITEM(self.get(), i);\n      if (!PyLong_Check(item)) {\n        return PyErr_Format(PyExc_TypeError,\n                            \"oneflow.Size() takes an iterable of 'int', but item '%d' is '%s'\", i,\n                            Py_TYPE(item)->tp_name);\n      }\n    }\n  }\n  return self.release();\n}\n\nstatic Py_ssize_t TensorSize_length(TensorSize* self) {\n  return PyTuple_Type.tp_as_sequence->sq_length((PyObject*)self);\n}\n\nstatic PyObject* TensorSize_concat(TensorSize* self, PyObject* other) {\n  PyObjectPtr result(PyTuple_Type.tp_as_sequence->sq_concat((PyObject*)self, other));\n  if (!result.get()) { return nullptr; }\n  if (PyTuple_Check(result.get())) {\n    PyObjectPtr args(PyTuple_Pack(1, result.get()));\n    return TensorSize_new(&TensorSize_Type, args.get(), nullptr);\n  }\n  return result.release();\n}\n\nstatic PyObject* TensorSize_repeat(TensorSize* self, Py_ssize_t n) {\n  PyObjectPtr result(PyTuple_Type.tp_as_sequence->sq_repeat((PyObject*)self, n));\n  if (!result.get()) { return nullptr; }\n  if (PyTuple_Check(result.get())) {\n    PyObjectPtr args(PyTuple_Pack(1, result.get()));\n    return TensorSize_new(&TensorSize_Type, args.get(), nullptr);\n  }\n  return result.release();\n}\n\nstatic PyObject* TensorSize_item(TensorSize* self, Py_ssize_t i) {\n  return PyTuple_Type.tp_as_sequence->sq_item((PyObject*)self, i);\n}\n\nstatic int TensorSize_contains(TensorSize* self, PyObject* el) {\n  return PyTuple_Type.tp_as_sequence->sq_contains((PyObject*)self, el);\n}\n\nstatic PySequenceMethods TensorSize_as_sequence = {\n    (lenfunc)TensorSize_length,      /* sq_length */\n    (binaryfunc)TensorSize_concat,   /* sq_concat */\n    (ssizeargfunc)TensorSize_repeat, /* sq_repeat */\n    (ssizeargfunc)TensorSize_item,   /* sq_item */\n    0,                               /* sq_slice */\n    0,                               /* sq_ass_item */\n    0,                               /* sq_ass_slice */\n    (objobjproc)TensorSize_contains, /* sq_contains */\n};\n\nstatic PyObject* TensorSize_subscript(TensorSize* self, PyObject* item) {\n  PyObjectPtr result(PyTuple_Type.tp_as_mapping->mp_subscript((PyObject*)self, item));\n  if (!result.get()) { return nullptr; }\n  if (PyTuple_Check(result.get())) {\n    PyObjectPtr args(PyTuple_Pack(1, result.get()));\n    return TensorSize_new(&TensorSize_Type, args.get(), nullptr);\n  }\n  return result.release();\n};\n\nstatic PyMappingMethods TensorSize_as_mapping = {\n    (lenfunc)TensorSize_length,       /* mp_length */\n    (binaryfunc)TensorSize_subscript, /* mp_subscript */\n    0,                                /* mp_ass_subscript */\n};\n\nstatic PyObject* TensorSize_numel(PyObject* self, PyObject* args) {\n  int64_t numel = 1;\n  for (int i = 0; i < PyTuple_Size(self); ++i) {\n    numel *= PyLong_AsLongLong(PyTuple_GET_ITEM((TensorSize*)self, i));\n  }\n  return PyLong_FromLongLong(numel);\n}\n\nstatic PyMethodDef TensorSize_methods[] = {\n    {\"numel\", (PyCFunction)TensorSize_numel, METH_NOARGS, NULL}, {NULL}};\n\nPyTypeObject TensorSize_Type = {\n    PyVarObject_HEAD_INIT(NULL, 0) \"oneflow.Size\", /* tp_name */\n    sizeof(TensorSize),                            /* tp_basicsize */\n    0,                                             /* tp_itemsize */\n    NULL,                                          /* tp_dealloc */\n    0,                                             /* tp_vectorcall_offset */\n    NULL,                                          /* tp_getattr */\n    NULL,                                          /* tp_setattr */\n    NULL,                                          /* tp_reserved */\n    (reprfunc)TensorSize_repr,                     /* tp_repr */\n    NULL,                                          /* tp_as_number */\n    &TensorSize_as_sequence,                       /* tp_as_sequence */\n    &TensorSize_as_mapping,                        /* tp_as_mapping */\n    NULL,                                          /* tp_hash  */\n    NULL,                                          /* tp_call */\n    NULL,                                          /* tp_str */\n    NULL,                                          /* tp_getattro */\n    NULL,                                          /* tp_setattro */\n    NULL,                                          /* tp_as_buffer */\n    Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE,      /* tp_flags */\n    NULL,                                          /* tp_doc */\n    NULL,                                          /* tp_traverse */\n    NULL,                                          /* tp_clear */\n    NULL,                                          /* tp_richcompare */\n    0,                                             /* tp_weaklistoffset */\n    NULL,                                          /* tp_iter */\n    NULL,                                          /* tp_iternext */\n    TensorSize_methods,                            /* tp_methods */\n    NULL,                                          /* tp_members */\n    NULL,                                          /* tp_getset */\n    &PyTuple_Type,                                 /* tp_base */\n    NULL,                                          /* tp_dict */\n    NULL,                                          /* tp_descr_get */\n    NULL,                                          /* tp_descr_set */\n    0,                                             /* tp_dictoffset */\n    NULL,                                          /* tp_init */\n    NULL,                                          /* tp_alloc */\n    TensorSize_new,                                /* tp_new */\n    NULL,                                          /* tp_free */\n};\n\nint TensorSize_Check(PyObject* p) { return p && p->ob_type == &TensorSize_Type; }\n\nPyObject* TensorSize_New(Py_ssize_t len) { return TensorSize_Type.tp_alloc(&TensorSize_Type, len); }\n\nPyObject* TensorSize_NewFromShape(const Shape& size) {\n  PyObjectPtr self(TensorSize_New(size.NumAxes()));\n  if (self.get()) {\n    for (int i = 0; i < size.NumAxes(); ++i) {\n      PyTuple_SET_ITEM(self.get(), i, PyLong_FromLongLong(size.At(i)));\n    }\n  }\n  return self.release();\n}\n\nShape TensorSize_AsShape(PyObject* self) {\n  if (!TensorSize_Check(self)) {\n    PyErr_Format(PyExc_TypeError, \"can only convert TensorSize(not \\\"%s\\\") to Shape\",\n                 Py_TYPE(self)->tp_name);\n    return Shape();\n  }\n  int size = TensorSize_length((TensorSize*)self);\n  DimVector dim_vec(size);\n  for (int i = 0; i < size; ++i) {\n    dim_vec[i] = PyLong_AsLongLong(PyTuple_GET_ITEM((TensorSize*)self, i));\n  }\n  return Shape(std::move(dim_vec));\n}\n\nONEFLOW_API_PYBIND11_MODULE(\"\", m) {\n  if (PyType_Ready(&TensorSize_Type) < 0) { return; }\n  Py_INCREF(&TensorSize_Type);\n  if (PyModule_AddObject(m.ptr(), \"Size\", (PyObject*)&TensorSize_Type) < 0) { return; }\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/api/python/framework/size.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_API_PYTHON_FRAMEWORK_SIZE_H_\n#define ONEFLOW_API_PYTHON_FRAMEWORK_SIZE_H_\n#include <type_traits>\n#include <Python.h>\n#undef _PyGC_FINALIZED\n#include <pybind11/pybind11.h>\n#include \"oneflow/core/common/shape.h\"\n\nnamespace oneflow {\n\ntypedef struct {\n  PyTupleObject ob_base;\n} TensorSize;\n\nextern PyTypeObject TensorSize_Type;\n\nint TensorSize_Check(PyObject* p);\n\nPyObject* TensorSize_New(Py_ssize_t len);\nPyObject* TensorSize_NewFromShape(const Shape& size);\n\nShape TensorSize_AsShape(PyObject* self);\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_API_PYTHON_FRAMEWORK_SIZE_H_\n"
  },
  {
    "path": "oneflow/api/python/framework/stream.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include <pybind11/pybind11.h>\n#include \"oneflow/api/python/of_api_registry.h\"\n#include \"oneflow/api/python/framework/thread.h\"\n#include \"oneflow/core/framework/stream.h\"\n#include \"oneflow/core/framework/stream_set.h\"\n#include \"oneflow/core/framework/stream_guard.h\"\n\nnamespace py = pybind11;\n\nONEFLOW_API_PYBIND11_MODULE(\"\", m) {\n  using namespace oneflow;\n  py::class_<StreamSet, std::shared_ptr<StreamSet>>(m, \"StreamSet\")\n      .def(py::init([](const AsyncThread& async_thread) {\n        return StreamSet::New(async_thread.thread_uid()).GetPtrOrThrow();\n      }));\n\n  py::class_<StreamGuard, std::shared_ptr<StreamGuard>>(m, \"StreamGuard\")\n      .def(py::init([](const std::shared_ptr<StreamSet>& stream_set) {\n        auto stream_converter = std::make_shared<StreamConverter>(stream_set);\n        return std::make_shared<StreamGuard>(stream_converter);\n      }));\n}\n"
  },
  {
    "path": "oneflow/api/python/framework/tensor.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/api/python/framework/tensor.h\"\n\n#include <pybind11/pybind11.h>\n#include <Python.h>\n#undef _PyGC_FINALIZED\n#include \"oneflow/api/python/exception/exception.h\"\n#include \"oneflow/api/python/framework/size.h\"\n#include \"oneflow/api/python/framework/tensortype.h\"\n#include \"oneflow/api/python/functional/common.h\"\n#include \"oneflow/api/python/functional/python_arg.h\"\n#include \"oneflow/api/python/functional/functional_api.yaml.pybind.h\"\n#include \"oneflow/api/python/functional/tensor_api.yaml.pybind.h\"\n#include \"oneflow/api/python/of_api_registry.h\"\n#include \"oneflow/api/python/utils/tensor_utils.h\"\n#include \"oneflow/core/autograd/autograd_engine.h\"\n#include \"oneflow/core/framework/tensor.h\"\n#include \"oneflow/core/framework/tensor_rpc_util.h\"\n#include \"oneflow/core/framework/device.h\"\n#include \"oneflow/core/common/stride.h\"\n#include \"oneflow/core/framework/dtype.h\"\n#include \"oneflow/core/framework/placement_utils.h\"\n#include \"oneflow/core/functional/functional.h\"\n#include \"oneflow/core/functional/tensor_index.h\"\n#include \"oneflow/core/kernel/kernel_util.h\"\n\nnamespace py = pybind11;\n\nnamespace oneflow {\nnamespace one {\n\n#define ASSERT(x) (x).GetOrThrow()\n#define ASSERT_PTR(x) (x).GetPtrOrThrow()\n#define PY_XINCREF(p) (({ Py_XINCREF(p); }), (p))\n\n#if PY_VERSION_HEX < 0x03070000\n#define PYGETSET_NAME(name) const_cast<char*>(name)\n#else\n#define PYGETSET_NAME(name) (name)\n#endif\n\nPyTypeObject* PyTensorObject_Type = NULL;\nPyTypeObject* PyParameterObject_Type = NULL;\n\nnamespace {\n\ntemplate<typename T>\nstruct AllocType {};\n#define DEFINE_ALLOC_TYPE(type)  \\\n  template<>                     \\\n  struct AllocType<type> {       \\\n    static PyTypeObject** value; \\\n  };                             \\\n  PyTypeObject** AllocType<type>::value = &Py##type##Object_Type\n\nDEFINE_ALLOC_TYPE(Tensor);\nDEFINE_ALLOC_TYPE(Parameter);\n#undef DEFINE_ALLOC_TYPE\n\ntemplate<typename T>\nPyObject* PyTensor_wrap(const std::shared_ptr<T>& data, PyTensorObject* bind_pyobj) {\n  if (!data) { Py_RETURN_NONE; }\n  PyObject* py_tensor = (PyObject*)data->pyobject();\n  if (bind_pyobj == nullptr && py_tensor) {\n    // Has been wrapped by python before\n    if (data->owns_pyobj()) {\n      // PyTensor are not alive in python side, so we flip back the ownership to PyTensor\n      data->set_owns_pyobj(false);\n      ((PyTensorObject*)py_tensor)->data = data;\n      // NOTE: Needn't incref here, because the reference count of py_tensor is already increased\n      return py_tensor;\n    } else {\n      // PyTensor is alive, so we directly incref it and return it\n      Py_XINCREF(py_tensor);\n      return py_tensor;\n    }\n  } else {\n    // Has not been wrapped by python before, so we create a new PyTensor and give it the ownership\n    if (bind_pyobj == nullptr) {\n      bind_pyobj = (PyTensorObject*)PyTensorObject_Type->tp_alloc(*AllocType<T>::value, 0);\n    }\n    bind_pyobj->data = data;\n    if (py_tensor) {\n      // If it has bind pyobj, reset the shared_ptr in origin PyTensorObject\n      ((PyTensorObject*)py_tensor)->data.reset();\n    }\n    bind_pyobj->data->set_pyobject_ptr(std::unique_ptr<void, void (*)(void*)>(\n        bind_pyobj, [](void* ptr) { Py_DECREF((PyObject*)ptr); }));\n    bind_pyobj->data->set_owns_pyobj(false);\n    return (PyObject*)bind_pyobj;\n  }\n}\n\nbool PyTensor_tryResurrect(PyObject* py_tensor) {\n  auto* self = (PyTensorObject*)py_tensor;\n  if (self->data) {\n    // PyTensor holds the ownership, now we flip it back to C++ and resurrect python object\n    // temporarily\n    auto tensor = self->data;\n    self->data.reset();\n    tensor->set_owns_pyobj(true);\n    Py_XINCREF(py_tensor);\n    return true;\n  }\n  // Otherwise, PyTensor was already not alive in python side\n  return false;\n}\n\n}  // namespace\n\nstatic int PyTensorObject_init(PyObject* self, PyObject* args, PyObject* kwargs) {\n  HANDLE_ERRORS\n  auto* temp = functional::_legacy_tensor_ctor(NULL, args, kwargs);\n  if (PyErr_Occurred()) { throw py::error_already_set(); }\n  PyTensor_wrap<Tensor>(PyTensor_Unpack(temp), (PyTensorObject*)self);\n  return 0;\n  END_HANDLE_ERRORS_RET(-1)\n}\n\nstatic void PyTensorObject_dealloc(PyObject* self) {\n  if (PyTensor_tryResurrect(self)) { return; }\n\n  // clear __dict__\n  PyObject** dict_ptr = _PyObject_GetDictPtr(self);\n  if (dict_ptr) { Py_CLEAR(*dict_ptr); }\n  auto* type = Py_TYPE(self);\n  type->tp_free(self);\n  Py_DECREF(type);\n}\n\nstatic int PyParameterObject_init(PyObject* self, PyObject* args, PyObject* kwargs) {\n  HANDLE_ERRORS\n  PyObject* data = NULL;\n  int requires_grad = 1;\n  static const char* keywords[3] = {\"data\", \"requires_grad\", NULL};\n  if (!PyArg_ParseTupleAndKeywords(args, kwargs, \"O|p:__init__\", const_cast<char**>(keywords),\n                                   &data, &requires_grad)) {\n    return -1;\n  }\n  if (self) {\n    PyTensor_wrap<Parameter>(\n        ASSERT_PTR(Parameter::MakeTensor(PyTensor_Unpack(data), requires_grad)),\n        (PyTensorObject*)self);\n  }\n  return 0;\n  END_HANDLE_ERRORS_RET(-1)\n}\n\nstatic Py_ssize_t PyTensorObject_length(PyTensorObject* self) {\n  if (self->data->ndim() == 0) { return 0; }\n  return self->data->dim(0);\n}\n\nstatic PyObject* PyTensorObject_getitem(PyObject* self, Py_ssize_t item) {\n  HANDLE_ERRORS\n  const auto& p = PyTensor_Unpack(self);\n  return PyTensor_New(\n      ASSERT_PTR(functional::TensorGetItem(p, {functional::detail::IndexItem(item)})));\n  END_HANDLE_ERRORS\n}\n\nstatic PyObject* PyTensorObject_subscript(PyObject* self, PyObject* item) {\n  HANDLE_ERRORS\n  const auto& p = PyTensor_Unpack(self);\n  functional::PythonArg arg(item);\n  return PyTensor_New(ASSERT_PTR(functional::TensorGetItem(p, arg.As<functional::TensorIndex>())));\n  END_HANDLE_ERRORS\n}\n\nstatic PySequenceMethods PyTensorObject_as_sequence = {\n    (lenfunc)PyTensorObject_length, NULL, /*sq_concat*/\n    NULL,                                 /*sq_repeat*/\n    (ssizeargfunc)PyTensorObject_getitem, /*sq_item*/\n};\n\nextern int PyTensorObject_setitem(PyObject*, PyObject*, PyObject*);\nstatic PyMappingMethods PyTensorObject_as_mapping = {\n    (lenfunc)PyTensorObject_length,\n    (binaryfunc)PyTensorObject_subscript,\n    (objobjargproc)PyTensorObject_setitem,\n};\n\nstatic PyObject* PyTensorObject_storage_offset(PyObject* self, PyObject* unused) {\n  HANDLE_ERRORS\n  return functional::CastToPyObject(PyTensor_Unpack(self)->storage_offset());\n  END_HANDLE_ERRORS\n}\n\nstatic PyObject* PyTensorObject_stride(PyObject* self, PyObject* unused) {\n  HANDLE_ERRORS\n  const auto& stride = ASSERT_PTR(PyTensor_Unpack(self)->stride());\n  PyObject* tup = PyTuple_New(stride->size());\n  for (int i = 0; i < stride->size(); ++i) {\n    PyTuple_SetItem(tup, i, PyLong_FromUnsignedLong(stride->at(i)));\n  }\n  return tup;\n  END_HANDLE_ERRORS\n}\n\nstatic PyObject* PyTensorObject_is_contiguous(PyObject* self, PyObject* unused) {\n  HANDLE_ERRORS\n  return functional::CastToPyObject(PyTensor_Unpack(self)->is_contiguous());\n  END_HANDLE_ERRORS\n}\n\nstatic PyObject* PyTensorObject_is_view(PyObject* self, PyObject* unused) {\n  HANDLE_ERRORS\n  if (PyTensor_Unpack(self)->is_view()) {\n    Py_RETURN_TRUE;\n  } else {\n    Py_RETURN_FALSE;\n  }\n  END_HANDLE_ERRORS\n}\n\nstatic PyObject* PyTensorObject_contiguous(PyObject* self, PyObject* unused) {\n  HANDLE_ERRORS\n  return PyTensor_New(PyTensor_Unpack(self)->contiguous());\n  END_HANDLE_ERRORS\n}\n\nstatic PyObject* PyTensorObject_contiguous_(PyObject* self, PyObject* unused) {\n  // NOTE: inplace version of contiguous\n  HANDLE_ERRORS\n  return PyTensor_New(ASSERT_PTR(functional::InplaceToContiguous(PyTensor_Unpack(self))));\n  END_HANDLE_ERRORS\n}\n\nstatic PyObject* PyTensorObject_pin_memory(PyObject* self, PyObject* unused) {\n  HANDLE_ERRORS\n  return PyTensor_New(PyTensor_Unpack(self)->pin_memory());\n  END_HANDLE_ERRORS\n}\n\nstatic PyObject* PyTensorObject_is_pinned(PyObject* self, PyObject* unused) {\n  HANDLE_ERRORS\n  return functional::CastToPyObject(CHECK_JUST(PyTensor_Unpack(self)->is_pinned()));\n  END_HANDLE_ERRORS\n}\n\nstatic PyObject* PyTensorObject_offload(PyObject* self, PyObject* unused) {\n  HANDLE_ERRORS\n  const auto& t = PyTensor_Unpack(self);\n  CHECK_JUST(t->offload());\n  Py_RETURN_NONE;\n  END_HANDLE_ERRORS\n}\n\nstatic PyObject* PyTensorObject_load(PyObject* self, PyObject* unused) {\n  HANDLE_ERRORS\n  const auto& t = PyTensor_Unpack(self);\n  CHECK_JUST(t->load());\n  Py_RETURN_NONE;\n  END_HANDLE_ERRORS\n}\n\nstatic PyObject* PyTensorObject_is_offloaded(PyObject* self, PyObject* unused) {\n  HANDLE_ERRORS\n  return functional::CastToPyObject(CHECK_JUST(PyTensor_Unpack(self)->is_offloaded()));\n  END_HANDLE_ERRORS\n}\n\nstatic PyObject* PyTensorObject_is_floating_point(PyObject* self, PyObject* unused) {\n  HANDLE_ERRORS\n  if (PyTensor_Unpack(self)->dtype()->is_floating_point()) {\n    Py_RETURN_TRUE;\n  } else {\n    Py_RETURN_FALSE;\n  }\n  END_HANDLE_ERRORS\n}\n\nstatic PyObject* PyTensorObject_requires_grad_(PyObject* self, PyObject* args, PyObject* kwargs) {\n  HANDLE_ERRORS\n  int requires_grad = 1;\n  static const char* keywords[2] = {\"requires_grad\", NULL};\n  if (!PyArg_ParseTupleAndKeywords(args, kwargs, \"|p:requires_grad_\", const_cast<char**>(keywords),\n                                   &requires_grad)) {\n    return NULL;\n  }\n  ASSERT(PyTensor_Unpack(self)->set_requires_grad(requires_grad));\n  Py_XINCREF(self);\n  return self;\n  END_HANDLE_ERRORS\n}\n\nstatic PyObject* PyTensorObject_retain_grad(PyObject* self, PyObject* unused) {\n  HANDLE_ERRORS\n  const auto& t = PyTensor_Unpack(self);\n  CHECK_JUST(t->set_retain_grad(true));\n  Py_RETURN_NONE;\n  END_HANDLE_ERRORS\n}\n\nstatic PyObject* PyTensorObject_detach(PyObject* self, PyObject* unused) {\n  HANDLE_ERRORS\n  return PyTensor_New(ASSERT_PTR(PyTensor_Unpack(self)->detach()));\n  END_HANDLE_ERRORS\n}\n\nstatic PyObject* PyTensorObject_clone(PyObject* self, PyObject* unused) {\n  HANDLE_ERRORS\n  return PyTensor_New(ASSERT_PTR(PyTensor_Unpack(self)->clone()));\n  END_HANDLE_ERRORS\n}\n\nstatic PyObject* PyTensorObject_zero_(PyObject* self, PyObject* unused) {\n  HANDLE_ERRORS\n  ASSERT(EagerLocalTensorZeros(PyTensor_Unpack(self)));\n  Py_XINCREF(self);\n  return self;\n  END_HANDLE_ERRORS\n}\n\nstd::vector<Symbol<SbpParallel>> RawSbpBToP(Symbol<NdSbp> nd_sbp) {\n  std::vector<Symbol<SbpParallel>> new_nd_sbp;\n  for (const auto& old_sbp : nd_sbp->sbp_parallel()) {\n    SbpParallel new_sbp = old_sbp;\n    if (new_sbp.has_broadcast_parallel()) { new_sbp.mutable_partial_sum_parallel(); }\n    new_nd_sbp.push_back(SymbolOf(new_sbp));\n  }\n  return new_nd_sbp;\n}\n\nstatic constexpr auto* SbpBToP = DECORATE(&RawSbpBToP, ThreadLocalCached);\n\nstatic PyObject* PyTensorObject_zero_grad(PyObject* self, PyObject* args, PyObject* kwargs) {\n  HANDLE_ERRORS\n  int set_to_none = 0;\n  static const char* keywords[2] = {\"set_to_none\", NULL};\n  if (!PyArg_ParseTupleAndKeywords(args, kwargs, \"|p:_zero_grad_\", const_cast<char**>(keywords),\n                                   &set_to_none)) {\n    return NULL;\n  }\n  const auto& t = PyTensor_Unpack(self);\n  const auto acc_grad = ASSERT_PTR(t->acc_grad());\n  if (acc_grad) {\n    if (set_to_none) {\n      ASSERT(t->set_acc_grad(NULL));\n    } else {\n      ASSERT(EagerLocalTensorZeros(acc_grad));\n      if (acc_grad->is_global() && acc_grad->is_eager()) {\n        const auto local_tensor = ASSERT_PTR(functional::GlobalToLocal(acc_grad, false));\n        const auto p = ASSERT_PTR(functional::LocalToGlobal(\n            local_tensor, ASSERT(acc_grad->parallel_desc()), SbpBToP(ASSERT(acc_grad->nd_sbp())),\n            *acc_grad->shape(), acc_grad->dtype(), false, false));\n        ASSERT(acc_grad->set_data(p));\n      }\n    }\n  }\n  Py_XINCREF(self);\n  return self;\n  END_HANDLE_ERRORS\n}\n\nstatic PyObject* PyTensorObject_register_hook(PyObject* self, PyObject* hook) {\n  HANDLE_ERRORS\n  const auto& _hook = py::cast<AutogradMeta::Hook>(py::reinterpret_borrow<py::object>(hook));\n  ASSERT(RegisterTensorHook(PyTensor_Unpack(self), _hook));\n  Py_RETURN_NONE;\n  END_HANDLE_ERRORS\n}\n\nstatic PyObject* PyTensorObject__register_post_grad_accumulation_hook(PyObject* self,\n                                                                      PyObject* hook) {\n  HANDLE_ERRORS\n  const auto& _hook = py::cast<AutogradMeta::Hook>(py::reinterpret_borrow<py::object>(hook));\n  ASSERT(RegisterTensorPostGradAccumulationHook(PyTensor_Unpack(self), _hook));\n  Py_RETURN_NONE;\n  END_HANDLE_ERRORS\n}\n\nstatic PyObject* PyTensorObject_global_id(PyObject* self, PyObject* unused) {\n  HANDLE_ERRORS\n  uint64_t global_id = static_cast<uint64_t>(ASSERT(PyTensor_Unpack(self)->transport_token()));\n  return functional::CastToPyObject(global_id);\n  END_HANDLE_ERRORS\n}\n\nstatic PyObject* PyTensorObject_check_meta_consistency(PyObject* self, PyObject* unused) {\n  HANDLE_ERRORS\n  ASSERT(CheckMetaConsistency(PyTensor_Unpack(self)));\n  Py_RETURN_NONE;\n  END_HANDLE_ERRORS\n}\n\nstatic PyObject* PyTensorObject_data_ptr(PyObject* self, PyObject* unused) {\n  HANDLE_ERRORS\n  const auto& t = PyTensor_Unpack(self);\n  const std::shared_ptr<LocalTensor> local_tensor =\n      t->is_local() ? ASSERT_PTR(t->AsLocalTensor()) : ASSERT_PTR(t->cur_rank_phy_tensor());\n  return functional::CastToPyObject(\n      reinterpret_cast<int64_t>(ASSERT(GetTensorDataPtr(local_tensor))));\n  END_HANDLE_ERRORS\n}\n\nstatic PyObject* PyTensorObject_to_numpy(PyObject* self, PyObject* unused) {\n  HANDLE_ERRORS\n  const auto& t = PyTensor_Unpack(self);\n  DataType data_type = t->dtype()->data_type();\n  switch (data_type) {\n#define SWITCH_EAGER_TENSOR_TO_NUMPY(cpp_type, of_type) \\\n  case of_type: return ASSERT(EagerLocalTensorToNumpy<cpp_type>(self));\n    OF_PP_FOR_EACH_TUPLE(SWITCH_EAGER_TENSOR_TO_NUMPY,\n                         POD_DATA_TYPE_SEQ INT16_DATA_TYPE_SEQ COMPLEX_DATA_TYPE_SEQ)\n    case DataType::kFloat16: return ASSERT(EagerLocalTensorToNumpy<float16>(self));\n    default: {\n      return PyErr_Format(PyExc_RuntimeError,\n                          (\"Invalid datatype \" + DataType_Name(data_type)).data());\n    }\n  }\n#undef SWITCH_EAGER_TENSOR_TO_NUMPY\n  END_HANDLE_ERRORS\n}\n\nstatic PyObject* PyTensorObject_item(PyObject* self, PyObject* unused) {\n  HANDLE_ERRORS\n  const auto& t = PyTensor_Unpack(self);\n  DataType data_type = t->dtype()->data_type();\n  switch (data_type) {\n#define CASE_SCALAR_TENSOR_TO_SCALAR(cpp_type, of_type) \\\n  case of_type: return ASSERT(EagerLocalTensorItem<cpp_type>(t));\n    OF_PP_FOR_EACH_TUPLE(CASE_SCALAR_TENSOR_TO_SCALAR,\n                         POD_AND_HALF_DATA_TYPE_SEQ COMPLEX_DATA_TYPE_SEQ);\n    default: {\n      return PyErr_Format(PyExc_RuntimeError,\n                          (\"Invalid datatype \" + DataType_Name(data_type)).data());\n    }\n  }\n#undef CASE_SCALAR_TENSOR_TO_SCALAR\n  END_HANDLE_ERRORS\n}\n\nstatic PyObject* PyTensorObject_type(PyObject* self, PyObject* args, PyObject* kwargs) {\n  HANDLE_ERRORS\n  const auto& tensor = PyTensor_Unpack(self);\n  PyObject* tensor_type = NULL;\n  int non_blocking = 0;\n  static const char* keywords[3] = {\"dtype\", \"non_blocking\", NULL};\n  if (!PyArg_ParseTupleAndKeywords(args, kwargs, \"|Op:type\", const_cast<char**>(keywords),\n                                   &tensor_type, &non_blocking)) {\n    return NULL;\n  }\n  // TODO: support non_blocking=True\n  if (non_blocking == 1) {\n    return PyErr_Format(PyExc_TypeError, \"non_blocking=True is not supported yet\");\n  }\n  if (tensor_type == NULL) {\n    tensor_type =\n        PyTensorType_FromDTypeAndDeviceType(tensor->dtype(), ASSERT(tensor->device())->enum_type());\n    return PyUnicode_FromString(((PyTensorType*)tensor_type)->name);\n  }\n  if (PyTensorMetaClass_CheckExact(tensor_type)) {\n    Optional<std::string> device = \"cpu\";\n    return PyTensor_New(ASSERT_PTR(functional::To(tensor, device, DType::Float(), /*copy=*/false)));\n  }\n  if (PyUnicode_Check(tensor_type)) {\n    tensor_type = PyTensorType_FromString(PyUnicode_AsUTF8(tensor_type));\n  }\n  if (PyTensorType_Check(tensor_type)) {\n    const auto& dtype = PyTensorType_UnpackDType(tensor_type);\n    DeviceType device_type = PyTensorType_UnpackDevice(tensor_type);\n    if (device_type == ASSERT(tensor->device())->enum_type()) {\n      return PyTensor_New(ASSERT_PTR(functional::To(tensor, dtype, /*copy=*/false)));\n    }\n    Optional<std::string> device = ASSERT(DeviceTag4DeviceType(device_type));\n    return PyTensor_New(ASSERT_PTR(functional::To(tensor, device, dtype, /*copy=*/false)));\n\n  } else if (functional::PyDTypeCheck(tensor_type)) {\n    return PyTensor_New(\n        ASSERT_PTR(functional::To(tensor, functional::PyUnpackDType(tensor_type), /*copy=*/false)));\n  }\n  return PyErr_Format(PyExc_TypeError, \"dtype must be a type, str, or dtype object\");\n  END_HANDLE_ERRORS\n}\n\nnamespace {\nvoid CopyFromNumpyArray(ep::Stream* stream,\n                        const std::shared_ptr<vm::EagerBlobObject>& eager_blob_object,\n                        const NumPyArrayPtr& array_ptr) {\n  SyncAutoMemcpy(stream, eager_blob_object->mut_dptr(), array_ptr.data(),\n                 eager_blob_object->ByteSizeOfBlobBody(), eager_blob_object->mem_case(),\n                 memory::MakeHostMemCase());\n}\n\nvoid CopyToNumpyArray(ep::Stream* stream,\n                      const std::shared_ptr<vm::EagerBlobObject>& eager_blob_object,\n                      const NumPyArrayPtr& array_ptr) {\n  SyncAutoMemcpy(stream, array_ptr.data(), eager_blob_object->dptr(),\n                 eager_blob_object->ByteSizeOfBlobBody(), memory::MakeHostMemCase(),\n                 eager_blob_object->mem_case());\n}\n}  // namespace\n   //\nstatic PyObject* PyTensorObject__copy_to_numpy(PyObject* self, PyObject* array) {\n  HANDLE_ERRORS\n  ASSERT(CopyBetweenLocalTensorAndNumpy(PyTensor_Unpack(self), array, CopyToNumpyArray, \"const\",\n                                        /*block_host_until_done=*/true));\n  Py_RETURN_NONE;\n  END_HANDLE_ERRORS\n}\nstatic PyObject* PyTensorObject__copy_from_numpy(PyObject* self, PyObject* array) {\n  HANDLE_ERRORS\n  auto* copied = PyArray_NewCopy((PyArrayObject*)array, NPY_CORDER);\n  ASSERT(CopyBetweenLocalTensorAndNumpy(PyTensor_Unpack(self), copied, CopyFromNumpyArray, \"mut\",\n                                        /*block_host_until_done=*/false));\n  Py_DECREF(copied);\n  Py_RETURN_NONE;\n  END_HANDLE_ERRORS\n}\n\nstatic PyObject* PyTensorObject__register_storage_delete_hook(PyObject* self, PyObject* hook) {\n  HANDLE_ERRORS\n  auto _hook = py::cast<std::function<void()>>(py::reinterpret_borrow<py::object>(hook));\n  ASSERT(PyTensor_Unpack(self)->RegisterStorageDeleteHook(_hook));\n  Py_RETURN_NONE;\n  END_HANDLE_ERRORS\n}\n\nstatic std::vector<PyMethodDef> concat_method_def(PyMethodDef methods[],\n                                                  PyMethodDef extra_methods[]) {\n  int len1 = 0;\n  int len2 = 0;\n  PyMethodDef* p1 = methods;\n  PyMethodDef* p2 = extra_methods;\n  while ((p1++)->ml_name != NULL) { len1++; }\n  while ((p2++)->ml_name != NULL) { len2++; }\n  std::vector<PyMethodDef> total_methods(len1 + len2 + 1);\n  for (int i = 0; i < len1; i++) total_methods[i] = methods[i];\n  for (int i = 0; i < len2; i++) total_methods[i + len1] = extra_methods[i];\n  total_methods[len1 + len2] = {NULL};\n  return total_methods;\n}\n\nstatic PyMethodDef PyTensorObject_methods[] = {\n    {\"storage_offset\", PyTensorObject_storage_offset, METH_NOARGS, NULL},\n    {\"stride\", PyTensorObject_stride, METH_NOARGS, NULL},\n    {\"is_contiguous\", PyTensorObject_is_contiguous, METH_NOARGS, NULL},\n    {\"is_view\", PyTensorObject_is_view, METH_NOARGS, NULL},\n    {\"contiguous\", PyTensorObject_contiguous, METH_NOARGS, NULL},\n    {\"contiguous_\", PyTensorObject_contiguous_, METH_NOARGS, NULL},\n    {\"pin_memory\", PyTensorObject_pin_memory, METH_NOARGS, NULL},\n    {\"is_pinned\", PyTensorObject_is_pinned, METH_NOARGS, NULL},\n    {\"offload\", PyTensorObject_offload, METH_NOARGS, NULL},\n    {\"load\", PyTensorObject_load, METH_NOARGS, NULL},\n    {\"is_offloaded\", PyTensorObject_is_offloaded, METH_NOARGS, NULL},\n    {\"is_floating_point\", PyTensorObject_is_floating_point, METH_NOARGS, NULL},\n    {\"requires_grad_\", (PyCFunction)PyTensorObject_requires_grad_, METH_VARARGS | METH_KEYWORDS,\n     NULL},\n    {\"retain_grad\", PyTensorObject_retain_grad, METH_NOARGS, NULL},\n    {\"detach\", PyTensorObject_detach, METH_NOARGS, NULL},\n    {\"clone\", PyTensorObject_clone, METH_NOARGS, NULL},\n    {\"zero_\", PyTensorObject_zero_, METH_NOARGS, NULL},\n    {\"_zero_grad_\", (PyCFunction)PyTensorObject_zero_grad, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"register_hook\", PyTensorObject_register_hook, METH_O, NULL},\n    {\"_register_post_grad_accumulation_hook\", PyTensorObject__register_post_grad_accumulation_hook,\n     METH_O, NULL},\n    {\"global_id\", PyTensorObject_global_id, METH_NOARGS, NULL},\n    {\"check_meta_consistency\", PyTensorObject_check_meta_consistency, METH_NOARGS, NULL},\n    {\"to_numpy\", PyTensorObject_to_numpy, METH_NOARGS, NULL},\n    {\"data_ptr\", PyTensorObject_data_ptr, METH_NOARGS, NULL},\n    {\"item\", PyTensorObject_item, METH_NOARGS, NULL},\n    {\"type\", (PyCFunction)PyTensorObject_type, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"_copy_to_numpy\", PyTensorObject__copy_to_numpy, METH_O, NULL},\n    {\"_copy_from_numpy\", PyTensorObject__copy_from_numpy, METH_O, NULL},\n    {\"_register_storage_delete_hook\", PyTensorObject__register_storage_delete_hook, METH_O, NULL},\n    {NULL}};\n\nstatic PyObject* PyTensorObject_ndim(PyObject* self, void* unused) {\n  return functional::CastToPyObject(PyTensor_Unpack(self)->ndim());\n}\n\nstatic PyObject* PyTensorObject_shape(PyObject* self, void* unused) {\n  return functional::CastToPyObject(PyTensor_Unpack(self)->shape());\n}\n\nstatic PyObject* PyTensorObject_dtype(PyObject* self, void* unused) {\n  HANDLE_ERRORS\n  const Symbol<DType>* dtype = &ASSERT(DType::Get(PyTensor_Unpack(self)->dtype()->data_type()));\n  return functional::CastToPyObject(dtype);\n  END_HANDLE_ERRORS\n}\n\nstatic PyObject* PyTensorObject_is_cpu(PyObject* self, void* unused) {\n  return functional::CastToPyObject(PyTensor_Unpack(self)->is_cpu());\n}\n\nstatic PyObject* PyTensorObject_is_cuda(PyObject* self, void* unused) {\n  return functional::CastToPyObject(PyTensor_Unpack(self)->is_cuda());\n}\n\nstatic PyObject* PyTensorObject_grad(PyObject* self, void* unused) {\n  HANDLE_ERRORS\n  return PyTensor_New(ASSERT_PTR(PyTensor_Unpack(self)->acc_grad()));\n  END_HANDLE_ERRORS\n}\n\nstatic int PyTensorObject_set_grad(PyObject* self, PyObject* grad, void* unused) {\n  HANDLE_ERRORS\n  const auto& t = PyTensor_Unpack(self);\n  if (self == grad) { PyErr_Format(PyExc_RuntimeError, \"can't assign Tensor as its own grad\"); }\n  if (grad && grad != Py_None) {\n    ASSERT(t->set_acc_grad(ASSERT_PTR(PyTensor_Unpack(grad)->detach())));\n  } else {\n    ASSERT(t->set_acc_grad(NULL));\n  }\n  return 0;\n  END_HANDLE_ERRORS_RET(-1)\n}\n\nstatic PyObject* PyTensorObject_data(PyObject* self, void* unused) {\n  HANDLE_ERRORS\n  return PyTensor_New(ASSERT_PTR(PyTensor_Unpack(self)->data()));\n  END_HANDLE_ERRORS\n}\n\nstatic int PyTensorObject_set_data(PyObject* self, PyObject* data, void* unused) {\n  HANDLE_ERRORS\n  const auto& t = PyTensor_Unpack(self);\n  auto hooks = t->autograd_meta()->hooks();\n  ASSERT(t->set_data(PyTensor_Unpack(data)));\n  // Re-register hooks\n  for (const auto& hook : hooks) { ASSERT(RegisterTensorHook(t, hook)); }\n  return 0;\n  END_HANDLE_ERRORS_RET(-1)\n}\n\nstatic PyObject* PyTensorObject_ref_tensor(PyObject* self, void* unused) {\n  HANDLE_ERRORS\n  return PyTensor_New(ASSERT_PTR(PyTensor_Unpack(self)->ref_tensor()));\n  END_HANDLE_ERRORS\n}\n\nstatic int PyTensorObject_set_ref_tensor(PyObject* self, PyObject* ref, void* unused) {\n  HANDLE_ERRORS\n  const auto& t = PyTensor_Unpack(self);\n  if (self == ref) { PyErr_Format(PyExc_RuntimeError, \"can't assign Tensor as its own reference\"); }\n  if (ref && ref != Py_None) {\n    ASSERT(t->set_ref_tensor(PyTensor_Unpack(ref)));\n  } else {\n    ASSERT(t->set_ref_tensor(NULL));\n  }\n  return 0;\n  END_HANDLE_ERRORS_RET(-1)\n}\n\nstatic PyObject* PyTensorObject_ref_index(PyObject* self, void* unused) {\n  return functional::CastToPyObject(PyTensor_Unpack(self)->ref_index());\n}\n\nstatic int PyTensorObject_set_ref_index(PyObject* self, PyObject* index, void* unused) {\n  HANDLE_ERRORS\n  const auto& t = PyTensor_Unpack(self);\n  CHECK_OR_THROW(PyLong_Check(index)) << Error::RuntimeError() << \"Index must be Integer type.\";\n  ASSERT(t->set_ref_index(PyLong_AsLong(index)));\n  return 0;\n  END_HANDLE_ERRORS_RET(-1)\n}\n\nstatic PyObject* PyTensorObject_grad_fn(PyObject* self, void* unused) {\n  return functional::CastToPyObject(PyTensor_Unpack(self)->grad_fn_node());\n}\n\nstatic PyObject* PyTensorObject_is_leaf(PyObject* self, void* unused) {\n  return functional::CastToPyObject(PyTensor_Unpack(self)->is_leaf());\n}\n\nstatic PyObject* PyTensorObject_requires_grad(PyObject* self, void* unused) {\n  return functional::CastToPyObject(PyTensor_Unpack(self)->requires_grad());\n}\n\nstatic int PyTensorObject_set_requires_grad(PyObject* self, PyObject* requires_grad, void* unused) {\n  HANDLE_ERRORS\n  const auto& t = PyTensor_Unpack(self);\n  CHECK_OR_THROW(t->is_leaf()) << Error::RuntimeError()\n                               << \"You can only change requires_grad flags of leaf tensors.\";\n  ASSERT(t->set_requires_grad(requires_grad == Py_True));\n  return 0;\n  END_HANDLE_ERRORS_RET(-1)\n}\n\nstatic PyObject* PyTensorObject_is_lazy(PyObject* self, void* unused) {\n  return functional::CastToPyObject(PyTensor_Unpack(self)->is_lazy());\n}\n\nstatic PyObject* PyTensorObject_is_eager(PyObject* self, void* unused) {\n  return functional::CastToPyObject(PyTensor_Unpack(self)->is_eager());\n}\n\nstatic PyObject* PyTensorObject_is_global(PyObject* self, void* unused) {\n  return functional::CastToPyObject(PyTensor_Unpack(self)->is_global());\n}\n\nstatic PyObject* PyTensorObject_is_local(PyObject* self, void* unused) {\n  return functional::CastToPyObject(PyTensor_Unpack(self)->is_local());\n}\n\nstatic PyObject* PyTensorObject__tensor_buffer_shapes_and_dtypes(PyObject* self, void* unused) {\n  HANDLE_ERRORS\n  return functional::CastToPyObject(MaybeGetTensorBufferShapesAndDTypes(PyTensor_Unpack(self)));\n  END_HANDLE_ERRORS\n}\n\nstatic PyObject* PyTensorObject_device(PyObject* self, void* unused) {\n  HANDLE_ERRORS\n  return functional::CastToPyObject(PyTensor_Unpack(self)->device());\n  END_HANDLE_ERRORS\n}\n\nstatic PyObject* PyTensorObject_placement(PyObject* self, void* unused) {\n  HANDLE_ERRORS\n  return functional::CastToPyObject(PyTensor_Unpack(self)->parallel_desc());\n  END_HANDLE_ERRORS\n}\n\nstatic PyObject* PyTensorObject_sbp(PyObject* self, void* unused) {\n  HANDLE_ERRORS\n  return functional::CastToPyObject(TensorGetPyTupleOfSbp(*PyTensor_Unpack(self)));\n  END_HANDLE_ERRORS\n}\n\n// NOLINTNEXTLINE\nstatic PyGetSetDef PyTensorObject_properties[] = {\n    {PYGETSET_NAME(\"ndim\"), (getter)PyTensorObject_ndim, NULL, NULL, NULL},\n    {PYGETSET_NAME(\"shape\"), (getter)PyTensorObject_shape, NULL, NULL, NULL},\n    {PYGETSET_NAME(\"dtype\"), (getter)PyTensorObject_dtype, NULL, NULL, NULL},\n    {PYGETSET_NAME(\"is_cpu\"), (getter)PyTensorObject_is_cpu, NULL, NULL, NULL},\n    {PYGETSET_NAME(\"is_cuda\"), (getter)PyTensorObject_is_cuda, NULL, NULL, NULL},\n    {PYGETSET_NAME(\"grad\"), (getter)PyTensorObject_grad, (setter)PyTensorObject_set_grad, NULL,\n     NULL},\n    {PYGETSET_NAME(\"data\"), (getter)PyTensorObject_data, (setter)PyTensorObject_set_data, NULL,\n     NULL},\n    {PYGETSET_NAME(\"_ref_tensor\"), (getter)PyTensorObject_ref_tensor,\n     (setter)PyTensorObject_set_ref_tensor, NULL, NULL},\n    {PYGETSET_NAME(\"_ref_index\"), (getter)PyTensorObject_ref_index,\n     (setter)PyTensorObject_set_ref_index, NULL, NULL},\n    {PYGETSET_NAME(\"grad_fn\"), (getter)PyTensorObject_grad_fn, NULL, NULL, NULL},\n    {PYGETSET_NAME(\"is_leaf\"), (getter)PyTensorObject_is_leaf, NULL, NULL, NULL},\n    {PYGETSET_NAME(\"requires_grad\"), (getter)PyTensorObject_requires_grad,\n     (setter)PyTensorObject_set_requires_grad, NULL, NULL},\n    {PYGETSET_NAME(\"is_lazy\"), (getter)PyTensorObject_is_lazy, NULL, NULL, NULL},\n    {PYGETSET_NAME(\"is_eager\"), (getter)PyTensorObject_is_eager, NULL, NULL, NULL},\n    {PYGETSET_NAME(\"is_global\"), (getter)PyTensorObject_is_global, NULL, NULL, NULL},\n    {PYGETSET_NAME(\"is_local\"), (getter)PyTensorObject_is_local, NULL, NULL, NULL},\n    {PYGETSET_NAME(\"_tensor_buffer_shapes_and_dtypes\"),\n     (getter)PyTensorObject__tensor_buffer_shapes_and_dtypes, NULL, NULL, NULL},\n    {PYGETSET_NAME(\"device\"), (getter)PyTensorObject_device, NULL, NULL, NULL},\n    {PYGETSET_NAME(\"placement\"), (getter)PyTensorObject_placement, NULL, NULL, NULL},\n    {PYGETSET_NAME(\"sbp\"), (getter)PyTensorObject_sbp, NULL, NULL, NULL},\n    {NULL}};\n\n// create a Tensor instance\nstatic PyObject* TensorMetaCls_call(PyObject* type, PyObject* args, PyObject* kwargs) {\n  return PyType_Type.tp_call(type, args, kwargs);\n}\n\nstatic void TensorMetaCls_dealloc(PyObject* type) { PyType_Type.tp_dealloc(type); }\n\nstatic PyHeapTypeObject* MakeTensorMetaclass() {\n  PyObject* name = PyUnicode_FromString(\"_TensorMeta\");\n\n  auto* heap_type = (PyHeapTypeObject*)PyType_Type.tp_alloc(&PyType_Type, 0);\n  heap_type->ht_name = name;\n  heap_type->ht_qualname = PY_XINCREF(name);\n\n  auto* type = &heap_type->ht_type;\n  type->tp_name = \"_TensorMeta\";\n  type->tp_base = PY_XINCREF(&PyType_Type);\n  type->tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HEAPTYPE;\n\n  type->tp_call = TensorMetaCls_call;\n  type->tp_dealloc = TensorMetaCls_dealloc;\n\n  if (PyType_Ready(type) < 0) { return NULL; }\n  PyObject_SetAttrString((PyObject*)type, \"__module__\", PyUnicode_FromString(\"oneflow._C\"));\n  return heap_type;\n}\n\nextern PyNumberMethods PyTensorObject_as_number;\nextern PyObject* PyTensorObject_richcompare(PyObject*, PyObject*, int);\nextern PyMethodDef PyTensorObject_extra_methods[];\n\nstatic PyHeapTypeObject* TensorMetaclass_Type = MakeTensorMetaclass();\n\nstatic PyTypeObject* MakeTensorType() {\n  PyObject* name = PyUnicode_FromString(\"Tensor\");\n\n  auto* metaclass = &TensorMetaclass_Type->ht_type;\n  auto* heap_type = (PyHeapTypeObject*)metaclass->tp_alloc(metaclass, 0);\n  if (!heap_type) { return NULL; }\n  heap_type->ht_name = name;\n  heap_type->ht_qualname = PY_XINCREF(name);\n  auto* type = &heap_type->ht_type;\n  type->tp_name = \"Tensor\";\n  type->tp_basicsize = sizeof(PyTensorObject);\n\n  type->tp_init = PyTensorObject_init;\n  type->tp_dealloc = PyTensorObject_dealloc;\n  type->tp_getset = PyTensorObject_properties;\n\n  static std::vector<PyMethodDef> total_methods =\n      concat_method_def(PyTensorObject_methods, PyTensorObject_extra_methods);\n  type->tp_methods = total_methods.data();\n\n  type->tp_as_number = &PyTensorObject_as_number;\n  type->tp_as_sequence = &PyTensorObject_as_sequence;\n  type->tp_as_mapping = &PyTensorObject_as_mapping;\n  type->tp_richcompare = PyTensorObject_richcompare;\n  type->tp_hash = (hashfunc)_Py_HashPointer;\n\n  type->tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HEAPTYPE;\n\n  if (PyType_Ready(type) < 0) { return NULL; }\n  PyObject_SetAttrString((PyObject*)type, \"__module__\", PyUnicode_FromString(\"oneflow\"));\n  return type;\n}\n\nstatic PyTypeObject* MakeParameterType() {\n  PyObject* name = PyUnicode_FromString(\"Parameter\");\n\n  auto* metaclass = &TensorMetaclass_Type->ht_type;\n  auto* heap_type = (PyHeapTypeObject*)metaclass->tp_alloc(metaclass, 0);\n  if (!heap_type) { return NULL; }\n  heap_type->ht_name = name;\n  heap_type->ht_qualname = PY_XINCREF(name);\n  auto* type = &heap_type->ht_type;\n  type->tp_name = \"Parameter\";\n  type->tp_basicsize = sizeof(PyTensorObject);\n\n  type->tp_init = PyParameterObject_init;\n\n  type->tp_base = PY_XINCREF(PyTensorObject_Type);\n\n  type->tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HEAPTYPE;\n\n  if (PyType_Ready(type) < 0) { return NULL; }\n  PyObject_SetAttrString((PyObject*)type, \"__module__\", PyUnicode_FromString(\"oneflow.nn\"));\n  return type;\n}\n\nPyObject* PyTensor_New(const std::shared_ptr<Tensor>& data) {\n  return PyTensor_wrap<Tensor>(data, /*bind_pyobj=*/nullptr);\n}\n\nPyObject* PyParameter_New(const std::shared_ptr<Parameter>& data) {\n  return PyTensor_wrap<Parameter>(data, /*bind_pyobj=*/nullptr);\n}\n\nPyObject* PyParameter_New(const std::shared_ptr<Tensor>& data, bool requires_grad) {\n  if (!data) { Py_RETURN_NONE; }\n  return PyTensor_wrap<Parameter>(ASSERT_PTR(Parameter::MakeTensor(data, requires_grad)),\n                                  /*bind_pyobj=*/nullptr);\n}\n\n}  // namespace one\n}  // namespace oneflow\n\n#undef ASSERT\n#undef ASSERT_PTR\n\nusing namespace oneflow::one;\n\nONEFLOW_API_PYBIND11_MODULE(\"\", m) {\n  PyTensorObject_Type = MakeTensorType();\n  PyParameterObject_Type = MakeParameterType();\n  if (PyTensorObject_Type\n      && PyModule_AddObject(m.ptr(), \"Tensor\", (PyObject*)PyTensorObject_Type) < 0) {\n    return;\n  }\n  auto nn = m.def_submodule(\"nn\");\n  if (PyParameterObject_Type\n      && PyModule_AddObject(nn.ptr(), \"Parameter\", (PyObject*)PyParameterObject_Type) < 0) {\n    return;\n  }\n}\n"
  },
  {
    "path": "oneflow/api/python/framework/tensor.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_API_PYTHON_FRAMEWORK_TENSOR_H_\n#define ONEFLOW_API_PYTHON_FRAMEWORK_TENSOR_H_\n\n#include <Python.h>\n#undef _PyGC_FINALIZED\n\n#include \"oneflow/core/framework/tensor.h\"\n\nnamespace oneflow {\nnamespace one {\n\ntypedef struct {\n  PyObject_HEAD;\n  std::shared_ptr<Tensor> data;\n} PyTensorObject;\n\nextern PyTypeObject* PyTensorObject_Type;\nextern PyTypeObject* PyParameterObject_Type;\n\ninline bool PyTensorMetaClass_CheckExact(PyObject* obj) {\n  return obj == (PyObject*)PyTensorObject_Type;\n}\n\ninline bool PyTensor_Check(PyObject* op) { return PyObject_TypeCheck(op, PyTensorObject_Type); }\n\ninline bool PyTensor_CheckExact(PyObject* op) {\n  return op->ob_type == PyTensorObject_Type || op->ob_type == PyParameterObject_Type;\n}\n\ninline std::shared_ptr<Tensor>& PyTensor_Unpack(PyObject* op) {\n  assert(PyTensor_Check(op));\n  return ((PyTensorObject*)op)->data;\n}\n\nPyObject* PyTensor_New(const std::shared_ptr<Tensor>& data);\nPyObject* PyParameter_New(const std::shared_ptr<Parameter>& data);\nPyObject* PyParameter_New(const std::shared_ptr<Tensor>& data, bool requires_grad);\n\n}  // namespace one\n}  // namespace oneflow\n\n#endif  // ONEFLOW_API_PYTHON_FRAMEWORK_TENSOR_H_\n"
  },
  {
    "path": "oneflow/api/python/framework/tensor_functions.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include <Python.h>\n#undef _PyGC_FINALIZED\n#include \"oneflow/api/python/exception/exception.h\"\n#include \"oneflow/api/python/framework/size.h\"\n#include \"oneflow/api/python/framework/tensor_functions_util.h\"\n#include \"oneflow/api/python/functional/common.h\"\n#include \"oneflow/api/python/functional/functional_api.yaml.pybind.h\"\n#include \"oneflow/api/python/functional/tensor_api.yaml.pybind.h\"\n#include \"oneflow/core/common/shape_vec.h\"\n#include \"oneflow/core/functional/functional.h\"\n#include \"oneflow/core/common/shape.h\"\n#include \"oneflow/core/common/wrap_dim_utils.h\"\n#include \"oneflow/core/functional/functional_api.yaml.h\"\n#include \"oneflow/api/python/functional/tensor_api.yaml.h\"\n#include \"oneflow/extension/python/numpy.h\"\n#include \"oneflow/api/python/utils/tensor_utils.h\"\n\nnamespace oneflow {\nnamespace one {\n\n#define ASSERT(x) (x).GetOrThrow()\n#define ASSERT_PTR(x) (x).GetPtrOrThrow()\n\nusing functional::PyObjectPtr;\nnamespace {\nPyObject* concat_self(PyObject* self, PyObject* args) {\n  PyObjectPtr self_tuple(PyTuple_Pack(1, self));\n  PyObject* tuple = PySequence_Concat(self_tuple.get(), args);\n  CHECK_OR_THROW(tuple != NULL);\n  return tuple;\n}\n\nPyObject* ndarray_judgment_and_compatibility(PyObject* self, PyObject* other) {\n  if (PyArray_Check(other)) {\n    const auto& tensor = PyTensor_Unpack(self);\n    CHECK_OR_THROW(tensor->is_cpu())\n        << Error::RuntimeError() << \"Can't convert non-cpu device tensor to numpy\";\n    if (tensor->is_global()) {\n      Symbol<ParallelDesc> placement = ASSERT(tensor->parallel_desc());\n      auto ndsbp = ASSERT(tensor->nd_sbp());\n      std::vector<Symbol<SbpParallel>> sbp(ndsbp->sbp_parallel_size(),\n                                           ASSERT(MakeBroadcastSbpParallel()));\n      other = functional::CastToPyObject(MakeGlobalTensorFromData(other, tensor->dtype(), placement,\n                                                                  sbp, /*requires_grad=*/false));\n    } else {\n      other = functional::CastToPyObject(functional::LocalTensorSharedNumpyData(other));\n    }\n  }\n  return other;\n}\n\n}  // namespace\n\n#define NB_UNARY_FUNC(func_name, bind_func)                  \\\n  static PyObject* func_name(PyObject* self) {               \\\n    HANDLE_ERRORS                                            \\\n    PyObjectPtr tuple(PyTuple_Pack(1, self));                \\\n    auto* result = bind_func(NULL, tuple.get(), NULL);       \\\n    if (PyErr_Occurred()) { throw py::error_already_set(); } \\\n    return result;                                           \\\n    END_HANDLE_ERRORS                                        \\\n  }\n\n#define NB_BINARY_FUNC(func_name, bind_func)                 \\\n  static PyObject* func_name(PyObject* a, PyObject* b) {     \\\n    HANDLE_ERRORS                                            \\\n    b = ndarray_judgment_and_compatibility(a, b);            \\\n    PyObjectPtr tuple(PyTuple_Pack(2, a, b));                \\\n    auto* result = bind_func(NULL, tuple.get(), NULL);       \\\n    if (PyErr_Occurred()) { throw py::error_already_set(); } \\\n    return result;                                           \\\n    END_HANDLE_ERRORS                                        \\\n  }  // namespace one\n\nNB_UNARY_FUNC(PyTensorObject_nb_absolute, functional::abs);\nNB_UNARY_FUNC(PyTensorObject_nb_negative, functional::negative);\nNB_UNARY_FUNC(PyTensorObject_nb_invert, functional::bitwise_not);\n// TODO: not implemented yet\n// NB_UNARY_FUNC(PyTensorObject_positive, functional::positive);\n\nNB_BINARY_FUNC(PyTensorObject_nb_add, functional::add);\nNB_BINARY_FUNC(PyTensorObject_nb_sub, functional::sub);\nNB_BINARY_FUNC(PyTensorObject_nb_mul, functional::mul);\nNB_BINARY_FUNC(PyTensorObject_nb_fmod, functional::fmod);\nNB_BINARY_FUNC(PyTensorObject_nb_div, functional::div);\nNB_BINARY_FUNC(PyTensorObject_nb_and, functional::logical_and);\nNB_BINARY_FUNC(PyTensorObject_nb_xor, functional::logical_xor);\nNB_BINARY_FUNC(PyTensorObject_nb_or, functional::logical_or);\nNB_BINARY_FUNC(PyTensorObject_nb_floor_div, functional::floor_divide);\nNB_BINARY_FUNC(PyTensorObject_nb_true_div, functional::div);\nNB_BINARY_FUNC(PyTensorObject_nb_matrix_multiply, functional::matmul);\n\nstatic PyObject* PyTensorObject_nb_pow(PyObject* a, PyObject* b, PyObject* unused) {\n  HANDLE_ERRORS\n  b = ndarray_judgment_and_compatibility(a, b);\n  PyObjectPtr tuple(PyTuple_Pack(2, a, b));\n  PyObject* result = functional::pow(NULL, tuple.get(), NULL);\n  if (PyErr_Occurred()) { throw py::error_already_set(); }\n  return result;\n  END_HANDLE_ERRORS\n}\n\n#define NB_INPLACE_BINARY_FUNC(func_name, bind_func)                           \\\n  static PyObject* func_name(PyObject* a, PyObject* b) {                       \\\n    HANDLE_ERRORS                                                              \\\n    b = ndarray_judgment_and_compatibility(a, b);                              \\\n    PyObjectPtr tuple(PyTuple_Pack(2, a, b));                                  \\\n    PyObjectPtr dict(PyDict_New());                                            \\\n    CHECK_OR_THROW(PyDict_SetItemString(dict.get(), \"inplace\", Py_True) > -1); \\\n    PyObject* result = bind_func(NULL, tuple.get(), dict.get());               \\\n    if (PyErr_Occurred()) { throw py::error_already_set(); }                   \\\n    return result;                                                             \\\n    END_HANDLE_ERRORS                                                          \\\n  }\n\n// inplace operators\nNB_INPLACE_BINARY_FUNC(PyTensorObject_nb_inplace_add, functional::add);\nNB_INPLACE_BINARY_FUNC(PyTensorObject_nb_inplace_sub, functional::sub);\n// The interface of inplace mul not mul(*, inplace=True) but mul_\nNB_BINARY_FUNC(PyTensorObject_nb_inplace_mul, functional::mul_);\nNB_BINARY_FUNC(PyTensorObject_nb_inplace_true_div, functional::div_);\n\nPyObject* PyTensorObject_nb_inplace_pow(PyObject* a, PyObject* b, PyObject* unused) {\n  HANDLE_ERRORS\n  PyObjectPtr tuple(PyTuple_Pack(2, a, b));\n  PyObjectPtr dict(PyDict_New());\n  CHECK_OR_THROW(PyDict_SetItemString(dict.get(), \"inplace\", Py_True) > -1);\n  auto* result = functional::pow(NULL, tuple.get(), NULL);\n  if (PyErr_Occurred()) { throw py::error_already_set(); }\n  return result;\n  END_HANDLE_ERRORS\n}\n\nPyNumberMethods PyTensorObject_as_number = {\n    PyTensorObject_nb_add,       // nb_add\n    PyTensorObject_nb_sub,       // nb_subtract\n    PyTensorObject_nb_mul,       // nb_multiply\n    PyTensorObject_nb_fmod,      // nb_remainder\n    NULL,                        // nb_divmod\n    PyTensorObject_nb_pow,       // nb_power\n    PyTensorObject_nb_negative,  // nb_negative\n    NULL,                        // nb_positive\n    PyTensorObject_nb_absolute,  // nb_absolute\n    NULL,                        // nb_bool\n    PyTensorObject_nb_invert,    // nb_invert\n    NULL,                        // nb_lshift\n    NULL,                        // nb_rshift\n    PyTensorObject_nb_and,       // nb_and\n    PyTensorObject_nb_xor,       // nb_xor\n    PyTensorObject_nb_or,        // nb_or\n    NULL,                        // nb_int\n    NULL,                        // nb_reserved\n    NULL,                        // nb_float\n\n    PyTensorObject_nb_inplace_add,  // nb_inplace_add\n    PyTensorObject_nb_inplace_sub,  // nb_inplace_sub\n    PyTensorObject_nb_inplace_mul,  // nb_inplace_mul\n    NULL,                           // nb_inplace_remainder\n    PyTensorObject_nb_inplace_pow,  // nb_inplace_pow\n    NULL,                           // nb_inplace_lshift\n    NULL,                           // nb_inplace_rshift\n    NULL,                           // nb_inplace_and\n    NULL,                           // nb_inplace_xor\n    NULL,                           // nb_inplace_or\n\n    PyTensorObject_nb_floor_div,         // nb_floor_div\n    PyTensorObject_nb_true_div,          // nb_true_div\n    NULL,                                // nb_inplace_floor_div\n    PyTensorObject_nb_inplace_true_div,  // nb_inplace_true_div\n\n    NULL,                               // nb_index\n    PyTensorObject_nb_matrix_multiply,  // nb_matrix_multiply\n    NULL,                               // nb_inplace_matrix_multiply\n\n};\n\n// extra methods\n\n// functions that accept only one Tensor\n#define UNARY_METHOD(func_name, bind_func)                             \\\n  static PyObject* func_name(PyObject* self, PyObject* unused) {       \\\n    HANDLE_ERRORS                                                      \\\n    return PyTensor_New(ASSERT_PTR(bind_func(PyTensor_Unpack(self)))); \\\n    END_HANDLE_ERRORS                                                  \\\n  }\n\nUNARY_METHOD(PyTensorObject_abs, functional::Abs);\nUNARY_METHOD(PyTensorObject_digamma, functional::Digamma);\nUNARY_METHOD(PyTensorObject_exp, functional::Exp);\nUNARY_METHOD(PyTensorObject_exp2, functional::Exp2);\nUNARY_METHOD(PyTensorObject_floor, functional::Floor);\nUNARY_METHOD(PyTensorObject_floor_, functional::Floor_);\nUNARY_METHOD(PyTensorObject_sign, functional::Sign);\nUNARY_METHOD(PyTensorObject_gelu, functional::Gelu);\nUNARY_METHOD(PyTensorObject_mish, functional::Mish);\nUNARY_METHOD(PyTensorObject_negative, functional::Negative);\nUNARY_METHOD(PyTensorObject_sigmoid, functional::Sigmoid);\nUNARY_METHOD(PyTensorObject_silu, functional::Silu);\nUNARY_METHOD(PyTensorObject_selu, functional::Selu);\nUNARY_METHOD(PyTensorObject_softsign, functional::SoftSign);\nUNARY_METHOD(PyTensorObject_log1p, functional::Log1p);\nUNARY_METHOD(PyTensorObject_log2, functional::Log2);\nUNARY_METHOD(PyTensorObject_log10, functional::Log10);\nUNARY_METHOD(PyTensorObject_reciprocal, functional::Reciprocal);\nUNARY_METHOD(PyTensorObject_ceil, functional::Ceil);\nUNARY_METHOD(PyTensorObject_ceil_, functional::Ceil_);\nUNARY_METHOD(PyTensorObject_erf, functional::Erf);\nUNARY_METHOD(PyTensorObject_erfc, functional::Erfc);\nUNARY_METHOD(PyTensorObject_erfinv, functional::Erfinv);\nUNARY_METHOD(PyTensorObject_erfinv_, functional::ErfinvInplace);\nUNARY_METHOD(PyTensorObject_expm1, functional::Expm1);\nUNARY_METHOD(PyTensorObject_log, functional::Log);\nUNARY_METHOD(PyTensorObject_rsqrt, functional::Rsqrt);\nUNARY_METHOD(PyTensorObject_sqrt, functional::Sqrt);\nUNARY_METHOD(PyTensorObject_square, functional::Square);\nUNARY_METHOD(PyTensorObject_round, functional::Round);\nUNARY_METHOD(PyTensorObject_round_, functional::Round_);\nUNARY_METHOD(PyTensorObject_t, functional::TransposeAllDimFunction);\nUNARY_METHOD(PyTensorObject_isnan, functional::IsNan);\nUNARY_METHOD(PyTensorObject_isinf, functional::IsInf);\nUNARY_METHOD(PyTensorObject_sin, functional::Sin);\nUNARY_METHOD(PyTensorObject_sin_, functional::Sin_);\nUNARY_METHOD(PyTensorObject_asin, functional::Asin);\nUNARY_METHOD(PyTensorObject_cos, functional::Cos);\nUNARY_METHOD(PyTensorObject_acos, functional::Acos);\nUNARY_METHOD(PyTensorObject_tan, functional::Tan);\nUNARY_METHOD(PyTensorObject_atan, functional::Atan);\nUNARY_METHOD(PyTensorObject_sinh, functional::Sinh);\nUNARY_METHOD(PyTensorObject_asinh, functional::Asinh);\nUNARY_METHOD(PyTensorObject_cosh, functional::Cosh);\nUNARY_METHOD(PyTensorObject_acosh, functional::Acosh);\nUNARY_METHOD(PyTensorObject_tanh, functional::Tanh);\nUNARY_METHOD(PyTensorObject_atanh, functional::Atanh);\nUNARY_METHOD(PyTensorObject_logical_not, functional::LogicalNot);\nUNARY_METHOD(PyTensorObject_bitwise_not, functional::BitwiseNot);\nUNARY_METHOD(PyTensorObject_inv, functional::Inv);\nUNARY_METHOD(PyTensorObject_trunc, functional::Trunc);\n// functions that directly pass arguments without parsing\n#define DIRECT_PASS_FUNC(func_name, bind_func)                                   \\\n  static PyObject* func_name(PyObject* self, PyObject* args, PyObject* kwargs) { \\\n    HANDLE_ERRORS                                                                \\\n    PyObjectPtr concat_args(concat_self(self, args));                            \\\n    PyObject* result = bind_func(NULL, concat_args.get(), kwargs);               \\\n    if (PyErr_Occurred()) { throw py::error_already_set(); }                     \\\n    return result;                                                               \\\n    END_HANDLE_ERRORS                                                            \\\n  }\n\nDIRECT_PASS_FUNC(PyTensorObject_floor_divide, functional::floor_divide)\nDIRECT_PASS_FUNC(PyTensorObject_atan2, functional::atan2)\nDIRECT_PASS_FUNC(PyTensorObject_gt, functional::greater)\nDIRECT_PASS_FUNC(PyTensorObject_gt_, functional::greater_)\nDIRECT_PASS_FUNC(PyTensorObject_frac, functional::frac)\nDIRECT_PASS_FUNC(PyTensorObject_frac_, functional::frac_)\nDIRECT_PASS_FUNC(PyTensorObject_ge, functional::greater_equal)\nDIRECT_PASS_FUNC(PyTensorObject_div, functional::div)\nDIRECT_PASS_FUNC(PyTensorObject_div_, functional::div_)\nDIRECT_PASS_FUNC(PyTensorObject_mul, functional::mul)\nDIRECT_PASS_FUNC(PyTensorObject_mul_, functional::mul_)\nDIRECT_PASS_FUNC(PyTensorObject_fmod, functional::fmod)\nDIRECT_PASS_FUNC(PyTensorObject_logical_and, functional::logical_and)\nDIRECT_PASS_FUNC(PyTensorObject_logical_or, functional::logical_or)\nDIRECT_PASS_FUNC(PyTensorObject_logical_xor, functional::logical_xor)\nDIRECT_PASS_FUNC(PyTensorObject_equal, functional::equal)\nDIRECT_PASS_FUNC(PyTensorObject_ne, functional::not_equal)\nDIRECT_PASS_FUNC(PyTensorObject_lt, functional::less)\nDIRECT_PASS_FUNC(PyTensorObject_le, functional::less_equal)\nDIRECT_PASS_FUNC(PyTensorObject_bmm, functional::batch_matmul)\nDIRECT_PASS_FUNC(PyTensorObject_argmax, functional::argmax)\nDIRECT_PASS_FUNC(PyTensorObject_argmin, functional::argmin)\nDIRECT_PASS_FUNC(PyTensorObject_amin, functional::amin)\nDIRECT_PASS_FUNC(PyTensorObject_amax, functional::amax)\nDIRECT_PASS_FUNC(PyTensorObject_addcmul, functional::addcmul)\nDIRECT_PASS_FUNC(PyTensorObject_addcmul_, functional::addcmul_)\nDIRECT_PASS_FUNC(PyTensorObject_addcdiv, functional::addcdiv)\nDIRECT_PASS_FUNC(PyTensorObject_addcdiv_, functional::addcdiv_)\nDIRECT_PASS_FUNC(PyTensorObject_flip, functional::flip)\nDIRECT_PASS_FUNC(PyTensorObject_clip, functional::clip)\nDIRECT_PASS_FUNC(PyTensorObject_clip_, functional::clip_)\nDIRECT_PASS_FUNC(PyTensorObject_clamp, functional::clamp)\nDIRECT_PASS_FUNC(PyTensorObject_clamp_min, functional::clamp_min)\nDIRECT_PASS_FUNC(PyTensorObject_clamp_max, functional::clamp_max)\nDIRECT_PASS_FUNC(PyTensorObject_clamp_, functional::clamp_)\nDIRECT_PASS_FUNC(PyTensorObject_clamp_min_, functional::clamp_min_)\nDIRECT_PASS_FUNC(PyTensorObject_clamp_max_, functional::clamp_max_)\nDIRECT_PASS_FUNC(PyTensorObject_flatten, functional::flatten)\nDIRECT_PASS_FUNC(PyTensorObject_in_top_k, functional::in_top_k)\nDIRECT_PASS_FUNC(PyTensorObject_index_select, functional::index_select)\nDIRECT_PASS_FUNC(PyTensorObject_logsumexp, functional::logsumexp)\nDIRECT_PASS_FUNC(PyTensorObject_maximum, functional::maximum)\nDIRECT_PASS_FUNC(PyTensorObject_minimum, functional::minimum)\nDIRECT_PASS_FUNC(PyTensorObject_tril, functional::tril)\nDIRECT_PASS_FUNC(PyTensorObject_tril_, functional::tril_)\nDIRECT_PASS_FUNC(PyTensorObject_triu, functional::triu)\nDIRECT_PASS_FUNC(PyTensorObject_triu_, functional::triu_)\nDIRECT_PASS_FUNC(PyTensorObject_softmax, functional::softmax)\nDIRECT_PASS_FUNC(PyTensorObject_log_softmax, functional::log_softmax)\nDIRECT_PASS_FUNC(PyTensorObject_roll, functional::roll)\nDIRECT_PASS_FUNC(PyTensorObject_unbind, functional::unbind)\nDIRECT_PASS_FUNC(PyTensorObject_squeeze, functional::squeeze)\nDIRECT_PASS_FUNC(PyTensorObject_swapaxes, functional::swapaxes)\nDIRECT_PASS_FUNC(PyTensorObject_swapdims, functional::swapdims)\nDIRECT_PASS_FUNC(PyTensorObject_unfold, functional::unfold_tensor)\nDIRECT_PASS_FUNC(PyTensorObject_unsqueeze, functional::unsqueeze)\nDIRECT_PASS_FUNC(PyTensorObject_max, functional::max)\nDIRECT_PASS_FUNC(PyTensorObject_min, functional::min)\nDIRECT_PASS_FUNC(PyTensorObject_median, functional::median)\nDIRECT_PASS_FUNC(PyTensorObject_mode, functional::mode)\nDIRECT_PASS_FUNC(PyTensorObject_pow, functional::pow)\nDIRECT_PASS_FUNC(PyTensorObject_chunk, functional::chunk)\nDIRECT_PASS_FUNC(PyTensorObject_split, functional::split)\nDIRECT_PASS_FUNC(PyTensorObject_narrow, functional::narrow)\nDIRECT_PASS_FUNC(PyTensorObject_masked_fill, functional::masked_fill)\nDIRECT_PASS_FUNC(PyTensorObject_masked_fill_, functional::masked_fill_)\nDIRECT_PASS_FUNC(PyTensorObject_dot, functional::dot)\nDIRECT_PASS_FUNC(PyTensorObject_nansum, functional::reduce_nansum)\nDIRECT_PASS_FUNC(PyTensorObject_sum, functional::reduce_sum)\nDIRECT_PASS_FUNC(PyTensorObject_bernoulli, functional::bernoulli)\nDIRECT_PASS_FUNC(PyTensorObject_bernoulli_, functional::bernoulli_)\nDIRECT_PASS_FUNC(PyTensorObject_bincount, functional::bincount)\nDIRECT_PASS_FUNC(PyTensorObject_isclose, functional::isclose)\nDIRECT_PASS_FUNC(PyTensorObject_broadcast_to, functional::broadcast_to)\nDIRECT_PASS_FUNC(PyTensorObject_lerp, functional::lerp)\nDIRECT_PASS_FUNC(PyTensorObject_lerp_, functional::lerp_)\nDIRECT_PASS_FUNC(PyTensorObject_unique, functional::unique)\nDIRECT_PASS_FUNC(PyTensorObject_topk, functional::topk)\nDIRECT_PASS_FUNC(PyTensorObject_quantile, functional::quantile)\nDIRECT_PASS_FUNC(PyTensorObject_bitwise_and, functional::bitwise_and)\nDIRECT_PASS_FUNC(PyTensorObject_bitwise_or, functional::bitwise_or)\nDIRECT_PASS_FUNC(PyTensorObject_bitwise_xor, functional::bitwise_xor)\nDIRECT_PASS_FUNC(PyTensorObject_baddbmm, functional::baddbmm)\nDIRECT_PASS_FUNC(PyTensorObject_mm, functional::mm)\nDIRECT_PASS_FUNC(PyTensorObject_sub, functional::sub)\nDIRECT_PASS_FUNC(PyTensorObject_mv, functional::matrix_vector_product)\nDIRECT_PASS_FUNC(PyTensorObject_fill_, functional::fill_)\nDIRECT_PASS_FUNC(PyTensorObject_gather, functional::dim_gather)\nDIRECT_PASS_FUNC(PyTensorObject_repeat_interleave, functional::repeat_interleave)\nDIRECT_PASS_FUNC(PyTensorObject_scatter_add, functional::scatter_add)\nDIRECT_PASS_FUNC(PyTensorObject_logaddexp, functional::logaddexp)\n\n// functions that parsing at Python C api layer\nstatic PyObject* PyTensorObject_byte(PyObject* self, PyObject* unused) {\n  HANDLE_ERRORS\n  return PyTensor_New(ASSERT_PTR(functional::To(PyTensor_Unpack(self), DType::UInt8(), false)));\n  END_HANDLE_ERRORS\n}\n\nstatic PyObject* PyTensorObject_dim(PyObject* self, PyObject* unused) {\n  HANDLE_ERRORS\n  return functional::CastToPyObject(PyTensor_Unpack(self)->ndim());\n  END_HANDLE_ERRORS\n}\n\nstatic PyObject* PyTensorObject_nelement(PyObject* self, PyObject* unused) {\n  HANDLE_ERRORS\n  return functional::CastToPyObject(PyTensor_Unpack(self)->nelement());\n  END_HANDLE_ERRORS\n}\n\nstatic PyObject* PyTensorObject_element_size(PyObject* self, PyObject* unused) {\n  HANDLE_ERRORS\n  return functional::CastToPyObject(PyTensor_Unpack(self)->dtype()->bytes());\n  END_HANDLE_ERRORS\n}\n\nstatic PyObject* PyTensorObject_get_device(PyObject* self, PyObject* unused) {\n  HANDLE_ERRORS\n  DeviceType device_type = ASSERT(PyTensor_Unpack(self)->device())->enum_type();\n  CHECK_OR_THROW(device_type == DeviceType::kCUDA)\n      << \"get_device is only available for GPU tensor.\";\n  return functional::CastToPyObject(ASSERT(PyTensor_Unpack(self)->device())->device_id());\n  END_HANDLE_ERRORS\n}\n\nstatic PyObject* PyTensorObject_size(PyObject* self, PyObject* args, PyObject* kwargs) {\n  HANDLE_ERRORS\n  PyObject* idx_obj = Py_None;\n  static const char* keywords[2] = {\"idx\", NULL};\n  if (!PyArg_ParseTupleAndKeywords(args, kwargs, \"|O:size\", const_cast<char**>(keywords),\n                                   &idx_obj)) {\n    return NULL;\n  }\n  auto shape = PyTensor_Unpack(self)->shape();\n  if (idx_obj == NULL || idx_obj == Py_None) return TensorSize_NewFromShape(*shape);\n  int64_t idx = PyLong_AsLongLong(idx_obj);\n  int64_t ndim = shape->NumAxes();\n  idx = CHECK_JUST(maybe_wrap_dim(idx, ndim));\n  idx = idx < 0 ? idx + ndim : idx;\n  return PyLong_FromLongLong(shape->At(idx));\n  END_HANDLE_ERRORS\n}\n\nstatic PyObject* PyTensorObject_cast(PyObject* self, PyObject* args, PyObject* kwargs) {\n  HANDLE_ERRORS\n  PyObject* dtype = NULL;\n  PyObject* pin_memory = Py_False;\n  static const char* keywords[3] = {\"dtype\", \"pin_memory\", NULL};\n  if (!PyArg_ParseTupleAndKeywords(args, kwargs, \"O|O!:cast\", const_cast<char**>(keywords), &dtype,\n                                   &PyBool_Type, &pin_memory)) {\n    return NULL;\n  }\n  CHECK_OR_THROW(functional::PyDTypeCheck(dtype))\n      << Error::TypeError() << \"cast(): argument 'dtype' must be data type, but found \"\n      << functional::PyStringAsString(PyObject_Str((PyObject*)Py_TYPE(dtype)));\n  const auto& result = functional::Cast(PyTensor_Unpack(self), functional::PyUnpackDType(dtype),\n                                        pin_memory == Py_True);\n  return PyTensor_New(ASSERT_PTR(result));\n  END_HANDLE_ERRORS\n}\n\nstatic PyObject* PyTensorObject_diag(PyObject* self, PyObject* args, PyObject* kwargs) {\n  HANDLE_ERRORS\n  int32_t diagonal = 0;\n  static const char* keywords[2] = {\"diagonal\", NULL};\n  if (!PyArg_ParseTupleAndKeywords(args, kwargs, \"|i:diag\", const_cast<char**>(keywords),\n                                   &diagonal)) {\n    return NULL;\n  }\n  return PyTensor_New(ASSERT_PTR(functional::Diag(PyTensor_Unpack(self), diagonal)));\n  END_HANDLE_ERRORS\n}\n\nstatic PyObject* PyTensorObject_diagonal(PyObject* self, PyObject* args, PyObject* kwargs) {\n  HANDLE_ERRORS\n  int32_t offset = 0;\n  int32_t dim1 = 0;\n  int32_t dim2 = 1;\n  static const char* keywords[4] = {\"offset\", \"dim1\", \"dim2\", NULL};\n  if (!PyArg_ParseTupleAndKeywords(args, kwargs, \"|iii:diagonal\", const_cast<char**>(keywords),\n                                   &offset, &dim1, &dim2)) {\n    return NULL;\n  }\n  return PyTensor_New(ASSERT_PTR(functional::Diagonal(PyTensor_Unpack(self), offset, dim1, dim2)));\n  END_HANDLE_ERRORS\n}\n\nstatic PyObject* PyTensorObject_matmul(PyObject* self, PyObject* args, PyObject* kwargs) {\n  HANDLE_ERRORS\n  PyObject* other = NULL;\n  static const char* keywords[2] = {\"other\", NULL};\n  if (!PyArg_ParseTupleAndKeywords(args, kwargs, \"O:matmul\", const_cast<char**>(keywords),\n                                   &other)) {\n    return NULL;\n  }\n  PyObjectPtr concat_args(PyTuple_Pack(2, self, other));\n  PyObject* result = functional::matmul(NULL, concat_args.get(), NULL);\n  if (PyErr_Occurred()) { throw py::error_already_set(); }\n  return result;\n  END_HANDLE_ERRORS\n}\n\nstatic PyObject* PyTensorObject_reshape(PyObject* self, PyObject* args, PyObject* kwargs) {\n  HANDLE_ERRORS\n  PyObject* shape = PyParseArgs(args, kwargs, \"reshape\", \"shape\");\n  PyObjectPtr _args = PyObjectPtr(PyTuple_Pack(2, self, shape));\n  PyObject* result = functional::reshape(NULL, _args.get(), NULL);\n  if (PyErr_Occurred()) { throw py::error_already_set(); }\n  return result;\n  END_HANDLE_ERRORS\n}\n\nstatic PyObject* PyTensorObject_reshape_as(PyObject* self, PyObject* args, PyObject* kwargs) {\n  HANDLE_ERRORS\n  auto tensor = PyTensor_Unpack(self);\n  PyObject* other = NULL;\n  static const char* keywords[2] = {\"other\", NULL};\n  if (!PyArg_ParseTupleAndKeywords(args, kwargs, \"O|:reshape_as\", const_cast<char**>(keywords),\n                                   &other)) {\n    return NULL;\n  }\n  return PyTensor_New(ASSERT_PTR(functional::Reshape(tensor, *PyTensor_Unpack(other)->shape())));\n  END_HANDLE_ERRORS\n}\n\nstatic PyObject* PyTensorObject_cpu(PyObject* self, PyObject* unused) {\n  HANDLE_ERRORS\n  Optional<std::string> device = \"cpu\";\n  return PyTensor_New(ASSERT_PTR(functional::To(PyTensor_Unpack(self), device, NullOpt, false)));\n  END_HANDLE_ERRORS\n}\n\nstatic PyObject* PyTensorObject_cuda(PyObject* self, PyObject* args, PyObject* kwargs) {\n  HANDLE_ERRORS\n  PyObject* device_obj = Py_None;\n  static const char* keywords[2] = {\"device\", NULL};\n  if (!PyArg_ParseTupleAndKeywords(args, kwargs, \"|O:cuda\", const_cast<char**>(keywords),\n                                   &device_obj)) {\n    return NULL;\n  }\n  auto tensor = PyTensor_Unpack(self);\n  if (functional::PyDeviceCheck(device_obj)) {\n    Optional<Symbol<Device>> device = functional::PyUnpackDevice(device_obj);\n    return PyTensor_New(ASSERT_PTR(functional::To(tensor, device, NullOpt, false)));\n  }\n  Optional<std::string> device_str;\n  if (device_obj == Py_None) {\n    device_str = \"cuda\";\n  } else if (PyLong_Check(device_obj)) {\n    device_str = \"cuda:\" + std::to_string(PyLong_AsLongLong(device_obj));\n  }\n  return PyTensor_New(ASSERT_PTR(functional::To(tensor, device_str, tensor->dtype(), false)));\n  END_HANDLE_ERRORS\n}\n\nstatic PyObject* PyTensorObject_var(PyObject* self, PyObject* args, PyObject* kwargs) {\n  HANDLE_ERRORS\n  PyObject* dim_obj = Py_None;\n  PyObject* unbiased_obj = Py_True;\n  PyObject* keepdim_obj = Py_False;\n  static const char* keywords[4] = {\"dim\", \"unbiased\", \"keepdim\", NULL};\n  if (!PyArg_ParseTupleAndKeywords(args, kwargs, \"|OO!O!:var\", const_cast<char**>(keywords),\n                                   &dim_obj, &PyBool_Type, &unbiased_obj, &PyBool_Type,\n                                   &keepdim_obj)) {\n    return NULL;\n  }\n  bool unbiased = unbiased_obj == Py_True;\n  bool keepdim = keepdim_obj == Py_True;\n  CHECK_OR_THROW(dim_obj == Py_None || PyLong_Check(dim_obj)\n                 || functional::PyLongSequenceCheck(dim_obj))\n      << Error::TypeError() << \"var(): argument 'dim' must be int32 list, not \"\n      << functional::PyStringAsString(PyObject_Str((PyObject*)Py_TYPE(dim_obj)));\n  auto tensor = PyTensor_Unpack(self);\n  if (dim_obj == Py_None) {\n    return PyTensor_New(ASSERT_PTR(functional::Variance(tensor, NullOpt, unbiased, keepdim)));\n  }\n  std::vector<int32_t> dim;\n  if (PyLong_Check(dim_obj)) {\n    dim.emplace_back(static_cast<int32_t>(PyLong_AsLong(dim_obj)));\n    return PyTensor_New(ASSERT_PTR(functional::Variance(tensor, dim, unbiased, keepdim)));\n  }\n  dim = functional::PyUnpackLongSequence<int32_t>(dim_obj);\n  return PyTensor_New(ASSERT_PTR(functional::Variance(tensor, dim, unbiased, keepdim)));\n  END_HANDLE_ERRORS\n}\n\nstatic PyObject* PyTensorObject_std(PyObject* self, PyObject* args, PyObject* kwargs) {\n  HANDLE_ERRORS\n  PyObject* dim_obj = Py_None;\n  PyObject* unbiased_obj = Py_True;\n  PyObject* keepdim_obj = Py_False;\n  static const char* keywords[4] = {\"dim\", \"unbiased\", \"keepdim\", NULL};\n  if (!PyArg_ParseTupleAndKeywords(args, kwargs, \"|OO!O!:std\", const_cast<char**>(keywords),\n                                   &dim_obj, &PyBool_Type, &unbiased_obj, &PyBool_Type,\n                                   &keepdim_obj)) {\n    return NULL;\n  }\n  bool unbiased = unbiased_obj == Py_True;\n  bool keepdim = keepdim_obj == Py_True;\n  CHECK_OR_THROW(dim_obj == Py_None || PyLong_Check(dim_obj)\n                 || functional::PyLongSequenceCheck(dim_obj))\n      << Error::TypeError() << \"std(): argument 'dim' must be int32 list, not \"\n      << functional::PyStringAsString(PyObject_Str((PyObject*)Py_TYPE(dim_obj)));\n  auto tensor = PyTensor_Unpack(self);\n  if (dim_obj == Py_None) {\n    return PyTensor_New(\n        ASSERT_PTR(functional::StandardDeviation(tensor, NullOpt, unbiased, keepdim)));\n  }\n  std::vector<int32_t> dim;\n  if (PyLong_Check(dim_obj)) {\n    dim.emplace_back(static_cast<int32_t>(PyLong_AsLong(dim_obj)));\n    return PyTensor_New(ASSERT_PTR(functional::StandardDeviation(tensor, dim, unbiased, keepdim)));\n  }\n  dim = functional::PyUnpackLongSequence<int32_t>(dim_obj);\n  return PyTensor_New(ASSERT_PTR(functional::StandardDeviation(tensor, dim, unbiased, keepdim)));\n  END_HANDLE_ERRORS\n}\n\nstatic PyObject* PyTensorObject_softplus(PyObject* self, PyObject* args, PyObject* kwargs) {\n  HANDLE_ERRORS\n  double beta = 1.0;\n  double threshold = 20.0;\n  static const char* keywords[3] = {\"beta\", \"threshold\", NULL};\n  if (!PyArg_ParseTupleAndKeywords(args, kwargs, \"dd:softplus\", const_cast<char**>(keywords), &beta,\n                                   &threshold)) {\n    return NULL;\n  }\n  return PyTensor_New(ASSERT_PTR(functional::Softplus(PyTensor_Unpack(self), beta, threshold)));\n  END_HANDLE_ERRORS\n}\n\nstatic PyObject* PyTensorObject_relu(PyObject* self, PyObject* unused) {\n  HANDLE_ERRORS\n  return PyTensor_New(ASSERT_PTR(functional::Relu(PyTensor_Unpack(self), false)));\n  END_HANDLE_ERRORS\n}\n\nstatic PyObject* PyTensorObject_relu_(PyObject* self, PyObject* unused) {\n  HANDLE_ERRORS\n  return PyTensor_New(ASSERT_PTR(functional::Relu(PyTensor_Unpack(self), true)));\n  END_HANDLE_ERRORS\n}\n\n#define REDUCE_FUNC(func_name, bind_func, whole_func)                            \\\n  static PyObject* func_name(PyObject* self, PyObject* args, PyObject* kwargs) { \\\n    HANDLE_ERRORS                                                                \\\n    if ((args == NULL || PyTuple_Size(args) == 0)                                \\\n        && (kwargs == NULL || PyDict_Size(kwargs) == 0)) {                       \\\n      return PyTensor_New(ASSERT_PTR(whole_func(PyTensor_Unpack(self))));        \\\n    }                                                                            \\\n    PyObjectPtr concat_args(concat_self(self, args));                            \\\n    PyObject* result = bind_func(NULL, concat_args.get(), kwargs);               \\\n    if (PyErr_Occurred()) { throw py::error_already_set(); }                     \\\n    return result;                                                               \\\n    END_HANDLE_ERRORS                                                            \\\n  }\n\nREDUCE_FUNC(PyTensorObject_any, functional::reduce_any, functional::ReduceAnyWhole)\nREDUCE_FUNC(PyTensorObject_all, functional::reduce_all, functional::ReduceAllWhole)\nREDUCE_FUNC(PyTensorObject_mean, functional::reduce_mean, functional::ReduceMeanWhole)\n\n#define DATATYPE_FUNC(func_name, dtype)                                    \\\n  static PyObject* func_name(PyObject* self, PyObject* unused) {           \\\n    HANDLE_ERRORS                                                          \\\n    auto tensor = PyTensor_Unpack(self);                                   \\\n    return PyTensor_New(ASSERT_PTR(functional::To(tensor, dtype, false))); \\\n    END_HANDLE_ERRORS                                                      \\\n  }\n\nDATATYPE_FUNC(PyTensorObject_bool, DType::Bool());\nDATATYPE_FUNC(PyTensorObject_int, DType::Int32());\nDATATYPE_FUNC(PyTensorObject_long, DType::Int64());\nDATATYPE_FUNC(PyTensorObject_half, DType::Float16());\nDATATYPE_FUNC(PyTensorObject_float, DType::Float());\nDATATYPE_FUNC(PyTensorObject_double, DType::Double());\nDATATYPE_FUNC(PyTensorObject_bfloat16, DType::BFloat16());\n\nstatic PyObject* PyTensorObject_view(PyObject* self, PyObject* args, PyObject* kwargs) {\n  HANDLE_ERRORS\n  PyObject* size = PyParseArgs(args, kwargs, \"view\", \"size\");\n  PyObjectPtr _args = PyObjectPtr(PyTuple_Pack(2, self, size));\n  PyObject* result = functional::view(NULL, _args.get(), NULL);\n  if (PyErr_Occurred()) { throw py::error_already_set(); }\n  return result;\n  END_HANDLE_ERRORS\n}\n\nstatic PyObject* PyTensorObject_view_as(PyObject* self, PyObject* args, PyObject* kwargs) {\n  HANDLE_ERRORS\n  auto tensor = PyTensor_Unpack(self);\n  PyObject* other = NULL;\n  static const char* keywords[2] = {\"other\", NULL};\n  if (!PyArg_ParseTupleAndKeywords(args, kwargs, \"O|:view_as\", const_cast<char**>(keywords),\n                                   &other)) {\n    return NULL;\n  }\n  return PyTensor_New(ASSERT_PTR(functional::View(tensor, *PyTensor_Unpack(other)->shape())));\n  END_HANDLE_ERRORS\n}\n\nstatic PyObject* PyTensorObject_permute(PyObject* self, PyObject* args, PyObject* kwargs) {\n  HANDLE_ERRORS\n  PyObject* dims = PyParseArgs(args, kwargs, \"permute\", \"dims\");\n  PyObjectPtr _args = PyObjectPtr(PyTuple_Pack(2, self, dims));\n  PyObject* result = functional::permute(NULL, _args.get(), NULL);\n  if (PyErr_Occurred()) { throw py::error_already_set(); }\n  return result;\n\n  END_HANDLE_ERRORS\n}\n\nstatic PyObject* PyTensorObject_transpose(PyObject* self, PyObject* args, PyObject* kwargs) {\n  HANDLE_ERRORS\n  auto tensor = PyTensor_Unpack(self);\n  int dim0 = 0;\n  int dim1 = 0;\n  static const char* keywords[3] = {\"dim0\", \"dim1\", NULL};\n  if (!PyArg_ParseTupleAndKeywords(args, kwargs, \"ii:transpose\", const_cast<char**>(keywords),\n                                   &dim0, &dim1)) {\n    return NULL;\n  }\n  return PyTensor_New(ASSERT_PTR(functional::Transpose2dim(tensor, dim0, dim1)));\n  END_HANDLE_ERRORS\n}\n\nstatic PyObject* PyTensorObject_local_to_global(PyObject* self, PyObject* args, PyObject* kwargs) {\n  HANDLE_ERRORS\n  auto tensor = PyTensor_Unpack(self);\n  CHECK_OR_THROW(tensor->is_local()) << Error::RuntimeError() << \"input must be a local tensor\";\n  PyObject* placement_obj = Py_None;\n  PyObject* sbp_obj = Py_None;\n  PyObject* check_meta_obj = Py_True;\n  PyObject* copy_obj = Py_False;\n  static const char* keywords[5] = {\"placement\", \"sbp\", \"check_meta\", \"copy\", NULL};\n  if (!PyArg_ParseTupleAndKeywords(args, kwargs, \"|OO$O!O!:local_to_global\",\n                                   const_cast<char**>(keywords), &placement_obj, &sbp_obj,\n                                   &PyBool_Type, &check_meta_obj, &PyBool_Type, &copy_obj)) {\n    return NULL;\n  }\n  const bool check_meta = (check_meta_obj == Py_True);\n  const bool copy = (copy_obj == Py_True);\n\n  CHECK_OR_THROW(placement_obj != Py_None && sbp_obj != Py_None)\n      << Error::InvalidValueError()\n      << \"Converting a local tensor to global tensor must have placement and sbp parameters.\";\n  CHECK_OR_THROW(functional::PyParallelDescCheck(placement_obj))\n      << Error::TypeError() << \"Invalid parameter placement with type \"\n      << functional::PyStringAsString(PyObject_Str((PyObject*)Py_TYPE(placement_obj)));\n\n  std::vector<Symbol<SbpParallel>> sbp;\n  if (functional::PySbpParallelCheck(sbp_obj)) {\n    sbp.emplace_back(functional::PyUnpackSbpParallel(sbp_obj));\n  } else {\n    CHECK_OR_THROW(functional::PySbpParallelSequenceCheck(sbp_obj))\n        << Error::TypeError() << \"Invalid parameter sbp with type \"\n        << functional::PyStringAsString(PyObject_Str((PyObject*)Py_TYPE(sbp_obj)));\n    sbp = functional::PyUnpackSbpParallelSequence(sbp_obj);\n  }\n  return PyTensor_New(ASSERT_PTR(functional::ToGlobal(\n      tensor, functional::PyUnpackParallelDesc(placement_obj), sbp, {}, check_meta, copy)));\n  END_HANDLE_ERRORS\n}\n\nstatic PyObject* PyTensorObject_global_to_global(PyObject* self, PyObject* args, PyObject* kwargs) {\n  HANDLE_ERRORS\n  auto tensor = PyTensor_Unpack(self);\n  CHECK_OR_THROW(tensor->is_global()) << Error::RuntimeError() << \"input must be a global tensor\";\n  PyObject* placement_obj = Py_None;\n  PyObject* sbp_obj = Py_None;\n  PyObject* grad_sbp_obj = Py_None;\n  Symbol<ParallelDesc> placement;\n  std::vector<Symbol<SbpParallel>> sbp;\n  std::vector<Symbol<SbpParallel>> grad_sbp;\n  PyObject* check_meta_obj = Py_False;\n  PyObject* copy_obj = Py_False;\n  static const char* keywords[6] = {\"placement\", \"sbp\", \"grad_sbp\", \"check_meta\", \"copy\", NULL};\n  if (!PyArg_ParseTupleAndKeywords(args, kwargs, \"|OO$OO!O!:global_to_global\",\n                                   const_cast<char**>(keywords), &placement_obj, &sbp_obj,\n                                   &grad_sbp_obj, &PyBool_Type, &check_meta_obj, &copy_obj)) {\n    return NULL;\n  }\n  const bool check_meta = (check_meta_obj == Py_True);\n  const bool copy = (copy_obj == Py_True);\n\n  // sbp\n  CHECK_OR_THROW(sbp_obj == Py_None || functional::PySbpParallelCheck(sbp_obj)\n                 || functional::PySbpParallelSequenceCheck(sbp_obj))\n      << Error::TypeError()\n      << \"sbp parameter must be type of oneflow.sbp.sbp or list/tuple of oneflow.sbp.sbp\";\n  if (functional::PySbpParallelCheck(sbp_obj)) {\n    sbp.emplace_back(functional::PyUnpackSbpParallel(sbp_obj));\n  } else if (functional::PySbpParallelSequenceCheck(sbp_obj)) {\n    sbp = functional::PyUnpackSbpParallelSequence(sbp_obj);\n  } else {\n    for (int32_t i = 0; i < ASSERT(tensor->nd_sbp())->sbp_parallel_size(); i++)\n      sbp.emplace_back(ASSERT(tensor->nd_sbp())->sbp_parallel(i));\n  }\n\n  // placement\n  CHECK_OR_THROW(placement_obj == Py_None || functional::PyParallelDescCheck(placement_obj))\n      << Error::TypeError() << \"Invalid parameter placement with type \"\n      << functional::PyStringAsString(PyObject_Str((PyObject*)Py_TYPE(placement_obj)));\n  if (placement_obj == Py_None) {\n    placement = ASSERT(tensor->parallel_desc());\n  } else {\n    placement = functional::PyUnpackParallelDesc(placement_obj);\n  }\n\n  // grad_sbp\n  CHECK_OR_THROW(grad_sbp_obj == Py_None || functional::PySbpParallelCheck(grad_sbp_obj)\n                 || functional::PySbpParallelSequenceCheck(grad_sbp_obj))\n      << Error::TypeError()\n      << \"grad_sbp parameter must be type of oneflow.sbp.sbp or list/tuple of oneflow.sbp.sbp\";\n  if (functional::PySbpParallelCheck(grad_sbp_obj)) {\n    grad_sbp.emplace_back(functional::PyUnpackSbpParallel(grad_sbp_obj));\n  } else if (functional::PySbpParallelSequenceCheck(grad_sbp_obj)) {\n    grad_sbp = functional::PyUnpackSbpParallelSequence(grad_sbp_obj);\n  }\n  return PyTensor_New(\n      ASSERT_PTR(functional::ToGlobal(tensor, placement, sbp, grad_sbp, check_meta, copy)));\n  END_HANDLE_ERRORS\n}\n\nstatic PyObject* PyTensorObject_to_global(PyObject* self, PyObject* args, PyObject* kwargs) {\n  HANDLE_ERRORS\n  const auto& tensor = PyTensor_Unpack(self);\n  PyObject* result = NULL;\n  if (tensor->is_global())\n    result = PyTensorObject_global_to_global(self, args, kwargs);\n  else {\n    result = PyTensorObject_local_to_global(self, args, kwargs);\n  }\n  if (PyErr_Occurred()) { throw py::error_already_set(); }\n  return result;\n\n  END_HANDLE_ERRORS\n}\n\nstatic PyObject* PyTensorObject_to_local(PyObject* self, PyObject* unused, PyObject* kwargs) {\n  HANDLE_ERRORS\n  auto tensor = PyTensor_Unpack(self);\n  CHECK_OR_THROW(tensor->is_global())\n      << Error::RuntimeError() << \"Expected global tensor for to_local but got local tensor!\";\n  bool copy = false;\n  static const char* keywords[2] = {\"copy\", NULL};\n  if (!PyArg_ParseTupleAndKeywords(unused, kwargs, \"|$O!:to_local\", const_cast<char**>(keywords),\n                                   &PyBool_Type, &copy)) {\n    return NULL;\n  };\n  return PyTensor_New(ASSERT_PTR(functional::GlobalToLocal(tensor, /*copy=*/copy)));\n  END_HANDLE_ERRORS\n}\n\nstatic PyObject* PyTensorObject_type_as(PyObject* self, PyObject* args, PyObject* kwargs) {\n  HANDLE_ERRORS\n  auto self_tensor = PyTensor_Unpack(self);\n  PyObject* other = NULL;\n  static const char* keywords[2] = {\"other\", NULL};\n  if (!PyArg_ParseTupleAndKeywords(args, kwargs, \"O|:type_as\", const_cast<char**>(keywords),\n                                   &other)) {\n    return NULL;\n  }\n\n  // target is local\n  auto other_tensor = PyTensor_Unpack(other);\n  if (other_tensor->is_local()) {\n    Optional<Symbol<Device>> device = ASSERT(other_tensor->device());\n    if (self_tensor->is_global()) {\n      self_tensor = ASSERT_PTR(functional::GlobalToLocal(self_tensor, /*copy=*/false));\n    }\n    return PyTensor_New(\n        ASSERT_PTR(functional::To(self_tensor, device, other_tensor->dtype(), /*copy=*/false)));\n  }\n\n  // target is global\n  std::shared_ptr<Tensor> value_tensor;\n  value_tensor = ASSERT_PTR(functional::To(self_tensor, other_tensor->dtype(), /*copy=*/false));\n  Symbol<ParallelDesc> placement = ASSERT(other_tensor->parallel_desc());\n  std::vector<Symbol<SbpParallel>> sbp;\n  auto ndsbp = ASSERT(other_tensor->nd_sbp());\n  for (int32_t i = 0; i < ndsbp->sbp_parallel_size(); i++) {\n    sbp.emplace_back(ndsbp->sbp_parallel(i));\n  }\n  return PyTensor_New(\n      ASSERT_PTR(functional::ToGlobal(value_tensor, placement, sbp, {}, true, /*copy=*/false)));\n  END_HANDLE_ERRORS\n}\n\nstatic PyObject* PyTensorObject_new(PyObject* self, PyObject* args, PyObject* kwargs) {\n  HANDLE_ERRORS\n  auto self_tensor = PyTensor_Unpack(self);\n\n  if (!kwargs) {\n    if (PyTuple_Size(args) == 1 && PyTensor_Check(PyTuple_GET_ITEM(args, 0))) {\n      // tensor.new(other)\n      auto other_tensor = PyTensor_Unpack(PyTuple_GET_ITEM(args, 0));\n      CHECK_OR_THROW(!self_tensor->is_global() && !other_tensor->is_global())\n          << \"Tensor.new(Tensor) only support local tensor.\";\n      CHECK_OR_THROW(self_tensor->dtype() == other_tensor->dtype())\n          << \"Tensor.new() expect \" << self_tensor->dtype()->name() << \" dtype tensor, but got \"\n          << other_tensor->dtype()->name() << \" dtype tensor.\";\n      CHECK_OR_THROW(ASSERT(self_tensor->device())->enum_type()\n                     == ASSERT(other_tensor->device())->enum_type())\n          << \"Tensor.new() expect tensor on \" << ASSERT(self_tensor->device())->type()\n          << \", but got tensor on \" << ASSERT(other_tensor->device())->type() << \".\";\n      return PyTensor_New(ASSERT_PTR(functional::TensorWithOtherCtor(other_tensor)));\n    }\n    kwargs = PyDict_New();\n  }\n  PyObjectPtr dtype_key(PyUnicode_FromString(\"dtype\"));\n  PyObjectPtr dtype_value(functional::CastToPyObject(self_tensor->dtype()));\n  CHECK_OR_THROW(PyDict_Contains(kwargs, dtype_key.get()) < 1);\n  CHECK_OR_THROW(PyDict_SetItemString(kwargs, \"dtype\", dtype_value.get()) > -1);\n\n  if (self_tensor->is_global()) {\n    PyObjectPtr placement_key(PyUnicode_FromString(\"placement\"));\n    PyObjectPtr sbp_key(PyUnicode_FromString(\"sbp\"));\n    CHECK_OR_THROW(PyDict_Contains(kwargs, placement_key.get()) < 1);\n    CHECK_OR_THROW(PyDict_Contains(kwargs, sbp_key.get()) < 1);\n\n    Symbol<ParallelDesc> placement = ASSERT(self_tensor->parallel_desc());\n    std::vector<Symbol<SbpParallel>> sbp;\n    auto ndsbp = ASSERT(self_tensor->nd_sbp());\n    for (int32_t i = 0; i < ndsbp->sbp_parallel_size(); i++) {\n      sbp.emplace_back(ndsbp->sbp_parallel(i));\n    }\n\n    PyObjectPtr placement_value(functional::CastToPyObject(placement));\n    PyObjectPtr sbp_value(functional::CastToPyObject(sbp));\n    CHECK_OR_THROW(PyDict_SetItemString(kwargs, \"placement\", placement_value.get()) > -1);\n    CHECK_OR_THROW(PyDict_SetItemString(kwargs, \"sbp\", sbp_value.get()) > -1);\n  } else {\n    auto device = ASSERT(self_tensor->device());\n\n    PyObjectPtr device_key(PyUnicode_FromString(\"device\"));\n    CHECK_OR_THROW(PyDict_Contains(kwargs, device_key.get()) < 1)\n        << \"Some of the keywords were incorrect: device\";\n    PyObjectPtr device_value(functional::CastToPyObject(device));\n    CHECK_OR_THROW(PyDict_SetItemString(kwargs, \"device\", device_value.get()) > -1);\n  }\n  return functional::_legacy_tensor_generic_ctor(NULL, args, kwargs);\n  END_HANDLE_ERRORS\n}\n\nint PyTensorObject_setitem(PyObject* self, PyObject* item, PyObject* value) {\n  HANDLE_ERRORS\n  CHECK_OR_THROW(functional::PyTensorIndexCheck(item))\n      << Error::TypeError() << \"tensor_setitem(): argument 'index' must be index, not \"\n      << functional::PyStringAsString(PyObject_Str((PyObject*)Py_TYPE(item)));\n  CHECK_OR_THROW(functional::PyScalarCheck(value) || PyTensor_Check(value))\n      << Error::TypeError() << \"tensor_setitem(): argument 'value' must be tensor or scalar, not \"\n      << functional::PyStringAsString(PyObject_Str((PyObject*)Py_TYPE(value)));\n  const auto& index_item = functional::PyUnpackTensorIndex(item);\n\n  auto tensor = PyTensor_Unpack(self);\n  // NOTE: use masked_fill_(local,global) to avoid D2H in TensorSetItem if index is bool tensor\n  if (functional::PyScalarCheck(value) && index_item.size() == 1 && index_item[0].IsTensor()) {\n    const auto& index_tensor = index_item[0].tensor();\n    if (index_tensor->shape() == tensor->shape()\n        && (index_tensor->dtype() == DType::Bool() || index_tensor->dtype() == DType::UInt8())) {\n      ASSERT_PTR(\n          functional::MaskedFillInplace(tensor, index_tensor, functional::PyUnpackScalar(value)));\n      return 0;\n    }\n  }\n\n  std::shared_ptr<Tensor> value_tensor;\n  {\n    if (tensor->is_global()) {\n      Symbol<ParallelDesc> placement = ASSERT(tensor->parallel_desc());\n      auto ndsbp = ASSERT(tensor->nd_sbp());\n      std::vector<Symbol<SbpParallel>> sbp(ndsbp->sbp_parallel_size(),\n                                           ASSERT(MakeBroadcastSbpParallel()));\n      if (functional::PyScalarCheck(value)) {\n        Scalar value_scalar = functional::PyUnpackScalar(value);\n        value_tensor = ASSERT_PTR(\n            functional::GlobalConstant(Shape({}), value_scalar, tensor->dtype(), placement, sbp));\n      } else {\n        value_tensor = PyTensor_Unpack(value);\n        CHECK_OR_THROW(value_tensor->is_global())\n            << Error::RuntimeError()\n            << \"tensor_setitem(): value must be a global tensor when self is global\";\n        value_tensor = ASSERT_PTR(\n            functional::ToGlobal(value_tensor, placement, sbp, {}, true, /*copy=*/false));\n      }\n    } else {\n      if (functional::PyScalarCheck(value)) {\n        // NOTE: initialize value_tensor in eager mode\n        LazyMode::Guard lazy_mode_disabled_guard(/*is_enabled=*/false);\n        Scalar value_scalar = functional::PyUnpackScalar(value);\n        value_tensor = ASSERT_PTR(functional::Constant(Shape({}), value_scalar, tensor->dtype(),\n                                                       ASSERT(tensor->device())));\n      } else {\n        value_tensor = PyTensor_Unpack(value);\n        CHECK_OR_THROW(value_tensor->is_local())\n            << Error::RuntimeError()\n            << \"tensor_setitem(): value must be a local tensor when self is local\";\n        Optional<Symbol<Device>> device = ASSERT(tensor->device());\n        value_tensor =\n            ASSERT_PTR(functional::To(value_tensor, device, value_tensor->dtype(), false));\n      }\n    }\n  }\n  ASSERT(functional::TensorSetItem(tensor, index_item, value_tensor));\n  return 0;\n  END_HANDLE_ERRORS_RET(-1)\n}\n\nPyMethodDef PyTensorObject_extra_methods[] = {\n    {\"byte\", PyTensorObject_byte, METH_NOARGS, NULL},\n    {\"size\", (PyCFunction)PyTensorObject_size, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"argmax\", (PyCFunction)PyTensorObject_argmax, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"argmin\", (PyCFunction)PyTensorObject_argmin, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"amin\", (PyCFunction)PyTensorObject_amin, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"dim\", PyTensorObject_dim, METH_NOARGS, NULL},\n    {\"ndimension\", PyTensorObject_dim, METH_NOARGS, NULL},\n    {\"nelement\", PyTensorObject_nelement, METH_NOARGS, NULL},\n    {\"numel\", PyTensorObject_nelement, METH_NOARGS, NULL},\n    {\"element_size\", PyTensorObject_element_size, METH_NOARGS, NULL},\n    {\"get_device\", PyTensorObject_get_device, METH_NOARGS, NULL},\n    {\"cast\", (PyCFunction)PyTensorObject_cast, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"diag\", (PyCFunction)PyTensorObject_diag, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"diagonal\", (PyCFunction)PyTensorObject_diagonal, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"addcmul\", (PyCFunction)PyTensorObject_addcmul, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"addcmul_\", (PyCFunction)PyTensorObject_addcmul_, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"addcdiv\", (PyCFunction)PyTensorObject_addcdiv, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"addcdiv_\", (PyCFunction)PyTensorObject_addcdiv_, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"matmul\", (PyCFunction)PyTensorObject_matmul, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"bool\", PyTensorObject_bool, METH_NOARGS, NULL},\n    {\"int\", PyTensorObject_int, METH_NOARGS, NULL},\n    {\"long\", PyTensorObject_long, METH_NOARGS, NULL},\n    {\"half\", PyTensorObject_half, METH_NOARGS, NULL},\n    {\"float\", PyTensorObject_float, METH_NOARGS, NULL},\n    {\"double\", PyTensorObject_double, METH_NOARGS, NULL},\n    {\"bfloat16\", PyTensorObject_bfloat16, METH_NOARGS, NULL},\n    {\"local_to_global\", (PyCFunction)PyTensorObject_local_to_global, METH_VARARGS | METH_KEYWORDS,\n     NULL},\n    {\"global_to_global\", (PyCFunction)PyTensorObject_global_to_global, METH_VARARGS | METH_KEYWORDS,\n     NULL},\n    {\"to_local\", (PyCFunction)PyTensorObject_to_local, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"to_global\", (PyCFunction)PyTensorObject_to_global, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"type_as\", (PyCFunction)PyTensorObject_type_as, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"cpu\", PyTensorObject_cpu, METH_NOARGS, NULL},\n    {\"cuda\", (PyCFunction)PyTensorObject_cuda, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"var\", (PyCFunction)PyTensorObject_var, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"std\", (PyCFunction)PyTensorObject_std, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"softplus\", (PyCFunction)PyTensorObject_softplus, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"relu\", PyTensorObject_relu, METH_NOARGS, NULL},\n    {\"relu_\", PyTensorObject_relu_, METH_NOARGS, NULL},\n    {\"all\", (PyCFunction)PyTensorObject_all, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"any\", (PyCFunction)PyTensorObject_any, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"sum\", (PyCFunction)PyTensorObject_sum, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"mean\", (PyCFunction)PyTensorObject_mean, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"new\", (PyCFunction)PyTensorObject_new, METH_VARARGS | METH_KEYWORDS, NULL},\n\n    // macro DIRECT_PASS_FUNC\n    {\"floor_divide\", (PyCFunction)PyTensorObject_floor_divide, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"atan2\", (PyCFunction)PyTensorObject_atan2, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"equal\", (PyCFunction)PyTensorObject_equal, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"gt\", (PyCFunction)PyTensorObject_gt, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"gt_\", (PyCFunction)PyTensorObject_gt_, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"frac\", (PyCFunction)PyTensorObject_frac, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"frac_\", (PyCFunction)PyTensorObject_frac_, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"ge\", (PyCFunction)PyTensorObject_ge, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"div\", (PyCFunction)PyTensorObject_div, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"div_\", (PyCFunction)PyTensorObject_div_, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"mul\", (PyCFunction)PyTensorObject_mul, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"mul_\", (PyCFunction)PyTensorObject_mul_, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"fmod\", (PyCFunction)PyTensorObject_fmod, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"logical_and\", (PyCFunction)PyTensorObject_logical_and, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"logical_or\", (PyCFunction)PyTensorObject_logical_or, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"logical_xor\", (PyCFunction)PyTensorObject_logical_xor, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"bmm\", (PyCFunction)PyTensorObject_bmm, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"ne\", (PyCFunction)PyTensorObject_ne, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"lt\", (PyCFunction)PyTensorObject_lt, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"le\", (PyCFunction)PyTensorObject_le, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"flip\", (PyCFunction)PyTensorObject_flip, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"clip\", (PyCFunction)PyTensorObject_clip, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"clip_\", (PyCFunction)PyTensorObject_clip_, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"clamp\", (PyCFunction)PyTensorObject_clamp, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"clamp_min\", (PyCFunction)PyTensorObject_clamp_min, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"clamp_max\", (PyCFunction)PyTensorObject_clamp_max, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"clamp_\", (PyCFunction)PyTensorObject_clamp_, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"clamp_min_\", (PyCFunction)PyTensorObject_clamp_min_, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"clamp_max_\", (PyCFunction)PyTensorObject_clamp_max_, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"flatten\", (PyCFunction)PyTensorObject_flatten, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"in_top_k\", (PyCFunction)PyTensorObject_in_top_k, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"index_select\", (PyCFunction)PyTensorObject_index_select, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"maximum\", (PyCFunction)PyTensorObject_maximum, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"minimum\", (PyCFunction)PyTensorObject_minimum, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"tril\", (PyCFunction)PyTensorObject_tril, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"tril_\", (PyCFunction)PyTensorObject_tril_, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"triu\", (PyCFunction)PyTensorObject_triu, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"triu_\", (PyCFunction)PyTensorObject_triu_, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"softmax\", (PyCFunction)PyTensorObject_softmax, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"log_softmax\", (PyCFunction)PyTensorObject_log_softmax, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"roll\", (PyCFunction)PyTensorObject_roll, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"unbind\", (PyCFunction)PyTensorObject_unbind, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"squeeze\", (PyCFunction)PyTensorObject_squeeze, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"swapaxes\", (PyCFunction)PyTensorObject_swapaxes, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"amax\", (PyCFunction)PyTensorObject_amax, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"swapdims\", (PyCFunction)PyTensorObject_swapdims, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"unfold\", (PyCFunction)PyTensorObject_unfold, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"unsqueeze\", (PyCFunction)PyTensorObject_unsqueeze, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"max\", (PyCFunction)PyTensorObject_max, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"min\", (PyCFunction)PyTensorObject_min, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"median\", (PyCFunction)PyTensorObject_median, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"mode\", (PyCFunction)PyTensorObject_mode, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"pow\", (PyCFunction)PyTensorObject_pow, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"chunk\", (PyCFunction)PyTensorObject_chunk, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"split\", (PyCFunction)PyTensorObject_split, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"narrow\", (PyCFunction)PyTensorObject_narrow, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"masked_fill\", (PyCFunction)PyTensorObject_masked_fill, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"masked_fill_\", (PyCFunction)PyTensorObject_masked_fill_, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"dot\", (PyCFunction)PyTensorObject_dot, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"nansum\", (PyCFunction)PyTensorObject_nansum, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"sum\", (PyCFunction)PyTensorObject_sum, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"bernoulli\", (PyCFunction)PyTensorObject_bernoulli, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"bernoulli_\", (PyCFunction)PyTensorObject_bernoulli_, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"bincount\", (PyCFunction)PyTensorObject_bincount, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"isclose\", (PyCFunction)PyTensorObject_isclose, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"broadcast_to\", (PyCFunction)PyTensorObject_broadcast_to, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"lerp\", (PyCFunction)PyTensorObject_lerp, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"lerp_\", (PyCFunction)PyTensorObject_lerp_, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"unique\", (PyCFunction)PyTensorObject_unique, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"topk\", (PyCFunction)PyTensorObject_topk, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"bitwise_and\", (PyCFunction)PyTensorObject_bitwise_and, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"bitwise_or\", (PyCFunction)PyTensorObject_bitwise_or, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"bitwise_xor\", (PyCFunction)PyTensorObject_bitwise_xor, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"baddbmm\", (PyCFunction)PyTensorObject_baddbmm, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"mm\", (PyCFunction)PyTensorObject_mm, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"sub\", (PyCFunction)PyTensorObject_sub, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"mv\", (PyCFunction)PyTensorObject_mv, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"fill_\", (PyCFunction)PyTensorObject_fill_, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"gather\", (PyCFunction)PyTensorObject_gather, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"repeat_interleave\", (PyCFunction)PyTensorObject_repeat_interleave,\n     METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"scatter_add\", (PyCFunction)PyTensorObject_scatter_add, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"logaddexp\", (PyCFunction)PyTensorObject_logaddexp, METH_VARARGS | METH_KEYWORDS, NULL},\n\n    // macro UNARY_METHOD\n    {\"abs\", PyTensorObject_abs, METH_NOARGS, NULL},\n    {\"digamma\", PyTensorObject_digamma, METH_NOARGS, NULL},\n    {\"exp\", PyTensorObject_exp, METH_NOARGS, NULL},\n    {\"exp2\", PyTensorObject_exp2, METH_NOARGS, NULL},\n    {\"floor\", PyTensorObject_floor, METH_NOARGS, NULL},\n    {\"floor_\", PyTensorObject_floor_, METH_NOARGS, NULL},\n    {\"acos\", PyTensorObject_acos, METH_NOARGS, NULL},\n    {\"arccos\", PyTensorObject_acos, METH_NOARGS, NULL},\n    {\"acosh\", PyTensorObject_acosh, METH_NOARGS, NULL},\n    {\"arccosh\", PyTensorObject_acosh, METH_NOARGS, NULL},\n    {\"atanh\", PyTensorObject_atanh, METH_NOARGS, NULL},\n    {\"arctanh\", PyTensorObject_atanh, METH_NOARGS, NULL},\n    {\"sign\", PyTensorObject_sign, METH_NOARGS, NULL},\n    {\"sinh\", PyTensorObject_sinh, METH_NOARGS, NULL},\n    {\"tan\", PyTensorObject_tan, METH_NOARGS, NULL},\n    {\"gelu\", PyTensorObject_gelu, METH_NOARGS, NULL},\n    {\"mish\", PyTensorObject_mish, METH_NOARGS, NULL},\n    {\"negative\", PyTensorObject_negative, METH_NOARGS, NULL},\n    {\"neg\", PyTensorObject_negative, METH_NOARGS, NULL},\n    {\"sigmoid\", PyTensorObject_sigmoid, METH_NOARGS, NULL},\n    {\"tanh\", PyTensorObject_tanh, METH_NOARGS, NULL},\n    {\"silu\", PyTensorObject_silu, METH_NOARGS, NULL},\n    {\"selu\", PyTensorObject_selu, METH_NOARGS, NULL},\n    {\"softsign\", PyTensorObject_softsign, METH_NOARGS, NULL},\n    {\"log1p\", PyTensorObject_log1p, METH_NOARGS, NULL},\n    {\"log2\", PyTensorObject_log2, METH_NOARGS, NULL},\n    {\"log10\", PyTensorObject_log10, METH_NOARGS, NULL},\n    {\"reciprocal\", PyTensorObject_reciprocal, METH_NOARGS, NULL},\n    {\"asin\", PyTensorObject_asin, METH_NOARGS, NULL},\n    {\"arcsin\", PyTensorObject_asin, METH_NOARGS, NULL},\n    {\"asinh\", PyTensorObject_asinh, METH_NOARGS, NULL},\n    {\"arcsinh\", PyTensorObject_asinh, METH_NOARGS, NULL},\n    {\"atan\", PyTensorObject_atan, METH_NOARGS, NULL},\n    {\"arctan\", PyTensorObject_atan, METH_NOARGS, NULL},\n    {\"ceil\", PyTensorObject_ceil, METH_NOARGS, NULL},\n    {\"ceil_\", PyTensorObject_ceil_, METH_NOARGS, NULL},\n    {\"cos\", PyTensorObject_cos, METH_NOARGS, NULL},\n    {\"cosh\", PyTensorObject_cosh, METH_NOARGS, NULL},\n    {\"erf\", PyTensorObject_erf, METH_NOARGS, NULL},\n    {\"erfc\", PyTensorObject_erfc, METH_NOARGS, NULL},\n    {\"erfinv\", PyTensorObject_erfinv, METH_NOARGS, NULL},\n    {\"erfinv_\", PyTensorObject_erfinv_, METH_NOARGS, NULL},\n    {\"expm1\", PyTensorObject_expm1, METH_NOARGS, NULL},\n    {\"log\", PyTensorObject_log, METH_NOARGS, NULL},\n    {\"rsqrt\", PyTensorObject_rsqrt, METH_NOARGS, NULL},\n    {\"sqrt\", PyTensorObject_sqrt, METH_NOARGS, NULL},\n    {\"square\", PyTensorObject_square, METH_NOARGS, NULL},\n    {\"round\", PyTensorObject_round, METH_NOARGS, NULL},\n    {\"round_\", PyTensorObject_round_, METH_NOARGS, NULL},\n    {\"t\", PyTensorObject_t, METH_NOARGS, NULL},\n    {\"sin\", PyTensorObject_sin, METH_NOARGS, NULL},\n    {\"sin_\", PyTensorObject_sin_, METH_NOARGS, NULL},\n    {\"isnan\", PyTensorObject_isnan, METH_NOARGS, NULL},\n    {\"inverse\", PyTensorObject_inv, METH_NOARGS, NULL},\n    {\"trunc\", PyTensorObject_trunc, METH_NOARGS, NULL},\n    {\"isinf\", PyTensorObject_isinf, METH_NOARGS, NULL},\n    {\"logical_not\", PyTensorObject_logical_not, METH_NOARGS, NULL},\n    {\"floor\", PyTensorObject_floor, METH_NOARGS, NULL},\n    {\"floor_\", PyTensorObject_floor_, METH_NOARGS, NULL},\n    {\"bitwise_not\", (PyCFunction)PyTensorObject_bitwise_not, METH_NOARGS, NULL},\n    {\"reshape\", (PyCFunction)PyTensorObject_reshape, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"reshape_as\", (PyCFunction)PyTensorObject_reshape_as, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"view\", (PyCFunction)PyTensorObject_view, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"view_as\", (PyCFunction)PyTensorObject_view_as, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"permute\", (PyCFunction)PyTensorObject_permute, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"transpose\", (PyCFunction)PyTensorObject_transpose, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"logsumexp\", (PyCFunction)PyTensorObject_logsumexp, METH_VARARGS | METH_KEYWORDS, NULL},\n    {\"quantile\", (PyCFunction)PyTensorObject_quantile, METH_VARARGS | METH_KEYWORDS, NULL},\n    {NULL},\n};\n\n// tp_richcompare\nPyObject* PyTensorObject_richcompare(PyObject* self, PyObject* other, int op) {\n  PyObjectPtr tuple(PyTuple_Pack(2, self, other));\n\n  switch (op) {\n    case Py_LT: return functional::less(NULL, tuple.get(), NULL);\n    case Py_LE: return functional::less_equal(NULL, tuple.get(), NULL);\n    case Py_EQ: {\n      if (self == Py_None || other == Py_None) Py_RETURN_FALSE;\n      return functional::broadcast_equal(NULL, tuple.get(), NULL);\n    }\n    case Py_NE: {\n      if (self == Py_None || other == Py_None) Py_RETURN_TRUE;\n      return functional::not_equal(NULL, tuple.get(), NULL);\n    }\n    case Py_GT: return functional::greater(NULL, tuple.get(), NULL);\n    case Py_GE: return functional::greater_equal(NULL, tuple.get(), NULL);\n  }\n  return NULL;\n}\n\n}  // namespace one\n}  // namespace oneflow\n\n#undef ASSERT\n#undef ASSERT_PTR\n"
  },
  {
    "path": "oneflow/api/python/framework/tensor_functions_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <Python.h>\n#undef _PyGC_FINALIZED\n#include <string>\n#include \"oneflow/api/python/exception/exception.h\"\n#include \"oneflow/api/python/functional/common.h\"\n#include \"oneflow/core/common/error.pb.h\"\n#include \"oneflow/core/common/throw.h\"\n\nnamespace oneflow {\nnamespace one {\n\nusing functional::PyObjectPtr;\n\nstd::string PyUnpack_String(PyObject* obj) {\n  CHECK_OR_THROW(PyUnicode_Check(obj)) << \"PyUnpack_String(): expect a PyUnicode object\";\n  Py_ssize_t size = -1;\n  const char* data = PyUnicode_AsUTF8AndSize(obj, &size);\n  CHECK_NOTNULL_OR_THROW(data) << \"error unpacking string as utf-8\";\n  return std::string(data, (size_t)size);\n}\n\n// For signature like Tensor.reshape(*shape), this function can handle these cases:\n// 1. parse positional arguments only case, like Tensor.reshape(1, 2)\n// 2. parse keyword arguments only case, like Tensor.reshape(shape=(1, 2))\n// 3. raise Error for multiple arguments case, like Tensor.reshape(1, shape=(1, ))\n// 4. return empty tuple for empty arguments, like Tensor.reshape()\nPyObject* PyParseArgs(PyObject* args, PyObject* kwargs, const char* func_name,\n                      const std::string& param_name) {\n  PyObject* args_obj = NULL;\n  // Tensor.reshape(shape=(1, 2)), get (1, 2) for kwargs[\"shape\"]\n  if (kwargs != NULL) {\n    PyObject* key = nullptr;\n    PyObject* value = nullptr;\n    Py_ssize_t pos = 0;\n    while (PyDict_Next(kwargs, &pos, &key, &value)) {\n      CHECK_OR_THROW(args_obj == NULL)\n          << Error::TypeError() << func_name << \"() got multiple values for argument '\"\n          << param_name << \"' or get invalid argument\";\n      CHECK_EQ_OR_THROW(PyUnpack_String(key), param_name)\n          << Error::TypeError() << func_name << \"() got an unexpected keyword argument \"\n          << PyUnpack_String(key);\n      args_obj = value;\n    }\n  }\n  if (PyTuple_GET_SIZE(args) != 0) {\n    CHECK_OR_THROW(args_obj == NULL)\n        << Error::TypeError() << func_name << \"() got multiple values for argument '\" << param_name\n        << \"' or get invalid argument\";\n    if (PyTuple_Size(args) == 1 && functional::PyShapeSequenceCheck(args)) {\n      args_obj = PyTuple_GET_ITEM(args, 0);\n    } else {\n      args_obj = args;\n    }\n  };\n  if (args_obj == NULL) { args_obj = args; }\n  return args_obj;\n}\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/api/python/framework/tensor_tuple.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include <vector>\n#include <pybind11/pybind11.h>\n#include <pybind11/stl.h>\n#include \"oneflow/api/python/of_api_registry.h\"\n#include \"oneflow/core/framework/tensor_tuple.h\"\n#include \"oneflow/core/framework/tensor.h\"\n\nnamespace py = pybind11;\n\nnamespace oneflow {\nnamespace one {\n\nnamespace {\n\nstruct TensorTupleUtil final {\n  static std::string ToString(const TensorTuple& tensor_tuple) {\n    std::stringstream ss;\n    int32_t idx = 0;\n    ss << \"TensorTuple(\";\n    for (const std::shared_ptr<Tensor>& tensor : tensor_tuple) {\n      ss << tensor;\n      if (++idx != tensor_tuple.size() || tensor_tuple.size() == 1) { ss << \", \"; }\n    }\n    ss << \")\";\n    return ss.str();\n  }\n\n  static void MergeFrom(std::shared_ptr<TensorTuple>& tensor_tuple, const TensorTuple& other) {\n    for (const auto& tensor : other) { tensor_tuple->emplace_back(tensor); }\n  }\n\n  static void AppendTensor(std::shared_ptr<TensorTuple>& tensor_tuple,\n                           const std::shared_ptr<Tensor>& tensor) {\n    tensor_tuple->emplace_back(tensor);\n  }\n};\n\n}  // namespace\n\nONEFLOW_API_PYBIND11_MODULE(\"\", m) {\n  py::class_<TensorTuple, std::shared_ptr<TensorTuple>>(m, \"TensorTuple\")\n      .def(py::init([]() { return std::make_shared<TensorTuple>(); }))\n      .def(py::init([](const std::shared_ptr<TensorTuple>& other) { return other; }))\n      .def(py::init([](const std::vector<std::shared_ptr<Tensor>>& list) {\n        auto tensor_tuple = std::make_shared<TensorTuple>();\n        for (const auto& t : list) { tensor_tuple->emplace_back(t); }\n        return tensor_tuple;\n      }))\n      .def(\"__str__\", &TensorTupleUtil::ToString)\n      .def(\"__repr__\", &TensorTupleUtil::ToString)\n      .def(\"__getitem__\",\n           [](const TensorTuple& tensor_tuple, int idx) { return tensor_tuple.at(idx); })\n      .def(\"__setitem__\",\n           [](std::shared_ptr<TensorTuple>& tensor_tuple, int idx,\n              const std::shared_ptr<Tensor>& tensor) { tensor_tuple->at(idx) = tensor; })\n      .def(\n          \"__iter__\",\n          [](const TensorTuple& tensor_tuple) {\n            return py::make_iterator(tensor_tuple.begin(), tensor_tuple.end());\n          },\n          py::keep_alive<0, 1>())\n      .def(\"__len__\", [](const TensorTuple& tensor_tuple) { return tensor_tuple.size(); })\n      .def(\"merge_from\", &TensorTupleUtil::MergeFrom)\n      .def(\"append\", &TensorTupleUtil::AppendTensor);\n}\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/api/python/framework/tensortype.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <Python.h>\n#undef _PyGC_FINALIZED\n#include <pybind11/pybind11.h>\n#include \"oneflow/api/python/framework/tensor.h\"\n#include \"oneflow/api/python/framework/tensortype.h\"\n#include \"oneflow/api/python/of_api_registry.h\"\n#include \"oneflow/core/common/symbol.h\"\n#include \"oneflow/api/python/functional/common.h\"\n#include \"oneflow/api/python/functional/tensor_api.yaml.pybind.h\"\n#include \"oneflow/core/framework/device.h\"\n#include \"oneflow/core/framework/dtype.h\"\n#include \"oneflow/core/functional/functional_api.yaml.h\"\n#include \"oneflow/api/python/exception/exception.h\"\n\nnamespace oneflow {\nnamespace one {\n\n#define ASSERT(x) (x).GetOrThrow()\n#define ASSERT_PTR(x) (x).GetPtrOrThrow()\nusing functional::PyObjectPtr;\n\nstatic PyTypeObject PyTensorTypeMetaClass{\n    PyVarObject_HEAD_INIT(NULL, 0) \"oneflow.tensortype\",  // tp_name\n    sizeof(PyTypeObject),                                 // tp_basicsize\n};\n\nstatic PyTypeObject PyTensorTypeTemplate{\n    PyVarObject_HEAD_INIT(&PyTensorTypeMetaClass, 0) NULL,  // tp_name\n    sizeof(PyTensorType),                                   // tp_basicsize\n};\n\nstatic std::vector<PyTensorType*> tensor_types;\n\nstatic const std::unordered_map<Symbol<DType>, std::string> all_data_types = {\n    {DType::Float(), \"FloatTensor\"},\n    {DType::Double(), \"DoubleTensor\"},\n    {DType::Int8(), \"CharTensor\"},\n    {DType::Int32(), \"IntTensor\"},\n    {DType::Int64(), \"LongTensor\"},\n    {DType::UInt8(), \"ByteTensor\"},\n    {DType::Float16(), \"HalfTensor\"},\n    {DType::BFloat16(), \"BFloat16Tensor\"},\n    {DType::Bool(), \"BoolTensor\"},\n    {DType::Complex32(), \"ComplexHalfTensor\"},\n    {DType::Complex64(), \"ComplexFloatTensor\"},\n    {DType::Complex128(), \"ComplexDoubleTensor\"},\n    {DType::Char(), \"CharTensor\"},\n    {DType::Int16(), \"ShortTensor\"},\n};\n\nstatic const std::string get_dtype_string(PyTensorType* tensortype) {\n  return all_data_types.at(tensortype->dtype);\n}\n\nstatic std::vector<std::pair<DeviceType, std::string>> all_device_types = {\n    {kCPU, \"oneflow\"},\n    {kCUDA, \"oneflow.cuda\"},\n};\n\nstatic PyObject* PyTensorTypeMetaCls_call(PyObject* self, PyObject* args, PyObject* kwargs) {\n  HANDLE_ERRORS\n  const auto& dtype = PyTensorType_UnpackDType(self);\n  PyObjectPtr dtype_value(functional::CastToPyObject(dtype));\n  if (!kwargs) {\n    kwargs = PyDict_New();\n  } else {\n    const char* dtype_str = \"dtype\";\n    PyObjectPtr dtype_key(PyUnicode_FromString(dtype_str));\n    CHECK_OR_THROW(PyDict_Contains(kwargs, dtype_key.get()) < 1)\n        << \"Some of the keywords were incorrect: dtype\";\n  }\n  CHECK_OR_THROW(PyDict_SetItemString(kwargs, \"dtype\", dtype_value.get()) > -1);\n\n  Maybe<std::string> maybe_device = DeviceTag4DeviceType(PyTensorType_UnpackDevice(self));\n  if (!TRY(maybe_device).IsOk()) { return PyErr_Format(PyExc_ValueError, \"invalid device\"); }\n\n  {\n    const char* placement_str = \"placement\";\n    PyObjectPtr placement_key(PyUnicode_FromString(placement_str));\n    if (PyDict_Contains(kwargs, placement_key.get()) == 1) {\n      // If creat global tensor, the device of TensorType will be cover by param placement\n      // Raise a warning to inform users of using oneflow.Tensortype rather than\n      // oneflow.xxx.Tensortype\n      CHECK_OR_THROW(PyTensorType_UnpackDevice(self) == kCPU)\n          << \"`\" << ((PyTensorType*)self)->name\n          << \"` can not creat a global tensor, consider use `oneflow.\"\n          << get_dtype_string((PyTensorType*)self) << \"`\";\n    } else {\n      std::string device = ASSERT(maybe_device);\n      PyObjectPtr device_value(PyUnicode_FromString(device.data()));\n      CHECK_OR_THROW(PyDict_SetItemString(kwargs, \"device\", device_value.get()) > -1);\n    }\n  }\n  auto* tensor = functional::_legacy_tensor_generic_ctor(NULL, args, kwargs);\n  if (PyErr_Occurred()) { throw py::error_already_set(); }\n  return tensor;\n  END_HANDLE_ERRORS\n};\n\nPyObject* PyTensorType_FromString(const std::string& tensortype) {\n  auto it = std::find_if(\n      tensor_types.begin(), tensor_types.end(),\n      [tensortype](PyTensorType* type) { return std::string(type->name) == tensortype; });\n  if (it == tensor_types.end()) {\n    PyErr_Format(PyExc_ValueError, \"invalid type: %s\", tensortype.data());\n    throw py::error_already_set();\n  }\n  return (PyObject*)(*it);\n}\n\nstatic const char* get_doc(PyTensorType* tensortype) {\n  // all tensortype docs\n  static std::vector<std::string> tensortype_doc;\n\n  std::string dtype = tensortype->dtype->name();\n  std::string doc = \"\";\n  if (!TRY(DeviceTag4DeviceType(tensortype->devicetype)).IsOk())\n    doc = \"The tensortype \" + std::string(tensortype->name) + \" is not available.\";\n  else {\n    std::string device = ASSERT(DeviceTag4DeviceType(tensortype->devicetype));\n    doc = \"Creates a Tensor with the dtype of \" + dtype + \" and the device on \" + device\n          + \", it has the same parameters as :func:`oneflow.Tensor`\";\n  }\n  tensortype_doc.emplace_back(doc);\n  return tensortype_doc.back().data();\n}\n\nstatic void init_tensortype_metaclass(PyTypeObject* metaclass) {\n  metaclass->tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;\n  metaclass->tp_base = &PyType_Type;\n  metaclass->tp_call = PyTensorTypeMetaCls_call;\n  if (PyType_Ready(metaclass) < 0) { return; }\n}\n\nstatic void init_tensortype(PyTypeObject* type, PyTypeObject& type_template, const char* name,\n                            const char* doc) {\n  memcpy(type, &type_template, sizeof(PyTypeObject));\n  type->tp_name = name;\n  type->tp_doc = doc;\n  type->tp_flags = Py_TPFLAGS_DEFAULT;\n  if (PyType_Ready(type) < 0) { THROW(RuntimeError) << \"tensortype initialization failed\"; }\n}\n\nstatic void generalize_tensor_types() {\n  init_tensortype_metaclass(&PyTensorTypeMetaClass);\n\n  for (const auto& devicetype : all_device_types) {\n    for (const auto& dtype : all_data_types) {\n      PyTensorType* tensortype = new PyTensorType();\n      // set name\n      std::string name = devicetype.second + \".\" + dtype.second;\n      size_t n = sizeof(tensortype->name);\n      strncpy(tensortype->name, name.c_str(), n - 1);\n      tensortype->name[n - 1] = '\\0';\n\n      // set type\n      tensortype->dtype = dtype.first;\n      tensortype->devicetype = devicetype.first;\n      tensortype->is_cuda = tensortype->devicetype == DeviceType::kCUDA;\n      tensor_types.push_back(tensortype);\n\n      const char* doc = get_doc(tensortype);\n      init_tensortype(&tensortype->py_type, PyTensorTypeTemplate, tensortype->name, doc);\n    }\n  }\n}\n\nbool PyTensorType_Check(PyObject* obj) { return PyObject_TypeCheck(obj, &PyTensorTypeMetaClass); }\n\nPyObject* PyTensorType_FromDTypeAndDeviceType(Symbol<DType> dtype, DeviceType device) {\n  auto it =\n      std::find_if(tensor_types.begin(), tensor_types.end(), [dtype, device](PyTensorType* x) {\n        return (x->dtype == dtype) && (x->devicetype == device);\n      });\n  if (it == tensor_types.end()) {\n    if (!TRY(DeviceTag4DeviceType(device)).IsOk())\n      return PyErr_Format(PyExc_ValueError, \"unsupported device\");\n    return PyErr_Format(PyExc_ValueError, \"unsupported data type (%s) or device (%s)\",\n                        dtype->name().c_str(), ASSERT(DeviceTag4DeviceType(device)).c_str());\n  }\n  return (PyObject*)(*it);\n};\n\n}  // namespace one\n}  // namespace oneflow\n\n#undef ASSERT\n\nusing namespace oneflow::one;\n\nONEFLOW_API_PYBIND11_MODULE(\"_C\", m) {\n  static std::string oneflow_prefix = \"oneflow.\";\n  generalize_tensor_types();\n\n  for (PyTensorType* tensortype : tensor_types) {\n    Py_INCREF(tensortype);\n    std::string name = std::string(tensortype->name);\n    size_t idx = name.rfind('.');\n    std::string type_name = name.substr(idx + 1);\n\n    name = name.substr(0, idx);\n    std::string module_name =\n        name.size() > oneflow_prefix.size() ? name.substr(oneflow_prefix.size()) : \"\";\n    auto module = m;\n    if (!module_name.empty()) { module = m.def_submodule(module_name.data()); }\n    if (tensortype\n        && PyModule_AddObject(module.ptr(), type_name.c_str(), (PyObject*)tensortype) < 0) {\n      return;\n    }\n  }\n}\n"
  },
  {
    "path": "oneflow/api/python/framework/tensortype.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_API_PYTHON_FRAMEWORK_TENSORTYPE_H_\n#define ONEFLOW_API_PYTHON_FRAMEWORK_TENSORTYPE_H_\n\n#include <Python.h>\n#undef _PyGC_FINALIZED\n#include \"oneflow/core/framework/dtype.h\"\n#include \"oneflow/core/framework/device.h\"\n\nnamespace oneflow {\nnamespace one {\n\ntypedef struct {\n  PyTypeObject py_type;\n  char name[64];\n  bool is_cuda;\n  Symbol<DType> dtype;\n  DeviceType devicetype;\n} PyTensorType;\n\nbool PyTensorType_Check(PyObject*);\n\ninline DeviceType PyTensorType_UnpackDevice(PyObject* self) {\n  return ((PyTensorType*)self)->devicetype;\n}\ninline Symbol<DType> PyTensorType_UnpackDType(PyObject* self) {\n  return ((PyTensorType*)self)->dtype;\n}\n\nPyObject* PyTensorType_FromDTypeAndDeviceType(Symbol<DType>, DeviceType);\nPyObject* PyTensorType_FromString(const std::string&);\n\n}  // namespace one\n}  // namespace oneflow\n\n#endif  // ONEFLOW_API_PYTHON_FRAMEWORK_TENSORTYPE_H_\n"
  },
  {
    "path": "oneflow/api/python/framework/thread.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <pybind11/pybind11.h>\n#include \"oneflow/api/python/of_api_registry.h\"\n#include \"oneflow/api/python/framework/thread.h\"\n#include \"oneflow/core/common/env_var/vm.h\"\n\nnamespace py = pybind11;\n\nnamespace oneflow {\n\nnamespace {\n\nclass UsingThreadUidSet final {\n public:\n  UsingThreadUidSet()\n      : using_thread_uids_({Stream::kDefaultStreamThreadUid}),\n        thread_limits_(using_thread_uids_.size()\n                       + ThreadLocalEnvInteger<ONEFLOW_VM_WORKER_THREAD_LIMIT>()) {}\n  ~UsingThreadUidSet() = default;\n\n  Maybe<int64_t> Get() {\n    std::unique_lock<std::mutex> lock(mutex_);\n    CHECK_LT_OR_RETURN(using_thread_uids_.size(), thread_limits_)\n        << \"can not create more worker threads. please check your code or increase environment \"\n           \"variable ONEFLOW_VM_WORKER_THREAD_LIMIT(default value:\"\n        << ThreadLocalEnvInteger<ONEFLOW_VM_WORKER_THREAD_LIMIT>() << \")\";\n    for (int i = 0; i < using_thread_uids_.size() + 1; ++i) {\n      if (using_thread_uids_.count(i) == 0) {\n        using_thread_uids_.insert(i);\n        return i;\n      }\n    }\n    UNIMPLEMENTED_THEN_RETURN();\n  }\n\n  Maybe<void> Put(int64_t thread_uid) {\n    std::unique_lock<std::mutex> lock(mutex_);\n    CHECK_NE_OR_RETURN(thread_uid, Stream::kDefaultStreamThreadUid)\n        << \"default thread_uid should not be erased. value: \" << thread_uid;\n    CHECK_OR_RETURN(using_thread_uids_.erase(thread_uid) > 0)\n        << \"no thread_uid found. (current: \" << thread_uid << \").\";\n    return Maybe<void>::Ok();\n  }\n\n private:\n  std::set<int64_t> using_thread_uids_;\n  size_t thread_limits_;\n  std::mutex mutex_;\n};\n\nUsingThreadUidSet* MutUsingThreadUidSet() {\n  static UsingThreadUidSet thread_uid_set;\n  return &thread_uid_set;\n}\n\n}  // namespace\n\n/*static*/ Maybe<AsyncThread> AsyncThread::New() {\n  return std::shared_ptr<AsyncThread>(new AsyncThread(JUST(MutUsingThreadUidSet()->Get())));\n}\n\nAsyncThread::~AsyncThread() { MutUsingThreadUidSet()->Put(thread_uid_).GetOrThrow(); }\n\n}  // namespace oneflow\n\nONEFLOW_API_PYBIND11_MODULE(\"\", m) {\n  using namespace oneflow;\n  py::class_<AsyncThread, std::shared_ptr<AsyncThread>>(m, \"AsyncThread\").def(py::init([]() {\n    return AsyncThread::New().GetPtrOrThrow();\n  }));\n}\n"
  },
  {
    "path": "oneflow/api/python/framework/thread.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_API_PYTHON_FRAMEWORK_THREAD_H_\n#define ONEFLOW_API_PYTHON_FRAMEWORK_THREAD_H_\n\n#include \"oneflow/core/framework/stream.h\"\n#include \"oneflow/core/common/util.h\"\n\nnamespace oneflow {\n\nclass AsyncThread final {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(AsyncThread);\n  ~AsyncThread();\n\n  static Maybe<AsyncThread> New();\n\n  int64_t thread_uid() const { return thread_uid_; }\n\n private:\n  AsyncThread(int64_t thread_uid) : thread_uid_(thread_uid) {}\n\n  int64_t thread_uid_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_API_PYTHON_FRAMEWORK_THREAD_H_\n"
  },
  {
    "path": "oneflow/api/python/framework/typeinfo.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include <limits>\n#include \"oneflow/api/python/exception/exception.h\"\n#include \"oneflow/api/python/functional/common.h\"\n#include \"oneflow/api/python/of_api_registry.h\"\n#include \"oneflow/api/python/framework/typeinfo.h\"\n\nnamespace oneflow {\nnamespace one {\n\n#define ASSERT(x) (x).GetOrThrow()\n#if PY_VERSION_HEX < 0x03070000\n#define PYGETSET_NAME(name) const_cast<char*>(name)\n#else\n#define PYGETSET_NAME(name) (name)\n#endif\n\nusing functional::PyObjectPtr;\n\n#define INFO_FLOAT_TYPE_SEQ FLOATING_DATA_TYPE_SEQ FLOAT16_DATA_TYPE_SEQ BFLOAT16_DATA_TYPE_SEQ\n#define INFO_TYPE_SEQ INT_DATA_TYPE_SEQ UNSIGNED_INT_DATA_TYPE_SEQ INFO_FLOAT_TYPE_SEQ\n\ntemplate<typename>\nstruct is_floating_point_with_half : public std::false_type {};\n\n#define DEFINE_IS_FLOATING_POINT_WITH_HALF(cpp_type, of_datatype) \\\n  template<>                                                      \\\n  struct is_floating_point_with_half<cpp_type> : public std::true_type {};\n\nOF_PP_FOR_EACH_TUPLE(DEFINE_IS_FLOATING_POINT_WITH_HALF, INFO_FLOAT_TYPE_SEQ);\n#undef DEFINE_IS_FLOATING_POINT_WITH_HALF\n\ntemplate<typename T>\ntypename std::enable_if<is_floating_point_with_half<T>::value, PyObject*>::type PyGetVal(T value) {\n  return PyFloat_FromDouble(value);\n}\n\ntemplate<typename T>\ntypename std::enable_if<std::is_integral<T>::value, PyObject*>::type PyGetVal(T value) {\n  return PyLong_FromLong(value);\n}\n\nPyObject* PyGetMaxVal(DataType datatype) {\n#define GET_MAX_VAL(cpp_type, of_datatype) \\\n  case of_datatype: return PyGetVal(std::numeric_limits<DataTypeToType<of_datatype>>::max());\n\n  switch (datatype) {\n    OF_PP_FOR_EACH_TUPLE(GET_MAX_VAL, INFO_TYPE_SEQ);\n    default: return NULL;\n#undef GET_MAX_VAL\n  }\n}\n\nPyObject* PyGetMinVal(DataType datatype) {\n#define GET_MIN_VAL(cpp_type, of_datatype) \\\n  case of_datatype: return PyGetVal(std::numeric_limits<DataTypeToType<of_datatype>>::lowest());\n\n  switch (datatype) {\n    OF_PP_FOR_EACH_TUPLE(GET_MIN_VAL, INFO_TYPE_SEQ);\n    default: return NULL;\n\n#undef GET_MIN_VAL\n  }\n}\n\n#define GET_FLOAT_RESOLUTION(cpp_type, of_datatype) \\\n  case of_datatype:                                 \\\n    return PyFloat_FromDouble(                      \\\n        std::pow(10, -std::numeric_limits<DataTypeToType<of_datatype>>::digits10));\n\n#define GET_FLOAT_EPS(cpp_type, of_datatype) \\\n  case of_datatype:                          \\\n    return PyFloat_FromDouble(std::numeric_limits<DataTypeToType<of_datatype>>::epsilon());\n\n#define GET_FLOAT_TINY(cpp_type, of_datatype) \\\n  case of_datatype:                           \\\n    return PyFloat_FromDouble(std::numeric_limits<DataTypeToType<of_datatype>>::min());\n\nPyTypeObject PyIInfoType = {\n    PyVarObject_HEAD_INIT(NULL, 0) \"oneflow.iinfo\",  // tp_name\n    sizeof(PyDTypeInfo),                             // tp_basicsize\n};\n\nPyTypeObject PyFInfoType = {\n    PyVarObject_HEAD_INIT(NULL, 0) \"oneflow.finfo\",  // tp_name\n    sizeof(PyDTypeInfo),                             // tp_basicsize\n};\n\nstatic PyObject* PyIInfo_new(PyTypeObject* self, PyObject* args, PyObject* kwargs) {\n  HANDLE_ERRORS\n  PyObject* dtype_obj = NULL;\n  static const char* keywords[2] = {\"type\", NULL};\n  if (!PyArg_ParseTupleAndKeywords(args, kwargs, \"O:iinfo\", const_cast<char**>(keywords),\n                                   &dtype_obj)) {\n    return NULL;\n  }\n  CHECK_OR_THROW(functional::PyDTypeCheck(dtype_obj))\n      << Error::TypeError() << \"iinfo(): argument 'type' must be oneflow.dtype, but found \"\n      << functional::PyStringAsString(PyObject_Str((PyObject*)Py_TYPE(dtype_obj)));\n\n  auto* self = (PyDTypeInfo*)PyIInfoType.tp_alloc(&PyIInfoType, 0);\n  if (!self) { throw py::error_already_set(); }\n  self->dtype = functional::PyUnpackDType(dtype_obj);\n  CHECK_OR_THROW(!self->dtype->is_floating_point() && !self->dtype->is_complex())\n      << Error::TypeError()\n      << \"oneflow.iinfo() requires an integer input type. Use oneflow.finfo to handle '\"\n      << self->dtype->name() << \"' \";\n  return (PyObject*)self;\n  END_HANDLE_ERRORS\n}\n\nstatic PyObject* PyFInfo_new(PyTypeObject* self, PyObject* args, PyObject* kwargs) {\n  HANDLE_ERRORS\n  PyObject* dtype_obj = functional::CastToPyObject(DType::Float());\n  static const char* keywords[2] = {\"type\", NULL};\n  if (!PyArg_ParseTupleAndKeywords(args, kwargs, \"|O:finfo\", const_cast<char**>(keywords),\n                                   &dtype_obj)) {\n    return NULL;\n  }\n  CHECK_OR_THROW(functional::PyDTypeCheck(dtype_obj))\n      << Error::TypeError() << \"finfo(): argument 'type' must be oneflow.dtype, but found \"\n      << functional::PyStringAsString(PyObject_Str((PyObject*)Py_TYPE(dtype_obj)));\n\n  auto* self = (PyDTypeInfo*)PyFInfoType.tp_alloc(&PyFInfoType, 0);\n  if (!self) { throw py::error_already_set(); }\n  self->dtype = functional::PyUnpackDType(dtype_obj);\n  CHECK_OR_THROW(self->dtype->is_floating_point() && !self->dtype->is_complex())\n      << Error::TypeError()\n      << \"oneflow.finfo() requires a float input type. Use oneflow.iinfo to handle '\"\n      << self->dtype->name() << \"' \";\n  return (PyObject*)self;\n  END_HANDLE_ERRORS\n}\n\nstatic PyObject* PyDInfo_bits(PyObject* self, void*) {\n  HANDLE_ERRORS\n  size_t bits = ASSERT(((PyDTypeInfo*)self)->dtype->bytes()) * 8;\n  return PyLong_FromSize_t(bits);\n  END_HANDLE_ERRORS\n}\n\nstatic PyObject* PyDInfo_min(PyObject* self, void*) {\n  HANDLE_ERRORS\n  DataType datatype = PyDTypeInfo_UnpackDataType(self);\n  PyObject* result = PyGetMinVal(datatype);\n  if (!result) {\n    THROW(RuntimeError) << PyDTypeInfo_UnpackDType(self)->name() << \" not supported by \"\n                        << self->ob_type->tp_name;\n  }\n  return result;\n  END_HANDLE_ERRORS\n}\n\nstatic PyObject* PyDInfo_max(PyObject* self, void*) {\n  HANDLE_ERRORS\n  DataType datatype = PyDTypeInfo_UnpackDataType(self);\n  PyObject* result = PyGetMaxVal(datatype);\n  if (!result) {\n    THROW(RuntimeError) << PyDTypeInfo_UnpackDType(self)->name() << \" not supported by \"\n                        << self->ob_type->tp_name;\n  }\n  return result;\n  END_HANDLE_ERRORS\n}\n\nstatic PyObject* PyFInfo_resolution(PyObject* self, void*) {\n  HANDLE_ERRORS\n  DataType datatype = PyDTypeInfo_UnpackDataType(self);\n  switch (datatype) {\n    OF_PP_FOR_EACH_TUPLE(GET_FLOAT_RESOLUTION, INFO_FLOAT_TYPE_SEQ);\n    default:\n      THROW(RuntimeError) << PyDTypeInfo_UnpackDType(self)->name()\n                          << \" not supported by oneflow.finfo\";\n      return NULL;\n  }\n  END_HANDLE_ERRORS\n}\n\nstatic PyObject* PyFInfo_eps(PyObject* self, void*) {\n  HANDLE_ERRORS\n  DataType datatype = PyDTypeInfo_UnpackDataType(self);\n  switch (datatype) {\n    OF_PP_FOR_EACH_TUPLE(GET_FLOAT_EPS, INFO_FLOAT_TYPE_SEQ);\n    default:\n      THROW(RuntimeError) << PyDTypeInfo_UnpackDType(self)->name()\n                          << \" not supported by oneflow.finfo\";\n      return NULL;\n  }\n  END_HANDLE_ERRORS\n}\n\nstatic PyObject* PyFInfo_tiny(PyObject* self, void*) {\n  HANDLE_ERRORS\n  DataType datatype = PyDTypeInfo_UnpackDataType(self);\n  switch (datatype) {\n    OF_PP_FOR_EACH_TUPLE(GET_FLOAT_TINY, INFO_FLOAT_TYPE_SEQ);\n    default:\n      THROW(RuntimeError) << PyDTypeInfo_UnpackDType(self)->name()\n                          << \" not supported by oneflow.finfo\";\n      return NULL;\n  }\n  END_HANDLE_ERRORS\n}\n\nstatic PyObject* PyDInfo_dtype(PyObject* self, void*) {\n  HANDLE_ERRORS\n  std::string name = ((PyDTypeInfo*)self)->dtype->name();\n  name = name.erase(0, name.find('.') + 1);\n  return PyUnicode_FromString(name.data());\n  END_HANDLE_ERRORS\n}\n\nstatic PyObject* PyIInfo_str(PyObject* self) {\n  HANDLE_ERRORS\n  std::ostringstream oss;\n  oss << \"iinfo(min=\" << PyLong_AS_LONG(PyDInfo_min((PyObject*)self, NULL)) << \", \";\n  oss << \"max=\" << PyLong_AS_LONG(PyDInfo_max((PyObject*)self, NULL)) << \", \";\n  oss << \"dtype=\" << PyDTypeInfo_UnpackDType(self)->name() << \", \";\n  oss << \"bits=\" << PyLong_AS_LONG(PyDInfo_bits((PyObject*)self, NULL)) << \")\";\n  return PyUnicode_FromString(oss.str().data());\n  END_HANDLE_ERRORS\n}\n\nstatic PyObject* PyFInfo_str(PyObject* self) {\n  HANDLE_ERRORS\n  std::ostringstream oss;\n  oss << \"finfo(resolution=\" << PyFloat_AS_DOUBLE(PyFInfo_resolution((PyObject*)self, NULL))\n      << \", \";\n  oss << \"min=\" << PyFloat_AS_DOUBLE(PyDInfo_min((PyObject*)self, NULL)) << \", \";\n  oss << \"max=\" << PyFloat_AS_DOUBLE(PyDInfo_max((PyObject*)self, NULL)) << \", \";\n  oss << \"eps=\" << PyFloat_AS_DOUBLE(PyFInfo_eps((PyObject*)self, NULL)) << \", \";\n  oss << \"tiny=\" << PyFloat_AS_DOUBLE(PyFInfo_tiny((PyObject*)self, NULL)) << \", \";\n  oss << \"dtype=\" << PyDTypeInfo_UnpackDType(self)->name() << \", \";\n  oss << \"bits=\" << PyLong_AS_LONG(PyDInfo_bits((PyObject*)self, NULL)) << \")\";\n  return PyUnicode_FromString(oss.str().data());\n  END_HANDLE_ERRORS\n}\n\nstatic struct PyGetSetDef PyIInfo_properties[] = {\n    {PYGETSET_NAME(\"bits\"), (getter)PyDInfo_bits, nullptr, nullptr, nullptr},\n    {PYGETSET_NAME(\"max\"), (getter)PyDInfo_max, nullptr, nullptr, nullptr},\n    {PYGETSET_NAME(\"min\"), (getter)PyDInfo_min, nullptr, nullptr, nullptr},\n    {PYGETSET_NAME(\"dtype\"), (getter)PyDInfo_dtype, nullptr, nullptr, nullptr},\n    {nullptr},\n};\n\nstatic struct PyGetSetDef PyFInfo_properties[] = {\n    {PYGETSET_NAME(\"bits\"), (getter)PyDInfo_bits, nullptr, nullptr, nullptr},\n    {PYGETSET_NAME(\"max\"), (getter)PyDInfo_max, nullptr, nullptr, nullptr},\n    {PYGETSET_NAME(\"min\"), (getter)PyDInfo_min, nullptr, nullptr, nullptr},\n    {PYGETSET_NAME(\"resolution\"), (getter)PyFInfo_resolution, nullptr, nullptr, nullptr},\n    {PYGETSET_NAME(\"eps\"), (getter)PyFInfo_eps, nullptr, nullptr, nullptr},\n    {PYGETSET_NAME(\"tiny\"), (getter)PyFInfo_tiny, nullptr, nullptr, nullptr},\n    {PYGETSET_NAME(\"dtype\"), (getter)PyDInfo_dtype, nullptr, nullptr, nullptr},\n    {nullptr},\n};\n\nstatic void init_info_type() {\n  PyIInfoType.tp_flags = Py_TPFLAGS_DEFAULT;\n  PyIInfoType.tp_str = (reprfunc)PyIInfo_str;\n  PyIInfoType.tp_repr = (reprfunc)PyIInfo_str;\n  PyIInfoType.tp_new = (newfunc)PyIInfo_new;\n  PyIInfoType.tp_getset = PyIInfo_properties;\n  if (PyType_Ready(&PyIInfoType) < 0) { return; }\n\n  PyFInfoType.tp_flags = Py_TPFLAGS_DEFAULT;\n  PyFInfoType.tp_str = (reprfunc)PyFInfo_str;\n  PyFInfoType.tp_repr = (reprfunc)PyFInfo_str;\n  PyFInfoType.tp_new = (newfunc)PyFInfo_new;\n  PyFInfoType.tp_getset = PyFInfo_properties;\n  if (PyType_Ready(&PyFInfoType) < 0) { return; }\n}\n\nONEFLOW_API_PYBIND11_MODULE(\"_C\", m) {\n  init_info_type();\n  if (PyModule_AddObject(m.ptr(), \"iinfo\", (PyObject*)&PyIInfoType) < 0) return;\n  if (PyModule_AddObject(m.ptr(), \"finfo\", (PyObject*)&PyFInfoType) < 0) return;\n}\n\n}  // namespace one\n}  // namespace oneflow\n#undef ASSERT\n#undef GET_FLOAT_RESOLUTION\n#undef GET_FLOAT_EPS\n#undef GET_FLOAT_TINY\n#undef INFO_FLOAT_TYPE_SEQ\n#undef INFO_TYPE_SEQ\n#undef PYGETSET_NAME"
  },
  {
    "path": "oneflow/api/python/framework/typeinfo.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#ifndef ONEFLOW_API_PYTHON_FRAMEWORK_TYPEINFO_H_\n#define ONEFLOW_API_PYTHON_FRAMEWORK_TYPEINFO_H_\n\n#include <Python.h>\n#undef _PyGC_FINALIZED\n#include \"oneflow/core/common/throw.h\"\n#include \"oneflow/core/framework/dtype.h\"\n\nnamespace oneflow {\nnamespace one {\n\ntypedef struct {\n  PyObject_HEAD;\n  Symbol<DType> dtype;\n} PyDTypeInfo;\n\nextern PyTypeObject PyIInfoType;\nextern PyTypeObject PyFInfoType;\n\ninline bool PyIInfo_Check(PyObject* obj) { return PyObject_TypeCheck(obj, &PyIInfoType); }\ninline bool PyFInfo_Check(PyObject* obj) { return PyObject_TypeCheck(obj, &PyFInfoType); }\ninline bool PyDTypeInfo_Check(PyObject* obj) { return PyIInfo_Check(obj) || PyFInfo_Check(obj); }\n\ninline Symbol<DType> PyDTypeInfo_UnpackDType(PyObject* obj) {\n  assert(PyDTypeInfo_Check(obj));\n  return ((PyDTypeInfo*)obj)->dtype;\n}\n\ninline DataType PyDTypeInfo_UnpackDataType(PyObject* obj) {\n  assert(PyDTypeInfo_Check(obj));\n  return ((PyDTypeInfo*)obj)->dtype->data_type();\n}\n\n}  // namespace one\n}  // namespace oneflow\n#endif  // ONEFLOW_API_PYTHON_FRAMEWORK_TYPEINFO_H_\n"
  },
  {
    "path": "oneflow/api/python/framework/variable_tensor_mgr.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <pybind11/pybind11.h>\n#include <pybind11/stl.h>\n#include <tuple>\n#include \"oneflow/api/common/variable_tensor_mgr.h\"\n#include \"oneflow/api/python/of_api_registry.h\"\n\nnamespace py = pybind11;\n\nnamespace oneflow {\n\nONEFLOW_API_PYBIND11_MODULE(\"\", m) {\n  m.def(\"FillVariableTensorMgr\", &FillVariableTensorMgr);\n  m.def(\"DumpVariableTensorMgr\", &DumpVariableTensorMgr);\n  m.def(\"ResetVariableTensorMgr\", &ResetVariableTensorMgr);\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/api/python/functional/common.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/api/python/functional/common.h\"\n#include <object.h>\n#include <string>\n#include <complex>\n\n#include \"oneflow/api/python/framework/memory_format.h\"\n#include \"oneflow/api/python/functional/indexing.h\"\n#include \"oneflow/extension/python/numpy.h\"\n#include \"oneflow/core/common/just.h\"\n#include \"oneflow/core/common/scalar.h\"\n#include \"oneflow/core/framework/dtype.h\"\n#include \"oneflow/core/framework/device.h\"\n#include \"oneflow/core/framework/op_expr.h\"\n#include \"oneflow/core/framework/tensor.h\"\n#include \"oneflow/core/framework/tensor_tuple.h\"\n#include \"oneflow/core/framework/random_generator.h\"\n#include \"oneflow/core/framework/instructions_builder.h\"\n#include \"oneflow/core/functional/tensor_index.h\"\n#include \"oneflow/core/functional/functional.h\"\n#include \"oneflow/core/vm/virtual_machine.h\"\n#include \"oneflow/core/kernel/kernel_util.h\"\n#include \"oneflow/core/framework/tensor_util.h\"\n\nnamespace oneflow {\nnamespace one {\nnamespace functional {\n\nnamespace detail {\n\nnamespace {\n\ntemplate<typename T>\nMaybe<T> GetItemInPyScalarTensor(PyObject* obj) {\n  return GetItemInScalarTensor<T>(PyTensor_Unpack(obj));\n}\n\n}  // namespace\n\ntemplate<typename T, typename std::enable_if<!std::is_base_of<py::object, T>::value, int>::type = 0>\nbool isinstance_fast(PyObject* obj) {\n  static auto type = py::detail::get_type_handle(typeid(T), false);\n  if (!type) { return false; }\n  const auto result = PyObject_IsInstance(obj, type.ptr());\n  if (result == -1) { throw py::error_already_set(); }\n  return result != 0;\n}\n\ntemplate<typename T, typename std::enable_if<!std::is_base_of<py::object, T>::value\n                                                 && !py::detail::is_shared_ptr<T>::value,\n                                             int>::type = 0>\nconst T& cast_fast(PyObject* obj) {\n  auto vh = reinterpret_cast<py::detail::instance*>(obj)->get_value_and_holder();\n  auto*& vptr = vh.value_ptr();\n  if (!vptr) {\n    throw py::cast_error(\"Unable to cast from object to T& since lazy allocation is not allowed \"\n                         \"for fast cast, please use pybind11::cast instead\");\n  }\n  return *reinterpret_cast<T*>(&vptr);\n}\n\ntemplate<typename T, typename std::enable_if<!std::is_base_of<py::object, T>::value\n                                                 && py::detail::is_shared_ptr<T>::value,\n                                             int>::type = 0>\nconst T& cast_fast(PyObject* obj) {\n  auto vh = reinterpret_cast<py::detail::instance*>(obj)->get_value_and_holder();\n  if (!vh.holder_constructed()) {\n    throw py::cast_error(\"Unable to cast from non-held to held instance (T& to Holder<T>)\");\n  }\n  return vh.template holder<T>();\n}\n\n}  // namespace detail\n\nbool PySequenceCheck(PyObject* obj, const std::function<bool(PyObject*)>& item_check) {\n  bool is_tuple = PyTuple_Check(obj);\n  if (!is_tuple && !PyList_Check(obj)) { return false; }\n  size_t size = is_tuple ? PyTuple_GET_SIZE(obj) : PyList_GET_SIZE(obj);\n  if (size == 0) { return true; }\n  PyObject* item = is_tuple ? PyTuple_GET_ITEM(obj, 0) : PyList_GET_ITEM(obj, 0);\n  return item_check(item);\n}\n\nbool PyLongSequenceCheck(PyObject* obj) {\n  return PySequenceCheck(\n      obj, [](PyObject* item) { return PyLong_Check(item) || PyIntegerScalarTensorCheck(item); });\n}\n\nbool PyFloatSequenceCheck(PyObject* obj) {\n  return PySequenceCheck(obj, [](PyObject* item) {\n    return PyFloat_Check(item) || PyLong_Check(item) || PyFloatScalarTensorCheck(item)\n           || PyIntegerScalarTensorCheck(item);\n  });\n}\n\nbool PyStringCheck(PyObject* obj) { return PyBytes_Check(obj) || PyUnicode_Check(obj); }\n\nbool PyStringSequenceCheck(PyObject* obj) {\n  return PySequenceCheck(obj, [](PyObject* item) { return PyStringCheck(item); });\n}\n\nstd::string PyStringAsString(PyObject* obj) {\n  PyObject* bytes = PyUnicode_AsEncodedString(obj, \"utf-8\", \"~E~\");\n  std::string str = PyBytes_AS_STRING(bytes);\n  Py_XDECREF(bytes);\n  return str;\n}\n\nstd::string PyObjectToReprStr(PyObject* obj) {\n  PyObject* repr_obj = PyObject_Repr(obj);\n  std::string str = PyStringAsString(repr_obj);\n  Py_XDECREF(repr_obj);\n  return str;\n}\n\n// Tensor list\nbool PyTensorSequenceCheck(PyObject* obj) {\n  return PySequenceCheck(obj, [](PyObject* item) { return PyTensor_Check(item); });\n}\nstd::vector<std::shared_ptr<Tensor>> PyUnpackTensorSequence(PyObject* obj) {\n  return PyUnpackSequence<std::shared_ptr<Tensor>>(\n      obj, [](PyObject* item) { return PyTensor_Unpack(item); });\n}\n\n// TensorTuple\nbool PyTensorTupleCheck(PyObject* obj) { return detail::isinstance_fast<TensorTuple>(obj); }\n\nstd::shared_ptr<TensorTuple> PyUnpackTensorTuple(PyObject* obj) {\n  return detail::cast_fast<std::shared_ptr<TensorTuple>>(obj);\n}\n\n// Scalar\nbool PyScalarCheck(PyObject* obj) {\n  return PyLong_Check(obj) || PyFloat_Check(obj) || PyComplex_Check(obj);\n}\n\nScalar PyUnpackScalar(PyObject* obj) {\n  if (PyBool_Check(obj)) {\n    return obj == Py_True;\n  } else if (PyLong_Check(obj)) {\n    return static_cast<int64_t>(PyLong_AsLongLong(obj));\n  } else if (PyFloat_Check(obj)) {\n    return PyFloat_AsDouble(obj);\n  } else if (PyComplex_Check(obj)) {\n    Py_complex value = PyComplex_AsCComplex(obj);\n    return std::complex<double>{value.real, value.imag};\n  } else if (PyArray_IsScalar(obj, Bool)) {\n    return obj == Py_True;\n  } else if (PyArray_IsScalar(obj, Floating)) {\n    return PyFloat_AsDouble(obj);\n  } else if (PyArray_IsScalar(obj, Complex64) || PyArray_IsScalar(obj, Complex128)) {\n    Py_complex value = PyComplex_AsCComplex(obj);\n    return std::complex<double>{value.real, value.imag};\n  }\n  THROW(RuntimeError) << \"The object is not scalar, but is \" << Py_TYPE(obj)->tp_name;\n  return 0;\n}\n\n// Scalar Tensor\nbool PyScalarTensorCheck(PyObject* obj) {\n  if (!LazyMode::is_enabled() && PyTensor_Check(obj)) {\n    const auto& tensor = PyTensor_Unpack(obj);\n    return tensor->shape()->size() == 0\n           && IsTriviallyCopyableDataType(tensor->dtype()->data_type());\n  }\n  return false;\n}\n\nScalar PyUnpackScalarTensor(PyObject* obj) {\n  if (PyBoolScalarTensorCheck(obj)) {\n    return PyUnpackBoolScalarTensor(obj);\n  } else if (PyIntegerScalarTensorCheck(obj)) {\n    return PyUnpackIntegerScalarTensor_AsLongLong(obj);\n  } else if (PyFloatScalarTensorCheck(obj)) {\n    return PyUnpackFloatScalarTensor_AsDouble(obj);\n  } else if (PyComplexScalarTensorCheck(obj)) {\n    return PyUnpackComplexScalarTensor_AsCComplex(obj);\n  }\n  THROW(RuntimeError) << \"The object is not scalar tensor, but is \" << Py_TYPE(obj)->tp_name\n                      << \"with data type: \"\n                      << DataType_Name(PyTensor_Unpack(obj)->dtype()->data_type());\n  return 0;\n}\n\n#define SWITCH_SCALAR_TENSOR_TO_SCALAR(cpp_type, of_type) \\\n  case of_type:                                           \\\n    return detail::GetItemInPyScalarTensor<cpp_type>(obj).GetOrThrow();\n\n#define SCALAR_TENSOR_UNPACK_FUNC_IMPL(func_name, return_type, type_seq)                  \\\n  return_type func_name(PyObject* obj) {                                                  \\\n    const auto& tensor = PyTensor_Unpack(obj);                                            \\\n    DataType data_type = tensor->dtype()->data_type();                                    \\\n    switch (data_type) {                                                                  \\\n      OF_PP_FOR_EACH_TUPLE(SWITCH_SCALAR_TENSOR_TO_SCALAR, type_seq)                      \\\n      default: {                                                                          \\\n        throw py::cast_error(\"Cannot get ##cpp##type from scalar tensor with data type: \" \\\n                             + DataType_Name(data_type));                                 \\\n      }                                                                                   \\\n    }                                                                                     \\\n  }\n\nSCALAR_TENSOR_UNPACK_FUNC_IMPL(PyUnpackBoolScalarTensor, bool,\n                               BOOL_DATA_TYPE_SEQ CHAR_DATA_TYPE_SEQ);\nSCALAR_TENSOR_UNPACK_FUNC_IMPL(PyUnpackIntegerScalarTensor_AsLongLong, long long,\n                               INT_DATA_TYPE_SEQ UNSIGNED_INT_DATA_TYPE_SEQ BOOL_DATA_TYPE_SEQ\n                                   CHAR_DATA_TYPE_SEQ);\nSCALAR_TENSOR_UNPACK_FUNC_IMPL(PyUnpackFloatScalarTensor_AsDouble, double,\n                               FLOATING_DATA_TYPE_SEQ INT_DATA_TYPE_SEQ UNSIGNED_INT_DATA_TYPE_SEQ);\nSCALAR_TENSOR_UNPACK_FUNC_IMPL(PyUnpackComplexScalarTensor_AsCComplex, std::complex<double>,\n                               COMPLEX_DATA_TYPE_SEQ FLOATING_DATA_TYPE_SEQ INT_DATA_TYPE_SEQ\n                                   UNSIGNED_INT_DATA_TYPE_SEQ);\n#undef SWITCH_SCALAR_TENSOR_TO_SCALAR\n#undef SCALAR_TENSOR_UNPACK_FUNC_IMPL\n\n// DType\nbool PyDTypeCheck(PyObject* obj) { return detail::isinstance_fast<Symbol<DType>>(obj); }\nSymbol<DType> PyUnpackDType(PyObject* obj) { return *detail::cast_fast<Symbol<DType>*>(obj); }\n\n// Layout\nbool PyLayoutCheck(PyObject* obj) { return detail::isinstance_fast<Symbol<Layout>>(obj); }\nSymbol<Layout> PyUnpackLayout(PyObject* obj) { return *detail::cast_fast<Symbol<Layout>*>(obj); }\n\n// Memory Format\nbool PyMemoryFormatCheck(PyObject* obj) { return PyMemoryFormat_Check(obj); }\nMemoryFormat PyUnpackMemoryFormat(PyObject* obj) { return PyMemoryFormat_Unpack(obj); }\n\n// DType list\nbool PyDTypeSequenceCheck(PyObject* obj) {\n  return PySequenceCheck(obj, [](PyObject* item) { return PyDTypeCheck(item); });\n}\nstd::vector<Symbol<DType>> PyUnpackDTypeSequence(PyObject* obj) {\n  return PyUnpackSequence<Symbol<DType>>(obj, [](PyObject* item) { return PyUnpackDType(item); });\n}\n\n// Shape\nbool PyShapeCheck(PyObject* obj) { return PyLongSequenceCheck(obj); }\n\nShape PyUnpackShape(PyObject* obj) {\n  bool is_tuple = PyTuple_Check(obj);\n  CHECK_OR_THROW(is_tuple || PyList_Check(obj))\n      << \"The object is not list or tuple, but is \" << Py_TYPE(obj)->tp_name;\n  size_t size = is_tuple ? PyTuple_GET_SIZE(obj) : PyList_GET_SIZE(obj);\n  DimVector values(size);\n  for (int i = 0; i < size; ++i) {\n    PyObject* item = is_tuple ? PyTuple_GET_ITEM(obj, i) : PyList_GET_ITEM(obj, i);\n    values[i] = PyLong_AsLongLong(item);\n  }\n  return Shape(values);\n}\n\n// Shape list\nbool PyShapeSequenceCheck(PyObject* obj) {\n  return PySequenceCheck(obj, [](PyObject* item) { return PyLongSequenceCheck(item); });\n}\nstd::vector<Shape> PyUnpackShapeSequence(PyObject* obj) {\n  return PyUnpackSequence<Shape>(obj, [](PyObject* item) -> Shape { return PyUnpackShape(item); });\n}\n\n// Generator\nbool PyGeneratorCheck(PyObject* obj) { return detail::isinstance_fast<Generator>(obj); }\nstd::shared_ptr<Generator> PyUnpackGenerator(PyObject* obj) {\n  return detail::cast_fast<std::shared_ptr<one::Generator>>(obj);\n}\n\n// Device\nbool PyDeviceCheck(PyObject* obj) { return detail::isinstance_fast<Symbol<Device>>(obj); }\nSymbol<Device> PyUnpackDevice(PyObject* obj) {\n  return *detail::cast_fast<std::shared_ptr<Symbol<Device>>>(obj);\n}\n\n// Placement\nbool PyParallelDescCheck(PyObject* obj) {\n  return detail::isinstance_fast<Symbol<ParallelDesc>>(obj);\n}\nSymbol<ParallelDesc> PyUnpackParallelDesc(PyObject* obj) {\n  return *detail::cast_fast<std::shared_ptr<Symbol<ParallelDesc>>>(obj);\n}\n\n// SBP\nbool PySbpParallelCheck(PyObject* obj) { return detail::isinstance_fast<Symbol<SbpParallel>>(obj); }\nSymbol<SbpParallel> PyUnpackSbpParallel(PyObject* obj) {\n  return *detail::cast_fast<std::shared_ptr<Symbol<SbpParallel>>>(obj);\n}\n\n// SBP list\nbool PySbpParallelSequenceCheck(PyObject* obj) {\n  return PySequenceCheck(obj, [](PyObject* item) { return PySbpParallelCheck(item); });\n}\nstd::vector<Symbol<SbpParallel>> PyUnpackSbpParallelSequence(PyObject* obj) {\n  return PyUnpackSequence<Symbol<SbpParallel>>(\n      obj, [](PyObject* item) { return PyUnpackSbpParallel(item); });\n}\n\n// Tensor index\nbool PyTensorIndexCheck(PyObject* obj) {\n  return PySlice_Check(obj) || PyLong_Check(obj) || obj == Py_Ellipsis || obj == Py_None\n         || PyTensor_Check(obj) || PySequence_Check(obj) || PyUnicode_Check(obj)\n         || numpy::PyArrayCheckLongScalar(obj);\n}\nTensorIndex PyUnpackTensorIndex(PyObject* obj) {\n  TensorIndex tensor_index;\n  // Obvious single-entry cases.\n  if (PySlice_Check(obj)                     // NOLINT\n      || PyLong_Check(obj)                   // NOLINT\n      || obj == Py_Ellipsis                  // NOLINT\n      || obj == Py_None                      // NOLINT\n      || PyTensor_Check(obj)                 // NOLINT\n      || !PySequence_Check(obj)              // NOLINT\n      || numpy::PyArrayCheckLongScalar(obj)  // NOLINT\n      || PyUnicode_Check(obj)) {\n    tensor_index.emplace_back(detail::UnpackIndexItem(obj));\n    return tensor_index;\n  }\n  PyObject* tup = NULL;\n  Py_ssize_t n = 0;\n  if (PyTuple_Check(obj)) {\n    tup = PySequence_Tuple(obj);\n    n = PySequence_Size(tup);\n  } else {\n    // The follow comments are from numpy:\n    // https://github.com/numpy/numpy/blob/main/numpy/core/src/multiarray/mapping.c#L266\n    /*\n     * At this point, we're left with a non-tuple, non-array, sequence:\n     * typically, a list. We use some somewhat-arbitrary heuristics from here\n     * onwards to decided whether to treat that list as a single index, or a\n     * list of indices.\n     */\n    n = PySequence_Size(obj);\n    // Negative size indicates a Python error in the PySequence_Size call.\n    if (n < 0) {\n      PyErr_Clear();\n      tensor_index.emplace_back(detail::UnpackIndexItem(obj));\n      return tensor_index;\n    }\n    // The follow comments are from numpy:\n    // https://github.com/numpy/numpy/blob/main/numpy/core/src/multiarray/mapping.c#L280\n    /*\n     * Backwards compatibility only takes effect for short sequences - otherwise\n     * we treat it like any other scalar.\n     *\n     * Sequences < NPY_MAXDIMS with any slice objects\n     * or newaxis, Ellipsis or other arrays or sequences\n     * embedded, are considered equivalent to an indexing\n     * tuple. (`a[[[1,2], [3,4]]] == a[[1,2], [3,4]]`)\n     */\n    if (n >= /*NPY_MAXDIMS=*/32) {\n      tensor_index.emplace_back(detail::UnpackIndexItem(obj));\n      return tensor_index;\n    }\n    // Check whether we should unpack the index like a tuple.\n    bool commit_to_unpack = false;\n    for (Py_ssize_t i = 0; i < n; ++i) {\n      PyObject* item = PySequence_GetItem(obj, i);\n      if (commit_to_unpack) {\n        CHECK_OR_THROW(item) << \"Sequence index is required.\";\n      } else {\n        if (!item) {\n          PyErr_Clear();\n          break;\n        }\n        if (PySequence_Check(item)   // NOLINT\n            || PySlice_Check(item)   // NOLINT\n            || PyTensor_Check(item)  // NOLINT\n            || item == Py_Ellipsis || item == Py_None) {\n          commit_to_unpack = true;\n        }\n      }\n      Py_DECREF(item);\n    }\n    if (commit_to_unpack) {\n      tup = PySequence_Tuple(obj);\n    } else {\n      tensor_index.emplace_back(detail::UnpackIndexItem(obj));\n      return tensor_index;\n    }\n  }\n\n  tensor_index.resize(n);\n  for (Py_ssize_t i = 0; i < n; ++i) {\n    PyObject* item = PySequence_GetItem(tup, i);\n    tensor_index[i] = detail::UnpackIndexItem(item);\n    Py_DECREF(item);\n  }\n  Py_DECREF(tup);\n  return tensor_index;\n}\n\n// OpExpr\nbool PyOpExprCheck(PyObject* obj) { return detail::isinstance_fast<OpExpr>(obj); }\n\nstd::shared_ptr<OpExpr> PyUnpackOpExpr(PyObject* obj) {\n  return detail::cast_fast<std::shared_ptr<OpExpr>>(obj);\n}\n\n// int64_t\nMaybe<int64_t> PyUnpackLong(PyObject* py_obj) {\n  int overflow = -1;\n  long long val = PyLong_AsLongLongAndOverflow(py_obj, &overflow);\n  if (val == -1 && PyErr_Occurred()) { return Error::RuntimeError() << \"Python exception occurs\"; }\n  if (overflow != 0) { return Error::RuntimeError() << \"Overflow when unpacking long\"; }\n  return (int64_t)val;\n}\n\n}  // namespace functional\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/api/python/functional/common.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_API_PYTHON_FUNCTIONAL_COMMON_H_\n#define ONEFLOW_API_PYTHON_FUNCTIONAL_COMMON_H_\n\n#include <string>\n#include <vector>\n#include <complex>\n#include <pybind11/pybind11.h>\n\n#include \"oneflow/api/python/framework/tensor.h\"\n#include \"oneflow/api/python/caster/maybe.h\"\n#include \"oneflow/api/python/caster/optional.h\"\n#include \"oneflow/core/common/throw.h\"\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/common/preprocessor.h\"\n#include \"oneflow/core/common/scalar.h\"\n#include \"oneflow/core/framework/dtype.h\"\n#include \"oneflow/core/framework/layout.h\"\n#include \"oneflow/core/framework/device.h\"\n#include \"oneflow/core/framework/op_expr.h\"\n#include \"oneflow/core/framework/tensor.h\"\n#include \"oneflow/core/framework/tensor_tuple.h\"\n#include \"oneflow/core/framework/random_generator.h\"\n#include \"oneflow/core/functional/tensor_index.h\"\n#include \"oneflow/core/common/foreign_lock_helper.h\"\n\nnamespace py = pybind11;\n\nnamespace oneflow {\nnamespace one {\nnamespace functional {\n\nstruct PyObjectPtrDeleter {\n  inline void operator()(PyObject* obj) {\n    CHECK_JUST(Singleton<ForeignLockHelper>::Get()->WithScopedAcquire([&]() -> Maybe<void> {\n      if (obj) { Py_DECREF(obj); }\n      obj = NULL;\n      return Maybe<void>::Ok();\n    }));\n  }\n};\n\nusing PyObjectPtr = std::unique_ptr<PyObject, PyObjectPtrDeleter>;\n\n#define INTEGER_AND_BOOL_TYPE_SEQ \\\n  OF_PP_MAKE_TUPLE_SEQ(int32_t)   \\\n  OF_PP_MAKE_TUPLE_SEQ(uint32_t)  \\\n  OF_PP_MAKE_TUPLE_SEQ(int64_t)   \\\n  OF_PP_MAKE_TUPLE_SEQ(uint64_t)  \\\n  OF_PP_MAKE_TUPLE_SEQ(bool)\n\n#define FLOATING_TYPE_SEQ     \\\n  OF_PP_MAKE_TUPLE_SEQ(float) \\\n  OF_PP_MAKE_TUPLE_SEQ(double)\n\nbool PySequenceCheck(PyObject* obj);\nbool PySequenceCheck(PyObject* obj, const std::function<bool(PyObject*)>& item_check);\n\ntemplate<typename T, typename UnpackItemFunc>\ninline std::vector<T> PyUnpackSequence(PyObject* obj, UnpackItemFunc unpack_item) {\n  bool is_tuple = PyTuple_Check(obj);\n  CHECK_OR_THROW(is_tuple || PyList_Check(obj))\n      << \"The object is not list or tuple, but is \" << Py_TYPE(obj)->tp_name;\n  size_t size = is_tuple ? PyTuple_GET_SIZE(obj) : PyList_GET_SIZE(obj);\n  std::vector<T> values(size);\n  for (int i = 0; i < size; ++i) {\n    PyObject* item = is_tuple ? PyTuple_GET_ITEM(obj, i) : PyList_GET_ITEM(obj, i);\n    values[i] = unpack_item(item);\n  }\n  return values;\n}\n\n// Scalar Tensor\nbool PyScalarTensorCheck(PyObject* obj);\nScalar PyUnpackScalarTensor(PyObject* obj);\n\n#define DefinePyTypeScalarTensorCheck(type, type_check_func)               \\\n  inline bool Py##type##ScalarTensorCheck(PyObject* obj) {                 \\\n    return PyScalarTensorCheck(obj)                                        \\\n           && type_check_func(PyTensor_Unpack(obj)->dtype()->data_type()); \\\n  }\n\nDefinePyTypeScalarTensorCheck(Bool, IsBoolDataType);         // PyBoolScalarTensorCheck\nDefinePyTypeScalarTensorCheck(Integer, IsIntegralDataType);  // PyIntegerScalarTensorCheck\nDefinePyTypeScalarTensorCheck(Float, IsFloatingDataType);    // PyFloatScalarTensorCheck\nDefinePyTypeScalarTensorCheck(Complex, IsComplexDataType);   // PyComplexScalarTensorCheck\n#undef DefinePyTypeScalarTensorCheck\n\nbool PyUnpackBoolScalarTensor(PyObject* obj);\nlong long PyUnpackIntegerScalarTensor_AsLongLong(PyObject* obj);\ndouble PyUnpackFloatScalarTensor_AsDouble(PyObject* obj);\nstd::complex<double> PyUnpackComplexScalarTensor_AsCComplex(PyObject* obj);\n\n// Integer/Float list\nbool PyLongSequenceCheck(PyObject* obj);\nbool PyFloatSequenceCheck(PyObject* obj);\n\ntemplate<typename T>\ninline std::vector<T> PyUnpackLongSequence(PyObject* obj) {\n  return PyUnpackSequence<T>(obj, [](PyObject* item) -> T {\n    if (PyIntegerScalarTensorCheck(item)) {\n      return static_cast<T>(PyUnpackIntegerScalarTensor_AsLongLong(item));\n    }\n    return static_cast<T>(PyLong_AsLongLong(item));\n  });\n}\n\ntemplate<typename T>\ninline std::vector<T> PyUnpackFloatSequence(PyObject* obj) {\n  return PyUnpackSequence<T>(obj, [](PyObject* item) -> T {\n    if (PyFloatScalarTensorCheck(item)) {\n      return static_cast<T>(PyUnpackFloatScalarTensor_AsDouble(item));\n    }\n    return static_cast<T>(PyFloat_AsDouble(item));\n  });\n}\n\n// String\nbool PyStringCheck(PyObject* obj);\nbool PyStringSequenceCheck(PyObject* obj);\n\nstd::string PyStringAsString(PyObject* obj);\n\nstd::string PyObjectToReprStr(PyObject* obj);\n\n// Scalar\nbool PyScalarCheck(PyObject* obj);\nScalar PyUnpackScalar(PyObject* obj);\n\n// Tensor list\nbool PyTensorSequenceCheck(PyObject* obj);\nstd::vector<std::shared_ptr<Tensor>> PyUnpackTensorSequence(PyObject* obj);\n\n// TensorTuple\nbool PyTensorTupleCheck(PyObject* obj);\nstd::shared_ptr<TensorTuple> PyUnpackTensorTuple(PyObject* obj);\n\n// DType\nbool PyDTypeCheck(PyObject* obj);\nSymbol<DType> PyUnpackDType(PyObject* obj);\n\n// Layout\nbool PyLayoutCheck(PyObject* obj);\nSymbol<Layout> PyUnpackLayout(PyObject* obj);\n\n// Memory Format\nbool PyMemoryFormatCheck(PyObject* obj);\nMemoryFormat PyUnpackMemoryFormat(PyObject* obj);\n\n// DType list\nbool PyDTypeSequenceCheck(PyObject* obj);\nstd::vector<Symbol<DType>> PyUnpackDTypeSequence(PyObject* obj);\n\n// Shape\nbool PyShapeCheck(PyObject* obj);\nShape PyUnpackShape(PyObject* obj);\n\n// Shape list\nbool PyShapeSequenceCheck(PyObject* obj);\nstd::vector<Shape> PyUnpackShapeSequence(PyObject* obj);\n\n// Generator\nbool PyGeneratorCheck(PyObject* obj);\nstd::shared_ptr<Generator> PyUnpackGenerator(PyObject* obj);\n\n// Device\nbool PyDeviceCheck(PyObject* obj);\nSymbol<Device> PyUnpackDevice(PyObject* obj);\n\n// Placement\nbool PyParallelDescCheck(PyObject* obj);\nSymbol<ParallelDesc> PyUnpackParallelDesc(PyObject* obj);\n\n// SBP\nbool PySbpParallelCheck(PyObject* obj);\nSymbol<SbpParallel> PyUnpackSbpParallel(PyObject* obj);\n\n// SBP list\nbool PySbpParallelSequenceCheck(PyObject* obj);\nstd::vector<Symbol<SbpParallel>> PyUnpackSbpParallelSequence(PyObject* obj);\n\n// Tensor index\nbool PyTensorIndexCheck(PyObject* obj);\nTensorIndex PyUnpackTensorIndex(PyObject* obj);\n\n// OpExpr\nbool PyOpExprCheck(PyObject* obj);\nstd::shared_ptr<OpExpr> PyUnpackOpExpr(PyObject* obj);\n\ntemplate<typename T>\ninline PyObject* CastToPyObject(T&& t) {\n  return py::cast(t).inc_ref().ptr();\n}\n\ntemplate<>\ninline PyObject* CastToPyObject<Maybe<Tensor>>(Maybe<Tensor>&& t) {\n  return PyTensor_New(t.GetPtrOrThrow());\n}\n\ntemplate<>\ninline PyObject* CastToPyObject<Maybe<TensorTuple>>(Maybe<TensorTuple>&& t) {\n  const auto& tensor_tuple = t.GetPtrOrThrow();\n  py::tuple tup(tensor_tuple->size());\n  for (int i = 0; i < tensor_tuple->size(); ++i) { tup[i] = py::cast(tensor_tuple->at(i)); }\n  return py::cast<py::object>(tup).inc_ref().ptr();\n}\n\ntemplate<>\ninline PyObject* CastToPyObject<Maybe<void>>(Maybe<void>&& t) {\n  t.GetOrThrow();\n  Py_RETURN_NONE;\n}\n\n// int64_t\nMaybe<int64_t> PyUnpackLong(PyObject* py_obj);\n\n}  // namespace functional\n}  // namespace one\n}  // namespace oneflow\n\n#endif  // ONEFLOW_API_PYTHON_FUNCTIONAL_COMMON_H_\n"
  },
  {
    "path": "oneflow/api/python/functional/dispatch_stateful_ops.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/common/scalar.h\"\n#include \"oneflow/core/framework/attr_map.h\"\n#include \"oneflow/core/framework/mutable_attr_map.h\"\n#include \"oneflow/core/framework/nd_sbp.h\"\n#include \"oneflow/core/framework/op_interpreter/lazy_op_interpreter.h\"\n#include \"oneflow/core/framework/op_interpreter/op_interpreter_util.h\"\n#include \"oneflow/core/framework/tensor.h\"\n#include \"oneflow/core/framework/tensor_tuple.h\"\n#include \"oneflow/core/functional/functional.h\"\n#include \"oneflow/core/functional/function_library.h\"\n\nnamespace oneflow {\nnamespace one {\nnamespace functional {\n\nnamespace impl {\n\nONEFLOW_FUNCTION_LIBRARY(m) {\n  m.add_functor(\n      \"DispatchFeedInput\",\n      [](const std::shared_ptr<OpExpr>& op, const std::shared_ptr<Tensor>& input) -> Maybe<Tensor> {\n        const auto& origin_input = JUST(OpInterpUtil::Dispatch<Tensor>(*op, {input}));\n        // Unpack input when do grad acc\n        return GradAccTryInsertUnpackAfterInput(origin_input);\n      });\n  m.add_functor(\n      \"DispatchFetchOutput\",\n      [](const std::shared_ptr<OpExpr>& op, const std::shared_ptr<Tensor>& input) -> Maybe<Tensor> {\n        // Pack output when do grad acc\n        const auto& pack_input = JUST(GradAccTryInsertPackBeforeOutput(input));\n        return OpInterpUtil::Dispatch<Tensor>(*op, {pack_input});\n      });\n  m.add_functor(\"DispatchFeedVariable\",\n                [](const std::shared_ptr<OpExpr>& op, const std::shared_ptr<Tensor>& input,\n                   const Scalar& l2) -> Maybe<Tensor> {\n                  auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"l2\");\n                  attrs.SetAllAttrs(l2.As<double>());\n                  const auto& origin_var =\n                      JUST(OpInterpUtil::Dispatch<Tensor>(*op, {input}, attrs));\n                  // Repeat variable when do grad acc\n                  return GradAccTryInsertRepeatAfterVar(origin_var);\n                });\n  m.add_functor(\n      \"DispatchOfrecordReader\",\n      [](const std::shared_ptr<OpExpr>& op, const std::string& data_dir, int32_t data_part_num,\n         const std::string& part_name_prefix, int32_t part_name_suffix_length, int32_t batch_size,\n         int32_t shuffle_buffer_size, bool random_shuffle, bool shuffle_after_epoch, int64_t seed,\n         const Optional<Symbol<Device>>& device) -> Maybe<Tensor> {\n        auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\n            \"data_dir\", \"data_part_num\", \"part_name_prefix\", \"part_name_suffix_length\",\n            \"batch_size\", \"shuffle_buffer_size\", \"random_shuffle\", \"shuffle_after_epoch\", \"seed\");\n        attrs.SetAllAttrs(data_dir, data_part_num, part_name_prefix, part_name_suffix_length,\n                          batch_size, shuffle_buffer_size, random_shuffle, shuffle_after_epoch,\n                          seed);\n        return OpInterpUtil::Dispatch<Tensor>(*op, {}, OpExprInterpContext(attrs, JUST(device)));\n      });\n  m.add_functor(\n      \"DispatchOfrecordReader\",\n      [](const std::shared_ptr<OpExpr>& op, const std::string& data_dir, int32_t data_part_num,\n         const std::string& part_name_prefix, int32_t part_name_suffix_length, int32_t batch_size,\n         int32_t shuffle_buffer_size, bool random_shuffle, bool shuffle_after_epoch, int64_t seed,\n         const Symbol<ParallelDesc>& placement,\n         const std::vector<Symbol<SbpParallel>>& sbp_tuple) -> Maybe<Tensor> {\n        auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\n            \"data_dir\", \"data_part_num\", \"part_name_prefix\", \"part_name_suffix_length\",\n            \"batch_size\", \"shuffle_buffer_size\", \"random_shuffle\", \"shuffle_after_epoch\", \"seed\",\n            \"nd_sbp\");\n        attrs.SetAllAttrs(data_dir, data_part_num, part_name_prefix, part_name_suffix_length,\n                          batch_size, shuffle_buffer_size, random_shuffle, shuffle_after_epoch,\n                          seed, *JUST(GetNdSbpStrList(sbp_tuple)));\n        auto nd_sbp = JUST(GetNdSbp(sbp_tuple));\n        return OpInterpUtil::Dispatch<Tensor>(*op, {},\n                                              OpExprInterpContext(attrs, placement, nd_sbp));\n      });\n  m.add_functor(\"DispatchOfrecordRawDecoder\",\n                [](const std::shared_ptr<OpExpr>& op, const std::shared_ptr<Tensor>& input,\n                   const std::string& name, const Shape& shape, const Symbol<DType>& data_type,\n                   bool dim1_varying_length, bool truncate) -> Maybe<Tensor> {\n                  auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"name\", \"shape\", \"data_type\",\n                                                               \"dim1_varying_length\", \"truncate\");\n                  attrs.SetAllAttrs(name, shape, data_type->data_type(), dim1_varying_length,\n                                    truncate);\n                  return OpInterpUtil::Dispatch<Tensor>(*op, {input}, attrs);\n                });\n  m.add_functor(\n      \"DispatchCoinFlip\",\n      [](const std::shared_ptr<OpExpr>& op, int64_t batch_size, Scalar probability, int64_t seed,\n         bool has_seed, const Optional<Symbol<Device>>& device) -> Maybe<Tensor> {\n        auto& attrs =\n            THREAD_CACHED_MUTABLE_ATTR_MAP(\"probability\", \"batch_size\", \"seed\", \"has_seed\");\n        attrs.SetAllAttrs(probability.As<float>(), batch_size, seed, has_seed);\n        return OpInterpUtil::Dispatch<Tensor>(*op, {}, OpExprInterpContext(attrs, JUST(device)));\n      });\n  m.add_functor(\"DispatchCoinFlip\",\n                [](const std::shared_ptr<OpExpr>& op, int64_t batch_size, Scalar probability,\n                   int64_t seed, bool has_seed, const Symbol<ParallelDesc>& placement,\n                   const std::vector<Symbol<SbpParallel>>& sbp_tuple) -> Maybe<Tensor> {\n                  auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"probability\", \"batch_size\", \"seed\",\n                                                               \"has_seed\", \"nd_sbp\");\n                  attrs.SetAllAttrs(probability.As<float>(), batch_size, seed, has_seed,\n                                    *JUST(GetNdSbpStrList(sbp_tuple)));\n                  auto nd_sbp = JUST(GetNdSbp(sbp_tuple));\n                  return OpInterpUtil::Dispatch<Tensor>(\n                      *op, {}, OpExprInterpContext(attrs, placement, nd_sbp));\n                });\n  m.add_functor(\n      \"DispatchDistributedPariticalFCSample\",\n      [](const std::shared_ptr<OpExpr>& op, const std::shared_ptr<Tensor>& weight,\n         const std::shared_ptr<Tensor>& label, const int64_t& num_sample) -> Maybe<TensorTuple> {\n        auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"num_sample\");\n        attrs.SetAllAttrs(num_sample);\n        return OpInterpUtil::Dispatch<TensorTuple>(*op, {weight, label}, attrs);\n      });\n  m.add_functor(\n      \"DispatchCropMirrorNormalizeFromUint8\",\n      [](const std::shared_ptr<OpExpr>& op, const TensorTuple& input, int64_t crop_h,\n         int64_t crop_w, float crop_pos_x, float crop_pos_y, const std::vector<float>& mean,\n         const std::vector<float>& std, const Symbol<DType>& output_dtype,\n         const std::string& output_layout, const std::string& color_space) -> Maybe<Tensor> {\n        auto& attrs =\n            THREAD_CACHED_MUTABLE_ATTR_MAP(\"color_space\", \"output_layout\", \"mean\", \"std\", \"crop_h\",\n                                           \"crop_w\", \"crop_pos_x\", \"crop_pos_y\", \"output_dtype\");\n        attrs.SetAllAttrs(color_space, output_layout, mean, std, crop_h, crop_w, crop_pos_x,\n                          crop_pos_y, output_dtype->data_type());\n        return OpInterpUtil::Dispatch<Tensor>(*op, input, attrs);\n      });\n  m.add_functor(\n      \"DispatchCropMirrorNormalizeFromTensorBuffer\",\n      [](const std::shared_ptr<OpExpr>& op, const TensorTuple& input, int64_t crop_h,\n         int64_t crop_w, float crop_pos_x, float crop_pos_y, const std::vector<float>& mean,\n         const std::vector<float>& std, const Symbol<DType>& output_dtype,\n         const std::string& output_layout, const std::string& color_space) -> Maybe<Tensor> {\n        auto& attrs =\n            THREAD_CACHED_MUTABLE_ATTR_MAP(\"color_space\", \"output_layout\", \"mean\", \"std\", \"crop_h\",\n                                           \"crop_w\", \"crop_pos_x\", \"crop_pos_y\", \"output_dtype\");\n        attrs.SetAllAttrs(color_space, output_layout, mean, std, crop_h, crop_w, crop_pos_x,\n                          crop_pos_y, output_dtype->data_type());\n        return OpInterpUtil::Dispatch<Tensor>(*op, {input}, attrs);\n      });\n  m.add_functor(\n      \"DispatchOfrecordImageDecoderRandomCrop\",\n      [](const std::shared_ptr<OpExpr>& op, const std::shared_ptr<Tensor>& input,\n         const std::string& name, const std::string& color_space,\n         const std::vector<float>& random_area, const std::vector<float>& random_aspect_ratio,\n         int32_t num_attempts, int64_t seed, bool has_seed) -> Maybe<Tensor> {\n        auto& attrs =\n            THREAD_CACHED_MUTABLE_ATTR_MAP(\"name\", \"color_space\", \"num_attempts\", \"seed\",\n                                           \"has_seed\", \"random_area\", \"random_aspect_ratio\");\n        attrs.SetAllAttrs(name, color_space, num_attempts, seed, has_seed, random_area,\n                          random_aspect_ratio);\n        return OpInterpUtil::Dispatch<Tensor>(*op, {input}, attrs);\n      });\n  m.add_functor(\"DispatchOfrecordImageDecoder\",\n                [](const std::shared_ptr<OpExpr>& op, const std::shared_ptr<Tensor>& input,\n                   const std::string& name, const std::string& color_space) -> Maybe<Tensor> {\n                  auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"name\", \"color_space\");\n                  attrs.SetAllAttrs(name, color_space);\n                  return OpInterpUtil::Dispatch<Tensor>(*op, {input}, attrs);\n                });\n  m.add_functor(\"DispatchImageDecoderRandomCropResize\",\n                [](const std::shared_ptr<OpExpr>& op, const std::shared_ptr<Tensor>& input,\n                   int64_t target_width, int64_t target_height, int64_t seed, int64_t num_workers,\n                   int64_t max_num_pixels, float random_area_min, float random_area_max,\n                   float random_aspect_ratio_min, float random_aspect_ratio_max,\n                   int64_t warmup_size, int64_t num_attempts) -> Maybe<Tensor> {\n                  auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\n                      \"target_width\", \"target_height\", \"seed\", \"num_workers\", \"max_num_pixels\",\n                      \"random_area_min\", \"random_area_max\", \"random_aspect_ratio_min\",\n                      \"random_aspect_ratio_max\", \"warmup_size\", \"num_attempts\");\n                  attrs.SetAllAttrs(target_width, target_height, seed, num_workers, max_num_pixels,\n                                    random_area_min, random_area_max, random_aspect_ratio_min,\n                                    random_aspect_ratio_max, warmup_size, num_attempts);\n                  return OpInterpUtil::Dispatch<Tensor>(*op, {input}, attrs);\n                });\n  m.add_functor(\n      \"DispatchTensorBufferToListOfTensorsV2\",\n      [](const std::shared_ptr<OpExpr>& op, const std::shared_ptr<Tensor>& input,\n         const std::vector<Shape>& out_shapes, const std::vector<Symbol<DType>>& out_dtypes,\n         bool dynamic_out) -> Maybe<TensorTuple> {\n        auto out_data_types = std::vector<DataType>();\n        for (auto it = out_dtypes.begin(); it != out_dtypes.end(); it++) {\n          out_data_types.emplace_back((*it)->data_type());\n        }\n        auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"out_shapes\", \"dynamic_out\", \"out_dtypes\");\n        attrs.SetAllAttrs(out_shapes, dynamic_out, out_data_types);\n        return OpInterpUtil::Dispatch<TensorTuple>(*op, {input}, attrs);\n      });\n  m.add_functor(\"DispatchImageResizeKeepAspectRatio\",\n                [](const std::shared_ptr<OpExpr>& op, const std::shared_ptr<Tensor>& input,\n                   int32_t target_size, int32_t min_size, int32_t max_size, bool resize_longer,\n                   const std::string& interpolation_type) -> Maybe<TensorTuple> {\n                  auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\n                      \"target_size\", \"min_size\", \"max_size\", \"resize_longer\", \"interpolation_type\");\n                  attrs.SetAllAttrs(target_size, min_size, max_size, resize_longer,\n                                    interpolation_type);\n                  return OpInterpUtil::Dispatch<TensorTuple>(*op, {input}, attrs);\n                });\n  m.add_functor(\"DispatchImageResizeToFixed\",\n                [](const std::shared_ptr<OpExpr>& op, const std::shared_ptr<Tensor>& input,\n                   int64_t target_width, int64_t target_height, int64_t channels,\n                   const Symbol<DType>& data_type,\n                   const std::string& interpolation_type) -> Maybe<TensorTuple> {\n                  auto& attrs =\n                      THREAD_CACHED_MUTABLE_ATTR_MAP(\"target_width\", \"target_height\", \"channels\",\n                                                     \"data_type\", \"interpolation_type\");\n                  attrs.SetAllAttrs(target_width, target_height, channels, data_type->data_type(),\n                                    interpolation_type);\n                  return OpInterpUtil::Dispatch<TensorTuple>(*op, {input}, attrs);\n                });\n  m.add_functor(\n      \"DispatchImageDecode\",\n      [](const std::shared_ptr<OpExpr>& op, const std::shared_ptr<Tensor>& input,\n         const std::string& color_space, const Symbol<DType>& data_type) -> Maybe<Tensor> {\n        auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"color_space\", \"data_type\");\n        attrs.SetAllAttrs(color_space, data_type->data_type());\n        return OpInterpUtil::Dispatch<Tensor>(*op, {input}, attrs);\n      });\n  m.add_functor(\"DispatchImageNormalize\",\n                [](const std::shared_ptr<OpExpr>& op, const std::shared_ptr<Tensor>& input,\n                   const std::vector<float>& mean, const std::vector<float>& std) -> Maybe<Tensor> {\n                  auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"std\", \"mean\");\n                  attrs.SetAllAttrs(std, mean);\n                  return OpInterpUtil::Dispatch<Tensor>(*op, {input}, attrs);\n                });\n  m.add_functor(\"DispatchCOCOReader\",\n                [](const std::shared_ptr<OpExpr>& op, const std::string& image_dir,\n                   const std::string& annotation_file, int64_t batch_size, bool shuffle_after_epoch,\n                   int64_t random_seed, bool group_by_ratio, bool remove_images_without_annotations,\n                   bool stride_partition, int64_t session_id,\n                   const Optional<Symbol<Device>>& device) -> Maybe<TensorTuple> {\n                  auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\n                      \"session_id\", \"annotation_file\", \"image_dir\", \"batch_size\",\n                      \"shuffle_after_epoch\", \"random_seed\", \"group_by_ratio\",\n                      \"remove_images_without_annotations\", \"stride_partition\");\n                  attrs.SetAllAttrs(session_id, annotation_file, image_dir, batch_size,\n                                    shuffle_after_epoch, random_seed, group_by_ratio,\n                                    remove_images_without_annotations, stride_partition);\n                  return OpInterpUtil::Dispatch<TensorTuple>(\n                      *op, {}, OpExprInterpContext(attrs, JUST(device)));\n                });\n  m.add_functor(\"DispatchCOCOReader\",\n                [](const std::shared_ptr<OpExpr>& op, const std::string& image_dir,\n                   const std::string& annotation_file, int64_t batch_size, bool shuffle_after_epoch,\n                   int64_t random_seed, bool group_by_ratio, bool remove_images_without_annotations,\n                   bool stride_partition, int64_t session_id, const Symbol<ParallelDesc>& placement,\n                   const std::vector<Symbol<SbpParallel>>& sbp_tuple) -> Maybe<TensorTuple> {\n                  auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\n                      \"session_id\", \"annotation_file\", \"image_dir\", \"batch_size\",\n                      \"shuffle_after_epoch\", \"random_seed\", \"group_by_ratio\",\n                      \"remove_images_without_annotations\", \"stride_partition\", \"nd_sbp\");\n                  attrs.SetAllAttrs(session_id, annotation_file, image_dir, batch_size,\n                                    shuffle_after_epoch, random_seed, group_by_ratio,\n                                    remove_images_without_annotations, stride_partition,\n                                    *JUST(GetNdSbpStrList(sbp_tuple)));\n                  auto nd_sbp = JUST(GetNdSbp(sbp_tuple));\n                  return OpInterpUtil::Dispatch<TensorTuple>(\n                      *op, {}, OpExprInterpContext(attrs, placement, nd_sbp));\n                });\n  m.add_functor(\n      \"DispatchImageBatchAlign\",\n      [](const std::shared_ptr<OpExpr>& op, const std::shared_ptr<Tensor>& input, int32_t alignment,\n         const Shape& shape, const Symbol<DType>& data_type, bool dynamic_out) -> Maybe<Tensor> {\n        auto& attrs =\n            THREAD_CACHED_MUTABLE_ATTR_MAP(\"shape\", \"data_type\", \"alignment\", \"dynamic_out\");\n        attrs.SetAllAttrs(shape, data_type->data_type(), alignment, dynamic_out);\n        return OpInterpUtil::Dispatch<Tensor>(*op, {input}, attrs);\n      });\n  m.add_functor(\"DispatchOfrecordBytesDecoder\",\n                [](const std::shared_ptr<OpExpr>& op, const std::shared_ptr<Tensor>& input,\n                   const std::string& name) -> Maybe<Tensor> {\n                  auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"name\");\n                  attrs.SetAllAttrs(name);\n                  return OpInterpUtil::Dispatch<Tensor>(*op, {input}, attrs);\n                });\n  m.add_functor(\n      \"DispatchMegatronGptMmapDataLoader\",\n      [](const std::shared_ptr<OpExpr>& op, const std::string& data_file_prefix, int64_t seq_length,\n         int64_t label_length, int64_t num_samples, int64_t batch_size, const Symbol<DType>& dtype,\n         const std::vector<int64_t>& split_sizes, int64_t split_index, bool shuffle,\n         int64_t random_seed, const Optional<Symbol<Device>>& device) -> Maybe<Tensor> {\n        auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\n            \"data_file_prefix\", \"seq_length\", \"label_length\", \"num_samples\", \"batch_size\", \"dtype\",\n            \"split_sizes\", \"split_index\", \"shuffle\", \"random_seed\");\n        attrs.SetAllAttrs(data_file_prefix, seq_length, label_length, num_samples, batch_size,\n                          dtype->data_type(), split_sizes, split_index, shuffle, random_seed);\n        return OpInterpUtil::Dispatch<Tensor>(*op, {}, OpExprInterpContext(attrs, JUST(device)));\n      });\n  m.add_functor(\n      \"DispatchMegatronGptMmapDataLoader\",\n      [](const std::shared_ptr<OpExpr>& op, const std::string& data_file_prefix, int64_t seq_length,\n         int64_t label_length, int64_t num_samples, int64_t batch_size, const Symbol<DType>& dtype,\n         const std::vector<int64_t>& split_sizes, int64_t split_index, bool shuffle,\n         int64_t random_seed, const Symbol<ParallelDesc>& placement,\n         const std::vector<Symbol<SbpParallel>>& sbp_tuple) -> Maybe<Tensor> {\n        auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\n            \"data_file_prefix\", \"seq_length\", \"label_length\", \"num_samples\", \"batch_size\", \"dtype\",\n            \"split_sizes\", \"split_index\", \"shuffle\", \"random_seed\");\n        attrs.SetAllAttrs(data_file_prefix, seq_length, label_length, num_samples, batch_size,\n                          dtype->data_type(), split_sizes, split_index, shuffle, random_seed);\n        auto nd_sbp = JUST(GetNdSbp(sbp_tuple));\n        return OpInterpUtil::Dispatch<Tensor>(*op, {},\n                                              OpExprInterpContext(attrs, placement, nd_sbp));\n      });\n  m.add_functor(\"DispatchRmspropUpdate\",\n                [](const std::shared_ptr<OpExpr>& op, const TensorTuple& inputs,\n                   float learning_rate, double scale, float l1, float l2, bool centered,\n                   float epsilon, float decay_rate, float weight_decay) -> Maybe<void> {\n                  auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"learning_rate_val\", \"scale\", \"l1\",\n                                                               \"l2\", \"centered\", \"epsilon\",\n                                                               \"decay_rate\", \"weight_decay\");\n                  attrs.SetAllAttrs(learning_rate, scale, l1, l2, centered, epsilon, decay_rate,\n                                    weight_decay);\n                  JUST(OpInterpUtil::Dispatch<TensorTuple>(*op, inputs, attrs));\n                  return Maybe<void>::Ok();\n                });\n  m.add_functor(\n      \"DispatchAdamUpdate\",\n      [](const std::shared_ptr<OpExpr>& op, const TensorTuple& inputs, float learning_rate,\n         float bias_correction1, float bias_correction2, double scale, float l1, float l2,\n         float beta1, float beta2, float epsilon, float weight_decay, bool amsgrad,\n         bool do_bias_correction) -> Maybe<void> {\n        auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\n            \"learning_rate_val\", \"bias_correction1_val\", \"bias_correction2_val\", \"scale\", \"l1\",\n            \"l2\", \"beta1\", \"beta2\", \"epsilon\", \"weight_decay\", \"amsgrad\", \"do_bias_correction\");\n        attrs.SetAllAttrs(learning_rate, bias_correction1, bias_correction2, scale, l1, l2, beta1,\n                          beta2, epsilon, weight_decay, amsgrad, do_bias_correction);\n        JUST(OpInterpUtil::Dispatch<TensorTuple>(*op, inputs, attrs));\n        return Maybe<void>::Ok();\n      });\n  m.add_functor(\"DispatchAdagradUpdate\",\n                [](const std::shared_ptr<OpExpr>& op, const TensorTuple& inputs,\n                   float learning_rate, double scale, float l1, float l2, float lr_decay,\n                   float weight_decay, float epsilon, int32_t train_step) -> Maybe<void> {\n                  auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"learning_rate_val\", \"scale\", \"l1\",\n                                                               \"l2\", \"lr_decay\", \"weight_decay\",\n                                                               \"epsilon\", \"train_step_val\");\n                  attrs.SetAllAttrs(learning_rate, scale, l1, l2, lr_decay, weight_decay, epsilon,\n                                    train_step);\n                  JUST(OpInterpUtil::Dispatch<TensorTuple>(*op, inputs, attrs));\n                  return Maybe<void>::Ok();\n                });\n  m.add_functor(\n      \"DispatchMomentumUpdate\",\n      [](const std::shared_ptr<OpExpr>& op, const TensorTuple& inputs, float learning_rate,\n         double scale, float l1, float l2, float beta, float dampening, bool nesterov,\n         bool maximize, float weight_decay) -> Maybe<void> {\n        auto& attrs =\n            THREAD_CACHED_MUTABLE_ATTR_MAP(\"learning_rate_val\", \"scale\", \"l1\", \"l2\", \"beta\",\n                                           \"dampening\", \"nesterov\", \"maximize\", \"weight_decay\");\n        attrs.SetAllAttrs(learning_rate, scale, l1, l2, beta, dampening, nesterov, maximize,\n                          weight_decay);\n        JUST(OpInterpUtil::Dispatch<TensorTuple>(*op, inputs, attrs));\n        return Maybe<void>::Ok();\n      });\n  m.add_functor(\n      \"DispatchSgdUpdate\",\n      [](const std::shared_ptr<OpExpr>& op, const TensorTuple& inputs, float learning_rate,\n         double scale, float l1, float l2, float weight_decay) -> Maybe<void> {\n        auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"learning_rate_val\", \"scale\", \"l1\", \"l2\",\n                                                     \"weight_decay\");\n        attrs.SetAllAttrs(learning_rate, scale, l1, l2, weight_decay);\n        JUST(OpInterpUtil::Dispatch<TensorTuple>(*op, inputs, attrs));\n        return Maybe<void>::Ok();\n      });\n  m.add_functor(\"DispatchLambUpdate\",\n                [](const std::shared_ptr<OpExpr>& op, const TensorTuple& inputs,\n                   float learning_rate, float bias_correction1, float bias_correction2,\n                   double scale, float l1, float l2, float beta1, float beta2, float epsilon,\n                   float weight_decay, bool do_bias_correction) -> Maybe<void> {\n                  auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\n                      \"learning_rate_val\", \"bias_correction1_val\", \"bias_correction2_val\", \"scale\",\n                      \"l1\", \"l2\", \"beta1\", \"beta2\", \"epsilon\", \"weight_decay\",\n                      \"do_bias_correction\");\n                  attrs.SetAllAttrs(learning_rate, bias_correction1, bias_correction2, scale, l1,\n                                    l2, beta1, beta2, epsilon, weight_decay, do_bias_correction);\n                  JUST(OpInterpUtil::Dispatch<TensorTuple>(*op, inputs, attrs));\n                  return Maybe<void>::Ok();\n                });\n  m.add_functor(\"DispatchFtrlUpdate\",\n                [](const std::shared_ptr<OpExpr>& op, const TensorTuple& inputs,\n                   float learning_rate, double scale, float l1, float l2, float lr_power,\n                   float lambda1, float lambda2, float beta, float weight_decay) -> Maybe<void> {\n                  auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"learning_rate_val\", \"scale\", \"l1\",\n                                                               \"l2\", \"lr_power\", \"lambda1\",\n                                                               \"lambda2\", \"beta\", \"weight_decay\");\n                  attrs.SetAllAttrs(learning_rate, scale, l1, l2, lr_power, lambda1, lambda2, beta,\n                                    weight_decay);\n                  JUST(OpInterpUtil::Dispatch<TensorTuple>(*op, inputs, attrs));\n                  return Maybe<void>::Ok();\n                });\n  m.add_functor(\n      \"DispatchAdadeltaUpdate\",\n      [](const std::shared_ptr<OpExpr>& op, const TensorTuple& inputs, float learning_rate,\n         double scale, float l1, float l2, float rho, float epsilon, bool maximize,\n         float weight_decay) -> Maybe<void> {\n        auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"learning_rate_val\", \"scale\", \"l1\", \"l2\",\n                                                     \"rho\", \"epsilon\", \"maximize\", \"weight_decay\");\n        attrs.SetAllAttrs(learning_rate, scale, l1, l2, rho, epsilon, maximize, weight_decay);\n        JUST(OpInterpUtil::Dispatch<TensorTuple>(*op, inputs, attrs));\n        return Maybe<void>::Ok();\n      });\n  m.add_functor(\"DispatchEagerCclAllReduce\",\n                [](const std::shared_ptr<OpExpr>& op, const std::shared_ptr<Tensor>& input,\n                   const std::string& parallel_conf) -> Maybe<Tensor> {\n                  auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"parallel_conf\");\n                  attrs.SetAllAttrs(parallel_conf);\n                  return OpInterpUtil::Dispatch<Tensor>(*op, {input}, attrs);\n                });\n  m.add_functor(\n      \"DispatchRawReader\",\n      [](const std::shared_ptr<OpExpr>& op, const std::vector<std::string>& files,\n         const Shape& shape, const Symbol<DType>& data_type, const int64_t batch_size,\n         const bool random_shuffle, const int64_t shuffle_block_size, int64_t random_seed,\n         const Optional<Symbol<Device>>& device) -> Maybe<Tensor> {\n        auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"files\", \"shape\", \"data_type\", \"batch_size\",\n                                                     \"random_shuffle\", \"shuffle_block_size\", \"seed\",\n                                                     \"nd_sbp\");\n        attrs.SetAllAttrs(files, shape, data_type->data_type(), batch_size, random_shuffle,\n                          shuffle_block_size, random_seed, std::vector<std::string>());\n        return OpInterpUtil::Dispatch<Tensor>(*op, {}, OpExprInterpContext(attrs, JUST(device)));\n      });\n  m.add_functor(\"DispatchRawReader\",\n                [](const std::shared_ptr<OpExpr>& op, const std::vector<std::string>& files,\n                   const Shape& shape, const Symbol<DType>& data_type, const int64_t batch_size,\n                   const bool random_shuffle, const int64_t shuffle_block_size, int64_t random_seed,\n                   const Symbol<ParallelDesc>& placement,\n                   const std::vector<Symbol<SbpParallel>>& sbp_tuple) -> Maybe<Tensor> {\n                  auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\n                      \"files\", \"shape\", \"data_type\", \"batch_size\", \"random_shuffle\",\n                      \"shuffle_block_size\", \"seed\", \"nd_sbp\");\n                  attrs.SetAllAttrs(files, shape, data_type->data_type(), batch_size,\n                                    random_shuffle, shuffle_block_size, random_seed,\n                                    *JUST(GetNdSbpStrList(sbp_tuple)));\n                  auto nd_sbp = JUST(GetNdSbp(sbp_tuple));\n                  return OpInterpUtil::Dispatch<Tensor>(\n                      *op, {}, OpExprInterpContext(attrs, placement, nd_sbp));\n                });\n}\n\n}  // namespace impl\n\n}  // namespace functional\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/api/python/functional/dispatch_stateful_ops.yaml",
    "content": "# Copyright 2020 The OneFlow Authors. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# The following data types are allowed,\n# {\n#   \"Tensor\", \"TensorTuple\", \"Scalar\", \"Int\", \"Int32\", \"Int64\", \"Float\", \"Double\", \"String\", \"Bool\",\n#   \"ScalarList\", \"IntList\", \"Int32List\", \"Int64List\", \"FloatList\", \"DoubleList\", \"StringList\",\n#   \"BoolList\", \"DataType\", \"Shape\", \"Generator\", \"TensorIndex\", \"Device\", \"Placement\",\n#   \"Sbp\", \"SbpList\"\n# }\n\n- name: \"dispatch_feed_input\"\n  signature: \"Tensor (OpExpr op, Tensor input) => DispatchFeedInput\"\n  bind_python: True\n\n- name: \"dispatch_feed_variable\"\n  signature: \"Tensor (OpExpr op, Tensor input, Scalar l2) => DispatchFeedVariable\"\n  bind_python: True\n\n- name: \"dispatch_fetch_output\"\n  signature: \"Tensor (OpExpr op, Tensor input) => DispatchFetchOutput\"\n  bind_python: True\n\n- name: \"dispatch_ofrecord_reader\"\n  signature: [\n      \"Tensor (OpExpr op, String data_dir, Int32 data_part_num, String part_name_prefix=\\\"part-\\\", Int32 part_name_suffix_length=-1, Int32 batch_size, Int32 shuffle_buffer_size=1024, Bool random_shuffle=False, Bool shuffle_after_epoch=False, Int64 seed=-1, Device device=None) => DispatchOfrecordReader\",\n      \"Tensor (OpExpr op, String data_dir, Int32 data_part_num, String part_name_prefix=\\\"part-\\\", Int32 part_name_suffix_length=-1, Int32 batch_size, Int32 shuffle_buffer_size=1024, Bool random_shuffle=False, Bool shuffle_after_epoch=False, Int64 seed=-1, Placement placement, SbpList sbp) => DispatchOfrecordReader\",\n  ]\n  bind_python: True\n\n- name: \"dispatch_ofrecord_raw_decoder\"\n  signature: \"Tensor (OpExpr op, Tensor input, String name, Shape shape, DataType data_type, Bool dim1_varying_length=False, Bool truncate=False) => DispatchOfrecordRawDecoder\"\n  bind_python: True\n\n- name: \"dispatch_coin_flip\"\n  signature: [\n      \"Tensor (OpExpr op, Int64 batch_size, Scalar probability=0.5, Int64 seed=-1, Bool has_seed=False, Device device=None) => DispatchCoinFlip\",\n      \"Tensor (OpExpr op, Int64 batch_size, Scalar probability=0.5, Int64 seed=-1, Bool has_seed=False, Placement placement, SbpList sbp) => DispatchCoinFlip\",\n  ]\n  bind_python: True\n\n- name: \"dispatch_distributed_partial_fc_sample\"\n  signature:\n    \"TensorTuple (OpExpr op, Tensor weight, Tensor label, Int64 num_sample) => DispatchDistributedPariticalFCSample\"\n  bind_python: True\n\n- name: \"dispatch_crop_mirror_normalize_from_uint8\"\n  signature: \"Tensor (OpExpr op, TensorTuple input, Int64 crop_h=0, Int64 crop_w=0, Float crop_pos_x=0.5, Float crop_pos_y=0.5, FloatList mean, FloatList std, DataType output_dtype=kFloat, String output_layout=\\\"NCHW\\\", String color_space=\\\"BGR\\\") => DispatchCropMirrorNormalizeFromUint8\"\n  bind_python: True\n\n- name: \"dispatch_crop_mirror_normalize_from_tensorbuffer\"\n  signature: \"Tensor (OpExpr op, TensorTuple input, Int64 crop_h=0, Int64 crop_w=0, Float crop_pos_x=0.5, Float crop_pos_y=0.5, FloatList mean, FloatList std, DataType output_dtype=kFloat, String output_layout=\\\"NCHW\\\", String color_space=\\\"BGR\\\") => DispatchCropMirrorNormalizeFromTensorBuffer\"\n  bind_python: True\n\n- name: \"dispatch_ofrecord_image_decoder_random_crop\"\n  signature: \"Tensor (OpExpr op, Tensor input, String name, String color_space=\\\"BGR\\\", FloatList random_area, FloatList random_aspect_ratio, Int32 num_attempts=10, Int64 seed=-1, Bool has_seed=False) => DispatchOfrecordImageDecoderRandomCrop\"\n  bind_python: True\n\n- name: \"dispatch_ofrecord_image_decoder\"\n  signature: \"Tensor (OpExpr op, Tensor input, String name, String color_space=\\\"BGR\\\") => DispatchOfrecordImageDecoder\"\n  bind_python: True\n\n- name: \"dispatch_image_decoder_random_crop_resize\"\n  signature: \"Tensor (OpExpr op, Tensor input, Int64 target_width, Int64 target_height, Int64 seed, Int64 num_workers=3, Int64 max_num_pixels=67108864, Float random_area_min=0.08f, Float random_area_max=1.0f, Float random_aspect_ratio_min=0.75f, Float random_aspect_ratio_max=1.333333f, Int64 warmup_size=6400, Int64 num_attempts=10) => DispatchImageDecoderRandomCropResize\"\n  bind_python: True\n\n- name: \"dispatch_tensor_buffer_to_list_of_tensors_v2\"\n  signature: \"TensorTuple (OpExpr op, Tensor input, ShapeList out_shapes, DataTypeList out_dtypes, Bool dynamic_out) => DispatchTensorBufferToListOfTensorsV2\"\n  bind_python: True\n\n- name: \"dispatch_image_resize_keep_aspect_ratio\"\n  signature: \"TensorTuple (OpExpr op, Tensor input, Int32 target_size, Int32 min_size=0, Int32 max_size=0, Bool resize_longer=False, String interpolation_type=\\\"bilinear\\\") => DispatchImageResizeKeepAspectRatio\"\n  bind_python: True\n\n- name: \"dispatch_image_resize_to_fixed\"\n  signature: \"TensorTuple (OpExpr op, Tensor input, Int64 target_width=0, Int64 target_height=0, Int64 channels=3, DataType data_type=kUInt8, String interpolation_type=\\\"bilinear\\\") => DispatchImageResizeToFixed\"\n  bind_python: True\n\n- name: \"dispatch_image_decode\"\n  signature: \"Tensor (OpExpr op, Tensor input, String color_space=\\\"BGR\\\", DataType data_type=kUInt8) => DispatchImageDecode\"\n  bind_python: True\n\n- name: \"dispatch_image_normalize\"\n  signature: \"Tensor (OpExpr op, Tensor input, FloatList mean, FloatList std) => DispatchImageNormalize\"\n  bind_python: True\n\n- name: \"dispatch_coco_reader\"\n  signature: [\n      \"TensorTuple (OpExpr op, String image_dir, String annotation_file, Int64 batch_size, Bool shuffle_after_epoch=False, Int64 random_seed=-1, Bool group_by_ratio=True, Bool remove_images_without_annotations=True, Bool stride_partition=False, Int64 session_id, Device device=None) => DispatchCOCOReader\",\n      \"TensorTuple (OpExpr op, String image_dir, String annotation_file, Int64 batch_size, Bool shuffle_after_epoch=False, Int64 random_seed=-1, Bool group_by_ratio=True, Bool remove_images_without_annotations=True, Bool stride_partition=False, Int64 session_id, Placement placement, SbpList sbp) => DispatchCOCOReader\",\n  ]\n  bind_python: True\n\n- name: \"dispatch_image_batch_align\"\n  signature: \"Tensor (OpExpr op, Tensor input, Int32 alignment, Shape shape, DataType data_type, Bool dynamic_out) => DispatchImageBatchAlign\"\n  bind_python: True\n\n- name: \"dispatch_ofrecord_bytes_decoder\"\n  signature: \"Tensor (OpExpr op, Tensor input, String name) => DispatchOfrecordBytesDecoder\"\n  bind_python: True\n\n- name: \"dispatch_megatron_gpt_mmap_data_loader\"\n  signature: [\n      \"Tensor (OpExpr op, String data_file_prefix, Int64 seq_length, Int64 label_length=1, Int64 num_samples, Int64 batch_size, DataType dtype, Int64List split_sizes, Int64 split_index, Bool shuffle, Int64 random_seed, Device device=None) => DispatchMegatronGptMmapDataLoader\",\n      \"Tensor (OpExpr op, String data_file_prefix, Int64 seq_length, Int64 label_length=1, Int64 num_samples, Int64 batch_size, DataType dtype, Int64List split_sizes, Int64 split_index, Bool shuffle, Int64 random_seed, Placement placement, SbpList sbp) => DispatchMegatronGptMmapDataLoader\",\n  ]\n  bind_python: True\n\n- name: \"dispatch_rmsprop_update\"\n  signature: \"Void (OpExpr op, TensorTuple inputs, Float learning_rate=0, Double scale=1.0, Float l1=0, Float l2=0, Bool centered=False, Float epsilon=1e-8, Float decay_rate=0.99, Float weight_decay=0.0) => DispatchRmspropUpdate\"\n  bind_python: True\n\n- name: \"dispatch_adam_update\"\n  signature: \"Void (OpExpr op, TensorTuple inputs, Float learning_rate=0, Float bias_correction1=1.0, Float bias_correction2=1.0, Double scale=1.0, Float l1=0, Float l2=0, Float beta1=0.9, Float beta2=0.999, Float epsilon=1e-8, Float weight_decay=0, Bool amsgrad=False, Bool do_bias_correction=True) => DispatchAdamUpdate\"\n  bind_python: True\n\n- name: \"dispatch_adagrad_update\"\n  signature: \"Void (OpExpr op, TensorTuple inputs, Float learning_rate=0, Double scale=1.0, Float l1=0, Float l2=0, Float lr_decay=0, Float weight_decay=0, Float epsilon=1e-10, Int32 train_step_val=0) => DispatchAdagradUpdate\"\n  bind_python: True\n\n- name: \"dispatch_momentum_update\"\n  signature: \"Void (OpExpr op, TensorTuple inputs, Float learning_rate=0, Double scale=1.0, Float l1=0, Float l2=0, Float beta=0.9, Float dampening=0.0, Bool nesterov=False, Bool maximize=False, Float weight_decay=0) => DispatchMomentumUpdate\"\n  bind_python: True\n\n- name: \"dispatch_sgd_update\"\n  signature: \"Void (OpExpr op, TensorTuple inputs, Float learning_rate=0, Double scale=1.0, Float l1=0, Float l2=0, Float weight_decay=0) => DispatchSgdUpdate\"\n  bind_python: True\n    \n- name: \"dispatch_lamb_update\"\n  signature: \"Void (OpExpr op, TensorTuple inputs, Float learning_rate=0, Float bias_correction1=1.0, Float bias_correction2=1.0, Double scale=1.0, Float l1=0, Float l2=0, Float beta1=0.9, Float beta2=0.999, Float epsilon=1e-8, Float weight_decay=0, Bool do_bias_correction=True) => DispatchLambUpdate\"\n  bind_python: True\n\n- name: \"dispatch_ftrl_update\"\n  signature: \"Void (OpExpr op, TensorTuple inputs, Float learning_rate=0, Double scale=1.0, Float l1=0, Float l2=0, Float lr_power, Float lambda1, Float lambda2, Float beta, Float weight_decay=0) => DispatchFtrlUpdate\"\n  bind_python: True\n\n- name: \"dispatch_adadelta_update\"\n  signature: \"Void (OpExpr op, TensorTuple inputs, Float learning_rate=0, Double scale=1.0, Float l1=0, Float l2=0, Float rho, Float epsilon, Bool maximize, Float weight_decay=0) => DispatchAdadeltaUpdate\"\n  bind_python: True\n\n\n- name: \"dispatch_eager_ccl_all_reduce\"\n  signature: \"Tensor (OpExpr op, Tensor input, String parallel_conf) => DispatchEagerCclAllReduce\"\n  bind_python: True\n\n- name: \"dispatch_raw_reader\"\n  signature: [\n    \"Tensor (OpExpr op, StringList files, Shape shape, DataType data_type, Int64 batch_size, Bool random_shuffle,  Int64 shuffle_block_size,  Int64 random_seed=-1,  Device device=None) => DispatchRawReader\",\n    \"Tensor (OpExpr op, StringList files, Shape shape, DataType data_type, Int64 batch_size, Bool random_shuffle,  Int64 shuffle_block_size, Int64 random_seed=-1, Placement placement, SbpList sbp) => DispatchRawReader\",\n  ]\n  bind_python: True\n"
  },
  {
    "path": "oneflow/api/python/functional/function_def.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_API_PYTHON_FUNCTIONAL_FUNCTION_DEF_H_\n#define ONEFLOW_API_PYTHON_FUNCTIONAL_FUNCTION_DEF_H_\n\n#include <memory>\n#include <string>\n#include <vector>\n\n#include \"oneflow/api/python/functional/python_arg.h\"\n#include \"oneflow/api/python/functional/value_types.h\"\n\nnamespace oneflow {\nnamespace one {\nnamespace functional {\n\nstruct ReturnDef {\n  explicit ReturnDef(const ValueType& t) : type(t) {}\n  ValueType type;\n};\n\nstruct ArgumentDef {\n  ArgumentDef(const std::string& arg_name, const ValueType& arg_type, int arg_size,\n              bool arg_keyword_only, bool arg_optional)\n      : name(arg_name),\n        type(arg_type),\n        size(arg_size),\n        keyword_only(arg_keyword_only),\n        optional(arg_optional),\n        has_default_value(false) {}\n\n  template<typename T>\n  ArgumentDef(const std::string& arg_name, const T& arg_val, int arg_size, bool arg_keyword_only,\n              bool arg_optional)\n      : name(arg_name),\n        type(ValueTypeOf<T>()),\n        size(arg_size),\n        keyword_only(arg_keyword_only),\n        optional(arg_optional),\n        has_default_value(true) {\n    default_value = std::make_shared<detail::TypedDefaultVal<T>>(arg_val);\n  }\n\n  std::string name;\n  ValueType type;\n\n  int size;\n  bool keyword_only;\n  bool optional;\n  bool has_default_value;\n  std::shared_ptr<const detail::DefaultVal> default_value;\n};\n\nstruct FunctionDef {\n  std::string name;\n  ReturnDef return_def;\n  std::vector<ArgumentDef> argument_def;\n};\n\n}  // namespace functional\n}  // namespace one\n}  // namespace oneflow\n\n#endif  // ONEFLOW_API_PYTHON_FUNCTIONAL_FUNCTION_DEF_H_\n"
  },
  {
    "path": "oneflow/api/python/functional/indexing.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/api/python/functional/indexing.h\"\n\n#include <object.h>\n#include <pybind11/pybind11.h>\n#include \"oneflow/api/python/functional/common.h\"\n#include \"oneflow/extension/python/numpy.h\"\n#include \"oneflow/core/eager/eager_blob_object.h\"\n#include \"oneflow/core/framework/device.h\"\n#include \"oneflow/core/framework/instructions_builder.h\"\n#include \"oneflow/core/functional/functional.h\"\n#include \"oneflow/api/python/functional/tensor_api.yaml.h\"\n#include \"oneflow/core/common/foreign_lock_helper.h\"\n\nnamespace oneflow {\nnamespace one {\nnamespace functional {\n\nnamespace detail {\n\nvoid PySliceUnpack(PyObject* object, Py_ssize_t* start, Py_ssize_t* stop, Py_ssize_t* step) {\n  PySliceObject* obj = (PySliceObject*)object;\n  if (obj->step == Py_None) {\n    *step = 1;\n  } else {\n    CHECK_OR_THROW(_PyEval_SliceIndex(obj->step, step))\n        << \"Invalid slice \" << PyObjectToReprStr(object);\n    CHECK_NE_OR_THROW(*step, 0) << \"slice step cannot be zero.\";\n    if (*step < -PY_SSIZE_T_MAX) *step = -PY_SSIZE_T_MAX;\n  }\n  if (obj->start == Py_None) {\n    *start = *step < 0 ? PY_SSIZE_T_MAX : 0;\n  } else {\n    CHECK_OR_THROW(_PyEval_SliceIndex(obj->start, start))\n        << \"Invalid slice \" << PyObjectToReprStr(object);\n  }\n  if (obj->stop == Py_None) {\n    *stop = *step < 0 ? PY_SSIZE_T_MIN : PY_SSIZE_T_MAX;\n  } else {\n    CHECK_OR_THROW(_PyEval_SliceIndex(obj->stop, stop))\n        << \"Invalid slice \" << PyObjectToReprStr(object);\n  }\n}\n\nDataType InferScalarType(PyObject* object) {\n  if (PyBool_Check(object)) {\n    return DataType::kBool;\n  } else if (PyLong_Check(object)) {\n    return DataType::kInt64;\n  } else if (PyArray_Check(object)) {\n    return numpy::GetOFDataTypeFromNpArray(reinterpret_cast<PyArrayObject*>(object)).GetOrThrow();\n  } else if (PyArray_CheckScalar(object)) {\n    return numpy::NumpyTypeToOFDataType(PyArray_DescrFromScalar(object)->type_num).GetOrThrow();\n  } else if (PySequence_Check(object)) {\n    int64_t length = PySequence_Length(object);\n    if (length == 0) { return DataType::kInt64; }\n    DataType scalar_type = DataType::kInvalidDataType;\n    for (int64_t i = 0; i < length; ++i) {\n      PyObjectPtr item(PySequence_GetItem(object, i));\n      const auto& item_scalar_type = InferScalarType(item.get());\n      if (scalar_type != DataType::kInvalidDataType) {\n        CHECK_EQ_OR_THROW(scalar_type, item_scalar_type)\n            << \"Different scalar types are not allowed.\";\n      } else {\n        scalar_type = item_scalar_type;\n      }\n    }\n    return scalar_type;\n  }\n  THROW(TypeError) << \"Can't infer scalar type of \" << Py_TYPE(object)->tp_name;\n  return DataType::kInvalidDataType;\n}\n\nvoid ParseScalar(PyObject* object, char* data, const DataType& dtype) {\n  if (dtype == DataType::kInt64) {\n    CHECK_OR_THROW(PyLong_Check(object) || numpy::PyArrayCheckLongScalar(object))\n        << \"Expected a long value.\";\n    *(reinterpret_cast<int64_t*>(data)) = PyLong_AsLongLong(object);\n  } else if (dtype == DataType::kInt32) {\n    CHECK_OR_THROW(PyLong_Check(object) || numpy::PyArrayCheckLongScalar(object))\n        << \"Expected a long value.\";\n    *(reinterpret_cast<int32_t*>(data)) = PyLong_AsLongLong(object);\n  } else if (dtype == DataType::kUInt8 || dtype == DataType::kBool) {\n    CHECK_OR_THROW(PyBool_Check(object) || PyLong_Check(object)\n                   || numpy::PyArrayCheckLongScalar(object))\n        << \"Expected a boolean or long value.\";\n    if (PyBool_Check(object) || numpy::PyArrayCheckBoolScalar(object)) {\n      *(reinterpret_cast<bool*>(data)) = (object == Py_True);\n    } else {\n      int64_t value = PyLong_AsLongLong(object);\n      CHECK_OR_THROW(value >= 0 && value <= 255) << \"Out of range 0-255.\";\n      *(reinterpret_cast<uint8_t*>(data)) = static_cast<uint8_t>(value);\n    }\n  } else {\n    THROW(TypeError) << \"Can't parse scalar with data type \" << dtype;\n  }\n}\n\nvoid RecursiveParseAndAssign(PyObject* object, char* data, const int& ndims, const int& dim,\n                             const ShapeView& shape, const DimVector& strides,\n                             const DataType& dtype) {\n  if (dim == ndims) { return ParseScalar(object, data, dtype); }\n  auto seq = PyObjectPtr(PySequence_Fast(object, \"Expected a sequence.\"));\n  int64_t size = PySequence_Fast_GET_SIZE(seq.get());\n  CHECK_EQ_OR_THROW(size, shape.At(dim)) << \"Sequence size is \" << size << \" at dimemsion \" << dim\n                                         << \", but expected \" << shape.At(dim);\n  for (int64_t i = 0; i < size; ++i) {\n    PyObject* item = PySequence_Fast_GET_ITEM(seq.get(), i);\n    RecursiveParseAndAssign(item, data, ndims, dim + 1, shape, strides, dtype);\n    data += strides.at(dim) * GetSizeOfDataType(dtype);\n  }\n}\n\nvoid ParseArrayToTensor(PyObject* object,\n                        const std::shared_ptr<vm::EagerBlobObject>& eager_blob_object) {\n  const DataType dtype = eager_blob_object->data_type();\n  const int ndims = eager_blob_object->shape().NumAxes();\n  DimVector strides(ndims);\n  int64_t size = 1;\n  for (int i = ndims - 1; i >= 0; --i) {\n    strides[i] = size;\n    size *= eager_blob_object->shape().At(i);\n  }\n  RecursiveParseAndAssign(object, eager_blob_object->mut_dptr<char>(), ndims, 0,\n                          eager_blob_object->shape(), strides, dtype);\n}\n\nShape InferArraySizes(PyObject* object) {\n  DimVector sizes;\n  PyObject* seq = object;\n  PyObjectPtr handle;\n  while (PySequence_Check(seq)) {\n    int64_t length = PySequence_Length(seq);\n    sizes.emplace_back(length);\n    CHECK_LE_OR_THROW(sizes.size(), /*MAX_DIMS=*/128)\n        << \"Too many dimensions \" << Py_TYPE(seq)->tp_name;\n    if (length == 0) break;\n    handle = PyObjectPtr(PySequence_GetItem(seq, 0));\n    seq = handle.get();\n  }\n  return Shape(sizes);\n}\n\nMaybe<Tensor> ConvertToIndexingTensor(PyObject* object) {\n  // NOTE: convert data to indexing will ensure in eager mode\n  LazyMode::Guard lazy_mode_disabled_guard(/*is_enabled*/ false);\n  const DataType dtype = InferScalarType(object);\n  const auto& device = JUST(Device::New(\"cpu\"));\n\n  // index type must be integers\n  if (!(IsIntegralDataType(dtype) || (IsBoolDataType(dtype)))) {\n    return Error::IndexError() << \"only integers, slices (`:`), ellipsis (`...`), numpy.newaxis \"\n                                  \"(`None`) and integer or boolean arrays are valid indices\";\n  }\n  // In advanced indexing condition, index can be array object, need to handle it specially.\n  if (PyArray_Check(object)) {\n    return TensorWithData(object, NullOpt, device, /*requires_grad=*/false, /*pin_memory=*/false);\n  }\n\n  const auto& sizes = InferArraySizes(object);\n  const auto& tensor = JUST(functional::Empty(sizes, CHECK_JUST(DType::Get(dtype)), device,\n                                              /*requires_grad=*/false, /*pin_memory=*/false));\n  // Prevent the python object release until the callback is complete.\n  Py_INCREF(object);\n  auto handle = std::shared_ptr<PyObject>(PyObjectPtr(object));\n\n  JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe<void> {\n    return builder->AccessBlobByCallback(\n        JUST(tensor->AsLocalTensor()),\n        [handle](ep::Stream* stream,\n                 const std::shared_ptr<vm::EagerBlobObject>& eager_blob_object) {\n          CHECK_JUST(Singleton<ForeignLockHelper>::Get()->WithScopedAcquire([&]() -> Maybe<void> {\n            ParseArrayToTensor(handle.get(), eager_blob_object);\n            return Maybe<void>::Ok();\n          }));\n        },\n        \"mut\");\n  }));\n  return tensor;\n}\n\nIndexItem UnpackIndexItem(PyObject* object) {\n  if (object == Py_Ellipsis) {\n    return IndexItem(EllipsisIndex{});\n  } else if (PySlice_Check(object)) {\n    Py_ssize_t start, end, step;\n    PySliceUnpack(object, &start, &end, &step);\n    return IndexItem(start, end, step);\n  } else if (PyLong_Check(object) && object != Py_False && object != Py_True) {\n    return IndexItem(static_cast<int64_t>(PyLong_AsLongLong(object)));\n  } else if (numpy::PyArrayCheckLongScalar(object)) {\n    return IndexItem(static_cast<int64_t>(PyLong_AsLongLong(object)));\n  } else if (object == Py_False || object == Py_True) {\n    return IndexItem(object == Py_True);\n  } else if (object == Py_None) {\n    return IndexItem(NoneIndex{});\n  } else if (PyTensor_Check(object)) {\n    return IndexItem(PyTensor_Unpack(object));\n  } else if (PySequence_Check(object)) {\n    return IndexItem(ConvertToIndexingTensor(object).GetPtrOrThrow());\n  }\n  THROW(IndexError) << \"Invalid index \" << Py_TYPE(object)->tp_name;\n  return IndexItem();\n}\n\n}  // namespace detail\n\n}  // namespace functional\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/api/python/functional/indexing.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_API_PYTHON_FUNCTIONAL_INDEXING_H_\n#define ONEFLOW_API_PYTHON_FUNCTIONAL_INDEXING_H_\n\n#include <Python.h>\n#undef _PyGC_FINALIZED\n\n#include \"oneflow/api/python/functional/common.h\"\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/framework/tensor.h\"\n#include \"oneflow/core/functional/tensor_index.h\"\n\nnamespace oneflow {\nnamespace one {\nnamespace functional {\n\nnamespace detail {\n\nvoid PySliceUnpack(PyObject* object, Py_ssize_t* start, Py_ssize_t* stop, Py_ssize_t* step);\n\nMaybe<Tensor> ConvertToIndexingTensor(PyObject* object);\n\nIndexItem UnpackIndexItem(PyObject* object);\n\n}  // namespace detail\n\n}  // namespace functional\n}  // namespace one\n}  // namespace oneflow\n\n#endif  // ONEFLOW_API_PYTHON_FUNCTIONAL_INDEXING_H_\n"
  },
  {
    "path": "oneflow/api/python/functional/python_arg.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/api/python/functional/python_arg.h\"\n\n#include \"oneflow/api/python/framework/tensor.h\"\n#include \"oneflow/api/python/functional/common.h\"\n#include \"oneflow/api/python/functional/indexing.h\"\n#include \"oneflow/api/python/framework/memory_format.h\"\n#include \"oneflow/extension/python/numpy.h\"\n#include \"oneflow/core/common/scalar.h\"\n#include \"oneflow/core/framework/dtype.h\"\n#include \"oneflow/core/framework/layout.h\"\n#include \"oneflow/core/framework/device.h\"\n#include \"oneflow/core/framework/op_expr.h\"\n#include \"oneflow/core/framework/tensor.h\"\n#include \"oneflow/core/framework/tensor_tuple.h\"\n#include \"oneflow/core/framework/random_generator.h\"\n#include \"oneflow/core/functional/tensor_index.h\"\n\nnamespace py = pybind11;\n\nnamespace oneflow {\nnamespace one {\nnamespace functional {\n\n#define INSTANCE_OBJECT_AS_INTEGER(T)                                                            \\\n  template<>                                                                                     \\\n  T PythonArg::ObjectAs<T>() const {                                                             \\\n    if (PyIntegerScalarTensorCheck(object_)) {                                                   \\\n      return static_cast<T>(PyUnpackIntegerScalarTensor_AsLongLong(object_));                    \\\n    }                                                                                            \\\n    return static_cast<T>(PyLong_AsLongLong(object_));                                           \\\n  }                                                                                              \\\n  template<>                                                                                     \\\n  std::vector<T> PythonArg::ObjectAs<std::vector<T>>() const {                                   \\\n    if (size_ > 0 && PyLong_Check(object_)) {                                                    \\\n      return std::vector<T>(size_, static_cast<T>(PyLong_AsLongLong(object_)));                  \\\n    }                                                                                            \\\n    return PyUnpackLongSequence<T>(object_);                                                     \\\n  }                                                                                              \\\n  template<>                                                                                     \\\n  std::shared_ptr<std::vector<T>> PythonArg::ObjectAs<std::shared_ptr<std::vector<T>>>() const { \\\n    return std::make_shared<std::vector<T>>(ObjectAs<std::vector<T>>());                         \\\n  }\n\nOF_PP_FOR_EACH_TUPLE(INSTANCE_OBJECT_AS_INTEGER, INTEGER_AND_BOOL_TYPE_SEQ)\n#undef INSTANCE_OBJECT_AS_INTEGER\n\n#define INSTANCE_OBJECT_AS_FLOAT(T)                                                              \\\n  template<>                                                                                     \\\n  T PythonArg::ObjectAs<T>() const {                                                             \\\n    if (PyFloatScalarTensorCheck(object_)) {                                                     \\\n      return static_cast<T>(PyUnpackFloatScalarTensor_AsDouble(object_));                        \\\n    }                                                                                            \\\n    return static_cast<T>(PyFloat_AsDouble(object_));                                            \\\n  }                                                                                              \\\n  template<>                                                                                     \\\n  std::vector<T> PythonArg::ObjectAs<std::vector<T>>() const {                                   \\\n    if (size_ > 0 && PyFloat_Check(object_)) {                                                   \\\n      return std::vector<T>(size_, static_cast<T>(PyFloat_AsDouble(object_)));                   \\\n    }                                                                                            \\\n    return PyUnpackFloatSequence<T>(object_);                                                    \\\n  }                                                                                              \\\n  template<>                                                                                     \\\n  std::shared_ptr<std::vector<T>> PythonArg::ObjectAs<std::shared_ptr<std::vector<T>>>() const { \\\n    return std::make_shared<std::vector<T>>(ObjectAs<std::vector<T>>());                         \\\n  }\n\nOF_PP_FOR_EACH_TUPLE(INSTANCE_OBJECT_AS_FLOAT, FLOATING_TYPE_SEQ)\n#undef INSTANCE_OBJECT_AS_FLOAT\n\n#define INSTANCE_OBJECT_AS_SHARED_PTR(T)                               \\\n  template<>                                                           \\\n  std::shared_ptr<T> PythonArg::ObjectAs<std::shared_ptr<T>>() const { \\\n    return std::make_shared<T>(ObjectAs<T>());                         \\\n  }\n\ntemplate<>\nstd::string PythonArg::ObjectAs<std::string>() const {\n  return PyStringAsString(object_);\n}\nINSTANCE_OBJECT_AS_SHARED_PTR(std::string)\n\ntemplate<>\nScalar PythonArg::ObjectAs<Scalar>() const {\n  if (PyScalarTensorCheck(object_)) { return PyUnpackScalarTensor(object_); }\n  return PyUnpackScalar(object_);\n}\nINSTANCE_OBJECT_AS_SHARED_PTR(Scalar)\n\ntemplate<>\nstd::shared_ptr<one::Tensor> PythonArg::ObjectAs<std::shared_ptr<one::Tensor>>() const {\n  return PyTensor_Unpack(object_);\n}\n\ntemplate<>\none::TensorTuple PythonArg::ObjectAs<one::TensorTuple>() const {\n  if (PyTensorTupleCheck(object_)) { return *PyUnpackTensorTuple(object_); }\n  const auto& v = PyUnpackTensorSequence(object_);\n  one::TensorTuple values(v.size());\n  for (int i = 0; i < v.size(); ++i) { values[i] = v.at(i); }\n  return values;\n}\nINSTANCE_OBJECT_AS_SHARED_PTR(one::TensorTuple)\n\ntemplate<>\nSymbol<DType> PythonArg::ObjectAs<Symbol<DType>>() const {\n  return PyUnpackDType(object_);\n}\n\ntemplate<>\nSymbol<Layout> PythonArg::ObjectAs<Symbol<Layout>>() const {\n  return PyUnpackLayout(object_);\n}\n\ntemplate<>\nSymbol<MemoryFormat> PythonArg::ObjectAs<Symbol<MemoryFormat>>() const {\n  return PyUnpackMemoryFormat(object_);\n}\n\ntemplate<>\nstd::vector<Symbol<DType>> PythonArg::ObjectAs<std::vector<Symbol<DType>>>() const {\n  return PyUnpackDTypeSequence(object_);\n}\nINSTANCE_OBJECT_AS_SHARED_PTR(std::vector<Symbol<DType>>)\n\ntemplate<>\nShape PythonArg::ObjectAs<Shape>() const {\n  return PyUnpackShape(object_);\n}\nINSTANCE_OBJECT_AS_SHARED_PTR(Shape)\n\ntemplate<>\nstd::vector<Shape> PythonArg::ObjectAs<std::vector<Shape>>() const {\n  return PyUnpackShapeSequence(object_);\n}\nINSTANCE_OBJECT_AS_SHARED_PTR(std::vector<Shape>)\n\ntemplate<>\nstd::shared_ptr<one::Generator> PythonArg::ObjectAs<std::shared_ptr<one::Generator>>() const {\n  return PyUnpackGenerator(object_);\n}\n\ntemplate<>\nSymbol<Device> PythonArg::ObjectAs<Symbol<Device>>() const {\n  if (PyStringCheck(object_)) {\n    std::string device_str = PyStringAsString(object_);\n    return Device::ParseAndNew(device_str).GetOrThrow();\n  }\n  return PyUnpackDevice(object_);\n}\n\ntemplate<>\nSymbol<ParallelDesc> PythonArg::ObjectAs<Symbol<ParallelDesc>>() const {\n  return PyUnpackParallelDesc(object_);\n}\n\ntemplate<>\nSymbol<SbpParallel> PythonArg::ObjectAs<Symbol<SbpParallel>>() const {\n  return PyUnpackSbpParallel(object_);\n}\n\ntemplate<>\nstd::vector<Symbol<SbpParallel>> PythonArg::ObjectAs<std::vector<Symbol<SbpParallel>>>() const {\n  if (PySbpParallelCheck(object_)) {\n    return std::vector<Symbol<SbpParallel>>(1, PyUnpackSbpParallel(object_));\n  }\n  return PyUnpackSbpParallelSequence(object_);\n}\nINSTANCE_OBJECT_AS_SHARED_PTR(std::vector<Symbol<SbpParallel>>)\n\ntemplate<>\nTensorIndex PythonArg::ObjectAs<TensorIndex>() const {\n  return PyUnpackTensorIndex(object_);\n}\nINSTANCE_OBJECT_AS_SHARED_PTR(TensorIndex)\n\ntemplate<>\nstd::shared_ptr<one::OpExpr> PythonArg::ObjectAs<std::shared_ptr<one::OpExpr>>() const {\n  return PyUnpackOpExpr(object_);\n}\n\ntemplate<>\nPyObject* PythonArg::ObjectAs<PyObject*>() const {\n  return object_;\n}\n\ntemplate<>\nstd::vector<std::string> PythonArg::ObjectAs<std::vector<std::string>>() const {\n  return PyUnpackSequence<std::string>(\n      object_, [](PyObject* item) -> std::string { return PyStringAsString(item); });\n}\n\nINSTANCE_OBJECT_AS_SHARED_PTR(std::vector<std::string>)\n\ntemplate<>\nMemoryFormat PythonArg::ObjectAs<MemoryFormat>() const {\n  return PyMemoryFormat_Unpack(object_);\n}\n\n#undef INSTANCE_OBJECT_AS_SHARED_PTR\n\nbool PythonArg::TypeCheck(ValueType type) const {\n  if (tag_ == HAS_DEFAULT) { return default_val_->value_type() == type; }\n  switch (type) {\n    case kINT32:\n    case kINT16:\n    case kCHAR:\n    case kUINT32:\n    case kINT64:\n    case kUINT64:\n    case kBOOL:\n      return PyLong_Check(object_) || numpy::PyArrayCheckLongScalar(object_)\n             || PyIntegerScalarTensorCheck(object_) || PyBoolScalarTensorCheck(object_);\n    case kINT32_LIST:\n    case kUINT32_LIST:\n    case kINT64_LIST:\n    case kUINT64_LIST:\n    case kBOOL_LIST: return PyLongSequenceCheck(object_) || (size_ > 0 && PyLong_Check(object_));\n    case kFLOAT:\n    case kDOUBLE:\n      return PyFloat_Check(object_) || PyLong_Check(object_)\n             || numpy::PyArrayCheckFloatScalar(object_) || numpy::PyArrayCheckLongScalar(object_)\n             || PyFloatScalarTensorCheck(object_) || PyIntegerScalarTensorCheck(object_);\n    case kFLOAT_LIST:\n    case kDOUBLE_LIST:\n      return PyFloatSequenceCheck(object_)\n             || (size_ > 0 && (PyFloat_Check(object_) || PyLong_Check(object_)));\n    case kSTRING: return PyStringCheck(object_);\n    case kSTRING_LIST: return PyStringSequenceCheck(object_);\n    case kSCALAR:\n      return PyScalarCheck(object_) || numpy::PyArrayCheckLongScalar(object_)\n             || numpy::PyArrayCheckFloatScalar(object_) || PyScalarTensorCheck(object_);\n    case kTENSOR:\n    case kTENSOR_REF: return PyTensor_Check(object_);\n    case kTENSOR_TUPLE: return PyTensorTupleCheck(object_) || PyTensorSequenceCheck(object_);\n    case kDTYPE: return PyDTypeCheck(object_);\n    case kLAYOUT: return PyLayoutCheck(object_);\n    case kMEMORY_FORMAT: return PyMemoryFormat_Check(object_);\n    case kSHAPE: return PyLongSequenceCheck(object_);\n    case kGENERATOR:\n    case kGENERATOR_REF: return PyGeneratorCheck(object_);\n    case kTENSOR_INDEX: return PyTensorIndexCheck(object_);\n    case kDEVICE: return PyStringCheck(object_) || PyDeviceCheck(object_);\n    case kPARALLEL_DESC: return PyParallelDescCheck(object_);\n    case kSBP_PARALLEL: return PySbpParallelCheck(object_);\n    case kSBP_PARALLEL_LIST:\n      return PySbpParallelSequenceCheck(object_) || PySbpParallelCheck(object_);\n    case kOPEXPR_REF: return PyOpExprCheck(object_);\n    case kPY_OBJECT: return nullptr != object_;\n    case kDTYPE_LIST: return PyDTypeSequenceCheck(object_);\n    case kSHAPE_LIST: return PyShapeSequenceCheck(object_);\n    case kCOMPLEX_FLOAT:\n    case kCOMPLEX_DOUBLE:\n      return PyComplex_Check(object_) || PyFloat_Check(object_) || PyLong_Check(object_)\n             || numpy::PyArrayCheckComplexScalar(object_) || numpy::PyArrayCheckFloatScalar(object_)\n             || numpy::PyArrayCheckLongScalar(object_) || PyComplexScalarTensorCheck(object_)\n             || PyFloatScalarTensorCheck(object_) || PyIntegerScalarTensorCheck(object_);\n    default: {\n      THROW(RuntimeError) << \"Can not check type \" << ValueTypeName(type);\n    }\n  }\n  return false;\n}\n\n}  // namespace functional\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/api/python/functional/python_arg.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_API_PYTHON_FUNCTIONAL_PYTHON_ARG_H_\n#define ONEFLOW_API_PYTHON_FUNCTIONAL_PYTHON_ARG_H_\n\n#include <pybind11/pybind11.h>\n#include <Python.h>\n#undef _PyGC_FINALIZED\n\n#include \"oneflow/core/common/throw.h\"\n#include \"oneflow/api/python/functional/value_types.h\"\n#include \"oneflow/core/common/maybe.h\"\n\nnamespace py = pybind11;\n\nnamespace oneflow {\nnamespace one {\nnamespace functional {\n\nnamespace detail {\n\nstruct DefaultVal {\n  virtual ValueType value_type() const = 0;\n  virtual const void* Ptr() const = 0;\n};\n\ntemplate<typename T>\nstruct TypedDefaultVal final : public DefaultVal {\n  T content;\n  explicit TypedDefaultVal(const T& v) : content(v) {}\n\n  ValueType value_type() const override { return ValueTypeOf<T>(); }\n  const void* Ptr() const override { return &content; }\n};\n\ntemplate<typename T>\nstruct optional_traits {\n  using type = void;\n};\n\ntemplate<typename T>\nstruct optional_traits<Optional<T>> {\n  using type =\n      decltype(std::declval<Optional<T>>().Data_YouAreNotAllowedToCallThisFuncOutsideThisFile());\n};\n\n}  // namespace detail\n\nclass PythonArg {\n public:\n  PythonArg() = default;\n\n  PythonArg(PyObject* object, int size = 0)\n      : object_(object), default_val_(), size_(size), tag_(HAS_OBJECT) {}\n\n  PythonArg(const detail::DefaultVal* value, int size = 0)\n      : object_(nullptr), default_val_(value), size_(size), tag_(HAS_DEFAULT) {}\n\n  template<typename T, typename std::enable_if<!internal::IsOptional<T>::value, int>::type = 0>\n  T As() const {\n    if (tag_ == HAS_DEFAULT) {\n      CHECK_EQ_OR_THROW(ValueTypeOf<T>(), default_val_->value_type())\n          << \"Could not convert default value from type \" << default_val_->value_type()\n          << \" to type \" << ValueTypeOf<T>();\n      return *reinterpret_cast<const T*>(default_val_->Ptr());\n    }\n    CHECK_EQ_OR_THROW(tag_, HAS_OBJECT);\n    return ObjectAs<oneflow::detail::remove_cvref_t<T>>();\n  }\n\n  template<typename T, typename std::enable_if<internal::IsOptional<T>::value, int>::type = 0>\n  T As() const {\n    if (tag_ == HAS_DEFAULT) {\n      CHECK_EQ_OR_THROW(ValueTypeOf<T>(), default_val_->value_type())\n          << \"Could not convert default value from type \" << default_val_->value_type()\n          << \" to type \" << ValueTypeOf<T>();\n      return *reinterpret_cast<const T*>(default_val_->Ptr());\n    }\n    CHECK_EQ_OR_THROW(tag_, HAS_OBJECT);\n    if (object_ == Py_None) { return T(); }\n    return ObjectAs<typename detail::optional_traits<T>::type>();\n  }\n\n  bool TypeCheck(ValueType type) const;\n\n private:\n  template<typename T>\n  T ObjectAs() const;\n\n  PyObject* object_;\n  const detail::DefaultVal* default_val_;\n  size_t size_;\n  enum { HAS_OBJECT, HAS_DEFAULT, HAS_NONE } tag_;\n};\n\n}  // namespace functional\n}  // namespace one\n}  // namespace oneflow\n\n#endif  // ONEFLOW_API_PYTHON_FUNCTIONAL_PYTHON_ARG_H_\n"
  },
  {
    "path": "oneflow/api/python/functional/python_arg_parser.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/api/python/functional/python_arg_parser.h\"\n#include \"oneflow/api/python/functional/common.h\"\n#include \"oneflow/api/python/functional/python_arg.h\"\n\nnamespace oneflow {\nnamespace one {\nnamespace functional {\n\nvoid FunctionSchema::ReportKwargsError(PyObject* kwargs, size_t nargs) const {\n  PyObject *key = nullptr, *value = nullptr;\n  Py_ssize_t pos = 0;\n\n  while (PyDict_Next(kwargs, &pos, &key, &value)) {\n    if (!PyStringCheck(key)) { THROW(TypeError) << def_->name << \"(): keywords must be strings\"; }\n    int64_t index = -1;\n    const std::string string_key = PyStringAsString(key);\n    for (int i = 0; i < def_->argument_def.size(); ++i) {\n      const auto& arg = def_->argument_def[i];\n      if (arg.name == string_key) {\n        index = i;\n        break;\n      }\n    }\n    if (index < 0) {\n      THROW(TypeError) << def_->name << \"(): got an unexpected keyword argument '\" << string_key\n                       << \"'\";\n    }\n    if (index < nargs) {\n      THROW(TypeError) << def_->name << \"(): got multiple values for argument '\" << string_key\n                       << \"'\";\n    }\n  }\n  THROW(TypeError) << def_->name << \"(): kwargs unknown error\";\n}\n\n// The argument parsing refers to the implementation of Pytorch.\nbool FunctionSchema::Parse(PyObject* args, PyObject* kwargs, PythonArg* parsed_args,\n                           bool raise_exception) const {\n  bool treat_args_as_list = false;\n  size_t nargs = args ? PyTuple_Size(args) : 0;\n  size_t remaining_kwargs = kwargs ? PyDict_Size(kwargs) : 0;\n\n  if (max_pos_nargs_ == 1) {\n    const auto& type = def_->argument_def[0].type;\n    treat_args_as_list = IsIntegralListType(type) || type == kSHAPE || type == kTENSOR_TUPLE;\n  }\n  if (nargs > max_pos_nargs_ && !treat_args_as_list) {\n    if (raise_exception) {\n      THROW(TypeError) << def_->name << \"(): takes \" << max_pos_nargs_\n                       << \" positional arguments but \" << nargs << \" were given\";\n    }\n    return false;\n  }\n  int arg_pos = 0;\n  for (int i = 0; i < def_->argument_def.size(); ++i) {\n    const auto& param = def_->argument_def[i];\n    PyObject* obj = NULL;\n    if (args && arg_pos < nargs) {\n      if (param.keyword_only) {\n        if (raise_exception) {\n          THROW(TypeError) << def_->name << \"(): argument '\" << param.name << \"' is keyword only\";\n        }\n        return false;\n      }\n      obj = PyTuple_GET_ITEM(args, arg_pos);\n    } else if (kwargs) {\n      obj = PyDict_GetItemString(kwargs, param.name.c_str());\n      if (obj) { --remaining_kwargs; }\n    }\n\n    if (obj) {\n      if (arg_pos == 0 && treat_args_as_list && !param.keyword_only\n          && (PyLong_Check(obj) || PyTensor_Check(obj))) {\n        obj = args;\n        arg_pos = nargs;\n      } else {\n        ++arg_pos;\n      }\n      PythonArg arg(obj, param.size);\n      if ((obj == Py_None && param.optional) || arg.TypeCheck(param.type)) {\n        parsed_args[i] = arg;\n      } else {\n        if (raise_exception) {\n          THROW(TypeError) << def_->name << \"(): argument '\" << param.name << \"' must be \"\n                           << ValueTypeName(param.type) << \", not \"\n                           << PyStringAsString(PyObject_Str((PyObject*)Py_TYPE(obj)));\n        }\n        return false;\n      }\n    } else {\n      if (!param.has_default_value) {\n        if (raise_exception) {\n          THROW(TypeError) << def_->name << \"(): missing required argument \" << param.name;\n        }\n        return false;\n      }\n      parsed_args[i] = param.default_value.get();\n    }\n  }\n  if (remaining_kwargs > 0) {\n    if (raise_exception) { ReportKwargsError(kwargs, nargs); }\n    return false;\n  }\n  return true;\n}\n\n}  // namespace functional\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/api/python/functional/python_arg_parser.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_API_PYTHON_FUNCTIONAL_PYTHON_ARG_PARSER_H_\n#define ONEFLOW_API_PYTHON_FUNCTIONAL_PYTHON_ARG_PARSER_H_\n\n#include <Python.h>\n#undef _PyGC_FINALIZED\n\n#include \"oneflow/api/python/functional/function_def.h\"\n#include \"oneflow/api/python/functional/python_arg.h\"\n#include \"oneflow/core/common/throw.h\"\n#include \"oneflow/core/common/util.h\"\n\nnamespace oneflow {\nnamespace one {\nnamespace functional {\n\ntemplate<int N>\nclass ParsedArgs {\n public:\n  ParsedArgs() = default;\n\n  const PythonArg& operator[](size_t idx) const { return data[idx]; }\n  PythonArg& operator[](size_t idx) { return data[idx]; }\n\n public:\n  PythonArg data[N];\n};\n\nclass FunctionSchema {\n public:\n  FunctionSchema() = default;\n  FunctionSchema(const std::string& signature, const FunctionDef* def, size_t max_pos_nargs)\n      : signature_(signature), def_(def), max_pos_nargs_(max_pos_nargs) {}\n\n  const std::string& signature() const { return signature_; }\n\n  bool Parse(PyObject* args, PyObject* kwargs, PythonArg* parsed_args, bool raise_exception) const;\n\n private:\n  void ReportKwargsError(PyObject* kwargs, size_t nargs) const;\n\n  std::string signature_;\n  const FunctionDef* def_;\n  size_t max_pos_nargs_;\n};\n\ntemplate<typename... SchemaT>\nclass PythonArgParser {\n public:\n  static_assert(sizeof...(SchemaT) >= 1, \"requires 1 template argument at least.\");\n  static constexpr size_t kSchemaSize = sizeof...(SchemaT);\n  static constexpr size_t N = std::max({SchemaT::max_args...});\n\n  template<size_t I>\n  using schema_t = typename std::tuple_element<I, std::tuple<SchemaT...>>::type;\n\n  PythonArgParser(const std::string& name) : name_(name) {\n    Init(std::make_index_sequence<sizeof...(SchemaT)>{});\n  }\n\n  int Parse(PyObject* args, PyObject* kwargs, ParsedArgs<N>* parsed_args) const {\n    bool raise_exception = (kSchemaSize == 1);\n    for (int i = 0; i < kSchemaSize; ++i) {\n      if (schema_[i].Parse(args, kwargs, parsed_args->data, raise_exception)) { return i; }\n    }\n    ReportInvalidArgsError(args, kwargs);\n    return -1;\n  }\n\n private:\n  template<size_t... I>\n  void Init(std::index_sequence<I...>) {\n    ((schema_[I] = FunctionSchema(schema_t<I>::signature, &schema_t<I>::function_def,\n                                  schema_t<I>::max_pos_args)),\n     ...);\n  }\n\n  void ReportInvalidArgsError(PyObject* args, PyObject* kwargs) const {\n    std::ostringstream ss;\n    ss << name_ << \"(): received an invalid combination of arguments. The valid signatures are:\";\n    for (int i = 0; i < kSchemaSize; ++i) { ss << \"\\n\\t*\" << i << \": \" << schema_[i].signature(); }\n    THROW(TypeError) << ss.str();\n  }\n\n private:\n  std::string name_;\n  FunctionSchema schema_[kSchemaSize];\n};\n\n}  // namespace functional\n}  // namespace one\n}  // namespace oneflow\n\n#endif  // ONEFLOW_API_PYTHON_FUNCTIONAL_PYTHON_ARG_PARSER_H_\n"
  },
  {
    "path": "oneflow/api/python/functional/python_return_types.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n// This code is referenced from:\n// https://github.com/pytorch/pytorch/blob/master/torch/csrc/utils/structseq.cpp\n\n#ifndef ONEFLOW_API_PYTHON_FUNCTIONAL_PYTHON_RETURN_TYPES_H_\n#define ONEFLOW_API_PYTHON_FUNCTIONAL_PYTHON_RETURN_TYPES_H_\n\n#include <Python.h>\n#undef _PyGC_FINALIZED\n#include <string>\n#include <sstream>\n#include <structmember.h>\n\n#include \"oneflow/api/python/exception/exception.h\"\n#include \"oneflow/api/python/functional/common.h\"\n\nnamespace oneflow {\nnamespace one {\nnamespace functional {\n\ninline PyObject* toTuple(PyStructSequence* obj) {\n#if PY_MAJOR_VERSION == 2\n  ROF_RUNTIME_ERROR() << \"Oneflow do not support python 2\";\n#else\n  Py_INCREF(obj);\n  return (PyObject*)obj;\n#endif\n}\n\nPyObject* returned_structseq_repr(PyStructSequence* obj) {\n  HANDLE_ERRORS\n  PyTypeObject* tp = Py_TYPE(obj);\n  PyObject* tuple = toTuple(obj);\n  if (tuple == nullptr) { return nullptr; }\n\n  std::stringstream ss;\n  ss << tp->tp_name << \"(\\n\";\n  Py_ssize_t num_elements = Py_SIZE(obj);\n\n  for (Py_ssize_t i = 0; i < num_elements; i++) {\n    const char* cname = tp->tp_members[i].name;\n    if (cname == nullptr) {\n      PyErr_Format(PyExc_SystemError,\n                   \"In structseq_repr(), member %zd name is nullptr\"\n                   \" for type %.500s\",\n                   i, tp->tp_name);\n      Py_DECREF(tuple);\n      return nullptr;\n    }\n\n    PyObject* val = PyTuple_GetItem(tuple, i);\n    if (val == nullptr) {\n      Py_DECREF(tuple);\n      return nullptr;\n    }\n\n    auto repr = PyObject_Repr(val);\n    if (repr == nullptr) {\n      Py_DECREF(tuple);\n      return nullptr;\n    }\n\n    const char* crepr = PyUnicode_AsUTF8(repr);\n    Py_DECREF(repr);\n    if (crepr == nullptr) {\n      Py_DECREF(tuple);\n      return nullptr;\n    }\n\n    ss << cname << '=' << crepr;\n    if (i < num_elements - 1) { ss << \",\\n\"; }\n  }\n  ss << \")\";\n\n  Py_DECREF(tuple);\n  return PyUnicode_FromString(ss.str().c_str());\n  END_HANDLE_ERRORS\n}\n\n}  // namespace functional\n}  // namespace one\n}  // namespace oneflow\n\n#endif  // ONEFLOW_API_PYTHON_FUNCTIONAL_PYTHON_RETURN_TYPES_H_\n"
  },
  {
    "path": "oneflow/api/python/functional/tensor_api.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <Python.h>\n#undef _PyGC_FINALIZED\n#include <memory>\n\n#include \"oneflow/api/python/utils/tensor_utils.h\"\n#include \"oneflow/api/python/dlpack/converter.h\"\n#include \"oneflow/api/python/framework/size.h\"\n#include \"oneflow/api/python/functional/common.h\"\n#include \"oneflow/api/python/functional/tensor_api.yaml.h\"\n#include \"oneflow/core/common/optional.h\"\n#include \"oneflow/core/common/scalar.h\"\n#include \"oneflow/core/eager/tensor_storage.h\"\n#include \"oneflow/core/framework/mutable_attr_map.h\"\n#include \"oneflow/core/framework/stream.h\"\n#include \"oneflow/core/framework/op_builder.h\"\n#include \"oneflow/core/framework/op_expr.h\"\n#include \"oneflow/core/framework/op_interpreter/op_interpreter_util.h\"\n#include \"oneflow/core/framework/tensor.h\"\n#include \"oneflow/core/framework/tensor_tuple.h\"\n#include \"oneflow/core/functional/functional.h\"\n#include \"oneflow/core/functional/function_library.h\"\n#include \"oneflow/core/functional/impl/common.h\"\n#include \"oneflow/core/job/lazy_mode.h\"\n#include \"oneflow/core/framework/nd_sbp.h\"\n#include \"oneflow/core/common/foreign_lock_helper.h\"\n\nnamespace oneflow {\nnamespace one {\nnamespace functional {\n\nnamespace impl {\n\nclass TensorWithDataFunctor {\n public:\n  Maybe<Tensor> operator()(PyObject* data, const Optional<Symbol<DType>>& dtype,\n                           const Optional<Symbol<Device>>& device, const bool requires_grad,\n                           const bool pin_memory) const {\n    // NOTE(chengcheng): flow.Tensor or flow.tensor ONLY created by EagerTensor now.\n    //  even if in nn.Graph build (module forward function), if you create a flow.Tensor,\n    //  its a eager tensor by Run functional::Empty() in LazyMode::Grad(false)\n    LazyMode::Guard lazy_mode_disabled_guard(/*is_enabled*/ false);\n    if (GlobalMode::is_enabled()) {\n      auto global_mode_gurad = GlobalMode::Guard(false);\n      return JUST(\n          functional::GlobalTensorWithData(data, dtype, GetGlobalParallelDescFromDevice(device),\n                                           *JUST(GetSbpList(GlobalMode::nd_sbp())), requires_grad));\n    }\n\n    if (PyTensor_Check(data)) {\n      // Throw warnings like pytorch.\n      auto ret = PyErr_WarnEx(\n          PyExc_UserWarning,\n          \"To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() \"\n          \"or sourceTensor.clone().detach().requires_grad_(True), rather than \"\n          \"oneflow.tensor(sourceTensor).\",\n          1);\n      if (ret != 0) { return Error::RuntimeError(); }\n\n      const auto& other = PyTensor_Unpack(data);\n      return MakeTensorFromOtherTensor(other, dtype, device, requires_grad, pin_memory);\n    } else {\n      // Make tensor from python sequence or numpy array.\n      return MakeLocalTensorFromData(data, dtype, device, requires_grad, pin_memory);\n    }\n  }\n};\n\nclass GlobalTensorWithDataFunctor {\n public:\n  Maybe<Tensor> operator()(PyObject* data, const Optional<Symbol<DType>>& dtype,\n                           const Symbol<ParallelDesc>& placement,\n                           const std::vector<Symbol<SbpParallel>>& sbp_tuple,\n                           const bool requires_grad) const {\n    // NOTE(chengcheng): flow.Tensor or flow.tensor ONLY created by EagerTensor now.\n    LazyMode::Guard lazy_mode_disabled_guard(/*is_enabled*/ false);\n    JUST(CheckDeviceIdsIsValid(placement));\n\n    if (PyTensor_Check(data)) {\n      // Throw warnings like pytorch.\n      auto ret = PyErr_WarnEx(\n          PyExc_UserWarning,\n          \"To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() \"\n          \"or sourceTensor.clone().detach().requires_grad_(True), rather than \"\n          \"oneflow.tensor(sourceTensor).\",\n          1);\n      if (ret != 0) { return Error::RuntimeError(); }\n\n      const auto& other = PyTensor_Unpack(data);\n      return MakeTensorFromOtherTensor(other, dtype, placement, sbp_tuple, requires_grad);\n    }\n    // Make global tensor from python sequence or numpy array.\n    return MakeGlobalTensorFromData(data, dtype, placement, sbp_tuple, requires_grad);\n  }\n};\n\nclass TensorEmptyGenericCtorFunctor {\n public:\n  Maybe<Tensor> operator()(const Symbol<DType>& dtype,\n                           const Optional<Symbol<Device>>& device) const {\n    Shape shape(DimVector{0});\n    return TensorWithShapeGenericCtor(shape, dtype, device);\n  }\n};\n\nclass GlobalTensorEmptyGenericCtorFunctor {\n public:\n  Maybe<Tensor> operator()(const Symbol<DType>& dtype, const Symbol<ParallelDesc>& placement,\n                           const std::vector<Symbol<SbpParallel>>& sbp_tuple) const {\n    Shape shape(DimVector{0});\n    JUST(CheckDeviceIdsIsValid(placement));\n    return GlobalTensorWithShapeGenericCtor(shape, dtype, placement, sbp_tuple);\n  }\n};\n\nclass TensorWithOtherGenericCtorFunctor {\n public:\n  Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& other,\n                           const Optional<Symbol<DType>>& dtype) const {\n    // NOTE(chengcheng): flow.Tensor or flow.tensor ONLY created by EagerTensor now.\n    LazyMode::Guard lazy_mode_disabled_guard(/*is_enabled*/ false);\n    bool is_pinned = false;\n    if (other->is_local()) { is_pinned = JUST(CHECK_JUST(other->AsLocalTensor())->is_pinned()); }\n    return To(JUST(MakeTensorFromOtherTensor(other, is_pinned)), dtype, false);\n  }\n};\n\nclass TensorWithDataGenericCtorFunctor {\n public:\n  Maybe<Tensor> operator()(PyObject* data, const Symbol<DType>& dtype,\n                           const Optional<Symbol<Device>>& device) const {\n    // Treat the single long as shape.\n    if (PyLong_Check(data)) {\n      int64_t size = PyLong_AsLongLong(data);\n      Shape shape(DimVector{size});\n      return TensorWithShapeGenericCtor(shape, dtype, device);\n    }\n    if (TensorSize_Check(data)) {\n      return TensorWithShapeGenericCtor(TensorSize_AsShape(data), dtype, device);\n    }\n\n    // NOTE(chengcheng): flow.Tensor or flow.tensor ONLY created by EagerTensor now.\n    LazyMode::Guard lazy_mode_disabled_guard(/*is_enabled*/ false);\n\n    if (PyTensor_Check(data)) {\n      const auto& other = PyTensor_Unpack(data);\n      const bool pin_memory =\n          other->is_local() ? JUST(JUST(other->AsLocalTensor())->is_pinned()) : false;\n      return MakeTensorFromOtherTensor(other, dtype, device,\n                                       /*requires_grad=*/false, /*pin_memory=*/pin_memory);\n    }\n    // Make tensor from python sequence or numpy array.\n    return MakeLocalTensorFromData(data, dtype, device, /*requires_grad=*/false,\n                                   /*pin_memory=*/false);\n  }\n};\n\nclass GlobalTensorWithDataGenericCtorFunctor {\n public:\n  Maybe<Tensor> operator()(PyObject* data, const Symbol<DType>& dtype,\n                           const Symbol<ParallelDesc>& placement,\n                           const std::vector<Symbol<SbpParallel>>& sbp_tuple) const {\n    JUST(CheckDeviceIdsIsValid(placement));\n    // Treat the single long as shape.\n    if (PyLong_Check(data)) {\n      int64_t size = PyLong_AsLongLong(data);\n      Shape shape(DimVector{size});\n      return GlobalTensorWithShapeGenericCtor(shape, dtype, placement, sbp_tuple);\n    }\n    if (TensorSize_Check(data)) {\n      return GlobalTensorWithShapeGenericCtor(TensorSize_AsShape(data), dtype, placement,\n                                              sbp_tuple);\n    }\n\n    // NOTE(chengcheng): flow.Tensor or flow.tensor ONLY created by EagerTensor now.\n    LazyMode::Guard lazy_mode_disabled_guard(/*is_enabled*/ false);\n\n    if (PyTensor_Check(data)) {\n      const auto& other = PyTensor_Unpack(data);\n      return MakeTensorFromOtherTensor(other, dtype, placement, sbp_tuple,\n                                       /*requires_grad=*/false);\n    }\n    // Make global tensor from python sequence or numpy array.\n    return MakeGlobalTensorFromData(data, dtype, placement, sbp_tuple, /*requires_grad=*/false);\n  }\n};\n\nclass TensorWithShapeGenericCtorFunctor {\n public:\n  Maybe<Tensor> operator()(const Shape& shape, const Symbol<DType>& dtype,\n                           const Optional<Symbol<Device>>& device) const {\n    // NOTE(chengcheng): flow.Tensor or flow.tensor ONLY created by EagerTensor now.\n    LazyMode::Guard lazy_mode_disabled_guard(/*is_enabled*/ false);\n    Symbol<Device> device_;\n    if (device) {\n      device_ = JUST(device);\n    } else {\n      device_ = JUST(Device::New(\"cpu\"));\n    }\n    return functional::Empty(shape, dtype, device_, /*requires_grad=*/false, /*pin_memory=*/false);\n  }\n};\n\nclass GlobalTensorWithShapeGenericCtorFunctor {\n public:\n  Maybe<Tensor> operator()(const Shape& shape, const Symbol<DType>& dtype,\n                           const Symbol<ParallelDesc>& placement,\n                           const std::vector<Symbol<SbpParallel>>& sbp_tuple) const {\n    // NOTE(chengcheng): flow.Tensor or flow.tensor ONLY created by EagerTensor now.\n    LazyMode::Guard lazy_mode_disabled_guard(/*is_enabled*/ false);\n    JUST(CheckDeviceIdsIsValid(placement));\n    return functional::GlobalEmpty(shape, dtype, placement, sbp_tuple);\n  }\n};\n\nclass AssignLocalTensorFunctor {\n public:\n  AssignLocalTensorFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"copy\").Input(\"in\").Output(\"out\").Build());\n  }\n  Maybe<void> operator()(const std::shared_ptr<one::Tensor>& y,\n                         const std::shared_ptr<one::Tensor>& x) const {\n    // JUST(CheckInplaceValid(y)); // align check to torch\n    CHECK_OR_RETURN(y->is_local() && x->is_local()) << \"Both x and y must be local tensor.\";\n    std::shared_ptr<one::Tensor> src = x;\n    if (y->dtype() != src->dtype()) { src = JUST(To(src, y->dtype(), false)); }\n\n    auto device = JUST(y->device());\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"device\", \"pin_memory\");\n    attrs.SetAllAttrs(device, false);\n    TensorTuple outputs{y};\n    return OpInterpUtil::Dispatch(*op_, {x}, &outputs, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nstatic std::vector<int64_t> get_shape_or_stride_from_numpy(size_t ndim, npy_intp* values) {\n  auto result = std::vector<int64_t>(ndim);\n  for (size_t i = 0; i < ndim; ++i) { result[i] = static_cast<int64_t>(values[i]); }\n  return result;\n}\n\nclass LocalTensorSharedDlPackDataFunctor {\n public:\n  LocalTensorSharedDlPackDataFunctor() {}\n  Maybe<Tensor> operator()(PyObject* obj) const {\n    DLManagedTensor* dlMTensor = (DLManagedTensor*)PyCapsule_GetPointer(obj, \"dltensor\");\n    CHECK_NOTNULL_OR_RETURN(dlMTensor)\n        << \"from_dlpack received an invalid capsule. \"\n           \"Note that DLTensor capsules can be consumed only once, \"\n           \"so you might have already constructed a tensor from it once.\";\n\n    // `tensor` steals the ownership of the underlying storage. It also passes a\n    // destructor function that will be called when the underlying storage goes\n    // out of scope. When the destructor is called, the dlMTensor is destructed\n    // too.\n    auto tensor = fromDLPack(dlMTensor);\n\n    // Make sure this capsule will never be used again.\n    PyCapsule_SetName(obj, \"used_dltensor\");\n\n    return tensor;\n  }\n};\n\nclass LocalTensorSharedNumpyDataFunctor {\n public:\n  LocalTensorSharedNumpyDataFunctor() {}\n  Maybe<Tensor> operator()(PyObject* obj) const {\n    if (!PyArray_Check(obj)) {\n      return Error::TypeError() << \"expected np.ndarray, but got \" << Py_TYPE(obj)->tp_name;\n    }\n    auto* array = reinterpret_cast<PyArrayObject*>(obj);\n    const size_t ndim = PyArray_NDIM(array);\n    std::vector<int64_t> sizes = get_shape_or_stride_from_numpy(ndim, PyArray_DIMS(array));\n    std::vector<int64_t> strides = get_shape_or_stride_from_numpy(ndim, PyArray_STRIDES(array));\n    // NumPy strides use bytes. OneFlow strides use element counts.\n    // These checks are consistent with pytorch(v1.10.0):\n    // https://github.com/pytorch/pytorch/blob/v1.10.0/torch/csrc/utils/tensor_numpy.cpp#L171\n    const auto element_size_in_bytes = PyArray_ITEMSIZE(array);\n    for (auto& stride : strides) {\n      if (stride % element_size_in_bytes != 0) {\n        return Error::InvalidValueError()\n               << \"given numpy array strides not a multiple of the element byte size. \"\n               << \"Copy the numpy array to reallocate the memory.\";\n      }\n      stride /= element_size_in_bytes;\n    }\n    for (size_t i = 0; i < ndim; ++i) {\n      if (strides[i] < 0) {\n        return Error::InvalidValueError()\n               << \"At least one stride in the given numpy array is negative, \"\n               << \"and tensors with negative strides are not currently supported. \"\n               << \"(You can probably work around this by making a copy of your array \"\n               << \" with array.copy().) \";\n      }\n    }\n    void* data_ptr = PyArray_DATA(array);\n    if (!PyArray_EquivByteorders(PyArray_DESCR(array)->byteorder, NPY_NATIVE)) {\n      return Error::InvalidValueError()\n             << \"given numpy array has byte order different from the native byte order. \"\n             << \"Conversion between byte orders is currently not supported.\";\n    }\n    Py_INCREF(obj);\n\n    // Build TensorMeta\n    const auto shape = Shape(DimVector(sizes.begin(), sizes.end()));\n    const auto stride = Stride(strides.begin(), strides.end());\n    DataType data_type = JUST(numpy::GetOFDataTypeFromNpArray(array));\n    Symbol<Device> device = JUST(Device::New(\"cpu\"));\n\n    auto tensor_meta =\n        SymbolOf(LocalTensorMeta(shape, stride, data_type, MemoryFormat::kContiguous, device));\n\n    // Build TensorBuffer\n    const auto& Free = [array](char* dptr) {\n      CHECK_JUST(Singleton<ForeignLockHelper>::Get()->WithScopedAcquire([&]() -> Maybe<void> {\n        Py_DECREF(array);\n        return Maybe<void>::Ok();\n      }));\n    };\n\n    const auto array_size_in_bytes = PyArray_NBYTES(array);\n    auto tensor_data = std::make_shared<vm::TensorStorage>(false, device);\n    tensor_data->set_blob_dptr(\n        std::unique_ptr<char, std::function<void(char*)>>(static_cast<char*>(data_ptr), Free),\n        array_size_in_bytes);\n\n    // Build TensorStorage: decrease ndarray reference count before releasing\n    auto tensor_storage = std::make_shared<TensorStorage>(tensor_data);\n\n    // Build Tensor\n    auto tensor_impl = std::make_shared<EagerLocalTensorImpl>(tensor_storage,\n                                                              /*requires_grad=*/false,\n                                                              /*ls_leaf=*/true);\n\n    // Init blob\n    JUST(tensor_impl->InitEagerBlobObject(tensor_meta, NewLocalDepObject()));\n    const auto& stream = JUST(GetDefaultStreamByDevice(device));\n    const auto& eager_blob_object = JUST(tensor_impl->eager_blob_object());\n    JUST(eager_blob_object->init_producer_stream(stream));\n    eager_blob_object->set_last_used_stream(stream);\n    std::shared_ptr<Tensor> out(new LocalTensor(tensor_impl));\n    return out;\n  }\n};\n\n}  // namespace impl\n\nONEFLOW_FUNCTION_LIBRARY(m) {\n  m.add_functor<impl::TensorWithDataFunctor>(\"TensorWithData\");\n  m.add_functor<impl::GlobalTensorWithDataFunctor>(\"GlobalTensorWithData\");\n  m.add_functor<impl::TensorEmptyGenericCtorFunctor>(\"TensorEmptyGenericCtor\");\n  m.add_functor<impl::GlobalTensorEmptyGenericCtorFunctor>(\"GlobalTensorEmptyGenericCtor\");\n  m.add_functor<impl::TensorWithOtherGenericCtorFunctor>(\"TensorWithOtherGenericCtor\");\n  m.add_functor<impl::TensorWithDataGenericCtorFunctor>(\"TensorWithDataGenericCtor\");\n  m.add_functor<impl::GlobalTensorWithDataGenericCtorFunctor>(\"GlobalTensorWithDataGenericCtor\");\n  m.add_functor<impl::TensorWithShapeGenericCtorFunctor>(\"TensorWithShapeGenericCtor\");\n  m.add_functor<impl::GlobalTensorWithShapeGenericCtorFunctor>(\"GlobalTensorWithShapeGenericCtor\");\n  m.add_functor<impl::AssignLocalTensorFunctor>(\"AssignLocalTensor\");\n  m.add_functor<impl::LocalTensorSharedNumpyDataFunctor>(\"LocalTensorSharedNumpyData\");\n  m.add_functor(\"TensorEmptyCtor\", [](const Optional<Symbol<Device>>& device) -> Maybe<Tensor> {\n    return TensorEmptyGenericCtor(GetDefaultDType(), device);\n  });\n  m.add_functor(\"GlobalTensorEmptyCtor\",\n                [](const Symbol<ParallelDesc>& placement,\n                   const std::vector<Symbol<SbpParallel>>& sbp_tuple) -> Maybe<Tensor> {\n                  return GlobalTensorEmptyGenericCtor(GetDefaultDType(), placement, sbp_tuple);\n                });\n  m.add_functor(\"TensorWithOtherCtor\", [](const std::shared_ptr<Tensor>& other) -> Maybe<Tensor> {\n    return TensorWithOtherGenericCtor(other, NullOpt);\n  });\n  m.add_functor(\"TensorWithDataCtor\",\n                [](PyObject* data, const Optional<Symbol<Device>>& device) -> Maybe<Tensor> {\n                  return TensorWithDataGenericCtor(data, GetDefaultDType(), device);\n                });\n  m.add_functor(\"GlobalTensorWithDataCtor\",\n                [](PyObject* data, const Symbol<ParallelDesc>& placement,\n                   const std::vector<Symbol<SbpParallel>>& sbp_tuple) -> Maybe<Tensor> {\n                  return GlobalTensorWithDataGenericCtor(data, GetDefaultDType(), placement,\n                                                         sbp_tuple);\n                });\n  m.add_functor(\"TensorWithShapeCtor\",\n                [](const Shape& shape, const Optional<Symbol<Device>>& device) -> Maybe<Tensor> {\n                  return TensorWithShapeGenericCtor(shape, GetDefaultDType(), device);\n                });\n  m.add_functor(\"GlobalTensorWithShapeCtor\",\n                [](const Shape& shape, const Symbol<ParallelDesc>& placement,\n                   const std::vector<Symbol<SbpParallel>>& sbp_tuple) -> Maybe<Tensor> {\n                  return GlobalTensorWithShapeGenericCtor(shape, GetDefaultDType(), placement,\n                                                          sbp_tuple);\n                });\n  m.add_functor<impl::LocalTensorSharedDlPackDataFunctor>(\"LocalTensorSharedDlPackData\");\n}\n\n}  // namespace functional\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/api/python/functional/tensor_api.yaml",
    "content": "# Copyright 2020 The OneFlow Authors. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n- name: \"tensor\"\n  signature: [\n      \"Tensor (PyObject* data, *, DataType dtype=None, Device device=None,\n      Bool requires_grad=False, Bool pin_memory=False) => TensorWithData\",\n      \"Tensor (PyObject* data, *, DataType dtype=None, Placement placement,\n      SbpList sbp, Bool requires_grad=False) => GlobalTensorWithData\",\n    ]\n  bind_python: True\n\n- name: \"_legacy_tensor_generic_ctor\"\n  signature: [\n      \"Tensor (*, DataType dtype, Device device=None) => TensorEmptyGenericCtor\",\n      \"Tensor (*, DataType dtype, Placement placement, SbpList sbp) => GlobalTensorEmptyGenericCtor\",\n      \"Tensor (Tensor other, *, DataType dtype=None) => TensorWithOtherGenericCtor\",\n      \"Tensor (PyObject* data, *, DataType dtype, Device device=None) => TensorWithDataGenericCtor\",\n      \"Tensor (PyObject* data, *, DataType dtype, Placement placement, SbpList sbp) => GlobalTensorWithDataGenericCtor\",\n      \"Tensor (Shape size, *, DataType dtype, Device device=None) => TensorWithShapeGenericCtor\",\n      \"Tensor (Shape size, *, DataType dtype, Placement placement, SbpList sbp) => GlobalTensorWithShapeGenericCtor\",\n  ]\n  bind_python: True\n\n\n- name: \"_legacy_tensor_ctor\"\n  signature:\n    [\n      \"Tensor (*, Device device=None) => TensorEmptyCtor\",\n      \"Tensor (*, Placement placement, SbpList sbp) => GlobalTensorEmptyCtor\",\n      \"Tensor (Tensor other) => TensorWithOtherCtor\",\n      \"Tensor (PyObject* data, *, Device device=None) => TensorWithDataCtor\",\n      \"Tensor (PyObject* data, *, Placement placement, SbpList sbp) => GlobalTensorWithDataCtor\",\n      \"Tensor (Shape size, *, Device device=None) => TensorWithShapeCtor\",\n      \"Tensor (Shape size, *, Placement placement, SbpList sbp) => GlobalTensorWithShapeCtor\",\n    ]\n  bind_python: True\n\n- name: \"assign_local_tensor\"\n  signature: \"Void (Tensor ref, Tensor value) => AssignLocalTensor\"\n  bind_python: True\n\n- name: \"from_numpy\"\n  signature: \"Tensor (PyObject* obj) => LocalTensorSharedNumpyData\"\n  bind_python: True\n\n- name: \"from_dlpack\"\n  signature: \"Tensor (PyObject* obj) => LocalTensorSharedDlPackData\"\n  bind_python: True\n"
  },
  {
    "path": "oneflow/api/python/functional/value_types.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/api/python/functional/value_types.h\"\n\n#include \"oneflow/core/common/throw.h\"\n#include \"oneflow/core/common/hash_container.h\"\n\nnamespace oneflow {\nnamespace one {\nnamespace functional {\n\nHashMap<ValueType, std::string>* GetValueTypeNameMap() {\n  static HashMap<ValueType, std::string> value_type_name_map = {\n      {kVOID, \"void\"},\n      {kINT32, \"int32\"},\n      {kUINT32, \"unsigned int32\"},\n      {kINT64, \"int64\"},\n      {kUINT64, \"unsigned int64\"},\n      {kFLOAT, \"float\"},\n      {kDOUBLE, \"double\"},\n      {kBOOL, \"bool\"},\n      {kSTRING, \"string\"},\n      {kINT32_LIST, \"int32 list\"},\n      {kUINT32_LIST, \"unsigned int32 list\"},\n      {kINT64_LIST, \"int64 list\"},\n      {kUINT64_LIST, \"unsigned int64 list\"},\n      {kFLOAT_LIST, \"float list\"},\n      {kDOUBLE_LIST, \"double list\"},\n      {kDOUBLE_LIST, \"bool list\"},\n      {kSTRING_LIST, \"string list\"},\n      {kVOID_MAYBE, \"maybe void\"},\n      {kBOOL_MAYBE, \"maybe bool\"},\n      {kSCALAR, \"scalar\"},\n      {kTENSOR, \"tensor\"},\n      {kTENSOR_REF, \"tensor\"},\n      {kTENSOR_MAYBE, \"maybe tensor\"},\n      {kTENSOR_TUPLE, \"tensor tuple\"},\n      {kTENSOR_TUPLE_REF, \"tensor tuple\"},\n      {kTENSOR_TUPLE_MAYBE, \"maybe tensor tuple\"},\n      {kATTR, \"attr\"},\n      {kATTR_REF, \"attr\"},\n      {kDTYPE, \"data type\"},\n      {kDTYPE_LIST, \"data type list\"},\n      {kSHAPE, \"shape\"},\n      {kSHAPE_LIST, \"shape list\"},\n      {kGENERATOR, \"generator\"},\n      {kGENERATOR_REF, \"generator\"},\n      {kGENERATOR_MAYBE, \"maybe generator\"},\n      {kTENSOR_INDEX, \"index\"},\n      {kDEVICE, \"device\"},\n      {kPARALLEL_DESC, \"placement\"},\n      {kSBP_PARALLEL, \"sbp\"},\n      {kSBP_PARALLEL_LIST, \"sbp list\"},\n      {kOPEXPR, \"opexpr\"},\n      {kOPEXPR_REF, \"opexpr\"},\n      {kPY_OBJECT, \"python object\"},\n      {kLAYOUT, \"layout\"},\n      {kMEMORY_FORMAT, \"memory format\"},\n      {kCOMPLEX_FLOAT, \"complex float\"},\n      {kCOMPLEX_DOUBLE, \"complex double\"},\n      {kCHAR, \"char\"},\n      {kINT16, \"int16\"}};\n  return &value_type_name_map;\n}\n\nconst std::string& ValueTypeName(ValueType type) {\n  const auto* type_name_map = GetValueTypeNameMap();\n  const auto& it = type_name_map->find(type);\n  CHECK_OR_THROW(it != type_name_map->end()) << \"Value type \" << type << \" has no type name.\";\n  return it->second;\n}\n\nbool IsIntegralType(ValueType type) { return type >= kINT32 && type < kINTEGRAL_MASK; }\nbool IsIntegralListType(ValueType type) {\n  return type >= kINT32_LIST && type < kINTEGRAL_LIST_MASK;\n}\nbool IsFloatingType(ValueType type) { return type >= kFLOAT && type < kFLOATING_MASK; }\nbool IsFloatingListType(ValueType type) {\n  return type >= kFLOAT_LIST && type < kFLOATING_LIST_MASK;\n}\n\n}  // namespace functional\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/api/python/functional/value_types.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#ifndef ONEFLOW_CORE_FUNCTIONAL_VALUE_TYPES_H_\n#define ONEFLOW_CORE_FUNCTIONAL_VALUE_TYPES_H_\n\n#include <complex>\n#include <memory>\n#include <Python.h>\n#undef _PyGC_FINALIZED\n\n#include \"oneflow/core/common/data_type.pb.h\"\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/common/memory_format.pb.h\"\n#include \"oneflow/core/common/optional.h\"\n#include \"oneflow/core/framework/dtype.h\"\n#include \"oneflow/core/framework/layout.h\"\n\nnamespace oneflow {\nclass Scalar;\nclass Shape;\n\ntemplate<typename T>\nclass Symbol;\n\nclass Device;\nclass ParallelDesc;\nclass SbpParallel;\n\nnamespace one {\nclass Tensor;\nclass TensorTuple;\nclass Generator;\nclass OpExpr;\n\nnamespace functional {\nclass TensorIndex;\n}  // namespace functional\n}  // namespace one\n\nnamespace one {\nnamespace functional {\n\nenum ValueType : int {\n  kINVALID = 0,\n  kVOID,\n  // Integral\n  kINT32,\n  kINT64,\n  kUINT32,\n  kUINT64,\n  kINTEGRAL_MASK = 10,\n  // Floating\n  kFLOAT,\n  kDOUBLE,\n  kFLOATING_MASK = 15,\n\n  kBOOL,\n  kSTRING,\n  // Integral list\n  kINT32_LIST = 50,\n  kUINT32_LIST,\n  kINT64_LIST,\n  kUINT64_LIST,\n  kINTEGRAL_LIST_MASK = 60,\n  // Floating list\n  kFLOAT_LIST,\n  kDOUBLE_LIST,\n  kFLOATING_LIST_MASK = 65,\n\n  kBOOL_LIST,\n  kSTRING_LIST,\n\n  kVOID_MAYBE = 100,\n  kBOOL_MAYBE,\n\n  kSCALAR = 200,\n  kTENSOR,\n  kTENSOR_REF,\n  kTENSOR_MAYBE,\n  kTENSOR_TUPLE,\n  kTENSOR_TUPLE_REF,\n  kTENSOR_TUPLE_MAYBE,\n  kATTR,\n  kATTR_REF,\n  kDTYPE,\n  kSHAPE,\n  kLAYOUT,\n  kSHAPE_MAYBE,\n  kGENERATOR,\n  kGENERATOR_REF,\n  kGENERATOR_MAYBE,\n  kTENSOR_INDEX,\n  kDEVICE,\n  kPARALLEL_DESC,\n  kSBP_PARALLEL,\n  kSBP_PARALLEL_LIST,\n  kSHAPE_LIST,\n  kDTYPE_LIST,\n\n  kMEMORY_FORMAT,\n\n  kOPEXPR = 390,\n  kOPEXPR_REF,\n  kPY_OBJECT = 400,\n\n  // Complex\n  kCOMPLEX_FLOAT,\n  kCOMPLEX_DOUBLE,\n  kCHAR,\n  kINT16\n};\n\n#define VALUE_TYPE_OF_IMPL(cpp_type, value_type)                                                 \\\n  template<typename T, typename std::enable_if<std::is_same<T, cpp_type>::value, int>::type = 0> \\\n  inline ValueType ValueTypeOf() {                                                               \\\n    return value_type;                                                                           \\\n  }                                                                                              \\\n  template<typename T,                                                                           \\\n           typename std::enable_if<std::is_same<T, Optional<cpp_type>>::value, int>::type = 0>   \\\n  inline ValueType ValueTypeOf() {                                                               \\\n    return value_type;                                                                           \\\n  }\n\nVALUE_TYPE_OF_IMPL(void, kVOID);\nVALUE_TYPE_OF_IMPL(int32_t, kINT32);\nVALUE_TYPE_OF_IMPL(int16_t, kINT16);\nVALUE_TYPE_OF_IMPL(char, kCHAR);\nVALUE_TYPE_OF_IMPL(uint32_t, kUINT32);\nVALUE_TYPE_OF_IMPL(int64_t, kINT64);\nVALUE_TYPE_OF_IMPL(uint64_t, kUINT64);\nVALUE_TYPE_OF_IMPL(float, kFLOAT);\nVALUE_TYPE_OF_IMPL(double, kDOUBLE);\nVALUE_TYPE_OF_IMPL(bool, kBOOL);\nVALUE_TYPE_OF_IMPL(std::string, kSTRING);\nVALUE_TYPE_OF_IMPL(std::vector<int32_t>, kINT32_LIST);\nVALUE_TYPE_OF_IMPL(std::vector<uint32_t>, kUINT32_LIST);\nVALUE_TYPE_OF_IMPL(std::vector<int64_t>, kINT64_LIST);\nVALUE_TYPE_OF_IMPL(std::vector<uint64_t>, kUINT64_LIST);\nVALUE_TYPE_OF_IMPL(std::vector<float>, kFLOAT_LIST);\nVALUE_TYPE_OF_IMPL(std::vector<double>, kDOUBLE_LIST);\nVALUE_TYPE_OF_IMPL(std::vector<bool>, kBOOL_LIST);\nVALUE_TYPE_OF_IMPL(std::vector<std::string>, kSTRING_LIST);\n\nVALUE_TYPE_OF_IMPL(Maybe<void>, kVOID_MAYBE);\nVALUE_TYPE_OF_IMPL(Maybe<bool>, kBOOL_MAYBE);\n\nVALUE_TYPE_OF_IMPL(Scalar, kSCALAR);\nVALUE_TYPE_OF_IMPL(one::Tensor, kTENSOR);\nVALUE_TYPE_OF_IMPL(std::shared_ptr<one::Tensor>, kTENSOR_REF);\nVALUE_TYPE_OF_IMPL(Maybe<one::Tensor>, kTENSOR_MAYBE);\nVALUE_TYPE_OF_IMPL(one::TensorTuple, kTENSOR_TUPLE);\nVALUE_TYPE_OF_IMPL(std::shared_ptr<one::TensorTuple>, kTENSOR_TUPLE_REF);\nVALUE_TYPE_OF_IMPL(Maybe<one::TensorTuple>, kTENSOR_TUPLE_MAYBE);\nVALUE_TYPE_OF_IMPL(Symbol<DType>, kDTYPE);\nVALUE_TYPE_OF_IMPL(Symbol<Layout>, kLAYOUT);\nVALUE_TYPE_OF_IMPL(std::vector<Symbol<DType>>, kDTYPE_LIST);\nVALUE_TYPE_OF_IMPL(Shape, kSHAPE);\nVALUE_TYPE_OF_IMPL(Maybe<Shape>, kSHAPE_MAYBE);\nVALUE_TYPE_OF_IMPL(std::vector<Shape>, kSHAPE_LIST);\nVALUE_TYPE_OF_IMPL(one::Generator, kGENERATOR);\nVALUE_TYPE_OF_IMPL(std::shared_ptr<one::Generator>, kGENERATOR_REF);\nVALUE_TYPE_OF_IMPL(Maybe<one::Generator>, kGENERATOR_MAYBE);\nVALUE_TYPE_OF_IMPL(TensorIndex, kTENSOR_INDEX);\nVALUE_TYPE_OF_IMPL(Symbol<Device>, kDEVICE);\nVALUE_TYPE_OF_IMPL(Symbol<ParallelDesc>, kPARALLEL_DESC);\nVALUE_TYPE_OF_IMPL(Symbol<SbpParallel>, kSBP_PARALLEL);\nVALUE_TYPE_OF_IMPL(std::vector<Symbol<SbpParallel>>, kSBP_PARALLEL_LIST);\n\nVALUE_TYPE_OF_IMPL(MemoryFormat, kMEMORY_FORMAT);\n\nVALUE_TYPE_OF_IMPL(one::OpExpr, kOPEXPR);\nVALUE_TYPE_OF_IMPL(std::shared_ptr<one::OpExpr>, kOPEXPR_REF);\n\nVALUE_TYPE_OF_IMPL(PyObject*, kPY_OBJECT);\nVALUE_TYPE_OF_IMPL(const PyObject*, kPY_OBJECT);\n\nVALUE_TYPE_OF_IMPL(std::complex<float>, kCOMPLEX_FLOAT);\nVALUE_TYPE_OF_IMPL(std::complex<double>, kCOMPLEX_DOUBLE);\n\n#undef VALUE_TYPE_OF_IMPL\n\nconst std::string& ValueTypeName(ValueType type);\n\nbool IsIntegralType(ValueType type);\nbool IsIntegralListType(ValueType type);\nbool IsFloatingType(ValueType type);\nbool IsFloatingListType(ValueType type);\n\n}  // namespace functional\n}  // namespace one\n}  // namespace oneflow\n\nnamespace std {\ntemplate<>\nstruct hash<oneflow::one::functional::ValueType> {\n  std::size_t operator()(oneflow::one::functional::ValueType v) const noexcept { return v; }\n};\n}  // namespace std\n\n#endif  // ONEFLOW_CORE_FUNCTIONAL_VALUE_TYPES_H_\n"
  },
  {
    "path": "oneflow/api/python/gil_foreign_lock_helper.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/foreign_lock_helper.h\"\n\n#include <pybind11/pybind11.h>\n#include \"oneflow/api/python/of_api_registry.h\"\n#include \"oneflow/core/common/singleton.h\"\n\nnamespace py = pybind11;\n\nnamespace oneflow {\nclass GILForeignLockHelper final : public ForeignLockHelper {\n  Maybe<void> WithScopedRelease(const std::function<Maybe<void>()>& Callback) const override {\n    if (PyGILState_Check()) {\n      py::gil_scoped_release release;\n      JUST(Callback());\n    } else {\n      JUST(Callback());\n    }\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> WithScopedAcquire(const std::function<Maybe<void>()>& Callback) const override {\n    if (!PyGILState_Check()) {\n      py::gil_scoped_acquire acquire;\n      JUST(Callback());\n    } else {\n      JUST(Callback());\n    }\n    return Maybe<void>::Ok();\n  }\n};\n\nONEFLOW_API_PYBIND11_MODULE(\"\", m) {\n  m.def(\"RegisterGILForeignLockHelper\", []() {\n    Singleton<ForeignLockHelper>::Delete();\n    Singleton<ForeignLockHelper>::SetAllocated(new GILForeignLockHelper());\n  });\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/api/python/init.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <vector>\n#include <unordered_map>\n#include <pybind11/pybind11.h>\n#include <pybind11/stl.h>\n#include \"oneflow/core/job/env_global_objects_scope.h\"\n#include \"oneflow/api/python/of_api_registry.h\"\n#include \"oneflow/core/job/cluster_instruction.h\"\n\nnamespace py = pybind11;\n\nPYBIND11_MAKE_OPAQUE(std::vector<int64_t>);\nPYBIND11_MAKE_OPAQUE(std::unordered_map<int64_t, std::shared_ptr<std::vector<int64_t>>>);\n\nnamespace oneflow {\n\nnamespace {\n\nusing IntList = std::vector<int64_t>;\nusing Int2IntListMap = std::unordered_map<int64_t, std::shared_ptr<IntList>>;\n\nbool Int2IntListMapContaining(const Int2IntListMap& bigger, const Int2IntListMap& smaller) {\n  for (const auto& pair : smaller) {\n    if (bigger.find(pair.first) == bigger.end()) { return false; }\n    const auto& bigger_device_ids = bigger.find(pair.first)->second;\n    std::vector<int64_t>::iterator ret;\n    for (int64_t device_id : *pair.second) {\n      ret = std::find(bigger_device_ids->begin(), bigger_device_ids->end(), device_id);\n      if (ret == bigger_device_ids->end()) { return false; }\n    }\n  }\n  return true;\n}\n\n}  // namespace\n\nPYBIND11_MODULE(_oneflow_internal, m) {\n  using IntList = std::vector<int64_t>;\n  using Int2IntListMap = std::unordered_map<int64_t, std::shared_ptr<IntList>>;\n\n  py::module_ oneflow_api_util = m.def_submodule(\"util\");\n\n  py::class_<IntList, std::shared_ptr<IntList>>(oneflow_api_util, \"IntList\")\n      .def(py::init<>())\n      .def(\"__len__\", [](const std::shared_ptr<IntList>& v) { return v->size(); })\n      .def(\n          \"items\",\n          [](std::shared_ptr<IntList>& v) { return py::make_iterator(v->begin(), v->end()); },\n          py::keep_alive<0, 1>())\n      .def(\"__getitem__\", (IntList::reference & (IntList::*)(IntList::size_type pos)) & IntList::at)\n      .def(\n          \"__iter__\",\n          [](std::shared_ptr<IntList>& v) { return py::make_iterator(v->begin(), v->end()); },\n          py::keep_alive<0, 1>())\n      .def(\"__eq__\", [](std::shared_ptr<IntList>& lhs, std::shared_ptr<IntList>& rhs) {\n        return *lhs == *rhs;\n      });\n\n  py::class_<Int2IntListMap, std::shared_ptr<Int2IntListMap>>(oneflow_api_util, \"Int2IntListMap\")\n      .def(py::init<>())\n      .def(\"__len__\", [](const std::shared_ptr<Int2IntListMap>& v) { return v->size(); })\n      .def(\n          \"items\",\n          [](std::shared_ptr<Int2IntListMap>& v) {\n            return py::make_iterator(v->begin(), v->end());\n          },\n          py::keep_alive<0, 1>())\n      .def(\"__getitem__\",\n           (Int2IntListMap::mapped_type & (Int2IntListMap::*)(const Int2IntListMap::key_type& pos))\n               & Int2IntListMap::operator[])\n      .def(\n          \"__iter__\",\n          [](std::shared_ptr<Int2IntListMap>& v) {\n            return py::make_iterator(v->begin(), v->end());\n          },\n          py::keep_alive<0, 1>())\n      .def(\"__eq__\",\n           [](std::shared_ptr<Int2IntListMap>& lhs, std::shared_ptr<Int2IntListMap>& rhs) {\n             return Int2IntListMapContaining(*lhs, *rhs) && Int2IntListMapContaining(*rhs, *lhs);\n           });\n  ::oneflow::OneflowModuleRegistry().ImportAll(m);\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/api/python/ir.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/singleton.h\"\n#include \"oneflow/ir/oneflow-extension/include/PyAst/Ast.h\"\n\n#include <llvm/IR/IntrinsicsS390.h>\n\n#include <pybind11/pybind11.h>\n#include <pybind11/pytypes.h>\n#include <pybind11/stl.h>\n\n#include <algorithm>\n#include <string>\n#include <iostream>\n#include <tuple>\n#include <vector>\n\n#ifdef WITH_MLIR\n\n#include \"oneflow/ir/include/OneFlow/Extension.h\"\n#include \"oneflow/ir/oneflow-extension/include/OneFlow/OneFlowRoundTrip.h\"\n#include \"oneflow/ir/oneflow-extension/include/OneFlow/OneFlowLRJITRegistry.h\"\n#include \"oneflow/api/python/of_api_registry.h\"\n#include <glog/logging.h>\n#include <functional>\n#include <utility>\n\nnamespace oneflow {\nONEFLOW_API_PYBIND11_MODULE(\"ir\", m) {\n  m.def(\"load_jit_shared_lib\",\n        [](const std::string& lib_path) { MutSharedLibPaths()->insert(lib_path); });\n\n  // TODO: this may be move to a common place for create global singleton.\n  m.def(\"create_global_lr_jit\", []() { Singleton<LRJITRegistry>::New(); });\n\n  m.def(\"compile_and_register_lr_jit\", [](const std::string& function_id,\n                                          std::shared_ptr<pyast::FunctionDef>& func, bool is_dump) {\n    Singleton<LRJITRegistry>::Get()->Register(function_id, *func.get(), is_dump);\n  });\n\n  // look up and execute the registered function for python api\n  m.def(\"get_lr\", [](const std::string& function_id, float base_lr, float step) {\n    auto engine = Singleton<LRJITRegistry>::Get()->LookUp(function_id);\n    return engine(base_lr, step);\n  });\n\n  pybind11::class_<pyast::stmt, std::shared_ptr<pyast::stmt>>(m, \"smt\");\n\n  pybind11::class_<pyast::expr, std::shared_ptr<pyast::expr>>(m, \"expr\");\n\n  pybind11::class_<pyast::FunctionDef, pyast::stmt, std::shared_ptr<pyast::FunctionDef>>(\n      m, \"FunctionDef\");\n  m.def(\"FunctionDef_\", &pyast::FunctionDef::FunctionDef_);\n\n  pybind11::class_<pyast::Return, pyast::stmt, std::shared_ptr<pyast::Return>>(m, \"Return\");\n  m.def(\"Return_\", &pyast::Return::Return_);\n\n  pybind11::class_<pyast::Assign, pyast::stmt, std::shared_ptr<pyast::Assign>>(m, \"Assign\");\n  m.def(\"Assign_\", &pyast::Assign::Assign_);\n\n  pybind11::class_<pyast::If, pyast::stmt, std::shared_ptr<pyast::If>>(m, \"If\");\n  m.def(\"If_\", &pyast::If::If_);\n\n  pybind11::class_<pyast::Raise, pyast::stmt, std::shared_ptr<pyast::Raise>>(m, \"Raise\");\n  m.def(\"Raise_\", &pyast::Raise::Raise_);\n\n  pybind11::class_<pyast::Assert, pyast::stmt, std::shared_ptr<pyast::Assert>>(m, \"Assert\");\n  m.def(\"Assert_\", &pyast::Assert::Assert_);\n\n  pybind11::class_<pyast::Expr, pyast::stmt, std::shared_ptr<pyast::Expr>>(m, \"Expr\");\n  m.def(\"Expr_\", &pyast::Expr::Expr_);\n\n  pybind11::class_<pyast::BoolOp, pyast::expr, std::shared_ptr<pyast::BoolOp>>(m, \"BoolOp\");\n  m.def(\"BoolOp_\", &pyast::BoolOp::BoolOp_);\n\n  pybind11::class_<pyast::BinOp, pyast::expr, std::shared_ptr<pyast::BinOp>>(m, \"BinOp\");\n  m.def(\"BinOp_\", &pyast::BinOp::BinOp_);\n\n  pybind11::class_<pyast::Lambda, pyast::expr, std::shared_ptr<pyast::Lambda>>(m, \"Lambda\");\n  m.def(\"Lambda_\", &pyast::Lambda::Lambda_);\n\n  pybind11::class_<pyast::Compare, pyast::expr, std::shared_ptr<pyast::Compare>>(m, \"Compare\");\n  m.def(\"Compare_\", &pyast::Compare::Compare_);\n\n  pybind11::class_<pyast::Call, pyast::expr, std::shared_ptr<pyast::Call>>(m, \"Call\");\n  m.def(\"Call_\", &pyast::Call::Call_);\n\n  pybind11::class_<pyast::Num, pyast::expr, std::shared_ptr<pyast::Num>>(m, \"Num\");\n  m.def(\"Num_\", &pyast::Num::Num_);\n\n  pybind11::class_<pyast::Constant, pyast::expr, std::shared_ptr<pyast::Constant>>(m, \"Constant\");\n  m.def(\"Constant_\", &pyast::Constant::Constant_);\n\n  pybind11::class_<pyast::Attribute, pyast::expr, std::shared_ptr<pyast::Attribute>>(m,\n                                                                                     \"Attribute\");\n  m.def(\"Attribute_\", &pyast::Attribute::Attribute_);\n\n  pybind11::class_<pyast::Name, pyast::expr, std::shared_ptr<pyast::Name>>(m, \"Name\");\n  m.def(\"Name_\", &pyast::Name::Name_);\n\n  pybind11::class_<pyast::arguments, std::shared_ptr<pyast::arguments>>(m, \"arguments\");\n  m.def(\"arguments_\", &pyast::arguments::arguments_);\n\n  pybind11::class_<pyast::arg, std::shared_ptr<pyast::arg>>(m, \"arg\");\n  m.def(\"arg_\", &pyast::arg::arg_);\n}\n\n}  // namespace oneflow\n\n#endif  // WITH_MLIR\n"
  },
  {
    "path": "oneflow/api/python/job_build/job_build_and_infer.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <pybind11/pybind11.h>\n#include <string>\n#include \"oneflow/api/python/of_api_registry.h\"\n#include \"oneflow/api/python/job_build/job_build_and_infer.h\"\n\nnamespace py = pybind11;\n\nnamespace oneflow {\n\nMaybe<void> MarkVariableGradients(const one::TensorTuple& variables,\n                                  const one::TensorTuple& gradients) {\n  CHECK_OR_RETURN(LazyMode::is_enabled());                 // NOLINT(maybe-need-error-msg)\n  CHECK_EQ_OR_RETURN(variables.size(), gradients.size());  // NOLINT(maybe-need-error-msg)\n  HashMap<std::string, std::string> variable_grad_lbns;\n  for (int i = 0; i < variables.size(); ++i) {\n    const std::string& variable_lbn = one::TensorNameScope::Global()->Lookup(variables[i]);\n    CHECK_OR_RETURN(!variable_lbn.empty())\n        << \"variable which index is \" << i << \" expected to have a tensor name\";\n    const std::string& gradient_lbn = one::TensorNameScope::Global()->Lookup(gradients[i]);\n    CHECK_OR_RETURN(!gradient_lbn.empty())\n        << \"gradient which index is \" << i << \" expected to have a tensor name\";\n    variable_grad_lbns.emplace(variable_lbn, gradient_lbn);\n  }\n  return JUST(GetCurInferCtx())->MarkVariableGradientBlobNames(variable_grad_lbns);\n}\n\nMaybe<void> MarkOutputGradients(const one::TensorTuple& outputs,\n                                const one::TensorTuple& gradients) {\n  CHECK_OR_RETURN(LazyMode::is_enabled());               // NOLINT(maybe-need-error-msg)\n  CHECK_EQ_OR_RETURN(outputs.size(), gradients.size());  // NOLINT(maybe-need-error-msg)\n  HashMap<std::string, std::string> output_gradient_lbns;\n  for (int i = 0; i < outputs.size(); ++i) {\n    const std::string& output_lbn = one::TensorNameScope::Global()->Lookup(outputs[i]);\n    CHECK_OR_RETURN(!output_lbn.empty())\n        << \"output which index is \" << i << \" expected to have a tensor name\";\n    const std::string& gradient_lbn = one::TensorNameScope::Global()->Lookup(gradients[i]);\n    CHECK_OR_RETURN(!gradient_lbn.empty())\n        << \"gradient which index is \" << i << \" expected to have a tensor name\";\n    output_gradient_lbns.emplace(output_lbn, gradient_lbn);\n  }\n  return JUST(GetCurInferCtx())->MarkOutputGradientBlobNames(output_gradient_lbns);\n}\n\nONEFLOW_API_PYBIND11_MODULE(\"\", m) {\n  m.def(\"JobBuildAndInferCtx_Open\", &JobBuildAndInferCtx_Open);\n  m.def(\"JobBuildAndInferCtx_GetCurrentJobName\", &JobBuildAndInferCtx_GetCurrentJobName);\n  m.def(\"JobBuildAndInferCtx_GetCurrentJobId\", &JobBuildAndInferCtx_GetCurrentJobId);\n  m.def(\"JobBuildAndInferCtx_Close\", &JobBuildAndInferCtx_Close);\n\n  m.def(\"CurJobBuildAndInferCtx_SetJobConf\", &CurJobBuildAndInferCtx_SetJobConf);\n\n  m.def(\"CurJobBuildAndInferCtx_Complete\", &CurJobBuildAndInferCtx_Complete,\n        py::call_guard<py::gil_scoped_release>());\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/api/python/job_build/job_build_and_infer.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_API_PYTHON_JOB_BUILD_JOB_BUILD_AND_INFER_H_\n#define ONEFLOW_API_PYTHON_JOB_BUILD_JOB_BUILD_AND_INFER_H_\n\n#include \"oneflow/core/job/global_for.h\"\n#include \"oneflow/core/common/protobuf.h\"\n#include \"oneflow/core/framework/tensor.h\"\n#include \"oneflow/core/framework/tensor_tuple.h\"\n#include \"oneflow/core/framework/tensor_name_scope.h\"\n#include \"oneflow/core/job/job_build_and_infer_ctx.h\"\n#include \"oneflow/core/job/job_build_and_infer_ctx_mgr.h\"\n#include \"oneflow/core/job/job.pb.h\"\n#include \"oneflow/core/job/lazy_mode.h\"\n#include \"oneflow/core/record/record.pb.h\"\n\nnamespace oneflow {\n\ninline Maybe<void> JobBuildAndInferCtx_Open(const std::string& job_name) {\n  auto* mgr = JUST(GlobalJobBuildAndInferCtxMgr());\n  return mgr->OpenJobBuildAndInferCtx(job_name);\n}\n\ninline Maybe<std::string> JobBuildAndInferCtx_GetCurrentJobName() {\n  auto* mgr = JUST(GlobalJobBuildAndInferCtxMgr());\n  return mgr->GetCurrentJobName();\n}\n\ninline Maybe<int64_t> JobBuildAndInferCtx_GetCurrentJobId() {\n  return JUST(GetCurInferCtx())->job_id();\n}\n\ninline Maybe<void> JobBuildAndInferCtx_Close() {\n  auto* mgr = JUST(GlobalJobBuildAndInferCtxMgr());\n  JUST(mgr->CloseCurrentJobBuildAndInferCtx());\n  return Maybe<void>::Ok();\n}\n\ninline Maybe<void> CurJobBuildAndInferCtx_SetJobConf(const std::string& job_conf_str) {\n  JobConfigProto job_conf;\n  CHECK_OR_RETURN(TxtString2PbMessage(job_conf_str, &job_conf)) << \"job conf parse failed\";\n  return JUST(GetCurInferCtx())->SetJobConf(job_conf);\n}\n\ninline Maybe<void> CurJobBuildAndInferCtx_Complete() { return JUST(GetCurInferCtx())->Complete(); }\n\ninline Maybe<void> AddTensorAsGraphLoss(const std::shared_ptr<one::Tensor>& t) {\n  CHECK_OR_RETURN(t->is_lazy());\n  CHECK_OR_RETURN(LazyMode::is_enabled());\n  const std::string& loss_lbn = one::TensorNameScope::Global()->Lookup(t);\n  CHECK_OR_RETURN(\"\" != loss_lbn);\n  return JUST(GetCurInferCtx())->AddLossLogicalBlobName(loss_lbn);\n}\n\nMaybe<void> MarkVariableGradients(const one::TensorTuple& variables,\n                                  const one::TensorTuple& gradients);\n\nMaybe<void> MarkOutputGradients(const one::TensorTuple& outputs, const one::TensorTuple& gradients);\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_API_PYTHON_JOB_BUILD_JOB_BUILD_AND_INFER_H_\n"
  },
  {
    "path": "oneflow/api/python/job_build/lazy_mode.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <memory>\n#include <pybind11/pybind11.h>\n#include \"oneflow/api/python/of_api_registry.h\"\n#include \"oneflow/core/job/lazy_mode.h\"\n\nnamespace py = pybind11;\n\nnamespace oneflow {\n\nONEFLOW_API_PYBIND11_MODULE(\"lazy_mode\", m) {\n  py::class_<LazyMode::Guard, std::shared_ptr<LazyMode::Guard>>(m, \"guard\")\n      .def(py::init(\n          [](const bool is_enabled) { return std::make_shared<LazyMode::Guard>(is_enabled); }))\n      .def(\"__enter__\", [](const LazyMode::Guard& guard_obj) {})\n      .def(\"__exit__\", [](const LazyMode::Guard& guard_obj, const py::object& type,\n                          const py::object& value, const py::object& traceback) {});\n\n  m.def(\"is_enabled\", []() { return LazyMode::is_enabled(); });\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/api/python/multiprocessing/init.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <pybind11/pybind11.h>\n#include \"oneflow/api/python/of_api_registry.h\"\n#include \"oneflow/api/python/multiprocessing/object_ptr.h\"\n#include \"oneflow/core/ep/cpu/cpu_device_manager.h\"\n#include \"oneflow/core/ep/include/device_manager_registry.h\"\n#include \"oneflow/core/ep/cpu/cpu_device.h\"\n#include <csignal>\n\n#include <stdexcept>\n\n#if defined(__linux__)\n#include <sys/prctl.h>\n#include <system_error>\n#endif\n\n#define SYSASSERT(rv, ...) \\\n  if ((rv) < 0) { throw std::system_error(errno, std::system_category(), ##__VA_ARGS__); }\n\nnamespace oneflow {\nnamespace multiprocessing {\n\nnamespace py = pybind11;\n\nvoid multiprocessing_init() {\n  auto multiprocessing_module = OFObjectPtr(PyImport_ImportModule(\"oneflow.multiprocessing\"));\n  if (!multiprocessing_module) {\n    throw std::runtime_error(\"multiprocessing init error >> multiprocessing_module init fail!\");\n  }\n\n  auto module = py::handle(multiprocessing_module).cast<py::module>();\n\n  module.def(\"_prctl_pr_set_pdeathsig\", [](int signal) {\n#if defined(__linux__)\n    auto rv = prctl(PR_SET_PDEATHSIG, signal);\n    SYSASSERT(rv, \"prctl\");\n#endif\n  });\n\n  // Py_RETURN_TRUE;\n}\n\nvoid set_num_threads(int num) {\n  int64_t cpu_logic_core = std::thread::hardware_concurrency();\n  if (num <= 0) {\n    py::print(\"Warning : \", num, \" less than 1 will be set to 1.\");\n    num = 1;\n  } else if (num >= cpu_logic_core) {\n    py::print(\"Warning : \", num,\n              \" is greater than the number of logical cores and will be set to the maximum number \"\n              \"of logical cores \",\n              cpu_logic_core);\n    num = cpu_logic_core;\n  }\n\n  auto cpu_device = std::static_pointer_cast<ep::CpuDevice>(\n      Singleton<ep::DeviceManagerRegistry>::Get()->GetDevice(DeviceType::kCPU, 0));\n  cpu_device->SetNumThreads(num);\n}\n\nONEFLOW_API_PYBIND11_MODULE(\"\", m) {\n  py::options options;\n  options.disable_function_signatures();\n  m.def(\"_multiprocessing_init\", &multiprocessing_init);\n  m.def(\"_set_num_threads\", &set_num_threads);\n  options.disable_function_signatures();\n}\n\n}  // namespace multiprocessing\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/api/python/multiprocessing/object_ptr.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/api/python/multiprocessing/object_ptr.h\"\n\ntemplate<>\nvoid OFPointer<PyObject>::free() {\n  if (ptr) Py_DECREF(ptr);\n}\n\ntemplate class OFPointer<PyObject>;\n"
  },
  {
    "path": "oneflow/api/python/multiprocessing/object_ptr.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#pragma once\n\n#include <pybind11/pybind11.h>\n#include \"oneflow/api/python/of_api_registry.h\"\n\n// reference: pytorch/torch/csrc/utils/object_ptr.h\n// https://github.com/pytorch/pytorch/blob/d69c22dd61a2f006dcfe1e3ea8468a3ecaf931aa/torch/csrc/utils/object_ptr.h\ntemplate<class T>\nclass OFPointer {\n public:\n  OFPointer() : ptr(nullptr){};\n  explicit OFPointer(T* ptr) noexcept : ptr(ptr){};\n  OFPointer(OFPointer&& p) noexcept {\n    free();\n    ptr = p.ptr;\n    p.ptr = nullptr;\n  };\n\n  ~OFPointer() { free(); };\n  T* get() { return ptr; }\n  const T* get() const { return ptr; }\n  T* release() {\n    T* tmp = ptr;\n    ptr = nullptr;\n    return tmp;\n  }\n  operator T*() { return ptr; }\n  OFPointer& operator=(T* new_ptr) noexcept {\n    free();\n    ptr = new_ptr;\n    return *this;\n  }\n  OFPointer& operator=(OFPointer&& p) noexcept {\n    free();\n    ptr = p.ptr;\n    p.ptr = nullptr;\n    return *this;\n  }\n  T* operator->() { return ptr; }\n  explicit operator bool() const { return ptr != nullptr; }\n\n private:\n  void free();\n  T* ptr = nullptr;\n};\n\n/**\n * An RAII-style, owning pointer to a PyObject.  You must protect\n * destruction of this object with the GIL.\n *\n * WARNING: Think twice before putting this as a field in a C++\n * struct.  This class does NOT take out the GIL on destruction,\n * so if you will need to ensure that the destructor of your struct\n * is either (a) always invoked when the GIL is taken or (b) takes\n * out the GIL itself.  Easiest way to avoid this problem is to\n * not use THPPointer in this situation.\n */\nusing OFObjectPtr = OFPointer<PyObject>;\n"
  },
  {
    "path": "oneflow/api/python/multiprocessing/shared_memory.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <pybind11/pybind11.h>\n#include <pybind11/stl.h>\n#include \"oneflow/api/python/of_api_registry.h\"\n#include \"oneflow/core/ipc/shared_memory.h\"\n\nnamespace oneflow {\n\nnamespace py = pybind11;\n\nONEFLOW_API_PYBIND11_MODULE(\"multiprocessing\", m) {\n  py::class_<ipc::SharedMemory, std::shared_ptr<ipc::SharedMemory>>(m, \"SharedMemory\")\n      .def(py::init([](const std::string& name, bool create, size_t size) {\n             if (create) { return ipc::SharedMemory::Open(size, create).GetPtrOrThrow(); }\n             return ipc::SharedMemory::Open(name, create).GetPtrOrThrow();\n           }),\n           py::arg(\"name\") = \"\", py::arg(\"create\") = false, py::arg(\"size\") = 0)\n      .def(\"close\", &ipc::SharedMemory::Close)\n      .def(\"unlink\", &ipc::SharedMemory::Unlink)\n      .def_property_readonly(\"buf\",\n                             [](ipc::SharedMemory* shm) {\n                               return py::memoryview::from_memory(shm->mut_buf(), shm->size());\n                             })\n      .def_property_readonly(\"name\", &ipc::SharedMemory::name)\n      .def_property_readonly(\"size\", &ipc::SharedMemory::size);\n  m.def(\"unlink_all_shared_memory\",\n        []() { return ipc::SharedMemoryManager::get().UnlinkAllShms(); });\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/api/python/numpy/init_numpy_c_api.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include <pybind11/pybind11.h>\n#include \"oneflow/api/python/of_api_registry.h\"\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/extension/python/numpy.h\"\n\nnamespace py = pybind11;\n\nONEFLOW_API_PYBIND11_MODULE(\"\", m) {\n  m.def(\"InitNumpyCAPI\", []() { return oneflow::numpy::InitNumpyCAPI(); });\n}\n"
  },
  {
    "path": "oneflow/api/python/of_api_registry.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/api/python/of_api_registry.h\"\n\nnamespace oneflow {\n\nnamespace {\n\n// If different APIs are registered under the same path, the BuildModuleFuntion of which will be\n// saved in the corresponding vector.\nusing SubModuleMap = std::map<std::string, std::vector<std::function<void(pybind11::module&)>>>;\n\nSubModuleMap* GetSubModuleMap() {\n  static SubModuleMap sub_module_map;\n  return &sub_module_map;\n}\n\n}  // namespace\n\nvoid OneflowModuleRegistry::Register(std::string module_path,\n                                     std::function<void(pybind11::module&)> BuildModule) {\n  (*GetSubModuleMap())[module_path].emplace_back(BuildModule);\n}\n\nvoid OneflowModuleRegistry::ImportAll(pybind11::module& m) {\n  for (const auto& pair : (*GetSubModuleMap())) {\n    for (const auto& BuildModule : pair.second) { BuildSubModule(pair.first, m, BuildModule); }\n  }\n}\n\nvoid OneflowModuleRegistry::BuildSubModule(\n    const std::string& module_path, pybind11::module& m,\n    const std::function<void(pybind11::module&)>& BuildModule) {\n  if (module_path.empty()) {\n    BuildModule(m);\n    return;\n  }\n  size_t dot_pos = module_path.find(\".\");\n  if (dot_pos == std::string::npos) {\n    pybind11::module sub_module = m.def_submodule(module_path.data());\n    BuildModule(sub_module);\n  } else {\n    const std::string& sub_module_name = module_path.substr(0, dot_pos);\n    pybind11::module sub_module = m.def_submodule(sub_module_name.data());\n    BuildSubModule(module_path.substr(dot_pos + 1), sub_module, BuildModule);\n  }\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/api/python/of_api_registry.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_API_PYTHON_UTIL_OF_API_REGISTRY_H_\n#define ONEFLOW_API_PYTHON_UTIL_OF_API_REGISTRY_H_\n#include <pybind11/pybind11.h>\n#include <map>\n#include <vector>\n#include <functional>\n#include \"oneflow/api/python/caster/maybe.h\"\n#include \"oneflow/api/python/caster/optional.h\"\n#include \"oneflow/api/python/caster/size.h\"\n#include \"oneflow/api/python/caster/tensor.h\"\n#include \"oneflow/api/python/caster/autograd_function_state.h\"\n#include \"oneflow/core/common/preprocessor.h\"\n\nnamespace oneflow {\n\nclass OneflowModuleRegistry {\n public:\n  OneflowModuleRegistry() = default;\n  ~OneflowModuleRegistry() = default;\n\n  void Register(std::string module_path, std::function<void(pybind11::module&)> BuildModule);\n  void ImportAll(pybind11::module& m);\n\n private:\n  void BuildSubModule(const std::string& module_path, pybind11::module& m,\n                      const std::function<void(pybind11::module&)>& BuildModule);\n};\n\n}  // namespace oneflow\n\n#define ONEFLOW_API_PYBIND11_MODULE(module_path, m)                                              \\\n  static void OF_PP_CAT(OneflowApiPythonModule, __LINE__)(pybind11::module&);                    \\\n  namespace {                                                                                    \\\n  struct OfApiRegistryInit {                                                                     \\\n    OfApiRegistryInit() {                                                                        \\\n      ::oneflow::OneflowModuleRegistry().Register(module_path,                                   \\\n                                                  &OF_PP_CAT(OneflowApiPythonModule, __LINE__)); \\\n    }                                                                                            \\\n  };                                                                                             \\\n  OfApiRegistryInit of_api_registry_init;                                                        \\\n  }                                                                                              \\\n  static void OF_PP_CAT(OneflowApiPythonModule, __LINE__)(pybind11::module & m)\n\n#endif  // ONEFLOW_API_PYTHON_UTIL_OF_API_REGISTRY_H_\n"
  },
  {
    "path": "oneflow/api/python/profiler.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <pybind11/pybind11.h>\n#include \"oneflow/api/python/of_api_registry.h\"\n#include \"oneflow/core/profiler/profiler.h\"\n\nnamespace py = pybind11;\n\nnamespace oneflow {\n\nONEFLOW_API_PYBIND11_MODULE(\"profiler\", m) {\n  m.def(\"RangePush\", [](const std::string& str) { OF_PROFILER_RANGE_PUSH(str); });\n\n  m.def(\"RangePop\", []() { OF_PROFILER_RANGE_POP(); });\n\n  m.def(\"ProfilerStart\", []() { profiler::ProfilerStart(); });\n\n  m.def(\"ProfilerStop\", []() { profiler::ProfilerStop(); });\n\n  m.def(\"EnableProfiler\", &profiler::EnableProfiler);\n\n  m.def(\"DisableProfilerAndReturnResult\", &profiler::DisableProfilerAndReturnResult);\n\n  m.def(\"StartRecord\", &profiler::StartRecord);\n\n  m.def(\"EndRecord\", &profiler::EndRecord);\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/api/python/registry/registry.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include <pybind11/pybind11.h>\n#include \"oneflow/api/python/of_api_registry.h\"\n#include \"oneflow/core/common/registry_error.h\"\n\nnamespace py = pybind11;\n\nnamespace oneflow {\n\nONEFLOW_API_PYBIND11_MODULE(\"\", m) {\n  m.def(\"CheckAndClearRegistryFlag\", &CheckAndClearRegistryFlag);\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/api/python/remat/remat.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <pybind11/pybind11.h>\n#include \"oneflow/api/python/of_api_registry.h\"\n#include \"oneflow/core/vm/remat/allocator.h\"\n#include \"oneflow/core/vm/remat/env.h\"\n#include \"oneflow/core/eager/eager_blob_object.h\"\n#include \"oneflow/core/framework/tensor.h\"\n#include \"oneflow/core/job/global_for.h\"\n#include \"oneflow/core/eager/tensor_storage.h\"\n\nnamespace py = pybind11;\n\nnamespace oneflow {\n\nnamespace {\nMaybe<vm::RematableTensorStorage> rematable_storage(const std::shared_ptr<one::Tensor>& tensor) {\n  auto ret = std::dynamic_pointer_cast<vm::RematableTensorStorage>(\n      JUST(tensor->eager_blob_object())->tensor_storage());\n  CHECK_NOTNULL_OR_RETURN(ret);\n  return ret;\n}\n}  // namespace\n\nONEFLOW_API_PYBIND11_MODULE(\"remat\", m) {\n  m.def(\"is_in_memory\", [](const std::shared_ptr<one::Tensor>& tensor) -> Maybe<bool> {\n    return JUST(rematable_storage(tensor))->is_in_memory();\n  });\n  m.def(\"allocated_memory\", [](const std::string& device_str) -> Maybe<size_t> {\n    auto device = JUST(Device::ParseAndNew(device_str));\n    return Singleton<remat::AllocatorManager>::Get()\n        ->CreateOrGetAllocator(device->enum_type(), device->device_id())\n        ->allocated_memory();\n  });\n  m.def(\"display\", [](const std::string& device_str) -> Maybe<void> {\n    auto device = JUST(Device::ParseAndNew(device_str));\n    Singleton<remat::AllocatorManager>::Get()\n        ->CreateOrGetAllocator(device->enum_type(), device->device_id())\n        ->DisplayAllPieces();\n    return Maybe<void>::Ok();\n  });\n  m.def(\"remat\", [](const std::shared_ptr<one::Tensor>& t) -> Maybe<void> {\n    // TODO: an instruction\n    JUST(rematable_storage(t))->Remat();\n    return Maybe<void>::Ok();\n  });\n  m.def(\"evict\", [](const std::shared_ptr<one::Tensor>& t) -> Maybe<void> {\n    // TODO: an instruction\n    JUST(rematable_storage(t))->Evict(false);\n    return Maybe<void>::Ok();\n  });\n  m.def(\"is_evictable\", [](const std::shared_ptr<one::Tensor>& t) -> Maybe<bool> {\n    return JUST(rematable_storage(t))->is_evictable();\n  });\n  m.def(\"disable_eviction\", [](const std::shared_ptr<one::Tensor>& t) -> Maybe<void> {\n    JUST(rematable_storage(t))->set_eviction_disabled(true);\n    return Maybe<void>::Ok();\n  });\n  m.def(\"clear_compute_op\", [](const std::shared_ptr<one::Tensor>& t) -> Maybe<void> {\n    JUST(rematable_storage(t))->clear_compute_op();\n    return Maybe<void>::Ok();\n  });\n  m.def(\"clear_stats\", []() { Singleton<remat::Env>::Get()->clear_stats(); });\n  m.def(\"forced_eviction_num\",\n        []() { return Singleton<remat::Env>::Get()->forced_eviction_num(); });\n  m.def(\"eager_eviction_num\", []() { return Singleton<remat::Env>::Get()->eager_eviction_num(); });\n  m.def(\"recomputation_num\", []() { return Singleton<remat::Env>::Get()->recomputation_num(); });\n  m.def(\"set_budget_in_bytes\", [](size_t budget_in_bytes) {\n    Singleton<remat::Env>::Get()->set_budget_in_bytes(budget_in_bytes);\n  });\n  m.def(\"budget_in_bytes\", []() { return Singleton<remat::Env>::Get()->budget_in_bytes(); });\n  m.def(\"set_small_pieces_optimization\", [](bool enabled) {\n    return Singleton<remat::Env>::Get()->set_small_pieces_optimization(enabled);\n  });\n  m.def(\"is_small_pieces_optimization_enabled\",\n        []() { return Singleton<remat::Env>::Get()->is_small_pieces_optimization_enabled(); });\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/api/python/rpc/ccl.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <pybind11/pybind11.h>\n#include <pybind11/pytypes.h>\n#include \"oneflow/api/python/of_api_registry.h\"\n#include \"oneflow/core/framework/dtype.h\"\n#include \"oneflow/core/ccl/ccl.h\"\n#include \"oneflow/core/rpc/include/global_process_ctx.h\"\n#include \"oneflow/core/job/rank_group.h\"\n\nnamespace py = pybind11;\n\nnamespace oneflow {\n\nnamespace {\nMaybe<py::bytes> CpuBroadcast(py::bytes* in, int64_t root) {\n  const auto& rank_group = JUST(RankGroup::DefaultRankGroup());\n  const auto& parallel_desc = JUST(RankGroup::GetDefaultParallelDesc(DeviceType::kCPU, rank_group));\n  Py_ssize_t length;\n  char* buffer;\n  if (GlobalProcessCtx::Rank() == root) {\n    CHECK_NOTNULL_OR_RETURN(in);\n    PyBytes_AsStringAndSize(in->ptr(), &buffer, &length);\n  }\n  const auto& meta_transport_token =\n      JUST(TransportToken::NewTransportToken(kTransportTokenTypeMeta));\n  JUST(ccl::CpuBroadcast(&length, &length, sizeof(length), root, parallel_desc,\n                         meta_transport_token));\n\n  const auto& data_transport_token =\n      JUST(TransportToken::NewTransportToken(kTransportTokenTypeData));\n  if (GlobalProcessCtx::Rank() == root) {\n    JUST(ccl::CpuBroadcast(buffer, buffer, length, root, parallel_desc,  // NOLINT\n                           data_transport_token));                       // NOLINT\n    return *in;\n  } else {\n    // https://github.com/pybind/pybind11/issues/1236#issuecomment-527730864\n    PyBytesObject* bytesObject =\n        static_cast<PyBytesObject*>(PyObject_Malloc(offsetof(PyBytesObject, ob_sval) + length + 1));\n\n    PyObject_INIT_VAR(bytesObject, &PyBytes_Type, length);\n    bytesObject->ob_shash = -1;\n    bytesObject->ob_sval[length] = '\\0';\n    buffer = bytesObject->ob_sval;\n    JUST(ccl::CpuBroadcast(nullptr, buffer, length, root, parallel_desc, data_transport_token));\n    return py::reinterpret_steal<py::bytes>(reinterpret_cast<PyObject*>(bytesObject));\n  }\n}\n\n}  // namespace\n\nONEFLOW_API_PYBIND11_MODULE(\"\", m) {\n  m.def(\"cpu_broadcast\",\n        [](py::bytes in, int64_t root) -> Maybe<py::bytes> { return CpuBroadcast(&in, root); });\n  m.def(\"cpu_broadcast\", [](const py::none& in, int64_t root) -> Maybe<py::bytes> {\n    return CpuBroadcast(nullptr, root);\n  });\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/api/python/rpc/rank_group.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <pybind11/pybind11.h>\n#include <pybind11/stl.h>\n#include <pybind11/functional.h>\n#include \"oneflow/api/python/of_api_registry.h\"\n#include \"oneflow/core/framework/rank_group_rpc_util.h\"\n#include \"oneflow/core/job/rank_group.h\"\n#include \"oneflow/core/job/rank_group_scope.h\"\n#include \"oneflow/core/common/symbol.h\"\n\nnamespace py = pybind11;\n\nnamespace oneflow {\n\nnamespace {\n\nMaybe<void> CheckCurrentRankGroupConsistency() {\n  const auto& rank_group = JUST(RankGroupScope::CurrentRankGroup());\n  const auto& ctx = JUST(CheckTransportToken(rank_group));\n  JUST(ctx->WaitDone());\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\nONEFLOW_API_PYBIND11_MODULE(\"\", m) {\n  m.def(\"check_current_rank_group_consistency\", &CheckCurrentRankGroupConsistency);\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/api/python/session/session.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <pybind11/pybind11.h>\n#include <pybind11/pytypes.h>\n#include <string>\n#include \"oneflow/api/python/of_api_registry.h\"\n#include \"oneflow/core/job/session.h\"\n#include \"oneflow/core/job/env_global_objects_scope.h\"\n#include \"oneflow/core/framework/multi_client_session_context.h\"\n\nnamespace py = pybind11;\n\nnamespace oneflow {\n\nONEFLOW_API_PYBIND11_MODULE(\"\", m) {\n  using namespace oneflow;\n  py::class_<MultiClientSessionContext, std::shared_ptr<MultiClientSessionContext>>(\n      m, \"SessionContext\")\n      .def(py::init<const std::shared_ptr<EnvGlobalObjectsScope>&>())\n      .def(\"try_init\",\n           [](MultiClientSessionContext& session, const std::string& config_proto_str) {\n             return session.TryInit(config_proto_str).GetOrThrow();\n           })\n      .def(\"update_resource\",\n           [](MultiClientSessionContext& session, const std::string& reso_proto_str) {\n             return session.UpdateResource(reso_proto_str).GetOrThrow();\n           });\n\n  m.def(\"NewSessionId\", &NewSessionId);\n  py::class_<LogicalConfigProtoContext>(m, \"LogicalConfigProtoContext\")\n      .def(py::init<const std::string&>());\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/api/python/stack_getter.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include <utility>\n\n#include \"pybind11/pybind11.h\"\n\n#include \"oneflow/api/python/of_api_registry.h\"\n#include \"oneflow/core/common/singleton.h\"\n#include \"oneflow/extension/stack/foreign_stack_getter.h\"\n#include \"oneflow/extension/stack/python/stack_getter.h\"\n#include \"oneflow/extension/stack/stacktrace.h\"\n\nnamespace py = pybind11;\n\nnamespace oneflow {\n\nONEFLOW_API_PYBIND11_MODULE(\"\", m) {\n  m.def(\"RegisterStackGetter\", &RegisterPyStackGetter);\n  m.def(\"GetCurrentStack\", []() {\n    auto* stack_getter = Singleton<ForeignStackGetter>::Get();\n    return stack_getter->GetFormattedStack(stack_getter->GetCurrentFrame());\n  });\n  m.def(\"RegisterSignalHandler\", []() {\n    if (ParseBooleanFromEnv(\"ONEFLOW_ENABLE_SIGNAL_HANDLER\", true)) {\n      Singleton<backward::SignalHandling>::New();\n    }\n  });\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/api/python/symbol/job_conf_symbol.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <pybind11/pybind11.h>\n#include <pybind11/operators.h>\n#include \"oneflow/core/common/throw.h\"\n#include \"oneflow/core/common/protobuf.h\"\n#include \"oneflow/api/python/of_api_registry.h\"\n#include \"oneflow/core/job/job_desc.h\"\n#include \"oneflow/core/job/job_conf.pb.h\"\n\nnamespace py = pybind11;\n\nnamespace oneflow {\n\nMaybe<JobDesc> CreateJobConfSymbol(int64_t symbol_id, const std::string& serialized_symbol_conf) {\n  JobConfigProto symbol_pb;\n  if (!TxtString2PbMessage(serialized_symbol_conf, &symbol_pb)) {\n    THROW(RuntimeError) << \"job conf parse failed.\\n\" << serialized_symbol_conf;\n  }\n  return JobDesc::New(symbol_id, symbol_pb);\n}\n\nONEFLOW_API_PYBIND11_MODULE(\"\", m) {\n  py::class_<JobDesc, std::shared_ptr<JobDesc>>(m, \"JobConfSymbol\")\n      .def(py::init([](int64_t symbol_id, const std::string& serialized_symbol_conf) {\n        return CreateJobConfSymbol(symbol_id, serialized_symbol_conf).GetPtrOrThrow();\n      }))\n      .def_property_readonly(\"symbol_id\",\n                             [](const JobDesc& x) {\n                               if (!x.symbol_id().has_value()) {\n                                 THROW(RuntimeError) << \"symbol_id not initialized\";\n                               }\n                               return CHECK_JUST(x.symbol_id());\n                             })\n      .def_property_readonly(\"data\", [](const JobDesc& job_conf_sym) -> std::string {\n        return PbMessage2TxtString(job_conf_sym.job_conf());\n      });\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/api/python/symbol/op_conf_symbol.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <pybind11/pybind11.h>\n#include <pybind11/operators.h>\n#include \"oneflow/core/common/throw.h\"\n#include \"oneflow/api/python/of_api_registry.h\"\n#include \"oneflow/core/operator/op_conf_symbol.h\"\n#include \"oneflow/core/common/maybe.h\"\n\nnamespace py = pybind11;\n\nnamespace oneflow {\n\nONEFLOW_API_PYBIND11_MODULE(\"\", m) {\n  py::class_<OperatorConfSymbol, std::shared_ptr<OperatorConfSymbol>>(m, \"OpConfSymbol\")\n      .def_property_readonly(\"symbol_id\",\n                             [](const OperatorConfSymbol& x) {\n                               if (!x.symbol_id().has_value()) {\n                                 THROW(RuntimeError) << \"symbol_id not initialized\";\n                               }\n                               return CHECK_JUST(x.symbol_id());\n                             })\n      .def_property_readonly(\"data\", &OperatorConfSymbol::data);\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/api/python/symbol/placement_symbol.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <pybind11/numpy.h>\n#include <pybind11/stl.h>\n#include <pybind11/operators.h>\n\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/extension/python/numpy.h\"\n#include \"oneflow/api/python/framework/size.h\"\n#include \"oneflow/api/python/of_api_registry.h\"\n#include \"oneflow/core/control/global_process_ctx.h\"\n#include \"oneflow/core/common/symbol.h\"\n#include \"oneflow/core/framework/instructions_builder.h\"\n#include \"oneflow/core/framework/parallel_conf_util.h\"\n#include \"oneflow/core/framework/to_string.h\"\n#include \"oneflow/core/job/parallel_desc.h\"\n#include \"oneflow/core/job/global_for.h\"\n#include \"oneflow/core/job/resource_desc.h\"\n#include \"oneflow/core/ep/include/device_manager_registry.h\"\n\nnamespace py = pybind11;\n\nnamespace oneflow {\n\nnamespace {\n\nint64_t GetDeviceCount(const std::string& device_name) {\n  return Singleton<ep::DeviceManagerRegistry>::Get()->GetDeviceCount(device_name);\n}\n\nstruct PlacementSymbolExportUtil {\n  static Maybe<void> CheckDeviceTag(const std::string& type) {\n    if (!TRY(DeviceType4DeviceTag(type)).IsOk()) {\n      return Error::RuntimeError() << \"Expected one of \" << PrintAvailableDevices()\n                                   << \" device type at start of device string: \" << type;\n    }\n    return Maybe<void>::Ok();\n  }\n\n  static Maybe<ParallelDesc> CreateParallelDesc(\n      const std::string& type, const std::vector<std::string>& formated_machine_device_ids,\n      const std::shared_ptr<Shape>& hierarchy_shape) {\n    JUST(CheckDeviceTag(type));\n    auto parallel_conf = JUST(MakeParallelConf(type, formated_machine_device_ids, hierarchy_shape));\n    std::shared_ptr<ParallelDesc> parallel_desc;\n    JUST(PhysicalRun([&parallel_desc, &parallel_conf](InstructionsBuilder* builder) -> Maybe<void> {\n      parallel_desc = JUST(builder->GetParallelDescSymbol(*parallel_conf));\n      return Maybe<void>::Ok();\n    }));\n\n    return parallel_desc;\n  }\n\n  static Maybe<ParallelDesc> CreateParallelDesc(const std::string& proto_str) {\n    ParallelConf parallel_conf;\n    CHECK_OR_RETURN(TxtString2PbMessage(proto_str, &parallel_conf))\n        << \" Get ParallelConf Pb from string failed.\";\n    std::shared_ptr<ParallelDesc> parallel_desc;\n    JUST(PhysicalRun([&parallel_desc, &parallel_conf](InstructionsBuilder* builder) -> Maybe<void> {\n      parallel_desc = JUST(builder->GetParallelDescSymbol(parallel_conf));\n      return Maybe<void>::Ok();\n    }));\n\n    return parallel_desc;\n  }\n\n  static Maybe<std::vector<std::string>> ParseAndFormatRanks(const py::dict& device_ids) {\n    std::vector<std::pair<int64_t, int64_t>> machine_device_id_vec;\n    for (const auto& pair : device_ids) {\n      CHECK_OR_RETURN(py::isinstance<py::int_>(pair.first))\n          << \"The key (node id) of placement device_ids must be int64.\";\n      int64_t machine_id = pair.first.cast<int64_t>();\n      if (py::isinstance<py::int_>(pair.second)) {\n        machine_device_id_vec.emplace_back(machine_id, pair.second.cast<int64_t>());\n      } else {\n        CHECK_OR_RETURN(py::isinstance<py::iterable>(pair.second))\n            << \"Value of device_ids dict must be int, list or range\";\n        for (const auto& device_id : pair.second) {\n          CHECK_OR_RETURN(py::isinstance<py::int_>(device_id))\n              << \"Value of device_ids dict must be int, list or range of int.\";\n          machine_device_id_vec.emplace_back(machine_id, device_id.cast<int64_t>());\n        }\n      }\n    }\n    auto formated_machine_device_ids = std::make_shared<std::vector<std::string>>();\n    for (const auto& pair : machine_device_id_vec) {\n      const std::string& device_name =\n          std::to_string(pair.first) + \":\" + std::to_string(pair.second);\n      formated_machine_device_ids->emplace_back(device_name);\n    }\n    return formated_machine_device_ids;\n  }\n\n  static Maybe<Shape> GetRanksShape(PyArrayObject* ranks) {\n    auto* shape = PyArray_SHAPE(ranks);\n    return std::make_shared<Shape>(DimVector(shape, shape + PyArray_NDIM(ranks)));\n  }\n\n  // Parse and format ranks to string \"machine_id:local_rank\"\n  static Maybe<std::vector<std::string>> ParseAndFormatRanks(PyArrayObject* ranks) {\n    size_t size = PyArray_SIZE(ranks);\n    CHECK_EQ_OR_RETURN(PyArray_TYPE(ranks), NPY_INT64)\n        << Error::RuntimeError() << \"placement ranks shoule be an array of long int\";\n    int64_t* rank_data = static_cast<int64_t*>(PyArray_DATA(ranks));\n\n    std::vector<std::pair<int64_t, int64_t>> machine_device_id_vec;\n    for (int i = 0; i < size; ++i) {\n      int64_t rank = rank_data[i];\n      int64_t machine_id = GlobalProcessCtx::NodeId(rank);\n      int64_t device_id = GlobalProcessCtx::LocalRank(rank);\n      machine_device_id_vec.emplace_back(machine_id, device_id);\n    }\n\n    auto formated_machine_device_ids = std::make_shared<std::vector<std::string>>();\n    for (const auto& pair : machine_device_id_vec) {\n      auto device_name = std::to_string(pair.first) + \":\" + std::to_string(pair.second);\n      formated_machine_device_ids->emplace_back(device_name);\n    }\n    return formated_machine_device_ids;\n  }\n\n  static Maybe<Symbol<ParallelDesc>> CreateParallelDescSymbol(\n      const std::string& type, const py::dict& device_ids,\n      const std::shared_ptr<Shape>& hierarchy) {\n    const auto& formated_machine_device_ids = JUST(ParseAndFormatRanks(device_ids));\n    return SymbolOf(*JUST(CreateParallelDesc(type, *formated_machine_device_ids, hierarchy)));\n  }\n\n  // create Symbol<ParallelDesc> object through given device_type and ranks parameters\n  static Maybe<Symbol<ParallelDesc>> CreateParallelDescSymbol(const std::string& type,\n                                                              const py::object& ranks) {\n    auto* obj = reinterpret_cast<PyArrayObject*>(PyArray_FromAny(\n        ranks.ptr(), nullptr, 0, 0, NPY_ARRAY_DEFAULT | NPY_ARRAY_ENSURECOPY, nullptr));\n    if (!obj) { return Error::RuntimeError() << \"placement ranks shoule be an array of long int\"; }\n\n    const auto& shape = JUST(GetRanksShape(obj));\n    const auto& formated_machine_device_ids = JUST(ParseAndFormatRanks(obj));\n    return SymbolOf(*JUST(CreateParallelDesc(type, *formated_machine_device_ids, shape)));\n  }\n\n  static Maybe<Symbol<ParallelDesc>> CreateParallelDescSymbol(const std::string& proto_str) {\n    return SymbolOf(*JUST(CreateParallelDesc(proto_str)));\n  }\n\n  static Maybe<Symbol<ParallelDesc>> AllDevicePlacement(const std::string& type) {\n    static thread_local HashMap<std::string, Symbol<ParallelDesc>> device_tag2placement;\n    CHECK_NOTNULL((Singleton<ResourceDesc, ForEnv>::Get()));\n    JUST(CheckDeviceTag(type));\n    auto it = device_tag2placement.find(type);\n    if (it == device_tag2placement.end()) {\n      int64_t node_size = GlobalProcessCtx::NodeSize();\n      int64_t device_num = GlobalProcessCtx::NumOfProcessPerNode();\n      if (type != \"cpu\") {\n        const int64_t device_count = GetDeviceCount(type);\n        CHECK_NE_OR_RETURN(device_count, 0)\n            << Error::RuntimeError() << \"Can\\'t construct placement with \\\"\" << type\n            << \"\\\" type because there is no device!\";\n        device_num = std::min(device_num, device_count);\n      }\n      std::vector<std::string> machine_device_ids;\n      for (int64_t node_id = 0; node_id < node_size; ++node_id) {\n        std::string device_name = std::to_string(node_id) + \":0-\" + std::to_string(device_num - 1);\n        machine_device_ids.emplace_back(device_name);\n      }\n      Symbol<ParallelDesc> placement =\n          SymbolOf(*JUST(CreateParallelDesc(type, machine_device_ids, std::shared_ptr<Shape>())));\n      it = device_tag2placement.emplace(type, placement).first;\n    }\n    return it->second;\n  }\n\n  static Maybe<py::array> GetPlacementRanks(const Symbol<ParallelDesc>& placement) {\n    py::list ranks;\n    for (int64_t machine_id : placement->sorted_machine_ids()) {\n      int64_t node_id = GlobalProcessCtx::NodeId(machine_id);\n      for (int64_t device_id : placement->sorted_dev_phy_ids(machine_id)) {\n        ranks.append(py::cast(node_id * GlobalProcessCtx::NumOfProcessPerNode() + device_id));\n      }\n    }\n    auto array_ranks = py::cast<py::array>(ranks);\n    array_ranks.resize(placement->hierarchy()->dim_vec());\n    return array_ranks;\n  }\n};\n\n}  // namespace\n\nONEFLOW_API_PYBIND11_MODULE(\"\", m) {\n  py::class_<Symbol<ParallelDesc>, std::shared_ptr<Symbol<ParallelDesc>>>(m, \"placement\",\n                                                                          py::dynamic_attr())\n      .def(py::init([](const std::string& device_type, const py::dict& device_ids,\n                       const std::shared_ptr<Shape>& hierarchy) {\n             PyErr_WarnEx(\n                 PyExc_UserWarning,\n                 \"The way to construct placement is deprecated, and it will be removed in next \"\n                 \"versions. Please use oneflow.placement(type=str, ranks=int array) instead\",\n                 1);\n             return PlacementSymbolExportUtil::CreateParallelDescSymbol(device_type, device_ids,\n                                                                        hierarchy)\n                 .GetOrThrow();\n           }),\n           py::arg(\"device_type\"), py::arg(\"device_ids\"), py::arg(\"hierarchy\"))\n      .def(py::init([](const std::string& device_type, const py::dict& device_ids,\n                       const py::tuple& hierarchy) {\n             PyErr_WarnEx(\n                 PyExc_UserWarning,\n                 \"The way to construct placement is deprecated, and it will be removed in next \"\n                 \"versions. Please use oneflow.placement(type=str, ranks=int array) instead\",\n                 1);\n             DimVector shape_dims{};\n             for (const auto& dim : hierarchy) { shape_dims.emplace_back(dim.cast<int64_t>()); }\n             return PlacementSymbolExportUtil::CreateParallelDescSymbol(\n                        device_type, device_ids, std::make_shared<Shape>(shape_dims))\n                 .GetOrThrow();\n           }),\n           py::arg(\"device_type\"), py::arg(\"device_ids\"), py::arg(\"hierarchy\") = py::tuple())\n      .def(py::init([](const std::string& type, const py::object& ranks) {\n             return PlacementSymbolExportUtil::CreateParallelDescSymbol(type, ranks).GetOrThrow();\n           }),\n           py::arg(\"type\"), py::arg(\"ranks\"))\n      .def(py::init([](const std::string& proto_str) {\n             return PlacementSymbolExportUtil::CreateParallelDescSymbol(proto_str).GetOrThrow();\n           }),\n           py::arg(\"proto_str\"))\n      .def_property_readonly(\n          \"device_type\",\n          [](Symbol<ParallelDesc> p) {\n            PyErr_WarnEx(\n                PyExc_UserWarning,\n                \"The property .device_type of placement is deprecated, please use .type instead\",\n                1);\n            return p->device_tag();\n          })\n      .def_property_readonly(\"type\", [](Symbol<ParallelDesc> p) { return p->device_tag(); })\n      .def_property_readonly(\"hierarchy\",\n                             [](Symbol<ParallelDesc> p) {\n                               PyErr_WarnEx(PyExc_UserWarning,\n                                            \"The property .hierarchy of placement is deprecated, \"\n                                            \"please use .ranks.shape instead\",\n                                            1);\n                               return p->hierarchy();\n                             })\n      .def_property_readonly(\"ranks\", &PlacementSymbolExportUtil::GetPlacementRanks)\n      .def(\"__str__\", PlacementToString)\n      .def(\"__repr__\", PlacementToString)\n      .def(py::self == py::self)\n      .def(py::hash(py::self))\n      .def_static(\"all\", &PlacementSymbolExportUtil::AllDevicePlacement);\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/api/python/symbol/sbp_symbol.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <pybind11/pybind11.h>\n#include <pybind11/operators.h>\n#include \"oneflow/api/python/of_api_registry.h\"\n#include \"oneflow/api/common/sbp.h\"\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/common/constant.h\"\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/common/symbol.h\"\n#include \"oneflow/core/framework/nd_sbp.h\"\n#include \"oneflow/core/job/sbp_parallel.h\"\n#include \"oneflow/core/framework/nd_sbp.h\"\n\nnamespace py = pybind11;\n\nnamespace oneflow {\n\nnamespace {\n\nMaybe<std::vector<Symbol<SbpParallel>>> MakeSplitSbpParallelList(int max_split_axis) {\n  std::shared_ptr<std::vector<Symbol<SbpParallel>>> ret =\n      std::make_shared<std::vector<Symbol<SbpParallel>>>(max_split_axis);\n  for (int i = 0; i < max_split_axis; ++i) { ret->at(i) = JUST(MakeSplitSbpParallel(i)); }\n  return ret;\n}\n\nMaybe<Symbol<SbpParallel>> GetSplitSbpParallel(int axis) {\n  CHECK_GE_OR_RETURN(axis, 0) << Error::RuntimeError()\n                              << \"Split axis must not be negative, but got \" << axis << \"!\";\n  CHECK_LT_OR_RETURN(axis, kMaxSplitAxis)\n      << Error::RuntimeError() << \"Expected split axis to be less than the supported maximum axis (\"\n      << kMaxSplitAxis << \"), but got \" << axis << \"!\";\n  static std::vector<Symbol<SbpParallel>> split_sbp_sym_list =\n      *JUST(MakeSplitSbpParallelList(kMaxSplitAxis));\n  return split_sbp_sym_list.at(axis);\n}\n\nMaybe<Symbol<SbpParallel>> GetBroadcastSbpParallel() {\n  static Symbol<SbpParallel> broadcast_sbp = JUST(MakeBroadcastSbpParallel());\n  return broadcast_sbp;\n}\n\nMaybe<Symbol<SbpParallel>> GetPartialSumSbpParallel() {\n  static Symbol<SbpParallel> partial_sum_sbp = JUST(MakePartialSumSbpParallel());\n  return partial_sum_sbp;\n}\n\nMaybe<std::pair<std::string, int>> SbpGetState(const Symbol<SbpParallel>& sbp) {\n  if (sbp->has_broadcast_parallel()) {\n    return std::make_shared<std::pair<std::string, int>>(\"B\", -1);\n  } else if (sbp->has_partial_sum_parallel()) {\n    return std::make_shared<std::pair<std::string, int>>(\"P\", -1);\n  } else if (sbp->has_split_parallel()) {\n    return std::make_shared<std::pair<std::string, int>>(\"S\", sbp->split_parallel().axis());\n  } else {\n    return Error::RuntimeError() << \"Invalid sbp signature: \" << sbp->DebugString();\n  }\n}\n\nMaybe<Symbol<SbpParallel>> GetSbpFromState(const std::pair<std::string, int>& state) {\n  if (state.first == \"B\") {\n    return GetBroadcastSbpParallel();\n  } else if (state.first == \"P\") {\n    return GetPartialSumSbpParallel();\n  } else if (state.first == \"S\") {\n    return GetSplitSbpParallel(state.second);\n  } else {\n    return Error::RuntimeError() << \"Invalid sbp signature state: (\" << state.first << \", \"\n                                 << state.second << \");\";\n  }\n}\n\n}  // namespace\n\nONEFLOW_API_PYBIND11_MODULE(\"sbp\", m) {\n  m.attr(\"max_split_axis\") = kMaxSplitAxis;\n  py::class_<Symbol<SbpParallel>, std::shared_ptr<Symbol<SbpParallel>>>(m, \"sbp\",\n                                                                        py::dynamic_attr())\n      .def(\"__str__\", &api::ApiSbpToString)\n      .def(\"__repr__\", &api::ApiSbpToString)\n      .def(py::self == py::self)\n      .def(py::hash(py::self))\n      .def(\"_ToAttrStr\",\n           [](const Symbol<SbpParallel>& sbp_sym) { return SbpParallelToString(*sbp_sym); })\n      .def(py::pickle(\n          [](const Symbol<SbpParallel>& sbp) {  // __getstate__\n            return SbpGetState(sbp).GetOrThrow();\n          },\n          [](const std::pair<std::string, int>& state) {  // __setstate__\n            return GetSbpFromState(state).GetOrThrow();\n          }));\n  m.def(\"split\", GetSplitSbpParallel, py::arg(\"axis\"));\n  m.def(\"broadcast\", &GetBroadcastSbpParallel);\n  m.def(\"partial_sum\", &GetPartialSumSbpParallel);\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/api/python/symbol/scope_symbol.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <pybind11/pybind11.h>\n#include <pybind11/operators.h>\n#include \"oneflow/core/common/throw.h\"\n#include \"oneflow/api/python/of_api_registry.h\"\n#include \"oneflow/core/common/protobuf.h\"\n#include \"oneflow/core/job/scope.h\"\n\nnamespace py = pybind11;\n\nnamespace oneflow {\n\nMaybe<Scope> CreateScopeSymbol(int64_t symbol_id, const std::string& symbol_conf_str) {\n  ScopeProto symbol_pb;\n  if (!TxtString2PbMessage(symbol_conf_str, &symbol_pb)) {\n    THROW(RuntimeError) << \"symbol conf parse failed.\\n\" << symbol_conf_str;\n  }\n  return Scope::New(symbol_id, symbol_pb);\n}\n\nONEFLOW_API_PYBIND11_MODULE(\"\", m) {\n  py::class_<Scope, std::shared_ptr<Scope>>(m, \"ScopeSymbol\")\n      .def(py::init([](int64_t symbol_id, const std::string& symbol_conf_str) {\n        return CreateScopeSymbol(symbol_id, symbol_conf_str).GetPtrOrThrow();\n      }))\n      .def_property_readonly(\"symbol_id\",\n                             [](const Scope& x) {\n                               if (!x.symbol_id().has_value()) {\n                                 THROW(RuntimeError) << \"symbol_id not initialized\";\n                               }\n                               return CHECK_JUST(x.symbol_id());\n                             })\n      .def_property_readonly(\"_proto_str\",\n                             [](const Scope& x) { return PbMessage2TxtString(x.scope_proto()); })\n      .def(\"auto_increment_id\", &Scope::auto_increment_id)\n      .def_property_readonly(\"session_id\", &Scope::session_id)\n      .def_property_readonly(\"job_desc_symbol\", &Scope::job_desc_symbol)\n      .def_property_readonly(\n          \"device_parallel_desc_symbol\",\n          [](const Scope& x) { return x.device_parallel_desc_symbol().shared_from_symbol(); })\n      .def_property_readonly(\"parent_scope_symbol\", &Scope::parent_scope_symbol)\n      .def(\"MakeChildScopeProto\", &Scope::MakeChildScopeProto);\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/api/python/utils/dataloader.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef _WIN32\n\n#include <atomic>\n#include <map>\n#include <set>\n#include <csignal>\n#include <sstream>\n#include <sys/wait.h>\n\n#include <pybind11/pybind11.h>\n#include \"oneflow/api/python/of_api_registry.h\"\n\n#include <stdexcept>\n\nnamespace oneflow {\n\nnamespace py = pybind11;\n\n// reference: pytorch/torch/csrc/DataLoader.cpp\n// https://github.com/pytorch/pytorch/blob/d69c22dd61a2f006dcfe1e3ea8468a3ecaf931aa/torch/csrc/DataLoader.cpp\n\n// Critical signal handlers should be registered on worker processes before\n// doing work.\n// The handler will raise default handler so that the kill information will be\n// retrieved from main process.\n// Python handle is _set_worker_signal_handlers().\n#define SIGNAL_HANDLER(SIGNAL, HANDLER_NAME, ERROR_MSG)                          \\\n  static void HANDLER_NAME(int sig, siginfo_t* info, void* ctx) {                \\\n    auto _w = write(STDERR_FILENO, ERROR_MSG, sizeof(ERROR_MSG) / sizeof(char)); \\\n    (void)_w;                                                                    \\\n    struct sigaction sa {};                                                      \\\n    sa.sa_handler = SIG_DFL;                                                     \\\n    sa.sa_flags = 0;                                                             \\\n    if (sigemptyset(&sa.sa_mask) != 0 || sigaction(SIGNAL, &sa, nullptr) != 0) { \\\n      _exit(EXIT_FAILURE);                                                       \\\n    } else {                                                                     \\\n      raise(SIGNAL);                                                             \\\n    }                                                                            \\\n  }\n\n// signal(2) is really not portable. So use sigaction.\n// http://man7.org/linux/man-pages/man2/signal.2.html\nstatic inline void setSignalHandler(int signal, void (*handler)(int, siginfo_t*, void*),\n                                    struct sigaction* old_sa_ptr) {\n  struct sigaction sa {};\n  sa.sa_sigaction = handler;\n  sa.sa_flags = SA_RESTART | SA_SIGINFO | SA_NOCLDSTOP | SA_NODEFER;\n  if (sigemptyset(&sa.sa_mask) != 0 || sigaction(signal, &sa, old_sa_ptr) != 0) {\n    std::ostringstream oss;\n    oss << \"An error occurred while setting handler for \" << strsignal(signal) << \".\";\n    throw std::runtime_error(oss.str());\n  }\n}\n\nSIGNAL_HANDLER(SIGBUS, handler_SIGBUS,\n               \"ERROR: Unexpected bus error encountered in worker. \"\n               \"This might be caused by insufficient shared memory (shm).\\n\");\nSIGNAL_HANDLER(SIGSEGV, handler_SIGSEGV,\n               \"ERROR: Unexpected segmentation fault encountered in worker.\\n\");\nSIGNAL_HANDLER(SIGFPE, handler_SIGFPE,\n               \"ERROR: Unexpected floating-point exception encountered in worker.\\n\");\n\n// When an error happened in DataLoader methods and Python starts to exit, the\n// error trace will keep the loader alive, and Python may kill the children\n// processes first before deleting the loader object. Then the cleaning up\n// methods in DataLoader.__del__ are not yet called, and SIGCHILD will print an\n// error saying a worker is killed by SIGTERM. So we suppress SIGTERM from main\n// loader process here to avoid this by _exit(EXIT_SUCCESS). Note that if we\n// exit with nonzero code, the loader SIGCHLD handler may report RuntimeError\n// again, and then it defeats the whole purpose.\nstatic void handler_SIGTERM(int sig, siginfo_t* info, void* ctx) {\n  if (info->si_pid == getppid()) { _exit(EXIT_SUCCESS); }\n  struct sigaction sa {};\n  sa.sa_handler = SIG_DFL;\n  sa.sa_flags = 0;\n  if (sigemptyset(&sa.sa_mask) != 0 || sigaction(SIGTERM, &sa, nullptr) != 0) {\n    _exit(EXIT_FAILURE);\n  } else {\n    raise(SIGTERM);\n  }\n}\n\nstatic void set_worker_signal_handlers() {\n  setSignalHandler(SIGBUS, &handler_SIGBUS, nullptr);\n  setSignalHandler(SIGSEGV, &handler_SIGSEGV, nullptr);\n  setSignalHandler(SIGTERM, &handler_SIGTERM, nullptr);\n  setSignalHandler(SIGFPE, &handler_SIGFPE, nullptr);\n}\n\n// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)\nstatic std::map<int64_t, std::set<pid_t>> worker_pids = {};\n\nstatic void error_if_any_worker_fails() {\n  // NOLINTNEXTLINE(cppcoreguidelines-init-variables)\n  int error;\n  // NOLINTNEXTLINE(cppcoreguidelines-init-variables)\n  std::set<pid_t>* pid_set;\n  // NOLINTNEXTLINE(cppcoreguidelines-init-variables)\n  pid_t worker_pid;\n  siginfo_t infop;\n\n  // Only check the pids we care about\n  for (auto& w : worker_pids) {\n    pid_set = &(w.second);\n    for (auto pid_it = pid_set->begin(); pid_it != pid_set->end(); ++pid_it) {\n      worker_pid = *pid_it;\n      // Use waitid rather than waitpid so that we can set NOWAIT, and that Python\n      // and other handlers can get whatever info they want about the child.\n      infop.si_pid = 0;\n      error = waitid(P_PID, worker_pid, &infop, WEXITED | WNOHANG | WNOWAIT);\n      // ignore errors and case with no waitable child\n      if (error < 0 || infop.si_pid == 0) continue;\n      if (infop.si_code == CLD_EXITED && infop.si_status != EXIT_SUCCESS) {  // exit with error\n        std::ostringstream oss;\n        oss << \"DataLoader worker (pid \" << worker_pid << \") exited \"\n            << \"unexpectedly with exit code \" << infop.si_status << \". \"\n            << \"Details are lost due to multiprocessing. Rerunning with \"\n            << \"num_workers=0 may give better error trace.\";\n        // This is necessary. Otherwise, the runtime error will kill the other\n        // workers, and trigger this again.\n        pid_set->clear();\n        throw std::runtime_error(oss.str());\n      } else if (infop.si_code == CLD_KILLED || infop.si_code == CLD_DUMPED) {  // killed by signal\n        std::ostringstream oss;\n        oss << \"DataLoader worker (pid \" << worker_pid << \") is killed \"\n            << \"by signal: \" << strsignal(infop.si_status) << \". \";\n        if (infop.si_status == SIGBUS) {\n          oss << \"It is possible that dataloader's workers are out of shared memory. \"\n              << \"Please try to raise your shared memory limit.\";\n        }\n        // This is necessary. Otherwise, the runtime error will kill the other\n        // workers, and trigger this again.\n        pid_set->clear();\n        throw std::runtime_error(oss.str());\n      }\n    }\n  }\n}\n\ninline int64_t utils_unpackLong(PyObject* obj) {\n  // NOLINTNEXTLINE(cppcoreguidelines-init-variables)\n  int overflow;\n  long long value = PyLong_AsLongLongAndOverflow(obj, &overflow);\n  if (value == -1 && PyErr_Occurred()) { throw py::value_error(); }\n  if (overflow != 0) { throw std::runtime_error(\"Overflow when unpacking long\"); }\n  return (int64_t)value;\n}\n\n// We don't want to exit on any SIGCHLD from any child. child_pids is a tuple\n// of pids we are interested in.\nstatic void set_worker_pids(py::args py_args) {\n  PyObject* args = py_args.ptr();\n  if (PyTuple_GET_SIZE(args) != 2) {\n    throw py::type_error(\"_set_worker_pids expects exactly 2 arguments.\");\n  }\n  int64_t key = utils_unpackLong(PyTuple_GET_ITEM(args, 0));\n  if (worker_pids.find(key) != worker_pids.end()) {\n    throw py::value_error(\n        \"_set_worker_pids should be called only once for each _BaseDataLoaderIter.\");\n  }\n  PyObject* child_pids = PyTuple_GET_ITEM(args, 1);\n  if (!PyTuple_Check(child_pids)) {\n    py::print(\"_set_worker_pids expects a tuple for child_pids, but got: \",\n              Py_TYPE(child_pids)->tp_name);\n    throw py::type_error(\"_set_worker_pids expects a tuple for child_pids\");\n  }\n\n  std::set<pid_t> pids_set = {};\n  auto size = PyTuple_GET_SIZE(child_pids);\n  for (int idx = 0; idx < size; idx++) {\n    PyObject* obj = PyTuple_GET_ITEM(child_pids, idx);\n    pids_set.insert(static_cast<pid_t>(utils_unpackLong(obj)));\n  }\n\n  worker_pids[key] = pids_set;\n}\n\nstatic void remove_worker_pids(py::args py_args) {\n  PyObject* args = py_args.ptr();\n  int64_t key = utils_unpackLong(PyTuple_GET_ITEM(args, 0));\n  auto it = worker_pids.find(key);\n  if (it == worker_pids.end()) {\n    py::print(\"Cannot find worker information for _BaseDataLoaderIter with id :\", key);\n    throw py::value_error(\"Cannot find worker information for _BaseDataLoaderIter\");\n  }\n  worker_pids.erase(it);\n}\n\n#undef SIGNAL_HANDLER\n\n#else\n// dummy implementations for windows\n\nstatic PyObject* set_worker_signal_handlers(PyObject* module, PyObject* _ignored) {\n  Py_RETURN_NONE;\n}\n\nstatic PyObject* set_worker_pids(PyObject* module, PyObject* _ignored) { Py_RETURN_NONE; }\n\nstatic PyObject* remove_worker_pids(PyObject* module, PyObject* _ignored) { Py_RETURN_NONE; }\n\nstatic PyObject* error_if_any_worker_fails(PyObject* module, PyObject* _ignored) { Py_RETURN_NONE; }\n\n#endif\n\nONEFLOW_API_PYBIND11_MODULE(\"\", m) {\n  m.def(\"_set_worker_signal_handlers\", &set_worker_signal_handlers);\n  m.def(\"_set_worker_pids\", &set_worker_pids);\n  m.def(\"_remove_worker_pids\", &remove_worker_pids);\n  m.def(\"_error_if_any_worker_fails\", &error_if_any_worker_fails);\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/api/python/utils/tensor_utils.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/api/python/utils/tensor_utils.h\"\n\n#include \"oneflow/core/autograd/autograd_engine.h\"\n#include \"oneflow/core/common/container_util.h\"\n#include \"oneflow/core/common/device_type.pb.h\"\n#include \"oneflow/core/common/switch_func.h\"\n#include \"oneflow/core/common/tensor_buffer.h\"\n#include \"oneflow/core/framework/nd_sbp.h\"\n#include \"oneflow/core/functional/functional.h\"\n#include \"oneflow/core/job/global_mode.h\"\n#include \"oneflow/core/kernel/kernel_util.h\"\n#include \"oneflow/extension/python/numpy.h\"\n#include \"oneflow/core/common/decorator.h\"\n#include \"oneflow/core/framework/consistency_check.h\"\n#include \"oneflow/core/functional/impl/common.h\"\n\nnamespace py = pybind11;\n\nnamespace oneflow {\nnamespace one {\n\nMaybe<void> EagerLocalTensorZeros(const std::shared_ptr<Tensor>& t) {\n  JUST(functional::CheckInplaceValid(t));\n  std::shared_ptr<LocalTensor> local_tensor;\n  if (t->is_local()) {\n    local_tensor = JUST(t->AsLocalTensor());\n  } else {\n    local_tensor = JUST(t->cur_rank_phy_tensor());\n  }\n  CHECK_OR_RETURN(local_tensor->is_eager()) << \"eager tensors supported only\";\n  JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe<void> {\n    JUST(builder->AccessBlobByCallback(\n        local_tensor,\n        [](ep::Stream* stream, const std::shared_ptr<vm::EagerBlobObject>& eager_blob_object) {\n          AutoMemset(stream, eager_blob_object->mut_dptr(), 0,\n                     eager_blob_object->ByteSizeOfBlobBody(), eager_blob_object->mem_case());\n        },\n        \"mut\"));\n    return Maybe<void>::Ok();\n  }));\n  return Maybe<void>::Ok();\n}\n\nnamespace {\nvoid CopyFromNumpyArray(ep::Stream* stream,\n                        const std::shared_ptr<vm::EagerBlobObject>& eager_blob_object,\n                        const NumPyArrayPtr& array_ptr) {\n  SyncAutoMemcpy(stream, eager_blob_object->mut_dptr(), array_ptr.data(),\n                 eager_blob_object->ByteSizeOfBlobBody(), eager_blob_object->mem_case(),\n                 memory::MakeHostMemCase());\n}\n}  // namespace\n\nMaybe<void> CopyLocalTensorFromUntypedArray(const std::shared_ptr<Tensor>& tensor,\n                                            PyObject* array) {\n  return CopyBetweenLocalTensorAndNumpy(tensor, array, CopyFromNumpyArray, \"mut\",\n                                        /*block_host_until_done=*/false);\n}\n\nMaybe<std::tuple<std::vector<Shape>, std::vector<Symbol<DType>>>>\nMaybeGetTensorBufferShapesAndDTypes(const std::shared_ptr<Tensor>& t) {\n  const auto& tensor = JUST(t->AsLocalTensor());\n  if (tensor->dtype() != DType::TensorBuffer()) {\n    return Error::RuntimeError() << \"tensor buffer supported only\";\n  }\n  CHECK_OR_RETURN(tensor->is_eager()) << \"eager tensors supported only\";\n  std::vector<Shape> shapes;\n  std::vector<Symbol<DType>> dtypes;\n\n  auto btb = std::make_shared<BlockingThenBusy>();\n  JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe<void> {\n    return builder->SyncAccessBlobByCallback(\n        tensor, btb, [](ep::Stream* stream, const std::shared_ptr<vm::EagerBlobObject>&) {},\n        \"const\");\n  }));\n  JUST(btb->WaitUntilCntEqualZero(VirtualMachine::GetPredicatorNoMoreInstructionsFinished()));\n\n  const auto& eager_blob_object = JUST(tensor->eager_blob_object());\n  const Shape& blob_shape = eager_blob_object->shape();\n  const auto* tensor_buffer_ptr = eager_blob_object->dptr<TensorBuffer>();\n  for (int64_t i = 0; i < blob_shape.elem_cnt(); ++i) {\n    const TensorBuffer* tensor_buffer = tensor_buffer_ptr + i;\n    shapes.emplace_back(tensor_buffer->shape());\n    dtypes.emplace_back(DType::Get(tensor_buffer->data_type()).GetOrThrow());\n  }\n  return std::make_tuple(shapes, dtypes);\n}\n\nMaybe<void> RegisterTensorHook(const std::shared_ptr<Tensor>& self,\n                               const AutogradMeta::Hook& hook) {\n  CHECK_OR_RETURN(self->requires_grad())\n      << \"cannot register a hook on a tensor that doesn't require gradient\";\n  if (!self->grad_fn_node()) { JUST(AddAccumulateFunctionNode(self)); }\n  self->mut_autograd_meta()->add_hook(hook);\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> RegisterTensorPostGradAccumulationHook(const std::shared_ptr<Tensor>& self,\n                                                   const AutogradMeta::Hook& hook) {\n  if (!self->grad_fn_node()) { JUST(AddAccumulateFunctionNode(self)); }\n  self->mut_autograd_meta()->add_post_grad_accumulation_hook(hook);\n  return Maybe<void>::Ok();\n}\n\nMaybe<py::tuple> TensorGetPyTupleOfSbp(const Tensor& tensor) {\n  const auto& nd_sbp = JUST(tensor.nd_sbp());\n  const auto& tuple = std::make_shared<py::tuple>(nd_sbp->sbp_parallel_size());\n  for (int i = 0; i < nd_sbp->sbp_parallel_size(); ++i) {\n    (*tuple)[i] = SymbolOf(nd_sbp->sbp_parallel(i));\n  }\n  return tuple;\n}\n\nMaybe<Tensor> MakeLocalTensorFromData(PyObject* data, const Optional<Symbol<DType>>& dtype,\n                                      const Optional<Symbol<Device>>& device,\n                                      const bool requires_grad, const bool pin_memory) {\n  bool is_bfloat16_dtype = dtype ? JUST(dtype)->data_type() == DataType::kBFloat16 : false;\n  bool is_cuda_device = device ? JUST(device)->enum_type() == DeviceType::kCUDA : false;\n  if (is_bfloat16_dtype && is_cuda_device) {\n#if CUDA_VERSION < 11000\n    return Error::RuntimeError()\n           << \"Cannot create a bfloat16 tensor on gpu under cuda version: 11000\";\n#endif  // CUDA_VERSION >= 11000\n  }\n  PyArray_Descr* np_dtype =\n      dtype.has_value() && !is_bfloat16_dtype\n          ? PyArray_DescrFromType(JUST(numpy::OFDataTypeToNumpyType(JUST(dtype)->data_type())))\n          : nullptr;\n  // NPY_ARRAY_DEFAULT is NPY_ARRAY_C_CONTIGUOUS | NPY_ARRAY_BEHAVED, so the\n  // array with NPY_ARRAY_DEFAULT flag is C-style contiguous.\n  // NPY_ARRAY_FORCECAST is needed otherwise there will a segfault.\n  //\n  // Even though PyArray_FromAny can cast the input array to the desired dtype\n  // if `dtype` argument is set, it fails to handle the following case:\n  // >> x = [flow.tensor([1, 2])] * 3 <-- x is a list of flow.Tensor\n  // >> y = flow.tensor(x, dtype=flow.float32) <-- returns nullptr\n  // However, the following case without `dtype` argument works well:\n  // >> x = [flow.tensor([1, 2])] * 3\n  // >> y = flow.tensor(x)\n  // So we cast the input array to the desired dtype manually.\n  PyArrayObject* _array = reinterpret_cast<PyArrayObject*>(\n      PyArray_FromAny(data, nullptr, 0, 0,\n                      NPY_ARRAY_DEFAULT | NPY_ARRAY_ENSURECOPY | NPY_ARRAY_FORCECAST, nullptr));\n  if (!_array) {\n    return Error::RuntimeError() << \"Can not convert input data to a new numpy array.\";\n  }\n  // PyArray_FromArray steals a reference to np_dtype object, so no need to decref it.\n  PyObject* array = PyArray_FromArray(\n      _array, np_dtype, NPY_ARRAY_DEFAULT | NPY_ARRAY_ENSURECOPY | NPY_ARRAY_FORCECAST);\n  Py_DECREF(_array);\n  auto* np_arr = reinterpret_cast<PyArrayObject*>(array);\n  const npy_intp* dims_ptr = PyArray_SHAPE(np_arr);\n  const Shape shape(DimVector(dims_ptr, dims_ptr + PyArray_NDIM(np_arr)));\n  DataType np_data_type = JUST(numpy::GetOFDataTypeFromNpArray(np_arr));\n\n  Symbol<Device> device_;\n  if (device) {\n    device_ = JUST(device);\n  } else {\n    device_ = JUST(Device::New(\"cpu\"));\n  }\n  std::shared_ptr<Tensor> tensor =\n      JUST(functional::Empty(shape, JUST(DType::Get(np_data_type)), device_,\n                             /*requires_grad=*/false, /*pin_memory=*/pin_memory));\n  if (device_->enum_type() != DeviceType::kMeta) {\n    JUST(CopyLocalTensorFromUntypedArray(tensor, array));\n  }\n\n  Py_DECREF(array);\n  if (dtype && JUST(dtype)->data_type() != np_data_type) {\n    tensor = JUST(functional::To(tensor, JUST(dtype), false));\n  } else if (!dtype && !PyArray_Check(data) && tensor->dtype()->is_floating_point()\n             && GetDefaultDType() != tensor->dtype()) {\n    // If it not assign dtype and created from PySequence, cast tensor to default floating dtype\n    tensor = JUST(functional::To(tensor, JUST(DType::Get(DataType::kFloat)), false));\n  }\n  JUST(tensor->set_requires_grad(requires_grad));\n  return tensor;\n}\n\nnamespace {\n\nMaybe<Symbol<NdSbp>> GetAllBroadcastNdSbp(size_t ndim) {\n  NdSbp broadcast_nd_sbp;\n  for (size_t i = 0; i < ndim; ++i) {\n    broadcast_nd_sbp.mutable_sbp_parallel()->Add()->mutable_broadcast_parallel();\n  }\n  return SymbolOf(broadcast_nd_sbp);\n}\n\nauto* CachedGetAllBroadcastNdSbp = DECORATE(&GetAllBroadcastNdSbp, ThreadLocal);\n\n}  // namespace\n\nMaybe<Tensor> MakeGlobalTensorFromData(PyObject* data, const Optional<Symbol<DType>>& dtype,\n                                       Symbol<ParallelDesc> placement,\n                                       const std::vector<Symbol<SbpParallel>>& sbp_tuple,\n                                       const bool requires_grad) {\n  PyObject* array = NULL;\n  if (PyArray_Check(data)) {\n    // Only NPY_CORDER is supported, and returns a new C-style contiguous array.\n    array = PyArray_NewCopy((PyArrayObject*)data, NPY_CORDER);\n  } else {\n    // NPY_ARRAY_DEFAULT is NPY_ARRAY_C_CONTIGUOUS | NPY_ARRAY_BEHAVED, so the\n    // array with NPY_ARRAY_DEFAULT flag is C-style contiguous.\n    array = PyArray_FromAny(data, nullptr, 0, 0, NPY_ARRAY_DEFAULT | NPY_ARRAY_ENSURECOPY, nullptr);\n    if (!array) { return Error::RuntimeError() << \"Can not convert input data to a numpy array.\"; }\n  }\n  auto* np_arr = reinterpret_cast<PyArrayObject*>(array);\n  const npy_intp* dims_ptr = PyArray_SHAPE(np_arr);\n  const Shape shape(DimVector(dims_ptr, dims_ptr + PyArray_NDIM(np_arr)));\n  DataType data_type = JUST(numpy::GetOFDataTypeFromNpArray(np_arr));\n\n  if (placement->parallel_num() > 1) {\n    const void* buf_ptr = PyArray_DATA(np_arr);\n    size_t array_size = PyArray_SIZE(np_arr);\n    CHECK_EQ_OR_RETURN(array_size, shape.elem_cnt());\n    size_t byte_size = array_size * GetSizeOfDataType(data_type);\n    JUST(DataConsistencyCheck(buf_ptr, byte_size, placement));\n  }\n\n  Symbol<Device> device = JUST(Device::New(placement->device_tag()));\n  std::shared_ptr<Tensor> local_tensor;\n  {\n    GlobalMode::Guard guard(/* disable global mode */ false);\n    local_tensor =\n        JUST(functional::Empty(shape, JUST(DType::Get(data_type)), device, /*requires_grad=*/false,\n                               /*pin_memory=*/false));\n  }\n  if (device->enum_type() != DeviceType::kMeta) {\n    JUST(CopyLocalTensorFromUntypedArray(local_tensor, array));\n  }\n\n  Py_DECREF(array);\n  // Cast to float if data is double sequence, rather than numpy array.\n  Symbol<DType> dtype_;\n  if (dtype) {\n    dtype_ = JUST(dtype);\n  } else if (!dtype && data_type == DataType::kDouble && !PyArray_Check(data)) {\n    dtype_ = DType::Float();\n  }\n  if (dtype_) { local_tensor = JUST(functional::Cast(local_tensor, dtype_, /*pin_memory=*/false)); }\n\n  size_t sbp_dims = sbp_tuple.size();\n  Symbol<NdSbp> broadcast_nd_sbp = JUST(CachedGetAllBroadcastNdSbp(sbp_dims));\n\n  std::shared_ptr<Tensor> broadcast_tensor = JUST(\n      functional::LocalToGlobal(local_tensor, placement, *JUST(GetSbpList(broadcast_nd_sbp)), shape,\n                                local_tensor->dtype(), /* sync_data */ true, /*copy=*/false));\n\n  std::vector<Symbol<SbpParallel>> grad_sbp_tuple;\n  auto global_tensor =\n      JUST(functional::ToGlobal(broadcast_tensor, placement, sbp_tuple, grad_sbp_tuple,\n                                /* check_meta */ false, /*copy=*/false));\n  JUST(global_tensor->set_requires_grad(requires_grad));\n  return global_tensor;\n}\n\nMaybe<Tensor> MakeTensorFromOtherTensor(const std::shared_ptr<Tensor>& other,\n                                        const bool pin_memory) {\n  if (other->is_local()) {\n    const Symbol<Device>& device = JUST(other->device());\n    return functional::Copy(other, device->type(), device->device_id(), pin_memory);\n  } else {\n    const Symbol<NdSbp>& nd_sbp = JUST(other->nd_sbp());\n    const std::vector<Symbol<SbpParallel>>& sbp_tuple = *JUST(GetSbpList(nd_sbp));\n    std::vector<Symbol<SbpParallel>> grad_sbp_tuple;\n    // TODO:(zhaoluyang) global case support pin_memory\n    return functional::ToGlobal(other, JUST(other->parallel_desc()), sbp_tuple, grad_sbp_tuple,\n                                /* check_meta */ false, /*copy=*/false);\n  }\n}\n\nMaybe<Tensor> MakeTensorFromOtherTensor(const std::shared_ptr<Tensor>& other,\n                                        const Optional<Symbol<DType>>& dtype,\n                                        const Optional<Symbol<Device>>& device,\n                                        const bool requires_grad, const bool pin_memory) {\n  std::shared_ptr<Tensor> tensor;\n  Symbol<Device> device_;\n  if (device) { device_ = JUST(device); }\n  if (other->is_local()) {\n    if (!device) { device_ = JUST(other->device()); }\n    tensor = JUST(functional::Copy(other, device_->type(), device_->device_id(),\n                                   pin_memory && !dtype.has_value()));\n  } else {\n    tensor = JUST(functional::GlobalToLocal(other, /*copy=*/false));\n    if (!device) { device_ = JUST(Device::New(\"cpu\")); }\n    tensor = JUST(functional::Copy(tensor, device_->type(), device_->device_id(),\n                                   pin_memory && !dtype.has_value()));\n  }\n  if (dtype) {\n    const Symbol<DType>& dtype_ = JUST(dtype);\n    if (tensor->dtype() != dtype_) { tensor = JUST(functional::Cast(tensor, dtype_, pin_memory)); }\n  }\n  JUST(tensor->set_requires_grad(requires_grad));\n  return tensor;\n}\n\nMaybe<Tensor> MakeTensorFromOtherTensor(const std::shared_ptr<Tensor>& other,\n                                        const Optional<Symbol<DType>>& dtype,\n                                        const Symbol<ParallelDesc>& placement,\n                                        const std::vector<Symbol<SbpParallel>>& sbp_tuple,\n                                        const bool requires_grad) {\n  std::vector<Symbol<SbpParallel>> grad_sbp_tuple;\n  bool check_meta = other->is_global() ? false : true;\n  std::shared_ptr<Tensor> tensor = JUST(functional::ToGlobal(\n      other, placement, sbp_tuple, grad_sbp_tuple, check_meta, /*copy=*/false));\n  if (dtype) {\n    const Symbol<DType>& dtype_ = JUST(dtype);\n    if (tensor->dtype() != dtype_) {\n      tensor = JUST(functional::Cast(tensor, dtype_, /*pin_memory=*/false));\n    }\n  }\n  JUST(tensor->set_requires_grad(requires_grad));\n  return tensor;\n}\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/api/python/utils/tensor_utils.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_API_PYTHON_UTILS_TENSOR_UTILS_H_\n#define ONEFLOW_API_PYTHON_UTILS_TENSOR_UTILS_H_\n\n#include <Python.h>\n#undef _PyGC_FINALIZED\n#include <pybind11/pybind11.h>\n#include <pybind11/stl.h>\n#include <pybind11/functional.h>\n#include <pybind11/numpy.h>\n\n#include \"oneflow/api/python/framework/tensor.h\"\n#include \"oneflow/extension/python/numpy.h\"\n#include \"oneflow/core/framework/device.h\"\n#include \"oneflow/core/framework/dtype.h\"\n#include \"oneflow/core/framework/instructions_builder.h\"\n#include \"oneflow/core/framework/tensor.h\"\n#include \"oneflow/core/common/stride.h\"\n#include \"oneflow/core/common/blocking_then_busy.h\"\n#include \"oneflow/core/vm/virtual_machine.h\"\n#include \"oneflow/core/common/foreign_lock_helper.h\"\n#include \"oneflow/core/kernel/kernel_util.h\"\n#include \"oneflow/api/python/functional/common.h\"\n#include \"oneflow/core/framework/tensor_util.h\"\n#include \"oneflow/core/profiler/profiler.h\"\n\nnamespace py = pybind11;\n\nnamespace pybind11 {\n// reference: https://github.com/pybind/pybind11/issues/1776\ntemplate<>\nstruct format_descriptor<oneflow::float16> {\n  static pybind11::dtype dtype() {\n    handle ptr = detail::npy_api::get().PyArray_DescrFromType_(NPY_FLOAT16);\n    return reinterpret_borrow<pybind11::dtype>(ptr);\n  }\n  static std::string format() {\n    // following: https://docs.python.org/3/library/struct.html#format-characters\n    return \"e\";\n  }\n  static constexpr auto name() { return detail::_(\"float16\"); }\n};\n}  // namespace pybind11\n\nnamespace oneflow {\nnamespace one {\n\nMaybe<void> EagerLocalTensorZeros(const std::shared_ptr<Tensor>& t);\n\ninline Maybe<void*> GetTensorDataPtr(const std::shared_ptr<LocalTensor>& tensor) {\n  void* data_ptr = nullptr;\n  const auto& Callback = [&](ep::Stream*,\n                             const std::shared_ptr<vm::EagerBlobObject>& eager_blob_object) {\n    data_ptr = eager_blob_object->mut_raw_dptr();\n  };\n  auto btb = std::make_shared<BlockingThenBusy>();\n  JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe<void> {\n    return builder->SyncAccessBlobByCallback(tensor, btb, Callback, \"const\");\n  }));\n  JUST(btb->WaitUntilCntEqualZero(VirtualMachine::GetPredicatorNoMoreInstructionsFinished()));\n  return data_ptr;\n}\n\ntemplate<typename T>\ninline static Maybe<PyObject*> EagerLocalTensorToNumpy(PyObject* py_tensor) {\n  const auto& t = PyTensor_Unpack(py_tensor);\n\n  std::shared_ptr<LocalTensor> tensor = JUST(t->AsLocalTensor());\n  CHECK_OR_RETURN(JUST(tensor->device()) == JUST(Device::New(\"cpu\")));\n  CHECK_OR_RETURN(tensor->is_eager()) << \"eager tensors supported only.\";\n  // set base object attr\n  py::handle handle = py::handle(py_tensor);\n\n  const size_t ndim = tensor->ndim();\n  const auto shape = numpy::OFShapeToNumpyShape(tensor->shape()->dim_vec());\n  // NumPy strides use bytes. OneFlow strides use element counts.\n  const auto stride =\n      numpy::OFStrideToNumpyStride(*JUST(tensor->stride()), tensor->dtype()->data_type());\n\n  void* data_ptr = JUST(GetTensorDataPtr(tensor));\n\n  return py::array(py::buffer_info(data_ptr, sizeof(T), py::format_descriptor<T>::format(), ndim,\n                                   shape, stride),\n                   handle)\n      .release()\n      .ptr();\n}\n\ntemplate<typename T>\nstruct TensorTypeToPyType final {\n  typedef T type;\n};\n\ntemplate<>\nstruct TensorTypeToPyType<float16> final {\n  typedef float type;\n};\n\ntemplate<>\nstruct TensorTypeToPyType<bfloat16> final {\n  typedef float type;\n};\n\ntemplate<typename T>\ninline static Maybe<PyObject*> EagerLocalTensorItem(const std::shared_ptr<Tensor>& tensor) {\n  // OF_PROFILER_RANGE_GUARD(\"EagerLocalTensorItem\");\n  T value = JUST(GetItemInScalarTensor<T>(tensor));\n  return functional::CastToPyObject(static_cast<typename TensorTypeToPyType<T>::type>(value));\n}\n\ninline Maybe<void> CopyBetweenLocalTensorAndNumpy(\n    const std::shared_ptr<Tensor>& t, PyObject* array,\n    void (*Copy)(ep::Stream*, const std::shared_ptr<vm::EagerBlobObject>&, const NumPyArrayPtr&),\n    const std::string& modifier, bool block_host_until_done) {\n  auto tensor = JUST(t->AsLocalTensor());\n  CHECK_OR_RETURN(tensor->is_contiguous()) << \"contiguous tensors supported only.\";\n  CHECK_OR_RETURN(tensor->is_eager()) << \"eager tensors supported only.\";\n\n  if (block_host_until_done) {\n    NumPyArrayPtr array_ptr(array);\n    const auto& Callback = [array_ptr, Copy](\n                               ep::Stream* stream,\n                               const std::shared_ptr<vm::EagerBlobObject>& eager_blob_object) {\n      Copy(stream, eager_blob_object, array_ptr);\n    };\n    auto btb = std::make_shared<BlockingThenBusy>();\n    JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe<void> {\n      return builder->SyncAccessBlobByCallback(tensor, btb, Callback, modifier);\n    }));\n    JUST(btb->WaitUntilCntEqualZero(VirtualMachine::GetPredicatorNoMoreInstructionsFinished()));\n  } else {\n    Py_INCREF(array);\n    NumPyArrayPtr array_ptr(array, [array]() {\n      // release array in main thread to eliminate the time-consuming gil request\n      CHECK_JUST(SingletonMaybe<VirtualMachine>())->add_main_thread_pending_task([array]() {\n        Py_DECREF(array);\n      });\n    });\n\n    JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe<void> {\n      return builder->AccessBlobByCallback(\n          tensor,\n          [array_ptr, Copy](ep::Stream* stream,\n                            const std::shared_ptr<vm::EagerBlobObject>& eager_blob_object) {\n            Copy(stream, eager_blob_object, array_ptr);\n          },\n          modifier);\n    }));\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<std::tuple<std::vector<Shape>, std::vector<Symbol<DType>>>>\nMaybeGetTensorBufferShapesAndDTypes(const std::shared_ptr<Tensor>& t);\n\nMaybe<void> RegisterTensorHook(const std::shared_ptr<Tensor>& self, const AutogradMeta::Hook& hook);\n\nMaybe<void> RegisterTensorPostGradAccumulationHook(const std::shared_ptr<Tensor>& self,\n                                                   const AutogradMeta::Hook& hook);\n\nMaybe<py::tuple> TensorGetPyTupleOfSbp(const Tensor& tensor);\n\nMaybe<Tensor> MakeLocalTensorFromData(PyObject* data, const Optional<Symbol<DType>>& dtype,\n                                      const Optional<Symbol<Device>>& device,\n                                      const bool requires_grad, const bool pin_memory);\n\nMaybe<Tensor> MakeGlobalTensorFromData(PyObject* data, const Optional<Symbol<DType>>& dtype,\n                                       Symbol<ParallelDesc> placement,\n                                       const std::vector<Symbol<SbpParallel>>& sbp_tuple,\n                                       const bool requires_grad);\n\nMaybe<Tensor> MakeTensorFromOtherTensor(const std::shared_ptr<Tensor>& other,\n                                        const bool pin_memory);\n\nMaybe<Tensor> MakeTensorFromOtherTensor(const std::shared_ptr<Tensor>& other,\n                                        const Optional<Symbol<DType>>& dtype,\n                                        const Optional<Symbol<Device>>& device,\n                                        const bool requires_grad, const bool pin_memory);\n\nMaybe<Tensor> MakeTensorFromOtherTensor(const std::shared_ptr<Tensor>& other,\n                                        const Optional<Symbol<DType>>& dtype,\n                                        const Symbol<ParallelDesc>& placement,\n                                        const std::vector<Symbol<SbpParallel>>& sbp_tuple,\n                                        const bool requires_grad);\n\n}  // namespace one\n}  // namespace oneflow\n\n#endif  // ONEFLOW_API_PYTHON_UTILS_TENSOR_UTILS_H_\n"
  },
  {
    "path": "oneflow/core/auto_parallel/algorithm_util.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/auto_parallel/algorithm_util.h\"\n\nnamespace oneflow {\nnamespace auto_parallel {\n\n// Inverse function of order\n// The reason why we need the inverse_order, a.k.a id2order, instead of id2value is to eliminate\n// equality. For example, we have v[0] < v[1] = v[2] < v[3] We do not know v[1] is before or after\n// v[2] with comp(v[1], v[2]). But if we transfer it to order order[0] < order[1] < order[2] <\n// order[3] We know the strict order.\nvoid InverseOrder(const std::vector<int32_t>& order, std::vector<int32_t>& inverse_order) {\n  inverse_order.resize(order.size());\n  for (int32_t i = 0; i < order.size(); i++) { inverse_order[order[i]] = i; }\n}\n\n}  // namespace auto_parallel\n\n// Ceil quotient define a division process, denoted by (/),\n// which give us the maximum part of an integer division.\n// For example,\n// 16 (/) 4 = 4, 17 (/) 4 = 5\n// 5 (/) 2 = 3, 6 (/) 2 = 3\n// 1 (/) 3 = 1, 2 (/) 7 = 1\n// 17 divide by 4 give us 5, 4, 4, 4\n// The normal quotient would take the smaller one 4,\n// but the ceil quotient would take the larger one 5.\nint64_t CeilQuotient(int64_t dividend, int64_t divisor) {\n  return (dividend + divisor - 1) / divisor;\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/auto_parallel/algorithm_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_AUTO_PARALLEL_ALGORITHM_UTIL_H_\n#define ONEFLOW_CORE_AUTO_PARALLEL_ALGORITHM_UTIL_H_\n\n#include <vector>\n#include <cstdlib>\n#include <algorithm>\n#include <unordered_map>\n\nnamespace oneflow {\nnamespace auto_parallel {\n\n// this function is to remove the i-th element from a vector in Constant time.\n// the vector should not care about ordering.\n// Be more careful about this function. Make sure that the traveling order of\n// the vector goes from back to front.\ntemplate<class T>\nvoid RemoveFrom(std::vector<T>& v, int32_t i) {\n  v[i] = v.back();\n  v.pop_back();\n}\n\ntemplate<class T>\nvoid CheckAndRemoveFrom(std::vector<T>& v, T& t) {\n  for (int32_t i = v.size() - 1; i >= 0; i--) {\n    if (v[i] == t) {\n      RemoveFrom<T>(v, i);\n      break;\n    }\n  }\n}\n\n// Inverse function, which transfer a vector to an unordered_map.\ntemplate<class T>\nvoid InverseFunction(const std::vector<T>& v, std::unordered_map<T, int32_t>& inverse_map) {\n  inverse_map.clear();\n  for (int32_t i = 0; i < v.size(); i++) { inverse_map[v[i]] = i; }\n}\n\n// When you want to sort something but you can not move any elements, use order.\n// Decide the order of sorting in a list v, we have\n// v[order[i]] < v[order[j]] for all i<j.\n// We could define the comparison, then we have\n// comp(v[order[i]], v[order[j]]) == true for all i<j.\ntemplate<class T, class Compare>\nvoid DecideOrder(const T& v, std::vector<int32_t>& order, const Compare& comp) {\n  // Initialize order\n  order.resize(v.size());\n  for (int32_t i = 0; i < v.size(); i++) { order[i] = i; }\n  // sort\n  std::sort(order.begin(), order.end(), [&](int32_t i, int32_t j) { return comp(v[i], v[j]); });\n}\n\n// Inverse function of order\n// The reason why we need the inverse_order, a.k.a id2order, instead of id2value is to eliminate\n// equality. For example, we have v[0] < v[1] = v[2] < v[3] We do not know v[1] is before or after\n// v[2] with comp(v[1], v[2]). But if we transfer it to order order[0] < order[1] < order[2] <\n// order[3] We know the strict order.\nvoid InverseOrder(const std::vector<int32_t>& order, std::vector<int32_t>& inverse_order);\n\n}  // namespace auto_parallel\n\n// Ceil quotient define a division process, denoted by (/),\n// which give us the maximum part of an integer division.\n// For example,\n// 16 (/) 4 = 4, 17 (/) 4 = 5\n// 5 (/) 2 = 3, 6 (/) 2 = 3\n// 17 divide by 4 give us 5, 4, 4, 4\n// The normal quotient would take the smaller one 4,\n// but the ceil quotient would take the larger one 5.\nint64_t CeilQuotient(int64_t dividend, int64_t divisor);\n\nstatic const double kFloatDeviationMinus = 0.9999999;\nstatic const double kFloatDeviationPlus = 1.0000001;\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_AUTO_PARALLEL_ALGORITHM_UTIL_H_\n"
  },
  {
    "path": "oneflow/core/auto_parallel/auto_memory.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/auto_parallel/auto_memory.h\"\n#include \"oneflow/core/auto_parallel/sbp_constructor.h\"\n#include \"oneflow/core/common/hash_container.h\"\n#include \"oneflow/core/framework/sbp_infer_util.h\"\n#include \"oneflow/core/graph/normal_forward_compute_task_node.h\"\n#include \"oneflow/core/graph/op_graph.h\"\n#include \"oneflow/core/graph/straighten_nodes.h\"\n#include \"oneflow/core/register/logical_blob_id.pb.h\"\n\nnamespace oneflow {\nnamespace auto_parallel {\n\nnamespace {\n\nclass TopoStruct {\n public:\n  SbpNode* sbp_node = nullptr;\n  const OpNode* op_node = nullptr;\n  // Memory increment = (memory of out registers) - (memory of in registers)\n  int64_t memory_increment = -1;\n  int32_t exceed_time = -1;\n  bool is_reusable = false;\n  int32_t counter = 0;\n  int32_t min_layer = -1;\n  // The maximum min_layer among out_topo_structs\n  int32_t max_layer = -1;\n  // TODO: remove tributary layer\n  // This node should be finished before tributary layer\n  int32_t tributary_layer = -1;\n\n  HashSet<TopoStruct*> in_topo_structs;\n  HashSet<TopoStruct*> out_topo_structs;\n\n  explicit TopoStruct(SbpNode* sbp_node_);\n  explicit TopoStruct(const OpNode* op_node_);\n\n  // Compute the minimum layer of this node\n  int32_t ComputeMinLayer();\n  // Compute the maximum layer of this node\n  void ComputeMaxLayer(int32_t max_min_layer);\n  // Compute the tributary layer\n  int32_t ComputeTributaryLayer(int32_t max_min_layer);\n  // Decide whether all the produced registers are reusable\n  void ComputeIsReusable();\n  // Exceed time = time of cpu - time of gpu\n  void ComputeExceedTime();\n\n  // deciding parameter\n  // kTributaryLayerAscend = 0,     // small tributary layers go first\n  // kDistanceToOverlapAscend = 1,  // small minimum distance to overlap go first\n  // kLayerAscend = 2,              // first in first out\n  // kMemoryIncrementAscend = 3,    // small memory increment go first\n  // kExceedTimeAscend = 4,         // small exceed time go first\n  // kTributaryLayerDescend = 100,     // large tributary layers go first\n  // kDistanceToOverlapDescend = 101,  // long distance to overlap go first\n  // kLayerDescend = 102,              // last in first out\n  // kMemoryIncrementDescend = 103,    // large memory increment go first\n  // kExceedTimeDescend = 104,         // large exceed time go first\n  int64_t GetDecidingParameter(StraightenOrder so) const;\n};\n\nstatic StraightenAlgorithmTag sat;\n\nstatic std::vector<StraightenOrder> decide_parameters;\n\n// Order in the waiting sets\nstruct comp {\n  bool operator()(const TopoStruct* a, const TopoStruct* b) const {\n    for (auto decide_parameter : decide_parameters) {\n      auto decide_parameter_a = a->GetDecidingParameter(decide_parameter);\n      auto decide_parameter_b = b->GetDecidingParameter(decide_parameter);\n      if (decide_parameter_a != decide_parameter_b) {\n        return decide_parameter_a < decide_parameter_b;\n      }\n    }\n    return a->op_node->op().op_name() < b->op_node->op().op_name();\n  }\n};\n\nbool IsProducedRegisterReusable(const Operator& op) {\n  // The repeat, acc, pack and unpack operators have non-reusable registers\n  // and a -1 register num at this moment.\n  if (op.op_conf().has_user_conf()) {\n    const auto& op_type_name = op.op_conf().user_conf().op_type_name();\n    // We record the frequency in swin-transformer on the right hand side\n    // and adjust the position accordingly.\n    if (op_type_name == \"repeat\"     // 213\n        || op_type_name == \"acc\"     // 173\n        || op_type_name == \"unpack\"  // 2\n        || op_type_name == \"pack\"    // 1\n    ) {\n      return false;\n    }\n  }\n  // NOTE: Please refer to oneflow/core/graph_impl/normal_forward_compute_task_node.cpp\n  // NormalForwardCompTaskNode::ProduceOutRegstByNameAndBlockNum\n  // for detail.\n  // We can not use <= 0 here since RegstNum4Op returns a number with type size_t.\n  // -1 is actually 18446744073709551615 here.\n  return RegstNum4Op(op) == -1;\n}\n\nTopoStruct::TopoStruct(SbpNode* sbp_node_)\n    : sbp_node(sbp_node_), op_node(sbp_node_->GetOperatorNode()) {\n  ComputeIsReusable();\n  ComputeExceedTime();\n}\n\nTopoStruct::TopoStruct(const OpNode* op_node_) : op_node(op_node_) {\n  ComputeIsReusable();\n  ComputeExceedTime();\n}\n\n// deciding parameter\n// kTributaryLayerAscend = 0,     // small tributary layers go first\n// kDistanceToOverlapAscend = 1,  // small minimum distance to overlap go first\n// kLayerAscend = 2,              // first in first out\n// kMemoryIncrementAscend = 3,    // small memory increment go first\n// kExceedTimeAscend = 4,         // small exceed time go first\n// kTributaryLayerDescend = 100,     // large tributary layers go first\n// kDistanceToOverlapDescend = 101,  // long distance to overlap go first\n// kLayerDescend = 102,              // last in first out\n// kMemoryIncrementDescend = 103,    // large memory increment go first\n// kExceedTimeDescend = 104,         // large exceed time go first\nint64_t TopoStruct::GetDecidingParameter(StraightenOrder so) const {\n  int64_t sign = 1;\n  if (so >= kDiff4AscendDescend) {\n    so = StraightenOrder(int(so) - kDiff4AscendDescend);\n    sign = -1;\n  }\n  switch (so) {\n    case StraightenOrder::kTributaryLayerAscend: return sign * tributary_layer;\n    case StraightenOrder::kDistanceToOverlapAscend: return 0;\n    case StraightenOrder::kLayerAscend: return sign * min_layer;\n    case StraightenOrder::kMemoryIncrementAscend: return sign * memory_increment;\n    case StraightenOrder::kExceedTimeAscend: return sign * exceed_time;\n    default: return 0;\n  }\n}\n\n// Exceed time = time of cpu - time of gpu\nvoid TopoStruct::ComputeExceedTime() {\n  if (ShortGpuTime(op_node->op().op_conf())) {\n    exceed_time = 1;\n  } else {\n    exceed_time = 0;\n  }\n}\n\n// Compute the minimum layer of this node\nint32_t TopoStruct::ComputeMinLayer() {\n  if (min_layer >= 0) { return min_layer; }\n  for (auto& in_topo_struct : in_topo_structs) {\n    min_layer = std::max(min_layer, in_topo_struct->ComputeMinLayer());\n  }\n  return ++min_layer;\n}\n\n// Compute the maximum layer of this node\nvoid TopoStruct::ComputeMaxLayer(int32_t max_min_layer) {\n  // Execute those optimizer as soon as possible to release the register of weight_diff\n  if (out_topo_structs.empty()) {\n    max_layer = min_layer;\n    return;\n  }\n  max_layer = max_min_layer;\n  for (auto& out_topo_struct : out_topo_structs) {\n    if (max_layer > out_topo_struct->min_layer) { max_layer = out_topo_struct->min_layer; }\n  }\n  --max_layer;\n}\n\n// Compute the tributary layer\nint32_t TopoStruct::ComputeTributaryLayer(int32_t max_min_layer) {\n  if (tributary_layer >= 0) { return tributary_layer; }\n  tributary_layer = max_min_layer;\n  for (auto& out_topo_struct : out_topo_structs) {\n    if (tributary_layer > out_topo_struct->ComputeTributaryLayer(max_min_layer)) {\n      tributary_layer = out_topo_struct->tributary_layer;\n    }\n  }\n  return --tributary_layer;\n}\n\nvoid TopoStruct::ComputeIsReusable() { is_reusable = IsProducedRegisterReusable(op_node->op()); }\n\n// Compute the memory increment for all the topological structures\nvoid ComputeAllMemoryIncrement(std::vector<TopoStruct*>& topo_structs,\n                               HashMap<LogicalBlobId, int32_t>& lbi2id,\n                               std::vector<std::vector<TopoStruct*>>& id2consumer_topo_structs,\n                               std::vector<int64_t>& id2blob_size) {\n  // Compute the memory increment for produced blobs\n  for (auto& topo_struct : topo_structs) {\n    topo_struct->memory_increment = 0;\n    const auto& curr_operator = topo_struct->op_node->op();\n    if (topo_struct->is_reusable) {\n      for (const auto& obn : curr_operator.output_bns()) {\n        const LogicalBlobId& lbi = curr_operator.BnInOp2Lbi(obn);\n        auto it = lbi2id.find(lbi);\n        if (it == lbi2id.end()) {\n          // There exist some blobs that do not have any consumer\n          // Such as: op name:\n          // model.cls_head.loss_func.lm_loss-sparse_softmax_cross_entropy_ms-231-split_softmax_reduce_max_global_stage\n          // blob name: mask_0\n          const BlobDesc& logical_blob_desc = topo_struct->op_node->LogicalBlobDesc4Lbi(lbi);\n          lbi2id[lbi] = id2blob_size.size();\n          id2blob_size.push_back(TotalByteSize4BlobDesc(logical_blob_desc));\n          // There are some inconsistency between id2blob_size and id2consumer_topo_structs\n          // We would deal with that at the end to avoid division by 0\n          topo_struct->memory_increment += id2blob_size.back();\n        } else {\n          topo_struct->memory_increment += id2blob_size[it->second];\n        }\n      }\n    }\n  }\n  // Subtract the consumed memory\n  for (int32_t index = 0; index < id2consumer_topo_structs.size(); index++) {\n    int64_t memory_decrease = id2blob_size[index] / id2consumer_topo_structs[index].size();\n    for (auto& consumer_topo_struct : id2consumer_topo_structs[index]) {\n      consumer_topo_struct->memory_increment -= memory_decrease;\n    }\n  }\n  // Add empty vectors for all those blobs without consumers\n  id2consumer_topo_structs.resize(id2blob_size.size());\n}\n\nvoid UpdateSat(const std::vector<TopoStruct*>& topo_structs, StraightenAlgorithmTag* sat) {\n  *sat = GlobalJobDesc().job_conf().straighten_algorithm_tag_in_task_graph();\n  if (*sat == StraightenAlgorithmTag::kOverlap4CpuGpu) {\n    // If not cpu nodes, then the overlap strategy between cpu and gpu might consume large memory\n    bool exist_cpu_nodes = false;\n    for (const auto& topo_struct : topo_structs) {\n      // Found a cpu node\n      if (topo_struct->exceed_time == 1) {\n        exist_cpu_nodes = true;\n        break;\n      }\n    }\n    if (!exist_cpu_nodes) {\n      // Switch to the compress memory strategy, the default one\n      // Since the overlap strategy for transfer might not be working on 1n1d.\n      *sat = StraightenAlgorithmTag::kCompressMemory;\n    }\n  }\n}\n\nvoid InitInOutTopoStructs(std::vector<TopoStruct*>* topo_structs) {\n  // Generate the map from operator names to topological structure\n  HashMap<std::string, TopoStruct*> op_name2topo_structs;\n  for (auto& topo_struct : *topo_structs) {\n    op_name2topo_structs[topo_struct->op_node->op().op_name()] = topo_struct;\n  }\n\n  // Traverse the topological structures\n  for (auto& this_topo_struct : *topo_structs) {\n    auto& node = this_topo_struct->op_node;\n    // Initialize input nodes for edges with data\n    node->ForEachNodeOnInEdge([&](OpNode* in) {\n      // Since we might be looking at a sub-graph of the operator graph.\n      // We need to check if the op_node exists in the sub-graph.\n      auto it = op_name2topo_structs.find(in->op().op_name());\n      if (it != op_name2topo_structs.end()) {\n        this_topo_struct->in_topo_structs.insert(it->second);\n        it->second->out_topo_structs.insert(this_topo_struct);\n      }\n    });\n    // Initialize input nodes for control edges\n    for (const auto& ctrl_in_op_name : node->op().op_conf().ctrl_in_op_name()) {\n      auto it = op_name2topo_structs.find(ctrl_in_op_name);\n      if (it != op_name2topo_structs.end()) {\n        auto& ctrl_in_topo_struct = it->second;\n        this_topo_struct->in_topo_structs.insert(ctrl_in_topo_struct);\n        // Initialize output nodes for this control edge simultaneously\n        ctrl_in_topo_struct->out_topo_structs.insert(this_topo_struct);\n      }\n    }\n  }\n}\n\nvoid ComputeLayer(std::vector<TopoStruct*>* topo_structs) {\n  int32_t max_min_layer = -1;\n  // Compute the minimum layer for the whole graph\n  for (auto& topo_struct : *topo_structs) {\n    if (max_min_layer < topo_struct->ComputeMinLayer()) { max_min_layer = topo_struct->min_layer; }\n  }\n  max_min_layer++;\n  // Compute the maximum layer for the whole graph\n  for (auto& topo_struct : *topo_structs) { topo_struct->ComputeMaxLayer(max_min_layer); }\n  // Compute the tributary layer\n  for (auto& topo_struct : *topo_structs) { topo_struct->ComputeTributaryLayer(max_min_layer); }\n}\n\nvoid InitAllParameters(std::vector<TopoStruct*>* topo_structs,\n                       HashMap<LogicalBlobId, int32_t>* lbi2id,\n                       std::vector<std::vector<TopoStruct*>>* id2consumer_topo_structs,\n                       std::vector<int64_t>* id2blob_size) {\n  // Construct the map from a lbi to its id, consumers, blob size\n  for (auto& topo_struct : *topo_structs) {\n    const auto& consumer = topo_struct->op_node->op();\n    for (const auto& ibn : consumer.input_bns()) {\n      const LogicalBlobId& lbi = consumer.BnInOp2Lbi(ibn);\n      auto it = lbi2id->find(lbi);\n      if (it == lbi2id->end()) {\n        (*lbi2id)[lbi] = id2blob_size->size();\n        const BlobDesc& logical_blob_desc = topo_struct->op_node->LogicalBlobDesc4Lbi(lbi);\n        id2blob_size->push_back(TotalByteSize4BlobDesc(logical_blob_desc));\n        id2consumer_topo_structs->push_back({topo_struct});\n      } else {\n        id2consumer_topo_structs->at(it->second).push_back(topo_struct);\n      }\n    }\n  }\n\n  // Construct all the data edges and control edges\n  InitInOutTopoStructs(topo_structs);\n\n  // Compute the layers\n  ComputeLayer(topo_structs);\n\n  // Compute the memory increment for all the topological structures\n  ComputeAllMemoryIncrement(*topo_structs, *lbi2id, *id2consumer_topo_structs, *id2blob_size);\n\n  // Update sat, since sat might be changed in previous jobs\n  UpdateSat(*topo_structs, &sat);\n\n  // Decide which node should run first\n  InitDecideParameters(sat, &decide_parameters);\n  VLOG(3) << \"Straightening order in sbp graph: \";\n  for (int32_t decide_parameter : decide_parameters) { VLOG(3) << decide_parameter; }\n}\n\nvoid StraightenOpNodes(HashMap<const OpNode*, TopoStruct>& op_node2topo_struct,\n                       std::vector<TopoStruct*>* topo_structs,\n                       HashMap<LogicalBlobId, int32_t>* lbi2id,\n                       std::vector<std::vector<TopoStruct*>>* id2consumer_topo_structs,\n                       std::vector<int64_t>* id2blob_size,\n                       std::vector<TopoStruct*>* ordered_topo_structs) {\n  InitAllParameters(topo_structs, lbi2id, id2consumer_topo_structs, id2blob_size);\n\n  std::set<TopoStruct*, comp> waiting_list;\n\n  // Wait in the list\n  auto wait = [&](TopoStruct* topo_struct) { waiting_list.insert(topo_struct); };\n\n  // Initialization\n  for (auto& topo_struct : *topo_structs) {\n    topo_struct->counter = topo_struct->in_topo_structs.size();\n    if (topo_struct->counter == 0) { wait(topo_struct); }\n  }\n\n  // Finish execution\n  auto finish_execution = [&](TopoStruct* topo_struct) {\n    for (auto& out : topo_struct->out_topo_structs) {\n      out->counter--;\n      if (out->counter == 0) { wait(out); }\n    }\n  };\n\n  // Execute the first node in the waiting list\n  // Make sure to check that waiting list is not empty before execution\n  auto execute = [&]() {\n    auto first_topo_struct = *waiting_list.begin();\n    // Set the order of execution for sbp nodes\n    ordered_topo_structs->push_back(first_topo_struct);\n    waiting_list.erase(waiting_list.begin());\n    finish_execution(first_topo_struct);\n  };\n\n  // straightening\n  while (!waiting_list.empty()) { execute(); }\n}\n\n}  // anonymous namespace\n\n// Use two function\nvoid InitMemory(const OpGraph& op_graph, SbpGraph* sbp_graph, bool nccl_use_compute_stream) {\n  // Generate topological data structure for each sbp node\n  HashMap<const OpNode*, TopoStruct> op_node2topo_struct;\n  std::vector<TopoStruct*> topo_structs;\n  std::vector<TopoStruct*> ordered_topo_structs;\n\n  // Traverse all the nodes in the sbp graph\n  for (const auto& sbp_node : sbp_graph->GetNodeList()) {\n    auto* op_node = sbp_node->GetOperatorNode();\n    CHECK(op_node != nullptr)\n        << \"No proxy node allow at this status. InitMemory() should be run before sbp collector!\";\n    op_node2topo_struct.insert({op_node, TopoStruct(sbp_node)});\n    topo_structs.push_back(&op_node2topo_struct.at(op_node));\n  }\n\n  // Construct the map from a lbi to its id, consumers, blob size\n  HashMap<LogicalBlobId, int32_t> lbi2id;\n  std::vector<std::vector<TopoStruct*>> id2consumer_topo_structs;\n  std::vector<int64_t> id2blob_size;\n\n  StraightenOpNodes(op_node2topo_struct, &topo_structs, &lbi2id, &id2consumer_topo_structs,\n                    &id2blob_size, &ordered_topo_structs);\n\n  // Mark the memory support, which contains two part:\n  // All the non-reusable memory and those blobs which is a part of the maximum reusable memory\n  int64_t max_reusable_memory = 0;\n  int64_t curr_reusable_memory = 0;\n  std::vector<int32_t> id2count(id2blob_size.size(), -1);\n  // Blobs born, increase count and memory\n  auto GenerateBlobs = [&](TopoStruct* topo_struct) {\n    const auto& curr_operator = topo_struct->op_node->op();\n    if (topo_struct->is_reusable) {\n      for (const auto& obn : curr_operator.output_bns()) {\n        const LogicalBlobId& lbi = curr_operator.BnInOp2Lbi(obn);\n        int32_t index = lbi2id.at(lbi);\n        // Reusable blobs born\n        curr_reusable_memory += id2blob_size[index];\n        id2count[index] = id2consumer_topo_structs[index].size();\n      }\n    }\n  };\n  // Blobs die, decrease count and memory\n  auto KillBlobs = [&](TopoStruct* topo_struct) {\n    const auto& curr_operator = topo_struct->op_node->op();\n    // Those reusable blobs who do not have a consumer would die immediately\n    // For example:\n    // register_num: 1, op_name:\n    // \"model.cls_head.loss_func.lm_loss-sparse_softmax_cross_entropy_ms-231-split_softmax_reduce_max_device_stage\",\n    // blob_name: \"mask_0\", shape { dim: 2048 dim: 21248 },\n    // data_type: kBool, time_shape { dim: 1 dim: 1 }, enable_reuse_mem: true,\n    // alloc_before_actor: 369, free_after_actor: 369\n    if (topo_struct->is_reusable) {\n      for (const auto& obn : curr_operator.output_bns()) {\n        const LogicalBlobId& lbi = curr_operator.BnInOp2Lbi(obn);\n        int32_t index = lbi2id.at(lbi);\n        // Do not have consumer\n        if (id2count[index] == 0) {\n          // Reusable blobs die\n          curr_reusable_memory -= id2blob_size[index];\n        }\n      }\n    }\n    // Reduce the counter and kill the blobs if count to 0\n    for (const auto& ibn : curr_operator.input_bns()) {\n      const LogicalBlobId& lbi = curr_operator.BnInOp2Lbi(ibn);\n      int32_t index = lbi2id.at(lbi);\n      if (id2count[index] > 0) {\n        --id2count[index];\n        if (id2count[index] == 0) {\n          // Reusable blobs die\n          curr_reusable_memory -= id2blob_size[index];\n        }\n      }\n    }\n  };\n  // Calculate the maximum reusable memory and mark those fixed memory\n  for (auto& topo_struct : ordered_topo_structs) {\n    // Blobs born, increase count and memory\n    GenerateBlobs(topo_struct);\n    // Record the maximum memory\n    if (curr_reusable_memory > max_reusable_memory) { max_reusable_memory = curr_reusable_memory; }\n    // Blobs die, decrease count and memory\n    KillBlobs(topo_struct);\n  }\n\n  // Make sure that every blob dies\n  CHECK_EQ(curr_reusable_memory, 0) << \" Have not kill all the reusable blobs!\";\n\n  // Mark those reusable memory which constitute the maximum reusable memory\n  for (auto& topo_struct : ordered_topo_structs) {\n    // Blobs born, increase count and memory\n    GenerateBlobs(topo_struct);\n    // Mark the first found support\n    if (curr_reusable_memory == max_reusable_memory) {\n      // Mark the temporary memory created by this operator\n      if (topo_struct->is_reusable) {\n        const auto& curr_operator = topo_struct->op_node->op();\n        for (const auto& obn : curr_operator.output_bns()) {\n          const LogicalBlobId& lbi = curr_operator.BnInOp2Lbi(obn);\n          int32_t index = lbi2id.at(lbi);\n          // We would use id2count != 0 to record the lbi support\n          // Those obn with no consumers have id2count[index] == 0, now it would be set to 1\n          id2count[index] = 1;\n        }\n      }\n      // The other lbi in the support would have a non-zero id2count\n      // No further process needed\n      break;\n    }\n    // Blobs die, decrease count and memory\n    KillBlobs(topo_struct);\n  }\n\n  // Initialize memory for each sbp node\n  for (auto& topo_struct : topo_structs) {\n    topo_struct->sbp_node->InitializeMemory(topo_struct->is_reusable, lbi2id, id2count,\n                                            nccl_use_compute_stream);\n  }\n}\n\n// Straighten a subset of the op graph\nvoid StraightenSubGraph(const std::vector<const OpNode*>& sub_graph,\n                        std::vector<const OpNode*>* ordered_op_nodes) {\n  // Generate topological data structure for each op node\n  HashMap<const OpNode*, TopoStruct> op_node2topo_struct;\n  std::vector<TopoStruct*> topo_structs;\n  std::vector<TopoStruct*> ordered_topo_structs;\n\n  // Traverse all the nodes in the sub graph\n  for (const auto& node : sub_graph) {\n    op_node2topo_struct.insert({node, TopoStruct(node)});\n    topo_structs.push_back(&op_node2topo_struct.at(node));\n  }\n\n  // Construct the map from a lbi to its id, consumers, blob size\n  HashMap<LogicalBlobId, int32_t> lbi2id;\n  std::vector<std::vector<TopoStruct*>> id2consumer_topo_structs;\n  std::vector<int64_t> id2blob_size;\n\n  StraightenOpNodes(op_node2topo_struct, &topo_structs, &lbi2id, &id2consumer_topo_structs,\n                    &id2blob_size, &ordered_topo_structs);\n\n  for (auto& ordered_topo_struct : ordered_topo_structs) {\n    ordered_op_nodes->push_back(ordered_topo_struct->op_node);\n  }\n}\n\n// Straighten the whole op graph\nvoid StraightenOpGraph(const OpGraph& op_graph, std::vector<const OpNode*>* ordered_op_nodes) {\n  std::vector<const OpNode*> sub_graph;\n\n  // Traverse and store all the nodes in the op graph\n  op_graph.ForEachNode([&](OpNode* node) { sub_graph.push_back(node); });\n\n  StraightenSubGraph(sub_graph, ordered_op_nodes);\n}\n\n}  // namespace auto_parallel\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/auto_parallel/auto_memory.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_AUTO_PARALLEL_AUTO_MEMORY_H_\n#define ONEFLOW_CORE_AUTO_PARALLEL_AUTO_MEMORY_H_\n\n#include \"oneflow/core/auto_parallel/sbp_graph.h\"\n#include \"oneflow/core/graph/op_graph.h\"\nnamespace oneflow {\n\nnamespace auto_parallel {\nvoid InitMemory(const OpGraph& op_graph, SbpGraph* sbp_graph, bool nccl_use_compute_stream);\n\n// Straighten a subset of the op graph\nvoid StraightenSubGraph(const std::vector<const OpNode*>& sub_graph,\n                        std::vector<const OpNode*>* ordered_op_nodes);\n\n// Straighten the whole op graph\nvoid StraightenOpGraph(const OpGraph& op_graph, std::vector<const OpNode*>* ordered_op_nodes);\n\n}  // namespace auto_parallel\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_AUTO_PARALLEL_AUTO_MEMORY_H_\n"
  },
  {
    "path": "oneflow/core/auto_parallel/binary_set.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/auto_parallel/binary_set.h\"\n\nnamespace oneflow {\nnamespace auto_parallel {\n\nnamespace {\n// A static function for initialization of log_2 mapping\nstd::unordered_map<BinarySetEntryType, int32_t> InitLog2() {\n  std::unordered_map<BinarySetEntryType, int32_t> log_2;\n  for (int32_t i = 0; i < 8 * sizeof(BinarySetEntryType); i++) {\n    log_2[static_cast<BinarySetEntryType>(1 << i)] = i;\n  }\n  return log_2;\n}\n\n// Initialization of log_2 mapping\n// Take log2 of a integer value: 2^n -> n.\nconst std::unordered_map<BinarySetEntryType, int32_t> log_2 = InitLog2();\n\n}  // namespace\n\n// Constructor\nBinarySet::BinarySet(int32_t size_of_set) : size_of_set_(size_of_set) {\n  int32_t k = (size_of_set - 1) / bit_entry_type_ + 1;\n  binary_set_values_.resize(k, 0);\n}\n\n// Initialization if needed\nvoid BinarySet::Initialize(int32_t size_of_set) {\n  size_of_set_ = size_of_set;\n  int32_t k = (size_of_set - 1) / bit_entry_type_ + 1;\n  binary_set_values_.resize(k, 0);\n}\n\n// Clear all the elements in the set\nvoid BinarySet::Clear() { binary_set_values_.assign(binary_set_values_.size(), 0); }\n\n// Check if i-th element in this subset\nbool BinarySet::CheckExistence(int32_t i) const {\n  int32_t k = i / bit_entry_type_;\n  int32_t j = i % bit_entry_type_;\n  return bool((binary_set_values_[k] >> j) & 1);\n}\n\n// Add i-th element into this subset\nvoid BinarySet::AddEntry(int32_t i) {\n  int32_t k = i / bit_entry_type_;\n  int32_t j = i % bit_entry_type_;\n  binary_set_values_[k] |= (1 << j);\n}\n// Take i-th element out from this subset\nvoid BinarySet::DeleteEntry(int32_t i) {\n  int32_t k = i / bit_entry_type_;\n  int32_t j = i % bit_entry_type_;\n  binary_set_values_[k] &= ~(1 << j);\n}\n// Get the union with another subset and store it into u\nvoid BinarySet::UnionTo(const BinarySet& bs, BinarySet& u) {\n  for (int32_t k = 0; k < binary_set_values_.size(); k++) {\n    u.binary_set_values_[k] = binary_set_values_[k] | bs.binary_set_values_[k];\n  }\n}\n// If this binary set intersects another one\nbool BinarySet::IfIntersect(const BinarySet& bs) const {\n  int32_t min_bs_size = std::min(binary_set_values_.size(), bs.binary_set_values_.size());\n  for (int32_t k = 0; k < min_bs_size; k++) {\n    if (binary_set_values_[k] & bs.binary_set_values_[k]) { return true; }\n  }\n  return false;\n}\n// Get the intersection with another subset and store it into i\nvoid BinarySet::IntersectionTo(const BinarySet& bs, BinarySet& i) const {\n  int32_t min_bs_size = std::min(binary_set_values_.size(), bs.binary_set_values_.size());\n  if (min_bs_size > i.binary_set_values_.size()) { i.binary_set_values_.resize(min_bs_size, 0); }\n  for (int32_t k = 0; k < binary_set_values_.size(); k++) {\n    i.binary_set_values_[k] = binary_set_values_[k] & bs.binary_set_values_[k];\n  }\n}\n// Count number of elements in this subset\nint32_t BinarySet::Total() const {\n  int32_t t = 0;\n  for (int32_t k = 0; k < binary_set_values_.size(); k++) {\n    BinarySetEntryType bsv = binary_set_values_[k];\n    bsv = (bsv & 0x5555555555555555) + ((bsv >> 1) & 0x5555555555555555);\n    bsv = (bsv & 0x3333333333333333) + ((bsv >> 2) & 0x3333333333333333);\n    bsv = (bsv & 0x0F0F0F0F0F0F0F0F) + ((bsv >> 4) & 0x0F0F0F0F0F0F0F0F);\n    bsv = (bsv & 0x00FF00FF00FF00FF) + ((bsv >> 8) & 0x00FF00FF00FF00FF);\n    bsv = (bsv & 0x0000FFFF0000FFFF) + ((bsv >> 16) & 0x0000FFFF0000FFFF);\n    // bsv = (bsv & 0x00000000FFFFFFFF) + ((bsv >> 32) & 0x00000000FFFFFFFF);\n    t += int32_t(bsv);\n  }\n  return t;\n}\n\n// Output all the elements in the subset\nvoid BinarySet::Output(std::vector<int32_t>& out) const {\n  out.clear();\n  for (int32_t i = 0; i < size_of_set_; i++) {\n    if (CheckExistence(i)) { out.emplace_back(i); }\n  }\n}\n\n// Output all the elements in the subset\nvoid BinarySet::QuickOutput(std::vector<int32_t>& out) const {\n  out.clear();\n  for (int32_t i = 0; i < binary_set_values_.size(); i++) {\n    BinarySetEntryType x = binary_set_values_[i];\n    BinarySetEntryType y = 0;\n    while (x) {\n      y = x;\n      x &= x - 1;\n      out.emplace_back(i * BinarySet::bit_entry_type_ + log_2.find(y - x)->second);\n    }\n  }\n}\n\n// Add elements of input into this subset\nvoid BinarySet::AddEntries(std::vector<int32_t>& in) {\n  for (int32_t i : in) { AddEntry(i); }\n}\n\n// If two binary sets are equal to each other\nbool BinarySet::operator==(const BinarySet& rhs) const {\n  if (size_of_set_ != rhs.size_of_set_) { return false; }\n  for (int32_t i = 0; i < binary_set_values_.size(); i++) {\n    if (binary_set_values_[i] != rhs.binary_set_values_[i]) { return false; }\n  }\n  return true;\n}\n\n}  // namespace auto_parallel\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/auto_parallel/binary_set.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_AUTO_PARALLEL_BINARY_SET_H_\n#define ONEFLOW_CORE_AUTO_PARALLEL_BINARY_SET_H_\n\n#include <cstdlib>\n#include <unordered_map>\n#include <vector>\n#include \"oneflow/core/common/hash.h\"\n\nnamespace oneflow {\nnamespace auto_parallel {\n\n// log_2_ index only support 32-bit int. Don't know why.\n// Don't have any other bugs for unsigned int.\nusing BinarySetEntryType = unsigned int;\n\nclass BinarySet {\n public:\n  BinarySet() {}\n  explicit BinarySet(int32_t size_of_set);\n\n  // Initialization\n  void Initialize(int32_t size_of_set);\n  // Clear all the elements in the set\n  void Clear();\n  // Check if i-th element in this subset\n  bool CheckExistence(int32_t i) const;\n  // Add i-th element into this subset\n  void AddEntry(int32_t i);\n  // Take i-th element out from this subset\n  void DeleteEntry(int32_t i);\n  // Get the union with another subset and store it into u\n  void UnionTo(const BinarySet& bs, BinarySet& u);\n  // If this binary set intersects another one\n  bool IfIntersect(const BinarySet& bs) const;\n  // Get the intersection with another subset and store it into i\n  void IntersectionTo(const BinarySet& bs, BinarySet& i) const;\n  // Count number of elements in this subset\n  int32_t Total() const;\n  // Output all the elements in the subset\n  void Output(std::vector<int32_t>& out) const;\n  // Output all the elements in the subset\n  void QuickOutput(std::vector<int32_t>& out) const;\n  // Add elements of input into this subset\n  void AddEntries(std::vector<int32_t>& in);\n  // If two binary sets are equal to each other\n  bool operator==(const BinarySet& rhs) const;\n\n  inline int32_t GetSizeOfSet() const { return size_of_set_; };\n\n private:\n  friend struct BinarySetHasher;\n  // binary_set_values_ contains a vector of 64-bit or 32-bit int.\n  // Each bit means whether an entry is in the set\n  std::vector<BinarySetEntryType> binary_set_values_;\n\n  int32_t size_of_set_ = -1;\n\n  // total bits of the entry type in vector binary_set_values_.\n  static constexpr int32_t bit_entry_type_ = 8 * sizeof(BinarySetEntryType);\n};\n\nstruct BinarySetHasher {\n  std::size_t operator()(const BinarySet& bs) const {\n    using std::hash;\n    using std::size_t;\n\n    size_t h = 0;\n    for (int i = 0; i < bs.binary_set_values_.size(); i++) {\n      h = HashCombine(h, hash<BinarySetEntryType>()(bs.binary_set_values_[i]));\n    }\n    return h;\n  };\n};\n\n}  // namespace auto_parallel\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_AUTO_PARALLEL_BINARY_SET_H_\n"
  },
  {
    "path": "oneflow/core/auto_parallel/boxing_collector.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include <memory>\n#include <string>\n#include \"oneflow/core/auto_parallel/algorithm_util.h\"\n#include \"oneflow/core/auto_parallel/boxing_collector.h\"\n#include \"oneflow/core/common/data_type.h\"\n#include \"oneflow/core/common/device_type.pb.h\"\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/framework/nd_sbp.h\"\n#include \"oneflow/core/job/global_for.h\"\n#include \"oneflow/core/job/nd_sbp_util.h\"\n#include \"oneflow/core/job/resource_desc.h\"\n#include \"oneflow/core/job/sbp_parallel.h\"\n#include \"oneflow/core/job/sbp_parallel.pb.h\"\n#include \"oneflow/core/register/blob_desc.h\"\n#include \"oneflow/core/rpc/include/global_process_ctx.h\"\n#include \"oneflow/core/framework/sbp_infer_util.h\"\n#include \"oneflow/core/job/parallel_desc.h\"\n#include \"oneflow/core/job/lazy_mode.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nstatic bool disable_middle_node = false;\n\nvoid DfsSetNdSbp(const std::vector<SbpParallel>& id2sbp_parallel, int32_t depth, int32_t max_depth,\n                 NdSbp& nd_sbp, std::vector<NdSbp>& nd_sbp_lists,\n                 std::unordered_map<NdSbp, int32_t>& nd_sbp_universe) {\n  if (depth == max_depth) {\n    nd_sbp_universe[nd_sbp] = nd_sbp_lists.size();\n    nd_sbp_lists.push_back(nd_sbp);\n  } else {\n    for (const auto& sbp_parallel : id2sbp_parallel) {\n      *nd_sbp.mutable_sbp_parallel(depth) = sbp_parallel;\n      DfsSetNdSbp(id2sbp_parallel, depth + 1, max_depth, nd_sbp, nd_sbp_lists, nd_sbp_universe);\n    }\n  }\n}\n\n// Let a nd sbp be consistent with the given hierarchy number\nMaybe<NdSbp> SetNdSbpDim(const NdSbp& nd_sbp, int32_t hierarchy_num) {\n  // Do not need to change\n  if (nd_sbp.sbp_parallel_size() == hierarchy_num) { return nd_sbp; }\n  // (S0, S0) -> S0\n  if (hierarchy_num == 1) {\n    CHECK_OR_RETURN(Is1dSbp(nd_sbp))\n        << NdSbpToString(nd_sbp) << \" can not be converted to a 1d sbp!\";\n    NdSbp new_sbp;\n    new_sbp.add_sbp_parallel();\n    *new_sbp.mutable_sbp_parallel(0) = nd_sbp.sbp_parallel(0);\n    return new_sbp;\n  }\n  // S0 -> (S0, S0)\n  CHECK_EQ_OR_RETURN(nd_sbp.sbp_parallel_size(), 1) << \"Illegal nd sbp transform.\";\n  NdSbp new_sbp;\n  for (int32_t i = 0; i < hierarchy_num; i++) {\n    new_sbp.add_sbp_parallel();\n    *new_sbp.mutable_sbp_parallel(i) = nd_sbp.sbp_parallel(0);\n  }\n  return new_sbp;\n}\n\nint32_t TotalNumSplit(const NdSbp& nd_sbp, const ParallelDesc& parallel_desc) {\n  int32_t total_num_split = 1;\n  for (int32_t i = 0; i < nd_sbp.sbp_parallel_size(); i++) {\n    if (nd_sbp.sbp_parallel(i).has_split_parallel()) {\n      total_num_split *= parallel_desc.hierarchy()->At(i);\n    }\n  }\n  return total_num_split;\n}\n\n// Dealing with 1D sbp to 1D sbp\n// Specifically, S -> P.\nMaybe<void> AskSbpCombinationFor1DSbp(const NdSbp& sbp_producer, const NdSbp& sbp_consumer,\n                                      const ParallelDesc& producer_parallel_desc,\n                                      const ParallelDesc& consumer_parallel_desc,\n                                      std::vector<NdSbp>& middle_sbps, int32_t* diag_node_pos) {\n  if (sbp_consumer.sbp_parallel(0).has_partial_sum_parallel()) {\n    // Support [4]: P <--> [2, 2]: (P, P)\n    // Support {0, 1, 2, 3}: P <--> {2, 0, 6, 7}: (P, P)\n    if (producer_parallel_desc.parallel_num() == consumer_parallel_desc.parallel_num()\n        && sbp_producer.sbp_parallel(0).has_partial_sum_parallel()) {\n      return Maybe<void>::Ok();\n    }\n\n    if (!sbp_producer.sbp_parallel(0).has_broadcast_parallel()) {\n      // S -> B -> P (Large cost!)\n      // TODO: Please implement S -> P directly.\n      // We do not support [3]: P <--> [2, 2]: (P, P) as well.\n\n      int32_t hierarchy_size = 0;\n      if (producer_parallel_desc.hierarchy()->elem_cnt()\n          < consumer_parallel_desc.hierarchy()->elem_cnt()) {\n        // The diagonal node uses the parallel description from producer\n        // (S, S) -> (B, B) -> P/(P, P) or S -> B -> P/(P, P)\n        *diag_node_pos = 1;\n        hierarchy_size = producer_parallel_desc.hierarchy()->NumAxes();\n      } else {\n        // The diagonal node uses the parallel description from consumer\n        // S/(S, S) -> B -> P or S/(S, S) -> (B, B) -> (P, P)\n        *diag_node_pos = 0;\n        hierarchy_size = consumer_parallel_desc.hierarchy()->NumAxes();\n      }\n\n      NdSbp broadcast_nd;\n      for (int32_t i = 0; i < hierarchy_size; i++) {\n        broadcast_nd.add_sbp_parallel();\n        broadcast_nd.mutable_sbp_parallel(i)->mutable_broadcast_parallel();\n      }\n      middle_sbps.emplace_back(broadcast_nd);\n    }\n  }\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\n// A constructor with init, designed for pre-stored boxing collector\nBoxingCollector::BoxingCollector(int32_t max_axis) { CHECK_JUST(Init(max_axis)); }\n\n// Construct a boxing collector with given maximum number of axis\nMaybe<void> BoxingCollector::Init(int32_t max_axis) {\n  // Update environment parameter\n  disable_middle_node = ParseBooleanFromEnv(\"ONEFLOW_BOXING_DISABLE_MIDDLE_NODE_AND_CHECK\", false);\n  // Not allowed two-step boxing and disable checking for debugging\n  if (disable_middle_node) { return Maybe<void>::Ok(); }\n  // Set up at least two split for op graph.\n  // For a negative example: Resnet50 only have B, P, S(0)\n  CollectUniverse(max_axis);\n  GenerateNdSbpList(2);\n  GenerateMap1d2nd();\n  // Get copy cost in lazy mode\n  LazyMode::Guard enable_lazy_mode(true);\n  JUST(GenerateCombination4SamePlacement(3));\n  JUST(GenerateCombination4DiffHierarchy(this, this));\n  JUST(GenerateCombination4DiffPlacement(this, this));\n  init_type_ = int32_t(enable_general_basic_communication\n                       || Singleton<ResourceDesc, ForSession>::Get()->nccl_use_compute_stream());\n  return Maybe<void>::Ok();\n}\n\n// Customized initialization with given blob and parallel description\nMaybe<void> BoxingCollector::Init(const BlobDesc& logical_blob_desc,\n                                  const ParallelDesc& parallel_desc) {\n  CollectUniverse(logical_blob_desc.shape().NumAxes());\n  GenerateNdSbpList(parallel_desc.hierarchy()->NumAxes());\n  // Filter out unsuitable middle nodes before computing minimum cost.\n  JUST(FilterNdSbpList4LogicalShape(logical_blob_desc, *parallel_desc.hierarchy()));\n  GenerateMap1d2nd();\n  // Get copy cost in lazy mode\n  LazyMode::Guard enable_lazy_mode(true);\n  JUST(GenerateCombination4SamePlacement(5, logical_blob_desc, parallel_desc));\n  init_type_ = int32_t(enable_general_basic_communication\n                       || Singleton<ResourceDesc, ForSession>::Get()->nccl_use_compute_stream());\n  return Maybe<void>::Ok();\n}\n\n// Collect Sbp Parallel\nvoid BoxingCollector::CollectUniverse(const SbpParallel& sbp) {\n  if (sbp_parallel_universe_.find(sbp) == sbp_parallel_universe_.end()) {\n    int32_t curr_size = sbp_parallel_universe_.size();\n    sbp_parallel_universe_[sbp] = curr_size;\n    id2sbp_parallel_.push_back(sbp);\n  }\n}\n\n// Find corresponding id for Nd sbp\nint32_t BoxingCollector::FindId4NdSbp(const NdSbp& nd_sbp) {\n  // Directly search on the nd_sbp_list\n  if (nd_sbp.sbp_parallel_size() == hierarchy_num_) {\n    const auto& it_nd_sbp = nd_sbp_universe_.find(nd_sbp);\n    if (it_nd_sbp != nd_sbp_universe_.end()) {\n      return it_nd_sbp->second;\n    } else {\n      return -1;\n    }\n  }\n\n  // Find the diagonal node if it could be converted to a 1D sbp\n  if (Is1dSbp(nd_sbp)) {\n    const auto& it_nd_sbp = sbp_parallel_universe_.find(nd_sbp.sbp_parallel(0));\n    if (it_nd_sbp != sbp_parallel_universe_.end()) { return id_1d_2_nd_[it_nd_sbp->second]; }\n  }\n\n  // Can not be converted to a 1D sbp or not found in the 1D sbp list\n  return -1;\n}\n\n// Set default Sbp list\nvoid BoxingCollector::CollectUniverse(int32_t max_axis) {\n  SbpParallel sbp;\n  sbp.mutable_broadcast_parallel();\n  CollectUniverse(sbp);\n  for (int32_t axis = 0; axis < max_axis; axis++) {\n    sbp.mutable_split_parallel()->set_axis(axis);\n    CollectUniverse(sbp);\n  }\n  sbp.mutable_partial_sum_parallel();\n  CollectUniverse(sbp);\n}\n\n// Generate nd sbp list\nvoid BoxingCollector::GenerateNdSbpList(int32_t hierarchy_num) {\n  // 1D sbp does not support S->P. But it seems that we do not need to deal with it for now.\n  // And we do not have 3D sbp or higher dimension.\n  hierarchy_num_ = hierarchy_num;\n\n  // Generate possible nd_sbp lists\n  NdSbp nd_sbp;\n  for (int32_t dim_sbp = 0; dim_sbp < hierarchy_num; dim_sbp++) { nd_sbp.add_sbp_parallel(); }\n  DfsSetNdSbp(id2sbp_parallel_, 0, hierarchy_num, nd_sbp, nd_sbp_lists_, nd_sbp_universe_);\n}\n\n// Generate the map from 1d sbp to 2d sbp\nvoid BoxingCollector::GenerateMap1d2nd() {\n  // Number of 1d sbp\n  int32_t m = id2sbp_parallel_.size();\n\n  // Generate the id Map from 1d sbp to nd sbp\n  NdSbp nd_sbp;\n  for (int32_t dim_sbp = 0; dim_sbp < hierarchy_num_; dim_sbp++) { nd_sbp.add_sbp_parallel(); }\n  id_1d_2_nd_.clear();\n  id_1d_2_nd_.resize(m, -1);\n  for (int32_t id_1d = 0; id_1d < m; id_1d++) {\n    for (int32_t dim_sbp = 0; dim_sbp < hierarchy_num_; dim_sbp++) {\n      *nd_sbp.mutable_sbp_parallel(dim_sbp) = id2sbp_parallel_[id_1d];\n    }\n    // NOTE: The 2d sbp might be filtered out already.\n    const auto& it_ = nd_sbp_universe_.find(nd_sbp);\n    if (it_ != nd_sbp_universe_.end()) { id_1d_2_nd_[id_1d] = it_->second; }\n  }\n}\n\n// Generate the transfer rule for different combinations with the same hierarchy\nMaybe<void> BoxingCollector::GenerateCombination4SamePlacement(int32_t max_middle_node_num) {\n  // other parameters\n  // NOTE: The performance of this function are all the same with different hierarchy\n  int32_t world_size = GlobalProcessCtx::WorldSize();\n  Shape hierarchy44({4 * world_size, 4 * world_size});\n  int32_t virtual_range_size = hierarchy44.elem_cnt();\n  std::shared_ptr<Shape> virtual_hierarchy = std::make_shared<Shape>(hierarchy44);\n  auto parallel_desc = JUST(ParallelDesc::New(\n      \"cpu\", {\"0:0-\" + std::to_string(hierarchy44.elem_cnt() - 1)}, virtual_hierarchy));\n  BlobDesc blob_desc({virtual_range_size, virtual_range_size, virtual_range_size,\n                      virtual_range_size, virtual_range_size, virtual_range_size},\n                     DataType::kInt8, MemoryFormat::kContiguous, /*is_dynamic=*/false);\n  JUST(GenerateCombination4SamePlacement(max_middle_node_num, blob_desc, *parallel_desc));\n  return Maybe<void>::Ok();\n}\n\n// Generate the transfer rule for different combinations with the same hierarchy\nMaybe<void> BoxingCollector::GenerateCombination4SamePlacement(int32_t max_middle_node_num,\n                                                               const BlobDesc& blob_desc,\n                                                               const ParallelDesc& parallel_desc) {\n  // Store the origin transfer cost information\n  int32_t n = nd_sbp_lists_.size();\n  minimum_copy_cost_.clear();\n  minimum_copy_cost_.resize(n);\n  middle_nodes_.clear();\n  middle_nodes_.resize(n);\n  for (int32_t i = 0; i < n; i++) {\n    minimum_copy_cost_[i].resize(n);\n    middle_nodes_[i].resize(n);\n    for (int32_t j = 0; j < n; j++) {\n      minimum_copy_cost_[i][j] = JUST(ComputeLazyCopyCostBetweenNdSbp(\n          nd_sbp_lists_[i], nd_sbp_lists_[j], blob_desc, parallel_desc, parallel_desc,\n          /*requires_same_sbp=*/false));\n    }\n  }\n\n  auto NotMiddleNode = [&](int32_t i, int32_t j, int32_t k, int32_t middle_node_num_ik) -> bool {\n    // Not allow i -> i -> j or i -> j -> j.\n    if (k == j || k == i) { return true; }\n    // We add middle nodes one by one\n    // Thus, we allow multiple nodes from i to k but we only accept 1 step from k to j.\n    // i -> ? -> k -> j\n    if (middle_nodes_[k][j].size() > 0) { return true; }\n    // To avoid multiple counting and bugs, the number of middle nodes between i and k\n    // must be exactly middle_node_num_ik, which is (middle_node_num - 1)\n    if (middle_node_num_ik) {\n      if (middle_nodes_[i][k].size() == 0 || middle_nodes_[i][k][0].size() != middle_node_num_ik) {\n        return true;\n      }\n    } else {\n      if (middle_nodes_[i][k].size() > 0) { return true; }\n    }\n    return false;\n  };\n\n  for (int32_t middle_node_num = 1; middle_node_num <= max_middle_node_num; middle_node_num++) {\n    int32_t middle_node_num_ik = middle_node_num - 1;\n\n    for (int32_t i = 0; i < n; i++) {\n      for (int32_t j = 0; j < n; j++) {\n        if (minimum_copy_cost_[i][j] < GetValidMaxCopyCost()) { continue; }\n        // Compute the smallest transfer cost\n        // k is the middle node, i -> k -> j\n        for (int32_t k = 0; k < n; k++) {\n          if (NotMiddleNode(i, j, k, middle_node_num_ik)) { continue; }\n          double curr_copy_cost = minimum_copy_cost_[i][k] + minimum_copy_cost_[k][j];\n          if (curr_copy_cost < minimum_copy_cost_[i][j]) {\n            minimum_copy_cost_[i][j] = curr_copy_cost;\n          }\n        }\n        // If the minimum copy cost remains infinity, adding one middle node does not make it.\n        if (minimum_copy_cost_[i][j] > GetValidMaxCopyCost()) { continue; }\n        // Find those middle nodes\n        for (int32_t k = 0; k < n; k++) {\n          if (NotMiddleNode(i, j, k, middle_node_num_ik)) { continue; }\n          // Now we start to judge if the edge have a minimum cost\n          // It needs to be \"<=\" since we have 0 cost.\n          // Using \"<\" would give no middle nodes from (B, B) to any other nd sbp.\n          if (minimum_copy_cost_[i][k] + minimum_copy_cost_[k][j]\n              <= minimum_copy_cost_[i][j] * 1.0000001) {\n            // i -> ? -> k\n            if (middle_nodes_[i][k].size() > 0) {\n              // We have multiple choices going from i to k\n              for (const auto& middle_node_ik : middle_nodes_[i][k]) {\n                middle_nodes_[i][j].push_back(middle_node_ik);\n                middle_nodes_[i][j][middle_nodes_[i][j].size() - 1].push_back(k);\n              }\n            } else {\n              // We only need one middle node k to reach j from i\n              middle_nodes_[i][j].push_back({k});\n            }\n          }\n        }\n        CHECK_OR_RETURN(middle_nodes_[i][j].size() > 0)\n            << \"No middle nodes given from \" << NdSbpToString(nd_sbp_lists_[i]) << \" to \"\n            << NdSbpToString(nd_sbp_lists_[j]) << \" in boxing collector\";\n      }\n    }\n  }\n\n  return Maybe<void>::Ok();\n}\n\n// Generate the transfer rule for different combinations with different hierarchies on the same\n// placement\nMaybe<void> BoxingCollector::GenerateCombination4DiffHierarchy(\n    BoxingCollector* boxing_collector_producer, BoxingCollector* boxing_collector_consumer) {\n  // Store the boxing collector pointer\n\n  // Search the path that contains one of the diagonal sbp\n  int32_t n = nd_sbp_lists_.size();\n  diag_node_diff_hierarchy_.clear();\n  diag_node_diff_hierarchy_.resize(n);\n  for (int32_t i = 0; i < n; i++) {\n    diag_node_diff_hierarchy_[i].resize(n);\n    for (int32_t j = 0; j < n; j++) {\n      JUST(Generate1Combination4DiffHierarchy(i, j, boxing_collector_producer,\n                                              boxing_collector_consumer,\n                                              diag_node_diff_hierarchy_[i][j]));\n    }\n  }\n\n  return Maybe<void>::Ok();\n}\n\n// Generate the transfer rule for different combinations with different placements\nMaybe<void> BoxingCollector::GenerateCombination4DiffPlacement(\n    BoxingCollector* boxing_collector_producer, BoxingCollector* boxing_collector_consumer) {\n  // Virtual parallel and blob description\n  int32_t world_size = GlobalProcessCtx::WorldSize();\n  int32_t virtual_range_size = 4 * world_size * (4 * world_size + 1);\n  BlobDesc blob_desc({virtual_range_size, virtual_range_size, virtual_range_size,\n                      virtual_range_size, virtual_range_size, virtual_range_size},\n                     DataType::kInt8, MemoryFormat::kContiguous, /*is_dynamic=*/false);\n  // Virtual placements before transfer\n  Shape in_hierarchy44({4 * world_size + 1, 4 * world_size});\n  std::shared_ptr<Shape> in_hierarchy = std::make_shared<Shape>(in_hierarchy44);\n  auto in_parallel_desc = JUST(ParallelDesc::New(\n      \"cpu\", {\"0:0-\" + std::to_string(in_hierarchy44.elem_cnt() - 1)}, in_hierarchy));\n  // Virtual placements after transfer\n  Shape out_hierarchy44({4 * world_size, 4 * world_size});\n  std::shared_ptr<Shape> out_hierarchy = std::make_shared<Shape>(out_hierarchy44);\n  auto out_parallel_desc = JUST(ParallelDesc::New(\n      \"cpu\", {\"0:0-\" + std::to_string(out_hierarchy44.elem_cnt() - 1)}, out_hierarchy));\n\n  JUST(GenerateCombination4DiffPlacement(boxing_collector_producer, boxing_collector_consumer,\n                                         blob_desc, *in_parallel_desc, *out_parallel_desc));\n  return Maybe<void>::Ok();\n}\n\n// The cost for transferring a 1D sbp between different placements\nMaybe<void> BoxingCollector::ComputeCostFor1DSbpDiffPlacement(\n    const BlobDesc& blob_desc, const ParallelDesc& in_parallel_desc,\n    const ParallelDesc& out_parallel_desc,\n    std::vector<std::vector<double>>& cost_4_diff_placement) {\n  // Number of 1d sbp\n  int32_t m = id2sbp_parallel_.size();\n  // Compute the cost while transferring a 1D sbp between different placements\n  cost_4_diff_placement.clear();\n  cost_4_diff_placement.resize(m);\n  for (int32_t id_1d_producer = 0; id_1d_producer < m; id_1d_producer++) {\n    cost_4_diff_placement[id_1d_producer].resize(m, GetMaxVal<float>());\n    int32_t diag_producer = id_1d_2_nd_[id_1d_producer];\n    if (diag_producer < 0) { continue; }\n\n    for (int32_t id_1d_consumer = 0; id_1d_consumer < m; id_1d_consumer++) {\n      int32_t diag_consumer = id_1d_2_nd_[id_1d_consumer];\n      if (diag_consumer < 0) { continue; }\n      cost_4_diff_placement[id_1d_producer][id_1d_consumer] = JUST(ComputeLazyCopyCostBetweenNdSbp(\n          nd_sbp_lists_[diag_producer], nd_sbp_lists_[diag_consumer], blob_desc, in_parallel_desc,\n          out_parallel_desc, false));\n    }\n  }\n  return Maybe<void>::Ok();\n}\n\n// Generate the transfer rule for different combinations with different placements\nMaybe<void> BoxingCollector::GenerateCombination4DiffPlacement(\n    BoxingCollector* boxing_collector_producer, BoxingCollector* boxing_collector_consumer,\n    const BlobDesc& blob_desc, const ParallelDesc& in_parallel_desc,\n    const ParallelDesc& out_parallel_desc) {\n  // The cost for transferring a 1D sbp between different placements\n  std::vector<std::vector<double>> cost_4_diff_placement;\n  // Compute the cost while transferring a 1D sbp between different placements\n  JUST(ComputeCostFor1DSbpDiffPlacement(blob_desc, in_parallel_desc, out_parallel_desc,\n                                        cost_4_diff_placement));\n\n  // Search the path that contains two of the diagonal sbp\n  int32_t n = nd_sbp_lists_.size();\n  diag_node_diff_placement_.clear();\n  diag_node_diff_placement_.resize(n);\n  for (int32_t i = 0; i < n; i++) {\n    diag_node_diff_placement_[i].resize(n);\n    for (int32_t j = 0; j < n; j++) {\n      JUST(Generate1Combination4DiffPlacement(i, j, boxing_collector_producer,\n                                              boxing_collector_consumer, cost_4_diff_placement,\n                                              diag_node_diff_placement_[i][j]));\n    }\n  }\n\n  return Maybe<void>::Ok();\n}\n\n// Print the cost and middle nodes\nvoid BoxingCollector::PrintBoxingTables() {\n  if (GlobalProcessCtx::Rank() == 0) {\n    std::cout << \"===================minimum copy cost==================\" << std::endl;\n    // other parameters\n    // To be noted that the performance of this function are all the same with different hierarchy\n    Shape hierarchy44({4, 4});\n    std::shared_ptr<Shape> in_hierarchy = std::make_shared<Shape>(hierarchy44);\n    double logical_blob_size = 1024.0;\n    int32_t n = nd_sbp_lists_.size();\n    // Print the origin copy cost table\n    std::cout << \"Cost\\t\";\n    for (int32_t j = 0; j < n; j++) { std::cout << NdSbpToString(nd_sbp_lists_[j]) << \"\\t\"; }\n    std::cout << std::endl;\n    for (int32_t i = 0; i < n; i++) {\n      std::cout << NdSbpToString(nd_sbp_lists_[i]) << \"\\t\";\n      for (int32_t j = 0; j < n; j++) {\n        if (minimum_copy_cost_[i][j] > GetValidMaxCopyCost()) {\n          std::cout << \"X\\t\";\n        } else {\n          std::cout << minimum_copy_cost_[i][j] << \"\\t\";\n        }\n      }\n      std::cout << std::endl;\n    }\n\n    std::cout << std::endl;\n    std::cout << \"Original Copy Cost\" << std::endl;\n    std::cout << \"logical blob size: \" << logical_blob_size << std::endl;\n    std::cout << \"hierarchy: \" << *in_hierarchy << std::endl;\n\n    std::cout << \"============================middle nodes===========================\" << std::endl;\n\n    // Print the middle nodes\n    std::cout << \"Middle Sbp\\t\";\n    for (int32_t j = 0; j < n; j++) { std::cout << NdSbpToString(nd_sbp_lists_[j]) << \"\\t\"; }\n    std::cout << std::endl;\n    for (int32_t i = 0; i < n; i++) {\n      std::cout << NdSbpToString(nd_sbp_lists_[i]) << \"\\t\";\n      for (int32_t j = 0; j < n; j++) {\n        if (minimum_copy_cost_[i][j] > GetValidMaxCopyCost()) {\n          std::cout << \"X\";\n        } else if (middle_nodes_[i][j].size() > 0) {\n          for (int32_t k = 0; k < middle_nodes_[i][j].size(); k++) {\n            std::cout << NdSbpToString(nd_sbp_lists_[middle_nodes_[i][j][k][0]]);\n            for (int32_t l = 1; l < middle_nodes_[i][j][k].size(); l++) {\n              std::cout << \"->\" << NdSbpToString(nd_sbp_lists_[middle_nodes_[i][j][k][l]]);\n            }\n            std::cout << \"; \";\n          }\n        }\n\n        std::cout << \"\\t\";\n      }\n      std::cout << std::endl;\n    }\n\n    std::cout << std::endl;\n    std::cout << \"Minimum Copy Cost after second search\" << std::endl;\n    std::cout << \"logical blob size: \" << logical_blob_size << std::endl;\n    std::cout << \"hierarchy: \" << *in_hierarchy << std::endl;\n\n    std::cout << \"====================middle nodes for different placement====================\"\n              << std::endl;\n\n    std::cout << \"Middle nodes for different placement\\t\";\n    for (int32_t j = 0; j < n; j++) { std::cout << NdSbpToString(nd_sbp_lists_[j]) << \"\\t\"; }\n    std::cout << std::endl;\n    for (int32_t i = 0; i < n; i++) {\n      std::cout << NdSbpToString(nd_sbp_lists_[i]) << \"\\t\";\n      for (int32_t j = 0; j < n; j++) {\n        if (diag_node_diff_placement_[i][j].size() > 0) {\n          for (int32_t k = 0; k < diag_node_diff_placement_[i][j].size(); k++) {\n            std::cout << \"[\" << NdSbpToString(nd_sbp_lists_[diag_node_diff_placement_[i][j][k][0]])\n                      << \", \" << NdSbpToString(nd_sbp_lists_[diag_node_diff_placement_[i][j][k][1]])\n                      << \"]; \";\n          }\n        }\n        std::cout << \"\\t\";\n      }\n      std::cout << std::endl;\n    }\n\n    std::cout << \"====================middle nodes for different hierarchy====================\"\n              << std::endl;\n\n    std::cout << \"Middle nodes for different hierarchy\\t\";\n    for (int32_t j = 0; j < n; j++) { std::cout << NdSbpToString(nd_sbp_lists_[j]) << \"\\t\"; }\n    std::cout << std::endl;\n    for (int32_t i = 0; i < n; i++) {\n      std::cout << NdSbpToString(nd_sbp_lists_[i]) << \"\\t\";\n      for (int32_t j = 0; j < n; j++) {\n        if (diag_node_diff_hierarchy_[i][j].size() > 0) {\n          for (int32_t k = 0; k < diag_node_diff_hierarchy_[i][j].size(); k++) {\n            std::cout << NdSbpToString(nd_sbp_lists_[diag_node_diff_hierarchy_[i][j][k][0]])\n                      << \"; \";\n          }\n        }\n        std::cout << \"\\t\";\n      }\n      std::cout << std::endl;\n    }\n\n    std::cout << \"================================================\" << std::endl;\n  }\n}\n\n// Ask if the boxing algorithm accepts the current sbp combination\nMaybe<void> BoxingCollector::AskSbpCombination(const NdSbp& sbp_producer, const NdSbp& sbp_consumer,\n                                               const BlobDesc& logical_blob_desc,\n                                               const ParallelDesc& producer_parallel_desc,\n                                               const ParallelDesc& consumer_parallel_desc,\n                                               bool is_customized, std::vector<NdSbp>& middle_sbps,\n                                               int32_t* diag_node_pos, bool compute_cost) {\n  middle_sbps.clear();\n  // Not allowed two-step boxing and disable checking for debugging\n  if (disable_middle_node) { return Maybe<void>::Ok(); }\n  if (producer_parallel_desc == consumer_parallel_desc && sbp_producer == sbp_consumer) {\n    return Maybe<void>::Ok();\n  }\n\n  // Dealing with 1D sbp to 1D sbp\n  if (Is1dSbp(sbp_producer) && Is1dSbp(sbp_consumer)) {\n    JUST(AskSbpCombinationFor1DSbp(sbp_producer, sbp_consumer, producer_parallel_desc,\n                                   consumer_parallel_desc, middle_sbps, diag_node_pos));\n    // No middle nodes for the other 1d-sbp combinations\n    return Maybe<void>::Ok();\n  }\n\n#if defined(WITH_CUDA) || defined(WITH_NPU) || defined(WITH_MLU)\n  // Use a general basic communication if no P in the consumer\n  if (((Singleton<ResourceDesc, ForSession>::Get()->nccl_use_compute_stream()\n        && producer_parallel_desc == consumer_parallel_desc)\n       || enable_general_basic_communication)\n      && (!NdSbpHasPartialParallel(sbp_consumer))\n      && producer_parallel_desc.device_type() == consumer_parallel_desc.device_type()\n      && producer_parallel_desc.device_type() != DeviceType::kCPU) {\n    if (NdSbpHasPartialParallel(sbp_producer) && NdSbpHasBroadcastParallel(sbp_consumer)) {\n      // (?, P, ?)->(Si, Sj)->(?, B, ?), two-step transfer\n      // Directly applying general basic communication would have O(n^2) time complexity for P->B\n      // Using two-step transfer would reduce it to a linear cost\n      JUST(AskSbpCombination4GeneralBasicCommunication(\n          sbp_producer, sbp_consumer, logical_blob_desc, producer_parallel_desc,\n          consumer_parallel_desc, middle_sbps, diag_node_pos));\n    }\n    // Otherwise, one-step transfer\n    return Maybe<void>::Ok();\n  }\n#endif  // WITH_CUDA || WITH_NPU || defined(WITH_MLU)\n\n  if (JUST(ComputeLazyCopyCostBetweenNdSbp(sbp_producer, sbp_consumer, logical_blob_desc,\n                                           producer_parallel_desc, consumer_parallel_desc,\n                                           /*requires_same_sbp=*/false))\n      < GetValidMaxCopyCost()) {\n    return Maybe<void>::Ok();\n  } else {\n    int32_t require_init_type =\n        int32_t(enable_general_basic_communication\n                || Singleton<ResourceDesc, ForSession>::Get()->nccl_use_compute_stream());\n    if (init_type_ != require_init_type) {\n      // We assemble the boxing table from S(0) to S(5).\n      // Those splitting in higher axes are considered in the customized boxing.\n      constexpr int32_t kRegularMaxSplitAxes = 6;\n      JUST(Init(kRegularMaxSplitAxes));\n    }\n  }\n\n  // Middle nodes algorithm supports transfer for different machines or devices or hierarchies\n  if (producer_parallel_desc != consumer_parallel_desc) {\n    JUST(AskSbpCombination4DiffPlacement(sbp_producer, sbp_consumer, logical_blob_desc,\n                                         producer_parallel_desc, consumer_parallel_desc,\n                                         is_customized, middle_sbps, diag_node_pos, compute_cost));\n\n    return Maybe<void>::Ok();\n  }\n  // Transfer for the same machines, devices and hierarchy.\n  if (sbp_producer == sbp_consumer) { return Maybe<void>::Ok(); }\n  const auto& parallel_hierarchy = producer_parallel_desc.hierarchy();\n\n  *diag_node_pos = 0;\n  // Dealing with nD sbp, n>2\n  if (parallel_hierarchy->NumAxes() > 2) {\n    CHECK_OR_RETURN(compute_cost)\n        << \"Boxing does not support a hierarchy with dimension greater than 2\";\n    return Maybe<void>::Ok();\n  }\n  // Ask for sbp combination with the same 2-D hierarchy and placement\n  JUST(AskSbpCombination4Same2DPlacement(sbp_producer, sbp_consumer, logical_blob_desc,\n                                         producer_parallel_desc, consumer_parallel_desc,\n                                         is_customized, middle_sbps, diag_node_pos, compute_cost));\n\n  return Maybe<void>::Ok();\n}\n\n// Ask for sbp combination with the same 2-D hierarchy and placement\nMaybe<void> BoxingCollector::AskSbpCombination4Same2DPlacement(\n    const NdSbp& sbp_producer, const NdSbp& sbp_consumer, const BlobDesc& logical_blob_desc,\n    const ParallelDesc& producer_parallel_desc, const ParallelDesc& consumer_parallel_desc,\n    bool is_customized, std::vector<NdSbp>& middle_sbps, int32_t* diag_node_pos,\n    bool compute_cost) {\n  CHECK_OR_RETURN(producer_parallel_desc == consumer_parallel_desc)\n      << \"Producer and consumer have different placements, Please use AskSbpCombination directly\";\n  middle_sbps.clear();\n\n  // Find the 2D sbp id\n  int32_t i = FindId4NdSbp(sbp_producer);\n  int32_t j = FindId4NdSbp(sbp_consumer);\n  // Dealing with 2D sbp\n  if (i >= 0 && j >= 0) {\n    // Such combination can not be support with limited middle nodes\n    if (minimum_copy_cost_[i][j] > GetValidMaxCopyCost()) {\n      CHECK_OR_RETURN(compute_cost) << \"Boxing does not support \" << NdSbpToString(sbp_producer)\n                                    << \" -> \" << NdSbpToString(sbp_consumer) << \" for 2D sbp\";\n      return Maybe<void>::Ok();\n    }\n    // Current design can deal with such combination. Do not need to insert middle nodes\n    if (middle_nodes_[i][j].size() == 0) { return Maybe<void>::Ok(); }\n    // Find a list of middle nodes with minimum storage\n    int32_t min_k = -1;\n    double min_cost = GetValidMaxCopyCost();\n    for (int32_t k = 0; k < middle_nodes_[i][j].size(); k++) {\n      double curr_cost = 0.0;\n      for (int32_t middle_sbp_id : middle_nodes_[i][j][k]) {\n        Shape logical_shape = logical_blob_desc.shape();\n        // Storage4NdSbp would modify logical_shape2 as well\n        curr_cost += Storage4NdSbp(nd_sbp_lists_[middle_sbp_id], logical_shape,\n                                   *producer_parallel_desc.hierarchy());\n        if (curr_cost > GetValidMaxCopyCost()) { break; }\n      }\n      // store k if renew minimum cost\n      if (curr_cost < min_cost) {\n        min_k = k;\n        min_cost = curr_cost;\n      }\n    }\n\n    // If we found a list of middle nodes with current boxing collector\n    int32_t producer_hierarchy_num = producer_parallel_desc.hierarchy()->NumAxes();\n    if (min_k >= 0) {\n      for (int32_t middle_sbp_id : middle_nodes_[i][j][min_k]) {\n        middle_sbps.emplace_back(\n            *JUST(SetNdSbpDim(nd_sbp_lists_[middle_sbp_id], producer_hierarchy_num)));\n      }\n      return Maybe<void>::Ok();\n    }\n  }\n\n  // // If we can not found a list of middle nodes even after customized boxing collector\n  if (is_customized) {\n    CHECK_OR_RETURN(compute_cost) << \"Boxing does not support \" << NdSbpToString(sbp_producer)\n                                  << \" -> \" << NdSbpToString(sbp_consumer)\n                                  << \" for Shape: \" << logical_blob_desc.shape();\n    return Maybe<void>::Ok();\n  }\n\n  // Customized boxing collector and try the algorithm again\n  BoxingCollector customized_boxing_collector;\n  JUST(customized_boxing_collector.Init(logical_blob_desc, producer_parallel_desc));\n  JUST(customized_boxing_collector.AskSbpCombination4Same2DPlacement(\n      sbp_producer, sbp_consumer, logical_blob_desc, producer_parallel_desc, consumer_parallel_desc,\n      /*is_customized=*/true, middle_sbps, diag_node_pos, compute_cost));\n  return Maybe<void>::Ok();\n}\n\n// Ask for sbp combination with different hierarchies and placements\nMaybe<void> BoxingCollector::AskSbpCombination4DiffPlacement(\n    const NdSbp& sbp_producer, const NdSbp& sbp_consumer, const BlobDesc& logical_blob_desc,\n    const ParallelDesc& producer_parallel_desc, const ParallelDesc& consumer_parallel_desc,\n    bool is_customized, std::vector<NdSbp>& middle_sbps, int32_t* diag_node_pos,\n    bool compute_cost) {\n  middle_sbps.clear();\n  // Find the 2D sbp id\n  int32_t i = FindId4NdSbp(sbp_producer);\n  int32_t j = FindId4NdSbp(sbp_consumer);\n  // Different placements: [2, 3] vs 5, or [3, 2] vs [2, 2], or cpu vs cuda\n  // Different hierarchies: [2, 3] vs 5, or [4, 3] vs [6, 2]\n  bool same_placement = producer_parallel_desc.EqualsIgnoringHierarchy(consumer_parallel_desc);\n  // Dealing with 2D sbp\n  if (i >= 0 && j >= 0) {\n    // Pure copy between machines and devices\n    if (i == j && (*producer_parallel_desc.hierarchy() == *consumer_parallel_desc.hierarchy())) {\n      return Maybe<void>::Ok();\n    }\n    if (same_placement) {\n      // Different hierarchies\n      CHECK_OR_RETURN(diag_node_diff_hierarchy_.size() > 0)\n          << \"Have not initialized the combination table for different hierarchies yet! \"\n             \"Please run JUST(GenerateCombination4DiffHierarchy(this, this)); \"\n             \"before Asking sbp combination for different parallel description.\";\n      if (JUST(Ask1Combination4DiffPlacement(\n              sbp_producer, sbp_consumer, logical_blob_desc, producer_parallel_desc,\n              consumer_parallel_desc, is_customized, middle_sbps, diag_node_pos, compute_cost, this,\n              this, diag_node_diff_hierarchy_[i][j]))) {\n        return Maybe<void>::Ok();\n      }\n    } else {\n      // Different placements\n      CHECK_OR_RETURN(diag_node_diff_placement_.size() > 0)\n          << \"Have not initialized the combination table for different hierarchies yet! \"\n             \"Please run JUST(GenerateCombination4DiffPlacement(this, this)); \"\n             \"before Asking sbp combination for different parallel description.\";\n      if (JUST(Ask1Combination4DiffPlacement(\n              sbp_producer, sbp_consumer, logical_blob_desc, producer_parallel_desc,\n              consumer_parallel_desc, is_customized, middle_sbps, diag_node_pos, compute_cost, this,\n              this, diag_node_diff_placement_[i][j]))) {\n        return Maybe<void>::Ok();\n      }\n    }\n  }\n  // Customized boxing collector and try the algorithm again\n  if (is_customized) {\n    CHECK_OR_RETURN(compute_cost) << \"Boxing does not support \" << NdSbpToString(sbp_producer)\n                                  << \"[hierarchy: \" << *producer_parallel_desc.hierarchy()\n                                  << \"] -> \" << NdSbpToString(sbp_consumer)\n                                  << \"[hierarchy: \" << *consumer_parallel_desc.hierarchy()\n                                  << \"] for blob shape: \" << logical_blob_desc.shape();\n    return Maybe<void>::Ok();\n  }\n  // Customize boxing collector for producer\n  BoxingCollector customized_boxing_collector_producer;\n  JUST(customized_boxing_collector_producer.Init(logical_blob_desc, producer_parallel_desc));\n  // Customize boxing collector for consumer\n  BoxingCollector customized_boxing_collector_consumer;\n  JUST(customized_boxing_collector_consumer.Init(logical_blob_desc, consumer_parallel_desc));\n\n  std::vector<std::vector<int32_t>> diag_nodes;\n  // Generate the combination table for different hierarchies or placements\n  if (same_placement) {\n    JUST(customized_boxing_collector_producer.Generate1Combination4DiffHierarchy(\n        customized_boxing_collector_producer.FindId4NdSbp(sbp_producer),\n        customized_boxing_collector_consumer.FindId4NdSbp(sbp_consumer),\n        &customized_boxing_collector_producer, &customized_boxing_collector_consumer, diag_nodes));\n  } else {\n    // Compute the cost while transferring a 1D sbp between different placements\n    std::vector<std::vector<double>> cost_4_diff_placement;\n    JUST(ComputeCostFor1DSbpDiffPlacement(logical_blob_desc, producer_parallel_desc,\n                                          consumer_parallel_desc, cost_4_diff_placement));\n\n    JUST(customized_boxing_collector_producer.Generate1Combination4DiffPlacement(\n        customized_boxing_collector_producer.FindId4NdSbp(sbp_producer),\n        customized_boxing_collector_consumer.FindId4NdSbp(sbp_consumer),\n        &customized_boxing_collector_producer, &customized_boxing_collector_consumer,\n        cost_4_diff_placement, diag_nodes));\n  }\n\n  JUST(customized_boxing_collector_producer.Ask1Combination4DiffPlacement(\n      sbp_producer, sbp_consumer, logical_blob_desc, producer_parallel_desc, consumer_parallel_desc,\n      /*is_customized=*/true, middle_sbps, diag_node_pos, compute_cost,\n      &customized_boxing_collector_producer, &customized_boxing_collector_consumer, diag_nodes));\n  return Maybe<void>::Ok();\n}\n\n// Generate the transfer rule for one combination with different hierarchies on the same\n// placement. id_producer -> id_consumer.\nMaybe<void> BoxingCollector::Generate1Combination4DiffHierarchy(\n    int32_t id_producer, int32_t id_consumer, BoxingCollector* boxing_collector_producer,\n    BoxingCollector* boxing_collector_consumer, std::vector<std::vector<int32_t>>& diag_nodes) {\n  // Number of 1d sbp\n  int32_t m = id2sbp_parallel_.size();\n\n  // Search the path that contains one of the diagonal sbp\n\n  // minimum number of node\n  int32_t min_path_length = 100;\n  // minimum cost\n  double min_cost = GetValidMaxCopyCost();\n\n  for (int32_t id_1d = 0; id_1d < m; id_1d++) {\n    // We do not support [2, 3]: (S0, S1) -> [6]: S0 for a tensor with shape (14, 21)\n    // Thus, the diagonal node should suit both the hierarchies.\n    int32_t diag_producer = boxing_collector_producer->id_1d_2_nd_[id_1d];\n    if (diag_producer < 0) { continue; }\n    int32_t diag_consumer = boxing_collector_consumer->id_1d_2_nd_[id_1d];\n    if (diag_consumer < 0) { continue; }\n    // Find the path with minimum number of nodes\n    int32_t path_length = 0;\n    // Transfer from id_producer to id_2d\n    if (boxing_collector_producer->middle_nodes_[id_producer][diag_producer].size() > 0) {\n      path_length +=\n          boxing_collector_producer->middle_nodes_[id_producer][diag_producer][0].size() + 1;\n    } else if (id_producer != diag_producer) {\n      path_length++;\n    }\n    // Transfer from id_2d to id_consumer\n    if (boxing_collector_consumer->middle_nodes_[diag_consumer][id_consumer].size() > 0) {\n      path_length +=\n          boxing_collector_consumer->middle_nodes_[diag_consumer][id_consumer][0].size() + 1;\n    } else if (diag_consumer != id_consumer) {\n      path_length++;\n    }\n    // Pick the path with minimum copy cost\n    if (path_length <= min_path_length) {\n      double curr_cost =\n          boxing_collector_producer->minimum_copy_cost_[id_producer][diag_producer]\n          + boxing_collector_consumer->minimum_copy_cost_[diag_consumer][id_consumer];\n\n      min_path_length = path_length;\n      // Find a candidate with small cost\n      if (curr_cost < min_cost * kFloatDeviationPlus) {\n        // Find a smaller cost, clear the previous path.\n        if (curr_cost < min_cost * kFloatDeviationMinus) {\n          min_cost = curr_cost;\n          diag_nodes.clear();\n        }\n        // Add the current diagonal node\n        // Asymmetry happens here. We can only store one side of the diagonal node.\n        // We do not store diag_consumer\n        diag_nodes.push_back({diag_producer, diag_consumer});\n      }\n    }\n  }\n\n  return Maybe<void>::Ok();\n}\n\n// Ask for one combination with different hierarchies and placements\nMaybe<bool> BoxingCollector::Ask1Combination4DiffPlacement(\n    const NdSbp& sbp_producer, const NdSbp& sbp_consumer, const BlobDesc& logical_blob_desc,\n    const ParallelDesc& producer_parallel_desc, const ParallelDesc& consumer_parallel_desc,\n    bool is_customized, std::vector<NdSbp>& middle_sbps, int32_t* diag_node_pos, bool compute_cost,\n    BoxingCollector* boxing_collector_producer, BoxingCollector* boxing_collector_consumer,\n    const std::vector<std::vector<int32_t>>& diag_nodes) {\n  // Pick the path with minimum storage for the diagonal node\n  int32_t id_producer = boxing_collector_producer->FindId4NdSbp(sbp_producer);\n  if (id_producer < 0) {\n    CHECK_OR_RETURN(compute_cost) << \"Source data with shape \" << logical_blob_desc.shape()\n                                  << \" has an invalid sbp \" << NdSbpToString(sbp_producer);\n    return false;\n  }\n  int32_t id_consumer = boxing_collector_consumer->FindId4NdSbp(sbp_consumer);\n  if (id_consumer < 0) {\n    CHECK_OR_RETURN(compute_cost) << \"Target data with shape \" << logical_blob_desc.shape()\n                                  << \" has an invalid sbp \" << NdSbpToString(sbp_consumer);\n    return false;\n  }\n  middle_sbps.clear();\n  // NOTE: For simplicity, We do not dig into those storage cost for the other middle nodes at\n  // this moment.\n  double min_cost = GetValidMaxCopyCost();\n  int32_t producer_hierarchy_num_axes = producer_parallel_desc.hierarchy()->NumAxes();\n  int32_t consumer_hierarchy_num_axes = consumer_parallel_desc.hierarchy()->NumAxes();\n  int32_t min_diag_producer = -1, min_diag_consumer = -1;\n  for (const auto& diag_pair : diag_nodes) {\n    Shape logical_shape = logical_blob_desc.shape();\n    // We do not check whether such shape is valid under two side of the sbp list in the\n    // middle nodes algorithm. Thus, we need to check them here.\n    double curr_cost =\n        Storage4NdSbp(*JUST(SetNdSbpDim(boxing_collector_producer->nd_sbp_lists_[diag_pair[0]],\n                                        producer_hierarchy_num_axes)),\n                      logical_shape, *producer_parallel_desc.hierarchy());\n    // Check the shape for both producer and consumer.\n    logical_shape = logical_blob_desc.shape();\n    curr_cost +=\n        Storage4NdSbp(*JUST(SetNdSbpDim(boxing_collector_consumer->nd_sbp_lists_[diag_pair[1]],\n                                        consumer_hierarchy_num_axes)),\n                      logical_shape, *consumer_parallel_desc.hierarchy());\n    if (curr_cost < min_cost) {\n      min_cost = curr_cost;\n      min_diag_producer = diag_pair[0];\n      min_diag_consumer = diag_pair[1];\n    }\n  }\n\n  // Different placements: [2, 3] vs 5, or [3, 2] vs [2, 2], or cpu vs cuda\n  // Different hierarchies: [2, 3] vs 5, or [4, 3] vs [6, 2]\n  bool diff_placement = !producer_parallel_desc.EqualsIgnoringHierarchy(consumer_parallel_desc);\n\n  // If we found a diagonal middle node with current boxing collector\n  if (min_diag_producer >= 0) {\n    std::vector<NdSbp> middle_sbps_buffer;\n    // Find the middle nodes between the producer and the diagonal node\n    if (id_producer != min_diag_producer) {\n      JUST(boxing_collector_producer->AskSbpCombination(\n          sbp_producer, boxing_collector_producer->nd_sbp_lists_[min_diag_producer],\n          logical_blob_desc, producer_parallel_desc, producer_parallel_desc,\n          /*is_customized=*/false, middle_sbps_buffer, diag_node_pos, compute_cost));\n      // Add the path into middle_sbps\n      for (auto& middle_sbp : middle_sbps_buffer) {\n        middle_sbps.emplace_back(*JUST(SetNdSbpDim(middle_sbp, producer_hierarchy_num_axes)));\n      }\n      // If different placement,\n      // or the same placement but with 2D hierarchies\n      // For example: Oneflow supports [6]: (S0) -> [3, 2]: (S0, S1)\n      // but does not support [2, 3]: (S0, S0) -> [3, 2]: (S0, S1)\n      if (diff_placement || producer_hierarchy_num_axes > 1) {\n        middle_sbps.emplace_back(\n            *JUST(SetNdSbpDim(boxing_collector_producer->nd_sbp_lists_[min_diag_producer],\n                              producer_hierarchy_num_axes)));\n      }\n    }\n    // If we do not have middle nodes on the consumer side\n    *diag_node_pos = middle_sbps.size();\n    // Find the middle nodes between the diagonal node and the consumer\n    if (id_consumer != min_diag_consumer) {\n      JUST(boxing_collector_consumer->AskSbpCombination(\n          boxing_collector_consumer->nd_sbp_lists_[min_diag_consumer], sbp_consumer,\n          logical_blob_desc, consumer_parallel_desc, consumer_parallel_desc,\n          /*is_customized=*/false, middle_sbps_buffer, diag_node_pos, compute_cost));\n      // Set the diagonal node position and stop using it as buffer\n      *diag_node_pos = middle_sbps.size();\n      // If different placement\n      if (diff_placement || consumer_hierarchy_num_axes > 1) {\n        middle_sbps.emplace_back(\n            *JUST(SetNdSbpDim(boxing_collector_consumer->nd_sbp_lists_[min_diag_consumer],\n                              consumer_hierarchy_num_axes)));\n      }\n      // Add the path into middle_sbps\n      for (auto& middle_sbp : middle_sbps_buffer) {\n        middle_sbps.emplace_back(*JUST(SetNdSbpDim(middle_sbp, consumer_hierarchy_num_axes)));\n      }\n    }\n    return true;\n  }\n  return false;\n}\n\n// Generate the transfer rule for one combination with different placements\n// id_producer -> id_consumer.\nMaybe<void> BoxingCollector::Generate1Combination4DiffPlacement(\n    int32_t id_producer, int32_t id_consumer, BoxingCollector* boxing_collector_producer,\n    BoxingCollector* boxing_collector_consumer,\n    const std::vector<std::vector<double>>& cost_4_diff_placement,\n    std::vector<std::vector<int32_t>>& diag_nodes) {\n  // Number of 1d sbp\n  int32_t m = id2sbp_parallel_.size();\n  // minimum number of node\n  int32_t min_path_length = 100;\n  // minimum cost\n  double min_cost = GetValidMaxCopyCost();\n\n  // Search the path that contains two of the diagonal sbp\n  // From the producer to the first diagonal node\n  for (int32_t id_1d_producer = 0; id_1d_producer < m; id_1d_producer++) {\n    // We do not support [2, 3]: (S0, S1) -> [6]: S0 for a tensor with shape (14, 21)\n    // Thus, the diagonal node should suit both the hierarchies.\n    int32_t diag_producer = boxing_collector_producer->id_1d_2_nd_[id_1d_producer];\n    if (diag_producer < 0\n        || boxing_collector_producer->minimum_copy_cost_[id_producer][diag_producer]\n               > GetValidMaxCopyCost()) {\n      continue;\n    }\n    // Find the path with minimum number of nodes\n    int32_t path_length = 0;\n    // Transfer from id_producer to diag_producer\n    if (boxing_collector_producer->middle_nodes_[id_producer][diag_producer].size() > 0) {\n      path_length +=\n          boxing_collector_producer->middle_nodes_[id_producer][diag_producer][0].size() + 1;\n    } else if (id_producer != diag_producer) {\n      path_length++;\n    }\n    // pruning\n    if (path_length > min_path_length) { continue; }\n\n    // From the second diagonal node to the consumer\n    for (int32_t id_1d_consumer = 0; id_1d_consumer < m; id_1d_consumer++) {\n      int32_t diag_consumer = boxing_collector_consumer->id_1d_2_nd_[id_1d_consumer];\n      // The diagonal sbp is not supported or no paths exist from the diagonal sbp to the\n      // consumer or between the two diagonal sbps.\n      if (diag_consumer < 0\n          || boxing_collector_consumer->minimum_copy_cost_[diag_consumer][id_consumer]\n                 > GetValidMaxCopyCost()\n          || cost_4_diff_placement[id_1d_producer][id_1d_consumer] > GetValidMaxCopyCost()) {\n        continue;\n      }\n\n      // Transfer from diag_consumer to id_consumer\n      int32_t curr_path_length = path_length;\n      if (boxing_collector_consumer->middle_nodes_[diag_consumer][id_consumer].size() > 0) {\n        curr_path_length +=\n            boxing_collector_consumer->middle_nodes_[diag_consumer][id_consumer][0].size() + 1;\n      } else if (diag_consumer != id_consumer) {\n        curr_path_length++;\n      }\n      // Pick the path with minimum copy cost\n      if (curr_path_length <= min_path_length) {\n        double curr_cost =\n            boxing_collector_producer->minimum_copy_cost_[id_producer][diag_producer]\n            + cost_4_diff_placement[id_1d_producer][id_1d_consumer]\n            + boxing_collector_consumer->minimum_copy_cost_[diag_consumer][id_consumer];\n\n        min_path_length = curr_path_length;\n        // Find a candidate with small cost\n        if (curr_cost < min_cost * 1.0000001) {\n          // Find a smaller cost, clear the previous path.\n          if (curr_cost < min_cost * 0.9999999) {\n            min_cost = curr_cost;\n            diag_nodes.clear();\n          }\n          // Add the current diagonal node\n          // Asymmetry happens here. We can only store one side of the diagonal node.\n          // We do not store diag_consumer\n          diag_nodes.push_back({diag_producer, diag_consumer});\n        }\n      }\n    }\n  }\n\n  return Maybe<void>::Ok();\n}\n\n// Filter nd sbp from nd_sbp_lists_ with given logical shape\nMaybe<void> BoxingCollector::FilterNdSbpList4LogicalShape(const BlobDesc& logical_blob_desc,\n                                                          const Shape& parallel_hierarchy) {\n  for (int32_t middle_sbp_id = nd_sbp_lists_.size() - 1; middle_sbp_id >= 0; middle_sbp_id--) {\n    Shape logical_shape = logical_blob_desc.shape();\n    if (JUST(FilterNdSbpByLogicalShape(nd_sbp_lists_[middle_sbp_id], logical_shape,\n                                       parallel_hierarchy))) {\n      // Change the value before erasing\n      // This might be true: nd_sbp_lists_.size() - 1 == middle_sbp_id\n      nd_sbp_universe_[nd_sbp_lists_[nd_sbp_lists_.size() - 1]] = middle_sbp_id;\n      nd_sbp_universe_.erase(nd_sbp_lists_[middle_sbp_id]);\n      nd_sbp_lists_[middle_sbp_id] = nd_sbp_lists_[nd_sbp_lists_.size() - 1];\n      nd_sbp_lists_.pop_back();\n    }\n  }\n  return Maybe<void>::Ok();\n}\n\n// Ask for sbp combination for general basic communication\nMaybe<void> BoxingCollector::AskSbpCombination4GeneralBasicCommunication(\n    const NdSbp& sbp_producer, const NdSbp& sbp_consumer, const BlobDesc& logical_blob_desc,\n    const ParallelDesc& producer_parallel_desc, const ParallelDesc& consumer_parallel_desc,\n    std::vector<NdSbp>& middle_sbps, int32_t* diag_node_pos) {\n  // (P, X) -> (B, X) || (X , P) -> (X, B), X is any SBP\n  // One step transfer, at most 50% reduction in the transfer cost, do not use middle nodes\n  if (producer_parallel_desc == consumer_parallel_desc\n      && producer_parallel_desc.hierarchy()->NumAxes() == 2\n      && (sbp_producer.sbp_parallel(0) == sbp_consumer.sbp_parallel(0)\n          || sbp_producer.sbp_parallel(1) == sbp_consumer.sbp_parallel(1))) {\n    return Maybe<void>::Ok();\n  }\n\n  // Not enough gain in transfer cost, do not use middle nodes\n  int32_t partial_ratio4producer = PartialRatio4Producer(sbp_producer, producer_parallel_desc);\n  int32_t broadcast_ratio4consumer = BroadcastRatio4Consumer(sbp_consumer, consumer_parallel_desc);\n  if (2 * (partial_ratio4producer + broadcast_ratio4consumer)\n      >= partial_ratio4producer * broadcast_ratio4consumer) {\n    return Maybe<void>::Ok();\n  }\n\n  bool close2producer = true;\n  if (producer_parallel_desc.parallel_num() == consumer_parallel_desc.parallel_num()) {\n    // Get close to the one with more splits\n    close2producer = TotalNumSplit(sbp_producer, producer_parallel_desc)\n                     > TotalNumSplit(sbp_consumer, consumer_parallel_desc);\n  } else {\n    // Get close to the one with more machines\n    close2producer = producer_parallel_desc.parallel_num() > consumer_parallel_desc.parallel_num();\n  }\n  // Get the contiguous sbp\n  if (close2producer) {\n    JUST(AskCloseAllSplitSbp(sbp_producer, producer_parallel_desc, logical_blob_desc, middle_sbps));\n    *diag_node_pos = 1;\n  } else {\n    JUST(AskCloseAllSplitSbp(sbp_consumer, consumer_parallel_desc, logical_blob_desc, middle_sbps));\n    *diag_node_pos = 0;\n  }\n  return Maybe<void>::Ok();\n}\n\n// Ask for a all-split sbp which is close to the original one\nMaybe<void> BoxingCollector::AskCloseAllSplitSbp(const NdSbp& nd_sbp,\n                                                 const ParallelDesc& parallel_desc,\n                                                 const BlobDesc& logical_blob_desc,\n                                                 std::vector<NdSbp>& middle_sbps) {\n  Shape remain_shape = logical_blob_desc.shape();\n  Shape rest_split_shape = logical_blob_desc.shape();\n  int32_t dim_shape = remain_shape.NumAxes();\n  // Initialize the remains and splitting\n  // logical_blob_desc.shape() == remain_shape .* rest_split_shape;\n  for (int32_t i = 0; i < dim_shape; i++) { rest_split_shape.Set(i, 1); }\n  for (int32_t sbp_id = 0; sbp_id < nd_sbp.sbp_parallel_size(); sbp_id++) {\n    const auto& sbp = nd_sbp.sbp_parallel(sbp_id);\n    if (sbp.has_split_parallel()) {\n      int32_t axis = sbp.split_parallel().axis();\n      int32_t split_num = parallel_desc.hierarchy()->At(sbp_id);\n      remain_shape.Set(axis, remain_shape.At(axis) / split_num);\n      rest_split_shape.Set(axis, rest_split_shape.At(axis) * split_num);\n    }\n  }\n  // Get the contiguous sbp\n  NdSbp new_sbp = nd_sbp;\n  for (int32_t sbp_id = 0; sbp_id < nd_sbp.sbp_parallel_size(); sbp_id++) {\n    const auto& sbp = nd_sbp.sbp_parallel(sbp_id);\n    int32_t split_num = parallel_desc.hierarchy()->At(sbp_id);\n    if (sbp.has_split_parallel()) {\n      int32_t axis = sbp.split_parallel().axis();\n      // split shape is the total splitting number starting from sbp_id to the end\n      rest_split_shape.Set(axis, rest_split_shape.At(axis) / split_num);\n    } else {\n      // change P or B to S(axis)\n      int32_t axis = -1;\n      // 4096 is large enough, we might not have that much devices\n      int32_t min_split_num = 4096;\n      // We need to pick a suitable axis\n      for (int32_t i = 0; i < remain_shape.NumAxes(); i++) {\n        if (remain_shape.At(i) % split_num == 0) {\n          if (rest_split_shape.At(i) < min_split_num) {\n            // Pick the axis with smallest splitting number among the rest of the sbp\n            min_split_num = rest_split_shape.At(i);\n            axis = i;\n          }\n        }\n      }\n      // P, B -> S(axis)\n      if (axis >= 0) {\n        new_sbp.mutable_sbp_parallel(sbp_id)->mutable_split_parallel()->set_axis(axis);\n        remain_shape.Set(axis, remain_shape.At(axis) / split_num);\n      } else {\n        // Can not find a suitable contiguous sbp\n        return Maybe<void>::Ok();\n      }\n    }\n  }\n  // Add the new sbp into the middle node lists\n  middle_sbps.emplace_back(new_sbp);\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/auto_parallel/boxing_collector.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#ifndef ONEFLOW_CORE_AUTO_PARALLEL_BOXING_COLLECTOR_H_\n#define ONEFLOW_CORE_AUTO_PARALLEL_BOXING_COLLECTOR_H_\n\n#include \"oneflow/core/common/hash_container.h\"\n#include \"oneflow/core/job/parallel_desc.h\"\n#include \"oneflow/core/job/sbp_parallel.h\"\n#include \"oneflow/core/framework/sbp_infer_util.h\"\n\nnamespace oneflow {\n\nclass BoxingCollector final {\n public:\n  BoxingCollector() = default;\n\n  ~BoxingCollector() = default;\n\n  // A constructor with init, designed for non-customized boxing collector\n  BoxingCollector(int32_t max_axis);\n\n  // Set default Sbp list\n  void CollectUniverse(int32_t max_axis);\n\n  // Construct a boxing collector with given maximum number of axis\n  Maybe<void> Init(int32_t max_axis);\n  // Init with given blob description\n  Maybe<void> Init(const BlobDesc& logical_blob_desc, const ParallelDesc& parallel_desc);\n\n  // Generate nd sbp list\n  void GenerateNdSbpList(int32_t hierarchy_num);\n  // Generate the map from 1d sbp to 2d sbp\n  void GenerateMap1d2nd();\n  // Generate the transfer rule for different combinations with the same hierarchy\n  Maybe<void> GenerateCombination4SamePlacement(int32_t max_middle_node_num);\n  Maybe<void> GenerateCombination4SamePlacement(int32_t max_middle_node_num,\n                                                const BlobDesc& blob_desc,\n                                                const ParallelDesc& parallel_desc);\n  // Generate the transfer rule for different combinations with different hierarchies\n  // on the same placement\n  Maybe<void> GenerateCombination4DiffHierarchy(BoxingCollector* boxing_collector_producer,\n                                                BoxingCollector* boxing_collector_consumer);\n  // Generate the transfer rule for different combinations with different placements\n  Maybe<void> GenerateCombination4DiffPlacement(BoxingCollector* boxing_collector_producer,\n                                                BoxingCollector* boxing_collector_consumer);\n  Maybe<void> GenerateCombination4DiffPlacement(BoxingCollector* boxing_collector_producer,\n                                                BoxingCollector* boxing_collector_consumer,\n                                                const BlobDesc& blob_desc,\n                                                const ParallelDesc& in_parallel_desc,\n                                                const ParallelDesc& out_parallel_desc);\n  // Print the cost and middle nodes\n  void PrintBoxingTables();\n  // Ask if the boxing algorithm accepts the current sbp combination\n  // If is_customized is true and we can not find a middle node list with\n  // reasonable cost, error occurs.\n  // If compute_cost is true, then no error occur even if no suitable middle nodes paths found.\n  // For different placements, we would return a diagonal node.\n  // Before this diagonal node (< *diag_node_pos), we use the parallel description of the producer.\n  // After this diagonal node (>= *diag_node_pos), we use the parallel description of the consumer.\n  Maybe<void> AskSbpCombination(const NdSbp& sbp_producer, const NdSbp& sbp_consumer,\n                                const BlobDesc& logical_blob_desc,\n                                const ParallelDesc& producer_parallel_desc,\n                                const ParallelDesc& consumer_parallel_desc, bool is_customized,\n                                std::vector<NdSbp>& middle_sbps, int32_t* diag_node_pos,\n                                bool compute_cost);\n  // Filter nd sbp from nd_sbp_lists_ with given logical shape\n  Maybe<void> FilterNdSbpList4LogicalShape(const BlobDesc& logical_blob_desc,\n                                           const Shape& parallel_hierarchy);\n\n private:\n  // Collect Sbp Parallel\n  void CollectUniverse(const SbpParallel& sbp);\n  // Find corresponding id for Nd sbp\n  int32_t FindId4NdSbp(const NdSbp& nd_sbp);\n  // Ask for sbp combination with the same 2-D hierarchy and placement\n  Maybe<void> AskSbpCombination4Same2DPlacement(const NdSbp& sbp_producer,\n                                                const NdSbp& sbp_consumer,\n                                                const BlobDesc& logical_blob_desc,\n                                                const ParallelDesc& producer_parallel_desc,\n                                                const ParallelDesc& consumer_parallel_desc,\n                                                bool is_customized, std::vector<NdSbp>& middle_sbps,\n                                                int32_t* diag_node_pos, bool compute_cost);\n  // Ask for sbp combination with different hierarchies on the same placement\n  Maybe<void> AskSbpCombination4DiffPlacement(const NdSbp& sbp_producer, const NdSbp& sbp_consumer,\n                                              const BlobDesc& logical_blob_desc,\n                                              const ParallelDesc& producer_parallel_desc,\n                                              const ParallelDesc& consumer_parallel_desc,\n                                              bool is_customized, std::vector<NdSbp>& middle_sbps,\n                                              int32_t* diag_node_pos, bool compute_cost);\n  // Generate the transfer rule for one combination with different hierarchies on the same\n  // placement. id_producer -> id_consumer.\n  Maybe<void> Generate1Combination4DiffHierarchy(int32_t id_producer, int32_t id_consumer,\n                                                 BoxingCollector* boxing_collector_producer,\n                                                 BoxingCollector* boxing_collector_consumer,\n                                                 std::vector<std::vector<int32_t>>& diag_nodes);\n  // The cost for transferring a 1D sbp between different placements\n  Maybe<void> ComputeCostFor1DSbpDiffPlacement(\n      const BlobDesc& blob_desc, const ParallelDesc& in_parallel_desc,\n      const ParallelDesc& out_parallel_desc,\n      std::vector<std::vector<double>>& cost_4_diff_placement);\n  // Generate the transfer rule for one combination with different placements\n  // id_producer -> id_consumer.\n  Maybe<void> Generate1Combination4DiffPlacement(\n      int32_t id_producer, int32_t id_consumer, BoxingCollector* boxing_collector_producer,\n      BoxingCollector* boxing_collector_consumer,\n      const std::vector<std::vector<double>>& cost_4_diff_placement,\n      std::vector<std::vector<int32_t>>& diag_nodes);\n  // Ask for one combination with different hierarchies and placements\n  Maybe<bool> Ask1Combination4DiffPlacement(const NdSbp& sbp_producer, const NdSbp& sbp_consumer,\n                                            const BlobDesc& logical_blob_desc,\n                                            const ParallelDesc& producer_parallel_desc,\n                                            const ParallelDesc& consumer_parallel_desc,\n                                            bool is_customized, std::vector<NdSbp>& middle_sbps,\n                                            int32_t* diag_node_pos, bool compute_cost,\n                                            BoxingCollector* boxing_collector_producer,\n                                            BoxingCollector* boxing_collector_consumer,\n                                            const std::vector<std::vector<int32_t>>& diag_nodes);\n  // Ask for sbp combination for general basic communication\n  Maybe<void> AskSbpCombination4GeneralBasicCommunication(\n      const NdSbp& sbp_producer, const NdSbp& sbp_consumer, const BlobDesc& logical_blob_desc,\n      const ParallelDesc& producer_parallel_desc, const ParallelDesc& consumer_parallel_desc,\n      std::vector<NdSbp>& middle_sbps, int32_t* diag_node_pos);\n  // Ask for a all-split sbp which is closed to the original one\n  Maybe<void> AskCloseAllSplitSbp(const NdSbp& nd_sbp, const ParallelDesc& parallel_desc,\n                                  const BlobDesc& logical_blob_desc,\n                                  std::vector<NdSbp>& middle_sbps);\n  // Stores all the possible SbpParallel.\n  HashMap<SbpParallel, int32_t> sbp_parallel_universe_;\n  // Relationship between id and Sbp Parallel\n  std::vector<SbpParallel> id2sbp_parallel_;\n  // minimum cost\n  // minimum_copy_cost[producer][consumer]\n  std::vector<std::vector<double>> minimum_copy_cost_;\n  // middle nodes\n  // middle_nodes_[producer][consumer][different choices] is a vector of middle nodes\n  // middle_nodes_[producer][consumer][different choices].size() is the minimum number of middle\n  // nodes that needs to be inserted\n  std::vector<std::vector<std::vector<std::vector<int32_t>>>> middle_nodes_;\n  // Stores all the possible NdSbp.\n  std::unordered_map<NdSbp, int32_t> nd_sbp_universe_;\n  // Relationship between id and Nd Sbp\n  std::vector<NdSbp> nd_sbp_lists_;\n  // The diagonal middle node for different placements\n  std::vector<std::vector<std::vector<std::vector<int32_t>>>> diag_node_diff_placement_;\n  // The diagonal middle node for different hierarchies in the same placement\n  std::vector<std::vector<std::vector<std::vector<int32_t>>>> diag_node_diff_hierarchy_;\n  // Id Map from 1d sbp to 2d sbp\n  // For example: B -> (B, B), S0 -> (S0, S0)\n  std::vector<int32_t> id_1d_2_nd_;\n  // The sbp size in the combination table\n  int32_t hierarchy_num_;\n  // How the boxing collector is initialized\n  int32_t init_type_ = -1;\n  // Enable general basic communication or not\n  const bool enable_general_basic_communication =\n      ParseBooleanFromEnv(\"ONEFLOW_BOXING_ENABLE_GENERAL_BASIC_COMMUNICATION\", false);\n};  // class BoxingCollector\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_AUTO_PARALLEL_BOXING_COLLECTOR_H_\n"
  },
  {
    "path": "oneflow/core/auto_parallel/sbp_collector.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include <string>\n#include \"oneflow/core/auto_parallel/sbp_collector.h\"\n#include \"oneflow/core/auto_parallel/binary_set.h\"\n#include \"oneflow/core/auto_parallel/sbp_util.h\"\n#include \"oneflow/core/auto_parallel/sbp_constructor.h\"\n\nnamespace oneflow {\n\nnamespace auto_parallel {\n\nnamespace {\n// Whether the given binary set intersects all the sbp sets of the consumers\nbool IfIntersectAll(\n    const HashMap<std::pair<std::string, std::string>, BinarySet>& consumer_bn2sbp_set,\n    const BinarySet& bs) {\n  for (const auto& sbp_set_group : consumer_bn2sbp_set) {\n    if (!bs.IfIntersect(sbp_set_group.second)) { return false; }\n  }\n\n  return true;\n}\n\n// Find unique sbp sets\nvoid FindUniqueSbpSets(\n    const HashMap<std::pair<std::string, std::string>, BinarySet>& consumer_bn2sbp_set,\n    const std::unordered_set<int32_t>& all_sbp_set, std::vector<int32_t>& accumulator,\n    BinarySet& unique_sbps) {\n  std::vector<int32_t> sbp_ids;\n  // count the number of sbp\n  for (const auto& sbp_set_group : consumer_bn2sbp_set) {\n    sbp_set_group.second.QuickOutput(sbp_ids);\n    for (int32_t sbp_id : sbp_ids) { accumulator[sbp_id]++; }\n  }\n  // find unique sbp and clear the accumulator\n  for (const auto& sbp_id : all_sbp_set) {\n    if (accumulator[sbp_id] == 1) { unique_sbps.AddEntry(sbp_id); }\n    accumulator[sbp_id] = 0;\n  }\n}\n\n// Find unique sbp groups\nvoid FindUniqueSbpGroups(\n    const HashMap<std::pair<std::string, std::string>, BinarySet>& consumer_bn2sbp_set,\n    const std::unordered_set<int32_t>& all_sbp_set, std::vector<int32_t>& accumulator,\n    BinarySet& bs_buffer, std::vector<BinarySet>& unique_sbp_groups) {\n  // find the unique sbp sets\n  BinarySet unique_sbps(accumulator.size());\n  FindUniqueSbpSets(consumer_bn2sbp_set, all_sbp_set, accumulator, unique_sbps);\n\n  // A: {B, S0, S1, S2, S3}, C: {B, S0}, D: {B, S0}\n  // {S1, S2, S3} show up only once, a parallel candidate should not contain two of them\n  for (const auto& sbp_set_group : consumer_bn2sbp_set) {\n    unique_sbps.IntersectionTo(sbp_set_group.second, bs_buffer);\n    // Find those unique sbp groups with more than two sbp\n    // For example {B, S1, S2} is an impossible proxy candidate,\n    // since {S1, S2} is only contained by A but not contained by C and D.\n    // A could be either S1 or S2. The tensor do not need to be transferred to both S1 and S2.\n    if (bs_buffer.Total() >= 2) { unique_sbp_groups.push_back(bs_buffer); }\n  }\n  bs_buffer.Clear();\n}\n\n// If not contains two sbp from a same unique group\nbool No2SbpFromSameUniqueGroup(const BinarySet& bs,\n                               const std::vector<BinarySet>& unique_sbp_groups) {\n  BinarySet intersection(bs.GetSizeOfSet());\n  for (const auto& unique_sbp_group : unique_sbp_groups) {\n    bs.IntersectionTo(unique_sbp_group, intersection);\n    // For example {B, S1, S2} is an impossible proxy candidate,\n    // since {S1, S2} is only contained by A but not contained by C and D.\n    // A could be either S1 or S2. The tensor do not need to be transferred to both S1 and S2.\n    if (intersection.Total() >= 2) { return false; }\n  }\n  return true;\n}\n}  // namespace\n\n// Default constructor for SbpCollector\n// Don't allow any special case for broadcast!\nSbpCollector::SbpCollector() {\n  // initialize Sbp Parallel Universe with broadcast.\n  // NdSbp sbp_broadcast;\n  // sbp_broadcast.mutable_broadcast_parallel();\n  // nd_sbp_universe_[sbp_broadcast] = 0;\n  // id2nd_sbp_.push_back(sbp_broadcast);\n}\n\n// Collect all the possible Sbp Parallel from a NdSbpSignature\nvoid SbpCollector::CollectUniverse(const NdSbpSignature& nd_sbp_sig) {\n  for (auto& bn_sbp_pair : nd_sbp_sig.bn_in_op2nd_sbp()) {\n    if (nd_sbp_universe_.find(bn_sbp_pair.second) == nd_sbp_universe_.end()) {\n      int32_t curr_size = nd_sbp_universe_.size();\n      nd_sbp_universe_[bn_sbp_pair.second] = curr_size;\n      id2nd_sbp_.push_back(bn_sbp_pair.second);\n    }\n  }\n}\n// Collect all the possible Sbp Parallel from a SbpNode\nvoid SbpCollector::CollectUniverse(const SbpNode* sbp_node) {\n  for (auto& nd_sbp_sig : sbp_node->sbp_sig_list_) { CollectUniverse(nd_sbp_sig); }\n}\n// Collect all the possible Sbp Parallel from a SbpGraph\nvoid SbpCollector::CollectUniverse(const SbpGraph& sbp_graph) {\n  for (auto* sbp_node : sbp_graph.node_list_) { CollectUniverse(sbp_node); }\n  accumulator_.resize(nd_sbp_universe_.size(), 0);\n  bs_buffer_.Initialize(nd_sbp_universe_.size());\n}\n\n// TODO: Auto Placement!\n// It only collect the same sbp with the same parallel description\n// In this moment their hierarchy is the same!\n\n// Initialize copy cost from producer to proxy of producer\nvoid SbpCollector::InitializeCopyCostFromNode2Proxy(const SbpNode* sbp_proxy,\n                                                    const LogicalBlobId& lbi) const {\n  // the only edge from producer  to proxy of producer\n  SbpEdge* sbp_edge = sbp_proxy->edges_in_[0];\n  SbpNode* sbp_node_producer = sbp_edge->start_node_;\n  sbp_edge->cost_.resize(sbp_node_producer->sbp_sig_list_.size());\n  int32_t consumer_sbp_size = sbp_proxy->parallel_candidates_.size();\n  // look through sbp signature in producer\n  for (int32_t sbp_id_producer = 0; sbp_id_producer < sbp_node_producer->sbp_sig_list_.size();\n       sbp_id_producer++) {\n    sbp_edge->cost_[sbp_id_producer].resize(consumer_sbp_size, 0);\n  }\n\n  // Assemble copy cost from producer to proxy of producer\n  OpNode* producer = sbp_node_producer->op_node_;\n\n  // get parallel description. Number of devices.\n  const ParallelDesc& producer_parallel_desc = producer->parallel_desc();\n  // Need to be careful, the logical blob description should be independent to current\n  // NdSbp. Use producer or op_node?\n  const BlobDesc& logical_blob_desc = producer->LogicalBlobDesc4Lbi(lbi);\n  const std::string& obn = *CHECK_JUST(producer->op().obn4lbi(lbi));\n\n  // A buffer to store the sbp parallel id\n  std::vector<int32_t> sbp_parallel_ids;\n\n  // look through sbp signature in producer\n  for (int32_t sbp_id_producer = 0; sbp_id_producer < sbp_node_producer->sbp_sig_list_.size();\n       sbp_id_producer++) {\n    // get sbp parallel for a logical blob in producer\n    const auto& producer_sbp_bn_in_op2sbp_parallel =\n        sbp_node_producer->sbp_sig_list_[sbp_id_producer].bn_in_op2nd_sbp();\n    const NdSbp& sbp_producer = producer_sbp_bn_in_op2sbp_parallel.at(obn);\n\n    // look through sbp parallel set in consumer\n    for (int32_t sbp_id_consumer = 0; sbp_id_consumer < consumer_sbp_size; sbp_id_consumer++) {\n      const BinarySet& sbp_parallel_set = sbp_proxy->parallel_candidates_[sbp_id_consumer];\n      sbp_parallel_set.QuickOutput(sbp_parallel_ids);\n\n      // look through all sbp parallels in a sbp parallel set\n      for (int32_t sbp_parallel_id : sbp_parallel_ids) {\n        // get sbp parallel for a logical blob in consumer\n        const NdSbp& sbp_consumer = id2nd_sbp_[sbp_parallel_id];\n\n        // compute copy cost for a specific logical blob\n        // Use the parallel description of producer as those for consumer for now.\n        sbp_edge->cost_[sbp_id_producer][sbp_id_consumer] +=\n            CHECK_JUST(ComputeCopyCostWithMiddleNodes(sbp_producer, sbp_consumer, logical_blob_desc,\n                                                      producer_parallel_desc,\n                                                      producer_parallel_desc, /*is_same=*/false));\n      }\n    }\n  }\n}\n\n// Initialize copy cost from proxy of producer to consumers\nvoid SbpCollector::InitializeCopyCostFromProxy2Consumer(\n    SbpNode* sbp_proxy,\n    const HashMap<std::pair<std::string, std::string>, BinarySet>& consumer_bn2sbp_set,\n    const HashMap<std::string, SbpNode*>& op_name2sbp_node) const {\n  // Connect sbp proxy and consumers\n  for (const auto& consumer_bn_group : consumer_bn2sbp_set) {\n    // consumer in cost model\n    SbpNode* sbp_node_consumer = op_name2sbp_node.find(consumer_bn_group.first.first)->second;\n    // input blob name of logical blob in consumer\n    const std::string& ibn = consumer_bn_group.first.second;\n\n    // check is_mutable in consumer\n    OpNode* consumer = sbp_node_consumer->op_node_;\n    CHECK(!RequireSameSbp(consumer, ibn)) << \"Create a proxy for an unsuitable consumer!\\n\";\n\n    // Connect sbp proxy and consumer\n    sbp_proxy->PointTo(sbp_node_consumer);\n    // the sbp edge connecting proxy and consumer\n    SbpEdge* sbp_edge = sbp_node_consumer->FindEdgeWithNode(sbp_proxy);\n    sbp_edge->cost_.resize(sbp_proxy->parallel_candidates_.size());\n    int32_t consumer_sbp_size = sbp_node_consumer->sbp_sig_list_.size();\n\n    // look through sbp parallel set in proxy\n    for (int32_t sbp_id_producer = 0; sbp_id_producer < sbp_proxy->parallel_candidates_.size();\n         sbp_id_producer++) {\n      // initialization for copy cost\n      sbp_edge->cost_[sbp_id_producer].resize(consumer_sbp_size, 0);\n      // get sbp parallel set for a logical blob in proxy\n      BinarySet& parallel_candidate = sbp_proxy->parallel_candidates_[sbp_id_producer];\n\n      // look through sbp signatures in consumers\n      for (int32_t sbp_id_consumer = 0; sbp_id_consumer < consumer_sbp_size; sbp_id_consumer++) {\n        // get sbp parallel for a logical blob in consumer\n        const auto& consumer_sbp_bn_in_op2sbp_parallel =\n            sbp_node_consumer->sbp_sig_list_[sbp_id_consumer].bn_in_op2nd_sbp();\n        const NdSbp& sbp_consumer = consumer_sbp_bn_in_op2sbp_parallel.at(ibn);\n\n        if ((!parallel_candidate.CheckExistence(nd_sbp_universe_.find(sbp_consumer)->second))) {\n          sbp_edge->cost_[sbp_id_producer][sbp_id_consumer] = GetMaxVal<float>();\n        }\n      }\n    }\n  }\n}\n\n// Export list of possible combination of Sbp Parallels\nvoid SbpCollector::ProxySbpCandidate(const OpGraph& op_graph,\n                                     const HashMap<std::string, SbpNode*>& op_name2sbp_node,\n                                     SbpGraph& sbp_graph) {\n  // If needed, we can output the mapping from operator name to its proxy.\n  // HashMap<std::string, HashMap<LogicalBlobId, SbpNode*>>&\n  //     op_name2lbi2sbp_proxy;\n\n  // mapping from a logical blob id to index\n  HashMap<LogicalBlobId, int32_t> lbi2index;\n  // mapping from the index to producer, consumer and corresponding input blob name, possible sbp\n  // sets\n  std::vector<const OpNode*> index2producer;\n  std::vector<std::unordered_set<int32_t>> index2sbp_set;\n  // mapping from consumers and input blob names to an unordered_set of SBP Parallel.\n  std::vector<HashMap<std::pair<std::string, std::string>, BinarySet>> index2consumer_bn2sbp_set;\n\n  for (auto* consumer_sbp_node : sbp_graph.node_list_) {\n    auto* node = consumer_sbp_node->op_node_;\n\n    OperatorConf::OpTypeCase op_type_case = node->op().op_conf().op_type_case();\n    // If not support boxing, just skip it.\n    if (IsClassRegistered<int32_t, DisableInputBoxingGroup>(op_type_case)) { return; }\n    for (const std::string& ibn : node->op().input_bns()) {\n      // Skip those blobs who enforce same SBP.\n      if (RequireSameSbp(node, ibn)) {\n        // Enforcing same SBP. Can not collect sbp from this blob.\n        continue;\n      }\n\n      const LogicalBlobId& lbi = node->op().BnInOp2Lbi(ibn);\n      const OpNode& producer = node->ProducerOpNode4Lbi(lbi);\n\n      // not building proxy for fixed operators\n      if (op_name2sbp_node.find(producer.op().op_name()) == op_name2sbp_node.end()) { return; }\n      // decide the index of a logical blob description\n      const auto& iterator_lbi = lbi2index.find(lbi);\n      int32_t index = 0;\n      if (iterator_lbi == lbi2index.end()) {\n        index = lbi2index.size();\n        lbi2index[lbi] = index;\n        // map from lbi to the producer\n        index2producer.push_back(&producer);\n        // Initialize consumer_bns and the sbp sets\n        index2consumer_bn2sbp_set.resize(index + 1);\n        index2sbp_set.resize(index + 1);\n      } else {\n        index = iterator_lbi->second;\n      }\n\n      // a set to store the id of all possible SBP Parallel for a downstream op\n      // should filter out repeated SBP Parallel by pre-storing them into an unordered_set\n      BinarySet& nd_sbp_ids = index2consumer_bn2sbp_set[index][{node->op().op_name(), ibn}];\n      nd_sbp_ids.Initialize(nd_sbp_universe_.size());\n      // The union sbp set of all the consumers\n      std::unordered_set<int32_t>& union_nd_sbp_ids = index2sbp_set[index];\n      for (auto& sbp_sig : consumer_sbp_node->sbp_sig_list_) {\n        const auto& map = sbp_sig.bn_in_op2nd_sbp();\n        const auto& iter = map.find(ibn);\n        CHECK(iter != map.end()) << \"blob_name \" << ibn << \" not found in sbp signature\";\n        const NdSbp& consumer_sbp = iter->second;\n        // filter out repeated SBP\n        int32_t sbp_universe_id = nd_sbp_universe_.find(consumer_sbp)->second;\n        nd_sbp_ids.AddEntry(sbp_universe_id);\n        union_nd_sbp_ids.insert(sbp_universe_id);\n      }\n    }\n  };\n\n  // A set of binary set with broadcast only\n  // std::unordered_set<BinarySet, BinarySetHasher> parallel_candidates_initializer;\n  // BinarySet one_broadcast(nd_sbp_universe_.size());\n  // one_broadcast.AddEntry(0);\n  // parallel_candidates_initializer.insert(std::move(one_broadcast));\n\n  // Decide if we should insert a proxy for each logical blob\n  for (auto& lbi_index : lbi2index) {\n    int32_t index = lbi_index.second;\n    // Only insert proxy for those blobs with multiple downstream consumers.\n    if (index2consumer_bn2sbp_set[index].size() < 2) { continue; }\n    // Maximum number of possible sbp in the proxy\n    int32_t max_num_sbp_proxy =\n        std::min(max_num_sbp_proxy_, index2consumer_bn2sbp_set[index].size());\n    // producer in cost model\n    const std::string& producer_name = index2producer[index]->op().op_name();\n    SbpNode* sbp_node_producer = op_name2sbp_node.find(producer_name)->second;\n\n    const LogicalBlobId& lbi = lbi_index.first;\n    // store all the binary sets of SBP Parallel into an unordered_set.\n    // std::vector<BinarySet> parallel_candidates;\n\n    // generate sbp proxy\n    SbpNode* sbp_proxy = sbp_graph.GenerateNode();\n\n    // A: {B, S0, S1, S2, S3}, C: {B, S0}, D: {B, S0}\n    // {S1, S2, S3} show up only once, a parallel candidate should not contain two of them\n    std::vector<BinarySet> unique_sbp_groups;\n    FindUniqueSbpGroups(index2consumer_bn2sbp_set[index], index2sbp_set[index], accumulator_,\n                        bs_buffer_, unique_sbp_groups);\n\n    // Depth first search to collect Sbp Parallel information for the whole sbp set\n    DfsSbpSet(0, max_num_sbp_proxy, index2sbp_set[index], index2sbp_set[index].begin(),\n              index2consumer_bn2sbp_set[index], unique_sbp_groups, sbp_proxy->parallel_candidates_);\n\n    // Initialize computation cost\n    sbp_proxy->cost_.resize(sbp_proxy->parallel_candidates_.size(), 0);\n\n    // Transfer a logical blob from producer to a sbp proxy of this blob\n    sbp_node_producer->PointTo(sbp_proxy);\n\n    // Compute copy cost between producer and proxy\n    InitializeCopyCostFromNode2Proxy(sbp_proxy, lbi);\n\n    // Build connection and compute copy cost between proxy and consumers\n    InitializeCopyCostFromProxy2Consumer(sbp_proxy, index2consumer_bn2sbp_set[index],\n                                         op_name2sbp_node);\n\n    // Unloading\n    for (const auto& consumer_bn_group : index2consumer_bn2sbp_set[index]) {\n      // consumer in cost model\n      SbpNode* sbp_node_consumer = op_name2sbp_node.find(consumer_bn_group.first.first)->second;\n      // the sbp edge connecting producer and consumer\n      SbpEdge* edge_found = sbp_node_consumer->FindEdgeWithNode(sbp_node_producer);\n      // unload logical blob from sbp edges\n      edge_found->UnloadLbi(lbi);\n      // Do not clip this edge. Save it for wait time.\n      // clip this edge if it no longer carries any blob\n      // We don't clip edges before since we have transfer cost\n      // Now we clip edges, which makes the topology simpler\n      if (edge_found->EmptyLbi() && edge_found->wait_time_ <= 0.0\n          && edge_found->wait_time_ > -0.5) {\n        sbp_graph.ClipEdge(edge_found);\n      }\n    }\n  }\n}\n\n// Depth first search to collect Sbp Parallel information for different logical blob ids\nvoid SbpCollector::DfsSbpSet(\n    int32_t depth, int32_t max_depth, const std::unordered_set<int32_t>& sbp_sets,\n    const std::unordered_set<int32_t>::iterator& start_it,\n    const HashMap<std::pair<std::string, std::string>, BinarySet>& consumer_bn2sbp_set,\n    const std::vector<BinarySet>& unique_sbp_groups, std::vector<BinarySet>& parallel_candidates) {\n  if (depth > 0) {\n    if (IfIntersectAll(consumer_bn2sbp_set, bs_buffer_)\n        && No2SbpFromSameUniqueGroup(bs_buffer_, unique_sbp_groups)) {\n      // store the binary set into an unordered_set\n      parallel_candidates.push_back(bs_buffer_);\n    }\n  }\n  if (depth >= max_depth) { return; }\n\n  // go through the rest of the sbp parallel\n  std::unordered_set<int32_t>::iterator curr_it = start_it;\n  while (curr_it != sbp_sets.end()) {\n    // Take the value out\n    int32_t nd_sbp_num = *curr_it;\n    // Then move to the next pointer\n    ++curr_it;\n    if (accumulator_[nd_sbp_num] == 0) {\n      bs_buffer_.AddEntry(nd_sbp_num);\n      ++accumulator_[nd_sbp_num];\n      DfsSbpSet(depth + 1, max_depth, sbp_sets, curr_it, consumer_bn2sbp_set, unique_sbp_groups,\n                parallel_candidates);\n      bs_buffer_.DeleteEntry(nd_sbp_num);\n      --accumulator_[nd_sbp_num];\n    }\n  }\n}\n\n}  // namespace auto_parallel\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/auto_parallel/sbp_collector.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#ifndef SBP_COLLECTOR_\n#define SBP_COLLECTOR_\n\n#include <unordered_map>\n#include <vector>\n#include <unordered_set>\n#include <utility>\n#include <type_traits>\n#include \"oneflow/core/auto_parallel/sbp_graph.h\"\n#include \"oneflow/core/graph/op_graph.h\"\n#include \"oneflow/core/job/sbp_parallel.pb.h\"\n#include \"oneflow/core/job/local_sig_infer_hint.h\"\n#include \"oneflow/core/job/job_builder.h\"\n// #include \"sbp_constructor.h\"\n#define DEBUG_COLLECTOR_\n\nnamespace oneflow {\n\nnamespace auto_parallel {\n\nclass SbpCollector {\n public:\n  SbpCollector();\n\n  ~SbpCollector() {}\n\n  // Collect all the possible Sbp Parallel from a SbpGraph\n  void CollectUniverse(const SbpGraph& sbp_graph);\n\n  // Export list of possible combination of Sbp Parallels\n  void ProxySbpCandidate(const OpGraph& op_graph,\n                         const HashMap<std::string, SbpNode*>& op_name2sbp_node,\n                         SbpGraph& sbp_graph);\n\n private:\n  // Stores all the possible NdSbp.\n  std::unordered_map<NdSbp, int32_t> nd_sbp_universe_;\n  // Relationship between id and Sbp Parallel\n  std::vector<NdSbp> id2nd_sbp_;\n  // Calculate number of downstream sbp\n  std::vector<int32_t> accumulator_;\n  // A binary set buffer to indicate sets of downstream sbp\n  BinarySet bs_buffer_;\n\n  // Collect all the possible Sbp Parallel from a NdSbpSignature\n  void CollectUniverse(const NdSbpSignature& nd_sbp_sig);\n  // Collect all the possible Sbp Parallel from a SbpNode\n  void CollectUniverse(const SbpNode* sbp_node);\n\n  // Initialize copy cost from producer to proxy of producer\n  void InitializeCopyCostFromNode2Proxy(const SbpNode* sbp_proxy, const LogicalBlobId& lbi) const;\n\n  // Initialize copy cost from proxy of producer to consumers\n  void InitializeCopyCostFromProxy2Consumer(\n      SbpNode* sbp_proxy,\n      const HashMap<std::pair<std::string, std::string>, BinarySet>& consumer_bn2sbp_set,\n      const HashMap<std::string, SbpNode*>& op_name2sbp_node) const;\n\n  // Maximum number of possible sbp in the proxy\n  const unsigned long max_num_sbp_proxy_ = 3;\n\n  // Depth first search to collect Sbp Parallel information for the whole sbp set\n  void DfsSbpSet(int32_t depth, int32_t max_depth, const std::unordered_set<int32_t>& sbp_sets,\n                 const std::unordered_set<int32_t>::iterator& sbp_set_it,\n                 const HashMap<std::pair<std::string, std::string>, BinarySet>& consumer_bn2sbp_set,\n                 const std::vector<BinarySet>& unique_sbp_groups,\n                 std::vector<BinarySet>& parallel_candidates);\n};  // class SbpCollector\n\n}  // namespace auto_parallel\n\n}  // namespace oneflow\n\n#endif  // SBP_COLLECTOR_\n"
  },
  {
    "path": "oneflow/core/auto_parallel/sbp_constructor.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/auto_parallel/sbp_constructor.h\"\n#include \"oneflow/core/auto_parallel/auto_memory.h\"\n#include \"oneflow/core/auto_parallel/sbp_node.h\"\n#include \"oneflow/core/auto_parallel/sbp_util.h\"\n#include \"oneflow/core/common/singleton.h\"\n#include \"oneflow/core/device/cuda_util.h\"\n#include \"oneflow/core/framework/sbp_infer_util.h\"\n#include \"oneflow/core/graph/op_graph.h\"\n#include \"oneflow/core/job/job_conf.pb.h\"\n#include \"oneflow/core/job/job_desc.h\"\n#include \"oneflow/core/job/sbp_parallel.h\"\n#include \"oneflow/core/framework/nd_sbp.h\"\n#include \"oneflow/core/job/job.pb.h\"\n#include \"oneflow/core/auto_parallel/sbp_collector.h\"\n#include \"oneflow/core/rpc/include/global_process_ctx.h\"\n\nnamespace oneflow {\n\nnamespace auto_parallel {\n\nnamespace {\n\n// AMS, a.k.a. Applied Mathematics & Statistics, is a department of the Stony Brook University.\n// It contains 5 tracks: Computational & Applied Mathematics, Computational Biology,\n// Operation Research, Quantitative Finance, Statistics.\nAutoMemoryStrategy ams;\n\n// kMemoryRatio increase by this rate at each time.\nstatic const double kMemoryIncreaseRatio = 2.0;\n// The ceil of kMemoryRatio.\nstatic const double kMaxMemoryRatio = 22.0;\n// The floor of kMemoryRatio\nstatic const double kMinMemoryRatio = 0.1;\n// If the current memory > available memory * kImpossibleRatio,\n// then it is impossible to reduce the memory to an acceptable size\nstatic const double kImpossibleRatio = 1.4;\n\n// Pick from 5 fixed types of memory ratio.\ndouble UpdateMemoryRatio() {\n  switch (ams) {\n    case kAdaptiveAutoMemory:\n    case kDisableAutoMemory: return 0.0;\n    case kSlightAutoMemory: return 0.4;\n    case kModerateAutoMemory: return 4.3;\n    default: return 11.0;  // case kHeavyAutoMemory\n  }\n}\n\n}  // namespace\n\ndouble kMemoryRatio;\n\nMaybe<void> SbpConstructor::Init(const OpGraph& op_graph, Job* job /*Maybe not use*/) {\n  JUST(InitSbpGraph(op_graph, *job));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> SbpConstructor::InitSbpGraph(const OpGraph& op_graph, const Job& job) {\n  // Update nccl_use_compute_stream\n  nccl_use_compute_stream_ = Singleton<ResourceDesc, ForSession>::Get()->nccl_use_compute_stream();\n  ams = job.job_conf().enable_auto_memory();\n  kMemoryRatio = UpdateMemoryRatio();\n  // TODO: process local node\n  JUST(GenerateNodeAndEdge(op_graph, job));\n  JUST(FillSbpSignatureForOpNode(op_graph, job));\n  JUST(InitComputationCost(op_graph));\n  if (enable_trunk_algo_) { JUST(ApplyTrunkAlgo()); }\n  // Load logical blobs on all sbp edges.\n  LoadLbi2SbpEdge(op_graph);\n  // InitMemory() should be run before the sbp collector and after the ApplyTrunkAlgo() and\n  // LoadLbi2SbpEdge(op_graph).\n  InitAvailableMemory();\n  InitMemory(op_graph, &sbp_graph_, nccl_use_compute_stream_);\n  if (use_sbp_collector_) {\n    // Use sbp collector to create sbp proxy for nodes with multiple downstream operators.\n    SbpCollector sbp_collector;\n    sbp_collector.CollectUniverse(sbp_graph_);\n    // TODO: Init memory cost for proxy\n    sbp_collector.ProxySbpCandidate(op_graph, op_name2sbp_node_, sbp_graph_);\n  }\n\n  JUST(InitCopyAndMemoryCost(op_graph));\n  // We need to store the original cost and memory after the initialization (InitComputationCost(),\n  // InitMemory(), InitCopyAndMemoryCost()) and before the usage of them (InitWeightedCost())\n  sbp_graph_.StoreOriginMemory();\n  InitWeightedCost();\n  // TODO:  Set all the sbp signature id to be 0 for initialization.\n  //        Could revert it back to\n  // sbp_graph_.RandomSbpSignature(use_sbp_collector_);\n  //        after settling down the synchronization of sbp strategy.\n  sbp_graph_.SetDefaultSbpSig();\n  double ori_cost = sbp_graph_.ComputeCost();\n  LOG(INFO) << \"Initial cost: \" << ori_cost;\n  // If we do not prune those parallel cast ops, steal the initial strategy from user setting and\n  // semi-auto parallelism\n  if (!job.job_conf().enable_auto_parallel_ignore_user_sbp_config()) {\n    JUST(StealSbpSignatureFromOpNode(op_graph, job));\n    ori_cost = sbp_graph_.ComputeCost();\n    LOG(INFO) << \"OpGraph cost: \" << ori_cost;\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> SbpConstructor::FindBestSbpSignature() {\n  double ori_cost = sbp_graph_.ComputeCost();\n  LOG(INFO) << \"Initial cost: \" << ori_cost;\n  int elimination_num = sbp_graph_.NodeAndEdgeEliminations();\n  LOG(INFO) << \"Elimination number: \" << elimination_num;\n  if (ori_cost > GetValidMaxCopyCost()) {\n    JUST(sbp_graph_.Find1Strategy4Greedy());\n    ori_cost = sbp_graph_.ComputeCost();\n    LOG(INFO) << \"Greedy cost: \" << ori_cost;\n  }\n\n  int32_t step = 1;\n  while (true) {\n    sbp_graph_.GreedyStrategy(/*nbh_num=*/4);\n    double curr_memory = sbp_graph_.GetMemory();\n    double total_weighted_cost = sbp_graph_.ComputeWeightedCost();\n    LOG(INFO) << \"The \" << step << \"-th try, memory ratio: \" << kMemoryRatio\n              << \", memory: \" << curr_memory << \", total cost: \" << total_weighted_cost\n              << \", time cost: \" << (total_weighted_cost - kMemoryRatio * curr_memory);\n    if (ams != AutoMemoryStrategy::kAdaptiveAutoMemory) { break; }\n    if (curr_memory < available_memory_ || kMemoryRatio >= kMaxMemoryRatio) { break; }\n    if (curr_memory > available_memory_ * kImpossibleRatio) {\n      kMemoryRatio = kMaxMemoryRatio;\n    } else {\n      kMemoryRatio =\n          std::max(std::min(kMaxMemoryRatio, kMemoryRatio * kMemoryIncreaseRatio), kMinMemoryRatio);\n    }\n    step++;\n    sbp_graph_.ReComputeWeightedCost();\n  }\n  sbp_graph_.FinalizeSbp();\n\n  double final_cost = sbp_graph_.ComputeCost();\n  LOG(INFO) << \"Final cost: \" << final_cost;\n  // TODO: Restart searching with another original random strategy\n  CHECK_LT_OR_RETURN(final_cost, GetValidMaxCopyCost())\n      << \"Failed! Auto parallel can't find a strategy with reasonable cost!\";\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> SbpConstructor::DumpNdSbpSignatureForJob(const OpGraph& op_graph, Job* job) {\n  for (auto& op_conf : *job->mutable_net()->mutable_op()) {\n    const OpNode* node = op_graph.OpNode4OpName(op_conf.name());\n    SbpNode* sbp_node = op_name2sbp_node_[node->op().op_name()];\n    const NdSbpSignature& nd_sbp_sig = sbp_node->FinalSbpSignature();\n    // Update NdSbpSignature\n    (*job->mutable_job_parallel_view_conf()\n          ->mutable_op_name2nd_sbp_signature_conf())[node->op().op_name()]\n        .CopyFrom(nd_sbp_sig);\n    // If we have 1D SbpSignature Conf\n    if (node->parallel_desc().hierarchy()->NumAxes() == 1) {\n      // Update SbpSignature\n      SbpSignature sbp_signature;\n      NdSbpSignatureToSbpSignature(nd_sbp_sig, &sbp_signature);\n      (*job->mutable_job_parallel_view_conf()\n            ->mutable_op_name2sbp_signature_conf())[node->op().op_name()]\n          .CopyFrom(sbp_signature);\n    }\n    JUST(node->op().GetDumpNdSbpSignatureForOpConfFn()(nd_sbp_sig, &op_conf));\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> SbpConstructor::GenerateNodeAndEdge(const OpGraph& op_graph, const Job& job) {\n  JobParallelViewConf job_parallel_view_conf(job.job_parallel_view_conf());\n\n  // Collect op_node\n  std::vector<OpNode*> op_node_list;\n  op_graph.ForEachNode([&](OpNode* op_node) {\n    // TODO: support local op\n    bool is_local_conf = false;\n    {\n      const auto& op_name2is_local = job_parallel_view_conf.op_name2is_local_parallel_view();\n      const auto& iter = op_name2is_local.find(op_node->op().op_name());\n      if (iter != op_name2is_local.end()) { is_local_conf = iter->second; }\n    }\n    CHECK(is_local_conf == false) << \"Haven't deal with local operators.\";\n    op_node_list.push_back(op_node);\n  });\n\n  // Decide the order to visit the op\n  std::vector<int32_t> order;\n  auto CompareOpName = [&](OpNode* a, OpNode* b) {\n    return a->op().op_name().compare(b->op().op_name()) > 0;\n  };\n  auto_parallel::DecideOrder(op_node_list, order, CompareOpName);\n  std::vector<int32_t> output_order;\n\n  // Create sbp nodes\n  for (int32_t i = 0; i < op_node_list.size(); i++) {\n    OpNode* op_node = op_node_list[order[i]];\n    // Generate sbp node in cost model and link it with corresponding op node\n    SbpNode* sbp_node = sbp_graph_.GenerateNode();\n    // Mapping from sbp_node to op_node\n    sbp_node->op_node_ = op_node;  // TODO: SetOpNode()\n    op_name2sbp_node_[op_node->op().op_name()] = sbp_node;\n  }\n  // Create sbp edges\n  for (int32_t i = 0; i < op_node_list.size(); i++) {\n    OpNode* op_node = op_node_list[order[i]];\n    // Get corresponding sbp node\n    SbpNode* sbp_node = op_name2sbp_node_[op_node->op().op_name()];\n    std::vector<OpNode*> output_node_list;\n    for (const auto* op_edge : op_node->out_edges()) {\n      output_node_list.push_back(op_edge->dst_node());\n    }\n    auto_parallel::DecideOrder(output_node_list, output_order, CompareOpName);\n    for (int32_t j : output_order) {\n      const auto& end_node_name = output_node_list[j]->op().op_name();\n      // Generate sbp edge in cost model\n      sbp_node->PointTo(op_name2sbp_node_[end_node_name]);\n    }\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> SbpConstructor::FillSbpSignatureForOpNode(const OpGraph& op_graph, const Job& job) {\n  // TODO: use user sbp signature in JobParallelViewConf\n  // const JobParallelViewConf& job_parallel_view_conf(job.job_parallel_view_conf());\n  JUST(op_graph.TopoForEachNodeWithErrorCaptured([&](OpNode* op_node) -> Maybe<void> {\n    HashMap<std::string, const BlobDesc*> ibn2blob_desc;\n    auto FindShape4Blobs = [&](const PbRpf<std::string>& bns) -> Maybe<void> {\n      for (const std::string& ibn : bns) {\n        const LogicalBlobId& lbi = op_node->op().BnInOp2Lbi(ibn);\n        const BlobDesc* logical_blob_desc = &op_node->LogicalBlobDesc4Lbi(lbi);\n        ibn2blob_desc.emplace(ibn, logical_blob_desc);\n      }\n      return Maybe<void>::Ok();\n    };\n    JUST(FindShape4Blobs(op_node->op().input_bns()));\n    JUST(FindShape4Blobs(op_node->op().output_bns()));\n    // Get logical blob description\n    auto LogicalBlobDesc4Ibn = [&](const std::string& ibn) -> Maybe<const BlobDesc&> {\n      const auto& it = ibn2blob_desc.find(ibn);\n      if (it == ibn2blob_desc.end()) {\n        return Error::InvalidValueError()\n               << \"Cannot find corresponding blob description for input_blob_name : \" + ibn + \" in \"\n                      + op_node->op().op_name();\n      }\n      return *(it->second);\n    };\n    // Get all valid sbp_signatures\n    SbpNode* sbp_node = op_name2sbp_node_[op_node->op().op_name()];\n    JUST(op_node->op().GetValidNdSbpSignatureList(LogicalBlobDesc4Ibn, op_node->parallel_desc(),\n                                                  &sbp_node->sbp_sig_list_, /*check_output=*/true));\n    sbp_node->InitializeSbp();\n    return Maybe<void>::Ok();\n  }));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> SbpConstructor::StealSbpSignatureFromOpNode(const OpGraph& op_graph, const Job& job) {\n  // Steal some strategy from original op graph\n  for (auto* sbp_node : sbp_graph_.node_list_) {\n    // sbp_collectors do not have op_node\n    if (sbp_node->op_node_) {\n      for (int32_t sbp_id = 0; sbp_id < sbp_node->sbp_sig_list_.size(); sbp_id++) {\n        if (*JUST(sbp_node->op_node_->op().nd_sbp_signature()) == sbp_node->sbp_sig_list_[sbp_id]) {\n          sbp_node->final_sbp_sig_id_ = sbp_id;\n          break;\n        }\n      }\n    }\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> SbpConstructor::InitComputationCost(const OpGraph& op_graph) {\n  // Compute computation cost for sbp nodes\n  JUST(op_graph.TopoForEachNodeWithErrorCaptured([&](OpNode* op_node) -> Maybe<void> {\n    // get corresponding sbp node producer\n    SbpNode* sbp_node = op_name2sbp_node_[op_node->op().op_name()];\n    // get parallel description. Number of devices.\n    const ParallelDesc& parallel_desc = op_node->parallel_desc();\n\n    CHECK_EQ_OR_RETURN(sbp_node->cost_.size(), sbp_node->sbp_sig_list_.size());\n    auto LogicalBlobDesc4Bn = [&](const std::string& bn) -> const BlobDesc& {\n      const LogicalBlobId& lbi = op_node->op().BnInOp2Lbi(bn);\n      return op_node->LogicalBlobDesc4Lbi(lbi);\n    };\n    for (int32_t sbp_id = 0; sbp_id < sbp_node->sbp_sig_list_.size(); sbp_id++) {\n      double comp_cost = JUST(op_node->op().GetComputeComplexity(\n          &sbp_node->sbp_sig_list_[sbp_id], LogicalBlobDesc4Bn, parallel_desc));\n      if (comp_cost > GetValidMaxCopyCost()) {\n        sbp_node->cost_[sbp_id] = comp_cost;\n      } else {\n        sbp_node->cost_[sbp_id] =\n            cost_ratio_ * comp_cost\n            * JUST(op_node->op().GetInputOutputFastestTimeShape())->elem_cnt();\n      }\n    }\n    return Maybe<void>::Ok();\n  }));\n  return Maybe<void>::Ok();\n}\n\n// Init copy cost and memory for edges\nMaybe<void> SbpConstructor::InitCopyAndMemoryCost(const OpGraph& op_graph) {\n  bool nccl_not_use_compute_stream = !nccl_use_compute_stream_;\n  // Compute copy cost for sbp edges\n  op_graph.ForEachNode([&](OpNode* op_node) {\n    // get corresponding sbp node consumer\n    SbpNode* sbp_node_consumer = op_name2sbp_node_[op_node->op().op_name()];\n    // Initialize copy cost between two nodes\n    for (auto* sbp_edge : sbp_node_consumer->edges_in_) {\n      // producer sbp node\n      const auto* sbp_node_producer = sbp_edge->start_node_;\n      // skip it if proxy\n      if (!sbp_node_producer->op_node_) { continue; }\n      sbp_edge->cost_.resize(sbp_node_producer->sbp_sig_list_.size());\n      if (nccl_not_use_compute_stream) {\n        sbp_edge->memory_.resize(sbp_node_producer->sbp_sig_list_.size());\n      }\n      int32_t consumer_sbp_size = sbp_node_consumer->sbp_sig_list_.size();\n      // look through sbp signature in producer\n      for (int32_t i = 0; i < sbp_node_producer->sbp_sig_list_.size(); ++i) {\n        sbp_edge->cost_[i].resize(consumer_sbp_size, 0);\n        if (nccl_not_use_compute_stream) { sbp_edge->memory_[i].resize(consumer_sbp_size, 0); }\n      }\n    }\n    // Find all those cases with wait time\n    // Do not skip edges carrying no lbi\n    sbp_node_consumer->InitCopyAndMemoryCost(use_sbp_collector_, nccl_not_use_compute_stream);\n  });\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> SbpConstructor::ApplyTrunkAlgo() {\n  // TODO: Remove this\n  auto OpNode2MutableOpCtrlDeps = JUST(GetMutableOpCtrlDeps(*op_graph_));\n  // Compute layer number for each node\n  int32_t max_min_layer = sbp_graph_.ComputeLayer(op_name2sbp_node_, *OpNode2MutableOpCtrlDeps);\n  // Accumulate cost on the trunk after initializing computation cost\n  sbp_graph_.FindTrunk(max_min_layer, op_name2sbp_node_);\n  return Maybe<void>::Ok();\n}\n\n// Load logical blob ids onto sbp edges\nvoid SbpConstructor::LoadLbi2SbpEdge(const OpGraph& op_graph) {\n  // Load logical blobs onto sbp edges\n\n  for (auto* sbp_node_consumer : sbp_graph_.node_list_) {\n    auto* op_node = sbp_node_consumer->op_node_;\n\n    // Loading logical blobs between two nodes\n    // look through input blobs\n    for (const std::string& ibn : op_node->op().input_bns()) {\n      // Each input blob has one source op node.\n      OpNode* producer = op_node->MutSrcNode4Ibn(ibn);\n      // producer sbp node\n      const auto* sbp_node_producer = op_name2sbp_node_[producer->op().op_name()];\n      // TODO: recode this\n      auto* edge_found = sbp_node_consumer->FindEdgeWithNode(sbp_node_producer);\n\n      CHECK(edge_found != NULL) << \"SbpEdge not found while loading!\" << std::endl;\n\n      // Add copy cost for each blob\n      const LogicalBlobId& lbi = op_node->op().BnInOp2Lbi(ibn);\n      edge_found->LoadLbi(lbi);\n    }\n  };\n}\n\nMaybe<void> SbpConstructor::CheckSbpAgreement(const Job& job) {\n  Job new_job;\n  new_job.CopyFrom(job);\n  OpGraph op_graph(new_job);\n  // Compare sbp in job\n  JUST(op_graph.TopoForEachNodeWithErrorCaptured([&](OpNode* op_node) -> Maybe<void> {\n    const std::string& op_name = op_node->op().op_name();\n    const NdSbpSignature& auto_parallel_sbp =\n        NdSbpSignature(job.job_parallel_view_conf().op_name2nd_sbp_signature_conf().at(op_name));\n    const NdSbpSignature& new_sbp = op_node->nd_sbp_signature();\n    CHECK_EQ_OR_RETURN(auto_parallel_sbp.bn_in_op2nd_sbp_size(), new_sbp.bn_in_op2nd_sbp_size());\n    for (const auto& iter : auto_parallel_sbp.bn_in_op2nd_sbp()) {\n      const NdSbp& new_sbp_parallel = new_sbp.bn_in_op2nd_sbp().at(iter.first);\n      const NdSbp& auto_parallel_sbp = iter.second;\n      // According error message, we can find op_type in op_conf.proto with type_id and locate\n      // the error op type.\n      const std::string& error_mgs =\n          \"Op: `\" + op_name + \"`(type_id: \" + std::to_string(op_node->op().op_conf().op_type_case())\n          + \") changed sbp from \" + NdSbpToString(auto_parallel_sbp) + \"(AutoParallel) to \"\n          + NdSbpToString(new_sbp_parallel) + \"(OpGraph) with blob_name: `\" + iter.first + \"`.\";\n      CHECK_OR_RETURN(new_sbp_parallel == auto_parallel_sbp) << error_mgs;\n    }\n    return Maybe<void>::Ok();\n  }));\n  return Maybe<void>::Ok();\n}\n\n// TODO: delete this, this is for variable op only\nMaybe<HashMap<const OpNode*, HashSet<std::string>>> SbpConstructor::GetMutableOpCtrlDeps(\n    const OpGraph& op_graph) {\n  auto IsMutableConsumedLbi = [](const Operator& op, const LogicalBlobId& lbi) -> bool {\n    for (const std::string& bn : op.input_bns()) {\n      if (op.BnInOp2Lbi(bn) == lbi && op.InputBlobModifier4Ibn(bn).is_mutable()) { return true; }\n    }\n    return false;\n  };\n  const auto& IsReachable = op_graph.MakePredicatorIsOpNameDataOrCtrlReachable();\n  HashMap<const OpNode*, HashSet<std::string>> op_node2ctrl_in_op_names;\n  JUST(op_graph.MaybeForEachNode([&](OpNode* op_node) -> Maybe<void> {\n    if (op_node->op().op_conf().has_variable_conf() == false) { return Maybe<void>::Ok(); }\n    if (op_node->out_edges().size() <= 1) { return Maybe<void>::Ok(); }\n    const Operator& variable_op = op_node->op();\n    const LogicalBlobId& variable_lbi = variable_op.BnInOp2Lbi(variable_op.SoleObn());\n    const OpNode* mutable_consumer = nullptr;\n    std::vector<const OperatorConf*> naive_consumers;\n    naive_consumers.reserve(op_node->out_edges().size());\n    for (OpEdge* edge : op_node->out_edges()) {\n      const auto& op_conf = edge->dst_node()->op().op_conf();\n      if (IsMutableConsumedLbi(edge->dst_node()->op(), variable_lbi)) {\n        CHECK_OR_RETURN(mutable_consumer == nullptr);\n        mutable_consumer = edge->dst_node();\n      } else {\n        naive_consumers.emplace_back(&op_conf);\n      }\n    }\n    if (mutable_consumer == nullptr) { return Maybe<void>::Ok(); }\n    for (const auto* fw_bw_op : naive_consumers) {\n      op_node2ctrl_in_op_names[mutable_consumer].insert(fw_bw_op->name());\n    }\n    return Maybe<void>::Ok();\n  }));\n  // Filter ctrl edges if all ctrl_in_op_names are reachable\n  HashMap<const OpNode*, HashSet<std::string>> filter_op_ctrl_deps;\n  for (const auto& pair : op_node2ctrl_in_op_names) {\n    const OpNode* op_node = pair.first;\n    for (const auto& fw_bw_op_name : pair.second) {\n      if (!IsReachable(fw_bw_op_name, op_node->op().op_name())) {\n        filter_op_ctrl_deps[op_node].insert(fw_bw_op_name);\n      }\n    }\n  }\n  return filter_op_ctrl_deps;\n}\n\nvoid SbpConstructor::InitAvailableMemory() {\n  size_t free = 0;\n  size_t total = 0;\n#ifdef WITH_CUDA\n  CudaCurrentDeviceGuard guard(GlobalProcessCtx::Rank());\n  OF_CUDA_CHECK(cudaMemGetInfo(&free, &total));\n#else\n  free = 1e13;   // 10T = 10,000G\n  total = 1e13;  // 10T = 10,000G\n  LOG(INFO) << \"We do not use CUDA in CPU mode, auto memory is unnecessary since all the SBPs are \"\n               \"Broadcast.\";\n#endif\n  // The estimated memory differs from the lower bound of the peak memory by the first ratio.\n  // The first ratio varies from -3% to 3.2% if not enabling nccl_use_compute_stream.\n  // It varies from 0.00313% to 0.5% if enabling nccl_use_compute_stream.\n  double first_ratio = 1.0;\n  if (nccl_use_compute_stream_) {\n    first_ratio = 1.01;\n  } else {\n    first_ratio = 1.04;\n  }\n  // The lower bound of the peak memory differs from the allocated memory by the second ratio.\n  // The second ratio varies from 0 to 2.65% if not using pipeline parallelism.\n  // It varies from 0 to 5.23% if using pipeline parallelism.\n  double second_ratio = 1.06;\n  // The occupied memory at this moment would be around 1114MB to 1240MB.\n  // When it gets to the training process, the occupied memory might drop by 162MB.\n  // But the key is that we start to allocate memory before the training process.\n  // Thus, this 161MB should not be added to the free memory.\n  // We still use \"available memory = free / ratio\" instead of \"free / ratio + 161MB\".\n  available_memory_ = int64_t(free / (first_ratio * second_ratio));\n  LOG(INFO) << \"Free memory: \" << free << \", total memory: \" << total\n            << \", available memory: \" << available_memory_;\n}\n\nvoid SbpConstructor::InitWeightedCost() {\n  for (auto& sbp_node : sbp_graph_.node_list_) {\n    sbp_node->ComputeWeightedCost();\n    for (auto& sbp_edge : sbp_node->edges_in_) { sbp_edge->ComputeWeightedCost(); }\n  }\n}\n\n// Print the graph with SBP in order\nvoid SbpConstructor::PrintSBPGraphDebugInfo() {\n  // sbp constructor information\n  std::cout << \"cost_ratio_:\" << cost_ratio_ << std::endl;\n  std::cout << \"wait_time_:\" << sbp_graph_.wait_time_ << std::endl;\n  std::cout << \"use_sbp_collector_\" << use_sbp_collector_ << std::endl;\n  std::cout << \"Total auto parallel guessed memory: \" << sbp_graph_.GetMemory() << std::endl;\n  std::cout << \"Final memory ratio: \" << kMemoryRatio << std::endl;\n  // test debug\n  std::cout << \"Get Into Print Op Graph\" << std::endl;\n  // Collect op_node\n  std::vector<OpNode*> node_list;\n  for (const auto& op_name_sbp_node : op_name2sbp_node_) {\n    auto* op_node_ = op_name_sbp_node.second->op_node_;\n    if (op_node_) { node_list.push_back(op_node_); }\n  }\n\n  // test debug\n  std::cout << \"Deciding order\" << std::endl;\n  // Decide the order to visit the op\n  std::vector<int32_t> order;\n  auto_parallel::DecideOrder(node_list, order, [&](OpNode* a, OpNode* b) {\n    return a->op().op_name().compare(b->op().op_name()) > 0;\n  });\n  std::vector<int32_t> str_order;\n\n  // test debug\n  std::cout << \"Finish deciding order\" << std::endl;\n\n  for (int32_t i = 0; i < node_list.size(); i++) {\n    OpNode* op_node = node_list[order[i]];\n    std::cout << op_node->op().op_name() << \" (^_^):\" << std::endl;\n    // get corresponding sbp node\n    const auto& it = op_name2sbp_node_.find(op_node->op().op_name());\n    // Print debug information for sbp graph\n    CHECK(it != op_name2sbp_node_.end());\n    const SbpNode* sbp_node = it->second;\n    std::cout << \"Computation Cost: \" << sbp_node->weighted_cost_[sbp_node->final_sbp_sig_id_];\n    std::cout << \", Min Layer: \" << sbp_node->min_layer_ << \", Max Layer: \" << sbp_node->max_layer_\n              << \", Tributary Layer: \" << sbp_node->tributary_layer_\n              << \", in trunk: \" << sbp_node->on_trunk_\n              << \", Remain Cost: \" << sbp_node->acc_trunk_cost_ << std::endl;\n    // Sort before printing\n    const auto& op_input_bns = op_node->op().input_bns();\n    auto CompareString = [](const std::string& a, const std::string& b) {\n      return a.compare(b) > 0;\n    };\n    auto_parallel::DecideOrder(op_input_bns, str_order, CompareString);\n    const NdSbpSignature& sbp_signature = sbp_node->FinalSbpSignature();\n    // Print out SBP information for input operator\n    for (int32_t j : str_order) {\n      const auto& ibn = op_input_bns[j];\n      const auto& producer_node = op_node->SrcNode4Ibn(ibn);\n      std::cout << \"Pre Op:\" << producer_node.op().op_name() << \": \" << ibn;\n      const auto& this_sbp_parallel = sbp_signature.bn_in_op2nd_sbp().at(ibn);\n      std::cout << \", \" << NdSbpToString(this_sbp_parallel);\n      if (RequireSameSbp(op_node, ibn)) { std::cout << \", require same SBP\"; }\n      std::cout << \", \" << op_node->LogicalBlobDesc4Lbi(op_node->op().BnInOp2Lbi(ibn)).shape();\n      std::cout << std::endl;\n    }\n    // Sort before printing\n    const auto& op_output_bns = op_node->op().output_bns();\n    auto_parallel::DecideOrder(op_output_bns, str_order, CompareString);\n    // Print out SBP information for output blobs\n    for (int32_t j : str_order) {\n      const auto& obn = op_output_bns[j];\n      std::cout << \"Out Op:\" << obn;\n      const auto& this_sbp_parallel = sbp_signature.bn_in_op2nd_sbp().at(obn);\n      std::cout << \", \" << NdSbpToString(this_sbp_parallel);\n      std::cout << \", \" << op_node->LogicalBlobDesc4Lbi(op_node->op().BnInOp2Lbi(obn)).shape();\n      std::cout << std::endl;\n    }\n    std::cout << std::endl;\n  }\n}\n\n}  // namespace auto_parallel\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/auto_parallel/sbp_constructor.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#ifndef ONEFLOW_CORE_AUTO_PARALLEL_SBP_CONSTRUCTOR_H_\n#define ONEFLOW_CORE_AUTO_PARALLEL_SBP_CONSTRUCTOR_H_\n\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/auto_parallel/sbp_graph.h\"\n#include \"oneflow/core/job/global_for.h\"\n\nnamespace oneflow {\n\nclass OpGraph;\nclass Job;\n\nnamespace auto_parallel {\n\n// A constructor which will assemble the sbp_graph with the information from oneflow.\n// SbpGraph contains the algorithms for elimination and search which is mainly for the strategy\n// itself. Constructor mainly deal with the assemblage of each node, edge and the cost computation,\n// activation of functions.\nclass SbpConstructor final {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(SbpConstructor);\n  SbpConstructor() = delete;\n  SbpConstructor(const OpGraph& op_graph, Job* job)\n      : cost_ratio_(job->job_conf().auto_parallel_computation_cost_ratio()),\n        enable_trunk_algo_(job->job_conf().enable_auto_parallel_trunk_algo()),\n        use_sbp_collector_(!Singleton<ResourceDesc, ForSession>::Get()\n                                ->resource()\n                                .disable_group_boxing_by_dst_parallel()\n                           && job->job_conf().enable_auto_parallel_sbp_collector()),\n        op_graph_(&op_graph) {\n    sbp_graph_.SetWaitTime(job->job_conf().auto_parallel_wait_time());\n    CHECK_JUST(Init(op_graph, job));\n  }\n  ~SbpConstructor() = default;\n\n  Maybe<void> Init(const OpGraph& op_graph, Job* job);\n  Maybe<void> FindBestSbpSignature();\n  Maybe<void> DumpNdSbpSignatureForJob(const OpGraph& op_graph, Job* job);\n  // Re-build OpGraph and check all sbp is same between op_graph and job\n  Maybe<void> CheckSbpAgreement(const Job& job);\n  // Print the graph with SBP in order\n  void PrintSBPGraphDebugInfo();\n\n private:\n  Maybe<void> InitSbpGraph(const OpGraph& op_graph, const Job& job);\n  Maybe<void> GenerateNodeAndEdge(const OpGraph& op_graph, const Job& job);\n  Maybe<void> FillSbpSignatureForOpNode(const OpGraph& op_graph, const Job& job);\n  Maybe<void> StealSbpSignatureFromOpNode(const OpGraph& op_graph, const Job& job);\n  Maybe<void> InitComputationCost(const OpGraph& op_graph);\n  Maybe<void> InitCopyAndMemoryCost(const OpGraph& op_graph);\n  Maybe<void> ApplyTrunkAlgo();\n  Maybe<HashMap<const OpNode*, HashSet<std::string>>> GetMutableOpCtrlDeps(const OpGraph& op_graph);\n  void InitAvailableMemory();\n  void InitWeightedCost();\n  // Load logical blob ids onto sbp edges\n  void LoadLbi2SbpEdge(const OpGraph& op_graph);\n\n  double cost_ratio_;\n  bool enable_trunk_algo_;\n  bool use_sbp_collector_;\n  SbpGraph sbp_graph_;\n  const OpGraph* op_graph_;\n  HashMap<std::string, SbpNode*> op_name2sbp_node_;\n  bool nccl_use_compute_stream_;\n  int64_t available_memory_;\n};\n\n}  // namespace auto_parallel\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_AUTO_PARALLEL_SBP_CONSTRUCTOR_H_\n"
  },
  {
    "path": "oneflow/core/auto_parallel/sbp_edge.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include <assert.h>\n#include <algorithm>\n#include <unordered_set>\n#include \"oneflow/core/job/parallel_desc.h\"\n#include \"oneflow/core/job/lazy_mode.h\"\n#include \"oneflow/core/framework/sbp_infer_util.h\"\n#include \"oneflow/core/auto_parallel/sbp_edge.h\"\n#include \"oneflow/core/auto_parallel/sbp_node.h\"\n#include \"oneflow/core/auto_parallel/sbp_graph.h\"\n#include \"oneflow/core/auto_parallel/sbp_util.h\"\n#include \"oneflow/core/graph/op_graph.h\"\n\nnamespace oneflow {\nnamespace auto_parallel {\n\nextern double kMemoryRatio;\n\n// function in cpp. Should be put in one file due to use of template\n// Otherwise we will need to declare specific template at the end of cpp file.\nSbpEdge::SbpEdge(SbpNode* start_node, SbpNode* mid_node, SbpNode* end_node, SbpEdge* first_edge,\n                 SbpEdge* second_edge)\n    : start_node_(start_node), mid_node_(mid_node), end_node_(end_node) {\n  // The first edge must between start_node and mid_node, but it could be\n  // start_node -> mid_node or mid_node -> start node\n  // Same for the second edge.\n  edge_list_.emplace_back(first_edge);\n  edge_list_.emplace_back(second_edge);\n};\n\n// Deconstructor\nSbpEdge::~SbpEdge() {\n  if (mid_node_ != nullptr) { delete mid_node_; }\n  for (auto& this_edge : edge_list_) { delete this_edge; }\n}\n\nvoid SbpEdge::SummarizeCost() {\n  // If any sub data structure is in the memory support,\n  // then this edge is in the memory support\n  if (mid_node_ && mid_node_->in_memory_support_) {\n    in_memory_support_ = true;\n  } else {\n    in_memory_support_ = std::any_of(edge_list_.begin(), edge_list_.end(), [](SbpEdge* sbp_edge) {\n      return sbp_edge->in_memory_support_;\n    });\n  }\n  // We would need to compute the memory for this elimination\n  int32_t start_node_sbp_size = start_node_->weighted_cost_.size();\n  if (in_memory_support_) { memory_.resize(start_node_sbp_size); }\n  weighted_cost_.resize(start_node_sbp_size);\n  // Copy cost and memory cost\n  if (mid_node_) {\n    // Buffer\n    int64_t memory_cost = 0;\n    int64_t min_memory_cost = 0;\n    int32_t min_sbp_mid = 0;\n    double weighted_cost = 0.0;\n    double min_weighted_cost = 0.0;\n    // Node elimination\n    mid_node_sbp_sig_.resize(start_node_sbp_size);\n    int32_t end_node_sbp_size = end_node_->weighted_cost_.size();\n    int32_t mid_node_sbp_size = mid_node_->weighted_cost_.size();\n    for (int32_t sbp_start = 0; sbp_start < start_node_sbp_size; sbp_start++) {\n      if (in_memory_support_) { memory_[sbp_start].resize(end_node_sbp_size); }\n      weighted_cost_[sbp_start].resize(end_node_sbp_size);\n      mid_node_sbp_sig_[sbp_start].resize(end_node_sbp_size);\n      for (int32_t sbp_end = 0; sbp_end < end_node_sbp_size; sbp_end++) {\n        for (int32_t sbp_mid = 0; sbp_mid < mid_node_sbp_size; sbp_mid++) {\n          // Add middle node cost\n          memory_cost = mid_node_->GetMemory(sbp_mid);\n          weighted_cost = mid_node_->weighted_cost_[sbp_mid];\n          // Add first edge cost\n          if (edge_list_[0]->end_node_ == mid_node_) {\n            int32_t edge_sbp_start =\n                start_node_->GetComponentSbpId(sbp_start, edge_list_[0]->start_node_);\n            memory_cost += edge_list_[0]->GetMemory(edge_sbp_start, sbp_mid);\n            weighted_cost += edge_list_[0]->weighted_cost_[edge_sbp_start][sbp_mid];\n          } else {\n            int32_t edge_sbp_start =\n                start_node_->GetComponentSbpId(sbp_start, edge_list_[0]->end_node_);\n            memory_cost += edge_list_[0]->GetMemory(sbp_mid, edge_sbp_start);\n            weighted_cost += edge_list_[0]->weighted_cost_[sbp_mid][edge_sbp_start];\n          }\n          // Add second edge cost\n          if (edge_list_[1]->start_node_ == mid_node_) {\n            int32_t edge_sbp_end = end_node_->GetComponentSbpId(sbp_end, edge_list_[1]->end_node_);\n            memory_cost += edge_list_[1]->GetMemory(sbp_mid, edge_sbp_end);\n            weighted_cost += edge_list_[1]->weighted_cost_[sbp_mid][edge_sbp_end];\n          } else {\n            int32_t edge_sbp_end =\n                end_node_->GetComponentSbpId(sbp_end, edge_list_[1]->start_node_);\n            memory_cost += edge_list_[1]->GetMemory(edge_sbp_end, sbp_mid);\n            weighted_cost += edge_list_[1]->weighted_cost_[edge_sbp_end][sbp_mid];\n          }\n\n          // Compare and look for the minimum cost\n          if (sbp_mid == 0 || weighted_cost < min_weighted_cost) {\n            min_sbp_mid = sbp_mid;\n            min_memory_cost = memory_cost;\n            min_weighted_cost = weighted_cost;\n          }\n        }\n        // Store the results of the dynamic programming for minimizing the weighted sum\n        if (in_memory_support_) { memory_[sbp_start][sbp_end] = min_memory_cost; }\n        weighted_cost_[sbp_start][sbp_end] = min_weighted_cost;\n        mid_node_sbp_sig_[sbp_start][sbp_end] = min_sbp_mid;\n      }\n    }\n  } else {\n    // Edge elimination\n    int32_t end_node_sbp_size = end_node_->weighted_cost_.size();\n    for (int32_t sbp_start = 0; sbp_start < weighted_cost_.size(); sbp_start++) {\n      if (in_memory_support_) { memory_[sbp_start].resize(end_node_sbp_size); }\n      weighted_cost_[sbp_start].resize(end_node_sbp_size);\n      for (int32_t sbp_end = 0; sbp_end < end_node_sbp_size; sbp_end++) {\n        int64_t memory_cost = 0;\n        double weighted_cost = 0.0;\n        for (int32_t edge_num = 0; edge_num < edge_list_.size(); edge_num++) {\n          // For normal edge elimination, instead of recomputation with different memory ratio\n          // Either (start_node_ == edge_list_[edge_num]->start_node_\n          // and end_node_ == edge_list_[edge_num]->end_node_) is true\n          // Or (start_node_ == edge_list_[edge_num]->end_node_ and\n          // end_node_ == edge_list_[edge_num]->start_node_) is true.\n          // At this moment, start_node_->component2merged_sig_id2component_sig_id_ is not\n          // initialized. As a result, if start_node_ != edge_list_[edge_num]->start_node_,\n          // IsComponent() would return false immediately.\n          if (start_node_->IsComponent(edge_list_[edge_num]->start_node_)) {\n            int32_t edge_sbp_start =\n                start_node_->GetComponentSbpId(sbp_start, edge_list_[edge_num]->start_node_);\n            int32_t edge_sbp_end =\n                end_node_->GetComponentSbpId(sbp_end, edge_list_[edge_num]->end_node_);\n            memory_cost += edge_list_[edge_num]->GetMemory(edge_sbp_start, edge_sbp_end);\n            weighted_cost += edge_list_[edge_num]->weighted_cost_[edge_sbp_start][edge_sbp_end];\n          } else {\n            // At this moment\n            // start_node_->IsComponent(edge_list_[edge_num]->end_node_)\n            // end_node_->IsComponent(edge_list_[edge_num]->start_node_)\n            int32_t edge_sbp_start =\n                start_node_->GetComponentSbpId(sbp_start, edge_list_[edge_num]->end_node_);\n            int32_t edge_sbp_end =\n                end_node_->GetComponentSbpId(sbp_end, edge_list_[edge_num]->start_node_);\n            memory_cost += edge_list_[edge_num]->GetMemory(edge_sbp_end, edge_sbp_start);\n            weighted_cost += edge_list_[edge_num]->weighted_cost_[edge_sbp_end][edge_sbp_start];\n          }\n        }\n        if (in_memory_support_) { memory_[sbp_start][sbp_end] = memory_cost; }\n        weighted_cost_[sbp_start][sbp_end] = weighted_cost;\n      }\n    }\n  }\n}\n\nvoid SbpEdge::DuplicateCost(\n    bool merged_node_is_start_node, bool duplicating_first_node,\n    const std::vector<std::pair<int32_t, int32_t>>& merged_sig_id2half_sig_id) {\n  const int32_t num_sig = merged_sig_id2half_sig_id.size();\n  std::vector<std::vector<double>> copy_cost;\n  std::vector<std::vector<int32_t>> temp_mid_node_sbp_sig;\n  std::vector<std::vector<int64_t>> temp_memory;\n  std::vector<std::vector<double>> weighted_cost;\n  if (merged_node_is_start_node) {\n    if (edge_list_.empty()) { copy_cost.resize(num_sig); }\n    if (mid_node_) { temp_mid_node_sbp_sig.resize(num_sig); }\n    weighted_cost.resize(num_sig);\n    if (in_memory_support_) { temp_memory.resize(num_sig); }\n    for (int32_t i = 0; i < num_sig; i++) {\n      const int32_t sig_idx = duplicating_first_node ? merged_sig_id2half_sig_id[i].first\n                                                     : merged_sig_id2half_sig_id[i].second;\n      if (edge_list_.empty()) { copy_cost[i] = cost_[sig_idx]; }\n      weighted_cost[i] = weighted_cost_[sig_idx];\n      if (mid_node_) { temp_mid_node_sbp_sig[i] = mid_node_sbp_sig_[sig_idx]; }\n      if (in_memory_support_) { temp_memory[i] = memory_[sig_idx]; }\n    }\n  } else {\n    const int32_t num_start_sig = weighted_cost_.size();\n    if (edge_list_.empty()) { copy_cost.resize(num_start_sig); }\n    weighted_cost.resize(num_start_sig);\n    if (mid_node_) { temp_mid_node_sbp_sig.resize(num_start_sig); }\n    if (in_memory_support_) { temp_memory.resize(num_start_sig); }\n    for (int32_t i = 0; i < num_start_sig; i++) {\n      if (edge_list_.empty()) { copy_cost[i].resize(num_sig); }\n      weighted_cost[i].resize(num_sig);\n      if (mid_node_) { temp_mid_node_sbp_sig[i].resize(num_sig); }\n      if (in_memory_support_) { temp_memory[i].resize(num_sig); }\n      for (int32_t j = 0; j < num_sig; j++) {\n        const int32_t sig_idx = duplicating_first_node ? merged_sig_id2half_sig_id[j].first\n                                                       : merged_sig_id2half_sig_id[j].second;\n        if (edge_list_.empty()) { copy_cost[i][j] = cost_[i][sig_idx]; }\n        weighted_cost[i][j] = weighted_cost_[i][sig_idx];\n        if (mid_node_) { temp_mid_node_sbp_sig[i][j] = mid_node_sbp_sig_[i][sig_idx]; }\n        if (in_memory_support_) { temp_memory[i][j] = memory_[i][sig_idx]; }\n      }\n    }\n  }\n\n  if (edge_list_.empty()) { cost_ = copy_cost; }\n  weighted_cost_ = weighted_cost;\n  if (mid_node_) { mid_node_sbp_sig_ = temp_mid_node_sbp_sig; }\n  if (in_memory_support_) { memory_ = temp_memory; }\n}\n\n// Compute the weighted sum of the time and memory cost\nvoid SbpEdge::ComputeWeightedCost() {\n  if (edge_list_.empty()) {\n    // If this edge does not contain any sub edges, it should have original cost\n    weighted_cost_ = cost_;\n    if (in_memory_support_) {\n      for (int32_t i = 0; i < memory_.size(); i++) {\n        auto& memory_i = memory_[i];\n        auto& weighted_cost_i = weighted_cost_[i];\n        for (int32_t j = 0; j < memory_[i].size(); j++) {\n          weighted_cost_i[j] += kMemoryRatio * memory_i[j];\n        }\n      }\n    }\n  } else {\n    // Compute the weighted cost for sub components\n    for (auto& sbp_edge : edge_list_) { sbp_edge->ComputeWeightedCost(); }\n    if (mid_node_) { mid_node_->ComputeWeightedCost(); }\n    // Generate relationship if two vertices are merged nodes\n    // For example, we have 4 nodes: A, B, C, D\n    // and two edges: 1: A->B, 2: A->B\n    // We merge the two edges 1 and 2 into 3: A->B.\n    // Then we merge A and C into E and merge B and D into F.\n    // Now the edge 3: E->F has two sub edges: 1: A->B, 2:A->B,\n    // which tell us that the sub edges might have different vertices from the current edge.\n    start_node_->GenerateComponentRelationship();\n    end_node_->GenerateComponentRelationship();\n    // Re-compute the weighted cost\n    SummarizeCost();\n  }\n}\n\nvoid SbpEdge::FinalizeSbp() {\n  // Finalize Sbp for mid_node_\n  if (mid_node_) {\n    mid_node_->final_sbp_sig_id_ =\n        mid_node_sbp_sig_[start_node_->final_sbp_sig_id_][end_node_->final_sbp_sig_id_];\n    mid_node_->FinalizeSbp();\n  }\n  for (const auto& this_edge : edge_list_) { this_edge->FinalizeSbp(); }\n}\n\ndouble SbpEdge::GreedyStrategy() {\n  // Sbp combination of the minimum cost\n  int32_t min_sbp_start = start_node_->final_sbp_sig_id_,\n          min_sbp_end = end_node_->final_sbp_sig_id_;\n  // An unordered_map to evaluate cost between two edge nodes and other nodes.\n  std::unordered_map<int32_t, int32_t> node_list_id2nbh_id = {{start_node_->node_list_id_, 0},\n                                                              {end_node_->node_list_id_, 1}};\n  // pre-compute and store the current cost between end_node_ and outside.\n  std::vector<double> end_node_out_cost(end_node_->weighted_cost_.size());\n  for (int32_t sbp_end = 0; sbp_end < weighted_cost_[0].size(); sbp_end++) {\n    end_node_->final_sbp_sig_id_ = sbp_end;\n    end_node_out_cost[sbp_end] = end_node_->EvalOutNbhCost(node_list_id2nbh_id);\n  }\n  // pre-compute and store the current cost between start_node_ and outside.\n  std::vector<double> start_node_out_cost(start_node_->weighted_cost_.size());\n  for (int32_t sbp_start = 0; sbp_start < weighted_cost_.size(); sbp_start++) {\n    start_node_->final_sbp_sig_id_ = sbp_start;\n    start_node_out_cost[sbp_start] = start_node_->EvalOutNbhCost(node_list_id2nbh_id);\n  }\n  // Current Cost, Minimum Cost, Cost with original sbp\n  double curr_cost = 0.0;\n  double min_cost = start_node_out_cost[min_sbp_start] + end_node_out_cost[min_sbp_end]\n                    + weighted_cost_[min_sbp_start][min_sbp_end];\n  double original_cost = min_cost;\n\n  for (int32_t sbp_start = 0; sbp_start < weighted_cost_.size(); sbp_start++) {\n    for (int32_t sbp_end = 0; sbp_end < weighted_cost_[0].size(); sbp_end++) {\n      // compute Current Cost for Neighborhood of edge\n      end_node_->final_sbp_sig_id_ = sbp_end;\n      curr_cost = start_node_out_cost[sbp_start] + end_node_out_cost[sbp_end]\n                  + weighted_cost_[sbp_start][sbp_end];\n      // Find the minimum current cost\n      if (curr_cost < min_cost) {\n        min_cost = curr_cost;\n        min_sbp_start = sbp_start;\n        min_sbp_end = sbp_end;\n      }\n    }\n  }\n  start_node_->final_sbp_sig_id_ = min_sbp_start;\n  end_node_->final_sbp_sig_id_ = min_sbp_end;\n  return min_cost - original_cost;\n}\n\n// Get the minimum element in Cost\ndouble SbpEdge::GetMinWeightedCost() {\n  // used the stored value if pre-computed.\n  if (kMemoryRatio == memory_ratio4min_weighted_cost_ && min_weighted_cost_ >= 0) {\n    return min_weighted_cost_;\n  }\n  // Check the size of Cost\n  CHECK(weighted_cost_.size() > 0) << \"Cost not initialized!\" << std::endl;\n  // Compute the min_cost for corresponding memory ratio\n  min_weighted_cost_ = GetWeightedCost();\n  for (int32_t i = 0; i < weighted_cost_.size(); i++) {\n    for (int32_t j = 0; j < weighted_cost_[i].size(); j++) {\n      min_weighted_cost_ = std::min(min_weighted_cost_, GetWeightedCost(i, j));\n    }\n  }\n  // Store current the memory ratio\n  memory_ratio4min_weighted_cost_ = kMemoryRatio;\n  return min_weighted_cost_;\n}\n\n// Assemble copy cost\nvoid SbpEdge::InitCopyAndMemoryCost(const std::string& ibn, bool use_sbp_collector,\n                                    bool nccl_not_use_compute_stream) {\n  std::vector<int64_t> consumer_nd_sbp_sig2memory;\n  if (nccl_not_use_compute_stream) {\n    in_memory_support_ = true;\n    // Compute and store the memory for consumer\n    const auto& consumer_operator = end_node_->op_node_->op();\n    const auto& end_sbp_sig_list = end_node_->sbp_sig_list_;\n    consumer_nd_sbp_sig2memory.resize(end_sbp_sig_list.size(), 0);\n    const auto& lbi = consumer_operator.BnInOp2Lbi(ibn);\n    const auto& consumer_hierarchy =\n        *CHECK_JUST(consumer_operator.GetParallelDesc4BnInOp(ibn))->hierarchy();\n    const auto& logical_blob_desc = start_node_->op_node_->LogicalBlobDesc4Lbi(lbi);\n    HashMap<NdSbp, int64_t> consumer_nd_sbp2memory;\n    for (int32_t sbp_sig_id = 0; sbp_sig_id < end_sbp_sig_list.size(); sbp_sig_id++) {\n      const NdSbp& nd_sbp = end_sbp_sig_list[sbp_sig_id].bn_in_op2nd_sbp().at(ibn);\n      auto it = consumer_nd_sbp2memory.find(nd_sbp);\n      if (it == consumer_nd_sbp2memory.end()) {\n        // This compute the memory at rank 0, the largest one.\n        // We could be faster if we just compute the average memory.\n        it = consumer_nd_sbp2memory\n                 .insert({nd_sbp,\n                          MaxByteSize4BlobDescSbp(logical_blob_desc, nd_sbp, consumer_hierarchy)})\n                 .first;\n      }\n      consumer_nd_sbp_sig2memory[sbp_sig_id] += it->second;\n    }\n  }\n\n  // In this part, we assemble the cost from nodes to nodes.\n  if (start_node_->op_node_ && end_node_->op_node_) {\n    OpNode* consumer = end_node_->op_node_;\n\n    // Add copy cost for each blob\n    const LogicalBlobId& lbi = consumer->op().BnInOp2Lbi(ibn);\n\n    // Check whether lbi is transferred by this edge\n    if (use_sbp_collector && !SearchLbi(lbi)) { return; }\n\n    OpNode* producer = start_node_->op_node_;\n    const std::string& producer_lbn = *CHECK_JUST(producer->op().obn4lbi(lbi));\n    const ParallelDesc& producer_parallel_desc =\n        *CHECK_JUST(producer->op().GetParallelDesc4BnInOp(producer_lbn));\n    const ParallelDesc& consumer_parallel_desc =\n        *CHECK_JUST(consumer->op().GetParallelDesc4BnInOp(ibn));\n\n    // Need to be careful, the logical blob description should be independent to current\n    // SbpParallel. Use producer or op_node?\n    const BlobDesc& logical_blob_desc = producer->LogicalBlobDesc4Lbi(lbi);\n    const std::string& obn = *CHECK_JUST(producer->op().obn4lbi(lbi));\n    // If we are deciding whether we need the wait time, then make require_same_sbp true.\n    // B->S cause cudaEventSynchronize in current implementation.\n    bool require_same_sbp = RequireSameSbp(consumer, ibn);\n    int32_t consumer_sbp_size = end_node_->sbp_sig_list_.size();\n    LazyMode::Guard enable_lazy_mode(true);\n\n    // look through sbp signature in producer\n    for (int32_t sbp_id_producer = 0; sbp_id_producer < start_node_->sbp_sig_list_.size();\n         sbp_id_producer++) {\n      // get sbp parallel for a logical blob in producer\n      const auto& producer_sbp_bn_in_op2sbp_parallel =\n          start_node_->sbp_sig_list_[sbp_id_producer].bn_in_op2nd_sbp();\n      const NdSbp& sbp_producer = producer_sbp_bn_in_op2sbp_parallel.at(obn);\n      auto& cost4sbp_id_producer = cost_[sbp_id_producer];\n\n      // look through sbp signature in consumer\n      for (int32_t sbp_id_consumer = 0; sbp_id_consumer < consumer_sbp_size; sbp_id_consumer++) {\n        // get sbp parallel for a logical blob in consumer\n        const auto& consumer_sbp_bn_in_op2sbp_parallel =\n            end_node_->sbp_sig_list_[sbp_id_consumer].bn_in_op2nd_sbp();\n        const NdSbp& sbp_consumer = consumer_sbp_bn_in_op2sbp_parallel.at(ibn);\n\n        // compute copy cost for a specific logical blob\n        double curr_edge_cost = CHECK_JUST(ComputeCopyCostWithMiddleNodes(\n            sbp_producer, sbp_consumer, logical_blob_desc, producer_parallel_desc,\n            consumer_parallel_desc, require_same_sbp));\n        if (curr_edge_cost < GetValidMaxCopyCost()) {\n          cost4sbp_id_producer[sbp_id_consumer] +=\n              CHECK_JUST(producer->op().GetOpTimeShape())->elem_cnt() * curr_edge_cost;\n        } else {\n          cost4sbp_id_producer[sbp_id_consumer] = curr_edge_cost;\n        }\n        // If enabling nccl_use_compute_stream and transfer occurs,\n        // the current code would create a non-reusable register to receive data.\n        if (nccl_not_use_compute_stream && curr_edge_cost > 0) {\n          memory_[sbp_id_producer][sbp_id_consumer] += consumer_nd_sbp_sig2memory[sbp_id_consumer];\n        }\n      }\n    }\n  }\n}\n\n// Assemble memory cost\nvoid SbpEdge::InitializeMemory(const HashMap<LogicalBlobId, int32_t>& lbi2id,\n                               const std::vector<int32_t>& id2count,\n                               const std::vector<int64_t>& producer_nd_sbp_sig2memory) {\n  const auto& consumer_operator = end_node_->op_node_->op();\n  const auto& end_sbp_sig_list = end_node_->sbp_sig_list_;\n  std::vector<int64_t> consumer_nd_sbp_sig2memory(end_sbp_sig_list.size(), 0);\n  // Compute and store the memory for consumer\n  for (const auto& ibn : consumer_operator.input_bns()) {\n    // Match the ibn to find the hierarchy\n    const auto& lbi = consumer_operator.BnInOp2Lbi(ibn);\n    if (SearchLbi(lbi) && id2count.at(lbi2id.at(lbi)) > 0) {\n      const auto& consumer_hierarchy =\n          *CHECK_JUST(consumer_operator.GetParallelDesc4BnInOp(ibn))->hierarchy();\n      const auto& logical_blob_desc = start_node_->op_node_->LogicalBlobDesc4Lbi(lbi);\n      HashMap<NdSbp, int64_t> consumer_nd_sbp2memory;\n      for (int32_t sbp_sig_id = 0; sbp_sig_id < end_sbp_sig_list.size(); sbp_sig_id++) {\n        const NdSbp& nd_sbp = end_sbp_sig_list[sbp_sig_id].bn_in_op2nd_sbp().at(ibn);\n        auto it = consumer_nd_sbp2memory.find(nd_sbp);\n        if (it == consumer_nd_sbp2memory.end()) {\n          // This compute the memory at rank 0, the largest one.\n          // We could be faster if we just compute the average memory.\n          it = consumer_nd_sbp2memory\n                   .insert({nd_sbp,\n                            MaxByteSize4BlobDescSbp(logical_blob_desc, nd_sbp, consumer_hierarchy)})\n                   .first;\n        }\n        consumer_nd_sbp_sig2memory[sbp_sig_id] += it->second;\n      }\n    }\n  }\n  // Avoid negative value for memory\n  // For example, B -> S might reduce memory but we still consider 0 memory increment instead of\n  // negative memory increment.\n  if (*std::max_element(consumer_nd_sbp_sig2memory.begin(), consumer_nd_sbp_sig2memory.end())\n      > *std::min_element(producer_nd_sbp_sig2memory.begin(), producer_nd_sbp_sig2memory.end())) {\n    in_memory_support_ = true;\n    memory_.resize(producer_nd_sbp_sig2memory.size());\n    int32_t consumer_sbp_sig_size = consumer_nd_sbp_sig2memory.size();\n    for (int32_t i = 0; i < producer_nd_sbp_sig2memory.size(); i++) {\n      auto& memory_i = memory_[i];\n      memory_i.resize(consumer_sbp_sig_size, 0);\n      for (int32_t j = 0; j < consumer_sbp_sig_size; j++) {\n        int64_t memory_difference = consumer_nd_sbp_sig2memory[j] - producer_nd_sbp_sig2memory[i];\n        // Only accept positive memory change\n        if (memory_difference > 0) { memory_i[j] = memory_difference; }\n      }\n    }\n  }\n}\n\n// Set the cut ratio\ndouble SbpEdge::GetCutRatio() const {\n  int32_t num = 0;\n  for (int32_t i = 0; i < weighted_cost_.size(); i++) {\n    for (int32_t j = 0; j < weighted_cost_[i].size(); j++) {\n      if (weighted_cost_[i][j] < GetValidMaxCopyCost()) { num++; }\n    }\n  }\n  return double(num) / double(weighted_cost_.size() * weighted_cost_[0].size());\n}\n\n// find the cut ratio\n// (#c>GetValidMaxCopyCost() in Cost)/(#c in Cost)\ndouble SbpEdge::FindCutRatio(int32_t threshold) const {\n  double cut_ratio = GetCutRatio();\n  // lift the cut ratio to 1 to filter out some improper couples to avoid unlimited merging\n  double n = weighted_cost_.size();\n  double m = weighted_cost_[0].size();\n  double num = cut_ratio * n * m;\n  cut_ratio += 0.16 * (n + m) / double(threshold);\n  if (num <= n * 2 || num <= m * 2 || (num <= threshold && cut_ratio < 0.51)) {\n    return cut_ratio;\n  } else {\n    return 1.0;\n  }\n}\n\n// load a logical blob\nvoid SbpEdge::LoadLbi(const LogicalBlobId& lbi) { carry_lbis_.insert(lbi); }\n\n// check the existence of a logical blob\nbool SbpEdge::SearchLbi(const LogicalBlobId& lbi) const {\n  return carry_lbis_.find(lbi) != carry_lbis_.end();\n}\n\n// unload a logical blob\nvoid SbpEdge::UnloadLbi(const LogicalBlobId& lbi) {\n  if (carry_lbis_.erase(lbi) == 0) { std::cout << \"Unload an empty lbi!\" << std::endl; }\n}\n\n// Not carrying any blob\nbool SbpEdge::EmptyLbi() const { return carry_lbis_.empty(); }\n\n}  // namespace auto_parallel\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/auto_parallel/sbp_edge.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_AUTO_PARALLEL_SBP_EDGE_H_\n#define ONEFLOW_CORE_AUTO_PARALLEL_SBP_EDGE_H_\n\n#include <assert.h>\n#include <algorithm>\n#include <unordered_set>\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/job/parallel_desc.h\"\n#include \"oneflow/core/job/lazy_mode.h\"\n#include \"oneflow/core/framework/sbp_infer_util.h\"\n#include \"oneflow/core/auto_parallel/sbp_node.h\"\n#include \"oneflow/core/auto_parallel/sbp_util.h\"\n#include \"oneflow/core/graph/op_graph.h\"\n\nnamespace oneflow {\nnamespace auto_parallel {\n\n// An edge structure to deal with the SBP strategy.\n// Please see SbpGraph for the whole algorithm and introduction.\nclass SbpEdge final {\n  /* There are 3 types of edges:\n   * 1. start_node_ -> end_node_\n   *      Nothing special\n   * 2. Multiple start_node_ -> end_node_\n   *      edge_list_ will store all the edges which goes from start_node_ to end_node_\n   * 3. start_node_ -> mid_node_ -> end_node_\n   *      It will pass by a middle node.\n   */\n public:\n  // Constructor for type 1 & 2\n  SbpEdge(SbpNode* start_node, SbpNode* end_node) : start_node_(start_node), end_node_(end_node) {\n    mid_node_ = nullptr;\n  }\n  // Constructor for type 3\n  SbpEdge(SbpNode* start_node, SbpNode* mid_node, SbpNode* end_node, SbpEdge* first_edge,\n          SbpEdge* second_edge);\n\n  // Deconstructor\n  ~SbpEdge();\n\n  OF_DISALLOW_COPY_AND_MOVE(SbpEdge);\n  bool operator==(const SbpEdge& other) { return this == &other; }\n\n  // Update copy cost for type 2 and 3\n  void SummarizeCost();\n  // Duplicate Cost. Designed for merging two nodes.\n  void DuplicateCost(bool merged_node_is_start_node, bool duplicating_first_node,\n                     const std::vector<std::pair<int32_t, int32_t>>& merged_sig_id2half_sig_id);\n  // Compute the weighted sum of the time and memory cost\n  void ComputeWeightedCost();\n  // Determine Final SbpSignature for attachment of this edge\n  void FinalizeSbp();\n  // Use Greedy Strategy to pick the sbp signature with minimum cost for this\n  // edge. You should have an initial strategy before running this. And the\n  // graph should be fully eliminated.\n  double GreedyStrategy();\n\n  // load a logical blob\n  void LoadLbi(const LogicalBlobId& lbi);\n\n  // check the existence of a logical blob\n  bool SearchLbi(const LogicalBlobId& lbi) const;\n\n  // unload a logical blob\n  void UnloadLbi(const LogicalBlobId& lbi);\n\n  // Not carrying any blob\n  bool EmptyLbi() const;\n\n  // Get the minimum element in Cost\n  double GetMinWeightedCost();\n\n  // Assemble copy and partial cost\n  void InitCopyAndMemoryCost(const std::string& ibn, bool use_sbp_collector,\n                             bool nccl_not_use_compute_stream);\n  // Assemble memory cost\n  void InitializeMemory(const HashMap<LogicalBlobId, int32_t>& lbi2id,\n                        const std::vector<int32_t>& id2count,\n                        const std::vector<int64_t>& producer_nd_sbp_sig2memory);\n\n  // find the cut ratio\n  // (#c>GetValidMaxCopyCost() in Cost)/(#c in Cost)\n  // But we would lift the cut ratio to 1 to filter out some improper couples\n  double FindCutRatio(int32_t threshold) const;\n  // Get the cut ratio\n  double GetCutRatio() const;\n\n  // Constant getter\n  SbpNode* GetEndNode() const { return end_node_; }\n  int64_t GetMemory(int32_t i, int32_t j) const { return in_memory_support_ ? memory_[i][j] : 0; }\n  // Get the current memory with the current sbp signature index\n  int64_t GetMemory() const {\n    return GetMemory(start_node_->final_sbp_sig_id_, end_node_->final_sbp_sig_id_);\n  }\n  double GetWeightedCost(int32_t i, int32_t j) const { return weighted_cost_[i][j]; }\n  // Get the current weighted cost with the current sbp signature index\n  double GetWeightedCost() const {\n    return GetWeightedCost(start_node_->final_sbp_sig_id_, end_node_->final_sbp_sig_id_);\n  }\n\n private:\n  friend class SbpNode;\n  friend class SbpGraph;\n  friend class SbpCollector;\n  friend class SbpConstructor;\n\n  // The edge point from start_node_ to end_node_\n  // It will have a middle node if and only if type 3\n  SbpNode *start_node_, *mid_node_, *end_node_;\n  // Cost[sbp_i][sbp_j] is the total cost from start_node_ with sbp_i to end_node_\n  // with sbp_j\n  std::vector<std::vector<double>> cost_;\n  // SbpSignature for mid_node_ with corresponding Cost if type 3, empty otherwise\n  std::vector<std::vector<int32_t>> mid_node_sbp_sig_;\n  // Contained edge list:\n  // empty if type 1,\n  // Parallel edges if type 2,\n  // succeed edges if type 3\n  // the edge list might have reverse direction:\n  // example 1: type 3 edge_list_ contain two edges:\n  //        mid_node_ -> start_node_, mid_node_ -> end_node_;\n  // example 2: type 2 edge_list_ contain three edges:\n  //        start_node_ -> end_node_, end_node_ -> start_node_, start_node_ -> end_node_;\n  std::vector<SbpEdge*> edge_list_;\n  // Time waiting for other gpus. pthread_cond_wait\n  double wait_time_ = -1.0;\n\n  // a set of ids of logical blobs carried/transferred on this sbp edge\n  std::unordered_set<LogicalBlobId> carry_lbis_;\n\n  // Minimum and maximum cost would not be changed by eliminations, which will generate new edges.\n  // Also would not be changed by node merging, which will only perform cost copy for the expanding\n  // dimensions.\n  // Minimum cost in the 2D array Cost.\n  // Would be initialized after GetMinWeightedCost();\n  // Only used in the final graph.\n  // Such pre-store and access process save a lot time.\n  // Gpt2 has 1178 storing and 14053 taking.\n  // Bert has 1464 storing and 17633 taking.\n  double min_weighted_cost_ = -1.0;\n  // If consider memory, each GetMinWeightedCost would have a memory_ratio_search\n  // Use the stored value for the same memory_ratio_search\n  double memory_ratio4min_weighted_cost_ = -1.0;\n\n  // The produced blob belongs to the support of the total memory\n  bool in_memory_support_ = false;\n  // The consumed memory for different sbp strategies\n  std::vector<std::vector<int64_t>> memory_;\n  // The weighted sum of time cost and memory cost\n  std::vector<std::vector<double>> weighted_cost_;\n};\n\n}  // namespace auto_parallel\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_AUTO_PARALLEL_SBP_EDGE_H_\n"
  },
  {
    "path": "oneflow/core/auto_parallel/sbp_graph.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include <algorithm>\n#include <unordered_map>\n#include \"oneflow/core/auto_parallel/binary_set.h\"\n#include \"oneflow/core/auto_parallel/sbp_graph.h\"\n#include \"oneflow/core/auto_parallel/sbp_edge.h\"\n#include \"oneflow/core/auto_parallel/sbp_node.h\"\n#include \"oneflow/core/auto_parallel/algorithm_util.h\"\n\nnamespace oneflow {\nnamespace auto_parallel {\n\n// function in cpp. Should be put in one file due to use of template\n// Otherwise we will need to declare specific template at the end of cpp file.\n\nnamespace {\nstatic const int32_t kMinNodeInGraphForMerging = 4;\n}  // anonymous namespace\n\n// Generate a node\nSbpNode* SbpGraph::GenerateNode() {\n  SbpNode* this_node = new SbpNode();\n  node_list_.emplace_back(this_node);\n  this_node->node_list_id_ = node_list_.size() - 1;\n  return this_node;\n}\n\nvoid SbpGraph::RemoveFromNodeList(SbpNode* this_node) {\n  if (this_node->node_list_id_ < 0) { return; }\n  node_list_.back()->node_list_id_ = this_node->node_list_id_;\n  RemoveFrom<SbpNode*>(node_list_, this_node->node_list_id_);\n  this_node->node_list_id_ = -1;\n}\n\nSbpGraph::~SbpGraph() {\n  for (auto this_node : node_list_) { delete this_node; }\n  node_list_.clear();\n}\n\nvoid SbpGraph::RandomSbpSignature(bool use_sbp_collector) const {\n  for (const auto& this_node : node_list_) {\n    if (this_node->sbp_sig_list_.size() > 0) {\n      this_node->final_sbp_sig_id_ = rand() % this_node->sbp_sig_list_.size();\n    } else {\n      // It must be a proxy when this_node->sbp_sig_list_.size() == 0\n      this_node->final_sbp_sig_id_ = rand() % this_node->parallel_candidates_.size();\n    }\n  }\n};\n\nvoid SbpGraph::SetDefaultSbpSig() const {\n  for (const auto& this_node : node_list_) { this_node->final_sbp_sig_id_ = 0; }\n};\n\nvoid SbpGraph::StoreOriginMemory() {\n  // We do not need to store the origin cost and memory for edges\n  // Because the origin cost and memory is the current cost and memory for a bare edge.\n  // For nodes, we need to do so because child elimination would attach the child cost and memory to\n  // the current cost and memory.\n  for (auto& this_node : node_list_) {\n    this_node->origin_cost_ = this_node->cost_;\n    this_node->origin_memory_ = this_node->memory_;\n  }\n}\n\ndouble SbpGraph::ComputeCost() const {\n  // Overall cost under current strategy\n  double graph_cost_ = 0;\n  for (const auto& this_node : node_list_) {\n    int32_t this_id = this_node->final_sbp_sig_id_;\n\n    graph_cost_ += this_node->weighted_cost_[this_id];\n    for (const auto& edge_out : this_node->edges_out_) {\n      graph_cost_ += edge_out->weighted_cost_[this_id][edge_out->end_node_->final_sbp_sig_id_];\n    }\n  }\n  return graph_cost_;\n}\n\ndouble SbpGraph::ComputeWeightedCost() const {\n  // Overall cost under current strategy\n  double graph_cost_ = 0;\n  for (const auto& this_node : node_list_) {\n    int32_t this_id = this_node->final_sbp_sig_id_;\n\n    graph_cost_ += this_node->weighted_cost_[this_id];\n    for (const auto& edge_out : this_node->edges_out_) {\n      graph_cost_ += edge_out->weighted_cost_[this_id][edge_out->end_node_->final_sbp_sig_id_];\n    }\n  }\n  return graph_cost_;\n}\n\n// Re-compute weighted cost\nvoid SbpGraph::ReComputeWeightedCost() {\n  for (const auto& this_node : node_list_) {\n    this_node->ComputeWeightedCost();\n    for (const auto& edge_out : this_node->edges_out_) { edge_out->ComputeWeightedCost(); }\n  }\n}\n\nint64_t SbpGraph::GetMemory() const {\n  // Overall memory under current strategy\n  int64_t total_memory = 0;\n  for (const auto& this_node : node_list_) {\n    total_memory += this_node->GetMemory();\n    for (const auto& edge_out : this_node->edges_out_) { total_memory += edge_out->GetMemory(); }\n  }\n  return total_memory;\n}\n\nint32_t SbpGraph::NodeElimination(SbpNode* this_node) {\n  if (this_node->edges_in_.size() + this_node->edges_out_.size() == 2) {\n    std::vector<SbpNode*> two_nodes;\n    for (const auto& one_edge : this_node->edges_in_) two_nodes.emplace_back(one_edge->start_node_);\n    for (const auto& one_edge : this_node->edges_out_) two_nodes.emplace_back(one_edge->end_node_);\n\n    // If a node is pointing to itself, could happen when shrink from a circle\n    if (two_nodes[0] == two_nodes[1]) {\n      int32_t elimination_number = 0;\n      if (this_node->edges_out_.empty()) {\n        elimination_number += EdgeElimination(two_nodes[0]);\n      } else {\n        elimination_number += EdgeElimination(this_node);\n      }\n\n      elimination_number += ChildElimination(this_node);\n      return elimination_number;\n    }\n\n    std::vector<SbpEdge*> two_edges(this_node->edges_in_);\n    two_edges.insert(two_edges.end(), this_node->edges_out_.begin(), this_node->edges_out_.end());\n\n    int32_t edges_in_size = this_node->edges_in_.size();\n\n    SbpEdge* e = new SbpEdge(two_nodes[0], this_node, two_nodes[1], two_edges[0], two_edges[1]);\n    e->SummarizeCost();\n    // check and remove the edge_in with new edge in graph\n    for (int32_t i = 0; i < edges_in_size; i++) {\n      CheckAndRemoveFrom<SbpEdge*>(two_nodes[i]->edges_out_, two_edges[i]);\n    }\n    // check and remove the edge_out with new edge in graph\n    for (int32_t i = edges_in_size; i < 2; i++) {\n      CheckAndRemoveFrom<SbpEdge*>(two_nodes[i]->edges_in_, two_edges[i]);\n    }\n    // Let e take control of edge_list_ completely by disconnecting MidNode\n    e->mid_node_->edges_out_.clear();\n    e->mid_node_->edges_in_.clear();\n\n    // Insert new compound edge into graph\n    two_nodes[0]->edges_out_.emplace_back(e);\n    two_nodes[1]->edges_in_.emplace_back(e);\n\n    // eliminate the node from graph by swapping with the last element and\n    // popping\n    RemoveFromNodeList(this_node);\n\n    // successfully eliminate this node\n    return 1;\n  }\n  // can not eliminate this node\n  return 0;\n}\n\nint32_t SbpGraph::NodeAndEdgeEliminations() {\n  // Total elimination number\n  int32_t total_elimination_num = 0;\n  int32_t elimination_num = 1;\n  // repeat these kinds of elimination until stuck\n  while (elimination_num > 0) {\n    elimination_num = 0;\n    for (int32_t i = node_list_.size() - 1; i >= 0; i--) {\n      elimination_num += NodeElimination(node_list_[i]);\n    }\n\n    for (int32_t i = node_list_.size() - 1; i >= 0; i--) {\n      elimination_num += EdgeElimination(node_list_[i]);\n    }\n\n    for (int32_t i = node_list_.size() - 1; i >= 0; i--) {\n      elimination_num += ChildElimination(node_list_[i]);\n    }\n\n    if (elimination_num == 0 && node_list_.size() > 2) {\n      elimination_num += PickAndMerge();\n      for (int32_t i = node_list_.size() - 1; i >= 0; i--) {\n        elimination_num += EdgeElimination(node_list_[i]);\n      }\n    }\n\n    total_elimination_num += elimination_num;\n  }\n\n  return total_elimination_num;\n}\n\nint32_t SbpGraph::EdgeElimination(SbpNode* this_node) const {\n  // Remove all edges with (start_node -> end_node) from edges_in_ of end_node\n  auto RemoveFromEdgesIn = [](SbpNode* start_node, SbpNode* end_node) -> void {\n    for (int32_t i = end_node->edges_in_.size() - 1; i >= 0; i--) {\n      if (start_node == end_node->edges_in_[i]->start_node_) {\n        RemoveFrom<SbpEdge*>(end_node->edges_in_, i);\n      }\n    }\n  };\n  auto LookForParallelEdge = [](SbpEdge*& e, SbpNode* start_node, SbpNode* end_node,\n                                bool if_reverse, int32_t stop_sign) -> int32_t {\n    // elimination edges with specific start node and end node in\n    // start_node->edges_out_ from index stop sign to the end.\n    // start_node->edges_out_[stop_sign] not included and need special treatment\n    // after this process.\n    int32_t elimination_num = 0;\n    for (int32_t j = start_node->edges_out_.size() - 1; j > stop_sign; j--) {\n      if (end_node == start_node->edges_out_[j]->end_node_) {\n        if (!e) {\n          if (if_reverse) {\n            e = new SbpEdge(end_node, start_node);\n          } else {\n            e = new SbpEdge(start_node, end_node);\n          }\n        }\n        // edge elimination\n        e->edge_list_.emplace_back(start_node->edges_out_[j]);\n        elimination_num++;\n        RemoveFrom<SbpEdge*>(start_node->edges_out_, j);\n      }\n    }\n    return elimination_num;\n  };\n\n  int32_t elimination_num = 0;\n\n  for (int32_t i = 0; i < this_node->edges_out_.size(); i++) {\n    SbpEdge* e = nullptr;\n    // Find and delete Parallel Edges from edges_out_\n    elimination_num += LookForParallelEdge(e, this_node, this_node->edges_out_[i]->end_node_,\n                                           /*if_reverse=*/false, i);\n    elimination_num += LookForParallelEdge(e, this_node->edges_out_[i]->end_node_, this_node,\n                                           /*if_reverse=*/true, /*stop_sign=*/-1);\n    if (e) {\n      // Delete Parallel Edges from edges_in_\n      RemoveFromEdgesIn(this_node, e->end_node_);\n      RemoveFromEdgesIn(e->end_node_, this_node);\n      // Add the compound edge\n      e->edge_list_.emplace_back(this_node->edges_out_[i]);\n      this_node->edges_out_[i] = e;\n      e->SummarizeCost();\n      e->end_node_->edges_in_.emplace_back(e);\n    }\n  }\n  return elimination_num;\n}\n\nint32_t SbpGraph::ChildElimination(SbpNode* this_node) {\n  if (this_node->EliminateItselfAsChild()) {\n    // eliminate this node from global node list\n    RemoveFromNodeList(this_node);\n    // successfully eliminate this node\n    return 1;\n  } else {\n    // can not eliminate this node\n    return 0;\n  }\n}\n\n// Merge two nodes\nint32_t SbpGraph::NodeMerging(SbpNode* first, SbpNode* second) {\n  SbpNode* new_node = new SbpNode(first, second);\n\n  // Adjust node_list_\n  RemoveFromNodeList(first);\n  RemoveFromNodeList(second);\n\n  new_node->node_list_id_ = node_list_.size();\n  node_list_.emplace_back(new_node);\n\n  return 1;\n}\n\nvoid SbpGraph::FinalizeSbp() const {\n  for (const auto& this_node : node_list_) { this_node->FinalizeSbp(); }\n}\n\ndouble SbpGraph::GreedyStrategy(bool for_node) const {\n  // Overall, this function should be replaced by GreedyStrategy(nbh_num);\n  // Total Cost Reduce & Cost Reduce for one loop\n  double total_cost_reduction = 0, cost_reduction = 0;\n  for (int32_t step = node_list_.size(); step >= 0; step--) {\n    cost_reduction = 0;\n    for (SbpNode* this_node : node_list_) {\n      // Use GreedyStrategy on Nodes if there is one node left for this\n      // connected component. Otherwise, Use GreedyStrategy on Edges.\n      if (for_node || this_node->edges_in_.size() + this_node->edges_out_.size() == 0) {\n        cost_reduction += this_node->GreedyStrategy();\n      } else {\n        // GreedyStrategy on Edges.\n        for (SbpEdge* this_edge : this_node->edges_out_) {\n          double second_rdc = this_edge->GreedyStrategy();\n          cost_reduction += second_rdc;\n        }\n      }\n    }\n    if (cost_reduction == 0) { break; }\n    total_cost_reduction += cost_reduction;\n  }\n  return total_cost_reduction;\n}\n\ndouble SbpGraph::GreedyStrategy(int32_t nbh_num) const {\n  // nbh_num is the maximum number of neighborhood to adjust sbp strategy in each step\n  // Total Cost Reduce & Cost Reduce for one loop\n  double total_cost_reduction = 0, cost_reduction = 0;\n  // A global buffer to store part of the one ring neighborhood.\n  std::vector<int32_t> nbh_id2node_list_id;\n  // Not accept a number lower than 1\n  if (nbh_num < 1) { nbh_num = 1; }\n  nbh_id2node_list_id.resize(nbh_num);\n  std::vector<int32_t> original_sbp_sig_id(nbh_num);\n  // store all the node_list_id whose corresponding nodes will be visited\n  // We can use unordered_map to do this but vector is faster\n  std::vector<int32_t> pre_visit_node_list(node_list_.size() + 1);\n  for (int32_t nbh_id = 0; nbh_id < node_list_.size(); nbh_id++) {\n    pre_visit_node_list[nbh_id] = nbh_id;\n  }\n  int32_t head = 0, tail = node_list_.size();\n  // whether a node_list_id is in pre_visit_node_list\n  std::vector<bool> pre_visit_tags(node_list_.size(), true);\n  int32_t step = 0;\n  // 1 ring neighborhood buffer\n  std::vector<int32_t> nbh_1ring(nbh_num);\n  // 2 ring neighborhood buffer\n  std::vector<int32_t> nbh_2ring;\n  std::vector<bool> node_tags(node_list_.size(), false);\n  std::vector<int32_t> nbh_1ring_buffer;\n\n  while (head != tail && step < node_list_.size()) {\n    auto* this_node = node_list_[pre_visit_node_list[head]];\n    if (nbh_num <= 1) {\n      // Greedy strategy on nodes, here we use nbh_1ring to store the nbh_id2node_list_id\n      // information for reutilization\n      nbh_1ring[0] = this_node->node_list_id_;\n      // store the original sbp signature of the 1-ring neighborhood for comparison\n      original_sbp_sig_id[0] = this_node->final_sbp_sig_id_;\n      cost_reduction = NbhGreedyStrategy(nbh_1ring);\n    } else {\n      // Use GreedyStrategy on the one ring neighborhood of this node.\n      this_node->OneRingNeighborhood(nbh_1ring);\n      // store the original sbp signature of the 1-ring neighborhood for comparison\n      original_sbp_sig_id.resize(nbh_1ring.size());\n      for (int32_t nbh_id = 0; nbh_id < nbh_1ring.size(); nbh_id++) {\n        original_sbp_sig_id[nbh_id] = node_list_[nbh_1ring[nbh_id]]->final_sbp_sig_id_;\n      }\n      if (nbh_1ring.size() <= nbh_num) {\n        cost_reduction = NbhGreedyStrategy(nbh_1ring);\n      } else {\n        // Use GreedyStrategy on part of the one ring neighborhood.\n        // Loop through the neighborhood. Each loop should contain the centroid.\n\n        // Initialize part of the one ring neighborhood\n        int32_t nbh_1ring_id = nbh_1ring.size() - nbh_num;\n        for (int32_t nbh_id = 1; nbh_id < nbh_num; ++nbh_id) {\n          nbh_id2node_list_id[nbh_id] = nbh_1ring[++nbh_1ring_id];\n        }\n        // loop through the one ring neighborhood\n        cost_reduction = 0;\n        int32_t nbh_id = 0;\n        for (nbh_1ring_id = 0; nbh_1ring_id < nbh_1ring.size(); ++nbh_1ring_id) {\n          nbh_id2node_list_id[nbh_id] = nbh_1ring[nbh_1ring_id];\n          cost_reduction += NbhGreedyStrategy(nbh_id2node_list_id);\n          // nbh_id for the next step\n          if (++nbh_id >= nbh_num) { nbh_id = 1; }\n        }\n      }\n    }\n    // change of strategies\n    if (cost_reduction != 0) {\n      // Add neighborhood into pre-visited node list for each node with changing strategy\n      for (int32_t nbh_id = 0; nbh_id < nbh_1ring.size(); nbh_id++) {\n        // If changes occur\n        if (original_sbp_sig_id[nbh_id] != node_list_[nbh_1ring[nbh_id]]->final_sbp_sig_id_) {\n          // schedule to visit the neighborhood of that changing node\n          node_list_[nbh_1ring[nbh_id]]->NRingNeighborhood(2, nbh_2ring, nbh_1ring_buffer,\n                                                           node_list_, node_tags);\n          for (int32_t nbh_node_list_id : nbh_2ring) {\n            // Put them into the pre-visited node list\n            if (!pre_visit_tags[nbh_node_list_id]) {\n              pre_visit_node_list[tail] = nbh_node_list_id;\n              pre_visit_tags[nbh_node_list_id] = true;\n              tail++;\n              if (tail == pre_visit_node_list.size()) { tail = 0; }\n            }\n          }\n        }\n      }\n    }\n    // Finish visiting\n    pre_visit_tags[pre_visit_node_list[head]] = false;\n    head++;\n    if (head == pre_visit_node_list.size()) {\n      head = 0;\n      step++;\n    }\n\n    total_cost_reduction += cost_reduction;\n  }\n  return total_cost_reduction;\n}\n\nvoid SbpGraph::DfsAddNbhCost(std::vector<int32_t>& nbh_id2node_list_id,\n                             std::unordered_map<int32_t, int32_t>& node_list_id2nbh_id,\n                             std::vector<int32_t>& order2nbh_id, std::vector<int32_t>& nbh_id2order,\n                             std::vector<double>& order2acc_min_in_nbh_cost,\n                             std::vector<std::vector<double>>& out_nbh_costs,\n                             std::vector<std::vector<int32_t>>& nbh_id2order2sbp_id,\n                             std::vector<int32_t>& min_sbp_sig_id, double& min_cost, int32_t order,\n                             double curr_cost) const {\n  // We have finished visiting the neighborhood\n  if (order >= nbh_id2node_list_id.size()) {\n    // relative difference > 1e-12\n    if (curr_cost < min_cost * kFloatDeviationMinus) {\n      min_cost = curr_cost;\n      for (int32_t nbh_id = 0; nbh_id < nbh_id2node_list_id.size(); nbh_id++) {\n        min_sbp_sig_id[nbh_id] = node_list_[nbh_id2node_list_id[nbh_id]]->final_sbp_sig_id_;\n      }\n    }\n    return;\n  }\n  // Pruning, remove all those branch with large cost\n  if (curr_cost + order2acc_min_in_nbh_cost[order] >= min_cost) { return; }\n  // Deep first search in the next order\n  int32_t nbh_id = order2nbh_id[order];\n  SbpNode* sbp_node = node_list_[nbh_id2node_list_id[nbh_id]];\n  for (int32_t sbp_id : nbh_id2order2sbp_id[nbh_id]) {\n    sbp_node->final_sbp_sig_id_ = sbp_id;\n    DfsAddNbhCost(nbh_id2node_list_id, node_list_id2nbh_id, order2nbh_id, nbh_id2order,\n                  order2acc_min_in_nbh_cost, out_nbh_costs, nbh_id2order2sbp_id, min_sbp_sig_id,\n                  min_cost, order + 1,\n                  curr_cost + out_nbh_costs[nbh_id][sbp_id]\n                      + sbp_node->EvalInNbhCost(node_list_id2nbh_id, nbh_id2order));\n  }\n}\n\nbool SbpGraph::DfsFindReasonableCost(std::vector<int32_t>& nbh_id2node_list_id,\n                                     std::unordered_map<int32_t, int32_t>& node_list_id2nbh_id,\n                                     std::vector<int32_t>& nbh_id2order, int32_t nbh_id) const {\n  // We found such a strategy\n  if (nbh_id == nbh_id2order.size()) { return true; }\n  SbpNode* sbp_node = node_list_[nbh_id2node_list_id[nbh_id]];\n  // Start from B.\n  for (int32_t sbp_id = sbp_node->weighted_cost_.size() - 1; sbp_id >= 0; sbp_id--) {\n    sbp_node->final_sbp_sig_id_ = sbp_id;\n    // If the cost for this node is reasonable, then go to the next one\n    if (sbp_node->weighted_cost_[sbp_id]\n            + sbp_node->EvalInNbhCost(node_list_id2nbh_id, nbh_id2order)\n        < GetValidMaxCopyCost()) {\n      if (DfsFindReasonableCost(nbh_id2node_list_id, node_list_id2nbh_id, nbh_id2order,\n                                nbh_id + 1)) {\n        // If we found one strategy, then exist the Dfs.\n        return true;\n      }\n    }\n  }\n  // Can not find a reasonable strategy with the setting for previous nodes.\n  // Go back and change the previous node.\n  return false;\n}\n\n// Find one strategy with finite cost for adjustment\nMaybe<void> SbpGraph::Find1Strategy4Greedy() const {\n  std::vector<int32_t> nbh_id2node_list_id;\n  std::vector<bool> not_visited(node_list_.size(), true);\n  std::vector<int32_t> nbh_1ring;\n  int32_t head = 0;\n  int32_t tail = 0;\n  std::vector<double> node_cut_ratios(node_list_.size());\n  // Initialize cut ratio for all the nodes\n  for (int32_t node_list_id = 0; node_list_id < node_list_.size(); node_list_id++) {\n    node_cut_ratios[node_list_id] = node_list_[node_list_id]->GetCutRatio();\n  }\n  // If have not visited all the nodes\n  while (tail < node_list_.size()) {\n    // Find the node with the minimum cut ratio\n    int32_t node_with_min_cut_ratio = -1;\n    double min_cut_ratio = 2.0;\n    for (int32_t node_list_id = 0; node_list_id < node_list_.size(); node_list_id++) {\n      if (not_visited[node_list_id]) {\n        double curr_cut_ratio = node_cut_ratios[node_list_id];\n        if (curr_cut_ratio < min_cut_ratio) {\n          min_cut_ratio = curr_cut_ratio;\n          node_with_min_cut_ratio = node_list_id;\n        }\n      }\n    }\n    // put this node into the open set\n    nbh_id2node_list_id.push_back(node_with_min_cut_ratio);\n    not_visited[node_with_min_cut_ratio] = false;\n    tail++;\n    // BFS\n    while (head < tail) {\n      // look for the neighborhood of the head\n      int32_t node_list_id = nbh_id2node_list_id[head];\n      node_list_[node_list_id]->OneRingNeighborhood(nbh_1ring);\n      // sort\n      std::sort(nbh_1ring.begin(), nbh_1ring.end(),\n                [&](int32_t i, int32_t j) { return node_cut_ratios[i] < node_cut_ratios[j]; });\n      for (int32_t curr_id : nbh_1ring) {\n        if (not_visited[curr_id]) {\n          nbh_id2node_list_id.push_back(curr_id);\n          tail++;\n          not_visited[curr_id] = false;\n        }\n      }\n      head++;\n    }\n  }\n  // mapping from the node_list_id to the id in the nbh_id2node_list_id\n  std::unordered_map<int32_t, int32_t> node_list_id2nbh_id;\n  InverseFunction<int32_t>(nbh_id2node_list_id, node_list_id2nbh_id);\n  // Initial an ordinary order\n  std::vector<int32_t> nbh_id2order(nbh_id2node_list_id.size());\n  for (int32_t nbh_id = 0; nbh_id < nbh_id2node_list_id.size(); nbh_id++) {\n    nbh_id2order[nbh_id] = nbh_id;\n  }\n  // Combining deep first search and pruning based on cut ratio\n  CHECK(DfsFindReasonableCost(nbh_id2node_list_id, node_list_id2nbh_id, nbh_id2order, /*nbh_id=*/0))\n      << \"Can't find a reasonable strategy!\";\n  return Maybe<void>::Ok();\n}\n\n// Use brute force to search for a strategy with minimum cost for a neighborhood\ndouble SbpGraph::NbhGreedyStrategy(std::vector<int32_t>& nbh_id2node_list_id) const {\n  // number of nodes in the neighborhood\n  int32_t num_nbh = nbh_id2node_list_id.size();\n  // mapping from the node_list_id to the id in the nbh_id2node_list_id\n  std::unordered_map<int32_t, int32_t> node_list_id2nbh_id;\n  InverseFunction<int32_t>(nbh_id2node_list_id, node_list_id2nbh_id);\n  // a sbp signature id set minimizing the overall cost, store the original one as default\n  std::vector<int32_t> min_sbp_sig_id(num_nbh);\n  for (int32_t nbh_id = 0; nbh_id < num_nbh; nbh_id++) {\n    min_sbp_sig_id[nbh_id] = node_list_[nbh_id2node_list_id[nbh_id]]->final_sbp_sig_id_;\n  }\n\n  // pre-compute and store the cost between neighborhood and outside nodes under different sbp for\n  // each node within the neighborhood\n  std::vector<std::vector<double>> out_nbh_costs(num_nbh);\n  for (int32_t nbh_id = 0; nbh_id < num_nbh; nbh_id++) {\n    SbpNode* sbp_node = node_list_[nbh_id2node_list_id[nbh_id]];\n    out_nbh_costs[nbh_id].resize(sbp_node->weighted_cost_.size());\n    for (int32_t sbp_id = sbp_node->weighted_cost_.size() - 1; sbp_id >= 0; sbp_id--) {\n      sbp_node->final_sbp_sig_id_ = sbp_id;\n      out_nbh_costs[nbh_id][sbp_id] = sbp_node->EvalOutNbhCost(node_list_id2nbh_id);\n    }\n  }\n  // pre-compute and store the order of the out_nbh_costs\n  std::vector<std::vector<int32_t>> nbh_id2order2sbp_id(num_nbh);\n  auto CompareDoubleLess = [](double a, double b) { return a < b; };\n  for (int32_t nbh_id = 0; nbh_id < num_nbh; nbh_id++) {\n    DecideOrder(out_nbh_costs[nbh_id], nbh_id2order2sbp_id[nbh_id], CompareDoubleLess);\n  }\n\n  // Decide the order to go through the neighborhood.\n  // Should visit those nodes with a larger difference in the out cost first.\n  std::vector<double> out_nbh_cost_diff(num_nbh);\n  for (int32_t nbh_id = 0; nbh_id < num_nbh; nbh_id++) {\n    out_nbh_cost_diff[nbh_id] =\n        *std::max_element(out_nbh_costs[nbh_id].begin(), out_nbh_costs[nbh_id].end())\n        - *std::min_element(out_nbh_costs[nbh_id].begin(), out_nbh_costs[nbh_id].end());\n  }\n  std::vector<int32_t> order2nbh_id;\n  DecideOrder(out_nbh_cost_diff, order2nbh_id, [](double a, double b) { return a > b; });\n  // Find the inverse map of order\n  std::vector<int32_t> nbh_id2order;\n  InverseOrder(order2nbh_id, nbh_id2order);\n\n  // Current Cost, Minimum Cost, Cost with original sbp\n  double original_cost = 0;\n  // Recover original sbp\n  for (int32_t nbh_id = 0; nbh_id < num_nbh; nbh_id++) {\n    node_list_[nbh_id2node_list_id[nbh_id]]->final_sbp_sig_id_ = min_sbp_sig_id[nbh_id];\n  }\n  // Compute cost with original sbp\n  for (int32_t nbh_id = 0; nbh_id < num_nbh; nbh_id++) {\n    SbpNode* sbp_node = node_list_[nbh_id2node_list_id[nbh_id]];\n    original_cost += out_nbh_costs[nbh_id][min_sbp_sig_id[nbh_id]];\n    original_cost += sbp_node->EvalInNbhCost(node_list_id2nbh_id, nbh_id2order);\n  }\n  double min_cost = original_cost;\n  // Accumulate minimum cost from the current node to the end of the neighborhood node list.\n  // The accumulated cost include the current node.\n  std::vector<double> order2acc_min_in_nbh_cost(num_nbh);\n  order2acc_min_in_nbh_cost[num_nbh - 1] =\n      *std::min_element(out_nbh_costs[order2nbh_id[num_nbh - 1]].begin(),\n                        out_nbh_costs[order2nbh_id[num_nbh - 1]].end());\n  for (int32_t order = num_nbh - 2; order >= 0; order--) {\n    int32_t nbh_id = order2nbh_id[order];\n    order2acc_min_in_nbh_cost[order] =\n        order2acc_min_in_nbh_cost[order + 1]\n        + *std::min_element(out_nbh_costs[nbh_id].begin(), out_nbh_costs[nbh_id].end())\n        + node_list_[nbh_id2node_list_id[nbh_id]]->EvalMinInNbhCost(node_list_id2nbh_id,\n                                                                    nbh_id2order);\n  }\n  // Use brute force (DFS) to adjust for the best strategy in the neighborhood.\n  DfsAddNbhCost(nbh_id2node_list_id, node_list_id2nbh_id, order2nbh_id, nbh_id2order,\n                order2acc_min_in_nbh_cost, out_nbh_costs, nbh_id2order2sbp_id, min_sbp_sig_id,\n                min_cost, /*order=*/0, /*curr_cost=*/0);\n  // Use the sbp strategy with minimum cost\n  for (int32_t nbh_id = 0; nbh_id < num_nbh; nbh_id++) {\n    node_list_[nbh_id2node_list_id[nbh_id]]->final_sbp_sig_id_ = min_sbp_sig_id[nbh_id];\n  }\n\n  if (min_cost < original_cost) {\n    // Directly return (min_cost - original_cost) might have floating point error up to 3e-16\n    // For example, original_cost: 2.22507e+06, min_cost: 2.22507e+06,\n    // diff: -4.65661e-10, relative diff:2.09279e-16\n    // Therefore, we use a threshold to filter out such fake true detection to\n    // avoid unlimited search.\n    if (original_cost * kFloatDeviationMinus > min_cost) { return min_cost - original_cost; }\n  }\n  return 0.0;\n}\n\n// Select and Merge two nodes\nint32_t SbpGraph::PickAndMerge() {\n  if (node_list_.size() < kMinNodeInGraphForMerging) { return 0; }\n  // Pick the one with the smallest cut ratio\n  double min_cut_ratio = 1.0;\n  double curr_cut_ratio = 0.0;\n  SbpEdge* merging_edge = nullptr;\n  for (int32_t i = 0; i < node_list_.size(); i++) {\n    for (SbpEdge* edge_in : node_list_[i]->edges_in_) {\n      curr_cut_ratio = edge_in->FindCutRatio(threshold_);\n      if (curr_cut_ratio < min_cut_ratio) {\n        min_cut_ratio = curr_cut_ratio;\n        merging_edge = edge_in;\n      }\n    }\n  }\n\n  if (merging_edge != nullptr) {\n    // Merge two nodes on the edge with the minimum cut ratio\n    return NodeMerging(merging_edge->start_node_, merging_edge->end_node_);\n  } else {\n    // Pick the couple with the largest similar neighborhood\n    std::vector<BinarySet> node_binary_sets(node_list_.size());\n    for (int32_t i = 0; i < node_list_.size(); i++) {\n      // Transfer edge to binary set\n      node_binary_sets[i].Initialize(node_list_.size());\n      node_binary_sets[i].AddEntry(i);\n      for (const SbpEdge* edge_in : node_list_[i]->edges_in_) {\n        node_binary_sets[i].AddEntry(edge_in->start_node_->node_list_id_);\n      }\n      for (const SbpEdge* edge_out : node_list_[i]->edges_out_) {\n        node_binary_sets[i].AddEntry(edge_out->start_node_->node_list_id_);\n      }\n    }\n    // Find two nodes with largest common subset\n    // buffer of binary set\n    BinarySet buffer_binary_set(node_list_.size());\n    // Number of common edges\n    int32_t max_comm_edge_num = 0, curr_comm_edge_num = 0;\n    int32_t min_node_pair[2];\n    // Number of Sbp Signature in merged node\n    int32_t min_sbp_num = 0, curr_sbp_num = 0;\n    for (int32_t i = 0; i < node_list_.size(); i++) {\n      for (int32_t j = i + 1; j < node_list_.size(); j++) {\n        curr_sbp_num = node_list_[i]->weighted_cost_.size() * node_list_[j]->weighted_cost_.size();\n        if (curr_sbp_num <= threshold_) {\n          node_binary_sets[i].IntersectionTo(node_binary_sets[j], buffer_binary_set);\n          curr_comm_edge_num = buffer_binary_set.Total();\n          if (curr_comm_edge_num > max_comm_edge_num\n              || (curr_comm_edge_num == max_comm_edge_num && curr_sbp_num < min_sbp_num)) {\n            min_node_pair[0] = i;\n            min_node_pair[1] = j;\n            max_comm_edge_num = curr_comm_edge_num;\n            min_sbp_num = curr_sbp_num;\n          }\n        }\n      }\n    }\n    if (max_comm_edge_num > 0) {\n      return NodeMerging(node_list_[min_node_pair[0]], node_list_[min_node_pair[1]]);\n    } else {\n      return 0;\n    }\n  }\n}\n\n// Clip an edge, remove it from graph\nvoid SbpGraph::ClipEdge(SbpEdge* this_edge) const {\n  CheckAndRemoveFrom<SbpEdge*>(this_edge->end_node_->edges_in_, this_edge);\n  CheckAndRemoveFrom<SbpEdge*>(this_edge->start_node_->edges_out_, this_edge);\n  delete this_edge;\n}\n\n// Compute the minimum and maximum layer of each node in the graph\nint32_t SbpGraph::ComputeLayer(\n    HashMap<std::string, SbpNode*>& op_name2sbp_node,\n    const HashMap<const OpNode*, HashSet<std::string>>& op_node2mutable_op_ctrl_deps) const {\n  // Compute minimum layer\n  for (SbpNode* this_node : node_list_) {\n    this_node->GetMinLayer(op_name2sbp_node, op_node2mutable_op_ctrl_deps);\n  }\n  // Find the largest minimum layer\n  int32_t max_min_layer = -1;\n  for (SbpNode* this_node : node_list_) {\n    if (max_min_layer < this_node->min_layer_) { max_min_layer = this_node->min_layer_; }\n  }\n  // Compute maximum layer\n  for (SbpNode* this_node : node_list_) {\n    this_node->SpreadMaxLayer(op_name2sbp_node, op_node2mutable_op_ctrl_deps);\n  }\n  for (SbpNode* this_node : node_list_) { this_node->LiftMaxLayer(max_min_layer); }\n  return max_min_layer;\n}\n\n// TODO: Remove the tributary layer here.\n// Find the trunk of the sbp graph, then reduce the wait time for tributaries\nvoid SbpGraph::FindTrunk(int32_t max_min_layer,\n                         HashMap<std::string, SbpNode*>& op_name2sbp_node) const {\n  // Summarize cost for each layer, on the trunk or tributaries\n  std::vector<double> trunk_cost(max_min_layer + 1, 0);\n  for (SbpNode* this_node : node_list_) {\n    trunk_cost[this_node->min_layer_] += this_node->GetMinCost();\n  }\n  // Decide trunks\n  double acc_cost = 0;\n  // All the nodes with MinLayer>=trunk_end_id would be considered as trunks\n  int32_t trunk_end_id = max_min_layer;\n  for (int32_t layer_id = max_min_layer; layer_id >= 0; layer_id--) {\n    acc_cost += trunk_cost[layer_id];\n    if (acc_cost > 0.5 * wait_time_) {\n      trunk_end_id = layer_id;\n      break;\n    }\n  }\n  // Find out all the nodes on the trunk.\n  for (SbpNode* this_node : node_list_) {\n    if (this_node->min_layer_ >= trunk_end_id) { this_node->SpreadTrunk(op_name2sbp_node); }\n  }\n\n  // Compute maximum layer for tributaries\n  // Clear counter and initialize tributary layer for each sbp node\n  for (SbpNode* this_node : node_list_) {\n    this_node->counter_ = 0;\n    this_node->DropTributaryLayer(max_min_layer);\n  }\n  // Count the number of consumers and downstream nodes\n  for (SbpNode* this_node : node_list_) { this_node->RaiseConsumerNum(op_name2sbp_node); }\n  // Compute maximum layer for tributaries\n  for (SbpNode* this_node : node_list_) { this_node->SpreadTributaryLayer(op_name2sbp_node); }\n\n  // Summarize cost for each layer on the trunk, store it to avoid subtraction of large values.\n  trunk_cost.assign(max_min_layer + 1, 0);\n  // tributary cost start from each min layer\n  std::vector<double> tributary_cost(max_min_layer + 1, 0);\n  // tributary cost would be outdated after Max Layer (before Max Layer + 1)\n  std::vector<double> outdated_tributary_cost(max_min_layer + 1, 0);\n  // number of operators in the trunk\n  std::vector<std::vector<SbpNode*>> trunk_ops(max_min_layer + 1);\n\n  for (SbpNode* this_node : node_list_) {\n    if (this_node->on_trunk_) {\n      trunk_cost[this_node->min_layer_] += this_node->GetMinCost();\n      trunk_ops[this_node->min_layer_].emplace_back(this_node);\n    } else {\n      double curr_min_cost = this_node->GetMinCost();\n      tributary_cost[this_node->min_layer_] += curr_min_cost;\n      outdated_tributary_cost[this_node->tributary_layer_] += curr_min_cost;\n    }\n  }\n  // Accumulate the cost from the consumer to the end, not including itself\n  std::vector<double> acc_trunk_cost(max_min_layer + 1, 0);\n  for (int32_t layer_id = max_min_layer; layer_id > 0; layer_id--) {\n    acc_trunk_cost[layer_id - 1] = acc_trunk_cost[layer_id] + trunk_cost[layer_id];\n  }\n\n  // Clear counter for each sbp node\n  for (SbpNode* this_node : node_list_) { this_node->counter_ = 0; }\n  // Count the number of consumers and downstream nodes\n  for (SbpNode* this_node : node_list_) { this_node->RaiseConsumerNum(op_name2sbp_node); }\n  // Reduce the wait time for tributaries\n  for (SbpNode* this_node : node_list_) {\n    this_node->SpreadAvailWaitTime(trunk_cost, acc_trunk_cost, op_name2sbp_node, wait_time_);\n  }\n\n  // Reduce the wait time for trunk from the end to the begin\n  double acc_tributary_cost = outdated_tributary_cost[max_min_layer];\n  double used_tributary_cost = 0.0;\n  double curr_wait_time = 0.0;\n  for (int32_t layer_id = max_min_layer - 1; layer_id >= 0; layer_id--) {\n    // Can not move it backward since we need to do this at the 0th layer.\n    // At some moment, the cost haven't been used would disappear.\n    if (tributary_cost[layer_id + 1] > used_tributary_cost) {\n      acc_tributary_cost -= tributary_cost[layer_id + 1] - used_tributary_cost;\n      used_tributary_cost = 0.0;\n      if (acc_tributary_cost < 0.0) {\n        // should not happen besides floating point error\n        std::cout << \"Caution! Current accumulated tributary cost is: \" << acc_tributary_cost\n                  << std::endl;\n        acc_tributary_cost = 0.0;\n      }\n    } else {\n      used_tributary_cost -= tributary_cost[layer_id + 1];\n    }\n    // accumulate tributary cost at this layer\n    acc_tributary_cost += outdated_tributary_cost[layer_id];\n    // If we have more cost in tributaries, we reduce the wait time\n    // This code maintains ( acc_tributary_cost + used_tributary_cost )\n    if (acc_tributary_cost > 0.0) {\n      if (acc_tributary_cost > wait_time_) {\n        curr_wait_time = 0.0;\n        acc_tributary_cost -= wait_time_;\n        used_tributary_cost += wait_time_;\n      } else {\n        curr_wait_time = wait_time_ - acc_tributary_cost;\n        used_tributary_cost += acc_tributary_cost;\n        acc_tributary_cost = 0.0;\n      }\n      // Reduce the wait time in the trunk\n      for (SbpNode* this_node : trunk_ops[layer_id]) {\n        this_node->SetTrunkWaitTime(curr_wait_time);\n      }\n    }\n  }\n}\n\n// Set wait time\nvoid SbpGraph::SetWaitTime(double wait_time) { wait_time_ = wait_time; }\n\n}  // namespace auto_parallel\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/auto_parallel/sbp_graph.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_AUTO_PARALLEL_SBP_GRAPH_H_\n#define ONEFLOW_CORE_AUTO_PARALLEL_SBP_GRAPH_H_\n\n#include <algorithm>\n#include <unordered_map>\n#include \"oneflow/core/auto_parallel/binary_set.h\"\n#include \"oneflow/core/auto_parallel/sbp_node.h\"\n#include \"oneflow/core/auto_parallel/sbp_edge.h\"\n#include \"oneflow/core/auto_parallel/algorithm_util.h\"\n#include \"oneflow/core/common/util.h\"\n\nnamespace oneflow {\nnamespace auto_parallel {\n\n// A graph structure to deal with the SBP strategy.\n// It contains a lot of eliminations to shrink the topography structure of the original graph.\n// Furthermore, it contains some adjustment tricks for search a good strategy in the shrunk graph.\nclass SbpGraph final {\n public:\n  // Constructor\n  SbpGraph() = default;\n\n  // Deconstructor\n  ~SbpGraph();\n\n  OF_DISALLOW_COPY_AND_MOVE(SbpGraph);\n  bool operator==(const SbpGraph& other) { return this == &other; }\n\n  // Randomly assign a SbpSignature strategy\n  void RandomSbpSignature(bool use_sbp_collector) const;\n  // assign 0 to a SbpSignature strategy to avoid randomness\n  void SetDefaultSbpSig() const;\n\n  void StoreOriginMemory();\n  // Compute Cost for current strategy\n  double ComputeCost() const;\n  double ComputeWeightedCost() const;\n  // Re-compute weighted cost\n  void ReComputeWeightedCost();\n\n  // Generate a node\n  SbpNode* GenerateNode();\n\n  // Merge all parallel edges & Check and eliminate all nodes with only one\n  // degree-in and one degree-out\n  int32_t NodeAndEdgeEliminations();\n\n  // Finalize Sbp Cost for the whole graph\n  void FinalizeSbp() const;\n\n  // Use Greedy Strategy to decide Sbp for Nodes in node_list_. Should be used\n  // after we have a initial strategy.\n  // Set for_node to be true will only use GreedyStrategy on Nodes.\n  double GreedyStrategy(bool for_node) const;\n  // Use greedy strategy on the one ring neighborhood with the maximum number of points nbh_num.\n  double GreedyStrategy(int32_t nbh_num = 4) const;\n\n  // Find one strategy with finite cost for adjustment\n  Maybe<void> Find1Strategy4Greedy() const;\n  // Use brute force to search for a strategy with minimum cost for a neighborhood\n  double NbhGreedyStrategy(std::vector<int32_t>& nbh_id2node_list_id) const;\n\n  // Set threshold_ for SbpNode Merging\n  void SetThreshold(int32_t threshold) { threshold_ = threshold; }\n\n  // Clip an edge, remove it from graph\n  // Clipping an edge will also delete the nodes and edges contained in this edge. Though not\n  // suffering from any compiling and runtime bugs, clipping an edge on a shrunk graph is not\n  // recommended. We should carefully think about it before any clipping.\n  void ClipEdge(SbpEdge* this_edge) const;\n\n  // Compute the minimum and maximum layer of each node in the graph\n  int32_t ComputeLayer(\n      HashMap<std::string, SbpNode*>& op_name2sbp_node,\n      const HashMap<const OpNode*, HashSet<std::string>>& op_node2mutable_op_ctrl_deps) const;\n\n  // Find the trunk of the sbp graph, then reduce the wait time for tributaries\n  void FindTrunk(int32_t max_min_layer, HashMap<std::string, SbpNode*>& op_name2sbp_node) const;\n\n  // Set wait time\n  void SetWaitTime(double wait_time);\n\n  // Constant getter\n  std::vector<SbpNode*>& GetNodeList() { return node_list_; }\n  int64_t GetMemory() const;\n\n private:\n  friend class SbpCollector;\n  friend class SbpConstructor;\n\n  // All the nodes\n  std::vector<SbpNode*> node_list_;\n\n  // Limitation: Merged node should not have a number of Sbp Signature greater\n  // than threshold.\n  int32_t threshold_ = 100;\n  // Wait time for copy cost, which occurs before communication between devices.\n  double wait_time_ = 16500.0;\n\n  // Remove a node from the node list\n  void RemoveFromNodeList(SbpNode* this_node);\n\n  // Check and eliminate one node with only one degree-in and one degree-out\n  int32_t NodeElimination(SbpNode* this_node);\n  // Merge all parallel edges with given start_node_ and end_node_\n  int32_t EdgeElimination(SbpNode* this_node) const;\n  // Check and eliminate one child node\n  int32_t ChildElimination(SbpNode* this_node);\n\n  // Merge two nodes\n  int32_t NodeMerging(SbpNode* first, SbpNode* second);\n  // Select two nodes and merge them\n  int32_t PickAndMerge();\n\n  void DfsAddNbhCost(std::vector<int32_t>& nbh_id2node_list_id,\n                     std::unordered_map<int32_t, int32_t>& node_list_id2nbh_id,\n                     std::vector<int32_t>& order2nbh_id, std::vector<int32_t>& nbh_id2order,\n                     std::vector<double>& order2acc_min_in_nbh_cost,\n                     std::vector<std::vector<double>>& out_nbh_costs,\n                     std::vector<std::vector<int32_t>>& nbh_id2order2sbp_id,\n                     std::vector<int32_t>& min_sbp_sig_id, double& min_cost, int32_t order,\n                     double curr_cost) const;\n\n  bool DfsFindReasonableCost(std::vector<int32_t>& nbh_id2node_list_id,\n                             std::unordered_map<int32_t, int32_t>& node_list_id2nbh_id,\n                             std::vector<int32_t>& nbh_id2order, int32_t nbh_id) const;\n};\n\n}  // namespace auto_parallel\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_AUTO_PARALLEL_SBP_GRAPH_H_\n"
  },
  {
    "path": "oneflow/core/auto_parallel/sbp_node.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include <algorithm>\n#include <cstdlib>\n#include <functional>\n#include <iostream>\n#include <vector>\n#include \"oneflow/core/auto_parallel/binary_set.h\"\n#include \"oneflow/core/common/data_type.h\"\n#include \"oneflow/core/framework/sbp_infer_util.h\"\n#include \"oneflow/core/graph/op_graph.h\"\n#include \"oneflow/core/auto_parallel/algorithm_util.h\"\n#include \"oneflow/core/job/sbp_parallel.pb.h\"\n#include \"oneflow/core/auto_parallel/sbp_node.h\"\n#include \"oneflow/core/auto_parallel/sbp_edge.h\"\n#include \"oneflow/core/auto_parallel/sbp_graph.h\"\n#include \"oneflow/core/register/logical_blob_id.pb.h\"\n\nnamespace oneflow {\nnamespace auto_parallel {\n\n// In dynamic programming, we can not minimize a vector (copy cost, memory cost)\n// Instead, we minimize the weighted sum of the vector, copy cost + kMemoryRatio * memory cost\nextern double kMemoryRatio;\n\n// function in cpp. Should be put in one file due to use of template\n// Otherwise we will need to declare specific template at the end of cpp file.\nSbpNode::SbpNode(SbpNode* first, SbpNode* second) {\n  half_node_.resize(2);\n  half_node_[0] = first;\n  half_node_[1] = second;\n\n  // Get the edge between first and second\n  // NOTE: It must zero or one edge between them\n  SbpEdge* common_edge = nullptr;\n  for (int32_t k = 0; k < first->edges_in_.size(); k++) {\n    if (first->edges_in_[k]->start_node_ == second) {\n      // CHECK_ISNULL(edge);\n      common_edge = first->edges_in_[k];\n    }\n  }\n  for (int32_t k = 0; k < first->edges_out_.size(); k++) {\n    if (first->edges_out_[k]->end_node_ == second) { common_edge = first->edges_out_[k]; }\n  }\n\n  // Find all available merged-SbpSignature(edge's cost less than threshold).\n  if (common_edge) {\n    in_memory_support_ =\n        first->in_memory_support_ || second->in_memory_support_ || common_edge->in_memory_support_;\n    // If there is no one case can choose, we will blow up\n    for (int32_t i = 0; i < first->weighted_cost_.size(); i++) {\n      for (int32_t j = 0; j < second->weighted_cost_.size(); j++) {\n        const double edge_weighted_cost = common_edge->start_node_ == first\n                                              ? common_edge->weighted_cost_[i][j]\n                                              : common_edge->weighted_cost_[j][i];\n        if (edge_weighted_cost < GetValidMaxCopyCost()) {\n          merged_sig_id2half_sig_id_.emplace_back(std::make_pair(i, j));\n          if (in_memory_support_) {\n            memory_.push_back((common_edge->start_node_ == first ? common_edge->GetMemory(i, j)\n                                                                 : common_edge->GetMemory(j, i))\n                              + first->GetMemory(i) + second->GetMemory(j));\n          }\n          weighted_cost_.emplace_back(edge_weighted_cost + first->weighted_cost_[i]\n                                      + second->weighted_cost_[j]);\n        }\n      }\n    }\n    CHECK(merged_sig_id2half_sig_id_.size() > 0)\n        << \"0 size for merge two half nodes with common edge!\";\n  } else {\n    in_memory_support_ = first->in_memory_support_ || second->in_memory_support_;\n    for (int32_t i = 0; i < first->weighted_cost_.size(); i++) {\n      for (int32_t j = 0; j < second->weighted_cost_.size(); j++) {\n        merged_sig_id2half_sig_id_.emplace_back(std::make_pair(i, j));\n        if (in_memory_support_) { memory_.push_back(first->GetMemory(i) + second->GetMemory(j)); }\n        weighted_cost_.emplace_back(first->weighted_cost_[i] + second->weighted_cost_[j]);\n      }\n    }\n  }\n\n  // Initialize default sbp choice\n  // If the original sbp pair does not go through, then use 0 as default.\n  final_sbp_sig_id_ = 0;\n  // Track the original strategy\n  for (int32_t sig_id = 0; sig_id < merged_sig_id2half_sig_id_.size(); sig_id++) {\n    if (merged_sig_id2half_sig_id_[sig_id].first == first->final_sbp_sig_id_\n        && merged_sig_id2half_sig_id_[sig_id].second == second->final_sbp_sig_id_) {\n      final_sbp_sig_id_ = sig_id;\n    }\n  }\n\n  // Merge edges_in_\n  edges_in_.reserve(first->edges_in_.size() + second->edges_in_.size());\n  edges_in_.insert(edges_in_.end(), first->edges_in_.begin(), first->edges_in_.end());\n  edges_in_.insert(edges_in_.end(), second->edges_in_.begin(), second->edges_in_.end());\n  // Merge edges_out_\n  edges_out_.reserve(first->edges_out_.size() + second->edges_out_.size());\n  edges_out_.insert(edges_out_.end(), first->edges_out_.begin(), first->edges_out_.end());\n  edges_out_.insert(edges_out_.end(), second->edges_out_.begin(), second->edges_out_.end());\n  // Merge SbpEdge Cost\n  for (SbpEdge*& this_edge : first->edges_in_) {\n    this_edge->DuplicateCost(false, true, merged_sig_id2half_sig_id_);\n    this_edge->end_node_ = this;\n  }\n  for (SbpEdge*& this_edge : first->edges_out_) {\n    this_edge->DuplicateCost(true, true, merged_sig_id2half_sig_id_);\n    this_edge->start_node_ = this;\n  }\n  for (SbpEdge*& this_edge : second->edges_in_) {\n    this_edge->DuplicateCost(false, false, merged_sig_id2half_sig_id_);\n    this_edge->end_node_ = this;\n  }\n  for (SbpEdge*& this_edge : second->edges_out_) {\n    this_edge->DuplicateCost(true, false, merged_sig_id2half_sig_id_);\n    this_edge->start_node_ = this;\n  }\n  // Remove edges from original nodes\n  first->edges_in_.clear();\n  first->edges_out_.clear();\n  second->edges_in_.clear();\n  second->edges_out_.clear();\n\n  // Move edges between two nodes to each half node\n  for (int32_t k = edges_out_.size() - 1; k >= 0; k--) {\n    if (edges_out_[k]->end_node_ == this) {\n      // Remove this edge from edges_out_ and edges_in_ and put it inside the node\n      CheckAndRemoveFrom<SbpEdge*>(edges_in_, edges_out_[k]);\n      first->edges_out_.emplace_back(edges_out_[k]);\n      second->edges_in_.emplace_back(edges_out_[k]);\n      RemoveFrom<SbpEdge*>(edges_out_, k);\n    }\n  }\n}\n\nSbpNode::~SbpNode() {\n  for (auto& edge_out : edges_out_) { delete edge_out; }\n  for (auto& child_node : children_) {\n    if (child_node->edges_in_.size()) { delete child_node->edges_in_[0]; }\n    delete child_node;\n  }\n  for (auto& half_node : half_node_) { delete half_node; }\n}\n\nvoid SbpNode::InitializeSbp() {\n  global_sbp_sig_size_ = sbp_sig_list_.size();\n  cost_.resize(sbp_sig_list_.size());\n};\n\n// Let one node point to another\nvoid SbpNode::StartPointToEnd(SbpNode* start_node, SbpNode* end_node) {\n  // generate the edge between them\n  SbpEdge* e = new SbpEdge(start_node, end_node);\n  start_node->edges_out_.emplace_back(e);\n  end_node->edges_in_.emplace_back(e);\n};\n\nvoid SbpNode::PointFrom(SbpNode* start_node) { StartPointToEnd(start_node, this); };\n\nvoid SbpNode::PointTo(SbpNode* end_node) { StartPointToEnd(this, end_node); };\n\nvoid SbpNode::SummarizeCost() {\n  if (children_.size() == child_node_sbp_sig_.size()) { return; }\n  int32_t previous_children_size = child_node_sbp_sig_.size();\n  child_node_sbp_sig_.resize(children_.size());\n  in_memory_support_ =\n      in_memory_support_\n      || std::any_of(children_.begin() + previous_children_size, children_.end(),\n                     [](SbpNode* sbp_node) { return sbp_node->in_memory_support_; });\n  if (in_memory_support_) { memory_.resize(weighted_cost_.size(), 0); }\n  // Buffer\n  int64_t min_memory_cost = 0, memory_cost = 0;\n  double min_weighted_sum = 0.0, weighted_sum = 0.0;\n  int32_t min_sbp_child = 0;\n  // Only deal with new children_\n  for (int32_t child = previous_children_size; child < children_.size(); child++) {\n    child_node_sbp_sig_[child].resize(weighted_cost_.size());\n\n    for (int32_t sbp_this = 0; sbp_this < weighted_cost_.size(); sbp_this++) {\n      SbpNode* child_node = children_[child];\n      for (int32_t sbp_child = 0; sbp_child < child_node->weighted_cost_.size(); sbp_child++) {\n        if (child_node->edges_in_.size()) {\n          // edge in graph: father -> child\n          memory_cost = child_node->edges_in_[0]->GetMemory(sbp_this, sbp_child)\n                        + child_node->GetMemory(sbp_child);\n          weighted_sum = child_node->edges_in_[0]->weighted_cost_[sbp_this][sbp_child]\n                         + child_node->weighted_cost_[sbp_child];\n        } else {\n          // edge in graph: child -> father\n          memory_cost = child_node->edges_out_[0]->GetMemory(sbp_child, sbp_this)\n                        + child_node->GetMemory(sbp_child);\n          weighted_sum = child_node->edges_out_[0]->weighted_cost_[sbp_child][sbp_this]\n                         + child_node->weighted_cost_[sbp_child];\n        }\n        // update min_cost with fixed SbpSignature for this node and child node\n        if (sbp_child == 0 || weighted_sum < min_weighted_sum) {\n          min_memory_cost = memory_cost;\n          min_weighted_sum = weighted_sum;\n          min_sbp_child = sbp_child;\n        }\n      }\n      child_node_sbp_sig_[child][sbp_this] = min_sbp_child;\n      // Add the cost for child node to this node\n      if (in_memory_support_) { memory_[sbp_this] += min_memory_cost; }\n      weighted_cost_[sbp_this] += min_weighted_sum;\n    }\n  }\n}\n\nbool SbpNode::EliminateItselfAsChild() {\n  if (edges_in_.size() + edges_out_.size() == 1) {\n    if (edges_in_.size()) {\n      // edge in graph: father -> this_node\n      SbpNode* father = edges_in_[0]->start_node_;\n      father->children_.emplace_back(this);\n      CheckAndRemoveFrom<SbpEdge*>(father->edges_out_, edges_in_[0]);\n      father->SummarizeCost();\n    } else {\n      // edge in graph: this_node -> father\n      SbpNode* father = edges_out_[0]->end_node_;\n      father->children_.emplace_back(this);\n      CheckAndRemoveFrom<SbpEdge*>(father->edges_in_, edges_out_[0]);\n      father->SummarizeCost();\n    }\n    // successfully eliminate this node\n    return true;\n  }\n  // can not eliminate this node\n  return false;\n}\n\n// Compute the weighted sum of the time and memory cost\nvoid SbpNode::ComputeWeightedCost() {\n  if (half_node_.empty()) {\n    // If this node is not generated from merging, it should have original cost\n    // weighted_cost_ = cost_;\n    weighted_cost_ = origin_cost_;\n    memory_ = origin_memory_;\n    if (in_memory_support_) {\n      for (int32_t sbp_id = 0; sbp_id < origin_memory_.size(); sbp_id++) {\n        weighted_cost_[sbp_id] += kMemoryRatio * origin_memory_[sbp_id];\n      }\n    }\n  } else {\n    half_node_[0]->ComputeWeightedCost();\n    half_node_[1]->ComputeWeightedCost();\n    // The edge between two half nodes\n    SbpEdge* edge_found = nullptr;\n    if (!half_node_[0]->edges_in_.empty()) {\n      edge_found = half_node_[0]->edges_in_[0];\n    } else if (!half_node_[0]->edges_out_.empty()) {\n      edge_found = half_node_[0]->edges_out_[0];\n    }\n    if (edge_found != nullptr) { edge_found->ComputeWeightedCost(); }\n    // Compute the weighted cost form half nodes\n    for (int32_t merged_sig_id = 0; merged_sig_id < merged_sig_id2half_sig_id_.size();\n         merged_sig_id++) {\n      const auto& pair = merged_sig_id2half_sig_id_[merged_sig_id];\n      if (in_memory_support_) {\n        memory_[merged_sig_id] =\n            half_node_[0]->GetMemory(pair.first) + half_node_[1]->GetMemory(pair.second);\n      }\n      weighted_cost_[merged_sig_id] =\n          half_node_[0]->weighted_cost_[pair.first] + half_node_[1]->weighted_cost_[pair.second];\n      if (edge_found != nullptr) {\n        // The dimension of weighted cost has been expand for the found edge.\n        // Both the dimension of weighted_cost_ is merged_sig_id2half_sig_id_.size().\n        // The start node and end node is changed to this for the found edge.\n        if (in_memory_support_) {\n          memory_[merged_sig_id] += edge_found->GetMemory(merged_sig_id, merged_sig_id);\n        }\n        weighted_cost_[merged_sig_id] += edge_found->weighted_cost_[merged_sig_id][merged_sig_id];\n      }\n    }\n  }\n  // Compute the weighted cost for children\n  for (auto& child_node : children_) {\n    child_node->ComputeWeightedCost();\n    for (auto& in_edge : child_node->edges_in_) { in_edge->ComputeWeightedCost(); }\n    for (auto* out_edge : child_node->edges_out_) { out_edge->ComputeWeightedCost(); }\n  }\n  // Compute the weighted cost from children\n  child_node_sbp_sig_.clear();\n  SummarizeCost();\n}\n\n// Generate the relationship between this merged node and its components\nvoid SbpNode::GenerateComponentRelationship() {\n  // Do nothing if not merged node or already generated\n  if (half_node_.empty() || !component2merged_sig_id2component_sig_id_.empty()) { return; }\n  // Add the map for two half nodes\n  auto& first_merged2component_id = component2merged_sig_id2component_sig_id_[half_node_[0]];\n  auto& second_merged2component_id = component2merged_sig_id2component_sig_id_[half_node_[1]];\n  int32_t total_sbp_num = weighted_cost_.size();\n  first_merged2component_id.resize(total_sbp_num);\n  second_merged2component_id.resize(total_sbp_num);\n  for (int32_t i = 0; i < total_sbp_num; i++) {\n    first_merged2component_id[i] = merged_sig_id2half_sig_id_[i].first;\n    second_merged2component_id[i] = merged_sig_id2half_sig_id_[i].second;\n  }\n  // Add the map for the half of the half nodes\n  for (int32_t i = 0; i < 2; i++) {\n    half_node_[i]->GenerateComponentRelationship();\n    auto& merged2half_id = component2merged_sig_id2component_sig_id_[half_node_[i]];\n    for (auto& pair : half_node_[i]->component2merged_sig_id2component_sig_id_) {\n      auto& merged2component_id = component2merged_sig_id2component_sig_id_[pair.first];\n      merged2component_id.resize(total_sbp_num);\n      auto& half2component_id = pair.second;\n      for (int32_t merged_id = 0; merged_id < total_sbp_num; merged_id++) {\n        merged2component_id[merged_id] = half2component_id[merged2half_id[merged_id]];\n      }\n    }\n  }\n}\n\nvoid SbpNode::FinalizeSbp() {\n  if (!half_node_.empty()) {\n    // Finalize Sbp of merged nodes\n    half_node_[0]->final_sbp_sig_id_ = merged_sig_id2half_sig_id_[final_sbp_sig_id_].first;\n    half_node_[1]->final_sbp_sig_id_ = merged_sig_id2half_sig_id_[final_sbp_sig_id_].second;\n  }\n\n  // Finalize Sbp of children_\n  for (int32_t i = 0; i < children_.size(); i++) {\n    children_[i]->final_sbp_sig_id_ = child_node_sbp_sig_[i][this->final_sbp_sig_id_];\n  }\n\n  // Finalize Sbp of half_node_ Attachment\n  if (!half_node_.empty()) {\n    half_node_[0]->FinalizeSbp();\n    half_node_[1]->FinalizeSbp();\n  }\n\n  // Finalize Sbp of edges in edges_out_\n  for (const auto& edge_out : edges_out_) { edge_out->FinalizeSbp(); }\n\n  // Finalize Sbp again in case of the node on the other side is not finalized\n  // yet. This may happen when Two side of an edge merged into two larger nodes\n  // and this edge is just a sub edge.\n  for (const auto& edge_in : edges_in_) { edge_in->FinalizeSbp(); }\n\n  // Finalize Sbp of children_ Attachment\n  for (int32_t i = 0; i < children_.size(); i++) {\n    children_[i]->FinalizeSbp();\n    for (const auto& edge_in : children_[i]->edges_in_) { edge_in->FinalizeSbp(); }\n  }\n}\n\ndouble SbpNode::GreedyStrategy() {\n  // Current Cost, Minimum Cost, Cost with original sbp\n  double curr_cost = 0;\n  double original_cost = EvalNbhCost();\n  double min_cost = original_cost;\n  int32_t min_sbp = final_sbp_sig_id_;\n  for (int32_t sbp = 0; sbp < weighted_cost_.size(); sbp++) {\n    final_sbp_sig_id_ = sbp;\n    curr_cost = EvalNbhCost();\n    if (curr_cost < min_cost) {\n      min_cost = curr_cost;\n      min_sbp = sbp;\n    }\n  }\n  final_sbp_sig_id_ = min_sbp;\n  return min_cost - original_cost;\n}\n\ndouble SbpNode::EvalNbhCost() const {\n  // Current Cost, Minimum Cost, Cost with original sbp\n  double curr_cost = GetWeightedCost();\n  for (SbpEdge* this_edge : edges_in_) { curr_cost += this_edge->GetWeightedCost(); }\n  for (SbpEdge* this_edge : edges_out_) { curr_cost += this_edge->GetWeightedCost(); }\n  return curr_cost;\n}\n\ndouble SbpNode::EvalOutNbhCost(\n    const std::unordered_map<int32_t, int32_t>& node_list_id2nbh_id) const {\n  // check if this node is in the node list\n  CHECK(node_list_id_ >= 0) << \"Compute out cost for a node out of the node list\" << std::endl;\n  // Cost with original sbp\n  double curr_cost = GetWeightedCost();\n  for (SbpEdge* this_edge : edges_in_) {\n    // if the start node is not in the neighborhood\n    if (node_list_id2nbh_id.find(this_edge->start_node_->node_list_id_)\n        == node_list_id2nbh_id.end()) {\n      curr_cost += this_edge->GetWeightedCost();\n    }\n  }\n  for (SbpEdge* this_edge : edges_out_) {\n    // if the end node is not in the neighborhood\n    if (node_list_id2nbh_id.find(this_edge->end_node_->node_list_id_)\n        == node_list_id2nbh_id.end()) {\n      curr_cost += this_edge->GetWeightedCost();\n    }\n  }\n  return curr_cost;\n}\n\n// Compute the cost between this node and adjacent nodes with a lower order\ndouble SbpNode::EvalInNbhCost(const std::unordered_map<int32_t, int32_t>& node_list_id2nbh_id,\n                              const std::vector<int32_t>& nbh_id2order) const {\n  // check if this node is in the node list\n  CHECK(node_list_id_ >= 0) << \"Compute in cost for a node out of the node list\";\n  // check if the node is in the neighborhood\n  const auto& this_it = node_list_id2nbh_id.find(node_list_id_);\n  CHECK(this_it != node_list_id2nbh_id.end())\n      << \"Compute in cost for a node out of the neighborhood\";\n  // Compute the minimum cost between this node and adjacent nodes with a lower order\n  int32_t order = nbh_id2order[this_it->second];\n  double curr_cost = 0;\n  for (SbpEdge* this_edge : edges_in_) {\n    const auto& it = node_list_id2nbh_id.find(this_edge->start_node_->node_list_id_);\n    // if the start node is in the neighborhood\n    if (it != node_list_id2nbh_id.end() && nbh_id2order[it->second] < order) {\n      curr_cost += this_edge->GetWeightedCost();\n      // End this function and return infinity.\n      if (curr_cost > GetValidMaxCopyCost()) { return GetMaxVal<float>(); }\n    }\n  }\n  for (SbpEdge* this_edge : edges_out_) {\n    const auto& it = node_list_id2nbh_id.find(this_edge->end_node_->node_list_id_);\n    // if the end node is in the neighborhood\n    if (it != node_list_id2nbh_id.end() && nbh_id2order[it->second] < order) {\n      curr_cost += this_edge->GetWeightedCost();\n      if (curr_cost > GetValidMaxCopyCost()) { return GetMaxVal<float>(); }\n    }\n  }\n  return curr_cost;\n}\n\ndouble SbpNode::EvalMinInNbhCost(const std::unordered_map<int32_t, int32_t>& node_list_id2nbh_id,\n                                 const std::vector<int32_t>& nbh_id2order) const {\n  // check if this node is in the node list\n  CHECK(node_list_id_ >= 0) << \"Compute out cost for a node out of the node list\" << std::endl;\n  // check if the node is in the neighborhood\n  const auto& this_it = node_list_id2nbh_id.find(node_list_id_);\n  CHECK(this_it != node_list_id2nbh_id.end())\n      << \"Compute out cost for a node out of the neighborhood\" << std::endl;\n  // Compute the minimum cost between this node and adjacent nodes with a higher order\n  int32_t order = nbh_id2order[this_it->second];\n  double curr_cost = 0;\n  for (SbpEdge* this_edge : edges_in_) {\n    const auto& it = node_list_id2nbh_id.find(this_edge->start_node_->node_list_id_);\n    // if the start node is in the neighborhood\n    if (it != node_list_id2nbh_id.end() && nbh_id2order[it->second] > order) {\n      curr_cost += this_edge->GetMinWeightedCost();\n    }\n  }\n  for (SbpEdge* this_edge : edges_out_) {\n    const auto& it = node_list_id2nbh_id.find(this_edge->end_node_->node_list_id_);\n    // if the end node is in the neighborhood\n    if (it != node_list_id2nbh_id.end() && nbh_id2order[it->second] > order) {\n      curr_cost += this_edge->GetMinWeightedCost();\n    }\n  }\n  return curr_cost;\n}\n\nvoid SbpNode::OneRingNeighborhood(std::vector<int32_t>& nbh_1ring) const {\n  nbh_1ring.resize(edges_in_.size() + edges_out_.size() + 1);\n  int32_t nbh_id = 0;\n  nbh_1ring[nbh_id] = node_list_id_;\n  for (SbpEdge* this_edge : edges_in_) {\n    nbh_id++;\n    nbh_1ring[nbh_id] = this_edge->start_node_->node_list_id_;\n  }\n  for (SbpEdge* this_edge : edges_out_) {\n    nbh_id++;\n    nbh_1ring[nbh_id] = this_edge->end_node_->node_list_id_;\n  }\n}\n\n// Get the n ring neighborhood of this node\n// Pre-allocate buffer, which will be faster.\nvoid SbpNode::NRingNeighborhood(int32_t n, std::vector<int32_t>& nbh_n_ring,\n                                std::vector<int32_t>& nbh_1ring,\n                                const std::vector<SbpNode*>& node_list,\n                                std::vector<bool>& node_tags) const {\n  // Initialize 0 ring\n  if (n <= 0) { n = 0; }\n  nbh_n_ring.resize(1);\n  nbh_n_ring[0] = node_list_id_;\n  node_tags[node_list_id_] = true;\n  int32_t l = 0;\n  // do ring expansion for n times\n  for (int32_t i = 0; i < n; i++) {\n    for (int32_t r = nbh_n_ring.size(); l < r; l++) {\n      node_list[nbh_n_ring[l]]->OneRingNeighborhood(nbh_1ring);\n      for (auto nbh_id : nbh_1ring) {\n        if (!node_tags[nbh_id]) {\n          nbh_n_ring.push_back(nbh_id);\n          node_tags[nbh_id] = true;\n        }\n      }\n    }\n  }\n  // Recover false for buffer\n  for (auto nbh_id : nbh_n_ring) { node_tags[nbh_id] = false; }\n}\n\n// Get or compute the minimum layer of this node\nint32_t SbpNode::GetMinLayer(\n    const HashMap<std::string, SbpNode*>& op_name2sbp_node,\n    const HashMap<const OpNode*, HashSet<std::string>>& op_node2mutable_op_ctrl_deps) {\n  if (min_layer_ >= 0) { return min_layer_; }\n  if (!op_node_) { return min_layer_; }\n  for (SbpEdge* this_edge : edges_in_) {\n    int32_t producer_min_layer =\n        this_edge->start_node_->GetMinLayer(op_name2sbp_node, op_node2mutable_op_ctrl_deps);\n    if (producer_min_layer > min_layer_) { min_layer_ = producer_min_layer; }\n  }\n  for (const auto& ctrl_in_op_name : op_node_->op().op_conf().ctrl_in_op_name()) {\n    const auto& it = op_name2sbp_node.find(ctrl_in_op_name);\n    if (it != op_name2sbp_node.end()) {\n      int32_t producer_min_layer =\n          it->second->GetMinLayer(op_name2sbp_node, op_node2mutable_op_ctrl_deps);\n      if (producer_min_layer > min_layer_) { min_layer_ = producer_min_layer; }\n    }\n  }\n  if (op_node2mutable_op_ctrl_deps.find(op_node_) != op_node2mutable_op_ctrl_deps.end()) {\n    for (const auto& ctrl_in_op_name : op_node2mutable_op_ctrl_deps.at(op_node_)) {\n      const auto& it = op_name2sbp_node.find(ctrl_in_op_name);\n      if (it != op_name2sbp_node.end()) {\n        int32_t producer_min_layer =\n            it->second->GetMinLayer(op_name2sbp_node, op_node2mutable_op_ctrl_deps);\n        if (producer_min_layer > min_layer_) { min_layer_ = producer_min_layer; }\n      }\n    }\n  }\n  return ++min_layer_;\n}\n\n// Spread the minimum layer to compute the maximum layer of producers\nvoid SbpNode::SpreadMaxLayer(\n    const HashMap<std::string, SbpNode*>& op_name2sbp_node,\n    const HashMap<const OpNode*, HashSet<std::string>>& op_node2mutable_op_ctrl_deps) {\n  if (min_layer_ <= 0) { return; }\n  int32_t producer_max_lay = min_layer_ - 1;\n  for (SbpEdge* this_edge : edges_in_) { this_edge->start_node_->DropMaxLayer(producer_max_lay); }\n  for (const auto& ctrl_in_op_name : op_node_->op().op_conf().ctrl_in_op_name()) {\n    const auto& it = op_name2sbp_node.find(ctrl_in_op_name);\n    if (it != op_name2sbp_node.end()) { it->second->DropMaxLayer(producer_max_lay); }\n  }\n  if (op_node2mutable_op_ctrl_deps.find(op_node_) != op_node2mutable_op_ctrl_deps.end()) {\n    for (const auto& ctrl_in_op_name : op_node2mutable_op_ctrl_deps.at(op_node_)) {\n      const auto& it = op_name2sbp_node.find(ctrl_in_op_name);\n      if (it != op_name2sbp_node.end()) { it->second->DropMaxLayer(producer_max_lay); }\n    }\n  }\n}\n\n// Drop down the maximum layer with the minimum layer form consumer\nvoid SbpNode::DropMaxLayer(int32_t upper_bound) {\n  if (upper_bound < max_layer_ || max_layer_ < 0) { max_layer_ = upper_bound; }\n}\n\n// Set max_layer_ = min_layer_ if this node does not have any consumer\n// This is the end of the whole graph\n// We could also set it to be the maximum of the min_layer_ in the graph. (It should be the same.)\nvoid SbpNode::LiftMaxLayer() {\n  if (max_layer_ < min_layer_) { max_layer_ = min_layer_; }\n}\n\n// Set max_layer_ = upper_bound if this node does not have any consumer\nvoid SbpNode::LiftMaxLayer(int32_t upper_bound) {\n  if (max_layer_ < min_layer_) { max_layer_ = upper_bound; }\n}\n\n// Get the minimum element in Cost\ndouble SbpNode::GetMinCost() const {\n  // Check the size of Cost\n  // Can not use weighted cost here since this function is used for find trunk.\n  // We have not initialize weighted cost at this moment\n  CHECK(cost_.size() > 0) << \"Cost not initialized!\" << std::endl;\n  // Compute the min_comp_cost\n  return *std::min_element(cost_.begin(), cost_.end());\n}\n\n// Set the cut ratio\ndouble SbpNode::GetCutRatio() const {\n  double curr_cut_ratio = 1.0;\n  for (auto* this_edge : edges_in_) { curr_cut_ratio *= this_edge->GetCutRatio(); }\n  for (auto* this_edge : edges_out_) { curr_cut_ratio *= this_edge->GetCutRatio(); }\n  return curr_cut_ratio;\n}\n\n// Judge if this node is on the trunk\n// If so, judge it for its producer/upstream nodes\nvoid SbpNode::SpreadTrunk(const HashMap<std::string, SbpNode*>& op_name2sbp_node) {\n  // Skip it if this node is already judged.\n  if (on_trunk_) { return; }\n  // Skip sbp proxy. This is before we have proxy.\n  if (min_layer_ < 0) { return; }\n  on_trunk_ = true;\n  // If I am in the trunk, then all the children with (min_layer_ >= my layer id - 1) would be\n  // considered as in the trunk\n  for (SbpEdge* this_edge : edges_in_) {\n    if (this_edge->start_node_->min_layer_ >= min_layer_ - 1) {\n      this_edge->start_node_->SpreadTrunk(op_name2sbp_node);\n    }\n  }\n  for (const auto& ctrl_in_op_name : op_node_->op().op_conf().ctrl_in_op_name()) {\n    const auto& it = op_name2sbp_node.find(ctrl_in_op_name);\n    if (it != op_name2sbp_node.end() && it->second->min_layer_ >= min_layer_ - 1) {\n      it->second->SpreadTrunk(op_name2sbp_node);\n    }\n  }\n}\n\n// Count consumers and any downstream nodes defined by control edges\nvoid SbpNode::RaiseConsumerNum(const HashMap<std::string, SbpNode*>& op_name2sbp_node) {\n  // Should clear it before running.\n  // skip the proxy nodes and the sources\n  if (min_layer_ <= 0) { return; }\n  for (SbpEdge* this_edge : edges_in_) { this_edge->start_node_->counter_++; }\n  for (const auto& ctrl_in_op_name : op_node_->op().op_conf().ctrl_in_op_name()) {\n    const auto& it = op_name2sbp_node.find(ctrl_in_op_name);\n    if (it != op_name2sbp_node.end()) { it->second->counter_++; }\n  }\n}\n\n// Compute the minimal available wait time for producers or upstream nodes\nvoid SbpNode::SpreadAvailWaitTime(const std::vector<double>& trunk_cost,\n                                  const std::vector<double>& acc_trunk_cost,\n                                  const HashMap<std::string, SbpNode*>& op_name2sbp_node,\n                                  double wait_time) {\n  // skip the proxy nodes and the sources\n  if (min_layer_ <= 0) { return; }\n  // Have not finished spreading for consumers or downstream nodes or already visited.\n  if (counter_) { return; }\n  if (on_trunk_) {\n    // Nodes on the trunk does not have any accumulate cost\n    acc_trunk_cost_ = 0;\n  } else {\n    if (acc_trunk_cost_ < 0) {\n      // Do not have any consumer or downstream node\n      acc_trunk_cost_ = acc_trunk_cost[min_layer_ - 1];\n    } else {\n      // Add the trunk cost at this layer\n      acc_trunk_cost_ += trunk_cost[min_layer_];\n    }\n  }\n\n  // Reduce the wait time for edges_in_, put the rest of the trunk cost in the producers\n  for (SbpEdge* this_edge : edges_in_) {\n    CHECK(this_edge->wait_time_ < 0)\n        << \"Double assign values into wait_time_ of this edge!\" << std::endl;\n    SbpNode* producer = this_edge->start_node_;\n    // Accumulate the cost from the start node to this node\n    double curr_trunk_cost =\n        acc_trunk_cost_ + acc_trunk_cost[producer->min_layer_] - acc_trunk_cost[min_layer_ - 1];\n    if (curr_trunk_cost >= wait_time) {\n      // Remain cost in the trunk is able to cover all the wait time\n      this_edge->wait_time_ = 0.0;\n      curr_trunk_cost -= wait_time;\n    } else {\n      // Remain cost in the trunk can only cover partial wait time\n      this_edge->wait_time_ = wait_time - curr_trunk_cost;\n      curr_trunk_cost = 0.0;\n    }\n    // Reducing non-matching edges\n    // For example:\n    // (1) P->S0->S0->S0->B\n    // (2) p->B->B->B->B\n    // We would use (2) when the tensor is relatively tiny.\n    // Do not inherit trunk cost for nodes on the trunk\n    if (!producer->on_trunk_) {\n      // Inherit the minimal of the trunk cost from consumers\n      producer->DropAvailWaitTime(curr_trunk_cost);\n    }\n    producer->counter_--;\n    producer->SpreadAvailWaitTime(trunk_cost, acc_trunk_cost, op_name2sbp_node, wait_time);\n  }\n  // Put the rest the trunk cost in the upstream nodes.\n  for (const auto& ctrl_in_op_name : op_node_->op().op_conf().ctrl_in_op_name()) {\n    const auto& it = op_name2sbp_node.find(ctrl_in_op_name);\n    if (it != op_name2sbp_node.end()) {\n      SbpNode* producer = it->second;\n      // Do not inherit trunk cost for nodes on the trunk\n      if (!producer->on_trunk_) {\n        // Accumulate the cost from the start node to this node\n        double curr_trunk_cost =\n            acc_trunk_cost_ + acc_trunk_cost[producer->min_layer_] - acc_trunk_cost[min_layer_ - 1];\n        // Inherit the minimal of the trunk cost from consumers\n        producer->DropAvailWaitTime(curr_trunk_cost);\n      }\n      producer->counter_--;\n      producer->SpreadAvailWaitTime(trunk_cost, acc_trunk_cost, op_name2sbp_node, wait_time);\n    }\n  }\n  // Set counter_ to be -1, do not visit it again.\n  counter_--;\n}\n\n// Drop down the available wait time with the minimum cost from downstream\nvoid SbpNode::DropAvailWaitTime(double curr_trunk_cost) {\n  if (acc_trunk_cost_ < 0.0 || acc_trunk_cost_ > curr_trunk_cost) {\n    acc_trunk_cost_ = curr_trunk_cost;\n  }\n}\n\n// Assemble copy cost and partial memory cost for all the incoming edges\nvoid SbpNode::InitCopyAndMemoryCost(bool use_sbp_collector, bool nccl_not_use_compute_stream) {\n  for (SbpEdge* this_edge : edges_in_) {\n    const auto* sbp_node_producer = this_edge->start_node_;\n    OpNode* producer = sbp_node_producer->op_node_;\n\n    // skip it if proxy\n    if (use_sbp_collector && !producer) { continue; }\n    // look through input blobs\n    for (const std::string& ibn : op_node_->op().input_bns()) {\n      if (producer->op().op_name() == op_node_->SrcNode4Ibn(ibn).op().op_name()) {\n        this_edge->InitCopyAndMemoryCost(ibn, use_sbp_collector, nccl_not_use_compute_stream);\n      }\n    }\n    // Add Wait time\n    for (auto& cost_row : this_edge->cost_) {\n      for (auto& cost_value : cost_row) {\n        // If transferring between devices, we need to add wait time.\n        if (cost_value > 0.0) { cost_value += this_edge->wait_time_; }\n      }\n    }\n  }\n}\n\n// Assemble memory cost\nvoid SbpNode::InitializeMemory(bool is_reusable, const HashMap<LogicalBlobId, int32_t>& lbi2id,\n                               const std::vector<int32_t>& id2count, bool nccl_use_compute_stream) {\n  const auto& curr_operator = op_node_->op();\n  // An edge should not be initialized twice\n  // During each initialization, we are computing sum(memory of consumer) - sum(memory of producer)\n  // This is why we need to pre-store memory of producer\n  HashMap<SbpEdge*, std::vector<int64_t>> sbp_edge2nd_sbp_sig2memory;\n  for (const auto& obn : curr_operator.output_bns()) {\n    const LogicalBlobId& lbi = curr_operator.BnInOp2Lbi(obn);\n    // Fixed memory or in the support of the reusable memory\n    if (!is_reusable || id2count.at(lbi2id.at(lbi)) > 0) {\n      // If not in support, memory_ would be empty.\n      in_memory_support_ = true;\n      memory_.resize(sbp_sig_list_.size(), 0);\n      const auto& logical_blob_desc = op_node_->LogicalBlobDesc4Lbi(lbi);\n      const auto& hierarchy = *CHECK_JUST(curr_operator.GetParallelDesc4BnInOp(obn))->hierarchy();\n      // There are some operators with a fixed sbp for some blobs, such as conv.\n      // {in: S0, kernel: B, out: S0}\n      // {in: B, kernel: B, out: B}\n      // The blob kernel have the same sbp for different signatures.\n      // We pre-store the results for the same sbp while accessing the same blobs.\n      HashMap<NdSbp, int64_t> nd_sbp2memory;\n      SbpEdge* edge_contain_lbi = nullptr;\n      for (const auto& edge_out : edges_out_) {\n        if (edge_out->SearchLbi(lbi)) { edge_contain_lbi = edge_out; }\n      }\n      // There exist some lbi which does not have a consumer\n      // At this moment edge_contain_lbi == nullptr\n      auto& nd_sbp_sig2memory = sbp_edge2nd_sbp_sig2memory[edge_contain_lbi];\n      nd_sbp_sig2memory.resize(sbp_sig_list_.size(), 0);\n      for (int32_t sbp_sig_id = 0; sbp_sig_id < sbp_sig_list_.size(); sbp_sig_id++) {\n        const NdSbp& nd_sbp = sbp_sig_list_[sbp_sig_id].bn_in_op2nd_sbp().at(obn);\n        auto it = nd_sbp2memory.find(nd_sbp);\n        if (it == nd_sbp2memory.end()) {\n          // This compute the memory at rank 0, the largest one.\n          // We could be faster if we just compute the average memory.\n          it = nd_sbp2memory\n                   .insert({nd_sbp, MaxByteSize4BlobDescSbp(logical_blob_desc, nd_sbp, hierarchy)})\n                   .first;\n        }\n        memory_[sbp_sig_id] += it->second;\n        nd_sbp_sig2memory[sbp_sig_id] += it->second;\n      }\n    }\n  }\n  // Even after the correction in the memory of edges, the relative error still have 0.73%.\n  if (nccl_use_compute_stream && in_memory_support_ && is_reusable) {\n    for (const auto& pair : sbp_edge2nd_sbp_sig2memory) {\n      // Init memory for each out-going edge\n      pair.first->InitializeMemory(lbi2id, id2count, pair.second);\n    }\n  }\n}\n\n// Reduce and set the wait time for op in the trunk\nvoid SbpNode::SetTrunkWaitTime(double trunk_wait_time) {\n  // only reduce the wait time for operators in the trunk\n  if (on_trunk_) {\n    // Reduce the wait time for edges_out_\n    for (SbpEdge* edge_out : edges_out_) {\n      if (edge_out->wait_time_ < 0.0 || edge_out->wait_time_ > trunk_wait_time) {\n        edge_out->wait_time_ = trunk_wait_time;\n      }\n    }\n    // Might reduce it for edges_in_\n  }\n}\n\n// Drop down the maximum layer with the minimum layer form consumer\nvoid SbpNode::DropTributaryLayer(int32_t upper_bound) {\n  if (upper_bound < tributary_layer_ || tributary_layer_ < 0) { tributary_layer_ = upper_bound; }\n}\n\n// Compute maximum layer for tributaries\nvoid SbpNode::SpreadTributaryLayer(const HashMap<std::string, SbpNode*>& op_name2sbp_node) {\n  if (counter_ || min_layer_ <= 0) { return; }\n  int32_t producer_max_lay = 0;\n  if (on_trunk_) {\n    producer_max_lay = min_layer_ - 1;\n  } else {\n    // On a tributary, the operator could be run later.\n    producer_max_lay = tributary_layer_;\n    // producer_max_lay = tributary_layer_ - 1;\n  }\n  for (SbpEdge* this_edge : edges_in_) {\n    this_edge->start_node_->DropTributaryLayer(producer_max_lay);\n    if (--this_edge->start_node_->counter_ == 0) {\n      this_edge->start_node_->SpreadTributaryLayer(op_name2sbp_node);\n    }\n  }\n  for (const auto& ctrl_in_op_name : op_node_->op().op_conf().ctrl_in_op_name()) {\n    const auto& it = op_name2sbp_node.find(ctrl_in_op_name);\n    if (it != op_name2sbp_node.end()) {\n      it->second->DropTributaryLayer(producer_max_lay);\n      if (--it->second->counter_ == 0) { it->second->SpreadTributaryLayer(op_name2sbp_node); }\n    }\n  }\n  counter_--;\n}\n\nSbpEdge* SbpNode::FindEdgeWithNode(const SbpNode* other_node) const {\n  for (auto* sbp_edge : edges_in_) {\n    if (sbp_edge->start_node_ == other_node) { return sbp_edge; }\n  }\n  for (auto* sbp_edge : edges_out_) {\n    if (sbp_edge->end_node_ == other_node) { return sbp_edge; }\n  }\n  return nullptr;\n};\n\n// Decide to use this SbpSignature\nconst NdSbpSignature& SbpNode::FinalSbpSignature() const {\n  CHECK(!sbp_sig_list_.empty()) << \"Asking for sbp signature for an empty node\";\n  return sbp_sig_list_[final_sbp_sig_id_];\n};\n\nint32_t SbpNode::GetComponentSbpId(int32_t merged_id, SbpNode* component_node) const {\n  if (this == component_node) { return merged_id; }\n  CHECK(!component2merged_sig_id2component_sig_id_.empty())\n      << \"Check the component before initialization!\" << std::endl;\n  return component2merged_sig_id2component_sig_id_.at(component_node).at(merged_id);\n}\n\n// Judge if sbp_node is a port of the current node\nbool SbpNode::IsComponent(SbpNode* sbp_node) const {\n  if (this == sbp_node) { return true; }\n  // If IsComponent() is call before we initialize component2merged_sig_id2component_sig_id_,\n  // we would also return false.\n  // Please do not call GenerateComponentRelationship() at here.\n  // Please see SbpEdge::SummarizeCost() for more details.\n  return component2merged_sig_id2component_sig_id_.find(sbp_node)\n         != component2merged_sig_id2component_sig_id_.end();\n}\n\n}  // namespace auto_parallel\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/auto_parallel/sbp_node.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_AUTO_PARALLEL_SBP_NODE_H_\n#define ONEFLOW_CORE_AUTO_PARALLEL_SBP_NODE_H_\n\n#include <cstdlib>\n#include <functional>\n#include <iostream>\n#include <vector>\n#include \"oneflow/core/auto_parallel/binary_set.h\"\n#include \"oneflow/core/common/data_type.h\"\n#include \"oneflow/core/common/hash_container.h\"\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/framework/sbp_infer_util.h\"\n#include \"oneflow/core/graph/op_graph.h\"\n#include \"oneflow/core/auto_parallel/algorithm_util.h\"\n#include \"oneflow/core/job/sbp_parallel.pb.h\"\n\nnamespace oneflow {\nnamespace auto_parallel {\n\nclass SbpEdge;\n\n// A node structure to deal with the SBP strategy.\n// Please see SbpGraph for the whole algorithm and introduction.\nclass SbpNode final {\n public:\n  // default constructor\n  SbpNode() : final_sbp_sig_id_(0) {}\n\n  // This constructor is to merge two node into one\n  SbpNode(SbpNode* first, SbpNode* second);\n\n  ~SbpNode();\n\n  OF_DISALLOW_COPY_AND_MOVE(SbpNode);\n  bool operator==(const SbpNode& other) { return this == &other; }\n\n  // another node point to this node\n  void PointFrom(SbpNode* start_node);\n  // this node point to another node\n  void PointTo(SbpNode* end_node);\n\n  SbpEdge* FindEdgeWithNode(const SbpNode* other_node) const;\n\n  // Check and eliminate one child node.\n  // Only used by SbpGraph since it need to remove it from the NodeList after this.\n  bool EliminateItselfAsChild();\n\n  // Initialize SbpSignature from Signature Objects\n  void InitializeSbp();\n  // Decide to use this SbpSignature\n  const NdSbpSignature& FinalSbpSignature() const;\n\n  // Recompute Computation Cost after adding child nodes in it\n  void SummarizeCost();\n  // Compute the weighted sum of the time and memory cost\n  void ComputeWeightedCost();\n  // Generate the relationship between this merged node and its components\n  void GenerateComponentRelationship();\n  // Determine Final SbpSignature for attachment of this node\n  void FinalizeSbp();\n  // Use Greedy Strategy to pick the sbp signature with minimum cost for this\n  // node You should have an initial strategy before running this\n  double GreedyStrategy();\n  // Evaluate summery of cost between neighborhood and outside nodes\n  double EvalOutNbhCost(const std::unordered_map<int32_t, int32_t>& node_list_id2nbh_id) const;\n  // Evaluate summery of cost within neighborhood\n  // We only accumulate the edge cost with a lower order.\n  double EvalInNbhCost(const std::unordered_map<int32_t, int32_t>& node_list_id2nbh_id,\n                       const std::vector<int32_t>& nbh_id2order) const;\n  // Evaluate summery of cost within neighborhood\n  // We only accumulate the minimum edge cost with a higher order.\n  double EvalMinInNbhCost(const std::unordered_map<int32_t, int32_t>& node_list_id2nbh_id,\n                          const std::vector<int32_t>& nbh_id2order) const;\n  // Get the one ring neighborhood of this node, which is itself and all the adjacent nodes.\n  void OneRingNeighborhood(std::vector<int32_t>& nbh_1ring) const;\n  // Get the n ring neighborhood of this node\n  // Pre-allocate buffer, which will be faster.\n  void NRingNeighborhood(int32_t n, std::vector<int32_t>& nbh_n_ring,\n                         std::vector<int32_t>& nbh_1ring, const std::vector<SbpNode*>& node_list,\n                         std::vector<bool>& node_tags) const;\n\n  // Get or compute the minimum layer of this node\n  int32_t GetMinLayer(\n      const HashMap<std::string, SbpNode*>& op_name2sbp_node,\n      const HashMap<const OpNode*, HashSet<std::string>>& op_node2mutable_op_ctrl_deps);\n  // Spread the minimum layer to compute the maximum layer of producers\n  void SpreadMaxLayer(\n      const HashMap<std::string, SbpNode*>& op_name2sbp_node,\n      const HashMap<const OpNode*, HashSet<std::string>>& op_node2mutable_op_ctrl_deps);\n  // Set max_layer_ = min_layer_ if this node does not have any consumer\n  void LiftMaxLayer();\n  // Set max_layer_ = upper_bound if this node does not have any consumer\n  void LiftMaxLayer(int32_t upper_bound);\n  // Compute maximum layer for tributaries\n  void SpreadTributaryLayer(const HashMap<std::string, SbpNode*>& op_name2sbp_node);\n  // Drop down the tributary layer\n  void DropTributaryLayer(int32_t upper_bound);\n\n  // Get the minimum element in Cost\n  double GetMinCost() const;\n  // get the cut ratio\n  double GetCutRatio() const;\n\n  // Judge if this node is on the trunk\n  // If so, judge it for its producer/upstream nodes\n  void SpreadTrunk(const HashMap<std::string, SbpNode*>& op_name2sbp_node);\n  // Count consumers and any downstream nodes defined by control edges\n  // for producers or upstream nodes\n  void RaiseConsumerNum(const HashMap<std::string, SbpNode*>& op_name2sbp_node);\n  // Compute the minimal available wait time for producers or upstream nodes\n  void SpreadAvailWaitTime(const std::vector<double>& trunk_cost,\n                           const std::vector<double>& acc_trunk_cost,\n                           const HashMap<std::string, SbpNode*>& op_name2sbp_node,\n                           double wait_time);\n  // Reduce and set the wait time for op in the trunk\n  void SetTrunkWaitTime(double trunk_wait_time);\n\n  // Assemble copy cost and partial memory cost for all the incoming edges\n  void InitCopyAndMemoryCost(bool use_sbp_collector, bool nccl_not_use_compute_stream);\n  // Assemble memory cost\n  void InitializeMemory(bool is_reusable, const HashMap<LogicalBlobId, int32_t>& lbi2id,\n                        const std::vector<int32_t>& id2count, bool nccl_use_compute_stream);\n\n  // Constant getter\n  int32_t GetMinLayer() const { return min_layer_; }\n  int32_t GetTributaryLayer() const { return tributary_layer_; }\n  OpNode* GetOperatorNode() const { return op_node_; }\n  const std::vector<SbpEdge*>& GetEdgesIn() const { return edges_in_; }\n  const std::vector<SbpEdge*>& GetEdgesOut() const { return edges_out_; }\n  int64_t GetMemory(int32_t i) const { return in_memory_support_ ? memory_[i] : 0; }\n  // Get the current memory with the current sbp signature index\n  int64_t GetMemory() const { return GetMemory(final_sbp_sig_id_); }\n  double GetWeightedCost(int32_t i) const { return weighted_cost_[i]; }\n  // Get the current weighted cost with the current sbp signature index\n  double GetWeightedCost() const { return GetWeightedCost(final_sbp_sig_id_); }\n  int32_t GetComponentSbpId(int32_t merged_id, SbpNode* component_node) const;\n  // Judge if sbp_node is a port of the current node\n  bool IsComponent(SbpNode* sbp_node) const;\n\n  // Setter\n  void SetInMemorySupport(bool in_memory_support) { in_memory_support_ = in_memory_support; }\n\n private:\n  friend class SbpEdge;\n  friend class SbpGraph;\n  friend class SbpCollector;\n  friend class SbpConstructor;\n\n  // compound edge in\n  std::vector<SbpEdge*> edges_in_;\n  // compound edge out\n  std::vector<SbpEdge*> edges_out_;\n\n  // Location in node_list of SbpGraph\n  int32_t node_list_id_ = -1;\n  // Global SbpSignature List Size\n  int32_t global_sbp_sig_size_ = -1;\n  // Decide to use SbpSignature with this id\n  int32_t final_sbp_sig_id_;\n  // Available SbpSignature object for this node\n  std::vector<NdSbpSignature> sbp_sig_list_;\n  // Cost[sbp] is Computation Cost when using sbp_sig_list_[sbp]\n  std::vector<double> cost_;\n  std::vector<double> origin_cost_;\n\n  // Child node list\n  std::vector<SbpNode*> children_;\n  // SbpSignature for each child node when using specific SbpSignature for this\n  // node Its dimension is Number of Child Nodes * Number of Available\n  // SbpSignatures for this node\n  std::vector<std::vector<int32_t>> child_node_sbp_sig_;\n\n  // Merge two nodes into this compound node\n  std::vector<SbpNode*> half_node_;\n  // We should delete those merged-signatures which has very large cost for speed up\n  // New sbp_sig_list_ index map to each half_node_'s sig_index\n  std::vector<std::pair<int32_t, int32_t>> merged_sig_id2half_sig_id_;\n\n  std::vector<BinarySet> parallel_candidates_;\n\n  OpNode* op_node_ = nullptr;\n\n  // We divide the sbp graph into multiple layers.\n  // min_layer_ is the minimum layer number to run this op as soon as possible.\n  // max_layer_ is the maximum layer number without slowing down the whole process of the graph.\n  // producer.max_layer_ < this_node.min_layer_ <= this_node.max_layer_ < consumer.min_layer_\n  int32_t min_layer_ = -1, max_layer_ = -1;\n  // Maximum layer in tributaries\n  int32_t tributary_layer_ = -1;\n  // Whether we are on the trunk\n  bool on_trunk_ = false;\n  // A counter_ buffer for topological traversal or something else\n  int32_t counter_ = 0;\n  // Accumulate trunk cost from consumer to the end\n  double acc_trunk_cost_ = -1.0;\n\n  // The produced blob belongs to the support of the total memory\n  bool in_memory_support_ = false;\n  // The consumed memory for different sbp strategies\n  std::vector<int64_t> memory_;\n  std::vector<int64_t> origin_memory_;\n  // The weighted sum of time cost and memory cost\n  // More specifically, weighted cost = time cost + kMemoryRatio * memory;\n  // We do not add any weight for the time cost since we need to judge if a cost is less than\n  // GetValidMaxCopyCost().\n  std::vector<double> weighted_cost_;\n  // Relationship between a merged node and its components\n  HashMap<SbpNode*, std::vector<int32_t>> component2merged_sig_id2component_sig_id_;\n\n  // Let one node point to another\n  void StartPointToEnd(SbpNode* start_node, SbpNode* end_node);\n\n  // Evaluate summery of cost in 1-ring neighborhood.\n  double EvalNbhCost() const;\n  // Drop down the maximum layer with the minimum layer from consumer\n  void DropMaxLayer(int32_t upper_bound);\n  // Drop down the available wait time with the minimum cost from downstream\n  void DropAvailWaitTime(double curr_trunk_cost);\n};  // class SbpNode\n\n}  // namespace auto_parallel\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_AUTO_PARALLEL_SBP_NODE_H_\n"
  },
  {
    "path": "oneflow/core/auto_parallel/sbp_util.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include <memory>\n#include \"oneflow/core/auto_parallel/sbp_util.h\"\n#include \"oneflow/core/common/data_type.h\"\n#include \"oneflow/core/job/sbp_parallel.h\"\n#include \"oneflow/core/graph/boxing/hierarchical_sub_task_graph_builder_impl.h\"\n\nnamespace oneflow {\nnamespace auto_parallel {\n\n// Judge whether we need the same SBP for both producer and consumer\nbool RequireSameSbp(const OpNode* consumer, const std::string& ibn) {\n  // is mutable\n  const auto& input_blob_modifier_ = consumer->op().InputBlobModifier4Ibn(ibn);\n  if (input_blob_modifier_.has_is_mutable() && input_blob_modifier_.is_mutable()) { return true; }\n  // kOFRecord or kTensorBuffer don't accept boxing\n  const LogicalBlobId& lbi = consumer->op().BnInOp2Lbi(ibn);\n  const OpNode& producer = consumer->ProducerOpNode4Lbi(lbi);\n  const BlobDesc& logical_blob_desc = producer.LogicalBlobDesc4Lbi(lbi);\n  return (logical_blob_desc.data_type() == DataType::kOFRecord\n          || logical_blob_desc.data_type() == DataType::kTensorBuffer);\n}\n\n}  // namespace auto_parallel\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/auto_parallel/sbp_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_AUTO_PARALLEL_SBP_UTIL_H_\n#define ONEFLOW_CORE_AUTO_PARALLEL_SBP_UTIL_H_\n\n#include \"oneflow/core/graph/op_graph.h\"\n\nnamespace oneflow {\nnamespace auto_parallel {\n\n// Judge whether we need the same SBP for both producer and consumer\nbool RequireSameSbp(const OpNode* consumer, const std::string& ibn);\n\n}  // namespace auto_parallel\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_AUTO_PARALLEL_SBP_UTIL_H_\n"
  },
  {
    "path": "oneflow/core/autograd/autograd_captured_tensor.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_AUTOGRAD_AUTOGRAD_CAPTURED_TENSOR_H_\n#define ONEFLOW_CORE_AUTOGRAD_AUTOGRAD_CAPTURED_TENSOR_H_\n\n#include \"oneflow/core/framework/tensor.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\nnamespace one {\n\nclass AutogradCapturedTensor final : public ProxyTensor<AutogradCapturedTensor> {\n public:\n  static Maybe<AutogradCapturedTensor> MakeTensor(const std::shared_ptr<Tensor>& tensor) {\n    if (tensor->requires_grad()) {\n      CHECK_NOTNULL_OR_RETURN(tensor->grad_fn_node().get())\n          << Error::RuntimeError()\n          << \"a grad function node is expected for the captured tensor \"\n             \"which requires_grad is True.\";\n    }\n    std::shared_ptr<AutogradCapturedTensor> captured_tensor(\n        new AutogradCapturedTensor(JUST(tensor->detach())));\n    captured_tensor->set_autograd_meta(tensor->mut_autograd_meta());\n    captured_tensor->grad_fn_node_ = tensor->mut_grad_fn_node();\n    return captured_tensor;\n  }\n\n  std::shared_ptr<const FunctionNode> grad_fn_node() const override { return grad_fn_node_.lock(); }\n  void set_grad_fn_node(const std::shared_ptr<FunctionNode>& grad_fn_node) override {\n    PRINT_BUG_PROMPT_AND_ABORT();\n  }\n  std::shared_ptr<FunctionNode> mut_grad_fn_node() override { return grad_fn_node_.lock(); }\n\n  std::shared_ptr<Tensor> contiguous() const override {\n    const auto& tensor = std::const_pointer_cast<Tensor>(shared_from_this());\n    if (tensor_->is_contiguous()) { return tensor; }\n    return CHECK_JUST(functional::ToContiguous(tensor));\n  }\n\n private:\n  explicit AutogradCapturedTensor(const std::shared_ptr<Tensor>& tensor)\n      : ProxyTensor<AutogradCapturedTensor>(tensor) {}\n\n private:\n  std::weak_ptr<FunctionNode> grad_fn_node_;\n};\n\n}  // namespace one\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_AUTOGRAD_AUTOGRAD_CAPTURED_TENSOR_H_\n"
  },
  {
    "path": "oneflow/core/autograd/autograd_engine.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include <memory>\n#include <stack>\n#include <queue>\n#include \"fmt/core.h\"\n#include \"fmt/format.h\"\n#include \"oneflow/core/autograd/autograd_engine.h\"\n#include \"oneflow/core/autograd/autograd_meta.h\"\n#include \"oneflow/core/autograd/autograd_mode.h\"\n#include \"oneflow/core/common/container_util.h\"\n#include \"oneflow/core/common/error.h\"\n#include \"oneflow/core/framework/tensor.h\"\n#include \"oneflow/core/framework/tensor_arg.h\"\n#include \"oneflow/core/framework/tensor_methods.h\"\n#include \"oneflow/core/framework/tensor_util.h\"\n#include \"oneflow/core/framework/tensor_tuple.h\"\n#include \"oneflow/core/framework/tensor_rpc_util.h\"\n#include \"oneflow/core/functional/functional.h\"\n#include \"oneflow/core/framework/nd_sbp.h\"\n#include \"oneflow/core/framework/global_param_grad_sync_mode.h\"\n#include \"oneflow/core/job/lazy_mode.h\"\n#include \"oneflow/core/profiler/profiler.h\"\n#include \"oneflow/core/common/env_var/debug_mode.h\"\n#include \"oneflow/core/persistence/tee_persistent_log_stream.h\"\n\nnamespace oneflow {\nnamespace one {\n\nnamespace {\n\nvoid GatherFunctionNodes(FunctionNode* node, std::stack<std::shared_ptr<FunctionNode>>& stack) {\n  for (auto& prev_node : node->next_functions()) {\n    auto prev_node_fun = std::get<0>(prev_node);\n    if (prev_node_fun) {\n      if (prev_node_fun.use_count() == 1) { stack.push(prev_node_fun); }\n    }\n  }\n}\n\n/* NOTE:\n * Stack overflows when releasing a very deep computation graph without\n * a custom deleter.\n *\n * For example, here is a very deep computation graph:\n * Tensor -> FunctionNode -> Tensor -> FunctionNode -> ... -> Tensor -> FunctionNode\n * When releasing the first Tensor, it will trigger the recursive deletion and stack overflow.\n *\n * So we must set a custom deleter and release them iteratively.\n */\nvoid FunctionNodeDeleter(FunctionNode* node) {\n  std::stack<std::shared_ptr<FunctionNode>> stack;\n  node->ReleaseData();\n  GatherFunctionNodes(node, stack);\n  delete node;\n\n  while (!stack.empty()) {\n    auto now_node = std::move(stack.top());\n    stack.pop();\n    now_node->ReleaseData();\n    GatherFunctionNodes(now_node.get(), stack);\n  }\n}\n\nbool IsReadyToRun(const std::vector<std::shared_ptr<AutogradMeta>>& out_meta_datas) {\n  return std::any_of(out_meta_datas.begin(), out_meta_datas.end(),\n                     [](const std::shared_ptr<AutogradMeta>& meta_data) {\n                       return !meta_data->current_grad()->Empty();\n                     });\n}\n\nMaybe<void> CopyOrAccGrad(AutogradMeta* autograd_meta, bool autograd_mode) {\n  autograd::AutoGradMode mode(autograd_mode);\n  auto current_grad = JUST(autograd_meta->current_grad_value());\n  if (!current_grad) { return Maybe<void>::Ok(); }\n  if (autograd_meta->acc_grad()) {\n    JUST(functional::Add(autograd_meta->acc_grad(), current_grad, /*alpha=*/1.0,\n                         /*inplace=*/true));\n  } else {\n    // NOTE: acc_grad can not share data with current_grad, because accumulate acc_grad\n    // with inplace operation and it maybe change current_grad to get wrong result.\n    // See more details in https://github.com/Oneflow-Inc/oneflow/issues/8248\n    if (!LazyMode::is_enabled()) { current_grad = JUST(functional::Identity(current_grad)); }\n    JUST(autograd_meta->set_acc_grad(current_grad));\n  }\n  for (const auto& hook : autograd_meta->post_grad_accumulation_hooks()) {\n    auto new_grad = hook(autograd_meta->acc_grad());\n    if (new_grad) { JUST(autograd_meta->set_acc_grad(new_grad)); }\n  }\n\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> RawTouchGlobalTensor(const std::shared_ptr<one::Tensor>& tensor) {\n  // Do nothing.\n  return Maybe<void>::Ok();\n}\n\nstatic constexpr auto* TouchGlobalTensor = DECORATE(&RawTouchGlobalTensor, CheckGlobalTensorMeta);\n\nMaybe<void> CheckGlobalTensorsMeta(const TensorTuple& tensor_tuple) {\n  for (const auto& tensor : tensor_tuple) {\n    if (tensor->is_global() && tensor->is_eager()) { JUST(TouchGlobalTensor(tensor)); }\n  }\n  return Maybe<void>::Ok();\n}\n\nstd::string GetDebugGraphFileName(const std::string& mode, const std::string& suffix) {\n  return fmt::format(\"autograd_{}_rank{}_suffix_graph.dot\", mode, GlobalProcessCtx::Rank(), suffix);\n}\n\n}  // namespace\n\nMaybe<void> AutogradEngine::RunBackwardAndSaveGrads4LeafTensorIf(const TensorTuple& outputs,\n                                                                 const TensorTuple& out_grads,\n                                                                 bool retain_graph,\n                                                                 bool create_graph) {\n  JUST(CheckGlobalTensorsMeta(outputs));\n  JUST(CheckGlobalTensorsMeta(out_grads));\n  DisableCheckGlobalTensorMetaScope disable_meta_check;\n  return RunBackwardAndSaveGrads4LeafTensor(outputs, out_grads, retain_graph, create_graph);\n}\n\nMaybe<TensorTuple> AutogradEngine::RunBackwardAndReturnInputsTensorGradIf(\n    const TensorTuple& outputs, const TensorTuple& inputs, const TensorTuple& out_grads,\n    bool retain_graph, bool create_graph, bool allow_unused) {\n  JUST(CheckGlobalTensorsMeta(outputs));\n  JUST(CheckGlobalTensorsMeta(inputs));\n  JUST(CheckGlobalTensorsMeta(out_grads));\n  DisableCheckGlobalTensorMetaScope disable_meta_check;\n  return RunBackwardAndReturnInputsTensorGrad(outputs, inputs, out_grads, retain_graph,\n                                              create_graph, allow_unused);\n}\n\nMaybe<void> FunctionNode::AccGrad4RetainGradTensor(bool create_graph) {\n  for (const std::shared_ptr<AutogradMeta>& out : output_meta_data_) {\n    if (out->retain_grad()) { JUST(CopyOrAccGrad(out.get(), create_graph)); }\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> FunctionNode::AccGrad4LeafTensor(bool create_graph) {\n  for (auto i = 0; i < output_meta_data_.size(); i++) {\n    auto& out = output_meta_data_[i];\n\n    if (out->is_leaf() && out->requires_grad()) {\n      JUST(CopyOrAccGrad(out.get(), /*autograd_mode=*/create_graph));\n\n      // control acc_grad to do boxing conditionally\n      const auto& acc_grad = out->acc_grad();\n      if (!LazyMode::is_enabled() && GlobalGradSyncMode::is_enabled() && acc_grad->is_global()\n          && acc_grad->is_eager()) {\n        auto& tensor_info = output_tensor_infos_[i];\n        const auto& placement = JUST(tensor_info.placement());\n        const auto& nd_sbp = JUST(tensor_info.sbp());\n        JUST(out->set_acc_grad(\n            JUST(functional::ToGlobal(acc_grad, placement, *JUST(GetSbpList(nd_sbp)),\n                                      GetNoneSbpList(), /* check_meta */ false, /*copy=*/false))));\n      }\n    }\n  }\n  return Maybe<void>::Ok();\n}\n\nvoid FunctionNode::ReleaseOutTensorArgs() {\n  for (const std::shared_ptr<AutogradMeta>& meta_data : output_meta_data_) {\n    meta_data->current_grad()->Release();\n  }\n}\n\nMaybe<bool> FunctionNode::Apply(bool create_graph) {\n  CHECK_NOTNULL_OR_RETURN(backward_fn_)\n      << \"This FunctionNode with name `\" << name() << \"` has been released.\\n\"\n      << \"Maybe you try to backward through the node a second time. Specify retain_graph=True when \"\n         \"calling .backward() or autograd.grad() the first time.\";\n  if (!IsReadyToRun(output_meta_data_)) { return false; }\n  TensorTuple input_grads(input_meta_data_.size());\n  TensorTuple output_grads(output_meta_data_.size());\n  for (int i = 0; i < output_meta_data_.size(); ++i) {\n    if (output_meta_data_[i]->current_grad()->Empty()) {\n      // Only initialize out_grads for those requires_grad outputs\n      if (output_meta_data_[i]->requires_grad()) {\n        output_grads[i] = JUST(output_tensor_infos_[i].zeros());\n      }\n    } else {\n      JUST(oneflow::VectorAt(output_grads, i)) =\n          JUST(JUST(oneflow::VectorAt(output_meta_data_, i))->current_grad_value());\n    }\n  }\n  JUST(backward_fn_->body(output_grads, &input_grads, create_graph));\n  for (const auto& hook : hooks_) {\n    auto new_input_grads = hook(input_grads, output_grads);\n    if (new_input_grads.has_value()) {\n      auto new_input_grads_value = *JUST(new_input_grads);\n      CHECK_EQ_OR_RETURN(new_input_grads_value.size(), input_grads.size())\n          << \"The number of input grads returned by hook is not correct, expected \"\n          << input_grads.size() << \", but got \" << new_input_grads_value.size() << \".\";\n      for (int i = 0; i < input_grads.size(); ++i) { input_grads[i] = new_input_grads_value[i]; }\n    }\n  }\n  for (int i = 0; i < input_meta_data_.size(); ++i) {\n    if (JUST(VectorAt(input_grads, i))) {\n      CHECK_NOTNULL_OR_RETURN(input_meta_data_[i])\n          << name_\n          << \" calculate grad for tensor which requires_grad is False. Please submit an issue in \"\n             \"`https://github.com/Oneflow-Inc/oneflow/issues` and we will fix it as soon as \"\n             \"possible\";\n      JUST(input_meta_data_[i]->current_grad()->PushPartialTensor(JUST(VectorAt(input_grads, i))));\n    } else {\n      CHECK_OR_RETURN(!input_meta_data_[i])\n          << name() << \"'s input[\" << i\n          << \"] need calculate grad but got nullptr. Please submit an issue in \"\n             \"`https://github.com/Oneflow-Inc/oneflow/issues` and we will fix it as soon as \"\n             \"possible;\";\n    }\n  }\n  return true;\n}\n\nvoid GraphFunctionNode::ReleaseData() {\n  if (backward_fn_ && backward_fn_->status()) { backward_fn_.reset(); }\n}\n\n/*static*/ std::shared_ptr<GraphFunctionNode> GraphFunctionNode::New(\n    const std::string& name, const std::shared_ptr<BackwardFunction>& backward_fn,\n    const TensorTuple& inputs, const TensorTuple& outputs) {\n  auto node = std::shared_ptr<GraphFunctionNode>(\n      new GraphFunctionNode(name, backward_fn, inputs, outputs), FunctionNodeDeleter);\n  return node;\n}\n\nGraphFunctionNode::GraphFunctionNode(const std::string& name,\n                                     const std::shared_ptr<BackwardFunction>& backward_fn,\n                                     const TensorTuple& inputs, const TensorTuple& outputs)\n    : FunctionNode(name, backward_fn) {\n  input_meta_data_.resize(inputs.size());\n  next_functions_.reserve(inputs.size());\n  for (int i = 0; i < inputs.size(); ++i) {\n    if (inputs.at(i)->requires_grad()) {\n      input_meta_data_.at(i) = inputs.at(i)->mut_autograd_meta();\n      next_functions_.emplace_back(inputs.at(i)->mut_grad_fn_node(), 0);\n    }\n  }\n\n  output_meta_data_.resize(outputs.size());\n  output_tensor_infos_.reserve(outputs.size());\n  for (int i = 0; i < outputs.size(); ++i) {\n    const auto& autograd_meta =\n        NewAutogradMeta(outputs.at(i)->requires_grad(), outputs.at(i)->is_leaf());\n    outputs.at(i)->set_autograd_meta(autograd_meta);\n    output_meta_data_.at(i) = outputs.at(i)->mut_autograd_meta();\n    output_tensor_infos_.emplace_back(*outputs.at(i));\n  }\n\n  backward_fn_ = backward_fn;\n}\n\nGraphTask::GraphTask(const TensorTuple& outputs, bool retain_graph, bool create_graph)\n    : retain_graph_(retain_graph), create_graph_(create_graph) {\n  roots_.reserve(outputs.size());\n  for (const auto& out_tensor : outputs) {\n    FunctionNode* node = out_tensor->mut_grad_fn_node().get();\n    roots_.emplace_back(node);\n  }\n}\n\nMaybe<void> GraphTask::WriteGraphToDotFile(const std::string& file_name) const {\n  auto ExecInfoToDotString = [](const ExecInfo& exec_info) -> std::string {\n    std::stringstream ss;\n    ss << \"ExecInfo{\\\\l\";\n    ss << \"\\tdependencies: \" << exec_info.dependencies << \"\\\\l\";\n    ss << \"\\tneed_execute: \" << exec_info.need_execute << \"\\\\l\";\n    if (exec_info.capture_indices) {\n      ss << \"\\tcapture_indices: [\";\n      for (const auto& out_idx_and_capture_idx : *exec_info.capture_indices) {\n        ss << out_idx_and_capture_idx.second << \", \";\n      }\n      ss << \"]\\\\l\";\n    }\n    ss << \"}\\\\l\";\n    return ss.str();\n  };\n\n  auto log_stream = TeePersistentLogStream::Create(file_name);\n  std::vector<std::string> lines;\n  lines.emplace_back(\"digraph AutogradTaskGraph {\");\n  lines.emplace_back(\"\\tmargin=\\\"1.5\\\";\");\n  lines.emplace_back(\"\\tnode [shape=box];\");\n  for (auto iter = grad_fn2exec_info_.begin(); iter != grad_fn2exec_info_.end(); ++iter) {\n    const FunctionNode* node = iter->first;\n    const ExecInfo& exec_info = iter->second;\n    // write label attribute\n    std::string node_color = \"black\";\n    if (exec_info.dependencies == 0 && exec_info.need_execute) {  // start node\n      node_color = \"red\";\n    } else if (exec_info.need_execute && exec_info.capture_indices) {  // end node\n      node_color = \"green\";\n    }\n    lines.emplace_back(fmt::format(\n        \"\\t\\\"{}\\\" [label=\\\"{}\\\\l{}\\\\l{}\\\", color={}];\", static_cast<const void*>(node),\n        node->name(), static_cast<const void*>(node), ExecInfoToDotString(exec_info), node_color));\n    // write edge\n    for (const auto& next_fn : node->next_functions()) {\n      lines.emplace_back(fmt::format(\"\\t\\\"{}\\\" -> \\\"{}\\\";\", static_cast<const void*>(node),\n                                     static_cast<const void*>(std::get<0>(next_fn).get())));\n    }\n  }\n  lines.emplace_back(\"}\");\n  log_stream << fmt::format(\"{}\", fmt::join(lines, \"\\n\"));\n  log_stream->Flush();\n  return Maybe<void>::Ok();\n}\n\n// Computes the number of dependencies for each FunctionNode\nMaybe<void> GraphTask::ComputeDependencies() {\n  HashSet<FunctionNode*> seen;\n  std::stack<FunctionNode*> stack;\n  for (FunctionNode* node : roots_) {\n    stack.push(node);\n    grad_fn2exec_info_[node].need_execute = true;\n  }\n\n  while (!stack.empty()) {\n    FunctionNode* node = stack.top();\n    stack.pop();\n    if (/*bool has_seen=*/!seen.insert(node).second) { continue; }\n    for (const auto& next_grad_fn : node->next_functions()) {\n      FunctionNode* next_node = std::get<0>(next_grad_fn).get();\n      ExecInfo& exec_info = grad_fn2exec_info_[next_node];\n      exec_info.dependencies += 1;\n      exec_info.need_execute = true;\n      if (seen.find(next_node) == seen.end()) { stack.push(next_node); }\n    }\n  }\n  return Maybe<void>::Ok();\n}\n\n// Computes the number of dependencies for each FunctionNode and prunes useless FunctionNode\n// according to input tensors\nMaybe<void> GraphTask::ComputeDependenciesAndPruneNode(const TensorTuple& inputs,\n                                                       bool allow_unused) {\n  struct NodeFrame {\n    explicit NodeFrame(FunctionNode* node) : node_(node), next_function_idx_(0) {}\n    FunctionNode* node_;\n    size_t next_function_idx_;\n\n    FunctionNode* GetNextFunction() {\n      if (next_function_idx_ < node_->next_functions().size()) {\n        next_function_idx_ += 1;\n        return std::get<0>(node_->next_functions().at(next_function_idx_ - 1)).get();\n      } else {\n        return nullptr;\n      }\n    }\n  };\n\n  // initialize all variable to capture grad for input tensors\n  captured_grads_ = std::make_shared<TensorTuple>(inputs.size());\n  for (int idx = 0; idx < inputs.size(); idx++) {\n    const auto& input = inputs[idx];\n    if (allow_unused && !input->mut_grad_fn_node().get()) { continue; }\n    CHECK_NOTNULL_OR_RETURN(input->mut_grad_fn_node().get())\n        << Error::RuntimeError()\n        << \"One of the differentiated Tensors appears to not have been used in the graph. Set \"\n           \"allow_unused=True if this is the desired behavior.\";\n    ExecInfo& exec_info = grad_fn2exec_info_[input->mut_grad_fn_node().get()];\n    exec_info.need_execute = true;\n    if (!exec_info.capture_indices) {\n      exec_info.capture_indices = std::make_unique<std::vector<std::pair<size_t, size_t>>>();\n    }\n    exec_info.capture_indices->emplace_back(std::make_pair(input->get_grad_fn_output_index(), idx));\n  }\n\n  HashSet<FunctionNode*> seen;\n  std::stack<NodeFrame> stack;\n\n  // Note: dfs to determine each FunctionNode should execute or not.\n  for (const auto& root : roots_) { stack.push(NodeFrame(root)); }\n  while (!stack.empty()) {\n    NodeFrame& frame = stack.top();\n    if (/*bool has_seen=*/seen.find(frame.node_) != seen.end()) {\n      stack.pop();\n      continue;\n    }\n    if (FunctionNode* node = frame.GetNextFunction()) {\n      grad_fn2exec_info_[node].dependencies += 1;\n      if (seen.find(node) == seen.end()) {\n        stack.push(NodeFrame(node));\n        continue;  // recurse\n      }\n    } else {\n      for (auto& fn : frame.node_->next_functions()) {\n        grad_fn2exec_info_[frame.node_].need_execute |=\n            grad_fn2exec_info_[std::get<0>(fn).get()].need_execute;\n      }\n      seen.insert(frame.node_);\n      stack.pop();\n    }\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> GraphTask::Apply(bool save_grad_for_leaf) {\n  std::queue<FunctionNode*> queue;\n  for (FunctionNode* node : roots_) {\n    if (grad_fn2exec_info_[node].dependencies == 0) { queue.push(node); }\n  }\n\n  while (!queue.empty()) {\n    FunctionNode* node = queue.front();\n    queue.pop();\n    auto& exec_info = grad_fn2exec_info_[node];\n\n    if (!exec_info.need_execute) {\n      node->ReleaseOutTensorArgs();\n      continue;\n    }\n    BackwardPassScopeGuard backward_guard(node->scope());\n    if (/*bool not_ready_to_apply=*/!(JUST(node->Apply(create_graph_)))) { continue; }\n    if (exec_info.capture_indices) {\n      CHECK_NOTNULL_OR_RETURN(captured_grads_.get()) << \"captured grads in GraphTask is nullptr\";\n      for (const auto& out_idx_and_capture_idx : *exec_info.capture_indices) {\n        JUST(VectorAt(*captured_grads_, out_idx_and_capture_idx.second)) =\n            JUST(JUST(VectorAt(node->output_meta_data_, out_idx_and_capture_idx.first))\n                     ->current_grad_value());\n      }\n    }\n    if (save_grad_for_leaf) { JUST(node->AccGrad4LeafTensor(create_graph_)); }\n    JUST(node->AccGrad4RetainGradTensor(create_graph_));\n    node->ReleaseOutTensorArgs();\n    if (!retain_graph_) { node->ReleaseData(); }\n\n    for (const auto& next_grad_fn : node->next_functions()) {\n      FunctionNode* next_node = std::get<0>(next_grad_fn).get();\n      int32_t& dependencies = grad_fn2exec_info_[next_node].dependencies;\n      dependencies -= 1;\n      if (dependencies == 0) { queue.push(next_node); }\n    }\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> GraphAutogradEngine::RunBackwardAndSaveGrads4LeafTensor(const TensorTuple& outputs,\n                                                                    const TensorTuple& out_grads,\n                                                                    bool retain_graph,\n                                                                    bool create_graph) {\n  for (int i = 0; i < outputs.size(); ++i) {\n    JUST(JUST(outputs.at(i)->current_grad())->PushPartialTensor(out_grads.at(i)));\n  }\n  GraphTask graph_task(outputs, retain_graph, create_graph);\n  JUST(graph_task.ComputeDependencies());\n  if (IsInDebugMode()) {\n    JUST(\n        graph_task.WriteGraphToDotFile(GetDebugGraphFileName(\"backward\", std::to_string(clock()))));\n  }\n  JUST(graph_task.Apply(/*save_grad_for_leaf=*/true));\n  return Maybe<void>::Ok();\n}\n\nMaybe<TensorTuple> GraphAutogradEngine::RunBackwardAndReturnInputsTensorGrad(\n    const TensorTuple& outputs, const TensorTuple& inputs, const TensorTuple& out_grads,\n    bool retain_graph, bool create_graph, bool allow_unused) {\n  for (int i = 0; i < outputs.size(); ++i) {\n    JUST(JUST(outputs.at(i)->current_grad())->PushPartialTensor(out_grads.at(i)));\n  }\n\n  GraphTask graph_task(outputs, retain_graph, create_graph);\n  JUST(graph_task.ComputeDependenciesAndPruneNode(inputs, allow_unused));\n  if (IsInDebugMode()) {\n    JUST(graph_task.WriteGraphToDotFile(GetDebugGraphFileName(\"grad\", std::to_string(clock()))));\n  }\n  JUST(graph_task.Apply(/*save_grad_for_leaf=*/false));\n  return graph_task.GetCapturedGrads();\n}\n\nMaybe<FunctionNode> GraphAutogradEngine::AddNode(\n    const std::string& name, const std::shared_ptr<BackwardFunction>& backward_fn,\n    const TensorTuple& inputs, TensorTuple* outputs) {\n  OF_PROFILER_RANGE_PUSH(\"AddAccumulateFunctionNode\");\n  // Firstly push function_node of tensor in stack which is leaf and requires_grad\n  for (const std::shared_ptr<Tensor>& in_tensor : inputs) {\n    if (in_tensor->is_leaf() && in_tensor->requires_grad()) {\n      if (!in_tensor->grad_fn_node()) { JUST(AddAccumulateFunctionNode(in_tensor)); }\n    }\n  }\n\n  OF_PROFILER_RANGE_POP();\n  OF_PROFILER_RANGE_PUSH(\"set_grad_fn_node\");\n  std::shared_ptr<FunctionNode> func_node =\n      GraphFunctionNode::New(name, backward_fn, inputs, *outputs);\n  for (int i = 0; i < outputs->size(); ++i) {\n    const std::shared_ptr<Tensor>& out_tensor = JUST(VectorAt(*outputs, i));\n    out_tensor->set_grad_fn_node(func_node);\n    out_tensor->set_grad_fn_output_index(i);\n  }\n  if (LazyMode::is_enabled()) { func_node->set_scope(JUST(GetCurrentScope())); }\n  OF_PROFILER_RANGE_POP();\n  return func_node;\n}\n\nAutogradEngine* GetThreadLocalAutogradEngine() {\n  thread_local static GraphAutogradEngine autograd_engine;\n  return &autograd_engine;\n}\n\nMaybe<void> AddAccumulateFunctionNode(const std::shared_ptr<Tensor>& tensor) {\n  auto backward_fn = std::make_shared<BackwardFunction>();\n  backward_fn->body = [=](const TensorTuple& out_grads, TensorTuple* in_grads,\n                          bool create_graph) -> Maybe<void> { return Maybe<void>::Ok(); };\n  backward_fn->status = []() { return false; };\n  tensor->set_grad_fn_node(GraphFunctionNode::New(\"accumulategrad\", backward_fn,\n                                                  /*inputs=*/TensorTuple{},\n                                                  /*outputs*/ TensorTuple{tensor}));\n  tensor->mut_grad_fn_node()->set_variable(tensor);\n  tensor->set_grad_fn_output_index(0);\n  if (LazyMode::is_enabled()) {\n    tensor->mut_grad_fn_node()->set_scope(JUST(GetTensorScope(tensor)));\n  }\n  return Maybe<void>::Ok();\n}\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/autograd_engine.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#ifndef ONEFLOW_CORE_AUTOGRAD_AUTOGRAD_ENGINE_H_\n#define ONEFLOW_CORE_AUTOGRAD_AUTOGRAD_ENGINE_H_\n\n#include <functional>\n#include <list>\n#include <memory>\n#include <vector>\n\n#include \"oneflow/core/autograd/autograd_meta.h\"\n#include \"oneflow/core/common/throw.h\"\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/framework/scope_util.h\"\n#include \"oneflow/core/job/lazy_mode.h\"\n\nnamespace oneflow {\n\nnamespace one {\n\nclass Tensor;\nclass TensorTuple;\n\nusing CaptureStatus = bool;\n\nstruct BackwardFunction {\n  std::function<Maybe<void>(const TensorTuple&, TensorTuple*, bool)> body;\n  std::function<CaptureStatus()> status;\n};\n\n// Calculates one backward op\nclass FunctionNode {\n public:\n  virtual ~FunctionNode() = default;\n\n  Maybe<bool> Apply(bool create_graph);\n  Maybe<void> AccGrad4LeafTensor(bool create_graph);\n  Maybe<void> AccGrad4RetainGradTensor(bool create_graph);\n  void ReleaseOutTensorArgs();\n  // Releases the eventual c++ std::function for backward if retain_graph=False to avoid calling\n  // `Apply` in second time\n  virtual void ReleaseData() = 0;\n\n  const std::vector<std::tuple<std::shared_ptr<FunctionNode>, int>>& next_functions() const {\n    return next_functions_;\n  }\n  const std::string& name() const { return name_; }\n\n  const std::shared_ptr<Scope>& scope() const { return scope_; }\n  void set_scope(const std::shared_ptr<Scope>& scope) { scope_ = scope; }\n  void set_variable(const std::weak_ptr<Tensor>& variable) { variable_ = variable; }\n  const Maybe<Tensor> Variable() const {\n    if (!variable_.lock()) { THROW(RuntimeError) << \"The tensor has already been deleted!\"; }\n    return variable_.lock();\n  }\n\n  using Hook = std::function<Optional<std::vector<std::shared_ptr<Tensor>>>(const TensorTuple&,\n                                                                            const TensorTuple&)>;\n  void add_post_hook(const Hook& hook) { hooks_.push_back(hook); }\n\n protected:\n  friend class GraphTask;\n  explicit FunctionNode(const std::string& name,\n                        const std::shared_ptr<BackwardFunction>& backward_fn)\n      : name_(name), backward_fn_(backward_fn), scope_(nullptr) {}\n\n  const std::string name_;\n  std::vector<std::tuple<std::shared_ptr<FunctionNode>, int>> next_functions_;\n\n  std::vector<std::shared_ptr<AutogradMeta>> input_meta_data_;\n  std::vector<std::shared_ptr<AutogradMeta>> output_meta_data_;\n  std::vector<TensorInfo> output_tensor_infos_;\n\n  // Actual backward function builds in `AutogradInterpreter` to calculate one backward op\n  std::shared_ptr<BackwardFunction> backward_fn_;\n  std::weak_ptr<Tensor> variable_;\n\n  // The execution scope\n  std::shared_ptr<Scope> scope_;\n\n  std::vector<Hook> hooks_;\n};\n\nclass AutogradEngine {\n public:\n  virtual ~AutogradEngine() = default;\n\n  Maybe<void> RunBackwardAndSaveGrads4LeafTensorIf(const TensorTuple& outputs,\n                                                   const TensorTuple& out_grads, bool retain_graph,\n                                                   bool create_graph);\n  Maybe<TensorTuple> RunBackwardAndReturnInputsTensorGradIf(const TensorTuple& outputs,\n                                                            const TensorTuple& inputs,\n                                                            const TensorTuple& out_grads,\n                                                            bool retain_graph, bool create_graph,\n                                                            bool allow_unused);\n  virtual void ClearEngine() = 0;\n  // Builds FunctionNode, binding to all `outputs_` tensors and saving in AutogradEngine\n  virtual Maybe<FunctionNode> AddNode(const std::string& name,\n                                      const std::shared_ptr<BackwardFunction>& backward_fn,\n                                      const TensorTuple& inputs, TensorTuple* outputs) = 0;\n\n protected:\n  AutogradEngine() = default;\n\n private:\n  virtual Maybe<void> RunBackwardAndSaveGrads4LeafTensor(const TensorTuple& outputs,\n                                                         const TensorTuple& out_grads,\n                                                         bool retain_graph, bool create_graph) = 0;\n  virtual Maybe<TensorTuple> RunBackwardAndReturnInputsTensorGrad(\n      const TensorTuple& outputs, const TensorTuple& inputs, const TensorTuple& out_grads,\n      bool retain_graph, bool create_graph, bool allow_unused) = 0;\n};\n\n// Graph Autograd Node and Engine\nclass GraphFunctionNode final : public FunctionNode {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(GraphFunctionNode);\n  static std::shared_ptr<GraphFunctionNode> New(\n      const std::string& name, const std::shared_ptr<BackwardFunction>& backward_fn,\n      const TensorTuple& inputs, const TensorTuple& outputs);\n\n  GraphFunctionNode() = delete;\n  ~GraphFunctionNode() override = default;\n\n  void ReleaseData() override;\n\n private:\n  GraphFunctionNode(const std::string& name, const std::shared_ptr<BackwardFunction>& backward_fn,\n                    const TensorTuple& inputs, const TensorTuple& outputs);\n};\n\nclass GraphTask final {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(GraphTask);\n  GraphTask() = delete;\n  GraphTask(const TensorTuple& outputs, bool retain_graph, bool create_graph);\n\n  Maybe<void> ComputeDependencies();\n  Maybe<void> ComputeDependenciesAndPruneNode(const TensorTuple& inputs, bool allow_unused);\n  Maybe<void> Apply(bool save_grad_for_leaf);\n  std::shared_ptr<TensorTuple> GetCapturedGrads() const { return captured_grads_; }\n  Maybe<void> WriteGraphToDotFile(const std::string& file_name) const;\n\n private:\n  class ExecInfo {\n   public:\n    ExecInfo() = default;\n\n    int32_t dependencies = 0;\n    bool need_execute = false;\n    // Used in autograd.grad interface, to record which grad of tensor will be captured.\n    // The pair means: <output index of this Node, the index of captured_grads_ to be saved>\n    std::unique_ptr<std::vector<std::pair<size_t, size_t>>> capture_indices;\n  };\n\n  bool retain_graph_;\n  bool create_graph_;\n  std::vector<FunctionNode*> roots_;\n  HashMap<FunctionNode*, ExecInfo> grad_fn2exec_info_;\n  std::shared_ptr<TensorTuple> captured_grads_;\n};\n\nclass GraphAutogradEngine final : public AutogradEngine {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(GraphAutogradEngine);\n  GraphAutogradEngine() = default;\n  ~GraphAutogradEngine() override = default;\n\n  void ClearEngine() override{};\n  Maybe<FunctionNode> AddNode(const std::string& name,\n                              const std::shared_ptr<BackwardFunction>& backward_fn,\n                              const TensorTuple& inputs, TensorTuple* outputs) override;\n\n private:\n  Maybe<void> RunBackwardAndSaveGrads4LeafTensor(const TensorTuple& outputs,\n                                                 const TensorTuple& out_grads, bool retain_graph,\n                                                 bool create_graph) override;\n  Maybe<TensorTuple> RunBackwardAndReturnInputsTensorGrad(const TensorTuple& outputs,\n                                                          const TensorTuple& inputs,\n                                                          const TensorTuple& out_grads,\n                                                          bool retain_graph, bool create_graph,\n                                                          bool allow_unused) override;\n};\n\nAutogradEngine* GetThreadLocalAutogradEngine();\n\nMaybe<void> AddAccumulateFunctionNode(const std::shared_ptr<Tensor>& tensor);\n\n}  // namespace one\n\n}  // namespace oneflow\n#endif  // ONEFLOW_CORE_AUTOGRAD_AUTOGRAD_ENGINE_H_\n"
  },
  {
    "path": "oneflow/core/autograd/autograd_function.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/autograd/autograd_function.h\"\n#include \"oneflow/core/framework/tensor_tuple.h\"\n#include \"oneflow/core/framework/op_expr.h\"\n#include \"oneflow/core/framework/op_interpreter/op_interpreter_util.h\"\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n\nnamespace oneflow {\nnamespace one {\n\n/*static*/ Maybe<TensorTuple> AutogradFunctionBase::Apply(const std::string& name,\n                                                          const FType& forward_fn,\n                                                          const FType& backward_fn,\n                                                          const TensorTuple& inputs) {\n  std::shared_ptr<TensorTuple> outputs = std::make_shared<TensorTuple>();\n  const auto& op = JUST(FunctionOpExpr::New(name, forward_fn, backward_fn));\n  JUST(OpInterpUtil::Dispatch(*op, inputs, outputs.get(), {}));\n  const HashSet<Tensor*>& non_differentiable_tensors = op->state()->NonDifferentiableTensors();\n  for (const auto& tensor : *outputs) {\n    if (non_differentiable_tensors.find(tensor.get()) != non_differentiable_tensors.end()) {\n      JUST(tensor->set_requires_grad(false));\n    }\n  }\n  return outputs;\n}\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/autograd_function.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#ifndef ONEFLOW_CORE_AUTOGRAD_AUTOGRAD_FUNCTION_H_\n#define ONEFLOW_CORE_AUTOGRAD_AUTOGRAD_FUNCTION_H_\n\n#include \"oneflow/core/common/util.h\"\n\nnamespace oneflow {\nnamespace one {\n\nclass TensorTuple;\nclass FunctionAutoGradCaptureState;\nclass FunctionOpExpr;\n\nclass AutogradFunctionBase {\n public:\n  using FType = std::function<std::shared_ptr<TensorTuple>(\n      const std::shared_ptr<FunctionAutoGradCaptureState>&, const TensorTuple&)>;\n  AutogradFunctionBase() = default;\n  virtual ~AutogradFunctionBase() = default;\n\n  static Maybe<TensorTuple> Apply(const std::string& name, const FType& forward_fn,\n                                  const FType& backward_fn, const TensorTuple& inputs);\n};\n\n}  // namespace one\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_AUTOGRAD_AUTOGRAD_FUNCTION_H_\n"
  },
  {
    "path": "oneflow/core/autograd/autograd_meta.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/framework/tensor.h\"\n#include \"oneflow/core/framework/dtype.h\"\n#include \"oneflow/core/framework/tensor_arg.h\"\n#include \"oneflow/core/autograd/autograd_meta.h\"\n#include \"oneflow/core/eager/eager_blob_object.h\"\n#include \"oneflow/core/eager/tensor_storage.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\n\nnamespace one {\n\nTensorInfo::TensorInfo(const Tensor& tensor) : shape_(tensor.shape()), dtype_(tensor.dtype()) {\n  if (tensor.is_global()) {\n    parallel_desc_ = CHECK_JUST(tensor.parallel_desc());\n    nd_sbp_ = CHECK_JUST(tensor.nd_sbp());\n  } else {\n    device_ = CHECK_JUST(tensor.device());\n  }\n}\n\nMaybe<const std::vector<Symbol<SbpParallel>>&> GetSbpTuple(Symbol<NdSbp> nd_sbp) {\n  static thread_local HashMap<Symbol<NdSbp>, std::vector<Symbol<SbpParallel>>> map;\n  auto iter = map.find(nd_sbp);\n  if (iter == map.end()) {\n    std::vector<Symbol<SbpParallel>> sbp_tuple;\n    sbp_tuple.reserve(nd_sbp->sbp_parallel().size());\n    for (const auto& sbp_parallel : nd_sbp->sbp_parallel()) {\n      sbp_tuple.push_back(SymbolOf(sbp_parallel));\n    }\n    iter = map.emplace(nd_sbp, sbp_tuple).first;\n  }\n  return iter->second;\n}\n\nMaybe<Tensor> TensorInfo::zeros() const {\n  if (device_.has_value()) {\n    const auto& device = JUST(device_);\n    return functional::Constant(*shape_.get(), 0, dtype_, device);\n  } else {\n    const auto& parallel_desc = JUST(parallel_desc_);\n    const auto& nd_sbp = JUST(nd_sbp_);\n    const auto& sbp_tuple = JUST(GetSbpTuple(nd_sbp));\n    return functional::GlobalConstant(*shape_.get(), 0, dtype_, parallel_desc, sbp_tuple);\n  }\n}\n\nAutogradMeta::AutogradMeta(bool requires_grad, bool is_leaf)\n    : is_leaf_(is_leaf),\n      requires_grad_(requires_grad),\n      retain_grad_(false),\n      current_grad_(new TensorArg) {}\n\nMaybe<void> AutogradMeta::set_acc_grad(const std::shared_ptr<Tensor>& grad) {\n  // NOTE(daquexian): update here if we support remat on global tensors\n  if (grad && acc_grad_ != nullptr && acc_grad_->is_eager() && acc_grad_->is_local()) {\n    // set old acc_grad evictable\n    if (auto rematable_storage = std::dynamic_pointer_cast<vm::RematableTensorStorage>(\n            JUST(acc_grad_->eager_blob_object())->tensor_storage())) {\n      rematable_storage->set_eviction_disabled(false);\n    }\n  }\n  if (const auto& static_zeros_tensor = std::dynamic_pointer_cast<StaticZerosTensor>(grad)) {\n    acc_grad_ = JUST(static_zeros_tensor->AsLocalTensor());\n  } else {\n    acc_grad_ = grad;\n  }\n  if (acc_grad_ != nullptr && acc_grad_->is_eager() && acc_grad_->is_local()) {\n    // set new acc_grad non-evictable\n    if (auto rematable_storage = std::dynamic_pointer_cast<vm::RematableTensorStorage>(\n            JUST(acc_grad_->eager_blob_object())->tensor_storage())) {\n      rematable_storage->set_eviction_disabled(true);\n    }\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<Tensor> AutogradMeta::current_grad_value() const {\n  std::shared_ptr<Tensor> res = JUST(current_grad_->GetAccTensor());\n  for (const auto& hook : hooks_) {\n    const auto& new_tensor = hook(res);\n    if (new_tensor) { res = new_tensor; }\n  }\n  return res;\n}\n\n}  // namespace one\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/autograd_meta.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#ifndef ONEFLOW_CORE_AUTOGRAD_AUTOGRAD_META_H_\n#define ONEFLOW_CORE_AUTOGRAD_AUTOGRAD_META_H_\n\n#include <memory>\n#include \"oneflow/core/common/data_type.pb.h\"\n#include \"oneflow/core/framework/dtype.h\"\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/common/symbol.h\"\n#include \"oneflow/core/common/optional.h\"\n\nnamespace oneflow {\n\nclass Shape;\n\nclass Device;\nclass ParallelDesc;\nclass NdSbp;\n\nnamespace one {\n\nclass Tensor;\nclass TensorArg;\nclass LocalTensor;\n\nclass AutogradMeta final {\n public:\n  AutogradMeta() = delete;\n  AutogradMeta(bool requires_grad, bool is_leaf);\n\n  // Getters\n  const std::shared_ptr<Tensor>& acc_grad() const { return acc_grad_; }\n  const std::shared_ptr<TensorArg>& current_grad() const { return current_grad_; }\n  // get current grad processed by hooks\n  Maybe<Tensor> current_grad_value() const;\n  bool requires_grad() const { return requires_grad_; }\n  bool is_leaf() const { return is_leaf_; }\n  bool retain_grad() const { return retain_grad_; }\n  using Hook = std::function<std::shared_ptr<Tensor>(const std::shared_ptr<const Tensor>&)>;\n  const std::vector<Hook>& hooks() const { return hooks_; }\n  const std::vector<Hook>& post_grad_accumulation_hooks() const {\n    return post_grad_accumulation_hooks_;\n  }\n\n  // Setters\n  Maybe<void> set_acc_grad(const std::shared_ptr<Tensor>& grad);\n  std::shared_ptr<Tensor> mut_acc_grad() { return acc_grad_; }\n  void set_requires_grad(bool requires_grad) { requires_grad_ = requires_grad; }\n  void set_retain_grad(bool retain_grad) { retain_grad_ = retain_grad; }\n  void set_is_leaf(bool is_leaf) { is_leaf_ = is_leaf; }\n  void add_hook(const Hook& hook) { hooks_.emplace_back(hook); }\n  void add_post_grad_accumulation_hook(const Hook& hook) {\n    post_grad_accumulation_hooks_.emplace_back(hook);\n  }\n\n private:\n  bool is_leaf_;\n\n  // Only meaningful on leaf Tensors (must be false otherwise)\n  bool requires_grad_;\n\n  // Only meaningful on non_leaf Tensors (must be false otherwise)\n  bool retain_grad_;\n\n  std::shared_ptr<Tensor> acc_grad_;\n  std::shared_ptr<TensorArg> current_grad_;\n  std::vector<Hook> hooks_;\n  std::vector<Hook> post_grad_accumulation_hooks_;\n};\n\ninline std::shared_ptr<AutogradMeta> NewAutogradMeta(bool requires_grad, bool is_leaf) {\n  return std::shared_ptr<AutogradMeta>(new AutogradMeta(requires_grad, is_leaf));\n}\n\nclass TensorInfo final {\n public:\n  TensorInfo() = delete;\n  explicit TensorInfo(const Tensor& tensor);\n\n  Maybe<Tensor> zeros() const;\n  Optional<Symbol<ParallelDesc>> placement() const { return parallel_desc_; }\n  Optional<Symbol<NdSbp>> sbp() const { return nd_sbp_; }\n\n private:\n  std::shared_ptr<const Shape> shape_;\n  Symbol<DType> dtype_;\n  Optional<Symbol<Device>> device_;               // for local tensor\n  Optional<Symbol<ParallelDesc>> parallel_desc_;  // for global tensor\n  Optional<Symbol<NdSbp>> nd_sbp_;                // for global tensor\n};\n\n}  // namespace one\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_AUTOGRAD_AUTOGRAD_META_H_\n"
  },
  {
    "path": "oneflow/core/autograd/autograd_mode.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/autograd/autograd_mode.h\"\n\nnamespace oneflow {\n\nnamespace autograd {\n\nnamespace {\n\nbool* GetThreadLocalGradMode() {\n  static thread_local bool g_grad_mode = true;\n  return &g_grad_mode;\n}\n\n}  // namespace\n\nbool GradMode::is_enabled() { return *GetThreadLocalGradMode(); }\n\nvoid GradMode::set_enabled(bool enabled) { *GetThreadLocalGradMode() = enabled; }\n\n}  // namespace autograd\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/autograd_mode.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#ifndef ONEFLOW_CORE_AUTOGRAD_AUTOGRAD_MODE_H_\n#define ONEFLOW_CORE_AUTOGRAD_AUTOGRAD_MODE_H_\n\nnamespace oneflow {\nnamespace autograd {\n\nstruct GradMode {\n  static bool is_enabled();\n  static void set_enabled(bool enabled);\n};\n\nclass AutoGradMode {\n public:\n  AutoGradMode(bool enabled) : prev_mode_(GradMode::is_enabled()) {\n    GradMode::set_enabled(enabled);\n  }\n  ~AutoGradMode() { GradMode::set_enabled(prev_mode_); }\n  bool prev_mode() const { return prev_mode_; }\n\n private:\n  bool prev_mode_;\n};\n\nclass NoGradGuard : public AutoGradMode {\n public:\n  NoGradGuard() : AutoGradMode(false){};\n};\n\n}  // namespace autograd\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_AUTOGRAD_AUTOGRAD_MODE_H_\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/activation.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/container_util.h\"\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct BaseActivationCaptureState : public AutoGradCaptureState {\n  bool requires_grad;\n};\n\nclass BaseActivation : public OpExprGradFunction<BaseActivationCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }\n\n  Maybe<void> Capture(BaseActivationCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    CHECK_EQ_OR_RETURN(inputs.size(), 1);   // NOLINT(maybe-need-error-msg)\n    CHECK_EQ_OR_RETURN(outputs.size(), 1);  // NOLINT(maybe-need-error-msg)\n    ctx->requires_grad = inputs.at(0)->requires_grad();\n    if (ctx->requires_grad) { ctx->SaveTensorForBackward(inputs.at(0)); }\n    return Maybe<void>::Ok();\n  }\n};\n\nclass Silu : public BaseActivation {\n public:\n  Maybe<void> Apply(const BaseActivationCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n    in_grads->resize(1);\n    if (ctx->requires_grad) {\n      const auto& x = ctx->SavedTensors().at(0);\n      in_grads->at(0) = JUST(functional::SiluGrad(out_grads.at(0), x));\n    }\n    return Maybe<void>::Ok();\n  }\n};\n\nclass Mish : public BaseActivation {\n public:\n  Maybe<void> Apply(const BaseActivationCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n    in_grads->resize(1);\n    if (ctx->requires_grad) {\n      const auto& x = ctx->SavedTensors().at(0);\n      in_grads->at(0) = JUST(functional::MishGrad(out_grads.at(0), x));\n    }\n    return Maybe<void>::Ok();\n  }\n};\n\nclass Selu : public BaseActivation {\n public:\n  Maybe<void> Apply(const BaseActivationCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n    in_grads->resize(1);\n    if (ctx->requires_grad) {\n      const auto& x = ctx->SavedTensors().at(0);\n      in_grads->at(0) = JUST(functional::SeluGrad(out_grads.at(0), x));\n    }\n    return Maybe<void>::Ok();\n  }\n};\n\nclass Softsign : public BaseActivation {\n public:\n  Maybe<void> Apply(const BaseActivationCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n    in_grads->resize(1);\n    if (ctx->requires_grad) {\n      const auto& x = ctx->SavedTensors().at(0);\n      in_grads->at(0) = JUST(functional::SoftSignGrad(out_grads.at(0), x));\n    }\n    return Maybe<void>::Ok();\n  }\n};\n\nclass GeLU : public BaseActivation {\n public:\n  Maybe<void> Apply(const BaseActivationCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n    in_grads->resize(1);\n    if (ctx->requires_grad) {\n      const auto& x = ctx->SavedTensors().at(0);\n      in_grads->at(0) = JUST(functional::GeluGrad(out_grads.at(0), x));\n    }\n    return Maybe<void>::Ok();\n  }\n};\n\nclass FastGeLU : public BaseActivation {\n public:\n  Maybe<void> Apply(const BaseActivationCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n    in_grads->resize(1);\n    if (ctx->requires_grad) {\n      const auto& x = ctx->SavedTensors().at(0);\n      in_grads->at(0) = JUST(functional::FastGeluGrad(out_grads.at(0), x));\n    }\n    return Maybe<void>::Ok();\n  }\n};\n\nstruct QuickGeluCaptureState : public AutoGradCaptureState {\n  bool requires_grad = false;\n};\n\nclass QuickGeLU : public OpExprGradFunction<QuickGeluCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }\n\n  Maybe<void> Capture(QuickGeluCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    CHECK_EQ_OR_RETURN(inputs.size(), 1);\n    CHECK_EQ_OR_RETURN(outputs.size(), 1);\n    ctx->requires_grad = inputs.at(0)->requires_grad();\n    if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n    ctx->SaveTensorForBackward(inputs.at(0));\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const QuickGeluCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    CHECK_EQ_OR_RETURN(out_grads.size(), 1);\n    in_grads->resize(1);\n    if (ctx->requires_grad) {\n      const auto& x = ctx->SavedTensors().at(0);\n      in_grads->at(0) = JUST(functional::QuickGeluGrad(out_grads.at(0), x));\n    }\n    return Maybe<void>::Ok();\n  }\n};\n\nstruct SquareReLUCaptureState : public AutoGradCaptureState {\n  bool requires_grad = false;\n};\n\nclass SquareReLU : public OpExprGradFunction<SquareReLUCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }\n\n  Maybe<void> Capture(SquareReLUCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    CHECK_EQ_OR_RETURN(inputs.size(), 1);   // NOLINT(maybe-need-error-msg)\n    CHECK_EQ_OR_RETURN(outputs.size(), 1);  // NOLINT(maybe-need-error-msg)\n    ctx->requires_grad = inputs.at(0)->requires_grad();\n    if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n    ctx->SaveTensorForBackward(inputs.at(0));\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const SquareReLUCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n    in_grads->resize(1);\n    if (ctx->requires_grad) {\n      const auto& x = ctx->SavedTensors().at(0);\n      in_grads->at(0) = JUST(functional::SquareReLUGrad(out_grads.at(0), x));\n    }\n    return Maybe<void>::Ok();\n  }\n};\n\nclass HardSigmoid : public BaseActivation {\n public:\n  Maybe<void> Apply(const BaseActivationCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n    in_grads->resize(1);\n    if (ctx->requires_grad) {\n      const auto& x = ctx->SavedTensors().at(0);\n      in_grads->at(0) = JUST(functional::HardSigmoidGrad(out_grads.at(0), x));\n    }\n    return Maybe<void>::Ok();\n  }\n};\n\nstruct HardShrinkCaptureState : public AutoGradCaptureState {\n  bool requires_grad = true;\n  double lambd = 0.5;\n};\n\nclass HardShrink : public OpExprGradFunction<HardShrinkCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override {\n    const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n    CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n    base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Capture(HardShrinkCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    CHECK_EQ_OR_RETURN(inputs.size(), 1);  // NOLINT(maybe-need-error-msg)\n    ctx->requires_grad = JUST(oneflow::VectorAt(inputs, 0))->requires_grad();\n    if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n\n    ComposedAttrMap composed_attrs(attrs, base_attrs_);\n    ctx->lambd = JUST(composed_attrs.GetAttr<double>(\"lambd\"));\n    ctx->SaveTensorForBackward(JUST(oneflow::VectorAt(outputs, 0)));\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const HardShrinkCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n    in_grads->resize(1);\n    if (ctx->requires_grad) {\n      const auto& y = JUST(oneflow::VectorAt(ctx->SavedTensors(), 0));\n      JUST(oneflow::VectorAt(*in_grads, 0)) =\n          JUST(functional::HardShrinkGrad(y, JUST(oneflow::VectorAt(out_grads, 0)), ctx->lambd));\n    }\n    return Maybe<void>::Ok();\n  }\n\n private:\n  AttrMap base_attrs_;\n};\n\nclass HardSwish : public BaseActivation {\n public:\n  Maybe<void> Apply(const BaseActivationCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n    in_grads->resize(1);\n    if (ctx->requires_grad) {\n      const auto& x = ctx->SavedTensors().at(0);\n      in_grads->at(0) = JUST(functional::HardSwishGrad(out_grads.at(0), x));\n    }\n    return Maybe<void>::Ok();\n  }\n};\n\n// ===== Activation with parms ====\nstruct ReLUCaptureState : public AutoGradCaptureState {\n  bool requires_grad;\n};\n\nclass ReLU : public OpExprGradFunction<ReLUCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }\n\n  Maybe<void> Capture(ReLUCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs,\n                      const AttrMap& attrs) const override {\n    CHECK_EQ_OR_RETURN(inputs.size(), 1);   // NOLINT(maybe-need-error-msg)\n    CHECK_EQ_OR_RETURN(outputs.size(), 1);  // NOLINT(maybe-need-error-msg)\n    ctx->requires_grad = inputs.at(0)->requires_grad();\n    if (ctx->requires_grad) { ctx->SaveTensorForBackward(outputs.at(0)); }\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const ReLUCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n    in_grads->resize(1);\n    if (ctx->requires_grad) {\n      const auto& y = ctx->SavedTensors().at(0);\n      in_grads->at(0) = JUST(functional::ReluGrad(out_grads.at(0), y));\n    }\n    return Maybe<void>::Ok();\n  }\n};\n\n// ===== Activation with parms ====\nstruct LeakyReluCaptureState : public AutoGradCaptureState {\n  bool requires_grad;\n  float alpha;\n};\n\nclass LeakyRelu : public OpExprGradFunction<LeakyReluCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override {\n    const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n    CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n    base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Capture(LeakyReluCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    CHECK_EQ_OR_RETURN(inputs.size(), 1);  // NOLINT(maybe-need-error-msg)\n    ctx->requires_grad = inputs.at(0)->requires_grad();\n    if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n\n    ComposedAttrMap composed_attrs(attrs, base_attrs_);\n    ctx->alpha = JUST(composed_attrs.GetAttr<float>(\"alpha\"));\n    ctx->SaveTensorForBackward(inputs.at(0));\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const LeakyReluCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n    in_grads->resize(1);\n    if (ctx->requires_grad) {\n      const auto& x = ctx->SavedTensors().at(0);\n      in_grads->at(0) = JUST(functional::LeakyReluGrad(x, out_grads.at(0), ctx->alpha));\n    }\n    return Maybe<void>::Ok();\n  }\n\n private:\n  AttrMap base_attrs_;\n};\n\nstruct SoftplusCaptureState : public AutoGradCaptureState {\n  bool requires_grad = true;\n  double beta = 1.0;\n  double threshold = 20.0;\n};\n\nclass Softplus : public OpExprGradFunction<SoftplusCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override {\n    const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n    CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n    base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Capture(SoftplusCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    CHECK_EQ_OR_RETURN(inputs.size(), 1);  // NOLINT(maybe-need-error-msg)\n\n    ComposedAttrMap composed_attrs(attrs, base_attrs_);\n    ctx->beta = JUST(composed_attrs.GetAttr<double>(\"beta\"));\n    ctx->threshold = JUST(composed_attrs.GetAttr<double>(\"threshold\"));\n    ctx->SaveTensorForBackward(JUST(oneflow::VectorAt(inputs, 0)));\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const SoftplusCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n    in_grads->resize(1);\n    if (ctx->requires_grad) {\n      const auto& x = JUST(oneflow::VectorAt(ctx->SavedTensors(), 0));\n      JUST(oneflow::VectorAt(*in_grads, 0)) = JUST(functional::SoftplusGrad(\n          x, JUST(oneflow::VectorAt(out_grads, 0)), ctx->beta, ctx->threshold));\n    }\n    return Maybe<void>::Ok();\n  }\n\n private:\n  AttrMap base_attrs_;\n};\n\nstruct HardTanhCaptureState : public AutoGradCaptureState {\n  bool requires_grad;\n  double min_val;\n  double max_val;\n};\n\nclass HardTanh : public OpExprGradFunction<HardTanhCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override {\n    const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n    CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n    base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Capture(HardTanhCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    CHECK_EQ_OR_RETURN(outputs.size(), 1);  // NOLINT(maybe-need-error-msg)\n    ctx->requires_grad = inputs.at(0)->requires_grad();\n    if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n\n    ComposedAttrMap composed_attrs(attrs, base_attrs_);\n    ctx->min_val = JUST(composed_attrs.GetAttr<double>(\"min_val\"));\n    ctx->max_val = JUST(composed_attrs.GetAttr<double>(\"max_val\"));\n    ctx->SaveTensorForBackward(outputs.at(0));\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const HardTanhCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n    in_grads->resize(1);\n    if (ctx->requires_grad) {\n      const auto& y = ctx->SavedTensors().at(0);\n      in_grads->at(0) =\n          JUST(functional::HardTanhGrad(y, out_grads.at(0), ctx->min_val, ctx->max_val));\n    }\n    return Maybe<void>::Ok();\n  }\n\n private:\n  AttrMap base_attrs_;\n};\n\nstruct EluCaptureState : public AutoGradCaptureState {\n  bool requires_grad;\n  double alpha;\n};\n\nclass Elu : public OpExprGradFunction<EluCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override {\n    const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n    CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n    base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Capture(EluCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs,\n                      const AttrMap& attrs) const override {\n    CHECK_EQ_OR_RETURN(inputs.size(), 1);  // NOLINT(maybe-need-error-msg)\n    ctx->requires_grad = inputs.at(0)->requires_grad();\n    if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n\n    ComposedAttrMap composed_attrs(attrs, base_attrs_);\n    ctx->alpha = JUST(composed_attrs.GetAttr<double>(\"alpha\"));\n    ctx->SaveTensorForBackward(inputs.at(0));\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const EluCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n    in_grads->resize(1);\n    if (ctx->requires_grad) {\n      const auto& x = ctx->SavedTensors().at(0);\n      in_grads->at(0) = JUST(functional::EluGrad(x, out_grads.at(0), ctx->alpha));\n    }\n    return Maybe<void>::Ok();\n  }\n\n private:\n  AttrMap base_attrs_;\n};\n\nstruct CeluCaptureState : public AutoGradCaptureState {\n  bool requires_grad = true;\n  double alpha = 1.0;\n};\n\nclass Celu : public OpExprGradFunction<CeluCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override {\n    const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n    CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n    base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Capture(CeluCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs,\n                      const AttrMap& attrs) const override {\n    CHECK_EQ_OR_RETURN(inputs.size(), 1);  // NOLINT(maybe-need-error-msg)\n    ctx->requires_grad = inputs.at(0)->requires_grad();\n    if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n\n    ComposedAttrMap composed_attrs(attrs, base_attrs_);\n    ctx->alpha = JUST(composed_attrs.GetAttr<double>(\"alpha\"));\n    ctx->SaveTensorForBackward(outputs.at(0));\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const CeluCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n    in_grads->resize(1);\n    if (ctx->requires_grad) {\n      const auto& y = ctx->SavedTensors().at(0);\n      in_grads->at(0) = JUST(functional::CeluGrad(y, out_grads.at(0), ctx->alpha));\n    }\n    return Maybe<void>::Ok();\n  }\n\n private:\n  AttrMap base_attrs_;\n};\n\nstruct SoftShrinkCaptureState : public AutoGradCaptureState {\n  bool requires_grad = true;\n  double alpha = 0.5;\n};\n\nclass SoftShrink : public OpExprGradFunction<SoftShrinkCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override {\n    const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n    CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n    base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Capture(SoftShrinkCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    CHECK_EQ_OR_RETURN(inputs.size(), 1);  // NOLINT(maybe-need-error-msg)\n    ctx->requires_grad = JUST(oneflow::VectorAt(inputs, 0))->requires_grad();\n    if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n\n    ComposedAttrMap composed_attrs(attrs, base_attrs_);\n    ctx->alpha = JUST(composed_attrs.GetAttr<double>(\"alpha\"));\n    ctx->SaveTensorForBackward(JUST(oneflow::VectorAt(outputs, 0)));\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const SoftShrinkCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n    in_grads->resize(1);\n    if (ctx->requires_grad) {\n      const auto& y = JUST(oneflow::VectorAt(ctx->SavedTensors(), 0));\n      JUST(oneflow::VectorAt(*in_grads, 0)) =\n          JUST(functional::SoftShrinkGrad(y, JUST(oneflow::VectorAt(out_grads, 0)), ctx->alpha));\n    }\n    return Maybe<void>::Ok();\n  }\n\n private:\n  AttrMap base_attrs_;\n};\n\nstruct PReLUCaptureState : public AutoGradCaptureState {\n  bool input_requires_grad;\n  bool alpha_requires_grad;\n};\n\nclass PReLU : public OpExprGradFunction<PReLUCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }\n\n  Maybe<void> Capture(PReLUCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs,\n                      const AttrMap& attrs) const override {\n    CHECK_EQ_OR_RETURN(inputs.size(), 2);                      // NOLINT(maybe-need-error-msg)\n    ctx->input_requires_grad = inputs.at(0)->requires_grad();  // input\n    ctx->alpha_requires_grad = inputs.at(1)->requires_grad();  // alpha\n    ctx->SaveTensorForBackward(inputs.at(0));\n    ctx->SaveTensorForBackward(inputs.at(1));\n\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const PReLUCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n    const auto& dy = out_grads.at(0);\n    const auto& x = ctx->SavedTensors().at(0);\n    const auto& alpha = ctx->SavedTensors().at(1);\n    in_grads->resize(2);\n    if (ctx->input_requires_grad || ctx->alpha_requires_grad) {\n      const auto& grads = JUST(functional::PReluGrad(dy, x, alpha));\n      if (ctx->input_requires_grad) { in_grads->at(0) = grads->at(0); }\n      if (ctx->alpha_requires_grad) { in_grads->at(1) = grads->at(1); }\n    }\n    return Maybe<void>::Ok();\n  }\n\n private:\n  std::shared_ptr<OpExpr> grad_op_;\n};\n\nstruct ThresholdCaptureState : public AutoGradCaptureState {\n  bool requires_grad = true;\n  double threshold = 0.0;\n};\n\nclass Threshold : public OpExprGradFunction<ThresholdCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override {\n    const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n    CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n    base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Capture(ThresholdCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    CHECK_EQ_OR_RETURN(inputs.size(), 1);  // NOLINT(maybe-need-error-msg)\n    ctx->requires_grad = JUST(oneflow::VectorAt(inputs, 0))->requires_grad();\n    if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n\n    ComposedAttrMap composed_attrs(attrs, base_attrs_);\n    ctx->threshold = JUST(composed_attrs.GetAttr<double>(\"threshold_val\"));\n    ctx->SaveTensorForBackward(JUST(oneflow::VectorAt(inputs, 0)));\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const ThresholdCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n    in_grads->resize(1);\n    if (ctx->requires_grad) {\n      const auto& x = JUST(oneflow::VectorAt(ctx->SavedTensors(), 0));\n      JUST(oneflow::VectorAt(*in_grads, 0)) =\n          JUST(functional::ThresholdGrad(x, JUST(oneflow::VectorAt(out_grads, 0)), ctx->threshold));\n    }\n    return Maybe<void>::Ok();\n  }\n\n private:\n  AttrMap base_attrs_;\n};\n\nstruct FracCaptureState : public AutoGradCaptureState {\n  bool requires_grad = false;\n};\n\nclass Frac : public OpExprGradFunction<FracCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override {\n    const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n    CHECK_NOTNULL_OR_RETURN(fw_op_expr);\n    base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Capture(FracCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs,\n                      const AttrMap& attrs) const override {\n    CHECK_EQ_OR_RETURN(inputs.size(), 1);\n    ctx->requires_grad = inputs.at(0)->requires_grad();\n    if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const FracCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    CHECK_EQ_OR_RETURN(out_grads.size(), 1);\n    in_grads->resize(1);\n    if (ctx->requires_grad) { in_grads->at(0) = out_grads.at(0); }\n    return Maybe<void>::Ok();\n  }\n\n private:\n  AttrMap base_attrs_;\n};\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"frac\", Frac);\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"silu\", Silu);\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"mish\", Mish);\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"selu\", Selu);\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"softsign\", Softsign);\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"relu\", ReLU);\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"gelu\", GeLU);\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"hardsigmoid\", HardSigmoid);\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"hardshrink\", HardShrink);\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"hardswish\", HardSwish);\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"leaky_relu\", LeakyRelu);\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"hardtanh\", HardTanh);\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"elu\", Elu);\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"celu\", Celu);\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"prelu\", PReLU);\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"threshold\", Threshold);\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"softplus\", Softplus);\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"softshrink\", SoftShrink);\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"fast_gelu\", FastGeLU);\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"quick_gelu\", QuickGeLU);\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"square_relu\", SquareReLU);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/adaptive_avg_pool.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/framework/op_builder.h\"\n#include \"oneflow/core/framework/op_expr.h\"\n#include \"oneflow/core/framework/op_interpreter/op_interpreter_util.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct AdaptivePoolCaptureState : public AutoGradCaptureState {\n  std::string data_format;\n  bool requires_grad;\n};\n\nclass AdaptivePoolNdGrad : public OpExprGradFunction<AdaptivePoolCaptureState> {\n public:\n  using OpExprGradFunction<AdaptivePoolCaptureState>::Init;\n\n  Maybe<void> Init(const OpExpr& op, std::string mode, const int& ndims);\n  Maybe<void> Capture(AdaptivePoolCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override;\n  Maybe<void> Apply(const AdaptivePoolCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override;\n\n private:\n  AttrMap base_attrs_;\n  std::string mode_;\n  int32_t ndims_;\n};\n\nMaybe<void> AdaptivePoolNdGrad::Init(const OpExpr& op, std::string mode, const int& ndims) {\n  const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n  CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n  base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n  mode_ = mode;\n  ndims_ = ndims;\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> AdaptivePoolNdGrad::Capture(AdaptivePoolCaptureState* ctx, const TensorTuple& inputs,\n                                        const TensorTuple& outputs, const AttrMap& attrs) const {\n  ctx->requires_grad = inputs.at(0)->requires_grad();\n  ctx->data_format = JUST(attrs.GetAttr<std::string>(\"data_format\"));\n  if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n\n  ctx->SaveTensorForBackward(inputs.at(0));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> AdaptivePoolNdGrad::Apply(const AdaptivePoolCaptureState* ctx,\n                                      const TensorTuple& out_grads, TensorTuple* in_grads) const {\n  if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n  CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n  const std::shared_ptr<oneflow::one::Tensor>& x = ctx->SavedTensors().at(0);\n  in_grads->resize(1);\n  in_grads->at(0) =\n      JUST(functional::AdaptivePoolNdGrad(x, out_grads.at(0), mode_, ndims_, ctx->data_format));\n  return Maybe<void>::Ok();\n}\n\nclass AdaptiveAvgPool1dGrad final : public AdaptivePoolNdGrad {\n public:\n  Maybe<void> Init(const OpExpr& op) override { return AdaptivePoolNdGrad::Init(op, \"avg\", 1); }\n};\n\nclass AdaptiveAvgPool2dGrad final : public AdaptivePoolNdGrad {\n public:\n  Maybe<void> Init(const OpExpr& op) override { return AdaptivePoolNdGrad::Init(op, \"avg\", 2); }\n};\n\nclass AdaptiveAvgPool3dGrad final : public AdaptivePoolNdGrad {\n public:\n  Maybe<void> Init(const OpExpr& op) override { return AdaptivePoolNdGrad::Init(op, \"avg\", 3); }\n};\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"adaptive_avg_pool1d\", AdaptiveAvgPool1dGrad);\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"adaptive_avg_pool2d\", AdaptiveAvgPool2dGrad);\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"adaptive_avg_pool3d\", AdaptiveAvgPool3dGrad);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/adaptive_max_pool.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/framework/op_builder.h\"\n#include \"oneflow/core/framework/op_expr.h\"\n#include \"oneflow/core/framework/op_interpreter/op_interpreter_util.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct AdaptiveMaxPoolCaptureState : public AutoGradCaptureState {\n  std::string data_format;\n  bool requires_grad = false;\n};\n\nclass AdaptiveMaxPoolNdGrad : public OpExprGradFunction<AdaptiveMaxPoolCaptureState> {\n public:\n  using OpExprGradFunction<AdaptiveMaxPoolCaptureState>::Init;\n\n  Maybe<void> Init(const OpExpr& op, const int& ndims);\n  Maybe<void> Capture(AdaptiveMaxPoolCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override;\n  Maybe<void> Apply(const AdaptiveMaxPoolCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override;\n\n private:\n  int32_t ndims_ = 0;\n};\n\nMaybe<void> AdaptiveMaxPoolNdGrad::Init(const OpExpr& op, const int& ndims) {\n  const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n  CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n  ndims_ = ndims;\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> AdaptiveMaxPoolNdGrad::Capture(AdaptiveMaxPoolCaptureState* ctx,\n                                           const TensorTuple& inputs, const TensorTuple& outputs,\n                                           const AttrMap& attrs) const {\n  ctx->requires_grad = inputs.at(0)->requires_grad();\n  if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n  ctx->data_format = JUST(attrs.GetAttr<std::string>(\"data_format\"));\n  ctx->SaveTensorForBackward(inputs.at(0));\n  ctx->SaveTensorForBackward(outputs.at(1));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> AdaptiveMaxPoolNdGrad::Apply(const AdaptiveMaxPoolCaptureState* ctx,\n                                         const TensorTuple& out_grads,\n                                         TensorTuple* in_grads) const {\n  if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n  CHECK_EQ_OR_RETURN(out_grads.size(), 2);  // NOLINT(maybe-need-error-msg)\n  const std::shared_ptr<oneflow::one::Tensor>& x = ctx->SavedTensors().at(0);\n  const std::shared_ptr<oneflow::one::Tensor>& index = ctx->SavedTensors().at(1);\n  in_grads->resize(1);\n  in_grads->at(0) =\n      JUST(functional::AdaptiveMaxPoolNdGrad(x, out_grads.at(0), index, ndims_, ctx->data_format));\n  return Maybe<void>::Ok();\n}\n\nclass AdaptiveMaxPool1dGrad final : public AdaptiveMaxPoolNdGrad {\n public:\n  Maybe<void> Init(const OpExpr& op) override { return AdaptiveMaxPoolNdGrad::Init(op, 1); }\n};\n\nclass AdaptiveMaxPool2dGrad final : public AdaptiveMaxPoolNdGrad {\n public:\n  Maybe<void> Init(const OpExpr& op) override { return AdaptiveMaxPoolNdGrad::Init(op, 2); }\n};\n\nclass AdaptiveMaxPool3dGrad final : public AdaptiveMaxPoolNdGrad {\n public:\n  Maybe<void> Init(const OpExpr& op) override { return AdaptiveMaxPoolNdGrad::Init(op, 3); }\n};\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"adaptive_max_pool1d\", AdaptiveMaxPool1dGrad);\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"adaptive_max_pool2d\", AdaptiveMaxPool2dGrad);\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"adaptive_max_pool3d\", AdaptiveMaxPool3dGrad);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/add_n.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct AddNCaptureState : public AutoGradCaptureState {\n  int32_t input_num;\n  std::vector<bool> requires_grad;\n};\n\nclass AddN : public OpExprGradFunction<AddNCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }\n\n  Maybe<void> Capture(AddNCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs,\n                      const AttrMap& attrs) const override {\n    ctx->input_num = inputs.size();\n    ctx->requires_grad.resize(inputs.size());\n    for (int i = 0; i < inputs.size(); ++i) {\n      ctx->requires_grad[i] = inputs.at(i)->requires_grad();\n    }\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const AddNCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n    in_grads->resize(ctx->input_num);\n    for (int i = 0; i < ctx->input_num; ++i) {\n      if (ctx->requires_grad.at(i)) { in_grads->at(i) = out_grads.at(0); }\n    }\n    return Maybe<void>::Ok();\n  }\n};\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"add_n\", AddN);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/affine_grid.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/framework/attr_map.h\"\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct AffineGridInterpState : public AutoGradCaptureState {\n  Shape size;\n  bool align_corners = false;\n  bool requires_grad = false;\n};\n\nclass AffineGrid : public OpExprGradFunction<AffineGridInterpState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override {\n    const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n    CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n    base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Capture(AffineGridInterpState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    CHECK_EQ_OR_RETURN(inputs.size(), 1);                // NOLINT(maybe-need-error-msg)\n    ctx->requires_grad = inputs.at(0)->requires_grad();  // theta\n    if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n\n    ComposedAttrMap composed_attrs(attrs, base_attrs_);\n    ctx->size = JUST(composed_attrs.GetAttr<Shape>(\"size\"));\n    ctx->align_corners = JUST(composed_attrs.GetAttr<bool>(\"align_corners\"));\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const AffineGridInterpState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n\n    CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n    in_grads->resize(1);\n    in_grads->at(0) =\n        JUST(functional::AffineGridGrad(out_grads.at(0), ctx->size, ctx->align_corners));\n    return Maybe<void>::Ok();\n  }\n\n private:\n  AttrMap base_attrs_;\n};\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"affine_grid\", AffineGrid);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/amp_white_identity.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/framework/op_expr.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\nnamespace one {\n\nenum class AmpIdentityType {\n  kWhite = 0,\n  kBlack,\n};\n\nstruct AmpIdentityCaptureState : public AutoGradCaptureState {};\n\ntemplate<AmpIdentityType type>\nclass AmpIdentityGrad : public OpExprGradFunction<AmpIdentityCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override {\n    const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n    CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Capture(AmpIdentityCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const AmpIdentityCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    in_grads->resize(1);\n    if (type == AmpIdentityType::kWhite) {\n      (*in_grads)[0] = JUST(functional::AmpWhiteIdentity(out_grads[0]));\n    } else if (type == AmpIdentityType::kBlack) {\n      (*in_grads)[0] = JUST(functional::AmpBlackIdentity(out_grads[0]));\n    } else {\n      (*in_grads)[0] = out_grads[0];\n    }\n    return Maybe<void>::Ok();\n  }\n};\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"amp_white_identity\", AmpIdentityGrad<AmpIdentityType::kWhite>);\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"amp_black_identity\", AmpIdentityGrad<AmpIdentityType::kBlack>);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/as_strided.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/framework/op_expr.h\"\n#include \"oneflow/core/framework/op_interpreter/op_interpreter_util.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct AsStridedCaptureState : public AutoGradCaptureState {\n  std::vector<int64_t> size;\n  std::vector<int64_t> stride;\n  int64_t storage_offset = 0;\n  bool requires_grad = false;\n};\n\nclass AsStrided : public OpExprGradFunction<AsStridedCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override;\n  Maybe<void> Capture(AsStridedCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override;\n  Maybe<void> Apply(const AsStridedCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override;\n\n private:\n  AttrMap base_attrs_;\n};\n\nMaybe<void> AsStrided::Init(const OpExpr& op) {\n  const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n  CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n  base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> AsStrided::Capture(AsStridedCaptureState* ctx, const TensorTuple& inputs,\n                               const TensorTuple& outputs, const AttrMap& attrs) const {\n  ctx->requires_grad = inputs.at(0)->requires_grad();\n  if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n\n  ctx->SaveTensorForBackward(inputs.at(0));\n\n  ComposedAttrMap composed_attrs(attrs, base_attrs_);\n  ctx->size = JUST(composed_attrs.GetAttr<std::vector<int64_t>>(\"size\"));\n  ctx->stride = JUST(composed_attrs.GetAttr<std::vector<int64_t>>(\"stride\"));\n  ctx->storage_offset = JUST(composed_attrs.GetAttr<int64_t>(\"storage_offset\"));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> AsStrided::Apply(const AsStridedCaptureState* ctx, const TensorTuple& out_grads,\n                             TensorTuple* in_grads) const {\n  if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n  CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n\n  const auto& input = ctx->SavedTensors().at(0);\n  std::vector<int64_t> size = ctx->size;\n  std::vector<int64_t> stride = ctx->stride;\n  int64_t storage_offset = ctx->storage_offset;\n\n  in_grads->at(0) =\n      JUST(functional::AsStridedGrad(out_grads.at(0), input, size, stride, storage_offset));\n  return Maybe<void>::Ok();\n}\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"as_strided\", AsStrided);\n\n}  // namespace one\n}  // namespace oneflow"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/avg_pool.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/attr_map.h\"\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/framework/op_builder.h\"\n#include \"oneflow/core/framework/op_interpreter/op_interpreter_util.h\"\n#include \"oneflow/core/framework/op_expr.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\nnamespace one {\n\nnamespace {\n\nstruct AvgPoolCaptureState : public AutoGradCaptureState {\n  bool requires_grad = false;\n  size_t input_index = 0;\n\n  std::string data_format;\n  std::vector<int32_t> padding;\n  std::vector<int32_t> kernel_size;\n  std::vector<int32_t> stride;\n  bool ceil_mode = false;\n  bool count_include_pad = false;\n  int32_t divisor_override = 0;\n};\n\nclass AvgPoolNdGrad : public OpExprGradFunction<AvgPoolCaptureState> {\n public:\n  virtual ~AvgPoolNdGrad() = default;\n  Maybe<void> Init(const OpExpr& op) override;\n  Maybe<void> Capture(AvgPoolCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override;\n  Maybe<void> Apply(const AvgPoolCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override;\n\n private:\n  AttrMap base_attrs_;\n};\n\nMaybe<void> AvgPoolNdGrad::Init(const OpExpr& op) {\n  const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n  CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n  base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> AvgPoolNdGrad::Capture(AvgPoolCaptureState* ctx, const TensorTuple& inputs,\n                                   const TensorTuple& outputs, const AttrMap& attrs) const {\n  ctx->requires_grad = inputs.at(0)->requires_grad();\n  if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n\n  ctx->input_index = ctx->SaveTensorForBackward(inputs.at(0));\n\n  ComposedAttrMap composed_attrs(attrs, base_attrs_);\n  ctx->data_format = JUST(composed_attrs.GetAttr<std::string>(\"data_format\"));\n  ctx->padding = JUST(composed_attrs.GetAttr<std::vector<int32_t>>(\"padding\"));\n  ctx->kernel_size = JUST(composed_attrs.GetAttr<std::vector<int32_t>>(\"kernel_size\"));\n  ctx->stride = JUST(composed_attrs.GetAttr<std::vector<int32_t>>(\"stride\"));\n  ctx->ceil_mode = JUST(composed_attrs.GetAttr<bool>(\"ceil_mode\"));\n  ctx->count_include_pad = JUST(composed_attrs.GetAttr<bool>(\"count_include_pad\"));\n  ctx->divisor_override = JUST(composed_attrs.GetAttr<int32_t>(\"divisor_override\"));\n\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> AvgPoolNdGrad::Apply(const AvgPoolCaptureState* ctx, const TensorTuple& out_grads,\n                                 TensorTuple* in_grads) const {\n  if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n  CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n\n  int32_t ndims = ctx->kernel_size.size();\n  const auto& input = ctx->SavedTensors().at(ctx->input_index);\n\n  in_grads->resize(1);\n  (*in_grads)[0] = JUST(functional::AvgPoolNdGrad(\n      input, out_grads[0], ndims, ctx->data_format, ctx->padding, ctx->kernel_size, ctx->stride,\n      ctx->ceil_mode, ctx->count_include_pad, ctx->divisor_override));\n\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"avg_pool_1d\", AvgPoolNdGrad);\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"avg_pool_2d\", AvgPoolNdGrad);\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"avg_pool_3d\", AvgPoolNdGrad);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/batch_gather.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/attr_map.h\"\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct BatchGatherCaptureState : public AutoGradCaptureState {\n  int64_t num_segments;\n  bool requires_grad;\n};\n\nclass BatchGather : public OpExprGradFunction<BatchGatherCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override;\n  Maybe<void> Capture(BatchGatherCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override;\n  Maybe<void> Apply(const BatchGatherCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override;\n};\n\nMaybe<void> BatchGather::Init(const OpExpr& op) {\n  const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n  CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> BatchGather::Capture(BatchGatherCaptureState* ctx, const TensorTuple& inputs,\n                                 const TensorTuple& outputs, const AttrMap& attrs) const {\n  ctx->requires_grad = inputs.at(0)->requires_grad();\n  if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n  const auto& in_shape = inputs.at(0)->shape();\n  const auto& indices_shape = inputs.at(1)->shape();\n  ctx->num_segments = in_shape->At(indices_shape->NumAxes() - 1);\n  ctx->SaveTensorForBackward(inputs.at(1));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> BatchGather::Apply(const BatchGatherCaptureState* ctx, const TensorTuple& out_grads,\n                               TensorTuple* in_grads) const {\n  in_grads->resize(2);\n  if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n  const auto& indices = ctx->SavedTensors().at(0);\n  in_grads->at(0) =\n      JUST(functional::UnsortedBatchSegmentSum(out_grads.at(0), indices, ctx->num_segments));\n  return Maybe<void>::Ok();\n}\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"batch_gather\", BatchGather);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/bias_add.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/framework/op_builder.h\"\n#include \"oneflow/core/framework/op_interpreter/op_interpreter_util.h\"\n#include \"oneflow/core/framework/op_expr.h\"\n#include \"oneflow/core/framework/attr_map.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct BiasAddCaptureState : public AutoGradCaptureState {\n  bool input_requires_grad;\n  bool bias_requires_grad;\n  int32_t axis;\n};\n\nclass BiasAdd : public OpExprGradFunction<BiasAddCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override {\n    const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n    CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n    base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Capture(BiasAddCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    CHECK_EQ_OR_RETURN(inputs.size(), 2);  // NOLINT(maybe-need-error-msg)\n    ctx->input_requires_grad = inputs.at(0)->requires_grad();\n    ctx->bias_requires_grad = inputs.at(1)->requires_grad();\n    ComposedAttrMap composed_attrs(attrs, base_attrs_);\n    ctx->axis = JUST(composed_attrs.GetAttr<int32_t>(\"axis\"));\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const BiasAddCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    const int64_t num_axes = out_grads.at(0)->shape()->NumAxes();\n    in_grads->resize(2);\n    if (ctx->bias_requires_grad) {\n      std::vector<int32_t> reduce_axes_vec;\n      reduce_axes_vec.reserve(num_axes);\n      for (int i = 0; i < num_axes; ++i) {\n        if (i != ctx->axis) { reduce_axes_vec.emplace_back(i); }\n      }\n      if (ctx->bias_requires_grad) {\n        in_grads->at(1) =\n            JUST(functional::ReduceSum(out_grads.at(0), reduce_axes_vec, false, NullOpt));\n      }\n    }\n    if (ctx->input_requires_grad) { in_grads->at(0) = out_grads.at(0); }\n\n    return Maybe<void>::Ok();\n  }\n\n private:\n  AttrMap base_attrs_;\n};\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"bias_add\", BiasAdd);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/binary_cross_entropy.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct BinaryCrossEntropyCaptureState : public AutoGradCaptureState {\n  bool input_requires_grad = false;\n  bool target_requires_grad = false;\n  bool has_weight = false;\n};\n\nclass BinaryCrossEntropy : public OpExprGradFunction<BinaryCrossEntropyCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override;\n  Maybe<void> Capture(BinaryCrossEntropyCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override;\n  Maybe<void> Apply(const BinaryCrossEntropyCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override;\n};\n\nMaybe<void> BinaryCrossEntropy::Init(const OpExpr& op) { return Maybe<void>::Ok(); }\n\nMaybe<void> BinaryCrossEntropy::Capture(BinaryCrossEntropyCaptureState* ctx,\n                                        const TensorTuple& inputs, const TensorTuple& outputs,\n                                        const AttrMap& attrs) const {\n  CHECK_OR_RETURN(inputs.size() >= 2 && inputs.size() <= 3);  // NOLINT(maybe-need-error-msg)\n  ctx->input_requires_grad = inputs[0]->requires_grad();\n  ctx->target_requires_grad = inputs[1]->requires_grad();\n  ctx->has_weight = inputs.size() == 3;\n\n  ctx->SaveTensorForBackward(inputs[0]);  // input\n  ctx->SaveTensorForBackward(inputs[1]);  // target\n  if (ctx->has_weight) {\n    ctx->SaveTensorForBackward(inputs[2]);  // weight\n  }\n  return Maybe<void>::Ok();\n}\nMaybe<void> BinaryCrossEntropy::Apply(const BinaryCrossEntropyCaptureState* ctx,\n                                      const TensorTuple& out_grads, TensorTuple* in_grads) const {\n  CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n  CHECK_EQ_OR_RETURN(ctx->SavedTensors().size(),\n                     2 + ctx->has_weight);  // NOLINT(maybe-need-error-msg)\n  in_grads->resize(2 + ctx->has_weight);\n\n  const auto& dy = out_grads[0];\n  const auto& input = ctx->SavedTensors()[0];\n  const auto& target = ctx->SavedTensors()[1];\n  const auto& weight = ctx->has_weight ? Optional<one::Tensor>(ctx->SavedTensors()[2]) : NullOpt;\n\n  if (ctx->input_requires_grad) {\n    (*in_grads)[0] = JUST(functional::BinaryCrossEntropyLossGrad(dy, input, target, weight));\n  }\n  if (ctx->target_requires_grad) {\n    (*in_grads)[1] = JUST(functional::BinaryCrossEntropyLossTargetGrad(dy, input, target, weight));\n  }\n  return Maybe<void>::Ok();\n}\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"binary_cross_entropy\", BinaryCrossEntropy);\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/binary_cross_entropy_with_logits.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct BinaryCrossEntropyWithLogitsCaptureState : public AutoGradCaptureState {\n  bool input_requires_grad = false;\n  bool target_requires_grad = false;\n  bool has_weight = false;\n  bool has_pos_weight = false;\n};\n\nclass BinaryCrossEntropyWithLogits\n    : public OpExprGradFunction<BinaryCrossEntropyWithLogitsCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override;\n  Maybe<void> Capture(BinaryCrossEntropyWithLogitsCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override;\n  Maybe<void> Apply(const BinaryCrossEntropyWithLogitsCaptureState* ctx,\n                    const TensorTuple& out_grads, TensorTuple* in_grads) const override;\n\n private:\n  AttrMap base_attrs_;\n};\n\nMaybe<void> BinaryCrossEntropyWithLogits::Init(const OpExpr& op) {\n  const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n  CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n  base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n  return Maybe<void>::Ok();\n}\nMaybe<void> BinaryCrossEntropyWithLogits::Capture(BinaryCrossEntropyWithLogitsCaptureState* ctx,\n                                                  const TensorTuple& inputs,\n                                                  const TensorTuple& outputs,\n                                                  const AttrMap& attrs) const {\n  CHECK_OR_RETURN(inputs.size() >= 2 && inputs.size() <= 4);  // NOLINT(maybe-need-error-msg)\n  ctx->input_requires_grad = inputs[0]->requires_grad();\n  ctx->target_requires_grad = inputs[1]->requires_grad();\n\n  ComposedAttrMap composed_attrs(attrs, base_attrs_);\n  ctx->has_pos_weight = JUST(composed_attrs.GetAttr<bool>(\"has_pos_weight\"));\n  ctx->has_weight = inputs.size() == 4 || (inputs.size() == 3 && !ctx->has_pos_weight);\n  ctx->SaveTensorForBackward(inputs[0]);  // input\n  ctx->SaveTensorForBackward(inputs[1]);  // target\n\n  if (inputs.size() == 3) {\n    ctx->SaveTensorForBackward(inputs[2]);  // weight or pos_weight\n  }\n  if (inputs.size() == 4) {\n    ctx->SaveTensorForBackward(inputs[2]);  // weight\n    ctx->SaveTensorForBackward(inputs[3]);  // pos_weight\n  }\n  return Maybe<void>::Ok();\n}\nMaybe<void> BinaryCrossEntropyWithLogits::Apply(const BinaryCrossEntropyWithLogitsCaptureState* ctx,\n                                                const TensorTuple& out_grads,\n                                                TensorTuple* in_grads) const {\n  CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n  CHECK_EQ_OR_RETURN(ctx->SavedTensors().size(),\n                     2 + ctx->has_weight + ctx->has_pos_weight);  // NOLINT(maybe-need-error-msg)\n  const auto& dy = out_grads[0];\n  const auto& input = ctx->SavedTensors()[0];\n  const auto& target = ctx->SavedTensors()[1];\n\n  in_grads->resize(ctx->SavedTensors().size());\n\n  size_t pos_weight_index = ctx->has_weight ? 3 : 2;\n  auto weight = ctx->has_weight ? Optional<one::Tensor>(ctx->SavedTensors()[2]) : NullOpt;\n  auto pos_weight =\n      ctx->has_pos_weight ? Optional<one::Tensor>(ctx->SavedTensors()[pos_weight_index]) : NullOpt;\n\n  if (ctx->input_requires_grad) {\n    (*in_grads)[0] = JUST(\n        functional::BinaryCrossEntropyWithLogitsLossGrad(dy, input, target, weight, pos_weight));\n  }\n  if (ctx->target_requires_grad) {\n    (*in_grads)[1] = JUST(functional::BinaryCrossEntropyWithLogitsLossTargetGrad(\n        dy, input, target, weight, pos_weight));\n  }\n\n  return Maybe<void>::Ok();\n}\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"binary_cross_entropy_with_logits\", BinaryCrossEntropyWithLogits);\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/binary_cross_entropy_with_logits_reduce_mean.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/common/container_util.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct BinaryCrossEntropyWithLogitsReduceMeanCaptureState : public AutoGradCaptureState {\n  bool input_requires_grad = false;\n  bool target_requires_grad = false;\n};\n\nclass BinaryCrossEntropyWithLogitsReduceMean\n    : public OpExprGradFunction<BinaryCrossEntropyWithLogitsReduceMeanCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override;\n  Maybe<void> Capture(BinaryCrossEntropyWithLogitsReduceMeanCaptureState* ctx,\n                      const TensorTuple& inputs, const TensorTuple& outputs,\n                      const AttrMap& attrs) const override;\n  Maybe<void> Apply(const BinaryCrossEntropyWithLogitsReduceMeanCaptureState* ctx,\n                    const TensorTuple& out_grads, TensorTuple* in_grads) const override;\n};\n\nMaybe<void> BinaryCrossEntropyWithLogitsReduceMean::Init(const OpExpr& op) {\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> BinaryCrossEntropyWithLogitsReduceMean::Capture(\n    BinaryCrossEntropyWithLogitsReduceMeanCaptureState* ctx, const TensorTuple& inputs,\n    const TensorTuple& outputs, const AttrMap& attrs) const {\n  CHECK_EQ_OR_RETURN(inputs.size(), 2);  // NOLINT(maybe-need-error-msg)\n  ctx->input_requires_grad = JUST(VectorAt(inputs, 0))->requires_grad();\n  ctx->target_requires_grad = JUST(VectorAt(inputs, 1))->requires_grad();\n\n  ctx->SaveTensorForBackward(JUST(VectorAt(inputs, 0)));  // input\n  ctx->SaveTensorForBackward(JUST(VectorAt(inputs, 1)));  // target\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> BinaryCrossEntropyWithLogitsReduceMean::Apply(\n    const BinaryCrossEntropyWithLogitsReduceMeanCaptureState* ctx, const TensorTuple& out_grads,\n    TensorTuple* in_grads) const {\n  CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n  const auto& dy = JUST(VectorAt(out_grads, 0));\n  const auto& input = JUST(VectorAt(ctx->SavedTensors(), 0));\n  const auto& target = JUST(VectorAt(ctx->SavedTensors(), 1));\n  in_grads->resize(2);\n\n  if (ctx->input_requires_grad) {\n    (*in_grads)[0] =\n        JUST(functional::BinaryCrossEntropyWithLogitsReduceMeanLossGrad(dy, input, target));\n  }\n  if (ctx->target_requires_grad) {\n    (*in_grads)[1] =\n        JUST(functional::BinaryCrossEntropyWithLogitsReduceMeanLossTargetGrad(dy, input, target));\n  }\n  return Maybe<void>::Ok();\n}\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"binary_cross_entropy_with_logits_reduce_mean\",\n                               BinaryCrossEntropyWithLogitsReduceMean);\n\n}  // namespace one\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/broadcast_binary_ops.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/container_util.h\"\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/framework/op_builder.h\"\n#include \"oneflow/core/framework/op_interpreter/op_interpreter_util.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct BroadcastBinaryCaptureState : public AutoGradCaptureState {\n  int x_index = -1;\n  int y_index = -1;\n  int z_index = -1;\n  bool x_requires_grad = false;\n  bool y_requires_grad = false;\n  bool broadcast_x = false;\n  bool broadcast_y = false;\n};\n\nclass BroadcastBinaryGrad : public OpExprGradFunction<BroadcastBinaryCaptureState> {\n public:\n  BroadcastBinaryGrad() = default;\n  virtual ~BroadcastBinaryGrad() = default;\n\n  virtual Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }\n\n  Maybe<void> Capture(BroadcastBinaryCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    CHECK_EQ_OR_RETURN(inputs.size(), 2);   // NOLINT(maybe-need-error-msg)\n    CHECK_EQ_OR_RETURN(outputs.size(), 1);  // NOLINT(maybe-need-error-msg)\n    ctx->x_requires_grad = inputs.at(0)->requires_grad();\n    ctx->y_requires_grad = inputs.at(1)->requires_grad();\n    ctx->broadcast_x = (*inputs.at(0)->shape() != *outputs.at(0)->shape());\n    ctx->broadcast_y = (*inputs.at(1)->shape() != *outputs.at(0)->shape());\n    return SaveTensorForBackward(ctx, inputs, outputs);\n  }\n\n protected:\n  virtual Maybe<void> SaveTensorForBackward(BroadcastBinaryCaptureState* ctx,\n                                            const TensorTuple& inputs,\n                                            const TensorTuple& outputs) const = 0;\n};\n\nclass BroadcastAdd : public BroadcastBinaryGrad {\n public:\n  Maybe<void> Apply(const BroadcastBinaryCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    in_grads->resize(2);\n    if (ctx->x_requires_grad) {\n      if (ctx->broadcast_x) {\n        const auto& x = ctx->SavedTensors().at(ctx->x_index);\n        in_grads->at(0) = JUST(functional::BroadcastReduceSumLike(out_grads.at(0), x));\n      } else {\n        in_grads->at(0) = out_grads.at(0);\n      }\n    }\n    if (ctx->y_requires_grad) {\n      if (ctx->broadcast_y) {\n        const auto& y = ctx->SavedTensors().at(ctx->y_index);\n        in_grads->at(1) = JUST(functional::BroadcastReduceSumLike(out_grads.at(0), y));\n      } else {\n        in_grads->at(1) = out_grads.at(0);\n      }\n    }\n    return Maybe<void>::Ok();\n  }\n\n protected:\n  Maybe<void> SaveTensorForBackward(BroadcastBinaryCaptureState* ctx, const TensorTuple& inputs,\n                                    const TensorTuple& outputs) const override {\n    if (ctx->x_requires_grad && ctx->broadcast_x) {\n      ctx->x_index = ctx->SaveTensorForBackward(inputs.at(0));\n    }\n    if (ctx->y_requires_grad && ctx->broadcast_y) {\n      ctx->y_index = ctx->SaveTensorForBackward(inputs.at(1));\n    }\n    return Maybe<void>::Ok();\n  }\n};\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"broadcast_add\", BroadcastAdd);\n\nclass BroadcastSub : public BroadcastBinaryGrad {\n public:\n  Maybe<void> Apply(const BroadcastBinaryCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    in_grads->resize(2);\n    if (ctx->x_requires_grad) {\n      if (ctx->broadcast_x) {\n        const auto& x = ctx->SavedTensors().at(ctx->x_index);\n        in_grads->at(0) = JUST(functional::BroadcastReduceSumLike(out_grads.at(0), x));\n      } else {\n        in_grads->at(0) = out_grads.at(0);\n      }\n    }\n    if (ctx->y_requires_grad) {\n      const auto& grad = JUST(functional::ScalarMul(out_grads.at(0), Scalar(-1.f), false));\n      if (ctx->broadcast_y) {\n        const auto& y = ctx->SavedTensors().at(ctx->y_index);\n        in_grads->at(1) = JUST(functional::BroadcastReduceSumLike(grad, y));\n      } else {\n        in_grads->at(1) = grad;\n      }\n    }\n    return Maybe<void>::Ok();\n  }\n\n protected:\n  Maybe<void> SaveTensorForBackward(BroadcastBinaryCaptureState* ctx, const TensorTuple& inputs,\n                                    const TensorTuple& outputs) const override {\n    if (ctx->x_requires_grad && ctx->broadcast_x) {\n      ctx->x_index = ctx->SaveTensorForBackward(inputs.at(0));\n    }\n    if (ctx->y_requires_grad && ctx->broadcast_y) {\n      ctx->y_index = ctx->SaveTensorForBackward(inputs.at(1));\n    }\n    return Maybe<void>::Ok();\n  }\n};\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"broadcast_sub\", BroadcastSub);\n\nclass BroadcastMul : public BroadcastBinaryGrad {\n public:\n  Maybe<void> Apply(const BroadcastBinaryCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    in_grads->resize(2);\n    if (ctx->x_requires_grad) {\n      const auto& y = ctx->SavedTensors().at(ctx->y_index);\n      const auto& x_grad = JUST(functional::Mul(out_grads.at(0), JUST(functional::Conj(y))));\n      if (ctx->broadcast_x) {\n        const auto& x = ctx->SavedTensors().at(ctx->x_index);\n        in_grads->at(0) = JUST(functional::BroadcastReduceSumLike(x_grad, x));\n      } else {\n        in_grads->at(0) = x_grad;\n      }\n    }\n    if (ctx->y_requires_grad) {\n      const auto& x = ctx->SavedTensors().at(ctx->x_index);\n      const auto& y_grad = JUST(functional::Mul(out_grads.at(0), JUST(functional::Conj(x))));\n      if (ctx->broadcast_y) {\n        const auto& y = ctx->SavedTensors().at(ctx->y_index);\n        in_grads->at(1) = JUST(functional::BroadcastReduceSumLike(y_grad, y));\n      } else {\n        in_grads->at(1) = y_grad;\n      }\n    }\n    return Maybe<void>::Ok();\n  }\n\n protected:\n  Maybe<void> SaveTensorForBackward(BroadcastBinaryCaptureState* ctx, const TensorTuple& inputs,\n                                    const TensorTuple& outputs) const override {\n    if (ctx->x_requires_grad) {\n      ctx->y_index = ctx->SaveTensorForBackward(inputs.at(1));\n      if (ctx->broadcast_x) { ctx->x_index = ctx->SaveTensorForBackward(inputs.at(0)); }\n    }\n    if (ctx->y_requires_grad) {\n      if (ctx->x_index == -1 /*x has not been saved*/) {\n        ctx->x_index = ctx->SaveTensorForBackward(inputs.at(0));\n      }\n      if (ctx->broadcast_y && ctx->y_index == -1 /*y has not been saved*/) {\n        ctx->y_index = ctx->SaveTensorForBackward(inputs.at(1));\n      }\n    }\n    return Maybe<void>::Ok();\n  }\n};\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"broadcast_mul\", BroadcastMul);\n\nclass BroadcastDiv : public BroadcastBinaryGrad {\n public:\n  Maybe<void> Apply(const BroadcastBinaryCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    in_grads->resize(2);\n    if (ctx->x_requires_grad) {\n      const auto& y = ctx->SavedTensors().at(ctx->y_index);\n      // const auto& x_grad = JUST(functional::Div(out_grads.at(0), y));\n      const auto& x_grad = JUST(functional::Div(out_grads.at(0), JUST(functional::Conj(y))));\n      if (ctx->broadcast_x) {\n        const auto& x = ctx->SavedTensors().at(ctx->x_index);\n        in_grads->at(0) = JUST(functional::BroadcastReduceSumLike(x_grad, x));\n      } else {\n        in_grads->at(0) = x_grad;\n      }\n    }\n    if (ctx->y_requires_grad) {\n      const auto& y = ctx->SavedTensors().at(ctx->y_index);\n      const auto& z = ctx->SavedTensors().at(ctx->z_index);\n      in_grads->at(1) = JUST(functional::DivGrad(out_grads.at(0), z, y));\n    }\n    return Maybe<void>::Ok();\n  }\n\n protected:\n  Maybe<void> SaveTensorForBackward(BroadcastBinaryCaptureState* ctx, const TensorTuple& inputs,\n                                    const TensorTuple& outputs) const override {\n    if (ctx->x_requires_grad) {\n      ctx->y_index = ctx->SaveTensorForBackward(inputs.at(1));\n      if (ctx->broadcast_x) { ctx->x_index = ctx->SaveTensorForBackward(inputs.at(0)); }\n    }\n    if (ctx->y_requires_grad) {\n      if (ctx->y_index == -1 /*y has not been saved*/) {\n        ctx->y_index = ctx->SaveTensorForBackward(inputs.at(1));\n      }\n      ctx->z_index = ctx->SaveTensorForBackward(outputs.at(0));\n    }\n    return Maybe<void>::Ok();\n  }\n};\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"broadcast_div\", BroadcastDiv);\n\nclass BroadcastPow : public BroadcastBinaryGrad {\n public:\n  Maybe<void> Apply(const BroadcastBinaryCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    const auto& x = ctx->SavedTensors().at(ctx->x_index);\n    const auto& y = ctx->SavedTensors().at(ctx->y_index);\n    in_grads->resize(2);\n    if (ctx->x_requires_grad) {\n      (*in_grads)[0] = JUST(functional::BroadcastPowXGrad(x, y, out_grads[0]));\n    }\n    if (ctx->y_requires_grad) {\n      (*in_grads)[1] = JUST(functional::BroadcastPowYGrad(x, y, out_grads[0]));\n    }\n    return Maybe<void>::Ok();\n  }\n\n protected:\n  Maybe<void> SaveTensorForBackward(BroadcastBinaryCaptureState* ctx, const TensorTuple& inputs,\n                                    const TensorTuple& outputs) const override {\n    ctx->x_index = ctx->SaveTensorForBackward(inputs[0]);\n    ctx->y_index = ctx->SaveTensorForBackward(inputs[1]);\n    return Maybe<void>::Ok();\n  }\n};\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"broadcast_pow\", BroadcastPow);\n\nclass BroadcastMinMax : public BroadcastBinaryGrad {\n public:\n  Maybe<void> Apply(const BroadcastBinaryCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    const auto& out_shape = *(out_grads.at(0)->shape());\n    in_grads->resize(2);\n    if (ctx->x_requires_grad || ctx->y_requires_grad) {\n      const auto& x = ctx->SavedTensors().at(ctx->x_index);\n      const auto& y = ctx->SavedTensors().at(ctx->y_index);\n      auto broad_x_ = x;\n      auto broad_y_ = y;\n      if (ctx->broadcast_x) {\n        const auto& x_shape = *(x->shape());\n        const Shape& left_extended_x_shape =\n            CreateLeftExtendedShape(ShapeView(x_shape), out_shape.NumAxes());\n        if (left_extended_x_shape == out_shape) {\n          broad_x_ = JUST(functional::ReshapeLike(x, JUST(VectorAt(out_grads, 0))));\n        } else {\n          const AxisVector& broadcast_axis_vec = left_extended_x_shape.Axes4BroadcastTo(out_shape);\n          const std::vector<int32_t> x_axis =\n              std::vector<int32_t>{broadcast_axis_vec.begin(), broadcast_axis_vec.end()};\n          broad_x_ = JUST(functional::BroadcastLike(x, JUST(VectorAt(out_grads, 0)), x_axis));\n        }\n      }\n      if (ctx->broadcast_y) {\n        const auto& y_shape = *(y->shape());\n        const Shape& left_extended_y_shape =\n            CreateLeftExtendedShape(ShapeView(y_shape), out_shape.NumAxes());\n        if (left_extended_y_shape == out_shape) {\n          broad_y_ = JUST(functional::ReshapeLike(y, JUST(VectorAt(out_grads, 0))));\n        } else {\n          const AxisVector& broadcast_axis_vec = left_extended_y_shape.Axes4BroadcastTo(out_shape);\n          const std::vector<int32_t> y_axis =\n              std::vector<int32_t>{broadcast_axis_vec.begin(), broadcast_axis_vec.end()};\n          broad_y_ = JUST(functional::BroadcastLike(y, JUST(VectorAt(out_grads, 0)), y_axis));\n        }\n      }\n      const auto& broad_grads =\n          JUST(elementwise_grad_functor_(out_grads.at(0), broad_x_, broad_y_));\n      if (ctx->x_requires_grad) {\n        if (ctx->broadcast_x) {\n          in_grads->at(0) = JUST(functional::BroadcastReduceSumLike(broad_grads->at(0), x));\n        } else {\n          in_grads->at(0) = broad_grads->at(0);\n        }\n      }\n      if (ctx->y_requires_grad) {\n        if (ctx->broadcast_y) {\n          in_grads->at(1) = JUST(functional::BroadcastReduceSumLike(broad_grads->at(1), y));\n        } else {\n          in_grads->at(1) = broad_grads->at(1);\n        }\n      }\n    }\n    return Maybe<void>::Ok();\n  }\n\n protected:\n  Maybe<void> SaveTensorForBackward(BroadcastBinaryCaptureState* ctx, const TensorTuple& inputs,\n                                    const TensorTuple& outputs) const override {\n    if (ctx->x_requires_grad || ctx->y_requires_grad) {\n      ctx->x_index = ctx->SaveTensorForBackward(inputs.at(0));\n      ctx->y_index = ctx->SaveTensorForBackward(inputs.at(1));\n    }\n    return Maybe<void>::Ok();\n  }\n\n  std::function<Maybe<TensorTuple>(const std::shared_ptr<Tensor>&, const std::shared_ptr<Tensor>&,\n                                   const std::shared_ptr<Tensor>&)>\n      elementwise_grad_functor_;\n};\n\nclass BroadcastMinimum : public BroadcastMinMax {\n public:\n  Maybe<void> Init(const OpExpr& op) override {\n    JUST(BroadcastMinMax::Init(op));\n    elementwise_grad_functor_ = functional::ElementwiseMinGrad;\n    return Maybe<void>::Ok();\n  }\n};\n\nclass BroadcastMaximum : public BroadcastMinMax {\n public:\n  Maybe<void> Init(const OpExpr& op) override {\n    JUST(BroadcastMinMax::Init(op));\n    elementwise_grad_functor_ = functional::ElementwiseMaxGrad;\n    return Maybe<void>::Ok();\n  }\n};\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"broadcast_minimum\", BroadcastMinimum);\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"broadcast_maximum\", BroadcastMaximum);\n\nclass BroadcastFMod : public BroadcastBinaryGrad {\n public:\n  Maybe<void> Apply(const BroadcastBinaryCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    const auto& out_shape = *(JUST(VectorAt(out_grads, 0))->shape());\n    in_grads->resize(2);\n    if (ctx->x_requires_grad || ctx->y_requires_grad) {\n      const auto& x = JUST(VectorAt(ctx->SavedTensors(), ctx->x_index));\n      const auto& y = JUST(VectorAt(ctx->SavedTensors(), ctx->y_index));\n      auto broad_x_ = x;\n      auto broad_y_ = y;\n      if (ctx->broadcast_x) {\n        const auto& x_shape = *(x->shape());\n        const Shape& left_extended_x_shape =\n            CreateLeftExtendedShape(ShapeView(x_shape), out_shape.NumAxes());\n        if (left_extended_x_shape == out_shape) {\n          broad_x_ = JUST(functional::ReshapeLike(x, JUST(VectorAt(out_grads, 0))));\n        } else {\n          const AxisVector& broadcast_axis_vec = left_extended_x_shape.Axes4BroadcastTo(out_shape);\n          const std::vector<int32_t> x_axis =\n              std::vector<int32_t>{broadcast_axis_vec.begin(), broadcast_axis_vec.end()};\n          broad_x_ = JUST(functional::BroadcastLike(x, JUST(VectorAt(out_grads, 0)), x_axis));\n        }\n      }\n      if (ctx->broadcast_y) {\n        const auto& y_shape = *(y->shape());\n        const Shape& left_extended_y_shape =\n            CreateLeftExtendedShape(ShapeView(y_shape), out_shape.NumAxes());\n        if (left_extended_y_shape == out_shape) {\n          broad_y_ = JUST(functional::ReshapeLike(y, JUST(VectorAt(out_grads, 0))));\n        } else {\n          const AxisVector& broadcast_axis_vec = left_extended_y_shape.Axes4BroadcastTo(out_shape);\n          const std::vector<int32_t> y_axis =\n              std::vector<int32_t>{broadcast_axis_vec.begin(), broadcast_axis_vec.end()};\n          broad_y_ = JUST(functional::BroadcastLike(y, JUST(VectorAt(out_grads, 0)), y_axis));\n        }\n      }\n      if (ctx->x_requires_grad) {\n        if (ctx->broadcast_x) {\n          JUST(VectorAt(*in_grads, 0)) =\n              JUST(functional::BroadcastReduceSumLike(JUST(VectorAt(out_grads, 0)), x));\n        } else {\n          JUST(VectorAt(*in_grads, 0)) = JUST(VectorAt(out_grads, 0));\n        }\n      }\n      if (ctx->y_requires_grad) {\n        auto result = JUST(functional::TruncDiv(broad_x_, broad_y_));\n        result = JUST(functional::Mul(JUST(VectorAt(out_grads, 0)), result));\n        JUST(functional::ScalarMul(result, Scalar(-1.f), true));\n        if (ctx->broadcast_y) {\n          in_grads->at(1) = JUST(functional::BroadcastReduceSumLike(result, y));\n        } else {\n          in_grads->at(1) = result;\n        }\n      }\n    }\n    return Maybe<void>::Ok();\n  }\n\n protected:\n  Maybe<void> SaveTensorForBackward(BroadcastBinaryCaptureState* ctx, const TensorTuple& inputs,\n                                    const TensorTuple& outputs) const override {\n    if (ctx->x_requires_grad && ctx->broadcast_x) {\n      ctx->x_index = ctx->SaveTensorForBackward(JUST(VectorAt(inputs, 0)));\n    }\n    if (ctx->y_requires_grad) {\n      ctx->x_index = ctx->SaveTensorForBackward(JUST(VectorAt(inputs, 0)));\n      ctx->y_index = ctx->SaveTensorForBackward(JUST(VectorAt(inputs, 1)));\n    }\n    return Maybe<void>::Ok();\n  }\n};\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"broadcast_fmod\", BroadcastFMod);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/broadcast_like.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/framework/op_builder.h\"\n#include \"oneflow/core/framework/op_expr.h\"\n#include \"oneflow/core/functional/functional.h\"\nnamespace oneflow {\nnamespace one {\n\nstruct BroadCastLikeCaptureState : public AutoGradCaptureState {\n  bool requires_grad;\n  size_t input_index;\n\n  std::vector<int32_t> broadcast_axes;\n};\n\nclass BroadCastLike : public OpExprGradFunction<BroadCastLikeCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override;\n  Maybe<void> Capture(BroadCastLikeCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override;\n  Maybe<void> Apply(const BroadCastLikeCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override;\n\n private:\n  AttrMap base_attrs_;\n};\n\nMaybe<void> BroadCastLike::Init(const OpExpr& op) {\n  const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n  CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n  base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> BroadCastLike::Capture(BroadCastLikeCaptureState* ctx, const TensorTuple& inputs,\n                                   const TensorTuple& outputs, const AttrMap& attrs) const {\n  ctx->requires_grad = inputs.at(0)->requires_grad();\n  if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n\n  ComposedAttrMap composed_attrs(attrs, base_attrs_);\n  ctx->broadcast_axes = JUST(composed_attrs.GetAttr<std::vector<int32_t>>(\"broadcast_axes\"));\n  ctx->input_index = ctx->SaveTensorForBackward(inputs.at(0));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> BroadCastLike::Apply(const BroadCastLikeCaptureState* ctx, const TensorTuple& out_grads,\n                                 TensorTuple* in_grads) const {\n  if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n  CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n\n  const auto& x = ctx->SavedTensors().at(ctx->input_index);\n  in_grads->resize(2);\n  in_grads->at(0) = JUST(functional::ReduceSumLike(out_grads.at(0), x, ctx->broadcast_axes));\n  return Maybe<void>::Ok();\n}\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"broadcast_like\", BroadCastLike);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/cast.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/framework/dtype.h\"\n#include \"oneflow/core/framework/op_builder.h\"\n#include \"oneflow/core/framework/op_interpreter/op_interpreter_util.h\"\n#include \"oneflow/core/framework/op_expr.h\"\n#include \"oneflow/core/functional/functional.h\"\n#include \"oneflow/core/common/symbol.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct CastCaptureState : public AutoGradCaptureState {\n  Symbol<DType> in_dtype;\n  Symbol<DType> out_dtype;\n};\n\nclass Cast : public OpExprGradFunction<CastCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override {\n    const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n    CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Capture(CastCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs,\n                      const AttrMap& attrs) const override {\n    ctx->in_dtype = inputs.at(0)->dtype();\n    ctx->out_dtype = outputs.at(0)->dtype();\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const CastCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    in_grads->resize(1);\n    if (!IsComplexDataType(ctx->in_dtype->data_type())\n        && IsComplexDataType(ctx->out_dtype->data_type())) {\n      (*in_grads)[0] = JUST(functional::Real(out_grads[0]));\n    } else {\n      (*in_grads)[0] = JUST(functional::Cast(out_grads[0], ctx->in_dtype, /*pin_memory=*/false));\n    }\n    return Maybe<void>::Ok();\n  }\n};\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"cast\", Cast);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/clip_by_scalar.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct ClipByScalarCaptureState : public AutoGradCaptureState {\n  bool requires_grad;\n  Scalar min;\n  Scalar max;\n};\n\nclass ClipByScalar : public OpExprGradFunction<ClipByScalarCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override {\n    const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n    CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n    base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Capture(ClipByScalarCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    CHECK_EQ_OR_RETURN(inputs.size(), 1);  // NOLINT(maybe-need-error-msg)\n    ctx->requires_grad = inputs.at(0)->requires_grad();\n    if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n    ctx->SaveTensorForBackward(inputs.at(0));\n\n    ComposedAttrMap composed_attrs(attrs, base_attrs_);\n    if (IsFloatingDataType(inputs.at(0)->dtype()->data_type())) {\n      ctx->min = Scalar(JUST(composed_attrs.GetAttr<double>(\"floating_min\")));\n      ctx->max = Scalar(JUST(composed_attrs.GetAttr<double>(\"floating_max\")));\n    } else if (IsIntegralDataType(inputs.at(0)->dtype()->data_type())) {\n      ctx->min = Scalar(JUST(composed_attrs.GetAttr<int64_t>(\"integral_min\")));\n      ctx->max = Scalar(JUST(composed_attrs.GetAttr<int64_t>(\"integral_max\")));\n    } else {\n      UNIMPLEMENTED_THEN_RETURN() << \"Data type is not floating or integral type.\";\n    }\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const ClipByScalarCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n    in_grads->resize(1);\n    if (ctx->requires_grad) {\n      const auto& x = ctx->SavedTensors().at(0);\n      in_grads->at(0) = JUST(functional::ClampGrad(out_grads.at(0), x, ctx->min, ctx->max));\n    }\n    return Maybe<void>::Ok();\n  }\n\n private:\n  AttrMap base_attrs_;\n};\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"clip_by_scalar\", ClipByScalar);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/clip_by_scalar_max.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct ClipByScalarMaxCaptureState : public AutoGradCaptureState {\n  bool requires_grad;\n  Scalar max;\n};\n\nclass ClipByScalarMax : public OpExprGradFunction<ClipByScalarMaxCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override {\n    const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n    CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n    base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Capture(ClipByScalarMaxCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    CHECK_EQ_OR_RETURN(inputs.size(), 1);  // NOLINT(maybe-need-error-msg)\n    ctx->requires_grad = inputs.at(0)->requires_grad();\n    if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n    ctx->SaveTensorForBackward(inputs.at(0));\n\n    ComposedAttrMap composed_attrs(attrs, base_attrs_);\n    if (IsFloatingDataType(inputs.at(0)->dtype()->data_type())) {\n      ctx->max = Scalar(JUST(composed_attrs.GetAttr<double>(\"floating_max\")));\n    } else if (IsIntegralDataType(inputs.at(0)->dtype()->data_type())) {\n      ctx->max = Scalar(JUST(composed_attrs.GetAttr<int64_t>(\"integral_max\")));\n    } else {\n      UNIMPLEMENTED_THEN_RETURN() << \"Data type is not floating or integral type.\";\n    }\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const ClipByScalarMaxCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n    in_grads->resize(1);\n    if (ctx->requires_grad) {\n      const auto& x = ctx->SavedTensors().at(0);\n      in_grads->at(0) = JUST(functional::ClampGrad(out_grads.at(0), x, /*min=*/NullOpt, ctx->max));\n    }\n    return Maybe<void>::Ok();\n  }\n\n private:\n  AttrMap base_attrs_;\n};\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"clip_by_scalar_max\", ClipByScalarMax);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/clip_by_scalar_min.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct ClipByScalarMinCaptureState : public AutoGradCaptureState {\n  bool requires_grad;\n  Scalar min;\n};\n\nclass ClipByScalarMin : public OpExprGradFunction<ClipByScalarMinCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override {\n    const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n    CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n    base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Capture(ClipByScalarMinCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    CHECK_EQ_OR_RETURN(inputs.size(), 1);  // NOLINT(maybe-need-error-msg)\n    ctx->requires_grad = inputs.at(0)->requires_grad();\n    if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n    ctx->SaveTensorForBackward(inputs.at(0));\n\n    ComposedAttrMap composed_attrs(attrs, base_attrs_);\n    if (IsFloatingDataType(inputs.at(0)->dtype()->data_type())) {\n      ctx->min = Scalar(JUST(composed_attrs.GetAttr<double>(\"floating_min\")));\n    } else if (IsIntegralDataType(inputs.at(0)->dtype()->data_type())) {\n      ctx->min = Scalar(JUST(composed_attrs.GetAttr<int64_t>(\"integral_min\")));\n    } else {\n      UNIMPLEMENTED_THEN_RETURN() << \"Data type is not floating or integral type.\";\n    }\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const ClipByScalarMinCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n    in_grads->resize(1);\n    if (ctx->requires_grad) {\n      const auto& x = ctx->SavedTensors().at(0);\n      in_grads->at(0) = JUST(functional::ClampGrad(out_grads.at(0), x, ctx->min,\n                                                   /*max=*/NullOpt));\n    }\n    return Maybe<void>::Ok();\n  }\n\n private:\n  AttrMap base_attrs_;\n};\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"clip_by_scalar_min\", ClipByScalarMin);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/combined_margin_loss.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/framework/attr_map.h\"\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct CombinedMarginLossCaptureState : public AutoGradCaptureState {\n  float m1;\n  float m2;\n  float m3;\n  int64_t depth;\n  size_t label_index;\n  size_t theta_index;\n  bool requires_grad;\n};\n\nclass CombinedMarginLoss : public OpExprGradFunction<CombinedMarginLossCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override {\n    const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n    CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n    base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Capture(CombinedMarginLossCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    CHECK_EQ_OR_RETURN(inputs.size(), 2);                // NOLINT(maybe-need-error-msg)\n    ctx->requires_grad = inputs.at(0)->requires_grad();  // x\n    if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n\n    ctx->label_index = ctx->SaveTensorForBackward(inputs.at(1));   // label\n    ctx->theta_index = ctx->SaveTensorForBackward(outputs.at(1));  // theta\n\n    ComposedAttrMap composed_attrs(attrs, base_attrs_);\n    ctx->m1 = JUST(composed_attrs.GetAttr<float>(\"m1\"));\n    ctx->m2 = JUST(composed_attrs.GetAttr<float>(\"m2\"));\n    ctx->m3 = JUST(composed_attrs.GetAttr<float>(\"m3\"));\n    ctx->depth = JUST(composed_attrs.GetAttr<int64_t>(\"depth\"));\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const CombinedMarginLossCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    CHECK_EQ_OR_RETURN(out_grads.size(), 2);  // NOLINT(maybe-need-error-msg)\n    in_grads->resize(2);\n\n    if (ctx->requires_grad) {\n      const auto& label = ctx->SavedTensors().at(ctx->label_index);\n      const auto& theta = ctx->SavedTensors().at(ctx->theta_index);\n      in_grads->at(0) = JUST(functional::CombinedMarginLossGrad(\n          out_grads.at(0), label, theta, ctx->m1, ctx->m2, ctx->m3, ctx->depth));\n    }\n    return Maybe<void>::Ok();\n  }\n\n private:\n  AttrMap base_attrs_;\n};\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"combined_margin_loss\", CombinedMarginLoss);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/complex.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct BaseComplexCaptureState : public AutoGradCaptureState {\n  bool requires_grad;\n};\n\n// TODO(lml): redesign these Apply method to support high order autograd.\nclass RealGrad : public OpExprGradFunction<BaseComplexCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }\n\n  Maybe<void> Capture(BaseComplexCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    CHECK_EQ_OR_RETURN(inputs.size(), 1);\n    CHECK_EQ_OR_RETURN(outputs.size(), 1);\n    ctx->requires_grad = inputs.at(0)->requires_grad();\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const BaseComplexCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    CHECK_EQ_OR_RETURN(out_grads.size(), 1);\n    in_grads->resize(1);\n    if (ctx->requires_grad) {\n      const auto& results = JUST(functional::RealGrad(out_grads.at(0)));\n      in_grads->at(0) = results;\n    }\n    return Maybe<void>::Ok();\n  }\n};\n\nclass ImagGrad : public OpExprGradFunction<BaseComplexCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }\n\n  Maybe<void> Capture(BaseComplexCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    CHECK_EQ_OR_RETURN(inputs.size(), 1);\n    CHECK_EQ_OR_RETURN(outputs.size(), 1);\n    ctx->requires_grad = inputs.at(0)->requires_grad();\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const BaseComplexCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    CHECK_EQ_OR_RETURN(out_grads.size(), 1);\n    in_grads->resize(1);\n    if (ctx->requires_grad) {\n      const auto& results = JUST(functional::ImagGrad(out_grads.at(0)));\n      in_grads->at(0) = results;\n    }\n    return Maybe<void>::Ok();\n  }\n};\n\nclass ConjPhysicalGrad : public OpExprGradFunction<BaseComplexCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }\n\n  Maybe<void> Capture(BaseComplexCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    CHECK_EQ_OR_RETURN(inputs.size(), 1);\n    CHECK_EQ_OR_RETURN(outputs.size(), 1);\n    ctx->requires_grad = inputs.at(0)->requires_grad();\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const BaseComplexCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    CHECK_EQ_OR_RETURN(out_grads.size(), 1);\n    in_grads->resize(1);\n    if (ctx->requires_grad) {\n      const auto& results = JUST(functional::ConjPhysical(out_grads.at(0)));\n      in_grads->at(0) = results;\n    }\n    return Maybe<void>::Ok();\n  }\n};\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"real\", RealGrad);\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"imag\", ImagGrad);\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"conj_physical\", ConjPhysicalGrad);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/concat.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/framework/op_builder.h\"\n#include \"oneflow/core/framework/op_expr.h\"\n#include \"oneflow/core/framework/op_interpreter/op_interpreter_util.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct ConcatCaptureState : public AutoGradCaptureState {\n  std::vector<bool> requires_grad;\n  int64_t axis;\n  int64_t input_num;\n};\n\nclass Concat : public OpExprGradFunction<ConcatCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override;\n  Maybe<void> Capture(ConcatCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override;\n  Maybe<void> Apply(const ConcatCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override;\n\n private:\n  AttrMap base_attrs_;\n};\n\nMaybe<void> Concat::Init(const OpExpr& op) {\n  const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n  CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n  base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> Concat::Capture(ConcatCaptureState* ctx, const TensorTuple& inputs,\n                            const TensorTuple& outputs, const AttrMap& attrs) const {\n  ctx->requires_grad.resize(inputs.size());\n  for (int i = 0; i < inputs.size(); ++i) { ctx->requires_grad[i] = inputs.at(i)->requires_grad(); }\n\n  ComposedAttrMap composed_attrs(attrs, base_attrs_);\n  ctx->axis = JUST(composed_attrs.GetAttr<int64_t>(\"axis\"));\n  for (const auto& input : inputs) { ctx->SaveTensorForBackward(input); }\n  ctx->input_num = inputs.size();\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> Concat::Apply(const ConcatCaptureState* ctx, const TensorTuple& out_grads,\n                          TensorTuple* in_grads) const {\n  CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n  in_grads->resize(ctx->input_num);\n  TensorTuple like(ctx->input_num);\n  for (int i = 0; i < ctx->input_num; ++i) { like[i] = ctx->SavedTensors().at(i); }\n  if (ctx->input_num == 1) {\n    in_grads->at(0) = out_grads.at(0);\n  } else {\n    const auto& results = JUST(functional::SplitLike(out_grads.at(0), like, ctx->axis));\n    CHECK_EQ_OR_RETURN(results->size(), ctx->input_num)\n        << Error::RuntimeError() << \"The size of results (\" << results->size()\n        << \") must match the size of inputs (\" << ctx->input_num << \")\";\n\n    for (int i = 0; i < ctx->input_num; ++i)\n      if (ctx->requires_grad.at(i)) { in_grads->at(i) = results->at(i); }\n  }\n  return Maybe<void>::Ok();\n}\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"cat\", Concat);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/conv.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/attr_map.h\"\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/framework/op_builder.h\"\n#include \"oneflow/core/framework/op_interpreter/op_interpreter_util.h\"\n#include \"oneflow/core/framework/op_expr.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct ConvolutionNdCaptureState : public AutoGradCaptureState {\n  bool input_requires_grad = false;\n  bool weight_requires_grad = false;\n  bool has_bias = false;\n  bool bias_requires_grad = false;\n  size_t input_index;\n  size_t weight_index;\n\n  std::string data_format;\n  std::vector<int32_t> padding_before;\n  std::vector<int32_t> kernel_size;\n  std::vector<int32_t> strides;\n  std::vector<int32_t> dilation_rate;\n  int32_t groups;\n};\n\nclass ConvolutionNd : public OpExprGradFunction<ConvolutionNdCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override;\n  Maybe<void> Capture(ConvolutionNdCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override;\n  Maybe<void> Apply(const ConvolutionNdCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override;\n\n private:\n  AttrMap base_attrs_;\n};\n\nMaybe<void> ConvolutionNd::Init(const OpExpr& op) {\n  const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n  CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n  base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> ConvolutionNd::Capture(ConvolutionNdCaptureState* ctx, const TensorTuple& inputs,\n                                   const TensorTuple& outputs, const AttrMap& attrs) const {\n  CHECK_OR_RETURN(inputs.size() == 2 || inputs.size() == 3);  // NOLINT(maybe-need-error-msg)\n  ctx->input_requires_grad = inputs.at(0)->requires_grad();\n  ctx->weight_requires_grad = inputs.at(1)->requires_grad();\n  if (inputs.size() == 3) {\n    ctx->has_bias = true;\n    ctx->bias_requires_grad = inputs.at(2)->requires_grad();\n  }\n\n  if (!ctx->input_requires_grad && !ctx->weight_requires_grad && !ctx->bias_requires_grad) {\n    return Maybe<void>::Ok();\n  }\n  if (ctx->input_requires_grad) {\n    ctx->weight_index = ctx->SaveTensorForBackward(inputs.at(1));  // weight\n  }\n  ctx->input_index = ctx->SaveTensorForBackward(inputs.at(0));  // input\n\n  ComposedAttrMap composed_attrs(attrs, base_attrs_);\n  ctx->data_format = JUST(composed_attrs.GetAttr<std::string>(\"data_format\"));\n  ctx->padding_before = JUST(composed_attrs.GetAttr<std::vector<int32_t>>(\"padding_before\"));\n  ctx->kernel_size = JUST(composed_attrs.GetAttr<std::vector<int32_t>>(\"kernel_size\"));\n  ctx->strides = JUST(composed_attrs.GetAttr<std::vector<int32_t>>(\"strides\"));\n  ctx->dilation_rate = JUST(composed_attrs.GetAttr<std::vector<int32_t>>(\"dilation_rate\"));\n  ctx->groups = JUST(composed_attrs.GetAttr<int32_t>(\"groups\"));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> ConvolutionNd::Apply(const ConvolutionNdCaptureState* ctx, const TensorTuple& out_grads,\n                                 TensorTuple* in_grads) const {\n  if (ctx->has_bias) {\n    in_grads->resize(3);\n  } else {\n    in_grads->resize(2);\n  }\n  size_t num_spatial_dims = ctx->kernel_size.size();\n  if (ctx->input_requires_grad) {\n    const auto& weight = ctx->SavedTensors().at(ctx->weight_index);\n    const auto& input = ctx->SavedTensors().at(ctx->input_index);\n    in_grads->at(0) = JUST(functional::ConvDataGrad(\n        out_grads.at(0), weight, input, num_spatial_dims, ctx->kernel_size, ctx->strides,\n        ctx->padding_before, ctx->dilation_rate, ctx->groups, ctx->data_format));\n  }\n  if (ctx->weight_requires_grad) {\n    const auto& input = ctx->SavedTensors().at(ctx->input_index);\n    in_grads->at(1) = JUST(functional::ConvFilterGrad(\n        out_grads.at(0), input, num_spatial_dims, ctx->kernel_size, ctx->strides,\n        ctx->padding_before, ctx->dilation_rate, ctx->groups, ctx->data_format));\n  }\n  if (ctx->bias_requires_grad) {\n    std::vector<int32_t> dim;\n    for (int i = 0; i < out_grads.at(0)->shape()->NumAxes(); ++i) {\n      if ((ctx->data_format == \"channels_first\" && i == 1)\n          || (ctx->data_format == \"channels_last\"\n              && i == out_grads.at(0)->shape()->NumAxes() - 1)) {\n        continue;\n      }\n      dim.push_back(i);\n    }\n    in_grads->at(2) = JUST(functional::ReduceSum(out_grads.at(0), dim, false, NullOpt));\n  }\n  return Maybe<void>::Ok();\n}\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"conv1d\", ConvolutionNd);\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"conv2d\", ConvolutionNd);\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"conv3d\", ConvolutionNd);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/copy.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/framework/device.h\"\n#include \"oneflow/core/framework/op_builder.h\"\n#include \"oneflow/core/framework/op_interpreter/op_interpreter_util.h\"\n#include \"oneflow/core/framework/op_expr.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct CopyCaptureState : public AutoGradCaptureState {\n  std::string device_type;\n  int64_t device_id;\n};\n\nclass Copy : public OpExprGradFunction<CopyCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override {\n    const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n    CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Capture(CopyCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs,\n                      const AttrMap& attrs) const override {\n    CHECK_EQ_OR_RETURN(inputs.size(), 1);  // NOLINT(maybe-need-error-msg)\n    if (inputs[0]->is_global()) {\n      ctx->device_type = JUST(inputs[0]->parallel_desc())->device_tag();\n      ctx->device_id = 0;  // global tensor only has one local device\n    } else {\n      ctx->device_type = JUST(inputs[0]->device())->type();\n      ctx->device_id = JUST(inputs[0]->device())->device_id();\n    }\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const CopyCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    in_grads->resize(1);\n    (*in_grads)[0] = JUST(\n        functional::Copy(out_grads[0], ctx->device_type, ctx->device_id, /*pin_memory=*/false));\n    return Maybe<void>::Ok();\n  }\n};\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"copy\", Copy);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/ctc_loss.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/attr_map.h\"\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/framework/op_builder.h\"\n#include \"oneflow/core/framework/op_expr.h\"\n#include \"oneflow/core/framework/op_interpreter/op_interpreter_util.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct CTCLossCaptureState : public AutoGradCaptureState {\n  int64_t max_target_length;\n  int32_t blank;\n  bool zero_infinity;\n  bool requires_grad;\n};\n\nclass CTCLoss : public OpExprGradFunction<CTCLossCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override;\n  Maybe<void> Capture(CTCLossCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override;\n  Maybe<void> Apply(const CTCLossCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override;\n\n private:\n  AttrMap base_attrs_;\n  std::shared_ptr<OpExpr> grad_op_;\n};\n\nMaybe<void> CTCLoss::Init(const OpExpr& op) {\n  const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n  CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n  base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> CTCLoss::Capture(CTCLossCaptureState* ctx, const TensorTuple& inputs,\n                             const TensorTuple& outputs, const AttrMap& attrs) const {\n  ctx->requires_grad = inputs.at(0)->requires_grad();\n  if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n\n  ComposedAttrMap composed_attrs(attrs, base_attrs_);\n  ctx->max_target_length = JUST(composed_attrs.GetAttr<int64_t>(\"max_target_length\"));\n  ctx->blank = JUST(composed_attrs.GetAttr<int64_t>(\"blank\"));\n  ctx->zero_infinity = JUST(composed_attrs.GetAttr<bool>(\"zero_infinity\"));\n\n  CHECK_EQ_OR_RETURN(inputs.size(), 4);       // NOLINT(maybe-need-error-msg)\n  CHECK_EQ_OR_RETURN(outputs.size(), 2);      // NOLINT(maybe-need-error-msg)\n  ctx->SaveTensorForBackward(outputs.at(0));  // loss\n  ctx->SaveTensorForBackward(outputs.at(1));  // alpha\n  ctx->SaveTensorForBackward(inputs.at(0));   // log_probs\n  ctx->SaveTensorForBackward(inputs.at(1));   // targets\n  ctx->SaveTensorForBackward(inputs.at(2));   // input_lengths\n  ctx->SaveTensorForBackward(inputs.at(3));   // target_lengths\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> CTCLoss::Apply(const CTCLossCaptureState* ctx, const TensorTuple& out_grads,\n                           TensorTuple* in_grads) const {\n  if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n  CHECK_EQ_OR_RETURN(out_grads.size(), 2);  // NOLINT(maybe-need-error-msg)\n\n  const auto& grad_out = out_grads.at(0);\n  const auto& loss = ctx->SavedTensors().at(0);\n  const auto& alpha = ctx->SavedTensors().at(1);\n  const auto& log_probs = ctx->SavedTensors().at(2);\n  const auto& targets = ctx->SavedTensors().at(3);\n  const auto& input_lengths = ctx->SavedTensors().at(4);\n  const auto& target_lengths = ctx->SavedTensors().at(5);\n  in_grads->resize(4);\n  in_grads->at(0) = JUST(functional::CtcLossGrad(grad_out, log_probs, targets, input_lengths,\n                                                 target_lengths, loss, alpha, ctx->blank,\n                                                 ctx->zero_infinity, ctx->max_target_length));\n  return Maybe<void>::Ok();\n}\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"ctc_loss\", CTCLoss);\n\n}  // namespace one\n}  // namespace oneflow"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/cublas_fused_mlp.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/error.pb.h\"\n#include \"oneflow/core/common/just.h\"\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/framework/op_builder.h\"\n#include \"oneflow/core/framework/op_expr.h\"\n#include \"oneflow/core/framework/op_interpreter/op_interpreter_util.h\"\n#include \"oneflow/core/common/container_util.h\"\n#include \"oneflow/core/functional/functional.h\"\n#include \"oneflow/core/functional/functional_api.yaml.h\"\n#if CUDA_VERSION >= 11060\n\nnamespace oneflow {\n\nnamespace one {\n\nstruct CublasFusedMLPCaptureState : public AutoGradCaptureState {\n  int32_t weight_num = 0;\n  bool skip_final_activation = false;\n  bool x_requires_grad = false;\n  std::vector<bool> weights_requires_grad;\n  std::vector<bool> biases_requires_grad;\n};\n\nclass CublasFusedMLP : public OpExprGradFunction<CublasFusedMLPCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override;\n  Maybe<void> Capture(CublasFusedMLPCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override;\n  Maybe<void> Apply(const CublasFusedMLPCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override;\n\n protected:\n  AttrMap base_attrs_;\n};\n\nMaybe<void> CublasFusedMLP::Init(const OpExpr& op) {\n  const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n  CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n  base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> CublasFusedMLP::Capture(CublasFusedMLPCaptureState* ctx, const TensorTuple& inputs,\n                                    const TensorTuple& outputs, const AttrMap& attrs) const {\n  CHECK_OR_RETURN(inputs.size() % 2 == 1)\n      << Error::RuntimeError() << \"Both weight and bias should be passed together\";\n  int32_t weight_num = (inputs.size() - 1) / 2;\n  ctx->weight_num = weight_num;\n  ctx->x_requires_grad = JUST(VectorAt(inputs, 0))->requires_grad();\n  ctx->weights_requires_grad.resize(weight_num);\n  ctx->biases_requires_grad.resize(weight_num);\n\n  for (int32_t i = 0; i < weight_num; i++) {\n    ctx->weights_requires_grad.at(i) = inputs.at(i + 1)->requires_grad();              // NOLINT\n    ctx->biases_requires_grad.at(i) = inputs.at(i + 1 + weight_num)->requires_grad();  // NOLINT\n  }\n\n  ctx->SaveTensorForBackward(JUST(VectorAt(inputs, 0)));  // x. idx_sum:1\n  for (int32_t i = 0; i < weight_num; i++) {\n    ctx->SaveTensorForBackward(JUST(VectorAt(inputs, i + 1)));  // weights. idx_sum:1+w\n  }\n\n  ctx->SaveTensorForBackward(JUST(VectorAt(outputs, 0)));  // final layers output. idx_sum:2+w\n  for (int32_t i = 0; i < weight_num; i++) {\n    ctx->SaveTensorForBackward(\n        JUST(VectorAt(outputs, i + 1)));  // cublas aux. need minus 1. idx_sum:2+2w\n  }\n  for (int32_t i = 0; i < weight_num; i++) {\n    ctx->SaveTensorForBackward(JUST(VectorAt(outputs, i + 1 + weight_num)));  // hidden.\n  }\n\n  ComposedAttrMap composed_attrs(attrs, base_attrs_);\n  ctx->skip_final_activation = JUST(composed_attrs.GetAttr<bool>(\"skip_final_activation\"));\n\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> CublasFusedMLP::Apply(const CublasFusedMLPCaptureState* ctx,\n                                  const TensorTuple& out_grads, TensorTuple* in_grads) const {\n  int32_t weight_num = ctx->weight_num;\n  in_grads->resize(1 + 2 * weight_num);\n  std::shared_ptr<one::Tensor> last_bias_dy = JUST(VectorAt(out_grads, 0));\n\n  if (!ctx->skip_final_activation) {\n    // step1: use dy and final output to get last layer's relu grad.\n    last_bias_dy = JUST(functional::ReluGrad(JUST(VectorAt(out_grads, 0)),\n                                             JUST(VectorAt(ctx->SavedTensors(), 1 + weight_num))));\n  }\n\n  TensorTuple hiddens(weight_num);\n  TensorTuple weights(weight_num);\n  TensorTuple cublas_auxs(weight_num);\n  TensorTuple dgrad(weight_num);\n\n  std::shared_ptr<one::Tensor> x = JUST(VectorAt(ctx->SavedTensors(), 0));\n\n  for (int32_t i = 0; i < weight_num; ++i) {\n    weights[i] = JUST(VectorAt(ctx->SavedTensors(), 1 + i));\n  }\n\n  for (int32_t i = 0; i < weight_num; ++i) {\n    cublas_auxs[i] = JUST(VectorAt(ctx->SavedTensors(), i + 2 + weight_num));\n  }\n\n  for (int32_t i = 0; i < weight_num; ++i) {\n    hiddens[i] = JUST(VectorAt(ctx->SavedTensors(), i + 2 + 2 * weight_num));\n  }\n\n  std::shared_ptr<one::Tensor> cublas_dy = last_bias_dy;\n\n  // Use Fully Fused MLP Backward.\n  if (ParseBooleanFromEnv(\"ONEFLOW_ONE_EMBEDDING_FUSED_MLP_ASYNC_GRAD\", false)) {\n    const std::vector<float> alpha_list(weight_num - 1, 1.0);\n    const auto& fused_mlp_grad =\n        JUST(functional::FusedMLPGrad(cublas_dy, JUST(VectorAt(ctx->SavedTensors(), 0)), weights,\n                                      cublas_auxs, hiddens, alpha_list));\n    if (ctx->x_requires_grad) {\n      // dx:\n      JUST(VectorAt(*in_grads, 0)) = fused_mlp_grad->at(0);\n    }\n\n    for (int32_t hidden_layer_idx = weight_num - 1; hidden_layer_idx > -1; hidden_layer_idx--) {\n      if (JUST(VectorAt(ctx->biases_requires_grad, (hidden_layer_idx)))) {\n        // dbias\n        JUST(VectorAt(*in_grads, weight_num + hidden_layer_idx + 1)) =\n            fused_mlp_grad->at(1 + hidden_layer_idx);  // NOLINT\n      }\n\n      // dw\n      if (JUST(VectorAt(ctx->weights_requires_grad, hidden_layer_idx))) {\n        JUST(VectorAt(*in_grads, (1 + hidden_layer_idx))) =\n            fused_mlp_grad->at(1 + weight_num + hidden_layer_idx);\n      }\n    }\n  } else {\n    // step2: use reduce_sum to get last layer's bias grad.\n    std::vector<int32_t> reduce_axes_vec{0};\n    if (JUST(VectorAt(ctx->biases_requires_grad, weight_num - 1))) {\n      JUST(VectorAt(*in_grads, 2 * weight_num)) =\n          JUST(functional::ReduceSum(last_bias_dy, reduce_axes_vec, false, NullOpt));\n    }\n\n    for (int32_t hidden_layer_idx = weight_num - 1; hidden_layer_idx > 0; hidden_layer_idx--) {\n      // If it is final layer, we use out_grads[0] as dy.\n      if (hidden_layer_idx != weight_num - 1) {\n        cublas_dy = JUST(VectorAt(dgrad, hidden_layer_idx + 1));\n      }\n      /*\n      Here we use cublas to compute bias + relu + matmul grad.\n      Then use Matmul to compute weight grad.\n      */\n      const auto& matmul_relu_bias_bgrad = JUST(functional::CublasBiasAddReluMatmulGrad(\n          cublas_dy, JUST(VectorAt(weights, hidden_layer_idx)),\n          JUST(VectorAt(cublas_auxs, hidden_layer_idx - 1)), /*alpha=*/1.0));\n\n      // dgrad\n      dgrad.at(hidden_layer_idx) = matmul_relu_bias_bgrad->at(0);  // NOLINT\n\n      if (JUST(VectorAt(ctx->biases_requires_grad, (hidden_layer_idx - 1)))) {\n        // dbias\n        JUST(VectorAt(*in_grads, weight_num + hidden_layer_idx)) =\n            matmul_relu_bias_bgrad->at(1);  // NOLINT\n      }\n      // dw\n      if (JUST(VectorAt(ctx->weights_requires_grad, hidden_layer_idx))) {\n        JUST(VectorAt(*in_grads, (1 + hidden_layer_idx))) = JUST(functional::MatMul(\n            cublas_dy, JUST(VectorAt(hiddens, hidden_layer_idx - 1)), true, false, 1.0));\n      }\n    }\n\n    // For the first layer, we need to use 2 matmul to get grads.\n    std::shared_ptr<one::Tensor> last_dy;\n    if (weight_num != 1) {\n      last_dy = JUST(VectorAt(dgrad, 1));\n    } else {\n      last_dy = last_bias_dy;\n    }\n\n    if (ctx->x_requires_grad) {\n      // dx:\n      JUST(VectorAt(*in_grads, 0)) =\n          JUST(functional::MatMul(last_dy, JUST(VectorAt(weights, 0)), false, false, 1.0));\n    }\n    if (JUST(VectorAt(ctx->weights_requires_grad, 0))) {\n      // dw:\n      JUST(VectorAt(*in_grads, 1)) = JUST(\n          functional::MatMul(last_dy, JUST(VectorAt(ctx->SavedTensors(), 0)), true, false, 1.0));\n    }\n  }\n\n  return Maybe<void>::Ok();\n}\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"cublas_fused_mlp\", CublasFusedMLP);\n\n}  // namespace one\n\n}  // namespace oneflow\n#endif  // CUDA_VERSION >= 11060\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/cum_ops.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct CumCaptureState : public AutoGradCaptureState {\n  bool requires_grad = false;\n  int32_t dim = 0;\n};\n\ntemplate<typename StateT>\nclass CumGrad : public OpExprGradFunction<StateT> {\n public:\n  Maybe<void> Init(const OpExpr& op) override {\n    const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n    CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n    base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n    return Maybe<void>::Ok();\n  }\n\n protected:\n  AttrMap base_attrs_;\n};\n\nclass CumsumGrad : public CumGrad<CumCaptureState> {\n public:\n  Maybe<void> Capture(CumCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs,\n                      const AttrMap& attrs) const override {\n    CHECK_EQ_OR_RETURN(inputs.size(), 1);  // NOLINT(maybe-need-error-msg)\n    ctx->requires_grad = inputs.at(0)->requires_grad();\n    if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n\n    ComposedAttrMap composed_attrs(attrs, base_attrs_);\n    ctx->dim = JUST(composed_attrs.GetAttr<int64_t>(\"dim\"));\n    return Maybe<void>::Ok();\n  }\n  Maybe<void> Apply(const CumCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n    in_grads->resize(1);\n    if (ctx->requires_grad) {\n      std::vector<int32_t> flip_dim(1, ctx->dim);\n      (*in_grads)[0] = JUST(\n          functional::Flip(JUST(functional::Cumsum(JUST(functional::Flip(out_grads[0], flip_dim)),\n                                                   ctx->dim, out_grads[0]->dtype())),\n                           flip_dim));\n    }\n    return Maybe<void>::Ok();\n  }\n};\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"cumsum\", CumsumGrad);\n\nclass CumProdGrad : public CumGrad<CumCaptureState> {\n public:\n  Maybe<void> Capture(CumCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs,\n                      const AttrMap& attrs) const override {\n    CHECK_EQ_OR_RETURN(inputs.size(), 1);  // NOLINT(maybe-need-error-msg)\n    ctx->requires_grad = inputs.at(0)->requires_grad();\n    if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n\n    ComposedAttrMap composed_attrs(attrs, base_attrs_);\n    ctx->dim = JUST(composed_attrs.GetAttr<int64_t>(\"dim\"));\n    ctx->SaveTensorForBackward(outputs.at(0));\n    ctx->SaveTensorForBackward(inputs.at(0));\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const CumCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n    in_grads->resize(1);\n    if (ctx->requires_grad) {\n      in_grads->at(0) = JUST(functional::CumprodGrad(out_grads.at(0), ctx->SavedTensors().at(0),\n                                                     ctx->SavedTensors().at(1), ctx->dim));\n    }\n    return Maybe<void>::Ok();\n  }\n};\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"cumprod\", CumProdGrad);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/deconv.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <cstdint>\n#include \"oneflow/core/common/optional.h\"\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/framework/op_interpreter/op_interpreter_util.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct DeConvolutionNdCaptureState : public AutoGradCaptureState {\n  bool weight_requires_grad = false;\n  bool activation_requires_grad = false;\n  size_t ndims;\n  std::string data_format;\n  std::vector<int32_t> padding_before;\n  std::vector<int32_t> kernel_size;\n  std::vector<int32_t> strides;\n  std::vector<int32_t> dilation_rate;\n  int32_t groups;\n};\n\nclass DeConvolutionNd : public OpExprGradFunction<DeConvolutionNdCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override;\n  Maybe<void> Capture(DeConvolutionNdCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override;\n  Maybe<void> Apply(const DeConvolutionNdCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override;\n\n private:\n  AttrMap base_attrs_;\n};\n\nMaybe<void> DeConvolutionNd::Init(const OpExpr& op) {\n  const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n  CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n  base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> DeConvolutionNd::Capture(DeConvolutionNdCaptureState* ctx, const TensorTuple& inputs,\n                                     const TensorTuple& outputs, const AttrMap& attrs) const {\n  ctx->activation_requires_grad = inputs.at(0)->requires_grad();\n  ctx->weight_requires_grad = inputs.at(1)->requires_grad();\n  if (ctx->activation_requires_grad) {\n    ctx->SaveTensorForBackward(inputs.at(1));  // weight\n  }\n  if (ctx->weight_requires_grad) {\n    ctx->SaveTensorForBackward(inputs.at(0));  // x\n  }\n\n  ComposedAttrMap composed_attrs(attrs, base_attrs_);\n  ctx->data_format = JUST(composed_attrs.GetAttr<std::string>(\"data_format\"));\n  ctx->padding_before = JUST(composed_attrs.GetAttr<std::vector<int32_t>>(\"padding_before\"));\n  ctx->kernel_size = JUST(composed_attrs.GetAttr<std::vector<int32_t>>(\"kernel_size\"));\n  ctx->strides = JUST(composed_attrs.GetAttr<std::vector<int32_t>>(\"strides\"));\n  ctx->dilation_rate = JUST(composed_attrs.GetAttr<std::vector<int32_t>>(\"dilation_rate\"));\n  ctx->groups = JUST(composed_attrs.GetAttr<int32_t>(\"groups\"));\n  ctx->ndims = ctx->kernel_size.size();\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> DeConvolutionNd::Apply(const DeConvolutionNdCaptureState* ctx,\n                                   const TensorTuple& out_grads, TensorTuple* in_grads) const {\n  in_grads->resize(2);\n  if (ctx->activation_requires_grad) {\n    const auto& x = ctx->SavedTensors().at(1);\n    std::vector<int64_t> start, stop, step;\n    for (int i = 0; i < x->shape()->NumAxes(); i++) {\n      start.emplace_back(0);\n      stop.emplace_back(x->shape()->At(i));\n      step.emplace_back(1);\n    }\n    const auto& weight = ctx->SavedTensors().at(0);\n    if (ctx->ndims == 1) {\n      std::shared_ptr<Tensor> result = JUST(functional::Conv1d(\n          out_grads.at(0), weight, Optional<Tensor>(), ctx->strides, ctx->padding_before,\n          ctx->dilation_rate, ctx->groups, ctx->data_format));\n      result = JUST(functional::Slice(result, start, stop, step, /*enable_view_slice=*/false));\n      in_grads->at(0) = result;\n    } else if (ctx->ndims == 2) {\n      std::shared_ptr<Tensor> result = JUST(functional::Conv2d(\n          out_grads.at(0), weight, Optional<Tensor>(), ctx->strides, ctx->padding_before,\n          ctx->dilation_rate, ctx->groups, ctx->data_format));\n      result = JUST(functional::Slice(result, start, stop, step, /*enable_view_slice=*/false));\n      in_grads->at(0) = result;\n    } else if (ctx->ndims == 3) {\n      std::shared_ptr<Tensor> result = JUST(functional::Conv3d(\n          out_grads.at(0), weight, Optional<Tensor>(), ctx->strides, ctx->padding_before,\n          ctx->dilation_rate, ctx->groups, ctx->data_format));\n      result = JUST(functional::Slice(result, start, stop, step, /*enable_view_slice=*/false));\n      in_grads->at(0) = result;\n    } else {\n      UNIMPLEMENTED_THEN_RETURN() << \"Invalid ndim \" << ctx->ndims << \" for conv functor\";\n    }\n  }\n  if (ctx->weight_requires_grad) {\n    int idx = ctx->activation_requires_grad;\n    const auto& x = ctx->SavedTensors().at(idx);\n    in_grads->at(1) = JUST(functional::ConvFilterGrad(\n        x, out_grads.at(0), ctx->ndims, ctx->kernel_size, ctx->strides, ctx->padding_before,\n        ctx->dilation_rate, ctx->groups, ctx->data_format));\n  }\n  return Maybe<void>::Ok();\n}\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"deconv1d\", DeConvolutionNd);\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"deconv2d\", DeConvolutionNd);\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"deconv3d\", DeConvolutionNd);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/deform_conv.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct DeformConvNdCaptureState : public AutoGradCaptureState {\n  bool input_requires_grad = false;\n  bool offset_requires_grad = false;\n  bool weight_requires_grad = false;\n  bool mask_requires_grad = false;\n  bool bias_requires_grad = false;\n  int32_t stride_h = 0;\n  int32_t stride_w = 0;\n  int32_t pad_h = 0;\n  int32_t pad_w = 0;\n  int32_t dilation_h = 0;\n  int32_t dilation_w = 0;\n  int32_t groups = 0;\n  int32_t offset_groups = 0;\n  bool use_mask = false;\n};\n\nclass DeformConvNd : public OpExprGradFunction<DeformConvNdCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override;\n  Maybe<void> Capture(DeformConvNdCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override;\n  Maybe<void> Apply(const DeformConvNdCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override;\n\n private:\n  AttrMap base_attrs_;\n};\n\nMaybe<void> DeformConvNd::Init(const OpExpr& op) {\n  const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n  CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n  base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> DeformConvNd::Capture(DeformConvNdCaptureState* ctx, const TensorTuple& inputs,\n                                  const TensorTuple& outputs, const AttrMap& attrs) const {\n  ctx->input_requires_grad = inputs.at(0)->requires_grad();\n  ctx->weight_requires_grad = inputs.at(1)->requires_grad();\n  ctx->offset_requires_grad = inputs.at(2)->requires_grad();\n  ctx->mask_requires_grad = inputs.at(3)->requires_grad();\n\n  ctx->SaveTensorForBackward(inputs.at(0));  // input\n  ctx->SaveTensorForBackward(inputs.at(1));  // weight\n  ctx->SaveTensorForBackward(inputs.at(2));  // offset\n  ctx->SaveTensorForBackward(inputs.at(3));  // mask\n\n  ComposedAttrMap composed_attrs(attrs, base_attrs_);\n\n  ctx->use_mask = JUST(composed_attrs.GetAttr<bool>(\"use_mask\"));\n  ctx->stride_h = JUST(composed_attrs.GetAttr<int32_t>(\"stride_h\"));\n  ctx->stride_w = JUST(composed_attrs.GetAttr<int32_t>(\"stride_w\"));\n  ctx->pad_h = JUST(composed_attrs.GetAttr<int32_t>(\"pad_h\"));\n  ctx->pad_w = JUST(composed_attrs.GetAttr<int32_t>(\"pad_w\"));\n  ctx->dilation_h = JUST(composed_attrs.GetAttr<int32_t>(\"dilation_h\"));\n  ctx->dilation_w = JUST(composed_attrs.GetAttr<int32_t>(\"dilation_w\"));\n  ctx->groups = JUST(composed_attrs.GetAttr<int32_t>(\"groups\"));\n  ctx->offset_groups = JUST(composed_attrs.GetAttr<int32_t>(\"offset_groups\"));\n\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> DeformConvNd::Apply(const DeformConvNdCaptureState* ctx, const TensorTuple& out_grads,\n                                TensorTuple* in_grads) const {\n  in_grads->resize(5);\n  CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n  const auto& input = ctx->SavedTensors().at(0);\n  const auto& weight = ctx->SavedTensors().at(1);\n  const auto& offset = ctx->SavedTensors().at(2);\n  const auto& mask = ctx->SavedTensors().at(3);\n  const auto& output_grad = out_grads.at(0);\n  if (ctx->input_requires_grad || ctx->offset_requires_grad || ctx->mask_requires_grad) {\n    std::shared_ptr<TensorTuple> grads_tuple;\n    if (ctx->use_mask) {\n      grads_tuple = JUST(functional::DeformConv2dInputGrad(\n          output_grad, input, weight, offset, mask, ctx->stride_h, ctx->stride_w, ctx->pad_h,\n          ctx->pad_w, ctx->dilation_h, ctx->dilation_w, ctx->groups, ctx->offset_groups,\n          ctx->use_mask));\n    } else {\n      grads_tuple = JUST(functional::DeformConv2dInputGrad(\n          output_grad, input, weight, offset, NullOpt, ctx->stride_h, ctx->stride_w, ctx->pad_h,\n          ctx->pad_w, ctx->dilation_h, ctx->dilation_w, ctx->groups, ctx->offset_groups,\n          ctx->use_mask));\n    }\n    if (ctx->input_requires_grad) {\n      in_grads->at(0) = grads_tuple->at(0);  // input_grad\n    }\n    if (ctx->offset_requires_grad) {\n      in_grads->at(2) = grads_tuple->at(1);  // offset_grad\n    }\n    if (ctx->use_mask && ctx->mask_requires_grad) {\n      in_grads->at(3) = grads_tuple->at(2);  // mask_grad\n    }\n  }\n\n  if (ctx->weight_requires_grad) {  // weight_grad\n    in_grads->at(1) = JUST(functional::DeformConv2dParamGrad(\n        output_grad, input, weight, offset, mask, ctx->stride_h, ctx->stride_w, ctx->pad_h,\n        ctx->pad_w, ctx->dilation_h, ctx->dilation_w, ctx->groups, ctx->offset_groups,\n        ctx->use_mask));\n  }\n\n  return Maybe<void>::Ok();\n}\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"deform_conv2d\", DeformConvNd);\n\n}  // namespace one\n}  // namespace oneflow"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/depand.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct DependCaptureState : public AutoGradCaptureState {\n  bool in_requires_grad = false;\n  bool depend_tensor_requires_grad = false;\n  Shape depend_tensor_shape;\n  Symbol<DType> depend_tensor_dtype;\n  Maybe<Symbol<Device>> depend_tensor_device;\n};\n\nclass Depend : public OpExprGradFunction<DependCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }\n\n  Maybe<void> Capture(DependCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    CHECK_EQ_OR_RETURN(inputs.size(), 2);   // NOLINT(maybe-need-error-msg)\n    CHECK_EQ_OR_RETURN(outputs.size(), 1);  // NOLINT(maybe-need-error-msg)\n    ctx->in_requires_grad = inputs.at(0)->requires_grad();\n    ctx->depend_tensor_requires_grad = inputs.at(1)->requires_grad();\n    if (ctx->depend_tensor_requires_grad) {\n      ctx->depend_tensor_shape = *(inputs.at(1)->shape());\n      ctx->depend_tensor_dtype = inputs.at(1)->dtype();\n      ctx->depend_tensor_device = inputs.at(1)->device();\n    }\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const DependCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n    in_grads->resize(2);\n    if (ctx->in_requires_grad) { in_grads->at(0) = out_grads.at(0); }\n    if (ctx->depend_tensor_requires_grad) {\n      in_grads->at(1) =\n          JUST(functional::Constant(ctx->depend_tensor_shape, Scalar(0), ctx->depend_tensor_dtype,\n                                    JUST(ctx->depend_tensor_device)));\n    }\n    return Maybe<void>::Ok();\n  }\n};\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"depend\", Depend);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/det.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/functional/functional.h\"\n#include \"oneflow/core/common/container_util.h\"\n#include \"oneflow/core/functional/functional_api.yaml.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct DetCaptureState : public AutoGradCaptureState {\n  bool requires_grad = false;\n  size_t input_index = 0;\n  size_t output_index = 0;\n};\n\nclass Det : public OpExprGradFunction<DetCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }\n  Maybe<void> Capture(DetCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs,\n                      const AttrMap& attrs) const override {\n    ctx->requires_grad = JUST(VectorAt(inputs, 0))->requires_grad();\n    if (ctx->requires_grad) {\n      ctx->input_index = ctx->SaveTensorForBackward(JUST(VectorAt(inputs, 0)));\n      ctx->output_index = ctx->SaveTensorForBackward(JUST(VectorAt(outputs, 0)));\n    }\n    return Maybe<void>::Ok();\n  }\n  Maybe<void> Apply(const DetCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    if (ctx->requires_grad) {\n      const auto& output = JUST(VectorAt(ctx->SavedTensors(), ctx->output_index));\n      const auto& input = JUST(VectorAt(ctx->SavedTensors(), ctx->input_index));\n      const auto& dy = JUST(VectorAt(out_grads, 0));\n      const auto& dy_unsqueeze = JUST(functional::UnsqueezeMultiple(dy, {-2, -1}, dy->ndim() + 2));\n      const auto& output_unsqueeze =\n          JUST(functional::UnsqueezeMultiple(output, {-2, -1}, output->ndim() + 2));\n      JUST(VectorAt(*in_grads, 0)) = JUST(functional::Transpose2dim(\n          JUST(functional::Mul(\n              dy_unsqueeze, JUST(functional::Mul(JUST(functional::Inv(input)), output_unsqueeze)))),\n          -2, -1));\n    }\n    return Maybe<void>::Ok();\n  }\n};\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"det\", Det);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/diag.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/attr_map.h\"\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct DiagCaptureState : public AutoGradCaptureState {\n  bool requires_grad;\n  int32_t diagonal;\n};\n\nclass Diag : public OpExprGradFunction<DiagCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override;\n  Maybe<void> Capture(DiagCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs,\n                      const AttrMap& attrs) const override;\n  Maybe<void> Apply(const DiagCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override;\n\n private:\n  AttrMap base_attrs_;\n};\n\nMaybe<void> Diag::Init(const OpExpr& op) {\n  const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n  CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n  base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> Diag::Capture(DiagCaptureState* ctx, const TensorTuple& inputs,\n                          const TensorTuple& outputs, const AttrMap& attrs) const {\n  CHECK_EQ_OR_RETURN(outputs.size(), 1);  // NOLINT(maybe-need-error-msg)\n  ctx->requires_grad = inputs.at(0)->requires_grad();\n  if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n  ComposedAttrMap composed_attrs(attrs, base_attrs_);\n  ctx->diagonal = JUST(composed_attrs.GetAttr<int32_t>(\"diagonal\"));\n  ctx->SaveTensorForBackward(inputs.at(0));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> Diag::Apply(const DiagCaptureState* ctx, const TensorTuple& out_grads,\n                        TensorTuple* in_grads) const {\n  CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n  in_grads->resize(2);\n  if (ctx->requires_grad) {\n    const auto& x = ctx->SavedTensors().at(0);\n    in_grads->at(0) = JUST(functional::DiagGrad(out_grads.at(0), x, ctx->diagonal));\n  }\n  return Maybe<void>::Ok();\n}\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"diag\", Diag);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/diagonal.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/attr_map.h\"\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct DiagonalInterpState : public AutoGradCaptureState {\n  bool requires_grad = false;\n  int32_t offset = 0;\n};\n\nclass Diagonal : public OpExprGradFunction<DiagonalInterpState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override;\n  Maybe<void> Capture(DiagonalInterpState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override;\n  Maybe<void> Apply(const DiagonalInterpState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override;\n\n private:\n  AttrMap base_attrs_;\n};\n\nMaybe<void> Diagonal::Init(const OpExpr& op) {\n  const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n  CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n  base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> Diagonal::Capture(DiagonalInterpState* ctx, const TensorTuple& inputs,\n                              const TensorTuple& outputs, const AttrMap& attrs) const {\n  CHECK_EQ_OR_RETURN(outputs.size(), 1);  // NOLINT(maybe-need-error-msg)\n  ctx->requires_grad = inputs.at(0)->requires_grad();\n  if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n  ComposedAttrMap composed_attrs(attrs, base_attrs_);\n  ctx->offset = JUST(composed_attrs.GetAttr<int32_t>(\"offset\"));\n  ctx->SaveTensorForBackward(inputs.at(0));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> Diagonal::Apply(const DiagonalInterpState* ctx, const TensorTuple& out_grads,\n                            TensorTuple* in_grads) const {\n  CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n  in_grads->resize(2);\n  if (ctx->requires_grad) {\n    const auto& x = ctx->SavedTensors().at(0);\n    in_grads->at(0) = JUST(functional::DiagonalGrad(out_grads.at(0), x, ctx->offset));\n  }\n  return Maybe<void>::Ok();\n}\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"diagonal\", Diagonal);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/dim_gather.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/framework/op_interpreter/op_interpreter_util.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct DimGatherCaptureState : public AutoGradCaptureState {\n  int32_t dim;\n  bool requires_grad;\n};\n\nclass DimGather : public OpExprGradFunction<DimGatherCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override;\n  Maybe<void> Capture(DimGatherCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override;\n  Maybe<void> Apply(const DimGatherCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override;\n\n private:\n  AttrMap base_attrs_;\n  std::shared_ptr<OpExpr> bw_dim_gather_op_;\n};\n\nMaybe<void> DimGather::Init(const OpExpr& op) {\n  const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n  CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n  base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> DimGather::Capture(DimGatherCaptureState* ctx, const TensorTuple& inputs,\n                               const TensorTuple& outputs, const AttrMap& attrs) const {\n  ctx->requires_grad = inputs.at(0)->requires_grad();\n  if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n\n  ctx->SaveTensorForBackward(inputs.at(1));\n  ctx->SaveTensorForBackward(inputs.at(0));\n\n  ComposedAttrMap composed_attrs(attrs, base_attrs_);\n  ctx->dim = JUST(composed_attrs.GetAttr<int32_t>(\"dim\"));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> DimGather::Apply(const DimGatherCaptureState* ctx, const TensorTuple& out_grads,\n                             TensorTuple* in_grads) const {\n  if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n  CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n  const std::shared_ptr<oneflow::one::Tensor>& index = ctx->SavedTensors().at(0);\n  const std::shared_ptr<oneflow::one::Tensor>& like = ctx->SavedTensors().at(1);\n\n  in_grads->at(0) = JUST(functional::DimScatterAddLike(like, ctx->dim, index, out_grads.at(0)));\n  return Maybe<void>::Ok();\n}\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"dim_gather\", DimGather);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/dim_scatter.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/framework/op_builder.h\"\n#include \"oneflow/core/framework/op_interpreter/op_interpreter_util.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct DimScatterCaptureState : public AutoGradCaptureState {\n  int32_t dim;\n  bool input_requires_grad;\n  bool src_requires_grad;\n};\nenum class ScatterType { kUpdate, kAdd, kMultiply };\n\ntemplate<ScatterType T>\nclass DimScatter : public OpExprGradFunction<DimScatterCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override;\n  Maybe<void> Capture(DimScatterCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override;\n  Maybe<void> Apply(const DimScatterCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override;\n\n private:\n  AttrMap base_attrs_;\n};\n\ntemplate<ScatterType T>\nMaybe<void> DimScatter<T>::Init(const OpExpr& op) {\n  const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n  CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n  base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n  return Maybe<void>::Ok();\n}\n\ntemplate<ScatterType T>\nMaybe<void> DimScatter<T>::Capture(DimScatterCaptureState* ctx, const TensorTuple& inputs,\n                                   const TensorTuple& outputs, const AttrMap& attrs) const {\n  CHECK_EQ_OR_RETURN(inputs.size(), 3);   // NOLINT(maybe-need-error-msg)\n  CHECK_EQ_OR_RETURN(outputs.size(), 1);  // NOLINT(maybe-need-error-msg)\n\n  ctx->input_requires_grad = inputs.at(0)->requires_grad();\n  ctx->src_requires_grad = inputs.at(2)->requires_grad();\n  if ((!ctx->input_requires_grad) && (!ctx->src_requires_grad)) { return Maybe<void>::Ok(); }\n\n  ctx->SaveTensorForBackward(inputs.at(1));  // index saved\n  if (T == ScatterType::kMultiply) {\n    ctx->SaveTensorForBackward(inputs.at(2));  // src saved\n  }\n\n  ComposedAttrMap composed_attrs(attrs, base_attrs_);\n  ctx->dim = JUST(composed_attrs.GetAttr<int32_t>(\"dim\"));\n  return Maybe<void>::Ok();\n}\n\ntemplate<ScatterType T>\nMaybe<void> DimScatter<T>::Apply(const DimScatterCaptureState* ctx, const TensorTuple& out_grads,\n                                 TensorTuple* in_grads) const {\n  if ((!ctx->input_requires_grad) && (!ctx->src_requires_grad)) { return Maybe<void>::Ok(); }\n\n  CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n  in_grads->resize(3);\n\n  const std::shared_ptr<oneflow::one::Tensor>& index = ctx->SavedTensors().at(0);\n\n  if (ctx->src_requires_grad) {\n    in_grads->at(2) = JUST(functional::DimGather(out_grads.at(0), ctx->dim, index, false));\n  }\n\n  if (ctx->input_requires_grad) {\n    if (T == ScatterType::kAdd) { in_grads->at(0) = out_grads.at(0); }\n\n    if (T == ScatterType::kUpdate) {\n      in_grads->at(0) = JUST(functional::DimScatterUpdateScalar(out_grads.at(0), ctx->dim, index,\n                                                                0.0f, /*inplace*/ false));\n    }\n\n    if (T == ScatterType::kMultiply) {\n      const std::shared_ptr<oneflow::one::Tensor>& src = ctx->SavedTensors().at(1);\n      in_grads->at(0) =\n          JUST(functional::DimScatterMul(out_grads.at(0), ctx->dim, index, src, /*inplace*/ false));\n    }\n  }\n  return Maybe<void>::Ok();\n}\n\nclass DimScatterUpdateScalar : public OpExprGradFunction<DimScatterCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override;\n  Maybe<void> Capture(DimScatterCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override;\n  Maybe<void> Apply(const DimScatterCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override;\n\n private:\n  AttrMap base_attrs_;\n};\n\nMaybe<void> DimScatterUpdateScalar::Init(const OpExpr& op) {\n  const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n  CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n  base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> DimScatterUpdateScalar::Capture(DimScatterCaptureState* ctx, const TensorTuple& inputs,\n                                            const TensorTuple& outputs,\n                                            const AttrMap& attrs) const {\n  CHECK_EQ_OR_RETURN(inputs.size(), 2);   // NOLINT(maybe-need-error-msg)\n  CHECK_EQ_OR_RETURN(outputs.size(), 1);  // NOLINT(maybe-need-error-msg)\n\n  ctx->input_requires_grad = inputs.at(0)->requires_grad();\n  if (!ctx->input_requires_grad) { return Maybe<void>::Ok(); }\n\n  ctx->SaveTensorForBackward(inputs.at(1));  // index saved\n\n  ComposedAttrMap composed_attrs(attrs, base_attrs_);\n  ctx->dim = JUST(composed_attrs.GetAttr<int32_t>(\"dim\"));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> DimScatterUpdateScalar::Apply(const DimScatterCaptureState* ctx,\n                                          const TensorTuple& out_grads,\n                                          TensorTuple* in_grads) const {\n  if (!ctx->input_requires_grad) { return Maybe<void>::Ok(); }\n  CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n  const std::shared_ptr<oneflow::one::Tensor>& index = ctx->SavedTensors().at(0);\n\n  in_grads->resize(2);\n  in_grads->at(0) = JUST(functional::DimScatterUpdateScalar(out_grads.at(0), ctx->dim, index, 0.0f,\n                                                            /*inplace*/ false));\n\n  return Maybe<void>::Ok();\n}\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"dim_scatter_update\", DimScatter<ScatterType::kUpdate>);\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"dim_scatter_add\", DimScatter<ScatterType::kAdd>);\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"dim_scatter_mul\", DimScatter<ScatterType::kMultiply>);\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"dim_scatter_update_scalar\", DimScatterUpdateScalar);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/dot.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct DotCaptureState : public AutoGradCaptureState {\n  bool x_requires_grad = false;\n  bool y_requires_grad = false;\n  size_t x_offset = 0;\n  size_t y_offset = 0;\n};\n\nclass DotGrad : public OpExprGradFunction<DotCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }\n\n  Maybe<void> Capture(DotCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs,\n                      const AttrMap& attrs) const override {\n    CHECK_EQ_OR_RETURN(inputs.size(), 2);\n    CHECK_EQ_OR_RETURN(outputs.size(), 1);\n    ctx->x_requires_grad = inputs.at(0)->requires_grad();\n    if (ctx->x_requires_grad) { ctx->x_offset = ctx->SaveTensorForBackward(inputs.at(1)); }\n    ctx->y_requires_grad = inputs.at(1)->requires_grad();\n    if (ctx->y_requires_grad) { ctx->y_offset = ctx->SaveTensorForBackward(inputs.at(0)); }\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const DotCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    CHECK_EQ_OR_RETURN(out_grads.size(), 1);\n    in_grads->resize(2);\n    if (ctx->x_requires_grad) {\n      const auto& x = ctx->SavedTensors().at(ctx->x_offset);\n      const auto& results = JUST(functional::Mul(x, out_grads.at(0)));\n      in_grads->at(0) = results;\n    }\n\n    if (ctx->y_requires_grad) {\n      const auto& y = ctx->SavedTensors().at(ctx->y_offset);\n      const auto& results = JUST(functional::Mul(y, out_grads.at(0)));\n      in_grads->at(1) = results;\n    }\n    return Maybe<void>::Ok();\n  }\n};\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"dot\", DotGrad);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/dropout.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct DropoutCaptureState : public AutoGradCaptureState {\n  bool requires_grad = true;\n  bool has_addend = false;\n  float rate = 0.0;\n};\n\nclass Dropout : public OpExprGradFunction<DropoutCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override;\n  Maybe<void> Capture(DropoutCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override;\n  Maybe<void> Apply(const DropoutCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override;\n\n private:\n  AttrMap base_attrs_;\n};\n\nMaybe<void> Dropout::Init(const OpExpr& op) {\n  const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n  CHECK_NOTNULL_OR_RETURN(fw_op_expr);\n  base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> Dropout::Capture(DropoutCaptureState* ctx, const TensorTuple& inputs,\n                             const TensorTuple& outputs, const AttrMap& attrs) const {\n  ComposedAttrMap composed_attrs(attrs, base_attrs_);\n  ctx->requires_grad = inputs.at(0)->requires_grad();\n\n  if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n  ctx->rate = JUST(composed_attrs.GetAttr<float>(\"rate\"));\n\n  if (inputs.size() == 1) {\n    ctx->has_addend = false;\n  } else if (inputs.size() == 2) {\n    ctx->has_addend = true;\n  } else {\n    UNIMPLEMENTED();\n  }\n\n  ctx->SaveTensorForBackward(outputs.at(1));  // output mask\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> Dropout::Apply(const DropoutCaptureState* ctx, const TensorTuple& out_grads,\n                           TensorTuple* in_grads) const {\n  if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n  CHECK_EQ_OR_RETURN(out_grads.size(), 2);  // Output has y and mask.\n  float scale = 0.0f;                       // When dropout rate = 1.0, we set scale as zero.\n  if (ctx->rate < 1.0f) { scale = 1.0f / (1.0f - ctx->rate); }\n  const std::shared_ptr<oneflow::one::Tensor>& mask = ctx->SavedTensors().at(0);\n  if (ctx->has_addend) {\n    in_grads->resize(2);\n    in_grads->at(0) = JUST(functional::DropoutGrad(out_grads.at(0), mask, scale));\n    in_grads->at(1) = out_grads.at(0);\n    return Maybe<void>::Ok();\n  } else {\n    in_grads->resize(1);\n    in_grads->at(0) = JUST(functional::DropoutGrad(out_grads.at(0), mask, scale));\n    return Maybe<void>::Ok();\n  }\n}\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"dropout\", Dropout);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/eager_ccl_broadcast.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/id_util.h\"\n#include \"oneflow/core/framework/op_builder.h\"\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/framework/device.h\"\n#include \"oneflow/core/framework/op_interpreter/op_interpreter_util.h\"\n\nnamespace oneflow {\n\nnamespace one {\n\nnamespace {\n\nMaybe<one::UserOpExpr> EagerCclReduce(Symbol<ParallelDesc> parallel_desc, int64_t root) {\n  return one::OpBuilder(\"eager_ccl_reduce\", *JUST(UniqueStr(\"eager_ccl_reduce\")))\n      .Input(\"in\")\n      .Output(\"out\")\n      .Attr<std::string>(\"parallel_conf\", PbMessage2TxtString(parallel_desc->parallel_conf()))\n      .Attr<int64_t>(\"root\", root)\n      .Build();\n}\n\nMaybe<one::UserOpExpr> FindOrCreatEagerCclReduceOpExpr(Symbol<ParallelDesc> parallel_desc,\n                                                       int64_t root) {\n  thread_local HashMap<std::pair<Symbol<ParallelDesc>, int64_t>, std::shared_ptr<one::UserOpExpr>>\n      parallel_desc_and_root_device2eager_nccl_reduce;\n  const auto& key = std::make_pair(parallel_desc, root);\n  auto iter = parallel_desc_and_root_device2eager_nccl_reduce.find(key);\n  if (iter == parallel_desc_and_root_device2eager_nccl_reduce.end()) {\n    std::shared_ptr<UserOpExpr> op_expr = JUST(EagerCclReduce(parallel_desc, root));\n    iter = parallel_desc_and_root_device2eager_nccl_reduce.emplace(key, op_expr).first;\n  }\n  return iter->second;\n}\n\n}  // namespace\n\nstruct EagerCclBroadcastCaptureState : public AutoGradCaptureState {  // NOLINT\n  Symbol<ParallelDesc> parallel_desc;\n  int64_t root;\n};\n\nclass EagerCclBroadcast : public OpExprGradFunction<EagerCclBroadcastCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override {\n    const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n    CHECK_NOTNULL_OR_RETURN(fw_op_expr);\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Capture(EagerCclBroadcastCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs,\n                      const OpExprInterpContext& interp_ctx) const override {\n    ctx->root = JUST(interp_ctx.attrs.GetAttr<int64_t>(\"root\"));\n    ctx->parallel_desc = JUST(interp_ctx.parallel_desc);\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const EagerCclBroadcastCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    const auto& grad_op = JUST(FindOrCreatEagerCclReduceOpExpr(ctx->parallel_desc, ctx->root));\n    in_grads->resize(1);\n    in_grads->at(0) = JUST(OpInterpUtil::Dispatch<Tensor>(*grad_op, {out_grads.at(0)}));\n    return Maybe<void>::Ok();\n  }\n};\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"eager_ccl_broadcast\", EagerCclBroadcast);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/elementwise_minimum_maximum.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/framework/op_builder.h\"\n#include \"oneflow/core/framework/op_interpreter/op_interpreter_util.h\"\n#include \"oneflow/core/framework/attr_map.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct ElementwiseXimumCaptureState : public AutoGradCaptureState {\n  bool x_requires_grad;\n  bool y_requires_grad;\n};\n\nclass ElementwiseXimumOp : public OpExprGradFunction<ElementwiseXimumCaptureState> {\n public:\n  Maybe<void> Capture(ElementwiseXimumCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    ctx->x_requires_grad = inputs.at(0)->requires_grad();\n    ctx->y_requires_grad = inputs.at(1)->requires_grad();\n    ctx->SaveTensorForBackward(inputs.at(0));\n    ctx->SaveTensorForBackward(inputs.at(1));\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const ElementwiseXimumCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    if (!(ctx->x_requires_grad || ctx->y_requires_grad)) { return Maybe<void>::Ok(); }\n\n    in_grads->resize(2);\n    const std::shared_ptr<one::Tensor>& x = ctx->SavedTensors().at(0);\n    const std::shared_ptr<one::Tensor>& y = ctx->SavedTensors().at(1);\n    if (ctx->x_requires_grad || ctx->y_requires_grad) {\n      const auto& grads = JUST(grad_functor(out_grads.at(0), x, y));\n      if (ctx->x_requires_grad) { in_grads->at(0) = grads->at(0); }\n      if (ctx->y_requires_grad) { in_grads->at(1) = grads->at(1); }\n    }\n\n    return Maybe<void>::Ok();\n  }\n\n protected:\n  std::function<Maybe<TensorTuple>(const std::shared_ptr<Tensor>&, const std::shared_ptr<Tensor>&,\n                                   const std::shared_ptr<Tensor>&)>\n      grad_functor;\n};\n\nclass ElementwiseMinimum : public ElementwiseXimumOp {\n public:\n  Maybe<void> Init(const OpExpr& op) override {\n    grad_functor = functional::ElementwiseMinGrad;\n    return Maybe<void>::Ok();\n  }\n};\n\nclass ElementwiseMaximum : public ElementwiseXimumOp {\n public:\n  Maybe<void> Init(const OpExpr& op) override {\n    grad_functor = functional::ElementwiseMaxGrad;\n    return Maybe<void>::Ok();\n  }\n};\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"elementwise_minimum\", ElementwiseMinimum);\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"elementwise_maximum\", ElementwiseMaximum);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/embedding.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/common/container_util.h\"\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/framework/op_expr.h\"\n#include \"oneflow/core/framework/op_interpreter/op_interpreter_util.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct EmbeddingCaptureState : public AutoGradCaptureState {\n  int64_t padding_idx = -1;\n  bool scale_grad_by_freq = false;\n  bool requires_grad = false;\n};\n\nclass Embedding : public OpExprGradFunction<EmbeddingCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override;\n  Maybe<void> Capture(EmbeddingCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override;\n  Maybe<void> Apply(const EmbeddingCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override;\n\n private:\n  AttrMap base_attrs_;\n};\n\nMaybe<void> Embedding::Init(const OpExpr& op) {\n  const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n  CHECK_NOTNULL_OR_RETURN(fw_op_expr) << \"Forward op must be not null\";\n  base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> Embedding::Capture(EmbeddingCaptureState* ctx, const TensorTuple& inputs,\n                               const TensorTuple& outputs, const AttrMap& attrs) const {\n  ctx->requires_grad = JUST(oneflow::VectorAt(inputs, 0))->requires_grad();\n  if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n\n  ctx->SaveTensorForBackward(JUST(oneflow::VectorAt(inputs, 0)));\n  ctx->SaveTensorForBackward(JUST(oneflow::VectorAt(inputs, 1)));\n\n  ComposedAttrMap composed_attrs(attrs, base_attrs_);\n  ctx->padding_idx = JUST(composed_attrs.GetAttr<int64_t>(\"padding_idx\"));\n  ctx->scale_grad_by_freq = JUST(composed_attrs.GetAttr<bool>(\"scale_grad_by_freq\"));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> Embedding::Apply(const EmbeddingCaptureState* ctx, const TensorTuple& out_grads,\n                             TensorTuple* in_grads) const {\n  CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n  if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n\n  in_grads->resize(ctx->SavedTensors().size());\n  const auto& weight = JUST(oneflow::VectorAt(ctx->SavedTensors(), 0));\n  const auto& indices = JUST(oneflow::VectorAt(ctx->SavedTensors(), 1));\n  int64_t padding_idx = ctx->padding_idx;\n  bool scale_grad_by_freq = ctx->scale_grad_by_freq;\n\n  JUST(oneflow::VectorAt(*in_grads, 0)) = JUST(functional::EmbeddingGrad(\n      JUST(oneflow::VectorAt(out_grads, 0)), weight, indices, padding_idx, scale_grad_by_freq));\n  return Maybe<void>::Ok();\n}\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"embedding\", Embedding);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/expand.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/framework/op_interpreter/op_interpreter_util.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct ExpandCaptureState : public AutoGradCaptureState {\n  bool requires_grad;\n  int32_t lpad;\n  bool keep_dims;\n  std::vector<int32_t> reduce_dims;\n};\n\nclass Expand : public OpExprGradFunction<ExpandCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override;\n  Maybe<void> Capture(ExpandCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override;\n  Maybe<void> Apply(const ExpandCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override;\n};\n\nMaybe<void> Expand::Init(const OpExpr& op) {\n  const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n  CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> Expand::Capture(ExpandCaptureState* ctx, const TensorTuple& inputs,\n                            const TensorTuple& outputs, const AttrMap& attrs) const {\n  CHECK_EQ_OR_RETURN(inputs.size(), 1);   // NOLINT(maybe-need-error-msg)\n  CHECK_EQ_OR_RETURN(outputs.size(), 1);  // NOLINT(maybe-need-error-msg)\n  ctx->requires_grad = inputs[0]->requires_grad();\n  if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n\n  const Shape& in_shape = *inputs[0]->shape();\n  const Shape& expand_shape = *outputs[0]->shape();\n  ctx->lpad = expand_shape.size() - in_shape.size();\n  ctx->keep_dims = (in_shape.size() > 0);\n  ctx->reduce_dims.reserve(expand_shape.size());\n  if (ctx->keep_dims) {\n    for (size_t i = 0; i < expand_shape.size(); ++i) {\n      const auto& t_dim = expand_shape[i];\n      const auto& dim = i < ctx->lpad ? 1 : in_shape[i - ctx->lpad];\n      if (dim != t_dim) { ctx->reduce_dims.push_back(i); }\n    }\n  } else {\n    for (int32_t axis = 0; axis < expand_shape.size(); ++axis) { ctx->reduce_dims.push_back(axis); }\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> Expand::Apply(const ExpandCaptureState* ctx, const TensorTuple& out_grads,\n                          TensorTuple* in_grads) const {\n  if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n  CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n  in_grads->resize(1);\n  in_grads->at(0) = out_grads[0];\n  if (ctx->reduce_dims.size() > 0) {\n    in_grads->at(0) =\n        JUST(functional::ReduceSum(in_grads->at(0), ctx->reduce_dims, ctx->keep_dims, NullOpt));\n  }\n  if (ctx->lpad > 0 && ctx->keep_dims) {\n    in_grads->at(0) = JUST(functional::Flatten(in_grads->at(0), 0, ctx->lpad));\n  }\n  return Maybe<void>::Ok();\n}\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"expand\", Expand);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/fake_quantization.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct FakeQuantizationCaptureState : public AutoGradCaptureState {\n  bool requires_grad;\n};\n\nclass FakeQuantization : public OpExprGradFunction<FakeQuantizationCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }\n\n  Maybe<void> Capture(FakeQuantizationCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    CHECK_EQ_OR_RETURN(inputs.size(), 3);\n    ctx->requires_grad = inputs.at(0)->requires_grad();\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const FakeQuantizationCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    CHECK_EQ_OR_RETURN(out_grads.size(), 1);\n    in_grads->resize(3);\n    if (ctx->requires_grad) { in_grads->at(0) = out_grads.at(0); }\n    return Maybe<void>::Ok();\n  }\n};\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"fake_quantization\", FakeQuantization);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/fft.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <string>\n#include \"oneflow/core/common/container_util.h\"\n#include \"oneflow/core/common/optional.h\"\n#include \"oneflow/core/framework/attr_map.h\"\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/functional/functional.h\"\n#include \"oneflow/core/functional/functional_api.yaml.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct FftR2CCaptureState : public AutoGradCaptureState {\n  bool requires_grad = false;\n  bool onesided = false;\n  std::vector<int64_t> dims;\n  DimVector input_shape_vec;\n  int32_t norm_mode = 0;\n};\n\nclass FftR2C : public OpExprGradFunction<FftR2CCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }\n\n  Maybe<void> Capture(FftR2CCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    CHECK_EQ_OR_RETURN(inputs.size(), 1) << \"RuntimeError: assert `inputs.size() == 1`\";\n    ctx->requires_grad = JUST(oneflow::VectorAt(inputs, 0))->requires_grad();\n    if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n\n    ctx->onesided = JUST(attrs.GetAttr<bool>(\"onesided\"));\n    ctx->dims = JUST(attrs.GetAttr<std::vector<int64_t>>(\"dims\"));\n    ctx->norm_mode = JUST(attrs.GetAttr<int32_t>(\"norm_mode\"));\n    ctx->input_shape_vec = JUST(oneflow::VectorAt(inputs, 0))->shape()->dim_vec();\n\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const FftR2CCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    CHECK_EQ_OR_RETURN(out_grads.size(), 1) << \"RuntimeError: assert `out_grads.size() == 1`\";\n    if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n\n    in_grads->resize(1);\n    if (!ctx->onesided) {\n      auto complex_grad = JUST(functional::FftC2C(JUST(oneflow::VectorAt(out_grads, 0)), NullOpt,\n                                                  ctx->dims, ctx->norm_mode,\n                                                  /*forward=*/false, /*normalized=*/false));\n      JUST(oneflow::VectorAt(*in_grads, 0)) = JUST(functional::Real(complex_grad));\n    } else {\n      std::vector<int64_t> fft_dims = ctx->dims;\n      std::vector<int64_t> fft_shapes(fft_dims.size(), 0);\n      FOR_RANGE(size_t, i, 0, fft_dims.size()) {\n        fft_shapes[i] = ctx->input_shape_vec[fft_dims[i]];\n      }\n\n      // fill the last dim\n      bool must_copy = false;\n      auto x_sizes = JUST(oneflow::VectorAt(out_grads, 0))->shape()->dim_vec();\n      std::vector<int64_t> pad_amount(x_sizes.size() * 2, 0);\n      int64_t last_dim = ctx->dims.back();\n      if (x_sizes[last_dim] < ctx->input_shape_vec[last_dim]) {\n        must_copy = true;\n        auto pad_idx = pad_amount.size() - 2 * last_dim - 1;\n        pad_amount[pad_idx] = ctx->input_shape_vec[last_dim] - x_sizes[last_dim];\n      }\n      auto complex_full_grad =\n          must_copy\n              ? JUST(functional::ConstantPad(JUST(oneflow::VectorAt(out_grads, 0)), pad_amount, 0))\n              : JUST(oneflow::VectorAt(out_grads, 0));\n      complex_full_grad =\n          JUST(functional::FftC2C(complex_full_grad, NullOpt, ctx->dims, ctx->norm_mode,\n                                  /*forward=*/false, /*normalized=*/false));\n\n      JUST(oneflow::VectorAt(*in_grads, 0)) = JUST(functional::Real(complex_full_grad));\n    }\n\n    return Maybe<void>::Ok();\n  }\n};\n\nstruct FftC2CCaptureState : public AutoGradCaptureState {\n  bool requires_grad = false;\n  bool forward = false;\n  std::vector<int64_t> dims;\n  int32_t norm_mode = 0;\n};\n\nclass FftC2C : public OpExprGradFunction<FftC2CCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }\n\n  Maybe<void> Capture(FftC2CCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    CHECK_EQ_OR_RETURN(inputs.size(), 1) << \"RuntimeError: assert `inputs.size() == 1`\";\n\n    ctx->requires_grad = JUST(oneflow::VectorAt(inputs, 0))->requires_grad();\n    if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n\n    ctx->forward = JUST(attrs.GetAttr<bool>(\"forward\"));\n    ctx->dims = JUST(attrs.GetAttr<std::vector<int64_t>>(\"dims\"));\n    ctx->norm_mode = JUST(attrs.GetAttr<int32_t>(\"norm_mode\"));\n\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const FftC2CCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    CHECK_EQ_OR_RETURN(out_grads.size(), 1) << \"RuntimeError: assert `out_grads.size() == 1`\";\n    if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n\n    in_grads->resize(1);\n    JUST(oneflow::VectorAt(*in_grads, 0)) = JUST(functional::FftC2C(\n        JUST(oneflow::VectorAt(out_grads, 0)), NullOpt, ctx->dims, ctx->norm_mode,\n        /*forward=*/!(ctx->forward), /*normalized=*/false));\n    return Maybe<void>::Ok();\n  }\n};\n\nstruct FftC2RCaptureState : public AutoGradCaptureState {\n  bool requires_grad = false;\n  std::vector<int64_t> dims;\n  int32_t norm_mode = 0;\n  int64_t last_dim_size = 1;\n  DimVector input_shape_vec;\n};\n\nclass FftC2R : public OpExprGradFunction<FftC2RCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }\n\n  Maybe<void> Capture(FftC2RCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    CHECK_EQ_OR_RETURN(inputs.size(), 1) << \"RuntimeError: assert `inputs.size() == 1`\";\n    ctx->requires_grad = JUST(oneflow::VectorAt(inputs, 0))->requires_grad();\n    if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n\n    ctx->dims = JUST(attrs.GetAttr<std::vector<int64_t>>(\"dims\"));\n    ctx->norm_mode = JUST(attrs.GetAttr<int32_t>(\"norm_mode\"));\n    ctx->last_dim_size = JUST(attrs.GetAttr<int64_t>(\"last_dim_size\"));\n    ctx->input_shape_vec = JUST(oneflow::VectorAt(inputs, 0))->shape()->dim_vec();\n\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const FftC2RCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    CHECK_EQ_OR_RETURN(out_grads.size(), 1) << \"RuntimeError: out_grads.size() == 1\";\n    if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n\n    in_grads->resize(1);\n\n    // NOTE: set `forward` True to prevent conjugating result\n    auto complex_grad = JUST(functional::FftR2C(\n        JUST(oneflow::VectorAt(out_grads, 0)), NullOpt, ctx->dims, ctx->norm_mode,\n        /*onesided=*/true, /*forward=*/true, /*normalized=*/false));  // no need conj\n    Shape input_shape(ctx->input_shape_vec);\n    int64_t last_dim = ctx->dims.back();\n    auto double_length =\n        JUST(oneflow::VectorAt(out_grads, 0))->dim(last_dim) - complex_grad->dim(last_dim);\n    auto in_grad = complex_grad;\n\n    // Mul by 2, and slice\n    if (double_length > 0) {\n      in_grad = JUST(functional::Narrow(complex_grad, last_dim, 1,\n                                        double_length));  // will change shape of in_grad\n      in_grad = JUST(functional::ScalarMul(in_grad, 2, /*inplace=*/true));\n    }\n\n    std::vector<int64_t> slice_st(input_shape.size(), 0);\n    std::vector<int64_t> slice_end(input_shape.begin(), input_shape.end());\n    std::vector<int64_t> slice_step(input_shape.size(), 1);\n    auto sliced_tensor =\n        JUST(functional::Slice(complex_grad, slice_st, slice_end, slice_step, false));\n\n    JUST(oneflow::VectorAt(*in_grads, 0)) = sliced_tensor;\n    return Maybe<void>::Ok();\n  }\n};\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"fft_r2c\", FftR2C);\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"fft_c2c\", FftC2C);\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"fft_c2r\", FftC2R);\n\n}  // namespace one\n\n}  // namespace oneflow"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/fill.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/just.h\"\n#include \"oneflow/core/framework/attr_map.h\"\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/functional/functional.h\"\n#include \"oneflow/core/functional/functional_api.yaml.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct FillCaptureState : public AutoGradCaptureState {\n  bool in_requires_grad = false;\n  bool value_requires_grad = false;\n};\n\nclass Fill : public OpExprGradFunction<FillCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override;\n  Maybe<void> Capture(FillCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs,\n                      const AttrMap& attrs) const override;\n  Maybe<void> Apply(const FillCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override;\n\n private:\n  AttrMap base_attrs_;\n};\n\nMaybe<void> Fill::Init(const OpExpr& op) {\n  const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n  CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n  base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> Fill::Capture(FillCaptureState* ctx, const TensorTuple& inputs,\n                          const TensorTuple& outputs, const AttrMap& attrs) const {\n  ctx->in_requires_grad = inputs[0]->requires_grad();\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> Fill::Apply(const FillCaptureState* ctx, const TensorTuple& out_grads,\n                        TensorTuple* in_grads) const {\n  CHECK_EQ_OR_RETURN(out_grads.size(), 1) << \"out_grads.size() must be equal to 1.\";\n  in_grads->resize(1);\n  if (ctx->in_requires_grad) { (*in_grads)[0] = JUST(functional::Fill(out_grads[0], 0)); }\n  return Maybe<void>::Ok();\n}\n\nclass FillTensor : public OpExprGradFunction<FillCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override;\n  Maybe<void> Capture(FillCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs,\n                      const AttrMap& attrs) const override;\n  Maybe<void> Apply(const FillCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override;\n\n private:\n  AttrMap base_attrs_;\n};\n\nMaybe<void> FillTensor::Init(const OpExpr& op) {\n  const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n  CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n  base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> FillTensor::Capture(FillCaptureState* ctx, const TensorTuple& inputs,\n                                const TensorTuple& outputs, const AttrMap& attrs) const {\n  ctx->in_requires_grad = inputs[0]->requires_grad();\n  ctx->value_requires_grad = inputs[1]->requires_grad();\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> FillTensor::Apply(const FillCaptureState* ctx, const TensorTuple& out_grads,\n                              TensorTuple* in_grads) const {\n  CHECK_EQ_OR_RETURN(out_grads.size(), 1) << \"out_grads.size() must be equal to 1.\";\n  in_grads->resize(2);\n  if (ctx->value_requires_grad) {\n    int32_t num_axes = out_grads[0]->shape()->NumAxes();\n    std::vector<int32_t> axes_vec(num_axes);\n    std::iota(axes_vec.begin(), axes_vec.end(), 0);\n    (*in_grads)[1] =\n        JUST(functional::ReduceSum(out_grads[0], axes_vec, /*keepdims=*/false, NullOpt));\n  }\n  if (ctx->in_requires_grad) { (*in_grads)[0] = JUST(functional::Fill(out_grads[0], 0)); }\n  return Maybe<void>::Ok();\n}\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"fill_\", Fill);\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"fill_tensor_\", FillTensor);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/flatten.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/framework/device.h\"\n#include \"oneflow/core/framework/op_builder.h\"\n#include \"oneflow/core/framework/op_interpreter/op_interpreter_util.h\"\n#include \"oneflow/core/framework/op_expr.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct FlattenCaptureState : public AutoGradCaptureState {\n  bool requires_grad;\n};\n\nclass Flatten : public OpExprGradFunction<FlattenCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override;\n  Maybe<void> Capture(FlattenCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override;\n  Maybe<void> Apply(const FlattenCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override;\n};\n\nMaybe<void> Flatten::Init(const OpExpr& op) {\n  const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n  CHECK_NOTNULL_OR_RETURN(fw_op_expr);\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> Flatten::Capture(FlattenCaptureState* ctx, const TensorTuple& inputs,\n                             const TensorTuple& outputs, const AttrMap& attrs) const {\n  ctx->requires_grad = inputs.at(0)->requires_grad();\n  if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n  ctx->SaveTensorForBackward(inputs.at(0));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> Flatten::Apply(const FlattenCaptureState* ctx, const TensorTuple& out_grads,\n                           TensorTuple* in_grads) const {\n  if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n  CHECK_EQ_OR_RETURN(out_grads.size(), 1);\n  const auto& like = ctx->SavedTensors().at(0);\n  in_grads->resize(1);\n  in_grads->at(0) = JUST(functional::ReshapeLike(out_grads.at(0), like));\n  return Maybe<void>::Ok();\n}\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"flatten\", Flatten);\n\n}  // namespace one\n}  // namespace oneflow"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/flip.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/attr_map.h\"\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct FlipCaptureState : public AutoGradCaptureState {\n  bool requires_grad;\n  std::vector<int32_t> dims;\n};\n\nclass Flip : public OpExprGradFunction<FlipCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override;\n  Maybe<void> Capture(FlipCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs,\n                      const AttrMap& attrs) const override;\n  Maybe<void> Apply(const FlipCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override;\n\n private:\n  AttrMap base_attrs_;\n};\n\nMaybe<void> Flip::Init(const OpExpr& op) {\n  const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n  CHECK_NOTNULL_OR_RETURN(fw_op_expr);\n  base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> Flip::Capture(FlipCaptureState* ctx, const TensorTuple& inputs,\n                          const TensorTuple& outputs, const AttrMap& attrs) const {\n  ctx->requires_grad = inputs.at(0)->requires_grad();\n  if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n  ComposedAttrMap composed_attrs(attrs, base_attrs_);\n  ctx->dims = JUST(composed_attrs.GetAttr<std::vector<int32_t>>(\"dims\"));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> Flip::Apply(const FlipCaptureState* ctx, const TensorTuple& out_grads,\n                        TensorTuple* in_grads) const {\n  CHECK_EQ_OR_RETURN(out_grads.size(), 1);\n  in_grads->resize(1);\n  if (ctx->requires_grad) { (*in_grads)[0] = JUST(functional::Flip(out_grads[0], ctx->dims)); }\n  return Maybe<void>::Ok();\n}\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"flip\", Flip);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/fold.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/attr_map.h\"\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct FoldInterpState : public AutoGradCaptureState {\n  bool requires_grad = true;\n  std::string data_format = \"channels_first\";\n  std::vector<int32_t> kernel_size;\n  std::vector<int32_t> dilation_rate;\n  std::vector<int32_t> padding;\n  std::vector<int32_t> strides;\n};\n\nclass Fold : public OpExprGradFunction<FoldInterpState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override;\n  Maybe<void> Capture(FoldInterpState* ctx, const TensorTuple& inputs, const TensorTuple& outputs,\n                      const AttrMap& attrs) const override;\n  Maybe<void> Apply(const FoldInterpState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override;\n\n private:\n  AttrMap base_attrs_;\n};\n\nMaybe<void> Fold::Init(const OpExpr& op) {\n  const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n  CHECK_NOTNULL_OR_RETURN(fw_op_expr);\n  base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> Fold::Capture(FoldInterpState* ctx, const TensorTuple& inputs,\n                          const TensorTuple& outputs, const AttrMap& attrs) const {\n  ctx->requires_grad = inputs.at(0)->requires_grad();\n  if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n  ComposedAttrMap composed_attrs(attrs, base_attrs_);\n  ctx->data_format = JUST(composed_attrs.GetAttr<std::string>(\"data_format\"));\n  ctx->kernel_size = JUST(composed_attrs.GetAttr<std::vector<int32_t>>(\"kernel_size\"));\n  ctx->dilation_rate = JUST(composed_attrs.GetAttr<std::vector<int32_t>>(\"dilation_rate\"));\n  ctx->padding = JUST(composed_attrs.GetAttr<std::vector<int32_t>>(\"padding\"));\n  ctx->strides = JUST(composed_attrs.GetAttr<std::vector<int32_t>>(\"strides\"));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> Fold::Apply(const FoldInterpState* ctx, const TensorTuple& out_grads,\n                        TensorTuple* in_grads) const {\n  if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n  CHECK_EQ_OR_RETURN(out_grads.size(), 1);\n  in_grads->resize(1);\n  in_grads->at(0) = JUST(functional::Unfold(out_grads.at(0), ctx->kernel_size, ctx->dilation_rate,\n                                            ctx->padding, ctx->strides, ctx->data_format));\n  return Maybe<void>::Ok();\n}\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"fold\", Fold);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/fused_bias_add_dropout.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/framework/op_interpreter/op_interpreter_util.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct FusedBiasAddDropoutInterpState : public AutoGradCaptureState {\n  bool input_requires_grad = true;\n  bool bias_requires_grad = true;\n  int32_t axis = 1;\n  float scale = 1.0;\n};\n\nclass FusedBiasAddDropout : public OpExprGradFunction<FusedBiasAddDropoutInterpState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override;\n  Maybe<void> Capture(FusedBiasAddDropoutInterpState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override;\n  Maybe<void> Apply(const FusedBiasAddDropoutInterpState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override;\n\n private:\n  AttrMap base_attrs_;\n};\n\nMaybe<void> FusedBiasAddDropout::Init(const OpExpr& op) {\n  const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n  CHECK_NOTNULL_OR_RETURN(fw_op_expr);\n  base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> FusedBiasAddDropout::Capture(FusedBiasAddDropoutInterpState* ctx,\n                                         const TensorTuple& inputs, const TensorTuple& outputs,\n                                         const AttrMap& attrs) const {\n  CHECK_EQ_OR_RETURN(inputs.size(), 3);\n  ctx->input_requires_grad = inputs.at(0)->requires_grad();  // input\n  ctx->bias_requires_grad = inputs.at(1)->requires_grad();   // bias\n\n  if (!ctx->input_requires_grad && !ctx->bias_requires_grad) { return Maybe<void>::Ok(); }\n  ComposedAttrMap composed_attrs(attrs, base_attrs_);\n  ctx->scale = JUST(composed_attrs.GetAttr<float>(\"scale\"));\n  ctx->axis = JUST(composed_attrs.GetAttr<int32_t>(\"axis\"));\n\n  ctx->SaveTensorForBackward(inputs.at(2));\n\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> FusedBiasAddDropout::Apply(const FusedBiasAddDropoutInterpState* ctx,\n                                       const TensorTuple& out_grads, TensorTuple* in_grads) const {\n  CHECK_EQ_OR_RETURN(out_grads.size(), 1);\n  if (!ctx->input_requires_grad && !ctx->bias_requires_grad) { return Maybe<void>::Ok(); }\n\n  // mask have no grad(reqiures_grad=False), but still take a place in in_grads\n  in_grads->resize(3);\n\n  const std::shared_ptr<oneflow::one::Tensor>& mask = ctx->SavedTensors().at(0);\n  const std::shared_ptr<oneflow::one::Tensor>& dropout_grad =\n      JUST(functional::DropoutGrad(out_grads.at(0), mask, ctx->scale));\n\n  if (ctx->input_requires_grad) { in_grads->at(0) = dropout_grad; }\n\n  const int64_t num_axes = out_grads.at(0)->shape()->NumAxes();\n  if (ctx->bias_requires_grad) {\n    std::vector<int32_t> reduce_axes_vec;\n    reduce_axes_vec.reserve(num_axes);\n    for (int i = 0; i < num_axes; ++i) {\n      if (i != ctx->axis) { reduce_axes_vec.emplace_back(i); }\n    }\n    in_grads->at(1) = JUST(functional::ReduceSum(dropout_grad, reduce_axes_vec, false, NullOpt));\n  }\n  return Maybe<void>::Ok();\n}\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"fused_bias_add_mask_scale\", FusedBiasAddDropout);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/fused_bias_add_gelu.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct FusedBiasAddGeluInterpState : public AutoGradCaptureState {\n  bool input_requires_grad = true;\n  bool bias_requires_grad = true;\n  int32_t axis = 1;\n};\n\nclass FusedBiasAddGelu : public OpExprGradFunction<FusedBiasAddGeluInterpState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override {\n    const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n    CHECK_NOTNULL_OR_RETURN(fw_op_expr);\n    base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Capture(FusedBiasAddGeluInterpState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    CHECK_EQ_OR_RETURN(inputs.size(), 2);\n    ctx->input_requires_grad = inputs.at(0)->requires_grad();\n    ctx->bias_requires_grad = inputs.at(1)->requires_grad();\n    ComposedAttrMap composed_attrs(attrs, base_attrs_);\n    ctx->axis = JUST(composed_attrs.GetAttr<int32_t>(\"axis\"));\n    if (ctx->input_requires_grad || ctx->bias_requires_grad) {\n      ctx->SaveTensorForBackward(inputs.at(0));\n      ctx->SaveTensorForBackward(inputs.at(1));\n    }\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const FusedBiasAddGeluInterpState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    if (!ctx->input_requires_grad && !ctx->bias_requires_grad) { return Maybe<void>::Ok(); }\n\n    CHECK_EQ_OR_RETURN(out_grads.size(), 1);\n    const int64_t num_axes = out_grads.at(0)->shape()->NumAxes();\n    in_grads->resize(2);\n    const auto& a = ctx->SavedTensors().at(0);\n    const auto& b = ctx->SavedTensors().at(1);\n    const std::shared_ptr<oneflow::one::Tensor>& fused_bias_add_gelu_grad =\n        JUST(functional::FusedBiasAddGeluGrad(a, b, out_grads.at(0), ctx->axis));\n    if (ctx->bias_requires_grad) {\n      std::vector<int32_t> reduce_axes_vec;\n      reduce_axes_vec.reserve(num_axes);\n      for (int i = 0; i < num_axes; ++i) {\n        if (i != ctx->axis) { reduce_axes_vec.emplace_back(i); }\n      }\n      in_grads->at(1) =\n          JUST(functional::ReduceSum(fused_bias_add_gelu_grad, reduce_axes_vec, false, NullOpt));\n    }\n    if (ctx->input_requires_grad) { in_grads->at(0) = fused_bias_add_gelu_grad; }\n    return Maybe<void>::Ok();\n  }\n\n private:\n  AttrMap base_attrs_;\n};\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"fused_bias_add_gelu\", FusedBiasAddGelu);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/fused_bias_add_scale_mask_softmax_dropout.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct FusedBiasAddScaleMaskSoftmaxDropoutCaptureState : public AutoGradCaptureState {\n  bool x_requires_grad = false;\n  bool bias_requires_grad = false;\n  bool bias_broadcast = false;\n  int softmax_y_index = -1;\n  int bias_index = -1;\n  int mask_index = -1;\n  int dropout_mask_index = -1;\n  float scale = 1.0;\n  float dropout_scale = 1.0;\n};\n\nclass FusedBiasAddScaleMaskSoftmaxDropoutGradFunction\n    : public OpExprGradFunction<FusedBiasAddScaleMaskSoftmaxDropoutCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override {\n    const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n    CHECK_NOTNULL_OR_RETURN(fw_op_expr);\n    base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Capture(FusedBiasAddScaleMaskSoftmaxDropoutCaptureState* ctx,\n                      const TensorTuple& inputs, const TensorTuple& outputs,\n                      const AttrMap& attrs) const override {\n    CHECK_EQ_OR_RETURN(outputs.size(), 2);  // (y, softmax_y)\n    CHECK_EQ_OR_RETURN(inputs.size(), 4);   // (x, bias, mask, dropout_mask)\n    ctx->x_requires_grad = inputs.at(0)->requires_grad();\n    ctx->bias_requires_grad = inputs.at(1)->requires_grad();\n\n    if (!ctx->x_requires_grad && !ctx->bias_requires_grad) { return Maybe<void>::Ok(); }\n\n    ComposedAttrMap composed_attrs(attrs, base_attrs_);\n    ctx->scale = JUST(composed_attrs.GetAttr<float>(\"scale_value\"));\n    ctx->dropout_scale = JUST(composed_attrs.GetAttr<float>(\"dropout_scale_value\"));\n\n    if (ctx->x_requires_grad) {\n      ctx->mask_index = ctx->SaveTensorForBackward(inputs.at(2));          // mask\n      ctx->dropout_mask_index = ctx->SaveTensorForBackward(inputs.at(3));  // dropout_mask\n      ctx->softmax_y_index = ctx->SaveTensorForBackward(outputs.at(1));    // softmax_y\n    }\n\n    if (ctx->bias_requires_grad) {\n      ctx->bias_broadcast = (inputs.at(0)->shape() != inputs.at(1)->shape());\n      if (ctx->bias_broadcast) {\n        ctx->bias_index = ctx->SaveTensorForBackward(inputs.at(1));  // bias\n      }\n    }\n\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const FusedBiasAddScaleMaskSoftmaxDropoutCaptureState* ctx,\n                    const TensorTuple& out_grads, TensorTuple* in_grads) const override {\n    if (!ctx->x_requires_grad && !ctx->bias_requires_grad) { return Maybe<void>::Ok(); }\n\n    CHECK_EQ_OR_RETURN(out_grads.size(), 2);  // (dy, d_softmax_y)\n    in_grads->resize(4);                      // (x, bias, mask, dropout_mask)\n\n    const auto& saved_tensors = ctx->SavedTensors();\n    const auto& dy = out_grads.at(0);\n    CHECK_GE_OR_RETURN(saved_tensors.size(), 3);  // (mask, dropout_mask, softmax_y, [bias])\n\n    if (ctx->x_requires_grad || ctx->bias_requires_grad) {\n      const auto& mask = saved_tensors.at(ctx->mask_index);\n      const auto& dropout_mask = saved_tensors.at(ctx->dropout_mask_index);\n      const auto& softmax_y = saved_tensors.at(ctx->softmax_y_index);\n      in_grads->at(0) = JUST(functional::FusedScaleMaskSoftmaxDropoutGrad(\n          softmax_y, dy, mask, dropout_mask, ctx->scale, ctx->dropout_scale));\n    }\n\n    if (ctx->bias_requires_grad) {\n      if (ctx->bias_broadcast) {\n        const auto& bias = saved_tensors.at(ctx->bias_index);\n        in_grads->at(1) = JUST(functional::BroadcastReduceSumLike(in_grads->at(0), bias));\n      } else {\n        in_grads->at(1) = in_grads->at(0);\n      }\n    }\n\n    return Maybe<void>::Ok();\n  }\n\n private:\n  AttrMap base_attrs_;\n};\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"fused_bias_add_scale_mask_softmax_dropout\",\n                               FusedBiasAddScaleMaskSoftmaxDropoutGradFunction);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/fused_center.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\nnamespace one {\n\nconst int32_t INPUT_LEN = 8;\nstruct FusedCenterCaptureState : public AutoGradCaptureState {\n  std::vector<bool> requires_grad;\n};\n\nclass FusedCenterGrad : public OpExprGradFunction<FusedCenterCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }\n\n  Maybe<void> Capture(FusedCenterCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    CHECK_EQ_OR_RETURN(inputs.size(), INPUT_LEN);\n    CHECK_EQ_OR_RETURN(outputs.size(), 1);\n    for (int i = 0; i < INPUT_LEN; i++) {\n      ctx->requires_grad.push_back(inputs.at(i)->requires_grad());\n    }\n    for (int i = 0; i < INPUT_LEN; i++) { ctx->SaveTensorForBackward(inputs.at(i)); }\n\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const FusedCenterCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    CHECK_EQ_OR_RETURN(out_grads.size(), 1);\n    const auto& rho2_diff = out_grads.at(0);\n\n    const auto& b1_x1 = ctx->SavedTensors().at(0);\n    const auto& b1_x2 = ctx->SavedTensors().at(1);\n    const auto& b2_x1 = ctx->SavedTensors().at(2);\n    const auto& b2_x2 = ctx->SavedTensors().at(3);\n    const auto& b1_y1 = ctx->SavedTensors().at(4);\n    const auto& b1_y2 = ctx->SavedTensors().at(5);\n    const auto& b2_y1 = ctx->SavedTensors().at(6);\n    const auto& b2_y2 = ctx->SavedTensors().at(7);\n\n    in_grads->resize(INPUT_LEN);\n    auto result = JUST(functional::FusedCenterGrad(b1_x1, b1_x2, b2_x1, b2_x2, b1_y1, b1_y2, b2_y1,\n                                                   b2_y2, rho2_diff));\n\n    CHECK_EQ_OR_RETURN(result->size(), INPUT_LEN);\n    for (int i = 0; i < INPUT_LEN; i++) {\n      if (ctx->requires_grad[i]) { in_grads->at(i) = result->at(i); }\n    }\n    return Maybe<void>::Ok();\n  }\n};\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"fused_get_center_dist\", FusedCenterGrad);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/fused_cross_interaction.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/framework/op_builder.h\"\n#include \"oneflow/core/framework/op_expr.h\"\n#include \"oneflow/core/framework/op_interpreter/op_interpreter_util.h\"\n#include \"oneflow/core/functional/functional.h\"\n#include \"oneflow/core/common/container_util.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct FusedCrossFeatureInteractionInterpState : public AutoGradCaptureState {\n  bool x_requires_grad = true;\n  bool weight_requires_grad = true;\n  bool x0_requires_grad = true;\n  bool bias_requires_grad = true;\n  size_t x_idx = 0;\n  size_t bias_idx = 0;\n  size_t weight_idx = 0;\n  size_t x0_idx = 0;\n  size_t matmul_result_idx = 0;\n  std::string interaction_mode;\n};\n\nclass FusedCrossFeatureInteraction\n    : public OpExprGradFunction<FusedCrossFeatureInteractionInterpState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override {\n    const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n    CHECK_NOTNULL_OR_RETURN(fw_op_expr) << \"fw_op_expr should not be None. \";\n    base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Capture(FusedCrossFeatureInteractionInterpState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    CHECK_EQ_OR_RETURN(inputs.size(), 4) << \"Input size should be equal to 4. \";\n    ComposedAttrMap composed_attrs(attrs, base_attrs_);\n    ctx->interaction_mode = JUST(composed_attrs.GetAttr<std::string>(\"interaction_mode\"));\n    ctx->x_requires_grad = JUST(oneflow::VectorAt(inputs, 0))->requires_grad();\n    ctx->weight_requires_grad = JUST(oneflow::VectorAt(inputs, 1))->requires_grad();\n    ctx->x_requires_grad = JUST(oneflow::VectorAt(inputs, 2))->requires_grad();\n    ctx->weight_requires_grad = JUST(oneflow::VectorAt(inputs, 3))->requires_grad();\n    ctx->x_idx = ctx->SaveTensorForBackward(JUST(oneflow::VectorAt(inputs, 0)));\n    ctx->weight_idx = ctx->SaveTensorForBackward(JUST(oneflow::VectorAt(inputs, 1)));\n    ctx->x0_idx = ctx->SaveTensorForBackward(JUST(oneflow::VectorAt(inputs, 2)));\n    if (ctx->interaction_mode == \"matrix\") {\n      ctx->bias_idx = ctx->SaveTensorForBackward(JUST(oneflow::VectorAt(inputs, 3)));\n    }\n    ctx->matmul_result_idx = ctx->SaveTensorForBackward(JUST(oneflow::VectorAt(outputs, 1)));\n\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const FusedCrossFeatureInteractionInterpState* ctx,\n                    const TensorTuple& out_grads, TensorTuple* in_grads) const override {\n    CHECK_EQ_OR_RETURN(out_grads.size(), 2) << \"Out grads size should be equal to 2. \";\n    std::shared_ptr<oneflow::one::TensorTuple> grads;\n    in_grads->resize(4);\n    if (ctx->interaction_mode == \"vector\") {\n      grads = JUST(functional::FusedCrossFeatureInteractionV1Grad(\n          JUST(oneflow::VectorAt(out_grads, 0)),\n          JUST(oneflow::VectorAt(ctx->SavedTensors(), ctx->weight_idx)),\n          JUST(oneflow::VectorAt(ctx->SavedTensors(), ctx->x_idx)),\n          JUST(oneflow::VectorAt(ctx->SavedTensors(), ctx->x0_idx)),\n          JUST(oneflow::VectorAt(ctx->SavedTensors(), ctx->matmul_result_idx))));\n    } else if (ctx->interaction_mode == \"matrix\") {\n      grads = JUST(functional::FusedCrossFeatureInteractionV2Grad(\n          JUST(oneflow::VectorAt(out_grads, 0)),\n          JUST(oneflow::VectorAt(ctx->SavedTensors(), ctx->weight_idx)),\n          JUST(oneflow::VectorAt(ctx->SavedTensors(), ctx->bias_idx)),\n          JUST(oneflow::VectorAt(ctx->SavedTensors(), ctx->x_idx)),\n          JUST(oneflow::VectorAt(ctx->SavedTensors(), ctx->x0_idx)),\n          JUST(oneflow::VectorAt(ctx->SavedTensors(), ctx->matmul_result_idx))));\n    } else {\n      UNIMPLEMENTED_THEN_RETURN() << \"Interaction mode only support `vector` and `matrix`. \";\n    }\n\n    if (ctx->x_requires_grad) {\n      JUST(oneflow::VectorAt(*in_grads, 0)) = JUST(oneflow::VectorAt(*grads, 0));\n    }\n    if (ctx->weight_requires_grad) {\n      JUST(oneflow::VectorAt(*in_grads, 1)) = JUST(oneflow::VectorAt(*grads, 1));\n    }\n    if (ctx->x0_requires_grad) {\n      JUST(oneflow::VectorAt(*in_grads, 2)) = JUST(oneflow::VectorAt(*grads, 2));\n    }\n    if (ctx->bias_requires_grad) {\n      JUST(oneflow::VectorAt(*in_grads, 3)) = JUST(oneflow::VectorAt(*grads, 3));\n    }\n\n    return Maybe<void>::Ok();\n  }\n\n private:\n  AttrMap base_attrs_;\n};\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"fused_cross_feature_interaction\", FusedCrossFeatureInteraction);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/fused_dot_feature_interaction.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/framework/op_builder.h\"\n#include \"oneflow/core/framework/op_expr.h\"\n#include \"oneflow/core/framework/op_interpreter/op_interpreter_util.h\"\n#include \"oneflow/core/functional/functional.h\"\n#include \"oneflow/core/common/container_util.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct FusedDotFeatureInteractionCaptureState : public AutoGradCaptureState {\n  bool need_grad_op = false;\n  std::vector<bool> features_requires_grad;\n  std::vector<int32_t> feature_dims;\n  int32_t output_concat_grad_dim = 0;\n  bool self_interaction = false;\n  bool has_output_concat = false;\n  bool has_output_concat_grad = false;\n  std::string pooling;\n};\n\nclass FusedDotFeatureInteraction\n    : public OpExprGradFunction<FusedDotFeatureInteractionCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override;\n  Maybe<void> Capture(FusedDotFeatureInteractionCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override;\n  Maybe<void> Apply(const FusedDotFeatureInteractionCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override;\n\n private:\n  AttrMap base_attrs_;\n};\n\nMaybe<void> FusedDotFeatureInteraction::Init(const OpExpr& op) {\n  const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n  CHECK_NOTNULL_OR_RETURN(fw_op_expr);\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> FusedDotFeatureInteraction::Capture(FusedDotFeatureInteractionCaptureState* ctx,\n                                                const TensorTuple& inputs,\n                                                const TensorTuple& outputs,\n                                                const AttrMap& attrs) const {\n  ctx->has_output_concat = JUST(attrs.GetAttr<bool>(\"has_output_concat\"));\n  int32_t num_features = 0;\n  if (ctx->has_output_concat) {\n    num_features = inputs.size() - 1;\n    const auto& output_concat = JUST(oneflow::VectorAt(inputs, num_features));\n    ctx->has_output_concat_grad = output_concat->requires_grad();\n    ctx->output_concat_grad_dim = output_concat->shape()->At(1);\n  } else {\n    num_features = inputs.size();\n  }\n  if (ctx->has_output_concat_grad) { ctx->need_grad_op = true; }\n  ctx->features_requires_grad.resize(num_features);\n  ctx->feature_dims.resize(num_features);\n  for (int32_t i = 0; i < num_features; ++i) {\n    const auto& feature = JUST(oneflow::VectorAt(inputs, i));\n    ctx->features_requires_grad[i] = feature->requires_grad();\n    ctx->feature_dims[i] = feature->shape()->At(1);\n    if (feature->requires_grad()) { ctx->need_grad_op = true; }\n    ctx->SaveTensorForBackward(feature);\n  }\n  ctx->pooling = JUST(attrs.GetAttr<std::string>(\"pooling\"));\n  if (!ctx->need_grad_op) { return Maybe<void>::Ok(); }\n  ctx->self_interaction = JUST(attrs.GetAttr<bool>(\"self_interaction\"));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> FusedDotFeatureInteraction::Apply(const FusedDotFeatureInteractionCaptureState* ctx,\n                                              const TensorTuple& out_grads,\n                                              TensorTuple* in_grads) const {\n  if (!ctx->need_grad_op) { return Maybe<void>::Ok(); }\n  int32_t num_features = ctx->features_requires_grad.size();\n  in_grads->resize(num_features + 1);\n  TensorTuple features(num_features);\n  for (int i = 0; i < num_features; ++i) {\n    features[i] = JUST(oneflow::VectorAt(ctx->SavedTensors(), i));\n  }\n  std::shared_ptr<oneflow::one::TensorTuple> grads;\n  grads = JUST(functional::FusedDotFeatureInteractionGrad(\n      JUST(oneflow::VectorAt(out_grads, 0)), features, ctx->has_output_concat,\n      ctx->self_interaction, ctx->output_concat_grad_dim, ctx->pooling));\n  for (int32_t i = 0; i < num_features; ++i) {\n    if (JUST(oneflow::VectorAt(ctx->features_requires_grad, i))) {\n      JUST(oneflow::VectorAt(*in_grads, i)) = JUST(oneflow::VectorAt(*grads, i));\n    }\n  }\n  if (ctx->has_output_concat_grad) {\n    JUST(oneflow::VectorAt(*in_grads, num_features)) =\n        JUST(oneflow::VectorAt(*grads, num_features));\n  }\n  return Maybe<void>::Ok();\n}\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"fused_dot_feature_interaction\", FusedDotFeatureInteraction);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/fused_fast_gelu_mul.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct FusedFastGeluMulGradCaptureState : public AutoGradCaptureState {\n  bool requires_grad = true;\n};\n\nclass FusedFastGeluMulGrad : public OpExprGradFunction<FusedFastGeluMulGradCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }\n\n  Maybe<void> Capture(FusedFastGeluMulGradCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    CHECK_EQ_OR_RETURN(inputs.size(), 2);   // (in, multiplier)\n    CHECK_EQ_OR_RETURN(outputs.size(), 1);  // (out,)\n    ctx->requires_grad = inputs.at(0)->requires_grad() || inputs.at(1)->requires_grad();\n    if (ctx->requires_grad) {\n      ctx->SaveTensorForBackward(inputs.at(0));  // in\n      ctx->SaveTensorForBackward(inputs.at(1));  // multiplier\n    }\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const FusedFastGeluMulGradCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n\n    CHECK_EQ_OR_RETURN(out_grads.size(), 1);\n    const auto& out_diff = out_grads.at(0);\n\n    const auto& saved_tensors = ctx->SavedTensors();\n    CHECK_EQ_OR_RETURN(saved_tensors.size(), 2);\n    const auto& in = saved_tensors.at(0);\n    const auto& multiplier = saved_tensors.at(1);\n\n    in_grads->resize(2);  // (in_diff, multiplier_diff)\n    auto result = JUST(functional::FusedFastGeluMulGrad(out_diff, in, multiplier));\n    CHECK_EQ_OR_RETURN(result->size(), 2);\n    in_grads->at(0) = result->at(0);\n    in_grads->at(1) = result->at(1);\n    return Maybe<void>::Ok();\n  }\n};\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"fused_fast_gelu_mul\", FusedFastGeluMulGrad);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/fused_get_boundding_boxes_coord.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <vector>\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\nnamespace one {\n\nconst int32_t INPUT_LEN = 8;\nstruct FusedGetBounddingBoxesCoordGradCaptureState : public AutoGradCaptureState {\n  std::vector<bool> requires_grad;\n};\n\nclass FusedGetBounddingBoxesCoordGrad\n    : public OpExprGradFunction<FusedGetBounddingBoxesCoordGradCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }\n\n  Maybe<void> Capture(FusedGetBounddingBoxesCoordGradCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    CHECK_EQ_OR_RETURN(inputs.size(), INPUT_LEN);\n    CHECK_EQ_OR_RETURN(outputs.size(), INPUT_LEN);\n    for (int i = 0; i < INPUT_LEN; i++) {\n      ctx->requires_grad.push_back(inputs.at(i)->requires_grad());\n    }\n\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const FusedGetBounddingBoxesCoordGradCaptureState* ctx,\n                    const TensorTuple& out_grads, TensorTuple* in_grads) const override {\n    CHECK_EQ_OR_RETURN(out_grads.size(), INPUT_LEN);\n    const auto& b1_x1_diff = out_grads.at(0);\n    const auto& b1_x2_diff = out_grads.at(1);\n    const auto& b1_y1_diff = out_grads.at(2);\n    const auto& b1_y2_diff = out_grads.at(3);\n    const auto& b2_x1_diff = out_grads.at(4);\n    const auto& b2_x2_diff = out_grads.at(5);\n    const auto& b2_y1_diff = out_grads.at(6);\n    const auto& b2_y2_diff = out_grads.at(7);\n\n    in_grads->resize(8);\n    auto result = JUST(functional::FusedGetBounddingBoxesCoordGrad(\n        b1_x1_diff, b1_x2_diff, b1_y1_diff, b1_y2_diff, b2_x1_diff, b2_x2_diff, b2_y1_diff,\n        b2_y2_diff));\n    CHECK_EQ_OR_RETURN(result->size(), INPUT_LEN);\n    for (int i = 0; i < result->size(); i++) {\n      if (ctx->requires_grad[i]) { in_grads->at(i) = result->at(i); }\n    }\n    return Maybe<void>::Ok();\n  }\n};\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"fused_get_boundding_boxes_coord\", FusedGetBounddingBoxesCoordGrad);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/fused_get_ciou_diagonal_angle.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <vector>\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\nnamespace one {\n\nconst int32_t INPUT_LEN = 4;\nstruct FusedCiouAngleCaptureState : public AutoGradCaptureState {\n  std::vector<bool> requires_grad;\n  float eps = 1e-8;\n};\n\nclass FusedGetCiouDiagonalAngleGrad : public OpExprGradFunction<FusedCiouAngleCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }\n\n  Maybe<void> Capture(FusedCiouAngleCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    CHECK_EQ_OR_RETURN(inputs.size(), INPUT_LEN);\n    CHECK_EQ_OR_RETURN(outputs.size(), 1);\n\n    for (int i = 0; i < INPUT_LEN; i++) {\n      ctx->requires_grad.push_back(inputs.at(i)->requires_grad());\n    }\n    for (int i = 0; i < INPUT_LEN; i++) { ctx->SaveTensorForBackward(inputs.at(i)); }\n\n    ComposedAttrMap composed_attrs(attrs);\n    ctx->eps = JUST(composed_attrs.GetAttr<float>(\"eps\"));\n\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const FusedCiouAngleCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    CHECK_EQ_OR_RETURN(out_grads.size(), 1);\n    const auto& v_diff = out_grads.at(0);\n\n    const auto& w1 = ctx->SavedTensors().at(0);\n    const auto& h1 = ctx->SavedTensors().at(1);\n    const auto& w2 = ctx->SavedTensors().at(2);\n    const auto& h2 = ctx->SavedTensors().at(3);\n\n    auto result = JUST(functional::FusedGetCiouDiagonalAngleGrad(w1, h1, w2, h2, v_diff, ctx->eps));\n    CHECK_EQ_OR_RETURN(result->size(), INPUT_LEN);\n\n    in_grads->resize(INPUT_LEN);\n    for (int i = 0; i < INPUT_LEN; i++) {\n      if (ctx->requires_grad[i]) { in_grads->at(i) = result->at(i); }\n    }\n    return Maybe<void>::Ok();\n  }\n};\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"fused_get_ciou_diagonal_angle\", FusedGetCiouDiagonalAngleGrad);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/fused_get_ciou_result.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <vector>\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct FusedGetCiouResultGradCaptureState : public AutoGradCaptureState {\n  bool v_requires_grad = false;\n  bool iou_requires_grad = false;\n  bool rho2_requires_grad = false;\n  bool c2_requires_grad = false;\n};\n\nclass FusedGetCiouResultGrad : public OpExprGradFunction<FusedGetCiouResultGradCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }\n\n  Maybe<void> Capture(FusedGetCiouResultGradCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    CHECK_EQ_OR_RETURN(inputs.size(), 4);\n    CHECK_EQ_OR_RETURN(outputs.size(), 2);\n    ctx->v_requires_grad = inputs.at(0)->requires_grad();\n    ctx->iou_requires_grad = inputs.at(1)->requires_grad();\n    ctx->rho2_requires_grad = inputs.at(2)->requires_grad();\n    ctx->c2_requires_grad = inputs.at(3)->requires_grad();\n    if (ctx->v_requires_grad && ctx->iou_requires_grad && ctx->rho2_requires_grad\n        && ctx->c2_requires_grad) {\n      ctx->SaveTensorForBackward(outputs.at(1));  // alpha\n      ctx->SaveTensorForBackward(inputs.at(2));   // rho2\n      ctx->SaveTensorForBackward(inputs.at(3));   // c2\n    }\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const FusedGetCiouResultGradCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    CHECK_EQ_OR_RETURN(out_grads.size(), 2);\n    const auto& dy = out_grads.at(0);\n\n    const auto& saved_tensors = ctx->SavedTensors();\n    CHECK_EQ_OR_RETURN(saved_tensors.size(), 3);\n    const auto& alpha = saved_tensors.at(0);\n    const auto& rho2 = saved_tensors.at(1);\n    const auto& c2 = saved_tensors.at(2);\n\n    in_grads->resize(4);\n    auto result = JUST(functional::FusedGetCiouResultGrad(dy, alpha, rho2, c2));\n    CHECK_EQ_OR_RETURN(result->size(), 4);\n    if (ctx->v_requires_grad && ctx->iou_requires_grad && ctx->rho2_requires_grad\n        && ctx->c2_requires_grad) {\n      in_grads->at(0) = result->at(0);\n      in_grads->at(1) = result->at(1);\n      in_grads->at(2) = result->at(2);\n      in_grads->at(3) = result->at(3);\n    }\n    return Maybe<void>::Ok();\n  }\n\n private:\n  AttrMap base_attrs_;\n};\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"fused_get_ciou_result\", FusedGetCiouResultGrad);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/fused_get_convex_diagonal_squared.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\nnamespace one {\n\nconst int32_t INPUT_LEN = 8;\nstruct FusedGetConvexDiagonalSquaredCaptureState : public AutoGradCaptureState {\n  std::vector<bool> requires_grad;\n  float eps = 1e-8;\n};\n\nclass FusedGetConvexDiagonalSquaredGrad\n    : public OpExprGradFunction<FusedGetConvexDiagonalSquaredCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override {\n    const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n    CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n    base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Capture(FusedGetConvexDiagonalSquaredCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    CHECK_EQ_OR_RETURN(inputs.size(), INPUT_LEN);\n    CHECK_EQ_OR_RETURN(outputs.size(), 1);\n    for (int i = 0; i < INPUT_LEN; i++) {\n      ctx->requires_grad.push_back(inputs.at(i)->requires_grad());\n    }\n    ComposedAttrMap composed_attrs(attrs, base_attrs_);\n    ctx->eps = JUST(composed_attrs.GetAttr<float>(\"eps\"));\n    for (int i = 0; i < INPUT_LEN; i++) { ctx->SaveTensorForBackward(inputs.at(i)); }\n\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const FusedGetConvexDiagonalSquaredCaptureState* ctx,\n                    const TensorTuple& out_grads, TensorTuple* in_grads) const override {\n    CHECK_EQ_OR_RETURN(out_grads.size(), 1);\n    const auto& c2_diff = out_grads.at(0);\n\n    const auto& b1_x1 = ctx->SavedTensors().at(0);\n    const auto& b1_x2 = ctx->SavedTensors().at(1);\n    const auto& b2_x1 = ctx->SavedTensors().at(2);\n    const auto& b2_x2 = ctx->SavedTensors().at(3);\n    const auto& b1_y1 = ctx->SavedTensors().at(4);\n    const auto& b1_y2 = ctx->SavedTensors().at(5);\n    const auto& b2_y1 = ctx->SavedTensors().at(6);\n    const auto& b2_y2 = ctx->SavedTensors().at(7);\n\n    in_grads->resize(INPUT_LEN);\n    auto result = JUST(functional::FusedGetConvexDiagonalSquaredGrad(\n        c2_diff, b1_x1, b1_x2, b2_x1, b2_x2, b1_y1, b1_y2, b2_y1, b2_y2, ctx->eps));\n\n    CHECK_EQ_OR_RETURN(result->size(), INPUT_LEN);\n    for (int i = 0; i < INPUT_LEN; i++) {\n      if (ctx->requires_grad[i]) { in_grads->at(i) = result->at(i); }\n    }\n    return Maybe<void>::Ok();\n  }\n\n private:\n  AttrMap base_attrs_;\n};\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"fused_get_convex_diagonal_squared\",\n                               FusedGetConvexDiagonalSquaredGrad);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/fused_get_intersection_area.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\nnamespace one {\n\nconst int32_t INPUT_LEN = 8;\nstruct FusedGetIntersectionAreaCaptureState : public AutoGradCaptureState {\n  std::vector<bool> requires_grad;\n};\n\nclass FusedGetIntersectionAreaGrad\n    : public OpExprGradFunction<FusedGetIntersectionAreaCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }\n\n  Maybe<void> Capture(FusedGetIntersectionAreaCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    CHECK_EQ_OR_RETURN(inputs.size(), INPUT_LEN);\n    CHECK_EQ_OR_RETURN(outputs.size(), 1);\n    for (int i = 0; i < INPUT_LEN; i++) {\n      ctx->requires_grad.push_back(inputs.at(i)->requires_grad());\n    }\n    for (int i = 0; i < INPUT_LEN; i++) { ctx->SaveTensorForBackward(inputs.at(i)); }\n\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const FusedGetIntersectionAreaCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    CHECK_EQ_OR_RETURN(out_grads.size(), 1);\n    const auto& rho2_diff = out_grads.at(0);\n\n    const auto& b1_x1 = ctx->SavedTensors().at(0);\n    const auto& b1_x2 = ctx->SavedTensors().at(1);\n    const auto& b2_x1 = ctx->SavedTensors().at(2);\n    const auto& b2_x2 = ctx->SavedTensors().at(3);\n    const auto& b1_y1 = ctx->SavedTensors().at(4);\n    const auto& b1_y2 = ctx->SavedTensors().at(5);\n    const auto& b2_y1 = ctx->SavedTensors().at(6);\n    const auto& b2_y2 = ctx->SavedTensors().at(7);\n\n    in_grads->resize(INPUT_LEN);\n    auto result = JUST(functional::FusedGetIntersectionAreaGrad(b1_x1, b1_x2, b2_x1, b2_x2, b1_y1,\n                                                                b1_y2, b2_y1, b2_y2, rho2_diff));\n\n    CHECK_EQ_OR_RETURN(result->size(), INPUT_LEN);\n    for (int i = 0; i < INPUT_LEN; i++) {\n      if (ctx->requires_grad[i]) { in_grads->at(i) = result->at(i); }\n    }\n    return Maybe<void>::Ok();\n  }\n};\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"fused_get_intersection_area\", FusedGetIntersectionAreaGrad);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/fused_get_iou.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <vector>\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/framework/placed_nd_sbp.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct FusedGetIouGradCaptureState : public AutoGradCaptureState {\n  bool requires_grad = true;\n  float eps = 1e-8;\n};\n\nclass FusedGetIouGrad : public OpExprGradFunction<FusedGetIouGradCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override {\n    const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n    CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n    base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Capture(FusedGetIouGradCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    CHECK_EQ_OR_RETURN(inputs.size(), 5);\n    CHECK_EQ_OR_RETURN(outputs.size(), 1);\n    ctx->requires_grad = inputs.at(0)->requires_grad() && inputs.at(1)->requires_grad()\n                         && inputs.at(4)->requires_grad();\n    ComposedAttrMap composed_attrs(attrs, base_attrs_);\n    ctx->eps = JUST(composed_attrs.GetAttr<float>(\"eps\"));\n    if (ctx->requires_grad) {\n      ctx->SaveTensorForBackward(inputs.at(0));  // w1\n      ctx->SaveTensorForBackward(inputs.at(1));  // h1\n      ctx->SaveTensorForBackward(inputs.at(2));  // w2\n      ctx->SaveTensorForBackward(inputs.at(3));  // h2\n      ctx->SaveTensorForBackward(inputs.at(4));  // inter\n    }\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const FusedGetIouGradCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    CHECK_EQ_OR_RETURN(out_grads.size(), 1);\n    const auto& diou = out_grads.at(0);\n\n    const auto& saved_tensors = ctx->SavedTensors();\n    CHECK_EQ_OR_RETURN(saved_tensors.size(), 5);\n    const auto& w1 = saved_tensors.at(0);\n    const auto& h1 = saved_tensors.at(1);\n    const auto& w2 = saved_tensors.at(2);\n    const auto& h2 = saved_tensors.at(3);\n    const auto& inter = saved_tensors.at(4);\n\n    in_grads->resize(5);\n    auto result = JUST(functional::FusedGetIouGrad(diou, w1, h1, w2, h2, inter, ctx->eps));\n    CHECK_EQ_OR_RETURN(result->size(), 3);\n    if (ctx->requires_grad) {\n      in_grads->at(0) = result->at(0);\n      in_grads->at(1) = result->at(1);\n      in_grads->at(4) = result->at(2);\n    }\n    return Maybe<void>::Ok();\n  }\n\n private:\n  AttrMap base_attrs_;\n};\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"fused_get_iou\", FusedGetIouGrad);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/fused_glu.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/container_util.h\"\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct FusedGluGradCaptureState : public AutoGradCaptureState {\n  bool is_split_mode = false;\n  bool has_bias = false;\n  std::string activation = \"none\";\n  bool w_requires_grad = false;\n  bool v_requires_grad = false;\n  bool b_requires_grad = false;\n  bool c_requires_grad = false;\n};\n\nclass FusedGluGrad : public OpExprGradFunction<FusedGluGradCaptureState> {\n  Maybe<void> Init(const OpExpr& op) override;\n\n  Maybe<void> Capture(FusedGluGradCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override;\n\n  Maybe<void> Apply(const FusedGluGradCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override;\n\n private:\n  AttrMap base_attrs_;\n};\n\nMaybe<void> FusedGluGrad::Init(const OpExpr& op) {\n  const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n  CHECK_NOTNULL_OR_RETURN(fw_op_expr);\n  base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> FusedGluGrad::Capture(FusedGluGradCaptureState* ctx, const TensorTuple& inputs,\n                                  const TensorTuple& outputs, const AttrMap& attrs) const {\n  // check input size\n  const size_t in_size = inputs.size();\n  CHECK_OR_RETURN(in_size == 2 || in_size == 3 || in_size == 5)\n      << \"FusedGluGrad::Capture(): input tensor size must be 2 or 3 or 5\";\n\n  // check the input pattern:\n  ctx->has_bias = JUST(attrs.GetAttr<bool>(\"has_bias\"));\n  ctx->is_split_mode = JUST(attrs.GetAttr<bool>(\"is_split\"));\n\n  // check whether input tensors need grad\n  ctx->w_requires_grad = inputs[1]->requires_grad();\n  if (ctx->has_bias) {\n    ctx->b_requires_grad = inputs[2]->requires_grad();\n    if (ctx->is_split_mode) {\n      ctx->v_requires_grad = inputs[3]->requires_grad();\n      ctx->c_requires_grad = inputs[4]->requires_grad();\n    }\n  } else {\n    if (ctx->is_split_mode) { ctx->v_requires_grad = inputs[2]->requires_grad(); }\n  }\n\n  // save tensors for backward\n  ctx->SaveTensorForBackward(inputs[0]);   // x\n  ctx->SaveTensorForBackward(outputs[1]);  // matmul_wx\n  if (ctx->is_split_mode) {\n    ctx->SaveTensorForBackward(outputs[2]);  // matmul_vx\n  }\n\n  // save activation type\n  ComposedAttrMap composed_attrs(attrs, base_attrs_);\n  ctx->activation = JUST(composed_attrs.GetAttr<std::string>(\"activation\"));\n\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> FusedGluGrad::Apply(const FusedGluGradCaptureState* ctx, const TensorTuple& out_grads,\n                                TensorTuple* in_grads) const {\n  // obtain saved tensors from forward process\n  const auto& x = ctx->SavedTensors()[0];\n  const auto& matmul_wx = ctx->SavedTensors()[1];\n\n  // obtain gradient dy\n  const auto& dy = out_grads[0];\n\n  if (ctx->is_split_mode) {\n    // obtain saved optional tensor from forward process\n    const auto& matmul_vx = ctx->SavedTensors()[2];\n\n    if (ctx->w_requires_grad or ctx->b_requires_grad or ctx->v_requires_grad\n        or ctx->c_requires_grad) {\n      // calculate the intermediate gradient using fused kernel\n      const auto& middle_results =\n          JUST(functional::FusedGluWithoutLinearGrad(dy, matmul_wx, matmul_vx, ctx->activation));\n      const auto& d_matmul_wx = (*middle_results)[0];\n      const auto& d_matmul_vx = (*middle_results)[1];\n\n      // calculate the final gradient result of w (if necessary)\n      if (ctx->w_requires_grad) {\n        (*in_grads)[1] = JUST(functional::BroadcastMatmulGradB(d_matmul_wx, x, 1.0));\n      }\n\n      // calculate the final gradient result of b (if necessary)\n      if (ctx->b_requires_grad) {\n        const int64_t num_axes = d_matmul_wx->shape()->NumAxes();\n        std::vector<int32_t> reduce_axes_vec;\n        reduce_axes_vec.reserve(num_axes - 1);\n        for (int i = 0; i < num_axes - 1; i++) { reduce_axes_vec.push_back(i); }\n\n        (*in_grads)[2] = JUST(functional::ReduceSum(d_matmul_wx, reduce_axes_vec, false, NullOpt));\n      }\n\n      // calculate the final gradient result of v (if necessary)\n      if (ctx->v_requires_grad) {\n        if (ctx->has_bias) {\n          (*in_grads)[3] = JUST(functional::BroadcastMatmulGradB(d_matmul_vx, x, 1.0));\n        } else {\n          (*in_grads)[2] = JUST(functional::BroadcastMatmulGradB(d_matmul_vx, x, 1.0));\n        }\n      }\n\n      // calculate the final gradient result of c (if necessary)\n      if (ctx->c_requires_grad) {\n        const int64_t num_axes = d_matmul_vx->shape()->NumAxes();\n        std::vector<int32_t> reduce_axes_vec;\n        reduce_axes_vec.reserve(num_axes - 1);\n        for (int i = 0; i < num_axes - 1; i++) { reduce_axes_vec.push_back(i); }\n\n        (*in_grads)[4] = JUST(functional::ReduceSum(d_matmul_vx, reduce_axes_vec, false, NullOpt));\n      }\n    }\n  } else {\n    if (ctx->w_requires_grad or ctx->b_requires_grad) {\n      // calculate the intermediate gradient using fused kernel\n      const auto& middle_results =\n          JUST(functional::FusedGluWithoutLinearGrad(dy, matmul_wx, nullptr, ctx->activation));\n      const auto& d_matmul_wx = (*middle_results)[0];\n\n      // calculate the final gradient result of w (if necessary)\n      if (ctx->w_requires_grad) {\n        (*in_grads)[1] = JUST(functional::BroadcastMatmulGradB(d_matmul_wx, x, 1.0));\n      }\n\n      // calculate the final gradient result of b (if necessary)\n      if (ctx->b_requires_grad) {\n        const int64_t num_axes = d_matmul_wx->shape()->NumAxes();\n        std::vector<int32_t> reduce_axes_vec;\n        reduce_axes_vec.reserve(num_axes - 1);\n        for (int i = 0; i < num_axes - 1; i++) { reduce_axes_vec.push_back(i); }\n\n        (*in_grads)[2] = JUST(functional::ReduceSum(d_matmul_wx, reduce_axes_vec, false, NullOpt));\n      }\n    }\n  }\n\n  return Maybe<void>::Ok();\n}\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"fused_glu\", FusedGluGrad);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/fused_gru_cell.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/attr_map.h\"\n#include \"oneflow/core/framework/dtype.h\"\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct FusedGruCellGradCaptureState : public AutoGradCaptureState {\n  bool has_bias = true;\n  bool hx_needs_grad = true;\n};\n\nclass FusedGruCellGrad : public OpExprGradFunction<FusedGruCellGradCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override {\n    const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n    CHECK_NOTNULL_OR_RETURN(fw_op_expr) << \"FusedGruCellGrad::Init forward op expr is null.\";\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Capture(FusedGruCellGradCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    const size_t in_size = inputs.size();\n    CHECK_OR_RETURN(in_size == 3 || in_size == 5)\n        << \"FusedGruCellGrad::Capture(): input tensor size must be 3 or 5\";\n    ctx->has_bias = in_size == 5;\n    ctx->hx_needs_grad = inputs[2]->requires_grad();\n    ctx->SaveTensorForBackward(outputs[1]);  // workspace\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const FusedGruCellGradCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    const auto& workspace = ctx->SavedTensors()[0];  // workspace\n    const auto& grad_hy = out_grads[0];\n    const auto& results =\n        JUST(functional::FusedGruCellGrad(grad_hy, workspace, ctx->has_bias, ctx->hx_needs_grad));\n\n    if (ctx->has_bias) {\n      in_grads->resize(5);\n    } else {\n      in_grads->resize(3);\n    }\n    (*in_grads)[0] = (*results)[0];\n    (*in_grads)[1] = (*results)[1];\n\n    if (ctx->hx_needs_grad) { (*in_grads)[2] = (*results)[2]; }\n\n    if (ctx->has_bias) {\n      if (ctx->hx_needs_grad) {\n        (*in_grads)[3] = (*results)[3];\n        (*in_grads)[4] = (*results)[4];\n      } else {\n        (*in_grads)[3] = (*results)[2];\n        (*in_grads)[4] = (*results)[3];\n      }\n    }\n    return Maybe<void>::Ok();\n  }\n};\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"fused_gru_cell\", FusedGruCellGrad);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/fused_lstm_cell.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/attr_map.h\"\n#include \"oneflow/core/framework/dtype.h\"\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct FusedLstmCellGradCaptureState : public AutoGradCaptureState {\n  bool has_bias = true;\n  bool need_grad_cx = true;\n};\n\nclass FusedLstmCellGrad : public OpExprGradFunction<FusedLstmCellGradCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override {\n    const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n    CHECK_NOTNULL_OR_RETURN(fw_op_expr) << \"FusedLstmCellGrad::Init forward op expr is null.\";\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Capture(FusedLstmCellGradCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    const size_t in_size = inputs.size();\n    CHECK_OR_RETURN(in_size == 3 || in_size == 5)\n        << \"FusedLstmCellGrad::Capture(): input tensor size must be 3 or 5\";\n    ctx->has_bias = in_size == 5;\n    ctx->need_grad_cx = inputs[2]->requires_grad();\n    ctx->SaveTensorForBackward(inputs[2]);   // cx\n    ctx->SaveTensorForBackward(outputs[1]);  // cy\n    ctx->SaveTensorForBackward(outputs[2]);  // workspace\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const FusedLstmCellGradCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    const auto& cx = ctx->SavedTensors()[0];         // cx\n    const auto& cy = ctx->SavedTensors()[1];         // cy\n    const auto& workspace = ctx->SavedTensors()[2];  // workspace\n\n    const auto& grad_hy = out_grads[0];\n    const auto& grad_cy = out_grads[1];\n\n    const auto& results = JUST(functional::FusedLstmCellGrad(grad_hy, grad_cy, cx, cy, workspace,\n                                                             ctx->need_grad_cx, ctx->has_bias));\n\n    if (ctx->has_bias) {\n      in_grads->resize(5);\n    } else {\n      in_grads->resize(3);\n    }\n    (*in_grads)[0] = (*results)[0];\n    (*in_grads)[1] = (*results)[0];\n\n    if (ctx->need_grad_cx) { (*in_grads)[2] = (*results)[1]; }\n\n    if (ctx->has_bias) {\n      if (ctx->need_grad_cx) {\n        (*in_grads)[3] = (*results)[2];\n        (*in_grads)[4] = (*results)[2];\n      } else {\n        (*in_grads)[3] = (*results)[1];\n        (*in_grads)[4] = (*results)[1];\n      }\n    }\n    return Maybe<void>::Ok();\n  }\n};\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"fused_lstm_cell\", FusedLstmCellGrad);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/fused_matmul_bias.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/common/container_util.h\"\n\nnamespace oneflow {\n\nnamespace one {\n\nstruct FusedMatmulBiasCaptureState : public AutoGradCaptureState {\n  bool x_requires_grad = false;\n  bool weight_requires_grad = false;\n  bool bias_requires_grad = false;\n};\n\nclass FusedMatmulBias : public OpExprGradFunction<FusedMatmulBiasCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override;\n  Maybe<void> Capture(FusedMatmulBiasCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override;\n  Maybe<void> Apply(const FusedMatmulBiasCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override;\n\n protected:\n  AttrMap base_attrs_;\n};\n\nMaybe<void> FusedMatmulBias::Init(const OpExpr& op) {\n  const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n  CHECK_NOTNULL_OR_RETURN(fw_op_expr);\n  base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> FusedMatmulBias::Capture(FusedMatmulBiasCaptureState* ctx, const TensorTuple& inputs,\n                                     const TensorTuple& outputs, const AttrMap& attrs) const {\n  CHECK_GE_OR_RETURN(inputs.size(), 3)\n      << \"x, weight, and bias, [add_to_output] should all be included\";\n  ctx->x_requires_grad = JUST(VectorAt(inputs, 0))->requires_grad();\n  ctx->weight_requires_grad = JUST(VectorAt(inputs, 1))->requires_grad();\n  ctx->bias_requires_grad = JUST(VectorAt(inputs, 2))->requires_grad();\n\n  ctx->SaveTensorForBackward(JUST(VectorAt(inputs, 0)));\n  ctx->SaveTensorForBackward(JUST(VectorAt(inputs, 1)));\n\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> FusedMatmulBias::Apply(const FusedMatmulBiasCaptureState* ctx,\n                                   const TensorTuple& out_grads, TensorTuple* in_grads) const {\n  CHECK_EQ_OR_RETURN(out_grads.size(), 1) << \"FusedMatmulBias more than one output\";\n  const auto& x = ctx->SavedTensors().at(0);\n  const auto& weight = ctx->SavedTensors().at(1);\n\n  if (ctx->x_requires_grad) {\n    in_grads->at(0) =\n        JUST(functional::MatMul(JUST(VectorAt(out_grads, 0)), weight, false, false, 1.0));\n  }\n  if (ctx->weight_requires_grad) {\n    in_grads->at(1) = JUST(functional::BroadcastMatmulGradB(JUST(VectorAt(out_grads, 0)), x, 1.0));\n  }\n  if (ctx->bias_requires_grad) {\n    const int64_t num_axes = out_grads.at(0)->shape()->NumAxes();\n    std::vector<int32_t> reduce_axes_vec;\n    reduce_axes_vec.reserve(num_axes - 1);\n    for (int i = 0; i < num_axes - 1; i++) { reduce_axes_vec.push_back(i); }\n    in_grads->at(2) =\n        JUST(functional::ReduceSum(JUST(VectorAt(out_grads, 0)), reduce_axes_vec, false, NullOpt));\n  }\n\n  return Maybe<void>::Ok();\n}\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"fused_matmul_bias\", FusedMatmulBias);\n\n}  // namespace one\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/fused_matmul_bias_add_relu_dropout.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/error.pb.h\"\n#include \"oneflow/core/common/just.h\"\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/framework/op_builder.h\"\n#include \"oneflow/core/framework/op_expr.h\"\n#include \"oneflow/core/framework/op_interpreter/op_interpreter_util.h\"\n#include \"oneflow/core/common/container_util.h\"\n#include \"oneflow/core/functional/functional.h\"\n#include \"oneflow/core/functional/functional_api.yaml.h\"\n#if CUDA_VERSION >= 11060\n\nnamespace oneflow {\n\nnamespace one {\n\nstruct FusedMatmulBiasAddReluDropoutCaptureState : public AutoGradCaptureState {\n  int32_t weight_num = 0;\n  bool skip_final_activation = false;\n  bool x_requires_grad = false;\n  std::vector<bool> weights_requires_grad;\n  std::vector<bool> biases_requires_grad;\n  std::vector<float> dropout_rate_list;\n};\n\nclass FusedMatmulBiasAddReluDropout\n    : public OpExprGradFunction<FusedMatmulBiasAddReluDropoutCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override;\n  Maybe<void> Capture(FusedMatmulBiasAddReluDropoutCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override;\n  Maybe<void> Apply(const FusedMatmulBiasAddReluDropoutCaptureState* ctx,\n                    const TensorTuple& out_grads, TensorTuple* in_grads) const override;\n\n protected:\n  AttrMap base_attrs_;\n};\n\nMaybe<void> FusedMatmulBiasAddReluDropout::Init(const OpExpr& op) {\n  const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n  CHECK_NOTNULL_OR_RETURN(fw_op_expr);\n  base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> FusedMatmulBiasAddReluDropout::Capture(FusedMatmulBiasAddReluDropoutCaptureState* ctx,\n                                                   const TensorTuple& inputs,\n                                                   const TensorTuple& outputs,\n                                                   const AttrMap& attrs) const {\n  CHECK_OR_RETURN(inputs.size() % 2 == 1) << \"Both weight and bias should be passed together. \";\n  int32_t weight_num = (inputs.size() - 1) / 2;\n  ctx->weight_num = weight_num;\n  ctx->x_requires_grad = JUST(VectorAt(inputs, 0))->requires_grad();\n  ctx->weights_requires_grad.resize(weight_num);\n  ctx->biases_requires_grad.resize(weight_num);\n\n  for (int32_t i = 0; i < weight_num; i++) {\n    ctx->weights_requires_grad.at(i) = inputs.at(i + 1)->requires_grad();              // NOLINT\n    ctx->biases_requires_grad.at(i) = inputs.at(i + 1 + weight_num)->requires_grad();  // NOLINT\n  }\n\n  ctx->SaveTensorForBackward(JUST(VectorAt(inputs, 0)));  // x. idx_sum:1\n  for (int32_t i = 0; i < weight_num; i++) {\n    ctx->SaveTensorForBackward(JUST(VectorAt(inputs, i + 1)));  // weights. idx_sum:1+w\n  }\n\n  ctx->SaveTensorForBackward(JUST(VectorAt(outputs, 0)));  // final layers output. idx_sum:2+w\n  for (int32_t i = 0; i < weight_num; i++) {\n    ctx->SaveTensorForBackward(\n        JUST(VectorAt(outputs, i + 1)));  // cublas aux. need minus 1. idx_sum:2+2w\n  }\n  for (int32_t i = 0; i < weight_num; i++) {\n    ctx->SaveTensorForBackward(JUST(VectorAt(outputs, i + 1 + weight_num)));  // hidden.\n  }\n\n  ComposedAttrMap composed_attrs(attrs, base_attrs_);\n  ctx->skip_final_activation = JUST(composed_attrs.GetAttr<bool>(\"skip_final_activation\"));\n  ctx->dropout_rate_list = JUST(composed_attrs.GetAttr<std::vector<float>>(\"dropout_rate_list\"));\n\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> FusedMatmulBiasAddReluDropout::Apply(\n    const FusedMatmulBiasAddReluDropoutCaptureState* ctx, const TensorTuple& out_grads,\n    TensorTuple* in_grads) const {\n  int32_t weight_num = ctx->weight_num;\n  in_grads->resize(1 + 2 * weight_num);\n\n  TensorTuple hiddens(weight_num);\n  TensorTuple weights(weight_num);\n  TensorTuple cublas_auxs(weight_num);\n  TensorTuple dgrad(weight_num);\n\n  std::shared_ptr<one::Tensor> x = JUST(VectorAt(ctx->SavedTensors(), 0));\n  std::shared_ptr<one::Tensor> out = JUST(VectorAt(ctx->SavedTensors(), 1 + weight_num));\n\n  for (int32_t i = 0; i < weight_num; ++i) {\n    weights[i] = JUST(VectorAt(ctx->SavedTensors(), 1 + i));\n  }\n\n  for (int32_t i = 0; i < weight_num; ++i) {\n    cublas_auxs[i] = JUST(VectorAt(ctx->SavedTensors(), i + 2 + weight_num));\n  }\n\n  for (int32_t i = 0; i < weight_num; ++i) {\n    hiddens[i] = JUST(VectorAt(ctx->SavedTensors(), i + 2 + 2 * weight_num));\n  }\n\n  float rate = ctx->dropout_rate_list.at(weight_num - 1);\n  float scale = 0.0f;\n  if (rate < 1.0f) { scale = 1.0f / (1.0f - rate); }\n\n  /*\n  step1: use dy and mask to get last layer's dropout + relu grad.\n  Because curand_uniform distribution is (0.0, 1.0], so the value after relu will be write into mask\n  too. And DropoutGrad use this mask to generate grad, it will generate dropout and relu grad\n  simultaneously.\n  */\n  std::shared_ptr<one::Tensor> last_bias_dy = JUST(VectorAt(out_grads, 0));\n  if (!ctx->skip_final_activation || rate != 0.0f) {\n    last_bias_dy = JUST(functional::FusedReluDropoutGrad(JUST(VectorAt(out_grads, 0)),\n                                                         cublas_auxs[weight_num - 1], scale));\n  }\n\n  if (ParseBooleanFromEnv(\"ONEFLOW_ONE_EMBEDDING_FUSED_MLP_ASYNC_GRAD\", false)) {\n    std::vector<float> alpha_list(weight_num - 1, 1.0);\n    for (int i = 0; i < weight_num - 1; i++) {\n      rate = ctx->dropout_rate_list.at(i);\n      scale = 1.0;\n      if (rate < 1.0f) { scale = 1.0f / (1.0f - rate); }\n      alpha_list.at(i) = scale;\n    }\n    const auto& fused_mlp_grad =\n        JUST(functional::FusedMLPGrad(last_bias_dy, JUST(VectorAt(ctx->SavedTensors(), 0)), weights,\n                                      cublas_auxs, hiddens, alpha_list));\n    if (ctx->x_requires_grad) {\n      // dx:\n      JUST(VectorAt(*in_grads, 0)) = fused_mlp_grad->at(0);\n    }\n\n    for (int32_t hidden_layer_idx = weight_num - 1; hidden_layer_idx > -1; hidden_layer_idx--) {\n      if (JUST(VectorAt(ctx->biases_requires_grad, (hidden_layer_idx)))) {\n        // dbias\n        JUST(VectorAt(*in_grads, weight_num + hidden_layer_idx + 1)) =\n            fused_mlp_grad->at(1 + hidden_layer_idx);  // NOLINT\n      }\n\n      // dw\n      if (JUST(VectorAt(ctx->weights_requires_grad, hidden_layer_idx))) {\n        JUST(VectorAt(*in_grads, (1 + hidden_layer_idx))) =\n            fused_mlp_grad->at(1 + weight_num + hidden_layer_idx);\n      }\n    }\n  } else {\n    // step2: use reduce_sum to get last layer's bias grad.\n    std::vector<int32_t> reduce_axes_vec{0};\n    if (JUST(VectorAt(ctx->biases_requires_grad, weight_num - 1))) {\n      JUST(VectorAt(*in_grads, 2 * weight_num)) =\n          JUST(functional::ReduceSum(last_bias_dy, reduce_axes_vec, false, NullOpt));\n    }\n\n    std::shared_ptr<one::Tensor> cublas_dy = last_bias_dy;\n    for (int32_t hidden_layer_idx = weight_num - 1; hidden_layer_idx > 0; hidden_layer_idx--) {\n      // If it is final layer, we use out_grads[0] as dy.\n      if (hidden_layer_idx != weight_num - 1) {\n        cublas_dy = JUST(VectorAt(dgrad, hidden_layer_idx + 1));\n      }\n      rate = ctx->dropout_rate_list.at(hidden_layer_idx - 1);\n      scale = 1.0;\n      if (rate < 1.0f) { scale = 1.0f / (1.0f - rate); }\n      /*\n      Here we use cublas to compute bias + relu + matmul grad.\n      Then use Matmul to compute weight grad.\n      */\n      const auto& matmul_relu_bias_bgrad = JUST(functional::CublasBiasAddReluMatmulGrad(\n          cublas_dy, JUST(VectorAt(weights, hidden_layer_idx)),\n          JUST(VectorAt(cublas_auxs, hidden_layer_idx - 1)), /*alpha=*/scale));\n\n      // dgrad\n      dgrad.at(hidden_layer_idx) = matmul_relu_bias_bgrad->at(0);  // NOLINT\n\n      if (JUST(VectorAt(ctx->biases_requires_grad, (hidden_layer_idx - 1)))) {\n        // dbias\n        JUST(VectorAt(*in_grads, weight_num + hidden_layer_idx)) =\n            matmul_relu_bias_bgrad->at(1);  // NOLINT\n      }\n      // dw\n      if (JUST(VectorAt(ctx->weights_requires_grad, hidden_layer_idx))) {\n        JUST(VectorAt(*in_grads, (1 + hidden_layer_idx))) = JUST(functional::MatMul(\n            cublas_dy, JUST(VectorAt(hiddens, hidden_layer_idx - 1)), true, false, 1.0));\n      }\n    }\n\n    // For the first layer, we need to use 2 matmul to get grads.\n    std::shared_ptr<one::Tensor> last_dy;\n    if (weight_num != 1) {\n      last_dy = JUST(VectorAt(dgrad, 1));\n    } else {\n      last_dy = last_bias_dy;\n    }\n\n    if (ctx->x_requires_grad) {\n      // dx:\n      JUST(VectorAt(*in_grads, 0)) =\n          JUST(functional::MatMul(last_dy, JUST(VectorAt(weights, 0)), false, false, 1.0));\n    }\n    if (JUST(VectorAt(ctx->weights_requires_grad, 0))) {\n      // dw:\n      JUST(VectorAt(*in_grads, 1)) = JUST(\n          functional::MatMul(last_dy, JUST(VectorAt(ctx->SavedTensors(), 0)), true, false, 1.0));\n    }\n  }\n\n  return Maybe<void>::Ok();\n}\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"fused_matmul_bias_add_relu_dropout\", FusedMatmulBiasAddReluDropout);\n\n}  // namespace one\n\n}  // namespace oneflow\n#endif  // CUDA_VERSION >= 11060\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/fused_scale_mask_bias_softmax.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/common/scalar.h\"\n#include \"oneflow/core/framework/op_builder.h\"\n#include \"oneflow/core/framework/op_expr.h\"\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/framework/op_interpreter/op_interpreter_util.h\"\n#include \"oneflow/core/framework/tensor_tuple.h\"\n#include \"oneflow/core/functional/functional.h\"\n#include \"oneflow/core/functional/functional_api.yaml.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct FusedScaleMaskBiasSoftmaxCaptureState : public AutoGradCaptureState {\n  bool input_requires_grad = false;\n  bool bias_requires_grad = false;\n  int32_t input_size = 3;\n  float scale = 1.0;\n};\n\nclass FusedScaleMaskBiasSoftmax : public OpExprGradFunction<FusedScaleMaskBiasSoftmaxCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override;\n  Maybe<void> Capture(FusedScaleMaskBiasSoftmaxCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override;\n  Maybe<void> Apply(const FusedScaleMaskBiasSoftmaxCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override;\n\n private:\n  AttrMap base_attrs_;\n};\n\nMaybe<void> FusedScaleMaskBiasSoftmax::Init(const OpExpr& op) {\n  const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n  CHECK_NOTNULL_OR_RETURN(fw_op_expr);\n  base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> FusedScaleMaskBiasSoftmax::Capture(FusedScaleMaskBiasSoftmaxCaptureState* ctx,\n                                               const TensorTuple& inputs,\n                                               const TensorTuple& outputs,\n                                               const AttrMap& attrs) const {\n  ComposedAttrMap composed_attrs(attrs, base_attrs_);\n  ctx->input_requires_grad = inputs.at(0)->requires_grad();\n  if (inputs.size() == 3) ctx->bias_requires_grad = inputs.at(2)->requires_grad();\n  if (!ctx->input_requires_grad && !ctx->bias_requires_grad) { return Maybe<void>::Ok(); }\n\n  ctx->scale = JUST(composed_attrs.GetAttr<float>(\"scale\"));\n  ctx->SaveTensorForBackward(outputs.at(0));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> FusedScaleMaskBiasSoftmax::Apply(const FusedScaleMaskBiasSoftmaxCaptureState* ctx,\n                                             const TensorTuple& out_grads,\n                                             TensorTuple* in_grads) const {\n  CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // dy\n  if (!ctx->input_requires_grad && !ctx->bias_requires_grad) { return Maybe<void>::Ok(); }\n  in_grads->resize(ctx->input_size);\n\n  const std::shared_ptr<oneflow::one::Tensor>& y = ctx->SavedTensors().at(0);\n  const std::shared_ptr<oneflow::one::Tensor>& input_grad =\n      JUST(functional::FusedScaleMaskBiasSoftmaxGrad(y, out_grads.at(0), ctx->scale));\n\n  if (ctx->input_requires_grad) in_grads->at(0) = input_grad;\n\n  if (ctx->bias_requires_grad) {\n    int batch_dim = (y->shape()->NumAxes() == 5) ? 1 : 0;\n    in_grads->at(2) = JUST(functional::ScalarMul(\n        1 / ctx->scale, JUST(functional::ReduceSum(input_grad, {batch_dim}, true, NullOpt))));\n  }\n\n  return Maybe<void>::Ok();\n}\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"fused_scale_mask_bias_softmax\", FusedScaleMaskBiasSoftmax);\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/fused_scale_mask_softmax.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/framework/op_builder.h\"\n#include \"oneflow/core/framework/op_expr.h\"\n#include \"oneflow/core/framework/op_interpreter/op_interpreter_util.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct FusedScaleMaskSoftmaxInterState : public AutoGradCaptureState {\n  bool input_requires_grad = false;\n  float scale = 1.0;\n};\n\nclass FusedScaleMaskSoftmax : public OpExprGradFunction<FusedScaleMaskSoftmaxInterState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override;\n  Maybe<void> Capture(FusedScaleMaskSoftmaxInterState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override;\n  Maybe<void> Apply(const FusedScaleMaskSoftmaxInterState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override;\n\n private:\n  AttrMap base_attrs_;\n};\n\nMaybe<void> FusedScaleMaskSoftmax::Init(const OpExpr& op) {\n  const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n  CHECK_NOTNULL_OR_RETURN(fw_op_expr);\n  base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> FusedScaleMaskSoftmax::Capture(FusedScaleMaskSoftmaxInterState* ctx,\n                                           const TensorTuple& inputs, const TensorTuple& outputs,\n                                           const AttrMap& attrs) const {\n  CHECK_EQ_OR_RETURN(inputs.size(), 2);  // input, mask\n  ctx->input_requires_grad = inputs.at(0)->requires_grad();\n\n  if (!ctx->input_requires_grad) { return Maybe<void>::Ok(); }\n  ComposedAttrMap composed_attrs(attrs, base_attrs_);\n  ctx->scale = JUST(composed_attrs.GetAttr<float>(\"scale_value\"));\n\n  ctx->SaveTensorForBackward(inputs.at(1));   // save mask\n  ctx->SaveTensorForBackward(outputs.at(0));  // save y, ie. softmax result\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> FusedScaleMaskSoftmax::Apply(const FusedScaleMaskSoftmaxInterState* ctx,\n                                         const TensorTuple& out_grads,\n                                         TensorTuple* in_grads) const {\n  if (!ctx->input_requires_grad) { return Maybe<void>::Ok(); }\n\n  CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // dy\n  in_grads->resize(2);                      // input, mask\n\n  const std::shared_ptr<oneflow::one::Tensor>& mask = ctx->SavedTensors().at(0);\n  const std::shared_ptr<oneflow::one::Tensor>& y = ctx->SavedTensors().at(1);\n  const std::shared_ptr<oneflow::one::Tensor>& fused_scale_mask_softmax_grad =\n      JUST(functional::FusedScaleMaskSoftmaxGrad(y, out_grads.at(0), mask, ctx->scale));\n\n  in_grads->at(0) = fused_scale_mask_softmax_grad;\n  return Maybe<void>::Ok();\n}\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"fused_scale_mask_softmax\", FusedScaleMaskSoftmax);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/fused_scale_mask_softmax_dropout.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/framework/op_builder.h\"\n#include \"oneflow/core/framework/op_expr.h\"\n#include \"oneflow/core/framework/op_interpreter/op_interpreter_util.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct FusedScaleMaskSoftmaxDropoutInterState : public AutoGradCaptureState {\n  bool input_requires_grad = true;\n  float scale = 1.0;\n  float dropout_scale = 1.0;\n};\n\nclass FusedScaleMaskSoftmaxDropout\n    : public OpExprGradFunction<FusedScaleMaskSoftmaxDropoutInterState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override;\n  Maybe<void> Capture(FusedScaleMaskSoftmaxDropoutInterState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override;\n  Maybe<void> Apply(const FusedScaleMaskSoftmaxDropoutInterState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override;\n\n private:\n  AttrMap base_attrs_;\n};\n\nMaybe<void> FusedScaleMaskSoftmaxDropout::Init(const OpExpr& op) {\n  const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n  CHECK_NOTNULL_OR_RETURN(fw_op_expr);\n  base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> FusedScaleMaskSoftmaxDropout::Capture(FusedScaleMaskSoftmaxDropoutInterState* ctx,\n                                                  const TensorTuple& inputs,\n                                                  const TensorTuple& outputs,\n                                                  const AttrMap& attrs) const {\n  CHECK_EQ_OR_RETURN(inputs.size(), 3);  // input, mask, dropout_mask\n  ctx->input_requires_grad = inputs.at(0)->requires_grad();\n\n  if (!ctx->input_requires_grad) { return Maybe<void>::Ok(); }\n  ComposedAttrMap composed_attrs(attrs, base_attrs_);\n  ctx->scale = JUST(composed_attrs.GetAttr<float>(\"scale_value\"));\n  ctx->dropout_scale = JUST(composed_attrs.GetAttr<float>(\"dropout_scale_value\"));\n\n  ctx->SaveTensorForBackward(inputs.at(1));   // mask\n  ctx->SaveTensorForBackward(inputs.at(2));   // dropout_mask\n  ctx->SaveTensorForBackward(outputs.at(1));  // softmax_y\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> FusedScaleMaskSoftmaxDropout::Apply(const FusedScaleMaskSoftmaxDropoutInterState* ctx,\n                                                const TensorTuple& out_grads,\n                                                TensorTuple* in_grads) const {\n  CHECK_EQ_OR_RETURN(out_grads.size(), 2);  // dy, d_softmax_y\n  if (!ctx->input_requires_grad) { return Maybe<void>::Ok(); }\n  in_grads->resize(3);  // input, mask, dropout_mask\n\n  const std::shared_ptr<oneflow::one::Tensor>& mask = ctx->SavedTensors().at(0);\n  const std::shared_ptr<oneflow::one::Tensor>& dropout_mask = ctx->SavedTensors().at(1);\n  const std::shared_ptr<oneflow::one::Tensor>& softmax_y = ctx->SavedTensors().at(2);\n  const std::shared_ptr<oneflow::one::Tensor>& input_grad =\n      JUST(functional::FusedScaleMaskSoftmaxDropoutGrad(\n          softmax_y, out_grads.at(0), mask, dropout_mask, ctx->scale, ctx->dropout_scale));\n\n  in_grads->at(0) = input_grad;\n  return Maybe<void>::Ok();\n}\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"fused_scale_mask_softmax_dropout\", FusedScaleMaskSoftmaxDropout);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/fused_scale_tril.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/framework/op_expr.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct FusedScaleTrilState : public AutoGradCaptureState {\n  bool requires_grad;\n  int64_t diagonal;\n  double floating_scale_value;\n  int64_t integer_scale_value;\n  bool is_floating_scale_value;\n};\n\nclass FusedScaleTril : public OpExprGradFunction<FusedScaleTrilState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override;\n  Maybe<void> Capture(FusedScaleTrilState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override;\n  Maybe<void> Apply(const FusedScaleTrilState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override;\n\n private:\n  AttrMap base_attrs_;\n};\n\nMaybe<void> FusedScaleTril::Init(const OpExpr& op) {\n  const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n  CHECK_NOTNULL_OR_RETURN(fw_op_expr);\n  base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> FusedScaleTril::Capture(FusedScaleTrilState* ctx, const TensorTuple& inputs,\n                                    const TensorTuple& outputs, const AttrMap& attrs) const {\n  ctx->requires_grad = inputs.at(0)->requires_grad();\n  if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n\n  ComposedAttrMap composed_attrs(attrs, base_attrs_);\n  ctx->diagonal = JUST(composed_attrs.GetAttr<int64_t>(\"diagonal\"));\n  ctx->floating_scale_value = JUST(composed_attrs.GetAttr<double>(\"floating_scale_value\"));\n  ctx->integer_scale_value = JUST(composed_attrs.GetAttr<int64_t>(\"integer_scale_value\"));\n  ctx->is_floating_scale_value = JUST(composed_attrs.GetAttr<bool>(\"is_floating_scale_value\"));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> FusedScaleTril::Apply(const FusedScaleTrilState* ctx, const TensorTuple& out_grads,\n                                  TensorTuple* in_grads) const {\n  if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n  CHECK_EQ_OR_RETURN(out_grads.size(), 1);\n  in_grads->resize(1);\n  Scalar scale;\n  if (ctx->is_floating_scale_value) {\n    scale = ctx->floating_scale_value;\n  } else {\n    scale = ctx->integer_scale_value;\n  }\n  (*in_grads)[0] = JUST(functional::FusedScaleTril(out_grads[0], ctx->diagonal, 0, scale));\n  return Maybe<void>::Ok();\n}\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"fused_scale_tril\", FusedScaleTril);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/fused_scale_tril_softmax_mask_scale.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/framework/op_builder.h\"\n#include \"oneflow/core/framework/op_expr.h\"\n#include \"oneflow/core/framework/op_interpreter/op_interpreter_util.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct FusedScaleTrilSoftmaxMaskScaleInterpState : public AutoGradCaptureState {\n  bool input_requires_grad = true;\n  int64_t diagonal = 0;\n  float tril_scale_value = 0.0;\n  float mask_scale_value = 1.0;\n};\n\nclass FusedScaleTrilSoftmaxMaskScale\n    : public OpExprGradFunction<FusedScaleTrilSoftmaxMaskScaleInterpState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override;\n  Maybe<void> Capture(FusedScaleTrilSoftmaxMaskScaleInterpState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override;\n  Maybe<void> Apply(const FusedScaleTrilSoftmaxMaskScaleInterpState* ctx,\n                    const TensorTuple& out_grads, TensorTuple* in_grads) const override;\n\n private:\n  AttrMap base_attrs_;\n};\n\nMaybe<void> FusedScaleTrilSoftmaxMaskScale::Init(const OpExpr& op) {\n  const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n  CHECK_NOTNULL_OR_RETURN(fw_op_expr);\n  base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> FusedScaleTrilSoftmaxMaskScale::Capture(FusedScaleTrilSoftmaxMaskScaleInterpState* ctx,\n                                                    const TensorTuple& inputs,\n                                                    const TensorTuple& outputs,\n                                                    const AttrMap& attrs) const {\n  CHECK_EQ_OR_RETURN(inputs.size(), 2);\n  ctx->input_requires_grad = inputs.at(0)->requires_grad();  // input\n\n  if (!ctx->input_requires_grad) { return Maybe<void>::Ok(); }\n  ComposedAttrMap composed_attrs(attrs, base_attrs_);\n  ctx->diagonal = JUST(composed_attrs.GetAttr<int64_t>(\"diagonal\"));\n  ctx->tril_scale_value = JUST(composed_attrs.GetAttr<float>(\"tril_scale_value\"));\n  ctx->mask_scale_value = JUST(composed_attrs.GetAttr<float>(\"mask_scale_value\"));\n  ctx->SaveTensorForBackward(inputs.at(1));   // Save Mask\n  ctx->SaveTensorForBackward(outputs.at(1));  // Save softmax_y\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> FusedScaleTrilSoftmaxMaskScale::Apply(\n    const FusedScaleTrilSoftmaxMaskScaleInterpState* ctx, const TensorTuple& out_grads,\n    TensorTuple* in_grads) const {\n  CHECK_EQ_OR_RETURN(out_grads.size(), 2);  // Cause output has y and softmax_y\n  if (!ctx->input_requires_grad) { return Maybe<void>::Ok(); }\n\n  // mask have no grad(reqiures_grad=False), but still take a place in in_grads\n  in_grads->resize(2);\n\n  const std::shared_ptr<oneflow::one::Tensor>& mask = ctx->SavedTensors().at(0);\n  const std::shared_ptr<oneflow::one::Tensor>& softmax_y = ctx->SavedTensors().at(1);\n  const std::shared_ptr<oneflow::one::Tensor>& input_grad =\n      JUST(functional::FusedScaleTrilSoftmaxMaskScaleGrad(softmax_y, out_grads.at(0), mask,\n                                                          ctx->diagonal, ctx->tril_scale_value,\n                                                          ctx->mask_scale_value));\n  if (ctx->input_requires_grad) { in_grads->at(0) = input_grad; }\n\n  return Maybe<void>::Ok();\n}\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"fused_tril_scale_softmax_mask_scale\",\n                               FusedScaleTrilSoftmaxMaskScale);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/fused_self_attention.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct FusedSelfAttentionInterpState : public AutoGradCaptureState {\n  bool input_requires_grad = false;\n  float alpha = 1.0;\n};\n\nclass FusedSelfAttention : public OpExprGradFunction<FusedSelfAttentionInterpState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override {\n    const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n    CHECK_NOTNULL_OR_RETURN(fw_op_expr);\n    base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Capture(FusedSelfAttentionInterpState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    CHECK_EQ_OR_RETURN(inputs.size(), 1);\n    ctx->input_requires_grad = inputs.at(0)->requires_grad();\n    if (!ctx->input_requires_grad) { return Maybe<void>::Ok(); }\n    ComposedAttrMap composed_attrs(attrs, base_attrs_);\n    ctx->alpha = JUST(composed_attrs.GetAttr<float>(\"alpha\"));\n    ctx->SaveTensorForBackward(inputs.at(0));\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const FusedSelfAttentionInterpState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    if (!ctx->input_requires_grad) { return Maybe<void>::Ok(); }\n\n    CHECK_EQ_OR_RETURN(out_grads.size(), 2);\n    in_grads->resize(1);\n    const auto& hidden_states = ctx->SavedTensors().at(0);\n    const std::shared_ptr<oneflow::one::Tensor>& fused_self_attention_grad =\n        JUST(functional::FusedSelfAttentionGrad(out_grads.at(0), out_grads.at(1), hidden_states,\n                                                ctx->alpha));\n    in_grads->at(0) = fused_self_attention_grad;\n    return Maybe<void>::Ok();\n  }\n\n private:\n  AttrMap base_attrs_;\n};\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"fused_self_attention_query_mul_key_and_value\", FusedSelfAttention);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/fused_weighted_sum.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct FusedWeightedSumCaptureState : public AutoGradCaptureState {\n  std::vector<bool> requires_grad;\n  std::vector<float> weights;\n  float alpha{};\n};\n\nclass FusedWeightedSum : public OpExprGradFunction<FusedWeightedSumCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }\n\n  Maybe<void> Capture(FusedWeightedSumCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    ctx->requires_grad.resize(inputs.size());\n    ctx->weights = JUST(attrs.GetAttr<std::vector<float>>(\"weights\"));\n    ctx->alpha = JUST(attrs.GetAttr<float>(\"alpha\"));\n    CHECK_EQ_OR_RETURN(ctx->weights.size(), inputs.size());\n    for (int i = 0; i < inputs.size(); ++i) { ctx->requires_grad[i] = inputs[i]->requires_grad(); }\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const FusedWeightedSumCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n    in_grads->resize(ctx->requires_grad.size());\n    for (int i = 0; i < ctx->requires_grad.size(); ++i) {\n      if (ctx->requires_grad[i]) {\n        (*in_grads)[i] =\n            JUST(functional::ScalarMul(out_grads[0], ctx->weights[i] * ctx->alpha, false));\n      }\n    }\n    return Maybe<void>::Ok();\n  }\n};\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"fused_weighted_sum\", FusedWeightedSum);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/gather.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/framework/op_expr.h\"\n#include \"oneflow/core/framework/op_interpreter/op_interpreter_util.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct GatherCaptureState : public AutoGradCaptureState {\n  int64_t axis;\n  bool requires_grad;\n};\n\nclass Gather : public OpExprGradFunction<GatherCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override;\n  Maybe<void> Capture(GatherCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override;\n  Maybe<void> Apply(const GatherCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override;\n\n private:\n  AttrMap base_attrs_;\n};\n\nMaybe<void> Gather::Init(const OpExpr& op) {\n  const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n  CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n  base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> Gather::Capture(GatherCaptureState* ctx, const TensorTuple& inputs,\n                            const TensorTuple& outputs, const AttrMap& attrs) const {\n  ctx->requires_grad = inputs.at(0)->requires_grad();\n  if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n\n  ctx->SaveTensorForBackward(inputs.at(0));\n  ctx->SaveTensorForBackward(inputs.at(1));\n\n  ComposedAttrMap composed_attrs(attrs, base_attrs_);\n  ctx->axis = JUST(composed_attrs.GetAttr<int64_t>(\"axis\"));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> Gather::Apply(const GatherCaptureState* ctx, const TensorTuple& out_grads,\n                          TensorTuple* in_grads) const {\n  if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n  CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n  const auto& x = ctx->SavedTensors().at(0);\n  const auto& indices = ctx->SavedTensors().at(1);\n  in_grads->at(0) =\n      JUST(functional::UnsortedSegmentSumLike(out_grads.at(0), indices, x, ctx->axis));\n  return Maybe<void>::Ok();\n}\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"gather\", Gather);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/gather_nd.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct GatherNdCaptureState : public AutoGradCaptureState {\n  bool requires_grad;\n};\n\nclass GatherNd : public OpExprGradFunction<GatherNdCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }\n\n  Maybe<void> Capture(GatherNdCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    CHECK_EQ_OR_RETURN(inputs.size(), 2);   // NOLINT(maybe-need-error-msg)\n    CHECK_EQ_OR_RETURN(outputs.size(), 1);  // NOLINT(maybe-need-error-msg)\n    ctx->requires_grad = inputs.at(0)->requires_grad();\n    if (ctx->requires_grad) {\n      ctx->SaveTensorForBackward(inputs.at(0));  // params\n      ctx->SaveTensorForBackward(inputs.at(1));  // indices\n    }\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const GatherNdCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n    in_grads->resize(2);\n    if (ctx->requires_grad) {\n      const auto& params = ctx->SavedTensors().at(0);\n      const auto& indices = ctx->SavedTensors().at(1);\n      in_grads->at(0) = JUST(functional::ScatterNdLike(params, out_grads.at(0), indices));\n    }\n    return Maybe<void>::Ok();\n  }\n};\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"gather_nd\", GatherNd);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/global_cast.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/mutable_attr_map.h\"\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/framework/op_interpreter/op_interpreter_util.h\"\n#include \"oneflow/core/framework/nd_sbp.h\"\n#include \"oneflow/core/boxing/eager_boxing_interpreter_mgr.h\"\n#include \"oneflow/core/framework/tensor_rpc_util.h\"\n#include \"oneflow/core/common/decorator.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct CastGlobalCaptureState : public AutoGradCaptureState {\n  Symbol<ParallelDesc> parallel_desc;\n  Symbol<NdSbp> nd_sbp;\n  std::shared_ptr<const Shape> shape;\n  Symbol<DType> dtype;\n};\n\nclass LocalToGlobal : public OpExprGradFunction<CastGlobalCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override {\n    const auto* fw_op_expr = dynamic_cast<const LocalToGlobalOpExpr*>(&op);\n    CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n    const std::string& op_name = fw_op_expr->op_name();\n    grad_op_ = JUST(one::GlobalToLocalOpExpr::New(GradientOpName(op_name)));\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Capture(CastGlobalCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs,\n                      const OpExprInterpContext& interp_ctx) const override {\n    ctx->parallel_desc = JUST(interp_ctx.parallel_desc);\n    ctx->nd_sbp = JUST(GetDualNdSbp(JUST(interp_ctx.nd_sbp)));\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const CastGlobalCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n    std::shared_ptr<Tensor> out_grad = out_grads.at(0);\n    CHECK_OR_RETURN(out_grad->is_global())\n        << Error::RuntimeError()\n        << \"Expected global tensor for local_to_global but got local tensor\";\n    {\n      Symbol<NdSbp> nd_sbp_constraint = ctx->nd_sbp;\n      Symbol<ParallelDesc> parallel_desc_constraint = ctx->parallel_desc;\n      out_grad = JUST(functional::ToGlobal(out_grad, parallel_desc_constraint,\n                                           *JUST(GetSbpList(nd_sbp_constraint)), GetNoneSbpList(),\n                                           /* check_meta */ false, /*copy=*/false));\n    }\n    in_grads->at(0) = JUST(OpInterpUtil::Dispatch<Tensor>(*grad_op_, {out_grad}));\n    return Maybe<void>::Ok();\n  }\n\n private:\n  std::shared_ptr<OpExpr> grad_op_;\n};\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"local_to_global\", LocalToGlobal);\n\nclass GlobalToLocal : public OpExprGradFunction<CastGlobalCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override {\n    const auto* fw_op_expr = dynamic_cast<const GlobalToLocalOpExpr*>(&op);\n    CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n    const std::string& op_name = fw_op_expr->op_name();\n    grad_op_ = JUST(one::LocalToGlobalOpExpr::New(GradientOpName(op_name)));\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Capture(CastGlobalCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    const auto& input = inputs.at(0);\n    CHECK_OR_RETURN(input->is_global())\n        << Error::RuntimeError()\n        << \"Expected global tensor for global_to_local but got local tensor\";\n    ctx->parallel_desc = JUST(input->parallel_desc());\n    ctx->nd_sbp = JUST(input->nd_sbp());\n    ctx->shape = input->shape();\n    ctx->dtype = input->dtype();\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const CastGlobalCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    const auto& dual_nd_sbp = JUST(GetDualNdSbp(ctx->nd_sbp));\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"shape\", \"dtype\", \"sync_data\");\n    attrs.SetAllAttrs(*ctx->shape, ctx->dtype->data_type(), true);\n    in_grads->at(0) = JUST(OpInterpUtil::Dispatch<Tensor>(\n        *grad_op_, {out_grads.at(0)}, OpExprInterpContext(attrs, ctx->parallel_desc, dual_nd_sbp)));\n    return Maybe<void>::Ok();\n  }\n\n private:\n  std::shared_ptr<OpExpr> grad_op_;\n};\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"global_to_local\", GlobalToLocal);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/global_to_global.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/id_util.h\"\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/framework/op_builder.h\"\n#include \"oneflow/core/framework/op_interpreter/op_interpreter_util.h\"\n#include \"oneflow/core/framework/op_expr.h\"\n#include \"oneflow/core/framework/nd_sbp.h\"\n#include \"oneflow/core/functional/functional.h\"\n#include \"oneflow/core/common/optional.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct GlobalToGlobalState : public AutoGradCaptureState {\n  Symbol<ParallelDesc> parallel_desc;\n  Symbol<NdSbp> nd_sbp;\n};\n\nclass GlobalToGlobalGradFunction : public OpExprGradFunction<GlobalToGlobalState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override {\n    const auto* fw_op_expr = dynamic_cast<const GlobalToGlobalOpExpr*>(&op);\n    CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n    grad_nd_sbp_ = fw_op_expr->grad_nd_sbp();\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Capture(GlobalToGlobalState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs,\n                      const OpExprInterpContext& interp_ctx) const override {\n    CHECK_EQ_OR_RETURN(inputs.size(), 1);  // NOLINT(maybe-need-error-msg)\n    ctx->parallel_desc = JUST(inputs.at(0)->parallel_desc());\n    ctx->nd_sbp = JUST(inputs.at(0)->nd_sbp());\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const GlobalToGlobalState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n    const auto& out_grad = out_grads.at(0);\n    CHECK_OR_RETURN(out_grad->is_global())\n        << Error::RuntimeError()\n        << \"Expected global tensor for global_to_global but got local tensor\";\n    in_grads->resize(1);\n    const auto& grad_nd_sbp = grad_nd_sbp_.value_or(JUST(out_grad->nd_sbp()));\n    const auto& grad_sbp_list = JUST(GetSbpList(grad_nd_sbp));\n\n    if (LazyMode::is_enabled()) {\n      (*in_grads)[0] = JUST(one::functional::ToGlobal(out_grad, ctx->parallel_desc, *grad_sbp_list,\n                                                      {}, /* check_meta */ false, /*copy=*/false));\n    } else {\n      const auto& grad_grad_sbp_list = JUST(GetSbpList(ctx->nd_sbp));\n      (*in_grads)[0] = JUST(one::functional::ToGlobal(out_grad, ctx->parallel_desc, *grad_sbp_list,\n                                                      *grad_grad_sbp_list, /* check_meta */ false,\n                                                      /*copy=*/false));\n    }\n    return Maybe<void>::Ok();\n  }\n\n private:\n  Optional<Symbol<NdSbp>> grad_nd_sbp_;\n};\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"global_to_global\", GlobalToGlobalGradFunction);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/gradient_accumulation.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/framework/op_expr.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct GradAccRepeatCaptureState : public AutoGradCaptureState {\n  int32_t repeat_num = 1;\n};\n\nclass GradAccRepeat : public OpExprGradFunction<GradAccRepeatCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override;\n  Maybe<void> Capture(GradAccRepeatCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override;\n  Maybe<void> Apply(const GradAccRepeatCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override;\n\n private:\n  AttrMap base_attrs_;\n};\n\nMaybe<void> GradAccRepeat::Init(const OpExpr& op) {\n  const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n  CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n  base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> GradAccRepeat::Capture(GradAccRepeatCaptureState* ctx, const TensorTuple& inputs,\n                                   const TensorTuple& outputs, const AttrMap& attrs) const {\n  ComposedAttrMap composed_attrs(attrs, base_attrs_);\n  ctx->repeat_num = JUST(composed_attrs.GetAttr<int32_t>(\"repeat_num\"));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> GradAccRepeat::Apply(const GradAccRepeatCaptureState* ctx, const TensorTuple& out_grads,\n                                 TensorTuple* in_grads) const {\n  CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n  in_grads->resize(1);\n  (*in_grads)[0] = JUST(functional::GradAccCollect(out_grads[0], ctx->repeat_num));\n  return Maybe<void>::Ok();\n}\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"repeat\", GradAccRepeat);\n\nstruct GradAccCollectCaptureState : public AutoGradCaptureState {\n  int32_t max_acc_num = 1;\n};\n\nclass GradAccCollect : public OpExprGradFunction<GradAccCollectCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override;\n  Maybe<void> Capture(GradAccCollectCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override;\n  Maybe<void> Apply(const GradAccCollectCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override;\n\n private:\n  AttrMap base_attrs_;\n};\n\nMaybe<void> GradAccCollect::Init(const OpExpr& op) {\n  const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n  CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n  base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> GradAccCollect::Capture(GradAccCollectCaptureState* ctx, const TensorTuple& inputs,\n                                    const TensorTuple& outputs, const AttrMap& attrs) const {\n  ComposedAttrMap composed_attrs(attrs, base_attrs_);\n  ctx->max_acc_num = JUST(composed_attrs.GetAttr<int32_t>(\"max_acc_num\"));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> GradAccCollect::Apply(const GradAccCollectCaptureState* ctx,\n                                  const TensorTuple& out_grads, TensorTuple* in_grads) const {\n  CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n  in_grads->resize(1);\n  (*in_grads)[0] = JUST(functional::GradAccRepeat(out_grads[0], ctx->max_acc_num));\n  return Maybe<void>::Ok();\n}\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"acc\", GradAccCollect);\n\nstruct GradAccPackCaptureState : public AutoGradCaptureState {\n  int32_t pack_num = 1;\n};\n\nclass GradAccPack : public OpExprGradFunction<GradAccPackCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override;\n  Maybe<void> Capture(GradAccPackCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override;\n  Maybe<void> Apply(const GradAccPackCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override;\n\n private:\n  AttrMap base_attrs_;\n};\n\nMaybe<void> GradAccPack::Init(const OpExpr& op) {\n  const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n  CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n  base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> GradAccPack::Capture(GradAccPackCaptureState* ctx, const TensorTuple& inputs,\n                                 const TensorTuple& outputs, const AttrMap& attrs) const {\n  ComposedAttrMap composed_attrs(attrs, base_attrs_);\n  ctx->pack_num = JUST(composed_attrs.GetAttr<int32_t>(\"pack_num\"));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> GradAccPack::Apply(const GradAccPackCaptureState* ctx, const TensorTuple& out_grads,\n                               TensorTuple* in_grads) const {\n  CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n  in_grads->resize(1);\n  (*in_grads)[0] = JUST(functional::GradAccUnpack(out_grads[0], ctx->pack_num));\n  return Maybe<void>::Ok();\n}\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"pack\", GradAccPack);\n\nstruct GradAccUnpackCaptureState : public AutoGradCaptureState {\n  int32_t unpack_num = 1;\n};\n\nclass GradAccUnpack : public OpExprGradFunction<GradAccUnpackCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override;\n  Maybe<void> Capture(GradAccUnpackCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override;\n  Maybe<void> Apply(const GradAccUnpackCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override;\n\n private:\n  AttrMap base_attrs_;\n};\n\nMaybe<void> GradAccUnpack::Init(const OpExpr& op) {\n  const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n  CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n  base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> GradAccUnpack::Capture(GradAccUnpackCaptureState* ctx, const TensorTuple& inputs,\n                                   const TensorTuple& outputs, const AttrMap& attrs) const {\n  ComposedAttrMap composed_attrs(attrs, base_attrs_);\n  ctx->unpack_num = JUST(composed_attrs.GetAttr<int32_t>(\"unpack_num\"));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> GradAccUnpack::Apply(const GradAccUnpackCaptureState* ctx, const TensorTuple& out_grads,\n                                 TensorTuple* in_grads) const {\n  CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n  in_grads->resize(1);\n  (*in_grads)[0] = JUST(functional::GradAccPack(out_grads[0], ctx->unpack_num));\n  return Maybe<void>::Ok();\n}\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"unpack\", GradAccUnpack);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/graph_feed_and_fetch.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/job/lazy_mode.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct GraphFeedAndFetchCaptureState : public AutoGradCaptureState {\n  bool requires_grad = false;\n};\n\nclass GraphFeedAndFetch : public OpExprGradFunction<GraphFeedAndFetchCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }\n\n  Maybe<void> Capture(GraphFeedAndFetchCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    CHECK_EQ_OR_RETURN(inputs.size(), 1);  // NOLINT(maybe-need-error-msg)\n    ctx->requires_grad = inputs.at(0)->requires_grad();\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const GraphFeedAndFetchCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n    in_grads->resize(1);\n    if (ctx->requires_grad) { in_grads->at(0) = out_grads.at(0); }\n    return Maybe<void>::Ok();\n  }\n};\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"graph_feed_and_fetch\", GraphFeedAndFetch);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/grid_sample.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/framework/attr_map.h\"\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct GridSampleInterpState : public AutoGradCaptureState {\n  std::string interpolation_mode = \"\";\n  std::string padding_mode = \"\";\n  bool align_corners = false;\n  size_t input_index = -1;\n  size_t grid_index = -1;\n  bool input_requires_grad = false;\n  bool grid_requires_grad = false;\n  bool requires_grad = false;\n};\n\nclass GridSample : public OpExprGradFunction<GridSampleInterpState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override {\n    const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n    CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n    base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Capture(GridSampleInterpState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    CHECK_EQ_OR_RETURN(inputs.size(), 2);  // NOLINT(maybe-need-error-msg)\n    ctx->input_requires_grad = inputs.at(0)->requires_grad();\n    ctx->grid_requires_grad = inputs.at(1)->requires_grad();\n    ctx->requires_grad = ctx->input_requires_grad || ctx->grid_requires_grad;\n    if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n\n    ctx->input_index = ctx->SaveTensorForBackward(inputs.at(0));  // input\n    ctx->grid_index = ctx->SaveTensorForBackward(inputs.at(1));   // grid\n\n    ComposedAttrMap composed_attrs(attrs, base_attrs_);\n    ctx->interpolation_mode = JUST(composed_attrs.GetAttr<std::string>(\"interpolation_mode\"));\n    ctx->padding_mode = JUST(composed_attrs.GetAttr<std::string>(\"padding_mode\"));\n    ctx->align_corners = JUST(composed_attrs.GetAttr<bool>(\"align_corners\"));\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const GridSampleInterpState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n\n    CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n\n    const auto& input = ctx->SavedTensors().at(ctx->input_index);\n    const auto& grid = ctx->SavedTensors().at(ctx->grid_index);\n    const auto& results =\n        JUST(functional::GridSampleGrad(out_grads.at(0), input, grid, ctx->interpolation_mode,\n                                        ctx->padding_mode, ctx->align_corners));\n    in_grads->resize(2);\n    if (ctx->input_requires_grad) { in_grads->at(0) = results->at(0); }\n    if (ctx->grid_requires_grad) { in_grads->at(1) = results->at(1); }\n    return Maybe<void>::Ok();\n  }\n\n private:\n  AttrMap base_attrs_;\n};\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"grid_sample\", GridSample);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/group_norm.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/attr_map.h\"\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct GroupNormCaptureState : public AutoGradCaptureState {\n  double epsilon = 1e-5;\n  bool x_requires_grad = true;\n  bool gamma_requires_grad = true;\n  bool beta_requires_grad = true;\n  bool affine = true;\n  int32_t num_groups = 1;\n  size_t x_index = 0;\n  size_t mean_index = 1;\n  size_t inv_variance_index = 2;\n  size_t gamma_index = 3;\n  std::string data_format;\n  std::string activation;\n};\n\nclass GroupNorm : public OpExprGradFunction<GroupNormCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override;\n\n  Maybe<void> Capture(GroupNormCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override;\n\n  Maybe<void> Apply(const GroupNormCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override;\n\n private:\n  AttrMap base_attrs_;\n  std::string op_name_;\n};\n\nMaybe<void> GroupNorm::Init(const OpExpr& op) {\n  const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n  CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n  base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n  op_name_ = fw_op_expr->op_name();\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> GroupNorm::Capture(GroupNormCaptureState* ctx, const TensorTuple& inputs,\n                               const TensorTuple& outputs, const AttrMap& attrs) const {\n  ComposedAttrMap composed_attrs(attrs, base_attrs_);\n  ctx->affine = JUST(composed_attrs.GetAttr<bool>(\"affine\"));\n  ctx->epsilon = JUST(composed_attrs.GetAttr<double>(\"epsilon\"));\n  ctx->num_groups = JUST(composed_attrs.GetAttr<int32_t>(\"num_groups\"));\n  ctx->data_format = JUST(composed_attrs.GetAttr<std::string>(\"data_format\"));\n  ctx->activation = JUST(composed_attrs.GetAttr<std::string>(\"activation\"));\n  if (ctx->affine) {\n    CHECK_EQ_OR_RETURN(inputs.size(), 3);  // NOLINT(maybe-need-error-msg)\n    ctx->gamma_requires_grad = inputs.at(1)->requires_grad();\n    ctx->beta_requires_grad = inputs.at(2)->requires_grad();\n  } else {\n    CHECK_EQ_OR_RETURN(inputs.size(), 1);  // NOLINT(maybe-need-error-msg)\n  }\n  CHECK_EQ_OR_RETURN(outputs.size(), 3);  // NOLINT(maybe-need-error-msg)\n\n  ctx->x_requires_grad = inputs.at(0)->requires_grad();\n  if (ctx->x_requires_grad || ctx->affine) {\n    ctx->x_index = ctx->SaveTensorForBackward(inputs.at(0));\n    ctx->mean_index = ctx->SaveTensorForBackward(outputs.at(1));\n    ctx->inv_variance_index = ctx->SaveTensorForBackward(outputs.at(2));\n    if (ctx->affine) {\n      ctx->gamma_index = ctx->SaveTensorForBackward(inputs.at(1));  // save gamma.\n    }\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> GroupNorm::Apply(const GroupNormCaptureState* ctx, const TensorTuple& out_grads,\n                             TensorTuple* in_grads) const {\n  CHECK_EQ_OR_RETURN(ctx->data_format, \"channels_first\");\n  CHECK_EQ_OR_RETURN(ctx->activation, \"none\");\n  const auto& saved_tensors = ctx->SavedTensors();\n  if (ctx->affine) {\n    in_grads->resize(3);\n  } else {\n    in_grads->resize(1);\n  }\n  const auto& dy = out_grads.at(0);\n  const auto& x = saved_tensors.at(ctx->x_index);\n  const auto& mean = saved_tensors.at(ctx->mean_index);\n  const auto& inv_variance = saved_tensors.at(ctx->inv_variance_index);\n\n  if (ctx->affine && (ctx->gamma_requires_grad || ctx->beta_requires_grad)) {\n    const auto& results = JUST(functional::GroupNormParamGrad(dy, x, mean, inv_variance));\n    if (ctx->gamma_requires_grad) { in_grads->at(1) = results->at(0); }  // For gamma.\n    if (ctx->beta_requires_grad) { in_grads->at(2) = results->at(1); }   // For beta.\n  }\n  if (ctx->x_requires_grad) {\n    if (ctx->affine) {\n      std::shared_ptr<Tensor> gamma = saved_tensors.at(ctx->gamma_index);\n      in_grads->at(0) = JUST(functional::GroupNormGrad(dy, x, mean, inv_variance, gamma,\n                                                       ctx->num_groups, ctx->epsilon));\n    } else {\n      in_grads->at(0) = JUST(functional::GroupNormGrad(dy, x, mean, inv_variance, NullOpt,\n                                                       ctx->num_groups, ctx->epsilon));\n    }\n  }\n  return Maybe<void>::Ok();\n}\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"group_norm\", GroupNorm);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/identity.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/job/lazy_mode.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct IdentityCaptureState : public AutoGradCaptureState {\n  bool requires_grad;\n};\n\nclass Identity : public OpExprGradFunction<IdentityCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }\n\n  Maybe<void> Capture(IdentityCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    CHECK_EQ_OR_RETURN(inputs.size(), 1);  // NOLINT(maybe-need-error-msg)\n    ctx->requires_grad = inputs.at(0)->requires_grad();\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const IdentityCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n    in_grads->resize(1);\n    if (ctx->requires_grad) {\n      if (LazyMode::is_enabled()) {\n        // requires an intermediate node to avoid redundant memory copy or commnet\n        // communication in lazy mode\n        in_grads->at(0) = JUST(functional::Identity(out_grads.at(0)));\n      } else {\n        in_grads->at(0) = out_grads.at(0);\n      }\n    }\n    return Maybe<void>::Ok();\n  }\n};\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"identity\", Identity);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/inv.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/functional/functional.h\"\n#include \"oneflow/core/common/container_util.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct InvCaptureState : public AutoGradCaptureState {\n  bool requires_grad = false;\n};\n\nclass Inv : public OpExprGradFunction<InvCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }\n  Maybe<void> Capture(InvCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs,\n                      const AttrMap& attrs) const override {\n    ctx->requires_grad = JUST(VectorAt(inputs, 0))->requires_grad();\n    if (ctx->requires_grad) { ctx->SaveTensorForBackward(JUST(VectorAt(outputs, 0))); }\n    return Maybe<void>::Ok();\n  }\n  Maybe<void> Apply(const InvCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    if (ctx->requires_grad) {\n      const auto& output = JUST(VectorAt(ctx->SavedTensors(), 0));\n      const auto& dy = JUST(VectorAt(out_grads, 0));\n      JUST(VectorAt(*in_grads, 0)) = JUST(functional::Negative(JUST(functional::MatMul(\n          output, JUST(functional::MatMul(dy, output, false, true, 1.0)), true, false, 1.0))));\n    }\n    return Maybe<void>::Ok();\n  }\n};\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"inv\", Inv);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/kl_div.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct KLDivLossCaptureState : public AutoGradCaptureState {\n  bool input_requires_grad = false;\n  bool target_requires_grad = false;\n  bool log_target = false;\n};\n\nclass KLDivLoss : public OpExprGradFunction<KLDivLossCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override;\n  Maybe<void> Capture(KLDivLossCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override;\n  Maybe<void> Apply(const KLDivLossCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override;\n\n private:\n  AttrMap base_attrs_;\n};\n\nMaybe<void> KLDivLoss::Init(const OpExpr& op) {\n  const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n  CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n  base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n  return Maybe<void>::Ok();\n}\nMaybe<void> KLDivLoss::Capture(KLDivLossCaptureState* ctx, const TensorTuple& inputs,\n                               const TensorTuple& outputs, const AttrMap& attrs) const {\n  CHECK_EQ_OR_RETURN(inputs.size(), 2);  // NOLINT(maybe-need-error-msg)\n  ctx->input_requires_grad = inputs[0]->requires_grad();\n  ctx->target_requires_grad = inputs[1]->requires_grad();\n\n  ComposedAttrMap composed_attrs(attrs, base_attrs_);\n  ctx->log_target = JUST(composed_attrs.GetAttr<bool>(\"log_target\"));\n  ctx->SaveTensorForBackward(inputs[0]);  // input\n  ctx->SaveTensorForBackward(inputs[1]);  // target\n  return Maybe<void>::Ok();\n}\nMaybe<void> KLDivLoss::Apply(const KLDivLossCaptureState* ctx, const TensorTuple& out_grads,\n                             TensorTuple* in_grads) const {\n  CHECK_EQ_OR_RETURN(out_grads.size(), 1);            // NOLINT(maybe-need-error-msg)\n  CHECK_EQ_OR_RETURN(ctx->SavedTensors().size(), 2);  // NOLINT(maybe-need-error-msg)\n  const auto& dy = out_grads[0];\n  const auto& input = ctx->SavedTensors()[0];\n  const auto& target = ctx->SavedTensors()[1];\n  in_grads->resize(2);\n\n  if (ctx->input_requires_grad) {\n    (*in_grads)[0] = JUST(functional::KLDivLossGrad(dy, input, target, ctx->log_target));\n  }\n  if (ctx->target_requires_grad) {\n    (*in_grads)[1] = JUST(functional::KLDivLossTargetGrad(dy, input, target, ctx->log_target));\n  }\n\n  return Maybe<void>::Ok();\n}\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"kl_div_loss\", KLDivLoss);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/l2_normalize.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/framework/op_builder.h\"\n#include \"oneflow/core/framework/op_expr.h\"\n#include \"oneflow/core/framework/op_interpreter/op_interpreter_util.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct L2NormalizeCaptureState : public AutoGradCaptureState {\n  int64_t axis;\n  float epsilon;\n  bool requires_grad;\n};\n\nclass L2Normalize : public OpExprGradFunction<L2NormalizeCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override;\n  Maybe<void> Capture(L2NormalizeCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override;\n  Maybe<void> Apply(const L2NormalizeCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override;\n\n private:\n  AttrMap base_attrs_;\n};\n\nMaybe<void> L2Normalize::Init(const OpExpr& op) {\n  const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n  CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n  base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> L2Normalize::Capture(L2NormalizeCaptureState* ctx, const TensorTuple& inputs,\n                                 const TensorTuple& outputs, const AttrMap& attrs) const {\n  ctx->requires_grad = inputs.at(0)->requires_grad();\n  if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n\n  ctx->SaveTensorForBackward(outputs.at(0));  // y\n  ctx->SaveTensorForBackward(outputs.at(1));  // square_x_sum\n\n  ComposedAttrMap composed_attrs(attrs, base_attrs_);\n  ctx->axis = JUST(composed_attrs.GetAttr<int32_t>(\"axis\"));\n  ctx->epsilon = JUST(composed_attrs.GetAttr<float>(\"epsilon\"));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> L2Normalize::Apply(const L2NormalizeCaptureState* ctx, const TensorTuple& out_grads,\n                               TensorTuple* in_grads) const {\n  if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n  in_grads->resize(1);\n  CHECK_EQ_OR_RETURN(out_grads.size(), 2);  // NOLINT(maybe-need-error-msg)\n  const auto& y = ctx->SavedTensors().at(0);\n  const auto& square_x_sum = ctx->SavedTensors().at(1);\n  in_grads->at(0) =\n      JUST(functional::L2NormalizeGrad(out_grads.at(0), y, square_x_sum, ctx->axis, ctx->epsilon));\n  return Maybe<void>::Ok();\n}\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"l2_normalize\", L2Normalize);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/layer_norm.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/attr_map.h\"\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\n\nDEFINE_ENV_BOOL(ONEFLOW_USE_FUSE_LAYER_NORM_GRAD, false);\n\nnamespace one {\n\nstruct LayerNormCaptureState : public AutoGradCaptureState {\n  bool center = true;\n  bool scale = true;\n\n  int64_t begin_norm_axis = 1;\n  int64_t begin_params_axis = 1;\n\n  double epsilon = 1e-5;\n\n  bool x_requires_grad = true;\n  bool has_affine = true;\n\n  size_t gamma_index = 0;\n  size_t x_index = 1;\n  size_t mean_index = 2;\n  size_t inv_variance_index = 3;\n};\n\n// y, mean, inv_variance =\n//   layer_norm(x, [gamma], [beta], center=False, scale=False, begin_norm_axis=1,\n//              begin_params_axis=-1, epsilon=1e-5)\nclass LayerNorm : public OpExprGradFunction<LayerNormCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override;\n\n  Maybe<void> Capture(LayerNormCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override;\n\n  Maybe<void> Apply(const LayerNormCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override;\n\n private:\n  AttrMap base_attrs_;\n  std::string op_name_;\n};\n\nMaybe<void> LayerNorm::Init(const OpExpr& op) {\n  const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n  CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n  base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n  op_name_ = fw_op_expr->op_name();\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> LayerNorm::Capture(LayerNormCaptureState* ctx, const TensorTuple& inputs,\n                               const TensorTuple& outputs, const AttrMap& attrs) const {\n  ComposedAttrMap composed_attrs(attrs, base_attrs_);\n  ctx->center = JUST(composed_attrs.GetAttr<bool>(\"center\"));\n  ctx->scale = JUST(composed_attrs.GetAttr<bool>(\"scale\"));\n  ctx->begin_norm_axis = JUST(composed_attrs.GetAttr<int64_t>(\"begin_norm_axis\"));\n  ctx->begin_params_axis = JUST(composed_attrs.GetAttr<int64_t>(\"begin_params_axis\"));\n  ctx->epsilon = JUST(composed_attrs.GetAttr<double>(\"epsilon\"));\n\n  CHECK_EQ_OR_RETURN(inputs.size(), ctx->center + ctx->scale + 1);  // NOLINT(maybe-need-error-msg)\n  CHECK_EQ_OR_RETURN(outputs.size(), 3);                            // NOLINT(maybe-need-error-msg)\n\n  bool has_gamma_diff = ctx->scale && inputs.at(1)->requires_grad();\n  bool has_beta_diff = ctx->center && inputs.at(2)->requires_grad();\n\n  ctx->has_affine = has_gamma_diff && has_beta_diff;\n\n  ctx->x_requires_grad = inputs.at(0)->requires_grad();\n  if (ctx->x_requires_grad || ctx->has_affine) {\n    ctx->x_index = ctx->SaveTensorForBackward(inputs.at(0));\n    ctx->mean_index = ctx->SaveTensorForBackward(outputs.at(1));\n    ctx->inv_variance_index = ctx->SaveTensorForBackward(outputs.at(2));\n    if (ctx->x_requires_grad && ctx->scale) {\n      ctx->gamma_index = ctx->SaveTensorForBackward(inputs.at(1));  // save gamma.\n    }\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> LayerNorm::Apply(const LayerNormCaptureState* ctx, const TensorTuple& out_grads,\n                             TensorTuple* in_grads) const {\n  const auto& saved_tensors = ctx->SavedTensors();\n  in_grads->resize(ctx->center + ctx->scale + 1);\n  std::shared_ptr<Tensor> dy = out_grads.at(0);\n  int64_t begin_params_axis = ctx->begin_params_axis;\n  if (begin_params_axis < 0) { begin_params_axis += dy->shape()->NumAxes(); }\n  int64_t begin_norm_axis = ctx->begin_norm_axis;\n  if (begin_norm_axis < 0) { begin_norm_axis += dy->shape()->NumAxes(); }\n\n  std::shared_ptr<Tensor> x = saved_tensors.at(ctx->x_index);\n  std::shared_ptr<Tensor> mean = saved_tensors.at(ctx->mean_index);\n  std::shared_ptr<Tensor> inv_variance = saved_tensors.at(ctx->inv_variance_index);\n\n  if (EnvBool<ONEFLOW_USE_FUSE_LAYER_NORM_GRAD>()) {\n    // just for npu\n    CHECK(ctx->has_affine) << \"LayerNorm::Apply must has_affine for NPU GPT2 test\";\n    if (ctx->x_requires_grad) {\n      if (ctx->scale) {\n        std::shared_ptr<Tensor> gamma = saved_tensors.at(ctx->gamma_index);\n        *in_grads = *JUST(functional::FuseLayerNormGrad(\n            dy, x, mean, inv_variance, gamma, begin_norm_axis, begin_params_axis, ctx->epsilon));\n      } else {\n        UNIMPLEMENTED();\n      }\n    }\n  } else {\n    if (ctx->has_affine) {\n      // Use LayerNormParamGrad(Tensor dy, Tensor x, Tensor mean, Tensor inv_variance,\n      // Int64 begin_params_axis)\n      const auto& results =\n          JUST(functional::LayerNormParamGrad(dy, x, mean, inv_variance, begin_params_axis));\n      in_grads->at(1) = results->at(0);  // For gamma.\n      in_grads->at(2) = results->at(1);  // For beta.\n    }\n    if (ctx->x_requires_grad) {\n      if (ctx->scale) {\n        std::shared_ptr<Tensor> gamma = saved_tensors.at(ctx->gamma_index);\n        in_grads->at(0) = JUST(functional::LayerNormAffineGrad(dy, x, mean, inv_variance, gamma,\n                                                               begin_norm_axis, ctx->epsilon));\n      } else {\n        in_grads->at(0) = JUST(\n            functional::LayerNormGrad(dy, x, mean, inv_variance, begin_norm_axis, ctx->epsilon));\n      }\n    }\n  }\n  return Maybe<void>::Ok();\n}\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"layer_norm\", LayerNorm);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/lerp.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/attr_map.h\"\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\nnamespace one {\n\nconst int32_t INPUT_LEN = 3;\nstruct LerpCaptureState : public AutoGradCaptureState {\n  std::vector<bool> requires_grad;\n};\nstruct ScalarLerpCaptureState : public AutoGradCaptureState {\n  std::vector<bool> requires_grad;\n  Scalar operand;\n};\n\nclass LerpGrad : public OpExprGradFunction<LerpCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }\n\n  Maybe<void> Capture(LerpCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs,\n                      const AttrMap& attrs) const override {\n    CHECK_EQ_OR_RETURN(inputs.size(), INPUT_LEN);\n    CHECK_EQ_OR_RETURN(outputs.size(), 1);\n\n    for (int i = 0; i < INPUT_LEN; i++) {\n      ctx->requires_grad.push_back(inputs.at(i)->requires_grad());\n      ctx->SaveTensorForBackward(inputs.at(i));\n    }\n\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const LerpCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    CHECK_EQ_OR_RETURN(out_grads.size(), 1);\n    const auto& out_diff = out_grads.at(0);\n\n    const auto& start = ctx->SavedTensors().at(0);\n    const auto& end = ctx->SavedTensors().at(1);\n    const auto& weight = ctx->SavedTensors().at(2);\n\n    auto result = JUST(functional::LerpGrad(start, end, weight, out_diff));\n    CHECK_EQ_OR_RETURN(result->size(), INPUT_LEN);\n\n    in_grads->resize(INPUT_LEN);\n    for (int i = 0; i < INPUT_LEN; i++) {\n      if (ctx->requires_grad[i]) { in_grads->at(i) = result->at(i); }\n    }\n    return Maybe<void>::Ok();\n  }\n};\n\nclass ScalarLerpGrad : public OpExprGradFunction<ScalarLerpCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override {\n    const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n    base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n    CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Capture(ScalarLerpCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    CHECK_EQ_OR_RETURN(inputs.size(), INPUT_LEN - 1);\n    CHECK_EQ_OR_RETURN(outputs.size(), 1);\n\n    for (int i = 0; i < INPUT_LEN - 1; i++) {\n      ctx->requires_grad.push_back(inputs.at(i)->requires_grad());\n      ctx->SaveTensorForBackward(inputs.at(i));\n    }\n\n    ComposedAttrMap composed_attrs(attrs, base_attrs_);\n    bool has_float_operand = JUST(composed_attrs.GetAttr<bool>(\"has_float_operand\"));\n    if (has_float_operand) {\n      ctx->operand = Scalar(JUST(composed_attrs.GetAttr<double>(\"float_operand\")));\n    } else {\n      ctx->operand = Scalar(JUST(composed_attrs.GetAttr<int64_t>(\"int_operand\")));\n    }\n\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const ScalarLerpCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    CHECK_EQ_OR_RETURN(out_grads.size(), 1);\n    const auto& out_diff = out_grads.at(0);\n\n    const auto& start = ctx->SavedTensors().at(0);\n    const auto& end = ctx->SavedTensors().at(1);\n\n    auto result = JUST(functional::ScalarLerpGrad(start, end, out_diff, ctx->operand));\n    CHECK_EQ_OR_RETURN(result->size(), INPUT_LEN - 1);\n\n    in_grads->resize(INPUT_LEN - 1);\n    for (int i = 0; i < INPUT_LEN - 1; i++) {\n      if (ctx->requires_grad[i]) { in_grads->at(i) = result->at(i); }\n    }\n    return Maybe<void>::Ok();\n  }\n\n private:\n  AttrMap base_attrs_;\n};\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"lerp\", LerpGrad);\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"scalar_lerp\", ScalarLerpGrad);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/linalg_cross.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/just.h\"\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/functional/functional_api.yaml.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct LinalgCrossCaptureState : public AutoGradCaptureState {\n  int64_t dim = -1;\n  bool input_requires_grad = false;\n  bool other_requires_grad = false;\n};\n\nclass LinalgCross : public OpExprGradFunction<LinalgCrossCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override;\n  Maybe<void> Capture(LinalgCrossCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override;\n  Maybe<void> Apply(const LinalgCrossCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override;\n\n private:\n  AttrMap base_attrs_;\n};\n\nMaybe<void> LinalgCross::Init(const OpExpr& op) {\n  const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n  CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n  base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> LinalgCross::Capture(LinalgCrossCaptureState* ctx, const TensorTuple& inputs,\n                                 const TensorTuple& outputs, const AttrMap& attrs) const {\n  ctx->input_requires_grad = inputs.at(0)->requires_grad();\n  ctx->other_requires_grad = inputs.at(1)->requires_grad();\n\n  if (ctx->input_requires_grad) { ctx->SaveTensorForBackward(inputs.at(1)); }\n  if (ctx->other_requires_grad) { ctx->SaveTensorForBackward(inputs.at(0)); }\n\n  ComposedAttrMap composed_attrs(attrs, base_attrs_);\n  ctx->dim = JUST(composed_attrs.GetAttr<int64_t>(\"dim\"));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> LinalgCross::Apply(const LinalgCrossCaptureState* ctx, const TensorTuple& out_grads,\n                               TensorTuple* in_grads) const {\n  in_grads->resize(ctx->SavedTensors().size());\n  CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n\n  if (ctx->input_requires_grad) {\n    in_grads->at(0) =\n        JUST(functional::LinalgCross(ctx->SavedTensors().at(0), out_grads.at(0), ctx->dim));\n  }\n  if (ctx->other_requires_grad) {\n    in_grads->at(1) = JUST(functional::LinalgCross(\n        out_grads.at(0), ctx->SavedTensors().at(ctx->input_requires_grad ? 1 : 0), ctx->dim));\n  }\n  return Maybe<void>::Ok();\n}\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"linalg_cross\", LinalgCross);\n\n}  // namespace one\n}  // namespace oneflow"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/log_softmax.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/framework/op_builder.h\"\n#include \"oneflow/core/framework/op_interpreter/op_interpreter_util.h\"\n#include \"oneflow/core/framework/op_expr.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct LogSoftmaxCaptureState : public AutoGradCaptureState {\n  bool requires_grad;\n};\n\nclass LogSoftmax : public OpExprGradFunction<LogSoftmaxCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override;\n  Maybe<void> Capture(LogSoftmaxCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override;\n  Maybe<void> Apply(const LogSoftmaxCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override;\n\n private:\n  AttrMap base_attrs_;\n  std::shared_ptr<OpExpr> grad_op_;\n};\n\nMaybe<void> LogSoftmax::Init(const OpExpr& op) { return Maybe<void>::Ok(); }\n\nMaybe<void> LogSoftmax::Capture(LogSoftmaxCaptureState* ctx, const TensorTuple& inputs,\n                                const TensorTuple& outputs, const AttrMap& attrs) const {\n  CHECK_EQ_OR_RETURN(inputs.size(), 1);  // NOLINT(maybe-need-error-msg)\n  ctx->requires_grad = inputs.at(0)->requires_grad();\n  ctx->SaveTensorForBackward(outputs.at(0));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> LogSoftmax::Apply(const LogSoftmaxCaptureState* ctx, const TensorTuple& out_grads,\n                              TensorTuple* in_grads) const {\n  CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n  const auto& dy = out_grads.at(0);\n  const auto& y = ctx->SavedTensors().at(0);\n  in_grads->resize(1);\n  in_grads->at(0) = JUST(functional::LogSoftmaxGrad(dy, y));\n  return Maybe<void>::Ok();\n}\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"log_softmax\", LogSoftmax);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/masked_fill.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/framework/op_builder.h\"\n#include \"oneflow/core/framework/op_interpreter/op_interpreter_util.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct MaskedFillCaptureState : public AutoGradCaptureState {\n  bool requires_grad = true;\n};\n\nclass MaskedFill : public OpExprGradFunction<MaskedFillCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }\n  Maybe<void> Capture(MaskedFillCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    ctx->requires_grad = inputs.at(0)->requires_grad();\n    if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n\n    ctx->SaveTensorForBackward(inputs.at(0));\n    ctx->SaveTensorForBackward(inputs.at(1));\n    return Maybe<void>::Ok();\n  }\n  Maybe<void> Apply(const MaskedFillCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n    CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n    const std::shared_ptr<oneflow::one::Tensor>& x = ctx->SavedTensors().at(0);\n    const std::shared_ptr<oneflow::one::Tensor>& mask = ctx->SavedTensors().at(1);\n\n    std::shared_ptr<oneflow::one::Tensor> zero_out = JUST(functional::ZerosLike(x));\n    in_grads->resize(2);\n    in_grads->at(0) = JUST(functional::Where(mask, zero_out, out_grads.at(0)));\n    return Maybe<void>::Ok();\n  }\n};\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"masked_fill\", MaskedFill);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/math_binary_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/framework/op_interpreter/op_interpreter_util.h\"\n#include \"oneflow/core/functional/functional.h\"\n#include \"oneflow/user/ops/math_binary_elementwise_seq.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct BinaryMathCaptureState : public AutoGradCaptureState {\n  bool x_requires_grad;\n  bool y_requires_grad;\n};\n\ntypedef Maybe<one::Tensor> (*BinaryBwFunc)(const std::shared_ptr<one::Tensor>&,\n                                           const std::shared_ptr<one::Tensor>&,\n                                           const std::shared_ptr<one::Tensor>&);\n\ntemplate<BinaryBwFunc BwXFunc, BinaryBwFunc BwYFunc>\nclass BinaryMathOp : public OpExprGradFunction<BinaryMathCaptureState> {\n  Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }\n\n  Maybe<void> Capture(BinaryMathCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    ctx->x_requires_grad = inputs.at(0)->requires_grad();\n    ctx->y_requires_grad = inputs.at(1)->requires_grad();\n    ctx->SaveTensorForBackward(inputs.at(0));\n    ctx->SaveTensorForBackward(inputs.at(1));\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const BinaryMathCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    if (!(ctx->x_requires_grad || ctx->y_requires_grad)) { return Maybe<void>::Ok(); }\n    in_grads->resize(2);\n    const std::shared_ptr<one::Tensor>& x = ctx->SavedTensors().at(0);\n    const std::shared_ptr<one::Tensor>& y = ctx->SavedTensors().at(1);\n    if (ctx->x_requires_grad) { in_grads->at(0) = JUST(BwXFunc(x, y, out_grads.at(0))); }\n    if (ctx->y_requires_grad) { in_grads->at(1) = JUST(BwYFunc(x, y, out_grads.at(0))); }\n    return Maybe<void>::Ok();\n  }\n};\n\n#define INSTANTIAT_AND_REGISTER_BINARY_MATHOP_CLASS(op_type_name, op_cls)             \\\n  class op_cls##Cls final                                                             \\\n      : public BinaryMathOp<functional::op_cls##XGrad, functional::op_cls##YGrad> {}; \\\n  REGISTER_OP_EXPR_GRAD_FUNCTION(op_type_name, op_cls##Cls);\n\nOF_PP_FOR_EACH_TUPLE(INSTANTIAT_AND_REGISTER_BINARY_MATHOP_CLASS, MATH_BINARY_ELEMENTWISE_FUNC_SEQ);\n\n#undef INSTANTIAT_AND_REGISTER_BINARY_MATHOP_CLASS\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/math_unary_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/framework/op_interpreter/op_interpreter_util.h\"\n#include \"oneflow/user/ops/math_unary_elementwise_seq.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct UnaryMathCaptureState : public AutoGradCaptureState {\n  bool x_requires_grad;\n};\n\ntypedef Maybe<one::Tensor> (*UnaryBwFunc)(const std::shared_ptr<one::Tensor>&,\n                                          const std::shared_ptr<one::Tensor>&);\n\ntemplate<UnaryBwFunc BwFunc>\nclass UnaryMathBwdWithDyXOp : public OpExprGradFunction<UnaryMathCaptureState> {\n  Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }\n\n  Maybe<void> Capture(UnaryMathCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    ctx->x_requires_grad = inputs.at(0)->requires_grad();\n    ctx->SaveTensorForBackward(inputs.at(0));\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const UnaryMathCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    if (!ctx->x_requires_grad) { return Maybe<void>::Ok(); }\n    const auto& x = ctx->SavedTensors().at(0);\n    in_grads->at(0) = JUST(BwFunc(x, out_grads.at(0)));\n    return Maybe<void>::Ok();\n  }\n\n protected:\n  std::shared_ptr<OpExpr> grad_op_;\n};\n\ntemplate<UnaryBwFunc BwFunc>\nclass UnaryMathBwdWithDyYOp : public OpExprGradFunction<UnaryMathCaptureState> {\n  Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }\n\n  Maybe<void> Capture(UnaryMathCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    ctx->x_requires_grad = inputs.at(0)->requires_grad();\n    ctx->SaveTensorForBackward(outputs.at(0));\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const UnaryMathCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    if (!ctx->x_requires_grad) { return Maybe<void>::Ok(); }\n    const auto& y = ctx->SavedTensors().at(0);\n    in_grads->at(0) = JUST(BwFunc(y, out_grads.at(0)));\n    return Maybe<void>::Ok();\n  }\n\n protected:\n  std::shared_ptr<OpExpr> grad_op_;\n};\n\nclass UnaryMathBwdWithFillZeroOp : public OpExprGradFunction<UnaryMathCaptureState> {\n  Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }\n\n  Maybe<void> Capture(UnaryMathCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    ctx->x_requires_grad = inputs.at(0)->requires_grad();\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const UnaryMathCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    if (!ctx->x_requires_grad) { return Maybe<void>::Ok(); }\n    in_grads->at(0) = JUST(functional::ZerosLike(out_grads[0]));\n    return Maybe<void>::Ok();\n  }\n\n protected:\n  std::shared_ptr<OpExpr> grad_op_;\n};\n\n#define INSTANTIAT_AND_REGISTER_UNARY_MATHOP_WITH_DY_X_CLASS(op_type_name, op_cls)     \\\n  class op_cls##Cls final : public UnaryMathBwdWithDyXOp<functional::op_cls##Grad> {}; \\\n  REGISTER_OP_EXPR_GRAD_FUNCTION(op_type_name, op_cls##Cls);\n\nOF_PP_FOR_EACH_TUPLE(INSTANTIAT_AND_REGISTER_UNARY_MATHOP_WITH_DY_X_CLASS,\n                     MATH_UNARY_ELEMENTWISE_PRIMITIVE_FUNC_BWD_WITH_DY_X_SEQ);\n\n#undef INSTANTIAT_AND_REGISTER_UNARY_MATHOP_WITH_DY_X_CLASS\n\n#define INSTANTIAT_AND_REGISTER_UNARY_MATHOP_WITH_DY_Y_CLASS(op_type_name, op_cls)     \\\n  class op_cls##Cls final : public UnaryMathBwdWithDyYOp<functional::op_cls##Grad> {}; \\\n  REGISTER_OP_EXPR_GRAD_FUNCTION(op_type_name, op_cls##Cls);\n\nOF_PP_FOR_EACH_TUPLE(INSTANTIAT_AND_REGISTER_UNARY_MATHOP_WITH_DY_Y_CLASS,\n                     MATH_UNARY_ELEMENTWISE_FUNC_BWD_WITH_DY_Y_SEQ);\n\nOF_PP_FOR_EACH_TUPLE(INSTANTIAT_AND_REGISTER_UNARY_MATHOP_WITH_DY_Y_CLASS,\n                     OF_PP_MAKE_TUPLE_SEQ(\"tanh\", Tanh));\n\n#undef INSTANTIAT_AND_REGISTER_UNARY_MATHOP_WITH_DY_Y_CLASS\n\n#define INSTANTIAT_AND_REGISTER_UNARY_MATHOP_WITH_FILL_CLASS(op_type_name, op_cls)     \\\n  class op_cls##Cls final : public UnaryMathBwdWithDyYOp<functional::op_cls##Grad> {}; \\\n  REGISTER_OP_EXPR_GRAD_FUNCTION(op_type_name, UnaryMathBwdWithFillZeroOp);\n\nOF_PP_FOR_EACH_TUPLE(INSTANTIAT_AND_REGISTER_UNARY_MATHOP_WITH_FILL_CLASS,\n                     MATH_UNARY_ELEMENTWISE_FUNC_BWD_WITH_FILL_SEQ);\n#undef INSTANTIAT_AND_REGISTER_UNARY_MATHOP_WITH_FILL_CLASS\n\nclass NegativeOp : public OpExprGradFunction<UnaryMathCaptureState> {\n  Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }\n\n  Maybe<void> Capture(UnaryMathCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    ctx->x_requires_grad = inputs.at(0)->requires_grad();\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const UnaryMathCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    if (!ctx->x_requires_grad) { return Maybe<void>::Ok(); }\n    in_grads->at(0) = JUST(functional::Negative(out_grads[0]));\n    return Maybe<void>::Ok();\n  }\n\n protected:\n  std::shared_ptr<OpExpr> grad_op_;\n};\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"negative\", NegativeOp);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/matmul.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/framework/op_builder.h\"\n#include \"oneflow/core/framework/op_expr.h\"\n#include \"oneflow/core/framework/op_interpreter/op_interpreter_util.h\"\n#include \"oneflow/core/functional/functional.h\"\n#include \"oneflow/core/common/container_util.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct MatmulCaptureState : public AutoGradCaptureState {\n  bool transpose_a;\n  bool transpose_b;\n  double alpha;\n  bool requires_grad_a;\n  bool requires_grad_b;\n  size_t a_index;\n  size_t b_index;\n};\n\nclass Matmul : public OpExprGradFunction<MatmulCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override;\n  Maybe<void> Capture(MatmulCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override;\n  Maybe<void> Apply(const MatmulCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override;\n\n protected:\n  AttrMap base_attrs_;\n};\n\nMaybe<void> Matmul::Init(const OpExpr& op) {\n  const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n  CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n  base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> Matmul::Capture(MatmulCaptureState* ctx, const TensorTuple& inputs,\n                            const TensorTuple& outputs, const AttrMap& attrs) const {\n  ctx->requires_grad_a = inputs.at(0)->requires_grad();\n  ctx->requires_grad_b = inputs.at(1)->requires_grad();\n  if (!ctx->requires_grad_a && !ctx->requires_grad_b) { return Maybe<void>::Ok(); }\n\n  ComposedAttrMap composed_attrs(attrs, base_attrs_);\n  ctx->transpose_a = JUST(composed_attrs.GetAttr<bool>(\"transpose_a\"));\n  ctx->transpose_b = JUST(composed_attrs.GetAttr<bool>(\"transpose_b\"));\n  ctx->alpha = JUST(composed_attrs.GetAttr<double>(\"alpha\"));\n  if (ctx->requires_grad_a) {\n    ctx->b_index = ctx->SaveTensorForBackward(inputs.at(1));  // input b\n  }\n  if (ctx->requires_grad_b) {\n    ctx->a_index = ctx->SaveTensorForBackward(inputs.at(0));  // input a\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> Matmul::Apply(const MatmulCaptureState* ctx, const TensorTuple& out_grads,\n                          TensorTuple* in_grads) const {\n  if (!ctx->requires_grad_a && !ctx->requires_grad_b) { return Maybe<void>::Ok(); }\n  CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n\n  in_grads->resize(2);\n  if (ctx->requires_grad_a) {\n    const auto& input_b = ctx->SavedTensors().at(ctx->b_index);\n    if (ctx->transpose_a) {\n      in_grads->at(0) =\n          JUST(functional::MatMul(input_b, out_grads.at(0), ctx->transpose_b, true, ctx->alpha));\n    } else {\n      in_grads->at(0) = JUST(\n          functional::MatMul(out_grads.at(0), input_b, false, !(ctx->transpose_b), ctx->alpha));\n    }\n  }\n\n  if (ctx->requires_grad_b) {\n    const auto& input_a = ctx->SavedTensors().at(ctx->a_index);\n    if (ctx->transpose_b) {\n      in_grads->at(1) =\n          JUST(functional::MatMul(out_grads.at(0), input_a, true, ctx->transpose_a, ctx->alpha));\n    } else {\n      in_grads->at(1) = JUST(\n          functional::MatMul(input_a, out_grads.at(0), !(ctx->transpose_a), false, ctx->alpha));\n    }\n  }\n\n  return Maybe<void>::Ok();\n}\n\nstruct BroadcastMatmulCaptureState : public AutoGradCaptureState {\n  bool transpose_a = false;\n  bool transpose_b = false;\n  double alpha = 1.0;\n  bool requires_grad_a = true;\n  bool requires_grad_b = true;\n  size_t a_index = 0;\n  size_t b_index = 1;\n  bool broadcast_a = false;\n  bool broadcast_b = false;\n  int64_t b_num_axes = 0;\n};\n\nclass BroadcastMatmul : public OpExprGradFunction<BroadcastMatmulCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override;\n  Maybe<void> Capture(BroadcastMatmulCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override;\n  Maybe<void> Apply(const BroadcastMatmulCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override;\n\n protected:\n  AttrMap base_attrs_;\n};\n\nMaybe<void> BroadcastMatmul::Init(const OpExpr& op) {\n  const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n  CHECK_NOTNULL_OR_RETURN(fw_op_expr) << \"fw_op_expr should not be null. \";\n  base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> BroadcastMatmul::Capture(BroadcastMatmulCaptureState* ctx, const TensorTuple& inputs,\n                                     const TensorTuple& outputs, const AttrMap& attrs) const {\n  ctx->requires_grad_a = JUST(VectorAt(inputs, 0))->requires_grad();\n  ctx->requires_grad_b = JUST(VectorAt(inputs, 1))->requires_grad();\n  if (!ctx->requires_grad_a && !ctx->requires_grad_b) { return Maybe<void>::Ok(); }\n\n  const auto a_shape = JUST(VectorAt(inputs, 0))->shape();\n  const auto b_shape = JUST(VectorAt(inputs, 1))->shape();\n\n  const int64_t a_num_axes = a_shape->NumAxes();\n  const int64_t b_num_axes = b_shape->NumAxes();\n\n  const size_t num_max_batch_dims = std::max(a_num_axes, b_num_axes) - 2;\n  auto MakeGetBatchDim = [num_max_batch_dims](size_t num_dims, const Shape& shape_dim) {\n    const int64_t num_batch_dims = num_dims - 2;\n    const int64_t num_padding_dims = num_max_batch_dims - num_batch_dims;\n    return [num_padding_dims, shape_dim](size_t index) {\n      return index < num_padding_dims ? 1 : shape_dim.At(index - num_padding_dims);\n    };\n  };\n  auto GetABatchDim = MakeGetBatchDim(a_num_axes, *a_shape);\n  auto GetBBatchDim = MakeGetBatchDim(b_num_axes, *b_shape);\n  bool broadcast_a = false;\n  bool broadcast_b = false;\n\n  for (int32_t i = 0; i < num_max_batch_dims; i++) {\n    if (GetABatchDim(i) < GetBBatchDim(i) || a_num_axes < b_num_axes) {\n      broadcast_a = true;\n      break;\n    }\n  }\n\n  for (int32_t i = 0; i < num_max_batch_dims; i++) {\n    if (GetBBatchDim(i) < GetABatchDim(i) || b_num_axes < a_num_axes) {\n      broadcast_b = true;\n      break;\n    }\n  }\n\n  if (b_num_axes == 2 && !ctx->transpose_a) {\n    // In this case, we can directly use `broadcast_matmul_grad_b` OP to generate Grad instead of\n    // broadcast_matmul+reduce_sum_like.\n    broadcast_b = false;\n  }\n\n  ctx->broadcast_a = broadcast_a;\n  ctx->broadcast_b = broadcast_b;\n\n  ComposedAttrMap composed_attrs(attrs, base_attrs_);\n  ctx->transpose_a = JUST(composed_attrs.GetAttr<bool>(\"transpose_a\"));\n  ctx->transpose_b = JUST(composed_attrs.GetAttr<bool>(\"transpose_b\"));\n  ctx->alpha = JUST(composed_attrs.GetAttr<double>(\"alpha\"));\n\n  if (ctx->requires_grad_a) {\n    ctx->b_index = ctx->SaveTensorForBackward(JUST(VectorAt(inputs, 1)));  // input b\n    if (broadcast_a) {\n      ctx->a_index = ctx->SaveTensorForBackward(JUST(VectorAt(inputs, 0)));  // input a\n    }\n  }\n\n  if (ctx->requires_grad_b) {\n    ctx->b_num_axes = JUST(VectorAt(inputs, 1))->shape()->NumAxes();\n    ctx->a_index = ctx->SaveTensorForBackward(JUST(VectorAt(inputs, 0)));  // input a\n    if (broadcast_b) {\n      ctx->b_index = ctx->SaveTensorForBackward(JUST(VectorAt(inputs, 1)));  // input b\n    }\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> BroadcastMatmul::Apply(const BroadcastMatmulCaptureState* ctx,\n                                   const TensorTuple& out_grads, TensorTuple* in_grads) const {\n  if (!ctx->requires_grad_a && !ctx->requires_grad_b) { return Maybe<void>::Ok(); }\n  CHECK_EQ_OR_RETURN(out_grads.size(), 1) << \"Out grad size should be equal to 1. \";\n  in_grads->resize(2);\n  const auto out_shape = JUST(VectorAt(out_grads, 0))->shape();\n  const int64_t out_num_axes = out_shape->NumAxes();\n  const size_t num_max_batch_dims = out_num_axes - 2;\n  auto MakeGetBatchDim = [num_max_batch_dims](size_t num_dims, const Shape& shape_dim) {\n    const int64_t num_batch_dims = num_dims - 2;\n    const int64_t num_padding_dims = num_max_batch_dims - num_batch_dims;\n    return [num_padding_dims, shape_dim](size_t index) {\n      return index < num_padding_dims ? 1 : shape_dim.At(index - num_padding_dims);\n    };\n  };\n  auto GetOutBatchDim = MakeGetBatchDim(out_num_axes, *out_shape);\n  if (ctx->requires_grad_a) {\n    std::shared_ptr<Tensor> broadcast_grad_a;\n    const auto& input_b = ctx->SavedTensors().at(ctx->b_index);\n    if (ctx->transpose_a) {\n      broadcast_grad_a = JUST(functional::MatMul(input_b, JUST(VectorAt(out_grads, 0)),\n                                                 ctx->transpose_b, true, ctx->alpha));\n    } else {\n      broadcast_grad_a = JUST(functional::MatMul(JUST(VectorAt(out_grads, 0)), input_b, false,\n                                                 !(ctx->transpose_b), ctx->alpha));\n    }\n    if (ctx->broadcast_a) {\n      const auto& input_a = JUST(VectorAt(ctx->SavedTensors(), ctx->a_index));\n      const auto a_shape = input_a->shape();\n      const int64_t a_num_axes = a_shape->NumAxes();\n\n      std::vector<int32_t> a_reduce_vec;\n      auto GetABatchDim = MakeGetBatchDim(a_num_axes, *a_shape);\n      const int64_t a_out_num_dim_differ = out_num_axes - a_num_axes;\n      for (int32_t i = 0; i < out_num_axes - 2; i++) {\n        if (GetOutBatchDim(i) > GetABatchDim(i)\n            || (GetOutBatchDim(i) == 1 && i < a_out_num_dim_differ)) {\n          a_reduce_vec.push_back(i);\n        }\n      }\n      JUST(VectorAt(*in_grads, 0)) =\n          JUST(functional::ReduceSumLike(broadcast_grad_a, input_a, a_reduce_vec));\n    } else {\n      JUST(VectorAt(*in_grads, 0)) = broadcast_grad_a;\n    }\n  }\n\n  if (ctx->requires_grad_b) {\n    const auto& input_a = ctx->SavedTensors().at(ctx->a_index);\n    if (ctx->b_num_axes == 2 && !ctx->transpose_a) {\n      if (ctx->transpose_b) {\n        JUST(VectorAt(*in_grads, 1)) = JUST(\n            functional::BroadcastMatmulGradB(JUST(VectorAt(out_grads, 0)), input_a, ctx->alpha));\n      } else {\n        JUST(VectorAt(*in_grads, 1)) = JUST(\n            functional::BroadcastMatmulGradB(input_a, JUST(VectorAt(out_grads, 0)), ctx->alpha));\n      }\n    } else {\n      std::shared_ptr<Tensor> broadcast_grad_b;\n      if (ctx->transpose_b) {\n        broadcast_grad_b = JUST(functional::MatMul(JUST(VectorAt(out_grads, 0)), input_a, true,\n                                                   ctx->transpose_a, ctx->alpha));\n      } else {\n        broadcast_grad_b = JUST(functional::MatMul(input_a, JUST(VectorAt(out_grads, 0)),\n                                                   !ctx->transpose_a, false, ctx->alpha));\n      }\n      if (ctx->broadcast_b) {\n        const auto& input_b = JUST(VectorAt(ctx->SavedTensors(), ctx->b_index));\n        const auto b_shape = input_b->shape();\n        std::vector<int32_t> b_reduce_vec;\n        auto GetBBatchDim = MakeGetBatchDim(ctx->b_num_axes, *b_shape);\n        const int64_t b_out_num_dim_differ = out_num_axes - ctx->b_num_axes;\n        for (int32_t i = 0; i < out_num_axes - 2; i++) {\n          if (GetOutBatchDim(i) > GetBBatchDim(i)\n              || (GetOutBatchDim(i) == 1 && i < b_out_num_dim_differ)) {\n            b_reduce_vec.push_back(i);\n          }\n        }\n        JUST(VectorAt(*in_grads, 1)) =\n            JUST(functional::ReduceSumLike(broadcast_grad_b, input_b, b_reduce_vec));\n      } else {\n        JUST(VectorAt(*in_grads, 1)) = broadcast_grad_b;\n      }\n    }\n  }\n  return Maybe<void>::Ok();\n}\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"matmul\", Matmul);\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"batch_matmul\", Matmul);\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"broadcast_matmul\", BroadcastMatmul);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/matrix_vector_product.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/framework/op_builder.h\"\n#include \"oneflow/core/framework/op_expr.h\"\n#include \"oneflow/core/framework/op_interpreter/op_interpreter_util.h\"\n#include \"oneflow/core/functional/functional.h\"\n#include \"oneflow/core/common/container_util.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct MatrixVectorProductCaptureState : public AutoGradCaptureState {\n  bool requires_grad_a = false;\n  bool requires_grad_b = false;\n  size_t a_index = 0;\n  size_t b_index = 1;\n};\n\nclass MatrixVectorProduct : public OpExprGradFunction<MatrixVectorProductCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override;\n  Maybe<void> Capture(MatrixVectorProductCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override;\n  Maybe<void> Apply(const MatrixVectorProductCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override;\n\n protected:\n  AttrMap base_attrs_;\n};\n\nMaybe<void> MatrixVectorProduct::Init(const OpExpr& op) {\n  const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n  CHECK_NOTNULL_OR_RETURN(fw_op_expr) << \"fw_op_expr should not be null. \";\n  base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> MatrixVectorProduct::Capture(MatrixVectorProductCaptureState* ctx,\n                                         const TensorTuple& inputs, const TensorTuple& outputs,\n                                         const AttrMap& attrs) const {\n  ctx->requires_grad_a = JUST(VectorAt(inputs, 0))->requires_grad();\n  ctx->requires_grad_b = JUST(VectorAt(inputs, 1))->requires_grad();\n  if (!ctx->requires_grad_a && !ctx->requires_grad_b) { return Maybe<void>::Ok(); }\n\n  ComposedAttrMap composed_attrs(attrs, base_attrs_);\n  if (ctx->requires_grad_a) {\n    ctx->b_index = ctx->SaveTensorForBackward(JUST(VectorAt(inputs, 1)));  // input b\n  }\n  if (ctx->requires_grad_b) {\n    ctx->a_index = ctx->SaveTensorForBackward(JUST(VectorAt(inputs, 0)));  // input a\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> MatrixVectorProduct::Apply(const MatrixVectorProductCaptureState* ctx,\n                                       const TensorTuple& out_grads, TensorTuple* in_grads) const {\n  if (!ctx->requires_grad_a && !ctx->requires_grad_b) { return Maybe<void>::Ok(); }\n  CHECK_EQ_OR_RETURN(out_grads.size(), 1) << \"Out grad size should be equal to 1. \";\n\n  in_grads->resize(2);\n  if (ctx->requires_grad_a) {\n    const auto& input_b = JUST(VectorAt(ctx->SavedTensors(), ctx->b_index));\n    JUST(VectorAt(*in_grads, 0)) =\n        JUST(functional::MatrixVectorProductGradA(JUST(VectorAt(out_grads, 0)), input_b));\n  }\n\n  if (ctx->requires_grad_b) {\n    const auto& input_a = JUST(VectorAt(ctx->SavedTensors(), ctx->a_index));\n    JUST(VectorAt(*in_grads, 1)) =\n        JUST(functional::MatrixVectorProductGradB(JUST(VectorAt(out_grads, 0)), input_a));\n    if (input_a->dtype()->is_complex()) {\n      JUST(VectorAt(*in_grads, 1)) = JUST(functional::Conj(JUST(VectorAt(*in_grads, 1))));\n    }\n  }\n\n  return Maybe<void>::Ok();\n}\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"matrix_vector_product\", MatrixVectorProduct);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/max_pool.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/framework/attr_map.h\"\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/framework/op_builder.h\"\n#include \"oneflow/core/framework/op_interpreter/op_interpreter_util.h\"\n#include \"oneflow/core/framework/op_expr.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\nnamespace one {\n\nnamespace {\n\nstruct MaxPoolCaptureState : public AutoGradCaptureState {\n  bool requires_grad = false;\n  size_t input_index = 0;\n  size_t indice_index = 0;\n\n  std::string data_format;\n  std::vector<int32_t> padding;\n  std::vector<int32_t> kernel_size;\n  std::vector<int32_t> stride;\n  std::vector<int32_t> dilation;\n  bool return_indices = false;\n  bool ceil_mode = false;\n};\n\nclass MaxPoolNdGrad : public OpExprGradFunction<MaxPoolCaptureState> {\n public:\n  virtual ~MaxPoolNdGrad() = default;\n\n  using OpExprGradFunction<MaxPoolCaptureState>::Init;\n\n  Maybe<void> Init(const OpExpr& op) override;\n  Maybe<void> Capture(MaxPoolCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override;\n  Maybe<void> Apply(const MaxPoolCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override;\n\n private:\n  AttrMap base_attrs_;\n};\n\nMaybe<void> MaxPoolNdGrad::Init(const OpExpr& op) {\n  const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n  CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n  base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> MaxPoolNdGrad::Capture(MaxPoolCaptureState* ctx, const TensorTuple& inputs,\n                                   const TensorTuple& outputs, const AttrMap& attrs) const {\n  ctx->requires_grad = inputs.at(0)->requires_grad();\n  if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n\n  ctx->input_index = ctx->SaveTensorForBackward(inputs.at(0));\n  ctx->indice_index = ctx->SaveTensorForBackward(outputs.at(1));\n\n  ComposedAttrMap composed_attrs(attrs, base_attrs_);\n  ctx->data_format = JUST(composed_attrs.GetAttr<std::string>(\"data_format\"));\n  ctx->padding = JUST(composed_attrs.GetAttr<std::vector<int32_t>>(\"padding\"));\n  ctx->kernel_size = JUST(composed_attrs.GetAttr<std::vector<int32_t>>(\"kernel_size\"));\n  ctx->stride = JUST(composed_attrs.GetAttr<std::vector<int32_t>>(\"stride\"));\n  ctx->dilation = JUST(composed_attrs.GetAttr<std::vector<int32_t>>(\"dilation\"));\n  ctx->return_indices = JUST(composed_attrs.GetAttr<bool>(\"return_indices\"));\n  ctx->ceil_mode = JUST(composed_attrs.GetAttr<bool>(\"ceil_mode\"));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> MaxPoolNdGrad::Apply(const MaxPoolCaptureState* ctx, const TensorTuple& out_grads,\n                                 TensorTuple* in_grads) const {\n  if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n  CHECK_LE_OR_RETURN(out_grads.size(), 2);  // NOLINT(maybe-need-error-msg)\n\n  int32_t ndims = ctx->kernel_size.size();\n  const auto& input = ctx->SavedTensors().at(ctx->input_index);\n  const auto& indice = ctx->SavedTensors().at(ctx->indice_index);\n\n  in_grads->resize(1);\n  (*in_grads)[0] = JUST(functional::MaxPoolNdGrad(\n      input, indice, out_grads[0], ndims, ctx->data_format, ctx->padding, ctx->kernel_size,\n      ctx->stride, ctx->dilation, ctx->return_indices, ctx->ceil_mode));\n\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"max_pool_1d\", MaxPoolNdGrad);\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"max_pool_2d\", MaxPoolNdGrad);\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"max_pool_3d\", MaxPoolNdGrad);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/max_unpool.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n\nnamespace oneflow {\nnamespace one {\n\nnamespace {\n\nstruct MaxUnpoolCaptureState : public AutoGradCaptureState {\n  bool requires_grad = false;\n  size_t input_index = 0;\n  size_t indices_index = 0;\n};\n\nusing FuncType = decltype(functional::MaxUnpool1dGrad);\n\ntemplate<FuncType F>\nclass MaxUnpoolNdGrad : public OpExprGradFunction<MaxUnpoolCaptureState> {\n public:\n  virtual ~MaxUnpoolNdGrad() = default;\n\n  using OpExprGradFunction<MaxUnpoolCaptureState>::Init;\n\n  Maybe<void> Init(const OpExpr& op) override;\n  Maybe<void> Capture(MaxUnpoolCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override;\n  Maybe<void> Apply(const MaxUnpoolCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override;\n\n private:\n  AttrMap base_attrs_;\n};\n\ntemplate<FuncType F>\nMaybe<void> MaxUnpoolNdGrad<F>::Init(const OpExpr& op) {\n  const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n  CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n  base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n  return Maybe<void>::Ok();\n}\n\ntemplate<FuncType F>\nMaybe<void> MaxUnpoolNdGrad<F>::Capture(MaxUnpoolCaptureState* ctx, const TensorTuple& inputs,\n                                        const TensorTuple& outputs, const AttrMap& attrs) const {\n  ctx->requires_grad = inputs.at(0)->requires_grad();\n  if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n  ctx->input_index = ctx->SaveTensorForBackward(inputs.at(0));\n  ctx->indices_index = ctx->SaveTensorForBackward(inputs.at(1));\n  return Maybe<void>::Ok();\n}\n\ntemplate<FuncType F>\nMaybe<void> MaxUnpoolNdGrad<F>::Apply(const MaxUnpoolCaptureState* ctx,\n                                      const TensorTuple& out_grads, TensorTuple* in_grads) const {\n  if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n  CHECK_LE_OR_RETURN(out_grads.size(), 2);  // NOLINT(maybe-need-error-msg)\n\n  const auto& input = ctx->SavedTensors().at(ctx->input_index);\n  const auto& indices = ctx->SavedTensors().at(ctx->indices_index);\n\n  in_grads->resize(2);\n  (*in_grads)[0] = JUST(F(input, indices, out_grads[0]));\n\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"max_unpool_1d\", MaxUnpoolNdGrad<functional::MaxUnpool1dGrad>);\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"max_unpool_2d\", MaxUnpoolNdGrad<functional::MaxUnpool2dGrad>);\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"max_unpool_3d\", MaxUnpoolNdGrad<functional::MaxUnpool3dGrad>);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/median.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/attr_map.h\"\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/framework/op_builder.h\"\n#include \"oneflow/core/framework/op_interpreter/op_interpreter_util.h\"\n#include \"oneflow/core/functional/functional.h\"\n#include \"oneflow/core/functional/sequence_function.h\"\n#include \"oneflow/core/common/container_util.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct MedianCaptureState : public AutoGradCaptureState {\n  bool requires_grad = false;\n};\n\nclass Median : public OpExprGradFunction<MedianCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }\n  Maybe<void> Capture(MedianCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    ctx->requires_grad = JUST(VectorAt(inputs, 0))->requires_grad();\n    if (ctx->requires_grad) {\n      ctx->SaveTensorForBackward(JUST(VectorAt(inputs, 0)));\n      ctx->SaveTensorForBackward(JUST(VectorAt(outputs, 0)));\n    }\n    return Maybe<void>::Ok();\n  }\n  Maybe<void> Apply(const MedianCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    if (ctx->requires_grad) {\n      const auto& input = JUST(VectorAt(ctx->SavedTensors(), 0));\n      const auto& output = JUST(VectorAt(ctx->SavedTensors(), 1));\n      const auto& dy = JUST(VectorAt(out_grads, 0));\n      std::vector<int32_t> axis(input->ndim());\n      std::iota(axis.begin(), axis.end(), 0);\n      const auto cast_like =\n          JUST(functional::SequenceFunction<Maybe<Tensor>()>(\n                   [&]() { return functional::BroadcastLike(output, input, axis); })\n                   .then(std::bind(functional::BroadcastEqual, input, std::placeholders::_1))\n                   .then(std::bind(functional::CastLike, std::placeholders::_1, input))\n                   .call());\n\n      const auto bcast_like_div =\n          JUST(functional::SequenceFunction<Maybe<Tensor>()>(\n                   [&]() { return functional::ReduceSum(cast_like, axis, false, NullOpt); })\n                   .then(std::bind(functional::Div, dy, std::placeholders::_1))\n                   .then(std::bind(functional::BroadcastLike, std::placeholders::_1, input, axis))\n                   .call());\n\n      in_grads->resize(1);\n      JUST(VectorAt(*in_grads, 0)) = JUST(functional::Mul(bcast_like_div, cast_like));\n    }\n    return Maybe<void>::Ok();\n  }\n};\n\nstruct MedianWithIndicesCaptureState : public AutoGradCaptureState {\n  bool requires_grad = false;\n};\n\nclass MedianWithIndices : public OpExprGradFunction<MedianWithIndicesCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }\n  Maybe<void> Capture(MedianWithIndicesCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    ctx->requires_grad = JUST(VectorAt(inputs, 0))->requires_grad();\n    if (ctx->requires_grad) {\n      ctx->SaveTensorForBackward(JUST(VectorAt(inputs, 0)));\n      ctx->SaveTensorForBackward(JUST(VectorAt(outputs, 1)));\n    }\n    return Maybe<void>::Ok();\n  }\n  Maybe<void> Apply(const MedianWithIndicesCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    if (ctx->requires_grad) {\n      in_grads->resize(1);\n      const auto& input = JUST(VectorAt(ctx->SavedTensors(), 0));\n      const auto& indices = JUST(functional::Unsqueeze(JUST(VectorAt(ctx->SavedTensors(), 1)), -1));\n      const auto& dout = JUST(functional::Unsqueeze(JUST(VectorAt(out_grads, 0)), -1));\n      JUST(VectorAt(*in_grads, 0)) = JUST(functional::DimScatterUpdate(\n          JUST(functional::Constant(*(input->shape()), Scalar(0), *dout->dtype(),\n                                    JUST(dout->device()))),\n          -1, indices, dout, /*inplace*/ false));\n    }\n    return Maybe<void>::Ok();\n  }\n};\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"median\", Median);\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"median_with_indices\", MedianWithIndices);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/mode.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/attr_map.h\"\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/framework/op_builder.h\"\n#include \"oneflow/core/framework/op_interpreter/op_interpreter_util.h\"\n#include \"oneflow/core/functional/functional.h\"\n#include \"oneflow/core/functional/sequence_function.h\"\n#include \"oneflow/core/common/container_util.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct ModeCaptureState : public AutoGradCaptureState {\n  bool requires_grad = false;\n};\n\nclass Mode : public OpExprGradFunction<ModeCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }\n  Maybe<void> Capture(ModeCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs,\n                      const AttrMap& attrs) const override {\n    ctx->requires_grad = JUST(VectorAt(inputs, 0))->requires_grad();\n    if (ctx->requires_grad) {\n      ctx->SaveTensorForBackward(JUST(VectorAt(inputs, 0)));\n      ctx->SaveTensorForBackward(JUST(VectorAt(outputs, 1)));\n    }\n    return Maybe<void>::Ok();\n  }\n  Maybe<void> Apply(const ModeCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    if (ctx->requires_grad) {\n      in_grads->resize(1);\n      const auto& input = JUST(VectorAt(ctx->SavedTensors(), 0));\n      const auto& indices = JUST(functional::Unsqueeze(JUST(VectorAt(ctx->SavedTensors(), 1)), -1));\n      const auto& dout = JUST(functional::Unsqueeze(JUST(VectorAt(out_grads, 0)), -1));\n      JUST(VectorAt(*in_grads, 0)) = JUST(functional::DimScatterUpdate(\n          JUST(functional::Constant(*(input->shape()), Scalar(0), *dout->dtype(),\n                                    JUST(dout->device()))),\n          -1, indices, dout, /*inplace*/ false));\n    }\n    return Maybe<void>::Ok();\n  }\n};\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"mode\", Mode);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/narrow.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/framework/op_builder.h\"\n#include \"oneflow/core/framework/op_interpreter/op_interpreter_util.h\"\n#include \"oneflow/core/functional/functional.h\"\n#include \"oneflow/core/job/lazy_mode.h\"\n#include \"oneflow/core/framework/nd_sbp.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct NarrowCaptureState : public AutoGradCaptureState {\n  bool requires_grad;\n  Shape shape;\n  int64_t dim;\n  int64_t start;\n  int64_t length;\n};\n\nclass Narrow : public OpExprGradFunction<NarrowCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override {\n    const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n    CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n    base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Capture(NarrowCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    CHECK_EQ_OR_RETURN(inputs.size(), 1);   // NOLINT(maybe-need-error-msg)\n    CHECK_EQ_OR_RETURN(outputs.size(), 1);  // NOLINT(maybe-need-error-msg)\n    ctx->requires_grad = inputs.at(0)->requires_grad();\n    if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n\n    ComposedAttrMap composed_attrs(attrs, base_attrs_);\n    ctx->dim = JUST(composed_attrs.GetAttr<int64_t>(\"dim\"));\n    ctx->start = JUST(composed_attrs.GetAttr<int64_t>(\"start\"));\n    ctx->length = JUST(composed_attrs.GetAttr<int64_t>(\"length\"));\n    if (LazyMode::is_enabled()) {\n      ctx->SaveTensorForBackward(inputs.at(0));\n    } else {\n      ctx->shape = *(inputs.at(0)->shape());\n    }\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const NarrowCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    const auto& dy = out_grads.at(0);\n    if (ctx->requires_grad) {\n      std::shared_ptr<Tensor> like;\n      if (LazyMode::is_enabled()) {\n        like = ctx->SavedTensors().at(0);\n      } else if (dy->is_local()) {\n        like = JUST(functional::Empty(ctx->shape, dy->dtype(), JUST(dy->device()),\n                                      ctx->requires_grad, /*pin_memory=*/false));\n      } else {\n        like = JUST(\n            functional::GlobalEmpty(ctx->shape, dy->dtype(), JUST(dy->parallel_desc()),\n                                    *JUST(private_details::RawGetSbpList(JUST(dy->nd_sbp())))));\n      }\n      in_grads->resize(1);\n      in_grads->at(0) = JUST(functional::NarrowGrad(dy, like, ctx->dim, ctx->start, ctx->length));\n    }\n    return Maybe<void>::Ok();\n  }\n\n private:\n  AttrMap base_attrs_;\n};\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"narrow\", Narrow);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/nll.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/functional/functional.h\"\n#include \"oneflow/core/common/container_util.h\"\n\nnamespace oneflow {\n\nnamespace one {\n\nstruct NLLCaptureState : public AutoGradCaptureState {\n  bool requires_grad = false;\n  int64_t ignore_index = -100;\n};\n\nclass NLLGradFunction : public OpExprGradFunction<NLLCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override;\n  Maybe<void> Capture(NLLCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs,\n                      const AttrMap& attrs) const override;\n  Maybe<void> Apply(const NLLCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override;\n\n private:\n  AttrMap base_attrs_;\n};\n\nMaybe<void> NLLGradFunction::Init(const OpExpr& op) {\n  const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n  CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n  base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> NLLGradFunction::Capture(NLLCaptureState* ctx, const TensorTuple& inputs,\n                                     const TensorTuple& outputs, const AttrMap& attrs) const {\n  auto input = JUST(VectorAt(inputs, 0));\n  ctx->requires_grad = input->requires_grad();\n  if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n\n  ComposedAttrMap composed_attrs(attrs, base_attrs_);\n  ctx->ignore_index = JUST(composed_attrs.GetAttr<int64_t>(\"ignore_index\"));\n  ctx->SaveTensorForBackward(input);                      // input\n  ctx->SaveTensorForBackward(JUST(VectorAt(inputs, 1)));  // target\n  if (inputs.size() == 3) {\n    ctx->SaveTensorForBackward(inputs[2]);  // weight\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> NLLGradFunction::Apply(const NLLCaptureState* ctx, const TensorTuple& out_grads,\n                                   TensorTuple* in_grads) const {\n  if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n\n  CHECK_EQ_OR_RETURN(out_grads.size(), 2);  // NOLINT(maybe-need-error-msg)\n  CHECK_GE_OR_RETURN(ctx->SavedTensors().size(), 2)\n      << Error::RuntimeError()\n      << \"The number of saved tensors is expected to be greater than or equal to 2, but got \"\n      << ctx->SavedTensors().size();\n  const auto& out_grad = out_grads[0];\n  const auto& input = ctx->SavedTensors()[0];\n  const auto& target = ctx->SavedTensors()[1];\n\n  in_grads->resize(ctx->SavedTensors().size());\n\n  if (ctx->SavedTensors().size() == 2) {\n    JUST(VectorAt(*in_grads, 0)) =\n        JUST(functional::NLLGrad(out_grad, input, target, NullOpt, ctx->ignore_index));\n  } else {\n    // has weight\n    auto weight = JUST(VectorAt(ctx->SavedTensors(), 2));\n    JUST(VectorAt(*in_grads, 0)) =\n        JUST(functional::NLLGrad(out_grad, input, target, weight, ctx->ignore_index));\n  }\n\n  return Maybe<void>::Ok();\n}\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"nll\", NLLGradFunction);\n\n}  // namespace one\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/noncontiguous_binary_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <glog/logging.h>\n#include \"oneflow/core/common/just.h\"\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/framework/op_builder.h\"\n#include \"oneflow/core/framework/op_expr.h\"\n#include \"oneflow/core/framework/op_interpreter/op_interpreter_util.h\"\n#include \"oneflow/core/functional/functional.h\"\n#include \"oneflow/core/functional/functional_api.yaml.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct NonContiguousBinaryOpCaptureState : public AutoGradCaptureState {\n  bool lhs_requires_grad = false;\n  bool rhs_requires_grad = false;\n  std::string op = \"add\";\n  bool inplace = false;\n};\n\nclass NonContiguousBinaryOp : public OpExprGradFunction<NonContiguousBinaryOpCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override;\n  Maybe<void> Capture(NonContiguousBinaryOpCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override;\n  Maybe<void> Apply(const NonContiguousBinaryOpCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override;\n\n private:\n  AttrMap base_attrs_;\n};\n\nMaybe<void> NonContiguousBinaryOp::Init(const OpExpr& op) {\n  const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n  CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n  base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> NonContiguousBinaryOp::Capture(NonContiguousBinaryOpCaptureState* ctx,\n                                           const TensorTuple& inputs, const TensorTuple& outputs,\n                                           const AttrMap& attrs) const {\n  ctx->lhs_requires_grad = inputs.at(0)->requires_grad();\n  ctx->rhs_requires_grad = inputs.at(1)->requires_grad();\n  if (!ctx->lhs_requires_grad && !ctx->rhs_requires_grad) { return Maybe<void>::Ok(); }\n\n  ComposedAttrMap composed_attrs(attrs, base_attrs_);\n  ctx->inplace = JUST(composed_attrs.GetAttr<bool>(\"inplace\"));\n  ctx->op = JUST(composed_attrs.GetAttr<std::string>(\"op\"));\n  if (ctx->inplace && ctx->rhs_requires_grad) {\n    CHECK_OR_RETURN(ctx->op == \"add\" || ctx->op == \"sub\")\n        << \"when inplace and rhs requires grad, op should be add/sub\";\n  }\n  ctx->SaveTensorForBackward(inputs.at(0));\n  ctx->SaveTensorForBackward(inputs.at(1));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> NonContiguousBinaryOp::Apply(const NonContiguousBinaryOpCaptureState* ctx,\n                                         const TensorTuple& out_grads,\n                                         TensorTuple* in_grads) const {\n  if (!ctx->lhs_requires_grad && !ctx->rhs_requires_grad) { return Maybe<void>::Ok(); }\n  CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n  in_grads->resize(2);\n  auto lhs = ctx->SavedTensors().at(0);\n  auto rhs = ctx->SavedTensors().at(1);\n  auto ret = JUST(functional::NonContiguousBinaryOpGrad(out_grads.at(0), lhs, rhs, ctx->op, false));\n  if (ctx->lhs_requires_grad) in_grads->at(0) = ret->at(0);\n  if (ctx->rhs_requires_grad) in_grads->at(1) = ret->at(1);\n  return Maybe<void>::Ok();\n}\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"noncontiguous_binary_op\", NonContiguousBinaryOp);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/normalization.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/attr_map.h\"\n#include \"oneflow/core/framework/dtype.h\"\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct NormalizationGradCaptureState : public AutoGradCaptureState {\n  int32_t axis;\n  float epsilon;\n  bool track_running_stats;\n  bool is_training;\n  bool x_requires_grad;\n  bool gamma_requires_grad;\n  bool beta_requires_grad;\n};\n\n// training:\n// y, mean, inv_variance = normalization(x, moving_mean, moving_variance, gamma, beta,\n// axis=1, epsilon=0.01, momentum=0.9)\n// y, mean, inv_variance = normalization(x, gamma, beta, axis=1, epsilon=0.01, momentum=0.9)\n\n// inference:\n// y = normalization(x, moving_mean, moving_variance, gamma, beta, axis=1, epsilon=0.01,\n// momentum=0.9)\nclass NormalizationGrad : public OpExprGradFunction<NormalizationGradCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override {\n    const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n    CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n    base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Capture(NormalizationGradCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    // input_size may be 3 or 5, as inputs may be\n    // (x, gamma, beta) or (x, moving_mean, moving_variance, gamma, beta)\n    // ref to track_running_stats false/true\n    // output_size may be 1 or 3, as outputs may be\n    // (x, ) or (x, mean, inv_variance)\n    // ref to is_training false/true\n    ctx->x_requires_grad = inputs.at(0)->requires_grad();\n    std::shared_ptr<Tensor> gamma, beta;\n    if (inputs.size() == 3) {\n      gamma = inputs.at(1);\n      beta = inputs.at(2);\n      ctx->track_running_stats = false;\n    } else {\n      CHECK_EQ_OR_RETURN(inputs.size(), 5);  // NOLINT(maybe-need-error-msg)\n      gamma = inputs.at(3);\n      beta = inputs.at(4);\n      ctx->track_running_stats = true;\n    }\n    ctx->gamma_requires_grad = gamma->requires_grad();\n    ctx->beta_requires_grad = beta->requires_grad();\n    ComposedAttrMap composed_attrs(attrs, base_attrs_);\n\n    ctx->axis = JUST(composed_attrs.GetAttr<int32_t>(\"axis\"));\n    ctx->epsilon = JUST(composed_attrs.GetAttr<float>(\"epsilon\"));\n    ctx->is_training = JUST(composed_attrs.GetAttr<bool>(\"training\"));\n    ctx->SaveTensorForBackward(inputs.at(0));  // x\n    ctx->SaveTensorForBackward(gamma);         // gamma\n    if (ctx->is_training || !ctx->track_running_stats) {\n      ctx->SaveTensorForBackward(outputs.at(1));  // mean\n      ctx->SaveTensorForBackward(outputs.at(2));  // inv_variance\n    } else {\n      ctx->SaveTensorForBackward(inputs.at(1));  // moving_mean\n      ctx->SaveTensorForBackward(inputs.at(2));  // moving_variance\n    }\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const NormalizationGradCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    const auto& x = ctx->SavedTensors().at(0);      // x\n    const auto& gamma = ctx->SavedTensors().at(1);  // gamma\n    const auto& y_grad = out_grads.at(0);\n\n    std::shared_ptr<Tensor> mean, inv_variance;\n    if (ctx->is_training || !ctx->track_running_stats) {\n      mean = ctx->SavedTensors().at(2);          // mean\n      inv_variance = ctx->SavedTensors().at(3);  // inv_variance\n    } else {\n      const auto& moving_mean = ctx->SavedTensors().at(2);      // moving_mean\n      const auto& moving_variance = ctx->SavedTensors().at(3);  // moving_variance\n      const auto& add_eps = JUST(\n          functional::ScalarAdd(moving_variance, ctx->epsilon, /*alpha=*/1, /*inplace=*/false));\n      mean = moving_mean;\n      inv_variance = JUST(functional::Rsqrt(add_eps));\n    }\n    const auto& results = JUST(functional::NormalizationGrad(y_grad, x, mean, inv_variance, gamma,\n                                                             ctx->epsilon, ctx->axis));\n    CHECK_EQ_OR_RETURN(results->size(), 3)\n        << Error::RuntimeError() << \"The number of results is expected to be 3, but got \"\n        << results->size();\n\n    if (ctx->track_running_stats) {\n      // The normalization op has 5 inputs which are x, moving_mean, moving_variance, gamma and\n      // beta.\n      in_grads->resize(5);\n      if (ctx->gamma_requires_grad) {\n        in_grads->at(3) = results->at(1);  // gamma_diff;\n      }\n      if (ctx->beta_requires_grad) {\n        in_grads->at(4) = results->at(2);  // beta_diff\n      }\n    } else {\n      // The normalization op has 3 inputs which are x, gamma and beta.\n      in_grads->resize(3);\n      if (ctx->gamma_requires_grad) {\n        in_grads->at(1) = results->at(1);  // gamma_diff;\n      }\n      if (ctx->beta_requires_grad) {\n        in_grads->at(2) = results->at(2);  // beta_diff\n      }\n    }\n\n    if (!ctx->x_requires_grad) { return Maybe<void>::Ok(); }\n    if (ctx->is_training) {\n      in_grads->at(0) = results->at(0);\n      return Maybe<void>::Ok();\n    }\n\n    Shape shape;\n    for (int i = 0; i < x->shape()->NumAxes(); ++i) {\n      if (i != ctx->axis) {\n        shape.emplace_back(1);\n      } else {\n        shape.emplace_back(x->shape()->At(ctx->axis));\n      }\n    }\n    const auto& reshaped_gamma = JUST(functional::Reshape(gamma, shape));\n    const auto& reshaped_inv_variance = JUST(functional::Reshape(inv_variance, shape));\n\n    std::shared_ptr<Tensor> y_grad_fp32 = y_grad;\n    bool is_fp16 = y_grad->dtype()->data_type() == DataType::kFloat16;\n    if (is_fp16) {\n      y_grad_fp32 = JUST(functional::Cast(y_grad, DType::Float(), /*pin_memory=*/false));\n    }\n    const auto& dy_mul_gamma = JUST(functional::Mul(reshaped_gamma, y_grad_fp32));\n    const auto& dy_mul_inv_var = JUST(functional::Mul(dy_mul_gamma, reshaped_inv_variance));\n    if (is_fp16) {\n      (*in_grads)[0] =\n          JUST(functional::Cast(dy_mul_inv_var, DType::Float16(), /*pin_memory=*/false));\n    } else {\n      (*in_grads)[0] = dy_mul_inv_var;\n    }\n    return Maybe<void>::Ok();\n  }\n\n private:\n  AttrMap base_attrs_;\n};\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"normalization\", NormalizationGrad);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/normalization_add_relu.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/attr_map.h\"\n#include \"oneflow/core/framework/dtype.h\"\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct NormalizationAddReluGradCaptureState : public AutoGradCaptureState {\n  int32_t axis = 1;\n  float epsilon = 1e-5;\n  bool track_running_stats = true;\n  bool is_training = true;\n  bool has_addend = false;\n  bool x_requires_grad = true;\n  bool addend_requires_grad = true;\n  bool gamma_requires_grad = true;\n  bool beta_requires_grad = true;\n};\n\n// training:\n// y, mean, inv_variance = normalization_add_relu(x, Optional(add_end), moving_mean,\n// moving_variance, gamma, beta, axis=1, epsilon=0.01, momentum=0.9) y, mean, inv_variance =\n// normalization_add_relu(x, Optional(add_end), gamma, beta, axis=1, epsilon=0.01, momentum=0.9)\n\n// inference:\n// y = normalization_add_relu(x, Optional(add_end), moving_mean, moving_variance, gamma, beta,\n// axis=1, epsilon=0.01, momentum=0.9)\n\nclass NormalizationAddReluGrad : public OpExprGradFunction<NormalizationAddReluGradCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override {\n    const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n    CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n    base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Capture(NormalizationAddReluGradCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    // input_size may be 3/4/5/6, as inputs may be\n    // (x, gamma, beta) or (x, moving_mean, moving_variance, gamma, beta)\n    // (x, addend, gamma, beta) or (x, addend, moving_mean, moving_variance, gamma, beta)\n\n    // ref to track_running_stats false/true\n    // output_size may be 2 or 4, as outputs may be\n    // (x, reserve_space) or (x, reserve_space, mean, inv_variance)\n    // ref to is_training false/true\n    ctx->x_requires_grad = inputs.at(0)->requires_grad();\n    std::shared_ptr<Tensor> add_end, gamma, beta;\n\n    if (inputs.size() == 3 || inputs.size() == 5) {\n      add_end = nullptr;\n      if (inputs.size() == 3) {\n        gamma = inputs.at(1);\n        beta = inputs.at(2);\n        ctx->track_running_stats = false;\n      } else {\n        gamma = inputs.at(3);\n        beta = inputs.at(4);\n        ctx->track_running_stats = true;\n      }\n      ctx->has_addend = false;\n    } else if (inputs.size() == 4 || inputs.size() == 6) {\n      add_end = inputs.at(1);\n      if (inputs.size() == 4) {\n        gamma = inputs.at(2);\n        beta = inputs.at(3);\n        ctx->track_running_stats = false;\n      } else {\n        gamma = inputs.at(4);\n        beta = inputs.at(5);\n        ctx->track_running_stats = true;\n      }\n      ctx->has_addend = true;\n      ctx->addend_requires_grad = inputs.at(1)->requires_grad();\n    }\n\n    ctx->gamma_requires_grad = gamma->requires_grad();\n    ctx->beta_requires_grad = beta->requires_grad();\n    ComposedAttrMap composed_attrs(attrs, base_attrs_);\n\n    ctx->axis = JUST(composed_attrs.GetAttr<int32_t>(\"axis\"));\n    ctx->epsilon = JUST(composed_attrs.GetAttr<float>(\"epsilon\"));\n    ctx->is_training = JUST(composed_attrs.GetAttr<bool>(\"training\"));\n\n    ctx->SaveTensorForBackward(inputs.at(0));  // x 0\n    ctx->SaveTensorForBackward(gamma);         // gamma 1\n    ctx->SaveTensorForBackward(beta);          // beta 2\n\n    if (ctx->is_training || !ctx->track_running_stats) {\n      ctx->SaveTensorForBackward(outputs.at(2));  // mean 3\n      ctx->SaveTensorForBackward(outputs.at(3));  // inv_variance 4\n    } else {\n      if (inputs.size() == 5) {\n        // without add_end\n        ctx->SaveTensorForBackward(inputs.at(1));  // moving_mean 3\n        ctx->SaveTensorForBackward(inputs.at(2));  // moving_variance 4\n      } else {\n        CHECK_EQ_OR_RETURN(inputs.size(), 6);  // NOLINT(maybe-need-error-msg)\n        // with add_end\n        ctx->SaveTensorForBackward(inputs.at(2));  // moving_mean 3\n        ctx->SaveTensorForBackward(inputs.at(3));  // moving_variance 4\n      }\n    }\n    ctx->SaveTensorForBackward(outputs.at(0));  // y 5\n    ctx->SaveTensorForBackward(outputs.at(1));  // reserve space 6\n\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const NormalizationAddReluGradCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    const auto& x = ctx->SavedTensors().at(0);      // x\n    const auto& gamma = ctx->SavedTensors().at(1);  // gamma\n    const auto& beta = ctx->SavedTensors().at(2);   // beta\n    const auto& y_grad = out_grads.at(0);\n\n    std::shared_ptr<Tensor> mean, inv_variance;\n    if (ctx->is_training || !ctx->track_running_stats) {\n      mean = ctx->SavedTensors().at(3);          // mean\n      inv_variance = ctx->SavedTensors().at(4);  // inv_variance\n    } else {\n      const auto& moving_mean = ctx->SavedTensors().at(3);      // moving_mean\n      const auto& moving_variance = ctx->SavedTensors().at(4);  // moving_variance\n      const auto& add_eps = JUST(\n          functional::ScalarAdd(moving_variance, ctx->epsilon, /*alpha=*/1, /*inplace=*/false));\n      mean = moving_mean;\n      inv_variance = JUST(functional::Rsqrt(add_eps));\n    }\n    const auto& y = ctx->SavedTensors().at(5);\n    const auto& reserve_space = ctx->SavedTensors().at(6);\n\n    const auto& results = JUST(functional::NormalizationAddReluGrad(\n        x, y_grad, mean, inv_variance, gamma, beta, reserve_space, y, ctx->axis, ctx->epsilon,\n        ctx->has_addend));\n    CHECK_EQ_OR_RETURN(results->size(), (ctx->has_addend ? 4 : 3))\n        << Error::RuntimeError() << \"The number of results is expected to be \"\n        << (ctx->has_addend ? 4 : 3) << \", but got \"\n        << results->size();  // here output includes \"gamma_diff\" \"beta_diff\" \"dx\" \"addend_diff\"\n\n    if (ctx->track_running_stats) {\n      // The normalization op has 5 inputs which are x, moving_mean, moving_variance, gamma and\n      // beta. or 6 inputs: x, add_end, moving_mean, moving_variance, gamma and beta.\n      if (ctx->has_addend) {\n        in_grads->resize(6);\n        if (ctx->gamma_requires_grad) {\n          in_grads->at(4) = results->at(1);  // gamma_diff;\n        }\n        if (ctx->beta_requires_grad) {\n          in_grads->at(5) = results->at(2);  // beta_diff\n        }\n        if (ctx->addend_requires_grad) {\n          in_grads->at(1) = results->at(3);  // add_end_diff\n        }\n      } else {\n        in_grads->resize(5);\n        if (ctx->gamma_requires_grad) {\n          in_grads->at(3) = results->at(1);  // gamma_diff;\n        }\n        if (ctx->beta_requires_grad) {\n          in_grads->at(4) = results->at(2);  // beta_diff\n        }\n      }\n\n    } else {\n      // The normalization op has 3 inputs which are x, addend, gamma and beta.\n      // or has 4 inputs which are x, addend, gamma and beta.\n      if (ctx->has_addend) {\n        in_grads->resize(4);\n        if (ctx->addend_requires_grad) {\n          in_grads->at(1) = results->at(3);  // addend_diff\n        }\n        if (ctx->gamma_requires_grad) {\n          in_grads->at(2) = results->at(1);  // gamma_diff;\n        }\n        if (ctx->beta_requires_grad) {\n          in_grads->at(3) = results->at(2);  // beta_diff\n        }\n      } else {\n        in_grads->resize(3);\n        if (ctx->gamma_requires_grad) {\n          in_grads->at(1) = results->at(1);  // gamma_diff;\n        }\n        if (ctx->beta_requires_grad) {\n          in_grads->at(2) = results->at(2);  // beta_diff\n        }\n      }\n    }\n\n    if (!ctx->x_requires_grad) { return Maybe<void>::Ok(); }\n    if (ctx->is_training) {\n      in_grads->at(0) = results->at(0);\n      return Maybe<void>::Ok();\n    }\n\n    // todo(zzk): add eval mode.\n    return Maybe<void>::Ok();\n  }\n\n private:\n  AttrMap base_attrs_;\n};\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"normalization_add_relu\", NormalizationAddReluGrad);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/one_embedding_fused_lookup.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/common/container_util.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct OneEmbeddingFusedLookupCaptureState : public AutoGradCaptureState {\n  bool requires_grad{};\n  std::string embedding_name{};\n  int64_t line_size{};\n  int64_t embedding_size{};\n  int shadow_index{};\n  int ids_index{};\n  int input_num{};\n};\n\nclass OneEmbeddingFusedLookup : public OpExprGradFunction<OneEmbeddingFusedLookupCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }\n\n  Maybe<void> Capture(OneEmbeddingFusedLookupCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    CHECK_GE_OR_RETURN(inputs.size(), 2);                          // NOLINT(maybe-need-error-msg)\n    ctx->requires_grad = inputs.at(0)->requires_grad();            // shadow\n    ctx->shadow_index = ctx->SaveTensorForBackward(inputs.at(0));  // shadow\n    ctx->ids_index = ctx->SaveTensorForBackward(inputs.at(1));     // id\n    ctx->embedding_name = JUST(attrs.GetAttr<std::string>(\"embedding_name\"));\n    ctx->line_size = JUST(attrs.GetAttr<int64_t>(\"line_size\"));\n    ctx->embedding_size = JUST(attrs.GetAttr<int64_t>(\"embedding_size\"));\n    ctx->input_num = inputs.size();\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const OneEmbeddingFusedLookupCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    in_grads->resize(ctx->input_num);\n    const auto& saved_tensors = ctx->SavedTensors();\n    CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n    if (ctx->requires_grad) {\n      JUST(functional::OneEmbeddingFusedLookupGrad(\n          saved_tensors.at(ctx->ids_index), JUST(VectorAt(out_grads, 0)), ctx->embedding_name,\n          ctx->line_size, ctx->embedding_size));\n      (*in_grads)[0] = JUST(functional::ZerosLike(saved_tensors.at(ctx->shadow_index)));\n    }\n    return Maybe<void>::Ok();\n  }\n};\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"one_embedding_fused_lookup\", OneEmbeddingFusedLookup);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/padding.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/functional/functional.h\"\n#include \"oneflow/core/common/container_util.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct PadNdCaptureState : public AutoGradCaptureState {\n  bool requires_grad = false;\n  std::vector<int64_t> paddings{};\n};\n\nclass PadNd : public OpExprGradFunction<PadNdCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override {\n    const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n    CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n    base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Capture(PadNdCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs,\n                      const AttrMap& attrs) const override {\n    CHECK_EQ_OR_RETURN(inputs.size(), 1);   // NOLINT(maybe-need-error-msg)\n    CHECK_EQ_OR_RETURN(outputs.size(), 1);  // NOLINT(maybe-need-error-msg)\n    ctx->requires_grad = JUST(VectorAt(inputs, 0))->requires_grad();\n    if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n\n    ComposedAttrMap composed_attrs(attrs, base_attrs_);\n    ctx->paddings = JUST(composed_attrs.GetAttr<std::vector<int64_t>>(\"padding\"));\n    return Maybe<void>::Ok();\n  }\n\n private:\n  AttrMap base_attrs_;\n};\n\nclass ReflectionPadNd : public PadNd {\n public:\n  Maybe<void> Apply(const PadNdCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n    in_grads->resize(1);\n    if (ctx->requires_grad) {\n      (*in_grads)[0] =\n          JUST(functional::PadGrad(JUST(VectorAt(out_grads, 0)), ctx->paddings, \"reflect\", 0));\n    }\n    return Maybe<void>::Ok();\n  }\n};\n\nclass ReplicationPadNd : public PadNd {\n public:\n  Maybe<void> Apply(const PadNdCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n    in_grads->resize(1);\n    if (ctx->requires_grad) {\n      (*in_grads)[0] =\n          JUST(functional::PadGrad(JUST(VectorAt(out_grads, 0)), ctx->paddings, \"replicate\", 0));\n    }\n    return Maybe<void>::Ok();\n  }\n};\n\nstruct ConstantPadNdCaptureState : public AutoGradCaptureState {\n  bool requires_grad;\n  std::vector<int64_t> paddings;\n};\n\nclass ConstantPadNd : public OpExprGradFunction<ConstantPadNdCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override {\n    const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n    CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n    base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Capture(ConstantPadNdCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    CHECK_EQ_OR_RETURN(inputs.size(), 1);   // NOLINT(maybe-need-error-msg)\n    CHECK_EQ_OR_RETURN(outputs.size(), 1);  // NOLINT(maybe-need-error-msg)\n    const std::shared_ptr<Tensor>& input_0 = JUST(VectorAt(inputs, 0));\n    ctx->requires_grad = input_0->requires_grad();\n    if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n\n    ComposedAttrMap composed_attrs(attrs, base_attrs_);\n    ctx->paddings = JUST(composed_attrs.GetAttr<std::vector<int64_t>>(\"padding\"));\n    for (int i = 0; i < ctx->paddings.size(); i++) { ctx->paddings[i] = -ctx->paddings[i]; }\n    return Maybe<void>::Ok();\n  }\n  Maybe<void> Apply(const ConstantPadNdCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n    in_grads->resize(1);\n    if (ctx->requires_grad) {\n      (*in_grads)[0] =\n          JUST(functional::Pad(JUST(VectorAt(out_grads, 0)), ctx->paddings, \"constant\", Scalar(0)));\n    }\n    return Maybe<void>::Ok();\n  }\n\n private:\n  AttrMap base_attrs_;\n};\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"pad\", ConstantPadNd);\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"reflection_pad1d\", ReflectionPadNd);\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"reflection_pad2d\", ReflectionPadNd);\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"replication_pad1d\", ReplicationPadNd);\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"replication_pad2d\", ReplicationPadNd);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/partial_fc_sample.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/attr_map.h\"\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct PartialFCSampleState : public AutoGradCaptureState {\n  bool requires_grad = false;\n  int32_t index_sampled_label = -1;\n  int32_t index_weight = -1;\n};\n\nclass PartialFCSample : public OpExprGradFunction<PartialFCSampleState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override;\n  Maybe<void> Capture(PartialFCSampleState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override;\n  Maybe<void> Apply(const PartialFCSampleState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override;\n\n private:\n  AttrMap base_attrs_;\n};\n\nMaybe<void> PartialFCSample::Init(const OpExpr& op) {\n  const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n  CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n  base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> PartialFCSample::Capture(PartialFCSampleState* ctx, const TensorTuple& inputs,\n                                     const TensorTuple& outputs, const AttrMap& attrs) const {\n  ctx->requires_grad = inputs.at(0)->requires_grad();\n  if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n  ctx->index_sampled_label = ctx->SaveTensorForBackward(outputs.at(1));  // sampled_label\n  ctx->index_weight = ctx->SaveTensorForBackward(inputs.at(0));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> PartialFCSample::Apply(const PartialFCSampleState* ctx, const TensorTuple& out_grads,\n                                   TensorTuple* in_grads) const {\n  CHECK_EQ_OR_RETURN(out_grads.size(), 3);  // NOLINT(maybe-need-error-msg)\n  in_grads->resize(2);\n  if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n  const auto& diff_sampled_weight = out_grads.at(2);  // diff of sampled_weight\n\n  const auto& sampled_tensor = ctx->SavedTensors().at(ctx->index_sampled_label);\n  const auto& weight = ctx->SavedTensors().at(ctx->index_weight);\n  const auto& out_tensors_of_op0 = JUST(\n      functional::DistributedPariticalFCSampleDisableBoxing(diff_sampled_weight, sampled_tensor));\n\n  const auto& out_tensors_of_op1 = JUST(functional::UnsortedSegmentSumLike(\n      out_tensors_of_op0->at(0), out_tensors_of_op0->at(1), weight, 0));\n  in_grads->at(0) = out_tensors_of_op1;\n  return Maybe<void>::Ok();\n}\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"distributed_partial_fc_sample\", PartialFCSample);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/reduce_ops.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/attr_map.h\"\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/framework/op_builder.h\"\n#include \"oneflow/core/framework/op_interpreter/op_interpreter_util.h\"\n#include \"oneflow/core/functional/functional.h\"\n#include \"oneflow/core/functional/sequence_function.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct ReduceSumCaptureState : public AutoGradCaptureState {\n  std::vector<int32_t> axis;\n};\n\nclass ReduceSum : public OpExprGradFunction<ReduceSumCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override;\n  Maybe<void> Capture(ReduceSumCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override;\n  Maybe<void> Apply(const ReduceSumCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override;\n\n private:\n  AttrMap base_attrs_;\n};\n\nMaybe<void> ReduceSum::Init(const OpExpr& op) {\n  const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n  CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n  base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> ReduceSum::Capture(ReduceSumCaptureState* ctx, const TensorTuple& inputs,\n                               const TensorTuple& outputs, const AttrMap& attrs) const {\n  ComposedAttrMap composed_attrs(attrs, base_attrs_);\n  ctx->axis = JUST(composed_attrs.GetAttr<std::vector<int32_t>>(\"axis\"));\n  ctx->SaveTensorForBackward(inputs.at(0));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> ReduceSum::Apply(const ReduceSumCaptureState* ctx, const TensorTuple& out_grads,\n                             TensorTuple* in_grads) const {\n  const auto& input = ctx->SavedTensors().at(0);\n  const auto& dy = out_grads.at(0);\n  in_grads->resize(1);\n  in_grads->at(0) = JUST(functional::BroadcastLike(dy, input, ctx->axis));\n  return Maybe<void>::Ok();\n}\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"reduce_sum\", ReduceSum);\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"reduce_nansum\", ReduceSum);\n\nstruct ReduceProdOpInterpState : public AutoGradCaptureState {\n  std::vector<int32_t> axis;\n  bool requires_grad;\n};\n\nclass ReduceProdOp : public OpExprGradFunction<ReduceProdOpInterpState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override;\n  Maybe<void> Capture(ReduceProdOpInterpState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override;\n  Maybe<void> Apply(const ReduceProdOpInterpState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override;\n\n private:\n  AttrMap base_attrs_;\n};\n\nMaybe<void> ReduceProdOp::Init(const OpExpr& op) {\n  const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n  CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n  base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> ReduceProdOp::Capture(ReduceProdOpInterpState* ctx, const TensorTuple& inputs,\n                                  const TensorTuple& outputs, const AttrMap& attrs) const {\n  ComposedAttrMap composed_attrs(attrs, base_attrs_);\n  ctx->axis = JUST(composed_attrs.GetAttr<std::vector<int32_t>>(\"axis\"));\n  ctx->requires_grad = inputs.at(0)->requires_grad();\n  ctx->SaveTensorForBackward(inputs.at(0));\n  ctx->SaveTensorForBackward(outputs.at(0));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> ReduceProdOp::Apply(const ReduceProdOpInterpState* ctx, const TensorTuple& out_grads,\n                                TensorTuple* in_grads) const {\n  if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n\n  const auto& input = ctx->SavedTensors().at(0);\n  const auto& output = ctx->SavedTensors().at(1);\n  const auto& dy = out_grads.at(0);\n\n  in_grads->resize(1);\n  in_grads->at(0) = JUST(\n      functional::SequenceFunction<Maybe<Tensor>()>([&]() { return functional::Mul(dy, output); })\n          .then(std::bind(functional::BroadcastLike, std::placeholders::_1, input, ctx->axis))\n          .then(std::bind(functional::Div, std::placeholders::_1, input))\n          .call());\n  return Maybe<void>::Ok();\n}\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"reduce_prod\", ReduceProdOp);\n\nstruct ReduceMaxOrMinCaptureState : public AutoGradCaptureState {\n  std::vector<int32_t> axis;\n  bool keepdims;\n};\n\nclass ReduceMaxOrMin : public OpExprGradFunction<ReduceMaxOrMinCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override;\n  Maybe<void> Capture(ReduceMaxOrMinCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override;\n  Maybe<void> Apply(const ReduceMaxOrMinCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override;\n\n private:\n  AttrMap base_attrs_;\n};\n\nMaybe<void> ReduceMaxOrMin::Init(const OpExpr& op) {\n  const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n  CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n  base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> ReduceMaxOrMin::Capture(ReduceMaxOrMinCaptureState* ctx, const TensorTuple& inputs,\n                                    const TensorTuple& outputs, const AttrMap& attrs) const {\n  ComposedAttrMap composed_attrs(attrs, base_attrs_);\n  ctx->axis = JUST(composed_attrs.GetAttr<std::vector<int32_t>>(\"axis\"));\n  ctx->keepdims = JUST(composed_attrs.GetAttr<bool>(\"keepdims\"));\n  ctx->SaveTensorForBackward(inputs.at(0));\n  ctx->SaveTensorForBackward(outputs.at(0));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> ReduceMaxOrMin::Apply(const ReduceMaxOrMinCaptureState* ctx,\n                                  const TensorTuple& out_grads, TensorTuple* in_grads) const {\n  const auto& input = ctx->SavedTensors().at(0);\n  const auto& output = ctx->SavedTensors().at(1);\n  const auto& dy = out_grads.at(0);\n\n  const auto cast_like =\n      JUST(functional::SequenceFunction<Maybe<Tensor>()>(\n               [&]() { return functional::BroadcastLike(output, input, ctx->axis); })\n               .then(std::bind(functional::BroadcastEqual, input, std::placeholders::_1))\n               .then(std::bind(functional::CastLike, std::placeholders::_1, input))\n               .call());\n\n  const auto& bcast_like_div =\n      JUST(functional::SequenceFunction<Maybe<Tensor>()>([&]() {\n             return functional::ReduceSum(cast_like, ctx->axis, ctx->keepdims, NullOpt);\n           })\n               .then(std::bind(functional::Div, dy, std::placeholders::_1))\n               .then(std::bind(functional::BroadcastLike, std::placeholders::_1, input, ctx->axis))\n               .call());\n\n  in_grads->resize(1);\n  in_grads->at(0) = JUST(functional::Mul(bcast_like_div, cast_like));\n\n  return Maybe<void>::Ok();\n}\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"reduce_min\", ReduceMaxOrMin);\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"reduce_max\", ReduceMaxOrMin);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/reduce_sum_like.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/framework/op_builder.h\"\n#include \"oneflow/core/framework/op_expr.h\"\n#include \"oneflow/core/functional/functional.h\"\n#include \"oneflow/core/functional/functional_api.yaml.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct ReduceSumLikeCaptureState : public AutoGradCaptureState {\n  bool requires_grad = false;\n  std::vector<int32_t> axis;\n};\n\nclass ReduceSumLike : public OpExprGradFunction<ReduceSumLikeCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override;\n  Maybe<void> Capture(ReduceSumLikeCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override;\n  Maybe<void> Apply(const ReduceSumLikeCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override;\n\n private:\n  AttrMap base_attrs_;\n};\n\nMaybe<void> ReduceSumLike::Init(const OpExpr& op) {\n  const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n  CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n  base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> ReduceSumLike::Capture(ReduceSumLikeCaptureState* ctx, const TensorTuple& inputs,\n                                   const TensorTuple& outputs, const AttrMap& attrs) const {\n  CHECK_EQ_OR_RETURN(inputs.size(), 2);   // NOLINT(maybe-need-error-msg)\n  CHECK_EQ_OR_RETURN(outputs.size(), 1);  // NOLINT(maybe-need-error-msg)\n  ctx->requires_grad = inputs.at(0)->requires_grad();\n  CHECK_OR_RETURN(!inputs.at(1)->requires_grad())\n      << Error::RuntimeError() << \"like tensor does not require grad\";\n  if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n\n  ComposedAttrMap composed_attrs(attrs, base_attrs_);\n  ctx->axis = JUST(composed_attrs.GetAttr<std::vector<int32_t>>(\"axis\"));\n  ctx->SaveTensorForBackward(inputs.at(0));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> ReduceSumLike::Apply(const ReduceSumLikeCaptureState* ctx, const TensorTuple& out_grads,\n                                 TensorTuple* in_grads) const {\n  const auto& x = ctx->SavedTensors().at(0);\n  in_grads->resize(2);\n  in_grads->at(0) = JUST(functional::BroadcastLike(out_grads.at(0), x, ctx->axis));\n  return Maybe<void>::Ok();\n}\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"reduce_sum_like\", ReduceSumLike);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/reshape.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n\n#include \"oneflow/core/framework/op_builder.h\"\n#include \"oneflow/core/framework/op_interpreter/op_interpreter_util.h\"\n#include \"oneflow/core/framework/op_expr.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct ReshapeCaptureState : public AutoGradCaptureState {\n  DimVector input_shape_vec;\n};\n\nclass ReshapeGrad : public OpExprGradFunction<ReshapeCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override {\n    const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n    CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Capture(ReshapeCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    ctx->input_shape_vec = inputs.at(0)->shape()->dim_vec();\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const ReshapeCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    in_grads->resize(1);\n    Shape shape(ctx->input_shape_vec);\n    in_grads->at(0) = JUST(functional::Reshape(out_grads.at(0), shape));\n    return Maybe<void>::Ok();\n  }\n};\n\nclass ReshapeLikeGrad : public OpExprGradFunction<ReshapeCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override {\n    const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n    CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Capture(ReshapeCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    CHECK_EQ_OR_RETURN(inputs.size(), 2);  // NOLINT(maybe-need-error-msg)\n    CHECK_OR_RETURN(!inputs.at(1)->requires_grad())\n        << \"ReshapeLikeOp's input[1] need not requires_grad.\";\n    ctx->input_shape_vec = inputs.at(0)->shape()->dim_vec();\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const ReshapeCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    in_grads->resize(2);\n    Shape shape(ctx->input_shape_vec);\n    in_grads->at(0) = JUST(functional::Reshape(out_grads.at(0), shape));\n    return Maybe<void>::Ok();\n  }\n};\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"reshape\", ReshapeGrad);\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"reshape_like\", ReshapeLikeGrad);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/rms_norm.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct RMSNormCaptureState : public AutoGradCaptureState {\n  bool x_requires_grad = false;\n  bool weight_requires_grad = false;\n  int x_index = -1;\n  int inv_rms_index = -1;\n  int weight_index = -1;\n};\n\nclass RMSNormGrad : public OpExprGradFunction<RMSNormCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }\n  Maybe<void> Capture(RMSNormCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override;\n  Maybe<void> Apply(const RMSNormCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override;\n};\n\nMaybe<void> RMSNormGrad::Capture(RMSNormCaptureState* ctx, const TensorTuple& inputs,\n                                 const TensorTuple& outputs, const AttrMap& attrs) const {\n  // (x, [weight])\n  CHECK_GE_OR_RETURN(inputs.size(), 1);  // NOLINT(maybe-need-error-msg)\n  CHECK_LE_OR_RETURN(inputs.size(), 2);  // NOLINT(maybe-need-error-msg)\n  // (y, inv_rms)\n  CHECK_EQ_OR_RETURN(outputs.size(), 2);  // NOLINT(maybe-need-error-msg)\n\n  // save x\n  ctx->x_requires_grad = inputs[0]->requires_grad();\n  ctx->x_index = ctx->SaveTensorForBackward(inputs[0]);\n\n  // save weight\n  ctx->weight_requires_grad = false;\n  if (inputs.size() > 1) {\n    ctx->weight_requires_grad = inputs[1]->requires_grad();\n    ctx->weight_index = ctx->SaveTensorForBackward(inputs[1]);\n  }\n\n  // save inv_rms\n  if (ctx->x_requires_grad || ctx->weight_requires_grad) {\n    ctx->inv_rms_index = ctx->SaveTensorForBackward(outputs[1]);\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> RMSNormGrad::Apply(const RMSNormCaptureState* ctx, const TensorTuple& out_grads,\n                               TensorTuple* in_grads) const {\n  // (x, inv_rms) or (x, weight, inv_rms)\n  const auto& saved_tensors = ctx->SavedTensors();\n  CHECK_GE_OR_RETURN(saved_tensors.size(), 2);  // NOLINT(maybe-need-error-msg)\n  CHECK_LE_OR_RETURN(saved_tensors.size(), 3);  // NOLINT(maybe-need-error-msg)\n\n  // (dy, inv_rms_diff)\n  CHECK_EQ_OR_RETURN(out_grads.size(), 2);  // NOLINT(maybe-need-error-msg)\n  const auto& dy = out_grads[0];\n  const auto& x = saved_tensors.at(ctx->x_index);\n  const auto& inv_rms = saved_tensors.at(ctx->inv_rms_index);\n\n  // (x_grad, weight_grad)\n  in_grads->resize(2);\n  if (ctx->x_requires_grad) {\n    if (saved_tensors.size() == 3) {\n      const auto& weight = saved_tensors.at(ctx->weight_index);\n      in_grads->at(0) = JUST(functional::RMSNormGrad(dy, x, inv_rms, weight, /*param_grad*/ false));\n    } else {\n      in_grads->at(0) =\n          JUST(functional::RMSNormGrad(dy, x, inv_rms, /*weight*/ NullOpt, /*param_grad*/ false));\n    }\n  }\n  if (ctx->weight_requires_grad) {\n    in_grads->at(1) =\n        JUST(functional::RMSNormGrad(dy, x, inv_rms, /*weight*/ NullOpt, /*param_grad*/ true));\n  }\n  return Maybe<void>::Ok();\n}\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"rms_norm\", RMSNormGrad);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/roi_align.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/framework/op_expr.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct RoiAlignCaptureState : public AutoGradCaptureState {\n  float spatial_scale = 1.0;\n  int32_t pooled_h = 0;\n  int32_t pooled_w = 0;\n  int32_t sampling_ratio = -1;\n  bool aligned = false;\n  bool requires_grad = false;\n};\n\nclass RoiAlign : public OpExprGradFunction<RoiAlignCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override {\n    const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n    CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n    base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Capture(RoiAlignCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    ctx->requires_grad = inputs.at(0)->requires_grad();\n    if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n\n    ctx->SaveTensorForBackward(inputs.at(0));\n    ctx->SaveTensorForBackward(inputs.at(1));\n\n    ComposedAttrMap composed_attrs(attrs, base_attrs_);\n    ctx->spatial_scale = JUST(composed_attrs.GetAttr<float>(\"spatial_scale\"));\n    ctx->pooled_h = JUST(composed_attrs.GetAttr<int32_t>(\"pooled_h\"));\n    ctx->pooled_w = JUST(composed_attrs.GetAttr<int32_t>(\"pooled_w\"));\n    ctx->sampling_ratio = JUST(composed_attrs.GetAttr<int32_t>(\"sampling_ratio\"));\n    ctx->aligned = JUST(composed_attrs.GetAttr<bool>(\"aligned\"));\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const RoiAlignCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n    const auto& x_like = ctx->SavedTensors().at(0);\n    const auto& rois = ctx->SavedTensors().at(1);\n    in_grads->at(0) = JUST(\n        functional::RoiAlignGrad(out_grads.at(0), x_like, rois, ctx->spatial_scale, ctx->pooled_h,\n                                 ctx->pooled_w, ctx->sampling_ratio, ctx->aligned));\n    return Maybe<void>::Ok();\n  }\n\n private:\n  AttrMap base_attrs_;\n};\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"roi_align\", RoiAlign);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/roll.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/framework/op_interpreter/op_interpreter_util.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct RollCaptureState : public AutoGradCaptureState {\n  std::vector<int32_t> shifts;\n  std::vector<int32_t> dims;\n  bool requires_grad = false;\n};\n\nclass Roll : public OpExprGradFunction<RollCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override;\n  Maybe<void> Capture(RollCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs,\n                      const AttrMap& attrs) const override;\n  Maybe<void> Apply(const RollCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override;\n\n private:\n  AttrMap base_attrs_;\n};\n\nMaybe<void> Roll::Init(const OpExpr& op) {\n  const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n  CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n  base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> Roll::Capture(RollCaptureState* ctx, const TensorTuple& inputs,\n                          const TensorTuple& outputs, const AttrMap& attrs) const {\n  ctx->requires_grad = inputs.at(0)->requires_grad();\n  if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n\n  ComposedAttrMap composed_attrs(attrs, base_attrs_);\n  ctx->shifts = JUST(composed_attrs.GetAttr<std::vector<int32_t>>(\"shifts\"));\n  ctx->dims = JUST(composed_attrs.GetAttr<std::vector<int32_t>>(\"dims\"));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> Roll::Apply(const RollCaptureState* ctx, const TensorTuple& out_grads,\n                        TensorTuple* in_grads) const {\n  if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n  CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n\n  std::vector<int32_t> new_shifts;\n  new_shifts.resize(ctx->shifts.size());\n  for (int i = 0; i < new_shifts.size(); ++i) { new_shifts[i] = -ctx->shifts[i]; }\n\n  in_grads->at(0) = JUST(functional::Roll(out_grads.at(0), new_shifts, ctx->dims));\n  return Maybe<void>::Ok();\n}\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"roll\", Roll);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/rrelu.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct RReluCaptureState : public AutoGradCaptureState {\n  bool requires_grad = true;\n  float lower = 1.0 / 8;\n  float upper = 1.0 / 3;\n  bool training = false;\n  int x_index = -1;\n  int noise_data_index = -1;\n};\n\nclass RRelu : public OpExprGradFunction<RReluCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override;\n  Maybe<void> Capture(RReluCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs,\n                      const AttrMap& attrs) const override;\n  Maybe<void> Apply(const RReluCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override;\n\n private:\n  AttrMap base_attrs_;\n};\n\nMaybe<void> RRelu::Init(const OpExpr& op) { return Maybe<void>::Ok(); }\n\nMaybe<void> RRelu::Capture(RReluCaptureState* ctx, const TensorTuple& inputs,\n                           const TensorTuple& outputs, const AttrMap& attrs) const {\n  ComposedAttrMap composed_attrs(attrs, base_attrs_);\n  ctx->requires_grad = inputs.at(0)->requires_grad();\n  if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n  ctx->lower = JUST(composed_attrs.GetAttr<float>(\"lower\"));\n  ctx->upper = JUST(composed_attrs.GetAttr<float>(\"upper\"));\n  ctx->training = JUST(composed_attrs.GetAttr<bool>(\"training\"));\n  ctx->x_index = ctx->SaveTensorForBackward(inputs[0]);\n  ctx->noise_data_index = ctx->SaveTensorForBackward(outputs[1]);  // output noise data\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> RRelu::Apply(const RReluCaptureState* ctx, const TensorTuple& out_grads,\n                         TensorTuple* in_grads) const {\n  if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n  const auto& saved_tensors = ctx->SavedTensors();\n  if (!ctx->training) {\n    float scale = (ctx->lower + ctx->upper) / 2;\n    const auto& x = saved_tensors.at(ctx->x_index);\n    in_grads->at(0) = JUST(functional::LeakyReluGrad(x, out_grads.at(0), scale));\n    return Maybe<void>::Ok();\n\n  } else {\n    const auto& noise_data = saved_tensors.at(ctx->noise_data_index);\n    in_grads->at(0) = JUST(functional::Mul(out_grads.at(0), noise_data));\n    return Maybe<void>::Ok();\n  }\n}\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"rrelu\", RRelu);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/scalar_add.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct ScalarAddCaptureState : public AutoGradCaptureState {\n  bool requires_grad;\n};\n\nclass ScalarAdd : public OpExprGradFunction<ScalarAddCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }\n\n  Maybe<void> Capture(ScalarAddCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    CHECK_EQ_OR_RETURN(inputs.size(), 1);  // NOLINT(maybe-need-error-msg)\n    ctx->requires_grad = inputs.at(0)->requires_grad();\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const ScalarAddCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n    in_grads->resize(1);\n    if (ctx->requires_grad) { in_grads->at(0) = out_grads.at(0); }\n    return Maybe<void>::Ok();\n  }\n};\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"scalar_add\", ScalarAdd);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/scalar_div.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/attr_map.h\"\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/functional/functional.h\"\n#include \"oneflow/core/common/container_util.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct ScalarDivCaptureState : public AutoGradCaptureState {\n  bool requires_grad = true;\n  Scalar operand;\n};\n\nclass ScalarDiv : public OpExprGradFunction<ScalarDivCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override {\n    const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n    CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n    base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Capture(ScalarDivCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    CHECK_EQ_OR_RETURN(inputs.size(), 1);  // NOLINT(maybe-need-error-msg)\n    ctx->requires_grad = JUST(VectorAt(inputs, 0))->requires_grad();\n    if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n    ComposedAttrMap composed_attrs(attrs, base_attrs_);\n    bool has_float_operand = JUST(composed_attrs.GetAttr<bool>(\"has_float_operand\"));\n    if (has_float_operand) {\n      ctx->operand = Scalar(JUST(composed_attrs.GetAttr<double>(\"float_operand\")));\n    } else {\n      ctx->operand = Scalar(JUST(composed_attrs.GetAttr<int64_t>(\"int_operand\")));\n    }\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const ScalarDivCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n    in_grads->resize(1);\n    if (ctx->requires_grad) {\n      JUST(VectorAt(*in_grads, 0)) =\n          JUST(functional::ScalarDiv(JUST(VectorAt(out_grads, 0)), ctx->operand));\n    }\n    return Maybe<void>::Ok();\n  }\n\n private:\n  AttrMap base_attrs_;\n};\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"scalar_div\", ScalarDiv);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/scalar_floordiv.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/functional/functional.h\"\n#include \"oneflow/core/common/container_util.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct ScalarFloorDivCaptureState : public AutoGradCaptureState {\n  bool requires_grad = true;\n};\n\nclass ScalarFloorDiv : public OpExprGradFunction<ScalarFloorDivCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }\n\n  Maybe<void> Capture(ScalarFloorDivCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    CHECK_EQ_OR_RETURN(inputs.size(), 1);  // NOLINT(maybe-need-error-msg)\n    ctx->requires_grad = JUST(VectorAt(inputs, 0))->requires_grad();\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const ScalarFloorDivCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n    in_grads->resize(1);\n    if (ctx->requires_grad) {\n      JUST(VectorAt(*in_grads, 0)) = JUST(functional::ZerosLike(JUST(VectorAt(out_grads, 0))));\n    }\n    return Maybe<void>::Ok();\n  }\n};\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"scalar_floordiv\", ScalarFloorDiv);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/scalar_fmod.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct ScalarFModGradCaptureState : public AutoGradCaptureState {\n  bool requires_grad;\n};\n\nclass ScalarFModGrad : public OpExprGradFunction<ScalarFModGradCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }\n\n  Maybe<void> Capture(ScalarFModGradCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    CHECK_EQ_OR_RETURN(inputs.size(), 1);  // NOLINT(maybe-need-error-msg)\n    ctx->requires_grad = inputs.at(0)->requires_grad();\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const ScalarFModGradCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n    in_grads->resize(1);\n    if (ctx->requires_grad) { in_grads->at(0) = out_grads.at(0); }\n    return Maybe<void>::Ok();\n  }\n};\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"scalar_fmod\", ScalarFModGrad);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/scalar_mul.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/attr_map.h\"\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct ScalarMulCaptureState : public AutoGradCaptureState {\n  bool requires_grad;\n  Scalar operand;\n};\n\nclass ScalarMul : public OpExprGradFunction<ScalarMulCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override {\n    const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n    CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n    base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Capture(ScalarMulCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    CHECK_EQ_OR_RETURN(inputs.size(), 1);  // NOLINT(maybe-need-error-msg)\n    ctx->requires_grad = inputs.at(0)->requires_grad();\n    if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n    ComposedAttrMap composed_attrs(attrs, base_attrs_);\n    bool has_float_operand = JUST(composed_attrs.GetAttr<bool>(\"has_float_operand\"));\n    if (has_float_operand) {\n      ctx->operand = Scalar(JUST(composed_attrs.GetAttr<double>(\"float_operand\")));\n    } else {\n      ctx->operand = Scalar(JUST(composed_attrs.GetAttr<int64_t>(\"int_operand\")));\n    }\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const ScalarMulCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n    in_grads->resize(1);\n    if (ctx->requires_grad) {\n      in_grads->at(0) = JUST(functional::ScalarMul(out_grads.at(0), ctx->operand, false));\n    }\n    return Maybe<void>::Ok();\n  }\n\n private:\n  AttrMap base_attrs_;\n};\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"scalar_mul\", ScalarMul);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/scalar_pow.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/attr_map.h\"\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct ScalarPowCaptureState : public AutoGradCaptureState {\n  bool requires_grad;\n  Scalar operand;\n};\n\nclass ScalarPow : public OpExprGradFunction<ScalarPowCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override {\n    const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n    base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n    CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Capture(ScalarPowCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    CHECK_EQ_OR_RETURN(inputs.size(), 1);   // NOLINT(maybe-need-error-msg)\n    CHECK_EQ_OR_RETURN(outputs.size(), 1);  // NOLINT(maybe-need-error-msg)\n    ctx->requires_grad = inputs.at(0)->requires_grad();\n    if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n\n    ComposedAttrMap composed_attrs(attrs, base_attrs_);\n    bool has_float_operand = JUST(composed_attrs.GetAttr<bool>(\"has_float_operand\"));\n    if (has_float_operand) {\n      ctx->operand = Scalar(JUST(composed_attrs.GetAttr<double>(\"float_operand\")));\n    } else {\n      ctx->operand = Scalar(JUST(composed_attrs.GetAttr<int64_t>(\"int_operand\")));\n    }\n    ctx->SaveTensorForBackward(inputs.at(0));\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const ScalarPowCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    const auto& x = ctx->SavedTensors().at(0);\n    in_grads->resize(1);\n    if (ctx->requires_grad) {\n      in_grads->at(0) = JUST(functional::ScalarPowGrad(x, out_grads.at(0), ctx->operand));\n    }\n    return Maybe<void>::Ok();\n  }\n\n private:\n  AttrMap base_attrs_;\n};\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"scalar_pow\", ScalarPow);\n\nclass ScalarReversePow : public OpExprGradFunction<ScalarPowCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override {\n    const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n    base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n    CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Capture(ScalarPowCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    CHECK_EQ_OR_RETURN(inputs.size(), 1);   // NOLINT(maybe-need-error-msg)\n    CHECK_EQ_OR_RETURN(outputs.size(), 1);  // NOLINT(maybe-need-error-msg)\n    ctx->requires_grad = inputs[0]->requires_grad();\n    if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n\n    ComposedAttrMap composed_attrs(attrs, base_attrs_);\n    bool has_float_operand = JUST(composed_attrs.GetAttr<bool>(\"has_float_operand\"));\n    if (has_float_operand) {\n      ctx->operand = Scalar(JUST(composed_attrs.GetAttr<double>(\"float_operand\")));\n    } else {\n      ctx->operand = Scalar(JUST(composed_attrs.GetAttr<int64_t>(\"int_operand\")));\n    }\n    ctx->SaveTensorForBackward(inputs[0]);\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const ScalarPowCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    const auto& x = ctx->SavedTensors()[0];\n    in_grads->resize(1);\n    if (ctx->requires_grad) {\n      (*in_grads)[0] = JUST(functional::ScalarReversePowGrad(x, out_grads[0], ctx->operand));\n    }\n    return Maybe<void>::Ok();\n  }\n\n private:\n  AttrMap base_attrs_;\n};\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"scalar_reverse_pow\", ScalarReversePow);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/scalar_truncdiv.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/functional/functional.h\"\n#include \"oneflow/core/common/container_util.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct ScalarTruncDivCaptureState : public AutoGradCaptureState {\n  bool requires_grad = true;\n};\n\nclass ScalarTruncDiv : public OpExprGradFunction<ScalarTruncDivCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }\n\n  Maybe<void> Capture(ScalarTruncDivCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    CHECK_EQ_OR_RETURN(inputs.size(), 1);  // NOLINT(maybe-need-error-msg)\n    ctx->requires_grad = inputs.at(0)->requires_grad();\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const ScalarTruncDivCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n    in_grads->resize(1);\n    if (ctx->requires_grad) {\n      JUST(VectorAt(*in_grads, 0)) = JUST(functional::ZerosLike(JUST(VectorAt(out_grads, 0))));\n    }\n    return Maybe<void>::Ok();\n  }\n};\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"scalar_truncdiv\", ScalarTruncDiv);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/scaled_dot_product_attention.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/framework/op_builder.h\"\n#include \"oneflow/core/framework/op_expr.h\"\n#include \"oneflow/core/framework/op_interpreter/op_interpreter_util.h\"\n#include \"oneflow/core/functional/functional.h\"\n#include \"oneflow/core/common/container_util.h\"\n#if CUDA_VERSION >= 11070\n\nnamespace oneflow {\n\nnamespace one {\n\nstruct ScaledDotProductFlashAttentionCaptureState : public AutoGradCaptureState {\n  bool query_requires_grad = true;\n  bool key_requires_grad = true;\n  bool value_requires_grad = true;\n  size_t query_idx = 0;\n  size_t key_idx = 0;\n  size_t value_idx = 0;\n  size_t out_idx = 0;\n  size_t softmax_lse_idx = 0;\n  size_t rng_state_idx = 0;\n  float p_dropout = .0f;\n  float softmax_scale = .0f;\n  bool is_causal = false;\n};\n\nclass ScaledDotProductFlashAttention\n    : public OpExprGradFunction<ScaledDotProductFlashAttentionCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override {\n    const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n    CHECK_NOTNULL_OR_RETURN(fw_op_expr) << \"fw_op_expr should not be None. \";\n    base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Capture(ScaledDotProductFlashAttentionCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    CHECK_EQ_OR_RETURN(inputs.size(), 3) << \"Input size should be equal to 3. \";\n    ComposedAttrMap composed_attrs(attrs, base_attrs_);\n    ctx->p_dropout = JUST(composed_attrs.GetAttr<float>(\"p_dropout\"));\n    ctx->softmax_scale = JUST(composed_attrs.GetAttr<float>(\"softmax_scale\"));\n    ctx->is_causal = JUST(composed_attrs.GetAttr<bool>(\"is_causal\"));\n    ctx->query_requires_grad = JUST(oneflow::VectorAt(inputs, 0))->requires_grad();\n    ctx->key_requires_grad = JUST(oneflow::VectorAt(inputs, 1))->requires_grad();\n    ctx->value_requires_grad = JUST(oneflow::VectorAt(inputs, 2))->requires_grad();\n    ctx->query_idx = ctx->SaveTensorForBackward(JUST(oneflow::VectorAt(inputs, 0)));\n    ctx->key_idx = ctx->SaveTensorForBackward(JUST(oneflow::VectorAt(inputs, 1)));\n    ctx->value_idx = ctx->SaveTensorForBackward(JUST(oneflow::VectorAt(inputs, 2)));\n    ctx->out_idx = ctx->SaveTensorForBackward(JUST(oneflow::VectorAt(outputs, 0)));\n    ctx->softmax_lse_idx = ctx->SaveTensorForBackward(JUST(oneflow::VectorAt(outputs, 1)));\n    ctx->rng_state_idx = ctx->SaveTensorForBackward(JUST(oneflow::VectorAt(outputs, 2)));\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const ScaledDotProductFlashAttentionCaptureState* ctx,\n                    const TensorTuple& out_grads, TensorTuple* in_grads) const override {\n    CHECK_EQ_OR_RETURN(out_grads.size(), 3) << \"Out grads size should be equal to 3. \";\n    std::shared_ptr<oneflow::one::TensorTuple> grads;\n    in_grads->resize(3);\n    grads = JUST(functional::ScaledDotProductFlashAttentionGrad(\n        JUST(oneflow::VectorAt(out_grads, 0)),\n        JUST(oneflow::VectorAt(ctx->SavedTensors(), ctx->query_idx)),\n        JUST(oneflow::VectorAt(ctx->SavedTensors(), ctx->key_idx)),\n        JUST(oneflow::VectorAt(ctx->SavedTensors(), ctx->value_idx)),\n        JUST(oneflow::VectorAt(ctx->SavedTensors(), ctx->out_idx)),\n        JUST(oneflow::VectorAt(ctx->SavedTensors(), ctx->softmax_lse_idx)),\n        JUST(oneflow::VectorAt(ctx->SavedTensors(), ctx->rng_state_idx)), ctx->p_dropout,\n        ctx->is_causal, ctx->softmax_scale));\n\n    if (ctx->query_requires_grad) {\n      JUST(oneflow::VectorAt(*in_grads, 0)) = JUST(oneflow::VectorAt(*grads, 0));\n    }\n    if (ctx->key_requires_grad) {\n      JUST(oneflow::VectorAt(*in_grads, 1)) = JUST(oneflow::VectorAt(*grads, 1));\n    }\n    if (ctx->value_requires_grad) {\n      JUST(oneflow::VectorAt(*in_grads, 2)) = JUST(oneflow::VectorAt(*grads, 2));\n    }\n\n    return Maybe<void>::Ok();\n  }\n\n private:\n  AttrMap base_attrs_;\n};\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"scaled_dot_product_flash_attention\",\n                               ScaledDotProductFlashAttention);\n\n}  // namespace one\n\n}  // namespace oneflow\n\n#endif  // CUDA_VERSION >= 11070\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/scatter_nd.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct ScatterNdCaptureState : public AutoGradCaptureState {\n  bool requires_grad;\n};\n\nclass ScatterNd : public OpExprGradFunction<ScatterNdCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }\n\n  Maybe<void> Capture(ScatterNdCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    CHECK_EQ_OR_RETURN(inputs.size(), 2);   // NOLINT(maybe-need-error-msg)\n    CHECK_EQ_OR_RETURN(outputs.size(), 1);  // NOLINT(maybe-need-error-msg)\n    ctx->requires_grad = inputs.at(1)->requires_grad();\n    if (ctx->requires_grad) {\n      ctx->SaveTensorForBackward(inputs.at(0));  // indices\n    }\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const ScatterNdCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n    in_grads->resize(2);\n    if (ctx->requires_grad) {\n      const auto& indices = ctx->SavedTensors().at(0);\n      in_grads->at(1) = JUST(functional::GatherNd(out_grads.at(0), indices));\n    }\n    return Maybe<void>::Ok();\n  }\n};\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"scatter_nd\", ScatterNd);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/select_top_n.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/framework/device.h\"\n#include \"oneflow/core/framework/op_builder.h\"\n#include \"oneflow/core/framework/op_interpreter/op_interpreter_util.h\"\n#include \"oneflow/core/framework/op_expr.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct SelectTopNCaptureState : public AutoGradCaptureState {\n  TensorTuple inputs;\n  std::vector<bool> requires_grad;\n  int32_t top_n = 0;\n};\n\nclass SelectTopN : public OpExprGradFunction<SelectTopNCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }\n\n  Maybe<void> Capture(SelectTopNCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    ctx->inputs = inputs;\n    ctx->top_n = JUST(attrs.GetAttr<int32_t>(\"top_n\"));\n    ctx->requires_grad.resize(inputs.size());\n    for (int i = 0; i < ctx->requires_grad.size(); ++i) {\n      ctx->requires_grad.at(i) = inputs.at(i)->requires_grad();\n    }\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const SelectTopNCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    CHECK_EQ_OR_RETURN(ctx->top_n, out_grads.size());  // NOLINT(maybe-need-error-msg)\n    for (int i = 0; i < ctx->top_n; ++i) {\n      if (!ctx->requires_grad.at(i)) { continue; }\n      in_grads->at(i) = out_grads.at(i);\n    }\n    for (int i = ctx->top_n; i < ctx->inputs.size(); ++i) {\n      if (!ctx->requires_grad.at(i)) { continue; }\n      const auto& tensor = ctx->inputs.at(i);\n      in_grads->at(i) =\n          JUST(StaticZerosTensor::MakeTensor(tensor->shape(), tensor->dtype()->data_type(),\n                                             tensor->memory_format(), JUST(tensor->device())));\n    }\n    return Maybe<void>::Ok();\n  }\n};\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"select_top_n\", SelectTopN);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/slice.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/framework/op_builder.h\"\n#include \"oneflow/core/framework/op_interpreter/op_interpreter_util.h\"\n#include \"oneflow/core/framework/op_expr.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct SliceCaptureState : public AutoGradCaptureState {\n  Shape like_shape;\n  std::vector<int64_t> start;\n  std::vector<int64_t> stop;\n  std::vector<int64_t> step;\n};\n\nclass Slice : public OpExprGradFunction<SliceCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override {\n    const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n    CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n    base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Capture(SliceCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs,\n                      const AttrMap& attrs) const override {\n    CHECK_EQ_OR_RETURN(inputs.size(), 1);   // NOLINT(maybe-need-error-msg)\n    CHECK_EQ_OR_RETURN(outputs.size(), 1);  // NOLINT(maybe-need-error-msg)\n\n    ComposedAttrMap composed_attrs(attrs, base_attrs_);\n    ctx->start = JUST(composed_attrs.GetAttr<std::vector<int64_t>>(\"start\"));\n    ctx->stop = JUST(composed_attrs.GetAttr<std::vector<int64_t>>(\"stop\"));\n    ctx->step = JUST(composed_attrs.GetAttr<std::vector<int64_t>>(\"step\"));\n    ctx->like_shape = *(inputs[0]->shape());\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const SliceCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    in_grads->resize(1);\n    (*in_grads)[0] = JUST(\n        functional::SliceGrad(out_grads[0], ctx->like_shape, ctx->start, ctx->stop, ctx->step));\n    return Maybe<void>::Ok();\n  }\n\n private:\n  AttrMap base_attrs_;\n};\n\nstruct SliceUpdateCaptureState : public AutoGradCaptureState {\n  bool requires_grad_ref = false;\n  bool requires_grad_value = false;\n  std::vector<int64_t> start;\n  std::vector<int64_t> stop;\n  std::vector<int64_t> step;\n  Shape value_shape;  // used to calculate ref gradient\n  Symbol<NdSbp> value_sbp;\n};\n\nclass SliceUpdate : public OpExprGradFunction<SliceUpdateCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override {\n    const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n    CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n\n    base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Capture(SliceUpdateCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    CHECK_EQ_OR_RETURN(inputs.size(), 2);   // NOLINT(maybe-need-error-msg)\n    CHECK_EQ_OR_RETURN(outputs.size(), 1);  // NOLINT(maybe-need-error-msg)\n    ctx->requires_grad_ref = inputs[0]->requires_grad();\n    ctx->requires_grad_value = inputs[1]->requires_grad();\n    if (!ctx->requires_grad_ref && !ctx->requires_grad_value) { return Maybe<void>::Ok(); }\n\n    ComposedAttrMap composed_attrs(attrs, base_attrs_);\n    ctx->start = JUST(composed_attrs.GetAttr<std::vector<int64_t>>(\"start\"));\n    ctx->stop = JUST(composed_attrs.GetAttr<std::vector<int64_t>>(\"stop\"));\n    ctx->step = JUST(composed_attrs.GetAttr<std::vector<int64_t>>(\"step\"));\n\n    if (ctx->requires_grad_ref) {\n      ctx->value_shape = *(inputs[1]->shape());\n      if (inputs[1]->is_global()) { ctx->value_sbp = JUST(inputs[1]->nd_sbp()); }\n    }\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const SliceUpdateCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    in_grads->resize(2);\n\n    if (ctx->requires_grad_ref) {\n      std::shared_ptr<Tensor> zeros;\n      if (out_grads[0]->is_local()) {\n        zeros = JUST(functional::Constant(ctx->value_shape, 0, out_grads[0]->dtype(),\n                                          JUST(out_grads[0]->device())));\n      } else {\n        const auto& parallel_desc = JUST(out_grads[0]->parallel_desc());\n        zeros = JUST(functional::GlobalConstant(ctx->value_shape, 0, out_grads[0]->dtype(),\n                                                parallel_desc, *JUST(GetSbpList(ctx->value_sbp))));\n      }\n      (*in_grads)[0] = JUST(functional::SliceUpdate(out_grads[0], zeros, ctx->start, ctx->stop,\n                                                    ctx->step, /*inplace=*/false));\n    }\n    if (ctx->requires_grad_value) {\n      (*in_grads)[1] = JUST(functional::Slice(out_grads[0], ctx->start, ctx->stop, ctx->step,\n                                              /*enable_view_slice=*/false));\n    }\n    return Maybe<void>::Ok();\n  }\n\n private:\n  AttrMap base_attrs_;\n};\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"slice_update\", SliceUpdate);\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"slice\", Slice);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/smooth_l1_loss.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/framework/attr_map.h\"\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct SmoothL1LossCaptureState : public AutoGradCaptureState {\n  bool input_requires_grad = false;\n  bool target_requires_grad = false;\n  float beta = 0.0;\n};\n\nclass SmoothL1Loss : public OpExprGradFunction<SmoothL1LossCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override {\n    const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n    CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n    base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Capture(SmoothL1LossCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    CHECK_EQ_OR_RETURN(inputs.size(), 2);  // NOLINT(maybe-need-error-msg)\n\n    ctx->input_requires_grad = inputs.at(0)->requires_grad();   // input\n    ctx->target_requires_grad = inputs.at(1)->requires_grad();  // target\n\n    ctx->SaveTensorForBackward(inputs.at(0));  // input\n    ctx->SaveTensorForBackward(inputs.at(1));  // target\n\n    ComposedAttrMap composed_attrs(attrs, base_attrs_);\n    ctx->beta = JUST(composed_attrs.GetAttr<float>(\"beta\"));\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const SmoothL1LossCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    CHECK_EQ_OR_RETURN(out_grads.size(), 1);            // NOLINT(maybe-need-error-msg)\n    CHECK_EQ_OR_RETURN(ctx->SavedTensors().size(), 2);  // NOLINT(maybe-need-error-msg)\n    in_grads->resize(2);\n    const auto& input = ctx->SavedTensors().at(0);\n    const auto& target = ctx->SavedTensors().at(1);\n    const auto& grad = JUST(functional::SmoothL1LossGrad(out_grads[0], input, target, ctx->beta));\n\n    if (ctx->input_requires_grad) { (*in_grads)[0] = grad; }\n    if (ctx->target_requires_grad) { (*in_grads)[1] = JUST(functional::Negative(grad)); }\n    return Maybe<void>::Ok();\n  }\n\n private:\n  AttrMap base_attrs_;\n};\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"smooth_l1_loss\", SmoothL1Loss);  // todo: name\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/softmax.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/framework/op_interpreter/op_interpreter_util.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct SoftmaxCaptureState : public AutoGradCaptureState {\n  bool requires_grad;\n};\n\nclass Softmax : public OpExprGradFunction<SoftmaxCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override;\n  Maybe<void> Capture(SoftmaxCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override;\n  Maybe<void> Apply(const SoftmaxCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override;\n};\n\nMaybe<void> Softmax::Init(const OpExpr& op) { return Maybe<void>::Ok(); }\n\nMaybe<void> Softmax::Capture(SoftmaxCaptureState* ctx, const TensorTuple& inputs,\n                             const TensorTuple& outputs, const AttrMap& attrs) const {\n  CHECK_EQ_OR_RETURN(inputs.size(), 1);  // NOLINT(maybe-need-error-msg)\n  ctx->requires_grad = inputs.at(0)->requires_grad();\n\n  if (!ctx->requires_grad) return Maybe<void>::Ok();\n\n  ctx->SaveTensorForBackward(outputs.at(0));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> Softmax::Apply(const SoftmaxCaptureState* ctx, const TensorTuple& out_grads,\n                           TensorTuple* in_grads) const {\n  if (!ctx->requires_grad) return Maybe<void>::Ok();\n  CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n  const auto& dy = out_grads.at(0);\n  const auto& y = ctx->SavedTensors().at(0);\n  in_grads->resize(1);\n  in_grads->at(0) = JUST(functional::SoftmaxGrad(dy, y));\n  return Maybe<void>::Ok();\n}\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"softmax\", Softmax);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/softmax_cross_entropy.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct SoftmaxCrossEntropyGradState : public AutoGradCaptureState {\n  bool requires_grad = false;\n};\n\nclass SoftmaxCrossEntropy : public OpExprGradFunction<SoftmaxCrossEntropyGradState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override;\n  Maybe<void> Capture(SoftmaxCrossEntropyGradState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override;\n  Maybe<void> Apply(const SoftmaxCrossEntropyGradState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override;\n};\n\nMaybe<void> SoftmaxCrossEntropy::Init(const OpExpr& op) {\n  const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n  CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> SoftmaxCrossEntropy::Capture(SoftmaxCrossEntropyGradState* ctx,\n                                         const TensorTuple& inputs, const TensorTuple& outputs,\n                                         const AttrMap& attrs) const {\n  ctx->requires_grad = inputs.at(0)->requires_grad();\n  if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n\n  CHECK_EQ_OR_RETURN(inputs.size(), 2);       // NOLINT(maybe-need-error-msg)\n  CHECK_EQ_OR_RETURN(outputs.size(), 2);      // NOLINT(maybe-need-error-msg)\n  ctx->SaveTensorForBackward(inputs.at(1));   // label\n  ctx->SaveTensorForBackward(outputs.at(1));  // prob\n\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> SoftmaxCrossEntropy::Apply(const SoftmaxCrossEntropyGradState* ctx,\n                                       const TensorTuple& out_grads, TensorTuple* in_grads) const {\n  if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n\n  CHECK_EQ_OR_RETURN(out_grads.size(), 2);  // NOLINT(maybe-need-error-msg)\n  const auto& dy = out_grads.at(0);\n  const auto& label = ctx->SavedTensors().at(0);\n  const auto& prob = ctx->SavedTensors().at(1);\n\n  in_grads->resize(2);  // prediction, label\n  (*in_grads)[0] = JUST(functional::SoftmaxCrossEntropyGrad(dy, label, prob));\n  return Maybe<void>::Ok();\n}\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"softmax_cross_entropy\", SoftmaxCrossEntropy);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/sparse_cross_entropy.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/framework/attr_map.h\"\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct SparseCrossEntropyCaptureState : public AutoGradCaptureState {\n  bool requires_grad = false;\n  int64_t depth = -1;\n  size_t prediction_index = -1;\n  size_t label_index = -1;\n};\n\ntemplate<bool is_distributed>\nclass SparseCrossEntropy : public OpExprGradFunction<SparseCrossEntropyCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override {\n    const auto* op_expr = dynamic_cast<const UserOpExpr*>(&op);\n    CHECK_NOTNULL_OR_RETURN(op_expr);  // NOLINT(maybe-need-error-msg)\n    base_attrs_ = MakeAttrMapFromUserOpConf(op_expr->proto());\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Capture(SparseCrossEntropyCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    CHECK_EQ_OR_RETURN(inputs.size(), 2);  // NOLINT(maybe-need-error-msg)\n    ctx->requires_grad = inputs.at(0)->requires_grad();\n    if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n\n    ComposedAttrMap composed_attrs(attrs, base_attrs_);\n    ctx->depth = JUST(composed_attrs.GetAttr<int64_t>(\"depth\"));\n    ctx->prediction_index = ctx->SaveTensorForBackward(inputs.at(0));  // prediction\n    ctx->label_index = ctx->SaveTensorForBackward(inputs.at(1));       // label\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const SparseCrossEntropyCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n\n    CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n    const auto& prediction = ctx->SavedTensors().at(ctx->prediction_index);\n    const auto& label = ctx->SavedTensors().at(ctx->label_index);\n    in_grads->resize(2);\n    if (is_distributed) {\n      in_grads->at(0) = JUST(\n          functional::SparseCrossEntropyMsGrad(prediction, label, out_grads.at(0), ctx->depth));\n    } else {\n      in_grads->at(0) =\n          JUST(functional::SparseCrossEntropyGrad(prediction, label, out_grads.at(0), ctx->depth));\n    }\n    return Maybe<void>::Ok();\n  }\n\n private:\n  AttrMap base_attrs_;\n};\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"sparse_cross_entropy_ms\", SparseCrossEntropy<true>);\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"sparse_cross_entropy\", SparseCrossEntropy<false>);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/sparse_softmax_cross_entropy.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/container_util.h\"\n#include \"oneflow/core/framework/attr_map.h\"\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/framework/op_interpreter/op_interpreter_util.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct SparseSoftmaxCrossEntropyCaptureState : public AutoGradCaptureState {\n  int64_t depth;\n};\n\nclass SparseSoftmaxCrossEntropy : public OpExprGradFunction<SparseSoftmaxCrossEntropyCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override;\n  Maybe<void> Capture(SparseSoftmaxCrossEntropyCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override;\n  Maybe<void> Apply(const SparseSoftmaxCrossEntropyCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override;\n\n private:\n  AttrMap base_attrs_;\n};\n\nMaybe<void> SparseSoftmaxCrossEntropy::Init(const OpExpr& op) {\n  const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n  CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n  base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> SparseSoftmaxCrossEntropy::Capture(SparseSoftmaxCrossEntropyCaptureState* ctx,\n                                               const TensorTuple& inputs,\n                                               const TensorTuple& outputs,\n                                               const AttrMap& attrs) const {\n  ComposedAttrMap composed_attrs(attrs, base_attrs_);\n  ctx->depth = JUST(composed_attrs.GetAttr<int64_t>(\"depth\"));\n  CHECK_EQ_OR_RETURN(inputs.size(), 2);                    // NOLINT(maybe-need-error-msg)\n  CHECK_EQ_OR_RETURN(outputs.size(), 2);                   // NOLINT(maybe-need-error-msg)\n  ctx->SaveTensorForBackward(JUST(VectorAt(outputs, 0)));  // prob\n  ctx->SaveTensorForBackward(JUST(VectorAt(inputs, 1)));   // label\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> SparseSoftmaxCrossEntropy::Apply(const SparseSoftmaxCrossEntropyCaptureState* ctx,\n                                             const TensorTuple& out_grads,\n                                             TensorTuple* in_grads) const {\n  CHECK_EQ_OR_RETURN(out_grads.size(), 2);  // NOLINT(maybe-need-error-msg)\n  const auto& dy = JUST(VectorAt(out_grads, 1));\n  const auto& prob = JUST(VectorAt(ctx->SavedTensors(), 0));\n  const auto& label = JUST(VectorAt(ctx->SavedTensors(), 1));\n  // SparseSoftmaxCrossEntropy has 2 inputs (prediction and label), and the second input does not\n  // require gradient.\n  in_grads->resize(2);\n  JUST(VectorAt(*in_grads, 0)) =\n      JUST(functional::SparseSoftmaxCrossEntropyGrad(dy, prob, label, ctx->depth));\n  return Maybe<void>::Ok();\n}\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"sparse_softmax_cross_entropy\", SparseSoftmaxCrossEntropy);\n\n}  // namespace one\n}  // namespace oneflow"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/sparse_softmax_cross_entropy_ms.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/container_util.h\"\n#include \"oneflow/core/framework/attr_map.h\"\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/framework/op_interpreter/op_interpreter_util.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct SparseSoftmaxCrossEntropyMsCaptureState : public AutoGradCaptureState {\n  int64_t depth = 0;\n};\n\nclass SparseSoftmaxCrossEntropyMs\n    : public OpExprGradFunction<SparseSoftmaxCrossEntropyMsCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override;\n  Maybe<void> Capture(SparseSoftmaxCrossEntropyMsCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override;\n  Maybe<void> Apply(const SparseSoftmaxCrossEntropyMsCaptureState* ctx,\n                    const TensorTuple& out_grads, TensorTuple* in_grads) const override;\n\n private:\n  AttrMap base_attrs_;\n};\n\nMaybe<void> SparseSoftmaxCrossEntropyMs::Init(const OpExpr& op) {\n  const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n  CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n  base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> SparseSoftmaxCrossEntropyMs::Capture(SparseSoftmaxCrossEntropyMsCaptureState* ctx,\n                                                 const TensorTuple& inputs,\n                                                 const TensorTuple& outputs,\n                                                 const AttrMap& attrs) const {\n  ComposedAttrMap composed_attrs(attrs, base_attrs_);\n  ctx->depth = JUST(composed_attrs.GetAttr<int64_t>(\"depth\"));\n  CHECK_EQ_OR_RETURN(inputs.size(), 2);                    // NOLINT(maybe-need-error-msg)\n  CHECK_EQ_OR_RETURN(outputs.size(), 2);                   // NOLINT(maybe-need-error-msg)\n  ctx->SaveTensorForBackward(JUST(VectorAt(outputs, 0)));  // prob\n  ctx->SaveTensorForBackward(JUST(VectorAt(inputs, 1)));   // label\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> SparseSoftmaxCrossEntropyMs::Apply(const SparseSoftmaxCrossEntropyMsCaptureState* ctx,\n                                               const TensorTuple& out_grads,\n                                               TensorTuple* in_grads) const {\n  CHECK_EQ_OR_RETURN(out_grads.size(), 2);  // NOLINT(maybe-need-error-msg)\n  const auto& dy = JUST(VectorAt(out_grads, 1));\n  const auto& prob = JUST(VectorAt(ctx->SavedTensors(), 0));\n  const auto& label = JUST(VectorAt(ctx->SavedTensors(), 1));\n  // SparseSoftmaxCrossEntropy has 2 inputs (prediction and label), and the second input does not\n  // require gradient.\n  in_grads->resize(2);\n  JUST(VectorAt(*in_grads, 0)) =\n      JUST(functional::SparseSoftmaxCrossEntropyMsGrad(dy, prob, label, ctx->depth));\n  return Maybe<void>::Ok();\n}\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"sparse_softmax_cross_entropy_ms\", SparseSoftmaxCrossEntropyMs);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/split_like.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/framework/op_builder.h\"\n#include \"oneflow/core/framework/op_interpreter/op_interpreter_util.h\"\n#include \"oneflow/core/framework/op_expr.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct SplitLikeCaptureState : public AutoGradCaptureState {\n  int64_t axis;\n  bool requires_grad;\n};\n\nclass SplitLike : public OpExprGradFunction<SplitLikeCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override;\n  Maybe<void> Capture(SplitLikeCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override;\n  Maybe<void> Apply(const SplitLikeCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override;\n\n private:\n  AttrMap base_attrs_;\n};\n\nMaybe<void> SplitLike::Init(const OpExpr& op) {\n  const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n  CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n  base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> SplitLike::Capture(SplitLikeCaptureState* ctx, const TensorTuple& inputs,\n                               const TensorTuple& outputs, const AttrMap& attrs) const {\n  CHECK_EQ_OR_RETURN(inputs.size(), outputs.size() + 1);  // NOLINT(maybe-need-error-msg)\n  ctx->requires_grad = inputs.at(0)->requires_grad();\n  if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n  ComposedAttrMap composed_attrs(attrs, base_attrs_);\n  ctx->axis = JUST(composed_attrs.GetAttr<int64_t>(\"axis\"));\n  for (int i = 0; i < outputs.size(); ++i) { ctx->SaveTensorForBackward(outputs.at(i)); }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> SplitLike::Apply(const SplitLikeCaptureState* ctx, const TensorTuple& out_grads,\n                             TensorTuple* in_grads) const {\n  in_grads->resize(out_grads.size() + 1);\n  if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n\n  const auto& saved_tensors = ctx->SavedTensors();\n  TensorTuple inputs;\n  inputs.reserve(out_grads.size());\n  for (int i = 0; i < out_grads.size(); ++i) {\n    const auto& out_grad_i = out_grads.at(i);\n    if (out_grad_i.get()) {\n      inputs.emplace_back(out_grad_i);\n    } else {\n      const auto& zero_grad = JUST(functional::ZerosLike(saved_tensors.at(i)));\n      inputs.emplace_back(zero_grad);\n    }\n  }\n  in_grads->at(0) = JUST(functional::Concat(inputs, ctx->axis));\n  return Maybe<void>::Ok();\n}\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"split_like\", SplitLike);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/squeeze.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/framework/op_builder.h\"\n#include \"oneflow/core/framework/op_expr.h\"\n#include \"oneflow/core/framework/op_interpreter/op_interpreter_util.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct SqueezeCaptureState : public AutoGradCaptureState {\n  bool requires_grad;\n};\n\nclass Squeeze : public OpExprGradFunction<SqueezeCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override;\n  Maybe<void> Capture(SqueezeCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override;\n  Maybe<void> Apply(const SqueezeCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override;\n\n private:\n  AttrMap base_attrs_;\n};\n\nMaybe<void> Squeeze::Init(const OpExpr& op) {\n  const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n  CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n  base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> Squeeze::Capture(SqueezeCaptureState* ctx, const TensorTuple& inputs,\n                             const TensorTuple& outputs, const AttrMap& attrs) const {\n  ctx->requires_grad = inputs.at(0)->requires_grad();\n  if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n\n  ctx->SaveTensorForBackward(inputs.at(0));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> Squeeze::Apply(const SqueezeCaptureState* ctx, const TensorTuple& out_grads,\n                           TensorTuple* in_grads) const {\n  if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n  CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n\n  const std::shared_ptr<oneflow::one::Tensor>& like = ctx->SavedTensors().at(0);\n  in_grads->resize(1);\n  in_grads->at(0) = JUST(functional::ReshapeLike(out_grads.at(0), like));\n  return Maybe<void>::Ok();\n}\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"squeeze\", Squeeze);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/stack.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/framework/op_builder.h\"\n#include \"oneflow/core/framework/op_expr.h\"\n#include \"oneflow/core/framework/op_interpreter/op_interpreter_util.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct StackCaptureState : public AutoGradCaptureState {\n  std::vector<bool> requires_grad;\n  int64_t axis = 1;\n  int64_t input_num = 2;\n};\n\nclass Stack : public OpExprGradFunction<StackCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override;\n  Maybe<void> Capture(StackCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs,\n                      const AttrMap& attrs) const override;\n  Maybe<void> Apply(const StackCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override;\n\n private:\n  AttrMap base_attrs_;\n};\n\nMaybe<void> Stack::Init(const OpExpr& op) {\n  const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n  CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n  base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> Stack::Capture(StackCaptureState* ctx, const TensorTuple& inputs,\n                           const TensorTuple& outputs, const AttrMap& attrs) const {\n  ctx->requires_grad.resize(inputs.size());\n  for (int i = 0; i < inputs.size(); ++i) { ctx->requires_grad[i] = inputs.at(i)->requires_grad(); }\n\n  ComposedAttrMap composed_attrs(attrs, base_attrs_);\n  ctx->axis = JUST(composed_attrs.GetAttr<int64_t>(\"axis\"));\n  for (const auto& input : inputs) { ctx->SaveTensorForBackward(input); }\n  ctx->input_num = inputs.size();\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> Stack::Apply(const StackCaptureState* ctx, const TensorTuple& out_grads,\n                         TensorTuple* in_grads) const {\n  CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n  in_grads->resize(ctx->input_num);\n  TensorTuple like(ctx->input_num);\n  for (int i = 0; i < ctx->input_num; ++i) { like[i] = ctx->SavedTensors().at(i); }\n  const auto& results = JUST(functional::StackGrad(out_grads.at(0), like, ctx->axis));\n  CHECK_EQ_OR_RETURN(results->size(), ctx->input_num)\n      << Error::RuntimeError() << \"The number of results (\" << results->size()\n      << \") must match the number of inputs (\" << ctx->input_num << \")\";\n  for (int i = 0; i < ctx->input_num; ++i) {\n    if (ctx->requires_grad.at(i)) { in_grads->at(i) = results->at(i); }\n  }\n  return Maybe<void>::Ok();\n}\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"stack\", Stack);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/tensor_scalar_binary.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/framework/op_interpreter/op_interpreter_util.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct TensorScalarCaptureState : public AutoGradCaptureState {\n  bool x_requires_grad;\n  bool scalar_requires_grad;\n};\n\nclass TensorScalarAddOrSub : public OpExprGradFunction<TensorScalarCaptureState> {\n public:\n  TensorScalarAddOrSub() = default;\n  virtual ~TensorScalarAddOrSub() = default;\n\n  Maybe<void> Init(const OpExpr& op) override;\n  Maybe<void> Capture(TensorScalarCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override;\n};\n\nMaybe<void> TensorScalarAddOrSub::Init(const OpExpr& op) {\n  const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n  CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> TensorScalarAddOrSub::Capture(TensorScalarCaptureState* ctx, const TensorTuple& inputs,\n                                          const TensorTuple& outputs, const AttrMap& attrs) const {\n  ctx->x_requires_grad = inputs.at(0)->requires_grad();\n  ctx->scalar_requires_grad = inputs.at(1)->requires_grad();\n  return Maybe<void>::Ok();\n}\n\nclass TensorScalarAdd : public TensorScalarAddOrSub {\n public:\n  Maybe<void> Apply(const TensorScalarCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    in_grads->resize(2);\n    if (ctx->x_requires_grad) { in_grads->at(0) = JUST(functional::Identity(out_grads.at(0))); }\n    if (ctx->scalar_requires_grad) {\n      int32_t num_axes = out_grads.at(0)->shape()->NumAxes();\n      std::vector<int32_t> axes_vec(num_axes);\n      std::iota(axes_vec.begin(), axes_vec.end(), 0);\n      in_grads->at(1) = JUST(functional::ReduceSum(out_grads.at(0), axes_vec, false, NullOpt));\n    }\n    return Maybe<void>::Ok();\n  }\n};\n\nclass TensorScalarSub : public TensorScalarAddOrSub {\n public:\n  Maybe<void> Apply(const TensorScalarCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    in_grads->resize(2);\n    if (ctx->x_requires_grad) { in_grads->at(0) = JUST(functional::Identity(out_grads.at(0))); }\n    if (ctx->scalar_requires_grad) {\n      int32_t num_axes = out_grads.at(0)->shape()->NumAxes();\n      std::vector<int32_t> axes_vec(num_axes);\n      std::iota(axes_vec.begin(), axes_vec.end(), 0);\n      const auto& reduce_sum =\n          JUST(functional::ReduceSum(out_grads.at(0), axes_vec, /*keepdims=*/false, NullOpt));\n      in_grads->at(1) = JUST(functional::ScalarMul(reduce_sum, /*other=*/1.0, false));\n    }\n    return Maybe<void>::Ok();\n  }\n};\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"scalar_add_by_tensor\", TensorScalarAdd);\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"scalar_sub_by_tensor\", TensorScalarSub);\n\nclass TensorScalarMul : public OpExprGradFunction<TensorScalarCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override;\n  Maybe<void> Capture(TensorScalarCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override;\n  Maybe<void> Apply(const TensorScalarCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override;\n};\n\nMaybe<void> TensorScalarMul::Init(const OpExpr& op) {\n  const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n  CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> TensorScalarMul::Capture(TensorScalarCaptureState* ctx, const TensorTuple& inputs,\n                                     const TensorTuple& outputs, const AttrMap& attrs) const {\n  ctx->x_requires_grad = inputs.at(0)->requires_grad();\n  ctx->scalar_requires_grad = inputs.at(1)->requires_grad();\n  if (ctx->x_requires_grad) { ctx->SaveTensorForBackward(inputs.at(1)); }\n  if (ctx->scalar_requires_grad) { ctx->SaveTensorForBackward(inputs.at(0)); }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> TensorScalarMul::Apply(const TensorScalarCaptureState* ctx,\n                                   const TensorTuple& out_grads, TensorTuple* in_grads) const {\n  in_grads->resize(2);\n  if (ctx->x_requires_grad) {\n    const auto& scalar = ctx->SavedTensors().at(0);\n    in_grads->at(0) = JUST(functional::Mul(out_grads.at(0), scalar));\n  }\n  if (ctx->scalar_requires_grad) {\n    const auto& x = ctx->SavedTensors().at(ctx->x_requires_grad);\n    const auto& y = JUST(functional::Mul(out_grads.at(0), x));\n    int32_t num_axes = out_grads.at(0)->shape()->NumAxes();\n    std::vector<int32_t> axes_vec(num_axes);\n    std::iota(axes_vec.begin(), axes_vec.end(), 0);\n    in_grads->at(1) = JUST(functional::ReduceSum(y, axes_vec, /*keepdims=*/false, NullOpt));\n  }\n  return Maybe<void>::Ok();\n}\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"scalar_mul_by_tensor\", TensorScalarMul);\n\nclass TensorScalarDiv : public OpExprGradFunction<TensorScalarCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override;\n  Maybe<void> Capture(TensorScalarCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override;\n  Maybe<void> Apply(const TensorScalarCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override;\n\n private:\n  std::shared_ptr<OpExpr> tensor_scalar_div_op_;\n  std::shared_ptr<OpExpr> broadcast_div_grad_op_;\n};\n\nMaybe<void> TensorScalarDiv::Init(const OpExpr& op) {\n  const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n  CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> TensorScalarDiv::Capture(TensorScalarCaptureState* ctx, const TensorTuple& inputs,\n                                     const TensorTuple& outputs, const AttrMap& attrs) const {\n  ctx->x_requires_grad = inputs.at(0)->requires_grad();\n  ctx->scalar_requires_grad = inputs.at(1)->requires_grad();\n  if (ctx->x_requires_grad || ctx->scalar_requires_grad) {\n    ctx->SaveTensorForBackward(inputs.at(1));\n  }\n  if (ctx->scalar_requires_grad) { ctx->SaveTensorForBackward(outputs.at(0)); }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> TensorScalarDiv::Apply(const TensorScalarCaptureState* ctx,\n                                   const TensorTuple& out_grads, TensorTuple* in_grads) const {\n  in_grads->resize(2);\n  if (ctx->x_requires_grad) {\n    const auto& scalar = ctx->SavedTensors().at(0);\n    in_grads->at(0) = JUST(functional::Div(out_grads.at(0), scalar));\n  }\n  if (ctx->scalar_requires_grad) {\n    const auto& scalar = ctx->SavedTensors().at(0);\n    const auto& y = ctx->SavedTensors().at(1);\n    in_grads->at(1) = JUST(functional::DivGrad(out_grads.at(0), y, scalar));\n  }\n  return Maybe<void>::Ok();\n}\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"scalar_div_by_tensor\", TensorScalarDiv);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/tensor_scatter_nd_update.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct TensorScatterNdUpdateCaptureState : public AutoGradCaptureState {\n  bool tensor_requires_grad = false;\n  bool update_requires_grad = false;\n};\n\nclass TensorScatterNdUpdate : public OpExprGradFunction<TensorScatterNdUpdateCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }\n\n  Maybe<void> Capture(TensorScatterNdUpdateCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    CHECK_EQ_OR_RETURN(inputs.size(), 3);   // NOLINT(maybe-need-error-msg)\n    CHECK_EQ_OR_RETURN(outputs.size(), 1);  // NOLINT(maybe-need-error-msg)\n    ctx->tensor_requires_grad = inputs.at(0)->requires_grad();\n    ctx->update_requires_grad = inputs.at(2)->requires_grad();\n    if (ctx->update_requires_grad || ctx->tensor_requires_grad) {\n      ctx->SaveTensorForBackward(inputs.at(1));  // indices\n    }\n    if (ctx->tensor_requires_grad) {\n      ctx->SaveTensorForBackward(inputs.at(2));  // update: only use meta information\n    }\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const TensorScatterNdUpdateCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n    in_grads->resize(3);\n    if (ctx->update_requires_grad) {\n      const auto& indices = ctx->SavedTensors().at(0);\n      in_grads->at(2) = JUST(functional::GatherNd(out_grads.at(0), indices));\n    }\n    if (ctx->tensor_requires_grad) {\n      const auto& indices = ctx->SavedTensors().at(0);\n      const auto& update = ctx->SavedTensors().at(1);\n      const auto& temp = JUST(functional::ZerosLike(update));\n      in_grads->at(0) = JUST(\n          functional::TensorScatterNdUpdate(out_grads.at(0), indices, temp, /*inplace=*/false));\n    }\n    return Maybe<void>::Ok();\n  }\n};\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"tensor_scatter_nd_update\", TensorScatterNdUpdate);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/tf_pool.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/attr_map.h\"\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/framework/op_builder.h\"\n#include \"oneflow/core/framework/op_interpreter/op_interpreter_util.h\"\n#include \"oneflow/core/framework/op_expr.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\nnamespace one {\n\nnamespace {\n\nstruct TFPoolCaptureState : public AutoGradCaptureState {\n  bool requires_grad = false;\n  size_t input_index = 0;\n  size_t output_index = 0;\n\n  std::string data_format;\n  std::string padding;\n  std::vector<int32_t> padding_before;\n  std::vector<int32_t> padding_after;\n  std::vector<int32_t> pool_size;\n  std::vector<int32_t> strides;\n  bool ceil_mode = false;\n};\n\nclass TFPoolNdGrad : public OpExprGradFunction<TFPoolCaptureState> {\n public:\n  virtual ~TFPoolNdGrad() = default;\n\n  using OpExprGradFunction<TFPoolCaptureState>::Init;\n\n  Maybe<void> Init(const OpExpr& op, const std::string& mode);\n  Maybe<void> Capture(TFPoolCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override;\n  Maybe<void> Apply(const TFPoolCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override;\n\n private:\n  std::string mode_;\n  AttrMap base_attrs_;\n};\n\nMaybe<void> TFPoolNdGrad::Init(const OpExpr& op, const std::string& mode) {\n  const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n  CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n  base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n  mode_ = mode;\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> TFPoolNdGrad::Capture(TFPoolCaptureState* ctx, const TensorTuple& inputs,\n                                  const TensorTuple& outputs, const AttrMap& attrs) const {\n  ctx->requires_grad = inputs.at(0)->requires_grad();\n  if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n\n  ctx->input_index = ctx->SaveTensorForBackward(inputs.at(0));\n  ctx->output_index = ctx->SaveTensorForBackward(outputs.at(0));\n\n  ComposedAttrMap composed_attrs(attrs, base_attrs_);\n  ctx->data_format = JUST(composed_attrs.GetAttr<std::string>(\"data_format\"));\n  ctx->padding = JUST(composed_attrs.GetAttr<std::string>(\"padding\"));\n  ctx->padding_before = JUST(composed_attrs.GetAttr<std::vector<int32_t>>(\"padding_before\"));\n  ctx->padding_after = JUST(composed_attrs.GetAttr<std::vector<int32_t>>(\"padding_after\"));\n  ctx->pool_size = JUST(composed_attrs.GetAttr<std::vector<int32_t>>(\"pool_size\"));\n  ctx->strides = JUST(composed_attrs.GetAttr<std::vector<int32_t>>(\"strides\"));\n  ctx->ceil_mode = JUST(composed_attrs.GetAttr<bool>(\"ceil_mode\"));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> TFPoolNdGrad::Apply(const TFPoolCaptureState* ctx, const TensorTuple& out_grads,\n                                TensorTuple* in_grads) const {\n  if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n  CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n\n  int32_t ndims = ctx->pool_size.size();\n  const auto& input = ctx->SavedTensors().at(ctx->input_index);\n  const auto& output = ctx->SavedTensors().at(ctx->output_index);\n\n  in_grads->resize(1);\n  (*in_grads)[0] = JUST(functional::TFPoolNdGrad(\n      input, output, out_grads[0], mode_, ndims, ctx->data_format, ctx->padding,\n      ctx->padding_before, ctx->padding_after, ctx->pool_size, ctx->strides, ctx->ceil_mode));\n\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\nclass TFMaxPoolNdGrad final : public TFPoolNdGrad {\n public:\n  Maybe<void> Init(const OpExpr& op) override { return TFPoolNdGrad::Init(op, \"tf_max\"); }\n};\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"tf_max_pool_1d\", TFMaxPoolNdGrad);\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"tf_max_pool_2d\", TFMaxPoolNdGrad);\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"tf_max_pool_3d\", TFMaxPoolNdGrad);\n\nclass TFAvgPoolNdGrad final : public TFPoolNdGrad {\n public:\n  Maybe<void> Init(const OpExpr& op) override { return TFPoolNdGrad::Init(op, \"tf_avg\"); }\n};\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"tf_avg_pool_1d\", TFAvgPoolNdGrad);\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"tf_avg_pool_2d\", TFAvgPoolNdGrad);\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"tf_avg_pool_3d\", TFAvgPoolNdGrad);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/to_contiguous.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct ToContiguousCaptureState : public AutoGradCaptureState {\n  bool requires_grad = false;\n};\n\nclass ToContiguous : public OpExprGradFunction<ToContiguousCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }\n\n  Maybe<void> Capture(ToContiguousCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    CHECK_EQ_OR_RETURN(inputs.size(), 1);  // NOLINT(maybe-need-error-msg)\n    ctx->requires_grad = inputs[0]->requires_grad();\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const ToContiguousCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n    in_grads->resize(1);\n    if (ctx->requires_grad) { (*in_grads)[0] = out_grads[0]; }\n    return Maybe<void>::Ok();\n  }\n};\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"to_contiguous\", ToContiguous);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/transpose.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/framework/op_builder.h\"\n#include \"oneflow/core/framework/op_expr.h\"\n#include \"oneflow/core/framework/op_interpreter/op_interpreter_util.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct TransposeCaptureState : public AutoGradCaptureState {\n  std::vector<int32_t> perm;\n  bool requires_grad;\n};\n\nclass Transpose : public OpExprGradFunction<TransposeCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override;\n  Maybe<void> Capture(TransposeCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override;\n  Maybe<void> Apply(const TransposeCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override;\n\n private:\n  AttrMap base_attrs_;\n};\n\nMaybe<void> Transpose::Init(const OpExpr& op) {\n  const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n  CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n  base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> Transpose::Capture(TransposeCaptureState* ctx, const TensorTuple& inputs,\n                               const TensorTuple& outputs, const AttrMap& attrs) const {\n  ctx->requires_grad = inputs.at(0)->requires_grad();\n  if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n\n  ComposedAttrMap composed_attrs(attrs, base_attrs_);\n  ctx->perm = JUST(composed_attrs.GetAttr<std::vector<int32_t>>(\"perm\"));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> Transpose::Apply(const TransposeCaptureState* ctx, const TensorTuple& out_grads,\n                             TensorTuple* in_grads) const {\n  if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n  CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n  std::vector<int32_t> grad_perm;\n  grad_perm.resize(ctx->perm.size());\n  FOR_RANGE(int32_t, i, 0, ctx->perm.size()) { grad_perm.at(ctx->perm.at(i)) = i; }\n  in_grads->at(0) = JUST(functional::Transpose(out_grads.at(0), grad_perm));\n  return Maybe<void>::Ok();\n}\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"transpose\", Transpose);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/tril.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/attr_map.h\"\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct TrilCaptureState : public AutoGradCaptureState {\n  bool requires_grad = false;\n  int64_t diagonal = 0;\n};\n\nclass Tril : public OpExprGradFunction<TrilCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override;\n  Maybe<void> Capture(TrilCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs,\n                      const AttrMap& attrs) const override;\n  Maybe<void> Apply(const TrilCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override;\n\n private:\n  AttrMap base_attrs_;\n};\n\nMaybe<void> Tril::Init(const OpExpr& op) {\n  const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n  CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n  base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> Tril::Capture(TrilCaptureState* ctx, const TensorTuple& inputs,\n                          const TensorTuple& outputs, const AttrMap& attrs) const {\n  ctx->requires_grad = inputs.at(0)->requires_grad();\n  if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n  ComposedAttrMap composed_attrs(attrs, base_attrs_);\n  ctx->diagonal = JUST(composed_attrs.GetAttr<int64_t>(\"diagonal\"));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> Tril::Apply(const TrilCaptureState* ctx, const TensorTuple& out_grads,\n                        TensorTuple* in_grads) const {\n  CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n  in_grads->resize(1);\n  if (ctx->requires_grad) {\n    in_grads->at(0) = JUST(functional::Tril(out_grads.at(0), ctx->diagonal));\n  }\n  return Maybe<void>::Ok();\n}\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"tril\", Tril);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/triu.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/attr_map.h\"\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct TriuCaptureState : public AutoGradCaptureState {\n  bool requires_grad;\n  int64_t diagonal;\n};\n\nclass Triu : public OpExprGradFunction<TriuCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override;\n  Maybe<void> Capture(TriuCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs,\n                      const AttrMap& attrs) const override;\n  Maybe<void> Apply(const TriuCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override;\n\n private:\n  AttrMap base_attrs_;\n};\n\nMaybe<void> Triu::Init(const OpExpr& op) {\n  const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n  CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n  base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> Triu::Capture(TriuCaptureState* ctx, const TensorTuple& inputs,\n                          const TensorTuple& outputs, const AttrMap& attrs) const {\n  ctx->requires_grad = inputs.at(0)->requires_grad();\n  if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n  ComposedAttrMap composed_attrs(attrs, base_attrs_);\n  ctx->diagonal = JUST(composed_attrs.GetAttr<int64_t>(\"diagonal\"));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> Triu::Apply(const TriuCaptureState* ctx, const TensorTuple& out_grads,\n                        TensorTuple* in_grads) const {\n  CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n  in_grads->resize(1);\n  if (ctx->requires_grad) {\n    in_grads->at(0) = JUST(functional::Triu(out_grads.at(0), ctx->diagonal));\n  }\n  return Maybe<void>::Ok();\n}\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"triu\", Triu);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/trunc.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/attr_map.h\"\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct TruncCaptureState : public AutoGradCaptureState {\n  bool requires_grad = false;\n};\n\nclass Trunc : public OpExprGradFunction<TruncCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override;\n  Maybe<void> Capture(TruncCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs,\n                      const AttrMap& attrs) const override;\n  Maybe<void> Apply(const TruncCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override;\n};\n\nMaybe<void> Trunc::Init(const OpExpr& op) {\n  const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n  CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> Trunc::Capture(TruncCaptureState* ctx, const TensorTuple& inputs,\n                           const TensorTuple& outputs, const AttrMap& attrs) const {\n  ctx->requires_grad = inputs.at(0)->requires_grad();\n  if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> Trunc::Apply(const TruncCaptureState* ctx, const TensorTuple& out_grads,\n                         TensorTuple* in_grads) const {\n  CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n  in_grads->resize(1);\n  if (ctx->requires_grad) { in_grads->at(0) = JUST(functional::ZerosLike(out_grads.at(0))); }\n  return Maybe<void>::Ok();\n}\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"trunc\", Trunc);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/two_stage_reduce.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/framework/attr_map.h\"\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\nnamespace one {\n\nenum class ReduceMode : int32_t {\n  kMin = 0,\n  kMax = 1,\n};\n\nstruct ReduceDeviceCaptureState : public AutoGradCaptureState {\n  std::vector<int32_t> axis;\n  bool requires_grad = false;\n  size_t mask_index = -1;\n  size_t count_index = -1;\n};\n\ntemplate<ReduceMode mode>\nclass ReduceDevice : public OpExprGradFunction<ReduceDeviceCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override {\n    const auto* op_expr = dynamic_cast<const UserOpExpr*>(&op);\n    CHECK_NOTNULL_OR_RETURN(op_expr);  // NOLINT(maybe-need-error-msg)\n    base_attrs_ = MakeAttrMapFromUserOpConf(op_expr->proto());\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Capture(ReduceDeviceCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    CHECK_EQ_OR_RETURN(inputs.size(), 1);  // NOLINT(maybe-need-error-msg)\n    ctx->requires_grad = inputs.at(0)->requires_grad();\n    if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n\n    ComposedAttrMap composed_attrs(attrs, base_attrs_);\n    ctx->axis = JUST(composed_attrs.GetAttr<std::vector<int32_t>>(\"axis\"));\n    ctx->mask_index = ctx->SaveTensorForBackward(outputs.at(1));   // mask\n    ctx->count_index = ctx->SaveTensorForBackward(outputs.at(2));  // count\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const ReduceDeviceCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n\n    CHECK_EQ_OR_RETURN(out_grads.size(), 3);  // NOLINT(maybe-need-error-msg)\n    const auto& mask = ctx->SavedTensors().at(ctx->mask_index);\n    const auto& count = ctx->SavedTensors().at(ctx->count_index);\n    in_grads->resize(1);\n    if (mode == ReduceMode::kMin) {\n      in_grads->at(0) =\n          JUST(functional::ReduceMinDeviceStageGrad(out_grads.at(0), mask, count, ctx->axis));\n    } else {\n      in_grads->at(0) =\n          JUST(functional::ReduceMaxDeviceStageGrad(out_grads.at(0), mask, count, ctx->axis));\n    }\n    return Maybe<void>::Ok();\n  }\n\n private:\n  AttrMap base_attrs_;\n};\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"reduce_min_device_stage\", ReduceDevice<ReduceMode::kMin>);\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"reduce_max_device_stage\", ReduceDevice<ReduceMode::kMax>);\n\nstruct ReduceGlobalCaptureState : public AutoGradCaptureState {\n  std::vector<int32_t> axis;\n  bool requires_grad = false;\n  bool keepdims = false;\n  size_t mask_index = -1;\n  size_t device_count_index = -1;\n};\n\ntemplate<ReduceMode mode>\nclass ReduceGlobal : public OpExprGradFunction<ReduceGlobalCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override {\n    const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n    CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n    base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Capture(ReduceGlobalCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    CHECK_EQ_OR_RETURN(inputs.size(), 2);   // NOLINT(maybe-need-error-msg)\n    CHECK_EQ_OR_RETURN(outputs.size(), 2);  // NOLINT(maybe-need-error-msg)\n    ctx->requires_grad = inputs.at(0)->requires_grad();\n    if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n\n    ComposedAttrMap composed_attrs(attrs, base_attrs_);\n    ctx->axis = JUST(composed_attrs.GetAttr<std::vector<int32_t>>(\"axis\"));\n    ctx->keepdims = JUST(composed_attrs.GetAttr<bool>(\"keepdims\"));\n    ctx->mask_index = ctx->SaveTensorForBackward(outputs.at(1));         // mask\n    ctx->device_count_index = ctx->SaveTensorForBackward(inputs.at(1));  // device_count\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const ReduceGlobalCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n    CHECK_EQ_OR_RETURN(out_grads.size(), 2);  // NOLINT(maybe-need-error-msg)\n    const auto& mask = ctx->SavedTensors().at(ctx->mask_index);\n    const auto& device_count = ctx->SavedTensors().at(ctx->device_count_index);\n    in_grads->resize(2);\n    if (mode == ReduceMode::kMin) {\n      in_grads->at(0) = JUST(functional::ReduceMinGlobalStageGrad(\n          out_grads.at(0), mask, device_count, ctx->axis, ctx->keepdims));\n    } else {\n      in_grads->at(0) = JUST(functional::ReduceMaxGlobalStageGrad(\n          out_grads.at(0), mask, device_count, ctx->axis, ctx->keepdims));\n    }\n    return Maybe<void>::Ok();\n  }\n\n private:\n  AttrMap base_attrs_;\n};\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"reduce_min_global_stage\", ReduceGlobal<ReduceMode::kMin>);\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"reduce_max_global_stage\", ReduceGlobal<ReduceMode::kMax>);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/unfold.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/attr_map.h\"\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct UnfoldInterpState : public AutoGradCaptureState {\n  bool requires_grad = true;\n  std::string data_format = \"channels_first\";\n  std::vector<int32_t> output_size;\n  std::vector<int32_t> kernel_size;\n  std::vector<int32_t> dilation_rate;\n  std::vector<int32_t> padding;\n  std::vector<int32_t> strides;\n};\n\nclass Unfold : public OpExprGradFunction<UnfoldInterpState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override;\n  Maybe<void> Capture(UnfoldInterpState* ctx, const TensorTuple& inputs, const TensorTuple& outputs,\n                      const AttrMap& attrs) const override;\n  Maybe<void> Apply(const UnfoldInterpState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override;\n\n private:\n  AttrMap base_attrs_;\n};\n\nMaybe<void> Unfold::Init(const OpExpr& op) {\n  const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n  CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n  base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> Unfold::Capture(UnfoldInterpState* ctx, const TensorTuple& inputs,\n                            const TensorTuple& outputs, const AttrMap& attrs) const {\n  ctx->requires_grad = inputs.at(0)->requires_grad();\n  if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n  ComposedAttrMap composed_attrs(attrs, base_attrs_);\n  std::vector<int32_t> out_shape(2);\n  const std::shared_ptr<Tensor>& x = inputs.at(0);\n  ctx->data_format = JUST(composed_attrs.GetAttr<std::string>(\"data_format\"));\n  ctx->kernel_size = JUST(composed_attrs.GetAttr<std::vector<int32_t>>(\"kernel_size\"));\n  ctx->dilation_rate = JUST(composed_attrs.GetAttr<std::vector<int32_t>>(\"dilation_rate\"));\n  ctx->padding = JUST(composed_attrs.GetAttr<std::vector<int32_t>>(\"padding\"));\n  ctx->strides = JUST(composed_attrs.GetAttr<std::vector<int32_t>>(\"strides\"));\n  // Only support 4-d Tensor Input.\n  for (int i = 0; i < 2; i++) { out_shape.at(i) = (x->shape()->At(i + 2)); }\n  ctx->output_size = out_shape;\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> Unfold::Apply(const UnfoldInterpState* ctx, const TensorTuple& out_grads,\n                          TensorTuple* in_grads) const {\n  if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n  CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n  in_grads->resize(1);\n  in_grads->at(0) =\n      JUST(functional::Fold(out_grads.at(0), ctx->output_size, ctx->kernel_size, ctx->dilation_rate,\n                            ctx->padding, ctx->strides, ctx->data_format));\n  return Maybe<void>::Ok();\n}\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"unfold\", Unfold);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/unfold_tensor.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/framework/op_interpreter/op_interpreter_util.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\n\nnamespace one {\n\nstruct UnfoldTensorCaptureState : public AutoGradCaptureState {\n  int32_t dimension = -1;\n  int32_t size = -1;\n  int32_t step = -1;\n  bool requires_grad = false;\n};\n\nclass UnfoldTensor : public OpExprGradFunction<UnfoldTensorCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override;\n  Maybe<void> Capture(UnfoldTensorCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override;\n  Maybe<void> Apply(const UnfoldTensorCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override;\n\n private:\n  AttrMap base_attrs_;\n  std::shared_ptr<OpExpr> grad_op_;\n};\n\nMaybe<void> UnfoldTensor::Init(const OpExpr& op) {\n  const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n  CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n  base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n  return Maybe<void>::Ok();\n}\nMaybe<void> UnfoldTensor::Capture(UnfoldTensorCaptureState* ctx, const TensorTuple& inputs,\n                                  const TensorTuple& outputs, const AttrMap& attrs) const {\n  ctx->requires_grad = inputs.at(0)->requires_grad();\n  if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n\n  ComposedAttrMap composed_attrs(attrs, base_attrs_);\n  ctx->dimension = JUST(composed_attrs.GetAttr<int32_t>(\"dimension\"));\n  ctx->size = JUST(composed_attrs.GetAttr<int32_t>(\"size\"));\n  ctx->step = JUST(composed_attrs.GetAttr<int32_t>(\"step\"));\n  ctx->SaveTensorForBackward(inputs.at(0));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> UnfoldTensor::Apply(const UnfoldTensorCaptureState* ctx, const TensorTuple& out_grads,\n                                TensorTuple* in_grads) const {\n  if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n  CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n  const auto& in = ctx->SavedTensors().at(0);\n  in_grads->at(0) =\n      JUST(functional::UnfoldTensorGrad(out_grads.at(0), in, ctx->dimension, ctx->size, ctx->step));\n  return Maybe<void>::Ok();\n}\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"unfold_tensor\", UnfoldTensor);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/unsqueeze.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/framework/op_builder.h\"\n#include \"oneflow/core/framework/op_expr.h\"\n#include \"oneflow/core/framework/op_interpreter/op_interpreter_util.h\"\n#include \"oneflow/core/functional/functional.h\"\n#include \"oneflow/core/job/lazy_mode.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct UnsqueezeCaptureState : public AutoGradCaptureState {\n  bool requires_grad;\n  Shape shape;\n};\n\nclass Unsqueeze : public OpExprGradFunction<UnsqueezeCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override;\n  Maybe<void> Capture(UnsqueezeCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override;\n  Maybe<void> Apply(const UnsqueezeCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override;\n\n private:\n  AttrMap base_attrs_;\n};\n\nMaybe<void> Unsqueeze::Init(const OpExpr& op) {\n  const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n  CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n  base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> Unsqueeze::Capture(UnsqueezeCaptureState* ctx, const TensorTuple& inputs,\n                               const TensorTuple& outputs, const AttrMap& attrs) const {\n  ctx->requires_grad = inputs.at(0)->requires_grad();\n  if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n  if (LazyMode::is_enabled()) {\n    ctx->SaveTensorForBackward(inputs.at(0));\n  } else {\n    ctx->shape = *(inputs.at(0)->shape());\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> Unsqueeze::Apply(const UnsqueezeCaptureState* ctx, const TensorTuple& out_grads,\n                             TensorTuple* in_grads) const {\n  if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n  CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n\n  in_grads->resize(1);\n  if (LazyMode::is_enabled()) {\n    const auto& like = ctx->SavedTensors().at(0);\n    in_grads->at(0) = JUST(functional::ReshapeLike(out_grads.at(0), like));\n  } else {\n    in_grads->at(0) = JUST(functional::Reshape(out_grads.at(0), ctx->shape));\n  }\n  return Maybe<void>::Ok();\n}\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"expand_dims\", Unsqueeze);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/upsample.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/functional/functional.h\"\n#include \"oneflow/core/framework/op_interpreter/op_interpreter_util.h\"\n#include \"oneflow/core/common/container_util.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct UpsampleCaptureState : public AutoGradCaptureState {\n  bool requires_grad = false;\n  double height_scale = 0.0;\n  double width_scale = 0.0;\n  float align_corners;\n  std::string data_format;\n  std::string interpolation;\n};\n\nclass Upsample : public OpExprGradFunction<UpsampleCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override;\n  Maybe<void> Capture(UpsampleCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override;\n  Maybe<void> Apply(const UpsampleCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override;\n\n private:\n  AttrMap base_attrs_;\n  std::shared_ptr<OpExpr> grad_op_;\n};\n\nMaybe<void> Upsample::Init(const OpExpr& op) {\n  const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n  CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n  base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> Upsample::Capture(UpsampleCaptureState* ctx, const TensorTuple& inputs,\n                              const TensorTuple& outputs, const AttrMap& attrs) const {\n  ctx->requires_grad = inputs.at(0)->requires_grad();\n  if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n  ComposedAttrMap composed_attrs(attrs, base_attrs_);\n  ctx->height_scale = JUST(composed_attrs.GetAttr<double>(\"height_scale\"));\n  ctx->width_scale = JUST(composed_attrs.GetAttr<double>(\"width_scale\"));\n  ctx->align_corners = JUST(composed_attrs.GetAttr<bool>(\"align_corners\"));\n  ctx->data_format = JUST(composed_attrs.GetAttr<std::string>(\"data_format\"));\n  ctx->interpolation = JUST(composed_attrs.GetAttr<std::string>(\"interpolation\"));\n  ctx->SaveTensorForBackward(inputs.at(0));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> Upsample::Apply(const UpsampleCaptureState* ctx, const TensorTuple& out_grads,\n                            TensorTuple* in_grads) const {\n  if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n  CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n\n  const std::shared_ptr<oneflow::one::Tensor>& x = ctx->SavedTensors().at(0);\n  in_grads->resize(1);\n  JUST(oneflow::VectorAt(*in_grads, 0)) = JUST(functional::UpsampleGrad(\n      JUST(oneflow::VectorAt(out_grads, 0)), x, ctx->height_scale, ctx->width_scale,\n      ctx->align_corners, ctx->data_format, ctx->interpolation));\n  return Maybe<void>::Ok();\n}\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"upsample\", Upsample);\n\nstruct UpsampleNearest2DCaptureState : public AutoGradCaptureState {\n  bool requires_grad = false;\n  double height_scale = 0.0;\n  double width_scale = 0.0;\n  std::vector<int64_t> output_size;\n  std::string data_format;\n};\n\nclass UpsampleNearest2D : public OpExprGradFunction<UpsampleNearest2DCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }\n\n  Maybe<void> Capture(UpsampleNearest2DCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    CHECK_EQ_OR_RETURN(inputs.size(), 1);   // NOLINT(maybe-need-error-msg)\n    CHECK_EQ_OR_RETURN(outputs.size(), 1);  // NOLINT(maybe-need-error-msg)\n    ctx->requires_grad = inputs.at(0)->requires_grad();\n    if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n    ComposedAttrMap composed_attrs(attrs, base_attrs_);\n    ctx->height_scale = JUST(composed_attrs.GetAttr<double>(\"height_scale\"));\n    ctx->width_scale = JUST(composed_attrs.GetAttr<double>(\"width_scale\"));\n    if (composed_attrs.Has(\"output_size\")) {\n      ctx->output_size = JUST(composed_attrs.GetAttr<std::vector<int64_t>>(\"output_size\"));\n    }\n    ctx->data_format = JUST(composed_attrs.GetAttr<std::string>(\"data_format\"));\n    ctx->SaveTensorForBackward(inputs.at(0));\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const UpsampleNearest2DCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n    CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n    const std::shared_ptr<oneflow::one::Tensor>& x = ctx->SavedTensors().at(0);\n    in_grads->resize(1);\n    JUST(oneflow::VectorAt(*in_grads, 0)) = JUST(functional::UpsampleNearest2DGrad(\n        JUST(oneflow::VectorAt(out_grads, 0)), x, ctx->height_scale, ctx->width_scale,\n        ctx->output_size, ctx->data_format));\n\n    return Maybe<void>::Ok();\n  }\n\n private:\n  AttrMap base_attrs_;\n};\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"upsample_nearest_2d\", UpsampleNearest2D);\n\nstruct UpsampleBilinear2DCaptureState : public AutoGradCaptureState {\n  bool requires_grad = false;\n  double height_scale = 0.0;\n  double width_scale = 0.0;\n  bool align_corners;\n  std::vector<int64_t> output_size;\n  std::string data_format;\n};\n\nclass UpsampleBilinear2D : public OpExprGradFunction<UpsampleBilinear2DCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }\n\n  Maybe<void> Capture(UpsampleBilinear2DCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    CHECK_EQ_OR_RETURN(inputs.size(), 1);   // NOLINT(maybe-need-error-msg)\n    CHECK_EQ_OR_RETURN(outputs.size(), 1);  // NOLINT(maybe-need-error-msg)\n    ctx->requires_grad = inputs.at(0)->requires_grad();\n    if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n    ComposedAttrMap composed_attrs(attrs, base_attrs_);\n    ctx->height_scale = JUST(composed_attrs.GetAttr<double>(\"height_scale\"));\n    ctx->width_scale = JUST(composed_attrs.GetAttr<double>(\"width_scale\"));\n    ctx->align_corners = JUST(composed_attrs.GetAttr<bool>(\"align_corners\"));\n    if (composed_attrs.Has(\"output_size\")) {\n      ctx->output_size = JUST(composed_attrs.GetAttr<std::vector<int64_t>>(\"output_size\"));\n    }\n    ctx->data_format = JUST(composed_attrs.GetAttr<std::string>(\"data_format\"));\n    ctx->SaveTensorForBackward(inputs.at(0));\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const UpsampleBilinear2DCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n    CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n    const std::shared_ptr<oneflow::one::Tensor>& x = ctx->SavedTensors().at(0);\n    in_grads->resize(1);\n    JUST(oneflow::VectorAt(*in_grads, 0)) = JUST(functional::UpsampleBilinear2DGrad(\n        JUST(oneflow::VectorAt(out_grads, 0)), x, ctx->height_scale, ctx->width_scale,\n        ctx->align_corners, ctx->output_size, ctx->data_format));\n\n    return Maybe<void>::Ok();\n  }\n\n private:\n  AttrMap base_attrs_;\n};\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"upsample_bilinear_2d\", UpsampleBilinear2D);\n\nstruct UpsampleLinear1DCaptureState : public AutoGradCaptureState {\n  bool requires_grad = false;\n  double scale_factor = 0.0;\n  bool align_corners;\n  std::vector<int64_t> output_size;\n  std::string data_format;\n};\n\nclass UpsampleLinear1D : public OpExprGradFunction<UpsampleLinear1DCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }\n\n  Maybe<void> Capture(UpsampleLinear1DCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    CHECK_EQ_OR_RETURN(inputs.size(), 1);   // NOLINT(maybe-need-error-msg)\n    CHECK_EQ_OR_RETURN(outputs.size(), 1);  // NOLINT(maybe-need-error-msg)\n    ctx->requires_grad = inputs.at(0)->requires_grad();\n    if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n    ComposedAttrMap composed_attrs(attrs, base_attrs_);\n    ctx->scale_factor = JUST(composed_attrs.GetAttr<double>(\"scale_factor\"));\n    ctx->align_corners = JUST(composed_attrs.GetAttr<bool>(\"align_corners\"));\n    if (composed_attrs.Has(\"output_size\")) {\n      ctx->output_size = JUST(composed_attrs.GetAttr<std::vector<int64_t>>(\"output_size\"));\n    }\n    ctx->data_format = JUST(composed_attrs.GetAttr<std::string>(\"data_format\"));\n    ctx->SaveTensorForBackward(inputs.at(0));\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const UpsampleLinear1DCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n    CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n    const std::shared_ptr<oneflow::one::Tensor>& x = ctx->SavedTensors().at(0);\n    in_grads->resize(1);\n    JUST(oneflow::VectorAt(*in_grads, 0)) = JUST(functional::UpsampleLinear1DGrad(\n        JUST(oneflow::VectorAt(out_grads, 0)), x, ctx->scale_factor, ctx->align_corners,\n        ctx->output_size, ctx->data_format));\n\n    return Maybe<void>::Ok();\n  }\n\n private:\n  AttrMap base_attrs_;\n};\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"upsample_linear_1d\", UpsampleLinear1D);\n\nstruct UpsampleNearest1DCaptureState : public AutoGradCaptureState {\n  bool requires_grad = false;\n  double scale_factor = 0.0;\n  std::vector<int64_t> output_size;\n  std::string data_format;\n};\n\nclass UpsampleNearest1D : public OpExprGradFunction<UpsampleNearest1DCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }\n\n  Maybe<void> Capture(UpsampleNearest1DCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    CHECK_EQ_OR_RETURN(inputs.size(), 1);   // NOLINT(maybe-need-error-msg)\n    CHECK_EQ_OR_RETURN(outputs.size(), 1);  // NOLINT(maybe-need-error-msg)\n    ctx->requires_grad = inputs.at(0)->requires_grad();\n    if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n    ComposedAttrMap composed_attrs(attrs, base_attrs_);\n    ctx->scale_factor = JUST(composed_attrs.GetAttr<double>(\"scale_factor\"));\n    if (composed_attrs.Has(\"output_size\")) {\n      ctx->output_size = JUST(composed_attrs.GetAttr<std::vector<int64_t>>(\"output_size\"));\n    }\n    ctx->data_format = JUST(composed_attrs.GetAttr<std::string>(\"data_format\"));\n    ctx->SaveTensorForBackward(inputs.at(0));\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const UpsampleNearest1DCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n    CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n    const std::shared_ptr<oneflow::one::Tensor>& x = ctx->SavedTensors().at(0);\n    in_grads->resize(1);\n    JUST(oneflow::VectorAt(*in_grads, 0)) = JUST(\n        functional::UpsampleNearest1DGrad(JUST(oneflow::VectorAt(out_grads, 0)), x,\n                                          ctx->scale_factor, ctx->output_size, ctx->data_format));\n\n    return Maybe<void>::Ok();\n  }\n\n private:\n  AttrMap base_attrs_;\n};\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"upsample_nearest_1d\", UpsampleNearest1D);\n\nstruct UpsampleBicubic2DCaptureState : public AutoGradCaptureState {\n  bool requires_grad = false;\n  double height_scale = 0.0;\n  double width_scale = 0.0;\n  bool align_corners;\n  std::vector<int64_t> output_size;\n  std::string data_format;\n};\n\nclass UpsampleBicubic2D : public OpExprGradFunction<UpsampleBicubic2DCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }\n\n  Maybe<void> Capture(UpsampleBicubic2DCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    CHECK_EQ_OR_RETURN(inputs.size(), 1);   // NOLINT(maybe-need-error-msg)\n    CHECK_EQ_OR_RETURN(outputs.size(), 1);  // NOLINT(maybe-need-error-msg)\n    ctx->requires_grad = inputs.at(0)->requires_grad();\n    if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n    ComposedAttrMap composed_attrs(attrs, base_attrs_);\n    ctx->height_scale = JUST(composed_attrs.GetAttr<double>(\"height_scale\"));\n    ctx->width_scale = JUST(composed_attrs.GetAttr<double>(\"width_scale\"));\n    ctx->align_corners = JUST(composed_attrs.GetAttr<bool>(\"align_corners\"));\n    if (composed_attrs.Has(\"output_size\")) {\n      ctx->output_size = JUST(composed_attrs.GetAttr<std::vector<int64_t>>(\"output_size\"));\n    }\n    ctx->data_format = JUST(composed_attrs.GetAttr<std::string>(\"data_format\"));\n    ctx->SaveTensorForBackward(inputs.at(0));\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const UpsampleBicubic2DCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n    CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n    const std::shared_ptr<oneflow::one::Tensor>& x = ctx->SavedTensors().at(0);\n    in_grads->resize(1);\n    JUST(oneflow::VectorAt(*in_grads, 0)) = JUST(functional::UpsampleBicubic2DGrad(\n        JUST(oneflow::VectorAt(out_grads, 0)), x, ctx->height_scale, ctx->width_scale,\n        ctx->align_corners, ctx->output_size, ctx->data_format));\n    return Maybe<void>::Ok();\n  }\n\n private:\n  AttrMap base_attrs_;\n};\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"upsample_bicubic_2d\", UpsampleBicubic2D);\n\nstruct UpsampleNearest3DCaptureState : public AutoGradCaptureState {\n  bool requires_grad = false;\n  double depth_scale = 0.0;\n  double height_scale = 0.0;\n  double width_scale = 0.0;\n  std::vector<int64_t> output_size;\n  std::string data_format;\n};\n\nclass UpsampleNearest3D : public OpExprGradFunction<UpsampleNearest3DCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }\n\n  Maybe<void> Capture(UpsampleNearest3DCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    CHECK_EQ_OR_RETURN(inputs.size(), 1);   // NOLINT(maybe-need-error-msg)\n    CHECK_EQ_OR_RETURN(outputs.size(), 1);  // NOLINT(maybe-need-error-msg)\n    ctx->requires_grad = inputs.at(0)->requires_grad();\n    if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n    ComposedAttrMap composed_attrs(attrs, base_attrs_);\n    ctx->depth_scale = JUST(composed_attrs.GetAttr<double>(\"depth_scale\"));\n    ctx->height_scale = JUST(composed_attrs.GetAttr<double>(\"height_scale\"));\n    ctx->width_scale = JUST(composed_attrs.GetAttr<double>(\"width_scale\"));\n    if (composed_attrs.Has(\"output_size\")) {\n      ctx->output_size = JUST(composed_attrs.GetAttr<std::vector<int64_t>>(\"output_size\"));\n    }\n    ctx->data_format = JUST(composed_attrs.GetAttr<std::string>(\"data_format\"));\n    ctx->SaveTensorForBackward(inputs.at(0));\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const UpsampleNearest3DCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n    CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n    const std::shared_ptr<oneflow::one::Tensor>& x = ctx->SavedTensors().at(0);\n    in_grads->resize(1);\n    JUST(oneflow::VectorAt(*in_grads, 0)) = JUST(functional::UpsampleNearest3DGrad(\n        JUST(oneflow::VectorAt(out_grads, 0)), x, ctx->depth_scale, ctx->height_scale,\n        ctx->width_scale, ctx->output_size, ctx->data_format));\n\n    return Maybe<void>::Ok();\n  }\n\n private:\n  AttrMap base_attrs_;\n};\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"upsample_nearest_3d\", UpsampleNearest3D);\n\nstruct UpsampleTrilinear3DCaptureState : public AutoGradCaptureState {\n  bool requires_grad = false;\n  double depth_scale = 0.0;\n  double height_scale = 0.0;\n  double width_scale = 0.0;\n  bool align_corners;\n  std::vector<int64_t> output_size;\n  std::string data_format;\n};\n\nclass UpsampleTrilinear3D : public OpExprGradFunction<UpsampleTrilinear3DCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }\n\n  Maybe<void> Capture(UpsampleTrilinear3DCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    CHECK_EQ_OR_RETURN(inputs.size(), 1);   // NOLINT(maybe-need-error-msg)\n    CHECK_EQ_OR_RETURN(outputs.size(), 1);  // NOLINT(maybe-need-error-msg)\n    ctx->requires_grad = inputs.at(0)->requires_grad();\n    if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n    ComposedAttrMap composed_attrs(attrs, base_attrs_);\n    ctx->depth_scale = JUST(composed_attrs.GetAttr<double>(\"depth_scale\"));\n    ctx->height_scale = JUST(composed_attrs.GetAttr<double>(\"height_scale\"));\n    ctx->width_scale = JUST(composed_attrs.GetAttr<double>(\"width_scale\"));\n    ctx->align_corners = JUST(composed_attrs.GetAttr<bool>(\"align_corners\"));\n    if (composed_attrs.Has(\"output_size\")) {\n      ctx->output_size = JUST(composed_attrs.GetAttr<std::vector<int64_t>>(\"output_size\"));\n    }\n    ctx->data_format = JUST(composed_attrs.GetAttr<std::string>(\"data_format\"));\n    ctx->SaveTensorForBackward(inputs.at(0));\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const UpsampleTrilinear3DCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n    CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n    const std::shared_ptr<oneflow::one::Tensor>& x = ctx->SavedTensors().at(0);\n    in_grads->resize(1);\n    JUST(oneflow::VectorAt(*in_grads, 0)) = JUST(functional::UpsampleTrilinear3DGrad(\n        JUST(oneflow::VectorAt(out_grads, 0)), x, ctx->depth_scale, ctx->height_scale,\n        ctx->width_scale, ctx->align_corners, ctx->output_size, ctx->data_format));\n\n    return Maybe<void>::Ok();\n  }\n\n private:\n  AttrMap base_attrs_;\n};\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"upsample_trilinear_3d\", UpsampleTrilinear3D);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/variance.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/attr_map.h\"\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/framework/op_builder.h\"\n#include \"oneflow/core/framework/op_interpreter/op_interpreter_util.h\"\n#include \"oneflow/core/framework/op_expr.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct VarianceState : public AutoGradCaptureState {\n  VarianceState() : requires_grad(false), unbiased(true), keepdim(false), axis({}){};\n  bool requires_grad;\n  bool unbiased;\n  bool keepdim;\n  std::vector<int32_t> axis;\n};\n\nclass Variance : public OpExprGradFunction<VarianceState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override;\n  Maybe<void> Capture(VarianceState* ctx, const TensorTuple& inputs, const TensorTuple& outputs,\n                      const AttrMap& attrs) const override;\n  Maybe<void> Apply(const VarianceState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override;\n\n private:\n  AttrMap base_attrs_;\n};\n\nMaybe<void> Variance::Init(const OpExpr& op) {\n  const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n  CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n  base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> Variance::Capture(VarianceState* ctx, const TensorTuple& inputs,\n                              const TensorTuple& outputs, const AttrMap& attrs) const {\n  CHECK_EQ_OR_RETURN(inputs.size(), 1);   // NOLINT(maybe-need-error-msg)\n  CHECK_EQ_OR_RETURN(outputs.size(), 1);  // NOLINT(maybe-need-error-msg)\n  ctx->requires_grad = inputs.at(0)->requires_grad();\n  if (!ctx->requires_grad) { return Maybe<void>::Ok(); }\n  ComposedAttrMap composed_attrs(attrs, base_attrs_);\n  ctx->keepdim = JUST(composed_attrs.GetAttr<bool>(\"keepdim\"));\n  ctx->unbiased = JUST(composed_attrs.GetAttr<bool>(\"unbiased\"));\n  ctx->axis = JUST(composed_attrs.GetAttr<std::vector<int32_t>>(\"dim\"));\n  ctx->SaveTensorForBackward(inputs.at(0));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> Variance::Apply(const VarianceState* ctx, const TensorTuple& out_grads,\n                            TensorTuple* in_grads) const {\n  // TODO(): replace it using kernel\n  const std::shared_ptr<oneflow::one::Tensor>& x = ctx->SavedTensors().at(0);\n  DataType data_type = x->dtype()->data_type();\n  CHECK_NE_OR_RETURN(data_type, DataType::kBFloat16)\n      << Error::RuntimeError() << \"Variance op not support backward for bfloat16 yet!\";\n  size_t correction = ctx->unbiased ? 1 : 0;\n  size_t elem_cnt = 1;\n  CHECK_OR_RETURN(ctx->axis.size() > 0)\n      << Error::RuntimeError() << \"The size of the axis must greater than 0, but got \"\n      << ctx->axis.size();\n  for (const auto& item : ctx->axis) { elem_cnt *= x->shape()->At(item); }\n\n  std::shared_ptr<Tensor> out_grad = out_grads.at(0);\n  if (ctx->keepdim == false) {\n    // for broadcast mul\n    const std::shared_ptr<const Shape>& out_grad_shape = out_grad->shape();\n    DimVector unsqueeze_vector(out_grad_shape->dim_vec());\n    for (int i = 0; i < ctx->axis.size(); i++) {\n      unsqueeze_vector.insert(unsqueeze_vector.begin() + ctx->axis.at(i), 1);\n    }\n    Shape unsqueeze_shape(unsqueeze_vector);\n    CHECK_EQ_OR_RETURN(unsqueeze_shape.elem_cnt(), out_grad_shape->elem_cnt())\n        << Error::RuntimeError()\n        << \"tensor size mismatch, expected tensor to have the same number of elements, but got \"\n        << unsqueeze_shape.elem_cnt() << \" and \" << out_grad_shape->elem_cnt()\n        << \" elements respectively\";\n    out_grad = JUST(functional::Reshape(out_grad, unsqueeze_shape));\n  }\n\n  in_grads->resize(1);\n  in_grads->at(0) = JUST(functional::Mul(\n      out_grad,\n      JUST(functional::ScalarMul(\n          Scalar(2.0 / (elem_cnt - correction)),\n          JUST(functional::Sub(x, JUST(functional::ReduceMean(x, ctx->axis, /*keepdim=*/true)),\n                               /*alpha=*/1.0, /*inplace=*/false))))));\n\n  return Maybe<void>::Ok();\n}\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"var\", Variance);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/vector_matrix_product.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/framework/op_builder.h\"\n#include \"oneflow/core/framework/op_expr.h\"\n#include \"oneflow/core/framework/op_interpreter/op_interpreter_util.h\"\n#include \"oneflow/core/functional/functional.h\"\n#include \"oneflow/core/common/container_util.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct VectorMatrixProductCaptureState : public AutoGradCaptureState {\n  bool requires_grad_a = false;\n  bool requires_grad_b = false;\n  size_t a_index = 0;\n  size_t b_index = 1;\n};\n\nclass VectorMatrixProduct : public OpExprGradFunction<VectorMatrixProductCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override;\n  Maybe<void> Capture(VectorMatrixProductCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override;\n  Maybe<void> Apply(const VectorMatrixProductCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override;\n\n protected:\n  AttrMap base_attrs_;\n};\n\nMaybe<void> VectorMatrixProduct::Init(const OpExpr& op) {\n  const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n  CHECK_NOTNULL_OR_RETURN(fw_op_expr) << \"fw_op_expr should not be null. \";\n  base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> VectorMatrixProduct::Capture(VectorMatrixProductCaptureState* ctx,\n                                         const TensorTuple& inputs, const TensorTuple& outputs,\n                                         const AttrMap& attrs) const {\n  ctx->requires_grad_a = JUST(VectorAt(inputs, 0))->requires_grad();\n  ctx->requires_grad_b = JUST(VectorAt(inputs, 1))->requires_grad();\n  if (!ctx->requires_grad_a && !ctx->requires_grad_b) { return Maybe<void>::Ok(); }\n\n  ComposedAttrMap composed_attrs(attrs, base_attrs_);\n  if (ctx->requires_grad_a) {\n    ctx->b_index = ctx->SaveTensorForBackward(JUST(VectorAt(inputs, 1)));  // input b\n  }\n  if (ctx->requires_grad_b) {\n    ctx->a_index = ctx->SaveTensorForBackward(JUST(VectorAt(inputs, 0)));  // input a\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> VectorMatrixProduct::Apply(const VectorMatrixProductCaptureState* ctx,\n                                       const TensorTuple& out_grads, TensorTuple* in_grads) const {\n  if (!ctx->requires_grad_a && !ctx->requires_grad_b) { return Maybe<void>::Ok(); }\n  CHECK_EQ_OR_RETURN(out_grads.size(), 1) << \"Out grad size should be equal to 1. \";\n\n  in_grads->resize(2);\n  if (ctx->requires_grad_a) {\n    const auto& input_b = JUST(VectorAt(ctx->SavedTensors(), ctx->b_index));\n    JUST(VectorAt(*in_grads, 0)) =\n        JUST(functional::VectorMatrixProductGradA(JUST(VectorAt(out_grads, 0)), input_b));\n  }\n\n  if (ctx->requires_grad_b) {\n    const auto& input_a = JUST(oneflow::VectorAt(ctx->SavedTensors(), ctx->a_index));\n    JUST(VectorAt(*in_grads, 1)) =\n        JUST(functional::VectorMatrixProductGradB(JUST(VectorAt(out_grads, 0)), input_a));\n  }\n\n  return Maybe<void>::Ok();\n}\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"vector_matrix_product\", VectorMatrixProduct);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/gradient_funcs/where.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/functional/functional.h\"\n#include \"oneflow/core/functional/impl/common.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct WhereCaptureState : public AutoGradCaptureState {\n  bool requires_grad_x = false;\n  bool requires_grad_y = false;\n  DimVector x_reduce_dims = {};\n  DimVector y_reduce_dims = {};\n  DimVector x_squeeze_dims = {};\n  DimVector y_squeeze_dims = {};\n};\n\nclass Where : public OpExprGradFunction<WhereCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override;\n  Maybe<void> Capture(WhereCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs,\n                      const AttrMap& attrs) const override;\n  Maybe<void> Apply(const WhereCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override;\n};\n\nMaybe<void> Where::Init(const OpExpr& op) { return Maybe<void>::Ok(); }\n\nMaybe<void> Where::Capture(WhereCaptureState* ctx, const TensorTuple& inputs,\n                           const TensorTuple& outputs, const AttrMap& attrs) const {\n  // cond, x, y\n  CHECK_EQ_OR_RETURN(inputs.size(), 3);  // NOLINT(maybe-need-error-msg)\n  ctx->requires_grad_x = inputs.at(1)->requires_grad();\n  ctx->requires_grad_y = inputs.at(2)->requires_grad();\n  if ((!ctx->requires_grad_x) && (!ctx->requires_grad_y)) { return Maybe<void>::Ok(); }\n\n  ctx->SaveTensorForBackward(inputs.at(0));  // condition\n\n  CHECK_EQ_OR_RETURN(outputs.size(), 1);\n  const Shape& out_shape = *outputs.at(0)->shape();\n  auto GetReduceDims = [&](DimVector& reduce_dim_vec, DimVector& squeeze_dim_vec,\n                           const std::shared_ptr<oneflow::one::Tensor>& tensor) -> Maybe<void> {\n    reduce_dim_vec.clear();\n    squeeze_dim_vec.clear();\n    const Shape& shape = *tensor->shape();\n    if (functional::IsScalarTensor(tensor)) {\n      reduce_dim_vec.resize(out_shape.size());\n      squeeze_dim_vec.resize(out_shape.size());\n      std::iota(reduce_dim_vec.begin(), reduce_dim_vec.end(), 0);\n      std::iota(squeeze_dim_vec.begin(), squeeze_dim_vec.end(), 0);\n    } else if (shape != out_shape) {\n      CHECK_GE_OR_RETURN(out_shape.size(), shape.size());  // NOLINT(maybe-need-error-msg)\n      size_t ddiff = out_shape.size() - shape.size();\n      for (int i = 0; i < out_shape.size(); ++i) {\n        if (i < ddiff) {\n          reduce_dim_vec.push_back(i);\n          squeeze_dim_vec.push_back(i);\n        } else if (out_shape[i] != shape[i - ddiff]) {\n          reduce_dim_vec.push_back(i);\n        }\n      }\n    }\n    return Maybe<void>::Ok();\n  };\n  JUST(GetReduceDims(ctx->x_reduce_dims, ctx->x_squeeze_dims, inputs.at(1)));\n  JUST(GetReduceDims(ctx->y_reduce_dims, ctx->y_squeeze_dims, inputs.at(2)));\n\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> Where::Apply(const WhereCaptureState* ctx, const TensorTuple& out_grads,\n                         TensorTuple* in_grads) const {\n  if ((!ctx->requires_grad_x) && (!ctx->requires_grad_y)) { return Maybe<void>::Ok(); }\n  CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n  const auto& out_grad = out_grads.at(0);\n  CHECK_EQ_OR_RETURN(ctx->SavedTensors().size(), 1);  // NOLINT(maybe-need-error-msg)\n  const auto& condition = ctx->SavedTensors().at(0);\n  std::shared_ptr<oneflow::one::Tensor> zero;\n  if (out_grad->is_local()) {\n    zero = JUST(\n        functional::Constant(Shape({}), Scalar(0), out_grad->dtype(), JUST(out_grad->device())));\n  } else {\n    const size_t sbp_ndim = JUST(out_grad->nd_sbp())->sbp_parallel_size();\n    std::vector<Symbol<SbpParallel>> nd_sbp_vec;\n    nd_sbp_vec.reserve(sbp_ndim);\n    for (int i = 0; i < sbp_ndim; ++i) {\n      SbpParallel sbp;\n      sbp.mutable_broadcast_parallel();\n      nd_sbp_vec.push_back(SymbolOf(sbp));\n    }\n    const auto& parallel_desc = JUST(out_grad->parallel_desc());\n    zero = JUST(functional::GlobalConstant(Shape({}), Scalar(0), out_grad->dtype(), parallel_desc,\n                                           nd_sbp_vec));\n  }\n  in_grads->resize(3);  // cond, x, y\n  if (ctx->requires_grad_x) {\n    auto x_grad = JUST(functional::Where(condition, out_grad, zero));\n    if (!ctx->x_reduce_dims.empty()) {\n      x_grad = JUST(functional::ReduceSum(\n          x_grad, std::vector<int32_t>{ctx->x_reduce_dims.begin(), ctx->x_reduce_dims.end()},\n          /*keepdims=*/true, NullOpt));\n    }\n    if (!ctx->x_squeeze_dims.empty()) {\n      x_grad = JUST(functional::Squeeze(\n          x_grad, std::vector<int32_t>{ctx->x_squeeze_dims.begin(), ctx->x_squeeze_dims.end()}));\n    }\n    in_grads->at(1) = x_grad;\n  }\n  if (ctx->requires_grad_y) {\n    auto y_grad = JUST(functional::Where(condition, zero, out_grad));\n    if (!ctx->y_reduce_dims.empty()) {\n      y_grad = JUST(functional::ReduceSum(\n          y_grad, std::vector<int32_t>{ctx->y_reduce_dims.begin(), ctx->y_reduce_dims.end()},\n          /*keepdims=*/true, NullOpt));\n    }\n    if (!ctx->y_squeeze_dims.empty()) {\n      y_grad = JUST(functional::Squeeze(\n          y_grad, std::vector<int32_t>{ctx->y_squeeze_dims.begin(), ctx->y_squeeze_dims.end()}));\n    }\n    in_grads->at(2) = y_grad;\n  }\n  return Maybe<void>::Ok();\n}\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"where\", Where);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/higher_order_gradient_funcs/activation.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <cstddef>\n#include \"oneflow/core/common/container_util.h\"\n#include \"oneflow/core/common/scalar.h\"\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/functional/functional.h\"\n#include \"oneflow/core/functional/functional_api.yaml.h\"\n#include \"oneflow/core/functional/sequence_function.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct BaseActivationGradGradCaptureState : public AutoGradCaptureState {\n  bool x_requires_grad = false;\n  bool grad_requires_grad = false;\n};\n\ntypedef Maybe<one::Tensor> (*NoParamActivationBwFunc)(const std::shared_ptr<one::Tensor>&,\n                                                      const std::shared_ptr<one::Tensor>&);\n\ntemplate<NoParamActivationBwFunc BwFunc, NoParamActivationBwFunc BwBwFunc>\nclass NoParamActivationGradGrad : public OpExprGradFunction<BaseActivationGradGradCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }\n\n  Maybe<void> Capture(BaseActivationGradGradCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    // dy, x\n    CHECK_EQ_OR_RETURN(inputs.size(), 2);   // NOLINT(maybe-need-error-msg)\n    CHECK_EQ_OR_RETURN(outputs.size(), 1);  // NOLINT(maybe-need-error-msg)\n\n    ctx->x_requires_grad = inputs.at(1)->requires_grad();\n    ctx->grad_requires_grad = inputs.at(0)->requires_grad();\n\n    if (!ctx->x_requires_grad && !ctx->grad_requires_grad) { return Maybe<void>::Ok(); }\n\n    ctx->SaveTensorForBackward(inputs.at(1));\n    if (ctx->x_requires_grad) { ctx->SaveTensorForBackward(inputs.at(0)); }\n\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const BaseActivationGradGradCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    in_grads->resize(2);\n    const auto& x = ctx->SavedTensors().at(0);\n\n    if (ctx->x_requires_grad) {\n      const auto& grad = ctx->SavedTensors().at(1);\n      in_grads->at(1) = JUST(functional::Mul(out_grads.at(0), JUST(BwBwFunc(x, grad))));\n    }\n    if (ctx->grad_requires_grad) { in_grads->at(0) = JUST(BwFunc(out_grads.at(0), x)); }\n    return Maybe<void>::Ok();\n  }\n};\n\n#define INSTANTIAT_AND_REGISTER_NOPARAM_ACTIVATION_CLASS(op_type_name, op_cls)                     \\\n  class op_cls##GradGradCls final                                                                  \\\n      : public NoParamActivationGradGrad<functional::op_cls##Grad, functional::op_cls##GradGrad> { \\\n  };                                                                                               \\\n  REGISTER_OP_EXPR_GRAD_FUNCTION(op_type_name, op_cls##GradGradCls);\n\n// first order backward param: (dy, x)\nINSTANTIAT_AND_REGISTER_NOPARAM_ACTIVATION_CLASS(\"mish_grad\", Mish)\nINSTANTIAT_AND_REGISTER_NOPARAM_ACTIVATION_CLASS(\"gelu_grad\", Gelu)\nINSTANTIAT_AND_REGISTER_NOPARAM_ACTIVATION_CLASS(\"silu_grad\", Silu)\nINSTANTIAT_AND_REGISTER_NOPARAM_ACTIVATION_CLASS(\"selu_grad\", Selu)\nINSTANTIAT_AND_REGISTER_NOPARAM_ACTIVATION_CLASS(\"softsign_grad\", SoftSign)\nINSTANTIAT_AND_REGISTER_NOPARAM_ACTIVATION_CLASS(\"hardsigmoid_grad\", HardSigmoid)\nINSTANTIAT_AND_REGISTER_NOPARAM_ACTIVATION_CLASS(\"hardswish_grad\", HardSwish)\n\n#undef INSTANTIAT_AND_REGISTER_NOPARAM_ACTIVATION_CLASS\n\nstruct HardShrinkGradGradCaptureState : public AutoGradCaptureState {\n  bool y_requires_grad = false;\n  bool grad_requires_grad = false;\n  double lambd = 0.5;\n};\n\nclass HardShrinkGradGrad : public OpExprGradFunction<HardShrinkGradGradCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override {\n    const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n    CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n    base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n    return Maybe<void>::Ok();\n  }\n  Maybe<void> Capture(HardShrinkGradGradCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    // y, dy\n    CHECK_EQ_OR_RETURN(inputs.size(), 2);   // NOLINT(maybe-need-error-msg)\n    CHECK_EQ_OR_RETURN(outputs.size(), 1);  // NOLINT(maybe-need-error-msg)\n\n    ctx->y_requires_grad = inputs.at(0)->requires_grad();\n    ctx->grad_requires_grad = inputs.at(1)->requires_grad();\n    if (!ctx->y_requires_grad && !ctx->grad_requires_grad) { return Maybe<void>::Ok(); }\n\n    ComposedAttrMap composed_attrs(attrs, base_attrs_);\n    ctx->lambd = JUST(composed_attrs.GetAttr<double>(\"lambd\"));\n    if (ctx->grad_requires_grad) { ctx->SaveTensorForBackward(inputs.at(0)); }\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const HardShrinkGradGradCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    in_grads->resize(2);\n\n    if (ctx->y_requires_grad) { in_grads->at(0) = JUST(functional::ZerosLike(out_grads.at(0))); }\n    if (ctx->grad_requires_grad) {\n      const auto& y = ctx->SavedTensors().at(0);\n      in_grads->at(1) = JUST(functional::HardShrinkGrad(y, out_grads.at(0), ctx->lambd));\n    }\n    return Maybe<void>::Ok();\n  }\n\n private:\n  AttrMap base_attrs_;\n};\n\nstruct SoftShrinkGradGradCaptureState : public AutoGradCaptureState {\n  bool y_requires_grad = false;\n  bool grad_requires_grad = false;\n  double alpha = 0.5;\n};\n\nclass SoftShrinkGradGrad : public OpExprGradFunction<SoftShrinkGradGradCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override {\n    const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n    CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n    base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n    return Maybe<void>::Ok();\n  }\n  Maybe<void> Capture(SoftShrinkGradGradCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    // y, dy\n    CHECK_EQ_OR_RETURN(inputs.size(), 2);   // NOLINT(maybe-need-error-msg)\n    CHECK_EQ_OR_RETURN(outputs.size(), 1);  // NOLINT(maybe-need-error-msg)\n\n    ctx->y_requires_grad = inputs.at(0)->requires_grad();\n    ctx->grad_requires_grad = inputs.at(1)->requires_grad();\n    if (!ctx->y_requires_grad && !ctx->grad_requires_grad) { return Maybe<void>::Ok(); }\n\n    ComposedAttrMap composed_attrs(attrs, base_attrs_);\n    ctx->alpha = JUST(composed_attrs.GetAttr<double>(\"alpha\"));\n    if (ctx->grad_requires_grad) { ctx->SaveTensorForBackward(inputs.at(0)); }\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const SoftShrinkGradGradCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    in_grads->resize(2);\n\n    if (ctx->y_requires_grad) { in_grads->at(0) = JUST(functional::ZerosLike(out_grads.at(0))); }\n    if (ctx->grad_requires_grad) {\n      const auto& y = ctx->SavedTensors().at(0);\n      in_grads->at(1) = JUST(functional::SoftShrinkGrad(y, out_grads.at(0), ctx->alpha));\n    }\n    return Maybe<void>::Ok();\n  }\n\n private:\n  AttrMap base_attrs_;\n};\n\nstruct ReluGradGradCaptureState : public AutoGradCaptureState {\n  bool y_requires_grad = false;\n  bool grad_requires_grad = false;\n};\n\nclass ReluGradGrad : public OpExprGradFunction<ReluGradGradCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }\n  Maybe<void> Capture(ReluGradGradCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    // dy, y\n    CHECK_EQ_OR_RETURN(inputs.size(), 2);   // NOLINT(maybe-need-error-msg)\n    CHECK_EQ_OR_RETURN(outputs.size(), 1);  // NOLINT(maybe-need-error-msg)\n\n    ctx->y_requires_grad = inputs.at(1)->requires_grad();\n    ctx->grad_requires_grad = inputs.at(0)->requires_grad();\n\n    if (ctx->grad_requires_grad) { ctx->SaveTensorForBackward(inputs.at(1)); }\n    return Maybe<void>::Ok();\n  }\n  Maybe<void> Apply(const ReluGradGradCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    in_grads->resize(2);\n    if (ctx->y_requires_grad) { in_grads->at(1) = JUST(functional::ZerosLike(out_grads.at(0))); }\n    if (ctx->grad_requires_grad) {\n      const auto& y = ctx->SavedTensors().at(0);\n      in_grads->at(0) = JUST(functional::ReluGrad(out_grads.at(0), y));\n    }\n    return Maybe<void>::Ok();\n  }\n};\n\nstruct LeakyReluGradGradCaptureState : public AutoGradCaptureState {\n  bool x_requires_grad = false;\n  bool grad_requires_grad = false;\n  float alpha = 0.01;\n};\n\nclass LeakyReluGradGrad : public OpExprGradFunction<LeakyReluGradGradCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override {\n    const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n    CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n    base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Capture(LeakyReluGradGradCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    // x, dy\n    CHECK_EQ_OR_RETURN(inputs.size(), 2);   // NOLINT(maybe-need-error-msg)\n    CHECK_EQ_OR_RETURN(outputs.size(), 1);  // NOLINT(maybe-need-error-msg)\n\n    ctx->x_requires_grad = inputs.at(0)->requires_grad();\n    ctx->grad_requires_grad = inputs.at(1)->requires_grad();\n    if (!ctx->x_requires_grad && !ctx->grad_requires_grad) { return Maybe<void>::Ok(); }\n\n    ComposedAttrMap composed_attrs(attrs, base_attrs_);\n    ctx->alpha = JUST(composed_attrs.GetAttr<float>(\"alpha\"));\n\n    if (ctx->grad_requires_grad) { ctx->SaveTensorForBackward(inputs.at(0)); }\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const LeakyReluGradGradCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    in_grads->resize(2);\n    if (ctx->x_requires_grad) { in_grads->at(0) = JUST(functional::ZerosLike(out_grads.at(0))); }\n    if (ctx->grad_requires_grad) {\n      const auto& x = ctx->SavedTensors().at(0);\n      in_grads->at(1) = JUST(functional::LeakyReluGrad(x, out_grads.at(0), ctx->alpha));\n    }\n    return Maybe<void>::Ok();\n  }\n\n private:\n  AttrMap base_attrs_;\n};\n\nstruct SoftplusGradGradCaptureState : public AutoGradCaptureState {\n  bool x_requires_grad = false;\n  bool grad_requires_grad = false;\n  double beta = 1.0;\n  double threshold = 20.0;\n};\n\nclass SoftplusGradGrad : public OpExprGradFunction<SoftplusGradGradCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override {\n    const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n    CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n    base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Capture(SoftplusGradGradCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    // x, dy\n    CHECK_EQ_OR_RETURN(inputs.size(), 2);  // NOLINT(maybe-need-error-msg)\n\n    ctx->x_requires_grad = inputs.at(0)->requires_grad();\n    ctx->grad_requires_grad = inputs.at(1)->requires_grad();\n    if (!ctx->x_requires_grad && !ctx->grad_requires_grad) { return Maybe<void>::Ok(); }\n\n    ComposedAttrMap composed_attrs(attrs, base_attrs_);\n    ctx->beta = JUST(composed_attrs.GetAttr<double>(\"beta\"));\n    ctx->threshold = JUST(composed_attrs.GetAttr<double>(\"threshold\"));\n\n    ctx->SaveTensorForBackward(inputs.at(0));\n    if (ctx->x_requires_grad) { ctx->SaveTensorForBackward(inputs.at(1)); }\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const SoftplusGradGradCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    in_grads->resize(2);\n    const auto& x = ctx->SavedTensors().at(0);\n\n    if (ctx->x_requires_grad) {\n      const auto& grad = ctx->SavedTensors().at(1);\n      in_grads->at(0) = JUST(functional::Mul(\n          out_grads.at(0), JUST(functional::SoftplusGradGrad(x, grad, ctx->beta, ctx->threshold))));\n    }\n    if (ctx->grad_requires_grad) {\n      in_grads->at(1) =\n          JUST(functional::SoftplusGrad(x, out_grads.at(0), ctx->beta, ctx->threshold));\n    }\n    return Maybe<void>::Ok();\n  }\n\n private:\n  AttrMap base_attrs_;\n};\n\nstruct HardTanhGradGradCaptureState : public AutoGradCaptureState {\n  bool y_requires_grad = false;\n  bool grad_requires_grad = false;\n  double min_val = -1.0;\n  double max_val = 1.0;\n};\n\nclass HardTanhGradGrad : public OpExprGradFunction<HardTanhGradGradCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override {\n    const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n    CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n    base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n    return Maybe<void>::Ok();\n  }\n  Maybe<void> Capture(HardTanhGradGradCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    // y, dy\n    CHECK_EQ_OR_RETURN(inputs.size(), 2);   // NOLINT(maybe-need-error-msg)\n    CHECK_EQ_OR_RETURN(outputs.size(), 1);  // NOLINT(maybe-need-error-msg)\n\n    ctx->y_requires_grad = inputs.at(0)->requires_grad();\n    ctx->grad_requires_grad = inputs.at(1)->requires_grad();\n    if (!ctx->y_requires_grad && !ctx->grad_requires_grad) { return Maybe<void>::Ok(); }\n\n    ComposedAttrMap composed_attrs(attrs, base_attrs_);\n    ctx->min_val = JUST(composed_attrs.GetAttr<double>(\"min_val\"));\n    ctx->max_val = JUST(composed_attrs.GetAttr<double>(\"max_val\"));\n    if (ctx->grad_requires_grad) { ctx->SaveTensorForBackward(inputs.at(0)); }\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const HardTanhGradGradCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    in_grads->resize(2);\n\n    if (ctx->y_requires_grad) { in_grads->at(0) = JUST(functional::ZerosLike(out_grads.at(0))); }\n    if (ctx->grad_requires_grad) {\n      const auto& y = ctx->SavedTensors().at(0);\n      in_grads->at(1) =\n          JUST(functional::HardTanhGrad(y, out_grads.at(0), ctx->min_val, ctx->max_val));\n    }\n    return Maybe<void>::Ok();\n  }\n\n private:\n  AttrMap base_attrs_;\n};\n\nstruct EluGradGradCaptureState : public AutoGradCaptureState {\n  bool x_requires_grad = false;\n  bool grad_requires_grad = false;\n  double alpha = 1.0;\n};\n\nclass EluGradGrad : public OpExprGradFunction<EluGradGradCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override {\n    const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n    CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n    base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Capture(EluGradGradCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    // x, dy\n    CHECK_EQ_OR_RETURN(inputs.size(), 2);  // NOLINT(maybe-need-error-msg)\n\n    ctx->x_requires_grad = inputs.at(0)->requires_grad();\n    ctx->grad_requires_grad = inputs.at(1)->requires_grad();\n\n    ComposedAttrMap composed_attrs(attrs, base_attrs_);\n    ctx->alpha = JUST(composed_attrs.GetAttr<double>(\"alpha\"));\n\n    if (!ctx->x_requires_grad && !ctx->grad_requires_grad) { return Maybe<void>::Ok(); }\n    ctx->SaveTensorForBackward(inputs.at(0));\n    if (ctx->x_requires_grad) { ctx->SaveTensorForBackward(inputs.at(1)); }\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const EluGradGradCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    in_grads->resize(2);\n    const auto& x = ctx->SavedTensors().at(0);\n\n    if (ctx->x_requires_grad) {\n      const auto& grad = ctx->SavedTensors().at(1);\n      in_grads->at(0) = JUST(\n          functional::Mul(out_grads.at(0), JUST(functional::EluGradGrad(x, grad, ctx->alpha))));\n    }\n    if (ctx->grad_requires_grad) {\n      in_grads->at(1) = JUST(functional::EluGrad(x, out_grads.at(0), ctx->alpha));\n    }\n    return Maybe<void>::Ok();\n  }\n\n private:\n  AttrMap base_attrs_;\n};\n\nclass CeluGradGrad : public EluGradGrad {\n public:\n  Maybe<void> Apply(const EluGradGradCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    in_grads->resize(2);\n    const auto& y = ctx->SavedTensors().at(0);\n\n    if (ctx->x_requires_grad) {\n      const auto& grad = ctx->SavedTensors().at(1);\n      in_grads->at(0) = JUST(\n          functional::CeluGradGrad(y, JUST(functional::Mul(out_grads.at(0), (grad))), ctx->alpha));\n    }\n    if (ctx->grad_requires_grad) {\n      in_grads->at(1) = JUST(functional::CeluGrad(y, out_grads.at(0), ctx->alpha));\n    }\n    return Maybe<void>::Ok();\n  }\n};\n\nstruct PReluGradGradCaptureState : public AutoGradCaptureState {\n  bool grad_requires_grad = false;\n  bool input_requires_grad = false;\n  bool alpha_requires_grad = false;\n  size_t grad_index = 0;\n  size_t input_index = 1;\n  size_t alpha_index = 2;\n};\n\nclass PReluGradGrad : public OpExprGradFunction<PReluGradGradCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }\n\n  Maybe<void> Capture(PReluGradGradCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    // dy, x, alpha\n    CHECK_EQ_OR_RETURN(inputs.size(), 3);  // NOLINT(maybe-need-error-msg)\n\n    ctx->grad_requires_grad = inputs.at(0)->requires_grad();   // grad\n    ctx->input_requires_grad = inputs.at(1)->requires_grad();  // input\n    ctx->alpha_requires_grad = inputs.at(2)->requires_grad();  // alpha\n\n    ctx->input_index = ctx->SaveTensorForBackward(inputs.at(1));\n    ctx->alpha_index = ctx->SaveTensorForBackward(inputs.at(2));\n    ctx->grad_index = ctx->SaveTensorForBackward(inputs.at(0));\n\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const PReluGradGradCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    in_grads->resize(3);\n\n    const auto& input = ctx->SavedTensors().at(ctx->input_index);\n    const auto& alpha = ctx->SavedTensors().at(ctx->alpha_index);\n    const auto& grad = ctx->SavedTensors().at(ctx->grad_index);\n    const auto& grad_for_input = out_grads.at(0);\n    const auto& grad_for_alpha = out_grads.at(1);\n    const auto& condition = JUST(functional::ScalarLogicalLess(input, Scalar(0.0)));\n    const auto& zero_grad = JUST(functional::ZerosLike(alpha));  // alpha can broadcast to input\n\n    if (ctx->grad_requires_grad) {\n      auto input_mul_grad = JUST(functional::Mul(alpha, grad_for_input));\n      auto alpha_mul_grad = JUST(functional::Mul(input, grad_for_alpha));\n      auto result = JUST(functional::Add(input_mul_grad, alpha_mul_grad, /*alpha=*/Scalar(1.0),\n                                         /*inplace*/ false));\n      in_grads->at(0) = JUST(functional::Where(condition, result, grad_for_input));\n    }\n    if (ctx->input_requires_grad) {\n      auto result = JUST(functional::Mul(grad, grad_for_alpha));\n      in_grads->at(1) = JUST(functional::Where(condition, result, zero_grad));\n    }\n    if (ctx->alpha_requires_grad) {\n      auto result = JUST(functional::Mul(grad, grad_for_input));\n      in_grads->at(2) = JUST(functional::Where(condition, result, zero_grad));\n    }\n    return Maybe<void>::Ok();\n  }\n\n private:\n  std::shared_ptr<OpExpr> grad_op_;\n};\n\nstruct ThresholdGradGradCaptureState : public AutoGradCaptureState {\n  bool x_requires_grad = false;\n  bool grad_requires_grad = false;\n  double threshold = 0.0;\n};\n\nclass ThresholdGradGrad : public OpExprGradFunction<ThresholdGradGradCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override {\n    const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n    CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n    base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Capture(ThresholdGradGradCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    // x, dy\n    CHECK_EQ_OR_RETURN(inputs.size(), 2);   // NOLINT(maybe-need-error-msg)\n    CHECK_EQ_OR_RETURN(outputs.size(), 1);  // NOLINT(maybe-need-error-msg)\n\n    ctx->x_requires_grad = inputs.at(0)->requires_grad();\n    ctx->grad_requires_grad = inputs.at(1)->requires_grad();\n    if (!ctx->x_requires_grad && !ctx->grad_requires_grad) { return Maybe<void>::Ok(); }\n\n    ComposedAttrMap composed_attrs(attrs, base_attrs_);\n    ctx->threshold = JUST(composed_attrs.GetAttr<double>(\"threshold_val\"));\n\n    if (ctx->grad_requires_grad) { ctx->SaveTensorForBackward(inputs.at(0)); }\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const ThresholdGradGradCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    in_grads->resize(2);\n    if (ctx->x_requires_grad) { in_grads->at(0) = JUST(functional::ZerosLike(out_grads.at(0))); }\n    if (ctx->grad_requires_grad) {\n      const auto& x = ctx->SavedTensors().at(0);\n      in_grads->at(1) = JUST(functional::ThresholdGrad(x, out_grads.at(0), ctx->threshold));\n    }\n    return Maybe<void>::Ok();\n  }\n\n private:\n  AttrMap base_attrs_;\n};\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"relu_grad\", ReluGradGrad);\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"elu_grad\", EluGradGrad);\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"celu_grad\", CeluGradGrad);\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"prelu_grad\", PReluGradGrad);\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"hardshrink_grad\", HardShrinkGradGrad);\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"softshrink_grad\", SoftShrinkGradGrad);\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"leaky_relu_grad\", LeakyReluGradGrad);\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"hardtanh_grad\", HardTanhGradGrad);\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"threshold_grad\", ThresholdGradGrad);\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"softplus_grad\", SoftplusGradGrad);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/higher_order_gradient_funcs/avg_pool.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/functional/functional.h\"\n#include \"oneflow/core/common/container_util.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct AdaptiveAvgPoolNDGradGradCaptureState : public AutoGradCaptureState {\n  bool input_requires_grad = false;\n  bool grad_requires_grad = false;\n  std::vector<int64_t> pool_output_size;\n  std::string data_format;\n};\n\ntemplate<int ndims>\nclass AdaptiveAvgPoolNdNdGradGrad\n    : public OpExprGradFunction<AdaptiveAvgPoolNDGradGradCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override {\n    const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n    CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n    base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Capture(AdaptiveAvgPoolNDGradGradCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    // dy, x\n    CHECK_EQ_OR_RETURN(inputs.size(), 2);   // NOLINT(maybe-need-error-msg)\n    CHECK_EQ_OR_RETURN(outputs.size(), 1);  // NOLINT(maybe-need-error-msg)\n\n    ComposedAttrMap composed_attrs(attrs, base_attrs_);\n    ctx->data_format = JUST(composed_attrs.GetAttr<std::string>(\"data_format\"));\n    ctx->grad_requires_grad = inputs[0]->requires_grad();\n    ctx->input_requires_grad = inputs[1]->requires_grad();\n    if (ctx->grad_requires_grad) {\n      const auto& grad_shape = *inputs[0]->shape();\n      if (ndims == 1) {\n        ctx->pool_output_size = {grad_shape[grad_shape.size() - 1]};\n      } else if (ndims == 2) {\n        ctx->pool_output_size = {grad_shape[grad_shape.size() - 2],\n                                 grad_shape[grad_shape.size() - 1]};\n      } else if (ndims == 3) {\n        ctx->pool_output_size = {grad_shape[grad_shape.size() - 3],\n                                 grad_shape[grad_shape.size() - 2],\n                                 grad_shape[grad_shape.size() - 1]};\n      } else {\n        UNIMPLEMENTED_THEN_RETURN();\n      }\n    }\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const AdaptiveAvgPoolNDGradGradCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n    in_grads->resize(2);\n\n    if (ctx->grad_requires_grad) {\n      if (ndims == 1) {\n        (*in_grads)[0] = JUST(\n            functional::AdaptiveAvgPool1D(out_grads[0], ctx->pool_output_size, ctx->data_format));\n      } else if (ndims == 2) {\n        (*in_grads)[0] = JUST(\n            functional::AdaptiveAvgPool2D(out_grads[0], ctx->pool_output_size, ctx->data_format));\n      } else if (ndims == 3) {\n        (*in_grads)[0] = JUST(\n            functional::AdaptiveAvgPool3D(out_grads[0], ctx->pool_output_size, ctx->data_format));\n      } else {\n        UNIMPLEMENTED_THEN_RETURN();\n      }\n    }\n    if (ctx->input_requires_grad) { (*in_grads)[1] = JUST(functional::ZerosLike(out_grads[0])); }\n    return Maybe<void>::Ok();\n  }\n\n private:\n  AttrMap base_attrs_;\n};\n\nstruct AvgPoolGradGradCaptureState : public AutoGradCaptureState {\n  bool input_requires_grad = false;\n  bool grad_requires_grad = false;\n\n  std::string data_format;\n  std::vector<int32_t> padding;\n  std::vector<int32_t> kernel_size;\n  std::vector<int32_t> stride;\n  bool ceil_mode = false;\n  bool count_include_pad = false;\n  int32_t divisor_override = 0;\n};\n\nclass AvgPoolNdGradGrad : public OpExprGradFunction<AvgPoolGradGradCaptureState> {\n public:\n  virtual ~AvgPoolNdGradGrad() = default;\n  Maybe<void> Init(const OpExpr& op) override {\n    const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n    CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n    base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n    return Maybe<void>::Ok();\n  }\n  Maybe<void> Capture(AvgPoolGradGradCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    // dy, x\n    CHECK_EQ_OR_RETURN(inputs.size(), 2);   // NOLINT(maybe-need-error-msg)\n    CHECK_EQ_OR_RETURN(outputs.size(), 1);  // NOLINT(maybe-need-error-msg)\n\n    ctx->grad_requires_grad = inputs[0]->requires_grad();\n    ctx->input_requires_grad = inputs[1]->requires_grad();\n\n    ComposedAttrMap composed_attrs(attrs, base_attrs_);\n    ctx->data_format = JUST(composed_attrs.GetAttr<std::string>(\"data_format\"));\n    ctx->padding = JUST(composed_attrs.GetAttr<std::vector<int32_t>>(\"padding\"));\n    ctx->kernel_size = JUST(composed_attrs.GetAttr<std::vector<int32_t>>(\"kernel_size\"));\n    ctx->stride = JUST(composed_attrs.GetAttr<std::vector<int32_t>>(\"stride\"));\n    ctx->ceil_mode = JUST(composed_attrs.GetAttr<bool>(\"ceil_mode\"));\n    ctx->count_include_pad = JUST(composed_attrs.GetAttr<bool>(\"count_include_pad\"));\n    ctx->divisor_override = JUST(composed_attrs.GetAttr<int32_t>(\"divisor_override\"));\n\n    return Maybe<void>::Ok();\n  }\n  Maybe<void> Apply(const AvgPoolGradGradCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n    in_grads->resize(2);\n\n    if (ctx->grad_requires_grad) {\n      int32_t ndims = ctx->kernel_size.size();\n      const auto pool_op =\n          (ndims == 1 ? functional::AvgPool1D\n                      : (ndims == 2 ? functional::AvgPool2D\n                                    : (ndims == 3 ? functional::AvgPool3D : nullptr)));\n      CHECK_NOTNULL_OR_RETURN(pool_op);  // NOLINT(maybe-need-error-msg)\n      (*in_grads)[0] =\n          JUST(pool_op(out_grads[0], ctx->kernel_size, ctx->stride, ctx->padding, ctx->ceil_mode,\n                       ctx->count_include_pad, ctx->divisor_override, ctx->data_format));\n    }\n    if (ctx->input_requires_grad) { (*in_grads)[1] = JUST(functional::ZerosLike(out_grads[0])); }\n\n    return Maybe<void>::Ok();\n  }\n\n private:\n  AttrMap base_attrs_;\n};\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"avg_pool_1d_grad\", AvgPoolNdGradGrad);\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"avg_pool_2d_grad\", AvgPoolNdGradGrad);\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"avg_pool_3d_grad\", AvgPoolNdGradGrad);\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"adaptive_avg_pool1d_grad\", AdaptiveAvgPoolNdNdGradGrad<1>);\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"adaptive_avg_pool2d_grad\", AdaptiveAvgPoolNdNdGradGrad<2>);\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"adaptive_avg_pool3d_grad\", AdaptiveAvgPoolNdNdGradGrad<3>);\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/higher_order_gradient_funcs/binary_cross_entropy_loss.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/functional/functional.h\"\n#include \"oneflow/core/functional/sequence_function.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct BinaryCrossEntropyGradGradCaptureState : public AutoGradCaptureState {\n  bool grad_requires_grad = false;\n  bool input_requires_grad = false;\n  bool target_requires_grad = false;\n  bool has_weight = false;\n};\n\nclass BinaryCrossEntropyGradGrad\n    : public OpExprGradFunction<BinaryCrossEntropyGradGradCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override;\n  Maybe<void> Capture(BinaryCrossEntropyGradGradCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override;\n  Maybe<void> Apply(const BinaryCrossEntropyGradGradCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override;\n};\n\nMaybe<void> BinaryCrossEntropyGradGrad::Init(const OpExpr& op) { return Maybe<void>::Ok(); }\n\nMaybe<void> BinaryCrossEntropyGradGrad::Capture(BinaryCrossEntropyGradGradCaptureState* ctx,\n                                                const TensorTuple& inputs,\n                                                const TensorTuple& outputs,\n                                                const AttrMap& attrs) const {\n  // dy, input, target[, weight]\n  CHECK_OR_RETURN(inputs.size() >= 3 && inputs.size() <= 4);  // NOLINT(maybe-need-error-msg)\n  ctx->grad_requires_grad = inputs[0]->requires_grad();\n  ctx->input_requires_grad = inputs[1]->requires_grad();\n  ctx->target_requires_grad = inputs[2]->requires_grad();\n  ctx->has_weight = inputs.size() == 4;\n\n  ctx->SaveTensorForBackward(inputs[0]);  // grad\n  ctx->SaveTensorForBackward(inputs[1]);  // input\n  ctx->SaveTensorForBackward(inputs[2]);  // target\n  if (ctx->has_weight) {\n    ctx->SaveTensorForBackward(inputs[3]);  // weight\n  }\n  return Maybe<void>::Ok();\n}\nMaybe<void> BinaryCrossEntropyGradGrad::Apply(const BinaryCrossEntropyGradGradCaptureState* ctx,\n                                              const TensorTuple& out_grads,\n                                              TensorTuple* in_grads) const {\n  CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n  CHECK_EQ_OR_RETURN(ctx->SavedTensors().size(),\n                     3 + ctx->has_weight);  // NOLINT(maybe-need-error-msg)\n  in_grads->resize(3 + ctx->has_weight);\n  const auto& grad = ctx->SavedTensors()[0];\n  const auto& input = ctx->SavedTensors()[1];\n  const auto& target = ctx->SavedTensors()[2];\n\n  // dx = grad * [-target/input + (1-target)/(1-input)]\n  // grad_for_grad = out_grad * [-target/input + (1-target)/(1-input)]\n  // grad_for_input = out_grad * grad * [target/(input*input) + (1-target)/((1-input)*(1-input))]\n  //                = out_grad * grad * [(input*input-2*input*target+target)/(input*(1-input))^2]\n  // grad_for_target = out_grad * grad * [1/(input*(1-input))]\n  if (ctx->grad_requires_grad) {\n    const auto& weight = ctx->has_weight ? Optional<one::Tensor>(ctx->SavedTensors()[3]) : NullOpt;\n    (*in_grads)[0] =\n        JUST(functional::BinaryCrossEntropyLossGrad(out_grads[0], input, target, weight));\n  }\n  if (ctx->input_requires_grad) {\n    auto one_sub_input = JUST(functional::ScalarSub(1, input, /*alpha=*/1));\n    auto input_mul_target = JUST(functional::Mul(input, target));\n    auto numerator =\n        JUST(functional::sequence_function(functional::Square)\n                 .then(std::bind(functional::Sub, std::placeholders::_1, input_mul_target,\n                                 /*alpha=*/2, /*inplace=*/false))\n                 .then([&target](const std::shared_ptr<Tensor>& in) {\n                   return functional::Add(in, target, /*alpha=*/1, /*inplace=*/false);\n                 })\n                 .call(input));\n    auto res = JUST(functional::sequence_function(functional::Mul)\n                        .then(functional::Square)\n                        .then(std::bind(functional::Div, numerator, std::placeholders::_1))\n                        .then(std::bind(functional::Mul, std::placeholders::_1, out_grads[0]))\n                        .then(std::bind(functional::Mul, std::placeholders::_1, grad))\n                        .call(input, one_sub_input));\n    (*in_grads)[1] = ctx->has_weight ? JUST(functional::Mul(ctx->SavedTensors()[3], res)) : res;\n  }\n  if (ctx->target_requires_grad) {\n    auto input_sub_one = JUST(functional::ScalarAdd(-1, input, /*alpha=*/1));\n    auto res = JUST(functional::sequence_function(functional::Mul)\n                        .then(std::bind(functional::LogGrad, std::placeholders::_1, out_grads[0]))\n                        .then(std::bind(functional::Mul, std::placeholders::_1, grad))\n                        .call(input, input_sub_one));\n    (*in_grads)[2] = ctx->has_weight ? JUST(functional::Mul(ctx->SavedTensors()[3], res)) : res;\n  }\n\n  return Maybe<void>::Ok();\n}\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"binary_cross_entropy_grad\", BinaryCrossEntropyGradGrad);\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/higher_order_gradient_funcs/binary_cross_entropy_with_logits.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/functional/functional.h\"\n#include \"oneflow/core/functional/sequence_function.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct BinaryCrossEntropyWithLogitsGradGradCaptureState : public AutoGradCaptureState {\n  bool grad_requires_grad = false;\n  bool input_requires_grad = false;\n  bool target_requires_grad = false;\n  bool has_weight = false;\n  bool has_pos_weight = false;\n};\n\nclass BinaryCrossEntropyWithLogitsGradGrad\n    : public OpExprGradFunction<BinaryCrossEntropyWithLogitsGradGradCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override;\n  Maybe<void> Capture(BinaryCrossEntropyWithLogitsGradGradCaptureState* ctx,\n                      const TensorTuple& inputs, const TensorTuple& outputs,\n                      const AttrMap& attrs) const override;\n  Maybe<void> Apply(const BinaryCrossEntropyWithLogitsGradGradCaptureState* ctx,\n                    const TensorTuple& out_grads, TensorTuple* in_grads) const override;\n\n private:\n  AttrMap base_attrs_;\n};\n\nMaybe<void> BinaryCrossEntropyWithLogitsGradGrad::Init(const OpExpr& op) {\n  const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n  CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n  base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n  return Maybe<void>::Ok();\n}\nMaybe<void> BinaryCrossEntropyWithLogitsGradGrad::Capture(\n    BinaryCrossEntropyWithLogitsGradGradCaptureState* ctx, const TensorTuple& inputs,\n    const TensorTuple& outputs, const AttrMap& attrs) const {\n  // dy, input, target[, weight][, pos_weight]\n  CHECK_OR_RETURN(inputs.size() >= 3 && inputs.size() <= 5);  // NOLINT(maybe-need-error-msg)\n  ctx->grad_requires_grad = inputs[0]->requires_grad();\n  ctx->input_requires_grad = inputs[1]->requires_grad();\n  ctx->target_requires_grad = inputs[2]->requires_grad();\n\n  ComposedAttrMap composed_attrs(attrs, base_attrs_);\n  ctx->has_pos_weight = JUST(composed_attrs.GetAttr<bool>(\"has_pos_weight\"));\n  ctx->has_weight = inputs.size() == 5 || (inputs.size() == 4 && !ctx->has_pos_weight);\n  ctx->SaveTensorForBackward(inputs[0]);  // grad\n  ctx->SaveTensorForBackward(inputs[1]);  // input\n  ctx->SaveTensorForBackward(inputs[2]);  // target\n\n  if (inputs.size() == 4) {\n    ctx->SaveTensorForBackward(inputs[3]);  // weight or pos_weight\n  }\n  if (inputs.size() == 5) {\n    ctx->SaveTensorForBackward(inputs[3]);  // weight\n    ctx->SaveTensorForBackward(inputs[4]);  // pos_weight\n  }\n  return Maybe<void>::Ok();\n}\nMaybe<void> BinaryCrossEntropyWithLogitsGradGrad::Apply(\n    const BinaryCrossEntropyWithLogitsGradGradCaptureState* ctx, const TensorTuple& out_grads,\n    TensorTuple* in_grads) const {\n  CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n  CHECK_EQ_OR_RETURN(ctx->SavedTensors().size(),\n                     3 + ctx->has_weight + ctx->has_pos_weight);  // NOLINT(maybe-need-error-msg)\n  in_grads->resize(3 + ctx->has_weight + ctx->has_pos_weight);\n  const auto& grad = ctx->SavedTensors()[0];\n  const auto& input = ctx->SavedTensors()[1];\n  const auto& target = ctx->SavedTensors()[2];\n  const size_t pos_weight_index = ctx->has_weight ? 4 : 3;\n  const auto& weight = ctx->has_weight ? Optional<one::Tensor>(ctx->SavedTensors()[3]) : NullOpt;\n  const auto& pos_weight =\n      ctx->has_pos_weight ? Optional<one::Tensor>(ctx->SavedTensors()[pos_weight_index]) : NullOpt;\n\n  // dx = grad * weight * (-target*(1-input.sigmoid())*pos_weight + input.sigmoid()*(1-target))\n  // grad_for_input = out_grad * grad * weight * sig * (1-sig) * [pos_weight * target + 1 - target]\n  // grad_for_target = -out_grad * grad * weight * [pos_weight + sig - pos_weight * sig]\n  if (ctx->grad_requires_grad) {\n    (*in_grads)[0] = JUST(functional::BinaryCrossEntropyWithLogitsLossGrad(\n        out_grads[0], input, target, weight, pos_weight));\n  }\n  if (ctx->input_requires_grad) {\n    auto res = JUST(functional::sequence_function(functional::Sigmoid)\n                        .then(std::bind(functional::SigmoidGrad, std::placeholders::_1, grad))\n                        .then(std::bind(functional::Mul, std::placeholders::_1, out_grads[0]))\n                        .call(input));\n    if (ctx->has_pos_weight) {\n      res = JUST(functional::sequence_function(functional::Mul)\n                     .then([](const std::shared_ptr<Tensor>& input) {\n                       return functional::ScalarAdd(1, input, /*alpha=*/Scalar(1));\n                     })\n                     .then(std::bind(functional::Sub, std::placeholders::_1, target, /*alpha=*/1,\n                                     /*inplace=*/false))\n                     .then(std::bind(functional::Mul, std::placeholders::_1, res))\n                     .call(JUST(pos_weight), target));\n    }\n    if (ctx->has_weight) { res = JUST(functional::Mul(res, JUST(weight))); }\n    (*in_grads)[1] = res;\n  }\n  if (ctx->target_requires_grad) {\n    auto res = JUST(functional::sequence_function(functional::Mul)\n                        .then(functional::Negative)\n                        .call(out_grads[0], grad));\n    if (ctx->has_pos_weight) {\n      auto sig = JUST(functional::Sigmoid(input));\n      auto one_sub_sig = JUST(functional::ScalarSub(1, sig, /*alpha=*/1));\n      res = JUST(functional::sequence_function(functional::Mul)\n                     .then([&sig](const std::shared_ptr<Tensor>& input) {\n                       return functional::Add(input, sig, /*alpha=*/Scalar(1), /*inplace=*/false);\n                     })\n                     .then(std::bind(functional::Mul, std::placeholders::_1, res))\n                     .call(one_sub_sig, JUST(pos_weight)));\n    }\n    if (ctx->has_weight) { res = JUST(functional::Mul(res, JUST(weight))); }\n    (*in_grads)[2] = res;\n  }\n\n  return Maybe<void>::Ok();\n}\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"binary_cross_entropy_with_logits_grad\",\n                               BinaryCrossEntropyWithLogitsGradGrad);\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/higher_order_gradient_funcs/binary_cross_entropy_with_logits_reduce_mean.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/common/container_util.h\"\n#include \"oneflow/core/functional/functional.h\"\n#include \"oneflow/core/functional/sequence_function.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct BinaryCrossEntropyWithLogitsReduceMeanGradGradCaptureState : public AutoGradCaptureState {\n  bool grad_requires_grad = false;\n  bool input_requires_grad = false;\n  bool target_requires_grad = false;\n\n  size_t grad_index = 0;\n  size_t input_index = 0;\n  size_t target_index = 0;\n};\n\nclass BinaryCrossEntropyWithLogitsReduceMeanGradGrad\n    : public OpExprGradFunction<BinaryCrossEntropyWithLogitsReduceMeanGradGradCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override;\n  Maybe<void> Capture(BinaryCrossEntropyWithLogitsReduceMeanGradGradCaptureState* ctx,\n                      const TensorTuple& inputs, const TensorTuple& outputs,\n                      const AttrMap& attrs) const override;\n  Maybe<void> Apply(const BinaryCrossEntropyWithLogitsReduceMeanGradGradCaptureState* ctx,\n                    const TensorTuple& out_grads, TensorTuple* in_grads) const override;\n};\n\nMaybe<void> BinaryCrossEntropyWithLogitsReduceMeanGradGrad::Init(const OpExpr& op) {\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> BinaryCrossEntropyWithLogitsReduceMeanGradGrad::Capture(\n    BinaryCrossEntropyWithLogitsReduceMeanGradGradCaptureState* ctx, const TensorTuple& inputs,\n    const TensorTuple& outputs, const AttrMap& attrs) const {\n  // dy, input, target\n  CHECK_EQ_OR_RETURN(inputs.size(), 3);  // NOLINT(maybe-need-error-msg)\n  ctx->grad_requires_grad = inputs[0]->requires_grad();\n  ctx->input_requires_grad = inputs[1]->requires_grad();\n  ctx->target_requires_grad = inputs[2]->requires_grad();\n\n  if (ctx->input_requires_grad || ctx->target_requires_grad) {\n    ctx->grad_index = ctx->SaveTensorForBackward(inputs[0]);  // grad\n  }\n  if (ctx->input_requires_grad || ctx->grad_requires_grad) {\n    ctx->input_index = ctx->SaveTensorForBackward(inputs[1]);  // input\n  }\n  if (ctx->grad_requires_grad) {\n    ctx->target_index = ctx->SaveTensorForBackward(inputs[2]);  // target\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> BinaryCrossEntropyWithLogitsReduceMeanGradGrad::Apply(\n    const BinaryCrossEntropyWithLogitsReduceMeanGradGradCaptureState* ctx,\n    const TensorTuple& out_grads, TensorTuple* in_grads) const {\n  CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n  in_grads->resize(3);\n\n  // dx = grad * weight * (input.sigmoid() - target)\n  // grad_for_input = out_grad * grad * weight * sig * (1-sig)\n  // grad_for_target = -out_grad * grad * weight\n  if (ctx->grad_requires_grad) {\n    const auto& input = JUST(VectorAt(ctx->SavedTensors(), ctx->input_index));\n    const auto& target = JUST(VectorAt(ctx->SavedTensors(), ctx->target_index));\n    (*in_grads)[0] = JUST(\n        functional::sequence_function(functional::Sigmoid)\n            .then(std::bind(functional::Sub, std::placeholders::_1, target, /*alpha=*/1,\n                            /*inplace=*/false))\n            .then(std::bind(functional::Mul, std::placeholders::_1, out_grads[0]))\n            .then(std::bind(functional::ReduceMean, std::placeholders::_1, std::vector<int32_t>{},\n                            /*keepdim=*/false))\n            .call(input));\n  }\n  if (ctx->input_requires_grad) {\n    const auto& grad = JUST(VectorAt(ctx->SavedTensors(), ctx->grad_index));\n    const auto& input = JUST(VectorAt(ctx->SavedTensors(), ctx->input_index));\n    const auto& mean_grad = JUST(functional::ScalarMul(1.0 / out_grads[0]->nelement(), grad));\n    (*in_grads)[1] =\n        JUST(functional::sequence_function(functional::Sigmoid)\n                 .then(std::bind(functional::SigmoidGrad, std::placeholders::_1, out_grads[0]))\n                 .then(std::bind(functional::Mul, std::placeholders::_1, mean_grad))\n                 .call(input));\n  }\n  if (ctx->target_requires_grad) {\n    const auto& grad = JUST(VectorAt(ctx->SavedTensors(), ctx->grad_index));\n    const auto& mean_grad = JUST(functional::ScalarMul(1.0 / out_grads[0]->nelement(), grad));\n    (*in_grads)[2] = JUST(functional::sequence_function(functional::Mul)\n                              .then(functional::Negative)\n                              .call(out_grads[0], mean_grad));\n  }\n  return Maybe<void>::Ok();\n}\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"binary_cross_entropy_with_logits_reduce_mean_grad\",\n                               BinaryCrossEntropyWithLogitsReduceMeanGradGrad);\n\n}  // namespace one\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/higher_order_gradient_funcs/conv.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/framework/op_interpreter/op_interpreter_util.h\"\n#include \"oneflow/core/functional/functional.h\"\n#include \"oneflow/core/functional/functional_api.yaml.h\"\n#include \"oneflow/core/functional/sequence_function.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct ConvDataGradGradCaptureState : public AutoGradCaptureState {\n  bool w_requires_grad = false;\n  bool grad_requires_grad = false;\n\n  size_t w_index = 0;\n  size_t grad_index = 0;\n\n  std::string data_format;\n  std::vector<int32_t> padding_before;\n  std::vector<int32_t> kernel_size;\n  std::vector<int32_t> strides;\n  std::vector<int32_t> dilation_rate;\n  int32_t groups = 0;\n};\n\nclass ConvDataGradGrad : public OpExprGradFunction<ConvDataGradGradCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override;\n  Maybe<void> Capture(ConvDataGradGradCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override;\n  Maybe<void> Apply(const ConvDataGradGradCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override;\n\n private:\n  AttrMap base_attrs_;\n};\n\nMaybe<void> ConvDataGradGrad::Init(const OpExpr& op) {\n  const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n  CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n  base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> ConvDataGradGrad::Capture(ConvDataGradGradCaptureState* ctx, const TensorTuple& inputs,\n                                      const TensorTuple& outputs, const AttrMap& attrs) const {\n  // input: dy, w, x_like, [add to output]\n  // output: dx\n  CHECK_EQ_OR_RETURN(inputs.size(), 3);   // NOLINT(maybe-need-error-msg)\n  CHECK_EQ_OR_RETURN(outputs.size(), 1);  // NOLINT(maybe-need-error-msg)\n\n  ctx->w_requires_grad = inputs.at(1)->requires_grad();\n  ctx->grad_requires_grad = inputs.at(0)->requires_grad();\n\n  if (ctx->grad_requires_grad) { ctx->w_index = ctx->SaveTensorForBackward(inputs.at(1)); }\n  if (ctx->w_requires_grad) { ctx->grad_index = ctx->SaveTensorForBackward(inputs.at(0)); }\n\n  ComposedAttrMap composed_attrs(attrs, base_attrs_);\n  ctx->data_format = JUST(composed_attrs.GetAttr<std::string>(\"data_format\"));\n  ctx->padding_before = JUST(composed_attrs.GetAttr<std::vector<int32_t>>(\"padding_before\"));\n  ctx->kernel_size = JUST(composed_attrs.GetAttr<std::vector<int32_t>>(\"kernel_size\"));\n  ctx->strides = JUST(composed_attrs.GetAttr<std::vector<int32_t>>(\"strides\"));\n  ctx->dilation_rate = JUST(composed_attrs.GetAttr<std::vector<int32_t>>(\"dilation_rate\"));\n  ctx->groups = JUST(composed_attrs.GetAttr<int32_t>(\"groups\"));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> ConvDataGradGrad::Apply(const ConvDataGradGradCaptureState* ctx,\n                                    const TensorTuple& out_grads, TensorTuple* in_grads) const {\n  in_grads->resize(3);\n  size_t num_spatial_dims = ctx->kernel_size.size();\n\n  // first order forward: ConvND\n  // x * w = y ( * => convolution)\n  // first order backward:\n  // x_grad = y_grad * w.rot180           (y.shape * w.shape -> x.shape)  call ConvDataGrad\n  // w_grad = x * y_grad                  (x.shape * y.shape -> w.shape)  call ConvFilterGrad\n\n  // second order forward (first order backward): ConvDataGrad\n  // y_grad * w.rot180 = x_grad\n  // second order forward:\n  // w_grad_grad = out_grads_x * y_grad   (x.shape * y.shape -> w.shape)  call ConvFilterGrad\n  // grad_for_y_grad = out_grads_x * w    (x.shape * w.shape -> y.shape)  call ConvND\n\n  // w_grad_grad\n  if (ctx->w_requires_grad) {\n    const auto& grad = ctx->SavedTensors().at(ctx->grad_index);\n    in_grads->at(1) = JUST(functional::ConvFilterGrad(\n        grad, out_grads.at(0), num_spatial_dims, ctx->kernel_size, ctx->strides,\n        ctx->padding_before, ctx->dilation_rate, ctx->groups, ctx->data_format));\n  }\n\n  // grad_for_y_grad\n  if (ctx->grad_requires_grad) {\n    const auto& w = ctx->SavedTensors().at(ctx->w_index);\n    const int32_t ndims = ctx->kernel_size.size();\n    const auto conv_op = (ndims == 1 ? functional::Conv1d\n                                     : (ndims == 2 ? functional::Conv2d\n                                                   : (ndims == 3 ? functional::Conv3d : nullptr)));\n    CHECK_NOTNULL_OR_RETURN(conv_op);  // NOLINT(maybe-need-error-msg)\n    in_grads->at(0) =\n        JUST(conv_op(out_grads.at(0), w, Optional<Tensor>(), ctx->strides, ctx->padding_before,\n                     ctx->dilation_rate, ctx->groups, ctx->data_format));\n  }\n\n  return Maybe<void>::Ok();\n}\n\nstruct ConvFilterGradGradCaptureState : public AutoGradCaptureState {\n  bool x_requires_grad = false;\n  bool grad_requires_grad = false;\n\n  size_t x_index = 0;\n  size_t grad_index = 0;\n\n  std::string data_format;\n  std::vector<int32_t> padding_before;\n  std::vector<int32_t> kernel_size;\n  std::vector<int32_t> strides;\n  std::vector<int32_t> dilation_rate;\n  int32_t groups = 0;\n};\n\nclass ConvFilterGradGrad : public OpExprGradFunction<ConvFilterGradGradCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override;\n  Maybe<void> Capture(ConvFilterGradGradCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override;\n  Maybe<void> Apply(const ConvFilterGradGradCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override;\n\n private:\n  AttrMap base_attrs_;\n};\n\nMaybe<void> ConvFilterGradGrad::Init(const OpExpr& op) {\n  const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n  CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n  base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> ConvFilterGradGrad::Capture(ConvFilterGradGradCaptureState* ctx,\n                                        const TensorTuple& inputs, const TensorTuple& outputs,\n                                        const AttrMap& attrs) const {\n  // input: dy, x\n  // output: dw\n  CHECK_EQ_OR_RETURN(inputs.size(), 2);   // NOLINT(maybe-need-error-msg)\n  CHECK_EQ_OR_RETURN(outputs.size(), 1);  // NOLINT(maybe-need-error-msg)\n\n  ctx->x_requires_grad = inputs.at(1)->requires_grad();\n  ctx->grad_requires_grad = inputs.at(0)->requires_grad();\n\n  ctx->x_index = ctx->SaveTensorForBackward(inputs.at(1));\n  if (ctx->x_requires_grad) { ctx->grad_index = ctx->SaveTensorForBackward(inputs.at(0)); }\n\n  ComposedAttrMap composed_attrs(attrs, base_attrs_);\n  ctx->data_format = JUST(composed_attrs.GetAttr<std::string>(\"data_format\"));\n  ctx->padding_before = JUST(composed_attrs.GetAttr<std::vector<int32_t>>(\"padding_before\"));\n  ctx->kernel_size = JUST(composed_attrs.GetAttr<std::vector<int32_t>>(\"kernel_size\"));\n  ctx->strides = JUST(composed_attrs.GetAttr<std::vector<int32_t>>(\"strides\"));\n  ctx->dilation_rate = JUST(composed_attrs.GetAttr<std::vector<int32_t>>(\"dilation_rate\"));\n  ctx->groups = JUST(composed_attrs.GetAttr<int32_t>(\"groups\"));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> ConvFilterGradGrad::Apply(const ConvFilterGradGradCaptureState* ctx,\n                                      const TensorTuple& out_grads, TensorTuple* in_grads) const {\n  in_grads->resize(2);\n  size_t num_spatial_dims = ctx->kernel_size.size();\n\n  // first order forward: ConvND\n  // x * w = y ( * => convolution)\n  // first order backward:\n  // x_grad = y_grad * w.rot180           (y.shape * w.shape -> x.shape)  call ConvDataGrad\n  // w_grad = x * y_grad                  (x.shape * y.shape -> w.shape)  call ConvFilterGrad\n\n  // second order forward (first order backward): ConvFilterGrad\n  // x * y_grad = w_grad\n  // second order backward:\n  // x_grad_grad = out_grads_w * y_grad.rot180    (y.shape * w.shape -> x.shape)  call ConvDataGrad\n  // grad_for_y_grad = x * out_grads_w            (x.shape * w.shape -> y.shape)  call ConvND\n\n  // x_grad_grad\n  if (ctx->x_requires_grad) {\n    const auto& grad = ctx->SavedTensors().at(ctx->grad_index);\n    const auto& x = ctx->SavedTensors().at(ctx->x_index);\n    in_grads->at(1) = JUST(functional::ConvDataGrad(\n        grad, out_grads.at(0), JUST(x->detach()), num_spatial_dims, ctx->kernel_size, ctx->strides,\n        ctx->padding_before, ctx->dilation_rate, ctx->groups, ctx->data_format));\n  }\n\n  // grad_for_y_grad\n  if (ctx->grad_requires_grad) {\n    const auto& x = ctx->SavedTensors().at(ctx->x_index);\n    const int32_t ndims = ctx->kernel_size.size();\n    const auto conv_op = (ndims == 1 ? functional::Conv1d\n                                     : (ndims == 2 ? functional::Conv2d\n                                                   : (ndims == 3 ? functional::Conv3d : nullptr)));\n    CHECK_NOTNULL_OR_RETURN(conv_op);  // NOLINT(maybe-need-error-msg)\n    in_grads->at(0) =\n        JUST(conv_op(x, out_grads.at(0), Optional<Tensor>(), ctx->strides, ctx->padding_before,\n                     ctx->dilation_rate, ctx->groups, ctx->data_format));\n  }\n\n  return Maybe<void>::Ok();\n}\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"conv_data_grad\", ConvDataGradGrad);\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"conv_filter_grad\", ConvFilterGradGrad);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/higher_order_gradient_funcs/div.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include <functional>\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/framework/op_interpreter/op_interpreter_util.h\"\n#include \"oneflow/core/functional/functional.h\"\n#include \"oneflow/core/functional/functional_api.yaml.h\"\n#include \"oneflow/core/functional/sequence_function.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct DivGradGradCaptureState : public AutoGradCaptureState {\n  bool y_requires_grad = false;\n  bool z_requires_grad = false;\n  bool grad_requires_grad = false;\n\n  size_t y_index = 0;\n  size_t z_index = 1;\n  size_t grad_index = 2;\n};\n\nclass DivGradGrad : public OpExprGradFunction<DivGradGradCaptureState> {\n  // div_grad    = -x/(y*y)*dz = -z/y*dz\n  // div_grad_y  = out_grad * z*dz/(y*y)\n  // div_grad_z  = out_grad * -dz/y\n  // div_grad_dz = out_grad * -z/y\n public:\n  Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }\n\n  Maybe<void> Capture(DivGradGradCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    // dz, z, y\n    CHECK_EQ_OR_RETURN(inputs.size(), 3);   // NOLINT(maybe-need-error-msg)\n    CHECK_EQ_OR_RETURN(outputs.size(), 1);  // NOLINT(maybe-need-error-msg)\n    ctx->grad_requires_grad = inputs.at(0)->requires_grad();\n    ctx->z_requires_grad = inputs.at(1)->requires_grad();\n    ctx->y_requires_grad = inputs.at(2)->requires_grad();\n\n    ctx->y_index = ctx->SaveTensorForBackward(inputs.at(2));\n    if (ctx->y_requires_grad || ctx->grad_requires_grad) {\n      ctx->z_index = ctx->SaveTensorForBackward(inputs.at(1));\n    }\n    if (ctx->y_requires_grad || ctx->z_requires_grad) {\n      ctx->grad_index = ctx->SaveTensorForBackward(inputs.at(0));\n    }\n\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const DivGradGradCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    in_grads->resize(3);\n    const auto& y = ctx->SavedTensors().at(ctx->y_index);\n\n    if (ctx->grad_requires_grad) {\n      const auto& z = ctx->SavedTensors().at(ctx->z_index);\n      in_grads->at(0) = JUST(functional::sequence_function(functional::Mul)\n                                 .then(functional::Negative)\n                                 .then(std::bind(functional::Div, std::placeholders::_1, y))\n                                 .call(out_grads.at(0), z));\n    }\n    if (ctx->z_requires_grad) {\n      const auto& grad = ctx->SavedTensors().at(ctx->grad_index);\n      in_grads->at(1) = JUST(functional::sequence_function(functional::Mul)\n                                 .then(functional::Negative)\n                                 .then(std::bind(functional::Div, std::placeholders::_1, y))\n                                 .call(out_grads.at(0), grad));\n    }\n    if (ctx->y_requires_grad) {\n      const auto& z = ctx->SavedTensors().at(ctx->z_index);\n      const auto& grad = ctx->SavedTensors().at(ctx->grad_index);\n      in_grads->at(2) = JUST(\n          functional::sequence_function(functional::Mul)\n              .then(std::bind(functional::BroadcastReduceSumLike, std::placeholders::_1, y))\n              .then(std::bind(functional::Mul, std::placeholders::_1, out_grads.at(0)))\n              .then(std::bind(functional::Div, std::placeholders::_1, JUST(functional::Square(y))))\n              .call(z, grad));\n    }\n\n    return Maybe<void>::Ok();\n  }\n};\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"broadcast_div_grad\", DivGradGrad);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/higher_order_gradient_funcs/kl_div_loss.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/common/container_util.h\"\n#include \"oneflow/core/functional/functional.h\"\n#include \"oneflow/core/functional/sequence_function.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct KLDivLossGradGradCaptureState : public AutoGradCaptureState {\n  bool grad_requires_grad = false;\n  bool input_requires_grad = false;\n  bool target_requires_grad = false;\n  bool log_target = false;\n\n  size_t input_index = 0;\n  size_t target_index = 0;\n};\n\nclass KLDivLossGradGrad : public OpExprGradFunction<KLDivLossGradGradCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override;\n  Maybe<void> Capture(KLDivLossGradGradCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override;\n  Maybe<void> Apply(const KLDivLossGradGradCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override;\n\n private:\n  AttrMap base_attrs_;\n};\n\nMaybe<void> KLDivLossGradGrad::Init(const OpExpr& op) {\n  const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n  CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n  base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n  return Maybe<void>::Ok();\n}\nMaybe<void> KLDivLossGradGrad::Capture(KLDivLossGradGradCaptureState* ctx,\n                                       const TensorTuple& inputs, const TensorTuple& outputs,\n                                       const AttrMap& attrs) const {\n  // grad, input, target\n  CHECK_EQ_OR_RETURN(inputs.size(), 3);  // NOLINT(maybe-need-error-msg)\n  ctx->grad_requires_grad = inputs[0]->requires_grad();\n  ctx->input_requires_grad = inputs[1]->requires_grad();\n  ctx->target_requires_grad = inputs[2]->requires_grad();\n\n  ComposedAttrMap composed_attrs(attrs, base_attrs_);\n  ctx->log_target = JUST(composed_attrs.GetAttr<bool>(\"log_target\"));\n\n  ctx->input_index = ctx->SaveTensorForBackward(inputs[1]);   // input\n  ctx->target_index = ctx->SaveTensorForBackward(inputs[2]);  // target\n\n  return Maybe<void>::Ok();\n}\nMaybe<void> KLDivLossGradGrad::Apply(const KLDivLossGradGradCaptureState* ctx,\n                                     const TensorTuple& out_grads, TensorTuple* in_grads) const {\n  CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n  in_grads->resize(3);\n\n  if (ctx->grad_requires_grad) {\n    const auto& input = JUST(VectorAt(ctx->SavedTensors(), ctx->input_index));\n    const auto& target = JUST(VectorAt(ctx->SavedTensors(), ctx->target_index));\n    (*in_grads)[0] = JUST(functional::KLDivLossGrad(out_grads[0], input, target, ctx->log_target));\n  }\n  if (ctx->input_requires_grad) { (*in_grads)[1] = JUST(functional::ZerosLike(out_grads[0])); }\n  if (ctx->target_requires_grad) { (*in_grads)[2] = JUST(functional::ZerosLike(out_grads[0])); }\n  //// In pytorch 1.13 the higher derivative grad is fixed, which will cause difference here\n  // if (ctx->target_requires_grad) {\n  //   if (ctx->log_target) (*in_grads)[2] =\n  //   JUST(functional::Mul(JUST(functional::Negative(JUST(functional::Exp(target)))),\n  //   out_grads[0])); else (*in_grads)[2] = JUST(functional::Negative(out_grads[0]));\n  // }\n\n  return Maybe<void>::Ok();\n}\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"kl_div_loss_grad\", KLDivLossGradGrad);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/higher_order_gradient_funcs/log_softmax.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/framework/op_interpreter/op_interpreter_util.h\"\n#include \"oneflow/core/functional/functional.h\"\n#include \"oneflow/core/functional/functional_api.yaml.h\"\n#include \"oneflow/core/functional/sequence_function.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct LogSoftmaxGradGradCaptureState : public AutoGradCaptureState {\n  bool y_requires_grad = false;\n  bool dy_requires_grad = false;\n};\n\nclass LogSoftmaxGradGrad : public OpExprGradFunction<LogSoftmaxGradGradCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override;\n  Maybe<void> Capture(LogSoftmaxGradGradCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override;\n  Maybe<void> Apply(const LogSoftmaxGradGradCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override;\n};\n\nMaybe<void> LogSoftmaxGradGrad::Init(const OpExpr& op) { return Maybe<void>::Ok(); }\n\nMaybe<void> LogSoftmaxGradGrad::Capture(LogSoftmaxGradGradCaptureState* ctx,\n                                        const TensorTuple& inputs, const TensorTuple& outputs,\n                                        const AttrMap& attrs) const {\n  // y, dy\n  CHECK_EQ_OR_RETURN(inputs.size(), 2);   // NOLINT(maybe-need-error-msg)\n  CHECK_EQ_OR_RETURN(outputs.size(), 1);  // NOLINT(maybe-need-error-msg)\n  ctx->y_requires_grad = inputs[0]->requires_grad();\n  ctx->dy_requires_grad = inputs[1]->requires_grad();\n\n  ctx->SaveTensorForBackward(inputs[0]);\n  if (ctx->y_requires_grad) ctx->SaveTensorForBackward(inputs[1]);\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> LogSoftmaxGradGrad::Apply(const LogSoftmaxGradGradCaptureState* ctx,\n                                      const TensorTuple& out_grads, TensorTuple* in_grads) const {\n  in_grads->resize(2);\n  const auto& y = ctx->SavedTensors()[0];\n  const std::vector<int32_t> reduce_axis{static_cast<int32_t>(y->ndim() - 1)};\n\n  if (ctx->y_requires_grad) {\n    const auto& dy = ctx->SavedTensors()[1];\n    in_grads->at(0) =\n        JUST(functional::sequence_function(functional::ReduceSum)\n                 .then(std::bind(functional::Mul, std::placeholders::_1, out_grads[0]))\n                 .then(std::bind(functional::Mul, std::placeholders::_1, JUST(functional::Exp(y))))\n                 .then(functional::Negative)\n                 .call(dy, reduce_axis, true, NullOpt));\n  }\n  if (ctx->dy_requires_grad) {\n    in_grads->at(1) =\n        JUST(functional::sequence_function(functional::Exp)\n                 .then(std::bind(functional::Mul, std::placeholders::_1, out_grads[0]))\n                 .then(std::bind(functional::ReduceSum, std::placeholders::_1, reduce_axis,\n                                 /*keepdim=*/true, NullOpt))\n                 .then(std::bind(functional::Sub, out_grads[0], std::placeholders::_1, /*alpha=*/1,\n                                 /*inplace=*/false))\n                 .call(y));\n  }\n\n  return Maybe<void>::Ok();\n}\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"log_softmax_grad\", LogSoftmaxGradGrad);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/higher_order_gradient_funcs/math_unary_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/common/container_util.h\"\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/framework/op_interpreter/op_interpreter_util.h\"\n#include \"oneflow/core/functional/functional.h\"\n#include \"oneflow/core/functional/functional_api.yaml.h\"\n#include \"oneflow/core/functional/sequence_function.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct UnaryMathGradGradState : public AutoGradCaptureState {\n  bool input_requires_grad = false;\n  bool grad_requires_grad = false;\n};\n\ntypedef Maybe<one::Tensor> (*UnaryBwFunc)(const std::shared_ptr<one::Tensor>&,\n                                          const std::shared_ptr<one::Tensor>&);\n\ntemplate<UnaryBwFunc BwFunc, UnaryBwFunc BwBwFunc>\nclass UnaryMathGradGrad : public OpExprGradFunction<UnaryMathGradGradState> {\n  Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }\n  Maybe<void> Capture(UnaryMathGradGradState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    CHECK_EQ_OR_RETURN(inputs.size(), 2);   // NOLINT(maybe-need-error-msg)\n    CHECK_EQ_OR_RETURN(outputs.size(), 1);  // NOLINT(maybe-need-error-msg)\n    ctx->input_requires_grad = inputs[0]->requires_grad();\n    ctx->grad_requires_grad = inputs[1]->requires_grad();\n    ctx->SaveTensorForBackward(inputs[0]);\n    if (ctx->input_requires_grad) { ctx->SaveTensorForBackward(inputs[1]); }\n    return Maybe<void>::Ok();\n  }\n  Maybe<void> Apply(const UnaryMathGradGradState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    in_grads->resize(2);\n    const auto& input = ctx->SavedTensors()[0];\n    if (ctx->input_requires_grad) {\n      const auto& grad = ctx->SavedTensors()[1];\n      (*in_grads)[0] = JUST(functional::Mul(out_grads[0], JUST(BwBwFunc(input, grad))));\n    }\n    if (ctx->grad_requires_grad) { (*in_grads)[1] = JUST(BwFunc(input, out_grads[0])); }\n    return Maybe<void>::Ok();\n  }\n};\n\ntemplate<UnaryBwFunc BwFunc>\nclass UnaryMathGradGradWithZeroDDX : public OpExprGradFunction<UnaryMathGradGradState> {\n  Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }\n  Maybe<void> Capture(UnaryMathGradGradState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    CHECK_EQ_OR_RETURN(inputs.size(), 2);   // NOLINT(maybe-need-error-msg)\n    CHECK_EQ_OR_RETURN(outputs.size(), 1);  // NOLINT(maybe-need-error-msg)\n    ctx->input_requires_grad = inputs[0]->requires_grad();\n    ctx->grad_requires_grad = inputs[1]->requires_grad();\n    ctx->SaveTensorForBackward(inputs[0]);\n    return Maybe<void>::Ok();\n  }\n  Maybe<void> Apply(const UnaryMathGradGradState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    in_grads->resize(2);\n    const auto& input = ctx->SavedTensors()[0];\n    if (ctx->input_requires_grad) { (*in_grads)[0] = JUST(functional::ZerosLike(input)); }\n    if (ctx->grad_requires_grad) { (*in_grads)[1] = JUST(BwFunc(input, out_grads[0])); }\n    return Maybe<void>::Ok();\n  }\n};\n\n// TODO: Lgamma, first order backward unimplemented\n#define MATH_UNARY_ELEMENTWISE_GRAD_GRAD_DY_X_FUNC_SEQ            \\\n  OF_PP_MAKE_TUPLE_SEQ(\"sin_grad\", Sin)                           \\\n  OF_PP_MAKE_TUPLE_SEQ(\"cos_grad\", Cos)                           \\\n  OF_PP_MAKE_TUPLE_SEQ(\"tan_grad\", Tan)                           \\\n  OF_PP_MAKE_TUPLE_SEQ(\"sinh_grad\", Sinh)                         \\\n  OF_PP_MAKE_TUPLE_SEQ(\"cosh_grad\", Cosh)                         \\\n  OF_PP_MAKE_TUPLE_SEQ(\"asin_grad\", Asin)                         \\\n  OF_PP_MAKE_TUPLE_SEQ(\"acos_grad\", Acos)                         \\\n  OF_PP_MAKE_TUPLE_SEQ(\"atan_grad\", Atan)                         \\\n  OF_PP_MAKE_TUPLE_SEQ(\"asinh_grad\", Asinh)                       \\\n  OF_PP_MAKE_TUPLE_SEQ(\"acosh_grad\", Acosh)                       \\\n  OF_PP_MAKE_TUPLE_SEQ(\"atanh_grad\", Atanh)                       \\\n  OF_PP_MAKE_TUPLE_SEQ(\"erf_grad\", Erf)                           \\\n  OF_PP_MAKE_TUPLE_SEQ(\"erfc_grad\", Erfc)                         \\\n  OF_PP_MAKE_TUPLE_SEQ(\"exp_grad\", Exp)                           \\\n  OF_PP_MAKE_TUPLE_SEQ(\"exp2_grad\", Exp2)                         \\\n  OF_PP_MAKE_TUPLE_SEQ(\"expm1_grad\", Expm1)                       \\\n  OF_PP_MAKE_TUPLE_SEQ(\"log_grad\", Log)                           \\\n  OF_PP_MAKE_TUPLE_SEQ(\"log_sigmoid_grad\", LogSigmoid)            \\\n  OF_PP_MAKE_TUPLE_SEQ(\"log2_grad\", Log2)                         \\\n  OF_PP_MAKE_TUPLE_SEQ(\"log1p_grad\", Log1p)                       \\\n  OF_PP_MAKE_TUPLE_SEQ(\"reciprocal_grad\", Reciprocal)             \\\n  OF_PP_MAKE_TUPLE_SEQ(\"reciprocal_no_nan_grad\", ReciprocalNoNan) \\\n  OF_PP_MAKE_TUPLE_SEQ(\"rsqrt_grad\", Rsqrt)                       \\\n  OF_PP_MAKE_TUPLE_SEQ(\"sqrt_grad\", Sqrt)                         \\\n  OF_PP_MAKE_TUPLE_SEQ(\"square_grad\", Square)\n\n#define MATH_UNARY_ELEMENTWISE_GRAD_GRAD_DY_Y_FUNC_SEQ \\\n  OF_PP_MAKE_TUPLE_SEQ(\"sigmoid_grad\", Sigmoid)        \\\n  OF_PP_MAKE_TUPLE_SEQ(\"tanh_grad\", Tanh)\n\n#define MATH_UNARY_ELEMENTWISE_GRAD_GRAD_ZERO_DDX_FUNC_SEQ OF_PP_MAKE_TUPLE_SEQ(\"abs_grad\", Abs)\n\n#define INSTANTIAT_AND_REGISTER_UNARY_MATHOP_GRAD_GRAD_CLASS(op_type_name, op_cls)           \\\n  class op_cls##GradGradCls final                                                            \\\n      : public UnaryMathGradGrad<functional::op_cls##Grad, functional::op_cls##GradGrad> {}; \\\n  REGISTER_OP_EXPR_GRAD_FUNCTION(op_type_name, op_cls##GradGradCls);\n\nOF_PP_FOR_EACH_TUPLE(INSTANTIAT_AND_REGISTER_UNARY_MATHOP_GRAD_GRAD_CLASS,\n                     MATH_UNARY_ELEMENTWISE_GRAD_GRAD_DY_X_FUNC_SEQ);\n\nOF_PP_FOR_EACH_TUPLE(INSTANTIAT_AND_REGISTER_UNARY_MATHOP_GRAD_GRAD_CLASS,\n                     MATH_UNARY_ELEMENTWISE_GRAD_GRAD_DY_Y_FUNC_SEQ);\n\n#define INSTANTIAT_AND_REGISTER_UNARY_MATHOP_GRAD_GRAD_ZERO_DDX_CLASS(op_type_name, op_cls) \\\n  class op_cls##GradGradCls final                                                           \\\n      : public UnaryMathGradGradWithZeroDDX<functional::op_cls##Grad> {};                   \\\n  REGISTER_OP_EXPR_GRAD_FUNCTION(op_type_name, op_cls##GradGradCls);\nOF_PP_FOR_EACH_TUPLE(INSTANTIAT_AND_REGISTER_UNARY_MATHOP_GRAD_GRAD_ZERO_DDX_CLASS,\n                     MATH_UNARY_ELEMENTWISE_GRAD_GRAD_ZERO_DDX_FUNC_SEQ);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/higher_order_gradient_funcs/matmul.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/framework/op_interpreter/op_interpreter_util.h\"\n#include \"oneflow/core/functional/functional.h\"\n#include \"oneflow/core/functional/functional_api.yaml.h\"\n#include \"oneflow/core/functional/sequence_function.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct BroadcastMatmulGradBGradCaptureState : public AutoGradCaptureState {\n  bool a_requires_grad = false;\n  bool b_requires_grad = false;\n  size_t a_index = 0;\n  size_t b_index = 1;\n  double alpha = 1.0;\n};\n\nclass BroadcastMatmulGradBGrad : public OpExprGradFunction<BroadcastMatmulGradBGradCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override {\n    const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n    CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n    base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n    return Maybe<void>::Ok();\n  };\n  Maybe<void> Capture(BroadcastMatmulGradBGradCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    CHECK_EQ_OR_RETURN(inputs.size(), 2);   // NOLINT(maybe-need-error-msg)\n    CHECK_EQ_OR_RETURN(outputs.size(), 1);  // NOLINT(maybe-need-error-msg)\n\n    ctx->a_requires_grad = inputs.at(0)->requires_grad();\n    ctx->b_requires_grad = inputs.at(1)->requires_grad();\n    if (ctx->a_requires_grad) { ctx->b_index = ctx->SaveTensorForBackward(inputs.at(1)); }\n    if (ctx->b_requires_grad) { ctx->a_index = ctx->SaveTensorForBackward(inputs.at(0)); }\n\n    ComposedAttrMap composed_attrs(attrs, base_attrs_);\n    ctx->alpha = JUST(composed_attrs.GetAttr<double>(\"alpha\"));\n\n    return Maybe<void>::Ok();\n  }\n  Maybe<void> Apply(const BroadcastMatmulGradBGradCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    in_grads->resize(2);\n\n    // for matmul: input_a[dims..., m, k] * input_b[k, n] -> [dims..., m, n]\n    // if forward: BroadcastMatmulGradB(input_a, JUST(VectorAt(out_grads, 0)), ctx->alpha))\n    //       then: a.shape = [dims..., m, k], b.shape = [dims..., m, n], grad.shape = [k, n]\n    // if forward: BroadcastMatmulGradB(JUST(VectorAt(out_grads, 0)), input_a, ctx->alpha))\n    //       then: a.shape = [dims..., m, n], b.shape = [dims..., m, k], grad.shape = [n, k]\n    if (ctx->a_requires_grad) {\n      const auto& b = ctx->SavedTensors()[ctx->b_index];\n      in_grads->at(0) = JUST(functional::MatMul(b, out_grads.at(0), false, true, ctx->alpha));\n    }\n    if (ctx->b_requires_grad) {\n      const auto& a = ctx->SavedTensors()[ctx->a_index];\n      in_grads->at(1) = JUST(functional::MatMul(a, out_grads.at(0), false, false, ctx->alpha));\n    }\n\n    return Maybe<void>::Ok();\n  }\n\n private:\n  AttrMap base_attrs_;\n};\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"broadcast_matmul_grad_b\", BroadcastMatmulGradBGrad);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/higher_order_gradient_funcs/max_pool.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/functional/functional.h\"\n#include \"oneflow/core/common/container_util.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct MaxPoolGradGradCaptureState : public AutoGradCaptureState {\n  bool grad_requires_grad = false;\n  bool input_requires_grad = false;\n};\n\ntemplate<int ndims>\nclass MaxPoolNdGradGrad : public OpExprGradFunction<MaxPoolGradGradCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }\n\n  Maybe<void> Capture(MaxPoolGradGradCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    // dy, x, indice\n    CHECK_EQ_OR_RETURN(inputs.size(), 3);   // NOLINT(maybe-need-error-msg)\n    CHECK_EQ_OR_RETURN(outputs.size(), 1);  // NOLINT(maybe-need-error-msg)\n\n    ctx->grad_requires_grad = inputs[0]->requires_grad();\n    ctx->input_requires_grad = inputs[1]->requires_grad();\n    if (ctx->grad_requires_grad) { ctx->SaveTensorForBackward(inputs[2]); }\n\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const MaxPoolGradGradCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n    in_grads->resize(3);\n\n    if (ctx->grad_requires_grad) {\n      const auto& indices = JUST(VectorAt(ctx->SavedTensors(), 0));\n      (*in_grads)[0] = JUST(functional::MaxPoolNdGradGrad(out_grads[0], indices, ndims));\n    }\n    if (ctx->input_requires_grad) { (*in_grads)[1] = JUST(functional::ZerosLike(out_grads[0])); }\n    return Maybe<void>::Ok();\n  }\n};\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"max_pool_1d_grad\", MaxPoolNdGradGrad<1>);\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"max_pool_2d_grad\", MaxPoolNdGradGrad<2>);\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"max_pool_3d_grad\", MaxPoolNdGradGrad<3>);\n// REGISTER_OP_EXPR_GRAD_FUNCTION(\"adaptive_max_pool1d_grad\", MaxPoolNdGradGrad<1>);\n// REGISTER_OP_EXPR_GRAD_FUNCTION(\"adaptive_max_pool2d_grad\", MaxPoolNdGradGrad<2>);\n// REGISTER_OP_EXPR_GRAD_FUNCTION(\"adaptive_max_pool3d_grad\", MaxPoolNdGradGrad<3>);\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/higher_order_gradient_funcs/nll_loss.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/functional/functional.h\"\n#include \"oneflow/core/common/container_util.h\"\n\nnamespace oneflow {\n\nnamespace one {\n\nstruct NLLCaptureState : public AutoGradCaptureState {\n  bool input_requires_grad = false;\n  bool grad_requires_grad = false;\n  bool has_weight = false;\n  int64_t ignore_index = -100;\n};\n\nclass NLLLossGradGrad : public OpExprGradFunction<NLLCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override;\n  Maybe<void> Capture(NLLCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs,\n                      const AttrMap& attrs) const override;\n  Maybe<void> Apply(const NLLCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override;\n\n private:\n  AttrMap base_attrs_;\n};\n\nMaybe<void> NLLLossGradGrad::Init(const OpExpr& op) {\n  const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n  CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n  base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> NLLLossGradGrad::Capture(NLLCaptureState* ctx, const TensorTuple& inputs,\n                                     const TensorTuple& outputs, const AttrMap& attrs) const {\n  // dy, input, target[, weight]\n  CHECK_OR_RETURN(inputs.size() >= 3 && inputs.size() <= 4);  // NOLINT(maybe-need-error-msg)\n  ctx->grad_requires_grad = inputs[0]->requires_grad();\n  ctx->input_requires_grad = inputs[1]->requires_grad();\n  ctx->has_weight = inputs.size() == 4;\n\n  if (ctx->grad_requires_grad) {\n    ctx->SaveTensorForBackward(inputs[2]);\n    if (ctx->has_weight) { ctx->SaveTensorForBackward(inputs[3]); }  // weight\n    ComposedAttrMap composed_attrs(attrs, base_attrs_);\n    ctx->ignore_index = JUST(composed_attrs.GetAttr<int64_t>(\"ignore_index\"));\n  }\n\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> NLLLossGradGrad::Apply(const NLLCaptureState* ctx, const TensorTuple& out_grads,\n                                   TensorTuple* in_grads) const {\n  CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n  in_grads->resize(3 + ctx->has_weight);\n\n  if (ctx->grad_requires_grad) {\n    const auto& target = JUST(VectorAt(ctx->SavedTensors(), 0));\n    if (ctx->has_weight) {\n      auto weight = JUST(VectorAt(ctx->SavedTensors(), 1));\n      (*in_grads)[0] =\n          JUST(functional::NLLLoss(out_grads[0], target, weight, ctx->ignore_index, \"none\"));\n    } else {\n      (*in_grads)[0] =\n          JUST(functional::NLLLoss(out_grads[0], target, NullOpt, ctx->ignore_index, \"none\"));\n    }\n  }\n  if (ctx->input_requires_grad) { (*in_grads)[1] = JUST(functional::ZerosLike(out_grads[0])); }\n\n  return Maybe<void>::Ok();\n}\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"nll_grad\", NLLLossGradGrad);\n\n}  // namespace one\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/higher_order_gradient_funcs/pow.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include <functional>\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/framework/op_interpreter/op_interpreter_util.h\"\n#include \"oneflow/core/functional/functional.h\"\n#include \"oneflow/core/functional/functional_api.yaml.h\"\n#include \"oneflow/core/functional/sequence_function.h\"\n\nnamespace oneflow {\nnamespace one {\nstruct PowXGradGradCaptureState : public AutoGradCaptureState {\n  bool x_requires_grad = false;\n  bool y_requires_grad = false;\n  bool dz_requires_grad = false;\n\n  size_t x_index = 0;\n  size_t y_index = 1;\n  size_t dz_index = 2;\n};\n\nclass PowXGradGrad : public OpExprGradFunction<PowXGradGradCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }\n  Maybe<void> Capture(PowXGradGradCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    // x, y, dz\n    CHECK_EQ_OR_RETURN(inputs.size(), 3);   // NOLINT(maybe-need-error-msg)\n    CHECK_EQ_OR_RETURN(outputs.size(), 1);  // NOLINT(maybe-need-error-msg)\n    ctx->x_requires_grad = inputs.at(0)->requires_grad();\n    ctx->y_requires_grad = inputs.at(1)->requires_grad();\n    ctx->dz_requires_grad = inputs.at(2)->requires_grad();\n\n    ctx->x_index = ctx->SaveTensorForBackward(inputs.at(0));\n    ctx->y_index = ctx->SaveTensorForBackward(inputs.at(1));\n    if (ctx->x_requires_grad || ctx->y_requires_grad) {\n      ctx->dz_index = ctx->SaveTensorForBackward(inputs.at(2));\n    }\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const PowXGradGradCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    in_grads->resize(3);\n    const auto& x = ctx->SavedTensors().at(ctx->x_index);\n    const auto& y = ctx->SavedTensors().at(ctx->y_index);\n\n    // dx = y * x^(y-1) * dz\n    // grad_for_x  = out_grads * dz * y * [x^(y-1)]'\n    // grad_for_y  = out_grads * dz * [x^(y-1) * (1 + y * ln(x))]\n    // grad_for_dz = out_grads * y * x^(y-1)\n\n    if (ctx->x_requires_grad || ctx->y_requires_grad) {\n      const auto& dz = ctx->SavedTensors().at(ctx->dz_index);\n      const auto& y_sub_one = JUST(functional::ScalarSub(y, 1, /*alpha=*/1, /*inplace=*/false));\n      if (ctx->x_requires_grad) {\n        in_grads->at(0) = JUST(functional::sequence_function(functional::PowXGrad)\n                                   .then(std::bind(functional::Mul, std::placeholders::_1, y))\n                                   .then(std::bind(functional::Mul, std::placeholders::_1, dz))\n                                   .call(x, y_sub_one, out_grads.at(0)));\n      }\n      if (ctx->y_requires_grad) {\n        in_grads->at(1) =\n            JUST(functional::sequence_function(functional::Log)\n                     .then(std::bind(functional::Mul, std::placeholders::_1, y))\n                     .then([](const std::shared_ptr<Tensor>& input) {\n                       return functional::ScalarAdd(1, input, /*alpha=*/1);\n                     })\n                     .then(std::bind(functional::Mul, std::placeholders::_1,\n                                     JUST(functional::Pow(x, y_sub_one))))\n                     .then(std::bind(functional::Mul, std::placeholders::_1, dz))\n                     .then(std::bind(functional::Mul, std::placeholders::_1, out_grads.at(0)))\n                     .call(x));\n      }\n    }\n    if (ctx->dz_requires_grad) {\n      in_grads->at(2) = JUST(functional::PowXGrad(x, y, out_grads.at(0)));\n    }\n    return Maybe<void>::Ok();\n  }\n};\n\nstruct PowYGradGradCaptureState : public AutoGradCaptureState {\n  bool x_requires_grad = false;\n  bool y_requires_grad = false;\n  bool dz_requires_grad = false;\n\n  size_t x_index = 0;\n  size_t y_index = 1;\n  size_t dz_index = 2;\n  size_t dy_index = 3;\n};\n\nclass PowYGradGrad : public OpExprGradFunction<PowYGradGradCaptureState> {\n public:\n  // dy = x^y*ln(x)*dz = z*ln(x)*dz\n  Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }\n  Maybe<void> Capture(PowYGradGradCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    // x, y, dz\n    CHECK_EQ_OR_RETURN(inputs.size(), 3);   // NOLINT(maybe-need-error-msg)\n    CHECK_EQ_OR_RETURN(outputs.size(), 1);  // NOLINT(maybe-need-error-msg)\n    ctx->x_requires_grad = inputs.at(0)->requires_grad();\n    ctx->y_requires_grad = inputs.at(1)->requires_grad();\n    ctx->dz_requires_grad = inputs.at(2)->requires_grad();\n\n    ctx->x_index = ctx->SaveTensorForBackward(inputs.at(0));\n    if (ctx->x_requires_grad || ctx->y_requires_grad) {\n      ctx->y_index = ctx->SaveTensorForBackward(inputs.at(1));\n    }\n    if (ctx->x_requires_grad) { ctx->dz_index = ctx->SaveTensorForBackward(inputs.at(2)); }\n    if (ctx->y_requires_grad) { ctx->dy_index = ctx->SaveTensorForBackward(outputs.at(0)); }\n\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const PowYGradGradCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    in_grads->resize(3);\n    const auto& x = ctx->SavedTensors().at(ctx->x_index);\n\n    // dy = x^y * ln(x) * dz = z * ln(x) * dz\n    // grad_for_x  = out_grads * dz * [x^(y-1) * (1 + y * ln(x))]\n    // grad_for_y  = out_grads * dy' = out_grads * dy * ln(x)\n    // grad_for_dz = out_grads * x^y * ln(x)\n\n    if (ctx->x_requires_grad) {\n      const auto& y = ctx->SavedTensors().at(ctx->y_index);\n      const auto& dz = ctx->SavedTensors().at(ctx->dz_index);\n      const auto& y_sub_one = JUST(functional::ScalarSub(y, 1, /*alpha=*/1, /*inplace=*/false));\n      in_grads->at(0) =\n          JUST(functional::sequence_function(functional::Log)\n                   .then(std::bind(functional::Mul, std::placeholders::_1, y))\n                   .then([](const std::shared_ptr<Tensor>& input) {\n                     return functional::ScalarAdd(1, input, /*alpha=*/1);\n                   })\n                   .then(std::bind(functional::Mul, std::placeholders::_1,\n                                   JUST(functional::Pow(x, y_sub_one))))\n                   .then(std::bind(functional::Mul, std::placeholders::_1, dz))\n                   .then(std::bind(functional::Mul, std::placeholders::_1, out_grads.at(0)))\n                   .call(x));\n    }\n\n    if (ctx->y_requires_grad) {\n      const auto& dy = ctx->SavedTensors().at(ctx->dy_index);\n      in_grads->at(1) =\n          JUST(functional::sequence_function(functional::Log)\n                   .then(std::bind(functional::Mul, std::placeholders::_1, dy))\n                   .then(std::bind(functional::Mul, std::placeholders::_1, out_grads.at(0)))\n                   .call(x));\n    }\n\n    if (ctx->dz_requires_grad) {\n      const auto& y = ctx->SavedTensors().at(ctx->y_index);\n      in_grads->at(2) = JUST(functional::PowYGrad(x, y, out_grads.at(0)));\n    }\n    return Maybe<void>::Ok();\n  }\n};\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"pow_x_grad\", PowXGradGrad);\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"pow_y_grad\", PowYGradGrad);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/higher_order_gradient_funcs/scalar_pow.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/attr_map.h\"\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/functional/functional.h\"\n#include \"oneflow/core/functional/functional_api.yaml.h\"\n#include \"oneflow/core/functional/sequence_function.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct ScalarPowGradGradCaptureState : public AutoGradCaptureState {\n  bool x_requires_grad = false;\n  bool grad_requires_grad = false;\n  Scalar operand;\n};\n\nclass ScalarPowGradGrad : public OpExprGradFunction<ScalarPowGradGradCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override {\n    const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n    base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n    CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Capture(ScalarPowGradGradCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    CHECK_EQ_OR_RETURN(inputs.size(), 2);   // NOLINT(maybe-need-error-msg)\n    CHECK_EQ_OR_RETURN(outputs.size(), 1);  // NOLINT(maybe-need-error-msg)\n\n    ctx->x_requires_grad = inputs.at(0)->requires_grad();\n    ctx->grad_requires_grad = inputs.at(1)->requires_grad();\n    if (!(ctx->x_requires_grad || ctx->grad_requires_grad)) { return Maybe<void>::Ok(); }\n\n    ComposedAttrMap composed_attrs(attrs, base_attrs_);\n    bool has_float_operand = JUST(composed_attrs.GetAttr<bool>(\"has_float_operand\"));\n    if (has_float_operand) {\n      ctx->operand = Scalar(JUST(composed_attrs.GetAttr<double>(\"float_operand\")));\n    } else {\n      ctx->operand = Scalar(JUST(composed_attrs.GetAttr<int64_t>(\"int_operand\")));\n    }\n    ctx->SaveTensorForBackward(inputs.at(0));\n    if (ctx->x_requires_grad) { ctx->SaveTensorForBackward(inputs.at(1)); }\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const ScalarPowGradGradCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    const auto& x = ctx->SavedTensors().at(0);\n    in_grads->resize(2);\n\n    // z = x^a, dx = a * x^(a-1) * dz\n    // grad_for_x  = out_grad * a * dz * [x^(a-1)]'\n    // grad_for_dz = out_grad * [x^a]'\n\n    if (ctx->x_requires_grad) {\n      const auto& grad = ctx->SavedTensors().at(1);\n      const auto operand_sub_one = ctx->operand - Scalar(1);\n      in_grads->at(0) = JUST(\n          functional::sequence_function(functional::Mul)\n              .then(std::bind(functional::ScalarPowGrad, x, std::placeholders::_1, operand_sub_one))\n              .then([&ctx](const std::shared_ptr<Tensor>& input) {\n                return functional::ScalarMul(ctx->operand, input);\n              })\n              .call(grad, out_grads.at(0)));\n    }\n    if (ctx->grad_requires_grad) {\n      in_grads->at(1) = JUST(functional::ScalarPowGrad(x, out_grads.at(0), ctx->operand));\n    }\n    return Maybe<void>::Ok();\n  }\n\n private:\n  AttrMap base_attrs_;\n};\n\nclass ScalarReversePowGradGrad : public OpExprGradFunction<ScalarPowGradGradCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override {\n    const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n    base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n    CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Capture(ScalarPowGradGradCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    CHECK_EQ_OR_RETURN(inputs.size(), 2);   // NOLINT(maybe-need-error-msg)\n    CHECK_EQ_OR_RETURN(outputs.size(), 1);  // NOLINT(maybe-need-error-msg)\n\n    ctx->x_requires_grad = inputs.at(0)->requires_grad();\n    ctx->grad_requires_grad = inputs.at(1)->requires_grad();\n    if (!(ctx->x_requires_grad || ctx->grad_requires_grad)) { return Maybe<void>::Ok(); }\n\n    ComposedAttrMap composed_attrs(attrs, base_attrs_);\n    bool has_float_operand = JUST(composed_attrs.GetAttr<bool>(\"has_float_operand\"));\n    if (has_float_operand) {\n      ctx->operand = Scalar(JUST(composed_attrs.GetAttr<double>(\"float_operand\")));\n    } else {\n      ctx->operand = Scalar(JUST(composed_attrs.GetAttr<int64_t>(\"int_operand\")));\n    }\n    ctx->SaveTensorForBackward(inputs.at(0));\n    if (ctx->x_requires_grad) { ctx->SaveTensorForBackward(outputs.at(0)); }\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const ScalarPowGradGradCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    const auto& x = ctx->SavedTensors().at(0);\n    in_grads->resize(2);\n\n    // z = a^x, dx = a^x * ln(a) * dz\n    // grad_for_x  = out_grad * dz * a^x * ln(a) * ln(a)\n    // grad_for_dz = out_grad * [a^x]'\n\n    if (ctx->x_requires_grad) {\n      const auto& dx = ctx->SavedTensors().at(1);\n      const auto log_operand = std::log(ctx->operand.As<double>());\n      in_grads->at(0) = JUST(functional::sequence_function(functional::Mul)\n                                 .then([&log_operand](const std::shared_ptr<Tensor>& input) {\n                                   return functional::ScalarMul(log_operand, input);\n                                 })\n                                 .call(dx, out_grads.at(0)));\n    }\n    if (ctx->grad_requires_grad) {\n      in_grads->at(1) = JUST(functional::ScalarReversePowGrad(x, out_grads.at(0), ctx->operand));\n    }\n    return Maybe<void>::Ok();\n  }\n\n private:\n  AttrMap base_attrs_;\n};\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"scalar_pow_grad\", ScalarPowGradGrad);\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"scalar_reverse_pow_grad\", ScalarReversePowGradGrad);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/higher_order_gradient_funcs/slice.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/framework/op_interpreter/op_interpreter_util.h\"\n#include \"oneflow/core/functional/functional.h\"\n#include \"oneflow/core/functional/functional_api.yaml.h\"\n#include \"oneflow/core/functional/sequence_function.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct SliceGradGradCaptureState : public AutoGradCaptureState {\n  std::vector<int64_t> start;\n  std::vector<int64_t> stop;\n  std::vector<int64_t> step;\n};\n\nclass SliceGradGrad : public OpExprGradFunction<SliceGradGradCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override {\n    const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n    CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n    base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Capture(SliceGradGradCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    CHECK_EQ_OR_RETURN(inputs.size(), 1);   // NOLINT(maybe-need-error-msg)\n    CHECK_EQ_OR_RETURN(outputs.size(), 1);  // NOLINT(maybe-need-error-msg)\n    ComposedAttrMap composed_attrs(attrs, base_attrs_);\n    ctx->start = JUST(composed_attrs.GetAttr<std::vector<int64_t>>(\"start\"));\n    ctx->stop = JUST(composed_attrs.GetAttr<std::vector<int64_t>>(\"stop\"));\n    ctx->step = JUST(composed_attrs.GetAttr<std::vector<int64_t>>(\"step\"));\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const SliceGradGradCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    in_grads->resize(1);\n    in_grads->at(0) = JUST(functional::Slice(out_grads.at(0), ctx->start, ctx->stop, ctx->step,\n                                             /*enable_view_slice=*/false));\n    return Maybe<void>::Ok();\n  }\n\n private:\n  AttrMap base_attrs_;\n};\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"slice_grad\", SliceGradGrad);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/higher_order_gradient_funcs/smooth_l1_loss.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/functional/functional.h\"\n#include \"oneflow/core/common/container_util.h\"\n#include \"oneflow/core/functional/sequence_function.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct SmoothL1LossGradGradCaptureState : public AutoGradCaptureState {\n  bool grad_requires_grad = false;\n  bool input_requires_grad = false;\n  bool target_requires_grad = false;\n  size_t grad_index = 0;\n  size_t input_index = 0;\n  size_t target_index = 0;\n  float beta = 0.0;\n};\n\nclass SmoothL1LossGradGrad : public OpExprGradFunction<SmoothL1LossGradGradCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override {\n    const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);\n    CHECK_NOTNULL_OR_RETURN(fw_op_expr);  // NOLINT(maybe-need-error-msg)\n    base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Capture(SmoothL1LossGradGradCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override {\n    // grad, input, target\n    CHECK_EQ_OR_RETURN(inputs.size(), 3);  // NOLINT(maybe-need-error-msg)\n\n    ctx->grad_requires_grad = inputs[0]->requires_grad();\n    ctx->input_requires_grad = inputs[1]->requires_grad();\n    ctx->target_requires_grad = inputs[2]->requires_grad();\n\n    if (ctx->input_requires_grad || ctx->target_requires_grad) {\n      ctx->grad_index = ctx->SaveTensorForBackward(inputs[0]);\n    }\n    ctx->input_index = ctx->SaveTensorForBackward(inputs[1]);\n    ctx->target_index = ctx->SaveTensorForBackward(inputs[2]);\n\n    ComposedAttrMap composed_attrs(attrs, base_attrs_);\n    ctx->beta = JUST(composed_attrs.GetAttr<float>(\"beta\"));\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(const SmoothL1LossGradGradCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override {\n    CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n    in_grads->resize(3);\n    const auto& input = JUST(VectorAt(ctx->SavedTensors(), ctx->input_index));\n    const auto& target = JUST(VectorAt(ctx->SavedTensors(), ctx->target_index));\n\n    if (ctx->grad_requires_grad) {\n      (*in_grads)[0] = JUST(functional::SmoothL1LossGrad(out_grads[0], input, target, ctx->beta));\n    }\n    if (ctx->input_requires_grad || ctx->target_requires_grad) {\n      const auto& grad = JUST(VectorAt(ctx->SavedTensors(), ctx->grad_index));\n      auto condition = JUST(functional::sequence_function(functional::Sub)\n                                .then(functional::Abs)\n                                .then([&ctx](const std::shared_ptr<Tensor>& input) {\n                                  return functional::ScalarLogicalLess(input, ctx->beta);\n                                })\n                                .call(input, target, /*alpha=*/1, /*inplace=*/false));\n      auto out = JUST(functional::sequence_function(functional::Mul)\n                          .then(std::bind(functional::Mul, std::placeholders::_1, condition))\n                          .then([&ctx](const std::shared_ptr<Tensor>& input) {\n                            double inv_beta = ctx->beta == 0.0 ? 0.0 : 1.0 / ctx->beta;\n                            return functional::ScalarMul(inv_beta, input);\n                          })\n                          .call(out_grads[0], grad));\n      if (ctx->input_requires_grad) { (*in_grads)[1] = out; }\n      if (ctx->target_requires_grad) { (*in_grads)[2] = JUST(functional::Negative(out)); }\n    }\n\n    return Maybe<void>::Ok();\n  }\n\n private:\n  AttrMap base_attrs_;\n};\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"smooth_l1_loss_grad\", SmoothL1LossGradGrad);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/autograd/higher_order_gradient_funcs/softmax.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/framework/op_interpreter/op_interpreter_util.h\"\n#include \"oneflow/core/functional/functional.h\"\n#include \"oneflow/core/functional/functional_api.yaml.h\"\n#include \"oneflow/core/functional/sequence_function.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct SoftmaxGradGradCaptureState : public AutoGradCaptureState {\n  bool y_requires_grad = false;\n  bool dy_requires_grad = false;\n};\n\nclass SoftmaxGradGrad : public OpExprGradFunction<SoftmaxGradGradCaptureState> {\n public:\n  Maybe<void> Init(const OpExpr& op) override;\n  Maybe<void> Capture(SoftmaxGradGradCaptureState* ctx, const TensorTuple& inputs,\n                      const TensorTuple& outputs, const AttrMap& attrs) const override;\n  Maybe<void> Apply(const SoftmaxGradGradCaptureState* ctx, const TensorTuple& out_grads,\n                    TensorTuple* in_grads) const override;\n};\n\nMaybe<void> SoftmaxGradGrad::Init(const OpExpr& op) { return Maybe<void>::Ok(); }\n\nMaybe<void> SoftmaxGradGrad::Capture(SoftmaxGradGradCaptureState* ctx, const TensorTuple& inputs,\n                                     const TensorTuple& outputs, const AttrMap& attrs) const {\n  // y, dy\n  CHECK_EQ_OR_RETURN(inputs.size(), 2);   // NOLINT(maybe-need-error-msg)\n  CHECK_EQ_OR_RETURN(outputs.size(), 1);  // NOLINT(maybe-need-error-msg)\n  ctx->y_requires_grad = inputs[0]->requires_grad();\n  ctx->dy_requires_grad = inputs[1]->requires_grad();\n\n  ctx->SaveTensorForBackward(inputs[0]);\n  if (ctx->y_requires_grad) ctx->SaveTensorForBackward(inputs[1]);\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> SoftmaxGradGrad::Apply(const SoftmaxGradGradCaptureState* ctx,\n                                   const TensorTuple& out_grads, TensorTuple* in_grads) const {\n  in_grads->resize(2);\n  const auto& y = ctx->SavedTensors()[0];\n\n  if (ctx->y_requires_grad) {\n    const auto& dy = ctx->SavedTensors()[1];\n    const std::vector<int32_t> reduce_axis{static_cast<int32_t>(y->ndim() - 1)};\n    const auto& a = JUST(functional::sequence_function(functional::Mul)\n                             .then(std::bind(functional::ReduceSum, std::placeholders::_1,\n                                             reduce_axis, /*keepdim=*/true, NullOpt))\n                             .then(std::bind(functional::Mul, std::placeholders::_1, dy))\n                             .call(y, out_grads[0]));\n    const auto& b = JUST(functional::sequence_function(functional::Mul)\n                             .then(std::bind(functional::ReduceSum, std::placeholders::_1,\n                                             reduce_axis, /*keepdim=*/true, NullOpt))\n                             .then(std::bind(functional::Mul, std::placeholders::_1, out_grads[0]))\n                             .call(y, dy));\n    in_grads->at(0) = JUST(functional::sequence_function(functional::Mul)\n                               .then(std::bind(functional::Sub, std::placeholders::_1, a,\n                                               /*alpha=*/1, /*inplace=*/false))\n                               .then(std::bind(functional::Sub, std::placeholders::_1, b,\n                                               /*alpha=*/1, /*inplace=*/false))\n                               .call(out_grads[0], dy));\n  }\n  if (ctx->dy_requires_grad) { in_grads->at(1) = JUST(functional::SoftmaxGrad(out_grads[0], y)); }\n\n  return Maybe<void>::Ok();\n}\n\nREGISTER_OP_EXPR_GRAD_FUNCTION(\"softmax_grad\", SoftmaxGradGrad);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/boxing/asymmetric_broadcast.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/container_util.h\"\n#include \"oneflow/core/control/global_process_ctx.h\"\n#include \"oneflow/core/framework/id_util.h\"\n#include \"oneflow/core/framework/op_builder.h\"\n#include \"oneflow/core/framework/op_interpreter/op_interpreter_util.h\"\n#include \"oneflow/core/framework/op_expr.h\"\n#include \"oneflow/core/framework/nd_sbp.h\"\n#include \"oneflow/core/job/nd_sbp_util.h\"\n#include \"oneflow/core/framework/placement_sbp_util.h\"\n#include \"oneflow/core/boxing/eager_boxing_interpreter.h\"\n#include \"oneflow/core/functional/functional.h\"\n#include \"oneflow/core/common/decorator.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nMaybe<void> RawCheckAsymmetricBroadcast(Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out,\n                                        const Shape& logical_shape) {\n  // NOLINTBEGIN(maybe-need-error-msg)\n  CHECK_EQ_OR_RETURN(in->nd_sbp()->sbp_parallel_size(), 1);\n  CHECK_EQ_OR_RETURN(out->nd_sbp()->sbp_parallel_size(), 1);\n  CHECK_OR_RETURN(NdSbpIsAllBroadcast(*in->nd_sbp()));\n  CHECK_OR_RETURN(NdSbpIsAllBroadcast(*out->nd_sbp()));\n  CHECK_OR_RETURN(out->placement()->Bigger(*in->placement())\n                  || in->placement()->Bigger(*out->placement()));\n  CHECK_OR_RETURN(in->placement()->device_type() == DeviceType::kCPU\n                  || in->placement()->device_type() == DeviceType::kCUDA);\n  // NOLINTEND(maybe-need-error-msg)\n  return Maybe<void>::Ok();\n}\n\nstatic constexpr auto* CheckAsymmetricBroadcast =\n    DECORATE(&RawCheckAsymmetricBroadcast, ThreadLocalCachedCopiable);\n\nMaybe<int64_t> CalBroadcastRoot(Symbol<ParallelDesc> src_parallel_desc,\n                                Symbol<ParallelDesc> dst_parallel_desc) {\n  int64_t machine_id = -1;\n  int64_t device_id = -1;\n  for (int64_t mach_id : src_parallel_desc->sorted_machine_ids()) {\n    bool machine_and_device_id_inited = false;\n    for (int64_t dev_id : src_parallel_desc->sorted_dev_phy_ids(mach_id)) {\n      if (dst_parallel_desc->Containing(mach_id, dev_id)) {\n        machine_id = mach_id;\n        device_id = dev_id;\n        machine_and_device_id_inited = true;\n        break;\n      }\n    }\n    if (machine_and_device_id_inited) { break; }\n  }\n  // Always true, if check failed, there is a bug in oneflow needed to be resolved.\n  CHECK_OR_RETURN(machine_id != -1 && device_id != -1)\n      << Error::RuntimeError()\n      << \"Calculate the intersection of placements \"\n         \"failed during execution of asymmetric broadcast,\"\n      << \", placement_a: \" << *JUST(PlacementToString(src_parallel_desc))\n      << \", placement_b: \" << *JUST(PlacementToString(dst_parallel_desc))\n      << \"! Please submit an issue in `https://github.com/Oneflow-Inc/oneflow/issues` \"\n         \"and we will fix it as soon as possible\";\n  return machine_id;\n}\n\nstatic constexpr auto* CachedGetBroadcastRoot = DECORATE(&CalBroadcastRoot, ThreadLocalCached);\n\nMaybe<one::UserOpExpr> EagerCclBroadcast(Symbol<ParallelDesc> parallel_desc, int64_t root,\n                                         const Shape& shape) {\n  return one::OpBuilder(\"eager_ccl_broadcast\", *JUST(UniqueStr(\"eager_ccl_broadcast\")))\n      .Input(\"in\")\n      .Output(\"out\")\n      .Attr<std::string>(\"parallel_conf\", PbMessage2TxtString(parallel_desc->parallel_conf()))\n      .Attr<std::vector<Shape>>(\"shape_list\", {shape})\n      .Attr<int64_t>(\"root\", root)\n      .Build();\n}\n\nstatic constexpr auto* CachedEagerCclBroadcast =\n    DECORATE(&EagerCclBroadcast, ThreadLocalCachedCopiable);\n}  // namespace\n\nMaybe<one::Tensor> AsymmetricBroadcast(const std::shared_ptr<one::Tensor>& tensor,\n                                       Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out) {\n  const auto& in_placement = in->placement();\n  const auto& out_placement = out->placement();\n  const auto& tensor_nd_sbp = JUST(tensor->nd_sbp());\n  CHECK_OR_RETURN(tensor_nd_sbp == in->nd_sbp())\n      << Error::RuntimeError() << \"The sbp of input tensor (\" << NdSbpToString(tensor_nd_sbp)\n      << \") must match the input sbp (\" << NdSbpToString(in->nd_sbp()) << \")\";\n  const auto& tensor_placement = JUST(tensor->parallel_desc());\n  CHECK_OR_RETURN(tensor_placement == in_placement)\n      << Error::RuntimeError() << \"The placement of input tensor (\"\n      << *JUST(PlacementToString(tensor_placement)) << \") must match the input placement (\"\n      << *JUST(PlacementToString(in_placement)) << \")\";\n  std::shared_ptr<one::Tensor> local_tensor = JUST(tensor->cur_rank_phy_tensor());\n  if (out->placement()->Bigger(*in->placement())) {\n    const auto& out_parallel_id = JUST(GetParallelId4CurrentProcessCtx(out_placement));\n    if (out_parallel_id->has_value()) {\n      const auto& broadcast_group = JUST(GetBroadcastGroup(in_placement, out_placement));\n\n      Symbol<ParallelDesc> broadcast_placement_cur_rank =\n          JUST(MapAt(*broadcast_group, GlobalProcessCtx::Rank()));\n      int64_t root = JUST(CachedGetBroadcastRoot(in_placement, broadcast_placement_cur_rank));\n      std::shared_ptr<one::UserOpExpr> op_expr =\n          JUST(CachedEagerCclBroadcast(broadcast_placement_cur_rank, root, *tensor->shape()));\n      local_tensor = JUST(one::OpInterpUtil::Dispatch<one::Tensor>(*op_expr, {local_tensor}));\n    }\n  }\n  return one::functional::LocalToGlobal(local_tensor, out_placement,\n                                        *JUST(GetSbpList(out->nd_sbp())), *tensor->shape(),\n                                        tensor->dtype(), /* sync_data */ false, /*copy=*/false);\n}\n\nCOMMAND(RegisterBoxingFunction(\"asymmetric-broadcast\", CheckAsymmetricBroadcast,\n                               &AsymmetricBroadcast));\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/boxing/boxing_dividor.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_BOXING_BOXING_DIVIDOR_H_\n#define ONEFLOW_CORE_BOXING_BOXING_DIVIDOR_H_\n\n#include <functional>\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/common/symbol.h\"\n\nnamespace oneflow {\n\nclass PlacedNdSbp;\n\nclass BoxingDividor final {\n public:\n  BoxingDividor(const BoxingDividor&) = delete;\n  BoxingDividor(BoxingDividor&&) = delete;\n  ~BoxingDividor() = default;\n\n  using FunctionT =\n      std::function<Maybe<Symbol<PlacedNdSbp>>(Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out)>;\n\n  BoxingDividor(const std::string& name, const FunctionT& function)\n      : name_(name), function_(function) {}\n\n  const std::string& name() const { return name_; }\n\n  Maybe<Symbol<PlacedNdSbp>> operator()(Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out) const {\n    return function_(in, out);\n  }\n\n private:\n  std::string name_;\n  FunctionT function_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_BOXING_BOXING_DIVIDOR_H_\n"
  },
  {
    "path": "oneflow/core/boxing/boxing_dividor_util.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/boxing/boxing_dividor_util.h\"\n#include \"oneflow/core/framework/nd_sbp.h\"\n#include \"oneflow/core/framework/placed_nd_sbp.h\"\n#include \"oneflow/core/framework/instructions_builder.h\"\n#include \"oneflow/core/common/decorator.h\"\n#include \"oneflow/core/job/parallel_desc.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nMaybe<BoxingDividor> RawReplaceInDeviceType(DeviceType device_type) {\n  return std::make_shared<BoxingDividor>(\n      \"ReplaceInDeviceType\",\n      [device_type](Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out) -> Maybe<Symbol<PlacedNdSbp>> {\n        const auto& new_placement = JUST(ReplaceDeviceType(in->placement(), device_type));\n        return PlacedNdSbp::New(in->nd_sbp(), new_placement);\n      });\n}\n\nMaybe<BoxingDividor> RawReplaceOutDeviceType(DeviceType device_type) {\n  return std::make_shared<BoxingDividor>(\n      \"ReplaceOutDeviceType\",\n      [device_type](Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out) -> Maybe<Symbol<PlacedNdSbp>> {\n        const auto& new_placement = JUST(ReplaceDeviceType(out->placement(), device_type));\n        return PlacedNdSbp::New(out->nd_sbp(), new_placement);\n      });\n}\n\n}  // namespace\n\ndecltype(ReplaceInDeviceType) ReplaceInDeviceType =\n    DECORATE(&RawReplaceInDeviceType, ThreadLocalCached);\ndecltype(ReplaceOutDeviceType) ReplaceOutDeviceType =\n    DECORATE(&RawReplaceOutDeviceType, ThreadLocalCached);\n\nnamespace {\n\nMaybe<Symbol<PlacedNdSbp>> RawFlattenHierarchy(Symbol<PlacedNdSbp> placed_nd_sbp) {\n  CHECK_GE_OR_RETURN(placed_nd_sbp->nd_sbp()->sbp_parallel_size(), 0)\n      << Error::RuntimeError() << \"Invalid nd_sbp with ndim equal 0!\";\n  const auto& first_sbp_parallel = placed_nd_sbp->nd_sbp()->sbp_parallel(0);\n  for (const auto& sbp_parallel : placed_nd_sbp->nd_sbp()->sbp_parallel()) {\n    CHECK_OR_RETURN(sbp_parallel == first_sbp_parallel)\n        << Error::RuntimeError()\n        << \"Expected all sbps to be on the same in sbp list during flatten sbps list, but find at \"\n           \"least two sbps, \"\n        << SbpToString(first_sbp_parallel) << \" and \" << SbpToString(sbp_parallel) << \"!\";\n  }\n  std::vector<Symbol<SbpParallel>> vec{SymbolOf(first_sbp_parallel)};\n  const auto& flattened_nd_sbp = JUST(GetNdSbp(vec));\n  ParallelConf flattened_parallel_conf(placed_nd_sbp->placement()->parallel_conf());\n  flattened_parallel_conf.clear_hierarchy();\n  const auto& flattened_placement = SymbolOf(ParallelDesc(flattened_parallel_conf));\n  return JUST(PlacedNdSbp::New(flattened_nd_sbp, flattened_placement));\n}\n\nstatic constexpr auto* FlattenHierarchy = DECORATE(&RawFlattenHierarchy, ThreadLocalCached);\n\nMaybe<BoxingDividor> RawFlattenInHierarchy() {\n  return std::make_shared<BoxingDividor>(\n      \"FlattenInHierarchy\",\n      [](Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out) -> Maybe<Symbol<PlacedNdSbp>> {\n        return FlattenHierarchy(in);\n      });\n}\n\nMaybe<Symbol<PlacedNdSbp>> RawUnflattenHierarchy(Symbol<PlacedNdSbp> in_placed_nd_sbp,\n                                                 Symbol<PlacedNdSbp> out_placed_nd_sbp) {\n  CHECK_GE_OR_RETURN(in_placed_nd_sbp->nd_sbp()->sbp_parallel_size(), 0)\n      << Error::RuntimeError() << \"Invalid nd_sbp with ndim equal 0!\";\n  CHECK_GE_OR_RETURN(out_placed_nd_sbp->nd_sbp()->sbp_parallel_size(), 0)\n      << Error::RuntimeError() << \"Invalid nd_sbp with ndim equal 0!\";\n  const auto& in_sbp_parallel = in_placed_nd_sbp->nd_sbp()->sbp_parallel(0);\n  NdSbp unflattened_nd_sbp;\n  for (int64_t i = 0; i < out_placed_nd_sbp->nd_sbp()->sbp_parallel_size(); ++i) {\n    unflattened_nd_sbp.mutable_sbp_parallel()->Add()->CopyFrom(in_sbp_parallel);\n  }\n  return JUST(PlacedNdSbp::New(SymbolOf(unflattened_nd_sbp), out_placed_nd_sbp->placement()));\n}\n\nstatic constexpr auto* UnflattenHierarchy = DECORATE(&RawUnflattenHierarchy, ThreadLocalCached);\n\nMaybe<BoxingDividor> RawUnflattenInHierarchy() {\n  return std::make_shared<BoxingDividor>(\n      \"UnflattenInHierarchy\",\n      [](Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out) -> Maybe<Symbol<PlacedNdSbp>> {\n        return UnflattenHierarchy(in, out);\n      });\n}\n\nMaybe<BoxingDividor> RawUnflattenOutHierarchy() {\n  return std::make_shared<BoxingDividor>(\n      \"UnflattenOutHierarchy\",\n      [](Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out) -> Maybe<Symbol<PlacedNdSbp>> {\n        return UnflattenHierarchy(out, in);\n      });\n}\n\n}  // namespace\n\ndecltype(FlattenInHierarchy) FlattenInHierarchy =\n    DECORATE(&RawFlattenInHierarchy, ThreadLocalCached);\ndecltype(UnflattenInHierarchy) UnflattenInHierarchy =\n    DECORATE(&RawUnflattenInHierarchy, ThreadLocalCached);\ndecltype(UnflattenOutHierarchy) UnflattenOutHierarchy =\n    DECORATE(&RawUnflattenOutHierarchy, ThreadLocalCached);\n\nnamespace {\n\nMaybe<Symbol<NdSbp>> GetAllPartialSumNdSbp(int64_t ndim) {\n  NdSbp partial_sum_nd_sbp;\n  for (int64_t i = 0; i < ndim; ++i) {\n    partial_sum_nd_sbp.mutable_sbp_parallel()->Add()->mutable_partial_sum_parallel();\n  }\n  return SymbolOf(partial_sum_nd_sbp);\n}\n\nauto* CachedGetAllPartialSumNdSbp = DECORATE(&GetAllPartialSumNdSbp, ThreadLocalCached);\n\nMaybe<Symbol<PlacedNdSbp>> RawReplaceNdSbpWithPartialSum(Symbol<PlacedNdSbp> placed_nd_sbp) {\n  Symbol<NdSbp> partial_sum_nd_sbp =\n      JUST(CachedGetAllPartialSumNdSbp(placed_nd_sbp->nd_sbp()->sbp_parallel_size()));\n  return JUST(PlacedNdSbp::New(partial_sum_nd_sbp, placed_nd_sbp->placement()));\n}\n\nstatic constexpr auto* ReplaceNdSbpWithPartialSum =\n    DECORATE(&RawReplaceNdSbpWithPartialSum, ThreadLocalCached);\n\nMaybe<BoxingDividor> RawOutPlacementAndPartialSum() {\n  return std::make_shared<BoxingDividor>(\n      \"OutPlacementAndPartialSum\",\n      [](Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out) -> Maybe<Symbol<PlacedNdSbp>> {\n        return ReplaceNdSbpWithPartialSum(out);\n      });\n}\n\n}  // namespace\n\ndecltype(OutPlacementAndPartialSum) OutPlacementAndPartialSum =\n    DECORATE(&RawOutPlacementAndPartialSum, ThreadLocalCached);\n\nnamespace {\n\nMaybe<Symbol<NdSbp>> GetAllBroadcastNdSbp(int64_t ndim) {\n  NdSbp broadcast_nd_sbp;\n  for (int64_t i = 0; i < ndim; ++i) {\n    broadcast_nd_sbp.mutable_sbp_parallel()->Add()->mutable_broadcast_parallel();\n  }\n  return SymbolOf(broadcast_nd_sbp);\n}\n\nauto* CachedGetAllBroadcastNdSbp = DECORATE(&GetAllBroadcastNdSbp, ThreadLocalCached);\n\nMaybe<Symbol<PlacedNdSbp>> RawReplaceNdSbpWithBroadcast(Symbol<PlacedNdSbp> placed_nd_sbp) {\n  Symbol<NdSbp> broadcast_nd_sbp =\n      JUST(CachedGetAllBroadcastNdSbp(placed_nd_sbp->nd_sbp()->sbp_parallel_size()));\n  return JUST(PlacedNdSbp::New(broadcast_nd_sbp, placed_nd_sbp->placement()));\n}\n\nstatic constexpr auto* ReplaceNdSbpWithBroadcast =\n    DECORATE(&RawReplaceNdSbpWithBroadcast, ThreadLocalCached);\n\nMaybe<BoxingDividor> RawInPlacementAndBroadcast() {\n  return std::make_shared<BoxingDividor>(\n      \"InPlacementAndBroadcast\",\n      [](Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out) -> Maybe<Symbol<PlacedNdSbp>> {\n        return ReplaceNdSbpWithBroadcast(in);\n      });\n}\n\nMaybe<BoxingDividor> RawOutPlacementAndBroadcast() {\n  return std::make_shared<BoxingDividor>(\n      \"OutPlacementAndBroadcast\",\n      [](Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out) -> Maybe<Symbol<PlacedNdSbp>> {\n        return ReplaceNdSbpWithBroadcast(out);\n      });\n}\n\n}  // namespace\n\ndecltype(InPlacementAndBroadcast) InPlacementAndBroadcast =\n    DECORATE(&RawInPlacementAndBroadcast, ThreadLocalCached);\ndecltype(OutPlacementAndBroadcast) OutPlacementAndBroadcast =\n    DECORATE(&RawOutPlacementAndBroadcast, ThreadLocalCached);\n\nnamespace {\n\nMaybe<Symbol<NdSbp>> GetSplitNdSbp(int64_t axis) {\n  NdSbp split_nd_sbp;\n  split_nd_sbp.mutable_sbp_parallel()->Add()->mutable_split_parallel()->set_axis(axis);\n  return SymbolOf(split_nd_sbp);\n}\n\nauto* CachedGetSplitNdSbp = DECORATE(&GetSplitNdSbp, ThreadLocalCached);\n\nMaybe<BoxingDividor> RawInPlacementAndSplit(int64_t axis) {\n  return std::make_shared<BoxingDividor>(\n      \"InPlacementAndSplit\",\n      [=](Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out) -> Maybe<Symbol<PlacedNdSbp>> {\n        Symbol<NdSbp> split_nd_sbp = JUST(CachedGetSplitNdSbp(axis));\n        return PlacedNdSbp::New(split_nd_sbp, in->placement());\n      });\n}\n\nMaybe<BoxingDividor> RawOutPlacementAndSplit(int64_t axis) {\n  return std::make_shared<BoxingDividor>(\n      \"OutPlacementAndSplit\",\n      [=](Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out) -> Maybe<Symbol<PlacedNdSbp>> {\n        Symbol<NdSbp> split_nd_sbp = JUST(CachedGetSplitNdSbp(axis));\n        return PlacedNdSbp::New(split_nd_sbp, out->placement());\n      });\n}\n\n}  // namespace\n\ndecltype(InPlacementAndSplit) InPlacementAndSplit =\n    DECORATE(&RawInPlacementAndSplit, ThreadLocalCached);\ndecltype(OutPlacementAndSplit) OutPlacementAndSplit =\n    DECORATE(&RawOutPlacementAndSplit, ThreadLocalCached);\n\nnamespace {\n\nMaybe<Symbol<ParallelDesc>> GetFisrtDeviceOfPlacement(Symbol<ParallelDesc> placement) {\n  ParallelConf parallel_conf;\n  int64_t machine_id = JUST(placement->MachineId4ParallelId(0));\n  int64_t device_id = JUST(placement->DeviceId4ParallelId(0));\n  parallel_conf.set_device_tag(placement->device_tag());\n  parallel_conf.add_device_name(std::string(\"@\") + std::to_string(machine_id) + \":\"\n                                + std::to_string(device_id));\n  for (int64_t i = 0; i < placement->hierarchy()->NumAxes(); ++i) {\n    parallel_conf.mutable_hierarchy()->add_dim(1);\n  }\n  std::shared_ptr<ParallelDesc> parallel_desc;\n  JUST(PhysicalRun([&parallel_desc, &parallel_conf](InstructionsBuilder* builder) -> Maybe<void> {\n    parallel_desc = JUST(builder->GetParallelDescSymbol(parallel_conf));\n    return Maybe<void>::Ok();\n  }));\n  return SymbolOf(*parallel_desc);\n}\n\nMaybe<BoxingDividor> RawInFirstDeviceAndAllBroadcast() {\n  return std::make_shared<BoxingDividor>(\n      \"InFirstDeviceAndAllBroadcast\",\n      [](Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out) -> Maybe<Symbol<PlacedNdSbp>> {\n        return PlacedNdSbp::New(JUST(CachedGetAllBroadcastNdSbp(in->nd_sbp()->sbp_parallel_size())),\n                                JUST(GetFisrtDeviceOfPlacement(in->placement())));\n      });\n}\n\nMaybe<BoxingDividor> RawOutFirstDeviceAndAllBroadcast() {\n  return std::make_shared<BoxingDividor>(\n      \"OutFirstDeviceAndAllBroadcast\",\n      [](Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out) -> Maybe<Symbol<PlacedNdSbp>> {\n        return PlacedNdSbp::New(\n            JUST(CachedGetAllBroadcastNdSbp(out->nd_sbp()->sbp_parallel_size())),\n            JUST(GetFisrtDeviceOfPlacement(out->placement())));\n      });\n}\n\n}  //  namespace\n\ndecltype(InFirstDeviceAndAllBroadcast) InFirstDeviceAndAllBroadcast =\n    DECORATE(&RawInFirstDeviceAndAllBroadcast, ThreadLocalCached);\ndecltype(OutFirstDeviceAndAllBroadcast) OutFirstDeviceAndAllBroadcast =\n    DECORATE(&RawOutFirstDeviceAndAllBroadcast, ThreadLocalCached);\n\nnamespace {\n\nMaybe<Symbol<PlacedNdSbp>> RawPlacementAndRepeatFirstSbp(Symbol<PlacedNdSbp> placed_nd_sbp) {\n  const auto& first_sbp_parallel = placed_nd_sbp->nd_sbp()->sbp_parallel(0);\n  NdSbp out_nd_sbp;\n  for (int64_t i = 0; i < placed_nd_sbp->nd_sbp()->sbp_parallel_size(); ++i) {\n    out_nd_sbp.mutable_sbp_parallel()->Add()->CopyFrom(first_sbp_parallel);\n  }\n  return JUST(PlacedNdSbp::New(SymbolOf(out_nd_sbp), placed_nd_sbp->placement()));\n}\n\nstatic constexpr auto* PlacementAndRepeatFirstSbp =\n    DECORATE(&RawPlacementAndRepeatFirstSbp, ThreadLocalCached);\n\nMaybe<BoxingDividor> RawInPlacementAndRepeatFirstSbp() {\n  return std::make_shared<BoxingDividor>(\n      \"InPlacementAndRepeatFirstSbp\",\n      [](Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out) -> Maybe<Symbol<PlacedNdSbp>> {\n        return PlacementAndRepeatFirstSbp(in);\n      });\n}\n\n}  // namespace\n\ndecltype(InPlacementAndRepeatFirstSbp) InPlacementAndRepeatFirstSbp =\n    DECORATE(&RawInPlacementAndRepeatFirstSbp, ThreadLocalCached);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/boxing/boxing_dividor_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_BOXING_BOXING_DIVIDOR_UTIL_H_\n#define ONEFLOW_CORE_BOXING_BOXING_DIVIDOR_UTIL_H_\n\n#include \"oneflow/core/common/device_type.pb.h\"\n#include \"oneflow/core/boxing/boxing_dividor.h\"\n\nnamespace oneflow {\n\nextern Maybe<BoxingDividor> (*ReplaceInDeviceType)(DeviceType device_type);\nextern Maybe<BoxingDividor> (*ReplaceOutDeviceType)(DeviceType device_type);\nextern Maybe<BoxingDividor> (*FlattenInHierarchy)();\nextern Maybe<BoxingDividor> (*UnflattenInHierarchy)();\nextern Maybe<BoxingDividor> (*UnflattenOutHierarchy)();\nextern Maybe<BoxingDividor> (*OutPlacementAndPartialSum)();\nextern Maybe<BoxingDividor> (*InPlacementAndBroadcast)();\nextern Maybe<BoxingDividor> (*OutPlacementAndBroadcast)();\nextern Maybe<BoxingDividor> (*InPlacementAndSplit)(int64_t axis);\nextern Maybe<BoxingDividor> (*OutPlacementAndSplit)(int64_t axis);\nextern Maybe<BoxingDividor> (*InFirstDeviceAndAllBroadcast)();\nextern Maybe<BoxingDividor> (*OutFirstDeviceAndAllBroadcast)();\nextern Maybe<BoxingDividor> (*InPlacementAndRepeatFirstSbp)();\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_BOXING_BOXING_DIVIDOR_UTIL_H_\n"
  },
  {
    "path": "oneflow/core/boxing/boxing_interpreter_status.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/decorator.h\"\n#include \"oneflow/core/framework/nd_sbp.h\"\n#include \"oneflow/core/framework/placed_nd_sbp.h\"\n#include \"oneflow/core/job/parallel_desc.h\"\n#include \"oneflow/core/boxing/boxing_interpreter_status.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nMaybe<BoxingInterpreterStatus> RawMakeBoxingInterpreterStatus(const std::string& boxing_name,\n                                                              const Shape& logical_shape,\n                                                              Symbol<PlacedNdSbp> in,\n                                                              Symbol<PlacedNdSbp> out) {\n  std::vector<std::string> sorted_boxing_names{boxing_name};\n  BoxingInterpreterStatus status(SymbolOf(sorted_boxing_names), logical_shape, in, out);\n  return status;\n}\n\nMaybe<BoxingInterpreterStatus> RawMakeComposedBoxingInterpreterStatus(\n    const std::shared_ptr<BoxingInterpreterStatus>& lhs_status,\n    const std::shared_ptr<BoxingInterpreterStatus>& rhs_status) {\n  CHECK_OR_RETURN(lhs_status->dst_placed_nd_sbp()\n                  == rhs_status->src_placed_nd_sbp())  // always true\n      << Error::RuntimeError()\n      << \"Intermediate placed_nd_sbp must be equal when compose boxing interpreter status\"\n      << \". lhs_status.dst_nd_sbp: \" << NdSbpToString(lhs_status->dst_placed_nd_sbp()->nd_sbp())\n      << \", rhs_status.dst_nd_sbp: \" << NdSbpToString(rhs_status->src_placed_nd_sbp()->nd_sbp())\n      << \", lhs_status.dst_placement: \"\n      << *JUST(PlacementToString(lhs_status->dst_placed_nd_sbp()->placement()))\n      << \", rhs_status.dst_placement: \"\n      << *JUST(PlacementToString(rhs_status->src_placed_nd_sbp()->placement()));\n  CHECK_OR_RETURN(lhs_status->logical_shape() == rhs_status->logical_shape())  // always true\n      << Error::RuntimeError()\n      << \"Logical_shape must be equal when compose boxing interpreter status\"\n      << \". lhs_status.logical_shape: \" << (lhs_status->logical_shape().ToString())\n      << \". rhs_status.logical_shape: \" << (rhs_status->logical_shape().ToString());\n  std::vector<std::string> sorted_boxing_names(*lhs_status->sorted_boxing_names());\n  sorted_boxing_names.insert(sorted_boxing_names.end(), rhs_status->sorted_boxing_names()->begin(),\n                             rhs_status->sorted_boxing_names()->end());\n  std::vector<Symbol<PlacedNdSbp>> mid_placed_nd_sbp(*lhs_status->mid_placed_nd_sbp());\n  mid_placed_nd_sbp.emplace_back(lhs_status->dst_placed_nd_sbp());\n  mid_placed_nd_sbp.insert(mid_placed_nd_sbp.end(), rhs_status->mid_placed_nd_sbp()->begin(),\n                           rhs_status->mid_placed_nd_sbp()->end());\n  BoxingInterpreterStatus status(sorted_boxing_names, lhs_status->logical_shape(),\n                                 lhs_status->src_placed_nd_sbp(), SymbolOf(mid_placed_nd_sbp),\n                                 rhs_status->dst_placed_nd_sbp());\n  return status;\n}\n\n}  // namespace\n\ndecltype(MakeBoxingInterpreterStatus) MakeBoxingInterpreterStatus =\n    DECORATE(&RawMakeBoxingInterpreterStatus, ThreadLocalCachedCopiable);\ndecltype(MakeComposedBoxingInterpreterStatus) MakeComposedBoxingInterpreterStatus =\n    DECORATE(&RawMakeComposedBoxingInterpreterStatus, ThreadLocalCachedCopiable);\n\nnamespace {\n\nMaybe<std::string> RawGetNdSbpRouting(Symbol<PlacedNdSbp> src_placed_nd_sbp,\n                                      Symbol<std::vector<Symbol<PlacedNdSbp>>> mid_placed_nd_sbp,\n                                      Symbol<PlacedNdSbp> dst_placed_nd_sbp) {\n  std::ostringstream ss;\n  ss << NdSbpToString(src_placed_nd_sbp->nd_sbp());\n  for (const auto& placed_nd_sbp : *mid_placed_nd_sbp) {\n    ss << \" -> \" << NdSbpToString(placed_nd_sbp->nd_sbp());\n  }\n  ss << \" -> \" << NdSbpToString(dst_placed_nd_sbp->nd_sbp());\n  return ss.str();\n}\n\nMaybe<std::string> RawGetPlacementRouting(\n    Symbol<PlacedNdSbp> src_placed_nd_sbp,\n    Symbol<std::vector<Symbol<PlacedNdSbp>>> mid_placed_nd_sbp,\n    Symbol<PlacedNdSbp> dst_placed_nd_sbp) {\n  std::ostringstream ss;\n  ss << *JUST(PlacementToString(src_placed_nd_sbp->placement()));\n  for (const auto& placed_nd_sbp : *mid_placed_nd_sbp) {\n    ss << \" -> \" << *JUST(PlacementToString(placed_nd_sbp->placement()));\n  }\n  ss << \" -> \" << *JUST(PlacementToString(dst_placed_nd_sbp->placement()));\n  return ss.str();\n}\n\nMaybe<std::string> RawGetBoxingDesc(Symbol<std::vector<std::string>> sorted_boxing_names) {\n  CHECK_OR_RETURN(!sorted_boxing_names->empty())  // always true\n      << Error::RuntimeError() << \"boxing_names of eager boxing status can't be empty!\";\n  std::ostringstream ss;\n  ss << sorted_boxing_names->at(0);\n  for (size_t i = 1; i < sorted_boxing_names->size(); ++i) {\n    ss << \" -> \" << sorted_boxing_names->at(i);\n  }\n  return ss.str();\n}\n\nstatic constexpr auto* GetNdSbpRouting = DECORATE(&RawGetNdSbpRouting, ThreadLocalCached);\nstatic constexpr auto* GetPlacementRouting = DECORATE(&RawGetPlacementRouting, ThreadLocalCached);\nstatic constexpr auto* GetBoxingDesc = DECORATE(&RawGetBoxingDesc, ThreadLocalCached);\n\n}  // namespace\n\nconst std::string& BoxingInterpreterStatus::boxing_routing() const {\n  return *CHECK_JUST(GetBoxingDesc(sorted_boxing_names_));\n}\n\nconst std::string& BoxingInterpreterStatus::nd_sbp_routing() const {\n  return *CHECK_JUST(GetNdSbpRouting(src_placed_nd_sbp_, mid_placed_nd_sbp_, dst_placed_nd_sbp_));\n}\n\nconst std::string& BoxingInterpreterStatus::placement_routing() const {\n  return *CHECK_JUST(\n      GetPlacementRouting(src_placed_nd_sbp_, mid_placed_nd_sbp_, dst_placed_nd_sbp_));\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/boxing/boxing_interpreter_status.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_BOXING_BOXING_INTERPRETER_STATUS_H_\n#define ONEFLOW_CORE_BOXING_BOXING_INTERPRETER_STATUS_H_\n\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/common/symbol.h\"\n#include \"oneflow/core/framework/placed_nd_sbp.h\"\n#include \"oneflow/core/common/shape.h\"\n\nnamespace oneflow {\n\nclass BoxingInterpreterStatus;\n\nextern Maybe<BoxingInterpreterStatus> (*MakeBoxingInterpreterStatus)(const std::string& boxing_name,\n                                                                     const Shape& logical_shape,\n                                                                     Symbol<PlacedNdSbp> in,\n                                                                     Symbol<PlacedNdSbp> out);\n\nextern Maybe<BoxingInterpreterStatus> (*MakeComposedBoxingInterpreterStatus)(\n    const std::shared_ptr<BoxingInterpreterStatus>& lhs_status,\n    const std::shared_ptr<BoxingInterpreterStatus>& rhs_status);\n\nclass BoxingInterpreterStatus final {\n public:\n  BoxingInterpreterStatus(Symbol<std::vector<std::string>> sorted_boxing_names,\n                          const Shape& logical_shape, Symbol<PlacedNdSbp> src_placed_nd_sbp,\n                          Symbol<std::vector<Symbol<PlacedNdSbp>>> mid_placed_nd_sbp,\n                          Symbol<PlacedNdSbp> dst_placed_nd_sbp)\n      : sorted_boxing_names_(sorted_boxing_names),\n        logical_shape_(logical_shape),\n        src_placed_nd_sbp_(src_placed_nd_sbp),\n        mid_placed_nd_sbp_(mid_placed_nd_sbp),\n        dst_placed_nd_sbp_(dst_placed_nd_sbp) {}\n  BoxingInterpreterStatus(Symbol<std::vector<std::string>> sorted_boxing_names,\n                          const Shape& logical_shape, Symbol<PlacedNdSbp> src_placed_nd_sbp,\n                          Symbol<PlacedNdSbp> dst_placed_nd_sbp)\n      : BoxingInterpreterStatus(sorted_boxing_names, logical_shape, src_placed_nd_sbp,\n                                SymbolOf(std::vector<Symbol<PlacedNdSbp>>()), dst_placed_nd_sbp) {}\n  ~BoxingInterpreterStatus() = default;\n\n  bool operator==(const BoxingInterpreterStatus& other) const {\n    return this->sorted_boxing_names_ == other.sorted_boxing_names_\n           && this->src_placed_nd_sbp_ == other.src_placed_nd_sbp_\n           && this->mid_placed_nd_sbp_ == other.mid_placed_nd_sbp_\n           && this->dst_placed_nd_sbp_ == other.dst_placed_nd_sbp_;\n  }\n\n  // Getters\n  Symbol<std::vector<std::string>> sorted_boxing_names() const { return sorted_boxing_names_; }\n  const Shape& logical_shape() const { return logical_shape_; }\n  Symbol<PlacedNdSbp> src_placed_nd_sbp() const { return src_placed_nd_sbp_; }\n  Symbol<PlacedNdSbp> dst_placed_nd_sbp() const { return dst_placed_nd_sbp_; }\n  Symbol<std::vector<Symbol<PlacedNdSbp>>> mid_placed_nd_sbp() const { return mid_placed_nd_sbp_; }\n\n  const std::string& boxing_routing() const;\n  const std::string& nd_sbp_routing() const;\n  const std::string& placement_routing() const;\n\n private:\n  Symbol<std::vector<std::string>> sorted_boxing_names_;\n  const Shape logical_shape_;\n  Symbol<PlacedNdSbp> src_placed_nd_sbp_;\n  Symbol<std::vector<Symbol<PlacedNdSbp>>> mid_placed_nd_sbp_;\n  Symbol<PlacedNdSbp> dst_placed_nd_sbp_;\n};\n\n}  // namespace oneflow\n\nnamespace std {\n\ntemplate<>\nstruct hash<oneflow::BoxingInterpreterStatus> {\n  size_t operator()(const oneflow::BoxingInterpreterStatus& status) const {\n    using namespace oneflow;\n    size_t ret = 0;\n    for (const auto& boxing_name : *status.sorted_boxing_names()) { AddHash(&ret, boxing_name); }\n    AddHash(&ret, *status.src_placed_nd_sbp());\n    for (const auto& mid_placed_nd_sbp : *status.mid_placed_nd_sbp()) {\n      AddHash(&ret, *mid_placed_nd_sbp);\n    }\n    AddHash(&ret, *status.dst_placed_nd_sbp());\n    return ret;\n  }\n};\n\n}  // namespace std\n\n#endif  // ONEFLOW_CORE_BOXING_BOXING_INTERPRETER_STATUS_H_\n"
  },
  {
    "path": "oneflow/core/boxing/ccl_boxing_function.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/id_util.h\"\n#include \"oneflow/core/framework/nd_sbp.h\"\n#include \"oneflow/core/job/nd_sbp_util.h\"\n#include \"oneflow/core/boxing/eager_boxing_interpreter.h\"\n#include \"oneflow/core/common/decorator.h\"\n#include \"oneflow/core/functional/functional.h\"\n#include \"oneflow/core/framework/user_op_registry_manager.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nclass EagerBoxingKernelRegContext final : public user_op::KernelRegContext {\n public:\n  explicit EagerBoxingKernelRegContext(DeviceType device_type) : device_type_(device_type) {}\n  ~EagerBoxingKernelRegContext() = default;\n\n  DeviceType device_type() const override { return device_type_; }\n  const ParallelContext& parallel_ctx() const override { PRINT_BUG_PROMPT_AND_ABORT(); }\n  const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(const std::string& arg_name,\n                                                        int32_t index) const override {\n    PRINT_BUG_PROMPT_AND_ABORT();\n  }\n  const std::vector<std::pair<std::string, int32_t>>& inputs() const override {\n    PRINT_BUG_PROMPT_AND_ABORT();\n  }\n  const std::vector<std::pair<std::string, int32_t>>& outputs() const override {\n    PRINT_BUG_PROMPT_AND_ABORT();\n  }\n\n  const user_op::UserOpConfWrapper& user_op_conf() const override { PRINT_BUG_PROMPT_AND_ABORT(); }\n\n  const std::shared_ptr<const user_op::AttrVal>& Attr4Name(\n      const std::string& attr_name) const override {\n    PRINT_BUG_PROMPT_AND_ABORT();\n  }\n\n private:\n  DeviceType device_type_;\n};\n\nMaybe<bool> RawCheckCclKernelRegistered(const std::string& op_type_name, DeviceType device_type) {\n  EagerBoxingKernelRegContext reg_ctx(device_type);\n  return user_op::UserOpRegistryMgr::Get().IsOpKernelRegistered(op_type_name, reg_ctx);\n}\n\nstatic constexpr auto* CheckCclKernelRegistered =\n    DECORATE(&RawCheckCclKernelRegistered, ThreadLocalCachedCopiable);\n\nMaybe<void> RawCheckCclP2B(Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out,\n                           const Shape& logical_shape) {\n  // NOLINTBEGIN(maybe-need-error-msg)\n  CHECK_EQ_OR_RETURN(in->nd_sbp()->sbp_parallel_size(), 1);\n  CHECK_EQ_OR_RETURN(out->nd_sbp()->sbp_parallel_size(), 1);\n\n  CHECK_OR_RETURN(NdSbpIsAllPartialSum(*in->nd_sbp()));\n  CHECK_OR_RETURN(NdSbpIsAllBroadcast(*out->nd_sbp()));\n\n  CHECK_OR_RETURN(in->placement() == out->placement());\n  CHECK_OR_RETURN(                                                      // NOLINT\n      JUST(CheckCclKernelRegistered(\"eager_ccl_all_reduce\",             // NOLINT\n                                    in->placement()->device_type())));  // NOLINT\n  // NOLINTEND(maybe-need-error-msg)\n  return Maybe<void>::Ok();\n}\n\nstatic constexpr auto* CheckCclP2B = DECORATE(&RawCheckCclP2B, ThreadLocalCachedCopiable);\n\nMaybe<void> RawCheckCclP2S(Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out,\n                           const Shape& logical_shape) {\n  // NOLINTBEGIN(maybe-need-error-msg)\n  CHECK_EQ_OR_RETURN(in->nd_sbp()->sbp_parallel_size(), 1);\n  CHECK_EQ_OR_RETURN(out->nd_sbp()->sbp_parallel_size(), 1);\n  CHECK_OR_RETURN(NdSbpIsAllPartialSum(*in->nd_sbp()));\n  CHECK_OR_RETURN(NdSbpIsAllSplit(*out->nd_sbp(), 0));\n\n  CHECK_GT_OR_RETURN(logical_shape.NumAxes(), 0);\n  CHECK_OR_RETURN(logical_shape.At(0) % in->placement()->parallel_num() == 0);\n\n  CHECK_OR_RETURN(in->placement() == out->placement());\n  CHECK_OR_RETURN(                                                      // NOLINT\n      JUST(CheckCclKernelRegistered(\"eager_ccl_reduce_scatter\",         // NOLINT\n                                    in->placement()->device_type())));  // NOLINT\n  // NOLINTEND(maybe-need-error-msg)\n  return Maybe<void>::Ok();\n}\n\nstatic constexpr auto* CheckCclP2S = DECORATE(&RawCheckCclP2S, ThreadLocalCachedCopiable);\n\nMaybe<void> RawCheckCclS2B(Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out,\n                           const Shape& logical_shape) {\n  // NOLINTBEGIN(maybe-need-error-msg)\n  CHECK_EQ_OR_RETURN(in->nd_sbp()->sbp_parallel_size(), 1);\n  CHECK_EQ_OR_RETURN(out->nd_sbp()->sbp_parallel_size(), 1);\n\n  CHECK_OR_RETURN(NdSbpIsAllSplit(*in->nd_sbp(), 0));\n  CHECK_OR_RETURN(NdSbpIsAllBroadcast(*out->nd_sbp()));\n\n  CHECK_GT_OR_RETURN(logical_shape.NumAxes(), 0);\n  CHECK_OR_RETURN(logical_shape.At(0) % in->placement()->parallel_num() == 0);\n\n  CHECK_OR_RETURN(in->placement() == out->placement());\n  CHECK_OR_RETURN(                                                      // NOLINT\n      JUST(CheckCclKernelRegistered(\"eager_ccl_all_gather\",             // NOLINT\n                                    in->placement()->device_type())));  // NOLINT\n  // NOLINTEND(maybe-need-error-msg)\n  return Maybe<void>::Ok();\n}\n\nstatic constexpr auto* CheckCclS2B = DECORATE(&RawCheckCclS2B, ThreadLocalCachedCopiable);\n\nMaybe<void> RawCheckCclS2S(Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out,\n                           const Shape& logical_shape) {\n  // NOLINTBEGIN(maybe-need-error-msg)\n  CHECK_EQ_OR_RETURN(in->nd_sbp()->sbp_parallel_size(), 1);\n  CHECK_EQ_OR_RETURN(out->nd_sbp()->sbp_parallel_size(), 1);\n\n  CHECK_OR_RETURN(in->nd_sbp()->sbp_parallel(0).has_split_parallel());\n  CHECK_OR_RETURN(out->nd_sbp()->sbp_parallel(0).has_split_parallel());\n  CHECK_NE_OR_RETURN(in->nd_sbp()->sbp_parallel(0).split_parallel().axis(),\n                     out->nd_sbp()->sbp_parallel(0).split_parallel().axis());\n\n  int64_t in_split_axis = in->nd_sbp()->sbp_parallel(0).split_parallel().axis();\n  int64_t out_split_axis = out->nd_sbp()->sbp_parallel(0).split_parallel().axis();\n  CHECK_GT_OR_RETURN(logical_shape.NumAxes(), in_split_axis);\n  CHECK_GT_OR_RETURN(logical_shape.NumAxes(), out_split_axis);\n  CHECK_OR_RETURN(logical_shape.At(in_split_axis) % in->placement()->parallel_num() == 0);\n  CHECK_OR_RETURN(logical_shape.At(out_split_axis) % in->placement()->parallel_num() == 0);\n\n  CHECK_OR_RETURN(in->placement() == out->placement());\n  CHECK_OR_RETURN(in->placement()->device_type() == DeviceType::kCPU\n                  || in->placement()->device_type() == DeviceType::kCUDA);\n  // NOLINTEND(maybe-need-error-msg)\n  return Maybe<void>::Ok();\n}\n\nstatic constexpr auto* CheckCclS2S = DECORATE(&RawCheckCclS2S, ThreadLocalCachedCopiable);\n\n}  // namespace\n\nMaybe<one::Tensor> CclP2B(const std::shared_ptr<one::Tensor>& tensor, Symbol<PlacedNdSbp> in,\n                          Symbol<PlacedNdSbp> out) {\n  const auto& tensor_nd_sbp = JUST(tensor->nd_sbp());\n  CHECK_OR_RETURN(tensor_nd_sbp == in->nd_sbp())\n      << Error::RuntimeError() << \"The sbp of input tensor (\" << NdSbpToString(tensor_nd_sbp)\n      << \") must match the input sbp (\" << NdSbpToString(in->nd_sbp()) << \")\";\n  const auto& tensor_placement = JUST(tensor->parallel_desc());\n  CHECK_OR_RETURN(tensor_placement == in->placement())\n      << Error::RuntimeError() << \"The placement of input tensor (\"\n      << *JUST(PlacementToString(tensor_placement)) << \") must match the input placement (\"\n      << *JUST(PlacementToString(in->placement())) << \")\";\n  return JUST(one::functional::GlobalAllReduce(tensor));\n}\n\nMaybe<one::Tensor> CclP2S(const std::shared_ptr<one::Tensor>& tensor, Symbol<PlacedNdSbp> in,\n                          Symbol<PlacedNdSbp> out) {\n  const auto& tensor_nd_sbp = JUST(tensor->nd_sbp());\n  CHECK_OR_RETURN(tensor_nd_sbp == in->nd_sbp())\n      << Error::RuntimeError() << \"The sbp of input tensor (\" << NdSbpToString(tensor_nd_sbp)\n      << \") must match the input sbp (\" << NdSbpToString(in->nd_sbp()) << \")\";\n  const auto& tensor_placement = JUST(tensor->parallel_desc());\n  CHECK_OR_RETURN(tensor_placement == in->placement())\n      << Error::RuntimeError() << \"The placement of input tensor (\"\n      << *JUST(PlacementToString(tensor_placement)) << \") must match the input placement (\"\n      << *JUST(PlacementToString(in->placement())) << \")\";\n\n  return JUST(one::functional::GlobalReduceScatter(tensor, \"sum\"));\n}\n\nMaybe<one::Tensor> CclS2B(const std::shared_ptr<one::Tensor>& tensor, Symbol<PlacedNdSbp> in,\n                          Symbol<PlacedNdSbp> out) {\n  const auto& tensor_nd_sbp = JUST(tensor->nd_sbp());\n  CHECK_OR_RETURN(tensor_nd_sbp == in->nd_sbp())\n      << Error::RuntimeError() << \"The sbp of input tensor (\" << NdSbpToString(tensor_nd_sbp)\n      << \") must match the input sbp (\" << NdSbpToString(in->nd_sbp()) << \")\";\n  const auto& tensor_placement = JUST(tensor->parallel_desc());\n  CHECK_OR_RETURN(tensor_placement == in->placement())\n      << Error::RuntimeError() << \"The placement of input tensor (\"\n      << *JUST(PlacementToString(tensor_placement)) << \") must match the input placement (\"\n      << *JUST(PlacementToString(in->placement())) << \")\";\n  return JUST(one::functional::GlobalAllGather(tensor));\n}\n\nMaybe<one::Tensor> CclS2S(const std::shared_ptr<one::Tensor>& tensor, Symbol<PlacedNdSbp> in,\n                          Symbol<PlacedNdSbp> out) {\n  const auto& tensor_nd_sbp = JUST(tensor->nd_sbp());\n  CHECK_OR_RETURN(tensor_nd_sbp == in->nd_sbp())\n      << Error::RuntimeError() << \"The sbp of input tensor (\" << NdSbpToString(tensor_nd_sbp)\n      << \") must match the input sbp (\" << NdSbpToString(in->nd_sbp()) << \")\";\n  const auto& tensor_placement = JUST(tensor->parallel_desc());\n  CHECK_OR_RETURN(tensor_placement == in->placement())\n      << Error::RuntimeError() << \"The placement of input tensor (\"\n      << *JUST(PlacementToString(tensor_placement)) << \") must match the input placement (\"\n      << *JUST(PlacementToString(in->placement())) << \")\";\n  return JUST(one::functional::GlobalS2S(tensor, *JUST(GetSbpList(out->nd_sbp()))));\n}\n\nCOMMAND(RegisterBoxingFunction(\"ccl-p-to-b\", CheckCclP2B, &CclP2B));\nCOMMAND(RegisterBoxingFunction(\"ccl-p-to-s\", CheckCclP2S, &CclP2S));\nCOMMAND(RegisterBoxingFunction(\"ccl-s-to-b\", CheckCclS2B, &CclS2B));\nCOMMAND(RegisterBoxingFunction(\"ccl-s-to-s\", CheckCclS2S, &CclS2S));\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/boxing/cuda_copy_boxing_interpreter.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/boxing/eager_boxing_interpreter.h\"\n#include \"oneflow/core/functional/functional.h\"\n#include \"oneflow/core/operator/operator.h\"\n#include \"oneflow/core/framework/device.h\"\n#include \"oneflow/core/framework/nd_sbp.h\"\n#include \"oneflow/core/job/parallel_desc.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nMaybe<bool> IgnoringDeviceTypeEqual(Symbol<ParallelDesc> lhs, Symbol<ParallelDesc> rhs) {\n  return lhs == JUST(ReplaceDeviceType(rhs, lhs->device_type()));\n}\n\n}  // namespace\n\n// NOLINTBEGIN(maybe-need-error-msg)\nMaybe<void> CheckCopyH2D(Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out,\n                         const Shape& logical_shape) {\n  bool equal = JUST(IgnoringDeviceTypeEqual(in->placement(), out->placement()));\n  CHECK_OR_RETURN(equal);\n  CHECK_EQ_OR_RETURN(in->placement()->device_type(), DeviceType::kCPU);\n  CHECK_NE_OR_RETURN(out->placement()->device_type(), DeviceType::kCPU);\n  CHECK_OR_RETURN(in->nd_sbp() == out->nd_sbp());\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> CheckCopyD2H(Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out,\n                         const Shape& logical_shape) {\n  bool equal = JUST(IgnoringDeviceTypeEqual(in->placement(), out->placement()));\n  CHECK_OR_RETURN(equal);\n  CHECK_NE_OR_RETURN(in->placement()->device_type(), DeviceType::kCPU);\n  CHECK_EQ_OR_RETURN(out->placement()->device_type(), DeviceType::kCPU);\n  CHECK_OR_RETURN(in->nd_sbp() == out->nd_sbp());\n  return Maybe<void>::Ok();\n}\n// NOLINTEND(maybe-need-error-msg)\n\nMaybe<one::Tensor> CopyBoxingFunction(const std::shared_ptr<one::Tensor>& tensor,\n                                      Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out) {\n  const auto& tensor_nd_sbp = JUST(tensor->nd_sbp());\n  CHECK_OR_RETURN(tensor_nd_sbp == in->nd_sbp())\n      << Error::RuntimeError() << \"The sbp of input tensor (\" << NdSbpToString(tensor_nd_sbp)\n      << \") must match the input sbp (\" << NdSbpToString(in->nd_sbp()) << \")\";\n  const auto& tensor_placement = JUST(tensor->parallel_desc());\n  CHECK_OR_RETURN(tensor_placement == in->placement())\n      << Error::RuntimeError() << \"The placement of input tensor (\"\n      << *JUST(PlacementToString(tensor_placement)) << \") must match the input placement (\"\n      << *JUST(PlacementToString(in->placement())) << \")\";\n  const std::shared_ptr<one::Tensor>& local_tensor = JUST(tensor->cur_rank_phy_tensor());\n  const auto& sbp_list = JUST(GetSbpList(out->nd_sbp()));\n  return JUST(one::functional::LocalToGlobal(local_tensor, out->placement(), *sbp_list,\n                                             *tensor->shape(), tensor->dtype(),\n                                             /* sync_data */ false, /*copy=*/false));\n}\n\nCOMMAND(RegisterBoxingFunction(\"copy-h2d\", &CheckCopyH2D, &CopyBoxingFunction));\nCOMMAND(RegisterBoxingFunction(\"copy-d2h\", &CheckCopyD2H, &CopyBoxingFunction));\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/boxing/eager_boxing_interpreter.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <typeinfo>\n#include \"oneflow/core/common/container_util.h\"\n#include \"oneflow/core/common/registry_error.h\"\n#include \"oneflow/core/boxing/eager_boxing_interpreter.h\"\n#include \"oneflow/core/framework/tensor_rpc_util.h\"\n#include \"oneflow/core/framework/to_string.h\"\n#include \"oneflow/core/boxing/eager_boxing_interpreter_mgr.h\"\n#include \"oneflow/core/framework/nd_sbp.h\"\n\nnamespace oneflow {\n\nnamespace {\nMaybe<void> CheckEagerBoxingDataType(DataType val) {\n  CHECK_OR_RETURN(val != DataType::kTensorBuffer && val != DataType::kOFRecord)\n      << Error::RuntimeError() << \"invalid boxing data type \" << ToString(val);\n  return Maybe<void>::Ok();\n}\n}  // namespace\n\nMaybe<one::Tensor> EagerBoxingInterpreter::Interpret(const std::shared_ptr<one::Tensor>& input,\n                                                     Symbol<NdSbp> in_nd_sbp,\n                                                     Symbol<NdSbp> out_nd_sbp,\n                                                     Symbol<ParallelDesc> in_parallel_desc,\n                                                     Symbol<ParallelDesc> out_parallel_desc) const {\n  JUST(CheckEagerBoxingDataType(input->dtype()->data_type()));\n  DisableCheckGlobalTensorMetaScope disable_meta_check;\n  const auto& tensor =\n      JUST(InterpretImpl(input, in_nd_sbp, out_nd_sbp, in_parallel_desc, out_parallel_desc));\n  const auto& tensor_nd_sbp = JUST(tensor->nd_sbp());\n  const auto& tensor_placement = JUST(tensor->parallel_desc());\n  CHECK_OR_RETURN(tensor_nd_sbp == out_nd_sbp)\n      << Error::RuntimeError() << \"The sbp of output tensor (\" << NdSbpToString(tensor_nd_sbp)\n      << \") must match the output sbp (\" << NdSbpToString(out_nd_sbp) << \")\";\n  CHECK_OR_RETURN(tensor_placement == out_parallel_desc)\n      << Error::RuntimeError() << \"The placement of output tensor (\"\n      << *JUST(PlacementToString(tensor_placement)) << \") must match the output placement (\"\n      << *JUST(PlacementToString(out_parallel_desc)) << \")\";\n  return tensor;\n}\n\nnamespace {\n\nHashMap<std::string, BoxingCheckerT>* MutName2BoxingChecker() {\n  static HashMap<std::string, BoxingCheckerT> map;\n  return &map;\n}\n\nHashMap<std::string, BoxingFunctionT>* MutName2BoxingFunction() {\n  static HashMap<std::string, BoxingFunctionT> map;\n  return &map;\n}\n\nMaybe<BoxingFunctionT> RawGetBoxingFunction(const std::string& method_name, Symbol<PlacedNdSbp> in,\n                                            Symbol<PlacedNdSbp> out, const Shape& logical_shape) {\n  const auto& Checker =\n      JUST_MSG(MapAt(*MutName2BoxingChecker(), method_name),\n               std::stringstream() << \"boxing checker not found. checker_name: \" << method_name);\n  JUST(Checker(in, out, logical_shape));\n  return JUST_MSG(MapAt(*MutName2BoxingFunction(), method_name),\n                  std::stringstream()\n                      << \"boxing function not found. function_name: \" << method_name);\n}\n\n}  // namespace\n\nMaybe<BoxingFunctionT> GetBoxingFunction(const std::string& method_name, Symbol<PlacedNdSbp> in,\n                                         Symbol<PlacedNdSbp> out, const Shape& logical_shape) {\n  return DECORATE(&RawGetBoxingFunction, ThreadLocalCachedCopiable)(method_name, in, out,\n                                                                    logical_shape);\n}\n\nvoid RegisterBoxingFunction(const std::string& method_name, const BoxingCheckerT& Checker,\n                            const BoxingFunctionT& BoxingFunction) {\n  CatchRegistryError([&]() -> Maybe<void> {\n    CHECK_OR_RETURN(MutName2BoxingChecker()->emplace(method_name, Checker).second)\n        << Error::RuntimeError() << \"register boxing checker failed: \" << method_name;\n    CHECK_OR_RETURN(MutName2BoxingFunction()->emplace(method_name, BoxingFunction).second)\n        << Error::RuntimeError() << \"register boxing function failed: \" << method_name;\n    return Maybe<void>::Ok();\n  });\n}\n\nMaybe<BoxingInterpreterStatus> AtomicBoxingExpr::Check(Symbol<PlacedNdSbp> in,\n                                                       Symbol<PlacedNdSbp> out,\n                                                       const Shape& logical_shape) const {\n  const auto& Checker =\n      JUST_MSG(MapAt(*MutName2BoxingChecker(), boxing_name_),\n               std::stringstream() << \"boxing checker not found. checker_name: \" << boxing_name_);\n  JUST(Checker(in, out, logical_shape));\n  return MakeBoxingInterpreterStatus(boxing_name_, logical_shape, in, out);\n}\n\nMaybe<BoxingFunctionT> AtomicBoxingExpr::GetBoxingFunction(Symbol<PlacedNdSbp> in,\n                                                           Symbol<PlacedNdSbp> out,\n                                                           const Shape& logical_shape) const {\n  return DECORATE(&RawGetBoxingFunction, ThreadLocalCachedCopiable)(boxing_name_, in, out,\n                                                                    logical_shape);\n}\n\nMaybe<BoxingInterpreterStatus> DivideAndConquerBoxingExpr::Check(Symbol<PlacedNdSbp> in,\n                                                                 Symbol<PlacedNdSbp> out,\n                                                                 const Shape& logical_shape) const {\n  const auto& middle = JUST((*boxing_dividor_)(in, out));\n  const auto& lhs_status = JUST(lhs_conquer_->Check(in, middle, logical_shape));\n  const auto& rhs_status = JUST(rhs_conquer_->Check(middle, out, logical_shape));\n  return MakeComposedBoxingInterpreterStatus(lhs_status, rhs_status);\n}\n\nMaybe<BoxingFunctionT> DivideAndConquerBoxingExpr::GetBoxingFunction(\n    Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out, const Shape& logical_shape) const {\n  const auto& middle = JUST((*boxing_dividor_)(in, out));\n  const auto& lhs_boxing_func = JUST(lhs_conquer_->GetBoxingFunction(in, middle, logical_shape));\n  const auto& rhs_boxing_func = JUST(rhs_conquer_->GetBoxingFunction(middle, out, logical_shape));\n  BoxingFunctionT boxing_function =\n      [lhs_boxing_func, rhs_boxing_func, middle, in, out, &logical_shape](\n          const std::shared_ptr<one::Tensor>& tensor, Symbol<PlacedNdSbp> arg_in,\n          Symbol<PlacedNdSbp> arg_out) -> Maybe<one::Tensor> {\n    // Always true, if check failed, there is a bug in oneflow needed to be resolved.\n    CHECK_OR_RETURN(in == arg_in) << Error::RuntimeError() << \"The placement (\"\n                                  << *JUST(PlacementToString(arg_in->placement())) << \") and sbp (\"\n                                  << NdSbpToString(in->nd_sbp())\n                                  << \") of input tensor must match the placement (\"\n                                  << *JUST(PlacementToString(in->placement())) << \") and sbp (\"\n                                  << NdSbpToString(arg_in->nd_sbp())\n                                  << \") used for get this boxing function! Please submit an issue \"\n                                     \"in `https://github.com/Oneflow-Inc/oneflow/issues` \"\n                                     \"and we will fix it as soon as possible\";\n    CHECK_OR_RETURN(logical_shape == *tensor->shape())\n        << Error::RuntimeError() << \"The logical_shape \" << tensor->shape()->ToString()\n        << \" of input tensor must match the logical_shape \" << logical_shape.ToString()\n        << \" used for get this boxing function! Please submit an issue in \"\n           \"`https://github.com/Oneflow-Inc/oneflow/issues` and we will fix it \"\n           \"as soon as possible\";\n    CHECK_OR_RETURN(out == arg_out)\n        << Error::RuntimeError() << \"The placement (\"\n        << *JUST(PlacementToString(arg_out->placement())) << \") and sbp (\"\n        << NdSbpToString(arg_out->nd_sbp()) << \") of output tensor must match the placement (\"\n        << *JUST(PlacementToString(out->placement())) << \") and sbp (\"\n        << NdSbpToString(out->nd_sbp())\n        << \") used for get this boxing function! Please submit \"\n           \"an issue in `https://github.com/Oneflow-Inc/oneflow/issues` and we will fix it \"\n           \"as soon as possible\";\n    const auto& middle_tensor = JUST((*lhs_boxing_func)(tensor, in, middle));\n    return JUST((*rhs_boxing_func)(middle_tensor, middle, out));\n  };\n  return boxing_function;\n}\n\nMaybe<BoxingInterpreterStatus> OrBoxingExpr::Check(Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out,\n                                                   const Shape& logical_shape) const {\n  const auto& lhs_status = TRY(lhs_boxing_->Check(in, out, logical_shape));\n  if (lhs_status.IsOk()) { return lhs_status; }\n  return rhs_boxing_->Check(in, out, logical_shape);\n}\n\nMaybe<BoxingFunctionT> OrBoxingExpr::GetBoxingFunction(Symbol<PlacedNdSbp> in,\n                                                       Symbol<PlacedNdSbp> out,\n                                                       const Shape& logical_shape) const {\n  if (lhs_boxing_->Check(in, out, logical_shape).IsOk()) {\n    return lhs_boxing_->GetBoxingFunction(in, out, logical_shape);\n  }\n  JUST(rhs_boxing_->Check(in, out, logical_shape));\n  return rhs_boxing_->GetBoxingFunction(in, out, logical_shape);\n}\n\nMaybe<BoxingExprIf> BoxingExpr(const std::string& boxing_name) {\n  JUST(MapAt(*MutName2BoxingChecker(), boxing_name));\n  auto boxing_expr = std::make_unique<AtomicBoxingExpr>(boxing_name);\n  return std::shared_ptr<BoxingExprIf>(std::move(boxing_expr));\n}\n\nMaybe<BoxingExprIf> BoxingExpr(const std::shared_ptr<BoxingDividor>& boxing_dividor,\n                               const std::string& lhs_conquer, const std::string& rhs_conquer) {\n  return BoxingExpr(boxing_dividor, JUST(BoxingExpr(lhs_conquer)), JUST(BoxingExpr(rhs_conquer)));\n}\n\nMaybe<BoxingExprIf> BoxingExpr(const std::shared_ptr<BoxingDividor>& boxing_dividor,\n                               const std::shared_ptr<BoxingExprIf>& lhs_conquer,\n                               const std::string& rhs_conquer) {\n  return BoxingExpr(boxing_dividor, lhs_conquer, JUST(BoxingExpr(rhs_conquer)));\n}\n\nMaybe<BoxingExprIf> BoxingExpr(const std::shared_ptr<BoxingDividor>& boxing_dividor,\n                               const std::string& lhs_conquer,\n                               const std::shared_ptr<BoxingExprIf>& rhs_conquer) {\n  return BoxingExpr(boxing_dividor, JUST(BoxingExpr(lhs_conquer)), rhs_conquer);\n}\n\nMaybe<BoxingExprIf> BoxingExpr(const std::shared_ptr<BoxingDividor>& boxing_dividor,\n                               const std::shared_ptr<BoxingExprIf>& lhs_conquer,\n                               const std::shared_ptr<BoxingExprIf>& rhs_conquer) {\n  auto divide_and_conquer =\n      std::make_unique<DivideAndConquerBoxingExpr>(boxing_dividor, lhs_conquer, rhs_conquer);\n  return std::shared_ptr<BoxingExprIf>(std::move(divide_and_conquer));\n}\n\nstd::shared_ptr<BoxingExprIf> operator|(const std::shared_ptr<BoxingExprIf>& lhs_boxing,\n                                        const std::shared_ptr<BoxingExprIf>& rhs_boxing) {\n  auto or_boxing = std::make_unique<OrBoxingExpr>(lhs_boxing, rhs_boxing);\n  return std::shared_ptr<BoxingExprIf>(std::move(or_boxing));\n}\n\nMaybe<BoxingExprIf> OptionalBoxing(const std::string& boxing_mame) {\n  return JUST(BoxingExpr(boxing_mame)) | JUST(BoxingExpr(\"identity\"));\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/boxing/eager_boxing_interpreter.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_BOXING_EAGER_BOXING_INTERPRETER_H_\n#define ONEFLOW_CORE_BOXING_EAGER_BOXING_INTERPRETER_H_\n\n#include \"oneflow/core/common/symbol.h\"\n#include \"oneflow/core/framework/tensor_tuple.h\"\n#include \"oneflow/core/boxing/boxing_dividor.h\"\n#include \"oneflow/core/framework/tensor.h\"\n#include \"oneflow/core/framework/placed_nd_sbp.h\"\n#include \"oneflow/core/job/parallel_desc.h\"\n#include \"oneflow/core/job/sbp_parallel.h\"\n#include \"oneflow/core/boxing/boxing_interpreter_status.h\"\n\nnamespace oneflow {\n\nclass EagerBoxingInterpreter {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(EagerBoxingInterpreter);\n  EagerBoxingInterpreter() = default;\n  virtual ~EagerBoxingInterpreter() = default;\n\n  Maybe<one::Tensor> Interpret(const std::shared_ptr<one::Tensor>& input, Symbol<NdSbp> in_nd_sbp,\n                               Symbol<NdSbp> out_nd_sbp, Symbol<ParallelDesc> in_parallel_desc,\n                               Symbol<ParallelDesc> out_parallel_desc) const;\n  virtual Maybe<BoxingInterpreterStatus> boxing_interpreter_status() const = 0;\n\n protected:\n  virtual Maybe<one::Tensor> InterpretImpl(const std::shared_ptr<one::Tensor>& input,\n                                           Symbol<NdSbp> in_nd_sbp, Symbol<NdSbp> out_nd_sbp,\n                                           Symbol<ParallelDesc> in_parallel_desc,\n                                           Symbol<ParallelDesc> out_parallel_desc) const = 0;\n};\n\nusing BoxingCheckerT = std::function<Maybe<void>(Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out,\n                                                 const Shape& logical_shape)>;\nusing BoxingFunctionT = std::function<Maybe<one::Tensor>(\n    const std::shared_ptr<one::Tensor>& input, Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out)>;\n\nMaybe<BoxingFunctionT> GetBoxingFunction(const std::string& method_name, Symbol<PlacedNdSbp> in,\n                                         Symbol<PlacedNdSbp> out, const Shape& logical_shape);\n\nvoid RegisterBoxingFunction(const std::string& method_name, const BoxingCheckerT& Check,\n                            const BoxingFunctionT& BoxingFunction);\n\ninline void RegisterBoxingFunction(\n    const std::string& method_name,\n    const std::pair<BoxingCheckerT, BoxingFunctionT>& CheckAndBoxing) {\n  RegisterBoxingFunction(method_name, CheckAndBoxing.first, CheckAndBoxing.second);\n}\n\nclass NaiveEagerBoxingInterpreter : public EagerBoxingInterpreter {\n public:\n  explicit NaiveEagerBoxingInterpreter(\n      const std::shared_ptr<BoxingFunctionT>& boxing_function,\n      const std::shared_ptr<BoxingInterpreterStatus>& boxing_interpreter_status)\n      : boxing_function_(boxing_function), boxing_interpreter_status_(boxing_interpreter_status) {}\n  NaiveEagerBoxingInterpreter(const NaiveEagerBoxingInterpreter&) = delete;\n  NaiveEagerBoxingInterpreter(NaiveEagerBoxingInterpreter&&) = delete;\n  ~NaiveEagerBoxingInterpreter() override = default;\n\n  Maybe<BoxingInterpreterStatus> boxing_interpreter_status() const override {\n    return boxing_interpreter_status_;\n  }\n\n private:\n  Maybe<one::Tensor> InterpretImpl(const std::shared_ptr<one::Tensor>& input,\n                                   Symbol<NdSbp> in_nd_sbp, Symbol<NdSbp> out_nd_sbp,\n                                   Symbol<ParallelDesc> in_parallel_desc,\n                                   Symbol<ParallelDesc> out_parallel_desc) const override {\n    const auto& in_placed_nd_sbp = JUST(PlacedNdSbp::New(in_nd_sbp, in_parallel_desc));\n    const auto& out_placed_nd_sbp = JUST(PlacedNdSbp::New(out_nd_sbp, out_parallel_desc));\n    return JUST((*boxing_function_)(input, in_placed_nd_sbp, out_placed_nd_sbp));\n  }\n\n  const std::shared_ptr<BoxingFunctionT> boxing_function_;\n  const std::shared_ptr<BoxingInterpreterStatus> boxing_interpreter_status_;\n};\n\nclass BoxingExprIf {\n public:\n  BoxingExprIf(const BoxingExprIf&) = default;\n  BoxingExprIf(BoxingExprIf&&) = default;\n  virtual ~BoxingExprIf() = default;\n\n  virtual Maybe<BoxingInterpreterStatus> Check(Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out,\n                                               const Shape& logical_shape) const = 0;\n  virtual Maybe<BoxingFunctionT> GetBoxingFunction(Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out,\n                                                   const Shape& logical_shape) const = 0;\n\n protected:\n  BoxingExprIf() = default;\n};\n\nclass AtomicBoxingExpr final : public BoxingExprIf {\n public:\n  AtomicBoxingExpr(const AtomicBoxingExpr&) = delete;\n  AtomicBoxingExpr(AtomicBoxingExpr&&) = delete;\n  ~AtomicBoxingExpr() override = default;\n\n  explicit AtomicBoxingExpr(const std::string& boxing_name)\n      : BoxingExprIf(), boxing_name_(boxing_name) {}\n\n  Maybe<BoxingInterpreterStatus> Check(Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out,\n                                       const Shape& logical_shape) const override;\n  Maybe<BoxingFunctionT> GetBoxingFunction(Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out,\n                                           const Shape& logical_shape) const override;\n\n private:\n  const std::string boxing_name_;\n};\n\nclass DivideAndConquerBoxingExpr final : public BoxingExprIf {\n public:\n  DivideAndConquerBoxingExpr(const DivideAndConquerBoxingExpr&) = delete;\n  DivideAndConquerBoxingExpr(DivideAndConquerBoxingExpr&&) = delete;\n  ~DivideAndConquerBoxingExpr() override = default;\n\n  explicit DivideAndConquerBoxingExpr(const std::shared_ptr<BoxingDividor>& boxing_dividor,\n                                      const std::shared_ptr<BoxingExprIf>& lhs_conquer,\n                                      const std::shared_ptr<BoxingExprIf>& rhs_conquer)\n      : BoxingExprIf(),\n        boxing_dividor_(boxing_dividor),\n        lhs_conquer_(lhs_conquer),\n        rhs_conquer_(rhs_conquer) {}\n\n  Maybe<BoxingInterpreterStatus> Check(Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out,\n                                       const Shape& logical_shape) const override;\n  Maybe<BoxingFunctionT> GetBoxingFunction(Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out,\n                                           const Shape& logical_shape) const override;\n\n private:\n  const std::shared_ptr<BoxingDividor> boxing_dividor_;\n  const std::shared_ptr<BoxingExprIf> lhs_conquer_;\n  const std::shared_ptr<BoxingExprIf> rhs_conquer_;\n};\n\nclass OrBoxingExpr final : public BoxingExprIf {\n public:\n  OrBoxingExpr(const OrBoxingExpr&) = delete;\n  OrBoxingExpr(OrBoxingExpr&&) = delete;\n  ~OrBoxingExpr() override = default;\n\n  explicit OrBoxingExpr(const std::shared_ptr<BoxingExprIf>& lhs_boxing,\n                        const std::shared_ptr<BoxingExprIf>& rhs_boxing)\n      : BoxingExprIf(), lhs_boxing_(lhs_boxing), rhs_boxing_(rhs_boxing) {}\n\n  Maybe<BoxingInterpreterStatus> Check(Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out,\n                                       const Shape& logical_shape) const override;\n  Maybe<BoxingFunctionT> GetBoxingFunction(Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out,\n                                           const Shape& logical_shape) const override;\n\n private:\n  const std::shared_ptr<BoxingExprIf> lhs_boxing_;\n  const std::shared_ptr<BoxingExprIf> rhs_boxing_;\n};\n\nMaybe<BoxingExprIf> BoxingExpr(const std::string& boxing_name);\nMaybe<BoxingExprIf> BoxingExpr(const std::shared_ptr<BoxingDividor>& boxing_dividor,\n                               const std::string& lhs_conquer, const std::string& rhs_conquer);\nMaybe<BoxingExprIf> BoxingExpr(const std::shared_ptr<BoxingDividor>& boxing_dividor,\n                               const std::shared_ptr<BoxingExprIf>& lhs_conquer,\n                               const std::string& rhs_conquer);\nMaybe<BoxingExprIf> BoxingExpr(const std::shared_ptr<BoxingDividor>& boxing_dividor,\n                               const std::string& lhs_conquer,\n                               const std::shared_ptr<BoxingExprIf>& rhs_conquer);\nMaybe<BoxingExprIf> BoxingExpr(const std::shared_ptr<BoxingDividor>& boxing_dividor,\n                               const std::shared_ptr<BoxingExprIf>& lhs_conquer,\n                               const std::shared_ptr<BoxingExprIf>& rhs_conquer);\n\nstd::shared_ptr<BoxingExprIf> operator|(const std::shared_ptr<BoxingExprIf>& lhs_boxing,\n                                        const std::shared_ptr<BoxingExprIf>& rhs_boxing);\n\nMaybe<BoxingExprIf> OptionalBoxing(const std::string& boxing_mame);\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_BOXING_EAGER_BOXING_INTERPRETER_H_\n"
  },
  {
    "path": "oneflow/core/boxing/eager_boxing_interpreter_mgr.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <utility>\n#include \"oneflow/core/common/constant.h\"\n#include \"oneflow/core/common/decorator.h\"\n#include \"oneflow/core/framework/nd_sbp.h\"\n#include \"oneflow/core/boxing/eager_boxing_interpreter_mgr.h\"\n#include \"oneflow/core/boxing/boxing_dividor_util.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nMaybe<bool> IgnoringDeviceTypeEqual(Symbol<ParallelDesc> lhs, Symbol<ParallelDesc> rhs) {\n  if (lhs == rhs) { return true; }\n  return lhs == JUST(ReplaceDeviceType(rhs, lhs->device_type()));\n}\n\nnamespace {\n\nMaybe<BoxingExprIf> OptionalCudaCopy(const std::shared_ptr<BoxingExprIf>& core_boxing_expr) {\n  return JUST(BoxingExpr(JUST(ReplaceInDeviceType(DeviceType::kCUDA)),\n                         JUST(OptionalBoxing(\"copy-h2d\")),\n                         JUST(BoxingExpr(JUST(ReplaceOutDeviceType(DeviceType::kCUDA)),\n                                         core_boxing_expr, JUST(OptionalBoxing(\"copy-d2h\"))))));\n}\n\nMaybe<BoxingExprIf> OptionalCpuCopy(const std::shared_ptr<BoxingExprIf>& core_boxing_expr) {\n  return JUST(BoxingExpr(JUST(ReplaceInDeviceType(DeviceType::kCPU)),\n                         JUST(OptionalBoxing(\"copy-d2h\")),\n                         JUST(BoxingExpr(JUST(ReplaceOutDeviceType(DeviceType::kCPU)),\n                                         core_boxing_expr, JUST(OptionalBoxing(\"copy-h2d\"))))));\n}\n\nMaybe<BoxingExprIf> SymmetricOneDimSxToBBoxingExpr() {\n  return JUST(BoxingExpr(JUST(InPlacementAndSplit(0)), JUST(OptionalBoxing(\"ccl-s-to-s\")),\n                         JUST(BoxingExpr(\"ccl-s-to-b\"))));\n}\n\nMaybe<BoxingExprIf> SymmetricOneDimPToSxBoxingExpr() {\n  return JUST(BoxingExpr(JUST(OutPlacementAndSplit(0)), JUST(BoxingExpr(\"ccl-p-to-s\")),\n                         JUST(OptionalBoxing(\"ccl-s-to-s\"))));\n}\n\nMaybe<BoxingExprIf> SymmetricCyclicNDimToNDimBoxingExpr() {\n  return JUST(BoxingExpr(JUST(InPlacementAndRepeatFirstSbp()),\n                         JUST(BoxingExpr(\"symmetric-acyclic-nd-sbp-to-nd-sbp\")),\n                         JUST(BoxingExpr(\"symmetric-acyclic-nd-sbp-to-nd-sbp\"))))\n         | JUST(BoxingExpr(JUST(InPlacementAndBroadcast()),\n                           JUST(BoxingExpr(\"symmetric-acyclic-nd-sbp-to-nd-sbp\")),\n                           JUST(BoxingExpr(\"symmetric-acyclic-nd-sbp-to-nd-sbp\"))));\n}\n\nMaybe<BoxingExprIf> SymmetricNDimToNDimBoxingExpr() {\n  return JUST(BoxingExpr(\"symmetric-acyclic-nd-sbp-to-nd-sbp\"))\n         | JUST(SymmetricCyclicNDimToNDimBoxingExpr());\n}\n\nMaybe<BoxingExprIf> SymmetricOneDimToNDimBoxingExpr() {\n  return JUST(BoxingExpr(JUST(UnflattenInHierarchy()), JUST(BoxingExpr(\"unflatten-hierarchy\")),\n                         JUST(SymmetricNDimToNDimBoxingExpr()) | JUST(BoxingExpr(\"identity\"))));\n}\n\nMaybe<BoxingExprIf> SymmetricNDimToOneDimBoxingExpr() {\n  return JUST(BoxingExpr(JUST(UnflattenOutHierarchy()),\n                         JUST(SymmetricNDimToNDimBoxingExpr()) | JUST(BoxingExpr(\"identity\")),\n                         JUST(BoxingExpr(\"flatten-hierarchy\"))));\n}\n\nMaybe<BoxingExprIf> NToOneBoxingExpr() {\n  return JUST(BoxingExpr(JUST(InPlacementAndBroadcast()),\n                         JUST(BoxingExpr(\"identity\")) | JUST(BoxingExpr(\"ccl-p-to-b\"))\n                             | JUST(SymmetricOneDimSxToBBoxingExpr())\n                             | JUST(BoxingExpr(\"naive-p-to-b\")) | JUST(BoxingExpr(\"naive-s-to-b\"))\n                             | JUST(SymmetricNDimToNDimBoxingExpr())\n                             | JUST(BoxingExpr(\"generic-symmetric-nd-sbp-to-nd-sbp\")),\n                         JUST(BoxingExpr(\"naive-b-to-1\"))));\n}\n\nMaybe<BoxingExprIf> OneToNBoxingExpr() {\n  return JUST(BoxingExpr(JUST(OutPlacementAndPartialSum()), JUST(BoxingExpr(\"naive-1-to-p\")),\n                         JUST(BoxingExpr(\"identity\")) | JUST(BoxingExpr(\"ccl-p-to-b\"))\n                             | JUST(SymmetricOneDimPToSxBoxingExpr())\n                             | JUST(BoxingExpr(\"naive-p-to-b\")) | JUST(BoxingExpr(\"naive-p-to-s\"))\n                             | JUST(SymmetricNDimToNDimBoxingExpr())\n                             | JUST(BoxingExpr(\"generic-symmetric-nd-sbp-to-nd-sbp\"))));\n}\n\nMaybe<BoxingExprIf> SymmetricOneDimXToBBoxingExpr() {\n  return JUST(BoxingExpr(\"ccl-p-to-b\"))\n         | JUST(BoxingExpr(JUST(InPlacementAndSplit(0)),\n                           JUST(BoxingExpr(\"identity\")) | JUST(BoxingExpr(\"ccl-s-to-s\")),\n                           JUST(BoxingExpr(\"ccl-s-to-b\"))));\n}\n\nMaybe<BoxingExprIf> ASymmetricOneDimXToBBoxingExpr() {\n  return JUST(BoxingExpr(JUST(InPlacementAndBroadcast()),\n                         JUST(BoxingExpr(\"identity\")) | JUST(SymmetricOneDimXToBBoxingExpr()),\n                         JUST(BoxingExpr(\"asymmetric-broadcast\"))));\n}\n\nMaybe<BoxingExprIf> GenericBoxingExpr() {\n  // in_placement contain out_placement or out_placement contain in_placement\n  const auto& boxing_expr_with_inclusive_placement =\n      JUST(BoxingExpr(JUST(OutPlacementAndBroadcast()), JUST(ASymmetricOneDimXToBBoxingExpr()),\n                      JUST(BoxingExpr(\"identity\")) | JUST(BoxingExpr(\"symmetric-b-to-p\"))\n                          | JUST(BoxingExpr(\"symmetric-b-to-s\"))));\n  // in_placement and out_placement have no containment relationship\n  // n to 1\n  const auto& lhs_boxing = JUST(NToOneBoxingExpr());\n  // 1 to 1 -> 1 to n\n  const auto& rhs_boxing =\n      JUST(BoxingExpr(JUST(OutFirstDeviceAndAllBroadcast()), JUST(OptionalBoxing(\"naive-1-to-1\")),\n                      JUST(OneToNBoxingExpr())));\n  return boxing_expr_with_inclusive_placement\n         | JUST(BoxingExpr(JUST(InFirstDeviceAndAllBroadcast()), lhs_boxing, rhs_boxing));\n}\n\nMaybe<BoxingExprIf> RawMainBoxingExpr() {\n  // clang-format off\n  const auto& core = JUST(BoxingExpr(\"identity\"))\n                     | JUST(BoxingExpr(\"copy-h2d\"))\n                     | JUST(BoxingExpr(\"copy-d2h\"))\n                     | JUST(BoxingExpr(\"ccl-p-to-b\"))\n                     | JUST(BoxingExpr(\"ccl-s-to-s\"))\n                     | JUST(SymmetricOneDimSxToBBoxingExpr())\n                     | JUST(SymmetricOneDimPToSxBoxingExpr())\n                     | JUST(BoxingExpr(\"symmetric-b-to-p\"))\n                     | JUST(BoxingExpr(\"symmetric-b-to-s\"))\n                     | JUST(BoxingExpr(\"symmetric-s-to-p\"))\n                     | JUST(SymmetricOneDimXToBBoxingExpr())\n                     | JUST(ASymmetricOneDimXToBBoxingExpr())\n                     | JUST(BoxingExpr(\"naive-1-to-1\"))\n                     | JUST(OneToNBoxingExpr())\n                     | JUST(NToOneBoxingExpr())\n                     | JUST(BoxingExpr(\"naive-s-to-s\"))\n                     | JUST(BoxingExpr(\"naive-s-to-b\"))\n                     | JUST(BoxingExpr(\"naive-b-to-s\"))\n                     | JUST(BoxingExpr(\"naive-p-to-b\"))\n                     | JUST(BoxingExpr(\"naive-p-to-s\"))\n                     | JUST(BoxingExpr(\"naive-s-to-p\"))\n                     | JUST(BoxingExpr(\"nd-sbp-dim-reduce\"))\n                     | JUST(SymmetricNDimToNDimBoxingExpr())\n                     | JUST(BoxingExpr(\"generic-symmetric-nd-sbp-to-nd-sbp\"))\n                     | JUST(SymmetricOneDimToNDimBoxingExpr())\n                     | JUST(SymmetricNDimToOneDimBoxingExpr())\n                     | JUST(GenericBoxingExpr());\n  // clang-format on\n  return core | JUST(OptionalCudaCopy(core)) | JUST(OptionalCpuCopy(core));\n}\n\n}  // namespace\n\nstatic constexpr auto* MainBoxingExpr = DECORATE(&RawMainBoxingExpr, ThreadLocalCached);\n\nMaybe<EagerBoxingInterpreter> GetBoxingInterpreter(Symbol<NdSbp> in_nd_sbp,\n                                                   Symbol<NdSbp> out_nd_sbp,\n                                                   Symbol<ParallelDesc> in_parallel_desc,\n                                                   Symbol<ParallelDesc> out_parallel_desc,\n                                                   const Shape& logical_shape) {\n  const auto& in = JUST(PlacedNdSbp::New(in_nd_sbp, in_parallel_desc));\n  const auto& out = JUST(PlacedNdSbp::New(out_nd_sbp, out_parallel_desc));\n  const auto& main_boxing_expr = JUST(MainBoxingExpr());\n  const auto& status = TRY(main_boxing_expr->Check(in, out, logical_shape));\n  if (status.IsOk()) {\n    const auto& boxing_func = JUST(main_boxing_expr->GetBoxingFunction(in, out, logical_shape));\n    return std::shared_ptr<EagerBoxingInterpreter>(\n        new NaiveEagerBoxingInterpreter(boxing_func, JUST(status)));\n  }\n\n  UNIMPLEMENTED_THEN_RETURN() << Error::RuntimeError() << \"global-to-global not supported\"\n                              << \". from_nd_sbp: \" << NdSbpToString(in_nd_sbp)\n                              << \", to_nd_sbp: \" << NdSbpToString(out_nd_sbp)\n                              << \", from_placement: \" << *JUST(PlacementToString(in_parallel_desc))\n                              << \", to_placement: \" << *JUST(PlacementToString(out_parallel_desc));\n}\n\nstatic constexpr auto* CachedGetBoxingInterpreter =\n    DECORATE(&GetBoxingInterpreter, ThreadLocalCachedCopiable);\n\n}  // namespace\n\nMaybe<EagerBoxingInterpreter> EagerBoxingInterpreterManager::GetEagerBoxingInterpreter(\n    Symbol<NdSbp> in_nd_sbp, Symbol<NdSbp> out_nd_sbp, Symbol<ParallelDesc> in_parallel_desc,\n    Symbol<ParallelDesc> out_parallel_desc, const Shape& logical_shape) const {\n  return JUST(CachedGetBoxingInterpreter(in_nd_sbp, out_nd_sbp, in_parallel_desc, out_parallel_desc,\n                                         logical_shape));\n}\n\nCOMMAND(\n    Singleton<EagerBoxingInterpreterManager>::SetAllocated(new EagerBoxingInterpreterManager()));\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/boxing/eager_boxing_interpreter_mgr.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_BOXING_EAGER_BOXING_INTERPRETER_MGR_H_\n#define ONEFLOW_CORE_BOXING_EAGER_BOXING_INTERPRETER_MGR_H_\n\n#include \"oneflow/core/boxing/eager_boxing_interpreter.h\"\n\nnamespace oneflow {\n\nclass EagerBoxingInterpreterManager final {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(EagerBoxingInterpreterManager);\n  EagerBoxingInterpreterManager() = default;\n  virtual ~EagerBoxingInterpreterManager() = default;\n\n  Maybe<EagerBoxingInterpreter> GetEagerBoxingInterpreter(Symbol<NdSbp> in_nd_sbp,\n                                                          Symbol<NdSbp> out_nd_sbp,\n                                                          Symbol<ParallelDesc> in_parallel_desc,\n                                                          Symbol<ParallelDesc> out_parallel_desc,\n                                                          const Shape& logical_shape) const;\n};\n\ntemplate<typename RetT, typename... Args>\nstruct DisableRecusiveBoxingCall {\n  static_assert(is_maybe<RetT>::value, \"returned value type must be Maybe<T>.\");\n  template<RetT (*func)(Args...)>\n  static RetT Call(Args... arg) {\n    static thread_local bool disable_boxing = false;\n    CHECK_OR_RETURN(!disable_boxing);\n    disable_boxing = true;\n    RetT ret = func(arg...);\n    disable_boxing = false;\n    return ret;\n  }\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_BOXING_EAGER_BOXING_INTERPRETER_MGR_H_\n"
  },
  {
    "path": "oneflow/core/boxing/eager_boxing_logger.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/singleton.h\"\n#include \"oneflow/core/common/decorator.h\"\n#include \"oneflow/core/common/env_var/debug_mode.h\"\n#include \"oneflow/core/boxing/eager_boxing_logger.h\"\n#include \"oneflow/core/boxing/boxing_interpreter_status.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nclass NullEagerBoxingLogger final : public EagerBoxingLogger {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(NullEagerBoxingLogger);\n  NullEagerBoxingLogger() = default;\n  ~NullEagerBoxingLogger() override = default;\n\n  void Log(const BoxingInterpreterStatus& status, const std::string& prefix) const override {}\n};\n\nclass NaiveEagerBoxingLogger final : public EagerBoxingLogger {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(NaiveEagerBoxingLogger);\n  NaiveEagerBoxingLogger() = default;\n  ~NaiveEagerBoxingLogger() override = default;\n\n  void Log(const BoxingInterpreterStatus& status, const std::string& prefix) const override {\n    LOG(INFO) << prefix << \"Boxing route: \" << (status.boxing_routing());\n    LOG(INFO) << prefix << \"Logical shape: \" << (status.logical_shape().ToString());\n    LOG(INFO) << prefix << \"Altered state of sbp: \" << (status.nd_sbp_routing());\n    LOG(INFO) << prefix << \"Altered state of placement: \" << (status.placement_routing());\n  }\n};\n\nconst EagerBoxingLogger* CreateEagerBoxingLogger() {\n  if (IsInDebugMode()) {\n    return new NaiveEagerBoxingLogger();\n  } else {\n    return new NullEagerBoxingLogger();\n  }\n}\n\n}  // namespace\n\nCOMMAND(Singleton<const EagerBoxingLogger>::SetAllocated(CreateEagerBoxingLogger()));\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/boxing/eager_boxing_logger.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_BOXING_EAGER_BOXING_LOGGER_H_\n#define ONEFLOW_CORE_BOXING_EAGER_BOXING_LOGGER_H_\n\n#include \"oneflow/core/common/util.h\"\n\nnamespace oneflow {\n\nclass BoxingInterpreterStatus;\n\nclass EagerBoxingLogger {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(EagerBoxingLogger);\n  EagerBoxingLogger() = default;\n  virtual ~EagerBoxingLogger() = default;\n\n  virtual void Log(const BoxingInterpreterStatus& status, const std::string& prefix) const = 0;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_BOXING_EAGER_BOXING_LOGGER_H_\n"
  },
  {
    "path": "oneflow/core/boxing/flatten_hierarchy.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/nd_sbp.h\"\n#include \"oneflow/core/boxing/eager_boxing_interpreter.h\"\n#include \"oneflow/core/functional/functional.h\"\n#include \"oneflow/core/common/decorator.h\"\n#include \"oneflow/core/operator/operator.h\"\n\nnamespace oneflow {\n\nnamespace {\n\n// NOLINTBEGIN(maybe-need-error-msg)\nMaybe<void> RawCheckFlattenHierarchy(Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out,\n                                     const Shape& logical_shape) {\n  CHECK_GT_OR_RETURN(in->nd_sbp()->sbp_parallel_size(), 1);\n  CHECK_EQ_OR_RETURN(out->nd_sbp()->sbp_parallel_size(), 1);\n  for (int i = 0; i < in->nd_sbp()->sbp_parallel_size(); ++i) {\n    const auto& sbp_parallel = in->nd_sbp()->sbp_parallel(i);\n    CHECK_OR_RETURN(sbp_parallel == out->nd_sbp()->sbp_parallel(0)) << \"nd_sbp axis: \" << i;\n  }\n  CHECK_EQ_OR_RETURN(in->placement()->device_type(), out->placement()->device_type());\n  CHECK_EQ_OR_RETURN(in->placement()->parallel_num(), out->placement()->parallel_num());\n  ParallelConf flattened_parallel_conf(in->placement()->parallel_conf());\n  flattened_parallel_conf.clear_hierarchy();\n  const auto& flatten_placement = SymbolOf(ParallelDesc(flattened_parallel_conf));\n  CHECK_OR_RETURN(flatten_placement == out->placement())\n      << \"The output placement is not a hierarch-flattened version of the input placement\";\n  for (int64_t in_parallel_id = 0; in_parallel_id < in->placement()->parallel_num();\n       ++in_parallel_id) {\n    const auto& in_physical_shape =\n        JUST(GetPhysicalShape(logical_shape, *in->nd_sbp(), *in->placement(), in_parallel_id));\n    const auto& out_physical_shape =\n        JUST(GetPhysicalShape(logical_shape, *out->nd_sbp(), *out->placement(), in_parallel_id));\n    CHECK_EQ_OR_RETURN(*in_physical_shape, *out_physical_shape);\n  }\n  return Maybe<void>::Ok();\n}\n// NOLINTEND(maybe-need-error-msg)\n\n}  // namespace\n\nstatic constexpr auto* CheckFlattenHierarchy =\n    DECORATE(&RawCheckFlattenHierarchy, ThreadLocalCachedCopiable);\n\nMaybe<one::Tensor> FlattenHierarchy(const std::shared_ptr<one::Tensor>& tensor,\n                                    Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out) {\n  const auto& tensor_nd_sbp = JUST(tensor->nd_sbp());\n  CHECK_OR_RETURN(tensor_nd_sbp == in->nd_sbp())\n      << Error::RuntimeError() << \"The sbp of input tensor (\" << NdSbpToString(tensor_nd_sbp)\n      << \") must match the input sbp (\" << NdSbpToString(in->nd_sbp()) << \")\";\n  const auto& tensor_placement = JUST(tensor->parallel_desc());\n  CHECK_OR_RETURN(tensor_placement == in->placement())\n      << Error::RuntimeError() << \"The placement of input tensor (\"\n      << *JUST(PlacementToString(tensor_placement)) << \") must match the input placement (\"\n      << *JUST(PlacementToString(in->placement())) << \")\";\n  const auto& local_tensor = JUST(tensor->cur_rank_phy_tensor());\n  const auto& sbp_list = JUST(GetSbpList(out->nd_sbp()));\n  return JUST(one::functional::LocalToGlobal(local_tensor, out->placement(), *sbp_list,\n                                             *tensor->shape(), tensor->dtype(),\n                                             /* sync_data */ false, /*copy=*/true));\n}\n\nCOMMAND(RegisterBoxingFunction(\"flatten-hierarchy\", CheckFlattenHierarchy, &FlattenHierarchy));\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/boxing/generic_symmetric_nd_sbp_boxing.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/boxing/eager_boxing_interpreter_mgr.h\"\n#include \"oneflow/core/boxing/eager_boxing_logger.h\"\n#include \"oneflow/core/boxing/eager_boxing_interpreter.h\"\n#include \"oneflow/core/common/decorator.h\"\n#include \"oneflow/core/common/balanced_splitter.h\"\n#include \"oneflow/core/functional/functional.h\"\n#include \"oneflow/core/framework/placement_sbp_util.h\"\n#include \"oneflow/core/framework/nd_sbp.h\"\n#include \"oneflow/core/operator/operator.h\"\n#include \"oneflow/core/common/stride.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nbool RawIsAllBroadcastNdSbpAfterDim(Symbol<NdSbp> nd_sbp, int dim) {\n  for (int i = dim; i < nd_sbp->sbp_parallel_size(); ++i) {\n    if (!nd_sbp->sbp_parallel(i).has_broadcast_parallel()) { return false; }\n  }\n  return true;\n}\n\nstatic constexpr auto* IsAllBroadcastNdSbpAfterDim =\n    DECORATE(&RawIsAllBroadcastNdSbpAfterDim, ThreadLocalCached);\n\nMaybe<Symbol<SbpParallel>> GetBroadcastSbp() {\n  SbpParallel broadcast_sbp;\n  broadcast_sbp.mutable_broadcast_parallel();\n  return SymbolOf(broadcast_sbp);\n}\n\nauto* CachedGetBroadcastSbp = DECORATE(&GetBroadcastSbp, ThreadLocalCached);\n\n// NOLINTBEGIN(maybe-need-error-msg)\nMaybe<Shape> CalcLogicalShape4Axis(const Shape& logical_shape, int axis,\n                                   Symbol<ParallelDesc> parallel_desc, Symbol<NdSbp> nd_sbp) {\n  CHECK_LT_OR_RETURN(axis, nd_sbp->sbp_parallel_size());  // Always true\n  std::shared_ptr<Shape> sub_logical_shape = std::make_shared<Shape>(logical_shape);\n\n  const auto& opt_parallel_id = JUST(GetParallelId4CurrentProcessCtx(parallel_desc));\n  int64_t parallel_id = JUST(*opt_parallel_id);\n  const auto& hierarchy_shape = *parallel_desc->hierarchy();\n  Stride hierarchy_stride(hierarchy_shape);\n\n  FOR_RANGE(int64_t, i, 0, axis) {\n    const auto& sbp_parallel = nd_sbp->sbp_parallel(i);\n    if (sbp_parallel.has_split_parallel()) {\n      int64_t index = CalcIndex4Axis(parallel_id, hierarchy_stride, i);\n      int64_t dim = hierarchy_shape.At(i);\n      const int64_t split_axis = sbp_parallel.split_parallel().axis();\n\n      if (sub_logical_shape->At(split_axis) > 0) {\n        CHECK_GE_OR_RETURN(sub_logical_shape->At(split_axis), dim)\n            << Error::RuntimeError() << \"The size of tensor (\" << sub_logical_shape->At(split_axis)\n            << \") at split dimension (\" << i\n            << \") should be greater than or equal to parallle num (\" << dim << \")\";\n        const BalancedSplitter bs(sub_logical_shape->At(split_axis), dim);\n        sub_logical_shape->Set(split_axis, bs.At(index).size());\n      }\n    }\n  }\n\n  return sub_logical_shape;\n}\n\nstatic constexpr auto* GetLogicalShape4Axis =\n    DECORATE(&CalcLogicalShape4Axis, ThreadLocalCachedCopiable);\n\nMaybe<int> CalcTheFirstDiffAxisBetweenTwoNdSbp(Symbol<NdSbp> in_nd_sbp, Symbol<NdSbp> out_nd_sbp) {\n  CHECK_EQ_OR_RETURN(in_nd_sbp->sbp_parallel_size(),\n                     out_nd_sbp->sbp_parallel_size());  // Always true\n  int dim = 0;\n  for (; dim < in_nd_sbp->sbp_parallel_size(); ++dim) {\n    if (in_nd_sbp->sbp_parallel(dim) != out_nd_sbp->sbp_parallel(dim)) { break; }\n  }\n  return dim;\n}\n\nMaybe<one::Tensor> Apply1DBoxing(const std::shared_ptr<one::Tensor>& input, Symbol<NdSbp> in_nd_sbp,\n                                 Symbol<NdSbp> out_nd_sbp, Symbol<ParallelDesc> in_parallel_desc,\n                                 Symbol<ParallelDesc> out_parallel_desc) {\n  const auto& boxing_interpreter =\n      JUST(Singleton<EagerBoxingInterpreterManager>::Get()->GetEagerBoxingInterpreter(\n          in_nd_sbp, out_nd_sbp, in_parallel_desc, out_parallel_desc, *input->shape()));\n  Singleton<const EagerBoxingLogger>::Get()->Log(\n      *JUST(boxing_interpreter->boxing_interpreter_status()),\n      /* prefix */ \"\\t\\tInternal boxing of generic-symmetric-nd-sbp-to-nd-sbp, \");\n  return JUST(boxing_interpreter->Interpret(input, in_nd_sbp, out_nd_sbp, in_parallel_desc,\n                                            out_parallel_desc));\n}\n\nMaybe<void> RawCheckGenericSymmetricNdSbpBoxing(Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out,\n                                                const Shape& logical_shape) {\n  CHECK_OR_RETURN(in->placement() == out->placement());\n  CHECK_OR_RETURN(in->nd_sbp() != out->nd_sbp());\n  CHECK_EQ_OR_RETURN(in->nd_sbp()->sbp_parallel_size(), out->nd_sbp()->sbp_parallel_size());\n  CHECK_GT_OR_RETURN(in->nd_sbp()->sbp_parallel_size(), 1);\n  return Maybe<void>::Ok();\n}\n// NOLINTEND(maybe-need-error-msg)\n\nstatic constexpr auto* CheckGenericSymmetricNdSbpBoxing =\n    DECORATE(&RawCheckGenericSymmetricNdSbpBoxing, ThreadLocalCachedCopiable);\n\n}  // namespace\n\nMaybe<one::Tensor> GenericSymmetricNdSbpBoxing(const std::shared_ptr<one::Tensor>& input,\n                                               Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out) {\n  const auto& in_parallel_desc = in->placement();\n  const auto& out_nd_sbp = out->nd_sbp();\n  const auto& out_parallel_desc = out->placement();\n  std::shared_ptr<one::Tensor> output;\n\n  const auto& out_parallel_id = JUST(GetParallelId4CurrentProcessCtx(out_parallel_desc));\n  if (out_parallel_id->has_value()) {\n    output = input;\n\n    int first_diff_sbp_dim = JUST(CalcTheFirstDiffAxisBetweenTwoNdSbp(in->nd_sbp(), out_nd_sbp));\n    Symbol<SbpParallel> broadcast_sbp = JUST(CachedGetBroadcastSbp());\n\n    const auto& opt_parallel_id = JUST(GetParallelId4CurrentProcessCtx(in_parallel_desc));\n    int64_t parallel_id = JUST(*opt_parallel_id);\n    const auto& hierarchy_shape = *in_parallel_desc->hierarchy();\n    Stride hierarchy_stride(hierarchy_shape);\n\n    const auto& logical_shape = input->shape();\n\n    // Convert input to broadcast tensor step by step\n    // e.g.\n    // If in_nd_sbp is (S(0), B, S(0)), (S(0), S(0), S(1))\n    // Altered state of sbp is (S(0), B, S(0)) -> (S(0), B, B)\n    for (int64_t i = out_nd_sbp->sbp_parallel_size() - 1; i >= first_diff_sbp_dim; --i) {\n      const auto& nd_sbp = JUST(output->nd_sbp());\n      const auto& sbp_parallel = nd_sbp->sbp_parallel(i);\n      if (sbp_parallel.has_broadcast_parallel()) { continue; }\n\n      const auto& one_dim_nd_sbp = JUST(SbpToNdSbp(sbp_parallel));\n      const auto& sub_logical_shape =\n          *JUST(GetLogicalShape4Axis(*logical_shape, i, in_parallel_desc, nd_sbp));\n      std::shared_ptr<one::Tensor> local_tensor = JUST(output->cur_rank_phy_tensor());\n      const auto& sub_parallel_desc = JUST(CalcSubParallelDesc4Axis(in_parallel_desc, i));\n\n      int64_t index = CalcIndex4Axis(parallel_id, hierarchy_stride, i);\n\n      const auto& physical_shape =\n          JUST(GetPhysicalShape(sub_logical_shape, *one_dim_nd_sbp, *sub_parallel_desc, index));\n      CHECK_EQ_OR_RETURN(*physical_shape, *local_tensor->shape())\n          << Error::RuntimeError() << \"Invalid input tensor, size of local tensor (\"\n          << local_tensor->shape()->ToString() << \") does not match global tensor (\"\n          << logical_shape->ToString() << \")!\";\n      std::shared_ptr<one::Tensor> sub_global_tensor = JUST(one::functional::LocalToGlobal(\n          local_tensor, sub_parallel_desc, *JUST(GetSbpList(one_dim_nd_sbp)), sub_logical_shape,\n          local_tensor->dtype(), /* sync_data */ false, /*copy=*/false));\n\n      sub_global_tensor =\n          JUST(Apply1DBoxing(sub_global_tensor, one_dim_nd_sbp, JUST(SbpToNdSbp(broadcast_sbp)),\n                             sub_parallel_desc, sub_parallel_desc));\n\n      local_tensor = JUST(sub_global_tensor->cur_rank_phy_tensor());\n\n      const auto& new_nd_sbp = JUST(SetSbpAtAxis(*nd_sbp, *broadcast_sbp, i));\n\n      output = JUST(one::functional::LocalToGlobal(\n          local_tensor, in_parallel_desc, *JUST(GetSbpList(new_nd_sbp)), *logical_shape,\n          local_tensor->dtype(), /* sync_data */ false, /*copy=*/false));\n    }\n\n    CHECK_OR_RETURN(IsAllBroadcastNdSbpAfterDim(JUST(output->nd_sbp()), first_diff_sbp_dim))\n        << Error::RuntimeError()\n        << \"Compute generic-symmetric-nd-sbp-to-nd-sbp failed. Please submit an issue in \"\n           \"`https://github.com/Oneflow-Inc/oneflow/issues` and we will fix it as soon as \"\n           \"possible\";\n\n    // Convert broadcast tensor to output with out_nd_sbp data step by step\n    // e.g.\n    // If out_nd_sbp is (S(0), S(0), S(1))\n    // Altered state of sbp is (S(0), B, B) -> (S(0), S(0), B) -> (S(0), S(0), S(1))\n    std::shared_ptr<Shape> sub_logical_shape = JUST(GetLogicalShape4Axis(\n        *logical_shape, first_diff_sbp_dim, in_parallel_desc, JUST(output->nd_sbp())));\n    for (int64_t i = first_diff_sbp_dim; i < out_nd_sbp->sbp_parallel_size(); ++i) {\n      const auto& sbp_parallel = out_nd_sbp->sbp_parallel(i);\n      if (sbp_parallel.has_broadcast_parallel()) { continue; }\n\n      const auto& nd_sbp = JUST(output->nd_sbp());\n\n      const auto& sub_parallel_desc = JUST(CalcSubParallelDesc4Axis(in_parallel_desc, i));\n\n      std::shared_ptr<one::Tensor> local_tensor = JUST(output->cur_rank_phy_tensor());\n\n      std::shared_ptr<one::Tensor> sub_global_tensor = JUST(one::functional::LocalToGlobal(\n          local_tensor, sub_parallel_desc, *JUST(GetSbpList(JUST(SbpToNdSbp(broadcast_sbp)))),\n          *sub_logical_shape, local_tensor->dtype(), /* sync_data */ false, /*copy=*/false));\n\n      const auto& one_dim_nd_sbp = JUST(SbpToNdSbp(sbp_parallel));\n      sub_global_tensor = JUST(Apply1DBoxing(sub_global_tensor, JUST(SbpToNdSbp(broadcast_sbp)),\n                                             one_dim_nd_sbp, sub_parallel_desc, sub_parallel_desc));\n\n      local_tensor = JUST(sub_global_tensor->cur_rank_phy_tensor());\n\n      int64_t index = CalcIndex4Axis(parallel_id, hierarchy_stride, i);\n      const auto& physical_shape =\n          JUST(GetPhysicalShape(*sub_logical_shape, *one_dim_nd_sbp, *sub_parallel_desc, index));\n      CHECK_EQ_OR_RETURN(*physical_shape, *local_tensor->shape())\n          << Error::RuntimeError()\n          << \"Compute generic-symmetric-nd-sbp-to-nd-sbp failed. Please submit an issue in \"\n             \"`https://github.com/Oneflow-Inc/oneflow/issues` and we will fix it as soon as \"\n             \"possible\";\n\n      const auto& new_nd_sbp = JUST(SetSbpAtAxis(*nd_sbp, sbp_parallel, i));\n\n      output = JUST(one::functional::LocalToGlobal(\n          local_tensor, in_parallel_desc, *JUST(GetSbpList(new_nd_sbp)), *logical_shape,\n          local_tensor->dtype(), /* sync_data */ false, /*copy=*/false));\n      // physical_shape of this axis is logical shape of next axis\n      sub_logical_shape = physical_shape;\n    }\n  } else {\n    one::GlobalTensorMeta tensor_meta(*input->shape(), input->dtype()->data_type(),\n                                      input->memory_format(), out_nd_sbp, out_parallel_desc);\n    const auto& tensor_impl =\n        JUST(one::EagerGlobalTensorImpl::New(SymbolOf(tensor_meta), input->requires_grad(), false));\n    output = std::make_shared<one::GlobalTensor>(tensor_impl);\n  }\n\n  return output;\n}\n\nCOMMAND(RegisterBoxingFunction(\"generic-symmetric-nd-sbp-to-nd-sbp\",\n                               CheckGenericSymmetricNdSbpBoxing, &GenericSymmetricNdSbpBoxing));\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/boxing/identity_boxing_interpreter.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/boxing/eager_boxing_interpreter.h\"\n#include \"oneflow/core/functional/functional.h\"\n#include \"oneflow/core/framework/nd_sbp.h\"\nnamespace oneflow {\n\nnamespace {\n\n// NOLINTBEGIN(maybe-need-error-msg)\nMaybe<void> RawCheckIdentity(Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out,\n                             const Shape& logical_shape) {\n  if (in->placement()->parallel_num() == 1) {\n    CHECK_OR_RETURN(in->placement()->EqualsIgnoringHierarchy(*out->placement()));\n    return Maybe<void>::Ok();\n  }\n  CHECK_OR_RETURN(in->placement() == out->placement());\n  CHECK_OR_RETURN(in->nd_sbp() == out->nd_sbp());\n  return Maybe<void>::Ok();\n}\n// NOLINTEND(maybe-need-error-msg)\n\n}  // namespace\n\nMaybe<one::Tensor> GetIdentity(const std::shared_ptr<one::Tensor>& tensor, Symbol<PlacedNdSbp> in,\n                               Symbol<PlacedNdSbp> out) {\n  const auto& tensor_nd_sbp = JUST(tensor->nd_sbp());\n  CHECK_OR_RETURN(tensor_nd_sbp == in->nd_sbp())\n      << Error::RuntimeError() << \"The sbp of input tensor (\" << NdSbpToString(tensor_nd_sbp)\n      << \") must match the input sbp (\" << NdSbpToString(in->nd_sbp()) << \")\";\n  const auto& tensor_placement = JUST(tensor->parallel_desc());\n  CHECK_OR_RETURN(tensor_placement == in->placement())\n      << Error::RuntimeError() << \"The placement of input tensor (\"\n      << *JUST(PlacementToString(tensor_placement)) << \") must match the input placement (\"\n      << *JUST(PlacementToString(in->placement())) << \")\";\n  // reset sbp if parallel_num == 1 and reset transport_token\n  const auto& local_tensor = JUST(tensor->cur_rank_phy_tensor());\n  const auto& sbp_list = JUST(GetSbpList(out->nd_sbp()));\n  return JUST(one::functional::LocalToGlobal(local_tensor, out->placement(), *sbp_list,\n                                             *tensor->shape(), tensor->dtype(),\n                                             /* sync_data */ false, /*copy=*/true));\n}\n\nCOMMAND(RegisterBoxingFunction(\"identity\", DECORATE(&RawCheckIdentity, ThreadLocalCachedCopiable),\n                               &GetIdentity));\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/boxing/naive_1_to_p_boxing.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/control/global_process_ctx.h\"\n#include \"oneflow/core/framework/nd_sbp.h\"\n#include \"oneflow/core/framework/device.h\"\n#include \"oneflow/core/boxing/eager_boxing_interpreter.h\"\n#include \"oneflow/core/functional/functional.h\"\n#include \"oneflow/core/common/decorator.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nbool NdSbpIsAllPartialSum(Symbol<NdSbp> nd_sbp) {\n  for (const auto& sbp_parallel : nd_sbp->sbp_parallel()) {\n    if (!sbp_parallel.has_partial_sum_parallel()) { return false; }\n  }\n  return true;\n}\n\n// NOLINTBEGIN(maybe-need-error-msg)\nMaybe<void> RawCheckNaive1ToP(Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out,\n                              const Shape& logical_shape) {\n  CHECK_EQ_OR_RETURN(in->placement()->parallel_num(), 1);\n  CHECK_OR_RETURN(NdSbpIsAllPartialSum(out->nd_sbp()));\n  CHECK_OR_RETURN(out->placement()->Bigger(*in->placement()));\n  return Maybe<void>::Ok();\n}\n// NOLINTEND(maybe-need-error-msg)\n\nstatic constexpr auto* CheckNaive1ToP = DECORATE(&RawCheckNaive1ToP, ThreadLocalCachedCopiable);\n\n}  // namespace\n\nMaybe<one::Tensor> Naive1ToP(const std::shared_ptr<one::Tensor>& tensor, Symbol<PlacedNdSbp> in,\n                             Symbol<PlacedNdSbp> out) {\n  const auto& tensor_nd_sbp = JUST(tensor->nd_sbp());\n  CHECK_OR_RETURN(tensor_nd_sbp == in->nd_sbp())\n      << Error::RuntimeError() << \"The sbp of input tensor (\" << NdSbpToString(tensor_nd_sbp)\n      << \") must match the input sbp (\" << NdSbpToString(in->nd_sbp()) << \")\";\n  const auto& tensor_placement = JUST(tensor->parallel_desc());\n  CHECK_OR_RETURN(tensor_placement == in->placement())\n      << Error::RuntimeError() << \"The placement of input tensor (\"\n      << *JUST(PlacementToString(tensor_placement)) << \") must match the input placement (\"\n      << *JUST(PlacementToString(in->placement())) << \")\";\n\n  int64_t root = JUST(tensor_placement->MachineId4ParallelId(0));\n  std::shared_ptr<one::Tensor> local_tensor = JUST(tensor->cur_rank_phy_tensor());\n  const auto& out_parallel_id = JUST(GetParallelId4CurrentProcessCtx(out->placement()));\n  if (root == GlobalProcessCtx::Rank() || !out_parallel_id->has_value()) {\n    // do nothing\n  } else {\n    const std::string& device_type = tensor_placement->device_tag();\n    local_tensor = JUST(one::functional::Constant(*tensor->shape(), 0, tensor->dtype(),\n                                                  JUST(Device::New(device_type))));\n  }\n  return JUST(one::functional::LocalToGlobal(\n      local_tensor, out->placement(), *JUST(GetSbpList(out->nd_sbp())), *tensor->shape(),\n      tensor->dtype(), /* sync_data */ false, /*copy=*/true));\n}\n\nCOMMAND(RegisterBoxingFunction(\"naive-1-to-p\", CheckNaive1ToP, &Naive1ToP));\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/boxing/naive_b_to_1_boxing.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/control/global_process_ctx.h\"\n#include \"oneflow/core/framework/nd_sbp.h\"\n#include \"oneflow/core/job/nd_sbp_util.h\"\n#include \"oneflow/core/framework/device.h\"\n#include \"oneflow/core/boxing/eager_boxing_interpreter.h\"\n#include \"oneflow/core/functional/functional.h\"\n#include \"oneflow/core/common/decorator.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nMaybe<void> RawCheckNaiveBTo1(Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out,\n                              const Shape& logical_shape) {\n  // NOLINTBEGIN(maybe-need-error-msg)\n  CHECK_EQ_OR_RETURN(out->placement()->parallel_num(), 1);\n  CHECK_OR_RETURN(NdSbpIsAllBroadcast(*in->nd_sbp()));\n  CHECK_OR_RETURN(in->placement()->Bigger(*out->placement()));\n  // NOLINTEND(maybe-need-error-msg)\n  return Maybe<void>::Ok();\n}\n\nstatic constexpr auto* CheckNaiveBTo1 = DECORATE(&RawCheckNaiveBTo1, ThreadLocalCachedCopiable);\n\n}  // namespace\n\nMaybe<one::Tensor> NaiveBTo1(const std::shared_ptr<one::Tensor>& tensor, Symbol<PlacedNdSbp> in,\n                             Symbol<PlacedNdSbp> out) {\n  const auto& tensor_nd_sbp = JUST(tensor->nd_sbp());\n  CHECK_OR_RETURN(tensor_nd_sbp == in->nd_sbp())\n      << Error::RuntimeError() << \"The sbp of input tensor (\" << NdSbpToString(tensor_nd_sbp)\n      << \") must match the input sbp (\" << NdSbpToString(in->nd_sbp()) << \")\";\n  const auto& tensor_placement = JUST(tensor->parallel_desc());\n  CHECK_OR_RETURN(tensor_placement == in->placement())\n      << Error::RuntimeError() << \"The placement of input tensor (\"\n      << *JUST(PlacementToString(tensor_placement)) << \") must match the input placement (\"\n      << *JUST(PlacementToString(in->placement())) << \")\";\n\n  std::shared_ptr<one::Tensor> local_tensor = JUST(tensor->cur_rank_phy_tensor());\n  return JUST(one::functional::LocalToGlobal(\n      local_tensor, out->placement(), *JUST(GetSbpList(out->nd_sbp())), *tensor->shape(),\n      tensor->dtype(), /* sync_data */ false, /*copy=*/true));\n}\n\nCOMMAND(RegisterBoxingFunction(\"naive-b-to-1\", CheckNaiveBTo1, &NaiveBTo1));\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/boxing/naive_b_to_s_boxing.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/balanced_splitter.h\"\n#include \"oneflow/core/framework/nd_sbp.h\"\n#include \"oneflow/core/boxing/eager_boxing_interpreter.h\"\n#include \"oneflow/core/common/decorator.h\"\n#include \"oneflow/core/functional/functional.h\"\n#include \"oneflow/core/boxing/slice_boxing_util.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nbool RawIsSplitSbp(Symbol<SbpParallel> sbp_parallel) { return sbp_parallel->has_split_parallel(); }\n\nstatic constexpr auto* IsSplitSbp = DECORATE(&RawIsSplitSbp, ThreadLocalCached);\n\nbool RawIsBroadcastSbp(Symbol<SbpParallel> sbp_parallel) {\n  return sbp_parallel->has_broadcast_parallel();\n}\n\nstatic constexpr auto* IsBroadcastSbp = DECORATE(&RawIsBroadcastSbp, ThreadLocalCached);\n\n// NOLINTBEGIN(maybe-need-error-msg)\nMaybe<void> RawCheckNaiveBToS(Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out,\n                              const Shape& logical_shape) {\n  CHECK_EQ_OR_RETURN(in->nd_sbp()->sbp_parallel_size(), 1);\n  CHECK_EQ_OR_RETURN(out->nd_sbp()->sbp_parallel_size(), 1);\n\n  CHECK_OR_RETURN(IsBroadcastSbp(in->nd_sbp()->sbp_parallel(0)));\n  CHECK_OR_RETURN(IsSplitSbp(out->nd_sbp()->sbp_parallel(0)));\n\n  CHECK_EQ_OR_RETURN(in->placement()->device_tag(), out->placement()->device_tag());\n  return Maybe<void>::Ok();\n}\n// NOLINTEND(maybe-need-error-msg)\n\nstatic constexpr auto* CheckNaiveBToS = DECORATE(&RawCheckNaiveBToS, ThreadLocalCachedCopiable);\n\n}  // namespace\n\nMaybe<one::Tensor> NaiveBToS(const std::shared_ptr<one::Tensor>& tensor, Symbol<PlacedNdSbp> in,\n                             Symbol<PlacedNdSbp> out) {\n  const auto& tensor_nd_sbp = JUST(tensor->nd_sbp());\n  CHECK_OR_RETURN(tensor_nd_sbp == in->nd_sbp())\n      << Error::RuntimeError() << \"The sbp of input tensor (\" << NdSbpToString(tensor_nd_sbp)\n      << \") must match the input sbp (\" << NdSbpToString(in->nd_sbp()) << \")\";\n  const auto& tensor_placement = JUST(tensor->parallel_desc());\n  CHECK_OR_RETURN(tensor_placement == in->placement())\n      << Error::RuntimeError() << \"The placement of input tensor (\"\n      << *JUST(PlacementToString(tensor_placement)) << \") must match the input placement (\"\n      << *JUST(PlacementToString(in->placement())) << \")\";\n  const auto& sbp_list = JUST(GetSbpList(out->nd_sbp()));\n  std::shared_ptr<one::Tensor> local_tensor = JUST(tensor->cur_rank_phy_tensor());\n  {\n    const auto& in_parallel_id = JUST(GetParallelId4CurrentProcessCtx(tensor_placement));\n    const auto& out_parallel_id = JUST(GetParallelId4CurrentProcessCtx(out->placement()));\n    if (in_parallel_id->has_value() || out_parallel_id->has_value()) {\n      local_tensor = JUST(one::functional::EagerBToS(\n          local_tensor, tensor_placement, out->placement(), *sbp_list, *tensor->shape()));\n    }\n  }\n\n  return JUST(one::functional::LocalToGlobal(local_tensor, out->placement(), *sbp_list,\n                                             *tensor->shape(), tensor->dtype(),\n                                             /* sync_data */ false, /*copy=*/false));\n}\n\nstatic constexpr auto* NaiveBToSWithAutoConvert =\n    EAGER_SLICE_BOXING_WARPPER(&NaiveBToS, EagerSliceBoxingType::kNaiveBToS);\n\nCOMMAND(RegisterBoxingFunction(\"naive-b-to-s\", CheckNaiveBToS, NaiveBToSWithAutoConvert));\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/boxing/naive_p_to_b_boxing.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/balanced_splitter.h\"\n#include \"oneflow/core/framework/nd_sbp.h\"\n#include \"oneflow/core/boxing/eager_boxing_interpreter.h\"\n#include \"oneflow/core/common/decorator.h\"\n#include \"oneflow/core/functional/functional.h\"\n#include \"oneflow/core/boxing/slice_boxing_util.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nbool RawIsPartialSumSbp(Symbol<SbpParallel> sbp_parallel) {\n  return sbp_parallel->has_partial_sum_parallel();\n}\n\nstatic constexpr auto* IsPartialSumSbp = DECORATE(&RawIsPartialSumSbp, ThreadLocalCached);\n\nbool RawIsBroadcastSbp(Symbol<SbpParallel> sbp_parallel) {\n  return sbp_parallel->has_broadcast_parallel();\n}\n\nstatic constexpr auto* IsBroadcastSbp = DECORATE(&RawIsBroadcastSbp, ThreadLocalCached);\n\n// NOLINTBEGIN(maybe-need-error-msg)\nMaybe<void> RawCheckNaivePToB(Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out,\n                              const Shape& logical_shape) {\n  CHECK_EQ_OR_RETURN(in->nd_sbp()->sbp_parallel_size(), 1);\n  CHECK_EQ_OR_RETURN(out->nd_sbp()->sbp_parallel_size(), 1);\n  CHECK_OR_RETURN(IsPartialSumSbp(in->nd_sbp()->sbp_parallel(0)));\n  CHECK_OR_RETURN(IsBroadcastSbp(out->nd_sbp()->sbp_parallel(0)));\n  CHECK_EQ_OR_RETURN(in->placement()->device_tag(), out->placement()->device_tag());\n  return Maybe<void>::Ok();\n}\n// NOLINTEND(maybe-need-error-msg)\n\nstatic constexpr auto* CheckNaivePToB = DECORATE(&RawCheckNaivePToB, ThreadLocalCachedCopiable);\n\n}  // namespace\n\nMaybe<one::Tensor> NaivePToB(const std::shared_ptr<one::Tensor>& tensor, Symbol<PlacedNdSbp> in,\n                             Symbol<PlacedNdSbp> out) {\n  const auto& tensor_nd_sbp = JUST(tensor->nd_sbp());\n  CHECK_OR_RETURN(tensor_nd_sbp == in->nd_sbp())\n      << Error::RuntimeError() << \"The sbp of input tensor (\" << NdSbpToString(tensor_nd_sbp)\n      << \") must match the input sbp (\" << NdSbpToString(in->nd_sbp()) << \")\";\n  const auto& tensor_placement = JUST(tensor->parallel_desc());\n  CHECK_OR_RETURN(tensor_placement == in->placement())\n      << Error::RuntimeError() << \"The placement of input tensor (\"\n      << *JUST(PlacementToString(tensor_placement)) << \") must match the input placement (\"\n      << *JUST(PlacementToString(in->placement())) << \")\";\n  std::shared_ptr<one::Tensor> local_tensor = JUST(tensor->cur_rank_phy_tensor());\n  {\n    const auto& in_parallel_id = JUST(GetParallelId4CurrentProcessCtx(tensor_placement));\n    const auto& out_parallel_id = JUST(GetParallelId4CurrentProcessCtx(out->placement()));\n    if (in_parallel_id->has_value() || out_parallel_id->has_value()) {\n      local_tensor = JUST(one::functional::EagerPToB(local_tensor, tensor_placement,\n                                                     out->placement(), *tensor->shape()));\n    }\n  }\n\n  const auto& sbp_list = JUST(GetSbpList(out->nd_sbp()));\n  return JUST(one::functional::LocalToGlobal(local_tensor, out->placement(), *sbp_list,\n                                             *tensor->shape(), tensor->dtype(),\n                                             /* sync_data */ false, /*copy=*/false));\n}\n\nstatic constexpr auto* NaivePToBWithAutoConvert =\n    EAGER_SLICE_BOXING_WARPPER(&NaivePToB, EagerSliceBoxingType::kNaivePToB);\n\nCOMMAND(RegisterBoxingFunction(\"naive-p-to-b\", CheckNaivePToB, NaivePToBWithAutoConvert));\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/boxing/naive_p_to_s_boxing.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/balanced_splitter.h\"\n#include \"oneflow/core/framework/nd_sbp.h\"\n#include \"oneflow/core/boxing/eager_boxing_interpreter.h\"\n#include \"oneflow/core/common/decorator.h\"\n#include \"oneflow/core/functional/functional.h\"\n#include \"oneflow/core/boxing/slice_boxing_util.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nbool RawIsPartialSumSbp(Symbol<SbpParallel> sbp_parallel) {\n  return sbp_parallel->has_partial_sum_parallel();\n}\n\nstatic constexpr auto* IsPartialSumSbp = DECORATE(&RawIsPartialSumSbp, ThreadLocalCached);\n\nbool RawIsSplitSbp(Symbol<SbpParallel> sbp_parallel) { return sbp_parallel->has_split_parallel(); }\n\nstatic constexpr auto* IsSplitSbp = DECORATE(&RawIsSplitSbp, ThreadLocalCached);\n\n// NOLINTBEGIN(maybe-need-error-msg)\nMaybe<void> RawCheckNaivePToS(Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out,\n                              const Shape& logical_shape) {\n  CHECK_EQ_OR_RETURN(in->nd_sbp()->sbp_parallel_size(), 1);\n  CHECK_EQ_OR_RETURN(out->nd_sbp()->sbp_parallel_size(), 1);\n\n  CHECK_OR_RETURN(IsPartialSumSbp(in->nd_sbp()->sbp_parallel(0)));\n  CHECK_OR_RETURN(IsSplitSbp(out->nd_sbp()->sbp_parallel(0)));\n  CHECK_EQ_OR_RETURN(in->placement()->device_tag(), out->placement()->device_tag());\n  return Maybe<void>::Ok();\n}\n// NOLINTEND(maybe-need-error-msg)\n\nstatic constexpr auto* CheckNaivePToS = DECORATE(&RawCheckNaivePToS, ThreadLocalCachedCopiable);\n\n}  // namespace\n\nMaybe<one::Tensor> NaivePToS(const std::shared_ptr<one::Tensor>& tensor, Symbol<PlacedNdSbp> in,\n                             Symbol<PlacedNdSbp> out) {\n  const auto& tensor_nd_sbp = JUST(tensor->nd_sbp());\n  CHECK_OR_RETURN(tensor_nd_sbp == in->nd_sbp())\n      << Error::RuntimeError() << \"The sbp of input tensor (\" << NdSbpToString(tensor_nd_sbp)\n      << \") must match the input sbp (\" << NdSbpToString(in->nd_sbp()) << \")\";\n  const auto& tensor_placement = JUST(tensor->parallel_desc());\n  CHECK_OR_RETURN(tensor_placement == in->placement())\n      << Error::RuntimeError() << \"The placement of input tensor (\"\n      << *JUST(PlacementToString(tensor_placement)) << \") must match the input placement (\"\n      << *JUST(PlacementToString(in->placement())) << \")\";\n  const auto& sbp_list = JUST(GetSbpList(out->nd_sbp()));\n  std::shared_ptr<one::Tensor> local_tensor = JUST(tensor->cur_rank_phy_tensor());\n  {\n    const auto& in_parallel_id = JUST(GetParallelId4CurrentProcessCtx(tensor_placement));\n    const auto& out_parallel_id = JUST(GetParallelId4CurrentProcessCtx(out->placement()));\n    if (in_parallel_id->has_value() || out_parallel_id->has_value()) {\n      local_tensor = JUST(one::functional::EagerPToS(\n          local_tensor, tensor_placement, out->placement(), *sbp_list, *tensor->shape()));\n    }\n  }\n\n  return JUST(one::functional::LocalToGlobal(local_tensor, out->placement(), *sbp_list,\n                                             *tensor->shape(), tensor->dtype(),\n                                             /* sync_data */ true, /*copy=*/false));\n}\n\nstatic constexpr auto* NaivePToSWithAutoConvert =\n    EAGER_SLICE_BOXING_WARPPER(&NaivePToS, EagerSliceBoxingType::kNaivePToS);\n\nCOMMAND(RegisterBoxingFunction(\"naive-p-to-s\", CheckNaivePToS, NaivePToSWithAutoConvert));\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/boxing/naive_s_to_b_boxing.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/balanced_splitter.h\"\n#include \"oneflow/core/framework/nd_sbp.h\"\n#include \"oneflow/core/boxing/eager_boxing_interpreter.h\"\n#include \"oneflow/core/common/decorator.h\"\n#include \"oneflow/core/functional/functional.h\"\n#include \"oneflow/core/boxing/slice_boxing_util.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nbool RawIsSplitSbp(Symbol<SbpParallel> sbp_parallel) { return sbp_parallel->has_split_parallel(); }\n\nstatic constexpr auto* IsSplitSbp = DECORATE(&RawIsSplitSbp, ThreadLocalCached);\n\nbool RawIsBroadcastSbp(Symbol<SbpParallel> sbp_parallel) {\n  return sbp_parallel->has_broadcast_parallel();\n}\n\nstatic constexpr auto* IsBroadcastSbp = DECORATE(&RawIsBroadcastSbp, ThreadLocalCached);\n\n// NOLINTBEGIN(maybe-need-error-msg)\nMaybe<void> RawCheckNaiveSToB(Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out,\n                              const Shape& logical_shape) {\n  CHECK_EQ_OR_RETURN(in->nd_sbp()->sbp_parallel_size(), 1);\n  CHECK_EQ_OR_RETURN(out->nd_sbp()->sbp_parallel_size(), 1);\n  CHECK_OR_RETURN(IsSplitSbp(in->nd_sbp()->sbp_parallel(0)));\n  CHECK_OR_RETURN(IsBroadcastSbp(out->nd_sbp()->sbp_parallel(0)));\n  CHECK_EQ_OR_RETURN(in->placement()->device_tag(), out->placement()->device_tag());\n  return Maybe<void>::Ok();\n}\n// NOLINTEND(maybe-need-error-msg)\n\nstatic constexpr auto* CheckNaiveSToB = DECORATE(&RawCheckNaiveSToB, ThreadLocalCachedCopiable);\n\n}  // namespace\n\nMaybe<one::Tensor> NaiveSToB(const std::shared_ptr<one::Tensor>& tensor, Symbol<PlacedNdSbp> in,\n                             Symbol<PlacedNdSbp> out) {\n  const auto& tensor_nd_sbp = JUST(tensor->nd_sbp());\n  CHECK_OR_RETURN(tensor_nd_sbp == in->nd_sbp())\n      << Error::RuntimeError() << \"The sbp of input tensor (\" << NdSbpToString(tensor_nd_sbp)\n      << \") must match the input sbp (\" << NdSbpToString(in->nd_sbp()) << \")\";\n  const auto& tensor_placement = JUST(tensor->parallel_desc());\n  CHECK_OR_RETURN(tensor_placement == in->placement())\n      << Error::RuntimeError() << \"The placement of input tensor (\"\n      << *JUST(PlacementToString(tensor_placement)) << \") must match the input placement (\"\n      << *JUST(PlacementToString(in->placement())) << \")\";\n  std::shared_ptr<one::Tensor> local_tensor = JUST(tensor->cur_rank_phy_tensor());\n  {\n    const auto& in_parallel_id = JUST(GetParallelId4CurrentProcessCtx(tensor_placement));\n    const auto& out_parallel_id = JUST(GetParallelId4CurrentProcessCtx(out->placement()));\n    if (in_parallel_id->has_value() || out_parallel_id->has_value()) {\n      local_tensor =\n          JUST(one::functional::EagerSToB(local_tensor, tensor_placement, out->placement(),\n                                          *JUST(GetSbpList(tensor_nd_sbp)), *tensor->shape()));\n    }\n  }\n\n  const auto& sbp_list = JUST(GetSbpList(out->nd_sbp()));\n  return JUST(one::functional::LocalToGlobal(local_tensor, out->placement(), *sbp_list,\n                                             *tensor->shape(), tensor->dtype(),\n                                             /* sync_data */ false, /*copy=*/false));\n}\n\nstatic constexpr auto* NaiveSToBWithAutoConvert =\n    EAGER_SLICE_BOXING_WARPPER(&NaiveSToB, EagerSliceBoxingType::kNaiveSToB);\n\nCOMMAND(RegisterBoxingFunction(\"naive-s-to-b\", CheckNaiveSToB, NaiveSToBWithAutoConvert));\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/boxing/naive_s_to_p_boxing.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/balanced_splitter.h\"\n#include \"oneflow/core/framework/nd_sbp.h\"\n#include \"oneflow/core/boxing/eager_boxing_interpreter.h\"\n#include \"oneflow/core/common/decorator.h\"\n#include \"oneflow/core/functional/functional.h\"\n#include \"oneflow/core/boxing/slice_boxing_util.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nbool RawIsSplitSbp(Symbol<SbpParallel> sbp_parallel) { return sbp_parallel->has_split_parallel(); }\n\nstatic constexpr auto* IsSplitSbp = DECORATE(&RawIsSplitSbp, ThreadLocalCached);\n\nbool RawIsPartialSumSbp(Symbol<SbpParallel> sbp_parallel) {\n  return sbp_parallel->has_partial_sum_parallel();\n}\n\nstatic constexpr auto* IsPartialSumSbp = DECORATE(&RawIsPartialSumSbp, ThreadLocalCached);\n\n// NOLINTBEGIN(maybe-need-error-msg)\nMaybe<void> RawCheckNaiveSToP(Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out,\n                              const Shape& logical_shape) {\n  CHECK_EQ_OR_RETURN(in->nd_sbp()->sbp_parallel_size(), 1);\n  CHECK_EQ_OR_RETURN(out->nd_sbp()->sbp_parallel_size(), 1);\n  CHECK_OR_RETURN(IsSplitSbp(in->nd_sbp()->sbp_parallel(0)));\n  CHECK_OR_RETURN(IsPartialSumSbp(out->nd_sbp()->sbp_parallel(0)));\n  CHECK_EQ_OR_RETURN(in->placement()->device_tag(), out->placement()->device_tag());\n  return Maybe<void>::Ok();\n}\n// NOLINTEND(maybe-need-error-msg)\n\nstatic constexpr auto* CheckNaiveSToP = DECORATE(&RawCheckNaiveSToP, ThreadLocalCachedCopiable);\n\n}  // namespace\n\nMaybe<one::Tensor> NaiveSToP(const std::shared_ptr<one::Tensor>& tensor, Symbol<PlacedNdSbp> in,\n                             Symbol<PlacedNdSbp> out) {\n  const auto& tensor_nd_sbp = JUST(tensor->nd_sbp());\n  CHECK_OR_RETURN(tensor_nd_sbp == in->nd_sbp())\n      << Error::RuntimeError() << \"The sbp of input tensor (\" << NdSbpToString(tensor_nd_sbp)\n      << \") must match the input sbp (\" << NdSbpToString(in->nd_sbp()) << \")\";\n  const auto& tensor_placement = JUST(tensor->parallel_desc());\n  CHECK_OR_RETURN(tensor_placement == in->placement())\n      << Error::RuntimeError() << \"The placement of input tensor (\"\n      << *JUST(PlacementToString(tensor_placement)) << \") must match the input placement (\"\n      << *JUST(PlacementToString(in->placement())) << \")\";\n  std::shared_ptr<one::Tensor> local_tensor = JUST(tensor->cur_rank_phy_tensor());\n  {\n    const auto& in_parallel_id = JUST(GetParallelId4CurrentProcessCtx(tensor_placement));\n    const auto& out_parallel_id = JUST(GetParallelId4CurrentProcessCtx(out->placement()));\n    if (in_parallel_id->has_value() || out_parallel_id->has_value()) {\n      local_tensor =\n          JUST(one::functional::EagerSToP(local_tensor, tensor_placement, out->placement(),\n                                          *JUST(GetSbpList(tensor_nd_sbp)), *tensor->shape()));\n    }\n  }\n\n  const auto& sbp_list = JUST(GetSbpList(out->nd_sbp()));\n  return JUST(one::functional::LocalToGlobal(local_tensor, out->placement(), *sbp_list,\n                                             *tensor->shape(), tensor->dtype(),\n                                             /* sync_data */ false, /*copy=*/false));\n}\n\nstatic constexpr auto* NaiveSToPWithAutoConvert =\n    EAGER_SLICE_BOXING_WARPPER(&NaiveSToP, EagerSliceBoxingType::kNaiveSToP);\n\nCOMMAND(RegisterBoxingFunction(\"naive-s-to-p\", CheckNaiveSToP, NaiveSToPWithAutoConvert));\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/boxing/naive_s_to_s_boxing.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/balanced_splitter.h\"\n#include \"oneflow/core/framework/nd_sbp.h\"\n#include \"oneflow/core/boxing/eager_boxing_interpreter.h\"\n#include \"oneflow/core/common/decorator.h\"\n#include \"oneflow/core/functional/functional.h\"\n#include \"oneflow/core/boxing/slice_boxing_util.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nbool RawIsSplitSbp(Symbol<SbpParallel> sbp_parallel) { return sbp_parallel->has_split_parallel(); }\n\nstatic constexpr auto* IsSplitSbp = DECORATE(&RawIsSplitSbp, ThreadLocalCached);\n\n// NOLINTBEGIN(maybe-need-error-msg)\nMaybe<void> RawCheckNaiveSToS(Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out,\n                              const Shape& logical_shape) {\n  CHECK_EQ_OR_RETURN(in->nd_sbp()->sbp_parallel_size(), 1);\n  CHECK_EQ_OR_RETURN(out->nd_sbp()->sbp_parallel_size(), 1);\n\n  CHECK_OR_RETURN(IsSplitSbp(in->nd_sbp()->sbp_parallel(0)));\n  CHECK_OR_RETURN(IsSplitSbp(out->nd_sbp()->sbp_parallel(0)));\n\n  CHECK_EQ_OR_RETURN(in->placement()->device_tag(), out->placement()->device_tag());\n  return Maybe<void>::Ok();\n}\n// NOLINTEND(maybe-need-error-msg)\n\nstatic constexpr auto* CheckNaiveSToS = DECORATE(&RawCheckNaiveSToS, ThreadLocalCachedCopiable);\n\n}  // namespace\n\nMaybe<one::Tensor> NaiveSToS(const std::shared_ptr<one::Tensor>& tensor, Symbol<PlacedNdSbp> in,\n                             Symbol<PlacedNdSbp> out) {\n  const auto& tensor_nd_sbp = JUST(tensor->nd_sbp());\n  CHECK_OR_RETURN(tensor_nd_sbp == in->nd_sbp())\n      << Error::RuntimeError() << \"The sbp of input tensor (\" << NdSbpToString(tensor_nd_sbp)\n      << \") must match the input sbp (\" << NdSbpToString(in->nd_sbp()) << \")\";\n  const auto& tensor_placement = JUST(tensor->parallel_desc());\n  CHECK_OR_RETURN(tensor_placement == in->placement())\n      << Error::RuntimeError() << \"The placement of input tensor (\"\n      << *JUST(PlacementToString(tensor_placement)) << \") must match the input placement (\"\n      << *JUST(PlacementToString(in->placement())) << \")\";\n  const auto& in_sbp_list = JUST(GetSbpList(tensor_nd_sbp));\n  const auto& out_sbp_list = JUST(GetSbpList(out->nd_sbp()));\n\n  std::shared_ptr<one::Tensor> local_tensor = JUST(tensor->cur_rank_phy_tensor());\n  {\n    const auto& in_parallel_id = JUST(GetParallelId4CurrentProcessCtx(tensor_placement));\n    const auto& out_parallel_id = JUST(GetParallelId4CurrentProcessCtx(out->placement()));\n    if (in_parallel_id->has_value() || out_parallel_id->has_value()) {\n      local_tensor =\n          JUST(one::functional::EagerNaiveSToS(local_tensor, tensor_placement, out->placement(),\n                                               *in_sbp_list, *out_sbp_list, *tensor->shape()));\n    }\n  }\n\n  return JUST(one::functional::LocalToGlobal(local_tensor, out->placement(), *out_sbp_list,\n                                             *tensor->shape(), tensor->dtype(),\n                                             /* sync_data */ false, /*copy=*/false));\n}\n\nstatic constexpr auto* NaiveSToSWithAutoConvert =\n    EAGER_SLICE_BOXING_WARPPER(&NaiveSToS, EagerSliceBoxingType::kNaiveSToS);\n\nCOMMAND(RegisterBoxingFunction(\"naive-s-to-s\", CheckNaiveSToS, NaiveSToSWithAutoConvert));\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/boxing/nd_sbp_dim_reduce_boxing.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/boxing/eager_boxing_interpreter_mgr.h\"\n#include \"oneflow/core/boxing/eager_boxing_logger.h\"\n#include \"oneflow/core/framework/nd_sbp.h\"\n#include \"oneflow/core/framework/device.h\"\n#include \"oneflow/core/functional/functional.h\"\n#include \"oneflow/core/common/decorator.h\"\n#include \"oneflow/core/operator/operator.h\"\n#include \"oneflow/core/framework/sbp_infer_util.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nMaybe<std::tuple<Symbol<PlacedNdSbp>, Symbol<PlacedNdSbp>>> RawInOutPlacedNdSbpDimReduce(\n    Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out, const Shape& logical_shape) {\n  // reduce hierarchy\n  ParallelDesc reduced_in_placement = *in->placement();\n  ParallelDesc reduced_out_placement = *out->placement();\n  NdSbp reduced_in_nd_sbp;\n  NdSbp reduced_out_nd_sbp;\n  InOutParallelDimReduce(*in->placement(), *out->placement(), *in->nd_sbp(), *out->nd_sbp(),\n                         &reduced_in_placement, &reduced_out_placement, &reduced_in_nd_sbp,\n                         &reduced_out_nd_sbp, logical_shape);\n  return std::make_tuple(\n      JUST(PlacedNdSbp::New(SymbolOf(reduced_in_nd_sbp), SymbolOf(reduced_in_placement))),\n      JUST(PlacedNdSbp::New(SymbolOf(reduced_out_nd_sbp), SymbolOf(reduced_out_placement))));\n}\n\nconstexpr auto* InOutPlacedNdSbpDimReduce =\n    DECORATE(&RawInOutPlacedNdSbpDimReduce, ThreadLocalCachedCopiable);\n\n// NOLINTBEGIN(maybe-need-error-msg)\nMaybe<void> RawCheckParallelDimReduce(Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out,\n                                      const Shape& logical_shape) {\n  CHECK_OR_RETURN(in->nd_sbp()->sbp_parallel_size() > 1 || out->nd_sbp()->sbp_parallel_size() > 1);\n  CHECK_EQ_OR_RETURN(in->placement()->device_tag(), out->placement()->device_tag());\n  Symbol<PlacedNdSbp> reduced_in;\n  Symbol<PlacedNdSbp> reduced_out;\n  std::tie(reduced_in, reduced_out) = *JUST(InOutPlacedNdSbpDimReduce(in, out, logical_shape));\n\n  for (int64_t in_parallel_id = 0; in_parallel_id < in->placement()->parallel_num();\n       ++in_parallel_id) {\n    const auto& in_physical_shape =\n        JUST(GetPhysicalShape(logical_shape, *in->nd_sbp(), *in->placement(), in_parallel_id));\n    const auto& reduce_in_physical_shape = JUST(GetPhysicalShape(\n        logical_shape, *reduced_in->nd_sbp(), *reduced_in->placement(), in_parallel_id));\n    CHECK_EQ_OR_RETURN(*in_physical_shape, *reduce_in_physical_shape);\n  }\n\n  for (int64_t out_parallel_id = 0; out_parallel_id < out->placement()->parallel_num();\n       ++out_parallel_id) {\n    const auto& out_physical_shape =\n        JUST(GetPhysicalShape(logical_shape, *out->nd_sbp(), *out->placement(), out_parallel_id));\n    const auto& reduce_out_physical_shape = JUST(GetPhysicalShape(\n        logical_shape, *reduced_out->nd_sbp(), *reduced_out->placement(), out_parallel_id));\n    CHECK_EQ_OR_RETURN(*out_physical_shape, *reduce_out_physical_shape);\n  }\n\n  if (reduced_in->nd_sbp()->sbp_parallel_size() == 1\n      && reduced_out->nd_sbp()->sbp_parallel_size() == 1) {\n    return Maybe<void>::Ok();\n  }\n  if ((reduced_in->placement() != in->placement() || reduced_out->placement() != out->placement())\n      && reduced_in->placement() == reduced_out->placement()) {\n    return Maybe<void>::Ok();\n  }\n  return Error::CheckFailedError();\n}\n// NOLINTEND(maybe-need-error-msg)\n\nstatic constexpr auto* CheckParallelDimReduce =\n    DECORATE(&RawCheckParallelDimReduce, ThreadLocalCachedCopiable);\n\n}  // namespace\n\nMaybe<one::Tensor> ParallelDimReduce(const std::shared_ptr<one::Tensor>& tensor,\n                                     Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out) {\n  const auto& tensor_nd_sbp = JUST(tensor->nd_sbp());\n  CHECK_OR_RETURN(tensor_nd_sbp == in->nd_sbp())\n      << Error::RuntimeError() << \"The sbp of input tensor (\" << NdSbpToString(tensor_nd_sbp)\n      << \") must match the input sbp (\" << NdSbpToString(in->nd_sbp()) << \")\";\n  const auto& tensor_placement = JUST(tensor->parallel_desc());\n  CHECK_OR_RETURN(tensor_placement == in->placement())\n      << Error::RuntimeError() << \"The placement of input tensor (\"\n      << *JUST(PlacementToString(tensor_placement)) << \") must match the input placement (\"\n      << *JUST(PlacementToString(in->placement())) << \")\";\n\n  Symbol<PlacedNdSbp> reduced_in;\n  Symbol<PlacedNdSbp> reduced_out;\n  std::tie(reduced_in, reduced_out) = *JUST(InOutPlacedNdSbpDimReduce(in, out, *tensor->shape()));\n\n  const std::shared_ptr<one::Tensor>& local_tensor = JUST(tensor->cur_rank_phy_tensor());\n\n  std::shared_ptr<one::Tensor> reduced_in_tensor = JUST(one::functional::LocalToGlobal(\n      local_tensor, reduced_in->placement(), *JUST(GetSbpList(reduced_in->nd_sbp())),\n      *tensor->shape(), tensor->dtype(), /* sync_data */ false, /*copy=*/false));\n\n  const auto& boxing_interpreter =\n      JUST(Singleton<EagerBoxingInterpreterManager>::Get()->GetEagerBoxingInterpreter(\n          reduced_in->nd_sbp(), reduced_out->nd_sbp(), reduced_in->placement(),\n          reduced_out->placement(), *tensor->shape()));\n  Singleton<const EagerBoxingLogger>::Get()->Log(\n      *JUST(boxing_interpreter->boxing_interpreter_status()),\n      /* prefix */ \"\\t\\tInternal boxing of nd-sbp-dim-reduce, \");\n  std::shared_ptr<one::Tensor> reduced_out_tensor = JUST(\n      boxing_interpreter->Interpret(reduced_in_tensor, reduced_in->nd_sbp(), reduced_out->nd_sbp(),\n                                    reduced_in->placement(), reduced_out->placement()));\n\n  const std::shared_ptr<one::Tensor>& reduced_out_local_tensor =\n      JUST(reduced_out_tensor->cur_rank_phy_tensor());\n\n  return JUST(one::functional::LocalToGlobal(\n      reduced_out_local_tensor, out->placement(), *JUST(GetSbpList(out->nd_sbp())),\n      *tensor->shape(), tensor->dtype(), /* sync_data */ false, /*copy=*/false));\n}\n\nCOMMAND(RegisterBoxingFunction(\"nd-sbp-dim-reduce\", CheckParallelDimReduce, &ParallelDimReduce));\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/boxing/one_to_one_boxing.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/control/global_process_ctx.h\"\n#include \"oneflow/core/framework/nd_sbp.h\"\n#include \"oneflow/core/framework/device.h\"\n#include \"oneflow/core/boxing/eager_boxing_interpreter.h\"\n#include \"oneflow/core/functional/functional.h\"\n#include \"oneflow/core/common/decorator.h\"\n#include \"oneflow/user/kernels/communicate_util.h\"\n\nnamespace oneflow {\n\nnamespace {\n\n// NOLINTBEGIN(maybe-need-error-msg)\nMaybe<void> RawCheckNaiveOneToOne(Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out,\n                                  const Shape& logical_shape) {\n  CHECK_EQ_OR_RETURN(in->placement()->parallel_num(), 1);\n  CHECK_EQ_OR_RETURN(out->placement()->parallel_num(), 1);\n  CHECK_EQ_OR_RETURN(in->placement()->device_tag(), out->placement()->device_tag());\n  CHECK_OR_RETURN(in->placement() != out->placement());\n  CHECK_OR_RETURN(IsSendAndRecvRegistered(in->placement()->device_type()));  // NOLINT\n  return Maybe<void>::Ok();\n}\n// NOLINTEND(maybe-need-error-msg)\n\nstatic constexpr auto* CheckNaiveOneToOne =\n    DECORATE(&RawCheckNaiveOneToOne, ThreadLocalCachedCopiable);\n\n}  // namespace\n\nMaybe<one::Tensor> NaiveOneToOne(const std::shared_ptr<one::Tensor>& tensor, Symbol<PlacedNdSbp> in,\n                                 Symbol<PlacedNdSbp> out) {\n  const auto& tensor_nd_sbp = JUST(tensor->nd_sbp());\n  CHECK_OR_RETURN(tensor_nd_sbp == in->nd_sbp())\n      << Error::RuntimeError() << \"The sbp of input tensor (\" << NdSbpToString(tensor_nd_sbp)\n      << \") must match the input sbp (\" << NdSbpToString(in->nd_sbp()) << \")\";\n  const auto& tensor_placement = JUST(tensor->parallel_desc());\n  CHECK_OR_RETURN(tensor_placement == in->placement())\n      << Error::RuntimeError() << \"The placement of input tensor (\"\n      << *JUST(PlacementToString(tensor_placement)) << \") must match the input placement (\"\n      << *JUST(PlacementToString(in->placement())) << \")\";\n\n  std::shared_ptr<one::Tensor> local_tensor = JUST(tensor->cur_rank_phy_tensor());\n  int64_t src = JUST(tensor_placement->MachineId4ParallelId(0));\n  int64_t dst = JUST(out->placement()->MachineId4ParallelId(0));\n\n  bool copy = true;\n  if (src != dst) {\n    copy = false;\n    if (GlobalProcessCtx::Rank() == src) {\n      JUST(one::functional::Send(local_tensor, dst, /* send_meta */ false));\n    }\n    if (GlobalProcessCtx::Rank() == dst) {\n      local_tensor = JUST(one::functional::Recv(src, *tensor->shape(), tensor->dtype(),\n                                                JUST(local_tensor->device()), NullOpt));\n    }\n  }\n  return JUST(one::functional::LocalToGlobal(\n      local_tensor, out->placement(), *JUST(GetSbpList(out->nd_sbp())), *tensor->shape(),\n      tensor->dtype(), /* sync_data */ false, /*copy=*/copy));\n}\n\nCOMMAND(RegisterBoxingFunction(\"naive-1-to-1\", CheckNaiveOneToOne, &NaiveOneToOne));\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/boxing/slice_boxing_util.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/container_util.h\"\n#include \"oneflow/core/boxing/slice_boxing_util.h\"\n#include \"oneflow/core/boxing/eager_boxing_interpreter_mgr.h\"\n#include \"oneflow/core/boxing/eager_boxing_logger.h\"\n#include \"oneflow/core/boxing/eager_boxing_interpreter.h\"\n#include \"oneflow/user/kernels/communicate_util.h\"\n\nnamespace oneflow {\n\nnamespace private_details {\n\nMaybe<one::Tensor> PreprocessInputTensor4SliceBoxing(const std::shared_ptr<one::Tensor>& tensor,\n                                                     const std::string& log_prefix) {\n  const auto& tensor_placement = JUST(tensor->parallel_desc());\n  if (IsSendAndRecvRegistered(tensor_placement->device_type())) { return tensor; }\n\n  const auto& tensor_nd_sbp = JUST(tensor->nd_sbp());\n  Symbol<ParallelDesc> new_placement = JUST(ReplaceDeviceType(tensor_placement, DeviceType::kCPU));\n\n  const auto& boxing_interpreter =\n      JUST(Singleton<EagerBoxingInterpreterManager>::Get()->GetEagerBoxingInterpreter(\n          tensor_nd_sbp, tensor_nd_sbp, tensor_placement, new_placement, *tensor->shape()));\n  Singleton<const EagerBoxingLogger>::Get()->Log(\n      *JUST(boxing_interpreter->boxing_interpreter_status()), log_prefix);\n  return JUST(boxing_interpreter->Interpret(tensor, tensor_nd_sbp, tensor_nd_sbp, tensor_placement,\n                                            new_placement));\n}\n\nMaybe<one::Tensor> PostprocessOutputTensor4SliceBoxing(const std::shared_ptr<one::Tensor>& tensor,\n                                                       Symbol<PlacedNdSbp> placed_nd_sbp,\n                                                       const std::string& log_prefix) {\n  const auto& tensor_nd_sbp = JUST(tensor->nd_sbp());\n  const auto& tensor_placement = JUST(tensor->parallel_desc());\n  CHECK_OR_RETURN(tensor_nd_sbp == placed_nd_sbp->nd_sbp())\n      << Error::RuntimeError()\n      << \"Compute slice boxing failed.  Please submit an issue in \"\n         \"`https://github.com/Oneflow-Inc/oneflow/issues` and we will fix it as soon as \"\n         \"possible\";\n  CHECK_OR_RETURN(tensor_placement->EqualsIgnoringDeviceType(*placed_nd_sbp->placement()))\n      << Error::RuntimeError()\n      << \"Compute slice boxing failed. Please submit an issue in \"\n         \"`https://github.com/Oneflow-Inc/oneflow/issues` and we will fix it as soon as \"\n         \"possible\";\n\n  if (JUST(tensor->parallel_desc()) == placed_nd_sbp->placement()) { return tensor; }\n  const auto& boxing_interpreter =\n      JUST(Singleton<EagerBoxingInterpreterManager>::Get()->GetEagerBoxingInterpreter(\n          placed_nd_sbp->nd_sbp(), placed_nd_sbp->nd_sbp(), JUST(tensor->parallel_desc()),\n          placed_nd_sbp->placement(), *tensor->shape()));\n  Singleton<const EagerBoxingLogger>::Get()->Log(\n      *JUST(boxing_interpreter->boxing_interpreter_status()), log_prefix);\n  return JUST(boxing_interpreter->Interpret(tensor, placed_nd_sbp->nd_sbp(),\n                                            placed_nd_sbp->nd_sbp(), JUST(tensor->parallel_desc()),\n                                            placed_nd_sbp->placement()));\n}\n\nconst std::string& LogPrefix4EagerSliceBoxingType(EagerSliceBoxingType boxing_type) {\n  static thread_local const HashMap<EagerSliceBoxingType, std::string> boxing_type2log_prefix = {\n      {EagerSliceBoxingType::kNaiveBToS, \"\\t\\tInternal boxing of naive-b-to-s, \"},\n      {EagerSliceBoxingType::kNaivePToB, \"\\t\\tInternal boxing of naive-p-to-b, \"},\n      {EagerSliceBoxingType::kNaivePToS, \"\\t\\tInternal boxing of naive-p-to-s, \"},\n      {EagerSliceBoxingType::kNaiveSToB, \"\\t\\tInternal boxing of naive-s-to-b, \"},\n      {EagerSliceBoxingType::kNaiveSToP, \"\\t\\tInternal boxing of naive-s-to-p, \"},\n      {EagerSliceBoxingType::kNaiveSToS, \"\\t\\tInternal boxing of naive-s-to-s, \"}};\n  return CHECK_JUST(MapAt(boxing_type2log_prefix, boxing_type));\n}\n\n}  // namespace private_details\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/boxing/slice_boxing_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_BOXING_SLICE_BOXING_UTIL_H_\n#define ONEFLOW_CORE_BOXING_SLICE_BOXING_UTIL_H_\n\n#include \"oneflow/core/framework/tensor.h\"\n#include \"oneflow/core/framework/placed_nd_sbp.h\"\n#include \"oneflow/core/job/parallel_desc.h\"\n\nnamespace oneflow {\n\nenum class EagerSliceBoxingType : unsigned int;\n\nnamespace private_details {\n\n// Copy to cpu if device of input tensor is not cpu or cuda, otherwise return self\nMaybe<one::Tensor> PreprocessInputTensor4SliceBoxing(const std::shared_ptr<one::Tensor>& tensor,\n                                                     const std::string& log_prefix);\n\n// Copy to corresponding device if device of output tensor is not same with that of placed_nd_sbp,\n// otherwise return self\nMaybe<one::Tensor> PostprocessOutputTensor4SliceBoxing(const std::shared_ptr<one::Tensor>& tensor,\n                                                       Symbol<PlacedNdSbp> placed_nd_sbp,\n                                                       const std::string& log_prefix);\n\nconst std::string& LogPrefix4EagerSliceBoxingType(EagerSliceBoxingType boxing_type);\n\n}  // namespace private_details\n\nenum class EagerSliceBoxingType : unsigned int {\n  kNaiveBToS = 0,\n  kNaivePToB = 1,\n  kNaivePToS = 2,\n  kNaiveSToB = 3,\n  kNaiveSToP = 4,\n  kNaiveSToS = 5\n};\n\ntemplate<EagerSliceBoxingType boxing_type>\nstruct EagerSliceBoxingAutoConvert {\n  template<Maybe<one::Tensor> (*func)(const std::shared_ptr<one::Tensor>&, Symbol<PlacedNdSbp>,\n                                      Symbol<PlacedNdSbp>)>\n  static Maybe<one::Tensor> Call(const std::shared_ptr<one::Tensor>& tensor, Symbol<PlacedNdSbp> in,\n                                 Symbol<PlacedNdSbp> out) {\n    std::shared_ptr<one::Tensor> processed_in_tensor =\n        JUST(private_details::PreprocessInputTensor4SliceBoxing(\n            tensor, private_details::LogPrefix4EagerSliceBoxingType(boxing_type)));\n    const auto& new_in =\n        JUST(PlacedNdSbp::New(in->nd_sbp(), JUST(processed_in_tensor->parallel_desc())));\n    Symbol<ParallelDesc> new_out_placement = JUST(ReplaceDeviceType(\n        out->placement(), JUST(processed_in_tensor->parallel_desc())->device_type()));\n    const auto& new_out = JUST(PlacedNdSbp::New(out->nd_sbp(), new_out_placement));\n    std::shared_ptr<one::Tensor> out_tensor = JUST(func(processed_in_tensor, new_in, new_out));\n    return JUST(private_details::PostprocessOutputTensor4SliceBoxing(\n        out_tensor, out, private_details::LogPrefix4EagerSliceBoxingType(boxing_type)));\n  }\n};\n\n#define EAGER_SLICE_BOXING_WARPPER(fn_ptr, boxing_type) \\\n  (&EagerSliceBoxingAutoConvert<boxing_type>::Call<fn_ptr>)\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_BOXING_SLICE_BOXING_UTIL_H_\n"
  },
  {
    "path": "oneflow/core/boxing/symmetric_acyclic_nd_sbp_boxing.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/boxing/eager_boxing_interpreter_mgr.h\"\n#include \"oneflow/core/boxing/eager_boxing_logger.h\"\n#include \"oneflow/core/framework/op_interpreter/op_interpreter_util.h\"\n#include \"oneflow/core/framework/placement_sbp_util.h\"\n#include \"oneflow/core/framework/placed_nd_sbp.h\"\n#include \"oneflow/core/framework/op_expr.h\"\n#include \"oneflow/core/framework/id_util.h\"\n#include \"oneflow/core/framework/nd_sbp.h\"\n#include \"oneflow/core/common/decorator.h\"\n#include \"oneflow/core/operator/operator.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nMaybe<one::Tensor> ReinterpterGlobalTensor(const std::shared_ptr<one::Tensor>& tensor,\n                                           const Shape& shape, Symbol<ParallelDesc> parallel_desc,\n                                           Symbol<NdSbp> nd_sbp) {\n  const auto& parallel_id = JUST(GetParallelId4CurrentProcessCtx(parallel_desc));\n  std::shared_ptr<Shape> pyhsical_shape =\n      JUST(GetPhysicalShape(shape, *nd_sbp, *parallel_desc, JUST(*parallel_id)));\n  std::shared_ptr<one::Tensor> x = JUST(tensor->cur_rank_phy_tensor());\n  if (*x->shape() != *pyhsical_shape) { x = JUST(one::functional::Reshape(x, *pyhsical_shape)); }\n  return JUST(one::functional::LocalToGlobal(x, parallel_desc, *JUST(GetSbpList(nd_sbp)), shape,\n                                             tensor->dtype(), /* sync_data */ false,\n                                             /*copy=*/false));\n}\n\nMaybe<one::Tensor> Apply1DBoxing(const std::shared_ptr<one::Tensor>& input, Symbol<NdSbp> in_nd_sbp,\n                                 Symbol<NdSbp> out_nd_sbp, Symbol<ParallelDesc> in_parallel_desc,\n                                 Symbol<ParallelDesc> out_parallel_desc) {\n  const auto& boxing_interpreter =\n      JUST(Singleton<EagerBoxingInterpreterManager>::Get()->GetEagerBoxingInterpreter(\n          in_nd_sbp, out_nd_sbp, in_parallel_desc, out_parallel_desc, *input->shape()));\n  Singleton<const EagerBoxingLogger>::Get()->Log(\n      *JUST(boxing_interpreter->boxing_interpreter_status()),\n      /* prefix */ \"\\t\\tInternal boxing of symmetric-acyclic-nd-sbp-to-nd-sbp, \");\n  return JUST(boxing_interpreter->Interpret(input, in_nd_sbp, out_nd_sbp, in_parallel_desc,\n                                            out_parallel_desc));\n}\n\n// NOLINTBEGIN(maybe-need-error-msg)\nMaybe<void> RawCheckSymmetricAcyclicNdSbpBoxing(Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out,\n                                                const Shape& logical_shape) {\n  CHECK_OR_RETURN(in->placement() == out->placement());\n  CHECK_OR_RETURN(in->nd_sbp() != out->nd_sbp());\n  CHECK_EQ_OR_RETURN(in->nd_sbp()->sbp_parallel_size(), out->nd_sbp()->sbp_parallel_size());\n  CHECK_GT_OR_RETURN(in->nd_sbp()->sbp_parallel_size(), 1);\n  JUST(CheckIsNdSbpBoxingAcyclicWithDecompose(in, out, logical_shape));\n  return Maybe<void>::Ok();\n}\n// NOLINTEND(maybe-need-error-msg)\n\nstatic constexpr auto* CheckSymmetricAcyclicNdSbpBoxing =\n    DECORATE(&RawCheckSymmetricAcyclicNdSbpBoxing, ThreadLocalCachedCopiable);\n\n}  // namespace\n\nMaybe<one::Tensor> SymmetricAcyclicNdSbpBoxing(const std::shared_ptr<one::Tensor>& input,\n                                               Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out) {\n  const auto& tensor_nd_sbp = JUST(input->nd_sbp());\n  CHECK_OR_RETURN(tensor_nd_sbp == in->nd_sbp())\n      << Error::RuntimeError() << \"The sbp of input tensor (\" << NdSbpToString(tensor_nd_sbp)\n      << \") must match the input sbp (\" << NdSbpToString(in->nd_sbp()) << \")\";\n  const auto& tensor_placement = JUST(input->parallel_desc());\n  CHECK_OR_RETURN(tensor_placement == in->placement())\n      << Error::RuntimeError() << \"The placement of input tensor (\"\n      << *JUST(PlacementToString(tensor_placement)) << \") must match the input placement (\"\n      << *JUST(PlacementToString(in->placement())) << \")\";\n  const auto& out_nd_sbp = out->nd_sbp();\n  const auto& out_parallel_desc = out->placement();\n  std::shared_ptr<one::Tensor> output;\n  const auto& out_parallel_id = JUST(GetParallelId4CurrentProcessCtx(out_parallel_desc));\n  if (out_parallel_id->has_value()) {\n    const auto& tensor_meta = JUST(input->global_tensor_meta());\n    const auto& naive_transformations =\n        JUST(DecomposeIntoNaiveTransformations(tensor_meta, out_nd_sbp));\n    std::shared_ptr<one::Tensor> tensor = input;\n    for (const auto& naive_transformation : *naive_transformations) {\n      const auto& sub_tensor_meta = naive_transformation.global_tensor_meta;\n      tensor = JUST(ReinterpterGlobalTensor(tensor, sub_tensor_meta->shape(),\n                                            sub_tensor_meta->parallel_desc(),\n                                            sub_tensor_meta->nd_sbp()));\n      tensor =\n          JUST(Apply1DBoxing(tensor, sub_tensor_meta->nd_sbp(), naive_transformation.dst_nd_sbp,\n                             sub_tensor_meta->parallel_desc(), sub_tensor_meta->parallel_desc()));\n    }\n    output = JUST(ReinterpterGlobalTensor(tensor, *input->shape(), out_parallel_desc, out_nd_sbp));\n  } else {\n    one::GlobalTensorMeta tensor_meta(*input->shape(), input->dtype()->data_type(),\n                                      input->memory_format(), out_nd_sbp, out_parallel_desc);\n    const auto& tensor_impl =\n        JUST(one::EagerGlobalTensorImpl::New(SymbolOf(tensor_meta), input->requires_grad(), false));\n    output = std::make_shared<one::GlobalTensor>(tensor_impl);\n  }\n  return output;\n}\n\nCOMMAND(RegisterBoxingFunction(\"symmetric-acyclic-nd-sbp-to-nd-sbp\",\n                               CheckSymmetricAcyclicNdSbpBoxing, &SymmetricAcyclicNdSbpBoxing));\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/boxing/symmetric_b_to_p_boxing.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/boxing/eager_boxing_interpreter.h\"\n#include \"oneflow/core/framework/device.h\"\n#include \"oneflow/core/framework/nd_sbp.h\"\n#include \"oneflow/core/job/nd_sbp_util.h\"\n#include \"oneflow/core/job/global_for.h\"\n#include \"oneflow/core/job/resource_desc.h\"\n#include \"oneflow/core/control/global_process_ctx.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nMaybe<void> RawCheckSymmetricBToP(Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out,\n                                  const Shape& logical_shape) {\n  // NOLINTBEGIN(maybe-need-error-msg)\n  CHECK_EQ_OR_RETURN(in->nd_sbp()->sbp_parallel_size(), 1);\n  CHECK_EQ_OR_RETURN(out->nd_sbp()->sbp_parallel_size(), 1);\n  CHECK_OR_RETURN(NdSbpIsAllBroadcast(*in->nd_sbp()));\n  CHECK_OR_RETURN(NdSbpIsAllPartialSum(*out->nd_sbp()));\n\n  CHECK_OR_RETURN(in->placement() == out->placement());\n  // NOLINTEND(maybe-need-error-msg)\n  return Maybe<void>::Ok();\n}\n\nstatic constexpr auto* CheckSymmetricBToP =\n    DECORATE(&RawCheckSymmetricBToP, ThreadLocalCachedCopiable);\n\n}  // namespace\n\nMaybe<one::Tensor> SymmetricBToP(const std::shared_ptr<one::Tensor>& tensor, Symbol<PlacedNdSbp> in,\n                                 Symbol<PlacedNdSbp> out) {\n  const auto& tensor_nd_sbp = JUST(tensor->nd_sbp());\n  CHECK_OR_RETURN(tensor_nd_sbp == in->nd_sbp())\n      << Error::RuntimeError() << \"The sbp of input tensor (\" << NdSbpToString(tensor_nd_sbp)\n      << \") must match the input sbp (\" << NdSbpToString(in->nd_sbp()) << \")\";\n  const auto& tensor_placement = JUST(tensor->parallel_desc());\n  CHECK_OR_RETURN(tensor_placement == in->placement())\n      << Error::RuntimeError() << \"The placement of input tensor (\"\n      << *JUST(PlacementToString(tensor_placement)) << \") must match the input placement (\"\n      << *JUST(PlacementToString(in->placement())) << \")\";\n\n  int64_t root = JUST(tensor_placement->MachineId4ParallelId(0));\n  std::shared_ptr<one::Tensor> local_tensor = JUST(tensor->cur_rank_phy_tensor());\n  if (root == GlobalProcessCtx::Rank()) {\n    // do nothing\n  } else {\n    local_tensor = JUST(one::functional::ZerosLike(local_tensor));\n  }\n  return JUST(one::functional::LocalToGlobal(\n      local_tensor, out->placement(), *JUST(GetSbpList(out->nd_sbp())), *tensor->shape(),\n      tensor->dtype(), /* sync_data */ false, /*copy=*/true));\n}\n\nCOMMAND(RegisterBoxingFunction(\"symmetric-b-to-p\", CheckSymmetricBToP, &SymmetricBToP));\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/boxing/symmetric_b_to_s_boxing.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/control/global_process_ctx.h\"\n#include \"oneflow/core/framework/nd_sbp.h\"\n#include \"oneflow/core/framework/device.h\"\n#include \"oneflow/core/boxing/eager_boxing_interpreter.h\"\n#include \"oneflow/core/functional/functional.h\"\n#include \"oneflow/core/common/decorator.h\"\n#include \"oneflow/core/register/tensor_slice_view.h\"\n#include \"oneflow/core/job/nd_sbp_util.h\"\n#include \"oneflow/core/framework/user_op_registry_manager.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nbool RawIsBroadcastSbp(Symbol<SbpParallel> sbp_parallel) {\n  return sbp_parallel->has_broadcast_parallel();\n}\n\nstatic constexpr auto* IsBroadcastSbp = DECORATE(&RawIsBroadcastSbp, ThreadLocalCached);\n\nbool RawIsSplitSbp(Symbol<SbpParallel> sbp_parallel) { return sbp_parallel->has_split_parallel(); }\n\nstatic constexpr auto* IsSplitSbp = DECORATE(&RawIsSplitSbp, ThreadLocalCached);\n\n// NOLINTBEGIN(maybe-need-error-msg)\nMaybe<void> RawCheckSymmetricB2S(Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out,\n                                 const Shape& logical_shape) {\n  CHECK_EQ_OR_RETURN(in->nd_sbp()->sbp_parallel_size(), 1);\n  CHECK_EQ_OR_RETURN(out->nd_sbp()->sbp_parallel_size(), 1);\n  CHECK_OR_RETURN(IsBroadcastSbp(SymbolOf(in->nd_sbp()->sbp_parallel(0))));\n  CHECK_OR_RETURN(IsSplitSbp(SymbolOf(out->nd_sbp()->sbp_parallel(0))));\n\n  CHECK_OR_RETURN(in->placement() == out->placement());                           // NOLINT\n  CHECK_OR_RETURN(in->placement()->device_type() != DeviceType::kInvalidDevice    // NOLINT\n                  && in->placement()->device_type() != kMeta                      // NOLINT\n                  && in->placement()->device_type() != DeviceType::kMockDevice);  // NOLINT\n  return Maybe<void>::Ok();\n}\n// NOLINTEND(maybe-need-error-msg)\n\nstatic constexpr auto* CheckSymmetricB2S =\n    DECORATE(&RawCheckSymmetricB2S, ThreadLocalCachedCopiable);\n\n}  // namespace\n\nMaybe<one::Tensor> SymmetricB2S(const std::shared_ptr<one::Tensor>& tensor, Symbol<PlacedNdSbp> in,\n                                Symbol<PlacedNdSbp> out) {\n  const auto& tensor_nd_sbp = JUST(tensor->nd_sbp());\n  CHECK_OR_RETURN(tensor_nd_sbp == in->nd_sbp())\n      << Error::RuntimeError() << \"The sbp of input tensor (\" << NdSbpToString(tensor_nd_sbp)\n      << \") must match the input sbp (\" << NdSbpToString(in->nd_sbp()) << \")\";\n  const auto& tensor_placement = JUST(tensor->parallel_desc());\n  CHECK_OR_RETURN(tensor_placement == in->placement())\n      << Error::RuntimeError() << \"The placement of input tensor (\"\n      << *JUST(PlacementToString(tensor_placement)) << \") must match the input placement (\"\n      << *JUST(PlacementToString(in->placement())) << \")\";\n\n  const auto& local_shape = *tensor->shape();\n  std::shared_ptr<one::Tensor> local_tensor = JUST(tensor->cur_rank_phy_tensor());\n\n  const auto& parallel_id = JUST(GetParallelId4CurrentProcessCtx(tensor_placement));\n\n  if (parallel_id->has_value()) {\n    const TensorSliceView& in_slice = GetTensorSliceView4ParallelId(\n        *tensor_placement->hierarchy(), *tensor_nd_sbp, local_shape, JUST(*parallel_id));\n    CHECK(!in_slice.IsEmpty());\n    const TensorSliceView& out_slice = GetTensorSliceView4ParallelId(\n        *tensor_placement->hierarchy(), *out->nd_sbp(), local_shape, JUST(*parallel_id));\n    CHECK(!out_slice.IsEmpty());\n    const TensorSliceView& intersection = out_slice.Intersect(in_slice);\n    CHECK(!intersection.IsEmpty());\n    const std::vector<Range>& range_vec = intersection.range_vec();\n    std::vector<int64_t> start;\n    std::vector<int64_t> stop;\n    std::vector<int64_t> step(range_vec.size(), 1);\n    for (const auto& range : range_vec) {\n      start.emplace_back(range.begin());\n      stop.emplace_back(range.end());\n    }\n    local_tensor = JUST(one::functional::Slice(local_tensor, start, stop, step,\n                                               /*enable_view_slice=*/false));\n  }\n\n  return JUST(one::functional::LocalToGlobal(\n      local_tensor, out->placement(), *JUST(GetSbpList(out->nd_sbp())), *tensor->shape(),\n      tensor->dtype(), /* sync_data */ false, /*copy=*/false));\n}\n\nCOMMAND(RegisterBoxingFunction(\"symmetric-b-to-s\", CheckSymmetricB2S, &SymmetricB2S));\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/boxing/symmetric_s_to_p_boxing.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/id_util.h\"\n#include \"oneflow/core/framework/nd_sbp.h\"\n#include \"oneflow/core/framework/op_builder.h\"\n#include \"oneflow/core/framework/op_expr.h\"\n#include \"oneflow/core/framework/op_interpreter/op_interpreter_util.h\"\n#include \"oneflow/core/framework/nd_sbp.h\"\n#include \"oneflow/core/common/decorator.h\"\n#include \"oneflow/core/boxing/eager_boxing_interpreter.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nbool RawIsSplitSbp(Symbol<SbpParallel> sbp_parallel) { return sbp_parallel->has_split_parallel(); }\n\nstatic constexpr auto* IsSplitSbp = DECORATE(&RawIsSplitSbp, ThreadLocalCached);\n\nbool RawIsPartialSumSbp(Symbol<SbpParallel> sbp_parallel) {\n  return sbp_parallel->has_partial_sum_parallel();\n}\n\nstatic constexpr auto* IsPartialSumSbp = DECORATE(&RawIsPartialSumSbp, ThreadLocalCached);\n\nMaybe<one::UserOpExpr> EagerSymmetricSToP(Symbol<ParallelDesc> parallel_desc,\n                                          Symbol<SbpParallel> src_sbp, const Shape& logical_shape) {\n  return one::OpBuilder(\"eager_symmetric_s_to_p\", *JUST(UniqueStr(\"eager_symmetric_s_to_p\")))\n      .Input(\"in\")\n      .Output(\"out\")\n      .Attr<int64_t>(\"in_split_axis\", src_sbp->split_parallel().axis())\n      .Attr<std::string>(\"parallel_conf\", PbMessage2TxtString(parallel_desc->parallel_conf()))\n      .Build();\n}\n\nstatic constexpr auto* CachedEagerSymmetricSToPOpExpr =\n    DECORATE(&EagerSymmetricSToP, ThreadLocalCachedCopiable);\n\n// NOLINTBEGIN(maybe-need-error-msg)\nMaybe<void> RawCheckSymmetricSToP(Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out,\n                                  const Shape& logical_shape) {\n  CHECK_EQ_OR_RETURN(in->nd_sbp()->sbp_parallel_size(), 1);\n  CHECK_EQ_OR_RETURN(out->nd_sbp()->sbp_parallel_size(), 1);\n\n  CHECK_OR_RETURN(IsSplitSbp(in->nd_sbp()->sbp_parallel(0)));\n  CHECK_OR_RETURN(IsPartialSumSbp(out->nd_sbp()->sbp_parallel(0)));\n\n  CHECK_OR_RETURN(in->placement() == out->placement());\n  return Maybe<void>::Ok();\n}\n// NOLINTEND(maybe-need-error-msg)\n\nstatic constexpr auto* CheckSymmetricSToP =\n    DECORATE(&RawCheckSymmetricSToP, ThreadLocalCachedCopiable);\n\n}  // namespace\n\nMaybe<one::Tensor> SymmetricSToP(const std::shared_ptr<one::Tensor>& tensor, Symbol<PlacedNdSbp> in,\n                                 Symbol<PlacedNdSbp> out) {\n  const auto& tensor_nd_sbp = JUST(tensor->nd_sbp());\n  CHECK_OR_RETURN(tensor_nd_sbp == in->nd_sbp())\n      << Error::RuntimeError() << \"The sbp of input tensor (\" << NdSbpToString(tensor_nd_sbp)\n      << \") must match the input sbp (\" << NdSbpToString(in->nd_sbp()) << \")\";\n  const auto& tensor_placement = JUST(tensor->parallel_desc());\n  CHECK_OR_RETURN(tensor_placement == in->placement())\n      << Error::RuntimeError() << \"The placement of input tensor (\"\n      << *JUST(PlacementToString(tensor_placement)) << \") must match the input placement (\"\n      << *JUST(PlacementToString(in->placement())) << \")\";\n\n  std::shared_ptr<one::OpExpr> op_expr = JUST(CachedEagerSymmetricSToPOpExpr(\n      tensor_placement, SymbolOf(tensor_nd_sbp->sbp_parallel(0)), *tensor->shape()));\n\n  return JUST(one::OpInterpUtil::Dispatch<one::Tensor>(*op_expr, {tensor}));\n}\n\nCOMMAND(RegisterBoxingFunction(\"symmetric-s-to-p\", CheckSymmetricSToP, &SymmetricSToP));\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/boxing/unflatten_hierarchy.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/nd_sbp.h\"\n#include \"oneflow/core/boxing/eager_boxing_interpreter.h\"\n#include \"oneflow/core/functional/functional.h\"\n#include \"oneflow/core/common/decorator.h\"\n#include \"oneflow/core/operator/operator.h\"\n\nnamespace oneflow {\n\nnamespace {\n\n// NOLINTBEGIN(maybe-need-error-msg)\nMaybe<void> RawCheckUnflattenHierarchy(Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out,\n                                       const Shape& logical_shape) {\n  CHECK_EQ_OR_RETURN(in->nd_sbp()->sbp_parallel_size(), 1);\n  CHECK_GT_OR_RETURN(out->nd_sbp()->sbp_parallel_size(), 1);\n  for (int i = 0; i < out->nd_sbp()->sbp_parallel_size(); ++i) {\n    const auto& sbp_parallel = out->nd_sbp()->sbp_parallel(i);\n    CHECK_OR_RETURN(sbp_parallel == out->nd_sbp()->sbp_parallel(0)) << \"nd_sbp axis: \" << i;\n  }\n  CHECK_EQ_OR_RETURN(in->placement()->device_type(), out->placement()->device_type());\n  CHECK_EQ_OR_RETURN(in->placement()->parallel_num(), out->placement()->parallel_num());\n  ParallelConf unflattened_parallel_conf(in->placement()->parallel_conf());\n  unflattened_parallel_conf.mutable_hierarchy()->CopyFrom(\n      out->placement()->parallel_conf().hierarchy());\n  const auto& unflatten_placement = SymbolOf(ParallelDesc(unflattened_parallel_conf));\n  CHECK_OR_RETURN(unflatten_placement == out->placement())\n      << \"The output placement is not a hierarch-unflattened version of the input placement\";\n  for (int64_t in_parallel_id = 0; in_parallel_id < in->placement()->parallel_num();\n       ++in_parallel_id) {\n    const auto& in_physical_shape =\n        JUST(GetPhysicalShape(logical_shape, *in->nd_sbp(), *in->placement(), in_parallel_id));\n    const auto& out_physical_shape =\n        JUST(GetPhysicalShape(logical_shape, *out->nd_sbp(), *out->placement(), in_parallel_id));\n    CHECK_EQ_OR_RETURN(*in_physical_shape, *out_physical_shape);\n  }\n  return Maybe<void>::Ok();\n}\n// NOLINTEND(maybe-need-error-msg)\n\n}  // namespace\n\nstatic constexpr auto* CheckUnflattenHierarchy =\n    DECORATE(&RawCheckUnflattenHierarchy, ThreadLocalCachedCopiable);\n\nMaybe<one::Tensor> UnflattenHierarchy(const std::shared_ptr<one::Tensor>& tensor,\n                                      Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out) {\n  const auto& tensor_nd_sbp = JUST(tensor->nd_sbp());\n  CHECK_OR_RETURN(tensor_nd_sbp == in->nd_sbp())\n      << Error::RuntimeError() << \"The sbp of input tensor (\" << NdSbpToString(tensor_nd_sbp)\n      << \") must match the input sbp (\" << NdSbpToString(in->nd_sbp()) << \")\";\n  const auto& tensor_placement = JUST(tensor->parallel_desc());\n  CHECK_OR_RETURN(tensor_placement == in->placement())\n      << Error::RuntimeError() << \"The placement of input tensor (\"\n      << *JUST(PlacementToString(tensor_placement)) << \") must match the input placement (\"\n      << *JUST(PlacementToString(in->placement())) << \")\";\n  const auto& local_tensor = JUST(tensor->cur_rank_phy_tensor());\n  const auto& sbp_list = JUST(GetSbpList(out->nd_sbp()));\n  return JUST(one::functional::LocalToGlobal(local_tensor, out->placement(), *sbp_list,\n                                             *tensor->shape(), tensor->dtype(),\n                                             /* sync_data */ false, /*copy=*/true));\n}\n\nCOMMAND(RegisterBoxingFunction(\"unflatten-hierarchy\", CheckUnflattenHierarchy,\n                               &UnflattenHierarchy));\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/ccl/ccl.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/ccl/ccl.h\"\n#include \"oneflow/core/device/nccl_util.h\"\n#include \"oneflow/core/framework/transport_util.h\"\n#include \"oneflow/core/job/parallel_desc.h\"\n#include \"oneflow/core/job/rank_group.h\"\n#include \"oneflow/core/common/data_type.h\"\n#include \"oneflow/core/common/data_type_seq.h\"\n#include \"oneflow/core/common/balanced_splitter.h\"\n#include \"oneflow/core/rpc/include/global_process_ctx.h\"\n#include \"oneflow/core/thread/thread_manager.h\"\n#include \"oneflow/core/job/eager_nccl_comm_manager.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n#include \"oneflow/core/common/constant.h\"\n\nnamespace oneflow {\nnamespace ccl {\n\nnamespace {\n\nMaybe<void> InitBroadcastRankHeap(std::vector<int64_t>* ranks, const ParallelDesc& parallel_desc,\n                                  int64_t root) {\n  CHECK_EQ_OR_RETURN(parallel_desc.parallel_num(), parallel_desc.sorted_machine_ids().size());\n  ranks->resize(parallel_desc.parallel_num());\n  int64_t root_index = -1;\n  for (int64_t parallel_id = 0; parallel_id < parallel_desc.parallel_num(); ++parallel_id) {\n    int64_t machine_id = JUST(parallel_desc.MachineId4ParallelId(parallel_id));\n    if (machine_id == root) { root_index = parallel_id; }\n    (*ranks)[parallel_id] = machine_id;\n  }\n  CHECK_NE_OR_RETURN(root_index, -1);\n  std::swap((*ranks)[0], (*ranks)[root_index]);\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\nMaybe<void> CpuBroadcast(const void* in, void* out, size_t buffer_size, int64_t root,\n                         Symbol<ParallelDesc> parallel_desc,\n                         const TransportToken& transport_token) {\n  static thread_local std::vector<int64_t> rank_heap{};\n  JUST(InitBroadcastRankHeap(&rank_heap, *parallel_desc, root));\n  auto Send = [&](void** buffer, std::size_t* size, std::function<void()>* Cb) -> Maybe<void> {\n    *buffer = (root == GlobalProcessCtx::Rank() ? const_cast<void*>(in) : out);\n    *size = buffer_size;\n    *Cb = [] {};\n    return Maybe<void>::Ok();\n  };\n  auto Recv = [&](void** buffer, std::size_t* size, std::function<void()>* Cb) -> Maybe<void> {\n    *buffer = out;\n    *size = buffer_size;\n    *Cb = [] {};\n    return Maybe<void>::Ok();\n  };\n  {\n    NaiveAsyncTransportCtx transport_ctx(transport_token, Send, Recv);\n    JUST(TransportUtil::ReceiveDataFromParentInHeap(rank_heap, transport_token, &transport_ctx));\n    JUST_MSG(transport_ctx.WaitDone(), kAsymmetricCodeErrorMsg);\n  }\n  {\n    NaiveAsyncTransportCtx transport_ctx(transport_token, Send, Recv);\n    JUST(TransportUtil::SendDataToChildrenInHeap(rank_heap, transport_token, &transport_ctx));\n    if (GlobalProcessCtx::Rank() == root && out != in) { std::memcpy(out, in, buffer_size); }\n    JUST_MSG(transport_ctx.WaitDone(), kAsymmetricCodeErrorMsg);\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> CpuSend(const void* in, size_t buffer_size, int64_t dst) {\n  TransportToken transport_token = JUST(TransportToken::NewTransportToken(kTransportTokenTypeData));\n  NaiveAsyncTransportCtx transport_ctx(\n      transport_token,\n      [&](void** buffer, std::size_t* size, std::function<void()>* Cb) -> Maybe<void> {\n        *buffer = const_cast<void*>(in);\n        *size = buffer_size;\n        *Cb = [] {};\n        return Maybe<void>::Ok();\n      },\n      [&](void** buffer, std::size_t* size, std::function<void()>* Cb) -> Maybe<void> {\n        UNIMPLEMENTED_THEN_RETURN();\n      });\n  JUST(TransportUtil::SendDataToRank(dst, transport_token, &transport_ctx));\n  JUST(transport_ctx.WaitDone());\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> CpuRecv(void* out, size_t buffer_size, int64_t src) {\n  TransportToken transport_token = JUST(TransportToken::NewTransportToken(kTransportTokenTypeData));\n  NaiveAsyncTransportCtx transport_ctx(\n      transport_token,\n      [&](void** buffer, std::size_t* size, std::function<void()>* Cb) -> Maybe<void> {\n        UNIMPLEMENTED_THEN_RETURN();\n      },\n      [&](void** buffer, std::size_t* size, std::function<void()>* Cb) -> Maybe<void> {\n        *buffer = out;\n        *size = buffer_size;\n        *Cb = [] {};\n        return Maybe<void>::Ok();\n      });\n  JUST(TransportUtil::ReceiveDataFromRank(src, transport_token, &transport_ctx));\n  JUST(transport_ctx.WaitDone());\n  return Maybe<void>::Ok();\n}\n\n}  // namespace ccl\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/ccl/ccl.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_CCL_CCL_H_\n#define ONEFLOW_CORE_CCL_CCL_H_\n\n#include \"oneflow/core/common/data_type.pb.h\"\n#include \"oneflow/core/common/device_type.h\"\n#include \"oneflow/core/common/symbol.h\"\n#include \"oneflow/core/common/switch_func.h\"\n#include \"oneflow/core/ep/include/stream.h\"\n\nnamespace oneflow {\n\nclass ParallelDesc;\nclass TransportToken;\n\n// collective communication library\nnamespace ccl {\n\nMaybe<void> CpuSend(const void* in, size_t buffer_size, int64_t dst);\n\nMaybe<void> CpuRecv(void* out, size_t buffer_size, int64_t src);\n\nMaybe<void> CpuBroadcast(const void* in, void* out, size_t buffer_size, int64_t root,\n                         Symbol<ParallelDesc> parallel_desc, const TransportToken& transport_token);\n\n}  // namespace ccl\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_CCL_CCL_H_\n"
  },
  {
    "path": "oneflow/core/comm_network/comm_network.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/comm_network/comm_network.h\"\n#include \"oneflow/core/control/global_process_ctx.h\"\n#include \"oneflow/core/job/resource_desc.h\"\n#include \"oneflow/core/job/env_desc.h\"\n#include \"oneflow/core/job/global_for.h\"\n\nnamespace oneflow {\n\nCommNet::~CommNet() {\n  ready_cbs_.Close();\n  ready_cb_poller_.join();\n}\n\nvoid* CommNet::NewActorReadId() { return new ActorReadContext; }\n\nvoid CommNet::DeleteActorReadId(void* actor_read_id) {\n  auto actor_read_ctx = static_cast<ActorReadContext*>(actor_read_id);\n  CHECK(actor_read_ctx->waiting_list.empty());\n  delete actor_read_ctx;\n}\n\nvoid CommNet::Read(void* actor_read_id, int64_t src_machine_id, void* src_token, void* dst_token) {\n  auto actor_read_ctx = static_cast<ActorReadContext*>(actor_read_id);\n  ReadContext* read_ctx = new ReadContext;\n  read_ctx->actor_read_ctx = actor_read_ctx;\n  auto do_read = [this, read_ctx, src_machine_id, src_token, dst_token]() {\n    DoRead(read_ctx, src_machine_id, src_token, dst_token);\n  };\n  AddWorkToStream(actor_read_id, do_read, true);\n}\n\nvoid CommNet::AddReadCallBack(void* actor_read_id, std::function<void()> callback) {\n  AddWorkToStream(actor_read_id, callback, false);\n}\n\nvoid CommNet::ReadDone(void* read_id) {\n  ReadContext* read_ctx = static_cast<ReadContext*>(read_id);\n  ActorReadContext* actor_read_ctx = read_ctx->actor_read_ctx;\n  CommNetItem item;\n  std::unique_lock<std::mutex> lck(actor_read_ctx->waiting_list_mtx);\n  CHECK(!actor_read_ctx->waiting_list.empty());\n  CHECK(actor_read_ctx->waiting_list.front().callback == nullptr);\n  actor_read_ctx->waiting_list.pop_front();\n  while (true) {\n    if (actor_read_ctx->waiting_list.empty()) { break; }\n    item = actor_read_ctx->waiting_list.front();\n    actor_read_ctx->waiting_list.pop_front();\n    CHECK(item.callback);\n    ready_cbs_.Send(item.callback);\n    if (item.is_read) { break; }\n  }\n  delete read_ctx;\n}\n\nvoid CommNet::AddWorkToStream(void* actor_read_id, const std::function<void()>& cb, bool is_read) {\n  auto actor_read_ctx = static_cast<ActorReadContext*>(actor_read_id);\n  std::unique_lock<std::mutex> lck(actor_read_ctx->waiting_list_mtx);\n  if (actor_read_ctx->waiting_list.empty()) {\n    ready_cbs_.Send(cb);\n  } else {\n    CommNetItem work_item(is_read, cb);\n    actor_read_ctx->waiting_list.emplace_back(work_item);\n  }\n  if (is_read) {\n    CommNetItem empty_cb;\n    actor_read_ctx->waiting_list.emplace_back(empty_cb);\n  }\n}\n\nCommNet::CommNet() {\n  int64_t this_machine_id = GlobalProcessCtx::Rank();\n  for (int64_t i : Singleton<ResourceDesc, ForSession>::Get()->process_ranks()) {\n    if (i == this_machine_id) { continue; }\n    peer_machine_id_.insert(i);\n  }\n\n  ready_cb_poller_ = std::thread([this]() {\n    std::function<void()> cb;\n    while (ready_cbs_.Receive(&cb) == kChannelStatusSuccess) { cb(); }\n  });\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/comm_network/comm_network.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_COMM_NETWORK_COMM_NETWORK_H_\n#define ONEFLOW_CORE_COMM_NETWORK_COMM_NETWORK_H_\n\n#ifndef DEPRECATED\n#define DEPRECATED __attribute__((deprecated))\n#endif\n\n#include \"oneflow/core/lazy/actor/actor_message.h\"\n#include \"oneflow/core/common/platform.h\"\n#include \"oneflow/core/common/channel.h\"\n\nnamespace oneflow {\n\nstruct CommNetItem {\n  bool is_read;\n  std::function<void()> callback;\n  CommNetItem() : CommNetItem(false, nullptr) {}\n  CommNetItem(bool read, const std::function<void()>& cb) : is_read(read), callback(cb) {}\n};\n\nclass CommNet {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CommNet);\n  virtual ~CommNet();\n\n  // \"RegisterMemory\" will return a Token, after \"RegisterMemoryDone\",\n  // we can use this token to use the \"Read\"\n  virtual void* RegisterMemory(void* ptr, size_t byte_size) = 0;\n  virtual void UnRegisterMemory(void* token) = 0;\n\n  // Stream\n  void* NewActorReadId();\n  void DeleteActorReadId(void* actor_read_id);\n  void Read(void* actor_read_id, int64_t src_machine_id, void* src_token, void* dst_token);\n  void AddReadCallBack(void* actor_read_id, std::function<void()> callback);\n  void ReadDone(void* read_id);\n\n  virtual void SendActorMsg(int64_t dst_machine_id, const ActorMsg& msg) = 0;\n\n protected:\n  CommNet();\n\n  virtual void DoRead(void* read_id, int64_t src_machine_id, void* src_token, void* dst_token) = 0;\n  const HashSet<int64_t>& peer_machine_id() { return peer_machine_id_; }\n\n  Channel<std::function<void()>> ready_cbs_;\n\n private:\n  friend class Singleton<CommNet>;\n  void AddWorkToStream(void* actor_read_id, const std::function<void()>& cb, bool is_read);\n  struct ActorReadContext;\n  struct ReadContext {\n    ActorReadContext* actor_read_ctx;\n  };\n  struct ActorReadContext {\n    std::mutex waiting_list_mtx;\n    std::list<CommNetItem> waiting_list;\n  };\n  HashSet<int64_t> peer_machine_id_;\n  std::thread ready_cb_poller_;\n};\n\ntemplate<typename MemDescType>\nclass CommNetIf : public CommNet {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CommNetIf);\n  CommNetIf() : CommNet() {}\n  virtual ~CommNetIf() {}\n\n  void* RegisterMemory(void* ptr, size_t byte_size) override {\n    std::unique_lock<std::mutex> lck(mem_descs_mtx_);\n    MemDescType* mem_desc = NewMemDesc(ptr, byte_size);\n    CHECK(mem_descs_.insert(mem_desc).second);\n    return mem_desc;\n  }\n\n  void UnRegisterMemory(void* token) override {\n    std::unique_lock<std::mutex> lck(mem_descs_mtx_);\n    MemDescType* mem_desc = static_cast<MemDescType*>(token);\n    delete mem_desc;\n    CHECK_EQ(mem_descs_.erase(mem_desc), 1);\n  }\n\n protected:\n  virtual MemDescType* NewMemDesc(void* ptr, size_t byte_size) = 0;\n  const HashSet<MemDescType*>& mem_descs() { return mem_descs_; }\n\n private:\n  std::mutex mem_descs_mtx_;\n  HashSet<MemDescType*> mem_descs_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_COMM_NETWORK_COMM_NETWORK_H_\n"
  },
  {
    "path": "oneflow/core/comm_network/epoll/epoll_comm_network.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifdef __linux__\n\n#include \"oneflow/core/comm_network/epoll/epoll_comm_network.h\"\n#include \"glog/logging.h\"\n#include \"oneflow/core/control/ctrl_client.h\"\n#include \"oneflow/core/control/global_process_ctx.h\"\n#include \"oneflow/core/job/resource_desc.h\"\n#include \"oneflow/core/job/env_desc.h\"\n#include \"oneflow/core/job/global_for.h\"\n#include <netinet/tcp.h>\n\nnamespace oneflow {\n\nnamespace {\n\nstatic const int32_t kInvlidPort = 0;\n\nsockaddr_in GetSockAddr(const std::string& addr, uint16_t port) {\n  sockaddr_in sa;\n  sa.sin_family = AF_INET;\n  sa.sin_port = htons(port);\n  PCHECK(inet_pton(AF_INET, addr.c_str(), &(sa.sin_addr)) == 1)\n      << \"addr: \" << addr << \", port: \" << port;\n  return sa;\n}\n\nint SockListen(int listen_sockfd, int32_t* listen_port, int32_t total_machine_num) {\n  // System designated available port if listen_port == kInvlidPort, otherwise, the configured port\n  // is used.\n  sockaddr_in sa = GetSockAddr(\"0.0.0.0\", *listen_port);\n  int reuse = 1;\n  int ret_setopt =\n      setsockopt(listen_sockfd, SOL_SOCKET, SO_REUSEADDR, (const void*)&reuse, sizeof(int));\n  CHECK_EQ(ret_setopt, 0);\n  int bind_result = bind(listen_sockfd, reinterpret_cast<sockaddr*>(&sa), sizeof(sa));\n  {\n    sockaddr_in bound_sock;\n    socklen_t bound_sock_size = sizeof(bound_sock);\n    getsockname(listen_sockfd, reinterpret_cast<sockaddr*>(&bound_sock), &bound_sock_size);\n    if (*listen_port != kInvlidPort) {\n      CHECK_EQ(*listen_port, static_cast<int32_t>(ntohs(bound_sock.sin_port)));\n    } else {\n      *listen_port = static_cast<int32_t>(ntohs(bound_sock.sin_port));\n    }\n  }\n  if (bind_result == 0) {\n    PCHECK(listen(listen_sockfd, total_machine_num) == 0);\n    LOG(INFO) << \"CommNet:Epoll listening on \"\n              << \"0.0.0.0:\" + std::to_string(*listen_port);\n  } else {\n    PCHECK(errno == EACCES || errno == EADDRINUSE) << \"SockListen errno: \" << errno;\n  }\n  return bind_result;\n}\n\nstd::string GenPortKey(int64_t machine_id) { return \"EpollPort/\" + std::to_string(machine_id); }\nvoid PushPort(int64_t machine_id, uint16_t port) {\n  Singleton<CtrlClient>::Get()->PushKV(GenPortKey(machine_id), std::to_string(port));\n}\nvoid ClearPort(int64_t machine_id) {\n  Singleton<CtrlClient>::Get()->ClearKV(GenPortKey(machine_id));\n}\nuint16_t PullPort(int64_t machine_id) {\n  uint16_t port = 0;\n  Singleton<CtrlClient>::Get()->PullKV(\n      GenPortKey(machine_id), [&](const std::string& v) { port = oneflow_cast<uint16_t>(v); });\n  return port;\n}\n\n}  // namespace\n\nEpollCommNet::~EpollCommNet() {\n  for (size_t i = 0; i < pollers_.size(); ++i) {\n    VLOG(1) << \"CommNet Thread \" << i << \" finish\";\n    pollers_[i]->Stop();\n  }\n  OF_ENV_BARRIER();\n  for (IOEventPoller* poller : pollers_) { delete poller; }\n  for (auto& pair : sockfd2helper_) { delete pair.second; }\n}\n\nvoid EpollCommNet::SendActorMsg(int64_t dst_machine_id, const ActorMsg& actor_msg) {\n  SocketMsg msg;\n  msg.msg_type = SocketMsgType::kActor;\n  msg.actor_msg = actor_msg;\n  if (actor_msg.IsDataRegstMsgToConsumer()) {\n    msg.actor_msg.set_comm_net_token(actor_msg.regst()->comm_net_token());\n  }\n  GetSocketHelper(dst_machine_id)->AsyncWrite(msg);\n}\n\nvoid EpollCommNet::SendTransportMsg(int64_t dst_machine_id, const TransportMsg& transport_msg) {\n  SocketMsg msg;\n  msg.msg_type = SocketMsgType::kTransport;\n  msg.transport_msg = transport_msg;\n  SendSocketMsg(dst_machine_id, msg);\n}\n\nvoid EpollCommNet::SendSocketMsg(int64_t dst_machine_id, const SocketMsg& msg) {\n  GetSocketHelper(dst_machine_id)->AsyncWrite(msg);\n}\n\nSocketMemDesc* EpollCommNet::NewMemDesc(void* ptr, size_t byte_size) {\n  SocketMemDesc* mem_desc = new SocketMemDesc;\n  mem_desc->mem_ptr = ptr;\n  mem_desc->byte_size = byte_size;\n  return mem_desc;\n}\n\nEpollCommNet::EpollCommNet() : CommNetIf() {\n  pollers_.resize(Singleton<ResourceDesc, ForSession>::Get()->CommNetWorkerNum(), nullptr);\n  for (size_t i = 0; i < pollers_.size(); ++i) { pollers_[i] = new IOEventPoller; }\n  InitSockets();\n  for (IOEventPoller* poller : pollers_) { poller->Start(); }\n}\n\nvoid EpollCommNet::InitSockets() {\n  int64_t this_machine_id = GlobalProcessCtx::Rank();\n  auto this_machine = Singleton<ResourceDesc, ForSession>::Get()->machine(this_machine_id);\n  int64_t total_machine_num = Singleton<ResourceDesc, ForSession>::Get()->process_ranks().size();\n  machine_id2sockfd_.assign(total_machine_num, -1);\n  sockfd2helper_.clear();\n  size_t poller_idx = 0;\n  auto NewSocketHelper = [&](int sockfd) {\n    IOEventPoller* poller = pollers_[poller_idx];\n    poller_idx = (poller_idx + 1) % pollers_.size();\n    return new SocketHelper(sockfd, poller);\n  };\n\n  // listen\n  int listen_sockfd = socket(AF_INET, SOCK_STREAM, 0);\n  int32_t this_listen_port = kInvlidPort;\n  {\n    if (this_machine.data_port_agent() != -1) {\n      this_listen_port = this_machine.data_port_agent();\n    } else if (Singleton<EnvDesc>::Get()->data_port() != -1) {\n      this_listen_port = Singleton<EnvDesc>::Get()->data_port();\n    }\n  }\n  CHECK_EQ(SockListen(listen_sockfd, &this_listen_port, total_machine_num), 0);\n  CHECK_NE(this_listen_port, 0);\n  PushPort(this_machine_id, this_listen_port);\n  int32_t src_machine_count = 0;\n\n  // connect\n  for (int64_t peer_id : peer_machine_id()) {\n    if (peer_id < this_machine_id) {\n      ++src_machine_count;\n      continue;\n    }\n    uint16_t peer_port = PullPort(peer_id);\n    auto peer_machine = Singleton<ResourceDesc, ForSession>::Get()->machine(peer_id);\n    sockaddr_in peer_sockaddr = GetSockAddr(peer_machine.addr(), peer_port);\n    int sockfd = socket(AF_INET, SOCK_STREAM, 0);\n    const int val = 1;\n    PCHECK(setsockopt(sockfd, IPPROTO_TCP, TCP_NODELAY, (char*)&val, sizeof(int)) == 0);\n    PCHECK(connect(sockfd, reinterpret_cast<sockaddr*>(&peer_sockaddr), sizeof(peer_sockaddr))\n           == 0);\n    ssize_t n = write(sockfd, &this_machine_id, sizeof(int64_t));\n    PCHECK(n == sizeof(int64_t));\n    CHECK(sockfd2helper_.emplace(sockfd, NewSocketHelper(sockfd)).second);\n    machine_id2sockfd_[peer_id] = sockfd;\n  }\n\n  // accept\n  HashSet<int64_t> processed_ranks;\n  FOR_RANGE(int32_t, idx, 0, src_machine_count) {\n    sockaddr_in peer_sockaddr;\n    socklen_t len = sizeof(peer_sockaddr);\n    int sockfd = accept(listen_sockfd, reinterpret_cast<sockaddr*>(&peer_sockaddr), &len);\n    PCHECK(sockfd != -1);\n    int64_t peer_rank;\n    ssize_t n = read(sockfd, &peer_rank, sizeof(int64_t));\n    PCHECK(n == sizeof(int64_t));\n    CHECK(sockfd2helper_.emplace(sockfd, NewSocketHelper(sockfd)).second);\n    CHECK(processed_ranks.emplace(peer_rank).second);\n    machine_id2sockfd_[peer_rank] = sockfd;\n  }\n  PCHECK(close(listen_sockfd) == 0);\n  ClearPort(this_machine_id);\n\n  // useful log\n  FOR_RANGE(int64_t, machine_id, 0, total_machine_num) {\n    VLOG(2) << \"machine \" << machine_id << \" sockfd \" << machine_id2sockfd_[machine_id];\n  }\n}\n\nSocketHelper* EpollCommNet::GetSocketHelper(int64_t machine_id) {\n  int sockfd = machine_id2sockfd_.at(machine_id);\n  return sockfd2helper_.at(sockfd);\n}\n\nvoid EpollCommNet::DoRead(void* read_id, int64_t src_machine_id, void* src_token, void* dst_token) {\n  SocketMsg msg;\n  msg.msg_type = SocketMsgType::kRequestWrite;\n  msg.request_write_msg.src_token = src_token;\n  msg.request_write_msg.dst_machine_id = GlobalProcessCtx::Rank();\n  msg.request_write_msg.dst_token = dst_token;\n  msg.request_write_msg.read_id = read_id;\n  GetSocketHelper(src_machine_id)->AsyncWrite(msg);\n}\n\n}  // namespace oneflow\n\n#endif  // __linux__\n"
  },
  {
    "path": "oneflow/core/comm_network/epoll/epoll_comm_network.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_COMM_NETWORK_EPOLL_EPOLL_COMM_NETWORK_H_\n#define ONEFLOW_CORE_COMM_NETWORK_EPOLL_EPOLL_COMM_NETWORK_H_\n\n#ifdef __linux__\n\n#include \"oneflow/core/comm_network/comm_network.h\"\n#include \"oneflow/core/comm_network/epoll/socket_helper.h\"\n#include \"oneflow/core/comm_network/epoll/socket_memory_desc.h\"\n\nnamespace oneflow {\n\nclass EpollCommNet final : public CommNetIf<SocketMemDesc> {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(EpollCommNet);\n  ~EpollCommNet();\n\n  void SendActorMsg(int64_t dst_machine_id, const ActorMsg& msg) override;\n  void SendSocketMsg(int64_t dst_machine_id, const SocketMsg& msg);\n  void SendTransportMsg(int64_t dst_machine_id, const TransportMsg& msg);\n\n private:\n  SocketMemDesc* NewMemDesc(void* ptr, size_t byte_size) override;\n\n  friend class Singleton<EpollCommNet>;\n  EpollCommNet();\n  void InitSockets();\n  SocketHelper* GetSocketHelper(int64_t machine_id);\n  void DoRead(void* read_id, int64_t src_machine_id, void* src_token, void* dst_token) override;\n\n  std::vector<IOEventPoller*> pollers_;\n  std::vector<int> machine_id2sockfd_;\n  HashMap<int, SocketHelper*> sockfd2helper_;\n};\n\n}  // namespace oneflow\n\n#endif  // __linux__\n\n#endif  // ONEFLOW_CORE_COMM_NETWORK_EPOLL_EPOLL_COMM_NETWORK_H_\n"
  },
  {
    "path": "oneflow/core/comm_network/epoll/io_event_poller.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifdef __linux__\n\n#include \"oneflow/core/comm_network/epoll/io_event_poller.h\"\n#include <sys/eventfd.h>\n\nnamespace oneflow {\n\nconst int IOEventPoller::max_event_num_ = 32;\n\nIOEventPoller::IOEventPoller() {\n  epfd_ = epoll_create1(0);\n  ep_events_ = new epoll_event[max_event_num_];\n  io_handlers_.clear();\n  break_epoll_loop_fd_ = eventfd(0, 0);\n  PCHECK(break_epoll_loop_fd_ != -1);\n  AddFdWithOnlyReadHandler(break_epoll_loop_fd_, []() { VLOG(1) << \"Break Epoll Loop\"; });\n}\n\nIOEventPoller::~IOEventPoller() {\n  for (IOHandler* handler : io_handlers_) {\n    PCHECK(close(handler->fd) == 0);\n    delete handler;\n  }\n  delete[] ep_events_;\n  PCHECK(close(epfd_) == 0);\n}\n\nvoid IOEventPoller::AddFd(int fd, std::function<void()> read_handler,\n                          std::function<void()> write_handler) {\n  AddFd(fd, &read_handler, &write_handler);\n}\n\nvoid IOEventPoller::AddFdWithOnlyReadHandler(int fd, std::function<void()> read_handler) {\n  AddFd(fd, &read_handler, nullptr);\n}\n\nvoid IOEventPoller::Start() { thread_ = std::thread(&IOEventPoller::EpollLoop, this); }\n\nvoid IOEventPoller::Stop() {\n  uint64_t break_epoll_loop_event = 1;\n  PCHECK(write(break_epoll_loop_fd_, &break_epoll_loop_event, 8) == 8);\n  thread_.join();\n}\n\nvoid IOEventPoller::AddFd(int fd, std::function<void()>* read_handler,\n                          std::function<void()>* write_handler) {\n  // Set Fd NONBLOCK\n  int opt = fcntl(fd, F_GETFL);\n  PCHECK(opt != -1);\n  PCHECK(fcntl(fd, F_SETFL, opt | O_NONBLOCK) == 0);\n  // Set CLOEXEC\n  opt = fcntl(fd, F_GETFD);\n  PCHECK(opt != -1);\n  PCHECK(fcntl(fd, F_SETFD, opt | FD_CLOEXEC) == 0);\n  // New IOHandler on Heap\n  IOHandler* io_handler = new IOHandler;\n  if (read_handler) { io_handler->read_handler = *read_handler; }\n  if (write_handler) { io_handler->write_handler = *write_handler; }\n  io_handler->fd = fd;\n  io_handlers_.push_front(io_handler);\n  // Add Fd to Epoll\n  epoll_event ep_event;\n  ep_event.events = EPOLLET;\n  if (read_handler) { ep_event.events |= EPOLLIN; }\n  if (write_handler) { ep_event.events |= EPOLLOUT; }\n  ep_event.data.ptr = io_handler;\n  PCHECK(epoll_ctl(epfd_, EPOLL_CTL_ADD, fd, &ep_event) == 0);\n}\n\nvoid IOEventPoller::EpollLoop() {\n  while (true) {\n    int event_num = epoll_wait(epfd_, ep_events_, max_event_num_, -1);\n    if (event_num == -1) {\n      PCHECK(errno == EINTR);\n      continue;\n    }\n    const epoll_event* cur_event = ep_events_;\n    for (int event_idx = 0; event_idx < event_num; ++event_idx, ++cur_event) {\n      auto io_handler = static_cast<IOHandler*>(cur_event->data.ptr);\n      PCHECK(!(cur_event->events & EPOLLERR)) << \"fd: \" << io_handler->fd;\n      if (io_handler->fd == break_epoll_loop_fd_) { return; }\n      if (cur_event->events & EPOLLIN) {\n        if (cur_event->events & EPOLLRDHUP) {\n          LOG(FATAL) << \"fd \" << io_handler->fd << \" closed by peer\";\n        } else {\n          io_handler->read_handler();\n        }\n      }\n      if (cur_event->events & EPOLLOUT) { io_handler->write_handler(); }\n    }\n  }\n}\n\n}  // namespace oneflow\n\n#endif  // __linux__\n"
  },
  {
    "path": "oneflow/core/comm_network/epoll/io_event_poller.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_COMM_NETWORK_EPOLL_IO_EVENT_POLLER_H_\n#define ONEFLOW_CORE_COMM_NETWORK_EPOLL_IO_EVENT_POLLER_H_\n\n#include \"oneflow/core/comm_network/epoll/socket_message.h\"\n\n#ifdef OF_PLATFORM_POSIX\n\nnamespace oneflow {\n\nclass IOEventPoller final {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(IOEventPoller);\n  IOEventPoller();\n  ~IOEventPoller();\n\n  void AddFd(int fd, std::function<void()> read_handler, std::function<void()> write_handler);\n  void AddFdWithOnlyReadHandler(int fd, std::function<void()> read_handler);\n\n  void Start();\n  void Stop();\n\n private:\n  struct IOHandler {\n    IOHandler() {\n      read_handler = []() { UNIMPLEMENTED(); };\n      write_handler = []() { UNIMPLEMENTED(); };\n      fd = -1;\n    }\n    std::function<void()> read_handler;\n    std::function<void()> write_handler;\n    int fd;\n  };\n\n  void AddFd(int fd, std::function<void()>* read_handler, std::function<void()>* write_handler);\n\n  void EpollLoop();\n  static const int max_event_num_;\n\n  int epfd_;\n  epoll_event* ep_events_;\n  std::forward_list<IOHandler*> io_handlers_;\n  int break_epoll_loop_fd_;\n  std::thread thread_;\n};\n\n}  // namespace oneflow\n\n#endif  // OF_PLATFORM_POSIX\n\n#endif  // ONEFLOW_CORE_COMM_NETWORK_EPOLL_IO_EVENT_POLLER_H_\n"
  },
  {
    "path": "oneflow/core/comm_network/epoll/socket_helper.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifdef __linux__\n\n#include \"oneflow/core/comm_network/epoll/socket_helper.h\"\n\nnamespace oneflow {\n\nSocketHelper::SocketHelper(int sockfd, IOEventPoller* poller) {\n  read_helper_ = new SocketReadHelper(sockfd);\n  write_helper_ = new SocketWriteHelper(sockfd, poller);\n  poller->AddFd(\n      sockfd, [this]() { read_helper_->NotifyMeSocketReadable(); },\n      [this]() { write_helper_->NotifyMeSocketWriteable(); });\n}\n\nSocketHelper::~SocketHelper() {\n  delete read_helper_;\n  delete write_helper_;\n}\n\nvoid SocketHelper::AsyncWrite(const SocketMsg& msg) { write_helper_->AsyncWrite(msg); }\n\n}  // namespace oneflow\n\n#endif  // __linux__\n"
  },
  {
    "path": "oneflow/core/comm_network/epoll/socket_helper.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_COMM_NETWORK_EPOLL_SOCKET_HELPER_H_\n#define ONEFLOW_CORE_COMM_NETWORK_EPOLL_SOCKET_HELPER_H_\n\n#include \"oneflow/core/comm_network/epoll/io_event_poller.h\"\n#include \"oneflow/core/comm_network/epoll/socket_read_helper.h\"\n#include \"oneflow/core/comm_network/epoll/socket_write_helper.h\"\n\n#ifdef OF_PLATFORM_POSIX\n\nnamespace oneflow {\n\nclass SocketHelper final {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(SocketHelper);\n  SocketHelper() = delete;\n  ~SocketHelper();\n\n  SocketHelper(int sockfd, IOEventPoller* poller);\n\n  void AsyncWrite(const SocketMsg& msg);\n\n private:\n  SocketReadHelper* read_helper_;\n  SocketWriteHelper* write_helper_;\n};\n\n}  // namespace oneflow\n\n#endif  // OF_PLATFORM_POSIX\n\n#endif  // ONEFLOW_CORE_COMM_NETWORK_EPOLL_SOCKET_HELPER_H_\n"
  },
  {
    "path": "oneflow/core/comm_network/epoll/socket_memory_desc.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_COMM_NETWORK_EPOLL_SOCKET_MEMORY_DESC_H_\n#define ONEFLOW_CORE_COMM_NETWORK_EPOLL_SOCKET_MEMORY_DESC_H_\n\n#include \"oneflow/core/comm_network/epoll/socket_memory_desc.h\"\n\n#ifdef OF_PLATFORM_POSIX\n\nnamespace oneflow {\n\nstruct SocketMemDesc {\n  void* mem_ptr;\n  size_t byte_size;\n};\n\n}  // namespace oneflow\n\n#endif  // OF_PLATFORM_POSIX\n\n#endif  // ONEFLOW_CORE_COMM_NETWORK_EPOLL_SOCKET_MEMORY_DESC_H_\n"
  },
  {
    "path": "oneflow/core/comm_network/epoll/socket_message.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_COMM_NETWORK_EPOLL_SOCKET_MESSAGE_H_\n#define ONEFLOW_CORE_COMM_NETWORK_EPOLL_SOCKET_MESSAGE_H_\n\n#include \"oneflow/core/common/platform.h\"\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/comm_network/comm_network.h\"\n\n#ifdef OF_PLATFORM_POSIX\n\n#include <arpa/inet.h>\n#include <fcntl.h>\n#include <netinet/in.h>\n#include <netinet/ip.h>\n#include <sys/epoll.h>\n#include <sys/socket.h>\n#include <sys/types.h>\n#include <unistd.h>\n#include \"oneflow/core/lazy/actor/actor_message.h\"\n#include \"oneflow/core/transport/transport_message.h\"\n\nnamespace oneflow {\n\n#define SOCKET_MSG_TYPE_SEQ                         \\\n  OF_PP_MAKE_TUPLE_SEQ(RequestWrite, request_write) \\\n  OF_PP_MAKE_TUPLE_SEQ(RequestRead, request_read)   \\\n  OF_PP_MAKE_TUPLE_SEQ(Actor, actor)                \\\n  OF_PP_MAKE_TUPLE_SEQ(Transport, transport)\n\nenum class SocketMsgType {\n#define MAKE_ENTRY(x, y) k##x,\n  OF_PP_FOR_EACH_TUPLE(MAKE_ENTRY, SOCKET_MSG_TYPE_SEQ)\n#undef MAKE_ENTRY\n};\n\nstruct RequestWriteMsg {\n  void* src_token;\n  int64_t dst_machine_id;\n  void* dst_token;\n  void* read_id;\n};\n\nstruct RequestReadMsg {\n  void* src_token;\n  void* dst_token;\n  void* read_id;\n};\n\nstruct SocketMsg {\n  SocketMsgType msg_type;\n  union {\n#define MAKE_ENTRY(x, y) x##Msg y##_msg;\n    OF_PP_FOR_EACH_TUPLE(MAKE_ENTRY, SOCKET_MSG_TYPE_SEQ)\n#undef MAKE_ENTRY\n  };\n};\n\nusing CallBackList = std::list<std::function<void()>>;\n\n}  // namespace oneflow\n\n#endif  // OF_PLATFORM_POSIX\n\n#endif  // ONEFLOW_CORE_COMM_NETWORK_EPOLL_SOCKET_MESSAGE_H_\n"
  },
  {
    "path": "oneflow/core/comm_network/epoll/socket_read_helper.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifdef __linux__\n\n#include \"oneflow/core/comm_network/epoll/socket_read_helper.h\"\n#include \"oneflow/core/lazy/actor/actor_message_bus.h\"\n#include \"oneflow/core/comm_network/epoll/epoll_comm_network.h\"\n#include \"oneflow/core/transport/transport.h\"\n\n#include <netinet/tcp.h>\n\nnamespace oneflow {\n\nSocketReadHelper::~SocketReadHelper() {\n  // do nothing\n}\n\nSocketReadHelper::SocketReadHelper(int sockfd) {\n  sockfd_ = sockfd;\n  SwitchToMsgHeadReadHandle();\n}\n\nvoid SocketReadHelper::NotifyMeSocketReadable() { ReadUntilSocketNotReadable(); }\n\nvoid SocketReadHelper::SwitchToMsgHeadReadHandle() {\n  cur_read_handle_ = &SocketReadHelper::MsgHeadReadHandle;\n  read_ptr_ = reinterpret_cast<char*>(&cur_msg_);\n  read_size_ = sizeof(cur_msg_);\n}\n\nvoid SocketReadHelper::ReadUntilSocketNotReadable() {\n  while ((this->*cur_read_handle_)()) {}\n}\n\nbool SocketReadHelper::MsgHeadReadHandle() {\n  return DoCurRead(&SocketReadHelper::SetStatusWhenMsgHeadDone);\n}\n\nbool SocketReadHelper::MsgBodyReadHandle() {\n  return DoCurRead(&SocketReadHelper::SetStatusWhenMsgBodyDone);\n}\n\nbool SocketReadHelper::DoCurRead(void (SocketReadHelper::*set_cur_read_done)()) {\n  ssize_t n = read(sockfd_, read_ptr_, read_size_);\n  const int val = 1;\n  PCHECK(setsockopt(sockfd_, IPPROTO_TCP, TCP_QUICKACK, (char*)&val, sizeof(int)) == 0);\n  if (n == read_size_) {\n    (this->*set_cur_read_done)();\n    return true;\n  } else if (n >= 0) {\n    read_ptr_ += n;\n    read_size_ -= n;\n    return true;\n  } else {\n    CHECK_EQ(n, -1);\n    PCHECK(errno == EAGAIN || errno == EWOULDBLOCK);\n    return false;\n  }\n}\n\nvoid SocketReadHelper::SetStatusWhenMsgHeadDone() {\n  switch (cur_msg_.msg_type) {\n#define MAKE_ENTRY(x, y) \\\n  case SocketMsgType::k##x: SetStatusWhen##x##MsgHeadDone(); break;\n    OF_PP_FOR_EACH_TUPLE(MAKE_ENTRY, SOCKET_MSG_TYPE_SEQ);\n#undef MAKE_ENTRY\n    default: UNIMPLEMENTED();\n  }\n}\n\nvoid SocketReadHelper::SetStatusWhenMsgBodyDone() {\n  if (cur_msg_.msg_type == SocketMsgType::kRequestRead) {\n    Singleton<EpollCommNet>::Get()->ReadDone(cur_msg_.request_read_msg.read_id);\n  }\n  SwitchToMsgHeadReadHandle();\n}\n\nvoid SocketReadHelper::SetStatusWhenRequestWriteMsgHeadDone() {\n  SocketMsg msg_to_send;\n  msg_to_send.msg_type = SocketMsgType::kRequestRead;\n  msg_to_send.request_read_msg.src_token = cur_msg_.request_write_msg.src_token;\n  msg_to_send.request_read_msg.dst_token = cur_msg_.request_write_msg.dst_token;\n  msg_to_send.request_read_msg.read_id = cur_msg_.request_write_msg.read_id;\n  Singleton<EpollCommNet>::Get()->SendSocketMsg(cur_msg_.request_write_msg.dst_machine_id,\n                                                msg_to_send);\n  SwitchToMsgHeadReadHandle();\n}\n\nvoid SocketReadHelper::SetStatusWhenRequestReadMsgHeadDone() {\n  auto mem_desc = static_cast<const SocketMemDesc*>(cur_msg_.request_read_msg.dst_token);\n  read_ptr_ = reinterpret_cast<char*>(mem_desc->mem_ptr);\n  read_size_ = mem_desc->byte_size;\n  cur_read_handle_ = &SocketReadHelper::MsgBodyReadHandle;\n}\n\nvoid SocketReadHelper::SetStatusWhenActorMsgHeadDone() {\n  Singleton<ActorMsgBus>::Get()->SendMsgWithoutCommNet(cur_msg_.actor_msg);\n  SwitchToMsgHeadReadHandle();\n}\n\nvoid SocketReadHelper::SetStatusWhenTransportMsgHeadDone() {\n  Singleton<Transport>::Get()->EnqueueTransportMsg(cur_msg_.transport_msg);\n  SwitchToMsgHeadReadHandle();\n}\n\n}  // namespace oneflow\n\n#endif  // __linux__\n"
  },
  {
    "path": "oneflow/core/comm_network/epoll/socket_read_helper.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_COMM_NETWORK_EPOLL_SOCKET_READ_HELPER_H_\n#define ONEFLOW_CORE_COMM_NETWORK_EPOLL_SOCKET_READ_HELPER_H_\n\n#include \"oneflow/core/comm_network/epoll/socket_message.h\"\n\n#ifdef OF_PLATFORM_POSIX\n\nnamespace oneflow {\n\nclass SocketReadHelper final {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(SocketReadHelper);\n  SocketReadHelper() = delete;\n  ~SocketReadHelper();\n\n  SocketReadHelper(int sockfd);\n\n  void NotifyMeSocketReadable();\n\n private:\n  void SwitchToMsgHeadReadHandle();\n  void ReadUntilSocketNotReadable();\n\n  bool MsgHeadReadHandle();\n  bool MsgBodyReadHandle();\n\n  bool DoCurRead(void (SocketReadHelper::*set_cur_read_done)());\n  void SetStatusWhenMsgHeadDone();\n  void SetStatusWhenMsgBodyDone();\n\n#define MAKE_ENTRY(x, y) void SetStatusWhen##x##MsgHeadDone();\n  OF_PP_FOR_EACH_TUPLE(MAKE_ENTRY, SOCKET_MSG_TYPE_SEQ);\n#undef MAKE_ENTRY\n\n  int sockfd_;\n\n  SocketMsg cur_msg_;\n  bool (SocketReadHelper::*cur_read_handle_)();\n  char* read_ptr_;\n  size_t read_size_;\n};\n\n}  // namespace oneflow\n\n#endif  // OF_PLATFORM_POSIX\n\n#endif  // ONEFLOW_CORE_COMM_NETWORK_EPOLL_SOCKET_READ_HELPER_H_\n"
  },
  {
    "path": "oneflow/core/comm_network/epoll/socket_write_helper.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifdef __linux__\n\n#include \"oneflow/core/comm_network/epoll/socket_write_helper.h\"\n#include \"oneflow/core/comm_network/epoll/socket_memory_desc.h\"\n\n#include <sys/eventfd.h>\n\nnamespace oneflow {\n\nSocketWriteHelper::~SocketWriteHelper() {\n  delete cur_msg_queue_;\n  cur_msg_queue_ = nullptr;\n  {\n    std::unique_lock<std::mutex> lck(pending_msg_queue_mtx_);\n    delete pending_msg_queue_;\n    pending_msg_queue_ = nullptr;\n  }\n}\n\nSocketWriteHelper::SocketWriteHelper(int sockfd, IOEventPoller* poller) {\n  sockfd_ = sockfd;\n  queue_not_empty_fd_ = eventfd(0, 0);\n  PCHECK(queue_not_empty_fd_ != -1);\n  poller->AddFdWithOnlyReadHandler(queue_not_empty_fd_,\n                                   std::bind(&SocketWriteHelper::ProcessQueueNotEmptyEvent, this));\n  cur_msg_queue_ = new std::queue<SocketMsg>;\n  pending_msg_queue_ = new std::queue<SocketMsg>;\n  cur_write_handle_ = &SocketWriteHelper::InitMsgWriteHandle;\n  write_ptr_ = nullptr;\n  write_size_ = 0;\n}\n\nvoid SocketWriteHelper::AsyncWrite(const SocketMsg& msg) {\n  pending_msg_queue_mtx_.lock();\n  bool need_send_event = pending_msg_queue_->empty();\n  pending_msg_queue_->push(msg);\n  pending_msg_queue_mtx_.unlock();\n  if (need_send_event) { SendQueueNotEmptyEvent(); }\n}\n\nvoid SocketWriteHelper::NotifyMeSocketWriteable() { WriteUntilMsgQueueEmptyOrSocketNotWriteable(); }\n\nvoid SocketWriteHelper::SendQueueNotEmptyEvent() {\n  uint64_t event_num = 1;\n  PCHECK(write(queue_not_empty_fd_, &event_num, 8) == 8);\n}\n\nvoid SocketWriteHelper::ProcessQueueNotEmptyEvent() {\n  uint64_t event_num = 0;\n  PCHECK(read(queue_not_empty_fd_, &event_num, 8) == 8);\n  WriteUntilMsgQueueEmptyOrSocketNotWriteable();\n}\n\nvoid SocketWriteHelper::WriteUntilMsgQueueEmptyOrSocketNotWriteable() {\n  while ((this->*cur_write_handle_)()) {}\n}\n\nbool SocketWriteHelper::InitMsgWriteHandle() {\n  if (cur_msg_queue_->empty()) {\n    {\n      std::unique_lock<std::mutex> lck(pending_msg_queue_mtx_);\n      std::swap(cur_msg_queue_, pending_msg_queue_);\n    }\n    if (cur_msg_queue_->empty()) { return false; }\n  }\n  cur_msg_ = cur_msg_queue_->front();\n  cur_msg_queue_->pop();\n  write_ptr_ = reinterpret_cast<const char*>(&cur_msg_);\n  write_size_ = sizeof(cur_msg_);\n  cur_write_handle_ = &SocketWriteHelper::MsgHeadWriteHandle;\n  return true;\n}\n\nbool SocketWriteHelper::MsgHeadWriteHandle() {\n  return DoCurWrite(&SocketWriteHelper::SetStatusWhenMsgHeadDone);\n}\n\nbool SocketWriteHelper::MsgBodyWriteHandle() {\n  return DoCurWrite(&SocketWriteHelper::SetStatusWhenMsgBodyDone);\n}\n\nbool SocketWriteHelper::DoCurWrite(void (SocketWriteHelper::*set_cur_write_done)()) {\n  ssize_t n = write(sockfd_, write_ptr_, write_size_);\n  if (n == write_size_) {\n    (this->*set_cur_write_done)();\n    return true;\n  } else if (n >= 0) {\n    write_ptr_ += n;\n    write_size_ -= n;\n    return true;\n  } else {\n    CHECK_EQ(n, -1);\n    PCHECK(errno == EAGAIN || errno == EWOULDBLOCK);\n    return false;\n  }\n}\n\nvoid SocketWriteHelper::SetStatusWhenMsgHeadDone() {\n  switch (cur_msg_.msg_type) {\n#define MAKE_ENTRY(x, y) \\\n  case SocketMsgType::k##x: return SetStatusWhen##x##MsgHeadDone();\n    OF_PP_FOR_EACH_TUPLE(MAKE_ENTRY, SOCKET_MSG_TYPE_SEQ);\n#undef MAKE_ENTRY\n    default: UNIMPLEMENTED();\n  }\n}\n\nvoid SocketWriteHelper::SetStatusWhenMsgBodyDone() {\n  cur_write_handle_ = &SocketWriteHelper::InitMsgWriteHandle;\n}\n\nvoid SocketWriteHelper::SetStatusWhenRequestWriteMsgHeadDone() {\n  cur_write_handle_ = &SocketWriteHelper::InitMsgWriteHandle;\n}\n\nvoid SocketWriteHelper::SetStatusWhenRequestReadMsgHeadDone() {\n  const void* src_token = cur_msg_.request_read_msg.src_token;\n  auto src_mem_desc = static_cast<const SocketMemDesc*>(src_token);\n  write_ptr_ = reinterpret_cast<const char*>(src_mem_desc->mem_ptr);\n  write_size_ = src_mem_desc->byte_size;\n  cur_write_handle_ = &SocketWriteHelper::MsgBodyWriteHandle;\n}\n\nvoid SocketWriteHelper::SetStatusWhenActorMsgHeadDone() {\n  cur_write_handle_ = &SocketWriteHelper::InitMsgWriteHandle;\n}\n\nvoid SocketWriteHelper::SetStatusWhenTransportMsgHeadDone() {\n  cur_write_handle_ = &SocketWriteHelper::InitMsgWriteHandle;\n}\n\n}  // namespace oneflow\n\n#endif  // __linux__\n"
  },
  {
    "path": "oneflow/core/comm_network/epoll/socket_write_helper.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_COMM_NETWORK_EPOLL_SOCKET_WRITE_HELPER_H_\n#define ONEFLOW_CORE_COMM_NETWORK_EPOLL_SOCKET_WRITE_HELPER_H_\n\n#include \"oneflow/core/comm_network/epoll/io_event_poller.h\"\n#include \"oneflow/core/comm_network/epoll/socket_message.h\"\n\n#ifdef OF_PLATFORM_POSIX\n\nnamespace oneflow {\n\nclass SocketWriteHelper final {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(SocketWriteHelper);\n  SocketWriteHelper() = delete;\n  ~SocketWriteHelper();\n\n  SocketWriteHelper(int sockfd, IOEventPoller* poller);\n\n  void AsyncWrite(const SocketMsg& msg);\n\n  void NotifyMeSocketWriteable();\n\n private:\n  void SendQueueNotEmptyEvent();\n  void ProcessQueueNotEmptyEvent();\n\n  void WriteUntilMsgQueueEmptyOrSocketNotWriteable();\n  bool InitMsgWriteHandle();\n  bool MsgHeadWriteHandle();\n  bool MsgBodyWriteHandle();\n\n  bool DoCurWrite(void (SocketWriteHelper::*set_cur_write_done)());\n  void SetStatusWhenMsgHeadDone();\n  void SetStatusWhenMsgBodyDone();\n\n#define MAKE_ENTRY(x, y) void SetStatusWhen##x##MsgHeadDone();\n  OF_PP_FOR_EACH_TUPLE(MAKE_ENTRY, SOCKET_MSG_TYPE_SEQ);\n#undef MAKE_ENTRY\n\n  int sockfd_;\n  int queue_not_empty_fd_;\n\n  std::queue<SocketMsg>* cur_msg_queue_;\n\n  std::mutex pending_msg_queue_mtx_;\n  std::queue<SocketMsg>* pending_msg_queue_;\n\n  SocketMsg cur_msg_;\n  bool (SocketWriteHelper::*cur_write_handle_)();\n  const char* write_ptr_;\n  size_t write_size_;\n};\n\n}  // namespace oneflow\n\n#endif  // OF_PLATFORM_POSIX\n\n#endif  // ONEFLOW_CORE_COMM_NETWORK_EPOLL_SOCKET_WRITE_HELPER_H_\n"
  },
  {
    "path": "oneflow/core/comm_network/ibverbs/ibverbs.proto",
    "content": "syntax = \"proto2\";\npackage oneflow;\n\nmessage IBVerbsConnectionInfo {\n  required uint32 lid = 1;\n  required uint32 qp_num = 2;\n  required uint64 subnet_prefix = 3;\n  required uint64 interface_id = 4;\n  required uint32 port_num = 5;\n  required int32 mtu = 6;\n}\n\n"
  },
  {
    "path": "oneflow/core/comm_network/ibverbs/ibverbs_comm_network.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/comm_network/ibverbs/ibverbs_comm_network.h\"\n#include \"oneflow/core/control/ctrl_client.h\"\n#include \"oneflow/core/control/global_process_ctx.h\"\n#include \"oneflow/core/job/resource_desc.h\"\n#include \"oneflow/core/job/global_for.h\"\n#include \"oneflow/core/platform/include/ibv.h\"\n#include \"oneflow/core/lazy/actor/actor_message_bus.h\"\n\n#if defined(WITH_RDMA) && defined(OF_PLATFORM_POSIX)\n\nnamespace oneflow {\n\nnamespace {\n\nstd::string GenTokensMsgKey(int64_t machine_id) {\n  return \"IBVerbsTokensMsg/\" + std::to_string(machine_id);\n}\n\nstd::string GenConnInfoKey(int64_t src_machine_id, int64_t dst_machine_id) {\n  return \"IBVerbsConnInfo/\" + std::to_string(src_machine_id) + \"/\" + std::to_string(dst_machine_id);\n}\n\nvoid IBVForkInit() {\n  if (ibv::IsAvailable()) {\n    if (ibv::wrapper.ibv_fork_init() != 0) { std::cerr << \"ibv_fork_init failed\\n\"; }\n  } else {\n    std::cerr << \"libibverbs not available, ibv_fork_init skipped\\n\";\n  }\n}\n\nvoid ParseUserDevicePort(std::string* device_name, int* port) {\n  std::string user_device_port = GetStringFromEnv(\"ONEFLOW_COMM_NET_IB_HCA\", \"\");\n  if (user_device_port.empty()) {\n    *device_name = \"\";\n    *port = 0;\n    return;\n  } else {\n    const std::string::size_type pos = user_device_port.find(':', 0);\n    if (pos == std::string::npos) {\n      *device_name = user_device_port;\n      *port = 0;\n      return;\n    } else {\n      *device_name = user_device_port.substr(0, pos);\n      *port = std::strtol(user_device_port.data() + pos + 1, nullptr, 10);\n      return;\n    }\n  }\n}\n\n}  // namespace\n\nIBVerbsCommNet::~IBVerbsCommNet() {\n  while (poll_exit_flag_.test_and_set() == true) {}\n  poll_thread_.join();\n  for (IBVerbsQP* qp : qp_vec_) {\n    if (qp) { delete qp; }\n  }\n  PCHECK(ibv::wrapper.ibv_destroy_cq(cq_) == 0);\n  PCHECK(ibv::wrapper.ibv_dealloc_pd(pd_) == 0);\n  CHECK_EQ(ibv::wrapper.ibv_close_device(context_), 0)\n      << \"Error, failed to close the IB device \"\n      << ibv::wrapper.ibv_get_device_name(context_->device);\n}\n\nvoid IBVerbsCommNet::SendActorMsg(int64_t dst_machine_id, const ActorMsg& msg) {\n  IBVerbsActorMsgWrapper msg_wrapper;\n  msg_wrapper.msg = msg;\n  if (msg.IsDataRegstMsgToConsumer()) {\n    auto* mem_desc = reinterpret_cast<IBVerbsMemDesc*>(msg.regst()->comm_net_token());\n    CHECK(mem_desc != nullptr);\n    msg_wrapper.rma_desc.mem_ptr = reinterpret_cast<uint64_t>(mem_desc->mem_ptr());\n    msg_wrapper.rma_desc.mem_size = mem_desc->mem_size();\n    msg_wrapper.rma_desc.mr_rkey = mem_desc->mr()->rkey;\n  }\n  qp_vec_.at(dst_machine_id)->PostSendRequest(msg_wrapper);\n}\n\nvoid IBVerbsCommNet::RecvActorMsg(const IBVerbsActorMsgWrapper& msg_wrapper) {\n  ActorMsg new_msg = msg_wrapper.msg;\n  if (msg_wrapper.msg.IsDataRegstMsgToConsumer()) {\n    std::lock_guard<std::mutex> lock(remote_regst2rma_desc_mutex_);\n    auto& desc = remote_regst2rma_desc_[std::make_pair(\n        msg_wrapper.msg.src_actor_id(), reinterpret_cast<uint64_t>(msg_wrapper.msg.regst()))];\n    if (!desc) { desc.reset(new IBVerbsCommNetRMADesc); }\n    *desc = msg_wrapper.rma_desc;\n    new_msg.set_comm_net_token(desc.get());\n  }\n  Singleton<ActorMsgBus>::Get()->SendMsgWithoutCommNet(new_msg);\n}\n\nIBVerbsCommNet::IBVerbsCommNet() : CommNetIf(), poll_exit_flag_(ATOMIC_FLAG_INIT) {\n  int num_device;\n  ibv_device** device_list = ibv::wrapper.ibv_get_device_list(&num_device);\n  CHECK_GT(num_device, 0) << \"No IB device found\";\n  PCHECK(device_list);\n  std::string user_device;\n  int user_port;\n  ParseUserDevicePort(&user_device, &user_port);\n  ibv_device* device = nullptr;\n  if (user_device.empty()) {\n    device = device_list[0];\n  } else {\n    for (int i = 0; i < num_device; ++i) {\n      if (device_list[i]->name == user_device) {\n        device = device_list[i];\n        break;\n      }\n    }\n    CHECK(device != nullptr) << \"No IB device match \" << user_device;\n  }\n  context_ = ibv::wrapper.ibv_open_device(device);\n  CHECK(context_ != NULL) << \"Error, failed to open the IB device \"\n                          << ibv::wrapper.ibv_get_device_name(device);\n  ibv::wrapper.ibv_free_device_list(device_list);\n  pd_ = ibv::wrapper.ibv_alloc_pd(context_);\n  CHECK(pd_) << \"Error, ibv_alloc_pd() allocates a Protection Domain (PD) failed\";\n  ibv_device_attr device_attr{};\n  PCHECK(ibv::wrapper.ibv_query_device(context_, &device_attr) == 0);\n  cq_ = ibv::wrapper.ibv_create_cq(context_, device_attr.max_cqe, nullptr, nullptr, 0);\n  PCHECK(cq_);\n  ibv_port_attr port_attr{};\n  const uint8_t port = user_port == 0 ? 1 : user_port;\n  PCHECK(ibv::wrapper.ibv_query_port_wrap(context_, port, &port_attr) == 0);\n  ibv_gid gid{};\n  const int64_t gid_index = ParseIntegerFromEnv(\"ONEFLOW_COMM_NET_IB_GID_INDEX\", 0);\n  PCHECK(ibv::wrapper.ibv_query_gid(context_, port, gid_index, &gid) == 0);\n  VLOG(1) << \"Using IB device \" << device->name << \" port \" << static_cast<int32_t>(port)\n          << \" gid index \" << gid_index;\n  int64_t this_machine_id = GlobalProcessCtx::Rank();\n  qp_vec_.assign(Singleton<ResourceDesc, ForEnv>::Get()->process_ranks().size(), nullptr);\n  for (int64_t peer_id : peer_machine_id()) {\n    IBVerbsQP* cur_qp = new IBVerbsQP(context_, pd_, port_attr, port, cq_, cq_);\n    qp_vec_.at(peer_id) = cur_qp;\n    IBVerbsConnectionInfo conn_info;\n    conn_info.set_lid(port_attr.lid);\n    conn_info.set_qp_num(cur_qp->qp_num());\n    conn_info.set_subnet_prefix(gid.global.subnet_prefix);\n    conn_info.set_interface_id(gid.global.interface_id);\n    conn_info.set_port_num(port);\n    conn_info.set_mtu(static_cast<int>(port_attr.active_mtu));\n    Singleton<CtrlClient>::Get()->PushKV(GenConnInfoKey(this_machine_id, peer_id), conn_info);\n  }\n  for (int64_t peer_id : peer_machine_id()) {\n    IBVerbsConnectionInfo conn_info;\n    Singleton<CtrlClient>::Get()->PullKV(GenConnInfoKey(peer_id, this_machine_id), &conn_info);\n    if (conn_info.lid() == 0) {\n      VLOG(2) << \"Connecting to peer \" << peer_id << \" port \" << conn_info.port_num() << \" qpn \"\n              << conn_info.qp_num() << \" gid index \" << gid_index << \" spn \"\n              << conn_info.subnet_prefix() << \" iid \" << conn_info.interface_id() << \" mtu \"\n              << conn_info.mtu();\n    } else {\n      VLOG(2) << \"Connecting to peer \" << peer_id << \" port \" << conn_info.port_num() << \" qpn \"\n              << conn_info.qp_num() << \" lid \" << conn_info.interface_id() << \" mtu \"\n              << conn_info.mtu();\n    }\n    qp_vec_.at(peer_id)->Connect(conn_info);\n    VLOG(1) << \"Connected to peer \" << peer_id;\n  }\n  OF_ENV_BARRIER();\n  for (int64_t peer_id : peer_machine_id()) {\n    qp_vec_.at(peer_id)->PostAllRecvRequest();\n    Singleton<CtrlClient>::Get()->ClearKV(GenConnInfoKey(this_machine_id, peer_id));\n  }\n  OF_ENV_BARRIER();\n  poll_thread_ = std::thread(&IBVerbsCommNet::PollCQ, this);\n  OF_ENV_BARRIER();\n}\n\nvoid IBVerbsCommNet::DoRead(void* read_id, int64_t src_machine_id, void* src_token,\n                            void* dst_token) {\n  qp_vec_.at(src_machine_id)\n      ->PostReadRequest(*reinterpret_cast<IBVerbsCommNetRMADesc*>(src_token),\n                        *static_cast<const IBVerbsMemDesc*>(dst_token), read_id);\n}\n\nvoid IBVerbsCommNet::PollCQ() {\n  std::vector<ibv_wc> wc_vec(max_poll_wc_num_);\n  while (poll_exit_flag_.test_and_set() == false) {\n    poll_exit_flag_.clear();\n    int32_t found_wc_num = ibv_poll_cq(cq_, max_poll_wc_num_, wc_vec.data());\n    CHECK_GE(found_wc_num, 0);\n    FOR_RANGE(int32_t, i, 0, found_wc_num) {\n      const ibv_wc& wc = wc_vec.at(i);\n      CHECK_EQ(wc.status, IBV_WC_SUCCESS) << wc.opcode;\n      WorkRequestId* wr_id = reinterpret_cast<WorkRequestId*>(wc.wr_id);\n      IBVerbsQP* qp = wr_id->qp;\n      switch (wc.opcode) {\n        case IBV_WC_RDMA_READ: {\n          qp->ReadDone(wr_id);\n          break;\n        }\n        case IBV_WC_SEND: {\n          qp->SendDone(wr_id);\n          break;\n        }\n        case IBV_WC_RECV: {\n          qp->RecvDone(wr_id);\n          break;\n        }\n        default: UNIMPLEMENTED();\n      }\n    }\n  }\n}\n\nconst int32_t IBVerbsCommNet::max_poll_wc_num_ = 32;\n\nCOMMAND(IBVForkInit());\n\n}  // namespace oneflow\n\n#endif  // WITH_RDMA && OF_PLATFORM_POSIX\n"
  },
  {
    "path": "oneflow/core/comm_network/ibverbs/ibverbs_comm_network.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_COMM_NETWORK_IBVERBS_IBVERBS_COMM_NETWORK_H_\n#define ONEFLOW_CORE_COMM_NETWORK_IBVERBS_IBVERBS_COMM_NETWORK_H_\n\n#include \"oneflow/core/common/platform.h\"\n#include \"oneflow/core/comm_network/comm_network.h\"\n#include \"oneflow/core/comm_network/ibverbs/ibverbs_memory_desc.h\"\n#include \"oneflow/core/comm_network/ibverbs/ibverbs_qp.h\"\n\n#if defined(WITH_RDMA) && defined(OF_PLATFORM_POSIX)\n\n#include <netdb.h>\n#include <arpa/inet.h>\n\nnamespace oneflow {\n\nclass IBVerbsCommNet final : public CommNetIf<IBVerbsMemDesc> {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(IBVerbsCommNet);\n  ~IBVerbsCommNet();\n\n  void SendActorMsg(int64_t dst_machine_id, const ActorMsg& msg) override;\n  void RecvActorMsg(const IBVerbsActorMsgWrapper& msg_wrapper);\n\n private:\n  friend class Singleton<IBVerbsCommNet>;\n  IBVerbsCommNet();\n\n  IBVerbsMemDesc* NewMemDesc(void* ptr, size_t byte_size) override {\n    return new IBVerbsMemDesc(pd_, ptr, byte_size);\n  }\n\n  void DoRead(void* read_id, int64_t src_machine_id, void* src_token, void* dst_token) override;\n  void PollCQ();\n\n  static const int32_t max_poll_wc_num_;\n\n  ibv_context* context_;\n  ibv_pd* pd_;\n  ibv_cq* cq_;\n  std::vector<IBVerbsQP*> qp_vec_;\n  std::atomic_flag poll_exit_flag_;\n  std::thread poll_thread_;\n  HashMap<std::pair<int64_t, uint64_t>, std::shared_ptr<IBVerbsCommNetRMADesc>>\n      remote_regst2rma_desc_;\n  std::mutex remote_regst2rma_desc_mutex_;\n};\n\n}  // namespace oneflow\n\n#endif  // WITH_RDMA && OF_PLATFORM_POSIX\n\n#endif  // ONEFLOW_CORE_COMM_NETWORK_IBVERBS_IBVERBS_COMM_NETWORK_H_\n"
  },
  {
    "path": "oneflow/core/comm_network/ibverbs/ibverbs_memory_desc.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/comm_network/ibverbs/ibverbs_memory_desc.h\"\n#include \"oneflow/core/job/global_for.h\"\n#include \"oneflow/core/platform/include/ibv.h\"\n\n#if defined(WITH_RDMA) && defined(OF_PLATFORM_POSIX)\n\nnamespace oneflow {\n\nIBVerbsMemDesc::IBVerbsMemDesc(ibv_pd* pd, void* mem_ptr, size_t byte_size)\n    : mem_ptr_(mem_ptr), mem_size_(byte_size) {\n  mr_ = ibv::wrapper.ibv_reg_mr_wrap(\n      pd, mem_ptr, byte_size,\n      IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ);\n  PCHECK(mr_);\n}\n\nIBVerbsMemDesc::~IBVerbsMemDesc() { PCHECK(ibv::wrapper.ibv_dereg_mr(mr_) == 0); }\n\n}  // namespace oneflow\n\n#endif  // WITH_RDMA && OF_PLATFORM_POSIX\n"
  },
  {
    "path": "oneflow/core/comm_network/ibverbs/ibverbs_memory_desc.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_COMM_NETWORK_IBVERBS_IBVERBS_MEMORY_DESC_H_\n#define ONEFLOW_CORE_COMM_NETWORK_IBVERBS_IBVERBS_MEMORY_DESC_H_\n\n#include \"oneflow/core/common/platform.h\"\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/comm_network/ibverbs/ibverbs.pb.h\"\n\n#if defined(WITH_RDMA) && defined(OF_PLATFORM_POSIX)\n\n#include <infiniband/verbs.h>\n\nnamespace oneflow {\n\nclass IBVerbsMemDesc final {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(IBVerbsMemDesc);\n  IBVerbsMemDesc() = delete;\n  IBVerbsMemDesc(ibv_pd* pd, void* mem_ptr, size_t byte_size);\n  ~IBVerbsMemDesc();\n\n  void* mem_ptr() const { return mem_ptr_; }\n\n  size_t mem_size() const { return mem_size_; }\n\n  const ibv_mr* mr() const { return mr_; }\n\n private:\n  ibv_mr* mr_;\n  void* mem_ptr_;\n  uint64_t mem_size_;\n};\n\n}  // namespace oneflow\n\n#endif  // WITH_RDMA && OF_PLATFORM_POSIX\n\n#endif  // ONEFLOW_CORE_COMM_NETWORK_IBVERBS_IBVERBS_MEMORY_DESC_H_\n"
  },
  {
    "path": "oneflow/core/comm_network/ibverbs/ibverbs_qp.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/comm_network/ibverbs/ibverbs_qp.h\"\n#include \"oneflow/core/comm_network/comm_network.h\"\n#include \"oneflow/core/lazy/actor/actor_message_bus.h\"\n#include \"oneflow/core/job/resource_desc.h\"\n#include \"oneflow/core/job/global_for.h\"\n#include \"oneflow/core/platform/include/ibv.h\"\n#include \"oneflow/core/comm_network/ibverbs/ibverbs_comm_network.h\"\n\n#if defined(WITH_RDMA) && defined(OF_PLATFORM_POSIX)\n\nnamespace oneflow {\n\nnamespace {\n\nconstexpr uint32_t kDefaultQueueDepth = 1024;\nconstexpr uint64_t kDefaultMemBlockSize = 8388608;  // 8M\n\n}  // namespace\n\nIBVerbsQP::IBVerbsQP(ibv_context* ctx, ibv_pd* pd, const struct ibv_port_attr& port_attr,\n                     uint8_t port_num, ibv_cq* send_cq, ibv_cq* recv_cq) {\n  // ctx_, pd_\n  ctx_ = ctx;\n  pd_ = pd;\n  port_num_ = port_num;\n  // qp_\n  ibv_device_attr device_attr{};\n  PCHECK(ibv::wrapper.ibv_query_device(ctx, &device_attr) == 0);\n  const int64_t user_queue_depth =\n      ParseIntegerFromEnv(\"ONEFLOW_COMM_NET_IB_QUEUE_DEPTH\", kDefaultQueueDepth);\n  const uint32_t queue_depth = std::min<uint32_t>(device_attr.max_qp_wr, user_queue_depth);\n  ibv_qp_init_attr qp_init_attr{};\n  qp_init_attr.qp_context = nullptr;\n  qp_init_attr.send_cq = send_cq;\n  qp_init_attr.recv_cq = recv_cq;\n  qp_init_attr.srq = nullptr;\n  qp_init_attr.cap.max_send_wr = queue_depth;\n  qp_init_attr.cap.max_recv_wr = queue_depth;\n  qp_init_attr.cap.max_send_sge = 1;\n  qp_init_attr.cap.max_recv_sge = 1;\n  qp_init_attr.cap.max_inline_data = 0;\n  qp_init_attr.qp_type = IBV_QPT_RC;\n  qp_init_attr.sq_sig_all = 1;\n  qp_ = ibv::wrapper.ibv_create_qp(pd, &qp_init_attr);\n  PCHECK(qp_);\n  // recv_msg_buf_\n  recv_msg_buf_.assign(queue_depth, nullptr);\n  FOR_RANGE(size_t, i, 0, recv_msg_buf_.size()) { recv_msg_buf_.at(i) = new ActorMsgMR(pd_); }\n  // send_msg_buf_\n  CHECK(send_msg_buf_.empty());\n  num_outstanding_send_wr_ = 0;\n  max_outstanding_send_wr_ = queue_depth;\n  read_block_size_ =\n      ParseIntegerFromEnv(\"ONEFLOW_COMM_NET_IB_MEM_BLOCK_SIZE\", kDefaultMemBlockSize);\n  mtu_ = static_cast<int32_t>(port_attr.active_mtu);\n}\n\nIBVerbsQP::~IBVerbsQP() {\n  PCHECK(ibv::wrapper.ibv_destroy_qp(qp_) == 0);\n  while (send_msg_buf_.empty() == false) {\n    delete send_msg_buf_.front();\n    send_msg_buf_.pop();\n  }\n  for (ActorMsgMR* msg_mr : recv_msg_buf_) { delete msg_mr; }\n}\n\nvoid IBVerbsQP::Connect(const IBVerbsConnectionInfo& peer_info) {\n  ibv_qp_attr qp_attr{};\n\n  // IBV_QPS_INIT\n  memset(&qp_attr, 0, sizeof(ibv_qp_attr));\n  qp_attr.qp_state = IBV_QPS_INIT;\n  // TODO(liujuncheng): Make pkey_index configurable\n  qp_attr.pkey_index = 0;\n  qp_attr.port_num = port_num_;\n  qp_attr.qp_access_flags =\n      IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ;\n  PCHECK(ibv::wrapper.ibv_modify_qp(\n             qp_, &qp_attr, IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_ACCESS_FLAGS)\n         == 0);\n\n  // IBV_QPS_RTR\n  memset(&qp_attr, 0, sizeof(ibv_qp_attr));\n  qp_attr.qp_state = IBV_QPS_RTR;\n  // TODO(liujuncheng): Make sl configurable;\n  qp_attr.ah_attr.sl = 0;\n  qp_attr.ah_attr.src_path_bits = 0;\n  if (peer_info.lid() == 0) {\n    qp_attr.ah_attr.is_global = 1;\n    qp_attr.ah_attr.grh.dgid.global.subnet_prefix = peer_info.subnet_prefix();\n    qp_attr.ah_attr.grh.dgid.global.interface_id = peer_info.interface_id();\n    qp_attr.ah_attr.grh.flow_label = 0;\n    const int64_t gid_index = ParseIntegerFromEnv(\"ONEFLOW_COMM_NET_IB_GID_INDEX\", 0);\n    qp_attr.ah_attr.grh.sgid_index = gid_index;\n    qp_attr.ah_attr.grh.hop_limit = 255;\n    // TODO(liujuncheng): Make traffic_class configurable;\n    qp_attr.ah_attr.grh.traffic_class = 0;\n  } else {\n    qp_attr.ah_attr.is_global = 0;\n    qp_attr.ah_attr.dlid = peer_info.lid();\n  }\n  qp_attr.ah_attr.port_num = peer_info.port_num();\n  qp_attr.path_mtu = static_cast<ibv_mtu>(std::min(peer_info.mtu(), mtu_));\n  qp_attr.dest_qp_num = peer_info.qp_num();\n  qp_attr.rq_psn = 0;\n  qp_attr.max_dest_rd_atomic = 1;\n  qp_attr.min_rnr_timer = 12;\n  PCHECK(ibv::wrapper.ibv_modify_qp(qp_, &qp_attr,\n                                    IBV_QP_STATE | IBV_QP_AV | IBV_QP_PATH_MTU | IBV_QP_DEST_QPN\n                                        | IBV_QP_RQ_PSN | IBV_QP_MAX_DEST_RD_ATOMIC\n                                        | IBV_QP_MIN_RNR_TIMER)\n         == 0);\n\n  // IBV_QPS_RTS\n  memset(&qp_attr, 0, sizeof(ibv_qp_attr));\n  qp_attr.qp_state = IBV_QPS_RTS;\n  qp_attr.sq_psn = 0;\n  qp_attr.max_rd_atomic = 1;\n  qp_attr.retry_cnt = 7;\n  qp_attr.rnr_retry = 7;\n  qp_attr.timeout = 14;\n  PCHECK(ibv::wrapper.ibv_modify_qp(qp_, &qp_attr,\n                                    IBV_QP_STATE | IBV_QP_SQ_PSN | IBV_QP_MAX_QP_RD_ATOMIC\n                                        | IBV_QP_RETRY_CNT | IBV_QP_RNR_RETRY | IBV_QP_TIMEOUT)\n         == 0);\n}\n\nvoid IBVerbsQP::PostAllRecvRequest() {\n  for (ActorMsgMR* msg_mr : recv_msg_buf_) { PostRecvRequest(msg_mr); }\n}\n\nvoid IBVerbsQP::PostReadRequest(const IBVerbsCommNetRMADesc& remote_mem,\n                                const IBVerbsMemDesc& local_mem, void* read_id) {\n  CHECK_EQ(remote_mem.mem_size, local_mem.mem_size());\n  WorkRequestId* wr_id = NewWorkRequestId();\n  const size_t block_num = RoundUp(remote_mem.mem_size, read_block_size_) / read_block_size_;\n  wr_id->outstanding_sge_cnt = static_cast<int32_t>(block_num);\n  wr_id->read_id = read_id;\n  FOR_RANGE(size_t, i, 0, block_num) {\n    ibv_send_wr wr{};\n    ibv_sge sge{};\n    sge.addr = reinterpret_cast<uint64_t>(local_mem.mem_ptr()) + i * read_block_size_;\n    sge.length = std::min(read_block_size_, local_mem.mem_size() - i * read_block_size_);\n    sge.lkey = local_mem.mr()->lkey;\n    wr.wr_id = reinterpret_cast<uint64_t>(wr_id);\n    wr.next = nullptr;\n    wr.sg_list = &sge;\n    wr.num_sge = 1;\n    wr.opcode = IBV_WR_RDMA_READ;\n    wr.send_flags = 0;\n    wr.imm_data = 0;\n    wr.wr.rdma.remote_addr = remote_mem.mem_ptr + i * read_block_size_;\n    wr.wr.rdma.rkey = remote_mem.mr_rkey;\n    EnqueuePostSendReadWR(wr, sge);\n  }\n}\n\nvoid IBVerbsQP::PostSendRequest(const IBVerbsActorMsgWrapper& msg_wrapper) {\n  ActorMsgMR* msg_mr = GetOneSendMsgMRFromBuf();\n  msg_mr->set_msg(msg_wrapper);\n  WorkRequestId* wr_id = NewWorkRequestId();\n  wr_id->msg_mr = msg_mr;\n  ibv_send_wr wr{};\n  ibv_sge sge{};\n  sge.addr = reinterpret_cast<uint64_t>(msg_mr->mem_desc().mem_ptr());\n  sge.length = msg_mr->mem_desc().mem_size();\n  sge.lkey = msg_mr->mem_desc().mr()->lkey;\n  wr.wr_id = reinterpret_cast<uint64_t>(wr_id);\n  wr.next = nullptr;\n  wr.sg_list = &sge;\n  wr.num_sge = 1;\n  wr.opcode = IBV_WR_SEND;\n  wr.send_flags = 0;\n  wr.imm_data = 0;\n  memset(&(wr.wr), 0, sizeof(wr.wr));\n  EnqueuePostSendReadWR(wr, sge);\n}\n\nvoid IBVerbsQP::EnqueuePostSendReadWR(ibv_send_wr wr, ibv_sge sge) {\n  std::unique_lock<std::mutex> pending_send_wr_lock_(pending_send_wr_mutex_);\n  if (num_outstanding_send_wr_ < max_outstanding_send_wr_) {\n    num_outstanding_send_wr_++;\n    ibv_send_wr* bad_wr = nullptr;\n    PCHECK(ibv_post_send(qp_, &wr, &bad_wr) == 0);\n  } else {\n    std::pair<ibv_send_wr, ibv_sge> ibv_send_wr_sge = std::make_pair(wr, sge);\n    pending_send_wr_queue_.push(ibv_send_wr_sge);\n  }\n}\n\nvoid IBVerbsQP::ReadDone(WorkRequestId* wr_id) {\n  CHECK_GE(wr_id->outstanding_sge_cnt, 1);\n  wr_id->outstanding_sge_cnt -= 1;\n  if (wr_id->outstanding_sge_cnt == 0) {\n    Singleton<CommNet>::Get()->ReadDone(wr_id->read_id);\n    DeleteWorkRequestId(wr_id);\n  }\n  PostPendingSendWR();\n}\n\nvoid IBVerbsQP::SendDone(WorkRequestId* wr_id) {\n  {\n    std::unique_lock<std::mutex> lck(send_msg_buf_mtx_);\n    send_msg_buf_.push(wr_id->msg_mr);\n  }\n  DeleteWorkRequestId(wr_id);\n  PostPendingSendWR();\n}\n\nvoid IBVerbsQP::RecvDone(WorkRequestId* wr_id) {\n  auto* ibv_comm_net = dynamic_cast<IBVerbsCommNet*>(Singleton<CommNet>::Get());\n  CHECK(ibv_comm_net != nullptr);\n  ibv_comm_net->RecvActorMsg(wr_id->msg_mr->msg());\n  PostRecvRequest(wr_id->msg_mr);\n  DeleteWorkRequestId(wr_id);\n}\n\nvoid IBVerbsQP::PostPendingSendWR() {\n  std::unique_lock<std::mutex> pending_send_wr_lock_(pending_send_wr_mutex_);\n  if (pending_send_wr_queue_.empty() == false) {\n    std::pair<ibv_send_wr, ibv_sge> ibv_send_wr_sge = std::move(pending_send_wr_queue_.front());\n    ibv_send_wr wr = ibv_send_wr_sge.first;\n    wr.sg_list = &ibv_send_wr_sge.second;\n    pending_send_wr_queue_.pop();\n    ibv_send_wr* bad_wr = nullptr;\n    PCHECK(ibv_post_send(qp_, &wr, &bad_wr) == 0);\n  } else {\n    if (num_outstanding_send_wr_ > 0) { num_outstanding_send_wr_--; }\n  }\n}\n\nvoid IBVerbsQP::PostRecvRequest(ActorMsgMR* msg_mr) {\n  WorkRequestId* wr_id = NewWorkRequestId();\n  wr_id->msg_mr = msg_mr;\n  ibv_recv_wr wr{};\n  ibv_sge sge{};\n  sge.addr = reinterpret_cast<uint64_t>(msg_mr->mem_desc().mem_ptr());\n  sge.length = msg_mr->mem_desc().mem_size();\n  sge.lkey = msg_mr->mem_desc().mr()->lkey;\n  wr.wr_id = reinterpret_cast<uint64_t>(wr_id);\n  wr.next = nullptr;\n  wr.sg_list = &sge;\n  wr.num_sge = 1;\n  ibv_recv_wr* bad_wr = nullptr;\n  PCHECK(ibv_post_recv(qp_, &wr, &bad_wr) == 0);\n}\n\nActorMsgMR* IBVerbsQP::GetOneSendMsgMRFromBuf() {\n  std::unique_lock<std::mutex> lck(send_msg_buf_mtx_);\n  if (send_msg_buf_.empty()) { send_msg_buf_.push(new ActorMsgMR(pd_)); }\n  ActorMsgMR* msg_mr = send_msg_buf_.front();\n  send_msg_buf_.pop();\n  return msg_mr;\n}\n\nWorkRequestId* IBVerbsQP::NewWorkRequestId() {\n  WorkRequestId* wr_id = new WorkRequestId;\n  wr_id->qp = this;\n  wr_id->outstanding_sge_cnt = 0;\n  wr_id->read_id = nullptr;\n  wr_id->msg_mr = nullptr;\n  return wr_id;\n}\n\nvoid IBVerbsQP::DeleteWorkRequestId(WorkRequestId* wr_id) {\n  CHECK_EQ(wr_id->qp, this);\n  delete wr_id;\n}\n\n}  // namespace oneflow\n\n#endif  // WITH_RDMA && OF_PLATFORM_POSIX\n"
  },
  {
    "path": "oneflow/core/comm_network/ibverbs/ibverbs_qp.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_COMM_NETWORK_IBVERBS_IBVERBS_QP_H_\n#define ONEFLOW_CORE_COMM_NETWORK_IBVERBS_IBVERBS_QP_H_\n\n#include \"oneflow/core/comm_network/ibverbs/ibverbs_memory_desc.h\"\n#include \"oneflow/core/lazy/actor/actor_message.h\"\n\n#if defined(WITH_RDMA) && defined(OF_PLATFORM_POSIX)\n\nnamespace oneflow {\n\nstruct IBVerbsCommNetRMADesc {\n  uint64_t mem_ptr;\n  uint64_t mem_size;\n  uint32_t mr_rkey;\n};\n\nstruct IBVerbsActorMsgWrapper final {\n  ActorMsg msg;\n  IBVerbsCommNetRMADesc rma_desc;\n};\n\nclass ActorMsgMR final {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(ActorMsgMR);\n  ActorMsgMR() = delete;\n  ActorMsgMR(ibv_pd* pd) { mem_desc_.reset(new IBVerbsMemDesc(pd, &msg_, sizeof(msg_))); }\n  ~ActorMsgMR() { mem_desc_.reset(); }\n\n  const IBVerbsActorMsgWrapper& msg() const { return msg_; }\n  void set_msg(const IBVerbsActorMsgWrapper& val) { msg_ = val; }\n  const IBVerbsMemDesc& mem_desc() const { return *mem_desc_; }\n\n private:\n  IBVerbsActorMsgWrapper msg_;\n  std::unique_ptr<IBVerbsMemDesc> mem_desc_;\n};\n\nclass IBVerbsQP;\n\nstruct WorkRequestId {\n  IBVerbsQP* qp;\n  int32_t outstanding_sge_cnt;\n  void* read_id;\n  ActorMsgMR* msg_mr;\n};\n\nstruct IBVerbsCommNetRMADesc;\n\nclass IBVerbsQP final {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(IBVerbsQP);\n  IBVerbsQP() = delete;\n  IBVerbsQP(ibv_context*, ibv_pd*, const struct ibv_port_attr&, uint8_t port_num, ibv_cq* send_cq,\n            ibv_cq* recv_cq);\n  ~IBVerbsQP();\n\n  uint32_t qp_num() const { return qp_->qp_num; }\n  void Connect(const IBVerbsConnectionInfo& peer_info);\n  void PostAllRecvRequest();\n\n  void PostReadRequest(const IBVerbsCommNetRMADesc& remote_mem, const IBVerbsMemDesc& local_mem,\n                       void* read_id);\n  void PostSendRequest(const IBVerbsActorMsgWrapper& msg_wrapper);\n\n  void ReadDone(WorkRequestId*);\n  void SendDone(WorkRequestId*);\n  void RecvDone(WorkRequestId*);\n\n private:\n  void EnqueuePostSendReadWR(ibv_send_wr wr, ibv_sge sge);\n  void PostPendingSendWR();\n  WorkRequestId* NewWorkRequestId();\n  void DeleteWorkRequestId(WorkRequestId* wr_id);\n  ActorMsgMR* GetOneSendMsgMRFromBuf();\n  void PostRecvRequest(ActorMsgMR*);\n\n  ibv_context* ctx_;\n  ibv_pd* pd_;\n  uint8_t port_num_;\n  ibv_qp* qp_;\n  std::vector<ActorMsgMR*> recv_msg_buf_;\n\n  std::mutex send_msg_buf_mtx_;\n  std::queue<ActorMsgMR*> send_msg_buf_;\n  std::mutex pending_send_wr_mutex_;\n  uint32_t num_outstanding_send_wr_;\n  uint32_t max_outstanding_send_wr_;\n  std::queue<std::pair<ibv_send_wr, ibv_sge>> pending_send_wr_queue_;\n  size_t read_block_size_;\n  int32_t mtu_;\n};\n\n}  // namespace oneflow\n\n#endif  // WITH_RDMA && OF_PLATFORM_POSIX\n\n#endif  // ONEFLOW_CORE_COMM_NETWORK_IBVERBS_IBVERBS_QP_H_\n"
  },
  {
    "path": "oneflow/core/common/array_ref.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_COMMON_ARRAY_REF_H_\n#define ONEFLOW_CORE_COMMON_ARRAY_REF_H_\n\n#include \"llvm/ADT/ArrayRef.h\"\n\nnamespace oneflow {\n\ntemplate<typename T>\nusing ArrayRef = llvm::ArrayRef<T>;\n\ntemplate<typename T>\nusing MutableArrayRef = llvm::MutableArrayRef<T>;\n\n}  // namespace oneflow\n\n#endif\n"
  },
  {
    "path": "oneflow/core/common/auto_registration_factory.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_COMMON_AUTO_REGISTRATION_FACTORY_H_\n#define ONEFLOW_CORE_COMMON_AUTO_REGISTRATION_FACTORY_H_\n\n#include \"oneflow/core/common/util.h\"\n\nnamespace oneflow {\n\ntemplate<typename Key, typename Base, typename... Args>\nstruct AutoRegistrationFactory {\n public:\n  using Creator = std::function<Base*(Args&&...)>;\n  template<typename Derived>\n  struct RawRegisterType {\n    RawRegisterType(Key k) {\n      CHECK((AutoRegistrationFactory<Key, Base, Args...>::Get()\n                 .mutable_creators()\n                 ->emplace(k, [](Args&&...) { return new Derived; })\n                 .second))\n          << k;\n    }\n  };\n\n  struct CreatorRegisterType {\n    CreatorRegisterType(Key k, Creator v) {\n      CHECK((AutoRegistrationFactory<Key, Base, Args...>::Get()\n                 .mutable_creators()\n                 ->emplace(k, v)\n                 .second))\n          << k;\n    }\n  };\n\n  Base* New(Key k, Args&&... args) const {\n    auto creators_it = creators().find(k);\n    CHECK(creators_it != creators().end())\n        << \"Unregistered: key: \" << k << \"  Base type name:\" << typeid(Base).name()\n        << \"  Key type name\" << typeid(Key).name();\n    return creators_it->second(std::forward<Args>(args)...);\n  }\n\n  bool IsClassRegistered(Key k, Args&&... args) const {\n    return creators().find(k) != creators().end();\n  }\n\n  static AutoRegistrationFactory<Key, Base, Args...>& Get() {\n    static AutoRegistrationFactory<Key, Base, Args...> obj;\n    return obj;\n  }\n\n private:\n  std::unique_ptr<HashMap<Key, Creator>> creators_;\n\n  bool has_creators() const { return creators_.get() != nullptr; }\n\n  const HashMap<Key, Creator>& creators() const {\n    CHECK(has_creators()) << \"Unregistered key type: \" << typeid(Key).name()\n                          << \"Base type name:\" << typeid(Base).name();\n    return *creators_.get();\n  }\n\n  HashMap<Key, Creator>* mutable_creators() {\n    if (!creators_) { creators_.reset(new HashMap<Key, Creator>); }\n    return creators_.get();\n  }\n};\n\n#define REGISTER_VAR_NAME OF_PP_CAT(g_registry_var, __COUNTER__)\n\n#define REGISTER_CLASS(Key, k, Base, Derived) \\\n  static AutoRegistrationFactory<Key, Base>::RawRegisterType<Derived> REGISTER_VAR_NAME(k)\n#define REGISTER_CLASS_WITH_ARGS(Key, k, Base, Derived, ...)                       \\\n  static AutoRegistrationFactory<Key, Base, __VA_ARGS__>::RawRegisterType<Derived> \\\n      REGISTER_VAR_NAME(k)\n#define REGISTER_CLASS_CREATOR(Key, k, Base, f, ...)                                               \\\n  static AutoRegistrationFactory<Key, Base, ##__VA_ARGS__>::CreatorRegisterType REGISTER_VAR_NAME( \\\n      k, f)\n\ntemplate<typename Key, typename Base, typename... Args>\ninline Base* NewObj(Key k, Args&&... args) {\n  return AutoRegistrationFactory<Key, Base, Args...>::Get().New(k, std::forward<Args>(args)...);\n}\n\ntemplate<typename Key, typename Base, typename... Args>\ninline std::unique_ptr<Base> NewObjUniquePtr(Key k, Args&&... args) {\n  return std::unique_ptr<Base>(\n      AutoRegistrationFactory<Key, Base, Args...>::Get().New(k, std::forward<Args>(args)...));\n}\n\ntemplate<typename Key, typename Base, typename... Args>\ninline bool IsClassRegistered(Key k, Args&&... args) {\n  return AutoRegistrationFactory<Key, Base, Args...>::Get().IsClassRegistered(\n      k, std::forward<Args>(args)...);\n}\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_COMMON_AUTO_REGISTRATION_FACTORY_H_\n"
  },
  {
    "path": "oneflow/core/common/balanced_splitter.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/balanced_splitter.h\"\n\nnamespace oneflow {\n\nBalancedSplitter::BalancedSplitter(int64_t total_num, int64_t split_num) {\n  base_part_size_ = total_num / split_num;\n  base_begin_idx_ = total_num % split_num;\n  split_num_ = split_num;\n  CHECK_EQ(this->total_num(), total_num);\n}\n\nint64_t BalancedSplitter::total_num() const { return At(split_num_ - 1).end(); }\n\nRange BalancedSplitter::At(int64_t idx) const {\n  CHECK_LT(idx, split_num_);\n  int64_t left_bound = -1;\n  int64_t right_bound = -1;\n  if (idx < base_begin_idx_) {\n    left_bound = (base_part_size_ + 1) * idx;\n    right_bound = left_bound + (base_part_size_ + 1);\n  } else {\n    left_bound =\n        (base_part_size_ + 1) * base_begin_idx_ + base_part_size_ * (idx - base_begin_idx_);\n    right_bound = left_bound + base_part_size_;\n  }\n  return Range(left_bound, right_bound);\n}\n\nRange BalancedSplitter::At(int64_t first_idx, int64_t last_idx) const {\n  CHECK_LE(first_idx, last_idx);\n  CHECK_LT(last_idx, split_num_);\n  Range first_range = At(first_idx);\n  Range last_range = At(last_idx);\n  return Range(first_range.begin(), last_range.end());\n}\n\nint64_t BalancedSplitter::GetRangeIndexForVal(int64_t value) const {\n  CHECK_GE(value, 0);\n  CHECK_LT(value, total_num());\n  int64_t base_size = (base_part_size_ + 1) * base_begin_idx_;\n  if (value < base_size) {\n    return value / (base_part_size_ + 1);\n  } else {\n    return base_begin_idx_ + (value - base_size) / base_part_size_;\n  }\n}\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/common/balanced_splitter.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_COMMON_BALANCED_SPLITTER_H_\n#define ONEFLOW_CORE_COMMON_BALANCED_SPLITTER_H_\n\n#include <stdint.h>\n#include \"oneflow/core/common/range.h\"\n#include \"oneflow/core/common/util.h\"\n\nnamespace oneflow {\n\n// For example\n// BalancedSplitter splitter(20, 6)\n// the result of splitter.At\n//     0    [0, 4)\n//     1    [4, 8)\n//     2    [8, 11)\n//     3    [11, 14)\n//     4    [14, 17)\n//     5    [17, 20)\nclass BalancedSplitter final {\n public:\n  // OF_DISALLOW_COPY_AND_MOVE(BalancedSplitter);\n  BalancedSplitter() = delete;\n  ~BalancedSplitter() = default;\n\n  BalancedSplitter(int64_t total_num, int64_t split_num);\n\n  Range At(int64_t idx) const;\n  Range At(int64_t first_idx, int64_t last_idx) const;\n\n  // Get the range index number of a value.\n  int64_t GetRangeIndexForVal(int64_t value) const;\n  int64_t total_num() const;\n\n private:\n  int64_t base_part_size_;\n  int64_t base_begin_idx_;\n  int64_t split_num_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_COMMON_BALANCED_SPLITTER_H_\n"
  },
  {
    "path": "oneflow/core/common/balanced_splitter_test.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"gtest/gtest.h\"\n#include \"oneflow/core/common/balanced_splitter.h\"\n\nnamespace oneflow {\n\nTEST(BalancedSplitter, split_20_to_6_part) {\n  BalancedSplitter splitter(20, 6);\n  ASSERT_TRUE(splitter.At(0) == Range(0, 4));\n  ASSERT_TRUE(splitter.At(1) == Range(4, 8));\n  ASSERT_TRUE(splitter.At(2) == Range(8, 11));\n  ASSERT_TRUE(splitter.At(3) == Range(11, 14));\n  ASSERT_TRUE(splitter.At(4) == Range(14, 17));\n  ASSERT_TRUE(splitter.At(5) == Range(17, 20));\n}\n\nTEST(BalancedSplitter, split_2_to_3_part) {\n  BalancedSplitter splitter(2, 3);\n  ASSERT_TRUE(splitter.At(0) == Range(0, 1));\n  ASSERT_TRUE(splitter.At(1) == Range(1, 2));\n  ASSERT_TRUE(splitter.At(2) == Range(2, 2));\n}\n\nTEST(BalancedSplitter, GetRangeIndexForVal) {\n  const size_t total_num = 937;\n  const size_t split_num = 11;\n  BalancedSplitter bs(total_num, split_num);\n  ASSERT_TRUE(bs.total_num() == total_num);\n  for (size_t i = 0; i < split_num; ++i) {\n    Range range = bs.At(i);\n    for (size_t value = range.begin(); value < range.end(); ++value) {\n      ASSERT_TRUE(bs.GetRangeIndexForVal(value) == i);\n    }\n  }\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/common/bfloat16.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_COMMON_BFLOAT16_H_\n#define ONEFLOW_CORE_COMMON_BFLOAT16_H_\n\n#include <stdint.h>\n#include <limits>\n#include <cmath>\n#include <cstring>\n\nnamespace oneflow {\n\n#if defined(__CUDACC__)\n#define OF_DEVICE_FUNCTION __device__ __host__ __forceinline__\n#else\n#define OF_DEVICE_FUNCTION inline\n#endif\n\nstruct alignas(2) bfloat16 {\n  uint16_t x;\n\n  bfloat16() = default;\n  bfloat16(const bfloat16& o) = default;\n  bfloat16& operator=(const bfloat16& o) = default;\n  bfloat16(bfloat16&& o) = default;\n  bfloat16& operator=(bfloat16&& o) = default;\n  ~bfloat16() = default;\n\n  struct from_bits_t {};\n  static constexpr inline from_bits_t from_bits() { return from_bits_t(); }\n\n  constexpr inline bfloat16(unsigned short bits, from_bits_t) : x(bits){};\n\n  // reference: pytorch/c10/util/BFloat16.h\n  // https://github.com/pytorch/pytorch/blob/release/1.12/c10/util/BFloat16.h\n  bfloat16(float value) {\n    if (std::isnan(value)) {\n      x = 0x7FC0;\n    } else {\n      union {\n        uint32_t U32;\n        float F32;\n      };\n\n      F32 = value;\n      uint32_t rounding_bias = ((U32 >> 16) & 1) + 0x7FFFU;\n      x = static_cast<uint16_t>((U32 + rounding_bias) >> 16);\n    }\n  }\n\n  inline operator float() const {\n    float res = 0;\n    uint32_t tmp = x;\n    tmp <<= 16;\n    std::memcpy(&res, &tmp, sizeof(tmp));\n    return res;\n  }\n\n  inline bool operator==(const bfloat16& other) const { return x == other.x; }\n\n  inline explicit operator bool() const { return (x & 0x7fff) != 0; }\n\n  inline explicit operator int8_t() const { return static_cast<int8_t>(static_cast<float>(*this)); }\n\n  inline explicit operator uint8_t() const {\n    return static_cast<uint8_t>(static_cast<float>(*this));\n  }\n\n  inline explicit operator int16_t() const {\n    return static_cast<int16_t>(static_cast<float>(*this));\n  }\n\n  inline explicit operator uint16_t() const {\n    return static_cast<uint16_t>(static_cast<float>(*this));\n  }\n\n  inline explicit operator int32_t() const {\n    return static_cast<int32_t>(static_cast<float>(*this));\n  }\n\n  inline explicit operator uint32_t() const {\n    return static_cast<uint32_t>(static_cast<float>(*this));\n  }\n\n  inline explicit operator int64_t() const {\n    return static_cast<int64_t>(static_cast<float>(*this));\n  }\n\n  inline explicit operator uint64_t() const {\n    return static_cast<uint64_t>(static_cast<float>(*this));\n  }\n\n  inline explicit operator double() const { return static_cast<double>(static_cast<float>(*this)); }\n};\n\n// Arithmetic\n\ninline bfloat16 operator+(const bfloat16& a, const bfloat16& b) {\n  return static_cast<float>(a) + static_cast<float>(b);\n}\n\ninline bfloat16 operator-(const bfloat16& a, const bfloat16& b) {\n  return static_cast<float>(a) - static_cast<float>(b);\n}\n\ninline bfloat16 operator*(const bfloat16& a, const bfloat16& b) {\n  return static_cast<float>(a) * static_cast<float>(b);\n}\n\ninline bfloat16 operator/(const bfloat16& a, const bfloat16& b) {\n  return static_cast<float>(a) / static_cast<float>(b);\n}\n\ninline bfloat16 operator-(const bfloat16& a) {\n  bfloat16 output;\n  output.x = a.x ^ 0x8000U;\n  return output;\n}\n\ninline bfloat16& operator+=(bfloat16& a, const bfloat16& b) {\n  a = a + b;\n  return a;\n}\n\ninline bfloat16& operator-=(bfloat16& a, const bfloat16& b) {\n  a = a - b;\n  return a;\n}\n\ninline bfloat16& operator*=(bfloat16& a, const bfloat16& b) {\n  a = a * b;\n  return a;\n}\n\ninline bfloat16& operator/=(bfloat16& a, const bfloat16& b) {\n  a = a / b;\n  return a;\n}\n\ninline bfloat16& operator|(bfloat16& a, const bfloat16& b) {\n  a.x = a.x | b.x;\n  return a;\n}\n\ninline bfloat16& operator^(bfloat16& a, const bfloat16& b) {\n  a.x = a.x ^ b.x;\n  return a;\n}\n\ninline bfloat16& operator&(bfloat16& a, const bfloat16& b) {\n  a.x = a.x & b.x;\n  return a;\n}\n\n// Arithmetic with floats\n\ninline float operator+(bfloat16 a, float b) { return static_cast<float>(a) + b; }\ninline float operator-(bfloat16 a, float b) { return static_cast<float>(a) - b; }\ninline float operator*(bfloat16 a, float b) { return static_cast<float>(a) * b; }\ninline float operator/(bfloat16 a, float b) { return static_cast<float>(a) / b; }\n\ninline float operator+(float a, bfloat16 b) { return a + static_cast<float>(b); }\ninline float operator-(float a, bfloat16 b) { return a - static_cast<float>(b); }\ninline float operator*(float a, bfloat16 b) { return a * static_cast<float>(b); }\ninline float operator/(float a, bfloat16 b) { return a / static_cast<float>(b); }\n\ninline float& operator+=(float& a, const bfloat16& b) { return a += static_cast<float>(b); }\ninline float& operator-=(float& a, const bfloat16& b) { return a -= static_cast<float>(b); }\ninline float& operator*=(float& a, const bfloat16& b) { return a *= static_cast<float>(b); }\ninline float& operator/=(float& a, const bfloat16& b) { return a /= static_cast<float>(b); }\n\n// Arithmetic with doubles\n\ninline double operator+(bfloat16 a, double b) { return static_cast<double>(a) + b; }\ninline double operator-(bfloat16 a, double b) { return static_cast<double>(a) - b; }\ninline double operator*(bfloat16 a, double b) { return static_cast<double>(a) * b; }\ninline double operator/(bfloat16 a, double b) { return static_cast<double>(a) / b; }\n\ninline double operator+(double a, bfloat16 b) { return a + static_cast<double>(b); }\ninline double operator-(double a, bfloat16 b) { return a - static_cast<double>(b); }\ninline double operator*(double a, bfloat16 b) { return a * static_cast<double>(b); }\ninline double operator/(double a, bfloat16 b) { return a / static_cast<double>(b); }\n\n// Arithmetic with int32_t\n\ninline bfloat16 operator+(bfloat16 a, int32_t b) { return a + static_cast<bfloat16>(b); }\ninline bfloat16 operator-(bfloat16 a, int32_t b) { return a - static_cast<bfloat16>(b); }\ninline bfloat16 operator*(bfloat16 a, int32_t b) { return a * static_cast<bfloat16>(b); }\ninline bfloat16 operator/(bfloat16 a, int32_t b) { return a / static_cast<bfloat16>(b); }\n\ninline bfloat16 operator+(int32_t a, bfloat16 b) { return static_cast<bfloat16>(a) + b; }\ninline bfloat16 operator-(int32_t a, bfloat16 b) { return static_cast<bfloat16>(a) - b; }\ninline bfloat16 operator*(int32_t a, bfloat16 b) { return static_cast<bfloat16>(a) * b; }\ninline bfloat16 operator/(int32_t a, bfloat16 b) { return static_cast<bfloat16>(a) / b; }\n\n// Arithmetic with int64_t\n\ninline bfloat16 operator+(bfloat16 a, int64_t b) { return a + static_cast<bfloat16>(b); }\ninline bfloat16 operator-(bfloat16 a, int64_t b) { return a - static_cast<bfloat16>(b); }\ninline bfloat16 operator*(bfloat16 a, int64_t b) { return a * static_cast<bfloat16>(b); }\ninline bfloat16 operator/(bfloat16 a, int64_t b) { return a / static_cast<bfloat16>(b); }\n\ninline bfloat16 operator+(int64_t a, bfloat16 b) { return static_cast<bfloat16>(a) + b; }\ninline bfloat16 operator-(int64_t a, bfloat16 b) { return static_cast<bfloat16>(a) - b; }\ninline bfloat16 operator*(int64_t a, bfloat16 b) { return static_cast<bfloat16>(a) * b; }\ninline bfloat16 operator/(int64_t a, bfloat16 b) { return static_cast<bfloat16>(a) / b; }\n\n// Comparison operators\n\ninline bool operator>(bfloat16& lhs, bfloat16& rhs) {\n  return static_cast<float>(lhs) > static_cast<float>(rhs);\n}\n\ninline bool operator>=(bfloat16& lhs, bfloat16& rhs) {\n  return static_cast<float>(lhs) >= static_cast<float>(rhs);\n}\n\ninline bool operator<(bfloat16& lhs, bfloat16& rhs) {\n  return static_cast<float>(lhs) < static_cast<float>(rhs);\n}\n\ninline bool operator<=(bfloat16& lhs, bfloat16& rhs) {\n  return static_cast<float>(lhs) <= static_cast<float>(rhs);\n}\n\ninline bool operator==(bfloat16& lhs, bfloat16& rhs) {\n  return static_cast<float>(lhs) == static_cast<float>(rhs);\n}\n\ninline bool operator!=(bfloat16& lhs, bfloat16& rhs) {\n  return static_cast<float>(lhs) != static_cast<float>(rhs);\n}\n\n}  // namespace oneflow\n\nnamespace std {\n\ninline bool isnan(const oneflow::bfloat16& value) { return (value.x & 0x7FFFU) > 0x07F80U; }\n\ninline bool isinf(const oneflow::bfloat16& value) { return value.x == 0x07F80U; }\n\ninline bool isfinite(const oneflow::bfloat16& value) { return !isinf(value) && !isnan(value); }\n\ntemplate<>\nclass numeric_limits<oneflow::bfloat16> {\n public:\n  static constexpr bool is_signed = true;\n  static constexpr bool is_specialized = true;\n  static constexpr bool is_integer = false;\n  static constexpr bool is_exact = false;\n  static constexpr bool has_infinity = true;\n  static constexpr bool has_quiet_NaN = true;\n  static constexpr bool has_signaling_NaN = true;\n  static constexpr auto has_denorm = numeric_limits<float>::has_denorm;\n  static constexpr auto has_denorm_loss = numeric_limits<float>::has_denorm_loss;\n  static constexpr auto round_style = numeric_limits<float>::round_style;\n  static constexpr bool is_iec559 = false;\n  static constexpr bool is_bounded = true;\n  static constexpr bool is_modulo = false;\n  static constexpr int digits = 8;\n  static constexpr int digits10 = 2;\n  static constexpr int max_digits10 = 4;\n  static constexpr int radix = 2;\n  static constexpr int min_exponent = -125;\n  static constexpr int min_exponent10 = -37;\n  static constexpr int max_exponent = 128;\n  static constexpr int max_exponent10 = 38;\n  static constexpr auto traps = numeric_limits<float>::traps;\n  static constexpr auto tinyness_before = numeric_limits<float>::tinyness_before;\n  static constexpr oneflow::bfloat16 min() {\n    return oneflow::bfloat16(0x0080U, oneflow::bfloat16::from_bits());\n  }\n  static constexpr oneflow::bfloat16 lowest() {\n    return oneflow::bfloat16(0xFF7FU, oneflow::bfloat16::from_bits());\n  }\n  static constexpr oneflow::bfloat16 max() {\n    return oneflow::bfloat16(0x7F7FU, oneflow::bfloat16::from_bits());\n  }\n  static constexpr oneflow::bfloat16 epsilon() {\n    return oneflow::bfloat16(0x3C00U, oneflow::bfloat16::from_bits());\n  }\n  static constexpr oneflow::bfloat16 round_error() {\n    return oneflow::bfloat16(0x3F00U, oneflow::bfloat16::from_bits());\n  }\n  static constexpr oneflow::bfloat16 infinity() {\n    return oneflow::bfloat16(0x7F80U, oneflow::bfloat16::from_bits());\n  }\n  static constexpr oneflow::bfloat16 quiet_NaN() {\n    return oneflow::bfloat16(0x7FC0U, oneflow::bfloat16::from_bits());\n  }\n  static constexpr oneflow::bfloat16 signaling_NaN() {\n    return oneflow::bfloat16(0x7F80U, oneflow::bfloat16::from_bits());\n  }\n  static constexpr oneflow::bfloat16 denorm_min() {\n    return oneflow::bfloat16(0x0001U, oneflow::bfloat16::from_bits());\n  }\n};\n\n}  // namespace std\n\n#endif  // ONEFLOW_CORE_COMMON_BFLOAT16_H_\n"
  },
  {
    "path": "oneflow/core/common/bfloat16_math.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_COMMON_BFLOAT16_MATH_H_\n#define ONEFLOW_CORE_COMMON_BFLOAT16_MATH_H_\n\n#include \"oneflow/core/common/bfloat16.h\"\n\nnamespace std {\n\n// reference: pytorch/c10/util/BFloat16-math.h\n// https://github.com/pytorch/pytorch/blob/release/1.12/c10/util/BFloat16-math.h\ninline oneflow::bfloat16 acos(oneflow::bfloat16 a) { return std::acos(static_cast<float>(a)); }\ninline oneflow::bfloat16 asin(oneflow::bfloat16 a) { return std::asin(static_cast<float>(a)); }\ninline oneflow::bfloat16 atan(oneflow::bfloat16 a) { return std::atan(static_cast<float>(a)); }\ninline oneflow::bfloat16 erf(oneflow::bfloat16 a) { return std::erf(static_cast<float>(a)); }\ninline oneflow::bfloat16 erfc(oneflow::bfloat16 a) { return std::erfc(static_cast<float>(a)); }\ninline oneflow::bfloat16 exp(oneflow::bfloat16 a) { return std::exp(static_cast<float>(a)); }\ninline oneflow::bfloat16 expm1(oneflow::bfloat16 a) { return std::expm1(static_cast<float>(a)); }\ninline oneflow::bfloat16 log(oneflow::bfloat16 a) { return std::log(static_cast<float>(a)); }\ninline oneflow::bfloat16 log10(oneflow::bfloat16 a) { return std::log10(static_cast<float>(a)); }\ninline oneflow::bfloat16 log1p(oneflow::bfloat16 a) { return std::log1p(static_cast<float>(a)); }\ninline oneflow::bfloat16 log2(oneflow::bfloat16 a) { return std::log2(static_cast<float>(a)); }\ninline oneflow::bfloat16 ceil(oneflow::bfloat16 a) { return std::ceil(static_cast<float>(a)); }\ninline oneflow::bfloat16 cos(oneflow::bfloat16 a) { return std::cos(static_cast<float>(a)); }\ninline oneflow::bfloat16 floor(oneflow::bfloat16 a) { return std::floor(static_cast<float>(a)); }\ninline oneflow::bfloat16 nearbyint(oneflow::bfloat16 a) {\n  return std::nearbyint(static_cast<float>(a));\n}\ninline oneflow::bfloat16 sin(oneflow::bfloat16 a) { return std::sin(static_cast<float>(a)); }\ninline oneflow::bfloat16 tan(oneflow::bfloat16 a) { return std::tan(static_cast<float>(a)); }\ninline oneflow::bfloat16 sinh(oneflow::bfloat16 a) { return std::sinh(static_cast<float>(a)); }\ninline oneflow::bfloat16 cosh(oneflow::bfloat16 a) { return std::cosh(static_cast<float>(a)); }\ninline oneflow::bfloat16 tanh(oneflow::bfloat16 a) { return std::tanh(static_cast<float>(a)); }\ninline oneflow::bfloat16 trunc(oneflow::bfloat16 a) { return std::trunc(static_cast<float>(a)); }\ninline oneflow::bfloat16 lgamma(oneflow::bfloat16 a) { return std::lgamma(static_cast<float>(a)); }\ninline oneflow::bfloat16 sqrt(oneflow::bfloat16 a) { return std::sqrt(static_cast<float>(a)); }\ninline oneflow::bfloat16 rsqrt(oneflow::bfloat16 a) {\n  return 1.0 / std::sqrt(static_cast<float>(a));\n}\ninline oneflow::bfloat16 abs(oneflow::bfloat16 a) { return std::abs(static_cast<float>(a)); }\ninline oneflow::bfloat16 pow(oneflow::bfloat16 a, double b) {\n  return std::pow(static_cast<float>(a), b);\n}\ninline oneflow::bfloat16 pow(oneflow::bfloat16 a, oneflow::bfloat16 b) {\n  return std::pow(static_cast<float>(a), static_cast<float>(b));\n}\ninline oneflow::bfloat16 fmod(oneflow::bfloat16 a, oneflow::bfloat16 b) {\n  return std::fmod(static_cast<float>(a), static_cast<float>(b));\n}\n\n}  // namespace std\n\n#endif  // ONEFLOW_CORE_COMMON_BFLOAT16_MATH_H_\n"
  },
  {
    "path": "oneflow/core/common/bfloat16_test.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"gtest/gtest.h\"\n#include \"oneflow/core/common/bfloat16.h\"\n#include \"oneflow/core/common/bfloat16_math.h\"\n\nnamespace oneflow {\nnamespace test {\n\nfloat float_from_bytes(uint32_t sign, uint32_t exponent, uint32_t fraction) {\n  // reference: pytorch/c10/test/util/bfloat16_test.cpp\n  // https://github.com/pytorch/pytorch/blob/release/1.12/c10/test/util/bfloat16_test.cpp\n  uint32_t bytes = 0;\n  bytes |= sign;\n  bytes <<= 8;\n  bytes |= exponent;\n  bytes <<= 23;\n  bytes |= fraction;\n  float res = NAN;\n  std::memcpy(&res, &bytes, sizeof(res));\n  return res;\n}\n\nTEST(BFLOAT16MATH, Add) {\n  // 6.25\n  float input = float_from_bytes(0, 0, 0x40C80000U);\n  // 7.25\n  float expected = float_from_bytes(0, 0, 0x40E80000U);\n\n  bfloat16 b(input);\n  b = b + 1;\n\n  float res = static_cast<float>(b);\n  EXPECT_EQ(res, expected);\n}\n\nTEST(BFLOAT16MATH, Sub) {\n  // 7.25\n  float input = float_from_bytes(0, 0, 0x40E80000U);\n  // 6.25\n  float expected = float_from_bytes(0, 0, 0x40C80000U);\n\n  bfloat16 b(input);\n  b = b - 1;\n\n  float res = static_cast<float>(b);\n  EXPECT_EQ(res, expected);\n}\n\nTEST(BFLOAT16MATH, Mul) {\n  // 3.125\n  float input = float_from_bytes(0, 0, 0x40480000U);\n  // 6.25\n  float expected = float_from_bytes(0, 0, 0x40C80000U);\n\n  bfloat16 b(input);\n  b = b * 2;\n\n  float res = static_cast<float>(b);\n  EXPECT_EQ(res, expected);\n}\n\nTEST(BFLOAT16MATH, Div) {\n  // 6.25\n  float input = float_from_bytes(0, 0, 0x40C80000U);\n  // 3.125\n  float expected = float_from_bytes(0, 0, 0x40480000U);\n\n  bfloat16 b(input);\n  b = b / 2;\n\n  float res = static_cast<float>(b);\n  EXPECT_EQ(res, expected);\n}\n\nTEST(BFLOAT16MATH, Log2) {\n  // 16\n  float input = float_from_bytes(0, 0, 0x41800000U);\n  // 4\n  float expected = float_from_bytes(0, 0, 0x40800000U);\n\n  bfloat16 b(input);\n  b = std::log2(b);\n\n  float res = static_cast<float>(b);\n  EXPECT_EQ(res, expected);\n}\n\nTEST(BFLOAT16MATH, Log10) {\n  // 100\n  float input = float_from_bytes(0, 0, 0x42C80000U);\n  // 2\n  float expected = float_from_bytes(0, 0, 0x40000000U);\n\n  bfloat16 b(input);\n  b = std::log10(b);\n\n  float res = static_cast<float>(b);\n  EXPECT_EQ(res, expected);\n}\n\nTEST(BFLOAT16MATH, Sqrt) {\n  // 25\n  float input = float_from_bytes(0, 0, 0x41C80000U);\n  // 5\n  float expected = float_from_bytes(0, 0, 0x40A00000U);\n\n  bfloat16 b(input);\n  b = std::sqrt(b);\n\n  float res = static_cast<float>(b);\n  EXPECT_EQ(res, expected);\n}\n\n}  // namespace test\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/common/blas.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_COMMON_BLAS_H_\n#define ONEFLOW_CORE_COMMON_BLAS_H_\n\n#include <type_traits>\n#include <utility>\n#include <complex>\n#include \"oneflow/core/common/cblas.h\"\n#include \"oneflow/core/common/preprocessor.h\"\n\nnamespace oneflow {\n\n#define BLAS_NAME_SEQ                      \\\n  OF_PP_MAKE_TUPLE_SEQ(dot)                \\\n  OF_PP_MAKE_TUPLE_SEQ(swap)               \\\n  OF_PP_MAKE_TUPLE_SEQ(copy)               \\\n  OF_PP_MAKE_TUPLE_SEQ(axpy)               \\\n  OF_PP_MAKE_TUPLE_SEQ(scal)               \\\n  OF_PP_MAKE_TUPLE_SEQ(gemv)               \\\n  OF_PP_MAKE_TUPLE_SEQ(gemm)               \\\n  OF_PP_MAKE_TUPLE_SEQ(gemmBatched)        \\\n  OF_PP_MAKE_TUPLE_SEQ(gemmStridedBatched) \\\n  OF_PP_MAKE_TUPLE_SEQ(getrfBatched)       \\\n  OF_PP_MAKE_TUPLE_SEQ(getriBatched)\n\n#define CBLAS_TEMPLATE(name)                                                                    \\\n  template<typename T, typename... Args>                                                        \\\n  auto cblas_##name(Args&&... args)                                                             \\\n      ->typename std::enable_if<std::is_same<T, float>::value,                                  \\\n                                decltype(cblas_##s##name(std::forward<Args>(args)...))>::type { \\\n    return cblas_##s##name(std::forward<Args>(args)...);                                        \\\n  }                                                                                             \\\n  template<typename T, typename... Args>                                                        \\\n  auto cblas_##name(Args&&... args)                                                             \\\n      ->typename std::enable_if<std::is_same<T, double>::value,                                 \\\n                                decltype(cblas_##d##name(std::forward<Args>(args)...))>::type { \\\n    return cblas_##d##name(std::forward<Args>(args)...);                                        \\\n  }                                                                                             \\\n  template<typename T, typename... Args>                                                        \\\n  auto cblas_##name(Args&&... args)                                                             \\\n      ->typename std::enable_if<std::is_same<T, std::complex<float>>::value,                    \\\n                                decltype(cblas_##c##name(std::forward<Args>(args)...))>::type { \\\n    return cblas_##c##name(std::forward<Args>(args)...);                                        \\\n  }                                                                                             \\\n  template<typename T, typename... Args>                                                        \\\n  auto cblas_##name(Args&&... args)                                                             \\\n      ->typename std::enable_if<std::is_same<T, std::complex<double>>::value,                   \\\n                                decltype(cblas_##z##name(std::forward<Args>(args)...))>::type { \\\n    return cblas_##z##name(std::forward<Args>(args)...);                                        \\\n  }\n\nOF_PP_FOR_EACH_TUPLE(CBLAS_TEMPLATE, BLAS_NAME_SEQ);\n\n#undef CBLAS_TEMPLATE\n\n#undef BLAS_NAME_SEQ\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_COMMON_BLAS_H_\n"
  },
  {
    "path": "oneflow/core/common/blocking_counter.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/blocking_counter.h\"\n#include \"oneflow/core/common/foreign_lock_helper.h\"\n#include \"oneflow/core/common/singleton.h\"\n#include \"oneflow/core/common/data_type.h\"\n#include \"oneflow/core/common/env_var/env_var.h\"\n\nnamespace oneflow {\n\nint64_t BlockingCounter::Increase() {\n  std::unique_lock<std::mutex> lck(mtx_);\n  CHECK_GT(cnt_val_, 0);\n  cnt_val_ += 1;\n  return cnt_val_;\n}\n\nint64_t BlockingCounter::Decrease() {\n  std::unique_lock<std::mutex> lck(mtx_);\n  cnt_val_ -= 1;\n  if (cnt_val_ == 0) { cond_.notify_all(); }\n  return cnt_val_;\n}\n\nMaybe<void> BlockingCounter::WaitUntilCntEqualZero(size_t timeout_seconds) {\n  return Singleton<ForeignLockHelper>::Get()->WithScopedRelease([&, this]() -> Maybe<void> {\n    std::chrono::duration<size_t> seconds(timeout_seconds);\n    std::unique_lock<std::mutex> lck(mtx_);\n    CHECK_OR_RETURN(cond_.wait_for(lck, seconds, [this]() { return cnt_val_ == 0; }))\n        << Error::TimeoutError();\n    return Maybe<void>::Ok();\n  });\n}\n\nvoid BlockingCounter::WaitForeverUntilCntEqualZero() {\n  CHECK_JUST(WaitUntilCntEqualZero([]() -> Maybe<bool> { return false; }));\n}\n\nMaybe<void> BlockingCounter::WaitUntilCntEqualZero(\n    const std::function<Maybe<bool>()>& StopWaitingAfterTimeout) {\n  while (true) {\n    auto status = TRY(WaitUntilCntEqualZero(EnvInteger<ONEFLOW_TIMEOUT_SECONDS>()));\n    if (status.IsOk()) { return status; }\n    if (!status.error()->has_timeout_error()) { return status; }\n    if (JUST(StopWaitingAfterTimeout())) { return status; }\n  }\n  UNIMPLEMENTED_THEN_RETURN();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/common/blocking_counter.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_COMMON_BLOCKING_COUNTER_H_\n#define ONEFLOW_CORE_COMMON_BLOCKING_COUNTER_H_\n\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/common/maybe.h\"\n\nnamespace oneflow {\n\nclass BlockingCounter final {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(BlockingCounter);\n  BlockingCounter() = delete;\n  ~BlockingCounter() = default;\n\n  BlockingCounter(int64_t cnt_val) { cnt_val_ = cnt_val; }\n\n  int64_t Increase();\n  int64_t Decrease();\n  void WaitForeverUntilCntEqualZero();\n  Maybe<void> WaitUntilCntEqualZero(size_t timeout_seconds);\n  Maybe<void> WaitUntilCntEqualZero(const std::function<Maybe<bool>()>& StopWaitingAfterTimeout);\n\n private:\n  std::mutex mtx_;\n  std::condition_variable cond_;\n  int64_t cnt_val_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_COMMON_BLOCKING_COUNTER_H_\n"
  },
  {
    "path": "oneflow/core/common/blocking_then_busy.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_COMMON_BLOCKING_THEN_BUSY_H_\n#define ONEFLOW_CORE_COMMON_BLOCKING_THEN_BUSY_H_\n\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/common/notifier.h\"\n#include \"oneflow/core/common/spin_counter.h\"\n\nnamespace oneflow {\n\nclass BlockingThenBusy final {\n public:\n  BlockingThenBusy(const BlockingThenBusy&) = delete;\n  BlockingThenBusy(BlockingThenBusy&&) = delete;\n  constexpr static int kCnt = 1;\n  BlockingThenBusy() : notifier_(), spin_counter_(kCnt) {}\n\n  Notifier* mut_notifier() { return &notifier_; }\n  SpinCounter* mut_spin_counter() { return &spin_counter_; }\n\n  void Reset() { mut_spin_counter()->Reset(kCnt); }\n\n  Maybe<void> WaitUntilCntEqualZero(const std::function<Maybe<bool>()>& StopAfterTimeout) {\n    JUST(notifier_.TimedWaitAndClearNotifiedCnt(StopAfterTimeout));\n    JUST(spin_counter_.WaitUntilCntEqualZero());\n    return Maybe<void>::Ok();\n  }\n\n private:\n  Notifier notifier_;\n  SpinCounter spin_counter_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_COMMON_BLOCKING_THEN_BUSY_H_\n"
  },
  {
    "path": "oneflow/core/common/buffer.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_COMMON_BUFFER_H_\n#define ONEFLOW_CORE_COMMON_BUFFER_H_\n\n#include \"oneflow/core/common/util.h\"\n\nnamespace oneflow {\n\nenum BufferStatus { kBufferStatusSuccess = 0, kBufferStatusErrorClosed, kBufferStatusEmpty };\n\ntemplate<typename T>\nclass Buffer final {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(Buffer);\n  Buffer(size_t max_len) : max_len_(max_len), is_closed_(false) {}\n  ~Buffer() = default;\n\n  template<typename U>\n  BufferStatus Push(U&& item);\n  BufferStatus Pull(T* item);\n  BufferStatus TryReceive(T* item);\n  void Close();\n\n private:\n  std::queue<T> queue_;\n  mutable std::mutex mutex_;\n  size_t max_len_;\n  bool is_closed_;\n  std::condition_variable cond_;\n};\n\ntemplate<typename T>\ntemplate<typename U>\nBufferStatus Buffer<T>::Push(U&& item) {\n  std::unique_lock<std::mutex> lock(mutex_);\n  cond_.wait(lock, [this]() { return queue_.size() < max_len_ || is_closed_; });\n  if (is_closed_) { return kBufferStatusErrorClosed; }\n  queue_.push(std::forward<U>(item));\n  cond_.notify_one();\n  return kBufferStatusSuccess;\n}\n\ntemplate<typename T>\nBufferStatus Buffer<T>::Pull(T* item) {\n  std::unique_lock<std::mutex> lock(mutex_);\n  cond_.wait(lock, [this]() { return (!queue_.empty()) || is_closed_; });\n  if (queue_.empty()) { return kBufferStatusErrorClosed; }\n  *item = std::move(queue_.front());\n  queue_.pop();\n  if (queue_.size() < max_len_) { cond_.notify_all(); }\n  return kBufferStatusSuccess;\n}\n\ntemplate<typename T>\nBufferStatus Buffer<T>::TryReceive(T* item) {\n  std::unique_lock<std::mutex> lock(mutex_);\n  if (queue_.empty()) { return is_closed_ ? kBufferStatusErrorClosed : kBufferStatusEmpty; }\n  *item = std::move(queue_.front());\n  queue_.pop();\n  if (queue_.size() < max_len_) { cond_.notify_all(); }\n  return kBufferStatusSuccess;\n}\n\ntemplate<typename T>\nvoid Buffer<T>::Close() {\n  std::unique_lock<std::mutex> lock(mutex_);\n  is_closed_ = true;\n  cond_.notify_all();\n}\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_COMMON_BUFFER_H_\n"
  },
  {
    "path": "oneflow/core/common/buffer_manager.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_COMMON_BUFFER_MANAGER_H_\n#define ONEFLOW_CORE_COMMON_BUFFER_MANAGER_H_\n\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/common/buffer.h\"\n\nnamespace oneflow {\n\ntemplate<typename T>\nclass BufferMgr final {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(BufferMgr);\n  ~BufferMgr() = default;\n\n  void NewBuffer(const std::string& buffer_name, size_t buffer_size) {\n    CHECK(name2buffer_.emplace(buffer_name, std::make_unique<Buffer<T>>(buffer_size)).second);\n  }\n  Buffer<T>* Get(const std::string& buffer_name) const {\n    const auto& iter = name2buffer_.find(buffer_name);\n    CHECK(iter != name2buffer_.end()) << \"buffer_name: \" << buffer_name;\n    return iter->second.get();\n  }\n\n private:\n  friend class Singleton<BufferMgr>;\n  BufferMgr() = default;\n\n  HashMap<std::string, std::unique_ptr<Buffer<T>>> name2buffer_;\n};\n\nstatic const std::string kBufferNameGlobalWaitJobId = \"GlobalWaitJobId\";\nstatic const std::string kCallbackNotifierBufferNamePrefix = \"CallbackNotifier-\";\nstatic const std::string kInputCriticalSectionWaitBufferNamePrefix = \"InputCriticalSectionWait-\";\nstatic const std::string kInputCriticalSectionCallbackBufferNamePrefix =\n    \"InputCriticalSectionCallback-\";\nstatic const std::string kOutputCriticalSectionWaitBufferNamePrefix = \"OutputCriticalSectionWait-\";\nstatic const std::string kOutputCriticalSectionCallbackBufferNamePrefix =\n    \"OutputCriticalSectionCallback-\";\nstatic const std::string kInputBufferNamePrefix = \"Input-\";\nstatic const std::string kOutputBufferNamePrefix = \"Output-\";\nstatic const std::string kSourceTickBufferNamePrefix = \"SourceTick-\";\n\ninline std::string GetCallbackNotifierBufferName(const std::string& job_name) {\n  return kCallbackNotifierBufferNamePrefix + job_name;\n}\n\ninline std::string GetInputCriticalSectionWaitBufferName(const std::string& job_name) {\n  return kInputCriticalSectionWaitBufferNamePrefix + job_name;\n}\n\ninline std::string GetInputCriticalSectionCallbackBufferName(const std::string& job_name) {\n  return kInputCriticalSectionCallbackBufferNamePrefix + job_name;\n}\n\ninline std::string GetOutputCriticalSectionWaitBufferName(const std::string& job_name) {\n  return kOutputCriticalSectionWaitBufferNamePrefix + job_name;\n}\n\ninline std::string GetOutputCriticalSectionCallbackBufferName(const std::string& job_name) {\n  return kOutputCriticalSectionCallbackBufferNamePrefix + job_name;\n}\n\ninline std::string GetInputBufferName(const std::string& job_name, const std::string& op_name) {\n  return kInputBufferNamePrefix + job_name + \"-\" + op_name;\n}\n\ninline std::string GetOutputBufferName(const std::string& job_name, const std::string& op_name) {\n  return kOutputBufferNamePrefix + job_name + \"-\" + op_name;\n}\n\ninline std::string GetSourceTickBufferName(const std::string& job_name) {\n  return kSourceTickBufferNamePrefix + job_name;\n}\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_COMMON_BUFFER_MANAGER_H_\n"
  },
  {
    "path": "oneflow/core/common/cached_caller.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/common/cached_caller.h\"\n#include \"oneflow/core/job/resource_desc.h\"\n#include \"oneflow/core/job/global_for.h\"\n\nnamespace oneflow {\n\nbool IsThreadLocalCacheEnabled() {\n  if (Singleton<ResourceDesc, ForSession>::Get() == nullptr) { return true; }\n  return Singleton<ResourceDesc, ForSession>::Get()->enable_thread_local_cache();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/common/cached_caller.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_COMMON_CACHED_CALLER_H_\n#define ONEFLOW_CORE_COMMON_CACHED_CALLER_H_\n\n#include <list>\n#include <tuple>\n#include <thread>\n#include \"oneflow/core/common/function_traits.h\"\n#include \"oneflow/core/common/hash_eq_trait_ptr.h\"\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/common/tuple_hash.h\"\n\n// gcc 11 falsely reports error:\n// ‘void operator delete(void*, std::size_t)’ called on unallocated object ‘cache’\n// However, `DeleteAndClear` is only called after `cache` is allocated in\n// if (cache == nullptr) block.\n// The reason not to use #pragma GCC diagnostic push/pop is that gcc reports\n// the error on the caller of `ThreadLocalCachedCall`.\n// TODO: replace ThreadLocalCachedCall with ThreadLocalCached decorator?\n#if defined(__GNUC__) && !defined(__clang__) && __GNUC__ >= 11\n#pragma GCC diagnostic ignored \"-Wfree-nonheap-object\"\n#endif\n\nnamespace oneflow {\n\ntemplate<typename T>\nvoid DeleteAndClear(T** ptr, size_t obj_cnt) {\n  static const size_t kThreshold = 4096;\n  if (obj_cnt <= kThreshold) {\n    delete ptr;\n  } else {\n    std::thread([](T* ptr) { delete ptr; }, *ptr);\n  }\n  *ptr = nullptr;\n}\n\nbool IsThreadLocalCacheEnabled();\n\ntemplate<\n    typename F, typename Ret = typename function_traits<F>::return_type,\n    typename RawArg = typename std::tuple_element<0, typename function_traits<F>::args_type>::type,\n    typename Arg = typename std::remove_const<typename std::remove_reference<RawArg>::type>::type>\nRet ThreadLocalCachedCall(size_t max_size, F f, const Arg& arg) {\n  if (IsThreadLocalCacheEnabled() == false) { return f(arg); }\n  using HashMap = std::unordered_map<HashEqTraitPtr<const Arg>, Ret>;\n  using KeyStorage = std::list<std::unique_ptr<Arg>>;\n  static thread_local HashMap* cache = nullptr;\n  static thread_local KeyStorage* key_storage = nullptr;\n  if (cache != nullptr && cache->size() >= max_size) {\n    DeleteAndClear(&cache, cache->size());\n    DeleteAndClear(&key_storage, cache->size());\n  }\n  if (cache == nullptr) {\n    cache = new HashMap();\n    key_storage = new KeyStorage();\n  }\n  size_t hash_value = std::hash<Arg>()(arg);\n  {\n    HashEqTraitPtr<const Arg> ptr_wrapper(&arg, hash_value);\n    const auto& iter = cache->find(ptr_wrapper);\n    if (iter != cache->end()) { return iter->second; }\n  }\n  Arg* new_arg = new Arg(arg);\n  key_storage->emplace_back(new_arg);\n  HashEqTraitPtr<const Arg> ptr_wrapper(new_arg, hash_value);\n  return cache->emplace(ptr_wrapper, f(*new_arg)).first->second;\n}\n\ntemplate<\n    typename F, typename Ret = typename function_traits<F>::return_type,\n    typename RawArg = typename std::tuple_element<0, typename function_traits<F>::args_type>::type,\n    typename Arg = typename std::remove_const<typename std::remove_reference<RawArg>::type>::type>\nstd::function<Ret(const Arg&)> WithResultCached(F f) {\n  auto cache = std::make_shared<std::unordered_map<Arg, Ret>>();\n  return [cache, f](const Arg& arg) -> Ret {\n    const auto& iter = cache->find(arg);\n    if (iter != cache->end()) { return iter->second; }\n    return cache->emplace(arg, f(arg)).first->second;\n  };\n}\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_COMMON_CACHED_CALLER_H_\n"
  },
  {
    "path": "oneflow/core/common/cblas.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_COMMON_CBLAS_H_\n#define ONEFLOW_CORE_COMMON_CBLAS_H_\n#include <stddef.h>\n\n/*\n * Enumerated and derived types\n */\n#define CBLAS_INDEX size_t /* this may vary between platforms */\n\nenum CBLAS_ORDER { CblasRowMajor = 101, CblasColMajor = 102 };\nenum CBLAS_TRANSPOSE { CblasNoTrans = 111, CblasTrans = 112, CblasConjTrans = 113 };\nenum CBLAS_UPLO { CblasUpper = 121, CblasLower = 122 };\nenum CBLAS_DIAG { CblasNonUnit = 131, CblasUnit = 132 };\nenum CBLAS_SIDE { CblasLeft = 141, CblasRight = 142 };\n\n#ifdef __cplusplus\nextern \"C\" {\n#endif\n\n/*\n * ===========================================================================\n * Prototypes for level 1 BLAS functions (complex are recast as routines)\n * ===========================================================================\n */\nfloat cblas_sdsdot(const int N, const float alpha, const float* X, const int incX, const float* Y,\n                   const int incY);\ndouble cblas_dsdot(const int N, const float* X, const int incX, const float* Y, const int incY);\nfloat cblas_sdot(const int N, const float* X, const int incX, const float* Y, const int incY);\ndouble cblas_ddot(const int N, const double* X, const int incX, const double* Y, const int incY);\n\n/*\n * Functions having prefixes Z and C only\n */\nvoid cblas_cdotu_sub(const int N, const void* X, const int incX, const void* Y, const int incY,\n                     void* dotu);\nvoid cblas_cdotc_sub(const int N, const void* X, const int incX, const void* Y, const int incY,\n                     void* dotc);\n\nvoid cblas_zdotu_sub(const int N, const void* X, const int incX, const void* Y, const int incY,\n                     void* dotu);\nvoid cblas_zdotc_sub(const int N, const void* X, const int incX, const void* Y, const int incY,\n                     void* dotc);\n\n/*\n * Functions having prefixes S D SC DZ\n */\nfloat cblas_snrm2(const int N, const float* X, const int incX);\nfloat cblas_sasum(const int N, const float* X, const int incX);\n\ndouble cblas_dnrm2(const int N, const double* X, const int incX);\ndouble cblas_dasum(const int N, const double* X, const int incX);\n\nfloat cblas_scnrm2(const int N, const void* X, const int incX);\nfloat cblas_scasum(const int N, const void* X, const int incX);\n\ndouble cblas_dznrm2(const int N, const void* X, const int incX);\ndouble cblas_dzasum(const int N, const void* X, const int incX);\n\n/*\n * Functions having standard 4 prefixes (S D C Z)\n */\nCBLAS_INDEX cblas_isamax(const int N, const float* X, const int incX);\nCBLAS_INDEX cblas_idamax(const int N, const double* X, const int incX);\nCBLAS_INDEX cblas_icamax(const int N, const void* X, const int incX);\nCBLAS_INDEX cblas_izamax(const int N, const void* X, const int incX);\n\n/*\n * ===========================================================================\n * Prototypes for level 1 BLAS routines\n * ===========================================================================\n */\n\n/*\n * Routines with standard 4 prefixes (s, d, c, z)\n */\nvoid cblas_sswap(const int N, float* X, const int incX, float* Y, const int incY);\nvoid cblas_scopy(const int N, const float* X, const int incX, float* Y, const int incY);\nvoid cblas_saxpy(const int N, const float alpha, const float* X, const int incX, float* Y,\n                 const int incY);\n\nvoid cblas_dswap(const int N, double* X, const int incX, double* Y, const int incY);\nvoid cblas_dcopy(const int N, const double* X, const int incX, double* Y, const int incY);\nvoid cblas_daxpy(const int N, const double alpha, const double* X, const int incX, double* Y,\n                 const int incY);\n\nvoid cblas_cswap(const int N, void* X, const int incX, void* Y, const int incY);\nvoid cblas_ccopy(const int N, const void* X, const int incX, void* Y, const int incY);\nvoid cblas_caxpy(const int N, const void* alpha, const void* X, const int incX, void* Y,\n                 const int incY);\n\nvoid cblas_zswap(const int N, void* X, const int incX, void* Y, const int incY);\nvoid cblas_zcopy(const int N, const void* X, const int incX, void* Y, const int incY);\nvoid cblas_zaxpy(const int N, const void* alpha, const void* X, const int incX, void* Y,\n                 const int incY);\n\n/*\n * Routines with S and D prefix only\n */\nvoid cblas_srotg(float* a, float* b, float* c, float* s);\nvoid cblas_srotmg(float* d1, float* d2, float* b1, const float b2, float* P);\nvoid cblas_srot(const int N, float* X, const int incX, float* Y, const int incY, const float c,\n                const float s);\nvoid cblas_srotm(const int N, float* X, const int incX, float* Y, const int incY, const float* P);\n\nvoid cblas_drotg(double* a, double* b, double* c, double* s);\nvoid cblas_drotmg(double* d1, double* d2, double* b1, const double b2, double* P);\nvoid cblas_drot(const int N, double* X, const int incX, double* Y, const int incY, const double c,\n                const double s);\nvoid cblas_drotm(const int N, double* X, const int incX, double* Y, const int incY,\n                 const double* P);\n\n/*\n * Routines with S D C Z CS and ZD prefixes\n */\nvoid cblas_sscal(const int N, const float alpha, float* X, const int incX);\nvoid cblas_dscal(const int N, const double alpha, double* X, const int incX);\nvoid cblas_cscal(const int N, const void* alpha, void* X, const int incX);\nvoid cblas_zscal(const int N, const void* alpha, void* X, const int incX);\nvoid cblas_csscal(const int N, const float alpha, void* X, const int incX);\nvoid cblas_zdscal(const int N, const double alpha, void* X, const int incX);\n\n/*\n * ===========================================================================\n * Prototypes for level 2 BLAS\n * ===========================================================================\n */\n\n/*\n * Routines with standard 4 prefixes (S, D, C, Z)\n */\nvoid cblas_sgemv(const enum CBLAS_ORDER order, const enum CBLAS_TRANSPOSE TransA, const int M,\n                 const int N, const float alpha, const float* A, const int lda, const float* X,\n                 const int incX, const float beta, float* Y, const int incY);\nvoid cblas_sgbmv(const enum CBLAS_ORDER order, const enum CBLAS_TRANSPOSE TransA, const int M,\n                 const int N, const int KL, const int KU, const float alpha, const float* A,\n                 const int lda, const float* X, const int incX, const float beta, float* Y,\n                 const int incY);\nvoid cblas_strmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,\n                 const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, const int N,\n                 const float* A, const int lda, float* X, const int incX);\nvoid cblas_stbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,\n                 const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, const int N,\n                 const int K, const float* A, const int lda, float* X, const int incX);\nvoid cblas_stpmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,\n                 const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, const int N,\n                 const float* Ap, float* X, const int incX);\nvoid cblas_strsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,\n                 const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, const int N,\n                 const float* A, const int lda, float* X, const int incX);\nvoid cblas_stbsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,\n                 const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, const int N,\n                 const int K, const float* A, const int lda, float* X, const int incX);\nvoid cblas_stpsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,\n                 const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, const int N,\n                 const float* Ap, float* X, const int incX);\n\nvoid cblas_dgemv(const enum CBLAS_ORDER order, const enum CBLAS_TRANSPOSE TransA, const int M,\n                 const int N, const double alpha, const double* A, const int lda, const double* X,\n                 const int incX, const double beta, double* Y, const int incY);\nvoid cblas_dgbmv(const enum CBLAS_ORDER order, const enum CBLAS_TRANSPOSE TransA, const int M,\n                 const int N, const int KL, const int KU, const double alpha, const double* A,\n                 const int lda, const double* X, const int incX, const double beta, double* Y,\n                 const int incY);\nvoid cblas_dtrmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,\n                 const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, const int N,\n                 const double* A, const int lda, double* X, const int incX);\nvoid cblas_dtbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,\n                 const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, const int N,\n                 const int K, const double* A, const int lda, double* X, const int incX);\nvoid cblas_dtpmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,\n                 const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, const int N,\n                 const double* Ap, double* X, const int incX);\nvoid cblas_dtrsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,\n                 const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, const int N,\n                 const double* A, const int lda, double* X, const int incX);\nvoid cblas_dtbsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,\n                 const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, const int N,\n                 const int K, const double* A, const int lda, double* X, const int incX);\nvoid cblas_dtpsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,\n                 const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, const int N,\n                 const double* Ap, double* X, const int incX);\n\nvoid cblas_cgemv(const enum CBLAS_ORDER order, const enum CBLAS_TRANSPOSE TransA, const int M,\n                 const int N, const void* alpha, const void* A, const int lda, const void* X,\n                 const int incX, const void* beta, void* Y, const int incY);\nvoid cblas_cgbmv(const enum CBLAS_ORDER order, const enum CBLAS_TRANSPOSE TransA, const int M,\n                 const int N, const int KL, const int KU, const void* alpha, const void* A,\n                 const int lda, const void* X, const int incX, const void* beta, void* Y,\n                 const int incY);\nvoid cblas_ctrmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,\n                 const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, const int N,\n                 const void* A, const int lda, void* X, const int incX);\nvoid cblas_ctbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,\n                 const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, const int N,\n                 const int K, const void* A, const int lda, void* X, const int incX);\nvoid cblas_ctpmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,\n                 const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, const int N,\n                 const void* Ap, void* X, const int incX);\nvoid cblas_ctrsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,\n                 const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, const int N,\n                 const void* A, const int lda, void* X, const int incX);\nvoid cblas_ctbsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,\n                 const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, const int N,\n                 const int K, const void* A, const int lda, void* X, const int incX);\nvoid cblas_ctpsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,\n                 const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, const int N,\n                 const void* Ap, void* X, const int incX);\n\nvoid cblas_zgemv(const enum CBLAS_ORDER order, const enum CBLAS_TRANSPOSE TransA, const int M,\n                 const int N, const void* alpha, const void* A, const int lda, const void* X,\n                 const int incX, const void* beta, void* Y, const int incY);\nvoid cblas_zgbmv(const enum CBLAS_ORDER order, const enum CBLAS_TRANSPOSE TransA, const int M,\n                 const int N, const int KL, const int KU, const void* alpha, const void* A,\n                 const int lda, const void* X, const int incX, const void* beta, void* Y,\n                 const int incY);\nvoid cblas_ztrmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,\n                 const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, const int N,\n                 const void* A, const int lda, void* X, const int incX);\nvoid cblas_ztbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,\n                 const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, const int N,\n                 const int K, const void* A, const int lda, void* X, const int incX);\nvoid cblas_ztpmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,\n                 const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, const int N,\n                 const void* Ap, void* X, const int incX);\nvoid cblas_ztrsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,\n                 const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, const int N,\n                 const void* A, const int lda, void* X, const int incX);\nvoid cblas_ztbsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,\n                 const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, const int N,\n                 const int K, const void* A, const int lda, void* X, const int incX);\nvoid cblas_ztpsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,\n                 const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, const int N,\n                 const void* Ap, void* X, const int incX);\n\n/*\n * Routines with S and D prefixes only\n */\nvoid cblas_ssymv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N,\n                 const float alpha, const float* A, const int lda, const float* X, const int incX,\n                 const float beta, float* Y, const int incY);\nvoid cblas_ssbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N, const int K,\n                 const float alpha, const float* A, const int lda, const float* X, const int incX,\n                 const float beta, float* Y, const int incY);\nvoid cblas_sspmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N,\n                 const float alpha, const float* Ap, const float* X, const int incX,\n                 const float beta, float* Y, const int incY);\nvoid cblas_sger(const enum CBLAS_ORDER order, const int M, const int N, const float alpha,\n                const float* X, const int incX, const float* Y, const int incY, float* A,\n                const int lda);\nvoid cblas_ssyr(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N,\n                const float alpha, const float* X, const int incX, float* A, const int lda);\nvoid cblas_sspr(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N,\n                const float alpha, const float* X, const int incX, float* Ap);\nvoid cblas_ssyr2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N,\n                 const float alpha, const float* X, const int incX, const float* Y, const int incY,\n                 float* A, const int lda);\nvoid cblas_sspr2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N,\n                 const float alpha, const float* X, const int incX, const float* Y, const int incY,\n                 float* A);\n\nvoid cblas_dsymv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N,\n                 const double alpha, const double* A, const int lda, const double* X,\n                 const int incX, const double beta, double* Y, const int incY);\nvoid cblas_dsbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N, const int K,\n                 const double alpha, const double* A, const int lda, const double* X,\n                 const int incX, const double beta, double* Y, const int incY);\nvoid cblas_dspmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N,\n                 const double alpha, const double* Ap, const double* X, const int incX,\n                 const double beta, double* Y, const int incY);\nvoid cblas_dger(const enum CBLAS_ORDER order, const int M, const int N, const double alpha,\n                const double* X, const int incX, const double* Y, const int incY, double* A,\n                const int lda);\nvoid cblas_dsyr(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N,\n                const double alpha, const double* X, const int incX, double* A, const int lda);\nvoid cblas_dspr(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N,\n                const double alpha, const double* X, const int incX, double* Ap);\nvoid cblas_dsyr2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N,\n                 const double alpha, const double* X, const int incX, const double* Y,\n                 const int incY, double* A, const int lda);\nvoid cblas_dspr2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N,\n                 const double alpha, const double* X, const int incX, const double* Y,\n                 const int incY, double* A);\n\n/*\n * Routines with C and Z prefixes only\n */\nvoid cblas_chemv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N,\n                 const void* alpha, const void* A, const int lda, const void* X, const int incX,\n                 const void* beta, void* Y, const int incY);\nvoid cblas_chbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N, const int K,\n                 const void* alpha, const void* A, const int lda, const void* X, const int incX,\n                 const void* beta, void* Y, const int incY);\nvoid cblas_chpmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N,\n                 const void* alpha, const void* Ap, const void* X, const int incX, const void* beta,\n                 void* Y, const int incY);\nvoid cblas_cgeru(const enum CBLAS_ORDER order, const int M, const int N, const void* alpha,\n                 const void* X, const int incX, const void* Y, const int incY, void* A,\n                 const int lda);\nvoid cblas_cgerc(const enum CBLAS_ORDER order, const int M, const int N, const void* alpha,\n                 const void* X, const int incX, const void* Y, const int incY, void* A,\n                 const int lda);\nvoid cblas_cher(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N,\n                const float alpha, const void* X, const int incX, void* A, const int lda);\nvoid cblas_chpr(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N,\n                const float alpha, const void* X, const int incX, void* A);\nvoid cblas_cher2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N,\n                 const void* alpha, const void* X, const int incX, const void* Y, const int incY,\n                 void* A, const int lda);\nvoid cblas_chpr2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N,\n                 const void* alpha, const void* X, const int incX, const void* Y, const int incY,\n                 void* Ap);\n\nvoid cblas_zhemv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N,\n                 const void* alpha, const void* A, const int lda, const void* X, const int incX,\n                 const void* beta, void* Y, const int incY);\nvoid cblas_zhbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N, const int K,\n                 const void* alpha, const void* A, const int lda, const void* X, const int incX,\n                 const void* beta, void* Y, const int incY);\nvoid cblas_zhpmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N,\n                 const void* alpha, const void* Ap, const void* X, const int incX, const void* beta,\n                 void* Y, const int incY);\nvoid cblas_zgeru(const enum CBLAS_ORDER order, const int M, const int N, const void* alpha,\n                 const void* X, const int incX, const void* Y, const int incY, void* A,\n                 const int lda);\nvoid cblas_zgerc(const enum CBLAS_ORDER order, const int M, const int N, const void* alpha,\n                 const void* X, const int incX, const void* Y, const int incY, void* A,\n                 const int lda);\nvoid cblas_zher(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N,\n                const double alpha, const void* X, const int incX, void* A, const int lda);\nvoid cblas_zhpr(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N,\n                const double alpha, const void* X, const int incX, void* A);\nvoid cblas_zher2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N,\n                 const void* alpha, const void* X, const int incX, const void* Y, const int incY,\n                 void* A, const int lda);\nvoid cblas_zhpr2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N,\n                 const void* alpha, const void* X, const int incX, const void* Y, const int incY,\n                 void* Ap);\n\n/*\n * ===========================================================================\n * Prototypes for level 3 BLAS\n * ===========================================================================\n */\n\n/*\n * Routines with standard 4 prefixes (S, D, C, Z)\n */\nvoid cblas_sgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA,\n                 const enum CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,\n                 const float alpha, const float* A, const int lda, const float* B, const int ldb,\n                 const float beta, float* C, const int ldc);\nvoid cblas_ssymm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,\n                 const enum CBLAS_UPLO Uplo, const int M, const int N, const float alpha,\n                 const float* A, const int lda, const float* B, const int ldb, const float beta,\n                 float* C, const int ldc);\nvoid cblas_ssyrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,\n                 const enum CBLAS_TRANSPOSE Trans, const int N, const int K, const float alpha,\n                 const float* A, const int lda, const float beta, float* C, const int ldc);\nvoid cblas_ssyr2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,\n                  const enum CBLAS_TRANSPOSE Trans, const int N, const int K, const float alpha,\n                  const float* A, const int lda, const float* B, const int ldb, const float beta,\n                  float* C, const int ldc);\nvoid cblas_strmm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,\n                 const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,\n                 const enum CBLAS_DIAG Diag, const int M, const int N, const float alpha,\n                 const float* A, const int lda, float* B, const int ldb);\nvoid cblas_strsm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,\n                 const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,\n                 const enum CBLAS_DIAG Diag, const int M, const int N, const float alpha,\n                 const float* A, const int lda, float* B, const int ldb);\n\nvoid cblas_dgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA,\n                 const enum CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,\n                 const double alpha, const double* A, const int lda, const double* B, const int ldb,\n                 const double beta, double* C, const int ldc);\nvoid cblas_dsymm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,\n                 const enum CBLAS_UPLO Uplo, const int M, const int N, const double alpha,\n                 const double* A, const int lda, const double* B, const int ldb, const double beta,\n                 double* C, const int ldc);\nvoid cblas_dsyrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,\n                 const enum CBLAS_TRANSPOSE Trans, const int N, const int K, const double alpha,\n                 const double* A, const int lda, const double beta, double* C, const int ldc);\nvoid cblas_dsyr2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,\n                  const enum CBLAS_TRANSPOSE Trans, const int N, const int K, const double alpha,\n                  const double* A, const int lda, const double* B, const int ldb, const double beta,\n                  double* C, const int ldc);\nvoid cblas_dtrmm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,\n                 const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,\n                 const enum CBLAS_DIAG Diag, const int M, const int N, const double alpha,\n                 const double* A, const int lda, double* B, const int ldb);\nvoid cblas_dtrsm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,\n                 const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,\n                 const enum CBLAS_DIAG Diag, const int M, const int N, const double alpha,\n                 const double* A, const int lda, double* B, const int ldb);\n\nvoid cblas_cgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA,\n                 const enum CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,\n                 const void* alpha, const void* A, const int lda, const void* B, const int ldb,\n                 const void* beta, void* C, const int ldc);\nvoid cblas_csymm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,\n                 const enum CBLAS_UPLO Uplo, const int M, const int N, const void* alpha,\n                 const void* A, const int lda, const void* B, const int ldb, const void* beta,\n                 void* C, const int ldc);\nvoid cblas_csyrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,\n                 const enum CBLAS_TRANSPOSE Trans, const int N, const int K, const void* alpha,\n                 const void* A, const int lda, const void* beta, void* C, const int ldc);\nvoid cblas_csyr2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,\n                  const enum CBLAS_TRANSPOSE Trans, const int N, const int K, const void* alpha,\n                  const void* A, const int lda, const void* B, const int ldb, const void* beta,\n                  void* C, const int ldc);\nvoid cblas_ctrmm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,\n                 const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,\n                 const enum CBLAS_DIAG Diag, const int M, const int N, const void* alpha,\n                 const void* A, const int lda, void* B, const int ldb);\nvoid cblas_ctrsm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,\n                 const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,\n                 const enum CBLAS_DIAG Diag, const int M, const int N, const void* alpha,\n                 const void* A, const int lda, void* B, const int ldb);\n\nvoid cblas_zgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA,\n                 const enum CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,\n                 const void* alpha, const void* A, const int lda, const void* B, const int ldb,\n                 const void* beta, void* C, const int ldc);\nvoid cblas_zsymm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,\n                 const enum CBLAS_UPLO Uplo, const int M, const int N, const void* alpha,\n                 const void* A, const int lda, const void* B, const int ldb, const void* beta,\n                 void* C, const int ldc);\nvoid cblas_zsyrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,\n                 const enum CBLAS_TRANSPOSE Trans, const int N, const int K, const void* alpha,\n                 const void* A, const int lda, const void* beta, void* C, const int ldc);\nvoid cblas_zsyr2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,\n                  const enum CBLAS_TRANSPOSE Trans, const int N, const int K, const void* alpha,\n                  const void* A, const int lda, const void* B, const int ldb, const void* beta,\n                  void* C, const int ldc);\nvoid cblas_ztrmm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,\n                 const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,\n                 const enum CBLAS_DIAG Diag, const int M, const int N, const void* alpha,\n                 const void* A, const int lda, void* B, const int ldb);\nvoid cblas_ztrsm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,\n                 const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,\n                 const enum CBLAS_DIAG Diag, const int M, const int N, const void* alpha,\n                 const void* A, const int lda, void* B, const int ldb);\n\n/*\n * Routines with prefixes C and Z only\n */\nvoid cblas_chemm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,\n                 const enum CBLAS_UPLO Uplo, const int M, const int N, const void* alpha,\n                 const void* A, const int lda, const void* B, const int ldb, const void* beta,\n                 void* C, const int ldc);\nvoid cblas_cherk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,\n                 const enum CBLAS_TRANSPOSE Trans, const int N, const int K, const float alpha,\n                 const void* A, const int lda, const float beta, void* C, const int ldc);\nvoid cblas_cher2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,\n                  const enum CBLAS_TRANSPOSE Trans, const int N, const int K, const void* alpha,\n                  const void* A, const int lda, const void* B, const int ldb, const float beta,\n                  void* C, const int ldc);\n\nvoid cblas_zhemm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,\n                 const enum CBLAS_UPLO Uplo, const int M, const int N, const void* alpha,\n                 const void* A, const int lda, const void* B, const int ldb, const void* beta,\n                 void* C, const int ldc);\nvoid cblas_zherk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,\n                 const enum CBLAS_TRANSPOSE Trans, const int N, const int K, const double alpha,\n                 const void* A, const int lda, const double beta, void* C, const int ldc);\nvoid cblas_zher2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,\n                  const enum CBLAS_TRANSPOSE Trans, const int N, const int K, const void* alpha,\n                  const void* A, const int lda, const void* B, const int ldb, const double beta,\n                  void* C, const int ldc);\n\nvoid cblas_xerbla(int p, const char* rout, const char* form, ...);\n\n#ifdef __cplusplus\n}\n#endif\n#endif\n"
  },
  {
    "path": "oneflow/core/common/channel.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_COMMON_CHANNEL_H_\n#define ONEFLOW_CORE_COMMON_CHANNEL_H_\n\n#include \"oneflow/core/common/util.h\"\n\nnamespace oneflow {\n\nenum ChannelStatus { kChannelStatusSuccess = 0, kChannelStatusErrorClosed };\n\ntemplate<typename T>\nclass Channel final {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(Channel);\n  Channel() : is_closed_(false) {}\n  ~Channel() = default;\n\n  template<typename U>\n  ChannelStatus Send(U&& item);\n  ChannelStatus Receive(T* item);\n  ChannelStatus ReceiveMany(std::queue<T>* items);\n  void Close();\n\n private:\n  std::queue<T> queue_;\n  std::mutex mutex_;\n  bool is_closed_;\n  std::condition_variable cond_;\n};\n\ntemplate<typename T>\ntemplate<typename U>\nChannelStatus Channel<T>::Send(U&& item) {\n  bool notify;\n  {\n    std::unique_lock<std::mutex> lock(mutex_);\n    if (is_closed_) { return kChannelStatusErrorClosed; }\n    notify = queue_.empty();\n    queue_.push(std::forward<U>(item));\n  }\n  if (notify) { cond_.notify_one(); }\n  return kChannelStatusSuccess;\n}\n\ntemplate<typename T>\nChannelStatus Channel<T>::Receive(T* item) {\n  std::unique_lock<std::mutex> lock(mutex_);\n  cond_.wait(lock, [this]() { return (!queue_.empty()) || is_closed_; });\n  if (queue_.empty()) { return kChannelStatusErrorClosed; }\n  *item = std::move(queue_.front());\n  queue_.pop();\n  return kChannelStatusSuccess;\n}\n\ntemplate<typename T>\nChannelStatus Channel<T>::ReceiveMany(std::queue<T>* items) {\n  std::unique_lock<std::mutex> lock(mutex_);\n  cond_.wait(lock, [this]() { return (!queue_.empty()) || is_closed_; });\n  if (queue_.empty()) { return kChannelStatusErrorClosed; }\n  while (!queue_.empty()) {\n    items->push(std::move(queue_.front()));\n    queue_.pop();\n  }\n  return kChannelStatusSuccess;\n}\n\ntemplate<typename T>\nvoid Channel<T>::Close() {\n  std::unique_lock<std::mutex> lock(mutex_);\n  is_closed_ = true;\n  cond_.notify_all();\n}\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_COMMON_CHANNEL_H_\n"
  },
  {
    "path": "oneflow/core/common/channel_test.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"gtest/gtest.h\"\n#include \"oneflow/core/common/channel.h\"\n#include \"oneflow/core/common/range.h\"\n\nnamespace oneflow {\n\nvoid CallFromSenderThread(Channel<int>* channel, Range range) {\n  for (int i = range.begin(); i < range.end(); ++i) {\n    if (channel->Send(i) != kChannelStatusSuccess) { break; }\n  }\n}\n\nvoid CallFromReceiverThread(std::vector<int>* visit, Channel<int>* channel) {\n  int num = -1;\n  int* num_ptr = &num;\n  while (channel->Receive(num_ptr) == kChannelStatusSuccess) { ++visit->at(*num_ptr); }\n}\n\nTEST(Channel, 30sender40receiver) {\n  Channel<int> channel;\n  std::vector<std::thread> senders;\n  std::vector<std::thread> receivers;\n  int sender_num = 30;\n  int receiver_num = 40;\n  int range_num = 200;\n  std::vector<std::vector<int>> visits;\n  for (int i = 0; i < receiver_num; ++i) {\n    std::vector<int> visit_i;\n    for (int j = 0; j < range_num; j++) { visit_i.emplace_back(0); }\n    visits.emplace_back(visit_i);\n  }\n  for (int i = 0; i < sender_num; ++i) {\n    senders.emplace_back(CallFromSenderThread, &channel, Range(0, range_num));\n  }\n  for (int i = 0; i < receiver_num; ++i) {\n    receivers.emplace_back(CallFromReceiverThread, &visits[i], &channel);\n  }\n  for (std::thread& this_thread : senders) { this_thread.join(); }\n  channel.Close();\n  for (std::thread& this_thread : receivers) { this_thread.join(); }\n  for (int i = 0; i < range_num; ++i) {\n    int visit_count = 0;\n    for (int j = 0; j < receiver_num; j++) { visit_count += visits[j][i]; }\n    ASSERT_EQ(visit_count, sender_num);\n  }\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/common/check.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/common/throw.h\"\nnamespace oneflow {\nvoid GLOGCHECK(bool value) { CHECK_OR_THROW(value); }\nvoid GLOGLOGFATAL(const char* error_msg) { LOG(FATAL) << error_msg; }\n}  // namespace oneflow"
  },
  {
    "path": "oneflow/core/common/check.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n// The functions in this header file are used to replace `CHECK` and `LOG(FATAL)` macros of glog\n// in those header files included by oneflow/core/common/throw.h, so those header files\n// do not need to include <glog/logging.h>, and we can undef CHECK series macro of\n// glog in oneflow/core/common/throw.h and use another impl instead with less modification.\nnamespace oneflow {\nvoid GLOGCHECK(bool);\nvoid GLOGLOGFATAL(const char*);\n}  // namespace oneflow"
  },
  {
    "path": "oneflow/core/common/check_level.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <cstdlib>\n#include <type_traits>\n#include \"oneflow/core/common/just.h\"\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/common/env_var/debug_mode.h\"\n\nnamespace oneflow {\n\nbool IsEnvEnabled(int32_t check_level) {\n  static const int env_check_level = ParseIntegerFromEnv(\"ONEFLOW_CHECK_LEVEL\", -1);\n  static const bool env_debug_mode = IsInDebugMode();\n  return env_debug_mode || env_check_level >= check_level;\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/common/check_level.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_COMMON_CHECK_LEVEL_H_\n#define ONEFLOW_CORE_COMMON_CHECK_LEVEL_H_\n\nnamespace oneflow {\n\nbool IsEnvEnabled(int32_t check_level);\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_COMMON_CHECK_LEVEL_H_\n"
  },
  {
    "path": "oneflow/core/common/constant.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_COMMON_CONSTANT_H_\n#define ONEFLOW_CORE_COMMON_CONSTANT_H_\n\n#include <string>\n\nnamespace oneflow {\n\nstatic const int64_t kInvalidSessionId = -1;\nstatic const std::string kNoPassTag = \"\";\nstatic const std::string kMainOp = \"main_op\";\nstatic const int64_t kMaxSplitAxis = 6;\nconstexpr size_t kMaxNumDims = 8;\nstatic const std::string kAsymmetricCodeErrorMsg =\n    \"Maybe executing different code in different ranks, please check if the code is branched and \"\n    \"operates on the global tensor.\";\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_COMMON_CONSTANT_H_\n"
  },
  {
    "path": "oneflow/core/common/container_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_COMMON_CONTAINER_UTIL_H_\n#define ONEFLOW_CORE_COMMON_CONTAINER_UTIL_H_\n\n#include <vector>\n#include \"oneflow/core/common/hash_container.h\"\n#include \"oneflow/core/common/type_traits.h\"\n#include \"oneflow/core/common/maybe.h\"\n\nnamespace oneflow {\n\ntemplate<typename MapT, typename KeyT, typename U>\nscalar_or_const_ref_t<typename MapT::mapped_type> MapAt(const MapT& map, const KeyT& key,\n                                                        const U& default_val) {\n  const auto& iter = map.find(key);\n  if (iter == map.end()) { return default_val; }\n  return iter->second;\n}\n\ntemplate<typename MapT, typename KeyT>\nMaybe<scalar_or_const_ref_t<typename MapT::mapped_type>> MapAt(const MapT& map, const KeyT& key) {\n  const auto& iter = map.find(key);\n  if constexpr (printable<KeyT>()) {\n    CHECK_OR_RETURN(iter != map.end()) << \"Key \\\"\" << key << \"\\\" not found\";\n  } else {\n    CHECK_OR_RETURN(iter != map.end())\n        << \"MapAt failed, but the key is not printable. Please implement operator<< if you want to \"\n           \"see the key in this error message.\";\n  }\n  return iter->second;\n}\n\ntemplate<typename MapT, typename KeyT>\nMaybe<typename MapT::mapped_type&> MapAt(MapT& map, const KeyT& key) {\n  const auto& iter = map.find(key);\n  if constexpr (printable<KeyT>()) {\n    CHECK_OR_RETURN(iter != map.end()) << \"Key \\\"\" << key << \"\\\" not found\";\n  } else {\n    CHECK_OR_RETURN(iter != map.end())\n        << \"MapAt failed, but the key is not printable. Please implement operator<< if you want to \"\n           \"see the key in this error message.\";\n  }\n  return iter->second;\n}\n\ntemplate<typename VecT>\nMaybe<scalar_or_const_ref_t<typename VecT::value_type>> VectorAt(const VecT& vec,\n                                                                 typename VecT::size_type index) {\n  CHECK_LT_OR_RETURN(index, vec.size());\n  return vec[index];\n}\n\ntemplate<typename VecT>\nMaybe<typename VecT::value_type&> VectorAt(VecT& vec, typename VecT::size_type index) {\n  static_assert(!std::is_same<typename VecT::value_type, bool>::value,\n                \"VectorAt(vector<bool>&, size_t) is not supported.\");\n  CHECK_LT_OR_RETURN(index, vec.size());\n  return vec[index];\n}\n\ntemplate<>\ninline Maybe<bool> VectorAt(const std::vector<bool>& vec,\n                            typename std::vector<bool>::size_type index) {\n  CHECK_LT_OR_RETURN(index, vec.size());\n  // convert vector bool proxy to bool\n  return static_cast<bool>(vec[index]);\n}\n\ntemplate<typename T>\nstd::string Join(const T& con, const std::string& delimiter) {\n  std::ostringstream os;\n  auto b = begin(con), e = end(con);\n\n  if (b != e) {\n    std::copy(b, prev(e), std::ostream_iterator<typename T::value_type>(os, delimiter));\n    b = prev(e);\n  }\n  if (b != e) { os << *b; }\n\n  return os.str();\n}\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_COMMON_CONTAINER_UTIL_H_\n"
  },
  {
    "path": "oneflow/core/common/container_util_test.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"gtest/gtest.h\"\n#include \"oneflow/core/common/container_util.h\"\n\nnamespace oneflow {\nnamespace test {\n\nTEST(VectorAt, write_int_vector) {\n  std::vector<int> vec = {1, 2, 3, 4, 5};\n  EXPECT_EQ(CHECK_JUST(VectorAt(vec, 1)), 2);\n  EXPECT_EQ(CHECK_JUST(VectorAt(vec, 3)), 4);\n  CHECK_JUST(VectorAt(vec, 1)) = 6;\n  EXPECT_EQ(CHECK_JUST(VectorAt(vec, 1)), 6);\n  CHECK_JUST(VectorAt(vec, 3)) = 8;\n  EXPECT_EQ(CHECK_JUST(VectorAt(vec, 3)), 8);\n  EXPECT_EQ(CHECK_JUST(VectorAt(vec, 0)), 1);\n  EXPECT_EQ(CHECK_JUST(VectorAt(vec, 2)), 3);\n  EXPECT_EQ(CHECK_JUST(VectorAt(vec, 4)), 5);\n}\n\nnamespace {\nclass A {\n public:\n  explicit A(int a) : a(a) {}\n  int a;\n};\n}  // namespace\n\nTEST(VectorAt, write_custom_class_vector) {\n  std::vector<A> vec = {A(1), A(2)};\n  EXPECT_EQ(CHECK_JUST(VectorAt(vec, 0)).a, 1);\n  EXPECT_EQ(CHECK_JUST(VectorAt(vec, 1)).a, 2);\n  CHECK_JUST(VectorAt(vec, 0)) = A(3);\n  EXPECT_EQ(CHECK_JUST(VectorAt(vec, 0)).a, 3);\n  CHECK_JUST(VectorAt(vec, 1)) = A(4);\n  EXPECT_EQ(CHECK_JUST(VectorAt(vec, 1)).a, 4);\n}\n\n}  // namespace test\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/common/cost_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_COMMON_TIME_UTIL_H_\n#define ONEFLOW_CORE_COMMON_TIME_UTIL_H_\n\n#include <chrono>\n#include <sstream>\n#include <string>\n\n#include \"nlohmann/json.hpp\"\n\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/common/mem_util.h\"\n#include \"oneflow/core/job/utils/progress_bar.h\"\n\nnamespace oneflow {\n\ntemplate<typename DurationT>\nstruct Duration {\n  static const std::string& Repr() {\n    static const std::string repr = \"\";\n    return repr;\n  }\n};\n\n#define DEFINE_DURATION_TRAIT(time_type)             \\\n  template<>                                         \\\n  struct Duration<typename std::chrono::time_type> { \\\n    static const std::string& Repr() {               \\\n      static const std::string repr = #time_type;    \\\n      return repr;                                   \\\n    }                                                \\\n  };\n\nDEFINE_DURATION_TRAIT(nanoseconds)\nDEFINE_DURATION_TRAIT(microseconds)\nDEFINE_DURATION_TRAIT(milliseconds)\nDEFINE_DURATION_TRAIT(seconds)\nDEFINE_DURATION_TRAIT(minutes)\nDEFINE_DURATION_TRAIT(hours)\n#undef DEFINE_DURATION_TRAIT\n\ntemplate<class Resolution = std::chrono::seconds>\nclass CostCounter final {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CostCounter);\n  explicit CostCounter(bool with_log = true, bool with_mem = false)\n      : with_log_(with_log), with_mem_(with_mem) {}\n  ~CostCounter() = default;\n\n  void Count(const std::string& log_prefix = \"\", int v_log_level = 0, bool log_progress = false);\n\n private:\n  using Clock = std::conditional_t<std::chrono::high_resolution_clock::is_steady,\n                                   std::chrono::high_resolution_clock, std::chrono::steady_clock>;\n\n  Clock::time_point start_{Clock::now()};\n  bool with_log_{false};\n  bool with_mem_{false};\n};\n\ntemplate<class Resolution>\nvoid CostCounter<Resolution>::Count(const std::string& log_prefix, int v_log_level,\n                                    bool log_progress) {\n  if (log_progress) { CHECK_JUST(LogProgress(log_prefix)); }\n\n  const auto end = Clock::now();\n  if (FLAGS_minloglevel <= 0 && VLOG_IS_ON(v_log_level) && with_log_ && v_log_level >= 0) {\n    // only do time/mem count and log when glog level is INFO and VLOG level is matched.\n    auto dur = std::chrono::duration_cast<Resolution>(end - start_).count();\n\n    nlohmann::json json_log;\n    json_log[\"loc\"] = log_prefix;\n    json_log[\"time_cost\"] = std::to_string(dur) + \" \" + Duration<Resolution>::Repr();\n\n    if (with_mem_) {\n#ifdef __linux__\n      double vm = 0, rss = 0;\n      ProcessMemUsage(&vm, &rss);\n      json_log[\"mem_rss\"] = std::to_string(rss) + \" MB\";\n#endif  // __linux__\n    }\n\n    if (v_log_level == 0) {\n      LOG(INFO) << \"[count log]\" << json_log.dump();\n    } else {\n      VLOG(v_log_level) << \"[count log]\" << json_log.dump();\n    }\n  }\n  start_ = end;\n  return;\n}\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_COMMON_TIME_UTIL_H_\n"
  },
  {
    "path": "oneflow/core/common/cpp_attribute.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_COMMON_CPP_ATTRIBUTE_H_\n#define ONEFLOW_CORE_COMMON_CPP_ATTRIBUTE_H_\n\n#define likely GOOGLE_PREDICT_TRUE\n#define unlikely GOOGLE_PREDICT_FALSE\n\n#endif  // ONEFLOW_CORE_COMMON_CPP_ATTRIBUTE_H_\n"
  },
  {
    "path": "oneflow/core/common/data_type.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/data_type.h\"\n#include \"oneflow/core/common/tensor_buffer.h\"\n\nnamespace oneflow {\n\nbool IsBoolDataType(DataType data_type) {\n  switch (data_type) {\n#define BOOL_CASE(type_cpp, type_proto) \\\n  case type_proto: return true;\n    OF_PP_FOR_EACH_TUPLE(BOOL_CASE, BOOL_DATA_TYPE_SEQ)\n    default: return false;\n  }\n#undef BOOL_CASE\n}\n\nbool IsIntegralDataType(DataType data_type) {\n  switch (data_type) {\n#define INTEGRAL_CASE(type_cpp, type_proto) \\\n  case type_proto: return true;\n    OF_PP_FOR_EACH_TUPLE(INTEGRAL_CASE, INT_DATA_TYPE_SEQ UNSIGNED_INT_DATA_TYPE_SEQ)\n    default: return false;\n  }\n#undef INTEGRAL_CASE\n}\nbool IsFloatingDataType(DataType data_type) {\n  switch (data_type) {\n#define FLOATING_CASE(type_cpp, type_proto) \\\n  case type_proto: return true;\n    OF_PP_FOR_EACH_TUPLE(FLOATING_CASE,\n                         FLOATING_DATA_TYPE_SEQ FLOAT16_DATA_TYPE_SEQ BFLOAT16_DATA_TYPE_SEQ)\n    default: return false;\n  }\n#undef FLOATING_CASE\n}\nbool IsHalfDataType(DataType data_type) {\n  switch (data_type) {\n#define HALF_CASE(type_cpp, type_proto) \\\n  case type_proto: return true;\n    OF_PP_FOR_EACH_TUPLE(HALF_CASE, FLOAT16_DATA_TYPE_SEQ BFLOAT16_DATA_TYPE_SEQ)\n    default: return false;\n  }\n#undef HALF_CASE\n}\nbool IsComplexDataType(DataType data_type) {\n  switch (data_type) {\n#define COMPLEX_CASE(type_cpp, type_proto) \\\n  case type_proto: return true;\n    OF_PP_FOR_EACH_TUPLE(COMPLEX_CASE, COMPLEX_DATA_TYPE_SEQ)\n    default: return false;\n  }\n#undef COMPLEX_CASE\n}\nbool IsTriviallyCopyableDataType(DataType data_type) {\n  switch (data_type) {\n#define TRIVIALLY_COPY_CASE(type_cpp, type_proto) \\\n  case type_proto: return true;\n    OF_PP_FOR_EACH_TUPLE(TRIVIALLY_COPY_CASE, TRIVIALLY_COPY_DATA_TYPE_SEQ INT16_DATA_TYPE_SEQ)\n    default: return false;\n  }\n#undef TRIVIALLY_COPY_CASE\n}\nbool IsIndexDataType(DataType data_type) {\n  switch (data_type) {\n#define INDEX_CASE(type_cpp, type_proto) \\\n  case type_proto: return true;\n    OF_PP_FOR_EACH_TUPLE(INDEX_CASE, INDEX_DATA_TYPE_SEQ)\n    default: return false;\n  }\n#undef INDEX_CASE\n}\nbool IsSupportRequireGradDataType(DataType data_type) {\n  switch (data_type) {\n#define REQUIRE_GRAD_CASE(type_cpp, type_proto) \\\n  case type_proto: return true;\n    OF_PP_FOR_EACH_TUPLE(\n        REQUIRE_GRAD_CASE,\n        FLOATING_DATA_TYPE_SEQ FLOAT16_DATA_TYPE_SEQ BFLOAT16_DATA_TYPE_SEQ COMPLEX_DATA_TYPE_SEQ)\n    default: return false;\n  }\n#undef REQUIRE_GRAD_CASE\n}\nbool NotSupportBoxingDataType(DataType data_type) {\n  switch (data_type) {\n#define NO_BOXING_CASE(type_cpp, type_proto) \\\n  case type_proto: return true;\n    OF_PP_FOR_EACH_TUPLE(NO_BOXING_CASE, NO_BOXING_DATA_TYPE_SEQ)\n    default: return false;\n  }\n#undef NO_BOXING_CASE\n}\n\nsize_t GetSizeOfDataType(DataType data_type) {\n  switch (data_type) {\n    // 8-bit\n    case kChar: return 1;\n    case kInt8: return 1;\n    case kUInt8: return 1;\n    case kBool: return 1;\n\n    // 16-bit\n    case kInt16: return 2;\n    case kUInt16: return 2;\n    case kFloat16: return 2;\n    case kBFloat16: return 2;\n\n    // 32-bit\n    case kInt32: return 4;\n    case kUInt32: return 4;\n    case kFloat: return 4;\n    case kComplex32: return 4;\n\n    // 64-bit\n    case kInt64: return 8;\n    case kUInt64: return 8;\n    case kDouble: return 8;\n    case kComplex64: return 8;\n\n    // 128-bit\n    case kInt128: return 16;\n    case kUInt128: return 16;\n    case kComplex128: return 16;\n\n    // non pod\n    case kOFRecord: return sizeof(OFRecord);\n    case kTensorBuffer: return sizeof(TensorBuffer);\n    default: LOG(FATAL) << \"invalid data_type: \" << DataType_Name(data_type);\n  }\n}\n\nnamespace {\n\nvoid CheckDataType() {\n  static_assert(sizeof(int8_t) == sizeof(char), \"sizeof(int8_t) != sizeof(char)\");\n  static_assert(sizeof(int16_t) == sizeof(short), \"sizeof(int16_t) != sizeof(short)\");\n  static_assert(sizeof(int32_t) == sizeof(int), \"sizeof(int32_t) != sizeof(int)\");\n  static_assert(sizeof(int64_t) == sizeof(long long), \"sizeof(int64_t) != sizeof(long long)\");\n\n#if defined(WITH_CUDA)\n\n#define CHECK_DEVICE_FP16(get_val)                              \\\n  do {                                                          \\\n    float16 host_fp16 = get_val<float16>();                     \\\n    half device_fp16 = get_val<half>();                         \\\n    CHECK_EQ(*(uint16_t*)&host_fp16, *(uint16_t*)&device_fp16); \\\n  } while (0)\n\n  CHECK_DEVICE_FP16(GetZeroVal);\n  CHECK_DEVICE_FP16(GetOneVal);\n  CHECK_DEVICE_FP16(GetMaxVal);\n  CHECK_DEVICE_FP16(GetMinVal);\n#undef CHECK_DEVICE_FP16\n\n#endif\n\n#define CHECK_MAX_VAL(T, limit_value) CHECK_EQ(GetMaxVal<T>(), std::numeric_limits<T>::max());\n  OF_PP_FOR_EACH_TUPLE(CHECK_MAX_VAL, MAX_VAL_SEQ);\n#undef CHECK_MAX_VAL\n\n#define CHECK_MIN_VAL(T, limit_value) CHECK_EQ(GetMinVal<T>(), std::numeric_limits<T>::lowest());\n  OF_PP_FOR_EACH_TUPLE(CHECK_MIN_VAL, MIN_VAL_SEQ);\n#undef CHECK_MIN_VAL\n}\n\nCOMMAND(CheckDataType());\n\n}  // namespace\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/common/data_type.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_COMMON_DATA_TYPE_H_\n#define ONEFLOW_CORE_COMMON_DATA_TYPE_H_\n\n#include <cfloat>\n#include <type_traits>\n#if defined(WITH_CUDA)\n#include <cuda_fp16.h>\n#include <cuda.h>\n#include <cuComplex.h>\n#if CUDA_VERSION >= 11000\n#include <cuda_bf16.h>\n#endif  // CUDA_VERSION >= 11000\n#endif\n#include \"oneflow/core/common/bfloat16.h\"\n#include \"oneflow/core/common/bfloat16_math.h\"\n#include \"oneflow/core/common/data_type.pb.h\"\n#include \"oneflow/core/common/data_type_seq.h\"\n#include \"oneflow/core/record/record.pb.h\"\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/common/device_type.h\"\n#include <half.hpp>\n\nnamespace std {\n\n// Extend numeric_limits<half> for the C++ standard library.\n#ifdef WITH_CUDA\n\ntemplate<>\nstruct numeric_limits<half> {\n  static constexpr int digits = std::numeric_limits<half_float::half>::digits;\n\n  static constexpr half_float::half lowest() {\n    return std::numeric_limits<half_float::half>::lowest();\n  }\n\n  static constexpr half_float::half max() { return std::numeric_limits<half_float::half>::max(); }\n};\n\n#endif  // WITH_CUDA\n\n}  // namespace std\n\nnamespace oneflow {\n\nnamespace detail {\n\ntemplate<typename>\nstruct IsFloat16Helper : std::false_type {};\n\ntemplate<typename>\nstruct IsFloatingHelper : std::false_type {};\n\ntemplate<typename>\nstruct IsIntegralHelper : std::false_type {};\n\ntemplate<typename>\nstruct IsUnsignedIntegralHelper : std::false_type {};\n\n#ifdef WITH_CUDA\ntemplate<typename>\nstruct IsCudaComplexHelper : std::false_type {};\n#endif  // WITH_CUDA\n\n}  // namespace detail\n\nusing float16 = half_float::half;\n\n#define DEFINE_SPEC(Trait, Type, Value) \\\n  template<>                            \\\n  struct Trait<Type> : std::integral_constant<bool, Value> {};\n\n// Type Trait: IsFloat16\n\nDEFINE_SPEC(detail::IsFloat16Helper, float16, true)\n#ifdef WITH_CUDA\nDEFINE_SPEC(detail::IsFloat16Helper, half, true)\n#endif  // WITH_CUDA\n\ntemplate<typename T>\nstruct IsFloat16\n    : std::integral_constant<bool,\n                             (detail::IsFloat16Helper<typename std::remove_cv<T>::type>::value)> {};\n\n// Type Trait: IsCudaComplex\n#ifdef WITH_CUDA\nDEFINE_SPEC(detail::IsCudaComplexHelper, cuComplex, true)\nDEFINE_SPEC(detail::IsCudaComplexHelper, cuDoubleComplex, true)\n\ntemplate<typename T>\nstruct IsCudaComplex\n    : std::integral_constant<\n          bool, (detail::IsCudaComplexHelper<typename std::remove_cv<T>::type>::value)> {};\n#endif  // WITH_CUDA\n\n// Type Trait: IsFloating\n\n#define SPECIALIZE_TRUE_FLOATING(type_cpp, type_proto) \\\n  DEFINE_SPEC(detail::IsFloatingHelper, type_cpp, true)\nOF_PP_FOR_EACH_TUPLE(SPECIALIZE_TRUE_FLOATING, FLOATING_DATA_TYPE_SEQ);\n#undef SPECIALIZE_TRUE_FLOATING\nDEFINE_SPEC(detail::IsFloatingHelper, float16, true)\n#ifdef WITH_CUDA\nDEFINE_SPEC(detail::IsFloatingHelper, half, true)\n#endif  // WITH_CUDA\n\ntemplate<typename T>\nstruct IsFloating\n    : std::integral_constant<bool,\n                             (detail::IsFloatingHelper<typename std::remove_cv<T>::type>::value)> {\n};\n\n// Type Trait: IsIntegral\n\n#define SPECIALIZE_TRUE_INTEGRAL(type_cpp, type_proto) \\\n  DEFINE_SPEC(detail::IsIntegralHelper, type_cpp, true)\nOF_PP_FOR_EACH_TUPLE(SPECIALIZE_TRUE_INTEGRAL, INT_DATA_TYPE_SEQ);\n#undef SPECIALIZE_TRUE_INTEGRAL\n\ntemplate<typename T>\nstruct IsIntegral\n    : std::integral_constant<bool,\n                             (detail::IsIntegralHelper<typename std::remove_cv<T>::type>::value)> {\n};\n\n// Type Trait: IsUnsignedIntegral\n\n#define SPECIALIZE_TRUE_INTEGRAL(type_cpp, type_proto) \\\n  DEFINE_SPEC(detail::IsUnsignedIntegralHelper, type_cpp, true)\nOF_PP_FOR_EACH_TUPLE(SPECIALIZE_TRUE_INTEGRAL, UNSIGNED_INT_DATA_TYPE_SEQ);\n#undef SPECIALIZE_TRUE_INTEGRAL\n\ntemplate<typename T>\nstruct IsUnsignedIntegral\n    : std::integral_constant<\n          bool, (detail::IsUnsignedIntegralHelper<typename std::remove_cv<T>::type>::value)> {};\n\n#undef DEFINE_SPEC\n\n// Type Trait: GetDataType\n\ntemplate<typename T, typename T2 = void>\nstruct GetDataType;\n\ntemplate<>\nstruct GetDataType<void> : std::integral_constant<DataType, DataType::kChar> {};\n\n#define SPECIALIZE_GET_DATA_TYPE(type_cpp, type_proto)                            \\\n  template<>                                                                      \\\n  struct GetDataType<type_cpp> : std::integral_constant<DataType, type_proto> {}; \\\n  inline type_cpp GetTypeByDataType(std::integral_constant<DataType, type_proto>) { return {}; }\nOF_PP_FOR_EACH_TUPLE(SPECIALIZE_GET_DATA_TYPE,\n                     ALL_DATA_TYPE_SEQ UNSIGNED_INT32_DATA_TYPE_SEQ FLOAT16_DATA_TYPE_SEQ\n                         BFLOAT16_DATA_TYPE_SEQ COMPLEX_DATA_TYPE_SEQ UNSIGNED_INT64_DATA_TYPE_SEQ\n                             INT16_DATA_TYPE_SEQ);\n#undef SPECIALIZE_GET_DATA_TYPE\n\ntemplate<typename T>\nstruct GetDataType<T, typename std::enable_if<IsFloat16<T>::value>::type>\n    : std::integral_constant<DataType, DataType::kFloat16> {};\n\n#ifdef WITH_CUDA\ntemplate<>\nstruct GetDataType<cuComplex> : std::integral_constant<DataType, DataType::kComplex64> {};\ntemplate<>\nstruct GetDataType<cuDoubleComplex> : std::integral_constant<DataType, DataType::kComplex128> {};\n#endif  // WITH_CUDA\n\n#if CUDA_VERSION >= 11000\ntemplate<>\nstruct GetDataType<nv_bfloat16> : std::integral_constant<DataType, DataType::kBFloat16> {};\n#endif\n\ntemplate<DataType type>\nusing DataTypeToType = decltype(GetTypeByDataType(std::integral_constant<DataType, type>{}));\n\n#if defined(__CUDACC__)\n#define OF_DEVICE_FUNC __device__ __host__ __forceinline__\n#else\n#define OF_DEVICE_FUNC inline\n#endif\n\n#ifdef WITH_CUDA\ntemplate<typename T, typename std::enable_if<!(IsFloat16<T>::value\n                                               || IsCudaComplex<T>::value)>::type* = nullptr>\nOF_DEVICE_FUNC T GetZeroVal() {\n  return static_cast<T>(0);\n}\n\ntemplate<typename T, typename std::enable_if<!(IsFloat16<T>::value\n                                               || IsCudaComplex<T>::value)>::type* = nullptr>\nOF_DEVICE_FUNC T GetOneVal() {\n  return static_cast<T>(1);\n}\n#else\ntemplate<typename T, typename std::enable_if<!IsFloat16<T>::value>::type* = nullptr>\nOF_DEVICE_FUNC T GetZeroVal() {\n  return static_cast<T>(0);\n}\n\ntemplate<typename T, typename std::enable_if<!IsFloat16<T>::value>::type* = nullptr>\nOF_DEVICE_FUNC T GetOneVal() {\n  return static_cast<T>(1);\n}\n#endif  // WITH_CUDA\n\ntemplate<typename T, typename std::enable_if<!IsFloat16<T>::value>::type* = nullptr>\nOF_DEVICE_FUNC T GetMinVal();\n\ntemplate<typename T, typename std::enable_if<!IsFloat16<T>::value>::type* = nullptr>\nOF_DEVICE_FUNC T GetMaxVal();\n\n#ifdef __APPLE__\n#define APPLE_MAX_VAL_SEQ OF_PP_MAKE_TUPLE_SEQ(unsigned long, ULONG_MAX)\n#else\n#define APPLE_MAX_VAL_SEQ\n#endif\n\n#define MAX_VAL_SEQ                          \\\n  OF_PP_MAKE_TUPLE_SEQ(int8_t, INT8_MAX)     \\\n  OF_PP_MAKE_TUPLE_SEQ(int16_t, INT16_MAX)   \\\n  OF_PP_MAKE_TUPLE_SEQ(int32_t, INT32_MAX)   \\\n  OF_PP_MAKE_TUPLE_SEQ(int64_t, INT64_MAX)   \\\n  OF_PP_MAKE_TUPLE_SEQ(uint8_t, UINT8_MAX)   \\\n  OF_PP_MAKE_TUPLE_SEQ(uint16_t, UINT16_MAX) \\\n  OF_PP_MAKE_TUPLE_SEQ(uint32_t, UINT32_MAX) \\\n  APPLE_MAX_VAL_SEQ                          \\\n  OF_PP_MAKE_TUPLE_SEQ(uint64_t, UINT64_MAX) \\\n  OF_PP_MAKE_TUPLE_SEQ(float, FLT_MAX)       \\\n  OF_PP_MAKE_TUPLE_SEQ(double, DBL_MAX)      \\\n  OF_PP_MAKE_TUPLE_SEQ(bool, true)\n\n#ifdef __APPLE__\n#define APPLE_MIN_VAL_SEQ OF_PP_MAKE_TUPLE_SEQ(unsigned long, 0)\n#else\n#define APPLE_MIN_VAL_SEQ\n#endif\n\n#define MIN_VAL_SEQ                        \\\n  OF_PP_MAKE_TUPLE_SEQ(int8_t, INT8_MIN)   \\\n  OF_PP_MAKE_TUPLE_SEQ(int16_t, INT16_MIN) \\\n  OF_PP_MAKE_TUPLE_SEQ(int32_t, INT32_MIN) \\\n  OF_PP_MAKE_TUPLE_SEQ(int64_t, INT64_MIN) \\\n  OF_PP_MAKE_TUPLE_SEQ(uint8_t, 0)         \\\n  OF_PP_MAKE_TUPLE_SEQ(uint16_t, 0)        \\\n  OF_PP_MAKE_TUPLE_SEQ(uint32_t, 0)        \\\n  APPLE_MIN_VAL_SEQ                        \\\n  OF_PP_MAKE_TUPLE_SEQ(uint64_t, 0)        \\\n  OF_PP_MAKE_TUPLE_SEQ(float, -FLT_MAX)    \\\n  OF_PP_MAKE_TUPLE_SEQ(double, -DBL_MAX)   \\\n  OF_PP_MAKE_TUPLE_SEQ(bool, false)\n\n#define SPECIALIZE_MAX_VAL(T, limit_value) \\\n  template<>                               \\\n  OF_DEVICE_FUNC T GetMaxVal<T>() {        \\\n    return limit_value;                    \\\n  }\nOF_PP_FOR_EACH_TUPLE(SPECIALIZE_MAX_VAL, MAX_VAL_SEQ);\n#undef SPECIALIZE_MAX_VAL\n\n#define SPECIALIZE_MIN_VAL(T, limit_value) \\\n  template<>                               \\\n  OF_DEVICE_FUNC T GetMinVal<T>() {        \\\n    return limit_value;                    \\\n  }\nOF_PP_FOR_EACH_TUPLE(SPECIALIZE_MIN_VAL, MIN_VAL_SEQ);\n#undef SPECIALIZE_MIN_VAL\n\ntemplate<typename T>\nconst T* GetZeroPtr() {\n  static const T ret = GetZeroVal<T>();\n  return &ret;\n}\n\ntemplate<typename T>\nconst T* GetOnePtr() {\n  static const T ret = GetOneVal<T>();\n  return &ret;\n}\n\ntemplate<typename T, typename std::enable_if<IsFloat16<T>::value>::type* = nullptr>\nOF_DEVICE_FUNC T GetZeroVal() {\n  uint16_t ret = 0x0;  // Decimal: 0; Binary: 0 00000 0000000000\n  return *(T*)&ret;\n}\n\n#ifdef WITH_CUDA\ntemplate<typename T, typename std::enable_if<std::is_same<T, cuComplex>::value>::type* = nullptr>\nOF_DEVICE_FUNC T GetZeroVal() {\n  return make_cuFloatComplex((float)0.0, (float)0.0);\n}\ntemplate<typename T,\n         typename std::enable_if<std::is_same<T, cuDoubleComplex>::value>::type* = nullptr>\nOF_DEVICE_FUNC T GetZeroVal() {\n  return make_cuDoubleComplex((double)0.0, (double)0.0);\n}\n#endif  // WITH_CUDA\n\ntemplate<typename T, typename std::enable_if<IsFloat16<T>::value>::type* = nullptr>\nOF_DEVICE_FUNC T GetOneVal() {\n  uint16_t ret = 0x3c00;  // Decimal: 15360; Binary: 0 01111 0000000000\n  return *(T*)&ret;\n}\n\n#ifdef WITH_CUDA\ntemplate<typename T, typename std::enable_if<std::is_same<T, cuComplex>::value>::type* = nullptr>\nOF_DEVICE_FUNC T GetOneVal() {\n  return make_cuFloatComplex((float)1.0, (float)1.0);\n}\n\ntemplate<typename T,\n         typename std::enable_if<std::is_same<T, cuDoubleComplex>::value>::type* = nullptr>\nOF_DEVICE_FUNC T GetOneVal() {\n  return make_cuDoubleComplex((double)1.0, (double)1.0);\n}\n#endif  // WITH_CUDA\n\ntemplate<typename T, typename std::enable_if<IsFloat16<T>::value>::type* = nullptr>\nOF_DEVICE_FUNC T GetMaxVal() {\n  uint16_t ret = 0x7bff;  // Decimal: 31743; Binary: 0 11110 1111111111\n  return *(T*)&ret;\n}\n\ntemplate<typename T, typename std::enable_if<IsFloat16<T>::value>::type* = nullptr>\nOF_DEVICE_FUNC T GetMinVal() {\n  uint16_t ret = 0xfbff;  // Decimal: 64511; Binary: 1 11110 1111111111\n  return *(T*)&ret;\n}\n\n#if CUDA_VERSION >= 11000\ntemplate<>\nOF_DEVICE_FUNC nv_bfloat16 GetMinVal<nv_bfloat16>() {\n  uint16_t ret = 0xff7f;\n  return *(nv_bfloat16*)&ret;\n}\n#endif  // CUDA_VERSION >= 11000\n\ntemplate<DeviceType, typename T>\nstruct DevDType {\n  typedef T type;\n};\n\n#if defined(WITH_CUDA)\ntemplate<>\nstruct DevDType<DeviceType::kCUDA, float16> {\n  static_assert(sizeof(float16) == sizeof(half), \"sizeof(float16) != sizeof(half)\");\n  typedef half type;\n};\n#if CUDA_VERSION >= 11000\ntemplate<>\nstruct DevDType<DeviceType::kCUDA, bfloat16> {\n  static_assert(sizeof(bfloat16) == sizeof(nv_bfloat16), \"sizeof(bfloat16) != sizeof(nv_bfloat16)\");\n  typedef nv_bfloat16 type;\n};\n#endif  // CUDA_VERSION >= 11000\n#endif  // defined(WITH_CUDA)\n\n// Func\n\nbool IsBoolDataType(DataType data_type);\nbool IsIntegralDataType(DataType data_type);\nbool IsFloatingDataType(DataType data_type);\nbool IsHalfDataType(DataType data_type);\nbool IsSupportRequireGradDataType(DataType data_type);\nbool IsComplexDataType(DataType data_type);\nbool IsTriviallyCopyableDataType(DataType data_type);\nbool IsIndexDataType(DataType data_type);\nbool NotSupportBoxingDataType(DataType data_type);\nsize_t GetSizeOfDataType(DataType data_type);\n\ninline bool operator==(const OptInt64& lhs, const OptInt64& rhs) {\n  return (lhs.has_value() && rhs.has_value() && lhs.value() == rhs.value())\n         || (!lhs.has_value() && !rhs.has_value());\n}\n\ntemplate<typename T>\nvoid CheckDataType(DataType data_type) {\n  LOG_IF(FATAL, (std::is_same<T, void>::value == false && std::is_same<T, char>::value == false\n                 && data_type != DataType::kChar && data_type != GetDataType<T>::value))\n      << data_type << \" \" << GetDataType<T>::value;\n}\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_COMMON_DATA_TYPE_H_\n"
  },
  {
    "path": "oneflow/core/common/data_type.proto",
    "content": "syntax = \"proto2\";\npackage oneflow;\n\nenum DataType {\n  kInvalidDataType = 0;\n  kChar = 1;\n  kFloat = 2;\n  kDouble = 3;\n  kInt8 = 4;\n  kInt32 = 5;\n  kInt64 = 6;\n  kUInt8 = 7;\n  kOFRecord = 8;\n  kFloat16 = 9;\n  kTensorBuffer = 10;\n  kBFloat16 = 11;\n  kBool = 12;\n  kUInt16 = 13;\n  kUInt32 = 14;\n  kUInt64 = 15;\n  kUInt128 = 16;\n  kInt16 = 17;\n  kInt128 = 18;\n  kComplex32 = 19;\n  kComplex64 = 20;\n  kComplex128 = 21;\n}\n\nmessage OptInt64 {\n  optional int64 value = 1 [ default = -1 ];\n}\n"
  },
  {
    "path": "oneflow/core/common/data_type_converter.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_COMMON_DATA_TYPE_CONVERTER_H_\n#define ONEFLOW_CORE_COMMON_DATA_TYPE_CONVERTER_H_\n\n#ifdef WITH_CUDA\n#include <cuda_runtime.h>\n#endif\n#include <cstdint>\n#include <limits>\n#include <type_traits>\n#include \"oneflow/core/common/data_type.h\"\n\nnamespace oneflow {\n\ntemplate<typename T>\nstruct IsFloatingOrHalf {\n  static const bool value = IsFloating<T>::value || IsFloat16<T>::value;\n};\n\ntemplate<typename T>\nstruct IsArithmeticOrHalf {\n  static const bool value = std::is_arithmetic<T>::value || IsFloat16<T>::value;\n};\n\ntemplate<typename From, typename To>\nstruct NeedsClamp {\n  static const bool from_fp = IsFloatingOrHalf<From>::value;\n  static const bool to_fp = IsFloatingOrHalf<To>::value;\n  static const bool from_fp16 = IsFloat16<From>::value;\n  static const bool to_fp16 = IsFloat16<To>::value;\n  static const bool from_unsigned = std::is_unsigned<From>::value;\n  static const bool to_unsigned = std::is_unsigned<To>::value;\n  static const bool value =\n      // to smaller type of same kind (fp, int)\n      (from_fp == to_fp && sizeof(To) < sizeof(From)) ||\n      // fp32 has range in excess of (u)int64\n      (from_fp && !to_fp) ||\n      // converting to unsigned requires clamping negatives to zero\n      (!from_unsigned && to_unsigned) ||\n      // zero-extending signed unsigned integers requires more bits\n      (from_unsigned && !to_unsigned && sizeof(To) <= sizeof(From)) ||\n      // float16\n      (to_fp16 && sizeof(To) <= sizeof(From));\n};\n\ntemplate<typename To>\nstruct NeedsClamp<bool, To> {\n  static const bool value = false;\n};\n\ntemplate<typename T, typename U, typename Enabled = void>\nstruct ClampHelper {};\n\n// floating-point and signed integer -> floating-point and signed integer\ntemplate<typename T, typename U>\nstruct ClampHelper<\n    T, U,\n    std::enable_if_t<\n        NeedsClamp<U, T>::value && std::is_signed<U>::value && std::is_signed<T>::value, void>> {\n  OF_DEVICE_FUNC static const T Call(U value) {\n    return value <= GetMinVal<T>()   ? GetMinVal<T>()\n           : value >= GetMaxVal<T>() ? GetMaxVal<T>()\n                                     : static_cast<T>(value);\n  }\n};\n\n// floating-point -> unsigned types\ntemplate<typename T, typename U>\nstruct ClampHelper<T, U,\n                   std::enable_if_t<NeedsClamp<U, T>::value && std::is_signed<U>::value\n                                        && IsFloatingOrHalf<U>::value && std::is_unsigned<T>::value,\n                                    void>> {\n  OF_DEVICE_FUNC static const T Call(U value) {\n    return value <= GetMinVal<T>()   ? GetMinVal<T>()\n           : value >= GetMaxVal<T>() ? GetMaxVal<T>()\n                                     : static_cast<T>(value);\n  }\n};\n\n// signed integer types -> unsigned types\ntemplate<typename T, typename U>\nstruct ClampHelper<T, U,\n                   std::enable_if_t<NeedsClamp<U, T>::value && std::is_signed<U>::value\n                                        && std::is_integral<U>::value && std::is_unsigned<T>::value,\n                                    void>> {\n  OF_DEVICE_FUNC static const T Call(U value) {\n    return value <= 0                                                      ? 0\n           : static_cast<std::make_unsigned_t<U>>(value) >= GetMaxVal<T>() ? GetMaxVal<T>()\n                                                                           : static_cast<T>(value);\n  }\n};\n\n// unsigned types -> any types\ntemplate<typename T, typename U>\nstruct ClampHelper<T, U,\n                   std::enable_if_t<NeedsClamp<U, T>::value && std::is_unsigned<U>::value, void>> {\n  OF_DEVICE_FUNC static const T Call(U value) {\n    return value >= GetMaxVal<T>() ? GetMaxVal<T>() : static_cast<T>(value);\n  }\n};\n\n// not clamp\ntemplate<typename T, typename U>\nstruct ClampHelper<T, U, std::enable_if_t<!NeedsClamp<U, T>::value, void>> {\n  OF_DEVICE_FUNC static const T Call(U value) { return value; }\n};\n\nOF_DEVICE_FUNC const int32_t Clamp(uint32_t value) {\n  return value & 0x80000000u ? 0x7fffffff : value;\n}\n\nOF_DEVICE_FUNC const uint32_t Clamp(int32_t value) { return value < 0 ? 0u : value; }\n\nOF_DEVICE_FUNC const int32_t Clamp(int64_t value) {\n  return value < static_cast<int64_t>(GetMinVal<int32_t>())   ? GetMinVal<int32_t>()\n         : value > static_cast<int64_t>(GetMaxVal<int32_t>()) ? GetMaxVal<int32_t>()\n                                                              : static_cast<int32_t>(value);\n}\n\ntemplate<>\nstruct ClampHelper<int32_t, uint64_t> {\n  OF_DEVICE_FUNC static const int32_t Call(uint64_t value) {\n    return value > static_cast<uint64_t>(GetMaxVal<int32_t>()) ? GetMaxVal<int32_t>()\n                                                               : static_cast<int32_t>(value);\n  }\n};\n\ntemplate<>\nstruct ClampHelper<uint32_t, int64_t> {\n  OF_DEVICE_FUNC static const uint32_t Call(int64_t value) {\n    return value < 0                                             ? 0\n           : value > static_cast<int64_t>(GetMaxVal<uint32_t>()) ? GetMaxVal<uint32_t>()\n                                                                 : static_cast<uint32_t>(value);\n  }\n};\n\ntemplate<>\nstruct ClampHelper<uint32_t, uint64_t> {\n  OF_DEVICE_FUNC static const uint32_t Call(uint64_t value) {\n    return value > static_cast<uint64_t>(GetMaxVal<uint32_t>()) ? GetMaxVal<uint32_t>()\n                                                                : static_cast<uint32_t>(value);\n  }\n};\n\ntemplate<typename T>\nstruct ClampHelper<bool, T> {\n  OF_DEVICE_FUNC static const bool Call(T value) { return static_cast<bool>(value); }\n};\n\ntemplate<typename T>\nstruct ClampHelper<float16, T> {\n  inline static const float16 Call(T value) {\n    return static_cast<float16>(ClampHelper<T, float>::Call(value) < GetMinVal<float16>()\n                                    ? GetMinVal<float16>()\n                                : ClampHelper<T, float>::Call(value) > GetMaxVal<float16>()\n                                    ? GetMaxVal<float16>()\n                                    : ClampHelper<T, float>::Call(value));\n  }\n};\n\ntemplate<typename T>\nstruct ClampHelper<T, float16> {\n  inline static const T Call(float16 value) {\n    return ClampHelper<T, float>::Call(static_cast<float>(value));\n  }\n};\n\ninline const float16 Clamp(float16 value) { return value; }\n\ntemplate<typename T, typename U>\nOF_DEVICE_FUNC const T Clamp(U value) {\n  return ClampHelper<T, U>::Call(value);\n}\n\nnamespace {\n#ifdef __CUDA_ARCH__\n\ninline __device__ int cuda_round_helper(float f, int) { return __float2int_rn(f); }\n\ninline __device__ unsigned cuda_round_helper(float f, unsigned) { return __float2uint_rn(f); }\n\ninline __device__ long long cuda_round_helper(float f, long long) {\n  return __float2ll_rd(f + 0.5f);\n}\n\ninline __device__ unsigned long long cuda_round_helper(float f, unsigned long long) {\n  return __float2ull_rd(f + 0.5f);\n}\n\ninline __device__ long cuda_round_helper(float f, long) {\n  return sizeof(long) == sizeof(int) ? __float2int_rn(f) : __float2ll_rd(f + 0.5f);\n}\n\ninline __device__ unsigned long cuda_round_helper(float f, unsigned long) {\n  return sizeof(unsigned long) == sizeof(unsigned int) ? __float2uint_rn(f)\n                                                       : __float2ull_rd(f + 0.5f);\n}\n\ninline __device__ int cuda_round_helper(double f, int) { return __double2int_rn(f); }\n\ninline __device__ unsigned cuda_round_helper(double f, unsigned) { return __double2uint_rn(f); }\n\ninline __device__ long long cuda_round_helper(double f, long long) {\n  return __double2ll_rd(f + 0.5f);\n}\n\ninline __device__ unsigned long long cuda_round_helper(double f, unsigned long long) {\n  return __double2ull_rd(f + 0.5f);\n}\n\ninline __device__ long cuda_round_helper(double f, long) {\n  return sizeof(long) == sizeof(int) ? __double2int_rn(f) : __double2ll_rd(f + 0.5f);\n}\n\ninline __device__ unsigned long cuda_round_helper(double f, unsigned long) {\n  return sizeof(unsigned long) == sizeof(unsigned int) ? __double2uint_rn(f)\n                                                       : __double2ull_rd(f + 0.5f);\n}\n#endif\n\ntemplate<typename Out, typename In, bool OutIsFp = IsFloatingOrHalf<Out>::value,\n         bool InIsFp = IsFloatingOrHalf<In>::value>\nstruct ConverterBase;\n\ntemplate<typename Out, typename In>\nstruct Converter : ConverterBase<Out, In> {\n  static_assert(IsArithmeticOrHalf<Out>::value && IsArithmeticOrHalf<In>::value,\n                \"Default ConverterBase can only be used with arithmetic types.\");\n};\n\n// Converts between two FP types\ntemplate<typename Out, typename In>\nstruct ConverterBase<Out, In, true, true> {\n  OF_DEVICE_FUNC static const Out Convert(In value) { return value; }\n  OF_DEVICE_FUNC static const Out ConvertNorm(In value) { return value; }\n  OF_DEVICE_FUNC static const Out ConvertSat(In value) { return value; }\n  OF_DEVICE_FUNC static const Out ConvertSatNorm(In value) { return value; }\n};\n\n// Converts integral to FP type\ntemplate<typename Out, typename In>\nstruct ConverterBase<Out, In, true, false> {\n  OF_DEVICE_FUNC static const Out Convert(In value) { return value; }\n  OF_DEVICE_FUNC static const Out ConvertSat(In value) { return value; }\n  OF_DEVICE_FUNC static const Out ConvertNorm(In value) {\n    return value * (Out(1) / (GetMaxVal<In>()));\n  }\n  OF_DEVICE_FUNC static const Out ConvertSatNorm(In value) {\n    return value * (Out(1) / (GetMaxVal<In>()));\n  }\n};\n\n// Converts integral to float16\ntemplate<typename In>\nstruct ConverterBase<float16, In, true, false> {\n  OF_DEVICE_FUNC static const float16 Convert(In value) {\n    auto out = ConverterBase<float, In, true, false>::Convert(value);\n    return static_cast<float16>(out);\n  }\n\n  OF_DEVICE_FUNC static const float16 ConvertSat(In value) {\n    auto out = ConverterBase<float, In, true, false>::ConvertSat(value);\n    return static_cast<float16>(out);\n  }\n\n  OF_DEVICE_FUNC static const float16 ConvertNorm(In value) {\n    auto out = ConverterBase<float, In, true, false>::ConvertNorm(value);\n    return static_cast<float16>(out);\n  }\n\n  OF_DEVICE_FUNC static const float16 ConvertSatNorm(In value) {\n    auto out = ConverterBase<float, In, true, false>::ConvertSatNorm(value);\n    return static_cast<float16>(out);\n  }\n};\n\n// Converts FP to integral type\ntemplate<typename Out, typename In>\nstruct ConverterBase<Out, In, false, true> {\n  OF_DEVICE_FUNC static const Out Convert(In value) {\n#ifdef __CUDA_ARCH__\n    return Clamp<Out>(cuda_round_helper(value, Out()));\n#else\n    return Clamp<Out>(std::round(value));\n#endif\n  }\n\n  OF_DEVICE_FUNC static const Out ConvertSat(In value) {\n#ifdef __CUDA_ARCH__\n    return Clamp<Out>(cuda_round_helper(value, Out()));\n#else\n    return Clamp<Out>(std::round(value));\n#endif\n  }\n\n  OF_DEVICE_FUNC static const Out ConvertNorm(In value) {\n#ifdef __CUDA_ARCH__\n    return Clamp<Out>(cuda_round_helper(value * GetMaxVal<Out>(), Out()));\n#else\n    return std::round(value * GetMaxVal<Out>());\n#endif\n  }\n\n  OF_DEVICE_FUNC static const Out ConvertSatNorm(In value) {\n#ifdef __CUDA_ARCH__\n    return std::is_signed<Out>::value\n               ? Clamp<Out>(cuda_round_helper(value * GetMaxVal<Out>(), Out()))\n               : cuda_round_helper(GetMaxVal<Out>() * __saturatef(value), Out());\n#else\n    return Clamp<Out>(std::round(value * GetMaxVal<Out>()));\n#endif\n  }\n};\n\n// Converts signed to signed, unsigned to unsigned or unsigned to signed\ntemplate<typename Out, typename In, bool IsOutSigned = std::is_signed<Out>::value,\n         bool IsInSigned = std::is_signed<In>::value>\nstruct ConvertIntInt {\n  OF_DEVICE_FUNC static const Out Convert(In value) { return value; }\n  OF_DEVICE_FUNC static const Out ConvertNorm(In value) {\n    return Converter<Out, float>::Convert(value * (1.0f * GetMaxVal<Out>() / GetMaxVal<In>()));\n  }\n  OF_DEVICE_FUNC static const Out ConvertSat(In value) { return Clamp<Out>(value); }\n  OF_DEVICE_FUNC static const Out ConvertSatNorm(In value) { return ConvertNorm(value); }\n};\n\n// Converts signed to unsigned integer\ntemplate<typename Out, typename In>\nstruct ConvertIntInt<Out, In, false, true> {\n  OF_DEVICE_FUNC static const Out Convert(In value) { return value; }\n  OF_DEVICE_FUNC static const Out ConvertNorm(In value) {\n    return Converter<Out, float>::Convert(value * (1.0f * GetMaxVal<Out>() / GetMaxVal<In>()));\n  }\n  OF_DEVICE_FUNC static const Out ConvertSat(In value) { return Clamp<Out>(value); }\n  OF_DEVICE_FUNC static const Out ConvertSatNorm(In value) {\n#ifdef __CUDA_ARCH__\n    return cuda_round_helper(__saturatef(value * (1.0f / GetMaxVal<In>())) * GetMaxVal<Out>());\n#else\n    return value < 0 ? 0 : ConvertNorm(value);\n  }\n#endif\n  };\n\n  // Converts between integral types\n  template<typename Out, typename In>\n  struct ConverterBase<Out, In, false, false> : ConvertIntInt<Out, In> {\n    static_assert(IsArithmeticOrHalf<Out>::value && IsArithmeticOrHalf<In>::value,\n                  \"Default ConverterBase can only be used with arithmetic types.\");\n  };\n\n  // Pass-through conversion\n  template<typename T>\n  struct Converter<T, T> {\n    static OF_DEVICE_FUNC const T Convert(T value) { return value; }\n    static OF_DEVICE_FUNC const T ConvertSat(T value) { return value; }\n    static OF_DEVICE_FUNC const T ConvertNorm(T value) { return value; }\n    static OF_DEVICE_FUNC const T ConvertSatNorm(T value) { return value; }\n  };\n\n  template<typename raw_out, typename raw_in>\n  using converter_t =\n      Converter<std::remove_cv_t<raw_out>, std::remove_cv_t<std::remove_reference_t<raw_in>>>;\n\n}  // namespace\n\ntemplate<typename Out, typename In>\nOF_DEVICE_FUNC const Out Convert(In value) {\n  return converter_t<Out, In>::Convert(value);\n}\n\ntemplate<typename Out, typename In>\nOF_DEVICE_FUNC const Out ConvertNorm(In value) {\n  return converter_t<Out, In>::ConvertNorm(value);\n}\n\ntemplate<typename Out, typename In>\nOF_DEVICE_FUNC const Out ConvertSat(In value) {\n  return converter_t<Out, In>::ConvertSat(value);\n}\n\ntemplate<typename Out, typename In>\nOF_DEVICE_FUNC const Out ConvertSatNorm(In value) {\n  return converter_t<Out, In>::ConvertSatNorm(value);\n}\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_COMMON_DATA_TYPE_CONVERTER_H_\n"
  },
  {
    "path": "oneflow/core/common/data_type_converter_test.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"gtest/gtest.h\"\n#include \"util.h\"\n#include \"oneflow/core/common/data_type_converter.h\"\n#include \"oneflow/core/common/data_type_converter_test_static.h\"\n#ifdef __CUDA_ARCH__\n#include <cuda_runtime.h>\n#else\n#include <cmath>\n#endif\n\nnamespace oneflow {\n\nnamespace {\n\n// cpp17 std::clamp possible implementation\ntemplate<class T>\nconstexpr const T& clamp(const T& v, const T& lo, const T& hi) {\n  return (v < lo) ? lo : (hi < v) ? hi : v;\n}\n\n}  // namespace\n\nTEST(ClampTest, Clamp) {\n  ASSERT_TRUE(Clamp<uint8_t>(0) == 0);\n  ASSERT_TRUE(Clamp<uint8_t>(255) == 255);\n  ASSERT_TRUE(Clamp<uint8_t>(100) == 100);\n  ASSERT_TRUE(Clamp<uint8_t>(100.3) == 100);\n  ASSERT_TRUE(Clamp<uint8_t>(256) == 255);\n  ASSERT_TRUE(Clamp<uint8_t>(-4) == 0);\n  ASSERT_TRUE(Clamp<uint8_t>(-4.0f) == 0);\n  ASSERT_TRUE(Clamp<uint8_t>(1e+20f) == 255);\n  ASSERT_TRUE(Clamp<uint8_t>(-1e+20f) == 0);\n  ASSERT_TRUE(Clamp<uint8_t>(1e+200) == 255);\n  ASSERT_TRUE(Clamp<uint8_t>(-1e+200) == 0);\n\n  ASSERT_TRUE(Clamp<int8_t>(-4) == -4);\n  ASSERT_TRUE(Clamp<int8_t>(-4.2) == -4);\n  ASSERT_TRUE(Clamp<int8_t>(4.2) == 4);\n  ASSERT_TRUE(Clamp<int8_t>(127) == 127);\n  ASSERT_TRUE(Clamp<int8_t>(128) == 127);\n  ASSERT_TRUE(Clamp<int8_t>(256) == 127);\n  ASSERT_TRUE(Clamp<int8_t>(-128) == -128);\n  ASSERT_TRUE(Clamp<int8_t>(-256) == -128);\n  ASSERT_TRUE(Clamp<int8_t>(1e+20f) == 127);\n  ASSERT_TRUE(Clamp<int8_t>(-1e+20f) == -128);\n  ASSERT_TRUE(Clamp<int8_t>(1e+200) == 127);\n  ASSERT_TRUE(Clamp<int8_t>(-1e+200) == -128);\n\n  ASSERT_TRUE(Clamp<uint16_t>(0) == 0);\n  ASSERT_TRUE(Clamp<uint16_t>(0xffff) == 0xffff);\n  ASSERT_TRUE(Clamp<uint16_t>(100) == 100);\n  ASSERT_TRUE(Clamp<uint16_t>(100.3) == 100);\n  ASSERT_TRUE(Clamp<uint16_t>(0x10000) == 0xffff);\n  ASSERT_TRUE(Clamp<uint16_t>(-4) == 0);\n  ASSERT_TRUE(Clamp<uint16_t>(-4.0f) == 0);\n  ASSERT_TRUE(Clamp<uint16_t>(1e+20f) == 0xffff);\n  ASSERT_TRUE(Clamp<uint16_t>(-1e+20f) == 0);\n  ASSERT_TRUE(Clamp<uint16_t>(1e+200) == 0xffff);\n  ASSERT_TRUE(Clamp<uint16_t>(-1e+200) == 0);\n\n  ASSERT_TRUE(Clamp<int16_t>(-4) == -4);\n  ASSERT_TRUE(Clamp<int16_t>(-4.2) == -4);\n  ASSERT_TRUE(Clamp<int16_t>(4.2) == 4);\n  ASSERT_TRUE(Clamp<int16_t>(0x7fff) == 0x7fff);\n  ASSERT_TRUE(Clamp<int16_t>(0x8000) == 0x7fff);\n  ASSERT_TRUE(Clamp<int16_t>(0x10000) == 0x7fff);\n  ASSERT_TRUE(Clamp<int16_t>(-0x8000) == -0x8000);\n  ASSERT_TRUE(Clamp<int16_t>(-0x10000) == -0x8000);\n  ASSERT_TRUE(Clamp<int16_t>(1e+20f) == 0x7fff);\n  ASSERT_TRUE(Clamp<int16_t>(-1e+20f) == -0x8000);\n  ASSERT_TRUE(Clamp<int16_t>(1e+200) == 0x7fff);\n  ASSERT_TRUE(Clamp<int16_t>(-1e+200) == -0x8000);\n\n  ASSERT_TRUE(Clamp<uint32_t>(0) == 0);\n  ASSERT_TRUE(Clamp<uint32_t>(0xffffffffLL) == 0xffffffffLL);\n  ASSERT_TRUE(Clamp<uint32_t>(100) == 100);\n  ASSERT_TRUE(Clamp<uint32_t>(100.3) == 100);\n  ASSERT_TRUE(Clamp<uint32_t>(0x100000000LL) == 0xffffffffLL);\n  ASSERT_TRUE(Clamp<uint32_t>(-4) == 0);\n  ASSERT_TRUE(Clamp<uint32_t>(-4.0f) == 0);\n  ASSERT_TRUE(Clamp<uint32_t>(1e+20f) == 0xffffffffu);\n  ASSERT_TRUE(Clamp<uint32_t>(-1.0e+20f) == 0);\n  ASSERT_TRUE(Clamp<uint32_t>(1e+200) == 0xffffffffu);\n  ASSERT_TRUE(Clamp<uint32_t>(-1.0e+200) == 0);\n\n  ASSERT_TRUE(Clamp<int32_t>(-4) == -4);\n  ASSERT_TRUE(Clamp<int32_t>(-4LL) == -4);\n  ASSERT_TRUE(Clamp<int32_t>(-4.2) == -4);\n  ASSERT_TRUE(Clamp<int32_t>(4.2) == 4);\n  ASSERT_TRUE(Clamp<int32_t>(0x7fffffff) == 0x7fffffff);\n  ASSERT_TRUE(Clamp<int32_t>(0x80000000L) == 0x7fffffff);\n  ASSERT_TRUE(Clamp<int32_t>(0x100000000L) == 0x7fffffff);\n  ASSERT_TRUE(Clamp<int32_t>(-0x80000000LL) == -0x7fffffff - 1);\n  ASSERT_TRUE(Clamp<int32_t>(-0x100000000LL) == -0x7fffffff - 1);\n  ASSERT_TRUE(Clamp<int32_t>(1.0e+20f) == 0x7fffffff);\n  ASSERT_TRUE(Clamp<int32_t>(-1.0e+20f) == -0x80000000L);\n  ASSERT_TRUE(Clamp<int32_t>(1.0e+200) == 0x7fffffff);\n  ASSERT_TRUE(Clamp<int32_t>(-1.0e+200) == -0x80000000L);\n\n  ASSERT_TRUE(Clamp<int64_t>(1.0e+200) == 0x7fffffffffffffffLL);\n  ASSERT_TRUE(Clamp<int64_t>(-1.0e+200) == -0x7fffffffffffffffLL - 1);\n  ASSERT_TRUE(Clamp<uint64_t>(1.0e+200) == 0xffffffffffffffffULL);\n  ASSERT_TRUE(Clamp<uint64_t>(-1.0e+200) == 0);\n}\n\nTEST(ConvertSat, float2int) {\n  FOR_RANGE(int32_t, exp, -10, 100) {\n    FOR_RANGE(float, sig, -256, 257) {\n      float f = ldexpf(sig, exp);\n      float integral;\n      float fract = modff(f, &integral);\n      if (fract == 0.5f || fract == -0.5f) continue;\n      double rounded = roundf(f);\n      int64_t clamped = clamp<double>(rounded, -128, 127);\n      ASSERT_EQ(ConvertSat<int8_t>(f), clamped) << \" with f = \" << f;\n      clamped = clamp<double>(rounded, 0, 255);\n      ASSERT_EQ(ConvertSat<uint8_t>(f), clamped) << \" with f = \" << f;\n      clamped = clamp<double>(rounded, -0x8000, 0x7fff);\n      ASSERT_EQ(ConvertSat<int16_t>(f), clamped) << \" with f = \" << f;\n      clamped = clamp<double>(rounded, 0, 0xffff);\n      ASSERT_EQ(ConvertSat<uint16_t>(f), clamped) << \" with f = \" << f;\n      clamped = clamp<double>(rounded, int32_t(~0x7fffffff), 0x7fffffff);\n      ASSERT_EQ(ConvertSat<int32_t>(f), clamped) << \" with f = \" << f;\n      clamped = clamp<double>(rounded, 0, 0xffffffffu);\n      ASSERT_EQ(ConvertSat<uint32_t>(f), clamped) << \" with f = \" << f;\n    }\n  }\n}\n\nTEST(ConvertNorm, int2int) {\n  EXPECT_EQ((ConvertNorm<uint8_t, uint8_t>(0)), 0);\n  EXPECT_EQ((ConvertNorm<uint8_t, int8_t>(127)), 255);\n}\n\nTEST(ConvertNorm, float2int) {\n  EXPECT_EQ(ConvertNorm<uint8_t>(0.0f), 0);\n  EXPECT_EQ(ConvertNorm<uint8_t>(0.499f), 127);\n  EXPECT_EQ(ConvertNorm<uint8_t>(1.0f), 255);\n  EXPECT_EQ(ConvertNorm<int8_t>(1.0f), 127);\n  EXPECT_EQ(ConvertNorm<int8_t>(0.499f), 63);\n  EXPECT_EQ(ConvertNorm<int8_t>(-1.0f), -127);\n\n  EXPECT_EQ(ConvertNorm<uint16_t>(0.0f), 0);\n  EXPECT_EQ(ConvertNorm<uint16_t>(1.0f), 0xffff);\n  EXPECT_EQ(ConvertNorm<int16_t>(1.0f), 0x7fff);\n  EXPECT_EQ(ConvertNorm<int16_t>(-1.0f), -0x7fff);\n}\n\nTEST(ConvertSatNorm, float2int) {\n  EXPECT_EQ(ConvertSatNorm<uint8_t>(2.0f), 255);\n  EXPECT_EQ(ConvertSatNorm<uint8_t>(0.499f), 127);\n  EXPECT_EQ(ConvertSatNorm<uint8_t>(-2.0f), 0);\n  EXPECT_EQ(ConvertSatNorm<int8_t>(2.0f), 127);\n  EXPECT_EQ(ConvertSatNorm<int8_t>(0.499f), 63);\n  EXPECT_EQ(ConvertSatNorm<int8_t>(-2.0f), -128);\n  EXPECT_EQ(ConvertSatNorm<uint8_t>(0.4f / 255), 0);\n  EXPECT_EQ(ConvertSatNorm<uint8_t>(0.6f / 255), 1);\n\n  EXPECT_EQ(ConvertSatNorm<int16_t>(2.0f), 0x7fff);\n  EXPECT_EQ(ConvertSatNorm<int16_t>(-2.0f), -0x8000);\n}\n\nTEST(ConvertNorm, int2float) {\n  EXPECT_EQ((ConvertNorm<float, uint8_t>(255)), 1.0f);\n  EXPECT_NEAR((ConvertNorm<float, uint8_t>(127)), 1.0f * 127 / 255, 1e-7f);\n  EXPECT_EQ((ConvertNorm<float, int8_t>(127)), 1.0f);\n  EXPECT_NEAR((ConvertNorm<float, int8_t>(64)), 1.0f * 64 / 127, 1e-7f);\n}\n\nTEST(Clamp1, int64_2_float16) {\n  int64_t big_num = 0x0FFFFFFFFFFFFFFF;\n  EXPECT_EQ(static_cast<float>(Clamp<float16>(big_num)), Clamp<float16>(Clamp<float>(big_num)));\n  EXPECT_EQ(65504.0f, Clamp<float16>(big_num));\n  EXPECT_EQ(-65504.0f, Clamp<float16>(-big_num));\n}\n\nTEST(Clamp2, float16_2_int64) {\n  float16 fp16 = static_cast<float16>(65504.0f);\n  EXPECT_EQ(65504, Clamp<int64_t>(fp16));\n  EXPECT_EQ(-65504, Clamp<int64_t>(-fp16));\n}\n\n}  // namespace oneflow"
  },
  {
    "path": "oneflow/core/common/data_type_converter_test_static.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_COMMON_DATA_TYPE_CONVERTER_TEST_STATIC_H_\n#define ONEFLOW_CORE_COMMON_DATA_TYPE_CONVERTER_TEST_STATIC_H_\n\n#include \"oneflow/core/common/data_type_converter.h\"\n\nnamespace oneflow {\n\nnamespace {\n// fp to int\nstatic_assert(NeedsClamp<float, int8_t>::value, \"Float range exceeds all ints up to 64b\");\nstatic_assert(NeedsClamp<float, uint8_t>::value, \"Float range exceeds all ints up to 64b\");\nstatic_assert(NeedsClamp<float, int16_t>::value, \"Float range exceeds all ints up to 64b\");\nstatic_assert(NeedsClamp<float, uint16_t>::value, \"Float range exceeds all ints up to 64b\");\nstatic_assert(NeedsClamp<float, int32_t>::value, \"Float range exceeds all ints up to 64b\");\nstatic_assert(NeedsClamp<float, uint32_t>::value, \"Float range exceeds all ints up to 64b\");\nstatic_assert(NeedsClamp<float, int64_t>::value, \"Float range exceeds all ints up to 64b\");\nstatic_assert(NeedsClamp<float, uint64_t>::value, \"Float range exceeds all ints up to 64b\");\n\n// same size, different signedness\nstatic_assert(NeedsClamp<int8_t, uint8_t>::value, \"Signed <-> unsigned requires clamp\");\nstatic_assert(NeedsClamp<uint8_t, int8_t>::value, \"Signed <-> unsigned requires clamp\");\nstatic_assert(NeedsClamp<int16_t, uint16_t>::value, \"Signed <-> unsigned requires clamp\");\nstatic_assert(NeedsClamp<uint16_t, int16_t>::value, \"Signed <-> unsigned requires clamp\");\nstatic_assert(NeedsClamp<int32_t, uint32_t>::value, \"Signed <-> unsigned requires clamp\");\nstatic_assert(NeedsClamp<uint32_t, int32_t>::value, \"Signed <-> unsigned requires clamp\");\nstatic_assert(NeedsClamp<int64_t, uint64_t>::value, \"Signed <-> unsigned requires clamp\");\nstatic_assert(NeedsClamp<uint64_t, int64_t>::value, \"Signed <-> unsigned requires clamp\");\n\n// larger, but unsigned\nstatic_assert(NeedsClamp<int8_t, uint16_t>::value, \"Need to clamp negatives to 0\");\nstatic_assert(NeedsClamp<int8_t, uint32_t>::value, \"Need to clamp negatives to 0\");\nstatic_assert(NeedsClamp<int8_t, uint64_t>::value, \"Need to clamp negatives to 0\");\nstatic_assert(NeedsClamp<int16_t, uint32_t>::value, \"Need to clamp negatives to 0\");\nstatic_assert(NeedsClamp<int16_t, uint64_t>::value, \"Need to clamp negatives to 0\");\nstatic_assert(NeedsClamp<int32_t, uint64_t>::value, \"Need to clamp negatives to 0\");\n\nstatic_assert(!NeedsClamp<int8_t, int8_t>::value, \"Clamping not required\");\nstatic_assert(!NeedsClamp<int8_t, int16_t>::value, \"Clamping not required\");\nstatic_assert(!NeedsClamp<uint8_t, int16_t>::value, \"Clamping not required\");\nstatic_assert(!NeedsClamp<uint8_t, uint16_t>::value, \"Clamping not required\");\nstatic_assert(!NeedsClamp<float, float>::value, \"Clamping not required\");\nstatic_assert(!NeedsClamp<float, double>::value, \"Clamping not required\");\n\n}  // namespace\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_COMMON_DATA_TYPE_CONVERTER_TEST_STATIC_H_\n"
  },
  {
    "path": "oneflow/core/common/data_type_seq.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_COMMON_DATA_TYPE_SEQ_H_\n#define ONEFLOW_CORE_COMMON_DATA_TYPE_SEQ_H_\n\n#include <complex>\n#include \"oneflow/core/common/preprocessor.h\"\n\n// SEQ\n\n#define BOOL_DATA_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(bool, DataType::kBool)\n\n#define FLOATING_DATA_TYPE_SEQ                  \\\n  OF_PP_MAKE_TUPLE_SEQ(float, DataType::kFloat) \\\n  OF_PP_MAKE_TUPLE_SEQ(double, DataType::kDouble)\n\n#define SIGNED_INT_DATA_TYPE_SEQ                  \\\n  OF_PP_MAKE_TUPLE_SEQ(int8_t, DataType::kInt8)   \\\n  OF_PP_MAKE_TUPLE_SEQ(int32_t, DataType::kInt32) \\\n  OF_PP_MAKE_TUPLE_SEQ(int64_t, DataType::kInt64)\n\n#define INT16_DATA_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(int16_t, DataType::kInt16)\n#define UNSIGNED_INT_DATA_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(uint8_t, DataType::kUInt8)\n#define UNSIGNED_INT32_DATA_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(uint32_t, DataType::kUInt32)\n#define UNSIGNED_INT64_DATA_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(uint64_t, DataType::kUInt64)\n\n#define INT_DATA_TYPE_SEQ SIGNED_INT_DATA_TYPE_SEQ\n\n#define CHAR_DATA_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(char, DataType::kChar)\n\n#define COMPLEX_DATA_TYPE_SEQ                                     \\\n  OF_PP_MAKE_TUPLE_SEQ(std::complex<float>, DataType::kComplex64) \\\n  OF_PP_MAKE_TUPLE_SEQ(std::complex<double>, DataType::kComplex128)\n\n#define ARITHMETIC_DATA_TYPE_SEQ \\\n  FLOATING_DATA_TYPE_SEQ         \\\n  INT_DATA_TYPE_SEQ\n\n#define POD_DATA_TYPE_SEQ \\\n  ARITHMETIC_DATA_TYPE_SEQ CHAR_DATA_TYPE_SEQ UNSIGNED_INT_DATA_TYPE_SEQ BOOL_DATA_TYPE_SEQ\n#define POD_AND_HALF_DATA_TYPE_SEQ POD_DATA_TYPE_SEQ FLOAT16_DATA_TYPE_SEQ BFLOAT16_DATA_TYPE_SEQ\n#define TRIVIALLY_COPY_DATA_TYPE_SEQ POD_AND_HALF_DATA_TYPE_SEQ COMPLEX_DATA_TYPE_SEQ\n#define PB_DATA_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(OFRecord, DataType::kOFRecord)\n#define ALL_DATA_TYPE_SEQ POD_DATA_TYPE_SEQ PB_DATA_TYPE_SEQ\n\n#define INDEX_DATA_TYPE_SEQ                       \\\n  OF_PP_MAKE_TUPLE_SEQ(int32_t, DataType::kInt32) \\\n  OF_PP_MAKE_TUPLE_SEQ(int64_t, DataType::kInt64)\n\n#define FLOAT16_DATA_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(float16, DataType::kFloat16)\n\n#define BFLOAT16_DATA_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(bfloat16, DataType::kBFloat16)\n\n#if defined(WITH_CUDA)\n#define HALF_DATA_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(half, DataType::kFloat16)\n#if CUDA_VERSION >= 11000\n#define NV_BFLOAT16_DATA_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(nv_bfloat16, DataType::kBFloat16)\n#endif  // CUDA_VERSION >= 11000\n#endif  // defined(WITH_CUDA)\n\n#define IMAGE_DATA_TYPE_SEQ                       \\\n  OF_PP_MAKE_TUPLE_SEQ(uint8_t, DataType::kUInt8) \\\n  OF_PP_MAKE_TUPLE_SEQ(float, DataType::kFloat)\n\n#define NO_BOXING_DATA_TYPE_SEQ                       \\\n  OF_PP_MAKE_TUPLE_SEQ(OFRecord, DataType::kOFRecord) \\\n  OF_PP_MAKE_TUPLE_SEQ(TensorBuffer, DataType::kTensorBuffer)\n\n#endif  // ONEFLOW_CORE_COMMON_DATA_TYPE_SEQ_H_\n"
  },
  {
    "path": "oneflow/core/common/decorator.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_COMMON_DECORATOR_H_\n#define ONEFLOW_CORE_COMMON_DECORATOR_H_\n\n#include <type_traits>\n#include <unordered_map>\n#include \"tuple_hash.h\"\n#include \"static_check.h\"\n#include \"oneflow/core/common/env_var/env_var.h\"\n#include \"oneflow/core/common/cpp_attribute.h\"\n\nnamespace oneflow {\n\ntemplate<template<typename...> class Decorator>\nstruct WithDecorator final {\n  template<typename T, typename = void>\n  struct Decorate;\n  template<typename T, typename... Args>\n  struct Decorate<T (*)(Args...)> final {\n    template<T (*func)(Args...)>\n    static T Call(Args... args) {\n      return Decorator<T, Args...>::template Call<func>(args...);\n    }\n  };\n};\n\n#define DECORATE(fn_ptr, decorator) \\\n  (&WithDecorator<decorator>::Decorate<decltype(fn_ptr)>::Call<fn_ptr>)\n\ntemplate<typename... Args>\nstruct ThreadLocalCopiable;\n\ntemplate<typename RetT>\nstruct ThreadLocalCopiable<RetT> {\n  template<RetT (*func)()>\n  static RetT Call() {\n    static thread_local RetT value = func();\n    return value;\n  }\n};\n\ntemplate<typename RetT, typename Arg0>\nstruct ThreadLocalCopiable<RetT, Arg0> {\n  template<RetT (*func)(Arg0)>\n  static RetT Call(Arg0 arg0) {\n    using KeyT = typename std::decay<Arg0>::type;\n    using MappedT = typename std::decay<RetT>::type;\n    static thread_local std::unordered_map<KeyT, MappedT> map;\n    auto iter = map.find(arg0);\n    if (iter == map.end()) { iter = map.emplace(arg0, func(arg0)).first; }\n    return iter->second;\n  }\n\n private:\n  static_assert(!IsOutArg<Arg0>::value, \"\");\n  static_assert(!StaticAny<IsOutArg, Arg0>::value, \"\");\n};\n\ntemplate<typename RetT, typename Arg0, typename Arg1>\nstruct ThreadLocalCopiable<RetT, Arg0, Arg1> {\n  template<RetT (*func)(Arg0, Arg1)>\n  static RetT Call(Arg0 arg0, Arg1 arg1) {\n    using KeyT0 = typename std::decay<Arg0>::type;\n    using KeyT1 = typename std::decay<Arg1>::type;\n    using MappedT = typename std::decay<RetT>::type;\n    static thread_local std::unordered_map<KeyT0, std::unordered_map<KeyT1, MappedT>> map;\n    auto* last_map = &map[arg0];\n    auto iter = last_map->find(arg1);\n    if (iter == last_map->end()) { iter = last_map->emplace(arg1, func(arg0, arg1)).first; }\n    return iter->second;\n  }\n\n private:\n  static_assert(!StaticAny<IsOutArg, Arg0, Arg1>::value, \"\");\n};\n\ntemplate<typename RetT, typename Arg0, typename Arg1, typename Arg2>\nstruct ThreadLocalCopiable<RetT, Arg0, Arg1, Arg2> {\n  template<RetT (*func)(Arg0, Arg1, Arg2)>\n  static RetT Call(Arg0 arg0, Arg1 arg1, Arg2 arg2) {\n    using KeyT0 = typename std::decay<Arg0>::type;\n    using KeyT1 = typename std::decay<Arg1>::type;\n    using KeyT2 = typename std::decay<Arg2>::type;\n    using MappedT = typename std::decay<RetT>::type;\n    static thread_local std::unordered_map<\n        KeyT0, std::unordered_map<KeyT1, std::unordered_map<KeyT2, MappedT>>>\n        map;\n    auto* last_map = &map[arg0][arg1];\n    auto iter = last_map->find(arg2);\n    if (iter == last_map->end()) { iter = last_map->emplace(arg2, func(arg0, arg1, arg2)).first; }\n    return iter->second;\n  }\n\n private:\n  static_assert(!StaticAny<IsOutArg, Arg0, Arg1, Arg2>::value, \"\");\n};\n\ntemplate<typename RetT, typename Arg0, typename Arg1, typename Arg2, typename Arg3,\n         typename... Args>\nstruct ThreadLocalCopiable<RetT, Arg0, Arg1, Arg2, Arg3, Args...> {\n  template<RetT (*func)(Arg0, Arg1, Arg2, Arg3, Args...)>\n  static RetT Call(Arg0 arg0, Arg1 arg1, Arg2 arg2, Arg3 arg3, Args... args) {\n    using KeyT0 = typename std::decay<Arg0>::type;\n    using KeyT1 = typename std::decay<Arg1>::type;\n    using KeyT2 = typename std::decay<Arg2>::type;\n    using KeyT3 = typename std::decay<Arg3>::type;\n    using KeyT = std::tuple<KeyT0, KeyT1, KeyT2, KeyT3, typename std::decay<Args>::type...>;\n    using MappedT = typename std::decay<RetT>::type;\n    static thread_local std::unordered_map<KeyT, MappedT> map;\n    const auto& key = KeyT(arg0, arg1, arg2, arg3, args...);\n    auto iter = map.find(key);\n    if (iter == map.end()) { iter = map.emplace(key, func(arg0, arg1, arg2, arg3, args...)).first; }\n    return iter->second;\n  }\n\n private:\n  static_assert(!StaticAny<IsOutArg, Arg0, Arg1, Arg2, Arg3, Args...>::value, \"\");\n};\n\n// for scalar type key.\ntemplate<typename RetT, typename... Args>\nstruct ThreadLocal : public ThreadLocalCopiable<RetT, Args...> {\n private:\n  static_assert(StaticAll<IsDecayedScalarType, Args...>::value, \"\");\n};\n\ntemplate<typename... Args>\nstruct ThreadLocalCachedCopiable;\n\ntemplate<typename RetT>\nstruct ThreadLocalCachedCopiable<RetT> {\n  template<RetT (*func)()>\n  static RetT Call() {\n    static thread_local RetT value = func();\n    return value;\n  }\n};\n\ntemplate<typename RetT, typename Arg0>\nstruct ThreadLocalCachedCopiable<RetT, Arg0> {\n  template<RetT (*func)(Arg0)>\n  static RetT Call(Arg0 arg0) {\n    using KeyT = typename std::decay<Arg0>::type;\n    using MappedT = typename std::decay<RetT>::type;\n    static thread_local std::unordered_map<KeyT, MappedT> map;\n    auto iter = map.find(arg0);\n    if (iter == map.end()) {\n      if (unlikely(map.size() >= ThreadLocalEnvInteger<ONEFLOW_THRAED_LOCAL_CACHED_SIZE>())) {\n        map.clear();\n      }\n      iter = map.emplace(arg0, func(arg0)).first;\n    }\n    return iter->second;\n  }\n\n private:\n  static_assert(!IsOutArg<Arg0>::value, \"\");\n  static_assert(!StaticAny<IsOutArg, Arg0>::value, \"\");\n};\n\ntemplate<typename RetT, typename Arg0, typename... Args>\nstruct ThreadLocalCachedCopiable<RetT, Arg0, Args...> {\n  template<RetT (*func)(Arg0, Args...)>\n  static RetT Call(Arg0 arg0, Args... args) {\n    using KeyT0 = typename std::decay<Arg0>::type;\n    using KeyT = std::tuple<KeyT0, typename std::decay<Args>::type...>;\n    using MappedT = typename std::decay<RetT>::type;\n    static thread_local std::unordered_map<KeyT, MappedT> map;\n    const auto& key = KeyT(arg0, args...);\n    auto iter = map.find(key);\n    if (iter == map.end()) {\n      if (unlikely(map.size() >= ThreadLocalEnvInteger<ONEFLOW_THRAED_LOCAL_CACHED_SIZE>())) {\n        map.clear();\n      }\n      iter = map.emplace(key, func(arg0, args...)).first;\n    }\n    return iter->second;\n  }\n\n private:\n  static_assert(!StaticAny<IsOutArg, Arg0, Args...>::value, \"\");\n};\n\n// for scalar type key.\ntemplate<typename RetT, typename... Args>\nstruct ThreadLocalCached : public ThreadLocalCachedCopiable<RetT, Args...> {\n private:\n  static_assert(StaticAll<IsDecayedScalarType, Args...>::value, \"\");\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_COMMON_DECORATOR_H_\n"
  },
  {
    "path": "oneflow/core/common/decorator_test.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"gtest/gtest.h\"\n#include \"oneflow/core/common/decorator.h\"\n#include \"oneflow/core/common/util.h\"\n\nnamespace oneflow {\nnamespace test {\n\nMaybe<int> Inc(int x) { return x + 1; }\n\nMaybe<int> IncByConstRef(const int& x) { return x + 1; }\n\nTEST(ThreadLocal, scalar) {\n  auto* CachedInc = DECORATE(&Inc, ThreadLocal);\n\n  int x = CHECK_JUST(CachedInc(0));\n  ASSERT_EQ(x, 1);\n}\n\nTEST(ThreadLocal, const_ref) {\n  auto* CachedIncByConstRef = DECORATE(&IncByConstRef, ThreadLocal);\n\n  int x = CHECK_JUST(CachedIncByConstRef(0));\n  ASSERT_EQ(x, 1);\n}\n\nnamespace {\n\nstruct Foo {\n  static Maybe<Foo> New(int x) { return std::shared_ptr<Foo>(new Foo{x}); }\n\n  int x;\n};\n\n}  // namespace\n\nTEST(ThreadLocal, _class) {\n  auto* CachedFooNew = DECORATE(&Foo::New, ThreadLocal);\n  const auto& foo = CHECK_JUST(CachedFooNew(10));\n  const auto& bar = CHECK_JUST(CachedFooNew(10));\n  ASSERT_EQ(foo->x, 10);\n  ASSERT_TRUE(foo == bar);\n}\n\n}  // namespace test\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/common/device.proto",
    "content": "syntax = \"proto2\";\npackage oneflow;\n\nimport \"oneflow/core/common/device_type.proto\";\n\nmessage DeviceProto {\n  required DeviceType device_type = 1;\n  required int64 device_id = 2;\n  optional bool rematable = 3 [default = false];\n}\n"
  },
  {
    "path": "oneflow/core/common/device_type.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/device_type.h\"\n\n#include <fmt/ranges.h>\n#include \"oneflow/core/ep/include/device_manager_registry.h\"\n\nnamespace oneflow {\n\nstd::vector<std::string> GetAllAvailableDeviceTypeNames() {\n  const auto& device_types = ep::DeviceManagerRegistry::GetRegisteredDeviceTypes();\n  std::vector<std::string> device_type_names;\n  device_type_names.reserve(device_types.size());\n  for (const auto& device_type : device_types) {\n    device_type_names.emplace_back(\n        ep::DeviceManagerRegistry::GetDeviceTypeNameByDeviceType(device_type));\n  }\n  return device_type_names;\n}\n\nstd::string PrintAvailableDevices() {\n  const auto& device_type_names = GetAllAvailableDeviceTypeNames();\n  return fmt::format(\"{}\", fmt::join(device_type_names, \", \"));\n}\n\nstd::string PrintGeneratorAvailableDevices() {\n  auto device_type_names = GetAllAvailableDeviceTypeNames();\n  device_type_names.emplace_back(\"auto\");  // \"auto\" is a fake device type for random generator.\n  return fmt::format(\"{}\", fmt::join(device_type_names, \", \"));\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/common/device_type.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_COMMON_DEVICE_TYPE_H_\n#define ONEFLOW_CORE_COMMON_DEVICE_TYPE_H_\n\n#include \"oneflow/core/common/device_type.pb.h\"\n\nnamespace std {\n\ntemplate<>\nstruct hash<oneflow::DeviceType> final {\n  size_t operator()(oneflow::DeviceType device_type) const {\n    return static_cast<size_t>(device_type);\n  }\n};\n\n}  // namespace std\n\nnamespace oneflow {\n\nstd::string PrintAvailableDevices();\nstd::string PrintGeneratorAvailableDevices();\n\n#if defined(WITH_CUDA)\n#define DEVICE_TYPE_SEQ                  \\\n  OF_PP_MAKE_TUPLE_SEQ(DeviceType::kCPU) \\\n  OF_PP_MAKE_TUPLE_SEQ(DeviceType::kCUDA)\n#else\n#define DEVICE_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(DeviceType::kCPU)\n#endif\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_COMMON_DEVICE_TYPE_H_\n"
  },
  {
    "path": "oneflow/core/common/device_type.proto",
    "content": "syntax = \"proto2\";\npackage oneflow;\n\nenum DeviceType {\n  kInvalidDevice = 0;\n  kCPU = 1;\n  kCUDA = 2;\n  kMockDevice = 3; // pseudo device for test.\n  kMeta = 4;\n  kMLU = 5;  // Cambricon MLU\n  kNPU = 6;  // Ascend NPU\n  kXPU = 7;  // KunLunXin\n}\n"
  },
  {
    "path": "oneflow/core/common/dtype_signature.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_REGISTER_DTYPE_SIGNATURE_H_\n#define ONEFLOW_CORE_REGISTER_DTYPE_SIGNATURE_H_\n\n#include \"oneflow/core/common/dtype_signature.pb.h\"\n#include \"oneflow/core/common/protobuf.h\"\n\nnamespace oneflow {\n\ninline bool operator==(const DTypeSignature& lhs, const DTypeSignature& rhs) {\n  return PbMd().Equals(lhs, rhs);\n}\n\n}  // namespace oneflow\n\nnamespace std {\n\ntemplate<>\nstruct hash<oneflow::DTypeSignature> final {\n  size_t operator()(const oneflow::DTypeSignature& dtype_signature) {\n    std::string serialized;\n    dtype_signature.SerializeToString(&serialized);\n    return std::hash<std::string>()(serialized);\n  }\n};\n\n}  // namespace std\n\n#endif  // ONEFLOW_CORE_REGISTER_DTYPE_SIGNATURE_H_\n"
  },
  {
    "path": "oneflow/core/common/dtype_signature.proto",
    "content": "syntax = \"proto2\";\npackage oneflow;\n\nimport \"oneflow/core/common/data_type.proto\";\n\nmessage DTypeSignature {\n  map<string, DataType> name2dtype = 1;\n}\n"
  },
  {
    "path": "oneflow/core/common/eigen_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_COMMON_EIGEN_UTIL_H_\n#define ONEFLOW_CORE_COMMON_EIGEN_UTIL_H_\n\n#include \"Eigen/Core\"\n#include \"Eigen/Dense\"\n\nnamespace oneflow {\n\ntemplate<typename T>\nusing EigenMatrixMap = Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>;\n\ntemplate<typename T>\nusing EigenArrayMap = Eigen::Map<Eigen::Array<T, Eigen::Dynamic, Eigen::Dynamic>>;\n\ntemplate<typename T>\nusing ConstEigenMatrixMap = Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>;\n\ntemplate<typename T>\nusing ConstEigenArrayMap = Eigen::Map<const Eigen::Array<T, Eigen::Dynamic, Eigen::Dynamic>>;\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_COMMON_EIGEN_UTIL_H_\n"
  },
  {
    "path": "oneflow/core/common/either_ptr.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_COMMON_EITHER_PTR_H_\n#define ONEFLOW_CORE_COMMON_EITHER_PTR_H_\n\n#include <memory>\n#include \"oneflow/core/common/throw.h\"\n\nnamespace oneflow {\n\ntemplate<typename X, typename Y>\nclass EitherPtr final {\n public:\n  static_assert(!std::is_same<X, Y>::value, \"X should not be Y\");\n\n  using XPtr = std::shared_ptr<X>;\n  using YPtr = std::shared_ptr<Y>;\n\n  // WARNING: we should assume that the structure of shared_ptr<X> and shared_ptr<Y> is same,\n  // and obviously at most time the assumption holds\n  static_assert(sizeof(XPtr) == sizeof(YPtr), \"unsupported shared_ptr implementation\");\n\n  EitherPtr() : type_(UnionType<X>::value), x_ptr_(nullptr) {}\n  EitherPtr(const XPtr& ptr) : type_(UnionType<X>::value), x_ptr_(ptr) {}\n  EitherPtr(const YPtr& ptr) : type_(UnionType<Y>::value) { new (&x_ptr_) YPtr(ptr); }\n\n  EitherPtr(XPtr&& ptr) : type_(UnionType<X>::value), x_ptr_(std::move(ptr)) {}\n  EitherPtr(YPtr&& ptr) : type_(UnionType<Y>::value) { new (&x_ptr_) YPtr(std::move(ptr)); }\n\n  EitherPtr(const EitherPtr& either_ptr) : type_(either_ptr.type_), x_ptr_(either_ptr.x_ptr_) {}\n  EitherPtr(EitherPtr&& either_ptr)\n      : type_(either_ptr.type_), x_ptr_(std::move(either_ptr.x_ptr_)) {}\n\n  // the destructor of X or Y will be called properly because it will be stored in the deleter of\n  // shared_ptr while constructed\n  ~EitherPtr() = default;\n\n  EitherPtr& operator=(const EitherPtr& either_ptr) {\n    x_ptr_ = either_ptr.x_ptr_;\n    type_ = either_ptr.type_;\n    return *this;\n  }\n\n  EitherPtr& operator=(EitherPtr&& either_ptr) {\n    x_ptr_ = std::move(either_ptr.x_ptr_);\n    type_ = either_ptr.type_;\n    return *this;\n  }\n\n  template<typename T>\n  bool Has() const {\n    return type_ == UnionType<T>::value;\n  }\n\n  template<typename T>\n  const std::shared_ptr<T>& Get() const {\n    return Get(tag<T>{});\n  }\n\n private:\n  template<typename T, typename Enable = void>\n  struct UnionType;\n  template<typename T>\n  struct UnionType<T, typename std::enable_if<std::is_same<X, T>::value>::type> {\n    static constexpr int8_t value = 0;\n  };\n  template<typename T>\n  struct UnionType<T, typename std::enable_if<std::is_same<Y, T>::value>::type> {\n    static constexpr int8_t value = 1;\n  };\n\n  template<typename>\n  struct tag {};\n\n  const XPtr& Get(tag<X>) const {\n    CHECK(Has<X>());\n    return x_ptr_;\n  }\n\n  const YPtr& Get(tag<Y>) const {\n    CHECK(Has<Y>());\n    const auto* __attribute__((__may_alias__)) ptr = reinterpret_cast<const YPtr*>(&x_ptr_);\n    return *ptr;\n  }\n\n  int8_t type_;\n  std::shared_ptr<X> x_ptr_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_COMMON_EITHER_PTR_H_\n"
  },
  {
    "path": "oneflow/core/common/env_var/bootstrap.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_COMMON_ENV_VAR_BOOTSTRAP_H_\n#define ONEFLOW_CORE_COMMON_ENV_VAR_BOOTSTRAP_H_\n\n#include \"oneflow/core/common/env_var/env_var.h\"\n\nnamespace oneflow {\n\nDEFINE_ENV_INTEGER(ONEFLOW_RPC_BOOTSTRAP_SERVER_SLEEP_SECONDS, 20);\nDEFINE_ENV_INTEGER(ONEFLOW_RPC_BOOTSTRAP_SERVER_MAX_RETRY_TIMES, 3);\nDEFINE_ENV_INTEGER(ONEFLOW_RPC_CLIENT_SLEEP_SECONDS, 5);\nDEFINE_ENV_INTEGER(ONEFLOW_RPC_CLIENT_MAX_RETRY_TIMES, 6);\n\n}  // namespace oneflow\n#endif  // ONEFLOW_CORE_COMMON_ENV_VAR_BOOTSTRAP_H_\n"
  },
  {
    "path": "oneflow/core/common/env_var/debug_mode.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_COMMON_ENV_VAR_DEBUG_MODE_H_\n#define ONEFLOW_CORE_COMMON_ENV_VAR_DEBUG_MODE_H_\n\n#include \"oneflow/core/common/env_var/env_var.h\"\n\nnamespace oneflow {\n\nDEFINE_ENV_BOOL(ONEFLOW_DEBUG_MODE, false);\nDEFINE_ENV_BOOL(ONEFLOW_DEBUG, false);\n\ninline bool IsInDebugMode() { return EnvBool<ONEFLOW_DEBUG_MODE>() || EnvBool<ONEFLOW_DEBUG>(); }\n\nDEFINE_ENV_BOOL(ENABLE_ACTOR_DEBUG_LOG, false);\ninline bool EnableActorDebugLog() { return EnvBool<ENABLE_ACTOR_DEBUG_LOG>(); }\n\nDEFINE_ENV_BOOL(ENABLE_LOGICAL_CHAIN, true);\ninline bool EnableLogicalChain() { return EnvBool<ENABLE_LOGICAL_CHAIN>(); }\n\nDEFINE_ENV_BOOL(ENABLE_NCCL_LOGICAL_FUSION, true);\ninline bool EnableNcclLogicalFusion() { return EnvBool<ENABLE_NCCL_LOGICAL_FUSION>(); }\n\ninline bool IsPythonStackGetterEnabledByDebugBuild() {\n  if (std::getenv(\"ONEFLOW_DEBUG_MODE\") == nullptr && std::getenv(\"ONEFLOW_DEBUG\") == nullptr\n      && std::getenv(\"ONEFLOW_PYTHON_STACK_GETTER\") == nullptr) {\n    return std::string(OF_PP_STRINGIZE(ONEFLOW_CMAKE_BUILD_TYPE)) == \"Debug\";\n  }\n  return false;\n}\n\ninline bool IsPythonStackGetterEnabled() {\n  if (IsPythonStackGetterEnabledByDebugBuild()) { return true; }\n  return ParseBooleanFromEnv(\"ONEFLOW_PYTHON_STACK_GETTER\", IsInDebugMode());\n}\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_COMMON_ENV_VAR_DEBUG_MODE_H_\n"
  },
  {
    "path": "oneflow/core/common/env_var/eager.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_COMMON_ENV_VAR_EAGER_H_\n#define ONEFLOW_CORE_COMMON_ENV_VAR_EAGER_H_\n\n#include \"oneflow/core/common/env_var/env_var.h\"\n#ifdef WITH_CUDA\n#include <nccl.h>\n#endif\n\nnamespace oneflow {\n\n// NOTE: use env variable 'ONEFLOW_EAGER_ENABLE_LOCAL_INFER_CACHE' indicate whether the\n// use infer cache in naive local op interpret.\nDEFINE_THREAD_LOCAL_ENV_BOOL(ONEFLOW_EAGER_ENABLE_LOCAL_INFER_CACHE, true);\n\n// NOTE: use env variable 'ONEFLOW_EAGER_TENSOR_INFER_CACHE_SIZE' indicate the size of\n// infer cache in op interpret.\nDEFINE_THREAD_LOCAL_ENV_INTEGER(ONEFLOW_EAGER_TENSOR_INFER_CACHE_SIZE, 128 * 1024);\n\nDEFINE_THREAD_LOCAL_ENV_BOOL(ONEFLOW_EAGER_NCCL_USE_COMPUTE_STREAM, false);\n\ninline bool EagerNcclUseComputeStream() {\n#if defined(WITH_CUDA) && NCCL_VERSION_CODE > 2700\n  static bool eager_nccl_use_compute_stream =\n      ThreadLocalEnvBool<ONEFLOW_EAGER_NCCL_USE_COMPUTE_STREAM>();\n  return eager_nccl_use_compute_stream;\n#else\n  return false;\n#endif\n}\n\n}  // namespace oneflow\n#endif  // ONEFLOW_CORE_COMMON_ENV_VAR_EAGER_H_\n"
  },
  {
    "path": "oneflow/core/common/env_var/env_var.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_COMMON_ENV_VAR_ENV_VAR_H_\n#define ONEFLOW_CORE_COMMON_ENV_VAR_ENV_VAR_H_\n\n#include \"oneflow/core/common/util.h\"\n\nnamespace oneflow {\n\ntemplate<typename env_var>\nbool EnvBool();\n\n#define DEFINE_ENV_BOOL(env_var, default_value)                          \\\n  struct env_var {};                                                     \\\n  template<>                                                             \\\n  inline bool EnvBool<env_var>() {                                       \\\n    return ParseBooleanFromEnv(OF_PP_STRINGIZE(env_var), default_value); \\\n  }\n\ntemplate<typename env_var>\nint64_t EnvInteger();\n\n#define DEFINE_ENV_INTEGER(env_var, default_value)                       \\\n  struct env_var {};                                                     \\\n  template<>                                                             \\\n  inline int64_t EnvInteger<env_var>() {                                 \\\n    return ParseIntegerFromEnv(OF_PP_STRINGIZE(env_var), default_value); \\\n  }\n\nDEFINE_ENV_INTEGER(ONEFLOW_TIMEOUT_SECONDS, 7200);\nDEFINE_ENV_INTEGER(ONEFLOW_CHECK_TIMEOUT_SLEEP_SECONDS, EnvInteger<ONEFLOW_TIMEOUT_SECONDS>());\n\nDEFINE_ENV_INTEGER(ONEFLOW_VM_BLOCKING_DEBUG_INSTRUCTIONS_DISPLAY_LIMIT, 100);\nDEFINE_ENV_INTEGER(ONEFLOW_DELETE_OUTDATED_SHM_NAMES_INTERVAL, 1000);\n\ntemplate<typename env_var>\nbool ThreadLocalEnvBool();\n\n#define DEFINE_THREAD_LOCAL_ENV_BOOL(env_var, default_value)                                \\\n  struct env_var {};                                                                        \\\n  template<>                                                                                \\\n  inline bool ThreadLocalEnvBool<env_var>() {                                               \\\n    thread_local bool value = ParseBooleanFromEnv(OF_PP_STRINGIZE(env_var), default_value); \\\n    return value;                                                                           \\\n  }\n\ntemplate<typename env_var>\nint64_t ThreadLocalEnvInteger();\n\n#define DEFINE_THREAD_LOCAL_ENV_INTEGER(env_var, default_value)                                \\\n  struct env_var {};                                                                           \\\n  template<>                                                                                   \\\n  inline int64_t ThreadLocalEnvInteger<env_var>() {                                            \\\n    thread_local int64_t value = ParseIntegerFromEnv(OF_PP_STRINGIZE(env_var), default_value); \\\n    return value;                                                                              \\\n  }\n\nDEFINE_THREAD_LOCAL_ENV_INTEGER(ONEFLOW_THRAED_LOCAL_CACHED_SIZE, 128 * 1024);\n\ntemplate<typename env_var>\nconst std::string& ThreadLocalEnvString();\n\n#define DEFINE_THREAD_LOCAL_ENV_STRING(env_var, default_value)                                  \\\n  struct env_var {};                                                                            \\\n  template<>                                                                                    \\\n  inline const std::string& ThreadLocalEnvString<env_var>() {                                   \\\n    thread_local std::string value = GetStringFromEnv(OF_PP_STRINGIZE(env_var), default_value); \\\n    return value;                                                                               \\\n  }\n\nDEFINE_THREAD_LOCAL_ENV_BOOL(ONEFLOW_ENABLE_LAZY_SEPARATE_COMPILE, false);\n// Default compilation mode during graph compilation. There 2 modes to choose:\n// \"naive\", master rank compile the full plan.\n// \"rank_per_process\", multi process(rank) run seperation compile.\nDEFINE_THREAD_LOCAL_ENV_STRING(ONEFLOW_LAZY_COMPILE_MODE, \"naive\");\n// Default number of threads during graph compilation.\nDEFINE_THREAD_LOCAL_ENV_INTEGER(ONEFLOW_LAZY_COMPILE_RPC_THREAD_NUM, 16);\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_COMMON_ENV_VAR_ENV_VAR_H_\n"
  },
  {
    "path": "oneflow/core/common/env_var/remat.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#pragma once\n\n#include \"oneflow/core/common/env_var/env_var.h\"\n\nnamespace oneflow {\n\nDEFINE_ENV_BOOL(ONEFLOW_REMAT_DISPLAY_IN_FIRST_TIME, false);\nDEFINE_ENV_BOOL(ONEFLOW_REMAT_RECORD_MEM_FRAG_RATE, true);\nDEFINE_ENV_INTEGER(ONEFLOW_REMAT_GROUP_NUM, 1);\nDEFINE_ENV_BOOL(ONEFLOW_REMAT_NEIGHBOR, true);\nDEFINE_ENV_BOOL(ONEFLOW_REMAT_HEURISTIC_DTE, false);\nDEFINE_ENV_BOOL(ONEFLOW_REMAT_HEURISTIC_DTR, false);\nDEFINE_ENV_BOOL(ONEFLOW_REMAT_LOG, false);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/common/env_var/stream.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_COMMON_ENV_VAR_STREAM_H_\n#define ONEFLOW_CORE_COMMON_ENV_VAR_STREAM_H_\n\n#include \"oneflow/core/common/env_var/env_var.h\"\n\nnamespace oneflow {\n\nDEFINE_THREAD_LOCAL_ENV_INTEGER(ONEFLOW_DEVICE_STREAM_MAX_SIZE, 16);\nDEFINE_THREAD_LOCAL_ENV_BOOL(ONEFLOW_STREAM_ENABLE_H2D_STREAM, false);\n\n}  // namespace oneflow\n#endif  // ONEFLOW_CORE_COMMON_ENV_VAR_STREAM_H_\n"
  },
  {
    "path": "oneflow/core/common/env_var/vm.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_COMMON_ENV_VAR_VM_H_\n#define ONEFLOW_CORE_COMMON_ENV_VAR_VM_H_\n\n#include \"oneflow/core/common/env_var/env_var.h\"\n\nnamespace oneflow {\n\nDEFINE_THREAD_LOCAL_ENV_BOOL(ONEFLOW_VM_COMPUTE_ON_WORKER_THREAD, true);\nDEFINE_THREAD_LOCAL_ENV_BOOL(ONEFLOW_VM_ENABLE_STREAM_WAIT, true);\nDEFINE_THREAD_LOCAL_ENV_INTEGER(ONEFLOW_VM_PENDING_HANDLE_WINDOW_SIZE, 10)\nDEFINE_THREAD_LOCAL_ENV_BOOL(ONEFLOW_VM_ENABLE_SCHEDULE_YIELD, true)\nDEFINE_THREAD_LOCAL_ENV_INTEGER(ONEFLOW_VM_WORKER_THREAD_LIMIT, 16);\nDEFINE_THREAD_LOCAL_ENV_BOOL(ONEFLOW_VM_MULTI_THREAD, true);\n\n}  // namespace oneflow\n#endif  // ONEFLOW_CORE_COMMON_ENV_VAR_VM_H_\n"
  },
  {
    "path": "oneflow/core/common/error.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <stdexcept>\n#include \"fmt/core.h\"\n#include \"fmt/color.h\"\n#include \"fmt/ostream.h\"\n#include \"oneflow/core/common/error.h\"\n#include \"oneflow/core/common/exception.h\"\n#include \"oneflow/core/common/protobuf.h\"\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/common/error_util.h\"\n#include \"oneflow/core/common/env_var/debug_mode.h\"\n#include \"oneflow/extension/stack/foreign_stack_getter.h\"\n#include \"oneflow/extension/stack/stacktrace.h\"\n#include \"oneflow/core/thread/thread_manager.h\"\n\nnamespace oneflow {\n\nStackedError::StackedError() : stack_frame_(), error_proto_(new ErrorProto()) {}\n\nnamespace {\n\nvoid LogError(const Error& error) {\n  // gdb break point\n  LOG(ERROR) << error->msg();\n}\n\nstd::shared_ptr<StackedError>* MutThreadLocalError() {\n  thread_local std::shared_ptr<StackedError> error;\n  return &error;\n}\n\n}  // namespace\n\nError&& Error::AddStackFrame(Symbol<ErrorStackFrame> error_stack_frame) {\n  stacked_error_->add_stack_frame(error_stack_frame);\n  return std::move(*this);\n}\n\nError&& Error::GetStackTrace(int64_t depth, int64_t skip_n_firsts) {\n  backward::StackTrace st;\n  backward::SnippetFactory snippets;\n  backward::TraceResolver resolver;\n  st.load_here(depth);\n  st.skip_n_firsts(skip_n_firsts);\n  resolver.load_stacktrace(st);\n\n  for (int i = 0; i < st.size(); i++) {\n    const auto& trace = resolver.resolve(st[i]);\n    if (!backward::Printer::is_oneflow_file(trace.object_filename)) { continue; }\n\n    //  without debug info\n    if (!trace.source.filename.size()) {\n      stacked_error_->add_stack_frame(\n          SymbolOf(ErrorStackFrame(trace.object_filename, -1, trace.object_function)));\n    }\n\n    //  with debug info\n    if (trace.source.filename.size()) {\n      const backward::ResolvedTrace::SourceLoc& source_loc = trace.source;\n      backward::SnippetFactory::lines_t lines =\n          snippets.get_snippet(source_loc.filename, source_loc.line, static_cast<unsigned>(1));\n      std::string code_text = lines[0].second;\n      const auto pos = code_text.find_first_not_of(\" \\t\");\n      code_text = code_text.substr(pos, code_text.size() - pos);\n      stacked_error_->add_stack_frame(SymbolOf(\n          ErrorStackFrame(source_loc.filename, source_loc.line, source_loc.function, code_text)));\n    }\n\n    for (size_t inliner_idx = 0; inliner_idx < trace.inliners.size(); ++inliner_idx) {\n      const backward::ResolvedTrace::SourceLoc& source_loc = trace.inliners[inliner_idx];\n      backward::SnippetFactory::lines_t lines =\n          snippets.get_snippet(source_loc.filename, source_loc.line, static_cast<unsigned>(1));\n      std::string code_text = lines[0].second;\n      const auto pos = code_text.find_first_not_of(\" \\t\");\n      code_text = code_text.substr(pos, code_text.size() - pos);\n      stacked_error_->add_stack_frame(SymbolOf(\n          ErrorStackFrame(source_loc.filename, source_loc.line, source_loc.function, code_text)));\n    }\n  }\n  return std::move(*this);\n}\n\nvoid Error::Merge(const Error& other) {\n  auto* error_proto = stacked_error_->mut_error_proto();\n  error_proto->MergeFrom(*other.stacked_error_->error_proto());\n}\n\nError::operator std::string() const { return stacked_error_->DebugString(); }\n\nError Error::Ok() { return std::make_shared<StackedError>(); }\n\nError Error::ProtoParseFailedError() {\n  auto error = std::make_shared<StackedError>();\n  error->mut_error_proto()->mutable_proto_parse_failed_error();\n  return error;\n}\n\nError Error::JobSetEmptyError() {\n  auto error = std::make_shared<StackedError>();\n  error->mut_error_proto()->mutable_job_set_empty_error();\n  return error;\n}\n\nError Error::DeviceTagNotFoundError() {\n  auto error = std::make_shared<StackedError>();\n  error->mut_error_proto()->mutable_device_tag_not_found_error();\n  return error;\n}\n\nError Error::InvalidValueError() {\n  auto error = std::make_shared<StackedError>();\n  error->mut_error_proto()->mutable_invalid_value_error();\n  return error;\n}\n\nError Error::IndexError() {\n  auto error = std::make_shared<StackedError>();\n  error->mut_error_proto()->mutable_index_error();\n  return error;\n}\n\nError Error::TypeError() {\n  auto error = std::make_shared<StackedError>();\n  error->mut_error_proto()->mutable_type_error();\n  return error;\n}\n\nError Error::TimeoutError() {\n  auto error = std::make_shared<StackedError>();\n  error->mut_error_proto()->mutable_timeout_error();\n  return error;\n}\n\nError Error::JobNameExistError() {\n  auto error = std::make_shared<StackedError>();\n  error->mut_error_proto()->mutable_job_name_exist_error();\n  return error;\n}\n\nError Error::JobNameEmptyError() {\n  auto error = std::make_shared<StackedError>();\n  error->mut_error_proto()->mutable_job_name_empty_error();\n  return error;\n}\n\nError Error::JobNameNotEqualError() {\n  auto error = std::make_shared<StackedError>();\n  error->mut_error_proto()->mutable_job_name_not_equal_error();\n  return error;\n}\n\nError Error::NoJobBuildAndInferCtxError() {\n  auto error = std::make_shared<StackedError>();\n  error->mut_error_proto()->mutable_no_job_build_and_infer_ctx_error();\n  return error;\n}\n\nError Error::JobConfFrozenError() {\n  auto error = std::make_shared<StackedError>();\n  error->mut_error_proto()->mutable_job_conf_frozen_error();\n  return error;\n}\n\nError Error::JobConfNotSetError() {\n  auto error = std::make_shared<StackedError>();\n  error->mut_error_proto()->mutable_job_conf_not_set_error();\n  return error;\n}\n\nError Error::JobConfRepeatedSetError() {\n  auto error = std::make_shared<StackedError>();\n  error->mut_error_proto()->mutable_job_conf_repeated_set_error();\n  return error;\n}\n\nError Error::JobTypeNotSetError() {\n  auto error = std::make_shared<StackedError>();\n  error->mut_error_proto()->mutable_job_type_not_set_error();\n  return error;\n}\n\nError Error::LogicalBlobNameNotExistError() {\n  auto error = std::make_shared<StackedError>();\n  error->mut_error_proto()->mutable_logical_blob_name_not_exist_error();\n  return error;\n}\n\nError Error::LogicalBlobNameExistError() {\n  auto error = std::make_shared<StackedError>();\n  error->mut_error_proto()->mutable_logical_blob_name_exist_error();\n  return error;\n}\n\nError Error::LogicalBlobNameInvalidError() {\n  auto error = std::make_shared<StackedError>();\n  error->mut_error_proto()->mutable_logical_blob_name_invalid_error();\n  return error;\n}\n\nError Error::OpNameExistError() {\n  auto error = std::make_shared<StackedError>();\n  error->mut_error_proto()->mutable_op_name_exist_error();\n  return error;\n}\n\nError Error::OpConfDeviceTagNoSetError() {\n  auto error = std::make_shared<StackedError>();\n  error->mut_error_proto()->mutable_op_conf_device_tag_no_set_error();\n  return error;\n}\n\nError Error::PlacementError() {\n  auto error = std::make_shared<StackedError>();\n  error->mut_error_proto()->mutable_placement_error();\n  return error;\n}\n\nError Error::BlobSplitAxisInferError() {\n  auto error = std::make_shared<StackedError>();\n  error->mut_error_proto()->mutable_blob_split_axis_infer_error();\n  return error;\n}\n\nError Error::UnknownJobBuildAndInferError() {\n  auto error = std::make_shared<StackedError>();\n  error->mut_error_proto()->mutable_unknown_job_build_and_infer_error();\n  return error;\n}\n\nError Error::CheckFailedError() {\n  auto error = std::make_shared<StackedError>();\n  error->mut_error_proto()->mutable_check_failed_error();\n  return error;\n}\n\nError Error::ValueNotFoundError() {\n  auto error = std::make_shared<StackedError>();\n  error->mut_error_proto()->mutable_value_not_found_error();\n  return error;\n}\n\nError Error::TodoError() {\n  auto error = std::make_shared<StackedError>();\n  error->mut_error_proto()->mutable_todo_error();\n  return error;\n}\n\nError Error::UnimplementedError() {\n  auto error = std::make_shared<StackedError>();\n  error->mut_error_proto()->mutable_unimplemented_error();\n  return error;\n}\n\nError Error::RuntimeError() {\n  auto error = std::make_shared<StackedError>();\n  error->mut_error_proto()->mutable_runtime_error();\n  return error;\n}\n\nError Error::OutOfMemoryError() {\n  auto error = std::make_shared<StackedError>();\n  error->mut_error_proto()->mutable_out_of_memory_error();\n  return error;\n}\n\nError Error::BoxingNotSupportedError() {\n  auto error = std::make_shared<StackedError>();\n  error->mut_error_proto()->mutable_boxing_not_supported_error();\n  return error;\n}\n\nError Error::OpKernelNotFoundError(const std::vector<std::string>& error_msgs) {\n  auto error = std::make_shared<StackedError>();\n  auto* op_kernel_not_found_error = error->mut_error_proto()->mutable_op_kernel_not_found_error();\n  for (const auto& msg : error_msgs) {\n    op_kernel_not_found_error->add_op_kernels_not_found_debug_str(msg);\n  }\n  return error;\n}\n\nError Error::MultipleOpKernelsMatchedError(const std::vector<std::string>& error_msgs) {\n  auto error = std::make_shared<StackedError>();\n  auto* multiple_op_kernels_matched_error =\n      error->mut_error_proto()->mutable_multiple_op_kernels_matched_error();\n  for (const auto& msg : error_msgs) {\n    multiple_op_kernels_matched_error->add_matched_op_kernels_debug_str(msg);\n  }\n  return error;\n}\n\nError Error::MemoryZoneOutOfMemoryError(int64_t machine_id, int64_t mem_zone_id, uint64_t calc,\n                                        uint64_t available, const std::string& device_tag) {\n  auto error = std::make_shared<StackedError>();\n  auto* memory_zone_out_of_memory_error =\n      error->mut_error_proto()->mutable_memory_zone_out_of_memory_error();\n  memory_zone_out_of_memory_error->add_machine_id(std::to_string(machine_id));\n  memory_zone_out_of_memory_error->add_mem_zone_id(std::to_string(mem_zone_id));\n  memory_zone_out_of_memory_error->add_device_tag(device_tag);\n  memory_zone_out_of_memory_error->add_available(std::to_string(available) + \" bytes\");\n  memory_zone_out_of_memory_error->add_required(std::to_string(calc) + \" bytes\");\n  return error;\n}\n\nError Error::LossBlobNotFoundError() {\n  auto error = std::make_shared<StackedError>();\n  error->mut_error_proto()->mutable_loss_blob_not_found_error();\n  return error;\n}\n\nError Error::RwMutexedObjectNotFoundError() {\n  auto error = std::make_shared<StackedError>();\n  error->mut_error_proto()->mutable_rw_mutexed_object_not_found_error();\n  return error;\n}\n\nError Error::GradientFunctionNotFoundError() {\n  auto error = std::make_shared<StackedError>();\n  error->mut_error_proto()->mutable_gradient_function_not_found_error();\n  return error;\n}\n\nError Error::SymbolIdUninitializedError() {\n  auto error = std::make_shared<StackedError>();\n  error->mut_error_proto()->mutable_symbol_id_uninitialized_error();\n  return error;\n}\n\nError Error::CompileOptionWrongError() {\n  auto error = std::make_shared<StackedError>();\n  error->mut_error_proto()->mutable_compile_option_wrong_error();\n  return error;\n}\n\nError Error::InputDeviceNotMatchError() {\n  auto error = std::make_shared<StackedError>();\n  auto* input_device_not_match_error =\n      error->mut_error_proto()->mutable_input_device_not_match_error();\n  input_device_not_match_error->add_info(\n      std::string(\"Input tensors are at different devices, please try to use tensor.to or \"\n                  \"module.to to correct it.\"));\n  return error;\n}\n\nstd::string GetStackedErrorString(const std::shared_ptr<StackedError>& error) {\n  const auto& maybe_error = TRY(FormatErrorStr(error));\n  const auto& error_str = maybe_error.GetDataAndStackedError(error->DebugString());\n  CHECK_NE(error->error_proto()->error_type_case(), ErrorProto::ERROR_TYPE_NOT_SET);\n  return error_str.first;\n}\n\nstd::string GetErrorString(const std::shared_ptr<StackedError>& error) {\n  std::string error_str;\n  if (IsInDebugMode()) {\n    error_str = GetStackedErrorString(error);\n  } else {\n    error_str = error->error_proto()->msg();\n  }\n  if (error_str.empty()) { error_str = \"<No error message>\"; }\n  return error_str;\n}\n\nvoid ThrowError(const std::shared_ptr<StackedError>& error) {\n  std::string error_str;\n  fmt::format_to(std::back_inserter(error_str), \"{}: {}\\n\",\n                 fmt::styled(\"Error\", fmt::emphasis::bold | fmt::fg(fmt::color::red)),\n                 GetErrorString(error));\n  // Append foreign stack trace (e.g. Python stack trace) when it is available.\n  if (ForeignFrameThreadLocalGuard::Current().has_value()) {\n    auto frame = *CHECK_JUST(ForeignFrameThreadLocalGuard::Current());\n    if (!IsMainThread()) {\n      if (auto* stack_getter = Singleton<ForeignStackGetter>::Get()) {\n        fmt::format_to(std::back_inserter(error_str),\n                       fmt::emphasis::bold | fmt::fg(fmt::color::dark_orange),\n                       \"Related Python stack trace:\");\n        if (IsPythonStackGetterEnabledByDebugBuild()) {\n          fmt::format_to(\n              std::back_inserter(error_str),\n              \" (You are seeing this stack trace because you compiled OneFlow with \"\n              \"CMAKE_BUILD_TYPE=Debug. If you want to see it even with other CMAKE_BUILD_TYPEs, \"\n              \"you can set ONEFLOW_DEBUG or ONEFLOW_PYTHON_STACK_GETTER to 1)\");\n        }\n        fmt::format_to(std::back_inserter(error_str), \"\\n{}\",\n                       stack_getter->GetFormattedStack(frame));\n      } else {\n        fmt::format_to(\n            std::back_inserter(error_str),\n            \"You can set {} or {} to 1 to get the Python stack of the error.\",\n            fmt::styled(\"ONEFLOW_DEBUG\", fmt::emphasis::bold | fmt::fg(fmt::color::dark_orange)),\n            fmt::styled(\"ONEFLOW_PYTHON_STACK_GETTER\",\n                        fmt::emphasis::bold | fmt::fg(fmt::color::dark_orange)));\n      }\n    }\n  }\n  *MutThreadLocalError() = error;\n  if ((*error)->has_runtime_error()) { throw RuntimeException(error_str); }\n  if ((*error)->has_type_error()) { throw TypeException(error_str); }\n  if ((*error)->has_index_error()) { throw IndexException(error_str); }\n  if ((*error)->has_unimplemented_error()) { throw NotImplementedException(error_str); }\n  throw Exception(GetStackedErrorString(error));\n}\n\nconst std::shared_ptr<StackedError>& ThreadLocalError() { return *MutThreadLocalError(); }\n\nconst char* kOfBugIssueUploadPrompt = \"This is a oneflow bug, please submit an issue at \"\n                                      \"'https://github.com/Oneflow-Inc/oneflow/issues' including \"\n                                      \"the log information of the error, the \"\n                                      \"minimum reproduction code, and the system information.\";\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/common/error.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_COMMON_ERROR_H_\n#define ONEFLOW_CORE_COMMON_ERROR_H_\n\n#include <sstream>\n#include <vector>\n#include <functional>\n#include <filesystem>\n#include \"oneflow/core/common/error.pb.h\"\n#include \"oneflow/core/common/check.h\"\n#include \"oneflow/core/common/symbol.h\"\n#include \"oneflow/core/common/small_vector.h\"\n#include \"oneflow/core/common/hash.h\"\n\nnamespace {\nstd::string RemoveProjectPathPrefix(const std::string& filename) {\n#if defined(ONEFLOW_SOURCE_DIR) && defined(ONEFLOW_BINARY_DIR)\n  std::string project_path = ONEFLOW_SOURCE_DIR;\n  std::string project_build_path = ONEFLOW_BINARY_DIR;\n  if (filename.rfind(project_build_path, 0) == 0) {\n    return std::filesystem::relative(filename, project_build_path);\n  } else if (filename.rfind(project_path, 0) == 0) {\n    return std::filesystem::relative(filename, project_path);\n  } else {\n    return filename;\n  }\n#else\n  return filename;\n#endif\n}\n}  // namespace\n\nnamespace oneflow {\n\nclass ErrorStackFrame final {\n public:\n  ErrorStackFrame(const ErrorStackFrame&) = default;\n  ErrorStackFrame(const std::string& file, int64_t line, const std::string& function)\n      : file_(RemoveProjectPathPrefix(file)), line_(line), function_(function), code_text_() {}\n  ErrorStackFrame(const std::string& file, int64_t line, const std::string& function,\n                  const std::string& code_text)\n      : file_(RemoveProjectPathPrefix(file)),\n        line_(line),\n        function_(function),\n        code_text_(code_text) {}\n\n  bool operator==(const ErrorStackFrame& other) const {\n    return this->file_ == other.file_ && this->line_ == other.line_\n           && this->function_ == other.function_ && this->code_text_ == other.code_text_;\n  }\n\n  const std::string& file() const { return file_; }\n  int64_t line() const { return line_; }\n  const std::string& function() const { return function_; }\n  const std::string& code_text() const { return code_text_; }\n\n  std::string DebugString() const {\n    return file_ + \":\" + std::to_string(line_) + \" \" + function_ + \"\\n\\t\" + code_text_ + \"\\n\";\n  }\n\n private:\n  std::string file_;\n  int64_t line_;\n  std::string function_;\n  std::string code_text_;\n};\n\n}  // namespace oneflow\n\nnamespace std {\n\ntemplate<>\nstruct hash<::oneflow::ErrorStackFrame> final {\n  size_t operator()(const ::oneflow::ErrorStackFrame& frame) const {\n    using namespace oneflow;\n    return Hash(frame.file(), frame.line(), frame.function(), frame.code_text());\n  }\n};\n\n}  // namespace std\n\nnamespace oneflow {\n\nclass StackedError final {\n public:\n  StackedError();\n  StackedError(const StackedError&) = default;\n\n  constexpr static int kStackReservedSize = 16;\n  using FrameVector = small_vector<Symbol<ErrorStackFrame>, kStackReservedSize>;\n\n  const ErrorProto* operator->() const { return error_proto().get(); }\n  ErrorProto* operator->() { return mut_error_proto(); }\n\n  // Getters\n  const FrameVector& stack_frame() const { return stack_frame_; }\n  const std::shared_ptr<const ErrorProto>& error_proto() const { return error_proto_; }\n  std::string DebugString() const {\n    std::string str;\n    for (const auto& frame : stack_frame()) { str += frame->DebugString() + \"\\n\"; }\n    str += error_proto()->DebugString();\n    return str;\n  }\n\n  // Setters\n  void add_stack_frame(Symbol<ErrorStackFrame> error_frame) { stack_frame_.push_back(error_frame); }\n  ErrorProto* mut_error_proto() { return const_cast<ErrorProto*>(error_proto_.get()); }\n\n private:\n  FrameVector stack_frame_;\n  std::shared_ptr<const ErrorProto> error_proto_;\n};\n\nstd::string GetErrorString(const std::shared_ptr<StackedError>& error);\n\nclass Error final {\n public:\n  Error(const std::shared_ptr<StackedError>& stacked_error)\n      : stacked_error_(stacked_error), msg_collecting_mode_(kMergeMessage) {}\n  Error(const Error&) = default;\n  ~Error() = default;\n\n  std::shared_ptr<StackedError> stacked_error() const { return stacked_error_; }\n  const ErrorProto* operator->() const { return stacked_error_->error_proto().get(); }\n  ErrorProto* operator->() { return stacked_error_->mut_error_proto(); }\n  operator std::string() const;\n  void Assign(const Error& other) { stacked_error_ = other.stacked_error_; }\n  void Merge(const Error& other);\n\n  Error&& AddStackFrame(Symbol<ErrorStackFrame> error_stack_frame);\n  Error&& GetStackTrace(int64_t depth = 32, int64_t skip_n_firsts = 2);\n\n  static Error Ok();\n  static Error ProtoParseFailedError();\n  static Error JobSetEmptyError();\n  static Error DeviceTagNotFoundError();\n  static Error InvalidValueError();\n  static Error IndexError();\n  static Error TypeError();\n  static Error TimeoutError();\n  static Error JobNameExistError();\n  static Error JobNameEmptyError();\n  static Error JobNameNotEqualError();\n  static Error NoJobBuildAndInferCtxError();\n  static Error JobConfFrozenError();\n  static Error JobConfNotSetError();\n  static Error JobConfRepeatedSetError();\n  static Error JobTypeNotSetError();\n  static Error LogicalBlobNameNotExistError();\n  static Error LogicalBlobNameExistError();\n  static Error LogicalBlobNameInvalidError();\n  static Error OpNameExistError();\n  static Error OpConfDeviceTagNoSetError();\n  static Error PlacementError();\n  static Error BlobSplitAxisInferError();\n  static Error UnknownJobBuildAndInferError();\n  static Error CheckFailedError();\n  static Error ValueNotFoundError();\n  static Error TodoError();\n  static Error UnimplementedError();\n  static Error RuntimeError();\n  static Error OutOfMemoryError();\n  static Error BoxingNotSupportedError();\n  static Error MemoryZoneOutOfMemoryError(int64_t machine_id, int64_t mem_zone_id, uint64_t calc,\n                                          uint64_t available, const std::string& device_type);\n  static Error OpKernelNotFoundError(const std::vector<std::string>& error_msgs);\n  static Error MultipleOpKernelsMatchedError(const std::vector<std::string>& error_msgs);\n  static Error LossBlobNotFoundError();\n\n  static Error RwMutexedObjectNotFoundError();\n\n  // gradient\n  static Error GradientFunctionNotFoundError();\n\n  // symbol\n  static Error SymbolIdUninitializedError();\n\n  static Error CompileOptionWrongError();\n\n  static Error InputDeviceNotMatchError();\n\n  enum MsgCollectingMode {\n    kInvalidMsgCollectingMode = 0,\n    kMergeMessage,\n    kOverrideThenMergeMessage,\n  };\n\n  MsgCollectingMode msg_collecting_mode() const { return msg_collecting_mode_; }\n  void set_msg_collecting_mode(MsgCollectingMode val) { msg_collecting_mode_ = val; }\n\n private:\n  std::shared_ptr<StackedError> stacked_error_;\n  MsgCollectingMode msg_collecting_mode_;\n};\n\n[[noreturn]] void ThrowError(const std::shared_ptr<StackedError>& error);\nconst std::shared_ptr<StackedError>& ThreadLocalError();\n\ninline Error& operator<<(Error& error, Error::MsgCollectingMode mode) {\n  error.set_msg_collecting_mode(mode);\n  return error;\n}\n\ntemplate<typename T>\nError& operator<<(Error& error, const T& x) {\n  std::ostringstream ss;\n  ss << x;\n  if (error.msg_collecting_mode() == Error::kMergeMessage) {\n    error->set_msg(error->msg() + ss.str());\n  } else if (error.msg_collecting_mode() == Error::kOverrideThenMergeMessage) {\n    error->set_msg(ss.str());\n    error.set_msg_collecting_mode(Error::kMergeMessage);\n  } else {\n    GLOGLOGFATAL(\"UNIMPLEMENTED\");\n  }\n  return error;\n}\n\n// r-value reference is used to supporting expressions like `Error() << \"invalid value\"`\ntemplate<typename T>\nError&& operator<<(Error&& error, const T& x) {\n  error << x;\n  return std::move(error);\n}\n\ntemplate<>\ninline Error&& operator<<(Error&& error, const std::stringstream& x) {\n  error << x.str();\n  return std::move(error);\n}\n\ntemplate<>\ninline Error&& operator<<(Error&& error, const std::ostream& x) {\n  error << x.rdbuf();\n  return std::move(error);\n}\n\ntemplate<>\ninline Error&& operator<<(Error&& error, const Error& other) {\n  error.Merge(other);\n  return std::move(error);\n}\n\n// handle CHECK_OR_THROW(expr) << ... << std::endl;\ninline Error&& operator<<(Error&& error, std::ostream& (*os)(std::ostream&)) {\n  error << os;\n  return std::move(error);\n}\n\nextern const char* kOfBugIssueUploadPrompt;\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_COMMON_ERROR_H_\n"
  },
  {
    "path": "oneflow/core/common/error.proto",
    "content": "syntax = \"proto2\";\npackage oneflow;\n\nmessage FieldValue {\n  required string field = 1;\n  required string value = 2;\n}\n\nenum OpcodeType {\n  kInvalidCompareType = 0;\n  kEq = 1;\n  kNe = 2;\n  kGt = 3;\n  kGe = 4;\n  kLt = 5;\n  kLe = 6;\n}\n\nmessage OneFieldAssertError {\n  required OpcodeType compare_type = 1;\n  required FieldValue left = 2;\n  required string right_value = 3;\n}\n\nmessage TwoFieldAssertError {\n  required OpcodeType compare_type = 1;\n  required FieldValue left = 2;\n  required FieldValue right = 3;\n}\n\nmessage ConfigAssertFailedError {\n  oneof oprand_type {\n    OneFieldAssertError one_field_assert_error = 1;\n    TwoFieldAssertError two_field_assert_error = 2;\n  }\n}\n\nmessage ConfigResourceUnavailableError {\n  required FieldValue field_value = 1;\n}\n\nmessage JobSetEmptyError { }\n\nmessage DeviceTagNotFoundError { }\n\nmessage JobNameExistError { }\n\nmessage JobNameEmptyError { }\n\nmessage JobNameNotEqualError { }\n\nmessage NoJobBuildAndInferCtxError { }\n\nmessage JobConfFrozenError { }\n\nmessage JobConfNotSetError { }\n\nmessage JobConfRepeatedSetError { }\n\nmessage JobTypeNotSetError { }\n\nmessage LogicalBlobNameNotExistError { }\n\nmessage LogicalBlobNameExistError { }\n\nmessage LogicalBlobNameInvalidError { }\n\nmessage OpNameExistError { }\n\nmessage OpConfDeviceTagNoSetError { }\n\nmessage PlacementError { }\n\nmessage BlobSplitAxisInferError { }\n\nmessage UnknownJobBuildAndInferError { }\n\nmessage ProtoParseFailedError { }\n\nmessage CheckFailedError { }\n\nmessage TodoError { }\n\nmessage UnimplementedError { }\n\nmessage RuntimeError { }\n\nmessage OutOfMemoryError { }\n\nmessage BoxingNotSupportedError { }\n\nmessage GradientFunctionNotFoundError { }\n\nmessage OpKernelNotFoundError {\n  repeated string op_kernels_not_found_debug_str = 1;\n}\n\nmessage MultipleOpKernelsMatchedError {\n  repeated string matched_op_kernels_debug_str = 1;\n}\n\nmessage MemoryZoneOutOfMemoryError {\n  repeated string machine_id = 1;\n  repeated string mem_zone_id = 2;\n  repeated string device_tag = 3;\n  repeated string required = 4;\n  repeated string available = 5;\n}\n\nmessage LossBlobNotFoundError { }\n\nmessage RwMutexedObjectNotFoundError { }\n\nmessage UnknownError { }\n\nmessage CompileOptionWrongError { }\n\nmessage InputDeviceNotMatchError { \n  repeated string info = 1;\n}\n\nmessage SymbolIdUninitializedError {}\n\nmessage InvalidValueError {}\n\nmessage IndexError {}\nmessage TypeError {}\n\nmessage TimeoutError {}\n\nmessage ValueNotFoundError {}\n\nmessage ErrorProto {\n  optional string msg = 1 [default = \"\"];\n  optional string frame_msg = 2 [default = \"\"];\n  oneof error_type {\n    ConfigAssertFailedError config_assert_failed_error = 12;\n    ConfigResourceUnavailableError config_resource_unavailable_error = 13;\n    ProtoParseFailedError proto_parse_failed_error = 15;\n    CheckFailedError check_failed_error = 16;\n    TodoError todo_error = 17;\n    UnimplementedError unimplemented_error = 18;\n    BoxingNotSupportedError boxing_not_supported_error = 19;\n    GradientFunctionNotFoundError gradient_function_not_found_error = 20;\n    OpKernelNotFoundError op_kernel_not_found_error = 21;\n    MultipleOpKernelsMatchedError multiple_op_kernels_matched_error = 22;\n    MemoryZoneOutOfMemoryError memory_zone_out_of_memory_error = 23;\n    LossBlobNotFoundError loss_blob_not_found_error = 24;\n    JobSetEmptyError job_set_empty_error = 25;\n    DeviceTagNotFoundError device_tag_not_found_error = 26;\n    InvalidValueError invalid_value_error = 27;\n    IndexError index_error = 28;\n    TypeError type_error = 29;\n    RuntimeError runtime_error = 30;\n    OutOfMemoryError out_of_memory_error = 32;\n    TimeoutError timeout_error = 40;\n    ValueNotFoundError value_not_found_error = 31;\n    \n    JobNameExistError job_name_exist_error = 100;\n    JobNameEmptyError job_name_empty_error = 101;\n    JobNameNotEqualError job_name_not_equal_error = 102;\n    NoJobBuildAndInferCtxError no_job_build_and_infer_ctx_error = 200;\n    JobConfFrozenError job_conf_frozen_error = 300;\n    JobConfNotSetError job_conf_not_set_error = 301;\n    JobConfRepeatedSetError job_conf_repeated_set_error = 302;\n    JobTypeNotSetError job_type_not_set_error = 303;\n    LogicalBlobNameNotExistError logical_blob_name_not_exist_error = 400;\n    LogicalBlobNameExistError logical_blob_name_exist_error = 401;\n    LogicalBlobNameInvalidError logical_blob_name_invalid_error = 402;\n    OpNameExistError op_name_exist_error = 450;\n    OpConfDeviceTagNoSetError op_conf_device_tag_no_set_error = 460;\n    PlacementError placement_error= 470;\n    BlobSplitAxisInferError blob_split_axis_infer_error = 480;\n    UnknownJobBuildAndInferError unknown_job_build_and_infer_error = 500;\n    RwMutexedObjectNotFoundError rw_mutexed_object_not_found_error = 600;\n    SymbolIdUninitializedError symbol_id_uninitialized_error = 700;\n    UnknownError unknown_error = 900;\n    CompileOptionWrongError compile_option_wrong_error = 950;\n    InputDeviceNotMatchError input_device_not_match_error = 1000;\n  }\n}\n"
  },
  {
    "path": "oneflow/core/common/error_util.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <sstream>\n#include \"oneflow/core/common/error_util.h\"\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/job/graph_scope_vars.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nstd::string StripSpace(std::string str) {\n  if (str.size() == 0) { return \"\"; }\n  size_t pos = str.find_first_not_of(\" \");\n  if (pos != std::string::npos) { str.erase(0, pos); }\n  pos = str.find_last_not_of(\" \");\n  if (pos != std::string::npos) { str.erase(pos + 1); }\n  return str;\n}\n\nbool IsLetterNumberOrUnderline(char c) {\n  return (c >= '0' && c <= '9') || (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c == '_');\n}\n\nMaybe<std::string> ShortenMsg(std::string str) {\n  // 150 characters is the threshold\n  const int num_character_threshold = 150;\n  const int num_displayed_character = 50;\n  if (str.size() == 0) { return str; }\n  // strip space when JUST(  xx  );\n  str = StripSpace(str);\n  if (str.size() < num_character_threshold) { return str; }\n\n  // left part whose number of characters is just over 50\n  int left_index = num_displayed_character;\n  bool pre_condition = IsLetterNumberOrUnderline(str.at(left_index));\n  for (; left_index < str.size(); left_index++) {\n    bool cur_condition = IsLetterNumberOrUnderline(str.at(left_index));\n    if ((pre_condition && !cur_condition) || (!pre_condition && cur_condition)) { break; }\n  }\n\n  // right part whose number of characters is just over 50\n  int right_index = str.size() - num_displayed_character;\n  pre_condition = IsLetterNumberOrUnderline(str.at(right_index));\n  for (; right_index >= 0; right_index--) {\n    bool cur_condition = IsLetterNumberOrUnderline(str.at(right_index));\n    if ((pre_condition && !cur_condition) || (!pre_condition && cur_condition)) {\n      right_index++;\n      break;\n    }\n  }\n  // a long word of more than 150\n  if (right_index - left_index < 50) { return str; }\n  std::stringstream ss;\n  CHECK_OR_RETURN(left_index >= 0);\n  CHECK_OR_RETURN(left_index < str.size());\n  ss << str.substr(0, left_index);\n  ss << \" ... \";\n  CHECK_OR_RETURN(right_index >= 0);\n  CHECK_OR_RETURN(right_index < str.size());\n  ss << str.substr(right_index);\n  return ss.str();\n}\n\n// file info in stack frame\nstd::string FormatFileOfStackFrame(const std::string& file) {\n  std::stringstream ss;\n  ss << \"\\n  File \\\"\" << file << \"\\\", \";\n  return ss.str();\n}\n\n// line info in stack frame\nstd::string FormatLineOfStackFrame(const int64_t& line) {\n  std::stringstream ss;\n  if (line >= 0) {\n    ss << \"line \" << line << \",\";\n  } else {\n    ss << \"line <unknown>,\";\n  }\n  return ss.str();\n}\n\n// function info in stack frame\nstd::string FormatFunctionOfStackFrame(const std::string& function) {\n  std::stringstream ss;\n  ss << \" in \" << function;\n  return ss.str();\n}\n\n// msg in stack frame\nMaybe<std::string> FormatMsgOfStackFrame(std::string error_msg, bool is_last_stack_frame) {\n  const bool debug_mode = GetGraphDebugMode();\n  // only shorten the message if it is not the last stack frame AND not in debug mode\n  if (!is_last_stack_frame && !debug_mode) { error_msg = *JUST(ShortenMsg(error_msg)); }\n  // error_msg of last stack frame come from \"<<\"\n  if (is_last_stack_frame) { error_msg = StripSpace(error_msg); }\n  std::stringstream ss;\n  if (!error_msg.empty()) { ss << \"\\n    \" << error_msg; }\n  return ss.str();\n}\n\n// the msg in error type instance.\nMaybe<std::string> FormatMsgOfErrorType(const std::shared_ptr<StackedError>& error) {\n  const auto& error_proto = error->error_proto();\n  CHECK_NE_OR_RETURN(error_proto->error_type_case(), ErrorProto::ERROR_TYPE_NOT_SET)\n      << Error::RuntimeError() << \"Parse error failed, unknown error type\";\n  std::stringstream ss;\n  const google::protobuf::Descriptor* error_des = error_proto->GetDescriptor();\n  const google::protobuf::OneofDescriptor* oneof_field_des =\n      error_des->FindOneofByName(\"error_type\");\n  const google::protobuf::Reflection* error_ref = error_proto->GetReflection();\n  const google::protobuf::FieldDescriptor* field_des =\n      error_ref->GetOneofFieldDescriptor(*error_proto, oneof_field_des);\n  CHECK_OR_RETURN(field_des != nullptr);\n  ss << \"Error Type: \" << field_des->full_name();\n  return ss.str();\n}\n\n}  // namespace\n\nMaybe<std::string> FormatErrorStr(const std::shared_ptr<StackedError>& error) {\n  std::stringstream ss;\n  ss << error->error_proto()->msg();\n  ss << error->error_proto()->frame_msg();\n  // Get msg from stack frame of error proto\n  for (auto iter = error->stack_frame().rbegin(); iter < error->stack_frame().rend(); iter++) {\n    auto stack_frame = *iter;\n    ss << FormatFileOfStackFrame(stack_frame->file()) << FormatLineOfStackFrame(stack_frame->line())\n       << FormatFunctionOfStackFrame(stack_frame->function())\n       << *JUST(FormatMsgOfStackFrame(stack_frame->code_text(),\n                                      iter == error->stack_frame().rend() - 1));\n  }\n  // Get msg from error type of error proto\n  std::string msg_of_error_type = *JUST(FormatMsgOfErrorType(error));\n  if (msg_of_error_type.size() != 0) { ss << \"\\n\" << msg_of_error_type; }\n  return ss.str();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/common/error_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_COMMON_ERROR_UTIL_H\n#define ONEFLOW_CORE_COMMON_ERROR_UTIL_H\n\n#include <string>\n#include \"oneflow/core/common/error.pb.h\"\n#include \"oneflow/core/common/maybe.h\"\n\nnamespace oneflow {\n\nMaybe<std::string> FormatErrorStr(const std::shared_ptr<StackedError>& error);\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_COMMON_ERROR_UTIL_H\n"
  },
  {
    "path": "oneflow/core/common/exception.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_COMMON_EXCEPTION_H_\n#define ONEFLOW_CORE_COMMON_EXCEPTION_H_\n\n#include <exception>\n#include <string>\n\nnamespace oneflow {\n\nclass Exception : public std::exception {\n public:\n  explicit Exception(const std::string& what) : what_(what) {}\n  virtual ~Exception() = default;\n\n  const char* what() const noexcept override { return what_.c_str(); }\n\n private:\n  std::string what_;\n};\n\nclass RuntimeException : public Exception {\n public:\n  using Exception::Exception;\n};\n\nclass TypeException : public Exception {\n public:\n  using Exception::Exception;\n};\n\nclass IndexException : public Exception {\n public:\n  using Exception::Exception;\n};\n\nclass NotImplementedException : public Exception {\n public:\n  using Exception::Exception;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_COMMON_EXCEPTION_H_\n"
  },
  {
    "path": "oneflow/core/common/flat_shape.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/flat_shape.h\"\n#include \"oneflow/core/common/shape.h\"\n\nnamespace oneflow {\n\n/*static*/ Maybe<FlatShape> FlatShape::New(const Shape& shape) {\n  const auto& flat_shape = std::make_shared<FlatShape>();\n  JUST(flat_shape->Init(shape));\n  return flat_shape;\n}\n\nMaybe<void> FlatShape::Init(const Shape& shape) {\n  CHECK_LE_OR_RETURN(shape.NumAxes(), SHAPE_MAX_AXIS_SIZE);\n  this->clear_dim();\n  for (int i = 0; i < shape.NumAxes(); ++i) { *this->mutable_dim()->Add() = shape.At(i); }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> FlatShape::Check(const Shape& shape) const {\n  CHECK_EQ_OR_RETURN(this->dim_size(), shape.NumAxes())\n      << Error::RuntimeError()\n      << \"Expected same shape on each rank, but found at least two shapes, \"\n      << JUST(ToShape())->ToString() << \" and \" << shape.ToString() << \"!\";\n  for (int i = 0; i < this->dim_size(); ++i) { CHECK_EQ_OR_RETURN(this->dim(i), shape.At(i)); }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> FlatShape::Check(const FlatShape& flat_shape) const {\n  CHECK_EQ_OR_RETURN(this->dim_size(), flat_shape.NumAxes())\n      << Error::RuntimeError()\n      << \"Expected input of each rank must have the same size, but got at least two size, \"\n      << JUST(ToShape())->ToString() << \" and \" << JUST(flat_shape.ToShape())->ToString();\n  for (int i = 0; i < this->dim_size(); ++i) {\n    CHECK_EQ_OR_RETURN(this->dim(i), flat_shape.At(i))\n        << Error::RuntimeError()\n        << \"Expected input of each rank must have the same size, but got at least two size, \"\n        << JUST(ToShape())->ToString() << \" and \" << JUST(flat_shape.ToShape())->ToString();\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<Shape> FlatShape::ToShape() const {\n  const auto& shape = std::make_shared<Shape>();\n  JUST(ToShape(shape.get()));\n  return shape;\n}\n\nMaybe<void> FlatShape::ToShape(Shape* shape) const {\n  DimVector dim_vec;\n  for (int i = 0; i < this->dim_size(); ++i) { dim_vec.emplace_back(this->dim(i)); }\n  *shape = Shape(dim_vec);\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/common/flat_shape.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_COMMON_FLAT_SHAPE_H_\n#define ONEFLOW_CORE_COMMON_FLAT_SHAPE_H_\n\n#include <memory>\n#include \"oneflow/core/intrusive/flat_msg.h\"\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/common/shape_vec.h\"\n\nnamespace oneflow {\n\nclass Shape;\n\n// clang-format off\n\nFLAT_MSG_BEGIN(FlatShape);\n public:\n  // Methods\n  static Maybe<FlatShape> New(const Shape& shape);\n  Maybe<void> Init(const Shape& shape);\n  Maybe<void> Check(const Shape& shape) const;\n  Maybe<void> Check(const FlatShape& flat_shape) const;\n  Maybe<Shape> ToShape() const;\n  Maybe<void> ToShape(Shape* shape) const;\n  int64_t At(int i) const { return dim(i); }\n  int64_t NumAxes() const { return dim_size(); }\n\n  // Fields\n  FLAT_MSG_DEFINE_REPEATED(int64_t, dim, SHAPE_MAX_AXIS_SIZE);\nFLAT_MSG_END(FlatShape);\n\n// clang-format on\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_COMMON_FLAT_SHAPE_H_\n"
  },
  {
    "path": "oneflow/core/common/foreign_lock_helper.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/foreign_lock_helper.h\"\n#include \"oneflow/core/common/singleton.h\"\n\nnamespace oneflow {\nclass NoForeignLockHelper final : public ForeignLockHelper {\n  Maybe<void> WithScopedRelease(const std::function<Maybe<void>()>& Callback) const override {\n    return Callback();\n  }\n\n  Maybe<void> WithScopedAcquire(const std::function<Maybe<void>()>& Callback) const override {\n    return Callback();\n  }\n};\n\nstatic int __register_no_foreign_lock_helper __attribute__((unused)) = []() {\n  Singleton<ForeignLockHelper>::SetAllocated(new NoForeignLockHelper());\n  return 0;\n}();\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/common/foreign_lock_helper.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_COMMON_FOREIGN_LOCK_HELPER_H\n#define ONEFLOW_CORE_COMMON_FOREIGN_LOCK_HELPER_H\n#include <functional>\n#include \"oneflow/core/common/maybe.h\"\n\nnamespace oneflow {\nclass ForeignLockHelper {\n public:\n  virtual ~ForeignLockHelper() = default;\n  virtual Maybe<void> WithScopedRelease(const std::function<Maybe<void>()>&) const = 0;\n  virtual Maybe<void> WithScopedAcquire(const std::function<Maybe<void>()>&) const = 0;\n};\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_COMMON_FOREIGN_LOCK_HELPER_H\n"
  },
  {
    "path": "oneflow/core/common/function_traits.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_COMMON_FUNCTION_TRAITS_H_\n#define ONEFLOW_CORE_COMMON_FUNCTION_TRAITS_H_\n\n#include <tuple>\n\nnamespace oneflow {\n\ntemplate<typename... Args>\nusing void_t = void;\n\ntemplate<typename T, typename = void>\nstruct function_traits;\n\ntemplate<typename Ret, typename... Args>\nstruct function_traits<Ret(Args...)> {\n  using func_type = Ret(Args...);\n  using return_type = Ret;\n  using args_type = std::tuple<Args...>;\n  template<size_t i>\n  using arg_type = typename std::tuple_element<i, args_type>::type;\n\n  static constexpr size_t nargs = sizeof...(Args);\n};\n\ntemplate<typename Ret, typename... Args>\nstruct function_traits<Ret (*)(Args...)> {\n  using func_type = Ret(Args...);\n  using return_type = Ret;\n  using args_type = std::tuple<Args...>;\n  template<size_t i>\n  using arg_type = typename std::tuple_element<i, args_type>::type;\n\n  static constexpr size_t nargs = sizeof...(Args);\n};\n\ntemplate<typename Ret, typename C, typename... Args>\nstruct function_traits<Ret (C::*)(Args...)> {\n  using func_type = Ret(Args...);\n  using return_type = Ret;\n  using args_type = std::tuple<Args...>;\n  template<size_t i>\n  using arg_type = typename std::tuple_element<i, args_type>::type;\n\n  static constexpr size_t nargs = sizeof...(Args);\n};\n\ntemplate<typename Ret, typename C, typename... Args>\nstruct function_traits<Ret (C::*)(Args...) const> {\n  using func_type = Ret(Args...);\n  using return_type = Ret;\n  using args_type = std::tuple<Args...>;\n  template<size_t i>\n  using arg_type = typename std::tuple_element<i, args_type>::type;\n\n  static constexpr size_t nargs = sizeof...(Args);\n};\n\ntemplate<typename F>\nstruct function_traits<F, void_t<decltype(&F::operator())>>\n    : public function_traits<decltype(&F::operator())> {};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_COMMON_FUNCTION_TRAITS_H_\n"
  },
  {
    "path": "oneflow/core/common/hash.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_COMMON_HASH_H_\n#define ONEFLOW_CORE_COMMON_HASH_H_\n#include <functional>\n#include <complex>\n\nnamespace oneflow {\n\ninline size_t HashCombine(size_t lhs, size_t rhs) {\n  return lhs ^ (rhs + 0x9e3779b9 + (lhs << 6U) + (lhs >> 2U));\n}\n\ninline void HashCombine(size_t* seed, size_t hash) { *seed = HashCombine(*seed, hash); }\n\ntemplate<typename... T>\ninline void AddHash(size_t* seed, const T&... v) {\n  (HashCombine(seed, std::hash<T>()(v)), ...);\n}\n\ntemplate<typename T, typename... Ts>\ninline size_t Hash(const T& v1, const Ts&... vn) {\n  size_t seed = std::hash<T>()(v1);\n\n  AddHash<Ts...>(&seed, vn...);\n\n  return seed;\n}\n\n}  // namespace oneflow\n\nnamespace std {\n\ntemplate<typename T0, typename T1>\nstruct hash<std::pair<T0, T1>> {\n  std::size_t operator()(const std::pair<T0, T1>& p) const {\n    return oneflow::Hash<T0, T1>(p.first, p.second);\n  }\n};\n\ntemplate<typename T>\nstruct hash<std::vector<T>> {\n  std::size_t operator()(const std::vector<T>& vec) const {\n    std::size_t hash_value = vec.size();\n    for (const auto& elem : vec) { oneflow::AddHash<T>(&hash_value, elem); }\n    return hash_value;\n  }\n};\n\ntemplate<typename T>\nstruct hash<std::complex<T>> {\n  size_t operator()(const std::complex<T>& c) const { return oneflow::Hash(c.real(), c.imag()); }\n};\n\n}  // namespace std\n\n#endif  // ONEFLOW_CORE_COMMON_HASH_H_\n"
  },
  {
    "path": "oneflow/core/common/hash_container.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_COMMON_HASH_CONTAINER_\n#define ONEFLOW_CORE_COMMON_HASH_CONTAINER_\n\n#include <unordered_set>\n#include <unordered_map>\n\nnamespace oneflow {\n\ntemplate<typename Key, typename T, typename Hash = std::hash<Key>>\nusing HashMap = std::unordered_map<Key, T, Hash>;\n\ntemplate<typename Key, typename Hash = std::hash<Key>>\nusing HashSet = std::unordered_set<Key, Hash>;\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_COMMON_HASH_CONTAINER_\n"
  },
  {
    "path": "oneflow/core/common/hash_eq_trait_ptr.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_COMMON_HASH_EQ_TRAIT_PTR_H_\n#define ONEFLOW_CORE_COMMON_HASH_EQ_TRAIT_PTR_H_\n\nnamespace oneflow {\n\ntemplate<typename T>\nclass HashEqTraitPtr final {\n public:\n  HashEqTraitPtr(const HashEqTraitPtr<T>&) = default;\n  HashEqTraitPtr(T* ptr, size_t hash_value) : ptr_(ptr), hash_value_(hash_value) {}\n  ~HashEqTraitPtr() = default;\n\n  T* ptr() const { return ptr_; }\n  size_t hash_value() const { return hash_value_; }\n\n  bool operator==(const HashEqTraitPtr<T>& rhs) const { return *ptr_ == *rhs.ptr_; }\n\n private:\n  T* ptr_;\n  size_t hash_value_;\n};\n\n}  // namespace oneflow\n\nnamespace std {\n\ntemplate<typename T>\nstruct hash<oneflow::HashEqTraitPtr<T>> final {\n  size_t operator()(const oneflow::HashEqTraitPtr<T>& ptr) const { return ptr.hash_value(); }\n};\n\n}  // namespace std\n\n#endif  // ONEFLOW_CORE_COMMON_HASH_EQ_TRAIT_PTR_H_\n"
  },
  {
    "path": "oneflow/core/common/high_order_bool.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_COMMON_HIGH_ORDER_BOOL_H_\n#define ONEFLOW_CORE_COMMON_HIGH_ORDER_BOOL_H_\n\n#include <string>\n#include <memory>\n#include <sstream>\n#include <functional>\n#include <utility>\n\n#include \"oneflow/core/common/function_traits.h\"\n#include \"oneflow/core/common/type_traits.h\"\n#include \"oneflow/core/common/util.h\"\n\nnamespace oneflow {\n\nnamespace hob {\n\ntemplate<typename Context, typename ValueT>\nstruct BaseExpr {\n#pragma GCC diagnostic push\n#pragma GCC diagnostic ignored \"-Wnon-virtual-dtor\"\n  // NOTE: Performance will be degraded if the destructor is virtual.\n  //       So please do NOT implement custom destructor in any child classes of BaseExpr,\n  //       and every fields of child classes should be of POD type.\n  ~BaseExpr() = default;\n#pragma GCC diagnostic pop\n  ALWAYS_INLINE virtual scalar_or_const_ref_t<ValueT> get(const Context&) const = 0;\n  virtual std::string DebugStr(const Context&, bool display_result = true) const = 0;  // NOLINT\n  operator bool() = delete;\n};\n\ntemplate<typename Context, typename ValueT, typename E>\nstruct Expr : public BaseExpr<Context, ValueT> {\n#pragma GCC diagnostic push\n#pragma GCC diagnostic ignored \"-Wnon-virtual-dtor\"\n  ~Expr() = default;\n#pragma GCC diagnostic pop\n};\n\ntemplate<typename Context, typename ValueT>\nstruct Literal final : public Expr<Context, ValueT, Literal<Context, ValueT>> {\n  Literal(const ValueT& val) : Literal(ToString(val), val) {}  // NOLINT\n  Literal(const std::string& debug_str, const ValueT& val) : val_(val), debug_str_(debug_str) {}\n  ALWAYS_INLINE scalar_or_const_ref_t<ValueT> get(const Context&) const override { return val_; }\n  std::string DebugStr(const Context&, bool display_result) const override { return debug_str_; }\n\n private:\n  ValueT val_;\n  std::string debug_str_;\n};\n\ntemplate<typename Context>\nusing LiteralBool = Literal<Context, bool>;\n\ntemplate<typename Fn,\n         typename Context =\n             std::decay_t<typename oneflow::function_traits<Fn>::template arg_type<0>>,\n         typename ValueT = std::decay_t<typename oneflow::function_traits<Fn>::return_type>>\nstruct Custom final : public Expr<Context, ValueT, Custom<Fn>> {\n  explicit Custom(Fn fn) : Custom(\"\", fn) {}\n  Custom(std::string debug_str, Fn fn) : fn_(std::move(fn)), debug_str_(std::move(debug_str)) {}\n  ALWAYS_INLINE scalar_or_const_ref_t<ValueT> get(const Context& context) const override {\n    return fn_(context);\n  }\n  std::string DebugStr(const Context&, bool display_result) const override { return debug_str_; }\n\n private:\n  Fn fn_;\n  std::string debug_str_;\n};\n\ntemplate<typename Fn>\nALWAYS_INLINE inline Custom<Fn> make_custom(Fn fn) {\n  return Custom<Fn>(std::forward<Fn>(fn));\n}\n\ntemplate<typename Fn>\nALWAYS_INLINE inline Custom<Fn> make_custom(const std::string& debug_str, Fn fn) {\n  return Custom<Fn>(debug_str, std::forward<Fn>(fn));\n}\n\ntemplate<typename Context, typename E>\nusing BoolExpr = Expr<Context, bool, E>;\n\ntemplate<typename Context, typename E>\nstruct NotBoolFunctor final : public BoolExpr<Context, NotBoolFunctor<Context, E>> {\n  explicit NotBoolFunctor(const E& expr) : expr_(expr) {}\n\n  ALWAYS_INLINE bool get(const Context& context) const override { return !expr_.get(context); }\n\n  std::string DebugStr(const Context& ctx, bool display_result) const override {\n    std::ostringstream string_stream;\n    string_stream << \"(\"\n                  << \"not \" << expr_.DebugStr(ctx, display_result) << \")\";\n    return string_stream.str();\n  }\n\n private:\n  const E expr_;\n};\n\ntemplate<typename Context, typename E>\nNotBoolFunctor<Context, E> operator!(BoolExpr<Context, E> const& lhs) {\n  return NotBoolFunctor<Context, E>(*static_cast<const E*>(&lhs));\n}\n\n#define DEFINE_BINARY_FUNCTOR(name, op)                                                           \\\n  template<typename Context, typename E1, typename E2>                                            \\\n  struct name##BoolFunctor final : public BoolExpr<Context, name##BoolFunctor<Context, E1, E2>> { \\\n    name##BoolFunctor(const E1& lhs, const E2& rhs) : lhs_(lhs), rhs_(rhs) {}                     \\\n                                                                                                  \\\n    ALWAYS_INLINE bool get(const Context& context) const override;                                \\\n                                                                                                  \\\n    std::string DebugStr(const Context& ctx, bool display_result) const override;                 \\\n                                                                                                  \\\n   private:                                                                                       \\\n    const E1 lhs_;                                                                                \\\n    const E2 rhs_;                                                                                \\\n  };                                                                                              \\\n                                                                                                  \\\n  template<typename Context, typename ValueT, typename E1, typename E2>                           \\\n  name##BoolFunctor<Context, E1, E2> operator op(Expr<Context, ValueT, E1> const& lhs,            \\\n                                                 Expr<Context, ValueT, E2> const& rhs) {          \\\n    return name##BoolFunctor<Context, E1, E2>(*static_cast<const E1*>(&lhs),                      \\\n                                              *static_cast<const E2*>(&rhs));                     \\\n  }                                                                                               \\\n                                                                                                  \\\n  template<typename Context, typename ValueT, typename E1>                                        \\\n  name##BoolFunctor<Context, E1, Literal<Context, ValueT>> operator op(                           \\\n      Expr<Context, ValueT, E1> const& lhs, ValueT const& rhs) {                                  \\\n    return name##BoolFunctor<Context, E1, Literal<Context, ValueT>>(                              \\\n        *static_cast<const E1*>(&lhs), Literal<Context, ValueT>(rhs));                            \\\n  }\n\nDEFINE_BINARY_FUNCTOR(Equal, ==)\nDEFINE_BINARY_FUNCTOR(And, &&)\nDEFINE_BINARY_FUNCTOR(Or, ||)\nDEFINE_BINARY_FUNCTOR(Greater, >)\nDEFINE_BINARY_FUNCTOR(Less, <)\nDEFINE_BINARY_FUNCTOR(EqualOrGreater, >=)\nDEFINE_BINARY_FUNCTOR(EqualOrLess, <=)\n\n#undef DEFINE_BINARY_FUNCTOR\n\n#define DEFINE_NON_SHORT_CIRCUIT_FUNCTOR_METHODS(name, op)                                  \\\n  template<typename Context, typename E1, typename E2>                                      \\\n  ALWAYS_INLINE inline bool name##BoolFunctor<Context, E1, E2>::get(const Context& context) \\\n      const {                                                                               \\\n    return lhs_.get(context) op rhs_.get(context);                                          \\\n  }                                                                                         \\\n  template<typename Context, typename E1, typename E2>                                      \\\n  std::string name##BoolFunctor<Context, E1, E2>::DebugStr(const Context& ctx,              \\\n                                                           bool display_result) const {     \\\n    std::string l_str = lhs_.DebugStr(ctx, display_result);                                 \\\n    std::string r_str = rhs_.DebugStr(ctx, display_result);                                 \\\n    std::ostringstream string_stream;                                                       \\\n    string_stream << \"(\" << l_str << \" \" << OF_PP_STRINGIZE(op) << \" \" << r_str << \")\";     \\\n    return string_stream.str();                                                             \\\n  }\n\nDEFINE_NON_SHORT_CIRCUIT_FUNCTOR_METHODS(Equal, ==)\nDEFINE_NON_SHORT_CIRCUIT_FUNCTOR_METHODS(Greater, >)\nDEFINE_NON_SHORT_CIRCUIT_FUNCTOR_METHODS(Less, <)\nDEFINE_NON_SHORT_CIRCUIT_FUNCTOR_METHODS(EqualOrGreater, >=)\nDEFINE_NON_SHORT_CIRCUIT_FUNCTOR_METHODS(EqualOrLess, <=)\n\n#undef DEFINE_NON_SHORT_CIRCUIT_FUNCTOR_METHODS\n\ntemplate<typename Context, typename E1, typename E2>\nALWAYS_INLINE inline bool AndBoolFunctor<Context, E1, E2>::get(const Context& context) const {\n  bool lhs_result = lhs_.get(context);\n  if (!lhs_result) { return false; }\n  return rhs_.get(context);\n}\n\ntemplate<typename Context, typename E1, typename E2>\nstd::string AndBoolFunctor<Context, E1, E2>::DebugStr(const Context& ctx,\n                                                      bool display_result) const {\n  std::string l_str = lhs_.DebugStr(ctx, display_result);\n  display_result = display_result && lhs_.get(ctx);\n  std::string r_str = rhs_.DebugStr(ctx, display_result);\n  std::ostringstream string_stream;\n  string_stream << \"(\" << l_str << \" and \" << r_str << \")\";\n  return string_stream.str();\n}\n\ntemplate<typename Context, typename E1, typename E2>\nALWAYS_INLINE inline bool OrBoolFunctor<Context, E1, E2>::get(const Context& context) const {\n  bool lhs_result = lhs_.get(context);\n  if (lhs_result) { return true; }\n  return rhs_.get(context);\n}\n\ntemplate<typename Context, typename E1, typename E2>\nstd::string OrBoolFunctor<Context, E1, E2>::DebugStr(const Context& ctx,\n                                                     bool display_result) const {\n  std::string l_str = lhs_.DebugStr(ctx, display_result);\n  display_result = display_result && (!lhs_.get(ctx));\n  std::string r_str = rhs_.DebugStr(ctx, display_result);\n  std::ostringstream string_stream;\n  string_stream << \"(\" << l_str << \" or \" << r_str << \")\";\n  return string_stream.str();\n}\n\ntemplate<typename Context, typename E1>\nEqualBoolFunctor<Context, E1, Literal<Context, std::string>> operator==(\n    Expr<Context, std::string, E1> const& lhs, const char* rhs) {\n  return EqualBoolFunctor<Context, E1, Literal<Context, std::string>>(\n      *static_cast<const E1*>(&lhs), Literal<Context, std::string>(rhs));\n}\n\n}  // namespace hob\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_COMMON_HIGH_ORDER_BOOL_H_\n"
  },
  {
    "path": "oneflow/core/common/just.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#ifndef ONEFLOW_CORE_COMMON_JUST_H_\n#define ONEFLOW_CORE_COMMON_JUST_H_\n\n#include <sstream>\n#include <type_traits>\n#include \"oneflow/core/common/error.h\"\n#include \"oneflow/core/common/throw.h\"\n#include \"oneflow/core/common/symbol.h\"\n#include \"oneflow/core/common/preprocessor.h\"\n\nnamespace oneflow {\n\ntemplate<typename T, typename Enabled = void>\nclass Maybe;\n\ntemplate<typename T>\nclass Optional;\n\nMaybe<std::string> FormatErrorStr(const std::shared_ptr<StackedError>&);\nnamespace {\nstd::string GetFormatedSerializedError(const std::shared_ptr<StackedError>&);\n}\n\nnamespace private_details {\n\ninline std::shared_ptr<StackedError>&& JustErrorAddStackFrame(\n    std::shared_ptr<StackedError>&& err, Symbol<ErrorStackFrame> error_stack_frame) {\n  err->add_stack_frame(error_stack_frame);\n  return std::move(err);\n}\n\ntemplate<typename T>\nError&& AddFrameMessage(Error&& error, const T& x) {\n  std::ostringstream ss;\n  ss << x;\n  error->set_frame_msg(error->frame_msg() + ss.str());\n  return std::move(error);\n}\n\ntemplate<>\ninline Error&& AddFrameMessage(Error&& error, const std::stringstream& x) {\n  AddFrameMessage(std::move(error), x.str());\n  return std::move(error);\n}\n\ntemplate<>\ninline Error&& AddFrameMessage(Error&& error, const std::ostream& x) {\n  AddFrameMessage(std::move(error), x.rdbuf());\n  return std::move(error);\n}\n\ntemplate<typename... T>\nError&& JustErrorAddFrameMessage(Error&& err, T&&... msg) {\n  (AddFrameMessage(std::move(err), std::forward<T>(msg)), ...);\n  return std::move(err);\n}\n\ntemplate<typename T>\nbool JustIsOk(const Maybe<T>& val) {\n  return val.IsOk();\n}\n\ntemplate<typename T>\nbool JustIsOk(const Optional<T>& val) {\n  return val.has_value();\n}\n\ntemplate<typename T>\nstd::shared_ptr<StackedError> JustGetError(const Maybe<T>& val) {\n  return val.stacked_error();\n}\n\ntemplate<typename T>\nstd::shared_ptr<StackedError> JustGetError(const Optional<T>&) {\n  return Error::ValueNotFoundError().stacked_error();\n}\n\ntemplate<typename T>\ntypename std::remove_const<typename std::remove_reference<T>::type>::type&& RemoveRValConst(\n    T&& v) noexcept {\n  static_assert(std::is_rvalue_reference<T&&>::value, \"rvalue is expected here\");\n  return const_cast<typename std::remove_const<typename std::remove_reference<T>::type>::type&&>(v);\n}\n\n}  // namespace private_details\n}  // namespace oneflow\n\n#define __JustStackCheckWrapper__(...) __VA_ARGS__\n#define TRY(...) __JustStackCheckWrapper__(__VA_ARGS__)\n\n#if defined(__GNUC__) || defined(__CUDACC__) || defined(__clang__)\n\n#define JUST(...)                                                                        \\\n  ::oneflow::private_details::RemoveRValConst(({                                         \\\n    auto&& _just_value_to_check_ = __JustStackCheckWrapper__(__VA_ARGS__);               \\\n    if (!::oneflow::private_details::JustIsOk(_just_value_to_check_)) {                  \\\n      return ::oneflow::private_details::JustErrorAddStackFrame(                         \\\n          ::oneflow::private_details::JustGetError(_just_value_to_check_),               \\\n          [](const char* function) {                                                     \\\n            thread_local static auto frame = ::oneflow::SymbolOf(                        \\\n                ::oneflow::ErrorStackFrame(__FILE__, __LINE__, function, #__VA_ARGS__)); \\\n            return frame;                                                                \\\n          }(__FUNCTION__));                                                              \\\n    }                                                                                    \\\n    std::forward<decltype(_just_value_to_check_)>(_just_value_to_check_);                \\\n  })).Data_YouAreNotAllowedToCallThisFuncOutsideThisFile()\n\n#define CHECK_JUST(...)                                                                            \\\n  ([&](const char* _just_closure_func_name_) {                                                     \\\n    auto&& _just_value_to_check_ = __JustStackCheckWrapper__(__VA_ARGS__);                         \\\n    if (!::oneflow::private_details::JustIsOk(_just_value_to_check_)) {                            \\\n      thread_local static auto frame = ::oneflow::SymbolOf(                                        \\\n          ::oneflow::ErrorStackFrame(__FILE__, __LINE__, _just_closure_func_name_, #__VA_ARGS__)); \\\n      THROW(RuntimeError) << ::oneflow::GetErrorString(                                            \\\n          ::oneflow::private_details::JustErrorAddStackFrame(                                      \\\n              ::oneflow::private_details::JustGetError(_just_value_to_check_), frame));            \\\n    }                                                                                              \\\n    return std::forward<decltype(_just_value_to_check_)>(_just_value_to_check_);                   \\\n  })(__FUNCTION__)                                                                                 \\\n      .Data_YouAreNotAllowedToCallThisFuncOutsideThisFile()\n\n#define JUST_MSG(value, ...)                                                                  \\\n  ::oneflow::private_details::RemoveRValConst(({                                              \\\n    auto&& _just_value_to_check_ = (value);                                                   \\\n    if (!::oneflow::private_details::JustIsOk(_just_value_to_check_)) {                       \\\n      return ::oneflow::private_details::JustErrorAddFrameMessage(                            \\\n          ::oneflow::Error(::oneflow::private_details::JustGetError(_just_value_to_check_))   \\\n              .AddStackFrame([](const char* function) {                                       \\\n                thread_local static auto frame = ::oneflow::SymbolOf(                         \\\n                    ::oneflow::ErrorStackFrame(__FILE__, __LINE__, function, #value));        \\\n                return frame;                                                                 \\\n              }(__FUNCTION__)),                                                               \\\n          \"\\nError message from \" __FILE__, \":\", __LINE__, \"\\n\\t\", #value, \": \", __VA_ARGS__, \\\n          \"\\n\");                                                                              \\\n    }                                                                                         \\\n    std::forward<decltype(_just_value_to_check_)>(_just_value_to_check_);                     \\\n  })).Data_YouAreNotAllowedToCallThisFuncOutsideThisFile()\n\n#define CHECK_JUST_MSG(value, ...)                                                                \\\n  ([&](const char* _just_closure_func_name_) {                                                    \\\n    auto&& _just_value_to_check_ = (value);                                                       \\\n    if (!::oneflow::private_details::JustIsOk(_just_value_to_check_)) {                           \\\n      thread_local static auto frame = ::oneflow::SymbolOf(                                       \\\n          ::oneflow::ErrorStackFrame(__FILE__, __LINE__, _just_closure_func_name_, #value));      \\\n      THROW(RuntimeError) << ::oneflow::GetErrorString(                                           \\\n          ::oneflow::private_details::JustErrorAddFrameMessage(                                   \\\n              ::oneflow::Error(::oneflow::private_details::JustGetError(_just_value_to_check_))   \\\n                  .AddStackFrame(frame),                                                          \\\n              \"\\nError message from \" __FILE__, \":\", __LINE__, \"\\n\\t\", #value, \": \", __VA_ARGS__, \\\n              \"\\n\")                                                                               \\\n              .stacked_error());                                                                  \\\n    }                                                                                             \\\n    return std::forward<decltype(_just_value_to_check_)>(_just_value_to_check_);                  \\\n  })(__FUNCTION__)                                                                                \\\n      .Data_YouAreNotAllowedToCallThisFuncOutsideThisFile()\n\n#define JUST_OPT(...)                                                      \\\n  ::oneflow::private_details::RemoveRValConst(({                           \\\n    auto&& _just_value_to_check_ = __JustStackCheckWrapper__(__VA_ARGS__); \\\n    if (!_just_value_to_check_.has_value()) { return NullOpt; }            \\\n    std::forward<decltype(_just_value_to_check_)>(_just_value_to_check_);  \\\n  })).Data_YouAreNotAllowedToCallThisFuncOutsideThisFile()\n\n#else\n#error statement expression is no supported, please implement try-catch version of JUST\n#endif\n\n#endif  // ONEFLOW_CORE_COMMON_JUST_H_\n"
  },
  {
    "path": "oneflow/core/common/layout_standardize.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_COMMON_LAYOUT_STANDARDIZE_H_\n#define ONEFLOW_CORE_COMMON_LAYOUT_STANDARDIZE_H_\n\nnamespace oneflow {\n\ntemplate<typename T>\nclass LayoutStandardize final {\n public:\n  void __Init__(const T& val) { new (&data_[0]) T(val); }\n  void __Delete__() { Mutable()->~T(); }\n\n  const T& Get() const { return *reinterpret_cast<const T*>(&data_[0]); }\n  T* Mutable() { return reinterpret_cast<T*>(&data_[0]); }\n\n private:\n  union {\n    char data_[sizeof(T)];\n    int64_t align_;\n  };\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_COMMON_LAYOUT_STANDARDIZE_H_\n"
  },
  {
    "path": "oneflow/core/common/math_util.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <utility>\n#include \"glog/logging.h\"\n#include \"oneflow/core/common/math_util.h\"\n\nnamespace oneflow {\n\nint64_t Gcd(int64_t m, int64_t n) {\n  if (m < n) { std::swap(m, n); }\n  if (n == 0) { return m; }\n  CHECK_GT(m, 0);\n  CHECK_GT(n, 0);\n  return Gcd(n, m % n);\n}\n\nint64_t Lcm(int64_t m, int64_t n) { return m * n / Gcd(m, n); }\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/common/math_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_COMMON_MATH_UTIL_H_\n#define ONEFLOW_CORE_COMMON_MATH_UTIL_H_\n#include <stdint.h>\n#include \"data_type.h\"\n#include \"oneflow/core/common/util.h\"\n\nnamespace oneflow {\n\n/*\n * math constants\n */\ntemplate<typename T>\nconstexpr T pi = static_cast<T>(3.141592653589793238462643383279502);\n\nint64_t Gcd(int64_t m, int64_t n);\n\nint64_t Lcm(int64_t m, int64_t n);\n\ntemplate<typename T>\nOF_DEVICE_FUNC T DeviceMin(T a, T b) {\n#if defined(__CUDA_ARCH__)\n  return a < b ? a : b;\n#else\n  return std::min(a, b);\n#endif\n}\n\ntemplate<typename T>\nOF_DEVICE_FUNC T DeviceMax(T a, T b) {\n#if defined(__CUDA_ARCH__)\n  return a > b ? a : b;\n#else\n  return std::max(a, b);\n#endif\n}\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_COMMON_MATH_UTIL_H_\n"
  },
  {
    "path": "oneflow/core/common/maybe.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_COMMON_MAYBE_H_\n#define ONEFLOW_CORE_COMMON_MAYBE_H_\n\n#include \"oneflow/core/common/throw.h\"\n#include <google/protobuf/text_format.h>\n#include \"oneflow/core/common/type_traits.h\"\n#include \"oneflow/core/common/either_ptr.h\"\n#include \"oneflow/core/common/shared_or_scalar.h\"\n#include \"oneflow/core/common/error.h\"\n#include \"oneflow/core/common/preprocessor.h\"\n#include \"oneflow/core/common/just.h\"\n\nnamespace oneflow {\n\ntemplate<typename T>\nstruct is_maybe {\n  static const bool value = false;\n};\n\ntemplate<typename T>\nstruct is_maybe<Maybe<T>> {\n  static const bool value = true;\n};\n\ntemplate<typename T>\nclass Maybe<T, typename std::enable_if<!(std::is_same<T, void>::value || IsScalarType<T>::value)\n                                       && !std::is_reference<T>::value>::type>\n    final {\n public:\n  Maybe(const T& data) : data_or_error_(std::make_shared<T>(data)) {}\n  Maybe(T&& data) : data_or_error_(std::make_shared<T>(std::move(data))) {}\n  Maybe(const Error& error) : data_or_error_(error.stacked_error()) {}\n  Maybe(const std::shared_ptr<T>& data) : data_or_error_(data) {}\n  Maybe(std::shared_ptr<T>&& data) : data_or_error_(std::move(data)) {}\n  Maybe(const std::shared_ptr<StackedError>& error) : data_or_error_(error) {}\n  Maybe(const Maybe&) = default;\n  Maybe(Maybe&& other) : data_or_error_(std::move(other.data_or_error_)) {}\n  ~Maybe() = default;\n\n  bool IsOk() const { return data_or_error_.template Has<T>(); }\n  std::shared_ptr<T> Data_YouAreNotAllowedToCallThisFuncOutsideThisFile() const {\n    return data_or_error_.template Get<T>();\n  }\n  std::shared_ptr<StackedError> stacked_error() const {\n    return data_or_error_.template Get<StackedError>();\n  }\n  std::shared_ptr<const ErrorProto> error() const { return stacked_error()->error_proto(); }\n\n  std::string GetSerializedError() const {\n    CHECK(!IsOk());\n    return GetFormatedSerializedError(this->stacked_error());\n  }\n\n  template<typename Type = T>\n  Type GetDataAndSerializedStackedError(std::string* error_str,\n                                        const Type& default_for_error) const {\n    static_assert(std::is_same<T, Type>::value, \"error type for argument 1\");\n    if (IsOk()) {\n      *error_str = StackedError().DebugString();\n      return *Data_YouAreNotAllowedToCallThisFuncOutsideThisFile();\n    } else {\n      *error_str = this->stacked_error()->DebugString();\n      return default_for_error;\n    }\n  }\n\n  template<typename Type = T>\n  std::pair<Type, std::shared_ptr<StackedError>> GetDataAndStackedError(\n      const Type& default_for_error) const {\n    if (IsOk()) {\n      return std::make_pair(*Data_YouAreNotAllowedToCallThisFuncOutsideThisFile(),\n                            std::shared_ptr<StackedError>());\n    } else {\n      return std::make_pair(default_for_error, stacked_error());\n    }\n  }\n\n  std::pair<std::shared_ptr<T>, std::shared_ptr<StackedError>> GetDataPtrAndStackedError() const {\n    if (IsOk()) {\n      return std::make_pair(Data_YouAreNotAllowedToCallThisFuncOutsideThisFile(),\n                            std::shared_ptr<StackedError>());\n    } else {\n      return std::make_pair(std::shared_ptr<T>(), stacked_error());\n    }\n  }\n\n  template<typename Type = T>\n  Type GetOrThrow() const {\n    if (!IsOk()) { ThrowError(stacked_error()); }\n    return *Data_YouAreNotAllowedToCallThisFuncOutsideThisFile();\n  }\n\n  std::shared_ptr<T> GetPtrOrThrow() const {\n    if (!IsOk()) { ThrowError(stacked_error()); }\n    return Data_YouAreNotAllowedToCallThisFuncOutsideThisFile();\n  }\n\n private:\n  EitherPtr<T, StackedError> data_or_error_;\n};\n\ntemplate<typename T>\nclass Maybe<T, typename std::enable_if<std::is_same<T, void>::value>::type> final {\n public:\n  Maybe(const Error& error) : error_or_scalar_(error.stacked_error()) { CheckError(); }\n  Maybe(const std::shared_ptr<StackedError>& error) : error_or_scalar_(error) { CheckError(); }\n  Maybe(const Maybe&) = default;\n  Maybe(Maybe&&) = default;\n  ~Maybe() = default;\n\n  static Maybe Ok() { return Maybe(); }\n\n  bool IsOk() const { return error_or_scalar_.IsScalar(); }\n  void Data_YouAreNotAllowedToCallThisFuncOutsideThisFile() const {}\n  std::shared_ptr<StackedError> stacked_error() const { return error_or_scalar_.shared_ptr(); }\n  std::shared_ptr<const ErrorProto> error() const { return stacked_error()->error_proto(); }\n\n  std::string GetSerializedError() const {\n    CHECK(!IsOk());\n    return GetFormatedSerializedError(this->stacked_error());\n  }\n\n  void GetDataAndSerializedStackedError(std::string* error_str) const {\n    if (IsOk()) {\n      *error_str = StackedError().DebugString();\n    } else {\n      *error_str = this->stacked_error()->DebugString();\n    }\n  }\n\n  std::shared_ptr<StackedError> GetDataAndStackedError() const {\n    if (IsOk()) {\n      return std::shared_ptr<StackedError>();\n    } else {\n      return stacked_error();\n    }\n  }\n\n  void GetOrThrow() const {\n    if (!IsOk()) { ThrowError(stacked_error()); }\n    return Data_YouAreNotAllowedToCallThisFuncOutsideThisFile();\n  }\n\n private:\n  Maybe() : error_or_scalar_(nullptr) {}\n  void CheckError() const {\n    CHECK_NE(this->error()->error_type_case(), ErrorProto::ERROR_TYPE_NOT_SET);\n  }\n\n  SharedOrScalar<StackedError, void*> error_or_scalar_;\n};\n\ninline const std::shared_ptr<StackedError>& UninitializedValueError() {\n  static thread_local const auto& error =\n      (Error::InvalidValueError() << \"uninitialized value\").stacked_error();\n  return error;\n}\n\ntemplate<typename T>\nclass Maybe<T, typename std::enable_if<IsScalarType<T>::value>::type> final {\n public:\n  Maybe(T data) : error_or_scalar_(data) {}\n  Maybe(const Error& error) : error_or_scalar_(error.stacked_error()) { CheckError(); }\n  Maybe(const std::shared_ptr<StackedError>& error) : error_or_scalar_(error) { CheckError(); }\n  Maybe() : error_or_scalar_(UninitializedValueError()) {}\n  Maybe(const Maybe&) = default;\n  Maybe(Maybe&&) = default;\n  ~Maybe() = default;\n\n  void operator=(const Maybe& rhs) { error_or_scalar_ = rhs.error_or_scalar_; }\n\n  bool IsOk() const { return error_or_scalar_.IsScalar(); }\n  T Data_YouAreNotAllowedToCallThisFuncOutsideThisFile() const {\n    return error_or_scalar_.scalar_value();\n  }\n  std::shared_ptr<StackedError> stacked_error() const { return error_or_scalar_.shared_ptr(); }\n  std::shared_ptr<const ErrorProto> error() const { return stacked_error()->error_proto(); }\n\n  std::string GetSerializedError() const {\n    CHECK(!IsOk());\n    return GetFormatedSerializedError(this->stacked_error());\n  }\n\n  T GetDataAndSerializedStackedError(std::string* error_str, const T& default_for_error) const {\n    if (IsOk()) {\n      *error_str = StackedError().DebugString();\n      return Data_YouAreNotAllowedToCallThisFuncOutsideThisFile();\n    } else {\n      *error_str = this->stacked_error()->DebugString();\n      return default_for_error;\n    }\n  }\n\n  std::pair<T, std::shared_ptr<StackedError>> GetDataAndStackedError(\n      const T& default_for_error) const {\n    if (IsOk()) {\n      return std::make_pair(Data_YouAreNotAllowedToCallThisFuncOutsideThisFile(),\n                            std::shared_ptr<StackedError>());\n    } else {\n      return std::make_pair(default_for_error, stacked_error());\n    }\n  }\n\n  T GetOrThrow() const {\n    if (!IsOk()) { ThrowError(stacked_error()); }\n    return Data_YouAreNotAllowedToCallThisFuncOutsideThisFile();\n  }\n\n private:\n  void CheckError() const {\n    CHECK_NE(this->error()->error_type_case(), ErrorProto::ERROR_TYPE_NOT_SET);\n  }\n\n  SharedOrScalar<StackedError, T> error_or_scalar_;\n};\n\ntemplate<typename T>\nclass Maybe<T, typename std::enable_if<!(std::is_same<T, void>::value || IsScalarType<T>::value)\n                                       && std::is_reference<T>::value>::type>\n    final {\n  using ValueT = typename std::remove_reference<T>::type;\n  using PtrT = ValueT*;\n\n public:\n  Maybe(T data) : maybe_ptr_(&data) {}\n  Maybe(const Error& error) : maybe_ptr_(error) {}\n  Maybe(const std::shared_ptr<StackedError>& error) : maybe_ptr_(error) {}\n  Maybe(const Maybe&) = default;\n  Maybe(Maybe&&) = default;\n  ~Maybe() = default;\n\n  bool IsOk() const { return maybe_ptr_.IsOk(); }\n  T Data_YouAreNotAllowedToCallThisFuncOutsideThisFile() const {\n    return *maybe_ptr_.Data_YouAreNotAllowedToCallThisFuncOutsideThisFile();\n  }\n  std::shared_ptr<StackedError> stacked_error() const { return maybe_ptr_.stacked_error(); }\n  std::shared_ptr<const ErrorProto> error() const { return stacked_error()->error_proto(); }\n\n  std::string GetSerializedError() const {\n    CHECK(!IsOk());\n    return maybe_ptr_.GetSerializedError();\n  }\n\n  T GetDataAndSerializedStackedError(std::string* error_str) const {\n    return *maybe_ptr_.GetDataAndSerializedStackedError(error_str, static_cast<PtrT>(nullptr));\n  }\n\n  T GetOrThrow() const {\n    if (!IsOk()) { ThrowError(stacked_error()); }\n    return Data_YouAreNotAllowedToCallThisFuncOutsideThisFile();\n  }\n\n private:\n  Maybe<PtrT> maybe_ptr_;\n};\n\nnamespace {\nstd::string GetFormatedSerializedError(const std::shared_ptr<StackedError>& stacked_error) {\n  // return error msg got from formatted function or debugstring.\n  const auto& maybe_error = TRY(FormatErrorStr(stacked_error));\n  const auto& error_str = maybe_error.GetDataAndStackedError(stacked_error->DebugString());\n  return error_str.first;\n}\n}  // namespace\n}  // namespace oneflow\n\n#define CHECK_OK(...)                                         \\\n  for (auto&& maybe = __JustStackCheckWrapper__(__VA_ARGS__); \\\n       GOOGLE_PREDICT_BRANCH_NOT_TAKEN(!maybe.IsOk());)       \\\n  LOG(FATAL) << OF_PP_STRINGIZE(__VA_ARGS__) << \" is not OK:\\n\" << maybe.GetSerializedError()\n\n#define OF_RETURN_IF_ERROR(...)                                                               \\\n  for (auto&& maybe_##__LINE__ = __JustStackCheckWrapper__(__VA_ARGS__);                      \\\n       !maybe_##__LINE__.IsOk();)                                                             \\\n  return Error(maybe_##__LINE__.stacked_error()).AddStackFrame([](const char* function) {     \\\n    thread_local static auto frame = SymbolOf(ErrorStackFrame(__FILE__, __LINE__, function)); \\\n    return frame;                                                                             \\\n  }(__FUNCTION__))\n\n#define OF_TODO()                                                                             \\\n  return Error::TodoError().AddStackFrame([](const char* function) {                          \\\n    thread_local static auto frame = SymbolOf(ErrorStackFrame(__FILE__, __LINE__, function)); \\\n    return frame;                                                                             \\\n  }(__FUNCTION__))\n#define OF_UNIMPLEMENTED()                                                                    \\\n  return Error::UnimplementedError().AddStackFrame([](const char* function) {                 \\\n    thread_local static auto frame = SymbolOf(ErrorStackFrame(__FILE__, __LINE__, function)); \\\n    return frame;                                                                             \\\n  }(__FUNCTION__))\n\n#define OF_RUNTIME_ERROR()                                                                    \\\n  return Error::RuntimeError().AddStackFrame([](const char* function) {                       \\\n    thread_local static auto frame = SymbolOf(ErrorStackFrame(__FILE__, __LINE__, function)); \\\n    return frame;                                                                             \\\n  }(__FUNCTION__))                                                                            \\\n         << \"RuntimeError \"                                                                   \\\n            \": \"\n#define RETURN_ERROR_WITH_BUG_PROMPT() OF_RUNTIME_ERROR() << kOfBugIssueUploadPrompt\n\n#define OF_LOG_ONCE(x)          \\\n  {                             \\\n    static bool warned = false; \\\n    if (!warned) {              \\\n      warned = true;            \\\n      x;                        \\\n    }                           \\\n  }\n\n#define OF_COMPLIE_OPTION_ERROR()                                                             \\\n  return Error::CompileOptionWrongError().AddStackFrame([](const char* function) {            \\\n    thread_local static auto frame = SymbolOf(ErrorStackFrame(__FILE__, __LINE__, function)); \\\n    return frame;                                                                             \\\n  }(__FUNCTION__))                                                                            \\\n         << \"Compile option wrong: \"\n\n#define CHECK_OR_RETURN_INTERNAL(expr, error_msg)                           \\\n  if (!(expr))                                                              \\\n  return Error::CheckFailedError().AddStackFrame([](const char* function) { \\\n    thread_local static auto frame =                                        \\\n        SymbolOf(ErrorStackFrame(__FILE__, __LINE__, function, error_msg)); \\\n    return frame;                                                           \\\n  }(__FUNCTION__))\n\n#define CHECK_OR_RETURN_ERROR(expr)                                                           \\\n  if (!(expr))                                                                                \\\n  return Error::CheckFailedError().AddStackFrame([](const char* function) {                   \\\n    thread_local static auto frame = SymbolOf(ErrorStackFrame(__FILE__, __LINE__, function)); \\\n    return frame;                                                                             \\\n  }(__FUNCTION__))\n\n// NOTE: Please contact @daquexian if you need to modify these CHECK_(XX_)OR_RETURN macros. There\n// are some static analyzers depending on the internal implementation of them.\n#define CHECK_OR_RETURN(expr)                                            \\\n  CHECK_OR_RETURN_INTERNAL(expr, OF_PP_STRINGIZE(CHECK_OR_RETURN(expr))) \\\n      << \"Check failed: (\" << OF_PP_STRINGIZE(expr) << \") \" << Error::kOverrideThenMergeMessage\n\n#define CHECK_EQ_OR_RETURN(lhs, rhs)                                                      \\\n  CHECK_OR_RETURN_INTERNAL((lhs) == (rhs), OF_PP_STRINGIZE(CHECK_EQ_OR_RETURN(lhs, rhs))) \\\n      << \"Check failed: (\" << (lhs) << \" == \" << (rhs) << \") \" << Error::kOverrideThenMergeMessage\n\n#define CHECK_GE_OR_RETURN(lhs, rhs)                                                      \\\n  CHECK_OR_RETURN_INTERNAL((lhs) >= (rhs), OF_PP_STRINGIZE(CHECK_GE_OR_RETURN(lhs, rhs))) \\\n      << \"Check failed: (\" << (lhs) << \" >= \" << (rhs) << \") \" << Error::kOverrideThenMergeMessage\n\n#define CHECK_GT_OR_RETURN(lhs, rhs)                                                     \\\n  CHECK_OR_RETURN_INTERNAL((lhs) > (rhs), OF_PP_STRINGIZE(CHECK_GT_OR_RETURN(lhs, rhs))) \\\n      << \"Check failed: (\" << (lhs) << \" > \" << (rhs) << \") \" << Error::kOverrideThenMergeMessage\n\n#define CHECK_LE_OR_RETURN(lhs, rhs)                                                      \\\n  CHECK_OR_RETURN_INTERNAL((lhs) <= (rhs), OF_PP_STRINGIZE(CHECK_LE_OR_RETURN(lhs, rhs))) \\\n      << \"Check failed: (\" << (lhs) << \" <= \" << (rhs) << \") \" << Error::kOverrideThenMergeMessage\n\n#define CHECK_LT_OR_RETURN(lhs, rhs)                                                     \\\n  CHECK_OR_RETURN_INTERNAL((lhs) < (rhs), OF_PP_STRINGIZE(CHECK_LT_OR_RETURN(lhs, rhs))) \\\n      << \"Check failed: (\" << (lhs) << \" < \" << (rhs) << \") \" << Error::kOverrideThenMergeMessage\n\n#define CHECK_NE_OR_RETURN(lhs, rhs)                                                      \\\n  CHECK_OR_RETURN_INTERNAL((lhs) != (rhs), OF_PP_STRINGIZE(CHECK_NE_OR_RETURN(lhs, rhs))) \\\n      << \"Check failed: (\" << (lhs) << \" != \" << (rhs) << \") \" << Error::kOverrideThenMergeMessage\n\n#define CHECK_STREQ_OR_RETURN(lhs, rhs) CHECK_EQ_OR_RETURN(std::string(lhs), std::string(rhs))\n\n#define CHECK_STRNE_OR_RETURN(lhs, rhs) CHECK_NE_OR_RETURN(std::string(lhs), std::string(rhs))\n\n#define CHECK_NOTNULL_OR_RETURN(ptr) CHECK_OR_RETURN(ptr != nullptr)\n\n#define CHECK_ISNULL_OR_RETURN(ptr) CHECK_OR_RETURN(ptr == nullptr)\n\n#define TODO_THEN_RETURN() OF_TODO()\n\n#define UNIMPLEMENTED_THEN_RETURN() OF_UNIMPLEMENTED()\n\n#endif  // ONEFLOW_CORE_COMMON_MAYBE_H_\n"
  },
  {
    "path": "oneflow/core/common/maybe_test.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/maybe.h\"\n#include \"gtest/gtest.h\"\n#include <gtest/gtest-death-test.h>\n#include <memory>\n#include \"oneflow/core/common/exception.h\"\n#include \"oneflow/core/common/util.h\"\n\nnamespace oneflow {\nnamespace test {\n\nTEST(Maybe, JUST_MSG) {\n  auto f = [](int x) -> Maybe<int> {\n    if (x > 10) { return Error::InvalidValueError() << \"input value \" << x; }\n\n    return 233;\n  };\n\n  auto g = [](int x) { return x * x - 5 * x + 3; };\n\n  auto h = [&](int x) -> Maybe<int> {\n    auto y = g(x);\n    return JUST_MSG(f(y), \"input value g(\", x, \")\");\n  };\n\n  auto i = [&](float x) -> Maybe<int> {\n    int y = x;\n    return JUST_MSG(h(y), std::stringstream() << \"input value int(\" << x << \")\");\n  };\n\n  auto data = CHECK_JUST(i(1));\n  ASSERT_EQ(data, 233);\n\n  auto err = i(10.123).stacked_error();\n  ASSERT_EQ(err->error_proto()->msg(), R\"(input value 53)\");\n  ASSERT_GE(err->stack_frame().size(), 2);\n  ASSERT_EQ(err->stack_frame().at(0)->code_text(), \"f(y)\");\n  ASSERT_EQ(err->stack_frame().at(1)->code_text(), \"h(y)\");\n\n  try {\n    CHECK_JUST(i(10.234));\n  } catch (const RuntimeException& e) {\n    EXPECT_TRUE(std::string(e.what()).find(R\"(input value 53)\") != std::string::npos);\n  }\n}\n\nTEST(Maybe, CHECK_OR_RETURN) {\n  auto f = [](int x) -> Maybe<int> {\n    CHECK_OR_RETURN(x > 10);\n    return 233;\n  };\n\n  auto i = [&](float x) -> Maybe<int> { return JUST(f(x)); };\n\n  auto data = CHECK_JUST(i(20));\n  ASSERT_EQ(data, 233);\n\n  auto err = i(1).stacked_error();\n  ASSERT_GE(err->stack_frame().size(), 2);\n  ASSERT_EQ(err->stack_frame().at(0)->code_text(), \"CHECK_OR_RETURN(x > 10)\");\n  ASSERT_EQ(err->stack_frame().at(1)->code_text(), \"f(x)\");\n}\n\nTEST(Maybe, CHECK_OK) {\n  auto f = [](int x) -> Maybe<int> {\n    if (x > 10) { return Error::InvalidValueError() << \"input value \" << x; }\n\n    return 233;\n  };\n\n  auto g = [&](int x) -> Maybe<int> {\n    auto y = JUST(f(x));\n    return f(y);\n  };\n\n  // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto)\n  ASSERT_EXIT(CHECK_OK(g(11)), testing::KilledBySignal(SIGABRT), R\"(g\\(11\\) is not OK)\");\n}\n\nTEST(Maybe, Noncopyable) { Maybe<std::unique_ptr<int>> a{std::make_unique<int>(1)}; }\n\n}  // namespace test\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/common/mem_util.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/mem_util.h\"\n#include \"oneflow/core/vm/vm_util.h\"\n#include \"oneflow/core/vm/virtual_machine.h\"\n\n#include <unistd.h>\n#include <sys/sysinfo.h>\n\nnamespace oneflow {\n\nnamespace {\nstruct ProcStat {\n  std::string pid, comm, state, ppid, pgrp, session, tty_nr;\n  std::string tpgid, flags, minflt, cminflt, majflt, cmajflt;\n  std::string utime, stime, cutime, cstime, priority, nice;\n  std::string num_threads, itrealvalue, starttime;\n  unsigned long vsize = 0;\n  long rss = 0;\n};\n\nMaybe<void> CPUSynchronize() {\n  if (Singleton<VirtualMachine>::Get() != nullptr) { return vm::CurrentRankSync(); }\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\n// Reference: https://stackoverflow.com/questions/669438/how-to-get-memory-usage-at-runtime-using-c\nvoid ProcessMemUsage(double* vm_usage, double* resident_set) {\n  *vm_usage = 0.0;\n  *resident_set = 0.0;\n\n#ifdef __linux__\n  // 'file' stat seems to give the most reliable results\n  std::ifstream stat_stream(\"/proc/self/stat\", std::ios_base::in);\n  ProcStat proc_stat;\n  stat_stream >> proc_stat.pid >> proc_stat.comm >> proc_stat.state >> proc_stat.ppid\n      >> proc_stat.pgrp >> proc_stat.session >> proc_stat.tty_nr >> proc_stat.tpgid\n      >> proc_stat.flags >> proc_stat.minflt >> proc_stat.cminflt >> proc_stat.majflt\n      >> proc_stat.cmajflt >> proc_stat.utime >> proc_stat.stime >> proc_stat.cutime\n      >> proc_stat.cstime >> proc_stat.priority >> proc_stat.nice >> proc_stat.num_threads\n      >> proc_stat.itrealvalue >> proc_stat.starttime >> proc_stat.vsize\n      >> proc_stat.rss;  // don't care about the rest\n\n  stat_stream.close();\n\n  long page_size_kb = sysconf(_SC_PAGE_SIZE);  // in case x86-64 is configured to use 2MB pages\n  // return with MB\n  *vm_usage = proc_stat.vsize >> 20;\n  // return with MB\n  *resident_set = (proc_stat.rss * page_size_kb) >> 20;\n#endif  // __linux__\n}\n\nMaybe<double> GetCPUMemoryUsed() {\n  JUST(CPUSynchronize());\n  double vm_ = 0, rss_ = 0;\n  ProcessMemUsage(&vm_, &rss_);\n  return rss_;\n}\n\nstd::string FormatMemSize(uint64_t size) {\n  std::ostringstream os;\n  os.precision(1);\n  os << std::fixed;\n  if (size <= 1024UL) {\n    os << size << \" Bytes\";\n  } else if (size <= 1048576UL) {\n    os << ((float)size / 1024.0) << \" KB\";\n  } else if (size <= 1073741824UL) {\n    os << ((float)size / 1048576.0) << \" MB\";\n  } else {\n    os << ((float)size / 1073741824.0) << \" GB\";\n  }\n  return os.str();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/common/mem_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_COMMON_MEM_UTIL_H_\n#define ONEFLOW_CORE_COMMON_MEM_UTIL_H_\n\n#include <chrono>\n#include <sstream>\n#include <string>\n\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/common/maybe.h\"\n\nnamespace oneflow {\nvoid ProcessMemUsage(double* vm_usage, double* resident_set);\nstd::string FormatMemSize(uint64_t size);\nMaybe<double> GetCPUMemoryUsed();\n}  // namespace oneflow\n\n#define LOG_MEM(...)                                                                \\\n  double vm_ = 0, rss_ = 0;                                                         \\\n  ProcessMemUsage(&vm_, &rss_);                                                     \\\n  VLOG(1) << \"File \" __FILE__ << \", Line \" << __LINE__ << \", Func \" << __FUNCTION__ \\\n          << \", Mem size RSS \" << rss_ << \"MB.\"\n\n#endif  // ONEFLOW_CORE_COMMON_MEM_UTIL_H_\n"
  },
  {
    "path": "oneflow/core/common/memory_format.proto",
    "content": "syntax = \"proto2\";\npackage oneflow;\n\nenum MemoryFormat {\n  kContiguous = 0;\n  kChannelsLast = 1;\n  kPreserve = 2;\n\n  kMemoryFormatCount = 3;\n};\n"
  },
  {
    "path": "oneflow/core/common/meta_util.hpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_COMMON_META_UTIL_HPP_\n#define ONEFLOW_CORE_COMMON_META_UTIL_HPP_\n\n#include <utility>\n#include <tuple>\n\nnamespace oneflow {\n\ntemplate<typename... Args, typename Func, std::size_t... Idx>\nvoid for_each(const std::tuple<Args...>& t, Func&& f, std::index_sequence<Idx...>) {\n  (std::forward<Func>(f)(std::get<Idx>(t)), ...);\n}\n\ntemplate<typename... Args, typename Func, std::size_t... Idx>\nvoid for_each_i(const std::tuple<Args...>& t, Func&& f, std::index_sequence<Idx...>) {\n  (std::forward<Func>(f)(std::get<Idx>(t), std::integral_constant<size_t, Idx>{}), ...);\n}\n\ntemplate<typename T>\nusing remove_const_reference_t = std::remove_const_t<std::remove_reference_t<T>>;\n\ntemplate<std::size_t... Is>\nauto make_tuple_from_sequence(std::index_sequence<Is...>) {\n  return std::make_tuple(Is...);\n}\n\ntemplate<std::size_t N>\nconstexpr auto make_tuple_from_sequence() {\n  return make_tuple_from_sequence(std::make_index_sequence<N>{});\n}\n\nnamespace detail {\ntemplate<class Tuple, class F, std::size_t... Is>\nvoid tuple_switch(const std::size_t i, Tuple&& t, F&& f, std::index_sequence<Is...>) {\n  (void)std::initializer_list<int>{\n      (i == Is && ((void)std::forward<F>(f)(std::integral_constant<size_t, Is>{}), 0))...};\n}\n}  // namespace detail\n\ntemplate<class Tuple, class F>\ninline void tuple_switch(const std::size_t i, Tuple&& t, F&& f) {\n  constexpr auto N = std::tuple_size<std::remove_reference_t<Tuple>>::value;\n\n  detail::tuple_switch(i, std::forward<Tuple>(t), std::forward<F>(f),\n                       std::make_index_sequence<N>{});\n}\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_COMMON_META_UTIL_HPP_\n"
  },
  {
    "path": "oneflow/core/common/nd_index.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/nd_index.h\"\n#include \"oneflow/core/common/protobuf.h\"\n\nnamespace oneflow {\n\nNdIndex::NdIndex(const std::initializer_list<int64_t>& dim_vec) : dim_vec_(dim_vec) {}\n\nNdIndex::NdIndex(const DimVector& dim_vec) : dim_vec_(dim_vec) {}\n\nNdIndex& NdIndex::operator=(const NdIndex& shape) {\n  dim_vec_ = shape.dim_vec_;\n  return *this;\n}\n\nbool NdIndex::operator==(const NdIndex& rhs) const { return dim_vec_ == rhs.dim_vec_; }\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/common/nd_index.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_COMMON_ND_INDEX_H_\n#define ONEFLOW_CORE_COMMON_ND_INDEX_H_\n\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/common/shape.h\"\n\nnamespace oneflow {\n\nclass NdIndex final {\n public:\n  NdIndex() = default;\n  explicit NdIndex(const DimVector& dim_vec);\n  NdIndex(const std::initializer_list<int64_t>& dim_vec);\n  ~NdIndex() = default;\n  NdIndex& operator=(const NdIndex& other);\n\n  bool operator==(const NdIndex& rhs) const;\n  bool operator!=(const NdIndex& rhs) const { return !(*this == rhs); }\n\n  const DimVector& dim_vec() const { return dim_vec_; }\n\n  int64_t At(int64_t index) const { return dim_vec_.at(index); }\n  int64_t NumAxes() const { return dim_vec_.size(); }\n\n private:\n  DimVector dim_vec_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_COMMON_ND_INDEX_H_\n"
  },
  {
    "path": "oneflow/core/common/nd_index_offset_helper.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_COMMON_ND_INDEX_OFFSET_HELPER_H_\n#define ONEFLOW_CORE_COMMON_ND_INDEX_OFFSET_HELPER_H_\n\n#include \"oneflow/core/common/data_type.h\"\n#include <cassert>\n\nnamespace oneflow {\n\ntemplate<typename T, int N>\nclass NdIndexOffsetHelper {\n public:\n  OF_DEVICE_FUNC NdIndexOffsetHelper() = default;\n\n  template<class... Ts>\n  OF_DEVICE_FUNC explicit NdIndexOffsetHelper(T d0, Ts... dims) {\n    constexpr int n = 1 + sizeof...(dims);\n    static_assert(n <= N, \"\");\n    T dims_arr[n] = {d0, static_cast<T>(dims)...};\n    InitStrides(dims_arr, n);\n  }\n\n  OF_DEVICE_FUNC explicit NdIndexOffsetHelper(const T* dims) { InitStrides(dims, N); }\n\n  template<typename U>\n  OF_DEVICE_FUNC explicit NdIndexOffsetHelper(const U* dims) {\n    T dims_arr[N];\n    for (int i = 0; i < N; ++i) { dims_arr[i] = dims[i]; }\n    InitStrides(dims_arr, N);\n  }\n\n  OF_DEVICE_FUNC explicit NdIndexOffsetHelper(const T* dims, int n) { InitStrides(dims, n); }\n\n  template<typename U>\n  OF_DEVICE_FUNC explicit NdIndexOffsetHelper(const U* dims, int n) {\n    T dims_arr[N];\n    for (int i = 0; i < N; ++i) {\n      if (i < n) { dims_arr[i] = dims[i]; }\n    }\n    InitStrides(dims_arr, n);\n  }\n\n  virtual ~NdIndexOffsetHelper() = default;\n\n  OF_DEVICE_FUNC T NdIndexToOffset(const T* index) const {\n    T offset = 0;\n#ifdef __CUDA_ARCH__\n#pragma unroll\n#endif\n    for (int i = 0; i < N; ++i) { offset += index[i] * stride_[i]; }\n    return offset;\n  }\n\n  OF_DEVICE_FUNC T NdIndexToOffset(const T* index, int n) const {\n    assert(n <= N);\n    T offset = 0;\n#ifdef __CUDA_ARCH__\n#pragma unroll\n#endif\n    for (int i = 0; i < N; ++i) {\n      if (i < n) { offset += index[i] * stride_[i]; }\n    }\n    return offset;\n  }\n\n  template<class... Ts>\n  OF_DEVICE_FUNC T NdIndexToOffset(T d0, Ts... others) const {\n    constexpr int n = 1 + sizeof...(others);\n    static_assert(n <= N, \"\");\n    T index[n] = {d0, others...};\n    T offset = 0;\n#ifdef __CUDA_ARCH__\n#pragma unroll\n#endif\n    for (int i = 0; i < n - 1; ++i) { offset += index[i] * stride_[i]; }\n    if (n == N) {\n      offset += index[n - 1];\n    } else {\n      offset += index[n - 1] * stride_[n - 1];\n    }\n    return offset;\n  }\n\n  OF_DEVICE_FUNC void OffsetToNdIndex(T offset, T* index) const {\n    T remaining = offset;\n#ifdef __CUDA_ARCH__\n#pragma unroll\n#endif\n    for (int i = 0; i < N - 1; ++i) {\n      const T idx = remaining / stride_[i];\n      index[i] = idx;\n      remaining = remaining - idx * stride_[i];\n    }\n    index[N - 1] = remaining;\n  }\n\n  OF_DEVICE_FUNC void OffsetToNdIndex(T offset, T* index, int n) const {\n    assert(n <= N);\n    T remaining = offset;\n#ifdef __CUDA_ARCH__\n#pragma unroll\n#endif\n    for (int i = 0; i < N; ++i) {\n      if (i < n) {\n        const T idx = remaining / stride_[i];\n        index[i] = idx;\n        remaining = remaining - idx * stride_[i];\n      }\n    }\n  }\n\n  template<class... Ts>\n  OF_DEVICE_FUNC void OffsetToNdIndex(T offset, T& d0, Ts&... others) const {\n    constexpr int n = 1 + sizeof...(others);\n    static_assert(n <= N, \"\");\n    T* index[n] = {&d0, &others...};\n    T remaining = offset;\n#ifdef __CUDA_ARCH__\n#pragma unroll\n#endif\n    for (int i = 0; i < n - 1; ++i) {\n      const T idx = remaining / stride_[i];\n      *index[i] = idx;\n      remaining = remaining - idx * stride_[i];\n    }\n    if (n == N) {\n      *index[n - 1] = remaining;\n    } else {\n      *index[n - 1] = remaining / stride_[n - 1];\n    }\n  }\n\n  OF_DEVICE_FUNC constexpr int Size() const { return N; }\n\n protected:\n  OF_DEVICE_FUNC void InitStrides(const T* dims, const int n) {\n    for (int i = n - 1; i < N; ++i) { stride_[i] = 1; }\n    for (int i = n - 2; i >= 0; --i) { stride_[i] = dims[i + 1] * stride_[i + 1]; }\n  }\n\n  T stride_[N];\n};\n\ntemplate<typename T, int N>\nclass NdIndexStrideOffsetHelper : public NdIndexOffsetHelper<T, N> {\n public:\n  OF_DEVICE_FUNC NdIndexStrideOffsetHelper() = default;\n  OF_DEVICE_FUNC explicit NdIndexStrideOffsetHelper(const T* strides) {\n    for (int i = 0; i < N; ++i) { stride_[i] = strides[i]; }\n  }\n\n  template<typename U>\n  OF_DEVICE_FUNC explicit NdIndexStrideOffsetHelper(const U* strides) {\n    for (int i = 0; i < N; ++i) { stride_[i] = static_cast<T>(strides[i]); }\n  }\n\n  OF_DEVICE_FUNC explicit NdIndexStrideOffsetHelper(const T* strides, int n) {\n    for (int i = 0; i < N; ++i) {\n      if (i < n) { stride_[i] = strides[i]; }\n    }\n  }\n\n  template<typename U>\n  OF_DEVICE_FUNC explicit NdIndexStrideOffsetHelper(const U* strides, int n) {\n    for (int i = 0; i < N; ++i) {\n      if (i < n) { stride_[i] = static_cast<T>(strides[i]); }\n    }\n  }\n\n private:\n  using NdIndexOffsetHelper<T, N>::stride_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_COMMON_ND_INDEX_OFFSET_HELPER_H_\n"
  },
  {
    "path": "oneflow/core/common/nd_index_offset_helper_test.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n// include sstream first to avoid some compiling error\n// caused by the following trick\n// reference: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=65899\n#include <sstream>\n#include \"gtest/gtest.h\"\n#define private public\n#define protected public\n#include \"oneflow/core/common/nd_index_offset_helper.h\"\n\nnamespace oneflow {\n\nnamespace test {\n\ntemplate<typename T, int ndims>\nvoid test_3d() {\n  const T d0_max = 3;\n  const T d1_max = 4;\n  const T d2_max = 5;\n  const NdIndexOffsetHelper<T, ndims> helper(d0_max, d1_max, d2_max);\n  for (T d0 = 0; d0 < d0_max; ++d0) {\n    const T offset0 = d0 * d1_max * d2_max;\n    {\n      std::vector<T> expected0({d0});\n      {\n        std::vector<T> dims(1);\n        helper.OffsetToNdIndex(offset0, dims.data(), 1);\n        ASSERT_EQ(expected0, dims);\n      }\n      {\n        std::vector<T> dims(1);\n        helper.OffsetToNdIndex(offset0, dims.at(0));\n        ASSERT_EQ(expected0, dims);\n      }\n      ASSERT_EQ(offset0, helper.NdIndexToOffset(expected0.data(), 1));\n      ASSERT_EQ(offset0, helper.NdIndexToOffset(expected0.at(0)));\n    }\n    for (T d1 = 0; d1 < d1_max; ++d1) {\n      const T offset1 = offset0 + d1 * d2_max;\n      {\n        std::vector<T> expected1({d0, d1});\n        {\n          std::vector<T> dims(2);\n          helper.OffsetToNdIndex(offset1, dims.data(), 2);\n          ASSERT_EQ(expected1, dims);\n        }\n        {\n          std::vector<T> dims(2);\n          helper.OffsetToNdIndex(offset1, dims.at(0), dims.at(1));\n          ASSERT_EQ(expected1, dims);\n        }\n        ASSERT_EQ(offset1, helper.NdIndexToOffset(expected1.data(), 2));\n        ASSERT_EQ(offset1, helper.NdIndexToOffset(expected1.at(0), expected1.at(1)));\n      }\n      for (T d2 = 0; d2 < d2_max; ++d2) {\n        const T offset2 = offset1 + d2;\n        {\n          std::vector<T> expected2({d0, d1, d2});\n          {\n            std::vector<T> dims(3);\n            helper.OffsetToNdIndex(offset2, dims.data(), 3);\n            ASSERT_EQ(expected2, dims);\n          }\n          {\n            std::vector<T> dims(3);\n            helper.OffsetToNdIndex(offset2, dims.at(0), dims.at(1), dims.at(2));\n            ASSERT_EQ(expected2, dims);\n          }\n          if (ndims == 3) {\n            std::vector<T> dims(3);\n            helper.OffsetToNdIndex(offset2, dims.data());\n            ASSERT_EQ(expected2, dims);\n            ASSERT_EQ(offset2, helper.NdIndexToOffset(expected2.data()));\n          }\n          ASSERT_EQ(offset2, helper.NdIndexToOffset(expected2.data(), 3));\n          ASSERT_EQ(offset2,\n                    helper.NdIndexToOffset(expected2.at(0), expected2.at(1), expected2.at(2)));\n        }\n      }\n    }\n  }\n}\n\nTEST(NdIndexOffsetHelper, static_3d) {\n  test_3d<int32_t, 3>();\n  test_3d<int64_t, 3>();\n}\n\nTEST(NdIndexOffsetHelper, dynamic_3d) {\n  test_3d<int32_t, 4>();\n  test_3d<int64_t, 4>();\n  test_3d<int32_t, 8>();\n  test_3d<int64_t, 8>();\n}\n\ntemplate<typename T>\nvoid test_constructor() {\n  const T d0 = 3;\n  const T d1 = 4;\n  const T d2 = 5;\n  // static\n  {\n    std::vector<T> dims({d0, d1, d2});\n    const NdIndexOffsetHelper<T, 3> helper1(d0, d1, d2);\n    const NdIndexOffsetHelper<T, 3> helper2(dims.data());\n    const NdIndexOffsetHelper<T, 3> helper3(dims.data(), dims.size());\n    std::vector<T> stride({d1 * d2, d2, 1});\n    for (int i = 0; i < 3; ++i) {\n      ASSERT_EQ(helper1.stride_[i], stride[i]);\n      ASSERT_EQ(helper2.stride_[i], stride[i]);\n      ASSERT_EQ(helper3.stride_[i], stride[i]);\n    }\n  }\n  // dynamic\n  {\n    std::vector<T> dims({d0, d1, d2});\n    const NdIndexOffsetHelper<T, 6> helper1(d0, d1, d2);\n    const NdIndexOffsetHelper<T, 6> helper2(dims.data(), dims.size());\n    std::vector<T> stride({d1 * d2, d2, 1, 1, 1, 1});\n    for (int i = 0; i < 6; ++i) {\n      ASSERT_EQ(helper1.stride_[i], stride[i]);\n      ASSERT_EQ(helper2.stride_[i], stride[i]);\n    }\n  }\n}\n\nTEST(NdIndexOffsetHelper, constructor) {\n  test_constructor<int32_t>();\n  test_constructor<int64_t>();\n}\n\ntemplate<typename T, typename U>\nvoid test_stride_constructor() {\n  const T d1 = 5;\n  const T d2 = 6;\n\n  const U u1 = 5;\n  const U u2 = 6;\n\n  std::vector<T> strides({d1 * d2, d2, 1});\n  std::vector<U> strides_u({u1 * u2, u2, 1});\n\n  const NdIndexStrideOffsetHelper<T, 3> helper1(strides.data());\n  const NdIndexStrideOffsetHelper<T, 3> helper2(strides.data(), strides.size());\n  const NdIndexStrideOffsetHelper<T, 3> helper3(strides_u.data());\n  const NdIndexStrideOffsetHelper<T, 3> helper4(strides_u.data(), strides_u.size());\n\n  for (int i = 0; i < 3; i++) {\n    ASSERT_EQ(helper1.stride_[i], strides[i]);\n    ASSERT_EQ(helper2.stride_[i], strides[i]);\n    ASSERT_EQ(helper3.stride_[i], strides_u[i]);\n    ASSERT_EQ(helper4.stride_[i], strides_u[i]);\n  }\n}\n\nTEST(NdIndexStrideOffsetHelper, constructor) {\n  test_stride_constructor<int32_t, int64_t>();\n  test_stride_constructor<int64_t, int32_t>();\n}\n\n}  // namespace test\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/common/not_equal_to_previous_adjacent_iterator.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_COMMON_NOT_EQUAL_TO_PREVIOUS_ADJACENT_ITERATOR_H_\n#define ONEFLOW_CORE_COMMON_NOT_EQUAL_TO_PREVIOUS_ADJACENT_ITERATOR_H_\n\n#include <iterator>\n\nnamespace oneflow {\n\n#define ITER_DEVICE_FUNC __host__ __device__ __forceinline__\n\ntemplate<typename ValueType, typename UnderlyingT, typename OffsetT = ptrdiff_t>\nclass NotEqualToPreviousAdjacentIterator {\n public:\n  typedef NotEqualToPreviousAdjacentIterator self_type;\n  typedef OffsetT difference_type;\n  typedef ValueType value_type;\n  typedef ValueType* pointer;\n  typedef ValueType reference;\n  typedef std::random_access_iterator_tag iterator_category;\n\n private:\n  const UnderlyingT* underlying;\n  OffsetT offset;\n\n public:\n  ITER_DEVICE_FUNC\n  NotEqualToPreviousAdjacentIterator(const UnderlyingT* underlying, OffsetT offset)\n      : underlying(underlying), offset(offset) {}\n\n  ITER_DEVICE_FUNC self_type operator++(int) {\n    self_type ret = *this;\n    offset++;\n    return ret;\n  }\n\n  ITER_DEVICE_FUNC self_type operator++() {\n    offset++;\n    return *this;\n  }\n\n  ITER_DEVICE_FUNC reference operator*() const {\n    return offset == 0 ? 0 : (underlying[offset] == underlying[offset - 1] ? 0 : 1);\n  }\n\n  template<typename Distance>\n  ITER_DEVICE_FUNC self_type operator+(Distance n) const {\n    self_type ret(underlying, offset + n);\n    return ret;\n  }\n\n  template<typename Distance>\n  ITER_DEVICE_FUNC self_type& operator+=(Distance n) {\n    offset += n;\n    return *this;\n  }\n\n  template<typename Distance>\n  ITER_DEVICE_FUNC self_type operator-(Distance n) const {\n    self_type ret(underlying, offset - n);\n    return ret;\n  }\n\n  template<typename Distance>\n  ITER_DEVICE_FUNC self_type& operator-=(Distance n) {\n    offset -= n;\n    return *this;\n  }\n\n  ITER_DEVICE_FUNC difference_type operator-(self_type other) const {\n    return offset - other.offset;\n  }\n\n  template<typename Distance>\n  ITER_DEVICE_FUNC reference operator[](Distance n) const {\n    return *(*this + n);\n  }\n\n  ITER_DEVICE_FUNC pointer operator->() { return nullptr; }\n\n  ITER_DEVICE_FUNC bool operator==(const self_type& rhs) {\n    return (offset == rhs.offset) && ((underlying == rhs.underlying));\n  }\n\n  ITER_DEVICE_FUNC bool operator!=(const self_type& rhs) {\n    return offset != rhs.offset || underlying != rhs.underlying;\n  }\n\n  friend std::ostream& operator<<(std::ostream& os, const self_type& itr) { return os; }\n};\n\n#undef ITER_DEVICE_FUNC\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_COMMON_NOT_EQUAL_TO_PREVIOUS_ADJACENT_ITERATOR_H_\n"
  },
  {
    "path": "oneflow/core/common/notifier.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/notifier.h\"\n#include \"oneflow/core/common/foreign_lock_helper.h\"\n#include \"oneflow/core/common/env_var/env_var.h\"\n\nnamespace oneflow {\n\nNotifierStatus Notifier::Notify() {\n  bool notify = false;\n  {\n    std::unique_lock<std::mutex> lock(mutex_);\n    if (is_closed_) { return kNotifierStatusErrorClosed; }\n    notify = (notified_cnt_ == 0);\n    ++notified_cnt_;\n  }\n  if (notify) { cond_.notify_one(); }\n  return kNotifierStatusSuccess;\n}\n\nNotifierStatus Notifier::WaitAndClearNotifiedCnt() {\n  std::unique_lock<std::mutex> lock(mutex_);\n  cond_.wait(lock, [this]() { return notified_cnt_ > 0 || is_closed_; });\n  if (notified_cnt_ == 0) { return kNotifierStatusErrorClosed; }\n  notified_cnt_ = 0;\n  return kNotifierStatusSuccess;\n}\n\nMaybe<void> Notifier::TimedWaitAndClearNotifiedCnt(size_t timeout_seconds) {\n  return Singleton<ForeignLockHelper>::Get()->WithScopedRelease([&, this]() -> Maybe<void> {\n    std::chrono::duration<size_t> seconds(timeout_seconds);\n    std::unique_lock<std::mutex> lock(mutex_);\n    CHECK_OR_RETURN(cond_.wait_for(lock, seconds, [this]() {\n      return notified_cnt_ > 0 || is_closed_;\n    })) << Error::TimeoutError();\n    CHECK_GT_OR_RETURN(notified_cnt_, 0) << \"notifier closed.\";\n    notified_cnt_ = 0;\n    return Maybe<void>::Ok();\n  });\n}\n\nMaybe<void> Notifier::TimedWaitAndClearNotifiedCnt(\n    const std::function<Maybe<bool>()>& StopWaitingAfterTimeout) {\n  while (true) {\n    auto status = TRY(TimedWaitAndClearNotifiedCnt(EnvInteger<ONEFLOW_TIMEOUT_SECONDS>()));\n    if (status.IsOk()) { return status; }\n    if (!status.error()->has_timeout_error()) { return status; }\n    if (JUST(StopWaitingAfterTimeout())) { return status; }\n  }\n  UNIMPLEMENTED_THEN_RETURN();\n}\n\nvoid Notifier::Close() {\n  std::unique_lock<std::mutex> lock(mutex_);\n  is_closed_ = true;\n  cond_.notify_all();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/common/notifier.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_COMMON_NOTIFIER_H_\n#define ONEFLOW_CORE_COMMON_NOTIFIER_H_\n\n#include \"oneflow/core/common/util.h\"\n\nnamespace oneflow {\n\nenum NotifierStatus { kNotifierStatusSuccess = 0, kNotifierStatusErrorClosed };\n\nclass Notifier final {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(Notifier);\n  Notifier() : notified_cnt_(0), is_closed_(false) {}\n  ~Notifier() = default;\n\n  NotifierStatus Notify();\n  NotifierStatus WaitAndClearNotifiedCnt();\n  void Close();\n\n  Maybe<void> TimedWaitAndClearNotifiedCnt(size_t timeout_seconds);\n  Maybe<void> TimedWaitAndClearNotifiedCnt(\n      const std::function<Maybe<bool>()>& StopWaitingAfterTimeout);\n\n private:\n  size_t notified_cnt_;\n  std::mutex mutex_;\n  bool is_closed_;\n  std::condition_variable cond_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_COMMON_NOTIFIER_H_\n"
  },
  {
    "path": "oneflow/core/common/of_unused.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_COMMON_OF_UNUSED_H_\n#define ONEFLOW_CORE_COMMON_OF_UNUSED_H_\n\nnamespace oneflow {\n\n#define OF_UNUSED(x) (void)(x)\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_COMMON_OF_UNUSED_H_\n"
  },
  {
    "path": "oneflow/core/common/op_args_reserved_size.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_COMMON_OP_ARGS_RESERVED_SIZE_H_\n#define ONEFLOW_CORE_COMMON_OP_ARGS_RESERVED_SIZE_H_\n\nnamespace oneflow {\n\nconstexpr static int kOpArgsReservedSize = 4;\n\n}\n\n#endif  // ONEFLOW_CORE_COMMON_OP_ARGS_RESERVED_SIZE_H_\n"
  },
  {
    "path": "oneflow/core/common/op_args_vector.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_COMMON_OP_ARGS_VECTOR_H_\n#define ONEFLOW_CORE_COMMON_OP_ARGS_VECTOR_H_\n\n#include \"oneflow/core/common/small_vector.h\"\n#include \"oneflow/core/common/op_args_reserved_size.h\"\n\nnamespace oneflow {\n\ntemplate<typename T>\nusing OpArgsVector = small_vector<T>;\n\n}\n\n#endif  // ONEFLOW_CORE_COMMON_OP_ARGS_VECTOR_H_\n"
  },
  {
    "path": "oneflow/core/common/optional.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#ifndef ONEFLOW_CORE_COMMON_OPTIONAL_H_\n#define ONEFLOW_CORE_COMMON_OPTIONAL_H_\n\n#include <memory>\n#include <type_traits>\n#include <utility>\n#include \"oneflow/core/common/error.pb.h\"\n#include \"oneflow/core/common/type_traits.h\"\n#include \"oneflow/core/common/just.h\"\n\nnamespace oneflow {\n\nstruct InPlaceConstructType {\n  explicit InPlaceConstructType() = default;\n};\nconstexpr InPlaceConstructType InPlaceConstruct{};\n\nstruct NullOptType {\n  explicit constexpr NullOptType(int) {}\n};\nconstexpr NullOptType NullOpt{0};\n\nnamespace internal {\n\ntemplate<typename T, typename U = void>\nclass OptionalBase;\n\ntemplate<typename T>\nclass OptionalBase<T, typename std::enable_if<IsScalarType<T>::value>::type> {\n public:\n  using value_type = T;\n  using storage_type = T;\n\n  OptionalBase() : init_(false), value_() {}\n  ~OptionalBase() = default;\n\n  explicit OptionalBase(const T& value) : init_(true), value_(value) {}\n  explicit OptionalBase(T&& value) : init_(true), value_(std::move(value)) {}\n\n  OptionalBase(const OptionalBase& base) : init_(base.init_), value_(base.value_) {}\n  OptionalBase(OptionalBase&& base) noexcept : init_(base.init_), value_(std::move(base.value_)) {}\n\n  OptionalBase& operator=(const T& value) {\n    value_ = value;\n    init_ = true;\n\n    return *this;\n  }\n  OptionalBase& operator=(T&& value) {\n    value_ = std::move(value);\n    init_ = true;\n\n    return *this;\n  }\n  OptionalBase& operator=(const OptionalBase& rhs) {\n    value_ = rhs.value_;\n    init_ = rhs.init_;\n\n    return *this;\n  }\n  OptionalBase& operator=(OptionalBase&& rhs) noexcept {\n    value_ = std::move(rhs.value_);\n    init_ = rhs.init_;\n\n    return *this;\n  }\n\n  T value() const& { return value_; }  // `T value() &&` goes here\n  T& value() & { return value_; }\n\n  bool has_value() const { return init_; }\n\n  T value_or(const T& other) const {\n    if (has_value()) {\n      return value();\n    } else {\n      return other;\n    }\n  }\n\n  void reset() { init_ = false; }\n\n private:\n  bool init_;\n  T value_;\n};\n\ntemplate<typename T>\nclass OptionalBase<T, typename std::enable_if<std::is_reference<T>::value>::type> {\n public:\n  using value_type = typename std::remove_reference<T>::type;\n  using storage_type = value_type*;\n\n  static_assert(std::is_lvalue_reference<T>::value, \"rvalue reference is not supported here\");\n\n  OptionalBase() : value_(nullptr){};\n  ~OptionalBase() = default;\n\n  explicit OptionalBase(T value) : value_(&value) {}\n  OptionalBase(const OptionalBase& base) : value_(base.value_) {}\n  OptionalBase(OptionalBase&& base) noexcept : value_(base.value_) {}\n\n  OptionalBase& operator=(T value) {\n    value_ = &value;\n    return *this;\n  }\n  OptionalBase& operator=(const OptionalBase& rhs) {\n    value_ = rhs.value_;\n    return *this;\n  }\n  OptionalBase& operator=(OptionalBase&& rhs) noexcept {\n    value_ = std::move(rhs.value_);\n    return *this;\n  }\n\n  const value_type& value() const { return *value_; }\n  T value() { return *value_; }\n\n  bool has_value() const { return value_; }\n\n  const value_type& value_or(const value_type& other) const {\n    if (has_value()) {\n      return value();\n    } else {\n      return other;\n    }\n  }\n\n  void reset() { value_ = nullptr; }\n\n private:\n  storage_type value_;\n};\n\ntemplate<typename T>\nclass OptionalBase<\n    T, typename std::enable_if<!IsScalarType<T>::value && !std::is_reference<T>::value>::type> {\n public:\n  using value_type = T;\n  using storage_type = std::shared_ptr<T>;\n\n  OptionalBase() : value_(nullptr){};\n  ~OptionalBase() = default;\n\n  template<typename... Args>\n  explicit OptionalBase(InPlaceConstructType, Args&&... args)\n      : value_(std::make_shared<T>(std::forward<Args>(args)...)) {}\n\n  explicit OptionalBase(const T& value) : value_(std::make_shared<T>(value)) {}\n  explicit OptionalBase(T&& value) : value_(std::make_shared<T>(std::move(value))) {}\n\n  explicit OptionalBase(const storage_type& value) : value_(value) {}\n  explicit OptionalBase(storage_type&& value) : value_(std::move(value)) {}\n\n  OptionalBase(const OptionalBase&) = default;\n  OptionalBase(OptionalBase&&) noexcept = default;\n\n  OptionalBase& operator=(const T& value) {\n    if (value_) {\n      *value_ = value;\n    } else {\n      value_ = std::make_shared<T>(value);\n    }\n    return *this;\n  }\n  OptionalBase& operator=(T&& value) {\n    if (value_) {\n      *value_ = std::move(value);\n    } else {\n      value_ = std::make_shared<T>(std::move(value));\n    }\n    return *this;\n  }\n\n  OptionalBase& operator=(const storage_type& value) {\n    value_ = value;\n    return *this;\n  }\n  OptionalBase& operator=(storage_type&& value) {\n    value_ = std::move(value);\n    return *this;\n  }\n\n  OptionalBase& operator=(const OptionalBase& rhs) {\n    value_ = rhs.value_;\n    return *this;\n  }\n  OptionalBase& operator=(OptionalBase&& rhs) noexcept {\n    value_ = std::move(rhs.value_);\n    return *this;\n  }\n\n  const storage_type& value() const& { return value_; }\n  storage_type& value() & { return value_; }\n\n  storage_type&& value() && { return std::move(value_); }\n\n  bool has_value() const { return bool(value_); }\n\n  const storage_type& value_or(const storage_type& other) const& {\n    if (has_value()) {\n      return value_;\n    } else {\n      return other;\n    }\n  }\n\n  storage_type value_or(const storage_type& other) && {\n    if (has_value()) {\n      return std::move(value_);\n    } else {\n      return other;\n    }\n  }\n\n  storage_type value_or(storage_type&& other) const& {\n    if (has_value()) {\n      return value_;\n    } else {\n      return std::move(other);\n    }\n  }\n\n  storage_type value_or(storage_type&& other) && {\n    if (has_value()) {\n      return std::move(value_);\n    } else {\n      return std::move(other);\n    }\n  }\n\n  // we introduce a dependent name `U` to delay the instantiation,\n  // so only the default parameter of `U` is allowed\n  template<typename U = value_type>\n  typename std::enable_if<!std::is_abstract<U>::value, const U&>::type value_or(\n      const value_type& other) const& {\n    static_assert(std::is_same<U, value_type>::value, \"expected default U\");\n\n    if (has_value()) {\n      return *value_;\n    } else {\n      return other;\n    }\n  }\n\n  template<typename U = value_type>\n  typename std::enable_if<!std::is_abstract<U>::value, U>::type value_or(\n      const value_type& other) && {\n    static_assert(std::is_same<U, value_type>::value, \"expected default U\");\n\n    if (has_value()) {\n      return std::move(*value_);\n    } else {\n      return other;\n    }\n  }\n\n  template<typename U = value_type>\n  typename std::enable_if<!std::is_abstract<U>::value, U>::type value_or(\n      value_type&& other) const& {\n    static_assert(std::is_same<U, value_type>::value, \"expected default U\");\n\n    if (has_value()) {\n      return *value_;\n    } else {\n      return std::move(other);\n    }\n  }\n\n  template<typename U = value_type>\n  typename std::enable_if<!std::is_abstract<U>::value, U>::type value_or(value_type&& other) && {\n    static_assert(std::is_same<U, value_type>::value, \"expected default U\");\n\n    if (has_value()) {\n      return std::move(*value_);\n    } else {\n      return std::move(other);\n    }\n  }\n\n  void reset() { value_.reset(); }\n\n private:\n  storage_type value_;\n};\n\ntemplate<typename T>\nstruct IsOptional : std::false_type {};\n\ntemplate<typename T>\nstruct IsOptional<Optional<T>> : std::true_type {};\n\nstruct monadic_operations {\n  template<typename T, typename F>\n  static auto map(T&& opt, F&& f)\n      -> Optional<decltype(std::forward<F>(f)(std::forward<T>(opt).value()))> {\n    if (opt.has_value()) { return std::forward<F>(f)(std::forward<T>(opt).value()); }\n\n    return NullOpt;\n  }\n\n  template<typename T, typename F,\n           typename U = std::decay_t<decltype(std::declval<F>()(std::declval<T>().value()))>>\n  static auto bind(T&& opt, F&& f) -> std::enable_if_t<IsOptional<U>::value, U> {\n    if (opt.has_value()) { return std::forward<F>(f)(std::forward<T>(opt).value()); }\n\n    return NullOpt;\n  }\n\n  template<typename T, typename F,\n           std::enable_if_t<std::is_same<decltype(std::declval<F>()()), void>::value, int> = 0>\n  static auto or_else(T&& opt, F&& f) -> std::decay_t<T> {\n    if (!opt.has_value()) {\n      std::forward<F>(f)();\n      return NullOpt;\n    }\n\n    return std::forward<T>(opt);\n  }\n\n  template<typename T, typename F,\n           std::enable_if_t<\n               std::is_convertible<decltype(std::declval<F>()()), std::decay_t<T>>::value, int> = 0>\n  static auto or_else(T&& opt, F&& f) -> std::decay_t<T> {\n    if (!opt.has_value()) { return std::forward<F>(f)(); }\n\n    return std::forward<T>(opt);\n  }\n};\n\n}  // namespace internal\n\ntemplate<typename T>\nclass Optional final : private internal::OptionalBase<T> {\n private:\n  using base = internal::OptionalBase<T>;\n  using move_value_type = decltype(std::declval<base>().value());\n\n public:\n  using value_type = typename base::value_type;\n  using storage_type = typename base::storage_type;\n\n  explicit Optional() = default;\n  ~Optional() = default;\n\n  Optional(NullOptType)  // NOLINT(google-explicit-constructor)\n      : base() {}\n\n  template<\n      typename Arg1, typename... ArgN,\n      typename std::enable_if<!(sizeof...(ArgN) == 0\n                                && std::is_same<Optional, typename std::decay<Arg1>::type>::value),\n                              int>::type = 0>\n  Optional(Arg1&& v1, ArgN&&... vn)  // NOLINT(google-explicit-constructor)\n      : base(std::forward<Arg1>(v1), std::forward<ArgN>(vn)...) {}\n\n  Optional(const Optional&) = default;\n  Optional(Optional&&) noexcept = default;\n\n  template<typename U,\n           typename std::enable_if<!std::is_same<Optional, typename std::decay<U>::type>::value,\n                                   int>::type = 0>\n  Optional& operator=(U&& val) {\n    return static_cast<Optional&>(static_cast<base&>(*this) = std::forward<U>(val));\n  }\n\n  Optional& operator=(const Optional& rhs) = default;\n  Optional& operator=(Optional&& rhs) noexcept = default;\n\n  template<typename U>\n  decltype(auto) value_or(U&& other) const& {\n    return base::value_or(std::forward<U>(other));\n  }\n\n  template<typename U>\n  decltype(auto) value_or(U&& other) && {\n    return std::move(*this).base::value_or(std::forward<U>(other));\n  }\n\n  bool has_value() const { return base::has_value(); }\n  explicit operator bool() const { return has_value(); }\n\n  // generate a temporary object to allow `const auto& x = optval().value()` where `optval()` is a\n  // function call which returns a temporary Optional\n  auto Data_YouAreNotAllowedToCallThisFuncOutsideThisFile() && -> std::conditional_t<\n      std::is_rvalue_reference<move_value_type>::value, std::remove_reference_t<move_value_type>,\n      move_value_type> {\n    return std::move(*this).base::value();\n  }\n\n  friend internal::monadic_operations;\n\n  template<typename F>\n  auto map(F&& f) const& {\n    return internal::monadic_operations::map(*this, std::forward<F>(f));\n  }\n\n  template<typename F>\n  auto map(F&& f) && {\n    return internal::monadic_operations::map(std::move(*this), std::forward<F>(f));\n  }\n\n  template<typename F>\n  auto bind(F&& f) const& {\n    return internal::monadic_operations::bind(*this, std::forward<F>(f));\n  }\n\n  template<typename F>\n  auto bind(F&& f) && {\n    return internal::monadic_operations::bind(std::move(*this), std::forward<F>(f));\n  }\n\n  template<typename F>\n  auto or_else(F&& f) const& {\n    return internal::monadic_operations::or_else(*this, std::forward<F>(f));\n  }\n\n  template<typename F>\n  auto or_else(F&& f) && {\n    return internal::monadic_operations::or_else(std::move(*this), std::forward<F>(f));\n  }\n\n  bool operator==(const Optional& other) const {\n    if (has_value()) {\n      if (other.has_value()) {\n        return base::value() == other.base::value();\n      } else {\n        return false;\n      }\n    } else {\n      return !other.has_value();\n    }\n  }\n\n  bool operator!=(const Optional& other) const { return !operator==(other); }\n\n  void reset() { base::reset(); }\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_COMMON_OPTIONAL_H_\n"
  },
  {
    "path": "oneflow/core/common/optional_test.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <gtest/gtest.h>\n#include \"oneflow/core/common/just.h\"\n#include \"oneflow/core/common/optional.h\"\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/common/exception.h\"\n\nnamespace oneflow {\nnamespace test {\n\nTEST(Optional, copy_constructor) {\n  Optional<int64_t> a(0);\n  std::vector<Optional<int64_t>> vec;\n  vec.emplace_back(a);\n  ASSERT_TRUE(vec[0].has_value());\n  int64_t val = CHECK_JUST(vec[0]);\n  ASSERT_EQ(val, 0);\n}\n\nTEST(Optional, move_constructor) {\n  Optional<int64_t> a(0);\n  std::map<int64_t, Optional<int64_t>> map;\n  map.emplace(0, a);\n  ASSERT_TRUE(map.at(0).has_value());\n  int64_t val = CHECK_JUST(map.at(0));\n  ASSERT_EQ(val, 0);\n}\n\nTEST(Optional, JUST) {\n  Optional<int> a(233), b;\n\n  ASSERT_EQ(a.value_or(0), 233);\n  ASSERT_EQ(b.value_or(1), 1);\n\n  auto f = [](const Optional<int>& v) -> Maybe<int> { return JUST(v); };\n\n  ASSERT_EQ(CHECK_JUST(f(a)), 233);\n  ASSERT_EQ(f(b).error()->msg(), \"\");\n\n  auto g = [](const Optional<int>& v) -> Optional<int> { return JUST_OPT(v); };\n\n  ASSERT_EQ(CHECK_JUST(g(a)), 233);\n\n  a = 234;\n  ASSERT_EQ(CHECK_JUST(a), 234);\n\n  b = a;\n  ASSERT_EQ(CHECK_JUST(b), 234);\n\n  b.reset();\n  ASSERT_EQ(b.value_or(1), 1);\n\n  Optional<const int> c(233);\n  ASSERT_EQ(CHECK_JUST(c), 233);\n}\n\nTEST(Optional, reference) {\n  int x = 1, z = 0;\n  Optional<int&> a(x), b;\n\n  x = 2;\n  ASSERT_EQ(CHECK_JUST(a), 2);\n  ASSERT_EQ(b.value_or(z), 0);\n\n  CHECK_JUST(a) = 3;\n  ASSERT_EQ(x, 3);\n\n  Optional<const int&> c(x);\n  ASSERT_EQ(CHECK_JUST(c), 3);\n}\n\nTEST(Optional, non_scalar) {\n  Optional<std::vector<int>> a(InPlaceConstruct, 10), b;\n  CHECK_JUST(a)->at(1) = 1;\n\n  ASSERT_EQ(CHECK_JUST(a)->size(), 10);\n  ASSERT_EQ(CHECK_JUST(a)->at(1), 1);\n\n  auto x = std::make_shared<std::vector<int>>(1);\n  ASSERT_EQ(b.value_or(x), x);\n\n  ASSERT_EQ(b.value_or(std::vector<int>{1, 2, 3}), (std::vector<int>{1, 2, 3}));\n  ASSERT_EQ(b.value_or(*x), *x);\n  ASSERT_EQ(a.value_or(*x), *CHECK_JUST(a));\n\n  ASSERT_EQ(Optional<std::vector<int>>().value_or(*x), *x);\n  ASSERT_EQ(Optional<std::vector<int>>().value_or(std::vector<int>{1, 2, 3}),\n            (std::vector<int>{1, 2, 3}));\n\n  Optional<const std::vector<int>> c(std::vector<int>{1, 2, 3});\n\n  ASSERT_EQ(CHECK_JUST(c)->at(1), 2);\n}\n\nTEST(Optional, optional_just_error_throw) {\n  ASSERT_THROW(  // NOLINT(cppcoreguidelines-avoid-goto)\n      {\n        ([]() -> Maybe<int> {\n          Optional<int> a;\n          return JUST(a);\n        })()\n            .GetOrThrow();\n      },\n      Exception);\n}\n\nTEST(Optional, monadic_operations) {\n  Optional<int> a(1), b, c(2);\n  ASSERT_EQ(a.map([](int x) { return x + 1; }), c);\n  ASSERT_EQ(b.map([](int x) { return x + 1; }), b);\n  ASSERT_EQ(a.map([](int x) { return std::string(x + 1, 'a'); }).map([](const auto& x) {\n    return (int)x->size();\n  }),\n            c);\n  ASSERT_EQ(a.bind([](int x) -> Optional<float> {\n               if (x < 10) {\n                 return x * 1.1;\n               } else {\n                 return NullOpt;\n               }\n             })\n                .map([](float x) { return x - 1; })\n                .map([](float x) { return std::abs(x - 0.1) < 0.001; }),\n            Optional<bool>(true));\n\n  int x = 0;\n  b.or_else([&] { x++; }).or_else([&] { x *= 2; });\n  ASSERT_EQ(x, 2);\n  ASSERT_EQ(b.or_else([] { return Optional<int>(3); }).map([](int x) { return x - 1; }), c);\n}\n\n}  // namespace test\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/common/pcheck.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_COMMON_PCHECK_H_\n#define ONEFLOW_CORE_COMMON_PCHECK_H_\n\n#include <string.h>\n#include \"oneflow/core/common/maybe.h\"\n\nnamespace oneflow {\n\n#define PCHECK_OR_RETURN(expr)                                             \\\n  for (int __err = (expr), *__cond = nullptr; __cond == nullptr; ++__cond) \\\n  CHECK_EQ_OR_RETURN(__err, 0) << strerror(errno) << \" \"\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_COMMON_PCHECK_H_\n"
  },
  {
    "path": "oneflow/core/common/permutation_iterator.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_COMMON_PERMUTATION_ITERATOR_H_\n#define ONEFLOW_CORE_COMMON_PERMUTATION_ITERATOR_H_\n\n#include <iterator>\n\nnamespace oneflow {\n\n#define ITER_DEVICE_FUNC __host__ __device__ __forceinline__\n\ntemplate<typename T, typename DataIter, typename IndexIter, typename OffsetT = std::ptrdiff_t>\nclass PermutationIterator {\n public:\n  using iterator_category = std::random_access_iterator_tag;\n  using self_type = PermutationIterator;\n  using difference_type = OffsetT;\n  using value_type = T;\n  using pointer = T*;\n  using reference = T&;\n\n  ITER_DEVICE_FUNC PermutationIterator(DataIter data_iter, IndexIter index_iter)\n      : data_iter_(data_iter), index_iter_(index_iter) {}\n\n  // const methods\n\n  ITER_DEVICE_FUNC bool operator==(const PermutationIterator& rhs) const {\n    return index_iter_ == rhs.index_iter_ && data_iter_ == rhs.data_iter_;\n  }\n\n  ITER_DEVICE_FUNC bool operator!=(const PermutationIterator& rhs) const { return !(*this == rhs); }\n\n  template<typename Int>\n  ITER_DEVICE_FUNC PermutationIterator operator+(Int n) const {\n    return PermutationIterator(data_iter_, index_iter_ + n);\n  }\n\n  template<typename Int>\n  ITER_DEVICE_FUNC PermutationIterator operator-(Int n) const {\n    return PermutationIterator(data_iter_, index_iter_ - n);\n  }\n\n  ITER_DEVICE_FUNC difference_type operator-(PermutationIterator other) const {\n    return index_iter_ - other.index_iter_;\n  }\n\n  ITER_DEVICE_FUNC pointer operator->() const { return &data_iter_[*index_iter_]; }\n\n  ITER_DEVICE_FUNC reference operator*() const { return data_iter_[*index_iter_]; }\n\n  template<typename Int>\n  ITER_DEVICE_FUNC reference operator[](Int n) const {\n    return data_iter_[index_iter_[n]];\n  }\n\n  // mutable methods\n\n  ITER_DEVICE_FUNC PermutationIterator operator++(int) {\n    PermutationIterator ret = *this;\n    index_iter_++;\n    return ret;\n  }\n\n  ITER_DEVICE_FUNC PermutationIterator operator++() {\n    index_iter_++;\n    return *this;\n  }\n\n  ITER_DEVICE_FUNC PermutationIterator operator--(int) {\n    PermutationIterator ret = *this;\n    index_iter_--;\n    return ret;\n  }\n\n  ITER_DEVICE_FUNC PermutationIterator operator--() {\n    index_iter_--;\n    return *this;\n  }\n\n  template<typename Int>\n  ITER_DEVICE_FUNC PermutationIterator& operator+=(Int n) {\n    index_iter_ += n;\n    return *this;\n  }\n\n  template<typename Int>\n  ITER_DEVICE_FUNC PermutationIterator& operator-=(Int n) {\n    index_iter_ -= n;\n    return *this;\n  }\n\n  ITER_DEVICE_FUNC pointer operator->() { return &data_iter_[*index_iter_]; }\n\n  ITER_DEVICE_FUNC reference operator*() { return data_iter_[*index_iter_]; }\n\n  template<typename Int>\n  ITER_DEVICE_FUNC reference operator[](Int n) {\n    return data_iter_[index_iter_[n]];\n  }\n\n private:\n  DataIter data_iter_;\n  IndexIter index_iter_;\n};\n\n#undef ITER_DEVICE_FUNC\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_COMMON_PERMUTATION_ITERATOR_H_\n"
  },
  {
    "path": "oneflow/core/common/platform.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_COMMON_PLATFORM_H_\n#define ONEFLOW_CORE_COMMON_PLATFORM_H_\n\n// Set one OF_PLATFORM_* macro and set OF_IS_MOBILE_PLATFORM if the platform is for\n// mobile.\n\n#if !defined(OF_PLATFORM_POSIX) && !defined(OF_PLATFORM_GOOGLE)                    \\\n    && !defined(OF_PLATFORM_POSIX_ANDROID) && !defined(OF_PLATFORM_GOOGLE_ANDROID) \\\n    && !defined(OF_PLATFORM_WINDOWS)\n\n// Choose which platform we are on.\n#if defined(ANDROID) || defined(__ANDROID__)\n#define OF_PLATFORM_POSIX_ANDROID\n#define OF_IS_MOBILE_PLATFORM\n\n#elif defined(__APPLE__)\n#define OF_PLATFORM_POSIX\n#include \"TargetConditionals.h\"\n#if OF_TARGET_IPHONE_SIMULATOR\n#define OF_IS_MOBILE_PLATFORM\n#elif OF_TARGET_OS_IPHONE\n#define OF_IS_MOBILE_PLATFORM\n#endif\n\n#elif defined(_WIN32)\n#define OF_PLATFORM_WINDOWS\n\n#elif defined(__arm__)\n#define OF_PLATFORM_POSIX\n\n// Require an outside macro to tell us if we're building for Raspberry Pi.\n#if !defined(RASPBERRY_PI)\n#define OF_IS_MOBILE_PLATFORM\n#endif  // !defined(RASPBERRY_PI)\n\n#else\n// If no platform specified, use:\n#define OF_PLATFORM_POSIX\n\n#endif\n#endif\n\n// Look for both gcc/clang and Visual Studio macros indicating we're compiling\n// for an x86 device.\n#if defined(__x86_64__) || defined(__amd64__) || defined(_M_IX86) || defined(_M_X64)\n#define OF_PLATFORM_IS_X86\n#endif\n\n#endif  // ONEFLOW_CORE_COMMON_PLATFORM_H_\n"
  },
  {
    "path": "oneflow/core/common/preprocessor.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_COMMON_PREPROCESSOR_H_\n#define ONEFLOW_CORE_COMMON_PREPROCESSOR_H_\n\n#include \"oneflow/core/common/preprocessor_internal.h\"\n\n// basic\n#define OF_PP_CAT(a, b) OF_PP_INTERNAL_CAT(a, b)\n\n#define OF_PP_STRINGIZE(...) OF_PP_INTERNAL_STRINGIZE(__VA_ARGS__)\n\n#define OF_PP_PAIR_FIRST(pair) OF_PP_INTERNAL_PAIR_FIRST(pair)\n\n#define OF_PP_PAIR_SECOND(pair) OF_PP_INTERNAL_PAIR_SECOND(pair)\n\n#define OF_PP_PAIR_THIRD(pair) OF_PP_INTERNAL_PAIR_THIRD(pair)\n\n#define OF_PP_TUPLE_SIZE(t) OF_PP_INTERNAL_TUPLE_SIZE(t)\n\n#define OF_PP_TUPLE_ELEM(n, t) OF_PP_INTERNAL_TUPLE_ELEM(n, t)\n\n#define OF_PP_MAKE_TUPLE_SEQ(...) OF_PP_INTERNAL_MAKE_TUPLE_SEQ(__VA_ARGS__)\n\n#define OF_PP_FOR_EACH_TUPLE(macro, seq) OF_PP_INTERNAL_FOR_EACH_TUPLE(macro, seq)\n\n#define OF_PP_OUTTER_FOR_EACH_TUPLE(macro, seq) OF_PP_INTERNAL_OUTTER_FOR_EACH_TUPLE(macro, seq)\n\n#define OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(macro, ...) \\\n  OF_PP_INTERNAL_SEQ_PRODUCT_FOR_EACH_TUPLE(macro, __VA_ARGS__)\n\n// advanced\n\n#define OF_PP_VARIADIC_SIZE(...) OF_PP_INTERNAL_VARIADIC_SIZE(__VA_ARGS__)\n\n#define OF_PP_SEQ_SIZE(seq) OF_PP_INTERNAL_SEQ_SIZE(seq)\n\n#define OF_PP_ATOMIC_TO_TUPLE(x) (x)\n\n#define OF_PP_FOR_EACH_ATOMIC(macro, seq) \\\n  OF_PP_FOR_EACH_TUPLE(macro, OF_PP_SEQ_MAP(OF_PP_ATOMIC_TO_TUPLE, seq))\n\n#define OF_PP_SEQ_PRODUCT(seq0, ...) OF_PP_INTERNAL_SEQ_PRODUCT(seq0, __VA_ARGS__)\n\n#define OF_PP_SEQ_MAP(macro, seq) \\\n  OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(OF_PP_I_SEQ_MAP_DO_EACH, (macro), seq)\n#define OF_PP_I_SEQ_MAP_DO_EACH(macro, elem) (macro(elem))\n\n#define OF_PP_JOIN(glue, ...) OF_PP_INTERNAL_JOIN(glue, __VA_ARGS__)\n\n#define OF_PP_TUPLE_PUSH_FRONT(t, x) OF_PP_INTERNAL_TUPLE_PUSH_FRONT(t, x)\n\n#define OF_PP_FORCE(...) OF_PP_TUPLE2VARADIC(OF_PP_CAT((__VA_ARGS__), ))\n\n#endif  // ONEFLOW_CORE_COMMON_PREPROCESSOR_H_\n"
  },
  {
    "path": "oneflow/core/common/preprocessor_internal.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_COMMON_PREPROCESSOR_INTERNAL_H_\n#define ONEFLOW_CORE_COMMON_PREPROCESSOR_INTERNAL_H_\n\n// Base\n\n#define OF_PP_TUPLE2VARADIC(t) OF_PP_TUPLE2VARADIC_I(t)\n\n#define OF_PP_TUPLE2VARADIC_I(t) OF_PP_TUPLE2VARADIC_II t\n\n#define OF_PP_TUPLE2VARADIC_II(...) __VA_ARGS__\n\n#define OF_PP_INTERNAL_STRINGIZE(...) OF_PP_INTERNAL_STRINGIZE_I(__VA_ARGS__)\n#define OF_PP_INTERNAL_STRINGIZE_I(...) #__VA_ARGS__\n\n#define OF_PP_INTERNAL_CAT(a, b) OF_PP_INTERNAL_CAT_I(a, b)\n#define OF_PP_INTERNAL_CAT_I(a, b) a##b\n\n#define OF_PP_INTERNAL_JOIN(glue, ...)                                                     \\\n  OF_PP_INTERNAL_CAT(                                                                      \\\n      OF_PP_INTERNAL_CAT(OF_PP_INTERNAL_JOIN_, OF_PP_INTERNAL_VARIADIC_SIZE(__VA_ARGS__))( \\\n          glue, __VA_ARGS__), )\n\n#define OF_PP_INTERNAL_JOIN_0(glue)\n#define OF_PP_INTERNAL_JOIN_1(glue, x) x\n#define OF_PP_INTERNAL_JOIN_2(glue, x, ...) \\\n  OF_PP_INTERNAL_CAT(                       \\\n      OF_PP_INTERNAL_CAT(OF_PP_INTERNAL_CAT(x, glue), OF_PP_INTERNAL_JOIN_1(glue, __VA_ARGS__)), )\n#define OF_PP_INTERNAL_JOIN_3(glue, x, ...) \\\n  OF_PP_INTERNAL_CAT(                       \\\n      OF_PP_INTERNAL_CAT(OF_PP_INTERNAL_CAT(x, glue), OF_PP_INTERNAL_JOIN_2(glue, __VA_ARGS__)), )\n#define OF_PP_INTERNAL_JOIN_4(glue, x, ...) \\\n  OF_PP_INTERNAL_CAT(                       \\\n      OF_PP_INTERNAL_CAT(OF_PP_INTERNAL_CAT(x, glue), OF_PP_INTERNAL_JOIN_3(glue, __VA_ARGS__)), )\n#define OF_PP_INTERNAL_JOIN_5(glue, x, ...) \\\n  OF_PP_INTERNAL_CAT(                       \\\n      OF_PP_INTERNAL_CAT(OF_PP_INTERNAL_CAT(x, glue), OF_PP_INTERNAL_JOIN_4(glue, __VA_ARGS__)), )\n#define OF_PP_INTERNAL_JOIN_6(glue, x, ...) \\\n  OF_PP_INTERNAL_CAT(                       \\\n      OF_PP_INTERNAL_CAT(OF_PP_INTERNAL_CAT(x, glue), OF_PP_INTERNAL_JOIN_5(glue, __VA_ARGS__)), )\n#define OF_PP_INTERNAL_JOIN_7(glue, x, ...) \\\n  OF_PP_INTERNAL_CAT(                       \\\n      OF_PP_INTERNAL_CAT(OF_PP_INTERNAL_CAT(x, glue), OF_PP_INTERNAL_JOIN_6(glue, __VA_ARGS__)), )\n#define OF_PP_INTERNAL_JOIN_8(glue, x, ...) \\\n  OF_PP_INTERNAL_CAT(                       \\\n      OF_PP_INTERNAL_CAT(OF_PP_INTERNAL_CAT(x, glue), OF_PP_INTERNAL_JOIN_7(glue, __VA_ARGS__)), )\n#define OF_PP_INTERNAL_JOIN_9(glue, x, ...) \\\n  OF_PP_INTERNAL_CAT(                       \\\n      OF_PP_INTERNAL_CAT(OF_PP_INTERNAL_CAT(x, glue), OF_PP_INTERNAL_JOIN_8(glue, __VA_ARGS__)), )\n#define OF_PP_INTERNAL_JOIN_10(glue, x, ...) \\\n  OF_PP_INTERNAL_CAT(                        \\\n      OF_PP_INTERNAL_CAT(OF_PP_INTERNAL_CAT(x, glue), OF_PP_INTERNAL_JOIN_9(glue, __VA_ARGS__)), )\n#define OF_PP_INTERNAL_JOIN_11(glue, x, ...)                         \\\n  OF_PP_INTERNAL_CAT(OF_PP_INTERNAL_CAT(OF_PP_INTERNAL_CAT(x, glue), \\\n                                        OF_PP_INTERNAL_JOIN_10(glue, __VA_ARGS__)), )\n#define OF_PP_INTERNAL_JOIN_12(glue, x, ...)                         \\\n  OF_PP_INTERNAL_CAT(OF_PP_INTERNAL_CAT(OF_PP_INTERNAL_CAT(x, glue), \\\n                                        OF_PP_INTERNAL_JOIN_11(glue, __VA_ARGS__)), )\n#define OF_PP_INTERNAL_JOIN_13(glue, x, ...)                         \\\n  OF_PP_INTERNAL_CAT(OF_PP_INTERNAL_CAT(OF_PP_INTERNAL_CAT(x, glue), \\\n                                        OF_PP_INTERNAL_JOIN_12(glue, __VA_ARGS__)), )\n#define OF_PP_INTERNAL_JOIN_14(glue, x, ...)                         \\\n  OF_PP_INTERNAL_CAT(OF_PP_INTERNAL_CAT(OF_PP_INTERNAL_CAT(x, glue), \\\n                                        OF_PP_INTERNAL_JOIN_13(glue, __VA_ARGS__)), )\n#define OF_PP_INTERNAL_JOIN_15(glue, x, ...)                         \\\n  OF_PP_INTERNAL_CAT(OF_PP_INTERNAL_CAT(OF_PP_INTERNAL_CAT(x, glue), \\\n                                        OF_PP_INTERNAL_JOIN_14(glue, __VA_ARGS__)), )\n\n#define OF_PP_INTERNAL_SEQ_HEAD(seq) OF_PP_INTERNAL_PAIR_FIRST(OF_PP_INTERNAL_SEQ_TO_PAIR(seq))\n#define OF_PP_INTERNAL_SEQ_TAIL(seq) OF_PP_INTERNAL_PAIR_SECOND(OF_PP_INTERNAL_SEQ_TO_PAIR(seq))\n\n#define OF_PP_INTERNAL_SEQ_TO_PAIR(seq) (OF_PP_INTERNAL_SEQ_TO_PAIR_ seq)\n#define OF_PP_INTERNAL_SEQ_TO_PAIR_(x) x, OF_PP_INTERNAL_NIL\n#define OF_PP_INTERNAL_NIL\n\n#define OF_PP_INTERNAL_PAIR_FIRST(t) OF_PP_INTERNAL_PAIR_FIRST_I(t)\n#define OF_PP_INTERNAL_PAIR_FIRST_I(t) OF_PP_INTERNAL_FIRST_ARG t\n#define OF_PP_INTERNAL_PAIR_SECOND(t) OF_PP_INTERNAL_PAIR_SECOND_I(t)\n#define OF_PP_INTERNAL_PAIR_SECOND_I(t) OF_PP_INTERNAL_SECOND_ARG t\n#define OF_PP_INTERNAL_PAIR_THIRD(t) OF_PP_INTERNAL_PAIR_THIRD_I(t)\n#define OF_PP_INTERNAL_PAIR_THIRD_I(t) OF_PP_INTERNAL_THIRD_ARG t\n\n#define OF_PP_INTERNAL_FIRST_ARG(x, ...) x\n#define OF_PP_INTERNAL_SECOND_ARG(x, y, ...) y\n#define OF_PP_INTERNAL_THIRD_ARG(x, y, z, ...) z\n\n#define OF_PP_INTERNAL_MAKE_TUPLE(...) (__VA_ARGS__)\n#define OF_PP_INTERNAL_MAKE_TUPLE_SEQ(...) (OF_PP_INTERNAL_MAKE_TUPLE(__VA_ARGS__))\n\n// Tuple\n\n#define OF_PP_INTERNAL_TUPLE_PUSH_FRONT(tuple, x)                                        \\\n  OF_PP_INTERNAL_CAT(OF_PP_INTERNAL_TUPLE_PUSH_FRONT_, OF_PP_INTERNAL_TUPLE_SIZE(tuple)) \\\n  (tuple, x)\n\n#define OF_PP_INTERNAL_TUPLE_PUSH_FRONT_0(tuple, x) (x)\n#define OF_PP_INTERNAL_TUPLE_PUSH_FRONT_1(tuple, x) (x, OF_PP_INTERNAL_TUPLE_ELEM(0, tuple))\n#define OF_PP_INTERNAL_TUPLE_PUSH_FRONT_2(tuple, x) \\\n  (x, OF_PP_INTERNAL_TUPLE_ELEM(0, tuple), OF_PP_INTERNAL_TUPLE_ELEM(1, tuple))\n#define OF_PP_INTERNAL_TUPLE_PUSH_FRONT_3(tuple, x)                             \\\n  (x, OF_PP_INTERNAL_TUPLE_ELEM(0, tuple), OF_PP_INTERNAL_TUPLE_ELEM(1, tuple), \\\n   OF_PP_INTERNAL_TUPLE_ELEM(2, tuple))\n#define OF_PP_INTERNAL_TUPLE_PUSH_FRONT_4(tuple, x)                             \\\n  (x, OF_PP_INTERNAL_TUPLE_ELEM(0, tuple), OF_PP_INTERNAL_TUPLE_ELEM(1, tuple), \\\n   OF_PP_INTERNAL_TUPLE_ELEM(2, tuple), OF_PP_INTERNAL_TUPLE_ELEM(3, tuple))\n#define OF_PP_INTERNAL_TUPLE_PUSH_FRONT_5(tuple, x)                             \\\n  (x, OF_PP_INTERNAL_TUPLE_ELEM(0, tuple), OF_PP_INTERNAL_TUPLE_ELEM(1, tuple), \\\n   OF_PP_INTERNAL_TUPLE_ELEM(2, tuple), OF_PP_INTERNAL_TUPLE_ELEM(3, tuple),    \\\n   OF_PP_INTERNAL_TUPLE_ELEM(4, tuple))\n#define OF_PP_INTERNAL_TUPLE_PUSH_FRONT_6(tuple, x)                             \\\n  (x, OF_PP_INTERNAL_TUPLE_ELEM(0, tuple), OF_PP_INTERNAL_TUPLE_ELEM(1, tuple), \\\n   OF_PP_INTERNAL_TUPLE_ELEM(2, tuple), OF_PP_INTERNAL_TUPLE_ELEM(3, tuple),    \\\n   OF_PP_INTERNAL_TUPLE_ELEM(4, tuple), OF_PP_INTERNAL_TUPLE_ELEM(5, tuple))\n#define OF_PP_INTERNAL_TUPLE_PUSH_FRONT_7(tuple, x)                             \\\n  (x, OF_PP_INTERNAL_TUPLE_ELEM(0, tuple), OF_PP_INTERNAL_TUPLE_ELEM(1, tuple), \\\n   OF_PP_INTERNAL_TUPLE_ELEM(2, tuple), OF_PP_INTERNAL_TUPLE_ELEM(3, tuple),    \\\n   OF_PP_INTERNAL_TUPLE_ELEM(4, tuple), OF_PP_INTERNAL_TUPLE_ELEM(5, tuple),    \\\n   OF_PP_INTERNAL_TUPLE_ELEM(6, tuple))\n#define OF_PP_INTERNAL_TUPLE_PUSH_FRONT_8(tuple, x)                             \\\n  (x, OF_PP_INTERNAL_TUPLE_ELEM(0, tuple), OF_PP_INTERNAL_TUPLE_ELEM(1, tuple), \\\n   OF_PP_INTERNAL_TUPLE_ELEM(2, tuple), OF_PP_INTERNAL_TUPLE_ELEM(3, tuple),    \\\n   OF_PP_INTERNAL_TUPLE_ELEM(4, tuple), OF_PP_INTERNAL_TUPLE_ELEM(5, tuple),    \\\n   OF_PP_INTERNAL_TUPLE_ELEM(6, tuple), OF_PP_INTERNAL_TUPLE_ELEM(7, tuple))\n\n#define OF_PP_INTERNAL_TUPLE_ELEM(n, t) OF_PP_INTERNAL_TUPLE_ELEM_I(n, t)\n#define OF_PP_INTERNAL_TUPLE_ELEM_I(n, t) \\\n  OF_PP_INTERNAL_CAT(OF_PP_INTERNAL_CAT(OF_PP_INTERNAL_ARG_, n) t, )\n\n#define OF_PP_INTERNAL_ARG_0(a0, ...) a0\n#define OF_PP_INTERNAL_ARG_1(a0, a1, ...) a1\n#define OF_PP_INTERNAL_ARG_2(a0, a1, a2, ...) a2\n#define OF_PP_INTERNAL_ARG_3(a0, a1, a2, a3, ...) a3\n#define OF_PP_INTERNAL_ARG_4(a0, a1, a2, a3, a4, ...) a4\n#define OF_PP_INTERNAL_ARG_5(a0, a1, a2, a3, a4, a5, ...) a5\n#define OF_PP_INTERNAL_ARG_6(a0, a1, a2, a3, a4, a5, a6, ...) a6\n#define OF_PP_INTERNAL_ARG_7(a0, a1, a2, a3, a4, a5, a6, a7, ...) a7\n\n#define OF_PP_INTERNAL_TUPLE_SIZE(tuple)                                               \\\n  OF_PP_INTERNAL_CAT(OF_PP_INTERNAL_TUPLE_SIZE_, OF_PP_INTERNAL_IS_TUPLE_EMPTY(tuple)) \\\n  (tuple)\n\n#define OF_PP_INTERNAL_TUPLE_SIZE_1(t) 0\n#define OF_PP_INTERNAL_TUPLE_SIZE_0(t) OF_PP_INTERNAL_TUPLE_SIZE_0_I(t)\n#define OF_PP_INTERNAL_TUPLE_SIZE_0_I(t) OF_PP_INTERNAL_CAT(OF_PP_INTERNAL_VARIADIC_SIZE t, )\n\n#define OF_PP_INTERNAL_VARIADIC_SIZE(...)                                                          \\\n  OF_PP_INTERNAL_CAT(                                                                              \\\n      OF_PP_INTERNAL_VARIADIC_SIZE_I(                                                              \\\n          __VA_ARGS__, 64, 63, 62, 61, 60, 59, 58, 57, 56, 55, 54, 53, 52, 51, 50, 49, 48, 47, 46, \\\n          45, 44, 43, 42, 41, 40, 39, 38, 37, 36, 35, 34, 33, 32, 31, 30, 29, 28, 27, 26, 25, 24,  \\\n          23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, ), )\n#define OF_PP_INTERNAL_VARIADIC_SIZE_I(                                                            \\\n    e0, e1, e2, e3, e4, e5, e6, e7, e8, e9, e10, e11, e12, e13, e14, e15, e16, e17, e18, e19, e20, \\\n    e21, e22, e23, e24, e25, e26, e27, e28, e29, e30, e31, e32, e33, e34, e35, e36, e37, e38, e39, \\\n    e40, e41, e42, e43, e44, e45, e46, e47, e48, e49, e50, e51, e52, e53, e54, e55, e56, e57, e58, \\\n    e59, e60, e61, e62, e63, size, ...)                                                            \\\n  size\n\n#define OF_PP_INTERNAL_IS_TUPLE_EMPTY(t) OF_PP_INTERNAL_IS_TUPLE_EMPTY_I(t)\n#define OF_PP_INTERNAL_IS_TUPLE_EMPTY_I(t) OF_PP_INTERNAL_CAT(OF_PP_INTERNAL_IS_VARIADIC_EMPTY t, )\n\n#define OF_PP_INTERNAL_IS_VARIADIC_EMPTY(...)                                                 \\\n  OF_PP_INTERNAL_IS_VARIADIC_EMPTY_(/* test if there is just one argument,                    \\\n                              eventually an empty one */                                      \\\n                                    OF_PP_INTERNAL_VARIADIC_HAS_COMMA(                        \\\n                                        __VA_ARGS__), /* test if                              \\\n                                                         _OF_PP_INTERNAL_TRIGGER_PARENTHESIS_ \\\n                                                         together with the                    \\\n                                                         argument adds a comma                \\\n                                                       */                                     \\\n                                    OF_PP_INTERNAL_VARIADIC_HAS_COMMA(                        \\\n                                        _OF_PP_INTERNAL_TRIGGER_PARENTHESIS_                  \\\n                                            __VA_ARGS__), /* test if the                      \\\n                                                             argument together                \\\n                                                             with a                           \\\n                                                             parenthesis adds                 \\\n                                                             a comma                          \\\n                                                           */                                 \\\n                                    OF_PP_INTERNAL_VARIADIC_HAS_COMMA(__VA_ARGS__(            \\\n                                        /*empty*/)), /* test if placing it                    \\\n                                                        between                               \\\n                                                        _OF_PP_INTERNAL_TRIGGER_PARENTHESIS_  \\\n                                                        and the                               \\\n                                                        parenthesis adds a                    \\\n                                                        comma */                              \\\n                                    OF_PP_INTERNAL_VARIADIC_HAS_COMMA(                        \\\n                                        _OF_PP_INTERNAL_TRIGGER_PARENTHESIS_ __VA_ARGS__(     \\\n                                            /*empty*/)))\n\n#define OF_PP_INTERNAL_IS_VARIADIC_EMPTY_(e0, e1, e2, e3) \\\n  OF_PP_INTERNAL_VARIADIC_HAS_COMMA(                      \\\n      OF_PP_INTERNAL_CAT5(OF_PP_INTERNAL_IS_EMPTY_CASE_, e0, e1, e2, e3))\n\n#define OF_PP_INTERNAL_VARIADIC_HAS_COMMA(...)                                                    \\\n  OF_PP_INTERNAL_CAT(OF_PP_INTERNAL_VARIADIC_HAS_COMMA_I(                                         \\\n                         __VA_ARGS__, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, \\\n                         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,  \\\n                         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0), )\n#define OF_PP_INTERNAL_VARIADIC_HAS_COMMA_I(                                                       \\\n    e0, e1, e2, e3, e4, e5, e6, e7, e8, e9, e10, e11, e12, e13, e14, e15, e16, e17, e18, e19, e20, \\\n    e21, e22, e23, e24, e25, e26, e27, e28, e29, e30, e31, e32, e33, e34, e35, e36, e37, e38, e39, \\\n    e40, e41, e42, e43, e44, e45, e46, e47, e48, e49, e50, e51, e52, e53, e54, e55, e56, e57, e58, \\\n    e59, e60, e61, e62, e63, has_comma, ...)                                                       \\\n  has_comma\n\n#define _OF_PP_INTERNAL_TRIGGER_PARENTHESIS_(...) ,\n\n#define OF_PP_INTERNAL_CAT5(e0, e1, e2, e3, e4) e0##e1##e2##e3##e4\n#define OF_PP_INTERNAL_IS_EMPTY_CASE_0001 ,\n\n// Seq Product\n\n#define OF_PP_INTERNAL_SEQ_PRODUCT_FOR_EACH_TUPLE(macro, seq0, ...) \\\n  OF_PP_INTERNAL_SEQ_FOR_EACH_TUPLE(macro, _, OF_PP_INTERNAL_SEQ_PRODUCT(seq0, __VA_ARGS__))\n\n#define OF_PP_INTERNAL_SEQ_PRODUCT(seq0, ...)         \\\n  OF_PP_INTERNAL_CAT(                                 \\\n      OF_PP_INTERNAL_CAT(OF_PP_INTERNAL_SEQ_PRODUCT_, \\\n                         OF_PP_INTERNAL_VARIADIC_SIZE(seq0, __VA_ARGS__))(seq0, __VA_ARGS__), )\n\n#define OF_PP_INTERNAL_SEQ_PRODUCT_0()\n#define OF_PP_INTERNAL_SEQ_PRODUCT_1(seq0) OF_PP_INTERNAL_TUPLE_SEQ_X_ATOMIC_SEQ((()), seq0)\n#define OF_PP_INTERNAL_SEQ_PRODUCT_2(seq0, ...) \\\n  OF_PP_INTERNAL_TUPLE_SEQ_X_ATOMIC_SEQ(OF_PP_INTERNAL_SEQ_PRODUCT_1(__VA_ARGS__), seq0)\n#define OF_PP_INTERNAL_SEQ_PRODUCT_3(seq0, ...) \\\n  OF_PP_INTERNAL_TUPLE_SEQ_X_ATOMIC_SEQ(OF_PP_INTERNAL_SEQ_PRODUCT_2(__VA_ARGS__), seq0)\n#define OF_PP_INTERNAL_SEQ_PRODUCT_4(seq0, ...) \\\n  OF_PP_INTERNAL_TUPLE_SEQ_X_ATOMIC_SEQ(OF_PP_INTERNAL_SEQ_PRODUCT_3(__VA_ARGS__), seq0)\n#define OF_PP_INTERNAL_SEQ_PRODUCT_5(seq0, ...) \\\n  OF_PP_INTERNAL_TUPLE_SEQ_X_ATOMIC_SEQ(OF_PP_INTERNAL_SEQ_PRODUCT_4(__VA_ARGS__), seq0)\n#define OF_PP_INTERNAL_SEQ_PRODUCT_6(seq0, ...) \\\n  OF_PP_INTERNAL_TUPLE_SEQ_X_ATOMIC_SEQ(OF_PP_INTERNAL_SEQ_PRODUCT_5(__VA_ARGS__), seq0)\n#define OF_PP_INTERNAL_SEQ_PRODUCT_7(seq0, ...) \\\n  OF_PP_INTERNAL_TUPLE_SEQ_X_ATOMIC_SEQ(OF_PP_INTERNAL_SEQ_PRODUCT_6(__VA_ARGS__), seq0)\n#define OF_PP_INTERNAL_SEQ_PRODUCT_8(seq0, ...) \\\n  OF_PP_INTERNAL_TUPLE_SEQ_X_ATOMIC_SEQ(OF_PP_INTERNAL_SEQ_PRODUCT_7(__VA_ARGS__), seq0)\n#define OF_PP_INTERNAL_SEQ_PRODUCT_9(seq0, ...) \\\n  OF_PP_INTERNAL_TUPLE_SEQ_X_ATOMIC_SEQ(OF_PP_INTERNAL_SEQ_PRODUCT_8(__VA_ARGS__), seq0)\n#define OF_PP_INTERNAL_SEQ_PRODUCT_10(seq0, ...) \\\n  OF_PP_INTERNAL_TUPLE_SEQ_X_ATOMIC_SEQ(OF_PP_INTERNAL_SEQ_PRODUCT_9(__VA_ARGS__), seq0)\n\n// Seq ForEach\n\n#define OF_PP_INTERNAL_OUTTER_FOR_EACH_TUPLE(macro, seq) \\\n  OF_PP_INTERNAL_OUTTER_SEQ_FOR_EACH_TUPLE(macro, _, seq)\n#define OF_PP_INTERNAL_FOR_EACH_TUPLE(macro, seq) OF_PP_INTERNAL_SEQ_FOR_EACH_TUPLE(macro, _, seq)\n#define OF_PP_INTERNAL_TUPLE_SEQ_X_ATOMIC_SEQ(tuple_seq, atomic_seq)       \\\n  OF_PP_INTERNAL_D1_SEQ_FOR_EACH(OF_PP_INTERNAL_D1_APPLY_ATOMIC_WITH_DATA, \\\n                                 OF_PP_INTERNAL_TUPLE_X_ATOMIC_SEQ, atomic_seq, tuple_seq)\n\n#define OF_PP_INTERNAL_TUPLE_X_ATOMIC_SEQ(atomic_seq, tuple)               \\\n  OF_PP_INTERNAL_D2_SEQ_FOR_EACH(OF_PP_INTERNAL_D2_APPLY_ATOMIC_WITH_DATA, \\\n                                 OF_PP_INTERNAL_MAKE_SEQ_TUPLE_PUSH_FRONT, tuple, atomic_seq)\n\n#define OF_PP_INTERNAL_D1_APPLY_ATOMIC_WITH_DATA(m, d, x) m(d, x)\n#define OF_PP_INTERNAL_D2_APPLY_ATOMIC_WITH_DATA(m, d, x) m(d, x)\n\n#define OF_PP_INTERNAL_MAKE_SEQ_TUPLE_PUSH_FRONT(tuple, x) \\\n  (OF_PP_INTERNAL_TUPLE_PUSH_FRONT(tuple, x))\n\n// Seq Size\n\n#define OF_PP_INTERNAL_SEQ_SIZE(seq) OF_PP_INTERNAL_SEQ_SIZE_I(seq)\n#define OF_PP_INTERNAL_SEQ_SIZE_I(seq) \\\n  OF_PP_INTERNAL_CAT(OF_PP_INTERNAL_SEQ_SIZE_, OF_PP_INTERNAL_SEQ_SIZE_0 seq)\n\n#define OF_PP_INTERNAL_OUTTER_SEQ_FOR_EACH_TUPLE OF_PP_INTERNAL_D0_SEQ_FOR_EACH_TUPLE\n#define OF_PP_INTERNAL_SEQ_FOR_EACH_TUPLE OF_PP_INTERNAL_D1_SEQ_FOR_EACH_TUPLE\n\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_TUPLE(m, d, seq) \\\n  OF_PP_INTERNAL_D0_SEQ_FOR_EACH(OF_PP_INTERNAL_D0_APPLY_TUPLE, m, d, seq)\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_TUPLE(m, d, seq) \\\n  OF_PP_INTERNAL_D1_SEQ_FOR_EACH(OF_PP_INTERNAL_APPLY_TUPLE, m, d, seq)\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_TUPLE(m, d, seq) \\\n  OF_PP_INTERNAL_D2_SEQ_FOR_EACH(OF_PP_INTERNAL_APPLY_TUPLE, m, d, seq)\n\n#define OF_PP_INTERNAL_SEQ_FOR_EACH_ATOMIC OF_PP_INTERNAL_D1_SEQ_FOR_EACH_ATOMIC\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_ATOMIC(m, d, seq) \\\n  OF_PP_INTERNAL_D1_SEQ_FOR_EACH(OF_PP_INTERNAL_APPLY_ATOMIC, m, d, seq)\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_ATOMIC(m, d, seq) \\\n  OF_PP_INTERNAL_D2_SEQ_FOR_EACH(OF_PP_INTERNAL_APPLY_ATOMIC, m, d, seq)\n\n#define OF_PP_INTERNAL_D0_APPLY_TUPLE(m, d, t) OF_PP_INTERNAL_D0_APPLY_TUPLE_I(m, d, t)\n#define OF_PP_INTERNAL_D0_APPLY_TUPLE_I(m, d, t) m t\n\n#define OF_PP_INTERNAL_APPLY_TUPLE(m, d, t) OF_PP_INTERNAL_APPLY_TUPLE_I(m, d, t)\n#define OF_PP_INTERNAL_APPLY_TUPLE_I(m, d, t) m t\n#define OF_PP_INTERNAL_APPLY_ATOMIC(m, d, x) m(x)\n#define OF_PP_INTERNAL_APPLY_ATOMIC_WITH_DATA(m, d, x) m(d, x)\n\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH(apply, m, d, seq)                            \\\n  OF_PP_INTERNAL_CAT(OF_PP_INTERNAL_D0_SEQ_FOR_EACH_, OF_PP_INTERNAL_SEQ_SIZE(seq)) \\\n  (apply, m, d, seq)\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH(apply, m, d, seq)                            \\\n  OF_PP_INTERNAL_CAT(OF_PP_INTERNAL_D1_SEQ_FOR_EACH_, OF_PP_INTERNAL_SEQ_SIZE(seq)) \\\n  (apply, m, d, seq)\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH(apply, m, d, seq)                            \\\n  OF_PP_INTERNAL_CAT(OF_PP_INTERNAL_D2_SEQ_FOR_EACH_, OF_PP_INTERNAL_SEQ_SIZE(seq)) \\\n  (apply, m, d, seq)\n\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_0(apply, m, d, seq)\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_0(apply, m, d, seq)\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_0(apply, m, d, seq)\n\n// php code to generate iterator macro\n// clang-format off\n/*\n<?php $limit = 512; for ($i = 0; $i < $limit; ++$i) {?> \n#define OF_PP_INTERNAL_SEQ_SIZE_<?= $i?>(_) OF_PP_INTERNAL_SEQ_SIZE_<?= $i + 1?> \n\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_<?= $i?> <?= $i?>\n\n<?php $dim = 2; for ($d = 0; $d <= $dim; ++$d) {?> \n#define OF_PP_INTERNAL_D<?= $d?>_SEQ_FOR_EACH_<?= $i + 1?>(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D<?= $d?>_SEQ_FOR_EACH_<?= $i?>(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n<?php }?> \n\n<?php }?> \n*/\n// clang-format on\n\n//  do not edit iterator macro directly, it's generated by the above php code.\n#define OF_PP_INTERNAL_SEQ_SIZE_0(_) OF_PP_INTERNAL_SEQ_SIZE_1\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_0 0\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_1(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_0(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_1(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_0(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_1(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_0(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_1(_) OF_PP_INTERNAL_SEQ_SIZE_2\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_1 1\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_2(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_1(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_2(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_1(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_2(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_1(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_2(_) OF_PP_INTERNAL_SEQ_SIZE_3\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_2 2\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_3(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_2(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_3(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_2(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_3(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_2(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_3(_) OF_PP_INTERNAL_SEQ_SIZE_4\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_3 3\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_4(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_3(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_4(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_3(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_4(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_3(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_4(_) OF_PP_INTERNAL_SEQ_SIZE_5\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_4 4\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_5(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_4(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_5(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_4(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_5(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_4(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_5(_) OF_PP_INTERNAL_SEQ_SIZE_6\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_5 5\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_6(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_5(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_6(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_5(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_6(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_5(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_6(_) OF_PP_INTERNAL_SEQ_SIZE_7\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_6 6\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_7(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_6(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_7(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_6(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_7(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_6(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_7(_) OF_PP_INTERNAL_SEQ_SIZE_8\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_7 7\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_8(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_7(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_8(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_7(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_8(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_7(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_8(_) OF_PP_INTERNAL_SEQ_SIZE_9\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_8 8\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_9(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_8(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_9(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_8(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_9(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_8(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_9(_) OF_PP_INTERNAL_SEQ_SIZE_10\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_9 9\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_10(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_9(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_10(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_9(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_10(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_9(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_10(_) OF_PP_INTERNAL_SEQ_SIZE_11\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_10 10\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_11(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_10(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_11(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_10(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_11(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_10(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_11(_) OF_PP_INTERNAL_SEQ_SIZE_12\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_11 11\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_12(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_11(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_12(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_11(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_12(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_11(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_12(_) OF_PP_INTERNAL_SEQ_SIZE_13\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_12 12\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_13(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_12(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_13(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_12(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_13(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_12(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_13(_) OF_PP_INTERNAL_SEQ_SIZE_14\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_13 13\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_14(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_13(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_14(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_13(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_14(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_13(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_14(_) OF_PP_INTERNAL_SEQ_SIZE_15\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_14 14\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_15(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_14(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_15(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_14(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_15(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_14(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_15(_) OF_PP_INTERNAL_SEQ_SIZE_16\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_15 15\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_16(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_15(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_16(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_15(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_16(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_15(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_16(_) OF_PP_INTERNAL_SEQ_SIZE_17\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_16 16\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_17(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_16(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_17(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_16(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_17(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_16(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_17(_) OF_PP_INTERNAL_SEQ_SIZE_18\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_17 17\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_18(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_17(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_18(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_17(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_18(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_17(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_18(_) OF_PP_INTERNAL_SEQ_SIZE_19\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_18 18\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_19(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_18(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_19(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_18(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_19(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_18(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_19(_) OF_PP_INTERNAL_SEQ_SIZE_20\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_19 19\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_20(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_19(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_20(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_19(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_20(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_19(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_20(_) OF_PP_INTERNAL_SEQ_SIZE_21\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_20 20\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_21(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_20(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_21(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_20(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_21(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_20(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_21(_) OF_PP_INTERNAL_SEQ_SIZE_22\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_21 21\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_22(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_21(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_22(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_21(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_22(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_21(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_22(_) OF_PP_INTERNAL_SEQ_SIZE_23\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_22 22\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_23(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_22(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_23(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_22(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_23(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_22(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_23(_) OF_PP_INTERNAL_SEQ_SIZE_24\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_23 23\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_24(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_23(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_24(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_23(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_24(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_23(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_24(_) OF_PP_INTERNAL_SEQ_SIZE_25\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_24 24\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_25(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_24(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_25(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_24(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_25(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_24(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_25(_) OF_PP_INTERNAL_SEQ_SIZE_26\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_25 25\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_26(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_25(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_26(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_25(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_26(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_25(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_26(_) OF_PP_INTERNAL_SEQ_SIZE_27\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_26 26\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_27(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_26(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_27(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_26(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_27(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_26(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_27(_) OF_PP_INTERNAL_SEQ_SIZE_28\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_27 27\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_28(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_27(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_28(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_27(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_28(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_27(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_28(_) OF_PP_INTERNAL_SEQ_SIZE_29\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_28 28\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_29(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_28(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_29(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_28(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_29(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_28(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_29(_) OF_PP_INTERNAL_SEQ_SIZE_30\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_29 29\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_30(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_29(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_30(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_29(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_30(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_29(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_30(_) OF_PP_INTERNAL_SEQ_SIZE_31\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_30 30\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_31(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_30(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_31(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_30(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_31(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_30(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_31(_) OF_PP_INTERNAL_SEQ_SIZE_32\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_31 31\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_32(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_31(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_32(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_31(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_32(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_31(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_32(_) OF_PP_INTERNAL_SEQ_SIZE_33\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_32 32\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_33(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_32(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_33(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_32(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_33(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_32(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_33(_) OF_PP_INTERNAL_SEQ_SIZE_34\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_33 33\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_34(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_33(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_34(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_33(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_34(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_33(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_34(_) OF_PP_INTERNAL_SEQ_SIZE_35\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_34 34\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_35(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_34(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_35(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_34(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_35(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_34(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_35(_) OF_PP_INTERNAL_SEQ_SIZE_36\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_35 35\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_36(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_35(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_36(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_35(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_36(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_35(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_36(_) OF_PP_INTERNAL_SEQ_SIZE_37\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_36 36\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_37(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_36(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_37(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_36(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_37(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_36(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_37(_) OF_PP_INTERNAL_SEQ_SIZE_38\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_37 37\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_38(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_37(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_38(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_37(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_38(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_37(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_38(_) OF_PP_INTERNAL_SEQ_SIZE_39\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_38 38\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_39(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_38(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_39(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_38(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_39(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_38(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_39(_) OF_PP_INTERNAL_SEQ_SIZE_40\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_39 39\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_40(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_39(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_40(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_39(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_40(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_39(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_40(_) OF_PP_INTERNAL_SEQ_SIZE_41\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_40 40\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_41(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_40(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_41(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_40(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_41(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_40(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_41(_) OF_PP_INTERNAL_SEQ_SIZE_42\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_41 41\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_42(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_41(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_42(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_41(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_42(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_41(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_42(_) OF_PP_INTERNAL_SEQ_SIZE_43\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_42 42\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_43(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_42(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_43(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_42(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_43(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_42(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_43(_) OF_PP_INTERNAL_SEQ_SIZE_44\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_43 43\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_44(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_43(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_44(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_43(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_44(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_43(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_44(_) OF_PP_INTERNAL_SEQ_SIZE_45\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_44 44\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_45(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_44(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_45(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_44(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_45(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_44(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_45(_) OF_PP_INTERNAL_SEQ_SIZE_46\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_45 45\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_46(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_45(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_46(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_45(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_46(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_45(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_46(_) OF_PP_INTERNAL_SEQ_SIZE_47\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_46 46\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_47(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_46(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_47(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_46(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_47(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_46(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_47(_) OF_PP_INTERNAL_SEQ_SIZE_48\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_47 47\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_48(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_47(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_48(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_47(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_48(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_47(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_48(_) OF_PP_INTERNAL_SEQ_SIZE_49\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_48 48\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_49(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_48(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_49(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_48(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_49(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_48(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_49(_) OF_PP_INTERNAL_SEQ_SIZE_50\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_49 49\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_50(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_49(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_50(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_49(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_50(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_49(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_50(_) OF_PP_INTERNAL_SEQ_SIZE_51\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_50 50\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_51(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_50(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_51(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_50(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_51(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_50(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_51(_) OF_PP_INTERNAL_SEQ_SIZE_52\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_51 51\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_52(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_51(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_52(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_51(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_52(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_51(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_52(_) OF_PP_INTERNAL_SEQ_SIZE_53\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_52 52\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_53(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_52(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_53(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_52(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_53(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_52(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_53(_) OF_PP_INTERNAL_SEQ_SIZE_54\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_53 53\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_54(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_53(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_54(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_53(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_54(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_53(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_54(_) OF_PP_INTERNAL_SEQ_SIZE_55\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_54 54\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_55(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_54(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_55(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_54(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_55(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_54(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_55(_) OF_PP_INTERNAL_SEQ_SIZE_56\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_55 55\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_56(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_55(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_56(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_55(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_56(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_55(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_56(_) OF_PP_INTERNAL_SEQ_SIZE_57\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_56 56\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_57(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_56(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_57(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_56(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_57(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_56(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_57(_) OF_PP_INTERNAL_SEQ_SIZE_58\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_57 57\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_58(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_57(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_58(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_57(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_58(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_57(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_58(_) OF_PP_INTERNAL_SEQ_SIZE_59\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_58 58\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_59(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_58(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_59(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_58(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_59(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_58(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_59(_) OF_PP_INTERNAL_SEQ_SIZE_60\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_59 59\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_60(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_59(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_60(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_59(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_60(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_59(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_60(_) OF_PP_INTERNAL_SEQ_SIZE_61\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_60 60\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_61(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_60(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_61(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_60(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_61(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_60(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_61(_) OF_PP_INTERNAL_SEQ_SIZE_62\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_61 61\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_62(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_61(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_62(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_61(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_62(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_61(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_62(_) OF_PP_INTERNAL_SEQ_SIZE_63\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_62 62\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_63(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_62(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_63(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_62(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_63(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_62(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_63(_) OF_PP_INTERNAL_SEQ_SIZE_64\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_63 63\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_64(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_63(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_64(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_63(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_64(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_63(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_64(_) OF_PP_INTERNAL_SEQ_SIZE_65\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_64 64\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_65(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_64(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_65(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_64(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_65(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_64(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_65(_) OF_PP_INTERNAL_SEQ_SIZE_66\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_65 65\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_66(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_65(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_66(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_65(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_66(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_65(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_66(_) OF_PP_INTERNAL_SEQ_SIZE_67\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_66 66\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_67(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_66(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_67(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_66(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_67(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_66(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_67(_) OF_PP_INTERNAL_SEQ_SIZE_68\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_67 67\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_68(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_67(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_68(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_67(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_68(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_67(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_68(_) OF_PP_INTERNAL_SEQ_SIZE_69\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_68 68\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_69(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_68(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_69(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_68(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_69(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_68(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_69(_) OF_PP_INTERNAL_SEQ_SIZE_70\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_69 69\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_70(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_69(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_70(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_69(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_70(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_69(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_70(_) OF_PP_INTERNAL_SEQ_SIZE_71\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_70 70\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_71(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_70(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_71(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_70(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_71(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_70(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_71(_) OF_PP_INTERNAL_SEQ_SIZE_72\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_71 71\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_72(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_71(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_72(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_71(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_72(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_71(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_72(_) OF_PP_INTERNAL_SEQ_SIZE_73\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_72 72\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_73(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_72(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_73(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_72(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_73(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_72(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_73(_) OF_PP_INTERNAL_SEQ_SIZE_74\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_73 73\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_74(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_73(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_74(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_73(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_74(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_73(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_74(_) OF_PP_INTERNAL_SEQ_SIZE_75\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_74 74\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_75(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_74(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_75(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_74(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_75(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_74(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_75(_) OF_PP_INTERNAL_SEQ_SIZE_76\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_75 75\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_76(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_75(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_76(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_75(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_76(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_75(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_76(_) OF_PP_INTERNAL_SEQ_SIZE_77\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_76 76\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_77(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_76(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_77(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_76(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_77(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_76(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_77(_) OF_PP_INTERNAL_SEQ_SIZE_78\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_77 77\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_78(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_77(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_78(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_77(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_78(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_77(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_78(_) OF_PP_INTERNAL_SEQ_SIZE_79\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_78 78\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_79(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_78(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_79(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_78(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_79(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_78(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_79(_) OF_PP_INTERNAL_SEQ_SIZE_80\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_79 79\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_80(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_79(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_80(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_79(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_80(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_79(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_80(_) OF_PP_INTERNAL_SEQ_SIZE_81\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_80 80\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_81(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_80(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_81(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_80(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_81(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_80(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_81(_) OF_PP_INTERNAL_SEQ_SIZE_82\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_81 81\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_82(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_81(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_82(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_81(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_82(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_81(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_82(_) OF_PP_INTERNAL_SEQ_SIZE_83\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_82 82\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_83(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_82(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_83(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_82(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_83(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_82(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_83(_) OF_PP_INTERNAL_SEQ_SIZE_84\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_83 83\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_84(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_83(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_84(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_83(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_84(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_83(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_84(_) OF_PP_INTERNAL_SEQ_SIZE_85\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_84 84\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_85(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_84(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_85(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_84(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_85(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_84(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_85(_) OF_PP_INTERNAL_SEQ_SIZE_86\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_85 85\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_86(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_85(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_86(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_85(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_86(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_85(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_86(_) OF_PP_INTERNAL_SEQ_SIZE_87\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_86 86\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_87(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_86(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_87(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_86(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_87(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_86(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_87(_) OF_PP_INTERNAL_SEQ_SIZE_88\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_87 87\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_88(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_87(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_88(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_87(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_88(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_87(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_88(_) OF_PP_INTERNAL_SEQ_SIZE_89\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_88 88\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_89(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_88(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_89(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_88(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_89(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_88(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_89(_) OF_PP_INTERNAL_SEQ_SIZE_90\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_89 89\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_90(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_89(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_90(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_89(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_90(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_89(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_90(_) OF_PP_INTERNAL_SEQ_SIZE_91\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_90 90\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_91(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_90(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_91(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_90(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_91(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_90(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_91(_) OF_PP_INTERNAL_SEQ_SIZE_92\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_91 91\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_92(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_91(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_92(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_91(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_92(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_91(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_92(_) OF_PP_INTERNAL_SEQ_SIZE_93\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_92 92\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_93(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_92(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_93(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_92(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_93(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_92(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_93(_) OF_PP_INTERNAL_SEQ_SIZE_94\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_93 93\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_94(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_93(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_94(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_93(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_94(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_93(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_94(_) OF_PP_INTERNAL_SEQ_SIZE_95\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_94 94\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_95(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_94(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_95(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_94(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_95(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_94(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_95(_) OF_PP_INTERNAL_SEQ_SIZE_96\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_95 95\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_96(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_95(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_96(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_95(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_96(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_95(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_96(_) OF_PP_INTERNAL_SEQ_SIZE_97\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_96 96\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_97(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_96(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_97(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_96(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_97(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_96(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_97(_) OF_PP_INTERNAL_SEQ_SIZE_98\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_97 97\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_98(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_97(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_98(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_97(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_98(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_97(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_98(_) OF_PP_INTERNAL_SEQ_SIZE_99\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_98 98\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_99(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_98(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_99(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_98(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_99(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                 \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_98(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_99(_) OF_PP_INTERNAL_SEQ_SIZE_100\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_99 99\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_100(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_99(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_100(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_99(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_100(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_99(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_100(_) OF_PP_INTERNAL_SEQ_SIZE_101\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_100 100\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_101(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_100(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_101(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_100(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_101(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_100(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_101(_) OF_PP_INTERNAL_SEQ_SIZE_102\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_101 101\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_102(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_101(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_102(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_101(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_102(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_101(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_102(_) OF_PP_INTERNAL_SEQ_SIZE_103\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_102 102\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_103(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_102(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_103(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_102(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_103(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_102(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_103(_) OF_PP_INTERNAL_SEQ_SIZE_104\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_103 103\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_104(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_103(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_104(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_103(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_104(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_103(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_104(_) OF_PP_INTERNAL_SEQ_SIZE_105\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_104 104\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_105(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_104(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_105(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_104(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_105(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_104(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_105(_) OF_PP_INTERNAL_SEQ_SIZE_106\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_105 105\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_106(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_105(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_106(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_105(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_106(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_105(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_106(_) OF_PP_INTERNAL_SEQ_SIZE_107\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_106 106\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_107(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_106(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_107(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_106(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_107(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_106(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_107(_) OF_PP_INTERNAL_SEQ_SIZE_108\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_107 107\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_108(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_107(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_108(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_107(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_108(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_107(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_108(_) OF_PP_INTERNAL_SEQ_SIZE_109\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_108 108\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_109(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_108(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_109(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_108(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_109(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_108(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_109(_) OF_PP_INTERNAL_SEQ_SIZE_110\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_109 109\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_110(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_109(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_110(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_109(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_110(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_109(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_110(_) OF_PP_INTERNAL_SEQ_SIZE_111\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_110 110\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_111(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_110(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_111(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_110(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_111(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_110(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_111(_) OF_PP_INTERNAL_SEQ_SIZE_112\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_111 111\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_112(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_111(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_112(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_111(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_112(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_111(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_112(_) OF_PP_INTERNAL_SEQ_SIZE_113\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_112 112\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_113(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_112(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_113(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_112(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_113(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_112(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_113(_) OF_PP_INTERNAL_SEQ_SIZE_114\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_113 113\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_114(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_113(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_114(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_113(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_114(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_113(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_114(_) OF_PP_INTERNAL_SEQ_SIZE_115\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_114 114\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_115(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_114(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_115(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_114(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_115(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_114(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_115(_) OF_PP_INTERNAL_SEQ_SIZE_116\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_115 115\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_116(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_115(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_116(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_115(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_116(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_115(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_116(_) OF_PP_INTERNAL_SEQ_SIZE_117\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_116 116\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_117(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_116(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_117(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_116(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_117(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_116(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_117(_) OF_PP_INTERNAL_SEQ_SIZE_118\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_117 117\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_118(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_117(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_118(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_117(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_118(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_117(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_118(_) OF_PP_INTERNAL_SEQ_SIZE_119\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_118 118\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_119(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_118(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_119(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_118(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_119(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_118(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_119(_) OF_PP_INTERNAL_SEQ_SIZE_120\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_119 119\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_120(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_119(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_120(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_119(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_120(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_119(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_120(_) OF_PP_INTERNAL_SEQ_SIZE_121\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_120 120\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_121(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_120(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_121(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_120(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_121(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_120(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_121(_) OF_PP_INTERNAL_SEQ_SIZE_122\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_121 121\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_122(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_121(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_122(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_121(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_122(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_121(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_122(_) OF_PP_INTERNAL_SEQ_SIZE_123\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_122 122\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_123(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_122(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_123(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_122(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_123(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_122(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_123(_) OF_PP_INTERNAL_SEQ_SIZE_124\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_123 123\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_124(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_123(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_124(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_123(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_124(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_123(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_124(_) OF_PP_INTERNAL_SEQ_SIZE_125\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_124 124\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_125(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_124(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_125(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_124(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_125(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_124(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_125(_) OF_PP_INTERNAL_SEQ_SIZE_126\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_125 125\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_126(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_125(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_126(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_125(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_126(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_125(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_126(_) OF_PP_INTERNAL_SEQ_SIZE_127\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_126 126\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_127(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_126(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_127(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_126(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_127(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_126(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_127(_) OF_PP_INTERNAL_SEQ_SIZE_128\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_127 127\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_128(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_127(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_128(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_127(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_128(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_127(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_128(_) OF_PP_INTERNAL_SEQ_SIZE_129\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_128 128\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_129(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_128(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_129(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_128(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_129(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_128(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_129(_) OF_PP_INTERNAL_SEQ_SIZE_130\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_129 129\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_130(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_129(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_130(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_129(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_130(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_129(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_130(_) OF_PP_INTERNAL_SEQ_SIZE_131\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_130 130\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_131(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_130(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_131(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_130(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_131(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_130(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_131(_) OF_PP_INTERNAL_SEQ_SIZE_132\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_131 131\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_132(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_131(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_132(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_131(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_132(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_131(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_132(_) OF_PP_INTERNAL_SEQ_SIZE_133\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_132 132\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_133(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_132(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_133(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_132(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_133(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_132(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_133(_) OF_PP_INTERNAL_SEQ_SIZE_134\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_133 133\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_134(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_133(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_134(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_133(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_134(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_133(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_134(_) OF_PP_INTERNAL_SEQ_SIZE_135\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_134 134\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_135(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_134(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_135(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_134(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_135(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_134(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_135(_) OF_PP_INTERNAL_SEQ_SIZE_136\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_135 135\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_136(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_135(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_136(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_135(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_136(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_135(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_136(_) OF_PP_INTERNAL_SEQ_SIZE_137\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_136 136\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_137(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_136(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_137(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_136(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_137(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_136(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_137(_) OF_PP_INTERNAL_SEQ_SIZE_138\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_137 137\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_138(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_137(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_138(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_137(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_138(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_137(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_138(_) OF_PP_INTERNAL_SEQ_SIZE_139\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_138 138\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_139(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_138(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_139(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_138(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_139(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_138(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_139(_) OF_PP_INTERNAL_SEQ_SIZE_140\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_139 139\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_140(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_139(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_140(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_139(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_140(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_139(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_140(_) OF_PP_INTERNAL_SEQ_SIZE_141\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_140 140\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_141(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_140(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_141(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_140(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_141(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_140(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_141(_) OF_PP_INTERNAL_SEQ_SIZE_142\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_141 141\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_142(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_141(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_142(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_141(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_142(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_141(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_142(_) OF_PP_INTERNAL_SEQ_SIZE_143\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_142 142\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_143(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_142(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_143(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_142(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_143(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_142(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_143(_) OF_PP_INTERNAL_SEQ_SIZE_144\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_143 143\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_144(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_143(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_144(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_143(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_144(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_143(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_144(_) OF_PP_INTERNAL_SEQ_SIZE_145\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_144 144\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_145(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_144(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_145(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_144(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_145(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_144(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_145(_) OF_PP_INTERNAL_SEQ_SIZE_146\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_145 145\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_146(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_145(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_146(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_145(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_146(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_145(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_146(_) OF_PP_INTERNAL_SEQ_SIZE_147\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_146 146\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_147(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_146(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_147(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_146(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_147(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_146(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_147(_) OF_PP_INTERNAL_SEQ_SIZE_148\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_147 147\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_148(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_147(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_148(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_147(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_148(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_147(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_148(_) OF_PP_INTERNAL_SEQ_SIZE_149\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_148 148\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_149(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_148(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_149(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_148(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_149(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_148(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_149(_) OF_PP_INTERNAL_SEQ_SIZE_150\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_149 149\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_150(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_149(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_150(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_149(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_150(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_149(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_150(_) OF_PP_INTERNAL_SEQ_SIZE_151\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_150 150\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_151(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_150(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_151(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_150(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_151(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_150(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_151(_) OF_PP_INTERNAL_SEQ_SIZE_152\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_151 151\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_152(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_151(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_152(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_151(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_152(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_151(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_152(_) OF_PP_INTERNAL_SEQ_SIZE_153\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_152 152\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_153(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_152(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_153(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_152(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_153(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_152(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_153(_) OF_PP_INTERNAL_SEQ_SIZE_154\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_153 153\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_154(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_153(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_154(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_153(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_154(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_153(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_154(_) OF_PP_INTERNAL_SEQ_SIZE_155\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_154 154\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_155(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_154(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_155(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_154(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_155(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_154(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_155(_) OF_PP_INTERNAL_SEQ_SIZE_156\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_155 155\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_156(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_155(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_156(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_155(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_156(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_155(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_156(_) OF_PP_INTERNAL_SEQ_SIZE_157\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_156 156\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_157(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_156(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_157(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_156(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_157(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_156(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_157(_) OF_PP_INTERNAL_SEQ_SIZE_158\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_157 157\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_158(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_157(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_158(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_157(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_158(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_157(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_158(_) OF_PP_INTERNAL_SEQ_SIZE_159\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_158 158\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_159(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_158(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_159(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_158(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_159(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_158(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_159(_) OF_PP_INTERNAL_SEQ_SIZE_160\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_159 159\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_160(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_159(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_160(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_159(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_160(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_159(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_160(_) OF_PP_INTERNAL_SEQ_SIZE_161\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_160 160\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_161(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_160(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_161(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_160(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_161(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_160(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_161(_) OF_PP_INTERNAL_SEQ_SIZE_162\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_161 161\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_162(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_161(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_162(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_161(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_162(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_161(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_162(_) OF_PP_INTERNAL_SEQ_SIZE_163\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_162 162\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_163(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_162(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_163(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_162(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_163(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_162(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_163(_) OF_PP_INTERNAL_SEQ_SIZE_164\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_163 163\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_164(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_163(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_164(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_163(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_164(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_163(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_164(_) OF_PP_INTERNAL_SEQ_SIZE_165\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_164 164\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_165(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_164(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_165(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_164(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_165(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_164(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_165(_) OF_PP_INTERNAL_SEQ_SIZE_166\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_165 165\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_166(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_165(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_166(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_165(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_166(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_165(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_166(_) OF_PP_INTERNAL_SEQ_SIZE_167\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_166 166\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_167(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_166(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_167(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_166(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_167(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_166(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_167(_) OF_PP_INTERNAL_SEQ_SIZE_168\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_167 167\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_168(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_167(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_168(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_167(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_168(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_167(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_168(_) OF_PP_INTERNAL_SEQ_SIZE_169\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_168 168\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_169(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_168(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_169(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_168(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_169(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_168(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_169(_) OF_PP_INTERNAL_SEQ_SIZE_170\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_169 169\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_170(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_169(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_170(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_169(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_170(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_169(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_170(_) OF_PP_INTERNAL_SEQ_SIZE_171\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_170 170\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_171(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_170(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_171(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_170(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_171(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_170(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_171(_) OF_PP_INTERNAL_SEQ_SIZE_172\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_171 171\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_172(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_171(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_172(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_171(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_172(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_171(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_172(_) OF_PP_INTERNAL_SEQ_SIZE_173\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_172 172\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_173(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_172(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_173(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_172(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_173(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_172(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_173(_) OF_PP_INTERNAL_SEQ_SIZE_174\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_173 173\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_174(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_173(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_174(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_173(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_174(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_173(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_174(_) OF_PP_INTERNAL_SEQ_SIZE_175\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_174 174\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_175(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_174(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_175(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_174(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_175(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_174(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_175(_) OF_PP_INTERNAL_SEQ_SIZE_176\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_175 175\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_176(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_175(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_176(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_175(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_176(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_175(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_176(_) OF_PP_INTERNAL_SEQ_SIZE_177\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_176 176\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_177(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_176(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_177(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_176(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_177(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_176(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_177(_) OF_PP_INTERNAL_SEQ_SIZE_178\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_177 177\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_178(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_177(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_178(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_177(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_178(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_177(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_178(_) OF_PP_INTERNAL_SEQ_SIZE_179\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_178 178\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_179(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_178(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_179(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_178(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_179(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_178(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_179(_) OF_PP_INTERNAL_SEQ_SIZE_180\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_179 179\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_180(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_179(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_180(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_179(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_180(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_179(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_180(_) OF_PP_INTERNAL_SEQ_SIZE_181\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_180 180\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_181(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_180(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_181(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_180(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_181(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_180(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_181(_) OF_PP_INTERNAL_SEQ_SIZE_182\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_181 181\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_182(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_181(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_182(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_181(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_182(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_181(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_182(_) OF_PP_INTERNAL_SEQ_SIZE_183\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_182 182\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_183(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_182(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_183(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_182(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_183(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_182(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_183(_) OF_PP_INTERNAL_SEQ_SIZE_184\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_183 183\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_184(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_183(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_184(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_183(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_184(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_183(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_184(_) OF_PP_INTERNAL_SEQ_SIZE_185\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_184 184\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_185(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_184(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_185(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_184(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_185(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_184(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_185(_) OF_PP_INTERNAL_SEQ_SIZE_186\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_185 185\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_186(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_185(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_186(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_185(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_186(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_185(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_186(_) OF_PP_INTERNAL_SEQ_SIZE_187\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_186 186\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_187(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_186(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_187(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_186(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_187(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_186(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_187(_) OF_PP_INTERNAL_SEQ_SIZE_188\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_187 187\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_188(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_187(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_188(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_187(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_188(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_187(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_188(_) OF_PP_INTERNAL_SEQ_SIZE_189\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_188 188\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_189(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_188(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_189(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_188(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_189(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_188(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_189(_) OF_PP_INTERNAL_SEQ_SIZE_190\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_189 189\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_190(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_189(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_190(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_189(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_190(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_189(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_190(_) OF_PP_INTERNAL_SEQ_SIZE_191\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_190 190\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_191(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_190(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_191(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_190(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_191(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_190(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_191(_) OF_PP_INTERNAL_SEQ_SIZE_192\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_191 191\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_192(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_191(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_192(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_191(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_192(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_191(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_192(_) OF_PP_INTERNAL_SEQ_SIZE_193\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_192 192\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_193(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_192(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_193(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_192(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_193(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_192(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_193(_) OF_PP_INTERNAL_SEQ_SIZE_194\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_193 193\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_194(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_193(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_194(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_193(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_194(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_193(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_194(_) OF_PP_INTERNAL_SEQ_SIZE_195\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_194 194\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_195(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_194(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_195(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_194(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_195(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_194(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_195(_) OF_PP_INTERNAL_SEQ_SIZE_196\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_195 195\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_196(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_195(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_196(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_195(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_196(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_195(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_196(_) OF_PP_INTERNAL_SEQ_SIZE_197\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_196 196\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_197(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_196(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_197(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_196(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_197(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_196(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_197(_) OF_PP_INTERNAL_SEQ_SIZE_198\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_197 197\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_198(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_197(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_198(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_197(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_198(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_197(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_198(_) OF_PP_INTERNAL_SEQ_SIZE_199\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_198 198\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_199(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_198(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_199(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_198(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_199(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_198(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_199(_) OF_PP_INTERNAL_SEQ_SIZE_200\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_199 199\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_200(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_199(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_200(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_199(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_200(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_199(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_200(_) OF_PP_INTERNAL_SEQ_SIZE_201\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_200 200\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_201(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_200(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_201(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_200(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_201(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_200(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_201(_) OF_PP_INTERNAL_SEQ_SIZE_202\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_201 201\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_202(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_201(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_202(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_201(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_202(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_201(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_202(_) OF_PP_INTERNAL_SEQ_SIZE_203\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_202 202\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_203(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_202(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_203(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_202(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_203(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_202(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_203(_) OF_PP_INTERNAL_SEQ_SIZE_204\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_203 203\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_204(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_203(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_204(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_203(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_204(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_203(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_204(_) OF_PP_INTERNAL_SEQ_SIZE_205\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_204 204\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_205(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_204(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_205(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_204(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_205(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_204(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_205(_) OF_PP_INTERNAL_SEQ_SIZE_206\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_205 205\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_206(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_205(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_206(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_205(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_206(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_205(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_206(_) OF_PP_INTERNAL_SEQ_SIZE_207\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_206 206\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_207(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_206(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_207(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_206(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_207(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_206(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_207(_) OF_PP_INTERNAL_SEQ_SIZE_208\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_207 207\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_208(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_207(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_208(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_207(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_208(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_207(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_208(_) OF_PP_INTERNAL_SEQ_SIZE_209\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_208 208\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_209(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_208(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_209(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_208(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_209(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_208(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_209(_) OF_PP_INTERNAL_SEQ_SIZE_210\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_209 209\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_210(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_209(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_210(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_209(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_210(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_209(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_210(_) OF_PP_INTERNAL_SEQ_SIZE_211\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_210 210\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_211(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_210(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_211(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_210(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_211(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_210(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_211(_) OF_PP_INTERNAL_SEQ_SIZE_212\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_211 211\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_212(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_211(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_212(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_211(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_212(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_211(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_212(_) OF_PP_INTERNAL_SEQ_SIZE_213\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_212 212\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_213(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_212(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_213(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_212(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_213(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_212(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_213(_) OF_PP_INTERNAL_SEQ_SIZE_214\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_213 213\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_214(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_213(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_214(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_213(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_214(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_213(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_214(_) OF_PP_INTERNAL_SEQ_SIZE_215\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_214 214\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_215(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_214(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_215(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_214(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_215(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_214(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_215(_) OF_PP_INTERNAL_SEQ_SIZE_216\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_215 215\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_216(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_215(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_216(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_215(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_216(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_215(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_216(_) OF_PP_INTERNAL_SEQ_SIZE_217\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_216 216\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_217(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_216(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_217(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_216(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_217(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_216(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_217(_) OF_PP_INTERNAL_SEQ_SIZE_218\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_217 217\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_218(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_217(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_218(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_217(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_218(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_217(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_218(_) OF_PP_INTERNAL_SEQ_SIZE_219\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_218 218\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_219(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_218(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_219(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_218(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_219(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_218(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_219(_) OF_PP_INTERNAL_SEQ_SIZE_220\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_219 219\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_220(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_219(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_220(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_219(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_220(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_219(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_220(_) OF_PP_INTERNAL_SEQ_SIZE_221\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_220 220\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_221(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_220(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_221(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_220(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_221(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_220(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_221(_) OF_PP_INTERNAL_SEQ_SIZE_222\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_221 221\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_222(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_221(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_222(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_221(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_222(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_221(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_222(_) OF_PP_INTERNAL_SEQ_SIZE_223\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_222 222\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_223(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_222(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_223(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_222(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_223(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_222(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_223(_) OF_PP_INTERNAL_SEQ_SIZE_224\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_223 223\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_224(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_223(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_224(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_223(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_224(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_223(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_224(_) OF_PP_INTERNAL_SEQ_SIZE_225\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_224 224\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_225(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_224(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_225(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_224(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_225(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_224(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_225(_) OF_PP_INTERNAL_SEQ_SIZE_226\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_225 225\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_226(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_225(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_226(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_225(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_226(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_225(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_226(_) OF_PP_INTERNAL_SEQ_SIZE_227\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_226 226\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_227(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_226(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_227(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_226(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_227(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_226(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_227(_) OF_PP_INTERNAL_SEQ_SIZE_228\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_227 227\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_228(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_227(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_228(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_227(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_228(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_227(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_228(_) OF_PP_INTERNAL_SEQ_SIZE_229\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_228 228\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_229(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_228(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_229(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_228(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_229(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_228(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_229(_) OF_PP_INTERNAL_SEQ_SIZE_230\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_229 229\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_230(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_229(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_230(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_229(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_230(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_229(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_230(_) OF_PP_INTERNAL_SEQ_SIZE_231\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_230 230\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_231(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_230(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_231(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_230(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_231(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_230(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_231(_) OF_PP_INTERNAL_SEQ_SIZE_232\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_231 231\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_232(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_231(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_232(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_231(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_232(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_231(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_232(_) OF_PP_INTERNAL_SEQ_SIZE_233\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_232 232\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_233(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_232(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_233(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_232(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_233(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_232(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_233(_) OF_PP_INTERNAL_SEQ_SIZE_234\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_233 233\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_234(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_233(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_234(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_233(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_234(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_233(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_234(_) OF_PP_INTERNAL_SEQ_SIZE_235\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_234 234\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_235(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_234(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_235(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_234(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_235(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_234(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_235(_) OF_PP_INTERNAL_SEQ_SIZE_236\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_235 235\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_236(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_235(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_236(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_235(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_236(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_235(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_236(_) OF_PP_INTERNAL_SEQ_SIZE_237\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_236 236\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_237(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_236(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_237(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_236(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_237(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_236(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_237(_) OF_PP_INTERNAL_SEQ_SIZE_238\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_237 237\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_238(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_237(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_238(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_237(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_238(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_237(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_238(_) OF_PP_INTERNAL_SEQ_SIZE_239\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_238 238\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_239(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_238(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_239(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_238(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_239(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_238(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_239(_) OF_PP_INTERNAL_SEQ_SIZE_240\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_239 239\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_240(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_239(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_240(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_239(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_240(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_239(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_240(_) OF_PP_INTERNAL_SEQ_SIZE_241\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_240 240\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_241(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_240(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_241(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_240(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_241(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_240(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_241(_) OF_PP_INTERNAL_SEQ_SIZE_242\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_241 241\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_242(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_241(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_242(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_241(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_242(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_241(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_242(_) OF_PP_INTERNAL_SEQ_SIZE_243\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_242 242\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_243(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_242(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_243(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_242(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_243(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_242(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_243(_) OF_PP_INTERNAL_SEQ_SIZE_244\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_243 243\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_244(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_243(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_244(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_243(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_244(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_243(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_244(_) OF_PP_INTERNAL_SEQ_SIZE_245\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_244 244\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_245(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_244(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_245(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_244(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_245(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_244(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_245(_) OF_PP_INTERNAL_SEQ_SIZE_246\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_245 245\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_246(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_245(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_246(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_245(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_246(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_245(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_246(_) OF_PP_INTERNAL_SEQ_SIZE_247\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_246 246\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_247(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_246(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_247(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_246(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_247(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_246(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_247(_) OF_PP_INTERNAL_SEQ_SIZE_248\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_247 247\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_248(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_247(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_248(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_247(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_248(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_247(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_248(_) OF_PP_INTERNAL_SEQ_SIZE_249\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_248 248\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_249(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_248(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_249(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_248(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_249(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_248(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_249(_) OF_PP_INTERNAL_SEQ_SIZE_250\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_249 249\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_250(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_249(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_250(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_249(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_250(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_249(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_250(_) OF_PP_INTERNAL_SEQ_SIZE_251\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_250 250\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_251(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_250(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_251(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_250(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_251(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_250(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_251(_) OF_PP_INTERNAL_SEQ_SIZE_252\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_251 251\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_252(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_251(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_252(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_251(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_252(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_251(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_252(_) OF_PP_INTERNAL_SEQ_SIZE_253\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_252 252\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_253(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_252(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_253(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_252(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_253(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_252(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_253(_) OF_PP_INTERNAL_SEQ_SIZE_254\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_253 253\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_254(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_253(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_254(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_253(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_254(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_253(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_254(_) OF_PP_INTERNAL_SEQ_SIZE_255\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_254 254\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_255(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_254(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_255(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_254(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_255(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_254(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_255(_) OF_PP_INTERNAL_SEQ_SIZE_256\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_255 255\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_256(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_255(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_256(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_255(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_256(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_255(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_256(_) OF_PP_INTERNAL_SEQ_SIZE_257\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_256 256\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_257(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_256(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_257(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_256(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_257(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_256(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_257(_) OF_PP_INTERNAL_SEQ_SIZE_258\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_257 257\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_258(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_257(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_258(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_257(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_258(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_257(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_258(_) OF_PP_INTERNAL_SEQ_SIZE_259\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_258 258\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_259(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_258(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_259(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_258(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_259(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_258(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_259(_) OF_PP_INTERNAL_SEQ_SIZE_260\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_259 259\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_260(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_259(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_260(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_259(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_260(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_259(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_260(_) OF_PP_INTERNAL_SEQ_SIZE_261\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_260 260\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_261(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_260(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_261(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_260(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_261(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_260(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_261(_) OF_PP_INTERNAL_SEQ_SIZE_262\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_261 261\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_262(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_261(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_262(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_261(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_262(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_261(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_262(_) OF_PP_INTERNAL_SEQ_SIZE_263\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_262 262\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_263(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_262(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_263(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_262(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_263(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_262(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_263(_) OF_PP_INTERNAL_SEQ_SIZE_264\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_263 263\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_264(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_263(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_264(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_263(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_264(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_263(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_264(_) OF_PP_INTERNAL_SEQ_SIZE_265\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_264 264\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_265(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_264(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_265(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_264(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_265(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_264(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_265(_) OF_PP_INTERNAL_SEQ_SIZE_266\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_265 265\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_266(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_265(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_266(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_265(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_266(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_265(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_266(_) OF_PP_INTERNAL_SEQ_SIZE_267\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_266 266\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_267(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_266(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_267(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_266(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_267(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_266(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_267(_) OF_PP_INTERNAL_SEQ_SIZE_268\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_267 267\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_268(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_267(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_268(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_267(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_268(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_267(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_268(_) OF_PP_INTERNAL_SEQ_SIZE_269\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_268 268\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_269(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_268(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_269(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_268(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_269(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_268(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_269(_) OF_PP_INTERNAL_SEQ_SIZE_270\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_269 269\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_270(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_269(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_270(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_269(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_270(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_269(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_270(_) OF_PP_INTERNAL_SEQ_SIZE_271\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_270 270\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_271(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_270(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_271(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_270(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_271(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_270(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_271(_) OF_PP_INTERNAL_SEQ_SIZE_272\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_271 271\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_272(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_271(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_272(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_271(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_272(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_271(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_272(_) OF_PP_INTERNAL_SEQ_SIZE_273\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_272 272\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_273(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_272(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_273(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_272(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_273(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_272(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_273(_) OF_PP_INTERNAL_SEQ_SIZE_274\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_273 273\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_274(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_273(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_274(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_273(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_274(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_273(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_274(_) OF_PP_INTERNAL_SEQ_SIZE_275\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_274 274\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_275(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_274(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_275(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_274(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_275(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_274(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_275(_) OF_PP_INTERNAL_SEQ_SIZE_276\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_275 275\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_276(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_275(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_276(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_275(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_276(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_275(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_276(_) OF_PP_INTERNAL_SEQ_SIZE_277\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_276 276\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_277(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_276(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_277(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_276(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_277(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_276(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_277(_) OF_PP_INTERNAL_SEQ_SIZE_278\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_277 277\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_278(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_277(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_278(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_277(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_278(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_277(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_278(_) OF_PP_INTERNAL_SEQ_SIZE_279\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_278 278\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_279(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_278(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_279(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_278(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_279(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_278(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_279(_) OF_PP_INTERNAL_SEQ_SIZE_280\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_279 279\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_280(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_279(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_280(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_279(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_280(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_279(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_280(_) OF_PP_INTERNAL_SEQ_SIZE_281\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_280 280\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_281(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_280(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_281(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_280(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_281(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_280(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_281(_) OF_PP_INTERNAL_SEQ_SIZE_282\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_281 281\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_282(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_281(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_282(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_281(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_282(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_281(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_282(_) OF_PP_INTERNAL_SEQ_SIZE_283\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_282 282\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_283(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_282(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_283(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_282(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_283(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_282(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_283(_) OF_PP_INTERNAL_SEQ_SIZE_284\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_283 283\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_284(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_283(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_284(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_283(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_284(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_283(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_284(_) OF_PP_INTERNAL_SEQ_SIZE_285\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_284 284\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_285(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_284(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_285(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_284(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_285(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_284(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_285(_) OF_PP_INTERNAL_SEQ_SIZE_286\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_285 285\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_286(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_285(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_286(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_285(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_286(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_285(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_286(_) OF_PP_INTERNAL_SEQ_SIZE_287\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_286 286\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_287(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_286(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_287(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_286(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_287(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_286(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_287(_) OF_PP_INTERNAL_SEQ_SIZE_288\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_287 287\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_288(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_287(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_288(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_287(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_288(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_287(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_288(_) OF_PP_INTERNAL_SEQ_SIZE_289\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_288 288\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_289(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_288(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_289(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_288(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_289(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_288(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_289(_) OF_PP_INTERNAL_SEQ_SIZE_290\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_289 289\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_290(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_289(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_290(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_289(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_290(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_289(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_290(_) OF_PP_INTERNAL_SEQ_SIZE_291\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_290 290\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_291(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_290(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_291(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_290(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_291(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_290(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_291(_) OF_PP_INTERNAL_SEQ_SIZE_292\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_291 291\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_292(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_291(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_292(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_291(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_292(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_291(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_292(_) OF_PP_INTERNAL_SEQ_SIZE_293\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_292 292\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_293(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_292(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_293(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_292(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_293(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_292(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_293(_) OF_PP_INTERNAL_SEQ_SIZE_294\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_293 293\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_294(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_293(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_294(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_293(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_294(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_293(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_294(_) OF_PP_INTERNAL_SEQ_SIZE_295\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_294 294\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_295(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_294(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_295(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_294(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_295(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_294(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_295(_) OF_PP_INTERNAL_SEQ_SIZE_296\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_295 295\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_296(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_295(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_296(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_295(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_296(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_295(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_296(_) OF_PP_INTERNAL_SEQ_SIZE_297\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_296 296\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_297(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_296(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_297(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_296(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_297(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_296(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_297(_) OF_PP_INTERNAL_SEQ_SIZE_298\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_297 297\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_298(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_297(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_298(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_297(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_298(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_297(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_298(_) OF_PP_INTERNAL_SEQ_SIZE_299\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_298 298\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_299(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_298(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_299(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_298(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_299(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_298(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_299(_) OF_PP_INTERNAL_SEQ_SIZE_300\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_299 299\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_300(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_299(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_300(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_299(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_300(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_299(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_300(_) OF_PP_INTERNAL_SEQ_SIZE_301\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_300 300\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_301(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_300(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_301(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_300(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_301(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_300(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_301(_) OF_PP_INTERNAL_SEQ_SIZE_302\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_301 301\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_302(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_301(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_302(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_301(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_302(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_301(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_302(_) OF_PP_INTERNAL_SEQ_SIZE_303\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_302 302\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_303(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_302(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_303(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_302(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_303(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_302(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_303(_) OF_PP_INTERNAL_SEQ_SIZE_304\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_303 303\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_304(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_303(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_304(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_303(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_304(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_303(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_304(_) OF_PP_INTERNAL_SEQ_SIZE_305\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_304 304\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_305(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_304(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_305(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_304(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_305(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_304(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_305(_) OF_PP_INTERNAL_SEQ_SIZE_306\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_305 305\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_306(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_305(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_306(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_305(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_306(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_305(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_306(_) OF_PP_INTERNAL_SEQ_SIZE_307\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_306 306\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_307(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_306(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_307(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_306(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_307(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_306(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_307(_) OF_PP_INTERNAL_SEQ_SIZE_308\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_307 307\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_308(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_307(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_308(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_307(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_308(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_307(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_308(_) OF_PP_INTERNAL_SEQ_SIZE_309\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_308 308\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_309(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_308(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_309(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_308(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_309(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_308(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_309(_) OF_PP_INTERNAL_SEQ_SIZE_310\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_309 309\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_310(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_309(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_310(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_309(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_310(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_309(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_310(_) OF_PP_INTERNAL_SEQ_SIZE_311\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_310 310\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_311(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_310(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_311(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_310(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_311(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_310(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_311(_) OF_PP_INTERNAL_SEQ_SIZE_312\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_311 311\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_312(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_311(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_312(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_311(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_312(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_311(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_312(_) OF_PP_INTERNAL_SEQ_SIZE_313\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_312 312\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_313(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_312(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_313(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_312(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_313(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_312(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_313(_) OF_PP_INTERNAL_SEQ_SIZE_314\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_313 313\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_314(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_313(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_314(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_313(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_314(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_313(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_314(_) OF_PP_INTERNAL_SEQ_SIZE_315\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_314 314\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_315(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_314(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_315(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_314(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_315(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_314(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_315(_) OF_PP_INTERNAL_SEQ_SIZE_316\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_315 315\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_316(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_315(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_316(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_315(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_316(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_315(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_316(_) OF_PP_INTERNAL_SEQ_SIZE_317\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_316 316\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_317(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_316(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_317(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_316(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_317(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_316(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_317(_) OF_PP_INTERNAL_SEQ_SIZE_318\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_317 317\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_318(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_317(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_318(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_317(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_318(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_317(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_318(_) OF_PP_INTERNAL_SEQ_SIZE_319\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_318 318\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_319(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_318(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_319(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_318(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_319(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_318(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_319(_) OF_PP_INTERNAL_SEQ_SIZE_320\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_319 319\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_320(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_319(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_320(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_319(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_320(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_319(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_320(_) OF_PP_INTERNAL_SEQ_SIZE_321\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_320 320\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_321(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_320(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_321(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_320(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_321(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_320(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_321(_) OF_PP_INTERNAL_SEQ_SIZE_322\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_321 321\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_322(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_321(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_322(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_321(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_322(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_321(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_322(_) OF_PP_INTERNAL_SEQ_SIZE_323\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_322 322\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_323(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_322(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_323(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_322(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_323(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_322(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_323(_) OF_PP_INTERNAL_SEQ_SIZE_324\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_323 323\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_324(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_323(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_324(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_323(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_324(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_323(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_324(_) OF_PP_INTERNAL_SEQ_SIZE_325\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_324 324\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_325(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_324(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_325(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_324(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_325(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_324(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_325(_) OF_PP_INTERNAL_SEQ_SIZE_326\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_325 325\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_326(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_325(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_326(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_325(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_326(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_325(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_326(_) OF_PP_INTERNAL_SEQ_SIZE_327\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_326 326\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_327(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_326(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_327(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_326(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_327(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_326(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_327(_) OF_PP_INTERNAL_SEQ_SIZE_328\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_327 327\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_328(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_327(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_328(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_327(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_328(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_327(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_328(_) OF_PP_INTERNAL_SEQ_SIZE_329\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_328 328\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_329(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_328(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_329(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_328(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_329(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_328(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_329(_) OF_PP_INTERNAL_SEQ_SIZE_330\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_329 329\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_330(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_329(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_330(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_329(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_330(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_329(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_330(_) OF_PP_INTERNAL_SEQ_SIZE_331\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_330 330\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_331(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_330(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_331(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_330(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_331(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_330(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_331(_) OF_PP_INTERNAL_SEQ_SIZE_332\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_331 331\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_332(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_331(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_332(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_331(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_332(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_331(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_332(_) OF_PP_INTERNAL_SEQ_SIZE_333\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_332 332\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_333(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_332(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_333(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_332(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_333(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_332(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_333(_) OF_PP_INTERNAL_SEQ_SIZE_334\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_333 333\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_334(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_333(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_334(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_333(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_334(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_333(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_334(_) OF_PP_INTERNAL_SEQ_SIZE_335\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_334 334\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_335(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_334(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_335(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_334(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_335(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_334(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_335(_) OF_PP_INTERNAL_SEQ_SIZE_336\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_335 335\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_336(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_335(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_336(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_335(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_336(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_335(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_336(_) OF_PP_INTERNAL_SEQ_SIZE_337\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_336 336\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_337(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_336(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_337(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_336(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_337(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_336(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_337(_) OF_PP_INTERNAL_SEQ_SIZE_338\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_337 337\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_338(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_337(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_338(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_337(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_338(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_337(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_338(_) OF_PP_INTERNAL_SEQ_SIZE_339\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_338 338\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_339(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_338(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_339(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_338(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_339(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_338(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_339(_) OF_PP_INTERNAL_SEQ_SIZE_340\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_339 339\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_340(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_339(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_340(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_339(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_340(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_339(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_340(_) OF_PP_INTERNAL_SEQ_SIZE_341\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_340 340\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_341(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_340(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_341(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_340(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_341(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_340(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_341(_) OF_PP_INTERNAL_SEQ_SIZE_342\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_341 341\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_342(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_341(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_342(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_341(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_342(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_341(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_342(_) OF_PP_INTERNAL_SEQ_SIZE_343\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_342 342\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_343(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_342(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_343(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_342(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_343(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_342(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_343(_) OF_PP_INTERNAL_SEQ_SIZE_344\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_343 343\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_344(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_343(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_344(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_343(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_344(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_343(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_344(_) OF_PP_INTERNAL_SEQ_SIZE_345\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_344 344\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_345(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_344(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_345(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_344(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_345(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_344(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_345(_) OF_PP_INTERNAL_SEQ_SIZE_346\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_345 345\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_346(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_345(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_346(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_345(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_346(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_345(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_346(_) OF_PP_INTERNAL_SEQ_SIZE_347\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_346 346\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_347(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_346(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_347(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_346(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_347(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_346(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_347(_) OF_PP_INTERNAL_SEQ_SIZE_348\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_347 347\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_348(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_347(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_348(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_347(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_348(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_347(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_348(_) OF_PP_INTERNAL_SEQ_SIZE_349\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_348 348\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_349(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_348(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_349(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_348(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_349(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_348(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_349(_) OF_PP_INTERNAL_SEQ_SIZE_350\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_349 349\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_350(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_349(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_350(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_349(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_350(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_349(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_350(_) OF_PP_INTERNAL_SEQ_SIZE_351\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_350 350\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_351(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_350(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_351(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_350(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_351(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_350(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_351(_) OF_PP_INTERNAL_SEQ_SIZE_352\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_351 351\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_352(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_351(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_352(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_351(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_352(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_351(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_352(_) OF_PP_INTERNAL_SEQ_SIZE_353\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_352 352\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_353(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_352(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_353(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_352(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_353(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_352(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_353(_) OF_PP_INTERNAL_SEQ_SIZE_354\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_353 353\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_354(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_353(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_354(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_353(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_354(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_353(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_354(_) OF_PP_INTERNAL_SEQ_SIZE_355\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_354 354\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_355(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_354(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_355(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_354(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_355(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_354(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_355(_) OF_PP_INTERNAL_SEQ_SIZE_356\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_355 355\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_356(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_355(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_356(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_355(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_356(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_355(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_356(_) OF_PP_INTERNAL_SEQ_SIZE_357\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_356 356\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_357(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_356(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_357(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_356(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_357(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_356(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_357(_) OF_PP_INTERNAL_SEQ_SIZE_358\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_357 357\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_358(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_357(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_358(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_357(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_358(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_357(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_358(_) OF_PP_INTERNAL_SEQ_SIZE_359\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_358 358\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_359(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_358(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_359(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_358(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_359(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_358(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_359(_) OF_PP_INTERNAL_SEQ_SIZE_360\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_359 359\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_360(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_359(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_360(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_359(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_360(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_359(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_360(_) OF_PP_INTERNAL_SEQ_SIZE_361\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_360 360\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_361(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_360(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_361(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_360(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_361(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_360(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_361(_) OF_PP_INTERNAL_SEQ_SIZE_362\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_361 361\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_362(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_361(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_362(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_361(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_362(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_361(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_362(_) OF_PP_INTERNAL_SEQ_SIZE_363\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_362 362\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_363(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_362(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_363(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_362(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_363(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_362(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_363(_) OF_PP_INTERNAL_SEQ_SIZE_364\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_363 363\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_364(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_363(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_364(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_363(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_364(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_363(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_364(_) OF_PP_INTERNAL_SEQ_SIZE_365\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_364 364\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_365(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_364(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_365(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_364(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_365(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_364(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_365(_) OF_PP_INTERNAL_SEQ_SIZE_366\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_365 365\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_366(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_365(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_366(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_365(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_366(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_365(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_366(_) OF_PP_INTERNAL_SEQ_SIZE_367\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_366 366\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_367(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_366(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_367(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_366(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_367(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_366(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_367(_) OF_PP_INTERNAL_SEQ_SIZE_368\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_367 367\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_368(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_367(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_368(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_367(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_368(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_367(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_368(_) OF_PP_INTERNAL_SEQ_SIZE_369\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_368 368\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_369(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_368(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_369(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_368(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_369(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_368(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_369(_) OF_PP_INTERNAL_SEQ_SIZE_370\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_369 369\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_370(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_369(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_370(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_369(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_370(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_369(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_370(_) OF_PP_INTERNAL_SEQ_SIZE_371\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_370 370\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_371(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_370(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_371(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_370(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_371(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_370(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_371(_) OF_PP_INTERNAL_SEQ_SIZE_372\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_371 371\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_372(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_371(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_372(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_371(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_372(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_371(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_372(_) OF_PP_INTERNAL_SEQ_SIZE_373\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_372 372\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_373(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_372(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_373(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_372(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_373(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_372(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_373(_) OF_PP_INTERNAL_SEQ_SIZE_374\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_373 373\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_374(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_373(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_374(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_373(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_374(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_373(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_374(_) OF_PP_INTERNAL_SEQ_SIZE_375\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_374 374\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_375(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_374(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_375(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_374(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_375(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_374(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_375(_) OF_PP_INTERNAL_SEQ_SIZE_376\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_375 375\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_376(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_375(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_376(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_375(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_376(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_375(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_376(_) OF_PP_INTERNAL_SEQ_SIZE_377\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_376 376\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_377(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_376(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_377(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_376(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_377(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_376(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_377(_) OF_PP_INTERNAL_SEQ_SIZE_378\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_377 377\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_378(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_377(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_378(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_377(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_378(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_377(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_378(_) OF_PP_INTERNAL_SEQ_SIZE_379\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_378 378\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_379(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_378(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_379(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_378(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_379(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_378(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_379(_) OF_PP_INTERNAL_SEQ_SIZE_380\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_379 379\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_380(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_379(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_380(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_379(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_380(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_379(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_380(_) OF_PP_INTERNAL_SEQ_SIZE_381\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_380 380\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_381(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_380(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_381(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_380(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_381(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_380(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_381(_) OF_PP_INTERNAL_SEQ_SIZE_382\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_381 381\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_382(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_381(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_382(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_381(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_382(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_381(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_382(_) OF_PP_INTERNAL_SEQ_SIZE_383\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_382 382\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_383(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_382(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_383(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_382(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_383(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_382(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_383(_) OF_PP_INTERNAL_SEQ_SIZE_384\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_383 383\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_384(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_383(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_384(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_383(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_384(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_383(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_384(_) OF_PP_INTERNAL_SEQ_SIZE_385\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_384 384\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_385(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_384(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_385(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_384(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_385(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_384(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_385(_) OF_PP_INTERNAL_SEQ_SIZE_386\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_385 385\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_386(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_385(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_386(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_385(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_386(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_385(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_386(_) OF_PP_INTERNAL_SEQ_SIZE_387\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_386 386\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_387(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_386(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_387(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_386(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_387(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_386(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_387(_) OF_PP_INTERNAL_SEQ_SIZE_388\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_387 387\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_388(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_387(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_388(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_387(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_388(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_387(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_388(_) OF_PP_INTERNAL_SEQ_SIZE_389\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_388 388\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_389(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_388(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_389(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_388(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_389(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_388(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_389(_) OF_PP_INTERNAL_SEQ_SIZE_390\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_389 389\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_390(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_389(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_390(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_389(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_390(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_389(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_390(_) OF_PP_INTERNAL_SEQ_SIZE_391\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_390 390\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_391(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_390(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_391(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_390(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_391(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_390(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_391(_) OF_PP_INTERNAL_SEQ_SIZE_392\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_391 391\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_392(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_391(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_392(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_391(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_392(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_391(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_392(_) OF_PP_INTERNAL_SEQ_SIZE_393\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_392 392\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_393(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_392(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_393(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_392(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_393(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_392(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_393(_) OF_PP_INTERNAL_SEQ_SIZE_394\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_393 393\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_394(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_393(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_394(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_393(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_394(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_393(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_394(_) OF_PP_INTERNAL_SEQ_SIZE_395\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_394 394\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_395(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_394(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_395(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_394(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_395(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_394(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_395(_) OF_PP_INTERNAL_SEQ_SIZE_396\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_395 395\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_396(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_395(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_396(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_395(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_396(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_395(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_396(_) OF_PP_INTERNAL_SEQ_SIZE_397\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_396 396\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_397(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_396(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_397(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_396(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_397(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_396(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_397(_) OF_PP_INTERNAL_SEQ_SIZE_398\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_397 397\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_398(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_397(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_398(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_397(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_398(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_397(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_398(_) OF_PP_INTERNAL_SEQ_SIZE_399\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_398 398\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_399(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_398(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_399(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_398(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_399(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_398(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_399(_) OF_PP_INTERNAL_SEQ_SIZE_400\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_399 399\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_400(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_399(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_400(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_399(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_400(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_399(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_400(_) OF_PP_INTERNAL_SEQ_SIZE_401\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_400 400\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_401(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_400(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_401(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_400(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_401(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_400(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_401(_) OF_PP_INTERNAL_SEQ_SIZE_402\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_401 401\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_402(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_401(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_402(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_401(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_402(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_401(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_402(_) OF_PP_INTERNAL_SEQ_SIZE_403\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_402 402\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_403(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_402(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_403(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_402(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_403(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_402(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_403(_) OF_PP_INTERNAL_SEQ_SIZE_404\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_403 403\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_404(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_403(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_404(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_403(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_404(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_403(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_404(_) OF_PP_INTERNAL_SEQ_SIZE_405\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_404 404\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_405(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_404(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_405(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_404(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_405(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_404(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_405(_) OF_PP_INTERNAL_SEQ_SIZE_406\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_405 405\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_406(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_405(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_406(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_405(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_406(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_405(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_406(_) OF_PP_INTERNAL_SEQ_SIZE_407\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_406 406\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_407(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_406(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_407(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_406(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_407(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_406(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_407(_) OF_PP_INTERNAL_SEQ_SIZE_408\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_407 407\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_408(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_407(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_408(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_407(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_408(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_407(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_408(_) OF_PP_INTERNAL_SEQ_SIZE_409\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_408 408\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_409(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_408(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_409(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_408(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_409(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_408(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_409(_) OF_PP_INTERNAL_SEQ_SIZE_410\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_409 409\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_410(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_409(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_410(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_409(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_410(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_409(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_410(_) OF_PP_INTERNAL_SEQ_SIZE_411\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_410 410\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_411(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_410(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_411(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_410(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_411(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_410(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_411(_) OF_PP_INTERNAL_SEQ_SIZE_412\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_411 411\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_412(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_411(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_412(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_411(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_412(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_411(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_412(_) OF_PP_INTERNAL_SEQ_SIZE_413\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_412 412\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_413(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_412(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_413(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_412(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_413(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_412(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_413(_) OF_PP_INTERNAL_SEQ_SIZE_414\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_413 413\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_414(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_413(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_414(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_413(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_414(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_413(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_414(_) OF_PP_INTERNAL_SEQ_SIZE_415\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_414 414\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_415(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_414(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_415(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_414(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_415(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_414(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_415(_) OF_PP_INTERNAL_SEQ_SIZE_416\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_415 415\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_416(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_415(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_416(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_415(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_416(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_415(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_416(_) OF_PP_INTERNAL_SEQ_SIZE_417\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_416 416\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_417(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_416(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_417(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_416(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_417(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_416(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_417(_) OF_PP_INTERNAL_SEQ_SIZE_418\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_417 417\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_418(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_417(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_418(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_417(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_418(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_417(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_418(_) OF_PP_INTERNAL_SEQ_SIZE_419\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_418 418\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_419(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_418(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_419(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_418(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_419(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_418(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_419(_) OF_PP_INTERNAL_SEQ_SIZE_420\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_419 419\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_420(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_419(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_420(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_419(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_420(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_419(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_420(_) OF_PP_INTERNAL_SEQ_SIZE_421\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_420 420\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_421(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_420(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_421(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_420(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_421(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_420(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_421(_) OF_PP_INTERNAL_SEQ_SIZE_422\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_421 421\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_422(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_421(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_422(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_421(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_422(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_421(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_422(_) OF_PP_INTERNAL_SEQ_SIZE_423\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_422 422\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_423(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_422(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_423(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_422(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_423(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_422(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_423(_) OF_PP_INTERNAL_SEQ_SIZE_424\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_423 423\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_424(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_423(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_424(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_423(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_424(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_423(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_424(_) OF_PP_INTERNAL_SEQ_SIZE_425\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_424 424\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_425(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_424(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_425(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_424(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_425(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_424(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_425(_) OF_PP_INTERNAL_SEQ_SIZE_426\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_425 425\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_426(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_425(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_426(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_425(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_426(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_425(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_426(_) OF_PP_INTERNAL_SEQ_SIZE_427\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_426 426\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_427(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_426(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_427(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_426(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_427(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_426(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_427(_) OF_PP_INTERNAL_SEQ_SIZE_428\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_427 427\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_428(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_427(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_428(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_427(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_428(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_427(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_428(_) OF_PP_INTERNAL_SEQ_SIZE_429\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_428 428\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_429(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_428(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_429(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_428(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_429(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_428(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_429(_) OF_PP_INTERNAL_SEQ_SIZE_430\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_429 429\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_430(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_429(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_430(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_429(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_430(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_429(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_430(_) OF_PP_INTERNAL_SEQ_SIZE_431\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_430 430\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_431(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_430(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_431(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_430(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_431(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_430(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_431(_) OF_PP_INTERNAL_SEQ_SIZE_432\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_431 431\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_432(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_431(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_432(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_431(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_432(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_431(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_432(_) OF_PP_INTERNAL_SEQ_SIZE_433\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_432 432\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_433(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_432(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_433(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_432(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_433(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_432(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_433(_) OF_PP_INTERNAL_SEQ_SIZE_434\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_433 433\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_434(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_433(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_434(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_433(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_434(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_433(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_434(_) OF_PP_INTERNAL_SEQ_SIZE_435\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_434 434\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_435(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_434(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_435(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_434(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_435(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_434(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_435(_) OF_PP_INTERNAL_SEQ_SIZE_436\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_435 435\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_436(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_435(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_436(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_435(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_436(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_435(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_436(_) OF_PP_INTERNAL_SEQ_SIZE_437\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_436 436\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_437(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_436(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_437(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_436(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_437(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_436(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_437(_) OF_PP_INTERNAL_SEQ_SIZE_438\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_437 437\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_438(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_437(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_438(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_437(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_438(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_437(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_438(_) OF_PP_INTERNAL_SEQ_SIZE_439\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_438 438\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_439(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_438(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_439(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_438(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_439(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_438(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_439(_) OF_PP_INTERNAL_SEQ_SIZE_440\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_439 439\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_440(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_439(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_440(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_439(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_440(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_439(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_440(_) OF_PP_INTERNAL_SEQ_SIZE_441\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_440 440\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_441(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_440(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_441(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_440(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_441(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_440(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_441(_) OF_PP_INTERNAL_SEQ_SIZE_442\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_441 441\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_442(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_441(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_442(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_441(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_442(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_441(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_442(_) OF_PP_INTERNAL_SEQ_SIZE_443\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_442 442\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_443(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_442(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_443(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_442(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_443(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_442(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_443(_) OF_PP_INTERNAL_SEQ_SIZE_444\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_443 443\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_444(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_443(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_444(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_443(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_444(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_443(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_444(_) OF_PP_INTERNAL_SEQ_SIZE_445\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_444 444\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_445(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_444(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_445(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_444(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_445(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_444(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_445(_) OF_PP_INTERNAL_SEQ_SIZE_446\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_445 445\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_446(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_445(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_446(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_445(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_446(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_445(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_446(_) OF_PP_INTERNAL_SEQ_SIZE_447\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_446 446\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_447(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_446(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_447(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_446(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_447(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_446(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_447(_) OF_PP_INTERNAL_SEQ_SIZE_448\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_447 447\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_448(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_447(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_448(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_447(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_448(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_447(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_448(_) OF_PP_INTERNAL_SEQ_SIZE_449\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_448 448\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_449(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_448(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_449(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_448(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_449(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_448(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_449(_) OF_PP_INTERNAL_SEQ_SIZE_450\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_449 449\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_450(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_449(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_450(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_449(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_450(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_449(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_450(_) OF_PP_INTERNAL_SEQ_SIZE_451\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_450 450\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_451(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_450(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_451(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_450(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_451(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_450(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_451(_) OF_PP_INTERNAL_SEQ_SIZE_452\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_451 451\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_452(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_451(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_452(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_451(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_452(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_451(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_452(_) OF_PP_INTERNAL_SEQ_SIZE_453\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_452 452\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_453(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_452(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_453(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_452(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_453(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_452(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_453(_) OF_PP_INTERNAL_SEQ_SIZE_454\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_453 453\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_454(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_453(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_454(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_453(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_454(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_453(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_454(_) OF_PP_INTERNAL_SEQ_SIZE_455\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_454 454\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_455(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_454(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_455(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_454(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_455(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_454(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_455(_) OF_PP_INTERNAL_SEQ_SIZE_456\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_455 455\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_456(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_455(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_456(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_455(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_456(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_455(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_456(_) OF_PP_INTERNAL_SEQ_SIZE_457\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_456 456\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_457(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_456(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_457(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_456(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_457(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_456(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_457(_) OF_PP_INTERNAL_SEQ_SIZE_458\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_457 457\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_458(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_457(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_458(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_457(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_458(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_457(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_458(_) OF_PP_INTERNAL_SEQ_SIZE_459\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_458 458\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_459(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_458(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_459(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_458(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_459(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_458(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_459(_) OF_PP_INTERNAL_SEQ_SIZE_460\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_459 459\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_460(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_459(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_460(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_459(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_460(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_459(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_460(_) OF_PP_INTERNAL_SEQ_SIZE_461\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_460 460\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_461(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_460(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_461(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_460(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_461(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_460(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_461(_) OF_PP_INTERNAL_SEQ_SIZE_462\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_461 461\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_462(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_461(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_462(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_461(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_462(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_461(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_462(_) OF_PP_INTERNAL_SEQ_SIZE_463\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_462 462\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_463(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_462(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_463(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_462(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_463(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_462(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_463(_) OF_PP_INTERNAL_SEQ_SIZE_464\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_463 463\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_464(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_463(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_464(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_463(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_464(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_463(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_464(_) OF_PP_INTERNAL_SEQ_SIZE_465\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_464 464\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_465(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_464(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_465(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_464(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_465(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_464(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_465(_) OF_PP_INTERNAL_SEQ_SIZE_466\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_465 465\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_466(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_465(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_466(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_465(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_466(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_465(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_466(_) OF_PP_INTERNAL_SEQ_SIZE_467\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_466 466\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_467(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_466(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_467(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_466(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_467(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_466(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_467(_) OF_PP_INTERNAL_SEQ_SIZE_468\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_467 467\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_468(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_467(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_468(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_467(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_468(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_467(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_468(_) OF_PP_INTERNAL_SEQ_SIZE_469\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_468 468\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_469(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_468(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_469(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_468(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_469(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_468(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_469(_) OF_PP_INTERNAL_SEQ_SIZE_470\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_469 469\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_470(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_469(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_470(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_469(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_470(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_469(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_470(_) OF_PP_INTERNAL_SEQ_SIZE_471\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_470 470\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_471(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_470(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_471(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_470(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_471(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_470(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_471(_) OF_PP_INTERNAL_SEQ_SIZE_472\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_471 471\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_472(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_471(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_472(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_471(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_472(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_471(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_472(_) OF_PP_INTERNAL_SEQ_SIZE_473\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_472 472\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_473(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_472(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_473(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_472(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_473(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_472(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_473(_) OF_PP_INTERNAL_SEQ_SIZE_474\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_473 473\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_474(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_473(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_474(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_473(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_474(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_473(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_474(_) OF_PP_INTERNAL_SEQ_SIZE_475\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_474 474\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_475(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_474(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_475(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_474(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_475(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_474(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_475(_) OF_PP_INTERNAL_SEQ_SIZE_476\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_475 475\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_476(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_475(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_476(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_475(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_476(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_475(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_476(_) OF_PP_INTERNAL_SEQ_SIZE_477\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_476 476\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_477(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_476(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_477(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_476(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_477(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_476(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_477(_) OF_PP_INTERNAL_SEQ_SIZE_478\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_477 477\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_478(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_477(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_478(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_477(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_478(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_477(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_478(_) OF_PP_INTERNAL_SEQ_SIZE_479\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_478 478\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_479(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_478(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_479(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_478(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_479(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_478(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_479(_) OF_PP_INTERNAL_SEQ_SIZE_480\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_479 479\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_480(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_479(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_480(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_479(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_480(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_479(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_480(_) OF_PP_INTERNAL_SEQ_SIZE_481\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_480 480\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_481(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_480(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_481(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_480(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_481(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_480(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_481(_) OF_PP_INTERNAL_SEQ_SIZE_482\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_481 481\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_482(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_481(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_482(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_481(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_482(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_481(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_482(_) OF_PP_INTERNAL_SEQ_SIZE_483\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_482 482\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_483(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_482(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_483(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_482(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_483(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_482(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_483(_) OF_PP_INTERNAL_SEQ_SIZE_484\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_483 483\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_484(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_483(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_484(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_483(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_484(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_483(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_484(_) OF_PP_INTERNAL_SEQ_SIZE_485\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_484 484\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_485(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_484(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_485(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_484(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_485(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_484(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_485(_) OF_PP_INTERNAL_SEQ_SIZE_486\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_485 485\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_486(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_485(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_486(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_485(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_486(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_485(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_486(_) OF_PP_INTERNAL_SEQ_SIZE_487\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_486 486\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_487(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_486(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_487(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_486(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_487(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_486(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_487(_) OF_PP_INTERNAL_SEQ_SIZE_488\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_487 487\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_488(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_487(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_488(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_487(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_488(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_487(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_488(_) OF_PP_INTERNAL_SEQ_SIZE_489\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_488 488\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_489(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_488(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_489(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_488(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_489(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_488(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_489(_) OF_PP_INTERNAL_SEQ_SIZE_490\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_489 489\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_490(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_489(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_490(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_489(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_490(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_489(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_490(_) OF_PP_INTERNAL_SEQ_SIZE_491\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_490 490\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_491(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_490(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_491(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_490(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_491(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_490(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_491(_) OF_PP_INTERNAL_SEQ_SIZE_492\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_491 491\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_492(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_491(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_492(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_491(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_492(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_491(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_492(_) OF_PP_INTERNAL_SEQ_SIZE_493\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_492 492\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_493(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_492(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_493(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_492(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_493(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_492(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_493(_) OF_PP_INTERNAL_SEQ_SIZE_494\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_493 493\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_494(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_493(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_494(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_493(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_494(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_493(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_494(_) OF_PP_INTERNAL_SEQ_SIZE_495\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_494 494\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_495(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_494(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_495(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_494(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_495(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_494(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_495(_) OF_PP_INTERNAL_SEQ_SIZE_496\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_495 495\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_496(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_495(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_496(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_495(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_496(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_495(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_496(_) OF_PP_INTERNAL_SEQ_SIZE_497\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_496 496\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_497(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_496(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_497(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_496(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_497(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_496(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_497(_) OF_PP_INTERNAL_SEQ_SIZE_498\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_497 497\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_498(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_497(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_498(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_497(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_498(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_497(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_498(_) OF_PP_INTERNAL_SEQ_SIZE_499\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_498 498\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_499(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_498(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_499(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_498(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_499(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_498(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_499(_) OF_PP_INTERNAL_SEQ_SIZE_500\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_499 499\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_500(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_499(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_500(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_499(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_500(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_499(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_500(_) OF_PP_INTERNAL_SEQ_SIZE_501\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_500 500\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_501(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_500(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_501(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_500(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_501(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_500(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_501(_) OF_PP_INTERNAL_SEQ_SIZE_502\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_501 501\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_502(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_501(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_502(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_501(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_502(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_501(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_502(_) OF_PP_INTERNAL_SEQ_SIZE_503\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_502 502\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_503(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_502(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_503(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_502(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_503(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_502(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_503(_) OF_PP_INTERNAL_SEQ_SIZE_504\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_503 503\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_504(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_503(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_504(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_503(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_504(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_503(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_504(_) OF_PP_INTERNAL_SEQ_SIZE_505\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_504 504\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_505(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_504(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_505(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_504(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_505(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_504(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_505(_) OF_PP_INTERNAL_SEQ_SIZE_506\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_505 505\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_506(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_505(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_506(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_505(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_506(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_505(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_506(_) OF_PP_INTERNAL_SEQ_SIZE_507\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_506 506\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_507(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_506(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_507(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_506(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_507(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_506(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_507(_) OF_PP_INTERNAL_SEQ_SIZE_508\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_507 507\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_508(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_507(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_508(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_507(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_508(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_507(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_508(_) OF_PP_INTERNAL_SEQ_SIZE_509\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_508 508\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_509(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_508(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_509(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_508(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_509(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_508(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_509(_) OF_PP_INTERNAL_SEQ_SIZE_510\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_509 509\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_510(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_509(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_510(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_509(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_510(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_509(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_510(_) OF_PP_INTERNAL_SEQ_SIZE_511\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_510 510\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_511(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_510(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_511(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_510(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_511(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_510(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_SEQ_SIZE_511(_) OF_PP_INTERNAL_SEQ_SIZE_512\n#define OF_PP_INTERNAL_SEQ_SIZE_OF_PP_INTERNAL_SEQ_SIZE_511 511\n#define OF_PP_INTERNAL_D0_SEQ_FOR_EACH_512(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D0_SEQ_FOR_EACH_511(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D1_SEQ_FOR_EACH_512(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D1_SEQ_FOR_EACH_511(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#define OF_PP_INTERNAL_D2_SEQ_FOR_EACH_512(apply, m, d, seq) \\\n  apply(m, d, OF_PP_INTERNAL_SEQ_HEAD(seq))                  \\\n      OF_PP_INTERNAL_D2_SEQ_FOR_EACH_511(apply, m, d, OF_PP_INTERNAL_SEQ_TAIL(seq))\n\n#endif  // ONEFLOW_CORE_COMMON_PREPROCESSOR_INTERNAL_H_\n"
  },
  {
    "path": "oneflow/core/common/preprocessor_test.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <gtest/gtest.h>\n#include <unordered_map>\n#include \"oneflow/core/common/data_type.h\"\n\nnamespace oneflow {\n\nTEST(PP_SEQ, internal_seq_size) {\n#define SEQ (1)(2)(3)\n  ASSERT_EQ(OF_PP_SEQ_SIZE(SEQ), 3);\n#undef SEQ\n}\n\nTEST(PP_SEQ, internal_big_seq_size) {\n#define SEQ                                                                                        \\\n  (0)(1)(2)(3)(4)(5)(6)(7)(8)(9)(10)(11)(12)(13)(14)(15)(16)(17)(18)(19)(20)(21)(22)(23)(24)(25)(  \\\n      26)(27)(28)(29)(30)(31)(32)(33)(34)(35)(36)(37)(38)(39)(40)(41)(42)(43)(44)(45)(46)(47)(48)( \\\n      49)(50)(51)(52)(53)(54)(55)(56)(57)(58)(59)(60)(61)(62)(63)\n  ASSERT_EQ(OF_PP_SEQ_SIZE(SEQ), 64);\n#undef SEQ\n}\n\nTEST(PP_SEQ, internal_for_each) {\n#define SEQ (1)(2)(3)(4)\n#define MAKE_PAIR(x) {x, x},\n  std::unordered_map<int, int> identity = {OF_PP_INTERNAL_SEQ_FOR_EACH_ATOMIC(MAKE_PAIR, _, SEQ)};\n#undef MAKE_PAIR\n#undef SEQ\n  for (int i = 1; i <= 4; ++i) { ASSERT_EQ(i, identity[i]); }\n}\n\nTEST(PP_TUPLE, internal_is_tuple_empty) {\n  ASSERT_EQ(OF_PP_INTERNAL_IS_TUPLE_EMPTY(()), 1);\n  ASSERT_EQ(OF_PP_INTERNAL_IS_TUPLE_EMPTY((1)), 0);\n  ASSERT_EQ(OF_PP_INTERNAL_IS_TUPLE_EMPTY((1, 2)), 0);\n}\n\nTEST(PP_TUPLE, internal_tuple_size) {\n  ASSERT_EQ(OF_PP_INTERNAL_TUPLE_SIZE(()), 0);\n  ASSERT_EQ(OF_PP_INTERNAL_TUPLE_SIZE((1)), 1);\n  ASSERT_EQ(OF_PP_INTERNAL_TUPLE_SIZE((1, 2)), 2);\n  ASSERT_EQ(OF_PP_INTERNAL_TUPLE_SIZE((1, 2, 3)), 3);\n  ASSERT_EQ(OF_PP_INTERNAL_TUPLE_SIZE((1, 2, 3, 4)), 4);\n  ASSERT_EQ(OF_PP_INTERNAL_TUPLE_SIZE((1, 2, 3, 4, 5)), 5);\n}\n\nTEST(PP_SEQ, internal_seq_product) {\n#define SEQ (0)(1)\n  std::string expanded(OF_PP_STRINGIZE(OF_PP_INTERNAL_SEQ_PRODUCT(SEQ, SEQ)));\n#undef SEQ\n  ASSERT_TRUE((expanded == \"((0, 0)) ((1, 0)) ((0, 1)) ((1, 1))\")\n              || (expanded == \"((0, 0)) ((1, 0))  ((0, 1)) ((1, 1))\"));\n}\n\nTEST(PP_SEQ, internal_different_seq_product) {\n#define SEQ1 (0)(1)\n#define SEQ2 (a)(b)\n  std::string expanded(OF_PP_STRINGIZE(OF_PP_INTERNAL_SEQ_PRODUCT(SEQ1, SEQ2)));\n#undef SEQ1\n#undef SEQ2\n  ASSERT_TRUE((expanded == \"((0, a)) ((1, a)) ((0, b)) ((1, b))\")\n              || (expanded == \"((0, a)) ((1, a))  ((0, b)) ((1, b))\"));\n}\n\nTEST(PP_SEQ, internal_seq_product_for_each) {\n#define SEQ (0)(1)\n#define MAKE_ENTRY(x, y) {OF_PP_STRINGIZE(OF_PP_CAT(x, y)), x || y},\n  std::unordered_map<std::string, bool> or_table = {\n      OF_PP_FOR_EACH_TUPLE(MAKE_ENTRY, OF_PP_INTERNAL_SEQ_PRODUCT(SEQ, SEQ))};\n#undef MAKE_ENTRY\n#undef SEQ\n  ASSERT_EQ(or_table[\"00\"], false);\n  ASSERT_EQ(or_table[\"01\"], true);\n  ASSERT_EQ(or_table[\"10\"], true);\n  ASSERT_EQ(or_table[\"11\"], true);\n}\n\nTEST(PP, stringize) {\n  ASSERT_EQ(OF_PP_STRINGIZE(foo), \"foo\");\n  ASSERT_EQ(OF_PP_STRINGIZE(bar), \"bar\");\n}\n\nTEST(PP, concate) {\n  ASSERT_EQ(OF_PP_CAT(OF_PP_, STRINGIZE)(foo), \"foo\");\n  ASSERT_EQ(OF_PP_CAT(OF_PP_, STRINGIZE)(bar), \"bar\");\n}\n\nTEST(PP_SEQ, make_tuple_seq) { ASSERT_EQ(OF_PP_STRINGIZE(OF_PP_MAKE_TUPLE_SEQ(1, 2)), \"((1, 2))\"); }\n\nTEST(PP_SEQ, for_each_tuple) {\n#define SEQ ((1, 1))((2, 2))((3, 3))((4, 4))\n#define MAKE_ENTRY(x, y) {x, y},\n  std::unordered_map<int, int> identity = {OF_PP_FOR_EACH_TUPLE(MAKE_ENTRY, SEQ)};\n#undef MAKE_ENTRY\n#undef SEQ\n  for (int i = 1; i <= 4; ++i) { ASSERT_EQ(i, identity[i]); }\n}\n\nTEST(PP_SEQ, outter_for_each_tuple) {\n#define SEQ ((1, 1))((2, 2))((3, 3))((4, 4))\n#define MAKE_ENTRY(x, y) {x, y},\n  std::unordered_map<int, int> identity = {OF_PP_OUTTER_FOR_EACH_TUPLE(MAKE_ENTRY, SEQ)};\n#undef MAKE_ENTRY\n#undef SEQ\n  for (int i = 1; i <= 4; ++i) { ASSERT_EQ(i, identity[i]); }\n}\n\nTEST(PP_SEQ, nested_for_each_tuple) {\n#define SEQ ((0))((1))((2))((3))\n#define MAKE_INNER(x) x,\n#define MAKE_OUTTER(x) {OF_PP_FOR_EACH_TUPLE(MAKE_INNER, SEQ)},\n  std::vector<std::vector<int>> table = {OF_PP_OUTTER_FOR_EACH_TUPLE(MAKE_OUTTER, SEQ)};\n#undef MAKE_OUTTER\n#undef MAKE_INNER\n#undef SEQ\n  ASSERT_EQ(table.size(), 4);\n  for (int i = 0; i < 4; ++i) {\n    ASSERT_EQ(table[i].size(), 4);\n    for (int j = 0; j < 4; ++j) { ASSERT_EQ(j, table[i][j]); }\n  }\n}\n\nTEST(PP_SEQ, seq_product_for_each) {\n#define SEQ (0)(1)\n#define MAKE_ENTRY(x, y) {OF_PP_STRINGIZE(OF_PP_CAT(x, y)), x || y},\n  std::unordered_map<std::string, bool> or_table = {\n      OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_ENTRY, SEQ, SEQ)};\n#undef MAKE_ENTRY\n#undef SEQ\n  ASSERT_EQ(or_table[\"00\"], false);\n  ASSERT_EQ(or_table[\"01\"], true);\n  ASSERT_EQ(or_table[\"10\"], true);\n  ASSERT_EQ(or_table[\"11\"], true);\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/common/process_state.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_COMMON_PROCESS_STATE_H_\n#define ONEFLOW_CORE_COMMON_PROCESS_STATE_H_\n\n#if defined(_MSC_VER)\n#include <WinSock2.h>\n#include <direct.h>\n#include <stdlib.h>\n#pragma comment(lib, \"Ws2_32.lib\")\n#else\n#include <unistd.h>\n#endif\n#include <memory>\n#include <string>\n\nnamespace oneflow {\n\ninline std::string GetCwd() {\n  size_t len = 128;\n  std::unique_ptr<char[]> a(new char[len]);\n  for (;;) {\n    char* p = getcwd(a.get(), len);\n    if (p != NULL) {\n      return p;\n    } else if (errno == ERANGE) {\n      len += len;\n      a.reset(new char[len]);\n    } else {\n      return NULL;\n    }\n  }\n}\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_COMMON_PROCESS_STATE_H_\n"
  },
  {
    "path": "oneflow/core/common/protobuf.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/protobuf.h\"\n#include \"oneflow/core/common/shape.pb.h\"\n#include \"oneflow/core/common/sequential.pb.h\"\n#include \"oneflow/core/common/str_util.h\"\n#include \"oneflow/core/register/blob_desc.pb.h\"\n#include <google/protobuf/io/coded_stream.h>\n#include <google/protobuf/io/zero_copy_stream_impl.h>\n#include <google/protobuf/text_format.h>\n\nnamespace oneflow {\n\n// parse protobuf message from .prototxt file\nbool TryParseProtoFromTextFile(const std::string& file_path, PbMessage* proto) {\n  std::ifstream in_stream(file_path.c_str(), std::ifstream::in);\n  google::protobuf::io::IstreamInputStream input(&in_stream);\n  return google::protobuf::TextFormat::Parse(&input, proto);\n}\n\nvoid ParseProtoFromTextFile(const std::string& file_path, PbMessage* proto) {\n  CHECK(TryParseProtoFromTextFile(file_path, proto));\n}\n\n// parse protobuf message from .pb file\nbool TryParseProtoFromPbFile(const std::string& file_path, PbMessage* proto) {\n  std::ifstream in_stream(file_path.c_str(), std::ifstream::in | std::ifstream::binary);\n  return proto->ParseFromIstream(&in_stream);\n}\n\nvoid ParseProtoFromPbFile(const std::string& file_path, PbMessage* proto) {\n  CHECK(TryParseProtoFromPbFile(file_path, proto));\n}\n\nvoid PrintProtoToTextFile(const PbMessage& proto, const std::string& file_path) {\n  std::ofstream out_stream(file_path.c_str(), std::ofstream::out | std::ofstream::trunc);\n  google::protobuf::io::OstreamOutputStream output(&out_stream);\n  CHECK(google::protobuf::TextFormat::Print(proto, &output));\n}\n\nstd::string PbMessage2TxtString(const PbMessage& proto) {\n  std::string str;\n  PbMessage2TxtString(proto, &str);\n  return str;\n}\n\nvoid PbMessage2TxtString(const PbMessage& proto, std::string* str) {\n  google::protobuf::TextFormat::PrintToString(proto, str);\n}\n\nbool TxtString2PbMessage(const std::string& proto_str, PbMessage* msg) {\n  return google::protobuf::TextFormat::ParseFromString(proto_str, msg);\n}\n\nbool FieldDefinedInPbMessage(const PbMessage& msg, const std::string& field_name) {\n  PROTOBUF_GET_FIELDDESC(msg, field_name);\n  return fd != nullptr;\n}\n\n#define DEFINE_GET_VAL_FROM_PBMESSAGE(cpp_type, pb_type_name)                                   \\\n  template<>                                                                                    \\\n  cpp_type GetValFromPbMessage<cpp_type>(const PbMessage& msg, const std::string& field_name) { \\\n    PROTOBUF_REFLECTION(msg, field_name);                                                       \\\n    return r->Get##pb_type_name(msg, fd);                                                       \\\n  }\n\nOF_PP_FOR_EACH_TUPLE(DEFINE_GET_VAL_FROM_PBMESSAGE,\n                     PROTOBUF_BASIC_DATA_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(const PbMessage&, Message))\n\n#define DEFINE_SET_VAL_IN_PBMESSAGE(cpp_type, pb_type_name)                                    \\\n  template<>                                                                                   \\\n  void SetValInPbMessage(PbMessage* msg, const std::string& field_name, const cpp_type& val) { \\\n    PROTOBUF_REFLECTION((*msg), field_name);                                                   \\\n    r->Set##pb_type_name(msg, fd, val);                                                        \\\n  }\n\nOF_PP_FOR_EACH_TUPLE(DEFINE_SET_VAL_IN_PBMESSAGE, PROTOBUF_BASIC_DATA_TYPE_SEQ)\n\nconst PbMessage& GetMessageInPbMessage(const PbMessage& msg, const std::string& field_name) {\n  PROTOBUF_REFLECTION(msg, field_name);\n  return r->GetMessage(msg, fd);\n}\n\nPbMessage* MutableMessageInPbMessage(PbMessage* msg, const std::string& field_name) {\n  PROTOBUF_REFLECTION((*msg), field_name);\n  return r->MutableMessage(msg, fd);\n}\n\nconst PbMessage& GetMessageInPbMessage(const PbMessage& msg, int field_index) {\n  const auto* d = const_cast<google::protobuf::Descriptor*>(msg.GetDescriptor());\n  const auto* fd = const_cast<PbFd*>(d->FindFieldByNumber(field_index));\n  CHECK_NOTNULL(fd);\n  const auto* r = const_cast<google::protobuf::Reflection*>(msg.GetReflection());\n  return r->GetMessage(msg, fd);\n}\n\nPbMessage* MutableMessageInPbMessage(PbMessage* msg, int field_index) {\n  const auto* d = const_cast<google::protobuf::Descriptor*>(msg->GetDescriptor());\n  const auto* fd = const_cast<PbFd*>(d->FindFieldByNumber(field_index));\n  CHECK_NOTNULL(fd);\n  const auto* r = const_cast<google::protobuf::Reflection*>(msg->GetReflection());\n  return r->MutableMessage(msg, fd);\n}\n\n#define DECLARE_GETTER_FUNC_HEADER(type) \\\n  template<>                             \\\n  type GetValFromPbMessage<type>(const PbMessage& msg, const std::string& field_name)\n\n#define DECLARE_SETTER_FUNC_HEADER(type) \\\n  template<>                             \\\n  void SetValInPbMessage<type>(PbMessage * msg, const std::string& field_name, const type& val)\n\n#define DEFINE_MESSAGE_VAL_GETTER_AND_SETTER(message_type)              \\\n  DECLARE_GETTER_FUNC_HEADER(message_type) {                            \\\n    PROTOBUF_REFLECTION(msg, field_name);                               \\\n    return *dynamic_cast<const message_type*>(&r->GetMessage(msg, fd)); \\\n  }                                                                     \\\n  DECLARE_SETTER_FUNC_HEADER(message_type) {                            \\\n    PROTOBUF_REFLECTION((*msg), field_name);                            \\\n    r->MutableMessage(msg, fd)->CopyFrom(val);                          \\\n  }\n\nDEFINE_MESSAGE_VAL_GETTER_AND_SETTER(ShapeProto);\nDEFINE_MESSAGE_VAL_GETTER_AND_SETTER(Int64ListProto);\n\n#define DEFINE_ENUM_VAL_GETTER_AND_SETTER(enum_type)         \\\n  DECLARE_GETTER_FUNC_HEADER(enum_type) {                    \\\n    PROTOBUF_REFLECTION(msg, field_name);                    \\\n    return static_cast<enum_type>(r->GetEnumValue(msg, fd)); \\\n  }                                                          \\\n  DECLARE_SETTER_FUNC_HEADER(enum_type) {                    \\\n    PROTOBUF_REFLECTION((*msg), field_name);                 \\\n    r->SetEnumValue(msg, fd, val);                           \\\n  }\n\nDEFINE_ENUM_VAL_GETTER_AND_SETTER(DataType);\n\n#define DEFINE_VECTOR_VAL_GETTER_AND_SETTER(vec_type, vec_type_name)                        \\\n  DECLARE_GETTER_FUNC_HEADER(vec_type) {                                                    \\\n    PROTOBUF_REFLECTION(msg, field_name);                                                   \\\n    int32_t field_size = r->FieldSize(msg, fd);                                             \\\n    vec_type retval(field_size);                                                            \\\n    for (int i = 0; i < field_size; ++i) { retval[i] = r->Get##vec_type_name(msg, fd, i); } \\\n    return retval;                                                                          \\\n  }                                                                                         \\\n  DECLARE_SETTER_FUNC_HEADER(vec_type) {                                                    \\\n    PROTOBUF_REFLECTION((*msg), field_name);                                                \\\n    for (int i = 0; i < val.size(); ++i) { r->Set##vec_type_name(msg, fd, i, val[i]); }     \\\n  }\n\n#define MAKE_REPEATED_TUPLE_SEQ(type, type_name) \\\n  OF_PP_MAKE_TUPLE_SEQ(std::vector<type>, Repeated##type_name)\n\n#define PROTOBUF_BASIC_REPEATED_DATA_TYPE_SEQ  \\\n  MAKE_REPEATED_TUPLE_SEQ(std::string, String) \\\n  MAKE_REPEATED_TUPLE_SEQ(int32_t, Int32)      \\\n  MAKE_REPEATED_TUPLE_SEQ(uint32_t, UInt32)    \\\n  MAKE_REPEATED_TUPLE_SEQ(int64_t, Int64)      \\\n  MAKE_REPEATED_TUPLE_SEQ(uint64_t, UInt64)    \\\n  MAKE_REPEATED_TUPLE_SEQ(float, Float)        \\\n  MAKE_REPEATED_TUPLE_SEQ(double, Double)      \\\n  MAKE_REPEATED_TUPLE_SEQ(int16_t, EnumValue)  \\\n  MAKE_REPEATED_TUPLE_SEQ(bool, Bool)\n\nOF_PP_FOR_EACH_TUPLE(DEFINE_VECTOR_VAL_GETTER_AND_SETTER, PROTOBUF_BASIC_REPEATED_DATA_TYPE_SEQ);\n\n#define DEFINE_ADD_VAL_IN_PBRF(cpp_type, pb_type_name)                                    \\\n  template<>                                                                              \\\n  void AddValInPbRf(PbMessage* msg, const std::string& field_name, const cpp_type& val) { \\\n    PROTOBUF_REFLECTION((*msg), field_name);                                              \\\n    r->Add##pb_type_name(msg, fd, val);                                                   \\\n  }\n\nOF_PP_FOR_EACH_TUPLE(DEFINE_ADD_VAL_IN_PBRF, PROTOBUF_BASIC_DATA_TYPE_SEQ)\n\nstd::pair<std::string, int32_t> GetFieldNameAndIndex4StrVal(const std::string& fd_name_with_idx) {\n  std::string field_name;\n  int32_t idx = 0;\n  CHECK_GE(idx, 0);\n  GetPrefixAndIndex(fd_name_with_idx, &field_name, &idx);\n  return std::make_pair(field_name, idx);\n}\n\nstd::string GetStrValInPbFdOrPbRpf(const PbMessage& msg, const std::string& fd_name_may_have_idx) {\n  const PbFd* fd = msg.GetDescriptor()->FindFieldByName(fd_name_may_have_idx);\n  if (fd) {\n    return GetValFromPbMessage<std::string>(msg, fd_name_may_have_idx);\n  } else {\n    const std::pair<std::string, int32_t> prefix_idx =\n        GetFieldNameAndIndex4StrVal(fd_name_may_have_idx);\n    return GetPbRpfFromPbMessage<std::string>(msg, prefix_idx.first).Get(prefix_idx.second);\n  }\n}\n\nbool HasStrFieldInPbFdOrPbRpf(const PbMessage& msg, const std::string& fd_name_may_have_idx) {\n  const PbFd* fd = msg.GetDescriptor()->FindFieldByName(fd_name_may_have_idx);\n  if (fd != nullptr) { return true; }\n  std::string field_name;\n  int32_t index = 0;\n  return TryGetPrefixAndIndex(fd_name_may_have_idx, &field_name, &index);\n}\n\nstd::string ReplaceStrValInPbFdOrPbRpf(PbMessage* msg, const std::string& fd_name_may_have_idx,\n                                       const std::string& new_val) {\n  const PbFd* fd = msg->GetDescriptor()->FindFieldByName(fd_name_may_have_idx);\n  std::string old_val;\n  if (fd) {\n    old_val = GetValFromPbMessage<std::string>(*msg, fd_name_may_have_idx);\n    SetValInPbMessage<std::string>(msg, fd_name_may_have_idx, new_val);\n  } else {\n    const std::pair<std::string, int32_t> prefix_idx =\n        GetFieldNameAndIndex4StrVal(fd_name_may_have_idx);\n    old_val = GetPbRpfFromPbMessage<std::string>(*msg, prefix_idx.first).Get(prefix_idx.second);\n    PbRpf<std::string>* rpf = MutPbRpfFromPbMessage<std::string>(msg, prefix_idx.first);\n    *rpf->Mutable(prefix_idx.second) = new_val;\n  }\n  return old_val;\n}\n\nPersistentOutStream& operator<<(PersistentOutStream& out_stream, const PbMessage& msg) {\n  std::string msg_bin;\n  msg.SerializeToString(&msg_bin);\n  int64_t msg_size = msg_bin.size();\n  CHECK_GT(msg_size, 0);\n  out_stream << msg_size << msg_bin;\n  return out_stream;\n}\n\nbool operator==(const BlobDescProto& lhs, const BlobDescProto& rhs) {\n  return PbMd().Equivalent(lhs, rhs);\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/common/protobuf.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_COMMON_PROTOBUF_H_\n#define ONEFLOW_CORE_COMMON_PROTOBUF_H_\n\n#ifdef _MSC_VER\n#include <io.h>\n#endif\n#include <google/protobuf/descriptor.h>\n#include <google/protobuf/map.h>\n#include <google/protobuf/message.h>\n#include <google/protobuf/util/message_differencer.h>\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/common/preprocessor.h\"\n#include \"oneflow/core/register/logical_blob_id.pb.h\"\n#include \"oneflow/core/register/op_blob_arg.pb.h\"\n#include \"oneflow/core/common/data_type.pb.h\"\n#include \"oneflow/core/job/sbp_parallel.pb.h\"\n#include \"oneflow/core/job/job_conf.pb.h\"\n#include \"oneflow/core/job/scope.pb.h\"\n#include \"oneflow/core/persistence/persistent_out_stream.h\"\n\nnamespace oneflow {\n\nusing PbMessage = google::protobuf::Message;\ntemplate<typename T>\nusing PbRf = google::protobuf::RepeatedField<T>;\ntemplate<typename T>\nusing PbRpf = google::protobuf::RepeatedPtrField<T>;\ntemplate<typename T1, typename T2>\nusing PbMapPair = google::protobuf::MapPair<T1, T2>;\ntemplate<typename K, typename V>\nusing PbMap = google::protobuf::Map<K, V>;\nusing PbFd = google::protobuf::FieldDescriptor;\nusing PbMd = google::protobuf::util::MessageDifferencer;\n\n#define PROTOBUF_BASIC_DATA_TYPE_SEQ        \\\n  OF_PP_MAKE_TUPLE_SEQ(std::string, String) \\\n  OF_PP_MAKE_TUPLE_SEQ(int32_t, Int32)      \\\n  OF_PP_MAKE_TUPLE_SEQ(uint32_t, UInt32)    \\\n  OF_PP_MAKE_TUPLE_SEQ(int64_t, Int64)      \\\n  OF_PP_MAKE_TUPLE_SEQ(uint64_t, UInt64)    \\\n  OF_PP_MAKE_TUPLE_SEQ(float, Float)        \\\n  OF_PP_MAKE_TUPLE_SEQ(double, Double)      \\\n  OF_PP_MAKE_TUPLE_SEQ(int16_t, EnumValue)  \\\n  OF_PP_MAKE_TUPLE_SEQ(bool, Bool)\n\n#define PROTOBUF_GET_FIELDDESC(msg, field_name)                            \\\n  auto d = const_cast<google::protobuf::Descriptor*>(msg.GetDescriptor()); \\\n  auto fd = const_cast<PbFd*>(d->FindFieldByName(field_name));\n\n#define PROTOBUF_REFLECTION(msg, field_name) \\\n  PROTOBUF_GET_FIELDDESC(msg, field_name)    \\\n  CHECK_NOTNULL(fd);                         \\\n  auto r = const_cast<google::protobuf::Reflection*>(msg.GetReflection());\n\n// Prototxt <-> File\nbool TryParseProtoFromTextFile(const std::string& file_path, PbMessage* proto);\nvoid ParseProtoFromTextFile(const std::string& file_path, PbMessage* proto);\nbool TryParseProtoFromPbFile(const std::string& file_path, PbMessage* proto);\nvoid ParseProtoFromPbFile(const std::string& file_path, PbMessage* proto);\nvoid PrintProtoToTextFile(const PbMessage& proto, const std::string& file_path);\nstd::string PbMessage2TxtString(const PbMessage& proto);\nvoid PbMessage2TxtString(const PbMessage& proto, std::string* str);\nbool TxtString2PbMessage(const std::string& proto_str, PbMessage* proto);\n\n// Does PbMessage have the field_name\nbool FieldDefinedInPbMessage(const PbMessage&, const std::string& field_name);\n\n// Get From PbMessage\ntemplate<typename T>\nT GetValFromPbMessage(const PbMessage&, const std::string& field_name);\n\ntemplate<typename T>\nconst PbRf<T>& GetPbRfFromPbMessage(const PbMessage& msg, const std::string& field_name) {\n  PROTOBUF_REFLECTION(msg, field_name);\n  return r->GetRepeatedField<T>(msg, fd);\n}\n\ntemplate<typename T>\nconst PbRpf<T>& GetPbRpfFromPbMessage(const PbMessage& msg, const std::string& field_name) {\n  PROTOBUF_REFLECTION(msg, field_name);\n  return r->GetRepeatedPtrField<T>(msg, fd);\n}\n\ntemplate<typename T>\nPbRpf<T>* MutPbRpfFromPbMessage(PbMessage* msg, const std::string& field_name) {\n  PROTOBUF_REFLECTION((*msg), field_name);\n  return r->MutableRepeatedPtrField<T>(msg, fd);\n}\n\n// Set In PbMessage\n\ntemplate<typename T>\nvoid SetValInPbMessage(PbMessage* msg, const std::string& field_name, const T& val);\n\nconst PbMessage& GetMessageInPbMessage(const PbMessage& msg, int field_index);\nconst PbMessage& GetMessageInPbMessage(const PbMessage& msg, const std::string& field_name);\n\nPbMessage* MutableMessageInPbMessage(PbMessage*, const std::string& field_name);\nPbMessage* MutableMessageInPbMessage(PbMessage*, int field_index);\n\n// Get/Replace str val maybe repeated;  field_name with index is like \"name_0\"\nstd::pair<std::string, int32_t> GetFieldNameAndIndex4StrVal(const std::string& fd_name_with_idx);\nstd::string GetStrValInPbFdOrPbRpf(const PbMessage& msg, const std::string& fd_name_may_have_idx);\nbool HasStrFieldInPbFdOrPbRpf(const PbMessage& msg, const std::string& fd_name_may_have_idx);\n// return old value\nstd::string ReplaceStrValInPbFdOrPbRpf(PbMessage* msg, const std::string& fd_name_may_have_idx,\n                                       const std::string& new_val);\n\n// Add In PbMessage RepeatedField\n\ntemplate<typename T>\nvoid AddValInPbRf(PbMessage*, const std::string& field_name, const T& val);\n\n// PbRf <-> std::vector\n\ntemplate<typename T>\ninline std::vector<T> PbRf2StdVec(const PbRf<T>& rf) {\n  return std::vector<T>(rf.begin(), rf.end());\n}\n\ntemplate<typename T>\ninline PbRf<T> StdVec2PbRf(const std::vector<T>& vec) {\n  return PbRf<T>(vec.begin(), vec.end());\n}\n\n// PbRpf <-> std::vector\ntemplate<typename T>\ninline std::vector<T> PbRpf2StdVec(const PbRpf<T>& rpf) {\n  return std::vector<T>(rpf.begin(), rpf.end());\n}\n\ntemplate<typename T>\ninline PbRpf<T> StdVec2PbRpf(const std::vector<T>& vec) {\n  using RetType = PbRpf<T>;\n  return RetType(vec.begin(), vec.end());\n}\n\n// ProtoMap <-> HashMap\ntemplate<typename K, typename V>\nHashMap<K, V> PbMap2HashMap(const google::protobuf::Map<K, V>& pb_map) {\n  return HashMap<K, V>(pb_map.begin(), pb_map.end());\n}\n\ntemplate<typename K, typename V>\ngoogle::protobuf::Map<K, V> HashMap2PbMap(const HashMap<K, V>& hash_map) {\n  using RetType = google::protobuf::Map<K, V>;\n  return RetType(hash_map.begin(), hash_map.end());\n}\n\n// If value exists in RepeatedField\ntemplate<typename T>\nbool IsInRepeatedField(const PbRf<T>& repeated_field, const T& value) {\n  return std::find(repeated_field.cbegin(), repeated_field.cend(), value) != repeated_field.cend();\n}\n\n// LBI compare operator\n\ninline bool operator<(const LogicalBlobId& lhs, const LogicalBlobId& rhs) {\n  if (lhs.op_name() != rhs.op_name()) { return lhs.op_name() < rhs.op_name(); }\n  if (lhs.blob_name() != rhs.blob_name()) { return lhs.blob_name() < rhs.blob_name(); }\n  return false;\n}\n\ninline bool operator==(const LogicalBlobId& lhs, const LogicalBlobId& rhs) {\n  return lhs.op_name() == rhs.op_name() && lhs.blob_name() == rhs.blob_name();\n}\n\ninline bool operator!=(const LogicalBlobId& lhs, const LogicalBlobId& rhs) { return !(lhs == rhs); }\n\ninline bool operator==(const OpBlobArg& lhs, const OpBlobArg& rhs) {\n  return PbMd().Equals(lhs, rhs);\n}\n\ninline bool operator!=(const OpBlobArg& lhs, const OpBlobArg& rhs) { return !(lhs == rhs); }\n\nclass BlobDescProto;\nbool operator==(const BlobDescProto& lhs, const BlobDescProto& rhs);\ninline bool operator!=(const BlobDescProto& lhs, const BlobDescProto& rhs) { return !(lhs == rhs); }\n\ninline bool operator==(const JobConfigProto& lhs, const JobConfigProto& rhs) {\n  return PbMd().Equals(lhs, rhs);\n}\n\ninline bool operator==(const ScopeProto& lhs, const ScopeProto& rhs) {\n  return PbMd().Equals(lhs, rhs);\n}\n\n// Persistent\n\nPersistentOutStream& operator<<(PersistentOutStream&, const PbMessage&);\n\ntemplate<typename T>\nstruct SerializedHashPb {\n  size_t operator()(const T& pb) const {\n    std::string serialized_string;\n    pb.SerializeToString(&serialized_string);\n    return std::hash<std::string>()(serialized_string);\n  }\n};\n\n}  // namespace oneflow\n\nnamespace std {\n\ntemplate<>\nstruct hash<oneflow::DataType> {\n  size_t operator()(const oneflow::DataType data_type) const {\n    return std::hash<int64_t>()(data_type);\n  }\n};\n\ntemplate<>\nstruct hash<oneflow::LogicalBlobId> {\n  size_t operator()(const oneflow::LogicalBlobId& lbi) const {\n    using namespace oneflow;\n    return Hash(lbi.op_name(), lbi.blob_name());\n  }\n};\n\ntemplate<>\nstruct hash<oneflow::OpBlobArg> {\n  size_t operator()(const oneflow::OpBlobArg& oba) const {\n    using namespace oneflow;\n    return Hash(oba.op_name(), oba.bn_in_op());\n  }\n};\n\ntemplate<>\nstruct hash<oneflow::SbpParallel> {\n  size_t operator()(const oneflow::SbpParallel& sbp_parallel) const {\n    using namespace oneflow;\n    size_t ret = 0;\n    if (sbp_parallel.has_broadcast_parallel()) {\n      AddHash(&ret, std::string(\"B\"));\n    } else if (sbp_parallel.has_partial_sum_parallel()) {\n      AddHash(&ret, std::string(\"P\"));\n    } else if (sbp_parallel.has_split_parallel()) {\n      AddHash(&ret, std::string(\"S\"));\n      AddHash(&ret, sbp_parallel.split_parallel().axis());\n    } else {\n      UNIMPLEMENTED();\n    }\n    return ret;\n  }\n};\n\ntemplate<>\nstruct hash<oneflow::NdSbp> {\n  size_t operator()(const oneflow::NdSbp& nd_sbp) const {\n    const auto& sbp_hash = std::hash<oneflow::SbpParallel>();\n    size_t hash = 0;\n    for (int i = 0; i < nd_sbp.sbp_parallel_size(); ++i) {\n      oneflow::HashCombine(&hash, sbp_hash(nd_sbp.sbp_parallel(i)));\n    }\n    return hash;\n  }\n};\n\ntemplate<>\nstruct hash<oneflow::JobConfigProto> {\n  size_t operator()(const oneflow::JobConfigProto& job_conf) const {\n    return oneflow::SerializedHashPb<oneflow::JobConfigProto>()(job_conf);\n  }\n};\n\ntemplate<>\nstruct hash<oneflow::ScopeProto> {\n  size_t operator()(const oneflow::ScopeProto& scope) const {\n    return oneflow::SerializedHashPb<oneflow::ScopeProto>()(scope);\n  }\n};\n\n}  // namespace std\n\n#endif  // ONEFLOW_CORE_COMMON_PROTOBUF_H_\n"
  },
  {
    "path": "oneflow/core/common/range.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/range.h\"\n\nnamespace oneflow {\n\nRange::Range(const RangeProto& range_proto) {\n  begin_ = range_proto.begin();\n  end_ = range_proto.end();\n}\n\nvoid Range::ToProto(RangeProto* ret) const {\n  ret->set_begin(begin_);\n  ret->set_end(end_);\n}\n\nMaybe<void> Range::ForEachSubRange(\n    int64_t sub_range_size, const std::function<Maybe<void>(const Range&)>& DoEachRange) const {\n  CHECK_EQ_OR_RETURN(size() % sub_range_size, 0);\n  int64_t start = begin();\n  for (; start < end(); start += sub_range_size) {\n    JUST(DoEachRange(Range(start, start + sub_range_size)));\n  }\n  CHECK_EQ_OR_RETURN(start, end());\n  return Maybe<void>::Ok();\n}\n\nRange FindIntersectant(const Range& lhs, const Range& rhs) {\n  if (lhs.end() > rhs.begin() && rhs.end() > lhs.begin()) {\n    int64_t left = lhs.begin() > rhs.begin() ? lhs.begin() : rhs.begin();\n    int64_t right = lhs.end() < rhs.end() ? lhs.end() : rhs.end();\n    return Range(left, right);\n  } else {\n    return Range(0, 0);\n  }\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/common/range.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_COMMON_RANGE_H_\n#define ONEFLOW_CORE_COMMON_RANGE_H_\n\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/common/range.pb.h\"\n\nnamespace oneflow {\n\nclass Range final {\n public:\n  // OF_DISALLOW_COPY_AND_MOVE(Range);\n  Range() : Range(0, 0) {}\n  ~Range() = default;\n\n  Range(int64_t begin, int64_t end) : begin_(begin), end_(end) {}\n  explicit Range(const RangeProto& range_proto);\n\n  bool operator==(const Range& rhs) const { return begin_ == rhs.begin_ && end_ == rhs.end_; }\n  bool operator!=(const Range& rhs) const { return !(*this == rhs); }\n\n  int64_t begin() const { return begin_; }\n  int64_t end() const { return end_; }\n\n  int64_t& mut_begin() { return begin_; }\n  int64_t& mut_end() { return end_; }\n\n  int64_t size() const { return end_ - begin_; }\n\n  Maybe<void> ForEachSubRange(int64_t sub_range_size,\n                              const std::function<Maybe<void>(const Range&)>& DoEachRange) const;\n\n  void ToProto(RangeProto* ret) const;\n\n private:\n  int64_t begin_;\n  int64_t end_;\n};\n\nRange FindIntersectant(const Range& lhs, const Range& rhs);\n\n}  // namespace oneflow\n\nnamespace std {\n\ntemplate<>\nstruct hash<oneflow::Range> {\n  size_t operator()(const oneflow::Range& range) const {\n    return oneflow::HashCombine(range.begin(), range.end());\n  }\n};\n\n}  // namespace std\n\n#endif  // ONEFLOW_CORE_COMMON_RANGE_H_\n"
  },
  {
    "path": "oneflow/core/common/range.proto",
    "content": "syntax = \"proto2\";\npackage oneflow;\n\nmessage RangeProto {\n  required int64 begin = 1;\n  required int64 end = 2;\n}\n"
  },
  {
    "path": "oneflow/core/common/registry_error.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/common/registry_error.h\"\n\nnamespace oneflow {\n\nnamespace {\nstd::shared_ptr<StackedError>* MutRegistryError() {\n  static std::shared_ptr<StackedError> registry_error;\n  return &registry_error;\n}\n}  // namespace\n\nMaybe<void> CheckAndClearRegistryFlag() {\n  if (!*MutRegistryError()) { return Maybe<void>::Ok(); }\n  std::shared_ptr<StackedError> registry_error_old = *MutRegistryError();\n  *MutRegistryError() = nullptr;\n  return registry_error_old;\n}\n\nvoid CatchRegistryError(const std::function<Maybe<void>()>& handler) {\n  const auto& maybe_error = TRY(handler());\n  if (!maybe_error.IsOk()) {\n    if (!*MutRegistryError()) { *MutRegistryError() = maybe_error.stacked_error(); }\n  }\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/common/registry_error.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_COMMON_REGISTRY_ERROR_H\n#define ONEFLOW_CORE_COMMON_REGISTRY_ERROR_H\n\n#include <functional>\n#include \"oneflow/core/common/maybe.h\"\n\nnamespace oneflow {\n\n// Note: there is a time interval between catching error and reporting an error,\n// any error occur in this interval can't be displayed.\nMaybe<void> CheckAndClearRegistryFlag();\nvoid CatchRegistryError(const std::function<Maybe<void>()>&);\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_COMMON_REGISTRY_ERROR_H\n"
  },
  {
    "path": "oneflow/core/common/scalar.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include <complex>\n#include \"oneflow/core/common/scalar.h\"\n\nnamespace oneflow {\n\n#define DEFINE_SCALAR_BINARY_OP(op)                                             \\\n  Scalar& Scalar::operator op##=(const Scalar& other) {                         \\\n    if (IsComplex() || other.IsComplex()) {                                     \\\n      std::complex<double> val =                                                \\\n          Value<std::complex<double>>() op other.Value<std::complex<double>>(); \\\n      *this = val;                                                              \\\n    }                                                                           \\\n    if (IsFloatingPoint() || other.IsFloatingPoint()) {                         \\\n      double val = As<double>() op other.As<double>();                          \\\n      *this = val;                                                              \\\n    } else {                                                                    \\\n      int64_t val = As<int64_t>() op other.As<int64_t>();                       \\\n      *this = val;                                                              \\\n    }                                                                           \\\n    return *this;                                                               \\\n  }                                                                             \\\n  Scalar Scalar::operator op(const Scalar& other) const {                       \\\n    if (IsComplex() || other.IsComplex()) {                                     \\\n      std::complex<double> val =                                                \\\n          Value<std::complex<double>>() op other.Value<std::complex<double>>(); \\\n      return Scalar(val);                                                       \\\n    }                                                                           \\\n    if (IsFloatingPoint() || other.IsFloatingPoint()) {                         \\\n      double val = As<double>() op other.As<double>();                          \\\n      return Scalar(val);                                                       \\\n    }                                                                           \\\n    int64_t val = As<int64_t>() op other.As<int64_t>();                         \\\n    return Scalar(val);                                                         \\\n  }\n\nDEFINE_SCALAR_BINARY_OP(+);\nDEFINE_SCALAR_BINARY_OP(-);\nDEFINE_SCALAR_BINARY_OP(*);\nDEFINE_SCALAR_BINARY_OP(/);  // NOLINT\n#undef DEFINE_SCALAR_BINARY_OP\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/common/scalar.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#ifndef ONEFLOW_CORE_COMMON_SCALAR_H_\n#define ONEFLOW_CORE_COMMON_SCALAR_H_\n\n#include <type_traits>\n#include <complex>\n#include \"oneflow/core/common/data_type.h\"\n#include \"oneflow/core/common/maybe.h\"\n\nnamespace oneflow {\n\nclass Scalar {\n public:\n  Scalar() : Scalar(int32_t(0)) {}\n\n  template<typename T, typename std::enable_if<std::is_same<std::complex<float>, T>::value\n                                                   || std::is_same<std::complex<double>, T>::value,\n                                               int>::type = 0>\n  Scalar(const T& value) : value_{.c = {value.real(), value.imag()}}, active_tag_(HAS_C) {}\n\n  template<typename T, typename std::enable_if<std::is_same<T, bool>::value, int>::type = 0>\n  OF_DEVICE_FUNC Scalar(const T& value) : value_{.b = value}, active_tag_(HAS_B) {}\n\n  template<typename T, typename std::enable_if<\n                           std::is_integral<T>::value && std::is_signed<T>::value, int>::type = 0>\n  OF_DEVICE_FUNC Scalar(const T& value) : value_{.s = value}, active_tag_(HAS_S) {}\n\n  template<typename T,\n           typename std::enable_if<std::is_integral<T>::value && std::is_unsigned<T>::value\n                                       && !std::is_same<T, bool>::value,\n                                   int>::type = 0>\n  OF_DEVICE_FUNC Scalar(const T& value) : value_{.u = value}, active_tag_(HAS_U) {}\n\n  template<typename T, typename std::enable_if<std::is_floating_point<T>::value, int>::type = 0>\n  OF_DEVICE_FUNC Scalar(const T& value) : value_{.d = value}, active_tag_(HAS_D) {}\n\n  template<typename T, typename std::enable_if<!std::is_same<T, Scalar>::value, int>::type = 0>\n  OF_DEVICE_FUNC Scalar& operator=(const T& value) {\n    *this = Scalar(value);\n    return *this;\n  }\n\n  OF_DEVICE_FUNC Scalar& operator=(const Scalar& other) {\n    value_ = other.value_;\n    active_tag_ = other.active_tag_;\n    return *this;\n  }\n\n  template<typename T, typename std::enable_if<std::is_scalar<T>::value, int>::type = 0>\n  OF_DEVICE_FUNC T As() const {\n    switch (active_tag_) {\n      case HAS_B: return static_cast<T>(value_.b);\n      case HAS_S: return static_cast<T>(value_.s);\n      case HAS_U: return static_cast<T>(value_.u);\n      case HAS_D: return static_cast<T>(value_.d);\n      default: assert(false); return 0;\n    }\n  }\n\n  template<typename T, typename std::enable_if<std::is_scalar<T>::value, int>::type = 0>\n  OF_DEVICE_FUNC T Value() const {\n    return As<T>();\n  }\n\n  template<typename T, typename std::enable_if<std::is_same<std::complex<float>, T>::value\n                                                   || std::is_same<std::complex<double>, T>::value,\n                                               int>::type = 0>\n  T Value() const {\n    if (!IsComplex()) { return T(As<double>(), 0.0); }\n    return T(value_.c.real, value_.c.imag);\n  }\n\n  bool IsBool() const { return active_tag_ == HAS_B; }\n  bool IsIntegral() const { return active_tag_ == HAS_S || active_tag_ == HAS_U; }\n  bool IsFloatingPoint() const { return active_tag_ == HAS_D; }\n  bool IsSigned() const { return active_tag_ == HAS_S || active_tag_ == HAS_D; }\n  bool IsUnsigned() const { return active_tag_ == HAS_U; }\n  bool IsComplex() const { return active_tag_ == HAS_C; }\n\n  Scalar operator+(const Scalar& other) const;\n  Scalar operator-(const Scalar& other) const;\n  Scalar operator*(const Scalar& other) const;\n  Scalar operator/(const Scalar& other) const;\n\n  Scalar& operator+=(const Scalar& other);\n  Scalar& operator-=(const Scalar& other);\n  Scalar& operator*=(const Scalar& other);\n  Scalar& operator/=(const Scalar& other);\n\n private:\n  union Value {\n    bool b;\n    int64_t s;\n    uint64_t u;\n    double d;\n    struct {\n      double real;\n      double imag;\n    } c;\n  } value_;\n  enum { HAS_B, HAS_S, HAS_U, HAS_D, HAS_C, HAS_NONE } active_tag_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_COMMON_SCALAR_H_\n"
  },
  {
    "path": "oneflow/core/common/sequential.proto",
    "content": "syntax = \"proto2\";\npackage oneflow;\n\nmessage Int64ListProto {\n  repeated int64 dim = 1;\n}\n\n\n"
  },
  {
    "path": "oneflow/core/common/shape.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/shape.h\"\n#include \"oneflow/core/common/shape_view.h\"\n#include \"oneflow/core/common/protobuf.h\"\n\nnamespace oneflow {\n\ntemplate<class T>\nint64_t ConstShapeMixIn<T>::elem_cnt() const {\n  return std::accumulate(tp()->begin(), tp()->end(), int64_t(1), std::multiplies<>());\n}\n\ntemplate<class T>\nint64_t ConstShapeMixIn<T>::At(int64_t index) const {\n  CHECK_GE(index, 0);\n  CHECK_LT(index, tp()->NumAxes()) << \" Shape: \" << tp()->DebugStr() << \" visit index: \" << index\n                                   << \" > num_axes: \" << tp()->NumAxes();\n  return (*tp())[index];\n}\n\ntemplate<class T>\nint64_t ConstShapeMixIn<T>::Count(int64_t begin_axis, int64_t end_axis) const {\n  CHECK(0 <= begin_axis && begin_axis <= end_axis && end_axis <= tp()->NumAxes())\n      << begin_axis << \" \" << end_axis;\n  int64_t cnt = 1;\n  for (int64_t i = begin_axis; i < end_axis; ++i) { cnt *= At(i); }\n  return cnt;\n}\ntemplate<class T>\nint64_t ConstShapeMixIn<T>::Count(int64_t begin_axis) const {\n  return Count(begin_axis, tp()->NumAxes());\n}\n\ntemplate<class T>\nbool ConstShapeMixIn<T>::Containing(ShapeView small_shape) const {\n  if (tp()->NumAxes() < small_shape.NumAxes()) { return false; }\n  FOR_RANGE(int, i, 0, small_shape.NumAxes()) {\n    if (tp()->At(i) != small_shape.At(i)) { return false; }\n  }\n  return true;\n}\n\ntemplate<class T>\nbool ConstShapeMixIn<T>::MatchBeforeLastDim(ShapeView next_shape) const {\n  if (tp()->NumAxes() != next_shape.NumAxes()) { return false; }\n  for (int64_t i = 0; i < tp()->NumAxes() - 1; ++i) {\n    if (next_shape.At(i) != tp()->At(i)) { return false; }\n  }\n  return true;\n}\n\ntemplate<class T>\nstd::string ConstShapeMixIn<T>::ToString() const {\n  std::stringstream ss;\n  int32_t idx = 0;\n  ss << \"(\";\n  for (int64_t dim : *tp()) {\n    ss << dim;\n    if (++idx != tp()->size() || tp()->size() == 1) { ss << \",\"; }\n  }\n  ss << \")\";\n  return ss.str();\n}\n\ntemplate<class T>\nstd::string ConstShapeMixIn<T>::DebugStr() const {\n  return ToString();\n}\n\ntemplate<class T>\nvoid ConstShapeMixIn<T>::ToProto(ShapeProto* ret) const {\n  *(ret->mutable_dim()) = PbRf<int64_t>(tp()->begin(), tp()->end());\n}\n\ntemplate<class T>\nbool ConstShapeMixIn<T>::operator==(const T& rhs) const {\n  if (this->NumAxes() != rhs.NumAxes()) { return false; }\n  FOR_RANGE(int, i, 0, this->NumAxes()) {\n    if (this->At(i) != rhs.At(i)) { return false; }\n  }\n  return true;\n}\n\ntemplate struct ConstShapeMixIn<Shape>;\ntemplate struct MutShapeMixIn<Shape>;\ntemplate struct ConstShapeMixIn<ShapeView>;\ntemplate struct ConstShapeMixIn<MutShapeView>;\ntemplate struct MutShapeMixIn<MutShapeView>;\n\nShape CreateReducedShape(ShapeView shape, const AxisVector& axis_vec) {\n  // For 0-dim Tensor\n  if (axis_vec.empty()) { return Shape({}); }\n  DimVector dim_vec;\n  shape.ToDimVector(&dim_vec);\n  for (int64_t axis : axis_vec) { dim_vec.at(ShiftNegativeAxis(axis, shape.NumAxes())) = 1; }\n  return Shape(std::move(dim_vec));\n}\n\nShape CreateLeftExtendedShape(ShapeView shape, int ndims_left_extend_to) {\n  CHECK_GE(ndims_left_extend_to, shape.NumAxes());\n  DimVector dim_vec(ndims_left_extend_to);\n  const size_t left_ones_num = ndims_left_extend_to - shape.NumAxes();\n  int i = 0;\n  for (; i < left_ones_num; ++i) { dim_vec.at(i) = 1LL; }\n  for (; i < ndims_left_extend_to; ++i) { dim_vec.at(i) = shape.At(i - left_ones_num); }\n  return Shape(std::move(dim_vec));\n}\n\nShape ExpandDimIf0D(const Shape& shape) {\n  if (shape.NumAxes() == 0) { return {1}; }\n  return shape;\n}\n\nShape ExpandDimIf0D(ShapeView shape) {\n  if (shape.NumAxes() == 0) { return {1}; }\n  return Shape(shape);\n}\n\nShape CreateReducedShapeOrOnesShape(ShapeView shape, const AxisVector& axis_vec) {\n  if (axis_vec.empty()) { return Shape::Ones(shape.NumAxes()); }\n  return CreateReducedShape(shape, axis_vec);\n}\n\nint64_t ShiftNegativeAxis(int64_t axis, const int64_t num_axes) {\n  if (axis < 0) { axis += num_axes; }\n  CHECK_GE(axis, 0);\n  CHECK_LT(axis, num_axes);\n  return axis;\n}\n\nShape::Shape(const DimVector& dim_vec) : DimVector(dim_vec), is_initialized_(true) {}\nShape::Shape(DimVector&& dim_vec) : DimVector(std::move(dim_vec)), is_initialized_(true) {}\nShape::Shape(const ShapeProto& shape_proto)\n    : DimVector(shape_proto.dim().begin(), shape_proto.dim().end()), is_initialized_(true) {}\nShape::Shape(ShapeView shape_view)\n    : DimVector(shape_view.begin(), shape_view.end()), is_initialized_(true) {}\n\nShape& Shape::CheckNumAxesIdenticalAndAssign(ShapeView shape_view) {\n  CHECK_EQ(NumAxes(), shape_view.NumAxes());\n  std::copy(shape_view.ptr(), shape_view.ptr() + shape_view.NumAxes(), data());\n  return *this;\n}\n\nShape& Shape::LeftOnesExtendedAssign(ShapeView shape_view) {\n  CHECK_GE(NumAxes(), shape_view.NumAxes());\n  size_t left_ones_size = NumAxes() - shape_view.NumAxes();\n  FOR_RANGE(int, i, 0, left_ones_size) { (*this)[i] = 1LL; }\n  std::copy(shape_view.ptr(), shape_view.ptr() + shape_view.NumAxes(), data() + left_ones_size);\n  return *this;\n}\n\nstd::ostream& operator<<(std::ostream& out, const Shape& shape) {\n  out << shape.DebugStr();\n  return out;\n}\n\nAxisVector Shape::ShiftNegativeAxisVec(const AxisVector& axis_vec) const {\n  const int64_t num_axes = this->NumAxes();\n  AxisVector ret = axis_vec;\n  for (int64_t i = 0; i < axis_vec.size(); i++) {\n    ret.at(i) = ShiftNegativeAxis(axis_vec.at(i), num_axes);\n  }\n  return ret;\n}\n\nShape Shape::RemoveOnes(const AxisVector& axis_vec) const {\n  DimVector dim_vec;\n  const AxisVector& axis_vec_shifted = ShiftNegativeAxisVec(axis_vec);\n  for (int64_t i = 0; i < this->dim_vec().size(); i++) {\n    if (std::find(axis_vec_shifted.begin(), axis_vec_shifted.end(), i) == axis_vec_shifted.end()) {\n      dim_vec.emplace_back(this->dim_vec().at(i));\n    } else {\n      CHECK_EQ(this->dim_vec().at(i), 1);\n    }\n  }\n  return Shape(dim_vec);\n}\n\nShape Shape::Ones(const int64_t num_axes) {\n  DimVector dim_vec(num_axes);\n  std::fill(dim_vec.begin(), dim_vec.end(), 1);\n  return Shape(dim_vec);\n}\n\nAxisVector Shape::Axes4BroadcastTo(ShapeView broadcast_shape) const {\n  AxisVector broadcast_axis_vec;\n  CHECK_EQ(broadcast_shape.NumAxes(), NumAxes());\n  for (int64_t i = 0; i < NumAxes(); i++) {\n    if (this->dim_vec().at(i) != broadcast_shape[i] && this->dim_vec().at(i) == 1) {\n      broadcast_axis_vec.emplace_back(i);\n    } else {\n      CHECK_EQ(this->dim_vec().at(i), broadcast_shape[i]);\n    }\n  }\n  CHECK(!broadcast_axis_vec.empty());\n  return broadcast_axis_vec;\n}\n\nMaybe<Shape> Shape::Slice(int64_t start_dim, int64_t end_dim) const {\n  CHECK_OR_RETURN(start_dim >= 0 && end_dim >= start_dim);\n  int64_t ndims = this->NumAxes();\n  if (start_dim > ndims) { start_dim = ndims; }\n  if (end_dim > ndims) { end_dim = ndims; }\n  std::shared_ptr<Shape> shape = std::make_shared<Shape>();\n  shape->assign(this->begin() + start_dim, this->begin() + end_dim);\n  return shape;\n}\n\nMaybe<Shape> Shape::Slice(int64_t start_dim) const { return Slice(start_dim, NumAxes()); }\n\nbool Shape::operator==(const Shape& rhs) const {\n  if (is_initialized_ != rhs.is_initialized_) { return false; }\n  if (is_initialized_ == false) { return true; }\n  if (this->NumAxes() != rhs.NumAxes()) { return false; }\n  FOR_RANGE(int, i, 0, this->NumAxes()) {\n    if (this->At(i) != rhs.At(i)) { return false; }\n  }\n  return true;\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/common/shape.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_COMMON_SHAPE_H_\n#define ONEFLOW_CORE_COMMON_SHAPE_H_\n\n#include \"oneflow/core/common/shape.pb.h\"\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/common/shape_vec.h\"\n#include \"oneflow/core/common/optional.h\"\n\nnamespace oneflow {\n\nclass ShapeView;\nclass MutShapeView;\nclass ShapeProto;\n\nnamespace cfg {\nclass ShapeProto;\n}  // namespace cfg\n\n/**\n * NOTE:\n *\n * There are two widely used shape-related classes: Shape and ShapeView.\n * The differences are:\n * 1. Shape owns the data, and ShapeView does not.\n * 2. ShapeView is very lightweight, whose size is only 16 bytes (two int64_t).\n *    So it should be passed by value.\n *\n * When adding new functions accepting a shape as a parameter, please follow\n * the rules:\n * 1. If your function doesn't modify the shape, prefer\n *    ShapeView. Shape can be implicitly converted to ShapeView so the method\n *    with ShapeView parameter can accept both Shape and ShapeView actually.\n * 2. If your function modify the shape but doesn't affect\n *    its rank, prefer MutShapeView. The reason is the same with rule 1.\n * 3. Use Shape otherwise.\n *\n * When adding new member methods of Shape or ShapeView, please follow\n * the rules:\n * 1. If the method is shared between Shape and ShapeView (like `NumAxes()`)\n *    please add it to ConstShapeMixIn.\n * 2. If the method is shared between Shape and MutShapeView (like `Set()`)\n *    please add it to MutShapeMixIn.\n * 3. Otherwise, add it to a concrete class (Shape, ShapeView or MutShapeView).\n *\n */\ntemplate<class T>\nstruct ConstShapeMixIn {\n  using DimType = int64_t;\n\n  int64_t NumAxes() const { return tp()->size(); }\n  int64_t elem_cnt() const;\n  int64_t At(int64_t index) const;\n  int64_t Count(int64_t begin_axis, int64_t end_axis) const;\n  int64_t Count(int64_t begin_axis) const;\n  bool Containing(ShapeView small_shape) const;\n  bool MatchBeforeLastDim(ShapeView next_shape) const;\n  std::string ToString() const;\n\n  std::string DebugStr() const;\n\n  void ToProto(ShapeProto* ret) const;\n\n  template<typename StreamT>\n  void SerializeWithTextFormat(StreamT& out_stream) const {\n    for (int64_t dim : *this) { out_stream << std::to_string(dim) << ' '; }\n  }\n\n  bool operator==(const T& rhs) const;\n\n protected:\n  // tp means \"this pointer\"\n  T* tp() { return static_cast<T*>(this); }\n  const T* tp() const { return static_cast<const T*>(this); }\n};\n\ntemplate<class T>\nstruct MutShapeMixIn : public ConstShapeMixIn<T> {\n  void Set(int64_t index, int64_t val) {\n    CHECK_GE(index, 0);\n    CHECK_LT(index, this->tp()->NumAxes())\n        << \" Shape: \" << this->tp()->DebugStr() << \" visit index: \" << index\n        << \" > num_axes: \" << this->tp()->NumAxes();\n    (*this->tp())[index] = val;\n  }\n};\n\nclass Shape final : public DimVector, public MutShapeMixIn<Shape> {\n public:\n  // OF_DISALLOW_COPY_AND_MOVE(Shape);\n  using DimVector::DimVector;\n  Shape() : is_initialized_(false) {}\n  explicit Shape(const DimVector& dim_vec);\n  explicit Shape(DimVector&& dim_vec);\n  explicit Shape(const ShapeProto& shape_proto);\n  // explicit constructor from ShapeView\n  explicit Shape(ShapeView shape_view);\n  ~Shape() = default;\n  using DimVector::operator==;\n\n#define OVERRIDE_ADD_DATA_FUNC(func)              \\\n  template<typename... Args>                      \\\n  void func(Args... args) {                       \\\n    DimVector::func(std::forward<Args>(args)...); \\\n    is_initialized_ = true;                       \\\n  }\n\n  OVERRIDE_ADD_DATA_FUNC(assign)\n  OVERRIDE_ADD_DATA_FUNC(push_back)\n  OVERRIDE_ADD_DATA_FUNC(emplace_back)\n  OVERRIDE_ADD_DATA_FUNC(append)\n  OVERRIDE_ADD_DATA_FUNC(insert)\n  OVERRIDE_ADD_DATA_FUNC(resize)\n\n#undef OVERRIDE_ADD_DATA_FUNC\n\n  Shape& CheckNumAxesIdenticalAndAssign(ShapeView shape_view);\n  Shape& LeftOnesExtendedAssign(ShapeView shape_view);\n\n  // Getters and Setters\n  bool is_initialized() const { return is_initialized_; }\n  const DimVector& dim_vec() const { return *this; }\n  DimVector& dim_vec() { return *this; }\n  int64_t NumAxes() const {\n    CHECK(is_initialized());\n    return ConstShapeMixIn<Shape>::NumAxes();\n  }\n  AxisVector ShiftNegativeAxisVec(const AxisVector& axis_vec) const;\n  Shape RemoveOnes(const AxisVector& axis_vec) const;\n  static Shape Ones(const int64_t num_axes);\n  AxisVector Axes4BroadcastTo(ShapeView broadcast_dim_vec) const;\n\n  Maybe<Shape> Slice(int64_t start_dim, int64_t end_dim) const;\n  Maybe<Shape> Slice(int64_t start_dim) const;\n\n  bool operator==(const Shape& rhs) const;\n\n private:\n  // Set default value here because some constructors are inherited from DimVector\n  // TODO(daquexian): remove this field and make it initializied by construction\n  bool is_initialized_ = true;\n};\n\nint64_t ShiftNegativeAxis(int64_t axis, const int64_t num_axes);\n\nShape CreateReducedShape(ShapeView shape, const AxisVector& axis_vec);\nShape CreateLeftExtendedShape(ShapeView shape, int ndims_extend_to);\nShape ExpandDimIf0D(const Shape& shape);\nShape ExpandDimIf0D(ShapeView shape);\nShape CreateReducedShapeOrOnesShape(ShapeView shape, const AxisVector& axis_vec);\n\nstd::ostream& operator<<(std::ostream& out, const Shape& shape);\n\n}  // namespace oneflow\n\nnamespace std {\n\ntemplate<>\nstruct hash<oneflow::Shape> {\n  size_t operator()(const oneflow::Shape& shape) const {\n    if (!shape.is_initialized()) { return 0; }\n    size_t ret = shape.NumAxes();\n    FOR_RANGE(int, i, 0, shape.NumAxes()) { oneflow::AddHash(&ret, shape.At(i)); }\n    return ret;\n  }\n};\n\n}  // namespace std\n\n#endif  // ONEFLOW_CORE_COMMON_SHAPE_H_\n"
  },
  {
    "path": "oneflow/core/common/shape.proto",
    "content": "syntax = \"proto2\";\npackage oneflow;\n\n// NOTE: shape.proto can be replaced with sequential.proto\n// for compatibility reasons, it will not be modified here.\nmessage ShapeProto {\n  repeated int64 dim = 1;\n}\n\n\n"
  },
  {
    "path": "oneflow/core/common/shape_test.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/shape.h\"\n#include \"gtest/gtest.h\"\n#include <functional>\n#include <algorithm>\n\nnamespace oneflow {\n\nnamespace test {\n\nTEST(Shape, constructor_0) {\n  Shape a;\n  ASSERT_EQ(a.is_initialized(), false);\n}\n\nTEST(Shape, function_test_1) {\n  Shape shape({4096, 16, 197, 197});\n  ASSERT_EQ(shape.is_initialized(), true);\n  ASSERT_EQ(shape.NumAxes(), 4);\n  ASSERT_EQ(shape.elem_cnt(), 2543386624);\n}\n\n}  // namespace test\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/common/shape_vec.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_COMMON_SHAPE_VEC_H_\n#define ONEFLOW_CORE_COMMON_SHAPE_VEC_H_\n\n#include \"oneflow/core/common/small_vector.h\"\n\nnamespace oneflow {\n\n#define SHAPE_MAX_AXIS_SIZE 20\n\ntypedef small_vector<int64_t, SHAPE_MAX_AXIS_SIZE> DimVector;\ntypedef small_vector<int64_t, SHAPE_MAX_AXIS_SIZE> AxisVector;\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_COMMON_SHAPE_VEC_H_\n"
  },
  {
    "path": "oneflow/core/common/shape_view.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/shape.h\"\n#include \"oneflow/core/common/shape.pb.h\"\n#include \"oneflow/core/common/shape_view.h\"\n\nnamespace oneflow {\n\nvoid ShapeView::ToDimVector(DimVector* dim_vec) const {\n  dim_vec->resize(this->size());\n  dim_vec->assign(this->data(), this->data() + this->size());\n}\n\nvoid ShapeView::ToShape(Shape* shape) const {\n  DimVector dim_vec;\n  this->ToDimVector(&dim_vec);\n  *shape = Shape(dim_vec);\n}\n\nstd::ostream& operator<<(std::ostream& out, ShapeView shape) {\n  out << shape.ToString();\n  return out;\n}\n\nvoid MutShapeView::set_shape(ShapeView shape) {\n  if (shape.ptr() == mut_ptr() && shape.NumAxes() == NumAxes()) { return; }\n  CHECK_EQ(NumAxes(), shape.NumAxes());\n  std::copy(shape.ptr(), shape.ptr() + shape.NumAxes(), mut_ptr());\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/common/shape_view.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_REGISTER_SHAPE_VIEW_H_\n#define ONEFLOW_CORE_REGISTER_SHAPE_VIEW_H_\n\n#include \"oneflow/core/common/array_ref.h\"\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/common/shape.h\"\n\nnamespace oneflow {\n\nclass ShapeProto;\nclass Shape;\n\nclass ShapeView : public ArrayRef<int64_t>, public ConstShapeMixIn<ShapeView> {\n public:\n  ShapeView() = default;\n  // NOLINTNEXTLINE\n  ShapeView(const ShapeProto& shape_proto)\n      : ArrayRef<int64_t>(shape_proto.dim().data(), shape_proto.dim_size()){};\n  // NOLINTNEXTLINE\n  ShapeView(const Shape& shape)\n      : ArrayRef<int64_t>(shape.dim_vec().data(), shape.dim_vec().size()){};\n\n  using ArrayRef<DimType>::ArrayRef;\n\n  const DimType* ptr() const { return this->data(); }\n\n  void ToDimVector(DimVector* dim_vec) const;\n  void ToShape(Shape* shape) const;\n};\n\nstd::ostream& operator<<(std::ostream& out, ShapeView shape);\n\nclass MutShapeView final : public MutableArrayRef<int64_t>, public MutShapeMixIn<MutShapeView> {\n public:\n  using MutableArrayRef<DimType>::MutableArrayRef;\n  // NOLINTNEXTLINE\n  MutShapeView(Shape& shape)\n      : MutableArrayRef<int64_t>(shape.dim_vec().data(), shape.dim_vec().size()){};\n\n  int64_t* mut_ptr() const { return this->data(); }\n\n  void set_shape(ShapeView shape);\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_REGISTER_SHAPE_VIEW_H_\n"
  },
  {
    "path": "oneflow/core/common/shared_or_scalar.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_COMMON_SHARED_OR_SCALAR_H_\n#define ONEFLOW_CORE_COMMON_SHARED_OR_SCALAR_H_\n\n#include <memory>\n\n#include \"oneflow/core/common/throw.h\"\n#include \"oneflow/core/common/type_traits.h\"\n#include \"oneflow/core/common/preprocessor.h\"\n\nnamespace oneflow {\n\ntemplate<typename StructT, typename ScalarT>\nclass SharedOrScalar final {\n public:\n  static_assert(IsScalarType<ScalarT>::value, \"ScalarT should be scalar type.\");\n\n  using Shared = std::shared_ptr<StructT>;\n\n  SharedOrScalar(const ScalarT& scalar_value) : is_scalar_(true), scalar_value_(scalar_value) {}\n\n  SharedOrScalar(const std::shared_ptr<StructT>& shared_ptr) : is_scalar_(false) {\n    new (&shared_mem_) Shared(shared_ptr);\n  }\n\n  SharedOrScalar(std::shared_ptr<StructT>&& shared_ptr) : is_scalar_(false) {\n    new (&shared_mem_) Shared(std::move(shared_ptr));\n  }\n\n  SharedOrScalar(const SharedOrScalar& rhs) : is_scalar_(rhs.is_scalar_) {\n    if (rhs.is_scalar_) {\n      scalar_value_ = rhs.scalar_value_;\n    } else {\n      new (&shared_mem_) Shared(rhs.GetShared());\n    }\n  }\n\n  SharedOrScalar(SharedOrScalar&& rhs) : is_scalar_(rhs.is_scalar_) {\n    if (rhs.is_scalar_) {\n      scalar_value_ = rhs.scalar_value_;\n    } else {\n      new (&shared_mem_) Shared(std::move(*rhs.MutableShared()));\n    }\n  }\n\n  SharedOrScalar& operator=(const SharedOrScalar& rhs) {\n    if (rhs.is_scalar_) {\n      scalar_value_ = rhs.scalar_value_;\n    } else {\n      if (is_scalar_) {\n        scalar_value_.~ScalarT();\n        new (&shared_mem_) Shared(rhs.GetShared());\n      } else {\n        *MutableShared() = rhs.GetShared();\n      }\n    }\n    is_scalar_ = rhs.is_scalar_;\n    return *this;\n  }\n\n  SharedOrScalar& operator=(SharedOrScalar&& rhs) {\n    if (rhs.is_scalar_) {\n      scalar_value_ = rhs.scalar_value_;\n    } else {\n      if (is_scalar_) {\n        scalar_value_.~ScalarT();\n        new (&shared_mem_) Shared(std::move(*rhs.MutableShared()));\n      } else {\n        *MutableShared() = std::move(*rhs.MutableShared());\n      }\n    }\n    is_scalar_ = rhs.is_scalar_;\n    return *this;\n  }\n\n  ~SharedOrScalar() {\n    if (is_scalar_) {\n      scalar_value_.~ScalarT();\n    } else {\n      GetShared().~Shared();\n    }\n  }\n\n  bool IsScalar() const { return is_scalar_; }\n  const ScalarT& scalar_value() const {\n    CHECK(is_scalar_);\n    return scalar_value_;\n  }\n\n  const std::shared_ptr<StructT>& shared_ptr() const {\n    CHECK(!is_scalar_);\n    return GetShared();\n  }\n\n  const ScalarT& operator*() const { return scalar_value(); }\n\n private:\n  bool is_scalar_;\n  union {\n    ScalarT scalar_value_;\n\n    //  to avoid error(a non-POD class definition is not allowed inside of a statement expression)\n    //  in nvcc while using with JUST macro (this type is used in Maybe)\n    alignas(Shared) char shared_mem_[sizeof(Shared)];\n  };\n\n  const Shared& GetShared() const {\n    const auto* __attribute__((__may_alias__)) shared =\n        reinterpret_cast<const Shared*>(&shared_mem_);\n    return *shared;\n  }\n\n  Shared* MutableShared() {\n    auto* __attribute__((__may_alias__)) shared = reinterpret_cast<Shared*>(&shared_mem_);\n    return shared;\n  }\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_COMMON_SHARED_OR_SCALAR_H_\n"
  },
  {
    "path": "oneflow/core/common/single_thread_obj_pool.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_COMMON_SINGLE_THREAD_OBJ_POOL_H_\n#define ONEFLOW_CORE_COMMON_SINGLE_THREAD_OBJ_POOL_H_\n\n#include <vector>\n#include <mutex>\n#include <memory>\n#include <thread>\n#include \"oneflow/core/common/throw.h\"\n#include \"oneflow/core/common/cpp_attribute.h\"\n\nnamespace oneflow {\nnamespace obj_pool {\n\nenum ReuseStrategy {\n  kEnableReconstruct,\n  kDisableReconstruct,\n};\n\n// object pool for single thread.\ntemplate<typename T, ReuseStrategy reuse_strategy = kEnableReconstruct>\nclass SingleThreadObjPool\n    : public std::enable_shared_from_this<SingleThreadObjPool<T, reuse_strategy>> {\n public:\n  SingleThreadObjPool() : pool_(), invalid_thread_id_(), owner_thread_id_(invalid_thread_id_) {\n    pool_.reserve(kInitPoolCap);\n  }\n  ~SingleThreadObjPool() {\n    if (reuse_strategy != kEnableReconstruct) {\n      for (T* ptr : pool_) { delete ptr; }\n    }\n  }\n\n  template<typename... Args>\n  std::shared_ptr<T> make_shared(Args&&... args) {\n    auto* ptr = New(std::forward<Args>(args)...);\n    std::weak_ptr<SingleThreadObjPool> pool(this->shared_from_this());\n    return std::shared_ptr<T>(ptr, [pool](T* ptr) { TryPut(pool.lock(), ptr); });\n  }\n\n private:\n  static constexpr int kInitPoolCap = 1024;\n\n  template<typename... Args>\n  T* New(Args&&... args) {\n    if (likely(pool_.size())) {\n      auto* ptr = Get();\n      if (reuse_strategy == kEnableReconstruct) { new (ptr) T(std::forward<Args>(args)...); }\n      return ptr;\n    }\n    return new T(std::forward<Args>(args)...);\n  }\n\n  static void TryPut(const std::shared_ptr<SingleThreadObjPool>& pool, T* object) {\n    if (likely(static_cast<bool>(pool))) {\n      pool->Put(object);\n    } else {\n      object->~T();\n    }\n  }\n\n  T* Get() {\n    CheckOrSetSingleThreadFlag();\n    auto* ptr = pool_[pool_.size() - 1];\n    pool_.pop_back();\n    return ptr;\n  }\n\n  void Put(T* obj) {\n    CheckOrSetSingleThreadFlag();\n    pool_.push_back(obj);\n    if (reuse_strategy == kEnableReconstruct) { obj->~T(); }\n  }\n\n  // Try to detect being wrongly used by multi threads, because SingleThreadObjPool does not\n  // guarantee thread safety. This function also is not thread safe, but it's not a big problem. In\n  // the most cases, bugs will be successfully detected even thread unsafe behaviors happen.\n  void CheckOrSetSingleThreadFlag() {\n    if (unlikely(owner_thread_id_ == invalid_thread_id_)) {\n      owner_thread_id_ = std::this_thread::get_id();\n    } else {\n      CHECK(likely(owner_thread_id_ == std::this_thread::get_id()));\n    }\n  }\n\n  std::vector<T*> pool_;\n  std::thread::id invalid_thread_id_;\n  std::thread::id owner_thread_id_;\n};\n\n}  // namespace obj_pool\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_COMMON_SINGLE_THREAD_OBJ_POOL_H_\n"
  },
  {
    "path": "oneflow/core/common/single_thread_obj_pool_test.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"gtest/gtest.h\"\n#include \"oneflow/core/common/single_thread_obj_pool.h\"\n#include \"oneflow/core/common/util.h\"\n\nnamespace oneflow {\nnamespace obj_pool {\nnamespace test {\n\nTEST(SingleThreadObjPool, naive) {\n  auto pool = std::make_shared<SingleThreadObjPool<int>>();\n  auto* ptr = pool->make_shared().get();\n  ASSERT_EQ(ptr, pool->make_shared().get());\n}\n\nstruct Int {  // NOLINT\n  Int() : x(0) {}\n  explicit Int(int val) : x(val) {}\n  ~Int() { x = 0; }\n  int x;\n};\n\nTEST(SingleThreadObjPool, enable_reconstruct) {\n  auto pool = std::make_shared<SingleThreadObjPool<Int, kEnableReconstruct>>();\n  (void)pool->make_shared(333);\n  ASSERT_EQ(0, pool->make_shared()->x);\n}\n\nTEST(SingleThreadObjPool, disable_reconstruct) {\n  auto pool = std::make_shared<SingleThreadObjPool<Int, kDisableReconstruct>>();\n  int value = pool->make_shared(333)->x;\n  ASSERT_EQ(value, pool->make_shared()->x);\n}\n\n}  // namespace test\n}  // namespace obj_pool\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/common/singleton.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_COMMON_SINGLETON_H_\n#define ONEFLOW_CORE_COMMON_SINGLETON_H_\n\n#include \"oneflow/core/common/throw.h\"\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/common/constant.h\"\n\nnamespace oneflow {\n\ntemplate<typename T, typename Kind = void>\nclass Singleton final {\n public:\n  static T* Get() { return *GetPPtr(); }\n  static void SetAllocated(T* val) { *GetPPtr() = val; }\n  template<typename... Args>\n  static T* New(Args&&... args) {\n    CHECK(Get() == nullptr);\n    VLOG(3) << \"NewGlobal \" << typeid(T).name();\n    T* ptr = new T(std::forward<Args>(args)...);\n    *GetPPtr() = ptr;\n    return ptr;\n  }\n  static void Delete() {\n    if (Get() != nullptr) {\n      VLOG(3) << \"DeleteGlobal \" << typeid(T).name();\n      delete Get();\n      *GetPPtr() = nullptr;\n    }\n  }\n\n private:\n  static T** GetPPtr() {\n    CheckKind();\n    static T* ptr = nullptr;\n    return &ptr;\n  }\n  static void CheckKind() {\n    if (!std::is_same<Kind, void>::value) {\n      CHECK(Singleton<T>::Get() == nullptr)\n          << typeid(Singleton<T>).name() << \" are disable for avoiding misuse\";\n    }\n  }\n};\n\ntemplate<typename T, typename... Kind>\nMaybe<T*> SingletonMaybe() {\n  CHECK_NOTNULL_OR_RETURN((Singleton<T, Kind...>::Get())) << \" typeid: \" << typeid(T).name();\n  return Singleton<T, Kind...>::Get();\n}\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_COMMON_SINGLETON_H_\n"
  },
  {
    "path": "oneflow/core/common/sized_buffer_view.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_COMMON_SIZED_BUFFER_VIEW_H_\n#define ONEFLOW_COMMON_SIZED_BUFFER_VIEW_H_\n\nnamespace oneflow {\n\nstruct SizedBufferView {\n  size_t capacity;  // allocated memory size for `data' field\n  size_t size;      // valid data size\n  char data[0];\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_COMMON_SIZED_BUFFER_VIEW_H_\n"
  },
  {
    "path": "oneflow/core/common/small_vector.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_COMMON_SMALL_VECTOR_H_\n#define ONEFLOW_CORE_COMMON_SMALL_VECTOR_H_\n\n#include \"llvm/ADT/SmallVector.h\"\n#include \"oneflow/core/common/op_args_reserved_size.h\"\n#include \"oneflow/core/common/check.h\"\n\nnamespace oneflow {\n\ntemplate<typename T, size_t N = kOpArgsReservedSize>\nclass small_vector : public llvm::SmallVector<T, N> {\n  using Base = llvm::SmallVector<T, N>;\n\n public:\n  constexpr static size_t kInitialSize = N;\n  // https://stackoverflow.com/questions/27954940/a-using-statement-compiles-with-g-fails-compilation-with-clang\n  using Base::Base;\n\n  typename Base::reference at(typename Base::size_type idx) {\n    GLOGCHECK(idx < Base::size());\n    return (*this)[idx];\n  }\n  typename Base::const_reference at(typename Base::size_type idx) const {\n    GLOGCHECK(idx < Base::size());\n    return (*this)[idx];\n  }\n  typename Base::reference operator[](typename Base::size_type idx) { return this->data()[idx]; }\n  typename Base::const_reference operator[](typename Base::size_type idx) const {\n    return this->data()[idx];\n  }\n  typename Base::const_iterator cbegin() const {\n    return (typename Base::const_iterator)this->BeginX;\n  }\n  typename Base::const_iterator cend() const {\n    return (typename Base::const_iterator)(this->BeginX) + Base::size();\n  }\n  typename Base::const_iterator cbegin() { return (typename Base::const_iterator)this->BeginX; }\n  typename Base::const_iterator cend() {\n    return (typename Base::const_iterator)(this->BeginX) + Base::size();\n  }\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_COMMON_SMALL_VECTOR_H_\n"
  },
  {
    "path": "oneflow/core/common/spin_counter.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <chrono>\n#include \"oneflow/core/common/spin_counter.h\"\n#include \"oneflow/core/common/singleton.h\"\n#include \"oneflow/core/common/foreign_lock_helper.h\"\n\nnamespace oneflow {\n\nMaybe<void> SpinCounter::WaitUntilCntEqualZero() const {\n  return Singleton<ForeignLockHelper>::Get()->WithScopedRelease([&]() -> Maybe<void> {\n    while (cnt_val_ > 0) {}\n    return Maybe<void>::Ok();\n  });\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/common/spin_counter.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_COMMON_SPIN_COUNTER_H_\n#define ONEFLOW_CORE_COMMON_SPIN_COUNTER_H_\n\n#include <atomic>\n#include \"oneflow/core/common/maybe.h\"\n\nnamespace oneflow {\n\nclass SpinCounter final {\n public:\n  SpinCounter() = delete;\n  SpinCounter(const SpinCounter&) = delete;\n  SpinCounter(SpinCounter&&) = delete;\n  ~SpinCounter() = default;\n\n  explicit SpinCounter(int64_t cnt_val) : cnt_val_(cnt_val) {}\n\n  int64_t Decrease() { return --cnt_val_; }\n  void Reset(int64_t cnt_val) { cnt_val_ = cnt_val; }\n  Maybe<void> WaitUntilCntEqualZero() const;\n\n private:\n  std::atomic<int64_t> cnt_val_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_COMMON_SPIN_COUNTER_H_\n"
  },
  {
    "path": "oneflow/core/common/static_check.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_COMMON_STATIC_CHECK_H_\n#define ONEFLOW_CORE_COMMON_STATIC_CHECK_H_\n\n#include \"type_traits.h\"\n\nnamespace oneflow {\n\nnamespace private_details {\n\ntemplate<template<typename> class Predicator>\nstruct StaticReduce {\n  template<typename... Args>\n  struct All;\n  template<typename Void>\n  struct All<Void> {\n    static_assert(std::is_same<Void, void>::value, \"\");\n    static constexpr bool value = true;\n  };\n  template<typename Void, typename T, typename... Args>\n  struct All<Void, T, Args...> {\n    static constexpr bool value = Predicator<T>::value && All<Void, Args...>::value;\n  };\n\n  template<typename... Args>\n  struct Any;\n  template<typename Void>\n  struct Any<Void> {\n    static_assert(std::is_same<Void, void>::value, \"\");\n    static constexpr bool value = false;\n  };\n  template<typename Void, typename T, typename... Args>\n  struct Any<Void, T, Args...> {\n    static constexpr bool value = Predicator<T>::value || Any<Void, Args...>::value;\n  };\n};\n\n}  // namespace private_details\n\ntemplate<template<typename> class Predicator, typename... Args>\nstruct StaticAll {\n  static constexpr bool value =\n      private_details::StaticReduce<Predicator>::template All<void, Args...>::value;\n};\n\ntemplate<template<typename> class Predicator, typename... Args>\nstruct StaticAny {\n  static constexpr bool value =\n      private_details::StaticReduce<Predicator>::template Any<void, Args...>::value;\n};\n\ntemplate<typename T>\nstruct IsOutArg {\n  static constexpr bool value =\n      (std::is_reference<T>::value\n       && !std::is_const<typename std::remove_reference<T>::type>::value)\n      || (std::is_pointer<T>::value\n          && !std::is_const<typename std::remove_pointer<T>::type>::value);\n};\n\ntemplate<typename T>\nstruct IsDecayedScalarType {\n  static constexpr bool value = IsScalarType<typename std::decay<T>::type>::value;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_COMMON_STATIC_CHECK_H_\n"
  },
  {
    "path": "oneflow/core/common/static_global.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_COMMON_STATIC_GLOBAL_H_\n#define ONEFLOW_CORE_COMMON_STATIC_GLOBAL_H_\n\n#include <mutex>\n#include \"oneflow/core/common/decorator.h\"\n\nnamespace oneflow {\n\ntemplate<typename... Args>\nstruct StaticGlobalCopiable;\n\ntemplate<typename RetT>\nstruct StaticGlobalCopiable<RetT> {\n  template<RetT (*func)()>\n  static RetT Call() {\n    static RetT value = func();\n    return value;\n  }\n};\n\ntemplate<typename RetT, typename Arg0>\nstruct StaticGlobalCopiable<RetT, Arg0> {\n  template<RetT (*func)(Arg0)>\n  static RetT Call(Arg0 arg0) {\n    using KeyT = typename std::decay<Arg0>::type;\n    using MappedT = typename std::decay<RetT>::type;\n    static std::mutex mutex;\n    static std::unordered_map<KeyT, MappedT> map;\n    {\n      std::unique_lock<std::mutex> lock(mutex);\n      auto iter = map.find(arg0);\n      if (iter != map.end()) { return iter->second; }\n    }\n    auto obj = func(arg0);\n    {\n      std::unique_lock<std::mutex> lock(mutex);\n      return map.emplace(arg0, std::move(obj)).first->second;\n    }\n  }\n\n private:\n  static_assert(!IsOutArg<Arg0>::value, \"\");\n  static_assert(!StaticAny<IsOutArg, Arg0>::value, \"\");\n};\n\ntemplate<typename RetT, typename Arg0, typename Arg1, typename... Args>\nstruct StaticGlobalCopiable<RetT, Arg0, Arg1, Args...> {\n  template<RetT (*func)(Arg0, Arg1, Args...)>\n  static RetT Call(Arg0 arg0, Arg1 arg1, Args... args) {\n    using KeyT0 = typename std::decay<Arg0>::type;\n    using KeyT1 = typename std::decay<Arg1>::type;\n    using KeyT = std::tuple<KeyT0, KeyT1, typename std::decay<Args>::type...>;\n    using MappedT = typename std::decay<RetT>::type;\n    static std::mutex mutex;\n    static std::unordered_map<KeyT, MappedT> map;\n    const auto& key = KeyT(arg0, arg1, args...);\n    {\n      std::unique_lock<std::mutex> lock(mutex);\n      auto iter = map.find(key);\n      if (iter != map.end()) { return iter->second; }\n    }\n    auto obj = func(arg0, arg1, args...);\n    {\n      std::unique_lock<std::mutex> lock(mutex);\n      return map.emplace(key, std::move(obj)).first->second;\n    }\n  }\n\n private:\n  static_assert(!StaticAny<IsOutArg, Arg0, Arg1, Args...>::value, \"\");\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_COMMON_STATIC_GLOBAL_H_\n"
  },
  {
    "path": "oneflow/core/common/steady_vector.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_COMMON_STEADY_VECTOR_H_\n#define ONEFLOW_CORE_COMMON_STEADY_VECTOR_H_\n\n#include <memory>\n#include <array>\n#include <mutex>\n#include <cmath>\n#include \"oneflow/core/common/throw.h\"\n\nnamespace oneflow {\n\ntemplate<typename T, int N = 20>\nclass SteadyVector {\n public:\n  SteadyVector() : size_(0) {}\n  ~SteadyVector() = default;\n\n  using value_type = const T;\n  using size_type = size_t;\n\n  // thread safe.\n  size_t size() const { return size_.load(std::memory_order_acquire); }\n\n  // thread safe.\n  const T& at(size_t index) const {\n    CHECK_GE(index, 0);\n    CHECK_LT(index, size_);\n    return (*this)[index];\n  }\n\n  // thread safe.\n  const T& operator[](size_t index) const {\n    int gran = 0;\n    size_t start = 0;\n    GetGranularityAndStart(index, &gran, &start);\n    return granularity2data_[gran].get()[index - start];\n  }\n\n  // `index` should be <= size()\n  void SetOrAdd(size_t index, T value) {\n    std::unique_lock<std::mutex> lock(mutex_);\n    size_t size = size_.load(std::memory_order_relaxed);\n    CHECK_LE(index, size) << \"index out of range\";\n    if (index == size) {\n      int granularity = GetGranularity(size);\n      if (size + 1 == (1 << granularity)) {\n        CHECK_LT(granularity, N);\n        granularity2data_[granularity].reset(new T[1 << granularity]);\n      }\n      *Mutable(index) = std::move(value);\n      size_.fetch_add(1, std::memory_order_release);\n    } else {\n      *Mutable(index) = std::move(value);\n    }\n  }\n\n  void push_back(const T& elem) { SetOrAdd(size_, elem); }\n\n private:\n  T* Mutable(size_t index) {\n    int gran = 0;\n    size_t start = 0;\n    GetGranularityAndStart(index, &gran, &start);\n    return &granularity2data_[gran].get()[index - start];\n  }\n\n  static void GetGranularityAndStart(size_t index, int* gran, size_t* start) {\n    *gran = GetGranularity(index);\n    *start = (1 << *gran) - 1;\n  }\n\n#ifdef __GNUC__\n#define LOG2(x) ((unsigned)(8 * sizeof(unsigned long long) - __builtin_clzll((x)) - 1))\n#else\n#define LOG2(x) std::log2(x)\n#endif\n\n  static int GetGranularity(size_t index) { return LOG2(index + 1); }\n\n#undef LOG2\n\n  std::atomic<size_t> size_;\n  std::mutex mutex_;\n  std::array<std::unique_ptr<T[]>, N> granularity2data_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_COMMON_STEADY_VECTOR_H_\n"
  },
  {
    "path": "oneflow/core/common/steady_vector_test.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"gtest/gtest.h\"\n#include \"oneflow/core/common/steady_vector.h\"\n\nnamespace oneflow {\nnamespace test {\n\nvoid TestSteadyVector(int granularity) {\n  CHECK_GT(granularity, 0);\n  SteadyVector<int> vec;\n  ASSERT_EQ(vec.size(), 0);\n  for (int i = 0; i < (1 << granularity); ++i) {\n    vec.push_back(i);\n    ASSERT_EQ(vec.at(i), i);\n    ASSERT_EQ(vec.size(), i + 1);\n  }\n}\n\nTEST(SteadyVector, simple) { TestSteadyVector(6); }\n\n}  // namespace test\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/common/str_util.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <glog/logging.h>\n#include <random>\n#include \"oneflow/core/common/str_util.h\"\n\nnamespace oneflow {\n\nnamespace internal {\n\nstd::string JoinPathImpl(std::initializer_list<std::string> paths) {\n  std::string result;\n  for (const std::string& path : paths) {\n    if (path.empty()) continue;\n    if (result.empty()) {\n      result = path;\n      continue;\n    }\n    if (result[result.size() - 1] == '/') {\n      if (IsAbsolutePath(path)) {\n        result.append(path.substr(1));\n      } else {\n        result.append(path);\n      }\n    } else {\n      if (IsAbsolutePath(path)) {\n        result.append(path);\n      } else {\n        result += (\"/\" + path);\n      }\n    }\n  }\n  return result;\n}\n\nstd::string GetHashKeyImpl(std::initializer_list<int> integers) {\n  std::string result = \"\";\n  for (int integer : integers) { result += std::to_string(integer) + \",\"; }\n  return result;\n}\n\n}  // namespace internal\n\nconst char* StrToToken(const char* text, const std::string& delims, std::string* token) {\n  token->clear();\n  while (*text != '\\0' && delims.find(*text) != std::string::npos) { text++; }\n  while (*text != '\\0' && delims.find(*text) == std::string::npos) { token->push_back(*text++); }\n  return text;\n}\n\nvoid Split(const std::string& text, const std::string& delims,\n           std::function<void(std::string&&)> Func) {\n  size_t token_start = 0;\n  if (text.empty()) { return; }\n  for (size_t i = 0; i < text.size() + 1; ++i) {\n    if ((i == text.size()) || (delims.find(text[i]) != std::string::npos)) {\n      Func(text.substr(token_start, i - token_start));\n      token_start = i + 1;\n    }\n  }\n}\n\nstd::string Dirname(const std::string& path) {\n  size_t found = path.rfind('/');\n  if (found == std::string::npos) { return \"\"; }\n  if (found == 0) { return \"/\"; }\n  return path.substr(0, found);\n}\n\nstd::string Basename(const std::string& path) {\n  size_t found = path.rfind('/');\n  if (found == std::string::npos) { return path; }\n  return path.substr(found + 1);\n}\n\nstd::string CleanPath(const std::string& unclean_path) {\n  std::string path = unclean_path;\n  const char* src = path.c_str();\n  std::string::iterator dst = path.begin();\n\n  // Check for absolute path and determine initial backtrack limit.\n  const bool is_absolute_path = *src == '/';\n  if (is_absolute_path) {\n    *dst++ = *src++;\n    while (*src == '/') ++src;\n  }\n  std::string::const_iterator backtrack_limit = dst;\n\n  // Process all parts\n  while (*src) {\n    bool parsed = false;\n\n    if (src[0] == '.') {\n      //  1dot \".<whateverisnext>\", check for END or SEP.\n      if (src[1] == '/' || !src[1]) {\n        if (*++src) { ++src; }\n        parsed = true;\n      } else if (src[1] == '.' && (src[2] == '/' || !src[2])) {\n        // 2dot END or SEP (\"..\" | \"../<whateverisnext>\").\n        src += 2;\n        if (dst != backtrack_limit) {\n          // We can backtrack the previous part\n          for (--dst; dst != backtrack_limit && dst[-1] != '/'; --dst) {\n            // Empty.\n          }\n        } else if (!is_absolute_path) {\n          // Failed to backtrack and we can't skip it either. Rewind and copy.\n          src -= 2;\n          *dst++ = *src++;\n          *dst++ = *src++;\n          if (*src) { *dst++ = *src; }\n          // We can never backtrack over a copied \"../\" part so set new limit.\n          backtrack_limit = dst;\n        }\n        if (*src) { ++src; }\n        parsed = true;\n      }\n    }\n\n    // If not parsed, copy entire part until the next SEP or EOS.\n    if (!parsed) {\n      while (*src && *src != '/') { *dst++ = *src++; }\n      if (*src) { *dst++ = *src++; }\n    }\n\n    // Skip consecutive SEP occurrences\n    while (*src == '/') { ++src; }\n  }\n\n  // Calculate and check the length of the cleaned path.\n  std::string::difference_type path_length = dst - path.begin();\n  if (path_length != 0) {\n    // Remove trailing '/' except if it is root path (\"/\" ==> path_length := 1)\n    if (path_length > 1 && path[path_length - 1] == '/') { --path_length; }\n    path.resize(path_length);\n  } else {\n    // The cleaned path is empty; assign \".\" as per the spec.\n    path.assign(1, '.');\n  }\n  return path;\n}\n\nvoid GetPrefixAndIndex(const std::string& prefix_and_idx, std::string* prefix, int32_t* index) {\n  const size_t underline_pos = prefix_and_idx.rfind('_');\n  CHECK_NE(underline_pos, std::string::npos);\n  CHECK_GT(underline_pos, 0);\n  CHECK_LT(underline_pos, prefix_and_idx.size() - 1);\n  *prefix = prefix_and_idx.substr(0, underline_pos);\n  *index = oneflow_cast<int32_t>(prefix_and_idx.substr(underline_pos + 1));\n  CHECK_GE(*index, 0);\n}\n\nbool TryGetPrefixAndIndex(const std::string& prefix_and_idx, std::string* prefix, int32_t* index) {\n  const size_t underline_pos = prefix_and_idx.rfind('_');\n  if (underline_pos == std::string::npos) { return false; }\n  if (underline_pos == 0) { return false; }\n  if (underline_pos == prefix_and_idx.size() - 1) { return false; }\n  *prefix = prefix_and_idx.substr(0, underline_pos);\n  std::string index_str = prefix_and_idx.substr(underline_pos + 1);\n  if (IsStrInt(index_str) == false) { return false; }\n  *index = oneflow_cast<int32_t>(index_str);\n  return *index >= 0;\n}\n\nstd::string ToLower(const std::string& cap) {\n  std::string small;\n  std::transform(cap.begin(), cap.end(), small.begin(),\n                 [](unsigned char c) { return std::tolower(c); });\n  return small;\n}\n\n// https://stackoverflow.com/questions/440133/how-do-i-create-a-random-alpha-numeric-string-in-c\nstd::string GenAlphaNumericString(size_t len) {\n  static thread_local const std::string alphanum(\"0123456789\"\n                                                 \"ABCDEFGHIJKLMNOPQRSTUVWXYZ\"\n                                                 \"abcdefghijklmnopqrstuvwxyz\");\n  std::string tmp_s;\n  tmp_s.reserve(len);\n\n  std::random_device rd{};\n  std::mt19937 mt(rd());\n  std::uniform_int_distribution<> dist(0, 1024);\n  for (int i = 0; i < len; ++i) { tmp_s += alphanum.at(dist(mt) % alphanum.size()); }\n  return tmp_s;\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/common/str_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_COMMON_STR_UTIL_H_\n#define ONEFLOW_CORE_COMMON_STR_UTIL_H_\n\n#include <functional>\n#include <string>\n#include \"oneflow/core/common/util.h\"\n\nnamespace oneflow {\n\ninline bool IsStrInt(const std::string& s) {\n  if (s.empty() || (!isdigit(s[0]) && (s[0] != '-'))) { return false; }\n  char* end_ptr = nullptr;\n  strtoll(s.c_str(), &end_ptr, 0);\n  return (*end_ptr == 0);\n}\n\ninline std::string StrCat(const std::string& prefix, int64_t id) {\n  return prefix + std::to_string(id);\n}\n\ninline void StringReplace(std::string* str, char old_ch, char new_ch) {\n  for (size_t i = 0; i < str->size(); ++i) {\n    if (str->at(i) == old_ch) { str->at(i) = new_ch; }\n  }\n}\n\nconst char* StrToToken(const char* text, const std::string& delims, std::string* token);\n\nvoid Split(const std::string& text, const std::string& delims,\n           std::function<void(std::string&&)> Func);\n\ntemplate<typename T>\nvoid SplitAndParseAs(const std::string& text, const std::string& delims,\n                     std::function<void(T&&)> Func) {\n  Split(text, delims, [&Func](std::string&& s) { Func(oneflow_cast<T>(s)); });\n}\n\n// Return true if path is absolute.\ninline bool IsAbsolutePath(const std::string& path) { return !path.empty() && path[0] == '/'; }\n\nvoid GetPrefixAndIndex(const std::string& prefix_and_idx, std::string* prefix, int32_t* index);\n\nbool TryGetPrefixAndIndex(const std::string& prefix_and_idx, std::string* prefix, int32_t* index);\n\nnamespace internal {\n\nstd::string JoinPathImpl(std::initializer_list<std::string> paths);\n\nstd::string GetHashKeyImpl(std::initializer_list<int> integers);\n\n}  // namespace internal\n\n// Join multiple paths together, without introducing unnecessary path\n// separators.\n// For example:\n//\n//  Arguments                  | JoinPath\n//  ---------------------------+----------\n//  '/foo', 'bar'              | /foo/bar\n//  '/foo/', 'bar'             | /foo/bar\n//  '/foo', '/bar'             | /foo/bar\n//\n// Usage:\n// string path = JoinPath(\"/mydir\", filename);\n// string path = JoinPath(FLAGS_test_srcdir, filename);\n// string path = JoinPath(\"/full\", \"path\", \"to\", \"filename);\ntemplate<typename... T>\nstd::string JoinPath(const T&... args) {\n  return internal::JoinPathImpl({args...});\n}\n\n// Returns the part of the path before the final \"/\".  If there is a single\n// leading \"/\" in the path, the result will be the leading \"/\".  If there is\n// no \"/\" in the path, the result is the empty prefix of the input.\nstd::string Dirname(const std::string& path);\n\n// Returns the part of the path after the final \"/\".  If there is no\n// \"/\" in the path, the result is the same as the input.\nstd::string Basename(const std::string& path);\n\n// Collapse duplicate \"/\"s, resolve \"..\" and \".\" path elements, remove\n// trailing \"/\".\n//\n// NOTE: This respects relative vs. absolute paths, but does not\n// invoke any system calls (getcwd(2)) in order to resolve relative\n// paths with respect to the actual working directory.  That is, this is purely\n// string manipulation, completely independent of process state.\nstd::string CleanPath(const std::string& path);\n\ntemplate<typename... T>\nstd::string GetHashKey(const T&... args) {\n  return internal::GetHashKeyImpl({args...});\n}\n\nstd::string ToLower(const std::string& cap);\n\nstd::string GenAlphaNumericString(size_t len);\n\ntemplate<typename CallbackT>\nconst std::string& ReturnEmptyStr(const CallbackT& Callback) {\n  Callback();\n  static std::string empty{};\n  return empty;\n}\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_COMMON_STR_UTIL_H_\n"
  },
  {
    "path": "oneflow/core/common/stream_type.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_COMMON_STREAM_TYPE_H_\n#define ONEFLOW_CORE_COMMON_STREAM_TYPE_H_\n\n#include <functional>\n#include <array>\n#include \"oneflow/core/common/preprocessor.h\"\n#include \"oneflow/core/common/throw.h\"\n\nnamespace oneflow {\n\nenum class StreamType {\n  kInvalid = 0,\n  kCompute,\n  kHost2Device,\n  kDevice2Host,\n  kCcl,\n  kBarrier,\n  kCriticalSection,\n  kLazyJobLauncher,\n  kPinnedCompute\n};\n\ntemplate<typename DerivedT>\nstruct StreamTypeVisitor {\n  template<typename... Args>\n  static auto Visit(StreamType stream_type, Args&&... args) {\n    switch (stream_type) {\n      case StreamType::kInvalid: LOG(FATAL) << \"invalid stream type\";\n      case StreamType::kCompute: return DerivedT::VisitCompute(std::forward<Args>(args)...);\n      case StreamType::kHost2Device: return DerivedT::VisitHost2Device(std::forward<Args>(args)...);\n      case StreamType::kDevice2Host: return DerivedT::VisitDevice2Host(std::forward<Args>(args)...);\n      case StreamType::kCcl: return DerivedT::VisitCcl(std::forward<Args>(args)...);\n      case StreamType::kBarrier: return DerivedT::VisitBarrier(std::forward<Args>(args)...);\n      case StreamType::kCriticalSection:\n        return DerivedT::VisitCriticalSection(std::forward<Args>(args)...);\n      case StreamType::kLazyJobLauncher:\n        return DerivedT::VisitLazyJobLauncher(std::forward<Args>(args)...);\n      case StreamType::kPinnedCompute:\n        return DerivedT::VisitPinnedCompute(std::forward<Args>(args)...);\n    }\n    LOG(FATAL) << \"invalid stream type\";\n  }\n};\n\n}  // namespace oneflow\n\nnamespace std {\n\ntemplate<>\nstruct hash<oneflow::StreamType> final {\n  size_t operator()(const oneflow::StreamType& stream_type) const {\n    return static_cast<int>(stream_type);\n  }\n};\n\n}  // namespace std\n\n#endif  // ONEFLOW_CORE_COMMON_STREAM_TYPE_H_\n"
  },
  {
    "path": "oneflow/core/common/stride.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include <numeric>\n\n#include \"oneflow/core/common/stride.h\"\n#include \"oneflow/core/common/constant.h\"\n#include \"oneflow/core/common/protobuf.h\"\n\nnamespace oneflow {\n\nStride::Stride(const ShapeView& shape) {\n  const int64_t ndim = shape.NumAxes();\n  resize(ndim);\n  if (ndim > 0 && shape.elem_cnt() > 0) {\n    std::exclusive_scan(shape.rbegin(), shape.rend(), rbegin(), (int64_t)1, std::multiplies<>{});\n  } else if (ndim > 0 && shape.elem_cnt() == 0) {\n    // 0-size shape\n    small_vector<int64_t, kMaxNumDims> tmp_shape(ndim);\n    for (int64_t i = 0; i < ndim; ++i) { tmp_shape[i] = shape.At(i) > 0 ? shape.At(i) : 1; }\n    std::exclusive_scan(tmp_shape.rbegin(), tmp_shape.rend(), rbegin(), (int64_t)1,\n                        std::multiplies<>{});\n  }\n}\n\nStride::Stride(const Shape& shape) {\n  if (shape.is_initialized()) {\n    ShapeView shape_view(shape);\n    new (this) Stride(shape_view);\n  }\n}\n\nStride::Stride(const std::shared_ptr<Shape>& shape) : Stride(*shape) {}\n\nStride::Stride(const Int64ListProto& stride_proto)\n    : DimVector(stride_proto.dim().begin(), stride_proto.dim().end()) {}\n\nStride& Stride::CheckNumAxesIdenticalAndAssign(const Stride& stride) {\n  CHECK_EQ(size(), stride.size());\n  assign(stride);\n  return *this;\n}\n\nstd::string Stride::ToString() const {\n  std::stringstream ss;\n  int32_t idx = 0;\n  ss << \"(\";\n  for (int64_t dim : *this) {\n    ss << dim;\n    if (++idx != this->size() || this->size() == 1) { ss << \",\"; }\n  }\n  ss << \")\";\n  return ss.str();\n}\n\nvoid Stride::ToProto(Int64ListProto* ret) const {\n  *(ret->mutable_dim()) = PbRf<int64_t>(begin(), end());\n}\n\nstd::ostream& operator<<(std::ostream& out, const Stride& stride) {\n  out << stride.ToString();\n  return out;\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/common/stride.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#ifndef ONEFLOW_CORE_FRAMEWORK_STRIDE_H_\n#define ONEFLOW_CORE_FRAMEWORK_STRIDE_H_\n\n#include \"oneflow/core/common/shape_view.h\"\n#include \"oneflow/core/common/sequential.pb.h\"\n#include \"oneflow/core/common/util.h\"\n\nnamespace oneflow {\n\nclass Int64ListProto;\n\nclass Stride final : public DimVector {\n public:\n  Stride() = default;\n  using DimVector::DimVector;\n  explicit Stride(const ShapeView& shape);\n  explicit Stride(const Shape& shape);\n  explicit Stride(const std::shared_ptr<Shape>& shape);\n  explicit Stride(const Int64ListProto& stride_proto);\n  Stride& CheckNumAxesIdenticalAndAssign(const Stride& stride);\n  ~Stride() = default;\n\n  std::string ToString() const;\n  void ToProto(Int64ListProto*) const;\n};\n\nstd::ostream& operator<<(std::ostream& out, const Stride& stride);\n\n}  // namespace oneflow\n\nnamespace std {\n\ntemplate<>\nstruct hash<oneflow::Stride> {\n  size_t operator()(const oneflow::Stride& stride) const {\n    size_t ret = stride.size();\n    FOR_RANGE(int, i, 0, stride.size()) { oneflow::AddHash(&ret, stride.at(i)); }\n    return ret;\n  }\n};\n\n}  // namespace std\n\n#endif  // ONEFLOW_CORE_FRAMEWORK_STRIDE_H_\n"
  },
  {
    "path": "oneflow/core/common/switch_func.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_COMMON_SWITCH_FUNC_H_\n#define ONEFLOW_CORE_COMMON_SWITCH_FUNC_H_\n\n#include \"oneflow/core/common/preprocessor.h\"\n#include <tuple>\n#include <utility>\n\ntemplate<typename... Args>\nauto SwitchCase(Args&&... args) {\n  return std::make_tuple(std::forward<Args>(args)...);\n}\n\n#define DEFINE_STATIC_SWITCH_FUNC(return_type, func_name, make_switch_entry, ctrv_seq, ...) \\\n  DEFINE_STATIC_SWITCH_FUNC_FROM_TUPLE(return_type, func_name, make_switch_entry,           \\\n                                       OF_PP_CAT((ctrv_seq, ##__VA_ARGS__), ))\n\n#define DEFINE_STATIC_SWITCH_FUNC_FROM_TUPLE(return_type, func_name, make_switch_entry,        \\\n                                             ctrv_seq_tuple)                                   \\\n  template<typename... Args>                                                                   \\\n  static return_type Switch##func_name(                                                        \\\n      const OF_PP_I_CTRV_SEQ_TUPLE2STD_TUPLE_TYPE(ctrv_seq_tuple) & switch_tuple,              \\\n      Args && ... args) {                                                                      \\\n    static const std::map<OF_PP_I_CTRV_SEQ_TUPLE2STD_TUPLE_TYPE(ctrv_seq_tuple),               \\\n                          std::function<return_type(Args&&...)>>                               \\\n        case_handlers{OF_PP_I_MAKE_ALL_SWITCH_ENTRIES_FROM_TUPLE(make_switch_entry, func_name, \\\n                                                                 Args, ctrv_seq_tuple)};       \\\n    return case_handlers.at(switch_tuple)(std::forward<Args>(args)...);                        \\\n  }\n\n// CTRV: Compile-time Token and Runtime Value pair,\n// CTRV example: (float, DataType::kFloat)\n// TYPED_CTRV_SEQ example: (DataType, ((float, DataType::kFloat)))\n\n#define MAKE_DATA_TYPE_CTRV_SEQ(data_type_seq) MAKE_TYPED_CTRV_SEQ(DataType, data_type_seq)\n#define MAKE_DEVICE_TYPE_CTRV_SEQ(device_type_seq) \\\n  MAKE_TYPED_CTRV_SEQ(DeviceType,                  \\\n                      OF_PP_FOR_EACH_TUPLE(OF_PP_I_MAKE_REPLICATE_TUPLE_SEQ, device_type_seq))\n#define MAKE_NDIM_CTRV_SEQ(ndim_seq) \\\n  MAKE_TYPED_CTRV_SEQ(int32_t, OF_PP_FOR_EACH_TUPLE(OF_PP_I_MAKE_REPLICATE_TUPLE_SEQ, ndim_seq))\n\n#define MAKE_STRINGIZED_DATA_TYPE_CTRV(data_type_pair) \\\n  (OF_PP_PAIR_FIRST(data_type_pair), OF_PP_STRINGIZE(OF_PP_PAIR_FIRST(data_type_pair)))\n#define MAKE_STRINGIZED_DATA_TYPE_CTRV_SEQ(data_type_seq) \\\n  (std::string, OF_PP_SEQ_MAP(MAKE_STRINGIZED_DATA_TYPE_CTRV, data_type_seq))\n\n#define MAKE_TYPED_CTRV_SEQ(runtime_value_type, ctrv_pair_seq) (runtime_value_type, ctrv_pair_seq)\n\n//  internal preprocessor macros\n\n#define OF_PP_I_MAKE_SWITCH_ENTRY_MAP_PAIR(switch_case, func_args_type, func) \\\n  {switch_case,                                                               \\\n   [](func_args_type&&... args) { return func(std::forward<func_args_type>(args)...); }},\n\n#define OF_PP_I_MAKE_REPLICATE_TUPLE_SEQ(x) OF_PP_MAKE_TUPLE_SEQ(x, x)\n\n#define OF_PP_I_MAKE_SWITCH_FUNC_ENTRY_1(make_template_func, func_name, func_args_type, \\\n                                         switch_case_pair0)                             \\\n  OF_PP_I_MAKE_SWITCH_ENTRY_MAP_PAIR(                                                   \\\n      SwitchCase(OF_PP_PAIR_SECOND(switch_case_pair0)), func_args_type,                 \\\n      make_template_func(func_name, OF_PP_PAIR_FIRST(switch_case_pair0)))\n\n#define OF_PP_I_MAKE_SWITCH_FUNC_ENTRY_2(make_template_func, func_name, func_args_type,       \\\n                                         switch_case_pair0, switch_case_pair1)                \\\n  OF_PP_I_MAKE_SWITCH_ENTRY_MAP_PAIR(                                                         \\\n      SwitchCase(OF_PP_PAIR_SECOND(switch_case_pair0), OF_PP_PAIR_SECOND(switch_case_pair1)), \\\n      func_args_type,                                                                         \\\n      make_template_func(func_name, OF_PP_PAIR_FIRST(switch_case_pair0),                      \\\n                         OF_PP_PAIR_FIRST(switch_case_pair1)))\n\n#define OF_PP_I_MAKE_SWITCH_FUNC_ENTRY_3(make_template_func, func_name, func_args_type,           \\\n                                         switch_case_pair0, switch_case_pair1, switch_case_pair2) \\\n  OF_PP_I_MAKE_SWITCH_ENTRY_MAP_PAIR(                                                             \\\n      SwitchCase(OF_PP_PAIR_SECOND(switch_case_pair0), OF_PP_PAIR_SECOND(switch_case_pair1),      \\\n                 OF_PP_PAIR_SECOND(switch_case_pair2)),                                           \\\n      func_args_type,                                                                             \\\n      make_template_func(func_name, OF_PP_PAIR_FIRST(switch_case_pair0),                          \\\n                         OF_PP_PAIR_FIRST(switch_case_pair1),                                     \\\n                         OF_PP_PAIR_FIRST(switch_case_pair2)))\n\n#define OF_PP_I_MAKE_SWITCH_FUNC_ENTRY_4(make_template_func, func_name, func_args_type,            \\\n                                         switch_case_pair0, switch_case_pair1, switch_case_pair2,  \\\n                                         switch_case_pair3)                                        \\\n  OF_PP_I_MAKE_SWITCH_ENTRY_MAP_PAIR(                                                              \\\n      SwitchCase(OF_PP_PAIR_SECOND(switch_case_pair0), OF_PP_PAIR_SECOND(switch_case_pair1),       \\\n                 OF_PP_PAIR_SECOND(switch_case_pair2), OF_PP_PAIR_SECOND(switch_case_pair3)),      \\\n      func_args_type,                                                                              \\\n      make_template_func(func_name, OF_PP_PAIR_FIRST(switch_case_pair0),                           \\\n                         OF_PP_PAIR_FIRST(switch_case_pair1), OF_PP_PAIR_FIRST(switch_case_pair2), \\\n                         OF_PP_PAIR_FIRST(switch_case_pair3)))\n\n#define OF_PP_I_MAKE_ALL_SWITCH_ENTRIES_FROM_TUPLE(make_switch_entry, func_name, args_type, t) \\\n  OF_PP_FORCE(OF_PP_CAT(OF_PP_I_MAKE_ALL_SWITCH_ENTRIES_FROM_TUPLE_, OF_PP_TUPLE_SIZE(t))(     \\\n      make_switch_entry, func_name, args_type, t))\n\n#define OF_PP_I_MAKE_ALL_SWITCH_ENTRIES_FROM_TUPLE_1(make_switch_entry, func_name, args_type, \\\n                                                     ctrv_seq_tuple)                          \\\n  OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(OF_PP_I_MAKE_SWITCH_FUNC_ENTRY_1, (make_switch_entry),     \\\n                                   (func_name), (args_type),                                  \\\n                                   OF_PP_PAIR_SECOND(OF_PP_TUPLE_ELEM(0, ctrv_seq_tuple)))\n#define OF_PP_I_MAKE_ALL_SWITCH_ENTRIES_FROM_TUPLE_2(make_switch_entry, func_name, args_type, \\\n                                                     ctrv_seq_tuple)                          \\\n  OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(OF_PP_I_MAKE_SWITCH_FUNC_ENTRY_2, (make_switch_entry),     \\\n                                   (func_name), (args_type),                                  \\\n                                   OF_PP_PAIR_SECOND(OF_PP_TUPLE_ELEM(0, ctrv_seq_tuple)),    \\\n                                   OF_PP_PAIR_SECOND(OF_PP_TUPLE_ELEM(1, ctrv_seq_tuple)))\n#define OF_PP_I_MAKE_ALL_SWITCH_ENTRIES_FROM_TUPLE_3(make_switch_entry, func_name, args_type, \\\n                                                     ctrv_seq_tuple)                          \\\n  OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(OF_PP_I_MAKE_SWITCH_FUNC_ENTRY_3, (make_switch_entry),     \\\n                                   (func_name), (args_type),                                  \\\n                                   OF_PP_PAIR_SECOND(OF_PP_TUPLE_ELEM(0, ctrv_seq_tuple)),    \\\n                                   OF_PP_PAIR_SECOND(OF_PP_TUPLE_ELEM(1, ctrv_seq_tuple)),    \\\n                                   OF_PP_PAIR_SECOND(OF_PP_TUPLE_ELEM(2, ctrv_seq_tuple)))\n#define OF_PP_I_MAKE_ALL_SWITCH_ENTRIES_FROM_TUPLE_4(make_switch_entry, func_name, args_type, \\\n                                                     ctrv_seq_tuple)                          \\\n  OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(OF_PP_I_MAKE_SWITCH_FUNC_ENTRY_4, (make_switch_entry),     \\\n                                   (func_name), (args_type),                                  \\\n                                   OF_PP_PAIR_SECOND(OF_PP_TUPLE_ELEM(0, ctrv_seq_tuple)),    \\\n                                   OF_PP_PAIR_SECOND(OF_PP_TUPLE_ELEM(1, ctrv_seq_tuple)),    \\\n                                   OF_PP_PAIR_SECOND(OF_PP_TUPLE_ELEM(2, ctrv_seq_tuple)),    \\\n                                   OF_PP_PAIR_SECOND(OF_PP_TUPLE_ELEM(3, ctrv_seq_tuple)))\n\n#define OF_PP_I_CTRV_SEQ_TUPLE2STD_TUPLE_TYPE(t) \\\n  OF_PP_FORCE(OF_PP_CAT(OF_PP_I_CTRV_SEQ_TUPLE2STD_TUPLE_TYPE_, OF_PP_TUPLE_SIZE(t))(t))\n\n#define OF_PP_I_CTRV_SEQ_TUPLE2STD_TUPLE_TYPE_1(t) \\\n  std::tuple<OF_PP_PAIR_FIRST(OF_PP_TUPLE_ELEM(0, t))>\n#define OF_PP_I_CTRV_SEQ_TUPLE2STD_TUPLE_TYPE_2(t) \\\n  std::tuple<OF_PP_PAIR_FIRST(OF_PP_TUPLE_ELEM(0, t)), OF_PP_PAIR_FIRST(OF_PP_TUPLE_ELEM(1, t))>\n#define OF_PP_I_CTRV_SEQ_TUPLE2STD_TUPLE_TYPE_3(t)                                               \\\n  std::tuple<OF_PP_PAIR_FIRST(OF_PP_TUPLE_ELEM(0, t)), OF_PP_PAIR_FIRST(OF_PP_TUPLE_ELEM(1, t)), \\\n             OF_PP_PAIR_FIRST(OF_PP_TUPLE_ELEM(2, t))>\n#define OF_PP_I_CTRV_SEQ_TUPLE2STD_TUPLE_TYPE_4(t)                                               \\\n  std::tuple<OF_PP_PAIR_FIRST(OF_PP_TUPLE_ELEM(0, t)), OF_PP_PAIR_FIRST(OF_PP_TUPLE_ELEM(1, t)), \\\n             OF_PP_PAIR_FIRST(OF_PP_TUPLE_ELEM(2, t)), OF_PP_PAIR_FIRST(OF_PP_TUPLE_ELEM(3, t))>\n\n#endif  // ONEFLOW_CORE_COMMON_SWITCH_FUNC_H_\n"
  },
  {
    "path": "oneflow/core/common/symbol.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_COMMON_SYMBOL_H_\n#define ONEFLOW_CORE_COMMON_SYMBOL_H_\n\n#include <mutex>\n#include <memory>\n#include <unordered_map>\n#include <unordered_set>\n#include \"oneflow/core/common/type_traits.h\"\n#include \"oneflow/core/common/check.h\"\n#include \"oneflow/core/common/hash_eq_trait_ptr.h\"\n\nnamespace oneflow {\n\ntemplate<typename T>\nstruct SymbolUtil;\n\ntemplate<typename T>\nclass Symbol final {\n public:\n  Symbol() : ptr_(nullptr) {}\n  Symbol(const T& obj) : ptr_(GetOrCreatePtr(obj)) {}\n  Symbol(const Symbol& rhs) = default;\n  Symbol(Symbol&& rhs) = default;\n  ~Symbol() = default;\n\n  explicit operator bool() const { return ptr_ != nullptr; }\n  const T* operator->() const { return ptr_; }\n  const T& operator*() const { return *ptr_; }\n  bool operator==(const Symbol<T>& rhs) const { return ptr_ == rhs.ptr_; }\n  bool operator!=(const Symbol<T>& rhs) const { return !(*this == rhs); }\n  size_t hash_value() const { return std::hash<const T*>()(ptr_); }\n\n  Symbol& operator=(const Symbol& other) {\n    ptr_ = other.ptr_;\n    return *this;\n  }\n  void reset() { ptr_ = nullptr; }\n  void reset(const T& obj) { ptr_ = GetOrCreatePtr(obj); }\n\n  const std::shared_ptr<const T>& shared_from_symbol() const;\n\n private:\n  template<typename SymbolT>\n  friend struct SymbolUtil;\n  static const T* GetOrCreatePtr(const T& obj);\n\n  const T* ptr_;\n};\n\ntemplate<typename T>\nstruct IsScalarType<Symbol<T>> final {\n  static const bool value = true;\n};\n\ntemplate<typename T>\nstruct SymbolUtil final {\n  using SymbolMap = std::unordered_map<HashEqTraitPtr<const T>, std::shared_ptr<const T>>;\n\n  static SymbolMap* GlobalSymbolMap() {\n    static SymbolMap symbol_map;\n    return &symbol_map;\n  }\n\n  static std::mutex* GlobalSymbolMapMutex() {\n    static std::mutex mutex;\n    return &mutex;\n  }\n\n  static SymbolMap* ThreadLocalSymbolMap() {\n    static thread_local SymbolMap thread_local_symbol_map;\n    return &thread_local_symbol_map;\n  }\n\n  static std::unordered_set<const T*>* ThreadLocalSymbolPtrSet() {\n    static thread_local std::unordered_set<const T*> thread_local_symbol_ptr_set;\n    return &thread_local_symbol_ptr_set;\n  }\n\n  template<typename SymbolMap::iterator (*GetIter4ObjectAndHashValue)(const T&, size_t)>\n  static const std::shared_ptr<const T>& LocalThreadGetOr(const T& obj) {\n    auto* thread_local_symbol_map = ThreadLocalSymbolMap();\n    size_t hash_value = std::hash<T>()(obj);\n    HashEqTraitPtr<const T> obj_ptr_wraper(&obj, hash_value);\n    const auto& local_iter = thread_local_symbol_map->find(obj_ptr_wraper);\n    if (local_iter != thread_local_symbol_map->end()) { return local_iter->second; }\n    const auto& iter = GetIter4ObjectAndHashValue(obj, hash_value);\n    (*thread_local_symbol_map)[iter->first] = iter->second;\n    GLOGCHECK(ThreadLocalSymbolPtrSet()->emplace(iter->second.get()).second);\n    return iter->second;\n  }\n\n  static typename SymbolMap::iterator FindGlobalSymbol(const T& obj, size_t hash_value) {\n    HashEqTraitPtr<const T> new_obj_ptr_wraper(&obj, hash_value);\n    auto* symbol_map = GlobalSymbolMap();\n    std::unique_lock<std::mutex> lock(*GlobalSymbolMapMutex());\n    const auto& iter = symbol_map->find(new_obj_ptr_wraper);\n    GLOGCHECK(iter != symbol_map->end());\n    return iter;\n  }\n\n  static const std::shared_ptr<const T>& SharedFromObject(const T& obj) {\n    return LocalThreadGetOr<FindGlobalSymbol>(obj);\n  }\n\n  static typename SymbolMap::iterator CreateGlobalSymbol(const T& obj, size_t hash_value) {\n    std::shared_ptr<const T> ptr(new T(obj));\n    HashEqTraitPtr<const T> new_obj_ptr_wraper(ptr.get(), hash_value);\n    std::unique_lock<std::mutex> lock(*GlobalSymbolMapMutex());\n    return GlobalSymbolMap()->emplace(new_obj_ptr_wraper, ptr).first;\n  }\n\n  static const std::shared_ptr<const T>& GetOrCreatePtr(const T& obj) {\n    return LocalThreadGetOr<CreateGlobalSymbol>(obj);\n  }\n};\n\ntemplate<typename T>\nconst std::shared_ptr<const T>& Symbol<T>::shared_from_symbol() const {\n  if (this->ptr_ == nullptr) {\n    static auto* none = new std::shared_ptr<const T>();\n    return *none;\n  }\n  return SymbolUtil<T>::SharedFromObject(*this->ptr_);\n}\n\ntemplate<typename T>\nconst T* Symbol<T>::GetOrCreatePtr(const T& obj) {\n  return SymbolUtil<T>::GetOrCreatePtr(obj).get();\n}\n\ntemplate<typename T>\nSymbol<T> SymbolOf(const T& obj) {\n  return Symbol<T>(obj);\n}\n\n}  // namespace oneflow\n\nnamespace std {\n\ntemplate<typename T>\nstruct hash<oneflow::Symbol<T>> final {\n  size_t operator()(const oneflow::Symbol<T>& symbol) const { return symbol.hash_value(); }\n};\n\n}  // namespace std\n\n#endif  // ONEFLOW_CORE_COMMON_SYMBOL_H_\n"
  },
  {
    "path": "oneflow/core/common/symbol_test.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"gtest/gtest.h\"\n#include \"oneflow/core/common/symbol.h\"\n#include \"oneflow/core/common/util.h\"\n\nnamespace oneflow {\nnamespace test {\n\nnamespace detail {\n\nclass SymObject {\n public:\n  SymObject(const std::string& name) : name_(name) {}\n\n  const std::string& name() const { return name_; }\n\n  bool operator==(const SymObject& other) const { return name_ == other.name_; }\n\n private:\n  std::string name_;\n};\n\n}  // namespace detail\n\nTEST(Symbol, shared_from_symbol) {\n  Symbol<detail::SymObject> symbol(detail::SymObject(\"SymbolObjectFoo\"));\n  ASSERT_TRUE(symbol.shared_from_symbol().get()\n              == SymbolOf(detail::SymObject(\"SymbolObjectFoo\")).shared_from_symbol().get());\n}\n\n}  // namespace test\n}  // namespace oneflow\n\nnamespace std {\n\ntemplate<>\nstruct hash<oneflow::test::detail::SymObject> final {\n  size_t operator()(const oneflow::test::detail::SymObject& sym_object) const {\n    return std::hash<std::string>()(sym_object.name());\n  }\n};\n\n}  // namespace std\n"
  },
  {
    "path": "oneflow/core/common/tensor_buffer.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/tensor_buffer.h\"\n#include \"oneflow/core/memory/memory_allocator.h\"\n\nnamespace oneflow {\n\nnamespace detail {\n\nstatic constexpr double kDefaultGrowthFactor = 1.0f;\nstatic constexpr double kDefaultShrinkFactor = 0.7f;\nstatic constexpr size_t kDefaultTensorBufferAlignedSize = 1024;\n\nsize_t GetTensorBufferAlignedSize(size_t origin_size, double factor) {\n  static size_t aligned_size =\n      ParseIntegerFromEnv(\"ONEFLOW_TENSOR_BUFFER_ALIGNED_SIZE\", kDefaultTensorBufferAlignedSize);\n  return RoundUp(static_cast<size_t>(origin_size * factor), aligned_size);\n}\n\nsize_t GetTensorBufferGrowthSize(size_t origin_size) {\n  static double factor =\n      ParseFloatFromEnv(\"ONEFLOW_TENSOR_BUFFER_GROWTH_FACTOR\", kDefaultGrowthFactor);\n  return GetTensorBufferAlignedSize(origin_size, factor);\n}\n\nsize_t GetTensorBufferShrinkSize(size_t origin_size) {\n  static double factor =\n      ParseFloatFromEnv(\"ONEFLOW_TENSOR_BUFFER_SHRINK_FACTOR\", kDefaultShrinkFactor);\n  return GetTensorBufferAlignedSize(origin_size, factor);\n}\n\nvoid CheckTensorBufferDataType(DataType val) {\n  CHECK(val != DataType::kTensorBuffer && val != DataType::kOFRecord)\n      << \"TensorBuffer only support POD as internal data type.\";\n}\n\nvoid TensorBufferImpl::Reset(const Shape& shape, DataType dtype) {\n  int64_t elem_cnt = shape.elem_cnt();\n  if (dtype == DataType::kInvalidDataType || elem_cnt == 0) { return; }\n  CheckTensorBufferDataType(dtype);\n\n  if (shape == shape_ && dtype == data_type_) { return; }\n\n  shape_ = shape;\n  data_type_ = dtype;\n\n  size_t new_buffer_size = elem_cnt * GetSizeOfDataType(dtype);\n  Reserve(new_buffer_size);\n}\n\nvoid TensorBufferImpl::Reset(const Shape& shape) { Reset(shape, data_type_); }\n\nvoid TensorBufferImpl::Reset(DataType dtype) {\n  CheckTensorBufferDataType(dtype);\n  if (dtype == DataType::kInvalidDataType) {\n    Reset();\n  } else {\n    Reset(shape_, dtype);\n  }\n}\n\nvoid TensorBufferImpl::Reset() {\n  shape_ = Shape();\n  data_type_ = DataType::kInvalidDataType;\n  DeallocateBuffer();\n}\n\nvoid TensorBufferImpl::AllocateBuffer(size_t size) {\n  CHECK(buffer_ == nullptr);\n  buffer_ = MemoryAllocatorImpl::AllocateUnPinnedHostMem(size);\n  buffer_size_ = size;\n}\n\nvoid TensorBufferImpl::DeallocateBuffer() {\n  if (buffer_) { MemoryAllocatorImpl::DeallocateUnPinnedHostMem(buffer_); }\n  buffer_ = nullptr;\n  buffer_size_ = 0;\n}\n\nvoid TensorBufferImpl::Reserve(size_t new_size) {\n  if (new_size > buffer_size_) {\n    size_t growth_size = std::max(new_size, GetTensorBufferGrowthSize(new_size));\n    DeallocateBuffer();\n    AllocateBuffer(growth_size);\n  } else {\n    size_t shrink_size = GetTensorBufferShrinkSize(buffer_size_);\n    if (new_size <= shrink_size) {\n      DeallocateBuffer();\n      AllocateBuffer(shrink_size);\n    }\n  }\n}\n\nvoid TensorBufferImpl::CopyFrom(const TensorBufferImpl* src) {\n  if (src == this) { return; }\n  Reset(src->shape(), src->data_type());\n  memcpy(buffer_, src->buffer(), buffer_size_);\n}\n\nvoid TensorBufferImpl::Swap(TensorBufferImpl* other) {\n  std::swap(buffer_, other->buffer_);\n  std::swap(buffer_size_, other->buffer_size_);\n  std::swap(shape_, other->shape_);\n  std::swap(data_type_, other->data_type_);\n}\n\n}  // namespace detail\n\nTensorBuffer::~TensorBuffer() {\n  if (auto* pool = TensorBufferPool::TryGet()) { pool->Deallocate(&impl_); }\n}\n\nTensorBuffer::TensorBuffer(const Shape& shape, DataType dtype) { Allocate(shape, dtype); }\n\nTensorBuffer& TensorBuffer::operator=(TensorBuffer&& other) noexcept {\n  impl_ = std::move(other.impl_);\n  return *this;\n}\n\nvoid TensorBuffer::Allocate(const Shape& shape, DataType dtype) {\n  CHECK(!is_allocated());\n  if (auto* pool = TensorBufferPool::TryGet()) {\n    pool->Allocate(&impl_, shape, dtype);\n  } else {\n    impl_.reset(new detail::TensorBufferImpl(shape, dtype));\n  }\n}\n\nvoid TensorBuffer::Reset(const Shape& shape, DataType dtype) {\n  if (is_allocated()) {\n    impl_->Reset(shape, dtype);\n  } else {\n    Allocate(shape, dtype);\n  }\n}\n\nvoid TensorBuffer::Reset(const Shape& shape) {\n  CHECK(is_allocated()) << \"TensorBuffer is not allocated\";\n  impl_->Reset(shape);\n}\n\nvoid TensorBuffer::Reset(DataType dtype) {\n  CHECK(is_allocated()) << \"TensorBuffer is not allocated\";\n  impl_->Reset(dtype);\n}\n\nvoid TensorBuffer::Reset() {\n  if (impl_) { impl_->Reset(); }\n}\n\nconst Shape& TensorBuffer::shape() const {\n  CHECK(is_allocated()) << \"TensorBuffer is not allocated\";\n  return impl_->shape();\n}\n\nDataType TensorBuffer::data_type() const {\n  CHECK(is_allocated()) << \"TensorBuffer is not allocated\";\n  return impl_->data_type();\n}\n\nvoid* TensorBuffer::raw_data() {\n  CHECK(is_allocated()) << \"TensorBuffer is not allocated\";\n  return impl_->buffer();\n}\n\nconst void* TensorBuffer::raw_data() const {\n  CHECK(is_allocated()) << \"TensorBuffer is not allocated\";\n  return const_cast<detail::TensorBufferImpl*>(impl_.get())->buffer();\n}\n\nvoid TensorBuffer::CopyFrom(const TensorBuffer& src) {\n  CHECK(src.is_allocated()) << \"TensorBuffer src is not allocated\";\n  if (!is_allocated()) { Allocate(src.shape(), src.data_type()); }\n  impl_->CopyFrom(src.impl_.get());\n}\n\nvoid TensorBuffer::Swap(TensorBuffer& other) { std::swap(impl_, other.impl_); }\n\nnamespace {\n\nconstexpr size_t kDefaultPoolSizeBase = 64;\nconstexpr double kDefaultPoolSizeFactor = 2.0;\nconstexpr size_t kDefaultThreadLocalCacheSize = 64;\n\nsize_t GetTensorBufferPoolSize(size_t base = kDefaultPoolSizeBase) {\n  static double factor =\n      ParseFloatFromEnv(\"ONEFLOW_TENSOR_BUFFER_POOL_SIZE_FACTOR\", kDefaultPoolSizeFactor);\n  return static_cast<size_t>(std::ceil(base * factor));\n}\n\nsize_t GetTensorBufferPoolThreadLocalCacheSize() {\n  static size_t cache_size = ParseIntegerFromEnv(\n      \"ONEFLOW_TENSOR_BUFFER_POOL_THREAD_LOCAL_CACHE_SIZE\", kDefaultThreadLocalCacheSize);\n  return cache_size;\n}\n\n}  // namespace\n\nTensorBufferPool::TensorBufferPool()\n    : thread_local_cache_size_(GetTensorBufferPoolThreadLocalCacheSize()),\n      pool_size_(GetTensorBufferPoolSize()) {\n  auto& thread_local_cache = ThreadLocalCache();\n  thread_local_cache.reserve(thread_local_cache_size_);\n  global_free_list_.reserve(pool_size_);\n}\n\nvoid TensorBufferPool::Allocate(ItemT* item, const Shape& shape, DataType dtype) {\n  CHECK(!(*item)) << \"TensorBuffer is already allocated\";\n  auto& thread_local_cache = ThreadLocalCache();\n  if (thread_local_cache.empty() && thread_local_cache_size_ > 0) {\n    std::unique_lock<std::mutex> lck(mtx_);\n    if (!global_free_list_.empty()) {\n      // fetch half of thread_local_cache_size of tensor buffers from global free list\n      size_t fetches = thread_local_cache_size_ / 2;\n      auto begin = global_free_list_.size() >= fetches ? (global_free_list_.end() - fetches)\n                                                       : global_free_list_.begin();\n      for (auto it = begin; it < global_free_list_.end(); ++it) {\n        thread_local_cache.push_back(std::move(*it));\n      }\n      global_free_list_.erase(begin, global_free_list_.end());\n    }\n  }\n\n  if (thread_local_cache.empty()) {\n    item->reset(new detail::TensorBufferImpl(shape, dtype));\n  } else {\n    *item = std::move(thread_local_cache.back());\n    thread_local_cache.pop_back();\n    (*item)->Reset(shape, dtype);\n  }\n}\n\nvoid TensorBufferPool::Deallocate(ItemT* item) {\n  if (!(*item)) { return; }\n  auto& thread_local_cache = ThreadLocalCache();\n  if (thread_local_cache.size() < thread_local_cache_size_) {\n    thread_local_cache.push_back(std::move(*item));\n  } else {\n    size_t releases = thread_local_cache.size() / 2;\n    {\n      std::unique_lock<std::mutex> lck(mtx_);\n      if (global_free_list_.size() < pool_size_) {\n        global_free_list_.push_back(std::move(*item));\n        // release half of tensor buffers in thread local cache back to global free list\n        while (global_free_list_.size() < pool_size_ && releases > 0) {\n          global_free_list_.push_back(std::move(thread_local_cache.back()));\n          thread_local_cache.pop_back();\n          releases--;\n        }\n      }\n    }\n    // global free list is also full, release half of thread local cache\n    thread_local_cache.resize(thread_local_cache.size() - releases);\n  }\n  if (*item) { item->reset(); }\n}\n\nvoid TensorBufferPool::IncreasePoolSizeByBase(size_t base) {\n  std::unique_lock<std::mutex> lck(mtx_);\n  pool_size_ += GetTensorBufferPoolSize(base);\n  if (pool_size_ > global_free_list_.capacity()) { global_free_list_.reserve(pool_size_); }\n  if (pool_size_ < global_free_list_.size()) { global_free_list_.resize(pool_size_); }\n}\n\nvoid TensorBufferPool::DecreasePoolSizeByBase(size_t base) {\n  std::unique_lock<std::mutex> lck(mtx_);\n  size_t dec = GetTensorBufferPoolSize(base);\n  CHECK_GE(pool_size_, dec) << \"pool_size \" << pool_size_ << \" decreased by \" << dec\n                            << \" would be negative\";\n  pool_size_ -= dec;\n  if (pool_size_ > global_free_list_.capacity()) { global_free_list_.reserve(pool_size_); }\n  if (pool_size_ < global_free_list_.size()) { global_free_list_.resize(pool_size_); }\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/common/tensor_buffer.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_COMMON_TENSOR_BUFFER_H_\n#define ONEFLOW_CORE_COMMON_TENSOR_BUFFER_H_\n\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/common/shape.h\"\n#include \"oneflow/core/common/shape_view.h\"\n#include \"oneflow/core/common/data_type.h\"\n\nnamespace oneflow {\n\nnamespace detail {\n\nclass TensorBufferImpl final {\n public:\n  TensorBufferImpl()\n      : shape_(Shape()),\n        data_type_(DataType::kInvalidDataType),\n        buffer_(nullptr),\n        buffer_size_(0) {}\n  TensorBufferImpl(const Shape& shape, DataType dtype)\n      : shape_(Shape()), data_type_(DataType::kInvalidDataType), buffer_(nullptr), buffer_size_(0) {\n    Reset(shape, dtype);\n  }\n  ~TensorBufferImpl() { DeallocateBuffer(); }\n  OF_DISALLOW_COPY_AND_MOVE(TensorBufferImpl);\n\n  void Reset(const Shape& shape, DataType dtype);\n  void Reset(const Shape& shape);\n  void Reset(DataType dtype);\n  void Reset();\n\n  void CopyFrom(const TensorBufferImpl* src);\n  void Swap(TensorBufferImpl* other);\n\n  const Shape& shape() const { return shape_; }\n  DataType data_type() const { return data_type_; }\n\n  void* buffer() { return buffer_; }\n  const void* buffer() const { return buffer_; }\n  size_t buffer_size() const { return buffer_size_; }\n\n private:\n  void AllocateBuffer(size_t size);\n  void DeallocateBuffer();\n  void Reserve(size_t new_size);\n\n  Shape shape_;\n  DataType data_type_;\n\n  void* buffer_;\n  size_t buffer_size_;\n};\n\n}  // namespace detail\n\nclass TensorBuffer final {\n public:\n  TensorBuffer() = default;\n  ~TensorBuffer();\n\n  TensorBuffer(const Shape& shape, DataType dtype);\n\n  TensorBuffer(const TensorBuffer&) = delete;\n  TensorBuffer& operator=(const TensorBuffer&) = delete;\n\n  TensorBuffer(TensorBuffer&& other) noexcept : impl_(std::move(other.impl_)) {}\n  TensorBuffer& operator=(TensorBuffer&& other) noexcept;\n\n  bool is_allocated() const { return bool(impl_); }\n  const Shape& shape() const;\n  ShapeView shape_view() const { return shape(); }\n  DataType data_type() const;\n  int64_t elem_cnt() const { return shape().elem_cnt(); }\n  size_t nbytes() const { return elem_cnt() * GetSizeOfDataType(data_type()); }\n\n  void Reset(const Shape& shape, DataType dtype);\n  void Reset(const Shape& shape);\n  void Reset(DataType dtype);\n  void Reset();\n\n  // backward compatible interface and will be deprecated in future\n  void Resize(const Shape& shape, DataType dtype) { Reset(shape, dtype); }\n\n  void CopyFrom(const TensorBuffer& src);\n  void Swap(TensorBuffer& other);\n\n  template<typename T = void>\n  T* mut_data() {\n    if (raw_data() == nullptr) { return nullptr; }\n    CheckDataType<T>(data_type());\n    return static_cast<T*>(raw_data());\n  }\n\n  template<typename T = void>\n  const T* data() const {\n    if (raw_data() == nullptr) { return nullptr; }\n    CheckDataType<T>(data_type());\n    return static_cast<const T*>(raw_data());\n  }\n\n private:\n  friend class TensorBufferPool;\n\n  void Allocate(const Shape& shape, DataType dtype);\n  void* raw_data();\n  const void* raw_data() const;\n\n  std::unique_ptr<detail::TensorBufferImpl> impl_;\n};\n\n#define BUFFER_DATA_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(TensorBuffer, DataType::kTensorBuffer)\n\ntemplate<>\nstruct GetDataType<TensorBuffer> : std::integral_constant<DataType, DataType::kTensorBuffer> {};\ninline TensorBuffer GetTypeByDataType(std::integral_constant<DataType, DataType::kTensorBuffer>) {\n  return {};\n}\n\nclass TensorBufferPool final {\n public:\n  using ItemT = std::unique_ptr<detail::TensorBufferImpl>;\n  using ListT = std::vector<ItemT>;\n\n  static TensorBufferPool* Get() {\n    auto& ptr = GetPtr();\n    CHECK(ptr) << \"TensorBufferPool has not been created\";\n    return ptr.get();\n  }\n\n  static TensorBufferPool* TryGet() {\n    auto& ptr = GetPtr();\n    return ptr.get();\n  }\n\n  static void New() {\n    auto& ptr = GetPtr();\n    CHECK(!ptr) << \"TensorBufferPool is already New\";\n    ptr.reset(new TensorBufferPool());\n  }\n\n  static void Delete() {\n    auto& ptr = GetPtr();\n    if (ptr) { ptr.reset(); }\n  }\n\n  ~TensorBufferPool() = default;\n  OF_DISALLOW_COPY_AND_MOVE(TensorBufferPool);\n\n  void Allocate(ItemT* item, const Shape& shape, DataType dtype);\n  void Deallocate(ItemT* item);\n\n  void IncreasePoolSizeByBase(size_t base);\n  void DecreasePoolSizeByBase(size_t base);\n\n private:\n  static std::unique_ptr<TensorBufferPool>& GetPtr() {\n    static std::unique_ptr<TensorBufferPool> ptr;\n    return ptr;\n  }\n\n  static ListT& ThreadLocalCache() {\n    thread_local ListT thread_local_cache;\n    return thread_local_cache;\n  }\n\n  TensorBufferPool();\n\n  size_t thread_local_cache_size_;\n  size_t pool_size_;\n\n  ListT global_free_list_;\n  std::mutex mtx_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_COMMON_TENSOR_BUFFER_H_\n"
  },
  {
    "path": "oneflow/core/common/tensor_desc.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/tensor_desc.h\"\n#include \"oneflow/core/register/blob_desc.pb.h\"\n\nnamespace oneflow {\n\nnamespace user_op {\n\nTensorDesc& TensorDesc::operator=(const TensorDesc& rhs) {\n  this->set_shape(rhs.shape());\n  this->set_stride(rhs.stride());\n  this->set_data_type(rhs.data_type());\n  this->set_is_dynamic(rhs.is_dynamic());\n  this->set_memory_format(rhs.memory_format());\n  return *this;\n}\n\nbool TensorDesc::operator==(const TensorDesc& rhs) const {\n  return (this->shape() == rhs.shape()) && (this->stride() == rhs.stride())\n         && (this->data_type() == rhs.data_type()) && (this->is_dynamic() == rhs.is_dynamic())\n         && (this->memory_format() == rhs.memory_format());\n}\n\nNaiveTensorDesc::NaiveTensorDesc(const NaiveTensorDesc& rhs) { *this = rhs; }\n\nNaiveTensorDesc::NaiveTensorDesc(const BlobDescProto& proto) { *this = proto; }\n\nNaiveTensorDesc& NaiveTensorDesc::operator=(const BlobDescProto& proto) {\n  data_type_ = proto.data_type();\n  shape_ = Shape(proto.shape());\n  stride_ = Stride(proto.stride());\n  is_dynamic_ = proto.is_dynamic();\n  memory_format_ = proto.memory_format();\n  return *this;\n}\n\n}  // namespace user_op\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/common/tensor_desc.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_COMMON_TENSOR_DESC_H_\n#define ONEFLOW_CORE_COMMON_TENSOR_DESC_H_\n\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/common/shape.h\"\n#include \"oneflow/core/common/stride.h\"\n#include \"oneflow/core/common/data_type.pb.h\"\n#include \"oneflow/core/common/memory_format.pb.h\"\n\nnamespace oneflow {\n\nclass BlobDescProto;\n\nnamespace user_op {\n\nclass TensorDesc {\n public:\n  virtual ~TensorDesc() = default;\n  TensorDesc& operator=(const TensorDesc& rhs);\n  bool operator==(const TensorDesc&) const;\n\n  virtual const Shape& shape() const = 0;\n  virtual void set_shape(const Shape& shape) = 0;\n  virtual const Stride& stride() const = 0;\n  virtual void set_stride(const Stride& stride) = 0;\n  virtual DataType data_type() const = 0;\n  virtual void set_data_type(DataType data_type) = 0;\n\n  virtual bool is_dynamic() const = 0;\n  virtual void set_is_dynamic(bool is_dynamic) = 0;\n\n  virtual MemoryFormat memory_format() const = 0;\n  virtual void set_memory_format(MemoryFormat memory_format) = 0;\n\n protected:\n  TensorDesc() = default;\n};\n\nclass NaiveTensorDesc final : public TensorDesc {\n public:\n  NaiveTensorDesc() = default;\n  ~NaiveTensorDesc() override = default;\n  NaiveTensorDesc(const NaiveTensorDesc&);\n  NaiveTensorDesc(const BlobDescProto&);\n\n  NaiveTensorDesc& operator=(const BlobDescProto&);\n\n  const Shape& shape() const override { return shape_; }\n  void set_shape(const Shape& shape) override { shape_ = shape; }\n  const Stride& stride() const override { return stride_; }\n  void set_stride(const Stride& stride) override { stride_ = stride; }\n  DataType data_type() const override { return data_type_; }\n  void set_data_type(DataType data_type) override { data_type_ = data_type; }\n\n  bool is_dynamic() const override { return is_dynamic_; }\n  void set_is_dynamic(bool is_dynamic) override { is_dynamic_ = is_dynamic; }\n\n  MemoryFormat memory_format() const override { return memory_format_; }\n  void set_memory_format(MemoryFormat memory_format) override { memory_format_ = memory_format; }\n\n private:\n  Shape shape_;\n  Stride stride_;\n  DataType data_type_;\n  bool is_dynamic_;\n  MemoryFormat memory_format_;\n};\n\n}  // namespace user_op\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_COMMON_TENSOR_DESC_H_\n"
  },
  {
    "path": "oneflow/core/common/tensor_meta.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/tensor_meta.h\"\n#include \"oneflow/core/common/stride.h\"\n#include \"oneflow/core/common/shape_view.h\"\n#include \"oneflow/core/framework/device.h\"\n#include \"oneflow/core/common/shape_view.h\"\n\nnamespace oneflow {\nnamespace one {\n\nMutTensorMeta::MutTensorMeta()\n    : TensorMeta(kInvalidDataType, MemoryFormat::kContiguous),\n      shape_(std::make_shared<const Shape>()),\n      stride_(std::make_shared<const Stride>()) {}\n\nMutTensorMeta::MutTensorMeta(const std::shared_ptr<const Shape>& shape, DataType dtype,\n                             MemoryFormat memory_format)\n    : TensorMeta(dtype, memory_format),\n      shape_(std::make_shared<const Shape>(*shape)),\n      stride_(std::make_shared<const Stride>(*shape)) {}\n\nMutTensorMeta::MutTensorMeta(const std::shared_ptr<const Shape>& shape,\n                             const std::shared_ptr<const Stride>& stride, DataType dtype,\n                             MemoryFormat memory_format)\n    : TensorMeta(dtype, memory_format),\n      shape_(std::make_shared<const Shape>(*shape)),\n      stride_(std::make_shared<const Stride>(*stride)) {}\n\nMutTensorMeta::MutTensorMeta(const Shape& shape, DataType dtype, MemoryFormat memory_format)\n    : TensorMeta(dtype, memory_format),\n      shape_(std::make_shared<const Shape>(shape)),\n      stride_(std::make_shared<const Stride>(shape)) {}\n\nMutTensorMeta::MutTensorMeta(const Shape& shape, const Stride& stride, DataType dtype,\n                             MemoryFormat memory_format)\n    : TensorMeta(dtype, memory_format),\n      shape_(std::make_shared<const Shape>(shape)),\n      stride_(std::make_shared<const Stride>(stride)) {}\n\nbool MutTensorMeta::operator==(const MutTensorMeta& other) const {\n  // It's correct to ignore is_dynamic_ field.\n  return *this->shape_ptr() == *other.shape_ptr() && this->dtype() == other.dtype()\n         && this->memory_format() == other.memory_format() && this->stride() == other.stride();\n}\n\nsize_t MutTensorMeta::CalcHashValue() const {\n  // It's correct to ignore is_dynamic_ field.\n  return Hash(*shape_ptr(), dtype(), memory_format(), stride());\n}\n\nConstTensorMeta::ConstTensorMeta()\n    : TensorMeta(kInvalidDataType, MemoryFormat::kContiguous),\n      shape_(SymbolOf(Shape())),\n      stride_(SymbolOf(Stride())) {}\n\nConstTensorMeta::ConstTensorMeta(Symbol<Shape> shape, DataType dtype, MemoryFormat memory_format)\n    : TensorMeta(dtype, memory_format), shape_(shape), stride_(SymbolOf(Stride(*shape))) {}\n\nConstTensorMeta::ConstTensorMeta(Symbol<Shape> shape, Symbol<Stride> stride, DataType dtype,\n                                 MemoryFormat memory_format)\n    : TensorMeta(dtype, memory_format), shape_(shape), stride_(stride) {}\n\nbool ConstTensorMeta::operator==(const ConstTensorMeta& other) const {\n  // It's correct to ignore is_dynamic_ field.\n  return *this->shape_ptr() == *other.shape_ptr() && this->dtype() == other.dtype()\n         && this->memory_format() == other.memory_format() && this->stride() == other.stride();\n}\n\nsize_t ConstTensorMeta::CalcHashValue() const {\n  // It's correct to ignore is_dynamic_ field.\n  return Hash(*shape_ptr(), dtype(), memory_format(), stride());\n}\n\nLocalTensorMeta::LocalTensorMeta()\n    : ConstTensorMeta(SymbolOf(Shape()), SymbolOf(Stride()), DataType::kInvalidDataType,\n                      MemoryFormat::kContiguous),\n      device_(Symbol<Device>()) {}\n\nLocalTensorMeta::LocalTensorMeta(Symbol<Shape> shape, DataType dtype, MemoryFormat memory_format,\n                                 Symbol<Device> device)\n    : ConstTensorMeta(shape, SymbolOf(Stride(*shape)), dtype, memory_format), device_(device) {}\n\nLocalTensorMeta::LocalTensorMeta(Symbol<Shape> shape, Symbol<Stride> stride, DataType dtype,\n                                 MemoryFormat memory_format, Symbol<Device> device)\n    : ConstTensorMeta(shape, stride, dtype, memory_format), device_(device) {}\n\nLocalTensorMeta::LocalTensorMeta(Symbol<Shape> shape, Symbol<Stride> stride, DataType dtype,\n                                 MemoryFormat memory_format, Symbol<Device> device,\n                                 const bool is_view)\n    : ConstTensorMeta(shape, stride, dtype, memory_format), device_(device), is_view_(is_view) {}\n\nbool LocalTensorMeta::operator==(const LocalTensorMeta& other) const {\n  // It's correct to ignore is_dynamic_ field.\n  return *this->shape_ptr() == *other.shape_ptr() && this->dtype() == other.dtype()\n         && this->memory_format() == other.memory_format() && this->device() == other.device()\n         && this->stride() == other.stride();\n}\n\nsize_t LocalTensorMeta::CalcHashValue() const {\n  // It's correct to ignore is_dynamic_ field.\n  return Hash(*shape_ptr(), dtype(), memory_format(), device(), stride());\n}\n\nMutLocalTensorMeta::MutLocalTensorMeta()\n    : MutTensorMeta(std::make_shared<const Shape>(), std::make_shared<const Stride>(),\n                    kInvalidDataType, MemoryFormat::kContiguous),\n      device_(Symbol<Device>()) {}\n\nMutLocalTensorMeta::MutLocalTensorMeta(const std::shared_ptr<const Shape>& shape, DataType dtype,\n                                       MemoryFormat memory_format, Symbol<Device> device)\n    : MutTensorMeta(shape, std::make_shared<const Stride>(*shape), dtype, memory_format),\n      device_(device) {}\n\nMutLocalTensorMeta::MutLocalTensorMeta(const std::shared_ptr<const Shape>& shape,\n                                       const std::shared_ptr<const Stride>& stride, DataType dtype,\n                                       MemoryFormat memory_format, Symbol<Device> device)\n    : MutTensorMeta(shape, stride, dtype, memory_format), device_(device) {}\n\nMutLocalTensorMeta::MutLocalTensorMeta(const Shape& shape, DataType dtype,\n                                       MemoryFormat memory_format, Symbol<Device> device)\n    : MutTensorMeta(shape, Stride(shape), dtype, memory_format), device_(device) {}\n\nMutLocalTensorMeta::MutLocalTensorMeta(const Shape& shape, const Stride& stride, DataType dtype,\n                                       MemoryFormat memory_format, Symbol<Device> device)\n    : MutTensorMeta(shape, stride, dtype, memory_format), device_(device) {}\n\nbool MutLocalTensorMeta::operator==(const MutLocalTensorMeta& other) const {\n  // It's correct to ignore is_dynamic_ field.\n  return *this->shape_ptr() == *other.shape_ptr() && this->dtype() == other.dtype()\n         && this->memory_format() == other.memory_format() && *this->device() == *other.device()\n         && this->stride() == other.stride();\n}\n\nsize_t MutLocalTensorMeta::CalcHashValue() const {\n  // It's correct to ignore is_dynamic_ field.\n  return Hash(*shape_ptr(), dtype(), memory_format(), *device(), stride());\n}\n\nbool GlobalTensorMeta::operator==(const GlobalTensorMeta& other) const {\n  // It's correct to ignore is_dynamic_ field.\n  return *this->shape_ptr() == *other.shape_ptr() && this->dtype() == other.dtype()\n         && this->memory_format() == other.memory_format() && this->nd_sbp() == other.nd_sbp()\n         && this->parallel_desc() == other.parallel_desc();\n}\n\nsize_t GlobalTensorMeta::CalcHashValue() const {\n  return Hash(*shape_ptr(), dtype(), memory_format(), nd_sbp(), parallel_desc());\n}\n\nbool IsContiguous(const Shape& shape, const Stride& stride) {\n  if (!shape.is_initialized()) { return true; }\n  return IsContiguous(ShapeView(shape), stride);\n}\n\nbool IsContiguous(const ShapeView& shape_view, const Stride& stride) {\n  if (shape_view.NumAxes() < 1 || shape_view.elem_cnt() <= 1) { return true; }\n  int64_t dim = shape_view.NumAxes();\n  int64_t expected_stride = 1;\n  bool contig_if_nonempty = true;\n  for (int64_t i = dim - 1; i >= 0; --i) {\n    // Contiguous by default when any dim is equal to zero\n    // https://stackoverflow.com/questions/31681324/identify-contiguous-segments-of-a-non-contiguous-numpy-array\n    if (shape_view.At(i) == 0) { return true; }\n    if (contig_if_nonempty && shape_view.At(i) != 1) {\n      if (stride.at(i) != expected_stride) { contig_if_nonempty = false; }\n      expected_stride *= shape_view.At(i);\n    }\n  }\n  return contig_if_nonempty;\n}\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/common/tensor_meta.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_COMMON_TENSOR_META_H_\n#define ONEFLOW_COMMON_TENSOR_META_H_\n\n#include <memory>\n#include \"oneflow/core/common/tensor_desc.h\"\n#include \"oneflow/core/common/symbol.h\"\n\nnamespace oneflow {\n\nclass NdSbp;\nclass Shape;\nclass Stride;\nclass Device;\nclass ParallelDesc;\n\nnamespace one {\n\nbool IsContiguous(const Shape& shape, const Stride& stride);\nbool IsContiguous(const ShapeView& shape_view, const Stride& stride);\n\nclass TensorMeta : public user_op::TensorDesc {\n public:\n  TensorMeta(DataType dtype, MemoryFormat memory_format)\n      : data_type_(dtype), is_dynamic_(false), memory_format_(memory_format) {}\n  TensorMeta(const TensorMeta& other) = default;\n  TensorMeta(TensorMeta&&) = default;\n  virtual ~TensorMeta() = default;\n\n  virtual const std::shared_ptr<const Shape>& shape_ptr() const = 0;\n  virtual const std::shared_ptr<const Stride>& stride_ptr() const = 0;\n  virtual bool is_contiguous() const = 0;\n\n  DataType dtype() const { return data_type_; }\n  DataType data_type() const override { return data_type_; }\n  bool is_dynamic() const override { return is_dynamic_; }\n  MemoryFormat memory_format() const override { return memory_format_; }\n\n  virtual void set_shape(const Shape& shape) override { PRINT_BUG_PROMPT_AND_ABORT(); }\n  virtual void set_stride(const Stride& stride) override { PRINT_BUG_PROMPT_AND_ABORT(); }\n  virtual void set_data_type(DataType data_type) override { PRINT_BUG_PROMPT_AND_ABORT(); }\n  virtual void set_is_dynamic(bool is_dynamic) override { PRINT_BUG_PROMPT_AND_ABORT(); }\n  virtual void set_memory_format(MemoryFormat memory_format) override {\n    PRINT_BUG_PROMPT_AND_ABORT();\n  }\n\n protected:\n  DataType data_type_;\n  bool is_dynamic_;\n  MemoryFormat memory_format_;\n};\n\nclass MutTensorMeta : public TensorMeta {\n public:\n  // uninitialized MutTensorMeta.\n  MutTensorMeta();\n  MutTensorMeta(const MutTensorMeta& other)\n      : TensorMeta(other),\n        shape_(std::make_shared<const Shape>(*other.shape_)),\n        stride_(std::make_shared<const Stride>(*other.stride_)) {}\n  MutTensorMeta(const std::shared_ptr<const Shape>& shape, DataType dtype,\n                MemoryFormat memory_format);\n  MutTensorMeta(const std::shared_ptr<const Shape>& shape,\n                const std::shared_ptr<const Stride>& stride, DataType dtype,\n                MemoryFormat memory_format);\n  MutTensorMeta(const Shape& shape, DataType dtype, MemoryFormat memory_format);\n  MutTensorMeta(const Shape& shape, const Stride& stride, DataType dtype,\n                MemoryFormat memory_format);\n  virtual ~MutTensorMeta() = default;\n\n  const std::shared_ptr<const Shape>& shape_ptr() const override { return shape_; }\n  const std::shared_ptr<const Stride>& stride_ptr() const override { return stride_; }\n  const Shape& shape() const override { return *shape_; }\n  const Stride& stride() const override { return *stride_; }\n  bool is_contiguous() const override { return IsContiguous(*shape_, *stride_); }\n\n  void set_shape(const Shape& shape) override { *const_cast<Shape*>(shape_.get()) = shape; }\n  void set_stride(const Stride& stride) override { *const_cast<Stride*>(stride_.get()) = stride; }\n  void set_data_type(DataType data_type) override { data_type_ = data_type; }\n  void set_is_dynamic(bool is_dynamic) override { is_dynamic_ = is_dynamic; }\n  void set_memory_format(MemoryFormat memory_format) override { memory_format_ = memory_format; }\n\n  bool operator==(const MutTensorMeta& other) const;\n  size_t CalcHashValue() const;\n\n  MutTensorMeta& operator=(const MutTensorMeta& other) {\n    this->data_type_ = other.data_type_;\n    this->is_dynamic_ = other.is_dynamic_;\n    this->memory_format_ = other.memory_format_;\n    this->shape_ = std::make_shared<const Shape>(*other.shape_);\n    this->stride_ = std::make_shared<const Stride>(*other.stride_);\n    return *this;\n  }\n\n protected:\n  std::shared_ptr<const Shape> shape_;\n  std::shared_ptr<const Stride> stride_;\n};\n\nclass ConstTensorMeta : public TensorMeta {\n public:\n  // uninitialized ConstTensorMeta.\n  ConstTensorMeta();\n  ConstTensorMeta(const ConstTensorMeta&) = default;\n  ConstTensorMeta(Symbol<Shape> shape, DataType dtype, MemoryFormat memory_format);\n  ConstTensorMeta(Symbol<Shape> shape, Symbol<Stride> stride, DataType dtype,\n                  MemoryFormat memory_format);\n  ConstTensorMeta(const Shape& shape, DataType dtype, MemoryFormat memory_format)\n      : ConstTensorMeta(SymbolOf(shape), dtype, memory_format) {}\n  ConstTensorMeta(const Shape& shape, const Stride& stride, DataType dtype,\n                  MemoryFormat memory_format)\n      : ConstTensorMeta(SymbolOf(shape), SymbolOf(stride), dtype, memory_format) {}\n\n  virtual ~ConstTensorMeta() = default;\n\n  const std::shared_ptr<const Shape>& shape_ptr() const override {\n    return shape_.shared_from_symbol();\n  }\n  const std::shared_ptr<const Stride>& stride_ptr() const override {\n    return stride_.shared_from_symbol();\n  }\n  const Shape& shape() const override { return *shape_; }\n  const Stride& stride() const override { return *stride_; }\n  bool is_contiguous() const override { return IsContiguous(*shape_, *stride_); }\n\n  bool operator==(const ConstTensorMeta& other) const;\n  size_t CalcHashValue() const;\n\n  ConstTensorMeta& operator=(const ConstTensorMeta& other) {\n    this->data_type_ = other.data_type_;\n    this->is_dynamic_ = other.is_dynamic_;\n    this->memory_format_ = other.memory_format_;\n    this->shape_ = other.shape_;\n    this->stride_ = other.stride_;\n    return *this;\n  }\n\n protected:\n  Symbol<Shape> shape_;\n  Symbol<Stride> stride_;\n};\n\nclass LocalTensorMeta : public ConstTensorMeta {\n public:\n  // uninitialized LocalTensorMeta.\n  LocalTensorMeta();\n  LocalTensorMeta(const LocalTensorMeta&) = default;\n  LocalTensorMeta(Symbol<Shape> shape, DataType dtype, MemoryFormat memory_format,\n                  Symbol<Device> device);\n  LocalTensorMeta(Symbol<Shape> shape, Symbol<Stride> stride, DataType dtype,\n                  MemoryFormat memory_format, Symbol<Device> device);\n  LocalTensorMeta(Symbol<Shape> shape, Symbol<Stride> stride, DataType dtype,\n                  MemoryFormat memory_format, Symbol<Device> device, bool is_view);\n  LocalTensorMeta(const Shape& shape, DataType dtype, MemoryFormat memory_format,\n                  Symbol<Device> device)\n      : LocalTensorMeta(SymbolOf(shape), dtype, memory_format, device) {}\n  LocalTensorMeta(const Shape& shape, const Stride& stride, DataType dtype,\n                  MemoryFormat memory_format, Symbol<Device> device)\n      : LocalTensorMeta(SymbolOf(shape), SymbolOf(stride), dtype, memory_format, device) {}\n  LocalTensorMeta(const Shape& shape, const Stride& stride, DataType dtype,\n                  MemoryFormat memory_format, Symbol<Device> device, const bool is_view)\n      : LocalTensorMeta(SymbolOf(shape), SymbolOf(stride), dtype, memory_format, device, is_view) {}\n  virtual ~LocalTensorMeta() = default;\n\n  const Symbol<Device>& device() const { return device_; }\n  bool is_view() const { return is_view_; }\n\n  bool operator==(const LocalTensorMeta& other) const;\n  size_t CalcHashValue() const;\n\n  LocalTensorMeta& operator=(const LocalTensorMeta& other) = default;\n\n private:\n  Symbol<Device> device_;\n  bool is_view_ = false;\n};\n\nclass MutLocalTensorMeta : public MutTensorMeta {\n public:\n  // uninitialized MutLocalTensorMeta.\n  MutLocalTensorMeta();\n  MutLocalTensorMeta(const MutLocalTensorMeta&) = default;\n  MutLocalTensorMeta(const std::shared_ptr<const Shape>& shape, DataType dtype,\n                     MemoryFormat memory_format, Symbol<Device> device);\n  MutLocalTensorMeta(const std::shared_ptr<const Shape>& shape,\n                     const std::shared_ptr<const Stride>& stride, DataType dtype,\n                     MemoryFormat memory_format, Symbol<Device> device);\n  MutLocalTensorMeta(const Shape& shape, DataType dtype, MemoryFormat memory_format,\n                     Symbol<Device> device);\n  MutLocalTensorMeta(const Shape& shape, const Stride& stride, DataType dtype,\n                     MemoryFormat memory_format, Symbol<Device> device);\n  virtual ~MutLocalTensorMeta() = default;\n\n  const Symbol<Device>& device() const { return device_; }\n\n  Symbol<Device>* mut_device() { return &device_; }\n\n  bool operator==(const MutLocalTensorMeta& other) const;\n  size_t CalcHashValue() const;\n\n  MutLocalTensorMeta& operator=(const MutLocalTensorMeta& other) = default;\n\n private:\n  Symbol<Device> device_;\n};\n\nclass GlobalTensorMeta : public ConstTensorMeta {\n public:\n  GlobalTensorMeta(Symbol<Shape> shape, DataType dtype, MemoryFormat memory_format,\n                   Symbol<NdSbp> nd_sbp, Symbol<ParallelDesc> parallel_desc)\n      : ConstTensorMeta(shape, dtype, memory_format),\n        nd_sbp_(nd_sbp),\n        parallel_desc_(parallel_desc) {}\n  GlobalTensorMeta(const Shape& shape, DataType dtype, MemoryFormat memory_format,\n                   Symbol<NdSbp> nd_sbp, Symbol<ParallelDesc> parallel_desc)\n      : GlobalTensorMeta(SymbolOf(shape), dtype, memory_format, nd_sbp, parallel_desc) {}\n  GlobalTensorMeta(const GlobalTensorMeta&) = default;\n  GlobalTensorMeta(GlobalTensorMeta&&) = default;\n  virtual ~GlobalTensorMeta() = default;\n\n  bool operator==(const GlobalTensorMeta& other) const;\n\n  Symbol<NdSbp> nd_sbp() const { return nd_sbp_; }\n  Symbol<ParallelDesc> parallel_desc() const { return parallel_desc_; }\n\n  size_t CalcHashValue() const;\n\n private:\n  Symbol<NdSbp> nd_sbp_;\n  Symbol<ParallelDesc> parallel_desc_;\n};\n\n}  // namespace one\n}  // namespace oneflow\n\nnamespace std {\n\ntemplate<>\nstruct hash<oneflow::one::LocalTensorMeta> final {\n  size_t operator()(const oneflow::one::LocalTensorMeta& local_tensor_meta) const {\n    return local_tensor_meta.CalcHashValue();\n  }\n};\n\ntemplate<>\nstruct hash<oneflow::one::GlobalTensorMeta> final {\n  size_t operator()(const oneflow::one::GlobalTensorMeta& global_tensor_meta) const {\n    return global_tensor_meta.CalcHashValue();\n  }\n};\n\n}  // namespace std\n\n#endif  // ONEFLOW_COMMON_TENSOR_META_H_\n"
  },
  {
    "path": "oneflow/core/common/test_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_COMMON_TEST_UTIL_H_\n#define ONEFLOW_CORE_COMMON_TEST_UTIL_H_\n\n#ifndef final\n#define final\n#endif\n\n#ifndef private\n#define private public\n#endif\n\n#include <gmock/gmock.h>\n#include <gtest/gtest.h>\n\n#endif  // ONEFLOW_CORE_COMMON_TEST_UTIL_H_\n"
  },
  {
    "path": "oneflow/core/common/thread_local_guard.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_COMMON_THREAD_LOCAL_GUARD_H_\n#define ONEFLOW_CORE_COMMON_THREAD_LOCAL_GUARD_H_\n\n#include <memory>\n#include \"oneflow/core/common/optional.h\"\n\nnamespace oneflow {\n\ntemplate<typename T>\nclass ThreadLocalGuard {\n public:\n  ThreadLocalGuard() {\n    old_value_ = *MutThreadLocalValue();\n    *MutThreadLocalValue() = Optional<T>();\n  }\n  explicit ThreadLocalGuard(const T& value) {\n    old_value_ = *MutThreadLocalValue();\n    *MutThreadLocalValue() = Optional<T>(value);\n  }\n  ~ThreadLocalGuard() { *MutThreadLocalValue() = old_value_; }\n\n  static const Optional<T>& Current() { return *MutThreadLocalValue(); }\n\n private:\n  static Optional<T>* MutThreadLocalValue() {\n    static thread_local Optional<T> value{};\n    return &value;\n  }\n\n  Optional<T> old_value_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_COMMON_THREAD_LOCAL_GUARD_H_\n"
  },
  {
    "path": "oneflow/core/common/thread_local_guard_test.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <gtest/gtest.h>\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/common/thread_local_guard.h\"\n\nnamespace oneflow {\nnamespace test {\n\ntemplate<typename T>\nvoid Assert(const T& value0, const T& value1) {\n  ASSERT_FALSE(ThreadLocalGuard<T>::Current().has_value());\n  {\n    ThreadLocalGuard<T> guard(value0);\n    ASSERT_TRUE(ThreadLocalGuard<T>::Current().has_value());\n  }\n  {\n    ThreadLocalGuard<T> guard(value0);\n    ASSERT_TRUE(ThreadLocalGuard<T>::Current().has_value());\n    T value = CHECK_JUST(ThreadLocalGuard<T>::Current());\n    ASSERT_EQ(value, value0);\n  }\n  {\n    ThreadLocalGuard<T> guard(value1);\n    ASSERT_TRUE(ThreadLocalGuard<T>::Current().has_value());\n    const auto& value = CHECK_JUST(ThreadLocalGuard<T>::Current());\n    ASSERT_EQ(value, value1);\n  }\n  {\n    ThreadLocalGuard<T> guard(value0);\n    ASSERT_TRUE(ThreadLocalGuard<T>::Current().has_value());\n    {\n      const auto& value = CHECK_JUST(ThreadLocalGuard<T>::Current());\n      ASSERT_EQ(value, value0);\n    }\n    {\n      ThreadLocalGuard<T> nested_guard(value1);\n      ASSERT_TRUE(ThreadLocalGuard<T>::Current().has_value());\n      const auto& value = CHECK_JUST(ThreadLocalGuard<T>::Current());\n      ASSERT_EQ(value, value1);\n    }\n    {\n      const auto& value = CHECK_JUST(ThreadLocalGuard<T>::Current());\n      ASSERT_EQ(value, value0);\n    }\n  }\n}\n\nTEST(ThreadLocalGuard, bool) { Assert<bool>(true, false); }\n\n}  // namespace test\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/common/throw.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_COMMON_THROW_H_\n#define ONEFLOW_CORE_COMMON_THROW_H_\n\n#include <glog/logging.h>\n#include \"oneflow/core/common/error.h\"\n#include \"oneflow/core/common/preprocessor.h\"\n#include \"oneflow/core/common/to_string.h\"\n\nnamespace oneflow {\n\nnamespace details {\n\nstruct Throw final {\n  [[noreturn]] void operator=(Error&& error) { ThrowError(error.stacked_error()); }\n};\n\n}  // namespace details\n\n}  // namespace oneflow\n\n#define PRINT_BUG_PROMPT_AND_ABORT() LOG(FATAL) << kOfBugIssueUploadPrompt\n\n// use CHECK_XX_OR_THROW instead of glog CHECK to get more information of stack when check failed\n#undef CHECK\n#undef CHECK_LT\n#undef CHECK_LE\n#undef CHECK_EQ\n#undef CHECK_NE\n#undef CHECK_GT\n#undef CHECK_GE\n\n#define CHECK CHECK_OR_THROW\n#define CHECK_LT CHECK_LT_OR_THROW\n#define CHECK_LE CHECK_LE_OR_THROW\n#define CHECK_EQ CHECK_EQ_OR_THROW\n#define CHECK_NE CHECK_NE_OR_THROW\n#define CHECK_GT CHECK_GT_OR_THROW\n#define CHECK_GE CHECK_GE_OR_THROW\n\n#define THROW(err_type)                                                                    \\\n  ::oneflow::details::Throw() =                                                            \\\n      ::oneflow::Error::err_type().AddStackFrame([](const char* function) {                \\\n        thread_local static auto frame =                                                   \\\n            ::oneflow::SymbolOf(::oneflow::ErrorStackFrame(__FILE__, __LINE__, function)); \\\n        return frame;                                                                      \\\n      }(__FUNCTION__))\n\n// use __FILE__ __LINE__ etc. macros to get last frame, so this macro can show\n// the file name and line where CHECK_OR_THROW located even if these is no debug info\n#define CHECK_OR_THROW_INTERNAL(expr, error_msg)                                      \\\n  if (!(expr))                                                                        \\\n  ::oneflow::details::Throw() =                                                       \\\n      ::oneflow::Error::CheckFailedError()                                            \\\n          .AddStackFrame([](const char* function) {                                   \\\n            thread_local static auto frame = ::oneflow::SymbolOf(                     \\\n                ::oneflow::ErrorStackFrame(__FILE__, __LINE__, function, error_msg)); \\\n            return frame;                                                             \\\n          }(__FUNCTION__))                                                            \\\n          .GetStackTrace()\n\n#define CHECK_OR_THROW(expr)                                           \\\n  CHECK_OR_THROW_INTERNAL(expr, OF_PP_STRINGIZE(CHECK_OR_THROW(expr))) \\\n      << \"Check failed: (\" << OF_PP_STRINGIZE(expr) << \") \"\n\n#define CHECK_EQ_OR_THROW(lhs, rhs)                                                     \\\n  CHECK_OR_THROW_INTERNAL((lhs) == (rhs), OF_PP_STRINGIZE(CHECK_EQ_OR_THROW(lhs, rhs))) \\\n      << \"Check failed: \"                                                               \\\n      << \"(\" << ::oneflow::ToStringIfApplicable(lhs)                                    \\\n      << \" == \" << ::oneflow::ToStringIfApplicable(rhs) << \"): \"\n\n#define CHECK_GE_OR_THROW(lhs, rhs)                                                     \\\n  CHECK_OR_THROW_INTERNAL((lhs) >= (rhs), OF_PP_STRINGIZE(CHECK_GE_OR_THROW(lhs, rhs))) \\\n      << \"Check failed: \"                                                               \\\n      << \"(\" << ::oneflow::ToStringIfApplicable(lhs)                                    \\\n      << \" >= \" << ::oneflow::ToStringIfApplicable(rhs) << \"): \"\n\n#define CHECK_GT_OR_THROW(lhs, rhs)                                                    \\\n  CHECK_OR_THROW_INTERNAL((lhs) > (rhs), OF_PP_STRINGIZE(CHECK_GT_OR_THROW(lhs, rhs))) \\\n      << \"Check failed: \"                                                              \\\n      << \"(\" << ::oneflow::ToStringIfApplicable(lhs) << \" > \"                          \\\n      << ::oneflow::ToStringIfApplicable(rhs) << \"): \"\n\n#define CHECK_LE_OR_THROW(lhs, rhs)                                                     \\\n  CHECK_OR_THROW_INTERNAL((lhs) <= (rhs), OF_PP_STRINGIZE(CHECK_LE_OR_THROW(lhs, rhs))) \\\n      << \"Check failed: \"                                                               \\\n      << \"(\" << ::oneflow::ToStringIfApplicable(lhs)                                    \\\n      << \" <= \" << ::oneflow::ToStringIfApplicable(rhs) << \"): \"\n\n#define CHECK_LT_OR_THROW(lhs, rhs)                                                    \\\n  CHECK_OR_THROW_INTERNAL((lhs) < (rhs), OF_PP_STRINGIZE(CHECK_LT_OR_THROW(lhs, rhs))) \\\n      << \"Check failed: \"                                                              \\\n      << \"(\" << ::oneflow::ToStringIfApplicable(lhs) << \" < \"                          \\\n      << ::oneflow::ToStringIfApplicable(rhs) << \"): \"\n\n#define CHECK_NE_OR_THROW(lhs, rhs)                                                     \\\n  CHECK_OR_THROW_INTERNAL((lhs) != (rhs), OF_PP_STRINGIZE(CHECK_NE_OR_THROW(lhs, rhs))) \\\n      << \"Check failed: \"                                                               \\\n      << \"(\" << ::oneflow::ToStringIfApplicable(lhs)                                    \\\n      << \" != \" << ::oneflow::ToStringIfApplicable(rhs) << \"): \"\n\n#define CHECK_STREQ_OR_THROW(lhs, rhs) CHECK_EQ_OR_THROW(std::string(lhs), std::string(rhs))\n\n#define CHECK_STRNE_OR_THROW(lhs, rhs) CHECK_NE_OR_THROW(std::string(lhs), std::string(rhs))\n\n#define CHECK_NOTNULL_OR_THROW(ptr) CHECK_OR_THROW(ptr != nullptr)\n\n#define CHECK_ISNULL_OR_THROW(ptr) CHECK_OR_THROW(ptr == nullptr)\n\n#define TODO_THEN_THROW()                                                                  \\\n  ::oneflow::details::Throw() =                                                            \\\n      ::oneflow::Error::TodoError().AddStackFrame([](const char* function) {               \\\n        thread_local static auto frame =                                                   \\\n            ::oneflow::SymbolOf(::oneflow::ErrorStackFrame(__FILE__, __LINE__, function)); \\\n        return frame;                                                                      \\\n      }(__FUNCTION__))\n\n#define UNIMPLEMENTED_THEN_THROW()                                                         \\\n  ::oneflow::details::Throw() =                                                            \\\n      ::oneflow::Error::UnimplementedError().AddStackFrame([](const char* function) {      \\\n        thread_local static auto frame =                                                   \\\n            ::oneflow::SymbolOf(::oneflow::ErrorStackFrame(__FILE__, __LINE__, function)); \\\n        return frame;                                                                      \\\n      }(__FUNCTION__))\n\n#endif  // ONEFLOW_CORE_COMMON_THROW_H_\n"
  },
  {
    "path": "oneflow/core/common/to_string.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_COMMON_TO_STRING_H_\n#define ONEFLOW_CORE_COMMON_TO_STRING_H_\n\n#include <string>\n#include \"oneflow/core/common/type_traits.h\"\n\nnamespace oneflow {\n\ntemplate<typename T>\ninline std::string ToString(const T& value) {\n  return std::to_string(value);\n}\n\ntemplate<typename T>\ninline std::string ToStringIfApplicable(const T& value) {\n  if constexpr (printable<T>()) {\n    std::stringstream ss;\n    ss << value;\n    return ss.str();\n  } else {\n    return \"<non-printable>\";\n  }\n}\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_COMMON_TO_STRING_H_\n"
  },
  {
    "path": "oneflow/core/common/tuple_hash.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_COMMON_TUPLE_HASH_H_\n#define ONEFLOW_CORE_COMMON_TUPLE_HASH_H_\n\n#include <tuple>\n#include <utility>\n#include \"oneflow/core/common/util.h\"\n\nnamespace std {\n\ntemplate<typename... T>\nstruct hash<std::tuple<T...>> final {\n  size_t operator()(const std::tuple<T...>& val) const {\n    return do_hash(val, std::index_sequence_for<T...>{});\n  }\n\n private:\n  template<size_t... I>\n  size_t do_hash(const std::tuple<T...>& val, std::index_sequence<I...>) const {\n    return oneflow::Hash<T...>(std::get<I>(val)...);\n  }\n};\n\n}  // namespace std\n\n#endif  // ONEFLOW_CORE_COMMON_TUPLE_HASH_H_\n"
  },
  {
    "path": "oneflow/core/common/type_traits.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_COMMON_TYPE_TRAITS_H_\n#define ONEFLOW_CORE_COMMON_TYPE_TRAITS_H_\n#include <type_traits>\n#if defined(WITH_CUDA)\n#include <cuda_fp16.h>\n#include <cuda.h>\n#endif\n#include \"oneflow/core/common/bfloat16.h\"\n#include <half.hpp>\n#include <complex>\n\nnamespace std {\n\n#if __GNUG__ && __GNUC__ < 5 && !__clang__\n// copied from\n// https://llvm.org/doxygen/type__traits_8h_source.html\nnamespace detail {\n/// Internal utility to detect trivial copy construction.\ntemplate<typename T>\nunion copy_construction_triviality_helper {\n  T t;\n  copy_construction_triviality_helper() = default;\n  copy_construction_triviality_helper(const copy_construction_triviality_helper&) = default;\n  ~copy_construction_triviality_helper() = default;\n};\n/// Internal utility to detect trivial move construction.\ntemplate<typename T>\nunion move_construction_triviality_helper {\n  T t;\n  move_construction_triviality_helper() = default;\n  move_construction_triviality_helper(move_construction_triviality_helper&&) = default;\n  ~move_construction_triviality_helper() = default;\n};\n\ntemplate<class T>\nunion trivial_helper {\n  T t;\n};\n\n}  // end namespace detail\n\n// is_trivially_copyable\n// An implementation of `std::is_trivially_copyable` since STL version\n// is not equally supported by all compilers, especially GCC 4.8.\n// Uniform implementation of this trait is important for ABI compatibility\n// as it has an impact on SmallVector's ABI (among others).\ntemplate<typename T>\nclass is_trivially_copyable {\n  // copy constructors\n  static constexpr bool has_trivial_copy_constructor =\n      std::is_copy_constructible<detail::trivial_helper<T>>::value;\n  static constexpr bool has_deleted_copy_constructor = !std::is_copy_constructible<T>::value;\n\n  // move constructors\n  static constexpr bool has_trivial_move_constructor =\n      std::is_move_constructible<detail::trivial_helper<T>>::value;\n  static constexpr bool has_deleted_move_constructor = !std::is_move_constructible<T>::value;\n\n  // copy assign\n  static constexpr bool has_trivial_copy_assign =\n      is_copy_assignable<detail::trivial_helper<T>>::value;\n  static constexpr bool has_deleted_copy_assign = !is_copy_assignable<T>::value;\n\n  // move assign\n  static constexpr bool has_trivial_move_assign =\n      is_move_assignable<detail::trivial_helper<T>>::value;\n  static constexpr bool has_deleted_move_assign = !is_move_assignable<T>::value;\n\n  // destructor\n  static constexpr bool has_trivial_destructor =\n      std::is_destructible<detail::trivial_helper<T>>::value;\n\n public:\n  static constexpr bool value = has_trivial_destructor\n                                && (has_deleted_move_assign || has_trivial_move_assign)\n                                && (has_deleted_move_constructor || has_trivial_move_constructor)\n                                && (has_deleted_copy_assign || has_trivial_copy_assign)\n                                && (has_deleted_copy_constructor || has_trivial_copy_constructor);\n\n#ifdef HAVE_STD_IS_TRIVIALLY_COPYABLE\n  static_assert(\n      value == std::is_trivially_copyable<T>::value,\n      \"inconsistent behavior between llvm:: and std:: implementation of is_trivially_copyable\");\n#endif\n};\ntemplate<typename T>\nclass is_trivially_copyable<T*> : public true_type {};\n#endif\n\n}  // namespace std\n\nnamespace oneflow {\n\n// Type Trait: IsScalarType\n\ntemplate<typename T, typename Enable = void>\nstruct IsScalarType final {\n  static const bool value = std::is_scalar<T>::value;\n};\n\ntemplate<typename T>\nstruct IsScalarType<\n    T, typename std::enable_if<\n           std::is_same<bfloat16, typename std::remove_cv<T>::type>::value\n           || std::is_same<half_float::half, typename std::remove_cv<T>::type>::value\n#ifdef WITH_CUDA\n           || std::is_same<half, typename std::remove_cv<T>::type>::value\n#endif  // WITH_CUDA\n           || std::is_same<std::complex<float>, typename std::remove_cv<T>::type>::value\n           || std::is_same<std::complex<double>, typename std::remove_cv<T>::type>::value>::type>\n    final {\n  static const bool value = true;\n};\n\nnamespace detail {\n\ntemplate<typename T>\nusing remove_cvref_t = typename std::remove_cv<typename std::remove_reference<T>::type>::type;\n\ntemplate<typename T, typename Enabled = void>\nstruct ScalarOrConstRef;\n\ntemplate<typename T>\nstruct ScalarOrConstRef<T, typename std::enable_if<std::is_scalar<T>::value>::type> {\n  using type = T;\n};\n\ntemplate<typename T>\nstruct ScalarOrConstRef<T, typename std::enable_if<!std::is_scalar<T>::value>::type> {\n  using type = const T&;\n};\n\ntemplate<typename T>\nconstexpr auto printable(int)\n    -> decltype(std::declval<std::stringstream>() << std::declval<T>(), bool()) {\n  return true;\n}\n\ntemplate<typename T>\nconstexpr bool printable(...) {\n  return false;\n}\n\n}  // namespace detail\n\ntemplate<typename T>\nusing scalar_or_const_ref_t = typename detail::ScalarOrConstRef<T>::type;\n\ntemplate<typename T>\nconstexpr bool printable() {\n  return detail::printable<T>(0);\n}\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_COMMON_TYPE_TRAITS_H_\n"
  },
  {
    "path": "oneflow/core/common/util.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/common/data_type.h\"\n#include <cfenv>\n#include \"oneflow/core/common/str_util.h\"\n#include \"oneflow/core/common/platform.h\"\n#include <csignal>\n#include <limits>\n\n#ifdef __linux__\n#include <sys/sysinfo.h>\n#include <unistd.h>\n#endif\n\nnamespace oneflow {\n\n#define DEFINE_ONEFLOW_STR2INT_CAST(dst_type, cast_func) \\\n  template<>                                             \\\n  dst_type oneflow_cast(const std::string& s) {          \\\n    char* end_ptr = nullptr;                             \\\n    dst_type ret = cast_func(s.c_str(), &end_ptr, 0);    \\\n    CHECK_EQ(*end_ptr, '\\0');                            \\\n    return ret;                                          \\\n  }\n\nDEFINE_ONEFLOW_STR2INT_CAST(long, strtol);\nDEFINE_ONEFLOW_STR2INT_CAST(unsigned long, strtoul);\nDEFINE_ONEFLOW_STR2INT_CAST(long long, strtoll);\nDEFINE_ONEFLOW_STR2INT_CAST(unsigned long long, strtoull);\n\nDEFINE_ONEFLOW_STR2INT_CAST(signed char, strtol);\nDEFINE_ONEFLOW_STR2INT_CAST(short, strtol);\nDEFINE_ONEFLOW_STR2INT_CAST(int, strtol);\n\nDEFINE_ONEFLOW_STR2INT_CAST(unsigned char, strtoul);\nDEFINE_ONEFLOW_STR2INT_CAST(unsigned short, strtoul);\nDEFINE_ONEFLOW_STR2INT_CAST(unsigned int, strtoul);\n\ntemplate<>\nfloat oneflow_cast(const std::string& s) {\n  char* end_ptr = nullptr;\n  float ret = strtof(s.c_str(), &end_ptr);\n  CHECK_EQ(*end_ptr, '\\0');\n  return ret;\n}\n\ntemplate<>\ndouble oneflow_cast(const std::string& s) {\n  char* end_ptr = nullptr;\n  double ret = strtod(s.c_str(), &end_ptr);\n  CHECK_EQ(*end_ptr, '\\0');\n  return ret;\n}\n\n#ifdef OF_PLATFORM_POSIX\n// COMMAND(feenableexcept(FE_ALL_EXCEPT & ~FE_INEXACT & ~FE_UNDERFLOW));\n#endif\n\n// If the interrupt during object malloc is changed to exit, the exit function indicates a normal\n// exit, triggering the object destructor function and then triggering object free. Since there is a\n// lock in malloc, if malloc and free obtain the same lock, it can cause a deadlock, which prevents\n// the process from exiting. After calling abort, the OS forces the program to exit,\n// relying on the OS to do resource cleanup, which can avoid the deadlock issue.\n// Process inability to exit can be more troublesome than potential resource leaks. If we find that\n// abort causes unreleased resources later, we can use exit in a local scope rather than globally.\n// Reference: https://github.com/Oneflow-Inc/OneTeam/issues/1954\nvoid AbortSignalHandler(int signal) { std::abort(); }\nCOMMAND(std::signal(SIGINT, AbortSignalHandler));\n\nsize_t GetAvailableCpuMemSize() {\n#if defined(__linux__)\n  std::ifstream mem_info(\"/proc/meminfo\");\n  CHECK(mem_info.good()) << \"can't open file: /proc/meminfo\";\n  std::string line;\n  while (std::getline(mem_info, line).good()) {\n    std::string token;\n    const char* p = line.c_str();\n    p = StrToToken(p, \" \", &token);\n    if (token != \"MemAvailable:\") { continue; }\n    CHECK_NE(*p, '\\0');\n    p = StrToToken(p, \" \", &token);\n    size_t mem_available = oneflow_cast<size_t>(token);\n    CHECK_NE(*p, '\\0');\n    p = StrToToken(p, \" \", &token);\n    CHECK_EQ(token, \"kB\");\n    return mem_available * 1024;\n  }\n  return sysconf(_SC_PAGESIZE) * sysconf(_SC_AVPHYS_PAGES);\n#elif defined(__APPLE__)\n  // macOS will eagerly make use of all memory so there is no point querying it\n  return std::numeric_limits<size_t>::max();\n#else\n  UNIMPLEMENTED();\n  return 0;\n#endif\n}\n\nbool IsKernelSafeInt32(int64_t n) { return n <= GetMaxVal<int32_t>() / 2; }\n\nnamespace {\n\nbool CaseInsensitiveStringEquals(const std::string& lhs, const std::string& rhs) {\n  return lhs.size() == rhs.size()\n         && std::equal(lhs.begin(), lhs.end(), rhs.begin(),\n                       [](char a, char b) { return std::tolower(a) == std::tolower(b); });\n}\n\nbool StringToBool(const std::string& str) {\n  return CaseInsensitiveStringEquals(str, \"1\") || CaseInsensitiveStringEquals(str, \"true\")\n         || CaseInsensitiveStringEquals(str, \"yes\") || CaseInsensitiveStringEquals(str, \"on\")\n         || CaseInsensitiveStringEquals(str, \"y\");\n}\n\nbool StringToInteger(const std::string& str, int64_t* value) {\n  char* end;\n  int64_t v = std::strtoll(str.data(), &end, 10);\n  if (end == str.data()) {\n    return false;\n  } else {\n    *value = v;\n    return true;\n  }\n}\n\nbool StringToFloat(const std::string& str, double* value) {\n  char* end = nullptr;\n  double v = std::strtof(str.data(), &end);\n  if (end == str.data()) {\n    return false;\n  } else {\n    *value = v;\n    return true;\n  }\n}\n\n}  // namespace\n\nbool ParseBooleanFromEnv(const std::string& env_var, bool default_value) {\n  const char* env_p = std::getenv(env_var.c_str());\n  if (env_p == nullptr) {\n    return default_value;\n  } else {\n    return StringToBool(env_p);\n  }\n}\n\nint64_t ParseIntegerFromEnv(const std::string& env_var, int64_t default_value) {\n  const char* env_p = std::getenv(env_var.c_str());\n  if (env_p == nullptr) { return default_value; }\n  int64_t value;\n  if (StringToInteger(env_p, &value)) {\n    return value;\n  } else {\n    return default_value;\n  }\n}\n\ndouble ParseFloatFromEnv(const std::string& env_var, double default_value) {\n  const char* env_p = std::getenv(env_var.c_str());\n  if (env_p == nullptr) { return default_value; }\n  double value = default_value;\n  StringToFloat(env_p, &value);\n  return value;\n}\n\nstd::string GetStringFromEnv(const std::string& env_var, const std::string& default_value) {\n  const char* env_p = std::getenv(env_var.c_str());\n  if (env_p == nullptr) {\n    return default_value;\n  } else {\n    return env_p;\n  }\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/common/util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_COMMON_UTIL_H_\n#define ONEFLOW_CORE_COMMON_UTIL_H_\n\n#include \"oneflow/core/common/preprocessor.h\"\n\n#include \"oneflow/core/common/throw.h\"\n#include <algorithm>\n#include <atomic>\n#include <condition_variable>\n#include <forward_list>\n#include <fstream>\n#include <functional>\n#include <iostream>\n#include <list>\n#include <memory>\n#include <mutex>\n#include <queue>\n#include <random>\n#include <thread>\n#include <utility>\n#include <cfenv>\n#include <complex>\n\n#include \"oneflow/core/common/hash_container.h\"\n#include \"oneflow/core/common/meta_util.hpp\"\n#include \"oneflow/core/common/singleton.h\"\n#include \"oneflow/core/common/hash.h\"\n#include \"oneflow/core/common/cpp_attribute.h\"\n#include \"fmt/format.h\"\n#include \"fmt/ranges.h\"\n\n#define CHECK_ISNULL(e) CHECK((e) == nullptr)\n\nnamespace fmt {\ntemplate<typename T>\nstruct formatter<std::complex<T>> : formatter<std::string_view> {\n  template<typename FormatContext>\n  auto format(const std::complex<T>& c, FormatContext& ctx) {\n    return formatter<std::string_view>::format(fmt::format(\"({}+{}j)\", c.real(), c.imag()), ctx);\n  }\n};\n}  // namespace fmt\n\ntemplate<class T>\nstd::ostream& operator<<(std::ostream& os, const std::vector<T>& v) {\n  os << fmt::format(\"{}\", v);\n  return os;\n}\n\nnamespace oneflow {\n\n#define OF_DISALLOW_COPY(ClassName)     \\\n  ClassName(const ClassName&) = delete; \\\n  ClassName& operator=(const ClassName&) = delete\n\n#define OF_DISALLOW_MOVE(ClassName) \\\n  ClassName(ClassName&&) = delete;  \\\n  ClassName& operator=(ClassName&&) = delete\n\n#define OF_DISALLOW_COPY_AND_MOVE(ClassName) \\\n  OF_DISALLOW_COPY(ClassName);               \\\n  OF_DISALLOW_MOVE(ClassName)\n\n#define UNIMPLEMENTED() LOG(FATAL) << \"UNIMPLEMENTED\"\n\n#define TODO() LOG(FATAL) << \"TODO\"\n\n#define OF_COMMA ,\n\n#define DEFINE_STATIC_VAR(type, name) \\\n  static type* name() {               \\\n    static type var;                  \\\n    return &var;                      \\\n  }\n\n#define COMMAND(...)                                                \\\n  namespace {                                                       \\\n  struct OF_PP_CAT(CommandT, __LINE__) {                            \\\n    OF_PP_CAT(CommandT, __LINE__)() { __VA_ARGS__; }                \\\n  };                                                                \\\n  OF_PP_CAT(CommandT, __LINE__) OF_PP_CAT(g_command_var, __LINE__); \\\n  }\n\ntemplate<typename T>\nbool operator==(const std::weak_ptr<T>& lhs, const std::weak_ptr<T>& rhs) {\n  return lhs.lock().get() == rhs.lock().get();\n}\n\ntemplate<typename T>\nvoid SortAndRemoveDuplication(std::vector<T>* vec) {\n  std::sort(vec->begin(), vec->end());\n  auto unique_it = std::unique(vec->begin(), vec->end());\n  vec->erase(unique_it, vec->end());\n}\n\ninline std::string NewUniqueId() {\n  static std::atomic<int64_t> counter(0);\n  return std::to_string(counter.fetch_add(1, std::memory_order_relaxed));\n}\n\ntemplate<typename K, typename V>\nvoid EraseIf(HashMap<K, V>* hash_map, std::function<bool(typename HashMap<K, V>::iterator)> cond) {\n  for (auto it = hash_map->begin(); it != hash_map->end();) {\n    if (cond(it)) {\n      hash_map->erase(it++);\n    } else {\n      ++it;\n    }\n  }\n}\n\ntemplate<typename T>\ntypename std::enable_if<std::is_enum<T>::value, std::ostream&>::type operator<<(\n    std::ostream& out_stream, const T& x) {\n  out_stream << static_cast<int>(x);\n  return out_stream;\n}\n\ntemplate<typename OutType, typename InType>\nOutType oneflow_cast(const InType&);\n\ninline uint32_t NewRandomSeed() {\n  static std::mt19937 gen{std::random_device{}()};\n  return gen();\n}\n\n#define DIM_SEQ           \\\n  OF_PP_MAKE_TUPLE_SEQ(1) \\\n  OF_PP_MAKE_TUPLE_SEQ(2) \\\n  OF_PP_MAKE_TUPLE_SEQ(3) \\\n  OF_PP_MAKE_TUPLE_SEQ(4) OF_PP_MAKE_TUPLE_SEQ(5) OF_PP_MAKE_TUPLE_SEQ(6) OF_PP_MAKE_TUPLE_SEQ(7)\n\n#define BOOL_SEQ (true)(false)\n\n#define FOR_RANGE(type, i, begin, end) for (type i = (begin), __end = (end); i < __end; ++i)\n#define FOR_EACH(it, container) for (auto it = container.begin(); it != container.end(); ++it)\n\ninline double GetCurTime() {\n  return std::chrono::high_resolution_clock::now().time_since_epoch().count();\n}\n\nconst size_t kHostAlignSize = 64;\nconst size_t kCudaAlignSize = 512;\nconst size_t kCudaMemAllocAlignSize = 512;\nconst int32_t kBlobBodyAlignSize = 512;\nconst int32_t kBlobHeaderAlignSize = 64;\n\ninline size_t RoundUp(size_t n, size_t val) { return (n + val - 1) / val * val; }\n\ninline size_t GetCudaAlignedSize(size_t size) { return RoundUp(size, kCudaAlignSize); }\n\nsize_t GetAvailableCpuMemSize();\n\ntemplate<typename T>\nvoid Erase(T& container, const std::function<bool(const typename T::value_type&)>& NeedErase,\n           const std::function<void(const typename T::value_type&)>& EraseElementHandler) {\n  auto iter = container.begin();\n  auto erase_from = container.end();\n  while (iter != erase_from) {\n    if (NeedErase(*iter)) {\n      --erase_from;\n      if (iter == erase_from) { break; }\n      std::swap(*iter, *erase_from);\n    } else {\n      ++iter;\n    }\n  }\n  for (; iter != container.end(); ++iter) { EraseElementHandler(*iter); }\n  if (erase_from != container.end()) { container.erase(erase_from, container.end()); }\n}\n\ntemplate<typename T>\nvoid Erase(T& container, const std::function<bool(const typename T::value_type&)>& NeedErase) {\n  Erase<T>(container, NeedErase, [](const typename T::value_type&) {});\n}\n\n#if defined(__GNUC__)\n#define ALWAYS_INLINE __attribute__((always_inline))\n#elif defined(__CUDACC__)\n#define ALWAYS_INLINE __forceinline__\n#else\n#define ALWAYS_INLINE inline\n#endif\n\nbool IsKernelSafeInt32(int64_t n);\n\nclass RoundModeGuard final {\n public:\n  RoundModeGuard(int mode) {\n    saved_mode_ = std::fegetround();\n    CHECK_EQ(std::fesetround(mode), 0);\n  }\n  ~RoundModeGuard() { std::fesetround(saved_mode_); }\n\n private:\n  int saved_mode_;\n};\n\nbool ParseBooleanFromEnv(const std::string& env_var, bool default_value);\n\nint64_t ParseIntegerFromEnv(const std::string& env_var, int64_t default_value);\n\ndouble ParseFloatFromEnv(const std::string& env_var, double default_value);\n\nstd::string GetStringFromEnv(const std::string& env_var, const std::string& default_value);\n\n#define OF_PREDICT_TRUE likely\n#define OF_PREDICT_FALSE unlikely\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_COMMON_UTIL_H_\n"
  },
  {
    "path": "oneflow/core/common/wrap_dim_utils.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <bitset>\n#include \"oneflow/core/common/maybe.h\"\n\nnamespace oneflow {\n\n// align with pytorch: `c10/core/WrapDimMinimal.h`\nstatic inline Maybe<int64_t> maybe_wrap_dim(int64_t dim, int64_t dim_post_expr,\n                                            bool wrap_scalar = true) {\n  if (dim_post_expr <= 0) {\n    if (!wrap_scalar) {\n      return Error::RuntimeError()\n             << \"dimension specified as \" << dim << \" but tensor has no dimensions\";\n    }\n    dim_post_expr = 1;  // this will make range [-1, 0]\n  }\n\n  int64_t min = -dim_post_expr;\n  int64_t max = dim_post_expr - 1;\n  if (dim < min || dim > max) {\n    return Error::IndexError() << \"Dimension out of range (expected to be in range of [\" << min\n                               << \", \" << max << \"], but got \" << dim << \")\";\n  }\n  if (dim < 0) dim += dim_post_expr;\n  return dim;\n}\n\n// align with pytorch: `aten/src/ATen/WrapDimUtilsMulti.h`\nconstexpr size_t dim_bitset_size = 64;\n\nstatic inline Maybe<std::bitset<dim_bitset_size>> dim_list_to_bitset(\n    const std::vector<int32_t>& dims, int64_t ndims) {\n  CHECK_LE_OR_RETURN(ndims, (int64_t)dim_bitset_size)\n      << Error::RuntimeError() << \"Only tensors with up to \" << dim_bitset_size\n      << \" dims are supported\";\n  std::bitset<dim_bitset_size> seen;\n  for (int32_t i = 0; i < dims.size(); i++) {\n    size_t dim = JUST(maybe_wrap_dim(dims[i], ndims));\n    CHECK_OR_RETURN_ERROR(!seen[dim]) << Error::RuntimeError() << \"The dim \" << dim\n                                      << \" appears multiple times in the list of dims\";\n    seen[dim] = true;\n  }\n  return seen;\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/common/zero_only_zip.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_COMMON_ZERO_ONLY_ZIP_H_\n#define ONEFLOW_CORE_COMMON_ZERO_ONLY_ZIP_H_\n\n#include <memory>\n#include \"oneflow/core/common/sized_buffer_view.h\"\n\nnamespace oneflow {\n\nstruct ZeroOnlyZipUtil final {\n  void ZipToSizedBuffer(const char* data, size_t size, SizedBufferView* sized_buffer);\n  void UnzipToExpectedSize(const SizedBufferView& size_buffer, char* data, size_t expected_size);\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_COMMON_ZERO_ONLY_ZIP_H_\n"
  },
  {
    "path": "oneflow/core/control/bootstrap_client.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_CONTROL_BOOTSTRAP_CLIENT_H_\n#define ONEFLOW_CORE_CONTROL_BOOTSTRAP_CLIENT_H_\n\n#include \"oneflow/core/control/rpc_client.h\"\n#include \"oneflow/core/job/env_desc.h\"\n\nnamespace oneflow {\n\nclass BootstrapClient : public RpcClient {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(BootstrapClient);\n  virtual ~BootstrapClient() override = default;\n\n protected:\n  friend class Singleton<BootstrapClient>;\n  BootstrapClient() = default;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_CONTROL_BOOTSTRAP_CLIENT_H_\n"
  },
  {
    "path": "oneflow/core/control/bootstrap_server.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_CONTROL_BOOTSTRAP_SERVER_H_\n#define ONEFLOW_CORE_CONTROL_BOOTSTRAP_SERVER_H_\n\n#include \"oneflow/core/control/rpc_server.h\"\n#include \"oneflow/core/job/env_desc.h\"\n\nnamespace oneflow {\n\nclass BootstrapServer : public RpcServer {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(BootstrapServer);\n  BootstrapServer() = default;\n  virtual ~BootstrapServer() override = default;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_CONTROL_BOOTSTRAP_SERVER_H_\n"
  },
  {
    "path": "oneflow/core/control/control.proto",
    "content": "syntax = \"proto2\";\npackage oneflow;\n\nmessage LoadServerRequest {\n  required string addr = 1;\n  optional int64 rank = 2 [default = -1];\n}\n\nmessage LoadServerResponse {\n}\n\nmessage BarrierRequest {\n  required string name = 1;\n  required int32 num = 2;\n}\n\nmessage BarrierResponse {\n}\n\nenum TryLockResult {\n  kLocked = 0;\n  kDone = 1;\n  kDoing = 2;\n}\n\nmessage TryLockRequest {\n  required string name = 1;\n}\n\nmessage TryLockResponse {\n  required TryLockResult result = 1;\n}\n\nmessage NotifyDoneRequest {\n  required string name = 1;\n}\n\nmessage NotifyDoneResponse {\n}\n\nmessage WaitUntilDoneRequest {\n  required string name = 1;\n}\n\nmessage WaitUntilDoneResponse {\n}\n\nmessage PushKVRequest {\n  required string key = 1;\n  required bytes val = 2;\n}\n\nmessage PushKVResponse {\n}\n\nmessage ClearKVRequest {\n  required string key = 1;\n}\n\nmessage ClearKVResponse {\n}\n\nmessage PullKVRequest {\n  required string key = 1;\n}\n\nmessage PullKVResponse {\n  required bytes val = 1;\n}\n\nmessage ClearRequest {\n}\n\nmessage ClearResponse {\n}\n\nmessage IncreaseCountRequest {\n  required string key = 1;\n  required int32 val = 2;\n}\n\nmessage IncreaseCountResponse {\n  required int32 val = 1;\n}\n\nmessage EraseCountRequest {\n  required string key = 1;\n}\n\nmessage EraseCountResponse {\n}\n"
  },
  {
    "path": "oneflow/core/control/ctrl_bootstrap.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <map>\n#include <set>\n#include \"oneflow/core/control/ctrl_bootstrap.h\"\n#include \"oneflow/core/control/worker_process_info.pb.h\"\n#include \"oneflow/core/control/host_list_bootstrap_server.h\"\n#include \"oneflow/core/control/host_list_bootstrap_client.h\"\n#include \"oneflow/core/control/rank_info_bootstrap_server.h\"\n#include \"oneflow/core/control/rank_info_bootstrap_client.h\"\n\nnamespace oneflow {\n\nMaybe<void> CtrlBootstrap::InitProcessCtx(int64_t port, ProcessCtx* ret_process_ctx) {\n  std::vector<WorkerProcessInfo> worker_process_info_list;\n  worker_process_info_list.reserve(world_size());\n  if (rank() == 0) {\n    WorkerProcessInfo worker_process_info;\n    {\n      worker_process_info.set_rank(rank());\n      worker_process_info.set_port(port);\n      JUST(SetCurrentHostByMaster(&worker_process_info));\n    }\n    worker_process_info_list.emplace_back(worker_process_info);\n    for (int64_t world_rank = 1; world_rank < world_size(); ++world_rank) {\n      std::string key = std::string(\"GetWorkerProcessInfo\") + std::to_string(world_rank);\n      WorkerProcessInfo cur_work_process_info;\n      mut_bootstrap_client()->PullMasterKV(key, &cur_work_process_info);\n      CHECK_EQ_OR_RETURN(world_rank, worker_process_info_list.size());\n      CHECK_EQ_OR_RETURN(world_rank, cur_work_process_info.rank());\n      worker_process_info_list.emplace_back(cur_work_process_info);\n    }\n  } else {\n    std::string key = std::string(\"GetWorkerProcessInfo\") + std::to_string(rank());\n    WorkerProcessInfo cur_work_process_info;\n    {\n      cur_work_process_info.set_rank(rank());\n      cur_work_process_info.set_port(port);\n      JUST(SetCurrentHostByWorker(&cur_work_process_info));\n    }\n    mut_bootstrap_client()->PushMasterKV(key, cur_work_process_info);\n  }\n\n  mut_bootstrap_client()->Barrier(__FILE__ \":\" OF_PP_STRINGIZE(__LINE__));\n\n  if (rank() == 0) {\n    ret_process_ctx->set_rank(rank());\n    ret_process_ctx->mutable_ctrl_addr()->Clear();\n    for (const auto& worker_process_info : worker_process_info_list) {\n      Address* addr = ret_process_ctx->mutable_ctrl_addr()->Add();\n      if (worker_process_info.has_host()) { addr->set_host(worker_process_info.host()); }\n      addr->set_port(worker_process_info.port());\n      JUST(SetHostByMaster(addr, worker_process_info.rank()));\n    }\n    JUST(SetNodeSize(ret_process_ctx));\n    mut_bootstrap_client()->PushMasterKV(\"BroadcastProcessCtx\", *ret_process_ctx);\n  } else {\n    mut_bootstrap_client()->PullMasterKV(\"BroadcastProcessCtx\", ret_process_ctx);\n    ret_process_ctx->set_rank(rank());\n  }\n\n  mut_bootstrap_client()->Barrier(__FILE__ \":\" OF_PP_STRINGIZE(__LINE__));\n\n  VLOG(2) << \"\\n\" << ret_process_ctx->DebugString();\n  return Maybe<void>::Ok();\n}\n\nHostListCtrlBootstrap::HostListCtrlBootstrap(const EnvDesc& env_desc) : CtrlBootstrap() {\n  bootstrap_server_.reset(new HostListBootstrapServer(env_desc));\n  bootstrap_client_.reset(new HostListBootstrapClient(env_desc));\n  bootstrap_client_->Barrier(__FILE__ \":\" OF_PP_STRINGIZE(__LINE__));\n  host_ = bootstrap_server_->this_machine_addr();\n  rank_ = env_desc.GetMachineId(host_);\n  world_size_ = env_desc.TotalMachineNum();\n}\n\nHostListCtrlBootstrap::~HostListCtrlBootstrap() {\n  bootstrap_client_.reset();\n  bootstrap_server_.reset();\n}\n\nMaybe<void> HostListCtrlBootstrap::SetHostByMaster(Address* addr, int64_t world_rank) const {\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> HostListCtrlBootstrap::SetCurrentHostByMaster(\n    WorkerProcessInfo* worker_process_info) const {\n  worker_process_info->set_host(host());\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> HostListCtrlBootstrap::SetCurrentHostByWorker(\n    WorkerProcessInfo* worker_process_info) const {\n  worker_process_info->set_host(host());\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> HostListCtrlBootstrap::SetNodeSize(ProcessCtx* process_ctx) const {\n  process_ctx->set_node_size(world_size());\n  return Maybe<void>::Ok();\n}\n\nBootstrapServer* HostListCtrlBootstrap::mut_bootstrap_server() { return bootstrap_server_.get(); }\nBootstrapClient* HostListCtrlBootstrap::mut_bootstrap_client() { return bootstrap_client_.get(); }\n\nRankInfoCtrlBootstrap::RankInfoCtrlBootstrap(const BootstrapConf& bootstrap_conf)\n    : CtrlBootstrap(), bootstrap_conf_(bootstrap_conf) {\n  bootstrap_server_.reset(new RankInfoBootstrapServer(bootstrap_conf));\n  bootstrap_client_.reset(new RankInfoBootstrapClient(bootstrap_conf));\n  bootstrap_client_->Barrier(__FILE__ \":\" OF_PP_STRINGIZE(__LINE__));\n  master_host_ = bootstrap_conf.master_addr().host();\n  rank_ = bootstrap_conf.rank();\n  world_size_ = bootstrap_conf.world_size();\n}\n\nRankInfoCtrlBootstrap::~RankInfoCtrlBootstrap() {\n  bootstrap_client_.reset();\n  bootstrap_server_.reset();\n}\n\nMaybe<void> RankInfoCtrlBootstrap::SetHostByMaster(Address* addr, int64_t world_rank) const {\n  if (addr->has_host()) { return Maybe<void>::Ok(); }\n  const auto& rank2host = JUST(bootstrap_server_->rank2host());\n  CHECK_EQ_OR_RETURN(rank2host.size(), world_size());\n  CHECK_GE_OR_RETURN(world_rank, 0);\n  CHECK_LT_OR_RETURN(world_rank, rank2host.size());\n  addr->set_host(rank2host.at(world_rank));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> RankInfoCtrlBootstrap::SetCurrentHostByMaster(\n    WorkerProcessInfo* worker_process_info) const {\n  CHECK_EQ_OR_RETURN(rank(), 0);\n  if (bootstrap_conf_.has_host()) {\n    worker_process_info->set_host(bootstrap_conf_.host());\n  } else {\n    worker_process_info->set_host(master_host_);\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> RankInfoCtrlBootstrap::SetCurrentHostByWorker(\n    WorkerProcessInfo* worker_process_info) const {\n  CHECK_NE_OR_RETURN(rank(), 0);\n  if (bootstrap_conf_.has_host()) { worker_process_info->set_host(bootstrap_conf_.host()); }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> RankInfoCtrlBootstrap::SetNodeSize(ProcessCtx* process_ctx) const {\n  if (bootstrap_conf_.has_node_size()) {\n    CHECK_EQ_OR_RETURN(world_size() % bootstrap_conf_.node_size(), 0);\n    process_ctx->set_node_size(bootstrap_conf_.node_size());\n    return Maybe<void>::Ok();\n  }\n  const auto& rank2host = JUST(bootstrap_server_->rank2host());\n  std::set<std::string> no_duplicated_host;\n  for (const auto& host : rank2host) { no_duplicated_host.insert(host); }\n  CHECK_EQ_OR_RETURN(world_size() % no_duplicated_host.size(), 0);\n  process_ctx->set_node_size(no_duplicated_host.size());\n  return Maybe<void>::Ok();\n}\n\nBootstrapServer* RankInfoCtrlBootstrap::mut_bootstrap_server() { return bootstrap_server_.get(); }\nBootstrapClient* RankInfoCtrlBootstrap::mut_bootstrap_client() { return bootstrap_client_.get(); }\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/control/ctrl_bootstrap.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_CONTROL_CTRL_BOOTSTRAP_H_\n#define ONEFLOW_CORE_CONTROL_CTRL_BOOTSTRAP_H_\n\n#include \"oneflow/core/control/ctrl_bootstrap.pb.h\"\n#include \"oneflow/core/job/env_desc.h\"\n#include \"oneflow/core/common/maybe.h\"\n\nnamespace oneflow {\n\nclass ProcessCtx;\nclass WorkerProcessInfo;\nclass BootstrapServer;\nclass BootstrapClient;\n\nclass CtrlBootstrap {\n public:\n  virtual ~CtrlBootstrap() {}\n\n  Maybe<void> InitProcessCtx(int64_t port, ProcessCtx* ret_process_ctx);\n\n protected:\n  virtual int64_t rank() const = 0;\n  virtual int64_t world_size() const = 0;\n  virtual Maybe<void> SetHostByMaster(Address*, int64_t world_rank) const = 0;\n  virtual Maybe<void> SetCurrentHostByMaster(WorkerProcessInfo*) const = 0;\n  virtual Maybe<void> SetCurrentHostByWorker(WorkerProcessInfo*) const = 0;\n  virtual Maybe<void> SetNodeSize(ProcessCtx* process_ctx) const = 0;\n\n  virtual BootstrapServer* mut_bootstrap_server() = 0;\n  virtual BootstrapClient* mut_bootstrap_client() = 0;\n\n  CtrlBootstrap() = default;\n};\n\nclass HostListBootstrapServer;\nclass HostListBootstrapClient;\n\nclass HostListCtrlBootstrap final : public CtrlBootstrap {\n public:\n  explicit HostListCtrlBootstrap(const EnvDesc& env_desc);\n  ~HostListCtrlBootstrap() override;\n\n private:\n  int64_t rank() const override { return rank_; }\n  int64_t world_size() const override { return world_size_; }\n\n  std::string host() const { return host_; }\n\n  Maybe<void> SetHostByMaster(Address*, int64_t world_rank) const override;\n  Maybe<void> SetCurrentHostByMaster(WorkerProcessInfo*) const override;\n  Maybe<void> SetCurrentHostByWorker(WorkerProcessInfo*) const override;\n  Maybe<void> SetNodeSize(ProcessCtx* process_ctx) const override;\n\n  BootstrapServer* mut_bootstrap_server() override;\n  BootstrapClient* mut_bootstrap_client() override;\n\n  // Uses shared_ptr and forward declaration to avoid `#include ...`\n  std::shared_ptr<HostListBootstrapServer> bootstrap_server_;\n  std::shared_ptr<HostListBootstrapClient> bootstrap_client_;\n\n  std::string host_;\n  int64_t rank_;\n  int64_t world_size_;\n};\n\nclass RankInfoBootstrapServer;\nclass RankInfoBootstrapClient;\n\nclass RankInfoCtrlBootstrap final : public CtrlBootstrap {\n public:\n  explicit RankInfoCtrlBootstrap(const BootstrapConf& bootstrap_conf);\n  ~RankInfoCtrlBootstrap() override;\n\n private:\n  int64_t rank() const override { return rank_; }\n  int64_t world_size() const override { return world_size_; }\n\n  Maybe<void> SetHostByMaster(Address*, int64_t world_rank) const override;\n  Maybe<void> SetCurrentHostByMaster(WorkerProcessInfo*) const override;\n  Maybe<void> SetCurrentHostByWorker(WorkerProcessInfo*) const override;\n  Maybe<void> SetNodeSize(ProcessCtx* process_ctx) const override;\n\n  BootstrapServer* mut_bootstrap_server() override;\n  BootstrapClient* mut_bootstrap_client() override;\n\n  // Uses shared_ptr and forward declaration to avoid `#include ...`\n  std::shared_ptr<RankInfoBootstrapServer> bootstrap_server_;\n  std::shared_ptr<RankInfoBootstrapClient> bootstrap_client_;\n\n  std::string master_host_;\n  BootstrapConf bootstrap_conf_;\n  int64_t rank_;\n  int64_t world_size_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_CONTROL_CTRL_BOOTSTRAP_H_\n"
  },
  {
    "path": "oneflow/core/control/ctrl_bootstrap.proto",
    "content": "syntax = \"proto2\";\npackage oneflow;\n\nmessage Address {\n  required string host = 1;\n  required int32 port = 2;\n}\n\nmessage ProcessCtx {\n  repeated Address ctrl_addr = 1;\n  required int64 rank = 2;\n  required int64 node_size = 3;\n}\n\nmessage BootstrapConf {\n  required Address master_addr = 1;\n  required int64 rank = 2;\n  required int64 world_size = 3;\n  optional string host = 4;\n  optional int32 ctrl_port = 5 [default = -1];\n  optional int64 node_size = 6 [default = -1];\n}\n"
  },
  {
    "path": "oneflow/core/control/ctrl_call.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_CONTROL_CTRL_CALL_H_\n#define ONEFLOW_CORE_CONTROL_CTRL_CALL_H_\n\n#include \"oneflow/core/control/ctrl_service.h\"\n\nnamespace oneflow {\n\nclass CtrlCallIf {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CtrlCallIf);\n  virtual ~CtrlCallIf() = default;\n\n  virtual void Process() = 0;\n  virtual void SendResponse() = 0;\n\n protected:\n  CtrlCallIf() = default;\n\n private:\n};\n\ntemplate<CtrlMethod ctrl_method>\nclass CtrlCall final : public CtrlCallIf {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CtrlCall);\n  CtrlCall() : status_(Status::kBeforeHandleRequest), responder_(&server_ctx_) {}\n  ~CtrlCall() = default;\n\n  static constexpr const size_t value = (size_t)ctrl_method;\n\n  const CtrlRequest<ctrl_method>& request() const { return request_; }\n  CtrlRequest<ctrl_method>* mut_request() { return &request_; }\n  CtrlResponse<ctrl_method>* mut_response() { return &response_; }\n  grpc::ServerContext* mut_server_ctx() { return &server_ctx_; }\n  const grpc::ServerContext& server_ctx() const { return server_ctx_; }\n  grpc::ServerAsyncResponseWriter<CtrlResponse<ctrl_method>>* mut_responder() {\n    return &responder_;\n  }\n  void set_request_handler(std::function<void()> val) { request_handler_ = val; }\n\n  void Process() override {\n    switch (status_) {\n      case Status::kBeforeHandleRequest: {\n        request_handler_();\n        return;\n      }\n      case Status::kBeforeDelete: {\n        delete this;\n        return;\n      }\n    }\n  }\n\n  void SendResponse() override {\n    responder_.Finish(response_, grpc::Status::OK, this);\n    status_ = Status::kBeforeDelete;\n  }\n\n private:\n  enum class Status { kBeforeHandleRequest, kBeforeDelete };\n\n  Status status_;\n  CtrlRequest<ctrl_method> request_;\n  CtrlResponse<ctrl_method> response_;\n  grpc::ServerContext server_ctx_;\n  grpc::ServerAsyncResponseWriter<CtrlResponse<ctrl_method>> responder_;\n  std::function<void()> request_handler_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_CONTROL_CTRL_CALL_H_\n"
  },
  {
    "path": "oneflow/core/control/ctrl_client.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/control/ctrl_client.h\"\n\nnamespace oneflow {\n\nnamespace {\n\n#define GRPC_CHECK(x) CHECK_EQ(x.error_code(), grpc::StatusCode::OK)\n\n}  // namespace\n\nGrpcCtrlClient::~GrpcCtrlClient() { StopHeartbeat(); }\n\nGrpcCtrlClient::GrpcCtrlClient(const ProcessCtx& process_ctx) : process_ctx_(process_ctx) {\n  rpc_client_.ReserveStubsOfSize(process_ctx.ctrl_addr_size());\n  for (int64_t i = 0; i < process_ctx.ctrl_addr_size(); ++i) {\n    const Address& address = process_ctx.ctrl_addr(i);\n    auto new_stub = CtrlService::NewStub(address.host() + \":\" + std::to_string(address.port()));\n    rpc_client_.AddStub(std::move(new_stub));\n    rpc_client_.LoadServer(address.host(), rpc_client_.GetStubAt(i));\n  }\n  need_heartbeat_thread_stop_ = false;\n  heartbeat_thread_ = std::thread([this]() {\n    std::mt19937 gen(NewRandomSeed());\n    std::uniform_int_distribution<int32_t> sleep_second_dis(7, 13);\n    LoadServerRequest request;\n    LoadServerResponse response;\n    while (true) {\n      const auto wait_duration = std::chrono::seconds(sleep_second_dis(gen));\n      {\n        std::unique_lock<std::mutex> lck(need_heartbeat_thread_stop_mtx_);\n        const bool stopped = need_heartbeat_thread_stop_cv_.wait_for(\n            lck, wait_duration, [&]() { return need_heartbeat_thread_stop_; });\n        if (stopped) { break; }\n      }\n      for (size_t i = 0; i < rpc_client_.GetStubSize(); ++i) {\n        grpc::ClientContext client_ctx;\n        request.set_addr(this->process_ctx().ctrl_addr(i).host());\n        GRPC_CHECK(rpc_client_.GetStubAt(i)->CallMethod<CtrlMethod::kLoadServer>(\n            &client_ctx, request, &response))\n            << \"Machine \" << i << \" lost\";\n      }\n    }\n  });\n}\n\nvoid GrpcCtrlClient::Barrier(const std::string& barrier_name) { rpc_client_.Barrier(barrier_name); }\n\nvoid GrpcCtrlClient::Barrier(const std::string& barrier_name, int32_t barrier_num) {\n  rpc_client_.Barrier(barrier_name, barrier_num);\n}\n\nTryLockResult GrpcCtrlClient::TryLock(const std::string& name) { return rpc_client_.TryLock(name); }\n\nvoid GrpcCtrlClient::NotifyDone(const std::string& name) { rpc_client_.NotifyDone(name); }\n\nvoid GrpcCtrlClient::WaitUntilDone(const std::string& name) { rpc_client_.WaitUntilDone(name); }\n\nvoid GrpcCtrlClient::PushKV(const std::string& k, const std::string& v) {\n  rpc_client_.PushKV(k, v);\n}\n\nvoid GrpcCtrlClient::PushKV(const std::string& k, const PbMessage& msg) {\n  rpc_client_.PushKV(k, msg);\n}\n\nvoid GrpcCtrlClient::PushKV(const std::string& k, std::function<void(std::string*)> VSetter) {\n  rpc_client_.PushKV(k, VSetter);\n}\n\nvoid GrpcCtrlClient::PushMasterKV(const std::string& k, const PbMessage& msg) {\n  rpc_client_.PushMasterKV(k, msg);\n}\n\nvoid GrpcCtrlClient::ClearKV(const std::string& k) { rpc_client_.ClearKV(k); }\n\nvoid GrpcCtrlClient::ClearMasterKV(const std::string& k) { rpc_client_.ClearMasterKV(k); }\n\nvoid GrpcCtrlClient::PullKV(const std::string& k, std::string* v) { rpc_client_.PullKV(k, v); }\n\nvoid GrpcCtrlClient::PullKV(const std::string& k, PbMessage* msg) { rpc_client_.PullKV(k, msg); }\n\nvoid GrpcCtrlClient::PullKV(const std::string& k, std::function<void(const std::string&)> VGetter) {\n  rpc_client_.PullKV(k, VGetter);\n}\n\nvoid GrpcCtrlClient::PullMasterKV(const std::string& k, PbMessage* msg) {\n  rpc_client_.PullMasterKV(k, msg);\n}\n\nvoid GrpcCtrlClient::Clear() { rpc_client_.Clear(); }\n\nint32_t GrpcCtrlClient::IncreaseCount(const std::string& k, int32_t v) {\n  return rpc_client_.IncreaseCount(k, v);\n}\n\nvoid GrpcCtrlClient::EraseCount(const std::string& k) { rpc_client_.EraseCount(k); }\n\nvoid GrpcCtrlClient::StopHeartbeat() {\n  bool already_stopped = false;\n  {\n    std::unique_lock<std::mutex> lck(need_heartbeat_thread_stop_mtx_);\n    already_stopped = need_heartbeat_thread_stop_;\n    need_heartbeat_thread_stop_ = true;\n    need_heartbeat_thread_stop_cv_.notify_all();\n  }\n  if (!already_stopped) { heartbeat_thread_.join(); }\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/control/ctrl_client.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_CONTROL_CTRL_CLIENT_H_\n#define ONEFLOW_CORE_CONTROL_CTRL_CLIENT_H_\n\n#include \"oneflow/core/rpc/include/ctrl.h\"\n\n#endif  // ONEFLOW_CORE_CONTROL_CTRL_CLIENT_H_\n"
  },
  {
    "path": "oneflow/core/control/ctrl_server.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/control/ctrl_server.h\"\n#include \"oneflow/core/control/ctrl_bootstrap.pb.h\"\n#include \"oneflow/core/job/env_desc.h\"\n#include \"grpc/grpc_posix.h\"\n\nnamespace oneflow {\n\nCtrlServer::CtrlServer(int ctrl_port) : RpcServer(), port_(ctrl_port) {\n  Init();\n  grpc::ServerBuilder server_builder;\n  server_builder.SetMaxMessageSize(INT_MAX);\n  int bound_port = 0;\n  server_builder.AddListeningPort(\"0.0.0.0:\" + std::to_string(port_),\n                                  grpc::InsecureServerCredentials(), &bound_port);\n  grpc_service_.reset(new CtrlService::AsyncService);\n  server_builder.RegisterService(grpc_service_.get());\n  cq_ = server_builder.AddCompletionQueue();\n  grpc_server_ = server_builder.BuildAndStart();\n  if (port() != 0) {\n    CHECK_EQ(port(), bound_port) << \"Port \" << port() << \" is unavailable\";\n  } else {\n    port_ = bound_port;\n    CHECK_NE(port(), 0);\n  }\n  LOG(INFO) << \"CtrlServer listening on \"\n            << \"0.0.0.0:\" + std::to_string(port());\n  loop_thread_ = std::thread(&CtrlServer::HandleRpcs, this);\n}\n\nCtrlServer::CtrlServer() : CtrlServer(0) {}\n\nvoid CtrlServer::OnLoadServer(CtrlCall<CtrlMethod::kLoadServer>* call) {\n  call->SendResponse();\n  EnqueueRequest<CtrlMethod::kLoadServer>();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/control/ctrl_server.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_CONTROL_CTRL_SERVER_H_\n#define ONEFLOW_CORE_CONTROL_CTRL_SERVER_H_\n\n#ifdef RPC_BACKEND_GRPC\n\n#include \"oneflow/core/control/rpc_server.h\"\n\nnamespace oneflow {\n\nclass CtrlServer final : public RpcServer {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CtrlServer);\n  ~CtrlServer() override {}\n\n  CtrlServer();\n  // port may be configured in bootstrap_conf\n  CtrlServer(int ctrl_port);\n\n  int64_t port() const { return port_; }\n\n private:\n  void OnLoadServer(CtrlCall<CtrlMethod::kLoadServer>* call) override;\n  int port_;\n};\n\n}  // namespace oneflow\n\n#endif  // RPC_BACKEND_GRPC\n\n#endif  // ONEFLOW_CORE_CONTROL_CTRL_SERVER_H_\n"
  },
  {
    "path": "oneflow/core/control/ctrl_service.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/control/ctrl_service.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<size_t method_index>\nconst grpc::internal::RpcMethod BuildOneRpcMethod(std::shared_ptr<grpc::ChannelInterface> channel) {\n  return grpc::internal::RpcMethod(GetMethodName(static_cast<CtrlMethod>(method_index)),\n                                   grpc::internal::RpcMethod::NORMAL_RPC, channel);\n}\n\ntemplate<size_t... method_indices>\nstd::array<const grpc::internal::RpcMethod, kCtrlMethodNum> BuildRpcMethods(\n    std::index_sequence<method_indices...>, std::shared_ptr<grpc::ChannelInterface> channel) {\n  return {BuildOneRpcMethod<method_indices>(channel)...};\n}\n\nconstexpr int64_t kDefaultGrpcMaxMessageByteSize = -1;\n\n}  // namespace\n\nCtrlService::Stub::Stub(std::shared_ptr<grpc::ChannelInterface> channel)\n    : rpcmethods_(BuildRpcMethods(std::make_index_sequence<kCtrlMethodNum>{}, channel)),\n      channel_(channel) {}\n\nstd::unique_ptr<CtrlService::Stub> CtrlService::NewStub(const std::string& addr) {\n  grpc::ChannelArguments ch_args;\n  int64_t max_msg_byte_size =\n      ParseIntegerFromEnv(\"ONEFLOW_GRPC_MAX_MESSAGE_BYTE_SIZE\", kDefaultGrpcMaxMessageByteSize);\n  ch_args.SetInt(GRPC_ARG_MAX_MESSAGE_LENGTH, max_msg_byte_size);\n  return std::make_unique<Stub>(\n      grpc::CreateCustomChannel(addr, grpc::InsecureChannelCredentials(), ch_args));\n}\n\nCtrlService::AsyncService::AsyncService() {\n  for (int32_t i = 0; i < kCtrlMethodNum; ++i) {\n    AddMethod(new grpc::internal::RpcServiceMethod(GetMethodName(static_cast<CtrlMethod>(i)),\n                                                   grpc::internal::RpcMethod::NORMAL_RPC, nullptr));\n    grpc::Service::MarkMethodAsync(i);\n  }\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/control/ctrl_service.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_CONTROL_CTRL_SERVICE_H_\n#define ONEFLOW_CORE_CONTROL_CTRL_SERVICE_H_\n\n#include <grpc++/grpc++.h>\n#include <grpc++/impl/codegen/async_stream.h>\n#include <grpc++/impl/codegen/async_unary_call.h>\n#include <grpc++/impl/codegen/proto_utils.h>\n#include <grpc++/impl/codegen/rpc_method.h>\n#include <grpc++/impl/codegen/service_type.h>\n#include <grpc++/impl/codegen/status.h>\n#include <grpc++/impl/codegen/stub_options.h>\n#include <grpc++/impl/codegen/sync_stream.h>\n#include <grpc++/impl/codegen/client_unary_call.h>\n#include \"oneflow/core/common/preprocessor.h\"\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/control/control.pb.h\"\n#include \"oneflow/core/rpc/include/base.h\"\n\nnamespace oneflow {\n\nclass CtrlService final {\n public:\n  class Stub final {\n   public:\n    Stub(std::shared_ptr<grpc::ChannelInterface> channel);\n\n    template<CtrlMethod ctrl_method>\n    grpc::Status CallMethod(grpc::ClientContext* context, const CtrlRequest<ctrl_method>& request,\n                            CtrlResponse<ctrl_method>* response) {\n      return grpc::internal::BlockingUnaryCall(channel_.get(),\n                                               rpcmethods_.at(static_cast<size_t>(ctrl_method)),\n                                               context, request, response);\n    }\n\n   private:\n    std::array<const grpc::internal::RpcMethod, kCtrlMethodNum> rpcmethods_;\n\n    std::shared_ptr<grpc::ChannelInterface> channel_;\n  };\n\n  static std::unique_ptr<Stub> NewStub(const std::string& addr);\n\n  class AsyncService final : public grpc::Service {\n   public:\n    AsyncService();\n    ~AsyncService() = default;\n    using grpc::Service::RequestAsyncUnary;\n  };\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_CONTROL_CTRL_SERVICE_H_\n"
  },
  {
    "path": "oneflow/core/control/ctrl_test.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"gtest/gtest.h\"\n#include \"oneflow/core/job/env.pb.h\"\n#include \"oneflow/core/control/ctrl_client.h\"\n#include \"oneflow/core/control/ctrl_server.h\"\n#include \"oneflow/core/control/ctrl_bootstrap.h\"\n#include \"oneflow/core/control/ctrl_util.h\"\n#include \"oneflow/core/control/global_process_ctx.h\"\n#include \"oneflow/core/job/resource_desc.h\"\n#include \"oneflow/core/job/global_for.h\"\n\n#ifdef OF_PLATFORM_POSIX\n\n#include <netinet/in.h>\n#include <netinet/tcp.h>\n#include <sys/socket.h>\n#include <sys/types.h>\n#include <arpa/inet.h>\n\nnamespace oneflow {\n\nnamespace {\n\nEnvProto GetEnvProto(int port) {\n  EnvProto ret;\n  auto* machine0 = ret.add_machine();\n  machine0->set_id(0);\n  machine0->set_addr(\"127.0.0.1\");\n  ret.set_ctrl_port(port);\n  return ret;\n}\n\nResource GetResource() {\n  Resource ret;\n  ret.set_machine_num(1);\n  ret.set_cpu_device_num(1);\n  ret.set_comm_net_worker_num(1);\n  return ret;\n}\n\n}  // namespace\n\n#ifdef RPC_BACKEND_GRPC\nTEST(CtrlServer, new_delete) {\n  int port = CtrlUtil().FindAvailablePort();\n  if (port == -1) { return; }\n  EnvProto env_proto = GetEnvProto(port);\n  Singleton<EnvDesc>::New(env_proto);\n  Singleton<CtrlServer>::New();\n  Singleton<ProcessCtx>::New();\n  CHECK_JUST(\n      HostListCtrlBootstrap(*Singleton<EnvDesc>::Get())\n          .InitProcessCtx(Singleton<CtrlServer>::Get()->port(), Singleton<ProcessCtx>::Get()));\n  auto* client = new GrpcCtrlClient(*Singleton<ProcessCtx>::Get());\n  Singleton<CtrlClient>::SetAllocated(client);\n  Singleton<ResourceDesc, ForEnv>::New(GetResource(), GlobalProcessCtx::NumOfProcessPerNode());\n  Singleton<ResourceDesc, ForSession>::New(GetResource(), GlobalProcessCtx::NumOfProcessPerNode());\n\n  // do test\n  // OF_ENV_BARRIER();\n\n  Singleton<ResourceDesc, ForSession>::Delete();\n  Singleton<ResourceDesc, ForEnv>::Delete();\n  Singleton<CtrlClient>::Delete();\n  Singleton<ProcessCtx>::Delete();\n  Singleton<CtrlServer>::Delete();\n  Singleton<EnvDesc>::Delete();\n}\n#endif  // RPC_BACKEND_GRPC\n\n}  // namespace oneflow\n\n#endif  // OF_PLATFORM_POSIX\n"
  },
  {
    "path": "oneflow/core/control/ctrl_util.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/common/platform.h\"\n#ifdef OF_PLATFORM_POSIX\n#include <netinet/in.h>\n#include <netinet/tcp.h>\n#include <sys/socket.h>\n#include <sys/types.h>\n#include <arpa/inet.h>\n#endif  // OF_PLATFORM_POSIX\n\n#include \"oneflow/core/control/ctrl_util.h\"\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/common/data_type.h\"\n\nnamespace oneflow {\n\n#ifdef OF_PLATFORM_POSIX\n\nnamespace {\n\nsockaddr_in GetSockAddr(const std::string& addr, uint16_t port) {\n  sockaddr_in sa;\n  sa.sin_family = AF_INET;\n  sa.sin_port = htons(port);\n  PCHECK(inet_pton(AF_INET, addr.c_str(), &(sa.sin_addr)) == 1);\n  return sa;\n}\n\n}  // namespace\n\nint CtrlUtil::FindAvailablePort() const {\n  int sock = socket(AF_INET, SOCK_STREAM, 0);\n\n  for (uint16_t port = 10000; port < GetMaxVal<uint16_t>(); ++port) {\n    sockaddr_in sa = GetSockAddr(\"0.0.0.0\", port);\n    int bind_result = bind(sock, reinterpret_cast<sockaddr*>(&sa), sizeof(sa));\n    if (bind_result == 0) {\n      shutdown(sock, SHUT_RDWR);\n      close(sock);\n      return port;\n    }\n  }\n  return -1;\n}\n\n#else\n\nint CtrlUtil::FindAvailablePort() const { UNIMPLEMENTED(); }\n\n#endif  // OF_PLATFORM_POSIX\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/control/ctrl_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_CONTROL_CTR_TEST_H_\n#define ONEFLOW_CORE_CONTROL_CTR_TEST_H_\n\nnamespace oneflow {\n\nclass CtrlUtil {\n public:\n  CtrlUtil() = default;\n  ~CtrlUtil() = default;\n\n  int FindAvailablePort() const;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_CONTROL_CTR_TEST_H_\n"
  },
  {
    "path": "oneflow/core/control/global_process_ctx.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_CONTROL_GLOBAL_PROCESS_CTX_H_\n#define ONEFLOW_CORE_CONTROL_GLOBAL_PROCESS_CTX_H_\n\n#include \"oneflow/core/rpc/include/global_process_ctx.h\"\n\n#endif  // ONEFLOW_CORE_CONTROL_GLOBAL_PROCESS_CTX_H_\n"
  },
  {
    "path": "oneflow/core/control/host_list_bootstrap_client.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/control/host_list_bootstrap_client.h\"\n#include \"oneflow/core/job/env_desc.h\"\n\nnamespace oneflow {\n\nHostListBootstrapClient::HostListBootstrapClient(const EnvDesc& env_desc) {\n  stubs_.reserve(env_desc.TotalMachineNum());\n  int32_t port = -1;\n  std::string addr = \"\";\n  for (int64_t i = 0; i < env_desc.TotalMachineNum(); ++i) {\n    const Machine& mchn = env_desc.machine(i);\n    port = (mchn.ctrl_port_agent() != -1) ? (mchn.ctrl_port_agent()) : env_desc.ctrl_port();\n    addr = mchn.addr() + \":\" + std::to_string(port);\n    stubs_.emplace_back(CtrlService::NewStub(addr));\n    LoadServer(mchn.addr(), stubs_[i].get());\n  }\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/control/host_list_bootstrap_client.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_CONTROL_HOST_LIST_BOOTSTRAP_CLIENT_H_\n#define ONEFLOW_CORE_CONTROL_HOST_LIST_BOOTSTRAP_CLIENT_H_\n\n#include \"oneflow/core/control/bootstrap_client.h\"\n#include \"oneflow/core/job/env_desc.h\"\n\nnamespace oneflow {\n\nclass HostListBootstrapClient final : public BootstrapClient {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(HostListBootstrapClient);\n  ~HostListBootstrapClient() override = default;\n\n  HostListBootstrapClient(const EnvDesc& env_desc);\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_CONTROL_HOST_LIST_BOOTSTRAP_CLIENT_H_\n"
  },
  {
    "path": "oneflow/core/control/host_list_bootstrap_server.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/control/host_list_bootstrap_server.h\"\n#include \"grpc/grpc_posix.h\"\n\nnamespace oneflow {\n\nHostListBootstrapServer::HostListBootstrapServer(const EnvDesc& env_desc)\n    : BootstrapServer(), is_first_connect_(true), this_machine_addr_(\"\") {\n  Init();\n  int port = env_desc.ctrl_port();\n  grpc::ServerBuilder server_builder;\n  server_builder.SetMaxMessageSize(INT_MAX);\n  int bound_port = 0;\n  server_builder.AddListeningPort(\"0.0.0.0:\" + std::to_string(port),\n                                  grpc::InsecureServerCredentials(), &bound_port);\n  grpc_service_.reset(new CtrlService::AsyncService);\n  server_builder.RegisterService(grpc_service_.get());\n  cq_ = server_builder.AddCompletionQueue();\n  grpc_server_ = server_builder.BuildAndStart();\n  CHECK_EQ(port, bound_port) << \"Port \" << port << \" is unavailable\";\n  LOG(INFO) << \"HostListBootstrapServer listening on \"\n            << \"0.0.0.0:\" + std::to_string(port);\n  loop_thread_ = std::thread(&HostListBootstrapServer::HandleRpcs, this);\n}\n\nvoid HostListBootstrapServer::OnLoadServer(CtrlCall<CtrlMethod::kLoadServer>* call) {\n  if (this->is_first_connect_) {\n    this->this_machine_addr_ = call->request().addr();\n    this->is_first_connect_ = false;\n  } else {\n    CHECK_EQ(call->request().addr(), this->this_machine_addr_);\n  }\n  call->SendResponse();\n  EnqueueRequest<CtrlMethod::kLoadServer>();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/control/host_list_bootstrap_server.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_CONTROL_HOST_LIST_BOOTSTRAP_SERVER_H_\n#define ONEFLOW_CORE_CONTROL_HOST_LIST_BOOTSTRAP_SERVER_H_\n\n#include \"oneflow/core/control/bootstrap_server.h\"\n#include \"oneflow/core/job/env_desc.h\"\n\nnamespace oneflow {\n\nclass HostListBootstrapServer final : public BootstrapServer {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(HostListBootstrapServer);\n  ~HostListBootstrapServer() override = default;\n\n  HostListBootstrapServer(const EnvDesc& env_desc);\n  const std::string& this_machine_addr() { return this_machine_addr_; }\n\n private:\n  void OnLoadServer(CtrlCall<CtrlMethod::kLoadServer>* call) override;\n\n  bool is_first_connect_;\n  std::string this_machine_addr_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_CONTROL_HOST_LIST_BOOTSTRAP_SERVER_H_\n"
  },
  {
    "path": "oneflow/core/control/rank_info_bootstrap_client.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/control/rank_info_bootstrap_client.h\"\n\nnamespace oneflow {\n\nRankInfoBootstrapClient::RankInfoBootstrapClient(const BootstrapConf& bootstrap_conf) {\n  stubs_.reserve(bootstrap_conf.world_size());\n  const auto& master_addr = bootstrap_conf.master_addr();\n  const std::string& host = master_addr.host() + \":\" + std::to_string(master_addr.port());\n  stubs_.emplace_back(CtrlService::NewStub(host));\n  LoadServerRequest request;\n  request.set_addr(master_addr.host());\n  request.set_rank(bootstrap_conf.rank());\n  LoadServer(request, stubs_[0].get());\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/control/rank_info_bootstrap_client.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_CONTROL_RANK_INFO_BOOTSTRAP_CLIENT_H_\n#define ONEFLOW_CORE_CONTROL_RANK_INFO_BOOTSTRAP_CLIENT_H_\n\n#include \"oneflow/core/control/bootstrap_client.h\"\n#include \"oneflow/core/control/ctrl_bootstrap.pb.h\"\n#include \"oneflow/core/job/env_desc.h\"\n\nnamespace oneflow {\n\nclass RankInfoBootstrapClient final : public BootstrapClient {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(RankInfoBootstrapClient);\n  ~RankInfoBootstrapClient() override = default;\n\n  RankInfoBootstrapClient(const BootstrapConf& bootstrap_conf);\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_CONTROL_RANK_INFO_BOOTSTRAP_CLIENT_H_\n"
  },
  {
    "path": "oneflow/core/control/rank_info_bootstrap_server.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <thread>\n#include <mutex>\n#include <chrono>\n#include \"grpc/grpc_posix.h\"\n#include \"oneflow/core/common/env_var/bootstrap.h\"\n#include \"oneflow/core/control/rank_info_bootstrap_server.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nstd::string GetHostFromUri(const std::string& uri) {\n  size_t first_delimiter_pos = uri.find(\":\");\n  CHECK_NE(first_delimiter_pos, std::string::npos);\n  const std::string& protocol_family = uri.substr(0, first_delimiter_pos);\n  CHECK_EQ(protocol_family, \"ipv4\");\n  size_t second_delimiter_pos = uri.rfind(\":\");\n  return uri.substr(first_delimiter_pos + 1, second_delimiter_pos - first_delimiter_pos - 1);\n}\n\nint64_t rpc_bootstrap_server_sleep_seconds() {\n  static const int64_t rpc_bootstrap_server_sleep_seconds =\n      EnvInteger<ONEFLOW_RPC_BOOTSTRAP_SERVER_SLEEP_SECONDS>();\n  return rpc_bootstrap_server_sleep_seconds;\n}\n\nint64_t rpc_bootstrap_server_max_retry_times() {\n  static const int64_t rpc_bootstrap_server_max_retry_times =\n      EnvInteger<ONEFLOW_RPC_BOOTSTRAP_SERVER_MAX_RETRY_TIMES>();\n  return rpc_bootstrap_server_max_retry_times;\n}\n\n}  // namespace\n\nRankInfoBootstrapServer::RankInfoBootstrapServer(const BootstrapConf& bootstrap_conf)\n    : BootstrapServer(), port_(0), world_size_(bootstrap_conf.world_size()) {\n  Init();\n  const int64_t rank = bootstrap_conf.rank();\n  int p = (rank == 0 ? bootstrap_conf.master_addr().port() : 0);\n  grpc::ServerBuilder server_builder;\n  server_builder.SetMaxMessageSize(INT_MAX);\n  server_builder.AddListeningPort(\"0.0.0.0:\" + std::to_string(p), grpc::InsecureServerCredentials(),\n                                  &port_);\n  grpc_service_.reset(new CtrlService::AsyncService);\n  server_builder.RegisterService(grpc_service_.get());\n  cq_ = server_builder.AddCompletionQueue();\n  grpc_server_ = server_builder.BuildAndStart();\n  if (rank == 0) { CHECK_EQ(p, port()) << \"Port \" << p << \" is unavailable\"; }\n  LOG(INFO) << \"RankInfoBootstrapServer listening on \"\n            << \"0.0.0.0:\" + std::to_string(port());\n  loop_thread_ = std::thread(&RankInfoBootstrapServer::HandleRpcs, this);\n  if (rank == 0) {\n    rank2host_ = std::make_shared<std::vector<std::string>>(world_size_, \"\");\n    // NOTE: use check_thread_ to check RankInfoBootstrapServer status on rank 0\n    // if size of ready ranks == total ranks(world_size), means status is ok.\n    // otherwise, it indicates that other ranks' server have not been created successfully!\n    check_thread_ = std::thread(&RankInfoBootstrapServer::CheckServerStatus, this);\n  }\n}\n\nvoid RankInfoBootstrapServer::CheckServerStatus() {\n  bool status_ok = false;\n  int64_t skip_warning_times = 1;\n  int64_t retry_idx = 0;\n  // lambda function to get valid rank num of rank2host_\n  auto GetValidRank2HostSize = [](const std::shared_ptr<std::vector<std::string>>& rank2host) {\n    int64_t valid_size = 0;\n    for (int64_t i = 0; i < rank2host->size(); ++i) {\n      if (rank2host->at(i) == \"\") { continue; }\n      valid_size += 1;\n    }\n    return valid_size;\n  };\n\n  for (; retry_idx < rpc_bootstrap_server_max_retry_times(); ++retry_idx) {\n    std::this_thread::sleep_for(std::chrono::seconds(rpc_bootstrap_server_sleep_seconds()));\n    int64_t valid_size = 0;\n    {\n      std::lock_guard<std::mutex> lock(lock_);\n      valid_size = GetValidRank2HostSize(rank2host_);\n    }\n    CHECK(valid_size <= world_size_);\n    if (valid_size == world_size_) {\n      status_ok = true;\n      break;\n    } else {\n      if (retry_idx >= skip_warning_times) {\n        LOG(WARNING) << \"BootstrapServer not ready, rpc server on some rank have not been created \"\n                        \"successfully. Failed at \"\n                     << retry_idx + 1 << \" times, total ranks(world_size): \" << world_size_\n                     << \", ready ranks: \" << valid_size;\n      }\n    }\n  }\n\n  if (!status_ok) {\n    LOG(FATAL) << \"CheckServerStatus() failed, rpc server on some rank are not ready, please check \"\n                  \"whether the processes on all ranks are \"\n                  \"created successfully.\";\n  }\n}\n\nMaybe<const std::vector<std::string>&> RankInfoBootstrapServer::rank2host() const {\n  CHECK_NOTNULL(rank2host_.get());\n  return *rank2host_;\n}\n\nvoid RankInfoBootstrapServer::OnLoadServer(CtrlCall<CtrlMethod::kLoadServer>* call) {\n  int64_t rank = call->request().rank();\n  CHECK_GE(rank, 0);\n  CHECK_LT(rank, world_size_);\n  if (!rank2host_) { rank2host_ = std::make_shared<std::vector<std::string>>(world_size_); }\n  std::lock_guard<std::mutex> lock(lock_);\n  rank2host_->at(rank) = GetHostFromUri(call->server_ctx().peer());\n  call->SendResponse();\n  EnqueueRequest<CtrlMethod::kLoadServer>();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/control/rank_info_bootstrap_server.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_CONTROL_RANK_INFO_BOOTSTRAP_SERVER_H_\n#define ONEFLOW_CORE_CONTROL_RANK_INFO_BOOTSTRAP_SERVER_H_\n\n#include \"oneflow/core/control/bootstrap_server.h\"\n#include \"oneflow/core/control/ctrl_bootstrap.pb.h\"\n#include \"oneflow/core/job/env_desc.h\"\n#include \"oneflow/core/common/maybe.h\"\n\nnamespace oneflow {\n\nclass RankInfoBootstrapServer final : public BootstrapServer {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(RankInfoBootstrapServer);\n  ~RankInfoBootstrapServer() override {\n    if (check_thread_.joinable()) { check_thread_.join(); }\n  }\n\n  RankInfoBootstrapServer(const BootstrapConf& bootstrap_conf);\n\n  int64_t port() const { return port_; }\n  Maybe<const std::vector<std::string>&> rank2host() const;\n\n private:\n  void OnLoadServer(CtrlCall<CtrlMethod::kLoadServer>* call) override;\n  void CheckServerStatus();\n\n  int port_;\n  const int64_t world_size_;\n  std::mutex lock_;\n  std::thread check_thread_;\n  // use std::shared_ptr as std::optional\n  std::shared_ptr<std::vector<std::string>> rank2host_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_CONTROL_RANK_INFO_BOOTSTRAP_SERVER_H_\n"
  },
  {
    "path": "oneflow/core/control/rpc_client.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/control/rpc_client.h\"\n#include \"oneflow/core/control/global_process_ctx.h\"\n#include \"oneflow/core/job/env_desc.h\"\n#include \"oneflow/core/common/env_var/bootstrap.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nint64_t rpc_client_max_retry_times() {\n  static const int64_t rpc_client_max_retry_times =\n      EnvInteger<ONEFLOW_RPC_CLIENT_MAX_RETRY_TIMES>();\n  return rpc_client_max_retry_times;\n}\n\nint64_t rpc_client_sleep_seconds() {\n  static const int64_t rpc_client_sleep_seconds = EnvInteger<ONEFLOW_RPC_CLIENT_SLEEP_SECONDS>();\n  return rpc_client_sleep_seconds;\n}\n\n#define GRPC_CHECK(x) CHECK_EQ(x.error_code(), grpc::StatusCode::OK)\n\ntemplate<CtrlMethod ctrl_method>\nclass ClientCall final {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(ClientCall);\n  ClientCall() = default;\n  ~ClientCall() = default;\n\n  CtrlRequest<ctrl_method>* mut_request() { return &request_; }\n  const CtrlResponse<ctrl_method>& response() const { return response_; }\n  void operator()(CtrlService::Stub* stub) {\n    grpc::ClientContext client_ctx;\n    GRPC_CHECK(stub->CallMethod<ctrl_method>(&client_ctx, request_, &response_));\n  }\n\n private:\n  CtrlRequest<ctrl_method> request_;\n  CtrlResponse<ctrl_method> response_;\n};\n\n}  // namespace\n\nvoid RpcClient::Barrier(const std::string& barrier_name) {\n  Barrier(barrier_name, Singleton<EnvDesc>::Get()->TotalMachineNum());\n}\n\nvoid RpcClient::Barrier(const std::string& barrier_name, int32_t barrier_num) {\n  ClientCall<CtrlMethod::kBarrier> call;\n  call.mut_request()->set_name(barrier_name);\n  call.mut_request()->set_num(barrier_num);\n  call(GetMasterStub());\n}\n\nTryLockResult RpcClient::TryLock(const std::string& name) {\n  {\n    std::unique_lock<std::mutex> lck(done_names_mtx_);\n    if (done_names_.find(name) != done_names_.end()) { return TryLockResult::kDone; }\n  }\n  ClientCall<CtrlMethod::kTryLock> call;\n  call.mut_request()->set_name(name);\n  call(GetResponsibleStub(name));\n  if (call.response().result() == TryLockResult::kDone) {\n    std::unique_lock<std::mutex> lck(done_names_mtx_);\n    done_names_.insert(name);\n  }\n  return call.response().result();\n}\n\nvoid RpcClient::NotifyDone(const std::string& name) {\n  ClientCall<CtrlMethod::kNotifyDone> call;\n  call.mut_request()->set_name(name);\n  call(GetResponsibleStub(name));\n}\n\nvoid RpcClient::WaitUntilDone(const std::string& name) {\n  ClientCall<CtrlMethod::kWaitUntilDone> call;\n  call.mut_request()->set_name(name);\n  call(GetResponsibleStub(name));\n}\n\nvoid RpcClient::PushKV(const std::string& k, std::function<void(std::string*)> VSetter) {\n  ClientCall<CtrlMethod::kPushKV> call;\n  call.mut_request()->set_key(k);\n  VSetter(call.mut_request()->mutable_val());\n  call(GetResponsibleStub(k));\n}\n\nvoid RpcClient::PushMasterKV(const std::string& k, std::function<void(std::string*)> VSetter) {\n  ClientCall<CtrlMethod::kPushKV> call;\n  call.mut_request()->set_key(k);\n  VSetter(call.mut_request()->mutable_val());\n  call(GetMasterStub());\n}\n\nvoid RpcClient::PushKV(const std::string& k, const std::string& v) {\n  PushKV(k, [&](std::string* o) { *o = v; });\n}\n\nvoid RpcClient::PushKV(const std::string& k, const PbMessage& msg) {\n  PushKV(k, [&](std::string* o) { msg.SerializeToString(o); });\n}\n\nvoid RpcClient::PushMasterKV(const std::string& k, const PbMessage& msg) {\n  PushMasterKV(k, [&](std::string* o) { msg.SerializeToString(o); });\n}\n\nvoid RpcClient::ClearKV(const std::string& k) {\n  ClientCall<CtrlMethod::kClearKV> call;\n  call.mut_request()->set_key(k);\n  call(GetResponsibleStub(k));\n}\n\nvoid RpcClient::ClearMasterKV(const std::string& k) {\n  ClientCall<CtrlMethod::kClearKV> call;\n  call.mut_request()->set_key(k);\n  call(GetMasterStub());\n}\n\nvoid RpcClient::PullKV(const std::string& k, std::function<void(const std::string&)> VGetter) {\n  ClientCall<CtrlMethod::kPullKV> call;\n  call.mut_request()->set_key(k);\n  call(GetResponsibleStub(k));\n  VGetter(call.response().val());\n}\n\nvoid RpcClient::PullMasterKV(const std::string& k,\n                             std::function<void(const std::string&)> VGetter) {\n  ClientCall<CtrlMethod::kPullKV> call;\n  call.mut_request()->set_key(k);\n  call(GetMasterStub());\n  VGetter(call.response().val());\n}\n\nvoid RpcClient::PullKV(const std::string& k, std::string* v) {\n  PullKV(k, [&](const std::string& i) { *v = i; });\n}\n\nvoid RpcClient::PullKV(const std::string& k, PbMessage* msg) {\n  PullKV(k, [&](const std::string& i) { msg->ParseFromString(i); });\n}\n\nvoid RpcClient::PullMasterKV(const std::string& k, PbMessage* msg) {\n  PullMasterKV(k, [&](const std::string& i) { msg->ParseFromString(i); });\n}\n\nvoid RpcClient::Clear() {\n  ClientCall<CtrlMethod::kClear> call;\n  call(GetThisStub());\n  std::unique_lock<std::mutex> lck(done_names_mtx_);\n  done_names_.clear();\n}\n\nint32_t RpcClient::IncreaseCount(const std::string& k, int32_t v) {\n  ClientCall<CtrlMethod::kIncreaseCount> call;\n  call.mut_request()->set_key(k);\n  call.mut_request()->set_val(v);\n  call(GetResponsibleStub(k));\n  return call.response().val();\n}\n\nvoid RpcClient::EraseCount(const std::string& k) {\n  ClientCall<CtrlMethod::kEraseCount> call;\n  call.mut_request()->set_key(k);\n  call(GetResponsibleStub(k));\n}\n\nvoid RpcClient::LoadServer(const std::string& server_addr, CtrlService::Stub* stub) {\n  LoadServerRequest request;\n  request.set_addr(server_addr);\n  return LoadServer(request, stub);\n}\n\nvoid RpcClient::LoadServer(const LoadServerRequest& request, CtrlService::Stub* stub) {\n  int32_t retry_idx = 0;\n  int32_t skip_warning_times = 3;\n  for (; retry_idx < rpc_client_max_retry_times(); ++retry_idx) {\n    grpc::ClientContext client_ctx;\n    LoadServerResponse response;\n    grpc::Status st = stub->CallMethod<CtrlMethod::kLoadServer>(&client_ctx, request, &response);\n    if (st.error_code() == grpc::StatusCode::OK) {\n      VLOG(3) << \"LoadServer \" << request.addr() << \" Successful at \" << retry_idx + 1 << \" times\";\n      break;\n    } else if (st.error_code() == grpc::StatusCode::UNAVAILABLE) {\n      if (retry_idx >= skip_warning_times) {\n        LOG(WARNING) << \"LoadServer \" << request.addr() << \" Failed at \" << retry_idx + 1\n                     << \" times\"\n                     << \" error_code: \" << st.error_code()\n                     << \" error_message: \" << st.error_message();\n      }\n      std::this_thread::sleep_for(std::chrono::seconds(rpc_client_sleep_seconds()));\n      continue;\n    } else {\n      LOG(FATAL) << st.error_message();\n    }\n  }\n  CHECK_LT(retry_idx, rpc_client_max_retry_times());\n}\n\nCtrlService::Stub* RpcClient::GetThisStub() { return stubs_[GlobalProcessCtx::Rank()].get(); }\n\nCtrlService::Stub* RpcClient::GetResponsibleStub(const std::string& key) {\n  int64_t machine_id =\n      (std::hash<std::string>{}(key)) % Singleton<EnvDesc>::Get()->TotalMachineNum();\n  return stubs_[machine_id].get();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/control/rpc_client.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_CONTROL_RPC_CLIENT_H_\n#define ONEFLOW_CORE_CONTROL_RPC_CLIENT_H_\n\n#include \"oneflow/core/lazy/actor/actor_message.h\"\n#include \"oneflow/core/common/protobuf.h\"\n#include \"oneflow/core/control/ctrl_service.h\"\n#include \"oneflow/core/job/global_for.h\"\n\nnamespace oneflow {\n\nclass RpcClient {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(RpcClient);\n  RpcClient() = default;\n  virtual ~RpcClient() = default;\n\n  void Barrier(const std::string& barrier_name);\n  void Barrier(const std::string& barrier_name, int32_t barrier_num);\n\n  TryLockResult TryLock(const std::string& name);\n  void NotifyDone(const std::string& name);\n  void WaitUntilDone(const std::string& name);\n\n  void PushKV(const std::string& k, std::function<void(std::string*)> VSetter);\n  void PushKV(const std::string& k, const std::string& v);\n  void PushKV(const std::string& k, const PbMessage& msg);\n  void PushMasterKV(const std::string& k, const PbMessage& msg);\n  template<typename T>\n  typename std::enable_if<std::is_arithmetic<T>::value>::type PushKVT(const std::string& k, T v) {\n    PushKV(k, std::to_string(v));\n  }\n\n  void ClearKV(const std::string& k);\n  void ClearMasterKV(const std::string& k);\n\n  void PullKV(const std::string& k, std::function<void(const std::string&)> VGetter);\n  void PullKV(const std::string& k, std::string* v);\n  void PullKV(const std::string& k, PbMessage* msg);\n  void PullMasterKV(const std::string& k, PbMessage* msg);\n  template<typename T>\n  typename std::enable_if<std::is_arithmetic<T>::value>::type PullKVT(const std::string& k, T* v) {\n    std::string v_str;\n    PullKV(k, &v_str);\n    *v = oneflow_cast<T>(v_str);\n  }\n\n  void Clear();\n\n  int32_t IncreaseCount(const std::string& k, int32_t v);\n  int32_t IncreaseCount(const std::string& k) { return IncreaseCount(k, 1); }\n  void EraseCount(const std::string& k);\n  void LoadServer(const std::string& server_addr, CtrlService::Stub* stub);\n  void LoadServer(const LoadServerRequest& request, CtrlService::Stub* stub);\n  void PushMasterKV(const std::string& k, std::function<void(std::string*)> VSetter);\n  void PullMasterKV(const std::string& k, std::function<void(const std::string&)> VGetter);\n  CtrlService::Stub* GetMasterStub() { return stubs_[0].get(); }\n  CtrlService::Stub* GetThisStub();\n  CtrlService::Stub* GetResponsibleStub(const std::string& key);\n  CtrlService::Stub* GetStubAt(int64_t i) { return stubs_[i].get(); };\n  size_t GetStubSize() { return stubs_.size(); };\n  void ReserveStubsOfSize(int64_t n) { stubs_.reserve(n); };\n  void AddStub(std::unique_ptr<CtrlService::Stub> s) { stubs_.emplace_back(std::move(s)); };\n\n  std::vector<std::unique_ptr<CtrlService::Stub>> stubs_;\n  std::mutex done_names_mtx_;\n  HashSet<std::string> done_names_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_CONTROL_RPC_CLIENT_H_\n"
  },
  {
    "path": "oneflow/core/control/rpc_server.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/control/rpc_server.h\"\n#include \"oneflow/core/job/env_desc.h\"\n#include \"grpc/grpc_posix.h\"\n\nnamespace oneflow {\n\nRpcServer::~RpcServer() {\n  // NOTE(chengcheng): This enqueues a special event (with a null tag) that causes\n  // the completion queue to be shut down on the polling thread.\n  grpc::Alarm alarm(cq_.get(), gpr_now(GPR_CLOCK_MONOTONIC), nullptr);\n  loop_thread_.join();\n}\n\nvoid RpcServer::HandleRpcs() {\n  EnqueueRequests();\n\n  void* tag = nullptr;\n  bool ok = false;\n  // NOTE(chengcheng): The is_shutdown bool flag make sure that 'ok = false' occurs ONLY after\n  // cq_->Shutdown() for security check.\n  bool is_shutdown = false;\n  // NOTE(chengcheng): The final end is that cq_->Next() get false and cq_ is empty with no item.\n  while (cq_->Next(&tag, &ok)) {\n    auto call = static_cast<CtrlCallIf*>(tag);\n    if (!ok) {\n      // NOTE(chengcheng): After call grpc_server_->Shutdown() and cq_->Shutdown(),\n      // there will trigger some cancel tag items on each RPC. And cq_->Next() can get these tag\n      // with ok = false. Then delete the tag with CtrlCallIf pointer for recovery.\n      CHECK(is_shutdown);\n      CHECK(call);\n      delete call;\n      continue;\n    }\n    if (call) {\n      call->Process();\n    } else {\n      // NOTE(chengcheng): A null `call` indicates that this is the shutdown alarm.\n      CHECK(!is_shutdown);\n      is_shutdown = true;\n      grpc_server_->Shutdown();\n      cq_->Shutdown();\n\n      // NOTE(chengcheng): You CANNOT use code 'break;' in this block because that\n      // there still be items in the cq_.\n      // 'break;'\n    }\n  }\n}\n\nvoid RpcServer::Init() {\n  Add([this](CtrlCall<CtrlMethod::kLoadServer>* call) { OnLoadServer(call); });\n\n  Add([this](CtrlCall<CtrlMethod::kBarrier>* call) {\n    const std::string& barrier_name = call->request().name();\n    int32_t barrier_num = call->request().num();\n    auto barrier_call_it = barrier_calls_.find(barrier_name);\n    if (barrier_call_it == barrier_calls_.end()) {\n      barrier_call_it =\n          barrier_calls_\n              .emplace(barrier_name, std::make_pair(std::list<CtrlCallIf*>{}, barrier_num))\n              .first;\n    }\n    CHECK_EQ(barrier_num, barrier_call_it->second.second) << barrier_name;\n    barrier_call_it->second.first.emplace_back(call);\n    if (barrier_call_it->second.first.size() == barrier_call_it->second.second) {\n      for (CtrlCallIf* pending_call : barrier_call_it->second.first) {\n        pending_call->SendResponse();\n      }\n      barrier_calls_.erase(barrier_call_it);\n    }\n\n    EnqueueRequest<CtrlMethod::kBarrier>();\n  });\n\n  Add([this](CtrlCall<CtrlMethod::kTryLock>* call) {\n    const std::string& lock_name = call->request().name();\n    auto name2lock_status_it = name2lock_status_.find(lock_name);\n    if (name2lock_status_it == name2lock_status_.end()) {\n      call->mut_response()->set_result(TryLockResult::kLocked);\n      auto waiting_until_done_calls = new std::list<CtrlCallIf*>;\n      CHECK(name2lock_status_.emplace(lock_name, waiting_until_done_calls).second);\n    } else {\n      if (name2lock_status_it->second) {\n        call->mut_response()->set_result(TryLockResult::kDoing);\n      } else {\n        call->mut_response()->set_result(TryLockResult::kDone);\n      }\n    }\n    call->SendResponse();\n    EnqueueRequest<CtrlMethod::kTryLock>();\n  });\n\n  Add([this](CtrlCall<CtrlMethod::kNotifyDone>* call) {\n    const std::string& lock_name = call->request().name();\n    auto name2lock_status_it = name2lock_status_.find(lock_name);\n    auto waiting_calls = static_cast<std::list<CtrlCallIf*>*>(name2lock_status_it->second);\n    for (CtrlCallIf* waiting_call : *waiting_calls) { waiting_call->SendResponse(); }\n    delete waiting_calls;\n    name2lock_status_it->second = nullptr;\n    call->SendResponse();\n    EnqueueRequest<CtrlMethod::kNotifyDone>();\n  });\n\n  Add([this](CtrlCall<CtrlMethod::kWaitUntilDone>* call) {\n    const std::string& lock_name = call->request().name();\n    void* lock_status = name2lock_status_.at(lock_name);\n    if (lock_status) {\n      auto waiting_calls = static_cast<std::list<CtrlCallIf*>*>(lock_status);\n      waiting_calls->emplace_back(call);\n    } else {\n      call->SendResponse();\n    }\n    EnqueueRequest<CtrlMethod::kWaitUntilDone>();\n  });\n\n  Add([this](CtrlCall<CtrlMethod::kPushKV>* call) {\n    const std::string& k = call->request().key();\n    const std::string& v = call->request().val();\n    CHECK(kv_.emplace(k, v).second);\n\n    auto pending_kv_calls_it = pending_kv_calls_.find(k);\n    if (pending_kv_calls_it != pending_kv_calls_.end()) {\n      for (auto pending_call : pending_kv_calls_it->second) {\n        pending_call->mut_response()->set_val(v);\n        pending_call->SendResponse();\n      }\n      pending_kv_calls_.erase(pending_kv_calls_it);\n    }\n    call->SendResponse();\n    EnqueueRequest<CtrlMethod::kPushKV>();\n  });\n\n  Add([this](CtrlCall<CtrlMethod::kClearKV>* call) {\n    const std::string& k = call->request().key();\n    CHECK_EQ(kv_.erase(k), 1);\n    CHECK(pending_kv_calls_.find(k) == pending_kv_calls_.end());\n    call->SendResponse();\n    EnqueueRequest<CtrlMethod::kClearKV>();\n  });\n\n  Add([this](CtrlCall<CtrlMethod::kPullKV>* call) {\n    const std::string& k = call->request().key();\n    auto kv_it = kv_.find(k);\n    if (kv_it != kv_.end()) {\n      call->mut_response()->set_val(kv_it->second);\n      call->SendResponse();\n    } else {\n      pending_kv_calls_[k].emplace_back(call);\n    }\n    EnqueueRequest<CtrlMethod::kPullKV>();\n  });\n\n  Add([this](CtrlCall<CtrlMethod::kClear>* call) {\n    name2lock_status_.clear();\n    kv_.clear();\n    CHECK(pending_kv_calls_.empty()) << \"size(): \" << pending_kv_calls_.size()\n                                     << \", begin()->key: \" << pending_kv_calls_.begin()->first;\n    call->SendResponse();\n    EnqueueRequest<CtrlMethod::kClear>();\n  });\n\n  Add([this](CtrlCall<CtrlMethod::kIncreaseCount>* call) {\n    int32_t& count = count_[call->request().key()];\n    count += call->request().val();\n    call->mut_response()->set_val(count);\n    call->SendResponse();\n    EnqueueRequest<CtrlMethod::kIncreaseCount>();\n  });\n\n  Add([this](CtrlCall<CtrlMethod::kEraseCount>* call) {\n    CHECK_EQ(count_.erase(call->request().key()), 1);\n    call->SendResponse();\n    EnqueueRequest<CtrlMethod::kEraseCount>();\n  });\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/control/rpc_server.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_CONTROL_RPC_SERVER_H_\n#define ONEFLOW_CORE_CONTROL_RPC_SERVER_H_\n\n#include <grpc++/alarm.h>\n#include <grpc++/server_builder.h>\n#include \"oneflow/core/control/ctrl_call.h\"\n#include \"oneflow/core/common/function_traits.h\"\n\nnamespace oneflow {\n\nnamespace {\ntemplate<size_t... Idx>\nstatic std::tuple<std::function<void(CtrlCall<(CtrlMethod)Idx>*)>...> GetHandlerTuple(\n    std::index_sequence<Idx...>) {\n  return {};\n}\n\n}  // namespace\n\nclass RpcServer {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(RpcServer);\n  virtual ~RpcServer();\n\n protected:\n  RpcServer() {}\n  void HandleRpcs();\n  void Init();\n\n  void EnqueueRequests() {\n    for_each_i(handlers_, helper{this}, std::make_index_sequence<kCtrlMethodNum>{});\n  }\n\n  template<CtrlMethod kMethod>\n  void EnqueueRequest() {\n    constexpr const size_t I = (size_t)kMethod;\n    auto handler = std::get<I>(handlers_);\n    auto call = new CtrlCall<(CtrlMethod)I>();\n    call->set_request_handler(std::bind(handler, call));\n    grpc_service_->RequestAsyncUnary(I, call->mut_server_ctx(), call->mut_request(),\n                                     call->mut_responder(), cq_.get(), cq_.get(), call);\n  }\n\n  template<typename F>\n  void Add(F f) {\n    using args_type = typename function_traits<F>::args_type;\n    using arg_type =\n        typename std::remove_pointer<typename std::tuple_element<0, args_type>::type>::type;\n\n    std::get<arg_type::value>(handlers_) = std::move(f);\n  }\n\n  virtual void OnLoadServer(CtrlCall<CtrlMethod::kLoadServer>* call) = 0;\n\n  struct helper {\n    helper(RpcServer* s) : s_(s) {}\n    template<typename T, typename V>\n    void operator()(const T& t, V) {\n      s_->EnqueueRequest<(CtrlMethod)V::value>();\n    }\n\n    RpcServer* s_;\n  };\n\n  using HandlerTuple = decltype(GetHandlerTuple(std::make_index_sequence<kCtrlMethodNum>{}));\n\n  HandlerTuple handlers_;\n  std::unique_ptr<CtrlService::AsyncService> grpc_service_;\n  std::unique_ptr<grpc::ServerCompletionQueue> cq_;\n  std::unique_ptr<grpc::Server> grpc_server_;\n  std::thread loop_thread_;\n  // Barrier\n  HashMap<std::string, std::pair<std::list<CtrlCallIf*>, int32_t>> barrier_calls_;\n  // TryLock, NotifyDone, WaitUntilDone\n  HashMap<std::string, void*> name2lock_status_;\n  // PushKV, ClearKV, PullKV\n  HashMap<std::string, std::string> kv_;\n  HashMap<std::string, std::list<CtrlCall<CtrlMethod::kPullKV>*>> pending_kv_calls_;\n  // IncreaseCount, EraseCount\n  HashMap<std::string, int32_t> count_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_CONTROL_RPC_SERVER_H_\n"
  },
  {
    "path": "oneflow/core/control/worker_process_info.proto",
    "content": "syntax = \"proto2\";\npackage oneflow;\n\nmessage WorkerProcessInfo {\n  required int64 rank = 1;\n  required int64 port = 2;\n  optional string host = 3;\n}\n"
  },
  {
    "path": "oneflow/core/cuda/atomic.cuh",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_CUDA_ATOMIC_H_\n#define ONEFLOW_CORE_CUDA_ATOMIC_H_\n\n#if defined(__CUDACC__)\n\n#include <cuda.h>\n#include <cuda_runtime.h>\n#include <cuda_fp16.h>\n#include <cstdint>\n#if CUDA_VERSION >= 11000\n#include <cuda_bf16.h>\n#endif  // CUDA_VERSION >= 11000\nnamespace oneflow {\n\nnamespace cuda {\n\nnamespace atomic {\n\nnamespace internal {\n\ntemplate<typename T, typename U>\nstruct CastCASImpl {\n  __device__ __forceinline__ T operator()(T* address, T compare, T val, bool* success) const {\n    static_assert(sizeof(T) == sizeof(U), \"\");\n    U assumed = *(reinterpret_cast<U*>(&compare));\n    U ret = atomicCAS(reinterpret_cast<U*>(address), assumed, *(reinterpret_cast<U*>(&val)));\n    *success = (ret == assumed);\n    return *(reinterpret_cast<T*>(&ret));\n  }\n};\n\n#if __CUDA_ARCH__ < 700 || (defined(__clang__) && defined(__CUDA__))\n\ntemplate<typename T>\nstruct CastCASImpl<T, unsigned short int> {\n  __device__ __forceinline__ T operator()(T* address, T compare, T val, bool* success) const {\n    static_assert(sizeof(T) == sizeof(unsigned short int), \"\");\n    size_t offset = reinterpret_cast<size_t>(address) & 0x2;\n    unsigned int* address_as_ui =\n        reinterpret_cast<unsigned int*>(reinterpret_cast<char*>(address) - offset);\n    unsigned int old = *address_as_ui;\n    unsigned int assumed = *(reinterpret_cast<unsigned short int*>(&compare));\n    unsigned int newval = *(reinterpret_cast<unsigned short int*>(&val));\n\n    assumed = offset ? (old & 0xffff) | (assumed << 16) : (old & 0xffff0000) | assumed;\n    newval = offset ? (old & 0xffff) | (newval << 16) : (old & 0xffff0000) | newval;\n\n    unsigned int ret = atomicCAS(address_as_ui, assumed, newval);\n    *success = (ret == assumed);\n    ret = offset ? (ret >> 16) : (ret & 0xffff);\n    return *(reinterpret_cast<T*>(&ret));\n  }\n};\n\n#endif  // __CUDA_ARCH__\n\ntemplate<typename T>\n__device__ __forceinline__ typename std::enable_if<sizeof(T) == sizeof(unsigned int), T>::type\nCASImpl(T* address, T compare, T val, bool* success) {\n  return CastCASImpl<T, unsigned int>()(address, compare, val, success);\n}\n\ntemplate<typename T>\n__device__ __forceinline__\n    typename std::enable_if<sizeof(T) == sizeof(unsigned long long int), T>::type\n    CASImpl(T* address, T compare, T val, bool* success) {\n  return CastCASImpl<T, unsigned long long int>()(address, compare, val, success);\n}\n\ntemplate<typename T>\n__device__ __forceinline__ typename std::enable_if<sizeof(T) == sizeof(unsigned short int), T>::type\nCASImpl(T* address, T compare, T val, bool* success) {\n  return CastCASImpl<T, unsigned short int>()(address, compare, val, success);\n}\n\n__device__ __forceinline__ int CASImpl(int* address, int compare, int val, bool* success) {\n  int ret = atomicCAS(address, compare, val);\n  *success = (ret == compare);\n  return ret;\n}\n\n__device__ __forceinline__ unsigned int CASImpl(unsigned int* address, unsigned int compare,\n                                                unsigned int val, bool* success) {\n  unsigned int ret = atomicCAS(address, compare, val);\n  *success = (ret == compare);\n  return ret;\n}\n\n__device__ __forceinline__ unsigned long long int CASImpl(unsigned long long int* address,\n                                                          unsigned long long int compare,\n                                                          unsigned long long int val,\n                                                          bool* success) {\n  unsigned long long int ret = atomicCAS(address, compare, val);\n  *success = (ret == compare);\n  return ret;\n}\n\n#if __CUDA_ARCH__ >= 700\n\n__device__ __forceinline__ unsigned short int CASImpl(unsigned short int* address,\n                                                      unsigned short int compare,\n                                                      unsigned short int val, bool* success) {\n  unsigned short int ret = atomicCAS(address, compare, val);\n  *success = (ret == compare);\n  return ret;\n}\n\n#endif  // __CUDA_ARCH__ >= 700\n\ntemplate<typename T>\nstruct AddOp {\n  __device__ __forceinline__ T operator()(T a, T b) { return a + b; }\n};\n\ntemplate<typename T, template<typename> class BinaryOp>\n__device__ __forceinline__ T AtomicCASBinaryImpl(T* address, T val) {\n  T old = *address;\n  T assumed;\n  bool success = false;\n  do {\n    assumed = old;\n    old = CASImpl(address, assumed, BinaryOp<T>()(old, val), &success);\n  } while (!success);\n  return old;\n}\n\ntemplate<typename T>\n__device__ __forceinline__ T AddImpl(T* address, T val) {\n  return AtomicCASBinaryImpl<T, AddOp>(address, val);\n}\n\n__device__ __forceinline__ int AddImpl(int* address, int val) { return atomicAdd(address, val); }\n\n__device__ __forceinline__ unsigned int AddImpl(unsigned int* address, unsigned int val) {\n  return atomicAdd(address, val);\n}\n\n__device__ __forceinline__ unsigned long long int AddImpl(unsigned long long int* address,\n                                                          unsigned long long int val) {\n  return atomicAdd(address, val);\n}\n\n__device__ __forceinline__ uint64_t AddImpl(uint64_t* address, uint64_t val) {\n  static_assert(sizeof(uint64_t) == sizeof(unsigned long long int), \"\");\n  return static_cast<uint64_t>(atomicAdd(reinterpret_cast<unsigned long long int*>(address),\n                                         static_cast<unsigned long long int>(val)));\n}\n\n__device__ __forceinline__ float AddImpl(float* address, float val) {\n  return atomicAdd(address, val);\n}\n\n#if __CUDA_ARCH__ >= 600\n\n__device__ __forceinline__ double AddImpl(double* address, double val) {\n  return atomicAdd(address, val);\n}\n\n__device__ __forceinline__ half2 AddImpl(half2* address, half2 val) {\n  return atomicAdd(address, val);\n}\n\n#endif  // __CUDA_ARCH__ >= 600\n\n#if __CUDA_ARCH__ >= 700\n\n__device__ __forceinline__ half AddImpl(half* address, half val) { return atomicAdd(address, val); }\n\n#endif  // __CUDA_ARCH__ >= 700\n\n#if __CUDA_ARCH__ >= 800\n\n__device__ __forceinline__ nv_bfloat16 AddImpl(nv_bfloat16* address, nv_bfloat16 val) {\n  return atomicAdd(address, val);\n}\n\n__device__ __forceinline__ nv_bfloat162 AddImpl(nv_bfloat162* address, nv_bfloat162 val) {\n  return atomicAdd(address, val);\n}\n\n#endif  // __CUDA_ARCH__ >= 800\n\n#if __CUDA_ARCH__ < 530\n\n__device__ __forceinline__ half2 AddImpl(half2* address, half2 val) {\n  __trap();\n  return val;\n}\n\n#endif  // __CUDA_ARCH__ < 530\n\n}  // namespace internal\n\ntemplate<typename T, typename U>\n__device__ __forceinline__ typename std::enable_if<!std::is_same<T, U>::value, T>::type Cast(U v) {\n  return static_cast<T>(v);\n}\n\ntemplate<typename T, typename U>\n__device__ __forceinline__ typename std::enable_if<std::is_same<T, U>::value, T>::type Cast(U v) {\n  return v;\n}\n\ntemplate<typename T, typename U, typename V>\n__device__ __forceinline__ T CAS(T* address, U compare, V val) {\n  bool success = false;\n  return internal::CASImpl(address, Cast<T>(compare), Cast<T>(val), &success);\n}\n\ntemplate<typename T, typename U>\n__device__ __forceinline__ T Add(T* address, U val) {\n  return internal::AddImpl(address, Cast<T>(val));\n}\n\n__device__ __forceinline__ float Mul(int32_t* address, const int32_t val) {\n  int32_t old = *address, assumed;\n  do {\n    assumed = old;\n    old = atomicCAS(address, assumed, val * assumed);\n  } while (assumed != old);\n  return old;\n}\n\n__device__ __forceinline__ float Mul(uint32_t* address, const uint32_t val) {\n  uint32_t old = *address, assumed;\n  do {\n    assumed = old;\n    old = atomicCAS(address, assumed, val * assumed);\n  } while (assumed != old);\n  return old;\n}\n\n__device__ __forceinline__ float Mul(uint64_t* address, const uint64_t val) {\n  static_assert(sizeof(uint64_t) == sizeof(unsigned long long int), \"\");\n  unsigned long long int old = *reinterpret_cast<unsigned long long int*>(address), assumed;\n  do {\n    assumed = old;\n    old = atomicCAS(reinterpret_cast<unsigned long long int*>(address), assumed,\n                    static_cast<unsigned long long int>(val) * assumed);\n  } while (assumed != old);\n  return old;\n}\n\n__device__ __forceinline__ float Mul(float* address, const float val) {\n  int32_t* address_as_int = reinterpret_cast<int32_t*>(address);\n  int32_t old = *address_as_int, assumed;\n  do {\n    assumed = old;\n    old = atomicCAS(address_as_int, assumed, __float_as_int(val * __int_as_float(assumed)));\n  } while (assumed != old);\n  return __int_as_float(old);\n}\n\n__device__ __forceinline__ float Mul(double* address, const double val) {\n  unsigned long long int* address_as_ull = reinterpret_cast<unsigned long long int*>(address);\n  unsigned long long int old = *address_as_ull, assumed;\n  do {\n    assumed = old;\n    old = atomicCAS(address_as_ull, assumed,\n                    __double_as_longlong(val * __longlong_as_double(assumed)));\n  } while (assumed != old);\n  return __longlong_as_double(old);\n}\n\n__device__ __forceinline__ float Max(float* address, const float val) {\n  int* address_as_i = (int*)address;\n  int old = *address_as_i;\n  int assumed = 0;\n  do {\n    assumed = old;\n    old = atomicCAS(address_as_i, assumed, __float_as_int(fmaxf(val, __int_as_float(assumed))));\n  } while (assumed != old);\n  return __int_as_float(old);\n}\n\n__device__ __forceinline__ double Max(double* address, const double val) {\n  unsigned long long int* address_as_i = (unsigned long long int*)address;\n  unsigned long long int old = *address_as_i;\n  unsigned long long int assumed = 0;\n  do {\n    assumed = old;\n    old = atomicCAS(address_as_i, assumed,\n                    __double_as_longlong(fmax(val, __longlong_as_double(assumed))));\n  } while (assumed != old);\n  return __longlong_as_double(old);\n}\n\n// FastAdd is referenced from\n// https://github.com/pytorch/pytorch/blob/396c3b1d88d7624938a2bb0b287f2a19f1e89bb4/aten/src/ATen/native/cuda/KernelUtils.cuh#L29\n#if defined(__CUDACC__)\ntemplate<typename T, typename std::enable_if<std::is_same<half, T>::value>::type* = nullptr>\n__device__ __forceinline__ void FastSpecializedAtomicAdd(T* base, size_t offset,\n                                                         const size_t length, T value) {\n#if ((defined(CUDA_VERSION) && (CUDA_VERSION < 10000)) \\\n     || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700)))\n  cuda::atomic::Add(reinterpret_cast<half*>(base) + offset, static_cast<half>(value));\n#else\n  // Accounts for the chance base falls on an odd 16 bit alignment (ie, not 32 bit aligned)\n  __half* target_addr = reinterpret_cast<__half*>(base + offset);\n  bool low_byte = (reinterpret_cast<std::uintptr_t>(target_addr) % sizeof(__half2) == 0);\n\n  if (low_byte && offset < (length - 1)) {\n    __half2 value2;\n    value2.x = value;\n    value2.y = __float2half_rz(0);\n    cuda::atomic::Add(reinterpret_cast<__half2*>(target_addr), value2);\n\n  } else if (!low_byte && offset > 0) {\n    __half2 value2;\n    value2.x = __float2half_rz(0);\n    value2.y = value;\n    cuda::atomic::Add(reinterpret_cast<__half2*>(target_addr - 1), value2);\n\n  } else {\n    cuda::atomic::Add(reinterpret_cast<__half*>(base) + offset, static_cast<__half>(value));\n  }\n#endif\n}\n\ntemplate<typename T, typename std::enable_if<!std::is_same<half, T>::value>::type* = nullptr>\n__device__ __forceinline__ void FastSpecializedAtomicAdd(T* base, size_t offset,\n                                                         const size_t length, T value) {\n  cuda::atomic::Add(base + offset, value);\n}\n\ntemplate<class T>\n__device__ __forceinline__ void FastAdd(T* base, size_t offset, const size_t length, T value) {\n  FastSpecializedAtomicAdd(base, offset, length, value);\n}\n#endif\n\n}  // namespace atomic\n\n}  // namespace cuda\n\n}  // namespace oneflow\n\n#endif  // defined(__CUDACC__)\n\n#endif  // ONEFLOW_CORE_CUDA_ATOMIC_H_\n"
  },
  {
    "path": "oneflow/core/cuda/elementwise.cuh",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_CUDA_ELEMENTWISE_H_\n#define ONEFLOW_CORE_CUDA_ELEMENTWISE_H_\n\n#include <cuda_runtime.h>\n#include <cstdint>\n#include <algorithm>\n#include <type_traits>\n\nnamespace oneflow {\n\nnamespace cuda {\n\nnamespace elementwise {\n\nconstexpr int kBlockSize = 256;\nconstexpr int kNumWaves = 32;\n\ninline cudaError_t GetNumBlocks(int64_t n, int* num_blocks) {\n  int dev;\n  {\n    cudaError_t err = cudaGetDevice(&dev);\n    if (err != cudaSuccess) { return err; }\n  }\n  int sm_count;\n  {\n    cudaError_t err = cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev);\n    if (err != cudaSuccess) { return err; }\n  }\n  int tpm;\n  {\n    cudaError_t err = cudaDeviceGetAttribute(&tpm, cudaDevAttrMaxThreadsPerMultiProcessor, dev);\n    if (err != cudaSuccess) { return err; }\n  }\n  *num_blocks = std::max<int>(1, std::min<int64_t>((n + kBlockSize - 1) / kBlockSize,\n                                                   sm_count * tpm / kBlockSize * kNumWaves));\n  return cudaSuccess;\n}\n\ntemplate<typename T, int pack_size>\nstruct GetPackType {\n  using type = typename std::aligned_storage<pack_size * sizeof(T), pack_size * sizeof(T)>::type;\n};\n\ntemplate<typename T, int pack_size>\nusing PackType = typename GetPackType<T, pack_size>::type;\n\ntemplate<typename T, int pack_size>\nunion Pack {\n  static_assert(sizeof(PackType<T, pack_size>) == sizeof(T) * pack_size, \"\");\n  __device__ Pack() {\n    // do nothing\n  }\n  PackType<T, pack_size> storage;\n  T elem[pack_size];\n};\n\ntemplate<typename T, int pack_size>\nstruct alignas(sizeof(T) * pack_size) Packed {\n  __device__ Packed() {\n    // do nothing\n  }\n  union {\n    T elem[pack_size];\n  };\n};\n\nconstexpr int kMaxPackBytes = 128 / 8;\nconstexpr int kMaxPackSize = 8;\n\nconstexpr int Min(int a, int b) { return a < b ? a : b; }\n\ntemplate<typename T>\nconstexpr int PackSize() {\n  return Min(kMaxPackBytes / sizeof(T), kMaxPackSize);\n}\n\ntemplate<typename T, typename U, typename... Args>\nconstexpr int PackSize() {\n  return Min(PackSize<T>(), PackSize<U, Args...>());\n}\n\ntemplate<typename T>\nclass HasApply2 {\n  typedef char one;\n  struct two {\n    char x[2];\n  };\n\n  template<typename C>\n  static one test(decltype(&C::Apply2));\n  template<typename C>\n  static two test(...);\n\n public:\n  enum { value = sizeof(test<T>(0)) == sizeof(char) };\n};\n\ntemplate<int pack_size, typename FunctorT, typename R, typename... IN>\n__device__ typename std::enable_if<HasApply2<FunctorT>::value == true && pack_size % 2 == 0,\n                                   Packed<R, pack_size>>::type\nApplyPack(const FunctorT& functor, const Packed<IN, pack_size>... in) {\n  Packed<R, pack_size> ret;\n#pragma unroll\n  for (int j = 0; j < pack_size; j += 2) { functor.Apply2(ret.elem + j, (in.elem + j)...); }\n  return ret;\n}\n\ntemplate<int pack_size, typename FunctorT, typename R, typename... IN>\n__device__ typename std::enable_if<HasApply2<FunctorT>::value == false || pack_size % 2 != 0,\n                                   Packed<R, pack_size>>::type\nApplyPack(const FunctorT& functor, const Packed<IN, pack_size>... in) {\n  Packed<R, pack_size> ret;\n#pragma unroll\n  for (int j = 0; j < pack_size; ++j) { ret.elem[j] = functor((in.elem[j])...); }\n  return ret;\n}\n\ntemplate<int pack_size, typename FactoryT, typename R, typename... IN>\n__global__ void __launch_bounds__(kBlockSize)\n    ApplyGeneric(FactoryT factory, int64_t n_pack, Packed<R, pack_size>* pack_r,\n                 const Packed<IN, pack_size>*... pack_in, int64_t n_tail, R* tail_r,\n                 const IN*... tail_in) {\n  auto functor = factory();\n  const int global_tid = blockIdx.x * kBlockSize + threadIdx.x;\n  for (int64_t i = global_tid; i < n_pack; i += blockDim.x * gridDim.x) {\n    pack_r[i] = ApplyPack<pack_size, decltype(functor), R, IN...>(functor, (pack_in[i])...);\n  }\n  if (global_tid < n_tail) { tail_r[global_tid] = functor((tail_in[global_tid])...); }\n}\n\ntemplate<typename FunctorT>\nstruct SimpleFactory {\n  explicit SimpleFactory(FunctorT functor) : tpl(functor) {}\n  __device__ FunctorT operator()() const { return tpl; }\n\n private:\n  FunctorT tpl;\n};\n\ntemplate<size_t pack_size>\nbool IsAlignedForPack() {\n  return true;\n}\n\ntemplate<size_t pack_size, typename T, typename... Args>\nbool IsAlignedForPack(const T* ptr, const Args*... others) {\n  return reinterpret_cast<uintptr_t>(ptr) % sizeof(Pack<T, pack_size>) == 0\n         && IsAlignedForPack<pack_size, Args...>(others...);\n}\n\ntemplate<size_t pack_size, typename FactoryT, typename R, typename... IN>\ncudaError_t LaunchKernel(FactoryT factory, int64_t n, R* r, const IN*... in, cudaStream_t stream) {\n  const int64_t n_pack = n / pack_size;\n  const int64_t tail_offset = n_pack * pack_size;\n  const int64_t n_tail = n - tail_offset;\n  int num_blocks;\n  {\n    cudaError_t err = GetNumBlocks(n_pack, &num_blocks);\n    if (err != cudaSuccess) { return err; }\n  }\n  ApplyGeneric<pack_size, FactoryT, R, IN...><<<num_blocks, kBlockSize, 0, stream>>>(\n      factory, n_pack, reinterpret_cast<Packed<R, pack_size>*>(r),\n      (reinterpret_cast<const Packed<IN, pack_size>*>(in))..., n_tail, r + tail_offset,\n      (in + tail_offset)...);\n  return cudaPeekAtLastError();\n}\n\ntemplate<typename FactoryT, typename R, typename... IN>\nstruct GenericLauncher {\n  static cudaError_t Launch(FactoryT factory, int64_t n, R* r, const IN*... in,\n                            cudaStream_t stream) {\n    constexpr int max_pack_size = PackSize<R, IN...>();\n    if (IsAlignedForPack<max_pack_size, R, IN...>(r, in...)) {\n      return LaunchKernel<max_pack_size, FactoryT, R, IN...>(factory, n, r, in..., stream);\n    } else {\n      return LaunchKernel<1, FactoryT, R, IN...>(factory, n, r, in..., stream);\n    }\n  }\n};\n\ntemplate<typename FactoryT, typename R, typename A>\ninline cudaError_t UnaryWithFactory(FactoryT factory, int64_t n, R* r, const A* a,\n                                    cudaStream_t stream) {\n  return GenericLauncher<FactoryT, R, A>::Launch(factory, n, r, a, stream);\n}\n\ntemplate<typename FunctorT, typename R, typename A>\ninline cudaError_t Unary(FunctorT functor, int64_t n, R* r, const A* a, cudaStream_t stream) {\n  return UnaryWithFactory(SimpleFactory<FunctorT>(functor), n, r, a, stream);\n}\n\ntemplate<typename FactoryT, typename R, typename A, typename B>\ninline cudaError_t BinaryWithFactory(FactoryT factory, int64_t n, R* r, const A* a, const B* b,\n                                     cudaStream_t stream) {\n  return GenericLauncher<FactoryT, R, A, B>::Launch(factory, n, r, a, b, stream);\n}\n\ntemplate<typename FunctorT, typename R, typename A, typename B>\ninline cudaError_t Binary(FunctorT functor, int64_t n, R* r, const A* a, const B* b,\n                          cudaStream_t stream) {\n  return BinaryWithFactory(SimpleFactory<FunctorT>(functor), n, r, a, b, stream);\n}\n\ntemplate<typename FactoryT, typename R, typename A, typename B, typename C>\ninline cudaError_t TernaryWithFactory(FactoryT factory, int64_t n, R* r, const A* a, const B* b,\n                                      const C* c, cudaStream_t stream) {\n  return GenericLauncher<FactoryT, R, A, B, C>::Launch(factory, n, r, a, b, c, stream);\n}\n\ntemplate<typename FunctorT, typename R, typename A, typename B, typename C>\ninline cudaError_t Ternary(FunctorT functor, int64_t n, R* r, const A* a, const B* b, const C* c,\n                           cudaStream_t stream) {\n  return TernaryWithFactory(SimpleFactory<FunctorT>(functor), n, r, a, b, c, stream);\n}\n\n}  // namespace elementwise\n\n}  // namespace cuda\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_CUDA_ELEMENTWISE_H_\n"
  },
  {
    "path": "oneflow/core/cuda/layer_norm.cuh",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#ifndef ONEFLOW_CORE_CUDA_LAYER_NORM_H_\n#define ONEFLOW_CORE_CUDA_LAYER_NORM_H_\n\n#include <cub/cub.cuh>\n#include <math_constants.h>\n#include <assert.h>\n\nnamespace oneflow {\n\nnamespace cuda {\n\nnamespace layer_norm {\n\nconstexpr int kWarpSize = 32;\n\ntemplate<typename T>\nstruct SumOp {\n  __device__ __forceinline__ T operator()(const T& a, const T& b) const { return a + b; }\n};\n\ntemplate<typename T>\nstruct MaxOp {\n  __device__ __forceinline__ T operator()(const T& a, const T& b) const { return max(a, b); }\n};\n\ntemplate<template<typename> class ReductionOp, typename T, int thread_group_width = kWarpSize>\n__inline__ __device__ T WarpAllReduce(T val) {\n  for (int mask = thread_group_width / 2; mask > 0; mask /= 2) {\n    val = ReductionOp<T>()(val, __shfl_xor_sync(0xffffffff, val, mask, thread_group_width));\n  }\n  return val;\n}\n\ntemplate<template<typename> class ReductionOp, typename T, int block_size>\n__inline__ __device__ T BlockAllReduce(T val) {\n  typedef cub::BlockReduce<T, block_size> BlockReduce;\n  __shared__ typename BlockReduce::TempStorage temp_storage;\n  __shared__ T result_broadcast;\n  T result = BlockReduce(temp_storage).Reduce(val, ReductionOp<T>());\n  if (threadIdx.x == 0) { result_broadcast = result; }\n  __syncthreads();\n  return result_broadcast;\n}\n\ntemplate<typename T>\n__inline__ __device__ T Div(T a, T b);\n\ntemplate<>\n__inline__ __device__ float Div<float>(float a, float b) {\n#ifdef OF_LAYER_NORM_USE_FAST_MATH\n  return __fdividef(a, b);\n#else\n  return a / b;\n#endif\n}\n\ntemplate<>\n__inline__ __device__ double Div<double>(double a, double b) {\n  return a / b;\n}\n\ntemplate<typename T>\n__inline__ __device__ T Rsqrt(T x);\n\ntemplate<>\n__inline__ __device__ float Rsqrt<float>(float x) {\n#ifdef OF_LAYER_NORM_USE_FAST_MATH\n  return __frsqrt_rn(x);\n#else\n  return rsqrt(x);\n#endif\n}\n\ntemplate<>\n__inline__ __device__ double Rsqrt<double>(double x) {\n  return rsqrt(x);\n}\n\ntemplate<class Func>\ninline cudaError_t GetNumBlocks(Func func, int64_t block_size, size_t dynamic_smem_size,\n                                int64_t max_blocks, int64_t waves, int* num_blocks) {\n  int dev;\n  {\n    cudaError_t err = cudaGetDevice(&dev);\n    if (err != cudaSuccess) { return err; }\n  }\n  int sm_count;\n  {\n    cudaError_t err = cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev);\n    if (err != cudaSuccess) { return err; }\n  }\n  int max_active_blocks;\n  {\n    cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_active_blocks, func,\n                                                                    block_size, dynamic_smem_size);\n  }\n  *num_blocks =\n      std::max<int>(1, std::min<int64_t>(max_blocks, sm_count * max_active_blocks * waves));\n  return cudaSuccess;\n}\n\ntemplate<typename T>\nstruct DefaultComputeType {\n  using type = T;\n};\n\ntemplate<>\nstruct DefaultComputeType<half> {\n  using type = float;\n};\n\n#if CUDA_VERSION >= 11000\ntemplate<>\nstruct DefaultComputeType<nv_bfloat16> {\n  using type = float;\n};\n#endif  // CUDA_VERSION >= 11000\n\ntemplate<typename T>\nclass HasCanPackAs {\n  typedef char one;\n  struct two {\n    char x[2];\n  };\n\n  template<typename C>\n  static one test(decltype(&C::CanPackAs));\n  template<typename C>\n  static two test(...);\n\n public:\n  enum { value = sizeof(test<T>(0)) == sizeof(char) };\n};\n\ntemplate<typename T>\ntypename std::enable_if<HasCanPackAs<T>::value == true, bool>::type CanPackAs(T t,\n                                                                              size_t pack_size) {\n  return t.CanPackAs(pack_size);\n}\n\ntemplate<typename T>\ntypename std::enable_if<HasCanPackAs<T>::value == false, bool>::type CanPackAs(T t,\n                                                                               size_t pack_size) {\n  return true;\n}\n\ntemplate<typename T, int N>\nstruct GetPackType {\n  using type = typename std::aligned_storage<N * sizeof(T), N * sizeof(T)>::type;\n};\n\ntemplate<typename T, int N>\nusing PackType = typename GetPackType<T, N>::type;\n\ntemplate<typename T, int N>\nunion Pack {\n  static_assert(sizeof(PackType<T, N>) == sizeof(T) * N, \"\");\n  __device__ Pack() {\n    // do nothing\n  }\n  PackType<T, N> storage;\n  T elem[N];\n};\n\ntemplate<typename SRC, typename DST>\nstruct DirectLoad {\n  using LoadType = DST;\n  DirectLoad(const SRC* src, int64_t row_size) : src(src), row_size(row_size) {}\n  template<int N>\n  __device__ void load(DST* dst, int64_t row, int64_t col) const {\n    Pack<SRC, N> pack;\n    const int64_t offset = (row * row_size + col) / N;\n    pack.storage = *(reinterpret_cast<const PackType<SRC, N>*>(src) + offset);\n#pragma unroll\n    for (int i = 0; i < N; ++i) { dst[i] = static_cast<DST>(pack.elem[i]); }\n  }\n  const SRC* src;\n  int64_t row_size;\n};\n\ntemplate<typename SRC, typename DST>\nstruct DirectStore {\n  DirectStore(DST* dst, int64_t row_size) : dst(dst), row_size(row_size) {}\n  template<int N>\n  __device__ void store(const SRC* src, int64_t row, int64_t col) {\n    Pack<DST, N> pack;\n    const int64_t offset = (row * row_size + col) / N;\n#pragma unroll\n    for (int i = 0; i < N; ++i) { pack.elem[i] = static_cast<DST>(src[i]); }\n    *(reinterpret_cast<PackType<DST, N>*>(dst) + offset) = pack.storage;\n  }\n  DST* dst;\n  int64_t row_size;\n};\n\ntemplate<typename T>\ninline __device__ void WelfordCombine(T val, T* mean, T* m2, T* count) {\n  // Use Welford Online algorithem to compute mean and variance\n  // For more details you can refer to:\n  // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm\n  *count += 1;\n  T delta1 = val - *mean;\n  *mean += Div(delta1, *count);\n  T delta2 = val - *mean;\n  *m2 += delta1 * delta2;\n}\n\ntemplate<typename T>\ninline __device__ void WelfordCombine(T b_mean, T b_m2, T b_count, T* mean, T* m2, T* count) {\n  if (b_count == 0) { return; }\n  T new_count = *count + b_count;\n  T nb_over_n = Div(b_count, new_count);\n  T delta = b_mean - *mean;\n  *mean += delta * nb_over_n;\n  *m2 += b_m2 + delta * delta * (*count) * nb_over_n;\n  *count = new_count;\n}\n\ntemplate<typename T, int thread_group_width = kWarpSize>\n__inline__ __device__ void WelfordWarpReduce(T thread_mean, T thread_m2, T thread_count, T* mean,\n                                             T* m2, T* count) {\n  *mean = thread_mean;\n  *m2 = thread_m2;\n  *count = thread_count;\n  for (int mask = thread_group_width / 2; mask > 0; mask /= 2) {\n    T b_mean = __shfl_down_sync(0xffffffff, *mean, mask, thread_group_width);\n    T b_m2 = __shfl_down_sync(0xffffffff, *m2, mask, thread_group_width);\n    T b_count = __shfl_down_sync(0xffffffff, *count, mask, thread_group_width);\n    WelfordCombine(b_mean, b_m2, b_count, mean, m2, count);\n  }\n}\n\ntemplate<typename T, int thread_group_width = kWarpSize>\n__inline__ __device__ void WelfordWarpAllReduce(T thread_mean, T thread_m2, T thread_count, T* mean,\n                                                T* m2, T* count) {\n  WelfordWarpReduce<T, thread_group_width>(thread_mean, thread_m2, thread_count, mean, m2, count);\n  *mean = __shfl_sync(0xffffffff, *mean, 0, thread_group_width);\n  *m2 = __shfl_sync(0xffffffff, *m2, 0, thread_group_width);\n  *count = __shfl_sync(0xffffffff, *count, 0, thread_group_width);\n}\n\ntemplate<typename T>\n__inline__ __device__ void WelfordBlockAllReduce(T thread_mean, T thread_m2, T thread_count,\n                                                 T* result_mean, T* result_m2, T* result_count) {\n  __shared__ T mean_shared[kWarpSize];\n  __shared__ T m2_shared[kWarpSize];\n  __shared__ T count_shared[kWarpSize];\n  __shared__ T mean_result_broadcast;\n  __shared__ T m2_result_broadcast;\n  __shared__ T count_result_broadcast;\n  const int lid = threadIdx.x % kWarpSize;\n  const int wid = threadIdx.x / kWarpSize;\n  T warp_mean = 0;\n  T warp_m2 = 0;\n  T warp_count = 0;\n  WelfordWarpReduce(thread_mean, thread_m2, thread_count, &warp_mean, &warp_m2, &warp_count);\n  __syncthreads();\n  if (lid == 0) {\n    mean_shared[wid] = warp_mean;\n    m2_shared[wid] = warp_m2;\n    count_shared[wid] = warp_count;\n  }\n  __syncthreads();\n  if (wid == 0) {\n    if (threadIdx.x < blockDim.x / kWarpSize) {\n      warp_mean = mean_shared[lid];\n      warp_m2 = m2_shared[lid];\n      warp_count = count_shared[lid];\n    } else {\n      warp_mean = static_cast<T>(0);\n      warp_m2 = static_cast<T>(0);\n      warp_count = static_cast<T>(0);\n    }\n    __syncwarp();\n    T block_mean = 0;\n    T block_m2 = 0;\n    T block_count = 0;\n    WelfordWarpReduce(warp_mean, warp_m2, warp_count, &block_mean, &block_m2, &block_count);\n    if (lid == 0) {\n      mean_result_broadcast = block_mean;\n      m2_result_broadcast = block_m2;\n      count_result_broadcast = block_count;\n    }\n  }\n  __syncthreads();\n  *result_mean = mean_result_broadcast;\n  *result_m2 = m2_result_broadcast;\n  *result_count = count_result_broadcast;\n}\n\ntemplate<typename LOAD, typename STORE, typename ComputeType, int pack_size,\n         int max_cols_per_thread, int min_cols_per_thread, int thread_group_width,\n         int rows_per_access, bool padding>\n__global__ void LayerNormWarpImpl(LOAD load, STORE store, const int64_t rows, const int64_t cols,\n                                  const double epsilon, ComputeType* mean,\n                                  ComputeType* inv_variance) {\n  using LoadType = typename LOAD::LoadType;\n  static_assert(max_cols_per_thread % pack_size == 0, \"\");\n  static_assert(min_cols_per_thread % pack_size == 0, \"\");\n  static_assert(thread_group_width <= kWarpSize, \"\");\n  static_assert(kWarpSize % thread_group_width == 0, \"\");\n  constexpr int max_num_packs = max_cols_per_thread / pack_size;\n  constexpr int min_num_packs = min_cols_per_thread / pack_size;\n  assert(cols <= max_cols_per_thread * thread_group_width);\n  ComputeType buf[rows_per_access][max_cols_per_thread];\n  const int64_t global_thread_group_id = blockIdx.x * blockDim.y + threadIdx.y;\n  const int64_t num_global_thread_group = gridDim.x * blockDim.y;\n  const int64_t lane_id = threadIdx.x;\n  const int64_t step = num_global_thread_group * rows_per_access;\n  for (int64_t row = global_thread_group_id * rows_per_access; row < rows; row += step) {\n    ComputeType thread_mean[rows_per_access];\n    ComputeType thread_m2[rows_per_access];\n    ComputeType thread_count[rows_per_access];\n#pragma unroll\n    for (int row_id = 0; row_id < rows_per_access; ++row_id) {\n      thread_mean[row_id] = 0;\n      thread_m2[row_id] = 0;\n      thread_count[row_id] = 0;\n      ComputeType* row_buf = buf[row_id];\n#pragma unroll\n      for (int pack_id = 0; pack_id < min_num_packs; ++pack_id) {\n        const int col = (pack_id * thread_group_width + lane_id) * pack_size;\n        const int pack_offset = pack_id * pack_size;\n        LoadType pack[pack_size];\n        load.template load<pack_size>(pack, row + row_id, col);\n#pragma unroll\n        for (int i = 0; i < pack_size; ++i) {\n          row_buf[pack_offset + i] = static_cast<ComputeType>(pack[i]);\n          WelfordCombine(row_buf[pack_offset + i], thread_mean + row_id, thread_m2 + row_id,\n                         thread_count + row_id);\n        }\n      }\n      for (int pack_id = min_num_packs; pack_id < max_num_packs; ++pack_id) {\n        const int col = (pack_id * thread_group_width + lane_id) * pack_size;\n        const int pack_offset = pack_id * pack_size;\n        if (!padding || col < cols) {\n          LoadType pack[pack_size];\n          load.template load<pack_size>(pack, row + row_id, col);\n#pragma unroll\n          for (int i = 0; i < pack_size; ++i) {\n            row_buf[pack_offset + i] = static_cast<ComputeType>(pack[i]);\n            WelfordCombine(row_buf[pack_offset + i], thread_mean + row_id, thread_m2 + row_id,\n                           thread_count + row_id);\n          }\n        } else {\n#pragma unroll\n          for (int i = 0; i < pack_size; ++i) { row_buf[pack_offset + i] = 0; }\n        }\n      }\n    }\n    ComputeType warp_mean[rows_per_access];\n    ComputeType warp_m2[rows_per_access];\n    ComputeType warp_count[rows_per_access];\n#pragma unroll\n    for (int row_id = 0; row_id < rows_per_access; ++row_id) {\n      int global_row_id = row + row_id;\n      ComputeType* row_buf = buf[row_id];\n      WelfordWarpAllReduce<ComputeType, thread_group_width>(\n          thread_mean[row_id], thread_m2[row_id], thread_count[row_id], warp_mean + row_id,\n          warp_m2 + row_id, warp_count + row_id);\n      ComputeType row_mean = warp_mean[row_id];\n      ComputeType row_variance =\n          max(Div(warp_m2[row_id], warp_count[row_id]), static_cast<ComputeType>(0.0));\n      ComputeType row_inv_var = Rsqrt(row_variance + static_cast<ComputeType>(epsilon));\n      if (lane_id == 0) {\n        mean[global_row_id] = row_mean;\n        inv_variance[global_row_id] = row_inv_var;\n      }\n#pragma unroll\n      for (int i = 0; i < max_cols_per_thread; ++i) {\n        row_buf[i] = (row_buf[i] - row_mean) * row_inv_var;\n      }\n#pragma unroll\n      for (int i = 0; i < min_num_packs; ++i) {\n        const int col = (i * thread_group_width + lane_id) * pack_size;\n        store.template store<pack_size>(row_buf + i * pack_size, global_row_id, col);\n      }\n#pragma unroll\n      for (int i = min_num_packs; i < max_num_packs; ++i) {\n        const int col = (i * thread_group_width + lane_id) * pack_size;\n        if (!padding || col < cols) {\n          store.template store<pack_size>(row_buf + i * pack_size, global_row_id, col);\n        }\n      }\n    }\n  }\n}\n\ntemplate<typename LOAD, typename STORE, typename ComputeType, int pack_size,\n         int max_cols_per_thread, int min_cols_per_thread, int thread_group_width,\n         int rows_per_access, bool padding>\ninline cudaError_t LaunchLayerNormWarpImpl(cudaStream_t stream, LOAD load, STORE store,\n                                           const int64_t rows, const int64_t cols,\n                                           const double epsilon, ComputeType* mean,\n                                           ComputeType* inv_variance) {\n  constexpr int block_size = 128;\n  constexpr int waves = 32;\n  static_assert(block_size % thread_group_width == 0, \"\");\n  constexpr int thread_groups_per_block = block_size / thread_group_width;\n  dim3 block_dim(thread_group_width, thread_groups_per_block);\n  const int64_t num_blocks =\n      (rows / rows_per_access + thread_groups_per_block - 1) / thread_groups_per_block;\n  int grid_dim_x;\n  {\n    cudaError_t err = GetNumBlocks(\n        LayerNormWarpImpl<LOAD, STORE, ComputeType, pack_size, max_cols_per_thread,\n                          min_cols_per_thread, thread_group_width, rows_per_access, padding>,\n        block_size, 0, num_blocks, waves, &grid_dim_x);\n    if (err != cudaSuccess) { return err; }\n  }\n  LayerNormWarpImpl<LOAD, STORE, ComputeType, pack_size, max_cols_per_thread, min_cols_per_thread,\n                    thread_group_width, rows_per_access, padding>\n      <<<grid_dim_x, block_dim, 0, stream>>>(load, store, rows, cols, epsilon, mean, inv_variance);\n  return cudaPeekAtLastError();\n}\n\ntemplate<typename LOAD, typename STORE, typename ComputeType, int pack_size,\n         int max_cols_per_thread, int min_cols_per_thread, int thread_group_width,\n         int rows_per_access>\ninline cudaError_t DispatchLayerNormWarpImplPadding(cudaStream_t stream, LOAD load, STORE store,\n                                                    const int64_t rows, const int64_t cols,\n                                                    const double epsilon, ComputeType* mean,\n                                                    ComputeType* inv_variance) {\n  if (cols == max_cols_per_thread * thread_group_width) {\n    // when not padding, min_cols_per_thread must equals to max_cols_per_thread, pass\n    // max_cols_per_thread as min_cols_per_thread and max_cols_per_thread param.\n    return LaunchLayerNormWarpImpl<LOAD, STORE, ComputeType, pack_size, max_cols_per_thread,\n                                   max_cols_per_thread, thread_group_width, rows_per_access, false>(\n        stream, load, store, rows, cols, epsilon, mean, inv_variance);\n  } else {\n    return LaunchLayerNormWarpImpl<LOAD, STORE, ComputeType, pack_size, max_cols_per_thread,\n                                   min_cols_per_thread, thread_group_width, rows_per_access, true>(\n        stream, load, store, rows, cols, epsilon, mean, inv_variance);\n  }\n}\n\ntemplate<typename LOAD, typename STORE, typename ComputeType, int pack_size>\ntypename std::enable_if<pack_size == 1, cudaError_t>::type DispatchLayerNormWarpImplCols(\n    cudaStream_t stream, LOAD load, STORE store, const int64_t rows, const int64_t cols,\n    const double epsilon, ComputeType* mean, ComputeType* inv_variance) {\n  if (cols <= 0) { return cudaErrorInvalidValue; }\n#define DEFINE_ONE_ELIF(thread_group_width)                                                      \\\n  else if (cols <= (thread_group_width)*pack_size) {                                             \\\n    if (rows % 2 == 0) {                                                                         \\\n      return DispatchLayerNormWarpImplPadding<LOAD, STORE, ComputeType, pack_size, pack_size, 0, \\\n                                              thread_group_width, 2>(                            \\\n          stream, load, store, rows, cols, epsilon, mean, inv_variance);                         \\\n    } else {                                                                                     \\\n      return DispatchLayerNormWarpImplPadding<LOAD, STORE, ComputeType, pack_size, pack_size, 0, \\\n                                              thread_group_width, 1>(                            \\\n          stream, load, store, rows, cols, epsilon, mean, inv_variance);                         \\\n    }                                                                                            \\\n  }\n  DEFINE_ONE_ELIF(4)\n  DEFINE_ONE_ELIF(8)\n  DEFINE_ONE_ELIF(16)\n  DEFINE_ONE_ELIF(32)\n#undef DEFINE_ONE_ELIF\n#define DEFINE_ONE_ELIF(max_col, min_col)                                                          \\\n  else if (cols <= (max_col)*kWarpSize) {                                                          \\\n    return DispatchLayerNormWarpImplPadding<LOAD, STORE, ComputeType, pack_size, max_col, min_col, \\\n                                            kWarpSize, 1>(stream, load, store, rows, cols,         \\\n                                                          epsilon, mean, inv_variance);            \\\n  }\n  DEFINE_ONE_ELIF(2, 1)\n  DEFINE_ONE_ELIF(4, 2)\n  DEFINE_ONE_ELIF(8, 4)\n  DEFINE_ONE_ELIF(12, 8)\n  DEFINE_ONE_ELIF(16, 12)\n  DEFINE_ONE_ELIF(20, 16)\n  DEFINE_ONE_ELIF(24, 20)\n  DEFINE_ONE_ELIF(28, 24)\n  DEFINE_ONE_ELIF(32, 28)\n#undef DEFINE_ONE_ELIF\n  else {\n    return cudaErrorInvalidValue;\n  }\n}\n\ntemplate<typename LOAD, typename STORE, typename ComputeType, int pack_size>\ntypename std::enable_if<pack_size == 2, cudaError_t>::type DispatchLayerNormWarpImplCols(\n    cudaStream_t stream, LOAD load, STORE store, const int64_t rows, const int64_t cols,\n    const double epsilon, ComputeType* mean, ComputeType* inv_variance) {\n  if (cols <= 0) { return cudaErrorInvalidValue; }\n#define DEFINE_ONE_ELIF(thread_group_width)                                                      \\\n  else if (cols <= (thread_group_width)*pack_size) {                                             \\\n    if (rows % 2 == 0) {                                                                         \\\n      return DispatchLayerNormWarpImplPadding<LOAD, STORE, ComputeType, pack_size, pack_size, 0, \\\n                                              thread_group_width, 2>(                            \\\n          stream, load, store, rows, cols, epsilon, mean, inv_variance);                         \\\n    } else {                                                                                     \\\n      return DispatchLayerNormWarpImplPadding<LOAD, STORE, ComputeType, pack_size, pack_size, 0, \\\n                                              thread_group_width, 1>(                            \\\n          stream, load, store, rows, cols, epsilon, mean, inv_variance);                         \\\n    }                                                                                            \\\n  }\n  DEFINE_ONE_ELIF(4)\n  DEFINE_ONE_ELIF(8)\n  DEFINE_ONE_ELIF(16)\n  DEFINE_ONE_ELIF(32)\n#undef DEFINE_ONE_ELIF\n#define DEFINE_ONE_ELIF(max_col, min_col)                                                          \\\n  else if ((cols <= (max_col)*kWarpSize) && (cols > (min_col)*kWarpSize)) {                        \\\n    return DispatchLayerNormWarpImplPadding<LOAD, STORE, ComputeType, pack_size, max_col, min_col, \\\n                                            kWarpSize, 1>(stream, load, store, rows, cols,         \\\n                                                          epsilon, mean, inv_variance);            \\\n  }\n  DEFINE_ONE_ELIF(4, 2)\n  DEFINE_ONE_ELIF(8, 4)\n  DEFINE_ONE_ELIF(12, 8)\n  DEFINE_ONE_ELIF(16, 12)\n  DEFINE_ONE_ELIF(20, 16)\n  DEFINE_ONE_ELIF(24, 20)\n  DEFINE_ONE_ELIF(28, 24)\n  DEFINE_ONE_ELIF(32, 28)\n#undef DEFINE_ONE_ELIF\n  else {\n    return cudaErrorInvalidValue;\n  }\n}\n\ntemplate<typename LOAD, typename STORE, typename ComputeType>\nstruct DispatchLayerNormWarpImplPackSize {\n  cudaError_t operator()(cudaStream_t stream, LOAD load, STORE store, const int64_t rows,\n                         const int64_t cols, const double epsilon, ComputeType* mean,\n                         ComputeType* inv_variance) {\n    if (cols % 2 == 0 && CanPackAs<LOAD>(load, 2) && CanPackAs<STORE>(store, 2)) {\n      return DispatchLayerNormWarpImplCols<LOAD, STORE, ComputeType, 2>(\n          stream, load, store, rows, cols, epsilon, mean, inv_variance);\n    } else {\n      return DispatchLayerNormWarpImplCols<LOAD, STORE, ComputeType, 1>(\n          stream, load, store, rows, cols, epsilon, mean, inv_variance);\n    }\n  }\n};\n\ntemplate<typename LOAD, typename STORE, typename ComputeType>\ninline cudaError_t DispatchLayerNormWarpImpl(cudaStream_t stream, LOAD load, STORE store,\n                                             const int64_t rows, const int64_t cols,\n                                             const double epsilon, ComputeType* mean,\n                                             ComputeType* inv_variance) {\n  return DispatchLayerNormWarpImplPackSize<LOAD, STORE, ComputeType>()(\n      stream, load, store, rows, cols, epsilon, mean, inv_variance);\n}\n\ntemplate<typename LOAD, typename STORE, typename ComputeType, int pack_size, int block_size>\n__global__ void LayerNormBlockSMemImpl(LOAD load, STORE store, const int64_t rows,\n                                       const int64_t cols, const double epsilon, ComputeType* mean,\n                                       ComputeType* inv_variance) {\n  using LoadType = typename LOAD::LoadType;\n  extern __shared__ __align__(sizeof(double)) unsigned char shared_buf[];\n  auto* buf = reinterpret_cast<LoadType*>(shared_buf);\n  const int tid = threadIdx.x;\n  assert(cols % pack_size == 0);\n  const int num_packs = static_cast<int>(cols) / pack_size;\n  for (int64_t row = blockIdx.x; row < rows; row += gridDim.x) {\n    ComputeType thread_mean = 0;\n    ComputeType thread_m2 = 0;\n    ComputeType thread_count = 0;\n    for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) {\n      LoadType pack[pack_size];\n      load.template load<pack_size>(pack, row, pack_id * pack_size);\n#pragma unroll\n      for (int i = 0; i < pack_size; ++i) {\n        buf[i * num_packs + pack_id] = pack[i];\n        WelfordCombine(static_cast<ComputeType>(pack[i]), &thread_mean, &thread_m2, &thread_count);\n      }\n    }\n    ComputeType row_mean = 0;\n    ComputeType row_m2 = 0;\n    ComputeType row_count = 0;\n    WelfordBlockAllReduce<ComputeType>(thread_mean, thread_m2, thread_count, &row_mean, &row_m2,\n                                       &row_count);\n    ComputeType row_variance = max(Div(row_m2, row_count), static_cast<ComputeType>(0.0));\n    ComputeType row_inv_var = Rsqrt(row_variance + static_cast<ComputeType>(epsilon));\n    if (threadIdx.x == 0) {\n      mean[row] = row_mean;\n      inv_variance[row] = row_inv_var;\n    }\n    for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) {\n      ComputeType pack[pack_size];\n#pragma unroll\n      for (int i = 0; i < pack_size; ++i) {\n        pack[i] = (static_cast<ComputeType>(buf[i * num_packs + pack_id]) - row_mean) * row_inv_var;\n      }\n      store.template store<pack_size>(pack, row, pack_id * pack_size);\n    }\n  }\n}\n\ntemplate<typename LOAD, typename STORE, typename ComputeType, int pack_size, int block_size>\ninline cudaError_t LaunchLayerNormBlockSMemImpl(cudaStream_t stream, LOAD load, STORE store,\n                                                int smem, const int64_t rows, const int64_t cols,\n                                                const double epsilon, ComputeType* mean,\n                                                ComputeType* inv_variance) {\n  constexpr int waves = 32;\n  int grid_dim_x;\n  {\n    cudaError_t err =\n        GetNumBlocks(LayerNormBlockSMemImpl<LOAD, STORE, ComputeType, pack_size, block_size>,\n                     block_size, smem, rows, waves, &grid_dim_x);\n    if (err != cudaSuccess) { return err; }\n  }\n  LayerNormBlockSMemImpl<LOAD, STORE, ComputeType, pack_size, block_size>\n      <<<grid_dim_x, block_size, smem, stream>>>(load, store, rows, cols, epsilon, mean,\n                                                 inv_variance);\n  return cudaPeekAtLastError();\n}\n\ntemplate<typename Func>\ncudaError_t MaximizeDynamicSharedMemorySize(Func func, const int max_smem_size) {\n  cudaFuncAttributes attr{};\n  cudaError_t err = cudaFuncGetAttributes(&attr, func);\n  if (err != cudaSuccess) { return err; }\n  constexpr int reserved_smem = 1024;  // 1K\n  return cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize,\n                              max_smem_size - attr.sharedSizeBytes - reserved_smem);\n}\n\ntemplate<typename LOAD, typename STORE, typename ComputeType, int pack_size>\ninline cudaError_t TryDispatchLayerNormBlockSMemImplBlockSize(\n    cudaStream_t stream, LOAD load, STORE store, const int64_t rows, const int64_t cols,\n    const double epsilon, ComputeType* mean, ComputeType* inv_variance, bool* success) {\n  constexpr int block_size_conf_1 = 128;\n  constexpr int block_size_conf_2 = 256;\n  constexpr int block_size_conf_3 = 512;\n  constexpr int block_size_conf_4 = 1024;\n\n  int dev = 0;\n  {\n    cudaError_t err = cudaGetDevice(&dev);\n    if (err != cudaSuccess) { return err; }\n  }\n\n  int sm_count = 0;\n  {\n    cudaError_t err = cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev);\n    if (err != cudaSuccess) { return err; }\n  }\n\n  static const bool max_smem_configed = [=]() {\n    int max_smem_size = 0;\n    cudaError_t err =\n        cudaDeviceGetAttribute(&max_smem_size, cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);\n    if (err != cudaSuccess) { return false; }\n\n    err = MaximizeDynamicSharedMemorySize(\n        LayerNormBlockSMemImpl<LOAD, STORE, ComputeType, pack_size, block_size_conf_1>,\n        max_smem_size);\n    if (err != cudaSuccess) { return false; }\n    err = MaximizeDynamicSharedMemorySize(\n        LayerNormBlockSMemImpl<LOAD, STORE, ComputeType, pack_size, block_size_conf_2>,\n        max_smem_size);\n    if (err != cudaSuccess) { return false; }\n    err = MaximizeDynamicSharedMemorySize(\n        LayerNormBlockSMemImpl<LOAD, STORE, ComputeType, pack_size, block_size_conf_3>,\n        max_smem_size);\n    if (err != cudaSuccess) { return false; }\n    err = MaximizeDynamicSharedMemorySize(\n        LayerNormBlockSMemImpl<LOAD, STORE, ComputeType, pack_size, block_size_conf_4>,\n        max_smem_size);\n    if (err != cudaSuccess) { return false; }\n\n    return true;\n  }();\n\n  const size_t smem = cols * sizeof(typename LOAD::LoadType);\n\n  int max_active_blocks_conf_1;\n  {\n    cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor(\n        &max_active_blocks_conf_1,\n        LayerNormBlockSMemImpl<LOAD, STORE, ComputeType, pack_size, block_size_conf_1>,\n        block_size_conf_1, smem);\n    if (err != cudaSuccess) { return err; }\n  }\n  if (max_active_blocks_conf_1 <= 0) {\n    *success = false;\n    return cudaSuccess;\n  }\n\n  int max_active_blocks_conf_4;\n  {\n    cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor(\n        &max_active_blocks_conf_4,\n        LayerNormBlockSMemImpl<LOAD, STORE, ComputeType, pack_size, block_size_conf_4>,\n        block_size_conf_4, smem);\n    if (err != cudaSuccess) { return err; }\n  }\n\n  if (max_active_blocks_conf_4 == max_active_blocks_conf_1\n      || (max_active_blocks_conf_4 > 0 && rows <= sm_count)) {\n    *success = true;\n    return LaunchLayerNormBlockSMemImpl<LOAD, STORE, ComputeType, pack_size, block_size_conf_4>(\n        stream, load, store, smem, rows, cols, epsilon, mean, inv_variance);\n  }\n\n  int max_active_blocks_conf_3;\n  {\n    cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor(\n        &max_active_blocks_conf_3,\n        LayerNormBlockSMemImpl<LOAD, STORE, ComputeType, pack_size, block_size_conf_3>,\n        block_size_conf_3, smem);\n    if (err != cudaSuccess) { return err; }\n  }\n  if (max_active_blocks_conf_3 == max_active_blocks_conf_1\n      || (max_active_blocks_conf_3 > 0 && rows <= sm_count)) {\n    *success = true;\n    return LaunchLayerNormBlockSMemImpl<LOAD, STORE, ComputeType, pack_size, block_size_conf_3>(\n        stream, load, store, smem, rows, cols, epsilon, mean, inv_variance);\n  }\n\n  int max_active_blocks_conf_2;\n  {\n    cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor(\n        &max_active_blocks_conf_2,\n        LayerNormBlockSMemImpl<LOAD, STORE, ComputeType, pack_size, block_size_conf_2>,\n        block_size_conf_2, smem);\n    if (err != cudaSuccess) { return err; }\n  }\n  if (max_active_blocks_conf_2 == max_active_blocks_conf_1\n      || (max_active_blocks_conf_2 > 0 && rows <= sm_count)) {\n    *success = true;\n    return LaunchLayerNormBlockSMemImpl<LOAD, STORE, ComputeType, pack_size, block_size_conf_2>(\n        stream, load, store, smem, rows, cols, epsilon, mean, inv_variance);\n  }\n\n  *success = true;\n  return LaunchLayerNormBlockSMemImpl<LOAD, STORE, ComputeType, pack_size, block_size_conf_1>(\n      stream, load, store, smem, rows, cols, epsilon, mean, inv_variance);\n}\n\ntemplate<typename LOAD, typename STORE, typename ComputeType>\nstruct TryDispatchLayerNormBlockSMemImplPackSize {\n  cudaError_t operator()(cudaStream_t stream, LOAD load, STORE store, const int64_t rows,\n                         const int64_t cols, const double epsilon, ComputeType* mean,\n                         ComputeType* inv_variance, bool* success) {\n    if (cols % 4 == 0 && CanPackAs<LOAD>(load, 4) && CanPackAs<STORE>(store, 4)) {\n      return TryDispatchLayerNormBlockSMemImplBlockSize<LOAD, STORE, ComputeType, 4>(\n          stream, load, store, rows, cols, epsilon, mean, inv_variance, success);\n    } else if (cols % 2 == 0 && CanPackAs<LOAD>(load, 2) && CanPackAs<STORE>(store, 2)) {\n      return TryDispatchLayerNormBlockSMemImplBlockSize<LOAD, STORE, ComputeType, 2>(\n          stream, load, store, rows, cols, epsilon, mean, inv_variance, success);\n    } else {\n      return TryDispatchLayerNormBlockSMemImplBlockSize<LOAD, STORE, ComputeType, 1>(\n          stream, load, store, rows, cols, epsilon, mean, inv_variance, success);\n    }\n  }\n};\n\ntemplate<typename LOAD, typename STORE, typename ComputeType>\ninline cudaError_t TryDispatchLayerNormBlockSMemImpl(cudaStream_t stream, LOAD load, STORE store,\n                                                     const int64_t rows, const int64_t cols,\n                                                     const double epsilon, ComputeType* mean,\n                                                     ComputeType* inv_variance, bool* success) {\n  return TryDispatchLayerNormBlockSMemImplPackSize<LOAD, STORE, ComputeType>()(\n      stream, load, store, rows, cols, epsilon, mean, inv_variance, success);\n}\n\ntemplate<typename LOAD, typename STORE, typename ComputeType, int pack_size, int block_size>\n__global__ void __launch_bounds__(1024)\n    LayerNormBlockUncachedImpl(LOAD load, STORE store, const int64_t rows, const int64_t cols,\n                               const double epsilon, ComputeType* mean, ComputeType* inv_variance) {\n  using LoadType = typename LOAD::LoadType;\n  const int tid = threadIdx.x;\n  assert(cols % pack_size == 0);\n  const int num_packs = static_cast<int>(cols) / pack_size;\n  for (int64_t row = blockIdx.x; row < rows; row += gridDim.x) {\n    ComputeType thread_mean = 0;\n    ComputeType thread_m2 = 0;\n    ComputeType thread_count = 0;\n    for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) {\n      LoadType pack[pack_size];\n      load.template load<pack_size>(pack, row, pack_id * pack_size);\n#pragma unroll\n      for (int i = 0; i < pack_size; ++i) {\n        WelfordCombine(static_cast<ComputeType>(pack[i]), &thread_mean, &thread_m2, &thread_count);\n      }\n    }\n    ComputeType row_mean = 0;\n    ComputeType row_m2 = 0;\n    ComputeType row_count = 0;\n    WelfordBlockAllReduce<ComputeType>(thread_mean, thread_m2, thread_count, &row_mean, &row_m2,\n                                       &row_count);\n    ComputeType row_variance = max(Div(row_m2, row_count), static_cast<ComputeType>(0.0));\n    ComputeType row_inv_var = Rsqrt(row_variance + static_cast<ComputeType>(epsilon));\n    if (threadIdx.x == 0) {\n      mean[row] = row_mean;\n      inv_variance[row] = row_inv_var;\n    }\n    for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) {\n      LoadType pack[pack_size];\n      ComputeType dst_pack[pack_size];\n      const int pack_offset = pack_id * pack_size;\n      load.template load<pack_size>(pack, row, pack_offset);\n#pragma unroll\n      for (int i = 0; i < pack_size; ++i) {\n        dst_pack[i] = (static_cast<ComputeType>(pack[i]) - row_mean) * row_inv_var;\n      }\n      store.template store<pack_size>(dst_pack, row, pack_offset);\n    }\n  }\n}\n\ntemplate<typename LOAD, typename STORE, typename ComputeType, int pack_size>\ninline cudaError_t LaunchLayerNormBlockUncachedImpl(cudaStream_t stream, LOAD load, STORE store,\n                                                    const int64_t rows, const int64_t cols,\n                                                    const double epsilon, ComputeType* mean,\n                                                    ComputeType* inv_variance) {\n  constexpr int block_size = 1024;\n  constexpr int waves = 32;\n  int grid_dim_x;\n  {\n    cudaError_t err =\n        GetNumBlocks(LayerNormBlockUncachedImpl<LOAD, STORE, ComputeType, pack_size, block_size>,\n                     block_size, 0, rows, waves, &grid_dim_x);\n    if (err != cudaSuccess) { return err; }\n  }\n  LayerNormBlockUncachedImpl<LOAD, STORE, ComputeType, pack_size, block_size>\n      <<<grid_dim_x, block_size, 0, stream>>>(load, store, rows, cols, epsilon, mean, inv_variance);\n  return cudaPeekAtLastError();\n}\n\ntemplate<typename LOAD, typename STORE, typename ComputeType>\nstruct DispatchLayerNormBlockUncachedImplPackSize {\n  cudaError_t operator()(cudaStream_t stream, LOAD load, STORE store, const int64_t rows,\n                         const int64_t cols, const double epsilon, ComputeType* mean,\n                         ComputeType* inv_variance) {\n    if (cols % 4 == 0 && CanPackAs<LOAD>(load, 4) && CanPackAs<STORE>(store, 4)) {\n      return LaunchLayerNormBlockUncachedImpl<LOAD, STORE, ComputeType, 4>(\n          stream, load, store, rows, cols, epsilon, mean, inv_variance);\n    } else if (cols % 2 == 0 && CanPackAs<LOAD>(load, 2) && CanPackAs<STORE>(store, 2)) {\n      return LaunchLayerNormBlockUncachedImpl<LOAD, STORE, ComputeType, 2>(\n          stream, load, store, rows, cols, epsilon, mean, inv_variance);\n    } else {\n      return LaunchLayerNormBlockUncachedImpl<LOAD, STORE, ComputeType, 1>(\n          stream, load, store, rows, cols, epsilon, mean, inv_variance);\n    }\n  }\n};\n\ntemplate<typename LOAD, typename STORE, typename ComputeType>\ninline cudaError_t DispatchLayerNormBlockUncachedImpl(cudaStream_t stream, LOAD load, STORE store,\n                                                      const int64_t rows, const int64_t cols,\n                                                      const double epsilon, ComputeType* mean,\n                                                      ComputeType* inv_variance) {\n  return DispatchLayerNormBlockUncachedImplPackSize<LOAD, STORE, ComputeType>()(\n      stream, load, store, rows, cols, epsilon, mean, inv_variance);\n}\n\ntemplate<typename LOAD, typename STORE, typename ComputeType>\ninline typename std::enable_if<!std::is_same<ComputeType, double>::value, cudaError_t>::type\nDispatchLayerNorm(cudaStream_t stream, LOAD load, STORE store, const int64_t rows,\n                  const int64_t cols, const double epsilon, ComputeType* mean,\n                  ComputeType* inv_variance) {\n  if (cols <= 1024) {\n    return DispatchLayerNormWarpImpl<LOAD, STORE, ComputeType>(stream, load, store, rows, cols,\n                                                               epsilon, mean, inv_variance);\n  } else {\n    bool dispatch_smem_impl_success;\n    {\n      cudaError_t err = TryDispatchLayerNormBlockSMemImpl<LOAD, STORE, ComputeType>(\n          stream, load, store, rows, cols, epsilon, mean, inv_variance,\n          &dispatch_smem_impl_success);\n      if (err != cudaSuccess) { return err; }\n    }\n    if (!dispatch_smem_impl_success) {\n      return DispatchLayerNormBlockUncachedImpl<LOAD, STORE, ComputeType>(\n          stream, load, store, rows, cols, epsilon, mean, inv_variance);\n    }\n    return cudaSuccess;\n  }\n}\n\ntemplate<typename LOAD, typename STORE, typename ComputeType>\ninline typename std::enable_if<std::is_same<ComputeType, double>::value, cudaError_t>::type\nDispatchLayerNorm(cudaStream_t stream, LOAD load, STORE store, const int64_t rows,\n                  const int64_t cols, const double epsilon, ComputeType* mean,\n                  ComputeType* inv_variance) {\n  return DispatchLayerNormBlockUncachedImpl<LOAD, STORE, ComputeType>(\n      stream, load, store, rows, cols, epsilon, mean, inv_variance);\n}\n\n/*\nLayerNormGrad dx:\nnormalized = (x - mean) * inv_var\nsum_stats1 = sum(scaled_dy)\nsum_stats2 = sum(scaled_dy * normalized)\ndx = cols * dy - sum_stats1 - normalized * sum_stats2\ndx *= inv_var / cols\n*/\ntemplate<typename LOAD_X, typename LOAD_SCALED_DY, typename STORE, typename ComputeType,\n         int pack_size, int max_cols_per_thread, int min_cols_per_thread, int thread_group_width,\n         int rows_per_access>\n__global__ void LayerNormGradWarpImpl(LOAD_X load_x, LOAD_SCALED_DY load_scaled_dy, STORE store,\n                                      const ComputeType* mean, const ComputeType* inv_variance,\n                                      const int64_t rows, const int64_t cols) {\n  using LoadTypeX = typename LOAD_X::LoadType;\n  using LoadTypeDy = typename LOAD_SCALED_DY::LoadType;\n  static_assert(max_cols_per_thread % pack_size == 0, \"\");\n  static_assert(min_cols_per_thread % pack_size == 0, \"\");\n  constexpr int max_num_packs = max_cols_per_thread / pack_size;\n  constexpr int min_num_packs = min_cols_per_thread / pack_size;\n  assert(cols <= max_cols_per_thread * thread_group_width);\n  static_assert(thread_group_width <= kWarpSize, \"\");\n  static_assert(kWarpSize % thread_group_width == 0, \"\");\n  ComputeType normalized_buf[rows_per_access][max_cols_per_thread];\n  ComputeType dy_buf[rows_per_access][max_cols_per_thread];\n  const ComputeType one_over_cols = static_cast<ComputeType>(1.0) / static_cast<ComputeType>(cols);\n  const int64_t global_thread_group_id = blockIdx.x * blockDim.y + threadIdx.y;\n  const int64_t num_global_thread_group = gridDim.x * blockDim.y;\n  const int lane_id = threadIdx.x;\n  const int64_t step = num_global_thread_group * rows_per_access;\n  for (int64_t row = global_thread_group_id * rows_per_access; row < rows; row += step) {\n    ComputeType sum_stats1[rows_per_access];\n    ComputeType sum_stats2[rows_per_access];\n    ComputeType inv_variance_buf[rows_per_access];\n#pragma unroll\n    for (int row_id = 0; row_id < rows_per_access; ++row_id) {\n      const int global_row_id = row + row_id;\n      ComputeType mean_val = mean[global_row_id];\n      inv_variance_buf[row_id] = inv_variance[global_row_id];\n      sum_stats1[row_id] = 0;\n      sum_stats2[row_id] = 0;\n      ComputeType* row_normalized_buf = normalized_buf[row_id];\n      ComputeType* row_dy_buf = dy_buf[row_id];\n#pragma unroll\n      for (int pack_id = 0; pack_id < min_num_packs; ++pack_id) {\n        const int col = (pack_id * thread_group_width + lane_id) * pack_size;\n        const int pack_offset = pack_id * pack_size;\n        LoadTypeX pack_x[pack_size];\n        LoadTypeDy pack_dy[pack_size];\n        load_x.template load<pack_size>(pack_x, global_row_id, col);\n        load_scaled_dy.template load<pack_size>(pack_dy, global_row_id, col);\n#pragma unroll\n        for (int i = 0; i < pack_size; ++i) {\n          const int col_id = pack_offset + i;\n          // row_normalized_buf store x\n          row_normalized_buf[col_id] =\n              (static_cast<ComputeType>(pack_x[i]) - mean_val) * inv_variance_buf[row_id];\n          row_dy_buf[col_id] = static_cast<ComputeType>(pack_dy[i]);\n          sum_stats1[row_id] += row_dy_buf[col_id];\n          sum_stats2[row_id] += row_dy_buf[col_id] * row_normalized_buf[col_id];\n        }\n      }\n#pragma unroll\n      for (int pack_id = min_num_packs; pack_id < max_num_packs; ++pack_id) {\n        const int col = (pack_id * thread_group_width + lane_id) * pack_size;\n        const int pack_offset = pack_id * pack_size;\n        if (col < cols) {\n          LoadTypeX pack_x[pack_size];\n          LoadTypeDy pack_dy[pack_size];\n          load_x.template load<pack_size>(pack_x, global_row_id, col);\n          load_scaled_dy.template load<pack_size>(pack_dy, global_row_id, col);\n#pragma unroll\n          for (int i = 0; i < pack_size; ++i) {\n            const int col_id = pack_offset + i;\n            // row_normalized_buf store x\n            row_normalized_buf[col_id] =\n                (static_cast<ComputeType>(pack_x[i]) - mean_val) * inv_variance_buf[row_id];\n            row_dy_buf[col_id] = static_cast<ComputeType>(pack_dy[i]);\n            sum_stats1[row_id] += row_dy_buf[col_id];\n            sum_stats2[row_id] += row_dy_buf[col_id] * row_normalized_buf[col_id];\n          }\n        }\n      }\n    }\n    ComputeType warp_sum_stats1[rows_per_access];\n    ComputeType warp_sum_stats2[rows_per_access];\n#pragma unroll\n    for (int row_id = 0; row_id < rows_per_access; ++row_id) {\n      warp_sum_stats1[row_id] =\n          WarpAllReduce<SumOp, ComputeType, thread_group_width>(sum_stats1[row_id]);\n      warp_sum_stats2[row_id] =\n          WarpAllReduce<SumOp, ComputeType, thread_group_width>(sum_stats2[row_id]);\n    }\n#pragma unroll\n    for (int row_id = 0; row_id < rows_per_access; ++row_id) {\n      const int global_row_id = row + row_id;\n      ComputeType* row_normalized_buf = normalized_buf[row_id];\n      ComputeType* row_dy_buf = dy_buf[row_id];\n      const ComputeType inv_variance_over_cols = inv_variance_buf[row_id] * one_over_cols;\n#pragma unroll\n      for (int pack_id = 0; pack_id < min_num_packs; ++pack_id) {\n        const int col = (pack_id * thread_group_width + lane_id) * pack_size;\n        const int pack_offset = pack_id * pack_size;\n        for (int i = 0; i < pack_size; ++i) {\n          const int col_id = pack_offset + i;\n          row_dy_buf[col_id] = (cols * row_dy_buf[col_id] - warp_sum_stats1[row_id]\n                                - row_normalized_buf[col_id] * warp_sum_stats2[row_id])\n                               * inv_variance_over_cols;\n        }\n        store.template store<pack_size>(row_dy_buf + pack_offset, global_row_id, col);\n      }\n#pragma unroll\n      for (int pack_id = min_num_packs; pack_id < max_num_packs; ++pack_id) {\n        const int col = (pack_id * thread_group_width + lane_id) * pack_size;\n        if (col < cols) {\n          const int pack_offset = pack_id * pack_size;\n          for (int i = 0; i < pack_size; ++i) {\n            const int col_id = pack_offset + i;\n            row_dy_buf[col_id] = (cols * row_dy_buf[col_id] - warp_sum_stats1[row_id]\n                                  - row_normalized_buf[col_id] * warp_sum_stats2[row_id])\n                                 * inv_variance_over_cols;\n          }\n          store.template store<pack_size>(row_dy_buf + pack_offset, global_row_id, col);\n        }\n      }\n    }\n  }\n}\n\ntemplate<typename LOAD_X, typename LOAD_SCALED_DY, typename STORE, typename ComputeType,\n         int pack_size, int max_cols_per_thread, int min_cols_per_thread, int thread_group_width,\n         int rows_per_access>\ninline cudaError_t LaunchLayerNormGradWarpImpl(cudaStream_t stream, LOAD_X load_x,\n                                               LOAD_SCALED_DY load_scaled_dy, STORE store,\n                                               const ComputeType* mean,\n                                               const ComputeType* inv_variance, const int64_t rows,\n                                               const int64_t cols) {\n  constexpr int block_size = 128;\n  constexpr int waves = 32;\n  static_assert(block_size % thread_group_width == 0, \"\");\n  constexpr int thread_groups_per_block = block_size / thread_group_width;\n  dim3 block_dim(thread_group_width, thread_groups_per_block);\n  const int64_t num_blocks =\n      (rows / rows_per_access + thread_groups_per_block - 1) / thread_groups_per_block;\n  int grid_dim_x;\n  {\n    cudaError_t err =\n        GetNumBlocks(LayerNormGradWarpImpl<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType, pack_size,\n                                           max_cols_per_thread, min_cols_per_thread,\n                                           thread_group_width, rows_per_access>,\n                     block_size, 0, num_blocks, waves, &grid_dim_x);\n    if (err != cudaSuccess) { return err; }\n  }\n  LayerNormGradWarpImpl<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType, pack_size, max_cols_per_thread,\n                        min_cols_per_thread, thread_group_width, rows_per_access>\n      <<<grid_dim_x, block_dim, 0, stream>>>(load_x, load_scaled_dy, store, mean, inv_variance,\n                                             rows, cols);\n  return cudaPeekAtLastError();\n}\n\ntemplate<typename LOAD_X, typename LOAD_SCALED_DY, typename STORE, typename ComputeType,\n         int pack_size, int max_cols_per_thread, int min_cols_per_thread, int thread_group_width,\n         int rows_per_access>\ninline cudaError_t DispatchLayerNormGradWarpImplPadding(cudaStream_t stream, LOAD_X load_x,\n                                                        LOAD_SCALED_DY load_scaled_dy, STORE store,\n                                                        const ComputeType* mean,\n                                                        const ComputeType* inv_variance,\n                                                        const int64_t rows, const int64_t cols) {\n  if (cols == max_cols_per_thread * thread_group_width) {\n    // when not padding, min_cols_per_thread must equals to max_cols_per_thread, pass\n    // max_cols_per_thread as min_cols_per_thread and max_cols_per_thread param.\n    return LaunchLayerNormGradWarpImpl<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType, pack_size,\n                                       max_cols_per_thread, max_cols_per_thread, thread_group_width,\n                                       rows_per_access>(stream, load_x, load_scaled_dy, store, mean,\n                                                        inv_variance, rows, cols);\n  } else {\n    return LaunchLayerNormGradWarpImpl<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType, pack_size,\n                                       max_cols_per_thread, min_cols_per_thread, thread_group_width,\n                                       rows_per_access>(stream, load_x, load_scaled_dy, store, mean,\n                                                        inv_variance, rows, cols);\n  }\n}\n\ntemplate<typename LOAD_X, typename LOAD_SCALED_DY, typename STORE, typename ComputeType,\n         int pack_size>\ntypename std::enable_if<pack_size == 1, cudaError_t>::type DispatchLayerNormGradWarpImplCols(\n    cudaStream_t stream, LOAD_X load_x, LOAD_SCALED_DY load_scaled_dy, STORE store,\n    const ComputeType* mean, const ComputeType* inv_variance, const int64_t rows,\n    const int64_t cols) {\n  if (cols <= 0) { return cudaErrorInvalidValue; }\n#define DEFINE_ONE_ELIF(thread_group_width)                                                        \\\n  else if (cols <= (thread_group_width)*pack_size) {                                               \\\n    if (rows % 2 == 0) {                                                                           \\\n      return DispatchLayerNormGradWarpImplPadding<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType,      \\\n                                                  pack_size, pack_size, 0, thread_group_width, 2>( \\\n          stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols);                  \\\n    } else {                                                                                       \\\n      return DispatchLayerNormGradWarpImplPadding<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType,      \\\n                                                  pack_size, pack_size, 0, thread_group_width, 1>( \\\n          stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols);                  \\\n    }                                                                                              \\\n  }\n  DEFINE_ONE_ELIF(4)\n  DEFINE_ONE_ELIF(8)\n  DEFINE_ONE_ELIF(16)\n  DEFINE_ONE_ELIF(32)\n#undef DEFINE_ONE_ELIF\n#define DEFINE_ONE_ELIF(max_col, min_col)                                                   \\\n  else if (cols <= (max_col)*kWarpSize) {                                                   \\\n    return DispatchLayerNormGradWarpImplPadding<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType, \\\n                                                pack_size, max_col, min_col, kWarpSize, 1>( \\\n        stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols);             \\\n  }\n  DEFINE_ONE_ELIF(2, 1)\n  DEFINE_ONE_ELIF(4, 2)\n  DEFINE_ONE_ELIF(8, 4)\n  DEFINE_ONE_ELIF(12, 8)\n  DEFINE_ONE_ELIF(16, 12)\n  DEFINE_ONE_ELIF(20, 16)\n  DEFINE_ONE_ELIF(24, 20)\n  DEFINE_ONE_ELIF(28, 24)\n  DEFINE_ONE_ELIF(32, 28)\n#undef DEFINE_ONE_ELIF\n  else {\n    return cudaErrorInvalidValue;\n  }\n}\n\ntemplate<typename LOAD_X, typename LOAD_SCALED_DY, typename STORE, typename ComputeType>\nstruct DispatchLayerNormGradWarpImplPackSize {\n  cudaError_t operator()(cudaStream_t stream, LOAD_X load_x, LOAD_SCALED_DY load_scaled_dy,\n                         STORE store, const ComputeType* mean, const ComputeType* inv_variance,\n                         const int64_t rows, const int64_t cols) {\n    return DispatchLayerNormGradWarpImplCols<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType, 1>(\n        stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols);\n  }\n};\n\ntemplate<typename LOAD_X, typename LOAD_SCALED_DY, typename STORE, typename ComputeType>\ninline cudaError_t DispatchLayerNormGradWarpImpl(cudaStream_t stream, LOAD_X load_x,\n                                                 LOAD_SCALED_DY load_scaled_dy, STORE store,\n                                                 const ComputeType* mean,\n                                                 const ComputeType* inv_variance,\n                                                 const int64_t rows, const int64_t cols) {\n  return DispatchLayerNormGradWarpImplPackSize<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType>()(\n      stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols);\n}\n\ntemplate<typename LOAD_X, typename LOAD_SCALED_DY, typename STORE, typename ComputeType,\n         int pack_size, int block_size>\n__global__ void LayerNormGradBlockSMemImpl(LOAD_X load_x, LOAD_SCALED_DY load_scaled_dy,\n                                           STORE store, const ComputeType* mean,\n                                           const ComputeType* inv_variance, const int64_t rows,\n                                           const int64_t cols) {\n  using LoadTypeX = typename LOAD_X::LoadType;\n  using LoadTypeDy = typename LOAD_SCALED_DY::LoadType;\n  extern __shared__ __align__(sizeof(double)) unsigned char grad_shared_buf[];\n  auto* normalized_buf = reinterpret_cast<LoadTypeX*>(grad_shared_buf);\n  auto* dy_buf = reinterpret_cast<LoadTypeDy*>(normalized_buf + cols);\n  const int tid = threadIdx.x;\n  assert(cols % pack_size == 0);\n  const int num_packs = static_cast<int>(cols) / pack_size;\n  const ComputeType one_over_cols = static_cast<ComputeType>(1.0) / static_cast<ComputeType>(cols);\n  for (int64_t row = blockIdx.x; row < rows; row += gridDim.x) {\n    ComputeType sum_stats1 = 0;\n    ComputeType sum_stats2 = 0;\n    const ComputeType mean_val = mean[row];\n    const ComputeType inv_variance_val = inv_variance[row];\n    const ComputeType inv_variance_over_cols = inv_variance_val * one_over_cols;\n    for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) {\n      LoadTypeX x_pack[pack_size];\n      LoadTypeDy dy_pack[pack_size];\n      load_x.template load<pack_size>(x_pack, row, pack_id * pack_size);\n      load_scaled_dy.template load<pack_size>(dy_pack, row, pack_id * pack_size);\n#pragma unroll\n      for (int i = 0; i < pack_size; ++i) {\n        const int buf_offset = i * num_packs + pack_id;\n        ComputeType normalized =\n            (static_cast<ComputeType>(x_pack[i]) - mean_val) * inv_variance_val;\n        normalized_buf[buf_offset] = static_cast<LoadTypeX>(normalized);\n        dy_buf[buf_offset] = dy_pack[i];\n        sum_stats1 += static_cast<ComputeType>(dy_pack[i]);\n        sum_stats2 += static_cast<ComputeType>(dy_pack[i]) * normalized;\n      }\n    }\n    const ComputeType row_sum_stats1 = BlockAllReduce<SumOp, ComputeType, block_size>(sum_stats1);\n    const ComputeType row_sum_stats2 = BlockAllReduce<SumOp, ComputeType, block_size>(sum_stats2);\n    for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) {\n      ComputeType pack[pack_size];\n#pragma unroll\n      for (int i = 0; i < pack_size; ++i) {\n        const int buf_offset = i * num_packs + pack_id;\n        pack[i] = (cols * static_cast<ComputeType>(dy_buf[buf_offset]) - row_sum_stats1\n                   - static_cast<ComputeType>(normalized_buf[buf_offset]) * row_sum_stats2)\n                  * inv_variance_over_cols;\n      }\n      store.template store<pack_size>(pack, row, pack_id * pack_size);\n    }\n  }\n}\n\ntemplate<typename LOAD_X, typename LOAD_SCALED_DY, typename STORE, typename ComputeType,\n         int pack_size, int block_size>\ninline cudaError_t LaunchLayerNormGradBlockSMemImpl(cudaStream_t stream, LOAD_X load_x,\n                                                    LOAD_SCALED_DY load_scaled_dy, STORE store,\n                                                    const ComputeType* mean,\n                                                    const ComputeType* inv_variance, int smem,\n                                                    const int64_t rows, const int64_t cols) {\n  constexpr int waves = 32;\n  int grid_dim_x;\n  {\n    cudaError_t err = GetNumBlocks(LayerNormGradBlockSMemImpl<LOAD_X, LOAD_SCALED_DY, STORE,\n                                                              ComputeType, pack_size, block_size>,\n                                   block_size, smem, rows, waves, &grid_dim_x);\n    if (err != cudaSuccess) { return err; }\n  }\n  LayerNormGradBlockSMemImpl<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType, pack_size, block_size>\n      <<<grid_dim_x, block_size, smem, stream>>>(load_x, load_scaled_dy, store, mean, inv_variance,\n                                                 rows, cols);\n  return cudaPeekAtLastError();\n}\n\ntemplate<typename LOAD_X, typename LOAD_SCALED_DY, typename STORE, typename ComputeType,\n         int pack_size>\ninline cudaError_t TryDispatchLayerNormGradBlockSMemImplBlockSize(\n    cudaStream_t stream, LOAD_X load_x, LOAD_SCALED_DY load_scaled_dy, STORE store,\n    const ComputeType* mean, const ComputeType* inv_variance, const int64_t rows,\n    const int64_t cols, bool* success) {\n  constexpr int block_size_conf_1 = 128;\n  constexpr int block_size_conf_2 = 256;\n  constexpr int block_size_conf_3 = 512;\n  constexpr int block_size_conf_4 = 1024;\n\n  int dev = 0;\n  {\n    cudaError_t err = cudaGetDevice(&dev);\n    if (err != cudaSuccess) { return err; }\n  }\n\n  int sm_count = 0;\n  {\n    cudaError_t err = cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev);\n    if (err != cudaSuccess) { return err; }\n  }\n\n  static const bool max_smem_configed = [=]() {\n    int max_smem_size = 0;\n    cudaError_t err =\n        cudaDeviceGetAttribute(&max_smem_size, cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);\n    if (err != cudaSuccess) { return false; }\n\n    err = MaximizeDynamicSharedMemorySize(\n        LayerNormGradBlockSMemImpl<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType, pack_size,\n                                   block_size_conf_1>,\n        max_smem_size);\n    if (err != cudaSuccess) { return false; }\n    err = MaximizeDynamicSharedMemorySize(\n        LayerNormGradBlockSMemImpl<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType, pack_size,\n                                   block_size_conf_2>,\n        max_smem_size);\n    if (err != cudaSuccess) { return false; }\n    err = MaximizeDynamicSharedMemorySize(\n        LayerNormGradBlockSMemImpl<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType, pack_size,\n                                   block_size_conf_3>,\n        max_smem_size);\n    if (err != cudaSuccess) { return false; }\n    err = MaximizeDynamicSharedMemorySize(\n        LayerNormGradBlockSMemImpl<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType, pack_size,\n                                   block_size_conf_4>,\n        max_smem_size);\n    if (err != cudaSuccess) { return false; }\n\n    return true;\n  }();\n\n  using LoadTypeX = typename LOAD_X::LoadType;\n  using LoadTypeDy = typename LOAD_SCALED_DY::LoadType;\n  const size_t smem = cols * (sizeof(LoadTypeX) + sizeof(LoadTypeDy));\n\n  int max_active_blocks_conf_1;\n  {\n    cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor(\n        &max_active_blocks_conf_1,\n        LayerNormGradBlockSMemImpl<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType, pack_size,\n                                   block_size_conf_1>,\n        block_size_conf_1, smem);\n    if (err != cudaSuccess) { return err; }\n  }\n  if (max_active_blocks_conf_1 <= 0) {\n    *success = false;\n    return cudaSuccess;\n  }\n\n  int max_active_blocks_conf_4;\n  {\n    cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor(\n        &max_active_blocks_conf_4,\n        LayerNormGradBlockSMemImpl<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType, pack_size,\n                                   block_size_conf_4>,\n        block_size_conf_4, smem);\n    if (err != cudaSuccess) { return err; }\n  }\n  if (max_active_blocks_conf_4 == max_active_blocks_conf_1\n      || (max_active_blocks_conf_4 > 0 && rows <= sm_count)) {\n    *success = true;\n    return LaunchLayerNormGradBlockSMemImpl<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType, pack_size,\n                                            block_size_conf_4>(\n        stream, load_x, load_scaled_dy, store, mean, inv_variance, smem, rows, cols);\n  }\n\n  int max_active_blocks_conf_3;\n  {\n    cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor(\n        &max_active_blocks_conf_3,\n        LayerNormGradBlockSMemImpl<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType, pack_size,\n                                   block_size_conf_3>,\n        block_size_conf_3, smem);\n    if (err != cudaSuccess) { return err; }\n  }\n  if (max_active_blocks_conf_3 == max_active_blocks_conf_1\n      || (max_active_blocks_conf_3 > 0 && rows <= sm_count)) {\n    *success = true;\n    return LaunchLayerNormGradBlockSMemImpl<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType, pack_size,\n                                            block_size_conf_3>(\n        stream, load_x, load_scaled_dy, store, mean, inv_variance, smem, rows, cols);\n  }\n\n  int max_active_blocks_conf_2;\n  {\n    cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor(\n        &max_active_blocks_conf_2,\n        LayerNormGradBlockSMemImpl<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType, pack_size,\n                                   block_size_conf_2>,\n        block_size_conf_2, smem);\n    if (err != cudaSuccess) { return err; }\n  }\n  if (max_active_blocks_conf_2 == max_active_blocks_conf_1\n      || (max_active_blocks_conf_2 > 0 && rows <= sm_count)) {\n    *success = true;\n    return LaunchLayerNormGradBlockSMemImpl<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType, pack_size,\n                                            block_size_conf_2>(\n        stream, load_x, load_scaled_dy, store, mean, inv_variance, smem, rows, cols);\n  }\n\n  *success = true;\n  return LaunchLayerNormGradBlockSMemImpl<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType, pack_size,\n                                          block_size_conf_1>(stream, load_x, load_scaled_dy, store,\n                                                             mean, inv_variance, smem, rows, cols);\n}\n\ntemplate<typename LOAD_X, typename LOAD_SCALED_DY, typename STORE, typename ComputeType>\nstruct TryDispatchLayerNormGradBlockSMemImplPackSize {\n  cudaError_t operator()(cudaStream_t stream, LOAD_X load_x, LOAD_SCALED_DY load_scaled_dy,\n                         STORE store, const ComputeType* mean, const ComputeType* inv_variance,\n                         const int64_t rows, const int64_t cols, bool* success) {\n    if (cols % 2 == 0 && CanPackAs<LOAD_X>(load_x, 2)\n        && CanPackAs<LOAD_SCALED_DY>(load_scaled_dy, 2) && CanPackAs<STORE>(store, 2)) {\n      return TryDispatchLayerNormGradBlockSMemImplBlockSize<LOAD_X, LOAD_SCALED_DY, STORE,\n                                                            ComputeType, 2>(\n          stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols, success);\n    } else {\n      return TryDispatchLayerNormGradBlockSMemImplBlockSize<LOAD_X, LOAD_SCALED_DY, STORE,\n                                                            ComputeType, 1>(\n          stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols, success);\n    }\n  }\n};\n\ntemplate<typename LOAD_X, typename LOAD_SCALED_DY, typename STORE, typename ComputeType>\ninline cudaError_t TryDispatchLayerNormGradBlockSMemImpl(cudaStream_t stream, LOAD_X load_x,\n                                                         LOAD_SCALED_DY load_scaled_dy, STORE store,\n                                                         const ComputeType* mean,\n                                                         const ComputeType* inv_variance,\n                                                         const int64_t rows, const int64_t cols,\n                                                         bool* success) {\n  return TryDispatchLayerNormGradBlockSMemImplPackSize<LOAD_X, LOAD_SCALED_DY, STORE,\n                                                       ComputeType>()(\n      stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols, success);\n}\n\ntemplate<typename LOAD_X, typename LOAD_SCALED_DY, typename STORE, typename ComputeType,\n         int pack_size, int block_size>\n__global__ void LayerNormGradBlockUncachedImpl(LOAD_X load_x, LOAD_SCALED_DY load_scaled_dy,\n                                               STORE store, const ComputeType* mean,\n                                               const ComputeType* inv_variance, const int64_t rows,\n                                               const int64_t cols) {\n  using LoadTypeX = typename LOAD_X::LoadType;\n  using LoadTypeDy = typename LOAD_SCALED_DY::LoadType;\n  const int tid = threadIdx.x;\n  assert(cols % pack_size == 0);\n  const int num_packs = static_cast<int>(cols) / pack_size;\n  const ComputeType one_over_cols = static_cast<ComputeType>(1.0) / static_cast<ComputeType>(cols);\n  for (int64_t row = blockIdx.x; row < rows; row += gridDim.x) {\n    const ComputeType mean_val = mean[row];\n    const ComputeType inv_variance_val = inv_variance[row];\n    const ComputeType inv_variance_over_cols = inv_variance_val * one_over_cols;\n    ComputeType sum_stats1 = 0;\n    ComputeType sum_stats2 = 0;\n    for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) {\n      const int pack_offset = pack_id * pack_size;\n      LoadTypeX x_pack[pack_size];\n      LoadTypeDy dy_pack[pack_size];\n      load_x.template load<pack_size>(x_pack, row, pack_offset);\n      load_scaled_dy.template load<pack_size>(dy_pack, row, pack_offset);\n#pragma unroll\n      for (int i = 0; i < pack_size; ++i) {\n        sum_stats1 += static_cast<ComputeType>(dy_pack[i]);\n        sum_stats2 += static_cast<ComputeType>(dy_pack[i])\n                      * (static_cast<ComputeType>(x_pack[i]) - mean_val) * inv_variance_val;\n      }\n    }\n    const ComputeType row_sum_stats1 = BlockAllReduce<SumOp, ComputeType, block_size>(sum_stats1);\n    const ComputeType row_sum_stats2 = BlockAllReduce<SumOp, ComputeType, block_size>(sum_stats2);\n    for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) {\n      const int pack_offset = pack_id * pack_size;\n      LoadTypeX x_pack[pack_size];\n      LoadTypeDy dy_pack[pack_size];\n      ComputeType dx_pack[pack_size];\n      load_x.template load<pack_size>(x_pack, row, pack_offset);\n      load_scaled_dy.template load<pack_size>(dy_pack, row, pack_offset);\n#pragma unroll\n      for (int i = 0; i < pack_size; ++i) {\n        dx_pack[i] =\n            (cols * static_cast<ComputeType>(dy_pack[i]) - row_sum_stats1\n             - (static_cast<ComputeType>(x_pack[i]) - mean_val) * inv_variance_val * row_sum_stats2)\n            * inv_variance_over_cols;\n      }\n      store.template store<pack_size>(dx_pack, row, pack_offset);\n    }\n  }\n}\n\ntemplate<typename LOAD_X, typename LOAD_SCALED_DY, typename STORE, typename ComputeType,\n         int pack_size, int block_size>\ninline cudaError_t LaunchLayerNormGradBlockUncachedImpl(cudaStream_t stream, LOAD_X load_x,\n                                                        LOAD_SCALED_DY load_scaled_dy, STORE store,\n                                                        const ComputeType* mean,\n                                                        const ComputeType* inv_variance,\n                                                        const int64_t rows, const int64_t cols) {\n  constexpr int waves = 32;\n  int grid_dim_x;\n  {\n    cudaError_t err =\n        GetNumBlocks(LayerNormGradBlockUncachedImpl<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType,\n                                                    pack_size, block_size>,\n                     block_size, 0, rows, waves, &grid_dim_x);\n    if (err != cudaSuccess) { return err; }\n  }\n  LayerNormGradBlockUncachedImpl<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType, pack_size, block_size>\n      <<<grid_dim_x, block_size, 0, stream>>>(load_x, load_scaled_dy, store, mean, inv_variance,\n                                              rows, cols);\n  return cudaPeekAtLastError();\n}\n\ntemplate<typename LOAD_X, typename LOAD_SCALED_DY, typename STORE, typename ComputeType,\n         int pack_size>\ninline cudaError_t TryDispatchLaunchLayerNormGradBlockUncachedImplBlockSize(\n    cudaStream_t stream, LOAD_X load_x, LOAD_SCALED_DY load_scaled_dy, STORE store,\n    const ComputeType* mean, const ComputeType* inv_variance, const int64_t rows,\n    const int64_t cols) {\n  int max_active_blocks = 0;\n  constexpr int block_size_conf_1 = 1024;\n  {\n    cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor(\n        &max_active_blocks,\n        LayerNormGradBlockUncachedImpl<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType, pack_size,\n                                       block_size_conf_1>,\n        block_size_conf_1, 0);\n    if (max_active_blocks > 0) {\n      return LaunchLayerNormGradBlockUncachedImpl<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType,\n                                                  pack_size, block_size_conf_1>(\n          stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols);\n    }\n  }\n  constexpr int block_size_conf_2 = 512;\n  {\n    cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor(\n        &max_active_blocks,\n        LayerNormGradBlockUncachedImpl<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType, pack_size,\n                                       block_size_conf_2>,\n        block_size_conf_2, 0);\n    if (max_active_blocks > 0) {\n      return LaunchLayerNormGradBlockUncachedImpl<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType,\n                                                  pack_size, block_size_conf_2>(\n          stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols);\n    }\n  }\n  constexpr int block_size_conf_3 = 256;\n  {\n    cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor(\n        &max_active_blocks,\n        LayerNormGradBlockUncachedImpl<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType, pack_size,\n                                       block_size_conf_3>,\n        block_size_conf_2, 0);\n    if (max_active_blocks > 0) {\n      return LaunchLayerNormGradBlockUncachedImpl<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType,\n                                                  pack_size, block_size_conf_3>(\n          stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols);\n    }\n  }\n  constexpr int block_size_conf_4 = 128;\n  return LaunchLayerNormGradBlockUncachedImpl<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType, pack_size,\n                                              block_size_conf_4>(\n      stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols);\n}\n\ntemplate<typename LOAD_X, typename LOAD_SCALED_DY, typename STORE, typename ComputeType>\nstruct DispatchLayerNormGradBlockUncachedImplPackSize {\n  cudaError_t operator()(cudaStream_t stream, LOAD_X load_x, LOAD_SCALED_DY load_scaled_dy,\n                         STORE store, const ComputeType* mean, const ComputeType* inv_variance,\n                         const int64_t rows, const int64_t cols) {\n    if (cols % 2 == 0 && CanPackAs<LOAD_X>(load_x, 2)\n        && CanPackAs<LOAD_SCALED_DY>(load_scaled_dy, 2) && CanPackAs<STORE>(store, 2)\n        && cols > kWarpSize) {\n      return TryDispatchLaunchLayerNormGradBlockUncachedImplBlockSize<LOAD_X, LOAD_SCALED_DY, STORE,\n                                                                      ComputeType, 2>(\n          stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols);\n    } else {\n      return TryDispatchLaunchLayerNormGradBlockUncachedImplBlockSize<LOAD_X, LOAD_SCALED_DY, STORE,\n                                                                      ComputeType, 1>(\n          stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols);\n    }\n  }\n};\n\ntemplate<typename LOAD_X, typename LOAD_SCALED_DY, typename STORE, typename ComputeType>\ninline cudaError_t DispatchLayerNormGradBlockUncachedImpl(cudaStream_t stream, LOAD_X load_x,\n                                                          LOAD_SCALED_DY load_scaled_dy,\n                                                          STORE store, const ComputeType* mean,\n                                                          const ComputeType* inv_variance,\n                                                          const int64_t rows, const int64_t cols) {\n  return DispatchLayerNormGradBlockUncachedImplPackSize<LOAD_X, LOAD_SCALED_DY, STORE,\n                                                        ComputeType>()(\n      stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols);\n}\n\ntemplate<typename LOAD_X, typename LOAD_SCALED_DY, typename STORE, typename ComputeType>\ninline typename std::enable_if<!std::is_same<ComputeType, double>::value, cudaError_t>::type\nDispatchLayerNormGrad(cudaStream_t stream, LOAD_X load_x, LOAD_SCALED_DY load_scaled_dy,\n                      STORE store, const ComputeType* mean, const ComputeType* inv_variance,\n                      const int64_t rows, const int64_t cols) {\n  if (cols <= 1024) {\n    return DispatchLayerNormGradWarpImpl<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType>(\n        stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols);\n  } else {\n    bool dispatch_smem_impl_success;\n    {\n      cudaError_t err =\n          TryDispatchLayerNormGradBlockSMemImpl<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType>(\n              stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols,\n              &dispatch_smem_impl_success);\n      if (err != cudaSuccess) { return err; }\n    }\n    if (!dispatch_smem_impl_success) {\n      return DispatchLayerNormGradBlockUncachedImpl<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType>(\n          stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols);\n    }\n    return cudaSuccess;\n  }\n}\n\ntemplate<typename LOAD_X, typename LOAD_SCALED_DY, typename STORE, typename ComputeType>\ninline typename std::enable_if<std::is_same<ComputeType, double>::value, cudaError_t>::type\nDispatchLayerNormGrad(cudaStream_t stream, LOAD_X load_x, LOAD_SCALED_DY load_scaled_dy,\n                      STORE store, const ComputeType* mean, const ComputeType* inv_variance,\n                      const int64_t rows, const int64_t cols) {\n  return DispatchLayerNormGradBlockUncachedImpl<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType>(\n      stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols);\n}\n\n}  // namespace layer_norm\n\n}  // namespace cuda\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_CUDA_LAYER_NORM_H_\n"
  },
  {
    "path": "oneflow/core/cuda/rms_norm.cuh",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#ifndef ONEFLOW_CORE_CUDA_RMS_NORM_H_\n#define ONEFLOW_CORE_CUDA_RMS_NORM_H_\n\n#include \"oneflow/core/cuda/layer_norm.cuh\"\n\nnamespace oneflow {\nnamespace cuda {\nnamespace rms_norm {\n\nconstexpr int kWarpSize = 32;\n\ntemplate<typename T>\n__inline__ __device__ T WarpReduceSum(T val) {\n  for (int mask = 16; mask > 0; mask /= 2) { val += __shfl_down_sync(0xffffffff, val, mask); }\n  return val;\n}\n\ntemplate<typename LOAD, typename STORE, typename ComputeType, int pack_size,\n         int max_cols_per_thread, int min_cols_per_thread, int thread_group_width,\n         int rows_per_access, bool padding>\n__global__ void RmsNormWarpImpl(LOAD load, STORE store, const int nrow, const int ncol,\n                                const double eps, ComputeType* inv_rms) {\n  static_assert(max_cols_per_thread % pack_size == 0, \"\");\n  static_assert(min_cols_per_thread % pack_size == 0, \"\");\n  static_assert(thread_group_width <= kWarpSize, \"\");\n  static_assert(kWarpSize % thread_group_width == 0, \"\");\n  constexpr int max_packs = max_cols_per_thread / pack_size;\n  constexpr int min_packs = min_cols_per_thread / pack_size;\n  assert(ncol <= max_cols_per_thread * thread_group_width);\n\n  ComputeType buf[rows_per_access][max_cols_per_thread];\n  const int global_thread_group_id = blockIdx.x * blockDim.y + threadIdx.y;\n  const int num_global_thread_groups = gridDim.x * blockDim.y;\n  for (int row_i = global_thread_group_id; row_i < nrow; row_i += num_global_thread_groups) {\n    ComputeType thread_square_sum[rows_per_access];\n#pragma unroll\n    for (int row_j = 0; row_j < rows_per_access; ++row_j) {\n      thread_square_sum[row_j] = 0;\n      ComputeType* row_buf = buf[row_j];\n      const int row = row_i * rows_per_access + row_j;\n#pragma unroll\n      for (int pack_i = 0; pack_i < min_packs; ++pack_i) {\n        const int pack_offset = pack_i * pack_size;\n        const int col = (pack_i * thread_group_width + threadIdx.x) * pack_size;\n        load.template load<pack_size>(row_buf + pack_offset, row, col);\n#pragma unroll\n        for (int pack_j = 0; pack_j < pack_size; ++pack_j) {\n          thread_square_sum[row_j] += row_buf[pack_offset + pack_j] * row_buf[pack_offset + pack_j];\n        }\n      }\n#pragma unroll\n      for (int pack_i = min_packs; pack_i < max_packs; ++pack_i) {\n        const int pack_offset = pack_i * pack_size;\n        const int col = (pack_i * thread_group_width + threadIdx.x) * pack_size;\n        if (!padding || col < ncol) {\n          load.template load<pack_size>(row_buf + pack_offset, row, col);\n#pragma unroll\n          for (int pack_j = 0; pack_j < pack_size; ++pack_j) {\n            thread_square_sum[row_j] +=\n                row_buf[pack_offset + pack_j] * row_buf[pack_offset + pack_j];\n          }\n        } else {\n#pragma unroll\n          for (int pack_j = 0; pack_j < pack_size; ++pack_j) {\n            row_buf[pack_i * pack_size + pack_j] = 0;\n          }\n        }\n      }\n    }\n    ComputeType warp_square_sum[rows_per_access];\n#pragma unroll\n    for (int row_j = 0; row_j < rows_per_access; ++row_j) {\n      const int row = row_i * rows_per_access + row_j;\n      ComputeType* row_buf = buf[row_j];\n      warp_square_sum[row_j] =\n          layer_norm::WarpAllReduce<layer_norm::SumOp, ComputeType, thread_group_width>(\n              thread_square_sum[row_j]);\n      ComputeType row_square_mean =\n          layer_norm::Div(warp_square_sum[row_j], static_cast<ComputeType>(ncol));\n      ComputeType row_inv_rms = layer_norm::Rsqrt(row_square_mean + static_cast<ComputeType>(eps));\n      if (threadIdx.x == 0) { inv_rms[row] = row_inv_rms; }\n#pragma unroll\n      for (int col = 0; col < max_cols_per_thread; ++col) { row_buf[col] *= row_inv_rms; }\n#pragma unroll\n      for (int pack_i = 0; pack_i < min_packs; ++pack_i) {\n        const int col = (pack_i * thread_group_width + threadIdx.x) * pack_size;\n        store.template store<pack_size>(row_buf + pack_i * pack_size, row, col);\n      }\n#pragma unroll\n      for (int pack_i = min_packs; pack_i < max_packs; ++pack_i) {\n        const int col = (pack_i * thread_group_width + threadIdx.x) * pack_size;\n        if (!padding || col < ncol) {\n          store.template store<pack_size>(row_buf + pack_i * pack_size, row, col);\n        }\n      }\n    }\n  }\n}\n\ntemplate<typename LOAD, typename STORE, typename ComputeType, int pack_size,\n         int max_cols_per_thread, int min_cols_per_thread, int thread_group_width,\n         int rows_per_access, bool padding>\ncudaError_t LaunchRmsNormWarpImpl(cudaStream_t stream, LOAD load, STORE store, const int64_t nrow,\n                                  const int64_t ncol, const double eps, ComputeType* inv_rms) {\n  constexpr int block_size = 128;\n  constexpr int waves = 32;\n  static_assert(block_size % thread_group_width == 0, \"\");\n  constexpr int thread_groups_per_block = block_size / thread_group_width;\n  const int64_t num_blocks =\n      (nrow / rows_per_access + thread_groups_per_block - 1) / thread_groups_per_block;\n  int grid_dim_x;\n  {\n    cudaError_t err = layer_norm::GetNumBlocks(\n        RmsNormWarpImpl<LOAD, STORE, ComputeType, pack_size, max_cols_per_thread,\n                        min_cols_per_thread, thread_group_width, rows_per_access, padding>,\n        block_size, 0, num_blocks, waves, &grid_dim_x);\n    if (err != cudaSuccess) { return err; }\n  }\n  dim3 block_dim(thread_group_width, thread_groups_per_block);\n  RmsNormWarpImpl<LOAD, STORE, ComputeType, pack_size, max_cols_per_thread, min_cols_per_thread,\n                  thread_group_width, rows_per_access, padding>\n      <<<grid_dim_x, block_dim, 0, stream>>>(load, store, static_cast<int>(nrow),\n                                             static_cast<int>(ncol), eps, inv_rms);\n  return cudaPeekAtLastError();\n}\n\ntemplate<typename LOAD, typename STORE, typename ComputeType, int pack_size,\n         int max_cols_per_thread, int min_cols_per_thread, int thread_group_width,\n         int rows_per_access>\ncudaError_t DispatchLaunchRmsNormWarpImplPadding(cudaStream_t stream, LOAD load, STORE store,\n                                                 const int64_t nrow, const int64_t ncol,\n                                                 const double eps, ComputeType* inv_rms) {\n  if (ncol == max_cols_per_thread * thread_group_width) {\n    // when not padding, min_cols_per_thread must equals to max_cols_per_thread, pass\n    // max_cols_per_thread as min_cols_per_thread and max_cols_per_thread param.\n    return LaunchRmsNormWarpImpl<LOAD, STORE, ComputeType, pack_size, max_cols_per_thread,\n                                 max_cols_per_thread, thread_group_width, rows_per_access, false>(\n        stream, load, store, nrow, ncol, eps, inv_rms);\n  } else {\n    return LaunchRmsNormWarpImpl<LOAD, STORE, ComputeType, pack_size, max_cols_per_thread,\n                                 min_cols_per_thread, thread_group_width, rows_per_access, true>(\n        stream, load, store, nrow, ncol, eps, inv_rms);\n  }\n}\n\ntemplate<typename LOAD, typename STORE, typename ComputeType, int pack_size>\ntypename std::enable_if<pack_size == 1, cudaError_t>::type DispatchLaunchRmsNormWarpImplCols(\n    cudaStream_t stream, LOAD load, STORE store, const int64_t nrow, const int64_t ncol,\n    const double eps, ComputeType* inv_rms) {\n  if (ncol <= 0) { return cudaErrorInvalidValue; }\n#define DEFINE_ONE_ELIF(thread_group_width)                                                       \\\n  else if (ncol <= (thread_group_width)*pack_size) {                                              \\\n    if (nrow % 2 == 0) {                                                                          \\\n      return DispatchLaunchRmsNormWarpImplPadding<LOAD, STORE, ComputeType, pack_size, pack_size, \\\n                                                  0, thread_group_width, 2>(                      \\\n          stream, load, store, nrow, ncol, eps, inv_rms);                                         \\\n    } else {                                                                                      \\\n      return DispatchLaunchRmsNormWarpImplPadding<LOAD, STORE, ComputeType, pack_size, pack_size, \\\n                                                  0, thread_group_width, 1>(                      \\\n          stream, load, store, nrow, ncol, eps, inv_rms);                                         \\\n    }                                                                                             \\\n  }\n  DEFINE_ONE_ELIF(4)\n  DEFINE_ONE_ELIF(8)\n  DEFINE_ONE_ELIF(16)\n  DEFINE_ONE_ELIF(32)\n#undef DEFINE_ONE_ELIF\n#define DEFINE_ONE_ELIF(max_col, min_col)                                                         \\\n  else if (ncol <= (max_col)*kWarpSize) {                                                         \\\n    return DispatchLaunchRmsNormWarpImplPadding<LOAD, STORE, ComputeType, pack_size, max_col,     \\\n                                                min_col, kWarpSize, 1>(stream, load, store, nrow, \\\n                                                                       ncol, eps, inv_rms);       \\\n  }\n  DEFINE_ONE_ELIF(2, 1)\n  DEFINE_ONE_ELIF(4, 2)\n  DEFINE_ONE_ELIF(8, 4)\n  DEFINE_ONE_ELIF(12, 8)\n  DEFINE_ONE_ELIF(16, 12)\n  DEFINE_ONE_ELIF(20, 16)\n  DEFINE_ONE_ELIF(24, 20)\n  DEFINE_ONE_ELIF(28, 24)\n  DEFINE_ONE_ELIF(32, 28)\n#undef DEFINE_ONE_ELIF\n  else {\n    return cudaErrorInvalidValue;\n  }\n}\n\ntemplate<typename LOAD, typename STORE, typename ComputeType, int pack_size>\ntypename std::enable_if<pack_size == 2, cudaError_t>::type DispatchLaunchRmsNormWarpImplCols(\n    cudaStream_t stream, LOAD load, STORE store, const int64_t nrow, const int64_t ncol,\n    const double eps, ComputeType* inv_rms) {\n  if (ncol <= 0) { return cudaErrorInvalidValue; }\n#define DEFINE_ONE_ELIF(thread_group_width)                                                       \\\n  else if (ncol <= (thread_group_width)*pack_size) {                                              \\\n    if (nrow % 2 == 0) {                                                                          \\\n      return DispatchLaunchRmsNormWarpImplPadding<LOAD, STORE, ComputeType, pack_size, pack_size, \\\n                                                  0, thread_group_width, 2>(                      \\\n          stream, load, store, nrow, ncol, eps, inv_rms);                                         \\\n    } else {                                                                                      \\\n      return DispatchLaunchRmsNormWarpImplPadding<LOAD, STORE, ComputeType, pack_size, pack_size, \\\n                                                  0, thread_group_width, 1>(                      \\\n          stream, load, store, nrow, ncol, eps, inv_rms);                                         \\\n    }                                                                                             \\\n  }\n  DEFINE_ONE_ELIF(4)\n  DEFINE_ONE_ELIF(8)\n  DEFINE_ONE_ELIF(16)\n  DEFINE_ONE_ELIF(32)\n#undef DEFINE_ONE_ELIF\n#define DEFINE_ONE_ELIF(max_col, min_col)                                                         \\\n  else if ((ncol <= (max_col)*kWarpSize) && (ncol > (min_col)*kWarpSize)) {                       \\\n    return DispatchLaunchRmsNormWarpImplPadding<LOAD, STORE, ComputeType, pack_size, max_col,     \\\n                                                min_col, kWarpSize, 1>(stream, load, store, nrow, \\\n                                                                       ncol, eps, inv_rms);       \\\n  }\n  DEFINE_ONE_ELIF(4, 2)\n  DEFINE_ONE_ELIF(8, 4)\n  DEFINE_ONE_ELIF(12, 8)\n  DEFINE_ONE_ELIF(16, 12)\n  DEFINE_ONE_ELIF(20, 16)\n  DEFINE_ONE_ELIF(24, 20)\n  DEFINE_ONE_ELIF(28, 24)\n  DEFINE_ONE_ELIF(32, 28)\n#undef DEFINE_ONE_ELIF\n  else {\n    return cudaErrorInvalidValue;\n  }\n}\n\ntemplate<typename LOAD, typename STORE, typename ComputeType>\ncudaError_t DispatchLaunchRmsNormWarpImplPackSize(cudaStream_t stream, LOAD load, STORE store,\n                                                  const int64_t nrow, const int64_t ncol,\n                                                  const double eps, ComputeType* inv_rms) {\n  if (ncol % 2 == 0 && layer_norm::CanPackAs<LOAD>(load, 2)\n      && layer_norm::CanPackAs<STORE>(store, 2)) {\n    return DispatchLaunchRmsNormWarpImplCols<LOAD, STORE, ComputeType, 2>(stream, load, store, nrow,\n                                                                          ncol, eps, inv_rms);\n  } else {\n    return DispatchLaunchRmsNormWarpImplCols<LOAD, STORE, ComputeType, 1>(stream, load, store, nrow,\n                                                                          ncol, eps, inv_rms);\n  }\n}\n\ntemplate<typename LOAD, typename STORE, typename ComputeType>\ncudaError_t DispatchLaunchRmsNormWarpImpl(cudaStream_t stream, LOAD load, STORE store,\n                                          const int64_t nrow, const int64_t ncol, const double eps,\n                                          ComputeType* inv_rms) {\n  return DispatchLaunchRmsNormWarpImplPackSize(stream, load, store, nrow, ncol, eps, inv_rms);\n}\n\ntemplate<typename LOAD, typename STORE, typename ComputeType, int pack_size, int block_size>\n__global__ void RmsNormBlockSMemImpl(LOAD load, STORE store, const int nrow, const int ncol,\n                                     const double eps, ComputeType* inv_rms) {\n  extern __shared__ __align__(sizeof(double)) unsigned char shared_buf[];\n  auto* buf = reinterpret_cast<ComputeType*>(shared_buf);\n  assert(ncol % pack_size == 0);\n  const int num_packs = ncol / pack_size;\n  for (int row = blockIdx.x; row < nrow; row += gridDim.x) {\n    ComputeType thread_square_sum = 0;\n    for (int pack_i = threadIdx.x; pack_i < num_packs; pack_i += block_size) {\n      ComputeType pack[pack_size];\n      const int col = pack_i * pack_size;\n      load.template load<pack_size>(pack, row, col);\n#pragma unroll\n      for (int pack_j = 0; pack_j < pack_size; ++pack_j) {\n        buf[pack_i * pack_size + pack_j] = pack[pack_j];\n        thread_square_sum += pack[pack_j] * pack[pack_j];\n      }\n    }\n    ComputeType row_square_sum =\n        layer_norm::BlockAllReduce<layer_norm::SumOp, ComputeType, block_size>(thread_square_sum);\n    ComputeType row_square_mean = layer_norm::Div(row_square_sum, static_cast<ComputeType>(ncol));\n    ComputeType row_inv_rms = layer_norm::Rsqrt(row_square_mean + static_cast<ComputeType>(eps));\n    if (threadIdx.x == 0) { inv_rms[row] = row_inv_rms; }\n    for (int pack_i = threadIdx.x; pack_i < num_packs; pack_i += block_size) {\n      ComputeType pack[pack_size];\n#pragma unroll\n      for (int pack_j = 0; pack_j < pack_size; ++pack_j) {\n        pack[pack_j] = buf[pack_i * pack_size + pack_j] * row_inv_rms;\n      }\n      const int col = pack_i * pack_size;\n      store.template store<pack_size>(pack, row, col);\n    }\n  }\n}\n\ntemplate<typename LOAD, typename STORE, typename ComputeType, int pack_size, int block_size>\ncudaError_t LaunchRmsNormBlockSMemImpl(cudaStream_t stream, LOAD load, STORE store,\n                                       size_t smem_size, const int64_t nrow, const int64_t ncol,\n                                       const double eps, ComputeType* inv_rms) {\n  constexpr int waves = 32;\n  int grid_dim_x;\n  {\n    cudaError_t err = layer_norm::GetNumBlocks(\n        RmsNormBlockSMemImpl<LOAD, STORE, ComputeType, pack_size, block_size>, block_size,\n        smem_size, nrow, waves, &grid_dim_x);\n    if (err != cudaSuccess) { return err; }\n  }\n  RmsNormBlockSMemImpl<LOAD, STORE, ComputeType, pack_size, block_size>\n      <<<grid_dim_x, block_size, smem_size, stream>>>(load, store, nrow, ncol, eps, inv_rms);\n  return cudaPeekAtLastError();\n}\n\ntemplate<typename LOAD, typename STORE, typename ComputeType, int pack_size>\ncudaError_t TryDispatchLaunchRmsNormBlockSMemImplBlockSize(cudaStream_t stream, LOAD load,\n                                                           STORE store, const int64_t nrow,\n                                                           const int64_t ncol, const double eps,\n                                                           ComputeType* inv_rms, bool* success) {\n  constexpr int block_size_conf_1 = 128;\n  constexpr int block_size_conf_2 = 256;\n  constexpr int block_size_conf_3 = 512;\n  constexpr int block_size_conf_4 = 1024;\n  const size_t smem_size = ncol * sizeof(ComputeType);\n  int max_active_blocks = 0;\n  int num_blocks = 0;\n\n#define SELECT_BLOCK_SIZE_CONF(block_size_conf)                                                  \\\n  {                                                                                              \\\n    cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor(                             \\\n        &num_blocks, RmsNormBlockSMemImpl<LOAD, STORE, ComputeType, pack_size, block_size_conf>, \\\n        block_size_conf, smem_size);                                                             \\\n    if (err != cudaSuccess) { return err; }                                                      \\\n    if (max_active_blocks == 0) {                                                                \\\n      if (num_blocks <= max_active_blocks) {                                                     \\\n        *success = false;                                                                        \\\n        return cudaSuccess;                                                                      \\\n      }                                                                                          \\\n      max_active_blocks = num_blocks;                                                            \\\n    } else {                                                                                     \\\n      if (num_blocks == max_active_blocks) {                                                     \\\n        *success = true;                                                                         \\\n        return LaunchRmsNormBlockSMemImpl<LOAD, STORE, ComputeType, pack_size, block_size_conf>( \\\n            stream, load, store, smem_size, nrow, ncol, eps, inv_rms);                           \\\n      }                                                                                          \\\n    }                                                                                            \\\n  }\n\n  SELECT_BLOCK_SIZE_CONF(block_size_conf_1)\n  SELECT_BLOCK_SIZE_CONF(block_size_conf_4)\n  SELECT_BLOCK_SIZE_CONF(block_size_conf_3)\n  SELECT_BLOCK_SIZE_CONF(block_size_conf_2)\n#undef SELECT_BLOCK_SIZE_CONF\n\n  *success = true;\n  return LaunchRmsNormBlockSMemImpl<LOAD, STORE, ComputeType, pack_size, block_size_conf_1>(\n      stream, load, store, smem_size, nrow, ncol, eps, inv_rms);\n}\n\ntemplate<typename LOAD, typename STORE, typename ComputeType>\ncudaError_t TryDispatchLaunchRmsNormBlockSMemImplPackSize(cudaStream_t stream, LOAD load,\n                                                          STORE store, const int64_t nrow,\n                                                          const int64_t ncol, const double eps,\n                                                          ComputeType* inv_rms, bool* success) {\n  if (ncol % 4 == 0 && layer_norm::CanPackAs<LOAD>(load, 4)\n      && layer_norm::CanPackAs<STORE>(store, 4)) {\n    return TryDispatchLaunchRmsNormBlockSMemImplBlockSize<LOAD, STORE, ComputeType, 4>(\n        stream, load, store, nrow, ncol, eps, inv_rms, success);\n  } else if (ncol % 2 == 0 && layer_norm::CanPackAs<LOAD>(load, 2)\n             && layer_norm::CanPackAs<STORE>(store, 2)) {\n    return TryDispatchLaunchRmsNormBlockSMemImplBlockSize<LOAD, STORE, ComputeType, 2>(\n        stream, load, store, nrow, ncol, eps, inv_rms, success);\n  } else {\n    return TryDispatchLaunchRmsNormBlockSMemImplBlockSize<LOAD, STORE, ComputeType, 1>(\n        stream, load, store, nrow, ncol, eps, inv_rms, success);\n  }\n}\n\ntemplate<typename LOAD, typename STORE, typename ComputeType>\ncudaError_t TryDispatchLaunchRmsNormBlockSMemImpl(cudaStream_t stream, LOAD load, STORE store,\n                                                  const int64_t nrow, const int64_t ncol,\n                                                  const double eps, ComputeType* inv_rms,\n                                                  bool* success) {\n  return TryDispatchLaunchRmsNormBlockSMemImplPackSize(stream, load, store, nrow, ncol, eps,\n                                                       inv_rms, success);\n}\n\ntemplate<typename LOAD, typename STORE, typename ComputeType, int pack_size, int block_size>\n__global__ void RmsNormBlockUncachedImpl(LOAD load, STORE store, const int nrow, const int ncol,\n                                         const double eps, ComputeType* inv_rms) {\n  assert(ncol % pack_size == 0);\n  const int num_packs = ncol / pack_size;\n  for (int row = blockIdx.x; row < nrow; row += gridDim.x) {\n    ComputeType thread_square_sum = 0;\n    for (int pack_i = threadIdx.x; pack_i < num_packs; pack_i += block_size) {\n      ComputeType pack[pack_size];\n      const int col = pack_i * pack_size;\n      load.template load<pack_size>(pack, row, col);\n#pragma unroll\n      for (int pack_j = 0; pack_j < pack_size; ++pack_j) {\n        thread_square_sum += pack[pack_j] * pack[pack_j];\n      }\n    }\n    ComputeType row_square_sum =\n        layer_norm::BlockAllReduce<layer_norm::SumOp, ComputeType, block_size>(thread_square_sum);\n    ComputeType row_square_mean = layer_norm::Div(row_square_sum, static_cast<ComputeType>(ncol));\n    ComputeType row_inv_rms = layer_norm::Rsqrt(row_square_mean + static_cast<ComputeType>(eps));\n    if (threadIdx.x == 0) { inv_rms[row] = row_inv_rms; }\n    for (int pack_i = threadIdx.x; pack_i < num_packs; pack_i += block_size) {\n      ComputeType pack[pack_size];\n      const int col = pack_i * pack_size;\n      load.template load<pack_size>(pack, row, col);\n#pragma unroll\n      for (int pack_j = 0; pack_j < pack_size; ++pack_j) {\n        pack[pack_j] = pack[pack_j] * row_inv_rms;\n      }\n      store.template store<pack_size>(pack, row, col);\n    }\n  }\n}\n\ntemplate<typename LOAD, typename STORE, typename ComputeType, int pack_size>\ncudaError_t LaunchRmsNormBlockUncachedImpl(cudaStream_t stream, LOAD load, STORE store,\n                                           const int64_t nrow, const int64_t ncol, const double eps,\n                                           ComputeType* inv_rms) {\n  constexpr int block_size = 1024;\n  constexpr int waves = 32;\n  int grid_dim_x;\n  {\n    cudaError_t err = layer_norm::GetNumBlocks(\n        RmsNormBlockUncachedImpl<LOAD, STORE, ComputeType, pack_size, block_size>, block_size, 0,\n        nrow, waves, &grid_dim_x);\n    if (err != cudaSuccess) { return err; }\n  }\n  RmsNormBlockUncachedImpl<LOAD, STORE, ComputeType, pack_size, block_size>\n      <<<grid_dim_x, block_size, 0, stream>>>(load, store, nrow, ncol, eps, inv_rms);\n  return cudaPeekAtLastError();\n}\n\ntemplate<typename LOAD, typename STORE, typename ComputeType>\ncudaError_t DispatchLaunchRmsNormBlockUncachedImplPackSize(cudaStream_t stream, LOAD load,\n                                                           STORE store, const int64_t nrow,\n                                                           const int64_t ncol, const double eps,\n                                                           ComputeType* inv_rms) {\n  if (ncol % 4 == 0 && layer_norm::CanPackAs<LOAD>(load, 4)\n      && layer_norm::CanPackAs<STORE>(store, 4)) {\n    return LaunchRmsNormBlockUncachedImpl<LOAD, STORE, ComputeType, 4>(stream, load, store, nrow,\n                                                                       ncol, eps, inv_rms);\n  } else if (ncol % 2 == 0 && layer_norm::CanPackAs<LOAD>(load, 2)\n             && layer_norm::CanPackAs<STORE>(store, 2)) {\n    return LaunchRmsNormBlockUncachedImpl<LOAD, STORE, ComputeType, 2>(stream, load, store, nrow,\n                                                                       ncol, eps, inv_rms);\n  } else {\n    return LaunchRmsNormBlockUncachedImpl<LOAD, STORE, ComputeType, 1>(stream, load, store, nrow,\n                                                                       ncol, eps, inv_rms);\n  }\n}\n\ntemplate<typename LOAD, typename STORE, typename ComputeType>\ncudaError_t DispatchLaunchRmsNormBlockUncachedImpl(cudaStream_t stream, LOAD load, STORE store,\n                                                   const int64_t nrow, const int64_t ncol,\n                                                   const double eps, ComputeType* inv_rms) {\n  return DispatchLaunchRmsNormBlockUncachedImplPackSize(stream, load, store, nrow, ncol, eps,\n                                                        inv_rms);\n}\n\ntemplate<typename LOAD, typename STORE, typename ComputeType>\ntypename std::enable_if<!std::is_same<ComputeType, double>::value, cudaError_t>::type LaunchRmsNorm(\n    cudaStream_t stream, LOAD load, STORE store, const int64_t nrow, const int64_t ncol,\n    const double eps, ComputeType* inv_rms) {\n  if (ncol <= 1024) {\n    return DispatchLaunchRmsNormWarpImpl(stream, load, store, nrow, ncol, eps, inv_rms);\n  } else {\n    bool dispatch_smem_impl_success = false;\n    {\n      cudaError_t err = TryDispatchLaunchRmsNormBlockSMemImpl(stream, load, store, nrow, ncol, eps,\n                                                              inv_rms, &dispatch_smem_impl_success);\n      if (err != cudaSuccess) { return err; }\n    }\n    if (!dispatch_smem_impl_success) {\n      return DispatchLaunchRmsNormBlockUncachedImpl(stream, load, store, nrow, ncol, eps, inv_rms);\n    }\n    return cudaSuccess;\n  }\n}\n\ntemplate<typename LOAD, typename STORE, typename ComputeType>\ntypename std::enable_if<std::is_same<ComputeType, double>::value, cudaError_t>::type LaunchRmsNorm(\n    cudaStream_t stream, LOAD load, STORE store, const int64_t nrow, const int64_t ncol,\n    const double eps, ComputeType* inv_rms) {\n  return DispatchLaunchRmsNormBlockUncachedImpl(stream, load, store, nrow, ncol, eps, inv_rms);\n}\n\ntemplate<typename LOAD_X, typename LOAD_DY, typename STORE, typename ComputeType, int pack_size,\n         int max_cols_per_thread, int min_cols_per_thread, int thread_group_width,\n         int rows_per_access>\n__global__ void RmsNormGradWarpImpl(const int nrow, const int ncol, LOAD_X load_x, LOAD_DY load_dy,\n                                    STORE store, const ComputeType* inv_rms) {\n  static_assert(max_cols_per_thread % pack_size == 0, \"\");\n  static_assert(min_cols_per_thread % pack_size == 0, \"\");\n  static_assert(thread_group_width <= kWarpSize, \"\");\n  static_assert(kWarpSize % thread_group_width == 0, \"\");\n  assert(ncol <= max_cols_per_thread * thread_group_width);\n\n  constexpr int max_packs = max_cols_per_thread / pack_size;\n  constexpr int min_packs = min_cols_per_thread / pack_size;\n\n  ComputeType normalized_buf[rows_per_access][max_cols_per_thread];\n  ComputeType dy_buf[rows_per_access][max_cols_per_thread];\n\n  const int global_thread_group_id = blockIdx.x * blockDim.y + threadIdx.y;\n  const int num_global_thread_group = gridDim.x * blockDim.y;\n  for (int row_i = global_thread_group_id; row_i < nrow; row_i += num_global_thread_group) {\n    ComputeType sum_stats[rows_per_access];\n    ComputeType inv_rms_buf[rows_per_access];\n#pragma unroll\n    for (int row_j = 0; row_j < rows_per_access; ++row_j) {\n      const int global_row = row_i * rows_per_access + row_j;\n      sum_stats[row_j] = 0;\n      inv_rms_buf[row_j] = inv_rms[global_row];\n      ComputeType* row_normalized_buf = normalized_buf[row_j];\n      ComputeType* row_dy_buf = dy_buf[row_j];\n#pragma unroll\n      for (int pack_i = 0; pack_i < min_packs; ++pack_i) {\n        const int pack_offset = pack_i * pack_size;\n        const int global_col = (pack_i * thread_group_width + threadIdx.x) * pack_size;\n        load_x.template load<pack_size>(row_normalized_buf + pack_offset, global_row, global_col);\n        load_dy.template load<pack_size>(row_dy_buf + pack_offset, global_row, global_col);\n#pragma unroll\n        for (int pack_j = 0; pack_j < pack_size; ++pack_j) {\n          const int col = pack_offset + pack_j;\n          row_normalized_buf[col] = row_normalized_buf[col] * inv_rms_buf[row_j];\n          sum_stats[row_j] += row_dy_buf[col] * row_normalized_buf[col];\n        }\n      }\n#pragma unroll\n      for (int pack_i = min_packs; pack_i < max_packs; ++pack_i) {\n        const int pack_offset = pack_i * pack_size;\n        const int global_col = (pack_i * thread_group_width + threadIdx.x) * pack_size;\n        if (global_col < ncol) {\n          load_x.template load<pack_size>(row_normalized_buf + pack_offset, global_row, global_col);\n          load_dy.template load<pack_size>(row_dy_buf + pack_offset, global_row, global_col);\n#pragma unroll\n          for (int pack_j = 0; pack_j < pack_size; ++pack_j) {\n            const int col = pack_offset + pack_j;\n            row_normalized_buf[col] = row_normalized_buf[col] * inv_rms_buf[row_j];\n            sum_stats[row_j] += row_dy_buf[col] * row_normalized_buf[col];\n          }\n        }\n      }\n    }\n    ComputeType warp_sum_stats[rows_per_access];\n#pragma unroll\n    for (int row_j = 0; row_j < rows_per_access; ++row_j) {\n      warp_sum_stats[row_j] =\n          layer_norm::WarpAllReduce<layer_norm::SumOp, ComputeType, thread_group_width>(\n              sum_stats[row_j]);\n    }\n#pragma unroll\n    for (int row_j = 0; row_j < rows_per_access; ++row_j) {\n      const int global_row = row_i * rows_per_access + row_j;\n      ComputeType* row_normalized_buf = normalized_buf[row_j];\n      ComputeType* row_dy_buf = dy_buf[row_j];\n#pragma unroll\n      for (int pack_i = 0; pack_i < min_packs; ++pack_i) {\n        const int pack_offset = pack_i * pack_size;\n        const int global_col = (pack_i * thread_group_width + threadIdx.x) * pack_size;\n        for (int pack_j = 0; pack_j < pack_size; ++pack_j) {\n          const int col = pack_offset + pack_j;\n          const ComputeType norm_val =\n              layer_norm::Div(row_normalized_buf[col], static_cast<ComputeType>(ncol));\n          row_dy_buf[col] =\n              (row_dy_buf[col] - norm_val * warp_sum_stats[row_j]) * inv_rms_buf[row_j];\n        }\n        store.template store<pack_size>(row_dy_buf + pack_offset, global_row, global_col);\n      }\n#pragma unroll\n      for (int pack_i = min_packs; pack_i < max_packs; ++pack_i) {\n        const int pack_offset = pack_i * pack_size;\n        const int global_col = (pack_i * thread_group_width + threadIdx.x) * pack_size;\n        if (global_col < ncol) {\n          for (int pack_j = 0; pack_j < pack_size; ++pack_j) {\n            const int col = pack_offset + pack_j;\n            const ComputeType norm_val =\n                layer_norm::Div(row_normalized_buf[col], static_cast<ComputeType>(ncol));\n            row_dy_buf[col] =\n                (row_dy_buf[col] - norm_val * warp_sum_stats[row_j]) * inv_rms_buf[row_j];\n          }\n          store.template store<pack_size>(row_dy_buf + pack_offset, global_row, global_col);\n        }\n      }\n    }\n  }\n}\n\ntemplate<typename LOAD_X, typename LOAD_DY, typename STORE, typename ComputeType, int pack_size,\n         int max_cols_per_thread, int min_cols_per_thread, int thread_group_width,\n         int rows_per_access>\ncudaError_t LaunchRmsNormGradWarpImpl(cudaStream_t stream, const int nrow, const int ncol,\n                                      LOAD_X load_x, LOAD_DY load_dy, STORE store,\n                                      const ComputeType* inv_rms) {\n  constexpr int block_size = 128;\n  constexpr int waves = 32;\n  static_assert(block_size % thread_group_width == 0, \"\");\n  constexpr int thread_groups_per_block = block_size / thread_group_width;\n  const int64_t num_blocks =\n      (nrow / rows_per_access + thread_groups_per_block - 1) / thread_groups_per_block;\n  int grid_dim_x;\n  {\n    cudaError_t err = layer_norm::GetNumBlocks(\n        RmsNormGradWarpImpl<LOAD_X, LOAD_DY, STORE, ComputeType, pack_size, max_cols_per_thread,\n                            min_cols_per_thread, thread_group_width, rows_per_access>,\n        block_size, 0, num_blocks, waves, &grid_dim_x);\n    if (err != cudaSuccess) { return err; }\n  }\n  dim3 block_dim(thread_group_width, thread_groups_per_block);\n  RmsNormGradWarpImpl<LOAD_X, LOAD_DY, STORE, ComputeType, pack_size, max_cols_per_thread,\n                      min_cols_per_thread, thread_group_width, rows_per_access>\n      <<<grid_dim_x, block_dim, 0, stream>>>(nrow, ncol, load_x, load_dy, store, inv_rms);\n  return cudaPeekAtLastError();\n}\n\ntemplate<typename LOAD_X, typename LOAD_DY, typename STORE, typename ComputeType, int pack_size>\ntypename std::enable_if<pack_size == 1, cudaError_t>::type DispatchLaunchRmsNormGradWarpImplCols(\n    cudaStream_t stream, const int64_t nrow, const int64_t ncol, LOAD_X load_x, LOAD_DY load_dy,\n    STORE store, const ComputeType* inv_rms) {\n  if (ncol <= 0) { return cudaErrorInvalidValue; }\n#define DEFINE_ONE_ELIF(thread_group_width)                                                       \\\n  else if (ncol <= (thread_group_width)*pack_size) {                                              \\\n    if (nrow % 2 == 0) {                                                                          \\\n      return LaunchRmsNormGradWarpImpl<LOAD_X, LOAD_DY, STORE, ComputeType, pack_size, pack_size, \\\n                                       0, thread_group_width, 2>(stream, nrow, ncol, load_x,      \\\n                                                                 load_dy, store, inv_rms);        \\\n    } else {                                                                                      \\\n      return LaunchRmsNormGradWarpImpl<LOAD_X, LOAD_DY, STORE, ComputeType, pack_size, pack_size, \\\n                                       0, thread_group_width, 1>(stream, nrow, ncol, load_x,      \\\n                                                                 load_dy, store, inv_rms);        \\\n    }                                                                                             \\\n  }\n  DEFINE_ONE_ELIF(4)\n  DEFINE_ONE_ELIF(8)\n  DEFINE_ONE_ELIF(16)\n  DEFINE_ONE_ELIF(32)\n#undef DEFINE_ONE_ELIF\n#define DEFINE_ONE_ELIF(max_col, min_col)                                                        \\\n  else if (ncol <= (max_col)*kWarpSize) {                                                        \\\n    return LaunchRmsNormGradWarpImpl<LOAD_X, LOAD_DY, STORE, ComputeType, pack_size, max_col,    \\\n                                     min_col, kWarpSize, 1>(stream, nrow, ncol, load_x, load_dy, \\\n                                                            store, inv_rms);                     \\\n  }\n  DEFINE_ONE_ELIF(2, 1)\n  DEFINE_ONE_ELIF(4, 2)\n  DEFINE_ONE_ELIF(8, 4)\n  DEFINE_ONE_ELIF(12, 8)\n  DEFINE_ONE_ELIF(16, 12)\n  DEFINE_ONE_ELIF(20, 16)\n  DEFINE_ONE_ELIF(24, 20)\n  DEFINE_ONE_ELIF(28, 24)\n  DEFINE_ONE_ELIF(32, 28)\n#undef DEFINE_ONE_ELIF\n  else {\n    return cudaErrorInvalidValue;\n  }\n}\n\ntemplate<typename LOAD_X, typename LOAD_DY, typename STORE, typename ComputeType>\ncudaError_t DispatchLaunchRmsNormGradWarpImplPackSize(cudaStream_t stream, const int64_t nrow,\n                                                      const int64_t ncol, LOAD_X load_x,\n                                                      LOAD_DY load_dy, STORE store,\n                                                      const ComputeType* inv_rms) {\n  return DispatchLaunchRmsNormGradWarpImplCols<LOAD_X, LOAD_DY, STORE, ComputeType, 1>(\n      stream, nrow, ncol, load_x, load_dy, store, inv_rms);\n}\n\ntemplate<typename LOAD_X, typename LOAD_DY, typename STORE, typename ComputeType, int pack_size,\n         int block_size>\n__global__ void RmsNormGradBlockSMemImpl(const int nrow, const int ncol, LOAD_X load_x,\n                                         LOAD_DY load_dy, STORE store, const ComputeType* inv_rms) {\n  extern __shared__ __align__(sizeof(double)) unsigned char dyn_smem[];\n  // dynamic shared memory for caching x and dy\n  auto* normalized_buf = reinterpret_cast<ComputeType*>(dyn_smem);\n  auto* dy_buf = normalized_buf + ncol;\n  assert(ncol % pack_size == 0);\n  const int num_packs = ncol / pack_size;\n  for (int row = blockIdx.x; row < nrow; row += gridDim.x) {\n    ComputeType sum_stats = 0;\n    const ComputeType inv_rms_val = inv_rms[row];\n    for (int pack_i = threadIdx.x; pack_i < num_packs; pack_i += blockDim.x) {\n      ComputeType x_pack[pack_size];\n      ComputeType dy_pack[pack_size];\n      const int pack_offset = pack_i * pack_size;\n      load_x.template load<pack_size>(x_pack, row, pack_offset);\n      load_dy.template load<pack_size>(dy_pack, row, pack_offset);\n#pragma unroll\n      for (int pack_j = 0; pack_j < pack_size; ++pack_j) {\n        const int col = pack_offset + pack_j;\n        normalized_buf[col] = x_pack[pack_j] * inv_rms_val;\n        dy_buf[col] = dy_pack[pack_j];\n        sum_stats += dy_buf[col] * normalized_buf[col];\n      }\n    }\n    const ComputeType row_sum_stats =\n        layer_norm::BlockAllReduce<layer_norm::SumOp, ComputeType, block_size>(sum_stats);\n    for (int pack_i = threadIdx.x; pack_i < num_packs; pack_i += blockDim.x) {\n      ComputeType pack[pack_size];\n      const int pack_offset = pack_i * pack_size;\n#pragma unroll\n      for (int pack_j = 0; pack_j < pack_size; ++pack_j) {\n        const int col = pack_offset + pack_j;\n        const ComputeType norm_val =\n            layer_norm::Div(normalized_buf[col], static_cast<ComputeType>(ncol));\n        pack[pack_j] = (dy_buf[col] - norm_val * row_sum_stats) * inv_rms_val;\n      }\n      store.template store<pack_size>(pack, row, pack_offset);\n    }\n  }\n}\n\ntemplate<typename LOAD_X, typename LOAD_DY, typename STORE, typename ComputeType, int pack_size,\n         int block_size>\ncudaError_t LaunchRmsNormGradBlockSMemImpl(cudaStream_t stream, const int64_t nrow,\n                                           const int64_t ncol, const size_t smem_size,\n                                           LOAD_X load_x, LOAD_DY load_dy, STORE store,\n                                           const ComputeType* inv_rms) {\n  constexpr int waves = 32;\n  int grid_dim_x;\n  {\n    cudaError_t err = layer_norm::GetNumBlocks(\n        RmsNormGradBlockSMemImpl<LOAD_X, LOAD_DY, STORE, ComputeType, pack_size, block_size>,\n        block_size, smem_size, nrow, waves, &grid_dim_x);\n    if (err != cudaSuccess) { return err; }\n  }\n  RmsNormGradBlockSMemImpl<LOAD_X, LOAD_DY, STORE, ComputeType, pack_size, block_size>\n      <<<grid_dim_x, block_size, smem_size, stream>>>(\n          static_cast<int>(nrow), static_cast<int>(ncol), load_x, load_dy, store, inv_rms);\n  return cudaPeekAtLastError();\n}\n\ntemplate<typename LOAD_X, typename LOAD_DY, typename STORE, typename ComputeType, int pack_size>\ncudaError_t TryDispatchLaunchRmsNormGradBlockSMemImplBlockSize(\n    cudaStream_t stream, const int64_t nrow, const int64_t ncol, LOAD_X load_x, LOAD_DY load_dy,\n    STORE store, const ComputeType* inv_rms, bool* success) {\n  constexpr int block_size_conf_1 = 128;\n  constexpr int block_size_conf_2 = 256;\n  constexpr int block_size_conf_3 = 512;\n  constexpr int block_size_conf_4 = 1024;\n  const size_t smem_size = ncol * sizeof(ComputeType) * 2;  // ncol * 2 for caching x and dy both\n  int max_active_blocks = 0;\n  int num_blocks = 0;\n\n#define SELECT_BLOCK_SIZE_CONF(block_size_conf)                                                    \\\n  {                                                                                                \\\n    cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor(                               \\\n        &num_blocks,                                                                               \\\n        RmsNormGradBlockSMemImpl<LOAD_X, LOAD_DY, STORE, ComputeType, pack_size, block_size_conf>, \\\n        block_size_conf, smem_size);                                                               \\\n    if (err != cudaSuccess) { return err; }                                                        \\\n    if (max_active_blocks == 0) {                                                                  \\\n      if (num_blocks <= max_active_blocks) {                                                       \\\n        *success = false;                                                                          \\\n        return cudaSuccess;                                                                        \\\n      }                                                                                            \\\n      max_active_blocks = num_blocks;                                                              \\\n    } else {                                                                                       \\\n      if (num_blocks == max_active_blocks) {                                                       \\\n        *success = true;                                                                           \\\n        return LaunchRmsNormGradBlockSMemImpl<LOAD_X, LOAD_DY, STORE, ComputeType, pack_size,      \\\n                                              block_size_conf>(stream, nrow, ncol, smem_size,      \\\n                                                               load_x, load_dy, store, inv_rms);   \\\n      }                                                                                            \\\n    }                                                                                              \\\n  }\n\n  SELECT_BLOCK_SIZE_CONF(block_size_conf_1)\n  SELECT_BLOCK_SIZE_CONF(block_size_conf_4)\n  SELECT_BLOCK_SIZE_CONF(block_size_conf_3)\n  SELECT_BLOCK_SIZE_CONF(block_size_conf_2)\n#undef SELECT_BLOCK_SIZE_CONF\n\n  *success = true;\n  return LaunchRmsNormGradBlockSMemImpl<LOAD_X, LOAD_DY, STORE, ComputeType, pack_size,\n                                        block_size_conf_1>(stream, nrow, ncol, smem_size, load_x,\n                                                           load_dy, store, inv_rms);\n}\n\ntemplate<typename LOAD_X, typename LOAD_DY, typename STORE, typename ComputeType>\ncudaError_t TryDispatchLaunchRmsNormGradBlockSMemImplPackSize(\n    cudaStream_t stream, const int64_t nrow, const int64_t ncol, LOAD_X load_x, LOAD_DY load_dy,\n    STORE store, const ComputeType* inv_rms, bool* success) {\n  if (ncol % 2 == 0 && layer_norm::CanPackAs<LOAD_X>(load_x, 2)\n      && layer_norm::CanPackAs<LOAD_DY>(load_dy, 2) && layer_norm::CanPackAs<STORE>(store, 2)) {\n    return TryDispatchLaunchRmsNormGradBlockSMemImplBlockSize<LOAD_X, LOAD_DY, STORE, ComputeType,\n                                                              2>(stream, nrow, ncol, load_x,\n                                                                 load_dy, store, inv_rms, success);\n  } else {\n    return TryDispatchLaunchRmsNormGradBlockSMemImplBlockSize<LOAD_X, LOAD_DY, STORE, ComputeType,\n                                                              1>(stream, nrow, ncol, load_x,\n                                                                 load_dy, store, inv_rms, success);\n  }\n}\n\ntemplate<typename LOAD_X, typename LOAD_DY, typename STORE, typename ComputeType, int pack_size,\n         int block_size>\n__global__ void RmsNormGradBlockUncachedImpl(const int nrow, const int ncol, LOAD_X load_x,\n                                             LOAD_DY load_dy, STORE store,\n                                             const ComputeType* inv_rms) {\n  assert(ncol % pack_size == 0);\n  const int num_packs = ncol / pack_size;\n  for (int row = blockIdx.x; row < nrow; row += gridDim.x) {\n    const ComputeType inv_rms_val = inv_rms[row];\n    ComputeType sum_stats = 0;\n    for (int pack_i = threadIdx.x; pack_i < num_packs; pack_i += blockDim.x) {\n      ComputeType x_pack[pack_size];\n      ComputeType dy_pack[pack_size];\n      const int pack_offset = pack_i * pack_size;\n      load_x.template load<pack_size>(x_pack, row, pack_offset);\n      load_dy.template load<pack_size>(dy_pack, row, pack_offset);\n#pragma unroll\n      for (int pack_j = 0; pack_j < pack_size; ++pack_j) {\n        sum_stats += dy_pack[pack_j] * x_pack[pack_j] * inv_rms_val;\n      }\n    }\n    const ComputeType row_sum_stats =\n        layer_norm::BlockAllReduce<layer_norm::SumOp, ComputeType, block_size>(sum_stats);\n    for (int pack_i = threadIdx.x; pack_i < num_packs; pack_i += blockDim.x) {\n      ComputeType x_pack[pack_size];\n      ComputeType dy_pack[pack_size];\n      const int pack_offset = pack_i * pack_size;\n      load_x.template load<pack_size>(x_pack, row, pack_offset);\n      load_dy.template load<pack_size>(dy_pack, row, pack_offset);\n#pragma unroll\n      for (int pack_j = 0; pack_j < pack_size; ++pack_j) {\n        const ComputeType norm_val =\n            layer_norm::Div(x_pack[pack_j] * inv_rms_val, static_cast<ComputeType>(ncol));\n        dy_pack[pack_j] = (dy_pack[pack_j] - norm_val * row_sum_stats) * inv_rms_val;\n      }\n      store.template store<pack_size>(dy_pack, row, pack_offset);\n    }\n  }\n}\n\ntemplate<typename LOAD_X, typename LOAD_DY, typename STORE, typename ComputeType, int pack_size,\n         int block_size>\ncudaError_t LaunchRmsNormGradBlockUncachedImpl(cudaStream_t stream, const int64_t nrow,\n                                               const int64_t ncol, LOAD_X load_x, LOAD_DY load_dy,\n                                               STORE store, const ComputeType* inv_rms) {\n  constexpr int waves = 32;\n  int grid_dim_x;\n  {\n    cudaError_t err = layer_norm::GetNumBlocks(\n        RmsNormGradBlockUncachedImpl<LOAD_X, LOAD_DY, STORE, ComputeType, pack_size, block_size>,\n        block_size, 0, nrow, waves, &grid_dim_x);\n    if (err != cudaSuccess) { return err; }\n  }\n  RmsNormGradBlockUncachedImpl<LOAD_X, LOAD_DY, STORE, ComputeType, pack_size, block_size>\n      <<<grid_dim_x, block_size, 0, stream>>>(nrow, ncol, load_x, load_dy, store, inv_rms);\n  return cudaPeekAtLastError();\n}\n\ntemplate<typename LOAD_X, typename LOAD_DY, typename STORE, typename ComputeType, int pack_size>\ncudaError_t DispatchLaunchRmsNormGradBlockUncachedImplBlockSize(cudaStream_t stream,\n                                                                const int64_t nrow,\n                                                                const int64_t ncol, LOAD_X load_x,\n                                                                LOAD_DY load_dy, STORE store,\n                                                                const ComputeType* inv_rms) {\n  constexpr int block_size_conf_1 = 128;\n  constexpr int block_size_conf_2 = 256;\n  constexpr int block_size_conf_3 = 512;\n  constexpr int block_size_conf_4 = 1024;\n  int max_active_blocks = 0;\n\n#define SELECT_BLOCK_SIZE_CONF(block_size_conf)                                                 \\\n  {                                                                                             \\\n    cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor(                            \\\n        &max_active_blocks,                                                                     \\\n        RmsNormGradBlockUncachedImpl<LOAD_X, LOAD_DY, STORE, ComputeType, pack_size,            \\\n                                     block_size_conf>,                                          \\\n        block_size_conf, 0);                                                                    \\\n    if (err != cudaSuccess) { return err; }                                                     \\\n    if (max_active_blocks > 0) {                                                                \\\n      return LaunchRmsNormGradBlockUncachedImpl<LOAD_X, LOAD_DY, STORE, ComputeType, pack_size, \\\n                                                block_size_conf>(stream, nrow, ncol, load_x,    \\\n                                                                 load_dy, store, inv_rms);      \\\n    }                                                                                           \\\n  }\n\n  SELECT_BLOCK_SIZE_CONF(block_size_conf_4)\n  SELECT_BLOCK_SIZE_CONF(block_size_conf_3)\n  SELECT_BLOCK_SIZE_CONF(block_size_conf_2)\n  SELECT_BLOCK_SIZE_CONF(block_size_conf_1)\n#undef SELECT_BLOCK_SIZE_CONF\n\n  return cudaErrorInvalidValue;\n}\n\ntemplate<typename LOAD_X, typename LOAD_DY, typename STORE, typename ComputeType>\ncudaError_t DispatchLaunchRmsNormGradBlockUncachedImplPackSize(cudaStream_t stream,\n                                                               const int64_t nrow,\n                                                               const int64_t ncol, LOAD_X load_x,\n                                                               LOAD_DY load_dy, STORE store,\n                                                               const ComputeType* inv_rms) {\n  if (ncol % 2 == 0 && layer_norm::CanPackAs<LOAD_X>(load_x, 2)\n      && layer_norm::CanPackAs<LOAD_DY>(load_dy, 2) && layer_norm::CanPackAs<STORE>(store, 2)\n      && ncol > kWarpSize) {\n    return DispatchLaunchRmsNormGradBlockUncachedImplBlockSize<LOAD_X, LOAD_DY, STORE, ComputeType,\n                                                               2>(stream, nrow, ncol, load_x,\n                                                                  load_dy, store, inv_rms);\n  } else {\n    return DispatchLaunchRmsNormGradBlockUncachedImplBlockSize<LOAD_X, LOAD_DY, STORE, ComputeType,\n                                                               1>(stream, nrow, ncol, load_x,\n                                                                  load_dy, store, inv_rms);\n  }\n}\n\ntemplate<typename LOAD_X, typename LOAD_DY, typename STORE, typename ComputeType>\ntypename std::enable_if<!std::is_same<ComputeType, double>::value, cudaError_t>::type\nLaunchRmsNormGrad(cudaStream_t stream, const int64_t nrow, const int64_t ncol, LOAD_X load_x,\n                  LOAD_DY load_dy, STORE store, const ComputeType* inv_rms) {\n  if (ncol <= 1024) {\n    return DispatchLaunchRmsNormGradWarpImplPackSize(stream, nrow, ncol, load_x, load_dy, store,\n                                                     inv_rms);\n  } else {\n    bool dispatch_smem_impl_success = false;\n    {\n      cudaError_t err = TryDispatchLaunchRmsNormGradBlockSMemImplPackSize(\n          stream, nrow, ncol, load_x, load_dy, store, inv_rms, &dispatch_smem_impl_success);\n      if (err != cudaSuccess) { return err; }\n    }\n    if (!dispatch_smem_impl_success) {\n      return DispatchLaunchRmsNormGradBlockUncachedImplPackSize(stream, nrow, ncol, load_x, load_dy,\n                                                                store, inv_rms);\n    }\n    return cudaSuccess;\n  }\n}\n\ntemplate<typename LOAD_X, typename LOAD_DY, typename STORE, typename ComputeType>\ntypename std::enable_if<std::is_same<ComputeType, double>::value, cudaError_t>::type\nLaunchRmsNormGrad(cudaStream_t stream, const int64_t nrow, const int64_t ncol, LOAD_X load_x,\n                  LOAD_DY load_dy, STORE store, const ComputeType* inv_rms) {\n  return DispatchLaunchRmsNormGradBlockUncachedImplPackSize(stream, nrow, ncol, load_x, load_dy,\n                                                            store, inv_rms);\n}\n\ntemplate<int nproc_per_thread, typename T, typename ComputeType>\n__global__ void RmsNormParamGrad(int nrow, int ncol, const T* __restrict__ dy,\n                                 const T* __restrict__ x, const ComputeType* __restrict__ inv_rms,\n                                 T* __restrict__ b_weight_grad) {\n  __shared__ ComputeType dweight[kWarpSize][kWarpSize + 1];\n  ComputeType dweight_sum[nproc_per_thread];\n#pragma unroll\n  for (int i = 0; i < nproc_per_thread; ++i) { dweight_sum[i] = 0; }\n  const int col = blockIdx.x * blockDim.x + threadIdx.x;\n  if (col < ncol) {\n    // a wave for one traverse (when nrow > warp_size * grad_dim_y)\n    for (int j = blockIdx.y * kWarpSize + threadIdx.y; j < nrow; j += kWarpSize * gridDim.y) {\n#pragma unroll\n      for (int i = 0; i < nproc_per_thread; ++i) {\n        int row = j + i * blockDim.y;\n        if (row < nrow) {\n          int offset = row * ncol + col;\n          const ComputeType dy_val = static_cast<ComputeType>(dy[offset]);\n          const ComputeType x_val = static_cast<ComputeType>(x[offset]);\n          const ComputeType inv_rms_val = inv_rms[row];\n          // collect dx from waves\n          dweight_sum[i] += dy_val * x_val * inv_rms_val;\n        }\n      }\n    }\n  }\n  // broadcast sum to the nproc_per_thread number rows\n  // each warp process the nproc_per_thread number rows of smem\n#pragma unroll\n  for (int i = 0; i < nproc_per_thread; ++i) {\n    dweight[i * blockDim.y + threadIdx.y][threadIdx.x] = dweight_sum[i];\n  }\n  __syncthreads();\n  // transpose access for leveraging warp to reduce rows in a block\n#pragma unroll\n  for (int i = 0; i < nproc_per_thread; ++i) {\n    // the first col of block threads is for storing the reduced sum of rows,\n    // and each first col thread is writing the nproc_per_thread number cols of output\n    const int row_in_block = threadIdx.y + i * blockDim.y;\n    const int col = blockIdx.x * blockDim.x + row_in_block;\n    if (col < ncol) {\n      // each warp process a col in which reduce sum all rows\n      ComputeType dweight_val = dweight[threadIdx.x][row_in_block];\n      ComputeType global_dweight = WarpReduceSum<ComputeType>(dweight_val);\n      if (threadIdx.x == 0) {\n        const int offset = blockIdx.y * ncol + col;\n        b_weight_grad[offset] = global_dweight;\n      }\n    }\n  }\n}\n\ntemplate<int nproc_per_thread, typename T>\ncudaError_t GetGrid2Dim(const int64_t nrow, const int64_t ncol, int block_dim_x, int block_dim_y,\n                        int* grid_dim_x, int* grid_dim_y) {\n  const int tile_size = block_dim_x;\n  if (nproc_per_thread * block_dim_y != tile_size) { return cudaErrorInvalidValue; }\n  *grid_dim_x = (ncol + tile_size - 1) / tile_size;\n  const int num_blocks_y = (nrow + tile_size - 1) / tile_size;\n\n  using ComputeType = typename layer_norm::DefaultComputeType<T>::type;\n  cudaError_t err = layer_norm::GetNumBlocks(RmsNormParamGrad<nproc_per_thread, T, ComputeType>,\n                                             block_dim_x * block_dim_y, /*dynamic_smem_size*/ 0,\n                                             num_blocks_y, /*waves*/ 1, grid_dim_y);\n  if (err != cudaSuccess) { return err; }\n  return cudaSuccess;\n}\n\n}  // namespace rms_norm\n}  // namespace cuda\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_CUDA_RMS_NORM_H_\n"
  },
  {
    "path": "oneflow/core/cuda/softmax.cuh",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#ifndef ONEFLOW_CORE_CUDA_SOFTMAX_H_\n#define ONEFLOW_CORE_CUDA_SOFTMAX_H_\n\n#include <cub/cub.cuh>\n#include <math_constants.h>\n#include <assert.h>\n#include <cuda.h>\n\n#if CUDA_VERSION >= 11000\n#include <cuda_bf16.h>\n#endif  // CUDA_VERSION >= 11000\n\nnamespace oneflow {\n\nnamespace cuda {\n\nnamespace softmax {\n\nconstexpr int kWarpSize = 32;\n\ntemplate<typename T>\nstruct SumOp {\n  __device__ __forceinline__ T operator()(const T& a, const T& b) const { return a + b; }\n};\n\ntemplate<typename T>\nstruct MaxOp {\n  __device__ __forceinline__ T operator()(const T& a, const T& b) const { return max(a, b); }\n};\n\ntemplate<template<typename> class ReductionOp, typename T, int thread_group_width = kWarpSize>\n__inline__ __device__ T WarpAllReduce(T val) {\n  for (int mask = thread_group_width / 2; mask > 0; mask /= 2) {\n    val = ReductionOp<T>()(val, __shfl_xor_sync(0xffffffff, val, mask));\n  }\n  return val;\n}\n\ntemplate<template<typename> class ReductionOp, typename T, int block_size>\n__inline__ __device__ T BlockAllReduce(T val) {\n  typedef cub::BlockReduce<T, block_size> BlockReduce;\n  __shared__ typename BlockReduce::TempStorage temp_storage;\n  __shared__ T result_broadcast;\n  T result = BlockReduce(temp_storage).Reduce(val, ReductionOp<T>());\n  if (threadIdx.x == 0) { result_broadcast = result; }\n  __syncthreads();\n  return result_broadcast;\n}\n\ntemplate<typename T>\n__inline__ __device__ T Inf();\n\ntemplate<>\n__inline__ __device__ float Inf<float>() {\n  return CUDART_INF_F;\n}\n\ntemplate<>\n__inline__ __device__ double Inf<double>() {\n  return CUDART_INF;\n}\n\ntemplate<typename T>\n__inline__ __device__ T Exp(T x);\n\ntemplate<>\n__inline__ __device__ float Exp<float>(float x) {\n#ifdef OF_SOFTMAX_USE_FAST_MATH\n  return __expf(x);\n#else\n  return exp(x);\n#endif\n}\n\ntemplate<>\n__inline__ __device__ double Exp<double>(double x) {\n  return exp(x);\n}\n\ntemplate<typename T>\n__inline__ __device__ T Div(T a, T b);\n\ntemplate<>\n__inline__ __device__ float Div<float>(float a, float b) {\n#ifdef OF_SOFTMAX_USE_FAST_MATH\n  return __fdividef(a, b);\n#else\n  return a / b;\n#endif\n}\n\ntemplate<>\n__inline__ __device__ double Div<double>(double a, double b) {\n  return a / b;\n}\n\ntemplate<typename T>\n__inline__ __device__ T Log(T x);\n\ntemplate<>\n__inline__ __device__ float Log<float>(float x) {\n#ifdef OF_SOFTMAX_USE_FAST_MATH\n  return __logf(x);\n#else\n  return log(x);\n#endif\n}\ntemplate<>\n__inline__ __device__ double Log<double>(double x) {\n  return log(x);\n}\n\ninline cudaError_t GetNumBlocks(int64_t block_size, int64_t max_blocks, int64_t waves,\n                                int* num_blocks) {\n  int dev;\n  {\n    cudaError_t err = cudaGetDevice(&dev);\n    if (err != cudaSuccess) { return err; }\n  }\n  int sm_count;\n  {\n    cudaError_t err = cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev);\n    if (err != cudaSuccess) { return err; }\n  }\n  int tpm;\n  {\n    cudaError_t err = cudaDeviceGetAttribute(&tpm, cudaDevAttrMaxThreadsPerMultiProcessor, dev);\n    if (err != cudaSuccess) { return err; }\n  }\n  *num_blocks =\n      std::max<int>(1, std::min<int64_t>(max_blocks, sm_count * tpm / block_size * waves));\n  return cudaSuccess;\n}\n\ntemplate<typename T>\nstruct DefaultComputeType {\n  using type = T;\n};\n\ntemplate<>\nstruct DefaultComputeType<half> {\n  using type = float;\n};\n\n#if CUDA_VERSION >= 11000\ntemplate<>\nstruct DefaultComputeType<nv_bfloat16> {\n  using type = float;\n};\n#endif  // CUDA_VERSION >= 11000\n\ntemplate<typename T, int N>\nstruct GetPackType {\n  using type = typename std::aligned_storage<N * sizeof(T), N * sizeof(T)>::type;\n};\n\ntemplate<typename T, int N>\nusing PackType = typename GetPackType<T, N>::type;\n\ntemplate<typename T, int N>\nunion Pack {\n  static_assert(sizeof(PackType<T, N>) == sizeof(T) * N, \"\");\n  __device__ Pack() {\n    // do nothing\n  }\n  PackType<T, N> storage;\n  T elem[N];\n};\n\ntemplate<typename SRC, typename DST>\nstruct DirectLoad {\n  DirectLoad(const SRC* src, int64_t row_size) : src(src), row_size(row_size) {}\n  template<int N>\n  __device__ void load(DST* dst, int64_t row, int64_t col) const {\n    Pack<SRC, N> pack;\n    const int64_t offset = (row * row_size + col) / N;\n    pack.storage = *(reinterpret_cast<const PackType<SRC, N>*>(src) + offset);\n#pragma unroll\n    for (int i = 0; i < N; ++i) { dst[i] = static_cast<DST>(pack.elem[i]); }\n  }\n  const SRC* src;\n  int64_t row_size;\n};\n\ntemplate<typename SRC, typename DST>\nstruct DirectStore {\n  DirectStore(DST* dst, int64_t row_size) : dst(dst), row_size(row_size) {}\n  template<int N>\n  __device__ void store(const SRC* src, int64_t row, int64_t col) {\n    Pack<DST, N> pack;\n    const int64_t offset = (row * row_size + col) / N;\n#pragma unroll\n    for (int i = 0; i < N; ++i) { pack.elem[i] = static_cast<DST>(src[i]); }\n    *(reinterpret_cast<PackType<DST, N>*>(dst) + offset) = pack.storage;\n  }\n  DST* dst;\n  int64_t row_size;\n};\n\nenum class Algorithm {\n  kSoftmax = 0,\n  kLogSoftmax = 1,\n};\n\ntemplate<typename LOAD, typename STORE, typename ComputeType, int pack_size, int cols_per_thread,\n         int thread_group_width, int rows_per_access, bool padding, Algorithm algorithm>\n__global__ void SoftmaxWarpImpl(LOAD load, STORE store, const int64_t rows, const int64_t cols) {\n  static_assert(cols_per_thread % pack_size == 0, \"\");\n  static_assert(thread_group_width <= kWarpSize, \"\");\n  static_assert(kWarpSize % thread_group_width == 0, \"\");\n  constexpr int num_packs = cols_per_thread / pack_size;\n  assert(cols <= cols_per_thread * thread_group_width);\n  ComputeType buf[rows_per_access][cols_per_thread];\n  const int global_thread_group_id = blockIdx.x * blockDim.y + threadIdx.y;\n  const int num_global_thread_group = gridDim.x * blockDim.y;\n  const int lane_id = threadIdx.x;\n  const int64_t step = num_global_thread_group * rows_per_access;\n  for (int64_t row = global_thread_group_id * rows_per_access; row < rows; row += step) {\n    ComputeType thread_max[rows_per_access];\n#pragma unroll\n    for (int row_id = 0; row_id < rows_per_access; ++row_id) {\n      thread_max[row_id] = -Inf<ComputeType>();\n      ComputeType* row_buf = buf[row_id];\n#pragma unroll\n      for (int pack_id = 0; pack_id < num_packs; ++pack_id) {\n        const int pack_offset = pack_id * pack_size;\n        const int col = (pack_id * thread_group_width + lane_id) * pack_size;\n        if (!padding || col < cols) {\n          load.template load<pack_size>(row_buf + pack_offset, row + row_id, col);\n#pragma unroll\n          for (int i = 0; i < pack_size; ++i) {\n            thread_max[row_id] = max(thread_max[row_id], row_buf[pack_offset + i]);\n          }\n        } else {\n#pragma unroll\n          for (int i = 0; i < pack_size; ++i) { row_buf[pack_offset + i] = -Inf<ComputeType>(); }\n        }\n      }\n    }\n    ComputeType warp_max[rows_per_access];\n#pragma unroll\n    for (int row_id = 0; row_id < rows_per_access; ++row_id) {\n      warp_max[row_id] = WarpAllReduce<MaxOp, ComputeType, thread_group_width>(thread_max[row_id]);\n    }\n    ComputeType thread_sum[rows_per_access];\n#pragma unroll\n    for (int row_id = 0; row_id < rows_per_access; ++row_id) {\n      thread_sum[row_id] = 0;\n      ComputeType* row_buf = buf[row_id];\n#pragma unroll\n      for (int i = 0; i < cols_per_thread; ++i) {\n        if (algorithm == Algorithm::kSoftmax) {\n          row_buf[i] = Exp(row_buf[i] - warp_max[row_id]);\n          thread_sum[row_id] += row_buf[i];\n        } else if (algorithm == Algorithm::kLogSoftmax) {\n          row_buf[i] -= warp_max[row_id];\n          thread_sum[row_id] += Exp(row_buf[i]);\n        } else {\n          __trap();\n        }\n      }\n    }\n    ComputeType warp_sum[rows_per_access];\n#pragma unroll\n    for (int row_id = 0; row_id < rows_per_access; ++row_id) {\n      warp_sum[row_id] = WarpAllReduce<SumOp, ComputeType, thread_group_width>(thread_sum[row_id]);\n    }\n#pragma unroll\n    for (int row_id = 0; row_id < rows_per_access; ++row_id) {\n      ComputeType* row_buf = buf[row_id];\n#pragma unroll\n      for (int i = 0; i < cols_per_thread; ++i) {\n        if (algorithm == Algorithm::kSoftmax) {\n          row_buf[i] = Div(row_buf[i], warp_sum[row_id]);\n        } else if (algorithm == Algorithm::kLogSoftmax) {\n          row_buf[i] -= Log(warp_sum[row_id]);\n        } else {\n          __trap();\n        }\n      }\n#pragma unroll\n      for (int i = 0; i < num_packs; ++i) {\n        const int col = (i * thread_group_width + lane_id) * pack_size;\n        if (!padding || col < cols) {\n          store.template store<pack_size>(row_buf + i * pack_size, row + row_id, col);\n        }\n      }\n    }\n  }\n}\n\ntemplate<typename LOAD, typename STORE, typename ComputeType, int pack_size, int cols_per_thread,\n         int thread_group_width, int rows_per_access, bool padding, Algorithm algorithm>\ninline cudaError_t LaunchSoftmaxWarpImpl(cudaStream_t stream, LOAD load, STORE store,\n                                         const int64_t rows, const int64_t cols) {\n  constexpr int block_size = 128;\n  constexpr int waves = 32;\n  static_assert(block_size % thread_group_width == 0, \"\");\n  constexpr int thread_groups_per_block = block_size / thread_group_width;\n  dim3 block_dim(thread_group_width, thread_groups_per_block);\n  const int64_t num_blocks =\n      (rows / rows_per_access + thread_groups_per_block - 1) / thread_groups_per_block;\n  int grid_dim_x;\n  {\n    cudaError_t err = GetNumBlocks(block_size, num_blocks, waves, &grid_dim_x);\n    if (err != cudaSuccess) { return err; }\n  }\n  SoftmaxWarpImpl<LOAD, STORE, ComputeType, pack_size, cols_per_thread, thread_group_width,\n                  rows_per_access, padding, algorithm>\n      <<<grid_dim_x, block_dim, 0, stream>>>(load, store, rows, cols);\n  return cudaPeekAtLastError();\n}\n\ntemplate<typename LOAD, typename STORE, typename ComputeType, int pack_size, int cols_per_thread,\n         int thread_group_width, int rows_per_access, Algorithm algorithm>\ninline cudaError_t DispatchSoftmaxWarpImplPadding(cudaStream_t stream, LOAD load, STORE store,\n                                                  const int64_t rows, const int64_t cols) {\n  if (cols == cols_per_thread * thread_group_width) {\n    return LaunchSoftmaxWarpImpl<LOAD, STORE, ComputeType, pack_size, cols_per_thread,\n                                 thread_group_width, rows_per_access, false, algorithm>(\n        stream, load, store, rows, cols);\n  } else {\n    return LaunchSoftmaxWarpImpl<LOAD, STORE, ComputeType, pack_size, cols_per_thread,\n                                 thread_group_width, rows_per_access, true, algorithm>(\n        stream, load, store, rows, cols);\n  }\n}\n\ntemplate<typename LOAD, typename STORE, typename ComputeType, int pack_size, Algorithm algorithm>\ntypename std::enable_if<pack_size == 1, cudaError_t>::type DispatchSoftmaxWarpImplCols(\n    cudaStream_t stream, LOAD load, STORE store, const int64_t rows, const int64_t cols) {\n  if (cols <= 0) { return cudaErrorInvalidValue; }\n#define DEFINE_ONE_ELIF(thread_group_width)                                                        \\\n  else if (cols <= (thread_group_width)*pack_size) {                                               \\\n    if (rows % 2 == 0) {                                                                           \\\n      return DispatchSoftmaxWarpImplPadding<LOAD, STORE, ComputeType, pack_size, pack_size,        \\\n                                            thread_group_width, 2, algorithm>(stream, load, store, \\\n                                                                              rows, cols);         \\\n    } else {                                                                                       \\\n      return DispatchSoftmaxWarpImplPadding<LOAD, STORE, ComputeType, pack_size, pack_size,        \\\n                                            thread_group_width, 1, algorithm>(stream, load, store, \\\n                                                                              rows, cols);         \\\n    }                                                                                              \\\n  }\n  DEFINE_ONE_ELIF(1)\n  DEFINE_ONE_ELIF(2)\n  DEFINE_ONE_ELIF(4)\n  DEFINE_ONE_ELIF(8)\n  DEFINE_ONE_ELIF(16)\n  DEFINE_ONE_ELIF(32)\n#undef DEFINE_ONE_ELIF\n#define DEFINE_ONE_ELIF(col)                                                                      \\\n  else if (cols <= (col)*kWarpSize) {                                                             \\\n    return DispatchSoftmaxWarpImplPadding<LOAD, STORE, ComputeType, pack_size, col, kWarpSize, 1, \\\n                                          algorithm>(stream, load, store, rows, cols);            \\\n  }\n  DEFINE_ONE_ELIF(2)\n  DEFINE_ONE_ELIF(3)\n  DEFINE_ONE_ELIF(4)\n  DEFINE_ONE_ELIF(5)\n  DEFINE_ONE_ELIF(6)\n  DEFINE_ONE_ELIF(7)\n  DEFINE_ONE_ELIF(8)\n  DEFINE_ONE_ELIF(9)\n  DEFINE_ONE_ELIF(10)\n  DEFINE_ONE_ELIF(11)\n  DEFINE_ONE_ELIF(12)\n  DEFINE_ONE_ELIF(13)\n  DEFINE_ONE_ELIF(14)\n  DEFINE_ONE_ELIF(15)\n  DEFINE_ONE_ELIF(16)\n  DEFINE_ONE_ELIF(17)\n  DEFINE_ONE_ELIF(18)\n  DEFINE_ONE_ELIF(19)\n  DEFINE_ONE_ELIF(20)\n  DEFINE_ONE_ELIF(21)\n  DEFINE_ONE_ELIF(22)\n  DEFINE_ONE_ELIF(23)\n  DEFINE_ONE_ELIF(24)\n  DEFINE_ONE_ELIF(25)\n  DEFINE_ONE_ELIF(26)\n  DEFINE_ONE_ELIF(27)\n  DEFINE_ONE_ELIF(28)\n  DEFINE_ONE_ELIF(29)\n  DEFINE_ONE_ELIF(30)\n  DEFINE_ONE_ELIF(31)\n  DEFINE_ONE_ELIF(32)\n#undef DEFINE_ONE_ELIF\n  else {\n    return cudaErrorInvalidValue;\n  }\n}\n\ntemplate<typename LOAD, typename STORE, typename ComputeType, int pack_size, Algorithm algorithm>\ntypename std::enable_if<pack_size == 2, cudaError_t>::type DispatchSoftmaxWarpImplCols(\n    cudaStream_t stream, LOAD load, STORE store, const int64_t rows, const int64_t cols) {\n  if (cols <= 0) { return cudaErrorInvalidValue; }\n#define DEFINE_ONE_ELIF(thread_group_width)                                                        \\\n  else if (cols <= (thread_group_width)*pack_size) {                                               \\\n    if (rows % 2 == 0) {                                                                           \\\n      return DispatchSoftmaxWarpImplPadding<LOAD, STORE, ComputeType, pack_size, pack_size,        \\\n                                            thread_group_width, 2, algorithm>(stream, load, store, \\\n                                                                              rows, cols);         \\\n    } else {                                                                                       \\\n      return DispatchSoftmaxWarpImplPadding<LOAD, STORE, ComputeType, pack_size, pack_size,        \\\n                                            thread_group_width, 1, algorithm>(stream, load, store, \\\n                                                                              rows, cols);         \\\n    }                                                                                              \\\n  }\n  DEFINE_ONE_ELIF(1)\n  DEFINE_ONE_ELIF(2)\n  DEFINE_ONE_ELIF(4)\n  DEFINE_ONE_ELIF(8)\n  DEFINE_ONE_ELIF(16)\n  DEFINE_ONE_ELIF(32)\n#undef DEFINE_ONE_ELIF\n#define DEFINE_ONE_ELIF(col)                                                                      \\\n  else if (cols <= (col)*kWarpSize) {                                                             \\\n    return DispatchSoftmaxWarpImplPadding<LOAD, STORE, ComputeType, pack_size, col, kWarpSize, 1, \\\n                                          algorithm>(stream, load, store, rows, cols);            \\\n  }\n  DEFINE_ONE_ELIF(4)\n  DEFINE_ONE_ELIF(6)\n  DEFINE_ONE_ELIF(8)\n  DEFINE_ONE_ELIF(10)\n  DEFINE_ONE_ELIF(12)\n  DEFINE_ONE_ELIF(14)\n  DEFINE_ONE_ELIF(16)\n  DEFINE_ONE_ELIF(18)\n  DEFINE_ONE_ELIF(20)\n  DEFINE_ONE_ELIF(22)\n  DEFINE_ONE_ELIF(24)\n  DEFINE_ONE_ELIF(26)\n  DEFINE_ONE_ELIF(28)\n  DEFINE_ONE_ELIF(30)\n  DEFINE_ONE_ELIF(32)\n#undef DEFINE_ONE_ELIF\n  else {\n    return cudaErrorInvalidValue;\n  }\n}\n\ntemplate<typename LOAD, typename STORE, typename ComputeType, Algorithm algorithm>\nstruct DispatchSoftmaxWarpImplPackSize {\n  cudaError_t operator()(cudaStream_t stream, LOAD load, STORE store, const int64_t rows,\n                         const int64_t cols) {\n    if (cols % 2 == 0) {\n      return DispatchSoftmaxWarpImplCols<LOAD, STORE, ComputeType, 2, algorithm>(stream, load,\n                                                                                 store, rows, cols);\n    } else {\n      return DispatchSoftmaxWarpImplCols<LOAD, STORE, ComputeType, 1, algorithm>(stream, load,\n                                                                                 store, rows, cols);\n    }\n  }\n};\n\ntemplate<typename LOAD, typename STORE, typename ComputeType, Algorithm algorithm>\ninline cudaError_t DispatchSoftmaxWarpImpl(cudaStream_t stream, LOAD load, STORE store,\n                                           const int64_t rows, const int64_t cols) {\n  return DispatchSoftmaxWarpImplPackSize<LOAD, STORE, ComputeType, algorithm>()(stream, load, store,\n                                                                                rows, cols);\n}\n\ntemplate<typename LOAD, typename STORE, typename ComputeType, int pack_size, int block_size,\n         Algorithm algorithm>\n__global__ void SoftmaxBlockSMemImpl(LOAD load, STORE store, const int64_t rows,\n                                     const int64_t cols) {\n  extern __shared__ __align__(sizeof(double)) unsigned char shared_buf[];\n  auto* buf = reinterpret_cast<ComputeType*>(shared_buf);\n  const int tid = threadIdx.x;\n  assert(cols % pack_size == 0);\n  const int num_packs = cols / pack_size;\n  for (int64_t row = blockIdx.x; row < rows; row += gridDim.x) {\n    ComputeType thread_max = -Inf<ComputeType>();\n    for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) {\n      ComputeType pack[pack_size];\n      load.template load<pack_size>(pack, row, pack_id * pack_size);\n#pragma unroll\n      for (int i = 0; i < pack_size; ++i) {\n        buf[i * num_packs + pack_id] = pack[i];\n        thread_max = max(thread_max, pack[i]);\n      }\n    }\n    const ComputeType row_max = BlockAllReduce<MaxOp, ComputeType, block_size>(thread_max);\n    ComputeType thread_sum = 0;\n    for (int col = tid; col < cols; col += block_size) {\n      if (algorithm == Algorithm::kSoftmax) {\n        const ComputeType exp_x = Exp(buf[col] - row_max);\n        buf[col] = exp_x;\n        thread_sum += exp_x;\n      } else {\n        const ComputeType x = buf[col] - row_max;\n        buf[col] = x;\n        thread_sum += Exp(x);\n      }\n    }\n    const ComputeType row_sum = BlockAllReduce<SumOp, ComputeType, block_size>(thread_sum);\n    for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) {\n      ComputeType pack[pack_size];\n#pragma unroll\n      for (int i = 0; i < pack_size; ++i) {\n        if (algorithm == Algorithm::kSoftmax) {\n          pack[i] = Div(buf[i * num_packs + pack_id], row_sum);\n        } else if (algorithm == Algorithm::kLogSoftmax) {\n          pack[i] = buf[i * num_packs + pack_id] - Log(row_sum);\n        } else {\n          __trap();\n        }\n      }\n      store.template store<pack_size>(pack, row, pack_id * pack_size);\n    }\n  }\n}\n\ntemplate<typename LOAD, typename STORE, typename ComputeType, int pack_size, int block_size,\n         Algorithm algorithm>\ninline cudaError_t LaunchSoftmaxBlockSMemImpl(cudaStream_t stream, LOAD load, STORE store, int smem,\n                                              const int64_t rows, const int64_t cols) {\n  constexpr int waves = 32;\n  int grid_dim_x;\n  {\n    cudaError_t err = GetNumBlocks(block_size, rows, waves, &grid_dim_x);\n    if (err != cudaSuccess) { return err; }\n  }\n  SoftmaxBlockSMemImpl<LOAD, STORE, ComputeType, pack_size, block_size, algorithm>\n      <<<grid_dim_x, block_size, smem, stream>>>(load, store, rows, cols);\n  return cudaPeekAtLastError();\n}\n\ntemplate<typename LOAD, typename STORE, typename ComputeType, int pack_size, Algorithm algorithm>\ninline cudaError_t TryDispatchSoftmaxBlockSMemImplBlockSize(cudaStream_t stream, LOAD load,\n                                                            STORE store, const int64_t rows,\n                                                            const int64_t cols, bool* success) {\n  constexpr int block_size_conf_1 = 128;\n  constexpr int block_size_conf_2 = 256;\n  constexpr int block_size_conf_3 = 512;\n  constexpr int block_size_conf_4 = 1024;\n  const size_t smem = cols * sizeof(ComputeType);\n  int max_active_blocks_conf_1;\n  {\n    cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor(\n        &max_active_blocks_conf_1,\n        SoftmaxBlockSMemImpl<LOAD, STORE, ComputeType, pack_size, block_size_conf_1, algorithm>,\n        block_size_conf_1, smem);\n    if (err != cudaSuccess) { return err; }\n  }\n  if (max_active_blocks_conf_1 <= 0) {\n    *success = false;\n    return cudaSuccess;\n  }\n  int max_active_blocks_conf_4;\n  {\n    cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor(\n        &max_active_blocks_conf_4,\n        SoftmaxBlockSMemImpl<LOAD, STORE, ComputeType, pack_size, block_size_conf_4, algorithm>,\n        block_size_conf_4, smem);\n    if (err != cudaSuccess) { return err; }\n  }\n  if (max_active_blocks_conf_4 == max_active_blocks_conf_1) {\n    *success = true;\n    return LaunchSoftmaxBlockSMemImpl<LOAD, STORE, ComputeType, pack_size, block_size_conf_4,\n                                      algorithm>(stream, load, store, smem, rows, cols);\n  }\n  int max_active_blocks_conf_3;\n  {\n    cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor(\n        &max_active_blocks_conf_3,\n        SoftmaxBlockSMemImpl<LOAD, STORE, ComputeType, pack_size, block_size_conf_3, algorithm>,\n        block_size_conf_3, smem);\n    if (err != cudaSuccess) { return err; }\n  }\n  if (max_active_blocks_conf_3 == max_active_blocks_conf_1) {\n    *success = true;\n    return LaunchSoftmaxBlockSMemImpl<LOAD, STORE, ComputeType, pack_size, block_size_conf_3,\n                                      algorithm>(stream, load, store, smem, rows, cols);\n  }\n  int max_active_blocks_conf_2;\n  {\n    cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor(\n        &max_active_blocks_conf_2,\n        SoftmaxBlockSMemImpl<LOAD, STORE, ComputeType, pack_size, block_size_conf_2, algorithm>,\n        block_size_conf_2, smem);\n    if (err != cudaSuccess) { return err; }\n  }\n  if (max_active_blocks_conf_2 == max_active_blocks_conf_1) {\n    *success = true;\n    return LaunchSoftmaxBlockSMemImpl<LOAD, STORE, ComputeType, pack_size, block_size_conf_2,\n                                      algorithm>(stream, load, store, smem, rows, cols);\n  }\n  *success = true;\n  return LaunchSoftmaxBlockSMemImpl<LOAD, STORE, ComputeType, pack_size, block_size_conf_1,\n                                    algorithm>(stream, load, store, smem, rows, cols);\n}\n\ntemplate<typename LOAD, typename STORE, typename ComputeType, Algorithm algorithm>\nstruct TryDispatchSoftmaxBlockSMemImplPackSize {\n  cudaError_t operator()(cudaStream_t stream, LOAD load, STORE store, const int64_t rows,\n                         const int64_t cols, bool* success) {\n    if (cols % 2 == 0) {\n      return TryDispatchSoftmaxBlockSMemImplBlockSize<LOAD, STORE, ComputeType, 2, algorithm>(\n          stream, load, store, rows, cols, success);\n    } else {\n      return TryDispatchSoftmaxBlockSMemImplBlockSize<LOAD, STORE, ComputeType, 1, algorithm>(\n          stream, load, store, rows, cols, success);\n    }\n  }\n};\n\ntemplate<typename LOAD, typename STORE, typename ComputeType, Algorithm algorithm>\ninline cudaError_t TryDispatchSoftmaxBlockSMemImpl(cudaStream_t stream, LOAD load, STORE store,\n                                                   const int64_t rows, const int64_t cols,\n                                                   bool* success) {\n  return TryDispatchSoftmaxBlockSMemImplPackSize<LOAD, STORE, ComputeType, algorithm>()(\n      stream, load, store, rows, cols, success);\n}\n\ntemplate<typename LOAD, typename STORE, typename ComputeType, int pack_size, int block_size,\n         Algorithm algorithm>\n__global__ void SoftmaxBlockUncachedImpl(LOAD load, STORE store, const int64_t rows,\n                                         const int64_t cols) {\n  const int tid = threadIdx.x;\n  assert(cols % pack_size == 0);\n  const int num_packs = cols / pack_size;\n  for (int64_t row = blockIdx.x; row < rows; row += gridDim.x) {\n    ComputeType thread_max = -Inf<ComputeType>();\n    for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) {\n      ComputeType pack[pack_size];\n      load.template load<pack_size>(pack, row, pack_id * pack_size);\n#pragma unroll\n      for (int i = 0; i < pack_size; ++i) { thread_max = max(thread_max, pack[i]); }\n    }\n    const ComputeType row_max = BlockAllReduce<MaxOp, ComputeType, block_size>(thread_max);\n    ComputeType thread_sum = 0;\n    for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) {\n      ComputeType pack[pack_size];\n      load.template load<pack_size>(pack, row, pack_id * pack_size);\n#pragma unroll\n      for (int i = 0; i < pack_size; ++i) { thread_sum += Exp(pack[i] - row_max); }\n    }\n    const ComputeType row_sum = BlockAllReduce<SumOp, ComputeType, block_size>(thread_sum);\n    for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) {\n      ComputeType pack[pack_size];\n      load.template load<pack_size>(pack, row, pack_id * pack_size);\n#pragma unroll\n      for (int i = 0; i < pack_size; ++i) {\n        if (algorithm == Algorithm::kSoftmax) {\n          pack[i] = Div(Exp(pack[i] - row_max), row_sum);\n        } else if (algorithm == Algorithm::kLogSoftmax) {\n          pack[i] = (pack[i] - row_max) - Log(row_sum);\n        } else {\n          __trap();\n        }\n      }\n      store.template store<pack_size>(pack, row, pack_id * pack_size);\n    }\n  }\n}\n\ntemplate<typename LOAD, typename STORE, typename ComputeType, int pack_size, Algorithm algorithm>\ninline cudaError_t LaunchSoftmaxBlockUncachedImpl(cudaStream_t stream, LOAD load, STORE store,\n                                                  const int64_t rows, const int64_t cols) {\n  constexpr int block_size = 1024;\n  constexpr int waves = 32;\n  int grid_dim_x;\n  {\n    cudaError_t err = GetNumBlocks(block_size, rows, waves, &grid_dim_x);\n    if (err != cudaSuccess) { return err; }\n  }\n  SoftmaxBlockUncachedImpl<LOAD, STORE, ComputeType, pack_size, block_size, algorithm>\n      <<<grid_dim_x, block_size, 0, stream>>>(load, store, rows, cols);\n  return cudaPeekAtLastError();\n}\n\ntemplate<typename LOAD, typename STORE, typename ComputeType, Algorithm algorithm>\nstruct DispatchSoftmaxBlockUncachedImplPackSize {\n  cudaError_t operator()(cudaStream_t stream, LOAD load, STORE store, const int64_t rows,\n                         const int64_t cols) {\n    if (cols % 2 == 0) {\n      return LaunchSoftmaxBlockUncachedImpl<LOAD, STORE, ComputeType, 2, algorithm>(\n          stream, load, store, rows, cols);\n    } else {\n      return LaunchSoftmaxBlockUncachedImpl<LOAD, STORE, ComputeType, 1, algorithm>(\n          stream, load, store, rows, cols);\n    }\n  }\n};\n\ntemplate<typename LOAD, typename STORE, typename ComputeType, Algorithm algorithm>\ninline cudaError_t DispatchSoftmaxBlockUncachedImpl(cudaStream_t stream, LOAD load, STORE store,\n                                                    const int64_t rows, const int64_t cols) {\n  return DispatchSoftmaxBlockUncachedImplPackSize<LOAD, STORE, ComputeType, algorithm>()(\n      stream, load, store, rows, cols);\n}\n\ntemplate<typename LOAD, typename STORE, typename ComputeType>\ninline typename std::enable_if<!std::is_same<ComputeType, double>::value, cudaError_t>::type\nDispatchSoftmax(cudaStream_t stream, LOAD load, STORE store, const int64_t rows,\n                const int64_t cols) {\n  if (cols < 1024) {\n    return DispatchSoftmaxWarpImpl<LOAD, STORE, ComputeType, Algorithm::kSoftmax>(\n        stream, load, store, rows, cols);\n  } else {\n    bool dispatch_smem_impl_success;\n    {\n      cudaError_t err =\n          TryDispatchSoftmaxBlockSMemImpl<LOAD, STORE, ComputeType, Algorithm::kSoftmax>(\n              stream, load, store, rows, cols, &dispatch_smem_impl_success);\n      if (err != cudaSuccess) { return err; }\n    }\n    if (!dispatch_smem_impl_success) {\n      return DispatchSoftmaxBlockUncachedImpl<LOAD, STORE, ComputeType, Algorithm::kSoftmax>(\n          stream, load, store, rows, cols);\n    }\n    return cudaSuccess;\n  }\n}\n\ntemplate<typename LOAD, typename STORE, typename ComputeType>\ninline typename std::enable_if<std::is_same<ComputeType, double>::value, cudaError_t>::type\nDispatchSoftmax(cudaStream_t stream, LOAD load, STORE store, const int64_t rows,\n                const int64_t cols) {\n  return DispatchSoftmaxBlockUncachedImpl<LOAD, STORE, ComputeType, Algorithm::kSoftmax>(\n      stream, load, store, rows, cols);\n}\n\ntemplate<typename LOAD, typename STORE, typename ComputeType>\ninline typename std::enable_if<!std::is_same<ComputeType, double>::value, cudaError_t>::type\nDispatchLogSoftmax(cudaStream_t stream, LOAD load, STORE store, const int64_t rows,\n                   const int64_t cols) {\n  if (cols <= 1024) {\n    return DispatchSoftmaxWarpImpl<LOAD, STORE, ComputeType, Algorithm::kLogSoftmax>(\n        stream, load, store, rows, cols);\n  } else {\n    bool dispatch_smem_impl_success;\n    {\n      cudaError_t err =\n          TryDispatchSoftmaxBlockSMemImpl<LOAD, STORE, ComputeType, Algorithm::kLogSoftmax>(\n              stream, load, store, rows, cols, &dispatch_smem_impl_success);\n      if (err != cudaSuccess) { return err; }\n    }\n    if (!dispatch_smem_impl_success) {\n      return DispatchSoftmaxBlockUncachedImpl<LOAD, STORE, ComputeType, Algorithm::kLogSoftmax>(\n          stream, load, store, rows, cols);\n    }\n    return cudaSuccess;\n  }\n}\n\ntemplate<typename LOAD, typename STORE, typename ComputeType>\ninline typename std::enable_if<std::is_same<ComputeType, double>::value, cudaError_t>::type\nDispatchLogSoftmax(cudaStream_t stream, LOAD load, STORE store, const int64_t rows,\n                   const int64_t cols) {\n  return DispatchSoftmaxBlockUncachedImpl<LOAD, STORE, ComputeType, Algorithm::kLogSoftmax>(\n      stream, load, store, rows, cols);\n}\n\ntemplate<typename LOAD_Y, typename LOAD_DY, typename STORE, typename ComputeType, int pack_size,\n         int cols_per_thread, int thread_group_width, int rows_per_access, bool padding,\n         Algorithm algorithm>\n__global__ void SoftmaxGradWarpImpl(LOAD_Y load_y, LOAD_DY load_dy, STORE store, const int64_t rows,\n                                    const int64_t cols) {\n  static_assert(cols_per_thread % pack_size == 0, \"\");\n  constexpr int pack_per_thread = cols_per_thread / pack_size;\n  assert(cols <= cols_per_thread * thread_group_width);\n  static_assert(thread_group_width <= kWarpSize, \"\");\n  static_assert(kWarpSize % thread_group_width == 0, \"\");\n  ComputeType y_buf[rows_per_access][cols_per_thread];\n  ComputeType dy_buf[rows_per_access][cols_per_thread];\n  const int global_thread_group_id = blockIdx.x * blockDim.y + threadIdx.y;\n  const int num_global_thread_group = gridDim.x * blockDim.y;\n  const int lane_id = threadIdx.x;\n  const int64_t step = num_global_thread_group * rows_per_access;\n  for (int64_t row = global_thread_group_id * rows_per_access; row < rows; row += step) {\n    ComputeType thread_sum[rows_per_access];\n#pragma unroll\n    for (int row_id = 0; row_id < rows_per_access; ++row_id) {\n      thread_sum[row_id] = 0;\n      ComputeType* row_y_buf = y_buf[row_id];\n      ComputeType* row_dy_buf = dy_buf[row_id];\n#pragma unroll\n      for (int pack_id = 0; pack_id < pack_per_thread; ++pack_id) {\n        const int pack_offset = pack_id * pack_size;\n        const int col = (pack_id * thread_group_width + lane_id) * pack_size;\n        if (!padding || col < cols) {\n          load_y.template load<pack_size>(row_y_buf + pack_offset, row + row_id, col);\n          load_dy.template load<pack_size>(row_dy_buf + pack_offset, row + row_id, col);\n#pragma unroll\n          for (int i = 0; i < pack_size; ++i) {\n            if (algorithm == Algorithm::kSoftmax) {\n              thread_sum[row_id] += row_y_buf[pack_offset + i] * row_dy_buf[pack_offset + i];\n            } else if (algorithm == Algorithm::kLogSoftmax) {\n              thread_sum[row_id] += row_dy_buf[pack_offset + i];\n            } else {\n              __trap();\n            }\n          }\n        }\n      }\n    }\n    ComputeType warp_sum[rows_per_access];\n#pragma unroll\n    for (int row_id = 0; row_id < rows_per_access; ++row_id) {\n      warp_sum[row_id] = WarpAllReduce<SumOp, ComputeType, thread_group_width>(thread_sum[row_id]);\n    }\n#pragma unroll\n    for (int row_id = 0; row_id < rows_per_access; ++row_id) {\n      ComputeType* row_y_buf = y_buf[row_id];\n      ComputeType* row_dy_buf = dy_buf[row_id];\n#pragma unroll\n      for (int pack_id = 0; pack_id < pack_per_thread; ++pack_id) {\n        const int pack_offset = pack_id * pack_size;\n        const int col = (pack_id * thread_group_width + lane_id) * pack_size;\n        if (!padding || col < cols) {\n          for (int i = 0; i < pack_size; ++i) {\n            if (algorithm == Algorithm::kSoftmax) {\n              row_dy_buf[pack_offset + i] =\n                  (row_dy_buf[pack_offset + i] - warp_sum[row_id]) * row_y_buf[pack_offset + i];\n            } else if (algorithm == Algorithm::kLogSoftmax) {\n              row_dy_buf[pack_offset + i] -= Exp(row_y_buf[pack_offset + i]) * warp_sum[row_id];\n            } else {\n              __trap();\n            }\n          }\n          store.template store<pack_size>(row_dy_buf + pack_offset, row + row_id, col);\n        }\n      }\n    }\n  }\n}\n\ntemplate<typename LOAD_Y, typename LOAD_DY, typename STORE, typename ComputeType, int pack_size,\n         int cols_per_thread, int thread_group_width, int rows_per_access, bool padding,\n         Algorithm algorithm>\ninline cudaError_t LaunchSoftmaxGradWarpImpl(cudaStream_t stream, LOAD_Y load_y, LOAD_DY load_dy,\n                                             STORE store, const int64_t rows, const int64_t cols) {\n  constexpr int block_size = 128;\n  constexpr int waves = 32;\n  static_assert(block_size % thread_group_width == 0, \"\");\n  constexpr int thread_groups_per_block = block_size / thread_group_width;\n  dim3 block_dim(thread_group_width, thread_groups_per_block);\n  const int64_t num_blocks =\n      (rows / rows_per_access + thread_groups_per_block - 1) / thread_groups_per_block;\n  int grid_dim_x;\n  {\n    cudaError_t err = GetNumBlocks(block_size, num_blocks, waves, &grid_dim_x);\n    if (err != cudaSuccess) { return err; }\n  }\n  SoftmaxGradWarpImpl<LOAD_Y, LOAD_DY, STORE, ComputeType, pack_size, cols_per_thread,\n                      thread_group_width, rows_per_access, padding, algorithm>\n      <<<grid_dim_x, block_dim, 0, stream>>>(load_y, load_dy, store, rows, cols);\n  return cudaPeekAtLastError();\n}\n\ntemplate<typename LOAD_Y, typename LOAD_DY, typename STORE, typename ComputeType, int pack_size,\n         int cols_per_thread, int thread_group_width, int rows_per_access, Algorithm algorithm>\ninline cudaError_t DispatchSoftmaxGradWarpImplPadding(cudaStream_t stream, LOAD_Y load_y,\n                                                      LOAD_DY load_dy, STORE store,\n                                                      const int64_t rows, const int64_t cols) {\n  if (cols == cols_per_thread * thread_group_width) {\n    return LaunchSoftmaxGradWarpImpl<LOAD_Y, LOAD_DY, STORE, ComputeType, pack_size,\n                                     cols_per_thread, thread_group_width, rows_per_access, false,\n                                     algorithm>(stream, load_y, load_dy, store, rows, cols);\n  } else {\n    return LaunchSoftmaxGradWarpImpl<LOAD_Y, LOAD_DY, STORE, ComputeType, pack_size,\n                                     cols_per_thread, thread_group_width, rows_per_access, true,\n                                     algorithm>(stream, load_y, load_dy, store, rows, cols);\n  }\n}\n\ntemplate<typename LOAD_Y, typename LOAD_DY, typename STORE, typename ComputeType, int pack_size,\n         Algorithm algorithm>\ntypename std::enable_if<pack_size == 1, cudaError_t>::type DispatchSoftmaxGradWarpImplCols(\n    cudaStream_t stream, LOAD_Y load_y, LOAD_DY load_dy, STORE store, const int64_t rows,\n    const int64_t cols) {\n  if (cols <= 0) { return cudaErrorInvalidValue; }\n#define DEFINE_ONE_ELIF(thread_group_width)                                                     \\\n  else if (cols <= (thread_group_width)*pack_size) {                                            \\\n    if (rows % 2 == 0) {                                                                        \\\n      return DispatchSoftmaxGradWarpImplPadding<LOAD_Y, LOAD_DY, STORE, ComputeType, pack_size, \\\n                                                pack_size, thread_group_width, 2, algorithm>(   \\\n          stream, load_y, load_dy, store, rows, cols);                                          \\\n    } else {                                                                                    \\\n      return DispatchSoftmaxGradWarpImplPadding<LOAD_Y, LOAD_DY, STORE, ComputeType, pack_size, \\\n                                                pack_size, thread_group_width, 1, algorithm>(   \\\n          stream, load_y, load_dy, store, rows, cols);                                          \\\n    }                                                                                           \\\n  }\n  DEFINE_ONE_ELIF(1)\n  DEFINE_ONE_ELIF(2)\n  DEFINE_ONE_ELIF(4)\n  DEFINE_ONE_ELIF(8)\n  DEFINE_ONE_ELIF(16)\n  DEFINE_ONE_ELIF(32)\n#undef DEFINE_ONE_ELIF\n#define DEFINE_ONE_ELIF(col)                                                                       \\\n  else if (cols <= (col)*kWarpSize) {                                                              \\\n    return DispatchSoftmaxGradWarpImplPadding<LOAD_Y, LOAD_DY, STORE, ComputeType, pack_size, col, \\\n                                              kWarpSize, 1, algorithm>(stream, load_y, load_dy,    \\\n                                                                       store, rows, cols);         \\\n  }\n  DEFINE_ONE_ELIF(2)\n  DEFINE_ONE_ELIF(3)\n  DEFINE_ONE_ELIF(4)\n  DEFINE_ONE_ELIF(5)\n  DEFINE_ONE_ELIF(6)\n  DEFINE_ONE_ELIF(7)\n  DEFINE_ONE_ELIF(8)\n  DEFINE_ONE_ELIF(9)\n  DEFINE_ONE_ELIF(10)\n  DEFINE_ONE_ELIF(11)\n  DEFINE_ONE_ELIF(12)\n  DEFINE_ONE_ELIF(13)\n  DEFINE_ONE_ELIF(14)\n  DEFINE_ONE_ELIF(15)\n  DEFINE_ONE_ELIF(16)\n  DEFINE_ONE_ELIF(17)\n  DEFINE_ONE_ELIF(18)\n  DEFINE_ONE_ELIF(19)\n  DEFINE_ONE_ELIF(20)\n  DEFINE_ONE_ELIF(21)\n  DEFINE_ONE_ELIF(22)\n  DEFINE_ONE_ELIF(23)\n  DEFINE_ONE_ELIF(24)\n  DEFINE_ONE_ELIF(25)\n  DEFINE_ONE_ELIF(26)\n  DEFINE_ONE_ELIF(27)\n  DEFINE_ONE_ELIF(28)\n  DEFINE_ONE_ELIF(29)\n  DEFINE_ONE_ELIF(30)\n  DEFINE_ONE_ELIF(31)\n  DEFINE_ONE_ELIF(32)\n#undef DEFINE_ONE_ELIF\n  else {\n    return cudaErrorInvalidValue;\n  }\n}\n\ntemplate<typename LOAD_Y, typename LOAD_DY, typename STORE, typename ComputeType, int pack_size,\n         Algorithm algorithm>\ntypename std::enable_if<pack_size == 2, cudaError_t>::type DispatchSoftmaxGradWarpImplCols(\n    cudaStream_t stream, LOAD_Y load_y, LOAD_DY load_dy, STORE store, const int64_t rows,\n    const int64_t cols) {\n  if (cols <= 0) { return cudaErrorInvalidValue; }\n#define DEFINE_ONE_ELIF(thread_group_width)                                                     \\\n  else if (cols <= (thread_group_width)*pack_size) {                                            \\\n    if (rows % 2 == 0) {                                                                        \\\n      return DispatchSoftmaxGradWarpImplPadding<LOAD_Y, LOAD_DY, STORE, ComputeType, pack_size, \\\n                                                pack_size, thread_group_width, 2, algorithm>(   \\\n          stream, load_y, load_dy, store, rows, cols);                                          \\\n    } else {                                                                                    \\\n      return DispatchSoftmaxGradWarpImplPadding<LOAD_Y, LOAD_DY, STORE, ComputeType, pack_size, \\\n                                                pack_size, thread_group_width, 1, algorithm>(   \\\n          stream, load_y, load_dy, store, rows, cols);                                          \\\n    }                                                                                           \\\n  }\n  DEFINE_ONE_ELIF(1)\n  DEFINE_ONE_ELIF(2)\n  DEFINE_ONE_ELIF(4)\n  DEFINE_ONE_ELIF(8)\n  DEFINE_ONE_ELIF(16)\n  DEFINE_ONE_ELIF(32)\n#undef DEFINE_ONE_ELIF\n#define DEFINE_ONE_ELIF(col)                                                                       \\\n  else if (cols <= (col)*kWarpSize) {                                                              \\\n    return DispatchSoftmaxGradWarpImplPadding<LOAD_Y, LOAD_DY, STORE, ComputeType, pack_size, col, \\\n                                              kWarpSize, 1, algorithm>(stream, load_y, load_dy,    \\\n                                                                       store, rows, cols);         \\\n  }\n  DEFINE_ONE_ELIF(4)\n  DEFINE_ONE_ELIF(6)\n  DEFINE_ONE_ELIF(8)\n  DEFINE_ONE_ELIF(10)\n  DEFINE_ONE_ELIF(12)\n  DEFINE_ONE_ELIF(14)\n  DEFINE_ONE_ELIF(16)\n  DEFINE_ONE_ELIF(18)\n  DEFINE_ONE_ELIF(20)\n  DEFINE_ONE_ELIF(22)\n  DEFINE_ONE_ELIF(24)\n  DEFINE_ONE_ELIF(26)\n  DEFINE_ONE_ELIF(28)\n  DEFINE_ONE_ELIF(30)\n  DEFINE_ONE_ELIF(32)\n#undef DEFINE_ONE_ELIF\n  else {\n    return cudaErrorInvalidValue;\n  }\n}\n\ntemplate<typename LOAD_Y, typename LOAD_DY, typename STORE, typename ComputeType,\n         Algorithm algorithm>\nstruct DispatchSoftmaxGradWarpImplPackSize {\n  cudaError_t operator()(cudaStream_t stream, LOAD_Y load_y, LOAD_DY load_dy, STORE store,\n                         const int64_t rows, const int64_t cols) {\n    if (cols % 2 == 0) {\n      return DispatchSoftmaxGradWarpImplCols<LOAD_Y, LOAD_DY, STORE, ComputeType, 2, algorithm>(\n          stream, load_y, load_dy, store, rows, cols);\n    } else {\n      return DispatchSoftmaxGradWarpImplCols<LOAD_Y, LOAD_DY, STORE, ComputeType, 1, algorithm>(\n          stream, load_y, load_dy, store, rows, cols);\n    }\n  }\n};\n\ntemplate<typename LOAD_Y, typename LOAD_DY, typename STORE, typename ComputeType,\n         Algorithm algorithm>\ninline cudaError_t DispatchSoftmaxGradWarpImpl(cudaStream_t stream, LOAD_Y load_y, LOAD_DY load_dy,\n                                               STORE store, const int64_t rows,\n                                               const int64_t cols) {\n  return DispatchSoftmaxGradWarpImplPackSize<LOAD_Y, LOAD_DY, STORE, ComputeType, algorithm>()(\n      stream, load_y, load_dy, store, rows, cols);\n}\n\ntemplate<typename LOAD_Y, typename LOAD_DY, typename STORE, typename ComputeType, int pack_size,\n         int block_size, Algorithm algorithm>\n__global__ void SoftmaxGradBlockSMemImpl(LOAD_Y load_y, LOAD_DY load_dy, STORE store,\n                                         const int64_t rows, const int64_t cols) {\n  extern __shared__ __align__(sizeof(double)) unsigned char grad_shared_buf[];\n  auto* y_buf = reinterpret_cast<ComputeType*>(grad_shared_buf);\n  auto* dy_buf = y_buf + cols;\n  const int tid = threadIdx.x;\n  assert(cols % pack_size == 0);\n  const int num_packs = cols / pack_size;\n  for (int64_t row = blockIdx.x; row < rows; row += gridDim.x) {\n    ComputeType thread_sum = 0;\n    for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) {\n      ComputeType y_pack[pack_size];\n      ComputeType dy_pack[pack_size];\n      load_y.template load<pack_size>(y_pack, row, pack_id * pack_size);\n      load_dy.template load<pack_size>(dy_pack, row, pack_id * pack_size);\n#pragma unroll\n      for (int i = 0; i < pack_size; ++i) {\n        y_buf[i * num_packs + pack_id] = y_pack[i];\n        dy_buf[i * num_packs + pack_id] = dy_pack[i];\n        if (algorithm == Algorithm::kSoftmax) {\n          thread_sum += y_pack[i] * dy_pack[i];\n        } else if (algorithm == Algorithm::kLogSoftmax) {\n          thread_sum += dy_pack[i];\n        } else {\n          __trap();\n        }\n      }\n    }\n    const ComputeType row_sum = BlockAllReduce<SumOp, ComputeType, block_size>(thread_sum);\n    for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) {\n      ComputeType pack[pack_size];\n#pragma unroll\n      for (int i = 0; i < pack_size; ++i) {\n        if (algorithm == Algorithm::kSoftmax) {\n          pack[i] = (dy_buf[i * num_packs + pack_id] - row_sum) * y_buf[i * num_packs + pack_id];\n        } else if (algorithm == Algorithm::kLogSoftmax) {\n          pack[i] = dy_buf[i * num_packs + pack_id] - Exp(y_buf[i * num_packs + pack_id]) * row_sum;\n        } else {\n          __trap();\n        }\n      }\n      store.template store<pack_size>(pack, row, pack_id * pack_size);\n    }\n  }\n}\n\ntemplate<typename LOAD_Y, typename LOAD_DY, typename STORE, typename ComputeType, int pack_size,\n         int block_size, Algorithm algorithm>\ninline cudaError_t LaunchSoftmaxGradBlockSMemImpl(cudaStream_t stream, LOAD_Y load_y,\n                                                  LOAD_DY load_dy, STORE store, int smem,\n                                                  const int64_t rows, const int64_t cols) {\n  constexpr int waves = 32;\n  int grid_dim_x;\n  {\n    cudaError_t err = GetNumBlocks(block_size, rows, waves, &grid_dim_x);\n    if (err != cudaSuccess) { return err; }\n  }\n  SoftmaxGradBlockSMemImpl<LOAD_Y, LOAD_DY, STORE, ComputeType, pack_size, block_size, algorithm>\n      <<<grid_dim_x, block_size, smem, stream>>>(load_y, load_dy, store, rows, cols);\n  return cudaPeekAtLastError();\n}\n\ntemplate<typename LOAD_Y, typename LOAD_DY, typename STORE, typename ComputeType, int pack_size,\n         Algorithm algorithm>\ninline cudaError_t TryDispatchSoftmaxGradBlockSMemImplBlockSize(cudaStream_t stream, LOAD_Y load_y,\n                                                                LOAD_DY load_dy, STORE store,\n                                                                const int64_t rows,\n                                                                const int64_t cols, bool* success) {\n  constexpr int block_size_conf_1 = 128;\n  constexpr int block_size_conf_2 = 256;\n  constexpr int block_size_conf_3 = 512;\n  constexpr int block_size_conf_4 = 1024;\n  const size_t smem = cols * sizeof(ComputeType) * 2;\n  int max_active_blocks_conf_1;\n  {\n    cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor(\n        &max_active_blocks_conf_1,\n        SoftmaxGradBlockSMemImpl<LOAD_Y, LOAD_DY, STORE, ComputeType, pack_size, block_size_conf_1,\n                                 algorithm>,\n        block_size_conf_1, smem);\n    if (err != cudaSuccess) { return err; }\n  }\n  if (max_active_blocks_conf_1 <= 0) {\n    *success = false;\n    return cudaSuccess;\n  }\n  int max_active_blocks_conf_4;\n  {\n    cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor(\n        &max_active_blocks_conf_4,\n        SoftmaxGradBlockSMemImpl<LOAD_Y, LOAD_DY, STORE, ComputeType, pack_size, block_size_conf_4,\n                                 algorithm>,\n        block_size_conf_4, smem);\n    if (err != cudaSuccess) { return err; }\n  }\n  if (max_active_blocks_conf_4 == max_active_blocks_conf_1) {\n    *success = true;\n    return LaunchSoftmaxGradBlockSMemImpl<LOAD_Y, LOAD_DY, STORE, ComputeType, pack_size,\n                                          block_size_conf_4, algorithm>(stream, load_y, load_dy,\n                                                                        store, smem, rows, cols);\n  }\n  int max_active_blocks_conf_3;\n  {\n    cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor(\n        &max_active_blocks_conf_3,\n        SoftmaxGradBlockSMemImpl<LOAD_Y, LOAD_DY, STORE, ComputeType, pack_size, block_size_conf_3,\n                                 algorithm>,\n        block_size_conf_3, smem);\n    if (err != cudaSuccess) { return err; }\n  }\n  if (max_active_blocks_conf_3 == max_active_blocks_conf_1) {\n    *success = true;\n    return LaunchSoftmaxGradBlockSMemImpl<LOAD_Y, LOAD_DY, STORE, ComputeType, pack_size,\n                                          block_size_conf_3, algorithm>(stream, load_y, load_dy,\n                                                                        store, smem, rows, cols);\n  }\n  int max_active_blocks_conf_2;\n  {\n    cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor(\n        &max_active_blocks_conf_2,\n        SoftmaxGradBlockSMemImpl<LOAD_Y, LOAD_DY, STORE, ComputeType, pack_size, block_size_conf_2,\n                                 algorithm>,\n        block_size_conf_2, smem);\n    if (err != cudaSuccess) { return err; }\n  }\n  if (max_active_blocks_conf_2 == max_active_blocks_conf_1) {\n    *success = true;\n    return LaunchSoftmaxGradBlockSMemImpl<LOAD_Y, LOAD_DY, STORE, ComputeType, pack_size,\n                                          block_size_conf_2, algorithm>(stream, load_y, load_dy,\n                                                                        store, smem, rows, cols);\n  }\n  *success = true;\n  return LaunchSoftmaxGradBlockSMemImpl<LOAD_Y, LOAD_DY, STORE, ComputeType, pack_size,\n                                        block_size_conf_1, algorithm>(stream, load_y, load_dy,\n                                                                      store, smem, rows, cols);\n}\n\ntemplate<typename LOAD_Y, typename LOAD_DY, typename STORE, typename ComputeType,\n         Algorithm algorithm>\nstruct TryDispatchSoftmaxGradBlockSMemImplPackSize {\n  cudaError_t operator()(cudaStream_t stream, LOAD_Y load_y, LOAD_DY load_dy, STORE store,\n                         const int64_t rows, const int64_t cols, bool* success) {\n    if (cols % 2 == 0) {\n      return TryDispatchSoftmaxGradBlockSMemImplBlockSize<LOAD_Y, LOAD_DY, STORE, ComputeType, 2,\n                                                          algorithm>(stream, load_y, load_dy, store,\n                                                                     rows, cols, success);\n    } else {\n      return TryDispatchSoftmaxGradBlockSMemImplBlockSize<LOAD_Y, LOAD_DY, STORE, ComputeType, 1,\n                                                          algorithm>(stream, load_y, load_dy, store,\n                                                                     rows, cols, success);\n    }\n  }\n};\n\ntemplate<typename LOAD_Y, typename LOAD_DY, typename STORE, typename ComputeType,\n         Algorithm algorithm>\ninline cudaError_t TryDispatchSoftmaxGradBlockSMemImpl(cudaStream_t stream, LOAD_Y load_y,\n                                                       LOAD_DY load_dy, STORE store,\n                                                       const int64_t rows, const int64_t cols,\n                                                       bool* success) {\n  return TryDispatchSoftmaxGradBlockSMemImplPackSize<LOAD_Y, LOAD_DY, STORE, ComputeType,\n                                                     algorithm>()(stream, load_y, load_dy, store,\n                                                                  rows, cols, success);\n}\n\ntemplate<typename LOAD_Y, typename LOAD_DY, typename STORE, typename ComputeType, int pack_size,\n         int block_size, Algorithm algorithm>\n__global__ void SoftmaxGradBlockUncachedImpl(LOAD_Y load_y, LOAD_DY load_dy, STORE store,\n                                             const int64_t rows, const int64_t cols) {\n  const int tid = threadIdx.x;\n  assert(cols % pack_size == 0);\n  const int num_packs = cols / pack_size;\n  for (int64_t row = blockIdx.x; row < rows; row += gridDim.x) {\n    ComputeType thread_sum = 0;\n    for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) {\n      ComputeType y_pack[pack_size];\n      ComputeType dy_pack[pack_size];\n      load_y.template load<pack_size>(y_pack, row, pack_id * pack_size);\n      load_dy.template load<pack_size>(dy_pack, row, pack_id * pack_size);\n\n#pragma unroll\n      for (int i = 0; i < pack_size; ++i) {\n        if (algorithm == Algorithm::kSoftmax) {\n          thread_sum += y_pack[i] * dy_pack[i];\n        } else if (algorithm == Algorithm::kLogSoftmax) {\n          thread_sum += dy_pack[i];\n        } else {\n          __trap();\n        }\n      }\n    }\n    const ComputeType row_sum = BlockAllReduce<SumOp, ComputeType, block_size>(thread_sum);\n    for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) {\n      ComputeType y_pack[pack_size];\n      ComputeType dy_pack[pack_size];\n      load_y.template load<pack_size>(y_pack, row, pack_id * pack_size);\n      load_dy.template load<pack_size>(dy_pack, row, pack_id * pack_size);\n#pragma unroll\n      for (int i = 0; i < pack_size; ++i) {\n        if (algorithm == Algorithm::kSoftmax) {\n          dy_pack[i] = (dy_pack[i] - row_sum) * y_pack[i];\n        } else if (algorithm == Algorithm::kLogSoftmax) {\n          dy_pack[i] -= Exp(y_pack[i]) * row_sum;\n        } else {\n          __trap();\n        }\n      }\n      store.template store<pack_size>(dy_pack, row, pack_id * pack_size);\n    }\n  }\n}\n\ntemplate<typename LOAD_Y, typename LOAD_DY, typename STORE, typename ComputeType, int pack_size,\n         Algorithm algorithm>\ninline cudaError_t LaunchSoftmaxGradBlockUncachedImpl(cudaStream_t stream, LOAD_Y load_y,\n                                                      LOAD_DY load_dy, STORE store,\n                                                      const int64_t rows, const int64_t cols) {\n  constexpr int block_size = 1024;\n  constexpr int waves = 32;\n  int grid_dim_x;\n  {\n    cudaError_t err = GetNumBlocks(block_size, rows, waves, &grid_dim_x);\n    if (err != cudaSuccess) { return err; }\n  }\n  SoftmaxGradBlockUncachedImpl<LOAD_Y, LOAD_DY, STORE, ComputeType, pack_size, block_size,\n                               algorithm>\n      <<<grid_dim_x, block_size, 0, stream>>>(load_y, load_dy, store, rows, cols);\n  return cudaPeekAtLastError();\n}\n\ntemplate<typename LOAD_Y, typename LOAD_DY, typename STORE, typename ComputeType,\n         Algorithm algorithm>\nstruct DispatchSoftmaxGradBlockUncachedImplPackSize {\n  cudaError_t operator()(cudaStream_t stream, LOAD_Y load_y, LOAD_DY load_dy, STORE store,\n                         const int64_t rows, const int64_t cols) {\n    if (cols % 2 == 0 && cols > kWarpSize) {\n      return LaunchSoftmaxGradBlockUncachedImpl<LOAD_Y, LOAD_DY, STORE, ComputeType, 2, algorithm>(\n          stream, load_y, load_dy, store, rows, cols);\n    } else {\n      return LaunchSoftmaxGradBlockUncachedImpl<LOAD_Y, LOAD_DY, STORE, ComputeType, 1, algorithm>(\n          stream, load_y, load_dy, store, rows, cols);\n    }\n  }\n};\n\ntemplate<typename LOAD_Y, typename LOAD_DY, typename STORE, typename ComputeType,\n         Algorithm algorithm>\ninline cudaError_t DispatchSoftmaxGradBlockUncachedImpl(cudaStream_t stream, LOAD_Y load_y,\n                                                        LOAD_DY load_dy, STORE store,\n                                                        const int64_t rows, const int64_t cols) {\n  return DispatchSoftmaxGradBlockUncachedImplPackSize<LOAD_Y, LOAD_DY, STORE, ComputeType,\n                                                      algorithm>()(stream, load_y, load_dy, store,\n                                                                   rows, cols);\n}\n\ntemplate<typename LOAD_Y, typename LOAD_DY, typename STORE, typename ComputeType>\ninline typename std::enable_if<!std::is_same<ComputeType, double>::value, cudaError_t>::type\nDispatchSoftmaxGrad(cudaStream_t stream, LOAD_Y load_y, LOAD_DY load_dy, STORE store,\n                    const int64_t rows, const int64_t cols) {\n  if (cols <= 1024) {\n    return DispatchSoftmaxGradWarpImpl<LOAD_Y, LOAD_DY, STORE, ComputeType, Algorithm::kSoftmax>(\n        stream, load_y, load_dy, store, rows, cols);\n  } else {\n    bool dispatch_smem_impl_success;\n    {\n      cudaError_t err = TryDispatchSoftmaxGradBlockSMemImpl<LOAD_Y, LOAD_DY, STORE, ComputeType,\n                                                            Algorithm::kSoftmax>(\n          stream, load_y, load_dy, store, rows, cols, &dispatch_smem_impl_success);\n      if (err != cudaSuccess) { return err; }\n    }\n    if (!dispatch_smem_impl_success) {\n      return DispatchSoftmaxGradBlockUncachedImpl<LOAD_Y, LOAD_DY, STORE, ComputeType,\n                                                  Algorithm::kSoftmax>(stream, load_y, load_dy,\n                                                                       store, rows, cols);\n    }\n    return cudaSuccess;\n  }\n}\n\ntemplate<typename LOAD_Y, typename LOAD_DY, typename STORE, typename ComputeType>\ninline typename std::enable_if<std::is_same<ComputeType, double>::value, cudaError_t>::type\nDispatchSoftmaxGrad(cudaStream_t stream, LOAD_Y load_y, LOAD_DY load_dy, STORE store,\n                    const int64_t rows, const int64_t cols) {\n  return DispatchSoftmaxGradBlockUncachedImpl<LOAD_Y, LOAD_DY, STORE, ComputeType,\n                                              Algorithm::kSoftmax>(stream, load_y, load_dy, store,\n                                                                   rows, cols);\n}\n\ntemplate<typename LOAD_Y, typename LOAD_DY, typename STORE, typename ComputeType>\ninline typename std::enable_if<!std::is_same<ComputeType, double>::value, cudaError_t>::type\nDispatchLogSoftmaxGrad(cudaStream_t stream, LOAD_Y load_y, LOAD_DY load_dy, STORE store,\n                       const int64_t rows, const int64_t cols) {\n  if (cols <= 1024) {\n    return DispatchSoftmaxGradWarpImpl<LOAD_Y, LOAD_DY, STORE, ComputeType, Algorithm::kLogSoftmax>(\n        stream, load_y, load_dy, store, rows, cols);\n  } else {\n    bool dispatch_smem_impl_success;\n    {\n      cudaError_t err = TryDispatchSoftmaxGradBlockSMemImpl<LOAD_Y, LOAD_DY, STORE, ComputeType,\n                                                            Algorithm::kLogSoftmax>(\n          stream, load_y, load_dy, store, rows, cols, &dispatch_smem_impl_success);\n      if (err != cudaSuccess) { return err; }\n    }\n    if (!dispatch_smem_impl_success) {\n      return DispatchSoftmaxGradBlockUncachedImpl<LOAD_Y, LOAD_DY, STORE, ComputeType,\n                                                  Algorithm::kLogSoftmax>(stream, load_y, load_dy,\n                                                                          store, rows, cols);\n    }\n    return cudaSuccess;\n  }\n}\n\ntemplate<typename LOAD_Y, typename LOAD_DY, typename STORE, typename ComputeType>\ninline typename std::enable_if<std::is_same<ComputeType, double>::value, cudaError_t>::type\nDispatchLogSoftmaxGrad(cudaStream_t stream, LOAD_Y load_y, LOAD_DY load_dy, STORE store,\n                       const int64_t rows, const int64_t cols) {\n  return DispatchSoftmaxGradBlockUncachedImpl<LOAD_Y, LOAD_DY, STORE, ComputeType,\n                                              Algorithm::kLogSoftmax>(stream, load_y, load_dy,\n                                                                      store, rows, cols);\n}\n\n}  // namespace softmax\n\n}  // namespace cuda\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_CUDA_SOFTMAX_H_\n"
  },
  {
    "path": "oneflow/core/cuda/unique.cuh",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_CUDA_UNIQUE_H_\n#define ONEFLOW_CORE_CUDA_UNIQUE_H_\n\n#include <cub/cub.cuh>\n#include <device_launch_parameters.h>\n#include \"oneflow/core/common/permutation_iterator.h\"\n#include \"oneflow/core/common/not_equal_to_previous_adjacent_iterator.h\"\n\nnamespace oneflow {\n\nnamespace cuda {\n\nnamespace unique {\n\nusing Flag = uint32_t;\nstatic constexpr Flag kDefault = 0x0;\nstatic constexpr Flag kInputSorted = 0x1;\nstatic constexpr Flag kOutputInverseIndices = 0x1 << 1;\nstatic constexpr Flag kOutputCounts = 0x1 << 2;\n\nnamespace {\n\nconstexpr size_t kCudaAlignSize = 512;\n\n__device__ __host__ __forceinline__ size_t GetCudaAlignedSize(size_t size) {\n  return (size + kCudaAlignSize - 1) / kCudaAlignSize * kCudaAlignSize;\n}\n\ntemplate<typename T>\n__device__ __host__ __forceinline__ T* PtrOffset(void* ptr, size_t offset) {\n  return reinterpret_cast<T*>(reinterpret_cast<unsigned char*>(ptr) + offset);\n}\n\n__device__ __host__ __forceinline__ size_t max(size_t a, size_t b) { return a > b ? a : b; }\n\ntemplate<typename Key, typename Index>\ncudaError_t DoUnique(size_t n, const Key* sorted_in, Key* unique, Index* num_unique,\n                     void* workspace, size_t* workspace_size, cudaStream_t stream) {\n  size_t ws = *workspace_size;\n  cudaError_t err = cub::DeviceSelect::Unique<const Key*, Key*, Index*>(\n      workspace, ws, sorted_in, unique, num_unique, n, stream);\n  if (err != cudaSuccess) { return err; }\n  if (*workspace_size == 0) { *workspace_size = ws; }\n  return cudaSuccess;\n}\n\ntemplate<typename Key, typename Index>\ncudaError_t DoUniqueWithCounts(size_t n, const Key* sorted_in, Key* unique, Index* num_unique,\n                               Index* counts, void* workspace, size_t* workspace_size,\n                               cudaStream_t stream) {\n  size_t ws = *workspace_size;\n  cudaError_t err = cub::DeviceRunLengthEncode::Encode<const Key*, Key*, Index*, Index*>(\n      workspace, ws, sorted_in, unique, counts, num_unique, n, stream);\n  if (err != cudaSuccess) { return err; }\n  if (*workspace_size == 0) { *workspace_size = ws; }\n  return cudaSuccess;\n}\n\ntemplate<typename Key, typename Index>\ncudaError_t DispatchOutputCounts(Flag flag, size_t n, const Key* sorted_in, Key* unique,\n                                 Index* num_unique, Index* counts, void* workspace,\n                                 size_t* workspace_size, cudaStream_t stream) {\n  size_t ws = *workspace_size;\n  if ((flag & kOutputCounts) != 0) {\n    cudaError_t err = DoUniqueWithCounts<Key, Index>(n, sorted_in, unique, num_unique, counts,\n                                                     workspace, &ws, stream);\n    if (err != cudaSuccess) { return err; }\n  } else {\n    cudaError_t err =\n        DoUnique<Key, Index>(n, sorted_in, unique, num_unique, workspace, &ws, stream);\n    if (err != cudaSuccess) { return err; }\n  }\n  if (*workspace_size == 0) { *workspace_size = ws; }\n  return cudaSuccess;\n}\n\ntemplate<typename Key, typename Index, typename InverseIndicesIter>\ncudaError_t DoGenInverseIndices(size_t n, const Key* sorted_in,\n                                InverseIndicesIter inverse_indices_iter, void* workspace,\n                                size_t* workspace_size, cudaStream_t stream) {\n  size_t ws = *workspace_size;\n  NotEqualToPreviousAdjacentIterator<Index, Key> unique_counting_iter(sorted_in, 0);\n  cudaError_t err =\n      cub::DeviceScan::InclusiveSum<decltype(unique_counting_iter), InverseIndicesIter>(\n          workspace, ws, unique_counting_iter, inverse_indices_iter, n, stream);\n  if (err != cudaSuccess) { return err; }\n  if (*workspace_size == 0) { *workspace_size = ws; }\n  return cudaSuccess;\n}\n\ntemplate<typename Key, typename Index, typename InverseIndicesIter>\ncudaError_t DispatchOutputInverseIndices(Flag flag, size_t n, const Key* sorted_in, Key* unique,\n                                         Index* num_unique, InverseIndicesIter inverse_indices_iter,\n                                         Index* counts, void* workspace, size_t* workspace_size,\n                                         cudaStream_t stream) {\n  size_t dispatch_with_counts_ws = *workspace_size;\n  size_t do_gen_inverse_indices_ws = *workspace_size;\n  {\n    cudaError_t err =\n        DispatchOutputCounts<Key, Index>(flag, n, sorted_in, unique, num_unique, counts, workspace,\n                                         &dispatch_with_counts_ws, stream);\n    if (err != cudaSuccess) { return err; }\n  }\n  if ((flag & kOutputInverseIndices) != 0) {\n    cudaError_t err = DoGenInverseIndices<Key, Index, InverseIndicesIter>(\n        n, sorted_in, inverse_indices_iter, workspace, &do_gen_inverse_indices_ws, stream);\n    if (err != cudaSuccess) { return err; }\n  }\n  if (*workspace_size == 0) {\n    *workspace_size = max(dispatch_with_counts_ws, do_gen_inverse_indices_ws);\n  }\n  return cudaSuccess;\n}\n\ntemplate<typename T>\n__global__ void IotaKernel(size_t n, T* out) {\n  for (T i = blockIdx.x * blockDim.x + threadIdx.x, step = blockDim.x * gridDim.x; i < n;\n       i += step) {\n    out[i] = i;\n  }\n}\n\ntemplate<typename Key, typename Index>\ncudaError_t DoSort(size_t n, const Key* in, Key* sorted, Index* sorted_indices, void* workspace,\n                   size_t* workspace_size, cudaStream_t stream) {\n  Index* indices;\n  const size_t indices_size = GetCudaAlignedSize(n * sizeof(Index));\n  void* sort_workspace;\n  size_t sort_ws;\n  if (*workspace_size == 0) {\n    indices = nullptr;\n    sort_workspace = nullptr;\n    sort_ws = 0;\n  } else {\n    if (*workspace_size <= indices_size) { return cudaErrorInvalidValue; }\n    indices = PtrOffset<Index>(workspace, 0);\n    sort_workspace = PtrOffset<Index>(workspace, indices_size);\n    sort_ws = *workspace_size - indices_size;\n  }\n  if (*workspace_size != 0) {\n    const int block_size = 1024;\n    const int num_blocks = static_cast<int>((n + block_size - 1) / block_size);\n    IotaKernel<Index><<<num_blocks, block_size, 0, stream>>>(n, indices);\n  }\n  cudaError_t err = cub::DeviceRadixSort::SortPairs<Key, Index>(\n      sort_workspace, sort_ws, in, sorted, indices, sorted_indices, n, 0, sizeof(Key) * 8, stream);\n  if (err != cudaSuccess) { return err; }\n  if (*workspace_size == 0) { *workspace_size = indices_size + sort_ws; }\n  return cudaSuccess;\n}\n\ntemplate<typename Key, typename Index>\ncudaError_t DispatchInputSorted(Flag flag, size_t n, const Key* in, Key* unique, Index* num_unique,\n                                Index* inverse_indices, Index* counts, void* workspace,\n                                size_t* workspace_size, cudaStream_t stream) {\n  if ((flag & kInputSorted) != 0) {\n    return DispatchOutputInverseIndices<Key, Index, Index*>(flag, n, in, unique, num_unique,\n                                                            inverse_indices, counts, workspace,\n                                                            workspace_size, stream);\n  } else {\n    const size_t sorted_in_size = GetCudaAlignedSize(n * sizeof(Key));\n    const size_t sorted_indices_size = GetCudaAlignedSize(n * sizeof(Index));\n    const size_t sort_buffer_size = sorted_in_size + sorted_indices_size;\n    Key* sorted_in;\n    Index* sorted_indices;\n    size_t do_sort_ws;\n    void* do_sort_workspace;\n    size_t do_inverse_indices_ws;\n    void* do_inverse_indices_workspace;\n    if (*workspace_size == 0) {\n      sorted_in = nullptr;\n      sorted_indices = nullptr;\n      do_sort_ws = 0;\n      do_sort_workspace = nullptr;\n      do_inverse_indices_ws = 0;\n      do_inverse_indices_workspace = nullptr;\n    } else {\n      if (*workspace_size <= sort_buffer_size) { return cudaErrorInvalidValue; }\n      sorted_in = PtrOffset<Key>(workspace, 0);\n      sorted_indices = PtrOffset<Index>(workspace, sorted_in_size);\n      do_sort_ws = *workspace_size - sort_buffer_size;\n      do_sort_workspace = PtrOffset<void>(workspace, sort_buffer_size);\n      do_inverse_indices_ws = do_sort_ws;\n      do_inverse_indices_workspace = do_sort_workspace;\n    }\n    {\n      cudaError_t err = DoSort<Key, Index>(n, in, sorted_in, sorted_indices, do_sort_workspace,\n                                           &do_sort_ws, stream);\n      if (err != cudaSuccess) { return err; }\n    }\n    PermutationIterator<Index, Index*, Index*> inverse_indices_iter(inverse_indices,\n                                                                    sorted_indices);\n    {\n      cudaError_t err = DispatchOutputInverseIndices<Key, Index, decltype(inverse_indices_iter)>(\n          flag, n, sorted_in, unique, num_unique, inverse_indices_iter, counts,\n          do_inverse_indices_workspace, &do_inverse_indices_ws, stream);\n      if (err != cudaSuccess) { return err; }\n    }\n    if (*workspace_size == 0) {\n      *workspace_size = sort_buffer_size + max(do_sort_ws, do_inverse_indices_ws);\n    }\n    return cudaSuccess;\n  }\n}\n\n}  // namespace\n\ntemplate<typename Key, typename Index>\ncudaError_t Launch(Flag flag, size_t n, const Key* in, Key* unique, Index* num_unique,\n                   Index* inverse_indices, Index* counts, void* workspace, size_t workspace_size,\n                   cudaStream_t stream) {\n  if (workspace_size == 0) { return cudaErrorInvalidValue; }\n  return DispatchInputSorted<Key, Index>(flag, n, in, unique, num_unique, inverse_indices, counts,\n                                         workspace, &workspace_size, stream);\n}\n\ntemplate<typename Key, typename Index>\ncudaError_t GetWorkspaceSize(Flag flag, size_t n, size_t* workspace_size) {\n  *workspace_size = 0;\n  return DispatchInputSorted<Key, Index>(flag, n, nullptr, nullptr, nullptr, nullptr, nullptr,\n                                         nullptr, workspace_size, 0);\n}\n\n}  // namespace unique\n\n}  // namespace cuda\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_CUDA_UNIQUE_H_\n"
  },
  {
    "path": "oneflow/core/device/cuda_pseudo_bfloat16.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_DEVICE_CUDA_PSEUDO_BFLOAT16_H_\n#define ONEFLOW_CORE_DEVICE_CUDA_PSEUDO_BFLOAT16_H_\n\n#ifdef WITH_CUDA\n\n#include <cuda.h>\n#include <cuda_runtime_api.h>\n#if CUDA_VERSION >= 11000\n#include <cuda_bf16.h>\n#endif\n\n#if CUDA_VERSION >= 11000 && CUDA_VERSION <= 12010 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800\n\n#define DEFINE_CUDA_PSEUDO_BFLOAT16_ARITHMETIC_BINARY_OPERATOR(op)                \\\n  __device__ __forceinline__ __nv_bfloat16 operator op(const __nv_bfloat16& lh,   \\\n                                                       const __nv_bfloat16& rh) { \\\n    return __float2bfloat16(__bfloat162float(lh) op __bfloat162float(rh));        \\\n  }\n\nDEFINE_CUDA_PSEUDO_BFLOAT16_ARITHMETIC_BINARY_OPERATOR(+)\nDEFINE_CUDA_PSEUDO_BFLOAT16_ARITHMETIC_BINARY_OPERATOR(-)\nDEFINE_CUDA_PSEUDO_BFLOAT16_ARITHMETIC_BINARY_OPERATOR(*)\nDEFINE_CUDA_PSEUDO_BFLOAT16_ARITHMETIC_BINARY_OPERATOR(/)\n\n#undef DEFINE_CUDA_PSEUDO_BFLOAT16_ARITHMETIC_BINARY_OPERATOR\n\n#define DEFINE_CUDA_PSEUDO_BFLOAT16_ARITHMETIC_BINARY_FUNC(func)                       \\\n  __device__ __forceinline__ __nv_bfloat16 __h##func(const __nv_bfloat16 a,            \\\n                                                     const __nv_bfloat16 b) {          \\\n    return __float2bfloat16(__f##func##_rn(__bfloat162float(a), __bfloat162float(b))); \\\n  }\n\nDEFINE_CUDA_PSEUDO_BFLOAT16_ARITHMETIC_BINARY_FUNC(add)\nDEFINE_CUDA_PSEUDO_BFLOAT16_ARITHMETIC_BINARY_FUNC(div)\nDEFINE_CUDA_PSEUDO_BFLOAT16_ARITHMETIC_BINARY_FUNC(mul)\nDEFINE_CUDA_PSEUDO_BFLOAT16_ARITHMETIC_BINARY_FUNC(sub)\n\n#undef DEFINE_CUDA_PSEUDO_BFLOAT16_BFLOAT162_ARITHMETIC_BINARY_FUNC\n\n#define DEFINE_CUDA_PSEUDO_BFLOAT16_BFLOAT162_ARITHMETIC_BINARY_FUNC(func)         \\\n  __device__ __forceinline__ __nv_bfloat162 __h##func##2(const __nv_bfloat162 a,   \\\n                                                         const __nv_bfloat162 b) { \\\n    __nv_bfloat162 ret;                                                            \\\n    ret.x = __h##func(a.x, b.x);                                                   \\\n    ret.y = __h##func(a.y, b.y);                                                   \\\n    return ret;                                                                    \\\n  }\n\nDEFINE_CUDA_PSEUDO_BFLOAT16_BFLOAT162_ARITHMETIC_BINARY_FUNC(add)\nDEFINE_CUDA_PSEUDO_BFLOAT16_BFLOAT162_ARITHMETIC_BINARY_FUNC(div)\nDEFINE_CUDA_PSEUDO_BFLOAT16_BFLOAT162_ARITHMETIC_BINARY_FUNC(mul)\nDEFINE_CUDA_PSEUDO_BFLOAT16_BFLOAT162_ARITHMETIC_BINARY_FUNC(sub)\n\n#undef DEFINE_CUDA_PSEUDO_BFLOAT16_BFLOAT162_ARITHMETIC_BINARY_FUNC\n\n#define DEFINE_CUDA_PSEUDO_BFLOAT16_ARITHMETIC_ASSIGNMENT_OPERATOR(op)             \\\n  __device__ __forceinline__ __nv_bfloat16& operator op(__nv_bfloat16& lh,         \\\n                                                        const __nv_bfloat16& rh) { \\\n    float lhv = __bfloat162float(lh);                                              \\\n    lhv op __bfloat162float(rh);                                                   \\\n    lh = __float2bfloat16(lhv);                                                    \\\n    return lh;                                                                     \\\n  }\n\nDEFINE_CUDA_PSEUDO_BFLOAT16_ARITHMETIC_ASSIGNMENT_OPERATOR(+=)\nDEFINE_CUDA_PSEUDO_BFLOAT16_ARITHMETIC_ASSIGNMENT_OPERATOR(-=)\nDEFINE_CUDA_PSEUDO_BFLOAT16_ARITHMETIC_ASSIGNMENT_OPERATOR(*=)\nDEFINE_CUDA_PSEUDO_BFLOAT16_ARITHMETIC_ASSIGNMENT_OPERATOR(/=)\n\n#undef DEFINE_CUDA_PSEUDO_BFLOAT16_ARITHMETIC_ASSIGNMENT_OPERATOR\n\n__device__ __forceinline__ __nv_bfloat16& operator++(__nv_bfloat16& h) {\n  h = __float2bfloat16(__bfloat162float(h) + 1);\n  return h;\n}\n\n__device__ __forceinline__ __nv_bfloat16& operator--(__nv_bfloat16& h) {\n  h = __float2bfloat16(__bfloat162float(h) - 1);\n  return h;\n}\n\n__device__ __forceinline__ __nv_bfloat16 operator++(__nv_bfloat16& h, int) {\n  __nv_bfloat16 ret = h;\n  h = __float2bfloat16(__bfloat162float(h) + 1);\n  return ret;\n}\n\n__device__ __forceinline__ __nv_bfloat16 operator--(__nv_bfloat16& h, int) {\n  __nv_bfloat16 ret = h;\n  h = __float2bfloat16(__bfloat162float(h) - 1);\n  return ret;\n}\n\n__device__ __forceinline__ __nv_bfloat16 operator+(const __nv_bfloat16& h) { return h; }\n\n__device__ __forceinline__ __nv_bfloat16 operator-(const __nv_bfloat16& h) {\n  return __float2bfloat16(-__bfloat162float(h));\n}\n\n__device__ __forceinline__ __nv_bfloat16 __hneg(const __nv_bfloat16 a) { return -a; }\n\n#define DEFINE_CUDA_PSEUDO_BFLOAT16_COMPARISON_BINARY_OPERATOR(op)                                \\\n  __device__ __forceinline__ bool operator op(const __nv_bfloat16& lh, const __nv_bfloat16& rh) { \\\n    return __bfloat162float(lh) op __bfloat162float(rh);                                          \\\n  }\n\nDEFINE_CUDA_PSEUDO_BFLOAT16_COMPARISON_BINARY_OPERATOR(==)\nDEFINE_CUDA_PSEUDO_BFLOAT16_COMPARISON_BINARY_OPERATOR(!=)\nDEFINE_CUDA_PSEUDO_BFLOAT16_COMPARISON_BINARY_OPERATOR(>)\nDEFINE_CUDA_PSEUDO_BFLOAT16_COMPARISON_BINARY_OPERATOR(<)\nDEFINE_CUDA_PSEUDO_BFLOAT16_COMPARISON_BINARY_OPERATOR(>=)\nDEFINE_CUDA_PSEUDO_BFLOAT16_COMPARISON_BINARY_OPERATOR(<=)\n\n#undef DEFINE_CUDA_PSEUDO_BFLOAT16_COMPARISON_BINARY_OPERATOR\n\n__device__ __forceinline__ bool __heq(const __nv_bfloat16 a, const __nv_bfloat16 b) {\n  return a == b;\n}\n__device__ __forceinline__ bool __hge(const __nv_bfloat16 a, const __nv_bfloat16 b) {\n  return a >= b;\n}\n__device__ __forceinline__ bool __hgt(const __nv_bfloat16 a, const __nv_bfloat16 b) {\n  return a > b;\n}\n__device__ __forceinline__ bool __hle(const __nv_bfloat16 a, const __nv_bfloat16 b) {\n  return a <= b;\n}\n__device__ __forceinline__ bool __hlt(const __nv_bfloat16 a, const __nv_bfloat16 b) {\n  return a < b;\n}\n__device__ __forceinline__ bool __hne(const __nv_bfloat16 a, const __nv_bfloat16 b) {\n  return a != b;\n}\n__device__ __forceinline__ __nv_bfloat16 __hmax(const __nv_bfloat16 a, const __nv_bfloat16 b) {\n  return a > b ? a : b;\n}\n__device__ __forceinline__ __nv_bfloat16 __hmin(const __nv_bfloat16 a, const __nv_bfloat16 b) {\n  return a > b ? a : b;\n}\n\n#define DEFINE_CUDA_PSEUDO_BFLOAT16_MATH_FUNC(func)                         \\\n  __device__ __forceinline__ __nv_bfloat16 h##func(const __nv_bfloat16 h) { \\\n    return __float2bfloat16(func##f(__bfloat162float(h)));                  \\\n  }\n\nDEFINE_CUDA_PSEUDO_BFLOAT16_MATH_FUNC(cos)\nDEFINE_CUDA_PSEUDO_BFLOAT16_MATH_FUNC(exp)\nDEFINE_CUDA_PSEUDO_BFLOAT16_MATH_FUNC(exp10)\nDEFINE_CUDA_PSEUDO_BFLOAT16_MATH_FUNC(exp2)\nDEFINE_CUDA_PSEUDO_BFLOAT16_MATH_FUNC(log)\nDEFINE_CUDA_PSEUDO_BFLOAT16_MATH_FUNC(log10)\nDEFINE_CUDA_PSEUDO_BFLOAT16_MATH_FUNC(log2)\n\n__device__ __forceinline__ __nv_bfloat16 hrcp(const __nv_bfloat16 h) {\n  return __float2bfloat16(1.0f / __bfloat162float(h));\n}\n\nDEFINE_CUDA_PSEUDO_BFLOAT16_MATH_FUNC(rsqrt)\nDEFINE_CUDA_PSEUDO_BFLOAT16_MATH_FUNC(sin)\nDEFINE_CUDA_PSEUDO_BFLOAT16_MATH_FUNC(sqrt)\n\n#undef DEFINE_CUDA_PSEUDO_BFLOAT16_MATH_FUNC\n\n#endif  // CUDA_VERSION >= 11000 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800\n\n#endif  // WITH_CUDA\n\n#endif  // ONEFLOW_CORE_DEVICE_CUDA_PSEUDO_BFLOAT16_H_\n"
  },
  {
    "path": "oneflow/core/device/cuda_pseudo_half.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_DEVICE_CUDA_PSEUDO_HALF_H_\n#define ONEFLOW_CORE_DEVICE_CUDA_PSEUDO_HALF_H_\n\n#ifdef WITH_CUDA\n\n#include <cuda.h>\n#include <cuda_fp16.h>\n#include <cuda_runtime_api.h>\n\n#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 530\n\n#define DEFINE_CUDA_PSEUDO_HALF_ARITHMETIC_BINARY_OPERATOR(op)                        \\\n  __device__ __forceinline__ __half operator op(const __half& lh, const __half& rh) { \\\n    return __float2half(__half2float(lh) op __half2float(rh));                        \\\n  }\n\nDEFINE_CUDA_PSEUDO_HALF_ARITHMETIC_BINARY_OPERATOR(+)\nDEFINE_CUDA_PSEUDO_HALF_ARITHMETIC_BINARY_OPERATOR(-)\nDEFINE_CUDA_PSEUDO_HALF_ARITHMETIC_BINARY_OPERATOR(*)\nDEFINE_CUDA_PSEUDO_HALF_ARITHMETIC_BINARY_OPERATOR(/)\n\n#undef DEFINE_CUDA_PSEUDO_HALF_ARITHMETIC_BINARY_OPERATOR\n\n#define DEFINE_CUDA_PSEUDO_HALF_ARITHMETIC_BINARY_FUNC(func)                    \\\n  __device__ __forceinline__ __half __h##func(const __half a, const __half b) { \\\n    return __float2half(__f##func##_rn(__half2float(a), __half2float(b)));      \\\n  }\n\nDEFINE_CUDA_PSEUDO_HALF_ARITHMETIC_BINARY_FUNC(add)\nDEFINE_CUDA_PSEUDO_HALF_ARITHMETIC_BINARY_FUNC(div)\nDEFINE_CUDA_PSEUDO_HALF_ARITHMETIC_BINARY_FUNC(mul)\nDEFINE_CUDA_PSEUDO_HALF_ARITHMETIC_BINARY_FUNC(sub)\n\n#undef DEFINE_CUDA_PSEUDO_HALF_HALF2_ARITHMETIC_BINARY_FUNC\n\n#define DEFINE_CUDA_PSEUDO_HALF_HALF2_ARITHMETIC_BINARY_FUNC(func)                    \\\n  __device__ __forceinline__ __half2 __h##func##2(const __half2 a, const __half2 b) { \\\n    __half2 ret;                                                                      \\\n    ret.x = __h##func(a.x, b.x);                                                      \\\n    ret.y = __h##func(a.y, b.y);                                                      \\\n    return ret;                                                                       \\\n  }\n\nDEFINE_CUDA_PSEUDO_HALF_HALF2_ARITHMETIC_BINARY_FUNC(add)\nDEFINE_CUDA_PSEUDO_HALF_HALF2_ARITHMETIC_BINARY_FUNC(div)\nDEFINE_CUDA_PSEUDO_HALF_HALF2_ARITHMETIC_BINARY_FUNC(mul)\nDEFINE_CUDA_PSEUDO_HALF_HALF2_ARITHMETIC_BINARY_FUNC(sub)\n\n#undef DEFINE_CUDA_PSEUDO_HALF_HALF2_ARITHMETIC_BINARY_FUNC\n\n#define DEFINE_CUDA_PSEUDO_HALF_ARITHMETIC_ASSIGNMENT_OPERATOR(op)               \\\n  __device__ __forceinline__ __half& operator op(__half& lh, const __half& rh) { \\\n    float lhv = __half2float(lh);                                                \\\n    lhv op __half2float(rh);                                                     \\\n    lh = __float2half(lhv);                                                      \\\n    return lh;                                                                   \\\n  }\n\nDEFINE_CUDA_PSEUDO_HALF_ARITHMETIC_ASSIGNMENT_OPERATOR(+=)\nDEFINE_CUDA_PSEUDO_HALF_ARITHMETIC_ASSIGNMENT_OPERATOR(-=)\nDEFINE_CUDA_PSEUDO_HALF_ARITHMETIC_ASSIGNMENT_OPERATOR(*=)\nDEFINE_CUDA_PSEUDO_HALF_ARITHMETIC_ASSIGNMENT_OPERATOR(/=)\n\n#undef DEFINE_CUDA_PSEUDO_HALF_ARITHMETIC_ASSIGNMENT_OPERATOR\n\n__device__ __forceinline__ __half& operator++(__half& h) {\n  h = __float2half(__half2float(h) + 1);\n  return h;\n}\n\n__device__ __forceinline__ __half& operator--(__half& h) {\n  h = __float2half(__half2float(h) - 1);\n  return h;\n}\n\n__device__ __forceinline__ __half operator++(__half& h, int) {\n  __half ret = h;\n  h = __float2half(__half2float(h) + 1);\n  return ret;\n}\n\n__device__ __forceinline__ __half operator--(__half& h, int) {\n  __half ret = h;\n  h = __float2half(__half2float(h) - 1);\n  return ret;\n}\n\n__device__ __forceinline__ __half operator+(const __half& h) { return h; }\n\n__device__ __forceinline__ __half operator-(const __half& h) {\n  return __float2half(-__half2float(h));\n}\n\n__device__ __forceinline__ __half __hneg(const __half a) { return -a; }\n\n#define DEFINE_CUDA_PSEUDO_HALF_COMPARISON_BINARY_OPERATOR(op)                      \\\n  __device__ __forceinline__ bool operator op(const __half& lh, const __half& rh) { \\\n    return __half2float(lh) op __half2float(rh);                                    \\\n  }\n\nDEFINE_CUDA_PSEUDO_HALF_COMPARISON_BINARY_OPERATOR(==)\nDEFINE_CUDA_PSEUDO_HALF_COMPARISON_BINARY_OPERATOR(!=)\nDEFINE_CUDA_PSEUDO_HALF_COMPARISON_BINARY_OPERATOR(>)\nDEFINE_CUDA_PSEUDO_HALF_COMPARISON_BINARY_OPERATOR(<)\nDEFINE_CUDA_PSEUDO_HALF_COMPARISON_BINARY_OPERATOR(>=)\nDEFINE_CUDA_PSEUDO_HALF_COMPARISON_BINARY_OPERATOR(<=)\n\n#undef DEFINE_CUDA_PSEUDO_HALF_COMPARISON_BINARY_OPERATOR\n\n__device__ __forceinline__ bool __heq(const __half a, const __half b) { return a == b; }\n__device__ __forceinline__ bool __hge(const __half a, const __half b) { return a >= b; }\n__device__ __forceinline__ bool __hgt(const __half a, const __half b) { return a > b; }\n__device__ __forceinline__ bool __hle(const __half a, const __half b) { return a <= b; }\n__device__ __forceinline__ bool __hlt(const __half a, const __half b) { return a < b; }\n__device__ __forceinline__ bool __hne(const __half a, const __half b) { return a != b; }\n__device__ __forceinline__ __half __hmax(const __half a, const __half b) { return a > b ? a : b; }\n__device__ __forceinline__ __half __hmin(const __half a, const __half b) { return a > b ? a : b; }\n\n#define DEFINE_CUDA_PSEUDO_HALF_MATH_FUNC(func)               \\\n  __device__ __forceinline__ __half h##func(const __half h) { \\\n    return __float2half(func##f(__half2float(h)));            \\\n  }\n\nDEFINE_CUDA_PSEUDO_HALF_MATH_FUNC(cos)\nDEFINE_CUDA_PSEUDO_HALF_MATH_FUNC(exp)\nDEFINE_CUDA_PSEUDO_HALF_MATH_FUNC(exp10)\nDEFINE_CUDA_PSEUDO_HALF_MATH_FUNC(exp2)\nDEFINE_CUDA_PSEUDO_HALF_MATH_FUNC(log)\nDEFINE_CUDA_PSEUDO_HALF_MATH_FUNC(log10)\nDEFINE_CUDA_PSEUDO_HALF_MATH_FUNC(log2)\n\n__device__ __forceinline__ __half hrcp(const __half h) {\n  return __float2half(1.0f / __half2float(h));\n}\n\nDEFINE_CUDA_PSEUDO_HALF_MATH_FUNC(rsqrt)\nDEFINE_CUDA_PSEUDO_HALF_MATH_FUNC(sin)\nDEFINE_CUDA_PSEUDO_HALF_MATH_FUNC(sqrt)\n\n#undef DEFINE_CUDA_PSEUDO_HALF_MATH_FUNC\n\n#endif  // defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 530\n\n#endif  // WITH_CUDA\n\n#endif  // ONEFLOW_CORE_DEVICE_CUDA_PSEUDO_HALF_H_\n"
  },
  {
    "path": "oneflow/core/device/cuda_util.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <mutex>\n#include \"oneflow/core/device/cuda_util.h\"\n#include \"oneflow/core/common/singleton.h\"\n#include \"oneflow/core/hardware/node_device_descriptor_manager.h\"\n#include \"oneflow/core/hardware/cuda_device_descriptor.h\"\n#include \"oneflow/core/rpc/include/global_process_ctx.h\"\n#include \"oneflow/core/job/env_global_objects_scope.h\"\n#include \"oneflow/core/job/lazy_mode.h\"\n#include \"oneflow/core/platform/include/pthread_fork.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n#include \"oneflow/core/vm/vm_util.h\"\n\n#ifdef WITH_CUDA\n\n#include <cuda.h>\n\n#endif  // WITH_CUDA\n\nnamespace oneflow {\n\n#ifdef WITH_CUDA\n\nconst char* CublasGetErrorString(cublasStatus_t error) {\n  switch (error) {\n    case CUBLAS_STATUS_SUCCESS: return \"CUBLAS_STATUS_SUCCESS\";\n    case CUBLAS_STATUS_NOT_INITIALIZED: return \"CUBLAS_STATUS_NOT_INITIALIZED\";\n    case CUBLAS_STATUS_ALLOC_FAILED: return \"CUBLAS_STATUS_ALLOC_FAILED\";\n    case CUBLAS_STATUS_INVALID_VALUE: return \"CUBLAS_STATUS_INVALID_VALUE\";\n    case CUBLAS_STATUS_ARCH_MISMATCH: return \"CUBLAS_STATUS_ARCH_MISMATCH\";\n    case CUBLAS_STATUS_MAPPING_ERROR: return \"CUBLAS_STATUS_MAPPING_ERROR\";\n    case CUBLAS_STATUS_EXECUTION_FAILED: return \"CUBLAS_STATUS_EXECUTION_FAILED\";\n    case CUBLAS_STATUS_INTERNAL_ERROR: return \"CUBLAS_STATUS_INTERNAL_ERROR\";\n#if CUDA_VERSION >= 6000\n    case CUBLAS_STATUS_NOT_SUPPORTED: return \"CUBLAS_STATUS_NOT_SUPPORTED\";\n#endif\n#if CUDA_VERSION >= 6050\n    case CUBLAS_STATUS_LICENSE_ERROR: return \"CUBLAS_STATUS_LICENSE_ERROR\";\n#endif\n    default: return \"Unknown cublas status\";\n  }\n}\n\nconst char* CurandGetErrorString(curandStatus_t error) {\n  switch (error) {\n    case CURAND_STATUS_SUCCESS: return \"CURAND_STATUS_SUCCESS\";\n    case CURAND_STATUS_VERSION_MISMATCH: return \"CURAND_STATUS_VERSION_MISMATCH\";\n    case CURAND_STATUS_NOT_INITIALIZED: return \"CURAND_STATUS_NOT_INITIALIZED\";\n    case CURAND_STATUS_ALLOCATION_FAILED: return \"CURAND_STATUS_ALLOCATION_FAILED\";\n    case CURAND_STATUS_TYPE_ERROR: return \"CURAND_STATUS_TYPE_ERROR\";\n    case CURAND_STATUS_OUT_OF_RANGE: return \"CURAND_STATUS_OUT_OF_RANGE\";\n    case CURAND_STATUS_LENGTH_NOT_MULTIPLE: return \"CURAND_STATUS_LENGTH_NOT_MULTIPLE\";\n    case CURAND_STATUS_DOUBLE_PRECISION_REQUIRED: return \"CURAND_STATUS_DOUBLE_PRECISION_REQUIRED\";\n    case CURAND_STATUS_LAUNCH_FAILURE: return \"CURAND_STATUS_LAUNCH_FAILURE\";\n    case CURAND_STATUS_PREEXISTING_FAILURE: return \"CURAND_STATUS_PREEXISTING_FAILURE\";\n    case CURAND_STATUS_INITIALIZATION_FAILED: return \"CURAND_STATUS_INITIALIZATION_FAILED\";\n    case CURAND_STATUS_ARCH_MISMATCH: return \"CURAND_STATUS_ARCH_MISMATCH\";\n    case CURAND_STATUS_INTERNAL_ERROR: return \"CURAND_STATUS_INTERNAL_ERROR\";\n    default: return \"Unknown curand status\";\n  }\n}\n\nconst char* CuFFTGetErrorString(cufftResult_t error) {\n  switch (error) {\n    case CUFFT_SUCCESS: return \"CUFFT_SUCCESS\";\n    case CUFFT_INVALID_PLAN: return \"CUFFT_INVALID_PLAN\";\n    case CUFFT_ALLOC_FAILED: return \"CUFFT_ALLOC_FAILED\";\n    case CUFFT_INVALID_TYPE: return \"CUFFT_INVALID_TYPE\";\n    case CUFFT_INVALID_VALUE: return \"CUFFT_INVALID_VALUE\";\n    case CUFFT_INTERNAL_ERROR: return \"CUFFT_INTERNAL_ERROR\";\n    case CUFFT_EXEC_FAILED: return \"CUFFT_EXEC_FAILED\";\n    case CUFFT_SETUP_FAILED: return \"CUFFT_SETUP_FAILED\";\n    case CUFFT_INVALID_SIZE: return \"CUFFT_INVALID_SIZE\";\n    case CUFFT_UNALIGNED_DATA: return \"CUFFT_UNALIGNED_DATA\";\n    case CUFFT_INCOMPLETE_PARAMETER_LIST: return \"CUFFT_INCOMPLETE_PARAMETER_LIST\";\n    case CUFFT_INVALID_DEVICE: return \"CUFFT_INVALID_DEVICE\";\n    case CUFFT_PARSE_ERROR: return \"CUFFT_PARSE_ERROR\";\n    case CUFFT_NO_WORKSPACE: return \"CUFFT_NO_WORKSPACE\";\n    case CUFFT_NOT_IMPLEMENTED: return \"CUFFT_NOT_IMPLEMENTED\";\n    case CUFFT_NOT_SUPPORTED: return \"CUFFT_NOT_SUPPORTED\";\n    default: return \"Unknown cufft status\";\n  }\n}\n\n#if CUDA_VERSION >= 11000\nconst char* CusovlerGetErrorString(cusolverStatus_t error) {\n  switch (error) {\n    case CUSOLVER_STATUS_SUCCESS: return \"CUSOLVER_STATUS_SUCCESS\";\n    case CUSOLVER_STATUS_NOT_INITIALIZED: return \"CUSOLVER_STATUS_NOT_INITIALIZED\";\n    case CUSOLVER_STATUS_ALLOC_FAILED: return \"CUSOLVER_STATUS_ALLOC_FAILED\";\n    case CUSOLVER_STATUS_INVALID_VALUE: return \"CUSOLVER_STATUS_INVALID_VALUE\";\n    case CUSOLVER_STATUS_ARCH_MISMATCH: return \"CUSOLVER_STATUS_ARCH_MISMATCH\";\n    case CUSOLVER_STATUS_EXECUTION_FAILED: return \"CUSOLVER_STATUS_EXECUTION_FAILED\";\n    case CUSOLVER_STATUS_INTERNAL_ERROR: return \"CUSOLVER_STATUS_INTERNAL_ERROR\";\n    case CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED:\n      return \"CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED\";\n    default: return \"Unknown cusolver status\";\n  }\n}\n#endif\n\n#if CUDA_VERSION >= 10020\n\nconst char* NvjpegGetErrorString(nvjpegStatus_t error) {\n  switch (error) {\n    case NVJPEG_STATUS_SUCCESS: return \"NVJPEG_STATUS_SUCCESS\";\n    case NVJPEG_STATUS_NOT_INITIALIZED: return \"NVJPEG_STATUS_NOT_INITIALIZED\";\n    case NVJPEG_STATUS_INVALID_PARAMETER: return \"NVJPEG_STATUS_INVALID_PARAMETER\";\n    case NVJPEG_STATUS_BAD_JPEG: return \"NVJPEG_STATUS_BAD_JPEG\";\n    case NVJPEG_STATUS_JPEG_NOT_SUPPORTED: return \"NVJPEG_STATUS_JPEG_NOT_SUPPORTED\";\n    case NVJPEG_STATUS_ALLOCATOR_FAILURE: return \"NVJPEG_STATUS_ALLOCATOR_FAILURE\";\n    case NVJPEG_STATUS_EXECUTION_FAILED: return \"NVJPEG_STATUS_EXECUTION_FAILED\";\n    case NVJPEG_STATUS_ARCH_MISMATCH: return \"NVJPEG_STATUS_ARCH_MISMATCH\";\n    case NVJPEG_STATUS_INTERNAL_ERROR: return \"NVJPEG_STATUS_INTERNAL_ERROR\";\n    case NVJPEG_STATUS_IMPLEMENTATION_NOT_SUPPORTED:\n      return \"NVJPEG_STATUS_IMPLEMENTATION_NOT_SUPPORTED\";\n    default: return \"Unknown nvjpeg status\";\n  }\n}\n\n#endif\n\nsize_t GetAvailableGpuMemSize(int dev_id) {\n  cudaDeviceProp prop{};\n  cudaGetDeviceProperties(&prop, dev_id);\n  return prop.totalGlobalMem;\n}\n\nnamespace {\n\nstd::function<cudaError_t(void**, size_t)> GetCudaMallocHostFn(int32_t dev) {\n  auto default_fn = [](void** ptr, size_t size) { return cudaMallocHost(ptr, size); };\n  auto manager = Singleton<hardware::NodeDeviceDescriptorManager>::Get();\n  if (manager == nullptr) { return default_fn; }\n  auto node_desc = manager->GetLocalNodeDeviceDescriptor();\n  auto cuda_device = std::dynamic_pointer_cast<const hardware::CudaDeviceDescriptor>(\n      node_desc->GetDevice(hardware::kCudaDeviceDescriptorClassName, dev));\n  if (!cuda_device) { return default_fn; }\n  auto saved_affinity = node_desc->Topology()->GetMemoryAffinity();\n  if (!saved_affinity) { return default_fn; }\n  auto device_affinity =\n      node_desc->Topology()->GetMemoryAffinityByPCIBusID(cuda_device->PCIBusID());\n  if (!device_affinity) { return default_fn; }\n  return [device_affinity, saved_affinity, node_desc, default_fn](void** ptr, size_t size) {\n    node_desc->Topology()->SetMemoryAffinity(device_affinity);\n    cudaError_t err = default_fn(ptr, size);\n    node_desc->Topology()->SetMemoryAffinity(saved_affinity);\n    return err;\n  };\n}\n\n}  // namespace\n\ncudaError_t NumaAwareCudaMallocHost(int32_t dev, void** ptr, size_t size) {\n  auto fn = GetCudaMallocHostFn(dev);\n  return fn(ptr, size);\n}\n\nCudaCurrentDeviceGuard::CudaCurrentDeviceGuard(int32_t dev_id) {\n  CHECK(!pthread_fork::IsForkedSubProcess()) << pthread_fork::kOfCudaNotSupportInForkedSubProcess;\n  OF_CUDA_CHECK(cudaGetDevice(&saved_dev_id_));\n  OF_CUDA_CHECK(cudaSetDevice(dev_id));\n}\n\nCudaCurrentDeviceGuard::CudaCurrentDeviceGuard() { OF_CUDA_CHECK(cudaGetDevice(&saved_dev_id_)); }\n\nCudaCurrentDeviceGuard::~CudaCurrentDeviceGuard() { OF_CUDA_CHECK(cudaSetDevice(saved_dev_id_)); }\n\nCublasMathModeGuard::CublasMathModeGuard(cublasHandle_t handle, cublasMath_t new_mode)\n    : CublasMathModeGuard(handle) {\n  SetMathMode(new_mode);\n}\n\nCublasMathModeGuard::CublasMathModeGuard(cublasHandle_t handle) : handle_(handle) {\n  OF_CUBLAS_CHECK(cublasGetMathMode(handle_, &saved_mode_));\n  new_mode_ = saved_mode_;\n}\n\nCublasMathModeGuard::~CublasMathModeGuard() {\n  if (new_mode_ != saved_mode_) { OF_CUBLAS_CHECK(cublasSetMathMode(handle_, saved_mode_)); }\n}\n\nvoid CublasMathModeGuard::SetMathMode(cublasMath_t new_mode) {\n  new_mode_ = new_mode;\n  if (new_mode_ != saved_mode_) { OF_CUBLAS_CHECK(cublasSetMathMode(handle_, new_mode_)); }\n}\n\nvoid CudaSynchronize(int device_id) {\n  CudaCurrentDeviceGuard dev_guard(device_id);\n  OF_CUDA_CHECK(cudaDeviceSynchronize());\n}\n\nvoid SetCudaDeviceIndex(int device_id) { OF_CUDA_CHECK(cudaSetDevice(device_id)); }\n\nint GetCudaDeviceIndex() { return GlobalProcessCtx::LocalRank(); }\n\nint GetCudaDeviceCount() {\n  /* static */ int cuda_device_count = 0;\n  OF_CUDA_CHECK(cudaGetDeviceCount(&cuda_device_count));\n  return cuda_device_count;\n}\n\n// NOTE(lixiang): Get the memory of the current device.\nMaybe<double> GetCUDAMemoryUsed() {\n  JUST(vm::CurrentRankSync());\n\n  int deviceCount = 0;\n  cudaError_t error_id = cudaGetDeviceCount(&deviceCount);\n  if (error_id != cudaSuccess) {\n    return Error::RuntimeError() << \"Error: GetCUDAMemoryUsed fails :\"\n                                 << cudaGetErrorString(error_id);\n  }\n\n  CHECK_OR_RETURN(deviceCount > 0) << \"GPU device does not exist\";\n\n  size_t gpu_total_size;\n  size_t gpu_free_size;\n\n  cudaError_t cuda_status = cudaMemGetInfo(&gpu_free_size, &gpu_total_size);\n\n  CHECK_OR_RETURN(cudaSuccess == cuda_status)\n      << \"Error: GetCUDAMemoryUsed fails :\" << cudaGetErrorString(cuda_status);\n\n  double total_memory = double(gpu_total_size) / (1024.0 * 1024.0);\n  double free_memory = double(gpu_free_size) / (1024.0 * 1024.0);\n  return (total_memory - free_memory);\n}\n\nstatic std::once_flag prop_init_flag;\nstatic std::vector<cudaDeviceProp> device_props;\n\nvoid InitDevicePropVectorSize() {\n  int device_count = GetCudaDeviceCount();\n  device_props.resize(device_count);\n}\n\nvoid InitDeviceProperties(int device_id) {\n  std::call_once(prop_init_flag, InitDevicePropVectorSize);\n  cudaDeviceProp prop{};\n  OF_CUDA_CHECK(cudaGetDeviceProperties(&prop, device_id));\n  device_props[device_id] = prop;\n}\n\ncudaDeviceProp* GetDeviceProperties(int device_id) {\n  InitCudaContextOnce(device_id);\n  return &device_props[device_id];\n}\n\nvoid InitCudaContextOnce(int device_id) {\n  static int device_count = GetCudaDeviceCount();\n  static std::vector<std::once_flag> init_flags = std::vector<std::once_flag>(device_count);\n  if (LazyMode::is_enabled()) { return; }\n  if (device_id == -1) { device_id = GetCudaDeviceIndex(); }\n  std::call_once(init_flags[device_id], [&]() {\n    OF_CUDA_CHECK(cudaSetDevice(device_id));\n    OF_CUDA_CHECK(cudaDeviceSynchronize());\n    InitDeviceProperties(device_id);\n  });\n}\n\ncudaError_t CudaDriverGetPrimaryCtxActive(int dev, int* active) {\n#if CUDA_VERSION >= 11030\n  CUdevice cu_device{};\n  {\n    CUresult (*fnCuDeviceGet)(CUdevice*, int) = nullptr;\n    cudaError_t err =\n        cudaGetDriverEntryPoint(\"cuDeviceGet\", (void**)&fnCuDeviceGet, cudaEnableDefault);\n    if (err != cudaSuccess) { return err; }\n    CUresult result = fnCuDeviceGet(&cu_device, dev);\n    if (result == CUDA_SUCCESS) {\n      // do nothing\n    } else if (result == CUresult::CUDA_ERROR_INVALID_DEVICE) {\n      return cudaErrorInvalidDevice;\n    } else {\n      return cudaErrorUnknown;\n    }\n  }\n  {\n    CUresult (*fnCuDevicePrimaryCtxGetState)(CUdevice, unsigned int*, int*) = nullptr;\n    cudaError_t err = cudaGetDriverEntryPoint(\n        \"cuDevicePrimaryCtxGetState\", (void**)&fnCuDevicePrimaryCtxGetState, cudaEnableDefault);\n    if (err != cudaSuccess) { return err; }\n    unsigned int flags{};\n    CUresult result = fnCuDevicePrimaryCtxGetState(cu_device, &flags, active);\n    if (result == CUDA_SUCCESS) {\n      return cudaSuccess;\n    } else {\n      return cudaErrorUnknown;\n    }\n  }\n#else\n  return cudaErrorNotSupported;\n#endif  // CUDA_VERSION < 11030\n}\n\n#endif  // WITH_CUDA\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/device/cuda_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_DEVICE_CUDA_UTIL_H_\n#define ONEFLOW_CORE_DEVICE_CUDA_UTIL_H_\n\n#include \"oneflow/core/common/data_type.h\"\n\n#ifdef WITH_CUDA\n\n#include <cublas_v2.h>\n#if CUDA_VERSION >= 11000\n#include <cusolverDn.h>\n#endif\n#include <cuda.h>\n#if CUDA_VERSION >= 10010\n#include <cublasLt.h>\n#endif\n#include <cuda_runtime.h>\n#include <cudnn.h>\n#include <curand.h>\n#include <cufft.h>\n#include <nccl.h>\n#include <cuda_fp16.h>\n#if CUDA_VERSION >= 11000\n#include <cuda_bf16.h>\n#endif  // CUDA_VERSION >= 11000\n#include \"oneflow/core/device/cuda_pseudo_half.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n\n#if CUDA_VERSION >= 10020\n\n#include <nvjpeg.h>\n\n#endif\n\nnamespace oneflow {\n\nconst char* CublasGetErrorString(cublasStatus_t error);\n\nconst char* CurandGetErrorString(curandStatus_t error);\n\nconst char* CuFFTGetErrorString(cufftResult_t error);\n\n#if CUDA_VERSION >= 11000\nconst char* CusovlerGetErrorString(cusolverStatus_t error);\n#endif\n\n#if CUDA_VERSION >= 10020\n\nconst char* NvjpegGetErrorString(nvjpegStatus_t error);\n\n#endif\n\n#define OF_CUDA_CHECK(condition)                                                               \\\n  for (cudaError_t _of_cuda_check_status = (condition); _of_cuda_check_status != cudaSuccess;) \\\n  LOG(FATAL) << \"Check failed: \" #condition \" : \" << cudaGetErrorString(_of_cuda_check_status) \\\n             << \" (\" << _of_cuda_check_status << \") \"\n\n#define OF_CUDNN_CHECK(condition)                                                                \\\n  for (cudnnStatus_t _of_cudnn_check_status = (condition);                                       \\\n       _of_cudnn_check_status != CUDNN_STATUS_SUCCESS;)                                          \\\n  LOG(FATAL) << \"Check failed: \" #condition \" : \" << cudnnGetErrorString(_of_cudnn_check_status) \\\n             << \" (\" << _of_cudnn_check_status << \") \"\n\n#define OF_CUBLAS_CHECK(condition)                                                                 \\\n  for (cublasStatus_t _of_cublas_check_status = (condition);                                       \\\n       _of_cublas_check_status != CUBLAS_STATUS_SUCCESS;)                                          \\\n  LOG(FATAL) << \"Check failed: \" #condition \" : \" << CublasGetErrorString(_of_cublas_check_status) \\\n             << \" (\" << _of_cublas_check_status << \") \"\n\n#define OF_CUFFT_CHECK(condition)                                                                \\\n  for (cufftResult_t _of_cufft_check_status = (condition);                                       \\\n       _of_cufft_check_status != CUFFT_SUCCESS;)                                                 \\\n  LOG(FATAL) << \"Check failed: \" #condition \" : \" << CuFFTGetErrorString(_of_cufft_check_status) \\\n             << \" (\" << _of_cufft_check_status << \") \"\n\n#if CUDA_VERSION >= 11000\n#define OF_CUSOLVER_CHECK(condition)                                        \\\n  for (cusolverStatus_t _of_cusolver_check_status = (condition);            \\\n       _of_cusolver_check_status != CUSOLVER_STATUS_SUCCESS;)               \\\n    LOG(FATAL) << \"Check failed: \" #condition \" : \"                         \\\n               << CusovlerGetErrorString(_of_cusolver_check_status) << \" (\" \\\n               << _of_cusolver_check_status << \") \";\n#endif\n\n#define OF_CURAND_CHECK(condition)                                                                 \\\n  for (curandStatus_t _of_curand_check_status = (condition);                                       \\\n       _of_curand_check_status != CURAND_STATUS_SUCCESS;)                                          \\\n  LOG(FATAL) << \"Check failed: \" #condition \" : \" << CurandGetErrorString(_of_curand_check_status) \\\n             << \" (\" << _of_curand_check_status << \") \"\n\n#define OF_NCCL_CHECK(condition)                                                                \\\n  for (ncclResult_t _of_nccl_check_status = (condition); _of_nccl_check_status != ncclSuccess;) \\\n  LOG(FATAL) << \"Check failed: \" #condition \" : \" << ncclGetErrorString(_of_nccl_check_status)  \\\n             << \" (\" << _of_nccl_check_status << \"). \"                                          \\\n             << \"To see more detail, please run OneFlow with system variable NCCL_DEBUG=INFO\"\n\n#define OF_NCCL_CHECK_OR_RETURN(condition)                                                         \\\n  for (ncclResult_t _of_nccl_check_status = (condition); _of_nccl_check_status != ncclSuccess;)    \\\n  return Error::CheckFailedError().AddStackFrame([](const char* function) {                        \\\n    thread_local static auto frame = SymbolOf(ErrorStackFrame(__FILE__, __LINE__, function));      \\\n    return frame;                                                                                  \\\n  }(__FUNCTION__))                                                                                 \\\n         << \"Check failed: \" #condition \" : \" << ncclGetErrorString(_of_nccl_check_status) << \" (\" \\\n         << _of_nccl_check_status << \") \"\n\n#if CUDA_VERSION >= 10020\n\n#define OF_NVJPEG_CHECK(condition)                                                                 \\\n  for (nvjpegStatus_t _of_nvjpeg_check_status = (condition);                                       \\\n       _of_nvjpeg_check_status != NVJPEG_STATUS_SUCCESS;)                                          \\\n  LOG(FATAL) << \"Check failed: \" #condition \" : \" << NvjpegGetErrorString(_of_nvjpeg_check_status) \\\n             << \" (\" << _of_nvjpeg_check_status << \") \"\n\n#endif\n\n// CUDA: grid stride looping\n#define CUDA_1D_KERNEL_LOOP(i, n)                                                                 \\\n  for (int32_t i = blockIdx.x * blockDim.x + threadIdx.x, step = blockDim.x * gridDim.x; i < (n); \\\n       i += step)\n\n#define CUDA_1D_KERNEL_LOOP_T(type, i, n)                                                      \\\n  for (type i = blockIdx.x * blockDim.x + threadIdx.x, step = blockDim.x * gridDim.x; i < (n); \\\n       i += step)\n\nconst int32_t kCudaThreadsNumPerBlock = 512;\nconst int32_t kCudaMaxBlocksNum = 8192;\nconst int32_t kCudaWarpSize = 32;\n\n// 48KB, max byte size of shared memroy per thread block\n// TODO: limit of shared memory should be different for different arch\nconst int32_t kCudaMaxSharedMemoryByteSize = 48 << 10;\n\ninline int64_t BlocksNum4ThreadsNum(const int64_t n) {\n  CHECK_GT(n, 0);\n  return std::min((n + kCudaThreadsNumPerBlock - 1) / kCudaThreadsNumPerBlock,\n                  static_cast<int64_t>(kCudaMaxBlocksNum));\n}\n\n#define RUN_CUDA_KERNEL(func, stream, elem_cnt, ...) \\\n  stream->As<ep::CudaStream>()->LaunchKernel(func, elem_cnt, 1, __VA_ARGS__)\n\nsize_t GetAvailableGpuMemSize(int dev_id);\n\ncudaError_t NumaAwareCudaMallocHost(int32_t dev, void** ptr, size_t size);\n\nclass CudaCurrentDeviceGuard final {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CudaCurrentDeviceGuard);\n  explicit CudaCurrentDeviceGuard(int32_t dev_id);\n  CudaCurrentDeviceGuard();\n  ~CudaCurrentDeviceGuard();\n\n private:\n  int32_t saved_dev_id_ = -1;\n};\n\nclass CublasMathModeGuard final {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CublasMathModeGuard);\n  CublasMathModeGuard(cublasHandle_t handle, cublasMath_t new_mode);\n  explicit CublasMathModeGuard(cublasHandle_t handle);\n  ~CublasMathModeGuard();\n\n  void SetMathMode(cublasMath_t new_mode);\n\n private:\n  cublasHandle_t handle_{};\n  cublasMath_t saved_mode_{};\n  cublasMath_t new_mode_{};\n};\n\nint GetCudaDeviceIndex();\n\nint GetCudaDeviceCount();\n\nMaybe<double> GetCUDAMemoryUsed();\n\ncudaDeviceProp* GetDeviceProperties(int device_id);\n\nvoid SetCudaDeviceIndex(int device_id);\n\nvoid CudaSynchronize(int device_id);\n\nvoid InitCudaContextOnce(int device_id);\n\ncudaError_t CudaDriverGetPrimaryCtxActive(int dev, int* active);\n\n}  // namespace oneflow\n\n#endif  // WITH_CUDA\n\n#endif  // ONEFLOW_CORE_DEVICE_CUDA_UTIL_H_\n"
  },
  {
    "path": "oneflow/core/device/cudnn_conv_util.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifdef WITH_CUDA\n#include \"oneflow/core/device/cudnn_conv_util.h\"\n#include \"oneflow/core/device/cuda_util.h\"\n#include \"oneflow/core/common/cached_caller.h\"\n#include \"oneflow/core/operator/operator_util.h\"\n#include \"oneflow/core/job/resource_desc.h\"\n#include \"oneflow/core/job/global_for.h\"\n#include \"oneflow/core/job/global_for.h\"\n#include \"oneflow/core/framework/op_kernel.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename algo_t>\nalgo_t GetDefaultAlgo();\n\ntemplate<>\ncudnnConvolutionFwdAlgo_t GetDefaultAlgo<cudnnConvolutionFwdAlgo_t>() {\n  return CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM;\n}\n\ntemplate<>\ncudnnConvolutionBwdDataAlgo_t GetDefaultAlgo<cudnnConvolutionBwdDataAlgo_t>() {\n  return CUDNN_CONVOLUTION_BWD_DATA_ALGO_1;\n}\n\ntemplate<>\ncudnnConvolutionBwdFilterAlgo_t GetDefaultAlgo<cudnnConvolutionBwdFilterAlgo_t>() {\n  return CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1;\n}\n\nsize_t ByteSize4Tensor(const int* dims, int ndim, cudnnDataType_t data_type) {\n  size_t byte_size = GetCudnnDataTypeByteSize(data_type);\n  FOR_RANGE(int, i, 0, ndim) { byte_size *= dims[i]; }\n  return byte_size;\n}\n\ntemplate<typename perf_t, typename algo_t>\nvoid SetAlgo4Perf(const CudnnConvArgs& args, CudnnConvResource* res, perf_t* algo_perf,\n                  algo_t algo) {\n  algo_perf->algo = algo;\n  if (args.params.data_type == CUDNN_DATA_HALF) {\n    algo_perf->mathType = CUDNN_TENSOR_OP_MATH;\n  } else {\n    algo_perf->mathType = CUDNN_DEFAULT_MATH;\n  }\n  OF_CUDNN_CHECK(GetCudnnConvWorkspaceSize(args, res, algo_perf->algo, &(algo_perf->memory)));\n  algo_perf->status = CUDNN_STATUS_SUCCESS;\n}\n\ntemplate<typename perf_t>\nperf_t GetBestAlgorithm(const CudnnConvArgs& args, CudnnConvResource* res,\n                        const std::vector<perf_t>& perf_vec) {\n  using algo_t = decltype(std::declval<perf_t>().algo);\n  if (perf_vec.size() == 0) {\n    LOG(WARNING) << \"There is no result with \"\n                 << (args.heuristic ? \"heuristic searching way.\" : \"exhaustive searching way.\")\n                 << \" (max_workspace_size=\" << args.params.max_ws_size << \")\"\n                 << \" Use default algo(\" << GetDefaultAlgo<algo_t>() << \") instead.\";\n    perf_t perf;\n    SetAlgo4Perf(args, res, &perf, GetDefaultAlgo<algo_t>());\n    return perf;\n  }\n\n  int found_algo_idx = -1;\n  FOR_RANGE(size_t, i, 0, perf_vec.size()) {\n    // Note: Shouldn't all returned results be successful?\n    CHECK_EQ(perf_vec[i].status, CUDNN_STATUS_SUCCESS);\n    if (perf_vec[i].memory > args.params.max_ws_size) { continue; }\n    if (args.deterministic && perf_vec[i].determinism == CUDNN_NON_DETERMINISTIC) { continue; }\n    found_algo_idx = i;\n    break;\n  }\n\n  if (found_algo_idx == -1) {\n    LOG(WARNING) << \"Cannot find any algorithm meets requirements (max_workspace_size=\"\n                 << args.params.max_ws_size << \", determinism=\" << args.deterministic << \") using \"\n                 << (args.heuristic ? \"heuristic searching way.\" : \"exhaustive searching way.\")\n                 << \" Using default algo(\" << GetDefaultAlgo<algo_t>() << \") instead.\";\n    perf_t algo_perf;\n    SetAlgo4Perf(args, res, &algo_perf, GetDefaultAlgo<algo_t>());\n    return algo_perf;\n  }\n\n  if (found_algo_idx != 0) {\n    LOG(WARNING) << \"Currently available alogrithm (algo=\" << perf_vec[found_algo_idx].algo\n                 << \", require memory=\" << perf_vec[found_algo_idx].memory\n                 << \", idx=\" << found_algo_idx\n                 << \") meeting requirments (max_workspace_size=\" << args.params.max_ws_size\n                 << \", determinism=\" << args.deterministic\n                 << \") is not fastest. Fastest algorithm (\" << perf_vec[0].algo\n                 << \") requires memory \" << perf_vec[0].memory;\n  }\n\n#if CUDNN_VERSION < 7500\n  // google [blacklist fft algorithms for strided dgrad]\n  if (std::is_same<decltype(perf_vec[found_algo_idx].algo), cudnnConvolutionBwdDataAlgo_t>::value) {\n    int stride_dim = args.params.x_ndim - 2;\n    bool blacklist =\n        std::any_of(std::begin(args.params.stride), std::begin(args.params.stride) + stride_dim,\n                    [](int n) { return n != 1; });\n    if (blacklist\n        && (static_cast<cudnnConvolutionBwdDataAlgo_t>(perf_vec[found_algo_idx].algo)\n                == CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING\n            || static_cast<cudnnConvolutionBwdDataAlgo_t>(perf_vec[found_algo_idx].algo)\n                   == CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT)) {\n      perf_t algo_perf;\n      SetAlgo4Perf(args, res, &algo_perf, GetDefaultAlgo<algo_t>());\n      return algo_perf;\n    }\n  }\n#endif\n\n  return perf_vec.at(found_algo_idx);\n}\n\ntemplate<typename perf_t>\nperf_t CudnnConvAlgoGetOrInfer(const CudnnConvParams& params,\n                               const std::function<perf_t(const CudnnConvParams&)>& InferFn,\n                               CudnnConvAlgoCache::Store<perf_t>* store, std::mutex* mutex) {\n  const size_t cache_size =\n      Singleton<ResourceDesc, ForSession>::Get()->thread_local_cache_max_size();\n  auto InferWithCache = [&](const CudnnConvParams& p) -> perf_t {\n    CudnnConvParams params_without_ws = p;\n    params_without_ws.max_ws_size = 0;\n    std::unique_lock<std::mutex> lock(*mutex);\n    const auto& key_it = store->find(params_without_ws);\n    if (key_it != store->cend()) {\n      const auto& perf_it = std::find_if(\n          key_it->second.cbegin(), key_it->second.cend(),\n          [&](const std::pair<size_t, perf_t>& pair) {\n            // There might be a case that only memory size pair.second.memory was required for the\n            // best algorithm even though a workspace pair.first supplied\n            return pair.second.memory <= p.max_ws_size /* for memory safety */\n                   && pair.first >= p.max_ws_size /* a case with larger workspace infered before */;\n          });\n      if (perf_it != key_it->second.cend()) { return perf_it->second; }\n    }\n    perf_t perf = InferFn(p);\n    (*store)[params_without_ws].emplace_back(std::make_pair(p.max_ws_size, perf));\n    return perf;\n  };\n  return ThreadLocalCachedCall(cache_size, InferWithCache, params);\n}\n\n}  // namespace\n\ntemplate<>\ncudnnConvolutionFwdAlgoPerf_t CudnnConvAlgoCache::Remember(\n    const CudnnConvParams& params,\n    const std::function<cudnnConvolutionFwdAlgoPerf_t(const CudnnConvParams&)>& InferFn) {\n  return CudnnConvAlgoGetOrInfer<cudnnConvolutionFwdAlgoPerf_t>(params, InferFn, &fwd_algo_store_,\n                                                                &fwd_algo_store_mutex_);\n}\n\ntemplate<>\ncudnnConvolutionBwdDataAlgoPerf_t CudnnConvAlgoCache::Remember(\n    const CudnnConvParams& params,\n    const std::function<cudnnConvolutionBwdDataAlgoPerf_t(const CudnnConvParams&)>& InferFn) {\n  return CudnnConvAlgoGetOrInfer<cudnnConvolutionBwdDataAlgoPerf_t>(\n      params, InferFn, &bwd_data_algo_store_, &bwd_data_algo_store_mutex_);\n}\n\ntemplate<>\ncudnnConvolutionBwdFilterAlgoPerf_t CudnnConvAlgoCache::Remember(\n    const CudnnConvParams& params,\n    const std::function<cudnnConvolutionBwdFilterAlgoPerf_t(const CudnnConvParams&)>& InferFn) {\n  return CudnnConvAlgoGetOrInfer<cudnnConvolutionBwdFilterAlgoPerf_t>(\n      params, InferFn, &bwd_filter_algo_store_, &bwd_filter_algo_cache_mutex_);\n}\n\nCudnnConvDesc::~CudnnConvDesc() { OF_CUDNN_CHECK(cudnnDestroyConvolutionDescriptor(val_)); }\n\nCudnnConvDesc::CudnnConvDesc(const DataType compute_type, const DataType data_type,\n                             const ShapeView& in_blob_shape, const user_op::InferContext& ctx) {\n  int32_t opkernel_dim = in_blob_shape.NumAxes() - 2;\n  OF_CUDNN_CHECK(cudnnCreateConvolutionDescriptor(&val_));\n  const auto& padding_before = ctx.Attr<std::vector<int32_t>>(\"padding_before\");\n  const auto& strides = ctx.Attr<std::vector<int32_t>>(\"strides\");\n  const auto& dilation_rate = ctx.Attr<std::vector<int32_t>>(\"dilation_rate\");\n  if (opkernel_dim == 2) {\n    OF_CUDNN_CHECK(cudnnSetConvolution2dDescriptor(\n        val_, padding_before.at(0), padding_before.at(1), strides.at(0), strides.at(1),\n        dilation_rate.at(0), dilation_rate.at(1), CUDNN_CROSS_CORRELATION,\n        GetCudnnDataType(compute_type)));\n  } else if (opkernel_dim == 1) {\n    OF_CUDNN_CHECK(cudnnSetConvolution2dDescriptor(val_, padding_before.at(0), 0, strides.at(0), 1,\n                                                   dilation_rate.at(0), 1, CUDNN_CROSS_CORRELATION,\n                                                   GetCudnnDataType(compute_type)));\n  } else {\n    OF_CUDNN_CHECK(cudnnSetConvolutionNdDescriptor(\n        val_, opkernel_dim, padding_before.data(), strides.data(), dilation_rate.data(),\n        CUDNN_CROSS_CORRELATION, GetCudnnDataType(compute_type)));\n  }\n  const int32_t groups = ctx.Attr<int32_t>(\"groups\");\n  if (groups != 1) { OF_CUDNN_CHECK(cudnnSetConvolutionGroupCount(val_, groups)); }\n  bool use_tensor_op_math;\n  if (GetCudnnDataType(data_type) == CUDNN_DATA_HALF) {\n    use_tensor_op_math = true;\n#if CUDNN_VERSION >= 8100\n  } else if (GetCudnnDataType(data_type) == CUDNN_DATA_BFLOAT16) {\n    use_tensor_op_math = true;\n#endif\n  } else {\n    use_tensor_op_math = false;\n  }\n  if (use_tensor_op_math) {\n    OF_CUDNN_CHECK(cudnnSetConvolutionMathType(val_, CUDNN_TENSOR_OP_MATH));\n  }\n}\n\nCudnnConvDesc::CudnnConvDesc(const DataType compute_type, const DataType data_type,\n                             const ShapeView& in_blob_shape,\n                             const user_op::KernelComputeContext& ctx) {\n  int32_t opkernel_dim = in_blob_shape.NumAxes() - 2;\n  OF_CUDNN_CHECK(cudnnCreateConvolutionDescriptor(&val_));\n  const auto& padding_before = ctx.Attr<std::vector<int32_t>>(\"padding_before\");\n  const auto& strides = ctx.Attr<std::vector<int32_t>>(\"strides\");\n  const auto& dilation_rate = ctx.Attr<std::vector<int32_t>>(\"dilation_rate\");\n  if (opkernel_dim == 2) {\n    OF_CUDNN_CHECK(cudnnSetConvolution2dDescriptor(\n        val_, padding_before.at(0), padding_before.at(1), strides.at(0), strides.at(1),\n        dilation_rate.at(0), dilation_rate.at(1), CUDNN_CROSS_CORRELATION,\n        GetCudnnDataType(compute_type)));\n  } else if (opkernel_dim == 1) {\n    OF_CUDNN_CHECK(cudnnSetConvolution2dDescriptor(val_, padding_before.at(0), 0, strides.at(0), 1,\n                                                   dilation_rate.at(0), 1, CUDNN_CROSS_CORRELATION,\n                                                   GetCudnnDataType(compute_type)));\n  } else {\n    OF_CUDNN_CHECK(cudnnSetConvolutionNdDescriptor(\n        val_, opkernel_dim, padding_before.data(), strides.data(), dilation_rate.data(),\n        CUDNN_CROSS_CORRELATION, GetCudnnDataType(compute_type)));\n  }\n  const int32_t groups = ctx.Attr<int32_t>(\"groups\");\n  if (groups != 1) { OF_CUDNN_CHECK(cudnnSetConvolutionGroupCount(val_, groups)); }\n  bool use_tensor_op_math;\n  if (GetCudnnDataType(data_type) == CUDNN_DATA_HALF) {\n    use_tensor_op_math = true;\n#if CUDNN_VERSION >= 8100\n  } else if (GetCudnnDataType(data_type) == CUDNN_DATA_BFLOAT16) {\n    use_tensor_op_math = true;\n#endif\n  } else {\n    use_tensor_op_math = false;\n  }\n  if (use_tensor_op_math) {\n    OF_CUDNN_CHECK(cudnnSetConvolutionMathType(val_, CUDNN_TENSOR_OP_MATH));\n  }\n}\n\nCudnnConvArgs::CudnnConvArgs(const user_op::InferContext& ctx, DataType x_data_type,\n                             const ShapeView& x_shape, DataType w_data_type,\n                             const ShapeView& w_shape, DataType y_data_type,\n                             const ShapeView& y_shape, const std::string& data_format,\n                             size_t max_workspace_size, bool heuristic_search,\n                             bool use_deterministic_algo_only, bool enable_pseudo_half)\n    : xdesc(x_data_type, x_shape, data_format),\n      ydesc(y_data_type, y_shape, data_format),\n      wdesc(w_data_type, w_shape, data_format),\n      cdesc(GetConvDescDataType(x_data_type, enable_pseudo_half), x_data_type, x_shape, ctx),\n      heuristic(heuristic_search),\n      deterministic(use_deterministic_algo_only) {\n  std::memset(&params, 0, sizeof(CudnnConvParams));\n  OF_CUDNN_CHECK(cudnnGetTensorNdDescriptor(xdesc.Get(), CudnnConvParams::kTensorMaxDims,\n                                            &params.x_data_type, &params.x_ndim, params.x_dims,\n                                            params.x_strides));\n  OF_CUDNN_CHECK(cudnnGetTensorNdDescriptor(ydesc.Get(), CudnnConvParams::kTensorMaxDims,\n                                            &params.y_data_type, &params.y_ndim, params.y_dims,\n                                            params.y_strides));\n  OF_CUDNN_CHECK(cudnnGetFilterNdDescriptor(wdesc.Get(), CudnnConvParams::kTensorMaxDims,\n                                            &params.w_data_type, &params.w_format, &params.w_ndim,\n                                            params.w_dims));\n  cudnnConvolutionMode_t mode;\n  int conv_dim_size = 0;\n  OF_CUDNN_CHECK(cudnnGetConvolutionNdDescriptor(cdesc.Get(), CudnnConvParams::kConvMaxDims,\n                                                 &conv_dim_size, params.padding, params.stride,\n                                                 params.dilation, &mode, &params.data_type));\n  CHECK_EQ(params.x_data_type, params.w_data_type);\n  CHECK_EQ(params.x_ndim, params.w_ndim);\n  CHECK_EQ(conv_dim_size + 2, params.x_ndim);\n  OF_CUDNN_CHECK(cudnnGetConvolutionGroupCount(cdesc.Get(), &params.groups));\n  params.max_ws_size = max_workspace_size;\n}\n\nCudnnConvArgs::CudnnConvArgs(const user_op::KernelComputeContext& ctx, DataType x_data_type,\n                             const ShapeView& x_shape, DataType w_data_type,\n                             const ShapeView& w_shape, DataType y_data_type,\n                             const ShapeView& y_shape, const std::string& data_format,\n                             size_t max_workspace_size, bool heuristic_search,\n                             bool use_deterministic_algo_only, bool enable_pseudo_half)\n    : xdesc(x_data_type, x_shape, data_format),\n      ydesc(y_data_type, y_shape, data_format),\n      wdesc(w_data_type, w_shape, data_format),\n      cdesc(GetConvDescDataType(x_data_type, enable_pseudo_half), x_data_type, x_shape, ctx),\n      heuristic(heuristic_search),\n      deterministic(use_deterministic_algo_only) {\n  std::memset(&params, 0, sizeof(CudnnConvParams));\n  OF_CUDNN_CHECK(cudnnGetTensorNdDescriptor(xdesc.Get(), CudnnConvParams::kTensorMaxDims,\n                                            &params.x_data_type, &params.x_ndim, params.x_dims,\n                                            params.x_strides));\n  OF_CUDNN_CHECK(cudnnGetTensorNdDescriptor(ydesc.Get(), CudnnConvParams::kTensorMaxDims,\n                                            &params.y_data_type, &params.y_ndim, params.y_dims,\n                                            params.y_strides));\n  OF_CUDNN_CHECK(cudnnGetFilterNdDescriptor(wdesc.Get(), CudnnConvParams::kTensorMaxDims,\n                                            &params.w_data_type, &params.w_format, &params.w_ndim,\n                                            params.w_dims));\n  cudnnConvolutionMode_t mode;\n  int conv_dim_size = 0;\n  OF_CUDNN_CHECK(cudnnGetConvolutionNdDescriptor(cdesc.Get(), CudnnConvParams::kConvMaxDims,\n                                                 &conv_dim_size, params.padding, params.stride,\n                                                 params.dilation, &mode, &params.data_type));\n  CHECK_EQ(params.x_data_type, params.w_data_type);\n  CHECK_EQ(params.x_ndim, params.w_ndim);\n  CHECK_EQ(conv_dim_size + 2, params.x_ndim);\n  OF_CUDNN_CHECK(cudnnGetConvolutionGroupCount(cdesc.Get(), &params.groups));\n  params.max_ws_size = max_workspace_size;\n}\n\nManagedCudnnConvResource::ManagedCudnnConvResource(const CudnnConvArgs& args)\n    : handle_(nullptr), x_dptr_(nullptr), w_dptr_(nullptr), y_dptr_(nullptr), ws_dptr_(nullptr) {\n  x_byte_size_ = ByteSize4Tensor(args.params.x_dims, args.params.x_ndim, args.params.x_data_type);\n  w_byte_size_ = ByteSize4Tensor(args.params.w_dims, args.params.w_ndim, args.params.w_data_type);\n  y_byte_size_ = ByteSize4Tensor(args.params.y_dims, args.params.y_ndim, args.params.y_data_type);\n  ws_byte_size_ = args.params.max_ws_size;\n}\n\nManagedCudnnConvResource::~ManagedCudnnConvResource() {\n  if (handle_ != nullptr) {\n    Singleton<CudnnHandlePool>::Get()->Put(handle_);\n    handle_ = nullptr;\n  }\n  if (x_dptr_ != nullptr) { OF_CUDA_CHECK(cudaFree(x_dptr_)); }\n  if (w_dptr_ != nullptr) { OF_CUDA_CHECK(cudaFree(w_dptr_)); }\n  if (y_dptr_ != nullptr) { OF_CUDA_CHECK(cudaFree(y_dptr_)); }\n  if (ws_dptr_ != nullptr) { OF_CUDA_CHECK(cudaFree(ws_dptr_)); }\n}\n\ncudnnHandle_t ManagedCudnnConvResource::cudnn_handle() {\n  if (handle_ == nullptr) { handle_ = Singleton<CudnnHandlePool>::Get()->Get(); }\n  return handle_;\n}\n\nvoid* ManagedCudnnConvResource::x_mut_dptr() {\n  if (x_dptr_ == nullptr) { OF_CUDA_CHECK(cudaMalloc(&x_dptr_, x_byte_size_)); }\n  return x_dptr_;\n}\n\nvoid* ManagedCudnnConvResource::w_mut_dptr() {\n  if (w_dptr_ == nullptr) { OF_CUDA_CHECK(cudaMalloc(&w_dptr_, w_byte_size_)); }\n  return w_dptr_;\n}\n\nvoid* ManagedCudnnConvResource::y_mut_dptr() {\n  if (y_dptr_ == nullptr) { OF_CUDA_CHECK(cudaMalloc(&y_dptr_, y_byte_size_)); }\n  return y_dptr_;\n}\n\nconst void* ManagedCudnnConvResource::x_const_dptr() const {\n  return const_cast<ManagedCudnnConvResource*>(this)->x_mut_dptr();\n}\n\nconst void* ManagedCudnnConvResource::w_const_dptr() const {\n  return const_cast<ManagedCudnnConvResource*>(this)->w_mut_dptr();\n}\n\nconst void* ManagedCudnnConvResource::y_const_dptr() const {\n  return const_cast<ManagedCudnnConvResource*>(this)->y_mut_dptr();\n}\n\nvoid* ManagedCudnnConvResource::ws_dptr() {\n  if (ws_dptr_ == nullptr) { OF_CUDA_CHECK(cudaMalloc(&ws_dptr_, ws_byte_size_)); }\n  return ws_dptr_;\n}\n\nbool operator==(const CudnnConvParams& a, const CudnnConvParams& b) {\n  auto ptr1 = reinterpret_cast<const uint8_t*>(&a);\n  auto ptr2 = reinterpret_cast<const uint8_t*>(&b);\n  return memcmp(ptr1, ptr2, sizeof(CudnnConvParams)) == 0;\n}\n\nDataType GetConvDescDataType(DataType data_type, bool pseudo_half) {\n  if (data_type == DataType::kFloat16 && pseudo_half) {\n    return DataType::kFloat;\n  } else if (data_type == DataType::kBFloat16) {\n    return DataType::kFloat;\n  }\n  return data_type;\n}\n\ncudnnStatus_t GetCudnnConvWorkspaceSize(const CudnnConvArgs& args, CudnnConvResource* res,\n                                        cudnnConvolutionFwdAlgo_t algo, size_t* sz) {\n  return cudnnGetConvolutionForwardWorkspaceSize(res->cudnn_handle(), args.xdesc.Get(),\n                                                 args.wdesc.Get(), args.cdesc.Get(),\n                                                 args.ydesc.Get(), algo, sz);\n}\n\ncudnnStatus_t GetCudnnConvWorkspaceSize(const CudnnConvArgs& args, CudnnConvResource* res,\n                                        cudnnConvolutionBwdDataAlgo_t algo, size_t* sz) {\n  return cudnnGetConvolutionBackwardDataWorkspaceSize(res->cudnn_handle(), args.wdesc.Get(),\n                                                      args.ydesc.Get(), args.cdesc.Get(),\n                                                      args.xdesc.Get(), algo, sz);\n}\n\ncudnnStatus_t GetCudnnConvWorkspaceSize(const CudnnConvArgs& args, CudnnConvResource* res,\n                                        cudnnConvolutionBwdFilterAlgo_t algo, size_t* sz) {\n  return cudnnGetConvolutionBackwardFilterWorkspaceSize(res->cudnn_handle(), args.xdesc.Get(),\n                                                        args.ydesc.Get(), args.cdesc.Get(),\n                                                        args.wdesc.Get(), algo, sz);\n}\n\ntemplate<>\nstruct CudnnConvAlgorithmSearch<cudnnConvolutionFwdAlgoPerf_t> {\n  using perf_t = cudnnConvolutionFwdAlgoPerf_t;\n\n  static int GetAlgoMaxCount(CudnnConvResource* res) {\n    int max_algo_cnt = 0;\n    OF_CUDNN_CHECK(cudnnGetConvolutionForwardAlgorithmMaxCount(res->cudnn_handle(), &max_algo_cnt));\n    return max_algo_cnt;\n  }\n\n  static void HeuristicSearch(const CudnnConvArgs& args, CudnnConvResource* res,\n                              std::vector<perf_t>* perf_vec) {\n    int found_algo_cnt = 0;\n    perf_vec->resize(GetAlgoMaxCount(res));\n    OF_CUDNN_CHECK(cudnnGetConvolutionForwardAlgorithm_v7(\n        res->cudnn_handle(), args.xdesc.Get(), args.wdesc.Get(), args.cdesc.Get(), args.ydesc.Get(),\n        perf_vec->size(), &found_algo_cnt, perf_vec->data()));\n    // vector::resize does not affect the first found_algo_cnt elements.\n    perf_vec->resize(found_algo_cnt);\n  }\n\n  static void ExhaustiveSearch(const CudnnConvArgs& args, CudnnConvResource* res,\n                               std::vector<perf_t>* perf_vec) {\n    int found_algo_cnt = 0;\n    perf_vec->resize(GetAlgoMaxCount(res));\n    OF_CUDNN_CHECK(cudnnFindConvolutionForwardAlgorithmEx(\n        res->cudnn_handle(), args.xdesc.Get(), res->x_const_dptr(), args.wdesc.Get(),\n        res->w_const_dptr(), args.cdesc.Get(), args.ydesc.Get(), res->y_mut_dptr(),\n        perf_vec->size(), &found_algo_cnt, perf_vec->data(), res->ws_dptr(),\n        args.params.max_ws_size));\n    // vector::resize does not affect the first found_algo_cnt elements.\n    perf_vec->resize(found_algo_cnt);\n  }\n};\n\ntemplate<>\nstruct CudnnConvAlgorithmSearch<cudnnConvolutionBwdDataAlgoPerf_t> {\n  using perf_t = cudnnConvolutionBwdDataAlgoPerf_t;\n\n  static int GetAlgoMaxCount(CudnnConvResource* res) {\n    int max_algo_cnt = 0;\n    OF_CUDNN_CHECK(\n        cudnnGetConvolutionBackwardDataAlgorithmMaxCount(res->cudnn_handle(), &max_algo_cnt));\n    return max_algo_cnt;\n  }\n\n  static void HeuristicSearch(const CudnnConvArgs& args, CudnnConvResource* res,\n                              std::vector<perf_t>* perf_vec) {\n    int found_algo_cnt = 0;\n    perf_vec->resize(GetAlgoMaxCount(res));\n    OF_CUDNN_CHECK(cudnnGetConvolutionBackwardDataAlgorithm_v7(\n        res->cudnn_handle(), args.wdesc.Get(), args.ydesc.Get(), args.cdesc.Get(), args.xdesc.Get(),\n        perf_vec->size(), &found_algo_cnt, perf_vec->data()));\n    // vector::resize does not affect the first found_algo_cnt elements.\n    perf_vec->resize(found_algo_cnt);\n  }\n\n  static void ExhaustiveSearch(const CudnnConvArgs& args, CudnnConvResource* res,\n                               std::vector<perf_t>* perf_vec) {\n    int found_algo_cnt = 0;\n    perf_vec->resize(GetAlgoMaxCount(res));\n    OF_CUDNN_CHECK(cudnnFindConvolutionBackwardDataAlgorithmEx(\n        res->cudnn_handle(), args.wdesc.Get(), res->w_const_dptr(), args.ydesc.Get(),\n        res->y_const_dptr(), args.cdesc.Get(), args.xdesc.Get(), res->x_mut_dptr(),\n        perf_vec->size(), &found_algo_cnt, perf_vec->data(), res->ws_dptr(),\n        args.params.max_ws_size));\n    // vector::resize does not affect the first found_algo_cnt elements.\n    perf_vec->resize(found_algo_cnt);\n  }\n};\n\ntemplate<>\nstruct CudnnConvAlgorithmSearch<cudnnConvolutionBwdFilterAlgoPerf_t> {\n  using perf_t = cudnnConvolutionBwdFilterAlgoPerf_t;\n\n  static int GetAlgoMaxCount(CudnnConvResource* res) {\n    int max_algo_cnt = 0;\n    OF_CUDNN_CHECK(\n        cudnnGetConvolutionBackwardFilterAlgorithmMaxCount(res->cudnn_handle(), &max_algo_cnt));\n    return max_algo_cnt;\n  }\n\n  static void HeuristicSearch(const CudnnConvArgs& args, CudnnConvResource* res,\n                              std::vector<perf_t>* perf_vec) {\n    int found_algo_cnt = 0;\n    perf_vec->resize(GetAlgoMaxCount(res));\n    OF_CUDNN_CHECK(cudnnGetConvolutionBackwardFilterAlgorithm_v7(\n        res->cudnn_handle(), args.xdesc.Get(), args.ydesc.Get(), args.cdesc.Get(), args.wdesc.Get(),\n        perf_vec->size(), &found_algo_cnt, perf_vec->data()));\n    // vector::resize does not affect the first found_algo_cnt elements.\n    perf_vec->resize(found_algo_cnt);\n  }\n\n  static void ExhaustiveSearch(const CudnnConvArgs& args, CudnnConvResource* res,\n                               std::vector<perf_t>* perf_vec) {\n    int found_algo_cnt = 0;\n    perf_vec->resize(GetAlgoMaxCount(res));\n    OF_CUDNN_CHECK(cudnnFindConvolutionBackwardFilterAlgorithmEx(\n        res->cudnn_handle(), args.xdesc.Get(), res->x_const_dptr(), args.ydesc.Get(),\n        res->y_const_dptr(), args.cdesc.Get(), args.wdesc.Get(), res->w_mut_dptr(),\n        perf_vec->size(), &found_algo_cnt, perf_vec->data(), res->ws_dptr(),\n        args.params.max_ws_size));\n    // vector::resize does not affect the first found_algo_cnt elements.\n    perf_vec->resize(found_algo_cnt);\n  }\n};\n\ntemplate<typename perf_t>\nperf_t FindCudnnConvAlgorithm(CudnnConvArgs* args) {\n  ManagedCudnnConvResource res(*args);\n  return FindCudnnConvAlgorithmWithResource<perf_t>(args, &res);\n}\n\ntemplate<typename perf_t>\nperf_t FindCudnnConvAlgorithmWithResource(CudnnConvArgs* args, CudnnConvResource* res) {\n  auto Infer = [args, res](const CudnnConvParams& params) {\n    std::vector<perf_t> perf_vec;\n    if (args->heuristic) {\n      CudnnConvAlgorithmSearch<perf_t>::HeuristicSearch(*args, res, &perf_vec);\n    } else {\n      CudnnConvAlgorithmSearch<perf_t>::ExhaustiveSearch(*args, res, &perf_vec);\n    }\n    return GetBestAlgorithm<perf_t>(*args, res, perf_vec);\n  };\n  return Singleton<CudnnConvAlgoCache>::Get()->Remember<perf_t>(args->params, Infer);\n}\n\ntemplate<typename perf_t, typename algo_t>\nperf_t GetCudnnConvAlgorithmPerference(CudnnConvArgs* args, algo_t algo) {\n  ManagedCudnnConvResource res(*args);\n  return GetCudnnConvAlgorithmPerferenceWithResource<perf_t>(args, &res, algo);\n}\n\ntemplate<typename perf_t, typename algo_t>\nperf_t GetCudnnConvAlgorithmPerferenceWithResource(CudnnConvArgs* args, CudnnConvResource* res,\n                                                   algo_t algo) {\n  perf_t perf;\n  SetAlgo4Perf(*args, res, &perf, algo);\n  return perf;\n}\n\n#define EXPLICIT_INSTANTIAT_CUDNN_CONV_ALGORITHM_INTERFACE(perf_t)                        \\\n  template perf_t FindCudnnConvAlgorithm(CudnnConvArgs*);                                 \\\n  template perf_t FindCudnnConvAlgorithmWithResource(CudnnConvArgs*, CudnnConvResource*); \\\n  template perf_t GetCudnnConvAlgorithmPerference(CudnnConvArgs*,                         \\\n                                                  decltype(std::declval<perf_t>().algo)); \\\n  template perf_t GetCudnnConvAlgorithmPerferenceWithResource(                            \\\n      CudnnConvArgs*, CudnnConvResource*, decltype(std::declval<perf_t>().algo));\n\nEXPLICIT_INSTANTIAT_CUDNN_CONV_ALGORITHM_INTERFACE(cudnnConvolutionFwdAlgoPerf_t)\nEXPLICIT_INSTANTIAT_CUDNN_CONV_ALGORITHM_INTERFACE(cudnnConvolutionBwdDataAlgoPerf_t)\nEXPLICIT_INSTANTIAT_CUDNN_CONV_ALGORITHM_INTERFACE(cudnnConvolutionBwdFilterAlgoPerf_t)\n\n}  // namespace oneflow\n\n#endif  // WITH_CUDA\n"
  },
  {
    "path": "oneflow/core/device/cudnn_conv_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_DEVICE_CUDNN_CONV_UTIL_H_\n#define ONEFLOW_CORE_DEVICE_CUDNN_CONV_UTIL_H_\n\n#ifdef WITH_CUDA\n\n#include \"oneflow/core/device/cudnn_util.h\"\n#include \"oneflow/core/common/protobuf.h\"\n\nnamespace oneflow {\n\nnamespace user_op {\n\nclass KernelComputeContext;\nclass InferContext;\n\n}  // namespace user_op\n\nclass CudnnConvDesc final {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CudnnConvDesc);\n  CudnnConvDesc() = delete;\n  ~CudnnConvDesc();\n\n  CudnnConvDesc(const DataType compute_type, const DataType data_type,\n                const ShapeView& in_blob_shape, const user_op::InferContext& ctx);\n\n  CudnnConvDesc(const DataType compute_type, const DataType data_type,\n                const ShapeView& in_blob_shape, const user_op::KernelComputeContext& ctx);\n\n  const cudnnConvolutionDescriptor_t& Get() const { return val_; }\n\n private:\n  cudnnConvolutionDescriptor_t val_;\n};\n\nstruct CudnnConvParams {\n  static constexpr size_t kTensorMaxDims = 5;\n  static constexpr size_t kConvMaxDims = 3;\n\n  cudnnDataType_t x_data_type;\n  cudnnDataType_t w_data_type;\n  cudnnDataType_t y_data_type;\n  cudnnDataType_t data_type;\n  cudnnTensorFormat_t w_format;\n  int x_ndim;\n  int w_ndim;\n  int y_ndim;\n  int x_dims[kTensorMaxDims];\n  int x_strides[kTensorMaxDims];\n  int y_dims[kTensorMaxDims];\n  int y_strides[kTensorMaxDims];\n  int w_dims[kTensorMaxDims];\n  int padding[kConvMaxDims];\n  int stride[kConvMaxDims];\n  int dilation[kConvMaxDims];\n  size_t max_ws_size;\n  int groups;\n};\n\nstruct CudnnConvArgs final {\n  CudnnConvParams params;\n  CudnnTensorDesc xdesc;\n  CudnnTensorDesc ydesc;\n  CudnnFilterDesc wdesc;\n  CudnnConvDesc cdesc;\n  bool heuristic;\n  bool deterministic;\n\n  OF_DISALLOW_COPY_AND_MOVE(CudnnConvArgs);\n  CudnnConvArgs(const user_op::InferContext& ctx, DataType x_data_type, const ShapeView& x_shape,\n                DataType w_data_type, const ShapeView& w_shape, DataType y_data_type,\n                const ShapeView& y_shape, const std::string& data_format, size_t max_workspace_size,\n                bool heuristic_search, bool use_deterministic_algo_only, bool enable_pseudo_half);\n  CudnnConvArgs(const user_op::KernelComputeContext& ctx, DataType x_data_type,\n                const ShapeView& x_shape, DataType w_data_type, const ShapeView& w_shape,\n                DataType y_data_type, const ShapeView& y_shape, const std::string& data_format,\n                size_t max_workspace_size, bool heuristic_search, bool use_deterministic_algo_only,\n                bool enable_pseudo_half);\n};\n\nclass CudnnConvResource {\n public:\n  CudnnConvResource() = default;\n  virtual ~CudnnConvResource() = default;\n  virtual cudnnHandle_t cudnn_handle() = 0;\n  virtual void* w_mut_dptr() = 0;\n  virtual void* x_mut_dptr() = 0;\n  virtual void* y_mut_dptr() = 0;\n  virtual const void* w_const_dptr() const = 0;\n  virtual const void* x_const_dptr() const = 0;\n  virtual const void* y_const_dptr() const = 0;\n  virtual void* ws_dptr() = 0;\n};\n\nclass AllocatedCudnnConvResource final : public CudnnConvResource {\n public:\n  AllocatedCudnnConvResource(cudnnHandle_t handle, void* x_dptr, void* w_dptr, void* y_dptr,\n                             void* ws_dptr)\n      : handle_(handle), x_dptr_(x_dptr), w_dptr_(w_dptr), y_dptr_(y_dptr), ws_dptr_(ws_dptr) {}\n  ~AllocatedCudnnConvResource() = default;\n  cudnnHandle_t cudnn_handle() override { return handle_; }\n  const void* x_const_dptr() const override { return x_dptr_; }\n  const void* w_const_dptr() const override { return w_dptr_; }\n  const void* y_const_dptr() const override { return y_dptr_; }\n  void* x_mut_dptr() override { return x_dptr_; }\n  void* w_mut_dptr() override { return w_dptr_; }\n  void* y_mut_dptr() override { return y_dptr_; }\n  void* ws_dptr() override { return ws_dptr_; }\n\n private:\n  cudnnHandle_t handle_;\n  void* x_dptr_;\n  void* w_dptr_;\n  void* y_dptr_;\n  void* ws_dptr_;\n};\n\nclass ManagedCudnnConvResource final : public CudnnConvResource {\n public:\n  ManagedCudnnConvResource(const CudnnConvArgs& args);\n  ~ManagedCudnnConvResource() override;\n  cudnnHandle_t cudnn_handle() override;\n  void* x_mut_dptr() override;\n  void* w_mut_dptr() override;\n  void* y_mut_dptr() override;\n  const void* x_const_dptr() const override;\n  const void* w_const_dptr() const override;\n  const void* y_const_dptr() const override;\n  void* ws_dptr() override;\n\n private:\n  cudnnHandle_t handle_;\n  void* x_dptr_;\n  void* w_dptr_;\n  void* y_dptr_;\n  void* ws_dptr_;\n  size_t x_byte_size_;\n  size_t w_byte_size_;\n  size_t y_byte_size_;\n  size_t ws_byte_size_;\n};\n\nbool operator==(const CudnnConvParams& a, const CudnnConvParams& b);\nDataType GetConvDescDataType(DataType data_type, bool pseudo_half);\n\ntemplate<typename perf_t>\nstruct CudnnConvAlgorithmSearch;\n\ncudnnStatus_t GetCudnnConvWorkspaceSize(const CudnnConvArgs& args, CudnnConvResource* res,\n                                        cudnnConvolutionFwdAlgo_t algo, size_t* sz);\ncudnnStatus_t GetCudnnConvWorkspaceSize(const CudnnConvArgs& args, CudnnConvResource* res,\n                                        cudnnConvolutionBwdDataAlgo_t algo, size_t* sz);\ncudnnStatus_t GetCudnnConvWorkspaceSize(const CudnnConvArgs& args, CudnnConvResource* res,\n                                        cudnnConvolutionBwdFilterAlgo_t algo, size_t* sz);\n\ntemplate<typename perf_t>\nperf_t FindCudnnConvAlgorithm(CudnnConvArgs* args);\n\ntemplate<typename perf_t>\nperf_t FindCudnnConvAlgorithmWithResource(CudnnConvArgs* args, CudnnConvResource* res);\n\ntemplate<typename perf_t, typename algo_t>\nperf_t GetCudnnConvAlgorithmPerference(CudnnConvArgs* args, algo_t algo);\n\ntemplate<typename perf_t, typename algo_t>\nperf_t GetCudnnConvAlgorithmPerferenceWithResource(CudnnConvArgs* args, CudnnConvResource* res,\n                                                   algo_t algo);\n\n}  // namespace oneflow\n\nnamespace std {\n\n// Hashing machinery for Params\n// see https://en.wikipedia.org/wiki/Fowler%E2%80%93Noll%E2%80%93Vo_hash_function\ntemplate<>\nstruct hash<oneflow::CudnnConvParams> final {\n  // Params must be a POD because we read out its memory\n  // contenst as char* when hashing\n  static_assert(std::is_pod<oneflow::CudnnConvParams>::value, \"CudnnConvParams is not POD\");\n\n  size_t operator()(const oneflow::CudnnConvParams& params) const {\n    const auto* ptr = reinterpret_cast<const uint8_t*>(&params);\n    uint32_t value = 0x811C9DC5;\n    for (int i = 0; i < (int)sizeof(oneflow::CudnnConvParams); ++i) {\n      value ^= ptr[i];\n      value *= 0x01000193;\n    }\n    return (size_t)value;\n  }\n};\n\n}  // namespace std\n\nnamespace oneflow {\n\nclass CudnnConvAlgoCache final {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CudnnConvAlgoCache);\n  CudnnConvAlgoCache() = default;\n  ~CudnnConvAlgoCache() = default;\n\n  template<typename perf_t>\n  using WorkspaceSizeAndPerfT = std::pair<size_t, perf_t>;\n  template<typename perf_t>\n  using Store = HashMap<CudnnConvParams, std::list<WorkspaceSizeAndPerfT<perf_t>>>;\n\n  template<typename perf_t>\n  perf_t Remember(const CudnnConvParams& params,\n                  const std::function<perf_t(const CudnnConvParams& param)>& InferFn);\n\n private:\n  Store<cudnnConvolutionFwdAlgoPerf_t> fwd_algo_store_;\n  std::mutex fwd_algo_store_mutex_;\n  Store<cudnnConvolutionBwdDataAlgoPerf_t> bwd_data_algo_store_;\n  std::mutex bwd_data_algo_store_mutex_;\n  Store<cudnnConvolutionBwdFilterAlgoPerf_t> bwd_filter_algo_store_;\n  std::mutex bwd_filter_algo_cache_mutex_;\n};\n\n}  // namespace oneflow\n\n#endif  // WITH_CUDA\n\n#endif  // ONEFLOW_CORE_DEVICE_CUDNN_CONV_UTIL_H_\n"
  },
  {
    "path": "oneflow/core/device/cudnn_util.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/device/cuda_util.h\"\n#include \"oneflow/core/device/cudnn_util.h\"\n\nnamespace oneflow {\n\n#ifdef WITH_CUDA\n\ncudnnDataType_t GetCudnnDataType(DataType val) {\n#define MAKE_ENTRY(type_cpp, type_cudnn) \\\n  if (val == GetDataType<type_cpp>::value) { return type_cudnn; }\n  OF_PP_FOR_EACH_TUPLE(MAKE_ENTRY, CUDNN_DATA_TYPE_SEQ);\n#undef MAKE_ENTRY\n#if CUDNN_VERSION >= 8100\n  if (val == kBFloat16) { return CUDNN_DATA_BFLOAT16; }\n#endif\n  UNIMPLEMENTED();\n}\n\nCudnnTensorDesc::CudnnTensorDesc() { OF_CUDNN_CHECK(cudnnCreateTensorDescriptor(&val_)); }\nCudnnTensorDesc::~CudnnTensorDesc() { OF_CUDNN_CHECK(cudnnDestroyTensorDescriptor(val_)); }\nCudnnTensorDesc::CudnnTensorDesc(cudnnTensorFormat_t format, DataType data_type, int n, int c,\n                                 int h, int w) {\n  OF_CUDNN_CHECK(cudnnCreateTensorDescriptor(&val_));\n  OF_CUDNN_CHECK(cudnnSetTensor4dDescriptor(val_, format, GetCudnnDataType(data_type), n, c, h, w));\n}\nCudnnTensorDesc::CudnnTensorDesc(DataType data_type, int dims, const int* dim, const int* stride) {\n  OF_CUDNN_CHECK(cudnnCreateTensorDescriptor(&val_));\n  OF_CUDNN_CHECK(cudnnSetTensorNdDescriptor(val_, GetCudnnDataType(data_type), dims, dim, stride));\n}\nCudnnTensorDesc::CudnnTensorDesc(DataType data_type, const ShapeView& shape,\n                                 const std::string& data_format) {\n  OF_CUDNN_CHECK(cudnnCreateTensorDescriptor(&val_));\n  cudnnTensorFormat_t cudnn_data_format;\n  if (data_format == \"channels_first\") {\n    cudnn_data_format = CUDNN_TENSOR_NCHW;\n  } else if (data_format == \"channels_last\") {\n    cudnn_data_format = CUDNN_TENSOR_NHWC;\n  } else {\n    UNIMPLEMENTED();\n  }\n\n  if (shape.NumAxes() == 3) {\n    int data_num = static_cast<int>(shape.At(0));\n    int channels = data_format == \"channels_first\" ? static_cast<int>(shape.At(1))\n                                                   : static_cast<int>(shape.At(2));\n    int kernel_h = data_format == \"channels_first\" ? static_cast<int>(shape.At(2))\n                                                   : static_cast<int>(shape.At(1));\n    int kernel_w = 1;\n    OF_CUDNN_CHECK(cudnnSetTensor4dDescriptor(val_, cudnn_data_format, GetCudnnDataType(data_type),\n                                              data_num, channels, kernel_h, kernel_w));\n  } else if (shape.NumAxes() == 4) {\n    int data_num = static_cast<int>(shape.At(0));\n    int channels = data_format == \"channels_first\" ? static_cast<int>(shape.At(1))\n                                                   : static_cast<int>(shape.At(3));\n    int kernel_h = data_format == \"channels_first\" ? static_cast<int>(shape.At(2))\n                                                   : static_cast<int>(shape.At(1));\n    int kernel_w = data_format == \"channels_first\" ? static_cast<int>(shape.At(3))\n                                                   : static_cast<int>(shape.At(2));\n    OF_CUDNN_CHECK(cudnnSetTensor4dDescriptor(val_, cudnn_data_format, GetCudnnDataType(data_type),\n                                              data_num, channels, kernel_h, kernel_w));\n  } else {\n    std::vector<int> tensor_dim({shape.ptr(), shape.ptr() + shape.NumAxes()});\n    std::vector<int> stride_of_tensor(shape.NumAxes(), 1);\n    for (int32_t i = shape.NumAxes() - 2; i >= 0; --i) {\n      stride_of_tensor[i] = stride_of_tensor[i + 1] * shape.At(i + 1);\n    }\n\n    OF_CUDNN_CHECK(cudnnSetTensorNdDescriptor(val_, GetCudnnDataType(data_type), shape.NumAxes(),\n                                              tensor_dim.data(), stride_of_tensor.data()));\n  }\n}\n\nCudnnFilterDesc::~CudnnFilterDesc() { OF_CUDNN_CHECK(cudnnDestroyFilterDescriptor(val_)); }\n\nCudnnFilterDesc::CudnnFilterDesc(DataType data_type, const ShapeView& shape,\n                                 const std::string& data_format) {\n  OF_CUDNN_CHECK(cudnnCreateFilterDescriptor(&val_));\n  cudnnTensorFormat_t cudnn_data_format;\n  if (data_format == \"channels_first\") {\n    cudnn_data_format = CUDNN_TENSOR_NCHW;\n  } else if (data_format == \"channels_last\") {\n    cudnn_data_format = CUDNN_TENSOR_NHWC;\n  } else {\n    UNIMPLEMENTED();\n  }\n\n  if (shape.NumAxes() == 3) {\n    int filters = static_cast<int>(shape.At(0));\n    int c = data_format == \"channels_first\" ? static_cast<int>(shape.At(1))\n                                            : static_cast<int>(shape.At(2));\n    int kernel_h = data_format == \"channels_first\" ? static_cast<int>(shape.At(2))\n                                                   : static_cast<int>(shape.At(1));\n    int kernel_w = 1;\n    OF_CUDNN_CHECK(cudnnSetFilter4dDescriptor(val_, GetCudnnDataType(data_type), cudnn_data_format,\n                                              filters, c, kernel_h, kernel_w));\n  } else if (shape.NumAxes() == 4) {\n    int filters = static_cast<int>(shape.At(0));\n    int kernel_h = data_format == \"channels_first\" ? static_cast<int>(shape.At(2))\n                                                   : static_cast<int>(shape.At(1));\n    int kernel_w = data_format == \"channels_first\" ? static_cast<int>(shape.At(3))\n                                                   : static_cast<int>(shape.At(2));\n    int c = data_format == \"channels_first\" ? static_cast<int>(shape.At(1))\n                                            : static_cast<int>(shape.At(3));\n    OF_CUDNN_CHECK(cudnnSetFilter4dDescriptor(val_, GetCudnnDataType(data_type), cudnn_data_format,\n                                              filters, c, kernel_h, kernel_w));\n  } else {\n    std::vector<int> dims({shape.ptr(), shape.ptr() + shape.NumAxes()});\n    OF_CUDNN_CHECK(cudnnSetFilterNdDescriptor(val_, GetCudnnDataType(data_type), cudnn_data_format,\n                                              dims.size(), dims.data()));\n  }\n}\n\nCudnnActivationDesc::CudnnActivationDesc(cudnnActivationMode_t mode,\n                                         cudnnNanPropagation_t relu_nan_opt, double coef) {\n  OF_CUDNN_CHECK(cudnnCreateActivationDescriptor(&val_));\n  OF_CUDNN_CHECK(cudnnSetActivationDescriptor(val_, mode, relu_nan_opt, coef));\n}\n\nCudnnActivationDesc::~CudnnActivationDesc() {\n  OF_CUDNN_CHECK(cudnnDestroyActivationDescriptor(val_));\n}\n\nsize_t GetCudnnDataTypeByteSize(cudnnDataType_t data_type) {\n  size_t byte_size = 0;\n  switch (data_type) {\n    case CUDNN_DATA_FLOAT:\n    case CUDNN_DATA_INT32:\n    case CUDNN_DATA_INT8x4:\n    case CUDNN_DATA_UINT8x4: {\n      byte_size = 4;\n      break;\n    }\n    case CUDNN_DATA_DOUBLE: {\n      byte_size = 8;\n      break;\n    }\n    case CUDNN_DATA_HALF: {\n      byte_size = 2;\n      break;\n    }\n    case CUDNN_DATA_INT8:\n    case CUDNN_DATA_UINT8: {\n      byte_size = 1;\n      break;\n    }\n#if CUDNN_VERSION > 7200\n    case CUDNN_DATA_INT8x32: {\n      byte_size = 32;\n      break;\n    }\n#endif\n#if CUDNN_VERSION >= 8100\n    case CUDNN_DATA_BFLOAT16: {\n      byte_size = 2;\n      break;\n    }\n#endif\n    default: {\n      UNIMPLEMENTED();\n    }\n  }\n  return byte_size;\n}\n\nCudnnHandlePool::~CudnnHandlePool() {\n  for (auto& pair : handle_list_map_) {\n    int64_t device_id = pair.first;\n    auto& handle_list = pair.second;\n    CudaCurrentDeviceGuard guard(device_id);\n    while (!handle_list.empty()) {\n      cudnnHandle_t handle = handle_list.back();\n      handle_list.pop_back();\n      OF_CUDNN_CHECK(cudnnDestroy(handle));\n    }\n  }\n  handle_list_map_.clear();\n}\n\ncudnnHandle_t CudnnHandlePool::Get() {\n  int device_id;\n  OF_CUDA_CHECK(cudaGetDevice(&device_id));\n  {\n    std::unique_lock<std::mutex> lock(mutex_);\n    std::vector<cudnnHandle_t>& handle_list = handle_list_map_[device_id];\n    if (!handle_list.empty()) {\n      cudnnHandle_t handle = handle_list.back();\n      handle_list.pop_back();\n      return handle;\n    }\n  }\n  cudnnHandle_t handle;\n  OF_CUDNN_CHECK(cudnnCreate(&handle));\n  return handle;\n}\n\nvoid CudnnHandlePool::Put(cudnnHandle_t handle) {\n  int device_id;\n  OF_CUDA_CHECK(cudaGetDevice(&device_id));\n  std::unique_lock<std::mutex> lock(mutex_);\n  std::vector<cudnnHandle_t>& handle_list = handle_list_map_[device_id];\n  handle_list.push_back(handle);\n}\n\n#endif  // WITH_CUDA\n\ntemplate<typename T>\nconst void* CudnnSPOnePtr() {\n  static const float fval = 1.0f;\n  static const double dval = 1.0;\n  const void* ret = std::is_same<T, double>::value ? static_cast<const void*>(&dval)\n                                                   : static_cast<const void*>(&fval);\n  return ret;\n}\n\ntemplate<typename T>\nconst void* CudnnSPZeroPtr() {\n  static const float fval = 0.0f;\n  static const double dval = 0.0;\n  const void* ret = std::is_same<T, double>::value ? static_cast<const void*>(&dval)\n                                                   : static_cast<const void*>(&fval);\n  return ret;\n}\n\ntemplate const void* CudnnSPOnePtr<float>();\ntemplate const void* CudnnSPOnePtr<double>();\ntemplate const void* CudnnSPOnePtr<float16>();\n\ntemplate const void* CudnnSPZeroPtr<float>();\ntemplate const void* CudnnSPZeroPtr<double>();\ntemplate const void* CudnnSPZeroPtr<float16>();\n\nconst void* CudnnSPOnePtr(const DataType dtype) {\n  if (dtype == kDouble) {\n    return CudnnSPOnePtr<double>();\n  } else if (dtype == kFloat) {\n    return CudnnSPOnePtr<float>();\n  } else if (dtype == kFloat16) {\n    return CudnnSPOnePtr<float16>();\n  } else if (dtype == kBFloat16) {\n    // NOTE(guoran): kBFloat16 use float OnePtr\n    return CudnnSPOnePtr<float>();\n  } else {\n    UNIMPLEMENTED();\n  }\n}\n\nconst void* CudnnSPZeroPtr(const DataType dtype) {\n  if (dtype == kDouble) {\n    return CudnnSPZeroPtr<double>();\n  } else if (dtype == kFloat) {\n    return CudnnSPZeroPtr<float>();\n  } else if (dtype == kFloat16) {\n    return CudnnSPZeroPtr<float16>();\n  } else if (dtype == kBFloat16) {\n    // NOTE(guoran): kBFloat16 use float ZeroPtr\n    return CudnnSPZeroPtr<float>();\n  } else {\n    UNIMPLEMENTED();\n  }\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/device/cudnn_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_DEVICE_CUDNN_UTIL_H_\n#define ONEFLOW_CORE_DEVICE_CUDNN_UTIL_H_\n\n#include \"oneflow/core/common/data_type.h\"\n#include \"oneflow/core/common/shape_view.h\"\n\n#ifdef WITH_CUDA\n\n#include \"cudnn.h\"\n\nnamespace oneflow {\n\n#define CUDNN_DATA_TYPE_SEQ                       \\\n  OF_PP_MAKE_TUPLE_SEQ(float, CUDNN_DATA_FLOAT)   \\\n  OF_PP_MAKE_TUPLE_SEQ(float16, CUDNN_DATA_HALF)  \\\n  OF_PP_MAKE_TUPLE_SEQ(double, CUDNN_DATA_DOUBLE) \\\n  OF_PP_MAKE_TUPLE_SEQ(int8_t, CUDNN_DATA_INT8)   \\\n  OF_PP_MAKE_TUPLE_SEQ(int32_t, CUDNN_DATA_INT32)\n\ncudnnDataType_t GetCudnnDataType(DataType);\n\ntemplate<typename T>\nstruct CudnnDataType;\n\n#define SPECIALIZE_CUDNN_DATA_TYPE(type_cpp, type_cudnn) \\\n  template<>                                             \\\n  struct CudnnDataType<type_cpp> : std::integral_constant<cudnnDataType_t, type_cudnn> {};\nOF_PP_FOR_EACH_TUPLE(SPECIALIZE_CUDNN_DATA_TYPE, CUDNN_DATA_TYPE_SEQ);\n#undef SPECIALIZE_CUDNN_DATA_TYPE\n\nclass CudnnTensorDesc final {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CudnnTensorDesc);\n  CudnnTensorDesc();\n  ~CudnnTensorDesc();\n\n  CudnnTensorDesc(cudnnTensorFormat_t, DataType, int n, int c, int h, int w);\n  CudnnTensorDesc(DataType data_type, int dims, const int* dim, const int* stride);\n  CudnnTensorDesc(DataType data_type, const ShapeView& shape, const std::string& data_format);\n\n  const cudnnTensorDescriptor_t& Get() const { return val_; }\n\n private:\n  cudnnTensorDescriptor_t val_;\n};\n\nclass CudnnFilterDesc final {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CudnnFilterDesc);\n  CudnnFilterDesc() = delete;\n  ~CudnnFilterDesc();\n\n  CudnnFilterDesc(DataType data_type, const ShapeView& shape, const std::string& data_format);\n\n  const cudnnFilterDescriptor_t& Get() const { return val_; }\n\n private:\n  cudnnFilterDescriptor_t val_;\n};\n\nclass CudnnActivationDesc final {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CudnnActivationDesc);\n  CudnnActivationDesc() = delete;\n  ~CudnnActivationDesc();\n\n  CudnnActivationDesc(cudnnActivationMode_t mode, cudnnNanPropagation_t relu_nan_opt, double coef);\n\n  const cudnnActivationDescriptor_t& Get() const { return val_; }\n\n private:\n  cudnnActivationDescriptor_t val_;\n};\n\nsize_t GetCudnnDataTypeByteSize(cudnnDataType_t data_type);\n\n// SP for scaling parameter\ntemplate<typename T>\nconst void* CudnnSPOnePtr();\n\ntemplate<typename T>\nconst void* CudnnSPZeroPtr();\n\nconst void* CudnnSPOnePtr(const DataType dtype);\n\nconst void* CudnnSPZeroPtr(const DataType dtype);\n\nclass CudnnHandlePool {\n public:\n  CudnnHandlePool() = default;\n  ~CudnnHandlePool();\n  cudnnHandle_t Get();\n  void Put(cudnnHandle_t handle);\n\n private:\n  std::mutex mutex_;\n  HashMap<int64_t, std::vector<cudnnHandle_t>> handle_list_map_;\n};\n\n}  // namespace oneflow\n\n#endif  // WITH_CUDA\n\n#endif  // ONEFLOW_CORE_DEVICE_CUDNN_UTIL_H_\n"
  },
  {
    "path": "oneflow/core/device/device_id.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/device/device_id.h\"\n\nnamespace oneflow {\n\nnamespace {\nconstexpr size_t kInt32Bits = sizeof(int32_t) * CHAR_BIT;\n\nconstexpr size_t kDeviceIndexShift = 0;\nconstexpr size_t kDeviceTypeShift = kDeviceIndexShift + DeviceId::kDeviceIndexBits;\nconstexpr size_t kRankShift = kDeviceTypeShift + DeviceId::kDeviceTypeBits;\n\nstatic_assert(kRankShift + DeviceId::kRankBits < kInt32Bits, \"\");\n\n}  // namespace\n\nint64_t EncodeDeviceIdToInt64(const DeviceId& device_id) {\n  int64_t id = static_cast<int64_t>(device_id.device_index());\n  id |= static_cast<int64_t>(device_id.device_type()) << kDeviceTypeShift;\n  id |= static_cast<int64_t>(device_id.rank()) << kRankShift;\n  return id;\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/device/device_id.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_DEVICE_DEVICE_ID_H_\n#define ONEFLOW_CORE_DEVICE_DEVICE_ID_H_\n\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/common/device_type.h\"\n\nnamespace oneflow {\n\n// DeviceId encoding (bits)\n// | reserved   |   node_index               | device_type | device_index  |\n// | --- 1 ---- | ----------- 19 ----------- | ---- 5 ---- | ----- 7 ----- |\n// |                               DeviceId                                |\n// | ------------------------------- 32 ---------------------------------- |\n\nclass DeviceId {\n public:\n  using rank_t = uint32_t;\n  using device_type_t = uint32_t;\n  using device_index_t = uint32_t;\n\n  constexpr static size_t kRankBits = 16;\n  constexpr static size_t kDeviceTypeBits = 5;\n  constexpr static size_t kDeviceIndexBits = 7;\n  constexpr static rank_t kMaxRank = (rank_t{1} << kRankBits) - rank_t{1};\n  constexpr static device_type_t kMaxDeviceTypeVal =\n      (device_type_t{1} << kDeviceTypeBits) - device_type_t{1};\n  constexpr static device_index_t kMaxDeviceIndex =\n      (device_index_t{1} << kDeviceIndexBits) - device_index_t{1};\n\n  DeviceId(rank_t rank, DeviceType device_type, device_index_t device_index)\n      : rank_(rank),\n        device_type_(static_cast<device_type_t>(device_type)),\n        device_index_(device_index) {\n    CHECK_LE(rank_, kMaxRank);\n    CHECK_LE(device_type_, kMaxDeviceTypeVal);\n    CHECK_LE(device_index_, kMaxDeviceIndex);\n  }\n\n  rank_t rank() const { return rank_; }\n  DeviceType device_type() const { return static_cast<DeviceType>(device_type_); }\n  device_index_t device_index() const { return device_index_; }\n\n  bool operator==(const DeviceId& rhs) const {\n    return rank_ == rhs.rank_ && device_type_ == rhs.device_type_\n           && device_index_ == rhs.device_index_;\n  }\n\n  bool operator!=(const DeviceId& rhs) const { return !(*this == rhs); }\n\n  size_t hash() const {\n    size_t hash = std::hash<rank_t>{}(rank_);\n    HashCombine(&hash, std::hash<device_type_t>{}(device_type_));\n    HashCombine(&hash, std::hash<device_index_t>{}(device_index_));\n    return hash;\n  }\n\n private:\n  rank_t rank_;\n  device_type_t device_type_;\n  device_index_t device_index_;\n};\n\nint64_t EncodeDeviceIdToInt64(const DeviceId& device_id);\n\n}  // namespace oneflow\n\nnamespace std {\n\ntemplate<>\nstruct hash<oneflow::DeviceId> {\n  size_t operator()(const oneflow::DeviceId& device_id) const { return device_id.hash(); }\n};\n\n}  // namespace std\n\n#endif  // ONEFLOW_CORE_DEVICE_DEVICE_ID_H_\n"
  },
  {
    "path": "oneflow/core/device/ep_based_event_record.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_DEVICE_EP_BASED_EVENT_RECORD_H_\n#define ONEFLOW_CORE_DEVICE_EP_BASED_EVENT_RECORD_H_\n\n#include \"oneflow/core/device/event_record.h\"\n#include \"oneflow/core/ep/include/active_device_guard.h\"\n\nnamespace oneflow {\n\nclass EpBasedEventRecord : public EventRecord {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(EpBasedEventRecord);\n  EpBasedEventRecord(ep::Event* event, ep::Device* device) : event_(event), device_(device) {}\n  ~EpBasedEventRecord() {\n    ep::ActiveDeviceGuard guard(device_);\n    device_->DestroyEvent(event_);\n  };\n\n  static std::shared_ptr<EventRecord> MakeEventRecord(ep::Stream* stream) {\n    ep::Device* device = stream->device();\n    ep::ActiveDeviceGuard guard(device);\n    ep::Event* event = device->CreateEvent();\n    stream->RecordEvent(event);\n    return std::make_shared<EpBasedEventRecord>(event, device);\n  }\n\n  bool QueryDone() const override {\n    ep::ActiveDeviceGuard guard(device_);\n    bool done = CHECK_JUST(event_->QueryDone());\n    return done;\n  }\n\n private:\n  ep::Event* event_;\n  ep::Device* device_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_DEVICE_EP_BASED_EVENT_RECORD_H_\n"
  },
  {
    "path": "oneflow/core/device/event_record.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_DEVICE_EVENT_RECORD_H_\n#define ONEFLOW_CORE_DEVICE_EVENT_RECORD_H_\n\n#include <atomic>\n#include <memory>\n#include \"oneflow/core/common/util.h\"\n\nnamespace oneflow {\n\nclass EventRecord {\n public:\n  EventRecord(const EventRecord&) = delete;\n  EventRecord(EventRecord&&) = delete;\n  EventRecord& operator=(const EventRecord&) = delete;\n  EventRecord& operator=(EventRecord&&) = delete;\n  virtual ~EventRecord() = default;\n\n  virtual bool QueryDone() const = 0;\n\n  EventRecord() = default;\n};\n\nclass NaiveEventRecord final : public EventRecord {\n public:\n  NaiveEventRecord(const NaiveEventRecord&) = delete;\n  NaiveEventRecord(NaiveEventRecord&&) = delete;\n  NaiveEventRecord& operator=(const NaiveEventRecord&) = delete;\n  NaiveEventRecord& operator=(NaiveEventRecord&&) = delete;\n\n  NaiveEventRecord() = default;\n  ~NaiveEventRecord() override = default;\n\n  bool QueryDone() const override { return true; }\n};\n\nclass SharedEventRecord final : public EventRecord {\n public:\n  SharedEventRecord(const SharedEventRecord&) = delete;\n  SharedEventRecord(SharedEventRecord&&) = delete;\n  SharedEventRecord& operator=(const SharedEventRecord&) = delete;\n  SharedEventRecord& operator=(SharedEventRecord&&) = delete;\n\n  SharedEventRecord() : EventRecord(), inited_(false) {}\n  ~SharedEventRecord() override = default;\n\n  bool QueryDone() const override { return inited_ && event_record_->QueryDone(); }\n\n  void Init(const std::shared_ptr<EventRecord>& event_record) {\n    // No lock needed. This function will be called only one time.\n    // In most cases, errors will be successfully detected by CHECK\n    // even though run in different threads.\n    CHECK(!inited_);\n    event_record_ = event_record;\n    inited_ = true;\n  }\n  void TryInit(const std::shared_ptr<EventRecord>& event_record) {\n    if (!inited_) { Init(event_record); }\n  }\n\n private:\n  std::atomic<bool> inited_;\n  std::shared_ptr<EventRecord> event_record_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_DEVICE_EVENT_RECORD_H_\n"
  },
  {
    "path": "oneflow/core/device/nccl_util.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/device/nccl_util.h\"\n\nnamespace oneflow {\n\n#ifdef WITH_CUDA\n\nstd::string NcclUniqueIdToString(const ncclUniqueId& unique_id) {\n  return std::string(unique_id.internal, NCCL_UNIQUE_ID_BYTES);\n}\n\nvoid NcclUniqueIdFromString(const std::string& str, ncclUniqueId* unique_id) {\n  CHECK_EQ(str.size(), NCCL_UNIQUE_ID_BYTES);\n  memcpy(unique_id->internal, str.data(), NCCL_UNIQUE_ID_BYTES);\n}\n\n#endif  // WITH_CUDA\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/device/nccl_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_DEVICE_NCCL_UTIL_H_\n#define ONEFLOW_CORE_DEVICE_NCCL_UTIL_H_\n\n#include \"oneflow/core/register/blob.h\"\n#include \"oneflow/core/common/data_type.pb.h\"\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/device/cuda_util.h\"\n\n#ifdef WITH_CUDA\n\n#include <cuda.h>\n#if CUDA_VERSION >= 11000\n#include <cuda_bf16.h>\n#endif  // CUDA_VERSION >= 11000\n\n#endif  // WITH_CUDA\n\nnamespace oneflow {\n\n#ifdef WITH_CUDA\n\ninline ncclDataType_t GetNcclDataType(const DataType& dt) {\n  switch (dt) {\n#define NCCL_DATA_TYPE_CASE(dtype) \\\n  case DataType::k##dtype: return ncclDataType_t::nccl##dtype\n    NCCL_DATA_TYPE_CASE(Char);\n    NCCL_DATA_TYPE_CASE(Float);\n    NCCL_DATA_TYPE_CASE(Double);\n    NCCL_DATA_TYPE_CASE(Int8);\n    NCCL_DATA_TYPE_CASE(Int32);\n    NCCL_DATA_TYPE_CASE(Int64);\n    NCCL_DATA_TYPE_CASE(Float16);\n    case DataType::kBool: return ncclDataType_t::ncclUint8;\n#if defined(__CUDA_BF16_TYPES_EXIST__) && NCCL_VERSION_CODE >= 21003\n    case DataType::kBFloat16: return ncclBfloat16;\n#endif\n    case DataType::kUInt8: return ncclUint8;\n    case DataType::kUInt32: return ncclUint32;\n    case DataType::kUInt64: return ncclUint64;\n    default: UNIMPLEMENTED();\n  }\n  return ncclDataType_t::ncclFloat;\n}\n\nstd::string NcclUniqueIdToString(const ncclUniqueId& unique_id);\n\nvoid NcclUniqueIdFromString(const std::string& str, ncclUniqueId* unique_id);\n\n#define HAS_NCCL_SEND_RECV NCCL_VERSION_CODE > 2700\n\n#endif  // WITH_CUDA\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_DEVICE_NCCL_UTIL_H_\n"
  },
  {
    "path": "oneflow/core/eager/call_context.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/eager/call_context.h\"\n#include \"oneflow/core/eager/tensor_storage.h\"\n\nnamespace oneflow {\nnamespace eager {\nnamespace {\n\nvm::WeakEagerBlobObjectList shared_to_weak(const vm::EagerBlobObjectList& shared_list) {\n  vm::WeakEagerBlobObjectList ret;\n  ret.reserve(shared_list.size());\n  for (const auto& shared : shared_list) { ret.emplace_back(shared); }\n  return ret;\n}\n\n}  // namespace\nDtrCallContext::DtrCallContext(const CallContext& call_ctx)\n    : composed_attrs_(call_ctx.composed_attrs()),\n      inputs_(call_ctx.inputs()),\n      outputs_(shared_to_weak(call_ctx.outputs())),\n      global_tensor_infer_result_(call_ctx.global_tensor_infer_result()),\n      op_interp_ctx_(call_ctx.op_interp_ctx()),\n      tmp_tensor_(call_ctx.tmp_tensor()) {\n  for (const auto& x : call_ctx.outputs()) {\n    ebo_infos_.push_back(EBOInfo{std::make_shared<MemoryCase>(x->mem_case()), x->tensor_meta(),\n                                 x->mut_tensor_meta(), x->data_type(), x->memory_format()});\n  }\n}\n\nCallContext::CallContext(const DtrCallContext& dtr_call_ctx)\n    : composed_attrs_(dtr_call_ctx.composed_attrs_),\n      inputs_(dtr_call_ctx.inputs_),\n      global_tensor_infer_result_(dtr_call_ctx.global_tensor_infer_result_),\n      op_interp_ctx_(dtr_call_ctx.op_interp_ctx_),\n      tmp_tensor_(dtr_call_ctx.tmp_tensor_) {\n  for (int i = 0; i < dtr_call_ctx.outputs_.size(); ++i) {\n    const auto& weak = dtr_call_ctx.outputs_[i];\n    if (weak.expired()) {\n      LOG(INFO) << \"index: \" << i << \" is expired\";\n      outputs_.push_back(std::make_shared<vm::EagerBlobObject>(\n          dtr_call_ctx.ebo_infos_[i].mem_case, dtr_call_ctx.ebo_infos_[i].local_tensor_meta,\n          dtr_call_ctx.ebo_infos_[i].dynamic_local_tensor_meta,\n          dtr_call_ctx.ebo_infos_[i].data_type, dtr_call_ctx.ebo_infos_[i].memory_format,\n          std::make_shared<vm::TensorStorage>(\n              true, dtr_call_ctx.ebo_infos_[i].local_tensor_meta->device())));\n    } else {\n      outputs_.push_back(weak.lock());\n    }\n  }\n}\n\nCallContext DtrCallContext::ToCallContext() const { return CallContext(*this); }\n\n}  // namespace eager\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/eager/call_context.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_EAGER_CALL_CONTEXT_H_\n#define ONEFLOW_CORE_EAGER_CALL_CONTEXT_H_\n\n#include \"oneflow/core/framework/attr_map.h\"\n#include \"oneflow/core/eager/eager_blob_object.h\"\n#include \"oneflow/core/framework/op_interpreter.h\"\n#include \"oneflow/core/common/shape_view.h\"\n#include \"oneflow/core/common/stride.h\"\n#include \"oneflow/core/common/small_vector.h\"\n\nnamespace oneflow {\n\nnamespace one {\n\nclass StatefulLocalOpKernel;\nclass GlobalTensorInferResult;\n\n}  // namespace one\n\nnamespace eager {\n\nclass TmpTensor final : public user_op::Tensor {\n public:\n  explicit TmpTensor(const std::shared_ptr<MemoryCase>& mem_case)\n      : mem_case_(mem_case), tmp_buffer_size_(0), tmp_buffer_ptr_(nullptr) {}\n  ~TmpTensor() = default;\n  TmpTensor(const TmpTensor& other)\n      : mem_case_(other.mem_case_),\n        tmp_buffer_size_(other.tmp_buffer_size_),\n        tmp_buffer_ptr_(other.tmp_buffer_ptr_) {\n    CHECK_ISNULL(tmp_buffer_ptr_);\n  }\n  TmpTensor(TmpTensor&&) = delete;\n  TmpTensor& operator=(const TmpTensor& other) = delete;\n  TmpTensor& operator=(TmpTensor&&) = delete;\n\n  ShapeView shape_view() const override { return ShapeView(&tmp_buffer_size_, 1); }\n  MutShapeView mut_shape_view() override { return MutShapeView(&tmp_buffer_size_, 1); }\n  const Stride& stride() const override {\n    UNIMPLEMENTED() << \"TmpTensor::stride() is not implemented.\";\n  }\n  DataType data_type() const override { return DataType::kChar; }\n  MemoryFormat memory_format() const override { return MemoryFormat::kContiguous; }\n  const MemoryCase& mem_case() const override { return *mem_case_; }\n  const void* raw_dptr() const override { return tmp_buffer_ptr_; }\n  void* mut_raw_dptr() override { return tmp_buffer_ptr_; }\n\n  int64_t tmp_buffer_size() const { return tmp_buffer_size_; }\n  void set_tmp_buffer_size(int64_t val) { tmp_buffer_size_ = val; }\n\n  char* mut_tmp_buffer_ptr() { return tmp_buffer_ptr_; }\n\n  void set_tmp_buffer_ptr(char* ptr) { tmp_buffer_ptr_ = ptr; }\n\n private:\n  std::shared_ptr<MemoryCase> mem_case_;\n  int64_t tmp_buffer_size_;\n  char* tmp_buffer_ptr_;\n};\n\nclass DtrCallContext;\n\nclass CallContext {\n public:\n  CallContext(ComposedAttrMap composed_attrs, vm::EagerBlobObjectList inputs,\n              vm::EagerBlobObjectList outputs,\n              const std::shared_ptr<const one::GlobalTensorInferResult>& global_tensor_infer_result,\n              const one::OpExprInterpContext& op_interp_ctx,\n              const std::shared_ptr<MemoryCase>& mem_case)\n      : composed_attrs_(std::move(composed_attrs)),\n        inputs_(std::move(inputs)),\n        outputs_(std::move(outputs)),\n        global_tensor_infer_result_(global_tensor_infer_result),\n        op_interp_ctx_(op_interp_ctx),\n        tmp_tensor_(mem_case) {}\n  explicit CallContext(const DtrCallContext&);\n\n  ~CallContext() = default;\n\n  const ComposedAttrMap& composed_attrs() const { return composed_attrs_; }\n  const vm::EagerBlobObjectList& inputs() const { return inputs_; }\n  const vm::EagerBlobObjectList& outputs() const { return outputs_; }\n  vm::EagerBlobObjectList& mut_inputs() { return inputs_; }\n  vm::EagerBlobObjectList& mut_outputs() { return outputs_; }\n  const std::shared_ptr<const one::GlobalTensorInferResult>& global_tensor_infer_result() const {\n    return global_tensor_infer_result_;\n  }\n  const one::OpExprInterpContext& op_interp_ctx() const { return op_interp_ctx_; }\n  TmpTensor* mut_tmp_tensor() { return &tmp_tensor_; }\n  const TmpTensor& tmp_tensor() const { return tmp_tensor_; }\n\n private:\n  const ComposedAttrMap composed_attrs_;\n  vm::EagerBlobObjectList inputs_;\n  vm::EagerBlobObjectList outputs_;\n  const std::shared_ptr<const one::GlobalTensorInferResult> global_tensor_infer_result_;\n  const one::OpExprInterpContext op_interp_ctx_;\n  TmpTensor tmp_tensor_;\n};\n\nclass DtrCallContext {\n public:\n  explicit DtrCallContext(const CallContext& call_ctx);\n  CallContext ToCallContext() const;\n  vm::EagerBlobObjectList& mut_inputs() { return inputs_; }\n  vm::WeakEagerBlobObjectList& mut_outputs() { return outputs_; }\n  friend class CallContext;\n\n private:\n  struct EBOInfo {\n    const std::shared_ptr<MemoryCase> mem_case;\n    const Symbol<one::LocalTensorMeta> local_tensor_meta;\n    const std::shared_ptr<const one::MutLocalTensorMeta> dynamic_local_tensor_meta;\n    const DataType data_type;\n    const MemoryFormat memory_format;\n  };\n  using EBOInfoList = small_vector<EBOInfo, vm::WeakEagerBlobObjectList::kInitialSize>;\n\n  const ComposedAttrMap composed_attrs_;\n  vm::EagerBlobObjectList inputs_;\n  vm::WeakEagerBlobObjectList outputs_;\n  EBOInfoList ebo_infos_;\n  const std::shared_ptr<const one::GlobalTensorInferResult> global_tensor_infer_result_;\n  const one::OpExprInterpContext op_interp_ctx_;\n  TmpTensor tmp_tensor_;\n};\n\n}  // namespace eager\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_EAGER_CALL_CONTEXT_H_\n"
  },
  {
    "path": "oneflow/core/eager/dev_vm_dep_object_consume_mode.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#ifndef ONEFLOW_CORE_EAGER_DEV_VM_DEP_OBJECT_CONSUME_MODE_H_\n#define ONEFLOW_CORE_EAGER_DEV_VM_DEP_OBJECT_CONSUME_MODE_H_\n\nnamespace oneflow {\nnamespace one {\n\nenum class DevVmDepObjectConsumeMode {\n  NONE,\n  MUTABLE,\n};\n\ninline DevVmDepObjectConsumeMode* CurrentDevVmDepObjectConsumeMode() {\n  static thread_local DevVmDepObjectConsumeMode mode_ = DevVmDepObjectConsumeMode::MUTABLE;\n  return &mode_;\n}\n\nclass DevVmDepObjectConsumeModeGuard {\n public:\n  DevVmDepObjectConsumeModeGuard(DevVmDepObjectConsumeMode mode)\n      : prev_mode_(*CurrentDevVmDepObjectConsumeMode()) {\n    *CurrentDevVmDepObjectConsumeMode() = mode;\n  }\n  ~DevVmDepObjectConsumeModeGuard() { *CurrentDevVmDepObjectConsumeMode() = prev_mode_; }  // NOLINT\n\n private:\n  DevVmDepObjectConsumeMode prev_mode_;\n};\n\n}  // namespace one\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_EAGER_DEV_VM_DEP_OBJECT_CONSUME_MODE_H_\n"
  },
  {
    "path": "oneflow/core/eager/eager_blob_object.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/eager/eager_blob_object.h\"\n#include \"oneflow/core/eager/tensor_storage.h\"\n#include \"oneflow/core/vm/allocator.h\"\n#include \"oneflow/core/framework/to_string.h\"\n#include \"oneflow/core/framework/shut_down_util.h\"\n#include \"oneflow/core/common/shape_vec.h\"\n#include \"oneflow/core/common/tensor_meta.h\"\n\nnamespace oneflow {\n\nnamespace vm {\n\nEagerBlobObject::EagerBlobObject(\n    const std::shared_ptr<MemoryCase>& mem_case,\n    const Symbol<one::LocalTensorMeta>& static_local_tensor_meta,\n    const std::shared_ptr<const one::MutLocalTensorMeta>& dynamic_local_tensor_meta,\n    DataType data_type, MemoryFormat memory_format,\n    const std::shared_ptr<TensorStorage>& tensor_storage,\n    const intrusive::shared_ptr<LocalDepObject>& dep_object)\n    : is_dynamic_(false),\n      mem_case_(mem_case),\n      data_type_(data_type),\n      memory_format_(memory_format),\n      storage_offset_(0),\n      tensor_storage_(tensor_storage),\n      compute_local_dep_object_(dep_object),\n      static_local_tensor_meta_(static_local_tensor_meta),\n      dynamic_local_tensor_meta_(dynamic_local_tensor_meta) {\n  CHECK(static_cast<bool>(tensor_storage));\n}\n\n// user_op::TensorDesc overrides\nconst Shape& EagerBlobObject::shape() const {\n  if (dynamic_local_tensor_meta_) {\n    return dynamic_local_tensor_meta_->shape();\n  } else {\n    return static_local_tensor_meta_->shape();\n  }\n}\nconst Stride& EagerBlobObject::stride() const {\n  if (dynamic_local_tensor_meta_) {\n    return dynamic_local_tensor_meta_->stride();\n  } else {\n    return static_local_tensor_meta_->stride();\n  }\n}\n\nvoid EagerBlobObject::set_shape(const Shape& shape) {\n  CHECK(dynamic_local_tensor_meta_);\n  std::const_pointer_cast<one::MutLocalTensorMeta>(dynamic_local_tensor_meta_)->set_shape(shape);\n}\nvoid EagerBlobObject::set_stride(const Stride& stride) {\n  CHECK(dynamic_local_tensor_meta_);\n  std::const_pointer_cast<one::MutLocalTensorMeta>(dynamic_local_tensor_meta_)->set_stride(stride);\n}\n\nMutShapeView EagerBlobObject::mut_shape_view() {\n  CHECK(dynamic_local_tensor_meta_);\n  return *const_cast<Shape*>(dynamic_local_tensor_meta_->shape_ptr().get());\n}\n\nstd::shared_ptr<const Shape> EagerBlobObject::shape_ptr() const {\n  if (dynamic_local_tensor_meta_) {\n    return dynamic_local_tensor_meta_->shape_ptr();\n  } else {\n    return static_local_tensor_meta_->shape_ptr();\n  }\n}\nstd::shared_ptr<const Stride> EagerBlobObject::stride_ptr() const {\n  if (dynamic_local_tensor_meta_) {\n    return dynamic_local_tensor_meta_->stride_ptr();\n  } else {\n    return static_local_tensor_meta_->stride_ptr();\n  }\n}\n\nint64_t EagerBlobObject::storage_offset() const { return storage_offset_; }\n\nvoid EagerBlobObject::set_storage_offset(const int64_t offset) { storage_offset_ = offset; }\n\nMaybe<bool> EagerBlobObject::TryAllocateBlobBodyMemory(vm::Allocator* allocator) {\n  size_t required_body_bytes = AlignedByteSizeOfBlobBody();\n  if (required_body_bytes == 0) {\n    CHECK_ISNULL_OR_RETURN(tensor_storage_->blob_dptr());\n  } else if (tensor_storage_->blob_dptr() != nullptr) {\n    CHECK_GE_OR_RETURN(tensor_storage_->blob_bytes(), ByteSizeOfBlobBody())\n        << \"This blob has been allocated memory, but less than needed space.\";\n  } else {\n    char* dptr = nullptr;\n    JUST(allocator->Allocate(&dptr, required_body_bytes));\n    // reset tensor_storage_;\n    const auto& Free = [allocator, required_body_bytes](char* dptr) {\n      if (IsShuttingDown()) { return; }\n      allocator->Deallocate(dptr, required_body_bytes);\n    };\n    tensor_storage_->set_blob_dptr(std::unique_ptr<char, std::function<void(char*)>>(dptr, Free),\n                                   required_body_bytes);\n    InitNonPODTypeEagerBlobObjectIfNeed(tensor_storage_->non_pod_allocator(), this);\n    return true;\n  }\n  return false;\n}\n\nconst void* EagerBlobObject::raw_dptr() const {\n  char* ptr = tensor_storage_->blob_dptr();\n  if (tensor_storage_->blob_bytes() > 0) { CHECK_NOTNULL(ptr); }\n  return ptr + storage_offset_ * GetSizeOfDataType(data_type_);\n}\n\nMaybe<void> EagerBlobObject::DeallocateBlobDataPtr() {\n  tensor_storage_->Release();\n  return Maybe<void>::Ok();\n}\n\nvoid EagerBlobObject::RegisterStorageDeleteHook(const std::function<void()>& hook) {\n  tensor_storage_->RegisterStorageDeleteHook(hook);\n}\n\nconst Optional<Symbol<::oneflow::Stream>>& EagerBlobObject::producer_stream() const {\n  return tensor_storage_->producer_stream();\n}\n\nMaybe<void> EagerBlobObject::init_producer_stream(Symbol<::oneflow::Stream> producer_stream) {\n  return tensor_storage_->init_producer_stream(producer_stream);\n}\n\nconst Optional<Symbol<::oneflow::Stream>>& EagerBlobObject::last_used_stream() const {\n  return tensor_storage_->last_used_stream();\n}\n\nvoid EagerBlobObject::set_last_used_stream(Symbol<::oneflow::Stream> last_used_stream) {\n  tensor_storage_->set_last_used_stream(last_used_stream);\n}\n\n}  // namespace vm\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/eager/eager_blob_object.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_EAGER_EAGER_BLOB_OBJECT_H_\n#define ONEFLOW_CORE_EAGER_EAGER_BLOB_OBJECT_H_\n\n#include <utility>\n\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/common/optional.h\"\n#include \"oneflow/core/common/op_args_reserved_size.h\"\n#include \"oneflow/core/eager/local_dep_object.h\"\n#include \"oneflow/core/memory/memory_allocator.h\"\n#include \"oneflow/core/framework/device.h\"\n#include \"oneflow/core/framework/stream.h\"\n#include \"oneflow/core/framework/tensor_methods.h\"\n#include \"oneflow/core/framework/user_op_tensor.h\"\n#include \"oneflow/core/common/tensor_desc.h\"\n#include \"oneflow/core/register/blob.h\"\n\nnamespace oneflow {\n\nnamespace one {\n\nclass LocalTensorMeta;\nclass MutLocalTensorMeta;\n\n}  // namespace one\n\nnamespace vm {\n\nclass Allocator;\n\nclass EagerBlobObject final : public user_op::Tensor,\n                              public user_op::TensorDesc,\n                              public std::enable_shared_from_this<EagerBlobObject> {\n public:\n  EagerBlobObject(const EagerBlobObject&) = delete;\n  EagerBlobObject(EagerBlobObject&&) = delete;\n  EagerBlobObject(const std::shared_ptr<MemoryCase>& mem_case,\n                  const Symbol<one::LocalTensorMeta>& static_local_tensor_meta,\n                  const std::shared_ptr<const one::MutLocalTensorMeta>& dynamic_local_tensor_meta,\n                  DataType data_type, MemoryFormat memory_format,\n                  const std::shared_ptr<TensorStorage>& tensor_storage)\n      : EagerBlobObject(mem_case, static_local_tensor_meta, dynamic_local_tensor_meta, data_type,\n                        memory_format, tensor_storage, intrusive::shared_ptr<LocalDepObject>()) {}\n  EagerBlobObject(const std::shared_ptr<MemoryCase>& mem_case,\n                  const Symbol<one::LocalTensorMeta>& static_local_tensor_meta,\n                  const std::shared_ptr<const one::MutLocalTensorMeta>& dynamic_local_tensor_meta,\n                  DataType data_type, MemoryFormat memory_format,\n                  const std::shared_ptr<TensorStorage>& tensor_storage,\n                  const intrusive::shared_ptr<LocalDepObject>& dep_object);\n\n  ~EagerBlobObject() { tensor_storage_.reset(); }\n\n  const std::shared_ptr<const one::MutLocalTensorMeta>& mut_tensor_meta() {\n    return dynamic_local_tensor_meta_;\n  }\n  // Getters\n  const Symbol<one::LocalTensorMeta>& tensor_meta() const { return static_local_tensor_meta_; }\n\n  // user_op::TensorDesc overrides\n  const Shape& shape() const override;\n  const Stride& stride() const override;\n  DataType data_type() const override { return data_type_; }\n  bool is_dynamic() const override { return is_dynamic_; }\n  MemoryFormat memory_format() const override { return memory_format_; }\n\n  void set_shape(const Shape& shape) override;\n  void set_stride(const Stride& stride) override;\n  void set_data_type(DataType data_type) override { data_type_ = data_type; }\n  void set_is_dynamic(bool is_dynamic) override { is_dynamic_ = is_dynamic; }\n  void set_memory_format(MemoryFormat memory_format) override { memory_format_ = memory_format; }\n\n  // user_op::Tensor overrides\n  ShapeView shape_view() const override { return shape(); }\n  MutShapeView mut_shape_view() override;\n  const MemoryCase& mem_case() const override { return *mem_case_; }\n  const void* raw_dptr() const override;\n  void* mut_raw_dptr() override { return const_cast<void*>(raw_dptr()); }\n\n  int64_t storage_offset() const;\n  void set_storage_offset(const int64_t offset);\n\n  // Returns true if allocate successfully.\n  Maybe<bool> TryAllocateBlobBodyMemory(vm::Allocator* allocator);\n  Maybe<void> DeallocateBlobDataPtr();\n  void RegisterStorageDeleteHook(const std::function<void()>& hook);\n\n  Maybe<LocalDepObject*> compute_local_dep_object() const {\n    CHECK_NOTNULL_OR_RETURN(compute_local_dep_object_.get());\n    return compute_local_dep_object_.get();\n  }\n\n  std::shared_ptr<TensorStorage>& tensor_storage() { return tensor_storage_; }\n\n  const Optional<Symbol<::oneflow::Stream>>& producer_stream() const;\n  Maybe<void> init_producer_stream(Symbol<::oneflow::Stream> producer_stream);\n\n  const Optional<Symbol<::oneflow::Stream>>& last_used_stream() const;\n  void set_last_used_stream(Symbol<::oneflow::Stream> last_used_stream);\n\n  std::shared_ptr<const Shape> shape_ptr() const;\n  std::shared_ptr<const Stride> stride_ptr() const;\n\n  size_t ByteSizeOfBlobBody() const {\n    const size_t elem_cnt = shape().elem_cnt();\n    if (elem_cnt == 0) { return 0; }\n    size_t max_offset = 0;\n    for (size_t i = 0; i < shape().NumAxes(); ++i) {\n      max_offset += (shape().at(i) - 1) * stride().at(i);\n    }\n    size_t capacity = max_offset + 1;\n    // TODO(liujuncheng): remove this\n    capacity = std::max<size_t>(capacity, elem_cnt);\n    return capacity * GetSizeOfDataType(data_type_);\n  }\n  size_t AlignedByteSizeOfBlobBody() const {\n    return RoundUp(ByteSizeOfBlobBody(), kBlobBodyAlignSize);\n  }\n  size_t ByteSizeOfBlobHeader() const { return shape().NumAxes() * sizeof(int64_t); }\n  size_t AlignedByteSizeOfBlobHeader() const {\n    return RoundUp(ByteSizeOfBlobHeader(), kBlobHeaderAlignSize);\n  }\n\n  const char* header_ptr() const { return reinterpret_cast<const char*>(shape().dim_vec().data()); }\n  char* mut_header_ptr() {\n    return reinterpret_cast<char*>(const_cast<int64_t*>(shape().dim_vec().data()));\n  }\n\n  void set_input_of_view_op(std::shared_ptr<EagerBlobObject> input) {\n    input_of_view_op_ = std::move(input);\n  }\n\n private:\n  bool is_dynamic_;\n  std::shared_ptr<MemoryCase> mem_case_;\n  DataType data_type_;\n  MemoryFormat memory_format_;\n  int64_t storage_offset_;\n  std::shared_ptr<TensorStorage> tensor_storage_;\n  intrusive::shared_ptr<LocalDepObject> compute_local_dep_object_;\n\n  Symbol<one::LocalTensorMeta> static_local_tensor_meta_;\n  std::shared_ptr<const one::MutLocalTensorMeta> dynamic_local_tensor_meta_;\n  // for rematerialization (i.e. Coop/DTR)\n  std::shared_ptr<EagerBlobObject> input_of_view_op_;\n};\n\nusing EagerBlobObjectList = small_vector<std::shared_ptr<vm::EagerBlobObject>, kOpArgsReservedSize>;\nusing WeakEagerBlobObjectList = small_vector<std::weak_ptr<vm::EagerBlobObject>>;\nusing EagerBlobObjectListPtr = std::shared_ptr<const EagerBlobObjectList>;\n\n}  // namespace vm\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_EAGER_EAGER_BLOB_OBJECT_H_\n"
  },
  {
    "path": "oneflow/core/eager/local_dep_object.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/eager/local_dep_object.h\"\n#include \"oneflow/core/common/decorator.h\"\n#include \"oneflow/core/common/static_global.h\"\n\nnamespace oneflow {\n\nintrusive::shared_ptr<LocalDepObject> NewLocalDepObject() {\n  return intrusive::make_shared<LocalDepObject>();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/eager/local_dep_object.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_FRAMEWORK_LOCAL_DEP_OBJECT_H_\n#define ONEFLOW_CORE_FRAMEWORK_LOCAL_DEP_OBJECT_H_\n\n#include \"oneflow/core/intrusive/intrusive.h\"\n#include \"oneflow/core/vm/vm_object.h\"\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/common/symbol.h\"\n#include \"oneflow/core/common/small_vector.h\"\n#include \"oneflow/core/common/op_args_reserved_size.h\"\n#include \"oneflow/core/framework/device.h\"\n\nnamespace oneflow {\n\n// LocalDepObject helps VirtualMachineEngine building instruction edges\nusing LocalDepObject = vm::Dependence;\n\nusing DependenceVector = small_vector<LocalDepObject*>;\n\nintrusive::shared_ptr<LocalDepObject> NewLocalDepObject();\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_FRAMEWORK_LOCAL_DEP_OBJECT_H_\n"
  },
  {
    "path": "oneflow/core/eager/tensor_storage.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/eager/tensor_storage.h\"\n#include \"oneflow/core/common/env_var/remat.h\"\n#include \"oneflow/core/vm/op_call_instruction_policy.h\"\n#include \"oneflow/core/vm/remat/disjoint_set.h\"\n#include \"oneflow/core/vm/remat/env.h\"\n#include \"oneflow/core/vm/remat/util.h\"\n#include \"oneflow/core/vm/virtual_machine.h\"\n\nnamespace oneflow {\nnamespace vm {\nnamespace {\nint64_t unique_id() {\n  static size_t id = 0;\n  return id++;\n}\n\n}  // namespace\n\nTensorStorage::TensorStorage(bool is_allocated_in_vm, Symbol<Device> device)\n    : blob_bytes_(0),\n      device_(device),\n      non_pod_allocator_(std::make_unique<MemoryAllocator>()),\n      producer_stream_(NullOpt),\n      last_used_stream_(NullOpt),\n      is_allocated_in_vm_(is_allocated_in_vm) {}\n\nSymbol<Device> TensorStorage::device() const { return device_; }\n\nTensorStorage::~TensorStorage() {\n  for (const auto& hook : storage_delete_hooks_) { hook(); }\n}\n\nvoid TensorStorage::_Release() {\n  non_pod_allocator_.reset();\n  blob_dptr_.reset();\n}\n\nvoid TensorStorage::Release() { return _Release(); }\n\nMaybe<void> TensorStorage::init_producer_stream(Symbol<::oneflow::Stream> producer_stream) {\n  CHECK_OR_RETURN(!producer_stream_.has_value());\n  producer_stream_ = producer_stream;\n  return Maybe<void>::Ok();\n}\n\nRematableTensorStorage::RematableTensorStorage(Symbol<Device> device)\n    : TensorStorage(true, device),\n      node(std::make_shared<remat::DisjNode>(0)),\n      id_(unique_id()),\n      num_pinned_(0),\n      last_access_time_(0),\n      compute_time_(0) {\n  VLOG(1) << \"create rematable storage \" << id_;\n}\n\nRematableTensorStorage::~RematableTensorStorage() {\n  // We must call _Release before destruction or the release will be\n  // called in base class's destructor and causes segfault.\n  // Time order:\n  // 1. ~RematableTensorStorage destructs its members\n  // 2. ~TensorStorage, Allocator::Deallocate, which uses RematableTensorStorage members\n  _Release();\n  if (compute_op_) { Singleton<remat::Env>::Get()->remove_compute_op(compute_op_.get()); }\n  VLOG(1) << \"delete storage \" << id_;\n}\n\nvoid RematableTensorStorage::LogEviction(bool eager_eviction) const {\n  Singleton<remat::Env>::Get()->add_eviction_num(eager_eviction);\n  VLOG(1) << \"evict storage \" << id_ << \", compute op type: \" << compute_op_type_name()\n          << \", eager_eviction: \" << eager_eviction;\n}\n\nvoid RematableTensorStorage::Remat() {\n  if (is_in_memory()) { return; }\n  auto stream = CHECK_JUST(GetDefaultStreamByDevice(device_));\n  auto* vm_stream = CHECK_JUST(Singleton<VirtualMachine>::Get()->GetVmStream(stream));\n  auto op = compute_op();\n  CHECK_JUST(Recompute(&op, vm_stream));\n}\n\nvoid RematableTensorStorage::Evict(bool eager_eviction) {\n  CHECK(!is_eviction_disabled());\n  LogEviction(eager_eviction);\n  return _Release();\n}\n\nvoid RematableTensorStorage::Release() {\n  CHECK(device_->rematable());\n  if (is_eviction_disabled()) { return; }\n  return Evict(true);\n}\n\nstd::vector<std::string> random_ops{\"uniform\", \"uniform_int\", \"normal\", \"randperm\"};\n\nbool RematableTensorStorage::is_evictable() const {\n  return compute_op_ != nullptr\n         && std::find(random_ops.begin(), random_ops.end(), compute_op_type_name())\n                == random_ops.end()\n         && !eviction_disabled_;\n}\n\nOpCallInstructionPolicy RematableTensorStorage::compute_op() const {\n  CHECK_NOTNULL(compute_op_);\n  return OpCallInstructionPolicy(*compute_op_);\n}\n\nstd::shared_ptr<DtrOpCallInstructionPolicy> RematableTensorStorage::dtr_compute_op() const {\n  return compute_op_;\n}\n\nvoid RematableTensorStorage::Pin() {\n  ++num_pinned_;\n  VLOG(3) << \"pin storage \" << id_ << \", num_pinned: \" << num_pinned_;\n}\n\nvoid RematableTensorStorage::Unpin() {\n  CHECK_GT(num_pinned_, 0);\n  --num_pinned_;\n  VLOG(3) << \"unpin storage \" << id_ << \", num_pinned: \" << num_pinned_;\n}\n\nvoid RematableTensorStorage::clear_compute_op() {\n  if (compute_op_ == nullptr) { return; }\n  VLOG(1) << \"clear_compute_op: \" << id_;\n  Singleton<remat::Env>::Get()->remove_compute_op(compute_op_.get());\n  compute_op_ = nullptr;\n  compute_time_ = -1;\n}\n\nvoid RematableTensorStorage::set_compute_op(\n    const std::shared_ptr<DtrOpCallInstructionPolicy>& compute_op, double compute_time) {\n  CHECK_ISNULL(compute_op_);\n  compute_op_ = compute_op;\n  VLOG(1) << \"set_compute_op: \" << id_ << \", compute op: \" << compute_op.get();\n  Singleton<remat::Env>::Get()->ops.push_back(CHECK_NOTNULL(compute_op_.get()));\n  compute_time_ = compute_time;\n}\n\nstd::string RematableTensorStorage::compute_op_type_name() const {\n  if (is_eviction_disabled()) { return \"eviction_disabled\"; }\n  if (compute_op_) { return compute_op_->opkernel().op_type_name(); }\n  return \"None\";\n}\n\nvoid RematableTensorStorage::Access() {\n  last_access_time_ = Singleton<remat::Env>::Get()->time_now();\n}\n\nMaybe<double> RematableTensorStorage::cost(size_t override_size) const {\n  CHECK_OR_RETURN(!is_eviction_disabled());\n  const double time_since_last_access =\n      Singleton<remat::Env>::Get()->time_now() - last_access_time_;\n  size_t size = 1;\n  if (EnvBool<ONEFLOW_REMAT_HEURISTIC_DTE>() || EnvBool<ONEFLOW_REMAT_HEURISTIC_DTR>()) {\n    size = override_size == 0 ? blob_bytes_ : override_size;\n  }\n  return (EnvBool<ONEFLOW_REMAT_NEIGHBOR>() ? approx_neighbor_cost() : compute_time_)\n         / time_since_last_access / static_cast<double>(size);\n}\n\ndouble RematableTensorStorage::approx_neighbor_cost() const {\n  const auto cal_cost = [](const auto& eager_blob_objects) {\n    double all_cost = 0;\n    for (int i = 0; i < eager_blob_objects.size(); ++i) {\n      const auto& tmp = eager_blob_objects[i];\n      if (auto storage = std::dynamic_pointer_cast<RematableTensorStorage>(tmp->tensor_storage());\n          !storage->is_in_memory()) {\n        double tmp_cost = remat::DisjointSet::find_father(storage->node)->compute_time();\n        if (tmp_cost < storage->compute_time()) { tmp_cost = storage->compute_time(); }\n        all_cost += tmp_cost;\n      }\n    }\n    return all_cost;\n  };\n  const auto compute_op = this->compute_op();\n  return cal_cost(compute_op.inputs()) + cal_cost(compute_op.outputs()) + compute_time_;\n}\n\n}  // namespace vm\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/eager/tensor_storage.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_EAGER_TENSOR_STORAGE_H_\n#define ONEFLOW_CORE_EAGER_TENSOR_STORAGE_H_\n\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/common/optional.h\"\n#include \"oneflow/core/memory/memory_allocator.h\"\n#include \"oneflow/core/framework/stream.h\"\n\nnamespace oneflow {\nnamespace remat {\nclass DisjNode;\n}\n\nnamespace vm {\n\nclass OpCallInstructionPolicy;\nclass DtrOpCallInstructionPolicy;\n\nclass TensorStorage {\n public:\n  explicit TensorStorage(bool is_allocated_in_vm, Symbol<Device> device);\n  OF_DISALLOW_COPY_AND_MOVE(TensorStorage);\n\n  virtual ~TensorStorage();\n\n  bool is_allocated_in_vm() const { return is_allocated_in_vm_; }\n\n  size_t blob_bytes() const { return blob_bytes_; }\n\n  char* blob_dptr() { return blob_dptr_.get(); }\n\n  MemoryAllocator* non_pod_allocator() { return non_pod_allocator_.get(); }\n\n  void set_blob_dptr(std::unique_ptr<char, std::function<void(char*)>>&& blob_dptr, size_t bytes) {\n    blob_dptr_ = std::move(blob_dptr);\n    blob_bytes_ = bytes;\n    is_initialized_ = true;\n  }\n\n  const Optional<Symbol<::oneflow::Stream>>& producer_stream() const { return producer_stream_; }\n  Maybe<void> init_producer_stream(Symbol<::oneflow::Stream> producer_stream);\n\n  const Optional<Symbol<::oneflow::Stream>>& last_used_stream() const { return last_used_stream_; }\n  void set_last_used_stream(Symbol<::oneflow::Stream> last_used_stream) {\n    last_used_stream_ = last_used_stream;\n  }\n\n  void _Release();\n  virtual void Release();\n\n  void RegisterStorageDeleteHook(const std::function<void()>& hook) {\n    storage_delete_hooks_.emplace_back(hook);\n  }\n  Symbol<Device> device() const;\n\n protected:\n  std::unique_ptr<char, std::function<void(char*)>> blob_dptr_;\n  size_t blob_bytes_;\n  bool is_initialized_ = false;\n  Symbol<Device> device_;\n\n private:\n  std::unique_ptr<MemoryAllocator> non_pod_allocator_;\n  Optional<Symbol<::oneflow::Stream>> producer_stream_;\n  Optional<Symbol<::oneflow::Stream>> last_used_stream_;\n  std::vector<std::function<void()>> storage_delete_hooks_;\n  bool is_allocated_in_vm_;\n};\n\nclass RematableTensorStorage final : public TensorStorage {\n public:\n  explicit RematableTensorStorage(Symbol<Device> device);\n  OF_DISALLOW_COPY_AND_MOVE(RematableTensorStorage);\n  ~RematableTensorStorage() override;\n\n  void set_compute_op(const std::shared_ptr<DtrOpCallInstructionPolicy>& compute_op,\n                      double compute_time);\n  void clear_compute_op();\n  OpCallInstructionPolicy compute_op() const;\n  std::shared_ptr<DtrOpCallInstructionPolicy> dtr_compute_op() const;\n  void Release() override;\n  void Remat();\n  void Evict(bool eager_eviction);\n  void Pin();\n  void Unpin();\n  void Access();\n  bool is_in_memory() const { return blob_bytes_ == 0 || blob_dptr_ != nullptr; }\n  bool is_pinned() const { return num_pinned() > 0; }\n  int32_t num_pinned() const { return num_pinned_; }\n  bool is_evictable() const;\n  void set_eviction_disabled(bool disabled) { eviction_disabled_ = disabled; }\n  bool is_eviction_disabled() const { return eviction_disabled_; }\n  int64_t id() const { return id_; }\n  Maybe<double> cost(size_t override_size) const;\n  double approx_neighbor_cost() const;\n  std::string compute_op_type_name() const;\n  bool is_initialized() const { return is_initialized_; }\n  void set_initialized() { is_initialized_ = true; }\n  bool is_needed_by_backward() const { return is_needed_by_backward_; }\n  void set_needed_by_backward() { is_needed_by_backward_ = true; }\n  double compute_time() const { return compute_time_; }\n  std::shared_ptr<remat::DisjNode> node;\n\n private:\n  int64_t id_{};\n  size_t num_pinned_{};\n  bool eviction_disabled_ = false;\n  double last_access_time_{};\n  double compute_time_{};\n  std::shared_ptr<DtrOpCallInstructionPolicy> compute_op_;\n  bool is_needed_by_backward_ = false;\n\n  void LogEviction(bool eager_eviction) const;\n};\n\n}  // namespace vm\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_EAGER_TENSOR_STORAGE_H_\n"
  },
  {
    "path": "oneflow/core/embedding/cache.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/embedding/cache.h\"\n#include \"oneflow/core/embedding/full_cache.h\"\n#include \"oneflow/core/embedding/lru_cache.h\"\n\nnamespace oneflow {\n\nnamespace embedding {\n\nstd::unique_ptr<Cache> NewCache(const CacheOptions& options) {\n#ifdef WITH_CUDA\n  CHECK_GT(options.key_size, 0);\n  CHECK_GT(options.value_size, 0);\n  CHECK_GT(options.capacity, 0);\n  if (options.policy == CacheOptions::Policy::kLRU) {\n    return NewLruCache(options);\n  } else if (options.policy == CacheOptions::Policy::kFull) {\n    return NewFullCache(options);\n  } else {\n    UNIMPLEMENTED();\n    return nullptr;\n  }\n#else\n  UNIMPLEMENTED();\n  return nullptr;\n#endif  // WITH_CUDA\n}\n\n}  // namespace embedding\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/embedding/cache.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_EMBEDDING_CACHE_H_\n#define ONEFLOW_CORE_EMBEDDING_CACHE_H_\n\n#include \"oneflow/core/embedding/kv_iterator.h\"\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/ep/include/stream.h\"\n#include \"oneflow/core/common/data_type.h\"\n\nnamespace oneflow {\n\nnamespace embedding {\n\nstruct CacheOptions {\n  enum class Policy {\n    kLRU,\n    kFull,\n  };\n  enum class MemoryKind {\n    kDevice,\n    kHost,\n  };\n  Policy policy = Policy::kLRU;\n  MemoryKind value_memory_kind = MemoryKind::kDevice;\n  uint64_t capacity{};\n  uint32_t key_size{};\n  uint32_t value_size{};\n  DataType value_type{};\n  float load_factor = 0.75;\n};\n\nclass Cache {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(Cache);\n  Cache() = default;\n  virtual ~Cache() = default;\n\n  virtual uint32_t KeySize() const = 0;\n  virtual uint32_t ValueSize() const = 0;\n  virtual DataType ValueType() const = 0;\n  virtual uint32_t MaxQueryLength() const = 0;\n  virtual void ReserveQueryLength(uint32_t query_length) = 0;\n  virtual uint64_t Capacity() const = 0;\n  virtual uint64_t DumpCapacity() const { return Capacity(); }\n  virtual CacheOptions::Policy Policy() const = 0;\n  virtual void Test(ep::Stream* stream, uint32_t n_keys, const void* keys, uint32_t* n_missing,\n                    void* missing_keys, uint32_t* missing_indices) = 0;\n  virtual void Get(ep::Stream* stream, uint32_t n_keys, const void* keys, void* values,\n                   uint32_t* n_missing, void* missing_keys, uint32_t* missing_indices) = 0;\n  virtual void Get(ep::Stream* stream, uint32_t n_keys, const void* keys, void* values,\n                   uint8_t* mask) {\n    UNIMPLEMENTED();\n  }\n  virtual void Put(ep::Stream* stream, uint32_t n_keys, const void* keys, const void* values,\n                   uint32_t* n_evicted, void* evicted_keys, void* evicted_values) = 0;\n  virtual void FusedHalfUpdatePut(ep::Stream* stream, uint32_t n_keys, const void* keys,\n                                  const void* values, const void* update, const float* lr,\n                                  float scale, uint32_t* n_evicted, void* evicted_keys,\n                                  void* evicted_values) {\n    UNIMPLEMENTED();\n  }\n  virtual void Dump(ep::Stream* stream, uint64_t start_key_index, uint64_t end_key_index,\n                    uint32_t* n_dumped, void* keys, void* values) = 0;\n\n  virtual void ClearDirtyFlags() = 0;\n\n  virtual void Clear() = 0;\n};\n\nstd::unique_ptr<Cache> NewCache(const CacheOptions& options);\n\n}  // namespace embedding\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_EMBEDDING_CACHE_H_\n"
  },
  {
    "path": "oneflow/core/embedding/cache_test.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/embedding/cache.h\"\n#include \"oneflow/core/device/cuda_util.h\"\n#include <gtest/gtest.h>\n#include \"oneflow/core/ep/include/device_manager_registry.h\"\n\nnamespace oneflow {\n\nnamespace embedding {\n\nnamespace {\n\n#ifdef WITH_CUDA\n\nbool HasCudaDevice() {\n  int device_count = 0;\n  if (cudaGetDeviceCount(&device_count) != cudaSuccess) { return false; }\n  if (device_count <= 0) { return false; }\n  return true;\n}\n\nvoid TestCache(Cache* cache, uint32_t line_size) {\n  std::unique_ptr<ep::DeviceManagerRegistry> device_manager_registry(\n      new ep::DeviceManagerRegistry());\n  auto device = device_manager_registry->GetDevice(DeviceType::kCUDA, 0);\n  ep::Stream* stream = device->CreateStream();\n\n  std::unordered_set<int64_t> in_cache;\n  const size_t n_iter = 32;\n  const uint32_t n_keys = 1024;\n  int64_t* d_keys;\n  int64_t* keys;\n  uint32_t* d_n_missing;\n  uint32_t* n_missing;\n  int64_t* d_missing_keys;\n  int64_t* missing_keys;\n  uint32_t* d_missing_indices;\n  uint32_t* missing_indices;\n  float* d_values;\n  float* values;\n  float* d_evicted_values;\n  float* evicted_values;\n  uint32_t* d_n_evicted;\n  uint32_t* n_evicted;\n  int64_t* d_evicted_keys;\n  int64_t* evicted_keys;\n  uint8_t* mask;\n  const size_t keys_size = n_keys * sizeof(int64_t);\n  OF_CUDA_CHECK(cudaMalloc(&d_keys, keys_size));\n  OF_CUDA_CHECK(cudaMallocHost(&keys, keys_size));\n  OF_CUDA_CHECK(cudaMalloc(&d_n_missing, sizeof(uint32_t)));\n  OF_CUDA_CHECK(cudaMallocHost(&n_missing, sizeof(uint32_t)));\n  OF_CUDA_CHECK(cudaMalloc(&d_missing_keys, keys_size));\n  OF_CUDA_CHECK(cudaMallocHost(&missing_keys, keys_size));\n  const size_t indices_size = n_keys * sizeof(uint32_t);\n  OF_CUDA_CHECK(cudaMalloc(&d_missing_indices, indices_size));\n  OF_CUDA_CHECK(cudaMallocHost(&missing_indices, indices_size));\n  const size_t values_size = n_keys * line_size * sizeof(float);\n  OF_CUDA_CHECK(cudaMalloc(&d_values, values_size));\n  OF_CUDA_CHECK(cudaMallocHost(&values, values_size));\n  OF_CUDA_CHECK(cudaMalloc(&d_evicted_values, values_size));\n  OF_CUDA_CHECK(cudaMallocHost(&evicted_values, values_size));\n  OF_CUDA_CHECK(cudaMalloc(&d_n_evicted, sizeof(uint32_t)));\n  OF_CUDA_CHECK(cudaMallocHost(&n_evicted, sizeof(uint32_t)));\n  OF_CUDA_CHECK(cudaMalloc(&d_evicted_keys, keys_size));\n  OF_CUDA_CHECK(cudaMallocHost(&evicted_keys, keys_size));\n  OF_CUDA_CHECK(cudaMalloc(&mask, n_keys));\n  std::vector<int64_t> random_keys(n_keys * 32);\n  std::iota(random_keys.begin(), random_keys.end(), 1);\n  std::random_device rd;\n  std::mt19937 g(rd());\n  for (size_t iter = 0; iter < n_iter; ++iter) {\n    std::shuffle(random_keys.begin(), random_keys.end(), g);\n    std::copy(random_keys.begin(), random_keys.begin() + n_keys, keys);\n    uint32_t expect_n_missing = 0;\n    std::unordered_set<int64_t> expect_missing_keys_set;\n    std::unordered_set<uint32_t> expect_missing_indices_set;\n    std::unordered_set<int64_t> keys_set;\n    for (size_t i = 0; i < n_keys; ++i) {\n      keys_set.emplace(keys[i]);\n      if (in_cache.count(keys[i]) == 0) {\n        expect_missing_keys_set.emplace(keys[i]);\n        expect_missing_indices_set.emplace(i);\n        expect_n_missing += 1;\n      }\n    }\n    // test\n    OF_CUDA_CHECK(cudaMemcpy(d_keys, keys, keys_size, cudaMemcpyDefault));\n    OF_CUDA_CHECK(cudaDeviceSynchronize());\n    cache->Test(stream, n_keys, d_keys, d_n_missing, d_missing_keys, d_missing_indices);\n    OF_CUDA_CHECK(cudaDeviceSynchronize());\n    OF_CUDA_CHECK(cudaMemcpy(n_missing, d_n_missing, sizeof(uint32_t), cudaMemcpyDefault));\n    OF_CUDA_CHECK(cudaMemcpy(missing_keys, d_missing_keys, keys_size, cudaMemcpyDefault));\n    OF_CUDA_CHECK(cudaMemcpy(missing_indices, d_missing_indices, indices_size, cudaMemcpyDefault));\n    OF_CUDA_CHECK(cudaDeviceSynchronize());\n    ASSERT_EQ(*n_missing, expect_n_missing);\n    std::unordered_set<int64_t> test_missing_keys_set;\n    std::unordered_set<uint32_t> test_missing_indices_set;\n    for (size_t i = 0; i < *n_missing; ++i) {\n      test_missing_keys_set.emplace(missing_keys[i]);\n      test_missing_indices_set.emplace(missing_indices[i]);\n      ASSERT_EQ(keys[missing_indices[i]], missing_keys[i]);\n    }\n    ASSERT_EQ(test_missing_keys_set, expect_missing_keys_set);\n    ASSERT_EQ(test_missing_indices_set, expect_missing_indices_set);\n\n    // get\n    OF_CUDA_CHECK(cudaDeviceSynchronize());\n    if (cache->Policy() == CacheOptions::Policy::kFull) {\n      cache->Get(stream, n_keys, d_keys, d_values, mask);\n    }\n    cache->Get(stream, n_keys, d_keys, d_values, d_n_missing, d_missing_keys, d_missing_indices);\n    OF_CUDA_CHECK(cudaDeviceSynchronize());\n    OF_CUDA_CHECK(cudaMemcpy(n_missing, d_n_missing, sizeof(uint32_t), cudaMemcpyDefault));\n    OF_CUDA_CHECK(cudaMemcpy(missing_keys, d_missing_keys, keys_size, cudaMemcpyDefault));\n    OF_CUDA_CHECK(cudaMemcpy(missing_indices, d_missing_indices, indices_size, cudaMemcpyDefault));\n    OF_CUDA_CHECK(cudaMemcpy(values, d_values, values_size, cudaMemcpyDefault));\n    OF_CUDA_CHECK(cudaDeviceSynchronize());\n    ASSERT_EQ(*n_missing, expect_n_missing);\n    std::unordered_set<int64_t> get_missing_keys_set;\n    std::unordered_set<uint32_t> get_missing_indices_set;\n    for (size_t i = 0; i < *n_missing; ++i) {\n      get_missing_keys_set.emplace(missing_keys[i]);\n      get_missing_indices_set.emplace(missing_indices[i]);\n      ASSERT_EQ(keys[missing_indices[i]], missing_keys[i]);\n    }\n    ASSERT_EQ(get_missing_keys_set, expect_missing_keys_set);\n    ASSERT_EQ(get_missing_indices_set, expect_missing_indices_set);\n    for (size_t i = 0; i < n_keys; ++i) {\n      if (get_missing_keys_set.count(keys[i]) == 0) {\n        for (size_t j = 0; j < line_size; ++j) {\n          ASSERT_EQ(values[i * line_size + j], static_cast<float>(keys[i] * line_size + j))\n              << \"iter \" << iter << \" i \" << i << \" j \" << j;\n        }\n      }\n    }\n\n    // put\n    for (size_t i = 0; i < n_keys; ++i) {\n      for (size_t j = 0; j < line_size; ++j) {\n        values[i * line_size + j] = static_cast<float>(keys[i] * line_size + j);\n      }\n    }\n    OF_CUDA_CHECK(cudaMemcpy(d_values, values, values_size, cudaMemcpyDefault));\n    OF_CUDA_CHECK(cudaDeviceSynchronize());\n    cache->Put(stream, n_keys, d_keys, d_values, d_n_evicted, d_evicted_keys, d_evicted_values);\n    OF_CUDA_CHECK(cudaDeviceSynchronize());\n    OF_CUDA_CHECK(cudaMemcpy(n_evicted, d_n_evicted, sizeof(uint32_t), cudaMemcpyDefault));\n    OF_CUDA_CHECK(cudaMemcpy(evicted_keys, d_evicted_keys, keys_size, cudaMemcpyDefault));\n    OF_CUDA_CHECK(cudaMemcpy(evicted_values, d_evicted_values, values_size, cudaMemcpyDefault));\n    OF_CUDA_CHECK(cudaDeviceSynchronize());\n    for (size_t i = 0; i < *n_evicted; ++i) {\n      ASSERT_TRUE(in_cache.count(evicted_keys[i]) > 0 || keys_set.count(evicted_keys[i]) > 0);\n      for (size_t j = 0; j < line_size; ++j) {\n        ASSERT_EQ(evicted_values[i * line_size + j],\n                  static_cast<float>(evicted_keys[i] * line_size + j));\n      }\n    }\n    for (size_t i = 0; i < n_keys; ++i) { in_cache.emplace(keys[i]); }\n    for (size_t i = 0; i < *n_evicted; ++i) { in_cache.erase(evicted_keys[i]); }\n  }\n  const uint64_t dump_capacity = cache->DumpCapacity();\n  for (size_t start_key_index = 0; start_key_index < dump_capacity; start_key_index += n_keys) {\n    cache->Dump(stream, start_key_index, std::min(start_key_index + n_keys, dump_capacity),\n                d_n_evicted, d_evicted_keys, d_evicted_values);\n    OF_CUDA_CHECK(cudaDeviceSynchronize());\n    OF_CUDA_CHECK(cudaMemcpy(n_evicted, d_n_evicted, sizeof(uint32_t), cudaMemcpyDefault));\n    OF_CUDA_CHECK(cudaMemcpy(evicted_keys, d_evicted_keys, keys_size, cudaMemcpyDefault));\n    OF_CUDA_CHECK(cudaMemcpy(evicted_values, d_evicted_values, values_size, cudaMemcpyDefault));\n    for (size_t i = 0; i < *n_evicted; ++i) {\n      ASSERT_TRUE(in_cache.count(evicted_keys[i]) > 0);\n      in_cache.erase(evicted_keys[i]);\n      for (size_t j = 0; j < line_size; ++j) {\n        ASSERT_EQ(evicted_values[i * line_size + j],\n                  static_cast<float>(evicted_keys[i] * line_size + j));\n      }\n    }\n  }\n  CHECK_EQ(in_cache.size(), 0);\n  OF_CUDA_CHECK(cudaFree(d_keys));\n  OF_CUDA_CHECK(cudaFreeHost(keys));\n  OF_CUDA_CHECK(cudaFree(d_n_missing));\n  OF_CUDA_CHECK(cudaFreeHost(n_missing));\n  OF_CUDA_CHECK(cudaFree(d_missing_keys));\n  OF_CUDA_CHECK(cudaFreeHost(missing_keys));\n  OF_CUDA_CHECK(cudaFree(d_missing_indices));\n  OF_CUDA_CHECK(cudaFreeHost(missing_indices));\n  OF_CUDA_CHECK(cudaFree(d_values));\n  OF_CUDA_CHECK(cudaFreeHost(values));\n  OF_CUDA_CHECK(cudaFree(d_evicted_values));\n  OF_CUDA_CHECK(cudaFreeHost(evicted_values));\n  OF_CUDA_CHECK(cudaFree(d_n_evicted));\n  OF_CUDA_CHECK(cudaFreeHost(n_evicted));\n  OF_CUDA_CHECK(cudaFree(d_evicted_keys));\n  OF_CUDA_CHECK(cudaFreeHost(evicted_keys));\n  OF_CUDA_CHECK(cudaFree(mask));\n  device->DestroyStream(stream);\n}\n\nTEST(Cache, FullCache) {\n  if (!HasCudaDevice()) { return; }\n\n  CacheOptions options{};\n  options.policy = CacheOptions::Policy::kFull;\n  const uint32_t line_size = 128;\n  options.value_size = 512;\n  options.capacity = 65536;\n  options.key_size = 8;\n  options.value_memory_kind = CacheOptions::MemoryKind::kDevice;\n  std::unique_ptr<Cache> cache(NewCache(options));\n  cache->ReserveQueryLength(65536);\n  TestCache(cache.get(), line_size);\n}\n\nTEST(Cache, LruCache) {\n  if (!HasCudaDevice()) { return; }\n\n  CacheOptions options{};\n  options.policy = CacheOptions::Policy::kLRU;\n  const uint32_t line_size = 128;\n  options.value_size = 512;\n  options.capacity = 65536;\n  options.key_size = 8;\n  options.value_memory_kind = CacheOptions::MemoryKind::kDevice;\n\n  std::unique_ptr<Cache> cache(NewCache(options));\n  cache->ReserveQueryLength(65536);\n  TestCache(cache.get(), line_size);\n}\n\n#endif  // WITH_CUDA\n\n}  // namespace\n\n}  // namespace embedding\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/embedding/cached_key_value_store.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/embedding/cached_key_value_store.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n#include \"oneflow/core/ep/include/device_manager_registry.h\"\n\nnamespace oneflow {\n\nnamespace embedding {\n\nnamespace {\ntemplate<typename Key, typename Elem>\n__global__ void PostStoreGetKernel(uint32_t num_cache_missing, uint32_t num_store_missing,\n                                   uint32_t num_elems_per_value,\n                                   const uint32_t* cache_missing_indices,\n                                   const uint32_t* store_missing_indices, const Elem* store_values,\n                                   Elem* values, uint32_t* missing_indices) {\n  const uint32_t num_cache_missing_elem = num_cache_missing * num_elems_per_value;\n  CUDA_1D_KERNEL_LOOP_T(uint32_t, i, num_cache_missing_elem) {\n    const uint32_t value_index = i / num_elems_per_value;\n    const uint32_t elem_index = i - value_index * num_elems_per_value;\n    values[cache_missing_indices[value_index] * num_elems_per_value + elem_index] = store_values[i];\n  }\n  CUDA_1D_KERNEL_LOOP_T(uint32_t, i, num_store_missing) {\n    missing_indices[i] = cache_missing_indices[store_missing_indices[i]];\n  }\n}\n\ntemplate<typename Key, typename Elem>\nclass CacheKeyValueStoreImpl : public KeyValueStore {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CacheKeyValueStoreImpl);\n  CacheKeyValueStoreImpl(std::unique_ptr<KeyValueStore>&& store, std::unique_ptr<Cache>&& cache)\n      : store_(std::move(store)), cache_(std::move(cache)), synced_(true), max_query_length_(0) {\n    OF_CUDA_CHECK(cudaGetDevice(&device_index_));\n    CHECK_EQ(store_->KeySize(), cache_->KeySize());\n    CHECK_EQ(store_->ValueSize(), cache_->ValueSize());\n    OF_CUDA_CHECK(cudaMalloc(&num_buffer_, sizeof(uint32_t)));\n    OF_CUDA_CHECK(cudaMallocHost(&host_num_buffer_, sizeof(uint32_t)));\n    num_elems_per_value_ = store_->ValueSize() / sizeof(Elem);\n  }\n  ~CacheKeyValueStoreImpl() {\n    CudaCurrentDeviceGuard guard(device_index_);\n    OF_CUDA_CHECK(cudaFree(num_buffer_));\n    OF_CUDA_CHECK(cudaFreeHost(host_num_buffer_));\n    if (max_query_length_ != 0) {\n      OF_CUDA_CHECK(cudaFree(keys_buffer_));\n      OF_CUDA_CHECK(cudaFree(values_buffer_));\n      OF_CUDA_CHECK(cudaFree(indices_buffer0_));\n      OF_CUDA_CHECK(cudaFree(indices_buffer1_));\n    }\n    cache_.reset();\n    store_.reset();\n  }\n\n  uint32_t KeySize() const override { return store_->KeySize(); }\n  uint32_t ValueSize() const override { return store_->ValueSize(); }\n  uint32_t MaxQueryLength() const override { return max_query_length_; }\n\n  void ReserveQueryLength(uint32_t query_length) override {\n    CudaCurrentDeviceGuard guard(device_index_);\n    if (query_length <= max_query_length_) { return; }\n    if (query_length > cache_->MaxQueryLength()) { cache_->ReserveQueryLength(query_length); }\n    if (query_length > store_->MaxQueryLength()) { store_->ReserveQueryLength(query_length); }\n    if (max_query_length_ != 0) {\n      OF_CUDA_CHECK(cudaFree(keys_buffer_));\n      OF_CUDA_CHECK(cudaFree(values_buffer_));\n      OF_CUDA_CHECK(cudaFree(indices_buffer0_));\n      OF_CUDA_CHECK(cudaFree(indices_buffer1_));\n    }\n    OF_CUDA_CHECK(cudaMalloc(&keys_buffer_, query_length * store_->KeySize()));\n    OF_CUDA_CHECK(cudaMalloc(&values_buffer_, query_length * store_->ValueSize()));\n    OF_CUDA_CHECK(cudaMalloc(&indices_buffer0_, query_length * sizeof(uint32_t)));\n    OF_CUDA_CHECK(cudaMalloc(&indices_buffer1_, query_length * sizeof(uint32_t)));\n    max_query_length_ = query_length;\n  }\n\n  void Get(ep::Stream* stream, uint32_t num_keys, const void* keys, void* values,\n           uint32_t* n_missing, uint32_t* missing_indices) override;\n  void Get(ep::Stream* stream, uint32_t num_keys, const void* keys, void* values,\n           uint8_t* mask) override;\n  void Put(ep::Stream* stream, uint32_t num_keys, const void* keys, const void* values) override;\n  void FusedHalfUpdatePut(ep::Stream* stream, uint32_t n_keys, const void* keys, const void* values,\n                          const void* update, const float* lr, float scale) override;\n  bool IsFusionSupported() override {\n    return cache_->Policy() == CacheOptions::Policy::kFull\n           && cache_->ValueType() == DataType::kFloat;\n  }\n  bool SnapshotExists(const std::string& name) override;\n  void LoadSnapshot(const std::string& name) override;\n  void SaveSnapshot(const std::string& name) override;\n  void LoadSnapshot(const std::string& name,\n                    const std::function<void(KVIterator* iter)>& Hook) override;\n\n private:\n  void SyncCacheToStore();\n\n  std::unique_ptr<KeyValueStore> store_;\n  std::unique_ptr<Cache> cache_;\n\n  uint32_t* num_buffer_{};\n  uint32_t* host_num_buffer_{};\n  Key* keys_buffer_{};\n  Elem* values_buffer_{};\n  uint32_t* indices_buffer0_{};\n  uint32_t* indices_buffer1_{};\n  int device_index_{};\n  uint32_t max_query_length_;\n  uint32_t num_elems_per_value_{};\n  std::recursive_mutex mutex_;\n  bool synced_;\n};\n\ntemplate<typename Key, typename Elem>\nvoid CacheKeyValueStoreImpl<Key, Elem>::Get(ep::Stream* stream, uint32_t num_keys, const void* keys,\n                                            void* values, uint32_t* n_missing,\n                                            uint32_t* missing_indices) {\n  std::lock_guard<std::recursive_mutex> lock(mutex_);\n  auto cuda_stream = stream->As<ep::CudaStream>();\n  if (cache_->Policy() == CacheOptions::Policy::kFull) {\n    cache_->Get(stream, num_keys, keys, values, n_missing, keys_buffer_, missing_indices);\n    return;\n  } else {\n    cache_->Get(stream, num_keys, keys, values, num_buffer_, keys_buffer_, indices_buffer0_);\n  }\n  OF_CUDA_CHECK(cudaMemcpyAsync(host_num_buffer_, num_buffer_, sizeof(uint32_t), cudaMemcpyDefault,\n                                cuda_stream->cuda_stream()));\n  CHECK_JUST(cuda_stream->Sync());\n  const uint32_t num_cache_missing = *host_num_buffer_;\n  if (num_cache_missing == 0) {\n    OF_CUDA_CHECK(cudaMemsetAsync(n_missing, 0, sizeof(uint32_t),\n                                  stream->As<ep::CudaStream>()->cuda_stream()));\n    return;\n  }\n  store_->Get(stream, num_cache_missing, keys_buffer_, values_buffer_, n_missing, indices_buffer1_);\n  OF_CUDA_CHECK(cudaMemcpyAsync(host_num_buffer_, n_missing, sizeof(uint32_t), cudaMemcpyDefault,\n                                cuda_stream->cuda_stream()));\n  CHECK_JUST(cuda_stream->Sync());\n  const uint32_t num_store_missing = *host_num_buffer_;\n  RUN_CUDA_KERNEL((PostStoreGetKernel<Key, Elem>), stream, num_cache_missing * num_elems_per_value_,\n                  num_cache_missing, num_store_missing, num_elems_per_value_, indices_buffer0_,\n                  indices_buffer1_, values_buffer_, static_cast<Elem*>(values), missing_indices);\n}\n\ntemplate<typename Key, typename Elem>\nvoid CacheKeyValueStoreImpl<Key, Elem>::Get(ep::Stream* stream, uint32_t num_keys, const void* keys,\n                                            void* values, uint8_t* mask) {\n  std::lock_guard<std::recursive_mutex> lock(mutex_);\n  if (cache_->Policy() == CacheOptions::Policy::kFull) {\n    cache_->Get(stream, num_keys, keys, values, mask);\n    return;\n  } else {\n    UNIMPLEMENTED();\n  }\n}\n\ntemplate<typename Key, typename Elem>\nvoid CacheKeyValueStoreImpl<Key, Elem>::Put(ep::Stream* stream, uint32_t num_keys, const void* keys,\n                                            const void* values) {\n  std::lock_guard<std::recursive_mutex> lock(mutex_);\n  synced_ = false;\n  auto cuda_stream = stream->As<ep::CudaStream>();\n  if (cache_->Policy() != CacheOptions::Policy::kFull) {\n    OF_CUDA_CHECK(cudaMemsetAsync(num_buffer_, 0, sizeof(uint32_t), cuda_stream->cuda_stream()));\n  }\n  cache_->Put(stream, num_keys, keys, values, num_buffer_, keys_buffer_, values_buffer_);\n  if (cache_->Policy() == CacheOptions::Policy::kFull) { return; }\n  OF_CUDA_CHECK(cudaMemcpyAsync(host_num_buffer_, num_buffer_, sizeof(uint32_t), cudaMemcpyDefault,\n                                cuda_stream->cuda_stream()));\n  CHECK_JUST(cuda_stream->Sync());\n  store_->Put(stream, *host_num_buffer_, keys_buffer_, values_buffer_);\n}\n\ntemplate<typename Key, typename Elem>\nvoid CacheKeyValueStoreImpl<Key, Elem>::FusedHalfUpdatePut(ep::Stream* stream, uint32_t num_keys,\n                                                           const void* keys, const void* values,\n                                                           const void* update, const float* lr,\n                                                           float scale) {\n  std::lock_guard<std::recursive_mutex> lock(mutex_);\n  if (cache_->Policy() != CacheOptions::Policy::kFull) {\n    OF_CUDA_CHECK(cudaMemsetAsync(num_buffer_, 0, sizeof(uint32_t),\n                                  stream->As<ep::CudaStream>()->cuda_stream()));\n  }\n  if (cache_->Policy() != CacheOptions::Policy::kFull || cache_->ValueType() != DataType::kFloat) {\n    UNIMPLEMENTED();\n  }\n  synced_ = false;\n  cache_->FusedHalfUpdatePut(stream, num_keys, keys, values, update, lr, scale, num_buffer_,\n                             keys_buffer_, values_buffer_);\n}\n\ntemplate<typename Key, typename Elem>\nbool CacheKeyValueStoreImpl<Key, Elem>::SnapshotExists(const std::string& name) {\n  return store_->SnapshotExists(name);\n}\n\ntemplate<typename Key, typename Elem>\nvoid CacheKeyValueStoreImpl<Key, Elem>::LoadSnapshot(const std::string& name) {\n  LoadSnapshot(name, nullptr);\n}\n\ntemplate<typename Key, typename Elem>\nvoid CacheKeyValueStoreImpl<Key, Elem>::LoadSnapshot(\n    const std::string& name, const std::function<void(KVIterator* iter)>& Hook) {\n  CudaCurrentDeviceGuard guard(device_index_);\n  std::lock_guard<std::recursive_mutex> lock(mutex_);\n  CHECK_GT(max_query_length_, 0);\n  cache_->Clear();\n  auto device =\n      Singleton<ep::DeviceManagerRegistry>::Get()->GetDevice(DeviceType::kCUDA, device_index_);\n  CHECK(device);\n  auto* stream = device->CreateStream();\n  store_->LoadSnapshot(name, [&](KVIterator* iter) {\n    if (cache_->Policy() == CacheOptions::Policy::kFull) {\n      auto* cuda_stream = stream->As<ep::CudaStream>();\n      while (true) {\n        iter->NextN(stream, max_query_length_, num_buffer_, keys_buffer_, values_buffer_);\n        OF_CUDA_CHECK(cudaDeviceSynchronize());\n        OF_CUDA_CHECK(cudaMemcpyAsync(host_num_buffer_, num_buffer_, sizeof(uint32_t),\n                                      cudaMemcpyDefault, cuda_stream->cuda_stream()));\n        CHECK_JUST(stream->Sync());\n        if (*host_num_buffer_ == 0) { return; }\n        cache_->Put(stream, *host_num_buffer_, keys_buffer_, values_buffer_, num_buffer_, nullptr,\n                    nullptr);\n      }\n    }\n    if (Hook) {\n      iter->Reset();\n      Hook(iter);\n    }\n  });\n  device->DestroyStream(stream);\n}\n\ntemplate<typename Key, typename Elem>\nvoid CacheKeyValueStoreImpl<Key, Elem>::SaveSnapshot(const std::string& name) {\n  CudaCurrentDeviceGuard guard(device_index_);\n  std::lock_guard<std::recursive_mutex> lock(mutex_);\n  SyncCacheToStore();\n  store_->SaveSnapshot(name);\n}\n\ntemplate<typename Key, typename Elem>\nvoid CacheKeyValueStoreImpl<Key, Elem>::SyncCacheToStore() {\n  if (synced_) { return; }\n  CudaCurrentDeviceGuard guard(device_index_);\n  auto device =\n      Singleton<ep::DeviceManagerRegistry>::Get()->GetDevice(DeviceType::kCUDA, device_index_);\n  CHECK(device);\n  auto* stream = device->CreateStream();\n  auto* cuda_stream = stream->As<ep::CudaStream>();\n  const uint64_t dump_capacity = cache_->DumpCapacity();\n  CHECK_GT(max_query_length_, 0);\n  for (uint64_t start_key_index = 0; start_key_index < dump_capacity;\n       start_key_index += max_query_length_) {\n    cache_->Dump(stream, start_key_index,\n                 std::min(start_key_index + max_query_length_, dump_capacity), num_buffer_,\n                 keys_buffer_, values_buffer_);\n    OF_CUDA_CHECK(cudaMemcpyAsync(host_num_buffer_, num_buffer_, sizeof(uint32_t),\n                                  cudaMemcpyDefault, cuda_stream->cuda_stream()));\n    CHECK_JUST(stream->Sync());\n    if (*host_num_buffer_ == 0) { continue; }\n    store_->Put(stream, *host_num_buffer_, keys_buffer_, values_buffer_);\n    CHECK_JUST(stream->Sync());\n  }\n  cache_->ClearDirtyFlags();\n  device->DestroyStream(stream);\n  synced_ = true;\n}\n\ntemplate<typename Key>\nstd::unique_ptr<KeyValueStore> DispatchElemType(std::unique_ptr<KeyValueStore>&& store,\n                                                std::unique_ptr<Cache>&& cache) {\n  const uint32_t value_size = store->ValueSize();\n  if (value_size % sizeof(uint4) == 0) {\n    return std::unique_ptr<KeyValueStore>(\n        new CacheKeyValueStoreImpl<Key, uint4>(std::move(store), std::move(cache)));\n  } else if (value_size % sizeof(uint64_t) == 0) {\n    return std::unique_ptr<KeyValueStore>(\n        new CacheKeyValueStoreImpl<Key, uint64_t>(std::move(store), std::move(cache)));\n  } else if (value_size % sizeof(uint32_t) == 0) {\n    return std::unique_ptr<KeyValueStore>(\n        new CacheKeyValueStoreImpl<Key, uint32_t>(std::move(store), std::move(cache)));\n  } else if (value_size % sizeof(uint16_t) == 0) {\n    return std::unique_ptr<KeyValueStore>(\n        new CacheKeyValueStoreImpl<Key, uint16_t>(std::move(store), std::move(cache)));\n  } else {\n    return std::unique_ptr<KeyValueStore>(\n        new CacheKeyValueStoreImpl<Key, uint8_t>(std::move(store), std::move(cache)));\n  }\n}\n\nstd::unique_ptr<KeyValueStore> DispatchKeyType(std::unique_ptr<KeyValueStore>&& store,\n                                               std::unique_ptr<Cache>&& cache) {\n  const uint32_t key_size = store->KeySize();\n  if (key_size == 4) {\n    return DispatchElemType<uint32_t>(std::move(store), std::move(cache));\n  } else if (key_size == 8) {\n    return DispatchElemType<uint64_t>(std::move(store), std::move(cache));\n  } else {\n    UNIMPLEMENTED();\n    return nullptr;\n  }\n}\n\n}  // namespace\n\nstd::unique_ptr<KeyValueStore> NewCachedKeyValueStore(std::unique_ptr<KeyValueStore>&& store,\n                                                      std::unique_ptr<Cache>&& cache) {\n  return DispatchKeyType(std::move(store), std::move(cache));\n}\n\n}  // namespace embedding\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/embedding/cached_key_value_store.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_EMBEDDING_CACHED_KEY_VALUE_STORE_H_\n#define ONEFLOW_CORE_EMBEDDING_CACHED_KEY_VALUE_STORE_H_\n\n#include \"oneflow/core/embedding/key_value_store.h\"\n#include \"oneflow/core/embedding/cache.h\"\n\nnamespace oneflow {\n\nnamespace embedding {\n\nstd::unique_ptr<KeyValueStore> NewCachedKeyValueStore(std::unique_ptr<KeyValueStore>&& store,\n                                                      std::unique_ptr<Cache>&& cache);\n\n}  // namespace embedding\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_EMBEDDING_CACHED_KEY_VALUE_STORE_H_\n"
  },
  {
    "path": "oneflow/core/embedding/embedding_manager.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/embedding/embedding_manager.h\"\n#include \"oneflow/core/embedding/persistent_table_key_value_store.h\"\n#include \"oneflow/core/ep/include/device_manager_registry.h\"\n#include \"oneflow/core/embedding/cached_key_value_store.h\"\n\nnamespace oneflow {\n\nnamespace embedding {\n\n#ifdef WITH_CUDA\n\nconstexpr size_t kDefaultMaxQueryLength = 131072;\n\nconstexpr int64_t kRingBufferSize = 8;\n\nstruct IdStatistics {\n  IdStatistics() : final_num_unique(0), iter(-1) {}\n  uint32_t final_num_unique;\n  std::vector<uint32_t> num_unique_matrix;\n  int64_t iter;\n};\n\n#if CUDA_VERSION >= 11020\n\nclass DynamicTmpBufferAllocator final : public TmpBufferAllocator {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(DynamicTmpBufferAllocator);\n  DynamicTmpBufferAllocator(cudaStream_t stream, cudaMemPool_t pool)\n      : stream_(stream), mem_pool_(pool) {}\n  ~DynamicTmpBufferAllocator() override = default;\n\n  void Allocate(void** ptr, size_t size) override {\n    OF_CUDA_CHECK(cudaMallocFromPoolAsync(ptr, GetCudaAlignedSize(size), mem_pool_, stream_));\n  }\n  void Free(void* ptr) override { OF_CUDA_CHECK(cudaFreeAsync(ptr, stream_)); }\n\n private:\n  cudaStream_t stream_{};\n  cudaMemPool_t mem_pool_{};\n};\n\nclass DynamicAllocationEmbeddingState final : public EmbeddingState {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(DynamicAllocationEmbeddingState);\n  DynamicAllocationEmbeddingState()\n      : lookup_values_(nullptr),\n        lookup_values_size_(0),\n        has_lookup_values_(false),\n        lookup_embeddings_(nullptr),\n        lookup_embeddings_size_(0),\n        has_lookup_embeddings_(false),\n        updated_values_(nullptr),\n        iter_(-1) {\n    OF_CUDA_CHECK(cudaGetDevice(&device_index_));\n    id_statistics_vec_.resize(kRingBufferSize);\n    cudaMemPoolProps poolProps = {};\n    poolProps.allocType = cudaMemAllocationTypePinned;\n    poolProps.handleTypes = cudaMemHandleTypePosixFileDescriptor;\n    poolProps.location.type = cudaMemLocationTypeDevice;\n    poolProps.location.id = device_index_;\n    cudaMemPoolCreate(&mem_pool_, &poolProps);\n    uint64_t threshold = UINT64_MAX;\n    cudaMemPoolSetAttribute(mem_pool_, cudaMemPoolAttrReleaseThreshold, &threshold);\n  }\n  ~DynamicAllocationEmbeddingState() {\n    CudaCurrentDeviceGuard guard(device_index_);\n    if (has_lookup_values_) { OF_CUDA_CHECK(cudaFree(lookup_values_)); }\n    if (has_lookup_embeddings_) { OF_CUDA_CHECK(cudaFree(lookup_embeddings_)); }\n    OF_CUDA_CHECK(cudaMemPoolDestroy(mem_pool_));\n  }\n\n  std::unique_ptr<TmpBufferAllocator> NewTmpBufferAllocator(\n      user_op::KernelComputeContext* ctx) override {\n    return std::make_unique<DynamicTmpBufferAllocator>(\n        ctx->stream()->As<ep::CudaStream>()->cuda_stream(), mem_pool_);\n  }\n\n  void OnEmbeddingLookupStart(user_op::KernelComputeContext* ctx, int64_t iter) override {\n    iter_ = iter;\n    cudaStream_t cuda_stream = ctx->stream()->As<ep::CudaStream>()->cuda_stream();\n    user_op::Tensor* unique_values = ctx->Tensor4ArgNameAndIndex(\"unique_values\", 0);\n    const int64_t embedding_size = ctx->Attr<int64_t>(\"embedding_size\");\n    const int64_t line_size = ctx->Attr<int64_t>(\"line_size\");\n    uint32_t num_unique = this->GetIdNumUnique(iter);\n    size_t lookup_values_size =\n        GetCudaAlignedSize(num_unique * line_size * GetSizeOfDataType(unique_values->data_type()));\n    if (!has_lookup_values_ || lookup_values_size_ < lookup_values_size) {\n      if (has_lookup_values_) { OF_CUDA_CHECK(cudaFreeAsync(lookup_values_, cuda_stream)); }\n      OF_CUDA_CHECK(\n          cudaMallocFromPoolAsync(&lookup_values_, lookup_values_size, mem_pool_, cuda_stream));\n      has_lookup_values_ = true;\n      lookup_values_size_ = lookup_values_size;\n      if (ctx->has_output(\"embeddings\", 0)) {\n        user_op::Tensor* embeddings = ctx->Tensor4ArgNameAndIndex(\"embeddings\", 0);\n        const size_t lookup_embeddings_size = GetCudaAlignedSize(\n            num_unique * embedding_size * GetSizeOfDataType(embeddings->data_type()));\n        if (!has_lookup_embeddings_ || lookup_embeddings_size_ < lookup_values_size) {\n          if (has_lookup_embeddings_) {\n            OF_CUDA_CHECK(cudaFreeAsync(lookup_embeddings_, cuda_stream));\n          }\n          OF_CUDA_CHECK(cudaMallocFromPoolAsync(&lookup_embeddings_, lookup_embeddings_size,\n                                                mem_pool_, cuda_stream));\n          has_lookup_embeddings_ = true;\n          lookup_embeddings_size_ = lookup_embeddings_size;\n        }\n      } else {\n        lookup_embeddings_ = nullptr;\n      }\n    }\n  }\n\n  void* LookupUniqueValues(int64_t iter) override {\n    CHECK_EQ(iter_, iter);\n    CHECK(has_lookup_values_);\n    return lookup_values_;\n  }\n\n  void* LookupEmbeddings(int64_t iter) override {\n    CHECK_EQ(iter_, iter);\n    CHECK(has_lookup_embeddings_);\n    return lookup_embeddings_;\n  }\n\n  void OnEmbeddingLookupEnd(user_op::KernelComputeContext* ctx, int64_t iter) override {\n    // do nothing\n  }\n\n  void OnEmbeddingGatherStart(user_op::KernelComputeContext* ctx, int64_t iter) override {\n    // do nothing\n  }\n\n  const void* EmbeddingGatherIn(int64_t iter) override {\n    if (has_lookup_embeddings_) {\n      return lookup_embeddings_;\n    } else {\n      CHECK(has_lookup_values_);\n      return lookup_values_;\n    }\n  }\n\n  void OnEmbeddingGatherEnd(user_op::KernelComputeContext* ctx, int64_t iter) override {\n    // do nothing\n  }\n\n  void OnEmbeddingShuffleStart(user_op::KernelComputeContext* ctx, int64_t iter) override {\n    // do nothing\n  }\n\n  const void* EmbeddingShuffleCurRankEmbeddings(int64_t iter) override {\n    if (has_lookup_embeddings_) {\n      return lookup_embeddings_;\n    } else {\n      CHECK(has_lookup_values_);\n      return lookup_values_;\n    }\n  }\n\n  void OnEmbeddingShuffleEnd(user_op::KernelComputeContext* ctx, int64_t iter) override {\n    // do nothing\n  }\n\n  void OnEmbeddingUpdateStart(user_op::KernelComputeContext* ctx, int64_t iter) override {\n    const user_op::Tensor* updated_unique_embeddings =\n        ctx->Tensor4ArgNameAndIndex(\"updated_unique_embeddings\", 0);\n    const int64_t line_size = ctx->Attr<int64_t>(\"line_size\");\n    uint32_t num_unique = this->GetIdNumUnique(iter);\n    size_t update_values_size = GetCudaAlignedSize(\n        num_unique * line_size * GetSizeOfDataType(updated_unique_embeddings->data_type()));\n    OF_CUDA_CHECK(cudaMallocFromPoolAsync(&updated_values_, update_values_size, mem_pool_,\n                                          ctx->stream()->As<ep::CudaStream>()->cuda_stream()));\n  }\n\n  const void* EmbeddingUpdateUniqueEmbeddings(int64_t iter) override {\n    CHECK_EQ(iter_, iter);\n    CHECK(has_lookup_values_);\n    return lookup_values_;\n  }\n\n  void* EmbeddingUpdateUpdatedUniqueEmbeddings(int64_t iter) override {\n    CHECK_EQ(iter_, iter);\n    return updated_values_;\n  }\n\n  void OnEmbeddingUpdateEnd(user_op::KernelComputeContext* ctx, int64_t iter) override {\n    // do nothing\n  }\n\n  void OnEmbeddingPutStart(user_op::KernelComputeContext* ctx, int64_t iter) override {\n    // do nothing\n  }\n\n  const void* EmbeddingPutUniqueEmbeddings(int64_t iter) override {\n    CHECK_EQ(iter_, iter);\n    return updated_values_;\n  }\n\n  void OnEmbeddingPutEnd(user_op::KernelComputeContext* ctx, int64_t iter) override {\n    OF_CUDA_CHECK(\n        cudaFreeAsync(updated_values_, ctx->stream()->As<ep::CudaStream>()->cuda_stream()));\n  }\n\n  void OnEmbeddingFusedUpdatePutStart(user_op::KernelComputeContext* ctx, int64_t iter) override {\n    // do nothing\n  }\n\n  const void* EmbeddingFusedUpdatePutUniqueEmbeddings(int64_t iter) override {\n    CHECK_EQ(iter_, iter);\n    CHECK(has_lookup_values_);\n    return lookup_values_;\n  }\n\n  void OnEmbeddingFusedUpdatePutEnd(user_op::KernelComputeContext* ctx, int64_t iter) override {\n    // do nothing\n  }\n\n  void SetIdFinalNumUnique(uint32_t final_num_unique, int64_t iter) override {\n    std::unique_lock<std::mutex> lock(mutex_);\n    int64_t index = iter % kRingBufferSize;\n    id_statistics_vec_.at(index).final_num_unique = final_num_unique;\n    id_statistics_vec_.at(index).iter = iter;\n  }\n\n  void SetIdNumUniqueMatrix(const std::vector<uint32_t>& num_unique_matrix, int64_t iter) override {\n    std::unique_lock<std::mutex> lock(mutex_);\n    int64_t index = iter % kRingBufferSize;\n    id_statistics_vec_.at(index).num_unique_matrix = num_unique_matrix;\n    id_statistics_vec_.at(index).iter = iter;\n  }\n\n  uint32_t GetIdNumUnique(int64_t iter) override {\n    std::unique_lock<std::mutex> lock(mutex_);\n    int64_t index = iter % kRingBufferSize;\n    const IdStatistics& statistics = id_statistics_vec_.at(index);\n    CHECK_EQ(statistics.iter, iter)\n        << \"saved iter: \" << statistics.iter << \" current iter: \" << iter;\n    return statistics.final_num_unique;\n  }\n\n  const std::vector<uint32_t>& GetIdNumUniqueMatrix(int64_t iter) override {\n    std::unique_lock<std::mutex> lock(mutex_);\n    int64_t index = iter % kRingBufferSize;\n    const IdStatistics& statistics = id_statistics_vec_.at(index);\n    CHECK_EQ(statistics.iter, iter)\n        << \"saved iter: \" << statistics.iter << \" current iter: \" << iter;\n    return statistics.num_unique_matrix;\n  }\n\n private:\n  void* lookup_values_;\n  size_t lookup_values_size_;\n  bool has_lookup_values_;\n  void* lookup_embeddings_;\n  size_t lookup_embeddings_size_;\n  bool has_lookup_embeddings_;\n  void* updated_values_;\n  int64_t iter_;\n  std::vector<IdStatistics> id_statistics_vec_;\n  int device_index_{};\n  cudaMemPool_t mem_pool_{};\n  std::mutex mutex_;\n};\n\n#endif\n\nclass StaticTmpBufferAllocator final : public TmpBufferAllocator {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(StaticTmpBufferAllocator);\n  StaticTmpBufferAllocator(void* ptr, size_t size) : ptr_(ptr), offset_(0), size_(size) {}\n  ~StaticTmpBufferAllocator() override = default;\n\n  void Allocate(void** ptr, size_t size) override {\n    CHECK(ptr_ != nullptr);\n    CHECK_GE(offset_, 0);\n    size_t aligned_size = GetCudaAlignedSize(size);\n    CHECK_LE(offset_ + aligned_size, size_);\n    *ptr = reinterpret_cast<char*>(ptr_) + offset_;\n    offset_ += aligned_size;\n  }\n\n  void Free(void* ptr) override {\n    // do nothing\n  }\n\n private:\n  void* ptr_;\n  int64_t offset_;\n  size_t size_;\n};\n\nclass StaticAllocationEmbeddingState final : public EmbeddingState {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(StaticAllocationEmbeddingState);\n  StaticAllocationEmbeddingState()\n      : lookup_unique_values_(nullptr),\n        lookup_embeddings_(nullptr),\n        has_lookup_embeddings_(false),\n        embedding_shuffle_cur_rank_embeddings_(nullptr),\n        embedding_update_unique_embeddings_(nullptr),\n        embedding_update_updated_unique_embeddings_(nullptr),\n        embedding_put_unique_embeddings_(nullptr),\n        embedding_fused_update_put_unique_embeddings_(nullptr) {\n    id_statistics_vec_.resize(kRingBufferSize);\n  }\n  ~StaticAllocationEmbeddingState() override = default;\n\n  std::unique_ptr<TmpBufferAllocator> NewTmpBufferAllocator(\n      user_op::KernelComputeContext* ctx) override {\n    user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n    return std::make_unique<StaticTmpBufferAllocator>(tmp_buffer->mut_dptr(),\n                                                      tmp_buffer->shape_view().elem_cnt());\n  }\n\n  void OnEmbeddingLookupStart(user_op::KernelComputeContext* ctx, int64_t iter) override {\n    user_op::Tensor* unique_values = ctx->Tensor4ArgNameAndIndex(\"unique_values\", 0);\n    lookup_unique_values_ = unique_values->mut_dptr();\n    if (ctx->has_output(\"embeddings\", 0)) {\n      user_op::Tensor* embeddings = ctx->Tensor4ArgNameAndIndex(\"embeddings\", 0);\n      has_lookup_embeddings_ = true;\n      lookup_embeddings_ = embeddings->mut_dptr();\n    }\n  }\n\n  void* LookupUniqueValues(int64_t iter) override { return lookup_unique_values_; }\n\n  void* LookupEmbeddings(int64_t iter) override {\n    CHECK(has_lookup_embeddings_);\n    return lookup_embeddings_;\n  }\n\n  void OnEmbeddingLookupEnd(user_op::KernelComputeContext* ctx, int64_t iter) override {\n    lookup_unique_values_ = nullptr;\n    lookup_embeddings_ = nullptr;\n    has_lookup_embeddings_ = false;\n  }\n\n  void OnEmbeddingGatherStart(user_op::KernelComputeContext* ctx, int64_t iter) override {\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    embedding_gather_in_ = in->dptr();\n  }\n\n  const void* EmbeddingGatherIn(int64_t iter) override { return embedding_gather_in_; }\n\n  void OnEmbeddingGatherEnd(user_op::KernelComputeContext* ctx, int64_t iter) override {\n    embedding_gather_in_ = nullptr;\n  }\n\n  void OnEmbeddingShuffleStart(user_op::KernelComputeContext* ctx, int64_t iter) override {\n    const user_op::Tensor* cur_rank_embeddings =\n        ctx->Tensor4ArgNameAndIndex(\"cur_rank_embeddings\", 0);\n    embedding_shuffle_cur_rank_embeddings_ = cur_rank_embeddings->dptr();\n  }\n\n  const void* EmbeddingShuffleCurRankEmbeddings(int64_t iter) override {\n    return embedding_shuffle_cur_rank_embeddings_;\n  }\n\n  void OnEmbeddingShuffleEnd(user_op::KernelComputeContext* ctx, int64_t iter) override {\n    embedding_shuffle_cur_rank_embeddings_ = nullptr;\n  }\n\n  void OnEmbeddingUpdateStart(user_op::KernelComputeContext* ctx, int64_t iter) override {\n    const user_op::Tensor* unique_embeddings = ctx->Tensor4ArgNameAndIndex(\"unique_embeddings\", 0);\n    user_op::Tensor* updated_unique_embeddings =\n        ctx->Tensor4ArgNameAndIndex(\"updated_unique_embeddings\", 0);\n    embedding_update_unique_embeddings_ = unique_embeddings->dptr();\n    embedding_update_updated_unique_embeddings_ = updated_unique_embeddings->mut_dptr();\n  }\n\n  const void* EmbeddingUpdateUniqueEmbeddings(int64_t iter) override {\n    return embedding_update_unique_embeddings_;\n  }\n\n  void* EmbeddingUpdateUpdatedUniqueEmbeddings(int64_t iter) override {\n    return embedding_update_updated_unique_embeddings_;\n  }\n\n  void OnEmbeddingUpdateEnd(user_op::KernelComputeContext* ctx, int64_t iter) override {\n    embedding_update_unique_embeddings_ = nullptr;\n    embedding_update_updated_unique_embeddings_ = nullptr;\n  }\n\n  void OnEmbeddingPutStart(user_op::KernelComputeContext* ctx, int64_t iter) override {\n    const user_op::Tensor* unique_embeddings = ctx->Tensor4ArgNameAndIndex(\"unique_embeddings\", 0);\n    embedding_put_unique_embeddings_ = unique_embeddings->dptr();\n  }\n\n  const void* EmbeddingPutUniqueEmbeddings(int64_t iter) override {\n    return embedding_put_unique_embeddings_;\n  }\n\n  void OnEmbeddingPutEnd(user_op::KernelComputeContext* ctx, int64_t iter) override {\n    embedding_put_unique_embeddings_ = nullptr;\n  }\n\n  void OnEmbeddingFusedUpdatePutStart(user_op::KernelComputeContext* ctx, int64_t iter) override {\n    const user_op::Tensor* unique_embeddings = ctx->Tensor4ArgNameAndIndex(\"unique_embeddings\", 0);\n    embedding_fused_update_put_unique_embeddings_ = unique_embeddings->dptr();\n  }\n\n  const void* EmbeddingFusedUpdatePutUniqueEmbeddings(int64_t iter) override {\n    return embedding_fused_update_put_unique_embeddings_;\n  }\n\n  void OnEmbeddingFusedUpdatePutEnd(user_op::KernelComputeContext* ctx, int64_t iter) override {\n    embedding_fused_update_put_unique_embeddings_ = nullptr;\n  }\n\n  void SetIdFinalNumUnique(uint32_t final_num_unique, int64_t iter) override {\n    std::unique_lock<std::mutex> lock(mutex_);\n    int64_t index = iter % kRingBufferSize;\n    id_statistics_vec_.at(index).final_num_unique = final_num_unique;\n    id_statistics_vec_.at(index).iter = iter;\n  }\n\n  void SetIdNumUniqueMatrix(const std::vector<uint32_t>& num_unique_matrix, int64_t iter) override {\n    std::unique_lock<std::mutex> lock(mutex_);\n    int64_t index = iter % kRingBufferSize;\n    id_statistics_vec_.at(index).num_unique_matrix = num_unique_matrix;\n    id_statistics_vec_.at(index).iter = iter;\n  }\n\n  uint32_t GetIdNumUnique(int64_t iter) override {\n    std::unique_lock<std::mutex> lock(mutex_);\n    int64_t index = iter % kRingBufferSize;\n    const IdStatistics& statistics = id_statistics_vec_.at(index);\n    CHECK_EQ(statistics.iter, iter)\n        << \"saved iter: \" << statistics.iter << \" current iter: \" << iter;\n    return statistics.final_num_unique;\n  }\n\n  const std::vector<uint32_t>& GetIdNumUniqueMatrix(int64_t iter) override {\n    std::unique_lock<std::mutex> lock(mutex_);\n    int64_t index = iter % kRingBufferSize;\n    const IdStatistics& statistics = id_statistics_vec_.at(index);\n    CHECK_EQ(statistics.iter, iter)\n        << \"saved iter: \" << statistics.iter << \" current iter: \" << iter;\n    return statistics.num_unique_matrix;\n  }\n\n  void* lookup_unique_values_;\n  void* lookup_embeddings_;\n  bool has_lookup_embeddings_;\n  const void* embedding_gather_in_;\n  const void* embedding_shuffle_cur_rank_embeddings_;\n  const void* embedding_update_unique_embeddings_;\n  void* embedding_update_updated_unique_embeddings_;\n  const void* embedding_put_unique_embeddings_;\n  const void* embedding_fused_update_put_unique_embeddings_;\n  std::vector<IdStatistics> id_statistics_vec_;\n  std::mutex mutex_;\n};\n\nEmbeddingState* EmbeddingManager::GetEmbeddingState(const std::string& embedding_name,\n                                                    int64_t rank_id) {\n  std::pair<std::string, int64_t> map_key = std::make_pair(embedding_name, rank_id);\n  std::unique_lock<std::mutex> lock(mutex_);\n  auto it = embedding_state_map_.find(map_key);\n  // for id shuffle test, not need to create table\n  if (it == embedding_state_map_.end()) {\n    LOG(INFO) << \"create embedding state: \" << embedding_name << \"-\" << rank_id;\n    if (UseDynamicMemoryAllocation()) {\n#if CUDA_VERSION >= 11020\n      it =\n          embedding_state_map_.emplace(map_key, std::make_unique<DynamicAllocationEmbeddingState>())\n              .first;\n#else\n      UNIMPLEMENTED();\n#endif\n    } else {\n      it = embedding_state_map_.emplace(map_key, std::make_unique<StaticAllocationEmbeddingState>())\n               .first;\n    }\n  }\n  return it->second.get();\n}\n\nKeyValueStore* EmbeddingManager::GetKeyValueStore(const std::string& embedding_name,\n                                                  int64_t rank_id) {\n  std::pair<std::string, int64_t> map_key = std::make_pair(embedding_name, rank_id);\n  std::unique_lock<std::mutex> lock(mutex_);\n  auto it = key_value_store_map_.find(map_key);\n  CHECK(it != key_value_store_map_.end())\n      << \"Can not find embedding: \" << embedding_name << \"-\" << rank_id;\n  return it->second.get();\n}\n\nvoid EmbeddingManager::CreateKeyValueStore(const KeyValueStoreOptions& key_value_store_options,\n                                           int64_t local_rank_id, int64_t rank_id,\n                                           int64_t world_size) {\n  CudaCurrentDeviceGuard guard(local_rank_id);\n  const std::string& name = key_value_store_options.Name();\n  const uint32_t line_size = key_value_store_options.LineSize();\n  std::pair<std::string, int64_t> map_key = std::make_pair(name, rank_id);\n  std::unique_lock<std::mutex> lock(mutex_);\n\n  std::unique_ptr<KeyValueStore> store;\n  PersistentTableKeyValueStoreOptions options{};\n  const std::vector<std::string>& persistent_table_paths =\n      key_value_store_options.PersistentTablePaths();\n  CHECK_EQ(persistent_table_paths.size(), world_size);\n  options.table_options.path = persistent_table_paths.at(rank_id);\n  options.table_options.value_size = line_size * key_value_store_options.ValueTypeSize();\n  options.table_options.key_size = key_value_store_options.KeyTypeSize();\n  options.table_options.physical_block_size =\n      key_value_store_options.PersistentTablePhysicalBlockSize();\n  options.table_options.target_chunk_size_mb = 4 * 1024;\n  options.table_options.capacity_hint = key_value_store_options.PersistentTableCapacityHint();\n  store = NewPersistentTableKeyValueStore(options);\n  const std::vector<CacheOptions>& cache_options = key_value_store_options.GetCachesOptions();\n  for (int i = cache_options.size() - 1; i >= 0; --i) {\n    std::unique_ptr<Cache> cache = NewCache(cache_options.at(i));\n    store = NewCachedKeyValueStore(std::move(store), std::move(cache));\n  }\n  store->ReserveQueryLength(kDefaultMaxQueryLength);\n  CHECK(key_value_store_map_.emplace(map_key, std::move(store)).second)\n      << \"Can't create an embedding with same name of an existing embedding, the name: \" << name;\n\n  if (UseDynamicMemoryAllocation()) {\n#if CUDA_VERSION >= 11020\n    CHECK(embedding_state_map_.emplace(map_key, std::make_unique<DynamicAllocationEmbeddingState>())\n              .second)\n        << \"Can't create an embedding state with same name of an existing embedding, the name: \"\n        << name;\n#else\n    UNIMPLEMENTED();\n#endif\n  } else {\n    CHECK(embedding_state_map_.emplace(map_key, std::make_unique<StaticAllocationEmbeddingState>())\n              .second)\n        << \"Can't create an embedding state with same name of an existing embedding, the name: \"\n        << name;\n  }\n}\n\nvoid EmbeddingManager::SaveSnapshot(const std::string& embedding_name, int64_t local_rank_id,\n                                    int64_t rank_id, const std::string& snapshot_name) {\n  CudaCurrentDeviceGuard guard(local_rank_id);\n  std::pair<std::string, int64_t> map_key = std::make_pair(embedding_name, rank_id);\n  std::unique_lock<std::mutex> lock(mutex_);\n\n  auto it = key_value_store_map_.find(map_key);\n  CHECK(it != key_value_store_map_.end())\n      << \"Can not find embedding: \" << embedding_name << \"-\" << rank_id;\n  it->second->SaveSnapshot(snapshot_name);\n}\n\nvoid EmbeddingManager::LoadSnapshot(const std::string& embedding_name, int64_t local_rank_id,\n                                    int64_t rank_id, const std::string& snapshot_name) {\n  CudaCurrentDeviceGuard guard(local_rank_id);\n  std::pair<std::string, int64_t> map_key = std::make_pair(embedding_name, rank_id);\n  auto it = key_value_store_map_.find(map_key);\n  CHECK(it != key_value_store_map_.end())\n      << \"Can not find embedding: \" << embedding_name << \"-\" << rank_id;\n  if (it->second->SnapshotExists(snapshot_name)) {\n    it->second->LoadSnapshot(snapshot_name);\n  } else {\n    LOG(ERROR) << \"Here Exists Embedding name is: \" << embedding_name << \"-\" << rank_id\n               << \" but no corresponding snapshot. \";\n  }\n}\n\n#endif  // WITH_CUDA\n\n}  // namespace embedding\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/embedding/embedding_manager.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_EMBEDDING_EMBEDDING_MANAGER_H_\n#define ONEFLOW_CORE_EMBEDDING_EMBEDDING_MANAGER_H_\n\n#include \"oneflow/core/device/cuda_util.h\"\n\n#include \"oneflow/core/embedding/key_value_store.h\"\n#include \"oneflow/core/embedding/key_value_store_options.h\"\n#include \"oneflow/core/framework/framework.h\"\n\nnamespace oneflow {\n\nnamespace embedding {\n\ninline bool UseDynamicMemoryAllocation() {\n  static bool use_dynamic_memory_allocation =\n      ParseBooleanFromEnv(\"ONEFLOW_ONE_EMBEDDING_USE_DYNAMIC_MEMORY_ALLOCATION\", false);\n#if CUDA_VERSION >= 11020\n  return use_dynamic_memory_allocation;\n#else\n  if (use_dynamic_memory_allocation) {\n    LOG(WARNING)\n        << \"Dynamic memory allocation only support when cuda_version greater equal than 11.2. \";\n  }\n  return false;\n#endif\n}\n\ninline bool UseEmbeddingShuffleP2PKernel(DataType embedding_dtype, DataType idx_dtype) {\n  static bool use_embedding_shuffle_p2p_env =\n      ParseBooleanFromEnv(\"ONEFLOW_ONE_EMBEDDING_EMBEDDING_SHUFFLE_USE_P2P\", false);\n  static bool add_id_shuffle_copy_out_env =\n      ParseBooleanFromEnv(\"ONEFLOW_ONE_EMBEDDING_ADD_ID_SHUFFLE_COPY_OUT\", true);\n  static bool enable_quantized_comm =\n      ParseBooleanFromEnv(\"ONEFLOW_ONE_EMBEDDING_ENABLE_QUANTIZED_COMM\", false);\n  if (use_embedding_shuffle_p2p_env) {\n    if (embedding_dtype != DataType::kFloat16 || idx_dtype != DataType::kUInt32) {\n      // p2p kernel only registered kFloat16 and kUint32.\n      return false;\n    }\n    if (!add_id_shuffle_copy_out_env) {\n      // when not enable id shuffle copy out, the ptrs change every iter.\n      return false;\n    }\n    if (enable_quantized_comm) {\n      // p2p kernel not support quantize comm.\n      return false;\n    }\n    if (UseDynamicMemoryAllocation()) {\n      // p2p kernel not support dynamic memory allocation.\n      return false;\n    }\n  }\n#if CUDA_VERSION >= 11030\n  return use_embedding_shuffle_p2p_env;\n#else\n  if (use_embedding_shuffle_p2p_env) {\n    LOG(WARNING)\n        << \"embedding shuffle p2p kernel only support when cuda_version greater equal than 11.3. \";\n  }\n  return false;\n#endif\n}\n\ninline bool UseEmbeddingGradientShuffleP2PKernel(DataType embedding_dtype, DataType idx_dtype) {\n  static bool use_embedding_gradient_shuffle_p2p_env =\n      ParseBooleanFromEnv(\"ONEFLOW_ONE_EMBEDDING_EMBEDDING_GRADIENT_SHUFFLE_USE_P2P\", false);\n  static bool add_id_shuffle_copy_out_env =\n      ParseBooleanFromEnv(\"ONEFLOW_ONE_EMBEDDING_ADD_ID_SHUFFLE_COPY_OUT\", true);\n  static bool enable_quantized_comm =\n      ParseBooleanFromEnv(\"ONEFLOW_ONE_EMBEDDING_ENABLE_QUANTIZED_COMM\", false);\n  if (use_embedding_gradient_shuffle_p2p_env) {\n    if (embedding_dtype != DataType::kFloat16 || idx_dtype != DataType::kUInt32) {\n      // p2p kernel only registered kFloat16 and kUint32.\n      return false;\n    }\n    if (!add_id_shuffle_copy_out_env) {\n      // when not enable id shuffle copy out, the ptrs change every iter.\n      return false;\n    }\n    if (enable_quantized_comm) {\n      // p2p kernel not support quantize comm.\n      return false;\n    }\n    if (UseDynamicMemoryAllocation()) {\n      // p2p kernel not support dynamic memory allocation.\n      return false;\n    }\n  }\n#if CUDA_VERSION >= 11030\n  return use_embedding_gradient_shuffle_p2p_env;\n#else\n  if (use_embedding_gradient_shuffle_p2p_env) {\n    LOG(WARNING) << \"embedding gradient shuffle p2p kernel only support when cuda_version greater \"\n                    \"equal than 11.3. \";\n  }\n  return false;\n#endif\n}\n\n#ifdef WITH_CUDA\n\nclass TmpBufferAllocator {\n public:\n  TmpBufferAllocator() = default;\n  virtual ~TmpBufferAllocator() = default;\n\n  virtual void Allocate(void** ptr, size_t size) = 0;\n  virtual void Free(void* ptr) = 0;\n};\n\nclass EmbeddingState {\n public:\n  EmbeddingState() = default;\n  virtual ~EmbeddingState() = default;\n\n  virtual std::unique_ptr<TmpBufferAllocator> NewTmpBufferAllocator(\n      user_op::KernelComputeContext* ctx) = 0;\n\n  virtual void OnEmbeddingLookupStart(user_op::KernelComputeContext* ctx, int64_t iter) = 0;\n  virtual void* LookupUniqueValues(int64_t iter) = 0;\n  virtual void* LookupEmbeddings(int64_t iter) = 0;\n  virtual void OnEmbeddingLookupEnd(user_op::KernelComputeContext* ctx, int64_t iter) = 0;\n\n  virtual void OnEmbeddingGatherStart(user_op::KernelComputeContext* ctx, int64_t iter) = 0;\n  virtual const void* EmbeddingGatherIn(int64_t iter) = 0;\n  virtual void OnEmbeddingGatherEnd(user_op::KernelComputeContext* ctx, int64_t iter) = 0;\n\n  virtual void OnEmbeddingShuffleStart(user_op::KernelComputeContext* ctx, int64_t iter) = 0;\n  virtual const void* EmbeddingShuffleCurRankEmbeddings(int64_t iter) = 0;\n  virtual void OnEmbeddingShuffleEnd(user_op::KernelComputeContext* ctx, int64_t iter) = 0;\n\n  virtual void OnEmbeddingUpdateStart(user_op::KernelComputeContext* ctx, int64_t iter) = 0;\n  virtual const void* EmbeddingUpdateUniqueEmbeddings(int64_t iter) = 0;\n  virtual void* EmbeddingUpdateUpdatedUniqueEmbeddings(int64_t iter) = 0;\n  virtual void OnEmbeddingUpdateEnd(user_op::KernelComputeContext* ctx, int64_t iter) = 0;\n\n  virtual void OnEmbeddingPutStart(user_op::KernelComputeContext* ctx, int64_t iter) = 0;\n  virtual const void* EmbeddingPutUniqueEmbeddings(int64_t iter) = 0;\n  virtual void OnEmbeddingPutEnd(user_op::KernelComputeContext* ctx, int64_t iter) = 0;\n\n  virtual void OnEmbeddingFusedUpdatePutStart(user_op::KernelComputeContext* ctx, int64_t iter) = 0;\n  virtual const void* EmbeddingFusedUpdatePutUniqueEmbeddings(int64_t iter) = 0;\n  virtual void OnEmbeddingFusedUpdatePutEnd(user_op::KernelComputeContext* ctx, int64_t iter) = 0;\n\n  virtual void SetIdFinalNumUnique(uint32_t final_num_unique, int64_t iter) = 0;\n  virtual void SetIdNumUniqueMatrix(const std::vector<uint32_t>& num_unique_matrix,\n                                    int64_t iter) = 0;\n  virtual uint32_t GetIdNumUnique(int64_t iter) = 0;\n  virtual const std::vector<uint32_t>& GetIdNumUniqueMatrix(int64_t iter) = 0;\n};\n\nclass EmbeddingManager final {\n public:\n  EmbeddingManager() = default;\n  ~EmbeddingManager() = default;\n\n  void SaveSnapshot(const std::string& embedding_name, int64_t local_rank_id, int64_t rank_id,\n                    const std::string& snapshot_name);\n  void LoadSnapshot(const std::string& embedding_name, int64_t local_rank_id, int64_t rank_id,\n                    const std::string& snapshot_name);\n\n  KeyValueStore* GetKeyValueStore(const std::string& embedding_name, int64_t rank_id);\n  EmbeddingState* GetEmbeddingState(const std::string& embedding_name, int64_t rank_id);\n  void CreateKeyValueStore(const KeyValueStoreOptions& options, int64_t local_rank_id,\n                           int64_t rank_id, int64_t world_size);\n\n private:\n  HashMap<std::pair<std::string, int64_t>, std::unique_ptr<KeyValueStore>> key_value_store_map_;\n  HashMap<std::pair<std::string, int64_t>, std::unique_ptr<EmbeddingState>> embedding_state_map_;\n  std::mutex mutex_;\n};\n\n#endif  // WITH_CUDA\n\n}  // namespace embedding\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_EMBEDDING_EMBEDDING_MANAGER_H_\n"
  },
  {
    "path": "oneflow/core/embedding/full_cache.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/embedding/full_cache.h\"\n#include \"oneflow/core/device/cuda_util.h\"\n#include \"oneflow/core/embedding/hash_functions.cuh\"\n#include \"oneflow/core/cuda/atomic.cuh\"\n\nnamespace oneflow {\n\nnamespace embedding {\n\nusing Key32 = unsigned int;\nusing Key64 = unsigned long long int;\nusing Key128 = ulonglong2;\n\nnamespace {\n\ntemplate<typename Key, typename Index, bool dump_dirty_only>\n__device__ bool TryGetOrInsert(Key* entry_key, volatile Index* entry_index, bool* entry_dirty_flag,\n                               Index* table_size, Key key, Index* out) {\n  Key key_hi = (key | 0x1);\n  Key key_lo = (key & 0x1);\n  Index index_plus_one = 0;\n  Key old_entry_key = cuda::atomic::CAS(entry_key, static_cast<Key>(0), key_hi);\n  while (index_plus_one == 0) {\n    if (old_entry_key == static_cast<Key>(0)) {\n      Index index = cuda::atomic::Add(table_size, static_cast<Index>(1));\n      index_plus_one = index + 1;\n      *entry_index = ((index_plus_one << 1U) | key_lo);\n      *out = index_plus_one;\n      if (dump_dirty_only) {\n        bool entry_flag_val = *entry_dirty_flag;\n        if (!entry_flag_val) { *entry_dirty_flag = true; }\n      }\n      return true;\n    } else if (old_entry_key == key_hi) {\n      const Index entry_index_val = *entry_index;\n      if (entry_index_val == 0) {\n        // do nothing\n      } else if ((entry_index_val & 0x1) == key_lo) {\n        *out = (entry_index_val >> 1U);\n        if (dump_dirty_only) {\n          bool entry_flag_val = *entry_dirty_flag;\n          if (!entry_flag_val) { *entry_dirty_flag = true; }\n        }\n        return true;\n      } else {\n        return false;\n      }\n    } else {\n      return false;\n    }\n  }\n  return false;\n}\n\ntemplate<typename Key, typename Index, bool dump_dirty_only>\n__device__ bool GetOrInsertOne(const size_t capacity, Key* table_keys, Index* table_indices,\n                               bool* table_dirty_flags, Index* table_size, Key key, size_t hash,\n                               Index* out) {\n  const size_t start_idx = hash % capacity;\n  for (size_t count = 0; count < capacity; ++count) {\n    const size_t idx = (start_idx + count) % capacity;\n    Key* entry_key = table_keys + idx;\n    Index* entry_index = table_indices + idx;\n    bool* entry_dirty_flag = dump_dirty_only ? table_dirty_flags + idx : nullptr;\n    if (TryGetOrInsert<Key, Index, dump_dirty_only>(entry_key, entry_index, entry_dirty_flag,\n                                                    table_size, key, out)) {\n      return true;\n    }\n  }\n  return false;\n}\n\ntemplate<typename Key, typename Index>\n__device__ bool GetOne(const size_t capacity, Key* table_keys, Index* table_indices, Key key,\n                       size_t hash, Index* out) {\n  const size_t start_idx = hash % capacity;\n  for (size_t count = 0; count < capacity; ++count) {\n    const size_t idx = (start_idx + count) % capacity;\n    Key entry_key = table_keys[idx];\n    Key entry_index = table_indices[idx];\n    Key key_hi = (key | 0x1);\n    Key key_lo = (key & 0x1);\n    if (entry_key == 0) { break; }\n    if (entry_key == key_hi) {\n      if ((entry_index & 0x1) == key_lo) {\n        *out = (entry_index >> 1U);\n        return true;\n      }\n    }\n  }\n  *out = 0;\n  return false;\n}\n\ntemplate<typename Key, typename Index, bool dump_dirty_only>\n__global__ void OrdinalEncodeKernel(uint64_t capacity, Key* table_keys, Index* table_indices,\n                                    bool* table_dirty_flags, Index* table_size, uint32_t num_keys,\n                                    const Key* keys, Index* context) {\n  CUDA_1D_KERNEL_LOOP(i, num_keys) {\n    Key key = keys[i];\n    uint64_t hash = FullCacheHash()(key);\n    bool success = GetOrInsertOne<Key, Index, dump_dirty_only>(\n        capacity, table_keys, table_indices, table_dirty_flags, table_size, key, hash, context + i);\n    assert(success);\n  }\n}\n\ntemplate<typename Key, typename Index>\n__global__ void OrdinalEncodeLookupKernel(uint64_t capacity, Key* table_keys, Index* table_indices,\n                                          uint32_t num_keys, const Key* keys, Index* context) {\n  CUDA_1D_KERNEL_LOOP(i, num_keys) {\n    Key key = keys[i];\n    uint64_t hash = FullCacheHash()(key);\n    GetOne<Key, Index>(capacity, table_keys, table_indices, key, hash, context + i);\n  }\n}\n\ntemplate<typename Key, typename Index, bool dump_dirty_only>\n__global__ void OrdinalEncodeDumpKernel(const Key* table_keys, const Index* table_indices,\n                                        const bool* table_dirty_flags, uint64_t start_key_index,\n                                        uint64_t end_key_index, uint32_t* n_dumped, Key* keys,\n                                        Index* context) {\n  CUDA_1D_KERNEL_LOOP(i, (end_key_index - start_key_index)) {\n    Key entry_key = table_keys[i + start_key_index];\n    Index entry_index = table_indices[i + start_key_index];\n    bool dump_flag = (entry_index != 0);\n    if (dump_dirty_only) {\n      bool entry_dirty_flag = table_dirty_flags[i + start_key_index];\n      dump_flag = (dump_flag && entry_dirty_flag);\n    }\n    if (dump_flag) {\n      uint32_t index = cuda::atomic::Add(n_dumped, static_cast<uint32_t>(1));\n      keys[index] = ((entry_key ^ 0x1) | (entry_index & 0x1));\n      context[index] = (entry_index >> 1U);\n    }\n  }\n}\n\ntemplate<typename Key, typename Elem, typename Index, bool return_value>\n__global__ void LookupKernel(uint32_t value_length, const Elem* cache_values,\n                             uint32_t values_elem_cnt, const Key* keys, const Index* context,\n                             Elem* values, uint32_t* n_missing, Key* missing_keys,\n                             uint32_t* missing_indices) {\n  CUDA_1D_KERNEL_LOOP(i, values_elem_cnt) {\n    const uint64_t key_id = i / value_length;\n    const uint64_t ctx = context[key_id];\n    const uint64_t row_id = ctx - 1;\n    const uint64_t col_id = i - key_id * value_length;\n    if (ctx == 0) {\n      const Key missing_key = keys[key_id];\n      if (col_id == 0) {\n        const uint32_t old_n_missing = cuda::atomic::Add(n_missing, static_cast<uint32_t>(1));\n        missing_keys[old_n_missing] = missing_key;\n        missing_indices[old_n_missing] = key_id;\n      }\n      continue;\n    }\n    if (return_value) { values[i] = cache_values[row_id * value_length + col_id]; }\n  }\n}\n\ntemplate<typename Key, typename Elem, typename Index, uint32_t block_size>\n__global__ void EncodeLookupKernel(uint32_t value_length, const Elem* cache_values,\n                                   uint32_t values_elem_cnt, const Key* keys, const Index* context,\n                                   Elem* values, uint32_t* n_missing, Key* missing_keys,\n                                   uint32_t* missing_indices, const size_t capacity,\n                                   Key* table_keys, Index* table_indices) {\n  constexpr uint32_t warp_size = 32;\n  constexpr uint32_t n_warp_per_block = block_size / warp_size;\n  const uint32_t warp_id = threadIdx.x / warp_size;\n  const uint32_t lane_id = threadIdx.x % warp_size;\n  const uint32_t global_warp_id = blockIdx.x * n_warp_per_block + warp_id;\n  const uint32_t global_n_warp = gridDim.x * n_warp_per_block;\n  const uint32_t n_keys = values_elem_cnt / value_length;\n  __shared__ Key batch_keys[n_warp_per_block][warp_size];\n  __shared__ Index batch_row_ids[n_warp_per_block][warp_size];\n  __shared__ Key batch_missing_keys[n_warp_per_block][warp_size];\n  __shared__ uint32_t batch_missing_indices[n_warp_per_block][warp_size];\n  __shared__ uint32_t batch_n_missing[n_warp_per_block];\n  for (uint32_t batch_start = global_warp_id * warp_size; batch_start < n_keys;\n       batch_start += global_n_warp * warp_size) {\n    const uint32_t batch_n_key = min(n_keys - batch_start, warp_size);\n    if (lane_id == 0) { batch_n_missing[warp_id] = 0; }\n    __syncwarp();\n    const uint32_t key_offset = batch_start + lane_id;\n    if (key_offset < n_keys) {\n      const Key key = keys[batch_start + lane_id];\n      const uint64_t hash = FullCacheHash()(key);\n      Index row;\n      GetOne<Key, Index>(capacity, table_keys, table_indices, key, hash, &row);\n      batch_row_ids[warp_id][lane_id] = row;\n      if (row == 0) {\n        const uint32_t batch_missing_idx = atomicAdd(batch_n_missing + warp_id, 1);\n        batch_missing_keys[warp_id][batch_missing_idx] = key;\n        batch_missing_indices[warp_id][batch_missing_idx] = key_offset;\n      }\n    }\n    __syncwarp();\n    const uint32_t batch_n_missing_t = batch_n_missing[warp_id];\n    if (lane_id == 0) {\n      const uint32_t old_n_missing =\n          cuda::atomic::Add(n_missing, static_cast<uint32_t>(batch_n_missing_t));\n      batch_n_missing[warp_id] = old_n_missing;\n    }\n    __syncwarp();\n    if (lane_id < batch_n_missing_t) {\n      missing_keys[batch_n_missing[warp_id] + lane_id] = batch_missing_keys[warp_id][lane_id];\n      missing_indices[batch_n_missing[warp_id] + lane_id] = batch_missing_indices[warp_id][lane_id];\n    }\n    for (int i = 0; i < batch_n_key; ++i) {\n      const Key key = batch_keys[warp_id][i];\n      const int64_t row = batch_row_ids[warp_id][i];\n      if (row == 0) { continue; }\n      for (int col = lane_id; col < value_length; col += warp_size) {\n        values[(batch_start + i) * value_length + col] =\n            cache_values[(row - 1) * value_length + col];\n      }\n    }\n    __syncwarp();\n  }\n}\n\ntemplate<typename T, size_t pack_size>\nstruct alignas(sizeof(T) * pack_size) Pack {\n  T elem[pack_size];\n};\n\ntemplate<typename Key, typename Elem, typename Index, uint32_t block_size, uint32_t pack_size>\n__global__ void EncodeLookupMaskKernel(uint32_t value_length, const Elem* __restrict__ cache_values,\n                                       uint32_t values_elem_cnt, const Key* __restrict__ keys,\n                                       const Index* __restrict__ context, Elem* __restrict__ values,\n                                       uint8_t* __restrict__ mask, const size_t capacity,\n                                       Key* __restrict__ table_keys,\n                                       Index* __restrict__ table_indices) {\n  const uint32_t packed_cols = value_length / pack_size;\n  auto* packed_values = reinterpret_cast<Pack<Elem, pack_size>*>(values);\n  const auto* packed_cache_values = reinterpret_cast<const Pack<Elem, pack_size>*>(cache_values);\n  constexpr uint32_t warp_size = 32;\n  constexpr uint32_t n_warp_per_block = block_size / warp_size;\n  const uint32_t warp_id = threadIdx.x / warp_size;\n  const uint32_t lane_id = threadIdx.x % warp_size;\n  const uint32_t global_warp_id = blockIdx.x * n_warp_per_block + warp_id;\n  const uint32_t global_n_warp = gridDim.x * n_warp_per_block;\n  const uint32_t n_keys = values_elem_cnt / value_length;\n  __shared__ Key batch_keys[n_warp_per_block][warp_size];\n  __shared__ Index batch_row_ids[n_warp_per_block][warp_size];\n  for (uint32_t batch_start = global_warp_id * warp_size; batch_start < n_keys;\n       batch_start += global_n_warp * warp_size) {\n    const uint32_t batch_n_key = min(n_keys - batch_start, warp_size);\n    const uint32_t key_offset = batch_start + lane_id;\n    if (key_offset < n_keys) {\n      const Key key = keys[batch_start + lane_id];\n      const uint64_t hash = FullCacheHash()(key);\n      Index row;\n      GetOne<Key, Index>(capacity, table_keys, table_indices, key, hash, &row);\n      batch_row_ids[warp_id][lane_id] = row;\n      mask[key_offset] = row > 0;\n    }\n    __syncwarp();\n    for (int i = 0; i < batch_n_key; ++i) {\n      const Key key = batch_keys[warp_id][i];\n      const int64_t row = batch_row_ids[warp_id][i];\n      if (row == 0) { continue; }\n#pragma unroll 4\n      for (int col = lane_id; col < packed_cols; col += warp_size) {\n        packed_values[(batch_start + i) * packed_cols + col] =\n            packed_cache_values[(row - 1) * packed_cols + col];\n      }\n    }\n    __syncwarp();\n  }\n}\n\ntemplate<typename Elem, typename Index, size_t pack_size>\n__global__ void UpdateKernel(uint32_t value_length, Elem* cache_values, uint32_t values_elem_cnt,\n                             const Index* context, const Elem* values) {\n  const int packed_values_elem_cnt = values_elem_cnt / pack_size;\n  const uint32_t packed_elem_cnt = value_length / pack_size;\n  auto* packed_cache_values = reinterpret_cast<Pack<Elem, pack_size>*>(cache_values);\n  auto* packed_values = reinterpret_cast<const Pack<Elem, pack_size>*>(values);\n  CUDA_1D_KERNEL_LOOP(i, packed_values_elem_cnt) {\n    const uint64_t key_id = i / packed_elem_cnt;\n    const uint64_t ctx = context[key_id];\n    if (ctx == 0) { continue; }\n    const uint64_t row_id = ctx - 1;\n    const uint64_t col_id = i - key_id * packed_elem_cnt;\n    packed_cache_values[row_id * packed_elem_cnt + col_id] = packed_values[i];\n  }\n}\n\ntemplate<typename Elem, typename Index, size_t pack_size>\n__global__ typename std::enable_if<std::is_same<Elem, float>::value, void>::type\nFusedHalfUpdateKernel(uint32_t value_length, Elem* __restrict__ cache_values,\n                      uint32_t values_elem_cnt, const Index* __restrict__ context,\n                      const Elem* __restrict__ values, const half* __restrict__ update,\n                      const float* __restrict__ lr, float scale) {\n  const int packed_values_elem_cnt = values_elem_cnt / pack_size;\n  const uint32_t packed_elem_cnt = value_length / pack_size;\n  auto* packed_cache_values = reinterpret_cast<Pack<Elem, pack_size>*>(cache_values);\n  auto* packed_values = reinterpret_cast<const Pack<Elem, pack_size>*>(values);\n  auto* packed_update = reinterpret_cast<const Pack<half, pack_size>*>(update);\n  const float alpha = -*lr * scale;\n  CUDA_1D_KERNEL_LOOP(i, packed_values_elem_cnt) {\n    const uint64_t key_id = i / packed_elem_cnt;\n    const uint64_t ctx = context[key_id];\n    if (ctx == 0) { continue; }\n    const uint64_t row_id = ctx - 1;\n    const uint64_t col_id = i - key_id * packed_elem_cnt;\n    Pack<Elem, pack_size> m = packed_values[i];\n    Pack<half, pack_size> u = packed_update[i];\n    for (size_t j = 0; j < pack_size; ++j) { m.elem[j] += static_cast<Elem>(u.elem[j]) * alpha; }\n    packed_cache_values[row_id * packed_elem_cnt + col_id] = m;\n  }\n}\n\ntemplate<typename Elem, typename Index, size_t pack_size>\n__global__ typename std::enable_if<!std::is_same<Elem, float>::value, void>::type\nFusedHalfUpdateKernel(uint32_t value_length, Elem* cache_values, uint32_t values_elem_cnt,\n                      const Index* context, const Elem* values, const half* update, const float* lr,\n                      float scale) {\n  __trap();\n}\n\ntemplate<typename Key, typename Elem, typename Index>\n__global__ void DumpValueKernel(uint32_t value_length, const uint32_t* n_dumped,\n                                const Index* context, const Elem* cache_values, Elem* values) {\n  CUDA_1D_KERNEL_LOOP(i, *n_dumped * value_length) {\n    const uint64_t key_id = i / value_length;\n    const uint64_t ctx = context[key_id];\n    const uint64_t row_id = ctx - 1;\n    const uint64_t col_id = i - key_id * value_length;\n    values[i] = cache_values[row_id * value_length + col_id];\n  }\n}\n\ntemplate<typename Key, typename Index>\nclass OrdinalEncoder {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(OrdinalEncoder);\n  explicit OrdinalEncoder(uint64_t capacity, float load_factor, bool if_dump_dirty)\n      : capacity_(capacity),\n        table_capacity_(capacity / load_factor),\n        if_dump_dirty_(if_dump_dirty) {\n    OF_CUDA_CHECK(cudaGetDevice(&device_index_));\n    OF_CUDA_CHECK(cudaMalloc(&table_size_, sizeof(Index)));\n    OF_CUDA_CHECK(cudaMallocHost(&table_size_host_, sizeof(Index)));\n    OF_CUDA_CHECK(cudaMalloc(&table_keys_, table_capacity_ * sizeof(Key)));\n    OF_CUDA_CHECK(cudaMalloc(&table_indices_, table_capacity_ * sizeof(Index)));\n    if (if_dump_dirty_) {\n      OF_CUDA_CHECK(cudaMalloc(&table_dirty_flags_, table_capacity_ * sizeof(bool)));\n    }\n    Clear();\n  }\n  ~OrdinalEncoder() {\n    CudaCurrentDeviceGuard guard(device_index_);\n    OF_CUDA_CHECK(cudaFree(table_size_));\n    OF_CUDA_CHECK(cudaFreeHost(table_size_host_));\n    OF_CUDA_CHECK(cudaFree(table_keys_));\n    OF_CUDA_CHECK(cudaFree(table_indices_));\n    if (if_dump_dirty_) { OF_CUDA_CHECK(cudaFree(table_dirty_flags_)); }\n  }\n\n  template<bool insert, bool dump_dirty_only>\n  void Encode(ep::Stream* stream, uint32_t num_keys, const Key* keys, Index* context) {\n    if (insert) {\n      RUN_CUDA_KERNEL((OrdinalEncodeKernel<Key, Index, dump_dirty_only>), stream, num_keys,\n                      table_capacity_, table_keys_, table_indices_, table_dirty_flags_, table_size_,\n                      num_keys, keys, context);\n    } else {\n      RUN_CUDA_KERNEL((OrdinalEncodeLookupKernel<Key, Index>), stream, num_keys, table_capacity_,\n                      table_keys_, table_indices_, num_keys, keys, context);\n    }\n  }\n\n  void Dump(ep::Stream* stream, uint64_t start_key_index, uint64_t end_key_index,\n            uint32_t* n_dumped, Key* keys, Index* context) {\n    OF_CUDA_CHECK(cudaMemsetAsync(n_dumped, 0, sizeof(uint32_t),\n                                  stream->As<ep::CudaStream>()->cuda_stream()));\n    RUN_CUDA_KERNEL((OrdinalEncodeDumpKernel<Key, Index, false>), stream,\n                    end_key_index - start_key_index, table_keys_, table_indices_,\n                    table_dirty_flags_, start_key_index, end_key_index, n_dumped, keys, context);\n  }\n\n  void DumpDirtyOnly(ep::Stream* stream, uint64_t start_key_index, uint64_t end_key_index,\n                     uint32_t* n_dumped, Key* keys, Index* context) {\n    OF_CUDA_CHECK(cudaMemsetAsync(n_dumped, 0, sizeof(uint32_t),\n                                  stream->As<ep::CudaStream>()->cuda_stream()));\n    RUN_CUDA_KERNEL((OrdinalEncodeDumpKernel<Key, Index, true>), stream,\n                    end_key_index - start_key_index, table_keys_, table_indices_,\n                    table_dirty_flags_, start_key_index, end_key_index, n_dumped, keys, context);\n  }\n\n  void ClearDirtyFlags() {\n    if (if_dump_dirty_) {\n      OF_CUDA_CHECK(cudaMemset(table_dirty_flags_, 0, table_capacity_ * sizeof(bool)));\n    }\n  }\n\n  void Clear() {\n    OF_CUDA_CHECK(cudaMemset(table_size_, 0, sizeof(Index)));\n    OF_CUDA_CHECK(cudaMemset(table_keys_, 0, table_capacity_ * sizeof(Key)));\n    OF_CUDA_CHECK(cudaMemset(table_indices_, 0, table_capacity_ * sizeof(Index)));\n    if (if_dump_dirty_) {\n      OF_CUDA_CHECK(cudaMemset(table_dirty_flags_, 0, table_capacity_ * sizeof(bool)));\n    }\n  }\n\n  uint64_t TableCapacity() const { return table_capacity_; }\n\n  Key* table_keys() const { return table_keys_; }\n\n  Index* table_indices() const { return table_indices_; }\n\n private:\n  int device_index_{};\n  Key* table_keys_;\n  Index* table_indices_;\n  bool* table_dirty_flags_;\n  uint64_t capacity_;\n  uint64_t table_capacity_;\n  bool if_dump_dirty_;\n  Index* table_size_{};\n  Index* table_size_host_{};\n};\n\ntemplate<typename Key, typename Elem, typename Index, size_t pack_size>\nclass CacheImpl : public Cache {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CacheImpl);\n  explicit CacheImpl(const CacheOptions& options)\n      : if_dump_dirty_(ParseBooleanFromEnv(\"ONEFLOW_ONE_EMBEDDING_DUMP_DIRTY_ONLY\", false)),\n        encoder_(options.capacity, options.load_factor, if_dump_dirty_),\n        device_index_(-1),\n        options_(options),\n        max_query_length_(0) {\n    OF_CUDA_CHECK(cudaGetDevice(&device_index_));\n    const uint64_t values_size = options.capacity * options.value_size;\n    if (options.value_memory_kind == CacheOptions::MemoryKind::kDevice) {\n      OF_CUDA_CHECK(cudaMalloc(&values_, values_size));\n    } else if (options.value_memory_kind == CacheOptions::MemoryKind::kHost) {\n      if (ParseBooleanFromEnv(\"ONEFLOW_ONE_EMBEDDING_DISABLE_NUMA_AWARE_ALLOCATION\", false)) {\n        OF_CUDA_CHECK(cudaMallocHost(&values_, values_size));\n      } else {\n        OF_CUDA_CHECK(NumaAwareCudaMallocHost(device_index_, reinterpret_cast<void**>(&values_),\n                                              values_size));\n      }\n    } else {\n      UNIMPLEMENTED();\n    }\n    num_elem_per_value_ = options_.value_size / sizeof(Elem);\n  }\n  ~CacheImpl() {\n    CudaCurrentDeviceGuard guard(device_index_);\n    if (options_.value_memory_kind == CacheOptions::MemoryKind::kDevice) {\n      OF_CUDA_CHECK(cudaFree(values_));\n    } else if (options_.value_memory_kind == CacheOptions::MemoryKind::kHost) {\n      OF_CUDA_CHECK(cudaFreeHost(values_));\n    } else {\n      UNIMPLEMENTED();\n    }\n    if (max_query_length_ > 0) { OF_CUDA_CHECK(cudaFree(encoding_buffer_)); }\n  }\n\n  uint64_t Capacity() const override { return options_.capacity; }\n  uint64_t DumpCapacity() const override { return encoder_.TableCapacity(); }\n  uint32_t KeySize() const override { return options_.key_size; }\n\n  uint32_t ValueSize() const override { return options_.value_size; }\n\n  DataType ValueType() const override { return options_.value_type; }\n\n  uint32_t MaxQueryLength() const override { return max_query_length_; }\n\n  void ReserveQueryLength(uint32_t query_length) override {\n    CudaCurrentDeviceGuard guard(device_index_);\n    if (query_length <= max_query_length_) { return; }\n    if (max_query_length_ > 0) { OF_CUDA_CHECK(cudaFree(encoding_buffer_)); }\n    OF_CUDA_CHECK(cudaMalloc(&encoding_buffer_, query_length * sizeof(uint64_t)));\n    max_query_length_ = query_length;\n  }\n\n  CacheOptions::Policy Policy() const override { return CacheOptions::Policy::kFull; }\n\n  void Test(ep::Stream* stream, uint32_t n_keys, const void* keys, uint32_t* n_missing,\n            void* missing_keys, uint32_t* missing_indices) override;\n\n  void Get(ep::Stream* stream, uint32_t n_keys, const void* keys, void* values, uint32_t* n_missing,\n           void* missing_keys, uint32_t* missing_indices) override;\n\n  void Get(ep::Stream* stream, uint32_t n_keys, const void* keys, void* values,\n           uint8_t* mask) override;\n\n  void Put(ep::Stream* stream, uint32_t n_keys, const void* keys, const void* values,\n           uint32_t* n_evicted, void* evicted_keys, void* evicted_values) override;\n\n  void FusedHalfUpdatePut(ep::Stream* stream, uint32_t n_keys, const void* keys, const void* values,\n                          const void* update, const float* lr, float scale, uint32_t* n_evicted,\n                          void* evicted_keys, void* evicted_values) override;\n  void Dump(ep::Stream* stream, uint64_t start_key_index, uint64_t end_key_index,\n            uint32_t* n_dumped, void* keys, void* values) override;\n\n  void ClearDirtyFlags() override;\n\n  void Clear() override;\n\n private:\n  bool if_dump_dirty_;\n  OrdinalEncoder<Key, Index> encoder_;\n  int device_index_;\n  uint32_t num_elem_per_value_{};\n  Elem* values_;\n  Index* encoding_buffer_{};\n  CacheOptions options_;\n  uint32_t max_query_length_;\n};\n\ntemplate<typename Key, typename Elem, typename Index, size_t pack_size>\nvoid CacheImpl<Key, Elem, Index, pack_size>::Test(ep::Stream* stream, uint32_t n_keys,\n                                                  const void* keys, uint32_t* n_missing,\n                                                  void* missing_keys, uint32_t* missing_indices) {\n  OF_CUDA_CHECK(\n      cudaMemsetAsync(n_missing, 0, sizeof(uint32_t), stream->As<ep::CudaStream>()->cuda_stream()));\n  if (n_keys == 0) { return; }\n  CHECK_LE(n_keys, max_query_length_);\n  if (if_dump_dirty_) {\n    encoder_.template Encode<false, true>(stream, n_keys, static_cast<const Key*>(keys),\n                                          encoding_buffer_);\n  } else {\n    encoder_.template Encode<false, false>(stream, n_keys, static_cast<const Key*>(keys),\n                                           encoding_buffer_);\n  }\n  const uint32_t values_elem_cnt = n_keys * num_elem_per_value_;\n  RUN_CUDA_KERNEL((LookupKernel<Key, Elem, Index, false>), stream, values_elem_cnt,\n                  num_elem_per_value_, values_, values_elem_cnt, static_cast<const Key*>(keys),\n                  encoding_buffer_, nullptr, n_missing, static_cast<Key*>(missing_keys),\n                  missing_indices);\n}\n\ntemplate<typename Key, typename Elem, typename Index, size_t pack_size>\nvoid CacheImpl<Key, Elem, Index, pack_size>::Get(ep::Stream* stream, uint32_t n_keys,\n                                                 const void* keys, void* values,\n                                                 uint32_t* n_missing, void* missing_keys,\n                                                 uint32_t* missing_indices) {\n  OF_CUDA_CHECK(\n      cudaMemsetAsync(n_missing, 0, sizeof(uint32_t), stream->As<ep::CudaStream>()->cuda_stream()));\n  if (n_keys == 0) { return; }\n  CHECK_LE(n_keys, max_query_length_);\n  constexpr uint32_t block_size = 128;\n  uint32_t grid_size = (n_keys + block_size - 1) / block_size;\n  const uint32_t values_elem_cnt = n_keys * num_elem_per_value_;\n  EncodeLookupKernel<Key, Elem, Index, block_size>\n      <<<grid_size, block_size, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(\n          num_elem_per_value_, values_, values_elem_cnt, static_cast<const Key*>(keys),\n          encoding_buffer_, static_cast<Elem*>(values), n_missing, static_cast<Key*>(missing_keys),\n          missing_indices, encoder_.TableCapacity(), encoder_.table_keys(),\n          encoder_.table_indices());\n}\n\ntemplate<typename Key, typename Elem, typename Index, size_t pack_size>\nvoid CacheImpl<Key, Elem, Index, pack_size>::Get(ep::Stream* stream, uint32_t n_keys,\n                                                 const void* keys, void* values, uint8_t* mask) {\n  if (n_keys == 0) { return; }\n  CHECK_LE(n_keys, max_query_length_);\n  constexpr uint32_t block_size = 128;\n  uint32_t grid_size = (n_keys + block_size - 1) / block_size;\n  const uint32_t values_elem_cnt = n_keys * num_elem_per_value_;\n  EncodeLookupMaskKernel<Key, Elem, Index, block_size, pack_size>\n      <<<grid_size, block_size, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(\n          num_elem_per_value_, values_, values_elem_cnt, static_cast<const Key*>(keys),\n          encoding_buffer_, static_cast<Elem*>(values), mask, encoder_.TableCapacity(),\n          encoder_.table_keys(), encoder_.table_indices());\n}\n\ntemplate<typename Key, typename Elem, typename Index, size_t pack_size>\nvoid CacheImpl<Key, Elem, Index, pack_size>::Put(ep::Stream* stream, uint32_t n_keys,\n                                                 const void* keys, const void* values,\n                                                 uint32_t* n_evicted, void* evicted_keys,\n                                                 void* evicted_values) {\n  if (n_keys == 0) { return; }\n  CHECK_LE(n_keys, max_query_length_);\n  if (if_dump_dirty_) {\n    encoder_.template Encode<true, true>(stream, n_keys, static_cast<const Key*>(keys),\n                                         encoding_buffer_);\n  } else {\n    encoder_.template Encode<true, false>(stream, n_keys, static_cast<const Key*>(keys),\n                                          encoding_buffer_);\n  }\n  const uint32_t values_elem_cnt = n_keys * num_elem_per_value_;\n  RUN_CUDA_KERNEL((UpdateKernel<Elem, Index, pack_size>), stream, values_elem_cnt / pack_size,\n                  num_elem_per_value_, values_, values_elem_cnt, encoding_buffer_,\n                  static_cast<const Elem*>(values));\n}\n\ntemplate<typename Key, typename Elem, typename Index, size_t pack_size>\nvoid CacheImpl<Key, Elem, Index, pack_size>::FusedHalfUpdatePut(\n    ep::Stream* stream, uint32_t n_keys, const void* keys, const void* values, const void* update,\n    const float* lr, float scale, uint32_t* n_evicted, void* evicted_keys, void* evicted_values) {\n  if (!std::is_same<Elem, float>::value) { UNIMPLEMENTED(); }\n  if (n_keys == 0) { return; }\n  CHECK_LE(n_keys, max_query_length_);\n  if (if_dump_dirty_) {\n    encoder_.template Encode<true, true>(stream, n_keys, static_cast<const Key*>(keys),\n                                         encoding_buffer_);\n  } else {\n    encoder_.template Encode<true, false>(stream, n_keys, static_cast<const Key*>(keys),\n                                          encoding_buffer_);\n  }\n  const uint32_t values_elem_cnt = n_keys * num_elem_per_value_;\n  RUN_CUDA_KERNEL((FusedHalfUpdateKernel<Elem, Index, pack_size>), stream,\n                  values_elem_cnt / pack_size, num_elem_per_value_, values_, values_elem_cnt,\n                  encoding_buffer_, static_cast<const Elem*>(values),\n                  static_cast<const half*>(update), lr, scale);\n}\n\ntemplate<typename Key, typename Elem, typename Index, size_t pack_size>\nvoid CacheImpl<Key, Elem, Index, pack_size>::Dump(ep::Stream* stream, uint64_t start_key_index,\n                                                  uint64_t end_key_index, uint32_t* n_dumped,\n                                                  void* keys, void* values) {\n  if (if_dump_dirty_) {\n    encoder_.DumpDirtyOnly(stream, start_key_index, end_key_index, n_dumped,\n                           static_cast<Key*>(keys), encoding_buffer_);\n  } else {\n    encoder_.Dump(stream, start_key_index, end_key_index, n_dumped, static_cast<Key*>(keys),\n                  encoding_buffer_);\n  }\n  RUN_CUDA_KERNEL((DumpValueKernel<Key, Elem, Index>), stream,\n                  num_elem_per_value_ * (end_key_index - start_key_index), num_elem_per_value_,\n                  n_dumped, encoding_buffer_, values_, static_cast<Elem*>(values));\n}\n\ntemplate<typename Key, typename Elem, typename Index, size_t pack_size>\nvoid CacheImpl<Key, Elem, Index, pack_size>::ClearDirtyFlags() {\n  encoder_.ClearDirtyFlags();\n}\n\ntemplate<typename Key, typename Elem, typename Index, size_t pack_size>\nvoid CacheImpl<Key, Elem, Index, pack_size>::Clear() {\n  encoder_.Clear();\n}\n\ntemplate<typename Key, typename Index>\nstd::unique_ptr<Cache> DispatchValueType(const CacheOptions& options) {\n  if (options.value_type == DataType::kFloat) {\n    const size_t value_elem_cnt = options.value_size / sizeof(float);\n    const size_t half_warp = 16;\n    if (value_elem_cnt % 4 == 0 && value_elem_cnt / 4 > half_warp) {\n      return std::unique_ptr<Cache>(new CacheImpl<Key, float, Index, 4>(options));\n    } else if (value_elem_cnt % 2 == 0 && value_elem_cnt / 2 > half_warp) {\n      return std::unique_ptr<Cache>(new CacheImpl<Key, float, Index, 2>(options));\n    } else {\n      return std::unique_ptr<Cache>(new CacheImpl<Key, float, Index, 1>(options));\n    }\n  } else if (options.value_size % sizeof(ulonglong2) == 0) {\n    return std::unique_ptr<Cache>(new CacheImpl<Key, ulonglong2, Index, 1>(options));\n  } else if (options.value_size % sizeof(uint64_t) == 0) {\n    return std::unique_ptr<Cache>(new CacheImpl<Key, uint64_t, Index, 1>(options));\n  } else if (options.value_size % sizeof(uint32_t) == 0) {\n    return std::unique_ptr<Cache>(new CacheImpl<Key, uint32_t, Index, 1>(options));\n  } else if (options.value_size % sizeof(uint16_t) == 0) {\n    return std::unique_ptr<Cache>(new CacheImpl<Key, uint16_t, Index, 1>(options));\n  } else {\n    return std::unique_ptr<Cache>(new CacheImpl<Key, uint8_t, Index, 1>(options));\n  }\n}\n\ntemplate<typename Index>\nstd::unique_ptr<Cache> DispatchKeyType(const CacheOptions& options) {\n  if (options.key_size == sizeof(Key32)) {\n    return DispatchValueType<Key32, Index>(options);\n  } else if (options.key_size == sizeof(Key64)) {\n    return DispatchValueType<Key64, Index>(options);\n  } else {\n    UNIMPLEMENTED();\n    return nullptr;\n  }\n}\n\nstd::unique_ptr<Cache> DispatchIndexType(const CacheOptions& options) {\n  const int64_t table_capacity = static_cast<double>(options.capacity) / options.load_factor;\n  if (table_capacity >= (1ULL << 31ULL)) {\n    return DispatchKeyType<uint64_t>(options);\n  } else {\n    return DispatchKeyType<uint32_t>(options);\n  }\n}\n\n}  // namespace\n\nstd::unique_ptr<Cache> NewFullCache(const CacheOptions& options) {\n  return DispatchIndexType(options);\n}\n\n}  // namespace embedding\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/embedding/full_cache.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_EMBEDDING_FULL_CACHE_H_\n#define ONEFLOW_CORE_EMBEDDING_FULL_CACHE_H_\n\n#include \"oneflow/core/embedding/cache.h\"\n#include \"oneflow/core/common/data_type.h\"\n\nnamespace oneflow {\n\nnamespace embedding {\n\n#ifdef WITH_CUDA\n\nstd::unique_ptr<Cache> NewFullCache(const CacheOptions& options);\n\n#endif  // WITH_CUDA\n\n}  // namespace embedding\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_EMBEDDING_FULL_CACHE_H_\n"
  },
  {
    "path": "oneflow/core/embedding/hash_functions.cuh",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_EMBEDDING_HASH_FUNCTION_H_\n#define ONEFLOW_CORE_EMBEDDING_HASH_FUNCTION_H_\n\n#include <stdint.h>\n#include \"oneflow/core/common/data_type.h\"\n\nnamespace oneflow {\n\nnamespace embedding {\n\nnamespace {\n\n// From https://github.com/Cyan4973/xxHash/blob/dev/xxhash.h\nstatic const uint64_t PRIME64_1 =\n    0x9E3779B185EBCA87ULL;  // 0b1001111000110111011110011011000110000101111010111100101010000111\nstatic const uint64_t PRIME64_2 =\n    0xC2B2AE3D27D4EB4FULL;  // 0b1100001010110010101011100011110100100111110101001110101101001111\nstatic const uint64_t PRIME64_3 =\n    0x165667B19E3779F9ULL;  // 0b0001011001010110011001111011000110011110001101110111100111111001\nstatic const uint64_t PRIME64_4 =\n    0x85EBCA77C2B2AE63ULL;  // 0b1000010111101011110010100111011111000010101100101010111001100011\nstatic const uint64_t PRIME64_5 =\n    0x27D4EB2F165667C5ULL;  // 0b0010011111010100111010110010111100010110010101100110011111000101\n\n#define XXH_rotl64(x, r) (((x) << (r)) | ((x) >> (64 - (r))))\n\nOF_DEVICE_FUNC uint64_t XXH64_round(uint64_t acc, uint64_t input) {\n  acc += input * PRIME64_2;\n  acc = XXH_rotl64(acc, 31);\n  acc *= PRIME64_1;\n  return acc;\n}\n\nOF_DEVICE_FUNC uint64_t xxh64_uint64(uint64_t v, uint64_t seed) {\n  uint64_t acc = seed + PRIME64_5;\n  acc += sizeof(uint64_t);\n  acc = acc ^ XXH64_round(0, v);\n  acc = XXH_rotl64(acc, 27) * PRIME64_1;\n  acc = acc + PRIME64_4;\n  acc ^= (acc >> 33);\n  acc = acc * PRIME64_2;\n  acc = acc ^ (acc >> 29);\n  acc = acc * PRIME64_3;\n  acc = acc ^ (acc >> 32);\n  return acc;\n}\n\nstatic const size_t kShardingHashSeed = 1;\nstatic const size_t kLocalUniqueHashSeed = 2;\nstatic const size_t kGlobalUniqueHashSeed = 3;\nstatic const size_t kFullCacheHashSeed = 4;\nstatic const size_t kLruCacheHashSeed = 5;\n\n}  // namespace\n\nstruct ShardingHash {\n  OF_DEVICE_FUNC size_t operator()(uint64_t v) { return xxh64_uint64(v, kShardingHashSeed); }\n  OF_DEVICE_FUNC size_t operator()(uint32_t v) { return xxh64_uint64(v, kShardingHashSeed); }\n  OF_DEVICE_FUNC size_t operator()(int32_t v) {\n    return xxh64_uint64(static_cast<uint32_t>(v), kShardingHashSeed);\n  }\n  OF_DEVICE_FUNC size_t operator()(int64_t v) {\n    return xxh64_uint64(static_cast<uint64_t>(v), kShardingHashSeed);\n  }\n};\n\nstruct LocalUniqueHash {\n  OF_DEVICE_FUNC size_t operator()(uint64_t v) { return xxh64_uint64(v, kLocalUniqueHashSeed); }\n};\n\nstruct GlobalUniqueHash {\n  OF_DEVICE_FUNC size_t operator()(uint64_t v) { return xxh64_uint64(v, kGlobalUniqueHashSeed); }\n};\n\nstruct FullCacheHash {\n  OF_DEVICE_FUNC size_t operator()(uint64_t v) { return xxh64_uint64(v, kFullCacheHashSeed); }\n};\n\nstruct LruCacheHash {\n  OF_DEVICE_FUNC size_t operator()(uint64_t v) { return xxh64_uint64(v, kLruCacheHashSeed); }\n};\n\n}  // namespace embedding\n}  // namespace oneflow\n#endif  // ONEFLOW_CORE_EMBEDDING_HASH_FUNCTION_H_\n"
  },
  {
    "path": "oneflow/core/embedding/key_value_store.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_EMBEDDING_KEY_VALUE_STORE_H_\n#define ONEFLOW_CORE_EMBEDDING_KEY_VALUE_STORE_H_\n\n#include \"oneflow/core/embedding/kv_iterator.h\"\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/ep/include/stream.h\"\n\nnamespace oneflow {\n\nnamespace embedding {\n\nclass KeyValueStore {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(KeyValueStore);\n  KeyValueStore() = default;\n  virtual ~KeyValueStore() = default;\n\n  virtual uint32_t KeySize() const = 0;\n  virtual uint32_t ValueSize() const = 0;\n  virtual uint32_t MaxQueryLength() const = 0;\n  virtual void ReserveQueryLength(uint32_t query_length) = 0;\n\n  virtual void Get(ep::Stream* stream, uint32_t num_keys, const void* keys, void* values,\n                   uint32_t* n_missing, uint32_t* missing_indices) = 0;\n  virtual void Get(ep::Stream* stream, uint32_t num_keys, const void* keys, void* values,\n                   uint8_t* mask) {\n    UNIMPLEMENTED();\n  }\n  virtual void Put(ep::Stream* stream, uint32_t num_keys, const void* keys, const void* values) = 0;\n  virtual void FusedHalfUpdatePut(ep::Stream* stream, uint32_t n_keys, const void* keys,\n                                  const void* values, const void* update, const float* lr,\n                                  float scale) {\n    UNIMPLEMENTED();\n  }\n  virtual bool IsFusionSupported() { return false; }\n  virtual bool SnapshotExists(const std::string& name) = 0;\n  virtual void LoadSnapshot(const std::string& name) = 0;\n  virtual void LoadSnapshot(const std::string& name,\n                            const std::function<void(KVIterator* iter)>& Hook) = 0;\n  virtual void SaveSnapshot(const std::string& name) = 0;\n};\n\n}  // namespace embedding\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_EMBEDDING_KEY_VALUE_STORE_H_\n"
  },
  {
    "path": "oneflow/core/embedding/key_value_store_options.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_EMBEDDING_KEY_VALUE_STORE_OPTIONS_H_\n#define ONEFLOW_EMBEDDING_KEY_VALUE_STORE_OPTIONS_H_\n#include \"nlohmann/json.hpp\"\n#include \"oneflow/core/job/resource_desc.h\"\n#include \"oneflow/core/embedding/cache.h\"\n\nnamespace oneflow {\nnamespace embedding {\n\nnamespace {\n\nvoid ParseCacheOptions(const nlohmann::json& cache_obj, CacheOptions* cache_options) {\n  CHECK_GT(cache_options->key_size, 0);\n  CHECK_GT(cache_options->value_size, 0);\n  CHECK(cache_obj.contains(\"policy\"));\n  CHECK(cache_obj[\"policy\"].is_string());\n  std::string policy = cache_obj[\"policy\"].get<std::string>();\n  if (policy == \"lru\") {\n    cache_options->policy = CacheOptions::Policy::kLRU;\n  } else if (policy == \"full\") {\n    cache_options->policy = CacheOptions::Policy::kFull;\n  } else {\n    UNIMPLEMENTED() << \"Unsupported cache policy\";\n  }\n  int64_t capacity = 0;\n  if (cache_obj.contains(\"capacity\")) {\n    CHECK(cache_obj[\"capacity\"].is_number());\n    capacity = cache_obj[\"capacity\"].get<int64_t>();\n  }\n  if (cache_obj.contains(\"cache_memory_budget_mb\")) {\n    CHECK(cache_obj[\"cache_memory_budget_mb\"].is_number());\n    int64_t cache_memory_budget_mb = cache_obj[\"cache_memory_budget_mb\"].get<int64_t>();\n    if (cache_memory_budget_mb > 0) {\n      CHECK_EQ(capacity, 0) << \"when set capacity, must not set cache_memory_budget_mb\";\n      capacity = cache_memory_budget_mb * 1024 * 1024 / cache_options->value_size;\n    }\n  }\n  CHECK_GT(capacity, 0) << \"capacity or cache_memory_budget_mb must be set\";\n  // add an extra_capacity to avoid crash by uneven partition.\n  const int64_t extra_capacity = capacity * 0.05;\n  cache_options->capacity = capacity + (extra_capacity > 4096 ? extra_capacity : 4096);\n  CHECK(cache_obj.contains(\"value_memory_kind\"));\n  CHECK(cache_obj[\"value_memory_kind\"].is_string());\n  std::string value_memory_kind = cache_obj[\"value_memory_kind\"].get<std::string>();\n  if (value_memory_kind == \"device\") {\n    cache_options->value_memory_kind = CacheOptions::MemoryKind::kDevice;\n  } else if (value_memory_kind == \"host\") {\n    cache_options->value_memory_kind = CacheOptions::MemoryKind::kHost;\n  } else {\n    UNIMPLEMENTED() << \"Unsupported cache value_memory_kind\";\n  }\n}\n\n}  // namespace\n\nclass KeyValueStoreOptions final {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(KeyValueStoreOptions);\n  explicit KeyValueStoreOptions(const std::string& json_serialized) {\n    auto json_object = nlohmann::json::parse(json_serialized);\n\n    CHECK(json_object.contains(\"key_type_size\"));\n    CHECK(json_object[\"key_type_size\"].is_number());\n    key_type_size_ = json_object[\"key_type_size\"].get<int64_t>();\n\n    CHECK(json_object.contains(\"value_type_size\"));\n    CHECK(json_object[\"value_type_size\"].is_number());\n    std::string value_type_name = json_object[\"value_type\"];\n    if (value_type_name == \"oneflow.float\" || value_type_name == \"oneflow.float32\") {\n      value_type_ = DataType::kFloat;\n    } else {\n      UNIMPLEMENTED();\n    }\n    value_type_size_ = json_object[\"value_type_size\"].get<int64_t>();\n\n    CHECK(json_object.contains(\"parallel_num\"));\n    CHECK(json_object[\"parallel_num\"].is_number());\n    const int64_t parallel_num = json_object[\"parallel_num\"].get<int64_t>();\n\n    CHECK(json_object.contains(\"name\"));\n    CHECK(json_object[\"name\"].is_string());\n    name_ = json_object[\"name\"].get<std::string>();\n\n    CHECK(json_object.contains(\"storage_dim\"));\n    CHECK(json_object[\"storage_dim\"].is_number());\n    line_size_ = json_object[\"storage_dim\"].get<int64_t>();\n\n    CHECK(json_object.contains(\"kv_store\"));\n    auto kv_store = json_object[\"kv_store\"];\n\n    auto caches = kv_store[\"caches\"];\n    if (caches != nlohmann::detail::value_t::null && caches.size() > 0) {\n      CHECK(caches.is_array());\n      cache_options_.resize(caches.size());\n      for (int i = 0; i < caches.size(); ++i) {\n        cache_options_.at(i).key_size = key_type_size_;\n        cache_options_.at(i).value_size = value_type_size_ * line_size_;\n        cache_options_.at(i).value_type = value_type_;\n        ParseCacheOptions(caches.at(i), &cache_options_.at(i));\n      }\n    }\n\n    CHECK(kv_store.contains(\"persistent_table\"));\n    auto persistent_table = kv_store[\"persistent_table\"];\n    CHECK(persistent_table.contains(\"path\"));\n    auto path = persistent_table[\"path\"];\n    CHECK(path.is_array() || path.is_string());\n    if (path.is_array()) {\n      CHECK_EQ(path.size(), parallel_num);\n      for (int i = 0; i < path.size(); ++i) {\n        CHECK(path.at(i).is_string());\n        persistent_table_paths_.push_back(path.at(i).get<std::string>());\n      }\n    } else {\n      std::string root_path = path.get<std::string>();\n      const std::string& num_rank = std::to_string(parallel_num);\n      const int64_t rank_id_suffix_length = num_rank.size();\n      for (int i = 0; i < parallel_num; ++i) {\n        const std::string& rank_id = std::to_string(i);\n        const std::string rank_i_path = root_path + \"/\"\n                                        + std::string(rank_id_suffix_length - rank_id.size(), '0')\n                                        + rank_id + \"-\" + num_rank;\n        persistent_table_paths_.push_back(rank_i_path);\n      }\n    }\n    CHECK(persistent_table.contains(\"physical_block_size\"));\n    CHECK(persistent_table[\"physical_block_size\"].is_number());\n    persistent_table_physical_block_size_ = persistent_table[\"physical_block_size\"].get<int64_t>();\n    if (persistent_table.contains(\"capacity_hint\")) {\n      CHECK(persistent_table[\"capacity_hint\"].is_number());\n      persistent_table_capacity_hint_ = persistent_table[\"capacity_hint\"].get<int64_t>();\n    } else {\n      persistent_table_capacity_hint_ = 0;\n    }\n  }\n  ~KeyValueStoreOptions() = default;\n  int64_t KeyTypeSize() const { return key_type_size_; }\n  int64_t ValueTypeSize() const { return value_type_size_; }\n  DataType ValueType() const { return value_type_; }\n  const std::string& Name() const { return name_; }\n  int64_t LineSize() const { return line_size_; }\n  const std::vector<CacheOptions>& GetCachesOptions() const { return cache_options_; }\n  const std::vector<std::string>& PersistentTablePaths() const { return persistent_table_paths_; }\n  int64_t PersistentTablePhysicalBlockSize() const { return persistent_table_physical_block_size_; }\n  int64_t PersistentTableCapacityHint() const { return persistent_table_capacity_hint_; }\n  bool IsFullCache() const {\n    if (cache_options_.size() > 0 && cache_options_.at(0).policy == CacheOptions::Policy::kFull) {\n      return true;\n    }\n    return false;\n  }\n\n private:\n  int64_t key_type_size_;\n  int64_t value_type_size_;\n  DataType value_type_;\n  std::string name_;\n  int64_t line_size_;\n  std::vector<std::string> persistent_table_paths_;\n  int64_t persistent_table_physical_block_size_;\n  int64_t persistent_table_capacity_hint_;\n  std::vector<CacheOptions> cache_options_;\n};\n\n}  // namespace embedding\n}  // namespace oneflow\n#endif  // ONEFLOW_EMBEDDING_KEY_VALUE_STORE_OPTIONS_H_\n"
  },
  {
    "path": "oneflow/core/embedding/key_value_store_test.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/embedding/persistent_table_key_value_store.h\"\n#include \"oneflow/core/embedding/cached_key_value_store.h\"\n#include \"oneflow/core/embedding/mock_key_value_store.h\"\n#include \"oneflow/core/embedding/cache.h\"\n#include \"oneflow/core/device/cuda_util.h\"\n#include <gtest/gtest.h>\n#include \"oneflow/core/ep/include/device_manager_registry.h\"\n#include \"oneflow/core/embedding/posix_file.h\"\n\nnamespace oneflow {\n\nnamespace embedding {\n\nnamespace {\n\n#ifdef WITH_CUDA\n\nstd::string CreateTempDirectory() {\n  const char* tmp_env = getenv(\"TMPDIR\");\n  const char* tmp_dir = tmp_env == nullptr ? \"/tmp\" : tmp_env;\n  std::string tpl = std::string(tmp_dir) + \"/test_kv_XXXXXX\";\n  char* path = mkdtemp(const_cast<char*>(tpl.c_str()));\n  PCHECK(path != nullptr);\n  return std::string(path);\n}\n\nbool HasCudaDevice() {\n  int device_count = 0;\n  if (cudaGetDeviceCount(&device_count) != cudaSuccess) { return false; }\n  if (device_count <= 0) { return false; }\n  return true;\n}\n\nvoid TestKeyValueStore(KeyValueStore* store, size_t num_embeddings, size_t test_embeddings,\n                       size_t embedding_vec_size) {\n  auto device = Singleton<ep::DeviceManagerRegistry>::Get()->GetDevice(DeviceType::kCUDA, 0);\n  ep::Stream* stream = device->CreateStream();\n\n  store->SaveSnapshot(\"init\");\n\n  uint64_t* keys = nullptr;\n  float* values = nullptr;\n  float* values1 = nullptr;\n  uint64_t* keys_host = nullptr;\n  float* values_host = nullptr;\n  uint64_t* context = nullptr;\n  uint32_t* n_missing = nullptr;\n  uint32_t* host_n_missing = nullptr;\n  uint64_t* missing_keys = nullptr;\n  uint32_t* missing_indices = nullptr;\n  size_t keys_size = sizeof(uint64_t) * num_embeddings;\n  size_t values_size = sizeof(float) * embedding_vec_size * num_embeddings;\n  size_t context_size = sizeof(uint64_t) * num_embeddings;\n  const size_t batch_size = 128;\n  OF_CUDA_CHECK(cudaMalloc(&keys, keys_size));\n  OF_CUDA_CHECK(cudaMalloc(&values, values_size));\n  OF_CUDA_CHECK(cudaMalloc(&values1, values_size));\n  OF_CUDA_CHECK(cudaMalloc(&context, context_size));\n  OF_CUDA_CHECK(cudaMallocHost(&keys_host, keys_size));\n  OF_CUDA_CHECK(cudaMallocHost(&values_host, values_size));\n  OF_CUDA_CHECK(cudaMallocHost(&host_n_missing, sizeof(uint32_t)));\n  OF_CUDA_CHECK(cudaMalloc(&missing_keys, batch_size * sizeof(uint64_t)));\n  OF_CUDA_CHECK(cudaMalloc(&missing_indices, batch_size * sizeof(uint32_t)));\n  OF_CUDA_CHECK(cudaMalloc(&n_missing, sizeof(uint32_t)));\n  for (size_t i = 0; i < num_embeddings; ++i) {\n    uint64_t key = i + 1;\n    keys_host[i] = key;\n    for (size_t j = 0; j < embedding_vec_size; j++) {\n      values_host[i * embedding_vec_size + j] = key;\n    }\n  }\n  OF_CUDA_CHECK(cudaMemcpy(keys, keys_host, keys_size, cudaMemcpyDefault));\n  OF_CUDA_CHECK(cudaMemcpy(values, values_host, values_size, cudaMemcpyDefault));\n\n  store->Put(stream, 0, keys, values);\n  OF_CUDA_CHECK(cudaDeviceSynchronize());\n  OF_CUDA_CHECK(cudaGetLastError());\n\n  for (size_t offset = 0; offset < test_embeddings; offset += batch_size) {\n    const size_t num_keys = std::min(batch_size, test_embeddings - offset);\n    store->Get(stream, num_keys, keys + offset, values1 + offset * embedding_vec_size, n_missing,\n               missing_indices);\n    OF_CUDA_CHECK(cudaMemcpy(host_n_missing, n_missing, sizeof(uint32_t), cudaMemcpyDefault));\n    OF_CUDA_CHECK(cudaDeviceSynchronize());\n    ASSERT_EQ(*host_n_missing, num_keys);\n    store->Put(stream, num_keys, keys + offset, values + offset * embedding_vec_size);\n  }\n\n  OF_CUDA_CHECK(cudaDeviceSynchronize());\n\n  store->SaveSnapshot(\"final\");\n\n  OF_CUDA_CHECK(cudaMemset(values_host, 0, values_size));\n  OF_CUDA_CHECK(cudaMemset(values, 0, values_size));\n  for (size_t offset = 0; offset < test_embeddings; offset += batch_size) {\n    const size_t num_keys = std::min(batch_size, test_embeddings - offset);\n    store->Get(stream, num_keys, keys + offset, values + offset * embedding_vec_size, n_missing,\n               missing_indices);\n    OF_CUDA_CHECK(cudaMemcpy(host_n_missing, n_missing, sizeof(uint32_t), cudaMemcpyDefault));\n    OF_CUDA_CHECK(cudaDeviceSynchronize());\n    ASSERT_EQ(*host_n_missing, 0);\n  }\n  OF_CUDA_CHECK(cudaMemcpy(values_host, values, values_size, cudaMemcpyDefault));\n  OF_CUDA_CHECK(cudaDeviceSynchronize());\n  for (size_t i = 0; i < test_embeddings; ++i) {\n    uint64_t key = keys_host[i];\n    for (size_t j = 0; j < embedding_vec_size; j++) {\n      ASSERT_EQ(values_host[i * embedding_vec_size + j], key);\n    }\n  }\n\n  store->LoadSnapshot(\"init\");\n\n  for (size_t offset = 0; offset < test_embeddings; offset += batch_size) {\n    const size_t num_keys = std::min(batch_size, test_embeddings - offset);\n    store->Get(stream, num_keys, keys + offset, values1 + offset * embedding_vec_size, n_missing,\n               missing_indices);\n    OF_CUDA_CHECK(cudaMemcpy(host_n_missing, n_missing, sizeof(uint32_t), cudaMemcpyDefault));\n    OF_CUDA_CHECK(cudaDeviceSynchronize());\n    ASSERT_EQ(*host_n_missing, num_keys);\n  }\n\n  store->LoadSnapshot(\"final\");\n\n  OF_CUDA_CHECK(cudaMemset(values_host, 0, values_size));\n  OF_CUDA_CHECK(cudaMemset(values, 0, values_size));\n  for (size_t offset = 0; offset < test_embeddings; offset += batch_size) {\n    const size_t num_keys = std::min(batch_size, test_embeddings - offset);\n    store->Get(stream, num_keys, keys + offset, values + offset * embedding_vec_size, n_missing,\n               missing_indices);\n    OF_CUDA_CHECK(cudaMemcpy(host_n_missing, n_missing, sizeof(uint32_t), cudaMemcpyDefault));\n    OF_CUDA_CHECK(cudaDeviceSynchronize());\n    ASSERT_EQ(*host_n_missing, 0);\n  }\n  OF_CUDA_CHECK(cudaMemcpy(values_host, values, values_size, cudaMemcpyDefault));\n  OF_CUDA_CHECK(cudaDeviceSynchronize());\n  for (size_t i = 0; i < test_embeddings; ++i) {\n    uint64_t key = keys_host[i];\n    for (size_t j = 0; j < embedding_vec_size; j++) {\n      ASSERT_EQ(values_host[i * embedding_vec_size + j], key);\n    }\n  }\n\n  OF_CUDA_CHECK(cudaDeviceSynchronize());\n  OF_CUDA_CHECK(cudaGetLastError());\n  OF_CUDA_CHECK(cudaFree(keys));\n  OF_CUDA_CHECK(cudaFree(values));\n  OF_CUDA_CHECK(cudaFree(values1));\n  OF_CUDA_CHECK(cudaFreeHost(keys_host));\n  OF_CUDA_CHECK(cudaFreeHost(values_host));\n  OF_CUDA_CHECK(cudaFreeHost(host_n_missing));\n  OF_CUDA_CHECK(cudaFree(n_missing));\n  OF_CUDA_CHECK(cudaFree(missing_keys));\n  OF_CUDA_CHECK(cudaFree(missing_indices));\n  CHECK_JUST(stream->Sync());\n  device->DestroyStream(stream);\n}\n\nTEST(PersistentTableKeyValueStore, PersistentTableKeyValueStore) {\n  if (!HasCudaDevice()) { return; }\n  Singleton<ep::DeviceManagerRegistry>::New();\n  PersistentTableKeyValueStoreOptions options{};\n  uint32_t value_length = 128;\n\n  std::string path = CreateTempDirectory();\n  options.table_options.path = path;\n  options.table_options.value_size = value_length * sizeof(float);\n  options.table_options.key_size = GetSizeOfDataType(DataType::kUInt64);\n  options.table_options.physical_block_size = 512;\n\n  std::unique_ptr<KeyValueStore> store = NewPersistentTableKeyValueStore(options);\n  store->ReserveQueryLength(128);\n  TestKeyValueStore(store.get(), 1024, 1024, value_length);\n  store.reset();\n  PosixFile::RecursiveDelete(path);\n  Singleton<ep::DeviceManagerRegistry>::Delete();\n}\n\nTEST(CachedKeyValueStore, LRU) {\n  if (!HasCudaDevice()) { return; }\n  Singleton<ep::DeviceManagerRegistry>::New();\n  PersistentTableKeyValueStoreOptions store_options{};\n  std::string path = CreateTempDirectory();\n  store_options.table_options.path = path;\n  uint32_t value_length = 128;\n  store_options.table_options.value_size = value_length * sizeof(float);\n  store_options.table_options.key_size = GetSizeOfDataType(DataType::kUInt64);\n  store_options.table_options.physical_block_size = 512;\n  std::unique_ptr<KeyValueStore> store = NewPersistentTableKeyValueStore(store_options);\n  CacheOptions cache_options{};\n  cache_options.policy = CacheOptions::Policy::kLRU;\n  cache_options.value_memory_kind = CacheOptions::MemoryKind::kDevice;\n  cache_options.value_size = 512;\n  cache_options.capacity = 512;\n  cache_options.key_size = 8;\n  std::unique_ptr<Cache> cache = NewCache(cache_options);\n  std::unique_ptr<KeyValueStore> cached_store =\n      NewCachedKeyValueStore(std::move(store), std::move(cache));\n  cached_store->ReserveQueryLength(128);\n  TestKeyValueStore(cached_store.get(), 1024, 1024, value_length);\n  cached_store.reset();\n  PosixFile::RecursiveDelete(path);\n  Singleton<ep::DeviceManagerRegistry>::Delete();\n}\n\nTEST(CachedKeyValueStore, Full) {\n  if (!HasCudaDevice()) { return; }\n  Singleton<ep::DeviceManagerRegistry>::New();\n  PersistentTableKeyValueStoreOptions store_options{};\n  std::string path = CreateTempDirectory();\n  store_options.table_options.path = path;\n  uint32_t value_length = 128;\n  store_options.table_options.value_size = value_length * sizeof(float);\n  store_options.table_options.key_size = GetSizeOfDataType(DataType::kUInt64);\n  store_options.table_options.physical_block_size = 512;\n  std::unique_ptr<KeyValueStore> store = NewPersistentTableKeyValueStore(store_options);\n  CacheOptions cache_options{};\n  cache_options.policy = CacheOptions::Policy::kFull;\n  cache_options.value_memory_kind = CacheOptions::MemoryKind::kHost;\n  cache_options.value_size = 512;\n  cache_options.capacity = 1024 * 2;\n  cache_options.key_size = 8;\n  std::unique_ptr<Cache> cache = NewCache(cache_options);\n  std::unique_ptr<KeyValueStore> cached_store =\n      NewCachedKeyValueStore(std::move(store), std::move(cache));\n  cached_store->ReserveQueryLength(128);\n  TestKeyValueStore(cached_store.get(), 1024, 1024, value_length);\n  cached_store.reset();\n  PosixFile::RecursiveDelete(path);\n  Singleton<ep::DeviceManagerRegistry>::Delete();\n}\n\nTEST(MockKeyValueStore, Mock) {\n  if (!HasCudaDevice()) { return; }\n  Singleton<ep::DeviceManagerRegistry>::New();\n  MockKeyValueStoreOptions store_options{};\n  std::string path = CreateTempDirectory();\n  uint32_t value_length = 128;\n  store_options.value_size = value_length * sizeof(float);\n  store_options.key_size = GetSizeOfDataType(DataType::kUInt64);\n  std::unique_ptr<KeyValueStore> store = NewMockKeyValueStore(store_options);\n  store->ReserveQueryLength(128);\n  TestKeyValueStore(store.get(), 1024, 1024, value_length);\n  store.reset();\n  PosixFile::RecursiveDelete(path);\n  Singleton<ep::DeviceManagerRegistry>::Delete();\n}\n\n#endif  // WITH_CUDA\n\n}  // namespace\n\n}  // namespace embedding\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/embedding/kv_iterator.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_EMBEDDING_KV_ITERATOR_H_\n#define ONEFLOW_CORE_EMBEDDING_KV_ITERATOR_H_\n\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/ep/include/stream.h\"\n\nnamespace oneflow {\n\nnamespace embedding {\n\nclass KVIterator {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(KVIterator);\n  KVIterator() = default;\n  virtual ~KVIterator() = default;\n\n  virtual void NextN(ep::Stream* stream, uint32_t n_request, uint32_t* n_result, void* keys,\n                     void* values) = 0;\n  virtual void Reset() = 0;\n};\n\n}  // namespace embedding\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_EMBEDDING_KV_ITERATOR_H_\n"
  },
  {
    "path": "oneflow/core/embedding/lru_cache.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n// Inspired by https://github.com/NVIDIA-Merlin/HugeCTR/blob/master/gpu_cache/src/nv_gpu_cache.cu\n\n#include \"oneflow/core/embedding/lru_cache.h\"\n#include \"oneflow/core/device/cuda_util.h\"\n#include \"oneflow/core/embedding/hash_functions.cuh\"\n#include <new>\n#include <cuda.h>\n\n#if CUDA_VERSION >= 11000 && ((!defined(__CUDA_ARCH__)) || (__CUDA_ARCH__ >= 700)) \\\n    && !(defined(__clang__) && defined(__CUDA__))\n#include <cuda/std/semaphore>\n#endif\n\nnamespace oneflow {\n\nnamespace embedding {\n\nnamespace {\n\nconstexpr int kWarpSize = 32;\nconstexpr int kNumWarpPerBlock = 4;\nconstexpr int kBlockSize = kNumWarpPerBlock * kWarpSize;\nconstexpr uint32_t kFullMask = 0xFFFFFFFFU;\n\nep::CudaLaunchConfig GetLaunchConfig(uint32_t n_keys) {\n  return ep::CudaLaunchConfig((n_keys + kNumWarpPerBlock - 1) / kNumWarpPerBlock,\n                              kWarpSize * kNumWarpPerBlock, 0);\n}\n\nstruct ThreadContext {\n  __device__ ThreadContext() {\n    const uint32_t global_thread_id = blockIdx.x * blockDim.x + threadIdx.x;\n    global_warp_id = global_thread_id / kWarpSize;\n    warp_id_in_block = global_warp_id % kNumWarpPerBlock;  // NOLINT\n    num_warps = gridDim.x * kNumWarpPerBlock;              // NOLINT\n    lane_id = global_thread_id % kWarpSize;\n  }\n\n  uint32_t global_warp_id;\n  uint32_t warp_id_in_block;\n  uint32_t num_warps;\n  uint32_t lane_id;\n};\n\nclass WarpMutexAtomicImpl {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(WarpMutexAtomicImpl);\n  __device__ WarpMutexAtomicImpl() : flag_(0) {}\n  __device__ ~WarpMutexAtomicImpl() = default;\n\n  __device__ void Lock(const ThreadContext& thread_ctx) {\n    if (thread_ctx.lane_id == 0) {\n      while (atomicCAS(&flag_, 0, 1) != 0)\n        ;\n    }\n    __threadfence();\n    __syncwarp();\n  }\n\n  __device__ void Unlock(const ThreadContext& thread_ctx) {\n    __syncwarp();\n    __threadfence();\n    if (thread_ctx.lane_id == 0) { atomicExch(&flag_, 0); }\n  }\n\n private:\n  int32_t flag_;\n};\n\n#if CUDA_VERSION >= 11000 && ((!defined(__CUDA_ARCH__)) || (__CUDA_ARCH__ >= 700)) \\\n    && !(defined(__clang__) && defined(__CUDA__))\n\nclass WarpMutexSemaphoreImpl {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(WarpMutexSemaphoreImpl);\n  __device__ WarpMutexSemaphoreImpl() : semaphore_(1) {}\n  __device__ ~WarpMutexSemaphoreImpl() = default;\n\n  __device__ void Lock(const ThreadContext& thread_ctx) {\n    if (thread_ctx.lane_id == 0) { semaphore_.acquire(); }\n    __syncwarp();\n  }\n\n  __device__ void Unlock(const ThreadContext& thread_ctx) {\n    __syncwarp();\n    if (thread_ctx.lane_id == 0) { semaphore_.release(); }\n  }\n\n private:\n  cuda::binary_semaphore<cuda::thread_scope_device> semaphore_;\n};\n\n#endif\n\ntemplate<typename Key, typename Elem>\nstruct LruCacheContext {\n  Key* keys;\n  Elem* lines;\n  uint8_t* ages;\n  void* mutex;\n  uint64_t n_set;\n  uint32_t line_size;\n  CacheOptions::MemoryKind value_memory_kind;\n};\n\n__global__ void InitCacheSetMutex(uint32_t n_set, void* mutex) {\n#if CUDA_VERSION >= 11000 && __CUDA_ARCH__ >= 700 && !(defined(__clang__) && defined(__CUDA__))\n  using WarpMutex = WarpMutexSemaphoreImpl;\n#else\n  using WarpMutex = WarpMutexAtomicImpl;\n#endif  // CUDA_VERSION >= 11000 && __CUDA_ARCH__ >= 700 && !(defined(__clang__) &&\n        // defined(__CUDA__))\n  const uint32_t idx = blockIdx.x * blockDim.x + threadIdx.x;\n  if (idx < n_set) { new (reinterpret_cast<WarpMutex*>(mutex) + idx) WarpMutex; }\n}\n\ntemplate<typename Key, typename Elem>\nvoid ClearLruCacheContext(LruCacheContext<Key, Elem>* ctx) {\n  OF_CUDA_CHECK(cudaMemset(ctx->keys, 0, ctx->n_set * kWarpSize * sizeof(Key)));\n  OF_CUDA_CHECK(cudaMemset(ctx->ages, 0, ctx->n_set * kWarpSize * sizeof(uint8_t)));\n  InitCacheSetMutex<<<(ctx->n_set - 1 + 256) / 256, 256>>>(ctx->n_set, ctx->mutex);\n}\n\ntemplate<typename Key, typename Elem>\nvoid InitLruCacheContext(const CacheOptions& options, LruCacheContext<Key, Elem>* ctx) {\n  const size_t keys_size_per_set = kWarpSize * sizeof(Key);\n  const uint32_t line_size = options.value_size / sizeof(Elem);\n  const size_t lines_size_per_set = kWarpSize * line_size * sizeof(Elem);\n  const size_t ages_size_per_set = kWarpSize * sizeof(uint8_t);\n  int device = 0;\n  OF_CUDA_CHECK(cudaGetDevice(&device));\n  int major = 0;\n  OF_CUDA_CHECK(cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device));\n  size_t mutex_size_per_set = 0;\n#if CUDA_VERSION >= 11000 && !(defined(__clang__) && defined(__CUDA__))\n  if (major >= 7) {\n#if !defined(__CUDA_ARCH__)\n    mutex_size_per_set = sizeof(WarpMutexSemaphoreImpl);\n#else\n    UNIMPLEMENTED();\n#endif\n  } else {\n    mutex_size_per_set = sizeof(WarpMutexAtomicImpl);\n  }\n#else\n  mutex_size_per_set = sizeof(WarpMutexAtomicImpl);\n#endif  // CUDA_VERSION >= 11000 && !(defined(__clang__) && defined(__CUDA__))\n  const size_t n_set = (options.capacity - 1 + kWarpSize) / kWarpSize;\n  CHECK_GT(n_set, 0);\n  ctx->n_set = n_set;\n  ctx->line_size = line_size;\n  const size_t keys_size = n_set * keys_size_per_set;\n  OF_CUDA_CHECK(cudaMalloc(&(ctx->keys), keys_size));\n  const size_t lines_size = n_set * lines_size_per_set;\n  if (options.value_memory_kind == CacheOptions::MemoryKind::kDevice) {\n    OF_CUDA_CHECK(cudaMalloc(&(ctx->lines), lines_size));\n  } else if (options.value_memory_kind == CacheOptions::MemoryKind::kHost) {\n    if (ParseBooleanFromEnv(\"ONEFLOW_ONE_EMBEDDING_DISABLE_NUMA_AWARE_ALLOCATION\", false)) {\n      OF_CUDA_CHECK(cudaMallocHost(&(ctx->lines), lines_size));\n    } else {\n      OF_CUDA_CHECK(\n          NumaAwareCudaMallocHost(device, reinterpret_cast<void**>(&ctx->lines), lines_size));\n    }\n  } else {\n    UNIMPLEMENTED();\n  }\n  ctx->value_memory_kind = options.value_memory_kind;\n  const size_t ages_size = n_set * ages_size_per_set;\n  OF_CUDA_CHECK(cudaMalloc(&(ctx->ages), ages_size));\n  const size_t mutex_size = n_set * mutex_size_per_set;\n  OF_CUDA_CHECK(cudaMalloc(&(ctx->mutex), mutex_size));\n\n  ClearLruCacheContext(ctx);\n}\n\ntemplate<typename Key, typename Elem>\nvoid DestroyLruCacheContext(LruCacheContext<Key, Elem>* ctx) {\n  OF_CUDA_CHECK(cudaFree(ctx->keys));\n  if (ctx->value_memory_kind == CacheOptions::MemoryKind::kDevice) {\n    OF_CUDA_CHECK(cudaFree(ctx->lines));\n  } else if (ctx->value_memory_kind == CacheOptions::MemoryKind::kHost) {\n    OF_CUDA_CHECK(cudaFreeHost(ctx->lines));\n  } else {\n    UNIMPLEMENTED();\n  }\n  OF_CUDA_CHECK(cudaFree(ctx->ages));\n  OF_CUDA_CHECK(cudaFree(ctx->mutex));\n}\n\ntemplate<typename Key, typename Elem>\nstruct SetContext {\n#if CUDA_VERSION >= 11000 && __CUDA_ARCH__ >= 700 && !(defined(__clang__) && defined(__CUDA__))\n  using WarpMutex = WarpMutexSemaphoreImpl;\n#else\n  using WarpMutex = WarpMutexAtomicImpl;\n#endif  // CUDA_VERSION >= 11000 && __CUDA_ARCH__ >= 700 && !(defined(__clang__) &&\n        // defined(__CUDA__))\n  __device__ SetContext(const LruCacheContext<Key, Elem>& ctx, uint32_t set_id)\n      : keys(ctx.keys + set_id * kWarpSize),\n        mutex(reinterpret_cast<WarpMutex*>(ctx.mutex) + set_id),\n        ages(ctx.ages + set_id * kWarpSize),\n        lines(ctx.lines + static_cast<size_t>(set_id) * kWarpSize * ctx.line_size) {}\n\n  __device__ int Lookup(const ThreadContext& thread_ctx, Key key) {\n    const Key lane_key = keys[thread_ctx.lane_id];\n    const int lane_age = ages[thread_ctx.lane_id];\n    const bool lane_hit = (lane_key == key && lane_age != 0);\n    const unsigned hit_mask = __ballot_sync(kFullMask, lane_hit);\n    if (hit_mask != 0) {\n      return __ffs(static_cast<int>(hit_mask)) - 1;\n    } else {\n      return -1;\n    }\n  }\n\n  __device__ void Read(const LruCacheContext<Key, Elem>& cache_ctx, const ThreadContext& thread_ctx,\n                       int way, Elem* line) {\n    const Elem* from_line = lines + way * cache_ctx.line_size;\n    for (int i = thread_ctx.lane_id; i < cache_ctx.line_size; i += kWarpSize) {\n      line[i] = from_line[i];\n    }\n  }\n\n  __device__ int InsertWithoutEvicting(const LruCacheContext<Key, Elem>& cache_ctx,\n                                       const ThreadContext& thread_ctx, Key key) {\n    int insert_way = -1;\n    const Key lane_key = keys[thread_ctx.lane_id];\n    int lane_age = ages[thread_ctx.lane_id];\n    const unsigned hit_mask = __ballot_sync(kFullMask, lane_key == key && lane_age != 0);\n    if (hit_mask != 0) {\n      insert_way = __ffs(static_cast<int>(hit_mask)) - 1;\n      const int insert_way_age = __shfl_sync(kFullMask, lane_age, insert_way);\n      if (lane_age > insert_way_age) {\n        lane_age -= 1;\n      } else if (thread_ctx.lane_id == insert_way) {\n        lane_age = kWarpSize;\n      }\n      __syncwarp();\n    }\n    if (insert_way == -1) {\n      const unsigned valid_mask = __ballot_sync(kFullMask, lane_age != 0);\n      if (valid_mask != kFullMask) {\n        insert_way = __popc(static_cast<int>(valid_mask));\n        if (lane_age > 0) {\n          lane_age -= 1;\n        } else if (thread_ctx.lane_id == insert_way) {\n          lane_age = kWarpSize;\n          keys[insert_way] = key;\n        }\n        __syncwarp();\n      }\n    }\n    if (insert_way != -1) { ages[thread_ctx.lane_id] = lane_age; }\n    return insert_way;\n  }\n\n  __device__ void Evict(const LruCacheContext<Key, Elem>& cache_ctx,\n                        const ThreadContext& thread_ctx, Key key, int* way, Key* evicted_key) {\n    const Key lane_key = keys[thread_ctx.lane_id];\n    int lane_age = ages[thread_ctx.lane_id];\n    const int insert_way = __ffs(__ballot_sync(kFullMask, lane_age == 1)) - 1;\n    *evicted_key = __shfl_sync(kFullMask, lane_key, insert_way);\n    if (thread_ctx.lane_id == insert_way) {\n      keys[insert_way] = key;\n      lane_age = kWarpSize;\n    } else if (lane_age > 1) {\n      lane_age -= 1;\n    }\n    __syncwarp();\n    ages[thread_ctx.lane_id] = lane_age;\n    *way = insert_way;\n  }\n\n  __device__ void Write(const LruCacheContext<Key, Elem>& cache_ctx,\n                        const ThreadContext& thread_ctx, int way, const Elem* line) {\n    Elem* to_line = lines + way * cache_ctx.line_size;\n    for (int i = thread_ctx.lane_id; i < cache_ctx.line_size; i += kWarpSize) {\n      to_line[i] = line[i];\n    }\n  }\n\n  __device__ void Lock(const ThreadContext& thread_ctx) { mutex->Lock(thread_ctx); }\n\n  __device__ void Unlock(const ThreadContext& thread_ctx) { mutex->Unlock(thread_ctx); }\n\n  Key* keys;\n  Elem* lines;\n  uint8_t* ages;\n  WarpMutex* mutex;\n};\n\ntemplate<typename Key, typename Elem, bool test_only>\n__global__ void GetKernel(LruCacheContext<Key, Elem> cache_ctx, uint32_t num_keys, const Key* keys,\n                          Elem* values, uint32_t* n_missing_keys, Key* missing_keys,\n                          uint32_t* missing_indices) {\n  ThreadContext thread_ctx{};\n  __shared__ Key block_keys[kNumWarpPerBlock][kWarpSize];\n  __shared__ size_t block_set_ids[kNumWarpPerBlock][kWarpSize];\n  for (uint32_t batch_offset = thread_ctx.global_warp_id * kWarpSize; batch_offset < num_keys;\n       batch_offset += thread_ctx.num_warps * kWarpSize) {\n    const uint32_t n_batch_keys = min(kWarpSize, num_keys - batch_offset);\n    if (thread_ctx.lane_id < n_batch_keys) {\n      const Key key = keys[batch_offset + thread_ctx.lane_id];\n      const size_t hash = LruCacheHash()(key);\n      const uint32_t set_id = hash % cache_ctx.n_set;\n      block_keys[thread_ctx.warp_id_in_block][thread_ctx.lane_id] = key;\n      block_set_ids[thread_ctx.warp_id_in_block][thread_ctx.lane_id] = set_id;\n    }\n    __syncwarp();\n    uint32_t n_warp_missing = 0;\n    Key warp_missing_key = 0;\n    uint32_t warp_missing_index = 0;\n    for (uint32_t i = 0; i < n_batch_keys; ++i) {\n      const uint32_t key_idx = batch_offset + i;\n      const Key key = block_keys[thread_ctx.warp_id_in_block][i];\n      const size_t set_id = block_set_ids[thread_ctx.warp_id_in_block][i];\n      SetContext<Key, Elem> set_ctx(cache_ctx, set_id);\n      const int way = set_ctx.Lookup(thread_ctx, key);\n      if (way < 0) {\n        if (thread_ctx.lane_id == n_warp_missing) {\n          warp_missing_key = key;\n          warp_missing_index = key_idx;\n        }\n        __syncwarp();\n        n_warp_missing += 1;\n      } else if (!test_only) {\n        set_ctx.Read(cache_ctx, thread_ctx, way, values + key_idx * cache_ctx.line_size);\n      }\n    }\n    if (n_warp_missing > 0) {\n      uint32_t base_missing_idx = 0;\n      if (thread_ctx.lane_id == 0) { base_missing_idx = atomicAdd(n_missing_keys, n_warp_missing); }\n      __syncwarp();\n      base_missing_idx = __shfl_sync(kFullMask, base_missing_idx, 0);\n      if (thread_ctx.lane_id < n_warp_missing) {\n        missing_keys[base_missing_idx + thread_ctx.lane_id] = warp_missing_key;\n        missing_indices[base_missing_idx + thread_ctx.lane_id] = warp_missing_index;\n      }\n      __syncwarp();\n    }\n    __syncwarp();\n  }\n}\n\ntemplate<typename Key, typename Elem>\n__global__ void PutWithoutEvictingKernel(LruCacheContext<Key, Elem> cache_ctx, uint32_t num_keys,\n                                         const Key* keys, const Elem* values, uint32_t* n_missing,\n                                         Key* missing_keys, uint32_t* missing_indices) {\n  ThreadContext thread_ctx{};\n  __shared__ Key block_keys[kNumWarpPerBlock][kWarpSize];\n  __shared__ size_t block_set_ids[kNumWarpPerBlock][kWarpSize];\n  for (uint32_t batch_offset = thread_ctx.global_warp_id * kWarpSize; batch_offset < num_keys;\n       batch_offset += thread_ctx.num_warps * kWarpSize) {\n    const uint32_t n_batch_keys = min(kWarpSize, num_keys - batch_offset);\n    if (thread_ctx.lane_id < n_batch_keys) {\n      const Key key = keys[batch_offset + thread_ctx.lane_id];\n      const size_t hash = LruCacheHash()(key);\n      const uint32_t set_id = hash % cache_ctx.n_set;\n      block_keys[thread_ctx.warp_id_in_block][thread_ctx.lane_id] = key;\n      block_set_ids[thread_ctx.warp_id_in_block][thread_ctx.lane_id] = set_id;\n    }\n    __syncwarp();\n    uint32_t n_warp_missing = 0;\n    Key warp_missing_key = 0;\n    uint32_t warp_missing_index = 0;\n    for (uint32_t i = 0; i < n_batch_keys; ++i) {\n      const uint32_t key_idx = batch_offset + i;\n      const Key key = block_keys[thread_ctx.warp_id_in_block][i];\n      const size_t set_id = block_set_ids[thread_ctx.warp_id_in_block][i];\n      SetContext<Key, Elem> set_ctx(cache_ctx, set_id);\n      set_ctx.Lock(thread_ctx);\n      Key evicted_key = 0;\n      const int insert_way = set_ctx.InsertWithoutEvicting(cache_ctx, thread_ctx, key);\n      if (insert_way >= 0) {\n        set_ctx.Write(cache_ctx, thread_ctx, insert_way, values + cache_ctx.line_size * key_idx);\n      } else {\n        if (thread_ctx.lane_id == n_warp_missing) {\n          warp_missing_key = key;\n          warp_missing_index = key_idx;\n        }\n        __syncwarp();\n        n_warp_missing += 1;\n      }\n      set_ctx.Unlock(thread_ctx);\n    }\n    if (n_warp_missing > 0) {\n      uint32_t base_missing_idx = 0;\n      if (thread_ctx.lane_id == 0) { base_missing_idx = atomicAdd(n_missing, n_warp_missing); }\n      __syncwarp();\n      base_missing_idx = __shfl_sync(kFullMask, base_missing_idx, 0);\n      if (thread_ctx.lane_id < n_warp_missing) {\n        missing_keys[base_missing_idx + thread_ctx.lane_id] = warp_missing_key;\n        missing_indices[base_missing_idx + thread_ctx.lane_id] = warp_missing_index;\n      }\n      __syncwarp();\n    }\n  }\n}\n\ntemplate<typename Key, typename Elem>\n__global__ void EvictKernel(LruCacheContext<Key, Elem> cache_ctx, const Key* keys,\n                            const uint32_t* indices, const Elem* values, const uint32_t* n_evict,\n                            Key* evicted_keys, Elem* evicted_values) {\n  ThreadContext thread_ctx{};\n  uint32_t num_evict = *n_evict;\n  __shared__ Key block_keys[kNumWarpPerBlock][kWarpSize];\n  __shared__ size_t block_set_ids[kNumWarpPerBlock][kWarpSize];\n  for (uint32_t batch_offset = thread_ctx.global_warp_id * kWarpSize; batch_offset < num_evict;\n       batch_offset += thread_ctx.num_warps * kWarpSize) {\n    const uint32_t n_batch_keys = min(kWarpSize, num_evict - batch_offset);\n    if (thread_ctx.lane_id < n_batch_keys) {\n      const Key key = keys[batch_offset + thread_ctx.lane_id];\n      const size_t hash = LruCacheHash()(key);\n      const uint32_t set_id = hash % cache_ctx.n_set;\n      block_keys[thread_ctx.warp_id_in_block][thread_ctx.lane_id] = key;\n      block_set_ids[thread_ctx.warp_id_in_block][thread_ctx.lane_id] = set_id;\n    }\n    __syncwarp();\n    for (uint32_t i = 0; i < n_batch_keys; ++i) {\n      const uint32_t key_idx = batch_offset + i;\n      const Key key = block_keys[thread_ctx.warp_id_in_block][i];\n      const uint32_t set_id = block_set_ids[thread_ctx.warp_id_in_block][i];\n      SetContext<Key, Elem> set_ctx(cache_ctx, set_id);\n      set_ctx.Lock(thread_ctx);\n      int evicted_way = -1;\n      Key evicted_key = 0;\n      set_ctx.Evict(cache_ctx, thread_ctx, key, &evicted_way, &evicted_key);\n      if (thread_ctx.lane_id == 0) { evicted_keys[key_idx] = evicted_key; }\n      __syncwarp();\n      set_ctx.Read(cache_ctx, thread_ctx, evicted_way,\n                   evicted_values + cache_ctx.line_size * key_idx);\n      set_ctx.Write(cache_ctx, thread_ctx, evicted_way,\n                    values + cache_ctx.line_size * indices[key_idx]);\n      set_ctx.Unlock(thread_ctx);\n    }\n  }\n}\n\ntemplate<typename Key, typename Elem>\n__global__ void DumpKernel(LruCacheContext<Key, Elem> cache_ctx, size_t start_key_index,\n                           size_t end_key_index, uint32_t* n_dumped, Key* keys, Elem* values) {\n  ThreadContext thread_ctx{};\n  __shared__ Key warp_keys[kNumWarpPerBlock][kWarpSize];\n  __shared__ uint8_t warp_ages[kNumWarpPerBlock][kWarpSize];\n  for (uint32_t warp_start_key_index = start_key_index + thread_ctx.global_warp_id * kWarpSize;\n       warp_start_key_index < end_key_index;\n       warp_start_key_index += thread_ctx.num_warps * kWarpSize) {\n    Key lane_key = 0;\n    uint8_t lane_age = 0;\n    if (warp_start_key_index + thread_ctx.lane_id < end_key_index) {\n      lane_key = cache_ctx.keys[warp_start_key_index + thread_ctx.lane_id];\n      lane_age = cache_ctx.ages[warp_start_key_index + thread_ctx.lane_id];\n    }\n    __syncwarp();\n    warp_keys[thread_ctx.warp_id_in_block][thread_ctx.lane_id] = lane_key;\n    warp_ages[thread_ctx.warp_id_in_block][thread_ctx.lane_id] = lane_age;\n    const int key_count = __popc(__ballot_sync(kFullMask, lane_age != 0));\n    if (key_count == 0) { continue; }\n    uint32_t offset = 0;\n    if (thread_ctx.lane_id == 0) { offset = atomicAdd(n_dumped, key_count); }\n    offset = __shfl_sync(kFullMask, offset, 0);\n    __syncwarp();\n    for (uint32_t i = 0; i < kWarpSize; ++i) {\n      const Key key = warp_keys[thread_ctx.warp_id_in_block][i];\n      const Key age = warp_ages[thread_ctx.warp_id_in_block][i];\n      if (age == 0) { continue; }\n      if (thread_ctx.lane_id == 0) { keys[offset] = key; }\n      __syncwarp();\n      for (uint32_t j = thread_ctx.lane_id; j < cache_ctx.line_size; j += kWarpSize) {\n        values[offset * cache_ctx.line_size + j] =\n            cache_ctx\n                .lines[static_cast<size_t>(warp_start_key_index + i) * cache_ctx.line_size + j];\n      }\n      __syncwarp();\n      offset += 1;\n    }\n  }\n}\n\ntemplate<typename Key, typename Elem>\nclass LruCache : public Cache {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(LruCache);\n  explicit LruCache(const CacheOptions& options)\n      : device_index_{},\n        max_query_length_(0),\n        query_indices_buffer_(nullptr),\n        query_keys_buffer_(nullptr),\n        value_type_(options.value_type) {\n    OF_CUDA_CHECK(cudaGetDevice(&device_index_));\n    InitLruCacheContext(options, &ctx_);\n  }\n  ~LruCache() override {\n    CudaCurrentDeviceGuard guard(device_index_);\n    if (max_query_length_ != 0) {\n      OF_CUDA_CHECK(cudaFree(query_indices_buffer_));\n      OF_CUDA_CHECK(cudaFree(query_keys_buffer_));\n    }\n    DestroyLruCacheContext(&ctx_);\n  }\n\n  uint32_t KeySize() const override { return sizeof(Key); }\n  uint32_t ValueSize() const override { return sizeof(Elem) * ctx_.line_size; }\n  DataType ValueType() const override { return value_type_; }\n  uint64_t Capacity() const override { return ctx_.n_set * kWarpSize; }\n  uint32_t MaxQueryLength() const override { return max_query_length_; }\n\n  void ReserveQueryLength(uint32_t query_length) override {\n    CudaCurrentDeviceGuard guard(device_index_);\n    if (query_length < max_query_length_) { return; }\n    if (max_query_length_ != 0) {\n      OF_CUDA_CHECK(cudaFree(query_indices_buffer_));\n      OF_CUDA_CHECK(cudaFree(query_keys_buffer_));\n    }\n    OF_CUDA_CHECK(cudaMalloc(&query_indices_buffer_, query_length * sizeof(uint32_t)));\n    OF_CUDA_CHECK(cudaMalloc(&query_keys_buffer_, query_length * sizeof(Key)));\n    max_query_length_ = query_length;\n  }\n\n  CacheOptions::Policy Policy() const override { return CacheOptions::Policy::kLRU; }\n\n  void Test(ep::Stream* stream, uint32_t n_keys, const void* keys, uint32_t* n_missing,\n            void* missing_keys, uint32_t* missing_indices) override {\n    CHECK_LE(n_keys, max_query_length_);\n    auto cuda_stream = stream->As<ep::CudaStream>();\n    OF_CUDA_CHECK(cudaMemsetAsync(n_missing, 0, sizeof(uint32_t), cuda_stream->cuda_stream()));\n    if (n_keys == 0) { return; }\n    cuda_stream->LaunchKernel(GetKernel<Key, Elem, true>, GetLaunchConfig(n_keys), ctx_, n_keys,\n                              static_cast<const Key*>(keys), nullptr, n_missing,\n                              static_cast<Key*>(missing_keys), missing_indices);\n  }\n\n  using Cache::Get;\n  void Get(ep::Stream* stream, uint32_t n_keys, const void* keys, void* values, uint32_t* n_missing,\n           void* missing_keys, uint32_t* missing_indices) override {\n    CHECK_LE(n_keys, max_query_length_);\n    auto cuda_stream = stream->As<ep::CudaStream>();\n    OF_CUDA_CHECK(cudaMemsetAsync(n_missing, 0, sizeof(uint32_t), cuda_stream->cuda_stream()));\n    if (n_keys == 0) { return; }\n    cuda_stream->LaunchKernel(GetKernel<Key, Elem, false>, GetLaunchConfig(n_keys), ctx_, n_keys,\n                              static_cast<const Key*>(keys), static_cast<Elem*>(values), n_missing,\n                              static_cast<Key*>(missing_keys), missing_indices);\n  }\n\n  void Put(ep::Stream* stream, uint32_t n_keys, const void* keys, const void* values,\n           uint32_t* n_evicted, void* evicted_keys, void* evicted_values) override {\n    CHECK_LE(n_keys, max_query_length_);\n    auto cuda_stream = stream->As<ep::CudaStream>();\n    OF_CUDA_CHECK(cudaMemsetAsync(n_evicted, 0, sizeof(uint32_t), cuda_stream->cuda_stream()));\n    if (n_keys == 0) { return; }\n    cuda_stream->LaunchKernel(PutWithoutEvictingKernel<Key, Elem>, GetLaunchConfig(n_keys), ctx_,\n                              n_keys, static_cast<const Key*>(keys),\n                              static_cast<const Elem*>(values), n_evicted, query_keys_buffer_,\n                              query_indices_buffer_);\n    cuda_stream->LaunchKernel(EvictKernel<Key, Elem>, GetLaunchConfig(n_keys), ctx_,\n                              query_keys_buffer_, query_indices_buffer_,\n                              static_cast<const Elem*>(values), n_evicted,\n                              static_cast<Key*>(evicted_keys), static_cast<Elem*>(evicted_values));\n  }\n\n  void Dump(ep::Stream* stream, uint64_t start_key_index, uint64_t end_key_index,\n            uint32_t* n_dumped, void* keys, void* values) override {\n    auto cuda_stream = stream->As<ep::CudaStream>();\n    OF_CUDA_CHECK(cudaMemsetAsync(n_dumped, 0, sizeof(uint32_t), cuda_stream->cuda_stream()));\n    const uint64_t max_dump_keys = end_key_index - start_key_index;\n    cuda_stream->LaunchKernel(\n        DumpKernel<Key, Elem>,\n        ep::CudaLaunchConfig((max_dump_keys + kNumWarpPerBlock - 1) / kNumWarpPerBlock, kBlockSize,\n                             0),\n        ctx_, start_key_index, end_key_index, n_dumped, static_cast<Key*>(keys),\n        static_cast<Elem*>(values));\n  }\n\n  void ClearDirtyFlags() override {\n    // do nothing.\n    return;\n  }\n\n  void Clear() override { ClearLruCacheContext<Key, Elem>(&ctx_); }\n\n private:\n  int device_index_;\n  uint32_t max_query_length_;\n  LruCacheContext<Key, Elem> ctx_;\n  uint32_t* query_indices_buffer_;\n  Key* query_keys_buffer_;\n  DataType value_type_;\n};\n\ntemplate<typename Key>\nstd::unique_ptr<Cache> DispatchValueType(const CacheOptions& options) {\n  if (options.value_size % sizeof(ulonglong2) == 0) {\n    return std::unique_ptr<Cache>(new LruCache<Key, ulonglong2>(options));\n  } else if (options.value_size % sizeof(uint64_t) == 0) {\n    return std::unique_ptr<Cache>(new LruCache<Key, uint64_t>(options));\n  } else if (options.value_size % sizeof(uint32_t) == 0) {\n    return std::unique_ptr<Cache>(new LruCache<Key, uint32_t>(options));\n  } else if (options.value_size % sizeof(uint16_t) == 0) {\n    return std::unique_ptr<Cache>(new LruCache<Key, uint16_t>(options));\n  } else {\n    return std::unique_ptr<Cache>(new LruCache<Key, uint8_t>(options));\n  }\n}\n\nstd::unique_ptr<Cache> DispatchKeyType(const CacheOptions& options) {\n  if (options.key_size == sizeof(uint32_t)) {\n    return DispatchValueType<uint32_t>(options);\n  } else if (options.key_size == sizeof(uint64_t)) {\n    return DispatchValueType<uint64_t>(options);\n  } else {\n    UNIMPLEMENTED();\n    return nullptr;\n  }\n}\n\n}  // namespace\n\nstd::unique_ptr<Cache> NewLruCache(const CacheOptions& options) { return DispatchKeyType(options); }\n\n}  // namespace embedding\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/embedding/lru_cache.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_EMBEDDING_LRU_CACHE_H_\n#define ONEFLOW_CORE_EMBEDDING_LRU_CACHE_H_\n\n#include \"oneflow/core/embedding/cache.h\"\n\nnamespace oneflow {\n\nnamespace embedding {\n\nstd::unique_ptr<Cache> NewLruCache(const CacheOptions& options);\n\n}  // namespace embedding\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_EMBEDDING_LRU_CACHE_H_\n"
  },
  {
    "path": "oneflow/core/embedding/mock_key_value_store.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/embedding/mock_key_value_store.h\"\n#include \"oneflow/core/device/cuda_util.h\"\n\nnamespace oneflow {\n\nnamespace embedding {\n\nnamespace {\n\ntemplate<typename Key>\nclass IteratorImpl : public KVIterator {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(IteratorImpl);\n  IteratorImpl(HashMap<Key, std::string>* store, uint32_t key_size, uint32_t value_size,\n               uint32_t max_query_length, void* host_keys_buffer, void* host_values_buffer,\n               uint32_t* host_num_buffer)\n      : store_(store),\n        pos_(store->begin()),\n        key_size_(key_size),\n        value_size_(value_size),\n        max_query_length_(max_query_length),\n        host_keys_buffer_(host_keys_buffer),\n        host_values_buffer_(host_values_buffer),\n        host_num_buffer_(host_num_buffer) {}\n  ~IteratorImpl() override = default;\n\n  void NextN(ep::Stream* stream, uint32_t n_request, uint32_t* n_result, void* keys,\n             void* values) override {\n    CHECK_LE(n_request, max_query_length_);\n    auto cuda_stream = stream->As<ep::CudaStream>();\n    CHECK_JUST(cuda_stream->Sync());\n    *host_num_buffer_ = 0;\n    while (*host_num_buffer_ < n_request && pos_ != store_->end()) {\n      reinterpret_cast<Key*>(host_keys_buffer_)[*host_num_buffer_] = pos_->first;\n      std::memcpy(reinterpret_cast<char*>(host_values_buffer_) + *host_num_buffer_ * value_size_,\n                  pos_->second.data(), value_size_);\n    }\n    OF_CUDA_CHECK(cudaMemcpyAsync(n_result, host_num_buffer_, sizeof(uint32_t), cudaMemcpyDefault,\n                                  cuda_stream->cuda_stream()));\n    const uint32_t num_keys = *host_num_buffer_;\n    if (num_keys != 0) {\n      OF_CUDA_CHECK(cudaMemcpyAsync(keys, host_keys_buffer_, num_keys * key_size_,\n                                    cudaMemcpyDefault, cuda_stream->cuda_stream()));\n      OF_CUDA_CHECK(cudaMemcpyAsync(values, host_values_buffer_, num_keys * value_size_,\n                                    cudaMemcpyDefault, cuda_stream->cuda_stream()));\n    }\n  }\n\n  void Reset() override { pos_ = store_->begin(); }\n\n private:\n  HashMap<Key, std::string>* store_;\n  typename HashMap<Key, std::string>::iterator pos_;\n  uint32_t key_size_;\n  uint32_t value_size_;\n  uint32_t max_query_length_;\n  void* host_keys_buffer_;\n  void* host_values_buffer_;\n  uint32_t* host_num_buffer_;\n};\n\ntemplate<typename Key>\nclass KeyValueStoreImpl : public KeyValueStore {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(KeyValueStoreImpl);\n  explicit KeyValueStoreImpl(const MockKeyValueStoreOptions& options)\n      : device_index_(-1), max_query_length_(0) {\n    OF_CUDA_CHECK(cudaGetDevice(&device_index_));\n    key_size_ = options.key_size;\n    value_size_ = options.value_size;\n    OF_CUDA_CHECK(NumaAwareCudaMallocHost(\n        device_index_, reinterpret_cast<void**>(&host_query_keys_), key_size_ * max_query_length_));\n    OF_CUDA_CHECK(NumaAwareCudaMallocHost(device_index_,\n                                          reinterpret_cast<void**>(&host_query_values_),\n                                          value_size_ * max_query_length_));\n    OF_CUDA_CHECK(NumaAwareCudaMallocHost(device_index_, reinterpret_cast<void**>(&host_n_missing_),\n                                          sizeof(uint32_t)));\n    OF_CUDA_CHECK(NumaAwareCudaMallocHost(device_index_,\n                                          reinterpret_cast<void**>(&host_missing_indices_),\n                                          sizeof(uint32_t) * max_query_length_));\n  }\n  ~KeyValueStoreImpl() {\n    CudaCurrentDeviceGuard guard(device_index_);\n    if (max_query_length_ != 0) {\n      OF_CUDA_CHECK(cudaFreeHost(host_query_keys_));\n      OF_CUDA_CHECK(cudaFreeHost(host_query_values_));\n      OF_CUDA_CHECK(cudaFreeHost(host_missing_indices_));\n    }\n    OF_CUDA_CHECK(cudaFreeHost(host_n_missing_));\n  }\n\n  uint32_t KeySize() const override { return key_size_; }\n\n  uint32_t ValueSize() const override { return value_size_; }\n\n  uint32_t MaxQueryLength() const override { return max_query_length_; }\n\n  void ReserveQueryLength(uint32_t query_length) override {\n    CudaCurrentDeviceGuard guard(device_index_);\n    if (query_length <= max_query_length_) { return; }\n    if (max_query_length_ != 0) {\n      OF_CUDA_CHECK(cudaFreeHost(host_query_keys_));\n      OF_CUDA_CHECK(cudaFreeHost(host_query_values_));\n      OF_CUDA_CHECK(cudaFreeHost(host_missing_indices_));\n    }\n    OF_CUDA_CHECK(NumaAwareCudaMallocHost(\n        device_index_, reinterpret_cast<void**>(&host_query_keys_), key_size_ * query_length));\n    OF_CUDA_CHECK(NumaAwareCudaMallocHost(\n        device_index_, reinterpret_cast<void**>(&host_query_values_), value_size_ * query_length));\n    OF_CUDA_CHECK(NumaAwareCudaMallocHost(device_index_,\n                                          reinterpret_cast<void**>(&host_missing_indices_),\n                                          sizeof(uint32_t) * query_length));\n    max_query_length_ = query_length;\n  }\n\n  using KeyValueStore::Get;\n  void Get(ep::Stream* stream, uint32_t num_keys, const void* keys, void* values,\n           uint32_t* n_missing, uint32_t* missing_indices) override;\n  void Put(ep::Stream* stream, uint32_t num_keys, const void* keys, const void* values) override;\n  bool SnapshotExists(const std::string& name) override;\n  void LoadSnapshot(const std::string& name) override;\n  void LoadSnapshot(const std::string& name,\n                    const std::function<void(KVIterator* iter)>& Hook) override;\n  void SaveSnapshot(const std::string& name) override;\n\n private:\n  int device_index_;\n  uint32_t max_query_length_;\n  uint32_t key_size_;\n  uint32_t value_size_;\n  Key* host_query_keys_{};\n  uint8_t* host_query_values_{};\n  uint32_t* host_n_missing_{};\n  uint32_t* host_missing_indices_{};\n  HashMap<Key, std::string> store_;\n  HashMap<std::string, HashMap<Key, std::string>> snapshots_;\n  std::mutex mutex_;\n};\n\ntemplate<typename Key>\nvoid KeyValueStoreImpl<Key>::Get(ep::Stream* stream, uint32_t num_keys, const void* keys,\n                                 void* values, uint32_t* n_missing, uint32_t* missing_indices) {\n  std::lock_guard<std::mutex> lock(mutex_);\n  auto cuda_stream = stream->As<ep::CudaStream>();\n  CHECK_LE(num_keys, max_query_length_);\n  if (num_keys == 0) {\n    OF_CUDA_CHECK(cudaMemsetAsync(n_missing, 0, sizeof(uint32_t),\n                                  stream->As<ep::CudaStream>()->cuda_stream()));\n    return;\n  }\n  OF_CUDA_CHECK(cudaMemcpyAsync(host_query_keys_, keys, key_size_ * num_keys, cudaMemcpyDefault,\n                                cuda_stream->cuda_stream()));\n  CHECK_JUST(cuda_stream->Sync());\n  *host_n_missing_ = 0;\n  for (uint32_t i = 0; i < num_keys; ++i) {\n    auto it = store_.find(host_query_keys_[i]);\n    if (it != store_.end()) {\n      std::memcpy(host_query_values_ + i * value_size_, it->second.data(), value_size_);\n    } else {\n      host_missing_indices_[*host_n_missing_] = i;\n      *host_n_missing_ += 1;\n    }\n  }\n  OF_CUDA_CHECK(cudaMemcpyAsync(values, host_query_values_, num_keys * value_size_,\n                                cudaMemcpyDefault, cuda_stream->cuda_stream()));\n  OF_CUDA_CHECK(cudaMemcpyAsync(n_missing, host_n_missing_, sizeof(uint32_t), cudaMemcpyDefault,\n                                cuda_stream->cuda_stream()));\n  OF_CUDA_CHECK(cudaMemcpyAsync(missing_indices, host_missing_indices_,\n                                (*host_n_missing_) * sizeof(uint32_t), cudaMemcpyDefault,\n                                cuda_stream->cuda_stream()));\n}\n\ntemplate<typename Key>\nvoid KeyValueStoreImpl<Key>::Put(ep::Stream* stream, uint32_t num_keys, const void* keys,\n                                 const void* values) {\n  std::lock_guard<std::mutex> lock(mutex_);\n  auto cuda_stream = stream->As<ep::CudaStream>();\n  CHECK_LE(num_keys, max_query_length_);\n  if (num_keys == 0) { return; }\n  OF_CUDA_CHECK(cudaMemcpyAsync(host_query_keys_, keys, key_size_ * num_keys, cudaMemcpyDefault,\n                                cuda_stream->cuda_stream()));\n  OF_CUDA_CHECK(cudaMemcpyAsync(host_query_values_, values, value_size_ * num_keys,\n                                cudaMemcpyDefault, cuda_stream->cuda_stream()));\n  CHECK_JUST(cuda_stream->Sync());\n  for (uint32_t i = 0; i < num_keys; ++i) {\n    store_[host_query_keys_[i]] = std::string(\n        reinterpret_cast<const char*>(host_query_values_) + i * value_size_, value_size_);\n  }\n}\n\ntemplate<typename Key>\nbool KeyValueStoreImpl<Key>::SnapshotExists(const std::string& name) {\n  return snapshots_.find(name) != snapshots_.end();\n}\n\ntemplate<typename Key>\nvoid KeyValueStoreImpl<Key>::LoadSnapshot(const std::string& name) {\n  CudaCurrentDeviceGuard guard(device_index_);\n  LoadSnapshot(name, nullptr);\n}\n\ntemplate<typename Key>\nvoid KeyValueStoreImpl<Key>::LoadSnapshot(const std::string& name,\n                                          const std::function<void(KVIterator* iter)>& Hook) {\n  CudaCurrentDeviceGuard guard(device_index_);\n  store_ = snapshots_[name];\n  if (Hook) {\n    IteratorImpl<Key> iterator(&store_, KeySize(), ValueSize(), max_query_length_, host_query_keys_,\n                               host_query_values_, host_n_missing_);\n    Hook(&iterator);\n  }\n}\n\ntemplate<typename Key>\nvoid KeyValueStoreImpl<Key>::SaveSnapshot(const std::string& name) {\n  CudaCurrentDeviceGuard guard(device_index_);\n  snapshots_[name] = store_;\n}\n\n}  // namespace\n\nstd::unique_ptr<KeyValueStore> NewMockKeyValueStore(const MockKeyValueStoreOptions& options) {\n  if (options.key_size == sizeof(uint64_t)) {\n    return std::unique_ptr<KeyValueStore>(new KeyValueStoreImpl<uint64_t>(options));\n  } else if (options.key_size == sizeof(uint32_t)) {\n    return std::unique_ptr<KeyValueStore>(new KeyValueStoreImpl<uint32_t>(options));\n  } else {\n    UNIMPLEMENTED();\n    return nullptr;\n  }\n}\n\n}  // namespace embedding\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/embedding/mock_key_value_store.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_EMBEDDING_MOCK_KEY_VALUE_STORE_H_\n#define ONEFLOW_CORE_EMBEDDING_MOCK_KEY_VALUE_STORE_H_\n\n#include \"oneflow/core/embedding/key_value_store.h\"\n\nnamespace oneflow {\n\nnamespace embedding {\n\n#ifdef WITH_CUDA\n\nstruct MockKeyValueStoreOptions {\n  uint32_t key_size = 0;\n  uint32_t value_size = 0;\n};\n\nstd::unique_ptr<KeyValueStore> NewMockKeyValueStore(const MockKeyValueStoreOptions& options);\n\n#endif  // WITH_CUDA\n\n}  // namespace embedding\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_EMBEDDING_MOCK_KEY_VALUE_STORE_H_\n"
  },
  {
    "path": "oneflow/core/embedding/persistent_table.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/embedding/persistent_table.h\"\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/embedding/hash_functions.cuh\"\n\n#ifdef __linux__\n\n#include \"oneflow/core/common/channel.h\"\n#include \"oneflow/core/embedding/posix_file.h\"\n#include \"oneflow/core/common/blocking_counter.h\"\n#include <robin_hood.h>\n#include <fcntl.h>\n#include <sys/mman.h>\n#include <dirent.h>\n#include <sys/syscall.h>\n#include <linux/aio_abi.h>\n#include <unistd.h>\n\n#endif  // __linux__\n\nnamespace oneflow {\n\nnamespace embedding {\n\n#ifdef __linux__\n\nnamespace {\n\nconstexpr uint32_t kDefaultNumWorkerThreads = 4;\nconstexpr uint32_t kRingQueueDepth = 128;\nconstexpr uint32_t kRingSubmitBatch = 32;\nconstexpr uint32_t kAioQueueDepth = 128;\nconstexpr uint32_t kChunkNameSuffixLength = 12;\nconstexpr char const* kKeyFileNamePrefix = \"key-\";\nconstexpr char const* kIndexFileNamePrefix = \"index-\";\nconstexpr char const* kValueFileNamePrefix = \"value-\";\nconstexpr char const* kLockFileName = \"LOCK\";\nconstexpr char const* kKeySizeFileName = \"KEY_SIZE\";\nconstexpr char const* kValueSizeFileName = \"VALUE_SIZE\";\nconstexpr char const* kPhysicalBlockSizeFileName = \"PHYSICAL_BLOCK_SIZE\";\nconstexpr char const* kNumLogicalBlocksPerChunkFileName = \"NUM_LOGICAL_BLOCKS_PER_CHUNK\";\nconstexpr char const* kKeysDirName = \"keys\";\nconstexpr char const* kValuesDirName = \"values\";\nconstexpr char const* kSnapshotsDirName = \"snapshots\";\nconstexpr char const* kSnapshotListFileName = \"LIST\";\nconstexpr size_t kParallelForStride = 256;\n\ntemplate<typename T>\nT* BytesOffset(T* ptr, size_t bytes) {\n  return reinterpret_cast<T*>(\n      const_cast<unsigned char*>((reinterpret_cast<const unsigned char*>(ptr) + bytes)));\n}\n\nvoid MemcpyOffset(void* dst, size_t dst_off, const void* src, size_t src_off, size_t n) {\n  std::memcpy(BytesOffset(dst, dst_off), BytesOffset(src, src_off), n);\n}\n\nvoid InitOrCheckMetaValue(const std::string& pathname, int64_t expected, bool init) {\n  bool exists = PosixFile::FileExists(pathname);\n  if (init) {\n    CHECK(!exists) << pathname;\n    std::ofstream ofs(pathname);\n    ofs << expected << std::endl;\n  } else {\n    CHECK(exists);\n    std::ifstream ifs(pathname);\n    int64_t value = 0;\n    ifs >> value;\n    if (value != expected) { LOG(FATAL) << \"Check failed: \" << pathname; }\n  }\n}\n\nstd::string GetChunkName(uint64_t chunk_id) {\n  const std::string chunk_name_wo_leading_zero = std::to_string(chunk_id);\n  CHECK_LE(chunk_name_wo_leading_zero.size(), kChunkNameSuffixLength);\n  return std::string(kChunkNameSuffixLength - chunk_name_wo_leading_zero.size(), '0')\n         + chunk_name_wo_leading_zero;\n}\n\nuint64_t GetChunkId(const std::string& chunk_name) {\n  size_t pos = 0;\n  const uint64_t chunk_id = std::stoull(chunk_name, &pos, 10);\n  CHECK_EQ(pos, kChunkNameSuffixLength);\n  return chunk_id;\n}\n\nuint64_t GetChunkId(const std::string& filename, const std::string& prefix) {\n  CHECK_EQ(filename.compare(0, prefix.size(), prefix), 0);\n  return GetChunkId(filename.substr(prefix.size()));\n}\n\nvoid ListChunkFiles(const std::string& base, const std::string& prefix,\n                    std::unordered_map<uint64_t, std::string>* chunks) {\n  DIR* dir = opendir(base.c_str());\n  PCHECK(dir != nullptr);\n  struct dirent* ent = nullptr;\n  while ((ent = readdir(dir)) != nullptr) {\n    if (strlen(ent->d_name) != prefix.size() + kChunkNameSuffixLength) { continue; }\n    if (strncmp(ent->d_name, prefix.c_str(), prefix.size()) != 0) { continue; }\n    const uint64_t chunk_id = GetChunkId(ent->d_name + prefix.size());\n    CHECK(chunks->emplace(chunk_id, PosixFile::JoinPath(base, ent->d_name)).second);\n  }\n  PCHECK(closedir(dir) == 0);\n}\n\nuint32_t GetLogicalBlockSize(uint32_t physical_block_size, uint32_t value_size) {\n  return physical_block_size >= value_size ? physical_block_size\n                                           : RoundUp(value_size, physical_block_size);\n}\n\nclass AlignedBuffer final {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(AlignedBuffer);\n  explicit AlignedBuffer(size_t alignment) : alignment_(alignment), size_(0) {}\n  ~AlignedBuffer() = default;\n\n  void Resize(size_t new_size) {\n    if (new_size > size_) {\n      ptr_.reset(static_cast<char*>(aligned_alloc(alignment_, new_size)));\n      size_ = new_size;\n    }\n  }\n\n  void* ptr() { return ptr_.get(); }\n\n private:\n  size_t alignment_;\n  size_t size_;\n  std::unique_ptr<char> ptr_;\n};\n\ntemplate<typename Key>\nclass ChunkIteratorImpl : public PersistentTable::Iterator {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(ChunkIteratorImpl);\n  ChunkIteratorImpl(uint32_t value_size, uint32_t logical_block_size, uint32_t num_values_per_block,\n                    uint64_t num_values_per_chunk, uint64_t chunk_id, uint64_t n,\n                    const Key* chunk_keys, const uint64_t* chunk_indices, const void* chunk_values)\n      : pos_(0),\n        value_size_(value_size),\n        logical_block_size_(logical_block_size),\n        num_values_per_block_(num_values_per_block),\n        num_values_per_chunk_(num_values_per_chunk),\n        n_(n),\n        chunk_keys_(chunk_keys),\n        chunk_indices_(chunk_indices),\n        chunk_values_(chunk_values),\n        chunk_index_offset_(chunk_id * num_values_per_chunk_) {}\n  ~ChunkIteratorImpl() override = default;\n\n  void Next(uint32_t num_keys, uint32_t* return_keys, void* keys, void* values) override {\n    uint32_t count = 0;\n    while (count < num_keys && pos_ != n_) {\n      const uint64_t index_in_chunk = chunk_indices_[pos_] - chunk_index_offset_;\n      static_cast<Key*>(keys)[count] = chunk_keys_[index_in_chunk];\n      const uint64_t block_in_chunk = index_in_chunk / num_values_per_block_;\n      const uint32_t index_in_block = index_in_chunk - block_in_chunk * num_values_per_block_;\n      const uint32_t value_offset =\n          block_in_chunk * logical_block_size_ + index_in_block * value_size_;\n      std::memcpy(static_cast<char*>(values) + count * value_size_,\n                  static_cast<const char*>(chunk_values_) + value_offset, value_size_);\n      count++;\n      pos_++;\n    }\n    *return_keys = count;\n  }\n\n  void Reset() override { pos_ = 0; }\n\n private:\n  uint64_t pos_;\n  uint32_t value_size_;\n  uint32_t logical_block_size_;\n  uint32_t num_values_per_block_;\n  uint64_t num_values_per_chunk_;\n  uint64_t n_;\n  const Key* chunk_keys_;\n  const uint64_t* chunk_indices_;\n  const void* chunk_values_;\n  uint64_t chunk_index_offset_;\n};\n\nclass AioEngine final {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(AioEngine);\n  AioEngine() : ctx_{}, num_readings_(0) {\n    PCHECK(syscall(__NR_io_setup, kAioQueueDepth, &ctx_) >= 0);\n    cbs_.resize(kAioQueueDepth);\n    cbs_ptr_.resize(kAioQueueDepth);\n    for (uint32_t i = 0; i < kAioQueueDepth; ++i) { cbs_ptr_[i] = &cbs_[i]; }\n    events_.resize(kAioQueueDepth);\n  }\n  ~AioEngine() {\n    WaitUntilDone();\n    PCHECK(syscall(__NR_io_destroy, ctx_) >= 0);\n  }\n\n  void AsyncPread(int fd, void* buf, size_t count, off_t offset) {\n    if (num_readings_ == kAioQueueDepth) { WaitUntilDone(); }\n    struct iocb* cb = &cbs_.at(num_readings_);\n    cb->aio_fildes = fd;\n    cb->aio_lio_opcode = IOCB_CMD_PREAD;\n    cb->aio_reqprio = 0;\n    cb->aio_buf = reinterpret_cast<uintptr_t>(buf);\n    cb->aio_nbytes = count;\n    cb->aio_offset = offset;\n    const long nr = 1;\n    PCHECK(syscall(__NR_io_submit, ctx_, nr, &cbs_ptr_.at(num_readings_)) >= 0);\n    num_readings_ += 1;\n  }\n\n  void WaitUntilDone() {\n    if (num_readings_ != 0) {\n      PCHECK(syscall(__NR_io_getevents, ctx_, num_readings_, num_readings_, events_.data(), nullptr)\n             >= 0);\n      for (long i = 0; i < num_readings_; ++i) { CHECK_GT(events_.at(i).res, 0); }\n      num_readings_ = 0;\n    }\n  }\n\n private:\n  aio_context_t ctx_;\n  long num_readings_;\n  std::vector<struct iocb> cbs_;\n  std::vector<struct iocb*> cbs_ptr_;\n  std::vector<struct io_event> events_;\n};\n\nconstexpr size_t kCacheLineSize = 64;\n\ntemplate<typename Engine>\nusing IoTask = std::function<void(Engine* engine)>;\n\ntemplate<typename Engine>\nusing ForRange = std::function<void(Engine* engine, size_t start, size_t end)>;\n\ntemplate<typename Engine>\nclass Worker final {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(Worker);\n  Worker() { thread_ = std::thread(&Worker<Engine>::PullTask, this); }\n  ~Worker() {\n    Shutdown();\n    thread_.join();\n  }\n\n  void Schedule(IoTask<Engine> task) { tasks_.Send(std::move(task)); }\n\n  void Shutdown() { tasks_.Close(); }\n\n private:\n  void PullTask() {\n    while (true) {\n      IoTask<Engine> task;\n      const ChannelStatus status = tasks_.Receive(&task);\n      if (status == ChannelStatus::kChannelStatusErrorClosed) { break; }\n      CHECK_EQ(status, ChannelStatus::kChannelStatusSuccess);\n      task(&engine_);\n    }\n  }\n  Channel<IoTask<Engine>> tasks_;\n  Engine engine_;\n  std::thread thread_;\n};\n\ntemplate<typename Key, typename Engine>\nclass SnapshotIteratorImpl;\n\ntemplate<typename Key, typename Engine>\nclass PersistentTableImpl : public PersistentTable {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(PersistentTableImpl);\n  explicit PersistentTableImpl(const PersistentTableOptions& options);\n  ~PersistentTableImpl() override;\n\n  uint32_t KeySize() const override { return key_size_; }\n\n  uint32_t ValueSize() const override { return value_size_; }\n\n  uint32_t LogicalBlockSize() const override;\n  void GetBlocks(uint32_t num_keys, const void* keys, void* blocks, uint32_t* offsets) override;\n  void Get(uint32_t num_keys, const void* keys, void* values, uint32_t* n_missing,\n           uint32_t* missing_indices) override;\n  void PutBlocks(uint32_t num_keys, const void* keys, const void* blocks) override;\n  void Put(uint32_t num_keys, const void* keys, const void* values) override;\n  bool SnapshotExists(const std::string& name) override;\n  void LoadSnapshot(const std::string& name) override;\n  void LoadSnapshot(const std::string& name,\n                    const std::function<void(Iterator* iter)>& Hook) override;\n  void SaveSnapshot(const std::string& name) override;\n  Iterator* ReadSnapshot(const std::string& name) override;\n\n private:\n  friend class SnapshotIteratorImpl<Key, Engine>;\n  std::string KeyFilePath(uint64_t chunk_id) const;\n  std::string ValueFilePath(uint64_t chunk_id) const;\n  std::string IndexFilePath(const std::string& name, uint64_t chunk_id) const;\n  std::string SnapshotDirPath(const std::string& name) const;\n  std::string SnapshotListFilePath(const std::string& name) const;\n  void LoadSnapshotImpl(const std::string& name);\n  void SaveSnapshotImpl(const std::string& name);\n  void ParallelFor(size_t total, const ForRange<Engine>& for_range);\n\n  std::string root_dir_;\n  std::string keys_dir_;\n  std::string values_dir_;\n  std::string snapshots_dir_;\n  uint32_t key_size_;\n  uint32_t value_size_;\n  uint64_t num_logical_blocks_per_chunk_;\n  uint64_t num_values_per_chunk_;\n  uint32_t num_values_per_block_;\n  uint32_t physical_block_size_;\n  uint32_t logical_block_size_;\n\n  std::vector<std::unique_ptr<Worker<Engine>>> workers_;\n\n  std::vector<uint32_t> offsets_buffer_;\n  AlignedBuffer blocks_buffer_;\n\n  std::recursive_mutex mutex_;\n  uint64_t physical_table_size_;\n  robin_hood::unordered_flat_map<Key, uint64_t> row_id_mapping_;\n  std::vector<PosixFile> value_files_;\n  PosixFile writable_key_file_;\n  uint64_t writable_key_file_chunk_id_;\n  PosixFileLockGuard lock_;\n  bool read_only_;\n};\n\ntemplate<typename Key, typename Engine>\nPersistentTableImpl<Key, Engine>::PersistentTableImpl(const PersistentTableOptions& options)\n    : root_dir_(options.path),\n      key_size_(options.key_size),\n      value_size_(options.value_size),\n      physical_block_size_(options.physical_block_size),\n      logical_block_size_(GetLogicalBlockSize(options.physical_block_size, value_size_)),\n      blocks_buffer_(options.physical_block_size),\n      writable_key_file_chunk_id_(-1),\n      read_only_(options.read_only) {\n  const uint64_t capacity_hint = ParseIntegerFromEnv(\n      \"ONEFLOW_ONE_EMBEDDING_PERSISTENT_TABLE_CAPACITY_HINT\", options.capacity_hint);\n  if (capacity_hint > 0) { row_id_mapping_.reserve(capacity_hint); }\n  PosixFile::RecursiveCreateDirectory(options.path, 0755);\n  const std::string lock_filename = PosixFile::JoinPath(options.path, kLockFileName);\n  const bool init = !PosixFile::FileExists(lock_filename);\n  if (read_only_) {\n    CHECK(!init) << \"The table must be initialized in read only mode\";\n  } else {\n    lock_ = PosixFileLockGuard(PosixFile(lock_filename, O_CREAT | O_RDWR, 0644));\n  }\n  const uint64_t target_chunk_size = options.target_chunk_size_mb * 1024 * 1024;\n  CHECK_GE(target_chunk_size, logical_block_size_);\n  num_logical_blocks_per_chunk_ = target_chunk_size / logical_block_size_,\n  num_values_per_block_ = logical_block_size_ / value_size_;\n  num_values_per_chunk_ = num_values_per_block_ * num_logical_blocks_per_chunk_;\n  InitOrCheckMetaValue(PosixFile::JoinPath(options.path, kKeySizeFileName), key_size_, init);\n  InitOrCheckMetaValue(PosixFile::JoinPath(options.path, kValueSizeFileName), value_size_, init);\n  InitOrCheckMetaValue(PosixFile::JoinPath(options.path, kPhysicalBlockSizeFileName),\n                       options.physical_block_size, init);\n  InitOrCheckMetaValue(PosixFile::JoinPath(options.path, kNumLogicalBlocksPerChunkFileName),\n                       num_logical_blocks_per_chunk_, init);\n  keys_dir_ = PosixFile::JoinPath(options.path, kKeysDirName);\n  values_dir_ = PosixFile::JoinPath(options.path, kValuesDirName);\n  snapshots_dir_ = PosixFile::JoinPath(options.path, kSnapshotsDirName);\n  if (init) {\n    PosixFile::RecursiveCreateDirectory(keys_dir_, 0755);\n    PosixFile::RecursiveCreateDirectory(values_dir_, 0755);\n  }\n  const uint32_t num_workers = ParseIntegerFromEnv(\n      \"ONEFLOW_ONE_EMBEDDING_PERSISTENT_TABLE_NUM_WORKERS\", kDefaultNumWorkerThreads);\n  workers_.resize(num_workers);\n  for (uint32_t tid = 0; tid < workers_.size(); ++tid) {\n    workers_.at(tid).reset(new Worker<Engine>);\n  }\n  std::unordered_map<uint64_t, std::string> chunks;\n  ListChunkFiles(values_dir_, kValueFileNamePrefix, &chunks);\n  for (auto& chunk : chunks) {\n    if (value_files_.size() <= chunk.first) { value_files_.resize(chunk.first + 1); }\n    CHECK_EQ(value_files_.at(chunk.first).fd(), -1);\n    const int flags = read_only_ ? (O_RDONLY | O_DIRECT) : (O_RDWR | O_DIRECT);\n    PosixFile value_file(chunk.second, flags, 0644);\n    value_files_.at(chunk.first) = std::move(value_file);\n  }\n  if (!value_files_.empty()) {\n    physical_table_size_ = ((value_files_.size() - 1) * num_logical_blocks_per_chunk_\n                            + value_files_.back().Size() / logical_block_size_)\n                           * num_values_per_block_;\n  } else {\n    physical_table_size_ = 0;\n  }\n}\n\ntemplate<typename Key, typename Engine>\nPersistentTableImpl<Key, Engine>::~PersistentTableImpl() {\n  for (uint32_t tid = 0; tid < workers_.size(); ++tid) { workers_.at(tid)->Shutdown(); }\n}\n\ntemplate<typename Key, typename Engine>\nuint32_t PersistentTableImpl<Key, Engine>::LogicalBlockSize() const {\n  return logical_block_size_;\n}\n\ntemplate<typename Key, typename Engine>\nvoid PersistentTableImpl<Key, Engine>::GetBlocks(uint32_t num_keys, const void* keys, void* blocks,\n                                                 uint32_t* offsets) {\n  std::lock_guard<std::recursive_mutex> lock(mutex_);\n  ParallelFor(num_keys, [&](Engine* engine, size_t start, size_t end) {\n    for (uint64_t i = start; i < end; ++i) {\n      const Key key = static_cast<const Key*>(keys)[i];\n      auto it = row_id_mapping_.find(key);\n      if (it == row_id_mapping_.end()) {\n        offsets[i] = logical_block_size_;\n      } else {\n        const uint64_t id = it->second;\n        const uint64_t block_id = id / num_values_per_block_;\n        const uint32_t id_in_block = id - block_id * num_values_per_block_;\n        const uint32_t offset_in_block = id_in_block * value_size_;\n        const uint64_t chunk_id = block_id / num_logical_blocks_per_chunk_;\n        const uint64_t block_in_chunk = block_id - chunk_id * num_logical_blocks_per_chunk_;\n        const uint64_t block_offset = block_in_chunk * logical_block_size_;\n        PosixFile& file = value_files_.at(chunk_id);\n        offsets[i] = offset_in_block;\n        engine->AsyncPread(file.fd(), BytesOffset(blocks, i * logical_block_size_),\n                           logical_block_size_, block_offset);\n      }\n    }\n  });\n}\n\ntemplate<typename Key, typename Engine>\nvoid PersistentTableImpl<Key, Engine>::Get(uint32_t num_keys, const void* keys, void* values,\n                                           uint32_t* n_missing, uint32_t* missing_indices) {\n  std::lock_guard<std::recursive_mutex> lock(mutex_);\n  offsets_buffer_.resize(num_keys);\n  void* blocks_ptr = nullptr;\n  if (value_size_ == logical_block_size_\n      && reinterpret_cast<uintptr_t>(values) % physical_block_size_ == 0) {\n    blocks_ptr = values;\n  } else {\n    blocks_buffer_.Resize(num_keys * logical_block_size_);\n    blocks_ptr = blocks_buffer_.ptr();\n  }\n  GetBlocks(num_keys, keys, blocks_ptr, offsets_buffer_.data());\n  uint32_t missing_count = 0;\n  for (uint32_t i = 0; i < num_keys; ++i) {\n    if (offsets_buffer_.at(i) == logical_block_size_) {\n      missing_indices[missing_count] = i;\n      missing_count += 1;\n    } else {\n      if (value_size_ != logical_block_size_) {\n        MemcpyOffset(values, i * value_size_, blocks_ptr,\n                     (i * logical_block_size_) + offsets_buffer_[i], value_size_);\n      }\n    }\n  }\n  *n_missing = missing_count;\n}\n\ntemplate<typename Key, typename Engine>\nvoid PersistentTableImpl<Key, Engine>::PutBlocks(uint32_t num_keys, const void* keys,\n                                                 const void* blocks) {\n  CHECK(!read_only_);\n  std::lock_guard<std::recursive_mutex> lock(mutex_);\n  const uint32_t num_blocks = RoundUp(num_keys, num_values_per_block_) / num_values_per_block_;\n  const uint32_t num_padded_keys = num_blocks * num_values_per_block_;\n  const uint64_t start_index = physical_table_size_;\n  physical_table_size_ += num_padded_keys;\n  CHECK_EQ(start_index % num_values_per_block_, 0);\n  const uint64_t start_block_id = start_index / num_values_per_block_;\n  uint64_t written_blocks = 0;\n  const uint64_t block_keys_size = num_values_per_block_ * sizeof(Key);\n  BlockingCounter bc(1);\n  workers_.at(0)->Schedule([&](Engine*) {\n    while (written_blocks < num_blocks) {\n      const uint64_t batch_start_block_id = start_block_id + written_blocks;\n      const uint64_t batch_chunk_id = batch_start_block_id / num_logical_blocks_per_chunk_;\n      if (batch_chunk_id == value_files_.size()) {\n        value_files_.emplace_back(ValueFilePath(batch_chunk_id), O_CREAT | O_RDWR | O_DIRECT, 0644);\n      } else {\n        CHECK_LE(batch_chunk_id, value_files_.size());\n      }\n      if ((!writable_key_file_.IsOpen()) || writable_key_file_chunk_id_ != batch_chunk_id) {\n        writable_key_file_ = PosixFile(KeyFilePath(batch_chunk_id), O_CREAT | O_RDWR, 0644);\n      }\n      PosixFile& value_file = value_files_.at(batch_chunk_id);\n      const uint64_t block_id_in_chunk =\n          batch_start_block_id - batch_chunk_id * num_logical_blocks_per_chunk_;\n      const uint64_t blocks_to_write =\n          std::min(num_blocks - written_blocks,\n                   (batch_chunk_id + 1) * num_logical_blocks_per_chunk_ - batch_start_block_id);\n      const uint64_t values_bytes = blocks_to_write * logical_block_size_;\n      const uint64_t values_offset_in_file = block_id_in_chunk * logical_block_size_;\n      CHECK_LE(value_file.Size(), values_offset_in_file);\n      value_file.Truncate(values_offset_in_file + values_bytes);\n      PCHECK(pwrite(value_file.fd(), BytesOffset(blocks, written_blocks * logical_block_size_),\n                    values_bytes, values_offset_in_file)\n             == values_bytes);\n      const uint64_t keys_offset_in_file = block_id_in_chunk * block_keys_size;\n      writable_key_file_.Truncate(keys_offset_in_file + blocks_to_write * block_keys_size);\n      const uint64_t keys_bytes = std::min(num_keys - written_blocks * num_values_per_block_,\n                                           blocks_to_write * num_values_per_block_)\n                                  * sizeof(Key);\n      PCHECK(pwrite(writable_key_file_.fd(), BytesOffset(keys, written_blocks * block_keys_size),\n                    keys_bytes, keys_offset_in_file)\n             == keys_bytes);\n      written_blocks += blocks_to_write;\n    }\n    bc.Decrease();\n  });\n  for (uint64_t i = 0; i < num_keys; ++i) {\n    row_id_mapping_[static_cast<const Key*>(keys)[i]] = start_index + i;\n  }\n  bc.WaitForeverUntilCntEqualZero();\n}\n\ntemplate<typename Key, typename Engine>\nvoid PersistentTableImpl<Key, Engine>::Put(uint32_t num_keys, const void* keys,\n                                           const void* values) {\n  CHECK(!read_only_);\n  std::lock_guard<std::recursive_mutex> lock(mutex_);\n  const void* blocks_ptr = nullptr;\n  if (value_size_ == logical_block_size_\n      && reinterpret_cast<uintptr_t>(values) % physical_block_size_ == 0) {\n    blocks_ptr = values;\n  } else {\n    const uint32_t num_blocks = RoundUp(num_keys, num_values_per_block_);\n    blocks_buffer_.Resize(num_blocks * logical_block_size_);\n    for (uint32_t i = 0; i < num_keys; i += num_values_per_block_) {\n      const uint32_t block_id = i / num_values_per_block_;\n      const uint32_t copy_size = (num_keys - i) < num_values_per_block_\n                                     ? (num_keys - i) * value_size_\n                                     : logical_block_size_;\n      MemcpyOffset(blocks_buffer_.ptr(), block_id * logical_block_size_, values, i * value_size_,\n                   copy_size);\n    }\n    blocks_ptr = blocks_buffer_.ptr();\n  }\n  PutBlocks(num_keys, keys, blocks_ptr);\n}\n\ntemplate<typename Key, typename Engine>\nstd::string PersistentTableImpl<Key, Engine>::KeyFilePath(uint64_t chunk_id) const {\n  return PosixFile::JoinPath(keys_dir_, kKeyFileNamePrefix + GetChunkName(chunk_id));\n}\n\ntemplate<typename Key, typename Engine>\nstd::string PersistentTableImpl<Key, Engine>::ValueFilePath(uint64_t chunk_id) const {\n  return PosixFile::JoinPath(values_dir_, kValueFileNamePrefix + GetChunkName(chunk_id));\n}\n\ntemplate<typename Key, typename Engine>\nstd::string PersistentTableImpl<Key, Engine>::IndexFilePath(const std::string& name,\n                                                            uint64_t chunk_id) const {\n  return PosixFile::JoinPath(SnapshotDirPath(name), kIndexFileNamePrefix + GetChunkName(chunk_id));\n}\n\ntemplate<typename Key, typename Engine>\nstd::string PersistentTableImpl<Key, Engine>::SnapshotDirPath(const std::string& name) const {\n  return PosixFile::JoinPath(snapshots_dir_, name);\n}\n\ntemplate<typename Key, typename Engine>\nstd::string PersistentTableImpl<Key, Engine>::SnapshotListFilePath(const std::string& name) const {\n  return PosixFile::JoinPath(SnapshotDirPath(name), kSnapshotListFileName);\n}\n\ntemplate<typename Key, typename Engine>\nvoid PersistentTableImpl<Key, Engine>::LoadSnapshotImpl(const std::string& name) {\n  std::lock_guard<std::recursive_mutex> lock(mutex_);\n  const std::string snapshot_base = SnapshotDirPath(name);\n  const std::string snapshot_list = SnapshotListFilePath(name);\n  row_id_mapping_.clear();\n  std::ifstream list_if(snapshot_list);\n  std::string index_filename;\n  while (std::getline(list_if, index_filename)) {\n    const uint64_t chunk_id = GetChunkId(index_filename, kIndexFileNamePrefix);\n    PosixFile index_file(PosixFile::JoinPath(snapshot_base, index_filename), O_RDONLY, 0644);\n    const size_t index_file_size = index_file.Size();\n    CHECK_EQ(index_file_size % sizeof(uint64_t), 0);\n    if (index_file_size == 0) { return; }\n    const size_t n_entries = index_file_size / sizeof(uint64_t);\n    PosixMappedFile mapped_index(std::move(index_file), index_file_size, PROT_READ);\n    PosixFile key_file(KeyFilePath(chunk_id), O_RDONLY, 0644);\n    PosixMappedFile mapped_key(std::move(key_file), key_file.Size(), PROT_READ);\n    const uint64_t* indices = static_cast<const uint64_t*>(mapped_index.ptr());\n    const Key* keys = static_cast<const Key*>(mapped_key.ptr());\n    const uint64_t chunk_start_index = chunk_id * num_values_per_chunk_;\n    row_id_mapping_.reserve(row_id_mapping_.size() + n_entries);\n    for (size_t i = 0; i < n_entries; ++i) {\n      CHECK(row_id_mapping_.emplace(keys[indices[i] - chunk_start_index], indices[i]).second);\n    }\n  }\n}\n\ntemplate<typename Key, typename Engine>\nvoid PersistentTableImpl<Key, Engine>::SaveSnapshotImpl(const std::string& name) {\n  CHECK(!read_only_);\n  std::lock_guard<std::recursive_mutex> lock(mutex_);\n  PosixFile::RecursiveCreateDirectory(SnapshotDirPath(name), 0755);\n  std::ofstream list_ofs(SnapshotListFilePath(name));\n  if (row_id_mapping_.empty()) { return; }\n  std::vector<PosixMappedFile> index_files(value_files_.size());\n  std::vector<uint64_t> counters(value_files_.size());\n  const uint64_t max_index_file_size = num_values_per_chunk_ * sizeof(uint64_t);\n  for (const auto& pair : row_id_mapping_) {\n    const uint64_t chunk_id = pair.second / num_values_per_chunk_;\n    CHECK(chunk_id < value_files_.size());\n    if (index_files[chunk_id].ptr() == nullptr) {\n      PosixFile snapshot_file(IndexFilePath(name, chunk_id), O_CREAT | O_RDWR, 0644);\n      snapshot_file.Truncate(max_index_file_size);\n      index_files[chunk_id] =\n          PosixMappedFile(std::move(snapshot_file), max_index_file_size, PROT_READ | PROT_WRITE);\n    }\n    uint64_t* indices = static_cast<uint64_t*>(index_files[chunk_id].ptr());\n    uint64_t& count = counters[chunk_id];\n    CHECK_LT(count, num_values_per_chunk_);\n    indices[count] = pair.second;\n    count += 1;\n  }\n  for (size_t i = 0; i < value_files_.size(); ++i) {\n    const uint64_t count = counters[i];\n    if (count > 0) {\n      index_files[i].file().Truncate(count * sizeof(uint64_t));\n      list_ofs << kIndexFileNamePrefix + GetChunkName(i) << std::endl;\n    } else {\n      CHECK(index_files[i].ptr() == nullptr);\n    }\n  }\n}\n\ntemplate<typename Key, typename Engine>\nbool PersistentTableImpl<Key, Engine>::SnapshotExists(const std::string& name) {\n  std::lock_guard<std::recursive_mutex> lock(mutex_);\n  return PosixFile::FileExists(SnapshotListFilePath(name));\n}\n\ntemplate<typename Key, typename Engine>\nvoid PersistentTableImpl<Key, Engine>::LoadSnapshot(const std::string& name) {\n  LoadSnapshotImpl(name);\n}\n\ntemplate<typename Key, typename Engine>\nvoid PersistentTableImpl<Key, Engine>::LoadSnapshot(\n    const std::string& name, const std::function<void(Iterator* iter)>& Hook) {\n  std::lock_guard<std::recursive_mutex> lock(mutex_);\n  int mmap_flags = MAP_SHARED;\n  if (ParseBooleanFromEnv(\"ONEFLOW_ONE_EMBEDDING_PERSISTENT_TABLE_SNAPSHOT_LOAD_MAP_POPULATE\",\n                          true)) {\n    mmap_flags |= MAP_POPULATE;\n  }\n  const std::string snapshot_base = SnapshotDirPath(name);\n  const std::string snapshot_list = SnapshotListFilePath(name);\n  row_id_mapping_.clear();\n  std::ifstream list_if(snapshot_list);\n  std::string index_filename;\n  while (std::getline(list_if, index_filename)) {\n    const uint64_t chunk_id = GetChunkId(index_filename, kIndexFileNamePrefix);\n    PosixFile index_file(PosixFile::JoinPath(snapshot_base, index_filename), O_RDONLY, 0644);\n    const size_t index_file_size = index_file.Size();\n    CHECK_EQ(index_file_size % sizeof(uint64_t), 0);\n    if (index_file_size == 0) { return; }\n    const size_t n_entries = index_file_size / sizeof(uint64_t);\n    PosixMappedFile mapped_index(std::move(index_file), index_file_size, PROT_READ, mmap_flags);\n    PosixFile key_file(KeyFilePath(chunk_id), O_RDONLY, 0644);\n    PosixMappedFile mapped_key(std::move(key_file), key_file.Size(), PROT_READ, mmap_flags);\n    const uint64_t* indices = static_cast<const uint64_t*>(mapped_index.ptr());\n    const Key* keys = static_cast<const Key*>(mapped_key.ptr());\n    const uint64_t chunk_start_index = chunk_id * num_values_per_chunk_;\n    row_id_mapping_.reserve(row_id_mapping_.size() + n_entries);\n    for (size_t i = 0; i < n_entries; ++i) {\n      CHECK(row_id_mapping_.emplace(keys[indices[i] - chunk_start_index], indices[i]).second);\n    }\n    if (Hook) {\n      PosixFile value_file(ValueFilePath(chunk_id), O_RDONLY, 0644);\n      PosixMappedFile mapped_value(std::move(value_file), value_file.Size(), PROT_READ, mmap_flags);\n      ChunkIteratorImpl<Key> chunk_iterator(value_size_, logical_block_size_, num_values_per_block_,\n                                            num_values_per_chunk_, chunk_id, n_entries, keys,\n                                            indices, mapped_value.ptr());\n      Hook(&chunk_iterator);\n    }\n  }\n}\n\ntemplate<typename Key, typename Engine>\nvoid PersistentTableImpl<Key, Engine>::SaveSnapshot(const std::string& name) {\n  SaveSnapshotImpl(name);\n}\n\ntemplate<typename Key, typename Engine>\nPersistentTable::Iterator* PersistentTableImpl<Key, Engine>::ReadSnapshot(const std::string& name) {\n  return new SnapshotIteratorImpl<Key, Engine>(this, name, value_size_, logical_block_size_,\n                                               num_values_per_block_, num_values_per_chunk_);\n}\n\ntemplate<typename Key, typename Engine>\nvoid PersistentTableImpl<Key, Engine>::ParallelFor(size_t total,\n                                                   const ForRange<Engine>& for_range) {\n  BlockingCounter bc(workers_.size());\n  std::atomic<size_t> counter(0);\n  for (size_t i = 0; i < workers_.size(); ++i) {\n    workers_.at(i)->Schedule([&](Engine* engine) {\n      while (true) {\n        const size_t start = counter.fetch_add(kParallelForStride, std::memory_order_relaxed);\n        if (start >= total) { break; }\n        const size_t next_start = start + kParallelForStride;\n        const size_t end = std::min(next_start, total);\n        for_range(engine, start, end);\n      }\n      engine->WaitUntilDone();\n      bc.Decrease();\n    });\n  }\n  bc.WaitForeverUntilCntEqualZero();\n}\n\ntemplate<typename Key, typename Engine>\nclass SnapshotIteratorImpl : public PersistentTable::Iterator {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(SnapshotIteratorImpl);\n  SnapshotIteratorImpl(PersistentTableImpl<Key, Engine>* table, const std::string& snapshot_name,\n                       uint32_t value_size, uint32_t logical_block_size,\n                       uint32_t num_values_per_block, uint64_t num_values_per_chunk)\n      : table_(table),\n        snapshot_name_(snapshot_name),\n        value_size_(value_size),\n        logical_block_size_(logical_block_size),\n        num_values_per_block_(num_values_per_block),\n        num_values_per_chunk_(num_values_per_chunk),\n        current_chunk_(0) {\n    const std::string snapshot_list = table_->SnapshotListFilePath(snapshot_name);\n    std::ifstream list_if(snapshot_list);\n    std::string index_filename;\n    while (std::getline(list_if, index_filename)) { indices_names_.push_back(index_filename); }\n  }\n  ~SnapshotIteratorImpl() override = default;\n\n  void Next(uint32_t num_keys, uint32_t* return_keys, void* keys, void* values) override {\n    *return_keys = 0;\n    while (current_chunk_ < indices_names_.size()) {\n      if (!chunk_iterator_) {\n        const std::string snapshot_base = table_->SnapshotDirPath(snapshot_name_);\n        const uint64_t chunk_id = GetChunkId(indices_names_[current_chunk_], kIndexFileNamePrefix);\n        PosixFile index_file(PosixFile::JoinPath(snapshot_base, indices_names_[current_chunk_]),\n                             O_RDONLY, 0644);\n        const size_t index_file_size = index_file.Size();\n        CHECK_EQ(index_file_size % sizeof(uint64_t), 0);\n        if (index_file_size == 0) {\n          current_chunk_ += 1;\n          continue;\n        }\n        const size_t n_entries = index_file_size / sizeof(uint64_t);\n        indices_file_.reset(new PosixMappedFile(std::move(index_file), index_file_size, PROT_READ));\n        PosixFile key_file(table_->KeyFilePath(chunk_id), O_RDONLY, 0644);\n        keys_file_.reset(new PosixMappedFile(std::move(key_file), key_file.Size(), PROT_READ));\n        PosixFile value_file(table_->ValueFilePath(chunk_id), O_RDONLY, 0644);\n        values_file_.reset(\n            new PosixMappedFile(std::move(value_file), value_file.Size(), PROT_READ));\n        chunk_iterator_.reset(new ChunkIteratorImpl<Key>(\n            value_size_, logical_block_size_, num_values_per_block_, num_values_per_chunk_,\n            chunk_id, n_entries, static_cast<const Key*>(keys_file_->ptr()),\n            static_cast<const uint64_t*>(indices_file_->ptr()), values_file_->ptr()));\n      }\n      chunk_iterator_->Next(num_keys, return_keys, keys, values);\n      if (*return_keys == 0) {\n        chunk_iterator_.reset();\n        keys_file_.reset();\n        values_file_.reset();\n        indices_file_.reset();\n        current_chunk_ += 1;\n        continue;\n      } else {\n        return;\n      }\n    }\n  }\n\n  void Reset() override { UNIMPLEMENTED(); }\n\n private:\n  PersistentTableImpl<Key, Engine>* table_;\n  std::string snapshot_name_;\n  uint32_t value_size_;\n  uint32_t logical_block_size_;\n  uint32_t num_values_per_block_;\n  uint64_t num_values_per_chunk_;\n  size_t current_chunk_;\n  std::vector<std::string> indices_names_;\n  std::unique_ptr<PosixMappedFile> keys_file_;\n  std::unique_ptr<PosixMappedFile> values_file_;\n  std::unique_ptr<PosixMappedFile> indices_file_;\n  std::unique_ptr<ChunkIteratorImpl<Key>> chunk_iterator_;\n};\n\ntemplate<typename Engine>\nstd::unique_ptr<PersistentTable> DispatchKeyType(const PersistentTableOptions& options) {\n  if (options.key_size == 4) {\n    return std::unique_ptr<PersistentTable>(new PersistentTableImpl<uint32_t, Engine>(options));\n  } else if (options.key_size == 8) {\n    return std::unique_ptr<PersistentTable>(new PersistentTableImpl<uint64_t, Engine>(options));\n  } else {\n    UNIMPLEMENTED();\n    return nullptr;\n  }\n}\n\nstd::unique_ptr<PersistentTable> DispatchEngine(const PersistentTableOptions& options) {\n  return DispatchKeyType<AioEngine>(options);\n}\n\n}  // namespace\n\n#endif  // __linux__\n\nstd::unique_ptr<PersistentTable> NewPersistentTable(const PersistentTableOptions& options) {\n#ifdef __linux__\n  CHECK(!options.path.empty());\n  CHECK_GT(options.value_size, 0);\n  CHECK_GT(options.target_chunk_size_mb, 0);\n  CHECK_GT(options.physical_block_size, 0);\n  CHECK_GT(options.key_size, 0);\n  return DispatchEngine(options);\n#else\n  UNIMPLEMENTED();\n  return nullptr;\n#endif  // __linux__\n}\n\n}  // namespace embedding\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/embedding/persistent_table.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_EMBEDDING_PERSISTENT_TABLE_H_\n#define ONEFLOW_CORE_EMBEDDING_PERSISTENT_TABLE_H_\n\n#include \"oneflow/core/common/util.h\"\n\nnamespace oneflow {\n\nnamespace embedding {\n\nstruct PersistentTableOptions {\n  std::string path;\n  uint32_t key_size = 0;\n  uint32_t value_size = 0;\n  uint64_t target_chunk_size_mb = 4 * 1024;\n  uint16_t physical_block_size = 4096;\n  uint64_t capacity_hint = 0;\n  bool read_only = false;\n};\n\nclass PersistentTable {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(PersistentTable);\n  PersistentTable() = default;\n  virtual ~PersistentTable() = default;\n\n  class Iterator {\n   public:\n    OF_DISALLOW_COPY_AND_MOVE(Iterator);\n    Iterator() = default;\n    virtual ~Iterator() = default;\n\n    virtual void Next(uint32_t n_request, uint32_t* n_result, void* keys, void* values) = 0;\n    virtual void Reset() = 0;\n  };\n\n  virtual uint32_t KeySize() const = 0;\n  virtual uint32_t ValueSize() const = 0;\n  virtual uint32_t LogicalBlockSize() const = 0;\n  virtual void GetBlocks(uint32_t num_keys, const void* keys, void* blocks, uint32_t* offsets) = 0;\n  virtual void Get(uint32_t num_keys, const void* keys, void* values, uint32_t* n_missing,\n                   uint32_t* missing_indices) = 0;\n  virtual void PutBlocks(uint32_t num_keys, const void* keys, const void* blocks) = 0;\n  virtual void Put(uint32_t num_keys, const void* keys, const void* values) = 0;\n  virtual bool SnapshotExists(const std::string& name) = 0;\n  virtual void LoadSnapshot(const std::string& name) = 0;\n  virtual void LoadSnapshot(const std::string& name,\n                            const std::function<void(Iterator* iter)>& Hook) = 0;\n  virtual void SaveSnapshot(const std::string& name) = 0;\n  virtual Iterator* ReadSnapshot(const std::string& name) = 0;\n};\n\nstd::unique_ptr<PersistentTable> NewPersistentTable(const PersistentTableOptions& options);\n\n}  // namespace embedding\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_EMBEDDING_PERSISTENT_TABLE_H_\n"
  },
  {
    "path": "oneflow/core/embedding/persistent_table_key_value_store.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/embedding/persistent_table_key_value_store.h\"\n#include \"oneflow/core/device/cuda_util.h\"\n#include \"oneflow/core/embedding/persistent_table.h\"\n#include <robin_hood.h>\n#include <fcntl.h>\n#include <sys/mman.h>\n#include <sys/stat.h>\n#include <dirent.h>\n\nnamespace oneflow {\n\nnamespace embedding {\n\nnamespace {\n\nclass IteratorImpl : public KVIterator {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(IteratorImpl);\n  IteratorImpl(PersistentTable::Iterator* base_iter, uint32_t key_size, uint32_t value_size,\n               uint32_t max_query_length, void* host_keys_buffer, void* host_values_buffer,\n               uint32_t* host_num_buffer)\n      : base_iter_(base_iter),\n        key_size_(key_size),\n        value_size_(value_size),\n        max_query_length_(max_query_length),\n        host_keys_buffer_(host_keys_buffer),\n        host_values_buffer_(host_values_buffer),\n        host_num_buffer_(host_num_buffer) {}\n  ~IteratorImpl() override = default;\n\n  void NextN(ep::Stream* stream, uint32_t n_request, uint32_t* n_result, void* keys,\n             void* values) override {\n    CHECK_LE(n_request, max_query_length_);\n    auto cuda_stream = stream->As<ep::CudaStream>();\n    CHECK_JUST(cuda_stream->Sync());\n    base_iter_->Next(n_request, host_num_buffer_, host_keys_buffer_, host_values_buffer_);\n    OF_CUDA_CHECK(cudaMemcpyAsync(n_result, host_num_buffer_, sizeof(uint32_t), cudaMemcpyDefault,\n                                  cuda_stream->cuda_stream()));\n    const uint32_t num_keys = *host_num_buffer_;\n    if (num_keys != 0) {\n      OF_CUDA_CHECK(cudaMemcpyAsync(keys, host_keys_buffer_, num_keys * key_size_,\n                                    cudaMemcpyDefault, cuda_stream->cuda_stream()));\n      OF_CUDA_CHECK(cudaMemcpyAsync(values, host_values_buffer_, num_keys * value_size_,\n                                    cudaMemcpyDefault, cuda_stream->cuda_stream()));\n    }\n  }\n\n  void Reset() override { base_iter_->Reset(); }\n\n private:\n  PersistentTable::Iterator* base_iter_;\n  uint32_t key_size_;\n  uint32_t value_size_;\n  uint32_t max_query_length_;\n  void* host_keys_buffer_;\n  void* host_values_buffer_;\n  uint32_t* host_num_buffer_;\n};\n\ntemplate<typename Key>\nclass KeyValueStoreImpl : public KeyValueStore {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(KeyValueStoreImpl);\n  explicit KeyValueStoreImpl(const PersistentTableKeyValueStoreOptions& options)\n      : device_index_(-1), max_query_length_(0) {\n    OF_CUDA_CHECK(cudaGetDevice(&device_index_));\n    key_size_ = options.table_options.key_size;\n    value_size_ = options.table_options.value_size;\n    table_ = NewPersistentTable(options.table_options);\n    OF_CUDA_CHECK(NumaAwareCudaMallocHost(\n        device_index_, reinterpret_cast<void**>(&host_query_keys_), key_size_ * max_query_length_));\n    OF_CUDA_CHECK(NumaAwareCudaMallocHost(device_index_,\n                                          reinterpret_cast<void**>(&host_query_values_),\n                                          value_size_ * max_query_length_));\n    OF_CUDA_CHECK(NumaAwareCudaMallocHost(device_index_, reinterpret_cast<void**>(&host_n_missing_),\n                                          sizeof(uint32_t)));\n    OF_CUDA_CHECK(NumaAwareCudaMallocHost(device_index_,\n                                          reinterpret_cast<void**>(&host_missing_indices_),\n                                          sizeof(uint32_t) * max_query_length_));\n  }\n  ~KeyValueStoreImpl() {\n    CudaCurrentDeviceGuard guard(device_index_);\n    if (max_query_length_ != 0) {\n      OF_CUDA_CHECK(cudaFreeHost(host_query_keys_));\n      OF_CUDA_CHECK(cudaFreeHost(host_query_values_));\n      OF_CUDA_CHECK(cudaFreeHost(host_missing_indices_));\n    }\n    OF_CUDA_CHECK(cudaFreeHost(host_n_missing_));\n  }\n\n  uint32_t KeySize() const override { return key_size_; }\n\n  uint32_t ValueSize() const override { return value_size_; }\n\n  uint32_t MaxQueryLength() const override { return max_query_length_; }\n\n  void ReserveQueryLength(uint32_t query_length) override {\n    CudaCurrentDeviceGuard guard(device_index_);\n    if (query_length <= max_query_length_) { return; }\n    if (max_query_length_ != 0) {\n      OF_CUDA_CHECK(cudaFreeHost(host_query_keys_));\n      OF_CUDA_CHECK(cudaFreeHost(host_query_values_));\n      OF_CUDA_CHECK(cudaFreeHost(host_missing_indices_));\n    }\n    OF_CUDA_CHECK(NumaAwareCudaMallocHost(\n        device_index_, reinterpret_cast<void**>(&host_query_keys_), key_size_ * query_length));\n    OF_CUDA_CHECK(NumaAwareCudaMallocHost(\n        device_index_, reinterpret_cast<void**>(&host_query_values_), value_size_ * query_length));\n    OF_CUDA_CHECK(NumaAwareCudaMallocHost(device_index_,\n                                          reinterpret_cast<void**>(&host_missing_indices_),\n                                          sizeof(uint32_t) * query_length));\n    max_query_length_ = query_length;\n  }\n\n  using KeyValueStore::Get;\n  void Get(ep::Stream* stream, uint32_t num_keys, const void* keys, void* values,\n           uint32_t* n_missing, uint32_t* missing_indices) override;\n  void Put(ep::Stream* stream, uint32_t num_keys, const void* keys, const void* values) override;\n  bool SnapshotExists(const std::string& name) override;\n  void LoadSnapshot(const std::string& name) override;\n  void LoadSnapshot(const std::string& name,\n                    const std::function<void(KVIterator* iter)>& Hook) override;\n  void SaveSnapshot(const std::string& name) override;\n\n private:\n  int device_index_;\n  uint32_t max_query_length_;\n  uint32_t key_size_;\n  uint32_t value_size_;\n  Key* host_query_keys_{};\n  uint8_t* host_query_values_{};\n  uint32_t* host_n_missing_{};\n  uint32_t* host_missing_indices_{};\n\n  std::mutex mutex_;\n  std::unique_ptr<PersistentTable> table_;\n};\n\ntemplate<typename Key>\nvoid KeyValueStoreImpl<Key>::Get(ep::Stream* stream, uint32_t num_keys, const void* keys,\n                                 void* values, uint32_t* n_missing, uint32_t* missing_indices) {\n  std::lock_guard<std::mutex> lock(mutex_);\n  auto cuda_stream = stream->As<ep::CudaStream>();\n  CHECK_LE(num_keys, max_query_length_);\n  if (num_keys == 0) {\n    OF_CUDA_CHECK(cudaMemsetAsync(n_missing, 0, sizeof(uint32_t),\n                                  stream->As<ep::CudaStream>()->cuda_stream()));\n    return;\n  }\n  OF_CUDA_CHECK(cudaMemcpyAsync(host_query_keys_, keys, key_size_ * num_keys, cudaMemcpyDefault,\n                                cuda_stream->cuda_stream()));\n  CHECK_JUST(cuda_stream->Sync());\n\n  table_->Get(num_keys, host_query_keys_, host_query_values_, host_n_missing_,\n              host_missing_indices_);\n\n  OF_CUDA_CHECK(cudaMemcpyAsync(values, host_query_values_, num_keys * value_size_,\n                                cudaMemcpyDefault, cuda_stream->cuda_stream()));\n  OF_CUDA_CHECK(cudaMemcpyAsync(n_missing, host_n_missing_, sizeof(uint32_t), cudaMemcpyDefault,\n                                cuda_stream->cuda_stream()));\n  OF_CUDA_CHECK(cudaMemcpyAsync(missing_indices, host_missing_indices_,\n                                (*host_n_missing_) * sizeof(uint32_t), cudaMemcpyDefault,\n                                cuda_stream->cuda_stream()));\n}\n\ntemplate<typename Key>\nvoid KeyValueStoreImpl<Key>::Put(ep::Stream* stream, uint32_t num_keys, const void* keys,\n                                 const void* values) {\n  std::lock_guard<std::mutex> lock(mutex_);\n  auto cuda_stream = stream->As<ep::CudaStream>();\n  CHECK_LE(num_keys, max_query_length_);\n  if (num_keys == 0) { return; }\n  OF_CUDA_CHECK(cudaMemcpyAsync(host_query_keys_, keys, key_size_ * num_keys, cudaMemcpyDefault,\n                                cuda_stream->cuda_stream()));\n  OF_CUDA_CHECK(cudaMemcpyAsync(host_query_values_, values, value_size_ * num_keys,\n                                cudaMemcpyDefault, cuda_stream->cuda_stream()));\n  CHECK_JUST(cuda_stream->Sync());\n  table_->Put(num_keys, host_query_keys_, host_query_values_);\n}\n\ntemplate<typename Key>\nbool KeyValueStoreImpl<Key>::SnapshotExists(const std::string& name) {\n  return table_->SnapshotExists(name);\n}\n\ntemplate<typename Key>\nvoid KeyValueStoreImpl<Key>::LoadSnapshot(const std::string& name) {\n  CudaCurrentDeviceGuard guard(device_index_);\n  LoadSnapshot(name, nullptr);\n}\n\ntemplate<typename Key>\nvoid KeyValueStoreImpl<Key>::LoadSnapshot(const std::string& name,\n                                          const std::function<void(KVIterator* iter)>& Hook) {\n  CudaCurrentDeviceGuard guard(device_index_);\n  if (Hook) {\n    table_->LoadSnapshot(name, [&](PersistentTable::Iterator* chunk_iterator) {\n      IteratorImpl iterator(chunk_iterator, KeySize(), ValueSize(), max_query_length_,\n                            host_query_keys_, host_query_values_, host_n_missing_);\n      Hook(&iterator);\n    });\n  } else {\n    table_->LoadSnapshot(name);\n  }\n}\n\ntemplate<typename Key>\nvoid KeyValueStoreImpl<Key>::SaveSnapshot(const std::string& name) {\n  CudaCurrentDeviceGuard guard(device_index_);\n  table_->SaveSnapshot(name);\n}\n\n}  // namespace\n\nstd::unique_ptr<KeyValueStore> NewPersistentTableKeyValueStore(\n    const PersistentTableKeyValueStoreOptions& options) {\n  if (options.table_options.key_size == sizeof(uint64_t)) {\n    return std::unique_ptr<KeyValueStore>(new KeyValueStoreImpl<uint64_t>(options));\n  } else if (options.table_options.key_size == sizeof(uint32_t)) {\n    return std::unique_ptr<KeyValueStore>(new KeyValueStoreImpl<uint32_t>(options));\n  } else {\n    UNIMPLEMENTED();\n    return nullptr;\n  }\n}\n\n}  // namespace embedding\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/embedding/persistent_table_key_value_store.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_EMBEDDING_PERSISTENT_TABLE_KEY_VALUE_STORE_H_\n#define ONEFLOW_CORE_EMBEDDING_PERSISTENT_TABLE_KEY_VALUE_STORE_H_\n\n#include \"oneflow/core/embedding/key_value_store.h\"\n#include \"oneflow/core/embedding/persistent_table.h\"\n\nnamespace oneflow {\n\nnamespace embedding {\n\n#ifdef WITH_CUDA\n\nstruct PersistentTableKeyValueStoreOptions {\n  PersistentTableOptions table_options{};\n};\n\nstd::unique_ptr<KeyValueStore> NewPersistentTableKeyValueStore(\n    const PersistentTableKeyValueStoreOptions& options);\n\n#endif  // WITH_CUDA\n\n}  // namespace embedding\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_EMBEDDING_PERSISTENT_TABLE_KEY_VALUE_STORE_H_\n"
  },
  {
    "path": "oneflow/core/embedding/posix_file.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_EMBEDDING_POSIX_FILE_H_\n#define ONEFLOW_CORE_EMBEDDING_POSIX_FILE_H_\n\n#ifdef __linux__\n\n#include <fcntl.h>\n#include <sys/types.h>\n#include <sys/stat.h>\n#include <unistd.h>\n#include <sys/ioctl.h>\n#include <linux/fs.h>\n#include <sys/mman.h>\n#include <libgen.h>\n#include <dirent.h>\n#include <sys/file.h>\n\nnamespace oneflow {\n\nnamespace embedding {\n\nclass PosixFile final {\n public:\n  PosixFile() : fd_(-1), size_(0) {}\n  PosixFile(const std::string& pathname, int flags, mode_t mode)\n      : PosixFile(pathname.c_str(), flags, mode) {}\n  PosixFile(const char* pathname, int flags, mode_t mode) : PosixFile() {\n    fd_ = open(pathname, flags, mode);\n    PCHECK(fd_ != -1);\n    struct stat sb {};\n    PCHECK(fstat(fd_, &sb) == 0);\n    size_ = sb.st_size;\n  }\n  PosixFile(PosixFile&& other) noexcept : PosixFile() { *this = std::move(other); }\n  PosixFile(const PosixFile&) = delete;\n  ~PosixFile() { Close(); }\n\n  PosixFile& operator=(PosixFile&& other) noexcept {\n    this->Close();\n    fd_ = other.fd_;\n    other.fd_ = -1;\n    size_ = other.size_;\n    other.size_ = 0;\n    return *this;\n  }\n  PosixFile& operator=(const PosixFile&) = delete;\n\n  int fd() { return fd_; }\n\n  bool IsOpen() { return fd_ != -1; }\n\n  void Close() {\n    if (IsOpen()) {\n      PCHECK(close(fd_) == 0);\n      fd_ = -1;\n    }\n  }\n\n  size_t Size() { return size_; }\n\n  void Truncate(size_t new_size) {\n    CHECK(IsOpen());\n    if (new_size == size_) { return; }\n    PCHECK(ftruncate(fd_, new_size) == 0);\n    size_ = new_size;\n  }\n\n  static bool FileExists(const std::string& pathname) {\n    return access(pathname.c_str(), F_OK) == 0;\n  }\n\n  static std::string JoinPath(const std::string& a, const std::string& b) { return a + \"/\" + b; }\n\n  static void RecursiveCreateDirectory(const std::string& pathname, mode_t mode) {\n    while (true) {\n      struct stat sb {};\n      if (stat(pathname.c_str(), &sb) == 0) {\n        CHECK(S_ISDIR(sb.st_mode)) << \"Could not create directory: '\" << pathname\n                                   << \"' already exists and is not a directory.\";\n        return;\n      } else {\n        PCHECK(errno == ENOENT) << \"Could not create directory '\" << pathname << \"'.\";\n        if (lstat(pathname.c_str(), &sb) == 0) {\n          LOG(FATAL) << \"Could not create directory: '\" << pathname << \"' is a broken link.\";\n        } else {\n          PCHECK(errno == ENOENT) << \"Could not create directory '\" << pathname << \"'.\";\n        }\n        std::vector<char> dirname_input(pathname.size() + 1);\n        std::memcpy(dirname_input.data(), pathname.c_str(), pathname.size() + 1);\n        const std::string parent = dirname(dirname_input.data());\n        RecursiveCreateDirectory(parent, mode);\n        if (mkdir(pathname.c_str(), mode) == 0) {\n          return;\n        } else {\n          PCHECK(errno == EEXIST) << \"Could not create directory '\" << pathname << \"'.\";\n        }\n      }\n    }\n  }\n\n  static void RecursiveDelete(const std::string& pathname) {\n    struct stat sb {};\n    if (stat(pathname.c_str(), &sb) == 0) {\n      if (S_ISDIR(sb.st_mode)) {\n        DIR* dir = opendir(pathname.c_str());\n        PCHECK(dir != nullptr);\n        struct dirent* ent = nullptr;\n        while ((ent = readdir(dir)) != nullptr) {\n          if (strcmp(ent->d_name, \".\") == 0 || strcmp(ent->d_name, \"..\") == 0) { continue; }\n          RecursiveDelete(pathname + \"/\" + ent->d_name);\n        }\n        PCHECK(closedir(dir) == 0);\n        PCHECK(rmdir(pathname.c_str()) == 0);\n      } else {\n        PCHECK(unlink(pathname.c_str()) == 0);\n      }\n    } else {\n      PCHECK(errno == ENOENT);\n    }\n  }\n\n private:\n  int fd_;\n  size_t size_;\n};\n\nclass PosixMappedFile final {\n public:\n  PosixMappedFile() : file_(), ptr_(nullptr) {}\n  PosixMappedFile(PosixFile&& file, size_t size, int prot, int flags)\n      : file_(std::move(file)), ptr_(nullptr) {\n    CHECK_NE(file_.fd(), -1);\n    void* ptr = mmap(nullptr, size, prot, flags, file_.fd(), 0);\n    PCHECK(ptr != MAP_FAILED);\n    ptr_ = ptr;\n  }\n  PosixMappedFile(PosixFile&& file, size_t size, int prot)\n      : PosixMappedFile(std::move(file), size, prot, MAP_SHARED) {}\n  PosixMappedFile(PosixMappedFile&& other) noexcept : PosixMappedFile() {\n    *this = std::move(other);\n  }\n  PosixMappedFile(const PosixMappedFile&) = delete;\n  ~PosixMappedFile() { Unmap(); }\n\n  PosixMappedFile& operator=(PosixMappedFile&& other) noexcept {\n    Unmap();\n    this->file_ = std::move(other.file_);\n    this->ptr_ = other.ptr_;\n    other.ptr_ = nullptr;\n    return *this;\n  }\n  PosixMappedFile& operator=(const PosixMappedFile&) = delete;\n\n  void* ptr() { return ptr_; }\n\n  PosixFile& file() { return file_; }\n\n private:\n  void Unmap() {\n    if (ptr_ != nullptr) { PCHECK(munmap(ptr_, file_.Size()) == 0); }\n  }\n  PosixFile file_;\n  void* ptr_;\n};\n\nclass PosixFileLockGuard final {\n public:\n  OF_DISALLOW_COPY(PosixFileLockGuard);\n  explicit PosixFileLockGuard() : file_() {}\n  explicit PosixFileLockGuard(PosixFile&& file) : file_(std::move(file)) {\n    CHECK_NE(file_.fd(), -1);\n    Lock();\n  }\n  PosixFileLockGuard(PosixFileLockGuard&& other) noexcept { *this = std::move(other); }\n  PosixFileLockGuard& operator=(PosixFileLockGuard&& other) noexcept {\n    Unlock();\n    file_ = std::move(other.file_);\n    return *this;\n  }\n  ~PosixFileLockGuard() { Unlock(); }\n\n private:\n  void Lock() {\n    if (file_.fd() != -1) {\n      struct flock f {};\n      f.l_type = F_WRLCK;\n      f.l_whence = SEEK_SET;\n      f.l_start = 0;\n      f.l_len = 0;\n      PCHECK(fcntl(file_.fd(), F_SETLK, &f) == 0);\n    }\n  }\n  void Unlock() {\n    if (file_.fd() != -1) {\n      struct flock f {};\n      f.l_type = F_UNLCK;\n      f.l_whence = SEEK_SET;\n      f.l_start = 0;\n      f.l_len = 0;\n      PCHECK(fcntl(file_.fd(), F_SETLK, &f) == 0);\n    }\n  }\n\n  PosixFile file_;\n};\n\n}  // namespace embedding\n\n}  // namespace oneflow\n\n#endif  // __linux__\n\n#endif  // ONEFLOW_CORE_EMBEDDING_POSIX_FILE_H_\n"
  },
  {
    "path": "oneflow/core/ep/common/active_device_guard.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/ep/include/active_device_guard.h\"\n#include \"oneflow/core/ep/include/device_manager_registry.h\"\n\nnamespace oneflow {\n\nnamespace ep {\n\nActiveDeviceGuard::ActiveDeviceGuard(Device* device) : device_manager_(device->device_manager()) {\n  saved_active_device_ = device_manager_->GetActiveDeviceIndex();\n  device->SetAsActiveDevice();\n}\n\nActiveDeviceGuard::~ActiveDeviceGuard() {\n  device_manager_->SetActiveDeviceByIndex(saved_active_device_);\n}\n\n}  // namespace ep\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/ep/common/device.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/ep/include/device.h\"\n\nnamespace oneflow {\n\nnamespace ep {\n\nEvent* Device::CreateEvent() {\n  Event* event = nullptr;\n  this->CreateEvents(&event, 1);\n  return event;\n}\n\nvoid Device::DestroyEvent(Event* event) { this->DestroyEvents(&event, 1); }\n\nbool Device::IsStreamOrderedMemoryAllocationSupported() const { return false; }\n\n}  // namespace ep\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/ep/common/device_manager_registry.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/ep/include/device_manager_registry.h\"\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/ep/include/device_manager.h\"\n\nnamespace oneflow {\n\nnamespace ep {\n\nclass DeviceManagerRegistry::Impl {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(Impl);\n  explicit Impl(DeviceManagerRegistry* registry) : registry_(registry) {\n    managers_.resize(DeviceType_ARRAYSIZE);\n  }\n  ~Impl() = default;\n\n  DeviceManager* GetDeviceManagerOrNull(DeviceType device_type) {\n    std::lock_guard<std::mutex> lock(mutex_);\n    if (!managers_.at(device_type)) {\n      std::lock_guard<std::mutex> factories_lock(*factories_mutex());\n      auto& factory = factories()->at(device_type);\n      if (factory) {\n        managers_.at(device_type) = factory->NewDeviceManager(registry_);\n      } else {\n        return nullptr;\n      }\n    }\n    return managers_.at(device_type).get();\n  }\n\n  DeviceManager* GetDeviceManager(DeviceType device_type) {\n    return CHECK_NOTNULL(GetDeviceManagerOrNull(device_type));\n  }\n\n  std::shared_ptr<Device> GetDevice(DeviceType device_type, size_t device_index) {\n    return GetDeviceManager(device_type)->GetDevice(device_index);\n  }\n\n  size_t GetDeviceCount(DeviceType device_type) {\n    DeviceManager* manager = GetDeviceManagerOrNull(device_type);\n    if (manager == nullptr) {\n      return 0;\n    } else {\n      return manager->GetDeviceCount();\n    }\n  }\n\n  size_t GetDeviceCount(const std::string& device_type_name) {\n    return GetDeviceCount(GetDeviceTypeByDeviceTypeName(device_type_name));\n  }\n\n  static void DumpVersionInfo() {\n    std::lock_guard<std::mutex> factories_lock(*factories_mutex());\n    for (auto& factory : *factories()) {\n      if (factory) { factory->DumpVersionInfo(); }\n    }\n  }\n\n  static std::string GetDeviceTypeNameByDeviceType(DeviceType device_type) {\n    static thread_local std::vector<std::string> device_type2device_type_name(DeviceType_ARRAYSIZE);\n    {\n      const std::string& name = device_type2device_type_name.at(device_type);\n      if (!name.empty()) { return name; }\n    }\n    std::lock_guard<std::mutex> factories_lock(*factories_mutex());\n    if (factories()->size() <= device_type) { return \"\"; }\n    auto& factory = factories()->at(device_type);\n    if (!factory) {\n      return \"\";\n    } else {\n      std::string name = factory->device_type_name();\n      device_type2device_type_name.at(device_type) = name;\n      return name;\n    }\n  }\n\n  static DeviceType GetDeviceTypeByDeviceTypeName(const std::string& device_type_name) {\n    static thread_local HashMap<std::string, DeviceType> device_type_name2device_type;\n    {\n      auto it = device_type_name2device_type.find(device_type_name);\n      if (it != device_type_name2device_type.end()) { return it->second; }\n    }\n    std::lock_guard<std::mutex> factories_lock(*factories_mutex());\n    auto it = device_type_name2device_type_map()->find(device_type_name);\n    if (it == device_type_name2device_type_map()->end()) {\n      return DeviceType::kInvalidDevice;\n    } else {\n      device_type_name2device_type[device_type_name] = it->second;\n      return it->second;\n    }\n  }\n\n  static void RegisterDeviceManagerFactory(std::unique_ptr<DeviceManagerFactory>&& factory) {\n    CHECK(factory);\n    const DeviceType device_type = factory->device_type();\n    std::lock_guard<std::mutex> lock(*factories_mutex());\n    factories()->resize(DeviceType_ARRAYSIZE);\n    CHECK(!factories()->at(device_type));\n    const std::string device_type_name = factory->device_type_name();\n    CHECK(!device_type_name.empty());\n    CHECK(device_type_name2device_type_map()->emplace(device_type_name, device_type).second);\n    factories()->at(device_type) = std::move(factory);\n  }\n\n  static std::set<DeviceType> GetRegisteredDeviceTypes() {\n    std::lock_guard<std::mutex> lock(*factories_mutex());\n    std::set<DeviceType> types;\n    for (auto& factory : *factories()) {\n      if (factory) { types.insert(factory->device_type()); }\n    }\n    return types;\n  }\n\n  static bool IsDeviceTypeRegistered(DeviceType device_type) {\n    std::lock_guard<std::mutex> lock(*factories_mutex());\n    return factories()->at(device_type).operator bool();\n  }\n\n private:\n  static HashMap<std::string, DeviceType>* device_type_name2device_type_map() {\n    static HashMap<std::string, DeviceType> device_type_name2device_type;\n    return &device_type_name2device_type;\n  }\n\n  static std::vector<std::unique_ptr<DeviceManagerFactory>>* factories() {\n    static std::vector<std::unique_ptr<DeviceManagerFactory>> factories_vec;\n    return &factories_vec;\n  }\n\n  static std::mutex* factories_mutex() {\n    static std::mutex mutex;\n    return &mutex;\n  }\n\n  std::mutex mutex_;\n  std::vector<std::unique_ptr<DeviceManager>> managers_;\n  DeviceManagerRegistry* registry_;\n};\n\nDeviceManagerRegistry::DeviceManagerRegistry() { impl_.reset(new Impl(this)); }\n\nDeviceManagerRegistry::~DeviceManagerRegistry() = default;\n\nDeviceManager* DeviceManagerRegistry::GetDeviceManager(DeviceType device_type) {\n  return impl_->GetDeviceManager(device_type);\n}\n\nDeviceManager* DeviceManagerRegistry::GetDeviceManagerOrNull(DeviceType device_type) {\n  return impl_->GetDeviceManagerOrNull(device_type);\n}\n\nstd::shared_ptr<Device> DeviceManagerRegistry::GetDevice(DeviceType device_type,\n                                                         size_t device_index) {\n  return impl_->GetDevice(device_type, device_index);\n}\n\nsize_t DeviceManagerRegistry::GetDeviceCount(DeviceType device_type) {\n  return impl_->GetDeviceCount(device_type);\n}\n\nsize_t DeviceManagerRegistry::GetDeviceCount(const std::string& device_type_name) {\n  return impl_->GetDeviceCount(device_type_name);\n}\n\n/*static*/ void DeviceManagerRegistry::RegisterDeviceManagerFactory(\n    std::unique_ptr<DeviceManagerFactory>&& factory) {\n  Impl::RegisterDeviceManagerFactory(std::move(factory));\n}\n\n/*static*/ void DeviceManagerRegistry::DumpVersionInfo() { Impl::DumpVersionInfo(); }\n\n/*static*/ std::string DeviceManagerRegistry::GetDeviceTypeNameByDeviceType(\n    DeviceType device_type) {\n  return Impl::GetDeviceTypeNameByDeviceType(device_type);\n}\n\n/*static*/ DeviceType DeviceManagerRegistry::GetDeviceTypeByDeviceTypeName(\n    const std::string& device_type_name) {\n  return Impl::GetDeviceTypeByDeviceTypeName(device_type_name);\n}\n\n/*static*/ std::set<DeviceType> DeviceManagerRegistry::GetRegisteredDeviceTypes() {\n  return Impl::GetRegisteredDeviceTypes();\n}\n\n/*static*/ bool DeviceManagerRegistry::IsDeviceTypeRegistered(DeviceType device_type) {\n  return Impl::IsDeviceTypeRegistered(device_type);\n}\n\n}  // namespace ep\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/ep/common/onednn.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_EP_COMMON_ONEDNN_H_\n#define ONEFLOW_CORE_EP_COMMON_ONEDNN_H_\n\n#ifdef WITH_ONEDNN\n\n#include \"oneflow/core/common/env_var/env_var.h\"\n\nnamespace oneflow {\n\nDEFINE_ENV_BOOL(ONEFLOW_ENABLE_ONEDNN_OPTS, true);\n\nnamespace ep {\nnamespace primitive {\n\ninline bool OneDnnIsEnabled() { return EnvBool<ONEFLOW_ENABLE_ONEDNN_OPTS>(); }\n\n}  // namespace primitive\n}  // namespace ep\n}  // namespace oneflow\n\n#endif  // WITH_ONEDNN\n\n#endif  // ONEFLOW_CORE_EP_COMMON_ONEDNN_H_\n"
  },
  {
    "path": "oneflow/core/ep/common/primitive/add.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/ep/include/primitive/add.h\"\n\nnamespace oneflow {\n\nnamespace ep {\nnamespace primitive {\n\nvoid Add::Launch(Stream* stream, const void* src0, const void* src1, void* dst, size_t count) {\n  const void* srcs[2];\n  srcs[0] = src0;\n  srcs[1] = src1;\n  Launch(stream, srcs, 2, dst, count);\n}\n\n}  // namespace primitive\n}  // namespace ep\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/ep/common/primitive/batch_matmul.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/ep/include/primitive/batch_matmul.h\"\n#include \"oneflow/core/ep/include/primitive/broadcast_matmul.h\"\n\nnamespace oneflow {\n\nnamespace ep {\nnamespace primitive {\n\nnamespace {\n\nclass BatchMatmulImpl : public BatchMatmul {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(BatchMatmulImpl);\n  BatchMatmulImpl(BlasTransposeType transpose_a, BlasTransposeType transpose_b,\n                  std::unique_ptr<BroadcastMatmul>&& broadcast_matmul)\n      : transpose_a_(transpose_a),\n        transpose_b_(transpose_b),\n        broadcast_matmul_(std::move(broadcast_matmul)) {}\n  ~BatchMatmulImpl() override = default;\n\n  void Launch(Stream* stream, size_t batch_size, size_t m, size_t n, size_t k, Scalar alpha,\n              const void* a, const void* b, Scalar beta, void* c) override {\n    int64_t a_dims[3];\n    int64_t b_dims[3];\n    int64_t c_dims[3];\n    a_dims[0] = batch_size;\n    b_dims[0] = batch_size;\n    c_dims[0] = batch_size;\n    if (transpose_a_ == BlasTransposeType::N) {\n      a_dims[1] = m;\n      a_dims[2] = k;\n    } else if (transpose_a_ == BlasTransposeType::T) {\n      a_dims[1] = k;\n      a_dims[2] = m;\n    } else {\n      UNIMPLEMENTED();\n    }\n    if (transpose_b_ == BlasTransposeType::N) {\n      b_dims[1] = k;\n      b_dims[2] = n;\n    } else if (transpose_b_ == BlasTransposeType::T) {\n      b_dims[1] = n;\n      b_dims[2] = k;\n    } else {\n      UNIMPLEMENTED();\n    }\n    c_dims[1] = m;\n    c_dims[2] = n;\n    broadcast_matmul_->Launch(stream, alpha, 3, a_dims, a, 3, b_dims, b, beta, 3, c_dims, c);\n  }\n\n private:\n  BlasTransposeType transpose_a_;\n  BlasTransposeType transpose_b_;\n  std::unique_ptr<BroadcastMatmul> broadcast_matmul_;\n};\n\ntemplate<DeviceType device_type>\nclass BatchMatmulFactoryImpl : public BatchMatmulFactory {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(BatchMatmulFactoryImpl);\n  BatchMatmulFactoryImpl() = default;\n  ~BatchMatmulFactoryImpl() override = default;\n\n  std::unique_ptr<BatchMatmul> New(DataType data_type, BlasTransposeType transpose_a,\n                                   BlasTransposeType transpose_b) override {\n    auto broadcast_matmul =\n        NewPrimitive<BroadcastMatmulFactory>(device_type, data_type, transpose_a, transpose_b, 3);\n    if (!broadcast_matmul) { return nullptr; }\n    return std::make_unique<BatchMatmulImpl>(transpose_a, transpose_b, std::move(broadcast_matmul));\n  }\n};\n\nREGISTER_PRIMITIVE_FACTORY(DeviceType::kCPU, BatchMatmulFactory,\n                           BatchMatmulFactoryImpl<DeviceType::kCPU>);\n\n#ifdef WITH_CUDA\nREGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, BatchMatmulFactory,\n                           BatchMatmulFactoryImpl<DeviceType::kCUDA>);\n#endif  // WITH_CUDA\n\n}  // namespace\n\n}  // namespace primitive\n}  // namespace ep\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/ep/common/primitive/binary_functor.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_PRIMITIVE_COMMON_BINARY_FUNCTOR_H_\n#define ONEFLOW_CORE_PRIMITIVE_COMMON_BINARY_FUNCTOR_H_\n\n#include \"oneflow/core/ep/include/primitive/binary_op.h\"\n#include \"oneflow/core/ep/common/primitive/unary_functor.h\"\n#include \"oneflow/core/common/data_type.h\"\n#include \"oneflow/core/common/scalar.h\"\n#include <cmath>\n\nnamespace oneflow {\n\nnamespace ep {\nnamespace primitive {\nnamespace broadcast_elementwise_binary {\n\ntemplate<DeviceType device, BinaryOp binary_op, typename Src, typename Dst>\nstruct BinaryFunctor;\n\ntemplate<DeviceType device, typename Src, typename Dst>\nstruct BinaryFunctor<device, BinaryOp::kAdd, Src, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src src0, Src src1) const { return static_cast<Dst>(src0 + src1); }\n};\n\ntemplate<DeviceType device, typename Src, typename Dst>\nstruct BinaryFunctor<device, BinaryOp::kSub, Src, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src src0, Src src1) const { return static_cast<Dst>(src0 - src1); }\n};\n\ntemplate<DeviceType device, typename Src, typename Dst>\nstruct BinaryFunctor<device, BinaryOp::kMul, Src, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src src0, Src src1) const { return static_cast<Dst>(src0 * src1); }\n};\n\ntemplate<DeviceType device>\nstruct BinaryFunctor<device, BinaryOp::kMul, bool, bool> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC bool operator()(bool src0, bool src1) const { return src0 && src1; }\n};\n\ntemplate<DeviceType device, typename Src, typename Dst>\nstruct BinaryFunctor<device, BinaryOp::kDiv, Src, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src src0, Src src1) const { return static_cast<Dst>(src0 / src1); }\n};\n\ntemplate<DeviceType device, typename Src, typename Dst>\nstruct BinaryFunctor<device, BinaryOp::kMax, Src, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src src0, Src src1) const {\n    return static_cast<Dst>(src0 > src1 ? src0 : src1);\n  }\n};\n\ntemplate<DeviceType device, typename Src, typename Dst>\nstruct BinaryFunctor<device, BinaryOp::kMin, Src, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src src0, Src src1) const {\n    return static_cast<Dst>(src0 < src1 ? src0 : src1);\n  }\n};\n\ntemplate<DeviceType device, typename Src, typename Dst>\nstruct BinaryFunctor<device, BinaryOp::kBitwiseAnd, Src, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src src0, Src src1) const { return static_cast<Dst>(src0 & src1); }\n};\n\ntemplate<DeviceType device, typename Src, typename Dst>\nstruct BinaryFunctor<device, BinaryOp::kBitwiseOr, Src, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src src0, Src src1) const { return static_cast<Dst>(src0 | src1); }\n};\n\ntemplate<DeviceType device, typename Src, typename Dst>\nstruct BinaryFunctor<device, BinaryOp::kBitwiseXor, Src, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src src0, Src src1) const { return static_cast<Dst>(src0 ^ src1); }\n};\n\ntemplate<DeviceType device, typename Src, typename Dst>\nstruct BinaryFunctor<device, BinaryOp::kEqual, Src, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src src0, Src src1) const { return static_cast<Dst>(src0 == src1); }\n};\n\ntemplate<DeviceType device, typename Src, typename Dst>\nstruct BinaryFunctor<device, BinaryOp::kNotEqual, Src, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src src0, Src src1) const { return static_cast<Dst>(src0 != src1); }\n};\n\ntemplate<DeviceType device, typename Src, typename Dst>\nstruct BinaryFunctor<device, BinaryOp::kLessThan, Src, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src src0, Src src1) const { return static_cast<Dst>(src0 < src1); }\n};\n\ntemplate<DeviceType device, typename Src, typename Dst>\nstruct BinaryFunctor<device, BinaryOp::kLessEqual, Src, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src src0, Src src1) const { return static_cast<Dst>(src0 <= src1); }\n};\n\ntemplate<DeviceType device, typename Src, typename Dst>\nstruct BinaryFunctor<device, BinaryOp::kGreaterThan, Src, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src src0, Src src1) const { return static_cast<Dst>(src0 > src1); }\n};\n\ntemplate<DeviceType device, typename Src, typename Dst>\nstruct BinaryFunctor<device, BinaryOp::kGreaterEqual, Src, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src src0, Src src1) const { return static_cast<Dst>(src0 >= src1); }\n};\n\ntemplate<DeviceType device, typename Src, typename Dst>\nstruct BinaryFunctor<device, BinaryOp::kIsCloseEqualNan, Src, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1)\n      : atol(attr0.Value<float>()), rtol(attr1.Value<float>()) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src src0, Src src1) const {\n    bool close = src0 == src1;\n    close |= (std::isnan(src0) and std::isnan(src1));\n    if (atol == 0 and rtol == 0) return close;\n    Src allowed_error = static_cast<Src>(atol) + abs(static_cast<Src>(rtol) * src1);\n    Src actual_error = abs(src0 - src1);\n    close |= (std::isfinite(actual_error) and (actual_error <= allowed_error));\n    return close;\n  }\n  float atol, rtol;\n};\n\ntemplate<DeviceType device, typename Src, typename Dst>\nstruct BinaryFunctor<device, BinaryOp::kIsClose, Src, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1)\n      : atol(attr0.Value<float>()), rtol(attr1.Value<float>()) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src src0, Src src1) const {\n    bool close = src0 == src1;\n    if (atol == 0 and rtol == 0) return close;\n    Src allowed_error = static_cast<Src>(atol) + abs(static_cast<Src>(rtol) * src1);\n    Src actual_error = abs(src0 - src1);\n    close |= (std::isfinite(actual_error) and (actual_error <= allowed_error));\n    return close;\n  }\n  float atol, rtol;\n};\n\ntemplate<DeviceType device, typename Src, typename Dst>\nstruct BinaryFunctor<device, BinaryOp::kLogicalAnd, Src, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src src0, Src src1) const { return static_cast<Dst>(src0 && src1); }\n};\n\ntemplate<DeviceType device, typename Src, typename Dst>\nstruct BinaryFunctor<device, BinaryOp::kLogicalOr, Src, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src src0, Src src1) const { return static_cast<Dst>(src0 || src1); }\n};\n\ntemplate<DeviceType device, typename Src, typename Dst>\nstruct BinaryFunctor<device, BinaryOp::kLogicalXor, Src, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src src0, Src src1) const {\n    return static_cast<bool>(src0) != static_cast<bool>(src1);\n  }\n};\n\ntemplate<DeviceType device, typename Src, typename Dst>\nstruct BinaryFunctor<device, BinaryOp::kFmod, Src, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src src0, Src src1) const { return static_cast<Dst>(src0 % src1); }\n};\n\ntemplate<DeviceType device, typename Src, typename Dst>\nstruct BinaryFunctor<device, BinaryOp::kFloorDiv, Src, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src src0, Src src1) const { return src0 / src1; }\n};\n\ntemplate<DeviceType device, typename Src, typename Dst>\nstruct BinaryFunctor<device, BinaryOp::kTruncDiv, Src, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src src0, Src src1) const { return static_cast<Dst>(src0 / src1); }\n};\n\ntemplate<DeviceType device, typename Src, typename Dst>\nstruct BinaryFunctor<device, BinaryOp::kFloorMod, Src, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src src0, Src src1) const {\n    Src trunc_mod = src0 % src1;\n    return (trunc_mod != static_cast<Src>(0))\n                   && ((src1 < static_cast<Src>(0)) != (trunc_mod < static_cast<Src>(0)))\n               ? trunc_mod + src1\n               : trunc_mod;\n  }\n};\n\ntemplate<DeviceType device>\nstruct BinaryFunctor<device, BinaryOp::kFloorMod, uint8_t, uint8_t> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC uint8_t operator()(uint8_t src0, uint8_t src1) const { return src0 % src1; }\n};\n\ntemplate<DeviceType device>\nstruct BinaryFunctor<device, BinaryOp::kFloorMod, uint32_t, uint32_t> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC uint32_t operator()(uint32_t src0, uint32_t src1) const { return src0 % src1; }\n};\n\ntemplate<DeviceType device>\nstruct BinaryFunctor<device, BinaryOp::kFloorMod, uint64_t, uint64_t> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC uint64_t operator()(uint64_t src0, uint64_t src1) const { return src0 % src1; }\n};\n\ntemplate<DeviceType device, typename Src, typename Dst>\nstruct BinaryFunctor<device, BinaryOp::kScalarBasePowerGrad, Src, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) : scalar_operand(attr0.Value<Src>()) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src src0, Src src1) const {\n    return scalar_operand * (pow(src0, scalar_operand - static_cast<Src>(1))) * src1;\n  }\n  Src scalar_operand;\n};\n\ntemplate<DeviceType device, typename Src, typename Dst>\nstruct BinaryFunctor<device, BinaryOp::kScalarExpPowerGrad, Src, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) : scalar_operand(attr0.Value<Src>()) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src src0, Src src1) const {\n    return (pow(scalar_operand, src0)) * log(scalar_operand) * src1;\n  }\n  Src scalar_operand;\n};\n\ntemplate<DeviceType device, typename Src, typename Dst>\nstruct BinaryFunctor<device, BinaryOp::kIdentityBackwardWithDyX, Src, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const { return static_cast<Dst>(dy); }\n};\n\ntemplate<DeviceType device, typename Src, typename Dst>\nstruct BinaryFunctor<device, BinaryOp::kEluBackwardWithDyX, Src, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) : alpha(attr0.Value<double>()) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const {\n    return (x > static_cast<Src>(0)) ? static_cast<Dst>(dy)\n                                     : static_cast<Dst>(dy * alpha * (exp(x)));\n  }\n  const Src alpha;\n};\n\ntemplate<DeviceType device, typename Src, typename Dst>\nstruct BinaryFunctor<device, BinaryOp::kCeluBackwardWithDyY, Src, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1)\n      : inv_alpha(1.0f / attr0.Value<double>()) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src dy, Src y) const {\n    return static_cast<Dst>((y > static_cast<Src>(0))\n                                ? dy\n                                : dy * static_cast<Src>(y * inv_alpha + static_cast<Src>(1)));\n  }\n  const Src inv_alpha;\n};\n\ntemplate<DeviceType device, typename Src, typename Dst>\nstruct BinaryFunctor<device, BinaryOp::kHardswishBackwardWithDyX, Src, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const {\n    if (x <= static_cast<Src>(-3)) {\n      return static_cast<Dst>(0);\n    } else if (x >= static_cast<Src>(3)) {\n      return static_cast<Dst>(dy);\n    } else {\n      return static_cast<Dst>(((x / static_cast<Src>(3)) + static_cast<Src>(0.5)) * dy);\n    }\n  }\n};\n\ntemplate<DeviceType device, typename Src, typename Dst>\nstruct BinaryFunctor<device, BinaryOp::kHardsigmoidBackwardWithDyX, Src, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const {\n    return static_cast<Dst>((x <= static_cast<Src>(-3) || x >= static_cast<Src>(3))\n                                ? static_cast<Src>(0)\n                                : dy / static_cast<Src>(6));\n  }\n};\n\ntemplate<DeviceType device, typename Src, typename Dst>\nstruct BinaryFunctor<device, BinaryOp::kHardshrinkBackwardWithDyY, Src, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src dy, Src y) const {\n    return static_cast<Dst>(y == static_cast<Src>(0) ? 0 : dy);\n  }\n};\n\ntemplate<DeviceType device, typename Src, typename Dst>\nstruct BinaryFunctor<device, BinaryOp::kHardtanhBackwardWithDyY, Src, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1)\n      : min_val(attr0.Value<float>()), max_val(attr1.Value<float>()) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src dy, Src y) const {\n    return static_cast<Dst>((y == min_val || y == max_val) ? static_cast<Src>(0) : dy);\n  }\n\n  const Src min_val;\n  const Src max_val;\n};\n\ntemplate<DeviceType device, typename Src, typename Dst>\nstruct BinaryFunctor<device, BinaryOp::kLeakyReluBackwardWithDyX, Src, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) : alpha(attr0.Value<float>()) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const {\n    return static_cast<Dst>((x > static_cast<Src>(0)) ? dy : dy * alpha);\n  }\n  const Src alpha;\n};\n\ntemplate<DeviceType device, typename Src, typename Dst>\nstruct BinaryFunctor<device, BinaryOp::kMishBackwardWithDyX, Src, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const {\n    Src sp = log(static_cast<Src>(1) + exp(x));\n    Src grad_sp = static_cast<Src>(1) - exp(-sp);\n    Src tsp = (exp(sp) - exp(-sp)) / (exp(sp) + exp(-sp));\n    Src grad_tsp = (static_cast<Src>(1) - tsp * tsp) * grad_sp;\n    return static_cast<Dst>(dy * (x * grad_tsp + tsp));\n  }\n};\n\ntemplate<DeviceType device, typename Src, typename Dst>\nstruct BinaryFunctor<device, BinaryOp::kReluBackwardWithDyY, Src, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src dy, Src y) const {\n    return static_cast<Dst>((y <= static_cast<Src>(0.0)) ? static_cast<Src>(0.0) : dy);\n  }\n};\n\ntemplate<DeviceType device, typename Src, typename Dst>\nstruct BinaryFunctor<device, BinaryOp::kReluBackwardWithDyX, Src, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const {\n    return static_cast<Dst>((x <= static_cast<Src>(0.0)) ? static_cast<Src>(0.0) : dy);\n  }\n};\n\ntemplate<DeviceType device, typename Src, typename Dst>\nstruct BinaryFunctor<device, BinaryOp::kSeluBackwardWithDyX, Src, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const {\n    return static_cast<Dst>((x > static_cast<Src>(0)) ? scale * dy : dy * scale * alpha * (exp(x)));\n  }\n  const Src scale = 1.0507009873554804934193349852946;\n  const Src alpha = 1.6732632423543772848170429916717;\n};\n\ntemplate<DeviceType device, typename Src, typename Dst>\nstruct BinaryFunctor<device, BinaryOp::kSiluBackwardWithDyX, Src, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const {\n    Src sig = static_cast<Src>(1) / (static_cast<Src>(1) + exp(-x));\n    return static_cast<Dst>(dy * (sig * (static_cast<Src>(1) + x * (static_cast<Src>(1) - sig))));\n  }\n};\n\ntemplate<DeviceType device, typename Src, typename Dst>\nstruct BinaryFunctor<device, BinaryOp::kSoftsignBackwardWithDyX, Src, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const {\n    Src val = (static_cast<Src>(1) + abs(x));\n    return static_cast<Dst>(dy / (val * val));\n  }\n};\n\ntemplate<DeviceType device, typename Src, typename Dst>\nstruct BinaryFunctor<device, BinaryOp::kSoftplusBackwardWithDyX, Src, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1)\n      : beta(attr0.Value<double>()), threshold(attr1.Value<double>()) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const {\n    Src z = exp(x * beta);\n    return static_cast<Dst>((x * beta) > threshold ? dy : dy * z / (z + static_cast<Src>(1.0)));\n  }\n  const Src beta;\n  const Src threshold;\n};\n\ntemplate<DeviceType device, typename Src, typename Dst>\nstruct BinaryFunctor<device, BinaryOp::kSoftshrinkBackwardWithDyY, Src, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) : alpha(attr0.Value<double>()) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src dy, Src y) const {\n    return static_cast<Dst>(y == static_cast<Src>(0) ? 0 : dy);\n  }\n  const Src alpha;\n};\n\ntemplate<DeviceType device, typename Src, typename Dst>\nstruct BinaryFunctor<device, BinaryOp::kThresholdBackwardWithDyX, Src, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) : threshold(attr0.Value<double>()) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const {\n    return static_cast<Dst>((x <= threshold) ? 0 : dy);\n  }\n  const Src threshold;\n};\n\ntemplate<DeviceType device, typename Src, typename Dst>\nstruct BinaryFunctor<device, BinaryOp::kAbsBackwardWithDyX, Src, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const {\n    const Src zero = static_cast<Src>(0.0);\n    if (x == zero) {\n      return zero;\n    } else if (x < zero) {\n      return -dy;\n    } else {\n      return dy;\n    }\n  }\n};\n\ntemplate<DeviceType device, typename Src, typename Dst>\nstruct BinaryFunctor<device, BinaryOp::kAcosBackwardWithDyX, Src, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n  OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const {\n    return dy * -rsqrt(static_cast<Src>(1.0) - x * x);\n  }\n};\n\ntemplate<DeviceType device, typename Src, typename Dst>\nstruct BinaryFunctor<device, BinaryOp::kAcoshBackwardWithDyX, Src, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n  OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const {\n    return dy * rsqrt(x * x - static_cast<Src>(1.0));\n  }\n};\n\ntemplate<DeviceType device, typename Src, typename Dst>\nstruct BinaryFunctor<device, BinaryOp::kAsinBackwardWithDyX, Src, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n  OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const {\n    return dy * rsqrt(static_cast<Src>(1.0) - x * x);\n  }\n};\n\ntemplate<DeviceType device, typename Src, typename Dst>\nstruct BinaryFunctor<device, BinaryOp::kAsinhBackwardWithDyX, Src, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n  OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const {\n    return dy * rsqrt(static_cast<Src>(1.0) + x * x);\n  }\n};\n\ntemplate<DeviceType device, typename Src, typename Dst>\nstruct BinaryFunctor<device, BinaryOp::kAtanBackwardWithDyX, Src, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n  OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const {\n    const Src one = static_cast<Src>(1.0);\n    return dy * (one / (one + x * x));\n  }\n};\n\ntemplate<DeviceType device, typename Src, typename Dst>\nstruct BinaryFunctor<device, BinaryOp::kAtanhBackwardWithDyX, Src, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n  OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const {\n    const Src one = static_cast<Src>(1.0);\n    return dy * (one / (one - x * x));\n  }\n};\n\ntemplate<DeviceType device, typename Src, typename Dst>\nstruct BinaryFunctor<device, BinaryOp::kCosBackwardWithDyX, Src, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n  OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const { return dy * (-sin(x)); }\n};\n\ntemplate<DeviceType device, typename Src, typename Dst>\nstruct BinaryFunctor<device, BinaryOp::kCoshBackwardWithDyX, Src, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n  OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const { return dy * sinh(x); }\n};\n\ntemplate<DeviceType device, typename Src, typename Dst>\nstruct BinaryFunctor<device, BinaryOp::kErfBackwardWithDyX, Src, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n  OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const {\n    return dy * static_cast<Src>(2.0) * rsqrt(static_cast<Src>(M_PI)) * exp(-x * x);\n  }\n};\n\ntemplate<DeviceType device, typename Src, typename Dst>\nstruct BinaryFunctor<device, BinaryOp::kErfcBackwardWithDyX, Src, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n  OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const {\n    return dy * -static_cast<Src>(2.0) * rsqrt(static_cast<Src>(M_PI)) * exp(-x * x);\n  }\n};\n\ntemplate<DeviceType device, typename Src, typename Dst>\nstruct BinaryFunctor<device, BinaryOp::kExpBackwardWithDyX, Src, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n  OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const { return dy * exp(x); }\n};\n\ntemplate<DeviceType device, typename Src, typename Dst>\nstruct BinaryFunctor<device, BinaryOp::kExp2BackwardWithDyX, Src, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n  OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const {\n    return dy * exp2(x) * log(static_cast<Src>(2.0));\n  }\n};\n\ntemplate<DeviceType device, typename Src, typename Dst>\nstruct BinaryFunctor<device, BinaryOp::kExpm1BackwardWithDyX, Src, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n  OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const { return dy * exp(x); }\n};\n\ntemplate<DeviceType device, typename Src, typename Dst>\nstruct BinaryFunctor<device, BinaryOp::kLogBackwardWithDyX, Src, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n  OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const { return dy * (static_cast<Src>(1.0) / x); }\n};\n\ntemplate<DeviceType device, typename Src, typename Dst>\nstruct BinaryFunctor<device, BinaryOp::kLog2BackwardWithDyX, Src, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n  OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const {\n    return dy * (static_cast<Src>(1.0) / (x * log(static_cast<Src>(2.0))));\n  }\n};\n\ntemplate<DeviceType device, typename Src, typename Dst>\nstruct BinaryFunctor<device, BinaryOp::kLog10BackwardWithDyX, Src, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n  OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const {\n    return dy * (static_cast<Src>(1.0) / (x * log(static_cast<Src>(10.0))));\n  }\n};\n\ntemplate<DeviceType device, typename Src, typename Dst>\nstruct BinaryFunctor<device, BinaryOp::kLog1pBackwardWithDyX, Src, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n  OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const {\n    return dy * (static_cast<Src>(1.0) / (x + static_cast<Src>(1.0)));\n  }\n};\n\ntemplate<DeviceType device, typename Src, typename Dst>\nstruct BinaryFunctor<device, BinaryOp::kLogSigmoidBackwardWithDyX, Src, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n  OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const {\n    return dy * (static_cast<Src>(1.0) / (exp(x) + static_cast<Src>(1.0)));\n  }\n};\n\ntemplate<DeviceType device, typename Src, typename Dst>\nstruct BinaryFunctor<device, BinaryOp::kReciprocalBackwardWithDyX, Src, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n  OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const {\n    return dy * (-static_cast<Src>(1.0) / (x * x));\n  }\n};\n\ntemplate<DeviceType device, typename Src, typename Dst>\nstruct BinaryFunctor<device, BinaryOp::kReciprocalNoNanBackwardWithDyX, Src, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n  OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const {\n    if (abs(x) <= static_cast<Src>(0.0)) { return static_cast<Dst>(0.0); }\n    return dy * (-static_cast<Src>(1.0) / (x * x));\n  }\n};\n\ntemplate<DeviceType device, typename Src, typename Dst>\nstruct BinaryFunctor<device, BinaryOp::kRsqrtBackwardWithDyX, Src, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n  OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const {\n    return dy * (static_cast<Src>(-1.0) / (static_cast<Src>(2.0) * sqrt(x * x * x)));\n  }\n};\n\ntemplate<DeviceType device, typename Src, typename Dst>\nstruct BinaryFunctor<device, BinaryOp::kSigmoidBackwardWithDyY, Src, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n  OF_DEVICE_FUNC Dst operator()(Src dy, Src y) const { return dy * (y * (1.0 - y)); }\n};\n\ntemplate<DeviceType device, typename Src, typename Dst>\nstruct BinaryFunctor<device, BinaryOp::kSigmoidBackwardWithDyX, Src, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n  OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const {\n    using UnaryOp = oneflow::ep::primitive::UnaryOp;\n    using UnaryFunctor = oneflow::ep::primitive::UnaryFunctor<device, UnaryOp::kSigmoid, Dst, Src>;\n    auto uf = UnaryFunctor(0, 0);\n    Src y = uf(x);\n    return dy * (y * (static_cast<Src>(1.0) - y));\n  }\n};\n\ntemplate<DeviceType device, typename Src, typename Dst>\nstruct BinaryFunctor<device, BinaryOp::kSinBackwardWithDyX, Src, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n  OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const { return dy * cos(x); }\n};\n\ntemplate<DeviceType device, typename Src, typename Dst>\nstruct BinaryFunctor<device, BinaryOp::kSinhBackwardWithDyX, Src, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n  OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const { return dy * cosh(x); }\n};\n\ntemplate<DeviceType device, typename Src, typename Dst>\nstruct BinaryFunctor<device, BinaryOp::kSqrtBackwardWithDyX, Src, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n  OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const {\n    return dy * static_cast<Src>(0.5) / sqrt(x);\n  }\n};\n\ntemplate<DeviceType device, typename Src, typename Dst>\nstruct BinaryFunctor<device, BinaryOp::kSquareBackwardWithDyX, Src, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n  OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const { return dy * static_cast<Src>(2.0) * x; }\n};\n\ntemplate<DeviceType device, typename Src, typename Dst>\nstruct BinaryFunctor<device, BinaryOp::kTanBackwardWithDyX, Src, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n  OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const {\n    const Src cos_val = cos(x);\n    return dy * (static_cast<Src>(1.0) / (cos_val * cos_val));\n  }\n};\n\n}  // namespace broadcast_elementwise_binary\n}  // namespace primitive\n}  // namespace ep\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_PRIMITIVE_COMMON_BINARY_FUNCTOR_H_\n"
  },
  {
    "path": "oneflow/core/ep/common/primitive/broadcast_elementwise_binary.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_PRIMITIVE_COMMON_BROADCAST_ELEMENTWISE_BINARY\n#define ONEFLOW_CORE_PRIMITIVE_COMMON_BROADCAST_ELEMENTWISE_BINARY\n\n#include \"oneflow/core/ep/include/primitive/primitive.h\"\n#include \"oneflow/core/ep/include/primitive/binary_op.h\"\n#include \"oneflow/core/common/nd_index_offset_helper.h\"\n#include \"oneflow/core/ep/common/primitive/util.h\"\n\nnamespace oneflow {\n\nnamespace ep {\nnamespace primitive {\n\nnamespace broadcast_elementwise_binary {\n\nconstexpr size_t kMaxNumDims = 8;\n\ninline bool IsDimsEquals(size_t num_src0_dims, const int64_t* src0_dims, size_t num_src1_dims,\n                         const int64_t* src1_dims) {\n  if (num_src0_dims != num_src1_dims) { return false; }\n  for (size_t i = 0; i < num_src1_dims; ++i) {\n    if (src0_dims[i] != src1_dims[i]) { return false; }\n  }\n  return true;\n}\n\n#define BINARY_MATH_OP_SEQ_0           \\\n  OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kAdd) \\\n  OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kSub) \\\n  OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kMul) \\\n  OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kDiv) \\\n  OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kMax)\n\n#define BINARY_MATH_OP_SEQ_1                \\\n  OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kMin)      \\\n  OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kPow)      \\\n  OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kFmod)     \\\n  OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kFloorDiv) \\\n  OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kTruncDiv)\n\n#define BINARY_MATH_OP_SEQ_2                           \\\n  OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kFloorMod)            \\\n  OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kScalarBasePowerGrad) \\\n  OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kScalarExpPowerGrad)\n\n#define BINARY_COMPLEX_MATH_OP_SEQ     \\\n  OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kAdd) \\\n  OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kSub) \\\n  OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kMul) \\\n  OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kDiv)\n\n#define BINARY_MATH_OP_SEQ \\\n  BINARY_MATH_OP_SEQ_0     \\\n  BINARY_MATH_OP_SEQ_1     \\\n  BINARY_MATH_OP_SEQ_2\n\n#define BINARY_COMPARISION_OP_SEQ_0         \\\n  OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kEqual)    \\\n  OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kNotEqual) \\\n  OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kLessThan) \\\n  OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kLessEqual)\n\n#define BINARY_COMPARISION_OP_SEQ_1                \\\n  OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kGreaterThan)     \\\n  OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kGreaterEqual)    \\\n  OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kIsCloseEqualNan) \\\n  OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kIsClose)\n\n#define BINARY_COMPARISION_OP_SEQ \\\n  BINARY_COMPARISION_OP_SEQ_0     \\\n  BINARY_COMPARISION_OP_SEQ_1\n\n#define BINARY_COMPLEX_COMPARISION_OP_SEQ \\\n  OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kEqual)  \\\n  OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kNotEqual)\n\n#define BINARY_LOGICAL_OP_SEQ                 \\\n  OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kLogicalAnd) \\\n  OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kLogicalOr)  \\\n  OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kLogicalXor)\n\n#define BINARY_BITWISE_OP_SEQ                 \\\n  OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kBitwiseAnd) \\\n  OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kBitwiseOr)  \\\n  OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kBitwiseXor)\n\n#define BINARY_MATH_FLOATING_OP_SEQ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kZeta)\n\n#define BINARY_ACTIVATION_BACKWARD_OP_SEQ_0                   \\\n  OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kIdentityBackwardWithDyX)    \\\n  OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kEluBackwardWithDyX)         \\\n  OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kCeluBackwardWithDyY)        \\\n  OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kGeluBackwardWithDyX)        \\\n  OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kHardswishBackwardWithDyX)   \\\n  OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kHardsigmoidBackwardWithDyX) \\\n  OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kHardshrinkBackwardWithDyY)  \\\n  OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kHardtanhBackwardWithDyY)\n\n#define BINARY_ACTIVATION_BACKWARD_OP_SEQ_1                 \\\n  OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kLeakyReluBackwardWithDyX) \\\n  OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kMishBackwardWithDyX)      \\\n  OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kReluBackwardWithDyY)      \\\n  OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kReluBackwardWithDyX)      \\\n  OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kSeluBackwardWithDyX)      \\\n  OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kSiluBackwardWithDyX)      \\\n  OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kSoftsignBackwardWithDyX)  \\\n  OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kSoftplusBackwardWithDyX)\n\n#define BINARY_ACTIVATION_BACKWARD_OP_SEQ_2                  \\\n  OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kSoftshrinkBackwardWithDyY) \\\n  OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kTanhBackwardWithDyY)       \\\n  OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kThresholdBackwardWithDyX)  \\\n  OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kFastGeluBackwardWithDyX)   \\\n  OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kQuickGeluBackwardWithDyX)  \\\n  OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kSquareReLUBackwardWithDyX)\n\n#define BINARY_ACTIVATION_BACKWARD_OP_SEQ \\\n  BINARY_ACTIVATION_BACKWARD_OP_SEQ_0     \\\n  BINARY_ACTIVATION_BACKWARD_OP_SEQ_1     \\\n  BINARY_ACTIVATION_BACKWARD_OP_SEQ_2\n\n#define BINARY_MATH_BACKWARD_OP_SEQ_0                   \\\n  OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kAbsBackwardWithDyX)   \\\n  OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kAcosBackwardWithDyX)  \\\n  OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kAcoshBackwardWithDyX) \\\n  OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kAsinBackwardWithDyX)  \\\n  OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kAsinhBackwardWithDyX) \\\n  OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kAtanBackwardWithDyX)  \\\n  OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kAtanhBackwardWithDyX) \\\n  OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kCosBackwardWithDyX)\n\n#define BINARY_MATH_BACKWARD_OP_SEQ_1                     \\\n  OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kCoshBackwardWithDyX)    \\\n  OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kErfBackwardWithDyX)     \\\n  OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kErfcBackwardWithDyX)    \\\n  OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kExpBackwardWithDyX)     \\\n  OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kExp2BackwardWithDyX)    \\\n  OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kExpm1BackwardWithDyX)   \\\n  OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kLgammaBackwardWithDyX)  \\\n  OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kDigammaBackwardWithDyX) \\\n  OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kLogBackwardWithDyX)     \\\n  OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kLog2BackwardWithDyX)\n\n#define BINARY_MATH_BACKWARD_OP_SEQ_2                             \\\n  OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kLog10BackwardWithDyX)           \\\n  OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kLog1pBackwardWithDyX)           \\\n  OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kLogSigmoidBackwardWithDyX)      \\\n  OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kReciprocalBackwardWithDyX)      \\\n  OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kReciprocalNoNanBackwardWithDyX) \\\n  OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kRsqrtBackwardWithDyX)           \\\n  OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kSinBackwardWithDyX)             \\\n  OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kSigmoidBackwardWithDyY)\n\n#define BINARY_MATH_BACKWARD_OP_SEQ_3                     \\\n  OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kSigmoidBackwardWithDyX) \\\n  OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kSinhBackwardWithDyX)    \\\n  OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kSqrtBackwardWithDyX)    \\\n  OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kSquareBackwardWithDyX)  \\\n  OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kTanBackwardWithDyX)\n\n#define BINARY_MATH_BACKWARD_OP_SEQ_COMPLEX OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kSqrtBackwardWithDyX)\n\n#define BINARY_MATH_BACKWARD_OP_SEQ \\\n  BINARY_MATH_BACKWARD_OP_SEQ_0     \\\n  BINARY_MATH_BACKWARD_OP_SEQ_1     \\\n  BINARY_MATH_BACKWARD_OP_SEQ_2     \\\n  BINARY_MATH_BACKWARD_OP_SEQ_3\n\n}  // namespace broadcast_elementwise_binary\n}  // namespace primitive\n}  // namespace ep\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_PRIMITIVE_COMMON_BROADCAST_ELEMENTWISE_BINARY\n"
  },
  {
    "path": "oneflow/core/ep/common/primitive/broadcast_elementwise_unary.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_PRIMITIVE_COMMON_BROADCAST_ELEMENTWISE_UNARY\n#define ONEFLOW_CORE_PRIMITIVE_COMMON_BROADCAST_ELEMENTWISE_UNARY\n\n#include \"oneflow/core/ep/include/primitive/broadcast_elementwise_unary.h\"\n#include \"oneflow/core/ep/include/primitive/fast_integer_math.h\"\n#include \"oneflow/core/ep/common/primitive/util.h\"\n\nnamespace oneflow {\n\nnamespace ep {\nnamespace primitive {\n\nnamespace broadcast_elementwise_unary {\n\nconstexpr size_t kMaxNumDims = 8;\n\ntemplate<typename T, int N>\nclass IndexToOffsetWithStrideCalculator {\n public:\n  IndexToOffsetWithStrideCalculator() {}\n\n  OF_DEVICE_FUNC explicit IndexToOffsetWithStrideCalculator(const T* strides) {\n    InitStrides(strides, N);\n  }\n\n  template<typename U>\n  OF_DEVICE_FUNC explicit IndexToOffsetWithStrideCalculator(const U* strides) {\n    T strides_arr[N];\n    for (int i = 0; i < N; ++i) { strides_arr[i] = strides[i]; }\n    InitStrides(strides_arr, N);\n  }\n\n  OF_DEVICE_FUNC explicit IndexToOffsetWithStrideCalculator(const T* strides, int n) {\n    InitStrides(strides, n);\n  }\n\n  template<typename U>\n  OF_DEVICE_FUNC explicit IndexToOffsetWithStrideCalculator(const U* strides, int n) {\n    T strides_arr[N];\n    for (int i = 0; i < N; ++i) {\n      if (i < n) { strides_arr[i] = strides[i]; }\n    }\n    InitStrides(strides_arr, n);\n  }\n\n  ~IndexToOffsetWithStrideCalculator() = default;\n\n  OF_DEVICE_FUNC T NdIndexToOffset(const T* index) const {\n    T offset = 0;\n#ifdef __CUDA_ARCH__\n#pragma unroll\n#endif\n    for (int i = 0; i < N - 1; ++i) { offset += index[i] * stride_[i]; }\n    offset += index[N - 1];\n    return offset;\n  }\n\n  OF_DEVICE_FUNC T NdIndexToOffset(const T* index, int n) const {\n    assert(n <= N);\n    T offset = 0;\n#ifdef __CUDA_ARCH__\n#pragma unroll\n#endif\n    for (int i = 0; i < N; ++i) {\n      if (i < n) { offset += index[i] * stride_[i]; }\n    }\n    return offset;\n  }\n\n  OF_DEVICE_FUNC constexpr int Size() const { return N; }\n\n private:\n  OF_DEVICE_FUNC void InitStrides(const T* strides, const int n) {\n    for (int i = n; i < N; ++i) { stride_[i] = 1; }\n    for (int i = n - 1; i >= 0; --i) { stride_[i] = strides[i]; }\n  }\n\n  T stride_[N];\n};\n\ntemplate<typename T, int N>\nclass OffsetToIndexWithStrideCalculator {\n public:\n  OffsetToIndexWithStrideCalculator() {}\n\n  OF_DEVICE_FUNC explicit OffsetToIndexWithStrideCalculator(const T* dims) {\n    InitFastIntegerMath(dims, N);\n  }\n\n  template<typename U>\n  OF_DEVICE_FUNC explicit OffsetToIndexWithStrideCalculator(const U* dims) {\n    T dims_arr[N];\n    for (int i = 0; i < N; ++i) { dims_arr[i] = dims[i]; }\n    InitFastIntegerMath(dims_arr, N);\n  }\n\n  OF_DEVICE_FUNC explicit OffsetToIndexWithStrideCalculator(const T* dims, int n) {\n    InitFastIntegerMath(dims, n);\n  }\n\n  template<typename U>\n  OF_DEVICE_FUNC explicit OffsetToIndexWithStrideCalculator(const U* dims, int n) {\n    T dims_arr[N];\n    for (int i = 0; i < N; ++i) {\n      if (i < n) { dims_arr[i] = dims[i]; }\n    }\n    InitFastIntegerMath(dims_arr, n);\n  }\n\n  ~OffsetToIndexWithStrideCalculator() = default;\n\n  OF_DEVICE_FUNC void OffsetToNdIndex(T offset, T* index) const {\n    T remaining = offset;\n#ifdef __CUDA_ARCH__\n#pragma unroll\n#endif\n    for (int i = 0; i < N - 1; ++i) {\n      const T idx = math_helper_[i].divides(remaining);\n      index[i] = idx;\n      remaining = remaining - math_helper_[i].mul(idx);\n    }\n    index[N - 1] = remaining;\n  }\n\n  OF_DEVICE_FUNC void OffsetToNdIndex(T offset, T* index, int n) const {\n    assert(n <= N);\n    T remaining = offset;\n#ifdef __CUDA_ARCH__\n#pragma unroll\n#endif\n    for (int i = 0; i < N; ++i) {\n      if (i == n - 1) { break; }\n      if (i < n - 1) {\n        const T idx = math_helper_[i].divides(remaining);\n        index[i] = idx;\n        remaining = remaining - math_helper_[i].mul(idx);\n      }\n    }\n    index[n - 1] = remaining;\n  }\n\n  OF_DEVICE_FUNC T divides(T remaining, int64_t i) const {\n    return math_helper_[i].divides(remaining);\n  }\n\n  OF_DEVICE_FUNC T mul(T idx, int64_t i) const { return math_helper_[i].mul(idx); }\n\n  OF_DEVICE_FUNC constexpr int Size() const { return N; }\n\n private:\n  OF_DEVICE_FUNC void InitFastIntegerMath(const T* dims, const int n) {\n    T stride_arr[N];\n    for (int i = n - 1; i < N; ++i) {\n      stride_arr[i] = 1;\n      math_helper_[i] = FastIntegerMath<T>(1);\n    }\n    for (int i = n - 2; i >= 0; --i) {\n      stride_arr[i] = dims[i + 1] * stride_arr[i + 1];\n      math_helper_[i] = FastIntegerMath<T>(stride_arr[i]);\n    }\n  }\n  FastIntegerMath<T> math_helper_[N];\n};\n\n#define UNARY_IDENTITY_SEQ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kIdentity)\n#define BROADCAST_ELEMENTWISE_CAST_OP_SEQ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kCast)\n\n}  // namespace broadcast_elementwise_unary\n}  // namespace primitive\n}  // namespace ep\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_PRIMITIVE_COMMON_BROADCAST_ELEMENTWISE_UNARY\n"
  },
  {
    "path": "oneflow/core/ep/common/primitive/broadcast_matmul.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_EP_COMMON_PRIMITIVE_BROADCAST_MATMUL_H_\n#define ONEFLOW_CORE_EP_COMMON_PRIMITIVE_BROADCAST_MATMUL_H_\n\n#include \"oneflow/core/ep/include/primitive/broadcast_matmul.h\"\n#include \"oneflow/core/common/nd_index_offset_helper.h\"\n#include \"oneflow/core/common/data_type.h\"\n#include \"oneflow/core/framework/dtype.h\"\n\nnamespace oneflow {\n\nnamespace ep {\nnamespace primitive {\n\nnamespace broadcast_matmul {\n\ninline void Simplify(size_t num_a_dims, const int64_t* a_dims, size_t num_b_dims,\n                     const int64_t* b_dims, size_t num_c_dims, const int64_t* c_dims,\n                     BlasTransposeType transpose_a, BlasTransposeType transpose_b, int64_t* m,\n                     int64_t* n, int64_t* k, int64_t* num_batch_dims, int64_t* broadcast_batch_dims,\n                     int64_t* a_batch_dims, int64_t* b_batch_dims, int64_t* c_batch_dims) {\n  CHECK_GE(num_a_dims, 2);\n  CHECK_GE(num_b_dims, 2);\n  CHECK_GE(num_c_dims, 2);\n  if (transpose_a == BlasTransposeType::N) {\n    *m = a_dims[num_a_dims - 2];\n    *k = a_dims[num_a_dims - 1];\n  } else if (transpose_a == BlasTransposeType::T) {\n    *m = a_dims[num_a_dims - 1];\n    *k = a_dims[num_a_dims - 2];\n  } else {\n    UNIMPLEMENTED();\n  }\n  CHECK_GT(*m, 0);\n  CHECK_GT(*k, 0);\n  if (transpose_b == BlasTransposeType::N) {\n    CHECK_EQ(b_dims[num_b_dims - 2], *k);\n    *n = b_dims[num_b_dims - 1];\n  } else if (transpose_b == BlasTransposeType::T) {\n    CHECK_EQ(b_dims[num_b_dims - 1], *k);\n    *n = b_dims[num_b_dims - 2];\n  } else {\n    UNIMPLEMENTED();\n  }\n  CHECK_GT(*n, 0);\n  CHECK_EQ(c_dims[num_c_dims - 2], *m);\n  CHECK_EQ(c_dims[num_c_dims - 1], *n);\n  const size_t num_max_batch_dims = std::max(std::max(num_a_dims, num_b_dims), num_c_dims) - 2;\n  auto MakeGetBatchDim = [num_max_batch_dims](size_t num_dims, const int64_t* dims) {\n    const int64_t num_batch_dims = num_dims - 2;\n    const int64_t num_padding_dims = num_max_batch_dims - num_batch_dims;\n    return [num_padding_dims, dims](size_t index) {\n      return index < num_padding_dims ? 1 : dims[index - num_padding_dims];\n    };\n  };\n  auto GetABatchDim = MakeGetBatchDim(num_a_dims, a_dims);\n  auto GetBBatchDim = MakeGetBatchDim(num_b_dims, b_dims);\n  auto GetCBatchDim = MakeGetBatchDim(num_c_dims, c_dims);\n  *num_batch_dims = 0;\n  bool prev_broadcast_a = false;\n  bool prev_broadcast_b = false;\n  bool prev_broadcast_c = false;\n  for (int64_t i = 0; i < num_max_batch_dims; ++i) {\n    const int64_t a_dim = GetABatchDim(i);\n    const int64_t b_dim = GetBBatchDim(i);\n    const int64_t c_dim = GetCBatchDim(i);\n    const int64_t broadcast_dim = std::max(std::max(a_dim, b_dim), c_dim);\n    CHECK_GT(broadcast_dim, 0);\n    const bool broadcast_a = (a_dim == 1);\n    const bool broadcast_b = (b_dim == 1);\n    const bool broadcast_c = (c_dim == 1);\n    CHECK((a_dim == broadcast_dim) || broadcast_a);\n    CHECK((b_dim == broadcast_dim) || broadcast_b);\n    CHECK((c_dim == broadcast_dim) || broadcast_c);\n    if (broadcast_dim == 1) {\n      continue;\n    } else if (*num_batch_dims != 0\n               && (prev_broadcast_a == broadcast_a && prev_broadcast_b == broadcast_b\n                   && prev_broadcast_c == broadcast_c)) {\n      a_batch_dims[*num_batch_dims - 1] *= a_dim;\n      b_batch_dims[*num_batch_dims - 1] *= b_dim;\n      c_batch_dims[*num_batch_dims - 1] *= c_dim;\n      broadcast_batch_dims[*num_batch_dims - 1] *= broadcast_dim;\n    } else {\n      a_batch_dims[*num_batch_dims] = a_dim;\n      b_batch_dims[*num_batch_dims] = b_dim;\n      c_batch_dims[*num_batch_dims] = c_dim;\n      broadcast_batch_dims[*num_batch_dims] = broadcast_dim;\n      *num_batch_dims += 1;\n      prev_broadcast_a = broadcast_a;\n      prev_broadcast_b = broadcast_b;\n      prev_broadcast_c = broadcast_c;\n    }\n  }\n  if (*num_batch_dims >= 1 && a_batch_dims[*num_batch_dims - 1] != 1\n      && b_batch_dims[*num_batch_dims - 1] == 1 && c_batch_dims[*num_batch_dims - 1] != 1\n      && transpose_a == BlasTransposeType::N) {\n    *m *= a_batch_dims[*num_batch_dims - 1];\n    *num_batch_dims -= 1;\n  }\n}\n\ntemplate<size_t max_num_dims, typename Func>\nvoid ForEachMatmul(DataType data_type, size_t m, size_t n, size_t k, Scalar beta,\n                   size_t num_batch_dims, const int64_t* broadcast_batch_dims,\n                   const int64_t* a_batch_dims, const int64_t* b_batch_dims,\n                   const int64_t* c_batch_dims, const void* a, const void* b, void* c, Func func) {\n  if (num_batch_dims == 0) {\n    func(a, b, c, beta);\n    return;\n  }\n  const size_t size_of_data_type = GetSizeOfDataType(data_type);\n  const size_t stride_a = m * k * size_of_data_type;\n  const size_t stride_b = k * n * size_of_data_type;\n  const size_t stride_c = m * n * size_of_data_type;\n  int64_t broadcast_batch_count = 1;\n  for (int64_t i = 0; i < num_batch_dims; ++i) { broadcast_batch_count *= broadcast_batch_dims[i]; }\n  NdIndexOffsetHelper<int64_t, max_num_dims> broadcast_index_helper(broadcast_batch_dims,\n                                                                    num_batch_dims);\n  NdIndexOffsetHelper<int64_t, max_num_dims> a_index_helper(a_batch_dims, num_batch_dims);\n  NdIndexOffsetHelper<int64_t, max_num_dims> b_index_helper(b_batch_dims, num_batch_dims);\n  NdIndexOffsetHelper<int64_t, max_num_dims> c_index_helper(c_batch_dims, num_batch_dims);\n  int64_t a_batch_index[max_num_dims]{};\n  int64_t b_batch_index[max_num_dims]{};\n  int64_t c_batch_index[max_num_dims]{};\n  int64_t broadcast_batch_index[max_num_dims]{};\n  bool init_c = true;\n  for (int64_t broadcast_batch_id = 0; broadcast_batch_id < broadcast_batch_count;\n       ++broadcast_batch_id) {\n    broadcast_index_helper.OffsetToNdIndex(broadcast_batch_id, broadcast_batch_index);\n    for (int64_t i = 0; i < num_batch_dims; ++i) {\n      if (a_batch_dims[i] == 1) {\n        a_batch_index[i] = 0;\n      } else {\n        a_batch_index[i] = broadcast_batch_index[i];\n      }\n      if (b_batch_dims[i] == 1) {\n        b_batch_index[i] = 0;\n      } else {\n        b_batch_index[i] = broadcast_batch_index[i];\n      }\n      if (c_batch_dims[i] == 1) {\n        c_batch_index[i] = 0;\n        if (broadcast_batch_index[i] != 0) { init_c = false; }\n      } else {\n        c_batch_index[i] = broadcast_batch_index[i];\n      }\n    }\n    const int64_t a_batch_id = a_index_helper.NdIndexToOffset(a_batch_index);\n    const int64_t b_batch_id = b_index_helper.NdIndexToOffset(b_batch_index);\n    const int64_t c_batch_id = c_index_helper.NdIndexToOffset(c_batch_index);\n    const void* a_ptr = static_cast<const unsigned char*>(a) + a_batch_id * stride_a;\n    const void* b_ptr = static_cast<const unsigned char*>(b) + b_batch_id * stride_b;\n    void* c_ptr = static_cast<unsigned char*>(c) + c_batch_id * stride_c;\n    const Scalar batch_beta = init_c ? beta : Scalar(1);\n    func(a_ptr, b_ptr, c_ptr, batch_beta);\n  }\n}\n\nnamespace internal {\n\nnamespace {\n\nvoid LaunchBroadcastMatmul(Stream* stream, DataType data_type, BlasTransposeType transpose_a,\n                           BlasTransposeType transpose_b, int64_t num_batch_dims,\n                           const int64_t* broadcast_batch_dims, const int64_t* a_batch_dims,\n                           const int64_t* b_batch_dims, const int64_t* c_batch_dims, int64_t m,\n                           int64_t n, int64_t k, Scalar alpha, const void* a, const void* b,\n                           Scalar beta, void* c);\n\ntemplate<size_t max_num_dims>\nclass BroadcastMatmulImpl : public BroadcastMatmul {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(BroadcastMatmulImpl);\n  BroadcastMatmulImpl(DataType data_type, BlasTransposeType transpose_a,\n                      BlasTransposeType transpose_b)\n      : data_type_(data_type), transpose_a_(transpose_a), transpose_b_(transpose_b) {}\n  ~BroadcastMatmulImpl() override = default;\n\n  void Launch(Stream* stream, Scalar alpha, size_t num_a_dims, const int64_t* a_dims, const void* a,\n              size_t num_b_dims, const int64_t* b_dims, const void* b, Scalar beta,\n              size_t num_c_dims, const int64_t* c_dims, void* c) override {\n    CHECK_LE(num_a_dims, max_num_dims);\n    CHECK_LE(num_b_dims, max_num_dims);\n    CHECK_LE(num_c_dims, max_num_dims);\n    int64_t m = 0;\n    int64_t n = 0;\n    int64_t k = 0;\n    int64_t num_batch_dims = 0;\n    int64_t broadcast_batch_dims[max_num_dims]{};\n    int64_t a_batch_dims[max_num_dims]{};\n    int64_t b_batch_dims[max_num_dims]{};\n    int64_t c_batch_dims[max_num_dims]{};\n    Simplify(num_a_dims, a_dims, num_b_dims, b_dims, num_c_dims, c_dims, transpose_a_, transpose_b_,\n             &m, &n, &k, &num_batch_dims, broadcast_batch_dims, a_batch_dims, b_batch_dims,\n             c_batch_dims);\n    LaunchBroadcastMatmul(stream, data_type_, transpose_a_, transpose_b_, num_batch_dims,\n                          broadcast_batch_dims, a_batch_dims, b_batch_dims, c_batch_dims, m, n, k,\n                          alpha, a, b, beta, c);\n  }\n\n private:\n  DataType data_type_;\n  BlasTransposeType transpose_a_;\n  BlasTransposeType transpose_b_;\n};\n\n}  // namespace\n\n}  // namespace internal\n\n}  // namespace broadcast_matmul\n\n}  // namespace primitive\n}  // namespace ep\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_EP_COMMON_PRIMITIVE_BROADCAST_MATMUL_H_\n"
  },
  {
    "path": "oneflow/core/ep/common/primitive/broadcast_simplify_dims_test.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/ep/common/primitive/util.h\"\n#include <gtest/gtest.h>\n\nnamespace oneflow {\n\nnamespace ep {\nnamespace primitive {\n\nnamespace {\n\ntemplate<size_t max_num_dims>\nvoid TestSimplifyBroadcastDims(size_t num_src0_dims, const int64_t* src0_dims, size_t num_src1_dims,\n                               const int64_t* src1_dims, size_t expected_num_dims,\n                               const int64_t* expected_src0_dims, const int64_t* expected_src1_dims,\n                               const int64_t* expected_dst_dims) {\n  size_t simplified_num_dims = 0;\n  int64_t simplified_src0_dims[max_num_dims]{};\n  int64_t simplified_src1_dims[max_num_dims]{};\n  int64_t simplified_dst_dims[max_num_dims]{};\n  SimplifyBroadcastDims<max_num_dims>(num_src0_dims, src0_dims, num_src1_dims, src1_dims,\n                                      &simplified_num_dims, simplified_src0_dims,\n                                      simplified_src1_dims, simplified_dst_dims);\n  ASSERT_EQ(simplified_num_dims, expected_num_dims);\n  for (size_t i = 0; i < simplified_num_dims; ++i) {\n    ASSERT_EQ(simplified_src0_dims[i], expected_src0_dims[i]);\n    ASSERT_EQ(simplified_src1_dims[i], expected_src1_dims[i]);\n    ASSERT_EQ(simplified_dst_dims[i], expected_dst_dims[i]);\n  }\n}\n\nTEST(Broadcast, SimplifyBroadcastDims) {\n  constexpr size_t max_num_dims = 8;\n\n  const size_t num_src0_dims_1 = 4;\n  const size_t num_src1_dims_1 = 5;\n  int64_t src0_dims_1[max_num_dims]{2, 5, 10, 5};\n  int64_t src1_dims_1[max_num_dims]{5, 1, 5, 10, 1};\n  const size_t simplified_num_dims_1 = 4;\n  int64_t simplified_src0_dims_1[max_num_dims]{1, 2, 50, 5};\n  int64_t simplified_src1_dims_1[max_num_dims]{5, 1, 50, 1};\n  int64_t simplified_dst_dims_1[max_num_dims]{5, 2, 50, 5};\n  TestSimplifyBroadcastDims<max_num_dims>(\n      num_src0_dims_1, src0_dims_1, num_src1_dims_1, src1_dims_1, simplified_num_dims_1,\n      simplified_src0_dims_1, simplified_src1_dims_1, simplified_dst_dims_1);\n\n  const size_t num_src0_dims_2 = 4;\n  const size_t num_src1_dims_2 = 1;\n  int64_t src0_dims_2[max_num_dims]{10, 5, 1, 5};\n  int64_t src1_dims_2[max_num_dims]{5};\n  const size_t simplified_num_dims_2 = 2;\n  int64_t simplified_src0_dims_2[max_num_dims]{50, 5};\n  int64_t simplified_src1_dims_2[max_num_dims]{1, 5};\n  int64_t simplified_dst_dims_2[max_num_dims]{50, 5};\n  TestSimplifyBroadcastDims<max_num_dims>(\n      num_src0_dims_2, src0_dims_2, num_src1_dims_2, src1_dims_2, simplified_num_dims_2,\n      simplified_src0_dims_2, simplified_src1_dims_2, simplified_dst_dims_2);\n\n  const size_t num_src0_dims_3 = 4;\n  const size_t num_src1_dims_3 = 1;\n  int64_t src0_dims_3[max_num_dims]{2, 5, 10, 5};\n  int64_t src1_dims_3[max_num_dims]{1};\n  const size_t simplified_num_dims_3 = 1;\n  int64_t simplified_src0_dims_3[max_num_dims]{500};\n  int64_t simplified_src1_dims_3[max_num_dims]{1};\n  int64_t simplified_dst_dims_3[max_num_dims]{500};\n  TestSimplifyBroadcastDims<max_num_dims>(\n      num_src0_dims_3, src0_dims_3, num_src1_dims_3, src1_dims_3, simplified_num_dims_3,\n      simplified_src0_dims_3, simplified_src1_dims_3, simplified_dst_dims_3);\n}\n\n}  // namespace\n\n}  // namespace primitive\n}  // namespace ep\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/ep/common/primitive/constant_pad.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_PRIMITIVE_COMMON_CONSTANT_PAD_H_\n#define ONEFLOW_CORE_PRIMITIVE_COMMON_CONSTANT_PAD_H_\n\n#include \"oneflow/core/ep/include/primitive/primitive.h\"\n#include \"oneflow/core/ep/include/primitive/fast_integer_math.h\"\n#include \"oneflow/core/common/nd_index_offset_helper.h\"\n\nnamespace oneflow {\n\nnamespace ep {\n\nnamespace primitive {\n\nnamespace {\n\nconstexpr int32_t kMaxNumDims = 8;\n\nconstexpr int32_t Min(int32_t a, int32_t b) { return a < b ? a : b; }\nconstexpr int32_t kMaxPackBytes = 128 / 8;\n\ntemplate<typename T>\nconstexpr int32_t GetMaxPackSize() {\n  return Min(kMaxPackBytes / sizeof(T), 8);\n}\n\ntemplate<typename T, int pack_size>\nstruct GetPackType {\n  using type = typename std::aligned_storage<pack_size * sizeof(T), pack_size * sizeof(T)>::type;\n};\n\ntemplate<typename T, int pack_size>\nusing PackType = typename GetPackType<T, pack_size>::type;\n\ntemplate<typename T, size_t pack_size>\nunion Pack {\n  static_assert(sizeof(PackType<T, pack_size>) == sizeof(T) * pack_size, \"\");\n  explicit OF_DEVICE_FUNC Pack(T value) {\n#ifdef __CUDA_ARCH__\n#pragma unroll\n#endif\n    for (int i = 0; i < pack_size; i++) { elem[i] = value; }\n  }\n  T elem[pack_size];\n  PackType<T, pack_size> storage;\n};\n\ntemplate<typename T>\nT GetValue(Scalar value) {\n  return value.Value<T>();\n}\n\ntemplate<typename T, int N>\nclass OffsetToIndexCalculator {\n public:\n  OffsetToIndexCalculator() {}\n  template<class... Ts>\n  OF_DEVICE_FUNC explicit OffsetToIndexCalculator(T d0, Ts... dims) {\n    constexpr int n = 1 + sizeof...(dims);\n    static_assert(n <= N, \"\");\n    T dims_arr[n] = {d0, static_cast<T>(dims)...};\n    InitFastIntegerMath(dims_arr, n);\n  }\n\n  OF_DEVICE_FUNC explicit OffsetToIndexCalculator(const T* dims) { InitFastIntegerMath(dims, N); }\n\n  template<typename U>\n  OF_DEVICE_FUNC explicit OffsetToIndexCalculator(const U* dims) {\n    T dims_arr[N];\n    for (int i = 0; i < N; ++i) { dims_arr[i] = dims[i]; }\n    InitFastIntegerMath(dims_arr, N);\n  }\n\n  OF_DEVICE_FUNC explicit OffsetToIndexCalculator(const T* dims, int n) {\n    InitFastIntegerMath(dims, n);\n  }\n\n  template<typename U>\n  OF_DEVICE_FUNC explicit OffsetToIndexCalculator(const U* dims, int n) {\n    T dims_arr[N];\n    for (int i = 0; i < N; ++i) {\n      if (i < n) { dims_arr[i] = dims[i]; }\n    }\n    InitFastIntegerMath(dims_arr, n);\n  }\n\n  ~OffsetToIndexCalculator() = default;\n\n  OF_DEVICE_FUNC void OffsetToNdIndex(T offset, T* index) const {\n    T remaining = offset;\n#ifdef __CUDA_ARCH__\n#pragma unroll\n#endif\n    for (int i = 0; i < N - 1; ++i) {\n      const T idx = math_helper_[i].divides(remaining);\n      index[i] = idx;\n      remaining = remaining - math_helper_[i].mul(idx);\n    }\n    index[N - 1] = remaining;\n  }\n\n  OF_DEVICE_FUNC void OffsetToNdIndex(T offset, T* index, int n) const {\n    assert(n <= N);\n    T remaining = offset;\n#ifdef __CUDA_ARCH__\n#pragma unroll\n#endif\n    for (int i = 0; i < N; ++i) {\n      if (i == n - 1) { break; }\n      if (i < n - 1) {\n        const T idx = math_helper_[i].divides(remaining);\n        index[i] = idx;\n        remaining = remaining - math_helper_[i].mul(idx);\n      }\n    }\n    index[n - 1] = remaining;\n  }\n\n  template<class... Ts>\n  OF_DEVICE_FUNC void OffsetToNdIndex(T offset, T& d0, Ts&... others) const {\n    constexpr int n = 1 + sizeof...(others);\n    static_assert(n <= N, \"\");\n    T* index[n] = {&d0, &others...};\n    T remaining = offset;\n#ifdef __CUDA_ARCH__\n#pragma unroll\n#endif\n    for (int i = 0; i < n - 1; ++i) {\n      const T idx = math_helper_[i].divides(remaining);\n      *index[i] = idx;\n      remaining = remaining - math_helper_[i].mul(idx);\n    }\n    if (n == N) {\n      *index[n - 1] = remaining;\n    } else {\n      *index[n - 1] = math_helper_[n - 1].divides(remaining);\n    }\n  }\n\n  OF_DEVICE_FUNC constexpr int Size() const { return N; }\n\n private:\n  OF_DEVICE_FUNC void InitFastIntegerMath(const T* dims, const int n) {\n    T stride_arr[N];\n    for (int i = n - 1; i < N; ++i) {\n      stride_arr[i] = 1;\n      math_helper_[i] = FastIntegerMath<T>(1);\n    }\n    for (int i = n - 2; i >= 0; --i) {\n      stride_arr[i] = dims[i + 1] * stride_arr[i + 1];\n      math_helper_[i] = FastIntegerMath<T>(stride_arr[i]);\n    }\n  }\n  FastIntegerMath<T> math_helper_[N];\n};\n\ntemplate<size_t num_dims, typename IndexType>\nstruct ConstantPadParams {\n  NdIndexOffsetHelper<IndexType, num_dims> src_index_helper;\n  OffsetToIndexCalculator<IndexType, num_dims> dst_index_helper;\n  IndexType valid_start[num_dims];\n  IndexType valid_end[num_dims];\n  IndexType elem_cnt{};\n  const void* src{};\n  void* dst{};\n};\n\ntemplate<size_t max_pack_size>\nsize_t GetLaunchPackSize(size_t num_dims, void* dst, const int64_t* dst_dims, const void* src,\n                         const int64_t* src_dims, const int64_t* padding_before,\n                         const int64_t* padding_after) {\n  static_assert(max_pack_size > 0 && (max_pack_size & (max_pack_size - 1)) == 0, \"\");\n  const int64_t last_dst_dim_size = dst_dims[num_dims - 1];\n  const int64_t last_src_dim_size = src_dims[num_dims - 1];\n  const int64_t last_padding_before_size = padding_before[num_dims - 1];\n  const int64_t last_padding_after_size = padding_after[num_dims - 1];\n  auto src_ptr = reinterpret_cast<std::uintptr_t>(src);\n  auto dst_ptr = reinterpret_cast<std::uintptr_t>(dst);\n  for (size_t size = max_pack_size; size > 1; size /= 2) {\n    if (last_dst_dim_size % size == 0 && last_src_dim_size % size == 0\n        && last_padding_before_size % size == 0 && last_padding_after_size % size == 0\n        && src_ptr % size == 0 && dst_ptr % size == 0) {\n      return size;\n    }\n  }\n  return 1;\n}\n\nvoid SimplifyPadDims(size_t num_dims, const int64_t* src_dims, const int64_t* padding_before,\n                     const int64_t* padding_after, size_t* simplified_num_dims,\n                     int64_t* simplified_dst_dims, int64_t* simplified_src_dims,\n                     int64_t* simplified_padding_before, int64_t* simplified_padding_after) {\n  CHECK_NE(num_dims, 0);\n  size_t valid_num_dims = 0;\n  FOR_RANGE(size_t, i, 0, num_dims) {\n    const int64_t dst_dim = src_dims[i] + padding_before[i] + padding_after[i];\n    if ((i != 0) && (padding_before[i] == 0 && padding_after[i] == 0)) {\n      simplified_dst_dims[valid_num_dims - 1] *= dst_dim;\n      simplified_src_dims[valid_num_dims - 1] *= src_dims[i];\n      simplified_padding_before[valid_num_dims - 1] *= src_dims[i];\n      simplified_padding_after[valid_num_dims - 1] *= src_dims[i];\n    } else {\n      simplified_dst_dims[valid_num_dims] = dst_dim;\n      simplified_src_dims[valid_num_dims] = src_dims[i];\n      simplified_padding_before[valid_num_dims] = padding_before[i];\n      simplified_padding_after[valid_num_dims] = padding_after[i];\n      valid_num_dims += 1;\n    }\n  }\n  *simplified_num_dims = valid_num_dims;\n}\n\n}  // namespace\n\n}  // namespace primitive\n\n}  // namespace ep\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_PRIMITIVE_COMMON_CONSTANT_PAD_H_\n"
  },
  {
    "path": "oneflow/core/ep/common/primitive/copy_nd.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#ifndef ONEFLOW_CORE_PRIMITIVE_COMMON_COPY_ND_H_\n#define ONEFLOW_CORE_PRIMITIVE_COMMON_COPY_ND_H_\n\n#include \"oneflow/core/ep/include/primitive/primitive.h\"\n#include \"oneflow/core/common/nd_index_offset_helper.h\"\n\nnamespace oneflow {\n\nnamespace ep {\nnamespace primitive {\n\nnamespace {\n\ntemplate<size_t num_dims, typename IndexType>\nstruct CopyNdKernelParams {\n  NdIndexOffsetHelper<IndexType, num_dims> src_index_helper;\n  NdIndexOffsetHelper<IndexType, num_dims> dst_index_helper;\n  NdIndexOffsetHelper<IndexType, num_dims> copy_index_helper;\n  IndexType dst_pos[num_dims];\n  IndexType src_pos[num_dims];\n  IndexType count{};\n  const void* src{};\n  void* dst{};\n};\n\ntemplate<size_t max_movement_size>\nsize_t GetMovementSize(size_t elem_size, size_t num_dims, void* dst, const int64_t* dst_dims,\n                       const int64_t* dst_pos, const void* src, const int64_t* src_dims,\n                       const int64_t* src_pos, const int64_t* extent) {\n  static_assert(max_movement_size > 0 && (max_movement_size & (max_movement_size - 1)) == 0, \"\");\n  CHECK_GT(elem_size, 0);\n  CHECK_EQ((elem_size & (elem_size - 1)), 0);\n  CHECK_EQ(max_movement_size % elem_size, 0);\n  const int64_t last_dst_dim_size = dst_dims[num_dims - 1] * elem_size;\n  const int64_t last_dst_pos = dst_pos[num_dims - 1] * elem_size;\n  const int64_t last_src_dim_size = src_dims[num_dims - 1] * elem_size;\n  const int64_t last_src_pos = src_pos[num_dims - 1] * elem_size;\n  const int64_t last_extent = extent[num_dims - 1] * elem_size;\n  auto src_ptr = reinterpret_cast<std::uintptr_t>(src);\n  auto dst_ptr = reinterpret_cast<std::uintptr_t>(dst);\n  for (size_t size = max_movement_size; size > elem_size; size /= 2) {\n    if (last_dst_dim_size % size == 0 && last_dst_pos % size == 0 && last_src_dim_size % size == 0\n        && last_src_pos % size == 0 && last_extent % size == 0 && src_ptr % size == 0\n        && dst_ptr % size == 0) {\n      return size;\n    }\n  }\n  return elem_size;\n}\n\nvoid SimplifyCopyNdDims(size_t num_dims, const int64_t* dst_dims, const int64_t* dst_pos,\n                        const int64_t* src_dims, const int64_t* src_pos, const int64_t* extent,\n                        size_t* simplified_num_dims, int64_t* simplified_dst_dims,\n                        int64_t* simplified_dst_pos, int64_t* simplified_src_dims,\n                        int64_t* simplified_src_pos, int64_t* simplified_extent) {\n  CHECK_NE(num_dims, 0);\n  size_t valid_num_dims = 0;\n  FOR_RANGE(size_t, i, 0, num_dims) {\n    if ((i != 0) && (dst_dims[i] == src_dims[i]) && (dst_dims[i] == extent[i]) && (src_pos[i] == 0)\n        && (dst_pos[i] == 0)) {\n      simplified_dst_dims[valid_num_dims - 1] *= extent[i];\n      simplified_dst_pos[valid_num_dims - 1] *= extent[i];\n      simplified_src_dims[valid_num_dims - 1] *= extent[i];\n      simplified_src_pos[valid_num_dims - 1] *= extent[i];\n      simplified_extent[valid_num_dims - 1] *= extent[i];\n    } else {\n      simplified_dst_dims[valid_num_dims] = dst_dims[i];\n      simplified_dst_pos[valid_num_dims] = dst_pos[i];\n      simplified_src_dims[valid_num_dims] = src_dims[i];\n      simplified_src_pos[valid_num_dims] = src_pos[i];\n      simplified_extent[valid_num_dims] = extent[i];\n      valid_num_dims += 1;\n    }\n  }\n  *simplified_num_dims = valid_num_dims;\n}\n\nconstexpr size_t kMaxMovementSize = 16;\nconstexpr size_t kMaxNumDims = 8;\n\ntemplate<size_t num_dims, size_t movement_size, typename IndexType>\nvoid LaunchKernel(Stream* stream, CopyNdKernelParams<num_dims, IndexType> params);\n\ntemplate<size_t num_dims, size_t movement_size, typename IndexType>\nvoid LaunchKernel(Stream* stream, void* dst, const int64_t* dst_dims, const int64_t* dst_pos,\n                  const void* src, const int64_t* src_dims, const int64_t* src_pos,\n                  const int64_t* extent, size_t count) {\n  CopyNdKernelParams<num_dims, IndexType> params;\n  params.dst_index_helper = NdIndexOffsetHelper<IndexType, num_dims>(dst_dims);\n  params.src_index_helper = NdIndexOffsetHelper<IndexType, num_dims>(src_dims);\n  params.copy_index_helper = NdIndexOffsetHelper<IndexType, num_dims>(extent);\n  for (size_t i = 0; i < num_dims; ++i) {\n    params.dst_pos[i] = dst_pos[i];\n    params.src_pos[i] = src_pos[i];\n  }\n  params.src = src;\n  params.dst = dst;\n  params.count = static_cast<IndexType>(count);\n  LaunchKernel<num_dims, movement_size, IndexType>(stream, params);\n}\n\ntemplate<size_t num_dims, size_t movement_size>\nvoid DispatchIndexType(Stream* stream, void* dst, const int64_t* dst_dims, const int64_t* dst_pos,\n                       const void* src, const int64_t* src_dims, const int64_t* src_pos,\n                       const int64_t* extent) {\n  size_t count = 1;\n  for (size_t i = 0; i < num_dims; ++i) { count *= extent[i]; }\n  if (count < GetMaxVal<int32_t>()) {\n    LaunchKernel<num_dims, movement_size, int32_t>(stream, dst, dst_dims, dst_pos, src, src_dims,\n                                                   src_pos, extent, count);\n  } else {\n    LaunchKernel<num_dims, movement_size, int64_t>(stream, dst, dst_dims, dst_pos, src, src_dims,\n                                                   src_pos, extent, count);\n  }\n}\n\ntemplate<size_t num_dims>\nvoid DispatchMovementSize(Stream* stream, size_t movement_size, void* dst, const int64_t* dst_dims,\n                          const int64_t* dst_pos, const void* src, const int64_t* src_dims,\n                          const int64_t* src_pos, const int64_t* extent) {\n  void (*func)(Stream* /*stream*/, void* /*dst*/, const int64_t* /*dst_dims*/,\n               const int64_t* /*dst_pos*/, const void* /*src*/, const int64_t* /*src_dims*/,\n               const int64_t* /*src_pos*/, const int64_t* /*extent*/) = nullptr;\n  if (movement_size == 1) {\n    func = DispatchIndexType<num_dims, 1>;\n  } else if (movement_size == 2) {\n    func = DispatchIndexType<num_dims, 2>;\n  } else if (movement_size == 4) {\n    func = DispatchIndexType<num_dims, 4>;\n  } else if (movement_size == 8) {\n    func = DispatchIndexType<num_dims, 8>;\n  } else if (movement_size == 16) {\n    func = DispatchIndexType<num_dims, 16>;\n  } else {\n    UNIMPLEMENTED();\n  }\n  func(stream, dst, dst_dims, dst_pos, src, src_dims, src_pos, extent);\n}\n\nvoid LaunchWithSimplified(Stream* stream, size_t movement_size, size_t num_dims, void* dst,\n                          const int64_t* dst_dims, const int64_t* dst_pos, const void* src,\n                          const int64_t* src_dims, const int64_t* src_pos, const int64_t* extent) {\n  void (*func)(Stream* /*stream*/, size_t /*movement_size*/, void* /*dst*/,\n               const int64_t* /*dst_dims*/, const int64_t* /*dst_pos*/, const void* /*src*/,\n               const int64_t* /*src_dims*/, const int64_t* /*src_pos*/, const int64_t* /*extent*/) =\n      nullptr;\n  if (num_dims == 1) {\n    func = DispatchMovementSize<1>;\n  } else if (num_dims == 2) {\n    func = DispatchMovementSize<2>;\n  } else if (num_dims == 3) {\n    func = DispatchMovementSize<3>;\n  } else if (num_dims == 4) {\n    func = DispatchMovementSize<4>;\n  } else if (num_dims == 5) {\n    func = DispatchMovementSize<5>;\n  } else if (num_dims == 6) {\n    func = DispatchMovementSize<6>;\n  } else if (num_dims == 7) {\n    func = DispatchMovementSize<7>;\n  } else if (num_dims == 8) {\n    func = DispatchMovementSize<8>;\n  } else {\n    UNIMPLEMENTED();\n  }\n  func(stream, movement_size, dst, dst_dims, dst_pos, src, src_dims, src_pos, extent);\n}\n\ntemplate<size_t max_movement_size>\nvoid SimplifyCopyNd(size_t num_dims, const int64_t* dst_dims, const int64_t* dst_pos,\n                    const int64_t* src_dims, const int64_t* src_pos, const int64_t* extent,\n                    size_t* simplified_num_dims, int64_t* simplified_dst_dims,\n                    int64_t* simplified_dst_pos, int64_t* simplified_src_dims,\n                    int64_t* simplified_src_pos, int64_t* simplified_extent, size_t elem_size,\n                    void* dst, const void* src, size_t* movement_size) {\n  SimplifyCopyNdDims(num_dims, dst_dims, dst_pos, src_dims, src_pos, extent, simplified_num_dims,\n                     simplified_dst_dims, simplified_dst_pos, simplified_src_dims,\n                     simplified_src_pos, simplified_extent);\n  *movement_size = GetMovementSize<max_movement_size>(\n      elem_size, *simplified_num_dims, dst, simplified_dst_dims, simplified_dst_pos, src,\n      simplified_src_dims, simplified_src_pos, simplified_extent);\n  size_t movement_elem_num = *movement_size / elem_size;\n  simplified_dst_dims[*simplified_num_dims - 1] /= movement_elem_num;\n  simplified_dst_pos[*simplified_num_dims - 1] /= movement_elem_num;\n  simplified_src_dims[*simplified_num_dims - 1] /= movement_elem_num;\n  simplified_src_pos[*simplified_num_dims - 1] /= movement_elem_num;\n  simplified_extent[*simplified_num_dims - 1] /= movement_elem_num;\n}\n\nvoid SimplifyThenLaunch(Stream* stream, DataType data_type, size_t num_dims, void* dst,\n                        const int64_t* dst_dims, const int64_t* dst_pos, const void* src,\n                        const int64_t* src_dims, const int64_t* src_pos, const int64_t* extent) {\n  CHECK_GT(num_dims, 0) << \"num_dims must greater than 0\";\n  CHECK_LE(num_dims, kMaxNumDims);\n  size_t simplified_num_dims = 0;\n  int64_t simplified_dst_dims[kMaxNumDims];\n  int64_t simplified_dst_pos[kMaxNumDims];\n  int64_t simplified_src_dims[kMaxNumDims];\n  int64_t simplified_src_pos[kMaxNumDims];\n  int64_t simplified_extent[kMaxNumDims];\n  size_t movement_size;\n  SimplifyCopyNd<kMaxMovementSize>(num_dims, dst_dims, dst_pos, src_dims, src_pos, extent,\n                                   &simplified_num_dims, simplified_dst_dims, simplified_dst_pos,\n                                   simplified_src_dims, simplified_src_pos, simplified_extent,\n                                   GetSizeOfDataType(data_type), dst, src, &movement_size);\n  LaunchWithSimplified(stream, movement_size, simplified_num_dims, dst, simplified_dst_dims,\n                       simplified_dst_pos, src, simplified_src_dims, simplified_src_pos,\n                       simplified_extent);\n}\n\n}  // namespace\n\n}  // namespace primitive\n}  // namespace ep\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_PRIMITIVE_COMMON_COPY_ND_H_\n"
  },
  {
    "path": "oneflow/core/ep/common/primitive/elementwise_unary.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_EP_COMMON_PRIMITIVE_ELEMENTWISE_UNARY_H_\n#define ONEFLOW_CORE_EP_COMMON_PRIMITIVE_ELEMENTWISE_UNARY_H_\n\n#include \"oneflow/core/ep/include/primitive/elementwise_unary.h\"\n\nnamespace oneflow {\n\nnamespace ep {\nnamespace primitive {\n\n#define UNARY_MATH_OP_SEQ              \\\n  OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kRelu) \\\n  OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kIdentity)\n\n#define UNARY_FLOATING_MATH_OP_SEQ                \\\n  OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kElu)             \\\n  OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kCelu)            \\\n  OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kGelu)            \\\n  OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kHardSwish)       \\\n  OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kHardSigmoid)     \\\n  OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kHardShrink)      \\\n  OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kHardTanh)        \\\n  OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kLeakyRelu)       \\\n  OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kMish)            \\\n  OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kSelu)            \\\n  OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kSilu)            \\\n  OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kSoftShrink)      \\\n  OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kSoftSign)        \\\n  OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kSoftPlus)        \\\n  OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kTanh)            \\\n  OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kThreshold)       \\\n  OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kAbs)             \\\n  OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kAcos)            \\\n  OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kAcosh)           \\\n  OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kAsin)            \\\n  OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kAsinh)           \\\n  OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kAtan)            \\\n  OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kAtanh)           \\\n  OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kCeil)            \\\n  OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kCos)             \\\n  OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kCosh)            \\\n  OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kDigamma)         \\\n  OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kTrigamma)        \\\n  OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kErf)             \\\n  OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kErfc)            \\\n  OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kExp)             \\\n  OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kExp2)            \\\n  OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kExpm1)           \\\n  OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kFloor)           \\\n  OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kLgamma)          \\\n  OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kLog)             \\\n  OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kLog2)            \\\n  OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kLog10)           \\\n  OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kLog1p)           \\\n  OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kLogSigmoid)      \\\n  OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kNegative)        \\\n  OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kReciprocal)      \\\n  OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kReciprocalNoNan) \\\n  OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kRint)            \\\n  OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kRound)           \\\n  OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kRsqrt)           \\\n  OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kSigmoid)         \\\n  OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kSign)            \\\n  OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kSin)             \\\n  OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kSinh)            \\\n  OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kSqrt)            \\\n  OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kSign)            \\\n  OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kSquare)          \\\n  OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kTan)             \\\n  OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kTrunc)           \\\n  OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kNotEqualZero)    \\\n  OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kNanAssign)       \\\n  OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kFastGelu)        \\\n  OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kQuickGelu)       \\\n  OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kSquareReLU)\n\n#define UNARY_COMPLEX_C2C_OP_SEQ       \\\n  OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kConj) \\\n  OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kSqrt) \\\n  OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kNegative)\n\n#define UNARY_COMPLEX_C2R_OP_SEQ       \\\n  OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kReal) \\\n  OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kImag)\n\n#define UNARY_COMPLEX_R2C_OP_SEQ           \\\n  OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kRealGrad) \\\n  OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kImagGrad)\n\n#define UNARY_INT_MATH_OP_SEQ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kAbs)\n\n#define UNARY_LOGICAL_OP_SEQ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kLogicalNot)\n\n#define UNARY_BITWISE_OP_SEQ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kBitwiseNot)\n\n#define UNARY_UTILS_OP_SEQ              \\\n  OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kIsInf) \\\n  OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kIsNan) \\\n  OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kIsFinite)\n\n}  // namespace primitive\n}  // namespace ep\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_EP_COMMON_PRIMITIVE_ELEMENTWISE_UNARY_H_\n"
  },
  {
    "path": "oneflow/core/ep/common/primitive/matmul.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/ep/include/primitive/matmul.h\"\n#include \"oneflow/core/ep/include/primitive/batch_matmul.h\"\n\nnamespace oneflow {\n\nnamespace ep {\nnamespace primitive {\n\nnamespace {\n\nclass MatmulImpl : public Matmul {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(MatmulImpl);\n  explicit MatmulImpl(std::unique_ptr<BatchMatmul>&& batch_matmul)\n      : batch_matmul_(std::move(batch_matmul)) {}\n  ~MatmulImpl() override = default;\n\n  void Launch(Stream* stream, size_t m, size_t n, size_t k, Scalar alpha, const void* a,\n              const void* b, Scalar beta, void* c) override {\n    batch_matmul_->Launch(stream, 1, m, n, k, alpha, a, b, beta, c);\n  }\n\n private:\n  std::unique_ptr<BatchMatmul> batch_matmul_;\n};\n\ntemplate<DeviceType device_type>\nclass MatmulFactoryImpl : public MatmulFactory {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(MatmulFactoryImpl);\n  MatmulFactoryImpl() = default;\n  ~MatmulFactoryImpl() override = default;\n\n  std::unique_ptr<Matmul> New(DataType data_type, BlasTransposeType transpose_a,\n                              BlasTransposeType transpose_b) override {\n    auto batch_matmul =\n        NewPrimitive<BatchMatmulFactory>(device_type, data_type, transpose_a, transpose_b);\n    if (!batch_matmul) { return nullptr; }\n    return std::make_unique<MatmulImpl>(std::move(batch_matmul));\n  }\n};\n\nREGISTER_PRIMITIVE_FACTORY(DeviceType::kCPU, MatmulFactory, MatmulFactoryImpl<DeviceType::kCPU>);\n\n#ifdef WITH_CUDA\nREGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, MatmulFactory, MatmulFactoryImpl<DeviceType::kCUDA>);\n#endif  // WITH_CUDA\n\n}  // namespace\n\n}  // namespace primitive\n}  // namespace ep\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/ep/common/primitive/permute.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_EP_COMMON_PRIMITIVE_PERMUTE_H_\n#define ONEFLOW_CORE_EP_COMMON_PRIMITIVE_PERMUTE_H_\n\n#include \"oneflow/core/ep/include/primitive/primitive.h\"\n#include \"oneflow/core/common/nd_index_offset_helper.h\"\n\nnamespace oneflow {\n\nnamespace ep {\nnamespace primitive {\n\nnamespace permute {\n\ntemplate<size_t max_movement_size>\nsize_t GetMovementSize(size_t elem_size, size_t num_dims, const int64_t* src_dims, const void* src,\n                       const int* permutation, void* dst) {\n  static_assert(max_movement_size > 0 && (max_movement_size & (max_movement_size - 1)) == 0, \"\");\n  CHECK_GT(elem_size, 0);\n  CHECK_EQ((elem_size & (elem_size - 1)), 0);\n  CHECK_EQ(max_movement_size % elem_size, 0);\n  if (permutation[num_dims - 1] == num_dims - 1) {\n    const int64_t last_dim_size = src_dims[num_dims - 1] * elem_size;\n    auto src_ptr = reinterpret_cast<std::uintptr_t>(src);\n    auto dst_ptr = reinterpret_cast<std::uintptr_t>(dst);\n    for (size_t size = max_movement_size; size > elem_size; size /= 2) {\n      if (last_dim_size % size == 0 && src_ptr % size == 0 && dst_ptr % size == 0) { return size; }\n    }\n  }\n  return elem_size;\n}\n\ntemplate<size_t max_num_dims>\nvoid SimplifyPermutation(size_t num_dims, const int64_t* src_dims, const int* permutation,\n                         size_t* simplified_num_dims, int64_t* simplified_src_dims,\n                         int* simplified_permutation) {\n  CHECK_NE(num_dims, 0);\n  int64_t coalesced_dims[max_num_dims];\n  size_t start_permutation_index = 0;\n  while (start_permutation_index < num_dims) {\n    const size_t start_dim_index = permutation[start_permutation_index];\n    coalesced_dims[start_dim_index] = src_dims[start_dim_index];\n    size_t end_permutation_index = start_permutation_index + 1;\n    while (end_permutation_index < num_dims\n           && permutation[end_permutation_index] == permutation[end_permutation_index - 1] + 1) {\n      const size_t end_dim_index = permutation[end_permutation_index];\n      coalesced_dims[start_dim_index] *= src_dims[end_dim_index];\n      coalesced_dims[end_dim_index] = 1;\n      end_permutation_index += 1;\n    }\n    start_permutation_index = end_permutation_index;\n  }\n  size_t valid_num_dims = 0;\n  int mapping[max_num_dims];\n  for (size_t i = 0; i < num_dims; ++i) {\n    const int src_dim = coalesced_dims[i];\n    if (src_dim == 1) {\n      mapping[i] = -1;\n    } else {\n      mapping[i] = valid_num_dims;\n      simplified_src_dims[valid_num_dims] = src_dim;\n      valid_num_dims += 1;\n    }\n  }\n  if (valid_num_dims == 0) {\n    *simplified_num_dims = 1;\n    simplified_src_dims[0] = 1;\n    simplified_permutation[0] = 0;\n  } else {\n    *simplified_num_dims = valid_num_dims;\n    size_t permutation_index = 0;\n    for (size_t i = 0; i < num_dims; ++i) {\n      const int mapped = mapping[permutation[i]];\n      if (mapped >= 0) {\n        simplified_permutation[permutation_index] = mapped;\n        permutation_index += 1;\n      }\n    }\n  }\n}\n\ntemplate<size_t max_num_dims, size_t max_movement_size>\nvoid SimplifyPermutation(size_t num_dims, const int64_t* src_dims, const int* permutation,\n                         size_t* simplified_num_dims, int64_t* simplified_src_dims,\n                         int* simplified_permutation, size_t elem_size, const void* src, void* dst,\n                         size_t* movement_size) {\n  const size_t pre_simplified_movement_size =\n      GetMovementSize<max_movement_size>(elem_size, num_dims, src_dims, src, permutation, dst);\n  int64_t tmp_dims[max_num_dims];\n  for (size_t i = 0; i < num_dims; ++i) { tmp_dims[i] = src_dims[i]; }\n  tmp_dims[num_dims - 1] /= (pre_simplified_movement_size / elem_size);\n  SimplifyPermutation<max_num_dims>(num_dims, tmp_dims, permutation, simplified_num_dims,\n                                    simplified_src_dims, simplified_permutation);\n  *movement_size =\n      GetMovementSize<max_movement_size>(pre_simplified_movement_size, *simplified_num_dims,\n                                         simplified_src_dims, src, simplified_permutation, dst);\n  simplified_src_dims[*simplified_num_dims - 1] /= (*movement_size / pre_simplified_movement_size);\n}\n\n}  // namespace permute\n\n}  // namespace primitive\n}  // namespace ep\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_EP_COMMON_PRIMITIVE_PERMUTE_H_\n"
  },
  {
    "path": "oneflow/core/ep/common/primitive/permute_impl.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_EP_COMMON_PRIMITIVE_PERMUTE_IMPL_H_\n#define ONEFLOW_CORE_EP_COMMON_PRIMITIVE_PERMUTE_IMPL_H_\n\n#include \"oneflow/core/ep/include/primitive/permute.h\"\n#include \"oneflow/core/ep/common/primitive/permute.h\"\n\nnamespace oneflow {\n\nnamespace ep {\nnamespace primitive {\n\nnamespace permute {\n\nnamespace internal {\n\nnamespace {\n\ntemplate<size_t num_dims, typename IndexType>\nstruct PermuteKernelParams {\n  NdIndexOffsetHelper<IndexType, num_dims> src_index_helper;\n  NdIndexOffsetHelper<IndexType, num_dims> dst_index_helper;\n  int permutation[num_dims]{};\n  IndexType count{};\n  const void* src{};\n  void* dst{};\n};\n\nconstexpr size_t kMaxMovementSize = 16;\nconstexpr size_t kMaxNumDims = 8;\n\ntemplate<size_t num_dims, typename IndexType>\nPermuteKernelParams<num_dims, IndexType> MakePermuteParams(const int64_t* src_dims, const void* src,\n                                                           const int* permutation, void* dst,\n                                                           size_t count) {\n  PermuteKernelParams<num_dims, IndexType> params;\n  params.src_index_helper = NdIndexOffsetHelper<IndexType, num_dims>(src_dims);\n  int64_t dst_dims[num_dims];\n  for (size_t i = 0; i < num_dims; ++i) { dst_dims[i] = src_dims[permutation[i]]; }\n  params.dst_index_helper = NdIndexOffsetHelper<IndexType, num_dims>(dst_dims);\n  for (size_t i = 0; i < num_dims; ++i) { params.permutation[i] = permutation[i]; }\n  params.src = src;\n  params.dst = dst;\n  params.count = static_cast<IndexType>(count);\n  return params;\n}\n\ntemplate<size_t num_dims, size_t movement_size, typename IndexType>\nvoid LaunchKernel(Stream* stream, const int64_t* src_dims, const void* src, const int* permutation,\n                  void* dst, size_t count);\n\ntemplate<size_t num_dims, size_t movement_size>\nvoid DispatchIndexType(Stream* stream, const int64_t* src_dims, const void* src,\n                       const int* permutation, void* dst) {\n  size_t count = 1;\n  for (size_t i = 0; i < num_dims; ++i) { count *= src_dims[i]; }\n  if (count < GetMaxVal<int32_t>()) {\n    LaunchKernel<num_dims, movement_size, int32_t>(stream, src_dims, src, permutation, dst, count);\n  } else {\n    LaunchKernel<num_dims, movement_size, int64_t>(stream, src_dims, src, permutation, dst, count);\n  }\n}\n\ntemplate<size_t num_dims>\nvoid DispatchMovementSize(Stream* stream, size_t movement_size, const int64_t* src_dims,\n                          const void* src, const int* permutation, void* dst) {\n  void (*func)(Stream* /*stream*/, const int64_t* /*src_dims*/, const void* /*src*/,\n               const int* /*permutation*/, void* /*dst*/) = nullptr;\n  if (movement_size == 1) {\n    func = DispatchIndexType<num_dims, 1>;\n  } else if (movement_size == 2) {\n    func = DispatchIndexType<num_dims, 2>;\n  } else if (movement_size == 4) {\n    func = DispatchIndexType<num_dims, 4>;\n  } else if (movement_size == 8) {\n    func = DispatchIndexType<num_dims, 8>;\n  } else if (movement_size == 16) {\n    func = DispatchIndexType<num_dims, 16>;\n  } else {\n    UNIMPLEMENTED();\n  }\n  func(stream, src_dims, src, permutation, dst);\n}\n\nvoid LaunchWithSimplified(Stream* stream, size_t movement_size, size_t num_dims,\n                          const int64_t* src_dims, const void* src, const int* permutation,\n                          void* dst) {\n  void (*func)(Stream* /*stream*/, size_t /*movement_size*/, const int64_t* /*src_dims*/,\n               const void* /*src*/, const int* /*permutation*/, void* /*dst*/) = nullptr;\n  if (num_dims == 1) {\n    func = DispatchMovementSize<1>;\n  } else if (num_dims == 2) {\n    func = DispatchMovementSize<2>;\n  } else if (num_dims == 3) {\n    func = DispatchMovementSize<3>;\n  } else if (num_dims == 4) {\n    func = DispatchMovementSize<4>;\n  } else if (num_dims == 5) {\n    func = DispatchMovementSize<5>;\n  } else if (num_dims == 6) {\n    func = DispatchMovementSize<6>;\n  } else if (num_dims == 7) {\n    func = DispatchMovementSize<7>;\n  } else if (num_dims == 8) {\n    func = DispatchMovementSize<8>;\n  } else {\n    UNIMPLEMENTED();\n  }\n  func(stream, movement_size, src_dims, src, permutation, dst);\n}\n\nvoid SimplifyThenLaunch(Stream* stream, DataType data_type, size_t num_dims,\n                        const int64_t* src_dims, const void* src, const int* permutation,\n                        void* dst) {\n  CHECK_LE(num_dims, kMaxNumDims);\n  CHECK_GT(num_dims, 0);\n  size_t simplified_num_dims = 0;\n  int64_t simplified_src_dims[kMaxNumDims];\n  int simplified_permutation[kMaxNumDims];\n  size_t movement_size = 0;\n  SimplifyPermutation<kMaxNumDims, kMaxMovementSize>(\n      num_dims, src_dims, permutation, &simplified_num_dims, simplified_src_dims,\n      simplified_permutation, GetSizeOfDataType(data_type), src, dst, &movement_size);\n  LaunchWithSimplified(stream, movement_size, simplified_num_dims, simplified_src_dims, src,\n                       simplified_permutation, dst);\n}\n\n}  // namespace\n\n}  // namespace internal\n\n}  // namespace permute\n\n}  // namespace primitive\n}  // namespace ep\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_EP_COMMON_PRIMITIVE_PERMUTE_IMPL_H_\n"
  },
  {
    "path": "oneflow/core/ep/common/primitive/permute_test.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/ep/common/primitive/permute.h\"\n#include <gtest/gtest.h>\n\nnamespace oneflow {\n\nnamespace ep {\nnamespace primitive {\n\nnamespace permute {\n\nnamespace {\n\ntemplate<size_t max_num_dims>\nvoid TestSimplifyPermutation(size_t num_dims, const int64_t* src_dims, const int* permutation,\n                             size_t expected_num_dims, const int64_t* expected_src_dims,\n                             const int* expected_permutation) {\n  size_t simplified_num_dims = 0;\n  int64_t simplified_src_dims[max_num_dims]{};\n  int simplified_permutation[max_num_dims]{};\n  SimplifyPermutation<max_num_dims>(num_dims, src_dims, permutation, &simplified_num_dims,\n                                    simplified_src_dims, simplified_permutation);\n  ASSERT_EQ(simplified_num_dims, expected_num_dims);\n  for (size_t i = 0; i < simplified_num_dims; ++i) {\n    ASSERT_EQ(simplified_src_dims[i], expected_src_dims[i]);\n    ASSERT_EQ(simplified_permutation[i], expected_permutation[i]);\n  }\n}\n\nTEST(Permute, SimplifyPermutation) {\n  constexpr size_t max_num_dims = 8;\n\n  const size_t num_dims_1 = 5;\n  int64_t src_dims_1[max_num_dims]{1, 2, 2, 1, 2};\n  int permutation_1[max_num_dims]{0, 1, 3, 4, 2};\n  const size_t simplified_num_dims_1 = 3;\n  int64_t simplified_src_dims_1[max_num_dims]{2, 2, 2};\n  int simplified_permutation_1[max_num_dims]{0, 2, 1};\n  TestSimplifyPermutation<max_num_dims>(num_dims_1, src_dims_1, permutation_1,\n                                        simplified_num_dims_1, simplified_src_dims_1,\n                                        simplified_permutation_1);\n\n  const size_t num_dims_2 = 4;\n  int64_t src_dims_2[max_num_dims]{5, 6, 7, 8};\n  int permutation_2[max_num_dims]{2, 3, 0, 1};\n  const size_t simplified_num_dims_2 = 2;\n  int64_t simplified_src_dims_2[max_num_dims]{5 * 6, 7 * 8};\n  int simplified_permutation_2[max_num_dims]{1, 0};\n  TestSimplifyPermutation<max_num_dims>(num_dims_2, src_dims_2, permutation_2,\n                                        simplified_num_dims_2, simplified_src_dims_2,\n                                        simplified_permutation_2);\n\n  const size_t num_dims_3 = 4;\n  int64_t src_dims_3[max_num_dims]{5, 6, 7, 8};\n  int permutation_3[max_num_dims]{0, 1, 2, 3};\n  const size_t simplified_num_dims_3 = 1;\n  int64_t simplified_src_dims_3[max_num_dims]{5 * 6 * 7 * 8};\n  int simplified_permutation_3[max_num_dims]{0};\n  TestSimplifyPermutation<max_num_dims>(num_dims_3, src_dims_3, permutation_3,\n                                        simplified_num_dims_3, simplified_src_dims_3,\n                                        simplified_permutation_3);\n}\n\n}  // namespace\n\n}  // namespace permute\n\n}  // namespace primitive\n}  // namespace ep\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/ep/common/primitive/unary_functor.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_EP_COMMON_PRIMITIVE_UNARY_FUNCTOR_H_\n#define ONEFLOW_CORE_EP_COMMON_PRIMITIVE_UNARY_FUNCTOR_H_\n\n#include \"oneflow/core/ep/include/primitive/unary_op.h\"\n#include \"oneflow/core/common/data_type.h\"\n#include \"oneflow/core/common/scalar.h\"\n\nnamespace oneflow {\n\nnamespace ep {\nnamespace primitive {\n\ntemplate<DeviceType device, UnaryOp unary_op, typename Dst, typename Src>\nstruct UnaryFunctor;\n\ntemplate<DeviceType device, typename Dst, typename Src>\nstruct UnaryFunctor<device, UnaryOp::kIdentity, Dst, Src> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast<Dst>(src); }\n};\n\ntemplate<DeviceType device, typename Dst, typename Src>\nstruct UnaryFunctor<device, UnaryOp::kElu, Dst, Src> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) : alpha(attr0.Value<double>()) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src src) const {\n    return static_cast<Dst>(\n        (src > static_cast<Src>(0.0)) ? src : alpha * (exp(src) - static_cast<Src>(1)));\n  }\n  const Src alpha;\n};\n\ntemplate<DeviceType device, typename Dst, typename Src>\nstruct UnaryFunctor<device, UnaryOp::kCelu, Dst, Src> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1)\n      : alpha(attr0.Value<double>()), inv_alpha(1.0f / attr0.Value<double>()) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src src) const {\n    return static_cast<Dst>(\n        (src > static_cast<Src>(0.0)) ? src : alpha * (exp(src * inv_alpha) - static_cast<Src>(1)));\n  }\n  const Src alpha;\n  const Src inv_alpha;\n};\n\ntemplate<DeviceType device, typename Dst, typename Src>\nstruct UnaryFunctor<device, UnaryOp::kHardSwish, Dst, Src> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src src) const {\n    if (src <= static_cast<Src>(-3)) {\n      return static_cast<Dst>(0);\n    } else if (src >= static_cast<Src>(3)) {\n      return static_cast<Dst>(src);\n    } else {\n      return static_cast<Dst>((src * (src + static_cast<Src>(3))) / static_cast<Src>(6));\n    }\n  }\n};\n\ntemplate<DeviceType device, typename Dst, typename Src>\nstruct UnaryFunctor<device, UnaryOp::kHardSigmoid, Dst, Src> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src src) const {\n    if (src <= static_cast<Src>(-3)) {\n      return static_cast<Dst>(0);\n    } else if (src >= static_cast<Src>(3)) {\n      return static_cast<Dst>(1);\n    } else {\n      return static_cast<Dst>(src / static_cast<Src>(6) + static_cast<Src>(0.5));\n    }\n  }\n};\n\ntemplate<DeviceType device, typename Dst, typename Src>\nstruct UnaryFunctor<device, UnaryOp::kHardShrink, Dst, Src> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) : lambd(attr0.Value<double>()) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src src) const {\n    return (src <= lambd && src >= -lambd) ? static_cast<Dst>(0) : static_cast<Dst>(src);\n  }\n\n  const Src lambd;\n};\n\ntemplate<DeviceType device, typename Dst, typename Src>\nstruct UnaryFunctor<device, UnaryOp::kHardTanh, Dst, Src> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1)\n      : min_val(attr0.Value<double>()), max_val(attr1.Value<double>()) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src src) const {\n    if (src <= min_val) {\n      return static_cast<Dst>(min_val);\n    } else if (src >= max_val) {\n      return static_cast<Dst>(max_val);\n    } else {\n      return static_cast<Dst>(src);\n    }\n  }\n\n  const Src min_val;\n  const Src max_val;\n};\n\ntemplate<DeviceType device, typename Dst, typename Src>\nstruct UnaryFunctor<device, UnaryOp::kLeakyRelu, Dst, Src> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) : alpha(attr0.Value<float>()) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src src) const {\n    return static_cast<Dst>((src > static_cast<Src>(0.0)) ? src : alpha * src);\n  }\n  const Src alpha;\n};\n\ntemplate<DeviceType device, typename Dst, typename Src>\nstruct UnaryFunctor<device, UnaryOp::kMish, Dst, Src> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src src) const {\n    Src soft_plus_val = log(static_cast<Src>(1) + exp(src));\n    Src exp_val = exp(soft_plus_val);\n    Src neg_exp_val = exp(-soft_plus_val);\n    Src tanh_val = (exp_val - neg_exp_val) / (exp_val + neg_exp_val);\n    return static_cast<Dst>(src * tanh_val);\n  }\n};\n\ntemplate<DeviceType device, typename Dst, typename Src>\nstruct UnaryFunctor<device, UnaryOp::kRelu, Dst, Src> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src src) const {\n    const Src zero_val = static_cast<Src>(0.0);\n    if (src <= zero_val) {\n      return static_cast<Dst>(zero_val);\n    } else {\n      return static_cast<Dst>(src);\n    }\n  }\n};\n\ntemplate<DeviceType device, typename Dst, typename Src>\nstruct UnaryFunctor<device, UnaryOp::kSilu, Dst, Src> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src src) const {\n    return static_cast<Dst>(src / (static_cast<Src>(1) + exp(-src)));\n  }\n};\n\ntemplate<DeviceType device, typename Dst, typename Src>\nstruct UnaryFunctor<device, UnaryOp::kSelu, Dst, Src> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src src) const {\n    return static_cast<Dst>((src > static_cast<Src>(0.0))\n                                ? src * scale\n                                : scale * alpha * (exp(src) - static_cast<Src>(1)));\n  }\n  const Src scale = 1.0507009873554804934193349852946;\n  const Src alpha = 1.6732632423543772848170429916717;\n};\n\ntemplate<DeviceType device, typename Dst, typename Src>\nstruct UnaryFunctor<device, UnaryOp::kSoftSign, Dst, Src> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src src) const {\n    return static_cast<Dst>(src / (static_cast<Src>(1) + abs(src)));\n  }\n};\n\ntemplate<DeviceType device, typename Dst, typename Src>\nstruct UnaryFunctor<device, UnaryOp::kSoftPlus, Dst, Src> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1)\n      : beta(attr0.Value<double>()), threshold(attr1.Value<double>()) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src src) const {\n    return static_cast<Dst>(\n        (src * beta) > threshold ? src : log(static_cast<Src>(1.0) + exp(src * beta)) / beta);\n  }\n\n  const Src beta;\n  const Src threshold;\n};\n\ntemplate<DeviceType device, typename Dst, typename Src>\nstruct UnaryFunctor<device, UnaryOp::kSoftShrink, Dst, Src> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) : alpha(attr0.Value<double>()) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src src) const {\n    if (src <= alpha && src >= -alpha) {\n      return static_cast<Dst>(0);\n    } else if (src > alpha) {\n      return static_cast<Dst>(src - alpha);\n    } else {\n      return static_cast<Dst>(src + alpha);\n    }\n  }\n  const Src alpha;\n};\n\ntemplate<DeviceType device, typename Dst, typename Src>\nstruct UnaryFunctor<device, UnaryOp::kThreshold, Dst, Src> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1)\n      : threshold(attr0.Value<double>()), value(attr1.Value<double>()) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src src) const {\n    return static_cast<Dst>((src <= threshold) ? value : src);\n  }\n  const Src threshold;\n  const Src value;\n};\n\ntemplate<DeviceType device, typename Dst, typename Src>\nstruct UnaryFunctor<device, UnaryOp::kLogicalNot, Dst, Src> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast<Dst>(!src); }\n};\n\ntemplate<DeviceType device, typename Src>\nstruct UnaryFunctor<device, UnaryOp::kIsInf, bool, Src> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC bool operator()(Src src) const { return false; }\n};\n\ntemplate<DeviceType device, typename Src>\nstruct UnaryFunctor<device, UnaryOp::kIsNan, bool, Src> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC bool operator()(Src src) const { return false; }\n};\n\ntemplate<DeviceType device, typename Src>\nstruct UnaryFunctor<device, UnaryOp::kIsFinite, bool, Src> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC bool operator()(Src src) const { return true; }\n};\n\ntemplate<DeviceType device, typename Dst, typename Src>\nstruct UnaryFunctor<device, UnaryOp::kTrunc, Dst, Src> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1);\n  OF_DEVICE_FUNC Dst operator()(Src src) const;\n};\n\ntemplate<DeviceType device, typename Dst, typename Src>\nstruct UnaryFunctor<device, UnaryOp::kAbs, Dst, Src> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast<Dst>(abs(src)); }\n};\n\ntemplate<DeviceType device>\nstruct UnaryFunctor<device, UnaryOp::kAbs, uint8_t, uint8_t> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC uint8_t operator()(uint8_t src) const { return src; }\n};\n\ntemplate<DeviceType device, typename Dst, typename Src>\nstruct UnaryFunctor<device, UnaryOp::kExp, Dst, Src> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast<Dst>(exp(src)); }\n};\n\ntemplate<DeviceType device, typename Dst, typename Src>\nstruct UnaryFunctor<device, UnaryOp::kExp2, Dst, Src> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast<Dst>(exp2(src)); }\n};\n\ntemplate<DeviceType device, typename Dst, typename Src>\nstruct UnaryFunctor<device, UnaryOp::kAcos, Dst, Src> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast<Dst>(acos(src)); }\n};\n\ntemplate<DeviceType device, typename Dst, typename Src>\nstruct UnaryFunctor<device, UnaryOp::kAcosh, Dst, Src> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast<Dst>(acosh(src)); }\n};\n\ntemplate<DeviceType device, typename Dst, typename Src>\nstruct UnaryFunctor<device, UnaryOp::kAsin, Dst, Src> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast<Dst>(asin(src)); }\n};\n\ntemplate<DeviceType device, typename Dst, typename Src>\nstruct UnaryFunctor<device, UnaryOp::kAsinh, Dst, Src> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast<Dst>(asinh(src)); }\n};\n\ntemplate<DeviceType device, typename Dst, typename Src>\nstruct UnaryFunctor<device, UnaryOp::kAtan, Dst, Src> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast<Dst>(atan(src)); }\n};\n\ntemplate<DeviceType device, typename Dst, typename Src>\nstruct UnaryFunctor<device, UnaryOp::kAtanh, Dst, Src> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast<Dst>(atanh(src)); }\n};\n\ntemplate<DeviceType device, typename Dst, typename Src>\nstruct UnaryFunctor<device, UnaryOp::kCeil, Dst, Src> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast<Dst>(ceil(src)); }\n};\n\ntemplate<DeviceType device, typename Dst, typename Src>\nstruct UnaryFunctor<device, UnaryOp::kCos, Dst, Src> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast<Dst>(cos(src)); }\n};\n\ntemplate<DeviceType device, typename Dst, typename Src>\nstruct UnaryFunctor<device, UnaryOp::kCosh, Dst, Src> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast<Dst>(cosh(src)); }\n};\n\ntemplate<DeviceType device, typename Dst, typename Src>\nstruct UnaryFunctor<device, UnaryOp::kErf, Dst, Src> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast<Dst>(erf(src)); }\n};\n\ntemplate<DeviceType device, typename Dst, typename Src>\nstruct UnaryFunctor<device, UnaryOp::kErfc, Dst, Src> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast<Dst>(erfc(src)); }\n};\n\ntemplate<DeviceType device, typename Dst, typename Src>\nstruct UnaryFunctor<device, UnaryOp::kExpm1, Dst, Src> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast<Dst>(expm1(src)); }\n};\n\ntemplate<DeviceType device, typename Dst, typename Src>\nstruct UnaryFunctor<device, UnaryOp::kFloor, Dst, Src> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast<Dst>(floor(src)); }\n};\n\ntemplate<DeviceType device, typename Dst, typename Src>\nstruct UnaryFunctor<device, UnaryOp::kLgamma, Dst, Src> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast<Dst>(lgamma(src)); }\n};\n\ntemplate<DeviceType device, typename Dst, typename Src>\nstruct UnaryFunctor<device, UnaryOp::kLog, Dst, Src> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast<Dst>(log(src)); }\n};\n\ntemplate<DeviceType device, typename Dst, typename Src>\nstruct UnaryFunctor<device, UnaryOp::kLog2, Dst, Src> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast<Dst>(log2(src)); }\n};\n\ntemplate<DeviceType device, typename Dst, typename Src>\nstruct UnaryFunctor<device, UnaryOp::kLog10, Dst, Src> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast<Dst>(log10(src)); }\n};\n\ntemplate<DeviceType device, typename Dst, typename Src>\nstruct UnaryFunctor<device, UnaryOp::kLog1p, Dst, Src> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast<Dst>(log1p(src)); }\n};\n\ntemplate<DeviceType device, typename Dst, typename Src>\nstruct UnaryFunctor<device, UnaryOp::kLogSigmoid, Dst, Src> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src src) const {\n    return static_cast<Dst>(-log(static_cast<Src>(1.0) + exp(-src)));\n  }\n};\n\ntemplate<DeviceType device, typename Dst, typename Src>\nstruct UnaryFunctor<device, UnaryOp::kNegative, Dst, Src> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast<Dst>(-src); }\n};\n\ntemplate<DeviceType device, typename Dst, typename Src>\nstruct UnaryFunctor<device, UnaryOp::kReciprocal, Dst, Src> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src src) const {\n    return static_cast<Dst>(static_cast<Src>(1.0) / src);\n  }\n};\n\ntemplate<DeviceType device, typename Dst, typename Src>\nstruct UnaryFunctor<device, UnaryOp::kReciprocalNoNan, Dst, Src> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src src) const {\n    if (abs(src) <= static_cast<Src>(0.0)) { return static_cast<Dst>(0.0); }\n    return static_cast<Dst>(static_cast<Src>(1.0) / src);\n  }\n};\n\ntemplate<DeviceType device, typename Dst, typename Src>\nstruct UnaryFunctor<device, UnaryOp::kRint, Dst, Src> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast<Dst>(rint(src)); }\n};\n\ntemplate<DeviceType device, typename Dst, typename Src>\nstruct UnaryFunctor<device, UnaryOp::kRound, Dst, Src> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast<Dst>(nearbyint(src)); }\n};\n\ntemplate<DeviceType device, typename Dst, typename Src>\nstruct UnaryFunctor<device, UnaryOp::kRsqrt, Dst, Src> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast<Dst>(rsqrt(src)); }\n};\n\ntemplate<DeviceType device, typename Dst, typename Src>\nstruct UnaryFunctor<device, UnaryOp::kSigmoid, Dst, Src> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src src) const {\n    return static_cast<Dst>(static_cast<Src>(1.0) / (static_cast<Src>(1.0) + exp(-src)));\n  }\n};\n\ntemplate<DeviceType device, typename Dst, typename Src>\nstruct UnaryFunctor<device, UnaryOp::kSign, Dst, Src> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src src) const {\n    const Src zero = static_cast<Src>(0.0);\n    if (src > zero) {\n      return static_cast<Dst>(1.0);\n    } else if (src < zero) {\n      return static_cast<Dst>(-1.0);\n    } else {\n      return static_cast<Dst>(0.0);\n    }\n  }\n};\n\ntemplate<DeviceType device, typename Dst, typename Src>\nstruct UnaryFunctor<device, UnaryOp::kSin, Dst, Src> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast<Dst>(sin(src)); }\n};\n\ntemplate<DeviceType device, typename Dst, typename Src>\nstruct UnaryFunctor<device, UnaryOp::kSinh, Dst, Src> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast<Dst>(sinh(src)); }\n};\n\ntemplate<DeviceType device, typename Dst, typename Src>\nstruct UnaryFunctor<device, UnaryOp::kSqrt, Dst, Src> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast<Dst>(sqrt(src)); }\n};\n\ntemplate<DeviceType device, typename Dst, typename Src>\nstruct UnaryFunctor<device, UnaryOp::kSquare, Dst, Src> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast<Dst>(src * src); }\n};\n\ntemplate<DeviceType device, typename Dst, typename Src>\nstruct UnaryFunctor<device, UnaryOp::kTan, Dst, Src> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast<Dst>(tan(src)); }\n};\n\ntemplate<DeviceType device, typename Dst, typename Src>\nstruct UnaryFunctor<device, UnaryOp::kNotEqualZero, Dst, Src> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src src) const {\n    return static_cast<Dst>(src != static_cast<Src>(0.0));\n  }\n};\n\ntemplate<DeviceType device, typename Dst, typename Src>\nstruct UnaryFunctor<device, UnaryOp::kNanAssign, Dst, Src> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src src) const {\n    return std::isnan(src) ? static_cast<Dst>(0.0) : src;\n  }\n};\n\ntemplate<DeviceType device, typename Dst, typename Src>\nstruct UnaryFunctor<device, UnaryOp::kCast, Dst, Src> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast<Dst>(src); }\n};\n\ntemplate<DeviceType device, typename Dst, typename Src>\nstruct UnaryFunctor<device, UnaryOp::kBitwiseNot, Dst, Src> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast<Dst>(~src); }\n};\n\ntemplate<DeviceType device, typename Dst>\nstruct UnaryFunctor<device, UnaryOp::kBitwiseNot, Dst, bool> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(bool src) const { return static_cast<Dst>(!src); }\n};\n\ntemplate<DeviceType device, typename Dst, typename Src>\nstruct UnaryFunctor<device, UnaryOp::kConj, Dst, Src> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src src) const { return Dst{src.real(), -src.imag()}; }\n};\n\ntemplate<DeviceType device, typename Dst, typename Src>\nstruct UnaryFunctor<device, UnaryOp::kReal, Dst, Src> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast<Dst>(src.real()); }\n};\n\ntemplate<DeviceType device, typename Dst, typename Src>\nstruct UnaryFunctor<device, UnaryOp::kImag, Dst, Src> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast<Dst>(src.imag()); }\n};\n\ntemplate<DeviceType device, typename Dst, typename Src>\nstruct UnaryFunctor<device, UnaryOp::kRealGrad, Dst, Src> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n  OF_DEVICE_FUNC Dst operator()(Src src) const { return Dst{src, 0.0}; }\n};\n\ntemplate<DeviceType device, typename Dst, typename Src>\nstruct UnaryFunctor<device, UnaryOp::kImagGrad, Dst, Src> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src src) const { return Dst{0.0, src}; }\n};\n\n}  // namespace primitive\n}  // namespace ep\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_EP_COMMON_PRIMITIVE_UNARY_FUNCTOR_H_\n"
  },
  {
    "path": "oneflow/core/ep/common/primitive/util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_EP_COMMON_PRIMITIVE_UTIL_H_\n#define ONEFLOW_CORE_EP_COMMON_PRIMITIVE_UTIL_H_\n\n#include \"oneflow/core/common/data_type.pb.h\"\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/ep/include/primitive/unary_op.h\"\n\nnamespace oneflow {\n\nnamespace ep {\nnamespace primitive {\n\ninline size_t GetElementCount(size_t num_dims, const int64_t* dims) {\n  size_t count = 1;\n  for (size_t i = 0; i < num_dims; ++i) { count *= dims[i]; }\n  return count;\n}\n\ntemplate<typename T>\nbool IsPackSizeSupported(const size_t pack_size, size_t num_dims, const int64_t* dims,\n                         const void* ptr) {\n  return (dims[num_dims - 1] % pack_size == 0)\n         && (reinterpret_cast<std::uintptr_t>(ptr) % (pack_size * sizeof(T)) == 0);\n}\n\ninline void CheckInplace(size_t num_dims, const int64_t* src_dims_or_strides, const void* src,\n                         const int64_t* dst_dims_or_strides, const void* dst) {\n  if (src == dst) {\n    for (int64_t i = 0; i < num_dims; ++i) {\n      CHECK_EQ(src_dims_or_strides[i], dst_dims_or_strides[i]);\n    }\n  }\n}\n\ntemplate<size_t max_num_dims>\ninline void SimplifyBroadcastDims(size_t num_src_dims, const int64_t* src_dims,\n                                  const int64_t* src_strides, size_t num_dst_dims,\n                                  const int64_t* dst_dims, const int64_t* dst_strides,\n                                  size_t* simplified_num_dims, int64_t* simplified_src_dims,\n                                  int64_t* simplified_src_strides, int64_t* simplified_dst_dims,\n                                  int64_t* simplified_dst_strides) {\n  *simplified_num_dims = 0;\n  std::pair<int64_t, size_t> sorted_dst_strides[max_num_dims];\n  int64_t new_dst_dims[max_num_dims];\n  int64_t new_src_dims[max_num_dims];\n  int64_t new_dst_strides[max_num_dims];\n  int64_t new_src_strides[max_num_dims];\n  for (size_t i = 0; i < num_dst_dims; i++) { sorted_dst_strides[i] = {dst_strides[i], i}; }\n  std::sort(sorted_dst_strides, sorted_dst_strides + num_dst_dims,\n            [](auto pair1, auto pair2) { return pair1.first > pair2.first; });\n  const int64_t num_src_padding_dims = num_dst_dims - num_src_dims;\n  // dimension completion\n  int64_t expanded_src_dims[max_num_dims];\n  int64_t expanded_src_strides[max_num_dims];\n  for (int64_t i = num_dst_dims - 1; i >= 0; i--) {\n    expanded_src_dims[i] = i < num_src_padding_dims ? 1 : src_dims[i - num_src_padding_dims];\n    expanded_src_strides[i] = i < num_src_padding_dims ? 0 : src_strides[i - num_src_padding_dims];\n  }\n  // dimension permutation\n  for (int64_t i = num_dst_dims - 1; i >= 0; i--) {\n    size_t idx = sorted_dst_strides[i].second;\n    new_dst_dims[i] = dst_dims[idx];\n    new_dst_strides[i] = dst_strides[idx];\n    new_src_dims[i] = expanded_src_dims[idx];\n    new_src_strides[i] = expanded_src_strides[idx];\n  }\n  // dimension merge\n  bool prev_broadcast_src = false;\n  for (int64_t i = 0; i < num_dst_dims; ++i) {\n    const bool broadcast_src = (new_src_dims[i] == 1);\n    if (new_dst_dims[i] == 1) {\n      continue;\n    } else if (*simplified_num_dims != 0 && prev_broadcast_src == broadcast_src\n               && (new_src_strides[i - 1] == new_src_strides[i] * new_src_dims[i])\n               && (new_dst_strides[i - 1] == new_dst_strides[i] * new_dst_dims[i])) {\n      simplified_src_dims[*simplified_num_dims - 1] *= new_src_dims[i];\n      simplified_dst_dims[*simplified_num_dims - 1] *= new_dst_dims[i];\n      simplified_src_strides[*simplified_num_dims - 1] = new_src_strides[i];\n      simplified_dst_strides[*simplified_num_dims - 1] = new_dst_strides[i];\n    } else {\n      simplified_src_dims[*simplified_num_dims] = new_src_dims[i];\n      simplified_dst_dims[*simplified_num_dims] = new_dst_dims[i];\n      simplified_src_strides[*simplified_num_dims] = new_src_strides[i];\n      simplified_dst_strides[*simplified_num_dims] = new_dst_strides[i];\n      *simplified_num_dims += 1;\n      prev_broadcast_src = broadcast_src;\n    }\n  }\n  if (*simplified_num_dims == 0) {\n    simplified_src_dims[0] = 1;\n    simplified_dst_dims[0] = 1;\n    simplified_src_strides[0] = 1;\n    simplified_dst_strides[0] = 1;\n    *simplified_num_dims = 1;\n  }\n}\n\ninline void SimplifyBroadcastDims(size_t num_a_dims, const int64_t* a_dims, size_t num_b_dims,\n                                  const int64_t* b_dims, size_t num_c_dims, const int64_t* c_dims,\n                                  size_t* simplified_num_dims, int64_t* simplified_broadcast_dims,\n                                  int64_t* simplified_a_dims, int64_t* simplified_b_dims,\n                                  int64_t* simplified_c_dims) {\n  const size_t num_max_dims = std::max(num_a_dims, num_b_dims);\n  auto MakeGetDim = [num_max_dims](size_t num_dims, const int64_t* dims) {\n    const int64_t num_padding_dims = num_max_dims - num_dims;\n    return [num_padding_dims, dims](size_t index) {\n      return index < num_padding_dims ? 1 : dims[index - num_padding_dims];\n    };\n  };\n  auto GetADim = MakeGetDim(num_a_dims, a_dims);\n  auto GetBDim = MakeGetDim(num_b_dims, b_dims);\n  auto GetCDim = MakeGetDim(num_c_dims, c_dims);\n  *simplified_num_dims = 0;\n  bool prev_broadcast_a = false;\n  bool prev_broadcast_b = false;\n  bool prev_broadcast_c = false;\n  for (int64_t i = 0; i < num_max_dims; ++i) {\n    const int64_t a_dim = GetADim(i);\n    const int64_t b_dim = GetBDim(i);\n    const int64_t c_dim = GetCDim(i);\n    const int64_t broadcast_dim = std::max(std::max(a_dim, b_dim), c_dim);\n    CHECK_GT(broadcast_dim, 0);\n    const bool broadcast_a = (a_dim == 1);\n    const bool broadcast_b = (b_dim == 1);\n    const bool broadcast_c = (c_dim == 1);\n    CHECK((a_dim == broadcast_dim) || broadcast_a);\n    CHECK((b_dim == broadcast_dim) || broadcast_b);\n    CHECK((c_dim == broadcast_dim) || broadcast_c);\n    if (broadcast_dim == 1) {\n      continue;\n    } else if (*simplified_num_dims != 0\n               && (prev_broadcast_a == broadcast_a && prev_broadcast_b == broadcast_b\n                   && prev_broadcast_c == broadcast_c)) {\n      simplified_a_dims[*simplified_num_dims - 1] *= a_dim;\n      simplified_b_dims[*simplified_num_dims - 1] *= b_dim;\n      simplified_c_dims[*simplified_num_dims - 1] *= c_dim;\n      simplified_broadcast_dims[*simplified_num_dims - 1] *= broadcast_dim;\n    } else {\n      simplified_a_dims[*simplified_num_dims] = a_dim;\n      simplified_b_dims[*simplified_num_dims] = b_dim;\n      simplified_c_dims[*simplified_num_dims] = c_dim;\n      simplified_broadcast_dims[*simplified_num_dims] = broadcast_dim;\n      *simplified_num_dims += 1;\n      prev_broadcast_a = broadcast_a;\n      prev_broadcast_b = broadcast_b;\n      prev_broadcast_c = broadcast_c;\n    }\n  }\n  if (*simplified_num_dims == 0) {\n    simplified_a_dims[0] = 1;\n    simplified_b_dims[0] = 1;\n    simplified_c_dims[0] = 1;\n    *simplified_num_dims = 1;\n  }\n}\n\ntemplate<size_t max_num_dims>\ninline void SimplifyBroadcastDims(size_t num_src0_dims, const int64_t* src0_dims,\n                                  size_t num_src1_dims, const int64_t* src1_dims,\n                                  size_t* simplified_num_dims, int64_t* simplified_src0_dims,\n                                  int64_t* simplified_src1_dims, int64_t* simplified_dst_dims) {\n  size_t src0_count = GetElementCount(num_src0_dims, src0_dims);\n  size_t src1_count = GetElementCount(num_src1_dims, src1_dims);\n  if (src0_count == 1 || src1_count == 1) {\n    *simplified_num_dims = 1;\n    simplified_src0_dims[0] = src0_count;\n    simplified_src1_dims[0] = src1_count;\n    simplified_dst_dims[0] = std::max(src0_count, src1_count);\n    return;\n  }\n  int64_t dst_dims[max_num_dims];\n  int64_t broadcast_dims[max_num_dims];\n  const size_t num_dst_dims = std::max(num_src0_dims, num_src1_dims);\n  for (int64_t i = 0; i < num_dst_dims; ++i) {\n    const int64_t num_src0_padding_dims = num_dst_dims - num_src0_dims;\n    const int64_t num_src1_padding_dims = num_dst_dims - num_src1_dims;\n    size_t src0_dim = i < num_src0_padding_dims ? 1 : src0_dims[i - num_src0_padding_dims];\n    size_t src1_dim = i < num_src1_padding_dims ? 1 : src1_dims[i - num_src1_padding_dims];\n    dst_dims[i] = std::max(src0_dim, src1_dim);\n  }\n  SimplifyBroadcastDims(num_src0_dims, src0_dims, num_src1_dims, src1_dims, num_dst_dims, dst_dims,\n                        simplified_num_dims, broadcast_dims, simplified_src0_dims,\n                        simplified_src1_dims, simplified_dst_dims);\n  for (int64_t i = 0; i < *simplified_num_dims; ++i) {\n    CHECK_EQ(broadcast_dims[i], simplified_dst_dims[i]);\n  }\n}\n\ntemplate<size_t max_num_dims>\ninline bool InferPermutable(size_t simplified_num_dims, const int64_t* simplified_src_strides,\n                            const int64_t* simplified_dst_strides,\n                            const int64_t* simplified_src_dims, const int64_t* simplified_dst_dims,\n                            int* permutation_list, int64_t* permutation_src_dims,\n                            UnaryOp unary_op) {\n  if (unary_op != UnaryOp::kIdentity) { return false; }\n\n  // all dims of src & dst should be the same\n  for (size_t i = 0; i < simplified_num_dims; i++) {\n    if (simplified_src_dims[i] != simplified_dst_dims[i]) { return false; }\n  }\n\n  // only simplified_src_strides need to be sorted, simplified_dst_strides has been sorted in\n  // SimplifyBroadcastDims\n  std::pair<int64_t, size_t> sorted_src_strides[max_num_dims];\n  for (size_t i = 0; i < simplified_num_dims; i++) {\n    sorted_src_strides[i] = {simplified_src_strides[i], i};\n  }\n  std::sort(sorted_src_strides, sorted_src_strides + simplified_num_dims,\n            [](auto pair1, auto pair2) { return pair1.first > pair2.first; });\n\n  // src & dst has to be filled with numbers without strides\n  if (sorted_src_strides[simplified_num_dims - 1].first != 1) { return false; }\n  for (size_t i = simplified_num_dims - 1; i > 0; i--) {\n    if (sorted_src_strides[i - 1].first\n        != sorted_src_strides[i].first * simplified_src_dims[sorted_src_strides[i].second]) {\n      return false;\n    }\n  }\n\n  if (simplified_dst_strides[simplified_num_dims - 1] != 1) { return false; }\n  for (size_t i = simplified_num_dims - 1; i > 0; i--) {\n    if (simplified_dst_strides[i - 1] != simplified_dst_strides[i] * simplified_dst_dims[i]) {\n      return false;\n    }\n  }\n\n  for (size_t j = 0; j < simplified_num_dims; j++) {\n    permutation_list[j] = sorted_src_strides[j].second;\n    permutation_src_dims[j] = simplified_src_dims[sorted_src_strides[j].second];\n  }\n\n  return true;\n}\n\ntemplate<typename T, typename D>\nstd::unique_ptr<T> NewPrimitiveFromHandlers(\n    const std::map<D, std::function<std::unique_ptr<T>()>>& handlers, const D& key) {\n  const auto iter = handlers.find(key);\n  if (iter != handlers.end()) { return iter->second(); }\n  return nullptr;\n}\n\n}  // namespace primitive\n}  // namespace ep\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_EP_COMMON_PRIMITIVE_UTIL_H_\n"
  },
  {
    "path": "oneflow/core/ep/common/primitive/where.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_EP_COMMON_PRIMITIVE_WHERE_H_\n#define ONEFLOW_CORE_EP_COMMON_PRIMITIVE_WHERE_H_\n\n#include \"oneflow/core/ep/include/primitive/where.h\"\n#include \"oneflow/core/ep/include/stream.h\"\n#include \"oneflow/core/ep/common/primitive/util.h\"\n#include \"oneflow/core/common/nd_index_offset_helper.h\"\n\nnamespace oneflow {\nnamespace ep {\nnamespace primitive {\n\nnamespace {\n\nconstexpr size_t kMaxNumDims = 8;\n\ntemplate<typename R, typename Cond, typename X, typename Y>\nstruct WhereElemwiseFunctor {\n  OF_DEVICE_FUNC WhereElemwiseFunctor() {}\n\n  OF_DEVICE_FUNC R operator()(Cond cond, X x, Y y) const { return cond ? x : y; }\n};\n\ntemplate<typename T, typename CondT>\nusing WhereFunctor = WhereElemwiseFunctor<T, CondT, T, T>;\n\ntemplate<size_t NDIM, typename IndexType>\nstruct BroadcastElementwiseWhereParams {\n  NdIndexOffsetHelper<IndexType, NDIM> cond_index_helper;\n  NdIndexOffsetHelper<IndexType, NDIM> x_index_helper;\n  NdIndexOffsetHelper<IndexType, NDIM> y_index_helper;\n  NdIndexOffsetHelper<IndexType, NDIM> z_index_helper;\n  IndexType cond_index_mask[NDIM];\n  IndexType x_index_mask[NDIM];\n  IndexType y_index_mask[NDIM];\n  IndexType elem_cnt{};\n  const void* cond{};\n  const void* x{};\n  const void* y{};\n  void* z{};\n};\n\ntemplate<typename T, int pack_size>\nstruct alignas(sizeof(T) * pack_size) Packed {\n  OF_DEVICE_FUNC Packed() {\n    // do nothing\n  }\n  union {\n    T elem[pack_size];\n  };\n};\n\ninline bool IsDimsEquals(size_t ndim, const int64_t* a_dims, const int64_t* b_dims) {\n  for (size_t i = 0; i < ndim; ++i) {\n    if (a_dims[i] != b_dims[i]) { return false; }\n  }\n  return true;\n}\n\n// Calculate compact broadcast dimensions\n// For example:\n//   [1, 2, 8] and [4, 2, 8] can be compacted to [1, 16] and [4, 16]\n//   [4, 1, 8] and [8] -> [4, 8] and [1, 8]\n//   [1, 1, 8] and [4, 2, 8] -> [1, 8] and [8, 8]\n// after compacting, cond, x, y will have the same number of dims,\n// z_dims is the broadcast dims of compacted cond, x, y dims.\ninline void GetCompactBroadcastDims(const size_t num_cond_ndims, const int64_t* cond_dims,\n                                    const size_t num_x_dims, const int64_t* x_dims,\n                                    const size_t num_y_dims, const int64_t* y_dims,\n                                    size_t* compact_num_dims, int64_t* compact_cond_dims,\n                                    int64_t* compact_x_dims, int64_t* compact_y_dims,\n                                    int64_t* compact_z_dims) {\n  size_t max_num_dims = std::max(std::max(num_x_dims, num_y_dims), num_cond_ndims);\n  CHECK_LE(max_num_dims, kMaxNumDims);\n\n  auto MakeGetDimSize = [max_num_dims](size_t ndim, const int64_t* dims) {\n    size_t lpad = max_num_dims - ndim;\n    return [lpad, dims](int dim) -> int64_t { return dim < lpad ? 1 : dims[dim - lpad]; };\n  };\n  auto GetCondDimSize = MakeGetDimSize(num_cond_ndims, cond_dims);\n  auto GetXDimSize = MakeGetDimSize(num_x_dims, x_dims);\n  auto GetYDimSize = MakeGetDimSize(num_y_dims, y_dims);\n\n  size_t& num_dims = *compact_num_dims;\n  num_dims = 0;\n  bool cond_pred_dim_broadcast = false;\n  bool x_pred_dim_broadcast = false;\n  bool y_pred_dim_broadcast = false;\n  for (int i = 0; i < max_num_dims; ++i) {\n    int64_t cond_dim_size = GetCondDimSize(i);\n    int64_t x_dim_size = GetXDimSize(i);\n    int64_t y_dim_size = GetYDimSize(i);\n    int64_t dim_size = std::max(std::max(x_dim_size, y_dim_size), cond_dim_size);\n    if (dim_size == 1) { continue; }\n    bool cond_broadcast = (cond_dim_size == 1);\n    bool x_broadcast = (x_dim_size == 1);\n    bool y_broadcast = (y_dim_size == 1);\n    if (*compact_num_dims > 0 && cond_broadcast == cond_pred_dim_broadcast\n        && x_broadcast == x_pred_dim_broadcast && y_broadcast == y_pred_dim_broadcast) {\n      compact_cond_dims[num_dims - 1] *= cond_dim_size;\n      compact_x_dims[num_dims - 1] *= x_dim_size;\n      compact_y_dims[num_dims - 1] *= y_dim_size;\n      compact_z_dims[num_dims - 1] *= dim_size;\n    } else {\n      compact_cond_dims[num_dims] = cond_dim_size;\n      compact_x_dims[num_dims] = x_dim_size;\n      compact_y_dims[num_dims] = y_dim_size;\n      compact_z_dims[num_dims] = dim_size;\n      num_dims += 1;\n      cond_pred_dim_broadcast = cond_broadcast;\n      x_pred_dim_broadcast = x_broadcast;\n      y_pred_dim_broadcast = y_broadcast;\n    }\n  }\n}\n\ntemplate<typename T, typename CondT, typename IndexT, size_t ndim, size_t cond_pack_size,\n         size_t x_pack_size, size_t y_pack_size>\nvoid LaunchKernel(Stream* stream, const int64_t* cond_dims, const int64_t* x_dims,\n                  const int64_t* y_dims, const int64_t* z_dims, const CondT* cond, const T* x,\n                  const T* y, T* z);\n\ntemplate<typename T, typename CondT>\nvoid LaunchScalarKernel(Stream* stream, const CondT* cond, const T* x, const T* y, T* z);\n\ntemplate<typename T, typename CondT, size_t ndim, size_t cond_pack_size, size_t x_pack_size,\n         size_t y_pack_size>\nvoid LaunchByDispatchIndexType(Stream* stream, int64_t* cond_dims, int64_t* x_dims, int64_t* y_dims,\n                               int64_t* z_dims, const CondT* cond, const T* x, const T* y, T* z) {\n  const size_t elem_cnt = GetElementCount(ndim, z_dims);\n  if (elem_cnt < GetMaxVal<int32_t>()) {\n    return LaunchKernel<T, CondT, int32_t, ndim, cond_pack_size, x_pack_size, y_pack_size>(\n        stream, cond_dims, x_dims, y_dims, z_dims, cond, x, y, z);\n  } else {\n    return LaunchKernel<T, CondT, int64_t, ndim, cond_pack_size, x_pack_size, y_pack_size>(\n        stream, cond_dims, x_dims, y_dims, z_dims, cond, x, y, z);\n  }\n}\n\ntemplate<typename T, typename CondT, size_t ndim, size_t max_pack_size>\nsize_t GetPackSize(const int64_t* cond_dims, const int64_t* x_dims, const int64_t* y_dims,\n                   const int64_t* z_dims, const CondT* cond, const T* x, const T* y, const T* z) {\n  static_assert(max_pack_size > 0 && (max_pack_size & (max_pack_size - 1)) == 0, \"\");\n  CHECK_GT(z_dims[ndim - 1], 1);\n  for (size_t pack_size = max_pack_size; pack_size >= 2; pack_size /= 2) {\n    if (!IsPackSizeSupported<T>(pack_size, ndim, z_dims, z)) { continue; }\n    if (x_dims[ndim - 1] != 1 && !IsPackSizeSupported<T>(pack_size, ndim, x_dims, x)) { continue; }\n    if (y_dims[ndim - 1] != 1 && !IsPackSizeSupported<T>(pack_size, ndim, y_dims, y)) { continue; }\n    if (cond_dims[ndim - 1] != 1 && !IsPackSizeSupported<CondT>(pack_size, ndim, cond_dims, cond)) {\n      continue;\n    }\n    return pack_size;\n  }\n  return 1;\n}\n\ntemplate<typename T, typename CondT, size_t ndim>\nvoid LaunchByDispatchPackSize(Stream* stream, int64_t* cond_dims, int64_t* x_dims, int64_t* y_dims,\n                              int64_t* z_dims, const CondT* cond, const T* x, const T* y, T* z) {\n  static_assert(ndim > 0, \"\");\n  constexpr size_t kMaxPackSize = 4;\n  size_t pack_size =\n      GetPackSize<T, CondT, ndim, kMaxPackSize>(cond_dims, x_dims, y_dims, z_dims, cond, x, y, z);\n  size_t cond_pack_size = 1;\n  size_t x_pack_size = 1;\n  size_t y_pack_size = 1;\n  if (pack_size > 1) {\n    if (cond_dims[ndim - 1] != 1) {\n      cond_dims[ndim - 1] /= pack_size;\n      cond_pack_size = pack_size;\n    }\n    if (x_dims[ndim - 1] != 1) {\n      x_dims[ndim - 1] /= pack_size;\n      x_pack_size = pack_size;\n    }\n    if (y_dims[ndim - 1] != 1) {\n      y_dims[ndim - 1] /= pack_size;\n      y_pack_size = pack_size;\n    }\n    z_dims[ndim - 1] /= pack_size;\n  }\n\n#define IF(cp, xp, yp)                                                                       \\\n  if (cond_pack_size == cp && x_pack_size == xp && y_pack_size == yp) {                      \\\n    LaunchByDispatchIndexType<T, CondT, ndim, cp, xp, yp>(stream, cond_dims, x_dims, y_dims, \\\n                                                          z_dims, cond, x, y, z);            \\\n  }\n#define ELIF(cp, xp, yp) else IF(cp, xp, yp)\n#define ELSE         \\\n  else {             \\\n    UNIMPLEMENTED(); \\\n  }\n\n  if (pack_size == 1) {\n    IF(1, 1, 1)\n    ELSE\n  } else if (pack_size == 2) {\n    IF(2, 2, 2)\n    ELIF(1, 2, 2)\n    ELIF(1, 2, 1)\n    ELIF(1, 1, 2)\n    ELIF(2, 1, 2)\n    ELIF(2, 1, 1)\n    ELIF(2, 2, 1)\n    ELSE\n  } else if (pack_size == 4) {\n    IF(4, 4, 4)\n    ELIF(1, 4, 4)\n    ELIF(1, 4, 1)\n    ELIF(1, 1, 4)\n    ELIF(4, 1, 4)\n    ELIF(4, 1, 1)\n    ELIF(4, 4, 1)\n    ELSE\n  }\n  ELSE\n\n#undef IF\n#undef ELIF\n#undef ELSE\n}\n\ntemplate<typename T, typename CondT>\nvoid LaunchByDispatchNDim(Stream* stream, size_t ndim, int64_t* cond_dims, int64_t* x_dims,\n                          int64_t* y_dims, int64_t* z_dims, const CondT* cond, const T* x,\n                          const T* y, T* z) {\n#define ELIF(n)                                                                                  \\\n  else if (ndim == n) {                                                                          \\\n    LaunchByDispatchPackSize<T, CondT, n>(stream, cond_dims, x_dims, y_dims, z_dims, cond, x, y, \\\n                                          z);                                                    \\\n  }\n#define ELSE         \\\n  else {             \\\n    UNIMPLEMENTED(); \\\n  }\n\n  if (ndim == 0) { LaunchScalarKernel<T, CondT>(stream, cond, x, y, z); }\n  ELIF(1)\n  ELIF(2)\n  ELIF(3)\n  ELIF(4)\n  ELSE\n\n#undef IF\n#undef ELIF\n#undef ELSE\n}\n\ntemplate<template<typename, typename> class Prim>\nstd::unique_ptr<Where> NewWhere(DataType cond_type, DataType data_type, size_t max_num_dims) {\n  if (max_num_dims > kMaxNumDims) { return nullptr; }\n\n  const size_t data_type_size = GetSizeOfDataType(data_type);\n\n#define IF(ctype, dtype_size)                                              \\\n  if (cond_type == ctype && data_type_size == dtype_size) {                \\\n    using T = typename std::aligned_storage<dtype_size, dtype_size>::type; \\\n    using CondT = DataTypeToType<ctype>;                                   \\\n    return std::unique_ptr<Where>(new Prim<T, CondT>());                   \\\n  }\n#define ELIF(ctype, dtype_size) else IF(ctype, dtype_size)\n#define ELSE        \\\n  else {            \\\n    return nullptr; \\\n  }\n\n  IF(DataType::kBool, 1)\n  ELIF(DataType::kBool, 2)\n  ELIF(DataType::kBool, 4)\n  ELIF(DataType::kBool, 8)\n  ELIF(DataType::kInt32, 1)\n  ELIF(DataType::kInt32, 2)\n  ELIF(DataType::kInt32, 4)\n  ELIF(DataType::kInt32, 8)\n  ELIF(DataType::kInt64, 1)\n  ELIF(DataType::kInt64, 2)\n  ELIF(DataType::kInt64, 4)\n  ELIF(DataType::kInt64, 8)\n  ELSE\n\n#undef IF\n#undef ELIF\n#undef ELSE\n}\n\n}  // namespace\n\n}  // namespace primitive\n}  // namespace ep\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_EP_COMMON_PRIMITIVE_WHERE_H_\n"
  },
  {
    "path": "oneflow/core/ep/cpu/cpu_device.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/mem_util.h\"\n#include \"oneflow/core/ep/cpu/cpu_device.h\"\n#include \"oneflow/core/ep/cpu/cpu_event.h\"\n#include \"oneflow/core/ep/cpu/cpu_stream.h\"\n#include \"oneflow/core/ep/include/device_manager_registry.h\"\n\nnamespace oneflow {\n\nnamespace ep {\n\nvoid CpuDevice::SetAsActiveDevice() {}\n\nStream* CpuDevice::CreateStream() { return new CpuStream(this); }\n\nvoid CpuDevice::DestroyStream(Stream* stream) { delete stream; }\n\nvoid CpuDevice::CreateEvents(Event** events, size_t count) {\n  for (size_t i = 0; i < count; ++i) { events[i] = new CpuEvent(); }\n}\n\nvoid CpuDevice::DestroyEvents(Event** events, size_t count) {\n  for (size_t i = 0; i < count; ++i) { delete events[i]; }\n}\n\nMaybe<void> CpuDevice::Alloc(const AllocationOptions& options, void** ptr, size_t size) {\n  if (options.HasPinnedDevice()) {\n    auto device =\n        this->device_manager()->registry()->GetDevice(options.GetPinnedDeviceType(),    // NOLINT\n                                                      options.GetPinnedDeviceIndex());  // NOLINT\n    CHECK_OR_RETURN(device);\n    JUST(device->AllocPinned(options, ptr, size));\n  } else {\n    *ptr = aligned_alloc(kMaxAlignmentRequirement, RoundUp(size, kMaxAlignmentRequirement));\n    if (*ptr == nullptr) {\n      return Error::RuntimeError()\n             << \"CPU can't allocate memory. Tried to allocate \" << FormatMemSize(size);\n    }\n  }\n  memset(*ptr, 0, size);\n  return Maybe<void>::Ok();\n}\n\nvoid CpuDevice::Free(const AllocationOptions& options, void* ptr) {\n  if (options.HasPinnedDevice()) {\n    auto device =\n        this->device_manager()->registry()->GetDevice(options.GetPinnedDeviceType(),    // NOLINT\n                                                      options.GetPinnedDeviceIndex());  // NOLINT\n    CHECK(device);\n    return device->FreePinned(options, ptr);\n  } else {\n    free(ptr);  // NOLINT\n  }\n}\n\nMaybe<void> CpuDevice::AllocPinned(const AllocationOptions& options, void** ptr, size_t size) {\n  AllocationOptions new_options = options;\n  new_options.ClearPinnedDevice();\n  return Alloc(new_options, ptr, size);\n}\n\nvoid CpuDevice::FreePinned(const AllocationOptions& options, void* ptr) {\n  AllocationOptions new_options = options;\n  new_options.ClearPinnedDevice();\n  return Free(new_options, ptr);\n}\n\n}  // namespace ep\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/ep/cpu/cpu_device.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_EP_CPU_CPU_DEVICE_H_\n#define ONEFLOW_CORE_EP_CPU_CPU_DEVICE_H_\n\n#include \"oneflow/core/ep/include/device.h\"\n\nnamespace oneflow {\n\nnamespace ep {\n\nclass CpuDevice : public Device {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CpuDevice);\n  explicit CpuDevice(DeviceManager* device_manager)\n      : device_manager_(device_manager), num_threads_(1) {}\n  ~CpuDevice() override = default;\n\n  void SetAsActiveDevice() override;\n  void Reset() override {}\n  void SetNumThreads(size_t num_threads) { num_threads_ = num_threads; }\n  size_t GetNumThreads() { return num_threads_; }\n\n  DeviceType device_type() const override { return DeviceType::kCPU; }\n  size_t device_index() const override { return 0; }\n  DeviceManager* device_manager() const override { return device_manager_; }\n\n  Stream* CreateStream() override;\n  void DestroyStream(Stream* stream) override;\n\n  void CreateEvents(Event** events, size_t count) override;\n  void DestroyEvents(Event** events, size_t count) override;\n\n  Maybe<void> Alloc(const AllocationOptions& options, void** ptr, size_t size) override;\n  void Free(const AllocationOptions& options, void* ptr) override;\n  Maybe<void> AllocPinned(const AllocationOptions& options, void** ptr, size_t size) override;\n  void FreePinned(const AllocationOptions& options, void* ptr) override;\n\n private:\n  DeviceManager* device_manager_;\n  size_t num_threads_;\n};\n\n}  // namespace ep\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_EP_CPU_CPU_DEVICE_H_\n"
  },
  {
    "path": "oneflow/core/ep/cpu/cpu_device_manager.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/ep/cpu/cpu_device_manager.h\"\n#include \"oneflow/core/ep/cpu/cpu_device.h\"\n#include \"oneflow/core/ep/cpu/cpu_random_generator.h\"\n\nnamespace oneflow {\n\nnamespace ep {\n\nCpuDeviceManager::CpuDeviceManager(DeviceManagerRegistry* registry)\n    : device_num_threads_(1), registry_(registry) {}\n\nCpuDeviceManager::~CpuDeviceManager() = default;\n\nDeviceManagerRegistry* CpuDeviceManager::registry() const { return registry_; }\n\nstd::shared_ptr<Device> CpuDeviceManager::GetDevice(size_t device_index) {\n  std::lock_guard<std::mutex> lock(device_mutex_);\n  if (!device_) { device_.reset(new CpuDevice(this)); }\n  device_->SetNumThreads(device_num_threads_);\n  return device_;\n}\n\nsize_t CpuDeviceManager::GetDeviceCount(size_t /*primary_device_index*/) { return 1; }\n\nsize_t CpuDeviceManager::GetDeviceCount() { return 1; }\n\nsize_t CpuDeviceManager::GetActiveDeviceIndex() { return 0; }\n\nvoid CpuDeviceManager::SetActiveDeviceByIndex(size_t device_index) {}\n\nvoid CpuDeviceManager::SetDeviceNumThreads(size_t num_threads) {\n  device_num_threads_ = num_threads;\n}\n\nstd::shared_ptr<RandomGenerator> CpuDeviceManager::CreateRandomGenerator(uint64_t seed,\n                                                                         size_t device_index) {\n  return std::make_shared<CPUGenerator>(seed, device_index);\n}\n\n}  // namespace ep\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/ep/cpu/cpu_device_manager.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_EP_CPU_CPU_DEVICE_MANAGER_H_\n#define ONEFLOW_CORE_EP_CPU_CPU_DEVICE_MANAGER_H_\n\n#include \"oneflow/core/ep/include/device_manager.h\"\n\nnamespace oneflow {\n\nnamespace ep {\n\nclass CpuDevice;\n\nclass CpuDeviceManager : public DeviceManager {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CpuDeviceManager);\n  explicit CpuDeviceManager(DeviceManagerRegistry* registry);\n  ~CpuDeviceManager() override;\n\n  DeviceManagerRegistry* registry() const override;\n  std::shared_ptr<Device> GetDevice(size_t device_index) override;\n  size_t GetDeviceCount(size_t primary_device_index) override;\n  size_t GetDeviceCount() override;\n  size_t GetActiveDeviceIndex() override;\n  void SetActiveDeviceByIndex(size_t device_index) override;\n  void SetDeviceNumThreads(size_t num_threads);\n\n  std::shared_ptr<RandomGenerator> CreateRandomGenerator(uint64_t seed,\n                                                         size_t device_index) override;\n\n private:\n  size_t device_num_threads_;\n  std::mutex device_mutex_;\n  std::shared_ptr<CpuDevice> device_;\n  DeviceManagerRegistry* registry_;\n};\n\n}  // namespace ep\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_EP_CPU_CPU_DEVICE_MANAGER_H_\n"
  },
  {
    "path": "oneflow/core/ep/cpu/cpu_device_manager_factory.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/ep/include/device_manager_factory.h\"\n#include \"oneflow/core/ep/include/device_manager_registry.h\"\n#include \"oneflow/core/ep/cpu/cpu_device_manager.h\"\n\nnamespace oneflow {\n\nnamespace ep {\n\nnamespace {\n\nclass CpuDeviceManagerFactory : public DeviceManagerFactory {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CpuDeviceManagerFactory);\n  CpuDeviceManagerFactory() = default;\n  ~CpuDeviceManagerFactory() override = default;\n\n  std::unique_ptr<DeviceManager> NewDeviceManager(DeviceManagerRegistry* registry) override {\n    return std::make_unique<CpuDeviceManager>(registry);\n  }\n\n  DeviceType device_type() const override { return DeviceType::kCPU; }\n\n  std::string device_type_name() const override { return \"cpu\"; }\n};\n\nCOMMAND(DeviceManagerRegistry::RegisterDeviceManagerFactory(\n    std::make_unique<CpuDeviceManagerFactory>()))\n\n}  // namespace\n\nnamespace {\n\nclass MockDeviceManagerFactory : public DeviceManagerFactory {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(MockDeviceManagerFactory);\n  MockDeviceManagerFactory() = default;\n  ~MockDeviceManagerFactory() override = default;\n\n  std::unique_ptr<DeviceManager> NewDeviceManager(DeviceManagerRegistry* registry) override {\n    return std::make_unique<CpuDeviceManager>(registry);\n  }\n\n  DeviceType device_type() const override { return DeviceType::kMockDevice; }\n\n  std::string device_type_name() const override { return \"mock\"; }\n};\n\nCOMMAND(DeviceManagerRegistry::RegisterDeviceManagerFactory(\n    std::make_unique<MockDeviceManagerFactory>()))\n\n}  // namespace\n\nnamespace {\n\nclass MetaDeviceManagerFactory : public DeviceManagerFactory {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(MetaDeviceManagerFactory);\n  MetaDeviceManagerFactory() = default;\n  ~MetaDeviceManagerFactory() override = default;\n\n  std::unique_ptr<DeviceManager> NewDeviceManager(DeviceManagerRegistry* registry) override {\n    return std::make_unique<CpuDeviceManager>(registry);\n  }\n\n  DeviceType device_type() const override { return DeviceType::kMeta; }\n\n  std::string device_type_name() const override { return \"meta\"; }\n};\n\nCOMMAND(DeviceManagerRegistry::RegisterDeviceManagerFactory(\n    std::make_unique<MetaDeviceManagerFactory>()))\n\n}  // namespace\n\n}  // namespace ep\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/ep/cpu/cpu_event.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/ep/cpu/cpu_event.h\"\n\nnamespace oneflow {\n\nnamespace ep {\n\nMaybe<bool> CpuEvent::QueryDone() { return Maybe<bool>(true); }\n\nMaybe<void> CpuEvent::Sync() { return Maybe<void>::Ok(); }\n\n}  // namespace ep\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/ep/cpu/cpu_event.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_EP_CPU_CPU_EVENT_H_\n#define ONEFLOW_CORE_EP_CPU_CPU_EVENT_H_\n\n#include \"oneflow/core/ep/include/event.h\"\n\nnamespace oneflow {\n\nnamespace ep {\n\nclass CpuEvent : public Event {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CpuEvent);\n  CpuEvent() = default;\n  ~CpuEvent() override = default;\n\n  Maybe<bool> QueryDone() override;\n  Maybe<void> Sync() override;\n};\n\n}  // namespace ep\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_EP_CPU_CPU_EVENT_H_\n"
  },
  {
    "path": "oneflow/core/ep/cpu/cpu_random_generator.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/ep/cpu/cpu_random_generator.h\"\n\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/common/str_util.h\"\n\nnamespace oneflow {\nnamespace ep {\n\nstruct CPUGeneratorState {\n  static constexpr int64_t state_size = std::mt19937::state_size;  // 624\n  int64_t states[state_size] = {};\n  int64_t seed = 0;\n};\nconstexpr int64_t CPUGeneratorState::state_size;\n\nvoid CPUGenerator::set_current_seed(uint64_t seed) {\n  seed_ = seed;\n  engine_.seed(seed_);\n  torch_engine_ = pytorch_mt19937_engine(seed);\n}\n\nsize_t CPUGenerator::GetStateSize() const { return sizeof(CPUGeneratorState); }\n\nvoid CPUGenerator::GetState(size_t state_size, void* state) const {\n  CHECK_EQ_OR_THROW(state_size, GetStateSize())\n      << \"state size of cpu generator should be equal to \" << GetStateSize();\n  CPUGeneratorState local_state;\n  std::stringstream ss;\n  ss << engine_;\n  std::vector<std::string> splits;\n  Split(ss.str(), \" \", [&](std::string&& s) { splits.emplace_back(s); });\n  // The last element in `splits` indicates state size, not state.\n  if (splits.size() != CPUGeneratorState::state_size + 1) {\n    return THROW(RuntimeError) << \"std::mt19937 state size should be \"\n                               << CPUGeneratorState::state_size << \", but got \"\n                               << splits.size() - 1;\n  }\n  for (int i = 0; i < CPUGeneratorState::state_size; ++i) {\n    local_state.states[i] = std::atoll(splits[i].data());\n  }\n  local_state.seed = current_seed();\n  memcpy(state, &local_state, sizeof(CPUGeneratorState));\n}\n\nvoid CPUGenerator::SetState(size_t state_size, const void* state) {\n  CHECK_EQ_OR_THROW(state_size, GetStateSize())\n      << \"state size of cpu generator should be equal to \" << GetStateSize();\n  const CPUGeneratorState* local_state = static_cast<const CPUGeneratorState*>(state);\n  seed_ = local_state->seed;\n  std::stringstream ss;\n  for (int i = 0; i < CPUGeneratorState::state_size; ++i) { ss << local_state->states[i] << \" \"; }\n  ss << CPUGeneratorState::state_size;\n  ss >> engine_;\n}\n\ntemplate<>\nstd::string GetRandomGeneratorDeviceTypeName<CPUGenerator>() {\n  return \"cpu\";\n}\n\n}  // namespace ep\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/ep/cpu/cpu_random_generator.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_EP_CPU_RANDOM_GENERATOR_H_\n#define ONEFLOW_CORE_EP_CPU_RANDOM_GENERATOR_H_\n\n#include <array>\n#include <cmath>\n#include <math.h>\n#include <random>\n#include <mutex>\n\n#include \"oneflow/core/common/device_type.h\"\n#include \"oneflow/core/ep/include/random_generator.h\"\n\nnamespace oneflow {\nnamespace ep {\n\n// NOTE(Liang Depeng): The following implementation of mt19937 is modified from\n//                     https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/core/MT19937RNGEngine.h\n//                     in order to make distribution related cpu kernels to have the same output as\n//                     pytorch when setting the same seed.\nconstexpr int MERSENNE_STATE_N = 624;\nconstexpr int MERSENNE_STATE_M = 397;\nconstexpr uint32_t MATRIX_A = 0x9908b0df;\nconstexpr uint32_t UMASK = 0x80000000;\nconstexpr uint32_t LMASK = 0x7fffffff;\n\nstruct pytorch_mt19937_data_pod {\n  uint64_t seed_;\n  int left_;\n  bool seeded_;\n  uint32_t next_;\n  std::array<uint32_t, MERSENNE_STATE_N> state_;\n};\n\nclass pytorch_mt19937_engine {\n public:\n  inline explicit pytorch_mt19937_engine(uint64_t seed = 5489) { init_with_uint32(seed); }\n\n  inline pytorch_mt19937_data_pod data() const { return data_; }\n\n  inline void set_data(pytorch_mt19937_data_pod data) { data_ = data; }\n\n  inline uint64_t seed() const { return data_.seed_; }\n\n  inline bool is_valid() {\n    if ((data_.seeded_ == true) && (data_.left_ > 0 && data_.left_ <= MERSENNE_STATE_N)\n        && (data_.next_ <= MERSENNE_STATE_N)) {\n      return true;\n    }\n    return false;\n  }\n\n  inline uint32_t operator()() {\n    uint32_t y;\n\n    if (--(data_.left_) == 0) { next_state(); }\n    y = *(data_.state_.data() + data_.next_++);\n    y ^= (y >> 11);\n    y ^= (y << 7) & 0x9d2c5680;\n    y ^= (y << 15) & 0xefc60000;\n    y ^= (y >> 18);\n\n    return y;\n  }\n\n private:\n  pytorch_mt19937_data_pod data_;\n\n  inline void init_with_uint32(uint64_t seed) {\n    data_.seed_ = seed;\n    data_.seeded_ = true;\n    data_.state_[0] = seed & 0xffffffff;\n    for (int j = 1; j < MERSENNE_STATE_N; ++j) {\n      data_.state_[j] = (1812433253 * (data_.state_[j - 1] ^ (data_.state_[j - 1] >> 30)) + j);\n    }\n    data_.left_ = 1;\n    data_.next_ = 0;\n  }\n\n  inline uint32_t mix_bits(uint32_t u, uint32_t v) { return (u & UMASK) | (v & LMASK); }\n\n  inline uint32_t twist(uint32_t u, uint32_t v) {\n    return (mix_bits(u, v) >> 1) ^ (v & 1 ? MATRIX_A : 0);\n  }\n\n  inline void next_state() {\n    uint32_t* p = data_.state_.data();\n    data_.left_ = MERSENNE_STATE_N;\n    data_.next_ = 0;\n\n    for (int j = MERSENNE_STATE_N - MERSENNE_STATE_M + 1; --j; p++) {\n      *p = p[MERSENNE_STATE_M] ^ twist(p[0], p[1]);\n    }\n\n    for (int j = MERSENNE_STATE_M; --j; p++) {\n      *p = p[MERSENNE_STATE_M - MERSENNE_STATE_N] ^ twist(p[0], p[1]);\n    }\n\n    *p = p[MERSENNE_STATE_M - MERSENNE_STATE_N] ^ twist(p[0], data_.state_[0]);\n  }\n};\n\nclass CPUGenerator : public RandomGenerator {\n public:\n  explicit CPUGenerator(uint64_t seed, int device_index)\n      : RandomGenerator(), seed_(seed), engine_(seed), torch_engine_(seed) {}\n\n  virtual ~CPUGenerator() = default;\n\n  uint64_t current_seed() const override { return seed_; }\n  void set_current_seed(uint64_t seed) override;\n\n  std::mt19937& engine() { return engine_; }\n\n  pytorch_mt19937_engine& torch_engine() { return torch_engine_; }\n\n  std::string device_type_name() const override { return \"cpu\"; }\n  int64_t device_index() const override { return 0; }\n\n  size_t GetStateSize() const override;\n  void GetState(size_t state_size, void* state) const override;\n  void SetState(size_t state_size, const void* state) override;\n\n public:\n  mutable std::mutex mutex_;\n  uint64_t seed_;\n  std::mt19937 engine_;\n  // TODO(Liang Depeng): needed to implement the get_state/set_state of pytorch_mt_19937_engine\n  //                     refer to\n  //                     https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/CPUGenerator.cpp#L206\n  pytorch_mt19937_engine torch_engine_;\n};\n\n}  // namespace ep\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_EP_CPU_RANDOM_GENERATOR_H_\n"
  },
  {
    "path": "oneflow/core/ep/cpu/cpu_stream.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/ep/cpu/cpu_stream.h\"\n#include \"oneflow/core/thread/thread_runtime_factory.h\"\n\nnamespace oneflow {\n\nnamespace ep {\n\nDeviceType CpuStream::device_type() const { return DeviceType::kCPU; }\n\nCpuDevice* CpuStream::device() const { return device_; }\n\nMaybe<void> CpuStream::Sync() { return Maybe<void>::Ok(); }\n\nvoid CpuStream::RecordEvent(Event* /*event*/) {}\n\nMaybe<void> CpuStream::InitThreadRuntime() {\n  const auto thread_runtime_type = GetStringFromEnv(\"OF_THREADING_RUNTIME\", [] {\n    if (thread::IsTbbEnabled()) { return \"TBB\"; }\n    if (thread::IsOmpEnabled()) { return \"OMP\"; }\n    return \"SEQ\";\n  }());\n  thread_runtime_ = JUST(thread::RuntimeFactory::Create(thread_runtime_type));\n  return Maybe<void>::Ok();\n}\n\n#ifdef WITH_ONEDNN\n\nconst std::unique_ptr<ep::OneDnnExecutor>& CpuStream::onednn_executor() const {\n  return onednn_executor_;\n}\n\n#endif\n\n}  // namespace ep\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/ep/cpu/cpu_stream.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_EP_CPU_CPU_STREAM_H_\n#define ONEFLOW_CORE_EP_CPU_CPU_STREAM_H_\n\n#include \"oneflow/core/ep/include/stream.h\"\n#include \"oneflow/core/ep/cpu/cpu_device.h\"\n#include \"oneflow/core/thread/thread_runtime_factory.h\"\n\n#ifdef WITH_ONEDNN\n#include <oneapi/dnnl/dnnl.hpp>\n#endif\n\nnamespace oneflow {\n\nnamespace ep {\n\nclass CpuNumThreadsGuard {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CpuNumThreadsGuard);\n#if WITH_TBB\n  explicit CpuNumThreadsGuard(size_t num_threads)\n      : global_thread_limit(tbb::global_control::max_allowed_parallelism, num_threads) {}\n  ~CpuNumThreadsGuard() {}\n#elif WITH_OMP\n  explicit CpuNumThreadsGuard(size_t num_threads) : set_num_threads_(num_threads) {\n    saved_num_threads_ = omp_get_max_threads();\n    omp_set_num_threads(set_num_threads_);\n  }\n  ~CpuNumThreadsGuard() { omp_set_num_threads(saved_num_threads_); }\n#endif\n\n private:\n#if WITH_TBB\n  tbb::global_control global_thread_limit;\n#elif WITH_OMP\n  size_t set_num_threads_;\n  size_t saved_num_threads_;\n#endif\n};\n\n#ifdef WITH_ONEDNN\n\nclass OneDnnExecutor;\n\n#endif\n\nclass CpuStream : public Stream {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CpuStream);\n\n  explicit CpuStream(CpuDevice* device) : device_(device) {\n    CHECK_JUST(InitThreadRuntime());\n#ifdef WITH_ONEDNN\n    onednn_executor_ = std::make_unique<ep::OneDnnExecutor>(this);\n#endif\n  }\n\n  ~CpuStream() override = default;\n\n  DeviceType device_type() const override;\n  CpuDevice* device() const override;\n  Maybe<void> Sync() override;\n  void RecordEvent(Event* event) override;\n\n  template<typename F>\n  void ParallelFor(int64_t begin, int64_t end, const F& func) {\n    ParallelFor(begin, end, func, kParallelForDefaultGrain);\n  }\n\n  template<typename F>\n  void ParallelFor(int64_t begin, int64_t end, const F& func, size_t grain_size) {\n    thread_runtime_->ParallelFor(begin, end, func, device()->GetNumThreads(), grain_size);\n  }\n\n#ifdef WITH_ONEDNN\n  const std::unique_ptr<ep::OneDnnExecutor>& onednn_executor() const;\n#endif\n\n private:\n  CpuDevice* device_;\n  static constexpr size_t kParallelForDefaultGrain = 32768;\n  std::shared_ptr<thread::RuntimeBase> thread_runtime_;\n\n  Maybe<void> InitThreadRuntime();\n\n#ifdef WITH_ONEDNN\n  std::unique_ptr<ep::OneDnnExecutor> onednn_executor_;\n#endif\n};\n\n#ifdef WITH_ONEDNN\n\nclass OneDnnExecutor {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(OneDnnExecutor);\n\n  OneDnnExecutor() = delete;\n\n  explicit OneDnnExecutor(CpuStream* cpu_stream) : cpu_stream_(cpu_stream) {\n    engine_.reset(new dnnl::engine(dnnl::engine::kind::cpu, 0));\n    stream_.reset(new dnnl::stream(*engine_));\n  }\n\n  ~OneDnnExecutor() = default;\n\n  template<typename F>\n  void Launch(const F& f) {\n    CpuNumThreadsGuard guard(cpu_stream_->device()->GetNumThreads());\n    f(engine_.get(), stream_.get());\n    stream_->wait();\n  }\n\n private:\n  CpuStream* cpu_stream_ = nullptr;\n  std::unique_ptr<dnnl::engine> engine_;\n  std::unique_ptr<dnnl::stream> stream_;\n};\n\n#endif\n\n}  // namespace ep\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_EP_CPU_CPU_STREAM_H_\n"
  },
  {
    "path": "oneflow/core/ep/cpu/primitive/add.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/ep/include/primitive/add.h\"\n#include \"oneflow/core/ep/cpu/primitive/type_seq.h\"\n#include \"oneflow/core/ep/cpu/cpu_stream.h\"\n#include \"oneflow/core/ep/common/primitive/util.h\"\n#include \"oneflow/core/ep/common/onednn.h\"\n\nnamespace oneflow {\n\nnamespace ep {\nnamespace primitive {\n\nnamespace {\n\ntemplate<typename T, size_t arity>\nvoid AddCpu(const T* const* srcs, T* dst, size_t count) {\n  for (size_t i = 0; i < count; ++i) {\n    T sum = T(0);\n    for (size_t a = 0; a < arity; ++a) { sum += srcs[a][i]; }\n    dst[i] = sum;\n  }\n}\n\ntemplate<typename T>\nvoid AddCpu(const T* const* srcs, size_t arity, T* dst, size_t count) {\n  for (size_t i = 0; i < count; ++i) {\n    T sum = T(0);\n    for (size_t a = 0; a < arity; ++a) { sum += srcs[a][i]; }\n    dst[i] = sum;\n  }\n}\n\ntemplate<typename T>\nclass AddDefaultImpl : public Add {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(AddDefaultImpl);\n  AddDefaultImpl() = default;\n  ~AddDefaultImpl() override = default;\n\n  using Add::Launch;\n  void Launch(Stream* stream, const void* const* srcs, size_t arity, void* dst,\n              size_t count) override {\n#define ONE_IF(a)                                                                            \\\n  if (arity == a) {                                                                          \\\n    AddCpu<T, a>(reinterpret_cast<const T* const*>(srcs), reinterpret_cast<T*>(dst), count); \\\n  }\n#define ONE_ELIF(a) else ONE_IF(a)\n#define ONE_ELSE                                                                                 \\\n  else {                                                                                         \\\n    AddCpu<T>(reinterpret_cast<const T* const*>(srcs), arity, reinterpret_cast<T*>(dst), count); \\\n  }\n    ONE_IF(0)\n    ONE_ELIF(1)\n    ONE_ELIF(2)\n    ONE_ELIF(3)\n    ONE_ELIF(4)\n    ONE_ELIF(5)\n    ONE_ELIF(6)\n    ONE_ELIF(7)\n    ONE_ELIF(8)\n    ONE_ELSE\n  }\n};\n\n#ifdef WITH_ONEDNN\nclass AddOneDnnImpl : public Add {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(AddOneDnnImpl);\n  explicit AddOneDnnImpl(dnnl::memory::data_type type) : type_onednn_(type){};\n  ~AddOneDnnImpl() override = default;\n\n  using Add::Launch;\n  void Launch(Stream* stream, const void* const* srcs, size_t arity, void* dst,\n              size_t count) override {\n    if (arity < 2) {\n      // TODO: arity 0 and 1\n      UNIMPLEMENTED() << \"Addn only supports summation of 2 or more tensors\";\n    } else if (arity == 2) {\n      if (srcs[1] == dst && srcs[0] != dst) {\n        LOG(FATAL) << \"Only the first parameter can be operated inplace\";\n      }\n    } else {\n      for (int i = 2; i < arity; i++) {\n        if (srcs[i] == dst) { LOG(FATAL) << \"Only the first parameter can be operated inplace\"; }\n      }\n    }\n\n    stream->As<CpuStream>()->onednn_executor()->Launch(\n        [&](dnnl::engine* onednn_engine, dnnl::stream* onednn_stream) {\n          dnnl::memory::dims src_dims = {static_cast<dnnl::memory::dim>(count)};\n          std::vector<dnnl::memory::desc> src_md;\n          std::vector<dnnl::memory> src_mem;\n          src_md.reserve(arity);\n          src_mem.reserve(arity);\n\n          for (int i = 0; i < arity; i++) {\n            auto md = dnnl::memory::desc(src_dims, type_onednn_, dnnl::memory::format_tag::x);\n            auto mem = dnnl::memory(md, *onednn_engine, (void*)(srcs)[i]);\n            src_md.emplace_back(md);\n            src_mem.emplace_back(mem);\n          }\n\n          std::vector<float> scales(arity, 1.0);\n          auto sum_pd = dnnl::sum::primitive_desc(scales, src_md, *onednn_engine);\n          auto sum_prim = dnnl::sum(sum_pd);\n          auto dst_mem = dnnl::memory(sum_pd.dst_desc(), *onednn_engine, dst);\n          std::unordered_map<int, dnnl::memory> sum_args{{DNNL_ARG_DST, dst_mem}};\n          for (int i = 0; i < arity; ++i) {\n            sum_args.insert({DNNL_ARG_MULTIPLE_SRC + i, src_mem[i]});\n          }\n\n          sum_prim.execute(*onednn_stream, sum_args);\n        });\n  }\n\n private:\n  dnnl::memory::data_type type_onednn_;\n};\n\n#endif\n\ntemplate<typename T>\nstd::unique_ptr<Add> NewAdd() {\n  return std::unique_ptr<Add>(new AddDefaultImpl<T>());\n}\n\n#ifdef WITH_ONEDNN\n\ntemplate<dnnl::memory::data_type type_onednn>\nstd::unique_ptr<Add> NewOneDnnAdd() {\n  return std::unique_ptr<Add>(new AddOneDnnImpl(type_onednn));\n}\n\n#endif\n\n#define CPU_PRIMITIVE_ADD_ONEDNN_TYPE_SEQ \\\n  CPU_PRIMITIVE_ONEDNN_INT8_TYPE_SEQ      \\\n  CPU_PRIMITIVE_ONEDNN_UINT8_TYPE_SEQ     \\\n  CPU_PRIMITIVE_ONEDNN_INT32_TYPE_SEQ     \\\n  CPU_PRIMITIVE_ONEDNN_FLOAT_TYPE_SEQ     \\\n  CPU_PRIMITIVE_ONEDNN_FLOAT16_TYPE_SEQ   \\\n  CPU_PRIMITIVE_ONEDNN_BFLOAT16_TYPE_SEQ\n\n#define CPU_PRIMITIVE_ADD_DEFAULT_TYPE_SEQ \\\n  CPU_PRIMITIVE_BOOL_TYPE_SEQ              \\\n  CPU_PRIMITIVE_CHAR_TYPE_SEQ              \\\n  CPU_PRIMITIVE_DOUBLE_TYPE_SEQ            \\\n  CPU_PRIMITIVE_INT64_TYPE_SEQ\n\nclass AddFactoryImpl : public AddFactory {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(AddFactoryImpl);\n  AddFactoryImpl() = default;\n  ~AddFactoryImpl() override = default;\n\n  std::unique_ptr<Add> New(DataType data_type) override {\n#define MAKE_NEW_ADD_ENTRY(type_cpp, type_proto) {type_proto, NewAdd<type_cpp>},\n\n    static const std::map<DataType, std::function<std::unique_ptr<Add>()>> new_add_handle{\n        OF_PP_FOR_EACH_TUPLE(MAKE_NEW_ADD_ENTRY, CPU_PRIMITIVE_ALL_TYPE_SEQ)};\n\n#undef MAKE_NEW_ADD_ENTRY\n#ifdef WITH_ONEDNN\n\n#define MAKE_NEW_ONEDNN_ADD_ENTRY(type_onednn, type_proto) {type_proto, NewOneDnnAdd<type_onednn>},\n\n    static const std::map<DataType, std::function<std::unique_ptr<Add>()>> new_add_onednn_handle{\n        OF_PP_FOR_EACH_TUPLE(MAKE_NEW_ONEDNN_ADD_ENTRY, CPU_PRIMITIVE_ADD_ONEDNN_TYPE_SEQ)};\n\n#undef MAKE_NEW_ONEDNN_ADD_ENTRY\n\n    if (OneDnnIsEnabled()) {\n      auto add_primitive = NewPrimitiveFromHandlers(new_add_onednn_handle, data_type);\n      if (add_primitive) { return add_primitive; }\n    }\n\n#endif\n    return NewPrimitiveFromHandlers(new_add_handle, data_type);\n  }\n};\n\nREGISTER_PRIMITIVE_FACTORY(DeviceType::kCPU, AddFactory, AddFactoryImpl);\n\n}  // namespace\n\n}  // namespace primitive\n}  // namespace ep\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/ep/cpu/primitive/binary_functor.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/ep/common/primitive/binary_functor.h\"\n#include \"oneflow/core/ep/cpu/primitive/unary_functor.h\"\nnamespace oneflow {\n\nnamespace ep {\nnamespace primitive {\nnamespace broadcast_elementwise_binary {\n\ntemplate<typename Src, typename Dst>\nstruct BinaryFunctor<DeviceType::kCPU, BinaryOp::kPow, Src, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src src0, Src src1) const { return std::pow(src0, src1); }\n};\n\ntemplate<>\nstruct BinaryFunctor<DeviceType::kCPU, BinaryOp::kPow, float16, float16> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC float16 operator()(float16 src0, float16 src1) const {\n    return static_cast<float16>(std::pow(static_cast<float>(src0), static_cast<float>(src1)));\n  }\n};\n\ntemplate<>\nstruct BinaryFunctor<DeviceType::kCPU, BinaryOp::kFmod, float, float> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC float operator()(float src0, float src1) const { return std::fmod(src0, src1); }\n};\n\ntemplate<>\nstruct BinaryFunctor<DeviceType::kCPU, BinaryOp::kFmod, double, double> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC double operator()(double src0, double src1) const { return std::fmod(src0, src1); }\n};\n\ntemplate<>\nstruct BinaryFunctor<DeviceType::kCPU, BinaryOp::kFmod, float16, float16> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC float16 operator()(float16 src0, float16 src1) const {\n    return static_cast<float16>(std::fmod(static_cast<float>(src0), static_cast<float>(src1)));\n  }\n};\n\ntemplate<>\nstruct BinaryFunctor<DeviceType::kCPU, BinaryOp::kFmod, bfloat16, bfloat16> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC bfloat16 operator()(bfloat16 src0, bfloat16 src1) const {\n    return std::fmod(src0, src1);\n  }\n};\n\ntemplate<>\nstruct BinaryFunctor<DeviceType::kCPU, BinaryOp::kFloorDiv, float, float> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC float operator()(float src0, float src1) const { return std::floor(src0 / src1); }\n};\n\ntemplate<>\nstruct BinaryFunctor<DeviceType::kCPU, BinaryOp::kFloorDiv, double, double> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC double operator()(double src0, double src1) const {\n    return std::floor(src0 / src1);\n  }\n};\n\ntemplate<>\nstruct BinaryFunctor<DeviceType::kCPU, BinaryOp::kFloorDiv, float16, float16> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC float16 operator()(float16 src0, float16 src1) const {\n    return static_cast<float16>(std::floor(static_cast<float>(src0) / static_cast<float>(src1)));\n  }\n};\n\ntemplate<>\nstruct BinaryFunctor<DeviceType::kCPU, BinaryOp::kFloorDiv, bfloat16, bfloat16> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC bfloat16 operator()(bfloat16 src0, bfloat16 src1) const {\n    return std::floor(src0 / src1);\n  }\n};\n\ntemplate<>\nstruct BinaryFunctor<DeviceType::kCPU, BinaryOp::kTruncDiv, float, float> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC float operator()(float src0, float src1) const { return std::trunc(src0 / src1); }\n};\n\ntemplate<>\nstruct BinaryFunctor<DeviceType::kCPU, BinaryOp::kTruncDiv, double, double> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC double operator()(double src0, double src1) const {\n    return std::trunc(src0 / src1);\n  }\n};\n\ntemplate<>\nstruct BinaryFunctor<DeviceType::kCPU, BinaryOp::kTruncDiv, float16, float16> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC float16 operator()(float16 src0, float16 src1) const {\n    return static_cast<float16>(std::trunc(static_cast<float>(src0) / static_cast<float>(src1)));\n  }\n};\n\ntemplate<>\nstruct BinaryFunctor<DeviceType::kCPU, BinaryOp::kTruncDiv, bfloat16, bfloat16> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC bfloat16 operator()(bfloat16 src0, bfloat16 src1) const {\n    return std::trunc(src0 / src1);\n  }\n};\n\ntemplate<>\nstruct BinaryFunctor<DeviceType::kCPU, BinaryOp::kFloorMod, float, float> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC float operator()(float src0, float src1) const {\n    float trunc_mod = std::fmod(src0, src1);\n    return (trunc_mod != static_cast<float>(0))\n                   && ((src1 < static_cast<float>(0)) != (trunc_mod < static_cast<float>(0)))\n               ? trunc_mod + src1\n               : trunc_mod;\n  }\n};\n\ntemplate<>\nstruct BinaryFunctor<DeviceType::kCPU, BinaryOp::kFloorMod, double, double> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC double operator()(double src0, double src1) const {\n    double trunc_mod = std::fmod(src0, src1);\n    return (trunc_mod != static_cast<double>(0))\n                   && ((src1 < static_cast<double>(0)) != (trunc_mod < static_cast<double>(0)))\n               ? trunc_mod + src1\n               : trunc_mod;\n  }\n};\n\ntemplate<>\nstruct BinaryFunctor<DeviceType::kCPU, BinaryOp::kFloorMod, float16, float16> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) : float_functor(attr0, attr1) {}\n  BinaryFunctor<DeviceType::kCPU, BinaryOp::kFloorMod, float, float> float_functor;\n\n  OF_DEVICE_FUNC float16 operator()(float16 src0, float16 src1) const {\n    return static_cast<float16>(float_functor(static_cast<float>(src0), static_cast<float>(src1)));\n  }\n};\n\ntemplate<>\nstruct BinaryFunctor<DeviceType::kCPU, BinaryOp::kFloorMod, bfloat16, bfloat16> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) : float_functor(attr0, attr1) {}\n  BinaryFunctor<DeviceType::kCPU, BinaryOp::kFloorMod, float, float> float_functor;\n\n  OF_DEVICE_FUNC bfloat16 operator()(bfloat16 src0, bfloat16 src1) const {\n    return static_cast<bfloat16>(float_functor(static_cast<float>(src0), static_cast<float>(src1)));\n  }\n};\n\ntemplate<>\nstruct BinaryFunctor<DeviceType::kCPU, BinaryOp::kScalarBasePowerGrad, float16, float16> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) : scalar_operand(attr0.Value<float>()) {}\n\n  OF_DEVICE_FUNC float16 operator()(float16 src0, float16 src1) const {\n    return static_cast<float16>(\n        scalar_operand\n        * (std::pow(static_cast<float>(src0), scalar_operand - static_cast<float>(1)))\n        * static_cast<float>(src1));\n  }\n  float scalar_operand;\n};\n\ntemplate<typename Dst>\nstruct BinaryFunctor<DeviceType::kCPU, BinaryOp::kScalarExpPowerGrad, int, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) : float_functor(attr0, attr1) {}\n  BinaryFunctor<DeviceType::kCPU, BinaryOp::kScalarExpPowerGrad, float, float> float_functor;\n\n  OF_DEVICE_FUNC Dst operator()(int src0, int src1) const {\n    return static_cast<Dst>(float_functor(static_cast<float>(src0), static_cast<float>(src1)));\n  }\n};\n\ntemplate<typename Dst>\nstruct BinaryFunctor<DeviceType::kCPU, BinaryOp::kScalarExpPowerGrad, int8_t, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) : float_functor(attr0, attr1) {}\n  BinaryFunctor<DeviceType::kCPU, BinaryOp::kScalarExpPowerGrad, float, float> float_functor;\n\n  OF_DEVICE_FUNC Dst operator()(int8_t src0, int8_t src1) const {\n    return static_cast<Dst>(float_functor(static_cast<float>(src0), static_cast<float>(src1)));\n  }\n};\n\ntemplate<typename Dst>\nstruct BinaryFunctor<DeviceType::kCPU, BinaryOp::kScalarExpPowerGrad, uint8_t, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) : float_functor(attr0, attr1) {}\n  BinaryFunctor<DeviceType::kCPU, BinaryOp::kScalarExpPowerGrad, float, float> float_functor;\n\n  OF_DEVICE_FUNC Dst operator()(uint8_t src0, uint8_t src1) const {\n    return static_cast<Dst>(float_functor(static_cast<float>(src0), static_cast<float>(src1)));\n  }\n};\n\ntemplate<typename Dst>\nstruct BinaryFunctor<DeviceType::kCPU, BinaryOp::kScalarExpPowerGrad, int64_t, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) : float_functor(attr0, attr1) {}\n  BinaryFunctor<DeviceType::kCPU, BinaryOp::kScalarExpPowerGrad, float, float> float_functor;\n\n  OF_DEVICE_FUNC Dst operator()(int src0, int src1) const {\n    return static_cast<Dst>(float_functor(static_cast<float>(src0), static_cast<float>(src1)));\n  }\n};\n\ntemplate<typename Dst>\nstruct BinaryFunctor<DeviceType::kCPU, BinaryOp::kScalarExpPowerGrad, float16, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) : scalar_operand(attr0.Value<float>()) {}\n\n  OF_DEVICE_FUNC Dst operator()(float16 src0, float16 src1) const {\n    return static_cast<Dst>(std::pow(scalar_operand, static_cast<float>(src0))\n                            * std::log(scalar_operand) * static_cast<float>(src1));\n  }\n  float scalar_operand;\n};\n\ntemplate<typename Dst>\nstruct BinaryFunctor<DeviceType::kCPU, BinaryOp::kSqrtBackwardWithDyX, std::complex<float>, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n  OF_DEVICE_FUNC Dst operator()(std::complex<float> dy, std::complex<float> x) const {\n    return dy * static_cast<std::complex<float>>(0.5) / std::conj(std::sqrt(x));\n  }\n};\n\ntemplate<typename Dst>\nstruct BinaryFunctor<DeviceType::kCPU, BinaryOp::kSqrtBackwardWithDyX, std::complex<double>, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n  OF_DEVICE_FUNC Dst operator()(std::complex<double> dy, std::complex<double> x) const {\n    return dy * static_cast<std::complex<double>>(0.5) / std::conj(std::sqrt(x));\n  }\n};\n\ntemplate<typename Src, typename Dst>\nstruct BinaryFunctor<DeviceType::kCPU, BinaryOp::kGeluBackwardWithDyX, Src, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const {\n    return static_cast<Dst>(\n        0.5 * (1.0 + std::erf(inv_sqrt2 * x) + x * coef * std::exp(-0.5 * x * x)) * dy);\n  }\n\n  Src inv_sqrt2 = std::sqrt(0.5);\n  Src coef = std::sqrt(2.0 / std::acos(-1.0));\n};\n\ntemplate<typename Src, typename Dst>\nstruct BinaryFunctor<DeviceType::kCPU, BinaryOp::kFastGeluBackwardWithDyX, Src, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const {\n    // ref to: https://mlfromscratch.com/activation-functions-explained/#gelu\n    const Src one = static_cast<Src>(1);\n    const Src half = static_cast<Src>(0.5);\n    const Src pow3 = x * x * x;\n    const Src tanh_out = std::tanh(alpha * (x + beta * pow3));\n    const Src dtanh = alpha * (half * x + beta * static_cast<Src>(1.5) * pow3);\n    return dy * (half + half * tanh_out + dtanh * (one - tanh_out * tanh_out));\n  }\n\n private:\n  static constexpr Src alpha = static_cast<Src>(0.7978845608028654);\n  static constexpr Src beta = static_cast<Src>(0.044714998453855515);\n};\n\ntemplate<typename Src, typename Dst>\nstruct BinaryFunctor<DeviceType::kCPU, BinaryOp::kQuickGeluBackwardWithDyX, Src, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const {\n    const Src one = static_cast<Src>(1.0);\n    const Src sigmoid = one / (one + exp(-x * alpha));\n    return dy * (sigmoid + alpha * x * (sigmoid * (one - sigmoid)));\n  }\n\n private:\n  static constexpr Src alpha = static_cast<Src>(1.702);\n};\n\ntemplate<typename Src, typename Dst>\nstruct BinaryFunctor<DeviceType::kCPU, BinaryOp::kSquareReLUBackwardWithDyX, Src, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const {\n    return static_cast<Dst>((x > static_cast<Src>(0.0)) ? static_cast<Src>(2.0) * x * dy\n                                                        : static_cast<Src>(0.0));\n  }\n};\n\ntemplate<typename Src, typename Dst>\nstruct BinaryFunctor<DeviceType::kCPU, BinaryOp::kTanhBackwardWithDyY, Src, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src dy, Src y) const {\n    return static_cast<Dst>(dy * (static_cast<Src>(1.0) - y * y));\n  }\n};\n\ntemplate<typename Src, typename Dst>\nstruct BinaryFunctor<DeviceType::kCPU, BinaryOp::kAcosBackwardWithDyX, Src, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n  OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const {\n    return dy * -(static_cast<Src>(1.0) / sqrt(static_cast<Src>(1.0) - x * x));\n  }\n};\n\ntemplate<typename Src, typename Dst>\nstruct BinaryFunctor<DeviceType::kCPU, BinaryOp::kAcoshBackwardWithDyX, Src, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n  OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const {\n    return dy / sqrt(x * x - static_cast<Src>(1.0));\n  }\n};\n\ntemplate<typename Src, typename Dst>\nstruct BinaryFunctor<DeviceType::kCPU, BinaryOp::kAsinBackwardWithDyX, Src, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n  OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const {\n    return dy * (static_cast<Src>(1.0) / sqrt(static_cast<Src>(1.0) - x * x));\n  }\n};\n\ntemplate<typename Src, typename Dst>\nstruct BinaryFunctor<DeviceType::kCPU, BinaryOp::kAsinhBackwardWithDyX, Src, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n  OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const {\n    return dy * (static_cast<Src>(1.0) / sqrt(static_cast<Src>(1.0) + x * x));\n  }\n};\n\ntemplate<typename Src, typename Dst>\nstruct BinaryFunctor<DeviceType::kCPU, BinaryOp::kErfBackwardWithDyX, Src, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n  OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const {\n    return dy * static_cast<Src>(2.0) * (static_cast<Src>(1.0) / sqrt(static_cast<Src>(M_PI)))\n           * exp(-x * x);\n  }\n};\n\ntemplate<typename Src, typename Dst>\nstruct BinaryFunctor<DeviceType::kCPU, BinaryOp::kErfcBackwardWithDyX, Src, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n  OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const {\n    return dy * static_cast<Src>(-2.0) * (static_cast<Src>(1.0) / sqrt(static_cast<Src>(M_PI)))\n           * exp(-x * x);\n  }\n};\n\ntemplate<>\nstruct BinaryFunctor<DeviceType::kCPU, BinaryOp::kDigammaBackwardWithDyX, float, float> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n  OF_DEVICE_FUNC float operator()(float dy, float x) const {\n    ep::primitive::UnaryFunctor<DeviceType::kCPU, UnaryOp::kTrigamma, float, float>\n        trigamma_functor(0, 0);\n    float trigamma_result = trigamma_functor(x);\n    return trigamma_result * dy;\n  }\n};\n\ntemplate<>\nstruct BinaryFunctor<DeviceType::kCPU, BinaryOp::kDigammaBackwardWithDyX, double, double> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n  OF_DEVICE_FUNC double operator()(double dy, double x) const {\n    ep::primitive::UnaryFunctor<DeviceType::kCPU, UnaryOp::kTrigamma, double, double>\n        trigamma_functor(0, 0);\n    double trigamma_result = trigamma_functor(x);\n    return trigamma_result * dy;\n  }\n};\n\ntemplate<typename Src, typename Dst>\nstruct BinaryFunctor<DeviceType::kCPU, BinaryOp::kLgammaBackwardWithDyX, Src, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n  OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const {\n    ep::primitive::UnaryFunctor<DeviceType::kCPU, UnaryOp::kDigamma, Src, Dst> digamma_functor(0,\n                                                                                               0);\n    Dst digamma_result = digamma_functor(x);\n    return digamma_result * dy;\n  }\n};\n\ntemplate<typename Src, typename Dst>\nstruct BinaryFunctor<DeviceType::kCPU, BinaryOp::kZeta, Src, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n  OF_DEVICE_FUNC Dst operator()(Src x, Src q) const {\n    // ref\n    // https://github.com/pytorch/pytorch/blob/release/1.13/aten/src/ATen/native/Math.h#L235-L309\n    const Src MACHEP = Src{1.11022302462515654042E-16};\n    constexpr Src zero = Src{0.0};\n    constexpr Src half = Src{0.5};\n    constexpr Src one = Src{1.0};\n    static const Src A[] = {\n        12.0,\n        -720.0,\n        30240.0,\n        -1209600.0,\n        47900160.0,\n        -1.8924375803183791606e9, /*1.307674368e12/691*/\n        7.47242496e10,\n        -2.950130727918164224e12,  /*1.067062284288e16/3617*/\n        1.1646782814350067249e14,  /*5.109094217170944e18/43867*/\n        -4.5979787224074726105e15, /*8.028576626982912e20/174611*/\n        1.8152105401943546773e17,  /*1.5511210043330985984e23/854513*/\n        -7.1661652561756670113e18  /*1.6938241367317436694528e27/236364091*/\n    };\n    int i = 0;\n    Src a, b, k, s, t, w;\n    if (x == one) { return std::numeric_limits<Dst>::infinity(); }\n\n    if (x < one) { return std::numeric_limits<Dst>::quiet_NaN(); }\n\n    if (q <= zero) {\n      if (q == floor(q)) { return std::numeric_limits<Dst>::infinity(); }\n      if (x != floor(x)) { return std::numeric_limits<Dst>::quiet_NaN(); }\n    }\n\n    s = pow(q, -x);\n    a = q;\n    i = 0;\n    b = zero;\n    while ((i < 9) || (a <= Src{9.0})) {\n      i += 1;\n      a += one;\n      b = pow(a, -x);\n      s += b;\n      if ((-MACHEP * s < b) && (b < MACHEP * s)) { return static_cast<Dst>(s); }\n    };\n\n    w = a;\n    s += b * w / (x - one);\n    s -= half * b;\n    a = one;\n    k = zero;\n    for (int i = 0; i < 12; i++) {\n      a *= x + k;\n      b /= w;\n      t = a * b / A[i];\n      s = s + t;\n      t = fabs(t / s);\n      if (t < MACHEP) { return static_cast<Dst>(s); }\n      k += one;\n      a *= x + k;\n      b /= w;\n      k += one;\n    }\n    return static_cast<Dst>(s);\n  }\n};\n\n#define SPECIALIZATION_CPU_BINARY_FUNCTOR(op, type)                                          \\\n  template<>                                                                                 \\\n  struct BinaryFunctor<DeviceType::kCPU, op, type, type> {                                   \\\n    OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) : int_functor(attr0, attr1) {}  \\\n                                                                                             \\\n    BinaryFunctor<DeviceType::kCPU, op, int, int> int_functor;                               \\\n    OF_DEVICE_FUNC type operator()(type src0, type src1) const {                             \\\n      return static_cast<type>(int_functor(static_cast<int>(src0), static_cast<int>(src1))); \\\n    }                                                                                        \\\n  };\n\nSPECIALIZATION_CPU_BINARY_FUNCTOR(BinaryOp::kPow, bool);\nSPECIALIZATION_CPU_BINARY_FUNCTOR(BinaryOp::kFmod, bool);\nSPECIALIZATION_CPU_BINARY_FUNCTOR(BinaryOp::kFloorDiv, bool);\nSPECIALIZATION_CPU_BINARY_FUNCTOR(BinaryOp::kTruncDiv, bool);\nSPECIALIZATION_CPU_BINARY_FUNCTOR(BinaryOp::kFloorMod, bool);\nSPECIALIZATION_CPU_BINARY_FUNCTOR(BinaryOp::kScalarBasePowerGrad, bool);\nSPECIALIZATION_CPU_BINARY_FUNCTOR(BinaryOp::kScalarExpPowerGrad, bool);\nSPECIALIZATION_CPU_BINARY_FUNCTOR(BinaryOp::kPow, char);\nSPECIALIZATION_CPU_BINARY_FUNCTOR(BinaryOp::kFmod, char);\nSPECIALIZATION_CPU_BINARY_FUNCTOR(BinaryOp::kFloorDiv, char);\nSPECIALIZATION_CPU_BINARY_FUNCTOR(BinaryOp::kTruncDiv, char);\nSPECIALIZATION_CPU_BINARY_FUNCTOR(BinaryOp::kFloorMod, char);\nSPECIALIZATION_CPU_BINARY_FUNCTOR(BinaryOp::kScalarBasePowerGrad, char);\nSPECIALIZATION_CPU_BINARY_FUNCTOR(BinaryOp::kScalarExpPowerGrad, char);\n\n}  // namespace broadcast_elementwise_binary\n}  // namespace primitive\n}  // namespace ep\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/ep/cpu/primitive/broadcast_elementwise_binary.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/ep/include/primitive/broadcast_elementwise_binary.h\"\n#include \"oneflow/core/common/data_type.h\"\n#include \"oneflow/core/ep/common//primitive/constant_pad.h\"\n#include \"oneflow/core/ep/common/primitive/broadcast_elementwise_binary.h\"\n#include \"oneflow/core/ep/cpu/primitive/binary_functor.h\"\n#include \"oneflow/core/ep/cpu/primitive/type_seq.h\"\n#include \"oneflow/core/ndarray/ndarray_util.h\"\n#include \"oneflow/core/ndarray/xpu_var_ndarray.h\"\n#include \"oneflow/core/ep/cpu/cpu_stream.h\"\n#include \"oneflow/core/ep/cpu/cpu_device.h\"\n#include \"oneflow/core/ep/common/primitive/util.h\"\n#include \"oneflow/core/ep/common/onednn.h\"\n\nnamespace oneflow {\n\nnamespace ep {\nnamespace primitive {\nnamespace broadcast_elementwise_binary {\n\nnamespace {\n\ntemplate<typename T>\nT GetValue(Scalar value) {\n  return value.Value<T>();\n}\n\ntemplate<>\nfloat16 GetValue<float16>(Scalar value) {\n  return static_cast<float16>(GetValue<float>(value));\n}\n\ntemplate<>\nbfloat16 GetValue<bfloat16>(Scalar value) {\n  return static_cast<bfloat16>(GetValue<float>(value));\n}\n\ntemplate<BinaryOp binary_op, typename Src, typename Dst>\nstruct BinaryLhsScalarFunctor {\n  BinaryLhsScalarFunctor(Src scalar, Scalar attr0, Scalar attr1)\n      : scalar(scalar), functor(attr0, attr1) {}\n  Dst operator()(Src src) const { return functor(scalar, src); }\n  const Src scalar;\n  BinaryFunctor<DeviceType::kCPU, binary_op, Src, Dst> functor;\n};\n\ntemplate<BinaryOp binary_op, typename Src, typename Dst>\nstruct BinaryRhsScalarFunctor {\n  BinaryRhsScalarFunctor(Src scalar, Scalar attr0, Scalar attr1)\n      : scalar(scalar), functor(attr0, attr1) {}\n  Dst operator()(Src src) const { return functor(src, scalar); }\n  const Src scalar;\n  BinaryFunctor<DeviceType::kCPU, binary_op, Src, Dst> functor;\n};\n\ntemplate<BinaryOp binary_op, typename Src, typename Dst>\nvoid LaunchElementwise(CpuStream* cpu_stream, size_t simplified_num_dims,\n                       const int64_t* simplified_src0_dims, const Src* src0,\n                       const int64_t* simplified_src1_dims, const Src* src1, Dst* dst, Scalar attr0,\n                       Scalar attr1) {\n  const int64_t elem_cnt = GetElementCount(simplified_num_dims, simplified_src0_dims);\n  auto functor = BinaryFunctor<DeviceType::kCPU, binary_op, Src, Dst>(attr0, attr1);\n  cpu_stream->ParallelFor(0, elem_cnt, [functor, src0, src1, dst](int64_t begin, int64_t end) {\n    for (int64_t i = begin; i < end; i++) { dst[i] = functor(src0[i], src1[i]); }\n  });\n}\n\ntemplate<BinaryOp binary_op, typename Src, typename Dst>\nvoid LaunchBinaryLhsScalar(CpuStream* cpu_stream, Src src0_value, size_t src1_elem_cnt,\n                           const Src* src1, Dst* dst, Scalar attr0, Scalar attr1) {\n  auto functor = BinaryLhsScalarFunctor<binary_op, Src, Dst>(src0_value, attr0, attr1);\n  cpu_stream->ParallelFor(0, src1_elem_cnt, [functor, src1, dst](int64_t begin, int64_t end) {\n    for (int64_t i = begin; i < end; i++) { dst[i] = functor(src1[i]); }\n  });\n}\n\ntemplate<BinaryOp binary_op, typename Src, typename Dst>\nvoid LaunchBinaryRhsScalar(CpuStream* cpu_stream, Src src1_value, size_t src0_elem_cnt,\n                           const Src* src0, Dst* dst, Scalar attr0, Scalar attr1) {\n  auto functor = BinaryRhsScalarFunctor<binary_op, Src, Dst>(src1_value, attr0, attr1);\n  cpu_stream->ParallelFor(0, src0_elem_cnt, [functor, src0, dst](int64_t begin, int64_t end) {\n    for (int64_t i = begin; i < end; i++) { dst[i] = functor(src0[i]); }\n  });\n}\n\ntemplate<BinaryOp binary_op, typename Src, typename Dst>\nvoid LaunchRowWithMatrix(CpuStream* cpu_stream, const int64_t* simplified_src0_dims,\n                         const Src* src0, const int64_t* simplified_src1_dims, const Src* src1,\n                         Dst* dst, Scalar attr0, Scalar attr1) {\n  int64_t rows = simplified_src1_dims[0];\n  int64_t cols = simplified_src0_dims[1];\n  auto functor = BinaryFunctor<DeviceType::kCPU, binary_op, Src, Dst>(attr0, attr1);\n  cpu_stream->ParallelFor(\n      0, rows,\n      [functor, src0, src1, dst, cols](int64_t begin, int64_t end) {\n        for (int64_t row_idx = begin; row_idx < end; row_idx++) {\n          const Src* src1_row = src1 + row_idx * cols;\n          Dst* dst_row = dst + row_idx * cols;\n          for (int64_t col_idx = 0; col_idx < cols; col_idx++) {\n            dst_row[col_idx] = functor(src0[col_idx], src1_row[col_idx]);\n          }\n        }\n      },\n      1);\n}\n\ntemplate<BinaryOp binary_op, typename Src, typename Dst>\nvoid LaunchMatrixWithRow(CpuStream* cpu_stream, const int64_t* simplified_src0_dims,\n                         const Src* src0, const int64_t* simplified_src1_dims, const Src* src1,\n                         Dst* dst, Scalar attr0, Scalar attr1) {\n  int64_t rows = simplified_src0_dims[0];\n  int64_t cols = simplified_src1_dims[1];\n  auto functor = BinaryFunctor<DeviceType::kCPU, binary_op, Src, Dst>(attr0, attr1);\n  cpu_stream->ParallelFor(\n      0, rows,\n      [functor, src0, src1, dst, cols](int64_t begin, int64_t end) {\n        for (int64_t row_idx = begin; row_idx < end; row_idx++) {\n          const Src* src0_row = src0 + row_idx * cols;\n          Dst* dst_row = dst + row_idx * cols;\n          for (int64_t col_idx = 0; col_idx < cols; col_idx++) {\n            dst_row[col_idx] = functor(src0_row[col_idx], src1[col_idx]);\n          }\n        }\n      },\n      1);\n}\n\ntemplate<BinaryOp binary_op, typename Src, typename Dst>\nvoid LaunchColWithMatrix(CpuStream* cpu_stream, const int64_t* simplified_src0_dims,\n                         const Src* src0, const int64_t* simplified_src1_dims, const Src* src1,\n                         Dst* dst, Scalar attr0, Scalar attr1) {\n  int64_t rows = simplified_src0_dims[0];\n  int64_t cols = simplified_src1_dims[1];\n  auto functor = BinaryFunctor<DeviceType::kCPU, binary_op, Src, Dst>(attr0, attr1);\n  cpu_stream->ParallelFor(\n      0, rows,\n      [functor, src0, src1, dst, cols](int64_t begin, int64_t end) {\n        for (int64_t row_idx = begin; row_idx < end; row_idx++) {\n          const Src* src1_row = src1 + row_idx * cols;\n          Dst* dst_row = dst + row_idx * cols;\n          for (int64_t col_idx = 0; col_idx < cols; col_idx++) {\n            dst_row[col_idx] = functor(src0[row_idx], src1_row[col_idx]);\n          }\n        }\n      },\n      1);\n}\n\ntemplate<BinaryOp binary_op, typename Src, typename Dst>\nvoid LaunchMatrixWithCol(CpuStream* cpu_stream, const int64_t* simplified_src0_dims,\n                         const Src* src0, const int64_t* simplified_src1_dims, const Src* src1,\n                         Dst* dst, Scalar attr0, Scalar attr1) {\n  int64_t rows = simplified_src1_dims[0];\n  int64_t cols = simplified_src0_dims[1];\n  auto functor = BinaryFunctor<DeviceType::kCPU, binary_op, Src, Dst>(attr0, attr1);\n  cpu_stream->ParallelFor(\n      0, rows,\n      [functor, src0, src1, dst, cols](int64_t begin, int64_t end) {\n        for (int64_t row_idx = begin; row_idx < end; row_idx++) {\n          const Src* src0_row = src0 + row_idx * cols;\n          Dst* dst_row = dst + row_idx * cols;\n          for (int64_t col_idx = 0; col_idx < cols; col_idx++) {\n            dst_row[col_idx] = functor(src0_row[col_idx], src1[row_idx]);\n          }\n        }\n      },\n      1);\n}\n\ntemplate<BinaryOp binary_op, typename Src, typename Dst, typename IndexType>\nvoid LaunchGeneral(CpuStream* cpu_stream, size_t simplified_num_dims,\n                   const int64_t* simplified_src0_dims, const Src* src0,\n                   const int64_t* simplified_src1_dims, const Src* src1,\n                   const int64_t* simplified_dst_dims, Dst* dst, int64_t dst_elem_cnt, Scalar attr0,\n                   Scalar attr1) {\n  auto functor = BinaryFunctor<DeviceType::kCPU, binary_op, Src, Dst>(attr0, attr1);\n  cpu_stream->ParallelFor(\n      0, dst_elem_cnt,\n      [functor, src0, src1, dst, simplified_num_dims, simplified_src0_dims, simplified_src1_dims,\n       simplified_dst_dims](int64_t begin, int64_t end) {\n        auto src0_index_helper =\n            NdIndexOffsetHelper<IndexType, kMaxNumDims>(simplified_src0_dims, simplified_num_dims);\n        auto src1_index_helper =\n            NdIndexOffsetHelper<IndexType, kMaxNumDims>(simplified_src1_dims, simplified_num_dims);\n        auto dst_index_helper = OffsetToIndexCalculator<IndexType, kMaxNumDims>(\n            simplified_dst_dims, simplified_num_dims);\n        IndexType src0_index[kMaxNumDims];\n        IndexType src1_index[kMaxNumDims];\n        IndexType dst_index[kMaxNumDims];\n        for (IndexType offset = begin; offset < end; offset++) {\n          dst_index_helper.OffsetToNdIndex(offset, dst_index, simplified_num_dims);\n          for (int i = 0; i < kMaxNumDims; i++) {\n            if (i < simplified_num_dims) {\n              src0_index[i] = (simplified_src0_dims[i] != 1) ? dst_index[i] : 0;\n              src1_index[i] = (simplified_src1_dims[i] != 1) ? dst_index[i] : 0;\n            } else {\n              src0_index[i] = 0;\n              src1_index[i] = 0;\n            }\n          }\n          const IndexType src0_offset =\n              src0_index_helper.NdIndexToOffset(src0_index, simplified_num_dims);\n          const IndexType src1_offset =\n              src1_index_helper.NdIndexToOffset(src1_index, simplified_num_dims);\n          dst[offset] = functor(src0[src0_offset], src1[src1_offset]);\n        }\n      });\n}\n\ntemplate<BinaryOp binary_op, typename Src, typename Dst>\nvoid LaunchGeneralDispatchIndexType(CpuStream* cpu_stream, size_t simplified_num_dims,\n                                    const int64_t* simplified_src0_dims, const Src* src0,\n                                    const int64_t* simplified_src1_dims, const Src* src1,\n                                    const int64_t* simplified_dst_dims, Dst* dst, Scalar attr0,\n                                    Scalar attr1) {\n  const int64_t dst_elem_cnt = GetElementCount(simplified_num_dims, simplified_dst_dims);\n  if (dst_elem_cnt < (GetMaxVal<int32_t>() / 2)) {\n    LaunchGeneral<binary_op, Src, Dst, int32_t>(\n        cpu_stream, simplified_num_dims, simplified_src0_dims, src0, simplified_src1_dims, src1,\n        simplified_dst_dims, dst, dst_elem_cnt, attr0, attr1);\n  } else {\n    LaunchGeneral<binary_op, Src, Dst, int64_t>(\n        cpu_stream, simplified_num_dims, simplified_src0_dims, src0, simplified_src1_dims, src1,\n        simplified_dst_dims, dst, dst_elem_cnt, attr0, attr1);\n  }\n}\n\ntemplate<BinaryOp binary_op, typename Src, typename Dst>\nvoid DispatchLaunch(Stream* stream, size_t num_src0_dims, const int64_t* src0_dims, const Src* src0,\n                    size_t num_src1_dims, const int64_t* src1_dims, const Src* src1, Dst* dst,\n                    Scalar attr0, Scalar attr1) {\n  auto* cpu_stream = stream->As<CpuStream>();\n  size_t simplified_num_dims = 0;\n  int64_t simplified_src0_dims[kMaxNumDims];\n  int64_t simplified_src1_dims[kMaxNumDims];\n  int64_t simplified_dst_dims[kMaxNumDims];\n  SimplifyBroadcastDims<kMaxNumDims>(num_src0_dims, src0_dims, num_src1_dims, src1_dims,\n                                     &simplified_num_dims, simplified_src0_dims,\n                                     simplified_src1_dims, simplified_dst_dims);\n  CheckInplace(simplified_num_dims, simplified_src0_dims, src0, simplified_dst_dims, dst);\n  CheckInplace(simplified_num_dims, simplified_src1_dims, src1, simplified_dst_dims, dst);\n  if (IsDimsEquals(simplified_num_dims, simplified_src0_dims, simplified_num_dims,\n                   simplified_src1_dims)) {\n    LaunchElementwise<binary_op, Src, Dst>(cpu_stream, simplified_num_dims, simplified_src0_dims,\n                                           src0, simplified_src1_dims, src1, dst, attr0, attr1);\n  } else {\n    if (simplified_num_dims == 1 && simplified_src0_dims[0] == 1) {\n      LaunchBinaryLhsScalar<binary_op, Src, Dst>(cpu_stream, *src0, simplified_src1_dims[0], src1,\n                                                 dst, attr0, attr1);\n    } else if (simplified_num_dims == 1 && simplified_src1_dims[0] == 1) {\n      LaunchBinaryRhsScalar<binary_op, Src, Dst>(cpu_stream, *src1, simplified_src0_dims[0], src0,\n                                                 dst, attr0, attr1);\n    } else if (simplified_num_dims == 2 && simplified_src0_dims[0] == 1\n               && simplified_src0_dims[1] == simplified_src1_dims[1]) {\n      LaunchRowWithMatrix<binary_op, Src, Dst>(cpu_stream, simplified_src0_dims, src0,\n                                               simplified_src1_dims, src1, dst, attr0, attr1);\n    } else if (simplified_num_dims == 2 && simplified_src1_dims[0] == 1\n               && simplified_src0_dims[1] == simplified_src1_dims[1]) {\n      LaunchMatrixWithRow<binary_op, Src, Dst>(cpu_stream, simplified_src0_dims, src0,\n                                               simplified_src1_dims, src1, dst, attr0, attr1);\n    } else if (simplified_num_dims == 2 && simplified_src0_dims[1] == 1\n               && simplified_src0_dims[0] == simplified_src1_dims[0]) {\n      LaunchColWithMatrix<binary_op, Src, Dst>(cpu_stream, simplified_src0_dims, src0,\n                                               simplified_src1_dims, src1, dst, attr0, attr1);\n    } else if (simplified_num_dims == 2 && simplified_src1_dims[1] == 1\n               && simplified_src0_dims[0] == simplified_src1_dims[0]) {\n      LaunchMatrixWithCol<binary_op, Src, Dst>(cpu_stream, simplified_src0_dims, src0,\n                                               simplified_src1_dims, src1, dst, attr0, attr1);\n    } else {\n      LaunchGeneralDispatchIndexType<binary_op, Src, Dst>(\n          cpu_stream, simplified_num_dims, simplified_src0_dims, src0, simplified_src1_dims, src1,\n          simplified_dst_dims, dst, attr0, attr1);\n    }\n  }\n}\n\ntemplate<BinaryOp binary_op, typename Src, typename Dst>\nclass BroadcastElementwiseBinaryImpl : public BroadcastElementwiseBinary {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(BroadcastElementwiseBinaryImpl);\n  BroadcastElementwiseBinaryImpl(Scalar attr0, Scalar attr1) : attr0(attr0), attr1(attr1) {}\n  ~BroadcastElementwiseBinaryImpl() override = default;\n\n  void Launch(Stream* stream, Scalar src0, size_t num_src1_dims, const int64_t* src1_dims,\n              const void* src1_ptr, void* dst_ptr) override {\n    auto* cpu_stream = stream->As<CpuStream>();\n    const size_t elem_cnt = GetElementCount(num_src1_dims, src1_dims);\n    Dst* dst = reinterpret_cast<Dst*>(dst_ptr);\n    const Src* src1 = reinterpret_cast<const Src*>(src1_ptr);\n    LaunchBinaryLhsScalar<binary_op, Src, Dst>(cpu_stream, GetValue<Src>(src0), elem_cnt, src1, dst,\n                                               attr0, attr1);\n  }\n  void Launch(Stream* stream, size_t num_src0_dims, const int64_t* src0_dims, const void* src0_ptr,\n              Scalar src1, void* dst_ptr) override {\n    auto* cpu_stream = stream->As<CpuStream>();\n    const size_t elem_cnt = GetElementCount(num_src0_dims, src0_dims);\n    Dst* dst = reinterpret_cast<Dst*>(dst_ptr);\n    const Src* src0 = reinterpret_cast<const Src*>(src0_ptr);\n    LaunchBinaryRhsScalar<binary_op, Src, Dst>(cpu_stream, GetValue<Src>(src1), elem_cnt, src0, dst,\n                                               attr0, attr1);\n  }\n  void Launch(Stream* stream, size_t num_src0_dims, const int64_t* src0_dims, const void* src0,\n              size_t num_src1_dims, const int64_t* src1_dims, const void* src1,\n              void* dst) override {\n    DispatchLaunch<binary_op, Src, Dst>(\n        stream, num_src0_dims, src0_dims, reinterpret_cast<const Src*>(src0), num_src1_dims,\n        src1_dims, reinterpret_cast<const Src*>(src1), reinterpret_cast<Dst*>(dst), attr0, attr1);\n  }\n\n private:\n  Scalar attr0, attr1;\n};\n\ntemplate<BinaryOp binary_op, typename Src, typename Dst>\nstd::unique_ptr<BroadcastElementwiseBinary> NewBroadcastElementwiseBinary(Scalar attr0,\n                                                                          Scalar attr1) {\n  return std::unique_ptr<BroadcastElementwiseBinary>(\n      new BroadcastElementwiseBinaryImpl<binary_op, Src, Dst>(attr0, attr1));\n}\n\n#define NDARRAY_BINARY_TYPE_SEQ \\\n  CPU_PRIMITIVE_BOOL_TYPE_SEQ   \\\n  CPU_PRIMITIVE_INT8_TYPE_SEQ   \\\n  CPU_PRIMITIVE_UINT8_TYPE_SEQ  \\\n  CPU_PRIMITIVE_INT32_TYPE_SEQ  \\\n  CPU_PRIMITIVE_INT64_TYPE_SEQ  \\\n  CPU_PRIMITIVE_FLOAT_TYPE_SEQ  \\\n  CPU_PRIMITIVE_DOUBLE_TYPE_SEQ \\\n  CPU_PRIMITIVE_FLOAT16_TYPE_SEQ\n\n#ifdef WITH_ONEDNN\n\nuint32_t OnednnFormatTagMap[kMaxNumDims] = {dnnl_a,     dnnl_ab,     dnnl_abc,     dnnl_abcd,\n                                            dnnl_abcde, dnnl_abcdef, dnnl_abcdefg, dnnl_abcdefgh};\n\ninline void OneDnnBroadcastDims(dnnl::memory::dims* src0, size_t num_src0_dims,\n                                const int64_t* src0_dims, dnnl::memory::dims* src1,\n                                size_t num_src1_dims, const int64_t* src1_dims,\n                                dnnl::memory::dims& dst) {\n  const int64_t num_dims = dst.size();\n  const int64_t num_src0_padding_dims = num_dims - num_src0_dims;\n  const int64_t num_src1_padding_dims = num_dims - num_src1_dims;\n  for (int64_t i = 0; i < num_dims; i++) {\n    int64_t src0_dim = i < num_src0_padding_dims ? 1 : src0_dims[i - num_src0_padding_dims];\n    int64_t src1_dim = i < num_src1_padding_dims ? 1 : src1_dims[i - num_src1_padding_dims];\n    CHECK((src0_dim == src1_dim || src0_dim == 1 || src1_dim == 1));\n    (*src0)[i] = src0_dim;\n    (*src1)[i] = src1_dim;\n    dst[i] = std::max(src0_dim, src1_dim);\n  }\n}\n\ntemplate<typename T, dnnl::algorithm algorithm, dnnl::memory::data_type src_onednn,\n         dnnl::memory::data_type dst_onednn>\nclass OneDnnBroadcastElementwiseBinaryImpl : public BroadcastElementwiseBinary {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(OneDnnBroadcastElementwiseBinaryImpl);\n  OneDnnBroadcastElementwiseBinaryImpl(Scalar attr0, Scalar attr1) : attr0(attr0), attr1(attr1) {}\n  ~OneDnnBroadcastElementwiseBinaryImpl() override = default;\n\n  void Launch(Stream* stream, Scalar src0, size_t num_src1_dims, const int64_t* src1_dims,\n              const void* src1, void* dst) override {\n    T scalar_val = GetValue<T>(src0);\n    const int64_t src0_dims = 1;\n    Launch(stream, 1, &src0_dims, &scalar_val, num_src1_dims, src1_dims, src1, dst);\n  }\n  void Launch(Stream* stream, size_t num_src0_dims, const int64_t* src0_dims, const void* src0,\n              Scalar src1, void* dst) override {\n    T scalar_val = GetValue<T>(src1);\n    const int64_t src1_dims = 1;\n    Launch(stream, num_src0_dims, src0_dims, src0, 1, &src1_dims, &scalar_val, dst);\n  }\n  void Launch(Stream* stream, size_t num_src0_dims, const int64_t* src0_dims, const void* src0,\n              size_t num_src1_dims, const int64_t* src1_dims, const void* src1,\n              void* dst) override {\n    stream->As<CpuStream>()->onednn_executor()->Launch([&](dnnl::engine* onednn_engine,\n                                                           dnnl::stream* onednn_stream) {\n      // onednn do not optimize for 3d tensor in our experiments, so expand it\n      // to 4d if needed.\n      // Note that only onednn \"internal\" dims will be affected, the shape\n      // of oneflow tensor (including the output tensor) will remain unchanged.\n      size_t num_dims = std::max(std::max(num_src0_dims, num_src1_dims), static_cast<size_t>(4));\n      dnnl::memory::dims src_0_dims(num_dims);\n      dnnl::memory::dims src_1_dims(num_dims);\n      dnnl::memory::dims dst_dims(num_dims);\n      const void* onednn_src0 = nullptr;\n      const void* onednn_src1 = nullptr;\n\n      // OneDNN inplace operations only support src_0\n      if (src1 == dst) {\n        onednn_src0 = src1;\n        onednn_src1 = src0;\n        OneDnnBroadcastDims(&src_0_dims, num_src1_dims, src1_dims, &src_1_dims, num_src0_dims,\n                            src0_dims, dst_dims);\n      } else {\n        onednn_src0 = src0;\n        onednn_src1 = src1;\n        OneDnnBroadcastDims(&src_0_dims, num_src0_dims, src0_dims, &src_1_dims, num_src1_dims,\n                            src1_dims, dst_dims);\n      }\n\n      CheckInplace(num_dims, src_0_dims.data(), onednn_src0, dst_dims.data(), dst);\n      CheckInplace(num_dims, src_1_dims.data(), onednn_src1, dst_dims.data(), dst);\n\n      auto src_0_md = dnnl::memory::desc(\n          src_0_dims, src_onednn,\n          static_cast<dnnl::memory::format_tag>(OnednnFormatTagMap[num_dims - 1]));\n      auto src_1_md = dnnl::memory::desc(\n          src_1_dims, src_onednn,\n          static_cast<dnnl::memory::format_tag>(OnednnFormatTagMap[num_dims - 1]));\n      auto dst_md = dnnl::memory::desc(\n          dst_dims, dst_onednn,\n          static_cast<dnnl::memory::format_tag>(OnednnFormatTagMap[num_dims - 1]));\n\n      auto src_0_mem = dnnl::memory(src_0_md, *onednn_engine, (void*)onednn_src0);\n      auto src_1_mem = dnnl::memory(src_1_md, *onednn_engine, (void*)onednn_src1);\n      auto dst_mem = dnnl::memory(dst_md, *onednn_engine, dst);\n\n      auto binary_d = dnnl::binary::desc(algorithm, src_0_md, src_1_md, dst_md);\n      auto binary_pd = dnnl::binary::primitive_desc(binary_d, *onednn_engine);\n      auto binary_prim = dnnl::binary(binary_pd);\n\n      binary_prim.execute(\n          *onednn_stream,\n          {{DNNL_ARG_SRC_0, src_0_mem}, {DNNL_ARG_SRC_1, src_1_mem}, {DNNL_ARG_DST, dst_mem}});\n    });\n  }\n\n private:\n  Scalar attr0, attr1;\n};\n\n#define CPU_PRIMITIVE_BINARY_ONEDNN_TYPE_SEQ                               \\\n  OF_PP_MAKE_TUPLE_SEQ(dnnl::memory::data_type::u8, DataType::kBool, bool) \\\n  OF_PP_MAKE_TUPLE_SEQ(dnnl::memory::data_type::f32, DataType::kFloat, float)\n\n// OneDNN binary op does not support s32\n// CPU_PRIMITIVE_ONEDNN_INT32_TYPE_SEQ\n\n#define CPU_PRIMITIVE_BINARY_ONEDNN_UNIMPLEMENTED_TYPE_SEQ \\\n  CPU_PRIMITIVE_FLOAT16_TYPE_SEQ                           \\\n  CPU_PRIMITIVE_DOUBLE_TYPE_SEQ                            \\\n  CPU_PRIMITIVE_INT8_TYPE_SEQ                              \\\n  CPU_PRIMITIVE_UINT8_TYPE_SEQ                             \\\n  CPU_PRIMITIVE_INT32_TYPE_SEQ                             \\\n  CPU_PRIMITIVE_INT64_TYPE_SEQ\n\n#define BINARY_ONEDNN_ADD OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kAdd, dnnl::algorithm::binary_add)\n#define BINARY_ONEDNN_SUB OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kSub, dnnl::algorithm::binary_sub)\n#define BINARY_ONEDNN_MUL OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kMul, dnnl::algorithm::binary_mul)\n#define BINARY_ONEDNN_DIV OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kDiv, dnnl::algorithm::binary_div)\n#define BINARY_ONEDNN_MAX OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kMax, dnnl::algorithm::binary_max)\n#define BINARY_ONEDNN_MIN OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kMin, dnnl::algorithm::binary_min)\n\n#define BINARY_ONEDNN_EQ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kEqual, dnnl::algorithm::binary_eq)\n#define BINARY_ONEDNN_NE OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kNotEqual, dnnl::algorithm::binary_ne)\n#define BINARY_ONEDNN_LT OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kLessThan, dnnl::algorithm::binary_lt)\n#define BINARY_ONEDNN_LE OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kLessEqual, dnnl::algorithm::binary_le)\n#define BINARY_ONEDNN_GT OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kGreaterThan, dnnl::algorithm::binary_gt)\n#define BINARY_ONEDNN_GE OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kGreaterEqual, dnnl::algorithm::binary_ge)\n\n#define BINARY_MATH_OP_ONEDNN_PAIR \\\n  BINARY_ONEDNN_ADD                \\\n  BINARY_ONEDNN_SUB                \\\n  BINARY_ONEDNN_MUL                \\\n  BINARY_ONEDNN_DIV                \\\n  BINARY_ONEDNN_MAX                \\\n  BINARY_ONEDNN_MIN\n\n#define BINARY_LOGICAL_COMPARISION_OP_ONEDNN_PAIR \\\n  BINARY_ONEDNN_EQ                                \\\n  BINARY_ONEDNN_NE                                \\\n  BINARY_ONEDNN_LT                                \\\n  BINARY_ONEDNN_LE                                \\\n  BINARY_ONEDNN_GT                                \\\n  BINARY_ONEDNN_GE\n\n#define BINARY_LOGICAL_COMPARISION_OP_ONEDNN_UNIMPLEMENTED \\\n  OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kLogicalAnd, AND)         \\\n  OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kLogicalOr, OR)           \\\n  OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kLogicalXor, XOR)\n\ntemplate<typename T, dnnl::algorithm algorithm, dnnl::memory::data_type src_onednn,\n         dnnl::memory::data_type dst_onednn>\nstd::unique_ptr<BroadcastElementwiseBinary> NewOneDnnBroadcastElementwiseBinary(Scalar attr0,\n                                                                                Scalar attr1) {\n  return std::unique_ptr<BroadcastElementwiseBinary>(\n      new OneDnnBroadcastElementwiseBinaryImpl<T, algorithm, src_onednn, dst_onednn>(attr0, attr1));\n}\n\n#define MAKE_NEW_ONEDNN_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY(binary_op_pair, data_type_pair) \\\n  {std::make_tuple(OF_PP_PAIR_FIRST(binary_op_pair), OF_PP_PAIR_SECOND(data_type_pair),         \\\n                   OF_PP_PAIR_SECOND(data_type_pair)),                                          \\\n   NewOneDnnBroadcastElementwiseBinary<                                                         \\\n       OF_PP_PAIR_THIRD(data_type_pair), OF_PP_PAIR_SECOND(binary_op_pair),                     \\\n       OF_PP_PAIR_FIRST(data_type_pair), OF_PP_PAIR_FIRST(data_type_pair)>},\n\n#define MAKE_NEW_ONEDNN_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_AND_LOGICAL_ENTRY(         \\\n    binary_op_pair, src_data_type_pair, dst_data_type_pair)                                 \\\n  {std::make_tuple(OF_PP_PAIR_FIRST(binary_op_pair), OF_PP_PAIR_SECOND(src_data_type_pair), \\\n                   OF_PP_PAIR_SECOND(dst_data_type_pair)),                                  \\\n   NewOneDnnBroadcastElementwiseBinary<                                                     \\\n       OF_PP_PAIR_THIRD(src_data_type_pair), OF_PP_PAIR_SECOND(binary_op_pair),             \\\n       OF_PP_PAIR_FIRST(src_data_type_pair), OF_PP_PAIR_FIRST(dst_data_type_pair)>},\n\n#endif  // WITH_ONEDNN\n\nclass BroadcastElementwiseBinaryFactoryImpl : public BroadcastElementwiseBinaryFactory {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(BroadcastElementwiseBinaryFactoryImpl);\n  BroadcastElementwiseBinaryFactoryImpl() = default;\n  ~BroadcastElementwiseBinaryFactoryImpl() override = default;\n\n  std::unique_ptr<BroadcastElementwiseBinary> New(BinaryOp op, DataType src_type, DataType dst_type,\n                                                  size_t max_num_dims) override {\n    return New(op, src_type, dst_type, max_num_dims, Scalar(), Scalar());\n  }\n\n  std::unique_ptr<BroadcastElementwiseBinary> New(BinaryOp op, DataType src_type, DataType dst_type,\n                                                  size_t max_num_dims, Scalar attr0) override {\n    return New(op, src_type, dst_type, max_num_dims, attr0, Scalar());\n  }\n\n  std::unique_ptr<BroadcastElementwiseBinary> New(BinaryOp binary_op, DataType src_type,\n                                                  DataType dst_type, size_t max_num_dims,\n                                                  Scalar attr0, Scalar attr1) override {\n    if (max_num_dims > kMaxNumDims) { return nullptr; }\n#define MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY(binary_op, data_type_pair) \\\n  {std::make_tuple(binary_op, OF_PP_PAIR_SECOND(data_type_pair),                    \\\n                   OF_PP_PAIR_SECOND(data_type_pair)),                              \\\n   NewBroadcastElementwiseBinary<binary_op, OF_PP_PAIR_FIRST(data_type_pair),       \\\n                                 OF_PP_PAIR_FIRST(data_type_pair)>},\n\n#define MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_AND_LOGICAL_ENTRY(      \\\n    binary_op, src_data_type_pair, dst_data_type_pair)                            \\\n  {std::make_tuple(binary_op, OF_PP_PAIR_SECOND(src_data_type_pair),              \\\n                   OF_PP_PAIR_SECOND(dst_data_type_pair)),                        \\\n   NewBroadcastElementwiseBinary<binary_op, OF_PP_PAIR_FIRST(src_data_type_pair), \\\n                                 OF_PP_PAIR_FIRST(dst_data_type_pair)>},\n\n#define MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_ACTIVATION_GRAD_ENTRY(binary_op, data_type_pair) \\\n  {std::make_tuple(binary_op, OF_PP_PAIR_SECOND(data_type_pair),                               \\\n                   OF_PP_PAIR_SECOND(data_type_pair)),                                         \\\n   NewBroadcastElementwiseBinary<binary_op, OF_PP_PAIR_FIRST(data_type_pair),                  \\\n                                 OF_PP_PAIR_FIRST(data_type_pair)>},\n\n    static const std::map<\n        std::tuple<BinaryOp, DataType, DataType>,\n        std::function<std::unique_ptr<BroadcastElementwiseBinary>(Scalar, Scalar)>>\n        new_broadcast_elementwise_binary_handle{\n            OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY,\n                                             BINARY_MATH_OP_SEQ, NDARRAY_BINARY_TYPE_SEQ)\n\n                OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY,\n                                                 BINARY_COMPLEX_MATH_OP_SEQ,\n                                                 CPU_PRIMITIVE_COMPLEX_TYPE_SEQ)\n\n                    OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(\n                        MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY, BINARY_BITWISE_OP_SEQ,\n                        CPU_PRIMITIVE_INT_TYPE_SEQ CPU_PRIMITIVE_BOOL_TYPE_SEQ)\n\n                        OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(\n                            MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY,\n                            BINARY_MATH_FLOATING_OP_SEQ, CPU_PRIMITIVE_FLOATING_TYPE_SEQ)\n\n                            OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(\n                                MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_AND_LOGICAL_ENTRY,\n                                BINARY_LOGICAL_OP_SEQ BINARY_COMPARISION_OP_SEQ,\n                                NDARRAY_BINARY_TYPE_SEQ, CPU_PRIMITIVE_BOOL_TYPE_SEQ)\n\n                                OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(\n                                    MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_AND_LOGICAL_ENTRY,\n                                    BINARY_COMPLEX_COMPARISION_OP_SEQ,\n                                    CPU_PRIMITIVE_COMPLEX_TYPE_SEQ, CPU_PRIMITIVE_BOOL_TYPE_SEQ)\n\n                                    OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(\n                                        MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_ACTIVATION_GRAD_ENTRY,\n                                        BINARY_ACTIVATION_BACKWARD_OP_SEQ,\n                                        CPU_PRIMITIVE_FLOATING_TYPE_SEQ)\n\n                                        OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(\n                                            MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY,\n                                            BINARY_MATH_BACKWARD_OP_SEQ,\n                                            CPU_PRIMITIVE_FLOATING_TYPE_SEQ)\n\n                                            OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(\n                                                MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY,\n                                                BINARY_MATH_BACKWARD_OP_SEQ_COMPLEX,\n                                                CPU_PRIMITIVE_COMPLEX_TYPE_SEQ)};\n\n#undef MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_AND_LOGICAL_ENTRY\n#undef MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY\n\n#ifdef WITH_ONEDNN\n    static const std::map<\n        std::tuple<BinaryOp, DataType, DataType>,\n        std::function<std::unique_ptr<BroadcastElementwiseBinary>(Scalar, Scalar)>>\n        new_broadcast_elementwise_binary_onednn_handle{\n            // For oneDNN binary op\n            OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(\n                MAKE_NEW_ONEDNN_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY, BINARY_MATH_OP_ONEDNN_PAIR,\n                CPU_PRIMITIVE_BINARY_ONEDNN_TYPE_SEQ)\n            // For OneDnn comparasion binary op\n            OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(\n                MAKE_NEW_ONEDNN_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_AND_LOGICAL_ENTRY,\n                BINARY_LOGICAL_COMPARISION_OP_ONEDNN_PAIR, CPU_PRIMITIVE_BINARY_ONEDNN_TYPE_SEQ,\n                CPU_PRIMITIVE_ONEDNN_BOOl_TYPE_SEQ)};\n\n#undef MAKE_NEW_ONEDNN_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_AND_LOGICAL_ENTRY\n#undef MAKE_NEW_ONEDNN_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY\n    if (OneDnnIsEnabled()) {\n      const auto iter = new_broadcast_elementwise_binary_onednn_handle.find(\n          std::make_tuple(binary_op, src_type, dst_type));\n      if (iter != new_broadcast_elementwise_binary_onednn_handle.end()) {\n        return iter->second(attr0, attr1);\n      }\n    }\n\n#endif\n    const auto iter = new_broadcast_elementwise_binary_handle.find(\n        std::make_tuple(binary_op, src_type, dst_type));\n    if (iter != new_broadcast_elementwise_binary_handle.end()) {\n      return iter->second(attr0, attr1);\n    } else {\n      return nullptr;\n    }\n  }\n};\n\nREGISTER_PRIMITIVE_FACTORY(DeviceType::kCPU, BroadcastElementwiseBinaryFactory,\n                           BroadcastElementwiseBinaryFactoryImpl);\n\n}  // namespace\n}  // namespace broadcast_elementwise_binary\n}  // namespace primitive\n}  // namespace ep\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/ep/cpu/primitive/broadcast_elementwise_unary.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/common/data_type.h\"\n#include \"oneflow/core/ep/common/primitive/broadcast_elementwise_unary.h\"\n#include \"oneflow/core/ep/include/primitive/permute.h\"\n#include \"oneflow/core/ep/cpu/primitive/unary_functor.h\"\n#include \"oneflow/core/ep/cpu/primitive/type_seq.h\"\n#include \"oneflow/core/ep/cpu/cpu_stream.h\"\n#include \"oneflow/core/ep/cpu/cpu_device.h\"\n\nnamespace oneflow {\n\nnamespace ep {\nnamespace primitive {\nnamespace broadcast_elementwise_unary {\n\nnamespace {\n\n#define CPU_PRIMITIVE_CAST_REAL_TYPE_SEQ \\\n  CPU_PRIMITIVE_INT16_TYPE_SEQ           \\\n  CPU_PRIMITIVE_NATIVE_TYPE_SEQ          \\\n  CPU_PRIMITIVE_FLOAT16_TYPE_SEQ         \\\n  CPU_PRIMITIVE_BFLOAT16_TYPE_SEQ\n\nbool IsContiguous(size_t num_dims, const int64_t* dims, const int64_t* strides) {\n  for (int i = num_dims - 1; i >= 0; i--) {\n    if ((i == num_dims - 1 && strides[i] != 1)\n        || (i != num_dims - 1 && strides[i] != dims[i + 1] * strides[i + 1])) {\n      return false;\n    }\n  }\n  return true;\n}\n\ntemplate<UnaryOp unary_op, typename Src, typename Dst>\nvoid LaunchScalarFill(CpuStream* stream, Dst* dst, const Src* src, size_t count, size_t stride,\n                      Scalar attr0, Scalar attr1) {\n  auto functor = UnaryFunctor<DeviceType::kCPU, unary_op, Dst, Src>(attr0, attr1);\n  Dst scalar_value = functor(*src);\n  stream->ParallelFor(0, count, [dst, stride, scalar_value](int64_t begin, int64_t end) {\n    for (int64_t i = begin; i < end; i++) { dst[i * stride] = scalar_value; }\n  });\n}\n\ntemplate<UnaryOp unary_op, typename Src, typename Dst>\nvoid LaunchTensorFill(CpuStream* stream, Dst* dst, const Src* src, size_t count, size_t dst_stride,\n                      size_t src_stride, Scalar attr0, Scalar attr1) {\n  auto functor = UnaryFunctor<DeviceType::kCPU, unary_op, Dst, Src>(attr0, attr1);\n  stream->ParallelFor(0, count,\n                      [functor, src, dst, src_stride, dst_stride](int64_t begin, int64_t end) {\n                        for (int64_t i = begin; i < end; i++) {\n                          dst[i * dst_stride] = functor(src[i * src_stride]);\n                        }\n                      });\n}\n\ntemplate<UnaryOp unary_op, typename Src, typename Dst>\nvoid LaunchGeneral(CpuStream* stream, Dst* dst, const Src* src, size_t num_dims,\n                   const int64_t* dst_dims, const int64_t* src_dims, const int64_t* dst_stride,\n                   const int64_t* src_stride, Scalar attr0, Scalar attr1) {\n  bool contiguous_output = IsContiguous(num_dims, dst_dims, dst_stride);\n  const int64_t elem_cnt = GetElementCount(num_dims, dst_dims);\n  auto functor = UnaryFunctor<DeviceType::kCPU, unary_op, Dst, Src>(attr0, attr1);\n  stream->ParallelFor(\n      0, elem_cnt,\n      [functor, src, dst, num_dims, src_dims, dst_dims, src_stride, dst_stride, contiguous_output](\n          int64_t begin, int64_t end) {\n        auto src_index_to_offset_helper =\n            IndexToOffsetWithStrideCalculator<int64_t, kMaxNumDims>(src_stride, num_dims);\n        auto dst_offset_to_index_helper =\n            OffsetToIndexWithStrideCalculator<int64_t, kMaxNumDims>(dst_dims, num_dims);\n        auto dst_index_to_offset_helper =\n            IndexToOffsetWithStrideCalculator<int64_t, kMaxNumDims>(dst_stride, num_dims);\n        int64_t src_index[kMaxNumDims];\n        int64_t dst_index[kMaxNumDims];\n        for (int64_t offset = begin; offset < end; offset++) {\n          dst_offset_to_index_helper.OffsetToNdIndex(offset, dst_index, num_dims);\n          for (int i = 0; i < kMaxNumDims; i++) {\n            if (i < num_dims) {\n              src_index[i] = (src_dims[i] != 1) ? dst_index[i] : 0;\n            } else {\n              src_index[i] = 0;\n            }\n          }\n          const int64_t src_offset =\n              src_index_to_offset_helper.NdIndexToOffset(src_index, num_dims);\n          if (!contiguous_output) {\n            const int64_t dst_offset =\n                dst_index_to_offset_helper.NdIndexToOffset(dst_index, num_dims);\n            dst[dst_offset] = functor(src[src_offset]);\n          } else {\n            dst[offset] = functor(src[src_offset]);\n          }\n        }\n      });\n}\n\ntemplate<UnaryOp unary_op, typename Src, DataType src_type, typename Dst, DataType dst_type>\nclass BroadcastElementwiseUnaryImpl : public BroadcastElementwiseUnary {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(BroadcastElementwiseUnaryImpl);\n  BroadcastElementwiseUnaryImpl(Scalar attr0, Scalar attr1) : attr0(attr0), attr1(attr1) {}\n  ~BroadcastElementwiseUnaryImpl() override = default;\n\n  void Launch(Stream* stream, size_t num_src_dims, const int64_t* src_dims, const void* src,\n              size_t num_dst_dims, const int64_t* dst_dims, void* dst) override {\n    CHECK_GT(num_src_dims, 0) << \"num_src_dims must greater than 0\";\n    CHECK_GT(num_dst_dims, 0) << \"num_dst_dims must greater than 0\";\n    int64_t src_strides[kMaxNumDims];\n    int64_t dst_strides[kMaxNumDims];\n    // init stride\n    for (int i = num_src_dims - 1; i < kMaxNumDims; ++i) { src_strides[i] = 1; }\n    for (int i = num_src_dims - 2; i >= 0; --i) {\n      src_strides[i] = src_dims[i + 1] * src_strides[i + 1];\n    }\n\n    for (int i = num_dst_dims - 1; i < kMaxNumDims; ++i) { dst_strides[i] = 1; }\n    for (int i = num_dst_dims - 2; i >= 0; --i) {\n      dst_strides[i] = dst_dims[i + 1] * dst_strides[i + 1];\n    }\n    Launch(stream, num_src_dims, src_dims, src_strides, src, num_dst_dims, dst_dims, dst_strides,\n           dst);\n  }\n\n  void Launch(Stream* stream, size_t num_src_dims, const int64_t* src_dims,\n              const int64_t* src_strides, const void* src_ptr, size_t num_dst_dims,\n              const int64_t* dst_dims, const int64_t* dst_strides, void* dst_ptr) override {\n    CHECK_GT(num_src_dims, 0) << \"num_src_dims must greater than 0\";\n    CHECK_GT(num_dst_dims, 0) << \"num_dst_dims must greater than 0\";\n    auto* cpu_stream = stream->As<CpuStream>();\n    Dst* dst = reinterpret_cast<Dst*>(dst_ptr);\n    const Src* src = reinterpret_cast<const Src*>(src_ptr);\n    size_t simplified_num_dims = 0;\n    int permutation_list[kMaxNumDims];\n    int64_t permutation_src_dims[kMaxNumDims];\n    int64_t simplified_src_dims[kMaxNumDims];\n    int64_t simplified_dst_dims[kMaxNumDims];\n    int64_t simplified_src_strides[kMaxNumDims];\n    int64_t simplified_dst_strides[kMaxNumDims];\n    SimplifyBroadcastDims<kMaxNumDims>(num_src_dims, src_dims, src_strides, num_dst_dims, dst_dims,\n                                       dst_strides, &simplified_num_dims, simplified_src_dims,\n                                       simplified_src_strides, simplified_dst_dims,\n                                       simplified_dst_strides);\n    bool permutable = InferPermutable<kMaxNumDims>(\n        simplified_num_dims, simplified_src_strides, simplified_dst_strides, simplified_src_dims,\n        simplified_dst_dims, permutation_list, permutation_src_dims, unary_op);\n    std::unique_ptr<Permute> permute =\n        NewPrimitive<PermuteFactory>(DeviceType::kCPU, simplified_num_dims);\n    CheckInplace(simplified_num_dims, simplified_src_dims, src, simplified_dst_dims, dst);\n    CheckInplace(simplified_num_dims, simplified_src_strides, src, simplified_dst_strides, dst);\n    if (simplified_num_dims == 1 && simplified_src_dims[0] == 1) {\n      const int64_t elem_cnt = simplified_dst_dims[0];\n      const int64_t dst_stride = simplified_dst_strides[0];\n      LaunchScalarFill<unary_op, Src, Dst>(cpu_stream, dst, src, elem_cnt, dst_stride, attr0,\n                                           attr1);\n    } else if (simplified_num_dims == 1) {\n      const int64_t elem_cnt = simplified_src_dims[0];\n      const int64_t src_stride = simplified_src_strides[0];\n      const int64_t dst_stride = simplified_dst_strides[0];\n      LaunchTensorFill<unary_op, Src, Dst>(cpu_stream, dst, src, elem_cnt, dst_stride, src_stride,\n                                           attr0, attr1);\n    } else if (permutable && src_type == dst_type && permute) {\n      permute->Launch(stream, dst_type, simplified_num_dims, permutation_src_dims, src_ptr,\n                      permutation_list, dst_ptr);\n    } else {\n      // fall back to normal cases\n      LaunchGeneral<unary_op, Src, Dst>(\n          cpu_stream, dst, src, simplified_num_dims, simplified_dst_dims, simplified_src_dims,\n          simplified_dst_strides, simplified_src_strides, attr0, attr1);\n    }\n  }\n\n protected:\n  Scalar attr0, attr1;\n};\n\ntemplate<UnaryOp unary_op, typename Src, DataType src_type, typename Dst, DataType dst_type>\nstd::unique_ptr<BroadcastElementwiseUnary> NewBroadcastElementwiseUnary(Scalar attr0,\n                                                                        Scalar attr1) {\n  return std::unique_ptr<BroadcastElementwiseUnary>(\n      new BroadcastElementwiseUnaryImpl<unary_op, Src, src_type, Dst, dst_type>(attr0, attr1));\n}\n\nclass BroadcastElementwiseUnaryFactoryImpl : public BroadcastElementwiseUnaryFactory {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(BroadcastElementwiseUnaryFactoryImpl);\n  BroadcastElementwiseUnaryFactoryImpl() = default;\n  ~BroadcastElementwiseUnaryFactoryImpl() override = default;\n\n  std::unique_ptr<BroadcastElementwiseUnary> New(UnaryOp op, DataType src_type, DataType dst_type,\n                                                 size_t max_num_dims) override {\n    return New(op, src_type, dst_type, max_num_dims, Scalar(), Scalar());\n  }\n\n  std::unique_ptr<BroadcastElementwiseUnary> New(UnaryOp op, DataType src_type, DataType dst_type,\n                                                 size_t max_num_dims, Scalar attr0) override {\n    return New(op, src_type, dst_type, max_num_dims, attr0, Scalar());\n  }\n\n  std::unique_ptr<BroadcastElementwiseUnary> New(UnaryOp unary_op, DataType src_type,\n                                                 DataType dst_type, size_t max_num_dims,\n                                                 Scalar attr0, Scalar attr1) override {\n    if (max_num_dims > kMaxNumDims) { return nullptr; }\n#define MAKE_NEW_SAME_DTYPE_BROADCAST_ELEMENTWISE_UNARY_ENTRY(unary_op, dtype_pair)          \\\n  {std::make_tuple(unary_op, OF_PP_PAIR_SECOND(dtype_pair), OF_PP_PAIR_SECOND(dtype_pair)),  \\\n   NewBroadcastElementwiseUnary<unary_op, OF_PP_PAIR_FIRST(dtype_pair),                      \\\n                                OF_PP_PAIR_SECOND(dtype_pair), OF_PP_PAIR_FIRST(dtype_pair), \\\n                                OF_PP_PAIR_SECOND(dtype_pair)>},\n\n#define MAKE_NEW_BROADCAST_ELEMENTWISE_UNARY_ENTRY(unary_op, src_dtype_pair, dst_dtype_pair) \\\n  {std::make_tuple(unary_op, OF_PP_PAIR_SECOND(src_dtype_pair),                              \\\n                   OF_PP_PAIR_SECOND(dst_dtype_pair)),                                       \\\n   NewBroadcastElementwiseUnary<                                                             \\\n       unary_op, OF_PP_PAIR_FIRST(src_dtype_pair), OF_PP_PAIR_SECOND(src_dtype_pair),        \\\n       OF_PP_PAIR_FIRST(dst_dtype_pair), OF_PP_PAIR_SECOND(dst_dtype_pair)>},\n\n    static const std::map<std::tuple<UnaryOp, DataType, DataType>,\n                          std::function<std::unique_ptr<BroadcastElementwiseUnary>(Scalar, Scalar)>>\n        new_broadcast_elementwise_unary_handle{\n            // For All Type OP\n            OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_NEW_SAME_DTYPE_BROADCAST_ELEMENTWISE_UNARY_ENTRY,\n                                             UNARY_IDENTITY_SEQ, CPU_PRIMITIVE_ALL_TYPE_SEQ)\n\n            // For Cast OP\n            OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(\n                MAKE_NEW_BROADCAST_ELEMENTWISE_UNARY_ENTRY, BROADCAST_ELEMENTWISE_CAST_OP_SEQ,\n                CPU_PRIMITIVE_CAST_REAL_TYPE_SEQ, CPU_PRIMITIVE_CAST_REAL_TYPE_SEQ)\n                OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(\n                    MAKE_NEW_BROADCAST_ELEMENTWISE_UNARY_ENTRY, BROADCAST_ELEMENTWISE_CAST_OP_SEQ,\n                    CPU_PRIMITIVE_COMPLEX_TYPE_SEQ, CPU_PRIMITIVE_COMPLEX_TYPE_SEQ)\n                    OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_NEW_BROADCAST_ELEMENTWISE_UNARY_ENTRY,\n                                                     BROADCAST_ELEMENTWISE_CAST_OP_SEQ,\n                                                     CPU_PRIMITIVE_CAST_REAL_TYPE_SEQ,\n                                                     CPU_PRIMITIVE_COMPLEX_TYPE_SEQ)};\n\n#undef MAKE_NEW_BROADCAST_ELEMENTWISE_UNARY_ENTRY\n#undef MAKE_NEW_SAME_DTYPE_BROADCAST_ELEMENTWISE_UNARY_ENTRY\n\n    const auto iter =\n        new_broadcast_elementwise_unary_handle.find(std::make_tuple(unary_op, src_type, dst_type));\n    if (iter != new_broadcast_elementwise_unary_handle.end()) {\n      return iter->second(attr0, attr1);\n    } else {\n      return nullptr;\n    }\n  }\n};\n\nREGISTER_PRIMITIVE_FACTORY(DeviceType::kCPU, BroadcastElementwiseUnaryFactory,\n                           BroadcastElementwiseUnaryFactoryImpl);\n\n}  // namespace\n}  // namespace broadcast_elementwise_unary\n}  // namespace primitive\n}  // namespace ep\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/ep/cpu/primitive/broadcast_matmul.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/ep/include/primitive/primitive.h\"\n#include \"oneflow/core/ep/include/primitive/broadcast_matmul.h\"\n#include \"oneflow/core/ep/common/primitive/broadcast_matmul.h\"\n#include \"oneflow/core/common/blas.h\"\n\nnamespace oneflow {\n\nnamespace ep {\nnamespace primitive {\n\nnamespace broadcast_matmul {\n\nnamespace internal {\n\nnamespace {\n\nconstexpr size_t kMaxNumDims = 8;\n\nCBLAS_TRANSPOSE GetCblasTranspose(BlasTransposeType transpose_type, DataType data_type) {\n  if (transpose_type == BlasTransposeType::N) {\n    return CblasNoTrans;\n  } else if (transpose_type == BlasTransposeType::T) {\n    return DType(data_type).is_complex() ? CblasConjTrans : CblasTrans;\n  } else {\n    UNIMPLEMENTED();\n    return CblasNoTrans;\n  }\n}\n\ntemplate<typename T, typename std::enable_if<\n                         !(std::is_same<T, std::complex<float>>::value\n                           || std::is_same<T, std::complex<double>>::value)>::type* = nullptr>\nvoid CblasMatmul(CBLAS_TRANSPOSE trans_a, CBLAS_TRANSPOSE trans_b, int m, int n, int k, T alpha,\n                 const T* a, const T* b, T beta, T* c) {\n  int lda = 0;\n  if (trans_a == CblasNoTrans) {\n    lda = k;\n  } else if (trans_a == CblasTrans || trans_a == CblasConjTrans) {\n    lda = m;\n  } else {\n    UNIMPLEMENTED();\n  }\n  int ldb = 0;\n  if (trans_b == CblasNoTrans) {\n    ldb = n;\n  } else if (trans_b == CblasTrans || trans_b == CblasConjTrans) {\n    ldb = k;\n  } else {\n    UNIMPLEMENTED();\n  }\n  const int ldc = n;\n  cblas_gemm<T>(CblasRowMajor, trans_a, trans_b, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);\n}\n\ntemplate<typename T,\n         typename std::enable_if<std::is_same<T, std::complex<float>>::value\n                                 || std::is_same<T, std::complex<double>>::value>::type* = nullptr>\nvoid CblasMatmul(CBLAS_TRANSPOSE trans_a, CBLAS_TRANSPOSE trans_b, int m, int n, int k, T alpha,\n                 const T* a, const T* b, T beta, T* c) {\n  int lda = 0;\n  if (trans_a == CblasNoTrans) {\n    lda = k;\n  } else if (trans_a == CblasTrans || trans_a == CblasConjTrans) {\n    lda = m;\n  } else {\n    UNIMPLEMENTED();\n  }\n  int ldb = 0;\n  if (trans_b == CblasNoTrans) {\n    ldb = n;\n  } else if (trans_b == CblasTrans || trans_b == CblasConjTrans) {\n    ldb = k;\n  } else {\n    UNIMPLEMENTED();\n  }\n  const int ldc = n;\n  cblas_gemm<T>(CblasRowMajor, trans_a, trans_b, m, n, k, reinterpret_cast<const void*>(&alpha),\n                reinterpret_cast<const void*>(a), lda, reinterpret_cast<const void*>(b), ldb,\n                reinterpret_cast<const void*>(&beta), reinterpret_cast<void*>(c), ldc);\n}\n\ntemplate<typename T>\nvoid LaunchCblasBroadcastMatmul(Stream* /*stream*/, DataType data_type,\n                                BlasTransposeType transpose_a, BlasTransposeType transpose_b,\n                                int64_t num_batch_dims, const int64_t* broadcast_batch_dims,\n                                const int64_t* a_batch_dims, const int64_t* b_batch_dims,\n                                const int64_t* c_batch_dims, int64_t m, int64_t n, int64_t k,\n                                Scalar alpha, const void* a, const void* b, Scalar beta, void* c) {\n  const CBLAS_TRANSPOSE cblas_trans_a = GetCblasTranspose(transpose_a, data_type);\n  const CBLAS_TRANSPOSE cblas_trans_b = GetCblasTranspose(transpose_b, data_type);\n  const T alpha_value = alpha.Value<T>();\n  auto func = [&](const void* batch_a, const void* batch_b, void* batch_c, Scalar batch_beta) {\n    const T beta_value = batch_beta.Value<T>();\n    CblasMatmul<T>(cblas_trans_a, cblas_trans_b, m, n, k, alpha_value,\n                   static_cast<const T*>(batch_a), static_cast<const T*>(batch_b), beta_value,\n                   static_cast<T*>(batch_c));\n  };\n  ForEachMatmul<kMaxNumDims>(data_type, m, n, k, beta, num_batch_dims, broadcast_batch_dims,\n                             a_batch_dims, b_batch_dims, c_batch_dims, a, b, c, func);\n}\n\nvoid LaunchBroadcastMatmul(Stream* stream, DataType data_type, BlasTransposeType transpose_a,\n                           BlasTransposeType transpose_b, int64_t num_batch_dims,\n                           const int64_t* broadcast_batch_dims, const int64_t* a_batch_dims,\n                           const int64_t* b_batch_dims, const int64_t* c_batch_dims, int64_t m,\n                           int64_t n, int64_t k, Scalar alpha, const void* a, const void* b,\n                           Scalar beta, void* c) {\n  if (data_type == DataType::kFloat) {\n    LaunchCblasBroadcastMatmul<float>(stream, data_type, transpose_a, transpose_b, num_batch_dims,\n                                      broadcast_batch_dims, a_batch_dims, b_batch_dims,\n                                      c_batch_dims, m, n, k, alpha, a, b, beta, c);\n  } else if (data_type == DataType::kDouble) {\n    LaunchCblasBroadcastMatmul<double>(stream, data_type, transpose_a, transpose_b, num_batch_dims,\n                                       broadcast_batch_dims, a_batch_dims, b_batch_dims,\n                                       c_batch_dims, m, n, k, alpha, a, b, beta, c);\n  } else if (data_type == DataType::kComplex64) {\n    LaunchCblasBroadcastMatmul<std::complex<float>>(\n        stream, data_type, transpose_a, transpose_b, num_batch_dims, broadcast_batch_dims,\n        a_batch_dims, b_batch_dims, c_batch_dims, m, n, k, alpha, a, b, beta, c);\n  } else if (data_type == DataType::kComplex128) {\n    LaunchCblasBroadcastMatmul<std::complex<double>>(\n        stream, data_type, transpose_a, transpose_b, num_batch_dims, broadcast_batch_dims,\n        a_batch_dims, b_batch_dims, c_batch_dims, m, n, k, alpha, a, b, beta, c);\n  } else {\n    UNIMPLEMENTED();\n  }\n}\n\nclass BroadcastMatmulFactoryImpl : public BroadcastMatmulFactory {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(BroadcastMatmulFactoryImpl);\n  BroadcastMatmulFactoryImpl() = default;\n  ~BroadcastMatmulFactoryImpl() override = default;\n\n  std::unique_ptr<BroadcastMatmul> New(DataType data_type, BlasTransposeType transpose_a,\n                                       BlasTransposeType transpose_b,\n                                       size_t max_num_dims) override {\n    if (max_num_dims > kMaxNumDims) { return nullptr; }\n    if (data_type == DataType::kFloat || data_type == DataType::kDouble\n        || data_type == DataType::kComplex64 || data_type == DataType::kComplex128) {\n      return std::make_unique<BroadcastMatmulImpl<kMaxNumDims>>(data_type, transpose_a,\n                                                                transpose_b);\n    } else {\n      return nullptr;\n    }\n  }\n};\n\nREGISTER_PRIMITIVE_FACTORY(DeviceType::kCPU, BroadcastMatmulFactory, BroadcastMatmulFactoryImpl);\n\n}  // namespace\n\n}  // namespace internal\n\n}  // namespace broadcast_matmul\n\n}  // namespace primitive\n}  // namespace ep\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/ep/cpu/primitive/cast.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/ep/include/primitive/cast.h\"\n#include \"oneflow/core/ep/cpu/primitive/type_seq.h\"\n\nnamespace oneflow {\n\nnamespace ep {\nnamespace primitive {\n\nnamespace {\n\ntemplate<typename From, typename To, typename = void>\nstruct CpuCastFunctor {\n  static void Call(const From* from, To* to, size_t count) {\n    for (size_t i = 0; i < count; ++i) { to[i] = static_cast<To>(from[i]); }\n  }\n};\n\ntemplate<typename To>\nstruct CpuCastFunctor<bfloat16, To,\n                      typename std::enable_if<!(std::is_same<To, bfloat16>::value)>::type> {\n  static void Call(const bfloat16* from, To* to, size_t count) {\n    for (size_t i = 0; i < count; ++i) { to[i] = static_cast<To>(static_cast<float>(from[i])); }\n  }\n};\n\ntemplate<typename From>\nstruct CpuCastFunctor<From, bfloat16,\n                      typename std::enable_if<!(std::is_same<From, bfloat16>::value)>::type> {\n  static void Call(const From* from, bfloat16* to, size_t count) {\n    for (size_t i = 0; i < count; ++i) { to[i] = bfloat16(static_cast<float>(from[i])); }\n  }\n};\n\ntemplate<typename From, typename To>\nclass CastImpl : public Cast {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CastImpl);\n  CastImpl() = default;\n  ~CastImpl() override = default;\n\n  void Launch(Stream* stream, const void* from, void* to, size_t count) override {\n    CpuCastFunctor<From, To>::Call(reinterpret_cast<const From*>(from), reinterpret_cast<To*>(to),\n                                   count);\n  }\n};\n\ntemplate<typename From, typename To>\nstd::unique_ptr<Cast> NewCast() {\n  return std::unique_ptr<Cast>(new CastImpl<From, To>());\n}\n\n#define CPU_PRIMITIVE_CAST_TYPE_SEQ \\\n  CPU_PRIMITIVE_BOOL_TYPE_SEQ       \\\n  CPU_PRIMITIVE_CHAR_TYPE_SEQ       \\\n  CPU_PRIMITIVE_INT8_TYPE_SEQ       \\\n  CPU_PRIMITIVE_UINT8_TYPE_SEQ      \\\n  CPU_PRIMITIVE_INT32_TYPE_SEQ      \\\n  CPU_PRIMITIVE_UINT32_TYPE_SEQ     \\\n  CPU_PRIMITIVE_INT64_TYPE_SEQ      \\\n  CPU_PRIMITIVE_UINT64_TYPE_SEQ     \\\n  CPU_PRIMITIVE_FLOAT_TYPE_SEQ      \\\n  CPU_PRIMITIVE_DOUBLE_TYPE_SEQ     \\\n  CPU_PRIMITIVE_FLOAT16_TYPE_SEQ    \\\n  CPU_PRIMITIVE_BFLOAT16_TYPE_SEQ\n\nclass CastFactoryImpl : public CastFactory {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CastFactoryImpl);\n  CastFactoryImpl() = default;\n  ~CastFactoryImpl() override = default;\n\n  std::unique_ptr<Cast> New(DataType from, DataType to) override {\n#define MAKE_NEW_CAST_ENTRY(from_pair, to_pair)                              \\\n  {std::make_pair(OF_PP_PAIR_SECOND(from_pair), OF_PP_PAIR_SECOND(to_pair)), \\\n   NewCast<OF_PP_PAIR_FIRST(from_pair), OF_PP_PAIR_FIRST(to_pair)>},\n\n    static const std::map<std::pair<DataType, DataType>, std::function<std::unique_ptr<Cast>()>>\n        new_cast_handle{OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(\n            MAKE_NEW_CAST_ENTRY, CPU_PRIMITIVE_CAST_TYPE_SEQ, CPU_PRIMITIVE_CAST_TYPE_SEQ)};\n\n#undef MAKE_NEW_CAST_ENTRY\n\n    const auto it = new_cast_handle.find(std::make_pair(from, to));\n    if (it != new_cast_handle.end()) {\n      return it->second();\n    } else {\n      return nullptr;\n    }\n  }\n};\n\nREGISTER_PRIMITIVE_FACTORY(DeviceType::kCPU, CastFactory, CastFactoryImpl);\n\n}  // namespace\n\n}  // namespace primitive\n}  // namespace ep\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/ep/cpu/primitive/constant_pad.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/ep/include/primitive/constant_pad.h\"\n#include \"oneflow/core/ep/common/primitive/constant_pad.h\"\n#include \"oneflow/core/ep/cpu/primitive/type_seq.h\"\n\nnamespace oneflow {\n\nnamespace ep {\n\nnamespace primitive {\n\nnamespace {\n\ntemplate<size_t num_dims, typename IndexType, typename StorageType>\nvoid ConstantPadKernel(ConstantPadParams<num_dims, IndexType> params, StorageType packed_pad_val) {\n  const StorageType* src = reinterpret_cast<const StorageType*>(params.src);\n  StorageType* dst = reinterpret_cast<StorageType*>(params.dst);\n  IndexType src_index[num_dims];\n  IndexType dst_index[num_dims];\n  for (IndexType linear_index = 0; linear_index < params.elem_cnt; ++linear_index) {\n    params.dst_index_helper.OffsetToNdIndex(linear_index, dst_index);\n    bool if_pad = false;\n    for (int i = 0; i < num_dims; i++) {\n      if (dst_index[i] >= params.valid_start[i] && dst_index[i] < params.valid_end[i]) {\n        src_index[i] = dst_index[i] - params.valid_start[i];\n      } else {\n        if_pad = true;\n        break;\n      }\n    }\n    StorageType dst_val = packed_pad_val;\n    if (!if_pad) {\n      const IndexType src_offset = params.src_index_helper.NdIndexToOffset(src_index);\n      dst_val = src[src_offset];\n    }\n    dst[linear_index] = dst_val;\n  }\n}\n\ntemplate<>\nfloat16 GetValue<float16>(Scalar value) {\n  return static_cast<float16>(GetValue<float>(value));\n}\n\ntemplate<>\nbfloat16 GetValue<bfloat16>(Scalar value) {\n  return static_cast<bfloat16>(GetValue<float>(value));\n}\n\ntemplate<size_t num_dims, typename IndexType, typename StorageType>\nvoid LaunchKernel(ConstantPadParams<num_dims, IndexType> params, StorageType packed_pad_val) {\n  ConstantPadKernel<num_dims, IndexType, StorageType>(params, packed_pad_val);\n}\n\ntemplate<size_t num_dims, typename IndexType, typename StorageType>\nvoid LaunchKernel(void* dst, const int64_t* dst_dims, const void* src, const int64_t* src_dims,\n                  const int64_t* padding_before, const int64_t* padding_after,\n                  StorageType packed_pad_val, size_t elem_cnt) {\n  ConstantPadParams<num_dims, IndexType> params;\n  params.dst_index_helper = OffsetToIndexCalculator<IndexType, num_dims>(dst_dims);\n  params.src_index_helper = NdIndexOffsetHelper<IndexType, num_dims>(src_dims);\n  params.dst = dst;\n  params.src = src;\n  for (int i = 0; i < num_dims; i++) {\n    params.valid_start[i] = padding_before[i];\n    params.valid_end[i] = dst_dims[i] - padding_after[i];\n  }\n  params.elem_cnt = elem_cnt;\n  LaunchKernel<num_dims, IndexType, StorageType>(params, packed_pad_val);\n}\n\ntemplate<size_t num_dims, typename StorageType>\nvoid DispatchIndexType(void* dst, const int64_t* dst_dims, const void* src, const int64_t* src_dims,\n                       const int64_t* padding_before, const int64_t* padding_after,\n                       StorageType packed_pad_val, size_t elem_cnt) {\n  if (elem_cnt < GetMaxVal<int32_t>()) {\n    LaunchKernel<num_dims, int32_t, StorageType>(dst, dst_dims, src, src_dims, padding_before,\n                                                 padding_after, packed_pad_val, elem_cnt);\n  } else {\n    LaunchKernel<num_dims, int64_t, StorageType>(dst, dst_dims, src, src_dims, padding_before,\n                                                 padding_after, packed_pad_val, elem_cnt);\n  }\n}\n\ntemplate<size_t num_dims, typename T>\nvoid DispatchPackSize(void* dst, int64_t* dst_dims, const void* src, int64_t* src_dims,\n                      int64_t* padding_before, int64_t* padding_after, T pad_val) {\n  constexpr int32_t max_packsize = GetMaxPackSize<T>();\n  size_t launch_pack_size = GetLaunchPackSize<max_packsize>(num_dims, dst, dst_dims, src, src_dims,\n                                                            padding_before, padding_after);\n\n  dst_dims[num_dims - 1] /= launch_pack_size;\n  src_dims[num_dims - 1] /= launch_pack_size;\n  padding_before[num_dims - 1] /= launch_pack_size;\n  padding_after[num_dims - 1] /= launch_pack_size;\n\n  size_t elem_cnt = 1;\n  for (int i = 0; i < num_dims; i++) { elem_cnt *= dst_dims[i]; }\n\n  if (launch_pack_size == 1) {\n    Pack<T, 1> packed_pad_val(pad_val);\n    DispatchIndexType<num_dims, PackType<T, 1>>(dst, dst_dims, src, src_dims, padding_before,\n                                                padding_after, packed_pad_val.storage, elem_cnt);\n  } else if (launch_pack_size == 2) {\n    Pack<T, 2> packed_pad_val(pad_val);\n    DispatchIndexType<num_dims, PackType<T, 2>>(dst, dst_dims, src, src_dims, padding_before,\n                                                padding_after, packed_pad_val.storage, elem_cnt);\n  } else if (launch_pack_size == 4) {\n    Pack<T, 4> packed_pad_val(pad_val);\n    DispatchIndexType<num_dims, PackType<T, 4>>(dst, dst_dims, src, src_dims, padding_before,\n                                                padding_after, packed_pad_val.storage, elem_cnt);\n  } else if (launch_pack_size == 8) {\n    Pack<T, 8> packed_pad_val(pad_val);\n    DispatchIndexType<num_dims, PackType<T, 8>>(dst, dst_dims, src, src_dims, padding_before,\n                                                padding_after, packed_pad_val.storage, elem_cnt);\n  } else if (launch_pack_size == 16) {\n    Pack<T, 16> packed_pad_val(pad_val);\n    DispatchIndexType<num_dims, PackType<T, 16>>(dst, dst_dims, src, src_dims, padding_before,\n                                                 padding_after, packed_pad_val.storage, elem_cnt);\n  } else {\n    UNIMPLEMENTED();\n  }\n}\n\ntemplate<typename T>\nvoid LaunchWithSimplified(size_t num_dims, void* dst, int64_t* dst_dims, const void* src,\n                          int64_t* src_dims, int64_t* padding_before, int64_t* padding_after,\n                          T pad_val) {\n  void (*func)(void* /*dst*/, int64_t* /*dst_dims*/, const void* /*src*/, int64_t* /*src_dims*/,\n               int64_t* /*padding_before*/, int64_t* /*padding_after*/, T) = nullptr;\n  if (num_dims == 1) {\n    func = DispatchPackSize<1, T>;\n  } else if (num_dims == 2) {\n    func = DispatchPackSize<2, T>;\n  } else if (num_dims == 3) {\n    func = DispatchPackSize<3, T>;\n  } else if (num_dims == 4) {\n    func = DispatchPackSize<4, T>;\n  } else if (num_dims == 5) {\n    func = DispatchPackSize<5, T>;\n  } else if (num_dims == 6) {\n    func = DispatchPackSize<6, T>;\n  } else if (num_dims == 7) {\n    func = DispatchPackSize<7, T>;\n  } else if (num_dims == 8) {\n    func = DispatchPackSize<8, T>;\n  } else {\n    UNIMPLEMENTED();\n  }\n  func(dst, dst_dims, src, src_dims, padding_before, padding_after, pad_val);\n}\n\ntemplate<typename T>\nvoid SimplifyThenLaunch(size_t num_dims, const int64_t* src_dims, const void* src,\n                        const int64_t* padding_before, const int64_t* padding_after, T pad_val,\n                        void* dst) {\n  CHECK_GT(num_dims, 0) << \"num_dims must greater than 0\";\n  CHECK_LE(num_dims, kMaxNumDims);\n  int64_t simplified_dst_dims[kMaxNumDims];\n  int64_t simplified_src_dims[kMaxNumDims];\n  int64_t simplified_padding_before[kMaxNumDims];\n  int64_t simplified_padding_after[kMaxNumDims];\n  size_t simplified_num_dims = 1;\n  SimplifyPadDims(num_dims, src_dims, padding_before, padding_after, &simplified_num_dims,\n                  simplified_dst_dims, simplified_src_dims, simplified_padding_before,\n                  simplified_padding_after);\n  LaunchWithSimplified<T>(simplified_num_dims, dst, simplified_dst_dims, src, simplified_src_dims,\n                          simplified_padding_before, simplified_padding_after, pad_val);\n}\n\ntemplate<typename T>\nclass ConstantPadImpl : public ConstantPad {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(ConstantPadImpl);\n  ConstantPadImpl() = default;\n  ~ConstantPadImpl() override = default;\n\n  void Launch(Stream* stream, size_t num_dims, const int64_t* src_dims, const void* src,\n              const int64_t* padding_before, const int64_t* padding_after, Scalar pad_val,\n              void* dst) override {\n    SimplifyThenLaunch<T>(num_dims, src_dims, src, padding_before, padding_after,\n                          GetValue<T>(pad_val), dst);\n  }\n};\n\ntemplate<typename T>\nstd::unique_ptr<ConstantPad> NewConstantPad() {\n  return std::unique_ptr<ConstantPad>(new ConstantPadImpl<T>());\n}\n\nclass ConstantPadFactoryImpl : public ConstantPadFactory {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(ConstantPadFactoryImpl);\n  ConstantPadFactoryImpl() = default;\n  ~ConstantPadFactoryImpl() override = default;\n\n  std::unique_ptr<ConstantPad> New(DataType data_type) override {\n#define MAKE_NEW_CONSTANT_PAD_ENTRY(type_cpp, type_proto) {type_proto, NewConstantPad<type_cpp>},\n\n    static const std::map<DataType, std::function<std::unique_ptr<ConstantPad>()>>\n        new_constant_pad_handle{\n            OF_PP_FOR_EACH_TUPLE(MAKE_NEW_CONSTANT_PAD_ENTRY, CPU_PRIMITIVE_ALL_TYPE_SEQ)};\n\n#undef MAKE_NEW_CONSTANT_PAD_ENTRY\n\n    const auto it = new_constant_pad_handle.find(data_type);\n    if (it != new_constant_pad_handle.end()) {\n      return it->second();\n    } else {\n      return nullptr;\n    }\n  }\n};\n\nREGISTER_PRIMITIVE_FACTORY(DeviceType::kCPU, ConstantPadFactory, ConstantPadFactoryImpl);\n\n}  // namespace\n\n}  // namespace primitive\n\n}  // namespace ep\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/ep/cpu/primitive/copy_nd.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/ep/include/primitive/copy_nd.h\"\n#include \"oneflow/core/ep/common/primitive/copy_nd.h\"\n\nnamespace oneflow {\n\nnamespace ep {\nnamespace primitive {\n\nnamespace {\n\ntemplate<size_t num_dims, size_t movement_size, typename IndexType>\nvoid CopyNdKernel(CopyNdKernelParams<num_dims, IndexType> params) {\n  using T = typename std::aligned_storage<movement_size, movement_size>::type;\n  const T* src = reinterpret_cast<const T*>(params.src);\n  T* dst = reinterpret_cast<T*>(params.dst);\n  for (IndexType i = 0; i < params.count; ++i) {\n    IndexType copy_index[num_dims];\n    IndexType src_index[num_dims];\n    IndexType dst_index[num_dims];\n    params.copy_index_helper.OffsetToNdIndex(i, copy_index);\n    for (size_t j = 0; j < num_dims; ++j) {\n      src_index[j] = params.src_pos[j] + copy_index[j];\n      dst_index[j] = params.dst_pos[j] + copy_index[j];\n    }\n    const IndexType src_offset = params.src_index_helper.NdIndexToOffset(src_index);\n    const IndexType dst_offset = params.dst_index_helper.NdIndexToOffset(dst_index);\n    dst[dst_offset] = src[src_offset];\n  }\n}\n\ntemplate<size_t num_dims, size_t movement_size, typename IndexType>\nvoid LaunchKernel(Stream* stream, CopyNdKernelParams<num_dims, IndexType> params) {\n  CopyNdKernel<num_dims, movement_size, IndexType>(params);\n}\n\nclass CopyNdImpl : public CopyNd {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CopyNdImpl);\n  CopyNdImpl() = default;\n  ~CopyNdImpl() = default;\n\n  void Launch(Stream* stream, DataType data_type, size_t num_dims, void* dst,\n              const int64_t* dst_dims, const int64_t* dst_pos, const void* src,\n              const int64_t* src_dims, const int64_t* src_pos,\n              const int64_t* extent) const override {\n    SimplifyThenLaunch(stream, data_type, num_dims, dst, dst_dims, dst_pos, src, src_dims, src_pos,\n                       extent);\n  }\n};\n\nclass CopyNdFactoryImpl : public CopyNdFactory {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CopyNdFactoryImpl);\n  CopyNdFactoryImpl() = default;\n  ~CopyNdFactoryImpl() override = default;\n\n  std::unique_ptr<CopyNd> New(size_t max_num_dims) override {\n    if (max_num_dims <= kMaxNumDims) {\n      return std::unique_ptr<CopyNd>(new CopyNdImpl());\n    } else {\n      return nullptr;\n    }\n  }\n};\n\nREGISTER_PRIMITIVE_FACTORY(DeviceType::kCPU, CopyNdFactory, CopyNdFactoryImpl);\n\n}  // namespace\n\n}  // namespace primitive\n}  // namespace ep\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/ep/cpu/primitive/elementwise_unary.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/ep/common/primitive/elementwise_unary.h\"\n#include \"oneflow/core/common/scalar.h\"\n#include \"oneflow/core/ep/cpu/primitive/unary_functor.h\"\n#include \"oneflow/core/ep/cpu/cpu_stream.h\"\n#include \"oneflow/core/ep/cpu/cpu_device.h\"\n\nnamespace oneflow {\n\nnamespace ep {\nnamespace primitive {\n\nnamespace {\n\ntemplate<UnaryOp unary_op, typename Src, typename Dst>\nclass ElementwiseUnaryImpl : public ElementwiseUnary {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(ElementwiseUnaryImpl);\n  ElementwiseUnaryImpl(Scalar attr0, Scalar attr1) : attr0(attr0), attr1(attr1) {}\n  ~ElementwiseUnaryImpl() override = default;\n\n  void Launch(Stream* stream, const void* src_ptr, void* dst_ptr, size_t count) override {\n    CpuStream* cpu_stream = stream->As<CpuStream>();\n\n    Dst* dst = reinterpret_cast<Dst*>(dst_ptr);\n    const Src* src = reinterpret_cast<const Src*>(src_ptr);\n    auto functor = UnaryFunctor<DeviceType::kCPU, unary_op, Dst, Src>(attr0, attr1);\n    cpu_stream->ParallelFor(0, count, [functor, src, dst](int64_t begin, int64_t end) {\n      for (int64_t i = begin; i < end; i++) { dst[i] = functor(src[i]); }\n    });\n  }\n\n protected:\n  Scalar attr0, attr1;\n};\n\ntemplate<UnaryOp unary_op, typename Src, typename Dst>\nstd::unique_ptr<ElementwiseUnary> NewElementwiseUnary(Scalar attr0, Scalar attr1) {\n  return std::unique_ptr<ElementwiseUnary>(\n      new ElementwiseUnaryImpl<unary_op, Src, Dst>(attr0, attr1));\n}\n\nclass ElementwiseUnaryFactoryImpl : public ElementwiseUnaryFactory {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(ElementwiseUnaryFactoryImpl);\n  ElementwiseUnaryFactoryImpl() = default;\n  ~ElementwiseUnaryFactoryImpl() override = default;\n\n  std::unique_ptr<ElementwiseUnary> New(UnaryOp unary_op, DataType src_type,\n                                        DataType dst_dtype) override {\n    return New(unary_op, src_type, dst_dtype, Scalar(), Scalar());\n  }\n\n  std::unique_ptr<ElementwiseUnary> New(UnaryOp unary_op, DataType src_type, DataType dst_dtype,\n                                        Scalar attr0) override {\n    return New(unary_op, src_type, dst_dtype, attr0, Scalar());\n  }\n\n  std::unique_ptr<ElementwiseUnary> New(UnaryOp unary_op, DataType src_type, DataType dst_dtype,\n                                        Scalar attr0, Scalar attr1) override {\n#define MAKE_NEW_SAME_DTYPE_ELEMENTWISE_UNARY_ENTRY(unary_op, dtype_pair)                   \\\n  {std::make_tuple(unary_op, OF_PP_PAIR_SECOND(dtype_pair), OF_PP_PAIR_SECOND(dtype_pair)), \\\n   NewElementwiseUnary<unary_op, OF_PP_PAIR_FIRST(dtype_pair), OF_PP_PAIR_FIRST(dtype_pair)>},\n\n#define MAKE_NEW_DIFFERENT_DTYPE_ELEMENTWISE_UNARY_ENTRY(unary_op, src_type_pair, dst_dtype_pair)  \\\n  {std::make_tuple(unary_op, OF_PP_PAIR_SECOND(src_type_pair), OF_PP_PAIR_SECOND(dst_dtype_pair)), \\\n   NewElementwiseUnary<unary_op, OF_PP_PAIR_FIRST(src_type_pair),                                  \\\n                       OF_PP_PAIR_FIRST(dst_dtype_pair)>},\n\n    static const std::map<std::tuple<UnaryOp, DataType, DataType>,\n                          std::function<std::unique_ptr<ElementwiseUnary>(Scalar, Scalar)>>\n        new_elementwise_unary_handle{\n            // For All Type OP\n            OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_NEW_SAME_DTYPE_ELEMENTWISE_UNARY_ENTRY,\n                                             UNARY_MATH_OP_SEQ, CPU_PRIMITIVE_NATIVE_TYPE_SEQ)\n            // For Float Type OP\n            OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(\n                MAKE_NEW_SAME_DTYPE_ELEMENTWISE_UNARY_ENTRY, UNARY_FLOATING_MATH_OP_SEQ,\n                CPU_PRIMITIVE_FLOATING_TYPE_SEQ CPU_PRIMITIVE_BFLOAT16_TYPE_SEQ)\n\n            // For Complex Type OP\n            OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_NEW_SAME_DTYPE_ELEMENTWISE_UNARY_ENTRY,\n                                             UNARY_COMPLEX_C2C_OP_SEQ,\n                                             CPU_PRIMITIVE_COMPLEX_TYPE_SEQ)\n\n                OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(\n                    MAKE_NEW_DIFFERENT_DTYPE_ELEMENTWISE_UNARY_ENTRY, UNARY_COMPLEX_C2R_OP_SEQ,\n                    CPU_PRIMITIVE_COMPLEX_TYPE_SEQ, CPU_PRIMITIVE_FLOATING_TYPE_SEQ)\n\n                    OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(\n                        MAKE_NEW_DIFFERENT_DTYPE_ELEMENTWISE_UNARY_ENTRY, UNARY_COMPLEX_R2C_OP_SEQ,\n                        CPU_PRIMITIVE_FLOATING_TYPE_SEQ, CPU_PRIMITIVE_COMPLEX_TYPE_SEQ)\n\n            // For Int Type OP\n            OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_NEW_SAME_DTYPE_ELEMENTWISE_UNARY_ENTRY,\n                                             UNARY_INT_MATH_OP_SEQ, CPU_PRIMITIVE_INT_TYPE_SEQ)\n\n            // For Bitwise OP\n            OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_NEW_SAME_DTYPE_ELEMENTWISE_UNARY_ENTRY,\n                                             UNARY_BITWISE_OP_SEQ,\n                                             CPU_PRIMITIVE_INT_TYPE_SEQ CPU_PRIMITIVE_BOOL_TYPE_SEQ)\n\n            // For Utils OP\n            OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_NEW_DIFFERENT_DTYPE_ELEMENTWISE_UNARY_ENTRY,\n                                             UNARY_UTILS_OP_SEQ, UTIL_OPS_DATA_TYPE_SEQ,\n                                             CPU_PRIMITIVE_BOOL_TYPE_SEQ)\n\n            // For Logical OP\n            OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_NEW_DIFFERENT_DTYPE_ELEMENTWISE_UNARY_ENTRY,\n                                             UNARY_LOGICAL_OP_SEQ, CPU_PRIMITIVE_NATIVE_TYPE_SEQ,\n                                             CPU_PRIMITIVE_BOOL_TYPE_SEQ)};\n\n#undef MAKE_NEW_DIFFERENT_DTYPE_ELEMENTWISE_UNARY_ENTRY\n\n#undef MAKE_NEW_SAME_DTYPE_ELEMENTWISE_UNARY_ENTRY\n\n    const auto it =\n        new_elementwise_unary_handle.find(std::make_tuple(unary_op, src_type, dst_dtype));\n    if (it != new_elementwise_unary_handle.end()) {\n      return it->second(attr0, attr1);\n    } else {\n      return nullptr;\n    }\n  }\n};\n\nREGISTER_PRIMITIVE_FACTORY(DeviceType::kCPU, ElementwiseUnaryFactory, ElementwiseUnaryFactoryImpl);\n\n}  // namespace\n}  // namespace primitive\n}  // namespace ep\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/ep/cpu/primitive/fill.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/ep/include/primitive/fill.h\"\n#include \"oneflow/core/ep/cpu/primitive/type_seq.h\"\n#include \"oneflow/core/common/scalar.h\"\n\nnamespace oneflow {\n\nnamespace ep {\nnamespace primitive {\n\nnamespace {\n\ntemplate<typename T>\nT GetValue(Scalar value) {\n  return value.Value<T>();\n}\n\ntemplate<>\nfloat16 GetValue<float16>(Scalar value) {\n  return static_cast<float16>(GetValue<float>(value));\n}\n\ntemplate<>\nbfloat16 GetValue<bfloat16>(Scalar value) {\n  return static_cast<bfloat16>(GetValue<float>(value));\n}\n\ntemplate<typename T>\nclass FillImpl : public Fill {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(FillImpl);\n  FillImpl() = default;\n  ~FillImpl() override = default;\n\n  void Launch(Stream* stream, void* dst, Scalar value, size_t count) override {\n    std::fill_n(reinterpret_cast<T*>(dst), count, GetValue<T>(value));\n  }\n};\n\ntemplate<typename T>\nstd::unique_ptr<Fill> NewFill() {\n  return std::unique_ptr<Fill>(new FillImpl<T>());\n}\n\nclass FillFactoryImpl : public FillFactory {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(FillFactoryImpl);\n  FillFactoryImpl() = default;\n  ~FillFactoryImpl() override = default;\n\n  std::unique_ptr<Fill> New(DataType data_type) override {\n#define MAKE_NEW_FILL_ENTRY(type_cpp, type_proto) {type_proto, NewFill<type_cpp>},\n\n    static const std::map<DataType, std::function<std::unique_ptr<Fill>()>> new_fill_handle{\n        OF_PP_FOR_EACH_TUPLE(MAKE_NEW_FILL_ENTRY,\n                             CPU_PRIMITIVE_ALL_TYPE_SEQ CPU_PRIMITIVE_INT16_TYPE_SEQ)};\n#undef MAKE_NEW_ADD_ENTRY\n    const auto it = new_fill_handle.find(data_type);\n    if (it != new_fill_handle.end()) {\n      return it->second();\n    } else {\n      return nullptr;\n    }\n  }\n};\n\nREGISTER_PRIMITIVE_FACTORY(DeviceType::kCPU, FillFactory, FillFactoryImpl);\n\n}  // namespace\n\n}  // namespace primitive\n}  // namespace ep\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/ep/cpu/primitive/memcpy.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/ep/include/primitive/memcpy.h\"\n\nnamespace oneflow {\n\nnamespace ep {\nnamespace primitive {\n\nnamespace {\n\nclass MemcpyImpl : public Memcpy {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(MemcpyImpl);\n  MemcpyImpl() = default;\n  ~MemcpyImpl() = default;\n\n  void Launch(Stream* stream, void* dst, const void* src, size_t count) {\n    if (dst == src) { return; }\n    std::memcpy(dst, src, count);\n  }\n};\n\nclass MemcpyFactoryImpl : public MemcpyFactory {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(MemcpyFactoryImpl);\n  MemcpyFactoryImpl() = default;\n  ~MemcpyFactoryImpl() override = default;\n\n  std::unique_ptr<Memcpy> New(MemcpyKind kind) override { return std::make_unique<MemcpyImpl>(); }\n};\n\nREGISTER_PRIMITIVE_FACTORY(DeviceType::kCPU, MemcpyFactory, MemcpyFactoryImpl);\n\n}  // namespace\n\n}  // namespace primitive\n}  // namespace ep\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/ep/cpu/primitive/memset.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/ep/include/primitive/memset.h\"\n\nnamespace oneflow {\n\nnamespace ep {\nnamespace primitive {\n\nnamespace {\n\nclass MemsetImpl : public Memset {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(MemsetImpl);\n  MemsetImpl() = default;\n  ~MemsetImpl() = default;\n\n  void Launch(Stream* stream, void* ptr, int value, size_t count) {\n    std::memset(ptr, value, count);\n  }\n};\n\nclass MemsetFactoryImpl : public MemsetFactory {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(MemsetFactoryImpl);\n  MemsetFactoryImpl() = default;\n  ~MemsetFactoryImpl() override = default;\n\n  std::unique_ptr<Memset> New() override { return std::make_unique<MemsetImpl>(); }\n};\n\nREGISTER_PRIMITIVE_FACTORY(DeviceType::kCPU, MemsetFactory, MemsetFactoryImpl);\n\n}  // namespace\n\n}  // namespace primitive\n}  // namespace ep\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/ep/cpu/primitive/permute.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/ep/include/primitive/permute.h\"\n#include \"oneflow/core/ep/common/primitive/permute_impl.h\"\n#include \"oneflow/core/ep/cpu/cpu_stream.h\"\n#include \"oneflow/core/ep/cpu/cpu_device.h\"\n#include \"oneflow/core/ep/common/onednn.h\"\n\nnamespace oneflow {\n\nnamespace ep {\nnamespace primitive {\n\nnamespace permute {\n\nnamespace internal {\n\nnamespace {\n\ntemplate<size_t num_dims, size_t movement_size, typename IndexType>\nvoid PermuteKernel(PermuteKernelParams<num_dims, IndexType> params) {\n  using T = typename std::aligned_storage<movement_size, movement_size>::type;\n  const T* src = reinterpret_cast<const T*>(params.src);\n  T* dst = reinterpret_cast<T*>(params.dst);\n  for (IndexType i = 0; i < params.count; ++i) {\n    IndexType src_index[num_dims];\n    IndexType dst_index[num_dims];\n    params.dst_index_helper.OffsetToNdIndex(i, dst_index);\n    for (size_t dim = 0; dim < num_dims; ++dim) {\n      src_index[params.permutation[dim]] = dst_index[dim];\n    }\n    IndexType src_offset = params.src_index_helper.NdIndexToOffset(src_index);\n    dst[i] = src[src_offset];\n  }\n}\n\ntemplate<size_t num_dims, size_t movement_size, typename IndexType>\nvoid LaunchKernel(Stream* stream, const int64_t* src_dims, const void* src, const int* permutation,\n                  void* dst, size_t count) {\n  PermuteKernelParams<num_dims, IndexType> params =\n      MakePermuteParams<num_dims, IndexType>(src_dims, src, permutation, dst, count);\n  PermuteKernel<num_dims, movement_size, IndexType>(params);\n}\nclass PermuteImpl : public Permute {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(PermuteImpl);\n  PermuteImpl() = default;\n  ~PermuteImpl() override = default;\n\n  using Permute::Launch;\n  void Launch(Stream* stream, DataType data_type, size_t num_dims, const int64_t* src_dims,\n              const void* src, const int* permutation, void* dst) override {\n    SimplifyThenLaunch(stream, data_type, num_dims, src_dims, src, permutation, dst);\n  }\n};\n\n#ifdef WITH_ONEDNN\nconstexpr size_t kMaxOneDnnMovementSize = 4;\nconstexpr size_t kMaxOneDnnMapSize = 5;\nuint32_t OnednnDatatypeTagMap[kMaxOneDnnMapSize] = {0, dnnl_u8, dnnl_f16, 0, dnnl_s32};\nclass OneDnnPermuteImpl : public Permute {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(OneDnnPermuteImpl);\n  OneDnnPermuteImpl() = default;\n  ~OneDnnPermuteImpl() override = default;\n\n  using Permute::Launch;\n  void Launch(Stream* stream, DataType data_type, size_t num_dims, const int64_t* src_dims,\n              const void* src, const int* permutation, void* dst) override {\n    CHECK_LE(num_dims, kMaxNumDims);\n    CHECK_GT(num_dims, 0);\n\n    stream->As<CpuStream>()->onednn_executor()->Launch([&](dnnl::engine* onednn_engine,\n                                                           dnnl::stream* onednn_stream) {\n      size_t onednn_num_dims = num_dims;\n      dnnl::memory::dims onednn_dims(kMaxNumDims + 1, 0);\n      dnnl::memory::dims onednn_permute(kMaxNumDims + 1, 0);\n      dnnl::memory::dims src_stride(kMaxNumDims + 1, 0);\n      dnnl::memory::dims dst_stride(kMaxNumDims + 1, 0);\n      for (int64_t dim = onednn_num_dims - 1; dim >= 0; dim--) {\n        onednn_dims[dim] = src_dims[dim];\n        onednn_permute[dim] = permutation[dim];\n      }\n      size_t movement_size = GetSizeOfDataType(data_type);\n      if (movement_size > kMaxOneDnnMovementSize) {\n        onednn_dims[onednn_num_dims] = movement_size / kMaxOneDnnMovementSize;\n        onednn_permute[onednn_num_dims] = onednn_num_dims;\n        onednn_num_dims = onednn_num_dims + 1;\n        movement_size = kMaxOneDnnMovementSize;\n      }\n      onednn_dims.resize(onednn_num_dims);\n\n      src_stride[onednn_num_dims - 1] = 1;\n      dst_stride[onednn_permute[onednn_num_dims - 1]] = 1;\n      for (int64_t i = onednn_num_dims - 2; i >= 0; i--) {\n        src_stride[i] = src_stride[i + 1] * onednn_dims[i + 1];\n        dst_stride[onednn_permute[i]] =\n            dst_stride[onednn_permute[i + 1]] * onednn_dims[onednn_permute[i + 1]];\n      }\n\n      dnnl::memory::data_type onednn_data_type =\n          static_cast<dnnl::memory::data_type>(OnednnDatatypeTagMap[movement_size]);\n      // The reorder primitive requires the source and destination tensors to have the same shape.\n      // Implicit broadcasting is not supported.\n      auto src_mem_desc = dnnl::memory::desc(onednn_dims, onednn_data_type, src_stride);\n      auto dst_mem_desc = dnnl::memory::desc(onednn_dims, onednn_data_type, dst_stride);\n      auto src_mem = dnnl::memory(src_mem_desc, *onednn_engine, const_cast<void*>(src));\n      auto dst_mem = dnnl::memory(dst_mem_desc, *onednn_engine, dst);\n      auto reorder_primitive_desc =\n          dnnl::reorder::primitive_desc(*onednn_engine, src_mem_desc, *onednn_engine, dst_mem_desc);\n      auto reorder_primitive = dnnl::reorder(reorder_primitive_desc);\n\n      reorder_primitive.execute(*onednn_stream, {{DNNL_ARG_SRC, src_mem}, {DNNL_ARG_DST, dst_mem}});\n    });\n  }\n};\n\n#endif  // WITH_ONEDNN\n\nclass PermuteFactoryImpl : public PermuteFactory {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(PermuteFactoryImpl);\n  PermuteFactoryImpl() = default;\n  ~PermuteFactoryImpl() override = default;\n\n  std::unique_ptr<Permute> New(size_t max_num_dims) override {\n    if (max_num_dims <= kMaxNumDims) {\n#ifdef WITH_ONEDNN\n      if (OneDnnIsEnabled()) { return std::unique_ptr<Permute>(new OneDnnPermuteImpl()); }\n#endif\n      return std::unique_ptr<Permute>(new PermuteImpl());\n    } else {\n      return nullptr;\n    }\n  }\n};\n\nREGISTER_PRIMITIVE_FACTORY(DeviceType::kCPU, PermuteFactory, PermuteFactoryImpl);\n\n}  // namespace\n\n}  // namespace internal\n\n}  // namespace permute\n\n}  // namespace primitive\n}  // namespace ep\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/ep/cpu/primitive/softmax.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/ep/include/primitive/softmax.h\"\n#include \"oneflow/core/ep/include/primitive/log_softmax.h\"\n#include \"oneflow/core/ep/cpu/primitive/type_seq.h\"\n#include \"oneflow/core/ep/cpu/cpu_stream.h\"\n#include \"oneflow/core/ep/cpu/cpu_device.h\"\n#include \"oneflow/core/ep/common/primitive/util.h\"\n#include \"oneflow/core/ep/common/onednn.h\"\n\nnamespace oneflow {\n\nnamespace ep {\nnamespace primitive {\n\nnamespace {\n\nenum class Algorithm {\n  kSoftmax,\n  kLogSoftmax,\n};\n\ntemplate<Algorithm algorithm, typename T>\nvoid SoftmaxCpu(size_t rows, size_t cols, const T* x, T* y) {\n  for (size_t i = 0; i < rows; ++i) {\n    size_t row_offset = i * cols;\n    const T* row_x = x + row_offset;\n    T* row_y = y + row_offset;\n    const T row_max = *std::max_element(row_x, row_x + cols);\n    T row_sum = 0;\n    for (size_t j = 0; j < cols; ++j) {\n      if (algorithm == Algorithm::kSoftmax) {\n        T exp_x = std::exp(row_x[j] - row_max);\n        row_sum += exp_x;\n        row_y[j] = exp_x;\n      } else if (algorithm == Algorithm::kLogSoftmax) {\n        row_y[j] = row_x[j] - row_max;\n        row_sum += std::exp(row_y[j]);\n      } else {\n        UNIMPLEMENTED();\n      }\n    }\n    for (size_t j = 0; j < cols; ++j) {\n      if (algorithm == Algorithm::kSoftmax) {\n        row_y[j] /= row_sum;\n      } else if (algorithm == Algorithm::kLogSoftmax) {\n        row_y[j] -= std::log(row_sum);\n      } else {\n        UNIMPLEMENTED();\n      }\n    }\n  }\n}\n\ntemplate<typename SoftmaxBase, Algorithm algorithm, typename T>\nclass SoftmaxImpl : public SoftmaxBase {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(SoftmaxImpl);\n  SoftmaxImpl() = default;\n  ~SoftmaxImpl() override = default;\n\n  void Launch(Stream* stream, size_t rows, size_t cols, const void* x, void* y) override {\n    SoftmaxCpu<algorithm, T>(rows, cols, reinterpret_cast<const T*>(x), reinterpret_cast<T*>(y));\n  }\n};\n\n#ifdef WITH_ONEDNN\n\ntemplate<class OneDnnSoftmax, dnnl::memory::data_type data_type>\nvoid SoftmaxOneDnn(Stream* stream, size_t rows, size_t cols, const void* x, void* y) {\n  stream->As<CpuStream>()->onednn_executor()->Launch(\n      [&](dnnl::engine* onednn_engine, dnnl::stream* onednn_stream) {\n        dnnl::memory::dims src_dims = {static_cast<dnnl::memory::dim>(rows),\n                                       static_cast<dnnl::memory::dim>(cols)};\n\n        auto src_md = dnnl::memory::desc(src_dims, data_type, dnnl::memory::format_tag::nc);\n        auto src_mem = dnnl::memory(src_md, *onednn_engine, const_cast<void*>(x));\n        auto dst_mem = dnnl::memory(src_md, *onednn_engine, y);\n        auto softmax_d = typename OneDnnSoftmax::desc(dnnl::prop_kind::forward, src_md, 1);\n        auto softmax_pd = typename OneDnnSoftmax::primitive_desc(softmax_d, *onednn_engine);\n        auto softmax_prim = OneDnnSoftmax(softmax_pd);\n\n        softmax_prim.execute(*onednn_stream, {{DNNL_ARG_SRC, src_mem}, {DNNL_ARG_DST, dst_mem}});\n      });\n}\n\ntemplate<typename SoftmaxBase, Algorithm algorithm, dnnl::memory::data_type data_type>\nclass OneDnnSoftmaxImpl;\n\n#define CPU_PRIMITIVE_SOFTMAX_ONEDNN_IMPL(oneflow_algorithm, onednn_algorithm)               \\\n  template<typename SoftmaxBase, dnnl::memory::data_type data_type>                          \\\n  class OneDnnSoftmaxImpl<SoftmaxBase, oneflow_algorithm, data_type> : public SoftmaxBase {  \\\n   public:                                                                                   \\\n    OF_DISALLOW_COPY_AND_MOVE(OneDnnSoftmaxImpl);                                            \\\n    OneDnnSoftmaxImpl() = default;                                                           \\\n    ~OneDnnSoftmaxImpl() override = default;                                                 \\\n                                                                                             \\\n    using OneDnnClass = onednn_algorithm;                                                    \\\n    void Launch(Stream* stream, size_t rows, size_t cols, const void* x, void* y) override { \\\n      SoftmaxOneDnn<OneDnnClass, data_type>(stream, rows, cols, x, y);                       \\\n    }                                                                                        \\\n  }\n\nCPU_PRIMITIVE_SOFTMAX_ONEDNN_IMPL(Algorithm::kSoftmax, dnnl::softmax_forward);\nCPU_PRIMITIVE_SOFTMAX_ONEDNN_IMPL(Algorithm::kLogSoftmax, dnnl::logsoftmax_forward);\n#undef CPU_PRIMITIVE_SOFTMAX_ONEDNN_IMPL\n\ntemplate<typename SoftmaxBase, Algorithm algorithm, dnnl::memory::data_type data_type>\nstd::unique_ptr<SoftmaxBase> NewOneDnnSoftmax() {\n  return std::unique_ptr<SoftmaxBase>(new OneDnnSoftmaxImpl<SoftmaxBase, algorithm, data_type>());\n}\n\n#endif  // WITH_ONEDNN\n\ntemplate<typename SoftmaxBase, Algorithm algorithm, typename T>\nstd::unique_ptr<SoftmaxBase> NewSoftmax() {\n  return std::unique_ptr<SoftmaxBase>(new SoftmaxImpl<SoftmaxBase, algorithm, T>());\n}\n\ntemplate<typename FactoryBase, typename SoftmaxBase, Algorithm algorithm>\nclass GenericSoftmaxFactoryImpl : public FactoryBase {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(GenericSoftmaxFactoryImpl);\n  GenericSoftmaxFactoryImpl() = default;\n  ~GenericSoftmaxFactoryImpl() override = default;\n\n  std::unique_ptr<SoftmaxBase> New(DataType data_type) override {\n#define MAKE_NEW_SOFTMAX_ENTRY(type_cpp, type_proto) \\\n  {type_proto, NewSoftmax<SoftmaxBase, algorithm, type_cpp>},\n\n    static const std::map<DataType, std::function<std::unique_ptr<SoftmaxBase>()>>\n        new_softmax_handle{\n            OF_PP_FOR_EACH_TUPLE(MAKE_NEW_SOFTMAX_ENTRY, CPU_PRIMITIVE_FLOATING_TYPE_SEQ)};\n\n#undef MAKE_NEW_SOFTMAX_ENTRY\n\n#ifdef WITH_ONEDNN\n\n    if (OneDnnIsEnabled() && data_type == DataType::kFloat) {\n      static std::function<std::unique_ptr<SoftmaxBase>()> onednn_softmax =\n          NewOneDnnSoftmax<SoftmaxBase, algorithm, dnnl::memory::data_type::f32>;\n      return onednn_softmax();\n    }\n\n#endif\n    return NewPrimitiveFromHandlers(new_softmax_handle, data_type);\n  }\n};\n\nusing SoftmaxFactoryImpl = GenericSoftmaxFactoryImpl<SoftmaxFactory, Softmax, Algorithm::kSoftmax>;\nusing LogSoftmaxFactoryImpl =\n    GenericSoftmaxFactoryImpl<LogSoftmaxFactory, LogSoftmax, Algorithm::kLogSoftmax>;\nREGISTER_PRIMITIVE_FACTORY(DeviceType::kCPU, SoftmaxFactory, SoftmaxFactoryImpl);\nREGISTER_PRIMITIVE_FACTORY(DeviceType::kCPU, LogSoftmaxFactory, LogSoftmaxFactoryImpl);\n\n}  // namespace\n\n}  // namespace primitive\n}  // namespace ep\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/ep/cpu/primitive/softmax_backward.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/ep/include/primitive/softmax_backward.h\"\n#include \"oneflow/core/ep/include/primitive/log_softmax_backward.h\"\n#include \"oneflow/core/ep/cpu/primitive/type_seq.h\"\n#include \"oneflow/core/ep/cpu/cpu_stream.h\"\n#include \"oneflow/core/ep/cpu/cpu_device.h\"\n#include \"oneflow/core/ep/common/onednn.h\"\n#include \"oneflow/core/ep/common/primitive/util.h\"\n\nnamespace oneflow {\n\nnamespace ep {\nnamespace primitive {\n\nnamespace {\n\nenum class Algorithm {\n  kSoftmax,\n  kLogSoftmax,\n};\n\ntemplate<Algorithm algorithm, typename T>\nvoid SoftmaxBackwardCpu(size_t rows, size_t cols, const T* y, const T* dy, T* dx) {\n  for (size_t i = 0; i < rows; ++i) {\n    size_t row_offset = i * cols;\n    const T* row_y = y + row_offset;\n    const T* row_dy = dy + row_offset;\n    T* row_dx = dx + row_offset;\n    T row_sum = 0;\n    for (size_t j = 0; j < cols; ++j) {\n      if (algorithm == Algorithm::kSoftmax) {\n        row_sum += row_y[j] * row_dy[j];\n      } else if (algorithm == Algorithm::kLogSoftmax) {\n        row_sum += row_dy[j];\n      } else {\n        UNIMPLEMENTED();\n      }\n    }\n    for (size_t j = 0; j < cols; ++j) {\n      if (algorithm == Algorithm::kSoftmax) {\n        row_dx[j] = (row_dy[j] - row_sum) * row_y[j];\n      } else if (algorithm == Algorithm::kLogSoftmax) {\n        row_dx[j] = row_dy[j] - std::exp(row_y[j]) * row_sum;\n      } else {\n        UNIMPLEMENTED();\n      }\n    }\n  }\n}\n\ntemplate<typename SoftmaxBackwardBase, Algorithm algorithm, typename T>\nclass SoftmaxBackwardImpl : public SoftmaxBackwardBase {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(SoftmaxBackwardImpl);\n  SoftmaxBackwardImpl() = default;\n  ~SoftmaxBackwardImpl() override = default;\n\n  void Launch(Stream* stream, size_t rows, size_t cols, const void* y, const void* dy,\n              void* dx) override {\n    SoftmaxBackwardCpu<algorithm, T>(rows, cols, reinterpret_cast<const T*>(y),\n                                     reinterpret_cast<const T*>(dy), reinterpret_cast<T*>(dx));\n  }\n};\n\n#ifdef WITH_ONEDNN\n\ntemplate<class OneDnnSoftmaxBackward, class OneDnnSoftmaxForward, dnnl::memory::data_type data_type>\nvoid SoftmaxBackwardOneDnn(Stream* stream, size_t rows, size_t cols, const void* y, const void* dy,\n                           void* dx) {\n  stream->As<CpuStream>()->onednn_executor()->Launch([&](dnnl::engine* onednn_engine,\n                                                         dnnl::stream* onednn_stream) {\n    dnnl::memory::dims src_dims = {static_cast<dnnl::memory::dim>(rows),\n                                   static_cast<dnnl::memory::dim>(cols)};\n    // Input and output parameters of the same data type\n    auto same_md = dnnl::memory::desc(src_dims, data_type, dnnl::memory::format_tag::nc);\n    // Backward memory\n    auto dst_mem = dnnl::memory(same_md, *onednn_engine, const_cast<void*>(y));\n    auto diff_dst_mem = dnnl::memory(same_md, *onednn_engine, const_cast<void*>(dy));\n    // Forward primitive description\n    auto forward_desc = typename OneDnnSoftmaxForward::desc(dnnl::prop_kind::forward, same_md, 1);\n    auto forward_prim_desc =\n        typename OneDnnSoftmaxForward::primitive_desc(forward_desc, *onednn_engine);\n    // Backward primitive description\n    auto diff_src_mem = dnnl::memory(same_md, *onednn_engine, dx);\n    auto backward_desc = typename OneDnnSoftmaxBackward::desc(same_md, same_md, 1);\n    auto backward_prim_desc = typename OneDnnSoftmaxBackward::primitive_desc(\n        backward_desc, *onednn_engine, forward_prim_desc);\n    auto backward_prim = OneDnnSoftmaxBackward(backward_prim_desc);\n\n    backward_prim.execute(*onednn_stream, {{DNNL_ARG_DIFF_DST, diff_dst_mem},\n                                           {DNNL_ARG_DST, dst_mem},\n                                           {DNNL_ARG_DIFF_SRC, diff_src_mem}});\n  });\n}\n\ntemplate<typename SoftmaxBackwardBase, Algorithm algorithm, dnnl::memory::data_type data_type>\nclass OneDnnSoftmaxBackwardImpl;\n\n#define CPU_PRIMITIVE_SOFTMAX_ONEDNN_IMPL(oneflow_algorithm, onednn_backward_algorithm,      \\\n                                          onednn_forward_algorithm)                          \\\n  template<typename SoftmaxBackwardBase, dnnl::memory::data_type data_type>                  \\\n  class OneDnnSoftmaxBackwardImpl<SoftmaxBackwardBase, oneflow_algorithm, data_type>         \\\n      : public SoftmaxBackwardBase {                                                         \\\n   public:                                                                                   \\\n    OF_DISALLOW_COPY_AND_MOVE(OneDnnSoftmaxBackwardImpl);                                    \\\n    OneDnnSoftmaxBackwardImpl() = default;                                                   \\\n    ~OneDnnSoftmaxBackwardImpl() override = default;                                         \\\n                                                                                             \\\n    void Launch(Stream* stream, size_t rows, size_t cols, const void* y, const void* dy,     \\\n                void* dx) override {                                                         \\\n      SoftmaxBackwardOneDnn<onednn_backward_algorithm, onednn_forward_algorithm, data_type>( \\\n          stream, rows, cols, y, dy, dx);                                                    \\\n    }                                                                                        \\\n  }\n\nCPU_PRIMITIVE_SOFTMAX_ONEDNN_IMPL(Algorithm::kSoftmax, dnnl::softmax_backward,\n                                  dnnl::softmax_forward);\nCPU_PRIMITIVE_SOFTMAX_ONEDNN_IMPL(Algorithm::kLogSoftmax, dnnl::logsoftmax_backward,\n                                  dnnl::logsoftmax_forward);\n#undef CPU_PRIMITIVE_SOFTMAX_ONEDNN_IMPL\n\ntemplate<typename SoftmaxBackwardBase, Algorithm algorithm, dnnl::memory::data_type data_type>\nstd::unique_ptr<SoftmaxBackwardBase> NewOneDnnSoftmaxBackward() {\n  return std::unique_ptr<SoftmaxBackwardBase>(\n      new OneDnnSoftmaxBackwardImpl<SoftmaxBackwardBase, algorithm, data_type>());\n}\n\n#endif  // WITH_ONEDNN\n\ntemplate<typename SoftmaxBackwardBase, Algorithm algorithm, typename T>\nstd::unique_ptr<SoftmaxBackwardBase> NewSoftmaxBackward() {\n  return std::unique_ptr<SoftmaxBackwardBase>(\n      new SoftmaxBackwardImpl<SoftmaxBackwardBase, algorithm, T>());\n}\n\ntemplate<typename BackwardFactoryBase, typename SoftmaxBackwardBase, Algorithm algorithm>\nclass GenericSoftmaxBackwardFactoryImpl : public BackwardFactoryBase {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(GenericSoftmaxBackwardFactoryImpl);\n  GenericSoftmaxBackwardFactoryImpl() = default;\n  ~GenericSoftmaxBackwardFactoryImpl() override = default;\n\n  std::unique_ptr<SoftmaxBackwardBase> New(DataType data_type) override {\n#define MAKE_NEW_SOFTMAX_BACKWARD_ENTRY(type_cpp, type_proto) \\\n  {type_proto, NewSoftmaxBackward<SoftmaxBackwardBase, algorithm, type_cpp>},\n    static const std::map<DataType, std::function<std::unique_ptr<SoftmaxBackwardBase>()>>\n        new_softmax_backward_handle{\n            OF_PP_FOR_EACH_TUPLE(MAKE_NEW_SOFTMAX_BACKWARD_ENTRY, CPU_PRIMITIVE_FLOATING_TYPE_SEQ)};\n#undef MAKE_NEW_SOFTMAX_BACKWARD_ENTRY\n\n#ifdef WITH_ONEDNN\n    if (OneDnnIsEnabled() && data_type == DataType::kFloat) {\n      static std::function<std::unique_ptr<SoftmaxBackwardBase>()> onednn_f32_softmax_backward =\n          NewOneDnnSoftmaxBackward<SoftmaxBackwardBase, algorithm, dnnl::memory::data_type::f32>;\n      return onednn_f32_softmax_backward();\n    }\n#endif\n    return NewPrimitiveFromHandlers(new_softmax_backward_handle, data_type);\n  }\n};\n\nusing SoftmaxBackwardFactoryImpl =\n    GenericSoftmaxBackwardFactoryImpl<SoftmaxBackwardFactory, SoftmaxBackward, Algorithm::kSoftmax>;\nusing LogSoftmaxBackwardFactoryImpl =\n    GenericSoftmaxBackwardFactoryImpl<LogSoftmaxBackwardFactory, LogSoftmaxBackward,\n                                      Algorithm::kLogSoftmax>;\nREGISTER_PRIMITIVE_FACTORY(DeviceType::kCPU, SoftmaxBackwardFactory, SoftmaxBackwardFactoryImpl);\nREGISTER_PRIMITIVE_FACTORY(DeviceType::kCPU, LogSoftmaxBackwardFactory,\n                           LogSoftmaxBackwardFactoryImpl);\n\n}  // namespace\n\n}  // namespace primitive\n}  // namespace ep\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/ep/cpu/primitive/tensor_fill.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/ep/include/primitive/tensor_fill.h\"\n#include \"oneflow/core/ep/cpu/primitive/type_seq.h\"\n\nnamespace oneflow {\n\nnamespace ep {\nnamespace primitive {\n\nnamespace {\n\ntemplate<typename T>\nclass TensorFillImpl : public TensorFill {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(TensorFillImpl);\n  TensorFillImpl() = default;\n  ~TensorFillImpl() override = default;\n\n  void Launch(Stream* stream, const void* src, void* dst, size_t count) override {\n    const T* value = reinterpret_cast<const T*>(src);\n    std::fill_n(reinterpret_cast<T*>(dst), count, value[0]);\n  }\n};\n\ntemplate<typename T>\nstd::unique_ptr<TensorFill> NewTensorFill() {\n  return std::unique_ptr<TensorFill>(new TensorFillImpl<T>());\n}\n\nclass TensorFillFactoryImpl : public TensorFillFactory {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(TensorFillFactoryImpl);\n  TensorFillFactoryImpl() = default;\n  ~TensorFillFactoryImpl() override = default;\n\n  std::unique_ptr<TensorFill> New(DataType data_type) override {\n#define MAKE_NEW_FILL_ENTRY(type_cpp, type_proto) {type_proto, NewTensorFill<type_cpp>},\n\n    static const std::map<DataType, std::function<std::unique_ptr<TensorFill>()>> new_fill_handle{\n        OF_PP_FOR_EACH_TUPLE(MAKE_NEW_FILL_ENTRY, CPU_PRIMITIVE_ALL_TYPE_SEQ)};\n#undef MAKE_NEW_ADD_ENTRY\n    const auto it = new_fill_handle.find(data_type);\n    if (it != new_fill_handle.end()) {\n      return it->second();\n    } else {\n      return nullptr;\n    }\n  }\n};\n\nREGISTER_PRIMITIVE_FACTORY(DeviceType::kCPU, TensorFillFactory, TensorFillFactoryImpl);\n\n}  // namespace\n\n}  // namespace primitive\n}  // namespace ep\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/ep/cpu/primitive/type_seq.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_EP_CPU_PRIMITIVE_TYPE_SEQ_H_\n#define ONEFLOW_CORE_EP_CPU_PRIMITIVE_TYPE_SEQ_H_\n\n#include \"oneflow/core/common/preprocessor.h\"\n#include \"oneflow/core/common/data_type.h\"\n#include <half.hpp>\n\n#ifdef WITH_ONEDNN\n#include \"oneapi/dnnl/dnnl.hpp\"\n#endif\n\n#define CPU_PRIMITIVE_BOOL_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(bool, DataType::kBool)\n#define CPU_PRIMITIVE_CHAR_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(char, DataType::kChar)\n#define CPU_PRIMITIVE_INT8_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(int8_t, DataType::kInt8)\n#define CPU_PRIMITIVE_INT16_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(int16_t, DataType::kInt16)\n#define CPU_PRIMITIVE_UINT8_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(uint8_t, DataType::kUInt8)\n#define CPU_PRIMITIVE_INT32_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(int32_t, DataType::kInt32)\n#define CPU_PRIMITIVE_UINT32_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(uint32_t, DataType::kUInt32)\n#define CPU_PRIMITIVE_INT64_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(int64_t, DataType::kInt64)\n#define CPU_PRIMITIVE_UINT64_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(uint64_t, DataType::kUInt64)\n#define CPU_PRIMITIVE_FLOAT_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(float, DataType::kFloat)\n#define CPU_PRIMITIVE_DOUBLE_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(double, DataType::kDouble)\n#define CPU_PRIMITIVE_FLOAT16_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(float16, DataType::kFloat16)\n#define CPU_PRIMITIVE_BFLOAT16_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(bfloat16, DataType::kBFloat16)\n#define CPU_PRIMITIVE_COMPLEX64_TYPE_SEQ \\\n  OF_PP_MAKE_TUPLE_SEQ(std::complex<float>, DataType::kComplex64)\n#define CPU_PRIMITIVE_COMPLEX128_TYPE_SEQ \\\n  OF_PP_MAKE_TUPLE_SEQ(std::complex<double>, DataType::kComplex128)\n\n#define CPU_PRIMITIVE_ONEDNN_BOOl_TYPE_SEQ \\\n  OF_PP_MAKE_TUPLE_SEQ(dnnl::memory::data_type::u8, DataType::kBool)\n#define CPU_PRIMITIVE_ONEDNN_INT8_TYPE_SEQ \\\n  OF_PP_MAKE_TUPLE_SEQ(dnnl::memory::data_type::s8, DataType::kInt8)\n#define CPU_PRIMITIVE_ONEDNN_UINT8_TYPE_SEQ \\\n  OF_PP_MAKE_TUPLE_SEQ(dnnl::memory::data_type::u8, DataType::kUInt8)\n#define CPU_PRIMITIVE_ONEDNN_INT32_TYPE_SEQ \\\n  OF_PP_MAKE_TUPLE_SEQ(dnnl::memory::data_type::s32, DataType::kInt32)\n#define CPU_PRIMITIVE_ONEDNN_FLOAT_TYPE_SEQ \\\n  OF_PP_MAKE_TUPLE_SEQ(dnnl::memory::data_type::f32, DataType::kFloat)\n#define CPU_PRIMITIVE_ONEDNN_FLOAT16_TYPE_SEQ \\\n  OF_PP_MAKE_TUPLE_SEQ(dnnl::memory::data_type::f16, DataType::kFloat16)\n#define CPU_PRIMITIVE_ONEDNN_BFLOAT16_TYPE_SEQ \\\n  OF_PP_MAKE_TUPLE_SEQ(dnnl::memory::data_type::bf16, DataType::kBFloat16)\n\n#define CPU_PRIMITIVE_NATIVE_TYPE_SEQ \\\n  CPU_PRIMITIVE_BOOL_TYPE_SEQ         \\\n  CPU_PRIMITIVE_CHAR_TYPE_SEQ         \\\n  CPU_PRIMITIVE_INT8_TYPE_SEQ         \\\n  CPU_PRIMITIVE_UINT8_TYPE_SEQ        \\\n  CPU_PRIMITIVE_INT32_TYPE_SEQ        \\\n  CPU_PRIMITIVE_UINT32_TYPE_SEQ       \\\n  CPU_PRIMITIVE_INT64_TYPE_SEQ        \\\n  CPU_PRIMITIVE_UINT64_TYPE_SEQ       \\\n  CPU_PRIMITIVE_FLOAT_TYPE_SEQ        \\\n  CPU_PRIMITIVE_DOUBLE_TYPE_SEQ\n\n#define CPU_PRIMITIVE_ALL_TYPE_SEQ \\\n  CPU_PRIMITIVE_NATIVE_TYPE_SEQ    \\\n  CPU_PRIMITIVE_FLOAT16_TYPE_SEQ   \\\n  CPU_PRIMITIVE_BFLOAT16_TYPE_SEQ  \\\n  CPU_PRIMITIVE_COMPLEX_TYPE_SEQ\n\n#define CPU_PRIMITIVE_COMPLEX_TYPE_SEQ \\\n  CPU_PRIMITIVE_COMPLEX64_TYPE_SEQ     \\\n  CPU_PRIMITIVE_COMPLEX128_TYPE_SEQ\n\n#define CPU_PRIMITIVE_FLOATING_TYPE_SEQ \\\n  CPU_PRIMITIVE_FLOAT_TYPE_SEQ          \\\n  CPU_PRIMITIVE_DOUBLE_TYPE_SEQ\n\n#define CPU_PRIMITIVE_INT_TYPE_SEQ \\\n  CPU_PRIMITIVE_INT8_TYPE_SEQ      \\\n  CPU_PRIMITIVE_UINT8_TYPE_SEQ     \\\n  CPU_PRIMITIVE_INT32_TYPE_SEQ     \\\n  CPU_PRIMITIVE_INT64_TYPE_SEQ\n\n#define UTIL_OPS_DATA_TYPE_SEQ \\\n  CPU_PRIMITIVE_INT8_TYPE_SEQ  \\\n  CPU_PRIMITIVE_UINT8_TYPE_SEQ \\\n  CPU_PRIMITIVE_INT32_TYPE_SEQ \\\n  CPU_PRIMITIVE_INT64_TYPE_SEQ \\\n  CPU_PRIMITIVE_FLOAT_TYPE_SEQ \\\n  CPU_PRIMITIVE_DOUBLE_TYPE_SEQ\n\n#endif  // ONEFLOW_CORE_EP_CPU_PRIMITIVE_TYPE_SEQ_H_\n"
  },
  {
    "path": "oneflow/core/ep/cpu/primitive/unary_functor.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/ep/common/primitive/unary_functor.h\"\n#include \"oneflow/core/ep/cpu/primitive/type_seq.h\"\n#include \"oneflow/core/common/math_util.h\"\n\nnamespace oneflow {\nnamespace ep {\nnamespace primitive {\n\ntemplate<typename Dst, typename Src>\nstruct UnaryFunctor<DeviceType::kCPU, UnaryOp::kGelu, Dst, Src> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src src) const {\n    return static_cast<Src>(0.5) * src * (static_cast<Src>(1.0) + std::erf(inv_sqrt2 * src));\n  }\n  Src inv_sqrt2 = std::sqrt(0.5);\n};\n\ntemplate<typename Dst, typename Src>\nstruct UnaryFunctor<DeviceType::kCPU, UnaryOp::kFastGelu, Dst, Src> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src src) const {\n    // ref to: https://mlfromscratch.com/activation-functions-explained/#gelu\n    const Src half = static_cast<Src>(0.5);\n    const Src one = static_cast<Src>(1);\n    const Src tanh_in = alpha * (src + beta * src * src * src);\n    return half * src * (one + std::tanh(tanh_in));\n  }\n\n private:\n  // constant ref to:\n  // https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/test/testdata/transform/fusion/fast_gelu.py\n  static constexpr Src alpha = static_cast<Src>(0.7978845608028654);\n  static constexpr Src beta = static_cast<Src>(0.044714998453855515);\n};\n\ntemplate<typename Dst, typename Src>\nstruct UnaryFunctor<DeviceType::kCPU, UnaryOp::kQuickGelu, Dst, Src> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src src) const {\n    const Src sigmoid =\n        static_cast<Dst>(static_cast<Src>(1.0) / (static_cast<Src>(1.0) + exp(-src * alpha)));\n    return src * sigmoid;\n  }\n\n private:\n  static constexpr Src alpha = static_cast<Src>(1.702);\n};\n\ntemplate<typename Dst, typename Src>\nstruct UnaryFunctor<DeviceType::kCPU, UnaryOp::kSquareReLU, Dst, Src> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src src) const {\n    return static_cast<Dst>((src > static_cast<Src>(0.0)) ? src * src : 0);\n  }\n};\n\ntemplate<typename Dst, typename Src>\nstruct UnaryFunctor<DeviceType::kCPU, UnaryOp::kTanh, Dst, Src> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src src) const { return std::tanh(src); }\n};\n\ntemplate<>\nstruct UnaryFunctor<DeviceType::kCPU, UnaryOp::kIsInf, bool, float> {\n  UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC bool operator()(float src) const { return std::isinf(src); }\n};\n\ntemplate<>\nstruct UnaryFunctor<DeviceType::kCPU, UnaryOp::kIsInf, bool, double> {\n  UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC bool operator()(double src) const { return std::isinf(src); }\n};\n\ntemplate<>\nstruct UnaryFunctor<DeviceType::kCPU, UnaryOp::kIsNan, bool, float> {\n  UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC bool operator()(float src) const { return std::isnan(src); }\n};\n\ntemplate<>\nstruct UnaryFunctor<DeviceType::kCPU, UnaryOp::kIsNan, bool, double> {\n  UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC bool operator()(double src) const { return std::isnan(src); }\n};\n\ntemplate<typename Src>\nstruct UnaryFunctor<DeviceType::kCPU, UnaryOp::kIsFinite, bool, Src> {\n  UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC bool operator()(Src src) const { return std::isfinite(src); }\n};\n\ntemplate<typename Dst, typename Src>\nstruct UnaryFunctor<DeviceType::kCPU, UnaryOp::kTrunc, Dst, Src> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n  OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast<Dst>(std::trunc(src)); }\n};\n\ntemplate<typename Dst, typename Src>\nstruct UnaryFunctor<DeviceType::kCPU, UnaryOp::kRsqrt, Dst, Src> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src src) const {\n    return static_cast<Dst>(static_cast<Src>(1.0) / static_cast<Src>(std::sqrt(src)));\n  }\n};\n\ntemplate<>\nstruct UnaryFunctor<DeviceType::kCPU, UnaryOp::kDigamma, float, float> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC float operator()(float src) const {\n    // references\n    // https://github.com/pytorch/pytorch/blob/release/1.13/aten/src/ATen/native/Math.h#L434-L487\n    const auto& calc_digamma = [](float x) {\n      std::function<float(float)> compute;\n      compute = [&](float x) {\n        static float PSI_10 = 2.25175258906672110764f;\n        if (x == 0) {\n          // As per C++ standard for gamma related functions and SciPy,\n          // If the argument is ±0, ±∞ is returned\n          return std::copysign(INFINITY, -x);\n        }\n\n        bool x_is_integer = x == truncf(x);\n        if (x < 0) {\n          if (x_is_integer) {\n            // As per C++ standard for gamma related functions and SciPy,\n            // If the argument is a negative integer, NaN is returned\n            return std::numeric_limits<float>::quiet_NaN();\n          }\n          // Extracts the fractional part of x as r, since tan(pi * r) is more numerically\n          // accurate than tan(pi * x). While these operations are mathematically equivalent\n          // since both x and r are in radians and tan() has a periodicity of pi, in practice\n          // the computation of pi * x is a source of error (when |x| > 1).\n          double q, r;\n          r = std::modf(x, &q);\n          float pi_over_tan_pi_x = (float)(pi<double> / tan(pi<double> * r));\n          return compute(1 - x) - pi_over_tan_pi_x;\n        }\n\n        // Push x to be >= 10\n        float result = 0;\n        while (x < 10) {\n          result -= 1 / x;\n          x += 1;\n        }\n        if (x == 10) { return result + PSI_10; }\n\n        // Compute asymptotic digamma\n        static const float A[] = {\n            8.33333333333333333333E-2f,  -2.10927960927960927961E-2f, 7.57575757575757575758E-3f,\n            -4.16666666666666666667E-3f, 3.96825396825396825397E-3f,  -8.33333333333333333333E-3f,\n            8.33333333333333333333E-2f,\n        };\n\n        float y = 0;\n        if (x < 1.0e17f) {\n          float z = 1 / (x * x);\n          float polevl_result = 0;\n          for (int i = 0; i <= 6; i++) { polevl_result = polevl_result * z + A[i]; }\n          y = z * polevl_result;\n        }\n        return result + logf(x) - (0.5f / x) - y;\n      };\n\n      return compute(x);\n    };\n\n    return calc_digamma(src);\n  }\n};\n\ntemplate<>\nstruct UnaryFunctor<DeviceType::kCPU, UnaryOp::kDigamma, double, double> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC double operator()(double src) const {\n    // references\n    // https://github.com/pytorch/pytorch/blob/release/1.13/aten/src/ATen/native/Math.h#L376-L428\n    const auto& calc_digamma = [](double x) {\n      std::function<double(double)> compute;\n      compute = [&](double x) {\n        static double PSI_10 = 2.25175258906672110764;\n        if (x == 0) {\n          // As per C++ standard for gamma related functions and SciPy,\n          // If the argument is ±0, ±∞ is returned\n          return std::copysign(INFINITY, -x);\n        }\n\n        bool x_is_integer = x == trunc(x);\n        if (x < 0) {\n          if (x_is_integer) {\n            // As per C++ standard for gamma related functions and SciPy,\n            // If the argument is a negative integer, NaN is returned\n            return std::numeric_limits<double>::quiet_NaN();\n          }\n          // Extracts the fractional part of x as r, since tan(pi * r) is more numerically\n          // accurate than tan(pi * x). While these operations are mathematically equivalent\n          // since both x and r are in radians and tan() has a periodicity of pi, in practice\n          // the computation of pi * x is a source of error (when |x| > 1).\n          double q, r;\n          r = std::modf(x, &q);\n          return compute(1 - x) - pi<double> / tan(pi<double> * r);\n        }\n\n        // Push x to be >= 10\n        double result = 0;\n        while (x < 10) {\n          result -= 1 / x;\n          x += 1;\n        }\n        if (x == 10) { return result + PSI_10; }\n\n        // Compute asymptotic digamma\n        static const double A[] = {\n            8.33333333333333333333E-2,  -2.10927960927960927961E-2, 7.57575757575757575758E-3,\n            -4.16666666666666666667E-3, 3.96825396825396825397E-3,  -8.33333333333333333333E-3,\n            8.33333333333333333333E-2,\n        };\n\n        double y = 0;\n        if (x < 1.0e17) {\n          double z = 1.0 / (x * x);\n          // y = z * polevl(z, A, 6);\n\n          double polevl_result = 0;\n          for (int i = 0; i <= 6; i++) { polevl_result = polevl_result * z + A[i]; }\n          y = z * polevl_result;\n        }\n        return result + log(x) - (0.5 / x) - y;\n      };\n\n      return compute(x);\n    };\n\n    return calc_digamma(src);\n  }\n};\n\ntemplate<>\nstruct UnaryFunctor<DeviceType::kCPU, UnaryOp::kTrigamma, double, double> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC double operator()(double x) const {\n    // references\n    // https://github.com/pytorch/pytorch/blob/release/1.13/aten/src/ATen/native/Math.h#L336-L352\n    double sign = +1;\n    double result = 0;\n    if (x < 0.5) {\n      sign = -1;\n      const double sin_pi_x = sin(pi<double> * x);\n      result -= (pi<double> * pi<double>) / (sin_pi_x * sin_pi_x);\n      x = 1 - x;\n    }\n    for (int i = 0; i < 6; ++i) {\n      result += 1 / (x * x);\n      x += 1;\n    }\n    const double ixx = 1 / (x * x);\n    result += (1 + 1 / (2 * x) + ixx * (1. / 6 - ixx * (1. / 30 - ixx * (1. / 42)))) / x;\n    return sign * result;\n  }\n};\n\ntemplate<>\nstruct UnaryFunctor<DeviceType::kCPU, UnaryOp::kTrigamma, float, float> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC float operator()(float x) const {\n    // references\n    // https://github.com/pytorch/pytorch/blob/release/1.13/aten/src/ATen/native/Math.h#L354-L370\n    float sign = +1;\n    float result = 0;\n    if (x < 0.5f) {\n      sign = -1;\n      const float sin_pi_x = sinf(pi<float> * x);\n      result -= (pi<float> * pi<float>) / (sin_pi_x * sin_pi_x);\n      x = 1 - x;\n    }\n    for (int i = 0; i < 6; ++i) {\n      result += 1 / (x * x);\n      x += 1;\n    }\n    const float ixx = 1 / (x * x);\n    result += (1 + 1 / (2 * x) + ixx * (1.f / 6 - ixx * (1.f / 30 - ixx * (1.f / 42)))) / x;\n    return sign * result;\n  }\n};\n\ntemplate<>\nstruct UnaryFunctor<DeviceType::kCPU, UnaryOp::kAbs, bfloat16, bfloat16> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC bfloat16 operator()(bfloat16 src) const { return std::abs(src); }\n};\n\n#define SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(op)                                        \\\n  template<>                                                                                 \\\n  struct UnaryFunctor<DeviceType::kCPU, op, bfloat16, bfloat16> {                            \\\n    OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) : float_functor(attr0, attr1) {} \\\n                                                                                             \\\n    UnaryFunctor<DeviceType::kCPU, op, float, float> float_functor;                          \\\n    OF_DEVICE_FUNC bfloat16 operator()(bfloat16 src) const {                                 \\\n      return bfloat16(float_functor(static_cast<float>(src)));                               \\\n    }                                                                                        \\\n  };\n\nSPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kElu);\nSPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kCelu);\nSPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kGelu);\nSPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kHardSwish);\nSPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kHardSigmoid);\nSPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kHardShrink);\nSPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kHardTanh);\nSPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kLeakyRelu);\nSPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kMish);\nSPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kSelu);\nSPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kSilu);\nSPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kSoftShrink);\nSPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kSoftSign);\nSPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kSoftPlus);\nSPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kTanh);\nSPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kThreshold);\nSPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kAcos);\nSPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kAcosh);\nSPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kAsin);\nSPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kAsinh);\nSPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kAtan);\nSPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kAtanh);\nSPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kCeil);\nSPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kCos);\nSPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kCosh);\nSPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kErf);\nSPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kErfc);\nSPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kExp);\nSPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kExp2);\nSPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kExpm1);\nSPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kFloor);\nSPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kLgamma);\nSPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kLog);\nSPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kLog2);\nSPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kLog1p);\nSPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kLogSigmoid);\nSPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kRint);\nSPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kRound);\nSPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kRsqrt);\nSPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kSigmoid);\nSPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kSin);\nSPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kSinh);\nSPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kSqrt);\nSPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kSquare);\nSPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kTan);\nSPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kReciprocalNoNan);\nSPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kNotEqualZero);\nSPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kFastGelu);\nSPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kQuickGelu);\nSPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kSquareReLU);\nSPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kDigamma);\nSPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kTrigamma);\n\ntemplate<>\nstruct UnaryFunctor<DeviceType::kCPU, UnaryOp::kIsInf, bool, bfloat16> {\n  UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC bool operator()(bfloat16 src) const { return std::isinf(src); }\n};\n\ntemplate<>\nstruct UnaryFunctor<DeviceType::kCPU, UnaryOp::kIsNan, bool, bfloat16> {\n  UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC bool operator()(bfloat16 src) const { return std::isnan(src); }\n};\n\n// avoid warning: narrowing conversion\ntemplate<>\nstruct UnaryFunctor<DeviceType::kCPU, UnaryOp::kRealGrad, std::complex<float>, double> {\n  UnaryFunctor(Scalar attr0, Scalar attr1) {}\n  std::complex<float> operator()(double src) const {\n    return std::complex<float>{static_cast<float>(src), 0.0f};\n  }\n};\n\ntemplate<>\nstruct UnaryFunctor<DeviceType::kCPU, UnaryOp::kImagGrad, std::complex<float>, double> {\n  UnaryFunctor(Scalar attr0, Scalar attr1) {}\n  std::complex<float> operator()(double src) const {\n    return std::complex<float>{0.0f, static_cast<float>(src)};\n  }\n};\n\n}  // namespace primitive\n}  // namespace ep\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/ep/cpu/primitive/where.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/ep/include/primitive/where.h\"\n#include \"oneflow/core/ep/common/primitive/where.h\"\n#include \"oneflow/core/ep/cpu/cpu_stream.h\"\n\nnamespace oneflow {\nnamespace ep {\nnamespace primitive {\n\nnamespace {\n\ntemplate<typename T, typename CondT, typename IndexT, size_t ndim, size_t cond_pack_size,\n         size_t x_pack_size, size_t y_pack_size>\nvoid BroadcastElementwiseWhereKernel(CpuStream* cpu_stream,\n                                     const BroadcastElementwiseWhereParams<ndim, IndexT>& params) {\n  constexpr size_t _pack_size = (x_pack_size > y_pack_size) ? x_pack_size : y_pack_size;\n  constexpr size_t pack_size = (cond_pack_size > _pack_size) ? cond_pack_size : _pack_size;\n  static_assert(cond_pack_size == pack_size || cond_pack_size == 1, \"\");\n  static_assert(x_pack_size == pack_size || x_pack_size == 1, \"\");\n  static_assert(y_pack_size == pack_size || y_pack_size == 1, \"\");\n\n  const auto* cond_pack = reinterpret_cast<const Packed<CondT, cond_pack_size>*>(params.cond);\n  const auto* x_pack = reinterpret_cast<const Packed<T, x_pack_size>*>(params.x);\n  const auto* y_pack = reinterpret_cast<const Packed<T, y_pack_size>*>(params.y);\n  auto* z_pack = reinterpret_cast<Packed<T, pack_size>*>(params.z);\n\n  WhereFunctor<T, CondT> where_fn{};\n\n  cpu_stream->ParallelFor(0, params.elem_cnt, [&](int64_t begin, int64_t end) {\n    IndexT cond_index[ndim];\n    IndexT x_index[ndim];\n    IndexT y_index[ndim];\n    IndexT z_index[ndim];\n\n    for (IndexT offset = begin; offset < end; offset++) {\n      params.z_index_helper.OffsetToNdIndex(offset, z_index);\n      for (size_t i = 0; i < ndim; ++i) {\n        cond_index[i] = params.cond_index_mask[i] * z_index[i];\n        x_index[i] = params.x_index_mask[i] * z_index[i];\n        y_index[i] = params.y_index_mask[i] * z_index[i];\n      }\n      const IndexT cond_offset = params.cond_index_helper.NdIndexToOffset(cond_index);\n      const IndexT x_offset = params.x_index_helper.NdIndexToOffset(x_index);\n      const IndexT y_offset = params.y_index_helper.NdIndexToOffset(y_index);\n\n      for (size_t j = 0; j < pack_size; ++j) {\n        const CondT cond_val = (cond_pack_size == pack_size) ? cond_pack[cond_offset].elem[j]\n                                                             : cond_pack[cond_offset].elem[0];\n        const T x_val =\n            (x_pack_size == pack_size) ? x_pack[x_offset].elem[j] : x_pack[x_offset].elem[0];\n        const T y_val =\n            (y_pack_size == pack_size) ? y_pack[y_offset].elem[j] : y_pack[y_offset].elem[0];\n        z_pack[offset].elem[j] = where_fn(static_cast<bool>(cond_val), x_val, y_val);\n      }\n    }\n  });\n}\n\ntemplate<typename T, typename CondT>\nvoid ScalarWhereKernel(const CondT* cond, const T* x, const T* y, T* z) {\n  WhereFunctor<T, CondT> where_fn{};\n  *z = where_fn(*cond, *x, *y);\n}\n\ntemplate<typename T, typename CondT, typename IndexT, size_t ndim, size_t cond_pack_size,\n         size_t x_pack_size, size_t y_pack_size>\nvoid LaunchKernel(Stream* stream, const int64_t* cond_dims, const int64_t* x_dims,\n                  const int64_t* y_dims, const int64_t* z_dims, const CondT* cond, const T* x,\n                  const T* y, T* z) {\n  static_assert(ndim > 0, \"\");\n  BroadcastElementwiseWhereParams<ndim, IndexT> params;\n  params.cond_index_helper = NdIndexOffsetHelper<IndexT, ndim>(cond_dims);\n  params.x_index_helper = NdIndexOffsetHelper<IndexT, ndim>(x_dims);\n  params.y_index_helper = NdIndexOffsetHelper<IndexT, ndim>(y_dims);\n  params.z_index_helper = NdIndexOffsetHelper<IndexT, ndim>(z_dims);\n  for (size_t i = 0; i < ndim; ++i) {\n    params.cond_index_mask[i] = (cond_dims[i] == 1) ? 0 : 1;\n    params.x_index_mask[i] = (x_dims[i] == 1) ? 0 : 1;\n    params.y_index_mask[i] = (y_dims[i] == 1) ? 0 : 1;\n  }\n  params.elem_cnt = static_cast<IndexT>(GetElementCount(ndim, z_dims));\n  params.cond = cond;\n  params.x = x;\n  params.y = y;\n  params.z = z;\n\n  auto* cpu_stream = stream->As<CpuStream>();\n  BroadcastElementwiseWhereKernel<T, CondT, IndexT, ndim, cond_pack_size, x_pack_size, y_pack_size>(\n      cpu_stream, params);\n}\n\ntemplate<typename T, typename CondT>\nvoid LaunchScalarKernel(Stream* stream, const CondT* cond, const T* x, const T* y, T* z) {\n  ScalarWhereKernel(cond, x, y, z);\n}\n\ntemplate<typename T, typename CondT>\nclass WhereImpl : public Where {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(WhereImpl);\n  explicit WhereImpl() = default;\n  ~WhereImpl() override = default;\n\n  void Launch(Stream* stream, size_t num_cond_dims, const int64_t* cond_dims, const void* cond,\n              size_t num_x_dims, const int64_t* x_dims, const void* x, size_t num_y_dims,\n              const int64_t* y_dims, const void* y, void* z) override {\n    size_t compact_num_dims = 0;\n    int64_t compact_cond_dims[kMaxNumDims] = {};\n    int64_t compact_x_dims[kMaxNumDims] = {};\n    int64_t compact_y_dims[kMaxNumDims] = {};\n    int64_t compact_z_dims[kMaxNumDims] = {};\n    GetCompactBroadcastDims(num_cond_dims, cond_dims, num_x_dims, x_dims, num_y_dims, y_dims,\n                            &compact_num_dims, compact_cond_dims, compact_x_dims, compact_y_dims,\n                            compact_z_dims);\n    LaunchByDispatchNDim(stream, compact_num_dims, compact_cond_dims, compact_x_dims,\n                         compact_y_dims, compact_z_dims, static_cast<const CondT*>(cond),\n                         static_cast<const T*>(x), static_cast<const T*>(y), static_cast<T*>(z));\n  }\n};\n\nclass WhereFactoryImpl : public WhereFactory {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(WhereFactoryImpl);\n  WhereFactoryImpl() = default;\n  ~WhereFactoryImpl() override = default;\n\n  std::unique_ptr<Where> New(DataType cond_type, DataType data_type, size_t max_num_dims) override {\n    return NewWhere<WhereImpl>(cond_type, data_type, max_num_dims);\n  }\n};\n\nREGISTER_PRIMITIVE_FACTORY(DeviceType::kCPU, WhereFactory, WhereFactoryImpl);\n\n}  // namespace\n\n}  // namespace primitive\n}  // namespace ep\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/ep/cuda/cuda_device.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/mem_util.h\"\n#include \"oneflow/core/ep/cuda/cuda_device.h\"\n#include \"oneflow/core/ep/cuda/cuda_event.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n\n#ifdef WITH_CUDA\n\n#include <cuda.h>\n#include <cuda_fp16.h>\n\n#if CUDA_VERSION >= 11000\n#include <cuda_bf16.h>\n#endif\n\nnamespace oneflow {\n\nnamespace ep {\n\nnamespace {\n\nconstexpr size_t kDefaultConstBufElementCount = 1024 * 1024;\n\ntemplate<typename T>\nvoid CreateConstBuffer(void** buf, T value, size_t n) {\n  OF_CUDA_CHECK(cudaMalloc(buf, n * sizeof(T)));\n  std::vector<T> host(n, value);\n  OF_CUDA_CHECK(cudaMemcpy(*buf, host.data(), n * sizeof(T), cudaMemcpyDefault));\n}\n\n}  // namespace\n\nCudaDevice::CudaDevice(int device_index, DeviceManager* device_manager)\n    : device_index_(device_index),\n      event_flags_{},\n      properties_{},\n      device_manager_(device_manager),\n      const_buf_elem_cnt_(0),\n      const_zeros_buffer_(nullptr),\n      const_ones_buffer_fp32_(nullptr),\n      const_ones_buffer_fp16_(nullptr),\n      const_ones_buffer_bf16_(nullptr) {\n  CudaCurrentDeviceGuard guard(device_index_);\n  OF_CUDA_CHECK(cudaGetDeviceProperties(&properties_, device_index_));\n  {\n    const char* env_name = \"ONEFLOW_EP_CUDA_DEVICE_FLAGS\";\n    if (std::getenv(env_name) != nullptr) {\n      const unsigned int flags = ParseIntegerFromEnv(env_name, 0);\n      OF_CUDA_CHECK(cudaSetDeviceFlags(flags));\n    }\n  }\n  event_flags_ = cudaEventDisableTiming;\n  if (ParseBooleanFromEnv(\"ONEFLOW_STREAM_CUDA_EVENT_FLAG_BLOCKING_SYNC\", false)) {\n    event_flags_ |= cudaEventBlockingSync;\n  }\n  const_buf_elem_cnt_ = ParseIntegerFromEnv(\"ONEFLOW_EP_CUDA_CONST_BUFFER_ELEMENT_COUNT\",\n                                            kDefaultConstBufElementCount);\n  if (const_buf_elem_cnt_ > 0) {\n    CreateConstBuffer<float>(&const_zeros_buffer_, static_cast<float>(0), const_buf_elem_cnt_);\n    CreateConstBuffer<float>(&const_ones_buffer_fp32_, static_cast<float>(1.0),\n                             const_buf_elem_cnt_);\n    CreateConstBuffer<half>(&const_ones_buffer_fp16_, static_cast<half>(1.0), const_buf_elem_cnt_);\n#if CUDA_VERSION >= 11000\n    CreateConstBuffer<nv_bfloat16>(&const_ones_buffer_bf16_, static_cast<nv_bfloat16>(1.0),\n                                   const_buf_elem_cnt_);\n#endif  // CUDA_VERSION >= 11000\n  }\n#if CUDA_VERSION >= 11020\n  if (ParseBooleanFromEnv(\"ONEFLOW_EP_CUDA_ENABLE_STREAM_ORDERED_MEMORY_ALLOCATOR\", false)) {\n    int memory_pools_supported = 0;\n    cudaError_t err = cudaDeviceGetAttribute(&memory_pools_supported,\n                                             cudaDevAttrMemoryPoolsSupported, device_index_);\n    if (err == cudaSuccess && memory_pools_supported) {\n      cudaMemPoolProps mem_pool_props = {};\n      mem_pool_props.allocType = cudaMemAllocationTypePinned;\n      mem_pool_props.handleTypes = cudaMemHandleTypePosixFileDescriptor;\n      mem_pool_props.location.type = cudaMemLocationTypeDevice;\n      mem_pool_props.location.id = device_index_;\n      OF_CUDA_CHECK(cudaMemPoolCreate(&mem_pool_, &mem_pool_props));\n      uint64_t threshold = UINT64_MAX;\n      OF_CUDA_CHECK(\n          cudaMemPoolSetAttribute(mem_pool_, cudaMemPoolAttrReleaseThreshold, &threshold));\n      int disabled = 0;\n      OF_CUDA_CHECK(\n          cudaMemPoolSetAttribute(mem_pool_, cudaMemPoolReuseFollowEventDependencies, &disabled));\n      OF_CUDA_CHECK(\n          cudaMemPoolSetAttribute(mem_pool_, cudaMemPoolReuseAllowOpportunistic, &disabled));\n      OF_CUDA_CHECK(\n          cudaMemPoolSetAttribute(mem_pool_, cudaMemPoolReuseAllowInternalDependencies, &disabled));\n    }\n    if (err != cudaSuccess) { (void)cudaGetLastError(); }\n  }\n#endif  // CUDA_VERSION >= 11020\n}\n\nCudaDevice::~CudaDevice() {\n  CudaCurrentDeviceGuard guard(device_index_);\n  for (auto* event : events_) { delete event; }\n  OF_CUDA_CHECK(cudaFree(const_zeros_buffer_));\n  OF_CUDA_CHECK(cudaFree(const_ones_buffer_fp32_));\n  OF_CUDA_CHECK(cudaFree(const_ones_buffer_fp16_));\n  OF_CUDA_CHECK(cudaFree(const_ones_buffer_bf16_));\n#if CUDA_VERSION >= 11020\n  if (mem_pool_) { OF_CUDA_CHECK(cudaMemPoolDestroy(mem_pool_)); }\n#endif  // CUDA_VERSION >= 11020\n}\n\nvoid CudaDevice::SetAsActiveDevice() { OF_CUDA_CHECK(cudaSetDevice(device_index_)); }\n\nvoid CudaDevice::Reset() {\n  SetAsActiveDevice();\n  OF_CUDA_CHECK(cudaDeviceReset());\n}\n\nStream* CudaDevice::CreateStream() {\n  CudaCurrentDeviceGuard guard(device_index_);\n  return new CudaStream(this);\n}\n\nvoid CudaDevice::DestroyStream(Stream* stream) {\n  CudaCurrentDeviceGuard guard(device_index_);\n  delete stream;\n}\n\nvoid CudaDevice::CreateEvents(Event** events, size_t count) {\n  size_t copied = 0;\n  {\n    std::lock_guard<std::mutex> lock(events_mutex_);\n    copied = std::min(count, events_.size());\n    size_t offset = events_.size() - copied;\n    std::copy(events_.begin() + offset, events_.end(), events);\n    events_.resize(offset);\n  }\n  if (copied != count) {\n    CudaCurrentDeviceGuard guard(device_index_);\n    for (size_t i = copied; i < count; ++i) { events[i] = new CudaEvent(event_flags_); }\n  }\n}\n\nvoid CudaDevice::DestroyEvents(Event** events, size_t count) {\n  std::lock_guard<std::mutex> lock(events_mutex_);\n  events_.insert(events_.end(), events, events + count);\n}\n\nMaybe<void> CudaDevice::Alloc(const AllocationOptions& options, void** ptr, size_t size) {\n  CudaCurrentDeviceGuard guard(device_index_);\n  CHECK(!options.HasPinnedDevice());\n  cudaError_t err = cudaMalloc(ptr, size);\n  if (err != cudaSuccess) {\n    if (err == cudaErrorMemoryAllocation) {\n      // NOTE:return out of memory error, so vm will try to shrink memory and rerun\n      return Error::OutOfMemoryError()\n             << \"CUDA \" << cudaGetErrorString(err) << \". Tried to allocate \" << FormatMemSize(size);\n    }\n    return Error::RuntimeError() << cudaGetErrorString(err);\n  } else {\n    return Maybe<void>::Ok();\n  }\n}\n\nvoid CudaDevice::Free(const AllocationOptions& attr, void* ptr) {\n  CudaCurrentDeviceGuard guard(device_index_);\n  OF_CUDA_CHECK(cudaFree(ptr));\n}\n\nMaybe<void> CudaDevice::AllocPinned(const AllocationOptions& options, void** ptr, size_t size) {\n  CudaCurrentDeviceGuard guard(device_index_);\n  cudaError_t err = NumaAwareCudaMallocHost(device_index_, ptr, size);\n  if (err != cudaSuccess) {\n    return Error::RuntimeError() << cudaGetErrorString(err);\n  } else {\n    return Maybe<void>::Ok();\n  }\n}\n\nvoid CudaDevice::FreePinned(const AllocationOptions& options, void* ptr) {\n  CudaCurrentDeviceGuard guard(device_index_);\n  OF_CUDA_CHECK(cudaFreeHost(ptr));\n}\n\nbool CudaDevice::IsStreamOrderedMemoryAllocationSupported() const {\n#if CUDA_VERSION >= 11020\n  return mem_pool_ != nullptr;\n#else\n  return false;\n#endif  // CUDA_VERSION >= 11020\n}\n\n#if CUDA_VERSION >= 11020\ncudaMemPool_t CudaDevice::mem_pool() { return mem_pool_; }\n#endif  // CUDA_VERSION >= 11020\n\nconst cudaDeviceProp& CudaDevice::properties() const { return properties_; }\n\nconst void* CudaDevice::GetConstZeros(DataType data_type, size_t n) const {\n  if (GetSizeOfDataType(data_type) * n\n      <= GetSizeOfDataType(DataType::kFloat) * const_buf_elem_cnt_) {\n    return const_zeros_buffer_;\n  } else {\n    return nullptr;\n  }\n}\n\nconst void* CudaDevice::GetConstOnes(DataType data_type, size_t n) const {\n  if (n <= const_buf_elem_cnt_) {\n    if (data_type == DataType::kFloat) {\n      return const_ones_buffer_fp32_;\n    } else if (data_type == DataType::kFloat16) {\n      return const_ones_buffer_fp16_;\n    } else if (data_type == DataType::kBFloat16) {\n      return const_ones_buffer_bf16_;\n    } else {\n      return nullptr;\n    }\n  } else {\n    return nullptr;\n  }\n}\n\n}  // namespace ep\n\n}  // namespace oneflow\n\n#endif  // WITH_CUDA\n"
  },
  {
    "path": "oneflow/core/ep/cuda/cuda_device.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_EP_CUDA_CUDA_DEVICE_H_\n#define ONEFLOW_CORE_EP_CUDA_CUDA_DEVICE_H_\n\n#include \"oneflow/core/ep/include/device.h\"\n#include \"oneflow/core/common/data_type.h\"\n\n#ifdef WITH_CUDA\n\n#include <cuda_runtime.h>\n\nnamespace oneflow {\n\nnamespace ep {\n\nclass CudaDevice : public Device {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CudaDevice);\n  explicit CudaDevice(int device_index, DeviceManager* device_manager);\n  ~CudaDevice() override;\n\n  void SetAsActiveDevice() override;\n  void Reset() override;\n\n  DeviceType device_type() const override { return DeviceType::kCUDA; }\n  size_t device_index() const override { return device_index_; }\n  DeviceManager* device_manager() const override { return device_manager_; }\n\n  Stream* CreateStream() override;\n  void DestroyStream(Stream* stream) override;\n\n  void CreateEvents(Event** events, size_t count) override;\n  void DestroyEvents(Event** events, size_t count) override;\n\n  Maybe<void> Alloc(const AllocationOptions& options, void** ptr, size_t size) override;\n  void Free(const AllocationOptions& options, void* ptr) override;\n  Maybe<void> AllocPinned(const AllocationOptions& options, void** ptr, size_t size) override;\n  void FreePinned(const AllocationOptions& options, void* ptr) override;\n  bool IsStreamOrderedMemoryAllocationSupported() const override;\n\n#if CUDA_VERSION >= 11020\n  cudaMemPool_t mem_pool();\n#endif  // CUDA_VERSION >= 11020\n  const cudaDeviceProp& properties() const;\n\n  const void* GetConstZeros(DataType data_type, size_t n) const;\n  const void* GetConstOnes(DataType data_type, size_t n) const;\n\n private:\n  int device_index_;\n  std::mutex events_mutex_;\n  std::vector<Event*> events_;\n  unsigned int event_flags_;\n  cudaDeviceProp properties_;\n  DeviceManager* device_manager_;\n  int64_t const_buf_elem_cnt_;\n  void* const_zeros_buffer_;\n  void* const_ones_buffer_fp32_;\n  void* const_ones_buffer_fp16_;\n  void* const_ones_buffer_bf16_;\n#if CUDA_VERSION >= 11020\n  cudaMemPool_t mem_pool_{};\n#endif  // CUDA_VERSION >= 11020\n};\n\n}  // namespace ep\n\n}  // namespace oneflow\n\n#endif  // WITH_CUDA\n\n#endif  // ONEFLOW_CORE_EP_CUDA_CUDA_DEVICE_H_\n"
  },
  {
    "path": "oneflow/core/ep/cuda/cuda_device_manager.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/ep/cuda/cuda_device_manager.h\"\n#include \"oneflow/core/ep/cuda/cuda_random_generator.h\"\n#include \"oneflow/core/device/cuda_util.h\"\n\n#ifdef WITH_CUDA\n\nnamespace oneflow {\n\nnamespace ep {\n\nCudaDeviceManager::CudaDeviceManager(DeviceManagerRegistry* registry) : registry_(registry) {}\nCudaDeviceManager::~CudaDeviceManager() = default;\n\nDeviceManagerRegistry* CudaDeviceManager::registry() const { return registry_; }\n\nstd::shared_ptr<Device> CudaDeviceManager::GetDevice(size_t device_index) {\n  std::lock_guard<std::mutex> lock(devices_mutex_);\n  if (device_index < devices_.size() && devices_.at(device_index)) {\n    return devices_.at(device_index);\n  }\n  auto device = std::make_shared<CudaDevice>(device_index, this);\n  if (device_index >= devices_.size()) { devices_.resize(device_index + 1); }\n  devices_.at(device_index) = device;\n  return device;\n}\n\nsize_t CudaDeviceManager::GetDeviceCount(size_t primary_device_index) {\n  CudaCurrentDeviceGuard guard(primary_device_index);\n  return this->GetDeviceCount();\n}\n\nsize_t CudaDeviceManager::GetDeviceCount() {\n  int count = 0;\n  cudaError_t err = cudaGetDeviceCount(&count);\n  if (err == cudaErrorNoDevice || err == cudaErrorInsufficientDriver) { return 0; }\n  OF_CUDA_CHECK(err);\n  return count;\n}\n\nsize_t CudaDeviceManager::GetActiveDeviceIndex() {\n  int device = 0;\n  OF_CUDA_CHECK(cudaGetDevice(&device));\n  return static_cast<size_t>(device);\n}\n\nvoid CudaDeviceManager::SetActiveDeviceByIndex(size_t device_index) {\n  OF_CUDA_CHECK(cudaSetDevice(static_cast<int>(device_index)));\n}\n\nstd::shared_ptr<RandomGenerator> CudaDeviceManager::CreateRandomGenerator(uint64_t seed,\n                                                                          size_t device_index) {\n  return std::make_shared<CUDAGenerator>(seed, device_index);\n}\n\n}  // namespace ep\n\n}  // namespace oneflow\n\n#endif  // WITH_CUDA\n"
  },
  {
    "path": "oneflow/core/ep/cuda/cuda_device_manager.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_EP_CUDA_CUDA_DEVICE_MANAGER_H_\n#define ONEFLOW_CORE_EP_CUDA_CUDA_DEVICE_MANAGER_H_\n\n#include \"oneflow/core/ep/include/device_manager.h\"\n\n#ifdef WITH_CUDA\n\nnamespace oneflow {\nnamespace ep {\n\nclass CudaDevice;\n\nclass CudaDeviceManager : public DeviceManager {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CudaDeviceManager);\n  CudaDeviceManager(DeviceManagerRegistry* registry);\n  ~CudaDeviceManager() override;\n\n  DeviceManagerRegistry* registry() const override;\n  std::shared_ptr<Device> GetDevice(size_t device_index) override;\n  size_t GetDeviceCount(size_t primary_device_index) override;\n  size_t GetDeviceCount() override;\n  size_t GetActiveDeviceIndex() override;\n  void SetActiveDeviceByIndex(size_t device_index) override;\n  bool IsStreamWaitEventSupported() const override { return true; }\n\n  std::shared_ptr<RandomGenerator> CreateRandomGenerator(uint64_t seed,\n                                                         size_t device_index) override;\n\n private:\n  std::mutex devices_mutex_;\n  std::vector<std::shared_ptr<CudaDevice>> devices_;\n  DeviceManagerRegistry* registry_;\n};\n\n}  // namespace ep\n\n}  // namespace oneflow\n\n#endif  // WITH_CUDA\n\n#endif  // ONEFLOW_CORE_EP_CUDA_CUDA_DEVICE_MANAGER_H_\n"
  },
  {
    "path": "oneflow/core/ep/cuda/cuda_device_manager_factory.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/ep/include/device_manager_factory.h\"\n#include \"oneflow/core/ep/include/device_manager_registry.h\"\n#include \"oneflow/core/ep/cuda/cuda_device_manager.h\"\n\n#ifdef WITH_CUDA\n\n#include <cuda_runtime.h>\n#include <cudnn.h>\n#include <nccl.h>\n\nnamespace oneflow {\n\nnamespace ep {\n\nnamespace {\n\nstd::string GetCudaVersionString(int version) {\n  return std::to_string(version / 1000) + \".\" + std::to_string((version % 1000) / 10);\n}\n\nbool GetCudnnVersion(libraryPropertyType type, int* version) {\n  cudnnStatus_t status = cudnnGetProperty(type, version);\n  if (status == CUDNN_STATUS_SUCCESS) {\n    return true;\n  } else {\n    LOG(ERROR) << \"Failed to get cuDNN version: \" << cudnnGetErrorString(status);\n    return false;\n  }\n}\n\nbool GetCudnnVersionString(std::string* version) {\n  int version_major = 0;\n  int version_minor = 0;\n  int version_patch = 0;\n  if (!GetCudnnVersion(libraryPropertyType::MAJOR_VERSION, &version_major)) { return false; }\n  if (!GetCudnnVersion(libraryPropertyType::MINOR_VERSION, &version_minor)) { return false; }\n  if (!GetCudnnVersion(libraryPropertyType::PATCH_LEVEL, &version_patch)) { return false; }\n  *version = std::to_string(version_major) + \".\" + std::to_string(version_minor) + \".\"\n             + std::to_string(version_patch);\n  return true;\n}\n\nvoid CudaDumpVersionInfo() {\n  {\n    int cuda_runtime_version = 0;\n    cudaError_t err = cudaRuntimeGetVersion(&cuda_runtime_version);\n    if (err == cudaSuccess) {\n      LOG(INFO) << \"CUDA runtime version: \" << GetCudaVersionString(cuda_runtime_version);\n    } else {\n      LOG(ERROR) << \"Failed to get cuda runtime version: \" << cudaGetErrorString(err);\n    }\n  }\n\n  {\n    std::string cudnn_version_string;\n    if (GetCudnnVersionString(&cudnn_version_string)) {\n      LOG(INFO) << \"cuDNN version: \" << cudnn_version_string;\n    }\n  }\n\n  {\n    int nccl_version = 0;\n    ncclResult_t result = ncclGetVersion(&nccl_version);\n    if (result == ncclSuccess) {\n      int nccl_version_major =\n          (nccl_version >= 20900) ? (nccl_version / 10000) : (nccl_version / 1000);\n      int nccl_version_minor =\n          (nccl_version >= 20900) ? (nccl_version % 10000) / 100 : (nccl_version % 1000) / 100;\n      int nccl_version_patch = (nccl_version % 100);\n      LOG(INFO) << \"NCCL version: \" << nccl_version_major << \".\" << nccl_version_minor << \".\"\n                << nccl_version_patch;\n    } else {\n      LOG(ERROR) << \"Failed to get NCCL version: \" << ncclGetErrorString(result);\n    }\n  }\n}\n\nclass CudaDeviceManagerFactory : public DeviceManagerFactory {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CudaDeviceManagerFactory);\n  CudaDeviceManagerFactory() = default;\n  ~CudaDeviceManagerFactory() override = default;\n\n  std::unique_ptr<DeviceManager> NewDeviceManager(DeviceManagerRegistry* registry) override {\n    return std::make_unique<CudaDeviceManager>(registry);\n  }\n\n  DeviceType device_type() const override { return DeviceType::kCUDA; }\n\n  std::string device_type_name() const override { return \"cuda\"; }\n\n  void DumpVersionInfo() const override { CudaDumpVersionInfo(); }\n};\n\nCOMMAND(DeviceManagerRegistry::RegisterDeviceManagerFactory(\n    std::make_unique<CudaDeviceManagerFactory>()))\n\n}  // namespace\n\n}  // namespace ep\n\n}  // namespace oneflow\n\n#endif  // WITH_CUDA\n"
  },
  {
    "path": "oneflow/core/ep/cuda/cuda_event.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/ep/cuda/cuda_event.h\"\n\n#ifdef WITH_CUDA\n\nnamespace oneflow {\n\nnamespace ep {\n\nCudaEvent::CudaEvent(unsigned int flags) : cuda_event_{} {\n  OF_CUDA_CHECK(cudaEventCreateWithFlags(&cuda_event_, flags));\n}\n\nCudaEvent::~CudaEvent() { OF_CUDA_CHECK(cudaEventDestroy(cuda_event_)); }\n\nMaybe<bool> CudaEvent::QueryDone() {\n  cudaError_t err = cudaEventQuery(cuda_event_);\n  if (err == cudaSuccess) {\n    return Maybe<bool>(true);\n  } else if (err == cudaErrorNotReady) {\n    return Maybe<bool>(false);\n  } else {\n    return Error::RuntimeError() << cudaGetErrorString(err);\n  }\n}\n\nMaybe<void> CudaEvent::Sync() {\n  cudaError_t err = cudaEventSynchronize(cuda_event_);\n  if (err == cudaSuccess) {\n    return Maybe<void>::Ok();\n  } else {\n    return Error::RuntimeError() << cudaGetErrorString(err);\n  }\n}\n\ncudaEvent_t CudaEvent::cuda_event() { return cuda_event_; }\n\n}  // namespace ep\n\n}  // namespace oneflow\n\n#endif  // WITH_CUDA\n"
  },
  {
    "path": "oneflow/core/ep/cuda/cuda_event.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_EP_CUDA_CUDA_EVENT_H_\n#define ONEFLOW_CORE_EP_CUDA_CUDA_EVENT_H_\n\n#include \"oneflow/core/ep/include/event.h\"\n\n#ifdef WITH_CUDA\n\n#include \"oneflow/core/device/cuda_util.h\"\n\nnamespace oneflow {\n\nnamespace ep {\n\nclass CudaEvent : public Event {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CudaEvent);\n  explicit CudaEvent(unsigned int flags);\n  ~CudaEvent() override;\n\n  Maybe<bool> QueryDone() override;\n  Maybe<void> Sync() override;\n\n  cudaEvent_t cuda_event();\n\n private:\n  cudaEvent_t cuda_event_;\n};\n\n}  // namespace ep\n\n}  // namespace oneflow\n\n#endif  // WITH_CUDA\n\n#endif  // ONEFLOW_CORE_EP_CUDA_CUDA_EVENT_H_\n"
  },
  {
    "path": "oneflow/core/ep/cuda/cuda_matmul_mode.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/ep/cuda/cuda_matmul_mode.h\"\n\nnamespace oneflow {\n\nnamespace ep {\n\nnamespace {\n\nbool* GetMatmulAllowTF32() {\n  static bool matmul_allow_tf32 = true;\n  return &matmul_allow_tf32;\n}\n\nbool* GetMatmulAllowFP16ReducedPrecisionReducton() {\n  static bool matmul_allow_fp16_reduced_precision_reduction = true;\n  return &matmul_allow_fp16_reduced_precision_reduction;\n}\n\n}  // namespace\n\nbool CudaMatmulMode::is_matmul_allow_tf32() { return *GetMatmulAllowTF32(); }\n\nvoid CudaMatmulMode::set_matmul_allow_tf32(bool matmul_allow_tf32) {\n  *GetMatmulAllowTF32() = matmul_allow_tf32;\n}\n\nbool CudaMatmulMode::is_matmul_allow_fp16_reduced_precision_reduction() {\n  return *GetMatmulAllowFP16ReducedPrecisionReducton();\n}\n\nvoid CudaMatmulMode::set_matmul_allow_fp16_reduced_precision_reduction(\n    bool matmul_allow_fp16_reduced_precision_reduction) {\n  *GetMatmulAllowFP16ReducedPrecisionReducton() = matmul_allow_fp16_reduced_precision_reduction;\n}\n\n}  // namespace ep\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/ep/cuda/cuda_matmul_mode.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#ifndef ONEFLOW_CORE_EP_CUDA_MATMUL_MODE_H_\n#define ONEFLOW_CORE_EP_CUDA_MATMUL_MODE_H_\n\nnamespace oneflow {\nnamespace ep {\n\nstruct CudaMatmulMode {\n  static bool is_matmul_allow_tf32();\n  static void set_matmul_allow_tf32(bool matmul_allow_tf32);\n  static bool is_matmul_allow_fp16_reduced_precision_reduction();\n  static void set_matmul_allow_fp16_reduced_precision_reduction(\n      bool matmul_allow_fp16_reduced_precision_reduction);\n};\n\n}  // namespace ep\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_EP_CUDA_MATMUL_MODE_H_\n"
  },
  {
    "path": "oneflow/core/ep/cuda/cuda_random_generator.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifdef WITH_CUDA\n\n#include \"oneflow/core/ep/cuda/cuda_random_generator.h\"\n\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n#include \"oneflow/core/device/cuda_util.h\"\n#include <cuda.h>\n#include <cuda_runtime.h>\n\nnamespace oneflow {\nnamespace ep {\n\nnamespace {\n\nint GetThreadNum(const cudaDeviceProp& prop) {\n  switch (prop.major) {\n    case 3:  // Kepler\n      return 2 * 192;\n    case 5:  // Maxwell\n      return 2 * 128;\n    case 6:  // Pascal\n      if ((prop.minor == 1) || (prop.minor == 2)) {\n        return 2 * 128;\n      } else {\n        return 2 * 64;\n      }\n    case 7:  // Volta and Turing\n      return 2 * 64;\n    default: return 2 * 64;\n  }\n}\n\n}  // namespace\n\nCUDAGenerator::CUDAGenerator(uint64_t seed, int device_index)\n    : RandomGenerator(), seed_(seed), device_index_(device_index), philox_offset_per_thread_(0) {\n  int device_count;\n  OF_CUDA_CHECK(cudaGetDeviceCount(&device_count));\n  CHECK_LT_OR_THROW(device_index, device_count)\n      << \"only \" << device_count << \" cuda devices are visible.\";\n  cudaDeviceProp prop;  // NOLINT\n  OF_CUDA_CHECK(cudaGetDeviceProperties(&prop, device_index));\n  max_block_num_ = prop.multiProcessorCount;\n  max_thread_num_ = GetThreadNum(prop);\n}\n\nvoid CUDAGenerator::set_current_seed(uint64_t seed) {\n  seed_ = seed;\n  philox_offset_per_thread_ = 0;\n}\n\nstd::tuple<uint64_t, dim3, dim3> CUDAGenerator::CalcExecutionPolicy(int64_t total_elements,\n                                                                    ep::CudaStream* stream) {\n  // NOTE(Liang Depeng): the implementation is modified from\n  // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cuda/DistributionTemplates.h\n\n  const uint64_t numel = static_cast<uint64_t>(total_elements);\n  const uint32_t block_size = 256;  // block_size_bound\n  // number of randoms given by distributions like curand_uniform4, curand_uniform2_double\n  // used in calculating philox offset.\n  const uint32_t curand4_engine_calls = 4;\n  const uint32_t unroll = curand4_engine_calls;\n  dim3 dim_block(block_size);\n  dim3 grid((numel + block_size - 1) / block_size);\n  uint32_t blocks_per_sm = stream->device_properties().maxThreadsPerMultiProcessor / block_size;\n  grid.x = std::min(\n      static_cast<uint32_t>(stream->device_properties().multiProcessorCount) * blocks_per_sm,\n      grid.x);\n  // number of times random will be generated per thread, to offset philox counter in thc random\n  // state\n  uint64_t counter_offset =\n      ((numel - 1) / (block_size * grid.x * unroll) + 1) * curand4_engine_calls;\n  return std::make_tuple(counter_offset, grid, dim_block);\n}\n\n// NOTE(Liang Depeng): The implementation of ` CUDAGenerator::get_philox_offset` is modified\n// from\n//      https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/cuda/CUDAGenerator.cpp#L269\n//      in order to make distribution related cuda kernels to have the same output as pytorch\n//      when setting the same seed.\nuint64_t CUDAGenerator::get_philox_offset(uint64_t increment) {\n  std::lock_guard<std::mutex> lock(mutex_);\n  // rounds increment up to the nearest multiple of 4\n  increment = ((increment + 3) / 4) * 4;\n  CHECK_EQ(this->philox_offset_per_thread_ % 4, 0);\n  uint64_t offset = this->philox_offset_per_thread_;\n  this->philox_offset_per_thread_ += increment;\n  return offset;\n}\n\n// NOTE: The RNG state comprises the seed, and an offset used for Philox.\n// The following line is just here for aligning Pytorch and it is also no\n// practical effect in Pytorch just for backward compatibility reason.\n// For more details pls refer to:\n// https://github.com/pytorch/pytorch/blob/v1.13.1/aten/src/ATen/cuda/CUDAGenerator.cpp#L152\nstatic constexpr size_t states_size = 200 * sizeof(4120);\nstatic constexpr size_t seed_size = sizeof(uint64_t);\nstatic constexpr size_t offset_size = sizeof(int64_t);\nstatic constexpr size_t total_size = states_size + seed_size + offset_size;\n\nsize_t CUDAGenerator::GetStateSize() const { return total_size; }\n\nvoid CUDAGenerator::GetState(size_t state_size, void* state) const {\n  CHECK_EQ_OR_THROW(state_size, GetStateSize())\n      << \"the state size of cuda generator should be equal to \" << GetStateSize();\n  memset(static_cast<uint8_t*>(state), -1, states_size);\n  memcpy(static_cast<uint8_t*>(state) + states_size, &seed_, seed_size);\n  memcpy(static_cast<uint8_t*>(state) + states_size + seed_size, &philox_offset_per_thread_,\n         offset_size);\n}\n\nvoid CUDAGenerator::SetState(size_t state_size, const void* state) {\n  CHECK_EQ_OR_THROW(state_size, GetStateSize())\n      << \"the state size of cuda generator should be equal to \" << GetStateSize();\n  const uint8_t* data = static_cast<const uint8_t*>(state);\n  seed_ = *((uint64_t*)(data + states_size));\n  philox_offset_per_thread_ = *((uint64_t*)(data + states_size + seed_size));\n}\n\ntemplate<>\nstd::string GetRandomGeneratorDeviceTypeName<CUDAGenerator>() {\n  return \"cuda\";\n}\n\n}  // namespace ep\n}  // namespace oneflow\n\n#endif  // WITH_CUDA\n"
  },
  {
    "path": "oneflow/core/ep/cuda/cuda_random_generator.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_EP_CUDA_RANDOM_GENERATOR_H_\n#define ONEFLOW_CORE_EP_CUDA_RANDOM_GENERATOR_H_\n\n#ifdef WITH_CUDA\n\n#include <mutex>\n#include <curand.h>\n#include <curand_kernel.h>\n\n#include \"oneflow/core/common/device_type.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n#include \"oneflow/core/ep/include/random_generator.h\"\n\nnamespace oneflow {\nnamespace ep {\n\nclass CUDAGenerator : public RandomGenerator {\n public:\n  explicit CUDAGenerator(uint64_t seed, int device_index);\n  virtual ~CUDAGenerator() = default;\n\n  int32_t max_block_num() const { return max_block_num_; }\n  int32_t max_thread_num() const { return max_thread_num_; }\n\n  uint64_t current_seed() const override { return seed_; }\n  void set_current_seed(uint64_t seed) override;\n\n  std::string device_type_name() const override { return \"cuda\"; }\n  int64_t device_index() const override { return device_index_; }\n\n  size_t GetStateSize() const override;\n  void GetState(size_t state_size, void* state) const override;\n  void SetState(size_t state_size, const void* state) override;\n\n  std::tuple<uint64_t, dim3, dim3> CalcExecutionPolicy(int64_t total_elements, CudaStream* stream);\n\n  uint64_t get_philox_offset(uint64_t increment);\n\n public:\n  mutable std::mutex mutex_;\n\n private:\n  uint64_t seed_;\n  int64_t device_index_;\n  int32_t max_block_num_;\n  int32_t max_thread_num_;\n  uint64_t philox_offset_per_thread_ = 0;\n};\n\n}  // namespace ep\n}  // namespace oneflow\n\n#endif  // WITH_CUDA\n\n#endif  // ONEFLOW_CORE_EP_CUDA_RANDOM_GENERATOR_H_\n"
  },
  {
    "path": "oneflow/core/ep/cuda/cuda_stream.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n#include \"oneflow/core/job/global_for.h\"\n#include \"oneflow/core/job/resource_desc.h\"\n#include \"oneflow/core/hardware/node_device_descriptor_manager.h\"\n#include \"oneflow/core/hardware/cuda_device_descriptor.h\"\n#include \"oneflow/core/ep/cuda/cuda_event.h\"\n#include \"oneflow/core/ep/cuda/cuda_device.h\"\n\n#ifdef WITH_CUDA\n\nnamespace oneflow {\n\nnamespace ep {\n\nnamespace {\n\nconstexpr size_t kDefaultWorkspaceSizeMb = 4;  // 4M\n\nvoid SetAffinityByDevice(int dev_id) {\n  auto node_device_desc_mgr = Singleton<hardware::NodeDeviceDescriptorManager>::Get();\n  if (node_device_desc_mgr == nullptr) { return; }\n  auto node_device_desc = node_device_desc_mgr->GetLocalNodeDeviceDescriptor();\n  auto cuda_device = std::dynamic_pointer_cast<const hardware::CudaDeviceDescriptor>(\n      node_device_desc->GetDevice(hardware::kCudaDeviceDescriptorClassName, dev_id));\n  if (!cuda_device) { return; }\n  node_device_desc->Topology()->SetCPUAffinityByPCIBusID(cuda_device->PCIBusID());\n  node_device_desc->Topology()->SetMemoryAffinityByPCIBusID(cuda_device->PCIBusID());\n}\n\nvoid CheckVersionCompatibility(int compiletime_major, int compiletime_minor, int runtime_major,\n                               int runtime_minor, const std::string& name) {\n  if (runtime_major != compiletime_major || runtime_minor < compiletime_minor) {\n    LOG(WARNING) << \"Runtime version \" << runtime_major << \".\" << runtime_minor << \" of \" << name\n                 << \" incompatible with compiletime version \" << compiletime_major << \".\"\n                 << compiletime_minor << \".\";\n  }\n}\n\nvoid CheckCudaRuntimeVersion() {\n#if !defined(CUDART_VERSION)\n#error\n#endif  // !defined(CUDART_VERSION)\n  const int compiletime_major = CUDART_VERSION / 1000;\n  const int compiletime_minor = CUDART_VERSION % 1000 / 10;\n  int runtime_version = 0;\n  OF_CUDA_CHECK(cudaRuntimeGetVersion(&runtime_version));\n  const int runtime_major = runtime_version / 1000;\n  const int runtime_minor = runtime_version % 1000 / 10;\n  CheckVersionCompatibility(compiletime_major, compiletime_minor, runtime_major, runtime_minor,\n                            \"CUDA Runtime\");\n}\n\nvoid CheckCublasVersion(cublasHandle_t handle) {\n#if CUDA_VERSION >= 10020\n#if (!defined(CUBLAS_VER_MAJOR)) || (!defined(CUBLAS_VER_MINOR))\n#error\n#endif  // (!defined(CUBLAS_VER_MAJOR)) || (!defined(CUBLAS_VER_MINOR))\n  int runtime_version = 0;\n  OF_CUBLAS_CHECK(cublasGetVersion(handle, &runtime_version));\n  int runtime_major = 0;\n  int runtime_minor = 0;\n  if (runtime_version >= 100000) {\n    runtime_major = runtime_version / 10000;\n    runtime_minor = runtime_version % 10000 / 100;\n  } else {\n    runtime_major = runtime_version / 1000;\n    runtime_minor = runtime_version % 1000 / 100;\n  }\n  CheckVersionCompatibility(CUBLAS_VER_MAJOR, CUBLAS_VER_MINOR, runtime_major, runtime_minor,\n                            \"cuBLAS\");\n#endif  // CUDA_VERSION >= 10020\n}\n\nvoid CheckCudnnVersion() {\n#if (!defined(CUDNN_MAJOR)) || (!defined(CUDNN_MINOR))\n#error\n#endif  // (!defined(CUDNN_MAJOR)) || (!defined(CUDNN_MINOR))\n  int runtime_major = 0;\n  int runtime_minor = 0;\n  OF_CUDNN_CHECK(cudnnGetProperty(libraryPropertyType::MAJOR_VERSION, &runtime_major));\n  OF_CUDNN_CHECK(cudnnGetProperty(libraryPropertyType::MINOR_VERSION, &runtime_minor));\n  CheckVersionCompatibility(CUDNN_MAJOR, CUDNN_MINOR, runtime_major, runtime_minor, \"cuDNN\");\n}\n\n}  // namespace\n\n#ifdef WITH_CUDA_GRAPHS\n\nCudaGraphExecutable::CudaGraphExecutable() : graph_exec_(nullptr), dev_(-1) {}\n\nCudaGraphExecutable::~CudaGraphExecutable() { Reset(); }\n\nvoid CudaGraphExecutable::Update(cudaGraph_t graph) {\n  int dev = -1;\n  OF_CUDA_CHECK(cudaGetDevice(&dev));\n  if (dev != dev_) { Reset(); }\n  dev_ = dev;\n  if (graph_exec_ != nullptr) {\n#if CUDA_VERSION < 12000\n    cudaGraphExecUpdateResult update_result{};\n    cudaGraphNode_t error_node = nullptr;\n    OF_CUDA_CHECK(cudaGraphExecUpdate(graph_exec_, graph, &error_node, &update_result));\n    if (update_result == cudaGraphExecUpdateSuccess) { return; }\n#else\n    cudaGraphExecUpdateResultInfo update_result{};\n    OF_CUDA_CHECK(cudaGraphExecUpdate(graph_exec_, graph, &update_result));\n    if (update_result.result == cudaGraphExecUpdateSuccess) { return; }\n#endif  // CUDA_VERSION < 12000\n  }\n  Reset();\n  OF_CUDA_CHECK(cudaGraphInstantiate(&graph_exec_, graph, NULL, NULL, 0));\n}\n\nvoid CudaGraphExecutable::Launch(cudaStream_t stream) const {\n  OF_CUDA_CHECK(cudaGraphLaunch(graph_exec_, stream));\n}\n\nbool CudaGraphExecutable::IsInstantiated() const { return graph_exec_ != nullptr; }\n\nvoid CudaGraphExecutable::Reset() {\n  if (graph_exec_ != nullptr) {\n    CudaCurrentDeviceGuard guard(dev_);\n    OF_CUDA_CHECK(cudaGraphExecDestroy(graph_exec_));\n  }\n}\n\n#endif  // WITH_CUDA_GRAPHS\n\nCudaStream::CudaStream(CudaDevice* device)\n    : device_index_(device->device_index()), device_(device) {\n  CudaCurrentDeviceGuard guard(device_index_);\n\n  const bool need_check_version = []() {\n    static std::atomic<bool> version_checked(false);\n    return version_checked.exchange(true) == false;\n  }();\n\n  if (need_check_version) { CheckCudaRuntimeVersion(); }\n\n  // cuda_stream\n  const char* stream_flags_env_name = \"ONEFLOW_EP_CUDA_STREAM_FLAGS\";\n  if (std::getenv(stream_flags_env_name) != nullptr) {\n    const unsigned int stream_flags = ParseIntegerFromEnv(stream_flags_env_name, 0);\n    OF_CUDA_CHECK(cudaStreamCreateWithFlags(&cuda_stream_, stream_flags));\n  } else {\n    OF_CUDA_CHECK(cudaStreamCreate(&cuda_stream_));\n  }\n  // cublas_handle\n  OF_CUBLAS_CHECK(cublasCreate(&cublas_handle_));\n  OF_CUBLAS_CHECK(cublasSetStream(cublas_handle_, cuda_stream_));\n  if (need_check_version) { CheckCublasVersion(cublas_handle_); }\n#if CUDA_VERSION >= 10010\n  // cublas_lt_handle\n  OF_CUBLAS_CHECK(cublasLtCreate(&cublas_lt_handle_));\n#endif\n#if CUBLAS_VERSION >= 11000\n  if (ParseBooleanFromEnv(\"ONEFLOW_EP_CUDA_ENABLE_TF32_EXECUTION\", true)) {\n    OF_CUBLAS_CHECK(cublasSetMathMode(cublas_handle_, CUBLAS_TF32_TENSOR_OP_MATH));\n  }\n#endif  // CUBLAS_VERSION >= 11000\n  // cusolver_dn_handle\n#if CUDA_VERSION >= 11000\n  OF_CUSOLVER_CHECK(cusolverDnCreate(&cusolver_dn_handle_));\n  OF_CUSOLVER_CHECK(cusolverDnSetStream(cusolver_dn_handle_, cuda_stream_));\n#endif\n  workspace_size_ =\n      ParseIntegerFromEnv(\"ONEFLOW_EP_CUDA_CUBLAS_WORKSPACE_SIZE_MB\", kDefaultWorkspaceSizeMb)\n      * 1024 * 1024;\n  OF_CUDA_CHECK(cudaMalloc(&workspace_, workspace_size_));\n#if CUBLAS_VERSION >= 11200\n  OF_CUBLAS_CHECK(cublasSetWorkspace(cublas_handle_, workspace_, workspace_size_));\n#endif  // CUBLAS_VERSION >= 11200\n  // cudnn_handle\n  OF_CUDNN_CHECK(cudnnCreate(&cudnn_handle_));\n  OF_CUDNN_CHECK(cudnnSetStream(cudnn_handle_, cuda_stream_));\n  if (need_check_version) { CheckCudnnVersion(); }\n}\n\nCudaStream::~CudaStream() {\n  CudaCurrentDeviceGuard guard(device_index_);\n  OF_CUDA_CHECK(cudaStreamSynchronize(cuda_stream_));\n  OF_CUDNN_CHECK(cudnnDestroy(cudnn_handle_));\n  OF_CUBLAS_CHECK(cublasDestroy(cublas_handle_));\n#if CUDA_VERSION >= 11000\n  OF_CUSOLVER_CHECK(cusolverDnDestroy(cusolver_dn_handle_));\n#endif\n#if CUDA_VERSION >= 10010\n  OF_CUBLAS_CHECK(cublasLtDestroy(cublas_lt_handle_));\n#endif\n  OF_CUDA_CHECK(cudaStreamDestroy(cuda_stream_));\n  OF_CUDA_CHECK(cudaFree(workspace_));\n}\n\nMaybe<void> CudaStream::OnExecutionContextSetup() {\n  OF_CUDA_CHECK(cudaSetDevice(device_index_));\n  SetAffinityByDevice(device_index_);\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> CudaStream::OnExecutionContextTeardown() { return Maybe<void>::Ok(); }\n\nDeviceType CudaStream::device_type() const { return DeviceType::kCUDA; }\n\nCudaDevice* CudaStream::device() const { return device_; }\n\nMaybe<void> CudaStream::Sync() {\n  cudaError_t err = cudaStreamSynchronize(cuda_stream_);\n  if (err == cudaSuccess) {\n    return Maybe<void>::Ok();\n  } else {\n    return Error::RuntimeError() << cudaGetErrorString(err) << \" (\" << err << \") \";\n  }\n}\n\nvoid CudaStream::RecordEvent(Event* event) {\n  auto* cuda_event = static_cast<CudaEvent*>(event);  // NOLINT\n  OF_CUDA_CHECK(cudaEventRecord(cuda_event->cuda_event(), cuda_stream_));\n}\n\nvoid CudaStream::WaitEvent(Event* event) {\n  auto* cuda_event = static_cast<CudaEvent*>(event);  // NOLINT\n  OF_CUDA_CHECK(cudaStreamWaitEvent(cuda_stream_, cuda_event->cuda_event(), 0));\n}\n\nMaybe<void> CudaStream::GetAsyncError() {\n  cudaError_t err = cudaGetLastError();\n  if (err == cudaSuccess) {\n    return Maybe<void>::Ok();\n  } else {\n    return Error::RuntimeError() << cudaGetErrorString(err) << \" (\" << err << \") \";\n  }\n}\n\nMaybe<void> CudaStream::AllocAsync(void** ptr, size_t size) {\n#if CUDA_VERSION >= 11020\n  if (!device_->IsStreamOrderedMemoryAllocationSupported()) { UNIMPLEMENTED_THEN_RETURN(); }\n  cudaError_t err = cudaMallocFromPoolAsync(ptr, size, device_->mem_pool(), cuda_stream_);\n  if (err == cudaSuccess) {\n    return Maybe<void>::Ok();\n  } else {\n    return Error::RuntimeError() << cudaGetErrorString(err) << \" (\" << err << \") \";\n  }\n#else\n  UNIMPLEMENTED_THEN_RETURN();\n#endif  // CUDA_VERSION >= 11020\n}\n\nMaybe<void> CudaStream::FreeAsync(void* ptr) {\n#if CUDA_VERSION >= 11020\n  if (!device_->IsStreamOrderedMemoryAllocationSupported()) { UNIMPLEMENTED_THEN_RETURN(); }\n  cudaError_t err = cudaFreeAsync(ptr, cuda_stream_);\n  if (err == cudaSuccess) {\n    return Maybe<void>::Ok();\n  } else {\n    return Error::RuntimeError() << cudaGetErrorString(err) << \" (\" << err << \") \";\n  }\n#else\n  UNIMPLEMENTED_THEN_RETURN();\n#endif  // CUDA_VERSION >= 11020\n}\n\ncudaStream_t CudaStream::cuda_stream() const { return cuda_stream_; }\n\ncublasHandle_t CudaStream::cublas_handle() const { return cublas_handle_; }\n\n#if CUDA_VERSION >= 11000\ncusolverDnHandle_t CudaStream::cusolver_dn_handle() const { return cusolver_dn_handle_; }\n#endif\n\n#if CUDA_VERSION >= 10010\ncublasLtHandle_t CudaStream::cublas_lt_handle() const { return cublas_lt_handle_; }\n#endif\n\nvoid* CudaStream::cublas_workspace() const { return workspace_; }\n\nsize_t CudaStream::cublas_workspace_size() const { return workspace_size_; }\n\ncudnnHandle_t CudaStream::cudnn_handle() const { return cudnn_handle_; }\n\nconst cudaDeviceProp& CudaStream::device_properties() const { return device_->properties(); }\n\nint CudaStream::cuda_arch() const {\n  return device_->properties().major * 100 + device_->properties().minor * 10;\n}\n\n#ifdef WITH_CUDA_GRAPHS\n\nvoid CudaStream::BeginGraphCapture() {\n  CHECK(!is_graph_capturing_);\n  is_graph_capturing_ = true;\n  OF_CUDA_CHECK(cudaStreamBeginCapture(cuda_stream_, cudaStreamCaptureModeThreadLocal));\n}\n\nvoid CudaStream::EndGraphCapture(CudaGraphExecutable* executable) {\n  cudaGraph_t graph = nullptr;\n  OF_CUDA_CHECK(cudaStreamEndCapture(cuda_stream_, &graph));\n  executable->Update(graph);\n  OF_CUDA_CHECK(cudaGraphDestroy(graph));\n  is_graph_capturing_ = false;\n}\n\nbool CudaStream::IsGraphCapturing() const { return is_graph_capturing_; }\n\nvoid CudaStream::LaunchGraph(const CudaGraphExecutable* executable) {\n  executable->Launch(cuda_stream_);\n}\n\n#endif  // WITH_CUDA_GRAPHS\n\n}  // namespace ep\n\n}  // namespace oneflow\n\n#endif  // WITH_CUDA\n"
  },
  {
    "path": "oneflow/core/ep/cuda/cuda_stream.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_EP_CUDA_CUDA_STREAM_H_\n#define ONEFLOW_CORE_EP_CUDA_CUDA_STREAM_H_\n\n#include \"oneflow/core/ep/include/stream.h\"\n#include \"oneflow/core/ep/cuda/cuda_device.h\"\n\n#ifdef WITH_CUDA\n\n#include <cuda.h>\n#include <cuda_runtime.h>\n\n#if CUDA_VERSION >= 11000\n#define WITH_CUDA_GRAPHS\n#endif  // CUDA_VERSION >= 11000\n\n#include \"oneflow/core/device/cuda_util.h\"\n\nnamespace oneflow {\n\nnamespace ep {\n\nclass CudaDevice;\n\n#ifdef WITH_CUDA_GRAPHS\n\nclass CudaGraphExecutable {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CudaGraphExecutable);\n  CudaGraphExecutable();\n  ~CudaGraphExecutable();\n\n  void Update(cudaGraph_t graph);\n  void Launch(cudaStream_t stream) const;\n  bool IsInstantiated() const;\n\n private:\n  void Reset();\n\n  cudaGraphExec_t graph_exec_;\n  int dev_;\n};\n\n#endif  // WITH_CUDA_GRAPHS\n\nstruct CudaLaunchConfig {\n  dim3 grid_dim;\n  dim3 block_dim;\n  size_t shared_mem_size;\n  CudaLaunchConfig() : grid_dim{}, block_dim{}, shared_mem_size(0) {}\n\n  CudaLaunchConfig(unsigned int grid_size, unsigned int block_size, size_t shared_mem_size)\n      : grid_dim(grid_size), block_dim(block_size), shared_mem_size(shared_mem_size) {}\n};\n\nclass CudaStream : public Stream {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CudaStream);\n  explicit CudaStream(CudaDevice* device);\n  ~CudaStream() override;\n\n  static constexpr uint32_t kDefaultBlockSize = 256;\n\n  DeviceType device_type() const override;\n  CudaDevice* device() const override;\n  Maybe<void> Sync() override;\n  void RecordEvent(Event* event) override;\n  void WaitEvent(Event* event) override;\n  Maybe<void> GetAsyncError() override;\n\n  Maybe<void> AllocAsync(void** ptr, size_t size) override;\n  Maybe<void> FreeAsync(void* ptr) override;\n\n  Maybe<void> OnExecutionContextSetup() override;\n  Maybe<void> OnExecutionContextTeardown() override;\n\n  cudaStream_t cuda_stream() const;\n  cublasHandle_t cublas_handle() const;\n#if CUDA_VERSION >= 11000\n  cusolverDnHandle_t cusolver_dn_handle() const;\n#endif\n\n#if CUDA_VERSION >= 10010\n\n  cublasLtHandle_t cublas_lt_handle() const;\n\n#endif\n\n  cudnnHandle_t cudnn_handle() const;\n  void* cublas_workspace() const;\n  size_t cublas_workspace_size() const;\n  const cudaDeviceProp& device_properties() const;\n  int cuda_arch() const;\n\n  void InitLaunchConfigWithWaves(CudaLaunchConfig* config, size_t elem_cnt, size_t block_size,\n                                 size_t max_waves) const {\n    const uint32_t max_grid_size = max_waves * device_properties().multiProcessorCount\n                                   * (device_properties().maxThreadsPerMultiProcessor / block_size);\n    const uint32_t grid_size =\n        std::min<uint32_t>(max_grid_size, (elem_cnt + block_size - 1) / block_size);\n    config->grid_dim = dim3(grid_size);\n    config->block_dim = dim3(block_size);\n    config->shared_mem_size = 0;\n  }\n\n#ifdef __CUDACC__\n  template<typename... Params, typename... Args>\n  void LaunchKernel(void (*kernel)(Params...), const CudaLaunchConfig& launch_config,\n                    Args... args) {\n    kernel<<<launch_config.grid_dim, launch_config.block_dim, launch_config.shared_mem_size,\n             cuda_stream()>>>(args...);\n  }\n\n  template<typename... Params, typename... Args>\n  void LaunchKernel(void (*kernel)(Params...), size_t elem_cnt, size_t max_waves, Args... args) {\n    constexpr uint32_t block_size = kDefaultBlockSize;\n    CudaLaunchConfig config{};\n    InitLaunchConfigWithWaves(&config, elem_cnt, block_size, max_waves);\n    LaunchKernel(kernel, config, args...);\n  }\n\n  template<typename... Params, typename... Args>\n  void LaunchKernelDefaultWaves(void (*kernel)(Params...), size_t elem_cnt, Args... args) {\n    const size_t default_waves = 32;\n    LaunchKernel(kernel, elem_cnt, default_waves, args...);\n  }\n#endif  // __CUDACC__\n\n#ifdef WITH_CUDA_GRAPHS\n  void BeginGraphCapture();\n  void EndGraphCapture(CudaGraphExecutable* executable);\n  bool IsGraphCapturing() const;\n  void LaunchGraph(const CudaGraphExecutable* executable);\n#endif  // WITH_CUDA_GRAPHS\n\n private:\n  cudaStream_t cuda_stream_{};\n  cublasHandle_t cublas_handle_{};\n#if CUDA_VERSION >= 11000\n  cusolverDnHandle_t cusolver_dn_handle_{};\n#endif\n\n#if CUDA_VERSION >= 10010\n\n  cublasLtHandle_t cublas_lt_handle_{};\n\n#endif\n\n  cudnnHandle_t cudnn_handle_{};\n  int device_index_;\n  void* workspace_{};\n  size_t workspace_size_{};\n#ifdef WITH_CUDA_GRAPHS\n  bool is_graph_capturing_{};\n#endif  // WITH_CUDA_GRAPHS\n  CudaDevice* device_;\n};\n\n}  // namespace ep\n\n}  // namespace oneflow\n\n#endif  // WITH_CUDA\n\n#endif  // ONEFLOW_CORE_EP_CUDA_CUDA_STREAM_H_\n"
  },
  {
    "path": "oneflow/core/ep/cuda/primitive/add.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/ep/include/primitive/add.h\"\n#include \"oneflow/core/ep/cuda/primitive/type_seq.h\"\n#include \"oneflow/core/cuda/elementwise.cuh\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n#include \"oneflow/core/device/cuda_pseudo_bfloat16.h\"\n\nnamespace oneflow {\n\nnamespace ep {\nnamespace primitive {\n\nnamespace {\n\ntemplate<typename... Args>\nstruct AddFunctor;\n\ntemplate<typename T>\nstruct AddFunctor<T> {\n  __device__ T operator()(T x) const { return x; }\n};\n\ntemplate<typename T, typename U, typename... Args>\nstruct AddFunctor<T, U, Args...> {\n  __device__ T operator()(T x0, U x1, Args... xs) const {\n    return x0 + AddFunctor<U, Args...>()(x1, xs...);\n  }\n};\n\ntemplate<typename U, typename... Args>\nstruct AddFunctor<cuComplex, U, Args...> {\n  __device__ cuComplex operator()(cuComplex x0, U x1, Args... xs) const {\n    cuComplex xn = AddFunctor<U, Args...>()(x1, xs...);\n    return cuComplex{x0.x + xn.x, x0.y + xn.y};\n  }\n};\n\ntemplate<typename U, typename... Args>\nstruct AddFunctor<cuDoubleComplex, U, Args...> {\n  __device__ cuDoubleComplex operator()(cuDoubleComplex x0, U x1, Args... xs) const {\n    cuDoubleComplex xn = AddFunctor<U, Args...>()(x1, xs...);\n    return cuDoubleComplex{x0.x + xn.x, x0.y + xn.y};\n  }\n};\n\ntemplate<typename T, typename... Args>\n__global__ void AddGpu(const Args*... srcs, T* dst, size_t count) {\n  CUDA_1D_KERNEL_LOOP_T(size_t, i, count) { dst[i] = AddFunctor<Args...>()(srcs[i]...); }\n}\n\ntemplate<typename T, typename... Args>\nvoid LaunchAddGpu(cudaStream_t stream, const Args*... srcs, T* dst, size_t count) {\n  AddGpu<T, Args...>\n      <<<BlocksNum4ThreadsNum(count), kCudaThreadsNumPerBlock, 0, stream>>>(srcs..., dst, count);\n}\n\ntemplate<typename T>\nvoid DispatchLaunch(cudaStream_t stream, const T* const* srcs, size_t arity, T* dst, size_t count) {\n  if (arity == 0) {\n    OF_CUDA_CHECK(cudaMemsetAsync(dst, 0, count * sizeof(T), stream));\n  } else if (arity == 1) {\n    OF_CUDA_CHECK(cudaMemcpyAsync(dst, srcs[0], count * sizeof(T), cudaMemcpyDefault, stream));\n  } else if (arity == 2) {\n    OF_CUDA_CHECK((cuda::elementwise::Binary<AddFunctor<T, T>, T, T, T>(\n        AddFunctor<T, T>(), count, dst, srcs[0], srcs[1], stream)));\n  } else if (arity == 3) {\n    OF_CUDA_CHECK((cuda::elementwise::Ternary<AddFunctor<T, T, T>, T, T, T, T>(\n        AddFunctor<T, T, T>(), count, dst, srcs[0], srcs[1], srcs[2], stream)));\n  } else if (arity == 4) {\n    LaunchAddGpu<T, T, T, T, T>(stream, srcs[0], srcs[1], srcs[2], srcs[3], dst, count);\n  } else if (arity == 5) {\n    LaunchAddGpu<T, T, T, T, T, T>(stream, srcs[0], srcs[1], srcs[2], srcs[3], srcs[4], dst, count);\n  } else if (arity == 6) {\n    LaunchAddGpu<T, T, T, T, T, T, T>(stream, srcs[0], srcs[1], srcs[2], srcs[3], srcs[4], srcs[5],\n                                      dst, count);\n  } else if (arity == 7) {\n    LaunchAddGpu<T, T, T, T, T, T, T, T>(stream, srcs[0], srcs[1], srcs[2], srcs[3], srcs[4],\n                                         srcs[5], srcs[6], dst, count);\n  } else if (arity == 8) {\n    LaunchAddGpu<T, T, T, T, T, T, T, T, T>(stream, srcs[0], srcs[1], srcs[2], srcs[3], srcs[4],\n                                            srcs[5], srcs[6], srcs[7], dst, count);\n  } else {\n    DispatchLaunch(stream, srcs + 7, arity - 7, dst, count);\n    LaunchAddGpu<T, T, T, T, T, T, T, T, T>(stream, srcs[0], srcs[1], srcs[2], srcs[3], srcs[4],\n                                            srcs[5], srcs[6], dst, dst, count);\n  }\n}\n\ntemplate<typename T>\nclass AddImpl : public Add {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(AddImpl);\n  AddImpl() = default;\n  ~AddImpl() override = default;\n\n  using Add::Launch;\n  void Launch(Stream* stream, const void* const* srcs, size_t arity, void* dst,\n              size_t count) override {\n    cudaStream_t cuda_stream = stream->As<CudaStream>()->cuda_stream();\n    DispatchLaunch(cuda_stream, reinterpret_cast<const T* const*>(srcs), arity,\n                   reinterpret_cast<T*>(dst), count);\n  }\n};\n\ntemplate<typename T>\nstd::unique_ptr<Add> NewAdd() {\n  return std::unique_ptr<Add>(new AddImpl<T>());\n}\n\nclass AddFactoryImpl : public AddFactory {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(AddFactoryImpl);\n  AddFactoryImpl() = default;\n  ~AddFactoryImpl() override = default;\n\n  std::unique_ptr<Add> New(DataType data_type) override {\n#define MAKE_NEW_ADD_ENTRY(type_cpp, type_proto) {type_proto, NewAdd<type_cpp>},\n\n    static const std::map<DataType, std::function<std::unique_ptr<Add>()>> new_add_handle{\n        OF_PP_FOR_EACH_TUPLE(MAKE_NEW_ADD_ENTRY,\n                             CUDA_PRIMITIVE_REAL_TYPE_SEQ CUDA_PRIMITIVE_COMPLEX_TYPE_SEQ)};\n\n#undef MAKE_NEW_ADD_ENTRY\n\n    const auto it = new_add_handle.find(data_type);\n    if (it != new_add_handle.end()) {\n      return it->second();\n    } else {\n      return nullptr;\n    }\n  }\n};\n\nREGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, AddFactory, AddFactoryImpl);\n\n}  // namespace\n\n}  // namespace primitive\n}  // namespace ep\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/ep/cuda/primitive/binary_functor.cuh",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include <cuComplex.h>\n#include \"oneflow/core/ep/common/primitive/binary_functor.h\"\n#include \"oneflow/core/ep/cuda/primitive/unary_functor.cuh\"\n\nnamespace oneflow {\nnamespace ep {\nnamespace primitive {\nnamespace broadcast_elementwise_binary {\n\ntemplate<typename Src, typename Dst>\nstruct BinaryFunctor<DeviceType::kCUDA, BinaryOp::kPow, Src, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src src0, Src src1) const { return pow(src0, src1); }\n};\n\ntemplate<>\nstruct BinaryFunctor<DeviceType::kCUDA, BinaryOp::kFmod, float, float> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC float operator()(float src0, float src1) const { return fmod(src0, src1); }\n};\n\ntemplate<>\nstruct BinaryFunctor<DeviceType::kCUDA, BinaryOp::kFmod, double, double> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC double operator()(double src0, double src1) const { return fmod(src0, src1); }\n};\n\ntemplate<>\nstruct BinaryFunctor<DeviceType::kCUDA, BinaryOp::kFloorDiv, float, float> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC float operator()(float src0, float src1) const { return floor(src0 / src1); }\n};\n\ntemplate<>\nstruct BinaryFunctor<DeviceType::kCUDA, BinaryOp::kFloorDiv, double, double> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC double operator()(double src0, double src1) const { return floor(src0 / src1); }\n};\n\ntemplate<>\nstruct BinaryFunctor<DeviceType::kCUDA, BinaryOp::kTruncDiv, float, float> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC float operator()(float src0, float src1) const { return truncf(src0 / src1); }\n};\n\ntemplate<>\nstruct BinaryFunctor<DeviceType::kCUDA, BinaryOp::kTruncDiv, double, double> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC double operator()(double src0, double src1) const { return trunc(src0 / src1); }\n};\n\ntemplate<>\nstruct BinaryFunctor<DeviceType::kCUDA, BinaryOp::kFloorMod, float, float> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC float operator()(float src0, float src1) const {\n    float trunc_mod = fmod(src0, src1);\n    return (trunc_mod != static_cast<float>(0))\n                   && ((src1 < static_cast<float>(0)) != (trunc_mod < static_cast<float>(0)))\n               ? trunc_mod + src1\n               : trunc_mod;\n  }\n};\n\ntemplate<>\nstruct BinaryFunctor<DeviceType::kCUDA, BinaryOp::kFloorMod, double, double> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC double operator()(double src0, double src1) const {\n    double trunc_mod = fmod(src0, src1);\n    return (trunc_mod != static_cast<double>(0))\n                   && ((src1 < static_cast<double>(0)) != (trunc_mod < static_cast<double>(0)))\n               ? trunc_mod + src1\n               : trunc_mod;\n  }\n};\n\ntemplate<typename Src, typename Dst>\nstruct BinaryFunctor<DeviceType::kCUDA, BinaryOp::kGeluBackwardWithDyX, Src, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {\n#if defined(__CUDA_ARCH__)\n    coef = sqrt(static_cast<Src>(2.0) / acos(static_cast<Src>(-1.0)));\n#else\n    coef = std::sqrt(static_cast<Src>(2.0) / std::acos(static_cast<Src>(-1.0)));\n#endif\n  }\n\n  OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const {\n    return static_cast<Src>(0.5)\n           * (static_cast<Src>(1.0) + erf(static_cast<Src>(M_SQRT1_2) * x)\n              + x * coef * exp(static_cast<Src>(-0.5) * x * x))\n           * dy;\n  }\n  Src coef;\n};\n\ntemplate<typename Src, typename Dst>\nstruct BinaryFunctor<DeviceType::kCUDA, BinaryOp::kFastGeluBackwardWithDyX, Src, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const {\n    // ref to: https://mlfromscratch.com/activation-functions-explained/#gelu\n    const Src one = static_cast<Src>(1);\n    const Src half = static_cast<Src>(0.5);\n    const Src pow3 = x * x * x;\n    const Src tanh_out = std::tanh(alpha * (x + beta * pow3));\n    const Src dtanh = alpha * (half * x + beta * static_cast<Src>(1.5) * pow3);\n    return dy * (half + half * tanh_out + dtanh * (one - tanh_out * tanh_out));\n  }\n\n private:\n  const Src alpha = static_cast<Src>(0.7978845608028654);\n  const Src beta = static_cast<Src>(0.044714998453855515);\n};\n\ntemplate<typename Src, typename Dst>\nstruct BinaryFunctor<DeviceType::kCUDA, BinaryOp::kQuickGeluBackwardWithDyX, Src, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const {\n    const Src one = static_cast<Src>(1.0);\n    const Src sigmoid = one / (one + exp(-x * alpha));\n    return dy * (sigmoid + alpha * x * (sigmoid * (one - sigmoid)));\n  }\n\n private:\n  const Src alpha = static_cast<Src>(1.702);\n};\n\ntemplate<typename Src, typename Dst>\nstruct BinaryFunctor<DeviceType::kCUDA, BinaryOp::kSquareReLUBackwardWithDyX, Src, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const {\n    return static_cast<Dst>((x > static_cast<Src>(0.0)) ? static_cast<Src>(2.0) * x * dy\n                                                        : static_cast<Src>(0.0));\n  }\n};\n\ntemplate<typename Src, typename Dst>\nstruct BinaryFunctor<DeviceType::kCUDA, BinaryOp::kTanhBackwardWithDyY, Src, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src dy, Src y) const {\n    return static_cast<Dst>(dy * (static_cast<Src>(1.0) - y * y));\n  }\n};\n\ntemplate<typename Dst>\nstruct BinaryFunctor<DeviceType::kCUDA, BinaryOp::kScalarExpPowerGrad, int, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) : float_functor(attr0, attr1) {}\n  BinaryFunctor<DeviceType::kCUDA, BinaryOp::kScalarExpPowerGrad, float, float> float_functor;\n\n  OF_DEVICE_FUNC Dst operator()(int src0, int src1) const {\n    return static_cast<Dst>(float_functor(static_cast<float>(src0), static_cast<float>(src1)));\n  }\n};\n\ntemplate<typename Dst>\nstruct BinaryFunctor<DeviceType::kCUDA, BinaryOp::kScalarExpPowerGrad, int8_t, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) : float_functor(attr0, attr1) {}\n  BinaryFunctor<DeviceType::kCUDA, BinaryOp::kScalarExpPowerGrad, float, float> float_functor;\n\n  OF_DEVICE_FUNC Dst operator()(int8_t src0, int8_t src1) const {\n    return static_cast<Dst>(float_functor(static_cast<float>(src0), static_cast<float>(src1)));\n  }\n};\n\ntemplate<typename Dst>\nstruct BinaryFunctor<DeviceType::kCUDA, BinaryOp::kScalarExpPowerGrad, uint8_t, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) : float_functor(attr0, attr1) {}\n  BinaryFunctor<DeviceType::kCUDA, BinaryOp::kScalarExpPowerGrad, float, float> float_functor;\n\n  OF_DEVICE_FUNC Dst operator()(uint8_t src0, uint8_t src1) const {\n    return static_cast<Dst>(float_functor(static_cast<float>(src0), static_cast<float>(src1)));\n  }\n};\n\ntemplate<typename Dst>\nstruct BinaryFunctor<DeviceType::kCUDA, BinaryOp::kScalarExpPowerGrad, int64_t, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) : float_functor(attr0, attr1) {}\n  BinaryFunctor<DeviceType::kCUDA, BinaryOp::kScalarExpPowerGrad, float, float> float_functor;\n\n  OF_DEVICE_FUNC Dst operator()(int src0, int src1) const {\n    return static_cast<Dst>(float_functor(static_cast<float>(src0), static_cast<float>(src1)));\n  }\n};\n\ntemplate<typename Src, typename Dst>\nstruct BinaryFunctor<DeviceType::kCUDA, BinaryOp::kAtanhBackwardWithDyX, Src, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n  OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const {\n    const Src one = static_cast<Src>(1.0);\n    return dy * one / (one - static_cast<Src>(pow(x, 2)));\n  }\n};\n\ntemplate<typename Src, typename Dst>\nstruct BinaryFunctor<DeviceType::kCUDA, BinaryOp::kIsCloseEqualNan, Src, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1)\n      : atol(attr0.Value<float>()), rtol(attr1.Value<float>()) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src src0, Src src1) const {\n    bool close = src0 == src1;\n    close |= (isnan(src0) and isnan(src1));\n    if (atol == 0 and rtol == 0) return close;\n    Src allowed_error = static_cast<Src>(atol) + abs(static_cast<Src>(rtol) * src1);\n    Src actual_error = abs(src0 - src1);\n    close |= (isfinite(actual_error) and (actual_error <= allowed_error));\n    return close;\n  }\n  float atol, rtol;\n};\n\ntemplate<typename Src, typename Dst>\nstruct BinaryFunctor<DeviceType::kCUDA, BinaryOp::kIsClose, Src, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1)\n      : atol(attr0.Value<float>()), rtol(attr1.Value<float>()) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src src0, Src src1) const {\n    bool close = src0 == src1;\n    if (atol == 0 and rtol == 0) return close;\n    Src allowed_error = static_cast<Src>(atol) + abs(static_cast<Src>(rtol) * src1);\n    Src actual_error = abs(src0 - src1);\n    close |= (isfinite(actual_error) and (actual_error <= allowed_error));\n    return close;\n  }\n  float atol, rtol;\n};\n\ntemplate<typename Src, typename Dst>\nstruct BinaryFunctor<DeviceType::kCUDA, BinaryOp::kDigammaBackwardWithDyX, Src, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const {\n    ep::primitive::UnaryFunctor<DeviceType::kCUDA, UnaryOp::kTrigamma, Src, Dst> trigamma_functor(\n        0, 0);\n    Src trigamma_result = trigamma_functor(x);\n    return trigamma_result * dy;\n  }\n};\n\ntemplate<typename Src, typename Dst>\nstruct BinaryFunctor<DeviceType::kCUDA, BinaryOp::kLgammaBackwardWithDyX, Src, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n  OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const {\n    ep::primitive::UnaryFunctor<DeviceType::kCUDA, UnaryOp::kDigamma, Src, Dst> digamma_functor(0,\n                                                                                                0);\n    Dst digamma_result = digamma_functor(x);\n    return digamma_result * dy;\n  }\n};\n\ntemplate<typename Src, typename Dst>\nstruct BinaryFunctor<DeviceType::kCUDA, BinaryOp::kZeta, Src, Dst> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n  OF_DEVICE_FUNC Dst operator()(Src x, Src q) const {\n    // ref\n    // https://github.com/pytorch/pytorch/blob/release/1.13/aten/src/ATen/native/cuda/Math.cuh#L302-L384\n    const Src MACHEP{1.11022302462515654042E-16};\n    constexpr Src zero{0};\n    constexpr Src half{0.5};\n    constexpr Src one{1};\n    static const Src A[] = {\n        12.0,\n        -720.0,\n        30240.0,\n        -1209600.0,\n        47900160.0,\n        -1.8924375803183791606e9, /*1.307674368e12/691*/\n        7.47242496e10,\n        -2.950130727918164224e12,  /*1.067062284288e16/3617*/\n        1.1646782814350067249e14,  /*5.109094217170944e18/43867*/\n        -4.5979787224074726105e15, /*8.028576626982912e20/174611*/\n        1.8152105401943546773e17,  /*1.5511210043330985984e23/854513*/\n        -7.1661652561756670113e18  /*1.6938241367317436694528e27/236364091*/\n    };\n\n    int i = 0;\n    Src a, b, k, s, t, w;\n\n    // Short-circuits x -> +infty\n    if (x == one) { return INFINITY; }\n\n    // Short-circuits x < 1 -> NaN\n    if (x < one) { return NAN; }\n\n    // Short-circuits negative q integers map to +infty,\n    //   negative q non-integers map to NaN\n    if (q <= zero) {\n      if (q == floor(q)) { return INFINITY; }\n      if (x != floor(x)) { return NAN; }\n    }\n\n    s = pow(q, -x);\n    a = q;\n    i = 0;\n    b = zero;\n    while ((i < 9) || (a <= Src{9.0})) {\n      i += 1;\n      a += one;\n      b = pow(a, -x);\n      s += b;\n      if ((-MACHEP * s < b) && (b < MACHEP * s)) { return s; }\n    }\n\n    w = a;\n    s += b * w / (x - one);\n    s -= half * b;\n    a = one;\n    k = zero;\n    for (int i = 0; i < 12; i++) {\n      a *= x + k;\n      b /= w;\n      t = a * b / A[i];\n      s = s + t;\n      t = fabs(t / s);\n\n      if (t < MACHEP) { return s; }\n\n      k += one;\n      a *= x + k;\n      b /= w;\n      k += one;\n    }\n\n    return s;\n  }\n};\n\n#define SPECIALIZATION_INTEGRAL_CLOSENESS_BINARY_FUNCTOR(op, type)                            \\\n  template<typename Dst>                                                                      \\\n  struct BinaryFunctor<DeviceType::kCUDA, op, type, Dst> {                                    \\\n    OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) : float_functor(attr0, attr1) {} \\\n    OF_DEVICE_FUNC Dst operator()(type src0, type src1) const {                               \\\n      return float_functor(static_cast<float>(src0), static_cast<float>(src1));               \\\n    }                                                                                         \\\n    BinaryFunctor<DeviceType::kCUDA, op, float, Dst> float_functor;                           \\\n  };\nSPECIALIZATION_INTEGRAL_CLOSENESS_BINARY_FUNCTOR(BinaryOp::kIsClose, bool);\nSPECIALIZATION_INTEGRAL_CLOSENESS_BINARY_FUNCTOR(BinaryOp::kIsClose, int);\nSPECIALIZATION_INTEGRAL_CLOSENESS_BINARY_FUNCTOR(BinaryOp::kIsClose, char);\nSPECIALIZATION_INTEGRAL_CLOSENESS_BINARY_FUNCTOR(BinaryOp::kIsClose, int8_t);\nSPECIALIZATION_INTEGRAL_CLOSENESS_BINARY_FUNCTOR(BinaryOp::kIsClose, uint8_t);\nSPECIALIZATION_INTEGRAL_CLOSENESS_BINARY_FUNCTOR(BinaryOp::kIsClose, int64_t);\nSPECIALIZATION_INTEGRAL_CLOSENESS_BINARY_FUNCTOR(BinaryOp::kIsCloseEqualNan, bool);\nSPECIALIZATION_INTEGRAL_CLOSENESS_BINARY_FUNCTOR(BinaryOp::kIsCloseEqualNan, int);\nSPECIALIZATION_INTEGRAL_CLOSENESS_BINARY_FUNCTOR(BinaryOp::kIsCloseEqualNan, char);\nSPECIALIZATION_INTEGRAL_CLOSENESS_BINARY_FUNCTOR(BinaryOp::kIsCloseEqualNan, int8_t);\nSPECIALIZATION_INTEGRAL_CLOSENESS_BINARY_FUNCTOR(BinaryOp::kIsCloseEqualNan, uint8_t);\nSPECIALIZATION_INTEGRAL_CLOSENESS_BINARY_FUNCTOR(BinaryOp::kIsCloseEqualNan, int64_t);\n\n/*********nv_bfloat16_kernel*******/\n\n#if CUDA_VERSION >= 11000\n\n#define SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(op)                                     \\\n  template<>                                                                                  \\\n  struct BinaryFunctor<DeviceType::kCUDA, op, nv_bfloat16, nv_bfloat16> {                     \\\n    OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) : float_functor(attr0, attr1) {} \\\n                                                                                              \\\n    BinaryFunctor<DeviceType::kCUDA, op, float, float> float_functor;                         \\\n    OF_DEVICE_FUNC nv_bfloat16 operator()(nv_bfloat16 src0, nv_bfloat16 src1) const {         \\\n      return __float2bfloat16(float_functor(__bfloat162float(src0), __bfloat162float(src1))); \\\n    }                                                                                         \\\n  };\n\nSPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kPow);\nSPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kFmod);\nSPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kFloorDiv);\nSPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kTruncDiv);\nSPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kFloorMod);\nSPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kZeta);\nSPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kScalarBasePowerGrad);\nSPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kScalarExpPowerGrad);\nSPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kIdentityBackwardWithDyX);\nSPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kEluBackwardWithDyX);\nSPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kCeluBackwardWithDyY);\nSPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kGeluBackwardWithDyX);\nSPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kHardswishBackwardWithDyX);\nSPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kHardsigmoidBackwardWithDyX);\nSPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kHardshrinkBackwardWithDyY);\nSPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kHardtanhBackwardWithDyY);\nSPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kLeakyReluBackwardWithDyX);\nSPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kMishBackwardWithDyX);\nSPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kSeluBackwardWithDyX);\nSPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kSiluBackwardWithDyX);\nSPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kSoftsignBackwardWithDyX);\nSPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kSoftplusBackwardWithDyX);\nSPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kSoftshrinkBackwardWithDyY);\nSPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kTanhBackwardWithDyY);\nSPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kThresholdBackwardWithDyX);\nSPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kFastGeluBackwardWithDyX);\nSPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kQuickGeluBackwardWithDyX);\nSPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kSquareReLUBackwardWithDyX);\n\nSPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kAcosBackwardWithDyX);\nSPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kAcoshBackwardWithDyX);\nSPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kAsinBackwardWithDyX);\nSPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kAsinhBackwardWithDyX);\nSPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kCosBackwardWithDyX);\nSPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kCoshBackwardWithDyX);\nSPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kErfBackwardWithDyX);\nSPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kErfcBackwardWithDyX);\nSPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kExpBackwardWithDyX);\nSPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kExp2BackwardWithDyX);\nSPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kExpm1BackwardWithDyX);\nSPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kLog2BackwardWithDyX);\nSPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kLog10BackwardWithDyX);\nSPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kLogSigmoidBackwardWithDyX);\nSPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kReciprocalNoNanBackwardWithDyX);\nSPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kRsqrtBackwardWithDyX);\nSPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kSinBackwardWithDyX);\nSPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kSinhBackwardWithDyX);\nSPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kSqrtBackwardWithDyX);\nSPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kTanBackwardWithDyX);\nSPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kSigmoidBackwardWithDyY);\nSPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kSigmoidBackwardWithDyX);\nSPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kAtanhBackwardWithDyX);\nSPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kLgammaBackwardWithDyX);\n\n#define SPECIALIZATION_BFLOAT16_COMPARISON_BINARY_FUNCTOR(op)                                 \\\n  template<typename Dst>                                                                      \\\n  struct BinaryFunctor<DeviceType::kCUDA, op, nv_bfloat16, Dst> {                             \\\n    OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) : float_functor(attr0, attr1) {} \\\n    BinaryFunctor<DeviceType::kCUDA, op, float, Dst> float_functor;                           \\\n    OF_DEVICE_FUNC Dst operator()(nv_bfloat16 src0, nv_bfloat16 src1) const {                 \\\n      return float_functor(__bfloat162float(src0), __bfloat162float(src1));                   \\\n    }                                                                                         \\\n  };\nSPECIALIZATION_BFLOAT16_COMPARISON_BINARY_FUNCTOR(BinaryOp::kIsCloseEqualNan)\nSPECIALIZATION_BFLOAT16_COMPARISON_BINARY_FUNCTOR(BinaryOp::kIsClose)\n\n#endif  // CUDA_VERSION >= 11000\n\n#define SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(op)                                         \\\n  template<>                                                                                  \\\n  struct BinaryFunctor<DeviceType::kCUDA, op, half, half> {                                   \\\n    OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) : float_functor(attr0, attr1) {} \\\n                                                                                              \\\n    BinaryFunctor<DeviceType::kCUDA, op, float, float> float_functor;                         \\\n    OF_DEVICE_FUNC half operator()(half src0, half src1) const {                              \\\n      return __float2half(float_functor(__half2float(src0), __half2float(src1)));             \\\n    }                                                                                         \\\n  };\n\nSPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kPow);\nSPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kFmod);\nSPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kFloorDiv);\nSPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kTruncDiv);\nSPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kFloorMod);\nSPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kZeta);\nSPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kScalarBasePowerGrad);\nSPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kScalarExpPowerGrad);\nSPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kEluBackwardWithDyX);\nSPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kCeluBackwardWithDyY);\nSPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kGeluBackwardWithDyX);\nSPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kHardswishBackwardWithDyX);\nSPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kHardshrinkBackwardWithDyY);\nSPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kMishBackwardWithDyX);\nSPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kSiluBackwardWithDyX);\nSPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kSeluBackwardWithDyX);\nSPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kSoftplusBackwardWithDyX);\nSPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kSoftsignBackwardWithDyX);\nSPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kSoftshrinkBackwardWithDyY);\nSPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kThresholdBackwardWithDyX);\nSPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kTanhBackwardWithDyY);\nSPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kFastGeluBackwardWithDyX);\nSPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kQuickGeluBackwardWithDyX);\nSPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kSquareReLUBackwardWithDyX);\n\nSPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kAcosBackwardWithDyX);\nSPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kAcoshBackwardWithDyX);\nSPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kAsinBackwardWithDyX);\nSPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kAsinhBackwardWithDyX);\nSPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kCosBackwardWithDyX);\nSPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kCoshBackwardWithDyX);\nSPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kErfBackwardWithDyX);\nSPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kErfcBackwardWithDyX);\nSPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kExpBackwardWithDyX);\nSPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kExp2BackwardWithDyX);\nSPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kExpm1BackwardWithDyX);\nSPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kLog2BackwardWithDyX);\nSPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kLog10BackwardWithDyX);\nSPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kLogSigmoidBackwardWithDyX);\nSPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kReciprocalNoNanBackwardWithDyX);\nSPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kRsqrtBackwardWithDyX);\nSPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kSinBackwardWithDyX);\nSPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kSinhBackwardWithDyX);\nSPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kSqrtBackwardWithDyX);\nSPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kTanBackwardWithDyX);\nSPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kSigmoidBackwardWithDyY);\nSPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kSigmoidBackwardWithDyX);\nSPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kAtanhBackwardWithDyX);\nSPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kLgammaBackwardWithDyX);\n\n#define SPECIALIZATION_HALF_COMPARISON_BINARY_FUNCTOR(op)                                     \\\n  template<typename Dst>                                                                      \\\n  struct BinaryFunctor<DeviceType::kCUDA, op, half, Dst> {                                    \\\n    OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) : float_functor(attr0, attr1) {} \\\n    BinaryFunctor<DeviceType::kCUDA, op, float, Dst> float_functor;                           \\\n    OF_DEVICE_FUNC Dst operator()(half src0, half src1) const {                               \\\n      return float_functor(__half2float(src0), __half2float(src1));                           \\\n    }                                                                                         \\\n  };\n\nSPECIALIZATION_HALF_COMPARISON_BINARY_FUNCTOR(BinaryOp::kIsCloseEqualNan)\nSPECIALIZATION_HALF_COMPARISON_BINARY_FUNCTOR(BinaryOp::kIsClose)\n\ntemplate<>\nstruct BinaryFunctor<DeviceType::kCUDA, BinaryOp::kMul, cuComplex, cuComplex> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC cuComplex operator()(cuComplex src0, cuComplex src1) const {\n    return cuCmulf(src0, src1);\n  }\n};\n\ntemplate<>\nstruct BinaryFunctor<DeviceType::kCUDA, BinaryOp::kDiv, cuComplex, cuComplex> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC cuComplex operator()(cuComplex src0, cuComplex src1) const {\n    return cuCdivf(src0, src1);\n  }\n};\n\ntemplate<>\nstruct BinaryFunctor<DeviceType::kCUDA, BinaryOp::kMul, cuDoubleComplex, cuDoubleComplex> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC cuDoubleComplex operator()(cuDoubleComplex src0, cuDoubleComplex src1) const {\n    return cuCmul(src0, src1);\n  }\n};\n\ntemplate<>\nstruct BinaryFunctor<DeviceType::kCUDA, BinaryOp::kDiv, cuDoubleComplex, cuDoubleComplex> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC cuDoubleComplex operator()(cuDoubleComplex src0, cuDoubleComplex src1) const {\n    return cuCdiv(src0, src1);\n  }\n};\n\ntemplate<>\nstruct BinaryFunctor<DeviceType::kCUDA, BinaryOp::kSqrtBackwardWithDyX, cuComplex, cuComplex> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) : unary_functor(attr0, attr1) {}\n  UnaryFunctor<DeviceType::kCUDA, UnaryOp::kSqrt, cuComplex, cuComplex> unary_functor;\n  OF_DEVICE_FUNC cuComplex operator()(cuComplex dy, cuComplex x) const {\n    // dy / (2 * sqrt(x).conj())\n    cuComplex y = unary_functor(x);\n    return cuCdivf(dy, cuComplex{2.0f * y.x, -2.0f * y.y});\n  }\n};\n\ntemplate<>\nstruct BinaryFunctor<DeviceType::kCUDA, BinaryOp::kSqrtBackwardWithDyX, cuDoubleComplex,\n                     cuDoubleComplex> {\n  OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) : unary_functor(attr0, attr1) {}\n  UnaryFunctor<DeviceType::kCUDA, UnaryOp::kSqrt, cuDoubleComplex, cuDoubleComplex> unary_functor;\n  OF_DEVICE_FUNC cuDoubleComplex operator()(cuDoubleComplex dy, cuDoubleComplex x) const {\n    // dy / (2 * sqrt(x).conj())\n    cuDoubleComplex y = unary_functor(x);\n    return cuCdiv(dy, cuDoubleComplex{2.0 * y.x, -2.0 * y.y});\n  }\n};\n\n#define SPECIALIZATION_COMPLEX_ARITHMETIC_BINARY_FUNCTOR(op, complex_type, real_type)        \\\n  template<>                                                                                 \\\n  struct BinaryFunctor<DeviceType::kCUDA, op, complex_type, complex_type> {                  \\\n    OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) : real_functor(attr0, attr1) {} \\\n    BinaryFunctor<DeviceType::kCUDA, op, real_type, real_type> real_functor;                 \\\n    OF_DEVICE_FUNC complex_type operator()(complex_type src0, complex_type src1) const {     \\\n      return complex_type{real_functor(src0.x, src1.x), real_functor(src0.y, src1.y)};       \\\n    }                                                                                        \\\n  };\n\nSPECIALIZATION_COMPLEX_ARITHMETIC_BINARY_FUNCTOR(BinaryOp::kAdd, cuComplex, float);\nSPECIALIZATION_COMPLEX_ARITHMETIC_BINARY_FUNCTOR(BinaryOp::kSub, cuComplex, float);\nSPECIALIZATION_COMPLEX_ARITHMETIC_BINARY_FUNCTOR(BinaryOp::kAdd, cuDoubleComplex, double);\nSPECIALIZATION_COMPLEX_ARITHMETIC_BINARY_FUNCTOR(BinaryOp::kSub, cuDoubleComplex, double);\n\n#define SPECIALIZATION_COMPLEX_EQAUL_BINARY_FUNCTOR(complex_type, real_type)                 \\\n  template<typename Dst>                                                                     \\\n  struct BinaryFunctor<DeviceType::kCUDA, BinaryOp::kEqual, complex_type, Dst> {             \\\n    OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) : real_functor(attr0, attr1) {} \\\n    BinaryFunctor<DeviceType::kCUDA, BinaryOp::kEqual, real_type, Dst> real_functor;         \\\n    OF_DEVICE_FUNC Dst operator()(complex_type src0, complex_type src1) const {              \\\n      return static_cast<Dst>(real_functor(src0.x, src1.x) && real_functor(src0.y, src1.y)); \\\n    }                                                                                        \\\n  };\nSPECIALIZATION_COMPLEX_EQAUL_BINARY_FUNCTOR(cuComplex, float);\nSPECIALIZATION_COMPLEX_EQAUL_BINARY_FUNCTOR(cuDoubleComplex, double);\n\n#define SPECIALIZATION_COMPLEX_NOT_EQAUL_BINARY_FUNCTOR(complex_type, real_type)             \\\n  template<typename Dst>                                                                     \\\n  struct BinaryFunctor<DeviceType::kCUDA, BinaryOp::kNotEqual, complex_type, Dst> {          \\\n    OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) : real_functor(attr0, attr1) {} \\\n    BinaryFunctor<DeviceType::kCUDA, BinaryOp::kNotEqual, real_type, Dst> real_functor;      \\\n    OF_DEVICE_FUNC Dst operator()(complex_type src0, complex_type src1) const {              \\\n      return static_cast<Dst>(real_functor(src0.x, src1.x) || real_functor(src0.y, src1.y)); \\\n    }                                                                                        \\\n  };\nSPECIALIZATION_COMPLEX_NOT_EQAUL_BINARY_FUNCTOR(cuComplex, float);\nSPECIALIZATION_COMPLEX_NOT_EQAUL_BINARY_FUNCTOR(cuDoubleComplex, double);\n\n#define SPECIALIZATION_GPU_BINARY_FUNCTOR(op, type)                                          \\\n  template<>                                                                                 \\\n  struct BinaryFunctor<DeviceType::kCUDA, op, type, type> {                                  \\\n    OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) : int_functor(attr0, attr1) {}  \\\n                                                                                             \\\n    BinaryFunctor<DeviceType::kCUDA, op, int, int> int_functor;                              \\\n    OF_DEVICE_FUNC type operator()(type src0, type src1) const {                             \\\n      return static_cast<type>(int_functor(static_cast<int>(src0), static_cast<int>(src1))); \\\n    }                                                                                        \\\n  };\n\nSPECIALIZATION_GPU_BINARY_FUNCTOR(BinaryOp::kPow, bool);\nSPECIALIZATION_GPU_BINARY_FUNCTOR(BinaryOp::kFmod, bool);\nSPECIALIZATION_GPU_BINARY_FUNCTOR(BinaryOp::kFloorDiv, bool);\nSPECIALIZATION_GPU_BINARY_FUNCTOR(BinaryOp::kTruncDiv, bool);\nSPECIALIZATION_GPU_BINARY_FUNCTOR(BinaryOp::kFloorMod, bool);\nSPECIALIZATION_GPU_BINARY_FUNCTOR(BinaryOp::kScalarBasePowerGrad, bool);\nSPECIALIZATION_GPU_BINARY_FUNCTOR(BinaryOp::kScalarExpPowerGrad, bool);\nSPECIALIZATION_GPU_BINARY_FUNCTOR(BinaryOp::kPow, char);\nSPECIALIZATION_GPU_BINARY_FUNCTOR(BinaryOp::kFmod, char);\nSPECIALIZATION_GPU_BINARY_FUNCTOR(BinaryOp::kFloorDiv, char);\nSPECIALIZATION_GPU_BINARY_FUNCTOR(BinaryOp::kTruncDiv, char);\nSPECIALIZATION_GPU_BINARY_FUNCTOR(BinaryOp::kFloorMod, char);\nSPECIALIZATION_GPU_BINARY_FUNCTOR(BinaryOp::kScalarBasePowerGrad, char);\nSPECIALIZATION_GPU_BINARY_FUNCTOR(BinaryOp::kScalarExpPowerGrad, char);\n\n}  // namespace broadcast_elementwise_binary\n}  // namespace primitive\n}  // namespace ep\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/ep/include/primitive/broadcast_elementwise_binary.h\"\n#include \"oneflow/core/ep/common/primitive/broadcast_elementwise_binary.h\"\n#include \"oneflow/core/ep/cuda/primitive/type_seq.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n#include \"oneflow/core/cuda/elementwise.cuh\"\n#include \"oneflow/core/ep/cuda/primitive/binary_functor.cuh\"\n\nnamespace oneflow {\n\nnamespace ep {\nnamespace primitive {\nnamespace broadcast_elementwise_binary {\n\ntemplate<BinaryOp binary_op, typename Src, typename Dst>\nstd::unique_ptr<BroadcastElementwiseBinary> NewBroadcastElementwiseBinary(Scalar attr0,\n                                                                          Scalar attr1);\n\nnamespace {\n\nclass BroadcastElementwiseBinaryFactoryImpl : public BroadcastElementwiseBinaryFactory {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(BroadcastElementwiseBinaryFactoryImpl);\n  BroadcastElementwiseBinaryFactoryImpl() = default;\n  ~BroadcastElementwiseBinaryFactoryImpl() override = default;\n\n  std::unique_ptr<BroadcastElementwiseBinary> New(BinaryOp op, DataType src_type, DataType dst_type,\n                                                  size_t max_num_dims) override {\n    return New(op, src_type, dst_type, max_num_dims, Scalar(), Scalar());\n  }\n\n  std::unique_ptr<BroadcastElementwiseBinary> New(BinaryOp op, DataType src_type, DataType dst_type,\n                                                  size_t max_num_dims, Scalar attr0) override {\n    return New(op, src_type, dst_type, max_num_dims, attr0, Scalar());\n  }\n\n  std::unique_ptr<BroadcastElementwiseBinary> New(BinaryOp binary_op, DataType src_type,\n                                                  DataType dst_type, size_t max_num_dims,\n                                                  Scalar attr0, Scalar attr1) override {\n    if (max_num_dims > kMaxNumDims) { return nullptr; }\n#define MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY(binary_op, data_type_pair) \\\n  {std::make_tuple(binary_op, OF_PP_PAIR_SECOND(data_type_pair),                    \\\n                   OF_PP_PAIR_SECOND(data_type_pair)),                              \\\n   NewBroadcastElementwiseBinary<binary_op, OF_PP_PAIR_FIRST(data_type_pair),       \\\n                                 OF_PP_PAIR_FIRST(data_type_pair)>},\n\n#define MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_AND_LOGICAL_ENTRY(      \\\n    binary_op, src_data_type_pair, dst_data_type_pair)                            \\\n  {std::make_tuple(binary_op, OF_PP_PAIR_SECOND(src_data_type_pair),              \\\n                   OF_PP_PAIR_SECOND(dst_data_type_pair)),                        \\\n   NewBroadcastElementwiseBinary<binary_op, OF_PP_PAIR_FIRST(src_data_type_pair), \\\n                                 OF_PP_PAIR_FIRST(dst_data_type_pair)>},\n\n#define MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_ACTIVATION_GRAD_ENTRY(binary_op, data_type_pair) \\\n  {std::make_tuple(binary_op, OF_PP_PAIR_SECOND(data_type_pair),                               \\\n                   OF_PP_PAIR_SECOND(data_type_pair)),                                         \\\n   NewBroadcastElementwiseBinary<binary_op, OF_PP_PAIR_FIRST(data_type_pair),                  \\\n                                 OF_PP_PAIR_FIRST(data_type_pair)>},\n\n    static const std::map<\n        std::tuple<BinaryOp, DataType, DataType>,\n        std::function<std::unique_ptr<BroadcastElementwiseBinary>(Scalar, Scalar)>>\n        new_broadcast_elementwise_binary_handle{\n            OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY,\n                                             BINARY_MATH_OP_SEQ, CUDA_PRIMITIVE_REAL_TYPE_SEQ)\n\n                OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY,\n                                                 BINARY_COMPLEX_MATH_OP_SEQ,\n                                                 CUDA_PRIMITIVE_COMPLEX_TYPE_SEQ)\n\n                    OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(\n                        MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_AND_LOGICAL_ENTRY,\n                        BINARY_COMPARISION_OP_SEQ BINARY_LOGICAL_OP_SEQ,\n                        CUDA_PRIMITIVE_REAL_TYPE_SEQ, CUDA_PRIMITIVE_BOOL_TYPE_SEQ)\n\n                        OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(\n                            MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY,\n                            BINARY_MATH_FLOATING_OP_SEQ, CUDA_PRIMITIVE_FLOATING_TYPE_SEQ)\n\n                            OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(\n                                MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_AND_LOGICAL_ENTRY,\n                                BINARY_COMPLEX_COMPARISION_OP_SEQ, CUDA_PRIMITIVE_COMPLEX_TYPE_SEQ,\n                                CUDA_PRIMITIVE_BOOL_TYPE_SEQ)\n\n                                OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(\n                                    MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_ACTIVATION_GRAD_ENTRY,\n                                    BINARY_ACTIVATION_BACKWARD_OP_SEQ,\n                                    CUDA_PRIMITIVE_FLOATING_TYPE_SEQ)\n\n                                    OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(\n                                        MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_ACTIVATION_GRAD_ENTRY,\n                                        BINARY_MATH_BACKWARD_OP_SEQ,\n                                        CUDA_PRIMITIVE_FLOATING_TYPE_SEQ)\n\n                                        OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(\n                                            MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY,\n                                            BINARY_MATH_BACKWARD_OP_SEQ_COMPLEX,\n                                            CUDA_PRIMITIVE_COMPLEX_TYPE_SEQ)\n\n                                            OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(\n                                                MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY,\n                                                BINARY_BITWISE_OP_SEQ,\n                                                CUDA_PRIMITIVE_INT_TYPE_SEQ\n                                                    CUDA_PRIMITIVE_BOOL_TYPE_SEQ)};\n\n#undef MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_AND_LOGICAL_ENTRY\n#undef MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY\n\n    const auto it = new_broadcast_elementwise_binary_handle.find(\n        std::make_tuple(binary_op, src_type, dst_type));\n    if (it != new_broadcast_elementwise_binary_handle.end()) {\n      return it->second(attr0, attr1);\n    } else {\n      return nullptr;\n    }\n  }\n};\n\nREGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, BroadcastElementwiseBinaryFactory,\n                           BroadcastElementwiseBinaryFactoryImpl);\n}  // namespace\n}  // namespace broadcast_elementwise_binary\n}  // namespace primitive\n}  // namespace ep\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary.cuh",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/ep/include/primitive//broadcast_elementwise_binary.h\"\n#include \"oneflow/core/ep/common/primitive/broadcast_elementwise_binary.h\"\n#include \"oneflow/core/ep/cuda/primitive/type_seq.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n#include \"oneflow/core/cuda/elementwise.cuh\"\n#include \"oneflow/core/ep/cuda/primitive/binary_functor.cuh\"\n\nnamespace oneflow {\n\nnamespace ep {\nnamespace primitive {\nnamespace broadcast_elementwise_binary {\n\nnamespace {\n\ntemplate<typename T, int N>\nstruct GetPackType {\n  using type = typename std::aligned_storage<N * sizeof(T), N * sizeof(T)>::type;\n};\n\ntemplate<typename T, int N>\nusing PackType = typename GetPackType<T, N>::type;\n\ntemplate<typename T, int N>\nunion Pack {\n  static_assert(sizeof(PackType<T, N>) == sizeof(T) * N, \"\");\n  OF_DEVICE_FUNC Pack() {\n    // do nothing\n  }\n  PackType<T, N> storage;\n  T elem[N];\n};\n\ntemplate<size_t max_dims, typename IndexType>\nstruct BroadcastElementwiseBinaryParams {\n  NdIndexOffsetHelper<IndexType, max_dims> src0_index_helper;\n  NdIndexOffsetHelper<IndexType, max_dims> src1_index_helper;\n  NdIndexOffsetHelper<IndexType, max_dims> dst_index_helper;\n  size_t num_dims;\n  IndexType src0_index_mask[max_dims];\n  IndexType src1_index_mask[max_dims];\n  IndexType count{};\n  const void* src0{};\n  const void* src1{};\n  void* dst{};\n  Scalar attr0;\n  Scalar attr1;\n};\n\ntemplate<BinaryOp binary_op, typename Src, typename Dst, size_t max_dims, size_t src0_pack_size,\n         size_t src1_pack_size, typename IndexType>\n__global__ void BroadcastElementwiseBinaryGpu(\n    BroadcastElementwiseBinaryParams<max_dims, IndexType> params) {\n  constexpr size_t dst_pack_size =\n      src0_pack_size > src1_pack_size ? src0_pack_size : src1_pack_size;\n  static_assert(src0_pack_size == dst_pack_size || src0_pack_size == 1, \"\");\n  static_assert(src1_pack_size == dst_pack_size || src1_pack_size == 1, \"\");\n\n  const PackType<Src, src0_pack_size>* src0 =\n      reinterpret_cast<const PackType<Src, src0_pack_size>*>(params.src0);\n  const PackType<Src, src1_pack_size>* src1 =\n      reinterpret_cast<const PackType<Src, src1_pack_size>*>(params.src1);\n  PackType<Dst, dst_pack_size>* dst = reinterpret_cast<PackType<Dst, dst_pack_size>*>(params.dst);\n\n  IndexType src0_index[max_dims];\n  IndexType src1_index[max_dims];\n  IndexType dst_index[max_dims];\n  size_t num_dims = params.num_dims;\n  CUDA_1D_KERNEL_LOOP_T(IndexType, offset, params.count) {\n    params.dst_index_helper.OffsetToNdIndex(offset, dst_index, num_dims);\n#pragma unroll\n    for (int i = 0; i < max_dims; ++i) {\n      if (i < num_dims) {\n        src0_index[i] = params.src0_index_mask[i] * dst_index[i];\n        src1_index[i] = params.src1_index_mask[i] * dst_index[i];\n      } else {\n        src0_index[i] = 0;\n        src1_index[i] = 0;\n      }\n    }\n    const IndexType src0_offset = params.src0_index_helper.NdIndexToOffset(src0_index, num_dims);\n    const IndexType src1_offset = params.src1_index_helper.NdIndexToOffset(src1_index, num_dims);\n    Pack<Src, src0_pack_size> src0_pack;\n    src0_pack.storage = src0[src0_offset];\n    Pack<Src, src1_pack_size> src1_pack;\n    src1_pack.storage = src1[src1_offset];\n    Pack<Dst, dst_pack_size> dst_pack;\n    BinaryFunctor<DeviceType::kCUDA, binary_op, Src, Dst> functor(params.attr0, params.attr1);\n#pragma unroll\n    for (int j = 0; j < dst_pack_size; ++j) {\n      const Src src0_val =\n          (src0_pack_size == dst_pack_size) ? src0_pack.elem[j] : src0_pack.elem[0];\n      const Src src1_val =\n          (src1_pack_size == dst_pack_size) ? src1_pack.elem[j] : src1_pack.elem[0];\n      dst_pack.elem[j] = functor(src0_val, src1_val);\n    }\n    dst[offset] = dst_pack.storage;\n  }\n}\n\ntemplate<BinaryOp op, typename T, typename R, size_t max_dims, size_t src0_pack_size,\n         size_t src1_pack_size, typename IndexType>\nvoid LaunchKernel(Stream* stream, int num_dims, const int64_t* src0_dims, const void* src0,\n                  const int64_t* src1_dims, const void* src1, const int64_t* dst_dims, void* dst,\n                  size_t count, Scalar attr0, Scalar attr1) {\n  BroadcastElementwiseBinaryParams<max_dims, IndexType> params;\n  for (size_t i = 0; i < num_dims; ++i) {\n    params.src0_index_mask[i] = (src0_dims[i] == 1) ? 0 : 1;\n    params.src1_index_mask[i] = (src1_dims[i] == 1) ? 0 : 1;\n  }\n  params.src0_index_helper = NdIndexOffsetHelper<IndexType, max_dims>(src0_dims, num_dims);\n  params.src1_index_helper = NdIndexOffsetHelper<IndexType, max_dims>(src1_dims, num_dims);\n  params.dst_index_helper = NdIndexOffsetHelper<IndexType, max_dims>(dst_dims, num_dims);\n  params.num_dims = num_dims;\n  params.src0 = src0;\n  params.src1 = src1;\n  params.dst = dst;\n  params.count = static_cast<IndexType>(count);\n  params.attr0 = attr0;\n  params.attr1 = attr1;\n  auto* cuda_stream = stream->As<CudaStream>();\n  BroadcastElementwiseBinaryGpu<op, T, R, max_dims, src0_pack_size, src1_pack_size, IndexType>\n      <<<BlocksNum4ThreadsNum(params.count), kCudaThreadsNumPerBlock, 0,\n         cuda_stream->cuda_stream()>>>(params);\n}\n\ntemplate<BinaryOp op, typename T, typename R, size_t max_dims, size_t src0_pack_size,\n         size_t src1_pack_size>\nvoid DispatchIndexType(Stream* stream, size_t num_dims, const int64_t* src0_dims, const void* src0,\n                       const int64_t* src1_dims, const void* src1, const int64_t* dst_dims,\n                       void* dst, Scalar attr0, Scalar attr1) {\n  size_t count = GetElementCount(num_dims, dst_dims);\n  if (count < GetMaxVal<int32_t>()) {\n    LaunchKernel<op, T, R, max_dims, src0_pack_size, src1_pack_size, int32_t>(\n        stream, num_dims, src0_dims, src0, src1_dims, src1, dst_dims, dst, count, attr0, attr1);\n  } else {\n    LaunchKernel<op, T, R, max_dims, src0_pack_size, src1_pack_size, int64_t>(\n        stream, num_dims, src0_dims, src0, src1_dims, src1, dst_dims, dst, count, attr0, attr1);\n  }\n}\n\ntemplate<BinaryOp op, typename T, typename R, size_t max_dims>\nvoid DispatchPackSize(Stream* stream, size_t src0_pack_size, size_t src1_pack_size, size_t num_dims,\n                      const int64_t* src0_dims, const void* src0, const int64_t* src1_dims,\n                      const void* src1, const int64_t* dst_dims, void* dst, Scalar attr0,\n                      Scalar attr1) {\n  void (*func)(Stream* /*stream*/, size_t /*num_dims*/, const int64_t* /*src0_dims*/,\n               const void* /*src0*/, const int64_t* /*src1_dims*/, const void* /*src1*/,\n               const int64_t* /*dst_dims*/, void* /*dst*/, Scalar /*attr0*/, Scalar /*attr1*/) =\n      nullptr;\n  if (src0_pack_size == 1 && src1_pack_size == 1) {\n    func = DispatchIndexType<op, T, R, max_dims, 1, 1>;\n  } else if (src0_pack_size == 4 && src1_pack_size == 4) {\n    func = DispatchIndexType<op, T, R, max_dims, 4, 4>;\n  } else if (src0_pack_size == 1 && src1_pack_size == 4) {\n    func = DispatchIndexType<op, T, R, max_dims, 1, 4>;\n  } else if (src0_pack_size == 4 && src1_pack_size == 1) {\n    func = DispatchIndexType<op, T, R, max_dims, 4, 1>;\n  } else {\n    UNIMPLEMENTED();\n  }\n  func(stream, num_dims, src0_dims, src0, src1_dims, src1, dst_dims, dst, attr0, attr1);\n}\n\ntemplate<BinaryOp op, typename T, typename R>\nvoid DispatchNumDims(Stream* stream, size_t src0_pack_size, size_t src1_pack_size, size_t num_dims,\n                     const int64_t* src0_dims, const void* src0, const int64_t* src1_dims,\n                     const void* src1, const int64_t* dst_dims, void* dst, Scalar attr0,\n                     Scalar attr1) {\n  void (*func)(Stream* /*stream*/, size_t /*src0_pack_size*/, size_t /*src1_pack_size*/,\n               size_t /*num_dims*/, const int64_t* /*src0_dims*/, const void* /*src0*/,\n               const int64_t* /*src1_dims*/, const void* /*src1*/, const int64_t* /*dst_dims*/,\n               void* /*dst*/, Scalar /*attr0*/, Scalar /*attr1*/) = nullptr;\n  CHECK_NE(num_dims, 1);\n  if (num_dims == 2) {\n    func = DispatchPackSize<op, T, R, 2>;\n  } else if (num_dims == 3) {\n    func = DispatchPackSize<op, T, R, 3>;\n  } else if (num_dims == 4) {\n    func = DispatchPackSize<op, T, R, 4>;\n  } else if (num_dims <= 8) {\n    func = DispatchPackSize<op, T, R, 8>;\n  } else {\n    UNIMPLEMENTED();\n  }\n  func(stream, src0_pack_size, src1_pack_size, num_dims, src0_dims, src0, src1_dims, src1, dst_dims,\n       dst, attr0, attr1);\n}\n\ntemplate<size_t max_pack_size, typename T, typename R>\nsize_t GetPackSize(size_t num_src_dims, const int64_t* src0_dims, const void* src0,\n                   const int64_t* src1_dims, const void* src1, void* dst) {\n  static_assert(max_pack_size > 0 && (max_pack_size & (max_pack_size - 1)) == 0, \"\");\n  CHECK(src0_dims[num_src_dims - 1] != 1 || src1_dims[num_src_dims - 1] != 1);\n  auto dst_ptr = reinterpret_cast<std::uintptr_t>(dst);\n  for (size_t pack_size = max_pack_size; pack_size > 2; pack_size /= 2) {\n    bool is_src0_supported = (src0_dims[num_src_dims - 1] == 1)\n                             || IsPackSizeSupported<T>(pack_size, num_src_dims, src0_dims, src0);\n    bool is_src1_supported = (src1_dims[num_src_dims - 1] == 1)\n                             || IsPackSizeSupported<T>(pack_size, num_src_dims, src1_dims, src1);\n    if (is_src0_supported && is_src1_supported && (dst_ptr % (pack_size * sizeof(R))) == 0) {\n      return pack_size;\n    }\n  }\n  return 1;\n}\n\nconstexpr size_t kMaxPackSize = 4;\n\ntemplate<BinaryOp op, typename T, typename R>\nvoid LaunchWithSimplified(Stream* stream, size_t simplified_num_dims, int64_t* simplified_src0_dims,\n                          const void* src0, int64_t* simplified_src1_dims, const void* src1,\n                          int64_t* simplified_dst_dims, void* dst, Scalar attr0, Scalar attr1) {\n  CHECK_LE(simplified_num_dims, kMaxNumDims);\n  size_t pack_size = GetPackSize<kMaxPackSize, T, R>(simplified_num_dims, simplified_src0_dims,\n                                                     src0, simplified_src1_dims, src1, dst);\n  size_t src0_pack_size = 1;\n  size_t src1_pack_size = 1;\n  if (simplified_src0_dims[simplified_num_dims - 1] != 1) {\n    simplified_src0_dims[simplified_num_dims - 1] /= pack_size;\n    src0_pack_size = pack_size;\n  }\n  if (simplified_src1_dims[simplified_num_dims - 1] != 1) {\n    simplified_src1_dims[simplified_num_dims - 1] /= pack_size;\n    src1_pack_size = pack_size;\n  }\n  simplified_dst_dims[simplified_num_dims - 1] /= pack_size;\n  DispatchNumDims<op, T, R>(stream, src0_pack_size, src1_pack_size, simplified_num_dims,\n                            simplified_src0_dims, src0, simplified_src1_dims, src1,\n                            simplified_dst_dims, dst, attr0, attr1);\n}\n\ntemplate<BinaryOp binary_op, typename Src, typename Dst>\nstruct BinaryLhsScalarFunctor {\n  __host__ __device__ BinaryLhsScalarFunctor(Src scalar, Scalar attr0, Scalar attr1)\n      : scalar(scalar), functor(attr0, attr1) {}\n  __device__ Dst operator()(Src src) const { return functor(scalar, src); }\n  const Src scalar;\n  BinaryFunctor<DeviceType::kCUDA, binary_op, Src, Dst> functor;\n};\n\ntemplate<BinaryOp binary_op, typename Src, typename Dst>\nstruct BinaryRhsScalarFunctor {\n  __host__ __device__ BinaryRhsScalarFunctor(Src scalar, Scalar attr0, Scalar attr1)\n      : scalar(scalar), functor(attr0, attr1) {}\n  __device__ Dst operator()(Src src) const { return functor(src, scalar); }\n  const Src scalar;\n  BinaryFunctor<DeviceType::kCUDA, binary_op, Src, Dst> functor;\n};\n\ntemplate<BinaryOp binary_op, typename Src, typename Dst>\nstruct BinaryLhsScalarPtrFunctorFactory {\n  __host__ __device__ BinaryLhsScalarPtrFunctorFactory(const Src* scalar_ptr, Scalar attr0,\n                                                       Scalar attr1)\n      : scalar_ptr(scalar_ptr), attr0(attr0), attr1(attr1) {}\n  __device__ BinaryLhsScalarFunctor<binary_op, Src, Dst> operator()() const {\n    return BinaryLhsScalarFunctor<binary_op, Src, Dst>(*scalar_ptr, attr0, attr1);\n  }\n  const Src* scalar_ptr;\n  Scalar attr0, attr1;\n};\n\ntemplate<BinaryOp binary_op, typename Src, typename Dst>\nstruct BinaryRhsScalarPtrFunctorFactory {\n  __host__ __device__ explicit BinaryRhsScalarPtrFunctorFactory(const Src* scalar_ptr, Scalar attr0,\n                                                                Scalar attr1)\n      : scalar_ptr(scalar_ptr), attr0(attr0), attr1(attr1) {}\n  __device__ BinaryRhsScalarFunctor<binary_op, Src, Dst> operator()() const {\n    return BinaryRhsScalarFunctor<binary_op, Src, Dst>(*scalar_ptr, attr0, attr1);\n  }\n  const Src* scalar_ptr;\n  Scalar attr0, attr1;\n};\n\ntemplate<BinaryOp binary_op, typename Src, typename Dst>\nvoid DispatchLaunch(Stream* stream, size_t num_src0_dims, const int64_t* src0_dims, const Src* src0,\n                    size_t num_src1_dims, const int64_t* src1_dims, const Src* src1, Dst* dst,\n                    Scalar attr0, Scalar attr1) {\n  auto* cuda_stream = stream->As<CudaStream>();\n  size_t simplified_num_dims = 0;\n  int64_t simplified_src0_dims[kMaxNumDims];\n  int64_t simplified_src1_dims[kMaxNumDims];\n  int64_t simplified_dst_dims[kMaxNumDims];\n  SimplifyBroadcastDims<kMaxNumDims>(num_src0_dims, src0_dims, num_src1_dims, src1_dims,\n                                     &simplified_num_dims, simplified_src0_dims,\n                                     simplified_src1_dims, simplified_dst_dims);\n  CheckInplace(simplified_num_dims, simplified_src0_dims, src0, simplified_dst_dims, dst);\n  CheckInplace(simplified_num_dims, simplified_src1_dims, src1, simplified_dst_dims, dst);\n  if (IsDimsEquals(simplified_num_dims, simplified_src0_dims, simplified_num_dims,\n                   simplified_src1_dims)) {\n    const int64_t elem_cnt = GetElementCount(simplified_num_dims, simplified_src0_dims);\n    OF_CUDA_CHECK((cuda::elementwise::Binary(\n        BinaryFunctor<DeviceType::kCUDA, binary_op, Src, Dst>(attr0, attr1), elem_cnt, dst, src0,\n        src1, cuda_stream->cuda_stream())));\n  } else {\n    if (simplified_num_dims == 1 && simplified_src0_dims[0] == 1) {\n      OF_CUDA_CHECK((cuda::elementwise::UnaryWithFactory(\n          BinaryLhsScalarPtrFunctorFactory<binary_op, Src, Dst>(src0, attr0, attr1),\n          simplified_src1_dims[0], dst, src1, cuda_stream->cuda_stream())));\n    } else if (simplified_num_dims == 1 && simplified_src1_dims[0] == 1) {\n      OF_CUDA_CHECK((cuda::elementwise::UnaryWithFactory(\n          BinaryRhsScalarPtrFunctorFactory<binary_op, Src, Dst>(src1, attr0, attr1),\n          simplified_src0_dims[0], dst, src0, cuda_stream->cuda_stream())));\n    } else {\n      LaunchWithSimplified<binary_op, Src, Dst>(stream, simplified_num_dims, simplified_src0_dims,\n                                                src0, simplified_src1_dims, src1,\n                                                simplified_dst_dims, dst, attr0, attr1);\n    }\n  }\n}\n\ntemplate<typename T>\nT GetValue(Scalar value) {\n  return value.Value<T>();\n}\n\ntemplate<>\nhalf GetValue<half>(Scalar value) {\n  return static_cast<half>(GetValue<float>(value));\n}\n\ntemplate<>\ncuComplex GetValue<cuComplex>(Scalar value) {\n  const std::complex<float> cpp_value = GetValue<std::complex<float>>(value);\n  return cuFloatComplex{cpp_value.real(), cpp_value.imag()};\n}\n\ntemplate<>\ncuDoubleComplex GetValue<cuDoubleComplex>(Scalar value) {\n  const std::complex<double> cpp_value = GetValue<std::complex<double>>(value);\n  return cuDoubleComplex{cpp_value.real(), cpp_value.imag()};\n}\n\n#if CUDA_VERSION >= 11000\n\ntemplate<>\nnv_bfloat16 GetValue<nv_bfloat16>(Scalar value) {\n  return static_cast<nv_bfloat16>(GetValue<float>(value));\n}\n\n#endif  // CUDA_VERSION >= 11000\n\ntemplate<BinaryOp binary_op, typename Src, typename Dst>\nclass BroadcastElementwiseBinaryImpl : public BroadcastElementwiseBinary {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(BroadcastElementwiseBinaryImpl);\n  BroadcastElementwiseBinaryImpl(Scalar attr0, Scalar attr1) : attr0(attr0), attr1(attr1) {}\n  ~BroadcastElementwiseBinaryImpl() override = default;\n\n  void Launch(Stream* stream, Scalar src0, size_t num_src1_dims, const int64_t* src1_dims,\n              const void* src1, void* dst) override {\n    auto* cuda_stream = stream->As<CudaStream>();\n    const size_t elem_cnt = GetElementCount(num_src1_dims, src1_dims);\n    OF_CUDA_CHECK((cuda::elementwise::Unary(\n        BinaryLhsScalarFunctor<binary_op, Src, Dst>(GetValue<Src>(src0), attr0, attr1), elem_cnt,\n        reinterpret_cast<Dst*>(dst), reinterpret_cast<const Src*>(src1),\n        cuda_stream->cuda_stream())));\n  }\n  void Launch(Stream* stream, size_t num_src0_dims, const int64_t* src0_dims, const void* src0,\n              Scalar src1, void* dst) override {\n    auto* cuda_stream = stream->As<CudaStream>();\n    const size_t elem_cnt = GetElementCount(num_src0_dims, src0_dims);\n    OF_CUDA_CHECK((cuda::elementwise::Unary(\n        BinaryRhsScalarFunctor<binary_op, Src, Dst>(GetValue<Src>(src1), attr0, attr1), elem_cnt,\n        reinterpret_cast<Dst*>(dst), reinterpret_cast<const Src*>(src0),\n        cuda_stream->cuda_stream())));\n  }\n  void Launch(Stream* stream, size_t num_src0_dims, const int64_t* src0_dims, const void* src0,\n              size_t num_src1_dims, const int64_t* src1_dims, const void* src1,\n              void* dst) override {\n    DispatchLaunch<binary_op, Src, Dst>(\n        stream, num_src0_dims, src0_dims, reinterpret_cast<const Src*>(src0), num_src1_dims,\n        src1_dims, reinterpret_cast<const Src*>(src1), reinterpret_cast<Dst*>(dst), attr0, attr1);\n  }\n\n private:\n  Scalar attr0, attr1;\n};\n\n}  // namespace\n\ntemplate<BinaryOp binary_op, typename Src, typename Dst>\nstd::unique_ptr<BroadcastElementwiseBinary> NewBroadcastElementwiseBinary(Scalar attr0,\n                                                                          Scalar attr1) {\n  return std::unique_ptr<BroadcastElementwiseBinary>(\n      new BroadcastElementwiseBinaryImpl<binary_op, Src, Dst>(attr0, attr1));\n}\n\n#define INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY(binary_op, data_type_pair) \\\n  template std::unique_ptr<BroadcastElementwiseBinary> NewBroadcastElementwiseBinary<      \\\n      binary_op, OF_PP_PAIR_FIRST(data_type_pair), OF_PP_PAIR_FIRST(data_type_pair)>(      \\\n      Scalar attr0, Scalar attr1);\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY,\n                                 BINARY_MATH_FLOATING_OP_SEQ, CUDA_PRIMITIVE_FLOATING_TYPE_SEQ);\n\n}  // namespace broadcast_elementwise_binary\n}  // namespace primitive\n}  // namespace ep\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary_activation_grad_0.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary.cuh\"\n\nnamespace oneflow {\n\nnamespace ep {\nnamespace primitive {\nnamespace broadcast_elementwise_binary {\n\n#define INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_ACTIVATION_GRAD_ENTRY(binary_op,      \\\n                                                                           data_type_pair) \\\n  template std::unique_ptr<BroadcastElementwiseBinary> NewBroadcastElementwiseBinary<      \\\n      binary_op, OF_PP_PAIR_FIRST(data_type_pair), OF_PP_PAIR_FIRST(data_type_pair)>(      \\\n      Scalar attr0, Scalar attr1);\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_ACTIVATION_GRAD_ENTRY,\n                                 BINARY_ACTIVATION_BACKWARD_OP_SEQ_0,\n                                 CUDA_PRIMITIVE_FLOATING_TYPE_SEQ);\n\n}  // namespace broadcast_elementwise_binary\n}  // namespace primitive\n}  // namespace ep\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary_activation_grad_1.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary.cuh\"\n\nnamespace oneflow {\n\nnamespace ep {\nnamespace primitive {\nnamespace broadcast_elementwise_binary {\n\n#define INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_ACTIVATION_GRAD_ENTRY(binary_op,      \\\n                                                                           data_type_pair) \\\n  template std::unique_ptr<BroadcastElementwiseBinary> NewBroadcastElementwiseBinary<      \\\n      binary_op, OF_PP_PAIR_FIRST(data_type_pair), OF_PP_PAIR_FIRST(data_type_pair)>(      \\\n      Scalar attr0, Scalar attr1);\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_ACTIVATION_GRAD_ENTRY,\n                                 BINARY_ACTIVATION_BACKWARD_OP_SEQ_1,\n                                 CUDA_PRIMITIVE_FLOATING_TYPE_SEQ);\n\n}  // namespace broadcast_elementwise_binary\n}  // namespace primitive\n}  // namespace ep\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary_activation_grad_2.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary.cuh\"\n\nnamespace oneflow {\n\nnamespace ep {\nnamespace primitive {\nnamespace broadcast_elementwise_binary {\n\n#define INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_ACTIVATION_GRAD_ENTRY(binary_op,      \\\n                                                                           data_type_pair) \\\n  template std::unique_ptr<BroadcastElementwiseBinary> NewBroadcastElementwiseBinary<      \\\n      binary_op, OF_PP_PAIR_FIRST(data_type_pair), OF_PP_PAIR_FIRST(data_type_pair)>(      \\\n      Scalar attr0, Scalar attr1);\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_ACTIVATION_GRAD_ENTRY,\n                                 BINARY_ACTIVATION_BACKWARD_OP_SEQ_2,\n                                 CUDA_PRIMITIVE_FLOATING_TYPE_SEQ);\n\n}  // namespace broadcast_elementwise_binary\n}  // namespace primitive\n}  // namespace ep\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary_bitwise.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary.cuh\"\n\nnamespace oneflow {\n\nnamespace ep {\nnamespace primitive {\nnamespace broadcast_elementwise_binary {\n\n#define INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_BITWISE_ENTRY(binary_op, data_type_pair) \\\n  template std::unique_ptr<BroadcastElementwiseBinary> NewBroadcastElementwiseBinary<         \\\n      binary_op, OF_PP_PAIR_FIRST(data_type_pair), OF_PP_PAIR_FIRST(data_type_pair)>(         \\\n      Scalar attr0, Scalar attr1);\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_BITWISE_ENTRY,\n                                 BINARY_BITWISE_OP_SEQ,\n                                 CUDA_PRIMITIVE_INT_TYPE_SEQ CUDA_PRIMITIVE_BOOL_TYPE_SEQ);\n\n}  // namespace broadcast_elementwise_binary\n}  // namespace primitive\n}  // namespace ep\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary_comparision_0.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary.cuh\"\n\nnamespace oneflow {\n\nnamespace ep {\nnamespace primitive {\nnamespace broadcast_elementwise_binary {\n\n#define INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_ENTRY(                       \\\n    binary_op, src_data_type_pair, dst_data_type_pair)                                        \\\n  template std::unique_ptr<BroadcastElementwiseBinary> NewBroadcastElementwiseBinary<         \\\n      binary_op, OF_PP_PAIR_FIRST(src_data_type_pair), OF_PP_PAIR_FIRST(dst_data_type_pair)>( \\\n      Scalar attr0, Scalar attr1);\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_ENTRY,\n                                 BINARY_COMPARISION_OP_SEQ_0, CUDA_PRIMITIVE_REAL_TYPE_SEQ,\n                                 CUDA_PRIMITIVE_BOOL_TYPE_SEQ);\n\n}  // namespace broadcast_elementwise_binary\n}  // namespace primitive\n}  // namespace ep\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary_comparision_1.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary.cuh\"\n\nnamespace oneflow {\n\nnamespace ep {\nnamespace primitive {\nnamespace broadcast_elementwise_binary {\n\n#define INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_ENTRY(                       \\\n    binary_op, src_data_type_pair, dst_data_type_pair)                                        \\\n  template std::unique_ptr<BroadcastElementwiseBinary> NewBroadcastElementwiseBinary<         \\\n      binary_op, OF_PP_PAIR_FIRST(src_data_type_pair), OF_PP_PAIR_FIRST(dst_data_type_pair)>( \\\n      Scalar attr0, Scalar attr1);\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_ENTRY,\n                                 BINARY_COMPARISION_OP_SEQ_1, CUDA_PRIMITIVE_REAL_TYPE_SEQ,\n                                 CUDA_PRIMITIVE_BOOL_TYPE_SEQ);\n\n}  // namespace broadcast_elementwise_binary\n}  // namespace primitive\n}  // namespace ep\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary_comparision_complex.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/ep/common/primitive/broadcast_elementwise_binary.h\"\n#include \"oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary.cuh\"\n#include \"oneflow/core/ep/cuda/primitive/type_seq.h\"\n\nnamespace oneflow {\n\nnamespace ep {\nnamespace primitive {\nnamespace broadcast_elementwise_binary {\n\n#define INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_ENTRY(                       \\\n    binary_op, src_data_type_pair, dst_data_type_pair)                                        \\\n  template std::unique_ptr<BroadcastElementwiseBinary> NewBroadcastElementwiseBinary<         \\\n      binary_op, OF_PP_PAIR_FIRST(src_data_type_pair), OF_PP_PAIR_FIRST(dst_data_type_pair)>( \\\n      Scalar attr0, Scalar attr1);\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_ENTRY,\n                                 BINARY_COMPLEX_COMPARISION_OP_SEQ, CUDA_PRIMITIVE_COMPLEX_TYPE_SEQ,\n                                 CUDA_PRIMITIVE_BOOL_TYPE_SEQ);\n\n}  // namespace broadcast_elementwise_binary\n}  // namespace primitive\n}  // namespace ep\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary_logical.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary.cuh\"\n\nnamespace oneflow {\n\nnamespace ep {\nnamespace primitive {\nnamespace broadcast_elementwise_binary {\n\n#define INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_LOGICAL_ENTRY(binary_op, src_data_type_pair, \\\n                                                                   dst_data_type_pair)            \\\n  template std::unique_ptr<BroadcastElementwiseBinary> NewBroadcastElementwiseBinary<             \\\n      binary_op, OF_PP_PAIR_FIRST(src_data_type_pair), OF_PP_PAIR_FIRST(dst_data_type_pair)>(     \\\n      Scalar attr0, Scalar attr1);\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_LOGICAL_ENTRY,\n                                 BINARY_LOGICAL_OP_SEQ, CUDA_PRIMITIVE_REAL_TYPE_SEQ,\n                                 CUDA_PRIMITIVE_BOOL_TYPE_SEQ);\n\n}  // namespace broadcast_elementwise_binary\n}  // namespace primitive\n}  // namespace ep\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary_math_0.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary.cuh\"\n\nnamespace oneflow {\n\nnamespace ep {\nnamespace primitive {\nnamespace broadcast_elementwise_binary {\n\n#define INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY(binary_op, data_type_pair) \\\n  template std::unique_ptr<BroadcastElementwiseBinary> NewBroadcastElementwiseBinary<      \\\n      binary_op, OF_PP_PAIR_FIRST(data_type_pair), OF_PP_PAIR_FIRST(data_type_pair)>(      \\\n      Scalar attr0, Scalar attr1);\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY,\n                                 BINARY_MATH_OP_SEQ_0, CUDA_PRIMITIVE_REAL_TYPE_SEQ);\n\n}  // namespace broadcast_elementwise_binary\n}  // namespace primitive\n}  // namespace ep\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary_math_1.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary.cuh\"\n\nnamespace oneflow {\n\nnamespace ep {\nnamespace primitive {\nnamespace broadcast_elementwise_binary {\n\n#define INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY(binary_op, data_type_pair) \\\n  template std::unique_ptr<BroadcastElementwiseBinary> NewBroadcastElementwiseBinary<      \\\n      binary_op, OF_PP_PAIR_FIRST(data_type_pair), OF_PP_PAIR_FIRST(data_type_pair)>(      \\\n      Scalar attr0, Scalar attr1);\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY,\n                                 BINARY_MATH_OP_SEQ_1, CUDA_PRIMITIVE_REAL_TYPE_SEQ);\n\n}  // namespace broadcast_elementwise_binary\n}  // namespace primitive\n}  // namespace ep\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary_math_2.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary.cuh\"\n\nnamespace oneflow {\n\nnamespace ep {\nnamespace primitive {\nnamespace broadcast_elementwise_binary {\n\n#define INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY(binary_op, data_type_pair) \\\n  template std::unique_ptr<BroadcastElementwiseBinary> NewBroadcastElementwiseBinary<      \\\n      binary_op, OF_PP_PAIR_FIRST(data_type_pair), OF_PP_PAIR_FIRST(data_type_pair)>(      \\\n      Scalar attr0, Scalar attr1);\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY,\n                                 BINARY_MATH_OP_SEQ_2, CUDA_PRIMITIVE_REAL_TYPE_SEQ);\n\n}  // namespace broadcast_elementwise_binary\n}  // namespace primitive\n}  // namespace ep\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary_math_complex.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary.cuh\"\n\nnamespace oneflow {\n\nnamespace ep {\nnamespace primitive {\nnamespace broadcast_elementwise_binary {\n\n#define INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY(binary_op, data_type_pair) \\\n  template std::unique_ptr<BroadcastElementwiseBinary> NewBroadcastElementwiseBinary<      \\\n      binary_op, OF_PP_PAIR_FIRST(data_type_pair), OF_PP_PAIR_FIRST(data_type_pair)>(      \\\n      Scalar attr0, Scalar attr1);\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY,\n                                 BINARY_COMPLEX_MATH_OP_SEQ, CUDA_PRIMITIVE_COMPLEX_TYPE_SEQ);\n\n}  // namespace broadcast_elementwise_binary\n}  // namespace primitive\n}  // namespace ep\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/ep/cuda/primitive/broadcast_elementwise_unary.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/ep/common/primitive/broadcast_elementwise_unary.h\"\n#include \"oneflow/core/ep/include/primitive/permute.h\"\n#include \"oneflow/core/ep/cuda/primitive/unary_functor.cuh\"\n#include \"oneflow/core/ep/cuda/primitive/type_seq.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n\nnamespace oneflow {\n\nnamespace ep {\nnamespace primitive {\nnamespace broadcast_elementwise_unary {\n\nnamespace {\n\n#define CUDA_PRIMITIVE_CAST_REAL_TYPE_SEQ \\\n  CUDA_PRIMITIVE_INT16_TYPE_SEQ           \\\n  CUDA_PRIMITIVE_UINT32_TYPE_SEQ          \\\n  CUDA_PRIMITIVE_REAL_TYPE_SEQ\n\nconstexpr size_t kMaxPackSize = 4;\n\ntemplate<size_t max_pack_size, typename Src, typename Dst>\nsize_t GetPackSize(size_t num_dims, const int64_t* src_dims, const void* src,\n                   const int64_t* dst_dims, const void* dst) {\n  static_assert(max_pack_size > 0 && (max_pack_size & (max_pack_size - 1)) == 0, \"\");\n  for (size_t pack_size = max_pack_size; pack_size > 2; pack_size /= 2) {\n    bool is_src_supported = IsPackSizeSupported<Src>(pack_size, num_dims, src_dims, src);\n    bool is_dst_supported = IsPackSizeSupported<Dst>(pack_size, num_dims, dst_dims, dst);\n    if (is_src_supported && is_dst_supported) { return pack_size; }\n  }\n  return 1;\n}\n\ntemplate<typename Src, typename Dst, size_t max_dims, typename IndexType>\nstruct BroadcastElementwiseUnaryParams {\n  OffsetToIndexWithStrideCalculator<IndexType, max_dims> dst_offset_to_index_helper;\n  size_t num_dims;\n  int64_t src_strides[max_dims];\n  int64_t dst_strides[max_dims];\n  IndexType src_index_mask[max_dims];\n  IndexType count{};\n  const Src* src{};\n  Dst* dst{};\n  bool dst_is_contiguous;\n  Scalar attr0;\n  Scalar attr1;\n};\n\ntemplate<UnaryOp unary_op, typename Src, typename Dst>\nstruct UnaryScalarFunctor {\n  __host__ __device__ explicit UnaryScalarFunctor(Src scalar) : scalar(scalar) {}\n  __device__ Dst operator()() const {\n    return UnaryFunctor<DeviceType::kCUDA, unary_op, Dst, Src>()(scalar);\n  }\n  const Src scalar;\n};\n\ntemplate<UnaryOp unary_op, typename Src, typename Dst>\nstruct UnaryScalarPtrFunctorFactory {\n  __host__ __device__ explicit UnaryScalarPtrFunctorFactory(const Src* scalar_ptr)\n      : scalar_ptr(scalar_ptr) {}\n  __device__ UnaryScalarFunctor<unary_op, Src, Dst> operator()() const {\n    return UnaryScalarFunctor<unary_op, Src, Dst>(*scalar_ptr);\n  }\n  const Src* scalar_ptr;\n};\n\ntemplate<UnaryOp op, typename Src, typename Dst, size_t max_dims, size_t pack_size,\n         typename IndexType>\n__global__ void BroadcastElementwiseUnaryGpu(\n    BroadcastElementwiseUnaryParams<Src, Dst, max_dims, IndexType> params) {\n  using LoadPack = cuda::elementwise::Packed<Src, pack_size>;\n  using StorePack = cuda::elementwise::Packed<Dst, pack_size>;\n  const LoadPack* src = reinterpret_cast<const LoadPack*>(params.src);\n  StorePack* dst = reinterpret_cast<StorePack*>(params.dst);\n\n  size_t num_dims = params.num_dims;\n  const int64_t* src_strides = params.src_strides;\n  const int64_t* dst_strides = params.dst_strides;\n  auto functor = UnaryFunctor<DeviceType::kCUDA, op, Dst, Src>(params.attr0, params.attr1);\n\n  CUDA_1D_KERNEL_LOOP_T(IndexType, offset, params.count) {\n    IndexType src_offset = 0;\n    IndexType dst_offset = 0;\n    IndexType remaining = offset;\n#pragma unroll\n    for (int i = 0; i < max_dims; ++i) {\n      if (i < num_dims - 1) {\n        IndexType dst_index = params.dst_offset_to_index_helper.divides(remaining, i);\n        remaining = remaining - params.dst_offset_to_index_helper.mul(dst_index, i);\n        dst_offset += dst_index * dst_strides[i];\n        src_offset += params.src_index_mask[i] * dst_index * src_strides[i];\n      } else if (i == num_dims - 1) {\n        dst_offset += remaining * dst_strides[i];\n        src_offset += params.src_index_mask[i] * remaining * src_strides[i];\n      } else {\n        break;\n      }\n    }\n\n    LoadPack src_pack = src[src_offset];\n    StorePack dst_pack;\n#pragma unroll\n    for (int j = 0; j < pack_size; ++j) { dst_pack.elem[j] = functor(src_pack.elem[j]); }\n    dst[dst_offset] = dst_pack;\n  }\n}\n\ntemplate<UnaryOp op, typename Src, typename Dst, size_t max_dims, size_t pack_size,\n         typename IndexType>\nvoid LaunchKernel(CudaStream* stream, size_t num_dims, const int64_t* src_dims,\n                  const int64_t* src_strides, const Src* src, const int64_t* dst_dims,\n                  const int64_t* dst_strides, Dst* dst, bool continuous_output, Scalar attr0,\n                  Scalar attr1, size_t count) {\n  BroadcastElementwiseUnaryParams<Src, Dst, max_dims, IndexType> params;\n  for (size_t i = 0; i < num_dims; ++i) {\n    params.src_index_mask[i] = (src_dims[i] == 1) ? 0 : 1;\n    params.src_strides[i] = src_strides[i];\n    params.dst_strides[i] = dst_strides[i];\n  }\n  params.dst_offset_to_index_helper =\n      OffsetToIndexWithStrideCalculator<IndexType, max_dims>(dst_dims, num_dims);\n  params.num_dims = num_dims;\n  params.src = src;\n  params.dst = dst;\n  params.count = static_cast<IndexType>(count);\n  params.attr0 = attr0;\n  params.attr1 = attr1;\n  params.dst_is_contiguous = continuous_output;\n\n  BroadcastElementwiseUnaryGpu<op, Src, Dst, max_dims, pack_size, IndexType>\n      <<<BlocksNum4ThreadsNum(params.count), kCudaThreadsNumPerBlock, 0, stream->cuda_stream()>>>(\n          params);\n}\n\ntemplate<UnaryOp op, typename Src, typename Dst, size_t max_dims, size_t pack_size>\nvoid DispatchIndexType(CudaStream* stream, size_t num_dims, const int64_t* src_dims,\n                       const int64_t* src_strides, const Src* src, const int64_t* dst_dims,\n                       const int64_t* dst_strides, Dst* dst, bool continuous_output, Scalar attr0,\n                       Scalar attr1) {\n  size_t count = GetElementCount(num_dims, dst_dims);\n  if (count < GetMaxVal<int32_t>() / 2) {\n    LaunchKernel<op, Src, Dst, max_dims, pack_size, int32_t>(\n        stream, num_dims, src_dims, src_strides, src, dst_dims, dst_strides, dst, continuous_output,\n        attr0, attr1, count);\n  } else {\n    LaunchKernel<op, Src, Dst, max_dims, pack_size, int64_t>(\n        stream, num_dims, src_dims, src_strides, src, dst_dims, dst_strides, dst, continuous_output,\n        attr0, attr1, count);\n  }\n}\n\ntemplate<UnaryOp op, typename Src, typename Dst, size_t max_dims>\nvoid DispatchPackSize(CudaStream* stream, size_t pack_size, size_t num_dims,\n                      const int64_t* src_dims, const int64_t* src_strides, const Src* src,\n                      const int64_t* dst_dims, const int64_t* dst_strides, Dst* dst,\n                      bool continuous_output, Scalar attr0, Scalar attr1) {\n  void (*func)(CudaStream* /*stream*/, size_t /*num_dims*/, const int64_t* /*src_dims*/,\n               const int64_t* /*src_strides*/, const Src* /*src*/, const int64_t* /*dst_dims*/,\n               const int64_t* /*dst_strides*/, Dst* /*dst*/, bool /*continuous_output*/,\n               Scalar /*attr0*/, Scalar /*attr1*/) = nullptr;\n  if (pack_size == 1) {\n    func = DispatchIndexType<op, Src, Dst, max_dims, 1>;\n  } else if (pack_size == 4) {\n    func = DispatchIndexType<op, Src, Dst, max_dims, 4>;\n  } else {\n    UNIMPLEMENTED();\n  }\n  func(stream, num_dims, src_dims, src_strides, src, dst_dims, dst_strides, dst, continuous_output,\n       attr0, attr1);\n}\n\ntemplate<UnaryOp op, typename Src, typename Dst>\nvoid DispatchNumDims(CudaStream* stream, size_t pack_size, size_t num_dims, const int64_t* src_dims,\n                     const int64_t* src_strides, const Src* src, const int64_t* dst_dims,\n                     const int64_t* dst_strides, Dst* dst, bool continuous_output, Scalar attr0,\n                     Scalar attr1) {\n  void (*func)(CudaStream* /*stream*/, size_t /*pack_size*/, size_t /*num_dims*/,\n               const int64_t* /*src_dims*/, const int64_t* /*src_strides*/, const Src* /*src*/,\n               const int64_t* /*dst_dims*/, const int64_t* /*dst_strides*/, Dst* /*dst*/,\n               bool /*continuous_output*/, Scalar /*attr0*/, Scalar /*attr1*/) = nullptr;\n  if (num_dims == 1) {\n    func = DispatchPackSize<op, Src, Dst, 1>;\n  } else if (num_dims == 2) {\n    func = DispatchPackSize<op, Src, Dst, 2>;\n  } else if (num_dims == 3) {\n    func = DispatchPackSize<op, Src, Dst, 3>;\n  } else if (num_dims == 4) {\n    func = DispatchPackSize<op, Src, Dst, 4>;\n  } else if (num_dims <= kMaxNumDims) {\n    func = DispatchPackSize<op, Src, Dst, kMaxNumDims>;\n  } else {\n    UNIMPLEMENTED();\n  }\n  func(stream, pack_size, num_dims, src_dims, src_strides, src, dst_dims, dst_strides, dst,\n       continuous_output, attr0, attr1);\n}\n\ntemplate<UnaryOp op, typename Src, typename Dst>\nvoid LaunchWithSimplified(CudaStream* stream, size_t simplified_num_dims,\n                          int64_t* simplified_src_dims, int64_t* simplified_src_strides,\n                          const Src* src, int64_t* simplified_dst_dims,\n                          int64_t* simplified_dst_strides, Dst* dst, Scalar attr0, Scalar attr1) {\n  CHECK_LE(simplified_num_dims, kMaxNumDims);\n  bool src_enable_pack = (simplified_src_strides[simplified_num_dims - 1] == 1);\n  bool dst_enable_pack = (simplified_dst_strides[simplified_num_dims - 1] == 1);\n  size_t pack_size = 1;\n  if (src_enable_pack && dst_enable_pack) {\n    pack_size = GetPackSize<kMaxPackSize, Src, Dst>(simplified_num_dims, simplified_src_dims, src,\n                                                    simplified_dst_dims, dst);\n  }\n  bool continuous_output = true;\n  for (int i = simplified_num_dims - 1; i >= 0; i--) {\n    if ((i == simplified_num_dims - 1 && simplified_dst_strides[i] != 1)\n        || (i != simplified_num_dims - 1\n            && simplified_dst_strides[i]\n                   != simplified_dst_strides[i + 1] * simplified_dst_dims[i + 1])) {\n      continuous_output = false;\n      break;\n    }\n  }\n  simplified_src_dims[simplified_num_dims - 1] /= pack_size;\n  simplified_dst_dims[simplified_num_dims - 1] /= pack_size;\n  for (int i = 0; i < simplified_num_dims - 1; i++) {\n    simplified_src_strides[i] /= pack_size;\n    simplified_dst_strides[i] /= pack_size;\n  }\n  DispatchNumDims<op, Src, Dst>(stream, pack_size, simplified_num_dims, simplified_src_dims,\n                                simplified_src_strides, src, simplified_dst_dims,\n                                simplified_dst_strides, dst, continuous_output, attr0, attr1);\n}\n\ntemplate<UnaryOp op, typename Src, typename Dst, size_t pack, bool tail>\n__global__ void LaunchFillKernel(UnaryFunctor<DeviceType::kCUDA, op, Dst, Src> functor, Dst* dst,\n                                 const Src* src, size_t pack_count, size_t count, size_t tail_count,\n                                 Dst* tail_dst) {\n  using StorePack = cuda::elementwise::Packed<Dst, pack>;\n  StorePack pack_value;\n  Dst value = functor(*src);\n#pragma unroll\n  for (size_t i = 0; i < pack; ++i) { pack_value.elem[i] = value; }\n  StorePack* pack_dst = reinterpret_cast<StorePack*>(dst);\n  CUDA_1D_KERNEL_LOOP_T(size_t, i, pack_count) { pack_dst[i] = pack_value; }\n  if (tail) {\n    CUDA_1D_KERNEL_LOOP_T(size_t, i, tail_count) { tail_dst[i] = value; }\n  }\n}\n\ntemplate<UnaryOp op, typename Src, typename Dst, size_t pack>\ntypename std::enable_if<(pack != 0), void>::type LaunchPackFill(CudaStream* stream, Dst* dst,\n                                                                const Src* src, size_t count,\n                                                                Scalar attr0, Scalar attr1) {\n  const size_t pack_count = count / pack;\n  const size_t tail_offset = pack_count * pack;\n  const size_t tail_count = count - tail_offset;\n  auto functor = UnaryFunctor<DeviceType::kCUDA, op, Dst, Src>(attr0, attr1);\n  if (tail_count > 0) {\n    LaunchFillKernel<op, Src, Dst, pack, true>\n        <<<BlocksNum4ThreadsNum(pack_count), kCudaThreadsNumPerBlock, 0, stream->cuda_stream()>>>(\n            functor, dst, src, pack_count, count, tail_count, dst + tail_offset);\n  } else {\n    LaunchFillKernel<op, Src, Dst, pack, false>\n        <<<BlocksNum4ThreadsNum(pack_count), kCudaThreadsNumPerBlock, 0, stream->cuda_stream()>>>(\n            functor, dst, src, pack_count, count, tail_count, dst + tail_offset);\n  }\n}\n\ntemplate<UnaryOp op, typename Src, typename Dst, size_t pack>\ntypename std::enable_if<(pack == 0), void>::type LaunchPackFill(CudaStream* stream, Dst* dst,\n                                                                const Src* src, size_t count,\n                                                                Scalar attr0, Scalar attr1) {\n  LOG(FATAL) << \"wrong alignment\";\n}\n\ntemplate<UnaryOp op, typename Src, typename Dst>\nvoid LaunchFill(CudaStream* stream, Dst* dst, const Src* src, size_t count, Scalar attr0,\n                Scalar attr1) {\n  auto uintptr = reinterpret_cast<std::uintptr_t>(dst);\n  if (uintptr % 16 == 0 && count * sizeof(Dst) >= 16) {\n    LaunchPackFill<op, Src, Dst, 16 / sizeof(Dst)>(stream, dst, src, count, attr0, attr1);\n  } else if (uintptr % 8 == 0 && count * sizeof(Dst) >= 8) {\n    LaunchPackFill<op, Src, Dst, 8 / sizeof(Dst)>(stream, dst, src, count, attr0, attr1);\n  } else if (uintptr % 4 == 0 && count * sizeof(Dst) >= 4) {\n    LaunchPackFill<op, Src, Dst, 4 / sizeof(Dst)>(stream, dst, src, count, attr0, attr1);\n  } else if (uintptr % 2 == 0 && count * sizeof(Dst) >= 2) {\n    LaunchPackFill<op, Src, Dst, 2 / sizeof(Dst)>(stream, dst, src, count, attr0, attr1);\n  } else {\n    LaunchPackFill<op, Src, Dst, 1 / sizeof(Dst)>(stream, dst, src, count, attr0, attr1);\n  }\n}\n\ntemplate<UnaryOp unary_op, typename Src, DataType src_type, typename Dst, DataType dst_type>\nclass BroadcastElementwiseUnaryImpl : public BroadcastElementwiseUnary {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(BroadcastElementwiseUnaryImpl);\n  BroadcastElementwiseUnaryImpl(Scalar attr0, Scalar attr1) : attr0(attr0), attr1(attr1) {}\n  ~BroadcastElementwiseUnaryImpl() override = default;\n\n  void Launch(Stream* stream, size_t num_src_dims, const int64_t* src_dims, const void* src,\n              size_t num_dst_dims, const int64_t* dst_dims, void* dst) override {\n    CHECK_GT(num_src_dims, 0) << \"num_src_dims must greater than 0\";\n    CHECK_GT(num_dst_dims, 0) << \"num_dst_dims must greater than 0\";\n    int64_t src_strides[kMaxNumDims];\n    int64_t dst_strides[kMaxNumDims];\n    // init stride\n    for (int i = num_src_dims - 1; i < kMaxNumDims; ++i) { src_strides[i] = 1; }\n    for (int i = num_src_dims - 2; i >= 0; --i) {\n      src_strides[i] = src_dims[i + 1] * src_strides[i + 1];\n    }\n\n    for (int i = num_dst_dims - 1; i < kMaxNumDims; ++i) { dst_strides[i] = 1; }\n    for (int i = num_dst_dims - 2; i >= 0; --i) {\n      dst_strides[i] = dst_dims[i + 1] * dst_strides[i + 1];\n    }\n    Launch(stream, num_src_dims, src_dims, src_strides, src, num_dst_dims, dst_dims, dst_strides,\n           dst);\n  }\n\n  void Launch(Stream* stream, size_t num_src_dims, const int64_t* src_dims,\n              const int64_t* src_strides, const void* src_ptr, size_t num_dst_dims,\n              const int64_t* dst_dims, const int64_t* dst_strides, void* dst_ptr) override {\n    CHECK_GT(num_src_dims, 0) << \"num_src_dims must greater than 0\";\n    CHECK_GT(num_dst_dims, 0) << \"num_dst_dims must greater than 0\";\n    auto* cuda_stream = stream->As<CudaStream>();\n    Dst* dst = reinterpret_cast<Dst*>(dst_ptr);\n    const Src* src = reinterpret_cast<const Src*>(src_ptr);\n    size_t simplified_num_dims = 0;\n    int permutation_list[kMaxNumDims];\n    int64_t permutation_src_dims[kMaxNumDims];\n    int64_t simplified_src_dims[kMaxNumDims];\n    int64_t simplified_dst_dims[kMaxNumDims];\n    int64_t simplified_src_strides[kMaxNumDims];\n    int64_t simplified_dst_strides[kMaxNumDims];\n    SimplifyBroadcastDims<kMaxNumDims>(num_src_dims, src_dims, src_strides, num_dst_dims, dst_dims,\n                                       dst_strides, &simplified_num_dims, simplified_src_dims,\n                                       simplified_src_strides, simplified_dst_dims,\n                                       simplified_dst_strides);\n    bool permutable = InferPermutable<kMaxNumDims>(\n        simplified_num_dims, simplified_src_strides, simplified_dst_strides, simplified_src_dims,\n        simplified_dst_dims, permutation_list, permutation_src_dims, unary_op);\n    std::unique_ptr<Permute> permute =\n        NewPrimitive<PermuteFactory>(DeviceType::kCUDA, simplified_num_dims);\n    CheckInplace(simplified_num_dims, simplified_src_dims, src, simplified_dst_dims, dst);\n    CheckInplace(simplified_num_dims, simplified_src_strides, src, simplified_dst_strides, dst);\n    if (simplified_num_dims == 1 && simplified_src_dims[0] == 1) {\n      const int64_t elem_cnt = simplified_dst_dims[0];\n      LaunchFill<unary_op, Src, Dst>(cuda_stream, dst, src, elem_cnt, attr0, attr1);\n    } else if (simplified_num_dims == 1 && simplified_src_strides[0] == 1\n               && simplified_dst_strides[0] == 1) {\n      const int64_t elem_cnt = simplified_src_dims[0];\n      auto functor = UnaryFunctor<DeviceType::kCUDA, unary_op, Dst, Src>(attr0, attr1);\n      OF_CUDA_CHECK((cuda::elementwise::Unary<decltype(functor), Dst, Src>(\n          functor, elem_cnt, dst, src, cuda_stream->cuda_stream())));\n    } else if (permutable && src_type == dst_type && permute) {\n      permute->Launch(stream, dst_type, simplified_num_dims, permutation_src_dims, src_ptr,\n                      permutation_list, dst_ptr);\n    } else {\n      // fall back to normal cases\n      LaunchWithSimplified<unary_op, Src, Dst>(\n          cuda_stream, simplified_num_dims, simplified_src_dims, simplified_src_strides, src,\n          simplified_dst_dims, simplified_dst_strides, dst, attr0, attr1);\n    }\n  }\n\n protected:\n  Scalar attr0, attr1;\n};\n\ntemplate<UnaryOp unary_op, typename Src, DataType src_type, typename Dst, DataType dst_type>\nstd::unique_ptr<BroadcastElementwiseUnary> NewBroadcastElementwiseUnary(Scalar attr0,\n                                                                        Scalar attr1) {\n  return std::unique_ptr<BroadcastElementwiseUnary>(\n      new BroadcastElementwiseUnaryImpl<unary_op, Src, src_type, Dst, dst_type>(attr0, attr1));\n}\n\nclass BroadcastElementwiseUnaryFactoryImpl : public BroadcastElementwiseUnaryFactory {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(BroadcastElementwiseUnaryFactoryImpl);\n  BroadcastElementwiseUnaryFactoryImpl() = default;\n  ~BroadcastElementwiseUnaryFactoryImpl() override = default;\n\n  std::unique_ptr<BroadcastElementwiseUnary> New(UnaryOp op, DataType src_type, DataType dst_type,\n                                                 size_t max_num_dims) override {\n    return New(op, src_type, dst_type, max_num_dims, Scalar(), Scalar());\n  }\n\n  std::unique_ptr<BroadcastElementwiseUnary> New(UnaryOp op, DataType src_type, DataType dst_type,\n                                                 size_t max_num_dims, Scalar attr0) override {\n    return New(op, src_type, dst_type, max_num_dims, attr0, Scalar());\n  }\n\n  std::unique_ptr<BroadcastElementwiseUnary> New(UnaryOp unary_op, DataType src_type,\n                                                 DataType dst_type, size_t max_num_dims,\n                                                 Scalar attr0, Scalar attr1) override {\n    if (max_num_dims > kMaxNumDims) { return nullptr; }\n#define MAKE_NEW_SAME_DTYPE_BROADCAST_ELEMENTWISE_UNARY_ENTRY(unary_op, dtype_pair)          \\\n  {std::make_tuple(unary_op, OF_PP_PAIR_SECOND(dtype_pair), OF_PP_PAIR_SECOND(dtype_pair)),  \\\n   NewBroadcastElementwiseUnary<unary_op, OF_PP_PAIR_FIRST(dtype_pair),                      \\\n                                OF_PP_PAIR_SECOND(dtype_pair), OF_PP_PAIR_FIRST(dtype_pair), \\\n                                OF_PP_PAIR_SECOND(dtype_pair)>},\n\n#define MAKE_NEW_BROADCAST_ELEMENTWISE_UNARY_ENTRY(unary_op, src_dtype_pair, dst_dtype_pair) \\\n  {std::make_tuple(unary_op, OF_PP_PAIR_SECOND(src_dtype_pair),                              \\\n                   OF_PP_PAIR_SECOND(dst_dtype_pair)),                                       \\\n   NewBroadcastElementwiseUnary<                                                             \\\n       unary_op, OF_PP_PAIR_FIRST(src_dtype_pair), OF_PP_PAIR_SECOND(src_dtype_pair),        \\\n       OF_PP_PAIR_FIRST(dst_dtype_pair), OF_PP_PAIR_SECOND(dst_dtype_pair)>},\n\n    static const std::map<std::tuple<UnaryOp, DataType, DataType>,\n                          std::function<std::unique_ptr<BroadcastElementwiseUnary>(Scalar, Scalar)>>\n        new_broadcast_elementwise_unary_handle{\n            // For All Type OP\n            OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_NEW_SAME_DTYPE_BROADCAST_ELEMENTWISE_UNARY_ENTRY,\n                                             UNARY_IDENTITY_SEQ, CUDA_PRIMITIVE_REAL_TYPE_SEQ)\n\n                OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(\n                    MAKE_NEW_SAME_DTYPE_BROADCAST_ELEMENTWISE_UNARY_ENTRY, UNARY_IDENTITY_SEQ,\n                    CUDA_PRIMITIVE_COMPLEX_TYPE_SEQ)\n\n            // For Cast OP\n            OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(\n                MAKE_NEW_BROADCAST_ELEMENTWISE_UNARY_ENTRY, BROADCAST_ELEMENTWISE_CAST_OP_SEQ,\n                CUDA_PRIMITIVE_CAST_REAL_TYPE_SEQ,\n                CUDA_PRIMITIVE_CAST_REAL_TYPE_SEQ CUDA_PRIMITIVE_COMPLEX_TYPE_SEQ)\n\n                OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(\n                    MAKE_NEW_BROADCAST_ELEMENTWISE_UNARY_ENTRY, BROADCAST_ELEMENTWISE_CAST_OP_SEQ,\n                    CUDA_PRIMITIVE_COMPLEX_TYPE_SEQ, CUDA_PRIMITIVE_COMPLEX_TYPE_SEQ)\n\n        };\n\n#undef MAKE_NEW_BROADCAST_ELEMENTWISE_UNARY_ENTRY\n#undef MAKE_NEW_SAME_DTYPE_BROADCAST_ELEMENTWISE_UNARY_ENTRY\n\n    const auto iter =\n        new_broadcast_elementwise_unary_handle.find(std::make_tuple(unary_op, src_type, dst_type));\n    if (iter != new_broadcast_elementwise_unary_handle.end()) {\n      return iter->second(attr0, attr1);\n    } else {\n      return nullptr;\n    }\n  }\n};\n\nREGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, BroadcastElementwiseUnaryFactory,\n                           BroadcastElementwiseUnaryFactoryImpl);\n\n}  // namespace\n}  // namespace broadcast_elementwise_unary\n}  // namespace primitive\n}  // namespace ep\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/ep/cuda/primitive/broadcast_matmul.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifdef WITH_CUDA\n\n#include \"oneflow/core/ep/include/primitive/primitive.h\"\n#include \"oneflow/core/ep/include/primitive/broadcast_matmul.h\"\n#include \"oneflow/core/ep/common/primitive/broadcast_matmul.h\"\n#include \"oneflow/core/common/optional.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n#include \"oneflow/core/ep/cuda/cuda_matmul_mode.h\"\n#include <cuda.h>\n#include <cuComplex.h>\n\nnamespace oneflow {\n\nnamespace ep {\nnamespace primitive {\n\nnamespace broadcast_matmul {\n\nnamespace internal {\n\nnamespace {\n\nconstexpr size_t kMaxNumDims = 8;\n\nOptional<cudaDataType_t> OptCudaDataType(DataType data_type) {\n  switch (data_type) {\n    case kFloat: return CUDA_R_32F;\n    case kDouble: return CUDA_R_64F;\n    case kFloat16: return CUDA_R_16F;\n    case kComplex64: return CUDA_C_32F;\n    case kComplex128: return CUDA_C_64F;\n#if CUDA_VERSION >= 11000\n    case kBFloat16: return CUDA_R_16BF;\n#endif  // CUDA_VERSION >= 11000\n    default: return NullOpt;\n  }\n}\n\ncudaDataType_t GetCudaDataType(DataType data_type) {\n  auto cuda_data_type = OptCudaDataType(data_type);\n  CHECK(cuda_data_type.has_value());\n  return cuda_data_type.value_or(CUDA_R_32F);\n}\n\nunion CublasScalarParameter {\n  double d;\n  float s;\n  half h;\n  cuComplex c;\n  cuDoubleComplex z;\n};\n\nCublasScalarParameter GetCublasScalarParameter(Scalar scalar, cublasComputeType_t compute_type) {\n  CublasScalarParameter sp{};\n  if (compute_type == CUBLAS_COMPUTE_64F) {\n    sp.d = scalar.Value<double>();\n  } else if (compute_type == CUBLAS_COMPUTE_32F_PEDANTIC\n             || compute_type == CUBLAS_COMPUTE_32F_FAST_TF32\n             || compute_type == CUBLAS_COMPUTE_32F) {\n    sp.s = scalar.Value<float>();\n  } else if (compute_type == CUBLAS_COMPUTE_16F) {\n    sp.h = static_cast<half>(scalar.Value<float>());\n  } else {\n    UNIMPLEMENTED();\n  }\n  return sp;\n}\n\ncudaDataType_t GetCublasScalarType(DataType data_type) {\n  switch (data_type) {\n    case kFloat: return CUDA_R_32F;\n    case kDouble: return CUDA_R_64F;\n    case kComplex64: return CUDA_C_32F;\n    case kComplex128: return CUDA_C_64F;\n    default: return CUDA_R_32F;\n  }\n}\n\ncublasComputeType_t GetComputeType(DataType data_type, CudaStream* cuda_stream) {\n  switch (data_type) {\n    case kFloat: {\n      if (CudaMatmulMode::is_matmul_allow_tf32()) {\n        return CUBLAS_COMPUTE_32F_FAST_TF32;\n      } else {\n        // Starting with cuBLAS version 11.0.0, the library will automatically make use of Tensor\n        // Core capabilities wherever possible, unless they are explicitly disabled by selecting\n        // pedantic compute modes in cuBLAS\n        return CUBLAS_COMPUTE_32F_PEDANTIC;\n      }\n    }\n    case kDouble: return CUBLAS_COMPUTE_64F;\n    case kFloat16: {\n      if (cuda_stream->device_properties().major >= 5) {\n        return CUBLAS_COMPUTE_32F;\n      } else {\n        return CUBLAS_COMPUTE_16F;\n      }\n    }\n    case kComplex64: {\n      if (CudaMatmulMode::is_matmul_allow_tf32()) {\n        return CUBLAS_COMPUTE_32F_FAST_TF32;\n      } else {\n        return CUBLAS_COMPUTE_32F_PEDANTIC;\n      }\n    }\n    case kComplex128: return CUBLAS_COMPUTE_64F;\n#if CUDA_VERSION >= 11000\n    case kBFloat16: return CUBLAS_COMPUTE_32F;\n#endif  // CUDA_VERSION >= 11000\n    default: UNIMPLEMENTED(); return CUBLAS_COMPUTE_32F;\n  }\n}\n\nvoid LaunchBroadcastMatmul(Stream* stream, DataType data_type, BlasTransposeType transpose_a,\n                           BlasTransposeType transpose_b, int64_t num_batch_dims,\n                           const int64_t* broadcast_batch_dims, const int64_t* a_batch_dims,\n                           const int64_t* b_batch_dims, const int64_t* c_batch_dims, int64_t m,\n                           int64_t n, int64_t k, Scalar alpha, const void* a, const void* b,\n                           Scalar beta, void* c) {\n  auto* cuda_stream = stream->As<CudaStream>();\n  const auto cuda_data_type = GetCudaDataType(data_type);\n  const auto compute_type = GetComputeType(data_type, cuda_stream);\n  const auto sp_alpha = GetCublasScalarParameter(alpha, compute_type);\n  const auto GetCublasOperation = [](BlasTransposeType transpose_type, DataType data_type) {\n    if (transpose_type == BlasTransposeType::N) {\n      return CUBLAS_OP_N;\n    } else if (transpose_type == BlasTransposeType::T) {\n      return DType(data_type).is_complex() ? CUBLAS_OP_C : CUBLAS_OP_T;\n    } else {\n      UNIMPLEMENTED();\n      return CUBLAS_OP_N;\n    }\n  };\n  const cublasOperation_t cublas_trans_a = GetCublasOperation(transpose_b, data_type);\n  const cublasOperation_t cublas_trans_b = GetCublasOperation(transpose_a, data_type);\n  const int cublas_m = n;\n  const int cublas_n = m;\n  const int cublas_k = k;\n  int cublas_lda = 0;\n  if (transpose_b == BlasTransposeType::N) {\n    cublas_lda = n;\n  } else if (transpose_b == BlasTransposeType::T) {\n    cublas_lda = k;\n  } else {\n    UNIMPLEMENTED();\n  }\n  int cublas_ldb = 0;\n  if (transpose_a == BlasTransposeType::N) {\n    cublas_ldb = k;\n  } else if (transpose_a == BlasTransposeType::T) {\n    cublas_ldb = m;\n  } else {\n    UNIMPLEMENTED();\n  }\n  const int cublas_ldc = n;\n\n  CublasMathModeGuard guard(cuda_stream->cublas_handle());\n  if (data_type == DataType::kFloat16) {\n#if CUDA_VERSION < 11000\n    guard.SetMathMode(CUBLAS_TENSOR_OP_MATH);\n#else\n    cublasMath_t cublas_flags = CUBLAS_DEFAULT_MATH;\n    if (cuda_stream->device_properties().major >= 5\n        && CudaMatmulMode::is_matmul_allow_fp16_reduced_precision_reduction()) {\n      cublas_flags = static_cast<cublasMath_t>(cublas_flags\n                                               | CUBLAS_MATH_DISALLOW_REDUCED_PRECISION_REDUCTION);\n    }\n    guard.SetMathMode(cublas_flags);\n#endif  // CUDA_VERSION < 11000\n  }\n#if CUDA_VERSION >= 11000\n  cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT;\n#else\n  cublasGemmAlgo_t algo =\n      (data_type == DataType::kFloat16) ? CUBLAS_GEMM_DFALT_TENSOR_OP : CUBLAS_GEMM_DEFAULT;\n#endif\n\n  if (num_batch_dims == 1 && c_batch_dims[0] != 1) {\n    const void* cublas_a = b;\n    const void* cublas_b = a;\n    void* cublas_c = c;\n    const int64_t a_batch_count = a_batch_dims[0];\n    const int64_t b_batch_count = b_batch_dims[0];\n    CHECK(a_batch_count == 1 || b_batch_count == 1 || a_batch_count == b_batch_count);\n    CHECK_GT(a_batch_count, 0);\n    CHECK_GT(b_batch_count, 0);\n    const int batch_count = std::max(a_batch_count, b_batch_count);\n    const long long int cublas_stride_a = b_batch_count == 1 ? 0 : cublas_m * cublas_k;\n    const long long int cublas_stride_b = a_batch_count == 1 ? 0 : cublas_k * cublas_n;\n    const long long int cublas_stride_c = cublas_m * cublas_n;\n    const auto sp_beta = GetCublasScalarParameter(beta, compute_type);\n    OF_CUBLAS_CHECK(cublasGemmStridedBatchedEx(\n        cuda_stream->cublas_handle(), cublas_trans_a, cublas_trans_b, cublas_m, cublas_n, cublas_k,\n        &sp_alpha, cublas_a, cuda_data_type, cublas_lda, cublas_stride_a, cublas_b, cuda_data_type,\n        cublas_ldb, cublas_stride_b, &sp_beta, cublas_c, cuda_data_type, cublas_ldc,\n        cublas_stride_c, batch_count, compute_type, algo));\n  } else {\n    auto func = [&](const void* batch_a, const void* batch_b, void* batch_c, Scalar batch_beta) {\n      const auto sp_beta = GetCublasScalarParameter(batch_beta, compute_type);\n      const void* cublas_a = batch_b;\n      const void* cublas_b = batch_a;\n      void* cublas_c = batch_c;\n      OF_CUBLAS_CHECK(cublasGemmEx(\n          cuda_stream->cublas_handle(), cublas_trans_a, cublas_trans_b, cublas_m, cublas_n,\n          cublas_k, &sp_alpha, cublas_a, cuda_data_type, cublas_lda, cublas_b, cuda_data_type,\n          cublas_ldb, &sp_beta, cublas_c, cuda_data_type, cublas_ldc, compute_type, algo));\n    };\n    ForEachMatmul<kMaxNumDims>(data_type, m, n, k, beta, num_batch_dims, broadcast_batch_dims,\n                               a_batch_dims, b_batch_dims, c_batch_dims, a, b, c, func);\n  }\n}\n\nclass BroadcastMatmulFactoryImpl : public BroadcastMatmulFactory {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(BroadcastMatmulFactoryImpl);\n  BroadcastMatmulFactoryImpl() = default;\n  ~BroadcastMatmulFactoryImpl() override = default;\n\n  std::unique_ptr<BroadcastMatmul> New(DataType data_type, BlasTransposeType transpose_a,\n                                       BlasTransposeType transpose_b,\n                                       size_t max_num_dims) override {\n    auto cuda_data_type = OptCudaDataType(data_type);\n    if (max_num_dims <= kMaxNumDims && cuda_data_type.has_value()) {\n      return std::make_unique<BroadcastMatmulImpl<kMaxNumDims>>(data_type, transpose_a,\n                                                                transpose_b);\n    } else {\n      return nullptr;\n    }\n  }\n};\n\nREGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, BroadcastMatmulFactory, BroadcastMatmulFactoryImpl);\n\n}  // namespace\n\n}  // namespace internal\n\n}  // namespace broadcast_matmul\n\n}  // namespace primitive\n}  // namespace ep\n\n}  // namespace oneflow\n\n#endif  // WITH_CUDA\n"
  },
  {
    "path": "oneflow/core/ep/cuda/primitive/cast.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/ep/include/primitive/cast.h\"\n#include \"oneflow/core/ep/cuda/primitive/type_seq.h\"\n#include \"oneflow/core/cuda/elementwise.cuh\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n\nnamespace oneflow {\n\nnamespace ep {\nnamespace primitive {\n\nnamespace {\n\ntemplate<typename To, typename From, typename = void>\nstruct CastFunctor {\n  __device__ To operator()(From from) const { return static_cast<To>(from); }\n};\n\ntemplate<typename To>\nstruct CastFunctor<To, half, typename std::enable_if<!std::is_same<To, half>::value>::type> {\n  __device__ To operator()(half from) const { return static_cast<To>(static_cast<float>(from)); }\n\n  __device__ void Apply2(To* to, const half* from) const {\n    const float2 f2 = __half22float2(*reinterpret_cast<const half2*>(from));\n    to[0] = static_cast<To>(f2.x);\n    to[1] = static_cast<To>(f2.y);\n  }\n};\n\ntemplate<typename From>\nstruct CastFunctor<half, From, typename std::enable_if<!std::is_same<From, half>::value>::type> {\n  __device__ half operator()(From from) const {\n    return static_cast<half>(static_cast<float>(from));\n  }\n\n  __device__ void Apply2(half* to, const From* from) const {\n    float2 f2;\n    f2.x = static_cast<float>(from[0]);\n    f2.y = static_cast<float>(from[1]);\n    *reinterpret_cast<half2*>(to) = __float22half2_rn(f2);\n  }\n};\n\n#if CUDA_VERSION >= 11000\n\ntemplate<typename To>\nstruct CastFunctor<To, nv_bfloat16,\n                   typename std::enable_if<!(std::is_same<To, nv_bfloat16>::value\n                                             || std::is_same<To, half>::value)>::type> {\n  __device__ To operator()(nv_bfloat16 from) const {\n    return static_cast<To>(static_cast<float>(from));\n  }\n};\n\ntemplate<typename From>\nstruct CastFunctor<nv_bfloat16, From,\n                   typename std::enable_if<!(std::is_same<From, nv_bfloat16>::value\n                                             || std::is_same<From, half>::value)>::type> {\n  __device__ nv_bfloat16 operator()(From from) const {\n    return static_cast<nv_bfloat16>(static_cast<float>(from));\n  }\n};\n\n#endif  // CUDA_VERSION >= 11000\n\ntemplate<typename From, typename To>\nclass CastImpl : public Cast {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CastImpl);\n  explicit CastImpl() = default;\n  ~CastImpl() override = default;\n\n  void Launch(Stream* stream, const void* from, void* to, size_t count) override {\n    auto* cuda_stream = stream->As<CudaStream>();\n    OF_CUDA_CHECK((cuda::elementwise::Unary<CastFunctor<To, From>, To, From>(\n        CastFunctor<To, From>(), count, reinterpret_cast<To*>(to),\n        reinterpret_cast<const From*>(from), cuda_stream->cuda_stream())));\n  }\n};\n\ntemplate<typename From, typename To>\nstd::unique_ptr<Cast> NewCast() {\n  return std::unique_ptr<Cast>(new CastImpl<From, To>());\n}\n\n#define CUDA_PRIMITIVE_CAST_TYPE_SEQ \\\n  CUDA_PRIMITIVE_BOOL_TYPE_SEQ       \\\n  CUDA_PRIMITIVE_CHAR_TYPE_SEQ       \\\n  CUDA_PRIMITIVE_INT8_TYPE_SEQ       \\\n  CUDA_PRIMITIVE_UINT8_TYPE_SEQ      \\\n  CUDA_PRIMITIVE_INT32_TYPE_SEQ      \\\n  CUDA_PRIMITIVE_UINT32_TYPE_SEQ     \\\n  CUDA_PRIMITIVE_INT64_TYPE_SEQ      \\\n  CUDA_PRIMITIVE_UINT64_TYPE_SEQ     \\\n  CUDA_PRIMITIVE_FLOAT_TYPE_SEQ      \\\n  CUDA_PRIMITIVE_DOUBLE_TYPE_SEQ     \\\n  CUDA_PRIMITIVE_FLOAT16_TYPE_SEQ    \\\n  CUDA_PRIMITIVE_BFLOAT16_TYPE_SEQ\n\nclass CastFactoryImpl : public CastFactory {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CastFactoryImpl);\n  CastFactoryImpl() = default;\n  ~CastFactoryImpl() override = default;\n\n  std::unique_ptr<Cast> New(DataType from, DataType to) override {\n#define MAKE_NEW_CAST_ENTRY(from_pair, to_pair)                              \\\n  {std::make_pair(OF_PP_PAIR_SECOND(from_pair), OF_PP_PAIR_SECOND(to_pair)), \\\n   NewCast<OF_PP_PAIR_FIRST(from_pair), OF_PP_PAIR_FIRST(to_pair)>},\n\n    static const std::map<std::pair<DataType, DataType>, std::function<std::unique_ptr<Cast>()>>\n        new_cast_handle{OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(\n            MAKE_NEW_CAST_ENTRY, CUDA_PRIMITIVE_CAST_TYPE_SEQ, CUDA_PRIMITIVE_CAST_TYPE_SEQ)};\n\n#undef MAKE_NEW_CAST_ENTRY\n\n    const auto it = new_cast_handle.find(std::make_pair(from, to));\n    if (it != new_cast_handle.end()) {\n      return it->second();\n    } else {\n      return nullptr;\n    }\n  }\n};\n\nREGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, CastFactory, CastFactoryImpl);\n\n}  // namespace\n\n}  // namespace primitive\n}  // namespace ep\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/ep/cuda/primitive/constant_pad.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/ep/include/primitive/constant_pad.h\"\n#include \"oneflow/core/ep/common/primitive/constant_pad.h\"\n#include \"oneflow/core/ep/cuda/primitive/type_seq.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n#include <cuda_runtime.h>\n\nnamespace oneflow {\n\nnamespace ep {\n\nnamespace primitive {\n\nnamespace {\n\ntemplate<size_t num_dims, typename IndexType, typename StorageType>\n__global__ void ConstantPadKernel(ConstantPadParams<num_dims, IndexType> params,\n                                  StorageType packed_pad_val) {\n  const StorageType* src = reinterpret_cast<const StorageType*>(params.src);\n  StorageType* dst = reinterpret_cast<StorageType*>(params.dst);\n  IndexType src_index[num_dims];\n  IndexType dst_index[num_dims];\n  CUDA_1D_KERNEL_LOOP_T(IndexType, linear_index, params.elem_cnt) {\n    params.dst_index_helper.OffsetToNdIndex(linear_index, dst_index);\n    bool if_pad = false;\n#pragma unroll\n    for (int i = 0; i < num_dims; i++) {\n      if (dst_index[i] >= params.valid_start[i] && dst_index[i] < params.valid_end[i]) {\n        src_index[i] = dst_index[i] - params.valid_start[i];\n      } else {\n        if_pad = true;\n        break;\n      }\n    }\n    StorageType dst_val = packed_pad_val;\n    if (!if_pad) {\n      const IndexType src_offset = params.src_index_helper.NdIndexToOffset(src_index);\n      dst_val = src[src_offset];\n    }\n    dst[linear_index] = dst_val;\n  }\n}\n\ntemplate<>\ncuComplex GetValue<cuComplex>(Scalar value) {\n  const std::complex<float> cpp_value = GetValue<std::complex<float>>(value);\n  return cuComplex{cpp_value.real(), cpp_value.imag()};\n}\n\ntemplate<>\ncuDoubleComplex GetValue<cuDoubleComplex>(Scalar value) {\n  const std::complex<double> cpp_value = GetValue<std::complex<double>>(value);\n  return cuDoubleComplex{cpp_value.real(), cpp_value.imag()};\n}\n\ntemplate<>\nhalf GetValue<half>(Scalar value) {\n  return static_cast<half>(GetValue<float>(value));\n}\n\n#if CUDA_VERSION >= 11000\n\ntemplate<>\nnv_bfloat16 GetValue<nv_bfloat16>(Scalar value) {\n  return static_cast<nv_bfloat16>(GetValue<float>(value));\n}\n\n#endif  // CUDA_VERSION >= 11000\n\ntemplate<size_t num_dims, typename IndexType, typename StorageType>\nvoid LaunchKernel(Stream* stream, ConstantPadParams<num_dims, IndexType> params,\n                  StorageType packed_pad_val, size_t elem_cnt) {\n  stream->As<CudaStream>()->LaunchKernelDefaultWaves(\n      (ConstantPadKernel<num_dims, IndexType, StorageType>), elem_cnt, params, packed_pad_val);\n}\n\ntemplate<size_t num_dims, typename IndexType, typename StorageType>\nvoid LaunchKernel(Stream* stream, void* dst, const int64_t* dst_dims, const void* src,\n                  const int64_t* src_dims, const int64_t* padding_before,\n                  const int64_t* padding_after, StorageType packed_pad_val, size_t elem_cnt) {\n  ConstantPadParams<num_dims, IndexType> params;\n  params.dst_index_helper = OffsetToIndexCalculator<IndexType, num_dims>(dst_dims);\n  params.src_index_helper = NdIndexOffsetHelper<IndexType, num_dims>(src_dims);\n  params.dst = dst;\n  params.src = src;\n  for (int i = 0; i < num_dims; i++) {\n    params.valid_start[i] = padding_before[i];\n    params.valid_end[i] = dst_dims[i] - padding_after[i];\n  }\n  params.elem_cnt = elem_cnt;\n  LaunchKernel<num_dims, IndexType, StorageType>(stream, params, packed_pad_val, elem_cnt);\n}\n\ntemplate<size_t num_dims, typename StorageType>\nvoid DispatchIndexType(Stream* stream, void* dst, const int64_t* dst_dims, const void* src,\n                       const int64_t* src_dims, const int64_t* padding_before,\n                       const int64_t* padding_after, StorageType packed_pad_val, size_t elem_cnt) {\n  if (elem_cnt < GetMaxVal<int32_t>()) {\n    LaunchKernel<num_dims, int32_t, StorageType>(stream, dst, dst_dims, src, src_dims,\n                                                 padding_before, padding_after, packed_pad_val,\n                                                 elem_cnt);\n  } else {\n    LaunchKernel<num_dims, int64_t, StorageType>(stream, dst, dst_dims, src, src_dims,\n                                                 padding_before, padding_after, packed_pad_val,\n                                                 elem_cnt);\n  }\n}\n\ntemplate<size_t num_dims, typename T>\nvoid DispatchPackSize(Stream* stream, void* dst, int64_t* dst_dims, const void* src,\n                      int64_t* src_dims, int64_t* padding_before, int64_t* padding_after,\n                      T pad_val) {\n  constexpr int32_t max_packsize = GetMaxPackSize<T>();\n  size_t launch_pack_size = GetLaunchPackSize<max_packsize>(num_dims, dst, dst_dims, src, src_dims,\n                                                            padding_before, padding_after);\n\n  dst_dims[num_dims - 1] /= launch_pack_size;\n  src_dims[num_dims - 1] /= launch_pack_size;\n  padding_before[num_dims - 1] /= launch_pack_size;\n  padding_after[num_dims - 1] /= launch_pack_size;\n\n  size_t elem_cnt = 1;\n  for (int i = 0; i < num_dims; i++) { elem_cnt *= dst_dims[i]; }\n  if (launch_pack_size == 1) {\n    Pack<T, 1> packed_pad_val(pad_val);\n    DispatchIndexType<num_dims, PackType<T, 1>>(stream, dst, dst_dims, src, src_dims,\n                                                padding_before, padding_after,\n                                                packed_pad_val.storage, elem_cnt);\n  } else if (launch_pack_size == 2) {\n    Pack<T, 2> packed_pad_val(pad_val);\n    DispatchIndexType<num_dims, PackType<T, 2>>(stream, dst, dst_dims, src, src_dims,\n                                                padding_before, padding_after,\n                                                packed_pad_val.storage, elem_cnt);\n  } else if (launch_pack_size == 4) {\n    Pack<T, 4> packed_pad_val(pad_val);\n    DispatchIndexType<num_dims, PackType<T, 4>>(stream, dst, dst_dims, src, src_dims,\n                                                padding_before, padding_after,\n                                                packed_pad_val.storage, elem_cnt);\n  } else if (launch_pack_size == 8) {\n    Pack<T, 8> packed_pad_val(pad_val);\n    DispatchIndexType<num_dims, PackType<T, 8>>(stream, dst, dst_dims, src, src_dims,\n                                                padding_before, padding_after,\n                                                packed_pad_val.storage, elem_cnt);\n  } else if (launch_pack_size == 16) {\n    Pack<T, 16> packed_pad_val(pad_val);\n    DispatchIndexType<num_dims, PackType<T, 16>>(stream, dst, dst_dims, src, src_dims,\n                                                 padding_before, padding_after,\n                                                 packed_pad_val.storage, elem_cnt);\n  } else {\n    UNIMPLEMENTED();\n  }\n}\n\ntemplate<typename T>\nvoid LaunchWithSimplified(Stream* stream, size_t num_dims, void* dst, int64_t* dst_dims,\n                          const void* src, int64_t* src_dims, int64_t* padding_before,\n                          int64_t* padding_after, T pad_val) {\n  void (*func)(Stream* /*stream*/, void* /*dst*/, int64_t* /*dst_dims*/, const void* /*src*/,\n               int64_t* /*src_dims*/, int64_t* /*padding_before*/, int64_t* /*padding_after*/, T) =\n      nullptr;\n  if (num_dims == 1) {\n    func = DispatchPackSize<1, T>;\n  } else if (num_dims == 2) {\n    func = DispatchPackSize<2, T>;\n  } else if (num_dims == 3) {\n    func = DispatchPackSize<3, T>;\n  } else if (num_dims == 4) {\n    func = DispatchPackSize<4, T>;\n  } else if (num_dims == 5) {\n    func = DispatchPackSize<5, T>;\n  } else if (num_dims == 6) {\n    func = DispatchPackSize<6, T>;\n  } else if (num_dims == 7) {\n    func = DispatchPackSize<7, T>;\n  } else if (num_dims == 8) {\n    func = DispatchPackSize<8, T>;\n  } else {\n    UNIMPLEMENTED();\n  }\n  func(stream, dst, dst_dims, src, src_dims, padding_before, padding_after, pad_val);\n}\n\ntemplate<typename T>\nvoid SimplifyThenLaunch(Stream* stream, size_t num_dims, const int64_t* src_dims, const void* src,\n                        const int64_t* padding_before, const int64_t* padding_after, T pad_val,\n                        void* dst) {\n  CHECK_GT(num_dims, 0) << \"num_dims must greater than 0\";\n  CHECK_LE(num_dims, kMaxNumDims);\n  int64_t simplified_dst_dims[kMaxNumDims];\n  int64_t simplified_src_dims[kMaxNumDims];\n  int64_t simplified_padding_before[kMaxNumDims];\n  int64_t simplified_padding_after[kMaxNumDims];\n  size_t simplified_num_dims = 1;\n  SimplifyPadDims(num_dims, src_dims, padding_before, padding_after, &simplified_num_dims,\n                  simplified_dst_dims, simplified_src_dims, simplified_padding_before,\n                  simplified_padding_after);\n  LaunchWithSimplified<T>(stream, simplified_num_dims, dst, simplified_dst_dims, src,\n                          simplified_src_dims, simplified_padding_before, simplified_padding_after,\n                          pad_val);\n}\n\ntemplate<typename T>\nclass ConstantPadImpl : public ConstantPad {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(ConstantPadImpl);\n  ConstantPadImpl() = default;\n  ~ConstantPadImpl() override = default;\n\n  void Launch(Stream* stream, size_t num_dims, const int64_t* src_dims, const void* src,\n              const int64_t* padding_before, const int64_t* padding_after, Scalar pad_val,\n              void* dst) override {\n    SimplifyThenLaunch<T>(stream, num_dims, src_dims, src, padding_before, padding_after,\n                          GetValue<T>(pad_val), dst);\n  }\n};\n\ntemplate<typename T>\nstd::unique_ptr<ConstantPad> NewConstantPad() {\n  return std::unique_ptr<ConstantPad>(new ConstantPadImpl<T>());\n}\n\nclass ConstantPadFactoryImpl : public ConstantPadFactory {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(ConstantPadFactoryImpl);\n  ConstantPadFactoryImpl() = default;\n  ~ConstantPadFactoryImpl() override = default;\n\n  std::unique_ptr<ConstantPad> New(DataType data_type) override {\n#define MAKE_NEW_CONSTANT_PAD_ENTRY(type_cpp, type_proto) {type_proto, NewConstantPad<type_cpp>},\n\n    static const std::map<DataType, std::function<std::unique_ptr<ConstantPad>()>>\n        new_constant_pad_handle{\n            OF_PP_FOR_EACH_TUPLE(MAKE_NEW_CONSTANT_PAD_ENTRY,\n                                 CUDA_PRIMITIVE_REAL_TYPE_SEQ CUDA_PRIMITIVE_COMPLEX_TYPE_SEQ)};\n\n#undef MAKE_NEW_CONSTANT_PAD_ENTRY\n\n    const auto it = new_constant_pad_handle.find(data_type);\n    if (it != new_constant_pad_handle.end()) {\n      return it->second();\n    } else {\n      return nullptr;\n    }\n  }\n};\n\nREGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, ConstantPadFactory, ConstantPadFactoryImpl);\n\n}  // namespace\n\n}  // namespace primitive\n\n}  // namespace ep\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/ep/cuda/primitive/copy_nd.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/ep/include/primitive/copy_nd.h\"\n#include \"oneflow/core/ep/common/primitive/copy_nd.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n#include <cuda_runtime.h>\n\nnamespace oneflow {\n\nnamespace ep {\nnamespace primitive {\n\nnamespace {\n\ntemplate<size_t num_dims, size_t movement_size, typename IndexType>\n__global__ void CopyNdKernel(CopyNdKernelParams<num_dims, IndexType> params) {\n  using T = typename std::aligned_storage<movement_size, movement_size>::type;\n  const T* src = reinterpret_cast<const T*>(params.src);\n  T* dst = reinterpret_cast<T*>(params.dst);\n  IndexType copy_index[num_dims];\n  IndexType src_index[num_dims];\n  IndexType dst_index[num_dims];\n  CUDA_1D_KERNEL_LOOP_T(IndexType, i, params.count) {\n    params.copy_index_helper.OffsetToNdIndex(i, copy_index);\n#pragma unroll\n    for (size_t j = 0; j < num_dims; ++j) {\n      src_index[j] = params.src_pos[j] + copy_index[j];\n      dst_index[j] = params.dst_pos[j] + copy_index[j];\n    }\n    const IndexType src_offset = params.src_index_helper.NdIndexToOffset(src_index);\n    const IndexType dst_offset = params.dst_index_helper.NdIndexToOffset(dst_index);\n    dst[dst_offset] = src[src_offset];\n  }\n}\n\ntemplate<size_t num_dims, size_t movement_size, typename IndexType>\nvoid LaunchKernel(Stream* stream, CopyNdKernelParams<num_dims, IndexType> params) {\n  cudaStream_t cuda_stream = stream->As<CudaStream>()->cuda_stream();\n  CopyNdKernel<num_dims, movement_size, IndexType>\n      <<<BlocksNum4ThreadsNum(params.count), kCudaThreadsNumPerBlock, 0, cuda_stream>>>(params);\n}\n\nclass CopyNdImpl : public CopyNd {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CopyNdImpl);\n  CopyNdImpl() = default;\n  ~CopyNdImpl() override = default;\n\n  void Launch(Stream* stream, DataType data_type, size_t num_dims, void* dst,\n              const int64_t* dst_dims, const int64_t* dst_pos, const void* src,\n              const int64_t* src_dims, const int64_t* src_pos,\n              const int64_t* extent) const override {\n    SimplifyThenLaunch(stream, data_type, num_dims, dst, dst_dims, dst_pos, src, src_dims, src_pos,\n                       extent);\n  }\n};\n\nclass CopyNdFactoryImpl : public CopyNdFactory {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CopyNdFactoryImpl);\n  CopyNdFactoryImpl() = default;\n  ~CopyNdFactoryImpl() override = default;\n\n  std::unique_ptr<CopyNd> New(size_t max_num_dims) override {\n    if (max_num_dims <= kMaxNumDims) {\n      return std::unique_ptr<CopyNd>(new CopyNdImpl());\n    } else {\n      return nullptr;\n    }\n  }\n};\n\nREGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, CopyNdFactory, CopyNdFactoryImpl);\n\n}  // namespace\n\n}  // namespace primitive\n}  // namespace ep\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/ep/cuda/primitive/elementwise_unary.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/ep/common/primitive/elementwise_unary.h\"\n#include \"oneflow/core/ep/cuda/primitive/unary_functor.cuh\"\n\nnamespace oneflow {\n\nnamespace ep {\nnamespace primitive {\n\nnamespace {\n\ntemplate<UnaryOp unary_op, typename Src, typename Dst>\nclass ElementwiseUnaryImpl : public ElementwiseUnary {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(ElementwiseUnaryImpl);\n  ElementwiseUnaryImpl(Scalar attr0, Scalar attr1) : attr0(attr0), attr1(attr1) {}\n  ~ElementwiseUnaryImpl() override = default;\n\n  void Launch(Stream* stream, const void* src, void* dst, size_t count) override {\n    auto* cuda_stream = stream->As<CudaStream>();\n    auto functor = UnaryFunctor<DeviceType::kCUDA, unary_op, Dst, Src>(attr0, attr1);\n    OF_CUDA_CHECK((cuda::elementwise::Unary<decltype(functor), Dst, Src>(\n        functor, count, reinterpret_cast<Dst*>(dst), reinterpret_cast<const Src*>(src),\n        cuda_stream->cuda_stream())));\n  }\n\n protected:\n  Scalar attr0, attr1;\n};\n\ntemplate<UnaryOp unary_op, typename Src, typename Dst>\nstd::unique_ptr<ElementwiseUnary> NewElementwiseUnary(Scalar attr0, Scalar attr1) {\n  return std::unique_ptr<ElementwiseUnary>(\n      new ElementwiseUnaryImpl<unary_op, Src, Dst>(attr0, attr1));\n}\n\nclass ElementwiseUnaryFactoryImpl : public ElementwiseUnaryFactory {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(ElementwiseUnaryFactoryImpl);\n  ElementwiseUnaryFactoryImpl() = default;\n  ~ElementwiseUnaryFactoryImpl() override = default;\n\n  std::unique_ptr<ElementwiseUnary> New(UnaryOp unary_op, DataType src_type,\n                                        DataType dst_dtype) override {\n    return New(unary_op, src_type, dst_dtype, Scalar(), Scalar());\n  }\n\n  std::unique_ptr<ElementwiseUnary> New(UnaryOp unary_op, DataType src_type, DataType dst_dtype,\n                                        Scalar attr0) override {\n    return New(unary_op, src_type, dst_dtype, attr0, Scalar());\n  }\n\n  std::unique_ptr<ElementwiseUnary> New(UnaryOp unary_op, DataType src_type, DataType dst_dtype,\n                                        Scalar attr0, Scalar attr1) override {\n#define MAKE_NEW_SAME_DTYPE_ELEMENTWISE_UNARY_ENTRY(unary_op, dtype_pair)                   \\\n  {std::make_tuple(unary_op, OF_PP_PAIR_SECOND(dtype_pair), OF_PP_PAIR_SECOND(dtype_pair)), \\\n   NewElementwiseUnary<unary_op, OF_PP_PAIR_FIRST(dtype_pair), OF_PP_PAIR_FIRST(dtype_pair)>},\n\n#define MAKE_NEW_DIFFERENT_DTYPE_ELEMENTWISE_UNARY_ENTRY(unary_op, src_type_pair, dst_dtype_pair)  \\\n  {std::make_tuple(unary_op, OF_PP_PAIR_SECOND(src_type_pair), OF_PP_PAIR_SECOND(dst_dtype_pair)), \\\n   NewElementwiseUnary<unary_op, OF_PP_PAIR_FIRST(src_type_pair),                                  \\\n                       OF_PP_PAIR_FIRST(dst_dtype_pair)>},\n\n    static const std::map<std::tuple<UnaryOp, DataType, DataType>,\n                          std::function<std::unique_ptr<ElementwiseUnary>(Scalar, Scalar)>>\n        new_elementwise_unary_handle{\n            // For All Type OP\n            OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_NEW_SAME_DTYPE_ELEMENTWISE_UNARY_ENTRY,\n                                             UNARY_MATH_OP_SEQ, CUDA_PRIMITIVE_REAL_TYPE_SEQ)\n            // For Float Type OP\n            OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_NEW_SAME_DTYPE_ELEMENTWISE_UNARY_ENTRY,\n                                             UNARY_FLOATING_MATH_OP_SEQ,\n                                             CUDA_PRIMITIVE_FLOATING_TYPE_SEQ)\n\n            // For Complex Type OP\n            OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_NEW_SAME_DTYPE_ELEMENTWISE_UNARY_ENTRY,\n                                             UNARY_COMPLEX_C2C_OP_SEQ,\n                                             CUDA_PRIMITIVE_COMPLEX_TYPE_SEQ)\n\n                OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(\n                    MAKE_NEW_DIFFERENT_DTYPE_ELEMENTWISE_UNARY_ENTRY, UNARY_COMPLEX_C2R_OP_SEQ,\n                    CUDA_PRIMITIVE_COMPLEX_TYPE_SEQ, CUDA_PRIMITIVE_FLOATING_TYPE_SEQ)\n\n                    OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(\n                        MAKE_NEW_DIFFERENT_DTYPE_ELEMENTWISE_UNARY_ENTRY, UNARY_COMPLEX_R2C_OP_SEQ,\n                        CUDA_PRIMITIVE_FLOATING_TYPE_SEQ, CUDA_PRIMITIVE_COMPLEX_TYPE_SEQ)\n\n            // For Int Type OP\n            OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_NEW_SAME_DTYPE_ELEMENTWISE_UNARY_ENTRY,\n                                             UNARY_INT_MATH_OP_SEQ, CUDA_PRIMITIVE_INT_TYPE_SEQ)\n\n            // For Utils OP\n            OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_NEW_DIFFERENT_DTYPE_ELEMENTWISE_UNARY_ENTRY,\n                                             UNARY_UTILS_OP_SEQ, UTIL_OPS_DATA_TYPE_SEQ,\n                                             CUDA_PRIMITIVE_BOOL_TYPE_SEQ)\n\n            // For Logical OP\n            OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_NEW_DIFFERENT_DTYPE_ELEMENTWISE_UNARY_ENTRY,\n                                             UNARY_LOGICAL_OP_SEQ, CUDA_PRIMITIVE_REAL_TYPE_SEQ,\n                                             CUDA_PRIMITIVE_BOOL_TYPE_SEQ)\n\n            // For bitwise op\n            OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(\n                MAKE_NEW_SAME_DTYPE_ELEMENTWISE_UNARY_ENTRY, UNARY_BITWISE_OP_SEQ,\n                CUDA_PRIMITIVE_INT_TYPE_SEQ CUDA_PRIMITIVE_BOOL_TYPE_SEQ)};\n\n#undef MAKE_NEW_DIFFERENT_DTYPE_ELEMENTWISE_UNARY_ENTRY\n\n#undef MAKE_NEW_SAME_DTYPE_ELEMENTWISE_UNARY_ENTRY\n    const auto it =\n        new_elementwise_unary_handle.find(std::make_tuple(unary_op, src_type, dst_dtype));\n    if (it != new_elementwise_unary_handle.end()) {\n      return it->second(attr0, attr1);\n    } else {\n      return nullptr;\n    }\n  }\n};\n\nREGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, ElementwiseUnaryFactory, ElementwiseUnaryFactoryImpl);\n\n}  // namespace\n}  // namespace primitive\n}  // namespace ep\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/ep/cuda/primitive/fill.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/ep/include/primitive/fill.h\"\n#include \"oneflow/core/ep/cuda/primitive/type_seq.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n\nnamespace oneflow {\n\nnamespace ep {\nnamespace primitive {\n\nnamespace {\n\ntemplate<size_t size>\nusing Storage = typename std::aligned_storage<size, size>::type;\n\ntemplate<typename T, size_t pack>\nunion Pack {\n  static constexpr size_t size = sizeof(T) * pack;\n  explicit __device__ __host__ Pack(T value) {\n    static_assert(sizeof(Pack) == size, \"\");\n    static_assert(alignof(Pack) == size, \"\");\n#pragma unroll\n    for (size_t i = 0; i < pack; ++i) { elem[i] = value; }\n  }\n  T elem[pack];\n  Storage<size> storage;\n};\n\ntemplate<typename T, size_t pack>\n__global__ void FillGpu(T* dst, T value, size_t count) {\n  const size_t pack_count = count / pack;\n  Pack<T, pack> pack_value(value);\n  auto* pack_dst = reinterpret_cast<decltype(pack_value.storage)*>(dst);\n  CUDA_1D_KERNEL_LOOP_T(size_t, i, pack_count) { pack_dst[i] = pack_value.storage; }\n  T* tail_dst = dst + pack_count * pack;\n  const size_t tail_count = count - pack_count * pack;\n  CUDA_1D_KERNEL_LOOP_T(size_t, i, tail_count) { tail_dst[i] = value; }\n}\n\ntemplate<typename T>\nT GetValue(Scalar value) {\n  return value.Value<T>();\n}\n\ntemplate<>\nhalf GetValue<half>(Scalar value) {\n  return static_cast<half>(GetValue<float>(value));\n}\n\ntemplate<>\ncuComplex GetValue<cuComplex>(Scalar value) {\n  const std::complex<float> cpp_value = GetValue<std::complex<float>>(value);\n  return cuComplex{cpp_value.real(), cpp_value.imag()};\n}\n\ntemplate<>\ncuDoubleComplex GetValue<cuDoubleComplex>(Scalar value) {\n  const std::complex<double> cpp_value = GetValue<std::complex<double>>(value);\n  return cuDoubleComplex{cpp_value.real(), cpp_value.imag()};\n}\n\n#if CUDA_VERSION >= 11000\n\ntemplate<>\nnv_bfloat16 GetValue<nv_bfloat16>(Scalar value) {\n  return static_cast<nv_bfloat16>(GetValue<float>(value));\n}\n\n#endif  // CUDA_VERSION >= 11000\n\ntemplate<typename T, size_t pack>\ntypename std::enable_if<(pack != 0), void>::type LaunchPackFill(cudaStream_t stream, T* dst,\n                                                                T value, size_t count) {\n  FillGpu<T, pack>\n      <<<BlocksNum4ThreadsNum(count), kCudaThreadsNumPerBlock, 0, stream>>>(dst, value, count);\n}\n\ntemplate<typename T, size_t pack>\ntypename std::enable_if<(pack == 0), void>::type LaunchPackFill(cudaStream_t stream, T* dst,\n                                                                T value, size_t count) {\n  LOG(FATAL) << \"wrong alignment\";\n}\n\ntemplate<typename T>\nvoid LaunchFill(cudaStream_t stream, T* dst, T value, size_t count) {\n  auto uintptr = reinterpret_cast<std::uintptr_t>(dst);\n  if (uintptr % 16 == 0) {\n    LaunchPackFill<T, 16 / sizeof(T)>(stream, dst, value, count);\n  } else if (uintptr % 8 == 0) {\n    LaunchPackFill<T, 8 / sizeof(T)>(stream, dst, value, count);\n  } else if (uintptr % 4 == 0) {\n    LaunchPackFill<T, 4 / sizeof(T)>(stream, dst, value, count);\n  } else if (uintptr % 2 == 0) {\n    LaunchPackFill<T, 2 / sizeof(T)>(stream, dst, value, count);\n  } else {\n    LaunchPackFill<T, 1 / sizeof(T)>(stream, dst, value, count);\n  }\n}\n\ntemplate<typename T>\nclass FillImpl : public Fill {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(FillImpl);\n  FillImpl() = default;\n  ~FillImpl() override = default;\n\n  void Launch(Stream* stream, void* dst, Scalar value, size_t count) override {\n    cudaStream_t cuda_stream = stream->As<CudaStream>()->cuda_stream();\n    LaunchFill<T>(cuda_stream, reinterpret_cast<T*>(dst), GetValue<T>(value), count);\n  }\n};\n\ntemplate<typename T>\nstd::unique_ptr<Fill> NewFill() {\n  return std::unique_ptr<Fill>(new FillImpl<T>());\n}\n\nclass FillFactoryImpl : public FillFactory {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(FillFactoryImpl);\n  FillFactoryImpl() = default;\n  ~FillFactoryImpl() override = default;\n\n  std::unique_ptr<Fill> New(DataType data_type) override {\n#define MAKE_NEW_FILL_ENTRY(type_cpp, type_proto) {type_proto, NewFill<type_cpp>},\n\n    static const std::map<DataType, std::function<std::unique_ptr<Fill>()>> new_fill_handle{\n        OF_PP_FOR_EACH_TUPLE(MAKE_NEW_FILL_ENTRY,\n                             CUDA_PRIMITIVE_REAL_TYPE_SEQ CUDA_PRIMITIVE_COMPLEX_TYPE_SEQ\n                                 CUDA_PRIMITIVE_INT16_TYPE_SEQ)};\n\n#undef MAKE_NEW_FILL_ENTRY\n\n    const auto it = new_fill_handle.find(data_type);\n    if (it != new_fill_handle.end()) {\n      return it->second();\n    } else {\n      return nullptr;\n    }\n  }\n};\n\nREGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, FillFactory, FillFactoryImpl);\n\n}  // namespace\n\n}  // namespace primitive\n}  // namespace ep\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/ep/cuda/primitive/math_elementwise_unary_math_grad_0.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary.cuh\"\n\nnamespace oneflow {\n\nnamespace ep {\nnamespace primitive {\nnamespace broadcast_elementwise_binary {\n\n#define INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY(binary_op, data_type_pair) \\\n  template std::unique_ptr<BroadcastElementwiseBinary> NewBroadcastElementwiseBinary<      \\\n      binary_op, OF_PP_PAIR_FIRST(data_type_pair), OF_PP_PAIR_FIRST(data_type_pair)>(      \\\n      Scalar attr0, Scalar attr1);\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY,\n                                 BINARY_MATH_BACKWARD_OP_SEQ_0, CUDA_PRIMITIVE_FLOATING_TYPE_SEQ);\n\n}  // namespace broadcast_elementwise_binary\n}  // namespace primitive\n}  // namespace ep\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/ep/cuda/primitive/math_elementwise_unary_math_grad_1.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary.cuh\"\n\nnamespace oneflow {\n\nnamespace ep {\nnamespace primitive {\nnamespace broadcast_elementwise_binary {\n\n#define INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY(binary_op, data_type_pair) \\\n  template std::unique_ptr<BroadcastElementwiseBinary> NewBroadcastElementwiseBinary<      \\\n      binary_op, OF_PP_PAIR_FIRST(data_type_pair), OF_PP_PAIR_FIRST(data_type_pair)>(      \\\n      Scalar attr0, Scalar attr1);\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY,\n                                 BINARY_MATH_BACKWARD_OP_SEQ_1, CUDA_PRIMITIVE_FLOATING_TYPE_SEQ);\n\n}  // namespace broadcast_elementwise_binary\n}  // namespace primitive\n}  // namespace ep\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/ep/cuda/primitive/math_elementwise_unary_math_grad_2.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary.cuh\"\n\nnamespace oneflow {\n\nnamespace ep {\nnamespace primitive {\nnamespace broadcast_elementwise_binary {\n\n#define INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY(binary_op, data_type_pair) \\\n  template std::unique_ptr<BroadcastElementwiseBinary> NewBroadcastElementwiseBinary<      \\\n      binary_op, OF_PP_PAIR_FIRST(data_type_pair), OF_PP_PAIR_FIRST(data_type_pair)>(      \\\n      Scalar attr0, Scalar attr1);\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY,\n                                 BINARY_MATH_BACKWARD_OP_SEQ_2, CUDA_PRIMITIVE_FLOATING_TYPE_SEQ);\n\n}  // namespace broadcast_elementwise_binary\n}  // namespace primitive\n}  // namespace ep\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/ep/cuda/primitive/math_elementwise_unary_math_grad_3.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary.cuh\"\n\nnamespace oneflow {\n\nnamespace ep {\nnamespace primitive {\nnamespace broadcast_elementwise_binary {\n\n#define INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY(binary_op, data_type_pair) \\\n  template std::unique_ptr<BroadcastElementwiseBinary> NewBroadcastElementwiseBinary<      \\\n      binary_op, OF_PP_PAIR_FIRST(data_type_pair), OF_PP_PAIR_FIRST(data_type_pair)>(      \\\n      Scalar attr0, Scalar attr1);\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY,\n                                 BINARY_MATH_BACKWARD_OP_SEQ_3, CUDA_PRIMITIVE_FLOATING_TYPE_SEQ);\n\n}  // namespace broadcast_elementwise_binary\n}  // namespace primitive\n}  // namespace ep\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/ep/cuda/primitive/math_elementwise_unary_math_grad_complex.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary.cuh\"\n\nnamespace oneflow {\n\nnamespace ep {\nnamespace primitive {\nnamespace broadcast_elementwise_binary {\n\n#define INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY(binary_op, data_type_pair) \\\n  template std::unique_ptr<BroadcastElementwiseBinary> NewBroadcastElementwiseBinary<      \\\n      binary_op, OF_PP_PAIR_FIRST(data_type_pair), OF_PP_PAIR_FIRST(data_type_pair)>(      \\\n      Scalar attr0, Scalar attr1);\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY,\n                                 BINARY_MATH_BACKWARD_OP_SEQ_COMPLEX,\n                                 CUDA_PRIMITIVE_COMPLEX_TYPE_SEQ);\n\n}  // namespace broadcast_elementwise_binary\n}  // namespace primitive\n}  // namespace ep\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/ep/cuda/primitive/memcpy.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifdef WITH_CUDA\n\n#include \"oneflow/core/ep/include/primitive/memcpy.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n#include <cuda_runtime.h>\n\nnamespace oneflow {\n\nnamespace ep {\nnamespace primitive {\n\nnamespace {\n\nclass MemcpyImpl : public Memcpy {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(MemcpyImpl);\n  MemcpyImpl() = default;\n  ~MemcpyImpl() override = default;\n\n  void Launch(Stream* stream, void* dst, const void* src, size_t count) override {\n    if (dst == src) { return; }\n    auto* cuda_stream = stream->As<CudaStream>();\n    OF_CUDA_CHECK(cudaMemcpyAsync(dst, src, count, cudaMemcpyDefault, cuda_stream->cuda_stream()));\n  }\n};\n\nclass MemcpyFactoryImpl : public MemcpyFactory {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(MemcpyFactoryImpl);\n  MemcpyFactoryImpl() = default;\n  ~MemcpyFactoryImpl() override = default;\n\n  std::unique_ptr<Memcpy> New(MemcpyKind kind) override {\n    return std::unique_ptr<Memcpy>(new MemcpyImpl());\n  }\n};\n\nREGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, MemcpyFactory, MemcpyFactoryImpl);\n\n}  // namespace\n\n}  // namespace primitive\n}  // namespace ep\n\n}  // namespace oneflow\n\n#endif\n"
  },
  {
    "path": "oneflow/core/ep/cuda/primitive/memset.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifdef WITH_CUDA\n\n#include \"oneflow/core/ep/include/primitive/memset.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n#include <cuda_runtime.h>\n\nnamespace oneflow {\n\nnamespace ep {\nnamespace primitive {\n\nnamespace {\n\nclass MemsetImpl : public Memset {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(MemsetImpl);\n  MemsetImpl() = default;\n  ~MemsetImpl() override = default;\n\n  void Launch(Stream* stream, void* ptr, int value, size_t count) override {\n    auto* cuda_stream = stream->As<CudaStream>();\n    OF_CUDA_CHECK(cudaMemsetAsync(ptr, value, count, cuda_stream->cuda_stream()));\n  }\n};\n\nclass MemsetFactoryImpl : public MemsetFactory {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(MemsetFactoryImpl);\n  MemsetFactoryImpl() = default;\n  ~MemsetFactoryImpl() override = default;\n\n  std::unique_ptr<Memset> New() override { return std::unique_ptr<Memset>(new MemsetImpl()); }\n};\n\nREGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, MemsetFactory, MemsetFactoryImpl);\n\n}  // namespace\n\n}  // namespace primitive\n}  // namespace ep\n\n}  // namespace oneflow\n\n#endif\n"
  },
  {
    "path": "oneflow/core/ep/cuda/primitive/permute.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/ep/include/primitive/permute.h\"\n#include \"oneflow/core/ep/common/primitive/permute_impl.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n#include <cuda_runtime.h>\n\nnamespace oneflow {\n\nnamespace ep {\nnamespace primitive {\n\nnamespace permute {\n\nnamespace internal {\n\nnamespace {\n\nconstexpr int32_t kMov4TileSize = 32;\nconstexpr int32_t kMov2TileSize = 64;\nconstexpr int32_t kBlockRows = 8;\n\ntemplate<size_t num_dims, size_t movement_size, typename IndexType>\n__global__ void PermuteKernel(PermuteKernelParams<num_dims, IndexType> params) {\n  using T = typename std::aligned_storage<movement_size, movement_size>::type;\n  const T* src = reinterpret_cast<const T*>(params.src);\n  T* dst = reinterpret_cast<T*>(params.dst);\n  IndexType src_index[num_dims];\n  IndexType dst_index[num_dims];\n  CUDA_1D_KERNEL_LOOP_T(IndexType, i, params.count) {\n    params.dst_index_helper.OffsetToNdIndex(i, dst_index);\n#pragma unroll\n    for (size_t dim = 0; dim < num_dims; ++dim) {\n      src_index[params.permutation[dim]] = dst_index[dim];\n    }\n    IndexType src_offset = params.src_index_helper.NdIndexToOffset(src_index);\n    dst[i] = src[src_offset];\n  }\n}\n\n// (B, X, Y) -> (B, Y, X)\n// refer from https://developer.nvidia.com/blog/efficient-matrix-transpose-cuda-cc/\ntemplate<size_t num_dims, size_t movement_size, size_t tile_size, typename IndexType>\n__global__ void BatchTransposeKernel(const void* src_ptr, void* dst_ptr, IndexType rows,\n                                     IndexType cols, IndexType num_tile_rows,\n                                     IndexType num_tile_cols, int32_t block_nums) {\n  const IndexType src_rows = rows;\n  const IndexType src_cols = cols;\n  const IndexType dst_rows = cols;\n  const IndexType dst_cols = rows;\n\n  using T = typename std::aligned_storage<movement_size, movement_size>::type;\n  __shared__ T tile[tile_size][tile_size + 1];  // To avoid bank conflict.\n\n  const T* src = reinterpret_cast<const T*>(src_ptr);\n  T* dst = reinterpret_cast<T*>(dst_ptr);\n\n  IndexType batch_num_tile = num_tile_rows * num_tile_cols;\n  for (int i = blockIdx.x, step = gridDim.x; i < block_nums; i += step) {\n    const IndexType batch_index = i / batch_num_tile;  // the index of batch.\n    const IndexType tile_index =\n        i - batch_index * batch_num_tile;  // equal to i % (num_tile_rows*num_tile_cols). the\n                                           // flatten index of tile in a batch.\n\n    const IndexType tile_row_index =\n        tile_index / num_tile_cols;  // the row index of tile in a batch.\n    const IndexType tile_col_index =\n        tile_index\n        - tile_row_index\n              * num_tile_cols;  // equal to k % num_tile_cols. the col index of tile in a batch.\n\n    const IndexType offset = batch_index * src_rows * src_cols;\n    {\n      IndexType col_in_tile = threadIdx.x;\n      IndexType col_in_matrix = tile_col_index * tile_size + threadIdx.x;\n#pragma unroll\n      for (IndexType row_in_tile = threadIdx.y; row_in_tile < tile_size;\n           row_in_tile += kBlockRows) {\n        IndexType row_in_matrix = row_in_tile + tile_row_index * tile_size;\n        if (col_in_matrix < src_cols && row_in_matrix < src_rows) {\n          tile[row_in_tile][col_in_tile] = src[offset + row_in_matrix * src_cols + col_in_matrix];\n        }\n      }\n    }\n    __syncthreads();\n    {\n      IndexType col_in_tile = threadIdx.x;\n      IndexType col_in_matrix = tile_row_index * tile_size + threadIdx.x;\n#pragma unroll\n      for (IndexType row_in_tile = threadIdx.y; row_in_tile < tile_size;\n           row_in_tile += kBlockRows) {\n        IndexType row_in_matrix = row_in_tile + tile_col_index * tile_size;\n        if (col_in_matrix < dst_cols && row_in_matrix < dst_rows) {\n          dst[offset + row_in_matrix * dst_cols + col_in_matrix] = tile[col_in_tile][row_in_tile];\n        }\n      }\n    }\n    __syncthreads();\n  }\n}\n\n/*\nHere is a Movementsie=2 version of Batch Transpose.\nWhen the H W can be divided by 2. we can read data use movementsize=4, and write back as\nmovementsize=4.\n*/\ntemplate<size_t num_dims, size_t tile_size, typename IndexType>\n__global__ void BatchTransposeMovement2Kernel(const void* src_ptr, void* dst_ptr, IndexType rows,\n                                              IndexType cols, IndexType num_tile_rows,\n                                              IndexType num_tile_cols, int32_t block_nums) {\n  const IndexType src_rows = rows;\n  const IndexType src_cols = cols;\n  const IndexType dst_rows = cols;\n  const IndexType dst_cols = rows;\n\n  static_assert(tile_size % 2 == 0, \"\");\n  using T_MOV2 = typename std::aligned_storage<2, 2>::type;\n  using T_MOV4 = typename std::aligned_storage<4, 4>::type;\n\n  const T_MOV4* src = reinterpret_cast<const T_MOV4*>(src_ptr);\n  T_MOV4* dst = reinterpret_cast<T_MOV4*>(dst_ptr);\n\n  // Use union structure to process Load and Store.\n  __shared__ union {\n    T_MOV2 tile_m2[tile_size][tile_size + 2];      // half [64][66]\n    T_MOV4 tile_m4[tile_size][tile_size / 2 + 1];  // half2 [64][33]\n  } tile_mem;\n\n  IndexType batch_num_tile = num_tile_rows * num_tile_cols;\n  for (int i = blockIdx.x, step = gridDim.x; i < block_nums; i += step) {\n    const IndexType batch_index = i / batch_num_tile;  // the index of batch.\n    const IndexType tile_index =\n        i - batch_index * batch_num_tile;  // equal to i % (num_tile_rows*num_tile_cols). the\n                                           // flatten index of tile in a batch.\n\n    const IndexType tile_row_index =\n        tile_index / num_tile_cols;  // the row index of tile in a batch.\n    const IndexType tile_col_index =\n        tile_index\n        - tile_row_index\n              * num_tile_cols;  // equal to k % num_tile_cols. the col index of tile in a batch.\n\n    const IndexType offset = batch_index * src_rows * src_cols;\n    {\n      IndexType col_in_tile = threadIdx.x;\n      IndexType col_in_matrix = tile_col_index * tile_size + threadIdx.x * 2;\n#pragma unroll\n      for (IndexType row_in_tile = threadIdx.y; row_in_tile < tile_size;\n           row_in_tile += kBlockRows) {\n        IndexType row_in_matrix = row_in_tile + tile_row_index * tile_size;\n        if (col_in_matrix < src_cols && row_in_matrix < src_rows) {\n          tile_mem.tile_m4[row_in_tile][col_in_tile] =\n              src[(offset + row_in_matrix * src_cols + col_in_matrix) / 2];\n        }\n      }\n    }\n    __syncthreads();\n    {\n      IndexType col_in_tile = threadIdx.x;\n      IndexType col_in_matrix = tile_row_index * tile_size + threadIdx.x * 2;\n#pragma unroll\n      for (IndexType row_in_tile = threadIdx.y; row_in_tile < tile_size;\n           row_in_tile += kBlockRows) {\n        IndexType row_in_matrix = row_in_tile + tile_col_index * tile_size;\n        union {\n          T_MOV4 m4;\n          T_MOV2 m2[2];\n        } tmp_storage;\n\n        if (col_in_matrix < dst_cols && row_in_matrix < dst_rows) {\n          tmp_storage.m2[0] = tile_mem.tile_m2[col_in_tile * 2][row_in_tile];\n          tmp_storage.m2[1] = tile_mem.tile_m2[col_in_tile * 2 + 1][row_in_tile];\n          dst[(offset + row_in_matrix * dst_cols + col_in_matrix) / 2] = tmp_storage.m4;\n        }\n      }\n    }\n    __syncthreads();\n  }\n}\n\ntemplate<size_t num_dims, size_t movement_size, size_t tile_size, typename IndexType>\nvoid LaunchBatchTransposeKernel(cudaStream_t& cuda_stream,\n                                const PermuteKernelParams<num_dims, IndexType>& params,\n                                const IndexType& num_batches, const IndexType& rows,\n                                const IndexType& cols) {\n  IndexType num_tile_rows = (rows + tile_size - 1) / tile_size;\n  IndexType num_tile_cols = (cols + tile_size - 1) / tile_size;\n  const int32_t block_nums = num_batches * num_tile_rows * num_tile_cols;\n  int32_t launched_block_nums = std::min(block_nums, kCudaMaxBlocksNum);\n  if (tile_size == kMov2TileSize) {\n    const int32_t half2_thread = tile_size / 2;  // cause each thread process two half elements.\n    BatchTransposeMovement2Kernel<num_dims, kMov2TileSize, IndexType>\n        <<<launched_block_nums, dim3(half2_thread, kBlockRows), 0, cuda_stream>>>(\n            params.src, params.dst, rows, cols, num_tile_rows, num_tile_cols,\n            block_nums);  // Set threads num as 32x8 cause each threads\n                          // process 4 elements to 64x66 half share memory.\n  } else {\n    BatchTransposeKernel<num_dims, movement_size, tile_size, IndexType>\n        <<<launched_block_nums, dim3(tile_size, kBlockRows), 0, cuda_stream>>>(\n            params.src, params.dst, rows, cols, num_tile_rows, num_tile_cols, block_nums);\n  }\n}\n\ntemplate<size_t tile_size, typename IndexType>\nbool CheckIfGreaterEqualThanTileSize(const IndexType& rows, const IndexType& cols) {\n  if (rows < tile_size || cols < tile_size) { return false; }\n  return true;\n}\n\ntemplate<size_t num_dims, size_t tile_size, typename IndexType>\nbool CheckLaunchBatchTranspose(const int* permutation, const IndexType& num_batches,\n                               const IndexType& rows, const IndexType& cols) {\n  if (CheckIfGreaterEqualThanTileSize<tile_size, IndexType>(rows, cols)) {\n    if (num_batches == 1 && permutation[1] == 0 && permutation[0] == 1) {\n      // 2d tensor case: (0, 1) -> (1, 0)\n      return true;\n    } else if (num_dims == 3 && permutation[2] == 1 && permutation[1] == 2) {\n      // 3d tensor case: (0, 1, 2) -> (0, 2, 1)\n      return true;\n    } else {\n      return false;\n    }\n  }\n  return false;\n}\n\ntemplate<typename IndexType, size_t movement_size>\nbool CheckUseMov2(const IndexType& rows, const IndexType& cols, const void* src, void* dst) {\n  auto src_ptr = reinterpret_cast<std::uintptr_t>(src);\n  auto dst_ptr = reinterpret_cast<std::uintptr_t>(dst);\n  return (movement_size == 2) && (rows % 2 == 0) && (cols % 2 == 0) && (src_ptr % 4 == 0)\n         && (dst_ptr % 4 == 0);\n}\n\ntemplate<size_t num_dims, typename IndexType>\nvoid InferBatchTransposeShape(const int64_t* src_dims, IndexType* num_batches, IndexType* rows,\n                              IndexType* cols) {\n  if (num_dims == 2) {\n    *num_batches = 1;\n    *rows = src_dims[0];\n    *cols = src_dims[1];\n  } else {\n    *num_batches = src_dims[0];\n    *rows = src_dims[1];\n    *cols = src_dims[2];\n  }\n}\n\ntemplate<size_t num_dims, size_t movement_size, typename IndexType>\nvoid LaunchKernel(Stream* stream, const int64_t* src_dims, const void* src, const int* permutation,\n                  void* dst, size_t count) {\n  PermuteKernelParams<num_dims, IndexType> params =\n      MakePermuteParams<num_dims, IndexType>(src_dims, src, permutation, dst, count);\n  cudaStream_t cuda_stream = stream->As<CudaStream>()->cuda_stream();\n\n  if (num_dims == 2 || num_dims == 3) {\n    IndexType num_batches;\n    IndexType rows;\n    IndexType cols;\n    InferBatchTransposeShape<num_dims, IndexType>(src_dims, &num_batches, &rows, &cols);\n    if (CheckLaunchBatchTranspose<num_dims, kMov4TileSize>(params.permutation, num_batches, rows,\n                                                           cols)) {\n      if (CheckUseMov2<IndexType, movement_size>(rows, cols, src, dst)) {\n        LaunchBatchTransposeKernel<num_dims, 2, kMov2TileSize, IndexType>(cuda_stream, params,\n                                                                          num_batches, rows, cols);\n      } else {\n        LaunchBatchTransposeKernel<num_dims, movement_size, kMov4TileSize, IndexType>(\n            cuda_stream, params, num_batches, rows, cols);\n      }\n    } else {\n      if (params.count == 0) { return; }\n      PermuteKernel<num_dims, movement_size, IndexType>\n          <<<BlocksNum4ThreadsNum(params.count), kCudaThreadsNumPerBlock, 0, cuda_stream>>>(params);\n    }\n  } else {\n    if (params.count == 0) { return; }\n    PermuteKernel<num_dims, movement_size, IndexType>\n        <<<BlocksNum4ThreadsNum(params.count), kCudaThreadsNumPerBlock, 0, cuda_stream>>>(params);\n  }\n}\n\nclass PermuteImpl : public Permute {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(PermuteImpl);\n  PermuteImpl() = default;\n  ~PermuteImpl() override = default;\n\n  using Permute::Launch;\n  void Launch(Stream* stream, DataType data_type, size_t num_dims, const int64_t* src_dims,\n              const void* src, const int* permutation, void* dst) override {\n    SimplifyThenLaunch(stream, data_type, num_dims, src_dims, src, permutation, dst);\n  }\n};\n\nclass PermuteFactoryImpl : public PermuteFactory {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(PermuteFactoryImpl);\n  PermuteFactoryImpl() = default;\n  ~PermuteFactoryImpl() override = default;\n\n  std::unique_ptr<Permute> New(size_t max_num_dims) override {\n    if (max_num_dims <= kMaxNumDims) {\n      return std::unique_ptr<Permute>(new PermuteImpl());\n    } else {\n      return nullptr;\n    }\n  }\n};\n\nREGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, PermuteFactory, PermuteFactoryImpl);\n\n}  // namespace\n\n}  // namespace internal\n\n}  // namespace permute\n\n}  // namespace primitive\n}  // namespace ep\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/ep/cuda/primitive/softmax.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/ep/include/primitive/softmax.h\"\n#include \"oneflow/core/ep/include/primitive/log_softmax.h\"\n#include \"oneflow/core/ep/cuda/primitive/type_seq.h\"\n#include \"oneflow/core/cuda/softmax.cuh\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n\nnamespace oneflow {\n\nnamespace ep {\nnamespace primitive {\n\nnamespace {\n\nenum class Algorithm {\n  kSoftmax,\n  kLogSoftmax,\n};\n\ntemplate<Algorithm algorithm, typename T>\nvoid SoftmaxGpu(cudaStream_t cuda_stream, size_t rows, size_t cols, const T* x, T* y) {\n  using ComputeType = typename cuda::softmax::DefaultComputeType<T>::type;\n  oneflow::cuda::softmax::DirectLoad<T, ComputeType> load(x, cols);\n  oneflow::cuda::softmax::DirectStore<ComputeType, T> store(y, cols);\n  if (algorithm == Algorithm::kSoftmax) {\n    OF_CUDA_CHECK((cuda::softmax::DispatchSoftmax<decltype(load), decltype(store), ComputeType>(\n        cuda_stream, load, store, rows, cols)));\n  } else if (algorithm == Algorithm::kLogSoftmax) {\n    OF_CUDA_CHECK((cuda::softmax::DispatchLogSoftmax<decltype(load), decltype(store), ComputeType>(\n        cuda_stream, load, store, rows, cols)));\n  } else {\n    UNIMPLEMENTED();\n  }\n}\n\ntemplate<typename SoftmaxBase, Algorithm algorithm, typename T>\nclass SoftmaxImpl : public SoftmaxBase {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(SoftmaxImpl);\n  SoftmaxImpl() = default;\n  ~SoftmaxImpl() override = default;\n\n  void Launch(Stream* stream, size_t rows, size_t cols, const void* x, void* y) override {\n    cudaStream_t cuda_stream = stream->As<CudaStream>()->cuda_stream();\n    SoftmaxGpu<algorithm, T>(cuda_stream, rows, cols, reinterpret_cast<const T*>(x),\n                             reinterpret_cast<T*>(y));\n  }\n};\n\ntemplate<typename SoftmaxBase, Algorithm algorithm, typename T>\nstd::unique_ptr<SoftmaxBase> NewSoftmax() {\n  return std::unique_ptr<SoftmaxBase>(new SoftmaxImpl<SoftmaxBase, algorithm, T>());\n}\n\ntemplate<typename FactoryBase, typename SoftmaxBase, Algorithm algorithm>\nclass GenericSoftmaxFactoryImpl : public FactoryBase {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(GenericSoftmaxFactoryImpl);\n  GenericSoftmaxFactoryImpl() = default;\n  ~GenericSoftmaxFactoryImpl() override = default;\n\n  std::unique_ptr<SoftmaxBase> New(DataType data_type) override {\n#define MAKE_NEW_SOFTMAX_ENTRY(type_cpp, type_proto) \\\n  {type_proto, NewSoftmax<SoftmaxBase, algorithm, type_cpp>},\n\n    static const std::map<DataType, std::function<std::unique_ptr<SoftmaxBase>()>>\n        new_softmax_handle{\n            OF_PP_FOR_EACH_TUPLE(MAKE_NEW_SOFTMAX_ENTRY, CUDA_PRIMITIVE_FLOATING_TYPE_SEQ)};\n\n#undef MAKE_NEW_SOFTMAX_ENTRY\n\n    const auto it = new_softmax_handle.find(data_type);\n    if (it != new_softmax_handle.end()) {\n      return it->second();\n    } else {\n      return nullptr;\n    }\n  }\n};\n\nusing SoftmaxFactoryImpl = GenericSoftmaxFactoryImpl<SoftmaxFactory, Softmax, Algorithm::kSoftmax>;\nusing LogSoftmaxFactoryImpl =\n    GenericSoftmaxFactoryImpl<LogSoftmaxFactory, LogSoftmax, Algorithm::kLogSoftmax>;\nREGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, SoftmaxFactory, SoftmaxFactoryImpl);\nREGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, LogSoftmaxFactory, LogSoftmaxFactoryImpl);\n\n}  // namespace\n\n}  // namespace primitive\n}  // namespace ep\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/ep/cuda/primitive/softmax_backward.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/ep/include/primitive/softmax_backward.h\"\n#include \"oneflow/core/ep/include/primitive/log_softmax_backward.h\"\n#include \"oneflow/core/ep/cuda/primitive/type_seq.h\"\n#include \"oneflow/core/cuda/softmax.cuh\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n\nnamespace oneflow {\n\nnamespace ep {\nnamespace primitive {\n\nnamespace {\n\nenum class Algorithm {\n  kSoftmax,\n  kLogSoftmax,\n};\n\ntemplate<Algorithm algorithm, typename T>\nvoid SoftmaxBackwardGpu(cudaStream_t cuda_stream, size_t rows, size_t cols, const T* y, const T* dy,\n                        T* dx) {\n  using ComputeType = typename cuda::softmax::DefaultComputeType<T>::type;\n  cuda::softmax::DirectLoad<T, ComputeType> load_y(y, cols);\n  cuda::softmax::DirectLoad<T, ComputeType> load_dy(dy, cols);\n  cuda::softmax::DirectStore<ComputeType, T> store(dx, cols);\n  if (algorithm == Algorithm::kSoftmax) {\n    OF_CUDA_CHECK((cuda::softmax::DispatchSoftmaxGrad<decltype(load_y), decltype(load_dy),\n                                                      decltype(store), ComputeType>(\n        cuda_stream, load_y, load_dy, store, rows, cols)));\n  } else if (algorithm == Algorithm::kLogSoftmax) {\n    OF_CUDA_CHECK((cuda::softmax::DispatchLogSoftmaxGrad<decltype(load_y), decltype(load_dy),\n                                                         decltype(store), ComputeType>(\n        cuda_stream, load_y, load_dy, store, rows, cols)));\n  } else {\n    UNIMPLEMENTED();\n  }\n}\n\ntemplate<typename SoftmaxBackwardBase, Algorithm algorithm, typename T>\nclass SoftmaxBackwardImpl : public SoftmaxBackwardBase {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(SoftmaxBackwardImpl);\n  SoftmaxBackwardImpl() = default;\n  ~SoftmaxBackwardImpl() override = default;\n\n  void Launch(Stream* stream, size_t rows, size_t cols, const void* y, const void* dy,\n              void* dx) override {\n    cudaStream_t cuda_stream = stream->As<CudaStream>()->cuda_stream();\n    SoftmaxBackwardGpu<algorithm, T>(cuda_stream, rows, cols, reinterpret_cast<const T*>(y),\n                                     reinterpret_cast<const T*>(dy), reinterpret_cast<T*>(dx));\n  }\n};\n\ntemplate<typename SoftmaxBackwardBase, Algorithm algorithm, typename T>\nstd::unique_ptr<SoftmaxBackwardBase> NewSoftmaxBackward() {\n  return std::unique_ptr<SoftmaxBackwardBase>(\n      new SoftmaxBackwardImpl<SoftmaxBackwardBase, algorithm, T>());\n}\n\ntemplate<typename BackwardFactoryBase, typename SoftmaxBackwardBase, Algorithm algorithm>\nclass GenericSoftmaxBackwardFactoryImpl : public BackwardFactoryBase {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(GenericSoftmaxBackwardFactoryImpl);\n  GenericSoftmaxBackwardFactoryImpl() = default;\n  ~GenericSoftmaxBackwardFactoryImpl() override = default;\n\n  std::unique_ptr<SoftmaxBackwardBase> New(DataType data_type) override {\n#define MAKE_NEW_SOFTMAX_ENTRY(type_cpp, type_proto) \\\n  {type_proto, NewSoftmaxBackward<SoftmaxBackwardBase, algorithm, type_cpp>},\n\n    static const std::map<DataType, std::function<std::unique_ptr<SoftmaxBackwardBase>()>>\n        new_softmax_backward_handle{\n            OF_PP_FOR_EACH_TUPLE(MAKE_NEW_SOFTMAX_ENTRY, CUDA_PRIMITIVE_FLOATING_TYPE_SEQ)};\n\n#undef MAKE_NEW_SOFTMAX_ENTRY\n\n    const auto it = new_softmax_backward_handle.find(data_type);\n    if (it != new_softmax_backward_handle.end()) {\n      return it->second();\n    } else {\n      return nullptr;\n    }\n  }\n};\n\nusing SoftmaxBackwardFactoryImpl =\n    GenericSoftmaxBackwardFactoryImpl<SoftmaxBackwardFactory, SoftmaxBackward, Algorithm::kSoftmax>;\nusing LogSoftmaxBackwardFactoryImpl =\n    GenericSoftmaxBackwardFactoryImpl<LogSoftmaxBackwardFactory, LogSoftmaxBackward,\n                                      Algorithm::kLogSoftmax>;\nREGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, SoftmaxBackwardFactory, SoftmaxBackwardFactoryImpl);\nREGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, LogSoftmaxBackwardFactory,\n                           LogSoftmaxBackwardFactoryImpl);\n\n}  // namespace\n\n}  // namespace primitive\n}  // namespace ep\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/ep/cuda/primitive/tensor_fill.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/ep/include/primitive/tensor_fill.h\"\n#include \"oneflow/core/ep/cuda/primitive/type_seq.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n\nnamespace oneflow {\n\nnamespace ep {\nnamespace primitive {\n\nnamespace {\n\ntemplate<size_t size>\nusing Storage = typename std::aligned_storage<size, size>::type;\n\ntemplate<typename T, size_t pack>\nunion Pack {\n  static constexpr size_t size = sizeof(T) * pack;\n  explicit __device__ __host__ Pack(const T value) {\n    static_assert(sizeof(Pack) == size, \"\");\n    static_assert(alignof(Pack) == size, \"\");\n#pragma unroll\n    for (size_t i = 0; i < pack; ++i) { elem[i] = value; }\n  }\n  T elem[pack];\n  Storage<size> storage;\n};\n\ntemplate<typename T, size_t pack>\n__global__ void TensorFillGpu(T* dst, const T* value, size_t count) {\n  const size_t pack_count = count / pack;\n  const T fill_value = value[0];\n  Pack<T, pack> pack_value(fill_value);\n  auto* pack_dst = reinterpret_cast<decltype(pack_value.storage)*>(dst);\n  CUDA_1D_KERNEL_LOOP_T(size_t, i, pack_count) { pack_dst[i] = pack_value.storage; }\n  T* tail_dst = dst + pack_count * pack;\n  const size_t tail_count = count - pack_count * pack;\n  CUDA_1D_KERNEL_LOOP_T(size_t, i, tail_count) { tail_dst[i] = fill_value; }\n}\n\ntemplate<typename T, size_t pack>\ntypename std::enable_if<(pack != 0), void>::type LaunchPackTensorFill(cudaStream_t stream, T* dst,\n                                                                      const T* value,\n                                                                      size_t count) {\n  TensorFillGpu<T, pack>\n      <<<BlocksNum4ThreadsNum(count), kCudaThreadsNumPerBlock, 0, stream>>>(dst, value, count);\n}\n\ntemplate<typename T, size_t pack>\ntypename std::enable_if<(pack == 0), void>::type LaunchPackTensorFill(cudaStream_t stream, T* dst,\n                                                                      const T* value,\n                                                                      size_t count) {\n  LOG(FATAL) << \"wrong alignment\";\n}\n\ntemplate<typename T>\nvoid LaunchTensorFill(cudaStream_t stream, T* dst, const T* value, size_t count) {\n  auto uintptr = reinterpret_cast<std::uintptr_t>(dst);\n  if (uintptr % 16 == 0) {\n    LaunchPackTensorFill<T, 16 / sizeof(T)>(stream, dst, value, count);\n  } else if (uintptr % 8 == 0) {\n    LaunchPackTensorFill<T, 8 / sizeof(T)>(stream, dst, value, count);\n  } else if (uintptr % 4 == 0) {\n    LaunchPackTensorFill<T, 4 / sizeof(T)>(stream, dst, value, count);\n  } else if (uintptr % 2 == 0) {\n    LaunchPackTensorFill<T, 2 / sizeof(T)>(stream, dst, value, count);\n  } else {\n    LaunchPackTensorFill<T, 1 / sizeof(T)>(stream, dst, value, count);\n  }\n}\n\ntemplate<typename T>\nclass TensorFillImpl : public TensorFill {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(TensorFillImpl);\n  TensorFillImpl() = default;\n  ~TensorFillImpl() override = default;\n\n  void Launch(Stream* stream, const void* src, void* dst, size_t count) override {\n    cudaStream_t cuda_stream = stream->As<CudaStream>()->cuda_stream();\n    const T* value = reinterpret_cast<const T*>(src);\n    LaunchTensorFill<T>(cuda_stream, reinterpret_cast<T*>(dst), value, count);\n  }\n};\n\ntemplate<typename T>\nstd::unique_ptr<TensorFill> NewTensorFill() {\n  return std::unique_ptr<TensorFill>(new TensorFillImpl<T>());\n}\n\nclass TensorFillFactoryImpl : public TensorFillFactory {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(TensorFillFactoryImpl);\n  TensorFillFactoryImpl() = default;\n  ~TensorFillFactoryImpl() override = default;\n\n  std::unique_ptr<TensorFill> New(DataType data_type) override {\n#define MAKE_NEW_TENSOR_FILL_ENTRY(type_cpp, type_proto) {type_proto, NewTensorFill<type_cpp>},\n\n    static const std::map<DataType, std::function<std::unique_ptr<TensorFill>()>> new_fill_handle{\n        OF_PP_FOR_EACH_TUPLE(MAKE_NEW_TENSOR_FILL_ENTRY, CUDA_PRIMITIVE_REAL_TYPE_SEQ)};\n\n#undef MAKE_NEW_TENSOR_FILL_ENTRY\n\n    const auto it = new_fill_handle.find(data_type);\n    if (it != new_fill_handle.end()) {\n      return it->second();\n    } else {\n      return nullptr;\n    }\n  }\n};\n\nREGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, TensorFillFactory, TensorFillFactoryImpl);\n\n}  // namespace\n\n}  // namespace primitive\n}  // namespace ep\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/ep/cuda/primitive/type_seq.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_EP_CUDA_PRIMITIVE_TYPE_SEQ_H_\n#define ONEFLOW_CORE_EP_CUDA_PRIMITIVE_TYPE_SEQ_H_\n\n#include \"oneflow/core/common/preprocessor.h\"\n#include \"oneflow/core/common/data_type.h\"\n\n#ifdef WITH_CUDA\n#include <cuda.h>\n#include <cuda_fp16.h>\n#include <cuComplex.h>\n\n#if CUDA_VERSION >= 11000\n#include <cuda_bf16.h>\n#endif  // CUDA_VERSION >= 11000\n\n#define CUDA_PRIMITIVE_BOOL_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(bool, DataType::kBool)\n#define CUDA_PRIMITIVE_CHAR_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(char, DataType::kChar)\n#define CUDA_PRIMITIVE_INT8_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(int8_t, DataType::kInt8)\n#define CUDA_PRIMITIVE_UINT8_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(uint8_t, DataType::kUInt8)\n#define CUDA_PRIMITIVE_INT16_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(int16_t, DataType::kInt16)\n#define CUDA_PRIMITIVE_INT32_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(int32_t, DataType::kInt32)\n#define CUDA_PRIMITIVE_UINT32_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(uint32_t, DataType::kUInt32)\n#define CUDA_PRIMITIVE_INT64_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(int64_t, DataType::kInt64)\n#define CUDA_PRIMITIVE_UINT64_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(uint64_t, DataType::kUInt64)\n#define CUDA_PRIMITIVE_FLOAT_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(float, DataType::kFloat)\n#define CUDA_PRIMITIVE_DOUBLE_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(double, DataType::kDouble)\n#define CUDA_PRIMITIVE_FLOAT16_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(half, DataType::kFloat16)\n#define CUDA_PRIMITIVE_COMPLEX64_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(cuComplex, DataType::kComplex64)\n#define CUDA_PRIMITIVE_COMPLEX128_TYPE_SEQ \\\n  OF_PP_MAKE_TUPLE_SEQ(cuDoubleComplex, DataType::kComplex128)\n\n#if CUDA_VERSION >= 11000\n#define CUDA_PRIMITIVE_BFLOAT16_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(nv_bfloat16, DataType::kBFloat16)\n#else\n#define CUDA_PRIMITIVE_BFLOAT16_TYPE_SEQ\n#endif  // CUDA_VERSION >= 11000\n\n#define CUDA_PRIMITIVE_REAL_TYPE_SEQ \\\n  CUDA_PRIMITIVE_BOOL_TYPE_SEQ       \\\n  CUDA_PRIMITIVE_CHAR_TYPE_SEQ       \\\n  CUDA_PRIMITIVE_INT8_TYPE_SEQ       \\\n  CUDA_PRIMITIVE_UINT8_TYPE_SEQ      \\\n  CUDA_PRIMITIVE_INT32_TYPE_SEQ      \\\n  CUDA_PRIMITIVE_INT64_TYPE_SEQ      \\\n  CUDA_PRIMITIVE_FLOAT_TYPE_SEQ      \\\n  CUDA_PRIMITIVE_DOUBLE_TYPE_SEQ     \\\n  CUDA_PRIMITIVE_FLOAT16_TYPE_SEQ    \\\n  CUDA_PRIMITIVE_BFLOAT16_TYPE_SEQ\n\n#define CUDA_PRIMITIVE_COMPLEX_TYPE_SEQ \\\n  CUDA_PRIMITIVE_COMPLEX64_TYPE_SEQ     \\\n  CUDA_PRIMITIVE_COMPLEX128_TYPE_SEQ\n\n#define CUDA_PRIMITIVE_FLOATING_TYPE_SEQ \\\n  CUDA_PRIMITIVE_FLOAT_TYPE_SEQ          \\\n  CUDA_PRIMITIVE_DOUBLE_TYPE_SEQ         \\\n  CUDA_PRIMITIVE_FLOAT16_TYPE_SEQ        \\\n  CUDA_PRIMITIVE_BFLOAT16_TYPE_SEQ\n\n#define CUDA_PRIMITIVE_INT_TYPE_SEQ \\\n  CUDA_PRIMITIVE_UINT8_TYPE_SEQ     \\\n  CUDA_PRIMITIVE_INT8_TYPE_SEQ      \\\n  CUDA_PRIMITIVE_INT32_TYPE_SEQ     \\\n  CUDA_PRIMITIVE_INT64_TYPE_SEQ\n\n#define UTIL_OPS_DATA_TYPE_SEQ    \\\n  CUDA_PRIMITIVE_INT8_TYPE_SEQ    \\\n  CUDA_PRIMITIVE_UINT8_TYPE_SEQ   \\\n  CUDA_PRIMITIVE_INT32_TYPE_SEQ   \\\n  CUDA_PRIMITIVE_INT64_TYPE_SEQ   \\\n  CUDA_PRIMITIVE_FLOAT_TYPE_SEQ   \\\n  CUDA_PRIMITIVE_DOUBLE_TYPE_SEQ  \\\n  CUDA_PRIMITIVE_FLOAT16_TYPE_SEQ \\\n  CUDA_PRIMITIVE_BFLOAT16_TYPE_SEQ\n\n#endif  // WITH_CUDA\n\n#endif  // ONEFLOW_CORE_EP_CUDA_PRIMITIVE_TYPE_SEQ_H_\n"
  },
  {
    "path": "oneflow/core/ep/cuda/primitive/unary_functor.cuh",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#ifndef ONEFLOW_CORE_EP_CUDA_PRIMITIVE_UNARY_FUNCTOR_CUH\n#define ONEFLOW_CORE_EP_CUDA_PRIMITIVE_UNARY_FUNCTOR_CUH\n#include \"oneflow/core/ep/common/primitive/unary_functor.h\"\n#include \"oneflow/core/ep/cuda/primitive/type_seq.h\"\n#include \"oneflow/core/cuda/elementwise.cuh\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n#include <cuda.h>\n#include \"oneflow/core/common/math_util.h\"\n\nnamespace oneflow {\nnamespace ep {\nnamespace primitive {\n\ntemplate<typename Dst, typename Src>\nstruct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kGelu, Dst, Src> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src src) const {\n    return static_cast<Src>(0.5) * src\n           * (static_cast<Src>(1.0) + erf(static_cast<Src>(M_SQRT1_2) * src));\n  }\n};\n\ntemplate<typename Dst, typename Src>\nstruct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kFastGelu, Dst, Src> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src src) const {\n    // ref to: https://mlfromscratch.com/activation-functions-explained/#gelu\n    const Src half = static_cast<Src>(0.5);\n    const Src one = static_cast<Src>(1);\n    const Src tanh_in = alpha * (src + beta * src * src * src);\n    return half * src * (one + tanh(tanh_in));\n  }\n\n private:\n  // constant ref to:\n  // https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/test/testdata/transform/fusion/fast_gelu.py\n  static constexpr Src alpha = static_cast<Src>(0.7978845608028654);\n  static constexpr Src beta = static_cast<Src>(0.044714998453855515);\n};\n\ntemplate<typename Dst, typename Src>\nstruct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kQuickGelu, Dst, Src> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src src) const {\n    const Src sigmoid =\n        static_cast<Dst>(static_cast<Src>(1.0) / (static_cast<Src>(1.0) + exp(-src * alpha)));\n    return src * sigmoid;\n  }\n\n private:\n  static constexpr Src alpha = static_cast<Src>(1.702);\n};\n\ntemplate<typename Dst, typename Src>\nstruct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kSquareReLU, Dst, Src> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src src) const {\n    return static_cast<Dst>((src > static_cast<Src>(0.0)) ? src * src : 0);\n  }\n};\n\nnamespace unary_functor_internal {\n\nnamespace {\n\nOF_DEVICE_FUNC\nfloat TanhApprox(float x) {\n#if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000)\n  float r;\n  asm(\"tanh.approx.f32 %0,%1; \\n\\t\" : \"=f\"(r) : \"f\"(x));\n  return r;\n#else\n  return tanhf(x);\n#endif  // (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000)\n}\n\n}  // namespace\n\n}  // namespace unary_functor_internal\n\ntemplate<>\nstruct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kFastGelu, half, half> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) : float_functor(attr0, attr1) {}\n\n  OF_DEVICE_FUNC half operator()(half src) const {\n#if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000)\n    const float tanh_in =\n        __half2float(__float2half_rn(alpha) * (src + __float2half_rn(beta) * src * src * src));\n    const float tanh_out = unary_functor_internal::TanhApprox(tanh_in);\n    return __float2half_rn(0.5F) * src * (__float2half_rn(1.0F) + __float2half_rn(tanh_out));\n#else\n    return static_cast<half>(float_functor(static_cast<float>(src)));\n#endif  // (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000)\n  }\n\n#if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000)\n  __device__ void Apply2(half* dst, const half* src) const {\n    const half2 src2 = *(reinterpret_cast<const half2*>(src));\n    const float2 tanh_in = __half22float2(__hmul2(\n        __float2half2_rn(alpha),\n        __hadd2(src2, __hmul2(__hmul2(__hmul2(__float2half2_rn(beta), src2), src2), src2))));\n    float2 tanh_out;\n    tanh_out.x = unary_functor_internal::TanhApprox(tanh_in.x);\n    tanh_out.y = unary_functor_internal::TanhApprox(tanh_in.y);\n    const half2 dst2 = __hmul2(__hmul2(__float2half2_rn(0.5F), src2),\n                               __hadd2(__float2half2_rn(1.0F), __float22half2_rn(tanh_out)));\n    *reinterpret_cast<half2*>(dst) = dst2;\n  }\n#endif  // (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000)\n\n private:\n  static constexpr float alpha = 0.7978845608028654F;\n  static constexpr float beta = 0.044714998453855515F;\n  UnaryFunctor<DeviceType::kCUDA, UnaryOp::kFastGelu, float, float> float_functor;\n};\n\ntemplate<>\nstruct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kTanh, float, float> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC float operator()(float src) const { return tanhf(src); }\n};\n\ntemplate<>\nstruct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kTanh, double, double> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC double operator()(double src) const { return tanh(src); }\n};\n\ntemplate<>\nstruct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kTanh, half, half> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC half operator()(half src) const { return __float2half(tanhf(__half2float(src))); }\n};\n\ntemplate<>\nstruct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kIsInf, bool, half> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC bool operator()(half src) const { return isinf(__half2float(src)); }\n};\n\ntemplate<>\nstruct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kIsInf, bool, float> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC bool operator()(float src) const { return isinf(src); }\n};\n\ntemplate<>\nstruct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kIsInf, bool, double> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC bool operator()(double src) const { return isinf(src); }\n};\n\ntemplate<>\nstruct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kIsNan, bool, half> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC bool operator()(half src) const { return isnan(__half2float(src)); }\n};\n\ntemplate<>\nstruct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kIsNan, bool, float> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC bool operator()(float src) const { return isnan(src); }\n};\n\ntemplate<>\nstruct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kIsNan, bool, double> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC bool operator()(double src) const { return isnan(src); }\n};\n\ntemplate<>\nstruct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kIsFinite, bool, half> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC bool operator()(half src) const { return isfinite(__half2float(src)); }\n};\n\ntemplate<>\nstruct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kIsFinite, bool, float> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC bool operator()(float src) const { return isfinite(src); }\n};\n\ntemplate<>\nstruct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kIsFinite, bool, double> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC bool operator()(double src) const { return isfinite(src); }\n};\n\ntemplate<>\nstruct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kTrunc, half, half> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n  __device__ half operator()(half src) const { return htrunc(src); }\n};\n\ntemplate<>\nstruct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kTrunc, float, float> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n  OF_DEVICE_FUNC float operator()(float src) const { return truncf(src); }\n};\n\ntemplate<>\nstruct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kTrunc, double, double> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n  OF_DEVICE_FUNC double operator()(double src) const { return trunc(src); }\n};\n\ntemplate<typename Dst, typename Src>\nstruct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kDigamma, Dst, Src> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src in) const {\n    // references\n    // https://github.com/pytorch/pytorch/blob/release/1.13/aten/src/ATen/native/cuda/Math.cuh#L3029-L3090\n    static const double PI_f64 = 3.14159265358979323846;\n    const Src PSI_10 = 2.25175258906672110764;\n    const Src A[] = {\n        8.33333333333333333333E-2,  -2.10927960927960927961E-2, 7.57575757575757575758E-3,\n        -4.16666666666666666667E-3, 3.96825396825396825397E-3,  -8.33333333333333333333E-3,\n        8.33333333333333333333E-2,\n    };\n\n    Src x = static_cast<Src>(in);\n    if (x == static_cast<Src>(0)) {\n      // As per C++ standard for gamma related functions and SciPy,\n      // If the argument is ±0, ±∞ is returned\n      return std::copysign(static_cast<Src>(INFINITY), -x);\n    }\n\n    bool x_is_integer = x == trunc(x);\n    Src result = static_cast<Src>(0);\n    if (x < 0) {\n      if (x_is_integer) {\n        // As per C++ standard for gamma related functions and SciPy,\n        // If the argument is a negative integer, NaN is returned\n        return static_cast<Src>(NAN);\n      }\n      // Extracts the fractional part of x as r, since tan(pi * r) is more numerically\n      // accurate than tan(pi * x). While these operations are mathematically equivalent\n      // since both x and r are in radians and tan() has a periodicity of pi, in practice\n      // the computation of pi * x is a source of error (when |x| > 1).\n      double q, r;\n      r = modf(static_cast<double>(x), &q);\n      result = static_cast<Src>(-PI_f64 / tan(PI_f64 * r));\n      x = static_cast<Src>(1) - x;\n    }\n\n    while (x < 10) {\n      result -= static_cast<Src>(1) / x;\n      x += 1;\n    }\n    if (x == static_cast<Src>(10)) { return static_cast<Src>(result + PSI_10); }\n\n    Src y = 0;\n    if (x < 1.0e17) {\n      Src z = static_cast<Src>(1) / (x * x);\n\n      Src polevl_result = 0;\n      for (int i = 0; i <= 6; i++) { polevl_result = polevl_result * z + A[i]; }\n      y = z * polevl_result;\n    }\n\n    return static_cast<Src>(log(x) - (static_cast<Src>(0.5) / x) - y + result);\n  }\n};\n\ntemplate<typename Dst, typename Src>\nstruct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kTrigamma, Dst, Src> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src x) const {\n    // references\n    // https://github.com/pytorch/pytorch/blob/release/1.13/aten/src/ATen/native/cuda/Math.cuh#L387-L410\n    const Src PI{3.14159265358979323846};\n    Src sign = 1;\n    Src result = 0;\n\n    if (x < Src{0.5}) {\n      sign = -1;\n      Src sin_pi_x = sin(PI * x);\n      result -= (PI * PI) / (sin_pi_x * sin_pi_x);\n      x = 1 - x;\n    }\n\n    for (int i = 0; i < 6; ++i) {\n      result += Src{1} / (x * x);\n      x += 1;\n    }\n\n    const Src one{1};\n    const Src ixx = one / (x * x);\n    result += (one + one / (Src{2} * x)\n               + ixx * (one / Src{6} - ixx * (one / Src{30} - ixx * (one / Src{42}))))\n              / x;\n    return sign * result;\n  }\n};\n\ntemplate<>\nstruct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kAbs, half, half> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  __device__ half operator()(half src) const {\n    return __hlt(src, static_cast<half>(0)) ? __hneg(src) : src;\n  }\n};\n\ntemplate<typename Dst, typename Src>\nstruct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kNanAssign, Dst, Src> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src src) const { return isnan(src) ? static_cast<Dst>(0.0) : src; }\n};\n\n#if CUDA_VERSION >= 11000\ntemplate<>\nstruct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kAbs, nv_bfloat16, nv_bfloat16> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  __device__ nv_bfloat16 operator()(nv_bfloat16 src) const {\n#if CUDA_ARCH >= 800\n    return __habs(src);\n#else\n    return __float2bfloat16(abs(__bfloat162float(src)));\n#endif  // CUDA_ARCH >= 800\n  }\n};\n#endif  // CUDA_VERSION >= 11000\n\n/*********half dtype support*********/\ntemplate<typename Dst>\nstruct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kCast, Dst, half> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(half src) const { return static_cast<Dst>(__half2float(src)); }\n};\n\ntemplate<typename Src>\nstruct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kCast, half, Src> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC half operator()(Src src) const { return __float2half(static_cast<float>(src)); }\n};\n\ntemplate<>\nstruct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kCast, half, half> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC half operator()(half src) const { return src; }\n};\n\n/*********nv_bfloat16 dtype support*********/\n#if CUDA_VERSION >= 11000\ntemplate<>\nstruct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kCast, nv_bfloat16, half> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC nv_bfloat16 operator()(half src) const {\n    return __float2bfloat16(__half2float(src));\n  }\n};\n\ntemplate<typename Dst>\nstruct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kCast, Dst, nv_bfloat16> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(nv_bfloat16 src) const {\n    return static_cast<Dst>(__bfloat162float(src));\n  }\n};\n\ntemplate<typename Src>\nstruct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kCast, nv_bfloat16, Src> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC nv_bfloat16 operator()(Src src) const {\n    return __float2bfloat16(static_cast<float>(src));\n  }\n};\n\ntemplate<>\nstruct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kCast, half, nv_bfloat16> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC half operator()(nv_bfloat16 src) const {\n    return __float2half(__bfloat162float(src));\n  }\n};\n\ntemplate<>\nstruct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kCast, nv_bfloat16, nv_bfloat16> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC nv_bfloat16 operator()(nv_bfloat16 src) const { return src; }\n};\n\ntemplate<>\nstruct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kCast, cuComplex, nv_bfloat16> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC cuComplex operator()(nv_bfloat16 src) const {\n    return make_cuComplex((__bfloat162float(src)), 0.0);\n  }\n};\n\ntemplate<>\nstruct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kCast, cuDoubleComplex, nv_bfloat16> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC cuDoubleComplex operator()(nv_bfloat16 src) const {\n    return make_cuDoubleComplex(static_cast<double>(__bfloat162float(src)), 0.0);\n  }\n};\n\n#endif  // CUDA_VERSION >= 11000\n\n#define SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(op)                                         \\\n  template<>                                                                                 \\\n  struct UnaryFunctor<DeviceType::kCUDA, op, half, half> {                                   \\\n    OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) : float_functor(attr0, attr1) {} \\\n                                                                                             \\\n    UnaryFunctor<DeviceType::kCUDA, op, float, float> float_functor;                         \\\n    OF_DEVICE_FUNC half operator()(half src) const {                                         \\\n      return __float2half(float_functor(__half2float(src)));                                 \\\n    }                                                                                        \\\n  };\n\nSPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kElu);\nSPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kCelu);\nSPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kGelu);\nSPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kMish);\nSPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kSelu);\nSPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kSilu);\nSPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kSoftSign);\nSPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kSoftPlus);\nSPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kAcos);\nSPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kAcosh);\nSPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kAsin);\nSPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kAsinh);\nSPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kAtan);\nSPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kAtanh);\nSPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kCeil);\nSPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kCos);\nSPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kCosh);\nSPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kDigamma);\nSPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kTrigamma);\nSPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kErf);\nSPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kErfc);\nSPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kExp);\nSPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kExp2);\nSPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kExpm1);\nSPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kFloor);\nSPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kLgamma);\nSPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kLog);\nSPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kLog2);\nSPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kLog10);\nSPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kLog1p);\nSPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kLogSigmoid);\nSPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kRint);\nSPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kRound);\nSPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kRsqrt);\nSPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kSigmoid);\nSPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kSin);\nSPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kSinh);\nSPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kSqrt);\nSPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kSquare);\nSPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kTan);\nSPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kReciprocalNoNan);\nSPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kNotEqualZero);\nSPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kNanAssign);\nSPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kQuickGelu);\nSPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kSquareReLU);\n\n/*********nv_bfloat16_kernel*******/\n\n#if CUDA_VERSION >= 11000\n\n#define SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(op)                                     \\\n  template<>                                                                                 \\\n  struct UnaryFunctor<DeviceType::kCUDA, op, nv_bfloat16, nv_bfloat16> {                     \\\n    OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) : float_functor(attr0, attr1) {} \\\n                                                                                             \\\n    UnaryFunctor<DeviceType::kCUDA, op, float, float> float_functor;                         \\\n    OF_DEVICE_FUNC nv_bfloat16 operator()(nv_bfloat16 src) const {                           \\\n      return __float2bfloat16(float_functor(__bfloat162float(src)));                         \\\n    }                                                                                        \\\n  };\n\nSPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kElu);\nSPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kCelu);\nSPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kGelu);\nSPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kHardSwish);\nSPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kHardSigmoid);\nSPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kHardShrink);\nSPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kHardTanh);\nSPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kLeakyRelu);\nSPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kMish);\nSPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kSelu);\nSPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kSilu);\nSPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kSoftShrink);\nSPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kSoftSign);\nSPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kSoftPlus);\nSPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kTanh);\nSPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kThreshold);\nSPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kAcos);\nSPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kAcosh);\nSPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kAsin);\nSPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kAsinh);\nSPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kAtan);\nSPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kAtanh);\nSPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kCeil);\nSPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kCos);\nSPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kCosh);\nSPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kErf);\nSPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kErfc);\nSPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kExp);\nSPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kExp2);\nSPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kExpm1);\nSPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kFloor);\nSPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kLgamma);\nSPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kLog);\nSPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kLog2);\nSPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kLog10);\nSPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kLog1p);\nSPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kLogSigmoid);\nSPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kRint);\nSPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kRound);\nSPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kRsqrt);\nSPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kSigmoid);\nSPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kSin);\nSPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kSinh);\nSPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kSqrt);\nSPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kSquare);\nSPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kTan);\nSPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kReciprocalNoNan);\nSPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kNotEqualZero);\nSPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kNanAssign);\nSPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kFastGelu);\nSPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kQuickGelu);\nSPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kSquareReLU);\nSPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kDigamma);\nSPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kTrigamma);\n\ntemplate<>\nstruct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kIsInf, bool, nv_bfloat16> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC bool operator()(nv_bfloat16 src) const { return isinf(__bfloat162float(src)); }\n};\n\ntemplate<>\nstruct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kIsNan, bool, nv_bfloat16> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC bool operator()(nv_bfloat16 src) const { return isnan(__bfloat162float(src)); }\n};\ntemplate<>\nstruct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kIsFinite, bool, nv_bfloat16> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC bool operator()(nv_bfloat16 src) const { return isfinite(__bfloat162float(src)); }\n};\n\ntemplate<>\nstruct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kTrunc, nv_bfloat16, nv_bfloat16> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n  __device__ nv_bfloat16 operator()(nv_bfloat16 src) const {\n#if CUDA_ARCH >= 800\n    return htrunc(src);\n#else\n    return __float2bfloat16(truncf(__bfloat162float(src)));\n#endif  // CUDA_ARCH >= 800\n  }\n};\n\n#endif  // CUDA_VERSION >= 11000\n\n/*********float complex dtype support*********/\ntemplate<typename Dst, typename Src>\nstruct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kConj, Dst, Src> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src src) const { return Dst{src.x, -src.y}; }\n};\n\ntemplate<typename Dst, typename Src>\nstruct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kReal, Dst, Src> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast<Dst>(src.x); }\n};\n\ntemplate<typename Dst, typename Src>\nstruct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kImag, Dst, Src> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast<Dst>(src.y); }\n};\n\ntemplate<typename Dst, typename Src>\nstruct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kRealGrad, Dst, Src> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src src) const { return Dst{src, 0.0}; }\n};\n\ntemplate<typename Dst, typename Src>\nstruct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kImagGrad, Dst, Src> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC Dst operator()(Src src) const { return Dst{0.0, src}; }\n};\n\n// avoid warning: narrowing conversion\ntemplate<>\nstruct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kRealGrad, cuComplex, double> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC cuComplex operator()(double src) const {\n    return cuComplex{static_cast<float>(src), 0.0f};\n  }\n};\n\ntemplate<>\nstruct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kImagGrad, cuComplex, double> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC cuComplex operator()(double src) const {\n    return cuComplex{0.0f, static_cast<float>(src)};\n  }\n};\n\ntemplate<typename Src>\nstruct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kCast, cuComplex, Src> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC cuComplex operator()(Src src) const {\n    return make_cuComplex(static_cast<float>(src), 0.0);\n  }\n};\n\ntemplate<>\nstruct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kCast, cuComplex, cuComplex> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC cuComplex operator()(cuComplex src) const { return src; }\n};\n\ntemplate<>\nstruct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kCast, cuComplex, cuDoubleComplex> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC cuComplex operator()(cuDoubleComplex src) const {\n    return cuComplexDoubleToFloat(src);\n  }\n};\n\ntemplate<>\nstruct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kCast, cuComplex, half> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC cuComplex operator()(half src) const {\n    return make_cuComplex((__half2float(src)), 0.0);\n  }\n};\n\ntemplate<>\nstruct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kIdentity, cuComplex, cuComplex> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC cuComplex operator()(cuComplex src) const { return src; }\n};\n\ntemplate<>\nstruct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kConj, cuComplex, cuComplex> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n  OF_DEVICE_FUNC cuComplex operator()(cuComplex src) const { return cuComplex{src.x, -src.y}; }\n};\n\n// reference : thrust: `thrust/detail/complex/csqrtf.h:csqrtf`\ntemplate<>\nstruct UnaryFunctor<kCUDA, UnaryOp::kSqrt, cuComplex, cuComplex> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC cuComplex operator()(cuComplex src) const {\n    float a = src.x, b = src.y;\n    float t = 0.0f;\n    int scale = 1;\n    cuComplex result;\n\n    /* We risk spurious overflow for components >= FLT_MAX / (1 + sqrt(2)). */\n    const float THRESH = 1.40949553037932e+38f;\n\n    /* Handle special cases. */\n    if (src.x == 0.0f && src.y == float()) {\n      // return (complex<float>(0, b));\n      return (cuComplex{0.0f, b});\n    }\n\n    // FLT_MIN*2\n    const float low_thresh = 2.35098870164458e-38f;\n    scale = 0;\n\n    if (fabsf(a) >= THRESH || fabsf(b) >= THRESH) {\n      /* Scale to avoid overflow. */\n      a *= 0.25f;\n      b *= 0.25f;\n      scale = 1;\n    } else if (fabsf(a) <= low_thresh && fabsf(b) <= low_thresh) {\n      /* Scale to avoid underflow. */\n      a *= 4.f;\n      b *= 4.f;\n      scale = 2;\n    }\n\n    /* Algorithm 312, CACM vol 10, Oct 1967. */\n    if (a >= 0.0f) {\n      t = sqrtf((a + hypotf(a, b)) * 0.5f);\n      // result = complex<float>(t, b / (2.0f * t));\n      result.x = t;\n      result.y = b / (2.0f * t);\n    } else {\n      t = sqrtf((-a + hypotf(a, b)) * 0.5f);\n      // result = complex<float>(fabsf(b) / (2.0f * t), copysignf(t, b));\n      result.x = fabsf(b) / (2.0f * t);\n      result.y = copysignf(t, b);\n    }\n\n    /* Rescale. */\n    if (scale == 1) {\n      // return (result * 2.0f);\n      result.x *= 2.0f;\n      result.y *= 2.0f;\n    } else if (scale == 2) {\n      // return (result * 0.5f);\n      result.x *= 0.5f;\n      result.y *= 0.5f;\n    }\n\n    return (result);\n  }\n};\n\n/*********double complex dtype support*********/\ntemplate<typename Src>\nstruct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kCast, cuDoubleComplex, Src> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC cuDoubleComplex operator()(Src src) const {\n    return make_cuDoubleComplex(static_cast<double>(src), 0.0);\n  }\n};\n\ntemplate<>\nstruct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kCast, cuDoubleComplex, cuDoubleComplex> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC cuDoubleComplex operator()(cuDoubleComplex src) const { return src; }\n};\n\ntemplate<>\nstruct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kCast, cuDoubleComplex, cuComplex> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC cuDoubleComplex operator()(cuComplex src) const {\n    return cuComplexFloatToDouble(src);\n  }\n};\n\ntemplate<>\nstruct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kCast, cuDoubleComplex, half> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC cuDoubleComplex operator()(half src) const {\n    return make_cuDoubleComplex(static_cast<double>(__half2float(src)), 0.0);\n  }\n};\n\ntemplate<>\nstruct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kIdentity, cuDoubleComplex, cuDoubleComplex> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC cuDoubleComplex operator()(cuDoubleComplex src) const { return src; }\n};\n\ntemplate<>\nstruct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kConj, cuDoubleComplex, cuDoubleComplex> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n  OF_DEVICE_FUNC cuDoubleComplex operator()(cuDoubleComplex src) const {\n    return cuDoubleComplex{src.x, -src.y};\n  }\n};\n\ntemplate<>\nstruct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kSqrt, cuDoubleComplex, cuDoubleComplex> {\n  OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}\n\n  OF_DEVICE_FUNC cuDoubleComplex operator()(cuDoubleComplex src) const {\n    double a = src.x, b = src.y;\n    double t = 0.0;\n    int scale = 1;\n    cuDoubleComplex result;\n\n    /* We risk spurious overflow for components >= DBL_MAX / (1 + sqrt(2)). */\n    const float THRESH = 7.446288774449766337959726e+307;\n\n    /* Handle special cases. */\n    if (src.x == 0.0 && src.y == double()) {\n      // return (complex<float>(0, b));\n      return (cuDoubleComplex{0.0, b});\n    }\n\n    // DBL_MIN*2\n    const double low_thresh = 4.450147717014402766180465e-308;\n    scale = 0;\n\n    if (fabs(a) >= THRESH || fabs(b) >= THRESH) {\n      /* Scale to avoid overflow. */\n      a *= 0.25;\n      b *= 0.25;\n      scale = 1;\n    } else if (fabs(a) <= low_thresh && fabs(b) <= low_thresh) {\n      /* Scale to avoid underflow. */\n      a *= 4.0;\n      b *= 4.0;\n      scale = 2;\n    }\n\n    /* Algorithm 312, CACM vol 10, Oct 1967. */\n    if (a >= 0.0) {\n      t = sqrt((a + hypot(a, b)) * 0.5);\n      // result = complex<float>(t, b / (2.0f * t));\n      result.x = t;\n      result.y = b / (2 * t);\n    } else {\n      t = sqrt((-a + hypot(a, b)) * 0.5);\n      // result = complex<float>(fabsf(b) / (2.0f * t), copysignf(t, b));\n      result.x = fabs(b) / (2 * t);\n      result.y = copysignf(t, b);\n    }\n\n    /* Rescale. */\n    if (scale == 1) {\n      // return (result * 2.0f);\n      result.x *= 2.0;\n      result.y *= 2.0;\n    } else if (scale == 2) {\n      // return (result * 0.5f);\n      result.x *= 0.5;\n      result.y *= 0.5;\n    }\n\n    return (result);\n  }\n};\n\n#define SPECIALIZATION_COMPLEX_ARITHMETIC_UNARY_FUNCTOR(op, complex_type, real_type)        \\\n  template<>                                                                                \\\n  struct UnaryFunctor<DeviceType::kCUDA, op, complex_type, complex_type> {                  \\\n    OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) : real_functor(attr0, attr1) {} \\\n    UnaryFunctor<DeviceType::kCUDA, op, real_type, real_type> real_functor;                 \\\n    OF_DEVICE_FUNC complex_type operator()(complex_type src) const {                        \\\n      return complex_type{real_functor(src.x), real_functor(src.y)};                        \\\n    }                                                                                       \\\n  };\n\nSPECIALIZATION_COMPLEX_ARITHMETIC_UNARY_FUNCTOR(UnaryOp::kNegative, cuComplex, float);\nSPECIALIZATION_COMPLEX_ARITHMETIC_UNARY_FUNCTOR(UnaryOp::kNegative, cuDoubleComplex, double);\n\n}  // namespace primitive\n}  // namespace ep\n}  // namespace oneflow\n#endif\n"
  },
  {
    "path": "oneflow/core/ep/cuda/primitive/where.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/ep/include/primitive/where.h\"\n#include \"oneflow/core/ep/common/primitive/where.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n#include \"oneflow/core/cuda/elementwise.cuh\"\n\nnamespace oneflow {\nnamespace ep {\nnamespace primitive {\n\nnamespace {\n\nusing cuda::elementwise::GetNumBlocks;\nusing cuda::elementwise::kBlockSize;\n\ntemplate<typename T, typename CondT, typename IndexT, size_t ndim, size_t cond_pack_size,\n         size_t x_pack_size, size_t y_pack_size>\n__global__ void BroadcastElementwiseWhereCudaKernel(\n    BroadcastElementwiseWhereParams<ndim, IndexT> params) {\n  constexpr size_t _pack_size = (x_pack_size > y_pack_size) ? x_pack_size : y_pack_size;\n  constexpr size_t pack_size = (cond_pack_size > _pack_size) ? cond_pack_size : _pack_size;\n  static_assert(cond_pack_size == pack_size || cond_pack_size == 1, \"\");\n  static_assert(x_pack_size == pack_size || x_pack_size == 1, \"\");\n  static_assert(y_pack_size == pack_size || y_pack_size == 1, \"\");\n  constexpr bool cond_pack_one = !(cond_pack_size == pack_size);\n  constexpr bool x_pack_one = !(x_pack_size == pack_size);\n  constexpr bool y_pack_one = !(y_pack_size == pack_size);\n\n  const auto* cond_pack_ptr = reinterpret_cast<const Packed<CondT, cond_pack_size>*>(params.cond);\n  const auto* x_pack_ptr = reinterpret_cast<const Packed<T, x_pack_size>*>(params.x);\n  const auto* y_pack_ptr = reinterpret_cast<const Packed<T, y_pack_size>*>(params.y);\n  auto* z_pack_ptr = reinterpret_cast<Packed<T, pack_size>*>(params.z);\n\n  IndexT cond_index[ndim];\n  IndexT x_index[ndim];\n  IndexT y_index[ndim];\n  IndexT z_index[ndim];\n\n  WhereFunctor<T, CondT> where_fn{};\n\n  CUDA_1D_KERNEL_LOOP_T(IndexT, offset, params.elem_cnt) {\n    params.z_index_helper.OffsetToNdIndex(offset, z_index);\n#pragma unroll\n    for (size_t i = 0; i < ndim; ++i) {\n      cond_index[i] = params.cond_index_mask[i] * z_index[i];\n      x_index[i] = params.x_index_mask[i] * z_index[i];\n      y_index[i] = params.y_index_mask[i] * z_index[i];\n    }\n    const IndexT cond_offset = params.cond_index_helper.NdIndexToOffset(cond_index);\n    const IndexT x_offset = params.x_index_helper.NdIndexToOffset(x_index);\n    const IndexT y_offset = params.y_index_helper.NdIndexToOffset(y_index);\n    Packed<CondT, cond_pack_size> cond_pack = cond_pack_ptr[cond_offset];\n    Packed<T, x_pack_size> x_pack = x_pack_ptr[x_offset];\n    Packed<T, y_pack_size> y_pack = y_pack_ptr[y_offset];\n    Packed<T, pack_size> z_pack;\n#pragma unroll\n    for (size_t j = 0; j < pack_size; ++j) {\n      const CondT cond_val = cond_pack_one ? cond_pack.elem[0] : cond_pack.elem[j];\n      const T x_val = x_pack_one ? x_pack.elem[0] : x_pack.elem[j];\n      const T y_val = y_pack_one ? y_pack.elem[0] : y_pack.elem[j];\n      z_pack.elem[j] = where_fn(cond_val, x_val, y_val);\n    }\n    z_pack_ptr[offset] = z_pack;\n  }\n}\n\ntemplate<typename T, typename CondT, typename IndexT, size_t ndim, size_t cond_pack_size,\n         size_t x_pack_size, size_t y_pack_size>\ncudaError_t LaunchCudaKernel(cudaStream_t stream, const int64_t* cond_dims, const int64_t* x_dims,\n                             const int64_t* y_dims, const int64_t* z_dims, const CondT* cond,\n                             const T* x, const T* y, T* z) {\n  BroadcastElementwiseWhereParams<ndim, IndexT> params;\n  params.cond_index_helper = NdIndexOffsetHelper<IndexT, ndim>(cond_dims);\n  params.x_index_helper = NdIndexOffsetHelper<IndexT, ndim>(x_dims);\n  params.y_index_helper = NdIndexOffsetHelper<IndexT, ndim>(y_dims);\n  params.z_index_helper = NdIndexOffsetHelper<IndexT, ndim>(z_dims);\n  for (size_t i = 0; i < ndim; ++i) {\n    params.cond_index_mask[i] = (cond_dims[i] == 1) ? 0 : 1;\n    params.x_index_mask[i] = (x_dims[i] == 1) ? 0 : 1;\n    params.y_index_mask[i] = (y_dims[i] == 1) ? 0 : 1;\n  }\n  params.elem_cnt = static_cast<IndexT>(GetElementCount(ndim, z_dims));\n  params.cond = cond;\n  params.x = x;\n  params.y = y;\n  params.z = z;\n\n  int num_blocks;\n  {\n    cudaError_t err = GetNumBlocks(params.elem_cnt, &num_blocks);\n    if (err != cudaSuccess) { return err; }\n  }\n  BroadcastElementwiseWhereCudaKernel<T, CondT, IndexT, ndim, cond_pack_size, x_pack_size,\n                                      y_pack_size><<<num_blocks, kBlockSize, 0, stream>>>(params);\n  return cudaPeekAtLastError();\n}\n\ntemplate<typename T, typename CondT, typename IndexT, size_t ndim, size_t cond_pack_size,\n         size_t x_pack_size, size_t y_pack_size>\nvoid LaunchKernel(Stream* stream, const int64_t* cond_dims, const int64_t* x_dims,\n                  const int64_t* y_dims, const int64_t* z_dims, const CondT* cond, const T* x,\n                  const T* y, T* z) {\n  static_assert(ndim > 0, \"\");\n  auto cuda_stream = stream->As<CudaStream>()->cuda_stream();\n  OF_CUDA_CHECK((LaunchCudaKernel<T, CondT, IndexT, ndim, cond_pack_size, x_pack_size, y_pack_size>(\n      cuda_stream, cond_dims, x_dims, y_dims, z_dims, cond, x, y, z)));\n}\n\ntemplate<typename T, typename CondT>\nvoid LaunchScalarKernel(Stream* stream, const CondT* cond, const T* x, const T* y, T* z) {\n  // should dispatch to elemwise tenary\n  UNIMPLEMENTED();\n}\n\ntemplate<typename T, typename CondT>\nvoid LaunchElemwiseTenary(CudaStream* stream, int64_t elem_cnt, const CondT* cond, const T* x,\n                          const T* y, T* z) {\n  cudaStream_t cuda_stream = stream->cuda_stream();\n\n  WhereElemwiseFunctor<T, CondT, T, T> where_fn{};\n  OF_CUDA_CHECK((cuda::elementwise::Ternary<decltype(where_fn), T, CondT, T, T>(\n      where_fn, elem_cnt, z, cond, x, y, cuda_stream)));\n}\n\ntemplate<typename T, typename CondT>\nclass WhereCudaImpl : public Where {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(WhereCudaImpl);\n  explicit WhereCudaImpl() = default;\n  ~WhereCudaImpl() override = default;\n\n  void Launch(Stream* stream, size_t num_cond_dims, const int64_t* cond_dims, const void* cond,\n              size_t num_x_dims, const int64_t* x_dims, const void* x, size_t num_y_dims,\n              const int64_t* y_dims, const void* y, void* z) override {\n    size_t compact_num_dims = 0;\n    int64_t compact_cond_dims[kMaxNumDims] = {};\n    int64_t compact_x_dims[kMaxNumDims] = {};\n    int64_t compact_y_dims[kMaxNumDims] = {};\n    int64_t compact_z_dims[kMaxNumDims] = {};\n    GetCompactBroadcastDims(num_cond_dims, cond_dims, num_x_dims, x_dims, num_y_dims, y_dims,\n                            &compact_num_dims, compact_cond_dims, compact_x_dims, compact_y_dims,\n                            compact_z_dims);\n\n    if (IsDimsEquals(compact_num_dims, compact_z_dims, compact_cond_dims)\n        && IsDimsEquals(compact_num_dims, compact_z_dims, compact_x_dims)\n        && IsDimsEquals(compact_num_dims, compact_z_dims, compact_y_dims)) {\n      // elementwise\n      const size_t elem_cnt = GetElementCount(compact_num_dims, compact_z_dims);\n      LaunchElemwiseTenary(stream->As<CudaStream>(), elem_cnt, static_cast<const CondT*>(cond),\n                           static_cast<const T*>(x), static_cast<const T*>(y), static_cast<T*>(z));\n    } else {\n      // broadcast\n      LaunchByDispatchNDim(stream, compact_num_dims, compact_cond_dims, compact_x_dims,\n                           compact_y_dims, compact_z_dims, static_cast<const CondT*>(cond),\n                           static_cast<const T*>(x), static_cast<const T*>(y), static_cast<T*>(z));\n    }\n  }\n};\n\nclass WhereFactoryCudaImpl : public WhereFactory {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(WhereFactoryCudaImpl);\n  WhereFactoryCudaImpl() = default;\n  ~WhereFactoryCudaImpl() override = default;\n\n  std::unique_ptr<Where> New(DataType cond_type, DataType data_type, size_t max_num_dims) override {\n    return NewWhere<WhereCudaImpl>(cond_type, data_type, max_num_dims);\n  }\n};\n\nREGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, WhereFactory, WhereFactoryCudaImpl);\n\n}  // namespace\n\n}  // namespace primitive\n}  // namespace ep\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/ep/include/active_device_guard.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_EP_ACTIVE_DEVICE_GUARD_H_\n#define ONEFLOW_CORE_EP_ACTIVE_DEVICE_GUARD_H_\n\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/ep/include/device.h\"\n\nnamespace oneflow {\n\nnamespace ep {\n\nclass DeviceManager;\n\nclass ActiveDeviceGuard {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(ActiveDeviceGuard);\n  explicit ActiveDeviceGuard(Device* device);\n  ~ActiveDeviceGuard();\n\n private:\n  size_t saved_active_device_;\n  DeviceManager* device_manager_;\n};\n\n}  // namespace ep\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_EP_ACTIVE_DEVICE_GUARD_H_\n"
  },
  {
    "path": "oneflow/core/ep/include/allocation_options.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_EP_ALLOCATION_ATTRIBUTE_H_\n#define ONEFLOW_CORE_EP_ALLOCATION_ATTRIBUTE_H_\n\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/common/device_type.h\"\n\nnamespace oneflow {\n\nnamespace ep {\n\nclass AllocationOptions {\n public:\n  AllocationOptions()\n      : pinned_device_type_(DeviceType::kInvalidDevice),\n        pinned_device_index_{},\n        numa_node_affinity_(-1) {}\n  ~AllocationOptions() = default;\n\n  bool HasPinnedDevice() const { return pinned_device_type_ != DeviceType::kInvalidDevice; }\n\n  DeviceType GetPinnedDeviceType() const {\n    CHECK(HasPinnedDevice());\n    return pinned_device_type_;\n  }\n\n  size_t GetPinnedDeviceIndex() const {\n    CHECK(HasPinnedDevice());\n    return pinned_device_index_;\n  }\n\n  void SetPinnedDevice(DeviceType device_type, size_t device_index) {\n    CHECK(!HasPinnedDevice());\n    CHECK_NE(device_type, DeviceType::kInvalidDevice);\n    pinned_device_type_ = device_type;\n    pinned_device_index_ = device_index;\n  }\n\n  void ClearPinnedDevice() { pinned_device_type_ = DeviceType::kInvalidDevice; }\n\n  bool HasNumaNodeAffinity() const { return numa_node_affinity_ >= 0; }\n\n  size_t GetNumaNodeAffinity() const {\n    CHECK(HasNumaNodeAffinity());\n    return numa_node_affinity_;\n  }\n\n  void SetNumaNodeAffinity(size_t numa_node) { numa_node_affinity_ = numa_node; }\n\n  void ClearNumaNodeAffinity() { numa_node_affinity_ = -1; }\n\n private:\n  DeviceType pinned_device_type_;\n  size_t pinned_device_index_;\n  int32_t numa_node_affinity_;\n};\n\n}  // namespace ep\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_EP_ALLOCATION_ATTRIBUTE_H_\n"
  },
  {
    "path": "oneflow/core/ep/include/device.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_EP_DEVICE_H_\n#define ONEFLOW_CORE_EP_DEVICE_H_\n\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/ep/include/event.h\"\n#include \"oneflow/core/ep/include/stream.h\"\n#include \"oneflow/core/ep/include/allocation_options.h\"\n\nnamespace oneflow {\n\nnamespace ep {\n\nconstexpr size_t kMaxAlignmentRequirement = 512;\n\nclass DeviceManager;\n\nclass Device {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(Device);\n  Device() = default;\n  virtual ~Device() = default;\n\n  virtual void SetAsActiveDevice() = 0;\n  virtual void Reset() = 0;\n\n  virtual DeviceType device_type() const = 0;\n  virtual size_t device_index() const = 0;\n  virtual DeviceManager* device_manager() const = 0;\n\n  virtual Stream* CreateStream() = 0;\n  virtual void DestroyStream(Stream* stream) = 0;\n\n  virtual Event* CreateEvent();\n  virtual void DestroyEvent(Event* event);\n  virtual void CreateEvents(Event** events, size_t count) = 0;\n  virtual void DestroyEvents(Event** events, size_t count) = 0;\n\n  virtual Maybe<void> Alloc(const AllocationOptions& options, void** ptr, size_t size) = 0;\n  virtual void Free(const AllocationOptions& options, void* ptr) = 0;\n  virtual Maybe<void> AllocPinned(const AllocationOptions& options, void** ptr, size_t size) = 0;\n  virtual void FreePinned(const AllocationOptions& options, void* ptr) = 0;\n  virtual bool IsStreamOrderedMemoryAllocationSupported() const;\n};\n\n}  // namespace ep\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_EP_DEVICE_H_\n"
  },
  {
    "path": "oneflow/core/ep/include/device_manager.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_EP_DEVICE_MANAGER_H_\n#define ONEFLOW_CORE_EP_DEVICE_MANAGER_H_\n\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/ep/include/device.h\"\n#include \"oneflow/core/ep/include/random_generator.h\"\n#include \"oneflow/core/common/auto_registration_factory.h\"\n#include \"oneflow/core/common/device_type.h\"\n\nnamespace oneflow {\n\nnamespace ep {\n\nclass DeviceManagerRegistry;\n\nclass DeviceManager {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(DeviceManager);\n  DeviceManager() = default;\n  virtual ~DeviceManager() = default;\n\n  virtual DeviceManagerRegistry* registry() const = 0;\n  virtual std::shared_ptr<Device> GetDevice(size_t device_index) = 0;\n  virtual size_t GetDeviceCount(size_t primary_device_index) = 0;\n  virtual size_t GetDeviceCount() = 0;\n  virtual size_t GetActiveDeviceIndex() = 0;\n  virtual void SetActiveDeviceByIndex(size_t device_index) = 0;\n  virtual bool IsStreamWaitEventSupported() const { return false; }\n\n  virtual std::shared_ptr<RandomGenerator> CreateRandomGenerator(uint64_t seed,\n                                                                 size_t device_index) = 0;\n};\n\n}  // namespace ep\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_EP_DEVICE_MANAGER_H_\n"
  },
  {
    "path": "oneflow/core/ep/include/device_manager_factory.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_EP_DEVICE_MANAGER_FACTORY_H_\n#define ONEFLOW_CORE_EP_DEVICE_MANAGER_FACTORY_H_\n\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/ep/include/device_manager.h\"\n#include \"oneflow/core/common/device_type.h\"\n\nnamespace oneflow {\n\nnamespace ep {\n\nclass DeviceManagerRegistry;\n\nclass DeviceManagerFactory {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(DeviceManagerFactory);\n  DeviceManagerFactory() = default;\n  virtual ~DeviceManagerFactory() = default;\n\n  virtual std::unique_ptr<DeviceManager> NewDeviceManager(DeviceManagerRegistry* registry) = 0;\n  virtual DeviceType device_type() const = 0;\n  virtual std::string device_type_name() const = 0;\n  virtual void DumpVersionInfo() const {}\n};\n\n}  // namespace ep\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_EP_DEVICE_MANAGER_FACTORY_H_\n"
  },
  {
    "path": "oneflow/core/ep/include/device_manager_registry.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_EP_DEVICE_MANAGER_REGISTRY_H_\n#define ONEFLOW_CORE_EP_DEVICE_MANAGER_REGISTRY_H_\n\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/ep/include/device_manager.h\"\n#include \"oneflow/core/ep/include/device_manager_factory.h\"\n\nnamespace oneflow {\n\nnamespace ep {\n\nclass DeviceManagerRegistry {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(DeviceManagerRegistry);\n  DeviceManagerRegistry();\n  ~DeviceManagerRegistry();\n\n  DeviceManager* GetDeviceManager(DeviceType device_type);\n  DeviceManager* GetDeviceManagerOrNull(DeviceType device_type);\n  std::shared_ptr<Device> GetDevice(DeviceType device_type, size_t device_index);\n  size_t GetDeviceCount(DeviceType device_type);\n  size_t GetDeviceCount(const std::string& device_type_name);\n\n  static void RegisterDeviceManagerFactory(std::unique_ptr<DeviceManagerFactory>&& factory);\n  static void DumpVersionInfo();\n  static std::string GetDeviceTypeNameByDeviceType(DeviceType device_type);\n  static DeviceType GetDeviceTypeByDeviceTypeName(const std::string& device_type_name);\n  static std::set<DeviceType> GetRegisteredDeviceTypes();\n  static bool IsDeviceTypeRegistered(DeviceType device_type);\n\n private:\n  class Impl;\n  std::unique_ptr<Impl> impl_;\n};\n\n}  // namespace ep\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_EP_DEVICE_MANAGER_REGISTRY_H_\n"
  },
  {
    "path": "oneflow/core/ep/include/event.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_EP_EVENT_H_\n#define ONEFLOW_CORE_EP_EVENT_H_\n\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/common/maybe.h\"\n\nnamespace oneflow {\n\nnamespace ep {\n\nclass Event {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(Event);\n  Event() = default;\n  virtual ~Event() = default;\n\n  virtual Maybe<bool> QueryDone() = 0;\n  virtual Maybe<void> Sync() = 0;\n};\n\n}  // namespace ep\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_EP_EVENT_H_\n"
  },
  {
    "path": "oneflow/core/ep/include/primitive/add.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_EP_PRIMITIVE_ADD_H_\n#define ONEFLOW_CORE_EP_PRIMITIVE_ADD_H_\n\n#include \"oneflow/core/ep/include/primitive/primitive.h\"\n\nnamespace oneflow {\n\nnamespace ep {\nnamespace primitive {\n\nclass Add : public Primitive {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(Add);\n  Add() = default;\n  ~Add() override = default;\n\n  virtual void Launch(Stream* stream, const void* const* srcs, size_t arity, void* dst,\n                      size_t count) = 0;\n  virtual void Launch(Stream* stream, const void* src0, const void* src1, void* dst, size_t count);\n};\n\nclass AddFactory : public Factory<Add> {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(AddFactory);\n  AddFactory() = default;\n  ~AddFactory() override = default;\n\n  virtual std::unique_ptr<Add> New(DataType data_type) = 0;\n};\n\n}  // namespace primitive\n}  // namespace ep\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_EP_PRIMITIVE_ADD_H_\n"
  },
  {
    "path": "oneflow/core/ep/include/primitive/batch_matmul.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_EP_PRIMITIVE_BATCH_MATMUL_H_\n#define ONEFLOW_CORE_EP_PRIMITIVE_BATCH_MATMUL_H_\n\n#include \"oneflow/core/ep/include/primitive/primitive.h\"\n#include \"oneflow/core/ep/include/primitive/blas.h\"\n#include \"oneflow/core/common/scalar.h\"\n\nnamespace oneflow {\n\nnamespace ep {\nnamespace primitive {\n\nclass BatchMatmul : public Primitive {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(BatchMatmul);\n  BatchMatmul() = default;\n  ~BatchMatmul() override = default;\n\n  virtual void Launch(Stream* stream, size_t batch_size, size_t m, size_t n, size_t k, Scalar alpha,\n                      const void* a, const void* b, Scalar beta, void* c) = 0;\n};\n\nclass BatchMatmulFactory : public Factory<BatchMatmul> {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(BatchMatmulFactory);\n  BatchMatmulFactory() = default;\n  ~BatchMatmulFactory() override = default;\n\n  virtual std::unique_ptr<BatchMatmul> New(DataType data_type, BlasTransposeType transpose_a,\n                                           BlasTransposeType transpose_b) = 0;\n};\n\n}  // namespace primitive\n}  // namespace ep\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_EP_PRIMITIVE_BATCH_MATMUL_H_\n"
  },
  {
    "path": "oneflow/core/ep/include/primitive/binary_op.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_EP_PRIMITIVE_BINARY_OP_H_\n#define ONEFLOW_CORE_EP_PRIMITIVE_BINARY_OP_H_\n\n#include \"oneflow/core/ep/include/primitive/primitive.h\"\n\nnamespace oneflow {\n\nnamespace ep {\nnamespace primitive {\n\nenum class BinaryOp {\n  // Math\n  kAdd,\n  kSub,\n  kMul,\n  kDiv,\n  kMax,\n  kMin,\n  kPow,\n  kFmod,\n  kFloorDiv,\n  kTruncDiv,\n  kFloorMod,\n  kScalarBasePowerGrad,\n  kScalarExpPowerGrad,\n  kZeta,\n  // Comparision\n  kEqual,\n  kNotEqual,\n  kLessThan,\n  kLessEqual,\n  kGreaterThan,\n  kGreaterEqual,\n  kIsClose,\n  kIsCloseEqualNan,\n  // Logical\n  kLogicalAnd,\n  kLogicalOr,\n  kLogicalXor,\n  // Bitwise\n  kBitwiseAnd,\n  kBitwiseOr,\n  kBitwiseXor,\n  // Unary Backward\n  kIdentityBackwardWithDyX,\n  kEluBackwardWithDyX,\n  kCeluBackwardWithDyY,\n  kGeluBackwardWithDyX,\n  kHardswishBackwardWithDyX,\n  kHardsigmoidBackwardWithDyX,\n  kHardshrinkBackwardWithDyY,\n  kHardtanhBackwardWithDyY,\n  kLeakyReluBackwardWithDyX,\n  kMishBackwardWithDyX,\n  kReluBackwardWithDyY,\n  kReluBackwardWithDyX,\n  kSeluBackwardWithDyX,\n  kSiluBackwardWithDyX,\n  kSoftsignBackwardWithDyX,\n  kSoftplusBackwardWithDyX,\n  kSoftshrinkBackwardWithDyY,\n  kTanhBackwardWithDyY,\n  kThresholdBackwardWithDyX,\n  kSigmoidBackwardWithDyY,\n  kSigmoidBackwardWithDyX,\n  kAbsBackwardWithDyX,\n  kAcosBackwardWithDyX,\n  kAcoshBackwardWithDyX,\n  kAsinBackwardWithDyX,\n  kAsinhBackwardWithDyX,\n  kAtanBackwardWithDyX,\n  kAtanhBackwardWithDyX,\n  kCosBackwardWithDyX,\n  kCoshBackwardWithDyX,\n  kErfBackwardWithDyX,\n  kErfcBackwardWithDyX,\n  kExpBackwardWithDyX,\n  kExp2BackwardWithDyX,\n  kExpm1BackwardWithDyX,\n  kLgammaBackwardWithDyX,\n  kDigammaBackwardWithDyX,\n  kLogBackwardWithDyX,\n  kLog2BackwardWithDyX,\n  kLog10BackwardWithDyX,\n  kLog1pBackwardWithDyX,\n  kLogSigmoidBackwardWithDyX,\n  kReciprocalBackwardWithDyX,\n  kReciprocalNoNanBackwardWithDyX,\n  kRsqrtBackwardWithDyX,\n  kSinBackwardWithDyX,\n  kSinhBackwardWithDyX,\n  kSqrtBackwardWithDyX,\n  kSquareBackwardWithDyX,\n  kTanBackwardWithDyX,\n  kFastGeluBackwardWithDyX,\n  kQuickGeluBackwardWithDyX,\n  kSquareReLUBackwardWithDyX,\n};\n\n}\n}  // namespace ep\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_EP_PRIMITIVE_BINARY_OP_H_\n"
  },
  {
    "path": "oneflow/core/ep/include/primitive/blas.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_EP_PRIMITIVE_BLAS_H_\n#define ONEFLOW_CORE_EP_PRIMITIVE_BLAS_H_\n\nnamespace oneflow {\n\nnamespace ep {\nnamespace primitive {\n\nenum class BlasTransposeType {\n  N = 0,\n  T,\n};\n\n}\n}  // namespace ep\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_EP_PRIMITIVE_BLAS_H_\n"
  },
  {
    "path": "oneflow/core/ep/include/primitive/broadcast_elementwise_binary.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_EP_PRIMITIVE_BROADCAST_ELEMENTWISE_BINARY_H_\n#define ONEFLOW_CORE_EP_PRIMITIVE_BROADCAST_ELEMENTWISE_BINARY_H_\n\n#include \"oneflow/core/ep/include/primitive/primitive.h\"\n#include \"oneflow/core/ep/include/primitive/binary_op.h\"\n#include \"oneflow/core/common/scalar.h\"\n\nnamespace oneflow {\n\nnamespace ep {\nnamespace primitive {\n\nclass BroadcastElementwiseBinary : public Primitive {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(BroadcastElementwiseBinary);\n  BroadcastElementwiseBinary() = default;\n  ~BroadcastElementwiseBinary() override = default;\n\n  virtual void Launch(Stream* stream, size_t num_src0_dims, const int64_t* src0_dims,\n                      const void* src0, size_t num_src1_dims, const int64_t* src1_dims,\n                      const void* src1, void* dst) = 0;\n  virtual void Launch(Stream* stream, Scalar src0, size_t num_src1_dims, const int64_t* src1_dims,\n                      const void* src1, void* dst) = 0;\n  virtual void Launch(Stream* stream, size_t num_src0_dims, const int64_t* src0_dims,\n                      const void* src0, Scalar src1, void* dst) = 0;\n};\n\nclass BroadcastElementwiseBinaryFactory : public Factory<BroadcastElementwiseBinary> {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(BroadcastElementwiseBinaryFactory);\n  BroadcastElementwiseBinaryFactory() = default;\n  ~BroadcastElementwiseBinaryFactory() override = default;\n\n  virtual std::unique_ptr<BroadcastElementwiseBinary> New(BinaryOp op, DataType src_type,\n                                                          DataType dst_type,\n                                                          size_t max_num_dims) = 0;\n\n  virtual std::unique_ptr<BroadcastElementwiseBinary> New(BinaryOp op, DataType src_type,\n                                                          DataType dst_type, size_t max_num_dims,\n                                                          Scalar attr0) = 0;\n\n  virtual std::unique_ptr<BroadcastElementwiseBinary> New(BinaryOp op, DataType src_type,\n                                                          DataType dst_type, size_t max_num_dims,\n                                                          Scalar attr0, Scalar attr1) = 0;\n};\n\n}  // namespace primitive\n}  // namespace ep\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_EP_PRIMITIVE_BROADCAST_ELEMENTWISE_BINARY_H_\n"
  },
  {
    "path": "oneflow/core/ep/include/primitive/broadcast_elementwise_unary.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_EP_PRIMITIVE_BROADCAST_ELEMENTWISE_UNARY_H_\n#define ONEFLOW_CORE_EP_PRIMITIVE_BROADCAST_ELEMENTWISE_UNARY_H_\n\n#include \"oneflow/core/ep/include/primitive/primitive.h\"\n#include \"oneflow/core/ep/include/primitive/unary_op.h\"\n#include \"oneflow/core/common/scalar.h\"\n\nnamespace oneflow {\n\nnamespace ep {\nnamespace primitive {\n\nclass BroadcastElementwiseUnary : public Primitive {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(BroadcastElementwiseUnary);\n  BroadcastElementwiseUnary() = default;\n  ~BroadcastElementwiseUnary() override = default;\n\n  virtual void Launch(Stream* stream, size_t num_src_dims, const int64_t* src_dims,\n                      const int64_t* src_strides, const void* src, size_t num_dst_dims,\n                      const int64_t* dst_dims, const int64_t* dst_strides, void* dst) = 0;\n\n  virtual void Launch(Stream* stream, size_t num_src_dims, const int64_t* src_dims, const void* src,\n                      size_t num_dst_dims, const int64_t* dst_dims, void* dst) = 0;\n};\n\nclass BroadcastElementwiseUnaryFactory : public Factory<BroadcastElementwiseUnary> {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(BroadcastElementwiseUnaryFactory);\n  BroadcastElementwiseUnaryFactory() = default;\n  ~BroadcastElementwiseUnaryFactory() override = default;\n\n  virtual std::unique_ptr<BroadcastElementwiseUnary> New(UnaryOp op, DataType src_type,\n                                                         DataType dst_type,\n                                                         size_t max_num_dims) = 0;\n\n  virtual std::unique_ptr<BroadcastElementwiseUnary> New(UnaryOp op, DataType src_type,\n                                                         DataType dst_type, size_t max_num_dims,\n                                                         Scalar attr0) = 0;\n\n  virtual std::unique_ptr<BroadcastElementwiseUnary> New(UnaryOp op, DataType src_type,\n                                                         DataType dst_type, size_t max_num_dims,\n                                                         Scalar attr0, Scalar attr1) = 0;\n};\n\n}  // namespace primitive\n}  // namespace ep\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_EP_PRIMITIVE_BROADCAST_ELEMENTWISE_UNARY_H_\n"
  },
  {
    "path": "oneflow/core/ep/include/primitive/broadcast_matmul.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_EP_PRIMITIVE_BROADCAST_MATMUL_H_\n#define ONEFLOW_CORE_EP_PRIMITIVE_BROADCAST_MATMUL_H_\n\n#include \"oneflow/core/ep/include/primitive/primitive.h\"\n#include \"oneflow/core/ep/include/primitive/blas.h\"\n#include \"oneflow/core/common/scalar.h\"\n\nnamespace oneflow {\n\nnamespace ep {\nnamespace primitive {\n\nclass BroadcastMatmul : public Primitive {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(BroadcastMatmul);\n  BroadcastMatmul() = default;\n  ~BroadcastMatmul() override = default;\n\n  virtual void Launch(Stream* stream, Scalar alpha, size_t num_a_dims, const int64_t* a_dims,\n                      const void* a, size_t num_b_dims, const int64_t* b_dims, const void* b,\n                      Scalar beta, size_t num_c_dims, const int64_t* c_dims, void* c) = 0;\n};\n\nclass BroadcastMatmulFactory : public Factory<BroadcastMatmul> {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(BroadcastMatmulFactory);\n  BroadcastMatmulFactory() = default;\n  ~BroadcastMatmulFactory() override = default;\n\n  virtual std::unique_ptr<BroadcastMatmul> New(DataType data_type, BlasTransposeType transpose_a,\n                                               BlasTransposeType transpose_b,\n                                               size_t max_num_dims) = 0;\n};\n\n}  // namespace primitive\n}  // namespace ep\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_EP_PRIMITIVE_BROADCAST_MATMUL_H_\n"
  },
  {
    "path": "oneflow/core/ep/include/primitive/cast.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_EP_PRIMITIVE_CAST_H_\n#define ONEFLOW_CORE_EP_PRIMITIVE_CAST_H_\n\n#include \"oneflow/core/ep/include/primitive/primitive.h\"\n\nnamespace oneflow {\n\nnamespace ep {\nnamespace primitive {\n\nclass Cast : public Primitive {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(Cast);\n  Cast() = default;\n  ~Cast() override = default;\n\n  virtual void Launch(Stream* stream, const void* from, void* to, size_t count) = 0;\n};\n\nclass CastFactory : public Factory<Cast> {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CastFactory);\n  CastFactory() = default;\n  ~CastFactory() override = default;\n\n  virtual std::unique_ptr<Cast> New(DataType from, DataType to) = 0;\n};\n\n}  // namespace primitive\n}  // namespace ep\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_EP_PRIMITIVE_CAST_H_\n"
  },
  {
    "path": "oneflow/core/ep/include/primitive/constant_pad.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_EP_PRIMITIVE_CONSTANT_PAD_H_\n#define ONEFLOW_CORE_EP_PRIMITIVE_CONSTANT_PAD_H_\n\n#include \"oneflow/core/ep/include/primitive/primitive.h\"\n#include \"oneflow/core/common/scalar.h\"\n\nnamespace oneflow {\n\nnamespace ep {\n\nnamespace primitive {\n\nclass ConstantPad : public Primitive {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(ConstantPad);\n  ConstantPad() = default;\n  ~ConstantPad() override = default;\n\n  virtual void Launch(Stream* stream, size_t num_dims, const int64_t* src_dims, const void* src,\n                      const int64_t* padding_before, const int64_t* padding_after, Scalar pad_val,\n                      void* dst) = 0;\n};\n\nclass ConstantPadFactory : public Factory<ConstantPad> {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(ConstantPadFactory);\n  ConstantPadFactory() = default;\n  ~ConstantPadFactory() override = default;\n\n  virtual std::unique_ptr<ConstantPad> New(DataType data_type) = 0;\n};\n\n}  // namespace primitive\n\n}  // namespace ep\n\n}  // namespace oneflow\n\n#endif\n"
  },
  {
    "path": "oneflow/core/ep/include/primitive/copy_nd.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_EP_PRIMITIVE_COPY_ND_H_\n#define ONEFLOW_CORE_EP_PRIMITIVE_COPY_ND_H_\n\n#include \"oneflow/core/ep/include/primitive/primitive.h\"\n\nnamespace oneflow {\n\nnamespace ep {\nnamespace primitive {\n\nclass CopyNd : public Primitive {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CopyNd);\n  CopyNd() = default;\n  ~CopyNd() override = default;\n\n  virtual void Launch(Stream* stream, DataType data_type, size_t num_dims, void* dst,\n                      const int64_t* dst_dims, const int64_t* dst_pos, const void* src,\n                      const int64_t* src_dims, const int64_t* src_pos,\n                      const int64_t* extent) const = 0;\n};\n\nclass CopyNdFactory : public Factory<CopyNd> {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CopyNdFactory);\n  CopyNdFactory() = default;\n  ~CopyNdFactory() override = default;\n\n  virtual std::unique_ptr<CopyNd> New(size_t max_num_dims) = 0;\n};\n\n}  // namespace primitive\n}  // namespace ep\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_EP_PRIMITIVE_COPY_ND_H_\n"
  },
  {
    "path": "oneflow/core/ep/include/primitive/elementwise_unary.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_EP_PRIMITIVE_ELEMENTWISE_UNARY_H_\n#define ONEFLOW_CORE_EP_PRIMITIVE_ELEMENTWISE_UNARY_H_\n\n#include \"oneflow/core/common/scalar.h\"\n#include \"oneflow/core/ep/include/primitive/primitive.h\"\n#include \"oneflow/core/ep/include/primitive/unary_op.h\"\n\nnamespace oneflow {\n\nnamespace ep {\nnamespace primitive {\n\nclass ElementwiseUnary : public Primitive {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(ElementwiseUnary);\n  ElementwiseUnary() = default;\n  ~ElementwiseUnary() override = default;\n\n  virtual void Launch(Stream* stream, const void* src, void* dst, size_t count) = 0;\n};\n\nclass ElementwiseUnaryFactory : public Factory<ElementwiseUnary> {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(ElementwiseUnaryFactory);\n  ElementwiseUnaryFactory() = default;\n  ~ElementwiseUnaryFactory() override = default;\n\n  virtual std::unique_ptr<ElementwiseUnary> New(UnaryOp op, DataType src_type,\n                                                DataType dst_type) = 0;\n\n  virtual std::unique_ptr<ElementwiseUnary> New(UnaryOp op, DataType src_type, DataType dst_type,\n                                                Scalar attr0) = 0;\n\n  virtual std::unique_ptr<ElementwiseUnary> New(UnaryOp op, DataType src_type, DataType dst_type,\n                                                Scalar attr0, Scalar attr1) = 0;\n};\n\n}  // namespace primitive\n}  // namespace ep\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_EP_PRIMITIVE_ELEMENTWISE_UNARY_H_\n"
  },
  {
    "path": "oneflow/core/ep/include/primitive/fast_integer_math.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_EP_PRIMITIVE_FAST_INTEGER_MATH_H_\n#define ONEFLOW_CORE_EP_PRIMITIVE_FAST_INTEGER_MATH_H_\n#include \"oneflow/core/common/data_type.h\"\n#include <cassert>\n\nnamespace oneflow {\n\n/*\n  Copyright microsoft/onnxruntime\n  https://github.com/microsoft/onnxruntime/blob/master/onnxruntime/core/providers/cuda/shared_inc/fast_divmod.h\n*/\ntemplate<typename T>\nstruct FastIntegerMath {\n  OF_DEVICE_FUNC FastIntegerMath() {}\n  OF_DEVICE_FUNC explicit FastIntegerMath(T operand) {\n#if defined(__CUDA_ARCH__)\n    int leading_zeroes = __clzll(operand);\n#else\n    int leading_zeroes = __builtin_clz(operand);\n#endif\n    bool is_power_2 = ((operand & (operand - 1)) == 0);\n    if (is_power_2) {\n      log2_operand_ = 31 - leading_zeroes;\n    } else {\n      log2_operand_ = -1;  // Set as flag.\n    }\n    operand_ = operand == 0 ? 1 : operand;\n    assert(operand_ >= 1 && operand_ <= GetMaxVal<T>());\n  }\n\n  OF_DEVICE_FUNC T divides(T n) const {\n    if (log2_operand_ >= 0) {\n      return n >> log2_operand_;\n    } else {\n      return n / operand_;\n    }\n  }\n\n  OF_DEVICE_FUNC T mod(T n) const { return n - divides(n) * operand_; }\n  OF_DEVICE_FUNC T mul(T n) const {\n    if (log2_operand_ >= 0) {\n      return n << log2_operand_;\n    } else {\n      return n * operand_;\n    }\n  }\n  OF_DEVICE_FUNC T add(T n) const { return n + operand_; }\n  OF_DEVICE_FUNC T sub(T n) const { return n - operand_; }\n  OF_DEVICE_FUNC void divmod(T n, T* q, T* r) const {\n    *q = divides(n);\n    *r = n - *q * operand_;\n  }\n\n  T operand_;\n  int32_t log2_operand_;\n};\n\ntemplate<>\nstruct FastIntegerMath<int32_t> {\n  OF_DEVICE_FUNC FastIntegerMath() {}\n\n  OF_DEVICE_FUNC explicit FastIntegerMath(const int32_t operand) {\n    operand_ = operand == 0 ? 1 : operand;\n    assert(operand_ >= 1 && operand_ <= GetMaxVal<uint32_t>());\n    for (l_ = 0; l_ < 32; l_++)\n      if ((1U << l_) >= operand_) break;\n\n    uint64_t one = 1;\n    uint64_t m = ((one << 32) * ((one << l_) - operand_)) / operand_ + 1;\n    M_ = static_cast<uint32_t>(m);\n    assert(M_ > 0 && M_ == m);\n  }\n\n  OF_DEVICE_FUNC int32_t divides(const int32_t n) const {\n#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)\n    uint32_t t = __umulhi(M_, n);\n    return (t + n) >> l_;\n#else\n    // Using uint64_t for t, then t + n won't overflow.\n    uint64_t t = ((uint64_t)M_ * n) >> 32;\n    return static_cast<int>((t + n) >> l_);\n#endif\n  }\n\n  OF_DEVICE_FUNC int32_t mod(int32_t n) const { return n - divides(n) * operand_; }\n  OF_DEVICE_FUNC int32_t mul(int32_t n) const { return n * operand_; }\n  OF_DEVICE_FUNC int32_t add(int32_t n) const { return n + operand_; }\n  OF_DEVICE_FUNC int32_t sub(int32_t n) const { return n - operand_; }\n  OF_DEVICE_FUNC void divmod(int32_t n, int32_t* q, int32_t* r) const {\n    *q = divides(n);\n    *r = n - *q * operand_;\n  }\n\n  uint32_t operand_;\n  uint32_t M_;  // m' in the paper.\n  uint32_t l_;  // l_ = ceil(log2(d_))\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_EP_PRIMITIVE_FAST_INTEGER_MATH_H_\n"
  },
  {
    "path": "oneflow/core/ep/include/primitive/fill.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_EP_PRIMITIVE_FILL_H_\n#define ONEFLOW_CORE_EP_PRIMITIVE_FILL_H_\n\n#include \"oneflow/core/ep/include/primitive/primitive.h\"\n#include \"oneflow/core/common/scalar.h\"\n\nnamespace oneflow {\n\nnamespace ep {\nnamespace primitive {\n\nclass Fill : public Primitive {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(Fill);\n  Fill() = default;\n  ~Fill() override = default;\n\n  virtual void Launch(Stream* stream, void* dst, Scalar value, size_t count) = 0;\n};\n\nclass FillFactory : public Factory<Fill> {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(FillFactory);\n  FillFactory() = default;\n  ~FillFactory() override = default;\n\n  virtual std::unique_ptr<Fill> New(DataType data_type) = 0;\n};\n\n}  // namespace primitive\n}  // namespace ep\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_EP_PRIMITIVE_FILL_H_\n"
  },
  {
    "path": "oneflow/core/ep/include/primitive/log_softmax.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_EP_PRIMITIVE_LOG_SOFTMAX_H_\n#define ONEFLOW_CORE_EP_PRIMITIVE_LOG_SOFTMAX_H_\n\n#include \"oneflow/core/ep/include/primitive/primitive.h\"\n\nnamespace oneflow {\n\nnamespace ep {\nnamespace primitive {\n\nclass LogSoftmax : public Primitive {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(LogSoftmax);\n  LogSoftmax() = default;\n  ~LogSoftmax() override = default;\n\n  virtual void Launch(Stream* stream, size_t rows, size_t cols, const void* x, void* y) = 0;\n};\n\nclass LogSoftmaxFactory : public Factory<LogSoftmax> {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(LogSoftmaxFactory);\n  LogSoftmaxFactory() = default;\n  ~LogSoftmaxFactory() override = default;\n\n  virtual std::unique_ptr<LogSoftmax> New(DataType data_type) = 0;\n};\n\n}  // namespace primitive\n}  // namespace ep\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_EP_PRIMITIVE_LOG_SOFTMAX_H_\n"
  },
  {
    "path": "oneflow/core/ep/include/primitive/log_softmax_backward.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_EP_PRIMITIVE_LOG_SOFTMAX_BACKWARD_H_\n#define ONEFLOW_CORE_EP_PRIMITIVE_LOG_SOFTMAX_BACKWARD_H_\n\n#include \"oneflow/core/ep/include/primitive/primitive.h\"\n\nnamespace oneflow {\n\nnamespace ep {\nnamespace primitive {\n\nclass LogSoftmaxBackward : public Primitive {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(LogSoftmaxBackward);\n  LogSoftmaxBackward() = default;\n  ~LogSoftmaxBackward() override = default;\n\n  virtual void Launch(Stream* stream, size_t rows, size_t cols, const void* y, const void* dy,\n                      void* dx) = 0;\n};\n\nclass LogSoftmaxBackwardFactory : public Factory<LogSoftmaxBackward> {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(LogSoftmaxBackwardFactory);\n  LogSoftmaxBackwardFactory() = default;\n  ~LogSoftmaxBackwardFactory() override = default;\n\n  virtual std::unique_ptr<LogSoftmaxBackward> New(DataType data_type) = 0;\n};\n\n}  // namespace primitive\n}  // namespace ep\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_EP_PRIMITIVE_LOG_SOFTMAX_BACKWARD_H_\n"
  },
  {
    "path": "oneflow/core/ep/include/primitive/matmul.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_EP_PRIMITIVE_MATMUL_H_\n#define ONEFLOW_CORE_EP_PRIMITIVE_MATMUL_H_\n\n#include \"oneflow/core/ep/include/primitive/primitive.h\"\n#include \"oneflow/core/ep/include/primitive/blas.h\"\n#include \"oneflow/core/common/scalar.h\"\n\nnamespace oneflow {\n\nnamespace ep {\nnamespace primitive {\n\nclass Matmul : public Primitive {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(Matmul);\n  Matmul() = default;\n  ~Matmul() override = default;\n\n  virtual void Launch(Stream* stream, size_t m, size_t n, size_t k, Scalar alpha, const void* a,\n                      const void* b, Scalar beta, void* c) = 0;\n};\n\nclass MatmulFactory : public Factory<Matmul> {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(MatmulFactory);\n  MatmulFactory() = default;\n  ~MatmulFactory() override = default;\n\n  virtual std::unique_ptr<Matmul> New(DataType data_type, BlasTransposeType transpose_a,\n                                      BlasTransposeType transpose_b) = 0;\n};\n\n}  // namespace primitive\n}  // namespace ep\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_EP_PRIMITIVE_MATMUL_H_\n"
  },
  {
    "path": "oneflow/core/ep/include/primitive/memcpy.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_EP_PRIMITIVE_MEMCPY_H_\n#define ONEFLOW_CORE_EP_PRIMITIVE_MEMCPY_H_\n\n#include \"oneflow/core/ep/include/primitive/primitive.h\"\n\nnamespace oneflow {\n\nnamespace ep {\nnamespace primitive {\n\nenum class MemcpyKind {\n  kAuto = 0,\n  kHtoD,\n  kDtoH,\n  kDtoD,\n};\n\nclass Memcpy : public Primitive {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(Memcpy);\n  Memcpy() = default;\n  ~Memcpy() override = default;\n\n  virtual void Launch(Stream* stream, void* dst, const void* src, size_t count) = 0;\n};\n\nclass MemcpyFactory : public Factory<Memcpy> {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(MemcpyFactory);\n  MemcpyFactory() = default;\n  ~MemcpyFactory() override = default;\n\n  virtual std::unique_ptr<Memcpy> New(MemcpyKind kind) = 0;\n};\n\n}  // namespace primitive\n}  // namespace ep\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_EP_PRIMITIVE_MEMCPY_H_\n"
  },
  {
    "path": "oneflow/core/ep/include/primitive/memset.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_EP_PRIMITIVE_MEMSET_H_\n#define ONEFLOW_CORE_EP_PRIMITIVE_MEMSET_H_\n\n#include \"oneflow/core/ep/include/primitive/primitive.h\"\n\nnamespace oneflow {\n\nnamespace ep {\nnamespace primitive {\n\nclass Memset : public Primitive {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(Memset);\n  Memset() = default;\n  ~Memset() override = default;\n\n  virtual void Launch(Stream* stream, void* ptr, int value, size_t count) = 0;\n};\n\nclass MemsetFactory : public Factory<Memset> {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(MemsetFactory);\n  MemsetFactory() = default;\n  ~MemsetFactory() override = default;\n\n  virtual std::unique_ptr<Memset> New() = 0;\n};\n\n}  // namespace primitive\n}  // namespace ep\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_EP_PRIMITIVE_MEMSET_H_\n"
  },
  {
    "path": "oneflow/core/ep/include/primitive/one_hot.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_EP_PRIMITIVE_ONE_HOT_H_\n#define ONEFLOW_CORE_EP_PRIMITIVE_ONE_HOT_H_\n\n#include \"oneflow/core/ep/include/primitive/primitive.h\"\n#include \"oneflow/core/common/scalar.h\"\n\nnamespace oneflow {\n\nnamespace ep {\nnamespace primitive {\n\nclass OneHot : public Primitive {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(OneHot);\n  OneHot() = default;\n  ~OneHot() override = default;\n\n  virtual void Launch(Stream* stream, const void* indices, void* out, Scalar on_value,\n                      Scalar off_value, size_t num_indices, size_t lower_bound,\n                      size_t upper_bound) = 0;\n};\n\nclass OneHotFactory : public Factory<OneHot> {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(OneHotFactory);\n  OneHotFactory() = default;\n  ~OneHotFactory() override = default;\n\n  virtual std::unique_ptr<OneHot> New(DataType indices_type, DataType out_type) = 0;\n};\n\n}  // namespace primitive\n}  // namespace ep\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_EP_PRIMITIVE_ONE_HOT_H_\n"
  },
  {
    "path": "oneflow/core/ep/include/primitive/permute.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_EP_PRIMITIVE_PERMUTE_H_\n#define ONEFLOW_CORE_EP_PRIMITIVE_PERMUTE_H_\n\n#include \"oneflow/core/ep/include/primitive/primitive.h\"\n\nnamespace oneflow {\n\nnamespace ep {\nnamespace primitive {\n\nclass Permute : public Primitive {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(Permute);\n  Permute() = default;\n  ~Permute() override = default;\n\n  virtual void Launch(Stream* stream, DataType data_type, size_t num_dims, const int64_t* src_dims,\n                      const void* src, const int* permutation, void* dst) = 0;\n};\n\nclass PermuteFactory : public Factory<Permute> {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(PermuteFactory);\n  PermuteFactory() = default;\n  ~PermuteFactory() override = default;\n\n  virtual std::unique_ptr<Permute> New(size_t max_num_dims) = 0;\n};\n\n}  // namespace primitive\n}  // namespace ep\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_EP_PRIMITIVE_PERMUTE_H_\n"
  },
  {
    "path": "oneflow/core/ep/include/primitive/primitive.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_EP_PRIMITIVE_PRIMITIVE_H_\n#define ONEFLOW_CORE_EP_PRIMITIVE_PRIMITIVE_H_\n\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/common/auto_registration_factory.h\"\n#include \"oneflow/core/common/device_type.h\"\n#include \"oneflow/core/framework/to_string.h\"\n#include \"oneflow/core/ep/include/stream.h\"\n\nnamespace oneflow {\n\nnamespace ep {\nnamespace primitive {\n\nclass Primitive {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(Primitive);\n  Primitive() = default;\n  virtual ~Primitive() = default;\n};\n\ntemplate<typename PrimitiveT>\nclass Factory {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(Factory);\n  Factory() = default;\n  virtual ~Factory() = default;\n\n  using PrimitiveType = PrimitiveT;\n};\n\ntemplate<typename FactoryType, typename... Args>\nstatic std::unique_ptr<typename FactoryType::PrimitiveType> NewPrimitive(DeviceType device_type,\n                                                                         Args&&... args) {\n  if (!IsClassRegistered<DeviceType, FactoryType>(device_type)) { return nullptr; }\n  std::unique_ptr<FactoryType> factory = NewObjUniquePtr<DeviceType, FactoryType>(device_type);\n  if (!factory) { return nullptr; }\n  return factory->New(std::forward<Args>(args)...);\n}\n\n#define REGISTER_PRIMITIVE_FACTORY(device, Base, Derived) \\\n  REGISTER_CLASS(DeviceType, device, Base, Derived)\n\n}  // namespace primitive\n}  // namespace ep\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_EP_PRIMITIVE_PRIMITIVE_H_\n"
  },
  {
    "path": "oneflow/core/ep/include/primitive/softmax.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_EP_PRIMITIVE_SOFTMAX_H_\n#define ONEFLOW_CORE_EP_PRIMITIVE_SOFTMAX_H_\n\n#include \"oneflow/core/ep/include/primitive/primitive.h\"\n\nnamespace oneflow {\n\nnamespace ep {\nnamespace primitive {\n\nclass Softmax : public Primitive {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(Softmax);\n  Softmax() = default;\n  ~Softmax() override = default;\n\n  virtual void Launch(Stream* stream, size_t rows, size_t cols, const void* x, void* y) = 0;\n};\n\nclass SoftmaxFactory : public Factory<Softmax> {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(SoftmaxFactory);\n  SoftmaxFactory() = default;\n  ~SoftmaxFactory() override = default;\n\n  virtual std::unique_ptr<Softmax> New(DataType data_type) = 0;\n};\n\n}  // namespace primitive\n}  // namespace ep\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_EP_PRIMITIVE_SOFTMAX_H_\n"
  },
  {
    "path": "oneflow/core/ep/include/primitive/softmax_backward.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_EP_PRIMITIVE_SOFTMAX_BACKWARD_H_\n#define ONEFLOW_CORE_EP_PRIMITIVE_SOFTMAX_BACKWARD_H_\n\n#include \"oneflow/core/ep/include/primitive/primitive.h\"\n\nnamespace oneflow {\n\nnamespace ep {\nnamespace primitive {\n\nclass SoftmaxBackward : public Primitive {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(SoftmaxBackward);\n  SoftmaxBackward() = default;\n  ~SoftmaxBackward() override = default;\n\n  virtual void Launch(Stream* stream, size_t rows, size_t cols, const void* y, const void* dy,\n                      void* dx) = 0;\n};\n\nclass SoftmaxBackwardFactory : public Factory<SoftmaxBackward> {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(SoftmaxBackwardFactory);\n  SoftmaxBackwardFactory() = default;\n  ~SoftmaxBackwardFactory() override = default;\n\n  virtual std::unique_ptr<SoftmaxBackward> New(DataType data_type) = 0;\n};\n\n}  // namespace primitive\n}  // namespace ep\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_EP_PRIMITIVE_SOFTMAX_BACKWARD_H_\n"
  },
  {
    "path": "oneflow/core/ep/include/primitive/tensor_fill.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_EP_PRIMITIVE_TENSOR_FILL_H_\n#define ONEFLOW_CORE_EP_PRIMITIVE_TENSOR_FILL_H_\n\n#include \"oneflow/core/ep/include/primitive/primitive.h\"\n\nnamespace oneflow {\n\nnamespace ep {\nnamespace primitive {\n\nclass TensorFill : public Primitive {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(TensorFill);\n  TensorFill() = default;\n  ~TensorFill() override = default;\n\n  virtual void Launch(Stream* stream, const void* src, void* dst, size_t count) = 0;\n};\n\nclass TensorFillFactory : public Factory<TensorFill> {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(TensorFillFactory);\n  TensorFillFactory() = default;\n  ~TensorFillFactory() override = default;\n\n  virtual std::unique_ptr<TensorFill> New(DataType data_type) = 0;\n};\n\n}  // namespace primitive\n}  // namespace ep\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_EP_PRIMITIVE_TENSOR_FILL_H_\n"
  },
  {
    "path": "oneflow/core/ep/include/primitive/unary_op.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_EP_PRIMITIVE_UNARY_OP_H_\n#define ONEFLOW_CORE_EP_PRIMITIVE_UNARY_OP_H_\n\nnamespace oneflow {\n\nnamespace ep {\nnamespace primitive {\n\nenum class UnaryOp {\n  kIdentity,\n  // activation op\n  kElu,\n  kCelu,\n  kRelu,\n  kGelu,\n  kHardSwish,\n  kHardSigmoid,\n  kHardShrink,\n  kHardTanh,\n  kLeakyRelu,\n  kMish,\n  kSelu,\n  kSilu,\n  kSoftShrink,\n  kSoftSign,\n  kSoftPlus,\n  kTanh,\n  kThreshold,\n  kFastGelu,\n  kQuickGelu,\n  kSquareReLU,\n  // math op\n  kAbs,\n  kAcos,\n  kAcosh,\n  kAsin,\n  kAsinh,\n  kAtan,\n  kAtanh,\n  kCeil,\n  kCos,\n  kCosh,\n  kDigamma,\n  kTrigamma,\n  kErf,\n  kErfc,\n  kExp,\n  kExp2,\n  kExpm1,\n  kFloor,\n  kLgamma,\n  kLog,\n  kLog2,\n  kLog10,\n  kLog1p,\n  kLogSigmoid,\n  kNegative,\n  kReciprocal,\n  kReciprocalNoNan,\n  kRint,\n  kRound,\n  kRsqrt,\n  kSigmoid,\n  kSign,\n  kSin,\n  kSinh,\n  kSqrt,\n  kSquare,\n  kTan,\n  kTrunc,\n  kNotEqualZero,\n  // logical op\n  kLogicalNot,\n\n  // cast op\n  kCast,\n\n  // utils op\n  kIsInf,\n  kIsNan,\n  kIsFinite,\n  kNanAssign,\n\n  // bitwise op\n  kBitwiseNot,\n\n  // complex op\n  kConj,\n  kReal,\n  kImag,\n  kRealGrad,\n  kImagGrad\n};\n\n}\n}  // namespace ep\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_EP_PRIMITIVE_UNARY_OP_H_\n"
  },
  {
    "path": "oneflow/core/ep/include/primitive/where.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_EP_PRIMITIVE_WHERE_H_\n#define ONEFLOW_CORE_EP_PRIMITIVE_WHERE_H_\n\n#include \"oneflow/core/ep/include/primitive/primitive.h\"\n\nnamespace oneflow {\nnamespace ep {\nnamespace primitive {\n\nclass Where : public Primitive {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(Where);\n  Where() = default;\n  ~Where() override = default;\n\n  virtual void Launch(Stream* stream, size_t num_cond_dims, const int64_t* cond_dims,\n                      const void* cond, size_t num_x_dims, const int64_t* x_dims, const void* x,\n                      size_t num_y_dims, const int64_t* y_dims, const void* y, void* z) = 0;\n};\n\nclass WhereFactory : public Factory<Where> {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(WhereFactory);\n  WhereFactory() = default;\n  ~WhereFactory() override = default;\n\n  virtual std::unique_ptr<Where> New(DataType cond_type, DataType data_type,\n                                     size_t max_num_dims) = 0;\n};\n\n}  // namespace primitive\n}  // namespace ep\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_EP_PRIMITIVE_WHERE_H_\n"
  },
  {
    "path": "oneflow/core/ep/include/random_generator.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_EP_RANDOM_GENERATOR_H_\n#define ONEFLOW_CORE_EP_RANDOM_GENERATOR_H_\n\n#include <string>\n\nnamespace oneflow {\nnamespace ep {\n\nclass RandomGenerator {\n public:\n  RandomGenerator() = default;\n  virtual ~RandomGenerator() = default;\n\n  virtual uint64_t current_seed() const = 0;\n  virtual void set_current_seed(uint64_t seed) = 0;\n\n  virtual std::string device_type_name() const = 0;\n  virtual int64_t device_index() const = 0;\n\n  virtual size_t GetStateSize() const = 0;\n  virtual void GetState(size_t state_size, void* state) const = 0;\n  virtual void SetState(size_t state_size, const void* state) = 0;\n};\n\ntemplate<typename T>\nstd::string GetRandomGeneratorDeviceTypeName();\n\n}  // namespace ep\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_EP_RANDOM_GENERATOR_H_\n"
  },
  {
    "path": "oneflow/core/ep/include/stream.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_EP_STREAM_H_\n#define ONEFLOW_CORE_EP_STREAM_H_\n\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/common/device_type.h\"\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/ep/include/event.h\"\n\nnamespace oneflow {\n\nnamespace ep {\n\nclass Device;\n\nclass Stream {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(Stream);\n  Stream() = default;\n  virtual ~Stream() = default;\n\n  virtual DeviceType device_type() const = 0;\n  virtual Device* device() const = 0;\n  virtual Maybe<void> Sync() = 0;\n  virtual void RecordEvent(Event* event) = 0;\n  virtual void WaitEvent(Event* event) { UNIMPLEMENTED(); }\n  virtual Maybe<void> GetAsyncError() { return Maybe<void>::Ok(); }\n\n  virtual Maybe<void> AllocAsync(void** ptr, size_t size) { UNIMPLEMENTED_THEN_RETURN(); }\n  virtual Maybe<void> FreeAsync(void* ptr) { UNIMPLEMENTED_THEN_RETURN(); }\n  template<typename T>\n  Maybe<void> AllocAsync(T** ptr, size_t size) {\n    return AllocAsync(reinterpret_cast<void**>(ptr), size);\n  }\n\n  virtual Maybe<void> OnExecutionContextSetup() { return Maybe<void>::Ok(); }\n  virtual Maybe<void> OnExecutionContextTeardown() { return Maybe<void>::Ok(); }\n\n  template<typename T>\n  T* As() {\n    return static_cast<T*>(this);\n  }\n};\n\n}  // namespace ep\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_EP_STREAM_H_\n"
  },
  {
    "path": "oneflow/core/ep/test/primitive/add_test.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <gtest/gtest.h>\n#include \"oneflow/core/ep/test/primitive/primitive_test.h\"\n#include \"oneflow/core/ep/include/primitive/memset.h\"\n#include \"oneflow/core/ep/include/primitive/memcpy.h\"\n#include \"oneflow/core/ep/include/primitive/add.h\"\n#include <Eigen/Core>\n\nnamespace oneflow {\n\nnamespace ep {\n\nnamespace primitive {\n\nnamespace test {\n\nnamespace {\n\ntemplate<DataType data_type, typename T, size_t n>\nvoid TestAdd(DeviceManagerRegistry* registry, const std::set<DeviceType>& device_types) {\n  constexpr size_t max_arity = 10;\n  using Matrix = Eigen::Matrix<T, 1, n>;\n  std::vector<Matrix> srcs(max_arity);\n  std::vector<Matrix> dsts(max_arity);\n  for (size_t i = 0; i < max_arity; ++i) {\n    srcs[i] = Matrix::Random();\n    if (i == 0) {\n      dsts[i] = Matrix::Zero();\n    } else {\n      dsts[i] = srcs[i - 1] + dsts[i - 1];\n    }\n  }\n  const size_t vector_size = n * sizeof(T);\n  for (const auto& device_type : device_types) {\n    auto device = registry->GetDevice(device_type, 0);\n    std::vector<void*> host_srcs(max_arity);\n    std::vector<void*> device_srcs(max_arity);\n    std::vector<void*> host_dsts(max_arity);\n    std::vector<void*> device_dsts(max_arity);\n    AllocationOptions pinned_options;\n    pinned_options.SetPinnedDevice(device_type, 0);\n    AllocationOptions device_options;\n    for (size_t i = 0; i < max_arity; ++i) {\n      CHECK_JUST(device->AllocPinned(pinned_options, &host_srcs[i], vector_size));\n      CHECK_JUST(device->AllocPinned(pinned_options, &host_dsts[i], vector_size));\n      CHECK_JUST(device->Alloc(device_options, &device_srcs[i], vector_size));\n      CHECK_JUST(device->Alloc(device_options, &device_dsts[i], vector_size));\n    }\n    ep::test::StreamGuard stream(device.get());\n    std::unique_ptr<Add> add = NewPrimitive<AddFactory>(device_type, data_type);\n    ASSERT_TRUE(add.operator bool());\n    std::unique_ptr<Memcpy> h2d = NewPrimitive<MemcpyFactory>(device_type, MemcpyKind::kHtoD);\n    std::unique_ptr<Memcpy> d2h = NewPrimitive<MemcpyFactory>(device_type, MemcpyKind::kDtoH);\n    ASSERT_TRUE(d2h.operator bool());\n    ASSERT_TRUE(h2d.operator bool());\n    for (size_t i = 0; i < max_arity; ++i) {\n      std::memcpy(host_srcs[i], srcs[i].data(), vector_size);\n      h2d->Launch(stream.stream(), device_srcs[i], host_srcs[i], vector_size);\n    }\n    for (size_t i = 2; i < max_arity; ++i) {\n      add->Launch(stream.stream(), device_srcs.data(), i, device_dsts.at(i), n);\n    }\n    for (size_t i = 2; i < max_arity; ++i) {\n      d2h->Launch(stream.stream(), host_dsts[i], device_dsts[i], vector_size);\n    }\n    CHECK_JUST(stream.stream()->Sync());\n    for (size_t i = 2; i < max_arity; ++i) {\n      auto res = Eigen::Map<Matrix, Eigen::Unaligned>(reinterpret_cast<T*>(host_dsts[i]), n);\n      ASSERT_TRUE(dsts[i].template isApprox(res));\n    }\n    for (size_t i = 0; i < max_arity; ++i) {\n      device->FreePinned(pinned_options, host_srcs[i]);\n      device->FreePinned(pinned_options, host_dsts[i]);\n      device->Free(device_options, device_srcs[i]);\n      device->Free(device_options, device_dsts[i]);\n    }\n  }\n}\n\n}  // namespace\n\nTEST_F(PrimitiveTest, TestAdd) {\n  TestAdd<DataType::kDouble, double, 1024>(&device_manager_registry_, available_device_types_);\n  TestAdd<DataType::kFloat, float, 1024>(&device_manager_registry_, available_device_types_);\n  TestAdd<DataType::kFloat16, Eigen::half, 1024>(&device_manager_registry_,\n                                                 available_device_types_);\n}\n\n}  // namespace test\n\n}  // namespace primitive\n\n}  // namespace ep\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/ep/test/primitive/batch_matmul_test.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <gtest/gtest.h>\n#include \"oneflow/core/ep/test/primitive/primitive_test.h\"\n#include \"oneflow/core/ep/include/primitive/memset.h\"\n#include \"oneflow/core/ep/include/primitive/memcpy.h\"\n#include \"oneflow/core/ep/include/primitive/batch_matmul.h\"\n#include <unsupported/Eigen/CXX11/Tensor>\n\nnamespace oneflow {\n\nnamespace ep {\n\nnamespace primitive {\n\nnamespace test {\n\nnamespace {\n\ntemplate<DataType data_type, typename T>\nvoid TestBatchMatmul(DeviceManagerRegistry* registry, const std::set<DeviceType>& device_types,\n                     int batch_size, int m, int k, int n, bool transpose_a, bool transpose_b) {\n  using Matrix = Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;\n  Eigen::Tensor<T, 3, Eigen::RowMajor> in_a_buffer(batch_size, m, k);\n  Eigen::Tensor<T, 3, Eigen::RowMajor> in_b_buffer(batch_size, k, n);\n  Eigen::Tensor<T, 3, Eigen::RowMajor> out_c_buffer(batch_size, m, n);\n  in_a_buffer.setRandom();\n  in_b_buffer.setRandom();\n  for (int i = 0; i < batch_size; ++i) {\n    Eigen::Map<Matrix, Eigen::Unaligned> a(in_a_buffer.data() + i * m * k, m, k);\n    Eigen::Map<Matrix, Eigen::Unaligned> b(in_b_buffer.data() + i * k * n, k, n);\n    Eigen::Map<Matrix, Eigen::Unaligned> c(out_c_buffer.data() + i * m * n, m, n);\n    c = a * b;\n  }\n  int64_t a_size = batch_size * m * k * sizeof(T);\n  int64_t b_size = batch_size * k * n * sizeof(T);\n  int64_t c_size = batch_size * m * n * sizeof(T);\n\n  Eigen::array<int, 3> shuffling({0, 2, 1});\n  Eigen::Tensor<T, 3, Eigen::RowMajor> in_a_transposed = in_a_buffer.shuffle(shuffling);\n  Eigen::Tensor<T, 3, Eigen::RowMajor> in_b_transposed = in_b_buffer.shuffle(shuffling);\n\n  for (const auto& device_type : device_types) {\n    if (device_type == DeviceType::kCPU && data_type == DataType::kFloat16) {\n      // CPU matmul not support float16\n      continue;\n    }\n    auto device = registry->GetDevice(device_type, 0);\n    ep::test::PinnedMemoryGuard input_a(device.get(), a_size);\n    ep::test::PinnedMemoryGuard input_b(device.get(), b_size);\n    if (transpose_a) {\n      std::memcpy(input_a.ptr(), in_a_transposed.data(), a_size);\n    } else {\n      std::memcpy(input_a.ptr(), in_a_buffer.data(), a_size);\n    }\n    if (transpose_b) {\n      std::memcpy(input_b.ptr(), in_b_transposed.data(), b_size);\n    } else {\n      std::memcpy(input_b.ptr(), in_b_buffer.data(), b_size);\n    }\n    ep::test::PinnedMemoryGuard output(device.get(), c_size);\n    ep::test::DeviceMemoryGuard device_a(device.get(), a_size);\n    ep::test::DeviceMemoryGuard device_b(device.get(), b_size);\n    ep::test::DeviceMemoryGuard device_c(device.get(), c_size);\n    ep::test::StreamGuard stream(device.get());\n    std::unique_ptr<Memcpy> h2d = NewPrimitive<MemcpyFactory>(device_type, MemcpyKind::kHtoD);\n    std::unique_ptr<Memcpy> d2h = NewPrimitive<MemcpyFactory>(device_type, MemcpyKind::kDtoH);\n    const auto trans_a = transpose_a ? BlasTransposeType::T : BlasTransposeType::N;\n    const auto trans_b = transpose_b ? BlasTransposeType::T : BlasTransposeType::N;\n    std::unique_ptr<BatchMatmul> batch_matmul =\n        NewPrimitive<BatchMatmulFactory>(device_type, data_type, trans_a, trans_b);\n    ASSERT_TRUE(d2h.operator bool());\n    ASSERT_TRUE(h2d.operator bool());\n    ASSERT_TRUE(batch_matmul.operator bool());\n    h2d->Launch(stream.stream(), device_a.ptr(), input_a.ptr(), a_size);\n    h2d->Launch(stream.stream(), device_b.ptr(), input_b.ptr(), b_size);\n    batch_matmul->Launch(stream.stream(), batch_size, m, n, k, 1.0, device_a.ptr(), device_b.ptr(),\n                         0.0, device_c.ptr());\n    d2h->Launch(stream.stream(), output.ptr(), device_c.ptr(), c_size);\n    CHECK_JUST(stream.stream()->Sync());\n    Eigen::Map<Eigen::Matrix<T, 1, Eigen::Dynamic>, Eigen::Unaligned> eigen_out(\n        out_c_buffer.data(), out_c_buffer.size());\n    Eigen::Map<Eigen::Matrix<T, 1, Eigen::Dynamic>, Eigen::Unaligned> of_out(\n        reinterpret_cast<T*>(output.ptr()), out_c_buffer.size());\n    ASSERT_TRUE(eigen_out.template isApprox(of_out, static_cast<T>(0.001)));\n  }\n}\n\ntemplate<DataType data_type, typename T>\nvoid TestBatchMatmul(DeviceManagerRegistry* registry, const std::set<DeviceType>& device_types,\n                     int batch_size, int m, int k, int n) {\n  TestBatchMatmul<data_type, T>(registry, device_types, batch_size, m, k, n, false, false);\n  TestBatchMatmul<data_type, T>(registry, device_types, batch_size, m, k, n, true, false);\n  TestBatchMatmul<data_type, T>(registry, device_types, batch_size, m, k, n, false, true);\n  TestBatchMatmul<data_type, T>(registry, device_types, batch_size, m, k, n, true, true);\n}\n\ntemplate<DataType data_type, typename T>\nvoid TestBatchMatmul(DeviceManagerRegistry* registry, const std::set<DeviceType>& device_types) {\n  TestBatchMatmul<data_type, T>(registry, device_types, 10, 64, 16, 8);\n  TestBatchMatmul<data_type, T>(registry, device_types, 12, 16, 7, 12);\n}\n\n}  // namespace\n\nTEST_F(PrimitiveTest, TestBatchMatmul) {\n  TestBatchMatmul<DataType::kDouble, double>(&device_manager_registry_, available_device_types_);\n  TestBatchMatmul<DataType::kFloat, float>(&device_manager_registry_, available_device_types_);\n  TestBatchMatmul<DataType::kFloat16, Eigen::half>(&device_manager_registry_,\n                                                   available_device_types_);\n}\n\n}  // namespace test\n\n}  // namespace primitive\n\n}  // namespace ep\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/ep/test/primitive/binary_test.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <gtest/gtest.h>\n#include \"oneflow/core/ep/test/primitive/primitive_test.h\"\n#include \"oneflow/core/ep/include/primitive/memset.h\"\n#include \"oneflow/core/ep/include/primitive/memcpy.h\"\n#include \"oneflow/core/ep/include/primitive/broadcast_elementwise_binary.h\"\n#include <Eigen/Core>\n#include <unsupported/Eigen/CXX11/Tensor>\n\nnamespace oneflow {\n\nnamespace ep {\n\nnamespace primitive {\n\nnamespace test {\n\nnamespace {\n\ntemplate<typename T>\nScalar GetScalar(const T& value) {\n  return Scalar(value);\n}\n\ntemplate<>\nScalar GetScalar<Eigen::half>(const Eigen::half& value) {\n  return Scalar(static_cast<float>(value));\n}\n\ntemplate<BinaryOp binary_op, DataType src_data_type, typename Src, DataType dst_data_type,\n         typename Dst>\nvoid TestElementwiseBroadcastBinary(DeviceManagerRegistry* registry,\n                                    const std::set<DeviceType>& device_types, int test_type) {\n  const int num_axes = 4;\n  const int broadcast_dim0 = 16;\n  const int broadcast_dim1 = 3;\n  const int broadcast_dim2 = 4;\n  const int broadcast_dim3 = 8;\n  bool is_broadcast = false;\n  bool left_scalar = false;\n  bool right_scalar = false;\n  if (test_type == 0) {\n    // do nothing\n  } else if (test_type == 1) {\n    is_broadcast = true;\n  } else if (test_type == 2) {\n    left_scalar = true;\n  } else if (test_type == 3) {\n    right_scalar = true;\n  } else {\n    UNIMPLEMENTED();\n  }\n  const int a_dim0 = left_scalar ? 1 : broadcast_dim0;\n  const int a_dim1 = left_scalar ? 1 : broadcast_dim1;\n  const int a_dim2 = left_scalar ? 1 : broadcast_dim2;\n  const int a_dim3 = left_scalar ? 1 : (is_broadcast ? 1 : broadcast_dim3);\n  const int b_dim0 = right_scalar ? 1 : broadcast_dim0;\n  const int b_dim1 = right_scalar ? 1 : (is_broadcast ? 1 : broadcast_dim1);\n  const int b_dim2 = right_scalar ? 1 : broadcast_dim2;\n  const int b_dim3 = right_scalar ? 1 : broadcast_dim3;\n  const int a_broadcast0 = left_scalar ? broadcast_dim0 : 1;\n  const int a_broadcast1 = left_scalar ? broadcast_dim1 : 1;\n  const int a_broadcast2 = left_scalar ? broadcast_dim2 : 1;\n  const int a_broadcast3 = left_scalar ? broadcast_dim3 : (is_broadcast ? broadcast_dim3 : 1);\n  const int b_broadcast0 = right_scalar ? broadcast_dim0 : 1;\n  const int b_broadcast1 = right_scalar ? broadcast_dim1 : (is_broadcast ? broadcast_dim1 : 1);\n  const int b_broadcast2 = right_scalar ? broadcast_dim2 : 1;\n  const int b_broadcast3 = right_scalar ? broadcast_dim3 : 1;\n  const Eigen::array<int, 4> a_broadcast = {a_broadcast0, a_broadcast1, a_broadcast2, a_broadcast3};\n  const Eigen::array<int, 4> b_broadcast = {b_broadcast0, b_broadcast1, b_broadcast2, b_broadcast3};\n  Eigen::Tensor<Src, 4, Eigen::RowMajor> a(a_dim0, a_dim1, a_dim2, a_dim3);\n  Eigen::Tensor<Src, 4, Eigen::RowMajor> b(b_dim0, b_dim1, b_dim2, b_dim3);\n  Eigen::Tensor<Dst, 4, Eigen::RowMajor> c(broadcast_dim0, broadcast_dim1, broadcast_dim2,\n                                           broadcast_dim3);\n  a.setRandom();\n  b.setRandom();\n  if (binary_op == BinaryOp::kAdd) {\n    c = (a.broadcast(a_broadcast) + b.broadcast(b_broadcast)).template cast<Dst>();\n  } else if (binary_op == BinaryOp::kSub) {\n    c = (a.broadcast(a_broadcast) - b.broadcast(b_broadcast)).template cast<Dst>();\n  } else if (binary_op == BinaryOp::kMul) {\n    c = (a.broadcast(a_broadcast) * b.broadcast(b_broadcast)).template cast<Dst>();\n  } else if (binary_op == BinaryOp::kDiv) {\n    Eigen::Tensor<Src, 4, Eigen::RowMajor> constant_value(b_dim0, b_dim1, b_dim2, b_dim3);\n    // avoid div 0\n    if (src_data_type == kInt8 || src_data_type == kUInt8) {\n      int rand_value = std::rand() % 127;\n      constant_value.setConstant(static_cast<Src>(rand_value));\n      b = constant_value;\n    } else {\n      constant_value.setConstant(static_cast<Src>(1));\n      b += constant_value;\n    }\n    c = (a.broadcast(a_broadcast) / b.broadcast(b_broadcast)).template cast<Dst>();\n  } else if (binary_op == BinaryOp::kMax) {\n    c = (a.broadcast(a_broadcast).cwiseMax(b.broadcast(b_broadcast))).template cast<Dst>();\n  } else if (binary_op == BinaryOp::kMin) {\n    c = (a.broadcast(a_broadcast).cwiseMin(b.broadcast(b_broadcast))).template cast<Dst>();\n  } else if (binary_op == BinaryOp::kEqual) {\n    c = (a.broadcast(a_broadcast) == b.broadcast(b_broadcast)).template cast<Dst>();\n  } else if (binary_op == BinaryOp::kNotEqual) {\n    c = (a.broadcast(a_broadcast) != b.broadcast(b_broadcast)).template cast<Dst>();\n  } else if (binary_op == BinaryOp::kLessThan) {\n    c = (a.broadcast(a_broadcast) < b.broadcast(b_broadcast)).template cast<Dst>();\n  } else if (binary_op == BinaryOp::kLessEqual) {\n    c = (a.broadcast(a_broadcast) <= b.broadcast(b_broadcast)).template cast<Dst>();\n  } else if (binary_op == BinaryOp::kGreaterThan) {\n    c = (a.broadcast(a_broadcast) > b.broadcast(b_broadcast)).template cast<Dst>();\n  } else if (binary_op == BinaryOp::kGreaterEqual) {\n    c = (a.broadcast(a_broadcast) >= b.broadcast(b_broadcast)).template cast<Dst>();\n  } else if (binary_op == BinaryOp::kLogicalAnd) {\n    c = (a.broadcast(a_broadcast).template cast<bool>()\n         && b.broadcast(b_broadcast).template cast<bool>())\n            .template cast<Dst>();\n  } else if (binary_op == BinaryOp::kLogicalOr) {\n    c = (a.broadcast(a_broadcast).template cast<bool>()\n         || b.broadcast(b_broadcast).template cast<bool>())\n            .template cast<Dst>();\n  } else if (binary_op == BinaryOp::kLogicalXor) {\n    c = (a.broadcast(a_broadcast).template cast<bool>()\n         ^ b.broadcast(b_broadcast).template cast<bool>())\n            .template cast<Dst>();\n  } else {\n    UNIMPLEMENTED();\n  }\n  std::vector<int64_t> a_dims = {a.dimension(0), a.dimension(1), a.dimension(2), a.dimension(3)};\n  std::vector<int64_t> b_dims = {b.dimension(0), b.dimension(1), b.dimension(2), b.dimension(3)};\n  std::vector<int64_t> c_dims = {c.dimension(0), c.dimension(1), c.dimension(2), c.dimension(3)};\n  int64_t a_size = a.size() * sizeof(Src);\n  int64_t b_size = b.size() * sizeof(Src);\n  int64_t c_size = c.size() * sizeof(Dst);\n\n  for (const auto& device_type : device_types) {\n    auto device = registry->GetDevice(device_type, 0);\n    ep::test::PinnedMemoryGuard input_a(device.get(), a_size);\n    ep::test::PinnedMemoryGuard input_b(device.get(), b_size);\n    std::memcpy(input_a.ptr(), a.data(), a_size);\n    std::memcpy(input_b.ptr(), b.data(), b_size);\n\n    ep::test::PinnedMemoryGuard output(device.get(), c_size);\n    ep::test::DeviceMemoryGuard device_a(device.get(), a_size);\n    ep::test::DeviceMemoryGuard device_b(device.get(), b_size);\n    ep::test::DeviceMemoryGuard device_c(device.get(), c_size);\n    ep::test::StreamGuard stream(device.get());\n    std::unique_ptr<Memcpy> h2d = NewPrimitive<MemcpyFactory>(device_type, MemcpyKind::kHtoD);\n    std::unique_ptr<Memcpy> d2h = NewPrimitive<MemcpyFactory>(device_type, MemcpyKind::kDtoH);\n    std::unique_ptr<BroadcastElementwiseBinary> binary =\n        NewPrimitive<BroadcastElementwiseBinaryFactory>(device_type, binary_op, src_data_type,\n                                                        dst_data_type, num_axes);\n    ASSERT_TRUE(d2h.operator bool());\n    ASSERT_TRUE(h2d.operator bool());\n    ASSERT_TRUE(binary.operator bool());\n    h2d->Launch(stream.stream(), device_a.ptr(), input_a.ptr(), a_size);\n    h2d->Launch(stream.stream(), device_b.ptr(), input_b.ptr(), b_size);\n    if (left_scalar) {\n      Src a_value = *reinterpret_cast<Src*>(input_a.ptr());\n      binary->Launch(stream.stream(), GetScalar(a_value), num_axes, b_dims.data(), device_b.ptr(),\n                     device_c.ptr());\n    } else if (right_scalar) {\n      Src b_value = *reinterpret_cast<Src*>(input_b.ptr());\n      binary->Launch(stream.stream(), num_axes, a_dims.data(), device_a.ptr(), GetScalar(b_value),\n                     device_c.ptr());\n    } else {\n      binary->Launch(stream.stream(), num_axes, a_dims.data(), device_a.ptr(), num_axes,\n                     b_dims.data(), device_b.ptr(), device_c.ptr());\n    }\n    d2h->Launch(stream.stream(), output.ptr(), device_c.ptr(), c_size);\n    CHECK_JUST(stream.stream()->Sync());\n\n    Eigen::Map<Eigen::Matrix<Dst, 1, Eigen::Dynamic>, Eigen::Unaligned> eigen_out(c.data(),\n                                                                                  c.size());\n    Eigen::Map<Eigen::Matrix<Dst, 1, Eigen::Dynamic>, Eigen::Unaligned> of_out(\n        reinterpret_cast<Dst*>(output.ptr()), c.size());\n    ASSERT_TRUE(eigen_out.template isApprox(of_out));\n  }\n}\n\ntemplate<BinaryOp binary_op, DataType src_data_type, typename Src, DataType dst_data_type,\n         typename Dst>\nvoid TestElementwiseBroadcastBinary(DeviceManagerRegistry* registry,\n                                    const std::set<DeviceType>& device_types) {\n  TestElementwiseBroadcastBinary<binary_op, src_data_type, Src, dst_data_type, Dst>(\n      registry, device_types, 0);\n  TestElementwiseBroadcastBinary<binary_op, src_data_type, Src, dst_data_type, Dst>(\n      registry, device_types, 1);\n  TestElementwiseBroadcastBinary<binary_op, src_data_type, Src, dst_data_type, Dst>(\n      registry, device_types, 2);\n  TestElementwiseBroadcastBinary<binary_op, src_data_type, Src, dst_data_type, Dst>(\n      registry, device_types, 3);\n}\n\ntemplate<BinaryOp binary_op>\nvoid TestComputeBinary(DeviceManagerRegistry* registry, const std::set<DeviceType>& device_types) {\n  TestElementwiseBroadcastBinary<binary_op, DataType::kInt8, int8_t, DataType::kInt8, int8_t>(\n      registry, device_types);\n  TestElementwiseBroadcastBinary<binary_op, DataType::kUInt8, uint8_t, DataType::kUInt8, uint8_t>(\n      registry, device_types);\n  TestElementwiseBroadcastBinary<binary_op, DataType::kInt32, int32_t, DataType::kInt32, int32_t>(\n      registry, device_types);\n  TestElementwiseBroadcastBinary<binary_op, DataType::kInt64, int64_t, DataType::kInt64, int64_t>(\n      registry, device_types);\n  TestElementwiseBroadcastBinary<binary_op, DataType::kDouble, double, DataType::kDouble, double>(\n      registry, device_types);\n  TestElementwiseBroadcastBinary<binary_op, DataType::kFloat, float, DataType::kFloat, float>(\n      registry, device_types);\n  TestElementwiseBroadcastBinary<binary_op, DataType::kFloat16, Eigen::half, DataType::kFloat16,\n                                 Eigen::half>(registry, device_types);\n}\n\ntemplate<BinaryOp binary_op>\nvoid TestLogicalBinary(DeviceManagerRegistry* registry, const std::set<DeviceType>& device_types) {\n  TestElementwiseBroadcastBinary<binary_op, DataType::kInt8, int8_t, DataType::kBool, bool>(\n      registry, device_types);\n  TestElementwiseBroadcastBinary<binary_op, DataType::kUInt8, uint8_t, DataType::kBool, bool>(\n      registry, device_types);\n  TestElementwiseBroadcastBinary<binary_op, DataType::kInt32, int32_t, DataType::kBool, bool>(\n      registry, device_types);\n  TestElementwiseBroadcastBinary<binary_op, DataType::kInt64, int64_t, DataType::kBool, bool>(\n      registry, device_types);\n  TestElementwiseBroadcastBinary<binary_op, DataType::kDouble, double, DataType::kBool, bool>(\n      registry, device_types);\n  TestElementwiseBroadcastBinary<binary_op, DataType::kFloat, float, DataType::kBool, bool>(\n      registry, device_types);\n  TestElementwiseBroadcastBinary<binary_op, DataType::kFloat16, Eigen::half, DataType::kBool, bool>(\n      registry, device_types);\n}\n\n}  // namespace\n\nTEST_F(PrimitiveTest, TestBinary) {\n  TestComputeBinary<BinaryOp::kAdd>(&device_manager_registry_, available_device_types_);\n  TestComputeBinary<BinaryOp::kSub>(&device_manager_registry_, available_device_types_);\n  TestComputeBinary<BinaryOp::kMul>(&device_manager_registry_, available_device_types_);\n  TestComputeBinary<BinaryOp::kDiv>(&device_manager_registry_, available_device_types_);\n  TestComputeBinary<BinaryOp::kMax>(&device_manager_registry_, available_device_types_);\n  TestComputeBinary<BinaryOp::kMin>(&device_manager_registry_, available_device_types_);\n  TestLogicalBinary<BinaryOp::kEqual>(&device_manager_registry_, available_device_types_);\n  TestLogicalBinary<BinaryOp::kNotEqual>(&device_manager_registry_, available_device_types_);\n  TestLogicalBinary<BinaryOp::kLessThan>(&device_manager_registry_, available_device_types_);\n  TestLogicalBinary<BinaryOp::kLessEqual>(&device_manager_registry_, available_device_types_);\n  TestLogicalBinary<BinaryOp::kGreaterThan>(&device_manager_registry_, available_device_types_);\n  TestLogicalBinary<BinaryOp::kGreaterEqual>(&device_manager_registry_, available_device_types_);\n  TestLogicalBinary<BinaryOp::kLogicalAnd>(&device_manager_registry_, available_device_types_);\n  TestLogicalBinary<BinaryOp::kLogicalOr>(&device_manager_registry_, available_device_types_);\n  TestLogicalBinary<BinaryOp::kLogicalXor>(&device_manager_registry_, available_device_types_);\n}\n\n}  // namespace test\n\n}  // namespace primitive\n\n}  // namespace ep\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/ep/test/primitive/broadcast_matmul_test.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <gtest/gtest.h>\n#include \"oneflow/core/ep/test/primitive/primitive_test.h\"\n#include \"oneflow/core/ep/include/primitive/memset.h\"\n#include \"oneflow/core/ep/include/primitive/memcpy.h\"\n#include \"oneflow/core/ep/include/primitive/broadcast_matmul.h\"\n#include <unsupported/Eigen/CXX11/Tensor>\n\nnamespace oneflow {\n\nnamespace ep {\n\nnamespace primitive {\n\nnamespace test {\n\nnamespace {\n\ntemplate<DataType data_type, typename T>\nvoid TestBroadcastMatmul(DeviceManagerRegistry* registry, const std::set<DeviceType>& device_types,\n                         int batch_size, int m, int k, int n, bool transpose_a, bool transpose_b,\n                         bool broadcast_a, bool broadcast_b, bool reduce_c) {\n  using Matrix = Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;\n  CHECK((!broadcast_a) || (!broadcast_b));\n  int a_batch_dims = broadcast_a ? 1 : batch_size;\n  int b_batch_dims = broadcast_b ? 1 : batch_size;\n  int c_batch_dims = reduce_c ? 1 : batch_size;\n  Eigen::Tensor<T, 3, Eigen::RowMajor> in_a_buffer(a_batch_dims, m, k);\n  Eigen::Tensor<T, 3, Eigen::RowMajor> in_b_buffer(b_batch_dims, k, n);\n  Eigen::Tensor<T, 3, Eigen::RowMajor> out_c_buffer(c_batch_dims, m, n);\n  Eigen::Tensor<T, 3, Eigen::RowMajor> broadcast_c_buffer(batch_size, m, n);\n  in_a_buffer.setRandom();\n  in_b_buffer.setRandom();\n  for (int i = 0; i < batch_size; ++i) {\n    int64_t a_offset = broadcast_a ? 0 : i * m * k;\n    int64_t b_offset = broadcast_b ? 0 : i * k * n;\n    Eigen::Map<Matrix, Eigen::Unaligned> a(in_a_buffer.data() + a_offset, m, k);\n    Eigen::Map<Matrix, Eigen::Unaligned> b(in_b_buffer.data() + b_offset, k, n);\n    Eigen::Map<Matrix, Eigen::Unaligned> c(broadcast_c_buffer.data() + i * m * n, m, n);\n    c = a * b;\n  }\n  if (reduce_c) {\n    Eigen::array<int, 1> reduce_dim = {0};\n    out_c_buffer = broadcast_c_buffer.sum(reduce_dim).eval().reshape(out_c_buffer.dimensions());\n  } else {\n    out_c_buffer = broadcast_c_buffer;\n  }\n  int64_t a_size = a_batch_dims * m * k * sizeof(T);\n  int64_t b_size = b_batch_dims * k * n * sizeof(T);\n  int64_t c_size = c_batch_dims * m * n * sizeof(T);\n  Eigen::array<int, 3> shuffling({0, 2, 1});\n  Eigen::Tensor<T, 3, Eigen::RowMajor> in_a_transposed = in_a_buffer.shuffle(shuffling);\n  Eigen::Tensor<T, 3, Eigen::RowMajor> in_b_transposed = in_b_buffer.shuffle(shuffling);\n\n  size_t num_a_dims = broadcast_a ? 2 : 3;\n  std::vector<int64_t> a_dims;\n  if (!broadcast_a) { a_dims.push_back(batch_size); }\n  if (transpose_a) {\n    a_dims.push_back(k);\n    a_dims.push_back(m);\n  } else {\n    a_dims.push_back(m);\n    a_dims.push_back(k);\n  }\n  size_t num_b_dims = broadcast_b ? 2 : 3;\n  std::vector<int64_t> b_dims;\n  if (!broadcast_b) { b_dims.push_back(batch_size); }\n  if (transpose_b) {\n    b_dims.push_back(n);\n    b_dims.push_back(k);\n  } else {\n    b_dims.push_back(k);\n    b_dims.push_back(n);\n  }\n  size_t num_c_dims = reduce_c ? 2 : 3;\n  std::vector<int64_t> c_dims;\n  if (!reduce_c) { c_dims.push_back(batch_size); }\n  c_dims.push_back(m);\n  c_dims.push_back(n);\n\n  for (const auto& device_type : device_types) {\n    if (device_type == DeviceType::kCPU && data_type == DataType::kFloat16) {\n      // CPU matmul not support float16\n      continue;\n    }\n    auto device = registry->GetDevice(device_type, 0);\n    ep::test::PinnedMemoryGuard input_a(device.get(), a_size);\n    ep::test::PinnedMemoryGuard input_b(device.get(), b_size);\n    if (transpose_a) {\n      std::memcpy(input_a.ptr(), in_a_transposed.data(), a_size);\n    } else {\n      std::memcpy(input_a.ptr(), in_a_buffer.data(), a_size);\n    }\n    if (transpose_b) {\n      std::memcpy(input_b.ptr(), in_b_transposed.data(), b_size);\n    } else {\n      std::memcpy(input_b.ptr(), in_b_buffer.data(), b_size);\n    }\n    ep::test::PinnedMemoryGuard output(device.get(), c_size);\n    ep::test::DeviceMemoryGuard device_a(device.get(), a_size);\n    ep::test::DeviceMemoryGuard device_b(device.get(), b_size);\n    ep::test::DeviceMemoryGuard device_c(device.get(), c_size);\n    ep::test::StreamGuard stream(device.get());\n    std::unique_ptr<Memcpy> h2d = NewPrimitive<MemcpyFactory>(device_type, MemcpyKind::kHtoD);\n    std::unique_ptr<Memcpy> d2h = NewPrimitive<MemcpyFactory>(device_type, MemcpyKind::kDtoH);\n    const auto trans_a = transpose_a ? BlasTransposeType::T : BlasTransposeType::N;\n    const auto trans_b = transpose_b ? BlasTransposeType::T : BlasTransposeType::N;\n    std::unique_ptr<BroadcastMatmul> broadcast_matmul =\n        NewPrimitive<BroadcastMatmulFactory>(device_type, data_type, trans_a, trans_b, 3);\n    ASSERT_TRUE(d2h.operator bool());\n    ASSERT_TRUE(h2d.operator bool());\n    ASSERT_TRUE(broadcast_matmul.operator bool());\n    h2d->Launch(stream.stream(), device_a.ptr(), input_a.ptr(), a_size);\n    h2d->Launch(stream.stream(), device_b.ptr(), input_b.ptr(), b_size);\n    broadcast_matmul->Launch(stream.stream(), 1.0, num_a_dims, a_dims.data(), device_a.ptr(),\n                             num_b_dims, b_dims.data(), device_b.ptr(), 0.0, num_c_dims,\n                             c_dims.data(), device_c.ptr());\n    d2h->Launch(stream.stream(), output.ptr(), device_c.ptr(), c_size);\n    CHECK_JUST(stream.stream()->Sync());\n    Eigen::Map<Eigen::Matrix<T, 1, Eigen::Dynamic>, Eigen::Unaligned> eigen_out(\n        out_c_buffer.data(), out_c_buffer.size());\n    Eigen::Map<Eigen::Matrix<T, 1, Eigen::Dynamic>, Eigen::Unaligned> of_out(\n        reinterpret_cast<T*>(output.ptr()), out_c_buffer.size());\n    ASSERT_TRUE(eigen_out.template isApprox(of_out, static_cast<T>(0.001)));\n  }\n}\n\ntemplate<DataType data_type, typename T>\nvoid TestBroadcastMatmul(DeviceManagerRegistry* registry, const std::set<DeviceType>& device_types,\n                         int m, int k, int n, bool transpose_a, bool transpose_b) {\n  TestBroadcastMatmul<data_type, T>(registry, device_types, 10, m, k, n, transpose_a, transpose_b,\n                                    false, false, true);\n  TestBroadcastMatmul<data_type, T>(registry, device_types, 10, m, k, n, transpose_a, transpose_b,\n                                    false, false, false);\n  TestBroadcastMatmul<data_type, T>(registry, device_types, 10, m, k, n, transpose_a, transpose_b,\n                                    false, true, true);\n  TestBroadcastMatmul<data_type, T>(registry, device_types, 10, m, k, n, transpose_a, transpose_b,\n                                    false, true, false);\n  TestBroadcastMatmul<data_type, T>(registry, device_types, 12, m, k, n, transpose_a, transpose_b,\n                                    true, false, true);\n  TestBroadcastMatmul<data_type, T>(registry, device_types, 12, m, k, n, transpose_a, transpose_b,\n                                    true, false, false);\n}\n\ntemplate<DataType data_type, typename T>\nvoid TestBroadcastMatmul(DeviceManagerRegistry* registry, const std::set<DeviceType>& device_types,\n                         int m, int k, int n) {\n  TestBroadcastMatmul<data_type, T>(registry, device_types, m, k, n, false, false);\n  TestBroadcastMatmul<data_type, T>(registry, device_types, m, k, n, true, false);\n  TestBroadcastMatmul<data_type, T>(registry, device_types, m, k, n, false, true);\n  TestBroadcastMatmul<data_type, T>(registry, device_types, m, k, n, true, true);\n}\n\ntemplate<DataType data_type, typename T>\nvoid TestBroadcastMatmul(DeviceManagerRegistry* registry,\n                         const std::set<DeviceType>& device_types) {\n  TestBroadcastMatmul<data_type, T>(registry, device_types, 64, 16, 8);\n  TestBroadcastMatmul<data_type, T>(registry, device_types, 16, 7, 12);\n}\n\n}  // namespace\n\nTEST_F(PrimitiveTest, TestBroadcastMatmul) {\n  TestBroadcastMatmul<DataType::kDouble, double>(&device_manager_registry_,\n                                                 available_device_types_);\n  TestBroadcastMatmul<DataType::kFloat, float>(&device_manager_registry_, available_device_types_);\n  TestBroadcastMatmul<DataType::kFloat16, Eigen::half>(&device_manager_registry_,\n                                                       available_device_types_);\n}\n\n}  // namespace test\n\n}  // namespace primitive\n\n}  // namespace ep\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/ep/test/primitive/cast_test.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <gtest/gtest.h>\n#include \"oneflow/core/ep/test/primitive/primitive_test.h\"\n#include \"oneflow/core/ep/include/primitive/memset.h\"\n#include \"oneflow/core/ep/include/primitive/memcpy.h\"\n#include \"oneflow/core/ep/include/primitive/cast.h\"\n#include <unsupported/Eigen/CXX11/Tensor>\n\nnamespace oneflow {\n\nnamespace ep {\n\nnamespace primitive {\n\nnamespace test {\n\nnamespace {\n\ntemplate<DataType src_data_type, typename Src, DataType dst_data_type, typename Dst>\nvoid TestCast(DeviceManagerRegistry* registry, const std::set<DeviceType>& device_types,\n              int elem_cnt) {\n  if (src_data_type == dst_data_type) { return; }\n  if (dst_data_type == kFloat16 && src_data_type != kFloat) { return; }\n  const int src_data_size = elem_cnt * sizeof(Src);\n  const int dst_data_size = elem_cnt * sizeof(Dst);\n  Eigen::Tensor<Src, 1, Eigen::RowMajor> cast_in(elem_cnt);\n  Eigen::Tensor<Dst, 1, Eigen::RowMajor> cast_out(elem_cnt);\n  cast_in.setRandom();\n  cast_out = cast_in.template cast<Dst>();\n\n  for (const auto& device_type : device_types) {\n    auto device = registry->GetDevice(device_type, 0);\n    ep::test::PinnedMemoryGuard input(device.get(), src_data_size);\n    ep::test::PinnedMemoryGuard output(device.get(), dst_data_size);\n    std::memcpy(input.ptr(), cast_in.data(), src_data_size);\n    ep::test::DeviceMemoryGuard device_in(device.get(), src_data_size);\n    ep::test::DeviceMemoryGuard device_out(device.get(), dst_data_size);\n    ep::test::StreamGuard stream(device.get());\n    std::unique_ptr<Memcpy> h2d = NewPrimitive<MemcpyFactory>(device_type, MemcpyKind::kHtoD);\n    ASSERT_TRUE(h2d.operator bool());\n    std::unique_ptr<Memcpy> d2h = NewPrimitive<MemcpyFactory>(device_type, MemcpyKind::kDtoH);\n    ASSERT_TRUE(d2h.operator bool());\n    h2d->Launch(stream.stream(), device_in.ptr(), input.ptr(), src_data_size);\n    std::unique_ptr<Cast> cast =\n        NewPrimitive<CastFactory>(device_type, src_data_type, dst_data_type);\n    ASSERT_TRUE(cast.operator bool());\n    cast->Launch(stream.stream(), device_in.ptr(), device_out.ptr(), elem_cnt);\n    d2h->Launch(stream.stream(), output.ptr(), device_out.ptr(), dst_data_size);\n    CHECK_JUST(stream.stream()->Sync());\n    Eigen::Map<Eigen::Matrix<Dst, 1, Eigen::Dynamic>, Eigen::Unaligned> eigen_out(cast_out.data(),\n                                                                                  cast_out.size());\n    Eigen::Map<Eigen::Matrix<Dst, 1, Eigen::Dynamic>, Eigen::Unaligned> of_out(\n        reinterpret_cast<Dst*>(output.ptr()), cast_out.size());\n    ASSERT_TRUE(eigen_out.template isApprox(of_out));\n  }\n}\n\ntemplate<DataType src_data_type, typename Src>\nvoid TestCast(DeviceManagerRegistry* registry, const std::set<DeviceType>& device_types,\n              int elem_cnt) {\n  TestCast<src_data_type, Src, DataType::kBool, bool>(registry, device_types, elem_cnt);\n  TestCast<src_data_type, Src, DataType::kInt8, int8_t>(registry, device_types, elem_cnt);\n  TestCast<src_data_type, Src, DataType::kUInt8, uint8_t>(registry, device_types, elem_cnt);\n  TestCast<src_data_type, Src, DataType::kInt32, int32_t>(registry, device_types, elem_cnt);\n  TestCast<src_data_type, Src, DataType::kInt64, int64_t>(registry, device_types, elem_cnt);\n  TestCast<src_data_type, Src, DataType::kFloat, float>(registry, device_types, elem_cnt);\n  TestCast<src_data_type, Src, DataType::kDouble, double>(registry, device_types, elem_cnt);\n  TestCast<src_data_type, Src, DataType::kFloat16, Eigen::half>(registry, device_types, elem_cnt);\n}\n\nvoid TestCast(DeviceManagerRegistry* registry, const std::set<DeviceType>& device_types,\n              int elem_cnt) {\n  TestCast<DataType::kBool, bool>(registry, device_types, elem_cnt);\n  TestCast<DataType::kInt8, int8_t>(registry, device_types, elem_cnt);\n  TestCast<DataType::kUInt8, uint8_t>(registry, device_types, elem_cnt);\n  TestCast<DataType::kInt32, int32_t>(registry, device_types, elem_cnt);\n  TestCast<DataType::kInt64, int64_t>(registry, device_types, elem_cnt);\n  TestCast<DataType::kFloat, float>(registry, device_types, elem_cnt);\n  TestCast<DataType::kDouble, double>(registry, device_types, elem_cnt);\n  TestCast<DataType::kFloat16, Eigen::half>(registry, device_types, elem_cnt);\n}\n\n}  // namespace\n\nTEST_F(PrimitiveTest, TestCast) {\n  std::vector<int> elem_cnts = {1024, 3193, 5765};\n  for (int i = 0; i < elem_cnts.size(); ++i) {\n    TestCast(&device_manager_registry_, available_device_types_, elem_cnts.at(i));\n  }\n}\n\n}  // namespace test\n\n}  // namespace primitive\n\n}  // namespace ep\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/ep/test/primitive/constant_pad_test.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <gtest/gtest.h>\n#include \"oneflow/core/ep/test/primitive/primitive_test.h\"\n#include \"oneflow/core/ep/include/primitive/memcpy.h\"\n#include \"oneflow/core/ep/include/primitive/constant_pad.h\"\n#include <Eigen/Core>\n#include <unsupported/Eigen/CXX11/Tensor>\nnamespace oneflow {\n\nnamespace ep {\n\nnamespace primitive {\n\nnamespace test {\n\ntemplate<typename T, DataType dtype>\nvoid TestConstantPad2d(DeviceManagerRegistry* registry, const std::set<DeviceType>& device_types,\n                       const int dims[2], const std::vector<int64_t> padding_before,\n                       const std::vector<int64_t> padding_after) {\n  using EigenVec = Eigen::Matrix<T, 1, Eigen::Dynamic>;\n  int in_elem_cnt = 1;\n  int out_elem_cnt = 1;\n  for (int i = 0; i < 2; i++) {\n    in_elem_cnt *= dims[i];\n    out_elem_cnt *= (dims[i] + padding_before[i] + padding_after[i]);\n  }\n  const int in_matrix_size = in_elem_cnt * sizeof(T);\n  const int out_matrix_size = out_elem_cnt * sizeof(T);\n\n  for (const auto& device_type : device_types) {\n    Eigen::Tensor<T, 2, Eigen::RowMajor> mat(dims[0], dims[1]);\n\n    mat.setRandom();\n    auto device = registry->GetDevice(device_type, 0);\n\n    ep::test::PinnedMemoryGuard host_src(device.get(), in_matrix_size);\n    ep::test::PinnedMemoryGuard host_dst(device.get(), out_matrix_size);\n    ep::test::DeviceMemoryGuard device_src(device.get(), in_matrix_size);\n    ep::test::DeviceMemoryGuard device_dst(device.get(), out_matrix_size);\n\n    ep::test::StreamGuard stream(device.get());\n    std::unique_ptr<ConstantPad> constant_pad =\n        NewPrimitive<ConstantPadFactory>(device_type, dtype);\n    ASSERT_TRUE(constant_pad.operator bool());\n    std::unique_ptr<Memcpy> h2d = NewPrimitive<MemcpyFactory>(device_type, MemcpyKind::kHtoD);\n    std::unique_ptr<Memcpy> d2h = NewPrimitive<MemcpyFactory>(device_type, MemcpyKind::kDtoH);\n    ASSERT_TRUE(d2h.operator bool());\n    ASSERT_TRUE(h2d.operator bool());\n    T* mat_data = mat.data();\n\n    std::memcpy(host_src.ptr(), mat_data, in_matrix_size);\n    h2d->Launch(stream.stream(), device_src.ptr<T>(), host_src.ptr<T>(), in_matrix_size);\n    const int64_t src_dims[2] = {dims[0], dims[1]};\n    constant_pad->Launch(stream.stream(), /*num_dims=*/2, src_dims, device_src.ptr<T>(),\n                         padding_before.data(), padding_after.data(), Scalar(0),\n                         device_dst.ptr<T>());\n    d2h->Launch(stream.stream(), host_dst.ptr<T>(), device_dst.ptr<T>(), out_matrix_size);\n    CHECK_JUST(stream.stream()->Sync());\n\n    Eigen::array<std::pair<int, int>, 2> paddings;\n    for (int i = 0; i < 2; i++) {\n      paddings[i] = std::make_pair(padding_before[i], padding_after[i]);\n    }\n\n    Eigen::Tensor<T, 2, Eigen::RowMajor> mat_padded = mat.pad(paddings);\n    auto eigen_padded_res = Eigen::Map<EigenVec, Eigen::Unaligned>(\n        reinterpret_cast<T*>(mat_padded.data()), out_elem_cnt);\n    auto constant_pad_primitive_res =\n        Eigen::Map<EigenVec, Eigen::Unaligned>(host_dst.ptr<T>(), out_elem_cnt);\n    ASSERT_TRUE(eigen_padded_res.template isApprox(constant_pad_primitive_res));\n  }\n}\n\ntemplate<typename T, DataType dtype>\nvoid TestConstantPadNegative2d(DeviceManagerRegistry* registry,\n                               const std::set<DeviceType>& device_types, const int dims[2],\n                               const std::vector<int64_t> padding_before,\n                               const std::vector<int64_t> padding_after) {\n  using EigenVec = Eigen::Matrix<T, 1, Eigen::Dynamic>;\n  int in_elem_cnt = 1;\n  int out_elem_cnt = 1;\n  int offsets[2];\n  int extents[2];\n\n  for (int i = 0; i < 2; i++) {\n    in_elem_cnt *= dims[i];\n    out_elem_cnt *= (dims[i] + padding_before[i] + padding_after[i]);\n    offsets[i] = -padding_before[i];\n    extents[i] = dims[i] + padding_before[i] + padding_after[i];\n  }\n  const int in_matrix_size = in_elem_cnt * sizeof(T);\n  const int out_matrix_size = out_elem_cnt * sizeof(T);\n\n  for (const auto& device_type : device_types) {\n    Eigen::Tensor<T, 2, Eigen::RowMajor> mat(dims[0], dims[1]);\n\n    mat.setRandom();\n    auto device = registry->GetDevice(device_type, 0);\n\n    ep::test::PinnedMemoryGuard host_src(device.get(), in_matrix_size);\n    ep::test::PinnedMemoryGuard host_dst(device.get(), out_matrix_size);\n    ep::test::DeviceMemoryGuard device_src(device.get(), in_matrix_size);\n    ep::test::DeviceMemoryGuard device_dst(device.get(), out_matrix_size);\n\n    ep::test::StreamGuard stream(device.get());\n    std::unique_ptr<ConstantPad> constant_pad =\n        NewPrimitive<ConstantPadFactory>(device_type, dtype);\n    ASSERT_TRUE(constant_pad.operator bool());\n    std::unique_ptr<Memcpy> h2d = NewPrimitive<MemcpyFactory>(device_type, MemcpyKind::kHtoD);\n    std::unique_ptr<Memcpy> d2h = NewPrimitive<MemcpyFactory>(device_type, MemcpyKind::kDtoH);\n    ASSERT_TRUE(d2h.operator bool());\n    ASSERT_TRUE(h2d.operator bool());\n    T* mat_data = mat.data();\n\n    std::memcpy(host_src.ptr(), mat_data, in_matrix_size);\n    h2d->Launch(stream.stream(), device_src.ptr<T>(), host_src.ptr<T>(), in_matrix_size);\n    const int64_t src_dims[2] = {dims[0], dims[1]};\n    constant_pad->Launch(stream.stream(), /*num_dims=*/2, src_dims, device_src.ptr<T>(),\n                         padding_before.data(), padding_after.data(), Scalar(0),\n                         device_dst.ptr<T>());\n    d2h->Launch(stream.stream(), host_dst.ptr<T>(), device_dst.ptr<T>(), out_matrix_size);\n    CHECK_JUST(stream.stream()->Sync());\n\n    Eigen::array<Eigen::Index, 2> slice_offsets = {offsets[0], offsets[1]};\n    Eigen::array<Eigen::Index, 2> slice_extents = {extents[0], extents[1]};\n    Eigen::Tensor<T, 2, Eigen::RowMajor> mat_padded = mat.slice(slice_offsets, slice_extents);\n    auto eigen_padded_res = Eigen::Map<EigenVec, Eigen::Unaligned>(\n        reinterpret_cast<T*>(mat_padded.data()), out_elem_cnt);\n    auto constant_pad_primitive_res =\n        Eigen::Map<EigenVec, Eigen::Unaligned>(host_dst.ptr<T>(), out_elem_cnt);\n    ASSERT_TRUE(eigen_padded_res.template isApprox(constant_pad_primitive_res));\n  }\n}\n\nTEST_F(PrimitiveTest, TestConstantPadPrimitive2d) {\n  const int32_t dims1[2] = {4, 4};\n  const int32_t dims2[2] = {10, 3};\n  const int32_t dims3[2] = {31, 4};\n  const int32_t dims4[2] = {6, 8};\n  const int32_t dims5[2] = {4, 11};\n\n  const std::vector<int64_t> padding_before1 = {1, 1};\n  const std::vector<int64_t> padding_after1 = {1, 1};\n  const std::vector<int64_t> padding_before2 = {1, 2};\n  const std::vector<int64_t> padding_after2 = {2, 1};\n  const std::vector<int64_t> padding_before3 = {2, 1};\n  const std::vector<int64_t> padding_after3 = {1, 2};\n  const std::vector<int64_t> padding_before4 = {3, 1};\n  const std::vector<int64_t> padding_after4 = {1, 3};\n  const std::vector<int64_t> padding_before5 = {1, 3};\n  const std::vector<int64_t> padding_after5 = {3, 1};\n\n  TestConstantPad2d<float, DataType::kFloat>(&device_manager_registry_, available_device_types_,\n                                             dims1, padding_before1, padding_after1);\n  TestConstantPad2d<double, DataType::kDouble>(&device_manager_registry_, available_device_types_,\n                                               dims2, padding_before2, padding_after2);\n  TestConstantPad2d<int32_t, DataType::kInt32>(&device_manager_registry_, available_device_types_,\n                                               dims3, padding_before3, padding_after3);\n  TestConstantPad2d<int64_t, DataType::kInt64>(&device_manager_registry_, available_device_types_,\n                                               dims4, padding_before4, padding_after4);\n  TestConstantPad2d<Eigen::half, DataType::kFloat16>(\n      &device_manager_registry_, available_device_types_, dims5, padding_before5, padding_after5);\n}\n\nTEST_F(PrimitiveTest, TestConstantPadPrimitiveNegative2d) {\n  // const int32_t dims1[2] = {4, 4};\n  const int32_t dims1[2] = {7, 9};\n\n  const int32_t dims2[2] = {10, 7};\n  const int32_t dims3[2] = {12, 11};\n  const int32_t dims4[2] = {6, 8};\n  const int32_t dims5[2] = {4, 11};\n\n  const std::vector<int64_t> padding_before1 = {-1, -1};\n  const std::vector<int64_t> padding_after1 = {-1, -1};\n  const std::vector<int64_t> padding_before2 = {-2, 0};\n  const std::vector<int64_t> padding_after2 = {0, -1};\n  const std::vector<int64_t> padding_before3 = {-2, -1};\n  const std::vector<int64_t> padding_after3 = {-1, -2};\n  const std::vector<int64_t> padding_before4 = {-1, 0};\n  const std::vector<int64_t> padding_after4 = {0, -1};\n  const std::vector<int64_t> padding_before5 = {-1, -3};\n  const std::vector<int64_t> padding_after5 = {0, -1};\n\n  TestConstantPadNegative2d<float, DataType::kFloat>(\n      &device_manager_registry_, available_device_types_, dims1, padding_before1, padding_after1);\n  TestConstantPadNegative2d<double, DataType::kDouble>(\n      &device_manager_registry_, available_device_types_, dims2, padding_before2, padding_after2);\n  TestConstantPadNegative2d<int32_t, DataType::kInt32>(\n      &device_manager_registry_, available_device_types_, dims3, padding_before3, padding_after3);\n  TestConstantPadNegative2d<int64_t, DataType::kInt64>(\n      &device_manager_registry_, available_device_types_, dims4, padding_before4, padding_after4);\n  TestConstantPadNegative2d<Eigen::half, DataType::kFloat16>(\n      &device_manager_registry_, available_device_types_, dims5, padding_before5, padding_after5);\n}\n\n}  // namespace test\n\n}  // namespace primitive\n\n}  // namespace ep\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/ep/test/primitive/copy_nd_test.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <gtest/gtest.h>\n#include \"oneflow/core/ep/test/primitive/primitive_test.h\"\n#include \"oneflow/core/ep/include/primitive/memset.h\"\n#include \"oneflow/core/ep/include/primitive/memcpy.h\"\n#include \"oneflow/core/ep/include/primitive/copy_nd.h\"\n\nnamespace oneflow {\n\nnamespace ep {\n\nnamespace primitive {\n\nnamespace test {\n\nnamespace {\n\ntemplate<DataType data_type, typename T>\nvoid TestCopyNd(DeviceManagerRegistry* registry, const std::set<DeviceType>& device_types,\n                int64_t num_dims) {\n  std::vector<int64_t> src_dims(num_dims, 0);\n  std::vector<int64_t> src_pos(num_dims, 0);\n  std::vector<int64_t> dst_pos(num_dims, 0);\n  std::vector<int64_t> dst_dims(num_dims, 0);\n  std::vector<int64_t> extent(num_dims, 0);\n  int64_t src_elem = 1;\n  int64_t dst_elem = 1;\n  for (int i = 0; i < num_dims; ++i) {\n    int64_t rand_dim = 8 + std::rand() % 32;\n    int64_t rand_pos = std::rand() % 16;\n    src_dims.at(i) = rand_dim;\n    dst_pos.at(i) = rand_pos;\n    dst_dims.at(i) = rand_pos + rand_dim;\n    extent.at(i) = rand_dim;\n    src_elem *= src_dims.at(i);\n    dst_elem *= dst_dims.at(i);\n  }\n  int64_t src_size = src_elem * sizeof(T);\n  int64_t dst_size = dst_elem * sizeof(T);\n\n  for (const auto& device_type : device_types) {\n    auto device = registry->GetDevice(device_type, 0);\n    ep::test::PinnedMemoryGuard input(device.get(), src_size);\n    ep::test::PinnedMemoryGuard output(device.get(), src_size);\n    ep::test::DeviceMemoryGuard device0(device.get(), src_size);\n    ep::test::DeviceMemoryGuard device1(device.get(), dst_size);\n    for (size_t i = 0; i < src_elem; ++i) { *(input.ptr<T>() + i) = static_cast<T>(i); }\n    ep::test::StreamGuard stream(device.get());\n    std::unique_ptr<Memcpy> h2d = NewPrimitive<MemcpyFactory>(device_type, MemcpyKind::kHtoD);\n    ASSERT_TRUE(h2d.operator bool());\n    std::unique_ptr<CopyNd> copy_nd = NewPrimitive<CopyNdFactory>(device_type, num_dims);\n    ASSERT_TRUE(copy_nd.operator bool());\n    std::unique_ptr<Memcpy> d2h = NewPrimitive<MemcpyFactory>(device_type, MemcpyKind::kDtoH);\n    ASSERT_TRUE(d2h.operator bool());\n    std::unique_ptr<Memset> memset = NewPrimitive<MemsetFactory>(device_type);\n    ASSERT_TRUE(memset.operator bool());\n    h2d->Launch(stream.stream(), device0.ptr(), input.ptr(), src_size);\n    // contiguous device0 to noncontiguous device1\n    copy_nd->Launch(stream.stream(), data_type, num_dims, device1.ptr(), dst_dims.data(),\n                    dst_pos.data(), device0.ptr(), src_dims.data(), src_pos.data(), extent.data());\n    // memset device0\n    memset->Launch(stream.stream(), device0.ptr(), 0x55, src_size);\n    // noncontiguous device1 to contiguous device0\n    copy_nd->Launch(stream.stream(), data_type, num_dims, device0.ptr(), src_dims.data(),\n                    src_pos.data(), device1.ptr(), dst_dims.data(), dst_pos.data(), extent.data());\n    d2h->Launch(stream.stream(), output.ptr(), device0.ptr(), src_size);\n    CHECK_JUST(stream.stream()->Sync());\n    for (size_t i = 0; i < src_elem; ++i) {\n      ASSERT_EQ(*(input.ptr<T>() + i), *(output.ptr<T>() + i));\n    }\n  }\n}\n\n}  // namespace\n\nTEST_F(PrimitiveTest, TestCopyNd) {\n  for (int i = 1; i < 6; ++i) {\n    TestCopyNd<DataType::kDouble, double>(&device_manager_registry_, available_device_types_, i);\n    TestCopyNd<DataType::kFloat, float>(&device_manager_registry_, available_device_types_, i);\n    TestCopyNd<DataType::kInt8, int8_t>(&device_manager_registry_, available_device_types_, i);\n    TestCopyNd<DataType::kInt32, int32_t>(&device_manager_registry_, available_device_types_, i);\n    TestCopyNd<DataType::kInt64, int64_t>(&device_manager_registry_, available_device_types_, i);\n  }\n}\n\n}  // namespace test\n\n}  // namespace primitive\n\n}  // namespace ep\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/ep/test/primitive/elementwise_unary_test.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <gtest/gtest.h>\n#include \"oneflow/core/ep/test/primitive/primitive_test.h\"\n#include \"oneflow/core/ep/include/primitive/memcpy.h\"\n#include \"oneflow/core/ep/include/primitive/elementwise_unary.h\"\n#include <Eigen/Core>\nnamespace oneflow {\n\nnamespace ep {\n\nnamespace primitive {\n\nnamespace test {\n\ntemplate<typename Src, typename Dst>\nstruct ReluFunctor {\n  Dst operator()(Src src) {\n    if (src > zero_val) { return src; }\n    return zero_val;\n  }\n\n  Src zero_val = static_cast<Src>(0.0);\n};\n\ntemplate<typename Src, typename Dst>\nstruct GeluFunctor {\n  Dst operator()(Src src) {\n    return static_cast<Dst>(0.5) * src * (static_cast<Src>(1.0) + std::erf(inv_sqrt2 * src));\n  }\n  Src inv_sqrt2 = std::sqrt(0.5);\n};\n\ntemplate<typename Src, typename Dst>\nstruct TanhFunctor {\n  Dst operator()(Src src) { return static_cast<Dst>(std::tanh(src)); }\n};\n\ntemplate<typename Src, typename Dst>\nstruct LogicalNotFunctor {\n  Dst operator()(Src src) { return static_cast<Dst>(!src); }\n};\n\ntemplate<typename Src, typename Dst, typename FunctorT>\nvoid EigenElementwise(FunctorT functor, Src* src, Dst* dst, const size_t elem_cnt) {\n  for (int idx = 0; idx < elem_cnt; idx++) { dst[idx] = functor(src[idx]); }\n}\n\ntemplate<typename Src, typename Dst, DataType SrcType, DataType DstType,\n         ep::primitive::UnaryOp unary_op, template<typename A, typename B> class FunctorClass>\nvoid TestElementwise(DeviceManagerRegistry* registry, const std::set<DeviceType>& device_types,\n                     const size_t elem_cnt, Scalar attr0 = Scalar(), Scalar attr1 = Scalar()) {\n  for (const auto& device_type : device_types) {\n    auto device = registry->GetDevice(device_type, 0);\n    using EigenSrcVec = Eigen::Matrix<Src, 1, Eigen::Dynamic>;\n    using EigenDstVec = Eigen::Matrix<Dst, 1, Eigen::Dynamic>;\n\n    const size_t src_data_size = elem_cnt * sizeof(Src);\n    const size_t dst_data_size = elem_cnt * sizeof(Dst);\n    EigenSrcVec eigen_src(elem_cnt);\n    EigenDstVec eigen_dst(elem_cnt);\n    eigen_src.setRandom();\n    eigen_dst.setZero();\n\n    ep::test::PinnedMemoryGuard host_src(device.get(), elem_cnt * sizeof(Src));\n    ep::test::PinnedMemoryGuard host_dst(device.get(), elem_cnt * sizeof(Dst));\n    ep::test::DeviceMemoryGuard device_src(device.get(), elem_cnt * sizeof(Src));\n    ep::test::DeviceMemoryGuard device_dst(device.get(), elem_cnt * sizeof(Dst));\n\n    ep::test::StreamGuard stream(device.get());\n    std::unique_ptr<ElementwiseUnary> elementwise_primitive = NewPrimitive<ElementwiseUnaryFactory>(\n        device_type, unary_op, /*src_type=*/SrcType, /*dst_type=*/DstType, attr0, attr1);\n    ASSERT_TRUE(elementwise_primitive.operator bool());\n    std::unique_ptr<Memcpy> h2d = NewPrimitive<MemcpyFactory>(device_type, MemcpyKind::kHtoD);\n    std::unique_ptr<Memcpy> d2h = NewPrimitive<MemcpyFactory>(device_type, MemcpyKind::kDtoH);\n    ASSERT_TRUE(d2h.operator bool());\n    ASSERT_TRUE(h2d.operator bool());\n    Src* eigen_src_data = eigen_src.data();\n    std::memcpy(host_src.ptr(), eigen_src_data, src_data_size);\n    h2d->Launch(stream.stream(), device_src.ptr<Src>(), host_src.ptr<Src>(), src_data_size);\n    elementwise_primitive->Launch(stream.stream(), device_src.ptr<Dst>(), device_dst.ptr<Dst>(),\n                                  elem_cnt);\n    d2h->Launch(stream.stream(), host_dst.ptr<Dst>(), device_dst.ptr<Dst>(), dst_data_size);\n    CHECK_JUST(stream.stream()->Sync());\n\n    FunctorClass<Src, Dst> functor{};\n    EigenElementwise<Src, Dst, FunctorClass<Src, Dst>>(functor, eigen_src.data(), eigen_dst.data(),\n                                                       elem_cnt);\n    auto elementwise_primitive_res =\n        Eigen::Map<EigenDstVec, Eigen::Unaligned>(host_dst.ptr<Dst>(), elem_cnt);\n    ASSERT_TRUE(eigen_dst.template isApprox(elementwise_primitive_res));\n  }\n}\n\nTEST_F(PrimitiveTest, TestElementwisePrimitive) {\n  // Test Relu\n  TestElementwise<float, float, DataType::kFloat, DataType::kFloat, ep::primitive::UnaryOp::kRelu,\n                  ReluFunctor>(&device_manager_registry_, available_device_types_, 16);\n  TestElementwise<double, double, DataType::kDouble, DataType::kDouble,\n                  ep::primitive::UnaryOp::kRelu, ReluFunctor>(&device_manager_registry_,\n                                                              available_device_types_, 32);\n  TestElementwise<int32_t, int32_t, DataType::kInt32, DataType::kInt32,\n                  ep::primitive::UnaryOp::kRelu, ReluFunctor>(&device_manager_registry_,\n                                                              available_device_types_, 64);\n  TestElementwise<int64_t, int64_t, DataType::kInt64, DataType::kInt64,\n                  ep::primitive::UnaryOp::kRelu, ReluFunctor>(&device_manager_registry_,\n                                                              available_device_types_, 128);\n\n  // Test Gelu\n  TestElementwise<float, float, DataType::kFloat, DataType::kFloat, ep::primitive::UnaryOp::kGelu,\n                  GeluFunctor>(&device_manager_registry_, available_device_types_, 32);\n  TestElementwise<double, double, DataType::kDouble, DataType::kDouble,\n                  ep::primitive::UnaryOp::kGelu, GeluFunctor>(&device_manager_registry_,\n                                                              available_device_types_, 128);\n\n  // Test Tanh\n  TestElementwise<float, float, DataType::kFloat, DataType::kFloat, ep::primitive::UnaryOp::kTanh,\n                  TanhFunctor>(&device_manager_registry_, available_device_types_, 32);\n  TestElementwise<double, double, DataType::kDouble, DataType::kDouble,\n                  ep::primitive::UnaryOp::kTanh, TanhFunctor>(&device_manager_registry_,\n                                                              available_device_types_, 128);\n\n  // Test Logical Not\n  TestElementwise<float, bool, DataType::kFloat, DataType::kBool,\n                  ep::primitive::UnaryOp::kLogicalNot, LogicalNotFunctor>(\n      &device_manager_registry_, available_device_types_, 32);\n  TestElementwise<double, bool, DataType::kDouble, DataType::kBool,\n                  ep::primitive::UnaryOp::kLogicalNot, LogicalNotFunctor>(\n      &device_manager_registry_, available_device_types_, 64);\n  TestElementwise<int8_t, bool, DataType::kInt8, DataType::kBool,\n                  ep::primitive::UnaryOp::kLogicalNot, LogicalNotFunctor>(\n      &device_manager_registry_, available_device_types_, 16);\n  TestElementwise<int32_t, bool, DataType::kInt32, DataType::kBool,\n                  ep::primitive::UnaryOp::kLogicalNot, LogicalNotFunctor>(\n      &device_manager_registry_, available_device_types_, 128);\n  TestElementwise<int64_t, bool, DataType::kInt64, DataType::kBool,\n                  ep::primitive::UnaryOp::kLogicalNot, LogicalNotFunctor>(\n      &device_manager_registry_, available_device_types_, 96);\n}\n\n}  // namespace test\n\n}  // namespace primitive\n\n}  // namespace ep\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/ep/test/primitive/fill_test.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <type_traits>\n#include <gtest/gtest.h>\n#include \"oneflow/core/ep/test/primitive/primitive_test.h\"\n#include \"oneflow/core/ep/include/primitive/memset.h\"\n#include \"oneflow/core/ep/include/primitive/memcpy.h\"\n#include \"oneflow/core/ep/include/primitive/fill.h\"\n\n#ifdef WITH_CUDA\n#include <cuda.h>\n#include <cuda_fp16.h>\n#if CUDA_VERSION >= 11000\n#include <cuda_bf16.h>\n#endif  // CUDA_VERSION >= 11000\n#endif  // WITH_CUDA\n\nnamespace oneflow {\n\nnamespace ep {\n\nnamespace primitive {\n\nnamespace test {\n\nnamespace {\n\ntemplate<DataType data_type, typename T>\nvoid TestFill(DeviceManagerRegistry* registry, const std::set<DeviceType>& device_types, size_t n) {\n  const size_t vector_size = n * sizeof(T);\n  for (const auto& device_type : device_types) {\n#ifdef WITH_CUDA\n#if CUDA_VERSION >= 11000\n    if (device_type == DeviceType::kCPU && data_type == DataType::kBFloat16) { continue; }\n#endif  // CUDA_VERSION >= 11000\n#endif  // WITH_CUDA\n    auto device = registry->GetDevice(device_type, 0);\n    ep::test::DeviceMemoryGuard device_mem(device.get(), vector_size);\n    ep::test::PinnedMemoryGuard host_mem(device.get(), vector_size);\n    ep::test::StreamGuard stream(device.get());\n\n    std::unique_ptr<Fill> fill = NewPrimitive<FillFactory>(device_type, data_type);\n    ASSERT_TRUE(fill.operator bool());\n    std::unique_ptr<Memcpy> d2h = NewPrimitive<MemcpyFactory>(device_type, MemcpyKind::kDtoH);\n    ASSERT_TRUE(d2h.operator bool());\n\n    fill->Launch(stream.stream(), device_mem.ptr(), Scalar(15.0), n);\n    d2h->Launch(stream.stream(), host_mem.ptr(), device_mem.ptr(), vector_size);\n    CHECK_JUST(stream.stream()->Sync());\n    for (size_t i = 0; i < n; ++i) {\n      ASSERT_EQ(*reinterpret_cast<T*>(host_mem.ptr<T>() + i), static_cast<T>(15.0));\n    }\n\n    fill->Launch(stream.stream(), device_mem.ptr(), Scalar(0), n);\n    d2h->Launch(stream.stream(), host_mem.ptr(), device_mem.ptr(), vector_size);\n    CHECK_JUST(stream.stream()->Sync());\n\n    for (size_t i = 0; i < n; ++i) {\n#ifdef WITH_CUDA\n      if constexpr (std::is_same_v<T, half>) {\n        ASSERT_EQ(*reinterpret_cast<T*>(host_mem.ptr<T>() + i), __float2half(0.0));\n#if CUDA_VERSION >= 11000\n      } else if constexpr (std::is_same_v<T, nv_bfloat16>) {\n        ASSERT_EQ(*reinterpret_cast<T*>(host_mem.ptr<T>() + i), __float2bfloat16(0.0));\n#endif  // CUDA_VERSION >= 11000\n      } else {\n        ASSERT_EQ(*reinterpret_cast<T*>(host_mem.ptr<T>() + i), static_cast<T>(0));\n      }\n#else\n      ASSERT_EQ(*reinterpret_cast<T*>(host_mem.ptr<T>() + i), static_cast<T>(0));\n#endif  // WITH_CUDA\n    }\n  }\n}\n\n}  // namespace\n\nTEST_F(PrimitiveTest, TestFill) {\n  TestFill<DataType::kChar, char>(&device_manager_registry_, available_device_types_, 1024);\n  TestFill<DataType::kDouble, double>(&device_manager_registry_, available_device_types_, 1024);\n  TestFill<DataType::kFloat, float>(&device_manager_registry_, available_device_types_, 1024);\n  TestFill<DataType::kInt8, int8_t>(&device_manager_registry_, available_device_types_, 1024);\n  TestFill<DataType::kInt32, int32_t>(&device_manager_registry_, available_device_types_, 1024);\n  TestFill<DataType::kInt64, int64_t>(&device_manager_registry_, available_device_types_, 1024);\n  TestFill<DataType::kUInt8, uint8_t>(&device_manager_registry_, available_device_types_, 1024);\n#ifdef WITH_CUDA\n  TestFill<DataType::kFloat16, half>(&device_manager_registry_, available_device_types_, 1024);\n#if CUDA_VERSION >= 11000\n  TestFill<DataType::kBFloat16, nv_bfloat16>(&device_manager_registry_, available_device_types_,\n                                             1024);\n#endif  // CUDA_VERSION >= 11000\n#endif  // WITH_CUDA\n  TestFill<DataType::kBool, bool>(&device_manager_registry_, available_device_types_, 1024);\n}\n\n}  // namespace test\n\n}  // namespace primitive\n\n}  // namespace ep\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/ep/test/primitive/matmul_test.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <gtest/gtest.h>\n#include \"oneflow/core/ep/test/primitive/primitive_test.h\"\n#include \"oneflow/core/ep/include/primitive/memset.h\"\n#include \"oneflow/core/ep/include/primitive/memcpy.h\"\n#include \"oneflow/core/ep/include/primitive/matmul.h\"\n#include <unsupported/Eigen/CXX11/Tensor>\n\nnamespace oneflow {\n\nnamespace ep {\n\nnamespace primitive {\n\nnamespace test {\n\nnamespace {\n\ntemplate<DataType data_type, typename T>\nvoid TestMatmul(DeviceManagerRegistry* registry, const std::set<DeviceType>& device_types, int m,\n                int k, int n, bool transpose_a, bool transpose_b) {\n  using Matrix = Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;\n  Matrix a = Matrix::Random(m, k);\n  Matrix b = Matrix::Random(k, n);\n  Matrix c = a * b;\n  Matrix a_transpose = a.transpose();\n  Matrix b_transpose = b.transpose();\n\n  int64_t a_size = m * k * sizeof(T);\n  int64_t b_size = k * n * sizeof(T);\n  int64_t c_size = m * n * sizeof(T);\n\n  for (const auto& device_type : device_types) {\n    if (device_type == DeviceType::kCPU && data_type == DataType::kFloat16) {\n      // CPU matmul not support float16\n      continue;\n    }\n    auto device = registry->GetDevice(device_type, 0);\n    ep::test::PinnedMemoryGuard input_a(device.get(), a_size);\n    ep::test::PinnedMemoryGuard input_b(device.get(), b_size);\n    if (transpose_a) {\n      std::memcpy(input_a.ptr(), a_transpose.data(), a_size);\n    } else {\n      std::memcpy(input_a.ptr(), a.data(), a_size);\n    }\n    if (transpose_b) {\n      std::memcpy(input_b.ptr(), b_transpose.data(), b_size);\n    } else {\n      std::memcpy(input_b.ptr(), b.data(), b_size);\n    }\n    ep::test::PinnedMemoryGuard output(device.get(), c_size);\n    ep::test::DeviceMemoryGuard device_a(device.get(), a_size);\n    ep::test::DeviceMemoryGuard device_b(device.get(), b_size);\n    ep::test::DeviceMemoryGuard device_c(device.get(), c_size);\n    ep::test::StreamGuard stream(device.get());\n    std::unique_ptr<Memcpy> h2d = NewPrimitive<MemcpyFactory>(device_type, MemcpyKind::kHtoD);\n    std::unique_ptr<Memcpy> d2h = NewPrimitive<MemcpyFactory>(device_type, MemcpyKind::kDtoH);\n    const auto trans_a = transpose_a ? BlasTransposeType::T : BlasTransposeType::N;\n    const auto trans_b = transpose_b ? BlasTransposeType::T : BlasTransposeType::N;\n    std::unique_ptr<Matmul> matmul =\n        NewPrimitive<MatmulFactory>(device_type, data_type, trans_a, trans_b);\n    ASSERT_TRUE(d2h.operator bool());\n    ASSERT_TRUE(h2d.operator bool());\n    ASSERT_TRUE(matmul.operator bool());\n    h2d->Launch(stream.stream(), device_a.ptr(), input_a.ptr(), a_size);\n    h2d->Launch(stream.stream(), device_b.ptr(), input_b.ptr(), b_size);\n    matmul->Launch(stream.stream(), m, n, k, 1.0, device_a.ptr(), device_b.ptr(), 0.0,\n                   device_c.ptr());\n    d2h->Launch(stream.stream(), output.ptr(), device_c.ptr(), c_size);\n    CHECK_JUST(stream.stream()->Sync());\n    auto res = Eigen::Map<Matrix, Eigen::Unaligned>(reinterpret_cast<T*>(output.ptr()), m, n);\n    ASSERT_TRUE(c.template isApprox(res, static_cast<T>(0.001)));\n  }\n}\n\ntemplate<DataType data_type, typename T>\nvoid TestMatmul(DeviceManagerRegistry* registry, const std::set<DeviceType>& device_types, int m,\n                int k, int n) {\n  TestMatmul<data_type, T>(registry, device_types, m, k, n, false, false);\n  TestMatmul<data_type, T>(registry, device_types, m, k, n, true, false);\n  TestMatmul<data_type, T>(registry, device_types, m, k, n, false, true);\n  TestMatmul<data_type, T>(registry, device_types, m, k, n, true, true);\n}\n\ntemplate<DataType data_type, typename T>\nvoid TestMatmul(DeviceManagerRegistry* registry, const std::set<DeviceType>& device_types) {\n  TestMatmul<data_type, T>(registry, device_types, 64, 16, 8);\n  TestMatmul<data_type, T>(registry, device_types, 16, 7, 12);\n}\n\n}  // namespace\n\nTEST_F(PrimitiveTest, TestMatmul) {\n  TestMatmul<DataType::kDouble, double>(&device_manager_registry_, available_device_types_);\n  TestMatmul<DataType::kFloat, float>(&device_manager_registry_, available_device_types_);\n  TestMatmul<DataType::kFloat16, Eigen::half>(&device_manager_registry_, available_device_types_);\n}\n\n}  // namespace test\n\n}  // namespace primitive\n\n}  // namespace ep\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/ep/test/primitive/memcpy_test.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <gtest/gtest.h>\n#include \"oneflow/core/ep/test/primitive/primitive_test.h\"\n#include \"oneflow/core/ep/include/primitive/memcpy.h\"\n\nnamespace oneflow {\n\nnamespace ep {\n\nnamespace primitive {\n\nnamespace test {\n\nTEST_F(PrimitiveTest, TestMemcpy) {\n  const size_t test_elem = 1024 * 1024;\n  const size_t test_size = test_elem * sizeof(float);\n  for (const auto& device_type : available_device_types_) {\n    auto device = device_manager_registry_.GetDevice(device_type, 0);\n    ep::test::PinnedMemoryGuard input(device.get(), test_size);\n    ep::test::PinnedMemoryGuard output(device.get(), test_size);\n    ep::test::DeviceMemoryGuard device0(device.get(), test_size);\n    ep::test::DeviceMemoryGuard device1(device.get(), test_size);\n    for (size_t i = 0; i < test_elem; ++i) { *(input.ptr<float>() + i) = i; }\n    ep::test::StreamGuard stream(device.get());\n    std::unique_ptr<Memcpy> h2d = NewPrimitive<MemcpyFactory>(device_type, MemcpyKind::kHtoD);\n    ASSERT_TRUE(h2d.operator bool());\n    std::unique_ptr<Memcpy> d2d = NewPrimitive<MemcpyFactory>(device_type, MemcpyKind::kDtoD);\n    ASSERT_TRUE(d2d.operator bool());\n    std::unique_ptr<Memcpy> d2h = NewPrimitive<MemcpyFactory>(device_type, MemcpyKind::kDtoH);\n    ASSERT_TRUE(d2h.operator bool());\n    h2d->Launch(stream.stream(), device0.ptr(), input.ptr(), test_size);\n    d2d->Launch(stream.stream(), device1.ptr(), device0.ptr(), test_size);\n    d2h->Launch(stream.stream(), output.ptr(), device1.ptr(), test_size);\n    CHECK_JUST(stream.stream()->Sync());\n    for (size_t i = 0; i < test_elem; ++i) {\n      ASSERT_EQ(*(input.ptr<float>() + i), *(output.ptr<float>() + i));\n    }\n  }\n}\n\n}  // namespace test\n\n}  // namespace primitive\n\n}  // namespace ep\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/ep/test/primitive/memset_test.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <gtest/gtest.h>\n#include \"oneflow/core/ep/test/primitive/primitive_test.h\"\n#include \"oneflow/core/ep/include/primitive/memset.h\"\n#include \"oneflow/core/ep/include/primitive/memcpy.h\"\n\nnamespace oneflow {\n\nnamespace ep {\n\nnamespace primitive {\n\nnamespace test {\n\nTEST_F(PrimitiveTest, TestMemset) {\n  const size_t test_size = 1024 * 1024;\n  for (const auto& device_type : available_device_types_) {\n    auto device = device_manager_registry_.GetDevice(device_type, 0);\n    ep::test::DeviceMemoryGuard device_mem(device.get(), test_size);\n    ep::test::PinnedMemoryGuard host_mem(device.get(), test_size);\n    ep::test::StreamGuard stream(device.get());\n    std::unique_ptr<Memset> memset = NewPrimitive<MemsetFactory>(device_type);\n    ASSERT_TRUE(memset.operator bool());\n    std::unique_ptr<Memcpy> d2h = NewPrimitive<MemcpyFactory>(device_type, MemcpyKind::kDtoH);\n    ASSERT_TRUE(d2h.operator bool());\n    memset->Launch(stream.stream(), device_mem.ptr(), 0x55, test_size);\n    d2h->Launch(stream.stream(), host_mem.ptr(), device_mem.ptr(), test_size);\n    CHECK_JUST(stream.stream()->Sync());\n    for (size_t i = 0; i < test_size; ++i) { ASSERT_EQ(*(host_mem.ptr<char>() + i), 0x55); }\n    memset->Launch(stream.stream(), device_mem.ptr(), 0, test_size);\n    d2h->Launch(stream.stream(), host_mem.ptr(), device_mem.ptr(), test_size);\n    CHECK_JUST(stream.stream()->Sync());\n    for (size_t i = 0; i < test_size; ++i) { ASSERT_EQ(*(host_mem.ptr<char>() + i), 0); }\n  }\n}\n\n}  // namespace test\n\n}  // namespace primitive\n\n}  // namespace ep\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/ep/test/primitive/permute_test.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <gtest/gtest.h>\n#include \"oneflow/core/ep/test/primitive/primitive_test.h\"\n#include \"oneflow/core/ep/include/primitive/memcpy.h\"\n#include \"oneflow/core/ep/include/primitive/permute.h\"\n#include <Eigen/Core>\n#include <unsupported/Eigen/CXX11/Tensor>\nnamespace oneflow {\n\nnamespace ep {\n\nnamespace primitive {\n\nnamespace test {\n\ntemplate<typename T, DataType dtype, int NumDims>\nvoid TestPermute2D(DeviceManagerRegistry* registry, const std::set<DeviceType>& device_types,\n                   const int dims[NumDims], const int permutation_list[NumDims]) {\n  using EigenVec = Eigen::Matrix<T, 1, Eigen::Dynamic>;\n  const int elem_cnt = dims[0] * dims[1];\n  const int matrix_size = elem_cnt * sizeof(T);\n\n  for (const auto& device_type : device_types) {\n    Eigen::Tensor<T, NumDims, Eigen::RowMajor> mat(dims[0], dims[1]);\n    mat.setRandom();\n    auto device = registry->GetDevice(device_type, 0);\n\n    ep::test::PinnedMemoryGuard host_src(device.get(), matrix_size);\n    ep::test::PinnedMemoryGuard host_dst(device.get(), matrix_size);\n    ep::test::DeviceMemoryGuard device_src(device.get(), matrix_size);\n    ep::test::DeviceMemoryGuard device_dst(device.get(), matrix_size);\n\n    ep::test::StreamGuard stream(device.get());\n    std::unique_ptr<Permute> permute =\n        NewPrimitive<PermuteFactory>(device_type, /*max_num_dims=*/NumDims);\n    ASSERT_TRUE(permute.operator bool());\n    std::unique_ptr<Memcpy> h2d = NewPrimitive<MemcpyFactory>(device_type, MemcpyKind::kHtoD);\n    std::unique_ptr<Memcpy> d2h = NewPrimitive<MemcpyFactory>(device_type, MemcpyKind::kDtoH);\n    ASSERT_TRUE(d2h.operator bool());\n    ASSERT_TRUE(h2d.operator bool());\n    T* mat_data = mat.data();\n    std::memcpy(host_src.ptr(), mat_data, matrix_size);\n    h2d->Launch(stream.stream(), device_src.ptr<T>(), host_src.ptr<T>(), matrix_size);\n    const int64_t src_dims[NumDims] = {dims[0], dims[1]};\n    permute->Launch(stream.stream(), dtype, /*num_dims=*/NumDims, src_dims, device_src.ptr<T>(),\n                    permutation_list, device_dst.ptr<T>());\n    d2h->Launch(stream.stream(), host_dst.ptr<T>(), device_dst.ptr<T>(), matrix_size);\n    CHECK_JUST(stream.stream()->Sync());\n\n    Eigen::array<int, NumDims> shuffle_index({permutation_list[0], permutation_list[1]});\n    Eigen::Tensor<T, NumDims, Eigen::RowMajor> mat_transposed = mat.shuffle(shuffle_index);\n\n    auto eigen_transposed_res = Eigen::Map<EigenVec, Eigen::Unaligned>(\n        reinterpret_cast<T*>(mat_transposed.data()), elem_cnt);\n    auto permute_primitive_res =\n        Eigen::Map<EigenVec, Eigen::Unaligned>(host_dst.ptr<T>(), elem_cnt);\n    ASSERT_TRUE(eigen_transposed_res.template isApprox(permute_primitive_res));\n  }\n}\n\ntemplate<typename T, DataType dtype, int NumDims>\nvoid TestPermute3D(DeviceManagerRegistry* registry, const std::set<DeviceType>& device_types,\n                   const int dims[NumDims], const int permutation_list[NumDims]) {\n  using EigenVec = Eigen::Matrix<T, 1, Eigen::Dynamic>;\n  const int elem_cnt = dims[0] * dims[1] * dims[2];\n  const int matrix_size = elem_cnt * sizeof(T);\n\n  for (const auto& device_type : device_types) {\n    Eigen::Tensor<T, NumDims, Eigen::RowMajor> mat(dims[0], dims[1], dims[2]);\n    mat.setRandom();\n    auto device = registry->GetDevice(device_type, 0);\n\n    ep::test::PinnedMemoryGuard host_src(device.get(), matrix_size);\n    ep::test::PinnedMemoryGuard host_dst(device.get(), matrix_size);\n    ep::test::DeviceMemoryGuard device_src(device.get(), matrix_size);\n    ep::test::DeviceMemoryGuard device_dst(device.get(), matrix_size);\n\n    ep::test::StreamGuard stream(device.get());\n    std::unique_ptr<Permute> permute =\n        NewPrimitive<PermuteFactory>(device_type, /*max_num_dims=*/NumDims);\n    ASSERT_TRUE(permute.operator bool());\n    std::unique_ptr<Memcpy> h2d = NewPrimitive<MemcpyFactory>(device_type, MemcpyKind::kHtoD);\n    std::unique_ptr<Memcpy> d2h = NewPrimitive<MemcpyFactory>(device_type, MemcpyKind::kDtoH);\n    ASSERT_TRUE(d2h.operator bool());\n    ASSERT_TRUE(h2d.operator bool());\n    T* mat_data = mat.data();\n    std::memcpy(host_src.ptr(), mat_data, matrix_size);\n    h2d->Launch(stream.stream(), device_src.ptr<T>(), host_src.ptr<T>(), matrix_size);\n    const int64_t src_dims[NumDims] = {dims[0], dims[1], dims[2]};\n    permute->Launch(stream.stream(), dtype, /*num_dims=*/NumDims, src_dims, device_src.ptr<T>(),\n                    permutation_list, device_dst.ptr<T>());\n    d2h->Launch(stream.stream(), host_dst.ptr<T>(), device_dst.ptr<T>(), matrix_size);\n    CHECK_JUST(stream.stream()->Sync());\n\n    Eigen::array<int, NumDims> shuffle_index(\n        {permutation_list[0], permutation_list[1], permutation_list[2]});\n    Eigen::Tensor<T, NumDims, Eigen::RowMajor> mat_transposed = mat.shuffle(shuffle_index);\n\n    auto eigen_transposed_res = Eigen::Map<EigenVec, Eigen::Unaligned>(\n        reinterpret_cast<T*>(mat_transposed.data()), elem_cnt);\n    auto permute_primitive_res =\n        Eigen::Map<EigenVec, Eigen::Unaligned>(host_dst.ptr<T>(), elem_cnt);\n    ASSERT_TRUE(eigen_transposed_res.template isApprox(permute_primitive_res));\n  }\n}\n\nTEST_F(PrimitiveTest, TestBatchPermute) {\n  const int permutation_list[2] = {1, 0};\n  const int32_t dims0[2] = {2, 3};\n  const int32_t dims1[2] = {7, 9};\n  const int32_t dims2[2] = {10, 3};\n  const int32_t dims3[2] = {31, 4};\n  const int32_t dims4[2] = {6, 8};\n\n  TestPermute2D<float, DataType::kFloat, 2>(&device_manager_registry_, available_device_types_,\n                                            dims0, permutation_list);\n  TestPermute2D<double, DataType::kDouble, 2>(&device_manager_registry_, available_device_types_,\n                                              dims1, permutation_list);\n  TestPermute2D<int32_t, DataType::kInt32, 2>(&device_manager_registry_, available_device_types_,\n                                              dims2, permutation_list);\n  TestPermute2D<int64_t, DataType::kInt64, 2>(&device_manager_registry_, available_device_types_,\n                                              dims3, permutation_list);\n  TestPermute2D<Eigen::half, DataType::kFloat16, 2>(\n      &device_manager_registry_, available_device_types_, dims4, permutation_list);\n}\n\nTEST_F(PrimitiveTest, TestPermute) {\n  const int permutation_list0[3] = {0, 2, 1};\n  const int permutation_list1[3] = {1, 2, 0};\n  const int permutation_list2[3] = {1, 0, 2};\n  const int permutation_list3[3] = {2, 1, 0};\n  const int permutation_list4[3] = {2, 0, 1};\n  const int32_t dims0[3] = {2, 3, 9};\n  const int32_t dims1[3] = {7, 9, 4};\n  const int32_t dims2[3] = {10, 3, 2};\n  const int32_t dims3[3] = {3, 7, 2};\n  const int32_t dims4[3] = {8, 2, 5};\n\n  TestPermute3D<float, DataType::kFloat, 3>(&device_manager_registry_, available_device_types_,\n                                            dims0, permutation_list0);\n  TestPermute3D<double, DataType::kDouble, 3>(&device_manager_registry_, available_device_types_,\n                                              dims1, permutation_list1);\n  TestPermute3D<int32_t, DataType::kInt32, 3>(&device_manager_registry_, available_device_types_,\n                                              dims2, permutation_list2);\n  TestPermute3D<int64_t, DataType::kInt64, 3>(&device_manager_registry_, available_device_types_,\n                                              dims3, permutation_list3);\n  TestPermute3D<Eigen::half, DataType::kFloat16, 3>(\n      &device_manager_registry_, available_device_types_, dims4, permutation_list4);\n}\n\n}  // namespace test\n\n}  // namespace primitive\n\n}  // namespace ep\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/ep/test/primitive/primitive_test.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_EP_TEST_PRIMITIVE_PRIMITIVE_TEST_\n#define ONEFLOW_CORE_EP_TEST_PRIMITIVE_PRIMITIVE_TEST_\n\n#include \"oneflow/core/ep/test/test_util.h\"\n\nnamespace oneflow {\n\nnamespace ep {\n\nnamespace primitive {\n\nnamespace test {\n\nclass PrimitiveTest : public ep::test::TestCase {};\n\n}  // namespace test\n\n}  // namespace primitive\n\n}  // namespace ep\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_EP_TEST_PRIMITIVE_PRIMITIVE_TEST_\n"
  },
  {
    "path": "oneflow/core/ep/test/primitive/softmax_backward_test.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <gtest/gtest.h>\n#include \"oneflow/core/ep/test/primitive/primitive_test.h\"\n#include \"oneflow/core/ep/include/primitive/memset.h\"\n#include \"oneflow/core/ep/include/primitive/memcpy.h\"\n#include \"oneflow/core/ep/include/primitive/softmax_backward.h\"\n#include \"oneflow/core/ep/include/primitive/log_softmax_backward.h\"\n#include <unsupported/Eigen/CXX11/Tensor>\n\nnamespace oneflow {\n\nnamespace ep {\n\nnamespace primitive {\n\nnamespace test {\n\nnamespace {\n\ntemplate<DataType data_type, typename T, typename ComputeType>\nvoid TestSoftmaxBackward(DeviceManagerRegistry* registry, const std::set<DeviceType>& device_types,\n                         int num_rows, int num_cols, bool log_softmax) {\n  const int elem_cnt = num_rows * num_cols;\n  const int data_size = elem_cnt * sizeof(T);\n  Eigen::Tensor<T, 2, Eigen::RowMajor> softmax_y(num_rows, num_cols);\n  Eigen::Tensor<T, 2, Eigen::RowMajor> softmax_dy(num_rows, num_cols);\n  Eigen::Tensor<T, 2, Eigen::RowMajor> softmax_dx(num_rows, num_cols);\n  softmax_y.setRandom();\n  softmax_dy.setRandom();\n  Eigen::array<int, 1> reduce_dim = {1};\n  Eigen::array<int, 2> reduced_shape = {num_rows, 1};\n  Eigen::array<int, 2> broadcast_shape = {1, num_cols};\n\n  Eigen::Tensor<ComputeType, 2, Eigen::RowMajor> compute_y = softmax_y.template cast<ComputeType>();\n  Eigen::Tensor<ComputeType, 2, Eigen::RowMajor> compute_dy =\n      softmax_dy.template cast<ComputeType>();\n  Eigen::Tensor<ComputeType, 2, Eigen::RowMajor> compute_dx;\n\n  if (log_softmax) {\n    compute_dx =\n        compute_dy\n        - compute_y.exp()\n              * compute_dy.sum(reduce_dim).eval().reshape(reduced_shape).broadcast(broadcast_shape);\n  } else {\n    Eigen::Tensor<ComputeType, 2, Eigen::RowMajor> row_buf = compute_dy * compute_y;\n    compute_dx =\n        (compute_dy\n         - row_buf.sum(reduce_dim).eval().reshape(reduced_shape).broadcast(broadcast_shape))\n        * compute_y;\n  }\n  softmax_dx = compute_dx.template cast<T>();\n\n  for (const auto& device_type : device_types) {\n    if (device_type == DeviceType::kCPU && data_type == DataType::kFloat16) {\n      // CPU softmax not support float16\n      continue;\n    }\n    auto device = registry->GetDevice(device_type, 0);\n    ep::test::PinnedMemoryGuard input_y(device.get(), data_size);\n    ep::test::PinnedMemoryGuard input_dy(device.get(), data_size);\n    ep::test::PinnedMemoryGuard output_dx(device.get(), data_size);\n    std::memcpy(input_y.ptr(), softmax_y.data(), data_size);\n    std::memcpy(input_dy.ptr(), softmax_dy.data(), data_size);\n    ep::test::DeviceMemoryGuard device_in_y(device.get(), data_size);\n    ep::test::DeviceMemoryGuard device_in_dy(device.get(), data_size);\n    ep::test::DeviceMemoryGuard device_out_dx(device.get(), data_size);\n    ep::test::StreamGuard stream(device.get());\n    std::unique_ptr<Memcpy> h2d = NewPrimitive<MemcpyFactory>(device_type, MemcpyKind::kHtoD);\n    ASSERT_TRUE(h2d.operator bool());\n    std::unique_ptr<Memcpy> d2h = NewPrimitive<MemcpyFactory>(device_type, MemcpyKind::kDtoH);\n    ASSERT_TRUE(d2h.operator bool());\n    h2d->Launch(stream.stream(), device_in_y.ptr(), input_y.ptr(), data_size);\n    h2d->Launch(stream.stream(), device_in_dy.ptr(), input_dy.ptr(), data_size);\n    if (log_softmax) {\n      std::unique_ptr<LogSoftmaxBackward> log_softmax =\n          NewPrimitive<LogSoftmaxBackwardFactory>(device_type, data_type);\n      ASSERT_TRUE(log_softmax.operator bool());\n      log_softmax->Launch(stream.stream(), num_rows, num_cols, device_in_y.ptr(),\n                          device_in_dy.ptr(), device_out_dx.ptr());\n    } else {\n      std::unique_ptr<SoftmaxBackward> softmax =\n          NewPrimitive<SoftmaxBackwardFactory>(device_type, data_type);\n      ASSERT_TRUE(softmax.operator bool());\n      softmax->Launch(stream.stream(), num_rows, num_cols, device_in_y.ptr(), device_in_dy.ptr(),\n                      device_out_dx.ptr());\n    }\n    d2h->Launch(stream.stream(), output_dx.ptr(), device_out_dx.ptr(), data_size);\n    CHECK_JUST(stream.stream()->Sync());\n    Eigen::Map<Eigen::Matrix<T, 1, Eigen::Dynamic>, Eigen::Unaligned> eigen_out(softmax_dx.data(),\n                                                                                softmax_dx.size());\n    Eigen::Map<Eigen::Matrix<T, 1, Eigen::Dynamic>, Eigen::Unaligned> of_out(\n        reinterpret_cast<T*>(output_dx.ptr()), softmax_dx.size());\n\n    ASSERT_TRUE(eigen_out.template isApprox(of_out, static_cast<T>(0.001)));\n  }\n}\n\nvoid TestSoftmaxBackward(DeviceManagerRegistry* registry, const std::set<DeviceType>& device_types,\n                         int num_rows, int num_cols) {\n  TestSoftmaxBackward<DataType::kFloat, float, float>(registry, device_types, num_rows, num_cols,\n                                                      true);\n  TestSoftmaxBackward<DataType::kFloat, float, float>(registry, device_types, num_rows, num_cols,\n                                                      false);\n  TestSoftmaxBackward<DataType::kDouble, double, double>(registry, device_types, num_rows, num_cols,\n                                                         true);\n  TestSoftmaxBackward<DataType::kDouble, double, double>(registry, device_types, num_rows, num_cols,\n                                                         false);\n  TestSoftmaxBackward<DataType::kFloat16, Eigen::half, float>(registry, device_types, num_rows,\n                                                              num_cols, true);\n  TestSoftmaxBackward<DataType::kFloat16, Eigen::half, float>(registry, device_types, num_rows,\n                                                              num_cols, false);\n}\n\n}  // namespace\n\nTEST_F(PrimitiveTest, TestSoftmaxBackward) {\n  std::vector<int> num_rows = {32, 33, 512, 511};\n  std::vector<int> num_cols = {15, 16, 32, 768, 1536};\n  for (int i = 0; i < num_rows.size(); ++i) {\n    for (int j = 0; j < num_cols.size(); ++j) {\n      TestSoftmaxBackward(&device_manager_registry_, available_device_types_, num_rows.at(i),\n                          num_cols.at(j));\n    }\n  }\n}\n\n}  // namespace test\n\n}  // namespace primitive\n\n}  // namespace ep\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/ep/test/primitive/softmax_test.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <gtest/gtest.h>\n#include \"oneflow/core/ep/test/primitive/primitive_test.h\"\n#include \"oneflow/core/ep/include/primitive/memset.h\"\n#include \"oneflow/core/ep/include/primitive/memcpy.h\"\n#include \"oneflow/core/ep/include/primitive/softmax.h\"\n#include \"oneflow/core/ep/include/primitive/log_softmax.h\"\n#include <unsupported/Eigen/CXX11/Tensor>\n\nnamespace oneflow {\n\nnamespace ep {\n\nnamespace primitive {\n\nnamespace test {\n\nnamespace {\n\ntemplate<DataType data_type, typename T>\nvoid TestSoftmax(DeviceManagerRegistry* registry, const std::set<DeviceType>& device_types,\n                 int num_rows, int num_cols, bool log_softmax) {\n  const int elem_cnt = num_rows * num_cols;\n  const int data_size = elem_cnt * sizeof(T);\n  Eigen::Tensor<T, 2, Eigen::RowMajor> softmax_in(num_rows, num_cols);\n  Eigen::Tensor<T, 2, Eigen::RowMajor> softmax_out(num_rows, num_cols);\n  softmax_in.setRandom();\n  Eigen::array<int, 1> reduce_dim = {1};\n  Eigen::array<int, 2> reduced_shape = {num_rows, 1};\n  Eigen::array<int, 2> broadcast_shape = {1, num_cols};\n\n  Eigen::Tensor<T, 2, Eigen::RowMajor> row_buf =\n      (softmax_in\n       - softmax_in.maximum(reduce_dim).eval().reshape(reduced_shape).broadcast(broadcast_shape));\n  if (log_softmax) {\n    softmax_out = row_buf\n                  - row_buf.exp()\n                        .sum(reduce_dim)\n                        .eval()\n                        .reshape(reduced_shape)\n                        .log()\n                        .broadcast(broadcast_shape);\n  } else {\n    row_buf = row_buf.exp();\n    softmax_out =\n        row_buf / row_buf.sum(reduce_dim).eval().reshape(reduced_shape).broadcast(broadcast_shape);\n  }\n\n  for (const auto& device_type : device_types) {\n    if (device_type == DeviceType::kCPU && data_type == DataType::kFloat16) {\n      // CPU softmax not support float16\n      continue;\n    }\n    auto device = registry->GetDevice(device_type, 0);\n    ep::test::PinnedMemoryGuard input(device.get(), data_size);\n    ep::test::PinnedMemoryGuard output(device.get(), data_size);\n    std::memcpy(input.ptr(), softmax_in.data(), data_size);\n    ep::test::DeviceMemoryGuard device_in(device.get(), data_size);\n    ep::test::DeviceMemoryGuard device_out(device.get(), data_size);\n    ep::test::StreamGuard stream(device.get());\n    std::unique_ptr<Memcpy> h2d = NewPrimitive<MemcpyFactory>(device_type, MemcpyKind::kHtoD);\n    ASSERT_TRUE(h2d.operator bool());\n    std::unique_ptr<Memcpy> d2h = NewPrimitive<MemcpyFactory>(device_type, MemcpyKind::kDtoH);\n    ASSERT_TRUE(d2h.operator bool());\n    h2d->Launch(stream.stream(), device_in.ptr(), input.ptr(), data_size);\n    if (log_softmax) {\n      std::unique_ptr<LogSoftmax> log_softmax =\n          NewPrimitive<LogSoftmaxFactory>(device_type, data_type);\n      ASSERT_TRUE(log_softmax.operator bool());\n      log_softmax->Launch(stream.stream(), num_rows, num_cols, device_in.ptr(), device_out.ptr());\n    } else {\n      std::unique_ptr<Softmax> softmax = NewPrimitive<SoftmaxFactory>(device_type, data_type);\n      ASSERT_TRUE(softmax.operator bool());\n      softmax->Launch(stream.stream(), num_rows, num_cols, device_in.ptr(), device_out.ptr());\n    }\n    d2h->Launch(stream.stream(), output.ptr(), device_out.ptr(), data_size);\n    CHECK_JUST(stream.stream()->Sync());\n    Eigen::Map<Eigen::Matrix<T, 1, Eigen::Dynamic>, Eigen::Unaligned> eigen_out(softmax_out.data(),\n                                                                                softmax_out.size());\n    Eigen::Map<Eigen::Matrix<T, 1, Eigen::Dynamic>, Eigen::Unaligned> of_out(\n        reinterpret_cast<T*>(output.ptr()), softmax_out.size());\n    ASSERT_TRUE(eigen_out.template isApprox(of_out, static_cast<T>(0.001)));\n  }\n}\n\nvoid TestSoftmax(DeviceManagerRegistry* registry, const std::set<DeviceType>& device_types,\n                 int num_rows, int num_cols) {\n  TestSoftmax<DataType::kFloat, float>(registry, device_types, num_rows, num_cols, true);\n  TestSoftmax<DataType::kFloat, float>(registry, device_types, num_rows, num_cols, false);\n  TestSoftmax<DataType::kDouble, double>(registry, device_types, num_rows, num_cols, true);\n  TestSoftmax<DataType::kDouble, double>(registry, device_types, num_rows, num_cols, false);\n  TestSoftmax<DataType::kFloat16, Eigen::half>(registry, device_types, num_rows, num_cols, true);\n  TestSoftmax<DataType::kFloat16, Eigen::half>(registry, device_types, num_rows, num_cols, false);\n}\n\n}  // namespace\n\nTEST_F(PrimitiveTest, TestSoftmax) {\n  std::vector<int> num_rows = {32, 33, 512, 511};\n  std::vector<int> num_cols = {15, 16, 32, 768, 1536};\n  for (int i = 0; i < num_rows.size(); ++i) {\n    for (int j = 0; j < num_cols.size(); ++j) {\n      TestSoftmax(&device_manager_registry_, available_device_types_, num_rows.at(i),\n                  num_cols.at(j));\n    }\n  }\n}\n\n}  // namespace test\n\n}  // namespace primitive\n\n}  // namespace ep\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/ep/test/primitive/unary_test.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <gtest/gtest.h>\n#include \"oneflow/core/ep/test/primitive/primitive_test.h\"\n#include \"oneflow/core/ep/include/primitive/memset.h\"\n#include \"oneflow/core/ep/include/primitive/memcpy.h\"\n#include \"oneflow/core/ep/include/primitive/elementwise_unary.h\"\n#include \"oneflow/core/ep/include/primitive/broadcast_elementwise_unary.h\"\n#include <Eigen/Core>\n#include <unsupported/Eigen/CXX11/Tensor>\n\nnamespace oneflow {\n\nnamespace ep {\n\nnamespace primitive {\n\nnamespace test {\n\nnamespace {\n\ntemplate<UnaryOp unary_op, DataType src_data_type, typename Src, DataType dst_data_type,\n         typename Dst>\nvoid TestElementwiseBroadcastUnary(DeviceManagerRegistry* registry,\n                                   const std::set<DeviceType>& device_types) {\n  const std::vector<int> num_src_axes = {1, 4, 1, 4, 4};\n  const std::vector<int> num_dst_axes = {4, 4, 1, 4, 4};\n\n  const std::vector<std::vector<int64_t>> a_dims_vec = {\n      {1, 1, 1, 1}, {1, 3, 2, 4}, {1, 1, 1, 1}, {1, 2, 3, 4}, {1, 2, 3, 4}};\n  const std::vector<std::vector<int64_t>> broadcast_dims_vec = {\n      {2, 3, 2, 4}, {2, 3, 2, 4}, {1, 1, 1, 1}, {1, 2, 3, 4}, {1, 2, 3, 4}};\n  const std::vector<std::vector<int64_t>> a_broadcasts_vec = {\n      {2, 3, 2, 4}, {2, 1, 1, 1}, {1, 1, 1, 1}, {1, 1, 1, 1}, {1, 1, 1, 1}};\n\n  const std::vector<std::vector<int64_t>> a_strides_vec = {\n      {0, 0, 0, 0},\n      {a_dims_vec[1][1] * a_dims_vec[1][2] * a_dims_vec[1][3], a_dims_vec[1][2] * a_dims_vec[1][3],\n       a_dims_vec[1][3], 1},\n      {0, 0, 0, 0},\n      {a_dims_vec[3][1] * a_dims_vec[3][2] * a_dims_vec[3][3], a_dims_vec[3][2] * a_dims_vec[3][3],\n       a_dims_vec[3][3], 1},\n      {a_dims_vec[4][1] * a_dims_vec[4][2] * a_dims_vec[4][3], a_dims_vec[4][2] * a_dims_vec[4][3],\n       a_dims_vec[4][3], 1}};\n  const std::vector<std::vector<int64_t>> c_strides_vec = {\n      {broadcast_dims_vec[0][1] * broadcast_dims_vec[0][2] * broadcast_dims_vec[0][3],\n       broadcast_dims_vec[0][2] * broadcast_dims_vec[0][3], broadcast_dims_vec[0][3], 1},\n      {broadcast_dims_vec[1][2] * broadcast_dims_vec[1][3],\n       broadcast_dims_vec[1][0] * broadcast_dims_vec[1][2] * broadcast_dims_vec[1][3], 1,\n       broadcast_dims_vec[1][2]},\n      {0, 0, 0, 0},\n      {broadcast_dims_vec[3][1] * broadcast_dims_vec[3][2] * broadcast_dims_vec[3][3],\n       broadcast_dims_vec[3][2], 1, broadcast_dims_vec[3][1] * broadcast_dims_vec[3][2]},\n      {1, broadcast_dims_vec[4][0], broadcast_dims_vec[4][0] * broadcast_dims_vec[4][1],\n       broadcast_dims_vec[4][0] * broadcast_dims_vec[4][1] * broadcast_dims_vec[4][2]}};\n\n  for (int i = 0; i < 5; i++) {\n    const std::vector<int64_t>& a_dims = a_dims_vec[i];\n    const std::vector<int64_t>& c_dims = broadcast_dims_vec[i];\n    const Eigen::array<int64_t, 4> a_broadcast = {a_broadcasts_vec[i][0], a_broadcasts_vec[i][1],\n                                                  a_broadcasts_vec[i][2], a_broadcasts_vec[i][3]};\n    Eigen::Tensor<Src, 4, Eigen::RowMajor> a(a_dims[0], a_dims[1], a_dims[2], a_dims[3]);\n\n    const std::vector<int64_t>& a_strides = a_strides_vec[i];\n    const std::vector<int64_t>& c_strides = c_strides_vec[i];\n\n    a.setRandom();\n\n    Eigen::Tensor<Src, 4, Eigen::RowMajor> t = a.broadcast(a_broadcast);\n    Eigen::Tensor<Dst, 4, Eigen::RowMajor> broadcast_a = t.template cast<Dst>();\n\n    const int64_t a_size = a.size() * sizeof(Src);\n    const int64_t c_count =\n        std::accumulate(c_dims.begin(), c_dims.end(), 1, std::multiplies<int64_t>());\n    const int64_t c_size = c_count * sizeof(Dst);\n    const int64_t broadcast_a_size = broadcast_a.size() * sizeof(Dst);\n\n    ASSERT_TRUE(c_size == broadcast_a_size);\n\n    for (const auto& device_type : device_types) {\n      // broadcast a with non-broadcast elementwise unary primitive\n      auto device = registry->GetDevice(device_type, 0);\n      ep::test::PinnedMemoryGuard input_broadcast_a(device.get(), broadcast_a_size);\n      std::memcpy(input_broadcast_a.ptr(), broadcast_a.data(), broadcast_a_size);\n\n      ep::test::PinnedMemoryGuard broadcast_output(device.get(), c_size);\n      ep::test::DeviceMemoryGuard device_broadcast_a(device.get(), broadcast_a_size);\n      ep::test::DeviceMemoryGuard device_broadcast_c(device.get(), c_size);\n      ep::test::StreamGuard stream(device.get());\n      std::unique_ptr<Memcpy> h2d = NewPrimitive<MemcpyFactory>(device_type, MemcpyKind::kHtoD);\n      std::unique_ptr<Memcpy> d2h = NewPrimitive<MemcpyFactory>(device_type, MemcpyKind::kDtoH);\n      std::unique_ptr<ElementwiseUnary> unary = NewPrimitive<ElementwiseUnaryFactory>(\n          device_type, unary_op, src_data_type, dst_data_type);\n      ASSERT_TRUE(d2h.operator bool());\n      ASSERT_TRUE(h2d.operator bool());\n      ASSERT_TRUE(unary.operator bool());\n      h2d->Launch(stream.stream(), device_broadcast_a.ptr(), input_broadcast_a.ptr(),\n                  broadcast_a_size);\n      unary->Launch(stream.stream(), device_broadcast_a.ptr(), device_broadcast_c.ptr(),\n                    c_count);  // c.size() is for count\n      d2h->Launch(stream.stream(), broadcast_output.ptr(), device_broadcast_c.ptr(),\n                  c_size);  // c_size is in bytes\n      CHECK_JUST(stream.stream()->Sync());\n\n      ep::test::PinnedMemoryGuard input_a(device.get(), a_size);\n      std::memcpy(input_a.ptr(), a.data(), a_size);\n\n      ep::test::PinnedMemoryGuard output(device.get(), c_size);\n      ep::test::DeviceMemoryGuard device_a(device.get(), a_size);\n      ep::test::DeviceMemoryGuard device_c(device.get(), c_size);\n      std::unique_ptr<BroadcastElementwiseUnary> broadcast_unary =\n          NewPrimitive<BroadcastElementwiseUnaryFactory>(device_type, unary_op, src_data_type,\n                                                         dst_data_type,\n                                                         MAX(num_src_axes[i], num_dst_axes[i]));\n      ASSERT_TRUE(broadcast_unary.operator bool());\n      h2d->Launch(stream.stream(), device_a.ptr(), input_a.ptr(), a_size);\n\n      broadcast_unary->Launch(stream.stream(), num_src_axes[i], a_dims.data(), a_strides.data(),\n                              device_a.ptr(), num_dst_axes[i], c_dims.data(), c_strides.data(),\n                              device_c.ptr());\n      d2h->Launch(stream.stream(), output.ptr(), device_c.ptr(), c_size);\n      CHECK_JUST(stream.stream()->Sync());\n\n      Dst thresh = 1e-4;\n      bool res = true;\n\n      std::vector<int64_t> a_broadcast_strides;\n      for (int j = num_dst_axes[i] - 1; j >= 0; j--) {\n        if (j == num_dst_axes[i] - 1) {\n          a_broadcast_strides.push_back(1);\n        } else {\n          a_broadcast_strides.insert(a_broadcast_strides.begin(),\n                                     a_broadcast_strides[0] * a_dims[j + 1] * a_broadcast[j + 1]);\n        }\n      }\n\n      for (int i0 = 0; i0 < c_dims[0]; i0++) {\n        for (int i1 = 0; i1 < c_dims[1]; i1++) {\n          for (int i2 = 0; i2 < c_dims[2]; i2++) {\n            for (int i3 = 0; i3 < c_dims[3]; i3++) {\n#define ABS(x) ((x > 0) ? (x) : (-x))\n              const size_t src_index = a_broadcast_strides[0] * i0 + a_broadcast_strides[1] * i1\n                                       + a_broadcast_strides[2] * i2 + a_broadcast_strides[3] * i3;\n              const size_t dst_index =\n                  c_strides[0] * i0 + c_strides[1] * i1 + c_strides[2] * i2 + c_strides[3] * i3;\n              if (ABS(reinterpret_cast<Dst*>(broadcast_output.ptr())[src_index]\n                      - reinterpret_cast<Dst*>(output.ptr())[dst_index])\n                  > thresh) {\n                res = false;\n              }\n#undef ABS\n            }\n          }\n        }\n      }\n      ASSERT_TRUE(res);\n    }\n  }\n}\n\ntemplate<DataType src_data_type, typename Src, DataType dst_data_type, typename Dst>\nvoid TestElementwiseBroadcastUnaryBatchPermute(DeviceManagerRegistry* registry,\n                                               const std::set<DeviceType>& device_types) {\n  const std::vector<int64_t>& a_dims = {5, 2};\n  const std::vector<int64_t>& c_dims = {5, 2};\n  Eigen::Tensor<Src, 2, Eigen::RowMajor> a(5, 4);\n\n  const std::vector<std::vector<int64_t>>& a_strides = {{4, 1}, {2, 1}};\n  const std::vector<std::vector<int64_t>>& c_strides = {{1, 5}, {1, 10}};\n\n  a.setRandom();\n\n  const int64_t a_size = a.size() * sizeof(Src);\n  const int64_t c_count =\n      std::accumulate(c_dims.begin(), c_dims.end(), 1, std::multiplies<int64_t>());\n  const int64_t c_size = MAX(c_count, a.size()) * sizeof(Dst);\n\n  for (int i = 0; i < a_strides.size(); i++) {\n    auto& a_stride = a_strides[i];\n    auto& c_stride = c_strides[i];\n    for (const auto& device_type : device_types) {\n      // broadcast a with non-broadcast elementwise unary primitive\n      auto device = registry->GetDevice(device_type, 0);\n      ep::test::StreamGuard stream(device.get());\n\n      ep::test::PinnedMemoryGuard input_a(device.get(), a_size);\n      std::memcpy(input_a.ptr(), a.data(), a_size);\n\n      ep::test::PinnedMemoryGuard output(device.get(), c_size);\n      ep::test::DeviceMemoryGuard device_a(device.get(), a_size);\n      ep::test::DeviceMemoryGuard device_c(device.get(), c_size);\n      std::unique_ptr<Memcpy> h2d = NewPrimitive<MemcpyFactory>(device_type, MemcpyKind::kHtoD);\n      std::unique_ptr<Memcpy> d2h = NewPrimitive<MemcpyFactory>(device_type, MemcpyKind::kDtoH);\n      std::unique_ptr<BroadcastElementwiseUnary> broadcast_unary =\n          NewPrimitive<BroadcastElementwiseUnaryFactory>(device_type, UnaryOp::kIdentity,\n                                                         src_data_type, dst_data_type, 2);\n      ASSERT_TRUE(broadcast_unary.operator bool());\n      ASSERT_TRUE(d2h.operator bool());\n      ASSERT_TRUE(h2d.operator bool());\n      h2d->Launch(stream.stream(), device_a.ptr(), input_a.ptr(), a_size);\n\n      broadcast_unary->Launch(stream.stream(), 2, a_dims.data(), a_stride.data(), device_a.ptr(), 2,\n                              c_dims.data(), c_stride.data(), device_c.ptr());\n\n      d2h->Launch(stream.stream(), output.ptr(), device_c.ptr(), c_size);\n      CHECK_JUST(stream.stream()->Sync());\n\n      Dst thresh = 1e-4;\n      bool res = true;\n\n      for (int i0 = 0; i0 < c_dims[0]; i0++) {\n        for (int i1 = 0; i1 < c_dims[1]; i1++) {\n#define ABS(x) ((x > 0) ? (x) : (-x))\n          const size_t src_index = a_stride[0] * i0 + a_stride[1] * i1;\n          const size_t dst_index = c_stride[0] * i0 + c_stride[1] * i1;\n          if (ABS(reinterpret_cast<Dst*>(input_a.ptr())[src_index]\n                  - reinterpret_cast<Dst*>(output.ptr())[dst_index])\n              > thresh) {\n            res = false;\n          }\n#undef ABS\n        }\n      }\n      ASSERT_TRUE(res);\n    }\n  }\n}\n\n}  // namespace\n\nTEST_F(PrimitiveTest, TestUnary) {\n  TestElementwiseBroadcastUnary<UnaryOp::kIdentity, DataType::kFloat, float, DataType::kFloat,\n                                float>(&device_manager_registry_, available_device_types_);\n  TestElementwiseBroadcastUnaryBatchPermute<DataType::kFloat, float, DataType::kFloat, float>(\n      &device_manager_registry_, available_device_types_);\n}\n\n}  // namespace test\n\n}  // namespace primitive\n\n}  // namespace ep\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/ep/test/primitive/where_test.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/ep/test/primitive/primitive_test.h\"\n#include \"oneflow/core/ep/include/primitive/memset.h\"\n#include \"oneflow/core/ep/include/primitive/memcpy.h\"\n#include \"oneflow/core/ep/include/primitive/where.h\"\n#include \"oneflow/core/common/data_type.h\"\n\n#include <gtest/gtest.h>\n#include <unsupported/Eigen/CXX11/Tensor>\n#include <sstream>\n\nnamespace oneflow {\n\ntemplate<>\nstruct GetDataType<Eigen::half> : std::integral_constant<DataType, DataType::kFloat16> {};\n\nnamespace ep {\nnamespace primitive {\nnamespace test {\n\nnamespace {\n\ntemplate<typename dims_type>\nstd::string DimsToString(const dims_type& dims, const std::string& name) {\n  std::ostringstream ss;\n  ss << name << \"=(\";\n  for (size_t i = 0; i < dims.size(); ++i) {\n    if (i > 0) { ss << \", \"; }\n    ss << dims[i];\n  }\n  ss << \")\";\n  return ss.str();\n};\n\ntemplate<typename T, typename CondT, size_t ndim>\nvoid TestWhere(const std::vector<Device*>& devices, size_t num_cond_dims, const int64_t* cond_dims,\n               size_t num_x_dims, const int64_t* x_dims, size_t num_y_dims, const int64_t* y_dims) {\n  ASSERT_TRUE(num_cond_dims <= ndim);\n  ASSERT_TRUE(num_x_dims <= ndim);\n  ASSERT_TRUE(num_y_dims <= ndim);\n\n  std::array<int64_t, ndim> broadcast_dims{};\n  std::array<int64_t, ndim> broadcast_cond_dims{};\n  std::array<int64_t, ndim> broadcast_x_dims{};\n  std::array<int64_t, ndim> broadcast_y_dims{};\n  std::array<int64_t, ndim> extend_cond_dims{};\n  std::array<int64_t, ndim> extend_x_dims{};\n  std::array<int64_t, ndim> extend_y_dims{};\n  for (size_t i = 0; i < ndim; ++i) {\n    size_t cond_lpad = ndim - num_cond_dims;\n    size_t x_lpad = ndim - num_x_dims;\n    size_t y_lpad = ndim - num_y_dims;\n    int64_t cond_dim = (i < cond_lpad) ? 1 : cond_dims[i - cond_lpad];\n    int64_t x_dim = (i < x_lpad) ? 1 : x_dims[i - x_lpad];\n    int64_t y_dim = (i < y_lpad) ? 1 : y_dims[i - y_lpad];\n    int64_t max_dim = std::max(x_dim, y_dim);\n    max_dim = std::max(max_dim, cond_dim);\n    ASSERT_TRUE((cond_dim == 1 || cond_dim == max_dim) && (x_dim == 1 || x_dim == max_dim)\n                && (y_dim == 1 || y_dim == max_dim));\n    broadcast_dims[i] = max_dim;\n    broadcast_cond_dims[i] = (cond_dim == max_dim) ? 1 : max_dim;\n    broadcast_x_dims[i] = (x_dim == max_dim) ? 1 : max_dim;\n    broadcast_y_dims[i] = (y_dim == max_dim) ? 1 : max_dim;\n    extend_cond_dims[i] = cond_dim;\n    extend_x_dims[i] = x_dim;\n    extend_y_dims[i] = y_dim;\n  }\n\n  size_t cond_size = std::accumulate(extend_cond_dims.begin(), extend_cond_dims.end(), 1,\n                                     std::multiplies<int64_t>());\n  size_t x_size =\n      std::accumulate(extend_x_dims.begin(), extend_x_dims.end(), 1, std::multiplies<int64_t>());\n  size_t y_size =\n      std::accumulate(extend_y_dims.begin(), extend_y_dims.end(), 1, std::multiplies<int64_t>());\n  size_t z_size =\n      std::accumulate(broadcast_dims.begin(), broadcast_dims.end(), 1, std::multiplies<int64_t>());\n  size_t cond_byte_size = cond_size * sizeof(CondT);\n  size_t x_byte_size = x_size * sizeof(T);\n  size_t y_byte_size = y_size * sizeof(T);\n  size_t z_byte_size = z_size * sizeof(T);\n\n  // Eigen contrast\n  Eigen::Tensor<T, ndim, Eigen::RowMajor> tensor_c(extend_cond_dims);\n  Eigen::Tensor<T, ndim, Eigen::RowMajor> tensor_x(extend_x_dims);\n  Eigen::Tensor<T, ndim, Eigen::RowMajor> tensor_y(extend_y_dims);\n  tensor_c.setRandom();\n  tensor_x.setRandom();\n  tensor_y.setRandom();\n  tensor_c = tensor_c.unaryExpr([](T x) -> T { return x > T{0} ? T{1} : T{0}; });\n  Eigen::Tensor<CondT, ndim, Eigen::RowMajor> tensor_cond = tensor_c.template cast<CondT>();\n  auto broadcast_c = tensor_cond.broadcast(broadcast_cond_dims);\n  auto broadcast_x = tensor_x.broadcast(broadcast_x_dims);\n  auto broadcast_y = tensor_y.broadcast(broadcast_y_dims);\n  Eigen::Tensor<T, ndim, Eigen::RowMajor> tensor_z = broadcast_c.select(broadcast_x, broadcast_y);\n  ASSERT_TRUE(tensor_z.size() == z_size) << tensor_z.size() << \" vs. \" << z_size << \", \";\n\n  // test on devices\n  for (auto* device : devices) {\n    if (device->device_type() == DeviceType::kCPU && GetDataType<T>() == DataType::kFloat16) {\n      // CPU matmul not support float16\n      continue;\n    }\n\n    ep::test::PinnedMemoryGuard host_cond(device, cond_byte_size);\n    ep::test::PinnedMemoryGuard host_x(device, x_byte_size);\n    ep::test::PinnedMemoryGuard host_y(device, y_byte_size);\n    ep::test::DeviceMemoryGuard cond(device, cond_byte_size);\n    ep::test::DeviceMemoryGuard x(device, x_byte_size);\n    ep::test::DeviceMemoryGuard y(device, y_byte_size);\n    ep::test::DeviceMemoryGuard z(device, z_byte_size);\n    ep::test::PinnedMemoryGuard host_z(device, z_byte_size);\n\n    std::memcpy(host_cond.ptr(), tensor_cond.data(), cond_byte_size);\n    std::memcpy(host_x.ptr(), tensor_x.data(), x_byte_size);\n    std::memcpy(host_y.ptr(), tensor_y.data(), y_byte_size);\n\n    ep::test::StreamGuard stream(device);\n    auto h2d = NewPrimitive<MemcpyFactory>(device->device_type(), MemcpyKind::kHtoD);\n    auto d2h = NewPrimitive<MemcpyFactory>(device->device_type(), MemcpyKind::kDtoH);\n    auto where = NewPrimitive<WhereFactory>(device->device_type(), GetDataType<CondT>(),\n                                            GetDataType<T>(), ndim);\n    ASSERT_TRUE(d2h.operator bool());\n    ASSERT_TRUE(h2d.operator bool());\n    ASSERT_TRUE(where.operator bool());\n\n    h2d->Launch(stream.stream(), cond.ptr(), host_cond.ptr(), cond_byte_size);\n    h2d->Launch(stream.stream(), x.ptr(), host_x.ptr(), x_byte_size);\n    h2d->Launch(stream.stream(), y.ptr(), host_y.ptr(), y_byte_size);\n    where->Launch(stream.stream(), num_cond_dims, cond_dims, cond.ptr(), num_x_dims, x_dims,\n                  x.ptr(), num_y_dims, y_dims, y.ptr(), z.ptr());\n    d2h->Launch(stream.stream(), host_z.ptr(), z.ptr(), z_byte_size);\n    CHECK_JUST(stream.stream()->Sync());\n\n    Eigen::Map<Eigen::Matrix<T, 1, Eigen::Dynamic>, Eigen::Unaligned> eigen_out(tensor_z.data(),\n                                                                                tensor_z.size());\n    Eigen::Map<Eigen::Matrix<T, 1, Eigen::Dynamic>, Eigen::Unaligned> of_out(\n        reinterpret_cast<T*>(host_z.ptr()), z_size);\n    ASSERT_TRUE(eigen_out.template isApprox(of_out));\n  }\n}\n\ntemplate<typename T, typename CondT, size_t ndim>\nvoid TestWhere(DeviceManagerRegistry* registry, const std::set<DeviceType>& device_types,\n               const std::vector<int64_t>& cond_dims, const std::vector<int64_t>& x_dims,\n               const std::vector<int64_t>& y_dims) {\n  std::vector<Device*> devices;\n  for (const auto& device_type : device_types) {\n    auto device = registry->GetDevice(device_type, 0);\n    ASSERT_TRUE(device);\n    devices.push_back(device.get());\n  }\n  TestWhere<T, CondT, ndim>(devices, cond_dims.size(), cond_dims.data(), x_dims.size(),\n                            x_dims.data(), y_dims.size(), y_dims.data());\n}\n\ntemplate<typename T, typename = void>\nstruct random {};\n\ntemplate<>\nstruct random<bool, void> {\n  bool operator()() {\n    static std::default_random_engine e;\n    static std::uniform_int_distribution<> dis(0, 1);\n    return static_cast<bool>(dis(e));\n  }\n};\n\ntemplate<typename T>\nstruct random<T, std::enable_if_t<std::is_integral<T>::value>> {\n  T operator()() {\n    static std::default_random_engine e;\n    static std::normal_distribution<> dis(0, 2);\n    return dis(e);\n  }\n};\n\ntemplate<>\nstruct random<Eigen::half, void> {\n  Eigen::half operator()() {\n    static std::default_random_engine e;\n    static std::uniform_real_distribution<> dis(-1, 1);\n    return Eigen::half{dis(e)};\n  }\n};\n\ntemplate<typename T>\nstruct random<T, std::enable_if_t<std::is_floating_point<T>::value>> {\n  T operator()() {\n    static std::default_random_engine e;\n    static std::uniform_real_distribution<> dis(-1, 1);\n    return dis(e);\n  }\n};\n\ntemplate<typename T, typename CondT>\nvoid TestScalarWhere(DeviceManagerRegistry* registry, const std::set<DeviceType>& device_types) {\n  std::vector<Device*> devices;\n  for (const auto& device_type : device_types) {\n    auto device_ptr = registry->GetDevice(device_type, 0);\n    ASSERT_TRUE(device_ptr);\n    Device* device = device_ptr.get();\n\n    CondT cond = random<bool>()();\n    T x = random<T>()();\n    T y = random<T>()();\n    T z = cond ? x : y;\n\n    ep::test::PinnedMemoryGuard host_cond(device, sizeof(CondT));\n    ep::test::PinnedMemoryGuard host_x(device, sizeof(T));\n    ep::test::PinnedMemoryGuard host_y(device, sizeof(T));\n    ep::test::DeviceMemoryGuard device_cond(device, sizeof(CondT));\n    ep::test::DeviceMemoryGuard device_x(device, sizeof(T));\n    ep::test::DeviceMemoryGuard device_y(device, sizeof(T));\n    ep::test::DeviceMemoryGuard device_z(device, sizeof(T));\n    ep::test::PinnedMemoryGuard host_z(device, sizeof(T));\n\n    std::memcpy(host_cond.ptr(), &cond, sizeof(CondT));\n    std::memcpy(host_x.ptr(), &x, sizeof(T));\n    std::memcpy(host_y.ptr(), &y, sizeof(T));\n\n    ep::test::StreamGuard stream(device);\n    auto h2d = NewPrimitive<MemcpyFactory>(device_type, MemcpyKind::kHtoD);\n    auto d2h = NewPrimitive<MemcpyFactory>(device_type, MemcpyKind::kDtoH);\n    auto where = NewPrimitive<WhereFactory>(device_type, GetDataType<CondT>(), GetDataType<T>(), 0);\n    ASSERT_TRUE(d2h.operator bool());\n    ASSERT_TRUE(h2d.operator bool());\n    ASSERT_TRUE(where.operator bool());\n\n    h2d->Launch(stream.stream(), device_cond.ptr(), host_cond.ptr(), sizeof(CondT));\n    h2d->Launch(stream.stream(), device_x.ptr(), host_x.ptr(), sizeof(T));\n    h2d->Launch(stream.stream(), device_y.ptr(), host_y.ptr(), sizeof(T));\n    where->Launch(stream.stream(), 0, nullptr, device_cond.ptr(), 0, nullptr, device_x.ptr(), 0,\n                  nullptr, device_y.ptr(), device_z.ptr());\n    d2h->Launch(stream.stream(), host_z.ptr(), device_z.ptr(), sizeof(T));\n    CHECK_JUST(stream.stream()->Sync());\n\n    ASSERT_TRUE(*host_z.ptr<T>() == z);\n  }\n}\n\n}  // namespace\n\nTEST_F(PrimitiveTest, TestWhere) {\n  TestWhere<float, bool, 2>(&device_manager_registry_, available_device_types_, {4, 8}, {4, 8},\n                            {4, 8});\n  TestWhere<bool, bool, 2>(&device_manager_registry_, available_device_types_, {4, 1}, {1, 8},\n                           {1, 8});\n  TestWhere<uint8_t, bool, 2>(&device_manager_registry_, available_device_types_, {4, 1}, {1, 8},\n                              {1, 8});\n  TestWhere<int32_t, bool, 2>(&device_manager_registry_, available_device_types_, {4, 1}, {1, 8},\n                              {1, 8});\n  TestWhere<Eigen::half, bool, 2>(&device_manager_registry_, available_device_types_, {4, 1},\n                                  {1, 8}, {1, 8});\n  TestWhere<double, bool, 2>(&device_manager_registry_, available_device_types_, {4, 1}, {1, 8},\n                             {1, 8});\n  TestWhere<bool, int32_t, 2>(&device_manager_registry_, available_device_types_, {1, 8}, {4, 8},\n                              {1});\n  TestWhere<int32_t, int32_t, 2>(&device_manager_registry_, available_device_types_, {1, 8}, {4, 8},\n                                 {1});\n  TestWhere<float, int32_t, 2>(&device_manager_registry_, available_device_types_, {1, 8}, {4, 8},\n                               {1});\n  TestWhere<Eigen::half, int32_t, 2>(&device_manager_registry_, available_device_types_, {1, 8},\n                                     {4, 8}, {1});\n  TestWhere<double, int32_t, 2>(&device_manager_registry_, available_device_types_, {1, 8}, {4, 8},\n                                {1});\n  TestWhere<float, bool, 2>(&device_manager_registry_, available_device_types_, {1, 6}, {2, 6},\n                            {2, 1});\n  TestWhere<float, bool, 2>(&device_manager_registry_, available_device_types_, {3, 7}, {3, 1},\n                            {1, 7});\n  TestWhere<float, bool, 3>(&device_manager_registry_, available_device_types_, {1, 4, 8},\n                            {4, 1, 8}, {1, 1, 8});\n  TestWhere<float, bool, 3>(&device_manager_registry_, available_device_types_, {1, 4, 8},\n                            {4, 4, 8}, {1});\n  TestWhere<float, bool, 4>(&device_manager_registry_, available_device_types_, {2, 1, 4, 8},\n                            {1, 3, 4, 1}, {4, 8});\n  TestScalarWhere<bool, bool>(&device_manager_registry_, available_device_types_);\n  TestScalarWhere<float, bool>(&device_manager_registry_, available_device_types_);\n  TestScalarWhere<Eigen::half, bool>(&device_manager_registry_, available_device_types_);\n  TestScalarWhere<double, bool>(&device_manager_registry_, available_device_types_);\n  TestScalarWhere<int32_t, bool>(&device_manager_registry_, available_device_types_);\n  TestScalarWhere<bool, int32_t>(&device_manager_registry_, available_device_types_);\n  TestScalarWhere<float, int32_t>(&device_manager_registry_, available_device_types_);\n  TestScalarWhere<Eigen::half, int32_t>(&device_manager_registry_, available_device_types_);\n  TestScalarWhere<double, int32_t>(&device_manager_registry_, available_device_types_);\n  TestScalarWhere<int32_t, int32_t>(&device_manager_registry_, available_device_types_);\n}\n\n}  // namespace test\n}  // namespace primitive\n}  // namespace ep\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/ep/test/test_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_EP_TEST_TEST_UTIL_\n#define ONEFLOW_CORE_EP_TEST_TEST_UTIL_\n\n#include <gtest/gtest.h>\n#include \"oneflow/core/ep/include/device_manager_registry.h\"\n\nnamespace oneflow {\n\nnamespace ep {\n\nnamespace test {\n\nclass TestCase : public ::testing::Test {\n protected:\n  void SetUp() override {\n    for (const auto& device_type : device_manager_registry_.GetRegisteredDeviceTypes()) {\n      // ignore mock device\n      if (device_type == DeviceType::kMockDevice || device_type == DeviceType::kMeta) { continue; }\n      if (device_manager_registry_.GetDeviceManager(device_type)->GetDeviceCount() > 0) {\n        available_device_types_.insert(device_type);\n      }\n    }\n  }\n  void TearDown() override {\n    // do nothing\n  }\n  DeviceManagerRegistry device_manager_registry_;\n  std::set<DeviceType> available_device_types_;\n};\n\nclass DeviceMemoryGuard {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(DeviceMemoryGuard);\n  DeviceMemoryGuard(Device* device, size_t size) : device_(device), options_{} {\n    CHECK_JUST(device_->Alloc(options_, &ptr_, size));\n  }\n\n  ~DeviceMemoryGuard() { device_->Free(options_, ptr_); }\n\n  template<typename T = void>\n  T* ptr() {\n    return reinterpret_cast<T*>(ptr_);\n  }\n\n private:\n  Device* device_;\n  AllocationOptions options_;\n  void* ptr_{};\n};\n\nclass PinnedMemoryGuard {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(PinnedMemoryGuard);\n  PinnedMemoryGuard(Device* device, size_t size) : device_(device) {\n    options_.SetPinnedDevice(device->device_type(), 0);\n    CHECK_JUST(device_->AllocPinned(options_, &ptr_, size));\n  }\n\n  ~PinnedMemoryGuard() { device_->FreePinned(options_, ptr_); }\n\n  template<typename T = void>\n  T* ptr() {\n    return reinterpret_cast<T*>(ptr_);\n  }\n\n private:\n  AllocationOptions options_;\n  Device* device_;\n  void* ptr_{};\n};\n\nclass StreamGuard {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(StreamGuard);\n  explicit StreamGuard(Device* device) : device_(device) {\n    stream_ = device_->CreateStream();\n    CHECK_JUST(stream_->OnExecutionContextSetup());\n  }\n\n  ~StreamGuard() {\n    CHECK_JUST(stream_->OnExecutionContextTeardown());\n    device_->DestroyStream(stream_);\n  }\n\n  Stream* stream() { return stream_; }\n\n private:\n  Device* device_;\n  Stream* stream_;\n};\n\n}  // namespace test\n\n}  // namespace ep\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_EP_TEST_TEST_UTIL_\n"
  },
  {
    "path": "oneflow/core/framework/arg_tuple.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/framework/arg_tuple.h\"\n#include <glog/logging.h>\n\nnamespace oneflow {\n\nnamespace {\n\nstd::pair<std::string, int> GetPair(const std::string& bn) {\n  int32_t index = 0;\n  const size_t pos = bn.rfind('_');\n  if (pos != std::string::npos) { index = std::stoi(bn.substr(pos + 1)); }\n  return std::make_pair(bn.substr(0, pos), index);\n}\n\nvoid InitArgName2BnIndex2TensorTupleIndex(\n    const std::vector<std::pair<std::string, int32_t>>& indexed_arg_pairs,\n    std::unordered_map<std::string, std::vector<int32_t>>* arg_name2bn_index2tensor_tuple_index) {\n  for (int i = 0; i < indexed_arg_pairs.size(); i++) {\n    const auto& pair = indexed_arg_pairs.at(i);\n    const std::string& arg_name = pair.first;\n    const int32_t bn_index = pair.second;\n    // vector is auto created by [] if arg_name doesn't exist in map\n    auto* bn_index2tensor_tuple_index = &(*arg_name2bn_index2tensor_tuple_index)[arg_name];\n    CHECK_EQ(bn_index2tensor_tuple_index->size(), bn_index)\n        << \"Duplicate index of \" << arg_name << \": \" << bn_index;\n    bn_index2tensor_tuple_index->emplace_back(i);\n  }\n}\n\n}  // namespace\n\nArgTuple::ArgTuple(const std::vector<std::string>& indexed_bns) : indexed_bns_(indexed_bns) {\n  indexed_arg_name_and_index_.reserve(indexed_bns.size());\n  for (const auto& bn : indexed_bns) { indexed_arg_name_and_index_.emplace_back(GetPair(bn)); }\n  InitArgName2BnIndex2TensorTupleIndex(indexed_arg_name_and_index_,\n                                       &arg_name2bn_index2tensor_tuple_index_);\n  for (int i = 0; i < indexed_bns.size(); ++i) {\n    bn_in_op2tensor_tuple_index_[indexed_bns.at(i)] = i;\n  }\n}\n\nint32_t ArgTuple::TensorTupleIndex4ArgNameAndIndex(const std::string& name, int32_t index) const {\n  const auto& map = arg_name2bn_index2tensor_tuple_index_;\n  const auto& iter = map.find(name);\n  if (iter == map.end()) { return -1; }\n  const auto& vec = iter->second;\n  return vec.at(index);\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/framework/arg_tuple.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_FRAMEWORK_ARG_TUPLE_H_\n#define ONEFLOW_CORE_FRAMEWORK_ARG_TUPLE_H_\n\n#include <string>\n#include <vector>\n#include <unordered_map>\n\nnamespace oneflow {\n\nclass ArgTuple final {\n public:\n  explicit ArgTuple(const std::vector<std::string>& indexed_bns);\n  ~ArgTuple() = default;\n\n  std::size_t size() const { return indexed_bns_.size(); }\n\n  const std::vector<std::string>& indexed_bns() const { return indexed_bns_; }\n  const std::vector<std::pair<std::string, int32_t>>& indexed_arg_name_and_index() const {\n    return indexed_arg_name_and_index_;\n  }\n  const std::unordered_map<std::string, std::vector<int32_t>>&\n  arg_name2bn_index2tensor_tuple_index() const {\n    return arg_name2bn_index2tensor_tuple_index_;\n  }\n  const std::unordered_map<std::string, int32_t>& bn_in_op2tensor_tuple_index() const {\n    return bn_in_op2tensor_tuple_index_;\n  }\n\n  // return -1 if not found\n  int32_t TensorTupleIndex4ArgNameAndIndex(const std::string& name, int32_t index) const;\n\n private:\n  std::vector<std::string> indexed_bns_;\n  std::vector<std::pair<std::string, int32_t>> indexed_arg_name_and_index_;\n  std::unordered_map<std::string, std::vector<int32_t>> arg_name2bn_index2tensor_tuple_index_;\n  std::unordered_map<std::string, int32_t> bn_in_op2tensor_tuple_index_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_FRAMEWORK_ARG_TUPLE_H_\n"
  },
  {
    "path": "oneflow/core/framework/attr_map.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include <fmt/ranges.h>\n#include \"oneflow/core/framework/attr_map.h\"\n#include \"oneflow/core/framework/attr_value.h\"\n#include \"oneflow/core/framework/attr_value_accessor.h\"\n#include \"oneflow/core/framework/user_op_attr.pb.h\"\n#include \"oneflow/core/operator/op_conf.pb.h\"\n#include \"oneflow/core/framework/mutable_attr_map.h\"\n\nnamespace oneflow {\n\nAttrMap::AttrInternal::AttrInternal()\n    : max_size(0),\n      size(0),\n      hash_value(0),\n      ordered_attr_names(std::make_shared<OrderedStringList<8>>()) {}\n\nAttrMap::AttrInternal::AttrInternal(\n    size_t _max_size, size_t _size, size_t _hash_value,\n    const std::shared_ptr<OrderedStringList<8>>& _ordered_attr_names)\n    : max_size(_max_size),\n      size(_size),\n      hash_value(_hash_value),\n      ordered_attr_names(_ordered_attr_names) {}\n\nAttrMap::AttrMap() : internal_(std::make_shared<AttrMap::AttrInternal>()) {}\n\nAttrMap::AttrMap(const MutableAttrMap& other)\n    : internal_(std::make_shared<AttrMap::AttrInternal>(other.max_size(), /*size*/ 0,\n                                                        /*hash_value*/ 0,\n                                                        other.ordered_attr_names())) {\n  internal_->attrs.resize(internal_->max_size);\n  for (int i = 0; i < internal_->max_size; ++i) {\n    internal_->attrs[i].second = other.valid_masks()[i];\n    if (other.valid_masks()[i]) {\n      ++(internal_->size);\n      internal_->attrs[i].first = other.attrs()[i];\n      // compute hash code\n      HashCombine(&internal_->hash_value, other.attrs()[i]->hash_value());\n    }\n  }\n}\n\nAttrMap::AttrMap(const UserOpConf& user_conf)\n    : internal_(std::make_shared<AttrMap::AttrInternal>()) {\n  for (const auto& kv : user_conf.attr()) {\n    auto cpp_attr_value = user_op::AttrValueUtil::ToCppAttrValue(kv.second);\n    if (cpp_attr_value.IsOk()) {\n      ++(internal_->size);\n      internal_->ordered_attr_names->emplace_back(kv.first);\n      internal_->attrs.emplace_back(CHECK_JUST(cpp_attr_value), true);\n      // compute hash code\n      HashCombine(&internal_->hash_value, internal_->attrs.back().first->hash_value());\n    } else {\n      LOG(ERROR) << user_conf.DebugString()\n                 << \" failed to convert to cpp attr value, key: \" << kv.first;\n    }\n  }\n  internal_->max_size = internal_->size;\n}\n\nAttrMap& AttrMap::operator=(const AttrMap& other) {\n  internal_ = other.internal_;\n  return *this;\n}\n\nbool AttrMap::operator==(const AttrMap& other) const {\n  if (internal_->size != other.internal_->size\n      || internal_->hash_value != other.internal_->hash_value) {\n    return false;\n  }\n  for (int i = 0; i < std::min(internal_->size, other.internal_->size); ++i) {\n    if (internal_->attrs[i].second != other.internal_->attrs[i].second) { return false; }\n    if (internal_->attrs[i].second) {\n      if ((*internal_->ordered_attr_names)[i] != (*other.internal_->ordered_attr_names)[i]) {\n        return false;\n      }\n      if (*(internal_->attrs[i].first) != *(other.internal_->attrs[i].first)) { return false; }\n    }\n  }\n  return true;\n}\n\ntemplate<typename T>\nMaybe<const T&> AttrMap::GetAttr(const std::string& attr_name) const {\n  const auto& attr = Attr4Name(attr_name);\n  CHECK_OR_RETURN(attr) << Error::InvalidValueError()\n                        << \"no attribute found. attribute name: \" << attr_name;\n  const auto* ptr = dynamic_cast<const user_op::TypedAttrVal<T>*>(attr.get());\n  CHECK_NOTNULL_OR_RETURN(ptr) << Error::RuntimeError() << \"Ptr should be non-null\";\n  return ptr->val();\n}\n\nconst std::shared_ptr<const user_op::AttrVal>& AttrMap::Attr4Name(\n    const std::string& attr_name) const {\n  int idx = internal_->ordered_attr_names->order(attr_name);\n  if (idx >= 0) { return internal_->attrs[idx].first; }\n  static const std::shared_ptr<const user_op::AttrVal> none;\n  return none;\n}\n\nbool AttrMap::Has(const std::string& attr_name) const { return Attr4Name(attr_name) != nullptr; }\n\nAttrMap::const_iterator::const_iterator(size_t pos, const AttrMap::AttrInternal* internal)\n    : pos_(pos), internal_(internal) {\n  UpdateKV();\n}\n\nAttrMap::const_iterator& AttrMap::const_iterator::operator++() {\n  ++pos_;\n  UpdateKV();\n  return *this;\n}\n\nvoid AttrMap::const_iterator::UpdateKV() {\n  while (pos_ < internal_->max_size) {\n    if (internal_->attrs[pos_].second) { break; }\n    ++pos_;\n  }\n  if (pos_ < internal_->max_size) {\n    kv_.first = (*internal_->ordered_attr_names)[pos_];\n    kv_.second = internal_->attrs[pos_].first;\n  }\n}\n\nstd::string ComposedAttrMap::ToString() const {\n  std::vector<std::string> results;\n  for (const auto& attr : prior_) {\n    results.emplace_back(fmt::format(\"{}={}\", attr.first, attr.second->ToString()));\n  }\n  for (const auto& attr : base_) {\n    if (prior_.Has(attr.first)) { continue; }\n    results.emplace_back(fmt::format(\"{}={}\", attr.first, attr.second->ToString()));\n  }\n  return fmt::format(\"{}\", fmt::join(results, \", \"));\n}\nAttrMap MakeAttrMapFromUserOpConf(const UserOpConf& user_conf) { return AttrMap(user_conf); }\n\ntemplate<typename T>\nMaybe<const T&> ComposedAttrMap::GetAttr(const std::string& attr_name) const {\n  const auto& attr = Attr4Name(attr_name);\n  CHECK_OR_RETURN(attr) << Error::InvalidValueError()\n                        << \"no attribute found. attribute name: \" << attr_name;\n  return dynamic_cast<const user_op::TypedAttrVal<T>*>(attr.get())->val();\n}\n\nconst std::shared_ptr<const user_op::AttrVal>& ComposedAttrMap::Attr4Name(\n    const std::string& attr_name) const {\n  const auto& prior_attr = prior_.Attr4Name(attr_name);\n  if (prior_attr) { return prior_attr; }\n  return base_.Attr4Name(attr_name);\n}\n\nbool ComposedAttrMap::Has(const std::string& attr_name) const {\n  return Attr4Name(attr_name) != nullptr;\n}\n\n#define DEFINE_ATTR_VALUE_MAP_GET_ATTR(field, T, attr_type)                         \\\n  template Maybe<const T&> AttrMap::GetAttr<T>(const std::string& attr_name) const; \\\n  template Maybe<const T&> ComposedAttrMap::GetAttr<T>(const std::string& attr_name) const;\n\nOF_PP_FOR_EACH_TUPLE(DEFINE_ATTR_VALUE_MAP_GET_ATTR, ATTR_SEQ);\n#undef DEFINE_ATTR_VALUE_MAP_GET_ATTR\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/framework/attr_map.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_FRAMEWORK_ATTR_MAP_H_\n#define ONEFLOW_CORE_FRAMEWORK_ATTR_MAP_H_\n\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/common/small_vector.h\"\n\nnamespace oneflow {\n\nnamespace user_op {\nclass AttrVal;\n}\nclass AttrValue;\nclass MutableAttrMap;\nclass UserOpConf;\n\ntemplate<int N>\nclass OrderedStringList;\n\nclass AttrMap final {\n public:\n  AttrMap();\n  AttrMap(const MutableAttrMap& other);\n  AttrMap(const UserOpConf& user_conf);\n\n  AttrMap(const AttrMap&) = default;\n  AttrMap(AttrMap&&) = default;\n  ~AttrMap() = default;\n\n  bool Has(const std::string& attr_name) const;\n\n  template<typename T>\n  Maybe<const T&> GetAttr(const std::string& attr_name) const;\n\n  const std::shared_ptr<const user_op::AttrVal>& Attr4Name(const std::string& attr_name) const;\n\n  AttrMap& operator=(const AttrMap& other);\n\n  bool operator==(const AttrMap& other) const;\n\n  size_t size() const { return internal_->size; }\n  bool empty() const { return internal_->size > 0; }\n\n  size_t hash_value() const { return internal_->hash_value; }\n\n  struct AttrInternal {\n    AttrInternal();\n    AttrInternal(size_t max_size, size_t size, size_t hash_value,\n                 const std::shared_ptr<OrderedStringList<8>>& ordered_attr_names);\n    size_t max_size;\n    size_t size;\n    size_t hash_value;\n    std::shared_ptr<OrderedStringList<8>> ordered_attr_names;\n    small_vector<std::pair<std::shared_ptr<const user_op::AttrVal>, bool>, 8> attrs;\n  };\n\n  class const_iterator {\n   public:\n    using const_reference = const std::pair<std::string, std::shared_ptr<const user_op::AttrVal>>&;\n    using const_pointer = const std::pair<std::string, std::shared_ptr<const user_op::AttrVal>>*;\n\n    const_iterator(size_t pos, const AttrInternal* internal);\n    ~const_iterator() = default;\n\n    const_reference operator*() const { return kv_; }\n    const_pointer operator->() const { return &kv_; }\n\n    const_iterator& operator++();\n    bool operator==(const const_iterator& x) const {\n      return pos_ == x.pos_ && internal_ == x.internal_;\n    }\n    bool operator!=(const const_iterator& x) const { return !(*this == x); }\n\n   private:\n    void UpdateKV();\n\n    size_t pos_;\n    const AttrInternal* internal_;\n    std::pair<std::string, std::shared_ptr<const user_op::AttrVal>> kv_;\n  };\n\n  const_iterator begin() const { return const_iterator(0, internal_.get()); }\n  const_iterator end() const { return const_iterator(internal_->max_size, internal_.get()); }\n\n private:\n  std::shared_ptr<AttrInternal> internal_;\n};\n\nAttrMap MakeAttrMapFromUserOpConf(const UserOpConf& user_conf);\n\nclass ComposedAttrMap final {\n public:\n  ComposedAttrMap(const AttrMap& base) : base_(base) {}\n  ComposedAttrMap(const AttrMap& prior, const AttrMap& base) : prior_(prior), base_(base) {}\n\n  template<typename T>\n  Maybe<const T&> GetAttr(const std::string& attr_name) const;\n\n  const std::shared_ptr<const user_op::AttrVal>& Attr4Name(const std::string& attr_name) const;\n\n  bool Has(const std::string& attr_name) const;\n\n  void ResetPrior(const AttrMap& prior) { prior_ = prior; }\n  void ResetBase(const AttrMap& base) { base_ = base; }\n\n  std::string ToString() const;\n\n private:\n  AttrMap prior_;\n  AttrMap base_;\n};\n\n}  // namespace oneflow\n\nnamespace std {\n\ntemplate<>\nstruct hash<oneflow::AttrMap> final {\n  size_t operator()(const oneflow::AttrMap& attr_map) const { return attr_map.hash_value(); }\n};\n\n}  // namespace std\n\n#endif  // ONEFLOW_CORE_FRAMEWORK_ATTR_MAP_H_\n"
  },
  {
    "path": "oneflow/core/framework/attr_map_test.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"gtest/gtest.h\"\n#include \"oneflow/core/framework/attr_map.h\"\n#include \"oneflow/core/framework/attr_value.h\"\n#include \"oneflow/core/framework/mutable_attr_map.h\"\n\nnamespace oneflow {\nnamespace test {\n\nTEST(AttrMap, basic) {\n  auto& mut_attr_map = THREAD_CACHED_MUTABLE_ATTR_MAP(\"zero\", \"one\", \"zeros\", \"ones\");\n  mut_attr_map.SetAllAttrs(static_cast<int32_t>(0), static_cast<int64_t>(1),\n                           std::vector<int32_t>{0}, std::vector<int64_t>{1});\n  AttrMap attr_map(mut_attr_map);\n  {\n    const auto& val = CHECK_JUST(attr_map.GetAttr<int32_t>(\"zero\"));\n    ASSERT_EQ(val, 0);\n  }\n  {\n    const auto& val = CHECK_JUST(attr_map.GetAttr<int64_t>(\"one\"));\n    ASSERT_EQ(val, 1);\n  }\n  {\n    const auto& val = CHECK_JUST(attr_map.GetAttr<std::vector<int32_t>>(\"zeros\"));\n    ASSERT_EQ(val.size(), 1);\n  }\n  {\n    const auto& val = CHECK_JUST(attr_map.GetAttr<std::vector<int32_t>>(\"zeros\"));\n    ASSERT_EQ(val.at(0), 0);\n  }\n  {\n    const auto& val = CHECK_JUST(attr_map.GetAttr<std::vector<int64_t>>(\"ones\"));\n    ASSERT_EQ(val.size(), 1);\n  }\n  {\n    const auto& val = CHECK_JUST(attr_map.GetAttr<std::vector<int64_t>>(\"ones\"));\n    ASSERT_EQ(val.at(0), 1);\n  }\n}\n\nTEST(AttrMap, hash_value) {\n  HashMap<AttrMap, int32_t> attr_map2int_value;\n  auto& mut_attr_map = THREAD_CACHED_MUTABLE_ATTR_MAP(\"zero\", \"one\", \"zeros\", \"ones\");\n  mut_attr_map.SetAllAttrs(static_cast<int32_t>(0), static_cast<int64_t>(1),\n                           std::vector<int32_t>{0}, std::vector<int64_t>{1});\n  ASSERT_EQ(AttrMap(mut_attr_map).hash_value(), AttrMap(mut_attr_map).hash_value());\n  ASSERT_TRUE(AttrMap(mut_attr_map) == AttrMap(mut_attr_map));\n}\n\nTEST(AttrMap, hash_map) {\n  HashMap<AttrMap, int32_t> attr_map2int_value;\n  auto& mut_attr_map = THREAD_CACHED_MUTABLE_ATTR_MAP(\"zero\", \"one\", \"zeros\", \"ones\");\n  attr_map2int_value[AttrMap(mut_attr_map)] = 0;\n  ASSERT_EQ(attr_map2int_value.at(AttrMap(mut_attr_map)), 0);\n  mut_attr_map.SetAttr<0>(static_cast<int32_t>(0));\n  attr_map2int_value[AttrMap(mut_attr_map)] = 1;\n  ASSERT_EQ(attr_map2int_value.at(AttrMap(mut_attr_map)), 1);\n  mut_attr_map.SetAttr<1>(static_cast<int64_t>(1));\n  attr_map2int_value[AttrMap(mut_attr_map)] = 2;\n  ASSERT_EQ(attr_map2int_value.at(AttrMap(mut_attr_map)), 2);\n  mut_attr_map.SetAttr<2>(std::vector<int32_t>{0});\n  attr_map2int_value[AttrMap(mut_attr_map)] = 3;\n  ASSERT_EQ(attr_map2int_value.at(AttrMap(mut_attr_map)), 3);\n  mut_attr_map.SetAttr<3>(std::vector<int64_t>{1});\n  attr_map2int_value[AttrMap(mut_attr_map)] = 4;\n  ASSERT_EQ(attr_map2int_value.at(AttrMap(mut_attr_map)), 4);\n}\n\n}  // namespace test\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/framework/attr_value.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/attr_value.h\"\n\nnamespace oneflow {\n\ntemplate<typename T>\nconst T& AttrValueCast(const user_op::AttrVal& attr_val) {\n  const auto* typed_attr = dynamic_cast<const user_op::TypedAttrValIf<T>*>(&attr_val);\n  return CHECK_NOTNULL(typed_attr)->val();\n}\n\ntemplate<typename T>\nstd::shared_ptr<user_op::AttrVal> CastAttrValue(const T& attr_val) {\n  return std::make_shared<user_op::TypedAttrVal<T>>(attr_val);\n}\n\ntemplate<typename T>\nstd::shared_ptr<user_op::AttrVal> CastAttrValue(const T* attr_val) {\n  return std::make_shared<user_op::TypedAttrValRef<T>>(attr_val);\n}\n\ntemplate<typename T>\nsize_t HashTypedAttrVal(const T& val) {\n  return std::hash<T>()(val);\n}\n\n#define INITIALIZE_ATTR_VALUE_CAST(field, T, attr_type)                        \\\n  template const T& AttrValueCast(const user_op::AttrVal& attr_val);           \\\n  template std::shared_ptr<user_op::AttrVal> CastAttrValue(const T& attr_val); \\\n  template std::shared_ptr<user_op::AttrVal> CastAttrValue(const T* attr_val); \\\n  template size_t HashTypedAttrVal(const T& attr_val);\n\nOF_PP_FOR_EACH_TUPLE(INITIALIZE_ATTR_VALUE_CAST, ATTR_SEQ)\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/framework/attr_value.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_FRAMEWORK_ATTR_VALUE_H_\n#define ONEFLOW_CORE_FRAMEWORK_ATTR_VALUE_H_\n\n#include <complex>\n#include \"fmt/core.h\"\n#include \"oneflow/core/framework/device.h\"\n#include \"oneflow/core/framework/user_op_attr.pb.h\"\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/common/hash.h\"\n#include \"oneflow/core/common/shape.h\"\n#include \"oneflow/core/common/stride.h\"\n#include \"oneflow/core/common/data_type.h\"\n#include \"oneflow/core/common/protobuf.h\"\n\nnamespace oneflow {\n\ntemplate<typename T>\nsize_t HashTypedAttrVal(const T& val);\n\nnamespace user_op {\n\n// SEQ\n#define BASIC_ATTR_SEQ                                         \\\n  OF_PP_MAKE_TUPLE_SEQ(at_int32, int32_t, AttrType::kAtInt32)  \\\n  OF_PP_MAKE_TUPLE_SEQ(at_int64, int64_t, AttrType::kAtInt64)  \\\n  OF_PP_MAKE_TUPLE_SEQ(at_bool, bool, AttrType::kAtBool)       \\\n  OF_PP_MAKE_TUPLE_SEQ(at_float, float, AttrType::kAtFloat)    \\\n  OF_PP_MAKE_TUPLE_SEQ(at_double, double, AttrType::kAtDouble) \\\n  OF_PP_MAKE_TUPLE_SEQ(at_string, std::string, AttrType::kAtString)\n\n#define ENUM_ATTR_SEQ                                                 \\\n  OF_PP_MAKE_TUPLE_SEQ(at_data_type, DataType, AttrType::kAtDataType) \\\n  OF_PP_MAKE_TUPLE_SEQ(at_memory_format, MemoryFormat, AttrType::kAtMemoryFormat)\n\n#define MESSAGE_ATTR_SEQ                                    \\\n  OF_PP_MAKE_TUPLE_SEQ(at_shape, Shape, AttrType::kAtShape) \\\n  OF_PP_MAKE_TUPLE_SEQ(at_stride, Stride, AttrType::kAtStride)\n\n#define BYTES_ATTR_SEQ OF_PP_MAKE_TUPLE_SEQ(at_bytes, std::vector<char>, AttrType::kAtBytes)\n\n#define LIST_BASIC_ATTR_SEQ                                                         \\\n  OF_PP_MAKE_TUPLE_SEQ(at_list_int32, std::vector<int32_t>, AttrType::kAtListInt32) \\\n  OF_PP_MAKE_TUPLE_SEQ(at_list_int64, std::vector<int64_t>, AttrType::kAtListInt64) \\\n  OF_PP_MAKE_TUPLE_SEQ(at_list_float, std::vector<float>, AttrType::kAtListFloat)\n\n#define LIST_ENUM_ATTR_SEQ \\\n  OF_PP_MAKE_TUPLE_SEQ(at_list_data_type, std::vector<DataType>, AttrType::kAtListDataType)\n\n#define LIST_MESSAGE_ATTR_SEQ                                                     \\\n  OF_PP_MAKE_TUPLE_SEQ(at_list_shape, std::vector<Shape>, AttrType::kAtListShape) \\\n  OF_PP_MAKE_TUPLE_SEQ(at_list_stride, std::vector<Stride>, AttrType::kAtListStride)\n\n#define LIST_STRING_ATTR_SEQ \\\n  OF_PP_MAKE_TUPLE_SEQ(at_list_string, std::vector<std::string>, AttrType::kAtListString)\n\n#define DEVICE_ATTR_SEQ OF_PP_MAKE_TUPLE_SEQ(at_device, Symbol<Device>, AttrType::kAtDevice)\n\n#define COMPLEX_DOUBLE_ATTR_SEQ \\\n  OF_PP_MAKE_TUPLE_SEQ(at_complex_double, std::complex<double>, AttrType::kAtComplexDouble)\n\n#define ATTR_SEQ        \\\n  BASIC_ATTR_SEQ        \\\n  ENUM_ATTR_SEQ         \\\n  MESSAGE_ATTR_SEQ      \\\n  BYTES_ATTR_SEQ        \\\n  LIST_BASIC_ATTR_SEQ   \\\n  LIST_ENUM_ATTR_SEQ    \\\n  LIST_MESSAGE_ATTR_SEQ \\\n  LIST_STRING_ATTR_SEQ  \\\n  DEVICE_ATTR_SEQ       \\\n  COMPLEX_DOUBLE_ATTR_SEQ\n\n// Type Trait: GetAttrType, GetCppType\n\ntemplate<typename T>\nstruct GetAttrType;\n\ntemplate<AttrType AttrT>\nstruct GetCppType;\n\n#define SPECIALIZE_GET_ATTR_TYPE(field, type_cpp, type_proto)                     \\\n  template<>                                                                      \\\n  struct GetAttrType<type_cpp> : std::integral_constant<AttrType, type_proto> {}; \\\n  template<>                                                                      \\\n  struct GetCppType<type_proto> {                                                 \\\n    typedef type_cpp type;                                                        \\\n  };\nOF_PP_FOR_EACH_TUPLE(SPECIALIZE_GET_ATTR_TYPE, ATTR_SEQ);\n#undef SPECIALIZE_GET_ATTR_TYPE\n\nclass AttrVal {\n public:\n  AttrVal() = default;\n  virtual ~AttrVal() = default;\n\n  virtual AttrType type() const = 0;\n  virtual size_t hash_value() const = 0;\n  virtual std::string ToString() const = 0;\n\n  virtual const void* Ptr() const = 0;\n  virtual bool operator==(const AttrVal& other) const = 0;\n  bool operator!=(const AttrVal& other) const { return !(*this == other); }\n\n private:\n  OF_DISALLOW_COPY_AND_MOVE(AttrVal);\n};\n\ntemplate<typename T>\nclass TypedAttrValIf : public AttrVal {\n public:\n  virtual const T& val() const = 0;\n  size_t hash_value() const override { return std::hash<T>()(val()); }\n  std::string ToString() const override { return fmt::format(\"{}\", val()); }\n\n  AttrType type() const override { return GetAttrType<T>::value; }\n\n  bool operator==(const AttrVal& other) const override {\n    if (other.type() != GetAttrType<T>::value) { return false; }\n    return *static_cast<const T*>(Ptr()) == *static_cast<const T*>(other.Ptr());\n  }\n};\n\ntemplate<typename T>\nclass TypedAttrVal final : public TypedAttrValIf<T> {\n public:\n  TypedAttrVal(T v) : val_(v) {}\n  ~TypedAttrVal() = default;\n\n  const T& val() const override { return val_; }\n  const void* Ptr() const override { return static_cast<const void*>(&val_); }\n\n  size_t hash_value() const override { return std::hash<T>()(val_); }\n\n private:\n  OF_DISALLOW_COPY_AND_MOVE(TypedAttrVal);\n\n  T val_;\n};\n\ntemplate<typename T>\nclass TypedAttrValRef final : public TypedAttrValIf<T> {\n public:\n  TypedAttrValRef(const T* v) : val_(v) {}\n  ~TypedAttrValRef() = default;\n\n  const T& val() const override { return *val_; }\n  const void* Ptr() const override { return static_cast<const void*>(val_); }\n\n  size_t hash_value() const override { return std::hash<T>()(*val_); }\n\n private:\n  OF_DISALLOW_COPY_AND_MOVE(TypedAttrValRef);\n\n  const T* val_;\n};\n\n}  // namespace user_op\n\ntemplate<typename T>\nconst T& AttrValueCast(const user_op::AttrVal& val);\n\ntemplate<typename T>\nstd::shared_ptr<user_op::AttrVal> CastAttrValue(const T& attr_val);\n\ntemplate<typename T>\nstd::shared_ptr<user_op::AttrVal> CastAttrValue(const T* attr_val);\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_FRAMEWORK_ATTR_VALUE_H_\n"
  },
  {
    "path": "oneflow/core/framework/attr_value_accessor.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/attr_value_accessor.h\"\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/common/shape.h\"\n#include \"oneflow/core/common/stride.h\"\n#include \"oneflow/core/common/protobuf.h\"\n#include \"oneflow/core/framework/attr_value.h\"\n#include \"oneflow/core/framework/device.h\"\n#include \"oneflow/core/framework/to_string.h\"\n#include \"oneflow/core/framework/user_op_conf.h\"\n\nnamespace oneflow {\n\nnamespace user_op {\n\n// Basic and Enum Attr\n#define BASIC_AND_ENUM_ATTR_SEQ_ENTRY(field, cpp_type, attr_type)                        \\\n  template<>                                                                             \\\n  cpp_type AttrValueAccessor<cpp_type>::Attr(const AttrValue& val) {                     \\\n    CHECK(val.has_##field());                                                            \\\n    return val.field();                                                                  \\\n  }                                                                                      \\\n  template<>                                                                             \\\n  void AttrValueAccessor<cpp_type>::Attr(const cpp_type& cpp_val, AttrValue* attr_val) { \\\n    attr_val->set_##field(cpp_val);                                                      \\\n  }\n\n#define BASIC_AND_ENUM_ATTR_SEQ \\\n  BASIC_ATTR_SEQ                \\\n  ENUM_ATTR_SEQ\n\nOF_PP_FOR_EACH_TUPLE(BASIC_AND_ENUM_ATTR_SEQ_ENTRY, BASIC_AND_ENUM_ATTR_SEQ)\n\n#undef BASIC_AND_ENUM_ATTR_SEQ\n#undef BASIC_AND_ENUM_ATTR_SEQ_ENTRY\n\n// Customized Message Attr\ntemplate<>\nShape AttrValueAccessor<Shape>::Attr(const AttrValue& val) {\n  return Shape(val.at_shape());\n}\n\ntemplate<>\nvoid AttrValueAccessor<Shape>::Attr(const Shape& cpp_val, AttrValue* attr_val) {\n  cpp_val.ToProto(attr_val->mutable_at_shape());\n}\n\ntemplate<>\nStride AttrValueAccessor<Stride>::Attr(const AttrValue& val) {\n  return Stride(val.at_stride());\n}\n\ntemplate<>\nvoid AttrValueAccessor<Stride>::Attr(const Stride& cpp_val, AttrValue* attr_val) {\n  cpp_val.ToProto(attr_val->mutable_at_stride());\n}\n\ntemplate<>\nSymbol<Device> AttrValueAccessor<Symbol<Device>>::Attr(const AttrValue& val) {\n  auto pb_device = val.at_device();\n  return CHECK_JUST(Device::New(*CHECK_JUST(DeviceTag4DeviceType(pb_device.device_type())),\n                                pb_device.device_id(), pb_device.rematable()));\n}\n\ntemplate<>\nvoid AttrValueAccessor<Symbol<Device>>::Attr(const Symbol<Device>& cpp_val, AttrValue* attr_val) {\n  attr_val->mutable_at_device()->set_device_type(cpp_val->enum_type());\n  attr_val->mutable_at_device()->set_device_id(cpp_val->device_id());\n  attr_val->mutable_at_device()->set_rematable(cpp_val->rematable());\n}\n\ntemplate<>\nstd::vector<char> AttrValueAccessor<std::vector<char>>::Attr(const AttrValue& val) {\n  return std::vector<char>(val.at_bytes().begin(), val.at_bytes().end());\n}\n\ntemplate<>\nvoid AttrValueAccessor<std::vector<char>>::Attr(const std::vector<char>& cpp_val,\n                                                AttrValue* attr_val) {\n  attr_val->mutable_at_bytes()->assign(cpp_val.data(), cpp_val.size());\n}\n\n// List of Basic Attr\n#define LIST_BASIC_ATTR_SEQ_ENTRY(field, cpp_type, attr_type)                                   \\\n  template<>                                                                                    \\\n  cpp_type AttrValueAccessor<cpp_type>::Attr(const AttrValue& val) {                            \\\n    return PbRf2StdVec<cpp_type::value_type>(val.field().val());                                \\\n  }                                                                                             \\\n  template<>                                                                                    \\\n  void AttrValueAccessor<cpp_type>::Attr(const cpp_type& cpp_val, AttrValue* attr_val) {        \\\n    *(attr_val->mutable_##field()->mutable_val()) = StdVec2PbRf<cpp_type::value_type>(cpp_val); \\\n  }\n\nOF_PP_FOR_EACH_TUPLE(LIST_BASIC_ATTR_SEQ_ENTRY, LIST_BASIC_ATTR_SEQ)\n\n#undef LIST_BASIC_ATTR_SEQ_ENTRY\n\n// List of Enum Attr\n#define LIST_ENUM_ATTR_SEQ_ENTRY(field, cpp_type, attr_type)                                   \\\n  template<>                                                                                   \\\n  cpp_type AttrValueAccessor<cpp_type>::Attr(const AttrValue& val) {                           \\\n    std::vector<cpp_type::value_type> ret;                                                     \\\n    ret.reserve(val.field().val_size());                                                       \\\n    for (const auto& value : val.field().val()) {                                              \\\n      ret.emplace_back(static_cast<cpp_type::value_type>(value));                              \\\n    }                                                                                          \\\n    return ret;                                                                                \\\n  }                                                                                            \\\n  template<>                                                                                   \\\n  void AttrValueAccessor<cpp_type>::Attr(const cpp_type& cpp_val, AttrValue* attr_val) {       \\\n    using proto_type = std::remove_reference_t<decltype(attr_val->field().val())>::value_type; \\\n    std::vector<proto_type> vec;                                                               \\\n    vec.reserve(cpp_val.size());                                                               \\\n    for (const auto& value : cpp_val) { vec.emplace_back(static_cast<proto_type>(value)); }    \\\n    *(attr_val->mutable_##field()->mutable_val()) = StdVec2PbRf<proto_type>(vec);              \\\n  }\n\nOF_PP_FOR_EACH_TUPLE(LIST_ENUM_ATTR_SEQ_ENTRY, LIST_ENUM_ATTR_SEQ)\n\n#undef LIST_ENUM_ATTR_SEQ_ENTRY\n\n// List of Customized Message Attr\ntemplate<>\nstd::vector<Shape> AttrValueAccessor<std::vector<Shape>>::Attr(const AttrValue& val) {\n  std::vector<Shape> ret;\n  ret.reserve(val.at_list_shape().val_size());\n  for (const auto& value : val.at_list_shape().val()) { ret.emplace_back(value); }\n  return ret;\n}\ntemplate<>\nvoid AttrValueAccessor<std::vector<Shape>>::Attr(const std::vector<Shape>& cpp_val,\n                                                 AttrValue* attr_val) {\n  attr_val->mutable_at_list_shape()->clear_val();\n  FOR_RANGE(int32_t, i, 0, cpp_val.size()) {\n    cpp_val.at(i).ToProto(attr_val->mutable_at_list_shape()->add_val());\n  }\n}\ntemplate<>\nstd::vector<Stride> AttrValueAccessor<std::vector<Stride>>::Attr(const AttrValue& val) {\n  std::vector<Stride> ret;\n  ret.reserve(val.at_list_stride().val_size());\n  for (const auto& value : val.at_list_stride().val()) { ret.emplace_back(value); }\n  return ret;\n}\ntemplate<>\nvoid AttrValueAccessor<std::vector<Stride>>::Attr(const std::vector<Stride>& cpp_val,\n                                                  AttrValue* attr_val) {\n  attr_val->mutable_at_list_stride()->clear_val();\n  FOR_RANGE(int32_t, i, 0, cpp_val.size()) {\n    cpp_val.at(i).ToProto(attr_val->mutable_at_list_stride()->add_val());\n  }\n}\n// List of String Attr\ntemplate<>\nstd::vector<std::string> AttrValueAccessor<std::vector<std::string>>::Attr(const AttrValue& val) {\n  return PbRpf2StdVec<std::string>(val.at_list_string().val());\n}\ntemplate<>\nvoid AttrValueAccessor<std::vector<std::string>>::Attr(const std::vector<std::string>& cpp_val,\n                                                       AttrValue* attr_val) {\n  *(attr_val->mutable_at_list_string()->mutable_val()) = StdVec2PbRpf<std::string>(cpp_val);\n}\n// ComplexDouble Attr\ntemplate<>\nstd::complex<double> AttrValueAccessor<std::complex<double>>::Attr(const AttrValue& val) {\n  std::complex<double> ret{val.at_complex_double().real(), val.at_complex_double().imag()};\n  return ret;\n}\ntemplate<>\nvoid AttrValueAccessor<std::complex<double>>::Attr(const std::complex<double>& cpp_val,\n                                                   AttrValue* attr_val) {\n  attr_val->mutable_at_complex_double()->set_real(cpp_val.real());\n  attr_val->mutable_at_complex_double()->set_imag(cpp_val.imag());\n}\n\ntemplate<typename ProtoT>\nMaybe<AttrVal> MakeCppAttrValueFromProtoAttrValue(const ProtoT& attr_value) {\n  switch (static_cast<int>(attr_value.value_case())) {\n#define MAKE_ENTRY(field, T, attr_type)       \\\n  case static_cast<int>(attr_type):           \\\n    return std::static_pointer_cast<AttrVal>( \\\n        std::make_shared<TypedAttrVal<T>>(AttrValueAccessor<T>::Attr(attr_value)));\n    OF_PP_FOR_EACH_TUPLE(MAKE_ENTRY, ATTR_SEQ);\n#undef MAKE_ENTRY\n    default: OF_UNIMPLEMENTED();\n  }\n}\n\n/* static */ Maybe<AttrVal> AttrValueUtil::ToCppAttrValue(const AttrValue& proto_attr_value) {\n  return MakeCppAttrValueFromProtoAttrValue(proto_attr_value);\n}\n\n/* static */ Maybe<void> AttrValueUtil::ToProtoAttrValue(const AttrVal& cpp_attr_value,\n                                                         AttrValue* attr_value) {\n  if (false) {\n// clang-format off\n#define MAKE_ENTRY(field, cpp_type, attr_type)                                        \\\n  }                                                                                   \\\n  else if (dynamic_cast<const TypedAttrValIf<cpp_type>*>(&cpp_attr_value) != nullptr) { \\\n    const auto* ptr = dynamic_cast<const TypedAttrValIf<cpp_type>*>(&cpp_attr_value);   \\\n    AttrValueAccessor<cpp_type>::Attr(ptr->val(), attr_value);\n    OF_PP_FOR_EACH_TUPLE(MAKE_ENTRY, ATTR_SEQ);\n#undef MAKE_ENTRY\n    // clang-format on\n  } else {\n    OF_UNIMPLEMENTED();\n  }\n  return Maybe<void>::Ok();\n}\n\n}  // namespace user_op\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/framework/attr_value_accessor.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_FRAMEWORK_ATTR_VAL_ACCESSOR_H_\n#define ONEFLOW_CORE_FRAMEWORK_ATTR_VAL_ACCESSOR_H_\n\n#include \"oneflow/core/common/maybe.h\"\n\nnamespace oneflow {\n\nclass AttrValue;\n\nnamespace user_op {\n\ntemplate<typename T>\nstruct AttrValueAccessor final {\n  static T Attr(const AttrValue&);\n  static void Attr(const T&, AttrValue*);\n};\n\nclass AttrVal;\n\nstruct AttrValueUtil final {\n  static Maybe<AttrVal> ToCppAttrValue(const AttrValue& proto_attr_value);\n  static Maybe<void> ToProtoAttrValue(const AttrVal& cpp_attr_value, AttrValue* attr_value);\n};\n\n}  // namespace user_op\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_FRAMEWORK_ATTR_VAL_ACCESSOR_H_\n"
  },
  {
    "path": "oneflow/core/framework/auto_random_generator.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/auto_random_generator.h\"\n\n#include \"oneflow/core/common/str_util.h\"\n#include \"oneflow/core/control/global_process_ctx.h\"\n#include \"oneflow/core/ep/include/device_manager_registry.h\"\n#include \"oneflow/core/framework/device.h\"\n#include \"oneflow/core/framework/random_generator.h\"\n#include \"oneflow/core/platform/include/pthread_fork.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstruct AutoGeneratorState {\n  uint64_t seed = 0;\n  int64_t num = 0;\n  int64_t device_tag_length = 0;\n  int64_t state_length = 0;\n  // std::vector<int64_t> state_sizes[num];\n  // std::vector<uint8_t> device_tags[device_tag_length];\n  // std::vector<uint8_t> states[state_sizes[0] + state_sizes[1] + ... + state_sizes[num - 1]]\n};\n\nvoid AutoGenerator::set_current_seed(uint64_t seed) {\n  std::lock_guard<std::mutex> lock(mutex_);\n  seed_ = seed;\n  for (const auto& it : generators_) {\n    if (unlikely(pthread_fork::IsForkedSubProcess() && it.first->type() != \"cpu\")) { continue; }\n    it.second->set_current_seed(seed);\n  }\n}\n\nsize_t AutoGenerator::GetStateSize() const {\n  std::lock_guard<std::mutex> lock(mutex_);\n  size_t state_size = sizeof(AutoGeneratorState) + generators_.size() * sizeof(uint64_t);\n  std::stringstream ss;\n  auto it = generators_.begin();\n  if (it != generators_.end()) {\n    ss << it->second->device_type_name() << \":\" << it->second->device_index();\n    ++it;\n  }\n  for (; it != generators_.end(); ++it) {\n    ss << \",\" << it->second->device_type_name() << \":\" << it->second->device_index();\n  }\n  state_size += ss.str().size();\n  for (const auto& it : generators_) { state_size += it.second->GetStateSize(); }\n  return state_size;\n}\n\nvoid AutoGenerator::GetState(size_t state_size, void* state) const {\n  std::lock_guard<std::mutex> lock(mutex_);\n  AutoGeneratorState state_info;\n  state_info.seed = current_seed();\n  state_info.num = generators_.size();\n  state_info.state_length = 0;\n  std::vector<int64_t> state_sizes;\n  state_sizes.reserve(generators_.size());\n\n  for (auto it = generators_.begin(); it != generators_.end(); ++it) {\n    state_sizes.emplace_back(it->second->GetStateSize());\n    state_info.state_length += state_sizes.back();\n  }\n  std::stringstream ss;\n  auto it = generators_.begin();\n  if (it != generators_.end()) {\n    ss << it->second->device_type_name() << \":\" << it->second->device_index();\n    ++it;\n  }\n  for (; it != generators_.end(); ++it) {\n    ss << \",\" << it->second->device_type_name() << \":\" << it->second->device_index();\n  }\n\n  std::string device_tags = ss.str();\n  state_info.device_tag_length = device_tags.size();\n  size_t total_size = sizeof(AutoGeneratorState) + state_info.num * sizeof(int64_t)\n                      + state_info.device_tag_length + state_info.state_length;\n  CHECK_EQ_OR_THROW(state_size, total_size)\n      << \"the state size of auto generator should be equal to \" << total_size;\n  {\n    uint8_t* data = static_cast<uint8_t*>(state);\n    memcpy(data, &state_info, sizeof(AutoGeneratorState));\n    data += sizeof(AutoGeneratorState);\n    memcpy(data, state_sizes.data(), state_info.num * sizeof(int64_t));\n    data += state_info.num * sizeof(int64_t);\n    memcpy(data, device_tags.data(), state_info.device_tag_length);\n    data += state_info.device_tag_length;\n    int i = 0;\n    for (auto it = generators_.begin(); it != generators_.end(); ++it, ++i) {\n      it->second->GetState(state_sizes[i], data);\n      data += state_sizes[i];\n    }\n  }\n}\n\nvoid AutoGenerator::SetState(size_t state_size, const void* state) {\n  AutoGeneratorState state_info;\n  const uint8_t* data = static_cast<const uint8_t*>(state);\n  memcpy(reinterpret_cast<void*>(&state_info), data, sizeof(AutoGeneratorState));\n  if (state_size\n      != sizeof(AutoGeneratorState) + state_info.num * sizeof(int64_t)\n             + state_info.device_tag_length + state_info.state_length) {\n    return THROW(RuntimeError) << \"Invalid auto generator state, size is not match.\";\n  }\n  data += sizeof(AutoGeneratorState);\n  std::vector<int64_t> state_sizes(state_info.num);\n  std::vector<const void*> state_data(state_info.num);\n  memcpy(state_sizes.data(), data, state_info.num * sizeof(int64_t));\n  data += state_info.num * sizeof(int64_t);\n  std::string device_tags;\n  device_tags.resize(state_info.device_tag_length);\n  memcpy(const_cast<char*>(device_tags.data()), data, state_info.device_tag_length);\n  data += state_info.device_tag_length;\n\n  for (int i = 0; i < state_info.num; ++i) {\n    state_data[i] = data;\n    data += state_sizes[i];\n  }\n  // set current seed.\n  set_current_seed(state_info.seed);\n\n  std::vector<std::string> splits;\n  Split(device_tags, \",\", [&](std::string&& s) { splits.emplace_back(s); });\n  if (splits.size() != state_info.num) {\n    return THROW(RuntimeError) << \"Invalid auto generator state. The number of state is \"\n                               << state_info.num << \", but device tags number is \" << splits.size();\n  }\n  std::lock_guard<std::mutex> lock(mutex_);\n\n  for (int i = 0; i < splits.size(); ++i) {\n    const auto& device = CHECK_JUST(Device::ParseAndNew(splits[i]));\n    auto generator = CHECK_JUST(GetOrCreate(device->type(), device->device_id()));\n    generator->SetState(state_sizes[i], state_data[i]);\n  }\n}\n\nMaybe<ep::RandomGenerator> AutoGenerator::GetOrCreate(const std::string& device, int device_index) {\n  if (device_index == -1) { device_index = (device == \"cpu\" ? 0 : GlobalProcessCtx::LocalRank()); }\n  std::lock_guard<std::mutex> lock(mutex_);\n  auto device_key = JUST(Device::New(device, device_index));\n  auto it = generators_.find(device_key);\n  if (it == generators_.end()) {\n    auto device_type = ep::DeviceManagerRegistry::GetDeviceTypeByDeviceTypeName(device);\n    if (device_type == DeviceType::kInvalidDevice) {\n      return Error::RuntimeError() << \"Expected one of \" << PrintGeneratorAvailableDevices()\n                                   << \" device type at start of device string: \" << device;\n    }\n    auto device_mgr = Singleton<ep::DeviceManagerRegistry>::Get()->GetDeviceManager(device_type);\n    it = generators_.emplace(device_key, device_mgr->CreateRandomGenerator(seed_, device_index))\n             .first;\n  }\n  return it->second;\n}\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/framework/auto_random_generator.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_FRAMEWORK_AUTO_RANDOM_GENERATOR_H_\n#define ONEFLOW_CORE_FRAMEWORK_AUTO_RANDOM_GENERATOR_H_\n\n#include <mutex>\n#include <unordered_map>\n#include <vector>\n\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/ep/include/random_generator.h\"\n#include \"oneflow/core/framework/device.h\"\n\nnamespace oneflow {\nnamespace one {\n\nclass AutoGenerator : public ep::RandomGenerator {\n public:\n  AutoGenerator(uint64_t seed) : seed_(seed) {}\n  virtual ~AutoGenerator() = default;\n\n  uint64_t current_seed() const override { return seed_; }\n  void set_current_seed(uint64_t seed) override;\n\n  std::string device_type_name() const override { return \"auto\"; }\n  int64_t device_index() const override { return 0; }\n\n  size_t GetStateSize() const override;\n  void GetState(size_t state_size, void* state) const override;\n  void SetState(size_t state_size, const void* state) override;\n\n  Maybe<ep::RandomGenerator> GetOrCreate(const std::string& device, int device_index);\n\n  template<typename T>\n  Maybe<T> GetOrCreate(int device_index) {\n    return std::dynamic_pointer_cast<T>(\n        JUST(GetOrCreate(ep::GetRandomGeneratorDeviceTypeName<T>(), device_index)));\n  }\n\n private:\n  mutable std::mutex mutex_;\n  uint64_t seed_;\n  std::unordered_map<Symbol<Device>, std::shared_ptr<ep::RandomGenerator>> generators_;\n};\n\n}  // namespace one\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_FRAMEWORK_AUTO_RANDOM_GENERATOR_H_\n"
  },
  {
    "path": "oneflow/core/framework/autocast.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/throw.h\"\n#include \"oneflow/core/framework/autocast.h\"\n#include \"oneflow/core/job_rewriter/auto_mixed_precision.h\"\n#include \"oneflow/core/job_rewriter/auto_mixed_precision_lists.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\nnamespace autocast {\n\nnamespace {\n\nbool* autocast_enabled() {\n  static thread_local bool autocast_enabled = false;\n  return &autocast_enabled;\n}\nDeviceType* autocast_device_type() {\n  static thread_local DeviceType autocast_device_type = kCUDA;\n  return &autocast_device_type;\n}\nSymbol<DType>* autocast_dtype() {\n  static thread_local Symbol<DType> autocast_dtype = DType::Float16();\n  return &autocast_dtype;\n}\nSymbol<DType>* autocast_cpu_dtype() {\n  static thread_local Symbol<DType> autocast_cpu_dtype = DType::BFloat16();\n  return &autocast_cpu_dtype;\n}\nSymbol<DType>* autocast_gpu_dtype() {\n  static thread_local Symbol<DType> autocast_gpu_dtype = DType::Float16();\n  return &autocast_gpu_dtype;\n}\nbool* cache_enabled() {\n  static thread_local bool cache_enabled = true;\n  return &cache_enabled;\n}\n\ninline Symbol<DType> get_lower_precision_fp_from_device_type(DeviceType device_type) {\n  if (device_type == DeviceType::kCPU) { return get_autocast_cpu_dtype(); };\n  return get_autocast_gpu_dtype();\n}\n\n// The structure below is referenced from PyTorch:\n// https://github.com/pytorch/pytorch/blob/41d79695907cd4105b8e7167cf8a57ba48e1f079/aten/src/ATen/autocast_mode.cpp#L60-L63\n// The weakref keeps the source's TensorImpl from being deleted.  We need to because we're\n// using the source TensorImpl* as the key.  If it were deleted, another random Tensor could\n// be allocated whose TensorImpl* happened to have the same value.  This TensorImpl* would\n// then mistakenly hit in cache: a rare, intermittent, unpredictable bug.\nusing val_type = std::pair<std::weak_ptr<one::Tensor>, std::shared_ptr<one::Tensor>>;\nusing key_type = std::pair<const one::EagerLocalTensorImpl*, DataType>;\nusing cached_map = std::unordered_map<key_type, val_type>;\n\nstd::unordered_map<key_type, val_type>* cached_casts() {\n  static thread_local std::unordered_map<key_type, val_type> cached_casts;\n  return &cached_casts;\n}\n\n}  // namespace\n\nbool is_enabled() { return *autocast_enabled(); }\nvoid set_enabled(bool enabled) { *autocast_enabled() = enabled; }\n\nDeviceType get_autocast_device_type() { return *autocast_device_type(); }\nvoid set_autocast_device_type(DeviceType device_type) { *autocast_device_type() = device_type; }\n\nSymbol<DType> get_autocast_dtype() { return *autocast_dtype(); }\nSymbol<DType> get_autocast_cpu_dtype() { return *autocast_cpu_dtype(); }\nSymbol<DType> get_autocast_gpu_dtype() { return *autocast_gpu_dtype(); }\n\nvoid set_autocast_dtype(Symbol<DType> dtype) { *autocast_dtype() = dtype; }\nvoid set_autocast_cpu_dtype(Symbol<DType> dtype) { *autocast_cpu_dtype() = dtype; }\nvoid set_autocast_gpu_dtype(Symbol<DType> dtype) { *autocast_gpu_dtype() = dtype; }\n\nbool is_autocast_cache_enabled() { return *cache_enabled(); }\nvoid set_autocast_cache_enabled(bool enabled) { *cache_enabled() = enabled; }\n\nMaybe<one::Tensor> cached_cast(const std::shared_ptr<one::Tensor>& tensor, Symbol<DType> cast_type,\n                               DeviceType device_type) {\n  bool use_cache = (is_autocast_cache_enabled() && tensor->requires_grad()\n                    && cast_type == get_lower_precision_fp_from_device_type(device_type)\n                    && tensor->dtype()->data_type() == DataType::kFloat && tensor->is_leaf()\n                    && !tensor->is_view());\n  if (use_cache) {\n    auto it = cached_casts()->find(\n        std::make_pair(JUST(tensor->mut_eager_local_tensor_impl()), cast_type->data_type()));\n    if (it == cached_casts()->end() || it->second.first.lock() == nullptr) {\n      const std::shared_ptr<one::Tensor>& result =\n          JUST(one::functional::To(tensor, cast_type, /*copy*/ false));\n      if (it == cached_casts()->end()) {\n        cached_casts()->emplace(\n            std::make_pair(JUST(tensor->mut_eager_local_tensor_impl()), cast_type->data_type()),\n            std::make_pair(tensor->weak_from_this(), result));\n      } else {\n        it->second.first = tensor->weak_from_this();\n        it->second.second = result;\n      }\n      return result;\n    } else {\n      return it->second.second;\n    }\n  } else {\n    return one::functional::To(tensor, cast_type, /*copy*/ false);\n  }\n};\n\nvoid clear_cache() { cached_casts()->clear(); }\n\nAutoCastColor AutoCastMeta::autocast_color() const { return autocast_color_; }\n\nvoid AutoCastMeta::set_autocast_color(AutoCastColor color) { autocast_color_ = color; }\n\nbool AutoCastMeta::is_autocast_eligible(DeviceType device_type, Symbol<DType> dtype) const {\n  int device_index = static_cast<int>(device_type);\n  if (is_autocast_eligible_.size() > device_index) {\n    int dtype_index = static_cast<int>(dtype->data_type());\n    if (is_autocast_eligible_[device_index].size() > dtype_index) {\n      return is_autocast_eligible_[device_index][dtype_index];\n    }\n  }\n  return false;\n}\n\nvoid AutoCastMeta::set_autocast_eligible(DeviceType device_type, Symbol<DType> dtype) {\n  int device_index = static_cast<int>(device_type);\n  while (is_autocast_eligible_.size() <= device_index) {\n    is_autocast_eligible_.resize(device_index + 1);\n  }\n  int dtype_index = static_cast<int>(dtype->data_type());\n  while (is_autocast_eligible_[device_index].size() <= dtype_index) {\n    is_autocast_eligible_[device_index].resize(dtype_index + 1);\n  }\n  is_autocast_eligible_[device_index][dtype_index] = true;\n}\n\nbool AutoCastMeta::is_args_autocast_eligible(int arg_index) const {\n  CHECK_LT_OR_THROW(arg_index, is_args_autocast_eligible_.size());  // NOLINT\n  return is_args_autocast_eligible_[arg_index];\n}\n\nconst std::vector<bool>& AutoCastMeta::is_args_autocast_eligible() const {\n  return is_args_autocast_eligible_;\n}\n\nvoid AutoCastMeta::set_arg_autocast_eligible(int arg_index) {\n  CHECK_LT_OR_THROW(arg_index, is_args_autocast_eligible_.size());  // NOLINT\n  is_args_autocast_eligible_[arg_index] = true;\n}\n\nstd::shared_ptr<AutoCastMeta> MakeAutoCastMeta(\n    const std::string& op_type_name,\n    const std::vector<std::pair<std::string, int32_t>>& input_args) {\n  auto autocast_meta = std::make_shared<AutoCastMeta>(input_args.size());\n  if (AutoMixedPrecisionLists::WhiteList().count(op_type_name)) {\n    autocast_meta->set_autocast_color(kWhite);\n  } else if (AutoMixedPrecisionLists::GrayList().count(op_type_name)) {\n    autocast_meta->set_autocast_color(kGray);\n  } else if (AutoMixedPrecisionLists::ClearList().count(op_type_name)) {\n    autocast_meta->set_autocast_color(kClear);\n  } else {\n    autocast_meta->set_autocast_color(kBlack);\n  }\n  for (int i = 0; i < input_args.size(); ++i) {\n    if (!amp::IsNoCast(op_type_name, input_args[i])) {\n      autocast_meta->set_arg_autocast_eligible(i);\n    }\n  }\n  // autocast only supports the following device type(s) and low precision type(s):\n  //   - device type: CUDA\n  //   - low precision type: half, bfloat16\n  static std::vector<DeviceType> autocast_device_types{kCUDA};\n  static std::vector<Symbol<DType>> autocast_dtypes{DType::Float16(), DType::BFloat16()};\n\n  if (autocast_meta->autocast_color() != kBlack) {\n    for (auto device_type : autocast_device_types) {\n      for (auto dtype : autocast_dtypes) {\n        autocast_meta->set_autocast_eligible(device_type, dtype);\n      }\n    }\n  }\n  return autocast_meta;\n}\n\n}  // namespace autocast\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/framework/autocast.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_FRAMEWORK_AUTOCAST_H_\n#define ONEFLOW_CORE_FRAMEWORK_AUTOCAST_H_\n\n#include \"oneflow/core/framework/device.h\"\n#include \"oneflow/core/framework/dtype.h\"\n#include \"oneflow/core/framework/tensor.h\"\n\nnamespace oneflow {\nnamespace autocast {\n\nbool is_enabled();\nvoid set_enabled(bool enabled);\n\nDeviceType get_autocast_device_type();\nvoid set_autocast_device_type(DeviceType device_type);\n\nSymbol<DType> get_autocast_dtype();\nSymbol<DType> get_autocast_cpu_dtype();\nSymbol<DType> get_autocast_gpu_dtype();\n\nvoid set_autocast_dtype(Symbol<DType> dtype);\nvoid set_autocast_cpu_dtype(Symbol<DType> dtype);\nvoid set_autocast_gpu_dtype(Symbol<DType> dtype);\n\nbool is_autocast_cache_enabled();\nvoid set_autocast_cache_enabled(bool enabled);\nvoid clear_cache();\nMaybe<one::Tensor> cached_cast(const std::shared_ptr<one::Tensor>& tensor, Symbol<DType> cast_type,\n                               DeviceType device_type);\n\nenum AutoCastColor { kNoColor, kWhite, kGray, kClear, kBlack };\n\nclass AutoCastMeta final {\n public:\n  AutoCastMeta() : AutoCastMeta(0) {}\n  explicit AutoCastMeta(int args_num)\n      : autocast_color_(kNoColor), is_args_autocast_eligible_(args_num, false) {}\n\n  AutoCastColor autocast_color() const;\n\n  bool is_autocast_eligible(DeviceType device_type, Symbol<DType> dtype) const;\n\n  bool is_args_autocast_eligible(int arg_index) const;\n  const std::vector<bool>& is_args_autocast_eligible() const;\n\n  void set_autocast_color(AutoCastColor color);\n  void set_autocast_eligible(DeviceType device_type, Symbol<DType> dtype);\n  void set_arg_autocast_eligible(int arg_index);\n\n private:\n  AutoCastColor autocast_color_;\n  std::vector<std::vector<bool>> is_autocast_eligible_;\n  std::vector<bool> is_args_autocast_eligible_;\n};\n\nstd::shared_ptr<AutoCastMeta> MakeAutoCastMeta(\n    const std::string& op_type_name,\n    const std::vector<std::pair<std::string, int32_t>>& input_args);\n\n}  // namespace autocast\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_FRAMEWORK_AUTOCAST_H_\n"
  },
  {
    "path": "oneflow/core/framework/compute_complexity_fn_context.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_FRAMEWORK_COMPUTE_COMPLEXITY_FN_CONTEXT_H_\n#define ONEFLOW_CORE_FRAMEWORK_COMPUTE_COMPLEXITY_FN_CONTEXT_H_\n\n#include \"oneflow/core/framework/user_op_conf.h\"\n#include \"oneflow/core/job/parallel_desc.h\"\n\nnamespace oneflow {\n\nclass Shape;\n\nnamespace user_op {\n\nclass UserOpDefWrapper;\n\nclass ComputeComplexityFnContext {\n public:\n  virtual ~ComputeComplexityFnContext() = default;\n\n  virtual const TensorDesc* TensorDesc4ArgNameAndIndex(const std::string&, int32_t) = 0;\n  virtual const Shape& Shape4ArgNameAndIndex(const std::string&, int32_t) const = 0;\n  virtual DataType Dtype4ArgNameAndIndex(const std::string&, int32_t) const = 0;\n  virtual const std::vector<std::pair<std::string, int32_t>>& inputs() const = 0;\n  virtual const std::vector<std::pair<std::string, int32_t>>& outputs() const = 0;\n  virtual const NdSbp NdSbp4ArgNameAndIndex(const std::string& arg_name, int32_t index) const = 0;\n  virtual const NdSbpSignature* GetNdSbpSignature() const = 0;\n\n  template<typename T>\n  T Attr(const std::string& attr_name) const {\n    return conf_.attr<T>(attr_name);\n  }\n\n  virtual const ParallelDesc& parallel_desc() const = 0;\n  virtual bool IsDynamic4ArgNameAndIndex(const std::string&, int32_t) const = 0;\n\n  const UserOpConfWrapper& user_op_conf() const { return conf_; }\n\n protected:\n  explicit ComputeComplexityFnContext(UserOpConfWrapper&& conf) : conf_(std::move(conf)) {}\n\n private:\n  UserOpConfWrapper conf_;\n};\n\n}  // namespace user_op\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_FRAMEWORK_COMPUTE_COMPLEXITY_FN_CONTEXT_H_\n"
  },
  {
    "path": "oneflow/core/framework/config_def.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/config_def.h\"\n#include \"oneflow/core/common/protobuf.h\"\n#include \"oneflow/core/common/util.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<ConfigDefType config_def_type>\nConfigDef* MutGlobalConfigDef() {\n  static ConfigDef config_def;\n  return &config_def;\n}\n\ntemplate<ConfigDefType config_def_type>\nAttrValue* AddAttrDef(const std::string& name, const std::string& description) {\n  auto* name2flag_def = MutGlobalConfigDef<config_def_type>()->mutable_attr_name2attr_def();\n  CHECK(name2flag_def->find(name) == name2flag_def->end()) << \"Duplicate attribute: \" << name;\n  auto* flag_def = &(*name2flag_def)[name];\n  flag_def->set_name(name);\n  flag_def->set_description(description);\n  return flag_def->mutable_default_val();\n}\n\n}  // namespace\n\nconst ConfigDef& GlobalEnvConfigDef() { return *MutGlobalConfigDef<kEnvConfigDefType>(); }\nconst ConfigDef& GlobalSessionConfigDef() { return *MutGlobalConfigDef<kSessionConfigDefType>(); }\nconst ConfigDef& GlobalFunctionConfigDef() { return *MutGlobalConfigDef<kFunctionConfigDefType>(); }\nconst ConfigDef& GlobalScopeConfigDef() { return *MutGlobalConfigDef<kScopeConfigDefType>(); }\n\ntemplate<ConfigDefType config_def_type>\nconst ConfigDefBuidler<config_def_type>& ConfigDefBuidler<config_def_type>::Bool(\n    const std::string& name, bool default_val, const std::string& description) const {\n  AddAttrDef<config_def_type>(name, description)->set_at_bool(default_val);\n  return *this;\n}\n\ntemplate<ConfigDefType config_def_type>\nconst ConfigDefBuidler<config_def_type>& ConfigDefBuidler<config_def_type>::Int64(\n    const std::string& name, int64_t default_val, const std::string& description) const {\n  AddAttrDef<config_def_type>(name, description)->set_at_int64(default_val);\n  return *this;\n}\n\ntemplate<ConfigDefType config_def_type>\nconst ConfigDefBuidler<config_def_type>& ConfigDefBuidler<config_def_type>::Double(\n    const std::string& name, double default_val, const std::string& description) const {\n  AddAttrDef<config_def_type>(name, description)->set_at_double(default_val);\n  return *this;\n}\n\ntemplate<ConfigDefType config_def_type>\nconst ConfigDefBuidler<config_def_type>& ConfigDefBuidler<config_def_type>::String(\n    const std::string& name, const std::string& default_val, const std::string& description) const {\n  AddAttrDef<config_def_type>(name, description)->set_at_string(default_val);\n  return *this;\n}\n\ntemplate<ConfigDefType config_def_type>\nconst ConfigDefBuidler<config_def_type>& ConfigDefBuidler<config_def_type>::ListInt64(\n    const std::string& name, const std::vector<int64_t>& default_val,\n    const std::string& description) const {\n  auto* list = AddAttrDef<config_def_type>(name, description)->mutable_at_list_int64();\n  *list->mutable_val() = {default_val.begin(), default_val.end()};\n  return *this;\n}\n\ntemplate struct ConfigDefBuidler<kEnvConfigDefType>;\ntemplate struct ConfigDefBuidler<kSessionConfigDefType>;\ntemplate struct ConfigDefBuidler<kFunctionConfigDefType>;\ntemplate struct ConfigDefBuidler<kScopeConfigDefType>;\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/framework/config_def.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_JOB_CONFIG_DEF_H_\n#define ONEFLOW_CORE_JOB_CONFIG_DEF_H_\n\n#include \"oneflow/core/common/preprocessor.h\"\n#include \"oneflow/core/framework/user_op_attr.pb.h\"\n#include \"oneflow/core/framework/config_def.pb.h\"\n\nnamespace oneflow {\n\ntemplate<ConfigDefType config_def_type>\nstruct ConfigDefBuidler final {\n  const ConfigDefBuidler& Bool(const std::string& name, bool default_val,\n                               const std::string& description) const;\n  const ConfigDefBuidler& Int64(const std::string& name, int64_t default_val,\n                                const std::string& description) const;\n  const ConfigDefBuidler& Double(const std::string& name, double default_val,\n                                 const std::string& description) const;\n  const ConfigDefBuidler& String(const std::string& name, const std::string& default_val,\n                                 const std::string& description) const;\n\n  const ConfigDefBuidler& ListInt64(const std::string& name,\n                                    const std::vector<int64_t>& default_val,\n                                    const std::string& description) const;\n};\n\n#define REGISTER_ENV_CONFIG_DEF() REGISTER_CONFIG_DEF(kEnvConfigDefType)\n#define REGISTER_SESSION_CONFIG_DEF() REGISTER_CONFIG_DEF(kSessionConfigDefType)\n#define REGISTER_FUNCTION_CONFIG_DEF() REGISTER_CONFIG_DEF(kFunctionConfigDefType)\n#define REGISTER_SCOPE_CONFIG_DEF() REGISTER_CONFIG_DEF(kScopeConfigDefType)\n\n#define REGISTER_CONFIG_DEF(config_def_type)                                                    \\\n  static ConfigDefBuidler<config_def_type> OF_PP_CAT(g_##config_def_type##_def_, __COUNTER__) = \\\n      ConfigDefBuidler<config_def_type>()\n\nconst ConfigDef& GlobalEnvConfigDef();\nconst ConfigDef& GlobalSessionConfigDef();\nconst ConfigDef& GlobalFunctionConfigDef();\nconst ConfigDef& GlobalScopeConfigDef();\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_JOB_CONFIG_DEF_H_\n"
  },
  {
    "path": "oneflow/core/framework/config_def.proto",
    "content": "syntax = \"proto2\";\npackage oneflow;\n\nimport \"oneflow/core/framework/user_op_attr.proto\";\n\nenum ConfigDefType {\n  kEnvConfigDefType = 1;\n  kSessionConfigDefType = 2;\n  kFunctionConfigDefType = 3;\n  kScopeConfigDefType = 4;\n}\n\nmessage ConfigDef {\n  map<string, AttrDef> attr_name2attr_def = 1;\n}\n"
  },
  {
    "path": "oneflow/core/framework/consistency_check.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <cstring>\n#include \"oneflow/core/framework/consistency_check.h\"\n#include \"oneflow/core/intrusive/flat_msg.h\"\n#include \"oneflow/core/job/rank_group.h\"\n#include \"oneflow/core/framework/transport_util.h\"\n#include \"oneflow/core/job/rank_group_scope.h\"\n#include \"oneflow/core/framework/synced_symbol_map.h\"\n#include \"oneflow/core/framework/sync_symbol_nd_sbp.h\"\n#include \"oneflow/core/framework/sync_symbol_parallel_desc.h\"\n#include \"oneflow/core/common/constant.h\"\n#include \"oneflow/core/common/check_level.h\"\n#include \"oneflow/core/framework/sync_symbol_global_tensor_meta.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nstruct FlatMetaInfoConsistency;\n\nclass CheckMetaInfoConsistencyAsyncTransportCtx : public AsyncTransportCtx {\n public:\n  CheckMetaInfoConsistencyAsyncTransportCtx(const TransportToken& transport_token,\n                                            const Symbol<ParallelDesc>& placement,\n                                            const Optional<Symbol<NdSbp>>& nd_sbp,\n                                            const Optional<Symbol<NdSbp>>& grad_nd_sbp)\n      : AsyncTransportCtx(transport_token),\n        placement_(placement),\n        nd_sbp_(nd_sbp),\n        grad_nd_sbp_(grad_nd_sbp) {}\n\n  ~CheckMetaInfoConsistencyAsyncTransportCtx() override = default;\n\n  Maybe<void> PrepareSendBufferAndCallback(int64_t rank, void** buffer, std::size_t* size,\n                                           std::function<void()>* Callback) override;\n\n  Maybe<void> PrepareRecvBufferAndCallback(int64_t rank, void** buffer, std::size_t* size,\n                                           std::function<void()>* Callback) override;\n\n  Maybe<void> Check() const;\n\n private:\n  Symbol<ParallelDesc> placement_;\n  Optional<Symbol<NdSbp>> nd_sbp_;\n  Optional<Symbol<NdSbp>> grad_nd_sbp_;\n  std::shared_ptr<FlatMetaInfoConsistency> flat_meta_info_consistency_;\n};\n\n// clang-format off\nFLAT_MSG_BEGIN(FlatMetaInfoConsistency);\n public:\n  static Maybe<FlatMetaInfoConsistency> New() {\n    const auto& consistency = std::make_shared<FlatMetaInfoConsistency>();\n    consistency->clear();\n    return consistency;\n  }\n  static Maybe<FlatMetaInfoConsistency> New(const Symbol<ParallelDesc>& placement,\n    const Optional<Symbol<NdSbp>>& nd_sbp, const Optional<Symbol<NdSbp>>& grad_nd_sbp) {\n    const auto& consistency = std::make_shared<FlatMetaInfoConsistency>();\n    consistency->clear();\n    JUST(consistency->Init(placement, nd_sbp, grad_nd_sbp));\n    return consistency;\n  }\n\n  Maybe<void> Check(const Symbol<ParallelDesc>& placement,\n    const Optional<Symbol<NdSbp>>& nd_sbp, const Optional<Symbol<NdSbp>>& grad_nd_sbp) {\n    \n    const auto& this_placement =\n        JUST(SyncedSymbolMap<ParallelDesc>::Symbol4SyncedSymbolId(\n            this->placement_symbol_id()));\n    CHECK_OR_RETURN(this_placement == placement) << Error::RuntimeError() << \"Each rank must have the same input placement\";\n    CHECK_EQ_OR_RETURN(nd_sbp.has_value(), this->has_nd_sbp_symbol_id()) << Error::RuntimeError()  << \"Either all ranks have sbp or not\";\n    if (this->has_nd_sbp_symbol_id()) {\n      const auto& that_nd_sbp =\n          JUST(SyncedSymbolMap<NdSbp>::Symbol4SyncedSymbolId(\n              this->nd_sbp_symbol_id()));\n      const auto& this_nd_sbp = JUST(nd_sbp);\n      CHECK_OR_RETURN(this_nd_sbp == that_nd_sbp) << Error::RuntimeError() << \"Each rank must have the same input sbp\";\n    }\n    CHECK_EQ_OR_RETURN(grad_nd_sbp.has_value(), this->has_grad_nd_sbp_symbol_id()) << Error::RuntimeError() << \"Either all ranks have grad sbp or not\";\n    if (this->has_grad_nd_sbp_symbol_id()) {\n       const auto& that_grad_nd_sbp =\n          JUST(SyncedSymbolMap<NdSbp>::Symbol4SyncedSymbolId(\n              this->grad_nd_sbp_symbol_id()));\n      const auto& this_grad_nd_sbp = JUST(grad_nd_sbp);\n      CHECK_OR_RETURN(this_grad_nd_sbp == that_grad_nd_sbp)<< Error::RuntimeError() << \"Each rank must have same input grad sbp\";\n    }\n    return Maybe<void>::Ok();\n  }\n private:\n  Maybe<void> Init(const Symbol<ParallelDesc>& placement, const Optional<Symbol<NdSbp>>& nd_sbp,\n    const Optional<Symbol<NdSbp>>& grad_nd_sbp) {\n    this->set_placement_symbol_id(\n        JUST(SyncedSymbolMap<ParallelDesc>::FindOrSync(placement, &SyncSymbolParallelDesc)));\n    if (nd_sbp.has_value()) {\n      this->set_nd_sbp_symbol_id(\n          JUST(SyncedSymbolMap<NdSbp>::FindOrSync(JUST(nd_sbp), &SyncSymbolNdSbp)));\n    }\n    if (grad_nd_sbp.has_value()) {\n      this->set_grad_nd_sbp_symbol_id(\n          JUST(SyncedSymbolMap<NdSbp>::FindOrSync(JUST(grad_nd_sbp), &SyncSymbolNdSbp)));\n    }\n    return Maybe<void>::Ok();\n  }\n  FLAT_MSG_DEFINE_OPTIONAL(uint64_t, placement_symbol_id);\n  FLAT_MSG_DEFINE_OPTIONAL(uint64_t, nd_sbp_symbol_id);\n  FLAT_MSG_DEFINE_OPTIONAL(uint64_t, grad_nd_sbp_symbol_id);\nFLAT_MSG_END(FlatMetaInfoConsistency);\n// clang-format on\n\nMaybe<void> CheckMetaInfoConsistencyAsyncTransportCtx::PrepareSendBufferAndCallback(\n    int64_t rank, void** buffer, std::size_t* size, std::function<void()>* Callback) {\n  const auto& meta_info_consistency =\n      JUST(FlatMetaInfoConsistency::New(placement_, nd_sbp_, grad_nd_sbp_));\n  *buffer = meta_info_consistency.get();\n  *size = sizeof(FlatMetaInfoConsistency);\n  *Callback = [meta_info_consistency] {};\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> CheckMetaInfoConsistencyAsyncTransportCtx::PrepareRecvBufferAndCallback(\n    int64_t rank, void** buffer, std::size_t* size, std::function<void()>* Callback) {\n  const auto& flat_meta_info_consistency = JUST(FlatMetaInfoConsistency::New());\n  *buffer = flat_meta_info_consistency.get();\n  *size = sizeof(FlatMetaInfoConsistency);\n  *Callback = [flat_meta_info_consistency]() {};\n  flat_meta_info_consistency_ = flat_meta_info_consistency;\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> CheckMetaInfoConsistencyAsyncTransportCtx::Check() const {\n  if (!flat_meta_info_consistency_) { return Maybe<void>::Ok(); }\n  JUST(flat_meta_info_consistency_->Check(placement_, nd_sbp_, grad_nd_sbp_));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\nMaybe<void> DataConsistencyCheck(const void* buffer_ptr, size_t buffer_size,\n                                 Symbol<ParallelDesc> placement) {\n  if (!placement->containing_current_rank() || placement->parallel_num() == 1) {\n    return Maybe<void>::Ok();\n  }\n\n  const auto& rank_group = JUST(RankGroup::New(placement));\n\n  std::vector<char> recv_buffer(buffer_size);\n  char* recv_ptr = recv_buffer.data();\n\n  TransportToken transport_token = JUST(TransportToken::NewTransportToken(kTransportTokenTypeData));\n  NaiveAsyncTransportCtx ctx(\n      transport_token,\n      [&](void** buffer, std::size_t* size, std::function<void()>* Cb) -> Maybe<void> {\n        *buffer = const_cast<void*>(buffer_ptr);\n        *size = buffer_size;\n        *Cb = [] {};\n        return Maybe<void>::Ok();\n      },\n      [&](void** buffer, std::size_t* size, std::function<void()>* Cb) -> Maybe<void> {\n        *buffer = recv_ptr;\n        *size = buffer_size;\n        *Cb = [] {};\n        return Maybe<void>::Ok();\n      });\n  JUST(TransportUtil::SendToNextRankInRing(rank_group, transport_token, &ctx));\n  JUST(TransportUtil::ReceiveFromPrevRankInRing(rank_group, transport_token, &ctx));\n  JUST_MSG(ctx.WaitDone(), kAsymmetricCodeErrorMsg);\n  CHECK_OR_RETURN(std::memcmp(buffer_ptr, reinterpret_cast<const void*>(recv_ptr), buffer_size)\n                  == 0)\n      << Error::RuntimeError() << \"Each rank must have same input sequence or numpy array\";\n  return Maybe<void>::Ok();\n}\n\nnamespace {\n\nMaybe<void> MetaInfoConsistencyCheckUtil(const Symbol<ParallelDesc>& placement,\n                                         const Optional<Symbol<NdSbp>>& nd_sbp,\n                                         const Optional<Symbol<NdSbp>>& grad_nd_sbp) {\n  const auto& rank_group = JUST(RankGroupScope::CurrentRankGroup());\n  const auto& transport_token =\n      JUST(TransportToken::NewTransportToken(kTransportTokenTypeCheckRankGroupConsistency));\n  const auto& ctx = std::make_shared<CheckMetaInfoConsistencyAsyncTransportCtx>(\n      transport_token, placement, nd_sbp, grad_nd_sbp);\n  JUST(TransportUtil::SendToNextRankInRing(rank_group, transport_token, ctx.get()));\n  JUST(TransportUtil::ReceiveFromPrevRankInRing(rank_group, transport_token, ctx.get()));\n  JUST_MSG(ctx->WaitDone(), kAsymmetricCodeErrorMsg);\n  JUST(ctx->Check());\n  return Maybe<void>::Ok();\n}\n\nint64_t* MutThreadLocalMetaInfoConsistencyCheckDepth() {\n  static thread_local int64_t recursive_depth = 0;\n  return &recursive_depth;\n}\n\ninline bool IsMetaInfoConsistencyCheckDisable() {\n  return *MutThreadLocalMetaInfoConsistencyCheckDepth() > 1;\n}\n\n}  // namespace\n\nNonRecursiveMetaInfoConsistencyCheckScope::NonRecursiveMetaInfoConsistencyCheckScope() {\n  auto* recursive_depth = MutThreadLocalMetaInfoConsistencyCheckDepth();\n  ++*recursive_depth;\n}\n\nNonRecursiveMetaInfoConsistencyCheckScope::~NonRecursiveMetaInfoConsistencyCheckScope() {\n  auto* recursive_depth = MutThreadLocalMetaInfoConsistencyCheckDepth();\n  --*recursive_depth;\n}\n\nMaybe<void> MetaInfoConsistencyCheck(const Symbol<ParallelDesc>& placement,\n                                     const Optional<Symbol<NdSbp>>& nd_sbp,\n                                     const Optional<Symbol<NdSbp>>& grad_nd_sbp,\n                                     const size_t debug_level, bool force_check) {\n  if ((IsEnvEnabled(debug_level) || force_check) && !IsMetaInfoConsistencyCheckDisable()) {\n    JUST(MetaInfoConsistencyCheckUtil(placement, nd_sbp, grad_nd_sbp));\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> MetaInfoConsistencyCheck(const Symbol<ParallelDesc>& placement,\n                                     const Optional<Symbol<NdSbp>>& nd_sbp,\n                                     const size_t debug_level, bool force_check) {\n  if ((IsEnvEnabled(debug_level) || force_check) && !IsMetaInfoConsistencyCheckDisable()) {\n    JUST(MetaInfoConsistencyCheckUtil(placement, nd_sbp, Optional<Symbol<NdSbp>>()));\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> MetaInfoConsistencyCheck(const Symbol<ParallelDesc>& placement,\n                                     const std::vector<Symbol<SbpParallel>>& sbp_tuple,\n                                     const std::vector<Symbol<SbpParallel>>& grad_sbp_tuple,\n                                     const size_t debug_level, bool force_check) {\n  Optional<Symbol<NdSbp>> nd_sbp;\n  Optional<Symbol<NdSbp>> grad_nd_sbp;\n  if (!sbp_tuple.empty()) { grad_nd_sbp = JUST(GetNdSbp(sbp_tuple)); }\n  if (!grad_sbp_tuple.empty()) { grad_nd_sbp = JUST(GetNdSbp(grad_sbp_tuple)); }\n  JUST(MetaInfoConsistencyCheck(placement, nd_sbp, grad_nd_sbp, debug_level, force_check));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> MetaInfoConsistencyCheck(const Symbol<ParallelDesc>& placement,\n                                     const std::vector<Symbol<SbpParallel>>& sbp_tuple,\n                                     const size_t debug_level, bool force_check) {\n  Optional<Symbol<NdSbp>> nd_sbp;\n  Optional<Symbol<NdSbp>> grad_nd_sbp;\n  if (!sbp_tuple.empty()) { grad_nd_sbp = JUST(GetNdSbp(sbp_tuple)); }\n  JUST(MetaInfoConsistencyCheck(placement, nd_sbp, grad_nd_sbp, debug_level, force_check));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/framework/consistency_check.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_FRAMEWORK_DATA_CONSISTENCY_CHECK_H_\n#define ONEFLOW_CORE_FRAMEWORK_DATA_CONSISTENCY_CHECK_H_\n\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/common/symbol.h\"\n#include \"oneflow/core/job/parallel_desc.h\"\n#include \"oneflow/core/framework/nd_sbp.h\"\n#include \"oneflow/core/common/tensor_meta.h\"\n\nnamespace oneflow {\n\nclass NonRecursiveMetaInfoConsistencyCheckScope final {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(NonRecursiveMetaInfoConsistencyCheckScope);\n  NonRecursiveMetaInfoConsistencyCheckScope();\n  ~NonRecursiveMetaInfoConsistencyCheckScope();\n};\n\nMaybe<void> DataConsistencyCheck(const void* buffer_ptr, size_t buffer_size,\n                                 Symbol<ParallelDesc> placement);\n\nMaybe<void> MetaInfoConsistencyCheck(const Symbol<ParallelDesc>& placement,\n                                     const Optional<Symbol<NdSbp>>& nd_sbp,\n                                     const Optional<Symbol<NdSbp>>& grad_nd_sbp,\n                                     const size_t debug_level, bool force_check);\n\nMaybe<void> MetaInfoConsistencyCheck(const Symbol<ParallelDesc>& placement,\n                                     const Optional<Symbol<NdSbp>>& nd_sbp,\n                                     const size_t debug_level, bool force_check);\n\nMaybe<void> MetaInfoConsistencyCheck(const Symbol<ParallelDesc>& placement,\n                                     const std::vector<Symbol<SbpParallel>>& sbp_tuple,\n                                     const std::vector<Symbol<SbpParallel>>& grad_sbp_tuple,\n                                     const size_t debug_level, bool force_check);\n\nMaybe<void> MetaInfoConsistencyCheck(const Symbol<ParallelDesc>& placement,\n                                     const std::vector<Symbol<SbpParallel>>& sbp_tuple,\n                                     const size_t debug_level, bool force_check);\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_FRAMEWORK_DATA_CONSISTENCY_CHECK_H_\n"
  },
  {
    "path": "oneflow/core/framework/device.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <fmt/core.h>\n#include \"oneflow/core/framework/device.h\"\n#include \"oneflow/core/control/global_process_ctx.h\"\n#include \"oneflow/core/common/str_util.h\"\n#include \"oneflow/core/common/decorator.h\"\n#include \"oneflow/core/job/resource_desc.h\"\n#include \"oneflow/core/job/global_for.h\"\n#include \"oneflow/core/job/parallel_desc.h\"\n#include \"oneflow/core/job/env_global_objects_scope.h\"\n#include \"oneflow/core/memory/memory_case_util.h\"\n#include \"oneflow/core/common/container_util.h\"\n#include \"oneflow/core/framework/to_string.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nvoid CheckDeviceType(const std::string& type) {\n  if (!TRY(DeviceType4DeviceTag(type)).IsOk()) {\n    std::string error_msg = \"Expected one of \" + PrintAvailableDevices()\n                            + \" device type at start of device string: \" + type;\n    throw std::runtime_error(error_msg);\n  }\n}\n\n}  // namespace\n\nDevice::Device(const std::string& type, int64_t device_id, bool rematable)\n    : type_(type),\n      enum_type_(kInvalidDevice),\n      device_id_(device_id),\n      rematable_(rematable),\n      hash_value_(Hash(type, device_id, rematable)) {}\n\nMaybe<void> Device::Init() {\n  if (type_ == \"auto\") { return Maybe<void>::Ok(); }\n  enum_type_ = JUST(DeviceType4DeviceTag(type()));\n  {\n    DeviceType dev_type = enum_type_;\n    if (dev_type == kMockDevice) { dev_type = DeviceType::kCPU; }\n    mem_case_ = memory::MakeMemCaseShared(enum_type_, device_id_);\n  }\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<Symbol<Device>> Device::New(const std::string& type, int64_t device_id,\n                                               bool rematable) {\n  CHECK_GE_OR_RETURN(device_id, 0)\n      << Error::InvalidValueError() << \"Device ID should be non-negative\";\n  static thread_local HashMap<std::tuple<std::string, int, bool>, Symbol<Device>> map;\n  auto key = std::make_tuple(type, device_id, rematable);\n  auto iter = map.find(key);\n  if (iter == map.end()) {\n    Device device(type, device_id, rematable);\n    JUST(device.Init());\n    iter = map.emplace(key, SymbolOf(device)).first;\n  }\n  return iter->second;\n}\n\n/* static */ Maybe<Symbol<Device>> Device::New(const std::string& type, int64_t device_id) {\n  return New(type, device_id, false);\n}\n\n/* static */ Maybe<Symbol<Device>> Device::New(const std::string& type) {\n  return New(type, GlobalProcessCtx::LocalRank());\n}\n\n/* static */ Maybe<Symbol<Device>> Device::ParseAndNew(const std::string& device_str) {\n  static thread_local HashMap<std::string, Symbol<Device>> map;\n  auto iter = map.find(device_str);\n  if (iter == map.end()) {\n    auto [type, device_id, rematable] = *JUST(ParseDeviceString(device_str));\n    CheckDeviceType(type);\n    if (device_id == -1) { device_id = GlobalProcessCtx::LocalRank(); }\n    Device device(type, device_id, rematable);\n    JUST(device.Init());\n    iter = map.emplace(device_str, SymbolOf(device)).first;\n  }\n  return iter->second;\n}\n\nstd::string Device::ToRepr() const {\n  auto rematable_suffix = \"\";\n  if (rematable_) { rematable_suffix = \", rematable=True\"; }\n  return fmt::format(\"device(type='{}', index={}{})\", type_, device_id_, rematable_suffix);\n}\n\nstd::ostream& operator<<(std::ostream& os, Symbol<Device> device) {\n  os << device->ToRepr();\n  return os;\n}\n\nstd::string Device::ToString() const {\n  auto rematable_suffix = \"\";\n  if (rematable_) { rematable_suffix = \"+remat\"; }\n  return fmt::format(\"{}:{}{}\", type_, device_id_, rematable_suffix);\n}\n\nMaybe<Symbol<Device>> Device::MakeDeviceByParallelDesc(const ParallelDesc& parallel_desc) {\n  const std::string& type = parallel_desc.device_tag();\n  std::vector<std::string> machine_device_ids;\n  machine_device_ids.reserve(parallel_desc.parallel_conf().device_name().size());\n  for (const auto& item : parallel_desc.parallel_conf().device_name()) {\n    machine_device_ids.emplace_back(item);\n  }\n  CHECK_EQ_OR_RETURN(machine_device_ids.size(), 1)\n      << Error::InvalidValueError() << \"Number of machine device should be one\";\n  const std::string& machine_device_id = machine_device_ids.at(0);\n  size_t pos = machine_device_id.find(':');\n  CHECK_NE_OR_RETURN(pos, std::string::npos)\n      << Error::InvalidValueError() << \"Invalid device ID: \" << machine_device_id;\n  std::string device_id = machine_device_id.substr(pos + 1);\n  CHECK_EQ_OR_RETURN(device_id.find('-'), std::string::npos)\n      << Error::InvalidValueError() << \"Device ID should be non-negative\";\n  CHECK_OR_RETURN(IsStrInt(device_id))\n      << Error::InvalidValueError() << \"Device ID is not integer: \" << device_id;\n  return Device::New(type, std::stoi(device_id));\n}\n\nnamespace {\n\nMaybe<Symbol<ParallelDesc>> RawGetPlacement(const Device& device) {\n  std::string machine_device_id =\n      \"@\" + std::to_string(GlobalProcessCtx::Rank()) + \":\" + std::to_string(device.device_id());\n  ParallelConf parallel_conf;\n  parallel_conf.set_device_tag(device.type());\n  parallel_conf.add_device_name(machine_device_id);\n  return SymbolOf(ParallelDesc(parallel_conf));\n}\n\nMaybe<Symbol<ParallelDesc>> RawPlacement4Device(Symbol<Device> device) {\n  return RawGetPlacement(*device);\n}\n\n}  // namespace\n\ndecltype(Device::GetPlacement) Device::GetPlacement =\n    DECORATE(&RawGetPlacement, ThreadLocalCopiable);\ndecltype(Placement4Device) Placement4Device = DECORATE(&RawPlacement4Device, ThreadLocal);\n\nMaybe<std::tuple<std::string, int, bool>> ParseDeviceString(std::string device_str) {\n  bool rematable = false;\n  if (device_str.size() > 6 && device_str.substr(device_str.size() - 6, 6) == \"+remat\") {\n    rematable = true;\n    device_str = device_str.substr(0, device_str.size() - 6);\n  }\n  std::string::size_type pos = device_str.find(':');\n  if (pos == std::string::npos) {\n    return std::make_tuple(device_str, -1, rematable);\n  } else {\n    std::string index_str = device_str.substr(pos + 1);\n    CHECK_OR_RETURN(IsStrInt(index_str))\n        << Error::InvalidValueError() << \"Invalid device tag \" << device_str;\n    return std::make_tuple(device_str.substr(0, pos), std::stoi(index_str), rematable);\n  }\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/framework/device.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_FRAMEWORK_DEVICE_H_\n#define ONEFLOW_CORE_FRAMEWORK_DEVICE_H_\n\n#include <fmt/core.h>\n#include <fmt/ostream.h>\n#include <memory>\n#include <string>\n#include <unordered_set>\n#include \"oneflow/core/common/device_type.pb.h\"\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/common/symbol.h\"\n#include \"oneflow/core/common/optional.h\"\n\nnamespace oneflow {\n\nclass ParallelDesc;\nclass MemoryCase;\n\ninline size_t GetInstructionHighWaterMark() { return 40000; }\ninline size_t GetInstructionLowWaterMark() { return 20000; }\n\nclass Device final {\n public:\n  Device(const Device&) = default;\n  Device(Device&&) = default;\n  ~Device() = default;\n  Device& operator=(const Device&) = delete;\n  const std::string& type() const { return type_; }\n  DeviceType enum_type() const { return enum_type_; }\n  int64_t device_id() const { return device_id_; }\n  bool rematable() const { return rematable_; }\n  std::string ToString() const;\n  std::string ToRepr() const;\n  size_t hash_value() const { return hash_value_; }\n  bool operator==(const Device& device) const {\n    return type_ == device.type() && device_id_ == device.device_id()\n           && rematable_ == device.rematable();\n  }\n  bool operator!=(const Device& device) const { return !operator==(device); }\n  const std::shared_ptr<MemoryCase>& mem_case() const { return mem_case_; }\n\n  static Maybe<Symbol<Device>> New(const std::string& type, int64_t device_id, bool rematable);\n  static Maybe<Symbol<Device>> New(const std::string& type, int64_t device_id);\n  static Maybe<Symbol<Device>> New(const std::string& type);\n  static Maybe<Symbol<Device>> ParseAndNew(const std::string& type_or_type_with_device_id);\n\n  static Maybe<Symbol<Device>> MakeDeviceByParallelDesc(const ParallelDesc& parallel_desc);\n\n  static Maybe<Symbol<ParallelDesc>> (*GetPlacement)(const Device& device);\n\n private:\n  Device(const std::string& type, int64_t device_id, bool rematable);\n  Maybe<void> Init();\n\n  const std::string type_;\n  DeviceType enum_type_;\n  const int64_t device_id_;\n  bool rematable_;\n  const size_t hash_value_;\n  std::shared_ptr<MemoryCase> mem_case_;\n};\n\nstd::ostream& operator<<(std::ostream& os, Symbol<Device> device);\n\nextern Maybe<Symbol<ParallelDesc>> (*Placement4Device)(Symbol<Device> device);\n\nMaybe<std::tuple<std::string, int, bool>> ParseDeviceString(std::string device_str);\n\n}  // namespace oneflow\n\ntemplate<>\nstruct fmt::formatter<oneflow::Symbol<oneflow::Device>> : ostream_formatter {};\n\nnamespace std {\ntemplate<>\nstruct hash<oneflow::Device> final {\n  size_t operator()(const oneflow::Device& device) const { return device.hash_value(); }\n};\n}  // namespace std\n\n#endif  // ONEFLOW_CORE_FRAMEWORK_DEVICE_H_\n"
  },
  {
    "path": "oneflow/core/framework/dtype.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"half.hpp\"\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/common/switch_func.h\"\n#include \"oneflow/core/common/container_util.h\"\n#include \"oneflow/core/common/data_type_seq.h\"\n#include \"oneflow/core/common/data_type.h\"\n#include \"oneflow/core/common/protobuf.h\"\n#include \"oneflow/core/framework/dtype.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename T>\nstd::size_t GetDataTypeBytes() {\n  return sizeof(T);\n}\n\n#define MAKE_DATA_TYPE_BYTES_SWITCH_ENTRY(func_name, T) func_name<T>\nDEFINE_STATIC_SWITCH_FUNC(\n    std::size_t, GetDataTypeBytes, MAKE_DATA_TYPE_BYTES_SWITCH_ENTRY,\n    MAKE_DATA_TYPE_CTRV_SEQ(POD_DATA_TYPE_SEQ FLOAT16_DATA_TYPE_SEQ BFLOAT16_DATA_TYPE_SEQ));\n\nclass DTypeMeta final {\n public:\n  DTypeMeta(const std::string& name, bool is_signed, bool is_integer, bool is_floating_point,\n            bool is_complex)\n      : name_(name),\n        is_signed_(is_signed),\n        is_integer_(is_integer),\n        is_floating_point_(is_floating_point),\n        is_complex_(is_complex) {}\n  DTypeMeta(const DTypeMeta&) = default;\n  DTypeMeta(DTypeMeta&) = default;\n  ~DTypeMeta() = default;\n\n  const std::string& name() const { return name_; }\n  bool is_signed() const { return is_signed_; }\n  bool is_integer() const { return is_integer_; }\n  bool is_floating_point() const { return is_floating_point_; }\n  bool is_complex() const { return is_complex_; }\n\n private:\n  const std::string name_;\n  const bool is_signed_;\n  const bool is_integer_;\n  const bool is_floating_point_;\n  const bool is_complex_;\n};\n\nMaybe<const DTypeMeta&> DTypeMeta4DataType(DataType data_type) {\n  static const HashMap<DataType, DTypeMeta> data_type2dtype_meta{\n      {DataType::kInvalidDataType,\n       DTypeMeta(\"oneflow.invalid_data_type\", false, false, false, false)},\n      {DataType::kChar, DTypeMeta(\"oneflow.char\", false, false, false, false)},\n      {DataType::kFloat16, DTypeMeta(\"oneflow.float16\", true, false, true, false)},\n      {DataType::kFloat, DTypeMeta(\"oneflow.float32\", true, false, true, false)},\n      {DataType::kDouble, DTypeMeta(\"oneflow.float64\", true, false, true, false)},\n      {DataType::kInt8, DTypeMeta(\"oneflow.int8\", true, true, false, false)},\n      {DataType::kInt16, DTypeMeta(\"oneflow.int16\", true, true, false, false)},\n      {DataType::kInt32, DTypeMeta(\"oneflow.int32\", true, true, false, false)},\n      {DataType::kInt64, DTypeMeta(\"oneflow.int64\", true, true, false, false)},\n      {DataType::kInt128, DTypeMeta(\"oneflow.int128\", true, true, false, false)},\n      {DataType::kUInt8, DTypeMeta(\"oneflow.uint8\", false, true, false, false)},\n      {DataType::kUInt16, DTypeMeta(\"oneflow.uint16\", false, true, false, false)},\n      {DataType::kUInt32, DTypeMeta(\"oneflow.uint32\", false, true, false, false)},\n      {DataType::kUInt64, DTypeMeta(\"oneflow.uint64\", false, true, false, false)},\n      {DataType::kUInt128, DTypeMeta(\"oneflow.uint128\", false, true, false, false)},\n      {DataType::kOFRecord, DTypeMeta(\"oneflow.of_record\", false, false, false, false)},\n      {DataType::kTensorBuffer, DTypeMeta(\"oneflow.tensor_buffer\", false, false, false, false)},\n      {DataType::kBFloat16, DTypeMeta(\"oneflow.bfloat16\", true, false, true, false)},\n      {DataType::kBool, DTypeMeta(\"oneflow.bool\", false, false, false, false)},\n      {DataType::kComplex32, DTypeMeta(\"oneflow.complex32\", false, false, false, true)},\n      {DataType::kComplex64, DTypeMeta(\"oneflow.complex64\", false, false, false, true)},\n      {DataType::kComplex128, DTypeMeta(\"oneflow.complex128\", false, false, false, true)},\n  };\n  return MapAt(data_type2dtype_meta, data_type);\n};\n\n}  // namespace\n\nMaybe<const Symbol<DType>&> DType::Get(DataType data_type) {\n  static HashMap<DataType, const Symbol<DType>> data_type2dtype{\n#define MAKE_ENTRY(data_type) {OF_PP_CAT(DataType::k, data_type), data_type()},\n      OF_PP_FOR_EACH_TUPLE(MAKE_ENTRY, DTYPE_SEQ)\n#undef MAKE_ENTRY\n  };\n  return MapAt(data_type2dtype, data_type);\n}\n\nMaybe<size_t> DType::bytes() const {\n  // DataType::OFRecord and DataType::TensorBuffer don't have fixed byte size\n  if (data_type() == DataType::kInvalidDataType || data_type() == DataType::kOFRecord\n      || data_type() == DataType::kTensorBuffer) {\n    OF_UNIMPLEMENTED();\n  }\n  return SwitchGetDataTypeBytes(SwitchCase(data_type()));\n}\n\nbool DType::is_signed() const { return CHECK_JUST(DTypeMeta4DataType(data_type_)).is_signed(); }\n\nbool DType::is_complex() const { return CHECK_JUST(DTypeMeta4DataType(data_type_)).is_complex(); }\n\n/*\n  The order of datatype is:\n  0    1    2    3    4    5    6    7    8    9    10   11   12   13   14   15   16   17   18   19\n  20 iv   c1   f4   f8   i1   i4   i8   u1   re   f2   bu   bf   b1   u4   u8   u16  i2   i16  cp4\n  cp8  cp16 The priority order of datatype is: 0    1    2    3    4    5    6    7    8    9    10\n  11    12   13   14   15    16    17     18   19   20 iv < b1 < u1 < c1 < i1 < i2 < u4 < i4 < u8 <\n  i8 < u16 < i16 < f2 < f4 < f8 < cp4 < cp8 < cp16 < bf < re < bu.\n*/\nconst int DType::priority_order[DataType_ARRAYSIZE] = {0,  /*kInvalid*/\n                                                       3,  /*kChar*/\n                                                       14, /*kFloat32*/\n                                                       15, /*kDouble*/\n                                                       4,  /*kInt8*/\n                                                       8,  /*kInt32*/\n                                                       10, /*kInt64*/\n                                                       2,  /*kUInt8*/\n                                                       20, /*kOFRecord*/\n                                                       13, /*kFloat16*/\n                                                       21, /*kTensorBuffer*/\n                                                       19, /*kBFloat16*/\n                                                       1,  /*kBool*/\n                                                       5,  /*kUint16*/\n                                                       7,  /*kUint32*/\n                                                       9,  /*kUint64*/\n                                                       11, /*kUint128*/\n                                                       6,  /*kInt16*/\n                                                       12, /*kInt128*/\n                                                       16, /*kComplex32*/\n                                                       17, /*kComplex64*/\n                                                       18 /*kComplex128*/};\n\nbool DType::is_integer() const { return CHECK_JUST(DTypeMeta4DataType(data_type_)).is_integer(); }\n\nbool DType::is_floating_point() const {\n  return CHECK_JUST(DTypeMeta4DataType(data_type_)).is_floating_point();\n}\n\nconst std::string& DType::name() const { return CHECK_JUST(DTypeMeta4DataType(data_type_)).name(); }\n\n#define DEFINE_GET_DATA_TYPE_FUNCTION(data_type)                                   \\\n  const Symbol<DType>& DType::data_type() {                                        \\\n    static const auto& dtype = SymbolOf(DType(OF_PP_CAT(DataType::k, data_type))); \\\n    return dtype;                                                                  \\\n  }\nOF_PP_FOR_EACH_TUPLE(DEFINE_GET_DATA_TYPE_FUNCTION, DTYPE_SEQ)\n#undef DEFINE_GET_DATA_TYPE_FUNCTION\n\nSymbol<DType> promoteTypes(const Symbol<DType> a, const Symbol<DType> b) {\n  const Symbol<DType> iv = CHECK_JUST(DType::Get(DataType::kInvalidDataType));\n  const Symbol<DType> c1 = CHECK_JUST(DType::Get(DataType::kChar));\n  const Symbol<DType> f4 = CHECK_JUST(DType::Get(DataType::kFloat));\n  const Symbol<DType> f8 = CHECK_JUST(DType::Get(DataType::kDouble));\n  const Symbol<DType> i1 = CHECK_JUST(DType::Get(DataType::kInt8));\n  const Symbol<DType> i4 = CHECK_JUST(DType::Get(DataType::kInt32));\n  const Symbol<DType> i8 = CHECK_JUST(DType::Get(DataType::kInt64));\n  const Symbol<DType> u1 = CHECK_JUST(DType::Get(DataType::kUInt8));\n  const Symbol<DType> re = CHECK_JUST(DType::Get(DataType::kOFRecord));\n  const Symbol<DType> f2 = CHECK_JUST(DType::Get(DataType::kFloat16));\n  const Symbol<DType> bu = CHECK_JUST(DType::Get(DataType::kTensorBuffer));\n  const Symbol<DType> bf = CHECK_JUST(DType::Get(DataType::kBFloat16));\n  const Symbol<DType> b1 = CHECK_JUST(DType::Get(DataType::kBool));\n  const Symbol<DType> u2 = CHECK_JUST(DType::Get(DataType::kUInt16));\n  const Symbol<DType> u4 = CHECK_JUST(DType::Get(DataType::kUInt32));\n  const Symbol<DType> u8 = CHECK_JUST(DType::Get(DataType::kUInt64));\n  const Symbol<DType> u16 = CHECK_JUST(DType::Get(DataType::kUInt128));\n  const Symbol<DType> i2 = CHECK_JUST(DType::Get(DataType::kInt16));\n  const Symbol<DType> i16 = CHECK_JUST(DType::Get(DataType::kInt128));\n  const Symbol<DType> cp4 = CHECK_JUST(DType::Get(DataType::kComplex32));\n  const Symbol<DType> cp8 = CHECK_JUST(DType::Get(DataType::kComplex64));\n  const Symbol<DType> cp16 = CHECK_JUST(DType::Get(DataType::kComplex128));\n\n  /* It is consistent with data_type.proto(except kInvalidDataType, kOFRecord and kTensorBuffer)\n    kInvalidDataType = 0;\n    kChar = 1;\n    kFloat = 2;\n    kDouble = 3;\n    kInt8 = 4;\n    kInt32 = 5;\n    kInt64 = 6;\n    kUInt8 = 7;\n    kOFRecord = 8;\n    kFloat16 = 9;\n    kTensorBuffer = 10;\n    kBFloat16 = 11;\n    kBool = 12;\n    kUInt16 = 13;\n    kUInt32 = 14;\n    kUInt64 = 15;\n    kUInt128 = 16;\n    kInt16 = 17;\n    kInt128 = 18;\n    kComplex32 = 19;\n    kComplex64 = 20;\n    kComplex128 = 21;\n\n    The priority order of datatype is:\n    iv < b1 < u1 < c1 < i1 < u2 < i2 < u4 < i4 < u8 < i8 < u16 < i16 < f2 < f4 < f8 < cp4 < cp8 <\n    cp16 < bf < re < bu.\n\n    When int8 + uint8, it need to promote to int16, etc.\n    But in int8 + uint128, we should promote to int256, but it is not exist, so we set as Invalid.\n\n    The new DataType should be add in the end of proto, and the Loopup table should be maintained as\n    right priority (author:zhengzekang).\n  */\n\n  // clang-format off\n  static const Symbol<DType> _promoteTypesLookup[DataType_ARRAYSIZE][DataType_ARRAYSIZE] = {\n      /*          iv   c1   f4   f8   i1   i4   i8   u1   re   f2   bu   bf   b1   u2   u4   u8   u16   i2   i16   cp4   cp8   cp16 */\n      /* iv */   {iv,  c1,  f4,  f8,  i1,  i4,  i8,  u1,  re,  f2,  bu,  bf,  b1,  u2,  u4,  u8,  u16,  i2,  i16,  cp4,  cp8,  cp16},\n      /* c1 */   {c1,  c1,  f4,  f8,  i1,  i4,  i8,  c1,  iv,  f2,  iv,  bf,  c1,  u2,  u4,  u8,  u16,  i2,  i16,  iv,   cp8,  cp16},\n      /* f4 */   {f4,  f4,  f4,  f8,  f4,  f4,  f4,  f4,  iv,  f4,  iv,  bf,  f4,  f4,  f4,  f4,  f4,   f4,  f4,   iv,   cp8,  cp16},\n      /* f8 */   {f8,  f8,  f8,  f8,  f8,  f8,  f8,  f8,  iv,  f8,  iv,  bf,  f8,  f8,  f8,  f8,  f8,   f8,  f8,   iv,   cp8,  cp16},\n      /* i1 */   {i1,  i1,  f4,  f8,  i1,  i4,  i8,  i2,  iv,  f2,  iv,  bf,  i1,  i4,  i8,  i16, iv,   i2,  i16,  iv,   cp8,  cp16},\n      /* i4 */   {i4,  i4,  f4,  f8,  i4,  i4,  i8,  i4,  iv,  f2,  iv,  bf,  i4,  i4,  i8,  i16, iv,   i4,  i16,  iv,   cp8,  cp16},\n      /* i8 */   {i8,  i8,  f4,  f8,  i8,  i8,  i8,  i8,  iv,  f2,  iv,  bf,  i8,  i8,  i8,  i16, iv,   i8,  i16,  iv,   cp8,  cp16},\n      /* u1 */   {u1,  c1,  f4,  f8,  i2,  i4,  i8,  u1,  iv,  f2,  iv,  bf,  u1,  u2,  u4,  u8,  u16,  i2,  i16,  iv,   cp8,  cp16},\n      /* re */   {iv,  iv,  iv,  iv,  iv,  iv,  iv,  iv,  iv,  iv,  iv,  iv,  iv,  iv,  iv,  iv,  iv,   iv,  iv,   iv,   iv,   iv},\n      /* f2 */   {f2,  f2,  f4,  f8,  f2,  f2,  f2,  f2,  iv,  f2,  iv,  bf,  f2,  f2,  f2,  f2,  iv,   f2,  f2,   iv,   cp8,  cp16},\n      /* bu */   {iv,  iv,  iv,  iv,  iv,  iv,  iv,  iv,  iv,  iv,  bu,  iv,  iv,  iv,  iv,  iv,  iv,   iv,  iv,   iv,   iv,   iv},\n      /* bf */   {bf,  bf,  bf,  bf,  bf,  bf,  bf,  bf,  iv,  bf,  iv,  bf,  bf,  bf,  bf,  bf,  iv,   bf,  bf,   iv,   cp8,  cp16},\n      /* b1 */   {b1,  c1,  f4,  f8,  i1,  i4,  i8,  u1,  iv,  f2,  iv,  bf,  b1,  u2,  u4,  u8,  u16,  i2,  i16,  iv,   cp8,  cp16},\n      /* u2 */   {u2,  u2,  f4,  f8,  i4,  i4,  i8,  u2,  iv,  f2,  iv,  bf,  u2,  u2,  u4,  u8,  u16,  i4,  i16,  iv,   cp8,  cp16},\n      /* u4 */   {u4,  u4,  f4,  f8,  i8,  i8,  i8,  u4,  iv,  f2,  iv,  bf,  u4,  u4,  u4,  u8,  u16,  i8,  i16,  iv,   cp8,  cp16},\n      /* u8 */   {u8,  u8,  f4,  f8,  i16, i16, i16, u8,  iv,  f2,  iv,  bf,  u8,  u8,  u8,  u8,  u16,  i16, i16,  iv,   cp8,  cp16},\n      /* u16 */  {u16, u16, f4,  f8,  iv,  iv,  iv,  u16, iv,  f2,  iv,  bf,  u16, u16, u16, u16, u16,  iv,  iv,   iv,   cp8,  cp16},\n      /* i2 */   {i2,  i2,  f4,  f8,  i2,  i4,  i8,  i2,  iv,  f2,  iv,  bf,  i2,  i4,  i8,  i16, iv,   i2,  i16,  iv,   cp8,  cp16},\n      /* i16 */  {i16, i16, f4,  f8,  i16, i16, i16, i16, iv,  f2,  iv,  bf,  i16, i16, i16, i16, iv,   i16, i16,  iv,   cp8,  cp16},\n      /* cp4 */  {iv,  iv,  iv,  iv,  iv,  iv,  iv,  iv,  iv,  iv,  iv,  iv,  iv,  iv,  iv,  iv,  iv,   iv,  iv,   cp4,  cp8,  cp16},\n      /* cp8 */  {cp8, cp8, cp8, cp8, cp8, cp8, cp8, cp8, iv,  cp8, iv,  cp8, cp8, cp8, cp8, cp8, cp8,  cp8, cp8,  cp8,  cp8,  cp16},\n      /* cp16 */ {cp16,cp16,cp16,cp16,cp16,cp16,cp16,cp16,iv,  cp16,iv,  cp16,cp16,cp16,cp16,cp16,cp16, cp16,cp16, cp16, cp16, cp16}};\n  // clang-format on\n  return _promoteTypesLookup[static_cast<int>(a->data_type())][static_cast<int>(b->data_type())];\n}\n\nnamespace {\n\nstd::mutex default_dtype_mutex;\nSymbol<DType>* GetMutDefaultDTypeSymbol() {\n  static Symbol<DType> default_dtype = CHECK_JUST(DType::Get(DataType::kFloat));\n  return &default_dtype;\n}\n\n}  // namespace\n\nMaybe<void> SetDefaultDType(const Symbol<DType>& dtype) {\n  std::lock_guard<std::mutex> lock(default_dtype_mutex);\n  CHECK_OR_RETURN(dtype->is_floating_point())\n      << \"only floating-point types are supported as the default type\";\n  *GetMutDefaultDTypeSymbol() = dtype;\n  return Maybe<void>::Ok();\n}\n\nSymbol<DType> GetDefaultDType() {\n  std::lock_guard<std::mutex> lock(default_dtype_mutex);\n  return *GetMutDefaultDTypeSymbol();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/framework/dtype.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_FRAMEWORK_DTYPE_H_\n#define ONEFLOW_CORE_FRAMEWORK_DTYPE_H_\n\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/common/preprocessor.h\"\n#include \"oneflow/core/common/data_type.pb.h\"\n#include \"oneflow/core/common/symbol.h\"\n\nnamespace oneflow {\n\n#define DTYPE_SEQ                       \\\n  OF_PP_MAKE_TUPLE_SEQ(InvalidDataType) \\\n  OF_PP_MAKE_TUPLE_SEQ(Bool)            \\\n  OF_PP_MAKE_TUPLE_SEQ(Char)            \\\n  OF_PP_MAKE_TUPLE_SEQ(Float16)         \\\n  OF_PP_MAKE_TUPLE_SEQ(Float)           \\\n  OF_PP_MAKE_TUPLE_SEQ(Double)          \\\n  OF_PP_MAKE_TUPLE_SEQ(Int8)            \\\n  OF_PP_MAKE_TUPLE_SEQ(Int32)           \\\n  OF_PP_MAKE_TUPLE_SEQ(Int64)           \\\n  OF_PP_MAKE_TUPLE_SEQ(UInt8)           \\\n  OF_PP_MAKE_TUPLE_SEQ(OFRecord)        \\\n  OF_PP_MAKE_TUPLE_SEQ(TensorBuffer)    \\\n  OF_PP_MAKE_TUPLE_SEQ(BFloat16)        \\\n  OF_PP_MAKE_TUPLE_SEQ(UInt16)          \\\n  OF_PP_MAKE_TUPLE_SEQ(UInt32)          \\\n  OF_PP_MAKE_TUPLE_SEQ(UInt64)          \\\n  OF_PP_MAKE_TUPLE_SEQ(UInt128)         \\\n  OF_PP_MAKE_TUPLE_SEQ(Int16)           \\\n  OF_PP_MAKE_TUPLE_SEQ(Int128)          \\\n  OF_PP_MAKE_TUPLE_SEQ(Complex32)       \\\n  OF_PP_MAKE_TUPLE_SEQ(Complex64)       \\\n  OF_PP_MAKE_TUPLE_SEQ(Complex128)\n\nclass DType final {\n public:\n  DType(const DType&) = default;\n  DType(DType&&) = delete;\n  explicit DType(DataType data_type) : data_type_(data_type) {}\n  ~DType() = default;\n\n  bool operator==(const DType& other) const { return this->data_type() == other.data_type(); }\n\n  DataType data_type() const { return data_type_; }\n  bool is_signed() const;\n  bool is_complex() const;\n  bool is_integer() const;\n  bool is_floating_point() const;\n  const std::string& name() const;\n  Maybe<size_t> bytes() const;\n\n  static Maybe<const Symbol<DType>&> Get(DataType);\n  static const int priority_order[DataType_ARRAYSIZE];\n\n#define DECLARE_GET_DATA_TYPE_FUNCTION(data_type) static const Symbol<DType>& data_type();\n  OF_PP_FOR_EACH_TUPLE(DECLARE_GET_DATA_TYPE_FUNCTION, DTYPE_SEQ)\n#undef DECLARE_GET_DATA_TYPE_FUNCTION\n\n private:\n  DataType data_type_;\n};\n\nSymbol<DType> promoteTypes(const Symbol<DType> a, const Symbol<DType> b);\n\nMaybe<void> SetDefaultDType(const Symbol<DType>& dtype);\nSymbol<DType> GetDefaultDType();\n\n}  // namespace oneflow\n\nnamespace std {\n\ntemplate<>\nstruct hash<oneflow::DType> final {\n  size_t operator()(const oneflow::DType& dtype) const {\n    return static_cast<size_t>(dtype.data_type());\n  }\n};\n\n}  // namespace std\n\n#endif  // ONEFLOW_CORE_FRAMEWORK_DTYPE_H_\n"
  },
  {
    "path": "oneflow/core/framework/eager_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_FRAMEWORK_VM_UTIL_H_\n#define ONEFLOW_CORE_FRAMEWORK_VM_UTIL_H_\n#endif  // ONEFLOW_CORE_FRAMEWORK_VM_UTIL_H_\n"
  },
  {
    "path": "oneflow/core/framework/framework.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_FRAMEWORK_FRAMEWORK_H_\n#define ONEFLOW_CORE_FRAMEWORK_FRAMEWORK_H_\n\n#include \"oneflow/core/common/data_type.h\"\n#include \"oneflow/core/framework/util.h\"\n\n#include \"oneflow/core/framework/user_op_registry_manager.h\"\n#include \"oneflow/core/framework/infer_util.h\"\n#include \"oneflow/core/framework/sbp_context.h\"\n#include \"oneflow/core/framework/infer_output_blob_time_shape_fn_context.h\"\n#include \"oneflow/core/framework/infer_nd_sbp_fn_context.h\"\n#include \"oneflow/core/framework/compute_complexity_fn_context.h\"\n#include \"oneflow/core/framework/get_nd_sbp_signature_list_context.h\"\n#include \"oneflow/core/framework/user_op_hob.h\"\n\n#include \"oneflow/core/common/tensor_desc.h\"\n#include \"oneflow/core/framework/op_kernel.h\"\n#include \"oneflow/core/framework/user_op_def.h\"\n#include \"oneflow/core/framework/multi_thread.h\"\n#include \"oneflow/core/framework/to_string.h\"\n\n#endif  // ONEFLOW_CORE_FRAMEWORK_FRAMEWORK_H_\n"
  },
  {
    "path": "oneflow/core/framework/get_nd_sbp_signature_list_context.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_FRAMEWORK_COMPUTE_GET_ND_SBP_SIGNATURE_LIST_CONTEXT_H_\n#define ONEFLOW_CORE_FRAMEWORK_COMPUTE_GET_ND_SBP_SIGNATURE_LIST_CONTEXT_H_\n\n#include \"oneflow/core/framework/user_op_conf.h\"\n#include \"oneflow/core/framework/user_op_registry.h\"\n#include \"oneflow/core/job/parallel_desc.h\"\n\nnamespace oneflow {\n\nclass Shape;\n\nnamespace user_op {\n\nclass UserOpDefWrapper;\n\nclass GetNdSbpSignatureListContext {\n public:\n  virtual ~GetNdSbpSignatureListContext() = default;\n\n  virtual void AddNdSbpSignature(NdSbpSignature&) = 0;\n  virtual std::vector<NdSbpSignature>* MutNdSbpSignatureList() = 0;\n  virtual const Shape& parallel_hierarchy() = 0;\n  virtual const Shape& BlobShape4InputArgNameAndIndex(const std::string& arg_name,\n                                                      int32_t index) const = 0;\n  template<typename T>\n  T Attr(const std::string& attr_name) const {\n    return conf_.attr<T>(attr_name);\n  }\n\n  const UserOpConfWrapper& user_op_conf() const { return conf_; }\n\n protected:\n  explicit GetNdSbpSignatureListContext(UserOpConfWrapper&& conf) : conf_(std::move(conf)) {}\n\n private:\n  UserOpConfWrapper conf_;\n};\n\n}  // namespace user_op\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_FRAMEWORK_COMPUTE_GET_ND_SBP_SIGNATURE_LIST_CONTEXT_H_\n"
  },
  {
    "path": "oneflow/core/framework/global_param_grad_sync_mode.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/framework/global_param_grad_sync_mode.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nbool* GetThreadLocalGradSyncMode() {\n  static thread_local bool g_grad_mode = true;\n  return &g_grad_mode;\n}\n\n}  // namespace\n\nbool GlobalGradSyncMode::is_enabled() { return *GetThreadLocalGradSyncMode(); }\n\nvoid GlobalGradSyncMode::set_enabled(bool enabled) { *GetThreadLocalGradSyncMode() = enabled; }\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/framework/global_param_grad_sync_mode.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#ifndef ONEFLOW_CORE_FRAMEWORK_GLOBAL_PARAM_GRAD_SYNC_MODE_\n#define ONEFLOW_CORE_FRAMEWORK_GLOBAL_PARAM_GRAD_SYNC_MODE_\n\nnamespace oneflow {\n\nstruct GlobalGradSyncMode {\n  static bool is_enabled();\n  static void set_enabled(bool enabled);\n};\n\nclass GlobalParamGradSyncMode {\n public:\n  GlobalParamGradSyncMode(bool enabled) : prev_mode_(GlobalGradSyncMode::is_enabled()) {\n    GlobalGradSyncMode::set_enabled(enabled);\n  }\n  ~GlobalParamGradSyncMode() { GlobalGradSyncMode::set_enabled(prev_mode_); }\n  bool prev_mode() const { return prev_mode_; }\n\n private:\n  bool prev_mode_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_FRAMEWORK_GLOBAL_PARAM_GRAD_SYNC_MODE_\n"
  },
  {
    "path": "oneflow/core/framework/global_tensor_infer_cache.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/global_tensor_infer_cache.h\"\n#include \"oneflow/core/framework/tensor_tuple.h\"\n#include \"oneflow/core/operator/operator.h\"\n#include \"oneflow/core/framework/to_string.h\"\n#include \"oneflow/core/framework/tensor.h\"\n#include \"oneflow/core/framework/op_expr.h\"\n#include \"oneflow/core/framework/user_op_registry_manager.h\"\n#include \"oneflow/core/common/container_util.h\"\n#include \"oneflow/core/common/env_var/eager.h\"\n\nnamespace oneflow {\nnamespace one {\n\nnamespace {\n\nbool OptionalEqual(const Optional<Symbol<NdSbp>>& lhs, const Optional<Symbol<NdSbp>>& rhs) {\n  if (lhs.has_value() != rhs.has_value()) { return false; }\n  if (!lhs.has_value()) { return true; }\n  return CHECK_JUST(lhs) == CHECK_JUST(rhs);\n}\n\n}  // namespace\n\nsize_t InputGlobalTensorMeta::hash_value() const {\n  size_t hash_value = std::hash<Symbol<GlobalTensorMeta>>()(tensor_meta());\n  if (consumer_nd_sbp_constraint().has_value()) {\n    AddHash(&hash_value, CHECK_JUST(consumer_nd_sbp_constraint()));\n  }\n  return hash_value;\n}\n\nbool InputGlobalTensorMeta::operator==(const InputGlobalTensorMeta& other) const {\n  return this->tensor_meta() == other.tensor_meta()\n         && OptionalEqual(this->consumer_nd_sbp_constraint(), other.consumer_nd_sbp_constraint());\n}\n\nvoid InputGlobalTensorMeta::assign(Symbol<GlobalTensorMeta> tensor_meta,\n                                   const Optional<Symbol<NdSbp>>& consumer_nd_sbp_constraint) {\n  tensor_meta_ = tensor_meta;\n  consumer_nd_sbp_constraint_ = consumer_nd_sbp_constraint;\n}\n\nsize_t GlobalTensorMetaInferArgs::hash_value() const {\n  size_t hash_value = std::hash<AttrMap>()(attrs_);\n  const auto& tensor_meta_hash_functor = std::hash<InputGlobalTensorMeta>();\n  for (const auto& tensor_meta : input_global_tensor_metas_) {\n    HashCombine(&hash_value, tensor_meta_hash_functor(tensor_meta));\n  }\n  return hash_value;\n}\n\nsize_t SrcOpGlobalTensorMetaInferArgs::hash_value() const {\n  size_t hash_value = std::hash<AttrMap>()(attrs_);\n  AddHash(&hash_value, parallel_desc_);\n  AddHash(&hash_value, nd_sbp_);\n  return hash_value;\n}\n\nbool GlobalTensorMetaInferArgs::operator==(const GlobalTensorMetaInferArgs& other) const {\n  return this->attrs_ == other.attrs_\n         && this->input_global_tensor_metas_ == other.input_global_tensor_metas_;\n}\n\nbool SrcOpGlobalTensorMetaInferArgs::operator==(const SrcOpGlobalTensorMetaInferArgs& other) const {\n  return this->attrs_ == other.attrs_ && this->parallel_desc_ == other.parallel_desc_\n         && this->nd_sbp_ == other.nd_sbp_;\n}\n\nMaybe<void> GlobalTensorMetaInferArgs::MakeNdSbpConstraints(\n    const UserOpExpr& user_op_expr, NdSbpSignature* nd_sbp_signature) const {\n  const auto& input_arg_tuple = *user_op_expr.input_arg_tuple();\n  auto* map = nd_sbp_signature->mutable_bn_in_op2nd_sbp();\n  for (int i = 0; i < input_arg_tuple.size(); ++i) {\n    const auto& constaint = input_global_tensor_metas_[i].consumer_nd_sbp_constraint();\n    if (constaint.has_value()) { (*map)[input_arg_tuple.indexed_bns().at(i)] = *JUST(constaint); }\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> GlobalTensorMetaInferArgs::MakeInputBlobDescs(const UserOpExpr& user_op_expr,\n                                                          std::vector<BlobDesc>* blob_descs) const {\n  CHECK_OR_RETURN(blob_descs->empty());\n  const auto& input_arg_tuple = *user_op_expr.input_arg_tuple();\n  blob_descs->reserve(input_arg_tuple.size());\n  for (int i = 0; i < input_arg_tuple.size(); ++i) {\n    const auto& tensor_meta = *input_global_tensor_metas_[i].tensor_meta();\n    blob_descs->emplace_back(tensor_meta.shape(), tensor_meta.stride(), tensor_meta.data_type(),\n                             tensor_meta.memory_format());\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> GlobalTensorMetaInferArgs::MakeNdSbpInferHints(\n    const UserOpExpr& user_op_expr, const std::vector<BlobDesc>& blob_descs,\n    std::vector<NdSbpInferHint>* hints) const {\n  CHECK_OR_RETURN(hints->empty());\n  const auto& input_arg_tuple = *user_op_expr.input_arg_tuple();\n  hints->reserve(input_arg_tuple.size());\n  for (int i = 0; i < input_arg_tuple.size(); ++i) {\n    const auto& tensor_meta = *input_global_tensor_metas_[i].tensor_meta();\n    const auto* parallel_desc = &*tensor_meta.parallel_desc();\n    const auto* blob_desc = &blob_descs.at(i);\n    const auto* nd_sbp = &*tensor_meta.nd_sbp();\n    hints->emplace_back(parallel_desc, blob_desc, nd_sbp);\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<GlobalTensorMetaInferArgs> GlobalTensorMetaInferArgs::New(const AttrMap& attrs,\n                                                                const TensorTuple& input_tensors) {\n  std::shared_ptr<GlobalTensorMetaInferArgs> infer_args(new GlobalTensorMetaInferArgs());\n  infer_args->attrs_ = attrs;\n  infer_args->input_global_tensor_metas_.resize(input_tensors.size());\n  JUST(infer_args->InitInputGlobalTensorMetas(input_tensors));\n  return infer_args;\n}\n\nMaybe<SrcOpGlobalTensorMetaInferArgs> SrcOpGlobalTensorMetaInferArgs::New(\n    const AttrMap& attrs, Symbol<ParallelDesc> parallel_desc, Symbol<NdSbp> nd_sbp) {\n  std::shared_ptr<SrcOpGlobalTensorMetaInferArgs> infer_args(new SrcOpGlobalTensorMetaInferArgs());\n  infer_args->attrs_ = attrs;\n  infer_args->parallel_desc_ = parallel_desc;\n  infer_args->nd_sbp_ = nd_sbp;\n  return infer_args;\n}\n\nMaybe<void> GlobalTensorMetaInferArgs::InitInputGlobalTensorMetas(\n    const TensorTuple& input_tensors) {\n  for (int i = 0; i < input_tensors.size(); ++i) {\n    const auto& tensor = *input_tensors.at(i);\n    const auto& tensor_meta = JUST(tensor.global_tensor_meta());\n    const auto& constraint = JUST(tensor.consumer_nd_sbp_constraint());\n    input_global_tensor_metas_[i].assign(tensor_meta, constraint);\n  }\n  return Maybe<void>::Ok();\n}\n\nnamespace {\n\nMaybe<Operator> MakeOp(const UserOpExpr& user_op_expr, const AttrMap& attrs,\n                       const std::string& device_tag) {\n  OperatorConf op_conf;\n  JUST(user_op_expr.BuildOpConf(&op_conf, attrs));\n  DeviceType device_type = JUST(DeviceType4DeviceTag(device_tag));\n  return JUST(ConstructOp(op_conf, device_type));\n}\n\nMaybe<void> CheckInputParallelDescIdentical(const GlobalTensorMetaInferArgs& infer_args,\n                                            const UserOpExpr& user_op_expr) {\n  if (infer_args.input_global_tensor_metas().empty()) { return Maybe<void>::Ok(); }\n  Symbol<ParallelDesc> default_parallel_desc;\n  for (int i = 0; i < infer_args.input_global_tensor_metas().size(); ++i) {\n    if (user_op_expr.IsHostMemoryInput(i)) { continue; }\n    default_parallel_desc =\n        JUST(VectorAt(infer_args.input_global_tensor_metas(), i)).tensor_meta()->parallel_desc();\n    break;\n  }\n\n  for (int i = 0; i < infer_args.input_global_tensor_metas().size(); ++i) {\n    if (user_op_expr.IsHostMemoryInput(i)) { continue; }\n    CHECK_OR_RETURN(\n        default_parallel_desc\n        == JUST(VectorAt(infer_args.input_global_tensor_metas(), i)).tensor_meta()->parallel_desc())\n        << Error::RuntimeError()\n        << \"Expected all tensors to be on the same placement, but found \"\n           \"at least two placements, \"\n        << *JUST(PlacementToString(default_parallel_desc)) << \" (positional 0) and \"\n        << *JUST(PlacementToString(JUST(VectorAt(infer_args.input_global_tensor_metas(), i))\n                                       .tensor_meta()\n                                       ->parallel_desc()))\n        << \" (positional \" << i << \")!\";\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> CheckIsDeviceSupportedByOp(const ParallelDesc& parallel_desc,\n                                       const std::string& op_type_name) {\n  if (IsCpuOnly(op_type_name)) { CHECK_EQ_OR_RETURN(parallel_desc.device_tag(), \"cpu\"); }\n  return Maybe<void>::Ok();\n}\n\nclass UserOpExprDeviceAndStreamInferContext final : public user_op::DeviceAndStreamInferContext {\n public:\n  UserOpExprDeviceAndStreamInferContext(const UserOpExpr* user_op_expr,\n                                        const GlobalTensorMetaInferArgs* infer_args)\n      : user_op_expr_(user_op_expr),\n        composed_attrs_(infer_args->attrs(), user_op_expr->base_attrs()),\n        in_tensor_devices_(user_op_expr_->input_size()),\n        out_tensor_devices_(user_op_expr_->output_size()) {\n    for (int i = 0; i < user_op_expr_->input_size(); ++i) {\n      const auto& parallel_desc =\n          infer_args->input_global_tensor_metas().at(i).tensor_meta()->parallel_desc();\n      in_tensor_devices_.at(i) = CHECK_JUST(GetTensorDevice(parallel_desc));\n    }\n  }\n\n  const std::vector<std::pair<std::string, int32_t>>& inputs() const override {\n    return user_op_expr_->indexed_input_pairs();\n  }\n\n  const std::vector<std::pair<std::string, int32_t>>& outputs() const override {\n    return user_op_expr_->indexed_output_pairs();\n  }\n\n  Symbol<Device>* OutputTensorDevice4ArgNameAndIndex(const std::string& name,\n                                                     int64_t index) override {\n    const auto& arg_tuple = *user_op_expr_->output_arg_tuple();\n    int32_t tuple_index = arg_tuple.TensorTupleIndex4ArgNameAndIndex(name, index);\n    CHECK_GE(tuple_index, 0);\n    CHECK_LT(tuple_index, user_op_expr_->output_size());\n    return &out_tensor_devices_.at(tuple_index);\n  }\n\n  Symbol<Device> InputTensorDevice4ArgNameAndIndex(const std::string& name,\n                                                   int64_t index) const override {\n    const auto& arg_tuple = *user_op_expr_->input_arg_tuple();\n    int32_t tuple_index = arg_tuple.TensorTupleIndex4ArgNameAndIndex(name, index);\n    CHECK_GE(tuple_index, 0);\n    CHECK_LT(tuple_index, user_op_expr_->input_size());\n    return in_tensor_devices_.at(tuple_index);\n  }\n\n private:\n  const std::shared_ptr<const user_op::AttrVal>& Attr4Name(\n      const std::string& attr_name) const override {\n    return composed_attrs_.Attr4Name(attr_name);\n  }\n  const UserOpExpr* user_op_expr_;\n  const ComposedAttrMap composed_attrs_;\n  std::vector<Symbol<Device>> in_tensor_devices_;\n  std::vector<Symbol<Device>> out_tensor_devices_;\n};\n\n}  // namespace\n\n/* static */ Maybe<Symbol<Stream>> GlobalTensorInferCache::InferDeviceAndStream(\n    const UserOpExpr& user_op_expr, const GlobalTensorMetaInferArgs& infer_args) {\n  if (!user_op_expr.device_and_stream_infer_fn()) {\n    Symbol<ParallelDesc> parallel_desc =\n        infer_args.input_global_tensor_metas()[0].tensor_meta()->parallel_desc();\n    return GetDefaultStreamByPlacement(parallel_desc);\n  } else {\n    UserOpExprDeviceAndStreamInferContext device_and_stream_ctx(&user_op_expr, &infer_args);\n    return TRY(user_op_expr.device_and_stream_infer_fn()(&device_and_stream_ctx));\n  }\n}\n\n/* static */ Maybe<const GlobalTensorInferResult> GlobalTensorInferCache::Infer(\n    const UserOpExpr& user_op_expr, const GlobalTensorMetaInferArgs& infer_args) {\n  CHECK_GT_OR_RETURN(infer_args.input_global_tensor_metas().size(), 0);  // NOLINT\n  Symbol<ParallelDesc> parallel_desc =\n      infer_args.input_global_tensor_metas()[0].tensor_meta()->parallel_desc();\n  JUST(CheckInputParallelDescIdentical(infer_args, user_op_expr));\n  JUST(CheckIsDeviceSupportedByOp(*parallel_desc, user_op_expr.op_type_name()));\n  std::vector<OpArgMutGlobalTensorMeta> output_mut_metas(user_op_expr.output_size());\n  {\n    // Infer OpArgMutGlobalTensorMeta.\n    const auto& input_metas = infer_args.input_global_tensor_metas();\n    JUST(user_op_expr.InferLogicalTensorDesc(\n        infer_args.attrs(), parallel_desc,\n        [&](int32_t i) { return &*input_metas.at(i).tensor_meta(); },\n        [&](int32_t i) { return output_mut_metas.at(i).mut_tensor_meta(); }));\n  }\n  const auto& op = JUST(MakeOp(user_op_expr, infer_args.attrs(), parallel_desc->device_tag()));\n  JUST(op->FillOpParallelDesc(parallel_desc.shared_from_symbol()));\n  JUST(op->InferParallelSignatureIf());\n  {\n    // Infer parallel distribution.\n    NdSbpSignature nd_sbp_constraints;\n    JUST(infer_args.MakeNdSbpConstraints(user_op_expr, &nd_sbp_constraints));\n    std::vector<BlobDesc> blob_descs;\n    JUST(infer_args.MakeInputBlobDescs(user_op_expr, &blob_descs));\n    std::vector<NdSbpInferHint> pd_infer_hints;\n    JUST(infer_args.MakeNdSbpInferHints(user_op_expr, blob_descs, &pd_infer_hints));\n    const auto& input_arg_tuple = *user_op_expr.input_arg_tuple();\n    const auto& NdSbpInferHint4Ibn = [&](const std::string& ibn) -> Maybe<const NdSbpInferHint*> {\n      int32_t input_index = input_arg_tuple.bn_in_op2tensor_tuple_index().at(ibn);\n      CHECK_GE_OR_RETURN(input_index, 0);\n      CHECK_LT_OR_RETURN(input_index, pd_infer_hints.size());\n      return &pd_infer_hints.at(input_index);\n    };\n    // The inferred results can be retrieved by op->NdSbp4BnInOp(obn).\n    JUST(op->InferNdSbpSignatureIf(nd_sbp_constraints, *parallel_desc, NdSbpInferHint4Ibn));\n  }\n  auto result = std::make_unique<GlobalTensorInferResult>(user_op_expr.input_size(),\n                                                          user_op_expr.output_size());\n  auto* input_metas = result->mut_input_tensor_metas();\n  for (int32_t i = 0; i < user_op_expr.input_size(); ++i) {\n    const auto& old_global_tensor_meta = infer_args.input_global_tensor_metas()[i].tensor_meta();\n    const auto& ibn = user_op_expr.input_arg_tuple()->indexed_bns().at(i);\n    const auto& nd_sbp = SymbolOf(*JUST(op->NdSbp4BnInOp(ibn)));\n    GlobalTensorMeta global_tensor_meta(\n        old_global_tensor_meta->shape(), old_global_tensor_meta->dtype(),\n        old_global_tensor_meta->memory_format(), nd_sbp, old_global_tensor_meta->parallel_desc());\n    (*input_metas)[i] = SymbolOf(global_tensor_meta);\n  }\n  auto* output_metas = result->mut_output_tensor_metas();\n  for (int32_t i = 0; i < user_op_expr.output_size(); ++i) {\n    const auto& output_mut_meta = output_mut_metas.at(i);\n    const auto& shape = output_mut_meta.tensor_meta().shape();\n    DataType data_type = output_mut_meta.tensor_meta().data_type();\n    MemoryFormat memory_format = output_mut_meta.tensor_meta().memory_format();\n    const auto& obn = user_op_expr.output_arg_tuple()->indexed_bns().at(i);\n    const auto& nd_sbp = SymbolOf(*JUST(op->NdSbp4BnInOp(obn)));\n    GlobalTensorMeta tensor_meta(shape, data_type, memory_format, nd_sbp, parallel_desc);\n    output_metas->at(i) = SymbolOf(tensor_meta);\n  }\n  result->set_stream(JUST(InferDeviceAndStream(user_op_expr, infer_args)));\n  return std::shared_ptr<const GlobalTensorInferResult>(std::move(result));\n}\n\n/* static */ Maybe<const GlobalTensorInferResult> GlobalTensorInferCache::Infer(\n    const UserOpExpr& user_op_expr, const SrcOpGlobalTensorMetaInferArgs& infer_args) {\n  Symbol<ParallelDesc> parallel_desc = infer_args.parallel_desc();\n  JUST(CheckIsDeviceSupportedByOp(*parallel_desc, user_op_expr.op_type_name()));\n  std::vector<OpArgMutGlobalTensorMeta> output_mut_metas(user_op_expr.output_size());\n  {\n    // Infer OpArgMutGlobalTensorMeta.\n    const auto& GetInputTensorMeta = [](int32_t i) {\n      UNIMPLEMENTED();\n      return nullptr;\n    };\n    JUST(user_op_expr.InferLogicalTensorDesc(\n        infer_args.attrs(), parallel_desc, GetInputTensorMeta,\n        [&](int32_t i) { return output_mut_metas.at(i).mut_tensor_meta(); }));\n  }\n  auto result = std::make_unique<GlobalTensorInferResult>(user_op_expr.input_size(),\n                                                          user_op_expr.output_size());\n  auto* output_metas = result->mut_output_tensor_metas();\n  for (int32_t i = 0; i < user_op_expr.output_size(); ++i) {\n    const auto& output_mut_meta = output_mut_metas.at(i);\n    const auto& shape = output_mut_meta.tensor_meta().shape();\n    DataType data_type = output_mut_meta.tensor_meta().data_type();\n    MemoryFormat memory_format = output_mut_meta.tensor_meta().memory_format();\n    const auto& nd_sbp = infer_args.nd_sbp();\n    GlobalTensorMeta tensor_meta(shape, data_type, memory_format, nd_sbp, parallel_desc);\n    output_metas->at(i) = SymbolOf(tensor_meta);\n  }\n  result->set_stream(JUST(GetDefaultStreamByPlacement(parallel_desc)));\n  return std::shared_ptr<const GlobalTensorInferResult>(std::move(result));\n}\n\nMaybe<const GlobalTensorInferResult> GlobalTensorInferCache::GetOrInfer(\n    const GlobalTensorMetaInferArgs& infer_args) {\n  auto iter = cache_.find(infer_args);\n  if (iter == cache_.end()) {\n    if (unlikely(cache_.size() >= ThreadLocalEnvInteger<ONEFLOW_EAGER_TENSOR_INFER_CACHE_SIZE>())) {\n      cache_.clear();\n    }\n    const auto& user_op_expr = user_op_expr_.lock();\n    CHECK_OR_RETURN(static_cast<bool>(user_op_expr));\n    const auto& output_tensor_metas = JUST(Infer(*user_op_expr, infer_args));\n    iter = cache_.emplace(infer_args, output_tensor_metas).first;\n  }\n  return iter->second;\n}\n\nMaybe<const GlobalTensorInferResult> GlobalTensorInferCache::GetOrInfer(\n    const SrcOpGlobalTensorMetaInferArgs& infer_args) {\n  auto iter = src_op_cache_.find(infer_args);\n  if (iter == src_op_cache_.end()) {\n    if (unlikely(src_op_cache_.size()\n                 >= ThreadLocalEnvInteger<ONEFLOW_EAGER_TENSOR_INFER_CACHE_SIZE>())) {\n      src_op_cache_.clear();\n    }\n    const auto& user_op_expr = user_op_expr_.lock();\n    CHECK_OR_RETURN(static_cast<bool>(user_op_expr));\n    const auto& output_tensor_metas = JUST(Infer(*user_op_expr, infer_args));\n    iter = src_op_cache_.emplace(infer_args, output_tensor_metas).first;\n  }\n  return iter->second;\n}\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/framework/global_tensor_infer_cache.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_FRAMEWORK_GLOBAL_TENSOR_INFER_CACHE_H_\n#define ONEFLOW_CORE_FRAMEWORK_GLOBAL_TENSOR_INFER_CACHE_H_\n\n#include \"oneflow/core/common/symbol.h\"\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/common/optional.h\"\n#include \"oneflow/core/framework/attr_map.h\"\n#include \"oneflow/core/framework/device.h\"\n#include \"oneflow/core/framework/stream.h\"\n#include \"oneflow/core/common/tensor_meta.h\"\n#include \"oneflow/core/register/blob_desc.h\"\n#include \"oneflow/core/job/nd_sbp_infer_hint.h\"\n\nnamespace oneflow {\n\nclass NdSbp;\n\nclass ParallelDesc;\n\nnamespace one {\n\nclass GlobalTensorMeta;\n\nclass InputGlobalTensorMeta final {\n public:\n  InputGlobalTensorMeta() : tensor_meta_(), consumer_nd_sbp_constraint_() {}\n  InputGlobalTensorMeta(Symbol<GlobalTensorMeta> tensor_meta,\n                        const Optional<Symbol<NdSbp>>& consumer_nd_sbp_constraint)\n      : tensor_meta_(tensor_meta), consumer_nd_sbp_constraint_(consumer_nd_sbp_constraint) {}\n\n  InputGlobalTensorMeta(const InputGlobalTensorMeta&) = default;\n  InputGlobalTensorMeta(InputGlobalTensorMeta&&) = default;\n  ~InputGlobalTensorMeta() = default;\n\n  size_t hash_value() const;\n  bool operator==(const InputGlobalTensorMeta& other) const;\n  Symbol<GlobalTensorMeta> tensor_meta() const { return tensor_meta_; }\n  const Optional<Symbol<NdSbp>>& consumer_nd_sbp_constraint() const {\n    return consumer_nd_sbp_constraint_;\n  }\n  void assign(Symbol<GlobalTensorMeta> tensor_meta,\n              const Optional<Symbol<NdSbp>>& consumer_nd_sbp_constraint);\n\n private:\n  Symbol<GlobalTensorMeta> tensor_meta_;\n  Optional<Symbol<NdSbp>> consumer_nd_sbp_constraint_;\n};\n\nclass TensorTuple;\nclass UserOpExpr;\n\nclass GlobalTensorMetaInferArgs final {\n public:\n  GlobalTensorMetaInferArgs(const GlobalTensorMetaInferArgs&) = default;\n  GlobalTensorMetaInferArgs(GlobalTensorMetaInferArgs&&) = default;\n  ~GlobalTensorMetaInferArgs() = default;\n\n  const std::vector<InputGlobalTensorMeta>& input_global_tensor_metas() const {\n    return input_global_tensor_metas_;\n  }\n  const AttrMap& attrs() const { return attrs_; }\n\n  size_t hash_value() const;\n\n  bool operator==(const GlobalTensorMetaInferArgs& other) const;\n\n  Maybe<void> MakeNdSbpConstraints(const UserOpExpr& user_op_expr,\n                                   NdSbpSignature* nd_sbp_signature) const;\n\n  Maybe<void> MakeInputBlobDescs(const UserOpExpr& user_op_expr,\n                                 std::vector<BlobDesc>* blob_descs) const;\n\n  Maybe<void> MakeNdSbpInferHints(const UserOpExpr& user_op_expr,\n                                  const std::vector<BlobDesc>& blob_descs,\n                                  std::vector<NdSbpInferHint>* hints) const;\n\n  static Maybe<GlobalTensorMetaInferArgs> New(const AttrMap& attrs,\n                                              const TensorTuple& input_tensors);\n\n private:\n  GlobalTensorMetaInferArgs() = default;\n  Maybe<void> InitInputGlobalTensorMetas(const TensorTuple& input_tensors);\n\n  AttrMap attrs_;\n  std::vector<InputGlobalTensorMeta> input_global_tensor_metas_;\n};\n\nclass SrcOpGlobalTensorMetaInferArgs final {\n public:\n  SrcOpGlobalTensorMetaInferArgs(const SrcOpGlobalTensorMetaInferArgs&) = default;\n  SrcOpGlobalTensorMetaInferArgs(SrcOpGlobalTensorMetaInferArgs&&) = default;\n  ~SrcOpGlobalTensorMetaInferArgs() = default;\n\n  Symbol<ParallelDesc> parallel_desc() const { return parallel_desc_; }\n  Symbol<NdSbp> nd_sbp() const { return nd_sbp_; }\n  const AttrMap& attrs() const { return attrs_; }\n\n  size_t hash_value() const;\n\n  bool operator==(const SrcOpGlobalTensorMetaInferArgs& other) const;\n\n  static Maybe<SrcOpGlobalTensorMetaInferArgs> New(const AttrMap& attrs,\n                                                   Symbol<ParallelDesc> parallel_desc,\n                                                   Symbol<NdSbp> nd_sbp);\n\n private:\n  SrcOpGlobalTensorMetaInferArgs() = default;\n\n  AttrMap attrs_;\n  Symbol<ParallelDesc> parallel_desc_;\n  Symbol<NdSbp> nd_sbp_;\n};\n\nclass OpArgMutGlobalTensorMeta final {\n public:\n  OpArgMutGlobalTensorMeta()\n      : tensor_meta_(std::make_shared<Shape>(), DataType::kInvalidDataType,\n                     MemoryFormat::kContiguous) {}\n\n  OpArgMutGlobalTensorMeta(const OpArgMutGlobalTensorMeta&) = default;\n  OpArgMutGlobalTensorMeta(OpArgMutGlobalTensorMeta&&) = default;\n  ~OpArgMutGlobalTensorMeta() = default;\n\n  const TensorMeta& tensor_meta() const { return tensor_meta_; }\n\n  TensorMeta* mut_tensor_meta() { return &tensor_meta_; }\n\n private:\n  MutTensorMeta tensor_meta_;\n};\n\n}  // namespace one\n}  // namespace oneflow\n\nnamespace std {\n\ntemplate<>\nstruct hash<oneflow::one::InputGlobalTensorMeta> final {\n  size_t operator()(const oneflow::one::InputGlobalTensorMeta& val) const {\n    return val.hash_value();\n  }\n};\n\ntemplate<>\nstruct hash<oneflow::one::GlobalTensorMetaInferArgs> final {\n  size_t operator()(const oneflow::one::GlobalTensorMetaInferArgs& val) const {\n    return val.hash_value();\n  }\n};\n\ntemplate<>\nstruct hash<oneflow::one::SrcOpGlobalTensorMetaInferArgs> final {\n  size_t operator()(const oneflow::one::SrcOpGlobalTensorMetaInferArgs& val) const {\n    return val.hash_value();\n  }\n};\n\n}  // namespace std\n\nnamespace oneflow {\nnamespace one {\n\nclass GlobalTensorInferResult final {\n public:\n  GlobalTensorInferResult(size_t input_size, size_t output_size)\n      : input_tensor_metas_(input_size), output_tensor_metas_(output_size) {}\n  GlobalTensorInferResult(const GlobalTensorInferResult&) = delete;\n  GlobalTensorInferResult(GlobalTensorInferResult&&) = delete;\n  ~GlobalTensorInferResult() = default;\n\n  const std::vector<Symbol<GlobalTensorMeta>>& input_tensor_metas() const {\n    return input_tensor_metas_;\n  }\n  const std::vector<Symbol<GlobalTensorMeta>>& output_tensor_metas() const {\n    return output_tensor_metas_;\n  }\n\n  std::vector<Symbol<GlobalTensorMeta>>* mut_input_tensor_metas() { return &input_tensor_metas_; }\n  std::vector<Symbol<GlobalTensorMeta>>* mut_output_tensor_metas() { return &output_tensor_metas_; }\n\n  const Symbol<Stream>& stream() const { return stream_; }\n  void set_stream(const Symbol<Stream>& stream) { stream_ = stream; }\n\n private:\n  std::vector<Symbol<GlobalTensorMeta>> input_tensor_metas_;\n  std::vector<Symbol<GlobalTensorMeta>> output_tensor_metas_;\n  Symbol<Stream> stream_;\n};\n\nclass GlobalTensorInferCache final {\n public:\n  GlobalTensorInferCache(const std::shared_ptr<const UserOpExpr>& user_op_expr)\n      : user_op_expr_(user_op_expr) {}\n\n  Maybe<const GlobalTensorInferResult> GetOrInfer(const GlobalTensorMetaInferArgs& infer_args);\n\n  static Maybe<const GlobalTensorInferResult> Infer(const UserOpExpr& user_op_expr,\n                                                    const GlobalTensorMetaInferArgs& infer_args);\n\n  Maybe<const GlobalTensorInferResult> GetOrInfer(const SrcOpGlobalTensorMetaInferArgs& infer_args);\n\n  static Maybe<const GlobalTensorInferResult> Infer(\n      const UserOpExpr& user_op_expr, const SrcOpGlobalTensorMetaInferArgs& infer_args);\n\n private:\n  static Maybe<Symbol<Stream>> InferDeviceAndStream(const UserOpExpr& user_op_expr,\n                                                    const GlobalTensorMetaInferArgs& infer_args);\n\n  std::weak_ptr<const UserOpExpr> user_op_expr_;\n  HashMap<GlobalTensorMetaInferArgs, std::shared_ptr<const GlobalTensorInferResult>> cache_;\n  HashMap<SrcOpGlobalTensorMetaInferArgs, std::shared_ptr<const GlobalTensorInferResult>>\n      src_op_cache_;\n};\n\n}  // namespace one\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_FRAMEWORK_GLOBAL_TENSOR_INFER_CACHE_H_\n"
  },
  {
    "path": "oneflow/core/framework/id_util.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/framework/id_util.h\"\n\nnamespace oneflow {\n\nMaybe<std::string> UniqueStr(const std::string& prefix) { return prefix + NewUniqueId(); }\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/framework/id_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_FRAMEWORK_ID_UTIL_H_\n#define ONEFLOW_CORE_FRAMEWORK_ID_UTIL_H_\n\n#include <string>\n#include \"oneflow/core/common/maybe.h\"\n\nnamespace oneflow {\n\nMaybe<std::string> UniqueStr(const std::string& prefix);\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_FRAMEWORK_ID_UTIL_H_\n"
  },
  {
    "path": "oneflow/core/framework/infer_nd_sbp_fn_context.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_FRAMEWORK_INFER_ND_SBP_FN_CONTEXT_H_\n#define ONEFLOW_CORE_FRAMEWORK_INFER_ND_SBP_FN_CONTEXT_H_\n\n#include \"oneflow/core/framework/user_op_conf.h\"\n\nnamespace oneflow {\n\nnamespace user_op {\n\nclass InferNdSbpFnContext {\n public:\n  InferNdSbpFnContext() = default;\n  virtual ~InferNdSbpFnContext() = default;\n  InferNdSbpFnContext(const InferNdSbpFnContext&) = delete;\n  virtual const TensorDesc& LogicalTensorDesc4InputArgNameAndIndex(\n      const std::string& input_arg_name, int32_t index) const = 0;\n  virtual NdSbp* NdSbp4ArgNameAndIndex(const std::string& arg_name, int32_t index) = 0;\n  virtual const NdSbp& NdSbpHint4InputArgNameAndIndex(const std::string& arg_name,\n                                                      int32_t index) const = 0;\n  virtual const NdSbpSignature& nd_sbp_constraints() const = 0;\n  virtual const UserOpConfWrapper& user_op_conf() const = 0;\n  virtual int64_t parallel_num() const = 0;\n  virtual const Shape& parallel_hierarchy() = 0;\n  virtual const std::vector<std::pair<std::string, int32_t>>& inputs() const = 0;\n  virtual const std::vector<std::pair<std::string, int32_t>>& outputs() const = 0;\n};\n\n}  // namespace user_op\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_FRAMEWORK_INFER_ND_SBP_FN_CONTEXT_H_\n"
  },
  {
    "path": "oneflow/core/framework/infer_output_blob_time_shape_fn_context.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_FRAMEWORK_INFER_OUTPUT_BLOB_TIME_SHAPE_FN_CONTEXT_H_\n#define ONEFLOW_CORE_FRAMEWORK_INFER_OUTPUT_BLOB_TIME_SHAPE_FN_CONTEXT_H_\n\n#include \"oneflow/core/framework/user_op_conf.h\"\n\nnamespace oneflow {\n\nnamespace user_op {\n\nclass InferOutputBlobTimeShapeFnContext {\n public:\n  InferOutputBlobTimeShapeFnContext() = default;\n  virtual ~InferOutputBlobTimeShapeFnContext() = default;\n  InferOutputBlobTimeShapeFnContext(const InferOutputBlobTimeShapeFnContext&) = delete;\n\n  virtual const Shape& TimeShape4InputArgNameAndIndex(const std::string& arg_name,\n                                                      int32_t index) = 0;\n  virtual const UserOpConfWrapper& user_op_conf() const = 0;\n  virtual Shape* mut_output_blob_time_shape() = 0;\n};\n\n}  // namespace user_op\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_FRAMEWORK_INFER_OUTPUT_BLOB_TIME_SHAPE_FN_CONTEXT_H_\n"
  },
  {
    "path": "oneflow/core/framework/infer_util.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/infer_util.h\"\n#include \"oneflow/core/common/data_type.pb.h\"\n#include \"oneflow/core/framework/user_op_def.pb.h\"\n#include \"oneflow/core/operator/op_conf.pb.h\"\n#include \"oneflow/core/framework/attr_value.h\"\n#include \"oneflow/core/framework/user_op_def.h\"\n#include \"oneflow/core/framework/user_op_conf.h\"\n#include \"oneflow/core/framework/attr_value_accessor.h\"\n\nnamespace oneflow {\n\nnamespace user_op {\n\nMaybe<void> TensorDescInferFnUtil::Unchanged(InferContext* ctx) {\n  const TensorDesc* first_tensor_desc = nullptr;\n  for (size_t i = 0; i < ctx->inputs().size(); ++i) {\n    const std::pair<std::string, int32_t>& input_arg = ctx->inputs().at(i);\n    if (first_tensor_desc) {\n      const TensorDesc& tensor_desc = ctx->InputTensorDesc(input_arg.first, input_arg.second);\n      CHECK_EQ_OR_RETURN(tensor_desc.shape(), first_tensor_desc->shape())\n          << Error::RuntimeError() << \"Tensor descriptions should have the same shape: expected \"\n          << first_tensor_desc->shape() << \" but got \" << tensor_desc.shape();\n    } else {\n      first_tensor_desc = &ctx->InputTensorDesc(input_arg.first, input_arg.second);\n    }\n  }\n  for (size_t i = 0; i < ctx->outputs().size(); ++i) {\n    const std::pair<std::string, int32_t>& output_arg = ctx->outputs().at(i);\n    ctx->SetOutputIsDynamic(output_arg.first, output_arg.second,                           // NOLINT\n                            first_tensor_desc->is_dynamic());                              // NOLINT\n    ctx->SetOutputShape(output_arg.first, output_arg.second, first_tensor_desc->shape());  // NOLINT\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> TensorDescInferFnUtil::UnchangedDataType(InferContext* ctx) {\n  const TensorDesc* first_tensor_desc = nullptr;\n  for (size_t i = 0; i < ctx->inputs().size(); ++i) {\n    const std::pair<std::string, int32_t>& input_arg = ctx->inputs().at(i);\n    if (first_tensor_desc) {\n      const TensorDesc& tensor_desc = ctx->InputTensorDesc(input_arg.first, input_arg.second);\n      CHECK_EQ_OR_RETURN(tensor_desc.data_type(), first_tensor_desc->data_type())\n          << Error::TypeError() << \"Tensor descriptions should have the same type. Expected \"\n          << DataType_Name(first_tensor_desc->data_type()) << \", but got \"\n          << DataType_Name(tensor_desc.data_type());\n    } else {\n      first_tensor_desc = &ctx->InputTensorDesc(input_arg.first, input_arg.second);\n    }\n  }\n  for (size_t i = 0; i < ctx->outputs().size(); ++i) {\n    const std::pair<std::string, int32_t>& output_arg = ctx->outputs().at(i);\n    ctx->SetOutputDType(output_arg.first, output_arg.second,  // NOLINT\n                        first_tensor_desc->data_type());      // NOLINT\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> TensorDescInferFnUtil::InOutCorrespond(InferContext* ctx) {\n  CHECK_EQ_OR_RETURN(ctx->inputs().size(), ctx->outputs().size())\n      << Error::InvalidValueError()\n      << \"Different input and output size. Input size :\" << ctx->inputs().size()\n      << \", output size: \" << ctx->outputs().size();\n  for (size_t i = 0; i < ctx->inputs().size(); ++i) {\n    const auto& input_arg = ctx->inputs().at(i);\n    const auto& output_arg = ctx->outputs().at(i);\n    *ctx->MutOutputTensorDesc(output_arg.first, output_arg.second) =\n        ctx->InputTensorDesc(input_arg.first, input_arg.second);\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> CheckAttrFnUtil::NoCheck(const UserOpDefWrapper&, const UserOpConfWrapper&) {\n  return Maybe<void>::Ok();\n}\n\nsize_t TmpSizeInferFnUtil::ZeroTmpSize(InferContext*) { return 0; }\n\n}  // namespace user_op\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/framework/infer_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_FRAMEWORK_INFER_UTIL_H_\n#define ONEFLOW_CORE_FRAMEWORK_INFER_UTIL_H_\n\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/framework/user_op_conf.h\"\n#include \"oneflow/core/common/tensor_desc.h\"\n#include \"oneflow/core/job/placement.pb.h\"\n#include \"oneflow/core/job/sbp_parallel.h\"\n#include \"oneflow/core/job/parallel_desc.h\"\n\nnamespace oneflow {\n\nclass Shape;\nclass JobDesc;\nclass Device;\n\nnamespace user_op {\nclass AttrVal;\n}  // namespace user_op\n\ntemplate<typename T>\nextern const T& AttrValueCast(const user_op::AttrVal& val);\n\nnamespace user_op {\n\nclass UserOpDefWrapper;\n\nclass InferContext {\n public:\n  virtual ~InferContext() = default;\n\n  virtual const TensorDesc& InputTensorDesc(const std::string&, int32_t) const = 0;\n  virtual const TensorDesc& OutputTensorDesc(const std::string&, int32_t) const = 0;\n  virtual TensorDesc* MutOutputTensorDesc(const std::string&, int32_t) = 0;\n  virtual const TensorDesc* LogicalTensorDesc4ArgNameAndIndex(const std::string&,\n                                                              int32_t) const = 0;\n  virtual const Shape& InputShape(const std::string&, int32_t) const = 0;\n  virtual const Shape& OutputShape(const std::string&, int32_t) const = 0;\n  virtual void SetOutputShape(const std::string&, int32_t, const Shape&) = 0;\n  virtual const Shape& Shape4ArgNameAndIndex(const std::string&, int32_t) const = 0;\n  virtual void SetShape4ArgNameAndIndex(const std::string&, int32_t, const Shape&) = 0;\n  virtual const Stride& InputStride(const std::string&, int32_t) const = 0;\n  virtual const Stride& OutputStride(const std::string&, int32_t) const = 0;\n  virtual void SetOutputStride(const std::string&, int32_t, const Stride&) = 0;\n  virtual const Stride& Stride4ArgNameAndIndex(const std::string&, int32_t) const = 0;\n  virtual void SetStride4ArgNameAndIndex(const std::string&, int32_t, const Stride&) = 0;\n  virtual DataType InputDType(const std::string&, int32_t) const = 0;\n  virtual DataType OutputDType(const std::string&, int32_t) const = 0;\n  virtual void SetOutputDType(const std::string&, int32_t, DataType) = 0;\n  virtual DataType Dtype4ArgNameAndIndex(const std::string&, int32_t) const = 0;\n  virtual void SetDtype4ArgNameAndIndex(const std::string&, int32_t, DataType) = 0;\n  virtual MemoryFormat InputMemoryFormat(const std::string& arg_name, int32_t index) const = 0;\n  virtual MemoryFormat OutputMemoryFormat(const std::string& arg_name, int32_t index) const = 0;\n  virtual void SetOutputMemoryFormat(const std::string& arg_name, int32_t index,\n                                     MemoryFormat memory_format) = 0;\n  virtual MemoryFormat MemoryFormat4ArgNameAndIndex(const std::string& arg_name,\n                                                    int32_t index) const = 0;\n  virtual void SetMemoryFormat4ArgNameAndIndex(const std::string& arg_name, int32_t index,\n                                               MemoryFormat memory_format) = 0;\n\n  virtual const std::vector<std::pair<std::string, int32_t>>& inputs() const = 0;\n  virtual const std::vector<std::pair<std::string, int32_t>>& outputs() const = 0;\n  virtual const std::string& input(const std::string& arg_name, int32_t index) const = 0;\n  virtual const std::string& output(const std::string& arg_name, int32_t index) const = 0;\n  virtual bool has_input(const std::string& arg_name, int32_t index) const = 0;\n  virtual bool has_output(const std::string& arg_name, int32_t index) const = 0;\n  virtual int32_t input_size(const std::string& arg_name) const = 0;\n  virtual int32_t output_size(const std::string& arg_name) const = 0;\n  virtual const std::string& op_name() const = 0;\n  virtual const std::string& op_type_name() const = 0;\n  virtual const std::string& op_loc() const = 0;\n\n  template<typename T>\n  const T& Attr(const std::string& attr_name) const {\n    return AttrValueCast<T>(*Attr4Name(attr_name));\n  }\n\n  virtual const ParallelContext& parallel_ctx() const = 0;\n  virtual const ParallelDesc& parallel_desc() const = 0;\n\n  virtual const JobDesc* job_desc() const {\n    UNIMPLEMENTED();\n    return nullptr;\n  };\n  virtual const SbpParallel& SbpParallel4ArgNameAndIndex(const std::string&, int32_t) const = 0;\n\n  virtual const NdSbp& NdSbp4ArgNameAndIndex(const std::string&, int32_t) const = 0;\n\n  virtual bool InputIsDynamic(const std::string&, int32_t) const = 0;\n  virtual bool OutputIsDynamic(const std::string&, int32_t) const = 0;\n  virtual void SetOutputIsDynamic(const std::string&, int32_t, bool) = 0;\n  virtual bool IsDynamic4ArgNameAndIndex(const std::string&, int32_t) const = 0;\n  virtual void SetIsDynamic4ArgNameAndIndex(const std::string&, int32_t, bool) = 0;\n\n  virtual int64_t parallel_num() const = 0;\n\n protected:\n  InferContext() = default;\n  InferContext(const InferContext&) = delete;\n  virtual const std::shared_ptr<const AttrVal>& Attr4Name(const std::string& attr_name) const = 0;\n};\n\nclass DeviceAndStreamInferContext {\n public:\n  virtual ~DeviceAndStreamInferContext() = default;\n\n  template<typename T>\n  const T& Attr(const std::string& attr_name) const {\n    return AttrValueCast<T>(*Attr4Name(attr_name));\n  }\n\n  virtual const std::vector<std::pair<std::string, int32_t>>& inputs() const = 0;\n  virtual const std::vector<std::pair<std::string, int32_t>>& outputs() const = 0;\n\n  virtual Symbol<Device>* OutputTensorDevice4ArgNameAndIndex(const std::string&, int64_t) = 0;\n\n  virtual Symbol<Device> InputTensorDevice4ArgNameAndIndex(const std::string&, int64_t) const = 0;\n\n protected:\n  DeviceAndStreamInferContext() = default;\n  virtual const std::shared_ptr<const AttrVal>& Attr4Name(const std::string& attr_name) const = 0;\n};\n\nstruct TensorDescInferFnUtil {\n  static Maybe<void> Unchanged(InferContext*);\n  static Maybe<void> UnchangedDataType(InferContext*);\n  static Maybe<void> InOutCorrespond(InferContext*);\n};\n\nstruct CheckAttrFnUtil {\n  static Maybe<void> NoCheck(const UserOpDefWrapper&, const UserOpConfWrapper&);\n};\n\nstruct TmpSizeInferFnUtil {\n  static size_t ZeroTmpSize(InferContext*);\n};\n\n}  // namespace user_op\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_FRAMEWORK_INFER_UTIL_H_\n"
  },
  {
    "path": "oneflow/core/framework/instructions_builder.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <atomic>\n#include <thread>\n#include <chrono>\n#include \"oneflow/core/framework/instructions_builder.h\"\n#include \"oneflow/core/framework/stream_guard.h\"\n#include \"oneflow/core/framework/symbol_storage_util.h\"\n#include \"oneflow/core/device/event_record.h\"\n#include \"oneflow/core/framework/parallel_conf_util.h\"\n#include \"oneflow/core/operator/op_node_signature.pb.h\"\n#include \"oneflow/core/operator/operator.h\"\n#include \"oneflow/core/framework/id_util.h\"\n#include \"oneflow/core/framework/scope_util.h\"\n#include \"oneflow/core/framework/session_util.h\"\n#include \"oneflow/core/common/container_util.h\"\n#include \"oneflow/core/common/decorator.h\"\n#include \"oneflow/core/common/blocking_counter.h\"\n#include \"oneflow/core/common/env_var/vm.h\"\n#include \"oneflow/core/rpc/include/global_process_ctx.h\"\n#include \"oneflow/core/vm/access_blob_arg_cb_instruction_policy.h\"\n#include \"oneflow/core/vm/ep_record_event_instruction_policy.h\"\n#include \"oneflow/core/vm/op_call_instruction_policy.h\"\n#include \"oneflow/core/vm/barrier_instruction_policy.h\"\n#include \"oneflow/core/vm/critical_section_instruction_policy.h\"\n#include \"oneflow/core/vm/release_tensor_instruction_policy.h\"\n#include \"oneflow/core/vm/lazy_job_instruction_policy.h\"\n#include \"oneflow/core/vm/global_sync_instruction_policy.h\"\n#include \"oneflow/core/vm/op_call_instruction_policy.h\"\n#include \"oneflow/core/vm/stream_wait_instruction_policy.h\"\n#include \"oneflow/core/vm/stream_record_event_instruction_policy.h\"\n#include \"oneflow/core/vm/stream_wait_event_instruction_policy.h\"\n#include \"oneflow/core/vm/sync_access_instruction_policy.h\"\n#include \"oneflow/core/vm/touch_tensors_instruction_policy.h\"\n#include \"oneflow/core/vm/virtual_machine.h\"\n#include \"oneflow/core/vm/vm_util.h\"\n#include \"oneflow/core/framework/global_tensor_infer_cache.h\"\n#include \"oneflow/core/eager/local_dep_object.h\"\n#include \"oneflow/core/framework/tensor.h\"\n#include \"oneflow/core/framework/device.h\"\n#include \"oneflow/core/framework/stream.h\"\n#include \"oneflow/core/framework/stream_need_soft_sync.h\"\n#include \"oneflow/core/framework/stream_is_comm_net_stream.h\"\n#include \"oneflow/core/framework/stream_support_stream_wait.h\"\n#include \"oneflow/core/framework/stream_on_independent_thread.h\"\n#include \"oneflow/core/job/env_desc.h\"\n#include \"oneflow/core/profiler/profiler.h\"\n#include \"oneflow/core/platform/include/pthread_fork.h\"\n#include \"oneflow/core/vm/allocate_tensor_instruction_policy.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nMaybe<Symbol<Stream>> RawGetCriticalSectionStream() {\n  return Stream::New(JUST(Device::New(\"cpu\")), StreamType::kCriticalSection);\n}\n\nstatic constexpr auto* GetCriticalSectionStream =\n    DECORATE(&RawGetCriticalSectionStream, ThreadLocal);\n\nMaybe<Symbol<Stream>> RawGetLazyJobLauncherStream() {\n  return Stream::New(JUST(Device::New(\"cpu\")), StreamType::kLazyJobLauncher);\n}\n\nstatic constexpr auto* GetLazyJobLauncherStream =\n    DECORATE(&RawGetLazyJobLauncherStream, ThreadLocal);\n\n}  // namespace\n\n// clang-format off\n// Job e.g.:\n//                                    [wait_and_send_ids]\n//                                             |\n//                                             V\n//                                             |\n//                         +-------------------+\n//                         |                   |\n//                         V             [cpu_decoder]\n//                         |                   |\n//             [critcial_section_wait]         V\n//                         |                   |\n//                         V            [forward_ops...]\n//                         |                   |\n//                         |                   V\n//                         +-------------------+\n//                                             |\n//                                        [copy_loss]\n//                                             |\n//                                             +-----------------------+\n//                                             |                       |\n//                                             V                       V\n//                                             |                       |\n//                                     [backward_ops...]               |\n//                                             |                       |\n//                                             V            [critical_section_callback]\n//                                             |                       |\n//                                     [optimizer_ops...]              V\n//                                             |                       |\n//                                             V                       |\n//                                             |                       |\n//                                             +-----------------------+\n//                                             |                       \n//                                     [callback_notifier]                       \n// \n//\n// clang-format on\n// critcial_section_wait is a blocking opkernel which waits tick signal from instruction\n// CriticalSectionBegin.\n// critical_section_callback is a non-blocking opkernel which notifies instruction\n// CriticalSectionEnd done.\nMaybe<void> InstructionsBuilder::LaunchLazyJob(const vm::EagerBlobObjectListPtr& inputs,\n                                               const vm::EagerBlobObjectListPtr& outputs,\n                                               const vm::EagerBlobObjectListPtr& parameters,\n                                               const std::shared_ptr<NNGraphIf>& nn_graph) {\n  JUST(SoftSyncNNGraphBuffers(inputs, nn_graph));\n  JUST(SoftSyncNNGraphBuffers(outputs, nn_graph));\n  JUST(SoftSyncNNGraphBuffers(parameters, nn_graph));\n  {\n    // instruction chain: [CriticalSectionBegin] -> [CriticalSectionEnd]\n    // instructions LaunchLazyJob are launched independent from instruction chains\n    // [CriticalSectionBegin] -> [CriticalSectionEnd]\n    const auto& input_op_name2end_event_record =\n        std::make_shared<HashMap<std::string, std::shared_ptr<SharedEventRecord>>>();\n    {\n      for (const auto& op_name : nn_graph->inputs_op_names()) {\n        const auto& event_record = std::make_shared<SharedEventRecord>();\n        CHECK_OR_RETURN(input_op_name2end_event_record->emplace(op_name, event_record).second)\n            << Error::RuntimeError() << \"Duplicate Op name \" << op_name;\n      }\n\n      auto stream = JUST(GetCriticalSectionStream());\n      auto* vm_stream = JUST(Singleton<VirtualMachine>::Get()->GetVmStream(stream));\n      auto instruction = intrusive::make_shared<vm::Instruction>(\n          vm_stream, std::make_shared<vm::InputCriticalSectionBeginInstructionPolicy>(\n                         nn_graph, inputs, input_op_name2end_event_record, vm_stream));\n      instruction_list_->EmplaceBack(std::move(instruction));\n    }\n    const auto& output_op_name2end_event_record =\n        std::make_shared<HashMap<std::string, std::shared_ptr<SharedEventRecord>>>();\n    {\n      for (const auto& op_name : nn_graph->outputs_op_names()) {\n        const auto& event_record = std::make_shared<SharedEventRecord>();\n        CHECK_OR_RETURN(output_op_name2end_event_record->emplace(op_name, event_record).second)\n            << Error::RuntimeError() << \"Duplicate Op name \" << op_name;\n      }\n      auto stream = JUST(GetCriticalSectionStream());\n      auto* vm_stream = JUST(Singleton<VirtualMachine>::Get()->GetVmStream(stream));\n      auto instruction = intrusive::make_shared<vm::Instruction>(\n          vm_stream, std::make_shared<vm::OutputCriticalSectionBeginInstructionPolicy>(\n                         nn_graph, outputs, output_op_name2end_event_record, vm_stream));\n      instruction_list_->EmplaceBack(std::move(instruction));\n    }\n    {\n      auto stream = JUST(GetLazyJobLauncherStream());\n      auto* vm_stream = JUST(Singleton<VirtualMachine>::Get()->GetVmStream(stream));\n      auto instruction = intrusive::make_shared<vm::Instruction>(\n          vm_stream, std::make_shared<vm::LaunchLazyJobInstructionPolicy>(nn_graph, parameters));\n      instruction_list_->EmplaceBack(std::move(instruction));\n    }\n    auto stream = JUST(GetCriticalSectionStream());\n    auto* vm_stream = JUST(Singleton<VirtualMachine>::Get()->GetVmStream(stream));\n    for (int i = 0; i < nn_graph->inputs_op_names().size(); ++i) {\n      const auto& eager_blob_object = inputs->at(i);\n      const auto& op_name = nn_graph->inputs_op_names().at(i);\n      const auto& event_record = JUST(MapAt(*input_op_name2end_event_record, op_name));\n      auto instruction = intrusive::make_shared<vm::Instruction>(\n          vm_stream, std::make_shared<vm::InputCriticalSectionEndInstructionPolicy>(\n                         eager_blob_object, event_record, vm_stream));\n      instruction_list_->EmplaceBack(std::move(instruction));\n    }\n    for (int i = 0; i < nn_graph->outputs_op_names().size(); ++i) {\n      const auto& eager_blob_object = outputs->at(i);\n      const auto& op_name = nn_graph->outputs_op_names().at(i);\n      const auto& event_record = JUST(MapAt(*output_op_name2end_event_record, op_name));\n      auto instruction = intrusive::make_shared<vm::Instruction>(\n          vm_stream, std::make_shared<vm::OutputCriticalSectionEndInstructionPolicy>(\n                         eager_blob_object, event_record, vm_stream));\n      instruction_list_->EmplaceBack(std::move(instruction));\n    }\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InstructionsBuilder::SoftSyncNNGraphBuffers(\n    const vm::EagerBlobObjectListPtr& eager_blob_objects,\n    const std::shared_ptr<NNGraphIf>& nn_graph) {\n  const auto& stream = JUST(GetCriticalSectionStream());\n  JUST(SoftSyncStream(*eager_blob_objects, stream));\n  return Maybe<void>::Ok();\n}\n\nnamespace {\n\nint64_t NewSymbolId() {\n  static std::atomic<int64_t> cnt(0);\n  return cnt.fetch_add(1, std::memory_order_relaxed);\n}\n\n}  // namespace\n\nMaybe<JobDesc> InstructionsBuilder::GetJobConfSymbol(const JobConfigProto& job_conf) {\n  return Singleton<symbol::Storage<JobDesc>>::Get()->FindOrCreate(job_conf, &NewSymbolId);\n}\n\nMaybe<ParallelDesc> InstructionsBuilder::GetParallelDescSymbol(const ParallelConf& parallel_conf) {\n  return Singleton<symbol::Storage<ParallelDesc>>::Get()->FindOrCreate(parallel_conf, &NewSymbolId);\n}\n\nMaybe<Scope> InstructionsBuilder::GetScopeSymbol(const ScopeProto& scope_proto) {\n  return Singleton<symbol::Storage<Scope>>::Get()->FindOrCreate(scope_proto, &NewSymbolId);\n}\n\nMaybe<OperatorConfSymbol> InstructionsBuilder::GetOpConfSymbol(const OperatorConf& op_conf) {\n  return Singleton<symbol::Storage<OperatorConfSymbol>>::Get()->FindOrCreate(op_conf, &NewSymbolId);\n}\n\nMaybe<Scope> InstructionsBuilder::BuildInitialScope(\n    int64_t session_id, const JobConfigProto& job_conf, const std::string& device_tag,\n    const std::vector<std::string>& machine_device_ids, const std::shared_ptr<Shape>& hierarchy,\n    bool is_local) {\n  ScopeProto scope_proto;\n  scope_proto.set_session_id(session_id);\n  std::shared_ptr<JobDesc> job_conf_sym = JUST(GetJobConfSymbol(job_conf));\n  scope_proto.set_job_desc_symbol_id(JUST(job_conf_sym->symbol_id()));\n  std::shared_ptr<ParallelConf> parallel_conf =\n      JUST(MakeParallelConf(device_tag, machine_device_ids, hierarchy));\n  std::shared_ptr<ParallelDesc> device_parallel_desc_sym =\n      JUST(GetParallelDescSymbol(*parallel_conf));\n  scope_proto.set_device_parallel_desc_symbol_id(JUST(device_parallel_desc_sym->symbol_id()));\n  parallel_conf = JUST(MakeParallelConf(\"cpu\", machine_device_ids, hierarchy));\n  std::shared_ptr<ParallelDesc> host_parallel_desc_sym =\n      JUST(GetParallelDescSymbol(*parallel_conf));\n  scope_proto.set_host_parallel_desc_symbol_id(JUST(host_parallel_desc_sym->symbol_id()));\n  if (is_local) {\n    scope_proto.mutable_opt_local_parallel_conf()->mutable_local_parallel();\n  } else {\n    scope_proto.mutable_opt_local_parallel_conf()->clear_local_parallel();\n  }\n  return GetScopeSymbol(scope_proto);\n}\n\nMaybe<Scope> InstructionsBuilder::BuildInitialScopeWithPlacement(int64_t session_id,\n                                                                 const JobConfigProto& job_conf,\n                                                                 Symbol<ParallelDesc> placement,\n                                                                 bool is_local) {\n  ScopeProto scope_proto;\n  scope_proto.set_session_id(session_id);\n  std::shared_ptr<JobDesc> job_conf_sym = JUST(GetJobConfSymbol(job_conf));\n  scope_proto.set_job_desc_symbol_id(JUST(job_conf_sym->symbol_id()));\n\n  std::shared_ptr<ParallelDesc> device_parallel_desc_sym =\n      JUST(GetParallelDescSymbol(placement->parallel_conf()));\n  scope_proto.set_device_parallel_desc_symbol_id(JUST(device_parallel_desc_sym->symbol_id()));\n\n  Symbol<ParallelDesc> new_placement = JUST(ReplaceDeviceType(placement, DeviceType::kCPU));\n  std::shared_ptr<ParallelDesc> host_parallel_desc_sym =\n      JUST(GetParallelDescSymbol(new_placement->parallel_conf()));\n  scope_proto.set_host_parallel_desc_symbol_id(JUST(host_parallel_desc_sym->symbol_id()));\n  if (is_local) {\n    scope_proto.mutable_opt_local_parallel_conf()->mutable_local_parallel();\n  } else {\n    scope_proto.mutable_opt_local_parallel_conf()->clear_local_parallel();\n  }\n  return GetScopeSymbol(scope_proto);\n}\n\nMaybe<Scope> InstructionsBuilder::BuildScopeWithNewParallelDesc(\n    const std::shared_ptr<Scope>& scope, const std::string& device_tag,\n    const std::vector<std::string>& machine_device_ids, const std::shared_ptr<Shape>& hierarchy) {\n  const auto SetScopeProto = [this, &device_tag, &machine_device_ids,\n                              &hierarchy](const std::shared_ptr<ScopeProto>& scope_proto) {\n    std::shared_ptr<ParallelConf> parallel_conf =\n        CHECK_JUST(MakeParallelConf(device_tag, machine_device_ids, hierarchy));\n    std::shared_ptr<ParallelDesc> device_parallel_desc_sym =\n        CHECK_JUST(GetParallelDescSymbol(*parallel_conf));\n    parallel_conf = CHECK_JUST(MakeParallelConf(\"cpu\", machine_device_ids, hierarchy));\n    std::shared_ptr<ParallelDesc> host_parallel_desc_sym =\n        CHECK_JUST(GetParallelDescSymbol(*parallel_conf));\n    scope_proto->set_device_parallel_desc_symbol_id(\n        CHECK_JUST(device_parallel_desc_sym->symbol_id()));\n    scope_proto->set_host_parallel_desc_symbol_id(CHECK_JUST(host_parallel_desc_sym->symbol_id()));\n  };\n\n  return BuildScopeByProtoSetter(scope, SetScopeProto);\n}\n\nMaybe<Scope> InstructionsBuilder::BuildScopeWithNewParallelConf(const std::shared_ptr<Scope>& scope,\n                                                                const ParallelConf& parallel_conf) {\n  const std::shared_ptr<std::tuple<std::string, std::vector<std::string>,\n                                   std::shared_ptr<ShapeProto>>>& tag_and_dev_ids_and_hierarchy =\n      JUST(GetDeviceTagAndMachineDeviceIdsAndHierarchy(parallel_conf));\n  std::shared_ptr<Shape> hierarchy;\n  if (std::get<2>(*tag_and_dev_ids_and_hierarchy)) {\n    hierarchy.reset(new Shape(parallel_conf.hierarchy()));\n  }\n  return BuildScopeWithNewParallelDesc(scope, std::get<0>(*tag_and_dev_ids_and_hierarchy),\n                                       std::get<1>(*tag_and_dev_ids_and_hierarchy), hierarchy);\n}\n\nMaybe<Scope> InstructionsBuilder::BuildScopeWithNewIsLocal(const std::shared_ptr<Scope>& scope,\n                                                           bool is_local) {\n  const auto SetScopeProto = [is_local](const std::shared_ptr<ScopeProto>& scope_proto) {\n    if (is_local) {\n      scope_proto->mutable_opt_local_parallel_conf()->mutable_local_parallel();\n    } else {\n      scope_proto->mutable_opt_local_parallel_conf()->clear_local_parallel();\n    }\n  };\n\n  return BuildScopeByProtoSetter(scope, SetScopeProto);\n}\n\nMaybe<Scope> InstructionsBuilder::BuildScopeWithNewScopeName(const std::shared_ptr<Scope>& scope,\n                                                             const std::string& scope_name) {\n  const auto SetScopeProto = [&scope_name](const std::shared_ptr<ScopeProto>& scope_proto) {\n    scope_proto->add_scope_op_name_prefixes(scope_name);\n  };\n\n  return BuildScopeByProtoSetter(scope, SetScopeProto);\n}\n\nMaybe<Scope> InstructionsBuilder::BuildScopeByProtoSetter(\n    const std::shared_ptr<Scope>& scope,\n    const std::function<void(const std::shared_ptr<ScopeProto>&)>& Setter) {\n  std::shared_ptr<ScopeProto> scope_proto = JUST(scope->MakeChildScopeProto());\n  Setter(scope_proto);\n  return GetScopeSymbol(*scope_proto);\n}\n\nMaybe<Scope> InstructionsBuilder::BuildScopeByProtoStrSetter(\n    const std::shared_ptr<Scope>& scope,\n    const std::function<std::string(const std::string&)>& StrSetter) {\n  std::shared_ptr<ScopeProto> scope_proto = JUST(scope->MakeChildScopeProto());\n  std::string serialized_scope_proto = PbMessage2TxtString(*scope_proto);\n  std::string new_serialized_scope_proto = StrSetter(serialized_scope_proto);\n  CHECK_OR_RETURN(TxtString2PbMessage(new_serialized_scope_proto, scope_proto.get()))\n      << Error::RuntimeError() << \"scope_proto parse failed\";\n  return GetScopeSymbol(*scope_proto);\n}\n\nMaybe<void> InstructionsBuilder::Call(const std::shared_ptr<one::StatefulOpKernel>& opkernel,\n                                      vm::EagerBlobObjectList&& input_eager_blob_objects,\n                                      vm::EagerBlobObjectList&& output_eager_blob_objects,\n                                      const one::OpExprInterpContext& ctx, Symbol<Stream> stream) {\n  return Call(opkernel, std::move(input_eager_blob_objects), std::move(output_eager_blob_objects),\n              nullptr, ctx, stream);\n}\n\nMaybe<void> InstructionsBuilder::AllocateTensors(const vm::EagerBlobObjectList& eager_blob_objects,\n                                                 Symbol<Stream> stream) {\n  // try soft sync eager blob objects which have memory allocated.\n  JUST(SoftSyncStream(eager_blob_objects, stream));\n  auto* vm_stream = JUST(Singleton<VirtualMachine>::Get()->GetVmStream(stream));\n  const auto& instruction_policy =\n      std::make_shared<vm::AllocateTensorInstructionPolicy>(eager_blob_objects, vm_stream);\n  auto instruction = intrusive::make_shared<vm::Instruction>(vm_stream, instruction_policy);\n  instruction_list_->EmplaceBack(std::move(instruction));\n  for (const auto& eager_blob_object : eager_blob_objects) {\n    if (!eager_blob_object->producer_stream().has_value()) {\n      JUST(eager_blob_object->init_producer_stream(stream));\n    }\n    eager_blob_object->set_last_used_stream(stream);\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InstructionsBuilder::Call(\n    const std::shared_ptr<one::StatefulOpKernel>& opkernel,\n    vm::EagerBlobObjectList&& input_eager_blob_objects,\n    vm::EagerBlobObjectList&& output_eager_blob_objects,\n    const std::shared_ptr<const one::GlobalTensorInferResult>& global_tensor_infer_result,\n    const one::OpExprInterpContext& ctx, Symbol<Stream> stream) {\n  stream = JUST(StreamGuard::TryConvertStream(stream));\n  Symbol<Stream> allocator_stream = JUST(GetAllocatorStream(stream));\n  if (stream != allocator_stream) {\n    JUST(AllocateTensors(output_eager_blob_objects, allocator_stream));\n  }\n  JUST(SoftSyncStream(output_eager_blob_objects, stream));\n  JUST(SoftSyncStream(input_eager_blob_objects, stream));\n  for (const auto& output : output_eager_blob_objects) {\n    if (!output->producer_stream().has_value()) { JUST(output->init_producer_stream(stream)); }\n    output->set_last_used_stream(stream);\n  }\n  auto* vm_stream = JUST(Singleton<VirtualMachine>::Get()->GetVmStream(stream));\n  auto instruction = intrusive::make_shared<vm::Instruction>(\n      vm_stream, JUST(vm::OpCallInstructionPolicy::New(\n                     vm_stream, opkernel, std::move(input_eager_blob_objects),\n                     std::move(output_eager_blob_objects), global_tensor_infer_result, ctx,\n                     *one::CurrentDevVmDepObjectConsumeMode())));\n  instruction_list_->EmplaceBack(std::move(instruction));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InstructionsBuilder::ReleaseTensor(\n    const std::shared_ptr<vm::EagerBlobObject>& eager_blob_object) {\n  const auto& last_used_stream = JUST(eager_blob_object->last_used_stream());\n  const auto& producer_stream = JUST(eager_blob_object->producer_stream());\n  if (pthread_fork::IsForkedSubProcess()\n      && producer_stream->device()->enum_type() != DeviceType::kCPU) {\n    return Maybe<void>::Ok();\n  }\n  Optional<Symbol<Stream>> stream{};\n  if (*one::CurrentDevVmDepObjectConsumeMode() == one::DevVmDepObjectConsumeMode::NONE) {\n    stream = Optional<Symbol<Stream>>(NullOpt);\n  } else if (IsCommNetStream::Visit(last_used_stream->stream_type())) {\n    // Disable inter-device instruction sequential for tensor used by communicative stream.\n    // It's not acceptable for us that cuda compute stream is blocked by cuda nccl stream.\n    stream = Optional<Symbol<Stream>>(NullOpt);\n  } else if (IsCommNetStream::Visit(producer_stream->stream_type())) {\n    // Disable inter-device instruction sequential for tensor produced by communicative stream.\n    stream = Optional<Symbol<Stream>>(NullOpt);\n  } else {\n    stream = producer_stream;\n  }\n  struct EnableStreamWaitOnReleaseTensor final\n      : public StreamTypeVisitor<EnableStreamWaitOnReleaseTensor> {\n    static bool VisitCompute() { return true; }\n    static bool VisitHost2Device() { return true; }\n    static bool VisitDevice2Host() { return true; }\n    static bool VisitCcl() { return false; }\n    static bool VisitBarrier() { return false; }\n    static bool VisitCriticalSection() { return false; }\n    static bool VisitLazyJobLauncher() { return false; }\n    static bool VisitPinnedCompute() { return VisitCompute(); }\n  };\n  const auto& EnableStreamWait = [&] {\n    if (last_used_stream->device() != producer_stream->device()) { return false; }\n    if (last_used_stream->stream_type() == producer_stream->stream_type()) { return true; }\n    return EnableStreamWaitOnReleaseTensor::Visit(last_used_stream->stream_type())\n           && EnableStreamWaitOnReleaseTensor::Visit(producer_stream->stream_type());\n  };\n  if (last_used_stream != producer_stream) {\n    if (stream.has_value() && EnableStreamWait()) {\n      JUST(SoftSyncStreamBetween({JUST(eager_blob_object->compute_local_dep_object())},\n                                 last_used_stream, JUST(stream)));\n    } else {\n      JUST(RecordEvent({JUST(eager_blob_object->compute_local_dep_object())}, last_used_stream));\n    }\n    eager_blob_object->set_last_used_stream(producer_stream);\n  }\n  auto vm_stream = stream.map([](Symbol<Stream> stream) -> vm::Stream* {\n    return CHECK_JUST(Singleton<VirtualMachine>::Get()->GetVmStream(stream));\n  });\n  StreamType stream_type = producer_stream->stream_type();\n  auto instruction = intrusive::make_shared<vm::Instruction>(\n      JUST(Singleton<VirtualMachine>::Get()->GetVmStream(producer_stream)),\n      JUST(vm::MakeReleaseTensorInstructionPolicy::Visit(stream_type, eager_blob_object,\n                                                         vm_stream)));\n  instruction_list_->EmplaceBack(std::move(instruction));\n\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InstructionsBuilder::TouchTensors(\n    const vm::EagerBlobObjectListPtr& eager_blob_objects) {\n  Symbol<Device> device = JUST(Device::New(\"cpu\"));\n  Symbol<Stream> stream = JUST(GetDefaultStreamByDevice(device));\n  return TouchTensors(eager_blob_objects, stream);\n}\n\nMaybe<void> InstructionsBuilder::TouchTensors(const vm::EagerBlobObjectListPtr& eager_blob_objects,\n                                              Symbol<Stream> stream) {\n  JUST(SoftSyncStream(*eager_blob_objects, stream));\n  auto instruction = intrusive::make_shared<vm::Instruction>(\n      JUST(Singleton<VirtualMachine>::Get()->GetVmStream(stream)),\n      std::make_unique<vm::TouchTensorsInstructionPolicy>(*eager_blob_objects));\n  instruction_list_->EmplaceBack(std::move(instruction));\n  return Maybe<void>::Ok();\n}\n\nnamespace {\n\ntemplate<typename T>\nusing SmallSet = small_vector<T>;\n\ntemplate<typename T>\nstd::pair<typename SmallSet<T>::iterator, bool> SmallSetInsert(SmallSet<T>* vec, const T& elem) {\n  for (auto iter = vec->begin(); iter != vec->end(); ++iter) {\n    if (*iter == elem) { return std::make_pair(iter, false); }\n  }\n  vec->push_back(elem);\n  return std::make_pair(vec->end() - 1, true);\n}\n\ntemplate<typename DoEachT>\nMaybe<void> ForEachEagerBlobObjectsNeedingSoftSync(\n    const vm::EagerBlobObjectList& eager_blob_objects, Symbol<Stream> stream,\n    const DoEachT& DoEach) {\n  if (eager_blob_objects.size() <= kOpArgsReservedSize) {\n    for (const auto& eager_blob_object : eager_blob_objects) {\n      const auto& opt_last_used_stream = eager_blob_object->last_used_stream();\n      if (unlikely(!opt_last_used_stream.has_value())) { continue; }\n      const auto& last_used_stream = JUST(opt_last_used_stream);\n      if (last_used_stream != stream) {\n        small_vector<intrusive::shared_ptr<LocalDepObject>> dep_objects{\n            intrusive::shared_ptr<LocalDepObject>(\n                JUST(eager_blob_object->compute_local_dep_object()))};\n        JUST(DoEach(last_used_stream, std::move(dep_objects)));\n      }\n    }\n  } else {\n    SmallSet<Symbol<Stream>> last_used_streams;\n    for (const auto& eager_blob_object : eager_blob_objects) {\n      const auto& opt_last_used_stream = eager_blob_object->last_used_stream();\n      if (unlikely(!opt_last_used_stream.has_value())) { continue; }\n      const auto& last_used_stream = JUST(opt_last_used_stream);\n      if (last_used_stream != stream) { SmallSetInsert(&last_used_streams, last_used_stream); }\n    }\n    for (const auto& last_used_stream : last_used_streams) {\n      small_vector<intrusive::shared_ptr<LocalDepObject>> dep_objects{};\n      for (const auto& eager_blob_object : eager_blob_objects) {\n        const auto& opt_stream = eager_blob_object->last_used_stream();\n        if (unlikely(!opt_stream.has_value())) { continue; }\n        if (JUST(opt_stream) == last_used_stream) {\n          dep_objects.emplace_back(JUST(eager_blob_object->compute_local_dep_object()));\n        }\n      }\n      JUST(DoEach(last_used_stream, std::move(dep_objects)));\n    }\n  }\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\nMaybe<void> InstructionsBuilder::SoftSyncStream(const vm::EagerBlobObjectList& eager_blob_objects,\n                                                Symbol<Stream> stream) {\n  JUST(ForEachEagerBlobObjectsNeedingSoftSync(\n      eager_blob_objects, stream,\n      [&](Symbol<Stream> last_used_stream, auto&& dep_objects) -> Maybe<void> {\n        return SoftSyncStreamBetween(std::move(dep_objects), last_used_stream, stream);\n      }));\n  for (const auto& eager_blob_object : eager_blob_objects) {\n    eager_blob_object->set_last_used_stream(stream);\n  }\n  return Maybe<void>::Ok();\n}\n\nnamespace {\n\nbool SupportingStreamWait(Symbol<Stream> from_stream, Symbol<Stream> to_stream) {\n  if (from_stream->device() == to_stream->device()\n      && from_stream->stream_type() == to_stream->stream_type()\n      && from_stream->thread_uid() == to_stream->thread_uid()) {\n    CHECK(from_stream == to_stream);\n  }\n  if (unlikely(!ThreadLocalEnvBool<ONEFLOW_VM_ENABLE_STREAM_WAIT>())) { return false; }\n  DeviceType from_device_type = from_stream->device()->enum_type();\n  DeviceType to_device_type = from_stream->device()->enum_type();\n  return from_stream->device() == to_stream->device() && from_stream->support_wait_event()\n         && to_stream->support_wait_event()\n         && StreamSupportStreamWait::Visit(from_stream->stream_type(), from_device_type)\n         && StreamSupportStreamWait::Visit(to_stream->stream_type(), to_device_type)\n         && !StreamOnIndependentThread::Visit(from_stream->stream_type())\n         && !StreamOnIndependentThread::Visit(to_stream->stream_type());\n}\n\n}  // namespace\n\nMaybe<void> InstructionsBuilder::SoftSyncStreamBetween(\n    small_vector<intrusive::shared_ptr<LocalDepObject>>&& dependences, Symbol<Stream> from_stream,\n    Symbol<Stream> to_stream) {\n  CHECK(from_stream != to_stream) << \"synchronization is unnecessary\";\n  if (SupportingStreamWait(from_stream, to_stream)) {\n    JUST(StreamWait(std::move(dependences), from_stream, to_stream));\n  } else {\n    JUST(RecordEvent(std::move(dependences), from_stream));\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InstructionsBuilder::StreamWait(\n    small_vector<intrusive::shared_ptr<LocalDepObject>>&& dependences, Symbol<Stream> from_stream,\n    Symbol<Stream> to_stream) {\n  auto* from_vm_stream = JUST(Singleton<VirtualMachine>::Get()->GetVmStream(from_stream));\n  auto* to_vm_stream = JUST(Singleton<VirtualMachine>::Get()->GetVmStream(to_stream));\n  if (from_vm_stream->mut_thread_ctx() != to_vm_stream->mut_thread_ctx()) {\n    auto stream_record_event =\n        std::make_shared<vm::StreamRecordEventInstructionPolicy>(dependences);\n    auto record_instruction =\n        intrusive::make_shared<vm::Instruction>(from_vm_stream, stream_record_event);\n    instruction_list_->EmplaceBack(std::move(record_instruction));\n    auto stream_wait_event =\n        std::make_shared<vm::StreamWaitEventInstructionPolicy>(dependences, stream_record_event);\n    auto wait_instruction =\n        intrusive::make_shared<vm::Instruction>(to_vm_stream, stream_wait_event);\n    instruction_list_->EmplaceBack(std::move(wait_instruction));\n  } else {\n    auto instruction = intrusive::make_shared<vm::Instruction>(\n        to_vm_stream, std::make_unique<vm::StreamWaitInstructionPolicy>(\n                          std::move(dependences), from_vm_stream, to_vm_stream));\n    instruction_list_->EmplaceBack(std::move(instruction));\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InstructionsBuilder::RecordEvent(\n    small_vector<intrusive::shared_ptr<LocalDepObject>>&& compute_local_dep_objects,\n    Symbol<Stream> last_used_stream) {\n  DeviceType device_type = last_used_stream->device()->enum_type();\n  if (!NeedSoftSync::Visit(last_used_stream->stream_type(), device_type)) {\n    return Maybe<void>::Ok();\n  }\n  std::string modifier = \"mut\";\n  StreamType stream_type = last_used_stream->stream_type();\n  auto instruction = intrusive::make_shared<vm::Instruction>(\n      JUST(Singleton<VirtualMachine>::Get()->GetVmStream(last_used_stream)),\n      JUST(GetRecordEventInstructionPolicy::Visit(stream_type, device_type,\n                                                  std::move(compute_local_dep_objects), modifier)));\n  instruction_list_->EmplaceBack(std::move(instruction));\n  return Maybe<void>::Ok();\n}\n\ntemplate<typename T>\nMaybe<void> InstructionsBuilder::SyncAccessBlobByCallback(\n    const T tensor, const std::shared_ptr<BlockingThenBusy>& btb,\n    const std::function<void(ep::Stream*, const std::shared_ptr<vm::EagerBlobObject>&)>& Callback,\n    const std::string& modifier) {\n  // We want balance the cpu overhead and notification latency.\n  //\n  // balanced timeline here:\n  //\n  //   B: blocking wait\n  //   W: wake up\n  //   S: spin wait\n  //\n  //   vm thread:    |<--------------- prev ops ------------------>|<- Callback() ->|\n  //\n  //   main thread:  |<-------------------- B -------------------->|<- W ->|<- S  ->|\n  //\n  // bad timeline with more notification latency:\n  //\n  //   B: blocking wait\n  //   W: wake up\n  //   S: spin wait\n  //\n  //   vm thread:    |<--------------- prev ops ------------------>|<- Callback() ->|\n  //\n  //   main thread:  |<---------------------------- B ----------------------------->|<- W ->|\n  //\n  // bad timeline with more cpu overhead:\n  //\n  //   B: blocking wait\n  //   W: wake up\n  //   S: spin wait\n  //\n  //   vm thread:    |<--------------- prev ops ------------------>|<- Callback() ->|\n  //                 |                                             |                |\n  //   main thread:  |<---------------------------- S ----------------------------->|\n\n  const auto& CallbackWrapper = [btb, Callback](\n                                    ep::Stream* stream,\n                                    const std::shared_ptr<vm::EagerBlobObject>& eager_blob_object) {\n    btb->mut_notifier()->Notify();\n    Callback(stream, eager_blob_object);\n    btb->mut_spin_counter()->Decrease();\n  };\n  return AccessBlobByCallback(tensor, CallbackWrapper, modifier);\n}\n\ntemplate Maybe<void> InstructionsBuilder::SyncAccessBlobByCallback(\n    const std::shared_ptr<one::LocalTensor> tensor, const std::shared_ptr<BlockingThenBusy>& btb,\n    const std::function<void(ep::Stream*, const std::shared_ptr<vm::EagerBlobObject>&)>& Callback,\n    const std::string& modifier);\n\ntemplate Maybe<void> InstructionsBuilder::SyncAccessBlobByCallback(\n    const one::EagerLocalTensorImpl* tensor, const std::shared_ptr<BlockingThenBusy>& btb,\n    const std::function<void(ep::Stream*, const std::shared_ptr<vm::EagerBlobObject>&)>& Callback,\n    const std::string& modifier);\n\nnamespace {\n\nMaybe<Symbol<Device>> GetDevice(const std::shared_ptr<one::LocalTensor>& tensor) {\n  return tensor->device();  // return Maybe<Symbol<Device>>\n}\n\nMaybe<Symbol<Device>> GetDevice(const one::EagerLocalTensorImpl* tensor) {\n  return tensor->device();  // return const Symbol<Device>&\n}\n\ntemplate<typename T>\nMaybe<Symbol<Stream>> GetAccessStream(const T tensor) {\n  Symbol<Device> device = JUST(GetDevice(tensor));\n  // Do not use producer_stream or last_used_stream.\n  // Bug case when using producer_stream or last_used_stream:\n  //\n  // ```python\n  // tensor = oneflow.ones((1024, 1024, 1024), device='cuda').cpu()\n  // ndarray = tensor.numpy() # share memory\n  //\n  // ```\n  // `ndarray` may not be ones because instruction AccessBlobByCallback is prescheduled before\n  // oneflow.ones actually finished.\n  Symbol<Stream> stream = JUST(GetDefaultStreamByDevice(device));\n  return StreamGuard::TryConvertStream(stream);\n}\n\n}  // namespace\n\ntemplate<typename T>\nMaybe<void> InstructionsBuilder::AccessBlobByCallback(\n    const T tensor,\n    const std::function<void(ep::Stream*, const std::shared_ptr<vm::EagerBlobObject>&)>& callback,\n    const std::string& modifier) {\n  const std::shared_ptr<vm::EagerBlobObject>& eager_blob_object = JUST(tensor->eager_blob_object());\n  Symbol<Stream> stream = JUST(GetAccessStream(tensor));\n  JUST(SoftSyncStream({eager_blob_object}, stream));\n  auto instruction = intrusive::make_shared<vm::Instruction>(\n      // Never replace `stream` with producer_stream or last_used_stream.\n      JUST(Singleton<VirtualMachine>::Get()->GetVmStream(stream)),\n      std::make_shared<vm::AccessBlobArgCbInstructionPolicy>(eager_blob_object, callback,\n                                                             modifier));\n  instruction_list_->EmplaceBack(std::move(instruction));\n  return Maybe<void>::Ok();\n}\n\ntemplate Maybe<void> InstructionsBuilder::AccessBlobByCallback(\n    const std::shared_ptr<one::LocalTensor> tensor,\n    const std::function<void(ep::Stream*, const std::shared_ptr<vm::EagerBlobObject>&)>& callback,\n    const std::string& modifier);\n\ntemplate Maybe<void> InstructionsBuilder::AccessBlobByCallback(\n    const one::EagerLocalTensorImpl* tensor,\n    const std::function<void(ep::Stream*, const std::shared_ptr<vm::EagerBlobObject>&)>& callback,\n    const std::string& modifier);\n\nnamespace {\n\nMaybe<Symbol<Stream>> GetBarrierStream() {\n  auto device = JUST(Device::New(\"cpu\"));\n  return Stream::New(device, StreamType::kBarrier);\n}\n\n}  // namespace\n\nMaybe<void> InstructionsBuilder::GlobalSync() {\n  auto stream = JUST(GetBarrierStream());\n  auto instruction = intrusive::make_shared<vm::Instruction>(\n      JUST(Singleton<VirtualMachine>::Get()->GetVmStream(stream)),\n      std::make_shared<vm::GlobalSyncInstructionPolicy>());\n  instruction_list_->PushBack(instruction.Mutable());\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InstructionsBuilder::Barrier(const std::function<void()>& Callback) {\n  auto stream = JUST(GetBarrierStream());\n  auto instruction = intrusive::make_shared<vm::Instruction>(\n      JUST(Singleton<VirtualMachine>::Get()->GetVmStream(stream)),\n      std::make_shared<vm::BarrierInstructionPolicy>(Callback));\n  instruction_list_->PushBack(instruction.Mutable());\n  return Maybe<void>::Ok();\n}\n\nnamespace {\n\ntemplate<typename InstructionPolicyT>\nMaybe<vm::Instruction*> MutThreadLocalInstruction(Symbol<Stream> stream) {\n  static thread_local std::vector<intrusive::shared_ptr<vm::Instruction>> vec;\n  if (unlikely(stream->unique_stream_id() >= vec.size())) {\n    vec.resize(stream->unique_stream_id() + 1);\n  }\n  auto* instruction_ptr = &vec[stream->unique_stream_id()];\n  if (static_cast<bool>(*instruction_ptr) && (*instruction_ptr)->ref_cnt() != 1) {\n    // This instruction should not be reusd because of being hold by other threads.\n    instruction_ptr->Reset();\n  }\n  if (unlikely(!static_cast<bool>(*instruction_ptr))) {\n    *instruction_ptr = intrusive::make_shared<vm::Instruction>(\n        JUST(Singleton<VirtualMachine>::Get()->GetVmStream(stream)),\n        std::make_shared<InstructionPolicyT>());\n  }\n  return instruction_ptr->Mutable();\n}\n\n}  // namespace\n\ntemplate<typename T, typename InstructionPolicyT>\nMaybe<void> SyncAccessSmallMem(char* mem_ptr, size_t bytes, const T tensor) {\n  static thread_local vm::InstructionList instruction_list;\n  static thread_local InstructionsBuilder instructions_builder(&instruction_list);\n  const std::shared_ptr<vm::EagerBlobObject>& eager_blob_object = JUST(tensor->eager_blob_object());\n  const Symbol<Stream> stream = JUST(GetAccessStream(tensor));\n  if (eager_blob_object->last_used_stream().has_value()\n      && stream != JUST(eager_blob_object->last_used_stream())) {\n    // Synchronize stream.\n    JUST(instructions_builder.SoftSyncStream({eager_blob_object}, stream));\n  }\n  InstructionPolicyT* instruction_policy = nullptr;\n  {\n    // Construct instruction.\n    auto* instruction = JUST(MutThreadLocalInstruction<InstructionPolicyT>(stream));\n    instruction_policy =\n        static_cast<InstructionPolicyT*>(instruction->mut_instruction_policy());  // NOLINT\n    instruction_policy->Reset(mem_ptr, bytes, eager_blob_object.get());\n    instruction_list.PushBack(instruction);\n  }\n  // Dispatch instructions.\n  JUST(vm::Run(&instruction_list));\n  {\n    // This thread should blocking wait if and only if there is a lot of workload on worker thread.\n    // When workload is small, we want better performance by skipping cond_.notify_xxx which costs\n    // about 2us to 3us.\n    auto* virtual_machine = JUST(SingletonMaybe<VirtualMachine>());\n    static constexpr int kSkipBlockingThreshold = 2;\n    if (virtual_machine->flying_instruction_cnt() < kSkipBlockingThreshold) {\n      // skip pthread_cond_broadcast on worker thread.\n      instruction_policy->mut_btb()->mut_notifier()->Notify();\n    }\n  }\n  // wait until done.\n  JUST(instruction_policy->mut_btb()->WaitUntilCntEqualZero(\n      VirtualMachine::GetPredicatorNoMoreInstructionsFinished()));\n  return Maybe<void>::Ok();\n}\n\ntemplate<typename T>\nMaybe<void> SyncReadSmallMem(char* mem_ptr, size_t bytes, const T tensor) {\n  return SyncAccessSmallMem<T, vm::SyncReadInstructionPolicy>(mem_ptr, bytes, tensor);\n}\n\ntemplate Maybe<void> SyncReadSmallMem(char* mem_ptr, size_t bytes,\n                                      const std::shared_ptr<one::LocalTensor> tensor);\n\ntemplate Maybe<void> SyncReadSmallMem(char* mem_ptr, size_t bytes,\n                                      const one::EagerLocalTensorImpl* tensor);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/framework/instructions_builder.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_FRAMEWORK_INSTRUCTIONS_BUILDER_H_\n#define ONEFLOW_CORE_FRAMEWORK_INSTRUCTIONS_BUILDER_H_\n\n#include \"oneflow/core/eager/eager_blob_object.h\"\n#include \"oneflow/core/eager/local_dep_object.h\"\n#include \"oneflow/core/framework/op_interpreter.h\"\n#include \"oneflow/core/vm/instruction.h\"\n#include \"oneflow/core/job/job_desc.h\"\n#include \"oneflow/core/job/parallel_desc.h\"\n#include \"oneflow/core/job/scope.h\"\n#include \"oneflow/core/job/scope.pb.h\"\n#include \"oneflow/core/common/singleton.h\"\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/common/shape.h\"\n#include \"oneflow/core/common/blocking_then_busy.h\"\n#include \"oneflow/core/operator/op_conf_symbol.h\"\n#include \"oneflow/core/vm/vm_util.h\"\n\nnamespace oneflow {\n\nnamespace one {\nclass StatefulOpKernel;\nclass TensorTuple;\nclass LocalTensor;\nclass GlobalTensorInferResult;\n}  // namespace one\n\nclass NNGraphIf;\n\nclass SharedEventRecord;\n\nclass InstructionsBuilder : public std::enable_shared_from_this<InstructionsBuilder> {\n public:\n  InstructionsBuilder(const InstructionsBuilder&) = delete;\n  InstructionsBuilder(InstructionsBuilder&&) = delete;\n  explicit InstructionsBuilder(vm::InstructionList* instruction_list)\n      : instruction_list_(instruction_list) {}\n  ~InstructionsBuilder() { instruction_list_->Clear(); }\n\n  const vm::InstructionList& instruction_list() const { return *instruction_list_; }\n\n  vm::InstructionList* mut_instruction_list() { return instruction_list_; }\n\n  // Build VM execution instructions with NNGraph's inputs/outputs/parameters for NNGraph execution.\n  Maybe<void> LaunchLazyJob(const vm::EagerBlobObjectListPtr& inputs,\n                            const vm::EagerBlobObjectListPtr& outputs,\n                            const vm::EagerBlobObjectListPtr& parameters,\n                            const std::shared_ptr<NNGraphIf>& nn_graph);\n\n  // soft sync for inputs/outputs buffers of NNGraph\n  Maybe<void> SoftSyncNNGraphBuffers(const vm::EagerBlobObjectListPtr& eager_blob_objects,\n                                     const std::shared_ptr<NNGraphIf>& nn_graph);\n\n  Maybe<JobDesc> GetJobConfSymbol(const JobConfigProto& job_conf);\n\n  Maybe<ParallelDesc> GetParallelDescSymbol(const ParallelConf& parallel_conf);\n\n  Maybe<Scope> GetScopeSymbol(const ScopeProto& scope_proto);\n\n  Maybe<OperatorConfSymbol> GetOpConfSymbol(const OperatorConf& op_conf);\n\n  Maybe<void> ReleaseTensor(const std::shared_ptr<vm::EagerBlobObject>& eager_blob_object);\n\n  Maybe<void> TouchTensors(const vm::EagerBlobObjectListPtr& eager_blob_objects);\n\n  Maybe<void> TouchTensors(const vm::EagerBlobObjectListPtr& eager_blob_objects,\n                           Symbol<Stream> stream);\n\n  template<typename T>\n  Maybe<void> SyncAccessBlobByCallback(\n      const T tensor, const std::shared_ptr<BlockingThenBusy>& btb,\n      const std::function<void(ep::Stream*, const std::shared_ptr<vm::EagerBlobObject>&)>& Callback,\n      const std::string& modifier);\n\n  template<typename T>\n  Maybe<void> AccessBlobByCallback(\n      const T tensor,\n      const std::function<void(ep::Stream*, const std::shared_ptr<vm::EagerBlobObject>&)>& callback,\n      const std::string& modifier);\n\n  Maybe<void> GlobalSync();\n  Maybe<void> Barrier(const std::function<void()>& callback);\n\n  Maybe<Scope> BuildInitialScope(int64_t session_id, const JobConfigProto& job_conf,\n                                 const std::string& device_tag,\n                                 const std::vector<std::string>& machine_device_ids,\n                                 const std::shared_ptr<Shape>& hierarchy, bool is_local);\n\n  Maybe<Scope> BuildInitialScopeWithPlacement(int64_t session_id, const JobConfigProto& job_conf,\n                                              Symbol<ParallelDesc> placement, bool is_local);\n\n  Maybe<Scope> BuildScopeWithNewParallelDesc(const std::shared_ptr<Scope>& scope,\n                                             const std::string& device_tag,\n                                             const std::vector<std::string>& machine_device_ids,\n                                             const std::shared_ptr<Shape>& hierarchy);\n\n  Maybe<Scope> BuildScopeWithNewParallelConf(const std::shared_ptr<Scope>& scope,\n                                             const ParallelConf& parallel_conf);\n\n  Maybe<Scope> BuildScopeWithNewIsLocal(const std::shared_ptr<Scope>& scope, bool is_local);\n\n  Maybe<Scope> BuildScopeWithNewScopeName(const std::shared_ptr<Scope>& scope,\n                                          const std::string& scope_name);\n\n  Maybe<Scope> BuildScopeByProtoSetter(\n      const std::shared_ptr<Scope>& scope,\n      const std::function<void(const std::shared_ptr<ScopeProto>&)>& Setter);\n\n  Maybe<Scope> BuildScopeByProtoStrSetter(\n      const std::shared_ptr<Scope>& scope,\n      const std::function<std::string(const std::string&)>& StrSetter);\n\n  Maybe<void> Call(const std::shared_ptr<one::StatefulOpKernel>& opkernel,\n                   vm::EagerBlobObjectList&& input_eager_blob_objects,\n                   vm::EagerBlobObjectList&& output_eager_blob_objects,\n                   const one::OpExprInterpContext& ctx, Symbol<Stream> stream);\n\n  Maybe<void> Call(\n      const std::shared_ptr<one::StatefulOpKernel>& opkernel,\n      vm::EagerBlobObjectList&& input_eager_blob_objects,\n      vm::EagerBlobObjectList&& output_eager_blob_objects,\n      const std::shared_ptr<const one::GlobalTensorInferResult>& global_tensor_infer_result,\n      const one::OpExprInterpContext& ctx, Symbol<Stream> stream);\n\n  Maybe<void> SoftSyncStream(const vm::EagerBlobObjectList& eager_blob_objects,\n                             Symbol<Stream> stream);\n\n private:\n  Maybe<void> AllocateTensors(const vm::EagerBlobObjectList& eager_blob_objects,\n                              Symbol<Stream> stream);\n\n  Maybe<void> SoftSyncStreamBetween(\n      small_vector<intrusive::shared_ptr<LocalDepObject>>&& dependences, Symbol<Stream> from_stream,\n      Symbol<Stream> to_stream);\n\n  Maybe<void> StreamWait(small_vector<intrusive::shared_ptr<LocalDepObject>>&& dependences,\n                         Symbol<Stream> from_stream, Symbol<Stream> to_stream);\n\n  Maybe<void> RecordEvent(\n      small_vector<intrusive::shared_ptr<LocalDepObject>>&& compute_local_dep_objects,\n      Symbol<Stream> stream);\n\n  vm::InstructionList* instruction_list_;\n};\n\n// Make VM instructions with instruction builder and run instructions with physical/local view.\ntemplate<typename CallbackT>\nMaybe<void> PhysicalRun(const CallbackT& Build) {\n  vm::InstructionList instruction_list;\n  InstructionsBuilder instructions_builder(&instruction_list);\n  JUST(Build(&instructions_builder));\n  JUST(vm::Run(instructions_builder.mut_instruction_list()));\n  return Maybe<void>::Ok();\n}\n\ntemplate<typename T>\nMaybe<void> SyncReadSmallMem(char* mem_ptr, size_t bytes, const T tensor);\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_FRAMEWORK_INSTRUCTIONS_BUILDER_H_\n"
  },
  {
    "path": "oneflow/core/framework/layout.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/common/container_util.h\"\n#include \"oneflow/core/framework/layout.h\"\n#include \"oneflow/core/common/preprocessor.h\"\n\nnamespace oneflow {\n\nSymbol<Layout> Layout::Get(LayoutType layout_type) {\n  static const HashMap<LayoutType, Symbol<Layout>> layout_type2layout{\n#define MAKE_ENTRY(layout_type) {OF_PP_CAT(LayoutType::k, layout_type), layout_type()},\n      OF_PP_FOR_EACH_TUPLE(MAKE_ENTRY, LAYOUT_SEQ)\n#undef MAKE_ENTRY\n  };\n  return layout_type2layout.at(layout_type);\n}\n\nconst std::string& GetLayoutTypeName(LayoutType layout_type) {\n  static const HashMap<LayoutType, std::string> layout_type2name{\n      {LayoutType::kStrided, \"oneflow.strided\"}};\n  return layout_type2name.at(layout_type);\n};\n\nconst std::string& Layout::name() const { return GetLayoutTypeName(layout_type_); }\n\n#define DEFINE_GET_LAYOUT_TYPE_FUNCTION(layout_type)                                     \\\n  Symbol<Layout> Layout::layout_type() {                                                 \\\n    static const auto& layout = SymbolOf(Layout(OF_PP_CAT(LayoutType::k, layout_type))); \\\n    return layout;                                                                       \\\n  }\nOF_PP_FOR_EACH_TUPLE(DEFINE_GET_LAYOUT_TYPE_FUNCTION, LAYOUT_SEQ)\n#undef DEFINE_GET_LAYOUT_TYPE_FUNCTION\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/framework/layout.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_FRAMEWORK_LAYOUT_H_\n#define ONEFLOW_CORE_FRAMEWORK_LAYOUT_H_\n#include <string>\n#include \"oneflow/core/common/symbol.h\"\n#include \"oneflow/core/common/maybe.h\"\n\nnamespace oneflow {\n\nenum class LayoutType {\n  kStrided,\n};\n\n#define LAYOUT_SEQ OF_PP_MAKE_TUPLE_SEQ(Strided)\n\nclass Layout final {\n public:\n  Layout(const Layout&) = default;\n  Layout(Layout&&) = delete;\n  explicit Layout(LayoutType layout_type) : layout_type_(layout_type) {}\n  ~Layout() = default;\n\n  bool operator==(const Layout& other) const { return this->layout_type() == other.layout_type(); }\n\n  const std::string& name() const;\n\n  LayoutType layout_type() const { return layout_type_; }\n  static Symbol<Layout> Get(LayoutType);\n#define DECLARE_GET_LAYOUT_TYPE_FUNCTION(layout_type) static Symbol<Layout> layout_type();\n  OF_PP_FOR_EACH_TUPLE(DECLARE_GET_LAYOUT_TYPE_FUNCTION, LAYOUT_SEQ)\n#undef DECLARE_GET_LAYOUT_TYPE_FUNCTION\n\n private:\n  LayoutType layout_type_;\n};\n\n}  // namespace oneflow\n\nnamespace std {\n\ntemplate<>\nstruct hash<oneflow::Layout> final {\n  size_t operator()(const oneflow::Layout& layout) const {\n    return static_cast<size_t>(layout.layout_type());\n  }\n};\n\n}  // namespace std\n\n#endif  // ONEFLOW_CORE_FRAMEWORK_LAYOUT_H_\n"
  },
  {
    "path": "oneflow/core/framework/load_library.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/load_library.h\"\n\n#include <dlfcn.h>\n\nnamespace oneflow {\n\nMaybe<void> LoadLibrary(const std::string& lib_path) {\n  void* handle = dlopen(lib_path.c_str(), RTLD_NOW | RTLD_LOCAL);\n  CHECK_OR_RETURN(handle) << \" LoadLibrary ERROR! Cannot load library file: \" + lib_path\n                          << \" the Error is: \" << dlerror();\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/framework/load_library.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_FRAMEWORK_LOAD_LIBRARY_H_\n#define ONEFLOW_CORE_FRAMEWORK_LOAD_LIBRARY_H_\n\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/common/maybe.h\"\n\nnamespace oneflow {\n\nMaybe<void> LoadLibrary(const std::string& lib_path);\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_FRAMEWORK_LOAD_LIBRARY_H_\n"
  },
  {
    "path": "oneflow/core/framework/local_tensor_infer_cache.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/local_tensor_infer_cache.h\"\n#include \"oneflow/core/framework/tensor_tuple.h\"\n#include \"oneflow/core/framework/tensor.h\"\n#include \"oneflow/core/operator/operator.h\"\n#include \"oneflow/core/framework/op_expr.h\"\n#include \"oneflow/core/common/container_util.h\"\n#include \"oneflow/core/common/env_var/eager.h\"\n#include \"oneflow/core/framework/infer_util.h\"\n\nnamespace oneflow {\nnamespace one {\n\nnamespace {\n\nMaybe<void> CheckIsDeviceSupportedByOp(const Device& device, const std::string& op_type_name) {\n  if (IsCpuOnly(op_type_name)) { CHECK_EQ_OR_RETURN(device.type(), \"cpu\"); }  // NOLINT\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> CheckInputDeviceIdentical(const LocalTensorMetaInferArgs& infer_args,\n                                      Symbol<Device> default_device,\n                                      const UserOpExpr& user_op_expr) {\n  for (int i = 0; i < infer_args.input_local_tensor_metas().size(); ++i) {\n    if (user_op_expr.IsHostMemoryInput(i)) { continue; }\n    CHECK_OR_RETURN(default_device\n                    == JUST(VectorAt(infer_args.input_local_tensor_metas(), i))->device())\n        << Error::RuntimeError()\n        << \"Expected all tensors to be on the same device, but found \"\n           \"at least two devices, \"\n        << default_device->ToString() << \" (positional 0) and \"\n        << JUST(VectorAt(infer_args.input_local_tensor_metas(), i))->device()->ToString()\n        << \" (positional \" << i << \")!\";\n  }\n  return Maybe<void>::Ok();\n}\n\nclass UserOpExprDeviceAndStreamInferContext final : public user_op::DeviceAndStreamInferContext {\n public:\n  UserOpExprDeviceAndStreamInferContext(const UserOpExpr* user_op_expr,\n                                        const LocalTensorMetaInferArgs& infer_args,\n                                        OpArgsVector<MutLocalTensorMeta>* output_tensor_metas)\n      : user_op_expr_(user_op_expr),\n        composed_attrs_(infer_args.attrs(), user_op_expr->base_attrs()),\n        infer_args_(infer_args),\n        output_tensor_metas_(output_tensor_metas) {}\n\n  const std::vector<std::pair<std::string, int32_t>>& inputs() const override {\n    return user_op_expr_->indexed_input_pairs();\n  }\n\n  const std::vector<std::pair<std::string, int32_t>>& outputs() const override {\n    return user_op_expr_->indexed_output_pairs();\n  }\n\n  Symbol<Device>* OutputTensorDevice4ArgNameAndIndex(const std::string& name,\n                                                     int64_t index) override {\n    const auto& arg_tuple = *user_op_expr_->output_arg_tuple();\n    int32_t tuple_index = arg_tuple.TensorTupleIndex4ArgNameAndIndex(name, index);\n    CHECK_GE(tuple_index, 0) << \"tuple index should be non-negative, but got \" << tuple_index;\n    CHECK_LT(tuple_index, user_op_expr_->output_size())\n        << \"tuple index \" << tuple_index << \" should be less than output size \"\n        << user_op_expr_->output_size();\n    return output_tensor_metas_->at(tuple_index).mut_device();\n  }\n\n  Symbol<Device> InputTensorDevice4ArgNameAndIndex(const std::string& name,\n                                                   int64_t index) const override {\n    const auto& arg_tuple = *user_op_expr_->input_arg_tuple();\n    int32_t tuple_index = arg_tuple.TensorTupleIndex4ArgNameAndIndex(name, index);\n    CHECK_GE(tuple_index, 0) << \"tuple index should be non-negative, but got \" << tuple_index;\n    CHECK_LT(tuple_index, user_op_expr_->input_size())\n        << \"tuple index \" << tuple_index << \" should be less than input size \"\n        << user_op_expr_->input_size();\n    return infer_args_.input_local_tensor_metas().at(tuple_index)->device();\n  }\n\n private:\n  const std::shared_ptr<const user_op::AttrVal>& Attr4Name(\n      const std::string& attr_name) const override {\n    return composed_attrs_.Attr4Name(attr_name);\n  }\n  const UserOpExpr* user_op_expr_;\n  const ComposedAttrMap composed_attrs_;\n  const LocalTensorMetaInferArgs& infer_args_;\n  OpArgsVector<MutLocalTensorMeta>* output_tensor_metas_;\n};\n\nMaybe<Symbol<Stream>> InferDeviceAndStream(const UserOpExpr& user_op_expr,\n                                           const Symbol<Device>& default_device,\n                                           const LocalTensorMetaInferArgs& infer_args,\n                                           OpArgsVector<MutLocalTensorMeta>* output_tensor_metas) {\n  Symbol<Stream> stream;\n  if (!user_op_expr.has_device_and_stream_infer_fn()) {\n    stream = JUST(GetDefaultStreamByDevice(default_device));\n    for (int i = 0; i < user_op_expr.output_size(); i++) {\n      auto& tensor_meta = output_tensor_metas->at(i);\n      *tensor_meta.mut_device() = default_device;\n    }\n  } else {\n    if (!user_op_expr.device_and_stream_infer_fn()) {\n      Symbol<Device> device = infer_args.input_local_tensor_metas().at(0)->device();\n      stream = JUST(GetDefaultStreamByDevice(device));\n    } else {\n      UserOpExprDeviceAndStreamInferContext device_and_stream_ctx(&user_op_expr, infer_args,\n                                                                  output_tensor_metas);\n      stream = JUST(user_op_expr.device_and_stream_infer_fn()(&device_and_stream_ctx));\n    }\n  }\n  return stream;\n}\n\n}  // namespace\n\nsize_t LocalTensorMetaInferArgs::hash_value() const {\n  size_t hash_value = std::hash<AttrMap>()(attrs_);\n  HashCombine(&hash_value, std::hash<Symbol<Device>>()(default_device_));\n  const auto& tensor_meta_hash_functor = std::hash<Symbol<LocalTensorMeta>>();\n  for (const auto& tensor_meta : input_local_tensor_metas_) {\n    HashCombine(&hash_value, tensor_meta_hash_functor(tensor_meta));\n  }\n  return hash_value;\n}\n\nbool LocalTensorMetaInferArgs::operator==(const LocalTensorMetaInferArgs& other) const {\n  return this->attrs_ == other.attrs_ && this->default_device_ == other.default_device_\n         && this->input_local_tensor_metas_ == other.input_local_tensor_metas_;\n}\n\nMaybe<void> LocalTensorMetaInferArgs::Init(const AttrMap& attrs, Symbol<Device> default_device,\n                                           const TensorTuple& input_tensors) {\n  this->attrs_ = attrs;\n  this->default_device_ = default_device;\n  this->input_local_tensor_metas_.resize(input_tensors.size());\n  JUST(this->InitInputLocalTensorMetas(input_tensors));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> LocalTensorMetaInferArgs::InitInputLocalTensorMetas(const TensorTuple& input_tensors) {\n  for (int i = 0; i < input_tensors.size(); ++i) {\n    input_local_tensor_metas_.at(i) = JUST(input_tensors.at(i)->local_tensor_meta());\n  }\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<const LocalTensorInferResult> LocalTensorInferCache::Infer(\n    const UserOpExpr& user_op_expr, const LocalTensorMetaInferArgs& infer_args) {\n  const auto& default_device = infer_args.default_device();\n  JUST(CheckInputDeviceIdentical(infer_args, default_device, user_op_expr));\n  JUST(CheckIsDeviceSupportedByOp(*default_device, user_op_expr.op_type_name()));\n\n  auto result = std::make_unique<LocalTensorInferResult>(user_op_expr.output_size());\n\n  OpArgsVector<MutLocalTensorMeta> output_mut_metas(user_op_expr.output_size());\n  // Infer devices\n  Symbol<Stream> stream =\n      JUST(InferDeviceAndStream(user_op_expr, default_device, infer_args, &output_mut_metas));\n  result->set_stream(stream);\n\n  {\n    const auto& GetInputTensorMeta = [&](int32_t i) -> const TensorMeta* {\n      return infer_args.input_local_tensor_metas().at(i).shared_from_symbol().get();\n    };\n    JUST(user_op_expr.InferPhysicalTensorDesc(\n        infer_args.attrs(), stream->device()->type(), GetInputTensorMeta,\n        [&](int32_t i) -> TensorMeta* { return &output_mut_metas.at(i); }));\n  }\n\n  auto* mut_output_tensor_metas = result->mut_output_tensor_metas();\n  for (int32_t i = 0; i < user_op_expr.output_size(); ++i) {\n    if (!JUST(user_op_expr.SupportNonContiguous())) {\n      Stride stride(output_mut_metas.at(i).shape());\n      output_mut_metas.at(i).set_stride(stride);\n    }\n    CHECK_OR_RETURN(static_cast<bool>(output_mut_metas.at(i).device()))\n        << Error::RuntimeError() << \"device not infered\";\n    mut_output_tensor_metas->at(i) = SymbolOf(\n        LocalTensorMeta(output_mut_metas.at(i).shape(), output_mut_metas.at(i).stride(),\n                        output_mut_metas.at(i).data_type(), output_mut_metas.at(i).memory_format(),\n                        output_mut_metas.at(i).device()));\n  }\n  return std::shared_ptr<const LocalTensorInferResult>(std::move(result));\n}\n\nMaybe<const LocalTensorInferResult> LocalTensorInferCache::GetOrInfer(\n    const LocalTensorMetaInferArgs& infer_args) {\n  if (ThreadLocalEnvBool<ONEFLOW_EAGER_ENABLE_LOCAL_INFER_CACHE>()) {\n    auto iter = cache_.find(infer_args);\n    if (iter == cache_.end()) {\n      if (unlikely(cache_.size()\n                   >= ThreadLocalEnvInteger<ONEFLOW_EAGER_TENSOR_INFER_CACHE_SIZE>())) {\n        cache_.clear();\n      }\n      const auto& user_op_expr = user_op_expr_.lock();\n      CHECK_OR_RETURN(static_cast<bool>(user_op_expr));  // NOLINT\n      const auto& output_tensor_metas = JUST(Infer(*user_op_expr, infer_args));\n      iter = cache_.emplace(infer_args, output_tensor_metas).first;\n    }\n    return iter->second;\n  } else {\n    const auto& user_op_expr = user_op_expr_.lock();\n    return JUST(Infer(*user_op_expr, infer_args));\n  }\n}\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/framework/local_tensor_infer_cache.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_FRAMEWORK_LOCAL_TENSOR_INFER_CACHE_H_\n#define ONEFLOW_CORE_FRAMEWORK_LOCAL_TENSOR_INFER_CACHE_H_\n\n#include \"oneflow/core/common/symbol.h\"\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/common/op_args_vector.h\"\n#include \"oneflow/core/framework/attr_map.h\"\n#include \"oneflow/core/framework/device.h\"\n#include \"oneflow/core/framework/stream.h\"\n#include \"oneflow/core/common/tensor_meta.h\"\n\nnamespace oneflow {\n\nclass Device;\n\nnamespace one {\n\nclass TensorTuple;\nclass UserOpExpr;\n\nclass LocalTensorMetaInferArgs final {\n public:\n  LocalTensorMetaInferArgs() = default;\n  LocalTensorMetaInferArgs(const LocalTensorMetaInferArgs&) = default;\n  LocalTensorMetaInferArgs(LocalTensorMetaInferArgs&&) = default;\n  ~LocalTensorMetaInferArgs() = default;\n\n  const OpArgsVector<Symbol<LocalTensorMeta>>& input_local_tensor_metas() const {\n    return input_local_tensor_metas_;\n  }\n  const AttrMap& attrs() const { return attrs_; }\n\n  const Symbol<Device>& default_device() const { return default_device_; }\n\n  size_t hash_value() const;\n\n  bool operator==(const LocalTensorMetaInferArgs& other) const;\n\n  Maybe<void> Init(const AttrMap& attrs, Symbol<Device> default_device,\n                   const TensorTuple& input_tensors);\n\n private:\n  Maybe<void> InitInputLocalTensorMetas(const TensorTuple& input_tensors);\n\n  AttrMap attrs_;\n  Symbol<Device> default_device_;\n  OpArgsVector<Symbol<LocalTensorMeta>> input_local_tensor_metas_;\n};\n\n}  // namespace one\n}  // namespace oneflow\n\nnamespace std {\n\ntemplate<>\nstruct hash<oneflow::one::LocalTensorMetaInferArgs> final {\n  size_t operator()(const oneflow::one::LocalTensorMetaInferArgs& val) const {\n    return val.hash_value();\n  }\n};\n\n}  // namespace std\n\nnamespace oneflow {\nnamespace one {\n\nclass LocalTensorInferResult final {\n public:\n  LocalTensorInferResult(size_t output_size) : output_tensor_metas_(output_size) {}\n  LocalTensorInferResult(const LocalTensorInferResult&) = delete;\n  LocalTensorInferResult(LocalTensorInferResult&&) = delete;\n  ~LocalTensorInferResult() = default;\n\n  const OpArgsVector<Symbol<LocalTensorMeta>>& output_tensor_metas() const {\n    return output_tensor_metas_;\n  }\n  OpArgsVector<Symbol<LocalTensorMeta>>* mut_output_tensor_metas() { return &output_tensor_metas_; }\n\n  const Symbol<Stream>& stream() const { return stream_; }\n  void set_stream(const Symbol<Stream>& stream) { stream_ = stream; }\n\n private:\n  OpArgsVector<Symbol<LocalTensorMeta>> output_tensor_metas_;\n  Symbol<Stream> stream_;\n};\n\nclass LocalTensorInferCache final {\n public:\n  LocalTensorInferCache(const std::shared_ptr<const UserOpExpr>& user_op_expr)\n      : user_op_expr_(user_op_expr) {}\n\n  Maybe<const LocalTensorInferResult> GetOrInfer(const LocalTensorMetaInferArgs& infer_args);\n\n private:\n  static Maybe<const LocalTensorInferResult> Infer(const UserOpExpr& user_op_expr,\n                                                   const LocalTensorMetaInferArgs& infer_args);\n\n  std::weak_ptr<const UserOpExpr> user_op_expr_;\n  HashMap<LocalTensorMetaInferArgs, std::shared_ptr<const LocalTensorInferResult>> cache_;\n};\n\n}  // namespace one\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_FRAMEWORK_LOCAL_TENSOR_INFER_CACHE_H_\n"
  },
  {
    "path": "oneflow/core/framework/multi_client_session_context.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/common/buffer_manager.h\"\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/common/protobuf.h\"\n#include \"oneflow/core/framework/multi_client_session_context.h\"\n#include \"oneflow/core/framework/load_library.h\"\n#include \"oneflow/core/job/id_state.h\"\n#include \"oneflow/core/job/resource.pb.h\"\n#include \"oneflow/core/job/version.h\"\n#include \"oneflow/core/job/global_for.h\"\n#include \"oneflow/core/job/id_manager.h\"\n#include \"oneflow/core/job/job_instance.h\"\n#include \"oneflow/core/job/critical_section_instance.h\"\n#include \"oneflow/core/job/job_build_and_infer_ctx_mgr.h\"\n#include \"oneflow/core/job/runtime_context.h\"\n#include \"oneflow/core/job/runtime_job_descs.h\"\n#include \"oneflow/core/thread/thread_manager.h\"\n#include \"oneflow/core/memory/memory_allocator.h\"\n#include \"oneflow/core/register/register_manager.h\"\n#include \"oneflow/user/summary/events_writer.h\"\n#include \"oneflow/core/rpc/include/global_process_ctx.h\"\n#include \"oneflow/core/memory/chunk_manager.h\"\n#include \"oneflow/core/vm/vm_util.h\"\n#include \"oneflow/core/job/collective_boxing/scheduler.h\"\n#include \"oneflow/core/graph/task_stream_index_manager.h\"\n#include \"oneflow/core/framework/variable_tensor_mgr.h\"\n#ifdef WITH_CUDA\n#include <cuda.h>\n#endif  // WITH_CUDA\n\nnamespace oneflow {\n\nnamespace {\n\nint32_t GetCpuDeviceNum() { return std::thread::hardware_concurrency(); }\n\n}  // namespace\n\nMultiClientSessionContext::MultiClientSessionContext(\n    const std::shared_ptr<EnvGlobalObjectsScope>& env_ctx)\n    : env_ctx_(env_ctx) {\n  CHECK(Singleton<MultiClientSessionContext>::Get() == nullptr)\n      << \"Duplicate multi client session context\";\n  Singleton<MultiClientSessionContext>::SetAllocated(this);\n}\n\nMultiClientSessionContext::~MultiClientSessionContext() {\n  CHECK_JUST(TryClose());\n  if (Singleton<MultiClientSessionContext>::Get() != nullptr) {\n    Singleton<MultiClientSessionContext>::SetAllocated(nullptr);\n  }\n  env_ctx_.reset();\n}\n\nMaybe<void> MultiClientSessionContext::TryInit(const ConfigProto& config_proto) {\n  if (!is_inited_) {\n    DumpVersionInfo();\n\n    Resource resource = config_proto.resource();\n\n    {\n      // NOTE(chengcheng):\n      //   In multi-client, user can NOT config cpu_device_num.\n      //\n      //   cpu_device_num is a confusing name, it should be explained as:\n      //       in current rank, assign CPU actor compute stream in this optional range.\n      //       That is, the number of independent CPU devices that can be abstracted from\n      //       this machine and this process.\n      //\n      //   NOTE: cpu_device_num NOT necessarily equal to the num of process\n      //       on this machine.\n      resource.set_machine_num(GlobalProcessCtx::NodeSize());\n      resource.set_cpu_device_num(GetCpuDeviceNum());\n    }\n\n    // NOTE(chengcheng): detele first because in EnvGlobalObjectScope has created ResourceDesc.\n    if (Singleton<ResourceDesc, ForSession>::Get() != nullptr) {\n      // TODO(chengcheng): reorganize dependency of all Global objects.\n      Singleton<ResourceDesc, ForSession>::Delete();\n    }\n    Singleton<ResourceDesc, ForSession>::New(resource, GlobalProcessCtx::NumOfProcessPerNode());\n    Singleton<IDMgr>::New();\n    Singleton<TaskStreamIndexManager>::New();\n    // TODO(chengcheng): refactor JobBuildAndInferCtxMgr\n    Singleton<LazyJobBuildAndInferCtxMgr>::New();\n\n    {\n      // NOTE(chengcheng): init runtime global objects\n      Singleton<BufferMgr<std::shared_ptr<JobInstance>>>::New();\n      Singleton<BufferMgr<std::shared_ptr<CriticalSectionInstance>>>::New();\n      Singleton<RuntimeCtx>::New();\n      Singleton<MemoryAllocator>::New();\n      Singleton<ChunkMgr>::New();\n      Singleton<RegstMgr>::New();\n      Singleton<ActorMsgBus>::New();\n      Singleton<ThreadMgr>::New();\n      Singleton<RuntimeJobDescs>::New();\n      Singleton<summary::EventsWriter>::New();\n      Singleton<boxing::collective::Scheduler>::New();\n      Singleton<VariableTensorMgr>::New();\n    }\n\n    is_inited_ = true;\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> MultiClientSessionContext::TryInit(const std::string& config_proto_str) {\n  ConfigProto config_proto;\n  CHECK_OR_RETURN(TxtString2PbMessage(config_proto_str, &config_proto))\n      << Error::RuntimeError() << \"failed to parse config_proto: \" << config_proto_str;\n  return TryInit(config_proto);\n}\n\nMaybe<void> MultiClientSessionContext::UpdateResource(const Resource& reso_proto) {\n  CHECK_OR_RETURN(is_inited_) << Error::RuntimeError()\n                              << \" session must be inited when updating resource.\";\n  CHECK_NOTNULL_OR_RETURN((Singleton<ResourceDesc, ForSession>::Get()))\n      << Error::RuntimeError() << \"ResourceDesc get failed!\";\n  Singleton<ResourceDesc, ForSession>::Get()->Update(reso_proto);\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> MultiClientSessionContext::UpdateResource(const std::string& reso_proto_str) {\n  Resource reso_proto;\n  CHECK_OR_RETURN(TxtString2PbMessage(reso_proto_str, &reso_proto))\n      << Error::RuntimeError() << \"failed to parse config_proto: \" << reso_proto_str;\n  return UpdateResource(reso_proto);\n}\n\nMaybe<void> MultiClientSessionContext::TryClose() {\n  if (is_inited_) {\n    VLOG(1) << \"Try to delete multi client session context.\" << std::endl;\n    {\n      // NOTE(chengcheng): delete runtime global objects\n      Singleton<boxing::collective::Scheduler>::Delete();\n      Singleton<summary::EventsWriter>::Delete();\n      Singleton<RuntimeJobDescs>::Delete();\n      Singleton<ThreadMgr>::Delete();\n      Singleton<ActorMsgBus>::Delete();\n      Singleton<RegstMgr>::Delete();\n      Singleton<ChunkMgr>::Delete();\n      Singleton<MemoryAllocator>::Delete();\n      Singleton<RuntimeCtx>::Delete();\n      Singleton<BufferMgr<std::shared_ptr<CriticalSectionInstance>>>::Delete();\n      Singleton<BufferMgr<std::shared_ptr<JobInstance>>>::Delete();\n      Singleton<VariableTensorMgr>::Delete();\n    }\n\n    Singleton<LazyJobBuildAndInferCtxMgr>::Delete();\n    Singleton<TaskStreamIndexManager>::Delete();\n    Singleton<IDMgr>::Delete();\n\n    // TODO(chengcheng): remove template ForEnv and ForSession\n    Singleton<ResourceDesc, ForSession>::Delete();\n    // NOTE(chengcheng): New after delete because in EnvGlobalObjectScope once created ResourceDesc.\n    Singleton<ResourceDesc, ForSession>::New(Singleton<ResourceDesc, ForEnv>::Get()->resource(),\n                                             GlobalProcessCtx::NumOfProcessPerNode());\n    VLOG(1) << \"Finish delete multi client session context.\" << std::endl;\n    is_inited_ = false;\n  }\n  return Maybe<void>::Ok();\n}\n\nvoid MultiClientSessionContext::StoreFreeEagerTensorWithNameByGraphName(\n    const std::string& graph_name, const std::shared_ptr<one::Tensor>& tensor,\n    const std::string& tensor_name) {\n  auto it = graph_name2free_eager_tensors_.find(graph_name);\n  if (it == graph_name2free_eager_tensors_.end()) {\n    it = graph_name2free_eager_tensors_\n             .emplace(graph_name,\n                      std::vector<std::pair<std::string, std::shared_ptr<one::Tensor>>>())\n             .first;\n  }\n  it->second.emplace_back(std::make_pair(tensor_name, tensor));\n}\n\nconst std::vector<std::pair<std::string, std::shared_ptr<one::Tensor>>>&\nMultiClientSessionContext::GetFreeEagerTensorNamePairByGraphName(const std::string& graph_name) {\n  auto it = graph_name2free_eager_tensors_.find(graph_name);\n  if (it == graph_name2free_eager_tensors_.end()) {\n    it = graph_name2free_eager_tensors_\n             .emplace(graph_name,\n                      std::vector<std::pair<std::string, std::shared_ptr<one::Tensor>>>())\n             .first;\n  }\n  return it->second;\n}\n\nvoid MultiClientSessionContext::RemoveGraphFreeEagerTensors(const std::string& graph_name) {\n  graph_name2free_eager_tensors_.erase(graph_name);\n}\n\nIdState MultiClientSessionContext::GetIdState() {\n  CHECK(Singleton<IDMgr>::Get() != nullptr);\n  CHECK(Singleton<TaskStreamIndexManager>::Get() != nullptr);\n  CHECK(Singleton<LazyJobBuildAndInferCtxMgr>::Get() != nullptr);\n  IdState id_state;\n\n  id_state.job_id_state_ = Singleton<LazyJobBuildAndInferCtxMgr>::Get()->GetJobIdCount();\n  Singleton<IDMgr>::Get()->SaveIdAndTaskIndex(&id_state);\n  Singleton<TaskStreamIndexManager>::Get()->GetTaskStreamIndex(&id_state.stream_index_state_);\n  return id_state;\n}\n\nvoid MultiClientSessionContext::SetIdState(const IdState& id_state) {\n  CHECK(Singleton<IDMgr>::Get() != nullptr);\n  CHECK(Singleton<TaskStreamIndexManager>::Get() != nullptr);\n  CHECK(Singleton<LazyJobBuildAndInferCtxMgr>::Get() != nullptr);\n  Singleton<IDMgr>::Get()->TryUpdateIdAndTaskIndex(&id_state);\n  Singleton<TaskStreamIndexManager>::Get()->TryUpdateTaskStreamIndex(id_state.stream_index_state_);\n  Singleton<LazyJobBuildAndInferCtxMgr>::Get()->TryUpdateJobIdCount(id_state.job_id_state_);\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/framework/multi_client_session_context.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_FRAMEWORK_MULTI_CLIENT_SESSION_CONTEXT_H_\n#define ONEFLOW_CORE_FRAMEWORK_MULTI_CLIENT_SESSION_CONTEXT_H_\n\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/job/id_state.h\"\n#include \"oneflow/core/job/job_set.pb.h\"\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/framework/tensor.h\"\n#include \"oneflow/core/job/env_global_objects_scope.h\"\n\nnamespace oneflow {\n\nclass MultiClientSessionContext {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(MultiClientSessionContext);\n  explicit MultiClientSessionContext(const std::shared_ptr<EnvGlobalObjectsScope>&);\n  ~MultiClientSessionContext();\n\n  Maybe<void> TryInit(const ConfigProto& config_proto);\n  Maybe<void> TryInit(const std::string& config_proto_str);\n  Maybe<void> UpdateResource(const Resource& reso_proto);\n  Maybe<void> UpdateResource(const std::string& reso_proto_str);\n\n  Maybe<void> TryClose();\n\n  // NOTE(chengcheng): for nn.Graph catch free EagerTensor in Graph.build().\n  //   NNGraph should NOT hold ANY shared_ptr<Tensor> because NNGraph will send to VM stream in\n  //   RunLazyNNGraphInstruction, the tensor in NNGraph will Never be released for hold in VM\n  //   instrunction and compute stream. So we store free EagerTensor in MultiClientSessionContext,\n  //   and will be release in NNGraph destructor.\n  void StoreFreeEagerTensorWithNameByGraphName(const std::string& graph_name,\n                                               const std::shared_ptr<one::Tensor>& tensor,\n                                               const std::string& tensor_name);\n  const std::vector<std::pair<std::string, std::shared_ptr<one::Tensor>>>&\n  GetFreeEagerTensorNamePairByGraphName(const std::string& graph_name);\n  void RemoveGraphFreeEagerTensors(const std::string& graph_name);\n\n  IdState GetIdState();\n  void SetIdState(const IdState& id_state);\n\n private:\n  bool is_inited_ = false;\n  std::shared_ptr<EnvGlobalObjectsScope> env_ctx_;\n  HashMap<std::string, std::vector<std::pair<std::string, std::shared_ptr<one::Tensor>>>>\n      graph_name2free_eager_tensors_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_FRAMEWORK_MULTI_CLIENT_SESSION_CONTEXT_H_\n"
  },
  {
    "path": "oneflow/core/framework/multi_thread.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/multi_thread.h\"\n#include \"oneflow/core/thread/thread_manager.h\"\n\nnamespace oneflow {\n\nnamespace user_op {\n\nvoid MultiThreadLoopInOpKernel(size_t num, std::function<void(size_t i)> Callback) {\n  MultiThreadLoop(num, Callback);\n}\n\n}  // namespace user_op\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/framework/multi_thread.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_FRAMEWORK_MULTI_THREAD_H_\n#define ONEFLOW_CORE_FRAMEWORK_MULTI_THREAD_H_\n\n#include \"oneflow/core/common/util.h\"\n\nnamespace oneflow {\n\nnamespace user_op {\n\nvoid MultiThreadLoopInOpKernel(size_t num, std::function<void(size_t i)> Callback);\n\n}  // namespace user_op\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_FRAMEWORK_MULTI_THREAD_H_\n"
  },
  {
    "path": "oneflow/core/framework/mutable_attr_map.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_FRAMEWORK_CACHED_ATTR_MAP_H_\n#define ONEFLOW_CORE_FRAMEWORK_CACHED_ATTR_MAP_H_\n\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/common/small_vector.h\"\n#include \"oneflow/core/common/throw.h\"\n#include \"oneflow/core/framework/attr_map.h\"\n#include \"oneflow/core/framework/attr_value.h\"\n#include \"oneflow/core/framework/attr_value_accessor.h\"\n#include \"oneflow/core/framework/ordered_string_list.h\"\n#include \"oneflow/core/framework/user_op_attr.pb.h\"\n#include \"oneflow/core/operator/op_conf.pb.h\"\n\nnamespace oneflow {\n\nclass MutableAttrMap {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(MutableAttrMap);\n\n  explicit MutableAttrMap(const std::vector<std::string>& attr_names)\n      : max_size_(attr_names.size()),\n        valid_masks_(max_size_, 0),\n        ordered_attr_names_(std::make_shared<OrderedStringList<8>>()) {\n    for (const auto& attr_name : attr_names) { ordered_attr_names_->emplace_back(attr_name); }\n    attrs_.resize(max_size_);\n  }\n\n  ~MutableAttrMap() = default;\n\n  size_t max_size() const { return max_size_; }\n\n  const std::shared_ptr<OrderedStringList<8>>& ordered_attr_names() const {\n    return ordered_attr_names_;\n  }\n  const small_vector<bool, 8>& valid_masks() const { return valid_masks_; }\n  const small_vector<std::shared_ptr<user_op::AttrVal>, 8>& attrs() const { return attrs_; }\n\n  inline void reset() {\n    // mark all cached attributes as illegal values\n    memset(valid_masks_.data(), 0, max_size_);\n  }\n\n  template<typename T>\n  inline void SetAttr(const char* attr_name, const T& attr_val) {\n    auto idx = ordered_attr_names_->order(attr_name);\n    CHECK_OR_THROW(idx != -1) << \"has no attribute named \" << attr_name;\n    SetAttrNoThrow(idx, attr_val);\n  }\n\n  template<int I, typename T>\n  inline void SetAttr(const T& attr_val) {\n    CHECK_LT_OR_THROW(I, max_size_)\n        << \"index \" << I << \" is out of bound, and the max size is \" << max_size_;\n    SetAttrNoThrow(I, attr_val);\n  }\n\n  template<typename... Args>\n  inline void SetAllAttrs(Args&&... args) {\n    CHECK_EQ_OR_THROW(sizeof...(args), max_size_)\n        << \"requires \" << max_size_ << \" arguments, but gives \" << sizeof...(args);\n    SetAttrNoThrow<Args...>(std::forward<Args>(args)...,\n                            std::make_index_sequence<sizeof...(args)>{});\n  }\n\n private:\n  template<typename T, typename std::enable_if<!std::is_same<T, NullOptType>::value\n                                                   && !internal::IsOptional<T>::value,\n                                               int>::type = 0>\n  inline void SetAttrNoThrow(int idx, const T& attr_val) {\n    valid_masks_[idx] = true;\n    if (!attrs_[idx] /*|| attrs_[idx]->type() != user_op::GetAttrType<T>::value*/\n        || *static_cast<const T*>(attrs_[idx]->Ptr()) != attr_val) {\n      attrs_[idx] = std::make_shared<user_op::TypedAttrVal<T>>(attr_val);\n    }\n  }\n\n  template<typename T, typename std::enable_if<internal::IsOptional<T>::value, int>::type = 0>\n  inline void SetAttrNoThrow(int idx, const T& attr_val) {\n    if (attr_val) {\n      using U = typename T::value_type;\n      SetAttrNoThrow(idx, attr_val.value_or(U()));\n    }\n  }\n\n  template<typename T, typename std::enable_if<std::is_same<T, NullOptType>::value, int>::type = 0>\n  inline void SetAttrNoThrow(int idx, const T&) {}\n\n  template<typename... Args, size_t... I>\n  inline void SetAttrNoThrow(Args&&... args, std::index_sequence<I...>) {\n    (SetAttrNoThrow(I, std::forward<Args>(args)), ...);\n  }\n\n  // The actually count of all attributes\n  size_t max_size_;\n  small_vector<bool, 8> valid_masks_;\n  small_vector<std::shared_ptr<user_op::AttrVal>, 8> attrs_;\n  // The ordered attribute names is determined and should be shared\n  // between other AttrMap\n  std::shared_ptr<OrderedStringList<8>> ordered_attr_names_;\n};\n\n#define THREAD_CACHED_MUTABLE_ATTR_MAP(...)                                          \\\n  []() -> MutableAttrMap& {                                                          \\\n    thread_local static MutableAttrMap attrs(std::vector<std::string>{__VA_ARGS__}); \\\n    attrs.reset();                                                                   \\\n    return attrs;                                                                    \\\n  }()\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_FRAMEWORK_CACHED_ATTR_MAP_H_\n"
  },
  {
    "path": "oneflow/core/framework/nd_sbp.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/framework/nd_sbp.h\"\n#include \"oneflow/core/job/sbp_parallel.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nMaybe<std::vector<std::string>> FindOrCreateNdSbpString(Symbol<NdSbp> nd_sbp) {\n  static thread_local auto* nd_sbp2nd_sbp_str =\n      new HashMap<Symbol<NdSbp>, std::shared_ptr<std::vector<std::string>>>();\n  auto iter = nd_sbp2nd_sbp_str->find(nd_sbp);\n  if (iter == nd_sbp2nd_sbp_str->end()) {\n    std::shared_ptr<std::vector<std::string>> nd_sbp_str =\n        std::make_shared<std::vector<std::string>>(nd_sbp->sbp_parallel_size());\n    for (int64_t i = 0; i < nd_sbp_str->size(); ++i) {\n      nd_sbp_str->at(i) = SbpParallelToString(nd_sbp->sbp_parallel(i));\n    }\n    iter = nd_sbp2nd_sbp_str->emplace(nd_sbp, nd_sbp_str).first;\n  }\n  return iter->second;\n}\n\nMaybe<void> GetDualSbpParallel(const SbpParallel& sbp_parallel, SbpParallel* dual_sbp_parallel) {\n  if (sbp_parallel.has_split_parallel()) {\n    *dual_sbp_parallel = sbp_parallel;\n  } else if (sbp_parallel.has_broadcast_parallel()) {\n    dual_sbp_parallel->mutable_partial_sum_parallel();\n  } else if (sbp_parallel.has_partial_sum_parallel()) {\n    dual_sbp_parallel->mutable_broadcast_parallel();\n  } else {\n    UNIMPLEMENTED_THEN_RETURN();\n  }\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\nMaybe<Symbol<NdSbp>> GetDualNdSbp(Symbol<NdSbp> nd_sbp) {\n  static thread_local HashMap<Symbol<NdSbp>, Symbol<NdSbp>> map;\n  auto iter = map.find(nd_sbp);\n  if (iter == map.end()) {\n    NdSbp dual_nd_sbp;\n    auto* mut_sbp_parallel = dual_nd_sbp.mutable_sbp_parallel();\n    for (const auto& sbp_parallel : nd_sbp->sbp_parallel()) {\n      JUST(GetDualSbpParallel(sbp_parallel, mut_sbp_parallel->Add()));\n    }\n    iter = map.emplace(nd_sbp, SymbolOf(dual_nd_sbp)).first;\n  }\n  return iter->second;\n}\n\nMaybe<std::vector<std::string>> GetNdSbpStrList(const std::vector<Symbol<SbpParallel>>& sbp_list) {\n  return FindOrCreateNdSbpString(JUST(GetNdSbp(sbp_list)));\n}\n\nMaybe<std::vector<std::string>> GetNdSbpStrList(Symbol<NdSbp> nd_sbp) {\n  return FindOrCreateNdSbpString(nd_sbp);\n}\n\nMaybe<std::vector<std::string>> GetDualNdSbpStrList(Symbol<NdSbp> nd_sbp) {\n  return GetNdSbpStrList(JUST(GetDualNdSbp(nd_sbp)));\n}\n\nnamespace private_details {\n\nMaybe<Symbol<NdSbp>> RawGetNdSbp(const std::vector<Symbol<SbpParallel>>& sbp_list) {\n  CHECK_OR_RETURN(!sbp_list.empty())\n      << Error::InvalidValueError() << \"sbp_list should be non-empty\";\n  NdSbp nd_sbp;\n  for (const auto& sbp : sbp_list) { *(nd_sbp.mutable_sbp_parallel()->Add()) = *sbp; }\n  return SymbolOf(nd_sbp);\n}\n\nMaybe<std::vector<Symbol<SbpParallel>>> RawGetSbpList(Symbol<NdSbp> nd_sbp) {\n  const auto& vec = std::make_shared<std::vector<Symbol<SbpParallel>>>();\n  CHECK_OR_RETURN(!nd_sbp->sbp_parallel().empty())\n      << Error::InvalidValueError() << \"sbp_parallel should be non-empty\";\n  for (const auto& sbp_parallel : nd_sbp->sbp_parallel()) {\n    vec->emplace_back(SymbolOf(sbp_parallel));\n  }\n  return vec;\n}\n\nbool RawContainSplitSbp(Symbol<NdSbp> nd_sbp) {\n  for (int32_t i = 0; i < nd_sbp->sbp_parallel_size(); ++i) {\n    if (nd_sbp->sbp_parallel(i).has_split_parallel()) { return true; }\n  }\n  return false;\n}\n\nMaybe<std::vector<Symbol<SbpParallel>>> RawNdSbpReplacePartialByBroadcast(\n    const std::vector<Symbol<SbpParallel>>& sbp_list) {\n  auto result = std::make_shared<std::vector<Symbol<SbpParallel>>>(sbp_list.size());\n  for (int i = 0; i < sbp_list.size(); ++i) {\n    const auto& sbp = sbp_list[i];\n    if (sbp->has_partial_sum_parallel()) {\n      (*result)[i] = JUST(MakeBroadcastSbpParallel());\n    } else {\n      (*result)[i] = sbp;\n    }\n  }\n  return result;\n}\n\n}  // namespace private_details\n\nconst std::vector<Symbol<SbpParallel>>& GetNoneSbpList() {\n  static thread_local std::vector<Symbol<SbpParallel>> none;\n  return none;\n}\n\nstd::string SbpToString(Symbol<SbpParallel> sbp_sym) { return SbpToString(*sbp_sym); }\n\nstd::string NdSbpToString(Symbol<NdSbp> nd_sbp_sym) { return NdSbpToString(*nd_sbp_sym); }\n\nstd::string SbpToString(const SbpParallel& sbp) {\n  std::ostringstream ss;\n  if (sbp.has_broadcast_parallel()) {\n    ss << \"B\";\n  } else if (sbp.has_partial_sum_parallel()) {\n    ss << \"P\";\n  } else if (sbp.has_split_parallel()) {\n    ss << \"S(\" << std::to_string(sbp.split_parallel().axis()) << \")\";\n  } else {\n    UNIMPLEMENTED();\n  }\n  return ss.str();\n}\n\nstd::string NdSbpToString(const NdSbp& nd_sbp) {\n  std::ostringstream ss;\n  ss << \"(\";\n  for (size_t i = 0; i < nd_sbp.sbp_parallel_size(); ++i) {\n    if (i > 0) { ss << \", \"; }\n    ss << SbpToString(nd_sbp.sbp_parallel(i));\n  }\n  ss << \")\";\n  return ss.str();\n}\n\nMaybe<Symbol<NdSbp>> SetSbpAtAxis(Symbol<NdSbp> nd_sbp, Symbol<SbpParallel> sbp, int axis) {\n  return SetSbpAtAxis(*nd_sbp, *sbp, axis);\n}\n\nMaybe<Symbol<NdSbp>> SetSbpAtAxis(const NdSbp& nd_sbp, const SbpParallel& sbp, int axis) {\n  CHECK_LT_OR_RETURN(axis, nd_sbp.sbp_parallel_size())\n      << Error::RuntimeError() << \"Expected axis to be less than the size of sbp list (\"\n      << nd_sbp.sbp_parallel_size() << \"), but got \" << axis;\n  NdSbp out_nd_sbp = nd_sbp;\n  *out_nd_sbp.mutable_sbp_parallel(axis) = sbp;\n  return SymbolOf(out_nd_sbp);\n}\n\nMaybe<Symbol<NdSbp>> SbpToNdSbp(Symbol<SbpParallel> sbp) { return SbpToNdSbp(*sbp); }\n\nMaybe<Symbol<NdSbp>> SbpToNdSbp(const SbpParallel& sbp) {\n  NdSbp out_nd_sbp;\n  *out_nd_sbp.add_sbp_parallel() = sbp;\n  return SymbolOf(out_nd_sbp);\n}\n\n// If an nd sbp can be converted to a 1d sbp.\nbool Is1dSbp(const NdSbp& nd_sbp) {\n  if (nd_sbp.sbp_parallel_size() == 0) { return false; }\n  // Equivalent to\n  // return std::all_of(nd_sbp.sbp_parallel().begin() + 1, nd_sbp.sbp_parallel().end(),\n  //                    [&](const auto& sbp) { return sbp == nd_sbp.sbp_parallel(0); });\n  for (int32_t i = 1; i < nd_sbp.sbp_parallel_size(); i++) {\n    if (nd_sbp.sbp_parallel(0) != nd_sbp.sbp_parallel(i)) { return false; }\n  }\n  return true;\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/framework/nd_sbp.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_FRAMEWORK_ND_SBP_H_\n#define ONEFLOW_CORE_FRAMEWORK_ND_SBP_H_\n\n#include <vector>\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/common/symbol.h\"\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/common/decorator.h\"\n#include \"oneflow/core/job/sbp_parallel.h\"\n\nnamespace oneflow {\n\nMaybe<Symbol<NdSbp>> GetDualNdSbp(Symbol<NdSbp> nd_sbp);\n\nMaybe<Symbol<NdSbp>> GetDualNdSbp(Symbol<NdSbp> sbp_list);\n\nMaybe<std::vector<std::string>> GetNdSbpStrList(const std::vector<Symbol<SbpParallel>>& sbp_list);\n\nMaybe<std::vector<std::string>> GetNdSbpStrList(Symbol<NdSbp> nd_sbp);\n\nMaybe<std::vector<std::string>> GetDualNdSbpStrList(Symbol<NdSbp> nd_sbp);\n\nMaybe<std::vector<std::string>> GetDualNdSbpStrList(Symbol<NdSbp> nd_sbp);\n\nnamespace private_details {\n\nMaybe<Symbol<NdSbp>> RawGetNdSbp(const std::vector<Symbol<SbpParallel>>& sbp_list);\nMaybe<std::vector<Symbol<SbpParallel>>> RawGetSbpList(Symbol<NdSbp> nd_sbp);\nbool RawContainSplitSbp(Symbol<NdSbp> nd_sbp);\n\nMaybe<std::vector<Symbol<SbpParallel>>> RawNdSbpReplacePartialByBroadcast(\n    const std::vector<Symbol<SbpParallel>>& sbp_list);\n\n}  // namespace private_details\n\nstatic constexpr auto* GetNdSbp = DECORATE(&private_details::RawGetNdSbp, ThreadLocalCopiable);\nstatic constexpr auto* GetSbpList = DECORATE(&private_details::RawGetSbpList, ThreadLocal);\nstatic constexpr auto* ContainSplitSbp =\n    DECORATE(&private_details::RawContainSplitSbp, ThreadLocal);\nconst std::vector<Symbol<SbpParallel>>& GetNoneSbpList();\n\nstatic constexpr auto* NdSbpReplacePartialByBroadcast =\n    DECORATE(&private_details::RawNdSbpReplacePartialByBroadcast, ThreadLocalCachedCopiable);\n\nstd::string SbpToString(Symbol<SbpParallel> sbp_sym);\nstd::string NdSbpToString(Symbol<NdSbp> nd_sbp_sym);\nstd::string SbpToString(const SbpParallel& sbp);\nstd::string NdSbpToString(const NdSbp& nd_sbp);\n\nMaybe<Symbol<NdSbp>> SetSbpAtAxis(Symbol<NdSbp> nd_sbp, Symbol<SbpParallel> sbp, int axis);\nMaybe<Symbol<NdSbp>> SetSbpAtAxis(const NdSbp& nd_sbp, const SbpParallel& sbp, int axis);\n\nMaybe<Symbol<NdSbp>> SbpToNdSbp(Symbol<SbpParallel> sbp);\nMaybe<Symbol<NdSbp>> SbpToNdSbp(const SbpParallel& sbp);\n\n// If an nd sbp can be converted to a 1d sbp.\nbool Is1dSbp(const NdSbp& nd_sbp);\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_FRAMEWORK_ND_SBP_H_\n"
  },
  {
    "path": "oneflow/core/framework/nn_graph.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/nn_graph.h\"\n#include \"oneflow/core/common/buffer_manager.h\"\n#include \"oneflow/core/common/env_var/debug_mode.h\"\n#include \"oneflow/core/common/hash_container.h\"\n#include \"oneflow/core/common/just.h\"\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/common/scalar.h\"\n#include \"oneflow/core/common/cost_util.h\"\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/common/container_util.h\"\n#include \"oneflow/core/control/ctrl_client.h\"\n#include \"oneflow/core/control/global_process_ctx.h\"\n#include \"oneflow/core/eager/eager_blob_object.h\"\n#include \"oneflow/core/framework/instructions_builder.h\"\n#include \"oneflow/core/framework/nd_sbp.h\"\n#include \"oneflow/core/framework/scope_util.h\"\n#include \"oneflow/core/framework/tensor.h\"\n#include \"oneflow/core/framework/tensor_name_scope.h\"\n#include \"oneflow/core/functional/functional.h\"\n#include \"oneflow/core/graph/op_graph.h\"\n#include \"oneflow/core/job/compiler.h\"\n#include \"oneflow/core/job/rank_compiler.h\"\n#include \"oneflow/core/graph/task_graph.h\"\n#include \"oneflow/core/job/job_build_and_infer_ctx_mgr.h\"\n#include \"oneflow/core/job/job_desc.h\"\n#include \"oneflow/core/job/job_instance.h\"\n#include \"oneflow/core/job/critical_section_instance.h\"\n#include \"oneflow/core/job/lazy_mode.h\"\n#include \"oneflow/core/job/plan_util.h\"\n#include \"oneflow/core/job/utils/progress_bar.h\"\n#include \"oneflow/core/job_rewriter/job_completer.h\"\n#include \"oneflow/core/persistence/tee_persistent_log_stream.h\"\n#include \"oneflow/core/rpc/include/global_process_ctx.h\"\n#include \"oneflow/core/vm/virtual_machine.h\"\n#include \"oneflow/core/vm/symbol_storage.h\"\n#include \"oneflow/core/vm/vm_util.h\"\n#include \"oneflow/core/profiler/profiler.h\"\n#include \"oneflow/core/framework/variable_tensor_mgr.h\"\n#include \"oneflow/core/common/env_var/env_var.h\"\n#include \"oneflow/core/job/compile_mode.h\"\n#include \"oneflow/core/thread/thread_manager.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nMaybe<bool> GetTensorValidInCurRank(const std::shared_ptr<one::Tensor>& tensor) {\n  if (tensor->is_global()) {\n    const auto& parallel_id = JUST(GetParallelId4CurrentProcessCtx(JUST(tensor->parallel_desc())));\n    if (parallel_id->has_value()) {\n      return true;\n    } else {\n      return false;\n    }\n  } else {\n    return true;\n  }\n}\n\nMaybe<std::string> GetTensorMetaString(const std::shared_ptr<one::Tensor>& tensor) {\n  std::string ret = \"shape=\" + tensor->shape()->ToString() + \", dtype=\" + tensor->dtype()->name();\n  if (tensor->is_global()) {\n    ret += \", placement=\" + *JUST(PlacementToString(JUST(tensor->parallel_desc())));\n    ret += \", nd_sbp=\" + NdSbpToString(JUST(tensor->nd_sbp()));\n  } else {\n    ret += \", device=\" + JUST(tensor->device())->ToString();\n  }\n  return ret;\n}\n\ntemplate<typename T>\nMaybe<void> MakeEagerBlobObjectList(vm::EagerBlobObjectList* blob_list, const T& tensor_list) {\n  blob_list->reserve(tensor_list.size());\n  for (const auto& tensor : tensor_list) {\n    CHECK_OR_RETURN(tensor->is_eager())\n        << Error::RuntimeError() << \"Tensors in nn.Graph should be eager\";\n    if (tensor->is_global()) {\n      blob_list->emplace_back(JUST(JUST(tensor->cur_rank_phy_tensor())->eager_blob_object()));\n    } else {\n      blob_list->emplace_back(JUST(tensor->eager_blob_object()));\n    }\n  }\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\nNNGraph::~NNGraph() {\n  VLOG(1) << \"Graph destructor Try to close c nn graph name \" << name_ << \".\" << std::endl;\n  CHECK_JUST(Close());\n}\n\nMaybe<void> NNGraph::Close() {\n  if (!is_closed_) {\n    VLOG(1) << \"Try to close c nn graph name \" << name_ << \".\" << std::endl;\n    CloseRuntimeBuffers();\n    runtime_.reset();\n    session_ctx_->RemoveGraphFreeEagerTensors(name_);\n    VLOG(1) << \"Finish close c nn graph name \" << name_ << \".\" << std::endl;\n\n    session_ctx_.reset();\n    is_closed_ = true;\n  }\n  return Maybe<void>::Ok();\n}\n\nconst std::vector<std::string>& NNGraph::inputs_op_names() const { return inputs_op_names_; }\n\nconst std::vector<std::string>& NNGraph::outputs_op_names() const { return outputs_op_names_; }\n\nconst std::vector<bool>& NNGraph::inputs_valid() const { return input_tensors_valid_; }\n\nconst std::vector<bool>& NNGraph::outputs_valid() const { return output_tensors_valid_; }\n\nconst std::vector<std::string>& NNGraph::inputs_tensor_meta_str() const {\n  return inputs_tensor_meta_str_;\n}\n\nconst std::vector<std::string>& NNGraph::outputs_tensor_meta_str() const {\n  return outputs_tensor_meta_str_;\n}\n\nint64_t NNGraph::variable_op_size() const { return variable_op_names_.size(); }\n\nconst std::shared_ptr<vm::EagerBlobObjectList>& NNGraph::var_blobs() const {\n  return variable_op_blobs_;\n}\n\nMaybe<void> NNGraph::RegisterAdditionalVarOpNamesAndTensorsToBeLoaded(\n    const std::vector<std::string>& additional_var_names,\n    const std::vector<std::shared_ptr<one::Tensor>>& additional_var_tensors) {\n  CHECK_EQ_OR_RETURN(additional_var_names.size(), additional_var_tensors.size())\n      << Error::RuntimeError()\n      << \"Number of additional variable names and tensors mismatch. \"\n         \"Size of variable names: \"\n      << additional_var_names.size() << \", size of tensors: \" << additional_var_tensors.size();\n  CHECK_OR_RETURN(additional_variable_op_tobe_loaded_name2tensor_.empty())\n      << Error::RuntimeError()\n      << \"The additional variables (states in Optimizer or LRScheduler) of nn.Graph \" << name_\n      << \" are registered repeatedly.\";\n  FOR_RANGE(size_t, i, 0, additional_var_names.size()) {\n    CHECK_OR_RETURN(additional_variable_op_tobe_loaded_name2tensor_\n                        .emplace(JUST(VectorAt(additional_var_names, i)),\n                                 JUST(VectorAt(additional_var_tensors, i)))\n                        .second)\n        << Error::RuntimeError() << \"Duplicate variable name: \" << additional_var_names[i];\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> NNGraph::RegisterInputOpNamesAndTensors(\n    const std::vector<std::string>& inputs_op_names,\n    const std::vector<std::shared_ptr<one::Tensor>>& input_tensors) {\n  CHECK_EQ_OR_RETURN(inputs_op_names.size(), input_tensors.size())\n      << Error::RuntimeError()\n      << \"Number of input op names and tensors mismatch. \"\n         \"Size of op names: \"\n      << inputs_op_names.size() << \", size of tensors: \" << input_tensors.size();\n  CHECK_OR_RETURN(inputs_op_names_.empty())\n      << Error::RuntimeError() << \"The input tensors of nn.Graph \" << name_\n      << \" are registered repeatedly.\";\n  CHECK_OR_RETURN(input_tensors_valid_.empty())\n      << Error::RuntimeError() << \"The input tensors of nn.Graph \" << name_\n      << \" are registered repeatedly.\";\n  CHECK_OR_RETURN(inputs_tensor_meta_str_.empty())\n      << Error::RuntimeError() << \"The input tensors of nn.Graph \" << name_\n      << \" are registered repeatedly.\";\n  inputs_op_names_.assign(inputs_op_names.begin(), inputs_op_names.end());\n  input_tensors_valid_.reserve(input_tensors.size());\n  inputs_tensor_meta_str_.reserve(input_tensors.size());\n  for (const auto& input_tensor : input_tensors) {\n    input_tensors_valid_.emplace_back(JUST(GetTensorValidInCurRank(input_tensor)));\n    inputs_tensor_meta_str_.emplace_back(*JUST(GetTensorMetaString(input_tensor)));\n  }\n  CHECK_EQ_OR_RETURN(input_tensors_valid_.size(), input_tensors.size());  // NOLINE\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> NNGraph::RegisterOutputOpNamesAndTensors(\n    const std::vector<std::string>& outputs_op_names,\n    const std::vector<std::shared_ptr<one::Tensor>>& output_tensors) {\n  CHECK_EQ_OR_RETURN(outputs_op_names.size(), output_tensors.size())\n      << \"Number of output op names and tensors mismatch \"\n         \"Size of op names: \"\n      << outputs_op_names.size() << \", size of tensors: \" << output_tensors.size();\n  CHECK_OR_RETURN(outputs_op_names_.empty())\n      << Error::RuntimeError() << \"The output tensors of nn.Graph \" << name_\n      << \" are registered repeatedly.\";\n  CHECK_OR_RETURN(output_tensors_valid_.empty())\n      << Error::RuntimeError() << \"The output tensors of nn.Graph \" << name_\n      << \" are registered repeatedly.\";\n  CHECK_OR_RETURN(outputs_tensor_meta_str_.empty())\n      << Error::RuntimeError() << \"The output tensors of nn.Graph \" << name_\n      << \" are registered repeatedly.\";\n  outputs_op_names_.assign(outputs_op_names.begin(), outputs_op_names.end());\n  output_tensors_valid_.reserve(output_tensors.size());\n  outputs_tensor_meta_str_.reserve(output_tensors.size());\n  for (const auto& output_tensor : output_tensors) {\n    output_tensors_valid_.emplace_back(JUST(GetTensorValidInCurRank(output_tensor)));\n    outputs_tensor_meta_str_.emplace_back(*JUST(GetTensorMetaString(output_tensor)));\n  }\n  CHECK_EQ_OR_RETURN(output_tensors_valid_.size(), output_tensors.size());  // NOLINT\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> NNGraph::RegisterVariableOpNamesAndTensors(\n    const std::vector<std::string>& variable_op_names,\n    const std::vector<std::shared_ptr<one::Tensor>>& variable_tensors) {\n  JUST(vm::CurrentRankSync());\n  CHECK_EQ_OR_RETURN(variable_op_names.size(), variable_tensors.size())\n      << \"Number of variable names and tensors mismatch. \"\n         \"Size of variable names: \"\n      << variable_op_names.size() << \", size of tensors: \" << variable_tensors.size();\n  CHECK_ISNULL_OR_RETURN(variable_op_blobs_);\n  variable_op_blobs_ = std::make_shared<vm::EagerBlobObjectList>();\n  JUST(MakeEagerBlobObjectList(variable_op_blobs_.get(), variable_tensors));\n  for (int32_t i = 0; i < variable_op_names.size(); ++i) {\n    const std::shared_ptr<one::Tensor>& var = variable_tensors[i];\n    CHECK_OR_RETURN(var->is_eager())\n        << Error::InvalidValueError() << \"Tensor variable to register in nn.Graph should be eager\";\n    const std::string& var_name = variable_op_names.at(i);\n    CHECK_OR_RETURN(!var_name.empty()) << Error::InvalidValueError() << \"Empty variable name\";\n    CHECK_OR_RETURN(variable_op_name2tensor_.emplace(var_name, var).second)\n        << Error::RuntimeError() << \"Duplicate variable name: \" << var_name;\n    CHECK_OR_RETURN(variable_op_names_.insert(var_name).second)\n        << Error::RuntimeError() << \"Duplicate variable name: \" << var_name;\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> NNGraph::RegisterFreeEagerTensorsToVariableOpNames() {\n  JUST(vm::CurrentRankSync());\n  const auto& free_eager_tensors = session_ctx_->GetFreeEagerTensorNamePairByGraphName(name_);\n  for (const auto& pair : free_eager_tensors) {\n    const std::string& var_name = pair.first;\n    const std::shared_ptr<one::Tensor>& var = pair.second;\n    CHECK_OR_RETURN(var->is_eager())\n        << Error::RuntimeError() << \"Free tensor variable to register in nn.Graph should be eager\";\n    CHECK_OR_RETURN(!var_name.empty()) << Error::RuntimeError() << \"Empty variable name\";\n    CHECK_OR_RETURN(variable_op_name2tensor_.emplace(var_name, var).second)\n        << Error::RuntimeError() << \"Duplicate variable name: \" << var_name;\n    CHECK_OR_RETURN(additional_variable_op_name_.insert(var_name).second)\n        << Error::RuntimeError() << \"Duplicate variable name: \" << var_name;\n    CHECK_OR_RETURN(variable_op_names_.insert(var_name).second)\n        << Error::RuntimeError() << \"Duplicate variable name: \" << var_name;\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<std::vector<std::string>> NNGraph::GetAdditionalVarOpNames() const {\n  std::vector<std::string> names;\n  for (const auto& iter : additional_variable_op_name_) { names.push_back(iter); }\n  return names;\n}\n\nMaybe<std::vector<std::shared_ptr<one::Tensor>>> NNGraph::GetAdditionalVarOpTensors() const {\n  std::vector<std::shared_ptr<one::Tensor>> tensors;\n  for (const auto& iter : additional_variable_op_name_) {\n    auto find_iter = variable_op_name2tensor_.find(iter);\n    CHECK_OR_RETURN(find_iter != variable_op_name2tensor_.end())\n        << Error::RuntimeError() << \"Additional variable op name \" << iter << \" not found.\";\n    tensors.push_back(find_iter->second);\n  }\n  return tensors;\n}\n\nMaybe<void> NNGraph::RegisterNewVariableOpInJobPass() {\n  OpGraph op_graph(job_);\n  JUST(op_graph.MaybeForEachNode([&](OpNode* op_node) -> Maybe<void> {\n    if (op_node->op().op_conf().has_variable_conf() == false) { return Maybe<void>::Ok(); }\n    const Operator& variable_op = op_node->op();\n    const VariableOpConf& var_conf = variable_op.op_conf().variable_conf();\n    const std::string& var_name = variable_op.op_name();\n    CHECK_OR_RETURN(var_conf.has_initializer())\n        << Error::RuntimeError() << \"nn.Graph ONLY support variable op with initializer conf.\";\n    if (var_conf.initializer().has_constant_conf()\n        || var_conf.initializer().has_constant_int_conf() /* vairable ops inserted by system */) {\n      CHECK_OR_RETURN(variable_op_names_.insert(var_name).second)\n          << Error::RuntimeError() << \"Variable_op_name: \" << var_name\n          << \" has been added in nn.Graph: \" << name_;\n      CHECK_OR_RETURN(\n          variable_op_name2tensor_.insert({var_name, std::shared_ptr<one::Tensor>()}).second)\n          << Error::RuntimeError() << \"Variable Tensor with op_name: \" << var_name\n          << \" has been add in nn.Graph: \" << name_;\n      CHECK_OR_RETURN(additional_variable_op_name_.insert(var_name).second)\n          << Error::RuntimeError() << \"Variable Tensor with op_name: \" << var_name\n          << \" has been add in nn.Graph: \" << name_;\n    } else /* vairable ops from user code */ {\n      CHECK_OR_RETURN(var_conf.initializer().has_empty_conf())\n          << Error::RuntimeError() << \"nn.Graph ONLY support variable_op with empty conf, \"\n          << \"because variable is inited by eager tensor. \"\n          << \"This error variable conf is: \" << variable_op.op_conf().DebugString()\n          << \" in nn.Graph \" << name_;\n      CHECK_OR_RETURN(variable_op_names_.find(var_name) != variable_op_names_.end())\n          << Error::RuntimeError() << var_name\n          << \" must be a variable created in nn.Graph: \" << name_;\n    }\n    return Maybe<void>::Ok();\n  }));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> NNGraph::DeleteOutdatedVariableInVariableTensorMgr() {\n  const auto& var_get_func = [&]() -> Maybe<std::set<std::string>> {\n    std::set<std::string> variable_names_;\n    OpGraph op_graph(job_);\n    JUST(op_graph.MaybeForEachNode([&](OpNode* op_node) -> Maybe<void> {\n      if (op_node->op().op_conf().has_variable_conf() == false) { return Maybe<void>::Ok(); }\n      variable_names_.insert(op_node->op().op_name());\n      return Maybe<void>::Ok();\n    }));\n    return variable_names_;\n  };\n  std::set<std::string> variable_names = *JUST(var_get_func());\n\n  auto mgr = Singleton<VariableTensorMgr>::Get();\n  for (auto& name : mgr->DumpNames()) {\n    if (variable_names.find(name) == variable_names.end()) { mgr->Delete(name); }\n  }\n  return Maybe<void>::Ok();\n}\n\nnamespace {\n\n// A templated function that broadcasts data from the master process to worker processes in a\n// multi-threaded manner. Return push/pull keys only in master process.\ntemplate<typename X, typename Y>\nstd::set<std::string> MultiThreadBroadcastFromMasterToWorkers(size_t world_size,\n                                                              const std::string& prefix,\n                                                              const X& master_data,\n                                                              Y* worker_data) {\n  const size_t thread_num = ThreadLocalEnvInteger<ONEFLOW_LAZY_COMPILE_RPC_THREAD_NUM>();\n  const size_t split_num = std::sqrt(world_size);\n  BalancedSplitter bs(world_size, split_num);\n  std::set<std::string> keys;\n  if (GlobalProcessCtx::IsThisProcessMaster()) {\n    std::mutex mtx4keys;\n    std::string data;\n    master_data.SerializeToString(&data);\n    MultiThreadLoop(\n        split_num,\n        [&](int i) {\n          std::string key = prefix + std::to_string(i);\n          Singleton<CtrlClient>::Get()->PushKV(key, data);\n          std::lock_guard<std::mutex> lock(mtx4keys);\n          CHECK(keys.insert(key).second);\n        },\n        thread_num);\n  } else {\n    const int64_t bs_index = bs.GetRangeIndexForVal(GlobalProcessCtx::Rank());\n    std::string key = prefix + std::to_string(bs_index);\n    Singleton<CtrlClient>::Get()->PullKV(key, worker_data);\n  }\n  return keys;\n}\n\n// A templated function that pushes data from the master process to each worker process using the\n// control client. The function takes as input a prefix for the key used to store the data in the\n// control client, a pointer to the data to be pushed, and a callable object PrepareEach that\n// preprocesses the worker's data. Return push/pull keys only in master process.\ntemplate<typename T, typename PrepareEachT>\nstd::set<std::string> MultiThreadPushFromMasterToWorkers(const std::string& prefix, T* data,\n                                                         const PrepareEachT& PrepareEach) {\n  const size_t thread_num = ThreadLocalEnvInteger<ONEFLOW_LAZY_COMPILE_RPC_THREAD_NUM>();\n  constexpr int kWorkerStartRank = 1;\n  std::set<std::string> keys{};\n  if (GlobalProcessCtx::IsThisProcessMaster()) {\n    std::mutex mtx4keys;\n    MultiThreadLoop(\n        GlobalProcessCtx::WorldSize(),\n        [&](int i) {\n          if (i < kWorkerStartRank) { return; }\n          T worker_data;\n          std::string key = prefix + std::to_string(i);\n          PrepareEach(&worker_data, i);\n          Singleton<CtrlClient>::Get()->PushKV(key, worker_data);\n          std::lock_guard<std::mutex> lock(mtx4keys);\n          CHECK(keys.emplace(key).second) << \"redundant pull key: \" << key;\n        },\n        thread_num);\n  } else {\n    Singleton<CtrlClient>::Get()->PullKV(prefix + std::to_string(GlobalProcessCtx::Rank()), data);\n  }\n  return keys;\n}\n\nvoid DumpCalculationPassName(Job* job) {\n  for (int i = 0; i < job->net().op_size(); ++i) {\n    auto* op_conf = job->mutable_net()->mutable_op(i);\n    if (op_conf->has_scope_symbol_id()) {\n      const auto& scope = Singleton<symbol::Storage<Scope>>::Get()->Get(op_conf->scope_symbol_id());\n      op_conf->set_calculation_pass_name(scope.scope_proto().calculation_pass_name());\n    }\n  }\n}\n\n}  // namespace\n\n// The main logic of separation plan compilation. Each rank (process) compile it's related task\n// nodes. This can reduce plan compile time and avoid transport large plan protobuf.\n// When master compile the full plan, some plan protos are much larger than 1GB, but protobuf has\n// 2GB limitation and larg files are slow to transport. So we mush do separatioin plan compile when\n// total rank num is large.\n// Separation plan compilation is done by:\n//   a. Master broadcast job(or logical graph) to all workers, make all rank use the same job.\n//   b. Mater compile BoxingTaskGraph and broadcast it to all workers. BoxingTaskGraph needs to be\n//      done on master rank.\n//   c. Each rank compile it's related task node with RankCompiler. RankCompiler compile with the\n//      BoxingTaskGraph and the job.\nMaybe<void> NNGraph::MasterAndWorkerRanksCompile() {\n  // Seperation compile mode only works with nccl use compute stream and logical chain.\n  CHECK_OR_RETURN(EnableLogicalChain())\n      << Error::RuntimeError()\n      << \"nn.Graph separete compilation needs to work with logical chain enabled.\";\n  // Note that nccl use compute stream mode has not need to generate CollectiveBoxingPlan.\n  CHECK_OR_RETURN((Singleton<ResourceDesc, ForSession>::Get()->nccl_use_compute_stream()))\n      << Error::RuntimeError()\n      << \"nn.Graph separete compilation needs to work with nccl using compute stream enabled.\";\n\n  std::set<std::string> push_pull_keys{};\n  const auto& MergeCommKeys = [&](std::set<std::string>&& keys) {\n    push_pull_keys.insert(keys.begin(), keys.end());\n  };\n  if (GlobalProcessCtx::IsThisProcessMaster()) { DumpCalculationPassName(&job_); }\n\n  // a. Master broadcast job(or logical graph) to all workers, make all rank use the same job.\n  const size_t world_size = GlobalProcessCtx::WorldSize();\n  MergeCommKeys(MultiThreadBroadcastFromMasterToWorkers(\n      world_size, name_ + std::string(__FUNCTION__) + \"_job\", job_, &job_));\n  OpGraphSingletonGuard op_graph_guard(job_);\n  size_t rank = GlobalProcessCtx::Rank();\n\n  // b. Mater compile BoxingTaskGraph and broadcast it to all workers. BoxingTaskGraph needs to be\n  //    done on master rank.\n  auto boxing_task_graph_proto = std::make_shared<BoxingTaskGraphProto>();\n  std::shared_ptr<BoxingTaskGraph> boxing_task_graph;\n  if (GlobalProcessCtx::IsThisProcessMaster()) {\n    const auto& ParallelLoop = [](size_t work_num, const std::function<void(size_t)>& Work) {\n      MultiThreadLoop(work_num, Work, -1);\n    };\n    boxing_task_graph = JUST(BoxingTaskGraph::New(ParallelLoop));\n    boxing_task_graph->ToProto([](TaskNode*) { return true; }, boxing_task_graph_proto.get());\n    if (Singleton<ResourceDesc, ForSession>::Get()->enable_debug_mode()) {\n      TeePersistentLogStream::Create(\"boxing_task_\" + name_ + \"_plan\" + std::to_string(0))\n          ->Write(*boxing_task_graph_proto);\n    }\n  }\n  const auto& PrepareWorkerBoxingTaskGraphProto = [&](BoxingTaskGraphProto* proto, int64_t i) {\n    boxing_task_graph->ToProto(\n        [i](TaskNode* task_node) { return BoxingTaskGraph::SelectTaskNodeByRank(task_node, i); },\n        proto);\n    if (Singleton<ResourceDesc, ForSession>::Get()->enable_debug_mode()) {\n      TeePersistentLogStream::Create(\"boxing_task_\" + name_ + \"_plan\" + std::to_string(i))\n          ->Write(*proto);\n    }\n  };\n  MergeCommKeys(MultiThreadPushFromMasterToWorkers(\n      name_ + std::string(__FUNCTION__) + \"_boxing_task_graph\", boxing_task_graph_proto.get(),\n      PrepareWorkerBoxingTaskGraphProto));\n\n  // c. Each rank compile it's related task node with RankCompiler. RankCompiler compile with the\n  //    BoxingTaskGraph and the job.\n  auto* plan = &plan_;\n  CHECK_JUST(RankCompiler(boxing_task_graph_proto, rank).Compile(variable_op_names_, &job_, plan));\n  PlanUtil::GenMemBlockAndChunkWithVariableOpNames4Plan(plan, variable_op_names_);\n\n  if (Singleton<ResourceDesc, ForSession>::Get()->enable_debug_mode()) {\n    TeePersistentLogStream::Create(\"job_\" + name_ + \"_plan\" + std::to_string(rank))->Write(*plan);\n    PlanUtil::ToDotFile(*plan, \"job_\" + name_ + \"_plan_\" + std::to_string(rank) + \".dot\");\n  }\n  PlanUtil::GenRegisterHint(plan);\n  PlanUtil::DumpCtrlRegstInfoToPlan(plan);\n  PlanUtil::PlanMemoryLog(&plan_, name_);\n  if (Singleton<ResourceDesc, ForSession>::Get()->enable_debug_mode()) {\n    PlanUtil::GenLightPlan(&plan_, name_, rank);\n  }\n  OF_SESSION_BARRIER();\n  for (const auto& k : push_pull_keys) { Singleton<CtrlClient>::Get()->ClearKV(k); }\n  OF_SESSION_BARRIER();\n  return Maybe<void>::Ok();\n}\n\n// Master compile the full plan.\nMaybe<void> NNGraph::NaiveCompile() {\n  auto compile_tc = std::make_unique<CostCounter<std::chrono::seconds>>(true, true);\n  if (GlobalProcessCtx::IsThisProcessMaster()) {\n    auto sub_compile_tc = std::make_unique<CostCounter<std::chrono::seconds>>(true, true);\n    // TODO(chengcheng): new memory reused by chunk\n    Compiler().Compile(&job_, &plan_);\n    sub_compile_tc->Count(\"[PlanCompile]\" + name_ + \" GenerateBasePlan\", 1);\n    PlanUtil::GenMemBlockAndChunkWithVariableOpNames4Plan(&plan_, variable_op_names_);\n    sub_compile_tc->Count(\"[PlanCompile]\" + name_ + \" GenMemBlockAndChunk\", 1);\n    PlanUtil::GenRegisterHint(&plan_);\n    sub_compile_tc->Count(\"[PlanCompile]\" + name_ + \" GenRegisterHint\", 1);\n    // TODO(chengcheng): test collective boxing for multi-job.\n    PlanUtil::GenCollectiveBoxingPlan(&job_, &plan_);\n    // PlanUtil::SetForceInplaceMemBlock(&plan_); NOTE(chengcheng): only for ssp.\n    sub_compile_tc->Count(\"[PlanCompile]\" + name_ + \" GenCollectiveBoxingPlan\", 1);\n    PlanUtil::DumpCtrlRegstInfoToPlan(&plan_);\n    sub_compile_tc->Count(\"[PlanCompile]\" + name_ + \" DumpCtrlRegstInfoToPlan\", 1);\n    PlanUtil::PlanMemoryLog(&plan_, name_);\n    if (Singleton<ResourceDesc, ForSession>::Get()->enable_debug_mode()) {\n      PlanUtil::GenLightPlan(&plan_, name_);\n    }\n    sub_compile_tc->Count(\"[GraphCompile]\" + name_ + \" GenMemAndLightPlanLog\", 1, true);\n  }\n  compile_tc->Count(\"[GraphCompile]\" + name_ + \" CompilePlan\", 0);\n  if (GlobalProcessCtx::WorldSize() > 1) {\n    std::string plan_name = \"plan:\" + job_name();\n    if (GlobalProcessCtx::IsThisProcessMaster()) {\n      // TODO(chengcheng): split plan for each rank.\n      Singleton<CtrlClient>::Get()->PushKV(plan_name, plan_);\n    } else {\n      Singleton<CtrlClient>::Get()->PullKV(plan_name, &plan_);\n    }\n    OF_SESSION_BARRIER();\n    if (GlobalProcessCtx::IsThisProcessMaster()) {\n      Singleton<CtrlClient>::Get()->ClearKV(plan_name);\n    }\n  }\n  compile_tc->Count(\"[GraphCompile]\" + name_ + \" SyncPlan\", 0, true);\n  return Maybe<void>::Ok();\n}\n\n// There are four plan compilation modes, with the first mode \"master compilation\" (default) and the\n// fourth mode \"rank separation compilation\" being the ones actually used.\nMaybe<void> NNGraph::CompilePlanForRuntime() {\n  // A global variable to get graph configurations.\n  auto current_graph_config = std::make_unique<GlobalJobDescScope>(job_.job_conf(), job_id());\n  auto compile_tc = std::make_unique<CostCounter<std::chrono::seconds>>(true, true);\n  typedef Maybe<void> (NNGraph::*CompileMethodT)();\n  struct GetCompileMethod final : public CompileModeVisitor<GetCompileMethod> {\n    static CompileMethodT VisitNaive() {\n      // Master rank compile the full plan.\n      return &NNGraph::NaiveCompile;\n    }\n    static CompileMethodT VisitRankPerProcess() {\n      // Multi process(rank) run seperation compile.\n      return &NNGraph::MasterAndWorkerRanksCompile;\n    }\n    static CompileMethodT VisitInValid() { return nullptr; }\n  };\n  JUST((this->*GetCompileMethod::Visit(JUST(CurrentCompileMode())))());\n  compile_tc->Count(\"[GraphCompile]\" + name_ + \" CompileAndSyncPlan\", 0);\n  PlanUtil::PopulateOpAttribute(&plan_, plan_.job_id2op_attribute_ref_table());\n  compile_tc->Count(\"[GraphCompile]\" + name_ + \" PopulateOpAttribute\", 0);\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> NNGraph::InitRuntime() {\n  CHECK_OR_RETURN(!runtime_inited_)\n      << Error::RuntimeError() << \"nn.Graph runtime is already initialized\";\n\n  auto compile_tc = std::make_unique<CostCounter<std::chrono::seconds>>(true, true);\n  NewRuntimeBuffers();\n\n  JUST(GetVariableRealBlobAfterSyncPlan());\n\n  // NOTE(strint): Do memory shrink to free cached memory in eager VM before graph runtime init.\n  JUST(vm::CurrentRankSync());\n  auto* vm = JUST(SingletonMaybe<VirtualMachine>());\n  JUST(vm->ShrinkAllMem());\n\n  if (Singleton<ResourceDesc, ForSession>::Get()->enable_debug_mode()) {\n    auto cur_rank = GlobalProcessCtx::Rank();\n    auto plan_name = \"job_\" + name_ + \"_plan\";\n    if (JUST(CurrentCompileMode()) != CompileMode::kNaive) {\n      plan_name += std::to_string(cur_rank);\n    }\n    if (cur_rank == 0 || JUST(CurrentCompileMode()) != CompileMode::kNaive) {\n      TeePersistentLogStream::Create(plan_name)->Write(plan_);\n      PlanUtil::ToDotFile(plan_, plan_name + \".dot\");\n    }\n  }\n\n  runtime_.reset(new Runtime(plan_, variable_op_name2eager_blob_object_));\n  compile_tc->Count(\"[GraphCompile]\" + name_ + \" InitRuntime\", 0, true);\n  JUST(LogProgress(\"[GraphCompile]\" + name_ + \" Done\", true));\n\n  runtime_inited_ = true;\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> NNGraph::AlignStatesAfterLogicalGraphCompile() {\n  auto compile_tc = std::make_unique<CostCounter<std::chrono::seconds>>(true, true);\n  JUST(RegisterFreeEagerTensorsToVariableOpNames());\n  JUST(RegisterNewVariableOpInJobPass());\n  JUST(DeleteOutdatedVariableInVariableTensorMgr());\n  // NOTE(chengcheng): TensorNameScope need to be cleared after current graph is built.\n  one::TensorNameScope::Global()->Clear();\n  // Clear all backward pass scope\n  ClearAllBackwardPassScope();\n  compile_tc->Count(\"[GraphCompile]\" + name_ + \" AlignStates\", 0);\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> NNGraph::CompleteLogicalGraphForRuntime() {\n  auto compile_tc = std::make_unique<CostCounter<std::chrono::seconds>>(true, true);\n  // A global variable to get graph configurations.\n  auto current_graph_config = std::make_unique<GlobalJobDescScope>(job_.job_conf(), job_id());\n  // NOTE(chengcheng): do job compeleter for each rank.\n  JUST(JobCompleter::Complete(&job_));\n  compile_tc->Count(\"[GraphCompile]\" + name_ + \" CompleteJob\", 0);\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> NNGraph::BuildWithNewInputFromSharedGraph(\n    const std::vector<std::string>& shared_inputs_op_names,\n    const std::vector<std::shared_ptr<one::Tensor>>& new_input_tensors,\n    const std::vector<std::string>& shared_op_names_from_ordered_original_graph,\n    const std::string& new_serialized_original_job) {\n  CHECK_EQ_OR_RETURN(shared_inputs_op_names.size(), new_input_tensors.size());  // NOLINE\n  auto compile_tc = std::make_unique<CostCounter<std::chrono::seconds>>(true, true);\n  // Register inputs.\n  JUST(RegisterInputOpNamesAndTensors(shared_inputs_op_names, new_input_tensors));\n\n  // Generate new input tensor getter.\n  HashMap<std::string, std::shared_ptr<one::Tensor>> input_name2tensor;\n  for (int64_t idx = 0; idx < shared_inputs_op_names.size(); ++idx) {\n    input_name2tensor.emplace(shared_inputs_op_names[idx], new_input_tensors[idx]);\n  }\n  const auto& InputTensor4Name =\n      [&input_name2tensor](const std::string& op_name) -> Maybe<std::shared_ptr<one::Tensor>> {\n    auto iter = input_name2tensor.find(op_name);\n    CHECK_OR_RETURN(iter != input_name2tensor.end())\n        << \"Can't find input tensor of \" << op_name << \".\";\n    return iter->second;\n  };\n\n  // Generate new OperatorConf getter.\n  Job new_build_original_job;\n  CHECK_OR_RETURN(new_build_original_job.ParseFromString(new_serialized_original_job))\n      << \"nn.Graph \" << name_ << \" parse job proto of new build graph failed.\";\n  CHECK_EQ_OR_RETURN(new_build_original_job.net().op_size(),\n                     shared_op_names_from_ordered_original_graph.size())\n      << \"nn.Graph \" << name_\n      << \" new_build_original_job op size and shared_op_names_from_ordered_original_graph \"\n      << \"size are not equal.\";\n  HashMap<std::string, const OperatorConf*> shared_op_name2_new_op;\n  for (int64_t op_idx = 0; op_idx < shared_op_names_from_ordered_original_graph.size(); ++op_idx) {\n    // Assume that the new graph and the shared graph from nn.Graph.build have the same op order.\n    const auto& op = new_build_original_job.mutable_net()->mutable_op()->at(op_idx);\n    shared_op_name2_new_op.emplace(shared_op_names_from_ordered_original_graph[op_idx], &op);\n  }\n  const auto& NewOp4SharedOpName =\n      [&shared_op_name2_new_op](const std::string& shared_op_name) -> Maybe<const OperatorConf*> {\n    auto iter = shared_op_name2_new_op.find(shared_op_name);\n    if (iter == shared_op_name2_new_op.end()) {\n      VLOG(1) << \"Can't find new traced operator conf for op \" << shared_op_name\n              << \" in the shared graph from the base graph. This op is not shared between graphs.\";\n      return nullptr;\n    }\n    return iter->second;\n  };\n\n  // A global variable to get graph configurations.\n  auto current_graph_config = std::make_unique<GlobalJobDescScope>(job_.job_conf(), job_id());\n  // NOTE(chengcheng): do job compeleter for each rank.\n  JUST(JobCompleter::UpdateSharedGraphForNewInput(&job_, InputTensor4Name, NewOp4SharedOpName));\n  compile_tc->Count(\"[GraphCompile]\" + name_ + \" CompleteJob\", 0);\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> NNGraph::CompileAndInitRuntime() {\n  JUST(AlignStatesAfterLogicalGraphCompile());\n  JUST(CompleteLogicalGraphForRuntime());\n  JUST(CompilePlanForRuntime());\n  JUST(InitRuntime());\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> NNGraph::GetVariableRealBlobAfterSyncPlan() {\n  CHECK_OR_RETURN(variable_op_name2eager_blob_object_.empty())\n      << Error::RuntimeError() << kOfBugIssueUploadPrompt;\n  JUST(vm::CurrentRankSync());\n  // Create or Rebuild variable, then get the real blob.\n  for (const std::string& var_name : variable_op_names_) {\n    auto iter = variable_op_name2tensor_.find(var_name);\n    CHECK_OR_RETURN(iter != variable_op_name2tensor_.end())\n        << Error::RuntimeError() << \"variable op name \" << var_name << \" not found.\";\n    std::shared_ptr<one::Tensor> tensor = iter->second;\n    vm::EagerBlobObject* var_blob = nullptr;\n    if (plan_.job_id2op_attribute_ref_table().at(job_id_).op_name2op_attribute().find(var_name)\n        == plan_.job_id2op_attribute_ref_table().at(job_id_).op_name2op_attribute().end()) {\n      // Deal with variable tensor not used in nn.Graph build.\n      CHECK_OR_RETURN(tensor != NULL)\n          << Error::RuntimeError() << \"The tensor of \" << var_name\n          << \" does not exist in the job, so it's not created in nn.Graph and cannot be NULL.\";\n      if (tensor->is_global()) {\n        const std::shared_ptr<one::LocalTensor> local_var = JUST(tensor->cur_rank_phy_tensor());\n        var_blob = JUST(local_var->eager_blob_object()).get();\n      } else {\n        var_blob = JUST(tensor->eager_blob_object()).get();\n      }\n    } else if (/*is_null=*/!tensor) {\n      // Deal with tensors which are not in the nn.Module.\n      // We can call these tensors as additional variables.\n      const auto& op_attribute =\n          plan_.job_id2op_attribute_ref_table().at(job_id_).op_name2op_attribute().at(var_name);\n      // NOTE(chengcheng): handle constant variable created by job pass\n      Symbol<ParallelDesc> placement(op_attribute.parallel_conf_signature().op_parallel_conf());\n      NdSbp nd_sbp(NdSbpSignature(op_attribute.nd_sbp_signature()).bn_in_op2nd_sbp().at(\"out\"));\n      const BlobDesc blob_desc(\n          op_attribute.logical_blob_desc_signature().bn_in_op2blob_desc().at(\"out\"));\n      DType dtype(blob_desc.data_type());\n      std::shared_ptr<std::vector<Symbol<SbpParallel>>> sbp_tuple =\n          JUST(GetSbpList(Symbol<NdSbp>(nd_sbp)));\n\n      auto load_tensor_iter = additional_variable_op_tobe_loaded_name2tensor_.find(var_name);\n      if (load_tensor_iter == additional_variable_op_tobe_loaded_name2tensor_.end()) {\n        // Create a additional variable tensor\n        Scalar value;\n        const VariableOpConf& var_conf = op_attribute.op_conf().variable_conf();\n        if (var_conf.initializer().has_constant_conf()) {\n          value = var_conf.initializer().constant_conf().value();\n        } else if (var_conf.initializer().has_constant_int_conf()) {\n          value = var_conf.initializer().constant_int_conf().value();\n        } else {\n          OF_UNIMPLEMENTED();\n        }\n        // NOTE(chengcheng): New EagerTensor need set LazyMode false.\n        auto lazy_mode_disabled_guard = LazyMode::Guard(/*is_enabled*/ false);\n        tensor = JUST(one::functional::GlobalConstant(blob_desc.shape(), value,\n                                                      Symbol<DType>(dtype), placement, *sbp_tuple));\n        JUST(vm::CurrentRankSync());\n        VLOG(2) << \"Lazy nn.Graph name \" << name_ << \" op: \" << op_attribute.op_conf().name()\n                << \" created in JobPass, nn.Graph has created a eager tensor for this variable.\\n\";\n      } else {\n        // Load a additional variable tensor\n        auto lazy_mode_disabled_guard = LazyMode::Guard(/*is_enabled*/ false);\n        std::vector<Symbol<SbpParallel>> grad_sbp_tuple;\n        // To consistent from a local or global tensor.\n        bool check_meta = load_tensor_iter->second->is_global() ? false : true;\n        tensor = JUST(one::functional::ToGlobal(load_tensor_iter->second, placement, *sbp_tuple,\n                                                grad_sbp_tuple, check_meta, /*copy=*/false));\n        JUST(vm::CurrentRankSync());\n        VLOG(2) << \"Lazy nn.Graph name \" << name_ << \" op: \" << op_attribute.op_conf().name()\n                << \" created in JobPass, nn.Graph has loaded the tensor from state dict for this \"\n                   \"variable.\\n\";\n      }\n      // Register\n      JUST(MapAt(variable_op_name2tensor_, var_name)) = tensor;\n      // NOTE(chengcheng): Just for tensor lifetime hold by session context in graph lifetime\n      // valid.\n      session_ctx_->StoreFreeEagerTensorWithNameByGraphName(name_, tensor, var_name);\n\n      const std::shared_ptr<one::LocalTensor> local_var = JUST(tensor->cur_rank_phy_tensor());\n      var_blob = JUST(local_var->eager_blob_object()).get();\n    } else if (tensor->is_global()) {\n      // Deal with tensors which need to change sbp.\n      NdSbpSignature var_nd_sbp_signature = NdSbpSignature(plan_.job_id2op_attribute_ref_table()\n                                                               .at(job_id_)\n                                                               .op_name2op_attribute()\n                                                               .at(var_name)\n                                                               .nd_sbp_signature());\n      NdSbp optimized_nd_sbp = var_nd_sbp_signature.bn_in_op2nd_sbp().at(\"out\");\n      // Change variable tensor's impl with new sbp when job pass has changed their sbp.\n      if (*JUST(tensor->nd_sbp()) != optimized_nd_sbp) {\n        VLOG(2) << \"Graph with name \" << name_ << \" variable with name `\" << var_name\n                << \"` changes its' sbp from \" << NdSbpToString(*JUST(tensor->nd_sbp())) << \" to \"\n                << NdSbpToString(optimized_nd_sbp) << \" after compile optimization.\";\n        std::vector<Symbol<SbpParallel>> optimized_sbp_parallels;\n        for (int i = 0; i < optimized_nd_sbp.sbp_parallel_size(); ++i) {\n          optimized_sbp_parallels.emplace_back(optimized_nd_sbp.sbp_parallel(i));\n        }\n        {\n          auto lazy_mode_disabled_guard = LazyMode::Guard(/* is_enabled */ false);\n          const auto& new_tensor = JUST(one::functional::ToGlobal(\n              tensor, JUST(tensor->parallel_desc()), optimized_sbp_parallels, {},\n              /* check_meta */ false, /*copy=*/false));\n          JUST(vm::CurrentRankSync());\n          // Use tensor.set_data inferface and make new TensorImpl instead of the old one.\n          JUST(tensor->set_data(new_tensor));\n        }\n      }\n      const std::shared_ptr<one::LocalTensor> local_var = JUST(tensor->cur_rank_phy_tensor());\n      var_blob = JUST(local_var->eager_blob_object()).get();\n    } else {\n      var_blob = JUST(tensor->eager_blob_object()).get();\n    }\n    CHECK_OR_RETURN(var_blob != nullptr) << Error::RuntimeError() << kOfBugIssueUploadPrompt;\n    CHECK_OR_RETURN(variable_op_name2eager_blob_object_.emplace(var_name, var_blob).second)\n        << Error::RuntimeError() << kOfBugIssueUploadPrompt;\n  }\n  // Initialize or check mem_ptr_for_allocation_computation_pipelining by TouchTensors instruction.\n  JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe<void> {\n    auto eager_blob_objects = std::make_shared<vm::EagerBlobObjectList>();\n    for (const auto& pair : variable_op_name2eager_blob_object_) {\n      eager_blob_objects->push_back(pair.second->shared_from_this());\n    }\n    return builder->TouchTensors(eager_blob_objects);\n  }));\n  JUST(vm::CurrentRankSync());\n  // Clear after load additional variable is finished.\n  additional_variable_op_tobe_loaded_name2tensor_.clear();\n  return Maybe<void>::Ok();\n}\n\nvoid NNGraph::NewRuntimeBuffers() {\n  // NOTE(chengcheng):\n  //   1. The BufferSize comes from job_conf.concurrency_width configured by user (default = 128)\n  //   2. In Pipeline Parallelism, this value need greater than pipeline stage num for pipelining.\n  size_t concurrency_width = job_.job_conf().concurrency_width();\n  {\n    auto* buffer_mgr = Singleton<BufferMgr<std::shared_ptr<JobInstance>>>::Get();\n    buffer_mgr->NewBuffer(GetSourceTickBufferName(name_), concurrency_width);\n    buffer_mgr->NewBuffer(GetCallbackNotifierBufferName(name_), concurrency_width);\n  }\n  {\n    auto* buffer_mgr = Singleton<BufferMgr<std::shared_ptr<CriticalSectionInstance>>>::Get();\n    buffer_mgr->NewBuffer(GetInputCriticalSectionWaitBufferName(name_), concurrency_width);\n    buffer_mgr->NewBuffer(GetInputCriticalSectionCallbackBufferName(name_), concurrency_width);\n    buffer_mgr->NewBuffer(GetOutputCriticalSectionWaitBufferName(name_), concurrency_width);\n    buffer_mgr->NewBuffer(GetOutputCriticalSectionCallbackBufferName(name_), concurrency_width);\n    for (const std::string& input_op_name : inputs_op_names_) {\n      buffer_mgr->NewBuffer(GetInputBufferName(name_, input_op_name), concurrency_width);\n    }\n    for (const std::string& output_op_name : outputs_op_names_) {\n      buffer_mgr->NewBuffer(GetOutputBufferName(name_, output_op_name), concurrency_width);\n    }\n  }\n}\n\nvoid NNGraph::CloseRuntimeBuffers() {\n  if (runtime_inited_) {\n    {\n      auto* buffer_mgr = Singleton<BufferMgr<std::shared_ptr<CriticalSectionInstance>>>::Get();\n      for (const std::string& output_op_name : outputs_op_names_) {\n        buffer_mgr->Get(GetOutputBufferName(name_, output_op_name))->Close();\n      }\n      for (const std::string& input_op_name : inputs_op_names_) {\n        buffer_mgr->Get(GetInputBufferName(name_, input_op_name))->Close();\n      }\n      buffer_mgr->Get(GetOutputCriticalSectionCallbackBufferName(name_))->Close();\n      buffer_mgr->Get(GetOutputCriticalSectionWaitBufferName(name_))->Close();\n      buffer_mgr->Get(GetInputCriticalSectionCallbackBufferName(name_))->Close();\n      buffer_mgr->Get(GetInputCriticalSectionWaitBufferName(name_))->Close();\n    }\n    {\n      auto* buffer_mgr = Singleton<BufferMgr<std::shared_ptr<JobInstance>>>::Get();\n      buffer_mgr->Get(GetCallbackNotifierBufferName(name_))->Close();\n      buffer_mgr->Get(GetSourceTickBufferName(name_))->Close();\n    }\n  }\n}\n\nMaybe<void> RunLazyNNGraph(const one::TensorTuple& inputs, const one::TensorTuple& outputs,\n                           const std::shared_ptr<NNGraph>& nn_graph) {\n  CHECK_EQ_OR_RETURN(inputs.size(), nn_graph->inputs_op_names().size())\n      << Error::RuntimeError()\n      << \"Number of inputs and NNGraph::inputs_op_names mismatch. \"\n         \"Size of inputs: \"\n      << inputs.size()\n      << \", size of NNGraph::inputs_op_names: \" << nn_graph->inputs_op_names().size();\n  CHECK_EQ_OR_RETURN(outputs.size(), nn_graph->outputs_op_names().size())\n      << Error::RuntimeError()\n      << \"Number of outputs and NNGraph::outputs_op_names mismatch. \"\n         \"Size of outputs: \"\n      << outputs.size()\n      << \", size of NNGraph::outputs_op_names: \" << nn_graph->outputs_op_names().size();\n  // NOTE(chengcheng):\n  //   parameters not used in LaunchLazyJobInstrucntion;\n  //   the args: parameters is all variable tensor hold by nn.Graph\n  //   but the NNGraph::variable_op_size may has FreeEagerTensor as sepcial variable op.\n  CHECK_LE_OR_RETURN(nn_graph->var_blobs()->size(), nn_graph->variable_op_size())\n      << Error::RuntimeError() << \"Parameter size should be less than or equal to variable size\";\n  for (int i = 0; i < inputs.size(); ++i) {\n    // TODO(chengcheng, liufengwei):\n    //   use TensorMeta.to_string and equal.\n    std::string tensor_meta_str = *JUST(GetTensorMetaString(inputs.at(i)));\n    const std::string& static_meta_str = nn_graph->inputs_tensor_meta_str().at(i);\n    CHECK_OR_RETURN(static_meta_str == tensor_meta_str)\n        << Error::RuntimeError()\n        << \"nn.Graph ONLY accepts static inputs tensor meta, please check whether your input \"\n        << \"tensor meta each step is the same as the input of first call graph.\\nThe excepted \"\n        << \"tensor meta is: \" << static_meta_str\n        << \", but the actual tensor meta is: \" << tensor_meta_str << \". The input index is \" << i\n        << \".\";\n  }\n  for (int i = 0; i < outputs.size(); ++i) {\n    CHECK_OR_RETURN(nn_graph->outputs_tensor_meta_str().at(i)\n                    == *JUST(GetTensorMetaString(outputs.at(i))))\n        << Error::RuntimeError() << \"Output tensor meta string mismatch\";\n  }\n  vm::EagerBlobObjectList input_blobs;\n  vm::EagerBlobObjectList output_blobs;\n  JUST(MakeEagerBlobObjectList(&input_blobs, inputs));\n  JUST(MakeEagerBlobObjectList(&output_blobs, outputs));\n  const auto& input_blob_list_ptr =\n      std::make_shared<const vm::EagerBlobObjectList>(std::move(input_blobs));\n  const auto& output_blob_list_ptr =\n      std::make_shared<const vm::EagerBlobObjectList>(std::move(output_blobs));\n  JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe<void> {\n    return builder->LaunchLazyJob(input_blob_list_ptr, output_blob_list_ptr, nn_graph->var_blobs(),\n                                  nn_graph);\n  }));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> SoftSyncNNGraphBuffers(const one::TensorTuple& buffers,\n                                   const std::shared_ptr<NNGraph>& nn_graph) {\n  const auto& eager_blob_objects = std::make_shared<vm::EagerBlobObjectList>();\n  JUST(MakeEagerBlobObjectList(eager_blob_objects.get(), buffers));\n  JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe<void> {\n    return builder->SoftSyncNNGraphBuffers(eager_blob_objects, nn_graph);\n  }));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/framework/nn_graph.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_FRAMEWORK_NN_GRAPH_H_\n#define ONEFLOW_CORE_FRAMEWORK_NN_GRAPH_H_\n\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/framework/nn_graph_if.h\"\n#include \"oneflow/core/framework/op_expr.h\"\n#include \"oneflow/core/framework/tensor.h\"\n#include \"oneflow/core/framework/tensor_tuple.h\"\n#include \"oneflow/core/framework/multi_client_session_context.h\"\n#include \"oneflow/core/job/job.pb.h\"\n#include \"oneflow/core/job/plan.pb.h\"\n#include \"oneflow/core/job/runtime.h\"\n#include \"oneflow/core/eager/eager_blob_object.h\"\n\nnamespace oneflow {\n\nclass Blob;\n\nclass NNGraph final : public NNGraphIf {\n public:\n  explicit NNGraph(const std::string& name, const Job& job, int64_t job_id,\n                   const std::shared_ptr<MultiClientSessionContext>& session_ctx)\n      : name_(name),\n        job_(job),\n        job_id_(job_id),\n        session_ctx_(session_ctx),\n        runtime_inited_(false),\n        is_closed_(false),\n        run_cnt_(0) {}\n  explicit NNGraph(const std::string& name, const Plan& plan, int64_t job_id,\n                   const std::shared_ptr<MultiClientSessionContext>& session_ctx)\n      : name_(name),\n        job_id_(job_id),\n        session_ctx_(session_ctx),\n        plan_(plan),\n        runtime_inited_(false),\n        is_closed_(false),\n        run_cnt_(0) {}\n  OF_DISALLOW_COPY_AND_MOVE(NNGraph);\n  ~NNGraph();\n\n  const std::string& job_name() const override { return name_; }\n  const Job& job() const { return job_; }\n  void restore_job(const Job& job) { job_ = job; }\n  int64_t job_id() const { return job_id_; }\n  void restore_job_id(int64_t job_id) { job_id_ = job_id; }\n  const Plan& plan() const { return plan_; }\n  void restore_plan(const Plan& plan) { plan_ = plan; }\n  const std::vector<std::string>& inputs_op_names() const override;\n  const std::vector<std::string>& outputs_op_names() const override;\n  const std::vector<bool>& inputs_valid() const override;\n  const std::vector<bool>& outputs_valid() const override;\n  const std::vector<std::string>& inputs_tensor_meta_str() const;\n  const std::vector<std::string>& outputs_tensor_meta_str() const;\n  int64_t variable_op_size() const;\n  const std::shared_ptr<vm::EagerBlobObjectList>& var_blobs() const;\n  int64_t run_cnt() const override { return run_cnt_; }\n  void NextRunCnt() override { run_cnt_++; }\n\n  Maybe<void> RegisterAdditionalVarOpNamesAndTensorsToBeLoaded(\n      const std::vector<std::string>& additional_var_names,\n      const std::vector<std::shared_ptr<one::Tensor>>& additional_var_tensors);\n  Maybe<void> RegisterInputOpNamesAndTensors(\n      const std::vector<std::string>& inputs_op_names,\n      const std::vector<std::shared_ptr<one::Tensor>>& input_tensors);\n  Maybe<void> RegisterOutputOpNamesAndTensors(\n      const std::vector<std::string>& outputs_op_names,\n      const std::vector<std::shared_ptr<one::Tensor>>& output_tensors);\n  Maybe<void> RegisterVariableOpNamesAndTensors(\n      const std::vector<std::string>& variable_op_names,\n      const std::vector<std::shared_ptr<one::Tensor>>& variable_tensors);\n  Maybe<std::vector<std::string>> GetAdditionalVarOpNames() const;\n  Maybe<std::vector<std::shared_ptr<one::Tensor>>> GetAdditionalVarOpTensors() const;\n  // After logical graph compile, some state variables should be cleaned or built.\n  Maybe<void> AlignStatesAfterLogicalGraphCompile();\n  // Add special operators into logical graph for lazy runtime.\n  Maybe<void> CompleteLogicalGraphForRuntime();\n  // Build graph with new inputs from a completed job of a shared graph.\n  Maybe<void> BuildWithNewInputFromSharedGraph(\n      const std::vector<std::string>& shared_inputs_op_names,\n      const std::vector<std::shared_ptr<one::Tensor>>& new_input_tensors,\n      const std::vector<std::string>& shared_op_names_from_ordered_original_graph,\n      const std::string& new_serialized_original_job);\n  // Generate execution plan for lazy runtime. Oneflow lazy runtime is an actor based runtime.\n  Maybe<void> CompilePlanForRuntime();\n  // Initialize lazy runtime.\n  Maybe<void> InitRuntime();\n  Maybe<void> CompileAndInitRuntime();\n  Maybe<void> Close();\n  const auto variable_op_name2tensor() const { return variable_op_name2tensor_; }\n  std::vector<std::shared_ptr<one::UserOpExpr>> cached_op_exprs;\n\n private:\n  // Compile the full task graph for all ranks and then broadcast to all ranks.\n  Maybe<void> NaiveCompile();\n  // Each rank compile it's task graph.\n  Maybe<void> MasterAndWorkerRanksCompile();\n  Maybe<void> RegisterFreeEagerTensorsToVariableOpNames();\n  Maybe<void> RegisterNewVariableOpInJobPass();\n  Maybe<void> DeleteOutdatedVariableInVariableTensorMgr();\n  Maybe<void> GetVariableRealBlobAfterSyncPlan();\n\n  void NewRuntimeBuffers();\n  void CloseRuntimeBuffers();\n\n  std::string name_;\n  Job job_;\n  int64_t job_id_;\n  std::shared_ptr<MultiClientSessionContext> session_ctx_;\n  std::vector<std::string> inputs_op_names_;\n  std::vector<std::string> outputs_op_names_;\n  std::vector<bool> input_tensors_valid_;\n  std::vector<bool> output_tensors_valid_;\n  std::vector<std::string> inputs_tensor_meta_str_;\n  std::vector<std::string> outputs_tensor_meta_str_;\n  HashMap<std::string, std::shared_ptr<one::Tensor>> variable_op_name2tensor_;\n  // Additional variables are variable other than model states, such as states in\n  // optimizers/lr schedulers or free eager tensors.\n  HashSet<std::string> additional_variable_op_name_;\n  // Additional states tensor loaded from state dict,\n  // they will be load into job after plan is generated.\n  HashMap<std::string, std::shared_ptr<one::Tensor>>\n      additional_variable_op_tobe_loaded_name2tensor_;\n  HashMap<std::string, vm::EagerBlobObject*> variable_op_name2eager_blob_object_;\n  HashSet<std::string> variable_op_names_;\n  std::shared_ptr<vm::EagerBlobObjectList> variable_op_blobs_;\n  Plan plan_;\n  // TODO(chengcheng): temp impl using runtime now, need reimplement for dynamic multi nn.Graph.\n  std::unique_ptr<Runtime> runtime_;\n  bool runtime_inited_;\n  bool is_closed_;\n  int64_t run_cnt_;\n};\n\nMaybe<void> RunLazyNNGraph(const one::TensorTuple& inputs, const one::TensorTuple& outputs,\n                           const std::shared_ptr<NNGraph>& nn_graph);\n\nMaybe<void> SoftSyncNNGraphBuffers(const one::TensorTuple& buffers,\n                                   const std::shared_ptr<NNGraph>& nn_graph);\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_FRAMEWORK_NN_GRAPH_H_\n"
  },
  {
    "path": "oneflow/core/framework/nn_graph_if.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_FRAMEWORK_NN_GRAPH_IF_H_\n#define ONEFLOW_CORE_FRAMEWORK_NN_GRAPH_IF_H_\n\n#include <string>\n#include <vector>\n\n#include \"oneflow/core/common/symbol.h\"\n\nnamespace oneflow {\n\nclass Device;\n\nclass NNGraphIf {\n public:\n  virtual ~NNGraphIf() = default;\n\n  virtual const std::string& job_name() const = 0;\n  virtual const std::vector<std::string>& inputs_op_names() const = 0;\n  virtual const std::vector<std::string>& outputs_op_names() const = 0;\n  virtual const std::vector<bool>& inputs_valid() const = 0;\n  virtual const std::vector<bool>& outputs_valid() const = 0;\n  virtual int64_t run_cnt() const = 0;\n  virtual void NextRunCnt() = 0;\n\n protected:\n  NNGraphIf() = default;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_FRAMEWORK_NN_GRAPH_IF_H_\n"
  },
  {
    "path": "oneflow/core/framework/op_builder.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/op_builder.h\"\n\n#include \"oneflow/core/common/protobuf.h\"\n#include \"oneflow/core/framework/attr_value.h\"\n#include \"oneflow/core/framework/attr_value_accessor.h\"\n#include \"oneflow/core/framework/id_util.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstatic constexpr char PositionalPlaceholderPrefix[] = \"^Placeholder_\";\n\nOpBuilder::OpBuilder(const std::string& op_type_name) {\n  *(proto_.mutable_op_type_name()) = op_type_name;\n  op_name_ = *CHECK_JUST(UniqueStr(op_type_name));\n}\n\nOpBuilder::OpBuilder(const std::string& op_type_name, const std::string& op_name)\n    : op_name_(op_name) {\n  *(proto_.mutable_op_type_name()) = op_type_name;\n}\n\nMaybe<OpBuilder&> OpBuilder::MaybeInput(const std::string& input_name, const int count) {\n  CHECK_GT_OR_RETURN(count, 0);\n  CHECK_EQ_OR_RETURN(proto_.input().count(input_name), 0)\n      << \"The Input \" << input_name << \" has been specified more than once.\";\n  proto_.add_input_order(input_name);\n  auto* input_list = &((*(proto_.mutable_input()))[input_name]);\n  for (int i = 0; i < count; ++i) {\n    const std::string& tensor_name =\n        op_name_ + \"/\" + PositionalPlaceholderPrefix + std::to_string(input_pos_++);\n    input_list->mutable_s()->Add()->assign(tensor_name);\n    indexed_ibns_.emplace_back(input_name + \"_\" + std::to_string(i));\n  }\n  CHECK_EQ_OR_RETURN(proto_.input().size(), proto_.input_order().size());\n  return *this;\n}\n\nOpBuilder& OpBuilder::Input(const std::string& input_name) {\n  return CHECK_JUST(MaybeInput(input_name, 1));\n}\nOpBuilder& OpBuilder::Input(const std::string& input_name, const int count) {\n  return CHECK_JUST(MaybeInput(input_name, count));\n}\n\nMaybe<OpBuilder&> OpBuilder::MaybeOutput(const std::string& output_name, const int count) {\n  CHECK_GT_OR_RETURN(count, 0);\n  CHECK_EQ_OR_RETURN(proto_.output().count(output_name), 0)\n      << \"The output \" << output_name << \" has been specified more than once.\";\n  proto_.add_output_order(output_name);\n  auto* output_list = &((*(proto_.mutable_output()))[output_name]);\n  for (int i = 0; i < count; ++i) {\n    const std::string& tensor_name = op_name_ + \"/\" + output_name + \"_\" + std::to_string(i);\n    output_list->mutable_s()->Add()->assign(tensor_name);\n    indexed_obns_.emplace_back(output_name + \"_\" + std::to_string(i));\n  }\n  CHECK_EQ_OR_RETURN(proto_.output().size(), proto_.output_order().size());\n  return *this;\n}\n\nOpBuilder& OpBuilder::Output(const std::string& output_name) {\n  return CHECK_JUST(MaybeOutput(output_name, 1));\n}\n\nOpBuilder& OpBuilder::Output(const std::string& output_name, const int count) {\n  return CHECK_JUST(MaybeOutput(output_name, count));\n}\n\ntemplate<>\nMaybe<OpBuilder&> OpBuilder::MaybeAttr(const std::string& attr_name, const AttrValue& attr_value) {\n  (*(proto_.mutable_attr()))[attr_name] = attr_value;\n  return *this;\n}\n\ntemplate<>\nOpBuilder& OpBuilder::Attr(const std::string& attr_name, const AttrValue& attr_value) {\n  return CHECK_JUST(MaybeAttr<AttrValue>(attr_name, attr_value));\n}\n\n#define DEFINE_OP_BUILDER_ATTR_FUNC(field, cpp_type, attr_type)                             \\\n  template<>                                                                                \\\n  Maybe<OpBuilder&> OpBuilder::MaybeAttr<cpp_type>(const std::string& attr_name,            \\\n                                                   const cpp_type& val) {                   \\\n    AttrValue attr_val;                                                                     \\\n    user_op::AttrValueAccessor<cpp_type>::Attr(val, &attr_val);                             \\\n    return this->MaybeAttr<AttrValue>(attr_name, attr_val);                                 \\\n  }                                                                                         \\\n                                                                                            \\\n  template<>                                                                                \\\n  OpBuilder& OpBuilder::Attr<cpp_type>(const std::string& attr_name, const cpp_type& val) { \\\n    return CHECK_JUST(MaybeAttr<cpp_type>(attr_name, val));                                 \\\n  }\n\nOF_PP_FOR_EACH_TUPLE(DEFINE_OP_BUILDER_ATTR_FUNC, ATTR_SEQ)\n#undef DEFINE_OP_BUILDER_ATTR_FUNC\n\nMaybe<UserOpExpr> OpBuilder::Build() {\n  return UserOpExpr::New(op_name_, std::move(proto_), indexed_ibns_, indexed_obns_);\n}\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/framework/op_builder.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_FRAMEWORK_OP_BUILDER_H_\n#define ONEFLOW_CORE_FRAMEWORK_OP_BUILDER_H_\n\n#include <string>\n\n#include \"oneflow/core/framework/op_expr.h\"\n\nnamespace oneflow {\nnamespace one {\n\n// The op builder for UserOp.\n// Note that the internal proto will be moved if the Build method is called.\n// Therefore, please make sure that the Build method be called at last, and do not perform any\n// operations on this builder instance after the calling.\nclass OpBuilder {\n public:\n  OpBuilder() = delete;\n  explicit OpBuilder(const std::string& op_type_name);\n  explicit OpBuilder(const std::string& op_type_name, const std::string& op_name);\n  virtual ~OpBuilder() = default;\n\n  Maybe<OpBuilder&> MaybeInput(const std::string& input_name, const int count);\n  OpBuilder& Input(const std::string& input_name);\n  OpBuilder& Input(const std::string& input_name, const int count);\n\n  Maybe<OpBuilder&> MaybeOutput(const std::string& output_name, const int count);\n  OpBuilder& Output(const std::string& output_name);\n  OpBuilder& Output(const std::string& output_name, const int count);\n\n  template<typename T>\n  Maybe<OpBuilder&> MaybeAttr(const std::string& attr_name, const T& attr_value);\n\n  template<typename T>\n  OpBuilder& Attr(const std::string& attr_name, const T& attr_value);\n\n  Maybe<UserOpExpr> Build();\n\n private:\n  std::string op_name_;\n  UserOpConf proto_;\n\n  int input_pos_ = 0;\n  std::vector<std::string> indexed_ibns_;\n  std::vector<std::string> indexed_obns_;\n};\n\n}  // namespace one\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_FRAMEWORK_OP_BUILDER_H_\n"
  },
  {
    "path": "oneflow/core/framework/op_definition.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_FRAMEWORK_OP_DEFINITION_H_\n#define ONEFLOW_CORE_FRAMEWORK_OP_DEFINITION_H_\n#include <string>\n\n#include \"oneflow/core/common/hash_container.h\"\n#include \"oneflow/core/common/maybe.h\"\n\nnamespace oneflow {\n\nnamespace user_op {\nclass AttrVal;\n}  // namespace user_op\nusing AttrVal = user_op::AttrVal;\n\nclass OpDefinitionBase {\n public:\n  virtual ~OpDefinitionBase() = default;\n  virtual Maybe<AttrVal> Attr(const std::string& attr_name) const = 0;\n  virtual const HashSet<std::string>& AttributeNames() const = 0;\n\n protected:\n  OpDefinitionBase() = default;\n};\n\ntemplate<typename Derived>\nclass OpDefinition : public OpDefinitionBase {\n public:\n  virtual ~OpDefinition() = default;\n  const HashSet<std::string>& AttributeNames() const override { return Derived::AttrNames(); }\n\n protected:\n  OpDefinition() : OpDefinitionBase() {}\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_FRAMEWORK_OP_DEFINITION_H_\n"
  },
  {
    "path": "oneflow/core/framework/op_expr.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <memory>\n#include \"oneflow/core/common/error.h\"\n#include \"oneflow/core/framework/op_expr.h\"\n#include \"oneflow/core/common/auto_registration_factory.h\"\n#include \"oneflow/core/framework/attr_value_accessor.h\"\n#include \"oneflow/core/framework/attr_map.h\"\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/framework/op_interpreter/dispatch_frame.h\"\n#include \"oneflow/core/framework/user_op_registry_manager.h\"\n#include \"oneflow/core/framework/local_tensor_infer_cache.h\"\n#include \"oneflow/core/framework/global_tensor_infer_cache.h\"\n#include \"oneflow/core/operator/op_conf.pb.h\"\n#include \"oneflow/user/kernels/stateful_opkernel.h\"\n#include \"oneflow/core/common/container_util.h\"\n\nnamespace oneflow {\nnamespace one {\n\nMaybe<autocast::AutoCastMeta> OpExpr::GetOrCreateAutoCastMeta() const {\n  static auto autocast_meta = std::make_shared<autocast::AutoCastMeta>();\n  return autocast_meta;\n}\n\nBuiltinOpExpr::BuiltinOpExpr(const std::string& op_name,\n                             const std::vector<std::string>& indexed_ibns,\n                             const std::vector<std::string>& indexed_obns)\n    : op_name_(op_name),\n      input_arg_tuple_(new ArgTuple(indexed_ibns)),\n      output_arg_tuple_(new ArgTuple(indexed_obns)) {}\n\n#define DEFINE_BUILTIN_OPEXPR_OP(T, op_type, disable_grad, support_non_contiguous)      \\\n  template<>                                                                            \\\n  const std::string& BuiltinOpExprImpl<T>::op_type_name() const {                       \\\n    static const std::string& name(op_type);                                            \\\n    return name;                                                                        \\\n  }                                                                                     \\\n  template<>                                                                            \\\n  Maybe<bool> BuiltinOpExprImpl<T>::IsGradDisabled() const {                            \\\n    return disable_grad;                                                                \\\n  }                                                                                     \\\n  template<>                                                                            \\\n  Maybe<bool> BuiltinOpExprImpl<T>::SupportNonContiguous() const {                      \\\n    return support_non_contiguous;                                                      \\\n  }                                                                                     \\\n  template<>                                                                            \\\n  Maybe<autocast::AutoCastMeta> BuiltinOpExprImpl<T>::GetOrCreateAutoCastMeta() const { \\\n    return OpExpr::GetOrCreateAutoCastMeta();                                           \\\n  }\n\nDEFINE_BUILTIN_OPEXPR_OP(FeedInputOpConf, \"feed_input\", false, false);\nDEFINE_BUILTIN_OPEXPR_OP(FeedVariableOpConf, \"feed_variable\", false, false);\nDEFINE_BUILTIN_OPEXPR_OP(FetchOutputOpConf, \"fetch_output\", false, false);\nDEFINE_BUILTIN_OPEXPR_OP(ImageDecoderRandomCropResizeOpConf, \"image_gpu_decode\", true, false);\nDEFINE_BUILTIN_OPEXPR_OP(VariableOpConf, \"variable\", true, false);\nDEFINE_BUILTIN_OPEXPR_OP(CastToLocalOpConf, \"cast_to_local\", false, false);\nDEFINE_BUILTIN_OPEXPR_OP(CastFromLocalOpConf, \"cast_from_local\", false, false);\nDEFINE_BUILTIN_OPEXPR_OP(DistributeSplitOpConf, \"distribute_split\", false, false);\nDEFINE_BUILTIN_OPEXPR_OP(DistributeCloneOpConf, \"distribute_clone\", false, false);\nDEFINE_BUILTIN_OPEXPR_OP(DistributeConcatOpConf, \"distribute_concat\", false, false);\nDEFINE_BUILTIN_OPEXPR_OP(DistributeAddOpConf, \"distribute_add\", false, false);\n\n#undef DEFINE_BUILTIN_OPEXPR_OP\n\ntemplate<>\nconst std::string& BuiltinOpExprImpl<UserOpConf>::op_type_name() const {\n  return op_proto_.op_type_name();\n}\n\nconst std::string& GlobalToGlobalOpExpr::op_type_name() const {\n  static const std::string kOpTypeName = \"global_to_global\";\n  return kOpTypeName;\n}\n\nconst std::string& LocalToGlobalOpExpr::op_type_name() const {\n  static const std::string kOpTypeName = \"local_to_global\";\n  return kOpTypeName;\n}\n\nconst std::string& GlobalToLocalOpExpr::op_type_name() const {\n  static const std::string kOpTypeName = \"global_to_local\";\n  return kOpTypeName;\n}\n\ntemplate<>\nMaybe<void> BuiltinOpExprImpl<UserOpConf>::BuildOpConf(OperatorConf* op_conf,\n                                                       const AttrMap& attrs) const {\n  *(op_conf->mutable_name()) = op_name_;\n  *(op_conf->mutable_user_conf()) = op_proto_;\n  *(op_conf->mutable_loc()) = DispatchFrame::get_str();\n  auto* user_op_conf = op_conf->mutable_user_conf();\n  for (const auto& it : attrs) {\n    AttrValue attr_val;\n    JUST(user_op::AttrValueUtil::ToProtoAttrValue(*it.second, &attr_val));\n    (*(user_op_conf->mutable_attr()))[it.first] = attr_val;\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<StatefulOpKernel> UserOpExpr::MutKernel4Stream(Symbol<Stream> stream) const {\n  const auto& it = stream2kernel_.find(stream);\n  if (it != stream2kernel_.end()) { return it->second; }\n\n  std::shared_ptr<OperatorConf> op_conf = std::make_shared<OperatorConf>();\n  JUST(BuildOpConf(op_conf.get(), {}));\n  op_conf->set_device_tag(stream->device()->type());\n  auto parallel_desc = JUST(Placement4Device(stream->device())).shared_from_symbol();\n  const auto& opkernel = JUST(StatefulOpKernel::New(op_conf, stream, base_attrs(), parallel_desc,\n                                                    input_arg_tuple(), output_arg_tuple()));\n  stream2kernel_.emplace(stream, opkernel);\n  return opkernel;\n}\n\ntemplate<>\nMaybe<bool> BuiltinOpExprImpl<UserOpConf>::IsGradDisabled() const {\n  const auto* registry =\n      user_op::UserOpRegistryMgr::Get().GetOpRegistryResult(proto().op_type_name());\n  CHECK_NOTNULL_OR_RETURN(registry);\n  return registry->no_grad;\n}\n\ntemplate<>\nMaybe<bool> BuiltinOpExprImpl<UserOpConf>::SupportNonContiguous() const {\n  const auto* registry =\n      user_op::UserOpRegistryMgr::Get().GetOpRegistryResult(proto().op_type_name());\n  CHECK_NOTNULL_OR_RETURN(registry)\n      << \"The op(operation) \" << proto().op_type_name()\n      << \" is not found. Please check whether it has been registered correctly.\";\n  return registry->non_contiguous_supported;\n}\n\ntemplate<>\nMaybe<OpExprGradClosure> BuiltinOpExprImpl<UserOpConf>::GetOrCreateOpGradClosure() const {\n  if (!op_grad_func_.get()) {\n    CHECK_OR_RETURN((IsClassRegistered<std::string, OpExprGradFunctionIf>(proto().op_type_name())))\n        << \"The gradient function for op \" << proto().op_type_name()\n        << \" is not found. Please check whether it has been implemented and registered correctly.\";\n    op_grad_func_.reset(NewObj<std::string, OpExprGradFunctionIf>(proto().op_type_name()));\n    JUST(op_grad_func_->Init(*this));\n  }\n  return std::make_shared<OpExprGradClosure>(op_grad_func_);\n}\n\ntemplate<>\nMaybe<autocast::AutoCastMeta> BuiltinOpExprImpl<UserOpConf>::GetOrCreateAutoCastMeta() const {\n  if (!autocast_meta_) {\n    autocast_meta_ =\n        autocast::MakeAutoCastMeta(proto().op_type_name(), this->indexed_input_pairs());\n  }\n  return autocast_meta_;\n}\n\nnamespace {\n\nclass UserOpExprInferContext : public user_op::InferContext {\n public:\n  UserOpExprInferContext(const UserOpExpr* user_op_expr, const AttrMap& attrs,\n                         const std::string& device_tag,\n                         const std::function<const TensorMeta*(int32_t)>& TensorMeta4InputIndex,\n                         const std::function<TensorMeta*(int32_t)>& TensorMeta4OutputIndex)\n      : user_op_expr_(user_op_expr),\n        composed_attrs_(attrs, user_op_expr->base_attrs()),\n        tensor_meta4input_index_(TensorMeta4InputIndex),\n        tensor_meta4output_index_(TensorMeta4OutputIndex) {\n    loc_ = DispatchFrame::get_str();\n  }\n  virtual ~UserOpExprInferContext() override = default;\n\n  const std::vector<std::pair<std::string, int32_t>>& inputs() const override {\n    return user_op_expr_->indexed_input_pairs();\n  }\n\n  const std::vector<std::pair<std::string, int32_t>>& outputs() const override {\n    return user_op_expr_->indexed_output_pairs();\n  }\n\n  const user_op::TensorDesc& InputTensorDesc(const std::string& arg_name,\n                                             int32_t index) const override {\n    return *TensorDesc4ArgNameAndIndex(arg_name, index);\n  }\n  const user_op::TensorDesc& OutputTensorDesc(const std::string& arg_name,\n                                              int32_t index) const override {\n    return *TensorDesc4ArgNameAndIndex(arg_name, index);\n  }\n  user_op::TensorDesc* MutOutputTensorDesc(const std::string& name, int32_t index) override {\n    return MutTensorDesc4ArgNameAndIndex(name, index);\n  }\n\n  const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(const std::string& name,\n                                                        int32_t index) const {\n    {\n      const auto& arg_tuple = *user_op_expr_->output_arg_tuple();\n      int32_t tuple_index = arg_tuple.TensorTupleIndex4ArgNameAndIndex(name, index);\n      if (tuple_index >= 0) { return tensor_meta4output_index_(tuple_index); }\n    }\n    {\n      const auto& arg_tuple = *user_op_expr_->input_arg_tuple();\n      int32_t tuple_index = arg_tuple.TensorTupleIndex4ArgNameAndIndex(name, index);\n      if (tuple_index >= 0) { return tensor_meta4input_index_(tuple_index); }\n    }\n    return nullptr;\n  }\n\n  user_op::TensorDesc* MutTensorDesc4ArgNameAndIndex(const std::string& name, int32_t index) {\n    {\n      const auto& arg_tuple = *user_op_expr_->output_arg_tuple();\n      int32_t tuple_index = arg_tuple.TensorTupleIndex4ArgNameAndIndex(name, index);\n      if (tuple_index >= 0) {\n        TensorMeta* tensor_meta_ptr = tensor_meta4output_index_(tuple_index);\n        CHECK_NOTNULL(dynamic_cast<MutTensorMeta*>(tensor_meta_ptr));\n        return tensor_meta_ptr;\n      }\n    }\n    {\n      const auto& arg_tuple = *user_op_expr_->input_arg_tuple();\n      int32_t tuple_index = arg_tuple.TensorTupleIndex4ArgNameAndIndex(name, index);\n      if (tuple_index >= 0) {\n        const TensorMeta* tensor_meta_ptr = tensor_meta4input_index_(tuple_index);\n        CHECK_NOTNULL(dynamic_cast<const MutTensorMeta*>(tensor_meta_ptr));\n        return const_cast<TensorMeta*>(tensor_meta_ptr);\n      }\n    }\n    PRINT_BUG_PROMPT_AND_ABORT();\n    return nullptr;\n  }\n\n  const Shape& InputShape(const std::string& name, int32_t index) const override {\n    const auto& arg_tuple = *user_op_expr_->input_arg_tuple();\n    int32_t tuple_index = arg_tuple.TensorTupleIndex4ArgNameAndIndex(name, index);\n    CHECK_GE(tuple_index, 0);\n    return tensor_meta4input_index_(tuple_index)->shape();\n  }\n\n  const Shape& OutputShape(const std::string& name, int32_t index) const override {\n    const auto& arg_tuple = *user_op_expr_->output_arg_tuple();\n    int32_t tuple_index = arg_tuple.TensorTupleIndex4ArgNameAndIndex(name, index);\n    CHECK_GE(tuple_index, 0);\n    return tensor_meta4input_index_(tuple_index)->shape();\n  }\n\n  void SetOutputShape(const std::string& name, int32_t index, const Shape& shape) override {\n    const auto& arg_tuple = *user_op_expr_->output_arg_tuple();\n    int32_t tuple_index = arg_tuple.TensorTupleIndex4ArgNameAndIndex(name, index);\n    CHECK_GE(tuple_index, 0);\n    TensorMeta* tensor_meta_ptr = tensor_meta4output_index_(tuple_index);\n    CHECK_NOTNULL(dynamic_cast<MutTensorMeta*>(tensor_meta_ptr));\n    return tensor_meta_ptr->set_shape(shape);\n  }\n\n  const Shape& Shape4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override {\n    return TensorDesc4ArgNameAndIndex(arg_name, index)->shape();\n  }\n\n  void SetShape4ArgNameAndIndex(const std::string& arg_name, int32_t index,\n                                const Shape& shape) override {\n    return MutTensorDesc4ArgNameAndIndex(arg_name, index)->set_shape(shape);\n  }\n\n  const Stride& InputStride(const std::string& name, int32_t index) const override {\n    const auto& arg_tuple = *user_op_expr_->input_arg_tuple();\n    int32_t tuple_index = arg_tuple.TensorTupleIndex4ArgNameAndIndex(name, index);\n    CHECK_GE(tuple_index, 0);\n    return tensor_meta4input_index_(tuple_index)->stride();\n  }\n\n  const Stride& OutputStride(const std::string& name, int32_t index) const override {\n    const auto& arg_tuple = *user_op_expr_->output_arg_tuple();\n    int32_t tuple_index = arg_tuple.TensorTupleIndex4ArgNameAndIndex(name, index);\n    CHECK_GE(tuple_index, 0);\n    return tensor_meta4output_index_(tuple_index)->stride();\n  }\n\n  void SetOutputStride(const std::string& name, int32_t index, const Stride& stride) override {\n    const auto& arg_tuple = *user_op_expr_->output_arg_tuple();\n    int32_t tuple_index = arg_tuple.TensorTupleIndex4ArgNameAndIndex(name, index);\n    CHECK_GE(tuple_index, 0);\n    TensorMeta* tensor_meta_ptr = tensor_meta4output_index_(tuple_index);\n    CHECK_NOTNULL(dynamic_cast<MutTensorMeta*>(tensor_meta_ptr));\n    return tensor_meta_ptr->set_stride(stride);\n  }\n\n  const Stride& Stride4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override {\n    return TensorDesc4ArgNameAndIndex(arg_name, index)->stride();\n  }\n\n  void SetStride4ArgNameAndIndex(const std::string& arg_name, int32_t index,\n                                 const Stride& stride) override {\n    return MutTensorDesc4ArgNameAndIndex(arg_name, index)->set_stride(stride);\n  }\n\n  DataType InputDType(const std::string& arg_name, int32_t index) const override {\n    return Dtype4ArgNameAndIndex(arg_name, index);\n  }\n  DataType OutputDType(const std::string& arg_name, int32_t index) const override {\n    return Dtype4ArgNameAndIndex(arg_name, index);\n  }\n  void SetOutputDType(const std::string& arg_name, int32_t index, DataType data_type) override {\n    return SetDtype4ArgNameAndIndex(arg_name, index, data_type);\n  }\n  DataType Dtype4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override {\n    return TensorDesc4ArgNameAndIndex(arg_name, index)->data_type();\n  }\n  void SetDtype4ArgNameAndIndex(const std::string& arg_name, int32_t index,\n                                DataType data_type) override {\n    return MutTensorDesc4ArgNameAndIndex(arg_name, index)->set_data_type(data_type);\n  }\n\n  MemoryFormat InputMemoryFormat(const std::string& arg_name, int32_t index) const override {\n    return MemoryFormat4ArgNameAndIndex(arg_name, index);\n  }\n  MemoryFormat OutputMemoryFormat(const std::string& arg_name, int32_t index) const override {\n    return MemoryFormat4ArgNameAndIndex(arg_name, index);\n  }\n  void SetOutputMemoryFormat(const std::string& arg_name, int32_t index,\n                             MemoryFormat memory_format) override {\n    return SetMemoryFormat4ArgNameAndIndex(arg_name, index, memory_format);\n  }\n  MemoryFormat MemoryFormat4ArgNameAndIndex(const std::string& arg_name,\n                                            int32_t index) const override {\n    return TensorDesc4ArgNameAndIndex(arg_name, index)->memory_format();\n  }\n  void SetMemoryFormat4ArgNameAndIndex(const std::string& arg_name, int32_t index,\n                                       MemoryFormat memory_format) override {\n    MutTensorDesc4ArgNameAndIndex(arg_name, index)->set_memory_format(memory_format);\n  }\n\n  bool InputIsDynamic(const std::string& arg_name, int32_t index) const override {\n    return IsDynamic4ArgNameAndIndex(arg_name, index);\n  }\n  bool OutputIsDynamic(const std::string& arg_name, int32_t index) const override {\n    return IsDynamic4ArgNameAndIndex(arg_name, index);\n  }\n  void SetOutputIsDynamic(const std::string& arg_name, int32_t index, bool is_dynamic) override {\n    return SetIsDynamic4ArgNameAndIndex(arg_name, index, is_dynamic);\n  }\n  bool IsDynamic4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override {\n    return TensorDesc4ArgNameAndIndex(arg_name, index)->is_dynamic();\n  }\n  void SetIsDynamic4ArgNameAndIndex(const std::string& arg_name, int32_t index,\n                                    bool is_dynamic) override {\n    return MutTensorDesc4ArgNameAndIndex(arg_name, index)->set_is_dynamic(is_dynamic);\n  }\n  const std::string& input(const std::string& arg_name, int32_t index) const override {\n    const auto& arg_tuple = *user_op_expr_->input_arg_tuple();\n    int32_t tuple_index = arg_tuple.TensorTupleIndex4ArgNameAndIndex(arg_name, index);\n    CHECK_GE(tuple_index, 0);\n    return arg_tuple.indexed_bns().at(tuple_index);\n  }\n  const std::string& output(const std::string& arg_name, int32_t index) const override {\n    const auto& arg_tuple = *user_op_expr_->output_arg_tuple();\n    int32_t tuple_index = arg_tuple.TensorTupleIndex4ArgNameAndIndex(arg_name, index);\n    CHECK_GE(tuple_index, 0);\n    return arg_tuple.indexed_bns().at(tuple_index);\n  }\n  bool has_input(const std::string& arg_name, int32_t index) const override {\n    const auto& arg_tuple = *user_op_expr_->input_arg_tuple();\n    int32_t tuple_index = arg_tuple.TensorTupleIndex4ArgNameAndIndex(arg_name, index);\n    return tuple_index >= 0;\n  }\n  bool has_output(const std::string& arg_name, int32_t index) const override {\n    const auto& arg_tuple = *user_op_expr_->output_arg_tuple();\n    int32_t tuple_index = arg_tuple.TensorTupleIndex4ArgNameAndIndex(arg_name, index);\n    return tuple_index >= 0;\n  }\n  int32_t input_size(const std::string& arg_name) const override {\n    const auto& arg_tuple = *user_op_expr_->input_arg_tuple();\n    return arg_tuple.arg_name2bn_index2tensor_tuple_index().at(arg_name).size();\n  }\n  int32_t output_size(const std::string& arg_name) const override {\n    const auto& arg_tuple = *user_op_expr_->output_arg_tuple();\n    return arg_tuple.arg_name2bn_index2tensor_tuple_index().at(arg_name).size();\n  }\n  const std::string& op_name() const override { return user_op_expr_->op_name(); }\n  const std::string& op_type_name() const override { return user_op_expr_->op_type_name(); }\n  const std::string& op_loc() const override { return loc_; }\n\n private:\n  const std::shared_ptr<const user_op::AttrVal>& Attr4Name(\n      const std::string& attr_name) const override {\n    return composed_attrs_.Attr4Name(attr_name);\n  }\n  const UserOpExpr* user_op_expr_;\n  const ComposedAttrMap composed_attrs_;\n  const std::function<const TensorMeta*(int32_t)>& tensor_meta4input_index_;\n  const std::function<TensorMeta*(int32_t)>& tensor_meta4output_index_;\n  std::string loc_;\n};\n\nnamespace {\n\nSymbol<NdSbp> Get1DBroadcastNdSbp() {\n  NdSbp broadcast_nd_sbp;\n  broadcast_nd_sbp.mutable_sbp_parallel()->Add()->mutable_broadcast_parallel();\n  return SymbolOf(broadcast_nd_sbp);\n}\n\nauto* CachedGet1DBroadcastNdSbp = DECORATE(&Get1DBroadcastNdSbp, ThreadLocalCached);\n\n}  // namespace\n\nclass UserOpExprPhysicalInferContext final : public UserOpExprInferContext {\n public:\n  UserOpExprPhysicalInferContext(\n      const UserOpExpr* user_op_expr, const AttrMap& attrs, const std::string& device_tag,\n      const std::function<const TensorMeta*(int32_t)>& TensorMeta4InputIndex,\n      const std::function<TensorMeta*(int32_t)>& TensorMeta4OutputIndex)\n      : UserOpExprInferContext(user_op_expr, attrs, device_tag, TensorMeta4InputIndex,\n                               TensorMeta4OutputIndex),\n        parallel_desc_(CHECK_JUST(GetParallelDescOfThisRank(device_tag))) {\n    parallel_ctx_.set_parallel_id(0);\n    parallel_ctx_.set_parallel_num(1);\n  }\n  ~UserOpExprPhysicalInferContext() override = default;\n\n  const user_op::TensorDesc* LogicalTensorDesc4ArgNameAndIndex(const std::string& name,\n                                                               int32_t index) const override {\n    PRINT_BUG_PROMPT_AND_ABORT();\n    return nullptr;\n  }\n\n  const ParallelContext& parallel_ctx() const override { return parallel_ctx_; }\n  const ParallelDesc& parallel_desc() const override { return *parallel_desc_; }\n  const SbpParallel& SbpParallel4ArgNameAndIndex(const std::string& name,\n                                                 int32_t index) const override {\n    CHECK_NOTNULL(TensorDesc4ArgNameAndIndex(name, index));\n    return CachedGet1DBroadcastNdSbp()->sbp_parallel(0);\n  }\n  const NdSbp& NdSbp4ArgNameAndIndex(const std::string& name, int32_t index) const override {\n    CHECK_NOTNULL(TensorDesc4ArgNameAndIndex(name, index));\n    return *(CachedGet1DBroadcastNdSbp());\n  }\n  int64_t parallel_num() const override { return 1; }\n\n private:\n  // these member vars just used for physical infer\n  Symbol<ParallelDesc> parallel_desc_;\n  ParallelContext parallel_ctx_;\n};\n\nclass UserOpExprLogicalInferContext final : public UserOpExprInferContext {\n public:\n  UserOpExprLogicalInferContext(\n      const UserOpExpr* user_op_expr, const AttrMap& attrs, Symbol<ParallelDesc> parallel_desc,\n      const std::function<const TensorMeta*(int32_t)>& TensorMeta4InputIndex,\n      const std::function<TensorMeta*(int32_t)>& TensorMeta4OutputIndex)\n      : UserOpExprInferContext(user_op_expr, attrs, parallel_desc->device_tag(),\n                               TensorMeta4InputIndex, TensorMeta4OutputIndex),\n        parallel_desc_(parallel_desc) {\n    const auto& opt_parallel_id = CHECK_JUST(GetParallelId4CurrentProcessCtx(parallel_desc_));\n    // Default parallel_id = -1, which will not cause bad effects becauce it will never be used in\n    // LogicalTensorDescInfer.\n    int64_t parallel_id = -1;\n    if (opt_parallel_id->has_value()) { parallel_id = CHECK_JUST(*opt_parallel_id); }\n    parallel_ctx_.set_parallel_id(parallel_id);\n    parallel_ctx_.set_parallel_num(parallel_desc_->parallel_num());\n  }\n  ~UserOpExprLogicalInferContext() override = default;\n\n  const user_op::TensorDesc* LogicalTensorDesc4ArgNameAndIndex(const std::string& name,\n                                                               int32_t index) const override {\n    PRINT_BUG_PROMPT_AND_ABORT();\n    return nullptr;\n  }\n\n  const ParallelContext& parallel_ctx() const override { return parallel_ctx_; }\n  const ParallelDesc& parallel_desc() const override { return *parallel_desc_; }\n  const SbpParallel& SbpParallel4ArgNameAndIndex(const std::string& name,\n                                                 int32_t index) const override {\n    const GlobalTensorMeta* tensor_meta =\n        dynamic_cast<const GlobalTensorMeta*>(TensorDesc4ArgNameAndIndex(name, index));\n    Symbol<NdSbp> nd_sbp = tensor_meta->nd_sbp();\n    CHECK_EQ(nd_sbp->sbp_parallel_size(), 1);\n    return nd_sbp->sbp_parallel(0);\n  }\n  const NdSbp& NdSbp4ArgNameAndIndex(const std::string& name, int32_t index) const override {\n    const GlobalTensorMeta* tensor_meta =\n        dynamic_cast<const GlobalTensorMeta*>(TensorDesc4ArgNameAndIndex(name, index));\n    return *tensor_meta->nd_sbp();\n  }\n  int64_t parallel_num() const override { return parallel_desc_->parallel_num(); }\n\n private:\n  Symbol<ParallelDesc> parallel_desc_;\n  ParallelContext parallel_ctx_;\n};\n\nclass UserOpExprDeviceAndStreamInferContext final : public user_op::DeviceAndStreamInferContext {\n public:\n  UserOpExprDeviceAndStreamInferContext(const UserOpExpr* user_op_expr, const AttrMap& attrs,\n                                        const TensorTuple& input_tensors,\n                                        TensorTuple* output_tensors)\n      : user_op_expr_(user_op_expr),\n        composed_attrs_(attrs, user_op_expr->base_attrs()),\n        input_tensors_(&input_tensors),\n        output_tensors_(output_tensors) {}\n\n  const std::vector<std::pair<std::string, int32_t>>& inputs() const override {\n    return user_op_expr_->indexed_input_pairs();\n  }\n\n  const std::vector<std::pair<std::string, int32_t>>& outputs() const override {\n    return user_op_expr_->indexed_output_pairs();\n  }\n\n  Symbol<Device>* OutputTensorDevice4ArgNameAndIndex(const std::string& name,\n                                                     int64_t index) override {\n    const auto& arg_tuple = *user_op_expr_->output_arg_tuple();\n    int32_t tuple_index = arg_tuple.TensorTupleIndex4ArgNameAndIndex(name, index);\n    CHECK_GE(tuple_index, 0);\n    return CHECK_JUST(output_tensors_->at(tuple_index)->mut_device());\n  }\n\n  Symbol<Device> InputTensorDevice4ArgNameAndIndex(const std::string& name,\n                                                   int64_t index) const override {\n    const auto& arg_tuple = *user_op_expr_->input_arg_tuple();\n    int32_t tuple_index = arg_tuple.TensorTupleIndex4ArgNameAndIndex(name, index);\n    CHECK_GE(tuple_index, 0);\n    return CHECK_JUST(input_tensors_->at(tuple_index)->device());\n  }\n\n private:\n  const std::shared_ptr<const user_op::AttrVal>& Attr4Name(\n      const std::string& attr_name) const override {\n    return composed_attrs_.Attr4Name(attr_name);\n  }\n  const UserOpExpr* user_op_expr_;\n  const ComposedAttrMap composed_attrs_;\n  const TensorTuple* input_tensors_;\n  TensorTuple* output_tensors_;\n};\n\n}  // namespace\n\nUserOpExpr::UserOpExpr(const std::string& op_name, UserOpConf&& proto, const AttrMap& base_attrs,\n                       const std::vector<std::string>& indexed_ibns,\n                       const std::vector<std::string>& indexed_obns)\n    : BuiltinOpExprImpl<UserOpConf>(op_name, std::move(proto), indexed_ibns, indexed_obns),\n      base_attrs_(base_attrs) {}\n\nMaybe<void> UserOpExpr::Init(const std::shared_ptr<const UserOpExpr>& self) {\n  const auto& op_type_name = op_proto_.op_type_name();\n  const auto* registry = user_op::UserOpRegistryMgr::Get().GetOpRegistryResult(op_type_name);\n  CHECK_NOTNULL_OR_RETURN(registry);\n  logical_tensor_desc_infer_fn_ = registry->logical_tensor_desc_infer_fn;\n  CHECK_OR_RETURN(static_cast<bool>(logical_tensor_desc_infer_fn_))\n      << Error::RuntimeError() << \"registry->logical_tensor_desc_infer_fn failed.\";\n  physical_tensor_desc_infer_fn_ = registry->physical_tensor_desc_infer_fn;\n  CHECK_OR_RETURN(static_cast<bool>(physical_tensor_desc_infer_fn_))\n      << Error::RuntimeError() << \"registry->logical_tensor_desc_infer_fn failed.\";\n  dtype_infer_fn_ = registry->data_type_infer_fn;\n  CHECK_OR_RETURN(static_cast<bool>(dtype_infer_fn_))\n      << Error::RuntimeError() << \"registry->data_type_infer_fn failed.\";\n  if (registry->device_and_stream_infer_fn) {\n    device_and_stream_infer_fn_ = registry->device_and_stream_infer_fn;\n  }\n  local_tensor_infer_cache_.reset(new LocalTensorInferCache(self));\n  global_tensor_infer_cache_.reset(new GlobalTensorInferCache(self));\n  const auto& indexed_input_pairs = this->indexed_input_pairs();\n  for (int32_t i = 0; i < indexed_input_pairs.size(); ++i) {\n    const auto& input_pair = JUST(VectorAt(indexed_input_pairs, i));\n    if (user_op::UserOpHostMemoryInputRegistry::Get().IsHostMemoryInput4Op(\n            op_type_name, input_pair.first, input_pair.second)) {\n      host_memory_input_ids_.emplace_back(i);\n    }\n  }\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<UserOpExpr> UserOpExpr::New(const std::string& op_name, UserOpConf&& op_proto,\n                                               const std::vector<std::string>& indexed_ibns,\n                                               const std::vector<std::string>& indexed_obns) {\n  JUST(AddAttrDefaultValueAndCheckValid(&op_proto));\n  AttrMap base_attrs = MakeAttrMapFromUserOpConf(op_proto);\n  std::shared_ptr<UserOpExpr> op_expr(\n      new UserOpExpr(op_name, std::move(op_proto), base_attrs, indexed_ibns, indexed_obns));\n  JUST(op_expr->Init(op_expr));\n  return op_expr;\n}\n\nMaybe<void> UserOpExpr::InferPhysicalTensorDesc(\n    const AttrMap& attrs, const std::string& device_tag,\n    const std::function<const TensorMeta*(int32_t)>& TensorMeta4InputIndex,\n    const std::function<TensorMeta*(int32_t)>& TensorMeta4OutputIndex) const {\n  UserOpExprPhysicalInferContext infer_ctx(this, attrs, device_tag, TensorMeta4InputIndex,\n                                           TensorMeta4OutputIndex);\n  JUST(physical_tensor_desc_infer_fn_(&infer_ctx));\n  JUST(dtype_infer_fn_(&infer_ctx));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> UserOpExpr::InferLogicalTensorDesc(\n    const AttrMap& attrs, Symbol<ParallelDesc> parallel_desc,\n    const std::function<const TensorMeta*(int32_t)>& TensorMeta4InputIndex,\n    const std::function<TensorMeta*(int32_t)>& TensorMeta4OutputIndex) const {\n  UserOpExprLogicalInferContext infer_ctx(this, attrs, parallel_desc, TensorMeta4InputIndex,\n                                          TensorMeta4OutputIndex);\n  JUST(logical_tensor_desc_infer_fn_(&infer_ctx));\n  JUST(dtype_infer_fn_(&infer_ctx));\n  return Maybe<void>::Ok();\n}\n\nMaybe<Symbol<Stream>> UserOpExpr::InferDeviceAndStream(const AttrMap& attrs,\n                                                       const TensorTuple& input_tensors,\n                                                       TensorTuple* output_tensors) const {\n  CHECK_OR_RETURN(static_cast<bool>(device_and_stream_infer_fn_));\n  UserOpExprDeviceAndStreamInferContext device_infer_ctx(this, attrs, input_tensors,\n                                                         output_tensors);\n  return TRY(device_and_stream_infer_fn_(&device_infer_ctx));\n}\n\nGlobalToGlobalOpExpr::GlobalToGlobalOpExpr(const Optional<Symbol<NdSbp>>& grad_nd_sbp)\n    : grad_nd_sbp_(grad_nd_sbp) {}\n\n/* static */ Maybe<GlobalToGlobalOpExpr> GlobalToGlobalOpExpr::New(\n    const Optional<Symbol<NdSbp>>& grad_nd_sbp) {\n  auto* ptr = new GlobalToGlobalOpExpr(grad_nd_sbp);\n  return std::shared_ptr<GlobalToGlobalOpExpr>(ptr);\n}\n\nCastGlobalOpExpr::CastGlobalOpExpr(const std::string& op_name) : op_name_(op_name) {}\n\nLocalToGlobalOpExpr::LocalToGlobalOpExpr(const std::string& op_name) : CastGlobalOpExpr(op_name) {}\n\n/* static */ Maybe<LocalToGlobalOpExpr> LocalToGlobalOpExpr::New(const std::string& op_name) {\n  return std::shared_ptr<LocalToGlobalOpExpr>(new LocalToGlobalOpExpr(op_name));\n}\n\nGlobalToLocalOpExpr::GlobalToLocalOpExpr(const std::string& op_name) : CastGlobalOpExpr(op_name) {}\n\n/* static */ Maybe<GlobalToLocalOpExpr> GlobalToLocalOpExpr::New(const std::string& op_name) {\n  return std::shared_ptr<GlobalToLocalOpExpr>(new GlobalToLocalOpExpr(op_name));\n}\n\ntemplate<>\nMaybe<void> BuiltinOpExprImpl<FeedInputOpConf>::BuildOpConf(OperatorConf* op_conf,\n                                                            const AttrMap& attrs) const {\n  CHECK_EQ_OR_RETURN(attrs.size(), 0);\n  *(op_conf->mutable_name()) = op_name_;\n  *(op_conf->mutable_feed_input_conf()) = op_proto_;\n  *(op_conf->mutable_loc()) = DispatchFrame::get_str();\n  return Maybe<void>::Ok();\n}\n\ntemplate<>\nMaybe<OpExprGradClosure> BuiltinOpExprImpl<FeedInputOpConf>::GetOrCreateOpGradClosure() const {\n  if (!op_grad_func_.get()) {\n    op_grad_func_.reset(NewObj<std::string, OpExprGradFunctionIf>(\"graph_feed_and_fetch\"));\n    CHECK_NOTNULL_OR_RETURN(op_grad_func_.get());  // NOLINT\n    JUST(op_grad_func_->Init(*this));\n  }\n  return std::make_shared<OpExprGradClosure>(op_grad_func_);\n}\n\ntemplate<>\nMaybe<void> BuiltinOpExprImpl<FeedVariableOpConf>::BuildOpConf(OperatorConf* op_conf,\n                                                               const AttrMap& attrs) const {\n  *(op_conf->mutable_name()) = op_name_;\n  *(op_conf->mutable_feed_variable_conf()) = op_proto_;\n  *(op_conf->mutable_loc()) = DispatchFrame::get_str();\n  return Maybe<void>::Ok();\n}\n\ntemplate<>\nMaybe<OpExprGradClosure> BuiltinOpExprImpl<FeedVariableOpConf>::GetOrCreateOpGradClosure() const {\n  if (!op_grad_func_.get()) {\n    op_grad_func_.reset(NewObj<std::string, OpExprGradFunctionIf>(\"graph_feed_and_fetch\"));\n    CHECK_NOTNULL_OR_RETURN(op_grad_func_.get());  // NOLINT\n    JUST(op_grad_func_->Init(*this));\n  }\n  return std::make_shared<OpExprGradClosure>(op_grad_func_);\n}\n\ntemplate<>\nMaybe<void> BuiltinOpExprImpl<FetchOutputOpConf>::BuildOpConf(OperatorConf* op_conf,\n                                                              const AttrMap& attrs) const {\n  CHECK_EQ_OR_RETURN(attrs.size(), 0);\n  *(op_conf->mutable_name()) = op_name_;\n  *(op_conf->mutable_fetch_output_conf()) = op_proto_;\n  *(op_conf->mutable_loc()) = DispatchFrame::get_str();\n  return Maybe<void>::Ok();\n}\n\ntemplate<>\nMaybe<OpExprGradClosure> BuiltinOpExprImpl<FetchOutputOpConf>::GetOrCreateOpGradClosure() const {\n  if (!op_grad_func_.get()) {\n    op_grad_func_.reset(NewObj<std::string, OpExprGradFunctionIf>(\"graph_feed_and_fetch\"));\n    CHECK_NOTNULL_OR_RETURN(op_grad_func_.get());  // NOLINT\n    JUST(op_grad_func_->Init(*this));\n  }\n  return std::make_shared<OpExprGradClosure>(op_grad_func_);\n}\n\ntemplate<>\nMaybe<void> BuiltinOpExprImpl<ImageDecoderRandomCropResizeOpConf>::BuildOpConf(\n    OperatorConf* op_conf, const AttrMap& attrs) const {\n  *(op_conf->mutable_name()) = op_name_;\n  *(op_conf->mutable_image_decoder_random_crop_resize_conf()) = op_proto_;\n  *(op_conf->mutable_loc()) = DispatchFrame::get_str();\n  auto* proto = op_conf->mutable_image_decoder_random_crop_resize_conf();\n  proto->set_target_width(JUST(attrs.GetAttr<int64_t>(\"target_width\")));\n  proto->set_target_height(JUST(attrs.GetAttr<int64_t>(\"target_height\")));\n  proto->set_num_workers(JUST(attrs.GetAttr<int64_t>(\"num_workers\")));\n  proto->set_max_num_pixels(JUST(attrs.GetAttr<int64_t>(\"max_num_pixels\")));\n  proto->set_warmup_size(JUST(attrs.GetAttr<int64_t>(\"warmup_size\")));\n  proto->set_seed(JUST(attrs.GetAttr<int64_t>(\"seed\")));\n  proto->set_num_attempts(JUST(attrs.GetAttr<int64_t>(\"num_attempts\")));\n  proto->set_random_area_min(JUST(attrs.GetAttr<float>(\"random_area_min\")));\n  proto->set_random_area_max(JUST(attrs.GetAttr<float>(\"random_area_max\")));\n  proto->set_random_aspect_ratio_min(JUST(attrs.GetAttr<float>(\"random_aspect_ratio_min\")));\n  proto->set_random_aspect_ratio_max(JUST(attrs.GetAttr<float>(\"random_aspect_ratio_max\")));\n  return Maybe<void>::Ok();\n}\n\ntemplate<>\nMaybe<OpExprGradClosure>\nBuiltinOpExprImpl<ImageDecoderRandomCropResizeOpConf>::GetOrCreateOpGradClosure() const {\n  UNIMPLEMENTED_THEN_RETURN();\n}\n\ntemplate<>\nMaybe<void> BuiltinOpExprImpl<VariableOpConf>::BuildOpConf(OperatorConf* op_conf,\n                                                           const AttrMap& attrs) const {\n  CHECK_EQ_OR_RETURN(attrs.size(), 0);\n  *(op_conf->mutable_name()) = op_name_;\n  *(op_conf->mutable_variable_conf()) = op_proto_;\n  *(op_conf->mutable_loc()) = DispatchFrame::get_str();\n  return Maybe<void>::Ok();\n}\n\ntemplate<>\nMaybe<OpExprGradClosure> BuiltinOpExprImpl<VariableOpConf>::GetOrCreateOpGradClosure() const {\n  UNIMPLEMENTED_THEN_RETURN();\n}\n\ntemplate<>\nMaybe<void> BuiltinOpExprImpl<CastToLocalOpConf>::BuildOpConf(OperatorConf* op_conf,\n                                                              const AttrMap& attrs) const {\n  CHECK_EQ_OR_RETURN(attrs.size(), 0);\n  *(op_conf->mutable_name()) = op_name_;\n  *(op_conf->mutable_cast_to_local_conf()) = op_proto_;\n  *(op_conf->mutable_loc()) = DispatchFrame::get_str();\n  return Maybe<void>::Ok();\n}\n\ntemplate<>\nMaybe<OpExprGradClosure> BuiltinOpExprImpl<CastToLocalOpConf>::GetOrCreateOpGradClosure() const {\n  UNIMPLEMENTED_THEN_RETURN();\n}\n\ntemplate<>\nMaybe<void> BuiltinOpExprImpl<CastFromLocalOpConf>::BuildOpConf(OperatorConf* op_conf,\n                                                                const AttrMap& attrs) const {\n  CHECK_EQ_OR_RETURN(attrs.size(), 0);\n  *(op_conf->mutable_name()) = op_name_;\n  *(op_conf->mutable_cast_from_local_conf()) = op_proto_;\n  *(op_conf->mutable_loc()) = DispatchFrame::get_str();\n  return Maybe<void>::Ok();\n}\n\ntemplate<>\nMaybe<OpExprGradClosure> BuiltinOpExprImpl<CastFromLocalOpConf>::GetOrCreateOpGradClosure() const {\n  UNIMPLEMENTED_THEN_RETURN();\n}\n\nMaybe<OpExprGradClosure> GlobalToGlobalOpExpr::GetOrCreateOpGradClosure() const {\n  if (!op_grad_func_.get()) {\n    op_grad_func_.reset(NewObj<std::string, OpExprGradFunctionIf>(\"global_to_global\"));\n    CHECK_NOTNULL_OR_RETURN(op_grad_func_.get());\n    JUST(op_grad_func_->Init(*this));\n  }\n  return std::make_shared<OpExprGradClosure>(op_grad_func_);\n}\n\nMaybe<OpExprGradClosure> LocalToGlobalOpExpr::GetOrCreateOpGradClosure() const {\n  if (!op_grad_func_.get()) {\n    op_grad_func_.reset(NewObj<std::string, OpExprGradFunctionIf>(\"local_to_global\"));\n    CHECK_NOTNULL_OR_RETURN(op_grad_func_.get());\n    JUST(op_grad_func_->Init(*this));\n  }\n  return std::make_shared<OpExprGradClosure>(op_grad_func_);\n}\n\nMaybe<OpExprGradClosure> GlobalToLocalOpExpr::GetOrCreateOpGradClosure() const {\n  if (!op_grad_func_.get()) {\n    op_grad_func_.reset(NewObj<std::string, OpExprGradFunctionIf>(\"global_to_local\"));\n    CHECK_NOTNULL_OR_RETURN(op_grad_func_.get());\n    JUST(op_grad_func_->Init(*this));\n  }\n  return std::make_shared<OpExprGradClosure>(op_grad_func_);\n}\n\ntemplate<>\nMaybe<void> BuiltinOpExprImpl<DistributeSplitOpConf>::BuildOpConf(OperatorConf* op_conf,\n                                                                  const AttrMap& attrs) const {\n  CHECK_EQ_OR_RETURN(attrs.size(), 0);\n  *(op_conf->mutable_name()) = op_name_;\n  *(op_conf->mutable_distribute_split_conf()) = op_proto_;\n  *(op_conf->mutable_loc()) = DispatchFrame::get_str();\n  return Maybe<void>::Ok();\n}\n\ntemplate<>\nMaybe<OpExprGradClosure> BuiltinOpExprImpl<DistributeSplitOpConf>::GetOrCreateOpGradClosure()\n    const {\n  UNIMPLEMENTED_THEN_RETURN();\n}\n\ntemplate<>\nMaybe<void> BuiltinOpExprImpl<DistributeCloneOpConf>::BuildOpConf(OperatorConf* op_conf,\n                                                                  const AttrMap& attrs) const {\n  CHECK_EQ_OR_RETURN(attrs.size(), 0);\n  *(op_conf->mutable_name()) = op_name_;\n  *(op_conf->mutable_distribute_clone_conf()) = op_proto_;\n  *(op_conf->mutable_loc()) = DispatchFrame::get_str();\n  return Maybe<void>::Ok();\n}\n\ntemplate<>\nMaybe<OpExprGradClosure> BuiltinOpExprImpl<DistributeCloneOpConf>::GetOrCreateOpGradClosure()\n    const {\n  UNIMPLEMENTED_THEN_RETURN();\n}\n\ntemplate<>\nMaybe<void> BuiltinOpExprImpl<DistributeConcatOpConf>::BuildOpConf(OperatorConf* op_conf,\n                                                                   const AttrMap& attrs) const {\n  CHECK_EQ_OR_RETURN(attrs.size(), 0);\n  *(op_conf->mutable_name()) = op_name_;\n  *(op_conf->mutable_distribute_concat_conf()) = op_proto_;\n  *(op_conf->mutable_loc()) = DispatchFrame::get_str();\n  return Maybe<void>::Ok();\n}\n\ntemplate<>\nMaybe<OpExprGradClosure> BuiltinOpExprImpl<DistributeConcatOpConf>::GetOrCreateOpGradClosure()\n    const {\n  UNIMPLEMENTED_THEN_RETURN();\n}\n\ntemplate<>\nMaybe<void> BuiltinOpExprImpl<DistributeAddOpConf>::BuildOpConf(OperatorConf* op_conf,\n                                                                const AttrMap& attrs) const {\n  CHECK_EQ_OR_RETURN(attrs.size(), 0);\n  *(op_conf->mutable_name()) = op_name_;\n  *(op_conf->mutable_distribute_add_conf()) = op_proto_;\n  *(op_conf->mutable_loc()) = DispatchFrame::get_str();\n  return Maybe<void>::Ok();\n}\n\ntemplate<>\nMaybe<OpExprGradClosure> BuiltinOpExprImpl<DistributeAddOpConf>::GetOrCreateOpGradClosure() const {\n  UNIMPLEMENTED_THEN_RETURN();\n}\n\nMaybe<OpExprGradClosure> SelectTopNOpExpr::GetOrCreateOpGradClosure() const {\n  if (!op_grad_func_.get()) {\n    op_grad_func_.reset(NewObj<std::string, OpExprGradFunctionIf>(\"select_top_n\"));\n    CHECK_NOTNULL_OR_RETURN(op_grad_func_.get());\n    JUST(op_grad_func_->Init(*this));\n  }\n  return std::make_shared<OpExprGradClosure>(op_grad_func_);\n}\n\nvoid FunctionOpExpr::reset_state() const { state_.reset(new FunctionAutoGradCaptureState); }\n\nMaybe<OpExprGradClosure> FunctionOpExpr::GetOrCreateOpGradClosure() const {\n  if (!op_grad_func_) {\n    op_grad_func_.reset(new FunctionOpExprGradFunction(func_name_, backward_fn_));\n  }\n  return std::make_shared<OpExprGradClosure>(op_grad_func_, state_);\n}\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/framework/op_expr.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_FRAMEWORK_OP_EXPR_H_\n#define ONEFLOW_CORE_FRAMEWORK_OP_EXPR_H_\n\n#include <string>\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/common/symbol.h\"\n#include \"oneflow/core/common/optional.h\"\n#include \"oneflow/core/job/sbp_parallel.h\"\n#include \"oneflow/core/operator/op_conf.pb.h\"\n#include \"oneflow/core/framework/attr_map.h\"\n#include \"oneflow/core/framework/autocast.h\"\n#include \"oneflow/core/framework/device.h\"\n#include \"oneflow/core/framework/stream.h\"\n#include \"oneflow/core/framework/tensor_tuple.h\"\n#include \"oneflow/core/framework/user_op_conf.pb.h\"\n#include \"oneflow/core/framework/user_op_registry.h\"\n#include \"oneflow/core/framework/arg_tuple.h\"\n#include \"oneflow/core/autograd/autograd_function.h\"\n#include \"oneflow/core/job/lazy_mode.h\"\n#include \"oneflow/core/framework/op_interpreter/dispatch_frame.h\"\n\nnamespace oneflow {\nnamespace one {\n\nclass OpExprGradFunctionIf;\nclass OpExprGradClosure;\n\nclass OpExpr {\n public:\n  virtual ~OpExpr() = default;\n  virtual const std::string& op_type_name() const = 0;\n\n  virtual int input_size() const = 0;\n  virtual int output_size() const = 0;\n\n  virtual Maybe<bool> IsGradDisabled() const = 0;\n  virtual Maybe<bool> SupportNonContiguous() const = 0;\n\n  virtual Maybe<OpExprGradClosure> GetOrCreateOpGradClosure() const = 0;\n\n  virtual Maybe<autocast::AutoCastMeta> GetOrCreateAutoCastMeta() const;\n\n protected:\n  OpExpr() = default;\n};\n\nclass BuiltinOpExpr : public OpExpr {\n public:\n  explicit BuiltinOpExpr(const std::string& op_name, const std::vector<std::string>& indexed_ibns,\n                         const std::vector<std::string>& indexed_obns);\n\n  virtual ~BuiltinOpExpr() = default;\n\n  const std::string& op_name() const { return op_name_; }\n\n  int input_size() const override { return input_arg_tuple_->size(); }\n  int output_size() const override { return output_arg_tuple_->size(); }\n\n  const std::shared_ptr<const ArgTuple>& input_arg_tuple() const { return input_arg_tuple_; }\n  const std::shared_ptr<const ArgTuple>& output_arg_tuple() const { return output_arg_tuple_; }\n\n  const std::vector<std::string>& indexed_ibns() const { return input_arg_tuple_->indexed_bns(); }\n  const std::vector<std::string>& indexed_obns() const { return output_arg_tuple_->indexed_bns(); }\n  const std::vector<std::pair<std::string, int32_t>>& indexed_input_pairs() const {\n    return input_arg_tuple_->indexed_arg_name_and_index();\n  }\n  const std::vector<std::pair<std::string, int32_t>>& indexed_output_pairs() const {\n    return output_arg_tuple_->indexed_arg_name_and_index();\n  }\n\n  virtual Maybe<void> BuildOpConf(OperatorConf* op_conf, const AttrMap& attrs) const = 0;\n\n protected:\n  std::string op_name_;\n  std::shared_ptr<const ArgTuple> input_arg_tuple_;\n  std::shared_ptr<const ArgTuple> output_arg_tuple_;\n};\n\nclass TensorMeta;\n\ntemplate<typename ProtoType>\nclass BuiltinOpExprImpl : public BuiltinOpExpr {\n public:\n  static Maybe<BuiltinOpExprImpl<ProtoType>> New(const std::string& op_name, ProtoType&& op_proto,\n                                                 const std::vector<std::string>& indexed_ibns,\n                                                 const std::vector<std::string>& indexed_obns) {\n    return std::shared_ptr<BuiltinOpExprImpl<ProtoType>>(\n        new BuiltinOpExprImpl<ProtoType>(op_name, std::move(op_proto), indexed_ibns, indexed_obns));\n  }\n\n  virtual ~BuiltinOpExprImpl() = default;\n\n  const ProtoType& proto() const { return op_proto_; }\n  ProtoType* mutable_proto() { return &op_proto_; }\n\n  const std::string& op_type_name() const override;\n\n  Maybe<bool> IsGradDisabled() const override;\n\n  Maybe<bool> SupportNonContiguous() const override;\n\n  Maybe<OpExprGradClosure> GetOrCreateOpGradClosure() const override;\n\n  Maybe<autocast::AutoCastMeta> GetOrCreateAutoCastMeta() const override;\n\n  Maybe<void> BuildOpConf(OperatorConf* op_conf, const AttrMap& attrs) const override;\n\n protected:\n  explicit BuiltinOpExprImpl(const std::string& op_name, ProtoType&& op_proto,\n                             const std::vector<std::string>& indexed_ibns,\n                             const std::vector<std::string>& indexed_obns)\n      : BuiltinOpExpr(op_name, indexed_ibns, indexed_obns), op_proto_(std::move(op_proto)) {}\n\n  ProtoType op_proto_;\n  mutable std::shared_ptr<OpExprGradFunctionIf> op_grad_func_;\n  mutable std::shared_ptr<autocast::AutoCastMeta> autocast_meta_;\n};\n\nclass StatefulOpKernel;\nclass LocalTensorInferCache;\nclass GlobalTensorInferCache;\n\nclass UserOpExpr final : public BuiltinOpExprImpl<UserOpConf> {\n public:\n  UserOpExpr() = delete;\n  virtual ~UserOpExpr() = default;\n\n  static Maybe<UserOpExpr> New(const std::string& op_name, UserOpConf&& op_proto,\n                               const std::vector<std::string>& indexed_ibns,\n                               const std::vector<std::string>& indexed_obns);\n\n  const AttrMap& base_attrs() const { return base_attrs_; }\n\n  Maybe<StatefulOpKernel> MutKernel4Stream(Symbol<Stream> stream) const;\n\n  bool has_device_and_stream_infer_fn() const {\n    return static_cast<bool>(device_and_stream_infer_fn_);\n  }\n  const user_op::DeviceAndStreamInferFn& device_and_stream_infer_fn() const {\n    return device_and_stream_infer_fn_;\n  }\n\n  bool IsHostMemoryInput(int32_t input_index) const {\n    return std::find(host_memory_input_ids_.begin(), host_memory_input_ids_.end(), input_index)\n           != host_memory_input_ids_.end();\n  }\n\n  Maybe<void> InferPhysicalTensorDesc(\n      const AttrMap& attrs, const std::string& device_tag,\n      const std::function<const TensorMeta*(int32_t)>& TensorMeta4InputIndex,\n      const std::function<TensorMeta*(int32_t)>& TensorMeta4OutputIndex) const;\n\n  Maybe<void> InferLogicalTensorDesc(\n      const AttrMap& attrs, Symbol<ParallelDesc> parallel_desc,\n      const std::function<const TensorMeta*(int32_t)>& TensorMeta4InputIndex,\n      const std::function<TensorMeta*(int32_t)>& TensorMeta4OutputIndex) const;\n  Maybe<Symbol<Stream>> InferDeviceAndStream(const AttrMap& attrs, const TensorTuple& inputs,\n                                             TensorTuple* outputs) const;\n  LocalTensorInferCache* mut_local_tensor_infer_cache() const {\n    return local_tensor_infer_cache_.get();\n  }\n  GlobalTensorInferCache* mut_global_tensor_infer_cache() const {\n    return global_tensor_infer_cache_.get();\n  }\n\n private:\n  UserOpExpr(const std::string& op_name, UserOpConf&& proto, const AttrMap& base_attrs,\n             const std::vector<std::string>& indexed_ibns,\n             const std::vector<std::string>& indexed_obns);\n  Maybe<void> Init(const std::shared_ptr<const UserOpExpr>& self);\n  AttrMap base_attrs_;\n  user_op::TensorDescInferFn logical_tensor_desc_infer_fn_;\n  user_op::TensorDescInferFn physical_tensor_desc_infer_fn_;\n  user_op::DataTypeInferFn dtype_infer_fn_;\n  user_op::DeviceAndStreamInferFn device_and_stream_infer_fn_;\n  mutable HashMap<Symbol<Stream>, std::shared_ptr<StatefulOpKernel>> stream2kernel_;\n  std::shared_ptr<LocalTensorInferCache> local_tensor_infer_cache_;\n  std::shared_ptr<GlobalTensorInferCache> global_tensor_infer_cache_;\n  small_vector<int32_t> host_memory_input_ids_;\n};\n\nclass GlobalToGlobalOpExpr : public OpExpr {\n public:\n  virtual ~GlobalToGlobalOpExpr() = default;\n\n  static Maybe<GlobalToGlobalOpExpr> New(const Optional<Symbol<NdSbp>>& grad_nd_sbp);\n\n  const Optional<Symbol<NdSbp>>& grad_nd_sbp() const { return grad_nd_sbp_; }\n  const std::string& op_type_name() const override;\n  int input_size() const override { return 1; }\n  int output_size() const override { return 1; }\n\n  Maybe<bool> IsGradDisabled() const override { return false; }\n  Maybe<bool> SupportNonContiguous() const override { return false; }\n  Maybe<OpExprGradClosure> GetOrCreateOpGradClosure() const override;\n\n protected:\n  GlobalToGlobalOpExpr(const Optional<Symbol<NdSbp>>& grad_nd_sbp);\n\n  Optional<Symbol<NdSbp>> grad_nd_sbp_;  //  Reserved for configuring grad sbp\n  mutable std::shared_ptr<OpExprGradFunctionIf> op_grad_func_;\n};\n\nclass CastGlobalOpExpr : public OpExpr {\n public:\n  virtual ~CastGlobalOpExpr() = default;\n\n  const std::string& op_name() const { return op_name_; }\n  int input_size() const override { return 1; }\n  int output_size() const override { return 1; }\n\n  Maybe<bool> IsGradDisabled() const override { return false; }\n  Maybe<bool> SupportNonContiguous() const override { return false; }\n\n protected:\n  CastGlobalOpExpr(const std::string& op_name);\n\n  std::string op_name_;\n  mutable std::shared_ptr<OpExprGradFunctionIf> op_grad_func_;\n};\n\nclass LocalToGlobalOpExpr final : public CastGlobalOpExpr {\n public:\n  ~LocalToGlobalOpExpr() = default;\n\n  static Maybe<LocalToGlobalOpExpr> New(const std::string& op_name);\n\n  const std::string& op_type_name() const override;\n  Maybe<OpExprGradClosure> GetOrCreateOpGradClosure() const override;\n\n private:\n  LocalToGlobalOpExpr(const std::string& op_name);\n};\n\nclass GlobalToLocalOpExpr final : public CastGlobalOpExpr {\n public:\n  ~GlobalToLocalOpExpr() = default;\n\n  static Maybe<GlobalToLocalOpExpr> New(const std::string& op_name);\n\n  const std::string& op_type_name() const override;\n  Maybe<OpExprGradClosure> GetOrCreateOpGradClosure() const override;\n\n private:\n  GlobalToLocalOpExpr(const std::string& op_name);\n};\n\n// NOTE(chengcheng): For Lazy nn.Graph Feed/Fetch EagerTensor to/from LazyTensor.\nusing FeedInputOpExpr = BuiltinOpExprImpl<FeedInputOpConf>;\nusing FeedVariableOpExpr = BuiltinOpExprImpl<FeedVariableOpConf>;\nusing FetchOutputOpExpr = BuiltinOpExprImpl<FetchOutputOpConf>;\n\n// NOTE(chengcheng): Special SystemOp for image gpu decode.\nusing ImageDecoderRandomCropResizeOpExpr = BuiltinOpExprImpl<ImageDecoderRandomCropResizeOpConf>;\n\nusing VariableOpExpr = BuiltinOpExprImpl<VariableOpConf>;\nusing CastToLocalOpExpr = BuiltinOpExprImpl<CastToLocalOpConf>;\nusing CastFromLocalOpExpr = BuiltinOpExprImpl<CastFromLocalOpConf>;\nusing DistributeSplitOpExpr = BuiltinOpExprImpl<DistributeSplitOpConf>;\nusing DistributeCloneOpExpr = BuiltinOpExprImpl<DistributeCloneOpConf>;\nusing DistributeConcatOpExpr = BuiltinOpExprImpl<DistributeConcatOpConf>;\nusing DistributeAddOpExpr = BuiltinOpExprImpl<DistributeAddOpConf>;\n\nclass SelectTopNOpExpr final : public OpExpr {\n public:\n  static Maybe<SelectTopNOpExpr> New() {\n    return std::shared_ptr<SelectTopNOpExpr>(new SelectTopNOpExpr());\n  }\n\n  const std::string& op_type_name() const override {\n    static const std::string kOpTypeName = \"select_top_n\";\n    return kOpTypeName;\n  }\n\n  int input_size() const override {\n    UNIMPLEMENTED();\n    return 0;\n  }\n\n  int output_size() const override {\n    // output should be resized in apply function\n    return 0;\n  }\n\n  Maybe<bool> IsGradDisabled() const override { return false; }\n\n  Maybe<bool> SupportNonContiguous() const override { return false; }\n\n  Maybe<OpExprGradClosure> GetOrCreateOpGradClosure() const override;\n\n private:\n  SelectTopNOpExpr() = default;\n\n  mutable std::shared_ptr<OpExprGradFunctionIf> op_grad_func_;\n};\n\nclass AutoGradCaptureState;\n\nclass FunctionOpExpr final : public OpExpr {\n public:\n  using FType = AutogradFunctionBase::FType;\n  FunctionOpExpr() = delete;\n  static Maybe<FunctionOpExpr> New(const std::string& func_name, const FType& forward_fn,\n                                   const FType& backward_fn) {\n    return std::shared_ptr<FunctionOpExpr>(new FunctionOpExpr(func_name, forward_fn, backward_fn));\n  }\n\n  const std::string& op_type_name() const override { return func_name_; }\n\n  int input_size() const override {\n    PRINT_BUG_PROMPT_AND_ABORT() << \"You cannot get input_size here.\";\n    return 0;\n  }\n  int output_size() const override {\n    PRINT_BUG_PROMPT_AND_ABORT() << \"You cannot get output_size here.\";\n    return 0;\n  }\n\n  FType forward() const { return forward_fn_; }\n  FType backward() const { return backward_fn_; }\n\n  std::shared_ptr<FunctionAutoGradCaptureState> state() const { return state_; }\n  void reset_state() const;\n\n  Maybe<bool> IsGradDisabled() const override { return false; }\n  Maybe<bool> SupportNonContiguous() const override { return false; }\n  Maybe<OpExprGradClosure> GetOrCreateOpGradClosure() const override;\n\n private:\n  FunctionOpExpr(const std::string& func_name, const FType& forward_fn, const FType& backward_fn)\n      : forward_fn_(forward_fn), backward_fn_(backward_fn), func_name_(func_name) {}\n\n  FType forward_fn_;\n  FType backward_fn_;\n  std::string func_name_;\n  mutable std::shared_ptr<FunctionAutoGradCaptureState> state_;\n  mutable std::shared_ptr<OpExprGradFunctionIf> op_grad_func_;\n};\n\n}  // namespace one\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_FRAMEWORK_OP_EXPR_H_\n"
  },
  {
    "path": "oneflow/core/framework/op_expr_grad_function.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n\n#include \"oneflow/core/eager/eager_blob_object.h\"\n#include \"oneflow/core/eager/tensor_storage.h\"\n#include \"oneflow/core/framework/saved_tensor_hooks.h\"\n\nnamespace oneflow {\nnamespace one {\n\nvoid AutoGradCaptureState::unpack() {\n  if (saved_tensors_.empty() && !hooks_.empty()) {\n    for (const auto& hook : hooks_) { saved_tensors_.push_back(hook->unpack()); }\n    hooks_.clear();\n  }\n}\n\nsize_t AutoGradCaptureState::SaveTensorForBackward(const std::shared_ptr<Tensor>& tensor) {\n  auto hook = []() -> std::unique_ptr<SavedTensorHook> {\n    if (auto* hook_creator = Singleton<SavedTensorHookCreator>::Get()) {\n      return hook_creator->new_saved_tensor_hook();\n    }\n    return nullptr;\n  }();\n  if (hook) {\n    hook->pack(tensor);\n    size_t offset = hooks_.size();\n    hooks_.push_back(std::move(hook));\n    return offset;\n  } else {\n    size_t offset = saved_tensors_.size();\n    if (tensor->is_local() && tensor->is_eager()) {\n      if (auto rematable_storage = std::dynamic_pointer_cast<vm::RematableTensorStorage>(\n              CHECK_JUST(tensor->eager_blob_object())->tensor_storage())) {\n        rematable_storage->set_needed_by_backward();\n      }\n    }\n    saved_tensors_.emplace_back(tensor);\n    return offset;\n  }\n}\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/framework/op_expr_grad_function.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#ifndef ONEFLOW_CORE_FRAMEWORK_OP_EXPR_GRAD_FUNCTION_H_\n#define ONEFLOW_CORE_FRAMEWORK_OP_EXPR_GRAD_FUNCTION_H_\n\n#include \"oneflow/core/autograd/autograd_captured_tensor.h\"\n#include \"oneflow/core/common/auto_registration_factory.h\"\n#include \"oneflow/core/common/op_args_vector.h\"\n#include \"oneflow/core/framework/op_interpreter.h\"\n#include \"oneflow/core/profiler/profiler.h\"\n#include \"oneflow/core/framework/saved_tensor_hooks.h\"\n\nnamespace oneflow {\nnamespace one {\n\nstatic constexpr char kGradientOpSuffix[] = \".grad\";\n\nclass AutoGradCaptureState {\n public:\n  AutoGradCaptureState() = default;\n  virtual ~AutoGradCaptureState() = default;\n\n  void unpack();\n\n  const TensorTuple& SavedTensors() const { return saved_tensors_; }\n\n  size_t SaveTensorForBackward(const std::shared_ptr<Tensor>& tensor);\n\n public:\n  std::vector<bool> input_requires_grad;\n\n protected:\n  TensorTuple saved_tensors_;\n  small_vector<std::unique_ptr<SavedTensorHook>, TensorTuple::kInitialSize> hooks_;\n};\n\nclass FunctionAutoGradCaptureState final\n    : public AutoGradCaptureState,\n      public std::enable_shared_from_this<FunctionAutoGradCaptureState> {\n public:\n  FunctionAutoGradCaptureState() : pyobj_ptr_(nullptr, [](void*) {}) {}\n  using AutoGradCaptureState::SavedTensors;\n  using AutoGradCaptureState::SaveTensorForBackward;\n\n  void MarkNonDifferentiable(const std::shared_ptr<Tensor>& tensor) {\n    non_differentiable_tensors_.emplace(tensor.get());\n  }\n\n  HashSet<Tensor*> NonDifferentiableTensors() const { return non_differentiable_tensors_; }\n\n  std::shared_ptr<FunctionAutoGradCaptureState> GetSharedFromThis() { return shared_from_this(); }\n\n  // NOTE(wyg): Hold PyOjbect ptr to ensure getting the same object when casting to python.\n  // And decrease the reference count when C++ object is destructed to avoid memory leaking.\n  void* pyobject() const { return pyobj_ptr_.get(); }\n  void set_pyobject_ptr(std::unique_ptr<void, void (*)(void*)>&& pyobj_ptr) {\n    pyobj_ptr_ = std::move(pyobj_ptr);\n  }\n\n public:\n  std::vector<bool> input_requires_grad;\n\n private:\n  HashSet<Tensor*> non_differentiable_tensors_;\n  std::unique_ptr<void, void (*)(void*)> pyobj_ptr_;\n};\n\n// Stateless container base of the backward op exprs.\n// The backward op exprs should be contained in the derived class.\nclass OpExprGradFunctionIf {\n public:\n  virtual ~OpExprGradFunctionIf() = default;\n\n  virtual std::shared_ptr<AutoGradCaptureState> MakeCustomState() const = 0;\n\n  virtual Maybe<void> Init(const OpExpr& op) = 0;\n\n  // Capture forward inputs and outputs for backward.\n  virtual Maybe<void> CaptureIf(AutoGradCaptureState* ctx, const TensorTuple& inputs,\n                                const TensorTuple& outputs,\n                                const OpExprInterpContext& interp_ctx) const = 0;\n\n  virtual Maybe<void> ApplyIf(const AutoGradCaptureState* ctx, const TensorTuple& out_grads,\n                              TensorTuple* in_grads) const = 0;\n};\n\ntemplate<typename StateT>\nclass OpExprGradFunction : public OpExprGradFunctionIf {\n public:\n  std::shared_ptr<AutoGradCaptureState> MakeCustomState() const override {\n    return std::make_shared<StateT>();\n  }\n\n  Maybe<void> CaptureIf(AutoGradCaptureState* ctx, const TensorTuple& inputs,\n                        const TensorTuple& outputs,\n                        const OpExprInterpContext& interp_ctx) const override {\n    StateT* state = dynamic_cast<StateT*>(ctx);\n    CHECK_NOTNULL_OR_RETURN(state);\n    // Convert outputs from `Tensor` to `AutogradCapturedTensor` to avoid\n    // circular reference between `Tensor` and `FunctionNode`.\n    OF_PROFILER_RANGE_PUSH(\"init inputs\");\n    TensorTuple captured_inputs(inputs.size());\n    for (int i = 0; i < inputs.size(); ++i) {\n      captured_inputs[i] = JUST(AutogradCapturedTensor::MakeTensor(inputs.at(i)));\n    }\n    OF_PROFILER_RANGE_POP();\n    OF_PROFILER_RANGE_PUSH(\"init outputs\");\n    TensorTuple captured_outputs(outputs.size());\n    for (int i = 0; i < outputs.size(); ++i) {\n      captured_outputs[i] = JUST(AutogradCapturedTensor::MakeTensor(outputs.at(i)));\n    }\n    OF_PROFILER_RANGE_POP();\n    OF_PROFILER_RANGE_GUARD(\"Capture\");\n    return Capture(state, captured_inputs, captured_outputs, interp_ctx);\n  }\n\n  Maybe<void> ApplyIf(const AutoGradCaptureState* ctx, const TensorTuple& out_grads,\n                      TensorTuple* in_grads) const override {\n    const StateT* state = dynamic_cast<const StateT*>(ctx);\n    CHECK_NOTNULL_OR_RETURN(state);\n    return Apply(state, out_grads, in_grads);\n  }\n\n protected:\n  virtual Maybe<void> Capture(StateT* ctx, const TensorTuple& inputs, const TensorTuple& outputs,\n                              const OpExprInterpContext& interp_ctx) const {\n    return Capture(ctx, inputs, outputs, interp_ctx.attrs);\n  }\n\n  virtual Maybe<void> Capture(StateT* ctx, const TensorTuple& inputs, const TensorTuple& outputs,\n                              const AttrMap& attrs) const {\n    UNIMPLEMENTED_THEN_RETURN();\n  }\n\n  virtual Maybe<void> Apply(const StateT* ctx, const TensorTuple& out_grads,\n                            TensorTuple* in_grads) const = 0;\n\n  std::string GradientOpName(const std::string& prefix) const {\n    return prefix + std::string(kGradientOpSuffix);\n  }\n};\n\nclass FunctionOpExprGradFunction final : public OpExprGradFunctionIf {\n public:\n  using FType = AutogradFunctionBase::FType;\n  FunctionOpExprGradFunction(const std::string& func_name, const FType& backward_fn)\n      : backward_fn_(backward_fn), op_name_(func_name) {}\n\n  std::shared_ptr<AutoGradCaptureState> MakeCustomState() const override {\n    PRINT_BUG_PROMPT_AND_ABORT()\n        << \"You should not construct AutoGradCaptureState by calling this function\";\n    return std::make_shared<FunctionAutoGradCaptureState>();\n  }\n\n  Maybe<void> Init(const OpExpr& op) override {\n    // do nothing\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> CaptureIf(AutoGradCaptureState* ctx, const TensorTuple& inputs,\n                        const TensorTuple& outputs,\n                        const OpExprInterpContext& interp_ctx) const override {\n    FunctionAutoGradCaptureState* func_ctx = dynamic_cast<FunctionAutoGradCaptureState*>(ctx);\n    func_ctx->input_requires_grad.resize(inputs.size());\n    for (int i = 0; i < inputs.size(); ++i) {\n      func_ctx->input_requires_grad[i] = inputs.at(i)->requires_grad();\n    }\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> ApplyIf(const AutoGradCaptureState* ctx, const TensorTuple& out_grads,\n                      TensorTuple* in_grads) const override {\n    const FunctionAutoGradCaptureState* func_ctx =\n        dynamic_cast<const FunctionAutoGradCaptureState*>(ctx);\n    CHECK_NOTNULL_OR_RETURN(func_ctx);\n    const std::shared_ptr<TensorTuple>& out = backward_fn_(\n        const_cast<FunctionAutoGradCaptureState*>(func_ctx)->GetSharedFromThis(), out_grads);\n    in_grads->resize(func_ctx->input_requires_grad.size());\n    CHECK_EQ_OR_RETURN(out->size(), in_grads->size())\n        << \"RuntimeError: function \" << op_name_\n        << \" returned an incorrect number of gradients (expected \" << in_grads->size() << \", got \"\n        << out->size() << \")\";\n    for (int i = 0; i < in_grads->size(); ++i) {\n      if (func_ctx->input_requires_grad[i]) {\n        if (!out->at(i)) {\n          return Error::RuntimeError()\n                 << \"autograd.Function named \" << op_name_ << \"'s inputs[\" << i\n                 << \"] requires grad but got None grad. Please use Tensor.detach() for this \"\n                    \"input.\";\n        }\n        in_grads->at(i) = out->at(i);\n      }\n    }\n    return Maybe<void>::Ok();\n  }\n\n protected:\n  FType backward_fn_;\n  std::string op_name_;\n};\n\n// Stateful wrapper of the `OpExprGradFunction`.\nclass OpExprGradClosure {\n public:\n  // Use `shared_ptr` in order to keep `impl` alive even if the forward op has been released.\n  explicit OpExprGradClosure(const std::shared_ptr<OpExprGradFunctionIf>& impl)\n      : OpExprGradClosure(impl, impl->MakeCustomState()) {}\n  explicit OpExprGradClosure(const std::shared_ptr<OpExprGradFunctionIf>& impl,\n                             const std::shared_ptr<AutoGradCaptureState>& state)\n      : impl_(impl), state_(state) {}\n\n  virtual ~OpExprGradClosure() = default;\n\n  Maybe<void> Capture(const TensorTuple& inputs, const TensorTuple& outputs,\n                      const OpExprInterpContext& interp_ctx) const {\n    return impl_->CaptureIf(state_.get(), inputs, outputs, interp_ctx);\n  }\n\n  Maybe<void> Apply(const TensorTuple& out_grads, TensorTuple* in_grads) const {\n    state_->unpack();\n    return impl_->ApplyIf(state_.get(), out_grads, in_grads);\n  }\n\n  const std::shared_ptr<AutoGradCaptureState>& state() const { return state_; }\n\n private:\n  std::shared_ptr<OpExprGradFunctionIf> impl_;\n  std::shared_ptr<AutoGradCaptureState> state_;\n};\n\n#define REGISTER_OP_EXPR_GRAD_FUNCTION(op_type, op_grad) \\\n  REGISTER_CLASS_CREATOR(std::string, op_type, OpExprGradFunctionIf, ([]() { return new op_grad; }))\n\n}  // namespace one\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_FRAMEWORK_OP_EXPR_GRAD_FUNCTION_H_\n"
  },
  {
    "path": "oneflow/core/framework/op_interpreter/dispatch_frame.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/op_interpreter/dispatch_frame.h\"\n#include <string>\n\nnamespace oneflow {\n\n/* static */ std::string* DispatchFrame::get_str_ptr() {\n  static thread_local std::string frame_str = \"\";\n  return &frame_str;\n}\n\n/* static */ const std::string& DispatchFrame::get_str() { return *get_str_ptr(); }\n\n/* static */ void DispatchFrame::set_str(const std::string& str) { *get_str_ptr() = str; }\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/framework/op_interpreter/dispatch_frame.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_FRAMEWORK_OP_INTERPRETER_DISPATCH_FRAME_H_\n#define ONEFLOW_CORE_FRAMEWORK_OP_INTERPRETER_DISPATCH_FRAME_H_\n\n#include \"oneflow/core/common/util.h\"\n\nnamespace oneflow {\n\nclass DispatchFrame {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(DispatchFrame);\n  DispatchFrame() = delete;\n  ~DispatchFrame() = delete;\n\n  static const std::string& get_str();\n  static void set_str(const std::string& str);\n\n  class Guard {\n   public:\n    explicit Guard(const std::string& frame_str) : prev_frame_str_(DispatchFrame::get_str()) {\n      DispatchFrame::set_str(frame_str);\n    }\n    ~Guard() { DispatchFrame::set_str(prev_frame_str_); }\n\n   private:\n    std::string prev_frame_str_;\n  };\n\n private:\n  static std::string* get_str_ptr();\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_FRAMEWORK_OP_INTERPRETER_DISPATCH_FRAME_H_\n"
  },
  {
    "path": "oneflow/core/framework/op_interpreter/eager_global_op_interpreter.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/to_string.h\"\n#include \"oneflow/core/framework/op_interpreter.h\"\n#include \"oneflow/core/framework/op_interpreter/op_interpreter_util.h\"\n#include \"oneflow/core/framework/instructions_builder.h\"\n#include \"oneflow/core/framework/scope_util.h\"\n#include \"oneflow/core/framework/session_util.h\"\n#include \"oneflow/core/framework/symbol_storage_util.h\"\n#include \"oneflow/core/framework/tensor.h\"\n#include \"oneflow/core/framework/tensor_name_scope.h\"\n#include \"oneflow/core/framework/tensor_tuple.h\"\n#include \"oneflow/core/framework/global_tensor_infer_cache.h\"\n#include \"oneflow/core/operator/operator.h\"\n#include \"oneflow/core/autograd/autograd_mode.h\"\n#include \"oneflow/core/boxing/eager_boxing_interpreter_mgr.h\"\n#include \"oneflow/user/kernels/stateful_opkernel.h\"\n#include \"oneflow/core/framework/consistency_check.h\"\n#include \"oneflow/core/framework/tensor_rpc_util.h\"\n#include \"oneflow/core/framework/tensor_global_id.h\"\n#include \"oneflow/core/framework/nd_sbp.h\"\n#include \"oneflow/core/common/decorator.h\"\n#include \"oneflow/core/boxing/eager_boxing_logger.h\"\n#include \"oneflow/core/common/cpp_attribute.h\"\n\nnamespace oneflow {\nnamespace one {\n\nnamespace {\n\nbool IsEnvEnableGlobalInputsWithInConsistentPlacement() {\n  const bool env_enable_inconsistent_placement =\n      ParseBooleanFromEnv(\"ONEFLOW_ENABLE_GLOBAL_INPUTS_WITH_INCONSISTENT_PLACEMENT\", false);\n  return env_enable_inconsistent_placement;\n}\n\nMaybe<bool> IsInputsParallelDescIdentical(\n    const std::shared_ptr<GlobalTensorMetaInferArgs>& infer_args) {\n  if (infer_args->input_global_tensor_metas().empty()) { return true; }\n  Symbol<ParallelDesc> default_parallel_desc =\n      JUST(VectorAt(infer_args->input_global_tensor_metas(), 0)).tensor_meta()->parallel_desc();\n\n  for (int i = 1; i < infer_args->input_global_tensor_metas().size(); ++i) {\n    const auto& parallel_desc = JUST(VectorAt(infer_args->input_global_tensor_metas(), i))\n                                    .tensor_meta()\n                                    ->parallel_desc()\n                                    ->data();\n    if (!default_parallel_desc->EqualsIgnoringDeviceType(parallel_desc)) { return false; }\n  }\n  return true;\n}\n\nconstexpr auto* IsAllInputsParallelDescIdentical =\n    DECORATE(&IsInputsParallelDescIdentical, ThreadLocalCopiable);\n\nMaybe<int> MaxRankNumber(Symbol<ParallelDesc> placement) {\n  // Find max rank number of a tensor's placement\n  // e.g. tensor's placement is [[0,1,2],[2,3,4],[7,8,9]]\n  // then max rank number is 9\n  return placement->sorted_machine_ids().back();\n}\n\nconstexpr auto* GetMaxRankNumber = DECORATE(&MaxRankNumber, ThreadLocalCachedCopiable);\n\nMaybe<Symbol<ParallelDesc>> MaxRankTensorPlacement(\n    const std::shared_ptr<GlobalTensorMetaInferArgs>& infer_args) {\n  // Find the max rank tensor id in all input tensors.\n  // e.g. if there are three tensor in inputs\n  //        tensor        parallel_desc\n  // inputs[0] tensor a    [0, 1, 2]\n  // inputs[1] tensor b    [3, 4, 5]\n  // inputs[2] tensor c    [2, 3, 4]\n  // then max rank number is 5, max rank tensor is b, max rank tensor id is 1\n  const auto& global_tensor_metas = infer_args->input_global_tensor_metas();\n  CHECK_OR_RETURN(global_tensor_metas.size() > 0);  // NOLINT\n  int64_t max_rank_tensor_id = 0;\n  int64_t max_rank = 0;\n  for (int64_t i = 0; i < global_tensor_metas.size(); ++i) {\n    int64_t tensor_max_rank = JUST(\n        GetMaxRankNumber(JUST(VectorAt(global_tensor_metas, i)).tensor_meta()->parallel_desc()));\n    if (tensor_max_rank >= max_rank) {\n      max_rank = tensor_max_rank;\n      max_rank_tensor_id = i;\n    }\n  }\n  return JUST(VectorAt(global_tensor_metas, max_rank_tensor_id)).tensor_meta()->parallel_desc();\n}\n\nconstexpr auto* GetMaxRankTensorPlacement =\n    DECORATE(&MaxRankTensorPlacement, ThreadLocalCachedCopiable);\n\nMaybe<Symbol<ParallelDesc>> GetParallelDesc(const TensorTuple& inputs,\n                                            const OpExprInterpContext& ctx,\n                                            const UserOpExpr& user_op_expr) {\n  if (!inputs.empty()) {\n    for (int32_t i = 0; i < inputs.size(); ++i) {\n      if (!user_op_expr.IsHostMemoryInput(i)) { return inputs.at(i)->parallel_desc(); }\n    }\n  }\n  return JUST(ctx.parallel_desc);\n}\n\nstd::string GetDynamicOpGlobalFailedDebugString(const UserOpExpr& user_op_expr,\n                                                const StatefulOpKernel& kernel) {\n  CHECK(!kernel.output_tuple_indexes4mut2_obns().empty());\n  std::string plentysuffix = kernel.output_tuple_indexes4mut2_obns().size() == 1 ? \"s\" : \"\";\n  std::stringstream ss;\n  ss << \"operator `\" << user_op_expr.op_type_name() << \"`\"\n     << \" does not support global mode because the shape\" << plentysuffix << \" of output tensor\"\n     << plentysuffix << \" \";\n  int i = 0;\n  for (const auto& out_index : kernel.output_tuple_indexes4mut2_obns()) {\n    if (i++ > 0) { ss << \", \"; }\n    ss << out_index;\n  }\n  ss << \" are not infered before op computation.\";\n  return ss.str();\n}\n\nMaybe<bool> IsAllZeroSizeTensorMeta(const std::vector<Symbol<GlobalTensorMeta>>& tensor_metas) {\n  if (tensor_metas.empty()) { return false; }\n  for (const auto& tensor_meta : tensor_metas) {\n    if (tensor_meta->shape().elem_cnt() != 0) { return false; }\n  }\n  return true;\n}\n\nconstexpr auto* CachedIsAllZeroSizeTensorMeta =\n    DECORATE(&IsAllZeroSizeTensorMeta, ThreadLocalCopiable);\n\nMaybe<Tensor> CalcBoxingOutput(const std::shared_ptr<Tensor>& input, Symbol<NdSbp> out_nd_sbp,\n                               Symbol<ParallelDesc> out_parallel_desc,\n                               bool current_rank_local_is_valid) {\n  const auto& logical_shape = input->shape();\n  // If the input is a tensor of size 0, construct the output directly.\n  if (unlikely(logical_shape->elem_cnt() == 0)) {\n    GlobalTensorMeta tensor_meta(*logical_shape, input->dtype()->data_type(),\n                                 input->memory_format(), out_nd_sbp, out_parallel_desc);\n    const auto& tensor_impl =\n        JUST(EagerGlobalTensorImpl::New(SymbolOf(tensor_meta), input->requires_grad(), false));\n    std::shared_ptr<Tensor> output = std::make_shared<GlobalTensor>(tensor_impl);\n    return output;\n  }\n  const auto* mgr = Singleton<EagerBoxingInterpreterManager>::Get();\n  // Eager boxing\n  const auto& in_nd_sbp = JUST(input->nd_sbp());\n  const auto& in_parallel_desc = JUST(input->parallel_desc());\n  const auto& boxing_interpreter = JUST(mgr->GetEagerBoxingInterpreter(\n      in_nd_sbp, out_nd_sbp, in_parallel_desc, out_parallel_desc, *logical_shape));\n  Singleton<const EagerBoxingLogger>::Get()->Log(\n      *JUST(boxing_interpreter->boxing_interpreter_status()), /* prefix */ \"\");\n  if (!current_rank_local_is_valid) { return input; }\n  const auto& output = JUST(boxing_interpreter->Interpret(input, in_nd_sbp, out_nd_sbp,\n                                                          in_parallel_desc, out_parallel_desc));\n  return output;\n}\n\nauto* GetBoxingOutput =\n    DECORATE(DECORATE(&CalcBoxingOutput, CheckGlobalTensorMeta), DisableRecusiveBoxingCall);\n\nMaybe<void> Interpret(const UserOpExpr& user_op_expr, const TensorTuple& inputs,\n                      TensorTuple* outputs, const OpExprInterpContext& ctx) {\n  CHECK_EQ_OR_RETURN(outputs->size(), user_op_expr.output_size());\n  Symbol<oneflow::ParallelDesc> parallel_desc = JUST(GetParallelDesc(inputs, ctx, user_op_expr));\n  std::shared_ptr<const GlobalTensorInferResult> result;\n  NonRecursiveMetaInfoConsistencyCheckScope scope;\n  // extand lifetime of boxing outputs to the end of this function\n  TensorTuple boxing_inputs = inputs;\n  if (inputs.empty()) {\n    // check consistency placement and nd_sbp, do not check in non-src op because it is assumed\n    // that InferSbp in op is a deterministic algorithm\n    JUST(MetaInfoConsistencyCheck(parallel_desc, ctx.nd_sbp, 1, /* force_check */ false));\n    const auto& infer_args =\n        JUST(SrcOpGlobalTensorMetaInferArgs::New(ctx.attrs, parallel_desc, JUST(ctx.nd_sbp)));\n    result = JUST(user_op_expr.mut_global_tensor_infer_cache()->GetOrInfer(*infer_args));\n  } else {\n    for (int i = 0; i < outputs->size(); ++i) {\n      if ((*outputs)[i]) {\n        const auto& nd_sbp = JUST((*outputs)[i]->nd_sbp());\n        JUST((*outputs)[i]->set_consumer_nd_sbp_constraint(nd_sbp));\n      }\n    }\n    std::shared_ptr<GlobalTensorMetaInferArgs> infer_args =\n        JUST(GlobalTensorMetaInferArgs::New(ctx.attrs, boxing_inputs));\n    // is_identical is true indicating all inputs tensor have same parallel_desc\n    const bool is_identical = JUST(IsAllInputsParallelDescIdentical(infer_args));\n    // if is_identical is false and env 'ONEFLOW_ENABLE_PIPELINE_PARALLELISM_AUTO_TO_GLOBAL' set to\n    // true then traverse all input tensor use function GetBoxingOutput(), during this process,\n    // each tensor will to_global with target parallel_desc\n    if (IsEnvEnableGlobalInputsWithInConsistentPlacement() && !is_identical) {\n      parallel_desc = JUST(GetMaxRankTensorPlacement(infer_args));\n      Optional<int64_t> parallel_id;\n      JUST(GetTensorDevice4CurrentProcessCtx(parallel_desc, &parallel_id));\n      for (int i = 0; i < inputs.size(); ++i) {\n        const auto& input = inputs.at(i);\n        Optional<int64_t> input_parallel_id;\n        JUST(GetTensorDevice4CurrentProcessCtx(JUST(input->parallel_desc()), &input_parallel_id));\n        const auto& final_input =\n            JUST(GetBoxingOutput(input, JUST(inputs[i]->nd_sbp()), parallel_desc,\n                                 input_parallel_id.has_value() || parallel_id.has_value()));\n\n        boxing_inputs[i] = final_input;\n      }\n      infer_args = JUST(GlobalTensorMetaInferArgs::New(ctx.attrs, boxing_inputs));\n    }\n    result = JUST(user_op_expr.mut_global_tensor_infer_cache()->GetOrInfer(*infer_args));\n  }\n\n  const auto& output_tensor_metas = result->output_tensor_metas();\n  Optional<int64_t> parallel_id;\n  const auto& tensor_device = JUST(GetTensorDevice4CurrentProcessCtx(parallel_desc, &parallel_id));\n  for (int i = 0; i < outputs->size(); ++i) {\n    if (!outputs->at(i)) {\n      const auto& tensor_impl = JUST(EagerGlobalTensorImpl::New(\n          output_tensor_metas[i], tensor_device, parallel_id, false, false));\n      (*outputs)[i].reset(new GlobalTensor(tensor_impl));\n    } else {\n      JUST((*outputs)[i]->set_consumer_nd_sbp_constraint(NullOpt));\n    }\n  }\n  // Do nothing if output_tensors has 0-size shape. Since the input of some ops is 0-size but the\n  // output is not 0-size, it cannot be judged based on the input, such as flow.cat\n  if (unlikely(JUST(CachedIsAllZeroSizeTensorMeta(output_tensor_metas)))) {\n    return Maybe<void>::Ok();\n  }\n  // Run instruction Call\n  const auto& kernel = JUST(user_op_expr.MutKernel4Stream(result->stream()));\n  CHECK_EQ_OR_RETURN(kernel->output_tuple_indexes4mut2_obns().size(), 0)\n      << Error::UnimplementedError() << GetDynamicOpGlobalFailedDebugString(user_op_expr, *kernel);\n\n  vm::EagerBlobObjectList input_eager_blob_objects(boxing_inputs.size());\n  // extand lifetime of boxing outputs to the end of this function\n  TensorTuple boxing_outputs;\n  for (int i = 0; i < boxing_inputs.size(); ++i) {\n    std::shared_ptr<Tensor> input = boxing_inputs.at(i);\n    const auto& infered_input_meta = result->input_tensor_metas().at(i);\n    const auto& input_parallel_desc = JUST(input->parallel_desc());\n    CHECK_OR_RETURN(input_parallel_desc == infered_input_meta->parallel_desc());\n    bool is_host_input = user_op_expr.IsHostMemoryInput(i);\n    Symbol<ParallelDesc> dst_parallel_desc =\n        is_host_input\n            ? JUST(ReplaceDeviceType(infered_input_meta->parallel_desc(), DeviceType::kCPU))\n            : infered_input_meta->parallel_desc();\n    if ((input_parallel_desc->parallel_num() != 1\n         && infered_input_meta->nd_sbp() != JUST(input->nd_sbp()))\n        || input_parallel_desc->device_type() != dst_parallel_desc->device_type()) {\n      input = JUST(GetBoxingOutput(input, infered_input_meta->nd_sbp(), dst_parallel_desc,\n                                   parallel_id.has_value()));\n      boxing_outputs.emplace_back(input);\n    }\n    const auto& local_tensor = JUST(input->cur_rank_phy_tensor());\n    input_eager_blob_objects.at(i) = JUST(local_tensor->eager_blob_object());\n  }\n  // Do nothing if the `parallel_desc` doesn't cover current ProcessCtx.\n  if (!parallel_id.has_value()) { return Maybe<void>::Ok(); }\n  vm::EagerBlobObjectList output_eager_blob_objects(outputs->size());\n  for (int i = 0; i < outputs->size(); ++i) {\n    const auto& local_tensor = JUST(outputs->at(i)->cur_rank_phy_tensor());\n    output_eager_blob_objects.at(i) = JUST(local_tensor->eager_blob_object());\n  }\n  if (tensor_device->enum_type() == DeviceType::kMeta) { return Maybe<void>::Ok(); }\n  JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe<void> {\n    return builder->Call(kernel, std::move(input_eager_blob_objects),\n                         std::move(output_eager_blob_objects), result, ctx, result->stream());\n  }));\n  return Maybe<void>::Ok();\n}\n\nauto* InterpretThenInitGlobalId = DECORATE(&Interpret, NonRecursiveInitGlobalId);\n\n}  // namespace\n\nMaybe<void> EagerGlobalInterpreter::ApplyImpl(const UserOpExpr& op_expr, const TensorTuple& inputs,\n                                              TensorTuple* outputs,\n                                              const OpExprInterpContext& ctx) const {\n  return InterpretThenInitGlobalId(op_expr, inputs, outputs, ctx);\n}\n\nMaybe<void> EagerGlobalInterpreter::ApplyImpl(const VariableOpExpr& op_expr,\n                                              const TensorTuple& inputs, TensorTuple* outputs,\n                                              const OpExprInterpContext& ctx) const {\n  OF_UNIMPLEMENTED();\n}\n\nnamespace {\n\nstatic constexpr auto* RecursiveGetBoxingOutput =\n    DECORATE(&CalcBoxingOutput, CheckGlobalTensorMeta);\n\nMaybe<void> RawGlobalToGlobal(const GlobalToGlobalOpExpr& op_expr, const TensorTuple& inputs,\n                              TensorTuple* outputs, const OpExprInterpContext& ctx) {\n  CHECK_EQ_OR_RETURN(inputs.size(), 1);\n  CHECK_EQ_OR_RETURN(outputs->size(), 1);\n  const auto& input = inputs.at(0);\n  CHECK_OR_RETURN(input->is_global());  // NOLINT\n  CHECK_OR_RETURN(ctx.parallel_desc.has_value());\n  CHECK_OR_RETURN(ctx.nd_sbp.has_value());\n  const auto& in_parallel_desc = JUST(input->parallel_desc());\n  const auto& out_nd_sbp = JUST(ctx.nd_sbp);\n  const auto& out_parallel_desc = JUST(ctx.parallel_desc);\n  const auto& in_parallel_id = JUST(GetParallelId4CurrentProcessCtx(in_parallel_desc));\n  const auto& out_parallel_id = JUST(GetParallelId4CurrentProcessCtx(out_parallel_desc));\n  const auto& tensor =\n      JUST(RecursiveGetBoxingOutput(input, out_nd_sbp, out_parallel_desc,\n                                    in_parallel_id->has_value() || out_parallel_id->has_value()));\n  CHECK_OR_RETURN(tensor);\n  if (out_parallel_id->has_value()) {\n    const auto& nd_sbp = JUST(tensor->nd_sbp());\n    const auto& parallel_desc = JUST(tensor->parallel_desc());\n    CHECK_OR_RETURN(nd_sbp == out_nd_sbp)\n        << \". nd_sbp: \" << NdSbpToString(nd_sbp) << \", out_nd_sbp\" << NdSbpToString(out_nd_sbp);\n    CHECK_OR_RETURN(parallel_desc == out_parallel_desc);\n    outputs->at(0) = tensor;\n  } else {\n    GlobalTensorMeta tensor_meta(*tensor->shape(), tensor->dtype()->data_type(),\n                                 tensor->memory_format(), out_nd_sbp, out_parallel_desc);\n    const auto& tensor_impl =\n        JUST(EagerGlobalTensorImpl::New(SymbolOf(tensor_meta), tensor->requires_grad(), false));\n    (*outputs)[0].reset(new GlobalTensor(tensor_impl));\n  }\n  CHECK_OR_RETURN(outputs->at(0));\n  return Maybe<void>::Ok();\n}\n\nstatic constexpr auto* GlobalToGlobal = DECORATE(&RawGlobalToGlobal, NonRecursiveInitGlobalId);\n\n}  // namespace\n\nMaybe<void> EagerGlobalInterpreter::ApplyImpl(const GlobalToGlobalOpExpr& op_expr,\n                                              const TensorTuple& inputs, TensorTuple* outputs,\n                                              const OpExprInterpContext& ctx) const {\n  JUST(GlobalToGlobal(op_expr, inputs, outputs, ctx));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> EagerGlobalInterpreter::ApplyImpl(const LocalToGlobalOpExpr& op_expr,\n                                              const TensorTuple& inputs, TensorTuple* outputs,\n                                              const OpExprInterpContext& ctx) const {\n  OF_UNIMPLEMENTED();\n}\n\nMaybe<void> EagerGlobalInterpreter::ApplyImpl(const GlobalToLocalOpExpr& op_expr,\n                                              const TensorTuple& inputs, TensorTuple* outputs,\n                                              const OpExprInterpContext& ctx) const {\n  CHECK_EQ_OR_RETURN(inputs.size(), 1);\n  const auto& input_tensor = inputs.at(0);\n  const auto& local_tensor = JUST(JUST(input_tensor->cur_rank_phy_tensor())->detach());\n  bool requires_grad = autograd::GradMode::is_enabled() && input_tensor->requires_grad();\n  JUST(local_tensor->set_requires_grad(requires_grad));\n  local_tensor->set_is_leaf(!requires_grad);\n  (*outputs)[0] = local_tensor;\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> EagerGlobalInterpreter::ApplyImpl(const CastToLocalOpExpr& op_expr,\n                                              const TensorTuple& inputs, TensorTuple* outputs,\n                                              const OpExprInterpContext& ctx) const {\n  OF_UNIMPLEMENTED();\n}\n\nMaybe<void> EagerGlobalInterpreter::ApplyImpl(const CastFromLocalOpExpr& op_expr,\n                                              const TensorTuple& inputs, TensorTuple* outputs,\n                                              const OpExprInterpContext& ctx) const {\n  OF_UNIMPLEMENTED();\n}\n\nMaybe<void> EagerGlobalInterpreter::ApplyImpl(const DistributeSplitOpExpr& op_expr,\n                                              const TensorTuple& inputs, TensorTuple* outputs,\n                                              const OpExprInterpContext& ctx) const {\n  OF_UNIMPLEMENTED();\n}\n\nMaybe<void> EagerGlobalInterpreter::ApplyImpl(const DistributeCloneOpExpr& op_expr,\n                                              const TensorTuple& inputs, TensorTuple* outputs,\n                                              const OpExprInterpContext& ctx) const {\n  OF_UNIMPLEMENTED();\n}\n\nMaybe<void> EagerGlobalInterpreter::ApplyImpl(const DistributeConcatOpExpr& op_expr,\n                                              const TensorTuple& inputs, TensorTuple* outputs,\n                                              const OpExprInterpContext& ctx) const {\n  OF_UNIMPLEMENTED();\n}\n\nMaybe<void> EagerGlobalInterpreter::ApplyImpl(const DistributeAddOpExpr& op_expr,\n                                              const TensorTuple& inputs, TensorTuple* outputs,\n                                              const OpExprInterpContext& ctx) const {\n  OF_UNIMPLEMENTED();\n}\n\nMaybe<void> EagerGlobalInterpreter::ApplyImpl(const SelectTopNOpExpr& op_expr,\n                                              const TensorTuple& inputs, TensorTuple* outputs,\n                                              const OpExprInterpContext& ctx) const {\n  OF_UNIMPLEMENTED();\n}\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/framework/op_interpreter/eager_local_op_interpreter.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/container_util.h\"\n#include \"oneflow/core/common/decorator.h\"\n#include \"oneflow/core/common/device_type.pb.h\"\n#include \"oneflow/core/common/symbol.h\"\n#include \"oneflow/core/framework/device.h\"\n#include \"oneflow/core/framework/mutable_attr_map.h\"\n#include \"oneflow/core/framework/op_interpreter.h\"\n#include \"oneflow/core/framework/op_interpreter/op_interpreter_util.h\"\n#include \"oneflow/core/framework/instructions_builder.h\"\n#include \"oneflow/core/framework/scope_util.h\"\n#include \"oneflow/core/framework/session_util.h\"\n#include \"oneflow/core/framework/symbol_storage_util.h\"\n#include \"oneflow/core/framework/tensor.h\"\n#include \"oneflow/core/framework/tensor_name_scope.h\"\n#include \"oneflow/core/framework/tensor_tuple.h\"\n#include \"oneflow/core/framework/local_tensor_infer_cache.h\"\n#include \"oneflow/core/common/stride.h\"\n#include \"oneflow/core/memory/memory_case_util.h\"\n#include \"oneflow/core/operator/operator.h\"\n#include \"oneflow/user/kernels/stateful_opkernel.h\"\n#include \"oneflow/core/vm/vm_util.h\"\n#include \"oneflow/core/vm/virtual_machine.h\"\n#include \"oneflow/core/autograd/autograd_mode.h\"\n#include \"oneflow/core/framework/placement_sbp_util.h\"\n#include \"oneflow/core/framework/tensor_rpc_util.h\"\n#include \"oneflow/core/framework/tensor_global_id.h\"\n#include \"oneflow/core/framework/op_builder.h\"\n#include \"oneflow/core/framework/id_util.h\"\n#include \"oneflow/core/functional/functional.h\"\n#include \"oneflow/core/rpc/include/global_process_ctx.h\"\n#include \"oneflow/core/profiler/profiler.h\"\n\nnamespace oneflow {\nnamespace one {\n\nnamespace {\n\nMaybe<Symbol<Device>> RawGetDefaultCpuDevice() { return Device::New(\"cpu\"); }\n\nconstexpr auto* GetDefaultCpuDevice = DECORATE(&RawGetDefaultCpuDevice, ThreadLocal);\n\nMaybe<Symbol<Device>> GetDefaultDevice(const TensorTuple& inputs, const OpExprInterpContext& ctx,\n                                       const UserOpExpr& user_op_expr) {\n  if (!inputs.empty()) {\n    for (int32_t i = 0; i < inputs.size(); ++i) {\n      if (!user_op_expr.IsHostMemoryInput(i)) { return JUST(inputs.at(i)->device()); }\n    }\n  }\n  if (ctx.device.has_value()) {\n    return JUST(ctx.device);\n  } else {\n    return GetDefaultCpuDevice();\n  }\n}\n\nMaybe<EagerLocalTensorImpl*> TensorImpl4Tensor(const std::shared_ptr<Tensor>& tensor) {\n  CHECK_OR_RETURN(static_cast<bool>(tensor));\n  return tensor->mut_eager_local_tensor_impl();\n}\n\n}  // namespace\n\nMaybe<void> NaiveInterpret(const UserOpExpr& user_op_expr, const TensorTuple& inputs,\n                           TensorTuple* outputs, const OpExprInterpContext& ctx) {\n  OF_PROFILER_RANGE_GUARD(\"NaiveInterpret\");\n  CHECK_EQ_OR_RETURN(outputs->size(), user_op_expr.output_size());  // NOLINT\n  Symbol<Device> default_device = JUST(GetDefaultDevice(inputs, ctx, user_op_expr));\n  const std::shared_ptr<const LocalTensorInferResult> result =\n      JUST([&]() -> Maybe<const LocalTensorInferResult> {\n        LocalTensorMetaInferArgs infer_args;\n        JUST(infer_args.Init(ctx.attrs, default_device, inputs));\n        return JUST(user_op_expr.mut_local_tensor_infer_cache()->GetOrInfer(infer_args));\n      }());\n\n  vm::EagerBlobObjectList input_eager_blob_objects(inputs.size());\n  // expand lifetime of host_inputs to the end of this function\n  TensorTuple host_inputs;\n  for (int i = 0; i < inputs.size(); i++) {\n    if (user_op_expr.IsHostMemoryInput(i)) {\n      const auto& host_input = JUST(functional::To(\n          inputs.at(i), Optional<Symbol<Device>>(JUST(GetDefaultCpuDevice())), NullOpt, false));\n      input_eager_blob_objects.at(i) = JUST(host_input->eager_blob_object());\n      host_inputs.emplace_back(host_input);\n    } else {\n      input_eager_blob_objects.at(i) = JUST(inputs.at(i)->eager_blob_object());\n    }\n  }\n\n  const auto& output_tensor_metas = result->output_tensor_metas();\n  vm::EagerBlobObjectList output_eager_blob_objects(outputs->size());\n\n  const auto& kernel = JUST(user_op_expr.MutKernel4Stream(result->stream()));\n\n  for (int i = 0; i < outputs->size(); i++) {\n    if (!outputs->at(i)) {\n      // NOTE: if op support stride(non-contiguous input), then output tensor's stride\n      // should be inferred in InferLogicalTensorDesc.\n      // otherwise, it will be set here(according to shape).\n      std::shared_ptr<MutLocalTensorMeta> mut_tensor_meta;\n      {\n        if (kernel->output_is_mut2_type(i)) {\n          mut_tensor_meta = std::make_shared<MutLocalTensorMeta>(\n              output_tensor_metas.at(i)->shape(), output_tensor_metas.at(i)->stride(),\n              output_tensor_metas.at(i)->dtype(), output_tensor_metas.at(i)->memory_format(),\n              output_tensor_metas.at(i)->device());\n        }\n      }\n      std::shared_ptr<EagerLocalTensorImpl> tensor_impl =\n          std::make_shared<EagerLocalTensorImpl>(false, false);\n      const auto& dep_object = NewLocalDepObject();\n      JUST(\n          tensor_impl->InitEagerBlobObject(output_tensor_metas.at(i), mut_tensor_meta, dep_object));\n      output_eager_blob_objects.at(i) = JUST(tensor_impl->eager_blob_object());\n      (*outputs)[i] = std::make_shared<LocalTensor>(tensor_impl);\n    } else {\n      const auto* tensor_impl = JUST(TensorImpl4Tensor(outputs->at(i)));\n      // output i is inplaced.\n      // check TensorMeta of infer result and TensorMeta of output i.\n      CHECK_OR_RETURN(tensor_impl->tensor_meta()->shape()                                 // NOLINT\n                      == output_tensor_metas.at(i)->shape())                              // NOLINT\n          << Error::RuntimeError() << tensor_impl->tensor_meta()->shape().ToString()      // NOLINT\n          << \" .vs \"                                                                      // NOLINT\n          << output_tensor_metas.at(i)->shape().ToString();                               // NOLINT\n      CHECK_OR_RETURN(tensor_impl->tensor_meta()->dtype()                                 // NOLINT\n                      == output_tensor_metas.at(i)->dtype())                              // NOLINT\n          << Error::RuntimeError() << DataType_Name(tensor_impl->tensor_meta()->dtype())  // NOLINT\n          << \" .vs \"                                                                      // NOLINT\n          << DataType_Name(output_tensor_metas.at(i)->dtype());                           // NOLINT\n      bool has_eager_blob_object = JUST(outputs->at(i)->has_eager_blob_object());\n      CHECK_OR_RETURN(has_eager_blob_object);  // NOLINT\n      output_eager_blob_objects.at(i) = JUST(outputs->at(i)->eager_blob_object());\n      // TODO(zhaoluyang):(thread_local TensorMeta set stride then check)\n      // CHECK_OR_RETURN(tensor_impl->tensor_meta()->stride() ==\n      // output_tensor_metas->at(i)->stride());\n    }\n  }\n\n  if (default_device->enum_type() == DeviceType::kMeta) { return Maybe<void>::Ok(); }\n\n  JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe<void> {\n    return builder->Call(kernel, std::move(input_eager_blob_objects),\n                         std::move(output_eager_blob_objects), ctx, result->stream());\n  }));\n  for (int64_t index : kernel->output_tuple_indexes4mut2_obns()) {\n    const auto* tensor_impl = JUST(TensorImpl4Tensor(outputs->at(index)));\n    auto btb = std::make_shared<BlockingThenBusy>();\n    JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe<void> {\n      return builder->SyncAccessBlobByCallback(\n          tensor_impl, btb, [](ep::Stream* stream, const std::shared_ptr<vm::EagerBlobObject>&) {},\n          \"const\");\n    }));\n    JUST(btb->WaitUntilCntEqualZero(VirtualMachine::GetPredicatorNoMoreInstructionsFinished()));\n    const auto& mut_tensor_meta = const_cast<EagerLocalTensorImpl*>(tensor_impl)->mut_tensor_meta();\n    Symbol<LocalTensorMeta> new_tensor_meta = SymbolOf(LocalTensorMeta(\n        mut_tensor_meta->shape(), mut_tensor_meta->stride(), mut_tensor_meta->dtype(),\n        mut_tensor_meta->memory_format(), mut_tensor_meta->device()));\n    std::shared_ptr<EagerLocalTensorImpl> final_tensor_impl =\n        std::make_shared<EagerLocalTensorImpl>(JUST(tensor_impl->tensor_storage()),\n                                               JUST(tensor_impl->storage_offset()), false, false);\n    JUST(final_tensor_impl->InitEagerBlobObject(\n        new_tensor_meta,\n        JUST(JUST(outputs->at(index)->eager_blob_object())->compute_local_dep_object())));\n    JUST(JUST(outputs->at(index)->AsLocalTensor())->set_impl(final_tensor_impl));\n  }\n\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> EagerLocalInterpreter::ApplyImpl(const UserOpExpr& op_expr, const TensorTuple& inputs,\n                                             TensorTuple* outputs,\n                                             const OpExprInterpContext& ctx) const {\n  return NaiveInterpret(op_expr, inputs, outputs, ctx);\n}\n\nMaybe<void> EagerLocalInterpreter::ApplyImpl(const VariableOpExpr& op_expr,\n                                             const TensorTuple& inputs, TensorTuple* outputs,\n                                             const OpExprInterpContext& ctx) const {\n  OF_UNIMPLEMENTED();\n}\n\nstatic Maybe<void> BuildAndRunLocalCastInstruction(const BuiltinOpExpr& op_expr,\n                                                   const TensorTuple& inputs,\n                                                   TensorTuple* outputs) {\n  // TODO()\n  OF_UNIMPLEMENTED();\n}\n\nnamespace {\n\nMaybe<one::UserOpExpr> EagerCclBroadcast(Symbol<ParallelDesc> parallel_desc, int64_t root,\n                                         size_t size, const std::vector<Shape>& shape_list) {\n  return one::OpBuilder(\"eager_ccl_broadcast\", *JUST(UniqueStr(\"eager_ccl_broadcast\")))\n      .Input(\"in\", size)\n      .Output(\"out\", size)\n      .Attr<std::string>(\"parallel_conf\", PbMessage2TxtString(parallel_desc->parallel_conf()))\n      .Attr<std::vector<Shape>>(\"shape_list\", shape_list)\n      .Attr<int64_t>(\"root\", root)\n      .Build();\n}\n\nauto* CachedEagerCclBroadcastOpExpr = DECORATE(&EagerCclBroadcast, ThreadLocalCachedCopiable);\n\n}  // namespace\n\nMaybe<Tensor> Broadcast(const std::shared_ptr<Tensor>& tensor, int64_t src_rank,\n                        Symbol<ParallelDesc> parallel_desc, bool inplace) {\n  CHECK_OR_RETURN(parallel_desc->containing_current_rank());\n  if (parallel_desc->parallel_num() == 1 /* no broadcast */) { return tensor; }\n  std::shared_ptr<UserOpExpr> op_expr =\n      JUST(CachedEagerCclBroadcastOpExpr(parallel_desc, src_rank, 1, {*tensor->shape()}));\n  auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"root\");\n  attrs.SetAllAttrs(src_rank);\n  if (inplace) {\n    TensorTuple outputs{tensor};\n    JUST(OpInterpUtil::Dispatch(*op_expr, {tensor}, &outputs,\n                                one::OpExprInterpContext(attrs, parallel_desc)));\n    return tensor;\n  } else {\n    return JUST(OpInterpUtil::Dispatch<one::Tensor>(\n        *op_expr, {tensor}, one::OpExprInterpContext(attrs, parallel_desc)));\n  }\n}\n\nMaybe<TensorTuple> Broadcast(const TensorTuple& inputs, int64_t src_rank,\n                             Symbol<ParallelDesc> parallel_desc, bool inplace) {\n  CHECK_OR_RETURN(parallel_desc->containing_current_rank())\n      << \"Current rank are not contained in the placement arguement\";\n  if (parallel_desc->parallel_num() == 1 /* no broadcast */) { return inputs; }\n  std::vector<Shape> shape_list;\n  for (const auto& tensor : inputs) { shape_list.emplace_back(*tensor->shape()); }\n  std::shared_ptr<UserOpExpr> op_expr =\n      JUST(CachedEagerCclBroadcastOpExpr(parallel_desc, src_rank, inputs.size(), shape_list));\n  auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"root\");\n  attrs.SetAllAttrs(src_rank);\n  if (inplace) {\n    auto outputs = std::make_shared<TensorTuple>(inputs);\n    JUST(OpInterpUtil::Dispatch(*op_expr, inputs, outputs.get(),\n                                one::OpExprInterpContext(attrs, parallel_desc)));\n    return outputs;\n  } else {\n    return JUST(OpInterpUtil::Dispatch<one::TensorTuple>(\n        *op_expr, inputs, one::OpExprInterpContext(attrs, parallel_desc)));\n  }\n}\n\nnamespace {\n\nMaybe<Tensor> GetSyncedTensorIfBroadcast(const std::shared_ptr<Tensor>& tensor,\n                                         Symbol<ParallelDesc> parallel_desc, Symbol<NdSbp> nd_sbp,\n                                         bool inplace) {\n  Optional<int64_t> parallel_id;\n  JUST(GetTensorDevice4CurrentProcessCtx(parallel_desc, &parallel_id));\n  if (!parallel_id.has_value()) { return tensor; }\n  const auto& broadcast_parallel_desc = JUST(GetBroadcastSubParallelDesc(parallel_desc, nd_sbp));\n  int64_t root = JUST(broadcast_parallel_desc->MachineId4ParallelId(0));\n  if (broadcast_parallel_desc->parallel_num() > 1 && inplace && GlobalProcessCtx::Rank() == 0) {\n    LOG_FIRST_N(WARNING, 1)\n        << \"Casting a local tensor to a global tensor with Broadcast sbp will modify the data of \"\n           \"input! \"\n           \"If you want to keep the input local tensor unchanged, please set the arg copy to True.\";\n  }\n  return Broadcast(tensor, root, broadcast_parallel_desc, inplace);\n}\n\nMaybe<Shape> CalcPhysicalShape(Symbol<GlobalTensorMeta> global_tensor_meta) {\n  const auto& opt_parallel_id =\n      JUST(GetParallelId4CurrentProcessCtx(global_tensor_meta->parallel_desc()));\n  int64_t parallel_id = JUST(*opt_parallel_id);\n  return GetPhysicalShape(global_tensor_meta->shape(), *global_tensor_meta->nd_sbp(),\n                          *global_tensor_meta->parallel_desc(), parallel_id);\n}\n\nstatic constexpr auto* GetPhysicalShape = DECORATE(&CalcPhysicalShape, ThreadLocal);\n\nMaybe<Tensor> TryReshapeTensor(const std::shared_ptr<Tensor>& tensor,\n                               Symbol<GlobalTensorMeta> global_tensor_meta) {\n  CHECK_OR_RETURN(tensor->is_local());\n  const auto& physical_shape = JUST(GetPhysicalShape(global_tensor_meta));\n  if (*physical_shape == *tensor->shape()) { return tensor; }\n  CHECK_EQ_OR_RETURN(physical_shape->elem_cnt(), tensor->shape()->elem_cnt());\n  // TODO(lixinqi) inplace reshape.\n  return tensor;\n}\n\n}  // namespace\n\nMaybe<void> EagerLocalInterpreter::ApplyImpl(const GlobalToGlobalOpExpr& op_expr,\n                                             const TensorTuple& inputs, TensorTuple* outputs,\n                                             const OpExprInterpContext& ctx) const {\n  OF_UNIMPLEMENTED();\n}\n\nnamespace {\n\nMaybe<void> RawLocalToGlobal(const LocalToGlobalOpExpr& op_expr, const TensorTuple& inputs,\n                             TensorTuple* outputs, const OpExprInterpContext& ctx) {\n  std::shared_ptr<LocalTensor> input_local_tensor;\n  {\n    CHECK_EQ_OR_RETURN(inputs.size(), 1);\n    CHECK_OR_RETURN(!inputs[0]->is_global());  // NOLINT\n    const auto& input_tensor = JUST(inputs.at(0)->detach());\n    input_local_tensor = JUST(input_tensor->AsLocalTensor());\n    CHECK_OR_RETURN(input_local_tensor)\n        << Error::InvalidValueError() << \"Tensor Cast Error\";  // NOLINT\n    bool requires_grad = autograd::GradMode::is_enabled() && inputs.at(0)->requires_grad();\n    JUST(input_local_tensor->set_requires_grad(requires_grad));\n    input_local_tensor->set_is_leaf(!requires_grad);\n  }\n  std::shared_ptr<GlobalTensor> global_tensor;\n  {\n    CHECK_OR_RETURN(ctx.parallel_desc.has_value());\n    CHECK_OR_RETURN(ctx.nd_sbp.has_value());\n    const auto& nd_sbp = JUST(ctx.nd_sbp);\n    const auto& parallel_desc = JUST(ctx.parallel_desc);\n    const auto& logical_shape = JUST(ctx.attrs.GetAttr<Shape>(\"shape\"));\n    DataType dtype = JUST(ctx.attrs.GetAttr<DataType>(\"dtype\"));\n    // MemoryFormat memory_format = JUST(ctx.attrs.GetAttr<MemoryFormat>(\"memory_format\"));\n    GlobalTensorMeta tensor_meta(logical_shape, dtype, MemoryFormat::kContiguous, nd_sbp,\n                                 parallel_desc);\n    Optional<int64_t> parallel_id{};\n    const auto& device = JUST(GetTensorDevice4CurrentProcessCtx(parallel_desc, &parallel_id));\n    const auto& global_tensor_impl = JUST(EagerGlobalTensorImpl::New(\n        SymbolOf(tensor_meta), device, parallel_id, input_local_tensor->requires_grad(),\n        !input_local_tensor->requires_grad()));\n    global_tensor = std::make_shared<GlobalTensor>(global_tensor_impl);\n    if (parallel_id.has_value()) {\n      const auto& pyhsical_shape = JUST(GetPhysicalShape(tensor_meta));\n      const auto& input_local_tensor_shape = input_local_tensor->shape();\n      CHECK_EQ_OR_RETURN(*pyhsical_shape, *input_local_tensor_shape);      // NOLINT\n      CHECK_OR_RETURN(dtype == input_local_tensor->dtype()->data_type());  // NOLINT\n      global_tensor_impl->reset_cur_rank_phy_tensor(input_local_tensor);\n    }\n  }\n  (*outputs)[0] = global_tensor;\n  return Maybe<void>::Ok();\n}\n\nstatic constexpr auto* LocalToGlobal = DECORATE(&RawLocalToGlobal, NonRecursiveInitGlobalId);\n\n}  // namespace\n\nMaybe<void> EagerLocalInterpreter::ApplyImpl(const LocalToGlobalOpExpr& op_expr,\n                                             const TensorTuple& inputs, TensorTuple* outputs,\n                                             const OpExprInterpContext& ctx) const {\n  bool sync_data = JUST(ctx.attrs.GetAttr<bool>(\"sync_data\"));\n  JUST(LocalToGlobal(op_expr, inputs, outputs, ctx));\n  const auto& global_tensor = JUST((*outputs)[0]->AsGlobalTensor());\n  JUST(WithConsistencyChecked(global_tensor, [&]() -> Maybe<void> {\n    if (IsGlobalTensorMetaCheckDisabled()) { return Maybe<void>::Ok(); }\n    const auto& parallel_desc = JUST(ctx.parallel_desc);\n    const auto& parallel_id = JUST(GetParallelId4CurrentProcessCtx(parallel_desc));\n    if (!parallel_id->has_value()) { return Maybe<void>::Ok(); }\n    const auto& nd_sbp = JUST(ctx.nd_sbp);\n    const auto& tensor_meta = JUST(global_tensor->global_tensor_meta());\n    const auto& local_tensor = JUST(global_tensor->cur_rank_phy_tensor());\n    const auto& reshaped_tensor = JUST(TryReshapeTensor(local_tensor, tensor_meta));\n    std::shared_ptr<Tensor> synced_tensor = reshaped_tensor;\n    if (sync_data) {\n      bool inplace = JUST(ctx.attrs.GetAttr<bool>(\"inplace_when_sync_data\"));\n      synced_tensor =\n          JUST(GetSyncedTensorIfBroadcast(reshaped_tensor, parallel_desc, nd_sbp, inplace));\n    }\n    auto* global_tensor_impl = reinterpret_cast<EagerGlobalTensorImpl*>(global_tensor->mut_impl());\n    CHECK_NOTNULL_OR_RETURN(global_tensor_impl);\n    global_tensor_impl->reset_cur_rank_phy_tensor(JUST(synced_tensor->AsLocalTensor()));\n    return Maybe<void>::Ok();\n  }));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> EagerLocalInterpreter::ApplyImpl(const GlobalToLocalOpExpr& op_expr,\n                                             const TensorTuple& inputs, TensorTuple* outputs,\n                                             const OpExprInterpContext& ctx) const {\n  OF_UNIMPLEMENTED();\n}\n\nMaybe<void> EagerLocalInterpreter::ApplyImpl(const CastToLocalOpExpr& op_expr,\n                                             const TensorTuple& inputs, TensorTuple* outputs,\n                                             const OpExprInterpContext& ctx) const {\n  return BuildAndRunLocalCastInstruction(op_expr, inputs, outputs);\n}\n\nMaybe<void> EagerLocalInterpreter::ApplyImpl(const CastFromLocalOpExpr& op_expr,\n                                             const TensorTuple& inputs, TensorTuple* outputs,\n                                             const OpExprInterpContext& ctx) const {\n  return BuildAndRunLocalCastInstruction(op_expr, inputs, outputs);\n}\n\nstatic Maybe<void> BuildAndRunDistributeSplitOrCloneInstruction(const BuiltinOpExpr& op_expr,\n                                                                const TensorTuple& inputs,\n                                                                TensorTuple* outputs) {\n  // TODO()\n  OF_UNIMPLEMENTED();\n}\n\nMaybe<void> EagerLocalInterpreter::ApplyImpl(const DistributeSplitOpExpr& op_expr,\n                                             const TensorTuple& inputs, TensorTuple* outputs,\n                                             const OpExprInterpContext& ctx) const {\n  return BuildAndRunDistributeSplitOrCloneInstruction(op_expr, inputs, outputs);\n}\n\nMaybe<void> EagerLocalInterpreter::ApplyImpl(const DistributeCloneOpExpr& op_expr,\n                                             const TensorTuple& inputs, TensorTuple* outputs,\n                                             const OpExprInterpContext& ctx) const {\n  return BuildAndRunDistributeSplitOrCloneInstruction(op_expr, inputs, outputs);\n}\n\nstatic Maybe<void> BuildAndRunDistributeConcatAndAddInstruction(const BuiltinOpExpr& op_expr,\n                                                                const TensorTuple& inputs,\n                                                                TensorTuple* outputs) {\n  // TODO()\n  OF_UNIMPLEMENTED();\n}\n\nMaybe<void> EagerLocalInterpreter::ApplyImpl(const DistributeConcatOpExpr& op_expr,\n                                             const TensorTuple& inputs, TensorTuple* outputs,\n                                             const OpExprInterpContext& ctx) const {\n  return BuildAndRunDistributeConcatAndAddInstruction(op_expr, inputs, outputs);\n}\n\nMaybe<void> EagerLocalInterpreter::ApplyImpl(const DistributeAddOpExpr& op_expr,\n                                             const TensorTuple& inputs, TensorTuple* outputs,\n                                             const OpExprInterpContext& ctx) const {\n  return BuildAndRunDistributeConcatAndAddInstruction(op_expr, inputs, outputs);\n}\n\nMaybe<void> EagerLocalInterpreter::ApplyImpl(const SelectTopNOpExpr& op_expr,\n                                             const TensorTuple& inputs, TensorTuple* outputs,\n                                             const OpExprInterpContext& ctx) const {\n  int top_n = JUST(ctx.attrs.GetAttr<int32_t>(\"top_n\"));\n  outputs->resize(top_n);\n  for (int i = 0; i < top_n; ++i) { (*outputs)[i] = JUST(JUST(VectorAt(inputs, i))->detach()); }\n  return Maybe<void>::Ok();\n}\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/framework/op_interpreter/eager_local_op_interpreter.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/eager/eager_blob_object.h\"\n\nnamespace oneflow {\n\nclass Device;\nclass TensorTuple;\nclass ParallelDesc;\n\nnamespace one {\n\nclass Tensor;\n\nMaybe<Tensor> Broadcast(const std::shared_ptr<Tensor>& tensor, int64_t src_rank,\n                        Symbol<ParallelDesc> parallel_desc, bool inplace);\n\nMaybe<TensorTuple> Broadcast(const TensorTuple& inputs, int64_t src_rank,\n                             Symbol<ParallelDesc> parallel_desc, bool inplace);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/framework/op_interpreter/lazy_op_interpreter.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/op_interpreter/lazy_op_interpreter.h\"\n\n#include <memory>\n#include \"oneflow/core/common/cpp_attribute.h\"\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/common/cpp_attribute.h\"\n#include \"oneflow/core/common/container_util.h\"\n#include \"oneflow/core/framework/consistency_check.h\"\n#include \"oneflow/core/framework/user_op_conf.h\"\n#include \"oneflow/core/framework/op_expr.h\"\n#include \"oneflow/core/framework/op_builder.h\"\n#include \"oneflow/core/framework/multi_client_session_context.h\"\n#include \"oneflow/core/framework/op_interpreter.h\"\n#include \"oneflow/core/framework/op_interpreter/op_interpreter_util.h\"\n#include \"oneflow/core/framework/instructions_builder.h\"\n#include \"oneflow/core/framework/scope_util.h\"\n#include \"oneflow/core/framework/session_util.h\"\n#include \"oneflow/core/framework/symbol_storage_util.h\"\n#include \"oneflow/core/framework/tensor.h\"\n#include \"oneflow/core/framework/tensor_name_scope.h\"\n#include \"oneflow/core/framework/tensor_tuple.h\"\n#include \"oneflow/core/framework/user_op_registry.h\"\n#include \"oneflow/core/framework/nd_sbp.h\"\n#include \"oneflow/core/job/job_desc.h\"\n#include \"oneflow/core/job/lazy_mode.h\"\n#include \"oneflow/core/operator/operator.h\"\n#include \"oneflow/core/job/sbp_parallel.h\"\n#include \"oneflow/core/job/job_build_and_infer_ctx_mgr.h\"\n#include \"oneflow/core/vm/vm_util.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\nnamespace one {\n\nnamespace {\n\nMaybe<Tensor> BuildTensor(const OpAttribute& op_attribute, const std::string& bn_in_op,\n                          const std::shared_ptr<ParallelDesc>& parallel_desc, const bool is_lazy,\n                          const bool is_local) {\n  CHECK_OR_RETURN(op_attribute.has_logical_blob_desc_signature());  // NOLINT(maybe-need-error-msg)\n  const auto& blob_desc_sign_map = op_attribute.logical_blob_desc_signature().bn_in_op2blob_desc();\n  auto blob_desc_it = blob_desc_sign_map.find(bn_in_op);\n  CHECK_OR_RETURN(blob_desc_it != blob_desc_sign_map.end())\n      << \"blob_desc of \" << bn_in_op << \" not found in op \" << op_attribute.op_conf().name();\n\n  auto shape = std::make_shared<Shape>(blob_desc_it->second.shape());\n  auto stride = std::make_shared<Stride>(shape);\n  auto dtype = blob_desc_it->second.data_type();\n  auto memory_format = blob_desc_it->second.memory_format();\n  if (is_local) {\n    const auto& device = JUST(Device::MakeDeviceByParallelDesc(*parallel_desc));\n    const auto& tensor =\n        JUST(LocalTensor::MakeTensor(shape, stride, dtype, memory_format, device, is_lazy,\n                                     /* requires_grad= */ false, /* is_leaf= */ true));\n    return static_cast<std::shared_ptr<Tensor>>(tensor);\n  } else {\n    const auto& nd_sbp_sign_map = op_attribute.nd_sbp_signature().bn_in_op2nd_sbp();\n    auto nd_sbp_it = nd_sbp_sign_map.find(bn_in_op);\n    CHECK_OR_RETURN(nd_sbp_it != nd_sbp_sign_map.end())\n        << \"nd_sbp of \" << bn_in_op << \" not found in op \" << op_attribute.op_conf().name();\n    NdSbp nd_sbp(nd_sbp_it->second);\n    const auto& tensor = JUST(GlobalTensor::MakeTensor(\n        shape, dtype, memory_format, SymbolOf(nd_sbp), SymbolOf(*parallel_desc), is_lazy,\n        /*requires_grad=*/false, /*is_leaf=*/true));\n    return static_cast<std::shared_ptr<Tensor>>(tensor);\n  }\n}\n\nMaybe<void> CheckTensorMatchAttr(const std::shared_ptr<Tensor>& tensor,\n                                 const OpAttribute& op_attribute, const std::string& bn_in_op,\n                                 const std::shared_ptr<ParallelDesc>& parallel_desc,\n                                 const bool is_local) {\n  CHECK_EQ_OR_RETURN(tensor->is_local(), is_local);  // NOLINT(maybe-need-error-msg)\n\n  CHECK_OR_RETURN(op_attribute.has_logical_blob_desc_signature());  // NOLINT(maybe-need-error-msg)\n  const auto& blob_desc_sign_map = op_attribute.logical_blob_desc_signature().bn_in_op2blob_desc();\n  auto blob_desc_it = blob_desc_sign_map.find(bn_in_op);\n  CHECK_OR_RETURN(blob_desc_it != blob_desc_sign_map.end())\n      << \"blob_desc of \" << bn_in_op << \" not found in op \" << op_attribute.op_conf().name();\n\n  auto shape = std::make_shared<Shape>(blob_desc_it->second.shape());\n  auto dtype = blob_desc_it->second.data_type();\n  CHECK_EQ_OR_RETURN(*tensor->shape(), *shape);             // NOLINT(maybe-need-error-msg)\n  CHECK_EQ_OR_RETURN(tensor->dtype()->data_type(), dtype);  // NOLINT(maybe-need-error-msg)\n\n  if (is_local) {\n    const auto& device = JUST(Device::MakeDeviceByParallelDesc(*parallel_desc));\n    CHECK_OR_RETURN(JUST(tensor->device()) == device);  // NOLINT(maybe-need-error-msg)\n  } else {\n    const auto& nd_sbp_sign_map = op_attribute.nd_sbp_signature().bn_in_op2nd_sbp();\n    auto nd_sbp_it = nd_sbp_sign_map.find(bn_in_op);\n    CHECK_OR_RETURN(nd_sbp_it != nd_sbp_sign_map.end())\n        << \"nd_sbp of \" << bn_in_op << \" not found in op \" << op_attribute.op_conf().name();\n    // Only check the nd_sbp if auto parallel is not enable,\n    // since the semi-auto parallellism rule might have inconsistency with the auto-parallel\n    // strategy.\n    if (!GlobalJobDesc().enable_auto_parallel()) {\n      NdSbp nd_sbp(nd_sbp_it->second);\n      CHECK_OR_RETURN(JUST(tensor->nd_sbp()) == SymbolOf(nd_sbp))\n          << \"The input sbp is not valid for an inplace operation, please try to use non-inplace. \"\n          << NdSbpToString(JUST(tensor->nd_sbp())) << \" vs \" << NdSbpToString(nd_sbp);\n    }\n    CHECK_OR_RETURN(JUST(tensor->parallel_desc())  // NOLINT(maybe-need-error-msg)\n                    == SymbolOf(*parallel_desc));  // NOLINT(maybe-need-error-msg)\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<const std::string&> GetDeviceTagOfTensor(const std::shared_ptr<Tensor>& tensor) {\n  if (tensor->is_global()) { return JUST(tensor->parallel_desc())->device_tag(); }\n  return JUST(tensor->device())->type();\n}\n\nbool GetIsDynamicOfTensor(const std::shared_ptr<Tensor>& tensor) {\n  if (tensor->is_global()) {\n    return false;\n  } else {\n    return true;\n  }\n}\n\nMaybe<void> GenNdSbpByTensor(NdSbp* nd_sbp, const std::shared_ptr<Tensor>& tensor) {\n  nd_sbp->clear_sbp_parallel();\n  if (tensor->is_local()) {\n    // NOTE(chengcheng):\n    //   OneFlow Lazy is always global. LocalTensor is a special case of GlobalTensor\n    //   which placement is only this rank, and SbpParallel is Broadcast.\n    nd_sbp->add_sbp_parallel()->mutable_broadcast_parallel();\n  } else {\n    *nd_sbp = *JUST(tensor->nd_sbp());\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> GenVariableOpConfNdSbpStringByTensor(VariableOpConf* var_conf,\n                                                 const std::shared_ptr<Tensor>& tensor) {\n  var_conf->clear_nd_sbp();\n  if (tensor->is_local()) {\n    SbpParallel broadcast;\n    broadcast.mutable_broadcast_parallel();\n    var_conf->add_nd_sbp(SbpParallelToString(broadcast));\n  } else {\n    const NdSbp& nd_sbp = *JUST(tensor->nd_sbp());\n    for (const auto& sbp_parallel : nd_sbp.sbp_parallel()) {\n      var_conf->add_nd_sbp(SbpParallelToString(sbp_parallel));\n    }\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<const ParallelDesc> GetParallelDescOfTensor(const std::shared_ptr<Tensor>& tensor) {\n  if (tensor->is_local()) {\n    const auto& device = JUST(tensor->device());\n    const auto& placement = JUST(Placement4Device(device));\n    return placement.shared_from_symbol();\n  } else {\n    return JUST(tensor->parallel_desc()).shared_from_symbol();\n  }\n}\n\nMaybe<Scope> NewScopeWithParallelConfAndCurScope(const ParallelConf& parallel_conf) {\n  std::shared_ptr<Scope> new_scope;\n  const auto& old_scope = JUST(GetCurrentScope());\n  JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe<void> {\n    new_scope = JUST(builder->BuildScopeWithNewParallelConf(old_scope, parallel_conf));\n    return Maybe<void>::Ok();\n  }));\n  // NOTE(chengcheng): need sync vm for get scope right now\n  JUST(vm::CurrentRankSync());\n  CHECK_OR_RETURN(new_scope);  // NOLINT(maybe-need-error-msg)\n  return new_scope;\n}\n\nMaybe<Scope> NewScopeWithParallelDescByTensor(const std::shared_ptr<Tensor>& tensor) {\n  return NewScopeWithParallelConfAndCurScope(\n      JUST(GetParallelDescOfTensor(tensor))->parallel_conf());\n}\n\nMaybe<int32_t> GetGradAccStep() {\n  const auto& infer_ctx = JUST(GetCurInferCtx());\n  const auto& job_conf = infer_ctx->job().job_conf();\n  if (job_conf.has_train_conf() && job_conf.has_num_gradient_accumulation_steps()\n      && job_conf.num_gradient_accumulation_steps() > 1) {\n    return job_conf.num_gradient_accumulation_steps();\n  } else {\n    return 1;\n  }\n}\n\nMaybe<void> AddFreeEagerTensorToVariableOp(const std::shared_ptr<Tensor>& input_tensor) {\n  if (!input_tensor->is_contiguous()) {\n    LazyMode::Guard lazy_mode_disabled_guard(false);\n    JUST(functional::InplaceToContiguous(input_tensor));\n    JUST(vm::CurrentRankSync());\n  }\n\n  CHECK_OR_RETURN(input_tensor->is_eager());  // NOLINT(maybe-need-error-msg)\n  const std::string& empty_lbn = TensorNameScope::Global()->Lookup(input_tensor);\n  CHECK_OR_RETURN(empty_lbn.empty());  // NOLINT(maybe-need-error-msg)\n  std::shared_ptr<Scope> scope = JUST(NewScopeWithParallelDescByTensor(input_tensor));\n  OperatorConf op_conf;\n  op_conf.set_scope_symbol_id(JUST(scope->symbol_id()));\n  op_conf.set_device_tag(JUST(GetDeviceTagOfTensor(input_tensor)));\n  VariableOpConf* var_conf = op_conf.mutable_variable_conf();\n  var_conf->set_out(\"out\");\n  input_tensor->shape()->ToProto(var_conf->mutable_shape());\n  var_conf->set_data_type(input_tensor->dtype()->data_type());\n  // NOTE(chengcheng): VariableOpConf initializer_conf is useless because variable is inited\n  //   by EagerTensor.\n  var_conf->mutable_initializer()->mutable_empty_conf();\n  JUST(GenVariableOpConfNdSbpStringByTensor(var_conf, input_tensor));\n  // NOTE(chengcheng): Free EagerTensor not trainable\n  var_conf->set_trainable(false);\n\n  auto infer_ctx = JUST(GetCurInferCtx());\n  // NOTE(chengcheng): MUST reset unique op name before InferCtx::AddOp, FreeEagerTensor has no\n  //  name so just new a unique name for it.\n  const std::string new_op_name = *JUST(infer_ctx->NewUniqueOpNameByFunctionalOpConf(op_conf));\n  op_conf.set_name(new_op_name);\n\n  VLOG(2) << \"Lazy nn.Graph name \" << infer_ctx->job().job_conf().job_name() << \" try to add op: \\n\"\n          << op_conf.DebugString() << std::endl;\n  OpAttribute op_attr = *JUST(infer_ctx->AddAndInferGlobalOp(op_conf));\n  VLOG(2) << \"Lazy nn.Graph name \" << infer_ctx->job().job_conf().job_name() << \" add op : \\n\"\n          << op_conf.name() << \" for FreeEagerTensor.\\n\";\n  VLOG(3) << \"Lazy nn.Graph name \" << infer_ctx->job().job_conf().job_name()\n          << \" infer and and op attr : \\n\"\n          << op_attr.DebugString() << \" for FreeEagerTensor.\\n\";\n\n  // NOTE(chengcheng): MUST store this tensor to MultiClientSessionContext for graph runtime bind.\n  const std::string graph_name = *JUST(JUST(GlobalJobBuildAndInferCtxMgr())->GetCurrentJobName());\n  const std::string lbn = GenLogicalBlobName(new_op_name, \"out\");\n  Singleton<MultiClientSessionContext>::Get()->StoreFreeEagerTensorWithNameByGraphName(\n      graph_name, input_tensor, new_op_name);\n\n  int64_t parallel_desc_sym_id = JUST(scope->GetParallelDescSymbolId(op_conf));\n  auto blob_parallel_desc = JUST(GetSymbol<ParallelDesc>(parallel_desc_sym_id));\n\n  auto var_tensor = JUST(BuildTensor(op_attr, \"out\", blob_parallel_desc, /* is_lazy= */ true,\n                                     /* is_local= */ input_tensor->is_local()));\n  TensorNameScope::Global()->Record(var_tensor, lbn);\n\n  // NOTE(chengcheng): MUST record this eager_tensor name as new variable output lbn.\n  // NOTE(chengcheng): in GradAcc FreeEagerTensor need insert repeat op, but there is no need to\n  //  create a new tensor for repeat op out. We just set repeat lbn as this free eager tensor's lbn.\n  auto repeat_tensor = JUST(GradAccTryInsertRepeatAfterVar(var_tensor));\n  const std::string& repeat_tensor_name = TensorNameScope::Global()->Lookup(repeat_tensor);\n  CHECK_OR_RETURN(!repeat_tensor_name.empty());  // NOLINT(maybe-need-error-msg)\n  TensorNameScope::Global()->Record(input_tensor, repeat_tensor_name);\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\nMaybe<Tensor> GradAccTryInsertUnpackAfterInput(const std::shared_ptr<Tensor>& input) {\n  int32_t grad_acc_step = JUST(GetGradAccStep());\n  if (grad_acc_step > 1) {\n    // NOTE(chengcheng):\n    //   We assume that the input data is one mini-batch which containing multi micro-batches.\n    //   So we need unpack input data for each micro-batch.\n    VLOG(2)\n        << \" Current OneFlow nn.Graph grad acc semantics is different from Torch. \\n\"\n        << \" Once call nn.Graph in OneFlow, it indicates a mini-batch. When grad acc steps > 1, \\n\"\n        << \" the input tensor of nn.Graph will be unpacked by 0th dim into multiple micro-batches \"\n        << \" and exec them in order.\\n\";\n    const auto& infer_ctx = JUST(GetCurInferCtx());\n    const auto& input_lbn = TensorNameScope::Global()->Lookup(input);\n    VLOG(2) << \"Lazy nn.Graph name \" << infer_ctx->job().job_conf().job_name()\n            << \" add grad acc unpack op after input \" << input_lbn << std::endl;\n    return functional::GradAccUnpack(input, grad_acc_step);\n  } else {\n    return input;\n  }\n}\n\nMaybe<Tensor> GradAccTryInsertRepeatAfterVar(const std::shared_ptr<Tensor>& variable) {\n  int32_t grad_acc_step = JUST(GetGradAccStep());\n  if (grad_acc_step > 1) {\n    // NOTE(chengcheng):\n    //   We assume that the nn.Graph once call is one mini-batch which containing multi\n    //   micro-batches. So we just repeat variable tensor for each micro-batch.\n    VLOG(2)\n        << \" Current OneFlow nn.Graph grad acc semantics is different from Torch. \\n\"\n        << \" Once call nn.Graph in OneFlow, it indicates a mini-batch. When grad acc steps > 1, \\n\"\n        << \" the var tensor of nn.Graph will be repeated exec for multiple micro-batches. \\n\";\n    const auto& infer_ctx = JUST(GetCurInferCtx());\n    const auto& variable_lbn = TensorNameScope::Global()->Lookup(variable);\n    VLOG(2) << \"Lazy nn.Graph name \" << infer_ctx->job().job_conf().job_name()\n            << \" add grad acc repeat op after variable \" << variable_lbn << std::endl;\n    return functional::GradAccRepeat(variable, grad_acc_step);\n  } else {\n    return variable;\n  }\n}\n\nMaybe<Tensor> GradAccTryInsertPackBeforeOutput(const std::shared_ptr<Tensor>& output) {\n  int32_t grad_acc_step = JUST(GetGradAccStep());\n  if (grad_acc_step > 1) {\n    // NOTE(chengcheng):\n    //   We assume that the nn.Graph once call is one mini-batch which containing multi\n    //   micro-batches. So we need pack output tensor for each micro-batch to one micro-batch.\n    VLOG(2)\n        << \" Current OneFlow nn.Graph grad acc semantics is different from Torch. \\n\"\n        << \" Once call nn.Graph in OneFlow, it indicates a mini-batch. When grad acc steps > 1, \\n\"\n        << \" the output tensor of nn.Graph will be packed to a big tensor by 0th dim, after exec \\n\"\n        << \" for multiple micro-batches. \\n\";\n    const auto& infer_ctx = JUST(GetCurInferCtx());\n    const auto& output_lbn = TensorNameScope::Global()->Lookup(output);\n    VLOG(2) << \"Lazy nn.Graph name \" << infer_ctx->job().job_conf().job_name()\n            << \" add grad acc pack op before output \" << output_lbn << std::endl;\n    return functional::GradAccPack(output, grad_acc_step);\n  } else {\n    return output;\n  }\n}\n\nMaybe<void> GradAccTryInsertRepeatTickBeforeSource(\n    const std::shared_ptr<OperatorConf>& source_op_conf, bool is_local) {\n  int32_t grad_acc_step = JUST(GetGradAccStep());\n  if (grad_acc_step > 1) {\n    // NOTE(chengcheng):\n    //   We assume that the nn.Graph once call is one mini-batch which containing multi\n    //   micro-batches. So we need repeat source op for each micro-batch in one micro-batch.\n    VLOG(2)\n        << \" Current OneFlow nn.Graph grad acc semantics is different from Torch. \\n\"\n        << \" Once call nn.Graph in OneFlow, it indicates a mini-batch. When grad acc steps > 1, \\n\"\n        << \" the source op of nn.Graph will be repeated exec n-times for multiple micro-batches.\\n\";\n    const auto& infer_ctx = JUST(GetCurInferCtx());\n    // Insert Tick\n    OperatorConf tick_conf{};\n    tick_conf.set_name(\"Sys-GradAcc-RepeatTick-DeviceTick-\" + source_op_conf->name());\n    tick_conf.set_device_tag(source_op_conf->device_tag());\n    tick_conf.mutable_device_tick_conf()->set_out(\"out\");\n    tick_conf.set_scope_symbol_id(source_op_conf->scope_symbol_id());\n    auto tick_lbn = GenLogicalBlobName(tick_conf.name(), tick_conf.device_tick_conf().out());\n    OpAttribute tick_op_attr = *JUST(infer_ctx->AddAndInferGlobalOp(tick_conf));\n    VLOG(2) << \"Lazy nn.Graph name \" << infer_ctx->job().job_conf().job_name() << \" add op: \\n\"\n            << tick_conf.DebugString() << std::endl;\n    VLOG(3) << \"Lazy nn.Graph name \" << infer_ctx->job().job_conf().job_name()\n            << \" infer and and op attr : \\n\"\n            << tick_op_attr.DebugString() << std::endl;\n\n    const auto& scope =\n        Singleton<symbol::Storage<Scope>>::Get()->Get(source_op_conf->scope_symbol_id());\n    int64_t parallel_desc_sym_id = JUST(scope.GetParallelDescSymbolId(tick_conf));\n    auto blob_parallel_desc = JUST(GetSymbol<ParallelDesc>(parallel_desc_sym_id));\n\n    auto tick_tensor = JUST(BuildTensor(tick_op_attr, tick_conf.device_tick_conf().out(),\n                                        blob_parallel_desc, /* is_lazy= */ true,\n                                        /* is_local= */ is_local));\n    TensorNameScope::Global()->Record(tick_tensor, tick_lbn);\n\n    VLOG(2) << \"Lazy nn.Graph name \" << infer_ctx->job().job_conf().job_name()\n            << \" add grad acc repeat op after tick op \" << tick_conf.name()\n            << \" and before source op\" << source_op_conf->name();\n    auto repeat_tensor = JUST(functional::GradAccRepeat(tick_tensor, grad_acc_step));\n    const std::string& repeat_tensor_name = TensorNameScope::Global()->Lookup(repeat_tensor);\n    CHECK_OR_RETURN(!repeat_tensor_name.empty());  // NOLINT(maybe-need-error-msg)\n    (*source_op_conf->mutable_user_conf()->mutable_input())[user_op::kUserSourceOpTickInputArgName]\n        .add_s(repeat_tensor_name);\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> LazyInterpreter::ApplyImpl(const FeedInputOpExpr& op_expr, const TensorTuple& inputs,\n                                       TensorTuple* outputs, const OpExprInterpContext& ctx) const {\n  // NOTE(chengcheng): inputs[0] is the EagerTensor\n  CHECK_EQ_OR_RETURN(inputs.size(), 1);         // NOLINT(maybe-need-error-msg)\n  CHECK_EQ_OR_RETURN(op_expr.input_size(), 1);  // NOLINT(maybe-need-error-msg)\n  const std::shared_ptr<Tensor>& input_tensor = inputs.at(0);\n  CHECK_OR_RETURN(input_tensor->is_eager());  // NOLINT(maybe-need-error-msg)\n\n  std::shared_ptr<Scope> scope = JUST(NewScopeWithParallelDescByTensor(input_tensor));\n\n  OperatorConf op_conf;\n  op_conf.set_name(op_expr.op_name());  // construct by python nn.Graph\n  op_conf.set_scope_symbol_id(JUST(scope->symbol_id()));\n  op_conf.set_device_tag(JUST(GetDeviceTagOfTensor(input_tensor)));\n  // NOTE(chengcheng):\n  //   We contruct InputOpConf instead of FeedInputOpConf because FeedInputOpExpr JUST for getting\n  //   input EagerTensor.\n  InputOpConf* input_conf = op_conf.mutable_input_conf();\n  input_conf->set_out(\"out\");\n  InterfaceBlobConf* blob_conf = input_conf->mutable_blob_conf();\n\n  input_tensor->shape()->ToProto(blob_conf->mutable_shape());\n  blob_conf->set_data_type(input_tensor->dtype()->data_type());\n  // NOTE(chengcheng): is_dynamic true has conflict in global lazy job even if world size 1.\n  //     this flag will be removed in the future.\n  // blob_conf->set_is_dynamic(GetIsDynamicOfTensor(input_tensor));\n  blob_conf->set_is_dynamic(false);\n  JUST(GenNdSbpByTensor(blob_conf->mutable_nd_sbp(), input_tensor));\n\n  auto infer_ctx = JUST(GetCurInferCtx());\n  VLOG(2) << \"Lazy nn.Graph name \" << infer_ctx->job().job_conf().job_name()\n          << \" try to add op: \\n: \" << op_conf.DebugString() << std::endl;\n  OpAttribute op_attr = *JUST(infer_ctx->AddAndInferGlobalOp(op_conf));\n  VLOG(2) << \"Lazy nn.Graph name \" << infer_ctx->job().job_conf().job_name() << \" add op : \\n\"\n          << op_conf.name() << std::endl;\n  VLOG(3) << \"Lazy nn.Graph name \" << infer_ctx->job().job_conf().job_name()\n          << \" infer and and op attr : \\n\"\n          << op_attr.DebugString() << std::endl;\n\n  int64_t parallel_desc_sym_id = JUST(scope->GetParallelDescSymbolId(op_conf));\n  auto blob_parallel_desc = JUST(GetSymbol<ParallelDesc>(parallel_desc_sym_id));\n\n  // Check outputs num and setup output tensor properties.\n  CHECK_EQ_OR_RETURN(outputs->size(), 1);        // NOLINT(maybe-need-error-msg)\n  CHECK_EQ_OR_RETURN(op_expr.output_size(), 1);  // NOLINT(maybe-need-error-msg)\n  CHECK_OR_RETURN(!(*outputs)[0]);               // NOLINT(maybe-need-error-msg)\n  const std::string obn = \"out\";  // NOTE(chengcheng): obn is NOT op_expr.indexed_obns\n  auto origin_input = JUST(BuildTensor(op_attr, obn, blob_parallel_desc, /* is_lazy= */ true,\n                                       /* is_local= */ input_tensor->is_local()));\n  TensorNameScope::Global()->Record(origin_input, GenLogicalBlobName(op_conf.name(), obn));\n  TensorNameScope::Global()->Record(input_tensor, GenLogicalBlobName(op_conf.name(), obn));\n\n  // NOTE: The input will then be unpacked in DispatchFeedInputOpExprFunctor\n  // if GradAcc is enabled\n  (*outputs)[0] = origin_input;\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> LazyInterpreter::ApplyImpl(const FeedVariableOpExpr& op_expr, const TensorTuple& inputs,\n                                       TensorTuple* outputs, const OpExprInterpContext& ctx) const {\n  // NOTE(chengcheng): inputs[0] is the EagerTensor\n  CHECK_EQ_OR_RETURN(inputs.size(), 1);         // NOLINT(maybe-need-error-msg)\n  CHECK_EQ_OR_RETURN(op_expr.input_size(), 1);  // NOLINT(maybe-need-error-msg)\n  const std::shared_ptr<Tensor>& input_tensor = inputs.at(0);\n  CHECK_OR_RETURN(input_tensor->is_eager());  // NOLINT(maybe-need-error-msg)\n\n  std::shared_ptr<Scope> scope = JUST(NewScopeWithParallelDescByTensor(input_tensor));\n\n  OperatorConf op_conf;\n  op_conf.set_name(op_expr.op_name());  // construct by python nn.Graph\n  op_conf.set_scope_symbol_id(JUST(scope->symbol_id()));\n  op_conf.set_device_tag(JUST(GetDeviceTagOfTensor(input_tensor)));\n  // NOTE(chengcheng):\n  //   We contruct VariableOpConf instead of FeedVariableOpConf because FeedVariableOpExpr JUST\n  //   for getting input EagerTensor.\n  VariableOpConf* var_conf = op_conf.mutable_variable_conf();\n  var_conf->set_out(\"out\");\n  input_tensor->shape()->ToProto(var_conf->mutable_shape());\n  var_conf->set_data_type(input_tensor->dtype()->data_type());\n  // NOTE(chengcheng): VariableOpConf initializer_conf is useless because variable is inited\n  //   by EagerTensor.\n  var_conf->mutable_initializer()->mutable_empty_conf();\n  JUST(GenVariableOpConfNdSbpStringByTensor(var_conf, input_tensor));\n  if (!input_tensor->requires_grad()) { var_conf->set_trainable(false); }\n  if (input_tensor->requires_grad()) {\n    double l2 = JUST(ctx.attrs.GetAttr<double>(\"l2\"));\n    if (unlikely(l2 != 0.0)) { var_conf->mutable_regularizer()->mutable_l1_l2_conf()->set_l2(l2); }\n  }\n\n  auto infer_ctx = JUST(GetCurInferCtx());\n  VLOG(2) << \"Lazy nn.Graph name \" << infer_ctx->job().job_conf().job_name()\n          << \" try to add op: \\n: \" << op_conf.DebugString() << std::endl;\n  OpAttribute op_attr = *JUST(infer_ctx->AddAndInferGlobalOp(op_conf));\n  VLOG(2) << \"Lazy nn.Graph name \" << infer_ctx->job().job_conf().job_name() << \" add op : \\n\"\n          << op_conf.name() << std::endl;\n  VLOG(3) << \"Lazy nn.Graph name \" << infer_ctx->job().job_conf().job_name()\n          << \" infer and and op attr : \\n\"\n          << op_attr.DebugString() << std::endl;\n\n  int64_t parallel_desc_sym_id = JUST(scope->GetParallelDescSymbolId(op_conf));\n  auto blob_parallel_desc = JUST(GetSymbol<ParallelDesc>(parallel_desc_sym_id));\n\n  // Check outputs num and setup output tensor properties.\n  CHECK_EQ_OR_RETURN(outputs->size(), 1);        // NOLINT(maybe-need-error-msg)\n  CHECK_EQ_OR_RETURN(op_expr.output_size(), 1);  // NOLINT(maybe-need-error-msg)\n  CHECK_OR_RETURN(!(*outputs)[0]);               // NOLINT(maybe-need-error-msg)\n\n  const std::string obn = \"out\";  // NOTE(chengcheng): obn is NOT op_expr.indexed_obns\n  auto origin_var = JUST(BuildTensor(op_attr, obn, blob_parallel_desc, /* is_lazy= */ true,\n                                     /* is_local */ input_tensor->is_local()));\n  // NOTE(chengcheng): Record variable op output LazyTenosr\n  TensorNameScope::Global()->Record(origin_var, GenLogicalBlobName(op_conf.name(), obn));\n  // NOTE(chengcheng): Record EagerTensor as variable tensor name\n  TensorNameScope::Global()->Record(input_tensor, GenLogicalBlobName(op_conf.name(), obn));\n\n  // NOTE: The output variable will then be repeat in DispatchFeedVariableOpExprFunctor\n  // if GradAcc is enabled\n  (*outputs)[0] = origin_var;\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> LazyInterpreter::ApplyImpl(const FetchOutputOpExpr& op_expr, const TensorTuple& inputs,\n                                       TensorTuple* outputs, const OpExprInterpContext& ctx) const {\n  // NOTE: The input has been packed in DispatchFetchOutputOpExprFunctor\n  // if GradAcc is enabled\n  // NOTE(chengcheng): inputs[0] is the LazyTensor\n  CHECK_EQ_OR_RETURN(inputs.size(), 1);         // NOLINT(maybe-need-error-msg)\n  CHECK_EQ_OR_RETURN(op_expr.input_size(), 1);  // NOLINT(maybe-need-error-msg)\n  const std::shared_ptr<Tensor>& input_tensor = inputs.at(0);\n  std::string input_lbn = TensorNameScope::Global()->Lookup(input_tensor);\n  // Lazy tensor must has lbn.\n  // Eager tensor may has lbn if it has already been treated as an output of a variable op\n  // or an output of an inplace op.\n  if (input_lbn.empty()) {\n    CHECK_OR_RETURN(input_tensor->is_eager());  // NOLINT(maybe-need-error-msg)\n    // This output tensor is a new free eager tensor, so treat it as a new variable op output.\n    JUST(AddFreeEagerTensorToVariableOp(input_tensor));\n    input_lbn = TensorNameScope::Global()->Lookup(input_tensor);\n    CHECK_OR_RETURN(!input_lbn.empty());  // NOLINT(maybe-need-error-msg)\n  }\n  std::shared_ptr<Scope> scope = JUST(NewScopeWithParallelDescByTensor(input_tensor));\n\n  OperatorConf op_conf;\n  op_conf.set_name(op_expr.op_name());  // construct by python nn.Graph\n  op_conf.set_scope_symbol_id(JUST(scope->symbol_id()));\n  op_conf.set_device_tag(JUST(GetDeviceTagOfTensor(input_tensor)));\n  // NOTE(chengcheng):\n  //   We contruct OutputOpConf instead of FetchOutputOpConf because FetchOutputOpExpr JUST\n  //   for get nn.Graph output LazyTensor.\n  OutputOpConf* output_conf = op_conf.mutable_output_conf();\n  output_conf->set_in(input_lbn);\n  output_conf->set_out(\"out\");\n  InterfaceBlobConf* blob_conf = output_conf->mutable_blob_conf();\n  input_tensor->shape()->ToProto(blob_conf->mutable_shape());\n  blob_conf->set_data_type(input_tensor->dtype()->data_type());\n  // NOTE(chengcheng): is_dynamic true has conflict in global lazy job even if world size 1.\n  //     this flag will be removed in the future.\n  // blob_conf->set_is_dynamic(GetIsDynamicOfTensor(input_tensor));\n  blob_conf->set_is_dynamic(false);\n  JUST(GenNdSbpByTensor(blob_conf->mutable_nd_sbp(), input_tensor));\n\n  auto infer_ctx = JUST(GetCurInferCtx());\n  VLOG(2) << \"Lazy nn.Graph name \" << infer_ctx->job().job_conf().job_name() << \" try to add op: \\n\"\n          << op_conf.DebugString() << std::endl;\n  OpAttribute op_attr = *JUST(infer_ctx->AddAndInferGlobalOp(op_conf));\n  VLOG(2) << \"Lazy nn.Graph name \" << infer_ctx->job().job_conf().job_name() << \" add op : \\n\"\n          << op_conf.name() << std::endl;\n  VLOG(3) << \"Lazy nn.Graph name \" << infer_ctx->job().job_conf().job_name()\n          << \" infer and and op attr : \\n\"\n          << op_attr.DebugString() << std::endl;\n\n  int64_t parallel_desc_sym_id = JUST(scope->GetParallelDescSymbolId(op_conf));\n  auto blob_parallel_desc = JUST(GetSymbol<ParallelDesc>(parallel_desc_sym_id));\n\n  // Check outputs num and setup output tensor properties.\n  CHECK_EQ_OR_RETURN(outputs->size(), 1);        // NOLINT(maybe-need-error-msg)\n  CHECK_EQ_OR_RETURN(op_expr.output_size(), 1);  // NOLINT(maybe-need-error-msg)\n  CHECK_OR_RETURN(!(*outputs)[0]);               // NOLINT(maybe-need-error-msg)\n  const std::string obn = \"out\";  // NOTE(chengcheng): obn is NOT op_expr.indexed_obns\n  (*outputs)[0] = JUST(BuildTensor(op_attr, obn, blob_parallel_desc, /* is_lazy= */ false,\n                                   /* is_local= */ input_tensor->is_local()));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> LazyInterpreter::ApplyImpl(const ImageDecoderRandomCropResizeOpExpr& op_expr,\n                                       const TensorTuple& inputs, TensorTuple* outputs,\n                                       const OpExprInterpContext& ctx) const {\n  CHECK_EQ_OR_RETURN(inputs.size(), 1);         // NOLINT(maybe-need-error-msg)\n  CHECK_EQ_OR_RETURN(op_expr.input_size(), 1);  // NOLINT(maybe-need-error-msg)\n  const std::shared_ptr<Tensor>& input_tensor = inputs.at(0);\n  const std::string& input_lbn = TensorNameScope::Global()->Lookup(input_tensor);\n  CHECK_OR_RETURN(!input_lbn.empty());  // NOLINT(maybe-need-error-msg)\n\n  auto op_conf = JUST(OpInterpUtil::GenBuiltinOpConf(op_expr, ctx.attrs));\n  std::string device_tag;\n  if (IsCpuOnly(*op_conf)) {\n    device_tag = \"cpu\";\n  } else {\n    device_tag = \"cuda\";\n  }\n\n  ParallelConf parallel_conf = JUST(GetParallelDescOfTensor(input_tensor))->parallel_conf();\n  parallel_conf.set_device_tag(device_tag);  // NOTE(chengcheng): only support gpu decode.\n  const auto& scope = JUST(NewScopeWithParallelConfAndCurScope(parallel_conf));\n\n  op_conf->set_scope_symbol_id(JUST(scope->symbol_id()));\n  op_conf->set_device_tag(device_tag);\n\n  // NOTE(chengcheng): replace right input_lbn and obn\n  ReplaceInputLbnInOpCustomizedConf(op_conf.get(), /* ibn */ \"in\", input_lbn);\n  op_conf->mutable_image_decoder_random_crop_resize_conf()->set_out(\"out\");\n\n  auto infer_ctx = JUST(GetCurInferCtx());\n  // NOTE(chengcheng): MUST reset unique op name before InferCtx::AddOp\n  const std::string new_op_name = *JUST(infer_ctx->NewUniqueOpNameByFunctionalOpConf(*op_conf));\n  op_conf->set_name(new_op_name);\n  VLOG(2) << \"Lazy nn.Graph name \" << infer_ctx->job().job_conf().job_name() << \" try to add op: \\n\"\n          << op_conf->DebugString() << std::endl;\n  OpAttribute op_attr = *JUST(infer_ctx->AddAndInferGlobalOp(*op_conf));\n  VLOG(2) << \"Lazy nn.Graph name \" << infer_ctx->job().job_conf().job_name() << \" add op : \\n\"\n          << op_conf->name() << std::endl;\n  VLOG(3) << \"Lazy nn.Graph name \" << infer_ctx->job().job_conf().job_name()\n          << \" infer and and op attr : \\n\"\n          << op_attr.DebugString() << std::endl;\n\n  int64_t parallel_desc_sym_id = JUST(scope->GetParallelDescSymbolId(*op_conf));\n  auto blob_parallel_desc = JUST(GetSymbol<ParallelDesc>(parallel_desc_sym_id));\n\n  // Check outputs num and setup output tensor properties.\n  CHECK_EQ_OR_RETURN(outputs->size(), 1);        // NOLINT(maybe-need-error-msg)\n  CHECK_EQ_OR_RETURN(op_expr.output_size(), 1);  // NOLINT(maybe-need-error-msg)\n  CHECK_OR_RETURN(!(*outputs)[0]);               // NOLINT(maybe-need-error-msg)\n  const std::string obn = \"out\";  // NOTE(chengcheng): obn is NOT op_expr.indexed_obns\n  (*outputs)[0] = JUST(BuildTensor(op_attr, obn, blob_parallel_desc, /* is_lazy= */ true,\n                                   /* is_local= */ input_tensor->is_local()));\n  TensorNameScope::Global()->Record((*outputs)[0], GenLogicalBlobName(new_op_name, obn));\n  return Maybe<void>::Ok();\n}\n\nnamespace {\n\nMaybe<void> LazyInterpreterApplyImplForSourceUserOpExpr(const UserOpExpr& op_expr,\n                                                        TensorTuple* outputs,\n                                                        const OpExprInterpContext& ctx) {\n  NonRecursiveMetaInfoConsistencyCheckScope non_scope;\n  bool is_local;\n  std::shared_ptr<const ParallelDesc> parallel_desc;\n  if (ctx.parallel_desc.has_value()) {\n    // NOTE(chengcheng): global\n    CHECK_OR_RETURN(!ctx.device.has_value());  // NOLINT(maybe-need-error-msg)\n    const auto& parallel_desc_sym = JUST(ctx.parallel_desc);\n    parallel_desc = parallel_desc_sym.shared_from_symbol();\n    JUST(MetaInfoConsistencyCheck(parallel_desc_sym, ctx.nd_sbp, 1, /* force_check */ false));\n    is_local = false;\n  } else {\n    // NOTE(chengcheng): local\n    CHECK_OR_RETURN(!ctx.nd_sbp.has_value());  // NOLINT(maybe-need-error-msg)\n    if (ctx.device.has_value()) {\n      const auto& device = JUST(ctx.device);\n      const auto& placement = JUST(Placement4Device(device));\n      parallel_desc = placement.shared_from_symbol();\n    } else {\n      // NOTE(chengcheng): if functor NOT set device, using cpu device default.\n      const auto& device = JUST(Device::New(\"cpu\"));\n      const auto& placement = JUST(Placement4Device(device));\n      parallel_desc = placement.shared_from_symbol();\n    }\n    is_local = true;\n  }\n  const auto& parallel_conf = parallel_desc->parallel_conf();\n  const auto& scope = JUST(NewScopeWithParallelConfAndCurScope(parallel_conf));\n  auto op_conf = JUST(OpInterpUtil::GenBuiltinOpConf(op_expr, ctx.attrs));\n  op_conf->set_scope_symbol_id(JUST(scope->symbol_id()));\n  op_conf->set_device_tag(parallel_conf.device_tag());\n\n  auto infer_ctx = JUST(GetCurInferCtx());\n  // NOTE(chengcheng): MUST reset unique op name before InferCtx::AddOp\n  const std::string new_op_name = *JUST(infer_ctx->NewUniqueOpNameByFunctionalOpConf(*op_conf));\n  const std::string graph_name = infer_ctx->job().job_conf().job_name();\n\n  // NOTE(chengcheng): for UserOp, NOT only reset op_name, but also the output values.\n  op_conf->set_name(new_op_name);\n  for (auto& pair : *(op_conf->mutable_user_conf()->mutable_output())) {\n    auto& list_s = pair.second;\n    for (int i = 0; i < list_s.s_size(); ++i) {\n      std::string old_lbn = list_s.s(i);\n      LogicalBlobId old_lbi = GenLogicalBlobId(old_lbn);\n      // NOTE(chengcheng): MUST change the old_lbn to new op name.\n      std::string new_lbn = GenLogicalBlobName(new_op_name, old_lbi.blob_name());\n      list_s.set_s(i, new_lbn);\n    }\n  }\n\n  JUST(GradAccTryInsertRepeatTickBeforeSource(op_conf, is_local));\n\n  VLOG(2) << \"Lazy nn.Graph name \" << infer_ctx->job().job_conf().job_name() << \" try to add op: \\n\"\n          << op_conf->DebugString() << std::endl;\n  OpAttribute op_attr = *JUST(infer_ctx->AddAndInferGlobalOp(*op_conf));\n  VLOG(2) << \"Lazy nn.Graph name \" << infer_ctx->job().job_conf().job_name() << \" add op : \\n\"\n          << op_conf->name() << std::endl;\n  VLOG(3) << \"Lazy nn.Graph name \" << infer_ctx->job().job_conf().job_name()\n          << \" infer and and op attr : \\n\"\n          << op_attr.DebugString() << std::endl;\n\n  int64_t parallel_desc_sym_id = JUST(scope->GetParallelDescSymbolId(*op_conf));\n  auto blob_parallel_desc = JUST(GetSymbol<ParallelDesc>(parallel_desc_sym_id));\n\n  // Check outputs num and setup output tensor properties.\n  CHECK_EQ_OR_RETURN(outputs->size(), op_expr.output_size());  // NOLINT(maybe-need-error-msg)\n  for (int i = 0; i < op_expr.output_size(); ++i) {\n    const std::string& obn = op_expr.indexed_obns().at(i);\n    if (!(*outputs)[i]) {\n      (*outputs)[i] =\n          JUST(BuildTensor(op_attr, obn, blob_parallel_desc, /* is_lazy= */ true, is_local));\n    } else {\n      VLOG(2) << \"Lazy nn.Graph name \" << graph_name << \" source op name \" << new_op_name\n              << \" run with inplace.\";\n      const std::shared_ptr<Tensor>& inplace_out = (*outputs)[i];\n      JUST(CheckTensorMatchAttr(inplace_out, op_attr, obn, blob_parallel_desc, is_local));\n    }\n    TensorNameScope::Global()->Record((*outputs)[i], GenLogicalBlobName(new_op_name, obn));\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> LazyInterpreterApplyImplForCopyUserOpExpr(const UserOpExpr& op_expr,\n                                                      const TensorTuple& inputs,\n                                                      TensorTuple* outputs,\n                                                      const OpExprInterpContext& ctx) {\n  CHECK_OR_RETURN(op_expr.op_type_name() == \"copy\");  // NOLINT(maybe-need-error-msg)\n  CHECK_EQ_OR_RETURN(inputs.size(), 1);               // NOLINT(maybe-need-error-msg)\n  CHECK_EQ_OR_RETURN(op_expr.input_size(), 1);        // NOLINT(maybe-need-error-msg)\n  const std::shared_ptr<Tensor>& input_tensor = inputs.at(0);\n  std::string input_lbn = TensorNameScope::Global()->Lookup(input_tensor);\n  if (input_lbn.empty()) {\n    JUST(AddFreeEagerTensorToVariableOp(input_tensor));\n    input_lbn = TensorNameScope::Global()->Lookup(input_tensor);\n  }\n  CHECK_OR_RETURN(!input_lbn.empty());  // NOLINT(maybe-need-error-msg)\n  auto device = JUST(ctx.attrs.GetAttr<Symbol<Device>>(\"device\"));\n\n  CHECK_EQ_OR_RETURN(outputs->size(), 1);        // NOLINT(maybe-need-error-msg)\n  CHECK_EQ_OR_RETURN(op_expr.output_size(), 1);  // NOLINT(maybe-need-error-msg)\n  if (input_tensor->is_local()) {\n    (*outputs)[0] = JUST(LocalTensor::MakeTensor(\n        input_tensor->shape(), JUST(input_tensor->stride()), input_tensor->dtype()->data_type(),\n        input_tensor->memory_format(), device,\n        /* is_lazy= */ true,\n        /*requires_grad=*/false, /*is_leaf=*/true));\n  } else {\n    ParallelConf parallel_conf = JUST(input_tensor->parallel_desc())->parallel_conf();\n    parallel_conf.set_device_tag(device->type());\n    ParallelDesc parallel_desc(parallel_conf);\n    (*outputs)[0] = JUST(GlobalTensor::MakeTensor(\n        input_tensor->shape(), input_tensor->dtype()->data_type(), input_tensor->memory_format(),\n        JUST(input_tensor->nd_sbp()), SymbolOf(parallel_desc),\n        /* is_lazy= */ true,\n        /*requires_grad=*/false, /*is_leaf=*/true));\n  }\n  // NOTE(chengcheng): output tensor lbn is SAME with input tensor.\n  TensorNameScope::Global()->Record(outputs->at(0), input_lbn);\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\nMaybe<void> LazyInterpreter::ApplyImpl(const UserOpExpr& op_expr, const TensorTuple& inputs,\n                                       TensorTuple* outputs, const OpExprInterpContext& ctx) const {\n  CHECK_EQ_OR_RETURN(inputs.size(), op_expr.input_size());  // NOLINT(maybe-need-error-msg)\n\n  // NOTE(chengcheng): Handle special UserOp such as:\n  //     1. [Source UserOp] : OFRecordReader, CoinFlip\n  //     2. [Change Placement/ParallelDesc UserOp] : to(copy)/to_global/parallel_cast\n  //     3. [Multi-Inputs & Different ParallelDesc for each input UserOp] : like there are 2 inputs,\n  //             one from CPU and the other from GPU.\n  //     ..., etc.\n  //\n  //     Need add if for each special UserOp for infer:\n  //     1. op_conf: device_tag,\n  //     2. output tensor: is_local,\n  //     3. op_parallel_conf for build new scope with parallel_desc\n  //     4. output blob (different with tensor) -> parallel_conf\n  //     5. need add to JobBuildAndInferCtx (like copy will NOT need)\n  if (inputs.size() == 0) {\n    // NOTE(chengcheng): handle for source UserOp like OFRecordReader, CoinFlip\n    return LazyInterpreterApplyImplForSourceUserOpExpr(op_expr, outputs, ctx);\n  }\n  if (op_expr.op_type_name() == \"copy\") {\n    // NOTE(chengcheng): handle for copy UserOp which will NOT add op to job.\n    return LazyInterpreterApplyImplForCopyUserOpExpr(op_expr, inputs, outputs, ctx);\n  }\n\n  // NOTE(chengcheng):\n  //   Normal UserOp inputs size >= 1 for infer parallel_desc.\n  CHECK_GE_OR_RETURN(inputs.size(), 1);  // NOLINT(maybe-need-error-msg)\n  auto op_conf = JUST(OpInterpUtil::GenBuiltinOpConf(op_expr, ctx.attrs));\n  std::shared_ptr<Scope> scope = JUST(NewScopeWithParallelDescByTensor(JUST(VectorAt(inputs, 0))));\n  op_conf->set_scope_symbol_id(JUST(scope->symbol_id()));\n  const std::string device_tag = JUST(GetDeviceTagOfTensor(JUST(VectorAt(inputs, 0))));\n  const bool is_local = inputs.at(0)->is_local();\n  const std::shared_ptr<const ParallelDesc> parallel_desc =\n      JUST(GetParallelDescOfTensor(inputs.at(0)));\n\n  op_conf->set_device_tag(device_tag);\n  auto infer_ctx = JUST(GetCurInferCtx());\n  // NOTE(chengcheng): MUST reset unique op name before InferCtx::AddOp\n  const std::string new_op_name = *JUST(infer_ctx->NewUniqueOpNameByFunctionalOpConf(*op_conf));\n  const std::string graph_name = infer_ctx->job().job_conf().job_name();\n\n  for (int i = 0; i < inputs.size(); ++i) {\n    const auto& input_tensor = inputs.at(i);\n    CHECK_EQ_OR_RETURN(is_local, input_tensor->is_local());  // NOLINT(maybe-need-error-msg)\n    if (!op_expr.IsHostMemoryInput(i)) {\n      if (is_local) {\n        CHECK_OR_RETURN(device_tag == JUST(GetDeviceTagOfTensor(input_tensor)))\n            << Error::RuntimeError() << \"Lazy nn.Graph name: \" << graph_name\n            << \" encountered ERROR in module/op_name: \" << new_op_name\n            << \". Expected all tensors to be on the same device, but found at least two devices, \"\n            << JUST(JUST(VectorAt(inputs, 0))->device())->ToString() << \" (positional 0) and \"\n            << JUST(JUST(VectorAt(inputs, i))->device())->ToString() << \" (positional \" << i\n            << \")! Please use tensor.to() to synchronize all the input with the same device.\";\n      } else {\n        // TODO: Print out all the placement\n        CHECK_OR_RETURN(parallel_desc->Equals(*JUST(GetParallelDescOfTensor(input_tensor))))\n            << Error::RuntimeError() << \"Lazy nn.Graph name: \" << graph_name\n            << \" encountered ERROR in module/op_name: \" << new_op_name\n            << \". Expected all tensors to be on the same placement, but found at least two \"\n               \"placements, \"\n            << *JUST(PlacementToString(JUST(JUST(VectorAt(inputs, 0))->parallel_desc())))\n            << \" (positional 0) and \"\n            << *JUST(PlacementToString(JUST(JUST(VectorAt(inputs, i))->parallel_desc())))\n            << \" (positional \" << i\n            << \")! Please use tensor.to_global() to synchronize all the input with the same \"\n               \"placement.\";\n      }\n    }\n    const std::string& ibn = op_expr.indexed_ibns().at(i);\n    std::string lbn = TensorNameScope::Global()->Lookup(input_tensor);\n    if (lbn.empty()) {\n      JUST(AddFreeEagerTensorToVariableOp(input_tensor));\n      lbn = TensorNameScope::Global()->Lookup(input_tensor);\n    }\n    CHECK_OR_RETURN(!lbn.empty());  // NOLINT(maybe-need-error-msg)\n    ReplaceInputLbnInOpCustomizedConf(op_conf.get(), ibn, lbn);\n  }\n\n  // NOTE(chengcheng): for UserOp, NOT only reset op_name, but also the output values.\n  op_conf->set_name(new_op_name);\n  for (auto& pair : *(op_conf->mutable_user_conf()->mutable_output())) {\n    auto& list_s = pair.second;\n    for (int i = 0; i < list_s.s_size(); ++i) {\n      std::string old_lbn = list_s.s(i);\n      LogicalBlobId old_lbi = GenLogicalBlobId(old_lbn);\n      // NOTE(chengcheng): MUST change the old_lbn to new op name.\n      std::string new_lbn = GenLogicalBlobName(new_op_name, old_lbi.blob_name());\n      list_s.set_s(i, new_lbn);\n    }\n  }\n\n  // Check outputs num and setup output tensor properties.\n  CHECK_EQ_OR_RETURN(outputs->size(), op_expr.output_size());  // NOLINT(maybe-need-error-msg)\n\n  // Disable boxing if the computation is inplace.\n  for (int i = 0; i < op_expr.output_size(); ++i) {\n    const auto& output = outputs->at(i);\n    if (output) {\n      const std::string& lbn = TensorNameScope::Global()->Lookup(output);\n      CHECK_OR_RETURN(!lbn.empty()) << \"The output which index is \" << i\n                                    << \" has no tensor name, please check whether the inplaced \"\n                                       \"output is also an input of the operation \"\n                                    << new_op_name;\n      JUST(infer_ctx->DisableBoxing(lbn));\n    }\n  }\n  VLOG(2) << \"Lazy nn.Graph name \" << graph_name << \" try to add op: \\n\"\n          << op_conf->DebugString() << std::endl;\n  OpAttribute op_attr = *JUST(infer_ctx->AddAndInferGlobalOp(*op_conf));\n  VLOG(2) << \"Lazy nn.Graph name \" << graph_name << \" add op : \\n\" << op_conf->name() << std::endl;\n  VLOG(3) << \"Lazy nn.Graph name \" << graph_name << \" infer and and op attr : \\n\"\n          << op_attr.DebugString() << std::endl;\n\n  int64_t parallel_desc_sym_id = JUST(scope->GetParallelDescSymbolId(*op_conf));\n  auto blob_parallel_desc = JUST(GetSymbol<ParallelDesc>(parallel_desc_sym_id));\n  for (int i = 0; i < op_expr.output_size(); ++i) {\n    const std::string& obn = op_expr.indexed_obns().at(i);\n    if (!(*outputs)[i]) {\n      (*outputs)[i] =\n          JUST(BuildTensor(op_attr, obn, blob_parallel_desc, /* is_lazy= */ true, is_local));\n    } else {\n      VLOG(2) << \"Lazy nn.Graph name \" << graph_name << \" op name \" << new_op_name\n              << \" run with inplace.\";\n      const std::shared_ptr<Tensor>& inplace_out = (*outputs)[i];\n      JUST(CheckTensorMatchAttr(inplace_out, op_attr, obn, blob_parallel_desc, is_local));\n    }\n    TensorNameScope::Global()->Record((*outputs)[i], GenLogicalBlobName(new_op_name, obn));\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> LazyInterpreter::ApplyImpl(const FunctionOpExpr& op_expr, const TensorTuple& inputs,\n                                       TensorTuple* outputs, const OpExprInterpContext&) const {\n  // Must reset ctx in each forward\n  op_expr.reset_state();\n  std::shared_ptr<FunctionAutoGradCaptureState> ctx = op_expr.state();\n  *outputs = *(op_expr.forward()(ctx, inputs));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> LazyInterpreter::ApplyImpl(const GlobalToGlobalOpExpr& op_expr,\n                                       const TensorTuple& inputs, TensorTuple* outputs,\n                                       const OpExprInterpContext& ctx) const {\n  CHECK_EQ_OR_RETURN(op_expr.input_size(), 1);  // NOLINT(maybe-need-error-msg)\n  CHECK_EQ_OR_RETURN(inputs.size(), 1);         // NOLINT(maybe-need-error-msg)\n  const auto& input_tensor = inputs[0];\n  CHECK_OR_RETURN(input_tensor->is_global());  // NOLINT(maybe-need-error-msg)\n\n  CHECK_OR_RETURN(ctx.parallel_desc.has_value());  // NOLINT(maybe-need-error-msg)\n  const auto& parallel_desc_sym = JUST(ctx.parallel_desc);\n  CHECK_OR_RETURN(ctx.nd_sbp.has_value());  // NOLINT(maybe-need-error-msg)\n  const auto& sbp_sym = JUST(ctx.nd_sbp);\n\n  std::string input_lbn = TensorNameScope::Global()->Lookup(input_tensor);\n  if (input_lbn.empty()) {\n    JUST(AddFreeEagerTensorToVariableOp(input_tensor));\n    input_lbn = TensorNameScope::Global()->Lookup(input_tensor);\n    CHECK_OR_RETURN(!input_lbn.empty());  // NOLINT(maybe-need-error-msg)\n  }\n\n  std::shared_ptr<Tensor> input_proxy;\n  if (!JUST(GetParallelDescOfTensor(input_tensor))\n           ->Equals(*parallel_desc_sym.shared_from_symbol())) {\n    // NOTE(zwx): The input tensor's parallel_desc is not equal to that of op's,\n    // create a proxy input with the parallel_desc that is the same as op's\n    input_proxy = JUST(GlobalTensor::MakeTensor(\n        input_tensor->shape(), input_tensor->dtype()->data_type(), input_tensor->memory_format(),\n        JUST(input_tensor->nd_sbp()), parallel_desc_sym,\n        /* is_lazy= */ true,\n        /*requires_grad=*/false, /*is_leaf=*/true));\n    TensorNameScope::Global()->Record(input_proxy, input_lbn);\n  }\n\n  CHECK_EQ_OR_RETURN(op_expr.output_size(), 1);  // NOLINT(maybe-need-error-msg)\n  CHECK_EQ_OR_RETURN(outputs->size(), 1);        // NOLINT(maybe-need-error-msg)\n  CHECK_OR_RETURN(!(*outputs)[0]);               // NOLINT(maybe-need-error-msg)\n\n  if (!op_expr.grad_nd_sbp().has_value() && sbp_sym == JUST(input_tensor->nd_sbp())) {\n    // NOTE(chengcheng):  if to_global ONLY change placement (nd_sbp and grad_nd_sbp is same),\n    //    there is no need to build hierarchical_parallel_cast op.\n    if (input_proxy) {\n      (*outputs)[0] = input_proxy;\n    } else {\n      (*outputs)[0] = input_tensor;\n    }\n    return Maybe<void>::Ok();\n  }\n\n  // build parallel cast op expr\n  std::shared_ptr<std::vector<std::string>> sbp_list_ptr = JUST(GetNdSbpStrList(sbp_sym));\n  std::string grad_mode;\n  std::vector<std::string> grad_sbp_str_list;\n  if (op_expr.grad_nd_sbp().has_value()) {\n    grad_mode = \"manual\";\n    grad_sbp_str_list = *JUST(GetNdSbpStrList(JUST(op_expr.grad_nd_sbp())));\n  } else {\n    grad_mode = \"identity\";\n  }\n  std::shared_ptr<UserOpExpr> parallel_cast_op_expr =\n      JUST(OpBuilder(\"hierarchical_parallel_cast\", \"trivial_op_name\")\n               .Input(\"in\")\n               .Output(\"out\")\n               .Attr<std::vector<std::string>>(\"nd_sbp\", *sbp_list_ptr)\n               .Attr<std::string>(\"grad_mode\", grad_mode)\n               .Attr<std::vector<std::string>>(\"grad_nd_sbp\", grad_sbp_str_list)\n               .Build());\n\n  if (input_proxy) {\n    (*outputs)[0] =\n        JUST(OpInterpUtil::Dispatch<one::Tensor>(*parallel_cast_op_expr, {input_proxy}));\n  } else {\n    (*outputs)[0] =\n        JUST(OpInterpUtil::Dispatch<one::Tensor>(*parallel_cast_op_expr, {input_tensor}));\n  }\n\n  return Maybe<void>::Ok();\n}\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/framework/op_interpreter/lazy_op_interpreter.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/framework/tensor.h\"\n#include \"oneflow/core/framework/tensor_name_scope.h\"\n#include \"oneflow/core/operator/operator.h\"\n\nnamespace oneflow {\nnamespace one {\n\nMaybe<Tensor> GradAccTryInsertUnpackAfterInput(const std::shared_ptr<Tensor>& input);\nMaybe<Tensor> GradAccTryInsertRepeatAfterVar(const std::shared_ptr<Tensor>& variable);\nMaybe<Tensor> GradAccTryInsertPackBeforeOutput(const std::shared_ptr<Tensor>& output);\n\nMaybe<void> GradAccTryInsertRepeatTickBeforeSource(\n    const std::shared_ptr<OperatorConf>& source_op_conf, bool is_local);\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/framework/op_interpreter/op_interpreter.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/op_interpreter.h\"\n\n#include \"oneflow/core/autograd/autograd_engine.h\"\n#include \"oneflow/core/autograd/autograd_mode.h\"\n#include \"oneflow/core/framework/op_interpreter/op_interpreter_util.h\"\n#include \"oneflow/core/framework/instructions_builder.h\"\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/framework/tensor.h\"\n#include \"oneflow/core/framework/tensor_tuple.h\"\n#include \"oneflow/core/job/lazy_mode.h\"\n#include \"oneflow/core/profiler/profiler.h\"\n\nnamespace oneflow {\nnamespace one {\n\nMaybe<void> LazyInterpreter::Apply(const OpExpr& op_expr, const TensorTuple& inputs,\n                                   TensorTuple* outputs, const OpExprInterpContext& ctx) const {\n#define APPLY_IF(op_type)                                              \\\n  if (const auto* op = dynamic_cast<const op_type##Expr*>(&op_expr)) { \\\n    return ApplyImpl(*op, inputs, outputs, ctx);                       \\\n  }\n\n  APPLY_IF(FeedInputOp);\n  APPLY_IF(FeedVariableOp);\n  APPLY_IF(FetchOutputOp);\n  APPLY_IF(UserOp);\n  APPLY_IF(GlobalToGlobalOp);\n  APPLY_IF(FunctionOp);\n  APPLY_IF(ImageDecoderRandomCropResizeOp);\n#undef APPLY_IF\n\n  OF_UNIMPLEMENTED() << \"The type \" << op_expr.op_type_name()\n                     << \" has not been supported in LazyInterpreter::Apply.\";\n}\n\nMaybe<void> EagerInterpreter::Apply(const OpExpr& op_expr, const TensorTuple& inputs,\n                                    TensorTuple* outputs, const OpExprInterpContext& ctx) const {\n  // In the op interpreter, judge whether to open the global mode to avoid recursion caused by\n  // GlobalMode.\n  // The global mode is enabled only if it was enabled and the current operation is a local\n  // operation.\n  auto global_mode_gurad = GlobalMode::Guard(GlobalMode::is_enabled() && is_local_,\n                                             GlobalMode::nd_sbp(), GlobalMode::parallel_desc());\n\n#define APPLY_IF(op_type)                                              \\\n  if (const auto* op = dynamic_cast<const op_type##Expr*>(&op_expr)) { \\\n    return ApplyImpl(*op, inputs, outputs, ctx);                       \\\n  }\n\n  APPLY_IF(UserOp);\n  APPLY_IF(VariableOp);\n  APPLY_IF(CastToLocalOp);\n  APPLY_IF(CastFromLocalOp);\n  APPLY_IF(GlobalToGlobalOp);\n  APPLY_IF(LocalToGlobalOp);\n  APPLY_IF(GlobalToLocalOp);\n  APPLY_IF(DistributeSplitOp);\n  APPLY_IF(DistributeCloneOp);\n  APPLY_IF(DistributeConcatOp);\n  APPLY_IF(DistributeAddOp);\n  APPLY_IF(FunctionOp);\n  APPLY_IF(SelectTopNOp)\n#undef APPLY_IF\n\n  OF_UNIMPLEMENTED() << \"The type \" << op_expr.op_type_name()\n                     << \" has not been supported in EagerInterpreter::Apply.\";\n}\n\nMaybe<void> EagerInterpreter::ApplyImpl(const FunctionOpExpr& op_expr, const TensorTuple& inputs,\n                                        TensorTuple* outputs, const OpExprInterpContext&) const {\n  // Must reset ctx in each forward\n  op_expr.reset_state();\n  std::shared_ptr<FunctionAutoGradCaptureState> ctx = op_expr.state();\n  *outputs = *(op_expr.forward()(ctx, inputs));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> AutogradInterpreter::Apply(const OpExpr& op_expr, const TensorTuple& inputs,\n                                       TensorTuple* outputs, const OpExprInterpContext& ctx) const {\n  bool requires_grad = false;\n  if (autograd::GradMode::is_enabled() && !JUST(op_expr.IsGradDisabled())) {\n    requires_grad =\n        std::any_of(inputs.begin(), inputs.end(),\n                    [](const std::shared_ptr<Tensor>& tensor) { return tensor->requires_grad(); });\n  }\n  {\n    autograd::AutoGradMode mode(false);\n    JUST(internal_->Apply(op_expr, inputs, outputs, ctx));\n  }\n  // Lazy mode will construct backward compute graph in passes, so disable autograd if lazy mode.\n  std::shared_ptr<OpExprGradClosure> grad_closure(nullptr);\n  if (requires_grad) {\n    OF_PROFILER_RANGE_PUSH(\"autograd.GetOrCreateOpGradClosure\");\n    grad_closure = JUST(op_expr.GetOrCreateOpGradClosure());\n    auto backward_fn = std::make_shared<BackwardFunction>();\n    backward_fn->body = [=](const TensorTuple& out_grads, TensorTuple* in_grads,\n                            bool create_graph) -> Maybe<void> {\n      autograd::AutoGradMode mode(create_graph);\n      JUST(grad_closure->Apply(out_grads, in_grads));\n      return Maybe<void>::Ok();\n    };\n    backward_fn->status = [=]() { return grad_closure->state()->SavedTensors().size() > 0; };\n    OF_PROFILER_RANGE_POP();\n    OF_PROFILER_RANGE_PUSH(\"autograd.AddNode\");\n    JUST(GetThreadLocalAutogradEngine()->AddNode(op_expr.op_type_name() + \"Backward\", backward_fn,\n                                                 inputs, outputs));\n    OF_PROFILER_RANGE_POP();\n  }\n\n  if (requires_grad) {\n    OF_PROFILER_RANGE_GUARD(\"autograd.Capture\");\n    // Capture inputs and outputs after `AddNode` because of that grad function\n    // node has been attached to them.\n    JUST(grad_closure->Capture(inputs, *outputs, ctx));\n  }\n  // Update outputs autograd meta\n  // Note: if requires_grad is True, we will create a new autograd meta for each output\n  // in `AddNode` to support inplace operation, so the update should after\n  // `AddNode`\n  for (auto& output : *outputs) {\n    output->set_is_leaf(inputs.size() == 0 || !requires_grad);\n    // If the output `requires_grad` is true, it means that the output is inplaced.\n    // The output `requires_grad` should be determined by this:\n    //   - If the inplaced output `requires_grad` is true, then the autograd must be disabled,\n    //     so the output `requires_grad` should never be changed.\n    //   - If the inplaced output `requires_grad` is false, then the output `requires_grad`\n    //     shoule be inferred by autograd mode and inputs. For example,\n    //\n    //     >>> import oneflow as flow\n    //     >>> x = flow.ones(4, 4, requires_grad=False)\n    //     >>> y = flow.ones(4, 4, requires_grad=True)\n    //     >>> x += y\n    //     >>> x.requires_grad\n    //     True\n    //     >>> with flow.no_grad():\n    //     >>>    x += y\n    //     >>> x.requires_grad\n    //     False\n    //\n    //   - If there is no inplace, the output `requires_grad` should be inferred by autograd\n    //     mode and inputs.\n    if (!output->requires_grad()) {\n      JUST(output->set_requires_grad(\n          requires_grad && IsSupportRequireGradDataType(output->dtype()->data_type())));\n    }\n  }\n  return Maybe<void>::Ok();\n}\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/framework/op_interpreter/op_interpreter_util.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/op_interpreter/op_interpreter_util.h\"\n#include <memory>\n\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/eager/eager_blob_object.h\"\n#include \"oneflow/core/framework/device.h\"\n#include \"oneflow/core/framework/dtype.h\"\n#include \"oneflow/core/framework/tensor_impl.h\"\n#include \"oneflow/core/functional/tensor_processor.h\"\n#include \"oneflow/core/job/lazy_mode.h\"\n#include \"oneflow/core/job/job_build_and_infer_ctx_mgr.h\"\n#include \"oneflow/core/operator/operator.h\"\n#include \"oneflow/core/profiler/profiler.h\"\n\nnamespace oneflow {\nnamespace one {\n\nnamespace {\n\nstd::shared_ptr<AutogradInterpreter> BuildEagerInterpreter(const bool& is_local) {\n  std::shared_ptr<OpExprInterpreter> internal;\n  if (is_local) {\n    internal = std::make_shared<EagerLocalInterpreter>();\n  } else {\n    internal = std::make_shared<EagerGlobalInterpreter>();\n  }\n  return std::make_shared<AutogradInterpreter>(internal);\n}\n\nstd::shared_ptr<AutogradInterpreter> BuildLazyInterpreter() {\n  auto internal = std::make_shared<LazyInterpreter>();\n  return std::make_shared<AutogradInterpreter>(internal);\n}\n\nstd::string ErrorString4Inputs(const TensorTuple& inputs, const OpExpr& op_expr) {\n  std::stringstream error_str;\n  error_str << \"Got input tensors with inconsistent attributes!\\n\"\n            << \"op_type_name: \" << op_expr.op_type_name() << \"\\n\"\n            << \"attributes of inputs is:\\n\";\n  int32_t idx = 0;\n  for (const auto& tensor : inputs) {\n    if (tensor->is_local()) {\n      error_str << \"local\";\n    } else {\n      error_str << \"global\";\n    }\n    if (++idx != inputs.size()) { error_str << \", \"; }\n  }\n  return error_str.str();\n}\n\nMaybe<AutogradInterpreter> GetInterpreter(const TensorTuple& inputs, const OpExprInterpContext& ctx,\n                                          const OpExpr& op_expr) {\n  static const auto& g_lazy_interpreter = BuildLazyInterpreter();\n  static const auto& g_eager_global_interpreter = BuildEagerInterpreter(/*is_local=*/false);\n  static const auto& g_eager_local_interpreter = BuildEagerInterpreter(/*is_local=*/true);\n  bool is_local = true;\n  if (inputs.empty()) {\n    if (ctx.parallel_desc.has_value()) {\n      JUST(ctx.nd_sbp);\n      CHECK_OR_RETURN(!ctx.device.has_value());\n      is_local = false;\n    } else {\n      CHECK_OR_RETURN(!ctx.nd_sbp.has_value());\n    }\n  } else {\n    if (inputs[0]->is_global()) {\n      if (inputs.size() == 1) {\n        // do nothing\n      } else if (inputs.size() == 2) {\n        CHECK_OR_RETURN(inputs[1]->is_global())      // NOLINT\n            << ErrorString4Inputs(inputs, op_expr);  // unroll loop for efficiency\n      } else if (inputs.size() == 3) {\n        CHECK_OR_RETURN(inputs[1]->is_global())\n            << ErrorString4Inputs(inputs, op_expr);  // unroll loop for efficiency\n        CHECK_OR_RETURN(inputs[2]->is_global())\n            << ErrorString4Inputs(inputs, op_expr);  // unroll loop for efficiency\n      } else {\n        for (const auto& tensor : inputs) {\n          CHECK_OR_RETURN(tensor->is_global()) << ErrorString4Inputs(inputs, op_expr);\n        }\n      }\n      is_local = false;\n    } else {\n      if (inputs.size() == 1) {\n        // do nothing\n      } else if (inputs.size() == 2) {\n        CHECK_OR_RETURN(inputs.at(1)->is_local())\n            << ErrorString4Inputs(inputs, op_expr);  // unroll loop for efficiency\n      } else if (inputs.size() == 3) {\n        CHECK_OR_RETURN(inputs.at(1)->is_local())\n            << ErrorString4Inputs(inputs, op_expr);  // unroll loop for efficiency\n        CHECK_OR_RETURN(inputs.at(2)->is_local())\n            << ErrorString4Inputs(inputs, op_expr);  // unroll loop for efficiency\n      } else {\n        for (const auto& tensor : inputs) {\n          CHECK_OR_RETURN(tensor->is_local()) << ErrorString4Inputs(inputs, op_expr);\n        }\n      }\n    }\n  }\n  if (!LazyMode::is_enabled()) {\n    if (is_local) {\n      return g_eager_local_interpreter;\n    } else {\n      return g_eager_global_interpreter;\n    }\n  } else {\n    return g_lazy_interpreter;\n  }\n}\n\n}  // namespace\n\ntemplate<>\n/* static */ Maybe<TensorTuple> OpInterpUtil::Dispatch<TensorTuple>(\n    const OpExpr& op_expr, const TensorTuple& inputs, const OpExprInterpContext& ctx) {\n  OF_PROFILER_RANGE_GUARD(\"Dispatch\");\n  auto outputs = std::make_shared<TensorTuple>(op_expr.output_size());\n  JUST(Dispatch(op_expr, inputs, outputs.get(), ctx));\n  return outputs;\n}\n\ntemplate<>\n/* static */ Maybe<Tensor> OpInterpUtil::Dispatch<Tensor>(const OpExpr& op_expr,\n                                                          const TensorTuple& inputs,\n                                                          const OpExprInterpContext& ctx) {\n  OF_PROFILER_RANGE_GUARD(\"Dispatch\");\n  return JUST(Dispatch<TensorTuple>(op_expr, inputs, ctx))->at(0);\n}\n\n/* static */ Maybe<void> OpInterpUtil::Dispatch(const OpExpr& op_expr, const TensorTuple& inputs,\n                                                TensorTuple* outputs,\n                                                const OpExprInterpContext& ctx) {\n  OF_PROFILER_RANGE_GUARD(\"Dispatch\");\n  functional::TensorProcessorPipe processor(inputs, outputs);\n  if (autocast::is_enabled()) {\n    JUST(processor.Apply<functional::TensorAutoCastProcessor>(\n        *JUST(op_expr.GetOrCreateAutoCastMeta())));\n  }\n  JUST(processor.Apply<functional::TensorLayoutProcessor>(JUST(op_expr.SupportNonContiguous())));\n  return JUST(GetInterpreter(processor.inputs(), ctx, op_expr))\n      ->Apply(op_expr, processor.inputs(), processor.outputs(), ctx);\n}\n\n/* static */ Maybe<OperatorConf> OpInterpUtil::GenBuiltinOpConf(const BuiltinOpExpr& op_expr,\n                                                                const AttrMap& attrs) {\n  auto op_conf = std::make_shared<OperatorConf>();\n  JUST(op_expr.BuildOpConf(op_conf.get(), attrs));\n  return op_conf;\n}\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/framework/op_interpreter/op_interpreter_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_FRAMEWORK_OP_INTERPRETER_UTIL_H_\n#define ONEFLOW_CORE_FRAMEWORK_OP_INTERPRETER_UTIL_H_\n\n#include \"oneflow/core/framework/attr_map.h\"\n#include \"oneflow/core/framework/instructions_builder.h\"\n#include \"oneflow/core/framework/op_expr.h\"\n#include \"oneflow/core/framework/op_interpreter.h\"\n#include \"oneflow/core/framework/scope_util.h\"\n#include \"oneflow/core/framework/session_util.h\"\n#include \"oneflow/core/framework/tensor.h\"\n#include \"oneflow/core/framework/tensor_tuple.h\"\n\nnamespace oneflow {\nnamespace one {\n\nclass OpInterpUtil {\n public:\n  template<typename T>\n  static Maybe<T> Dispatch(const OpExpr& op_expr, const TensorTuple& inputs, const AttrMap& attrs) {\n    return Dispatch<T>(op_expr, inputs, OpExprInterpContext(attrs));\n  }\n\n  template<typename T>\n  static Maybe<T> Dispatch(const OpExpr& op_expr, const TensorTuple& inputs) {\n    return Dispatch<T>(op_expr, inputs, OpExprInterpContext(AttrMap{}));\n  }\n\n  template<typename T>\n  static Maybe<T> Dispatch(const OpExpr& op_expr, const TensorTuple& inputs,\n                           const OpExprInterpContext& ctx);\n\n  static Maybe<void> Dispatch(const OpExpr& op_expr, const TensorTuple& inputs,\n                              TensorTuple* outputs, const AttrMap& attrs) {\n    return Dispatch(op_expr, inputs, outputs, OpExprInterpContext(attrs));\n  }\n\n  static Maybe<void> Dispatch(const OpExpr& op_expr, const TensorTuple& inputs,\n                              TensorTuple* outputs) {\n    return Dispatch(op_expr, inputs, outputs, OpExprInterpContext(AttrMap{}));\n  }\n\n  static Maybe<void> Dispatch(const OpExpr& op_expr, const TensorTuple& inputs,\n                              TensorTuple* outputs, const OpExprInterpContext& ctx);\n\n  static Maybe<OperatorConf> GenBuiltinOpConf(const BuiltinOpExpr& op_expr, const AttrMap& attrs);\n};\n\n}  // namespace one\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_FRAMEWORK_OP_INTERPRETER_UTIL_H_\n"
  },
  {
    "path": "oneflow/core/framework/op_interpreter.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_FRAMEWORK_OP_INTERPRETER_H_\n#define ONEFLOW_CORE_FRAMEWORK_OP_INTERPRETER_H_\n\n#include \"oneflow/core/framework/attr_map.h\"\n#include \"oneflow/core/framework/op_expr.h\"\n#include \"oneflow/core/framework/tensor.h\"\n#include \"oneflow/core/framework/tensor_tuple.h\"\n#include \"oneflow/core/framework/op_kernel.h\"\n#include \"oneflow/core/common/optional.h\"\n\nnamespace oneflow {\n\nclass Device;\nclass ParallelDesc;\nclass NdSbp;\n\nnamespace one {\n\nstruct OpExprInterpContext {\n  OpExprInterpContext(const AttrMap& attrs_arg) : attrs(attrs_arg) {}\n  OpExprInterpContext(const AttrMap& attrs_arg, Symbol<Device> device_arg)\n      : attrs(attrs_arg), device(device_arg) {}\n  OpExprInterpContext(const AttrMap& attrs_arg, std::shared_ptr<user_op::OpKernelState> state_arg)\n      : attrs(attrs_arg), state(state_arg) {}\n  OpExprInterpContext(const AttrMap& attrs_arg, Symbol<Device> device_arg,\n                      std::shared_ptr<user_op::OpKernelState> state_arg)\n      : attrs(attrs_arg), device(device_arg), state(state_arg) {}\n  OpExprInterpContext(const AttrMap& attrs_arg, Symbol<ParallelDesc> parallel_desc_arg)\n      : attrs(attrs_arg), parallel_desc(parallel_desc_arg) {}\n  OpExprInterpContext(const AttrMap& attrs_arg, Symbol<ParallelDesc> parallel_desc_arg,\n                      Symbol<NdSbp> nd_sbp_arg)\n      : attrs(attrs_arg), parallel_desc(parallel_desc_arg), nd_sbp(nd_sbp_arg) {}\n  OpExprInterpContext(const AttrMap& attrs_arg, Symbol<ParallelDesc> parallel_desc_arg,\n                      Symbol<NdSbp> nd_sbp_arg, std::shared_ptr<user_op::OpKernelState> state_arg)\n      : attrs(attrs_arg), parallel_desc(parallel_desc_arg), nd_sbp(nd_sbp_arg), state(state_arg) {}\n\n  AttrMap attrs;\n  Optional<Symbol<Device>> device;               // for local op\n  Optional<Symbol<ParallelDesc>> parallel_desc;  // for global op\n  Optional<Symbol<NdSbp>> nd_sbp;                // for global op\n  std::shared_ptr<user_op::OpKernelState> state;\n};\n\nclass OpExprInterpreter {\n public:\n  OpExprInterpreter() = default;\n  virtual ~OpExprInterpreter() = default;\n\n  Maybe<void> Apply(const OpExpr& op, const TensorTuple& inputs, TensorTuple* outputs,\n                    const AttrMap& attrs) const {\n    return Apply(op, inputs, outputs, OpExprInterpContext(attrs));\n  }\n\n  Maybe<void> Apply(const OpExpr& op, const TensorTuple& inputs, TensorTuple* outputs) const {\n    return Apply(op, inputs, outputs, AttrMap{});\n  }\n\n  virtual Maybe<void> Apply(const OpExpr& op, const TensorTuple& inputs, TensorTuple* outputs,\n                            const OpExprInterpContext& ctx) const = 0;\n};\n\n#define FOR_EACH_BUILTIN_OPS(_macro) \\\n  _macro(UserOp);                    \\\n  _macro(SelectTopNOp);              \\\n  _macro(VariableOp);                \\\n  _macro(CastToLocalOp);             \\\n  _macro(CastFromLocalOp);           \\\n  _macro(GlobalToGlobalOp);          \\\n  _macro(LocalToGlobalOp);           \\\n  _macro(GlobalToLocalOp);           \\\n  _macro(DistributeSplitOp);         \\\n  _macro(DistributeCloneOp);         \\\n  _macro(DistributeConcatOp);        \\\n  _macro(DistributeAddOp);\n\n#define DECLARE_NORMAL_APPLY_FUNC(op_type)                                               \\\n  virtual Maybe<void> ApplyImpl(const op_type##Expr& op_expr, const TensorTuple& inputs, \\\n                                TensorTuple* outputs, const OpExprInterpContext& ctx) const\n\n#define DECLARE_PURE_VIRTUAL_APPLY_FUNC(op_type) DECLARE_NORMAL_APPLY_FUNC(op_type) = 0;\n\n#define DECLARE_OVERRIDE_APPLY_FUNC(op_type)                                     \\\n  Maybe<void> ApplyImpl(const op_type##Expr& op_expr, const TensorTuple& inputs, \\\n                        TensorTuple* outputs, const OpExprInterpContext& ctx) const override;\n\nclass LazyInterpreter : public OpExprInterpreter {\n public:\n  LazyInterpreter() : OpExprInterpreter() {}\n  virtual ~LazyInterpreter() = default;\n\n  Maybe<void> Apply(const OpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs,\n                    const OpExprInterpContext& ctx) const override;\n\n private:\n  DECLARE_NORMAL_APPLY_FUNC(UserOp);\n  DECLARE_NORMAL_APPLY_FUNC(FeedInputOp);\n  DECLARE_NORMAL_APPLY_FUNC(FeedVariableOp);\n  DECLARE_NORMAL_APPLY_FUNC(FetchOutputOp);\n  DECLARE_NORMAL_APPLY_FUNC(FunctionOp);\n  DECLARE_NORMAL_APPLY_FUNC(GlobalToGlobalOp);\n  DECLARE_NORMAL_APPLY_FUNC(ImageDecoderRandomCropResizeOp);\n};\n\nclass EagerInterpreter : public OpExprInterpreter {\n public:\n  EagerInterpreter(bool is_local) : OpExprInterpreter(), is_local_(is_local) {}\n  virtual ~EagerInterpreter() = default;\n\n  Maybe<void> Apply(const OpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs,\n                    const OpExprInterpContext& ctx) const override;\n\n protected:\n  // NOTE(lixiang): To ensure the correctness of GlobalMode, check whether it is a local operation\n  // and initialize it as true when using EagerLocalInterpreter.\n  //   Used by Maybe<void> EagerInterpreter::Apply.\n  bool is_local_;\n\n private:\n  FOR_EACH_BUILTIN_OPS(DECLARE_PURE_VIRTUAL_APPLY_FUNC);\n  DECLARE_NORMAL_APPLY_FUNC(FunctionOp);\n};\n\nclass EagerGlobalInterpreter : public EagerInterpreter {\n public:\n  EagerGlobalInterpreter() : EagerInterpreter(false) {}\n  virtual ~EagerGlobalInterpreter() = default;\n\n private:\n  FOR_EACH_BUILTIN_OPS(DECLARE_OVERRIDE_APPLY_FUNC);\n};\n\nclass EagerLocalInterpreter : public EagerInterpreter {\n public:\n  EagerLocalInterpreter() : EagerInterpreter(true) {}\n  virtual ~EagerLocalInterpreter() = default;\n\n private:\n  FOR_EACH_BUILTIN_OPS(DECLARE_OVERRIDE_APPLY_FUNC);\n};\n\n#undef DECLARE_OVERRIDE_APPLY_FUNC\n#undef DECLARE_PURE_VIRTUAL_APPLY_FUNC\n#undef DECLARE_NORMAL_APPLY_FUNC\n#undef FOR_EACH_BUILTIN_OPS\n\nclass AutogradInterpreter {\n public:\n  AutogradInterpreter() = delete;\n  AutogradInterpreter(const std::shared_ptr<OpExprInterpreter>& internal) : internal_(internal) {}\n\n  virtual ~AutogradInterpreter() = default;\n\n  Maybe<void> Apply(const OpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs,\n                    const AttrMap& attrs) const {\n    return Apply(op_expr, inputs, outputs, OpExprInterpContext(attrs));\n  }\n\n  Maybe<void> Apply(const OpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs) const {\n    return Apply(op_expr, inputs, outputs, OpExprInterpContext(AttrMap{}));\n  }\n\n  Maybe<void> Apply(const OpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs,\n                    const OpExprInterpContext& ctx) const;\n\n private:\n  std::shared_ptr<OpExprInterpreter> internal_;\n};\n\n}  // namespace one\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_FRAMEWORK_OP_INTERPRETER_H_\n"
  },
  {
    "path": "oneflow/core/framework/op_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/op_kernel.h\"\n#include \"oneflow/core/framework/attr_value_accessor.h\"\n\nnamespace oneflow {\n\nnamespace user_op {\n\nvoid OpKernel::InferShape(KernelInferContext* ctx) const {\n  InferContext* op_infer_ctx = ctx->MutOpInferContext();\n  CHECK_NOTNULL(op_infer_ctx);\n  ctx->GetOpInferFn()(op_infer_ctx);\n  for (const auto& arg_pair : ctx->outputs()) {\n    const Shape& shape = op_infer_ctx->OutputShape(arg_pair.first, arg_pair.second);\n    auto mut_shape_view = ctx->MutShapeView4ArgNameAndIndex(arg_pair.first, arg_pair.second);\n    mut_shape_view.set_shape(shape);\n  }\n}\n\n}  // namespace user_op\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/framework/op_kernel.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_FRAMEWORK_OP_KERNEL_H_\n#define ONEFLOW_CORE_FRAMEWORK_OP_KERNEL_H_\n\n#include <memory>\n\n#include \"oneflow/core/common/throw.h\"\n\n#include \"oneflow/core/framework/util.h\"\n#include \"oneflow/core/framework/user_op_tensor.h\"\n#include \"oneflow/core/framework/user_op_conf.h\"\n#include \"oneflow/core/framework/user_op_registry.h\"\n#include \"oneflow/core/framework/infer_util.h\"\n#include \"oneflow/core/ep/include/stream.h\"\n#include \"oneflow/core/job/placement.pb.h\"\n#include \"oneflow/core/job/parallel_desc.h\"\n#include \"oneflow/core/ep/include/stream.h\"\n\nnamespace oneflow {\n\nclass JobDesc;\n\nnamespace user_op {\n\nclass KernelInitContext {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(KernelInitContext);\n  virtual ~KernelInitContext() = default;\n\n  virtual ep::Stream* stream() = 0;\n\n  virtual DeviceType device_type() const = 0;\n  virtual const ParallelContext& parallel_ctx() const = 0;\n  virtual const TensorDesc* TensorDesc4ArgNameAndIndex(const std::string&, int32_t) const = 0;\n  virtual const SbpParallel& SbpParallel4ArgNameAndIndex(const std::string&, int32_t) const = 0;\n  virtual const TensorDesc* LogicalTensorDesc4ArgNameAndIndex(const std::string&,\n                                                              int32_t) const = 0;\n  virtual const ParallelDesc& parallel_desc() const = 0;\n  virtual const NdSbp& NdSbp4ArgNameAndIndex(const std::string&, int32_t) const = 0;\n\n  virtual const std::vector<std::pair<std::string, int32_t>>& inputs() const = 0;\n  virtual const std::vector<std::pair<std::string, int32_t>>& outputs() const = 0;\n\n  const std::string& input(const std::string& arg_name, int32_t index) const {\n    return user_op_conf().input(arg_name, index);\n  }\n  const std::string& output(const std::string& arg_name, int32_t index) const {\n    return user_op_conf().output(arg_name, index);\n  }\n  bool has_input(const std::string& arg_name, int32_t index) const {\n    return user_op_conf().has_input(arg_name, index);\n  }\n  bool has_output(const std::string& arg_name, int32_t index) const {\n    return user_op_conf().has_output(arg_name, index);\n  }\n  int32_t input_size(const std::string& arg_name) const {\n    return user_op_conf().input_size(arg_name);\n  }\n  int32_t output_size(const std::string& arg_name) const {\n    return user_op_conf().output_size(arg_name);\n  }\n  const std::string& op_name() const { return user_op_conf().op_name(); }\n  const std::string& op_type_name() const { return user_op_conf().op_type_name(); }\n  const OperatorConf& op_conf() const { return user_op_conf().op_conf(); }\n\n  template<typename T>\n  const T& Attr(const std::string& attr_name) const {\n    return AttrValueCast<T>(*Attr4Name(attr_name));\n  }\n\n  template<typename T>\n  const T& attr(const std::string& attr_name) const;\n\n protected:\n  KernelInitContext() = default;\n\n  virtual const UserOpConfWrapper& user_op_conf() const = 0;\n  virtual const std::shared_ptr<const AttrVal>& Attr4Name(const std::string& attr_name) const = 0;\n};\n\nclass KernelCacheContext {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(KernelCacheContext);\n  virtual ~KernelCacheContext() = default;\n\n  virtual ep::Stream* stream() = 0;\n\n  virtual DeviceType device_type() const = 0;\n  virtual const ParallelContext& parallel_ctx() const = 0;\n  virtual const TensorDesc* TensorDesc4ArgNameAndIndex(const std::string&, int32_t) const = 0;\n  virtual const SbpParallel& SbpParallel4ArgNameAndIndex(const std::string&, int32_t) const = 0;\n  virtual const TensorDesc* LogicalTensorDesc4ArgNameAndIndex(const std::string&,\n                                                              int32_t) const = 0;\n  virtual const ParallelDesc& parallel_desc() const = 0;\n  virtual const NdSbp& NdSbp4ArgNameAndIndex(const std::string&, int32_t) const = 0;\n\n  virtual const std::vector<std::pair<std::string, int32_t>>& inputs() const = 0;\n  virtual const std::vector<std::pair<std::string, int32_t>>& outputs() const = 0;\n\n  const std::string& input(const std::string& arg_name, int32_t index) const {\n    return user_op_conf().input(arg_name, index);\n  }\n  const std::string& output(const std::string& arg_name, int32_t index) const {\n    return user_op_conf().output(arg_name, index);\n  }\n  bool has_input(const std::string& arg_name, int32_t index) const {\n    return user_op_conf().has_input(arg_name, index);\n  }\n  bool has_output(const std::string& arg_name, int32_t index) const {\n    return user_op_conf().has_output(arg_name, index);\n  }\n  int32_t input_size(const std::string& arg_name) const {\n    return user_op_conf().input_size(arg_name);\n  }\n  int32_t output_size(const std::string& arg_name) const {\n    return user_op_conf().output_size(arg_name);\n  }\n  const std::string& op_name() const { return user_op_conf().op_name(); }\n  const std::string& op_type_name() const { return user_op_conf().op_type_name(); }\n  const OperatorConf& op_conf() const { return user_op_conf().op_conf(); }\n\n  template<typename T>\n  const T& Attr(const std::string& attr_name) const {\n    return AttrValueCast<T>(*Attr4Name(attr_name));\n  }\n\n  template<typename T>\n  const T& attr(const std::string& attr_name) const;\n\n protected:\n  KernelCacheContext() = default;\n\n  virtual const UserOpConfWrapper& user_op_conf() const = 0;\n  virtual const std::shared_ptr<const AttrVal>& Attr4Name(const std::string& attr_name) const = 0;\n};\n\nclass KernelInferContext {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(KernelInferContext);\n  virtual ~KernelInferContext() = default;\n\n  virtual const std::vector<std::pair<std::string, int32_t>>& inputs() const = 0;\n  virtual const std::vector<std::pair<std::string, int32_t>>& outputs() const = 0;\n  virtual const TensorDesc* TensorDesc4ArgNameAndIndex(const std::string&, int32_t) const = 0;\n  virtual DeviceType device_type() const = 0;\n  virtual const ParallelContext& parallel_ctx() const = 0;\n\n  virtual ep::Stream* stream() = 0;\n  virtual Tensor* Tensor4ArgNameAndIndex(const std::string& arg_name, int32_t arg_index) = 0;\n  virtual ShapeView ShapeView4ArgNameAndIndex(const std::string& arg_name, int32_t arg_index) = 0;\n  virtual MutShapeView MutShapeView4ArgNameAndIndex(const std::string& arg_name,\n                                                    int32_t arg_index) = 0;\n\n  const std::string& input(const std::string& arg_name, int32_t index) const {\n    return user_op_conf().input(arg_name, index);\n  }\n  const std::string& output(const std::string& arg_name, int32_t index) const {\n    return user_op_conf().output(arg_name, index);\n  }\n  bool has_input(const std::string& arg_name, int32_t index) const {\n    return user_op_conf().has_input(arg_name, index);\n  }\n  bool has_output(const std::string& arg_name, int32_t index) const {\n    return user_op_conf().has_output(arg_name, index);\n  }\n  int32_t input_size(const std::string& arg_name) const {\n    return user_op_conf().input_size(arg_name);\n  }\n  int32_t output_size(const std::string& arg_name) const {\n    return user_op_conf().output_size(arg_name);\n  }\n  const std::string& op_name() const { return user_op_conf().op_name(); }\n  const std::string& op_type_name() const { return user_op_conf().op_type_name(); }\n\n  template<typename T>\n  const T& Attr(const std::string& attr_name) const {\n    return AttrValueCast<T>(*Attr4Name(attr_name));\n  }\n\n  virtual InferContext* MutOpInferContext() {\n    UNIMPLEMENTED();\n    return nullptr;\n  }\n  virtual const TensorDescInferFn& GetOpInferFn() const {\n    UNIMPLEMENTED();\n    static TensorDescInferFn empty_fn;\n    return empty_fn;\n  }\n\n protected:\n  KernelInferContext() = default;\n\n  virtual const UserOpConfWrapper& user_op_conf() const = 0;\n  virtual const std::shared_ptr<const AttrVal>& Attr4Name(const std::string& attr_name) const = 0;\n};\n\nclass Tensor;\n\nclass KernelComputeContext {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(KernelComputeContext);\n  virtual ~KernelComputeContext() = default;\n\n  virtual Tensor* Tensor4ArgNameAndIndex(const std::string& arg_name, int32_t index) = 0;\n  virtual ep::Stream* stream() = 0;\n\n  virtual const TensorDesc* TensorDesc4ArgNameAndIndex(const std::string& arg_name,\n                                                       int32_t index) const = 0;\n  virtual DeviceType device_type() const = 0;\n  virtual const ParallelContext& parallel_ctx() const = 0;\n\n  virtual const std::vector<std::pair<std::string, int32_t>>& inputs() const = 0;\n  virtual const std::vector<std::pair<std::string, int32_t>>& outputs() const = 0;\n  const std::string& input(const std::string& arg_name, int32_t index) const {\n    return user_op_conf().input(arg_name, index);\n  }\n  const std::string& output(const std::string& arg_name, int32_t index) const {\n    return user_op_conf().output(arg_name, index);\n  }\n  bool has_input(const std::string& arg_name, int32_t index) const {\n    return user_op_conf().has_input(arg_name, index);\n  }\n  bool has_output(const std::string& arg_name, int32_t index) const {\n    return user_op_conf().has_output(arg_name, index);\n  }\n  int32_t input_size(const std::string& arg_name) const {\n    return user_op_conf().input_size(arg_name);\n  }\n  int32_t output_size(const std::string& arg_name) const {\n    return user_op_conf().output_size(arg_name);\n  }\n  const std::string& op_name() const { return user_op_conf().op_name(); }\n  const std::string& op_type_name() const { return user_op_conf().op_type_name(); }\n\n  template<typename T>\n  const T& Attr(const std::string& attr_name) const {\n    return AttrValueCast<T>(*Attr4Name(attr_name));\n  }\n\n protected:\n  KernelComputeContext() = default;\n\n  virtual const UserOpConfWrapper& user_op_conf() const = 0;\n\n  virtual const std::shared_ptr<const AttrVal>& Attr4Name(const std::string& attr_name) const = 0;\n};\n\nclass OpKernelState {\n public:\n  virtual ~OpKernelState() = default;\n\n protected:\n  OpKernelState() = default;\n};\n\nclass OpKernelCache {\n public:\n  virtual ~OpKernelCache() = default;\n\n  static const int32_t kAllMayChanged = 0;\n  static const int32_t kShapeNotChanged = 1 << 0;\n  static const int32_t kAttrNotChanged = 1 << 1;\n\n protected:\n  OpKernelCache() = default;\n};\n\nclass OpKernel;\n\ntemplate<typename T, typename... Args>\nOpKernel* NewOpKernel(Args&&... args);\n\nclass OpKernel {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(OpKernel);\n  virtual ~OpKernel() = default;\n\n  virtual std::shared_ptr<OpKernelState> CreateOpKernelState(KernelInitContext* ctx) const {\n    return std::shared_ptr<OpKernelState>();\n  }\n\n  virtual std::shared_ptr<OpKernelCache> InitOpKernelCache(KernelCacheContext* ctx) const {\n    return std::shared_ptr<OpKernelCache>();\n  }\n\n  virtual void InitOpKernelCacheWithFlags(KernelCacheContext* ctx, int8_t flag,\n                                          std::shared_ptr<OpKernelCache>* cache_ptr) const {\n    *cache_ptr = InitOpKernelCache(ctx);\n  }\n\n  virtual void Compute(KernelComputeContext* ctx, OpKernelState*, const OpKernelCache*) const {\n    Compute(ctx);\n  }\n  virtual void Compute(KernelComputeContext* ctx) const {\n    LOG(WARNING) << ctx->op_name() << \" :UNIMPLEMENTED\";\n  }\n  virtual void InferShape(KernelInferContext* ctx) const;\n  virtual bool AlwaysComputeWhenAllOutputsEmpty() const = 0;\n  virtual bool IsKernelLaunchSynchronized() const { return true; }\n\n  bool has_state_or_cache() const { return has_state_or_cache_; }\n\n protected:\n  OpKernel() : has_state_or_cache_(true) {}\n\n private:\n  template<typename T, typename... Args>\n  friend OpKernel* NewOpKernel(Args&&... args);\n  bool has_state_or_cache_;\n};\n\ntemplate<typename T, typename... Args>\nOpKernel* NewOpKernel(Args&&... args) {\n  OpKernel* ptr = new T(std::forward<Args>(args)...);\n  ptr->has_state_or_cache_ = !(std::is_same<decltype(&OpKernel::CreateOpKernelState),\n                                            decltype(&T::CreateOpKernelState)>::value\n                               && std::is_same<decltype(&OpKernel::InitOpKernelCache),\n                                               decltype(&T::InitOpKernelCache)>::value\n                               && std::is_same<decltype(&OpKernel::InitOpKernelCacheWithFlags),\n                                               decltype(&T::InitOpKernelCacheWithFlags)>::value);\n  return ptr;\n}\n\n}  // namespace user_op\n\n}  // namespace oneflow\n\n#endif\n"
  },
  {
    "path": "oneflow/core/framework/op_kernel_infer_cache.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/op_kernel_infer_cache.h\"\n#include \"oneflow/core/framework/op_kernel.h\"\n#include \"oneflow/core/operator/operator.h\"\n\nnamespace oneflow {\n\nnamespace user_op {\n\nOpKernelInferCache::OpKernelInferCache(const KernelConf& kernel_conf, const void* scope) {\n  const OperatorConf& op_conf = kernel_conf.op_attribute().op_conf();\n  std::shared_ptr<Operator> op = CHECK_JUST(ConstructOp(op_conf));\n  cache_key_.scope = scope;\n  cache_key_.op_conf_sym = op->GetOpConfWithoutOpNameAndLbn();\n  cache_key_.ibn_idx2shape_sym.resize(op->input_bns().size());\n  cache_key_.dtype_signature_sym = SymbolOf(kernel_conf.dtype_signature());\n}\n\nbool OpKernelInferCache::IsCacheHit() const {\n  size_t hash_value = std::hash<KeyType>()(cache_key_);\n  HashEqTraitPtr<const KeyType> ptr_wrapper(&cache_key_, hash_value);\n  return cached_key2value_.find(ptr_wrapper) != cached_key2value_.end();\n}\n\nOpKernelInferCache::ValueType OpKernelInferCache::GetCacheValue() const {\n  size_t hash_value = std::hash<KeyType>()(cache_key_);\n  HashEqTraitPtr<const KeyType> ptr_wrapper(&cache_key_, hash_value);\n  CHECK(cached_key2value_.find(ptr_wrapper) != cached_key2value_.end());\n  return cached_key2value_.at(ptr_wrapper);\n}\n\nvoid OpKernelInferCache::UpdateCacheKey(KernelInferContext* ctx) {\n  auto GetSymbolOfShape = [&](const std::string& arg_name, int32_t arg_index) -> Symbol<Shape> {\n    Shape shape;\n    ctx->ShapeView4ArgNameAndIndex(arg_name, arg_index).ToShape(&shape);\n    return SymbolOf(shape);\n  };\n  const auto& inputs = ctx->inputs();\n  FOR_RANGE(int, i, 0, inputs.size()) {\n    const auto& arg_pair = inputs.at(i);\n    cache_key_.ibn_idx2shape_sym.at(i) = GetSymbolOfShape(arg_pair.first, arg_pair.second);\n  }\n}\n\nvoid OpKernelInferCache::UpdateCacheValue(KernelInferContext* ctx) {\n  // TODO: make max size configurable\n  if (cached_key2value_.size() >= kReleaseInIndependentThreadThreshold) { Reset(); }\n  auto* cache_value = new OpInferCacheValue();\n  cache_value->obn_idx2shape_sym.resize(ctx->outputs().size());\n  FOR_RANGE(int, i, 0, ctx->outputs().size()) {\n    const auto& out_arg_pair = ctx->outputs().at(i);\n    const ShapeView& out_shape_view =\n        ctx->ShapeView4ArgNameAndIndex(out_arg_pair.first, out_arg_pair.second);\n    Shape out_shape;\n    out_shape_view.ToShape(&out_shape);\n    cache_value->obn_idx2shape_sym.at(i).reset(out_shape);\n  }\n  KeyType* new_key = new KeyType(cache_key_);\n  key_storage_.emplace_back(new_key);\n  size_t hash_value = std::hash<KeyType>()(cache_key_);\n  HashEqTraitPtr<const KeyType> ptr_wrapper(new_key, hash_value);\n  CHECK(cached_key2value_.emplace(ptr_wrapper, ValueType(cache_value)).second);\n}\n\nvoid OpKernelInferCache::Reset() {\n  CHECK_EQ(cached_key2value_.size(), key_storage_.size());\n  HashMap to_release_key2values;\n  KeyStorage to_release_key_storage;\n  std::swap(cached_key2value_, to_release_key2values);\n  std::swap(key_storage_, to_release_key_storage);\n  if (to_release_key2values.size() <= kReleaseInIndependentThreadThreshold) {\n    to_release_key2values.clear();\n    to_release_key_storage.clear();\n  } else {\n    std::thread(\n        [](HashMap&& cache, KeyStorage&& key_storage) {\n          cache.clear();\n          key_storage.clear();\n        },\n        std::move(to_release_key2values), std::move(to_release_key_storage));\n  }\n}\n\n}  // namespace user_op\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/framework/op_kernel_infer_cache.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_FRAMEWORK_OP_KERNEL_INFER_CACHE_H_\n#define ONEFLOW_CORE_FRAMEWORK_OP_KERNEL_INFER_CACHE_H_\n\n#include \"oneflow/core/operator/op_infer_cache.h\"\n#include \"oneflow/core/common/hash_eq_trait_ptr.h\"\n#include \"oneflow/core/kernel/kernel.pb.h\"\n\nnamespace oneflow {\n\nnamespace user_op {\n\nclass KernelInferContext;\n\nclass OpKernelInferCache final {\n public:\n  using KeyType = OpInferCacheKey;\n  using ValueType = std::shared_ptr<const OpInferCacheValue>;\n  using HashMap = std::unordered_map<HashEqTraitPtr<const KeyType>, ValueType>;\n  using KeyStorage = std::list<std::unique_ptr<KeyType>>;\n  static constexpr size_t kReleaseInIndependentThreadThreshold = 4096;\n\n  OpKernelInferCache(const KernelConf& kernel_conf, const void* scope);\n  ~OpKernelInferCache() = default;\n\n  bool IsCacheHit() const;\n  ValueType GetCacheValue() const;\n  void UpdateCacheKey(KernelInferContext* ctx);\n  void UpdateCacheValue(KernelInferContext* ctx);\n  void Reset();\n\n private:\n  KeyType cache_key_;\n  HashMap cached_key2value_;\n  KeyStorage key_storage_;\n};\n\n}  // namespace user_op\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_FRAMEWORK_OP_KERNEL_INFER_CACHE_H_\n"
  },
  {
    "path": "oneflow/core/framework/ordered_string_list.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_FRAMEWORK_ORDERED_STRING_LIST_H_\n#define ONEFLOW_CORE_FRAMEWORK_ORDERED_STRING_LIST_H_\n\n#include \"llvm/ADT/StringRef.h\"\n\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/common/small_vector.h\"\n\nnamespace oneflow {\n\ntemplate<int N>\nclass OrderedStringList {\n public:\n  OrderedStringList() = default;\n\n  size_t size() const { return strings_.size(); }\n\n  void emplace_back(llvm::StringRef s) {\n    strings_.emplace_back(std::make_shared<std::string>(s.str()));\n    order_.emplace(*strings_.back(), order_.size());\n  }\n\n  int order(llvm::StringRef s) const {\n    const auto& it = order_.find(s);\n    if (it == order_.end()) { return -1; }\n    return it->second;\n  }\n\n  const std::string& operator[](int idx) { return *(strings_[idx]); }\n\n private:\n  struct Hash {\n    size_t operator()(llvm::StringRef val) const {\n      return HashCombine(val.size(), val.size() > 0 ? static_cast<size_t>(val.data()[0] - '0') : 0);\n    }\n  };\n  HashMap<llvm::StringRef, int, Hash> order_;\n  // Use shared_ptr to prevent the appended element from being freed when the\n  // vector increases\n  small_vector<std::shared_ptr<std::string>, N> strings_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_FRAMEWORK_ORDERED_STRING_LIST_H_\n"
  },
  {
    "path": "oneflow/core/framework/parallel_conf_util.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/str_util.h\"\n#include \"oneflow/core/framework/parallel_conf_util.h\"\n#include \"oneflow/core/common/shape.pb.h\"\n\nnamespace oneflow {\n\nMaybe<std::tuple<std::string, std::vector<std::string>, std::shared_ptr<ShapeProto>>>\nGetDeviceTagAndMachineDeviceIdsAndHierarchy(const ParallelConf& parallel_conf) {\n  std::vector<std::string> machine_device_ids;\n  machine_device_ids.reserve(parallel_conf.device_name().size());\n  for (const std::string& device_name : parallel_conf.device_name()) {\n    machine_device_ids.emplace_back(device_name);\n  }\n  std::shared_ptr<ShapeProto> hierarchy;\n  if (parallel_conf.has_hierarchy()) { hierarchy.reset(new ShapeProto(parallel_conf.hierarchy())); }\n  return std::make_tuple(parallel_conf.device_tag(), machine_device_ids, hierarchy);\n}\n\nMaybe<ParallelConf> MakeParallelConf(const std::string& device_tag,\n                                     const std::vector<std::string>& machine_device_ids,\n                                     const std::shared_ptr<Shape>& hierarchy) {\n  std::shared_ptr<ParallelConf> parallel_conf = std::make_shared<ParallelConf>();\n  parallel_conf->set_device_tag(device_tag);\n  for (const std::string& machine_device_id : machine_device_ids) {\n    size_t pos = machine_device_id.find(':');\n    CHECK_NE_OR_RETURN(pos, std::string::npos) << \"device_name: \" << machine_device_id;\n    std::string machine_id = machine_device_id.substr(0, pos);\n    CHECK_OR_RETURN(\n        (IsStrInt(machine_id) || (machine_id[0] == '@' && IsStrInt(machine_id.substr(1)))))\n        << \" machine_id: \" << machine_id;\n    std::string device_id = machine_device_id.substr(pos + 1);\n    size_t minus_pos = device_id.rfind('-');\n    if (minus_pos == std::string::npos) {\n      CHECK_OR_RETURN(IsStrInt(device_id));\n    } else {\n      std::string min_id = device_id.substr(0, minus_pos);\n      CHECK_OR_RETURN(IsStrInt(min_id));\n      std::string max_id = device_id.substr(minus_pos + 1);\n      CHECK_OR_RETURN(IsStrInt(max_id));\n    }\n    parallel_conf->add_device_name(machine_device_id);\n    if (hierarchy) {\n      ShapeProto proto;\n      hierarchy->ToProto(&proto);\n      parallel_conf->mutable_hierarchy()->CopyFrom(proto);\n    }\n  }\n  return parallel_conf;\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/framework/parallel_conf_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_FRAMEWORK_PRARLLEL_CONF_UTIL_H_\n#define ONEFLOW_CORE_FRAMEWORK_PRARLLEL_CONF_UTIL_H_\n\n#include <utility>\n#include <vector>\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/job/placement.pb.h\"\n#include \"oneflow/core/common/shape.h\"\n\nnamespace oneflow {\n\nMaybe<std::tuple<std::string, std::vector<std::string>, std::shared_ptr<ShapeProto>>>\nGetDeviceTagAndMachineDeviceIdsAndHierarchy(const ParallelConf& parallel_conf);\n\nMaybe<ParallelConf> MakeParallelConf(const std::string& device_tag,\n                                     const std::vector<std::string>& machine_device_ids,\n                                     const std::shared_ptr<Shape>& hierarchy);\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_FRAMEWORK_PRARLLEL_CONF_UTIL_H_\n"
  },
  {
    "path": "oneflow/core/framework/parallel_conf_util_test.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"gtest/gtest.h\"\n#include <algorithm>\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/framework/parallel_conf_util.h\"\n\nnamespace oneflow {\nnamespace test {\n\nTEST(ParallelConfUtil, MakeParallelConfSuccess) {\n  std::string device_tag = \"cpu\";\n  std::vector<std::string> machine_device_ids;\n  machine_device_ids.emplace_back(\"0:0-3\");\n  machine_device_ids.emplace_back(\"1:0-3\");\n  auto parallel_conf = CHECK_JUST(MakeParallelConf(device_tag, machine_device_ids, nullptr));\n  ASSERT_EQ(parallel_conf->device_tag(), \"cpu\");\n  ASSERT_EQ(parallel_conf->device_name().size(), 2);\n  ASSERT_EQ(parallel_conf->has_hierarchy(), false);\n}\n\nTEST(ParallelConfUtil, MakeParallelConfError) {\n  std::string device_tag = \"cpu\";\n  std::vector<std::string> machine_device_ids;\n  machine_device_ids.emplace_back(\"0:0-3\");\n  machine_device_ids.emplace_back(\"1:0-\");\n  auto parallel_conf = TRY(MakeParallelConf(device_tag, machine_device_ids, nullptr));\n  ASSERT_EQ(parallel_conf.error()->has_check_failed_error(), true);\n}\n\nTEST(ParallelConfUtil, GetDeviceTagAndMachineDeviceIdsAndHierarchy) {\n  ParallelConf parallel_conf;\n  parallel_conf.set_device_tag(\"cpu\");\n  parallel_conf.add_device_name(\"0:0-1\");\n  parallel_conf.add_device_name(\"0:2-3\");\n  parallel_conf.add_device_name(\"1:0-1\");\n  parallel_conf.add_device_name(\"1:2-3\");\n  parallel_conf.mutable_hierarchy()->add_dim(2);\n  parallel_conf.mutable_hierarchy()->add_dim(4);\n  std::tuple<std::string, std::vector<std::string>, std::shared_ptr<ShapeProto>>\n      tag_and_dev_ids_and_hierarchy =\n          *CHECK_JUST(GetDeviceTagAndMachineDeviceIdsAndHierarchy(parallel_conf));\n  std::string device_tag = std::get<0>(tag_and_dev_ids_and_hierarchy);\n  std::vector<std::string> machine_device_ids = std::get<1>(tag_and_dev_ids_and_hierarchy);\n  std::shared_ptr<ShapeProto> hierarchy = std::get<2>(tag_and_dev_ids_and_hierarchy);\n  ASSERT_EQ(device_tag, \"cpu\");\n  ASSERT_NE(std::count(machine_device_ids.begin(), machine_device_ids.end(), \"0:0-1\"), 0);\n  ASSERT_NE(std::count(machine_device_ids.begin(), machine_device_ids.end(), \"0:2-3\"), 0);\n  ASSERT_NE(std::count(machine_device_ids.begin(), machine_device_ids.end(), \"1:0-1\"), 0);\n  ASSERT_NE(std::count(machine_device_ids.begin(), machine_device_ids.end(), \"1:2-3\"), 0);\n  ASSERT_EQ(std::count(machine_device_ids.begin(), machine_device_ids.end(), \"2:0-3\"), 0);\n  ASSERT_EQ(hierarchy->dim(0), 2);\n  ASSERT_EQ(hierarchy->dim(1), 4);\n}\n\n}  // namespace test\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/framework/placed_nd_sbp.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/placed_nd_sbp.h\"\n#include \"oneflow/core/job/sbp_parallel.h\"\n#include \"oneflow/core/job/parallel_desc.h\"\n#include \"oneflow/core/common/decorator.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nMaybe<Symbol<PlacedNdSbp>> RawNew(const Symbol<NdSbp>& nd_sbp,\n                                  const Symbol<ParallelDesc>& placement) {\n  CHECK_OR_RETURN(nd_sbp);\n  CHECK_OR_RETURN(placement);\n  CHECK_GT_OR_RETURN(nd_sbp->sbp_parallel_size(), 0);\n  CHECK_EQ_OR_RETURN(nd_sbp->sbp_parallel_size(), placement->hierarchy()->NumAxes());\n  return SymbolOf(PlacedNdSbp(nd_sbp, placement));\n}\n\n}  // namespace\n\ndecltype(PlacedNdSbp::New) PlacedNdSbp::New = DECORATE(&RawNew, ThreadLocal);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/framework/placed_nd_sbp.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_FRAMEWORK_PLACED_ND_SBP_H_\n#define ONEFLOW_CORE_FRAMEWORK_PLACED_ND_SBP_H_\n\n#include <functional>\n#include \"oneflow/core/common/symbol.h\"\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/common/util.h\"\n\nnamespace oneflow {\n\nclass NdSbp;\nclass ParallelDesc;\n\nclass PlacedNdSbp final {\n public:\n  PlacedNdSbp(const Symbol<NdSbp>& nd_sbp, const Symbol<ParallelDesc>& placement)\n      : nd_sbp_(nd_sbp), placement_(placement) {}\n  ~PlacedNdSbp() = default;\n\n  static Maybe<Symbol<PlacedNdSbp>> (*New)(const Symbol<NdSbp>&, const Symbol<ParallelDesc>&);\n\n  const Symbol<NdSbp>& nd_sbp() const { return nd_sbp_; }\n  const Symbol<ParallelDesc>& placement() const { return placement_; }\n\n  bool operator==(const PlacedNdSbp& other) const {\n    return this->nd_sbp_ == other.nd_sbp_ && this->placement_ == other.placement_;\n  }\n\n private:\n  Symbol<NdSbp> nd_sbp_;\n  Symbol<ParallelDesc> placement_;\n};\n\n}  // namespace oneflow\n\nnamespace std {\n\ntemplate<>\nstruct hash<oneflow::PlacedNdSbp> final {\n  size_t operator()(const oneflow::PlacedNdSbp& placed_nd_sbp) const {\n    return oneflow::Hash(placed_nd_sbp.nd_sbp(), placed_nd_sbp.placement());\n  }\n};\n\n}  // namespace std\n\n#endif  // ONEFLOW_CORE_FRAMEWORK_PLACED_ND_SBP_H_\n"
  },
  {
    "path": "oneflow/core/framework/placement_sbp_util.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <tuple>\n#include <algorithm>\n#include \"oneflow/core/framework/placement_sbp_util.h\"\n#include \"oneflow/core/framework/placed_nd_sbp.h\"\n#include \"oneflow/core/common/tensor_meta.h\"\n#include \"oneflow/core/framework/nd_sbp.h\"\n#include \"oneflow/core/common/shape.h\"\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/job/parallel_desc.h\"\n#include \"oneflow/core/job/sbp_parallel.h\"\n#include \"oneflow/core/common/decorator.h\"\n#include \"oneflow/core/common/optional.h\"\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/common/math_util.h\"\n#include \"oneflow/core/common/container_util.h\"\n#include \"oneflow/core/operator/operator.h\"\n#include \"oneflow/core/rpc/include/global_process_ctx.h\"\n\nnamespace oneflow {\n\nnamespace private_details {\n\nnamespace {\n\nusing IndexVector = DimVector;\n\nMaybe<void> GetIndexesFromOffset(const Stride& strides, int64_t offset, IndexVector* indexes) {\n  indexes->resize(strides.size());\n  for (int i = 0; i < strides.size(); ++i) {\n    indexes->at(i) = offset / strides.at(i);\n    offset = offset % strides.at(i);\n  }\n  CHECK_EQ_OR_RETURN(offset, 0);\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> GetOffsetFromIndexes(const Stride& strides, const IndexVector& indexes,\n                                 int64_t* offset) {\n  CHECK_EQ_OR_RETURN(strides.size(), indexes.size())\n      << Error::RuntimeError() << \"Expected size of strides to match that of indexes\";\n  *offset = 0;\n  for (int i = 0; i < strides.size(); ++i) { *offset += indexes.at(i) * strides.at(i); }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> GetSelectedIndex2OriginIndex(\n    const IndexVector& indexes, const std::vector<int>& axis2is_selected,\n    std::function<void(const DimVector&, DimVector*)>* SelectedIndex2OriginIndex) {\n  CHECK_EQ_OR_RETURN(axis2is_selected.size(), indexes.size());\n  *SelectedIndex2OriginIndex = [=](const DimVector& broadcast, DimVector* origin) {\n    origin->resize(indexes.size());\n    for (int i = 0; i < indexes.size(); ++i) {\n      origin->at(i) = axis2is_selected.at(i) ? broadcast.at(i) : indexes.at(i);\n    }\n  };\n  return Maybe<void>::Ok();\n}\n\nMaybe<const Shape> GetSelectedShape(const Shape& hierarchy_shape,\n                                    const std::vector<int>& axis2is_selected) {\n  CHECK_EQ_OR_RETURN(hierarchy_shape.NumAxes(), axis2is_selected.size());\n  DimVector dim_vec = hierarchy_shape.dim_vec();\n  for (int i = 0; i < axis2is_selected.size(); ++i) {\n    if (!axis2is_selected.at(i)) { dim_vec.at(i) = 1; }\n  }\n  return std::make_shared<const Shape>(dim_vec);\n}\n\nMaybe<Symbol<std::vector<int>>> CalcAxis2IsBroadcast(Symbol<NdSbp> nd_sbp) {\n  std::vector<int> axis2is_selected(nd_sbp->sbp_parallel_size());\n  for (int i = 0; i < axis2is_selected.size(); ++i) {\n    axis2is_selected.at(i) = nd_sbp->sbp_parallel(i).has_broadcast_parallel();\n  }\n  return SymbolOf(axis2is_selected);\n}\n\nstatic auto* GetAxis2IsBroadcast = DECORATE(&CalcAxis2IsBroadcast, ThreadLocal);\n\nMaybe<Symbol<ParallelDesc>> CalcSelectedSubParallelDesc(Symbol<ParallelDesc> parallel_desc,\n                                                        Symbol<std::vector<int>> axis2is_selected) {\n  const auto& opt_parallel_id = JUST(GetParallelId4CurrentProcessCtx(parallel_desc));\n  int64_t parallel_id = JUST(*opt_parallel_id);\n  const auto& hierarchy_shape = *parallel_desc->hierarchy();\n  const auto& broadcast_parallel_ids =\n      JUST(GetSelectedParallelIds(hierarchy_shape, *axis2is_selected, parallel_id));\n  ParallelConf parallel_conf;\n  parallel_conf.set_device_tag(parallel_desc->device_tag());\n  bool found_parallel_id = false;\n  for (int64_t i : *broadcast_parallel_ids) {\n    found_parallel_id = found_parallel_id || (i == parallel_id);\n    int64_t machine_id = JUST(parallel_desc->MachineId4ParallelId(i));\n    int64_t device_id = JUST(parallel_desc->DeviceId4ParallelId(i));\n    parallel_conf.add_device_name(std::string(\"@\") + std::to_string(machine_id) + \":\"\n                                  + std::to_string(device_id));\n  }\n  CHECK_OR_RETURN(found_parallel_id);\n  return SymbolOf(ParallelDesc(parallel_conf));\n}\n\nstatic auto* GetSelectedSubParallelDesc = DECORATE(&CalcSelectedSubParallelDesc, ThreadLocal);\n\n}  // namespace\n\nMaybe<Symbol<ParallelDesc>> CalcSubParallelDesc4Axis(Symbol<ParallelDesc> parallel_desc, int axis) {\n  const auto& opt_parallel_id = JUST(GetParallelId4CurrentProcessCtx(parallel_desc));\n  int64_t parallel_id = JUST(*opt_parallel_id);\n  const auto& hierarchy_shape = *parallel_desc->hierarchy();\n  Stride hierarchy_strides(hierarchy_shape);\n\n  int64_t index = CalcIndex4Axis(parallel_id, hierarchy_strides, axis);\n\n  int64_t stride = hierarchy_strides.at(axis);\n\n  int64_t start_parallel_id = parallel_id - index * stride;\n  ParallelConf parallel_conf;\n  parallel_conf.set_device_tag(parallel_desc->device_tag());\n  for (int64_t i = 0; i < hierarchy_shape.At(axis); ++i) {\n    int64_t id = start_parallel_id + i * stride;\n    int64_t machine_id = JUST(parallel_desc->MachineId4ParallelId(id));\n    int64_t device_id = JUST(parallel_desc->DeviceId4ParallelId(id));\n    parallel_conf.add_device_name(std::string(\"@\") + std::to_string(machine_id) + \":\"\n                                  + std::to_string(device_id));\n  }\n  return SymbolOf(ParallelDesc(parallel_conf));\n}\n\nMaybe<std::vector<int64_t>> GetSelectedParallelIds(const Shape& hierarchy_shape,\n                                                   const std::vector<int>& axis2is_selected,\n                                                   int64_t parallel_id) {\n  CHECK_EQ_OR_RETURN(hierarchy_shape.NumAxes(), axis2is_selected.size());\n  Stride hierarchy_strides(hierarchy_shape);\n  IndexVector indexes{};\n  JUST(GetIndexesFromOffset(hierarchy_strides, parallel_id, &indexes));\n  std::function<void(const DimVector&, DimVector*)> SelectedIndex2OriginIndex;\n  JUST(GetSelectedIndex2OriginIndex(indexes, axis2is_selected, &SelectedIndex2OriginIndex));\n  const auto& broadcast_shape = JUST(GetSelectedShape(hierarchy_shape, axis2is_selected));\n  Stride broadcast_strides(*broadcast_shape);\n  const auto& origin_offsets = std::make_shared<std::vector<int64_t>>(broadcast_shape->elem_cnt());\n  for (int64_t i = 0; i < broadcast_shape->elem_cnt(); ++i) {\n    IndexVector broadcast_indexes{};\n    JUST(GetIndexesFromOffset(broadcast_strides, i, &broadcast_indexes));\n    IndexVector origin_indexes{};\n    SelectedIndex2OriginIndex(broadcast_indexes, &origin_indexes);\n    int64_t origin_offset = -1;\n    JUST(GetOffsetFromIndexes(hierarchy_strides, origin_indexes, &origin_offset));\n    origin_offsets->at(i) = origin_offset;\n  }\n  return origin_offsets;\n}\n\nMaybe<Symbol<ParallelDesc>> GetBroadcastSubParallelDesc(Symbol<ParallelDesc> parallel_desc,\n                                                        Symbol<NdSbp> nd_sbp) {\n  const auto& axis2is_selected = JUST(GetAxis2IsBroadcast(nd_sbp));\n  return GetSelectedSubParallelDesc(parallel_desc, axis2is_selected);\n}\n\nnamespace {\n\nMaybe<Symbol<NdSbp>> MakeNdSbp(const SbpParallel& sbp) {\n  NdSbp nd_sbp;\n  nd_sbp.mutable_sbp_parallel()->Add()->CopyFrom(sbp);\n  return SymbolOf(nd_sbp);\n}\n\nMaybe<void> InitShapeAxis2NdSbpIndexes(\n    Symbol<NdSbp> nd_sbp, std::vector<std::vector<int64_t>>* shape_axis2nd_sbp_indexes) {\n  for (int i = 0; i < nd_sbp->sbp_parallel_size(); ++i) {\n    const auto& sbp = nd_sbp->sbp_parallel(i);\n    if (sbp.has_split_parallel()) {\n      int64_t axis = sbp.split_parallel().axis();\n      CHECK_GE_OR_RETURN(axis, 0);\n      CHECK_LT_OR_RETURN(axis, shape_axis2nd_sbp_indexes->size());\n      shape_axis2nd_sbp_indexes->at(axis).emplace_back(i);\n    }\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> CheckSplitAxisExpandable(\n    const Shape& hierarchy, const std::vector<std::vector<int64_t>>& shape_axis2src_nd_sbp_indexes,\n    const std::vector<std::vector<int64_t>>& shape_axis2dst_nd_sbp_indexes) {\n  const auto& GetHierarchyDim = [&](int64_t axis) { return hierarchy.At(axis); };\n  for (int i = 0; i < shape_axis2src_nd_sbp_indexes.size(); ++i) {\n    const auto& src_nd_sbp_indexes = JUST(VectorAt(shape_axis2src_nd_sbp_indexes, i));\n    if (src_nd_sbp_indexes.empty()) { continue; }\n    const auto& dst_nd_sbp_indexes = JUST(VectorAt(shape_axis2dst_nd_sbp_indexes, i));\n    if (dst_nd_sbp_indexes.empty()) { continue; }\n    std::vector<int64_t> src_nd_sbp_dims{};\n    src_nd_sbp_dims.reserve(src_nd_sbp_indexes.size());\n    std::transform(src_nd_sbp_indexes.begin(), src_nd_sbp_indexes.end(),\n                   std::back_inserter(src_nd_sbp_dims), GetHierarchyDim);\n    std::vector<int64_t> dst_nd_sbp_dims{};\n    dst_nd_sbp_dims.reserve(dst_nd_sbp_indexes.size());\n    std::transform(dst_nd_sbp_indexes.begin(), dst_nd_sbp_indexes.end(),\n                   std::back_inserter(dst_nd_sbp_dims), GetHierarchyDim);\n    CHECK_OR_RETURN(src_nd_sbp_dims == dst_nd_sbp_dims) << Error::BoxingNotSupportedError();\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InitShapAxis2ExpandedDim(\n    std::vector<DimVector>* shape_axis2expanded_dims, const Shape& shape, const Shape& hierarchy,\n    const std::vector<std::vector<int64_t>>& shape_axis2src_nd_sbp_indexes,\n    const std::vector<std::vector<int64_t>>& shape_axis2dst_nd_sbp_indexes) {\n  std::vector<DimVector> shape_axis2required_dim(shape.NumAxes());\n  for (int i = 0; i < shape.NumAxes(); ++i) {\n    const auto& src_nd_sbp_indexes = shape_axis2src_nd_sbp_indexes.at(i);\n    const auto& dst_nd_sbp_indexes = shape_axis2dst_nd_sbp_indexes.at(i);\n    int64_t max_used_cnt = std::max<size_t>(src_nd_sbp_indexes.size(), dst_nd_sbp_indexes.size());\n    for (int j = 0; j < max_used_cnt; ++j) {\n      if (j < src_nd_sbp_indexes.size() && j < dst_nd_sbp_indexes.size()) {\n        int64_t m = hierarchy.At(src_nd_sbp_indexes.at(j));\n        int64_t n = hierarchy.At(dst_nd_sbp_indexes.at(j));\n        shape_axis2required_dim.at(i).emplace_back(Lcm(m, n));\n      } else if (j < src_nd_sbp_indexes.size()) {\n        shape_axis2required_dim.at(i).emplace_back(hierarchy.At(src_nd_sbp_indexes.at(j)));\n      } else if (j < dst_nd_sbp_indexes.size()) {\n        shape_axis2required_dim.at(i).emplace_back(hierarchy.At(dst_nd_sbp_indexes.at(j)));\n      } else {\n        UNIMPLEMENTED_THEN_RETURN();\n      }\n    }\n  }\n  for (int i = 0; i < shape.NumAxes(); ++i) {\n    int64_t total_dim = shape.At(i);\n    shape_axis2expanded_dims->at(i).clear();\n    if (JUST(VectorAt(shape_axis2required_dim, i)).empty()\n        || JUST(VectorAt(shape_axis2required_dim, i)).size() == 1) {\n      shape_axis2expanded_dims->at(i).emplace_back(total_dim);\n    } else {\n      Shape inner_shape(shape_axis2required_dim.at(i));\n      CHECK_EQ_OR_RETURN(total_dim % inner_shape.elem_cnt(), 0)\n          << \"dim \" << total_dim << \"(axis \" << i << \" in shape \" << shape.ToString() << \")\"\n          << \" cannot be reshape into exapanded shape \" << inner_shape.ToString();\n      auto* dim_vec = &shape_axis2expanded_dims->at(i);\n      *dim_vec = shape_axis2required_dim.at(i);\n      dim_vec->at(dim_vec->size() - 1) *= total_dim / inner_shape.elem_cnt();\n    }\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<const Shape> Flatten(const std::vector<DimVector>& shape_axis2expanded_dims) {\n  DimVector dim_vec;\n  for (const auto& expanded_dims : shape_axis2expanded_dims) {\n    CHECK_OR_RETURN(!expanded_dims.empty());\n    dim_vec.insert(dim_vec.end(), expanded_dims.begin(), expanded_dims.end());\n  }\n  return std::make_shared<const Shape>(dim_vec);\n}\n\nMaybe<void> InitOldAxis2NewAxisOffset(std::vector<int64_t>* old_axis2new_axis_offset,\n                                      const std::vector<DimVector>& shape_axis2expanded_dims) {\n  for (int i = 0, offset = 0; i < shape_axis2expanded_dims.size(); ++i) {\n    old_axis2new_axis_offset->at(i) = offset;\n    offset += shape_axis2expanded_dims.at(i).size();\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<Symbol<NdSbp>> ShiftSplitAxis(\n    Symbol<NdSbp> nd_sbp, const std::vector<std::vector<int64_t>>& shape_axis2nd_sbp_indexes,\n    const std::vector<int64_t>& old_axis2new_axis_offset) {\n  CHECK_EQ_OR_RETURN(shape_axis2nd_sbp_indexes.size(), old_axis2new_axis_offset.size());\n  NdSbp new_nd_sbp(*nd_sbp);\n  for (int axis = 0; axis < shape_axis2nd_sbp_indexes.size(); ++axis) {\n    int64_t offset = old_axis2new_axis_offset.at(axis);\n    for (int64_t j = 0; j < shape_axis2nd_sbp_indexes.at(axis).size(); ++j) {\n      int64_t nd_sbp_index = shape_axis2nd_sbp_indexes.at(axis).at(j);\n      CHECK_GE_OR_RETURN(nd_sbp_index, 0);\n      CHECK_LT_OR_RETURN(nd_sbp_index, new_nd_sbp.sbp_parallel_size());\n      auto* sbp_parallel = new_nd_sbp.mutable_sbp_parallel(nd_sbp_index);\n      CHECK_OR_RETURN(sbp_parallel->has_split_parallel());\n      CHECK_EQ_OR_RETURN(sbp_parallel->split_parallel().axis(), axis);\n      sbp_parallel->mutable_split_parallel()->set_axis(offset + j);\n    }\n  }\n  return SymbolOf(new_nd_sbp);\n}\n\n}  // namespace\n\nMaybe<std::tuple<std::shared_ptr<const Shape>, Symbol<NdSbp>, Symbol<NdSbp>>>\nCalcDecomposableEquivalentShapeAndNdSbpPair(const Shape& shape, const Shape& hierarchy,\n                                            Symbol<NdSbp> src_nd_sbp, Symbol<NdSbp> dst_nd_sbp) {\n  CHECK_EQ_OR_RETURN(src_nd_sbp->sbp_parallel_size(), dst_nd_sbp->sbp_parallel_size());\n  std::vector<std::vector<int64_t>> shape_axis2src_nd_sbp_indexes(shape.NumAxes());\n  JUST(InitShapeAxis2NdSbpIndexes(src_nd_sbp, &shape_axis2src_nd_sbp_indexes));\n  std::vector<std::vector<int64_t>> shape_axis2dst_nd_sbp_indexes(shape.NumAxes());\n  JUST(InitShapeAxis2NdSbpIndexes(dst_nd_sbp, &shape_axis2dst_nd_sbp_indexes));\n  std::vector<DimVector> shape_axis2expanded_dims(shape.NumAxes());\n  CHECK_EQ_OR_RETURN(hierarchy.NumAxes(), src_nd_sbp->sbp_parallel_size());\n  JUST(CheckSplitAxisExpandable(hierarchy, shape_axis2src_nd_sbp_indexes,\n                                shape_axis2dst_nd_sbp_indexes));\n  JUST(InitShapAxis2ExpandedDim(&shape_axis2expanded_dims, shape, hierarchy,\n                                shape_axis2src_nd_sbp_indexes, shape_axis2dst_nd_sbp_indexes));\n  std::shared_ptr<const Shape> new_shape = JUST(Flatten(shape_axis2expanded_dims));\n  CHECK_EQ_OR_RETURN(new_shape->elem_cnt(), shape.elem_cnt());\n  std::vector<int64_t> old_axis2new_axis_offset(shape.NumAxes());\n  JUST(InitOldAxis2NewAxisOffset(&old_axis2new_axis_offset, shape_axis2expanded_dims));\n  Symbol<NdSbp> new_src_nd_sbp =\n      JUST(ShiftSplitAxis(src_nd_sbp, shape_axis2src_nd_sbp_indexes, old_axis2new_axis_offset));\n  Symbol<NdSbp> new_dst_nd_sbp =\n      JUST(ShiftSplitAxis(dst_nd_sbp, shape_axis2dst_nd_sbp_indexes, old_axis2new_axis_offset));\n  return std::make_tuple(new_shape, new_src_nd_sbp, new_dst_nd_sbp);\n}\n\nnamespace {\n\n// nd_sbp is called decomposable if no particular axis is used to split tensor more than once.\n// e.g.\n// 1) (S0, S1) is decomposable.\n// 2) (S0, S0) is not decomposable.\n// 3) (S1, S1) is not decomposable.\n// although `nd_sbp (S0, S0) on shape (4, 4)` is not decomposable, they could be transformed into a\n// decomposable form: `n_sbp (S0, S1) on shape (2, 2, 4)`.\nMaybe<std::pair<Symbol<one::GlobalTensorMeta>, Symbol<NdSbp>>> CalcDecomposableEquivalent(\n    Symbol<one::GlobalTensorMeta> tensor_meta, Symbol<NdSbp> dst_nd_sbp) {\n  std::shared_ptr<const Shape> shape = tensor_meta->shape_ptr();\n  Symbol<NdSbp> src_nd_sbp = tensor_meta->nd_sbp();\n  const auto& hierarchy = tensor_meta->parallel_desc()->hierarchy();\n  std::tie(shape, src_nd_sbp, dst_nd_sbp) = *JUST(\n      CalcDecomposableEquivalentShapeAndNdSbpPair(*shape, *hierarchy, src_nd_sbp, dst_nd_sbp));\n\n  one::GlobalTensorMeta decomposible_tensor_meta(*shape, tensor_meta->dtype(),\n                                                 tensor_meta->memory_format(), src_nd_sbp,\n                                                 tensor_meta->parallel_desc());\n  return std::make_pair(SymbolOf(decomposible_tensor_meta), dst_nd_sbp);\n}\n\nstatic constexpr auto* GetDecomposableEquivalent =\n    DECORATE(&CalcDecomposableEquivalent, ThreadLocal);\n\nMaybe<void> InitDstNdSbpAxis2ExclusiveSrcNdSbpAxis(\n    HashMap<int64_t, int64_t>* dst_nd_sbp_axis2exclusive_src_nd_sbp_axis, Symbol<NdSbp> src_nd_sbp,\n    Symbol<NdSbp> dst_nd_sbp) {\n  HashMap<int64_t, int64_t> split_axis2src_nd_sbp_axis;\n  for (int i = 0; i < src_nd_sbp->sbp_parallel_size(); ++i) {\n    const auto& sbp_parallel = src_nd_sbp->sbp_parallel(i);\n    if (sbp_parallel.has_split_parallel()) {\n      split_axis2src_nd_sbp_axis[sbp_parallel.split_parallel().axis()] = i;\n    }\n  }\n  for (int i = 0; i < dst_nd_sbp->sbp_parallel_size(); ++i) {\n    const auto& sbp_parallel = dst_nd_sbp->sbp_parallel(i);\n    if (sbp_parallel.has_split_parallel()) {\n      int64_t axis = sbp_parallel.split_parallel().axis();\n      const auto& iter = split_axis2src_nd_sbp_axis.find(axis);\n      if (iter != split_axis2src_nd_sbp_axis.end() && iter->second != i) {\n        (*dst_nd_sbp_axis2exclusive_src_nd_sbp_axis)[i] = iter->second;\n      }\n    }\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> MakeExclusiveSrcNdSbpAxis4DstNdSbpAxis(\n    std::function<Maybe<Optional<int64_t>>(int64_t)>* ExclusiveSrcNdSbpAxis4DstNdSbpAxis,\n    Symbol<NdSbp> src_nd_sbp, Symbol<NdSbp> dst_nd_sbp) {\n  CHECK_EQ_OR_RETURN(src_nd_sbp->sbp_parallel_size(), dst_nd_sbp->sbp_parallel_size());\n  HashMap<int64_t, int64_t> split_axis2src_nd_sbp_axis;\n  for (int i = 0; i < src_nd_sbp->sbp_parallel_size(); ++i) {\n    const auto& sbp_parallel = src_nd_sbp->sbp_parallel(i);\n    if (sbp_parallel.has_split_parallel()) {\n      int64_t split_axis = sbp_parallel.split_parallel().axis();\n      CHECK_OR_RETURN(split_axis2src_nd_sbp_axis.emplace(split_axis, i).second);\n    }\n  }\n  {\n    // check split_axis used only once.\n    HashMap<int64_t, int64_t> split_axis2dst_nd_sbp_axis;\n    for (int i = 0; i < dst_nd_sbp->sbp_parallel_size(); ++i) {\n      const auto& sbp_parallel = dst_nd_sbp->sbp_parallel(i);\n      if (sbp_parallel.has_split_parallel()) {\n        int64_t split_axis = sbp_parallel.split_parallel().axis();\n        CHECK_OR_RETURN(split_axis2dst_nd_sbp_axis.emplace(split_axis, i).second);\n      }\n    }\n  }\n  *ExclusiveSrcNdSbpAxis4DstNdSbpAxis = [split_axis2src_nd_sbp_axis, src_nd_sbp,\n                                         dst_nd_sbp](int64_t dst_axis) -> Maybe<Optional<int64_t>> {\n    CHECK_GE_OR_RETURN(dst_axis, 0);\n    CHECK_LT_OR_RETURN(dst_axis, dst_nd_sbp->sbp_parallel_size());\n    const auto& dst_sbp_parallel = dst_nd_sbp->sbp_parallel(dst_axis);\n    if (!dst_sbp_parallel.has_split_parallel()) { return Optional<int64_t>(); }\n    int64_t split_axis = dst_sbp_parallel.split_parallel().axis();\n    const auto& src_iter = split_axis2src_nd_sbp_axis.find(split_axis);\n    if (src_iter == split_axis2src_nd_sbp_axis.end()) { return Optional<int64_t>(); }\n    int64_t src_axis = src_iter->second;\n    CHECK_GE_OR_RETURN(src_axis, 0);\n    CHECK_LT_OR_RETURN(src_axis, dst_nd_sbp->sbp_parallel_size());\n    const auto& src_sbp_parallel = src_nd_sbp->sbp_parallel(src_axis);\n    CHECK_OR_RETURN(src_sbp_parallel.has_split_parallel());\n    CHECK_EQ_OR_RETURN(src_sbp_parallel.split_parallel().axis(), split_axis);\n    if (src_axis == dst_axis) { return Optional<int64_t>(); }\n    return Optional<int64_t>(src_axis);\n  };\n  return Maybe<void>::Ok();\n}\n\nMaybe<bool> IsNdSbpBoxingAcyclic(\n    int64_t num_axes,\n    const std::function<Maybe<Optional<int64_t>>(int64_t)>& ExclusiveSrcNdSbpAxis4DstNdSbpAxis) {\n  for (int start_axis = 0; start_axis < num_axes; ++start_axis) {\n    int64_t axis = start_axis;\n    HashSet<int64_t> visited_axes;\n    for (int i = 0; i < num_axes + 1; ++i) {\n      const auto& opt_axis = JUST(ExclusiveSrcNdSbpAxis4DstNdSbpAxis(axis));\n      if (!opt_axis->has_value()) { break; }\n      axis = JUST(*opt_axis);\n      if (!visited_axes.insert(axis).second) { return false; }\n    }\n  }\n  return true;\n}\n\nMaybe<void> InitNdSbpValidTransformationAxisSequence(\n    std::vector<int64_t>* nd_sbp_axis_sequence, Symbol<NdSbp> src_nd_sbp, Symbol<NdSbp> dst_nd_sbp,\n    const std::function<Maybe<Optional<int64_t>>(int64_t)>& ExclusiveSrcNdSbpAxis4DstNdSbpAxis) {\n  CHECK_EQ_OR_RETURN(src_nd_sbp->sbp_parallel_size(), dst_nd_sbp->sbp_parallel_size());\n  int64_t num_axes = src_nd_sbp->sbp_parallel_size();\n  HashSet<int64_t> handled_axes;\n  nd_sbp_axis_sequence->reserve(num_axes);\n  const auto& HasNoExclusiveSrcNdSbpAxis = [&](int64_t axis) -> Maybe<bool> {\n    const auto& opt_src_axis = JUST(ExclusiveSrcNdSbpAxis4DstNdSbpAxis(axis));\n    if (!opt_src_axis->has_value()) { return true; }\n    return handled_axes.count(JUST(*opt_src_axis)) > 0;\n  };\n  for (int i = 0; i < num_axes; ++i) {\n    for (int axis = 0; axis < num_axes; ++axis) {\n      if (handled_axes.count(axis) == 0 && JUST(HasNoExclusiveSrcNdSbpAxis(axis))) {\n        if (!(src_nd_sbp->sbp_parallel(axis) == dst_nd_sbp->sbp_parallel(axis))) {\n          nd_sbp_axis_sequence->emplace_back(axis);\n        }\n        handled_axes.insert(axis);\n      }\n    }\n  }\n  CHECK_EQ_OR_RETURN(handled_axes.size(), num_axes);\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\nMaybe<bool> IsNdSbpBoxingAcyclic(Symbol<NdSbp> src_nd_sbp, Symbol<NdSbp> dst_nd_sbp) {\n  std::function<Maybe<Optional<int64_t>>(int64_t)> ExclusiveSrcNdSbpAxis4DstNdSbpAxis;\n  JUST(MakeExclusiveSrcNdSbpAxis4DstNdSbpAxis(&ExclusiveSrcNdSbpAxis4DstNdSbpAxis, src_nd_sbp,\n                                              dst_nd_sbp));\n  return IsNdSbpBoxingAcyclic(src_nd_sbp->sbp_parallel_size(), ExclusiveSrcNdSbpAxis4DstNdSbpAxis);\n}\n\nMaybe<std::vector<int64_t>> GetNdSbpValidTransformationAxisSequence(Symbol<NdSbp> src_nd_sbp,\n                                                                    Symbol<NdSbp> dst_nd_sbp) {\n  HashMap<int64_t, int64_t> dst_nd_sbp_axis2exclusive_src_nd_sbp_axis;\n  std::function<Maybe<Optional<int64_t>>(int64_t)> ExclusiveSrcNdSbpAxis4DstNdSbpAxis;\n  JUST(MakeExclusiveSrcNdSbpAxis4DstNdSbpAxis(&ExclusiveSrcNdSbpAxis4DstNdSbpAxis, src_nd_sbp,\n                                              dst_nd_sbp));\n  bool is_acyclic = JUST(\n      IsNdSbpBoxingAcyclic(src_nd_sbp->sbp_parallel_size(), ExclusiveSrcNdSbpAxis4DstNdSbpAxis));\n  CHECK_OR_RETURN(is_acyclic) << Error::UnimplementedError()\n                              << \"cyclic split axis boxing are not supported\";\n  std::vector<int64_t> nd_sbp_axis_sequence;\n  JUST(InitNdSbpValidTransformationAxisSequence(&nd_sbp_axis_sequence, src_nd_sbp, dst_nd_sbp,\n                                                ExclusiveSrcNdSbpAxis4DstNdSbpAxis));\n  return nd_sbp_axis_sequence;\n}\n\nstd::string GetCyclicBoxingDebugString(\n    Symbol<NdSbp> src_nd_sbp, Symbol<NdSbp> dst_nd_sbp,\n    const std::function<Maybe<Optional<int64_t>>(int64_t)>& ExclusiveSrcNdSbpAxis4DstNdSbpAxis) {\n  CHECK_EQ(src_nd_sbp->sbp_parallel_size(), dst_nd_sbp->sbp_parallel_size());\n  std::stringstream ss;\n  ss << \"cyclic split axis boxing are not supported. \"\n     << \"src_nd_sbp: \" << NdSbpToString(src_nd_sbp) << \", dst_nd_sbp: \" << NdSbpToString(dst_nd_sbp)\n     << \". \"\n     << \"dst_nd_sbp axis to exclusive src_nd_sbp axis: \";\n  ss << \"[\";\n  for (int i = 0; i < src_nd_sbp->sbp_parallel_size(); ++i) {\n    const auto& opt_axis = CHECK_JUST(ExclusiveSrcNdSbpAxis4DstNdSbpAxis(i));\n    if (i) { ss << \", \"; }\n    if (opt_axis->has_value()) {\n      ss << CHECK_JUST(*opt_axis);\n    } else {\n      ss << \"None\";\n    }\n  }\n  ss << \"]\";\n  return ss.str();\n}\n\nMaybe<Shape> GetPhysicalShape(const Shape& shape, Symbol<NdSbp> nd_sbp,\n                              Symbol<ParallelDesc> parallel_desc) {\n  const auto& parallel_id = JUST(GetParallelId4CurrentProcessCtx(parallel_desc));\n  return GetPhysicalShape(shape, *nd_sbp, *parallel_desc, JUST(*parallel_id));\n}\n\nMaybe<Shape> GetSubLogicalShape(Symbol<one::GlobalTensorMeta> tensor_meta,\n                                Symbol<ParallelDesc> sub_parallel_desc, Symbol<NdSbp> sub_nd_sbp) {\n  CHECK_EQ_OR_RETURN(sub_nd_sbp->sbp_parallel_size(), 1);  // NOLINT(maybe-need-error-msg)\n  const auto& logical_shape = tensor_meta->shape();\n  const auto& physical_shape =\n      JUST(GetPhysicalShape(logical_shape, tensor_meta->nd_sbp(), tensor_meta->parallel_desc()));\n\n  std::shared_ptr<Shape> sub_logical_shape = std::make_shared<Shape>(*physical_shape);\n  if (sub_nd_sbp->sbp_parallel(0).has_split_parallel()) {\n    const int64_t split_axis = sub_nd_sbp->sbp_parallel(0).split_parallel().axis();\n    sub_logical_shape->Set(split_axis, logical_shape.At(split_axis));\n  }\n  return sub_logical_shape;\n}\n\nMaybe<Symbol<one::GlobalTensorMeta>> CalcSubGlobalTensorMeta(\n    Symbol<one::GlobalTensorMeta> tensor_meta, Symbol<ParallelDesc> sub_parallel_desc,\n    Symbol<NdSbp> sub_nd_sbp) {\n  CHECK_EQ_OR_RETURN(sub_nd_sbp->sbp_parallel_size(), 1);  // NOLINT(maybe-need-error-msg)\n  const auto& logical_shape = JUST(GetSubLogicalShape(tensor_meta, sub_parallel_desc, sub_nd_sbp));\n  one::GlobalTensorMeta sub_global_tensor_meta(*logical_shape, tensor_meta->dtype(),\n                                               tensor_meta->memory_format(), sub_nd_sbp,\n                                               sub_parallel_desc);\n  return SymbolOf(sub_global_tensor_meta);\n}\n\nstatic constexpr auto* GetSubGlobalTensorMeta = DECORATE(&CalcSubGlobalTensorMeta, ThreadLocal);\n\nMaybe<Symbol<NdSbp>> ReplaceNdSbpComponent(Symbol<NdSbp> nd_sbp, int64_t axis,\n                                           Symbol<NdSbp> component) {\n  CHECK_GE_OR_RETURN(axis, 0);\n  CHECK_LT_OR_RETURN(axis, nd_sbp->sbp_parallel_size());\n  CHECK_EQ_OR_RETURN(component->sbp_parallel_size(), 1);\n  NdSbp new_nd_sbp(*nd_sbp);\n  *new_nd_sbp.mutable_sbp_parallel(axis) = component->sbp_parallel(0);\n  return SymbolOf(new_nd_sbp);\n}\n\nMaybe<Symbol<one::GlobalTensorMeta>> ReplaceNdSbp(Symbol<one::GlobalTensorMeta> tensor_meta,\n                                                  Symbol<NdSbp> nd_sbp) {\n  one::GlobalTensorMeta new_tensor_meta(tensor_meta->shape(), tensor_meta->dtype(),\n                                        tensor_meta->memory_format(), nd_sbp,\n                                        tensor_meta->parallel_desc());\n  return SymbolOf(new_tensor_meta);\n}\n\nMaybe<std::vector<NaiveBoxingTransformation>> DecomposeIntoNaiveTransformations(\n    Symbol<one::GlobalTensorMeta> tensor_meta, Symbol<NdSbp> dst_nd_sbp) {\n  std::tie(tensor_meta, dst_nd_sbp) = *JUST(GetDecomposableEquivalent(tensor_meta, dst_nd_sbp));\n  const auto& parallel_desc = tensor_meta->parallel_desc();\n  const auto& src_nd_sbp = tensor_meta->nd_sbp();\n  CHECK_EQ_OR_RETURN(src_nd_sbp->sbp_parallel_size(), dst_nd_sbp->sbp_parallel_size());\n  std::vector<int64_t> nd_sbp_axis_sequence;\n  {\n    std::function<Maybe<Optional<int64_t>>(int64_t)> ExclusiveSrcNdSbpAxis4DstNdSbpAxis;\n    JUST(MakeExclusiveSrcNdSbpAxis4DstNdSbpAxis(&ExclusiveSrcNdSbpAxis4DstNdSbpAxis, src_nd_sbp,\n                                                dst_nd_sbp));\n    bool is_acyclic = JUST(\n        IsNdSbpBoxingAcyclic(src_nd_sbp->sbp_parallel_size(), ExclusiveSrcNdSbpAxis4DstNdSbpAxis));\n    CHECK_OR_RETURN(is_acyclic) << Error::UnimplementedError()\n                                << GetCyclicBoxingDebugString(src_nd_sbp, dst_nd_sbp,\n                                                              ExclusiveSrcNdSbpAxis4DstNdSbpAxis);\n    JUST(InitNdSbpValidTransformationAxisSequence(&nd_sbp_axis_sequence, src_nd_sbp, dst_nd_sbp,\n                                                  ExclusiveSrcNdSbpAxis4DstNdSbpAxis));\n  }\n  const auto& transformations = std::make_shared<std::vector<NaiveBoxingTransformation>>();\n  for (int axis : nd_sbp_axis_sequence) {\n    const auto& src_sbp = src_nd_sbp->sbp_parallel(axis);\n    const auto& dst_sbp = dst_nd_sbp->sbp_parallel(axis);\n    if (src_sbp == dst_sbp) { continue; }\n    std::vector<int> axis2selected(src_nd_sbp->sbp_parallel_size());\n    axis2selected[axis] = 1;\n    const auto& sub_parallel_desc =\n        JUST(GetSelectedSubParallelDesc(parallel_desc, SymbolOf(axis2selected)));\n    const auto& sub_src_nd_sbp = JUST(MakeNdSbp(src_sbp));\n    const auto& sub_dst_nd_sbp = JUST(MakeNdSbp(dst_sbp));\n    const auto& sub_global_tensor_meta =\n        JUST(GetSubGlobalTensorMeta(tensor_meta, sub_parallel_desc, sub_src_nd_sbp));\n    const auto& new_src_nd_sbp =\n        JUST(ReplaceNdSbpComponent(tensor_meta->nd_sbp(), axis, sub_dst_nd_sbp));\n    tensor_meta = JUST(ReplaceNdSbp(tensor_meta, new_src_nd_sbp));\n    transformations->emplace_back(NaiveBoxingTransformation{\n        .global_tensor_meta = sub_global_tensor_meta,\n        .dst_nd_sbp = sub_dst_nd_sbp,\n    });\n  }\n  return transformations;\n}\n\n}  // namespace private_details\n\nnamespace {\n\nMaybe<std::unordered_map<int64_t, Symbol<ParallelDesc>>> CalcBroadcastGroup(\n    Symbol<ParallelDesc> src_parallel_desc, Symbol<ParallelDesc> dst_parallel_desc,\n    bool allow_across_node) {\n  CHECK_EQ_OR_RETURN(src_parallel_desc->parallel_num(),\n                     src_parallel_desc->sorted_machine_ids().size());\n  CHECK_EQ_OR_RETURN(dst_parallel_desc->parallel_num(),\n                     dst_parallel_desc->sorted_machine_ids().size());\n  CHECK_EQ_OR_RETURN(src_parallel_desc->device_type(), dst_parallel_desc->device_type());\n  CHECK_LE_OR_RETURN(src_parallel_desc->parallel_num(), dst_parallel_desc->parallel_num());\n  const auto& src_process_ids = src_parallel_desc->sorted_machine_ids();\n  HashMap<int64_t, std::vector<int64_t>> process_id2group{};\n  HashMap<int64_t, std::vector<int64_t>> node_id2src_process_id{};\n  for (int64_t process_id : src_process_ids) {\n    std::vector<int64_t> vec{process_id};\n    CHECK_OR_RETURN(process_id2group.emplace(process_id, vec).second);\n    CHECK_OR_RETURN(dst_parallel_desc->ContainingMachineId(process_id));\n    node_id2src_process_id[GlobalProcessCtx::NodeId(process_id)].emplace_back(process_id);\n  }\n  std::vector<int64_t> remainder_process_ids{};\n  remainder_process_ids.reserve(dst_parallel_desc->sorted_machine_ids().size());\n  HashMap<int64_t, int64_t> node_id2counter{};\n  for (int64_t process_id : dst_parallel_desc->sorted_machine_ids()) {\n    if (!src_parallel_desc->ContainingMachineId(process_id)) {\n      const auto& node_iter = node_id2src_process_id.find(GlobalProcessCtx::NodeId(process_id));\n      if (node_iter == node_id2src_process_id.end()) {\n        CHECK_OR_RETURN(allow_across_node)\n            << Error::UnimplementedError() << \"\\n----[src_placement]----\\n\"\n            << src_parallel_desc->parallel_conf().DebugString() << \"\\n----[dst_placement]----\\n\"\n            << dst_parallel_desc->parallel_conf().DebugString();\n        // handle `process_id` later.\n        remainder_process_ids.emplace_back(process_id);\n      } else {\n        // balancedly put `process_id` into the groups within the same node..\n        int64_t node_id = node_iter->first;\n        const auto& src_process_ids = node_iter->second;\n        int64_t src_process_index = (node_id2counter[node_id]++) % src_process_ids.size();\n        int64_t src_process_id = src_process_ids.at(src_process_index);\n        JUST(MapAt(process_id2group, src_process_id)).emplace_back(process_id);\n      }\n    }\n  }\n  // put remainder process ids into src groups.\n  for (int i = 0; i < remainder_process_ids.size(); ++i) {\n    int64_t src_process_id = src_process_ids.at(i % src_process_ids.size());\n    JUST(MapAt(process_id2group, src_process_id))\n        .emplace_back(JUST(oneflow::VectorAt(remainder_process_ids, i)));\n  }\n  const auto& map = std::make_shared<std::unordered_map<int64_t, Symbol<ParallelDesc>>>();\n  for (const auto& pair : process_id2group) {\n    const auto& group = pair.second;\n    ParallelConf parallel_conf;\n    parallel_conf.set_device_tag(dst_parallel_desc->parallel_conf().device_tag());\n    for (int64_t process_id : group) {\n      const auto& device_ids = dst_parallel_desc->sorted_dev_phy_ids(process_id);\n      CHECK_EQ_OR_RETURN(device_ids.size(), 1);\n      parallel_conf.add_device_name(std::string(\"@\") + std::to_string(process_id) + \":\"\n                                    + std::to_string(device_ids.at(0)));\n    }\n    const auto& parallel_desc = SymbolOf(ParallelDesc(parallel_conf));\n    for (int64_t process_id : group) {\n      CHECK_OR_RETURN(map->emplace(process_id, parallel_desc).second);\n    }\n  }\n  return map;\n}\nauto* CachedBroadcastGroup = DECORATE(&CalcBroadcastGroup, ThreadLocal);\n\nMaybe<void> RawCheckIsNdSbpBoxingAcyclic(Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out) {\n  using namespace private_details;\n  const auto& src_nd_sbp = in->nd_sbp();\n  const auto& dst_nd_sbp = out->nd_sbp();\n  std::function<Maybe<Optional<int64_t>>(int64_t)> ExclusiveSrcNdSbpAxis4DstNdSbpAxis;\n  JUST(MakeExclusiveSrcNdSbpAxis4DstNdSbpAxis(&ExclusiveSrcNdSbpAxis4DstNdSbpAxis, src_nd_sbp,\n                                              dst_nd_sbp));\n  bool is_acyclic = JUST(\n      IsNdSbpBoxingAcyclic(src_nd_sbp->sbp_parallel_size(), ExclusiveSrcNdSbpAxis4DstNdSbpAxis));\n  CHECK_OR_RETURN(is_acyclic) << Error::UnimplementedError()\n                              << GetCyclicBoxingDebugString(src_nd_sbp, dst_nd_sbp,\n                                                            ExclusiveSrcNdSbpAxis4DstNdSbpAxis);\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> RawCheckIsNdSbpBoxingAcyclicWithDecompose(Symbol<PlacedNdSbp> in,\n                                                      Symbol<PlacedNdSbp> out,\n                                                      const Shape& logical_shape) {\n  using namespace private_details;\n  Symbol<NdSbp> src_nd_sbp = in->nd_sbp();\n  Symbol<NdSbp> dst_nd_sbp = out->nd_sbp();\n  const auto& hierarchy = in->placement()->hierarchy();\n  std::shared_ptr<const Shape> shape;\n\n  std::tie(shape, src_nd_sbp, dst_nd_sbp) = *JUST(CalcDecomposableEquivalentShapeAndNdSbpPair(\n      logical_shape, *hierarchy, src_nd_sbp, dst_nd_sbp));\n\n  std::function<Maybe<Optional<int64_t>>(int64_t)> ExclusiveSrcNdSbpAxis4DstNdSbpAxis;\n  JUST(MakeExclusiveSrcNdSbpAxis4DstNdSbpAxis(&ExclusiveSrcNdSbpAxis4DstNdSbpAxis, src_nd_sbp,\n                                              dst_nd_sbp));\n  bool is_acyclic = JUST(\n      IsNdSbpBoxingAcyclic(src_nd_sbp->sbp_parallel_size(), ExclusiveSrcNdSbpAxis4DstNdSbpAxis));\n  CHECK_OR_RETURN(is_acyclic) << Error::UnimplementedError()\n                              << GetCyclicBoxingDebugString(src_nd_sbp, dst_nd_sbp,\n                                                            ExclusiveSrcNdSbpAxis4DstNdSbpAxis);\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\nint64_t CalcIndex4Axis(int64_t offset, const Stride& stride, int axis) {\n  CHECK_LT(axis, stride.size()) << \"Expected axis (\" << axis << \") to be less than size of stride (\"\n                                << stride.size() << \")\";\n  if (axis == 0) {\n    return offset / stride.at(0);\n  } else {\n    return offset % stride.at(axis - 1) / stride.at(axis);\n  }\n}\n\ndecltype(CheckIsNdSbpBoxingAcyclic) CheckIsNdSbpBoxingAcyclic =\n    DECORATE(&RawCheckIsNdSbpBoxingAcyclic, ThreadLocal);\n\ndecltype(CheckIsNdSbpBoxingAcyclicWithDecompose) CheckIsNdSbpBoxingAcyclicWithDecompose =\n    DECORATE(&RawCheckIsNdSbpBoxingAcyclicWithDecompose, ThreadLocalCopiable);\n\nMaybe<std::unordered_map<int64_t, Symbol<ParallelDesc>>> GetBroadcastGroup(\n    Symbol<ParallelDesc> src_parallel_desc, Symbol<ParallelDesc> dst_parallel_desc) {\n  return CachedBroadcastGroup(src_parallel_desc, dst_parallel_desc, true);\n}\n\nMaybe<std::unordered_map<int64_t, Symbol<ParallelDesc>>> GetBroadcastGroupWithoutAcrossNode(\n    Symbol<ParallelDesc> src_parallel_desc, Symbol<ParallelDesc> dst_parallel_desc) {\n  return CachedBroadcastGroup(src_parallel_desc, dst_parallel_desc, false);\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/framework/placement_sbp_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_FRAMEWORK_PLACEMENT_SBP_UTIL_H_\n#define ONEFLOW_CORE_FRAMEWORK_PLACEMENT_SBP_UTIL_H_\n\n#include <unordered_map>\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/common/symbol.h\"\n#include \"oneflow/core/common/decorator.h\"\n#include \"oneflow/core/job/sbp_parallel.h\"\n#include \"oneflow/core/common/stride.h\"\n\nnamespace oneflow {\n\nclass Shape;\nclass Stride;\nclass ParallelDesc;\nclass PlacedNdSbp;\n\nnamespace one {\n\nclass GlobalTensorMeta;\n\n}\n\n// 1) src_nd_sbp.sbp_parallel_size() == 1\n// 2) dst_nd_sbp.sbp_parallel_size() == 1\nstruct NaiveBoxingTransformation {\n  Symbol<one::GlobalTensorMeta> global_tensor_meta;\n  Symbol<NdSbp> dst_nd_sbp;\n};\n\nnamespace private_details {\n\nMaybe<std::vector<int64_t>> GetSelectedParallelIds(const Shape& hierarchy_shape,\n                                                   const std::vector<int>& axis2is_selected,\n                                                   int64_t parallel_id);\n\nMaybe<std::tuple<std::shared_ptr<const Shape>, Symbol<NdSbp>, Symbol<NdSbp>>>\nCalcDecomposableEquivalentShapeAndNdSbpPair(const Shape& shape, const Shape& hierarchy,\n                                            Symbol<NdSbp> src_nd_sbp, Symbol<NdSbp> dst_nd_sbp);\n\nMaybe<Symbol<ParallelDesc>> GetBroadcastSubParallelDesc(Symbol<ParallelDesc> parallel_desc,\n                                                        Symbol<NdSbp> nd_sbp);\n\nMaybe<std::vector<NaiveBoxingTransformation>> DecomposeIntoNaiveTransformations(\n    Symbol<one::GlobalTensorMeta> tensor_meta, Symbol<NdSbp> dst_nd_sbp);\n\nMaybe<bool> IsNdSbpBoxingAcyclic(Symbol<NdSbp> src_nd_sbp, Symbol<NdSbp> dst_nd_sbp);\n\nMaybe<std::vector<int64_t>> GetNdSbpValidTransformationAxisSequence(Symbol<NdSbp> src_nd_sbp,\n                                                                    Symbol<NdSbp> dst_nd_sbp);\n\nMaybe<Symbol<one::GlobalTensorMeta>> CalcSubGlobalTensorMeta(\n    Symbol<one::GlobalTensorMeta> tensor_meta, Symbol<ParallelDesc> sub_parallel_desc,\n    Symbol<NdSbp> sub_nd_sbp);\n\nMaybe<Symbol<ParallelDesc>> CalcSubParallelDesc4Axis(Symbol<ParallelDesc> parallel_desc, int axis);\n\n}  // namespace private_details\n\nextern Maybe<void> (*CheckIsNdSbpBoxingAcyclic)(Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out);\n\nextern Maybe<void> (*CheckIsNdSbpBoxingAcyclicWithDecompose)(Symbol<PlacedNdSbp> in,\n                                                             Symbol<PlacedNdSbp> out,\n                                                             const Shape& logical_shape);\n\nint64_t CalcIndex4Axis(int64_t offset, const Stride& stride, int axis);\n\nstatic constexpr auto* GetSubGlobalTensorMeta =\n    DECORATE(&private_details::CalcSubGlobalTensorMeta, ThreadLocal);\n\nstatic constexpr auto* GetBroadcastSubParallelDesc =\n    DECORATE(&private_details::GetBroadcastSubParallelDesc, ThreadLocal);\n\nstatic constexpr auto* DecomposeIntoNaiveTransformations =\n    DECORATE(&private_details::DecomposeIntoNaiveTransformations, ThreadLocal);\n\nstatic constexpr auto* CalcSubParallelDesc4Axis =\n    DECORATE(&private_details::CalcSubParallelDesc4Axis, ThreadLocal);\n\nMaybe<std::unordered_map<int64_t, Symbol<ParallelDesc>>> GetBroadcastGroup(\n    Symbol<ParallelDesc> src_parallel_desc, Symbol<ParallelDesc> dst_parallel_desc);\n\nMaybe<std::unordered_map<int64_t, Symbol<ParallelDesc>>> GetBroadcastGroupWithoutAcrossNode(\n    Symbol<ParallelDesc> src_parallel_desc, Symbol<ParallelDesc> dst_parallel_desc);\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_FRAMEWORK_PLACEMENT_SBP_UTIL_H_\n"
  },
  {
    "path": "oneflow/core/framework/placement_sbp_util_test.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"gtest/gtest.h\"\n#include \"oneflow/core/framework/placement_sbp_util.h\"\n#include \"oneflow/core/common/tensor_meta.h\"\n#include \"oneflow/core/job/parallel_desc.h\"\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/common/shape.h\"\n#include \"oneflow/core/control/ctrl_bootstrap.pb.h\"\n#include \"oneflow/core/job/sbp_parallel.h\"\n\nnamespace oneflow {\nnamespace test {\n\nnamespace {\n\nstruct GlobaProcessCtxScope final {\n  GlobaProcessCtxScope(GlobaProcessCtxScope&) = default;\n  GlobaProcessCtxScope(GlobaProcessCtxScope&&) = default;\n  GlobaProcessCtxScope& operator=(GlobaProcessCtxScope&) = default;\n  GlobaProcessCtxScope& operator=(GlobaProcessCtxScope&&) = default;\n  GlobaProcessCtxScope(int64_t node_size, int64_t world_size) {\n    Singleton<ProcessCtx>::New();\n    auto* ctx = Singleton<ProcessCtx>::Get();\n    for (int i = 0; i < world_size; ++i) { ctx->mutable_ctrl_addr()->Add(); }\n    ctx->set_rank(0);\n    ctx->set_node_size(node_size);\n  }\n  ~GlobaProcessCtxScope() { Singleton<ProcessCtx>::Delete(); }\n};\n\n}  // namespace\n\nTEST(GetSelectedParallelIds, 1d_broadcast) {\n  int64_t parallel_size = 4;\n  Shape hierarchy_shape(DimVector{parallel_size});\n  std::vector<int> axis2is_selected{true};\n  const auto& expected = std::vector<int64_t>{0, 1, 2, 3};\n  for (int i = 0; i < parallel_size; ++i) {\n    const auto& broadcast_parallel_ids =\n        CHECK_JUST(private_details::GetSelectedParallelIds(hierarchy_shape, axis2is_selected, i));\n    ASSERT_TRUE(*broadcast_parallel_ids == expected);\n  }\n}\n\nTEST(GetSelectedParallelIds, 1d_nonbroadcast) {\n  int64_t parallel_size = 4;\n  Shape hierarchy_shape(DimVector{parallel_size});\n  std::vector<int> axis2is_selected{false};\n  for (int i = 0; i < parallel_size; ++i) {\n    const auto& broadcast_parallel_ids =\n        CHECK_JUST(private_details::GetSelectedParallelIds(hierarchy_shape, axis2is_selected, i));\n    const auto& expected = std::vector<int64_t>{i};\n    ASSERT_TRUE(*broadcast_parallel_ids == expected);\n  }\n}\n\nTEST(GetSelectedParallelIds, 2d_broadcast_broadcast) {\n  int64_t parallel_size = 4;\n  Shape hierarchy_shape(DimVector{parallel_size, parallel_size});\n  std::vector<int> axis2is_selected{true, true};\n  std::vector<int64_t> expected{};\n  for (int i = 0; i < parallel_size * parallel_size; ++i) { expected.emplace_back(i); }\n  for (int i = 0; i < parallel_size * parallel_size; ++i) {\n    const auto& broadcast_parallel_ids =\n        CHECK_JUST(private_details::GetSelectedParallelIds(hierarchy_shape, axis2is_selected, i));\n    ASSERT_TRUE(*broadcast_parallel_ids == expected);\n  }\n}\n\nTEST(GetSelectedParallelIds, 2d_nonbroadcast_nonbroadcast) {\n  int64_t parallel_size = 4;\n  Shape hierarchy_shape(DimVector{parallel_size, parallel_size});\n  std::vector<int> axis2is_selected{false, false};\n  for (int i = 0; i < parallel_size * parallel_size; ++i) {\n    const auto& broadcast_parallel_ids =\n        CHECK_JUST(private_details::GetSelectedParallelIds(hierarchy_shape, axis2is_selected, i));\n    const auto& expected = std::vector<int64_t>{i};\n    ASSERT_TRUE(*broadcast_parallel_ids == expected);\n  }\n}\n\nTEST(GetSelectedParallelIds, 2d_broadcast_nonbroadcast) {\n  int64_t parallel_size = 4;\n  Shape hierarchy_shape(DimVector{parallel_size, parallel_size});\n  std::vector<int> axis2is_selected{true, false};\n  for (int i = 0; i < parallel_size; ++i) {\n    for (int j = 0; j < parallel_size; ++j) {\n      std::vector<int64_t> expected{};\n      for (int k = 0; k < parallel_size; ++k) { expected.emplace_back(k * parallel_size + j); }\n      int64_t parallel_id = i * parallel_size + j;\n      const auto& broadcast_parallel_ids = CHECK_JUST(\n          private_details::GetSelectedParallelIds(hierarchy_shape, axis2is_selected, parallel_id));\n      ASSERT_TRUE(*broadcast_parallel_ids == expected);\n    }\n  }\n}\n\nTEST(GetSelectedParallelIds, 2d_nonbroadcast_broadcast) {\n  int64_t parallel_size = 4;\n  Shape hierarchy_shape(DimVector{parallel_size, parallel_size});\n  std::vector<int> axis2is_selected{false, true};\n  for (int i = 0; i < parallel_size; ++i) {\n    std::vector<int64_t> expected{};\n    for (int j = 0; j < parallel_size; ++j) { expected.emplace_back(i * parallel_size + j); }\n    for (int j = 0; j < parallel_size; ++j) {\n      int64_t parallel_id = i * parallel_size + j;\n      const auto& broadcast_parallel_ids = CHECK_JUST(\n          private_details::GetSelectedParallelIds(hierarchy_shape, axis2is_selected, parallel_id));\n      ASSERT_TRUE(*broadcast_parallel_ids == expected);\n    }\n  }\n}\n\nnamespace {\n\nvoid InitSbpParallel(SbpParallel* sbp_parallel, const std::string& sbp_tag) {\n  CHECK(sbp_tag.size() == 1 || sbp_tag.size() == 2);\n  if (sbp_tag[0] == 'S') {\n    CHECK_EQ(sbp_tag.size(), 2);\n    int64_t axis = sbp_tag[1] - '0';\n    sbp_parallel->mutable_split_parallel()->set_axis(axis);\n  } else if (sbp_tag == \"B\") {\n    sbp_parallel->mutable_broadcast_parallel();\n  } else if (sbp_tag == \"P\") {\n    sbp_parallel->mutable_partial_sum_parallel();\n  } else {\n    UNIMPLEMENTED();\n  }\n}\n\ntemplate<typename... Args>\nSymbol<NdSbp> GetNdSbp(Args... sbps) {\n  NdSbp nd_sbp;\n  for (const auto& sbp : std::vector<std::string>{sbps...}) {\n    InitSbpParallel(nd_sbp.mutable_sbp_parallel()->Add(), sbp);\n  }\n  return SymbolOf(nd_sbp);\n}\n\nSymbol<one::GlobalTensorMeta> MakeGlobalTensorMeta(Symbol<ParallelDesc> parallel_desc,\n                                                   Symbol<NdSbp> nd_sbp) {\n  auto shape = Shape(DimVector{256, 256});\n  one::GlobalTensorMeta tensor_meta(shape, DataType::kInt32, MemoryFormat::kContiguous, nd_sbp,\n                                    parallel_desc);\n  return SymbolOf(tensor_meta);\n}\n\n}  // namespace\n\nTEST(DecomposeIntoNaiveTransformations, decompose_axis0) {\n  GlobaProcessCtxScope scope(2, 8);\n  ParallelConf parallel_conf;\n  parallel_conf.set_device_tag(\"cpu\");\n  parallel_conf.add_device_name(\"0:0-3\");\n  parallel_conf.add_device_name(\"1:0-3\");\n  parallel_conf.mutable_hierarchy()->add_dim(2);\n  parallel_conf.mutable_hierarchy()->add_dim(4);\n  const auto& parallel_desc = SymbolOf(ParallelDesc(parallel_conf));\n  const auto& src_nd_sbp = GetNdSbp(\"P\", \"B\");\n  const auto& dst_nd_sbp = GetNdSbp(\"S0\", \"B\");\n  const auto& tensor_meta = MakeGlobalTensorMeta(parallel_desc, src_nd_sbp);\n  const auto& transformations =\n      CHECK_JUST(private_details::DecomposeIntoNaiveTransformations(tensor_meta, dst_nd_sbp));\n  ASSERT_EQ(transformations->size(), 1);\n  ParallelConf expected_parallel_conf;\n  expected_parallel_conf.set_device_tag(\"cpu\");\n  expected_parallel_conf.add_device_name(std::string(\"0:0\"));\n  expected_parallel_conf.add_device_name(std::string(\"1:0\"));\n  const auto& expected_parallel_desc = SymbolOf(ParallelDesc(expected_parallel_conf));\n  const auto& ctensor_meta = transformations->at(0).global_tensor_meta;\n  ASSERT_TRUE(ctensor_meta->parallel_desc() == expected_parallel_desc);\n  ASSERT_EQ(ctensor_meta->nd_sbp()->sbp_parallel_size(), 1);\n  ASSERT_EQ(transformations->at(0).dst_nd_sbp->sbp_parallel_size(), 1);\n  ASSERT_TRUE(ctensor_meta->nd_sbp()->sbp_parallel(0).has_partial_sum_parallel());\n  ASSERT_TRUE(transformations->at(0).dst_nd_sbp->sbp_parallel(0).has_split_parallel());\n  ASSERT_EQ(transformations->at(0).dst_nd_sbp->sbp_parallel(0).split_parallel().axis(), 0);\n}\n\nTEST(DecomposeIntoNaiveTransformations, decompose_axis1) {\n  GlobaProcessCtxScope scope(2, 8);\n  ParallelConf parallel_conf;\n  parallel_conf.set_device_tag(\"cpu\");\n  parallel_conf.add_device_name(\"0:0-3\");\n  parallel_conf.add_device_name(\"1:0-3\");\n  parallel_conf.mutable_hierarchy()->add_dim(2);\n  parallel_conf.mutable_hierarchy()->add_dim(4);\n  const auto& parallel_desc = SymbolOf(ParallelDesc(parallel_conf));\n  const auto& src_nd_sbp = GetNdSbp(\"S0\", \"P\");\n  const auto& dst_nd_sbp = GetNdSbp(\"S0\", \"S1\");\n  const auto& tensor_meta = MakeGlobalTensorMeta(parallel_desc, src_nd_sbp);\n  const auto& transformations =\n      CHECK_JUST(private_details::DecomposeIntoNaiveTransformations(tensor_meta, dst_nd_sbp));\n  ASSERT_EQ(transformations->size(), 1);\n  ParallelConf expected_parallel_conf;\n  expected_parallel_conf.set_device_tag(\"cpu\");\n  expected_parallel_conf.add_device_name(\"0:0-3\");\n  const auto& expected_parallel_desc = SymbolOf(ParallelDesc(expected_parallel_conf));\n  const auto& ctensor_meta = transformations->at(0).global_tensor_meta;\n  ASSERT_TRUE(ctensor_meta->parallel_desc() == expected_parallel_desc);\n  ASSERT_EQ(ctensor_meta->nd_sbp()->sbp_parallel_size(), 1);\n  ASSERT_EQ(transformations->at(0).dst_nd_sbp->sbp_parallel_size(), 1);\n  ASSERT_TRUE(ctensor_meta->nd_sbp()->sbp_parallel(0).has_partial_sum_parallel());\n  ASSERT_TRUE(transformations->at(0).dst_nd_sbp->sbp_parallel(0).has_split_parallel());\n  ASSERT_EQ(transformations->at(0).dst_nd_sbp->sbp_parallel(0).split_parallel().axis(), 1);\n}\n\nTEST(DecomposeIntoNaiveTransformations, decompose_two_axes) {\n  GlobaProcessCtxScope scope(2, 8);\n  ParallelConf parallel_conf;\n  parallel_conf.set_device_tag(\"cpu\");\n  parallel_conf.add_device_name(\"0:0-1\");\n  parallel_conf.add_device_name(\"1:0-1\");\n  parallel_conf.mutable_hierarchy()->add_dim(2);\n  parallel_conf.mutable_hierarchy()->add_dim(2);\n  const auto& parallel_desc = SymbolOf(ParallelDesc(parallel_conf));\n  const auto& src_nd_sbp = GetNdSbp(\"S0\", \"P\");\n  const auto& dst_nd_sbp = GetNdSbp(\"B\", \"S0\");\n  const auto& tensor_meta = MakeGlobalTensorMeta(parallel_desc, src_nd_sbp);\n  const auto& transformations =\n      CHECK_JUST(private_details::DecomposeIntoNaiveTransformations(tensor_meta, dst_nd_sbp));\n  ASSERT_EQ(transformations->size(), 2);\n  {\n    ParallelConf expected_parallel_conf;\n    expected_parallel_conf.set_device_tag(\"cpu\");\n    expected_parallel_conf.add_device_name(std::string(\"0:0\"));\n    expected_parallel_conf.add_device_name(std::string(\"1:0\"));\n    const auto& expected_parallel_desc = SymbolOf(ParallelDesc(expected_parallel_conf));\n    const auto& ctensor_meta = transformations->at(0).global_tensor_meta;\n    ASSERT_TRUE(ctensor_meta->parallel_desc() == expected_parallel_desc);\n    ASSERT_EQ(ctensor_meta->nd_sbp()->sbp_parallel_size(), 1);\n    ASSERT_EQ(transformations->at(0).dst_nd_sbp->sbp_parallel_size(), 1);\n    ASSERT_TRUE(ctensor_meta->nd_sbp()->sbp_parallel(0).has_split_parallel());\n    ASSERT_TRUE(transformations->at(0).dst_nd_sbp->sbp_parallel(0).has_broadcast_parallel());\n    ASSERT_EQ(ctensor_meta->nd_sbp()->sbp_parallel(0).split_parallel().axis(), 0);\n  }\n  {\n    ParallelConf expected_parallel_conf;\n    expected_parallel_conf.set_device_tag(\"cpu\");\n    expected_parallel_conf.add_device_name(\"0:0-1\");\n    const auto& expected_parallel_desc = SymbolOf(ParallelDesc(expected_parallel_conf));\n    const auto& ctensor_meta = transformations->at(1).global_tensor_meta;\n    ASSERT_TRUE(ctensor_meta->parallel_desc() == expected_parallel_desc);\n    ASSERT_EQ(ctensor_meta->nd_sbp()->sbp_parallel_size(), 1);\n    ASSERT_EQ(transformations->at(1).dst_nd_sbp->sbp_parallel_size(), 1);\n    ASSERT_TRUE(ctensor_meta->nd_sbp()->sbp_parallel(0).has_partial_sum_parallel());\n    ASSERT_TRUE(transformations->at(1).dst_nd_sbp->sbp_parallel(0).has_split_parallel());\n    ASSERT_EQ(transformations->at(1).dst_nd_sbp->sbp_parallel(0).split_parallel().axis(), 0);\n  }\n}\n\nTEST(CalcDecomposableEquivalentShapeAndNdSbpPair, naive) {\n  Shape shape(DimVector{4, 4});\n  Shape hierarchy(DimVector{4, 4});\n  const auto& src_nd_sbp = GetNdSbp(\"S0\", \"S1\");\n  const auto& dst_nd_sbp = GetNdSbp(\"B\", \"P\");\n  const auto& maybe_tuple = TRY(private_details::CalcDecomposableEquivalentShapeAndNdSbpPair(\n      shape, hierarchy, src_nd_sbp, dst_nd_sbp));\n  ASSERT_TRUE(maybe_tuple.IsOk());\n  const auto& tuple = CHECK_JUST(maybe_tuple);\n  ASSERT_TRUE(*std::get<0>(*tuple) == shape);\n  ASSERT_TRUE(std::get<1>(*tuple) == src_nd_sbp);\n  ASSERT_TRUE(std::get<2>(*tuple) == dst_nd_sbp);\n}\n\nTEST(CalcDecomposableEquivalentShapeAndNdSbpPair, expand_src) {\n  Shape shape(DimVector{16, 4});\n  Shape hierarchy(DimVector{4, 4});\n  const auto& src_nd_sbp = GetNdSbp(\"S0\", \"S0\");\n  const auto& dst_nd_sbp = GetNdSbp(\"B\", \"P\");\n  const auto& maybe_tuple = TRY(private_details::CalcDecomposableEquivalentShapeAndNdSbpPair(\n      shape, hierarchy, src_nd_sbp, dst_nd_sbp));\n  ASSERT_TRUE(maybe_tuple.IsOk());\n  const auto& tuple = CHECK_JUST(maybe_tuple);\n  ASSERT_TRUE(*std::get<0>(*tuple) == Shape(DimVector{4, 4, 4}));\n  ASSERT_TRUE(std::get<1>(*tuple) == GetNdSbp(\"S0\", \"S1\"));\n  ASSERT_TRUE(std::get<2>(*tuple) == dst_nd_sbp);\n}\n\nTEST(CalcDecomposableEquivalentShapeAndNdSbpPair, expand_failed) {\n  Shape shape(DimVector{32, 4});\n  Shape hierarchy(DimVector{4, 4, 4});\n  const auto& src_nd_sbp = GetNdSbp(\"S0\", \"S0\", \"S0\");\n  const auto& dst_nd_sbp = GetNdSbp(\"P\", \"S0\", \"S1\");\n  const auto& maybe_tuple = TRY(private_details::CalcDecomposableEquivalentShapeAndNdSbpPair(\n      shape, hierarchy, src_nd_sbp, dst_nd_sbp));\n  ASSERT_FALSE(maybe_tuple.IsOk());\n}\n\nTEST(IsNdSbpBoxingAcyclic, yes) {\n  const auto& src_nd_sbp = GetNdSbp(\"S0\", \"S1\", \"S2\");\n  const auto& dst_nd_sbp = GetNdSbp(\"S1\", \"S2\", \"S3\");\n  const auto& maybe_acyclic = TRY(private_details::IsNdSbpBoxingAcyclic(src_nd_sbp, dst_nd_sbp));\n  ASSERT_TRUE(maybe_acyclic.IsOk());\n  ASSERT_TRUE(CHECK_JUST(maybe_acyclic));\n}\n\nTEST(IsNdSbpBoxingAcyclic, ring) {\n  const auto& src_nd_sbp = GetNdSbp(\"S0\", \"S1\", \"S2\");\n  const auto& dst_nd_sbp = GetNdSbp(\"S1\", \"S2\", \"S0\");\n  const auto& maybe_acyclic = TRY(private_details::IsNdSbpBoxingAcyclic(src_nd_sbp, dst_nd_sbp));\n  ASSERT_TRUE(maybe_acyclic.IsOk());\n  ASSERT_FALSE(CHECK_JUST(maybe_acyclic));\n}\n\nTEST(IsNdSbpBoxingAcyclic, partial_ring) {\n  const auto& src_nd_sbp = GetNdSbp(\"B\", \"S0\", \"S1\", \"S2\", \"S5\");\n  const auto& dst_nd_sbp = GetNdSbp(\"P\", \"S1\", \"S2\", \"S0\", \"S4\");\n  const auto& maybe_acyclic = TRY(private_details::IsNdSbpBoxingAcyclic(src_nd_sbp, dst_nd_sbp));\n  ASSERT_TRUE(maybe_acyclic.IsOk());\n  ASSERT_FALSE(CHECK_JUST(maybe_acyclic));\n}\n\nTEST(IsNdSbpBoxingAcyclic, dag) {\n  const auto& src_nd_sbp = GetNdSbp(\"S0\", \"S1\", \"S2\");\n  const auto& dst_nd_sbp = GetNdSbp(\"S1\", \"S2\", \"S3\");\n  const auto& maybe_acyclic = TRY(private_details::IsNdSbpBoxingAcyclic(src_nd_sbp, dst_nd_sbp));\n  ASSERT_TRUE(maybe_acyclic.IsOk());\n  ASSERT_TRUE(CHECK_JUST(maybe_acyclic));\n}\n\nTEST(GetNdSbpValidTransformationAxisSequence, naive) {\n  const auto& src_nd_sbp = GetNdSbp(\"S0\", \"S1\", \"S2\");\n  const auto& dst_nd_sbp = GetNdSbp(\"S0\", \"B\", \"S2\");\n  const auto& maybe_axis_seq =\n      TRY(private_details::GetNdSbpValidTransformationAxisSequence(src_nd_sbp, dst_nd_sbp));\n  ASSERT_TRUE(maybe_axis_seq.IsOk());\n  const auto& axis_seq = CHECK_JUST(maybe_axis_seq);\n  ASSERT_TRUE(*axis_seq == std::vector<int64_t>{1});\n}\n\nTEST(GetNdSbpValidTransformationAxisSequence, 2d) {\n  const auto& src_nd_sbp = GetNdSbp(\"B\", \"S0\");\n  const auto& dst_nd_sbp = GetNdSbp(\"S0\", \"S1\");\n  const auto& maybe_axis_seq =\n      TRY(private_details::GetNdSbpValidTransformationAxisSequence(src_nd_sbp, dst_nd_sbp));\n  ASSERT_TRUE(maybe_axis_seq.IsOk());\n  const auto& axis_seq = CHECK_JUST(maybe_axis_seq);\n  ASSERT_TRUE(*axis_seq == (std::vector<int64_t>{1, 0}));\n}\n\nTEST(GetNdSbpValidTransformationAxisSequence, 3d) {\n  const auto& src_nd_sbp = GetNdSbp(\"S0\", \"S1\", \"S2\");\n  const auto& dst_nd_sbp = GetNdSbp(\"S1\", \"S2\", \"S3\");\n  const auto& maybe_axis_seq =\n      TRY(private_details::GetNdSbpValidTransformationAxisSequence(src_nd_sbp, dst_nd_sbp));\n  ASSERT_TRUE(maybe_axis_seq.IsOk());\n  const auto& axis_seq = CHECK_JUST(maybe_axis_seq);\n  ASSERT_TRUE(*axis_seq == (std::vector<int64_t>{2, 1, 0}));\n}\n\n}  // namespace test\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/framework/placement_utils.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/container_util.h\"\n#include \"oneflow/core/common/just.h\"\n#include \"oneflow/core/framework/instructions_builder.h\"\n#include \"oneflow/core/framework/placement_utils.h\"\n#include \"oneflow/core/framework/parallel_conf_util.h\"\n\nnamespace oneflow {\n\nMaybe<Symbol<ParallelDesc>> ReplacePlacementDeviceTag(Symbol<ParallelDesc> parallel_desc,\n                                                      const std::string& device_type) {\n  ParallelConf parallel_conf = parallel_desc->parallel_conf();\n  parallel_conf.set_device_tag(device_type);\n  std::shared_ptr<ParallelDesc> out_parallel_desc;\n  JUST(PhysicalRun(\n      [&out_parallel_desc, &parallel_conf](InstructionsBuilder* builder) -> Maybe<void> {\n        out_parallel_desc = JUST(builder->GetParallelDescSymbol(parallel_conf));\n        return Maybe<void>::Ok();\n      }));\n  return SymbolOf(*out_parallel_desc);\n}\n\nMaybe<void> TouchGlobalTensor(const std::shared_ptr<one::Tensor>& tensor) {\n  CHECK_OR_RETURN(tensor->is_global());  // NOLINT\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/framework/placement_utils.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef _ONEFLOW_CORE_FRAMEWORK_PLACEMENT_UTILS_H_\n#define _ONEFLOW_CORE_FRAMEWORK_PLACEMENT_UTILS_H_\n\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/common/symbol.h\"\n#include \"oneflow/core/framework/tensor_rpc_util.h\"\n#include \"oneflow/core/job/parallel_desc.h\"\n\nnamespace oneflow {\n\nMaybe<Symbol<ParallelDesc>> ReplacePlacementDeviceTag(Symbol<ParallelDesc> parallel_desc,\n                                                      const std::string& device_type);\n\nMaybe<void> TouchGlobalTensor(const std::shared_ptr<one::Tensor>& tensor);\n\nconstexpr auto* CheckMetaConsistency = DECORATE(&TouchGlobalTensor, CheckGlobalTensorMeta);\n\n}  // namespace oneflow\n\n#endif\n"
  },
  {
    "path": "oneflow/core/framework/random_generator.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/random_generator.h\"\n\n#include \"oneflow/core/control/global_process_ctx.h\"\n#include \"oneflow/core/device/cuda_util.h\"\n#include \"oneflow/core/ep/include/device_manager_registry.h\"\n#include \"oneflow/core/framework/auto_random_generator.h\"\n#include \"oneflow/core/framework/nd_sbp.h\"\n#include \"oneflow/user/kernels/random_seed_util.h\"\n#include \"oneflow/core/framework/instructions_builder.h\"\n#include \"oneflow/core/framework/tensor_util.h\"\n#include \"oneflow/core/framework/to_string.h\"\n#include \"oneflow/core/functional/functional.h\"\n#include \"oneflow/core/platform/include/pthread_fork.h\"\n#include \"oneflow/core/vm/virtual_machine.h\"\n#include \"oneflow/core/vm/vm_util.h\"\n#include \"oneflow/core/functional/impl/common.h\"\n\nnamespace oneflow {\nnamespace one {\n\nnamespace {\n\nuint64_t GetNonDeterministicRandom() {\n  std::random_device rd;\n  // limit to 53 bits to ensure unique representation in double\n  auto s = ((((uint64_t)rd()) << 32) + rd()) & 0x1FFFFFFFFFFFFF;\n  return s;\n}\n\nMaybe<void> CPUSynchronize() {\n  if (Singleton<VirtualMachine>::Get() != nullptr) { return vm::CurrentRankSync(); }\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\nGenerator::Generator(const std::shared_ptr<ep::RandomGenerator>& internal) : internal_(internal) {}\n\nuint64_t Generator::current_seed() const { return internal_->current_seed(); }\n\nvoid Generator::add_children_generator(Symbol<ParallelDesc> placement, Symbol<NdSbp> nd_sbp,\n                                       const std::shared_ptr<Generator>& generator) {\n  children_generators_.emplace(std::make_pair(placement, nd_sbp), generator);\n}\n\nconst HashMap<std::pair<Symbol<ParallelDesc>, Symbol<NdSbp>>, std::shared_ptr<one::Generator>>&\nGenerator::children_generators() const {\n  return children_generators_;\n}\n\nvoid Generator::set_current_seed(uint64_t seed) {\n  CHECK_JUST(CPUSynchronize());\n  internal_->set_current_seed(seed);\n  for (auto pair : children_generators_) {\n    uint64_t rank_seed = seed;\n    if (pair.first.first->parallel_num() > 1) {\n      CHECK_JUST(one::functional::BroadcastSeedToAllRanks(&seed, /*root=*/0));  // NOLINT\n      rank_seed =\n          CHECK_JUST(GetRandomSeedForRank(*(pair.first.first), *(pair.first.second),  // NOLINT\n                                          seed,                                       // NOLINT\n                                          GlobalProcessCtx::Rank()));                 // NOLINT\n    }\n    pair.second->set_current_seed(rank_seed);\n  }\n}\n\nuint64_t Generator::seed() {\n  uint64_t seed = GetNonDeterministicRandom();\n  set_current_seed(seed);\n  return seed;\n}\n\nMaybe<Symbol<Device>> Generator::device() const {\n  return Device::New(internal_->device_type_name(), internal_->device_index());\n}\n\nMaybe<Tensor> Generator::GetState() const {\n  JUST(CPUSynchronize());\n  int64_t state_size = internal_->GetStateSize();\n  std::vector<uint8_t> state_data(state_size);\n  internal_->GetState(state_size, state_data.data());\n  const auto& device = JUST(Device::New(\"cpu\"));\n  const auto& state = JUST(functional::Empty(Shape{state_size}, DType::UInt8(), device,\n                                             /*requires_grad=*/false, /*pin_memory=*/false));\n  const auto& callback = [&](ep::Stream*,\n                             const std::shared_ptr<vm::EagerBlobObject>& eager_blob_object) {\n    memcpy(eager_blob_object->mut_dptr(), state_data.data(), state_size);\n  };\n  JUST(SyncAccessTensorWithTimeOut(state, callback, \"mut\"));\n  return state;\n}\n\nMaybe<void> Generator::SetState(const std::shared_ptr<Tensor>& state) {\n  const auto& device = JUST(state->device());\n  if (device->type() != \"cpu\") {\n    return Error::RuntimeError() << \"Generator state should be host tensor.\";\n  }\n  if (state->dtype() != DType::UInt8()) {\n    return Error::RuntimeError() << \"Generator state should be dtype=flow.uint8\";\n  }\n  size_t state_size = state->shape()->elem_cnt();\n  std::vector<uint8_t> state_data(state_size);\n  const auto& callback = [&](ep::Stream*,\n                             const std::shared_ptr<vm::EagerBlobObject>& eager_blob_object) {\n    memcpy(state_data.data(), eager_blob_object->dptr(), state_size);\n  };\n  JUST(SyncAccessTensorWithTimeOut(state, callback, \"const\"));\n  JUST(CPUSynchronize());\n  internal_->SetState(state_size, state_data.data());\n  return Maybe<void>::Ok();\n}\n\nMaybe<Generator> DefaultGenerator(const std::string& device, int device_index) {\n  static auto* default_auto_generator =\n      dynamic_cast<AutoGenerator*>(JUST(DefaultAutoGenerator())->internal().get());\n  if (device_index == -1) { device_index = (device == \"cpu\" ? 0 : GlobalProcessCtx::LocalRank()); }\n  return std::make_shared<Generator>(\n      JUST(default_auto_generator->GetOrCreate(device, device_index)));\n}\n\nMaybe<Generator> DefaultAutoGenerator() {\n  // Skip destructing to avoid calling symbols in other dynamic libraries when the global object is\n  // released.\n  static auto default_auto_generator = std::make_shared<Generator>(std::shared_ptr<AutoGenerator>(\n      new AutoGenerator(GetNonDeterministicRandom()), [](AutoGenerator*) {}));\n  return default_auto_generator;\n}\n\nMaybe<Generator> DefaultCPUGenerator() {\n  static auto default_cpu_generator = JUST(DefaultGenerator(\"cpu\", 0));\n  return default_cpu_generator;\n}\n\nMaybe<Generator> DefaultCUDAGenerator(int device_index) {\n#ifdef WITH_CUDA\n  static int device_count = GetCudaDeviceCount();\n#else\n  static int device_count = 0;\n#endif  // WITH_CUDA\n  static std::vector<std::once_flag> init_flags(device_count);\n  static std::vector<std::shared_ptr<Generator>> default_cuda_generator(device_count);\n\n  if (device_index == -1) { device_index = GlobalProcessCtx::LocalRank(); }\n  CHECK_OR_RETURN(device_index >= 0 && device_index < device_count)\n      << \"Invalid device index \" << device_index;\n  std::call_once(init_flags[device_index], [&]() {\n    default_cuda_generator[device_index] = CHECK_JUST(DefaultGenerator(\"cuda\", device_index));\n  });\n  return default_cuda_generator.at(device_index);\n}\n\nMaybe<Generator> MakeAutoGenerator() {\n  return std::make_shared<Generator>(std::make_shared<AutoGenerator>(default_rng_seed_val));\n}\n\nMaybe<Generator> MakeCPUGenerator() {\n  static auto device_mgr =\n      Singleton<ep::DeviceManagerRegistry>::Get()->GetDeviceManager(DeviceType::kCPU);\n  return std::make_shared<Generator>(device_mgr->CreateRandomGenerator(default_rng_seed_val, 0));\n}\n\nMaybe<Generator> MakeCUDAGenerator(int device_index) {\n  static auto device_mgr =\n      Singleton<ep::DeviceManagerRegistry>::Get()->GetDeviceManager(DeviceType::kCUDA);\n  if (device_index == -1) { device_index = GlobalProcessCtx::LocalRank(); }\n  return std::make_shared<Generator>(\n      device_mgr->CreateRandomGenerator(default_rng_seed_val, device_index));\n}\n\nMaybe<void> ManualSeedAllCudaGenerator(uint64_t seed) {\n#ifdef WITH_CUDA\n  static int device_count = GetCudaDeviceCount();\n  FOR_RANGE(int, device_id, 0, device_count) {\n    const auto& cuda_gen = JUST(DefaultCUDAGenerator(device_id));\n    cuda_gen->set_current_seed(seed);\n  }\n#endif  // WITH_CUDA\n  return Maybe<void>::Ok();\n}\n\nMaybe<Generator> MakeGenerator(const std::string& device, int device_index) {\n  if (device == \"auto\") {\n    return std::make_shared<Generator>(std::make_shared<AutoGenerator>(default_rng_seed_val));\n  }\n  auto device_type = ep::DeviceManagerRegistry::GetDeviceTypeByDeviceTypeName(device);\n  if (device_type == DeviceType::kInvalidDevice) {\n    return Error::RuntimeError() << \"Expected one of \" << PrintGeneratorAvailableDevices()\n                                 << \" device type at start of device string: \" << device;\n  }\n  auto device_mgr = Singleton<ep::DeviceManagerRegistry>::Get()->GetDeviceManager(device_type);\n  if (device_index == -1) { device_index = (device == \"cpu\" ? 0 : GlobalProcessCtx::LocalRank()); }\n  return std::make_shared<Generator>(\n      device_mgr->CreateRandomGenerator(default_rng_seed_val, device_index));\n}\n\nMaybe<Generator> DefaultGenerator(DeviceType device, int device_index) {\n  return DefaultGenerator(*JUST(DeviceTag4DeviceType(device)), device_index);\n}\n\nMaybe<Generator> MakeGenerator(DeviceType device, int device_index) {\n  return MakeGenerator(*JUST(DeviceTag4DeviceType(device)), device_index);\n}\n\nMaybe<Generator> ManualSeed(uint64_t seed) {\n  const auto& default_auto_generator = JUST(DefaultAutoGenerator());\n  default_auto_generator->set_current_seed(seed);\n  return default_auto_generator;\n}\n\nMaybe<void> ManualSeed(uint64_t seed, const std::string& device, int device_index) {\n  JUST(DefaultGenerator(device, device_index))->set_current_seed(seed);\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> ManualSeed(uint64_t seed, DeviceType device, int device_index) {\n  return ManualSeed(seed, *JUST(DeviceTag4DeviceType(device)), device_index);\n}\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/framework/random_generator.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_FRAMEWORK_RANDOM_GENERATOR_H_\n#define ONEFLOW_CORE_FRAMEWORK_RANDOM_GENERATOR_H_\n\n#include <mutex>\n\n#include \"oneflow/core/ep/include/random_generator.h\"\n#include \"oneflow/core/framework/auto_random_generator.h\"\n#include \"oneflow/core/framework/device.h\"\n\n#include \"oneflow/core/ep/cpu/cpu_random_generator.h\"\n#include \"oneflow/core/ep/cuda/cuda_random_generator.h\"\n#include \"oneflow/core/common/hash_container.h\"\n\nnamespace oneflow {\n\nclass NdSbp;\n\nnamespace one {\n\n// The default seed is selected to be a large number\n// with good distribution of 0s and 1s in bit representation.\nstatic constexpr uint64_t default_rng_seed_val = 67280421310721;\n\nclass Tensor;\n\nclass Generator final {\n public:\n  explicit Generator(const std::shared_ptr<ep::RandomGenerator>& internal);\n\n  ~Generator() = default;\n\n  void set_current_seed(uint64_t seed);\n\n  uint64_t current_seed() const;\n\n  void add_children_generator(Symbol<ParallelDesc> placement, Symbol<NdSbp> nd_sbp,\n                              const std::shared_ptr<Generator>& generator);\n  const HashMap<std::pair<Symbol<ParallelDesc>, Symbol<NdSbp>>, std::shared_ptr<one::Generator>>&\n  children_generators() const;\n\n  // Reset current generator by a non-deterministic random seed, and returns it.\n  uint64_t seed();\n\n  Maybe<Symbol<Device>> device() const;\n\n  Maybe<Tensor> GetState() const;\n  Maybe<void> SetState(const std::shared_ptr<Tensor>& state);\n\n  const std::shared_ptr<ep::RandomGenerator>& internal() const { return internal_; }\n\n  template<typename T>\n  Maybe<T> Get(int device_index = -1) const {\n    if (auto* internal = dynamic_cast<AutoGenerator*>(internal_.get())) {\n      return internal->GetOrCreate<T>(device_index);\n    }\n    auto internal = std::dynamic_pointer_cast<T>(internal_);\n    CHECK_NOTNULL_OR_RETURN(internal);\n    if (device_index != -1) {\n      CHECK_EQ_OR_RETURN(device_index, internal->device_index())\n          << \"Invalid device index \" << device_index << \" since the generator's device index is \"\n          << internal->device_index();\n    }\n    return internal;\n  }\n\n private:\n  mutable std::mutex mutex_;\n  std::shared_ptr<ep::RandomGenerator> internal_;\n  // children generator for eager global mode\n  HashMap<std::pair<Symbol<ParallelDesc>, Symbol<NdSbp>>,  // NOLINT\n          std::shared_ptr<one::Generator>>                 // NOLINT\n      children_generators_;                                // NOLINT\n};\n\nMaybe<Generator> MakeGenerator(const std::string& device, int device_index = -1);\nMaybe<Generator> MakeGenerator(DeviceType device, int device_index = -1);\n\nMaybe<Generator> MakeAutoGenerator();\nMaybe<Generator> MakeCPUGenerator();\nMaybe<Generator> MakeCUDAGenerator();\n\nMaybe<Generator> DefaultAutoGenerator();\nMaybe<Generator> DefaultCPUGenerator();\nMaybe<Generator> DefaultCUDAGenerator(int device_index = -1);\n\nMaybe<Generator> DefaultGenerator(const std::string& device, int device_index = -1);\nMaybe<Generator> DefaultGenerator(DeviceType device, int device_index = -1);\n\nMaybe<Generator> ManualSeed(uint64_t seed);\n\nMaybe<void> ManualSeed(uint64_t seed, const std::string& device, int device_index = -1);\nMaybe<void> ManualSeed(uint64_t seed, DeviceType device, int device_index = -1);\n\nMaybe<void> ManualSeedAllCudaGenerator(uint64_t seed);\n\n}  // namespace one\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_FRAMEWORK_RANDOM_GENERATOR_H_\n"
  },
  {
    "path": "oneflow/core/framework/rank_group_rpc_util.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <memory>\n#include <chrono>\n#include \"oneflow/core/framework/rank_group_rpc_util.h\"\n#include \"oneflow/core/framework/transport_util.h\"\n#include \"oneflow/core/common/data_type.h\"\n#include \"oneflow/core/common/container_util.h\"\n#include \"oneflow/core/job/rank_group.h\"\n#include \"oneflow/core/job/rank_group_scope.h\"\n#include \"oneflow/core/job/parallel_desc.h\"\n#include \"oneflow/core/thread/thread_global_id.h\"\n#include \"oneflow/core/rpc/include/global_process_ctx.h\"\n\nnamespace oneflow {\n\nMaybe<NaiveAsyncTransportCtx> CheckTransportToken(Symbol<RankGroup> rank_group) {\n  const auto& transport_token =\n      JUST(TransportToken::NewTransportToken(kTransportTokenTypeCheckRankGroupConsistency));\n  const auto& PrepareBuffer = [](void** buffer, std::size_t* size,\n                                 std::function<void()>* Callback) -> Maybe<void> {\n    const auto& placeholder = std::make_shared<uint32_t>();\n    *buffer = placeholder.get();\n    *size = sizeof(uint32_t);\n    *Callback = [placeholder]() {};\n    return Maybe<void>::Ok();\n  };\n  const auto& ctx =\n      std::make_shared<NaiveAsyncTransportCtx>(transport_token, PrepareBuffer, PrepareBuffer);\n  JUST(TransportUtil::SendToNextRankInRing(rank_group, transport_token, ctx.get()));\n  JUST(TransportUtil::ReceiveFromPrevRankInRing(rank_group, transport_token, ctx.get()));\n  return ctx;\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/framework/rank_group_rpc_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_FRAMEWORK_PLACEMENT_RPC_UTIL_H_\n#define ONEFLOW_CORE_FRAMEWORK_PLACEMENT_RPC_UTIL_H_\n\n#include \"oneflow/core/framework/transport_token.h\"\n#include \"oneflow/core/framework/transport_util.h\"\n#include \"oneflow/core/common/symbol.h\"\n#include \"oneflow/core/job/rank_group.h\"\n\nnamespace oneflow {\n\nMaybe<NaiveAsyncTransportCtx> CheckTransportToken(Symbol<RankGroup> rank_group);\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_FRAMEWORK_PLACEMENT_RPC_UTIL_H_\n"
  },
  {
    "path": "oneflow/core/framework/saved_tensor_hooks.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_FRAMEWORK_SAVED_TENSOR_HOOKS_H_\n#define ONEFLOW_CORE_FRAMEWORK_SAVED_TENSOR_HOOKS_H_\n\n#include \"oneflow/core/framework/tensor.h\"\n\nnamespace oneflow {\nnamespace one {\nclass SavedTensorHook {\n public:\n  virtual ~SavedTensorHook() = default;\n  virtual void pack(const std::shared_ptr<Tensor>& tensor) = 0;\n  virtual std::shared_ptr<Tensor> unpack() = 0;\n};\n\nclass SavedTensorHookCreator {\n public:\n  virtual ~SavedTensorHookCreator() = default;\n  virtual std::unique_ptr<SavedTensorHook> new_saved_tensor_hook() const = 0;\n};\n\n}  // namespace one\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_FRAMEWORK_SAVED_TENSOR_HOOKS_H_\n"
  },
  {
    "path": "oneflow/core/framework/sbp_context.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/sbp_context.h\"\n#include \"oneflow/core/job/sbp_signature_builder.h\"\n#include \"oneflow/core/operator/operator.h\"\n#include \"oneflow/core/framework/nd_sbp.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ninline void SplitImpl(SbpSignature* sbp_sign, const std::string& bn, int64_t axis) {\n  (*sbp_sign->mutable_bn_in_op2sbp_parallel())[bn].mutable_split_parallel()->set_axis(axis);\n}\n\ninline void BroadcastImpl(SbpSignature* sbp_sign, const std::string& bn) {\n  (*sbp_sign->mutable_bn_in_op2sbp_parallel())[bn].mutable_broadcast_parallel();\n}\n\ninline void PartialSumImpl(SbpSignature* sbp_sign, const std::string& bn) {\n  (*sbp_sign->mutable_bn_in_op2sbp_parallel())[bn].mutable_partial_sum_parallel();\n}\n\n}  // namespace\n\nnamespace user_op {\n\nUserOpSbpSignatureBuilder& UserOpSbpSignatureBuilder::Split(const OpArg& op_arg, int64_t axis) {\n  SplitImpl(&sbp_sig_tmp_, GenRepeatedBn(op_arg.name(), op_arg.index()), axis);\n  return *this;\n}\n\nUserOpSbpSignatureBuilder& UserOpSbpSignatureBuilder::Split(const std::vector<OpArg>& op_args,\n                                                            int64_t axis) {\n  for (const auto& op_arg : op_args) { Split(op_arg, axis); }\n  return *this;\n}\n\nUserOpSbpSignatureBuilder& UserOpSbpSignatureBuilder::Split(\n    const std::vector<std::pair<std::string, int32_t>>& args, int64_t axis) {\n  for (const auto& pair : args) {\n    SplitImpl(&sbp_sig_tmp_, GenRepeatedBn(pair.first, pair.second), axis);\n  }\n  return *this;\n}\n\nUserOpSbpSignatureBuilder& UserOpSbpSignatureBuilder::Broadcast(const OpArg& op_arg) {\n  BroadcastImpl(&sbp_sig_tmp_, GenRepeatedBn(op_arg.name(), op_arg.index()));\n  return *this;\n}\n\nUserOpSbpSignatureBuilder& UserOpSbpSignatureBuilder::Broadcast(const std::vector<OpArg>& op_args) {\n  for (const auto& op_arg : op_args) { Broadcast(op_arg); }\n  return *this;\n}\n\nUserOpSbpSignatureBuilder& UserOpSbpSignatureBuilder::Broadcast(\n    const std::vector<std::pair<std::string, int32_t>>& op_args) {\n  for (const auto& pair : op_args) {\n    BroadcastImpl(&sbp_sig_tmp_, GenRepeatedBn(pair.first, pair.second));\n  }\n  return *this;\n}\n\nUserOpSbpSignatureBuilder& UserOpSbpSignatureBuilder::PartialSum(const OpArg& op_arg) {\n  PartialSumImpl(&sbp_sig_tmp_, GenRepeatedBn(op_arg.name(), op_arg.index()));\n  return *this;\n}\n\nUserOpSbpSignatureBuilder& UserOpSbpSignatureBuilder::PartialSum(\n    const std::vector<OpArg>& op_args) {\n  for (const auto& op_arg : op_args) { PartialSum(op_arg); }\n  return *this;\n}\n\nUserOpSbpSignatureBuilder& UserOpSbpSignatureBuilder::PartialSum(\n    const std::vector<std::pair<std::string, int32_t>>& op_args) {\n  for (const auto& pair : op_args) {\n    PartialSumImpl(&sbp_sig_tmp_, GenRepeatedBn(pair.first, pair.second));\n  }\n  return *this;\n}\n\nMaybe<void> GetSbpFnUtil::DefaultBroadcastToBroadcast(SbpContext* ctx) { return Maybe<void>::Ok(); }\n\nMaybe<void> GetSbpFnUtil::SplitForEachAxis(SbpContext* ctx) {\n  const auto& inputs = ctx->inputs();\n  CHECK_GE_OR_RETURN(inputs.size(), 1)\n      << \"At least one input for op GetSbpFnUtil::SplitForEachAxis\";\n  int64_t num_axes =\n      ctx->LogicalTensorDesc4InputArgNameAndIndex(inputs.at(0).first, inputs.at(0).second)\n          .shape()\n          .NumAxes();\n  for (const auto& pair : inputs) {\n    CHECK_EQ(\n        num_axes,\n        ctx->LogicalTensorDesc4InputArgNameAndIndex(pair.first, pair.second).shape().NumAxes());\n  }\n  for (int64_t axis = 0; axis < num_axes; ++axis) {\n    ctx->NewBuilder().Split(inputs, axis).Split(ctx->outputs(), axis).Build();\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InferNdSbp4SrcOp(user_op::InferNdSbpFnContext* ctx, const SbpParallel& default_sbp) {\n  const Shape& hierarchy = ctx->parallel_hierarchy();\n  const auto& sbp_str_list = ctx->user_op_conf().attr<std::vector<std::string>>(\"nd_sbp\");\n\n  // src op may have tick inputs whose sbp should be broadcast\n  for (const auto& input_arg : ctx->inputs()) {\n    NdSbp* input_nd_sbp = ctx->NdSbp4ArgNameAndIndex(input_arg.first, input_arg.second);\n    FOR_RANGE(int, i, 0, hierarchy.NumAxes()) {\n      input_nd_sbp->add_sbp_parallel()->mutable_broadcast_parallel();\n    }\n  }\n\n  for (const auto& output_arg : ctx->outputs()) {\n    NdSbp* output_nd_sbp = ctx->NdSbp4ArgNameAndIndex(output_arg.first, output_arg.second);\n    size_t nd_sbp_size = sbp_str_list.size();\n    if (nd_sbp_size == 0) {\n      nd_sbp_size = hierarchy.NumAxes();\n    } else {\n      CHECK_EQ_OR_RETURN(nd_sbp_size, hierarchy.NumAxes());\n    }\n    FOR_RANGE(size_t, i, 0, nd_sbp_size) {\n      SbpParallel* sbp = output_nd_sbp->add_sbp_parallel();\n      if (sbp_str_list.size() == 0) {\n        *sbp = default_sbp;\n      } else {\n        CHECK_OR_RETURN(ParseSbpParallelFromString(sbp_str_list[i], sbp));\n      }\n      CHECK_OR_RETURN(sbp->has_split_parallel() || sbp->has_broadcast_parallel());\n    }\n  }\n\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> SetSrcOpNdSbp(const NdSbpSignature& nd_sbp_sig, const std::string& blob_name,\n                          OperatorConf* op_conf) {\n  CHECK_OR_RETURN(nd_sbp_sig.bn_in_op2nd_sbp().find(blob_name)\n                  != nd_sbp_sig.bn_in_op2nd_sbp().end())\n      << \"blob `\" << blob_name << \"` can't found in NdSBP signature: \" << nd_sbp_sig.DebugString();\n  const auto& nd_sbp = nd_sbp_sig.bn_in_op2nd_sbp().at(blob_name);\n  std::vector<std::string> nd_sbp_str_list = *JUST(GetNdSbpStrList(nd_sbp));\n  CHECK_OR_RETURN(op_conf->has_user_conf())\n      << \"user_op::SetSrcOpNdSbp function only used to set user op conf\";\n  CHECK_OR_RETURN(op_conf->user_conf().attr().find(\"nd_sbp\") != op_conf->user_conf().attr().end())\n      << op_conf->name() << \" has no attr named `nd_sbp`\";\n  *op_conf->mutable_user_conf()\n       ->mutable_attr()\n       ->at(\"nd_sbp\")\n       .mutable_at_list_string()\n       ->mutable_val() = {nd_sbp_str_list.begin(), nd_sbp_str_list.end()};\n  return Maybe<void>::Ok();\n}\n\n}  // namespace user_op\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/framework/sbp_context.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_FRAMEWORK_SBP_CONTEXT_H_\n#define ONEFLOW_CORE_FRAMEWORK_SBP_CONTEXT_H_\n\n#include \"oneflow/core/framework/user_op_conf.h\"\n#include \"oneflow/core/job/sbp_parallel.h\"\n#include \"oneflow/core/framework/infer_nd_sbp_fn_context.h\"\n\nnamespace oneflow {\n\nnamespace user_op {\n\nclass TensorDesc;\n\nclass UserOpSbpSignatureBuilder final {\n public:\n  UserOpSbpSignatureBuilder(SbpSignatureList* sbp_sig_list) : sbp_sig_list_(sbp_sig_list) {}\n\n  UserOpSbpSignatureBuilder& Split(const OpArg& op_arg, int64_t axis);\n  UserOpSbpSignatureBuilder& Split(const std::vector<OpArg>& op_args, int64_t axis);\n  UserOpSbpSignatureBuilder& Split(const std::vector<std::pair<std::string, int32_t>>& op_args,\n                                   int64_t axis);\n\n  UserOpSbpSignatureBuilder& Broadcast(const OpArg& op_arg);\n  UserOpSbpSignatureBuilder& Broadcast(const std::vector<OpArg>& op_args);\n  UserOpSbpSignatureBuilder& Broadcast(const std::vector<std::pair<std::string, int32_t>>& op_args);\n\n  UserOpSbpSignatureBuilder& PartialSum(const OpArg& op_arg);\n  UserOpSbpSignatureBuilder& PartialSum(const std::vector<OpArg>& op_args);\n  UserOpSbpSignatureBuilder& PartialSum(\n      const std::vector<std::pair<std::string, int32_t>>& op_args);\n\n  void Build() { *(sbp_sig_list_->mutable_sbp_signature()->Add()) = sbp_sig_tmp_; }\n\n private:\n  SbpSignatureList* sbp_sig_list_;\n  SbpSignature sbp_sig_tmp_;\n};\n\nclass SbpContextBase {\n public:\n  SbpContextBase() = default;\n  virtual ~SbpContextBase() = default;\n\n  virtual const TensorDesc& LogicalTensorDesc4InputArgNameAndIndex(\n      const std::string& input_arg_name, int32_t index) const = 0;\n  virtual const std::vector<std::pair<std::string, int32_t>>& inputs() const = 0;\n  virtual const std::vector<std::pair<std::string, int32_t>>& outputs() const = 0;\n\n  virtual DeviceType device_type() const = 0;\n  virtual int64_t parallel_num() const = 0;\n\n  template<typename T>\n  T Attr(const std::string& attr_name) const {\n    return user_op_conf().attr<T>(attr_name);\n  }\n  virtual const UserOpConfWrapper& user_op_conf() const = 0;\n};\n\nclass SbpContext : public SbpContextBase {\n public:\n  SbpContext() = default;\n  ~SbpContext() override = default;\n\n  // hierarchy value is the value at the dimension corresponding to the current SBP\n  // For example, 2 machines, 4 gpus per machine, hierarchy = [2, 4]\n  // Suppose we have nd_sbp = (S0, B)\n  // The hierarchy value corresponding to S0 is 2\n  // The hierarchy value corresponding to B is 4.\n  virtual int64_t hierarchy_value() const = 0;\n  virtual UserOpSbpSignatureBuilder NewBuilder() = 0;\n};\n\nclass InferSbpSignatureFnContext : public SbpContextBase {\n public:\n  InferSbpSignatureFnContext() = default;\n  ~InferSbpSignatureFnContext() override = default;\n\n  virtual SbpSignature* mutable_sbp_signature() = 0;\n  virtual const SbpSignature& sbp_signature_conf() const = 0;\n  virtual const SbpParallel& SbpParallelHint4InputArgNameAndIndex(const std::string& input_arg_name,\n                                                                  int32_t index) const = 0;\n};\n\nstruct GetSbpFnUtil {\n  static Maybe<void> DefaultBroadcastToBroadcast(SbpContext*);\n  static Maybe<void> SplitForEachAxis(SbpContext*);\n};\n\nMaybe<void> InferNdSbp4SrcOp(user_op::InferNdSbpFnContext* ctx, const SbpParallel& default_sbp);\nMaybe<void> SetSrcOpNdSbp(const NdSbpSignature& nd_sbp_sig, const std::string& blob_name,\n                          OperatorConf* op_conf);\n\n}  // namespace user_op\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_FRAMEWORK_SBP_CONTEXT_H_\n"
  },
  {
    "path": "oneflow/core/framework/sbp_infer_util.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/framework/sbp_infer_util.h\"\n#include \"oneflow/core/auto_parallel/algorithm_util.h\"\n#include \"oneflow/core/auto_parallel/boxing_collector.h\"\n#include \"oneflow/core/boxing/eager_boxing_interpreter_mgr.h\"\n#include \"oneflow/core/common/device_type.pb.h\"\n#include \"oneflow/core/common/nd_index_offset_helper.h\"\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/job/global_for.h\"\n#include \"oneflow/core/job/lazy_mode.h\"\n#include \"oneflow/core/job/nd_sbp_util.h\"\n#include \"oneflow/core/job/parallel_desc.h\"\n#include \"oneflow/core/job/resource_desc.h\"\n#include \"oneflow/core/job/sbp_parallel.pb.h\"\n#include \"oneflow/core/register/blob_desc.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nstatic const double kUnsupportedBoxing = GetMaxVal<float>();\n\n// check whether the sbp_parallel is legal\nbool CheckSbpParallel(const SbpParallel& sbp_parallel) {\n  return sbp_parallel.has_split_parallel() || sbp_parallel.has_broadcast_parallel()\n         || sbp_parallel.has_partial_sum_parallel();\n}\n\n// check whether the nd_sbp is legal\nbool CheckNdSbp(const NdSbp& nd_sbp) {\n  if (nd_sbp.sbp_parallel_size() <= 0) { return false; }\n  for (const auto& sbp : nd_sbp.sbp_parallel()) {\n    if (!CheckSbpParallel(sbp)) { return false; }\n  }\n  return true;\n}\n\ndouble Penalty4PartialInConsumer(double logical_blob_size, int32_t producer_parallel_num,\n                                 int32_t consumer_parallel_num) {\n  static const int64_t penalty4partial_in_consumer_tag =\n      ParseIntegerFromEnv(\"ONEFLOW_PENALTY_FOR_PARTIAL_IN_CONSUMER_POLICY\", 2);\n  if (penalty4partial_in_consumer_tag == Penalty4PartialInConsumerTag::kSlight) {\n    return 1.0;\n  } else if (penalty4partial_in_consumer_tag == Penalty4PartialInConsumerTag::kMiddle) {\n    return 4 * logical_blob_size * (producer_parallel_num + consumer_parallel_num);\n  } else {\n    return kUnsupportedBoxing;\n  }\n}\n\nint32_t Ratio4Sbp(const NdSbp& nd_sbp, const ParallelDesc& parallel_desc,\n                  const std::function<bool(const SbpParallel&)>& classifier) {\n  int32_t ratio = 1;\n  for (int32_t sbp_id = 0; sbp_id < nd_sbp.sbp_parallel_size(); sbp_id++) {\n    if (classifier(nd_sbp.sbp_parallel(sbp_id))) { ratio *= parallel_desc.hierarchy()->At(sbp_id); }\n  }\n  return ratio;\n}\n\nMaybe<double> ComputCopyCostBetweenTwoSbpParallel(const SbpParallel& producer_sbp_parallel,\n                                                  const SbpParallel& consumer_sbp_parallel,\n                                                  const BlobDesc& logical_blob_desc,\n                                                  bool on_same_devices,\n                                                  int32_t producer_parallel_num,\n                                                  int32_t consumer_parallel_num) {\n  if (!(CheckSbpParallel(producer_sbp_parallel) && CheckSbpParallel(consumer_sbp_parallel))) {\n    return Error::RuntimeError() << \"Illegal sbp parallel has been found.\";\n  }\n\n  // Not supporting S->P for lazy boxing now.\n  if (LazyMode::is_enabled()) {\n    if (consumer_sbp_parallel.has_partial_sum_parallel()\n        && producer_sbp_parallel.has_split_parallel()) {\n      return kUnsupportedBoxing;\n    }\n  }\n\n  // NOTE: A tensor placed on cpu with a consumer operator that accepts cuda inputs would be\n  // transferred to cuda later. We might not have correct parallel description at this moment.\n  if (on_same_devices && producer_parallel_num == consumer_parallel_num) {\n    // Same sbp, no cost: S->S, B->B, P->P\n    if (producer_sbp_parallel == consumer_sbp_parallel) { return 0.0; }\n    double logical_blob_size = TotalByteSize4BlobDesc(logical_blob_desc);\n    // S->P for eager. It should be 0 as well.\n    // NOTE: Similar to B->P, we just make the other part to be 0. You can consider P as S(i) for an\n    // arbitrary i.\n    // ? -> P\n    if (consumer_sbp_parallel.has_partial_sum_parallel()) {\n      return Penalty4PartialInConsumer(logical_blob_size, producer_parallel_num,\n                                       consumer_parallel_num);\n    }\n    // B->S\n    if (producer_sbp_parallel.has_broadcast_parallel()) { return 1.0; }\n\n    // has S\n    if (consumer_sbp_parallel.has_split_parallel() || producer_sbp_parallel.has_split_parallel()) {\n      if (consumer_sbp_parallel.has_split_parallel()\n          && producer_sbp_parallel.has_split_parallel()) {\n        // S(0)->S(1), S(1)->S(0), etc.\n        return logical_blob_size * (producer_parallel_num - 1) / producer_parallel_num;\n      } else {\n        // P->S, S->B/P\n        return logical_blob_size * (producer_parallel_num - 1);\n      }\n    }\n    // P->B\n    return 2 * logical_blob_size * (producer_parallel_num - 1);\n  } else {\n    // Not supporting P->P for different placement\n    if (LazyMode::is_enabled()) {\n      if (consumer_sbp_parallel.has_partial_sum_parallel()\n          && producer_sbp_parallel.has_partial_sum_parallel()) {\n        return kUnsupportedBoxing;\n      }\n    }\n\n    double logical_blob_size = TotalByteSize4BlobDesc(logical_blob_desc);\n    double overall_cost = logical_blob_size;\n    // ? -> B\n    if (consumer_sbp_parallel.has_broadcast_parallel()) {\n      overall_cost += (consumer_parallel_num - 1) * logical_blob_size;\n    }\n    // P -> ?\n    if (producer_sbp_parallel.has_partial_sum_parallel()) {\n      overall_cost += (producer_parallel_num - 1) * logical_blob_size;\n    }\n    // ? -> P\n    if (consumer_sbp_parallel.has_partial_sum_parallel()) {\n      overall_cost += Penalty4PartialInConsumer(logical_blob_size, producer_parallel_num,\n                                                consumer_parallel_num);\n    }\n    // For B->S, S->S, overall_cost == logical_blob_size;\n    return overall_cost;\n  }\n}\n\n// compute copy cost for two SBPs.\n// They may be either different or on different devices.\ndouble ComputCopyCostBetweenTwoDiffSbpParallel(const SbpParallel& producer_sbp_parallel,\n                                               const SbpParallel& consumer_sbp_parallel,\n                                               double logical_blob_size, double parallel_num,\n                                               bool on_same_devices) {\n  // Not supporting S->P for now.\n  if (consumer_sbp_parallel.has_partial_sum_parallel()\n      && producer_sbp_parallel.has_split_parallel()) {\n    return kUnsupportedBoxing;\n  }\n  if (on_same_devices) {\n    // B->P\n    if (consumer_sbp_parallel.has_partial_sum_parallel()) {\n      return Penalty4PartialInConsumer(logical_blob_size, parallel_num, parallel_num);\n    }\n    // B->S\n    if (producer_sbp_parallel.has_broadcast_parallel()) { return 1; }\n    // has S\n    if (consumer_sbp_parallel.has_split_parallel() || producer_sbp_parallel.has_split_parallel()) {\n      if (consumer_sbp_parallel.has_split_parallel()\n          && producer_sbp_parallel.has_split_parallel()) {\n        // S(0)->S(1), S(1)->S(0), etc.\n        return logical_blob_size * (parallel_num - 1) / parallel_num;\n      } else {\n        // P->S, S->B\n        return logical_blob_size * (parallel_num - 1);\n      }\n    }\n    // P->B (= P->S + S->B)\n    return 2 * logical_blob_size * (parallel_num - 1);\n  } else {\n    // They have the same hierarchy at the transfer dimension.\n    double overall_cost = logical_blob_size;\n    // ? -> B\n    if (consumer_sbp_parallel.has_broadcast_parallel()) {\n      overall_cost += logical_blob_size * (parallel_num - 1);\n    }\n    // P -> ?\n    if (producer_sbp_parallel.has_partial_sum_parallel()) {\n      overall_cost += logical_blob_size * (parallel_num - 1);\n    }\n    if (consumer_sbp_parallel.has_partial_sum_parallel()) {\n      overall_cost += Penalty4PartialInConsumer(logical_blob_size, parallel_num, parallel_num);\n    }\n    // For B->P, B->S, S->S, overall_cost == logical_blob_size;\n    return overall_cost;\n  }\n}\n\nMaybe<double> ComputCopyCostBetweenTwoNdSbp(const NdSbp& producer_nd_sbp,\n                                            const NdSbp& consumer_nd_sbp, double logical_blob_size,\n                                            const Shape& hierarchy, bool on_same_devices) {\n  if (hierarchy.NumAxes() != 2) { return kUnsupportedBoxing; }\n  const auto& producer_sbp_size = producer_nd_sbp.sbp_parallel_size();\n  const auto& consumer_sbp_size = consumer_nd_sbp.sbp_parallel_size();\n  // One of the SBP should have size 2\n  CHECK_OR_RETURN((producer_sbp_size == 1 && consumer_sbp_size == 2)\n                  || (producer_sbp_size == 2 && consumer_sbp_size == 1)\n                  || (producer_sbp_size == 2 && consumer_sbp_size == 2))\n      << \"Not supporting such boxing type. Check if we have bugs in auto parallel.\";\n  for (int32_t dim_same_sbp = 0; dim_same_sbp < 2; dim_same_sbp++) {\n    // If the nd_sbp only have size 1, then make its dimension 0\n    int32_t dim_producer = dim_same_sbp;\n    if (producer_sbp_size == 1) { dim_producer = 0; }\n    int32_t dim_consumer = dim_same_sbp;\n    if (consumer_sbp_size == 1) { dim_consumer = 0; }\n    // The SBP parallel are the same at dimension (dim_same_sbp)\n    if (producer_nd_sbp.sbp_parallel(dim_producer) == consumer_nd_sbp.sbp_parallel(dim_consumer)) {\n      if (!producer_nd_sbp.sbp_parallel(dim_producer).has_split_parallel()) {\n        logical_blob_size *= hierarchy.At(dim_same_sbp);\n      }\n      // The SBP parallel are different at dimension (dim_diff_sbp)\n      int32_t dim_diff_sbp = 1 - dim_same_sbp;\n      // If the nd_sbp only have size 1, then make its dimension 0.\n      // Since we have already do this before, we just maintain the value.\n      // Otherwise, switch the dimension to dim_diff_sbp\n      if (producer_sbp_size == 2) { dim_producer = dim_diff_sbp; }\n      if (consumer_sbp_size == 2) { dim_consumer = dim_diff_sbp; }\n      // Spliting at the same dimension needs special cares!\n      // Not supported by nccl\n      if (dim_diff_sbp == 0\n          && producer_nd_sbp.sbp_parallel(dim_producer)\n                 != consumer_nd_sbp.sbp_parallel(dim_consumer)\n          && (NdSbpAllSameSplitParallel(producer_nd_sbp)\n              || NdSbpAllSameSplitParallel(consumer_nd_sbp))) {\n        return kUnsupportedBoxing;\n      }\n      return ComputCopyCostBetweenTwoDiffSbpParallel(\n          producer_nd_sbp.sbp_parallel(dim_producer), consumer_nd_sbp.sbp_parallel(dim_consumer),\n          logical_blob_size, hierarchy.At(dim_diff_sbp), on_same_devices);\n    }\n  }\n  return kUnsupportedBoxing;\n}\n\nMaybe<double> ComputeEagerCopyCostBetweenNdSbp(const NdSbp& producer_sbp_parallel,\n                                               const NdSbp& consumer_sbp_parallel,\n                                               const BlobDesc& logical_blob_desc,\n                                               const ParallelDesc& producer_parallel_desc,\n                                               const ParallelDesc& consumer_parallel_desc,\n                                               bool requires_same_sbp) {\n  if (!(CheckNdSbp(producer_sbp_parallel) && CheckNdSbp(consumer_sbp_parallel))) {\n    return Error::RuntimeError() << \"Illegal sbp parallel has been found.\";\n  }\n\n  // TODO: get copy cost from each EagerBoxingInterpreter\n  if (!TRY(Singleton<EagerBoxingInterpreterManager>::Get()->GetEagerBoxingInterpreter(\n               producer_sbp_parallel, consumer_sbp_parallel, producer_parallel_desc,\n               consumer_parallel_desc, logical_blob_desc.shape()))\n           .IsOk()) {\n    return kUnsupportedBoxing;\n  }\n\n  bool on_same_devices = producer_parallel_desc.EqualsIgnoringHierarchy(consumer_parallel_desc);\n\n  // Reduce before cost computation\n  Shape reduced_in_hierarchy;\n  NdSbp reduced_in_nd_sbp;\n  Shape reduced_out_hierarchy;\n  NdSbp reduced_out_nd_sbp;\n  InOutParallelDimReduce(*producer_parallel_desc.hierarchy(), *consumer_parallel_desc.hierarchy(),\n                         producer_sbp_parallel, consumer_sbp_parallel, &reduced_in_hierarchy,\n                         &reduced_out_hierarchy, &reduced_in_nd_sbp, &reduced_out_nd_sbp,\n                         logical_blob_desc.shape());\n\n  bool same_nd_sbp = reduced_in_nd_sbp == reduced_out_nd_sbp;\n  // Same sbp is always supported.\n  if (same_nd_sbp && on_same_devices && reduced_in_hierarchy == reduced_out_hierarchy) {\n    return 0.0;\n  }\n  if (requires_same_sbp) { return kUnsupportedBoxing; }\n\n  int32_t in_dim = reduced_in_hierarchy.NumAxes();\n  int32_t out_dim = reduced_out_hierarchy.NumAxes();\n  // We support different hierarchy for 1D sbp\n  if (in_dim == 1 && out_dim == 1) {\n    return ComputCopyCostBetweenTwoSbpParallel(\n        reduced_in_nd_sbp.sbp_parallel(0), reduced_out_nd_sbp.sbp_parallel(0), logical_blob_desc,\n        on_same_devices, reduced_in_hierarchy.elem_cnt(), reduced_out_hierarchy.elem_cnt());\n  }\n\n  double total_cost = 1.0;\n  if (on_same_devices && reduced_in_hierarchy == reduced_out_hierarchy) {\n    // NOTE: After analysis, transfer cost increase if spliting the same dimension.\n    // Example 1: (S(1), S(0), S(1), S(0)) -> (S(0), S(0), S(0), S(0))\n    // Example 2: (B, S(0)) -> (S(0), S(0))\n    // The cost would be (1-1/n)T, where n is the product of hierarchy number in those splitting\n    // dimensions. To give a more precise cost, we add a upper bound of those lost cost back for\n    // simplification.\n    bool normal_case = true;\n    // nd to nd\n    for (int32_t i = 0; i < in_dim; ++i) {\n      const auto& in_sbp = reduced_in_nd_sbp.sbp_parallel(i);\n      const auto& out_sbp = reduced_out_nd_sbp.sbp_parallel(i);\n      // Have bugs here. (B, S0) -> (S0, S0) will give a cost 0.\n      // Actually it is (1-1/m)T for hierarchy (n, m)\n      // TODO: Fix that after support all sbp combination for eager.\n      total_cost += JUST(ComputCopyCostBetweenTwoSbpParallel(\n          in_sbp, out_sbp, logical_blob_desc, on_same_devices, reduced_in_hierarchy.elem_cnt(),\n          reduced_out_hierarchy.elem_cnt()));\n      // Add the penalty for P in the consumer\n      if (out_sbp.has_partial_sum_parallel() && (in_sbp != out_sbp)) {\n        total_cost += Penalty4PartialInConsumer(TotalByteSize4BlobDesc(logical_blob_desc),\n                                                producer_parallel_desc.parallel_num(),\n                                                consumer_parallel_desc.parallel_num());\n      }\n      // detect the cases that splits the same dimension before this splitting\n      if (normal_case && in_sbp.has_split_parallel() && in_sbp == out_sbp) {\n        for (int32_t j = 0; j < i; j++) {\n          const auto& in_sbp_j = reduced_in_nd_sbp.sbp_parallel(j);\n          const auto& out_sbp_j = reduced_out_nd_sbp.sbp_parallel(j);\n          // in_sbp == out_sbp in this situation\n          if ((in_sbp_j != out_sbp_j) && (in_sbp_j == in_sbp || out_sbp_j == in_sbp)) {\n            normal_case = false;\n            break;\n          }\n        }\n      }\n    }\n    // Add the cost for the special case\n    if (!normal_case) { total_cost += TotalByteSize4BlobDesc(logical_blob_desc); }\n  } else {\n    double logical_blob_size = TotalByteSize4BlobDesc(logical_blob_desc);\n    {\n      double in_cost = 1.0;\n      for (int32_t i = 0; i < in_dim; ++i) {\n        // P -> ?\n        if (reduced_in_nd_sbp.sbp_parallel(i).has_partial_sum_parallel()) {\n          in_cost *= reduced_in_hierarchy.At(i);\n        }\n      }\n      total_cost += logical_blob_size * in_cost;\n    }\n    {\n      double out_cost = 1.0;\n      for (int32_t i = 0; i < out_dim; ++i) {\n        // ? -> B\n        if (reduced_out_nd_sbp.sbp_parallel(i).has_broadcast_parallel()) {\n          out_cost *= reduced_out_hierarchy.At(i);\n        }\n        // Add the penalty for P in the consumer\n        if (reduced_out_nd_sbp.sbp_parallel(i).has_partial_sum_parallel()) {\n          total_cost +=\n              Penalty4PartialInConsumer(logical_blob_size, producer_parallel_desc.parallel_num(),\n                                        consumer_parallel_desc.parallel_num());\n        }\n      }\n      total_cost += logical_blob_size * out_cost;\n    }\n  }\n  return total_cost;\n}\n\nusing CopyCostFunc = Maybe<double>(const NdSbp&, const NdSbp&, const BlobDesc&, const ParallelDesc&,\n                                   const ParallelDesc&, bool);\nMaybe<CopyCostFunc*> GetComputeCopyCostFunc() {\n  if (LazyMode::is_enabled()) {\n    return &ComputeCopyCostWithMiddleNodes;\n  } else {\n    return &ComputeEagerCopyCostBetweenNdSbp;\n  }\n}\n\n// Replace the hierarchy and then create a new parallel description\nvoid ReplaceHierarchy4ParallelDesc(const ParallelDesc& old_parallel_desc,\n                                   const Shape& new_hierarchy, ParallelDesc* new_parallel_desc) {\n  if (*old_parallel_desc.hierarchy() == new_hierarchy) {\n    *new_parallel_desc = old_parallel_desc;\n  } else {\n    ParallelConf new_parallel_conf = old_parallel_desc.parallel_conf();\n    new_hierarchy.ToProto(new_parallel_conf.mutable_hierarchy());\n    *new_parallel_desc = ParallelDesc(new_parallel_conf);\n  }\n}\n\n// We can not just simply merging two same split\n// For example, shape = [6], we are trying to merge [2, 2]: (S0, S0) -> [4]: S0\n// For each rank, [4]: S0 has number of data: 2, 2, 1, 1\n// For each rank, [2]: S0 has number of data: 3, 3\n// For each rank, [2, 2]: (S0, S0) has number of data: 2, 1, 2, 1\n// Thus {[2, 2]: (S0, S0)} != {[4]: S0} for shape [6]\n// However {[2, 2]: (S0, S0)} == {[4]: S0} for shape [4], [5], [7], [8]\n// More specifically, {[a, b]: (Si, Si)} == {[a*b]: Si} if and only if\n// shape value % (a * b) == 0, 1, a*b - 1\nbool CanMergeSplit(int32_t shape_value, int32_t merged_split_hierarchy_value) {\n  int32_t remainder = shape_value % merged_split_hierarchy_value;\n  if (remainder <= 1 || remainder == merged_split_hierarchy_value - 1) {\n    return true;\n  } else {\n    return false;\n  }\n}\n\n}  // namespace\n\nint32_t PartialRatio4Producer(const NdSbp& sbp_producer,\n                              const ParallelDesc& producer_parallel_desc) {\n  return Ratio4Sbp(sbp_producer, producer_parallel_desc, &SbpParallel::has_partial_sum_parallel);\n}\n\nint32_t BroadcastRatio4Consumer(const NdSbp& sbp_consumer,\n                                const ParallelDesc& consumer_parallel_desc) {\n  return Ratio4Sbp(sbp_consumer, consumer_parallel_desc, &SbpParallel::has_broadcast_parallel);\n}\n\nvoid NdSbpDimReduce(const Shape& hierarchy, const NdSbp& nd_sbp, Shape* reduced_hierarchy,\n                    NdSbp* reduced_nd_sbp, const Shape& logical_shape) {\n  NdSbpsDimReduce(hierarchy, {&nd_sbp}, reduced_hierarchy, {reduced_nd_sbp}, logical_shape);\n}\n\nvoid NdSbpsDimReduce(const Shape& hierarchy, const std::vector<const NdSbp*>& nd_sbps,\n                     Shape* reduced_hierarchy, const std::vector<NdSbp*>& reduced_nd_sbps,\n                     const Shape& logical_shape) {\n  int32_t sbp_num = nd_sbps.size();\n  // Speed up for 1d sbp\n  if (hierarchy.NumAxes() == 1) {\n    *reduced_hierarchy = hierarchy;\n    for (int32_t index = 0; index < sbp_num; index++) {\n      if (hierarchy.elem_cnt() == 1) {\n        reduced_nd_sbps[index]->add_sbp_parallel()->mutable_broadcast_parallel();\n      } else {\n        *reduced_nd_sbps[index] = *nd_sbps[index];\n      }\n    }\n    return;\n  }\n  reduced_hierarchy->clear();\n  for (auto& reduced_nd_sbp : reduced_nd_sbps) { reduced_nd_sbp->clear_sbp_parallel(); }\n  // At this moment, if we have [2, 4, 3, 7]: (S0, S1, S0, S0) for logical shape [601, 301, 999]\n  // We hold the split when accessing the current dimension\n  // Do the true splitting until we reach the next step\n  // dim = 0, split_axis2holding_reduced_shapes: {(0: 601)}, last split axis = -1\n  // dim = 1, split_axis2holding_reduced_shapes: {(0: 300, 301), (1: 301)}, last split axis = 0\n  // dim = 2, split_axis2holding_reduced_shapes: {(0: 300, 301), (1: 75, 76)}, last split axis = 1\n  // dim = 3, at this moment, last split axis (0) == current split axis (0),\n  // dim = 3, but judging 300 % (3 * 7) = 6 fails the CanMergeSplit(), not merging\n  // dim = 3, split_axis2holding_reduced_shapes: {(0: 100, 101), (1: 75, 76)}, last split axis = 0\n  std::vector<HashMap<int32_t, HashSet<int32_t>>> index2split_axis2holding_reduced_shapes(sbp_num);\n  std::vector<std::vector<int32_t>> index2last_holding_reduced_shapes(sbp_num);\n  std::vector<int32_t> last_split_axises(sbp_num, -1);\n  std::vector<int32_t> indexes(sbp_num);\n  for (int32_t index = 0; index < sbp_num; index++) { indexes[index] = index; }\n  auto add_to_reduced_sbp_hierarchy = [&](int32_t hierarchy_dim) {\n    // Clear the last holding split axis\n    for (int32_t index = 0; index < sbp_num; index++) {\n      auto& split_axis2holding_reduced_shapes = index2split_axis2holding_reduced_shapes[index];\n      auto& last_holding_reduced_shapes = index2last_holding_reduced_shapes[index];\n      auto& last_split_axis = last_split_axises[index];\n      auto& nd_sbp = nd_sbps[index];\n      auto& reduced_nd_sbp = reduced_nd_sbps[index];\n      if (last_split_axis >= 0) {\n        auto& holding_reduced_shapes = split_axis2holding_reduced_shapes[last_split_axis];\n        holding_reduced_shapes.clear();\n        for (int32_t last_holding_reduced_shape : last_holding_reduced_shapes) {\n          int32_t quotient = last_holding_reduced_shape / reduced_hierarchy->back();\n          if (last_holding_reduced_shape % reduced_hierarchy->back() != 0) {\n            holding_reduced_shapes.insert(quotient + 1);\n          }\n          holding_reduced_shapes.insert(quotient);\n        }\n      }\n      // Add a new sbp_parallel and a new hierarchy dimension\n      const auto& curr_sbp_parallel = nd_sbp->sbp_parallel(hierarchy_dim);\n      *reduced_nd_sbp->add_sbp_parallel() = curr_sbp_parallel;\n      // Hold the current split shape\n      if (curr_sbp_parallel.has_split_parallel()) {\n        last_holding_reduced_shapes.clear();\n        last_split_axis = curr_sbp_parallel.split_parallel().axis();\n        auto it = split_axis2holding_reduced_shapes.find(last_split_axis);\n        if (it == split_axis2holding_reduced_shapes.end()) {\n          // Looking at a dimension which is never splitted before\n          // Shape: [601, ...], sbp: (S0, ...)\n          last_holding_reduced_shapes.push_back(logical_shape.At(last_split_axis));\n        } else {\n          // This dimension is splitted before\n          // Shape: [601, 301, ...], sbp: (S0, S1, B, S0, ...), hierarchy: [2, 3, 100, 7, ...]\n          // Looking at i = 3, we hold the second S0, but 601 is already splitted by the first S0.\n          // split_axis2holding_reduced_shapes: {(0: 300, 301), (1: 100, 101)}\n          last_holding_reduced_shapes.assign(it->second.begin(), it->second.end());\n        }\n      } else {\n        last_split_axis = -1;\n      }\n    }\n    // Add a new hierarchy dimension\n    reduced_hierarchy->emplace_back(hierarchy.At(hierarchy_dim));\n  };\n  for (int32_t hierarchy_dim = 0; hierarchy_dim < hierarchy.NumAxes(); hierarchy_dim++) {\n    // Shrink those dimension with hierarchy value = 1\n    if (hierarchy.At(hierarchy_dim) == 1) { continue; }\n    if (reduced_hierarchy->empty()) {\n      // Empty hierarchy, add to the back\n      add_to_reduced_sbp_hierarchy(hierarchy_dim);\n      continue;\n    }\n    if (std::all_of(indexes.begin(), indexes.end(), [&](int32_t index) {\n          // reduced_hierarchy->size() == reduced_nd_sbps[index]->sbp_parallel_size()\n          // Basically, current nd sbp == reduced nd sbp.back()\n          return nd_sbps[index]->sbp_parallel(hierarchy_dim)\n                 == reduced_nd_sbps[index]->sbp_parallel(reduced_hierarchy->size() - 1);\n        })) {\n      int32_t merged_hierarchy_value = reduced_hierarchy->back() * hierarchy.At(hierarchy_dim);\n      // You can merge two sbp with B or P.\n      // If sbp = S, then you need to make sure that all the shape value can be splitted\n      if (std::all_of(indexes.begin(), indexes.end(), [&](int32_t index) {\n            return !nd_sbps[index]->sbp_parallel(hierarchy_dim).has_split_parallel()\n                   || std::all_of(index2last_holding_reduced_shapes[index].begin(),\n                                  index2last_holding_reduced_shapes[index].end(), [&](int32_t i) {\n                                    return CanMergeSplit(i, merged_hierarchy_value);\n                                  });\n          })) {\n        // Merge sbp and hierarchy\n        reduced_hierarchy->back() = merged_hierarchy_value;\n        continue;\n      }\n    }\n    // Can not merge, add to the back\n    add_to_reduced_sbp_hierarchy(hierarchy_dim);\n  }\n  // [1, 1, ..., 1]: Any --> [1]: (B)\n  if (reduced_hierarchy->empty()) {\n    reduced_hierarchy->emplace_back(hierarchy.At(0));\n    for (auto& reduced_nd_sbp : reduced_nd_sbps) {\n      reduced_nd_sbp->add_sbp_parallel()->mutable_broadcast_parallel();\n    }\n  }\n}\n\nvoid NdSbpDimReduce(const ParallelDesc& parallel_desc, const NdSbp& nd_sbp,\n                    ParallelDesc* reduced_parallel_desc, NdSbp* reduced_nd_sbp,\n                    const Shape& logical_shape) {\n  // Speed up for 1d sbp\n  if (parallel_desc.hierarchy()->NumAxes() == 1) {\n    *reduced_parallel_desc = parallel_desc;\n    *reduced_nd_sbp = nd_sbp;\n    return;\n  }\n  Shape reduced_hierarchy;\n  NdSbpDimReduce(*parallel_desc.hierarchy(), nd_sbp, &reduced_hierarchy, reduced_nd_sbp,\n                 logical_shape);\n\n  ReplaceHierarchy4ParallelDesc(parallel_desc, reduced_hierarchy, reduced_parallel_desc);\n}\n\nvoid InOutParallelDimReduce(const Shape& in_hierarchy, const Shape& out_hierarchy,\n                            const NdSbp& in_nd_sbp, const NdSbp& out_nd_sbp,\n                            Shape* reduced_in_hierarchy, Shape* reduced_out_hierarchy,\n                            NdSbp* reduced_in_nd_sbp, NdSbp* reduced_out_nd_sbp,\n                            const Shape& logical_shape) {\n  if (in_hierarchy == out_hierarchy) {\n    // [2, 4]: (S0, S0) -> [2, 4]: (S0, S1)\n    NdSbpsDimReduce(in_hierarchy, {&in_nd_sbp, &out_nd_sbp}, reduced_in_hierarchy,\n                    {reduced_in_nd_sbp, reduced_out_nd_sbp}, logical_shape);\n    *reduced_out_hierarchy = *reduced_in_hierarchy;\n  } else {\n    // [2, 4]: (S0, S0) -> [4, 2]: (S0, S1)\n    // [2, 4]: (S0, S0) -> [3, 3]: (S0, S1)\n    NdSbpDimReduce(in_hierarchy, in_nd_sbp, reduced_in_hierarchy, reduced_in_nd_sbp, logical_shape);\n    NdSbpDimReduce(out_hierarchy, out_nd_sbp, reduced_out_hierarchy, reduced_out_nd_sbp,\n                   logical_shape);\n\n    // Sbp of 3d or higher dimension would use general basic communication\n    // Only looks at 1d to 2d or 2d to 1d\n    if (reduced_in_hierarchy->NumAxes() + reduced_out_hierarchy->NumAxes() == 3\n        && reduced_in_hierarchy->elem_cnt() == reduced_out_hierarchy->elem_cnt()) {\n      if (reduced_in_hierarchy->NumAxes() == 1) {\n        // [8]: S0 -> [4, 2]: (S0, S1)\n        // [8]: B -> [2, 4]: (S0, S1)\n        const auto& in_sbp_parallel = reduced_in_nd_sbp->sbp_parallel(0);\n        if (!in_sbp_parallel.has_split_parallel()\n            || CanMergeSplit(logical_shape.At(in_sbp_parallel.split_parallel().axis()),\n                             reduced_in_hierarchy->elem_cnt())) {\n          // Change [8]: S0 -> [4, 2]: (S0, S1) to [4, 2]: (S0, S0) -> [4, 2]: (S0, S1)\n          // Change [8]: B -> [2, 4]: (S0, S1) to [2, 4]: (B, B) -> [2, 4]: (S0, S1)\n          *reduced_in_nd_sbp->add_sbp_parallel() = in_sbp_parallel;\n          *reduced_in_hierarchy = *reduced_out_hierarchy;\n        }\n      } else {\n        // [2, 3]: (S0, P) -> [6]: S0\n        // [3, 4]: (B, S1) -> [12]: B\n        const auto& out_sbp_parallel = reduced_out_nd_sbp->sbp_parallel(0);\n        if (!out_sbp_parallel.has_split_parallel()\n            || CanMergeSplit(logical_shape.At(out_sbp_parallel.split_parallel().axis()),\n                             reduced_out_hierarchy->elem_cnt())) {\n          // Change [2, 3]: (S0, P) -> [6]: S0 to [2, 3]: (S0, P) -> [2, 3]: (S0, S0)\n          // Change [3, 4]: (B, S1) -> [12]: B to [3, 4]: (B, S1) -> [3, 4]: (B, B)\n          *reduced_out_nd_sbp->add_sbp_parallel() = out_sbp_parallel;\n          *reduced_out_hierarchy = *reduced_in_hierarchy;\n        }\n      }\n    }\n  }\n}\n\nvoid InOutParallelDimReduce(const ParallelDesc& in_parallel_desc,\n                            const ParallelDesc& out_parallel_desc, const NdSbp& in_nd_sbp,\n                            const NdSbp& out_nd_sbp, ParallelDesc* reduced_in_parallel_desc,\n                            ParallelDesc* reduced_out_parallel_desc, NdSbp* reduced_in_nd_sbp,\n                            NdSbp* reduced_out_nd_sbp, const Shape& logical_shape) {\n  // Speed up for 1d sbp\n  if (in_parallel_desc.hierarchy()->NumAxes() == 1\n      && out_parallel_desc.hierarchy()->NumAxes() == 1) {\n    *reduced_in_parallel_desc = in_parallel_desc;\n    *reduced_out_parallel_desc = out_parallel_desc;\n    *reduced_in_nd_sbp = in_nd_sbp;\n    *reduced_out_nd_sbp = out_nd_sbp;\n  } else {\n    Shape reduced_in_hierarchy;\n    Shape reduced_out_hierarchy;\n    InOutParallelDimReduce(*in_parallel_desc.hierarchy(), *out_parallel_desc.hierarchy(), in_nd_sbp,\n                           out_nd_sbp, &reduced_in_hierarchy, &reduced_out_hierarchy,\n                           reduced_in_nd_sbp, reduced_out_nd_sbp, logical_shape);\n    ReplaceHierarchy4ParallelDesc(in_parallel_desc, reduced_in_hierarchy, reduced_in_parallel_desc);\n    ReplaceHierarchy4ParallelDesc(out_parallel_desc, reduced_out_hierarchy,\n                                  reduced_out_parallel_desc);\n  }\n}\n\nint64_t TotalByteSize4BlobDesc(const BlobDesc& logical_blob_desc) {\n  return logical_blob_desc.shape().elem_cnt() * GetSizeOfDataType(logical_blob_desc.data_type());\n}\n\nint64_t MaxByteSize4BlobDescSbp(const BlobDesc& logical_blob_desc, const NdSbp& nd_sbp,\n                                const Shape& hierarchy) {\n  Shape blob_shape = logical_blob_desc.shape();\n  for (int32_t sbp_id = 0; sbp_id < nd_sbp.sbp_parallel_size(); sbp_id++) {\n    const auto& sbp = nd_sbp.sbp_parallel(sbp_id);\n    if (sbp.has_split_parallel()) {\n      int32_t split_axis = sbp.split_parallel().axis();\n      blob_shape.Set(split_axis, CeilQuotient(blob_shape.At(split_axis), hierarchy.At(sbp_id)));\n    }\n  }\n  return blob_shape.elem_cnt() * GetSizeOfDataType(logical_blob_desc.data_type());\n}\n\nMaybe<double> ComputeLazyCopyCostBetweenNdSbp(const NdSbp& producer_sbp_parallel,\n                                              const NdSbp& consumer_sbp_parallel,\n                                              const BlobDesc& logical_blob_desc,\n                                              const ParallelDesc& producer_parallel_desc,\n                                              const ParallelDesc& consumer_parallel_desc,\n                                              bool requires_same_sbp) {\n  if (!(CheckNdSbp(producer_sbp_parallel) && CheckNdSbp(consumer_sbp_parallel))) {\n    return Error::RuntimeError() << \"Illegal sbp parallel has been found.\";\n  }\n  bool on_same_devices = producer_parallel_desc.EqualsIgnoringHierarchy(consumer_parallel_desc);\n\n  // Reduce before cost computation\n  Shape reduced_in_hierarchy;\n  NdSbp reduced_in_nd_sbp;\n  Shape reduced_out_hierarchy;\n  NdSbp reduced_out_nd_sbp;\n  InOutParallelDimReduce(*producer_parallel_desc.hierarchy(), *consumer_parallel_desc.hierarchy(),\n                         producer_sbp_parallel, consumer_sbp_parallel, &reduced_in_hierarchy,\n                         &reduced_out_hierarchy, &reduced_in_nd_sbp, &reduced_out_nd_sbp,\n                         logical_blob_desc.shape());\n  int32_t in_dim = reduced_in_hierarchy.NumAxes();\n  int32_t out_dim = reduced_out_hierarchy.NumAxes();\n  // Not supporting n-D sbp with n >= 3\n  // TODO: Support it in the future\n  if (std::min(in_dim, out_dim) <= 0 || std::max(in_dim, out_dim) >= 3) {\n    return kUnsupportedBoxing;\n  }\n\n  bool same_nd_sbp = reduced_in_nd_sbp == reduced_out_nd_sbp;\n  // Same sbp is always supported.\n  if (same_nd_sbp && on_same_devices && reduced_in_hierarchy == reduced_out_hierarchy) {\n    return 0.0;\n  }\n  if (requires_same_sbp) { return kUnsupportedBoxing; }\n\n  // We support different hierarchy for 1D sbp\n  if (in_dim == 1 && out_dim == 1) {\n    return GetTransferCost()\n           + JUST(ComputCopyCostBetweenTwoSbpParallel(\n               reduced_in_nd_sbp.sbp_parallel(0), reduced_out_nd_sbp.sbp_parallel(0),\n               logical_blob_desc, on_same_devices, reduced_in_hierarchy.elem_cnt(),\n               reduced_out_hierarchy.elem_cnt()));\n  }\n\n#ifdef WITH_CUDA\n  static const bool enable_general_basic_communication =\n      ParseBooleanFromEnv(\"ONEFLOW_BOXING_ENABLE_GENERAL_BASIC_COMMUNICATION\", false);\n  // Use a general basic communication if no P in the consumer\n  if ((((Singleton<ResourceDesc, ForSession>::Get()->nccl_use_compute_stream()\n         && producer_parallel_desc == consumer_parallel_desc)\n        || enable_general_basic_communication)\n       && !NdSbpHasPartialParallel(consumer_sbp_parallel))\n      && producer_parallel_desc.device_type() == DeviceType::kCUDA\n      && consumer_parallel_desc.device_type() == DeviceType::kCUDA) {\n    return Cost4GeneralBasicCommunication(producer_sbp_parallel, consumer_sbp_parallel,\n                                          logical_blob_desc, producer_parallel_desc,\n                                          consumer_parallel_desc)\n           + GetTransferCost();\n  }\n#endif  // WITH_CUDA\n\n  // Not supporting different hierarchy without general basic communication\n  if (reduced_in_hierarchy.elem_cnt() != reduced_out_hierarchy.elem_cnt()) {\n    return kUnsupportedBoxing;\n  }\n\n  double logical_blob_size = TotalByteSize4BlobDesc(logical_blob_desc);\n\n  if (in_dim == 2 && out_dim == 2) {\n    // Not supporting different hierarchy\n    // TODO: Support it in the future\n    if (reduced_in_hierarchy != reduced_out_hierarchy) { return kUnsupportedBoxing; }\n    return GetTransferCost()\n           + JUST(ComputCopyCostBetweenTwoNdSbp(reduced_in_nd_sbp, reduced_out_nd_sbp,\n                                                logical_blob_size, reduced_in_hierarchy,\n                                                on_same_devices));\n  }\n\n  // (in_dim == 2 && out_dim == 1) || (in_dim == 1 && out_dim == 2)\n  if (in_dim == 2 && out_dim == 1) {\n    return GetTransferCost()\n           + JUST(ComputCopyCostBetweenTwoNdSbp(reduced_in_nd_sbp, reduced_out_nd_sbp,\n                                                logical_blob_size, reduced_in_hierarchy,\n                                                on_same_devices));\n  }\n\n  if (in_dim == 1 && out_dim == 2) {\n    return GetTransferCost()\n           + JUST(ComputCopyCostBetweenTwoNdSbp(reduced_in_nd_sbp, reduced_out_nd_sbp,\n                                                logical_blob_size, reduced_out_hierarchy,\n                                                on_same_devices));\n  }\n\n  return Error::RuntimeError()\n         << \"Should not reach here. Something went wrong in ComputCopyCostBetweenNdSbp() in \"\n            \"sbp_util.cpp.\";\n}\n\ndouble GetValidMaxCopyCost() {\n  // We suppose that valid copy cost range is [0, FloatMax*0.8]\n  static const double kValidMaxCopyCost = kUnsupportedBoxing * 0.8;\n  return kValidMaxCopyCost;\n}\n\ndouble GetTransferCost() {\n  // Each transfer would have cost.\n  // Except for same parallel description and sbp\n  static const double kTransferCost = ParseFloatFromEnv(\"AUTO_PARALLEL_TRANSFER_COST\", 1.65e4);\n  return kTransferCost;\n}\n\nvoid ResizeNdSbpSignature(NdSbpSignature& nd_sbp_sig, int32_t size) {\n  for (auto& pair : *nd_sbp_sig.mutable_bn_in_op2nd_sbp()) {\n    if (pair.second.sbp_parallel_size() > size) { pair.second.clear_sbp_parallel(); }\n    while (pair.second.sbp_parallel_size() < size) { pair.second.add_sbp_parallel(); }\n  }\n}\n\nvoid SetNdSbpSignature(NdSbpSignature* nd_sbp_signature, const SbpSignature& sbp_signature,\n                       int32_t sbp_axis) {\n  for (const auto& pair : sbp_signature.bn_in_op2sbp_parallel()) {\n    *((*nd_sbp_signature->mutable_bn_in_op2nd_sbp())[pair.first].mutable_sbp_parallel(sbp_axis)) =\n        pair.second;\n  }\n}\n\nvoid DfsGetNdSbpSignature(NdSbpSignature& nd_sbp_sig, int32_t depth, int32_t dims,\n                          const Shape& hierarchy,\n                          const HashMap<int32_t, SbpSignatureList>& hierarchy_value2sbp_sig_list,\n                          std::vector<NdSbpSignature>* nd_sbp_sig_list) {\n  if (depth == dims) {\n    nd_sbp_sig_list->push_back(nd_sbp_sig);\n  } else {\n    for (const auto& sbp_signature :\n         hierarchy_value2sbp_sig_list.at(hierarchy.At(depth)).sbp_signature()) {\n      SetNdSbpSignature(&nd_sbp_sig, sbp_signature, depth);\n      DfsGetNdSbpSignature(nd_sbp_sig, depth + 1, dims, hierarchy, hierarchy_value2sbp_sig_list,\n                           nd_sbp_sig_list);\n    }\n  }\n}\n\nnamespace {\n\n// give a mesure value for NdSbp for sorting\nsize_t MesureNdSbp(const NdSbp& nd_sbp) {\n  // start from 1, B + P + max split axis (8)\n  constexpr size_t kMaxSplitAxis = 8;\n  constexpr size_t kCarryDigit = kMaxSplitAxis + 3;\n  size_t value = 0;\n  for (int i = 0; i < nd_sbp.sbp_parallel_size(); ++i) {\n    size_t cur_dim_value = 0;\n    const auto& sbp = nd_sbp.sbp_parallel(i);\n    if (sbp.has_broadcast_parallel()) {\n      cur_dim_value = 1;\n    } else if (sbp.has_partial_sum_parallel()) {\n      cur_dim_value = 2;\n    } else if (sbp.has_split_parallel()) {\n      CHECK_LT(sbp.split_parallel().axis(), kMaxSplitAxis);\n      // from 3 to 10\n      cur_dim_value = 3 + sbp.split_parallel().axis();\n    } else {\n      UNIMPLEMENTED();\n    }\n    value = value * kCarryDigit + cur_dim_value;\n  }\n  return value;\n}\n\nsize_t MesureNdSbpSignature(const NdSbpSignature& nd_sbp_sig, const std::vector<std::string>& bns) {\n  // big enough for 2d-sbp signatrue set\n  // if want to extend to 3d-sbp, consider increase to 170\n  constexpr size_t kCarryDigit = 97;\n  size_t value = 0;\n  for (size_t i = 0; i < bns.size(); ++i) {\n    auto nd_sbp_it = nd_sbp_sig.bn_in_op2nd_sbp().find(bns[i]);\n    CHECK(nd_sbp_it != nd_sbp_sig.bn_in_op2nd_sbp().end())\n        << \"can't find bn (\" << bns[i] << \") in \" << PbMessage2TxtString(nd_sbp_sig);\n    size_t cur_arg_value = MesureNdSbp(nd_sbp_it->second);\n    CHECK_LE(value + cur_arg_value / kCarryDigit, std::numeric_limits<size_t>::max() / kCarryDigit);\n    value = value * kCarryDigit + cur_arg_value;\n  }\n  return value;\n}\n\n}  // namespace\n\nvoid DeduplicateNdSbpSignatureList(std::vector<NdSbpSignature>* nd_sbp_sig_list,\n                                   const std::vector<std::string>& bns) {\n  if (bns.size() > 8) { return; }\n  std::map<size_t, NdSbpSignature> value2nd_sbp_sig;\n  for (auto& nd_sbp_sig : *nd_sbp_sig_list) {\n    size_t order_value = MesureNdSbpSignature(nd_sbp_sig, bns);\n    if (value2nd_sbp_sig.find(order_value) == value2nd_sbp_sig.end()) {\n      value2nd_sbp_sig.emplace(order_value, std::move(nd_sbp_sig));\n    }\n  }\n  nd_sbp_sig_list->clear();\n  for (auto& nd_sbp_pair : value2nd_sbp_sig) {\n    nd_sbp_sig_list->emplace_back(std::move(nd_sbp_pair.second));\n  }\n}\n\n// Compute storage per device for given NdSbp\ndouble Storage4NdSbp(const NdSbp& nd_sbp, Shape& logical_shape, const Shape& parallel_hierarchy) {\n  if (nd_sbp.sbp_parallel_size() == 1) {\n    double logical_blob_size = logical_shape.elem_cnt();\n    // Checking 1D sbp\n    const auto& sbp_parallel = nd_sbp.sbp_parallel(0);\n    if (sbp_parallel.has_split_parallel()) {\n      const int64_t axis = sbp_parallel.split_parallel().axis();\n      if (axis >= logical_shape.NumAxes()) { return kUnsupportedBoxing; }\n      if (logical_shape.At(axis) < parallel_hierarchy.At(0)) { return kUnsupportedBoxing; }\n      logical_blob_size /= parallel_hierarchy.At(0);\n    }\n    return logical_blob_size;\n  } else {\n    for (int32_t dim_sbp = 0; dim_sbp < nd_sbp.sbp_parallel_size(); ++dim_sbp) {\n      const auto& sbp_parallel = nd_sbp.sbp_parallel(dim_sbp);\n      if (sbp_parallel.has_split_parallel()) {\n        // Split axis and store result back to logical shape\n        const int64_t axis = sbp_parallel.split_parallel().axis();\n        if (axis >= logical_shape.NumAxes()) { return kUnsupportedBoxing; }\n        // Use completely average split to count the storage\n        if (logical_shape.At(axis) < parallel_hierarchy.At(dim_sbp)) { return kUnsupportedBoxing; }\n        logical_shape.Set(axis, logical_shape.At(axis) / parallel_hierarchy.At(dim_sbp));\n      }\n    }\n    return logical_shape.elem_cnt();\n  }\n}\n\n// Judge whether an NdSbp could be applied on a tensor with given logical shape\n// True means this NdSbp is not valid.\nMaybe<bool> FilterNdSbpByLogicalShape(const NdSbp& nd_sbp, Shape& logical_shape,\n                                      const Shape& parallel_hierarchy) {\n  return Storage4NdSbp(nd_sbp, logical_shape, parallel_hierarchy) > GetValidMaxCopyCost();\n}\n\nMaybe<double> ComputeCopyCostBetweenNdSbp(const NdSbp& producer_sbp_parallel,\n                                          const NdSbp& consumer_sbp_parallel,\n                                          const BlobDesc& logical_blob_desc,\n                                          const ParallelDesc& producer_parallel_desc,\n                                          const ParallelDesc& consumer_parallel_desc,\n                                          bool requires_same_sbp) {\n  return JUST(GetComputeCopyCostFunc())(producer_sbp_parallel, consumer_sbp_parallel,\n                                        logical_blob_desc, producer_parallel_desc,\n                                        consumer_parallel_desc, requires_same_sbp);\n}\n\nMaybe<double> ComputeCopyCostWithMiddleNodes(const NdSbp& producer_sbp_parallel,\n                                             const NdSbp& consumer_sbp_parallel,\n                                             const BlobDesc& logical_blob_desc,\n                                             const ParallelDesc& producer_parallel_desc,\n                                             const ParallelDesc& consumer_parallel_desc,\n                                             bool requires_same_sbp) {\n  // In 90% of the transfer, we would have the same parallel description for producer and consumer\n  // We need to speed it up and give an approximation of the cost\n  if (producer_parallel_desc.EqualsIgnoringHierarchy(consumer_parallel_desc)) {\n    // [2, 2]: (S0, S1) -> [2, 2]: (S0, S1)\n    if (*producer_parallel_desc.hierarchy() == *consumer_parallel_desc.hierarchy()\n        && producer_sbp_parallel == consumer_sbp_parallel) {\n      return 0.0;\n    }\n    // Reduce before cost computation\n    Shape reduced_in_hierarchy;\n    NdSbp reduced_in_nd_sbp;\n    Shape reduced_out_hierarchy;\n    NdSbp reduced_out_nd_sbp;\n    InOutParallelDimReduce(*producer_parallel_desc.hierarchy(), *consumer_parallel_desc.hierarchy(),\n                           producer_sbp_parallel, consumer_sbp_parallel, &reduced_in_hierarchy,\n                           &reduced_out_hierarchy, &reduced_in_nd_sbp, &reduced_out_nd_sbp,\n                           logical_blob_desc.shape());\n\n    // [2, 2]: (B, B) -> [4]: B\n    if (reduced_in_hierarchy == reduced_out_hierarchy && reduced_in_nd_sbp == reduced_out_nd_sbp) {\n      return 1.0;\n    }\n  }\n  if (requires_same_sbp) { return kUnsupportedBoxing; }\n#ifdef WITH_CUDA\n  static const bool enable_general_basic_communication =\n      ParseBooleanFromEnv(\"ONEFLOW_BOXING_ENABLE_GENERAL_BASIC_COMMUNICATION\", false);\n  // Use a general basic communication if no P in the consumer\n  if ((((Singleton<ResourceDesc, ForSession>::Get()->nccl_use_compute_stream()\n         && producer_parallel_desc == consumer_parallel_desc)\n        || enable_general_basic_communication)\n       && !NdSbpHasPartialParallel(consumer_sbp_parallel))\n      && producer_parallel_desc.device_type() == DeviceType::kCUDA\n      && consumer_parallel_desc.device_type() == DeviceType::kCUDA) {\n    return Cost4GeneralBasicCommunication(producer_sbp_parallel, consumer_sbp_parallel,\n                                          logical_blob_desc, producer_parallel_desc,\n                                          consumer_parallel_desc)\n           + GetTransferCost();\n  }\n#endif  // WITH_CUDA\n\n  // Initialize boxing collector\n  constexpr int32_t kRegularMaxSplitAxes = 6;\n  static thread_local BoxingCollector boxing_collector(kRegularMaxSplitAxes);\n  std::vector<NdSbp> middle_sbps;\n  // Ask for middle nodes\n  int32_t diag_node = 0;\n  JUST(boxing_collector.AskSbpCombination(\n      producer_sbp_parallel, consumer_sbp_parallel, logical_blob_desc, producer_parallel_desc,\n      consumer_parallel_desc, /*is_customized=*/false, middle_sbps, &diag_node,\n      /*compute_cost=*/true));\n  // Parameters\n  double total_cost = 0.0;\n  // Set up the information of the first node in the first connection\n  const NdSbp* pre_nd_sbp = &producer_sbp_parallel;\n  const ParallelDesc* pre_parallel_desc = &producer_parallel_desc;\n  const ParallelDesc* middle_parallel_desc = nullptr;\n  // Connection for the next middle node\n  for (int32_t middle_node_id = 0; middle_node_id < middle_sbps.size(); middle_node_id++) {\n    const auto& middle_sbp = middle_sbps[middle_node_id];\n    if (middle_node_id < diag_node) {\n      middle_parallel_desc = &producer_parallel_desc;\n    } else {\n      middle_parallel_desc = &consumer_parallel_desc;\n    }\n    // We use the parallel description of consumer as the parallel description for all the middle\n    // nodes, following the same procedure in boxing_with_middle_nodes.cpp\n    // TODO: Needs more effort if dealing with different placement\n    total_cost += JUST(ComputeLazyCopyCostBetweenNdSbp(*pre_nd_sbp, middle_sbp, logical_blob_desc,\n                                                       *pre_parallel_desc, *middle_parallel_desc,\n                                                       requires_same_sbp));\n    // Set up the information of the first node in the next connection\n    pre_nd_sbp = &middle_sbp;\n    pre_parallel_desc = middle_parallel_desc;\n  }\n  // Connection between the last middle node and consumer\n  total_cost += JUST(ComputeLazyCopyCostBetweenNdSbp(*pre_nd_sbp, consumer_sbp_parallel,\n                                                     logical_blob_desc, *pre_parallel_desc,\n                                                     consumer_parallel_desc, requires_same_sbp));\n\n  return total_cost;\n}\n\n// Decide the priority to infer sbp\ndouble ComputeSbpInferPriority(const NdSbp& producer_nd_sbp, const NdSbp& consumer_nd_sbp,\n                               const ParallelDesc& producer_parallel_desc,\n                               const ParallelDesc& consumer_parallel_desc, bool requires_same_sbp,\n                               const Shape& logical_shape) {\n  if (producer_nd_sbp == consumer_nd_sbp && producer_parallel_desc == consumer_parallel_desc) {\n    // Highest priority: this blob have the same placement and sbp on both the producer and\n    // consumer\n    return 0.0;\n  }\n  // Reduce before cost computation\n  Shape reduced_in_hierarchy;\n  NdSbp reduced_in_nd_sbp;\n  Shape reduced_out_hierarchy;\n  NdSbp reduced_out_nd_sbp;\n  InOutParallelDimReduce(*producer_parallel_desc.hierarchy(), *consumer_parallel_desc.hierarchy(),\n                         producer_nd_sbp, consumer_nd_sbp, &reduced_in_hierarchy,\n                         &reduced_out_hierarchy, &reduced_in_nd_sbp, &reduced_out_nd_sbp,\n                         logical_shape);\n\n  if (requires_same_sbp) {\n    // This blob does not support boxing\n    if (reduced_in_nd_sbp == reduced_out_nd_sbp && reduced_in_hierarchy == reduced_out_hierarchy\n        && producer_parallel_desc.EqualsIgnoringHierarchy(consumer_parallel_desc)) {\n      // Normal priority: No transfer occurs but we have different sbp\n      // For example: [1]:S0 -> [1]:B\n      // [1, 2]:(P, S0) -> [1, 2]:(S0, S0)\n      return 1.0;\n    } else {\n      // Penalty: this blob have different placements and sbps but it does not support boxing\n      return 2.0;\n    }\n  } else {\n    // This blob supports boxing\n    if (producer_nd_sbp.sbp_parallel_size() == consumer_nd_sbp.sbp_parallel_size()) {\n      if (producer_nd_sbp == consumer_nd_sbp) {\n        // Highest priority: this blob have the same sbp on both the producer and consumer\n        // Not just [0-3] -> [4-7], but also cpu:[0] -> cuda:[0-3]\n        return 0.0;\n      }\n    } else {\n      if (reduced_in_nd_sbp == reduced_out_nd_sbp) {\n        // Highest priority: this blob have the same sbp on both the producer and consumer\n        // [2, 2]: (S0, S0) -> [2]: S0\n        // (learning rate) [1]: B -> [2, 2]: (B, B)\n        return 0.0;\n      }\n    }\n    // Normal priority: transfer might occurs\n    // Or might not: [1, 2]: (P, S0) -> [1, 2]: (B, S0)\n    // No transfer but not highest priority\n    return 1.0;\n  }\n}\n\n// The transfer ratio for general basic communication\n// Cost = ratio * data amount\n// When we get the this function, either producer_sbp_parallel != consumer_sbp_parallel\n// or producer_parallel_desc != consumer_parallel_desc\ndouble Cost4GeneralBasicCommunication(const NdSbp& producer_sbp_parallel,\n                                      const NdSbp& consumer_sbp_parallel,\n                                      const BlobDesc& logical_blob_desc,\n                                      const ParallelDesc& producer_parallel_desc,\n                                      const ParallelDesc& consumer_parallel_desc) {\n  // The upper bound of the amount of the transferred data\n  int32_t producer_partial_ratio =\n      PartialRatio4Producer(producer_sbp_parallel, producer_parallel_desc);\n  int32_t consumer_broadcast_ratio =\n      BroadcastRatio4Consumer(consumer_sbp_parallel, consumer_parallel_desc);\n  // More intersection on the same devices\n  bool on_same_devices = producer_parallel_desc.EqualsIgnoringHierarchy(consumer_parallel_desc);\n  // approximate intersection ratio\n  double intersection_ratio = 1.0;\n  // (?, P, ?)->(Si, Sj)->(?, B, ?), two-step transfer\n  if (producer_partial_ratio > 1 && consumer_broadcast_ratio > 1) {\n    if (on_same_devices) {\n      // Pure P in the producer or B in the consumer\n      // (P, P, P) -> ? or ? -> (B, B)\n      if (producer_partial_ratio == producer_parallel_desc.parallel_num()\n          || consumer_broadcast_ratio == consumer_parallel_desc.parallel_num()) {\n        // There some cases which is not applicable to this ratio\n        // We just take the one with the largest possibility\n        // For example: (P, S0) -> (B, B) for 1-D blob with machine hierarchy [n, m]\n        // The path should be (P, S0) -> (S0, S0) -> (B, B)\n        // true intersection ratio = 1/m + 1\n        intersection_ratio = 2.0;\n      } else {\n        // sbp_consumer = (B, Si) or (Si, B)\n        for (int32_t sbp_id = 0; sbp_id < std::min(producer_sbp_parallel.sbp_parallel_size(),\n                                                   consumer_sbp_parallel.sbp_parallel_size());\n             sbp_id++) {\n          if (consumer_sbp_parallel.sbp_parallel(sbp_id).has_split_parallel()) {\n            const auto& producer_sbp4sbp_id = producer_sbp_parallel.sbp_parallel(sbp_id);\n            // (B, P) or (Si, P) -> (Si, B)\n            // (P, B) or (P, Si) -> (B, Si)\n            if (producer_sbp4sbp_id.has_broadcast_parallel()\n                || producer_sbp4sbp_id == consumer_sbp_parallel.sbp_parallel(sbp_id)) {\n              intersection_ratio = 2.0;\n              break;\n            }\n          }\n        }\n        // Judge whether the intersection ratio is given a value (2.0)\n        if (intersection_ratio == 1.0) {\n          // The true intersection ratio range from 0 to 2,\n          // we just take a middle point of the range as the approximation\n          // For example: (P, S0) -> (S0, B), Path: (P, S0) -> (S1, S0) -> (S0, B)\n          // true intersection ratio = 1 + 1/m\n          // For example: (P, S0) -> (S1, B), Path: (P, S0) -> (S1, S0) -> (S1, B)\n          // true intersection ratio = 1 + 1\n          // For example: (P, S0) -> (B, S0), with a 1D blob\n          // true intersection ratio = (n+p-1)/nm + (n+p-1)/nm\n          // For example: (S0, P) -> (B, S0), Path: (S0, P) -> (S0, S1) -> (B, S0)\n          // true intersection ratio = 1 + 1/n\n\n          // We use the approximation 1 + (1/n + 1/m)/2\n          intersection_ratio = 1.0 + 0.5 / producer_parallel_desc.hierarchy()->At(0)\n                               + 0.5 / producer_parallel_desc.hierarchy()->At(1);\n        }\n      }\n    }\n    // Otherwise, on different devices\n    // intersection_ratio = 1.0;\n  } else {\n    // No P in the producer or no B in the consumer, one-step transfer\n    if (on_same_devices) {\n      // We use simulation for nD sbp with n=1,2,3,...\n      TensorSliceView in_second_slice =\n          GetTensorSliceView4ParallelId(*producer_parallel_desc.hierarchy(), producer_sbp_parallel,\n                                        logical_blob_desc.shape(), /*parallel_id=*/1);\n      TensorSliceView out_second_slice =\n          GetTensorSliceView4ParallelId(*consumer_parallel_desc.hierarchy(), consumer_sbp_parallel,\n                                        logical_blob_desc.shape(), /*parallel_id=*/1);\n      const TensorSliceView& intersection = in_second_slice.Intersect(out_second_slice);\n      // The intersection ratio is design for two steps.\n      // However, we only have one step here, we would increase the ratio by 1.0\n      // to eliminate the unused step\n      intersection_ratio += std::min(\n          1.0, (double)(intersection.shape().elem_cnt() * producer_parallel_desc.parallel_num())\n                   / logical_blob_desc.shape().elem_cnt());\n    }\n    // Otherwise, on different devices\n    // intersection_ratio = 1.0;\n  }\n  // Subtract the intersection part\n  return (producer_partial_ratio + consumer_broadcast_ratio - intersection_ratio)\n         * TotalByteSize4BlobDesc(logical_blob_desc);\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/framework/sbp_infer_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#ifndef ONEFLOW_CORE_FRAMEWORK_SBP_INFER_UTIL_H_\n#define ONEFLOW_CORE_FRAMEWORK_SBP_INFER_UTIL_H_\n\n#include \"oneflow/core/job/sbp_parallel.h\"\n\nnamespace oneflow {\n\nenum SbpInferRuleTag : int {\n  kAllMatch = 1,   // All match first, then lowest cost\n  kMatchAMAP = 2,  // Match as much as possible\n  kMinCost = 3     // Lowest cost\n};\n\nenum Penalty4PartialInConsumerTag : int {\n  kSlight = 1,  // Slight penalty\n  kMiddle = 2,  // Make sure we do not select P in the consumer\n  kStrict = 3   // Not allow a transfer to P\n};\n\n// [2, 3, 4, 5, 9, 100, 8]: (P, S0, P, P, B, S1, P)\n// partial ratio = 2 * 4 * 5 * 8\nint32_t PartialRatio4Producer(const NdSbp& sbp_producer,\n                              const ParallelDesc& producer_parallel_desc);\n\n// [2, 3, 4, 5, 9, 100, 8]: (P, S0, B, P, B, S1, P)\n// broadcast ratio = 4 * 9\nint32_t BroadcastRatio4Consumer(const NdSbp& sbp_consumer,\n                                const ParallelDesc& consumer_parallel_desc);\n\nvoid NdSbpDimReduce(const Shape& hierarchy, const NdSbp& nd_sbp, Shape* reduced_hierarchy,\n                    NdSbp* reduced_nd_sbp, const Shape& logical_shape);\nvoid NdSbpsDimReduce(const Shape& hierarchy, const std::vector<const NdSbp*>& nd_sbps,\n                     Shape* reduced_hierarchy, const std::vector<NdSbp*>& reduced_nd_sbps,\n                     const Shape& logical_shape);\nvoid NdSbpDimReduce(const ParallelDesc& parallel_desc, const NdSbp& nd_sbp,\n                    ParallelDesc* reduced_parallel_desc, NdSbp* reduced_nd_sbp,\n                    const Shape& logical_shape);\n\nvoid InOutParallelDimReduce(const Shape& in_hierarchy, const Shape& out_hierarchy,\n                            const NdSbp& in_nd_sbp, const NdSbp& out_nd_sbp,\n                            Shape* reduced_in_hierarchy, Shape* reduced_out_hierarchy,\n                            NdSbp* reduced_in_nd_sbp, NdSbp* reduced_out_nd_sbp,\n                            const Shape& logical_shape);\nvoid InOutParallelDimReduce(const ParallelDesc& in_parallel_desc,\n                            const ParallelDesc& out_parallel_desc, const NdSbp& in_nd_sbp,\n                            const NdSbp& out_nd_sbp, ParallelDesc* reduced_in_parallel_desc,\n                            ParallelDesc* reduced_out_parallel_desc, NdSbp* reduced_in_nd_sbp,\n                            NdSbp* reduced_out_nd_sbp, const Shape& logical_shape);\n\ndouble GetValidMaxCopyCost();\n\ndouble GetTransferCost();\n\nvoid ResizeNdSbpSignature(NdSbpSignature& nd_sbp_sig, int32_t size);\n\nvoid SetNdSbpSignature(NdSbpSignature* nd_sbp_signature, const SbpSignature& sbp_signature,\n                       int32_t sbp_axis);\n\nvoid DfsGetNdSbpSignature(NdSbpSignature& nd_sbp_sig, int32_t depth, int32_t dims,\n                          const Shape& hierarchy,\n                          const HashMap<int32_t, SbpSignatureList>& hierarchy_value2sbp_sig_list,\n                          std::vector<NdSbpSignature>* nd_sbp_sig_list);\n\nvoid DeduplicateNdSbpSignatureList(std::vector<NdSbpSignature>* nd_sbp_sig_list,\n                                   const std::vector<std::string>& bns);\n\n// Compute storage for given NdSbp\ndouble Storage4NdSbp(const NdSbp& nd_sbp, Shape& logical_shape, const Shape& parallel_hierarchy);\n\n// Judge whether an NdSbp could be applied on a tensor with given logical shape\nMaybe<bool> FilterNdSbpByLogicalShape(const NdSbp& nd_sbp, Shape& logical_shape,\n                                      const Shape& parallel_hierarchy);\n\n// TODO: Unify lazy and eager boxing\nMaybe<double> ComputeCopyCostBetweenNdSbp(const NdSbp& producer_sbp_parallel,\n                                          const NdSbp& consumer_sbp_parallel,\n                                          const BlobDesc& logical_blob_desc,\n                                          const ParallelDesc& producer_parallel_desc,\n                                          const ParallelDesc& consumer_parallel_desc,\n                                          bool requires_same_sbp);\n\n// Cost for boxing in lazy\nMaybe<double> ComputeLazyCopyCostBetweenNdSbp(const NdSbp& producer_sbp_parallel,\n                                              const NdSbp& consumer_sbp_parallel,\n                                              const BlobDesc& logical_blob_desc,\n                                              const ParallelDesc& producer_parallel_desc,\n                                              const ParallelDesc& consumer_parallel_desc,\n                                              bool requires_same_sbp);\n\n// The public interface for computing cost\n// It uses the middle nodes algorithm.\nMaybe<double> ComputeCopyCostWithMiddleNodes(const NdSbp& producer_sbp_parallel,\n                                             const NdSbp& consumer_sbp_parallel,\n                                             const BlobDesc& logical_blob_desc,\n                                             const ParallelDesc& producer_parallel_desc,\n                                             const ParallelDesc& consumer_parallel_desc,\n                                             bool requires_same_sbp);\n\n// Decide the priority to infer sbp\n// 0: highest priority\n// 1.0: normal priority\n// 2.0: Penalty, the same as infinity\ndouble ComputeSbpInferPriority(const NdSbp& producer_sbp_parallel,\n                               const NdSbp& consumer_sbp_parallel,\n                               const ParallelDesc& producer_parallel_desc,\n                               const ParallelDesc& consumer_parallel_desc, bool requires_same_sbp,\n                               const Shape& logical_shape);\n\n// The transfer ratio for general basic communication\n// Cost = ratio * data amount\ndouble Cost4GeneralBasicCommunication(const NdSbp& producer_sbp_parallel,\n                                      const NdSbp& consumer_sbp_parallel,\n                                      const BlobDesc& logical_blob_desc,\n                                      const ParallelDesc& producer_parallel_desc,\n                                      const ParallelDesc& consumer_parallel_desc);\n\nint64_t TotalByteSize4BlobDesc(const BlobDesc& logical_blob_desc);\nint64_t MaxByteSize4BlobDescSbp(const BlobDesc& logical_blob_desc, const NdSbp& nd_sbp,\n                                const Shape& hierarchy);\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_FRAMEWORK_SBP_INFER_UTIL_H_\n"
  },
  {
    "path": "oneflow/core/framework/sbp_infer_util_test.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/sbp_infer_util.h\"\n#include \"oneflow/core/framework/nd_sbp.h\"\n\n#include <gtest/gtest.h>\n\nnamespace oneflow {\nnamespace test {\n\nnamespace {\n\nbool ParseNdSbpSignatureFromString(const std::string& nd_sbp_signature_str,\n                                   NdSbpSignature& nd_sbp_signature) {\n  auto* bn2nd_sbp = nd_sbp_signature.mutable_bn_in_op2nd_sbp();\n  std::string arg_name = \"in\";\n  bool meet_nd_sbp_group = false;\n  bool meet_split = false;\n  int nd_sbp_group_id = 0;\n  std::vector<std::string> nd_sbp_str_group;\n  size_t pos = 0;\n  while (pos < nd_sbp_signature_str.size()) {\n    const char& c = nd_sbp_signature_str[pos];\n    pos++;\n    if (c == ' ') {\n      continue;\n    } else if (c == '(') {\n      if (!meet_nd_sbp_group) {\n        // enter a nd-sbp group\n        meet_nd_sbp_group = true;\n        nd_sbp_str_group.emplace_back();\n        continue;\n      } else {\n        // meet left parentheses of S(x)\n        meet_split = true;\n      }\n    } else if (c == ')') {\n      if (meet_split) {\n        // meet right parentheses of S(x)\n        meet_split = false;\n      } else if (meet_nd_sbp_group) {\n        // leave a nd-sbp group\n        meet_nd_sbp_group = false;\n        std::string bn = arg_name + \"_\" + std::to_string(nd_sbp_group_id);\n        if (!ParseNdSbpFromStringList(nd_sbp_str_group, &(*bn2nd_sbp)[bn])) { return false; }\n        nd_sbp_str_group.clear();\n        continue;\n      } else {\n        return false;\n      }\n    } else if (c == ',') {\n      if (meet_nd_sbp_group) {\n        nd_sbp_str_group.emplace_back();\n      } else {\n        nd_sbp_group_id += 1;\n      }\n      continue;\n    } else if (c == '-') {\n      if (pos < nd_sbp_signature_str.size() && nd_sbp_signature_str[pos] == '>') {\n        // in args parsing has finished, parse out args\n        arg_name = \"out\";\n        nd_sbp_group_id = 0;\n        // skip '>' in substr '->'\n        pos++;\n        continue;\n      } else {\n        return false;\n      }\n    } else {\n      // do nothing\n    }\n    nd_sbp_str_group.back() += c;\n  }\n  return true;\n}\n\nstd::string NdSbpSignature2String(const NdSbpSignature& nd_sbp_signature,\n                                  const std::vector<std::string>& inputs,\n                                  const std::vector<std::string>& outputs) {\n  std::ostringstream ss;\n  auto BnNdSbpToString = [&](const std::string& bn) {\n    auto iter = nd_sbp_signature.bn_in_op2nd_sbp().find(bn);\n    CHECK(iter != nd_sbp_signature.bn_in_op2nd_sbp().end());\n    ss << NdSbpToString(iter->second);\n  };\n  auto ArgsNdSbpToString = [&](const std::vector<std::string>& arg_bns) {\n    for (size_t i = 0; i < arg_bns.size(); ++i) {\n      if (i > 0) { ss << \", \"; }\n      BnNdSbpToString(arg_bns[i]);\n    }\n  };\n  ArgsNdSbpToString(inputs);\n  ss << \" -> \";\n  ArgsNdSbpToString(outputs);\n  return ss.str();\n}\n\nvoid TestDeduplicateNdSbpSignature(const std::vector<std::string>& nd_sbp_signature_str_list,\n                                   const std::vector<std::string>& input_bns,\n                                   const std::vector<std::string>& output_bns) {\n  // parse\n  std::vector<NdSbpSignature> nd_sbp_sig_list;\n  nd_sbp_sig_list.reserve(nd_sbp_signature_str_list.size());\n  for (const auto& nd_sbp_signature_str : nd_sbp_signature_str_list) {\n    nd_sbp_sig_list.emplace_back();\n    ASSERT_TRUE(ParseNdSbpSignatureFromString(nd_sbp_signature_str, nd_sbp_sig_list.back()));\n  }\n\n  // shuffle and repeat\n  std::random_device rd;\n  std::mt19937 gen(rd());\n  std::shuffle(nd_sbp_sig_list.begin(), nd_sbp_sig_list.end(), gen);\n  nd_sbp_sig_list.reserve(nd_sbp_sig_list.size() + nd_sbp_sig_list.size() / 2);\n  std::copy_n(nd_sbp_sig_list.begin(), nd_sbp_sig_list.size() / 2,\n              std::back_inserter(nd_sbp_sig_list));\n  std::shuffle(nd_sbp_sig_list.begin(), nd_sbp_sig_list.end(), gen);\n\n  // dedup and sort\n  std::vector<std::string> bns;\n  bns.insert(bns.end(), input_bns.begin(), input_bns.end());\n  bns.insert(bns.end(), output_bns.begin(), output_bns.end());\n  DeduplicateNdSbpSignatureList(&nd_sbp_sig_list, bns);\n\n  // compare\n  ASSERT_EQ(nd_sbp_signature_str_list.size(), nd_sbp_sig_list.size());\n  for (size_t i = 0; i < nd_sbp_sig_list.size(); ++i) {\n    auto nd_sbp_sig_result = NdSbpSignature2String(nd_sbp_sig_list[i], input_bns, output_bns);\n    ASSERT_EQ(nd_sbp_sig_result, nd_sbp_signature_str_list[i]);\n  }\n}\n\n}  // namespace\n\nTEST(SbpInferUtil, DeduplicateNdSbpSignatureList) {\n  TestDeduplicateNdSbpSignature(\n      {\n          \"(B, B) -> (B, B)\",\n          \"(B, P) -> (B, P)\",\n          \"(B, S(0)) -> (B, S(0))\",\n          \"(B, S(1)) -> (B, S(1))\",\n          \"(B, S(3)) -> (B, S(2))\",\n          \"(P, B) -> (P, B)\",\n          \"(P, P) -> (P, P)\",\n          \"(P, S(0)) -> (P, S(0))\",\n          \"(P, S(1)) -> (P, S(1))\",\n          \"(P, S(3)) -> (P, S(2))\",\n          \"(S(0), B) -> (S(0), B)\",\n          \"(S(0), P) -> (S(0), P)\",\n          \"(S(0), S(0)) -> (S(0), S(0))\",\n          \"(S(0), S(1)) -> (S(0), S(1))\",\n          \"(S(0), S(3)) -> (S(0), S(2))\",\n          \"(S(1), B) -> (S(1), B)\",\n          \"(S(1), P) -> (S(1), P)\",\n          \"(S(1), S(0)) -> (S(1), S(0))\",\n          \"(S(1), S(1)) -> (S(1), S(1))\",\n          \"(S(1), S(3)) -> (S(1), S(2))\",\n          \"(S(3), B) -> (S(2), B)\",\n          \"(S(3), P) -> (S(2), P)\",\n          \"(S(3), S(0)) -> (S(2), S(0))\",\n          \"(S(3), S(1)) -> (S(2), S(1))\",\n          \"(S(3), S(3)) -> (S(2), S(2))\",\n      },\n      {\"in_0\"}, {\"out_0\"});\n}\n\n}  // namespace test\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/framework/scope_util.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <list>\n#include \"oneflow/core/framework/scope_util.h\"\n\n#include \"oneflow/core/common/just.h\"\n#include \"oneflow/core/framework/device.h\"\n#include \"oneflow/core/framework/instructions_builder.h\"\n#include \"oneflow/core/framework/session_util.h\"\n#include \"oneflow/core/job/job_conf.pb.h\"\n#include \"oneflow/core/job/lazy_mode.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nMaybe<Scope> MakeDefaultScope() {\n  JobConfigProto config_proto;\n  config_proto.mutable_predict_conf();\n  config_proto.set_job_name(\"\");\n  return MakeScope(config_proto, *JUST(Device::New(\"cpu\")));\n}\n\nstd::list<std::shared_ptr<Scope>>* ThreadLocalScopeStack() {\n  thread_local static std::list<std::shared_ptr<Scope>> scope_stack{CHECK_JUST(MakeDefaultScope())};\n  return &scope_stack;\n}\n\n}  // namespace\n\nMaybe<Scope> MakeScope(const JobConfigProto& config_proto, const Device& device) {\n  std::shared_ptr<Scope> scope;\n  JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe<void> {\n    int64_t session_id = JUST(GetDefaultSessionId());\n    std::string device_tag = \"cpu\";\n    std::string machine_ids = \"0\";\n    std::string device_ids = \"0\";\n    if (device.type() != \"cpu\") {\n      device_tag = device.type();\n      device_ids = std::to_string(device.device_id());\n    }\n    scope = JUST(builder->BuildInitialScope(session_id, config_proto, device_tag,\n                                            {machine_ids + \":\" + device_ids}, nullptr, false));\n    return Maybe<void>::Ok();\n  }));\n  return scope;\n}\n\nMaybe<Scope> MakeInitialScope(const JobConfigProto& job_conf, Symbol<ParallelDesc> placement,\n                              bool is_local) {\n  std::shared_ptr<Scope> scope;\n  JUST(PhysicalRun([&scope, &job_conf, placement,\n                    is_local](InstructionsBuilder* builder) -> Maybe<void> {\n    int64_t session_id = JUST(GetDefaultSessionId());\n    scope =\n        JUST(builder->BuildInitialScopeWithPlacement(session_id, job_conf, placement, is_local));\n    return Maybe<void>::Ok();\n  }));\n  return scope;\n}\n\nMaybe<Scope> GetCurrentScope() {\n  auto* scope_stack = ThreadLocalScopeStack();\n  CHECK_GT_OR_RETURN(scope_stack->size(), 0);\n  return scope_stack->back();\n}\n\nMaybe<void> InitThreadLocalScopeStack(const std::shared_ptr<Scope>& scope) {\n  auto* scope_stack = ThreadLocalScopeStack();\n  scope_stack->clear();\n  scope_stack->emplace_back(scope);\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> ThreadLocalScopeStackPush(const std::shared_ptr<Scope>& scope) {\n  auto* scope_stack = ThreadLocalScopeStack();\n  scope_stack->emplace_back(scope);\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> ThreadLocalScopeStackPop() {\n  auto* scope_stack = ThreadLocalScopeStack();\n  scope_stack->pop_back();\n  return Maybe<void>::Ok();\n}\n\nBackwardPassScopeGuard::BackwardPassScopeGuard() {\n  if (LazyMode::is_enabled()) {\n    const auto& scope = CHECK_JUST(GetCurrentScope());\n    if (scope) {\n      backward_pass_scope_ = CHECK_JUST(FindOrCreateBackwardPassScope(scope));\n      CHECK_JUST(ThreadLocalScopeStackPush(backward_pass_scope_));\n    }\n  }\n}\n\nBackwardPassScopeGuard::BackwardPassScopeGuard(const std::shared_ptr<Scope>& scope) {\n  if (scope && LazyMode::is_enabled()) {\n    backward_pass_scope_ = CHECK_JUST(FindOrCreateBackwardPassScope(scope));\n    CHECK_JUST(ThreadLocalScopeStackPush(backward_pass_scope_));\n  }\n}\n\nBackwardPassScopeGuard::~BackwardPassScopeGuard() {\n  if (backward_pass_scope_) { CHECK_JUST(ThreadLocalScopeStackPop()); }\n}\n\nclass BackwardPassScopeStorage {\n public:\n  std::mutex mutex;\n\n  static BackwardPassScopeStorage* Global() {\n    static BackwardPassScopeStorage instance;\n    return &instance;\n  }\n  HashMap<int64_t, std::shared_ptr<Scope>>& get() { return scopes_; }\n\n private:\n  HashMap<int64_t, std::shared_ptr<Scope>> scopes_;\n};\n\nextern const std::string kBackwardPass;\nMaybe<Scope> FindOrCreateBackwardPassScope(const std::shared_ptr<Scope>& scope) {\n  auto* storage = BackwardPassScopeStorage::Global();\n  auto& scopes = storage->get();\n  std::lock_guard<std::mutex> lock(storage->mutex);\n  auto it = scopes.find(JUST(scope->symbol_id()));\n  if (it != scopes.end()) { return it->second; }\n  auto scope_proto = JUST((scope->MakeChildScopeProto()));\n  scope_proto->set_calculation_pass_name(kBackwardPass);\n  std::shared_ptr<Scope> backward_pass_scope;\n  JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe<void> {\n    backward_pass_scope = JUST(builder->GetScopeSymbol(*scope_proto));\n    return Maybe<void>::Ok();\n  }));\n  scopes.emplace(JUST(scope->symbol_id()), backward_pass_scope);\n  return backward_pass_scope;\n}\n\nvoid ClearAllBackwardPassScope() {\n  auto* storage = BackwardPassScopeStorage::Global();\n  std::lock_guard<std::mutex> lock(storage->mutex);\n  storage->get().clear();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/framework/scope_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_FRAMEWORK_SCOPE_UTIL_H_\n#define ONEFLOW_CORE_FRAMEWORK_SCOPE_UTIL_H_\n\n#include <vector>\n#include \"oneflow/core/job/scope.h\"\n\nnamespace oneflow {\n\nMaybe<Scope> MakeScope(const JobConfigProto& config_proto, const Device& device);\n\nMaybe<Scope> MakeInitialScope(const JobConfigProto& job_conf, Symbol<ParallelDesc> placement,\n                              bool is_local);\n\nMaybe<Scope> GetCurrentScope();\n\nMaybe<void> InitThreadLocalScopeStack(const std::shared_ptr<Scope>& scope);\n\nMaybe<void> ThreadLocalScopeStackPush(const std::shared_ptr<Scope>& scope);\n\nMaybe<void> ThreadLocalScopeStackPop();\n\nclass BackwardPassScopeGuard {\n public:\n  BackwardPassScopeGuard();\n  explicit BackwardPassScopeGuard(const std::shared_ptr<Scope>& scope);\n  ~BackwardPassScopeGuard();\n\n private:\n  std::shared_ptr<Scope> backward_pass_scope_;\n};\n\nMaybe<Scope> FindOrCreateBackwardPassScope(const std::shared_ptr<Scope>& scope);\nvoid ClearAllBackwardPassScope();\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_FRAMEWORK_SCOPE_UTIL_H_\n"
  },
  {
    "path": "oneflow/core/framework/session_util.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/session_util.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nstd::mutex* GlobalSessionUtilMutex() {\n  static std::mutex global_id2session_map_mutex;\n  return &global_id2session_map_mutex;\n}\n\nstd::vector<int64_t>* RegsiteredSessionIds() {\n  static std::vector<int64_t> default_sess_id;\n  return &default_sess_id;\n}\n\nMaybe<void> SetDefaultSessionId(int64_t val) {\n  std::vector<int64_t>* ids = RegsiteredSessionIds();\n  ids->emplace_back(val);\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\nMaybe<int64_t> GetDefaultSessionId() {\n  std::unique_lock<std::mutex> lock(*GlobalSessionUtilMutex());\n  const auto& regsitered_ids = *(RegsiteredSessionIds());\n  CHECK_GT_OR_RETURN(regsitered_ids.size(), 0);\n  return regsitered_ids.back();\n}\n\nbool RegsterSessionId(int64_t session_id) {\n  std::unique_lock<std::mutex> lock(*GlobalSessionUtilMutex());\n  auto* regsitered_ids = RegsiteredSessionIds();\n  auto itor = std::find(regsitered_ids->begin(), regsitered_ids->end(), session_id);\n  if (itor != regsitered_ids->end()) { return false; }\n  regsitered_ids->push_back(session_id);\n  return true;\n}\n\nbool ClearSessionId(int64_t session_id) {\n  std::unique_lock<std::mutex> lock(*GlobalSessionUtilMutex());\n  auto* regsitered_ids = RegsiteredSessionIds();\n  auto itor = std::find(regsitered_ids->begin(), regsitered_ids->end(), session_id);\n  if (itor == regsitered_ids->end()) { return false; }\n  regsitered_ids->erase(itor);\n  return true;\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/framework/session_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_FRAMEWORK_SESSION_UTIL_H_\n#define ONEFLOW_CORE_FRAMEWORK_SESSION_UTIL_H_\n\n#include \"oneflow/core/common/maybe.h\"\n\nnamespace oneflow {\n\nMaybe<int64_t> GetDefaultSessionId();\nbool RegsterSessionId(int64_t session_id);\nbool ClearSessionId(int64_t session_id);\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_FRAMEWORK_SESSION_UTIL_H_\n"
  },
  {
    "path": "oneflow/core/framework/shut_down_util.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/shut_down_util.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nstd::atomic<bool>* GetShuttingDown() {\n  static std::atomic<bool> shutting_down{false};\n  return &shutting_down;\n}\n\n}  // namespace\n\nbool IsShuttingDown() {\n  auto* shutting_down = GetShuttingDown();\n  bool is_interpreter_shutdown = *shutting_down;\n  return is_interpreter_shutdown;\n}\n\nvoid SetShuttingDown(bool arg_shutting_down) {\n  auto* shutting_down = GetShuttingDown();\n  *shutting_down = arg_shutting_down;\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/framework/shut_down_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_FRAMEWORK_PYTHON_INTERPRETER_UTIL_H_\n#define ONEFLOW_CORE_FRAMEWORK_PYTHON_INTERPRETER_UTIL_H_\n\n#include \"oneflow/core/common/maybe.h\"\n\nnamespace oneflow {\n\nbool IsShuttingDown();\n\nvoid SetShuttingDown(bool arg_shutting_down = true);\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_FRAMEWORK_PYTHON_INTERPRETER_UTIL_H_\n"
  },
  {
    "path": "oneflow/core/framework/stream.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/stream.h\"\n#include \"oneflow/core/framework/stream_is_comm_net_stream.h\"\n#include \"oneflow/core/thread/thread_global_id.h\"\n#include \"oneflow/core/common/decorator.h\"\n#include \"oneflow/core/common/static_global.h\"\n#include \"oneflow/core/common/singleton.h\"\n#include \"oneflow/core/job/parallel_desc.h\"\n#include \"oneflow/core/framework/stream_mgr.h\"\n#include \"oneflow/core/vm/stream_get_allocator_stream_type.h\"\n#include \"oneflow/core/ep/include/device_manager.h\"\n#include \"oneflow/core/ep/include/device_manager_registry.h\"\n\nnamespace oneflow {\n\nStream::Stream(Symbol<Device> device, StreamType stream_type, size_t thread_uid)\n    : device_(device),\n      stream_type_(stream_type),\n      thread_uid_(thread_uid),\n      unique_stream_id_(-1),\n      support_wait_event_(false) {\n  ep::DeviceManager* device_mgr =\n      Singleton<ep::DeviceManagerRegistry>::Get()->GetDeviceManagerOrNull(device->enum_type());\n  if (!device_mgr) { return; }\n  support_wait_event_ = device_mgr->IsStreamWaitEventSupported();\n}\n\nMaybe<void> Stream::Init(size_t unique_stream_id) {\n  unique_stream_id_ = unique_stream_id;\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<Symbol<Stream>> Stream::RawNew(Symbol<Device> device, StreamType stream_type,\n                                                size_t thread_uid) {\n  std::shared_ptr<Stream> stream(new Stream(device, stream_type, thread_uid));\n  return JUST(SingletonMaybe<StreamMgr>())\n      ->AddStreamSymbol(*stream, [&](size_t unique_stream_id) -> Maybe<Symbol<Stream>> {\n        JUST(stream->Init(unique_stream_id));\n        return SymbolOf(*stream);\n      });\n}\n\n/*static*/ Maybe<Symbol<Stream>> Stream::New(Symbol<Device> device, StreamType stream_type,\n                                             size_t thread_uid) {\n  constexpr auto* Make = DECORATE(&Stream::RawNew, ThreadLocalCopiable);\n  return Make(device, stream_type, thread_uid);\n}\n\nnamespace {\n\nMaybe<Symbol<Stream>> RawGetDefaultStreamByDevice(Symbol<Device> device) {\n  return Stream::New(device, StreamType::kCompute);\n}\n\nMaybe<Symbol<Stream>> RawGetDefaultStreamByPlacement(Symbol<ParallelDesc> parallel_desc) {\n  return RawGetDefaultStreamByDevice(JUST(GetTensorDevice(parallel_desc)));\n}\n\nMaybe<Symbol<Stream>> RawGetAllocatorStream(Symbol<Stream> stream) {\n  StreamType allocator_stream_type = JUST(GetAllocatorStreamType::Visit(stream->stream_type()));\n  if (allocator_stream_type == stream->stream_type()) { return stream; }\n  return Stream::New(stream->device(), allocator_stream_type, stream->thread_uid());\n}\n\n}  // namespace\n\nint64_t Stream::kDefaultStreamThreadUid = 0;\n\ndecltype(GetDefaultStreamByDevice) GetDefaultStreamByDevice =\n    DECORATE(&RawGetDefaultStreamByDevice, ThreadLocal);\n\ndecltype(GetDefaultStreamByPlacement) GetDefaultStreamByPlacement =\n    DECORATE(&RawGetDefaultStreamByPlacement, ThreadLocal);\n\ndecltype(GetAllocatorStream) GetAllocatorStream = DECORATE(&RawGetAllocatorStream, ThreadLocal);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/framework/stream.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_FRAMEWORK_STREAM_H_\n#define ONEFLOW_CORE_FRAMEWORK_STREAM_H_\n\n#include <functional>\n#include \"oneflow/core/common/stream_type.h\"\n#include \"oneflow/core/common/symbol.h\"\n#include \"oneflow/core/common/optional.h\"\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/framework/device.h\"\n\nnamespace oneflow {\n\nclass Stream final {\n public:\n  Stream(const Stream&) = default;\n  Stream(Stream&&) = default;\n  ~Stream() = default;\n\n  bool operator==(const Stream& that) const {\n    return this->device() == that.device() && this->stream_type() == that.stream_type()\n           && this->thread_uid() == that.thread_uid()\n           && this->support_wait_event() == that.support_wait_event();\n  }\n  bool operator!=(const Stream& that) const { return !(*this == that); }\n\n  static Maybe<Symbol<Stream>> New(Symbol<Device> device, StreamType stream_type) {\n    return New(device, stream_type, kDefaultStreamThreadUid);\n  }\n  static Maybe<Symbol<Stream>> New(Symbol<Device> device, StreamType stream_type,\n                                   size_t thread_uid);\n\n  Symbol<Device> device() const { return device_; }\n  StreamType stream_type() const { return stream_type_; }\n  size_t thread_uid() const { return thread_uid_; }\n  size_t unique_stream_id() const { return unique_stream_id_; }\n  bool support_wait_event() const { return support_wait_event_; }\n\n  static int64_t kDefaultStreamThreadUid;\n\n private:\n  Stream(Symbol<Device> device, StreamType stream_type, size_t thread_uid);\n\n  static Maybe<Symbol<Stream>> RawNew(Symbol<Device> device, StreamType stream_type,\n                                      size_t thread_uid);\n\n  Maybe<void> Init(size_t unique_stream_id);\n\n  Symbol<Device> device_;\n  StreamType stream_type_;\n  size_t thread_uid_;\n  size_t unique_stream_id_;\n  bool support_wait_event_;\n};\n\nextern Maybe<Symbol<Stream>> (*GetDefaultStreamByDevice)(Symbol<Device>);\nclass ParallelDesc;\nextern Maybe<Symbol<Stream>> (*GetDefaultStreamByPlacement)(Symbol<ParallelDesc>);\n\nextern Maybe<Symbol<Stream>> (*GetAllocatorStream)(Symbol<Stream>);\n\n}  // namespace oneflow\n\nnamespace std {\ntemplate<>\nstruct hash<oneflow::Stream> final {\n  size_t operator()(const oneflow::Stream& stream) const {\n    using namespace oneflow;\n    return Hash(stream.device(), stream.stream_type(), stream.thread_uid());\n  }\n};\n\n}  // namespace std\n#endif  // ONEFLOW_CORE_FRAMEWORK_STREAM_H_\n"
  },
  {
    "path": "oneflow/core/framework/stream_allocator_is_pinned.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_FRAMEWORK_STREAM_ALLOCATOR_IS_PINNED_H_\n#define ONEFLOW_CORE_FRAMEWORK_STREAM_ALLOCATOR_IS_PINNED_H_\n\n#include \"oneflow/core/common/stream_type.h\"\n\nnamespace oneflow {\n\nstruct IsStreamAllocatorPinned : public StreamTypeVisitor<IsStreamAllocatorPinned> {\n  static bool VisitCompute() { return false; }\n  static bool VisitHost2Device() { return false; }\n  static bool VisitDevice2Host() { return false; }\n  static bool VisitCcl() { return false; }\n  static bool VisitBarrier() { return false; }\n  static bool VisitCriticalSection() { return false; }\n  static bool VisitLazyJobLauncher() { return false; }\n  static bool VisitPinnedCompute() { return true; }\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_FRAMEWORK_STREAM_ALLOCATOR_IS_PINNED_H_\n"
  },
  {
    "path": "oneflow/core/framework/stream_get_stream_type_name.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_FRAMEWORK_STREAM_GET_STREAM_TYPE_NAME_H_\n#define ONEFLOW_CORE_FRAMEWORK_STREAM_GET_STREAM_TYPE_NAME_H_\n\n#include <string>\n#include \"oneflow/core/common/stream_type.h\"\n#include \"oneflow/core/common/device_type.h\"\n#include \"oneflow/core/framework/to_string.h\"\n\nnamespace oneflow {\n\nstruct GetStreamTypeName : public StreamTypeVisitor<GetStreamTypeName> {\n  static const char* VisitCompute() { return \"compute\"; }\n  static const char* VisitHost2Device() { return \"h2d\"; }\n  static const char* VisitDevice2Host() { return \"d2h\"; }\n  static const char* VisitCcl() { return \"ccl\"; }\n  static const char* VisitBarrier() { return \"barrier\"; }\n  static const char* VisitCriticalSection() { return \"critical_section\"; }\n  static const char* VisitLazyJobLauncher() { return \"lazy_job_launcher\"; }\n  static const char* VisitPinnedCompute() { return \"pinned_compute\"; }\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_FRAMEWORK_STREAM_GET_STREAM_TYPE_NAME_H_\n"
  },
  {
    "path": "oneflow/core/framework/stream_guard.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/stream_guard.h\"\n\nnamespace oneflow {\n\n/*static*/ Optional<StreamConverter>* StreamGuard::MutCurrent() {\n  static thread_local Optional<StreamConverter> current;\n  return &current;\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/framework/stream_guard.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_FRAMEWORK_STREAM_GUARD_H_\n#define ONEFLOW_CORE_FRAMEWORK_STREAM_GUARD_H_\n\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/common/decorator.h\"\n#include \"oneflow/core/common/env_var/stream.h\"\n#include \"oneflow/core/framework/stream.h\"\n#include \"oneflow/core/framework/stream_set.h\"\n#include \"oneflow/core/framework/stream_is_comm_net_stream.h\"\n#include \"oneflow/core/thread/thread_global_id.h\"\n\nnamespace oneflow {\n\nclass StreamConverter final {\n public:\n  explicit StreamConverter(const std::shared_ptr<StreamSet>& stream_set)\n      : stream_set_(stream_set) {}\n\n  Maybe<Symbol<Stream>> TryConvertStream(Symbol<Stream> stream) {\n    size_t thread_uid = stream_set_->worker_thread_id();\n    return Stream::New(stream->device(), stream->stream_type(), thread_uid);\n  }\n\n private:\n  const std::shared_ptr<StreamSet> stream_set_;\n};\n\nclass StreamGuard final {\n public:\n  explicit StreamGuard(const std::shared_ptr<StreamConverter>& stream_converter) {\n    old_value_ = Current();\n    *MutCurrent() = stream_converter;\n  }\n  ~StreamGuard() { *MutCurrent() = old_value_; }\n\n  static Maybe<Symbol<Stream>> TryConvertStream(Symbol<Stream> stream) {\n    if (!Current().has_value()) { return stream; }\n    return JUST(Current())->TryConvertStream(stream);\n  }\n\n private:\n  static const Optional<StreamConverter>& Current() { return *MutCurrent(); }\n  static Optional<StreamConverter>* MutCurrent();\n\n  Optional<StreamConverter> old_value_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_FRAMEWORK_STREAM_GUARD_H_\n"
  },
  {
    "path": "oneflow/core/framework/stream_is_comm_net_stream.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_FRAMEWORK_STREAM_IS_COMM_NET_STREAM_H_\n#define ONEFLOW_CORE_FRAMEWORK_STREAM_IS_COMM_NET_STREAM_H_\n\n#include \"oneflow/core/common/stream_type.h\"\n\nnamespace oneflow {\n\nstruct IsCommNetStream final : public StreamTypeVisitor<IsCommNetStream> {\n  static bool VisitCompute() { return false; }\n  static bool VisitHost2Device() { return false; }\n  static bool VisitDevice2Host() { return false; }\n  static bool VisitCcl() { return true; }\n  static bool VisitBarrier() { return false; }\n  static bool VisitCriticalSection() { return false; }\n  static bool VisitLazyJobLauncher() { return false; }\n  static bool VisitPinnedCompute() { return VisitCompute(); }\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_FRAMEWORK_STREAM_IS_COMM_NET_STREAM_H_\n"
  },
  {
    "path": "oneflow/core/framework/stream_mgr.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/stream_mgr.h\"\n#include \"oneflow/core/common/container_util.h\"\n#include \"oneflow/core/common/singleton.h\"\n#include \"oneflow/core/common/util.h\"\n\nnamespace oneflow {\n\nMaybe<Symbol<Stream>> StreamMgr::AddStreamSymbol(\n    const Stream& stream,\n    const std::function<Maybe<Symbol<Stream>>(size_t unique_stream_id)>& CreateStreamSymbol) {\n  Symbol<Stream> stream_symbol;\n  std::unique_lock<std::mutex> lock(mutex_);\n  if (stream2unique_stream_id_.count(stream) > 0) {\n    size_t unique_stream_id = stream2unique_stream_id_[stream];\n    auto existed_stream_symbol = JUST(VectorAt(unique_stream_id2stream_symbol_, unique_stream_id));\n    stream_symbol = JUST(CreateStreamSymbol(unique_stream_id));\n    CHECK_OR_RETURN(existed_stream_symbol == stream_symbol)\n        << \"the result of current called CreateStreamSymbol is not the result of last called \"\n           \"CreateStreamSymbol\";\n  } else {\n    size_t unique_stream_id = unique_stream_id2stream_symbol_.size();\n    stream2unique_stream_id_[stream] = unique_stream_id;\n    stream_symbol = JUST(CreateStreamSymbol(unique_stream_id));\n    unique_stream_id2stream_symbol_.push_back(stream_symbol);\n    CHECK_OR_RETURN(unique_stream_id2stream_symbol_[unique_stream_id] == stream)\n        << \"the result of CreateStreamSymbol is no the symbol of `stream`\";\n    CHECK_EQ_OR_RETURN(unique_stream_id2stream_symbol_[unique_stream_id]->unique_stream_id(),\n                       unique_stream_id)\n        << \"unique_stream_id is wrongly initialized\";\n  }\n  return stream_symbol;\n}\n\nsize_t StreamMgr::UniqueStreamSize() const {\n  std::unique_lock<std::mutex> lock(mutex_);\n  return unique_stream_id2stream_symbol_.size();\n}\n\nMaybe<Symbol<Stream>> StreamMgr::GetStreamSymbol(size_t unique_stream_id) const {\n  std::unique_lock<std::mutex> lock(mutex_);\n  return JUST(VectorAt(unique_stream_id2stream_symbol_, unique_stream_id));\n}\n\nCOMMAND(Singleton<StreamMgr>::SetAllocated(new StreamMgr()));\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/framework/stream_mgr.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_FRAMEWORK_STREAM_MGR_H_\n#define ONEFLOW_CORE_FRAMEWORK_STREAM_MGR_H_\n\n#include <mutex>\n#include <functional>\n#include \"oneflow/core/common/symbol.h\"\n#include \"oneflow/core/common/optional.h\"\n#include \"oneflow/core/framework/stream.h\"\n\nnamespace oneflow {\n\nclass StreamMgr final {\n public:\n  StreamMgr() = default;\n  ~StreamMgr() = default;\n\n  Maybe<Symbol<Stream>> AddStreamSymbol(\n      const Stream& stream,\n      const std::function<Maybe<Symbol<Stream>>(size_t unique_stream_id)>& CreateStreamSymbol);\n\n  size_t UniqueStreamSize() const;\n\n  Maybe<Symbol<Stream>> GetStreamSymbol(size_t unique_stream_id) const;\n\n private:\n  mutable std::mutex mutex_;\n  std::vector<Symbol<Stream>> unique_stream_id2stream_symbol_;\n  std::unordered_map<Stream, size_t> stream2unique_stream_id_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_FRAMEWORK_STREAM_MGR_H_\n"
  },
  {
    "path": "oneflow/core/framework/stream_need_soft_sync.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_FRAMEWORK_STREAM_NEED_SOFT_SYNC_H_\n#define ONEFLOW_CORE_FRAMEWORK_STREAM_NEED_SOFT_SYNC_H_\n\n#include \"oneflow/core/common/device_type.h\"\n#include \"oneflow/core/common/stream_type.h\"\n\nnamespace oneflow {\n\nstruct NeedSoftSync : public StreamTypeVisitor<NeedSoftSync> {\n  static bool VisitCompute(DeviceType device_type) { return device_type != kCPU; }\n  static bool VisitHost2Device(DeviceType) { return false; }\n  static bool VisitDevice2Host(DeviceType) { return false; }\n  static bool VisitCcl(DeviceType device_type) { return false; }\n  static bool VisitBarrier(DeviceType) { return false; }\n  static bool VisitCriticalSection(DeviceType) { return false; }\n  static bool VisitLazyJobLauncher(DeviceType) { return false; }\n  static bool VisitPinnedCompute(DeviceType device_type) { return VisitCompute(device_type); }\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_FRAMEWORK_STREAM_NEED_SOFT_SYNC_H_\n"
  },
  {
    "path": "oneflow/core/framework/stream_on_independent_thread.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_FRAMEWORK_STREAM_ON_INDEPENDENT_THREAD_H_\n#define ONEFLOW_CORE_FRAMEWORK_STREAM_ON_INDEPENDENT_THREAD_H_\n\n#include \"oneflow/core/common/stream_type.h\"\n\nnamespace oneflow {\n\nstruct StreamOnIndependentThread : public StreamTypeVisitor<StreamOnIndependentThread> {\n  static bool VisitCompute() { return false; }\n  static bool VisitHost2Device() { return false; }\n  static bool VisitDevice2Host() { return false; }\n  static bool VisitCcl() { return false; }\n  static bool VisitBarrier() { return false; }\n  static bool VisitCriticalSection() { return true; }\n  static bool VisitLazyJobLauncher() { return true; }\n  static bool VisitPinnedCompute() { return VisitCompute(); }\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_FRAMEWORK_STREAM_ON_INDEPENDENT_THREAD_H_\n"
  },
  {
    "path": "oneflow/core/framework/stream_set.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <vector>\n#include <mutex>\n#include <set>\n#include <map>\n#include \"oneflow/core/framework/stream_set.h\"\n#include \"oneflow/core/thread/thread_global_id.h\"\n#include \"oneflow/core/common/env_var/stream.h\"\n#include \"oneflow/core/common/container_util.h\"\n\nnamespace oneflow {\n\nStreamSet::StreamSet(int64_t worker_thread_id) : worker_thread_id_(worker_thread_id) {}\n\nStreamSet::~StreamSet() {}\n\n/*static*/ Maybe<StreamSet> StreamSet::New(int64_t worker_thread_id) {\n  return std::shared_ptr<StreamSet>(new StreamSet(worker_thread_id));\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/framework/stream_set.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_FRAMEWORK_STREAM_SET_H_\n#define ONEFLOW_CORE_FRAMEWORK_STREAM_SET_H_\n\n#include <unordered_map>\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/common/optional.h\"\n#include \"oneflow/core/framework/stream.h\"\n\nnamespace oneflow {\n\nclass StreamSet final {\n public:\n  ~StreamSet();\n\n  static Maybe<StreamSet> New(int64_t worker_thread_id);\n\n  int64_t worker_thread_id() const { return worker_thread_id_; }\n\n private:\n  StreamSet(int64_t worker_thread_id);\n\n  int64_t worker_thread_id_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_FRAMEWORK_STREAM_SET_H_\n"
  },
  {
    "path": "oneflow/core/framework/stream_support_stream_wait.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_FRAMEWORK_STREAM_SUPPORT_STREAM_WAIT_H_\n#define ONEFLOW_CORE_FRAMEWORK_STREAM_SUPPORT_STREAM_WAIT_H_\n\n#include \"oneflow/core/common/stream_type.h\"\n\nnamespace oneflow {\n\nstruct StreamSupportStreamWait : public StreamTypeVisitor<StreamSupportStreamWait> {\n  static bool VisitCompute(DeviceType device_type) { return Supported(device_type); }\n  static bool VisitHost2Device(DeviceType device_type) { return Supported(device_type); }\n  static bool VisitDevice2Host(DeviceType device_type) { return Supported(device_type); }\n  static bool VisitCcl(DeviceType device_type) { return Supported(device_type); }\n  static bool VisitBarrier(DeviceType device_type) { return false; }\n  static bool VisitCriticalSection(DeviceType device_type) { return false; }\n  static bool VisitLazyJobLauncher(DeviceType device_type) { return false; }\n  static bool VisitPinnedCompute(DeviceType device_type) { return VisitCompute(device_type); }\n\n private:\n  static bool Supported(DeviceType device_type) { return device_type == kCUDA; }\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_FRAMEWORK_STREAM_SUPPORT_STREAM_WAIT_H_\n"
  },
  {
    "path": "oneflow/core/framework/symbol_storage_util.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/job/job_desc.h\"\n#include \"oneflow/core/job/scope.h\"\n#include \"oneflow/core/operator/op_node_signature.pb.h\"\n#include \"oneflow/core/operator/operator.h\"\n#include \"oneflow/core/operator/op_conf_symbol.h\"\n#include \"oneflow/core/vm/symbol_storage.h\"\n\nnamespace oneflow {\n\nCOMMAND(\n    Singleton<symbol::Storage<ParallelDesc>>::SetAllocated(new symbol::Storage<ParallelDesc>()));\nCOMMAND(Singleton<symbol::Storage<Scope>>::SetAllocated(new symbol::Storage<Scope>()));\nCOMMAND(Singleton<symbol::Storage<JobDesc>>::SetAllocated(new symbol::Storage<JobDesc>()));\nCOMMAND(Singleton<symbol::Storage<OperatorConfSymbol>>::SetAllocated(\n    new symbol::Storage<OperatorConfSymbol>()));\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/framework/symbol_storage_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_FRAMEWORK_SYMBOL_STORAGE_H_\n#define ONEFLOW_CORE_FRAMEWORK_SYMBOL_STORAGE_H_\n\n#include \"oneflow/core/vm/symbol_storage.h\"\n\nnamespace oneflow {\n\ntemplate<typename SymbolT>\nMaybe<SymbolT> GetSymbol(int64_t symbol_id) {\n  const auto& symbol_storage = *Singleton<symbol::Storage<SymbolT>>::Get();\n  const auto& ptr = JUST(symbol_storage.MaybeGetPtr(symbol_id));\n  JUST(ptr->symbol_id());\n  return ptr;\n}\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_FRAMEWORK_SYMBOL_STORAGE_H_\n"
  },
  {
    "path": "oneflow/core/framework/sync_symbol_global_tensor_meta.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/sync_symbol_global_tensor_meta.h\"\n#include \"oneflow/core/framework/sync_symbol_parallel_desc.h\"\n#include \"oneflow/core/framework/sync_symbol_nd_sbp.h\"\n#include \"oneflow/core/framework/rank_group_rpc_util.h\"\n#include \"oneflow/core/common/tensor_meta.h\"\n#include \"oneflow/core/framework/synced_symbol_map.h\"\n#include \"oneflow/core/common/flat_shape.h\"\n\nnamespace oneflow {\n\nstruct FlatGlobalTensorMeta final {\n  static Maybe<FlatGlobalTensorMeta> New(uint64_t symbol_id,\n                                         Symbol<one::GlobalTensorMeta> global_tensor_meta) {\n    const auto& meta = std::make_shared<FlatGlobalTensorMeta>();\n    JUST(meta->Init(symbol_id, global_tensor_meta));\n    return meta;\n  }\n\n  Maybe<void> Init(uint64_t symbol_id, Symbol<one::GlobalTensorMeta> global_tensor_meta) {\n    this->symbol_id = symbol_id;\n    JUST(this->shape.Init(global_tensor_meta->shape()));\n    this->dtype = static_cast<int32_t>(global_tensor_meta->dtype());\n    this->is_dynamic = global_tensor_meta->is_dynamic();\n    this->nd_sbp =\n        JUST(SyncedSymbolMap<NdSbp>::FindOrSync(global_tensor_meta->nd_sbp(), &SyncSymbolNdSbp));\n    this->parallel_desc = JUST(SyncedSymbolMap<ParallelDesc>::FindOrSync(\n        global_tensor_meta->parallel_desc(), &SyncSymbolParallelDesc));\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Check(uint64_t symbol_id, Symbol<one::GlobalTensorMeta> global_tensor_meta) {\n    CHECK_EQ_OR_RETURN(this->symbol_id, symbol_id);\n    JUST(this->shape.Check(global_tensor_meta->shape()));\n    CHECK_EQ_OR_RETURN(static_cast<DataType>(this->dtype), global_tensor_meta->dtype());  // NOLINT\n    CHECK_EQ_OR_RETURN(this->is_dynamic, global_tensor_meta->is_dynamic());               // NOLINT\n    const auto& nd_sbp = JUST(SyncedSymbolMap<NdSbp>::Symbol4SyncedSymbolId(this->nd_sbp));\n    CHECK_OR_RETURN(nd_sbp == global_tensor_meta->nd_sbp());  // NOLINT\n    const auto& parallel_desc =\n        JUST(SyncedSymbolMap<ParallelDesc>::Symbol4SyncedSymbolId(this->parallel_desc));\n    CHECK_OR_RETURN(parallel_desc == global_tensor_meta->parallel_desc());  // NOLINT\n    return Maybe<void>::Ok();\n  }\n\n  uint64_t symbol_id;\n  FlatShape shape;\n  int32_t dtype;\n  bool is_dynamic;\n  uint64_t nd_sbp;\n  uint64_t parallel_desc;\n};\n\nMaybe<void> SyncSymbolGlobalTensorMeta(uint64_t symbol_id,\n                                       Symbol<one::GlobalTensorMeta> global_tensor_meta) {\n  const auto& transport_token =\n      JUST(TransportToken::NewTransportToken(kTransportTokenTypeSyncSymbolGlobalTensorMeta));\n  const auto& recv_buffer = std::make_shared<FlatGlobalTensorMeta>();\n  NaiveAsyncTransportCtx ctx(\n      transport_token,\n      [&](void** buffer, std::size_t* size, std::function<void()>* Cb) -> Maybe<void> {\n        const auto& send_buffer = JUST(FlatGlobalTensorMeta::New(symbol_id, global_tensor_meta));\n        *buffer = send_buffer.get();\n        *size = sizeof(FlatGlobalTensorMeta);\n        *Cb = [send_buffer] {};\n        return Maybe<void>::Ok();\n      },\n      [recv_buffer](void** buffer, std::size_t* size, std::function<void()>* Cb) -> Maybe<void> {\n        *buffer = recv_buffer.get();\n        *size = sizeof(FlatGlobalTensorMeta);\n        *Cb = [recv_buffer] {};\n        return Maybe<void>::Ok();\n      });\n  const auto& rank_group = JUST(RankGroupScope::CurrentRankGroup());\n  JUST(TransportUtil::SendToNextRankInRing(rank_group, transport_token, &ctx));\n  JUST(TransportUtil::ReceiveFromPrevRankInRing(rank_group, transport_token, &ctx));\n  JUST(ctx.WaitDone());\n  JUST(recv_buffer->Check(symbol_id, global_tensor_meta));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/framework/sync_symbol_global_tensor_meta.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_FRAMEWORK_SYNC_SYMBOL_GLOBAL_TENSOR_META_H_\n#define ONEFLOW_CORE_FRAMEWORK_SYNC_SYMBOL_GLOBAL_TENSOR_META_H_\n\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/common/symbol.h\"\n#include \"oneflow/core/framework/transport_util.h\"\n#include \"oneflow/core/framework/transport_token.h\"\n\nnamespace oneflow {\n\nnamespace one {\nclass GlobalTensorMeta;\n}\n\nMaybe<void> SyncSymbolGlobalTensorMeta(uint64_t symbol_id, Symbol<one::GlobalTensorMeta>);\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_FRAMEWORK_SYNC_SYMBOL_GLOBAL_TENSOR_META_H_\n"
  },
  {
    "path": "oneflow/core/framework/sync_symbol_nd_sbp.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/intrusive/flat_msg.h\"\n#include \"oneflow/core/framework/sync_symbol_nd_sbp.h\"\n#include \"oneflow/core/framework/rank_group_rpc_util.h\"\n#include \"oneflow/core/job/rank_group_scope.h\"\n#include \"oneflow/core/job/sbp_parallel.h\"\n#include \"oneflow/core/common/shape_vec.h\"\n#include \"oneflow/core/common/constant.h\"\n\nnamespace oneflow {\n\nnamespace {\n\n// clang-format off\nFLAT_MSG_BEGIN(FlatSplitParallel);\n  FLAT_MSG_DEFINE_OPTIONAL(int64_t, axis);\nFLAT_MSG_END(FlatSplitParallel);\n\nFLAT_MSG_BEGIN(FlatBroadcastParallel);\nFLAT_MSG_END(FlatBroadcastParallel);\n\nFLAT_MSG_BEGIN(FlatPartialSumParallel);\nFLAT_MSG_END(FlatPartialSumParallel);\n\nFLAT_MSG_BEGIN(FlatSbpParallel);\n public:\n  Maybe<void> Init(const SbpParallel& sbp_parallel) {\n    if (sbp_parallel.has_split_parallel()) {\n      this->mutable_split_parallel()->set_axis(sbp_parallel.split_parallel().axis());\n    } else if (sbp_parallel.has_broadcast_parallel()) {\n      this->mutable_broadcast_parallel();\n    } else if (sbp_parallel.has_partial_sum_parallel()) {\n      this->mutable_partial_sum_parallel();\n    } else {\n      OF_UNIMPLEMENTED();\n    }\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Check(const SbpParallel& sbp_parallel) const {\n    if (sbp_parallel.has_split_parallel()) {\n      CHECK_EQ_OR_RETURN(this->split_parallel().axis(), sbp_parallel.split_parallel().axis());\n    } else if (sbp_parallel.has_broadcast_parallel()) {\n      CHECK_OR_RETURN(this->has_broadcast_parallel());\n    } else if (sbp_parallel.has_partial_sum_parallel()) {\n      CHECK_OR_RETURN(this->has_partial_sum_parallel());\n    } else {\n      OF_UNIMPLEMENTED();\n    }\n    return Maybe<void>::Ok();\n  }\n\n private:\n  FLAT_MSG_DEFINE_ONEOF(parallel_type,\n    FLAT_MSG_ONEOF_FIELD(FlatSplitParallel, split_parallel)\n    FLAT_MSG_ONEOF_FIELD(FlatBroadcastParallel, broadcast_parallel)\n    FLAT_MSG_ONEOF_FIELD(FlatPartialSumParallel, partial_sum_parallel));\nFLAT_MSG_END(FlatSbpParallel);\n\nFLAT_MSG_BEGIN(FlatNdSbp);\n public:\n  Maybe<void> Init(uint64_t symbol_id, Symbol<NdSbp> nd_sbp) {\n    this->set_symbol_id(symbol_id);\n    this->set_size(nd_sbp->sbp_parallel_size());\n    for (int i = 0; i < this->size(); ++i) {\n      const auto& sbp_parallel = nd_sbp->sbp_parallel(i);\n      JUST(this->mutable_sbp_parallel()->Mutable(i)->Init(sbp_parallel));\n    }\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Check(uint64_t symbol_id, Symbol<NdSbp> nd_sbp) const {\n    CHECK_EQ_OR_RETURN(this->symbol_id(), symbol_id);\n    CHECK_EQ_OR_RETURN(this->size(), nd_sbp->sbp_parallel_size());\n    for (int i = 0; i < this->size(); ++i) {\n      JUST(this->sbp_parallel().Get(i).Check(nd_sbp->sbp_parallel(i)));\n    }\n    return Maybe<void>::Ok();\n  }\n\n private:\n  FLAT_MSG_DEFINE_OPTIONAL(uint64_t, symbol_id);\n  FLAT_MSG_DEFINE_OPTIONAL(size_t, size);\n  FLAT_MSG_DEFINE_REPEATED(FlatSbpParallel, sbp_parallel, SHAPE_MAX_AXIS_SIZE);\nFLAT_MSG_END(FlatNdSbp);\n// clang-format on\n\nclass FlatNdSbpAsyncTransportCtx : public AsyncTransportCtx {\n public:\n  FlatNdSbpAsyncTransportCtx(const TransportToken& transport_token, uint64_t symbol_id,\n                             Symbol<NdSbp> nd_sbp)\n      : AsyncTransportCtx(transport_token), symbol_id_(symbol_id), nd_sbp_(nd_sbp) {}\n\n  ~FlatNdSbpAsyncTransportCtx() override {}\n\n  Maybe<void> PrepareSendBufferAndCallback(int64_t rank, void** buffer, std::size_t* size,\n                                           std::function<void()>* Callback) override {\n    const auto& flat_nd_sbp = std::make_shared<FlatNdSbp>();\n    JUST(flat_nd_sbp->Init(symbol_id_, nd_sbp_));\n    *buffer = flat_nd_sbp.get();\n    *size = sizeof(FlatNdSbp);\n    *Callback = [flat_nd_sbp]() {};\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> PrepareRecvBufferAndCallback(int64_t rank, void** buffer, std::size_t* size,\n                                           std::function<void()>* Callback) override {\n    const auto& flat_nd_sbp = std::make_shared<FlatNdSbp>();\n    *buffer = flat_nd_sbp.get();\n    *size = sizeof(FlatNdSbp);\n    *Callback = [flat_nd_sbp]() {};\n    flat_nd_sbp_ = flat_nd_sbp;\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Check() const {\n    CHECK_NOTNULL_OR_RETURN(flat_nd_sbp_.get());\n    JUST(flat_nd_sbp_->Check(symbol_id_, nd_sbp_));\n    return Maybe<void>::Ok();\n  }\n\n private:\n  uint64_t symbol_id_;\n  Symbol<NdSbp> nd_sbp_;\n  std::shared_ptr<FlatNdSbp> flat_nd_sbp_;\n};\n\n}  // namespace\n\nnamespace {}\n\nMaybe<void> SyncSymbolNdSbp(uint64_t symbol_id, Symbol<NdSbp> symbol) {\n  const auto& rank_group = JUST(RankGroupScope::CurrentRankGroup());\n  const auto& transport_token =\n      JUST(TransportToken::NewTransportToken(kTransportTokenTypeSyncSymbolNdSbp));\n  FlatNdSbpAsyncTransportCtx ctx(transport_token, symbol_id, symbol);\n  JUST(TransportUtil::SendToNextRankInRing(rank_group, transport_token, &ctx));\n  JUST(TransportUtil::ReceiveFromPrevRankInRing(rank_group, transport_token, &ctx));\n  JUST_MSG(ctx.WaitDone(), kAsymmetricCodeErrorMsg);\n  JUST(ctx.Check());\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/framework/sync_symbol_nd_sbp.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_FRAMEWORK_SYNC_SYMBOL_ND_SBP_H_\n#define ONEFLOW_CORE_FRAMEWORK_SYNC_SYMBOL_ND_SBP_H_\n\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/common/symbol.h\"\n#include \"oneflow/core/framework/transport_util.h\"\n#include \"oneflow/core/framework/transport_token.h\"\n\nnamespace oneflow {\n\nclass NdSbp;\n\nMaybe<void> SyncSymbolNdSbp(uint64_t symbol_id, Symbol<NdSbp>);\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_FRAMEWORK_SYNC_SYMBOL_ND_SBP_H_\n"
  },
  {
    "path": "oneflow/core/framework/sync_symbol_parallel_desc.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/sync_symbol_parallel_desc.h\"\n#include \"oneflow/core/framework/rank_group_rpc_util.h\"\n#include \"oneflow/core/job/rank_group_scope.h\"\n#include \"oneflow/core/job/parallel_desc.h\"\n#include \"oneflow/core/common/constant.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nstatic const int kLimitParallelConfString = 1024 * 64;\nstruct FlatParallelConf {\n  size_t available_size() const {\n    CHECK_GE(this->buffer_size, 0) << \"Buffer size should be non-negative\";\n    CHECK_LT(this->buffer_size, kLimitParallelConfString)\n        << \"Buffer size should be less than \" << kLimitParallelConfString;\n    return sizeof(FlatParallelConf) - kLimitParallelConfString + this->buffer_size;\n  }\n\n  size_t capacity() const { return sizeof(FlatParallelConf); }\n\n  static Maybe<FlatParallelConf> New(uint64_t symbol_id, Symbol<ParallelDesc> parallel_desc) {\n    const auto& data = std::make_shared<FlatParallelConf>();\n    JUST(data->Init(symbol_id, parallel_desc));\n    return data;\n  }\n\n  Maybe<void> Init(uint64_t symbol_id, Symbol<ParallelDesc> parallel_desc) {\n    const auto& parallel_conf = parallel_desc->parallel_conf();\n    int64_t byte_size = parallel_conf.ByteSize();\n    CHECK_LE_OR_RETURN(byte_size, kLimitParallelConfString)\n        << Error::InvalidValueError() << \"Byte size of parallel description should be less than \"\n        << kLimitParallelConfString << \", but got \" << byte_size;\n    this->symbol_id = symbol_id;\n    this->buffer_size = byte_size;\n    CHECK_OR_RETURN(parallel_conf.SerializeToArray(this->buffer, kLimitParallelConfString))\n        << Error::RuntimeError()\n        << \"Error serializing parallel description: \" << parallel_conf.ShortDebugString();\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Check(uint64_t symbol_id, Symbol<ParallelDesc> parallel_desc) const {\n    const auto& parallel_conf = parallel_desc->parallel_conf();\n    int64_t byte_size = parallel_conf.ByteSize();\n    const auto& debugString = parallel_conf.ShortDebugString();\n    CHECK_LE_OR_RETURN(byte_size, kLimitParallelConfString)\n        << Error::InvalidValueError() << \"Byte size of parallel description should be less than \"\n        << kLimitParallelConfString << \", but got \" << byte_size;\n    CHECK_EQ_OR_RETURN(this->symbol_id, symbol_id) << Error::RuntimeError() << \"expected symbol id \"\n                                                   << symbol_id << \", but got \" << this->symbol_id;\n    CHECK_EQ_OR_RETURN(this->buffer_size, byte_size)\n        << Error::RuntimeError() << \"Inconsistent parallel description: \" << debugString;\n    std::vector<char> serialized(byte_size);\n    CHECK_OR_RETURN(parallel_conf.SerializeToArray(serialized.data(), kLimitParallelConfString))\n        << Error::RuntimeError() << \"Error serializing parallel description: \" << debugString;\n    CHECK_EQ_OR_RETURN(std::memcmp(serialized.data(), this->buffer, byte_size), 0)\n        << Error::RuntimeError() << \"Inconsistent parallel description: \" << debugString;\n    return Maybe<void>::Ok();\n  }\n\n  uint64_t symbol_id;\n  uint64_t buffer_size;\n  char buffer[kLimitParallelConfString];\n};\n\n}  // namespace\n\nMaybe<void> SyncSymbolParallelDesc(uint64_t symbol_id, Symbol<ParallelDesc> parallel_desc) {\n  const auto& transport_token =\n      JUST(TransportToken::NewTransportToken(kTransportTokenTypeSyncSymbolParallelDesc));\n  const auto& recv_buffer = std::make_shared<FlatParallelConf>();\n  NaiveAsyncTransportCtx ctx(\n      transport_token,\n      [&](void** buffer, std::size_t* size, std::function<void()>* Cb) -> Maybe<void> {\n        const auto& send_buffer = JUST(FlatParallelConf::New(symbol_id, parallel_desc));\n        *buffer = send_buffer.get();\n        *size = send_buffer->available_size();\n        *Cb = [send_buffer] {};\n        return Maybe<void>::Ok();\n      },\n      [recv_buffer](void** buffer, std::size_t* size, std::function<void()>* Cb) -> Maybe<void> {\n        *buffer = recv_buffer.get();\n        *size = recv_buffer->capacity();\n        *Cb = [recv_buffer] {};\n        return Maybe<void>::Ok();\n      });\n  const auto& rank_group = JUST(RankGroupScope::CurrentRankGroup());\n  JUST(TransportUtil::SendToNextRankInRing(rank_group, transport_token, &ctx));\n  JUST(TransportUtil::ReceiveFromPrevRankInRing(rank_group, transport_token, &ctx));\n  JUST_MSG(ctx.WaitDone(), kAsymmetricCodeErrorMsg);\n  JUST(recv_buffer->Check(symbol_id, parallel_desc));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/framework/sync_symbol_parallel_desc.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_FRAMEWORK_SYNC_SYMBOL_PARALLEL_DESC_H_\n#define ONEFLOW_CORE_FRAMEWORK_SYNC_SYMBOL_PARALLEL_DESC_H_\n\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/common/symbol.h\"\n#include \"oneflow/core/framework/transport_util.h\"\n#include \"oneflow/core/framework/transport_token.h\"\n\nnamespace oneflow {\n\nclass ParallelDesc;\n\nMaybe<void> SyncSymbolParallelDesc(uint64_t symbol_id, Symbol<ParallelDesc>);\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_FRAMEWORK_SYNC_SYMBOL_PARALLEL_DESC_H_\n"
  },
  {
    "path": "oneflow/core/framework/synced_symbol_map.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/synced_symbol_map.h\"\n\nnamespace oneflow {\n\nuint64_t GetAutoIncrementalSymbolId() {\n  static thread_local uint64_t id = 4096;\n  return id++;\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/framework/synced_symbol_map.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_FRAMEWORK_SYNCED_SYMBOL_MAP_H_\n#define ONEFLOW_CORE_FRAMEWORK_SYNCED_SYMBOL_MAP_H_\n\n#include <unordered_map>\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/common/container_util.h\"\n#include \"oneflow/core/common/symbol.h\"\n#include \"oneflow/core/common/type_traits.h\"\n#include \"oneflow/core/job/rank_group_scope.h\"\n\nnamespace oneflow {\n\nuint64_t GetAutoIncrementalSymbolId();\n\ntemplate<typename T>\nstruct SyncedSymbolMap final {\n  template<typename SyncT>\n  static Maybe<uint64_t> FindOrSync(Symbol<T> symbol, const SyncT& Sync) {\n    auto* map = JUST(MutThreadLocalSymbol2SyncedSymbolId());\n    const auto& iter = map->find(symbol);\n    if (iter != map->end()) { return iter->second; }\n    uint64_t symbol_id = GetAutoIncrementalSymbolId();\n    JUST(Sync(symbol_id, symbol));\n    JUST(Emplace(symbol_id, symbol));\n    return symbol_id;\n  }\n\n  static Maybe<Symbol<T>> Symbol4SyncedSymbolId(uint64_t synced_symbol_id) {\n    auto* map = JUST(MutThreadLocalSyncedSymbolId2Symbol());\n    return JUST(MapAt(*map, synced_symbol_id));\n  }\n\n private:\n  static Maybe<void> Emplace(uint64_t synced_symbol_id, Symbol<T> symbol) {\n    auto* id2symbol = JUST(MutThreadLocalSyncedSymbolId2Symbol());\n    CHECK_OR_RETURN(id2symbol->emplace(synced_symbol_id, symbol).second);\n    auto* symbol2id = JUST(MutThreadLocalSymbol2SyncedSymbolId());\n    CHECK_OR_RETURN(symbol2id->emplace(symbol, synced_symbol_id).second);\n    return Maybe<void>::Ok();\n  }\n\n  static Maybe<std::unordered_map<uint64_t, Symbol<T>>*> MutThreadLocalSyncedSymbolId2Symbol() {\n    static thread_local auto* map =\n        new std::unordered_map<Symbol<RankGroup>, std::unordered_map<uint64_t, Symbol<T>>>();\n    const auto& rank_group = JUST(RankGroupScope::CurrentRankGroup());\n    return &(*map)[rank_group];\n  }\n\n  static Maybe<std::unordered_map<Symbol<T>, uint64_t>*> MutThreadLocalSymbol2SyncedSymbolId() {\n    static thread_local auto* map =\n        new std::unordered_map<Symbol<RankGroup>, std::unordered_map<Symbol<T>, uint64_t>>();\n    const auto& rank_group = JUST(RankGroupScope::CurrentRankGroup());\n    return &(*map)[rank_group];\n  }\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_FRAMEWORK_SYNCED_SYMBOL_MAP_H_\n"
  },
  {
    "path": "oneflow/core/framework/tensor.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/tensor.h\"\n#include \"oneflow/core/common/device_type.pb.h\"\n#include \"oneflow/core/framework/tensor_methods.h\"\n#include \"oneflow/core/framework/tensor_name_scope.h\"\n#include \"oneflow/core/framework/tensor_rpc_util.h\"\n#include \"oneflow/core/framework/nd_sbp.h\"\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/job/parallel_desc.h\"\n#include \"oneflow/core/job/job_build_and_infer_ctx_mgr.h\"\n#include \"oneflow/core/job/job_build_and_infer_ctx.h\"\n#include \"oneflow/core/framework/device.h\"\n#include \"oneflow/core/framework/dtype.h\"\n#include \"oneflow/core/framework/tensor_tuple.h\"\n#include \"oneflow/core/autograd/autograd_engine.h\"\n#include \"oneflow/core/framework/op_interpreter/eager_local_op_interpreter.h\"\n#include \"oneflow/core/functional/functional.h\"\n#include \"oneflow/core/eager/tensor_storage.h\"\n#include \"oneflow/core/vm/vm_util.h\"\n#include \"oneflow/core/vm/virtual_machine.h\"\n\nnamespace oneflow {\n\nnamespace one {\n\nMaybe<void> Tensor::BorrowTensorName(const Tensor* other) const {\n  CHECK_OR_RETURN(other->is_lazy())\n      << Error::RuntimeError() << \"can not borrow tensor name from an eager tensor\";\n  const auto& lbn = TensorNameScope::Global()->Lookup(other);\n  CHECK_OR_RETURN(!lbn.empty()) << \"the input lazy tensor has no tensor name\";\n  TensorNameScope::Global()->Record(this, lbn);\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> Tensor::set_ref_tensor(const std::shared_ptr<Tensor>& ref) {\n  ref_tensor_ = ref;\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> Tensor::set_ref_index(const int64_t index) {\n  ref_index_ = index;\n  return Maybe<void>::Ok();\n}\n\nMaybe<LocalTensor> StaticZerosTensor::AsLocalTensor() {\n  CHECK_OR_RETURN(is_local());  // NOLINT\n  return std::dynamic_pointer_cast<LocalTensor>(\n      JUST(functional::Constant(*shape_, Scalar(0), CHECK_JUST(DType::Get(dtype_)), device_)));\n}\n\nParameter::Parameter(const std::shared_ptr<Tensor>& tensor, bool requires_grad)\n    : ProxyTensor<Parameter>(tensor) {\n  CHECK_JUST(this->tensor_->set_requires_grad(requires_grad));\n  if (tensor->is_local() && tensor->is_eager()) {\n    if (auto rematable_storage = std::dynamic_pointer_cast<vm::RematableTensorStorage>(\n            CHECK_JUST(tensor_->eager_blob_object())->tensor_storage());\n        rematable_storage != nullptr && tensor_->is_local() && tensor_->is_eager()) {\n      rematable_storage->set_eviction_disabled(true);\n    }\n  }\n}\nMaybe<void> Parameter::set_data(const std::shared_ptr<Tensor>& other) {\n  if (is_local() && is_eager()) {\n    auto rematable_storage = std::dynamic_pointer_cast<vm::RematableTensorStorage>(\n        CHECK_JUST(tensor_->eager_blob_object())->tensor_storage());\n    bool enable_remat = rematable_storage != nullptr && tensor_->is_local() && tensor_->is_eager();\n    if (enable_remat) { rematable_storage->set_eviction_disabled(false); }\n    JUST(tensor_->set_data(other));\n    if (enable_remat) { rematable_storage->set_eviction_disabled(true); }\n  } else {\n    JUST(tensor_->set_data(other));\n  }\n  return Maybe<void>::Ok();\n}\n\nstd::shared_ptr<Tensor> Parameter::contiguous() const {\n  const auto& tensor = std::const_pointer_cast<Tensor>(shared_from_this());\n  if (tensor_->is_contiguous()) { return tensor; }\n  return CHECK_JUST(functional::ToContiguous(tensor));\n}\n\nstd::shared_ptr<Tensor> Parameter::pin_memory() const {\n  std::shared_ptr<Tensor> tensor = std::const_pointer_cast<Tensor>(shared_from_this());\n  return CHECK_JUST(functional::PinMemory(tensor));\n}\n\n/* static */ Maybe<LocalTensor> LocalTensor::MakeTensor(const std::shared_ptr<const Shape>& shape,\n                                                        const std::shared_ptr<const Stride>& stride,\n                                                        DataType dtype, MemoryFormat memory_format,\n                                                        const Symbol<Device>& device, bool is_lazy,\n                                                        bool requires_grad, bool is_leaf) {\n  const auto& tensor_meta = SymbolOf(LocalTensorMeta(*shape, dtype, memory_format, device));\n  if (is_lazy) {\n    const auto& impl = std::make_shared<LazyLocalTensorImpl>(tensor_meta, requires_grad, is_leaf);\n    return std::make_shared<LocalTensor>(impl);\n  } else {\n    const auto& impl = std::make_shared<EagerLocalTensorImpl>(requires_grad, is_leaf);\n    const auto& dep_object = NewLocalDepObject();\n    JUST(impl->InitEagerBlobObject(tensor_meta, dep_object));\n    return std::make_shared<LocalTensor>(impl);\n  }\n}\n\nbool LocalTensor::is_cpu() const { return CHECK_JUST(device())->type() == \"cpu\"; }\nbool LocalTensor::is_cuda() const { return CHECK_JUST(device())->type() == \"cuda\"; }\n\nMaybe<Tensor> LocalTensor::detach() const {\n  std::shared_ptr<Tensor> tensor = std::make_shared<LocalTensor>(JUST(impl_->detach()));\n  if (this->is_lazy()) { JUST(tensor->BorrowTensorName(this)); }\n  return tensor;\n}\n\nstd::shared_ptr<Tensor> LocalTensor::contiguous() const {\n  std::shared_ptr<Tensor> tensor = std::const_pointer_cast<Tensor>(shared_from_this());\n  if (tensor->is_contiguous()) { return tensor; }\n  return CHECK_JUST(functional::ToContiguous(tensor));\n}\n\nstd::shared_ptr<Tensor> LocalTensor::pin_memory() const {\n  std::shared_ptr<Tensor> tensor = std::const_pointer_cast<Tensor>(shared_from_this());\n  return CHECK_JUST(functional::PinMemory(tensor));\n}\n\nMaybe<Tensor> LocalTensor::clone() const {\n  std::shared_ptr<Tensor> input = std::const_pointer_cast<Tensor>(shared_from_this());\n  const bool pin_memory = JUST(JUST(input->AsLocalTensor())->is_pinned());\n  return JUST(functional::Copy(input, JUST(this->device()), /*pin_memory=*/pin_memory));\n}\n\nMaybe<void> LocalTensor::set_data(const std::shared_ptr<Tensor>& other) {\n  CHECK_OR_RETURN(this->is_leaf()) << \"Can only set leaf tensor's data.\";\n  const auto& mirrored_tensor = std::dynamic_pointer_cast<LocalTensor>(JUST(other->detach()));\n  CHECK_NOTNULL_OR_RETURN(mirrored_tensor)\n      << \"Can not set a global tensor to the data of a local tensor\";\n  bool old_requires_grad = requires_grad();\n  impl_ = mirrored_tensor->impl_;\n  JUST(set_requires_grad(old_requires_grad));\n  grad_fn_node_ = nullptr;\n  if (other->is_lazy()) { JUST(this->BorrowTensorName(other.get())); }\n  return Maybe<void>::Ok();\n}\n\n#define TENSOR_OFFLOAD_CHECK(is_offloaded, msg)                  \\\n  if (is_cpu()) {                                                \\\n    LOG(WARNING) << \"Only non-cpu tensor can be offloaded.\";     \\\n    return Maybe<void>::Ok();                                    \\\n  }                                                              \\\n  if (is_offloaded_ != is_offloaded) {                           \\\n    LOG(WARNING) << \"This tensor has already be \" << msg << \".\"; \\\n    return Maybe<void>::Ok();                                    \\\n  }\n\nMaybe<void> LocalTensor::offload() {\n  TENSOR_OFFLOAD_CHECK(false, \"offloaded\");\n\n  // Offload to cpu mem with a cpu tensor implantation.\n  int64_t device_id = JUST(this->device())->device_id();\n  std::shared_ptr<Tensor> cuda_tensor = shared_from_this();\n  auto offloaded_tensor =\n      JUST(functional::Copy(cuda_tensor, \"cpu\", device_id, /*pin_memory=*/JUST(is_pinned())));\n  JUST(vm::CurrentRankSync());\n\n  const auto& detached_tensor =\n      std::dynamic_pointer_cast<LocalTensor>(JUST(offloaded_tensor->detach()));\n  CHECK_NOTNULL_OR_RETURN(detached_tensor) << \" detached_tensor must be a local tensor.\";\n  offloaded_impl_ = detached_tensor->impl_;\n\n  // Release cuda memory, but the meta data is valid.\n  auto eager_blob_obj = JUST(JUST(impl_->mut_eager_local_tensor_impl())->eager_blob_object());\n  JUST(eager_blob_obj->DeallocateBlobDataPtr());\n\n  auto* vm = JUST(SingletonMaybe<VirtualMachine>());\n  JUST(vm->ShrinkAllMem());\n\n  is_offloaded_ = true;\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> LocalTensor::load() {\n  TENSOR_OFFLOAD_CHECK(true, \"loaded\");\n\n  // Load cpu to cuda.\n  int64_t device_id = JUST(this->device())->device_id();\n  std::shared_ptr<Tensor> cpu_tensor = std::make_shared<LocalTensor>(offloaded_impl_);\n  auto loaded_tensor = JUST(functional::Copy(cpu_tensor, \"cuda\", device_id,\n                                             /*pin_memory=*/JUST(cpu_tensor->is_pinned())));\n  JUST(vm::CurrentRankSync());\n  JUST(set_data(loaded_tensor));\n\n  // Release cpu memory.\n  cpu_tensor.reset();\n  offloaded_impl_.reset();\n  auto* vm = JUST(SingletonMaybe<VirtualMachine>());\n  JUST(vm->ShrinkAllMem());\n\n  is_offloaded_ = false;\n  return Maybe<void>::Ok();\n}\n\nstd::shared_ptr<Tensor> GlobalTensor::contiguous() const {\n  std::shared_ptr<Tensor> tensor = std::const_pointer_cast<Tensor>(shared_from_this());\n  if (tensor->is_contiguous()) { return tensor; }\n  return CHECK_JUST(functional::ToContiguous(tensor));\n}\n\nstd::shared_ptr<Tensor> GlobalTensor::pin_memory() const {\n  std::shared_ptr<Tensor> tensor = std::const_pointer_cast<Tensor>(shared_from_this());\n  return CHECK_JUST(functional::PinMemory(tensor));\n}\n\nMaybe<Tensor> GlobalTensor::clone() const {\n  std::shared_ptr<Tensor> input = std::const_pointer_cast<Tensor>(shared_from_this());\n  DisableCheckGlobalTensorMetaScope disable_meta_check{};\n  return JUST(functional::ToGlobal(input, JUST(parallel_desc()), *JUST(GetSbpList(JUST(nd_sbp()))),\n                                   /*grad_sbp_parallels=*/{}, /* sync_data */ true, /*copy=*/true));\n}\n\nMaybe<GlobalTensor> GlobalTensor::MakeTensor(const std::shared_ptr<const Shape>& shape,\n                                             DataType dtype, MemoryFormat memory_format,\n                                             Symbol<NdSbp> nd_sbp,\n                                             Symbol<ParallelDesc> parallel_desc, bool is_lazy,\n                                             bool requires_grad, bool is_leaf) {\n  std::shared_ptr<GlobalTensorImpl> impl;\n  Symbol<GlobalTensorMeta> global_tensor_meta(\n      GlobalTensorMeta(*shape, dtype, memory_format, nd_sbp, parallel_desc));\n  if (is_lazy) {\n    impl = std::make_shared<LazyGlobalTensorImpl>(global_tensor_meta, requires_grad, is_leaf);\n  } else {\n    impl = JUST(EagerGlobalTensorImpl::New(global_tensor_meta, requires_grad, is_leaf));\n  }\n  return std::make_shared<GlobalTensor>(impl);\n}\n\nbool GlobalTensor::is_cpu() const {\n  return CHECK_JUST(parallel_desc())->device_type() == DeviceType::kCPU;\n}\nbool GlobalTensor::is_cuda() const {\n  return CHECK_JUST(parallel_desc())->device_type() == DeviceType::kCUDA;\n}\n\nMaybe<Tensor> GlobalTensor::detach() const {\n  std::shared_ptr<Tensor> tensor = std::make_shared<GlobalTensor>(JUST(impl_->detach()));\n  if (this->is_lazy()) { JUST(tensor->BorrowTensorName(this)); }\n  return tensor;\n}\n\nMaybe<void> GlobalTensor::set_data(const std::shared_ptr<Tensor>& other) {\n  CHECK_OR_RETURN(this->is_leaf())\n      << \"Only leaf tensor's data can be set, because non-leaf tensor's data has been captured in \"\n         \"the backward graph in autograd.\";\n  const auto& global_tensor = std::dynamic_pointer_cast<GlobalTensor>(JUST(other->detach()));\n  CHECK_NOTNULL_OR_RETURN(global_tensor);  // NOLINT\n  JUST(WithConsistencyChecked(global_tensor, [&]() -> Maybe<void> { return Maybe<void>::Ok(); }));\n\n  bool old_requires_grad = requires_grad();\n  impl_ = global_tensor->impl_;\n  JUST(set_requires_grad(old_requires_grad));\n  grad_fn_node_ = nullptr;\n  if (other->is_lazy()) { JUST(this->BorrowTensorName(other.get())); }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> GlobalTensor::offload() {\n  TENSOR_OFFLOAD_CHECK(false, \"offloaded\");\n\n  // Offload to cpu mem with a cpu tensor implantation.\n  std::shared_ptr<Tensor> cuda_tensor = shared_from_this();\n  auto offloaded_tensor = JUST(functional::Copy(cuda_tensor, \"cpu\", GlobalProcessCtx::LocalRank(),\n                                                /*pin_memory=*/false));\n  JUST(vm::ClusterSync());\n  const auto& detached_tensor =\n      std::dynamic_pointer_cast<GlobalTensor>(JUST(offloaded_tensor->detach()));\n  CHECK_NOTNULL_OR_RETURN(detached_tensor) << \"detached_tensor must be a global tensor.\";\n  offloaded_impl_ = detached_tensor->impl_;\n\n  // Release cuda memory, but the meta data is valid.\n  auto eager_blob_obj = JUST(JUST(impl_->cur_rank_phy_tensor())->eager_blob_object());\n  JUST(eager_blob_obj->DeallocateBlobDataPtr());\n\n  auto* vm = JUST(SingletonMaybe<VirtualMachine>());\n  JUST(vm->ShrinkAllMem());\n\n  is_offloaded_ = true;\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> GlobalTensor::load() {\n  TENSOR_OFFLOAD_CHECK(true, \"loaded\");\n\n  // Load cpu to cuda.\n  std::shared_ptr<Tensor> cpu_tensor = std::make_shared<GlobalTensor>(offloaded_impl_);\n  auto loaded_tensor = JUST(functional::Copy(cpu_tensor, \"cuda\", GlobalProcessCtx::LocalRank(),\n                                             /*pin_memory=*/false));\n  JUST(vm::ClusterSync());\n  JUST(set_data(loaded_tensor));\n\n  // Release cpu memory.\n  cpu_tensor.reset();\n  offloaded_impl_.reset();\n  auto* vm = JUST(SingletonMaybe<VirtualMachine>());\n  JUST(vm->ShrinkAllMem());\n\n  is_offloaded_ = false;\n  return Maybe<void>::Ok();\n}\n#undef TENSOR_OFFLOAD_CHECK\n\n}  // namespace one\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/framework/tensor.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_FRAMEWORK_TENSOR_H_\n#define ONEFLOW_CORE_FRAMEWORK_TENSOR_H_\n\n#include <memory>\n#include \"oneflow/core/common/data_type.h\"\n#include \"oneflow/core/common/shape_view.h\"\n#include \"oneflow/core/common/shape.h\"\n#include \"oneflow/core/common/stride.h\"\n#include \"oneflow/core/memory/memory_case.pb.h\"\n#include \"oneflow/core/framework/tensor_impl.h\"\n#include \"oneflow/core/framework/transport_token.h\"\n#include \"oneflow/core/common/error.h\"\n#include \"oneflow/core/autograd/autograd_engine.h\"\n#include \"oneflow/core/job/global_mode.h\"\n\nnamespace oneflow {\n\nclass NdSbp;\nclass Device;\n\nnamespace one {\n\nclass FunctionNode;\n\nclass GlobalTensor;\nclass LocalTensor;\n\nclass Tensor : public std::enable_shared_from_this<Tensor> {\n public:\n  virtual ~Tensor() = default;\n\n  // Getters\n  int64_t dim(int64_t index) const { return shape()->At(index); }\n  int64_t nelement() const { return shape()->elem_cnt(); }\n  int64_t ndim() const { return shape()->NumAxes(); }\n  Maybe<Tensor> ref_tensor() const { return ref_tensor_.lock(); }\n  int64_t ref_index() const { return ref_index_; }\n\n  virtual std::shared_ptr<const Shape> shape() const = 0;\n  virtual Symbol<DType> dtype() const = 0;\n  virtual Maybe<TransportToken> transport_token() const = 0;\n  virtual Maybe<Symbol<NdSbp>> nd_sbp() const = 0;\n  virtual Maybe<Symbol<ParallelDesc>> parallel_desc() const = 0;\n  virtual Maybe<Symbol<Device>> device() const = 0;\n  virtual Maybe<Symbol<Device>*> mut_device() = 0;\n  virtual bool is_cpu() const = 0;\n  virtual bool is_cuda() const = 0;\n  virtual bool is_global() const = 0;\n  virtual bool is_local() const { return !is_global(); }\n  virtual bool is_lazy() const = 0;\n  virtual bool is_eager() const { return !is_lazy(); }\n  virtual bool is_contiguous() const = 0;\n  virtual bool is_view() const = 0;\n  virtual Maybe<bool> is_pinned() const = 0;\n  virtual const TensorMeta& tensor_meta() const = 0;\n  virtual Maybe<Tensor> data() = 0;\n  virtual std::shared_ptr<Tensor> pin_memory() const = 0;\n  virtual Maybe<Symbol<LocalTensorMeta>> local_tensor_meta() const { OF_UNIMPLEMENTED(); }\n  virtual Maybe<Symbol<GlobalTensorMeta>> global_tensor_meta() const { OF_UNIMPLEMENTED(); }\n\n  // Getters valid only for EagerLocalTensor\n  virtual Maybe<EagerLocalTensorImpl*> mut_eager_local_tensor_impl() { OF_UNIMPLEMENTED(); }\n  virtual Maybe<vm::EagerBlobObject> eager_blob_object() const = 0;\n  virtual Maybe<LocalDepObject*> compute_local_dep_object() const = 0;\n  virtual Maybe<bool> has_eager_blob_object() const = 0;\n  virtual Maybe<TensorStorage> tensor_storage() const { OF_UNIMPLEMENTED(); }\n  virtual Maybe<const Stride> stride() const { OF_UNIMPLEMENTED(); }\n  virtual Maybe<int64_t> storage_offset() const { OF_UNIMPLEMENTED(); }\n  virtual MemoryFormat memory_format() const = 0;\n\n  // Getters/Setters valid only for EagerGlobalTensor\n  virtual Maybe<const Optional<Symbol<NdSbp>>&> consumer_nd_sbp_constraint() const {\n    OF_UNIMPLEMENTED();\n  }\n  virtual Maybe<LocalTensor> cur_rank_phy_tensor() const { OF_UNIMPLEMENTED(); }\n  virtual Maybe<void> set_consumer_nd_sbp_constraint(const Optional<Symbol<NdSbp>>& val) {\n    OF_UNIMPLEMENTED();\n  }\n\n  // Getters for autograd\n  virtual bool requires_grad() const = 0;\n  virtual bool is_leaf() const = 0;\n  virtual bool retain_grad() const = 0;\n  virtual std::shared_ptr<const FunctionNode> grad_fn_node() const = 0;\n  virtual int32_t get_grad_fn_output_index() const = 0;\n  virtual Maybe<Tensor> acc_grad() const = 0;\n  virtual Maybe<TensorArg> current_grad() const = 0;\n  virtual Maybe<Tensor> detach() const = 0;\n  virtual Maybe<Tensor> clone() const = 0;\n  virtual std::shared_ptr<Tensor> contiguous() const = 0;\n\n  // Setters for autograd\n  virtual Maybe<void> set_requires_grad(bool requires_grad) = 0;\n  virtual Maybe<void> set_retain_grad(bool retain_grad) = 0;\n  virtual void set_grad_fn_node(const std::shared_ptr<FunctionNode>& grad_fn_node) = 0;\n  virtual std::shared_ptr<FunctionNode> mut_grad_fn_node() = 0;\n  virtual void set_grad_fn_output_index(int32_t idx) = 0;\n  virtual Maybe<void> set_acc_grad(const std::shared_ptr<Tensor>& grad) = 0;\n  virtual Maybe<Tensor> mut_acc_grad() = 0;\n  virtual void set_is_leaf(bool is_leaf) = 0;\n  virtual std::shared_ptr<const AutogradMeta> autograd_meta() const = 0;\n  virtual std::shared_ptr<AutogradMeta> mut_autograd_meta() = 0;\n  virtual void set_autograd_meta(const std::shared_ptr<AutogradMeta>& autograd_meta) = 0;\n\n  virtual user_op::TensorDesc* mut_tensor_meta() = 0;\n  virtual Maybe<void> set_data(const std::shared_ptr<Tensor>& other) = 0;\n\n  // For offloading between devices\n  virtual Maybe<void> offload() = 0;\n  virtual Maybe<void> load() = 0;\n  virtual Maybe<bool> is_offloaded() const = 0;\n\n  virtual Maybe<void> RegisterStorageDeleteHook(const std::function<void()>& hook) {\n    OF_UNIMPLEMENTED();\n  };\n  virtual Maybe<LocalTensor> AsLocalTensor() = 0;\n  virtual Maybe<GlobalTensor> AsGlobalTensor() = 0;\n\n  Maybe<void> BorrowTensorName(const Tensor* other) const;\n  Maybe<void> set_ref_tensor(const std::shared_ptr<Tensor>& ref);\n  Maybe<void> set_ref_index(const int64_t index);\n\n  // The same tensor instance should share the python object to ensure that\n  // their id are consistent in Python. That is if x and y are hold the same tensor,\n  // then `id(x)` should equal to `id(y)`\n  void* pyobject() const { return pyobj_ptr_.get(); }\n  void set_pyobject_ptr(std::unique_ptr<void, void (*)(void*)>&& pyobj_ptr) {\n    pyobj_ptr_ = std::move(pyobj_ptr);\n  }\n  bool owns_pyobj() const { return owns_pyobj_; }\n  void set_owns_pyobj(bool owns_pyobj) { owns_pyobj_ = owns_pyobj; }\n\n protected:\n  Tensor()\n      : pyobj_ptr_(nullptr, [](void*) {}),\n        owns_pyobj_(false),\n        ref_tensor_(std::weak_ptr<Tensor>()),\n        ref_index_(0) {}\n\n private:\n  std::unique_ptr<void, void (*)(void*)> pyobj_ptr_;\n  bool owns_pyobj_;\n  std::weak_ptr<Tensor> ref_tensor_;\n  int64_t ref_index_;\n};\n\nclass StaticZerosTensor final : public Tensor {\n public:\n  static Maybe<StaticZerosTensor> MakeTensor(const std::shared_ptr<const Shape>& shape,\n                                             DataType dtype, MemoryFormat memory_format,\n                                             Symbol<Device> device) {\n    return std::shared_ptr<StaticZerosTensor>(\n        new StaticZerosTensor(shape, dtype, memory_format, device));\n  }\n  // Getters\n  std::shared_ptr<const Shape> shape() const override { return shape_; }\n  Symbol<DType> dtype() const override { return CHECK_JUST(DType::Get(dtype_)); }\n  Maybe<TransportToken> transport_token() const override { RETURN_ERROR_WITH_BUG_PROMPT(); }\n  Maybe<Symbol<NdSbp>> nd_sbp() const override { RETURN_ERROR_WITH_BUG_PROMPT(); }\n  Maybe<Symbol<ParallelDesc>> parallel_desc() const override { RETURN_ERROR_WITH_BUG_PROMPT(); }\n  Maybe<Symbol<Device>> device() const override { return device_; }\n  Maybe<Symbol<Device>*> mut_device() override { RETURN_ERROR_WITH_BUG_PROMPT(); }\n  bool is_cpu() const override {\n    PRINT_BUG_PROMPT_AND_ABORT();\n    return false;\n  }\n  bool is_cuda() const override {\n    PRINT_BUG_PROMPT_AND_ABORT();\n    return false;\n  }\n  bool is_global() const override { return false; }\n  bool is_local() const override { return !is_global(); }\n  bool is_lazy() const override {\n    PRINT_BUG_PROMPT_AND_ABORT();\n    return false;\n  }\n  bool is_eager() const override { return !is_lazy(); }\n  const TensorMeta& tensor_meta() const override {\n    PRINT_BUG_PROMPT_AND_ABORT();\n    return *(TensorMeta*)nullptr;\n  }\n  Maybe<Tensor> data() override { RETURN_ERROR_WITH_BUG_PROMPT(); }\n  std::shared_ptr<Tensor> pin_memory() const override {\n    return std::const_pointer_cast<Tensor>(shared_from_this());\n  }\n  Maybe<Symbol<LocalTensorMeta>> local_tensor_meta() const override {\n    RETURN_ERROR_WITH_BUG_PROMPT();\n  }\n  Maybe<Symbol<GlobalTensorMeta>> global_tensor_meta() const override {\n    RETURN_ERROR_WITH_BUG_PROMPT();\n  }\n\n  // Getters valid only for EagerLocalTensor\n  Maybe<EagerLocalTensorImpl*> mut_eager_local_tensor_impl() override {\n    RETURN_ERROR_WITH_BUG_PROMPT();\n  }\n  Maybe<vm::EagerBlobObject> eager_blob_object() const override { RETURN_ERROR_WITH_BUG_PROMPT(); }\n  Maybe<LocalDepObject*> compute_local_dep_object() const override {\n    RETURN_ERROR_WITH_BUG_PROMPT();\n  }\n  Maybe<bool> has_eager_blob_object() const override { RETURN_ERROR_WITH_BUG_PROMPT(); }\n  Maybe<TensorStorage> tensor_storage() const override { RETURN_ERROR_WITH_BUG_PROMPT(); }\n  Maybe<const Stride> stride() const override { RETURN_ERROR_WITH_BUG_PROMPT(); }\n  Maybe<int64_t> storage_offset() const override { RETURN_ERROR_WITH_BUG_PROMPT(); }\n  MemoryFormat memory_format() const override { return memory_format_; }\n\n  // Getters/Setters valid only for EagerGlobalTensor\n  Maybe<const Optional<Symbol<NdSbp>>&> consumer_nd_sbp_constraint() const override {\n    RETURN_ERROR_WITH_BUG_PROMPT();\n  }\n  Maybe<LocalTensor> cur_rank_phy_tensor() const override { RETURN_ERROR_WITH_BUG_PROMPT(); }\n  Maybe<void> set_consumer_nd_sbp_constraint(const Optional<Symbol<NdSbp>>& val) override {\n    RETURN_ERROR_WITH_BUG_PROMPT();\n  }\n\n  // Getters for autograd\n  bool requires_grad() const override {\n    PRINT_BUG_PROMPT_AND_ABORT();\n    return false;\n  }\n  bool is_leaf() const override {\n    PRINT_BUG_PROMPT_AND_ABORT();\n    return false;\n  }\n  bool retain_grad() const override {\n    PRINT_BUG_PROMPT_AND_ABORT();\n    return false;\n  }\n  bool is_contiguous() const override {\n    PRINT_BUG_PROMPT_AND_ABORT();\n    return true;\n  }\n  bool is_view() const override {\n    PRINT_BUG_PROMPT_AND_ABORT();\n    return false;\n  }\n  Maybe<bool> is_pinned() const override { RETURN_ERROR_WITH_BUG_PROMPT(); }\n  std::shared_ptr<const FunctionNode> grad_fn_node() const override {\n    PRINT_BUG_PROMPT_AND_ABORT();\n    return nullptr;\n  }\n  int32_t get_grad_fn_output_index() const override {\n    PRINT_BUG_PROMPT_AND_ABORT();\n    return 0;\n  }\n  Maybe<Tensor> acc_grad() const override { RETURN_ERROR_WITH_BUG_PROMPT(); }\n  Maybe<TensorArg> current_grad() const override { RETURN_ERROR_WITH_BUG_PROMPT(); }\n  Maybe<Tensor> detach() const override { RETURN_ERROR_WITH_BUG_PROMPT(); }\n  Maybe<Tensor> clone() const override { RETURN_ERROR_WITH_BUG_PROMPT(); }\n  std::shared_ptr<Tensor> contiguous() const override {\n    return std::const_pointer_cast<Tensor>(shared_from_this());\n  }\n  // Setters for autograd\n  Maybe<void> set_requires_grad(bool requires_grad) override {\n    PRINT_BUG_PROMPT_AND_ABORT();\n    return Maybe<void>::Ok();\n  }\n  Maybe<void> set_retain_grad(bool retain_grad) override {\n    RETURN_ERROR_WITH_BUG_PROMPT();\n    return Maybe<void>::Ok();\n  }\n  void set_grad_fn_node(const std::shared_ptr<FunctionNode>& grad_fn_node) override {\n    PRINT_BUG_PROMPT_AND_ABORT();\n  }\n  void set_grad_fn_output_index(int32_t idx) override { PRINT_BUG_PROMPT_AND_ABORT(); }\n  std::shared_ptr<FunctionNode> mut_grad_fn_node() override {\n    PRINT_BUG_PROMPT_AND_ABORT();\n    return *(std::shared_ptr<FunctionNode>*)nullptr;\n  }\n  Maybe<void> set_acc_grad(const std::shared_ptr<Tensor>& grad) override {\n    RETURN_ERROR_WITH_BUG_PROMPT();\n  }\n  Maybe<Tensor> mut_acc_grad() override { RETURN_ERROR_WITH_BUG_PROMPT(); }\n  void set_is_leaf(bool is_leaf) override { PRINT_BUG_PROMPT_AND_ABORT(); }\n  std::shared_ptr<const AutogradMeta> autograd_meta() const override {\n    PRINT_BUG_PROMPT_AND_ABORT();\n  }\n  std::shared_ptr<AutogradMeta> mut_autograd_meta() override {\n    PRINT_BUG_PROMPT_AND_ABORT();\n    return nullptr;\n  }\n  void set_autograd_meta(const std::shared_ptr<AutogradMeta>& autograd_meta) override {\n    PRINT_BUG_PROMPT_AND_ABORT();\n  }\n\n  user_op::TensorDesc* mut_tensor_meta() override {\n    PRINT_BUG_PROMPT_AND_ABORT();\n    return nullptr;\n  }\n  Maybe<void> set_data(const std::shared_ptr<Tensor>& other) override {\n    RETURN_ERROR_WITH_BUG_PROMPT();\n  }\n\n  Maybe<void> offload() override { RETURN_ERROR_WITH_BUG_PROMPT(); }\n  Maybe<void> load() override { RETURN_ERROR_WITH_BUG_PROMPT(); }\n  Maybe<bool> is_offloaded() const override { RETURN_ERROR_WITH_BUG_PROMPT(); }\n\n  Maybe<LocalTensor> AsLocalTensor() override;\n  Maybe<GlobalTensor> AsGlobalTensor() override { RETURN_ERROR_WITH_BUG_PROMPT(); }\n\n private:\n  StaticZerosTensor(const std::shared_ptr<const Shape>& shape, DataType dtype,\n                    MemoryFormat memory_format, Symbol<Device> device)\n      : shape_(shape), dtype_(dtype), memory_format_(memory_format), device_(device) {}\n  const std::shared_ptr<const Shape> shape_;\n  DataType dtype_;\n  MemoryFormat memory_format_;\n  Symbol<Device> device_;\n};\n\ntemplate<typename DerivedT>\nclass TensorIf : public Tensor {\n public:\n  virtual ~TensorIf() = default;\n\n  // Getters for autograd\n  // acc_grad is tensor's accumulated grad in more than once backward operation,\n  // and current_grad is temporary grad to shared data with different FunctionNode\n  std::shared_ptr<const FunctionNode> grad_fn_node() const override { return grad_fn_node_; }\n  int32_t get_grad_fn_output_index() const override { return grad_fn_output_index_; }\n\n  // Setters for autograd\n  void set_grad_fn_node(const std::shared_ptr<FunctionNode>& grad_fn_node) override {\n    grad_fn_node_ = grad_fn_node;\n  }\n  std::shared_ptr<FunctionNode> mut_grad_fn_node() override { return grad_fn_node_; }\n  void set_grad_fn_output_index(int32_t idx) override { grad_fn_output_index_ = idx; }\n\n protected:\n  TensorIf() = default;\n  std::shared_ptr<FunctionNode> grad_fn_node_;\n  int32_t grad_fn_output_index_ = -1;\n};\n\ntemplate<typename DerivedT>\nclass ProxyTensor : public TensorIf<DerivedT> {\n public:\n  ProxyTensor(const std::shared_ptr<Tensor>& tensor) : tensor_(tensor) {\n    if (tensor->is_lazy()) { CHECK_JUST(this->BorrowTensorName(tensor.get())); }\n  }\n  virtual ~ProxyTensor() = default;\n\n  virtual std::shared_ptr<const Shape> shape() const override { return tensor_->shape(); }\n  virtual Symbol<DType> dtype() const override { return tensor_->dtype(); }\n  virtual Maybe<Symbol<NdSbp>> nd_sbp() const override { return tensor_->nd_sbp(); }\n  virtual Maybe<Symbol<ParallelDesc>> parallel_desc() const override {\n    return tensor_->parallel_desc();\n  }\n  virtual Maybe<Symbol<Device>> device() const override { return tensor_->device(); }\n  virtual Maybe<Symbol<Device>*> mut_device() override { return tensor_->mut_device(); }\n  virtual bool is_cpu() const override { return tensor_->is_cpu(); }\n  virtual bool is_cuda() const override { return tensor_->is_cuda(); }\n  virtual bool is_global() const override { return tensor_->is_global(); }\n  virtual bool is_local() const override { return tensor_->is_local(); }\n  virtual bool is_lazy() const override { return tensor_->is_lazy(); }\n  virtual bool is_eager() const override { return tensor_->is_eager(); }\n  virtual const TensorMeta& tensor_meta() const override { return tensor_->tensor_meta(); }\n  virtual Maybe<Symbol<LocalTensorMeta>> local_tensor_meta() const override {\n    return tensor_->local_tensor_meta();\n  }\n  virtual Maybe<Symbol<GlobalTensorMeta>> global_tensor_meta() const override {\n    return tensor_->global_tensor_meta();\n  }\n  virtual Maybe<Tensor> data() override { return tensor_->detach(); }\n  virtual std::shared_ptr<Tensor> pin_memory() const override { return tensor_->pin_memory(); }\n\n  // Must override grad_fn_node function. Otherwise grad_fn will belong to this not tensor_,\n  // and it will be wrong when use Tensor.data() in operators.\n  virtual std::shared_ptr<const FunctionNode> grad_fn_node() const override {\n    return tensor_->grad_fn_node();\n  }\n  virtual void set_grad_fn_node(const std::shared_ptr<FunctionNode>& grad_fn_node) override {\n    tensor_->set_grad_fn_node(grad_fn_node);\n  }\n  virtual std::shared_ptr<FunctionNode> mut_grad_fn_node() override {\n    return tensor_->mut_grad_fn_node();\n  }\n\n  virtual Maybe<EagerLocalTensorImpl*> mut_eager_local_tensor_impl() override {\n    return tensor_->mut_eager_local_tensor_impl();\n  }\n  virtual Maybe<vm::EagerBlobObject> eager_blob_object() const override {\n    return tensor_->eager_blob_object();\n  }\n  virtual Maybe<LocalDepObject*> compute_local_dep_object() const override {\n    return tensor_->compute_local_dep_object();\n  }\n  virtual Maybe<bool> has_eager_blob_object() const override {\n    return tensor_->has_eager_blob_object();\n  }\n  virtual Maybe<TensorStorage> tensor_storage() const override { return tensor_->tensor_storage(); }\n  virtual Maybe<const Stride> stride() const override { return tensor_->stride(); }\n  virtual Maybe<int64_t> storage_offset() const override { return tensor_->storage_offset(); }\n  virtual MemoryFormat memory_format() const override { return tensor_->memory_format(); }\n\n  virtual Maybe<const Optional<Symbol<NdSbp>>&> consumer_nd_sbp_constraint() const override {\n    return tensor_->consumer_nd_sbp_constraint();\n  }\n  virtual Maybe<TransportToken> transport_token() const override {\n    return tensor_->transport_token();\n  }\n  virtual Maybe<LocalTensor> cur_rank_phy_tensor() const override {\n    return tensor_->cur_rank_phy_tensor();\n  }\n  virtual Maybe<void> set_consumer_nd_sbp_constraint(const Optional<Symbol<NdSbp>>& val) override {\n    return tensor_->set_consumer_nd_sbp_constraint(val);\n  }\n\n  virtual bool requires_grad() const override { return tensor_->requires_grad(); }\n  virtual bool is_leaf() const override { return tensor_->is_leaf(); }\n  virtual bool retain_grad() const override { return tensor_->retain_grad(); }\n  virtual bool is_contiguous() const override { return tensor_->is_contiguous(); }\n  virtual bool is_view() const override { return tensor_->is_view(); }\n  virtual Maybe<bool> is_pinned() const override { return tensor_->is_pinned(); }\n  virtual Maybe<Tensor> acc_grad() const override { return tensor_->acc_grad(); }\n  virtual Maybe<TensorArg> current_grad() const override { return tensor_->current_grad(); }\n  virtual Maybe<Tensor> detach() const override { return tensor_->detach(); }\n  virtual Maybe<Tensor> clone() const override { return tensor_->clone(); }\n\n  virtual Maybe<void> set_requires_grad(bool requires_grad) override {\n    return tensor_->set_requires_grad(requires_grad);\n  }\n  virtual Maybe<void> set_retain_grad(bool retain_grad) override {\n    return tensor_->set_retain_grad(retain_grad);\n  }\n  virtual Maybe<void> set_acc_grad(const std::shared_ptr<Tensor>& grad) override {\n    return tensor_->set_acc_grad(grad);\n  }\n  virtual Maybe<Tensor> mut_acc_grad() override { return tensor_->mut_acc_grad(); }\n  virtual void set_is_leaf(bool is_leaf) override { return tensor_->set_is_leaf(is_leaf); }\n  virtual std::shared_ptr<const AutogradMeta> autograd_meta() const override {\n    return tensor_->autograd_meta();\n  }\n  virtual std::shared_ptr<AutogradMeta> mut_autograd_meta() override {\n    return tensor_->mut_autograd_meta();\n  }\n  virtual void set_autograd_meta(const std::shared_ptr<AutogradMeta>& autograd_meta) override {\n    return tensor_->set_autograd_meta(autograd_meta);\n  }\n\n  virtual user_op::TensorDesc* mut_tensor_meta() override { return tensor_->mut_tensor_meta(); }\n  virtual Maybe<void> set_data(const std::shared_ptr<Tensor>& other) override {\n    bool old_requires_grad = tensor_->requires_grad();\n    this->tensor_ = JUST(other->detach());\n    JUST(this->tensor_->set_requires_grad(old_requires_grad));\n    if (other->is_lazy()) { JUST(this->BorrowTensorName(other.get())); }\n    return Maybe<void>::Ok();\n  }\n\n  virtual Maybe<void> offload() override {\n    JUST(tensor_->offload());\n    return Maybe<void>::Ok();\n  }\n  virtual Maybe<void> load() override {\n    JUST(tensor_->load());\n    return Maybe<void>::Ok();\n  }\n  Maybe<bool> is_offloaded() const override { return JUST(tensor_->is_offloaded()); }\n\n  virtual Maybe<LocalTensor> AsLocalTensor() override {\n    if (const auto& local_tensor = std::dynamic_pointer_cast<LocalTensor>(tensor_)) {\n      return local_tensor;\n    }\n    RETURN_ERROR_WITH_BUG_PROMPT();\n  }\n\n  virtual Maybe<GlobalTensor> AsGlobalTensor() override {\n    if (const auto& global_tensor = std::dynamic_pointer_cast<GlobalTensor>(tensor_)) {\n      return global_tensor;\n    }\n    RETURN_ERROR_WITH_BUG_PROMPT();\n  }\n\n protected:\n  std::shared_ptr<Tensor> tensor_;\n};\n\nclass Parameter final : public ProxyTensor<Parameter> {\n public:\n  static Maybe<Parameter> MakeTensor(const std::shared_ptr<Tensor>& tensor, bool requires_grad) {\n    return std::shared_ptr<Parameter>(new Parameter(JUST(tensor->detach()), requires_grad));\n  }\n  bool is_leaf() const override { return true; }\n  std::shared_ptr<Tensor> contiguous() const override;\n  std::shared_ptr<Tensor> pin_memory() const override;\n  Maybe<void> set_data(const std::shared_ptr<Tensor>& other) override;\n\n private:\n  Parameter(const std::shared_ptr<Tensor>& tensor, bool requires_grad);\n};\n\nclass LocalTensor final : public TensorIf<LocalTensor> {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(LocalTensor);\n  LocalTensor() = default;\n  explicit LocalTensor(const std::shared_ptr<LocalTensorImpl>& impl) { impl_ = impl; }\n  ~LocalTensor() override = default;\n\n  // Getters\n  std::shared_ptr<const Shape> shape() const override { return impl_->shape(); }\n  Symbol<DType> dtype() const override { return CHECK_JUST(DType::Get(impl_->dtype())); }\n  Maybe<TransportToken> transport_token() const override {\n    OF_RUNTIME_ERROR() << \"Only global tensors have 'global_id', global id is used to \"\n                          \"synchronize rank\";\n  }\n  Maybe<Symbol<NdSbp>> nd_sbp() const override {\n    OF_RUNTIME_ERROR()\n        << \"Local tensor has no sbp property. \"\n           \"sbp is the description in the oneflow distributed case, you can refer to \"\n           \"https://docs.oneflow.org/master/parallelism/03_global_tensor.html; \"\n           \"For example, create a global tensor like this : 'x = oneflow.tensor((2,3, \"\n           \"placement=oneflow.placement(\\\"cuda\\\", {0: 0}), sbp=oneflow.sbp.broadcast))', then \"\n           \"'x.sbp' is 'oneflow.sbp.broadcast'\";\n  }\n  Maybe<Symbol<ParallelDesc>> parallel_desc() const override {\n    OF_RUNTIME_ERROR() << \"Only global tensors have 'placement'. Placement is used to describe \"\n                          \"the distribution of global tensor in multiple GPUs. Please use \"\n                          \"'.device' for local tensors.\";\n  }\n  Maybe<Symbol<Device>> device() const override { return impl_->device(); }\n  Maybe<Symbol<Device>*> mut_device() override { return impl_->mut_device(); }\n  bool is_lazy() const override { return impl_->is_lazy(); }\n  bool is_global() const override { return false; }\n  bool is_cpu() const override;\n  bool is_cuda() const override;\n  std::shared_ptr<Tensor> contiguous() const override;\n\n  const TensorMeta& tensor_meta() const override { return *impl_->tensor_meta(); }\n  Maybe<Tensor> data() override { return this->detach(); }\n  std::shared_ptr<Tensor> pin_memory() const override;\n\n  // Getters valid only for EagerLocalTensor\n  Maybe<vm::EagerBlobObject> eager_blob_object() const override {\n    return impl_->eager_blob_object();\n  }\n  Maybe<LocalDepObject*> compute_local_dep_object() const override {\n    return impl_->compute_local_dep_object();\n  }\n  Maybe<TensorStorage> tensor_storage() const override { return impl_->tensor_storage(); }\n  Maybe<bool> has_eager_blob_object() const override { return impl_->has_eager_blob_object(); }\n  Maybe<const Stride> stride() const override { return impl_->stride(); }\n  Maybe<int64_t> storage_offset() const override { return impl_->storage_offset(); }\n  MemoryFormat memory_format() const override { return impl_->memory_format(); }\n\n  // Getters for autograd\n  Maybe<Tensor> acc_grad() const override { return impl_->acc_grad(); }\n  Maybe<TensorArg> current_grad() const override { return impl_->current_grad(); }\n  bool requires_grad() const override { return impl_->requires_grad(); }\n  bool is_leaf() const override { return impl_->is_leaf(); }\n  bool retain_grad() const override { return impl_->retain_grad(); }\n  bool is_contiguous() const override { return impl_->is_contiguous(); }\n  bool is_view() const override { return impl_->is_view(); }\n  Maybe<bool> is_pinned() const override { return impl_->is_pinned(); };\n\n  Maybe<Symbol<LocalTensorMeta>> local_tensor_meta() const override { return impl_->tensor_meta(); }\n\n  // Setters for autograd\n  Maybe<void> set_acc_grad(const std::shared_ptr<Tensor>& grad) override {\n    if (!grad_fn_node_ && requires_grad()) {\n      CHECK_OR_RETURN(is_leaf()) << \"only leaf tensor may have no grad_fn\";\n      AddAccumulateFunctionNode(shared_from_this());\n    }\n    return impl_->set_acc_grad(grad);\n  }\n  Maybe<void> set_requires_grad(bool requires_grad) override {\n    JUST(impl_->set_requires_grad(requires_grad));\n    if (!requires_grad) { set_grad_fn_node(nullptr); }\n    return Maybe<void>::Ok();\n  }\n  Maybe<void> set_retain_grad(bool retain_grad) override {\n    return impl_->set_retain_grad(retain_grad);\n  }\n  Maybe<Tensor> mut_acc_grad() override { return impl_->mut_acc_grad(); }\n  void set_is_leaf(bool is_leaf) override { impl_->set_is_leaf(is_leaf); }\n  std::shared_ptr<const AutogradMeta> autograd_meta() const override {\n    return impl_->autograd_meta();\n  }\n  std::shared_ptr<AutogradMeta> mut_autograd_meta() override { return impl_->mut_autograd_meta(); }\n  void set_autograd_meta(const std::shared_ptr<AutogradMeta>& autograd_meta) override {\n    impl_->set_autograd_meta(autograd_meta);\n  }\n\n  // Operators for tensor\n  Maybe<Tensor> detach() const override;\n  Maybe<Tensor> clone() const override;\n\n  static Maybe<LocalTensor> MakeTensor(const std::shared_ptr<const Shape>& shape,\n                                       const std::shared_ptr<const Stride>& stride, DataType dtype,\n                                       MemoryFormat memory_format, const Symbol<Device>& device,\n                                       bool is_lazy, bool requires_grad, bool is_leaf);\n  LocalTensorImpl* mut_impl() { return impl_.get(); }\n  Maybe<EagerLocalTensorImpl*> mut_eager_local_tensor_impl() override {\n    return impl_->mut_eager_local_tensor_impl();\n  }\n  user_op::TensorDesc* mut_tensor_meta() override {\n    return std::const_pointer_cast<MutLocalTensorMeta>(impl_->mut_tensor_meta()).get();\n  }\n  Maybe<void> set_data(const std::shared_ptr<Tensor>& other) override;\n\n  Maybe<void> offload() override;\n  Maybe<void> load() override;\n  Maybe<bool> is_offloaded() const override { return is_offloaded_; }\n\n  Maybe<void> set_impl(std::shared_ptr<LocalTensorImpl> impl) {\n    impl_ = impl;\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> RegisterStorageDeleteHook(const std::function<void()>& hook) override {\n    return impl_->RegisterStorageDeleteHook(hook);\n  }\n\n  Maybe<LocalTensor> AsLocalTensor() override {\n    return std::dynamic_pointer_cast<LocalTensor>(shared_from_this());\n  }\n  Maybe<GlobalTensor> AsGlobalTensor() override { RETURN_ERROR_WITH_BUG_PROMPT(); }\n\n private:\n  std::shared_ptr<LocalTensorImpl> impl_;\n  std::shared_ptr<LocalTensorImpl> offloaded_impl_;\n  bool is_offloaded_{false};\n};\n\nclass GlobalTensor final : public TensorIf<GlobalTensor> {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(GlobalTensor);\n  GlobalTensor() = default;\n  explicit GlobalTensor(const std::shared_ptr<GlobalTensorImpl>& impl) { impl_ = impl; }\n  ~GlobalTensor() override = default;\n\n  // Getters\n  std::shared_ptr<const Shape> shape() const override { return impl_->shape(); }\n  Symbol<DType> dtype() const override { return CHECK_JUST(DType::Get(impl_->dtype())); }\n  Maybe<TransportToken> transport_token() const override { return impl_->transport_token(); }\n  Maybe<Symbol<NdSbp>> nd_sbp() const override { return impl_->nd_sbp(); }\n  Maybe<Symbol<ParallelDesc>> parallel_desc() const override { return impl_->parallel_desc(); }\n  Maybe<Symbol<Device>> device() const override {\n    if (GlobalMode::is_enabled()) {\n      auto global_mode_gurad = GlobalMode::Guard(false);\n      const auto& device_tag = JUST(parallel_desc())->device_tag();\n      return JUST(Device::New(device_tag));\n    }\n    OF_RUNTIME_ERROR() << \"Only local tensors have 'device'. Please use \"\n                          \"'.placement' for global tensors.\";\n  }\n  Maybe<Symbol<Device>*> mut_device() override {\n    OF_RUNTIME_ERROR() << \"GlobalTensor has no mut_device property\";\n  }\n  bool is_lazy() const override { return impl_->is_lazy(); }\n  bool is_global() const override { return true; }\n  Maybe<const Optional<Symbol<NdSbp>>&> consumer_nd_sbp_constraint() const override {\n    return impl_->consumer_nd_sbp_constraint();\n  }\n  Maybe<LocalTensor> cur_rank_phy_tensor() const override { return impl_->cur_rank_phy_tensor(); }\n  bool is_cpu() const override;\n  bool is_cuda() const override;\n  std::shared_ptr<Tensor> contiguous() const override;\n  Maybe<Tensor> data() override { return this->detach(); }\n  Maybe<const Stride> stride() const override { return impl_->stride(); }\n  MemoryFormat memory_format() const override { return impl_->memory_format(); }\n  std::shared_ptr<Tensor> pin_memory() const override;\n\n  // Getters valid only for EagerLocalTensor\n  Maybe<vm::EagerBlobObject> eager_blob_object() const override {\n    return impl_->eager_blob_object();\n  }\n  Maybe<LocalDepObject*> compute_local_dep_object() const override {\n    return impl_->compute_local_dep_object();\n  }\n  const TensorMeta& tensor_meta() const override { return *impl_->tensor_meta(); }\n  Maybe<TensorStorage> tensor_storage() const override { return impl_->tensor_storage(); }\n  Maybe<bool> has_eager_blob_object() const override { return impl_->has_eager_blob_object(); }\n\n  // Setters\n  Maybe<void> set_consumer_nd_sbp_constraint(const Optional<Symbol<NdSbp>>& val) override {\n    impl_->set_consumer_nd_sbp_constraint(val);\n    return Maybe<void>::Ok();\n  }\n\n  // Getters for autograd\n  Maybe<Tensor> acc_grad() const override { return impl_->acc_grad(); }\n  Maybe<TensorArg> current_grad() const override { return impl_->current_grad(); }\n  bool requires_grad() const override { return impl_->requires_grad(); }\n  bool is_leaf() const override { return impl_->is_leaf(); }\n  bool retain_grad() const override { return impl_->retain_grad(); }\n  bool is_contiguous() const override { return impl_->is_contiguous(); }\n  bool is_view() const override { return impl_->is_view(); }\n  Maybe<bool> is_pinned() const override {\n    OF_RUNTIME_ERROR() << \"Global tensor has no is_pinned method\";\n  }\n\n  // Setters for autograd\n  Maybe<void> set_acc_grad(const std::shared_ptr<Tensor>& grad) override {\n    if (!grad_fn_node_ && requires_grad()) {\n      CHECK_OR_RETURN(is_leaf()) << \"only leaf tensor may have no grad_fn\";\n      AddAccumulateFunctionNode(shared_from_this());\n    }\n    return impl_->set_acc_grad(grad);\n  }\n  Maybe<Tensor> mut_acc_grad() override { return impl_->mut_acc_grad(); }\n  Maybe<void> set_requires_grad(bool requires_grad) override {\n    JUST(impl_->set_requires_grad(requires_grad));\n    if (!requires_grad) { set_grad_fn_node(nullptr); }\n    return Maybe<void>::Ok();\n  }\n  Maybe<void> set_retain_grad(bool retain_grad) override {\n    return impl_->set_retain_grad(retain_grad);\n  }\n  void set_is_leaf(bool is_leaf) override { impl_->set_is_leaf(is_leaf); }\n  std::shared_ptr<const AutogradMeta> autograd_meta() const override {\n    return impl_->autograd_meta();\n  }\n  std::shared_ptr<AutogradMeta> mut_autograd_meta() override { return impl_->mut_autograd_meta(); }\n  void set_autograd_meta(const std::shared_ptr<AutogradMeta>& autograd_meta) override {\n    impl_->set_autograd_meta(autograd_meta);\n  }\n\n  // Operators for tensor\n  Maybe<Tensor> detach() const override;\n  Maybe<Tensor> clone() const override;\n\n  static Maybe<GlobalTensor> MakeTensor(const std::shared_ptr<const Shape>& shape, DataType dtype,\n                                        MemoryFormat memory_format, Symbol<NdSbp> nd_sbp,\n                                        Symbol<ParallelDesc> parallel_desc, bool is_lazy,\n                                        bool requires_grad, bool is_leaf);\n\n  GlobalTensorImpl* mut_impl() { return impl_.get(); }\n\n  Maybe<Symbol<GlobalTensorMeta>> global_tensor_meta() const override {\n    return impl_->tensor_meta();\n  }\n\n  user_op::TensorDesc* mut_tensor_meta() override { return impl_->mut_tensor_meta(); }\n  Maybe<void> set_data(const std::shared_ptr<Tensor>& other) override;\n\n  Maybe<void> offload() override;\n  Maybe<void> load() override;\n  Maybe<bool> is_offloaded() const override { return is_offloaded_; }\n\n  Maybe<LocalTensor> AsLocalTensor() override { RETURN_ERROR_WITH_BUG_PROMPT(); }\n  Maybe<GlobalTensor> AsGlobalTensor() override {\n    return std::dynamic_pointer_cast<GlobalTensor>(shared_from_this());\n  }\n\n private:\n  std::shared_ptr<GlobalTensorImpl> impl_;\n  std::shared_ptr<GlobalTensorImpl> offloaded_impl_;\n  bool is_offloaded_{false};\n};\n\n}  // namespace one\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_FRAMEWORK_TENSOR_H_\n"
  },
  {
    "path": "oneflow/core/framework/tensor_arg.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/framework/tensor_arg.h\"\n#include \"oneflow/core/framework/tensor_tuple.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\nnamespace one {\n\nbool TensorArg::Empty() const { return !acc_tensor_; }\n\nvoid TensorArg::Release() { acc_tensor_.reset(); }\n\nMaybe<void> TensorArg::PushPartialTensor(const std::shared_ptr<Tensor>& partial_tensor) {\n  if (!acc_tensor_) {\n    acc_tensor_ = partial_tensor;\n  } else {\n    // Should not inplace accumulate grad. For example,\n    // >>> z = x + y\n    // >>> p = x / z\n    // >>> p.sum().backward()\n    //\n    // As we know that dx = dz + dp / z and dy = dz, so it will lead to wrong value\n    // for dy if dx is shared with dz.\n    acc_tensor_ =\n        JUST(functional::Add(partial_tensor, acc_tensor_, /*alpha=*/1, /*inplace=*/false));\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<Tensor> TensorArg::GetAccTensor() const {\n  CHECK_OR_RETURN(Empty() == false) << \"Can not GetAccTensor because it is empty\";\n  return acc_tensor_;\n}\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/framework/tensor_arg.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#ifndef ONEFLOW_CORE_FRAMEWORK_TENSOR_ARG_H_\n#define ONEFLOW_CORE_FRAMEWORK_TENSOR_ARG_H_\n\n#include <memory>\n#include <vector>\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/autograd/autograd_meta.h\"\n\nnamespace oneflow {\nnamespace one {\n\nclass Tensor;\n\n// This class will be used in TensorImpl and Autograd. It will share data with different\n// FunctionNodes.\nclass TensorArg final {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(TensorArg);\n  TensorArg() = default;\n  ~TensorArg() = default;\n\n  bool Empty() const;\n  void Release();\n  Maybe<void> PushPartialTensor(const std::shared_ptr<Tensor>& partial_tensor);\n  Maybe<Tensor> GetAccTensor() const;\n\n private:\n  std::shared_ptr<Tensor> acc_tensor_;\n};\n\n}  // namespace one\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_FRAMEWORK_TENSOR_ARG_H_\n"
  },
  {
    "path": "oneflow/core/framework/tensor_global_id.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/decorator.h\"\n#include \"oneflow/core/framework/tensor.h\"\n#include \"oneflow/core/framework/tensor_tuple.h\"\n#include \"oneflow/core/framework/transport_token.h\"\n#include \"oneflow/core/framework/tensor_global_id.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nMaybe<std::shared_ptr<TransportToken>> RawGetMetaTransportToken() {\n  const auto& token = JUST(TransportToken::NewTransportToken(kTransportTokenTypeMeta));\n  return std::make_shared<TransportToken>(token);\n}\nstatic constexpr auto* GetMetaTransportToken = DECORATE(&RawGetMetaTransportToken, ThreadLocal);\n\n}  // namespace\n\nMaybe<TransportToken> NewTensorGlobalId() { return ++**JUST(GetMetaTransportToken()); }\n\nnamespace one {\n\nint64_t* MutThreadLocalGlobalIdDepth() {\n  static thread_local int64_t recursive_depth = 0;\n  return &recursive_depth;\n}\n\nMaybe<void> InitGlobalId(TensorTuple* outputs) {\n  for (const auto& output : *outputs) {\n    CHECK_OR_RETURN(output);\n    const auto& global_tensor = JUST(output->AsGlobalTensor());\n    CHECK_OR_RETURN(global_tensor)\n        << Error::UnimplementedError() << \"global tensors suppported only.\";\n    const auto& transport_token = JUST(NewTensorGlobalId());\n    JUST(global_tensor->mut_impl()->set_transport_token(transport_token));\n  }\n  return Maybe<void>::Ok();\n}\n\n}  // namespace one\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/framework/tensor_global_id.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_FRAMEWORK_TENSOR_GLOBAL_ID_\n#define ONEFLOW_CORE_FRAMEWORK_TENSOR_GLOBAL_ID_\n\n#include \"oneflow/core/common/maybe.h\"\n\nnamespace oneflow {\n\nMaybe<TransportToken> NewTensorGlobalId();\n\nnamespace one {\n\nclass TensorTuple;\n\nint64_t* MutThreadLocalGlobalIdDepth();\nMaybe<void> InitGlobalId(TensorTuple* outputs);\n\ntemplate<typename... Args>\nstruct NonRecursiveInitGlobalId;\n\ntemplate<typename Arg0, typename Arg1, typename... Args>\nstruct NonRecursiveInitGlobalId<Maybe<void>, Arg0, Arg1, TensorTuple*, Args...> {\n  template<Maybe<void> (*func)(Arg0, Arg1, TensorTuple*, Args...)>\n  static Maybe<void> Call(Arg0 arg0, Arg1 arg1, TensorTuple* outputs, Args... args) {\n    auto* recursive_depth = MutThreadLocalGlobalIdDepth();\n    ++*recursive_depth;\n    Maybe<void> ret = func(arg0, arg1, outputs, args...);\n    --*recursive_depth;\n    if (*recursive_depth == 0 && ret.IsOk()) { JUST(InitGlobalId(outputs)); }\n    return ret;\n  }\n};\n\n}  // namespace one\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_FRAMEWORK_TENSOR_GLOBAL_ID_\n"
  },
  {
    "path": "oneflow/core/framework/tensor_impl.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <type_traits>\n#include \"oneflow/core/common/blocking_then_busy.h\"\n#include \"oneflow/core/common/data_type.h\"\n#include \"oneflow/core/common/stream_type.h\"\n#include \"oneflow/core/common/tensor_meta.h\"\n#include \"oneflow/core/vm/virtual_machine.h\"\n#include \"oneflow/core/framework/instructions_builder.h\"\n#include \"oneflow/core/framework/tensor_impl.h\"\n#include \"oneflow/core/framework/tensor.h\"\n#include \"oneflow/core/common/stride.h\"\n#include \"oneflow/core/job/parallel_desc.h\"\n#include \"oneflow/core/job/sbp_parallel.h\"\n#include \"oneflow/core/functional/functional.h\"\n#include \"oneflow/core/framework/dtype.h\"\n#include \"oneflow/core/eager/eager_blob_object.h\"\n#include \"oneflow/core/eager/tensor_storage.h\"\n#include \"oneflow/core/eager/local_dep_object.h\"\n#include \"oneflow/core/eager/tensor_storage.h\"\n#include \"oneflow/core/vm/vm_util.h\"\n#include \"oneflow/core/operator/operator.h\"\n#include \"oneflow/core/control/global_process_ctx.h\"\n#include \"oneflow/core/framework/stream_allocator_is_pinned.h\"\n\nnamespace oneflow {\nnamespace one {\n\nMaybe<void> TensorImpl::set_requires_grad(bool requires_grad) {\n  if (requires_grad) {\n    const DataType tensor_dtype = dtype();\n    CHECK_OR_RETURN(IsSupportRequireGradDataType(tensor_dtype))\n        << \"RuntimeError: only Tensors of floating point or complex can require gradients\";\n  }\n  autograd_meta_->set_requires_grad(requires_grad);\n  return Maybe<void>::Ok();\n}\n\nMaybe<Tensor> TensorImpl::acc_grad() const { return autograd_meta_->acc_grad(); }\n\nMaybe<TensorArg> TensorImpl::current_grad() const { return autograd_meta_->current_grad(); }\n\nMaybe<void> TensorImpl::set_acc_grad(const std::shared_ptr<Tensor>& grad) {\n  return autograd_meta_->set_acc_grad(grad);\n}\n\nMaybe<Tensor> TensorImpl::mut_acc_grad() { return autograd_meta_->mut_acc_grad(); }\n\nMaybe<void> TensorImpl::set_retain_grad(bool retain_grad) {\n  if (!requires_grad() && retain_grad) {\n    return Error::RuntimeError() << \"Can't retain_grad on Tensor that has requires_grad=False\";\n  }\n  if (!is_leaf() && retain_grad) { autograd_meta_->set_retain_grad(retain_grad); }\n  return Maybe<void>::Ok();\n}\n\nMaybe<LocalTensorImpl> LazyLocalTensorImpl::detach() const {\n  auto detached_impl = std::make_shared<LazyLocalTensorImpl>(tensor_meta_, false, true);\n  return std::shared_ptr<LocalTensorImpl>(detached_impl);\n}\n\nEagerLocalTensorImpl::EagerLocalTensorImpl(const std::shared_ptr<TensorStorage>& tensor_storage,\n                                           int64_t storage_offset, bool requires_grad, bool is_leaf)\n    : LocalTensorImpl(requires_grad, is_leaf),\n      tensor_storage_(tensor_storage),\n      storage_offset_(storage_offset) {}\n\nEagerLocalTensorImpl::~EagerLocalTensorImpl() {}\n\nMaybe<void> EagerLocalTensorImpl::UpdateTensorStorage() {\n  const auto& eager_blob_object = eager_blob_object_;\n  tensor_storage_ = std::make_shared<TensorStorage>(eager_blob_object->tensor_storage());\n  tensor_storage_->set_releaser_hook([eager_blob_object](\n                                         const std::shared_ptr<vm::TensorStorage>&) {\n    auto ret = PhysicalRun([&](InstructionsBuilder* builder) -> Maybe<void> {\n      if (eager_blob_object->producer_stream().has_value()) {\n        JUST(builder->ReleaseTensor(eager_blob_object));\n      }\n      return Maybe<void>::Ok();\n    });\n    // We should not use CHECK_JUST here because it will throw an exception\n    // in destructor.\n    if (!ret.IsOk()) {\n      LOG(WARNING)\n          << \"Release hook gets an error. Release hooks are executed in destructor, so the error \"\n             \"is possibly only a secondary error caused by another unrelated exception.\";\n      LOG(WARNING) << \"======= Error message begin =======\";\n      LOG(WARNING) << ret.GetSerializedError();\n      LOG(WARNING) << \"======= Error message end =======\";\n    }\n  });\n  return Maybe<void>::Ok();\n}\n\nconst std::shared_ptr<const MutLocalTensorMeta>& EagerLocalTensorImpl::mut_tensor_meta() {\n  return eager_blob_object_->mut_tensor_meta();\n}\n// Getters\nconst Symbol<LocalTensorMeta>& EagerLocalTensorImpl::tensor_meta() const {\n  return eager_blob_object_->tensor_meta();\n}\n\nMaybe<LocalDepObject*> EagerLocalTensorImpl::compute_local_dep_object() const {\n  return JUST(eager_blob_object())->compute_local_dep_object();\n}\n\nMaybe<void> EagerLocalTensorImpl::InitEagerBlobObject(\n    const Symbol<one::LocalTensorMeta>& local_tensor_meta,\n    const std::shared_ptr<const one::MutLocalTensorMeta>& mut_local_tensor_meta,\n    const intrusive::shared_ptr<LocalDepObject>& dep_object) {\n  CHECK_OR_RETURN(static_cast<bool>(local_tensor_meta->device()));  // NOLINT\n  const auto& mem_case = local_tensor_meta->device()->mem_case();\n\n  if (tensor_storage_) {\n    auto tensor_storage = tensor_storage_->storage();\n    eager_blob_object_ = std::make_shared<vm::EagerBlobObject>(\n        mem_case, local_tensor_meta, mut_local_tensor_meta, local_tensor_meta->dtype(),\n        local_tensor_meta->memory_format(), tensor_storage, dep_object);\n  } else {\n    auto device = local_tensor_meta->device();\n    auto storage = device->rematable() ? std::make_shared<vm::RematableTensorStorage>(device)\n                                       : std::make_shared<vm::TensorStorage>(true, device);\n    const auto& eager_blob_object = std::make_shared<vm::EagerBlobObject>(\n        mem_case, local_tensor_meta, mut_local_tensor_meta, local_tensor_meta->dtype(),\n        local_tensor_meta->memory_format(), storage, dep_object);\n    JUST(set_eager_blob_object(eager_blob_object));\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<bool> EagerLocalTensorImpl::is_pinned() const {\n  if (this->device() == JUST(Device::New(\"meta\"))) { return false; }\n  if (!eager_blob_object_) { return false; }\n  return IsStreamAllocatorPinned::Visit(JUST(eager_blob_object_->producer_stream())->stream_type());\n}\n\nMaybe<void> EagerLocalTensorImpl::set_eager_blob_object(\n    std::shared_ptr<vm::EagerBlobObject> eager_blob_object) {\n  eager_blob_object_ = eager_blob_object;\n  CHECK_OR_RETURN(eager_blob_object_->shape() == tensor_meta()->shape()) << kOfBugIssueUploadPrompt;\n  CHECK_OR_RETURN(eager_blob_object_->data_type() == tensor_meta()->dtype())\n      << kOfBugIssueUploadPrompt;\n  JUST(UpdateTensorStorage());\n  return Maybe<void>::Ok();\n}\n\nstd::shared_ptr<const Shape> EagerLocalTensorImpl::shape() const {\n  if (!eager_blob_object_) { return tensor_meta()->shape_ptr(); }\n  return eager_blob_object_->shape_ptr();\n}\n\nstd::shared_ptr<const Stride> EagerLocalTensorImpl::stride() const {\n  if (!eager_blob_object_) { return tensor_meta()->stride_ptr(); }\n  return eager_blob_object_->stride_ptr();\n}\n\nMemoryFormat EagerLocalTensorImpl::memory_format() const {\n  if (!eager_blob_object_) { return tensor_meta()->memory_format(); }\n  return eager_blob_object_->memory_format();\n}\n\nMaybe<LocalTensorImpl> EagerLocalTensorImpl::detach() const {\n  auto detached_impl = std::make_shared<EagerLocalTensorImpl>(tensor_storage_, false, true);\n  detached_impl->eager_blob_object_ = eager_blob_object_;\n  return std::shared_ptr<LocalTensorImpl>(detached_impl);\n}\n\nMaybe<void> EagerLocalTensorImpl::RegisterStorageDeleteHook(const std::function<void()>& hook) {\n  CHECK_OR_RETURN(eager_blob_object_) << \"EagerBlobObject has not initialized\";\n  eager_blob_object_->RegisterStorageDeleteHook(hook);\n  return Maybe<void>::Ok();\n}\n\nMaybe<GlobalTensorImpl> LazyGlobalTensorImpl::detach() const {\n  auto detached_impl = std::make_shared<LazyGlobalTensorImpl>(tensor_meta_, false, true);\n  return std::shared_ptr<GlobalTensorImpl>(detached_impl);\n}\n\nEagerGlobalTensorImpl::EagerGlobalTensorImpl(\n    Symbol<GlobalTensorMeta> global_tensor_meta,\n    const std::shared_ptr<LocalTensor>& cur_rank_phy_tensor)\n    : GlobalTensorImpl(global_tensor_meta, cur_rank_phy_tensor->requires_grad(),\n                       cur_rank_phy_tensor->is_leaf()),\n      cur_rank_phy_tensor_(cur_rank_phy_tensor) {}\n\n/* static */ Maybe<EagerGlobalTensorImpl> EagerGlobalTensorImpl::New(\n    Symbol<GlobalTensorMeta> global_tensor_meta, bool requires_grad, bool is_leaf) {\n  const auto& parallel_desc = global_tensor_meta->parallel_desc();\n  Optional<int64_t> parallel_id;\n  const auto& device = JUST(parallel_desc->GetTensorDevice4CurrentProcessCtx(&parallel_id));\n  return EagerGlobalTensorImpl::New(global_tensor_meta, device, parallel_id, requires_grad,\n                                    is_leaf);\n}\n\nnamespace {\n\nMaybe<Shape> GetPhysicalShape(const Shape& logical_shape, const NdSbp& nd_sbp,\n                              const ParallelDesc& parallel_desc,\n                              const Optional<int64_t>& parallel_id) {\n  if (parallel_id.has_value()) {\n    return GetPhysicalShape(logical_shape, nd_sbp, parallel_desc, JUST(parallel_id));\n  } else {\n    return std::make_shared<Shape>(DimVector(logical_shape.NumAxes(), 0));\n  }\n}\n\n}  // namespace\n\n/* static */ Maybe<EagerGlobalTensorImpl> EagerGlobalTensorImpl::New(\n    Symbol<GlobalTensorMeta> global_tensor_meta, Symbol<Device> device,\n    const Optional<int64_t>& parallel_id, bool requires_grad, bool is_leaf) {\n  const auto& shape = global_tensor_meta->shape_ptr();\n  const auto& dtype = global_tensor_meta->dtype();\n  const auto& memory_format = global_tensor_meta->memory_format();\n  const auto& nd_sbp = global_tensor_meta->nd_sbp();\n  const auto& parallel_desc = global_tensor_meta->parallel_desc();\n  const auto& cur_rank_phy_shape =\n      JUST(GetPhysicalShape(*shape, *nd_sbp, *parallel_desc, parallel_id));\n  std::shared_ptr<LocalTensor> cur_rank_phy_tensor;\n  // If the `'parallel_desc` doesn't cover current ProcessCtx or the tensor has 0-size shape, there\n  // is no need to compute through the corresponding opkernel, and can be obtained directly through\n  // empty op.\n  if (parallel_id.has_value() && shape->elem_cnt() != 0) {\n    const auto& cur_rank_phy_tensor_meta =\n        SymbolOf(LocalTensorMeta(*cur_rank_phy_shape, dtype, memory_format, device));\n    auto cur_rank_phy_tensor_impl = std::make_shared<EagerLocalTensorImpl>(requires_grad, is_leaf);\n    const auto& dep_object = NewLocalDepObject();\n    JUST(cur_rank_phy_tensor_impl->InitEagerBlobObject(cur_rank_phy_tensor_meta, dep_object));\n    cur_rank_phy_tensor = std::make_shared<LocalTensor>(cur_rank_phy_tensor_impl);\n  } else {\n    const auto& dtype_symbol = JUST(DType::Get(dtype));\n    const auto& empty =\n        JUST(functional::Empty(*cur_rank_phy_shape, dtype_symbol, device,\n                               /*requires_grad=*/requires_grad, /*pin_memory=*/false));\n    cur_rank_phy_tensor = JUST(empty->AsLocalTensor());\n    JUST(cur_rank_phy_tensor->set_requires_grad(requires_grad));\n    cur_rank_phy_tensor->set_is_leaf(is_leaf);\n  }\n  auto* tensor_impl = new EagerGlobalTensorImpl(global_tensor_meta, cur_rank_phy_tensor);\n  return std::shared_ptr<EagerGlobalTensorImpl>(tensor_impl);\n}\n\nMaybe<GlobalTensorImpl> EagerGlobalTensorImpl::detach() const {\n  auto detached_impl = std::shared_ptr<EagerGlobalTensorImpl>(new EagerGlobalTensorImpl(\n      tensor_meta_, JUST(JUST(cur_rank_phy_tensor_->detach())->AsLocalTensor())));\n  detached_impl->consumer_nd_sbp_constraint_ = consumer_nd_sbp_constraint_;\n  detached_impl->transport_token_ = transport_token_;\n  return std::shared_ptr<GlobalTensorImpl>(detached_impl);\n}\n\nstd::shared_ptr<const Stride> EagerGlobalTensorImpl::stride() const {\n  if (!cur_rank_phy_tensor_) { return tensor_meta()->stride_ptr(); }\n  return cur_rank_phy_tensor_->tensor_meta().stride_ptr();\n}\n\nMemoryFormat EagerGlobalTensorImpl::memory_format() const {\n  if (!cur_rank_phy_tensor_) { return tensor_meta()->memory_format(); }\n  return cur_rank_phy_tensor_->tensor_meta().memory_format();\n}\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/framework/tensor_impl.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#ifndef ONEFLOW_CORE_FRAMEWORK_TENSOR_IMPL_H_\n#define ONEFLOW_CORE_FRAMEWORK_TENSOR_IMPL_H_\n\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/common/data_type.h\"\n#include \"oneflow/core/common/optional.h\"\n#include \"oneflow/core/framework/tensor_storage.h\"\n#include \"oneflow/core/common/tensor_desc.h\"\n#include \"oneflow/core/common/tensor_meta.h\"\n#include \"oneflow/core/framework/transport_token.h\"\n#include \"oneflow/core/autograd/autograd_meta.h\"\n#include \"oneflow/core/common/symbol.h\"\n#include \"oneflow/core/intrusive/intrusive.h\"\n#include \"oneflow/core/eager/local_dep_object.h\"\n\nnamespace oneflow {\n\nclass MemoryCase;\n\nclass Shape;\nclass Stride;\nclass Device;\n\nnamespace vm {\nclass EagerBlobObject;\nclass TensorStorage;\n}  // namespace vm\n\nnamespace one {\n\nclass Tensor;\nclass TensorArg;\n\nclass TensorImpl {\n public:\n  virtual ~TensorImpl() = default;\n\n  // Getters\n  virtual std::shared_ptr<const Shape> shape() const = 0;\n  virtual std::shared_ptr<const Stride> stride() const = 0;\n  virtual MemoryFormat memory_format() const = 0;\n  virtual DataType dtype() const = 0;\n  virtual bool is_lazy() const = 0;\n\n  // Getters valid only for EagerLocalTensorImpl\n  virtual Maybe<vm::EagerBlobObject> eager_blob_object() const = 0;\n  virtual Maybe<LocalDepObject*> compute_local_dep_object() const = 0;\n  virtual Maybe<TensorStorage> tensor_storage() const { OF_UNIMPLEMENTED(); }\n  virtual Maybe<bool> has_eager_blob_object() const = 0;\n  virtual Maybe<int64_t> storage_offset() const { OF_UNIMPLEMENTED(); }\n  virtual bool is_contiguous() const = 0;\n  virtual bool is_view() const = 0;\n  virtual Maybe<bool> is_pinned() const { OF_UNIMPLEMENTED(); }\n\n  // Getters for autograd\n  Maybe<Tensor> acc_grad() const;\n  Maybe<TensorArg> current_grad() const;\n  bool requires_grad() const { return autograd_meta_->requires_grad(); }\n  bool is_leaf() const { return autograd_meta_->is_leaf(); }\n  bool retain_grad() const { return autograd_meta_->retain_grad(); }\n\n  // Setters for autograd\n  Maybe<void> set_acc_grad(const std::shared_ptr<Tensor>& grad);\n  Maybe<Tensor> mut_acc_grad();\n  Maybe<void> set_requires_grad(bool requires_grad);\n  Maybe<void> set_retain_grad(bool retain_grad);\n\n  void set_is_leaf(bool is_leaf) { autograd_meta_->set_is_leaf(is_leaf); }\n\n  std::shared_ptr<const AutogradMeta> autograd_meta() const { return autograd_meta_; }\n  std::shared_ptr<AutogradMeta> mut_autograd_meta() { return autograd_meta_; }\n  void set_autograd_meta(const std::shared_ptr<AutogradMeta>& autograd_meta) {\n    autograd_meta_ = autograd_meta;\n  }\n\n  virtual Maybe<void> RegisterStorageDeleteHook(const std::function<void()>& hook) {\n    OF_UNIMPLEMENTED();\n  }\n\n protected:\n  TensorImpl(bool requires_grad, bool is_leaf)\n      : autograd_meta_(std::make_shared<AutogradMeta>(requires_grad, is_leaf)) {}\n\n protected:\n  std::shared_ptr<AutogradMeta> autograd_meta_;\n};\n\nclass EagerLocalTensorImpl;\nclass LocalTensorImpl : public TensorImpl {\n public:\n  virtual ~LocalTensorImpl() = default;\n\n  // Getters\n  DataType dtype() const override { return tensor_meta()->dtype(); }\n  const Symbol<Device>& device() const { return tensor_meta()->device(); }\n  bool is_contiguous() const override { return tensor_meta()->is_contiguous(); }\n  bool is_view() const override { return tensor_meta()->is_view(); }\n\n  virtual const Symbol<LocalTensorMeta>& tensor_meta() const = 0;\n  // Setters\n  virtual const std::shared_ptr<const MutLocalTensorMeta>& mut_tensor_meta() = 0;\n  Maybe<Symbol<Device>*> mut_device() {\n    return std::const_pointer_cast<MutLocalTensorMeta>(mut_tensor_meta())->mut_device();\n  }\n  virtual Maybe<EagerLocalTensorImpl*> mut_eager_local_tensor_impl() {\n    RETURN_ERROR_WITH_BUG_PROMPT();\n  }\n\n  virtual Maybe<LocalTensorImpl> detach() const { RETURN_ERROR_WITH_BUG_PROMPT(); }\n\n protected:\n  LocalTensorImpl(bool requires_grad, bool is_leaf) : TensorImpl(requires_grad, is_leaf) {}\n};\n\nclass LocalTensor;\n\nclass GlobalTensorImpl : public TensorImpl {\n public:\n  virtual ~GlobalTensorImpl() = default;\n\n  // Getters\n  std::shared_ptr<const Shape> shape() const override { return tensor_meta_->shape_ptr(); }\n  std::shared_ptr<const Stride> stride() const override { return tensor_meta_->stride_ptr(); }\n  MemoryFormat memory_format() const override { return tensor_meta_->memory_format(); }\n  DataType dtype() const override { return tensor_meta_->dtype(); }\n\n  Symbol<NdSbp> nd_sbp() const { return tensor_meta_->nd_sbp(); }\n  Symbol<ParallelDesc> parallel_desc() const { return tensor_meta_->parallel_desc(); }\n  const Optional<Symbol<NdSbp>>& consumer_nd_sbp_constraint() const {\n    return consumer_nd_sbp_constraint_;\n  }\n  virtual Maybe<LocalTensor> cur_rank_phy_tensor() const { RETURN_ERROR_WITH_BUG_PROMPT(); }\n  Symbol<GlobalTensorMeta> tensor_meta() const { return tensor_meta_; }\n\n  // Getters valid only for EagerLocalTensorImpl\n  Maybe<vm::EagerBlobObject> eager_blob_object() const override { RETURN_ERROR_WITH_BUG_PROMPT(); }\n  Maybe<LocalDepObject*> compute_local_dep_object() const override {\n    RETURN_ERROR_WITH_BUG_PROMPT();\n  }\n  Maybe<bool> has_eager_blob_object() const override { RETURN_ERROR_WITH_BUG_PROMPT(); }\n\n  // Setters\n  void set_consumer_nd_sbp_constraint(const Optional<Symbol<NdSbp>>& val) {\n    consumer_nd_sbp_constraint_ = val;\n  }\n\n  GlobalTensorMeta* mut_tensor_meta() {\n    PRINT_BUG_PROMPT_AND_ABORT();\n    return nullptr;\n  }\n\n  Maybe<TransportToken> transport_token() const { return JUST(transport_token_); }\n\n  Maybe<void> set_transport_token(const TransportToken& transport_token) {\n    transport_token_ = transport_token;\n    return Maybe<void>::Ok();\n  }\n\n  virtual Maybe<GlobalTensorImpl> detach() const { RETURN_ERROR_WITH_BUG_PROMPT(); }\n\n protected:\n  GlobalTensorImpl(Symbol<GlobalTensorMeta> tensor_meta, bool requires_grad, bool is_leaf)\n      : TensorImpl(requires_grad, is_leaf),\n        tensor_meta_(tensor_meta),\n        consumer_nd_sbp_constraint_(),\n        transport_token_() {}\n\n  Symbol<GlobalTensorMeta> tensor_meta_;\n  Optional<Symbol<NdSbp>> consumer_nd_sbp_constraint_;\n  Optional<TransportToken> transport_token_;\n};\n\nclass LazyLocalTensorImpl final : public LocalTensorImpl {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(LazyLocalTensorImpl);\n  LazyLocalTensorImpl(const Symbol<LocalTensorMeta>& tensor_meta, bool requires_grad, bool is_leaf)\n      : LocalTensorImpl(requires_grad, is_leaf), tensor_meta_(tensor_meta) {}\n  ~LazyLocalTensorImpl() override = default;\n\n  // Getters\n  const Symbol<LocalTensorMeta>& tensor_meta() const override { return tensor_meta_; }\n  std::shared_ptr<const Shape> shape() const override { return tensor_meta()->shape_ptr(); }\n  std::shared_ptr<const Stride> stride() const override { return tensor_meta()->stride_ptr(); }\n  MemoryFormat memory_format() const override { return tensor_meta()->memory_format(); }\n\n  bool is_lazy() const override { return true; }\n  bool is_contiguous() const override {\n    // TODO:(zhaoluyang) default return true for now,\n    // but should return real status while stride/view mechanism is ready in lazy-local mode\n    return true;\n  }\n  bool is_view() const override { return false; }\n  Maybe<bool> is_pinned() const override { return false; }\n\n  const std::shared_ptr<const MutLocalTensorMeta>& mut_tensor_meta() override {\n    PRINT_BUG_PROMPT_AND_ABORT();\n  }\n\n  // Getters valid only for EagerLocalTensorImpl\n  Maybe<vm::EagerBlobObject> eager_blob_object() const override { RETURN_ERROR_WITH_BUG_PROMPT(); }\n  Maybe<LocalDepObject*> compute_local_dep_object() const override {\n    RETURN_ERROR_WITH_BUG_PROMPT();\n  }\n  Maybe<TensorStorage> tensor_storage() const override { RETURN_ERROR_WITH_BUG_PROMPT(); }\n  Maybe<bool> has_eager_blob_object() const override { RETURN_ERROR_WITH_BUG_PROMPT(); }\n  Maybe<LocalTensorImpl> detach() const override;\n\n private:\n  Symbol<LocalTensorMeta> tensor_meta_;\n};\n\nclass EagerLocalTensorImpl final : public LocalTensorImpl {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(EagerLocalTensorImpl);\n  EagerLocalTensorImpl()\n      : EagerLocalTensorImpl(std::shared_ptr<TensorStorage>(), 0, false, false) {}\n  EagerLocalTensorImpl(const std::shared_ptr<TensorStorage>& tensor_storage, bool requires_grad,\n                       bool is_leaf)\n      : EagerLocalTensorImpl(tensor_storage, 0, requires_grad, is_leaf) {}\n  EagerLocalTensorImpl(const std::shared_ptr<TensorStorage>& tensor_storage, int64_t storage_offset,\n                       bool requires_grad, bool is_leaf);\n\n  EagerLocalTensorImpl(bool requires_grad, bool is_leaf)\n      : EagerLocalTensorImpl(std::shared_ptr<TensorStorage>(), 0, requires_grad, is_leaf) {}\n  ~EagerLocalTensorImpl() override;\n\n  const std::shared_ptr<const MutLocalTensorMeta>& mut_tensor_meta() override;\n  // Getters\n  const Symbol<LocalTensorMeta>& tensor_meta() const override;\n  std::shared_ptr<const Shape> shape() const override;\n  std::shared_ptr<const Stride> stride() const override;\n  MemoryFormat memory_format() const override;\n\n  Maybe<LocalTensorImpl> detach() const override;\n  bool is_lazy() const override { return false; }\n  bool is_contiguous() const override { return tensor_meta()->is_contiguous(); }\n  bool is_view() const override { return tensor_meta()->is_view(); }\n  Maybe<bool> is_pinned() const override;\n\n  // Getters valid only for EagerLocalTensorImpl\n  Maybe<vm::EagerBlobObject> eager_blob_object() const override {\n    CHECK_OR_RETURN(eager_blob_object_);\n    return eager_blob_object_;\n  }\n  Maybe<LocalDepObject*> compute_local_dep_object() const override;\n  Maybe<TensorStorage> tensor_storage() const override {\n    CHECK_OR_RETURN(eager_blob_object_);\n    return tensor_storage_;\n  }\n  Maybe<bool> has_eager_blob_object() const override { return eager_blob_object_.get(); }\n  Maybe<int64_t> storage_offset() const override { return storage_offset_; }\n  // Setters\n  TensorStorage* mut_tensor_storage() { return tensor_storage_.get(); }\n  void set_storage_offset(int64_t offset) { storage_offset_ = offset; }\n\n  Maybe<void> InitEagerBlobObject(\n      const Symbol<one::LocalTensorMeta>& local_tensor_meta,\n      const std::shared_ptr<const one::MutLocalTensorMeta>& mut_local_tensor_meta,\n      const intrusive::shared_ptr<LocalDepObject>& dep_object);\n  Maybe<void> InitEagerBlobObject(const Symbol<one::LocalTensorMeta>& local_tensor_meta,\n                                  const intrusive::shared_ptr<LocalDepObject>& dep_object) {\n    JUST(InitEagerBlobObject(local_tensor_meta, std::shared_ptr<const one::MutLocalTensorMeta>(),\n                             dep_object));\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<EagerLocalTensorImpl*> mut_eager_local_tensor_impl() override { return this; }\n\n  Maybe<void> RegisterStorageDeleteHook(const std::function<void()>& hook) override;\n\n private:\n  Maybe<void> UpdateTensorStorage();\n  Maybe<void> set_eager_blob_object(std::shared_ptr<vm::EagerBlobObject> eager_blob_object);\n\n  std::shared_ptr<TensorStorage> tensor_storage_;\n  int64_t storage_offset_;\n  std::shared_ptr<vm::EagerBlobObject> eager_blob_object_;\n};\n\nclass LazyGlobalTensorImpl final : public GlobalTensorImpl {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(LazyGlobalTensorImpl);\n  LazyGlobalTensorImpl(Symbol<GlobalTensorMeta> global_tensor_meta, bool requires_grad,\n                       bool is_leaf)\n      : GlobalTensorImpl(global_tensor_meta, requires_grad, is_leaf) {}\n  ~LazyGlobalTensorImpl() override = default;\n\n  // Getters\n  bool is_lazy() const override { return true; }\n\n  bool is_contiguous() const override {\n    // TODO:(zhaoluyang) default return true for now,\n    // but should return real status while stride/view mechanism is ready in lazy-global mode\n    return true;\n  }\n\n  bool is_view() const override { return false; }\n\n  Maybe<GlobalTensorImpl> detach() const override;\n};\n\nclass EagerGlobalTensorImpl final : public GlobalTensorImpl {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(EagerGlobalTensorImpl);\n  ~EagerGlobalTensorImpl() override = default;\n\n  // Getters\n  std::shared_ptr<const Stride> stride() const override;\n  MemoryFormat memory_format() const override;\n\n  bool is_lazy() const override { return false; }\n\n  bool is_contiguous() const override {\n    // TODO:(zhaoluyang) default return true for now,\n    // but should return real status while stride/view mechanism is ready in eager-global mode\n    return true;\n  }\n  bool is_view() const override { return false; }\n\n  Maybe<LocalTensor> cur_rank_phy_tensor() const override { return cur_rank_phy_tensor_; }\n  void reset_cur_rank_phy_tensor(const std::shared_ptr<LocalTensor>& val) {\n    cur_rank_phy_tensor_ = val;\n  }\n\n  static Maybe<EagerGlobalTensorImpl> New(Symbol<GlobalTensorMeta> global_tensor_meta,\n                                          bool requires_grad, bool is_leaf);\n\n  static Maybe<EagerGlobalTensorImpl> New(Symbol<GlobalTensorMeta> global_tensor_meta,\n                                          Symbol<Device> device,\n                                          const Optional<int64_t>& parallel_id, bool requires_grad,\n                                          bool is_leaf);\n\n  Maybe<GlobalTensorImpl> detach() const override;\n\n private:\n  EagerGlobalTensorImpl(Symbol<GlobalTensorMeta> global_tensor_meta,\n                        const std::shared_ptr<LocalTensor>& cur_rank_phy_tensor);\n\n  std::shared_ptr<LocalTensor> cur_rank_phy_tensor_;\n};\n\n}  // namespace one\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_FRAMEWORK_TENSOR_IMPL_H_\n"
  },
  {
    "path": "oneflow/core/framework/tensor_methods.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/tensor_methods.h\"\n#include \"oneflow/core/autograd/autograd_engine.h\"\n#include \"oneflow/core/autograd/autograd_mode.h\"\n#include \"oneflow/core/common/container_util.h\"\n#include \"oneflow/core/common/scalar.h\"\n#include \"oneflow/core/common/shape.h\"\n#include \"oneflow/core/eager/eager_blob_object.h\"\n#include \"oneflow/core/common/stride.h\"\n#include \"oneflow/core/framework/dtype.h\"\n#include \"oneflow/core/functional/functional.h\"\n#include \"oneflow/core/framework/instructions_builder.h\"\n#include \"oneflow/core/ep/include/device_manager_registry.h\"\n#include \"oneflow/core/common/wrap_dim_utils.h\"\n#include \"oneflow/core/functional/functional_api.yaml.h\"\n\nnamespace oneflow {\nnamespace one {\nnamespace view {\n\n// NOTE: use env variable 'ONEFLOW_DISABLE_VIEW' control use view mechanism or not\n// If  set true, then do not use view mechanism(and view ops)\nbool IsEnvViewDisabled() {\n  static const bool env_view_disabled = ParseBooleanFromEnv(\"ONEFLOW_DISABLE_VIEW\", false);\n  return env_view_disabled;\n}\n\nbool IsViewApplicable(const std::shared_ptr<Tensor>& input) {\n  if (IsEnvViewDisabled()) { return false; }\n  // NOTE: only eager local tensor support view for now\n  // elem_cnt() >= 1  used to excluding 0 shape tensor\n  if (input->is_local() && !(LazyMode::is_enabled()) && input->shape()->elem_cnt() >= 1) {\n    return true;\n  }\n  return false;\n}\n\nstatic bool IsOverlappingMemorys(const std::vector<int64_t>& sizes,\n                                 const std::vector<int64_t>& strides) {\n  // reference: torch/csrc/autograd/FunctionsManual.cpp _maybe_overlapping_memory()\n  if (sizes.size() > 0) {\n    std::vector<std::size_t> argsort(sizes.size());\n    std::iota(argsort.begin(), argsort.end(), 0);\n    std::sort(argsort.begin(), argsort.end(),\n              [&](std::size_t i, std::size_t j) { return strides[i] < strides[j]; });\n    int64_t max_index_in_slice = 0;\n    for (auto i : argsort) {\n      auto stride_ = strides[i];\n      if (stride_ <= max_index_in_slice) { return true; }\n      max_index_in_slice += stride_ * (sizes[i] - 1);\n    }\n  }\n  return false;\n}\n\nstatic int64_t MinStorageSize(const std::vector<int64_t>& sizes,\n                              const std::vector<int64_t>& strides, int64_t storage_offset) {\n  int64_t storage_size = storage_offset + 1;\n  int64_t ndim = sizes.size();\n  for (size_t i = 0; i < ndim; i++) {\n    auto size_i = sizes[i];\n    if (size_i == 0) { return storage_offset; }\n    storage_size += (size_i - 1) * strides[i];\n  }\n  return storage_size;\n}\n\nMaybe<Tensor> BasicView(const std::shared_ptr<Tensor>& input, const Shape& target_shape,\n                        const int64_t storage_offset) {\n  /**\n   * This function provides basic view capabilities which\n   * accept input tensor with target shape, and return viewed tensor.\n   *\n   * The viewed tensor shared memory with input tensor, and both of\n   * them are memory contiguous, but has different shapes/strides.\n   */\n  Stride target_stride(target_shape);\n  return BasicView(input, target_shape, target_stride, storage_offset);\n}\n\nMaybe<Tensor> BasicView(const std::shared_ptr<Tensor>& input, const Shape& target_shape,\n                        const Stride& target_stride, const int64_t storage_offset) {\n  auto device = JUST(input->device());\n  auto tensor_meta =\n      SymbolOf(LocalTensorMeta(target_shape, target_stride, input->dtype()->data_type(),\n                               input->memory_format(), device, /*is_view=*/true));\n\n  CHECK_OR_RETURN(JUST(input->has_eager_blob_object()));\n  // new output tensor\n  const auto& blob_object = JUST(input->eager_blob_object());\n  bool requires_grad = (autograd::GradMode::is_enabled() && input->requires_grad());\n  auto tensor_impl = std::make_shared<EagerLocalTensorImpl>(JUST(input->tensor_storage()),\n                                                            storage_offset, requires_grad,\n                                                            /*is_leaf=*/!requires_grad);\n  JUST(\n      tensor_impl->InitEagerBlobObject(tensor_meta, JUST(blob_object->compute_local_dep_object())));\n\n  auto view_tensor = std::make_shared<LocalTensor>(tensor_impl);\n\n  const std::shared_ptr<vm::EagerBlobObject>& view_eager_blob_object =\n      JUST(view_tensor->eager_blob_object());\n  view_eager_blob_object->set_storage_offset(JUST(view_tensor->storage_offset()));\n  view_eager_blob_object->set_input_of_view_op(blob_object);\n  return std::static_pointer_cast<Tensor>(view_tensor);\n}\n\nMaybe<void> InplaceView(const std::shared_ptr<Tensor>& input, const Shape& target_shape,\n                        const Stride& target_stride, const int64_t storage_offset) {\n  Symbol<LocalTensorMeta> new_tensor_meta =\n      SymbolOf(LocalTensorMeta(target_shape, target_stride, input->dtype()->data_type(),\n                               input->memory_format(), JUST(input->device())));\n\n  bool requires_grad = (autograd::GradMode::is_enabled() && input->requires_grad());\n  std::shared_ptr<EagerLocalTensorImpl> new_tensor_impl = std::make_shared<EagerLocalTensorImpl>(\n      JUST(input->tensor_storage()), storage_offset, /*requires_grad=*/requires_grad,\n      /*is_leaf=*/!requires_grad);\n  JUST(new_tensor_impl->InitEagerBlobObject(\n      new_tensor_meta, JUST(JUST(input->eager_blob_object())->compute_local_dep_object())));\n  JUST(JUST(input->AsLocalTensor())->set_impl(new_tensor_impl));\n  return Maybe<void>::Ok();\n}\n\nMaybe<Tensor> Reshape(const std::shared_ptr<Tensor>& input, const Shape& target_shape) {\n  Stride target_stride(target_shape);\n  return Reshape(input, target_shape, target_stride);\n}\n\nMaybe<Tensor> Reshape(const std::shared_ptr<Tensor>& input, const Shape& target_shape,\n                      const Stride& target_stride) {\n  int64_t storage_offset = JUST(JUST(input->AsLocalTensor())->storage_offset());\n  std::shared_ptr<Tensor> output =\n      JUST(BasicView(input, target_shape, target_stride, storage_offset));\n\n  if (autograd::GradMode::is_enabled() && input->requires_grad()) {\n    Shape input_shape(input->shape()->dim_vec());\n    auto backward_fn = std::make_shared<BackwardFunction>();\n    backward_fn->body = [=](const TensorTuple& out_grads, TensorTuple* in_grads,\n                            bool create_graph) -> Maybe<void> {\n      autograd::AutoGradMode mode(create_graph);\n      CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n      in_grads->resize(1);\n      JUST(oneflow::VectorAt(*in_grads, 0)) =\n          JUST(functional::Reshape(JUST(oneflow::VectorAt(out_grads, 0)), input_shape));\n      return Maybe<void>::Ok();\n    };\n    backward_fn->status = []() { return false; };\n    TensorTuple outputs{output};\n    JUST(GetThreadLocalAutogradEngine()->AddNode(\"view::reshape_backward\", backward_fn, {input},\n                                                 &outputs));\n  }\n  return output;\n}\n\nMaybe<Tensor> Slice(const std::shared_ptr<Tensor>& input, const std::vector<int64_t>& starts,\n                    const std::vector<int64_t>& ends, const std::vector<int64_t>& steps) {\n  const auto& shape = input->shape();\n  const auto& strides = JUST(input->stride());\n  const int64_t ndim = starts.size();\n\n  CHECK_OR_RETURN(ndim == shape->NumAxes())\n      << Error::RuntimeError() << \"view::Slice(): starts size is expected \" << shape->NumAxes()\n      << \", but got \" << ndim;\n\n  CHECK_OR_RETURN(ends.size() == ndim && steps.size() == ndim)\n      << Error::RuntimeError() << \"view::Slice(): \" << (ends.size() != ndim ? \"ends\" : \"steps\")\n      << \" size is not equal to start.\";\n\n  DimVector target_dims(ndim);\n  Stride target_strides(ndim);\n  int64_t storage_offset = JUST(JUST(input->AsLocalTensor())->storage_offset());\n  for (int i = 0; i < ndim; ++i) {\n    int64_t step = std::min(steps[i], shape->At(i));\n    CHECK_OR_RETURN(step >= 0) << Error::RuntimeError() << \"Step must be greater than zero.\";\n    int64_t start = std::min(starts[i], shape->At(i));\n    int64_t end = std::min(ends[i], shape->At(i));\n    if (start < 0) { start += shape->At(i); }\n    if (start < 0) start = 0;\n    if (end < 0) { end += shape->At(i); }\n    if (end < start) end = start;\n    int64_t length = start == end ? 0 : (end - start + step - 1) / step;\n    target_dims[i] = length;\n    target_strides[i] = step * strides->at(i);\n    storage_offset += start * strides->at(i);\n  }\n\n  auto output = JUST(BasicView(input, Shape(target_dims), target_strides, storage_offset));\n  if (autograd::GradMode::is_enabled() && input->requires_grad()) {\n    const Shape in_shape = *input->shape();\n    auto backward_fn = std::make_shared<BackwardFunction>();\n    backward_fn->body = [=](const TensorTuple& out_grads, TensorTuple* in_grads,\n                            bool create_graph) -> Maybe<void> {\n      autograd::AutoGradMode mode(create_graph);\n      CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n      in_grads->resize(1);\n      (*in_grads)[0] = JUST(functional::SliceGrad(out_grads[0], in_shape, starts, ends, steps));\n      return Maybe<void>::Ok();\n    };\n    backward_fn->status = []() { return true; };\n    TensorTuple outputs{output};\n    JUST(GetThreadLocalAutogradEngine()->AddNode(\"view::slice_backward\", backward_fn, {input},\n                                                 &outputs));\n  }\n  return output;\n}\n\nMaybe<Tensor> Unsqueeze(const std::shared_ptr<Tensor>& input, const int32_t expand_dim) {\n  const auto& shape = input->shape();\n  const auto& strides = JUST(input->stride());\n  const auto& ndim = shape->NumAxes();\n\n  DimVector target_dim_vec(ndim + 1);\n  Stride target_stride_vec(ndim + 1);\n\n  {\n    int cnt = 0;\n    for (int i = 0; i < ndim; i++) {\n      if (i == expand_dim) { cnt++; }\n      target_dim_vec[cnt] = shape->at(i);\n      target_stride_vec[cnt] = strides->at(i);\n      cnt++;\n    }\n    target_dim_vec[expand_dim] = 1;\n    target_stride_vec[expand_dim] =\n        expand_dim < ndim ? strides->at(expand_dim) * target_dim_vec.at(expand_dim + 1) : 1;\n  }\n\n  int64_t storage_offset = JUST(JUST(input->AsLocalTensor())->storage_offset());\n  std::shared_ptr<Tensor> output =\n      JUST(BasicView(input, Shape(target_dim_vec), target_stride_vec, storage_offset));\n\n  if (autograd::GradMode::is_enabled() && input->requires_grad()) {\n    auto backward_fn = std::make_shared<BackwardFunction>();\n    backward_fn->body = [=](const TensorTuple& out_grads, TensorTuple* in_grads,\n                            bool create_graph) -> Maybe<void> {\n      autograd::AutoGradMode mode(create_graph);\n      CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n      in_grads->resize(1);\n      JUST(oneflow::VectorAt(*in_grads, 0)) =\n          JUST(functional::Reshape(JUST(oneflow::VectorAt(out_grads, 0)), *shape));\n      return Maybe<void>::Ok();\n    };\n    backward_fn->status = []() { return false; };\n    TensorTuple outputs{output};\n    JUST(GetThreadLocalAutogradEngine()->AddNode(\"view::unsqueeze_backward\", backward_fn, {input},\n                                                 &outputs));\n  }\n  return output;\n}\n\nMaybe<void> InplaceUnsqueeze(const std::shared_ptr<Tensor>& input, const int32_t expand_dim) {\n  const auto& shape = input->shape();\n  const auto& strides = JUST(input->stride());\n  const auto& ndim = shape->NumAxes();\n\n  DimVector target_dim_vec(ndim + 1);\n  Stride target_stride_vec(ndim + 1);\n\n  {\n    int cnt = 0;\n    for (int i = 0; i < ndim; i++) {\n      if (i == expand_dim) { cnt++; }\n      target_dim_vec[cnt] = shape->at(i);\n      target_stride_vec[cnt] = strides->at(i);\n      cnt++;\n    }\n    target_dim_vec[expand_dim] = 1;\n    target_stride_vec[expand_dim] =\n        expand_dim < ndim ? strides->at(expand_dim) * target_dim_vec.at(expand_dim + 1) : 1;\n  }\n\n  int64_t storage_offset = JUST(JUST(input->AsLocalTensor())->storage_offset());\n  JUST(view::InplaceView(input, Shape(target_dim_vec), target_stride_vec, storage_offset));\n\n  if (autograd::GradMode::is_enabled() && input->requires_grad()) {\n    auto backward_fn = std::make_shared<BackwardFunction>();\n    backward_fn->body = [=](const TensorTuple& out_grads, TensorTuple* in_grads,\n                            bool create_graph) -> Maybe<void> {\n      autograd::AutoGradMode mode(create_graph);\n      CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n      in_grads->resize(1);\n      JUST(oneflow::VectorAt(*in_grads, 0)) =\n          JUST(functional::Reshape(JUST(oneflow::VectorAt(out_grads, 0)), *shape));\n      return Maybe<void>::Ok();\n    };\n    backward_fn->status = []() { return false; };\n    TensorTuple outputs{input};\n    JUST(GetThreadLocalAutogradEngine()->AddNode(\"view::inplace_unsqueeze_backward\", backward_fn,\n                                                 {input}, &outputs));\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<Tensor> Squeeze(const std::shared_ptr<Tensor>& input,\n                      const std::vector<int32_t>& squeeze_dims) {\n  const auto& shape = input->shape();\n  const auto& strides = JUST(input->stride());\n  const int64_t ndim = shape->NumAxes();\n\n  const int target_ndim = ndim - squeeze_dims.size();\n  DimVector target_dim_vec(target_ndim);\n  Stride target_stride_vec(target_ndim);\n\n  {\n    int cnt = 0;\n    for (int i = 0; i < ndim; i++) {\n      if (find(squeeze_dims.begin(), squeeze_dims.end(), i) == squeeze_dims.end()) {\n        target_dim_vec[cnt] = shape->At(i);\n        target_stride_vec[cnt] = strides->at(i);\n        cnt++;\n      }\n    }\n  }\n\n  int64_t storage_offset = JUST(JUST(input->AsLocalTensor())->storage_offset());\n  std::shared_ptr<Tensor> output =\n      JUST(BasicView(input, Shape(target_dim_vec), target_stride_vec, storage_offset));\n\n  if (autograd::GradMode::is_enabled() && input->requires_grad()) {\n    auto backward_fn = std::make_shared<BackwardFunction>();\n    backward_fn->body = [=](const TensorTuple& out_grads, TensorTuple* in_grads,\n                            bool create_graph) -> Maybe<void> {\n      autograd::AutoGradMode mode(create_graph);\n      CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n      in_grads->resize(1);\n      JUST(oneflow::VectorAt(*in_grads, 0)) = JUST(functional::Reshape(\n          JUST(oneflow::VectorAt(out_grads, 0)), Shape(input->shape()->dim_vec())));\n      return Maybe<void>::Ok();\n    };\n    backward_fn->status = []() { return true; };\n    TensorTuple outputs{output};\n    JUST(GetThreadLocalAutogradEngine()->AddNode(\"view::squeeze_backward\", backward_fn, {input},\n                                                 &outputs));\n  }\n  return output;\n}\n\nMaybe<void> InplaceSqueeze(const std::shared_ptr<Tensor>& input,\n                           const std::vector<int32_t>& squeeze_dims) {\n  const auto& shape = input->shape();\n  const auto& strides = JUST(input->stride());\n  const int64_t ndim = shape->NumAxes();\n\n  const int target_ndim = ndim - squeeze_dims.size();\n  DimVector target_dim_vec(target_ndim);\n  Stride target_stride_vec(target_ndim);\n\n  {\n    int cnt = 0;\n    for (int i = 0; i < ndim; i++) {\n      if (find(squeeze_dims.begin(), squeeze_dims.end(), i) == squeeze_dims.end()) {\n        target_dim_vec[cnt] = shape->At(i);\n        target_stride_vec[cnt] = strides->at(i);\n        cnt++;\n      }\n    }\n  }\n\n  int64_t storage_offset = JUST(JUST(input->AsLocalTensor())->storage_offset());\n  JUST(view::InplaceView(input, Shape(target_dim_vec), target_stride_vec, storage_offset));\n\n  if (autograd::GradMode::is_enabled() && input->requires_grad()) {\n    auto backward_fn = std::make_shared<BackwardFunction>();\n    backward_fn->body = [=](const TensorTuple& out_grads, TensorTuple* in_grads,\n                            bool create_graph) -> Maybe<void> {\n      autograd::AutoGradMode mode(create_graph);\n      CHECK_EQ_OR_RETURN(out_grads.size(), 1);  // NOLINT(maybe-need-error-msg)\n      in_grads->resize(1);\n      JUST(oneflow::VectorAt(*in_grads, 0)) = JUST(functional::Reshape(\n          JUST(oneflow::VectorAt(out_grads, 0)), Shape(input->shape()->dim_vec())));\n      return Maybe<void>::Ok();\n    };\n    backward_fn->status = []() { return true; };\n    TensorTuple outputs{input};\n    JUST(GetThreadLocalAutogradEngine()->AddNode(\"view::inplace_squeeze_backward\", backward_fn,\n                                                 {input}, &outputs));\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<Tensor> Expand(const std::shared_ptr<Tensor>& input, const Shape& expand_shape) {\n  const Shape& input_shape = *input->shape();\n  const Stride& input_stride = *JUST(input->stride());\n  size_t lpad = expand_shape.size() - input_shape.size();\n  CHECK_GE_OR_RETURN(lpad, 0);  // NOLINT(maybe-need-error-msg)\n\n  Stride expand_stride(expand_shape.size(), 0);\n  std::vector<int32_t> reduce_dims;\n  reduce_dims.reserve(expand_shape.size());\n\n  for (int i = expand_shape.size() - 1; i >= 0; --i) {\n    int64_t dim = i < lpad ? 1 : input_shape[i - lpad];\n    if (dim == expand_shape[i]) {\n      if (i >= lpad) {\n        expand_stride[i] = input_stride[i - lpad];\n      } else if (i < expand_shape.size() - 1) {\n        expand_stride[i] = expand_stride[i + 1] * expand_shape[i + 1];\n      }\n    } else {\n      CHECK_EQ_OR_RETURN(dim, 1);  // NOLINT(maybe-need-error-msg)\n      reduce_dims.push_back(i);\n    }\n  }\n\n  if (input_shape.size() == 0) {\n    // handle scalar expand backward reduce dims\n    reduce_dims.clear();\n    for (int32_t axis = 0; axis < expand_shape.size(); ++axis) { reduce_dims.push_back(axis); }\n  }\n\n  int64_t storage_offset = JUST(JUST(input->AsLocalTensor())->storage_offset());\n  std::shared_ptr<Tensor> output =\n      JUST(BasicView(input, expand_shape, expand_stride, storage_offset));\n\n  if (autograd::GradMode::is_enabled() && input->requires_grad()) {\n    auto backward_fn = std::make_shared<BackwardFunction>();\n    backward_fn->body = [=](const TensorTuple& out_grads, TensorTuple* in_grads,\n                            bool create_graph) -> Maybe<void> {\n      autograd::AutoGradMode mode(create_graph);\n      CHECK_EQ_OR_RETURN(out_grads.size(), 1)\n          << \"out grad size should be 1, but got \" << out_grads.size();\n      in_grads->resize(1);\n      in_grads->at(0) = out_grads[0];\n      bool keep_dims = (input_shape.size() > 0);\n      if (reduce_dims.size() > 0) {\n        in_grads->at(0) =\n            JUST(functional::ReduceSum(in_grads->at(0), reduce_dims, keep_dims, NullOpt));\n      }\n      if (lpad > 0 && keep_dims) {\n        in_grads->at(0) = JUST(functional::Flatten(in_grads->at(0), 0, lpad));\n      }\n      return Maybe<void>::Ok();\n    };\n    backward_fn->status = []() { return true; };\n    TensorTuple outputs{output};\n    JUST(GetThreadLocalAutogradEngine()->AddNode(\"view::expand_backward\", backward_fn, {input},\n                                                 &outputs));\n  }\n  return output;\n}\n\nMaybe<void> InplaceExpand(const std::shared_ptr<Tensor>& input, const Shape& expand_shape) {\n  const Shape& input_shape = *input->shape();\n  const Stride& input_stride = *JUST(input->stride());\n  size_t lpad = expand_shape.size() - input_shape.size();\n  CHECK_GE_OR_RETURN(lpad, 0);  // NOLINT(maybe-need-error-msg)\n\n  Stride expand_stride(expand_shape.size(), 0);\n  std::vector<int32_t> reduce_dims;\n  reduce_dims.reserve(expand_shape.size());\n\n  for (int i = expand_shape.size() - 1; i >= 0; --i) {\n    int64_t dim = i < lpad ? 1 : input_shape[i - lpad];\n    if (dim == expand_shape[i]) {\n      if (i >= lpad) {\n        expand_stride[i] = input_stride[i - lpad];\n      } else if (i < expand_shape.size() - 1) {\n        expand_stride[i] = expand_stride[i + 1] * expand_shape[i + 1];\n      }\n    } else {\n      CHECK_EQ_OR_RETURN(dim, 1);  // NOLINT(maybe-need-error-msg)\n      reduce_dims.push_back(i);\n    }\n  }\n\n  if (input_shape.size() == 0) {\n    // handle scalar expand backward reduce dims\n    reduce_dims.clear();\n    for (int32_t axis = 0; axis < expand_shape.size(); ++axis) { reduce_dims.push_back(axis); }\n  }\n\n  int64_t storage_offset = JUST(JUST(input->AsLocalTensor())->storage_offset());\n  JUST(view::InplaceView(input, expand_shape, expand_stride, storage_offset));\n\n  if (autograd::GradMode::is_enabled() && input->requires_grad()) {\n    auto backward_fn = std::make_shared<BackwardFunction>();\n    backward_fn->body = [=](const TensorTuple& out_grads, TensorTuple* in_grads,\n                            bool create_graph) -> Maybe<void> {\n      autograd::AutoGradMode mode(create_graph);\n      CHECK_EQ_OR_RETURN(out_grads.size(), 1)\n          << \"out grad size should be 1, but got \" << out_grads.size();\n      in_grads->resize(1);\n      in_grads->at(0) = out_grads[0];\n      bool keep_dims = (input_shape.size() > 0);\n      if (reduce_dims.size() > 0) {\n        in_grads->at(0) =\n            JUST(functional::ReduceSum(in_grads->at(0), reduce_dims, keep_dims, NullOpt));\n      }\n      if (lpad > 0 && keep_dims) {\n        in_grads->at(0) = JUST(functional::Flatten(in_grads->at(0), 0, lpad));\n      }\n      return Maybe<void>::Ok();\n    };\n    backward_fn->status = []() { return true; };\n    TensorTuple outputs{input};\n    JUST(GetThreadLocalAutogradEngine()->AddNode(\"view::expand_backward\", backward_fn, {input},\n                                                 &outputs));\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<Tensor> Narrow(const std::shared_ptr<Tensor>& input, const int64_t dim, const int64_t start,\n                     const int64_t length) {\n  const auto& shape = input->shape();\n  const auto& strides = JUST(input->stride());\n  const int64_t ndim = shape->NumAxes();\n  DimVector dim_vec;\n  dim_vec.insert(dim_vec.end(), shape->dim_vec().cbegin(), shape->dim_vec().cbegin() + dim);\n  dim_vec.insert(dim_vec.end(), length);\n  dim_vec.insert(dim_vec.end(), shape->dim_vec().cbegin() + dim + 1, shape->dim_vec().end());\n\n  int64_t storage_offset = JUST(JUST(input->AsLocalTensor())->storage_offset());\n  Shape target_shape(dim_vec);\n\n  Stride stride(ndim);\n  for (int i = 0; i < ndim; ++i) {\n    stride[i] = strides->at(i);\n    if (dim == i) { storage_offset += start * strides->at(i); }\n  }\n\n  auto output = JUST(BasicView(input, target_shape, stride, storage_offset));\n  if (autograd::GradMode::is_enabled() && input->requires_grad()) {\n    auto backward_fn = std::make_shared<BackwardFunction>();\n    backward_fn->body = [=](const TensorTuple& out_grads, TensorTuple* in_grads,\n                            bool create_graph) -> Maybe<void> {\n      autograd::AutoGradMode mode(create_graph);\n      CHECK_EQ_OR_RETURN(out_grads.size(), 1)\n          << \"out grad size should be 1, but got \" << out_grads.size();\n      auto like =\n          JUST(functional::Empty(Shape(input->shape()->dim_vec()), input->dtype(),\n                                 JUST(input->device()), /*requires_grad=*/input->requires_grad(),\n                                 /*pin_memory=*/false));\n      in_grads->resize(1);\n      (*in_grads)[0] = JUST(functional::NarrowGrad(out_grads[0], like, dim, start, length));\n      return Maybe<void>::Ok();\n    };\n    backward_fn->status = []() { return true; };\n    TensorTuple outputs{output};\n    JUST(GetThreadLocalAutogradEngine()->AddNode(\"view::narrow_backward\", backward_fn, {input},\n                                                 &outputs));\n  }\n  return output;\n}\n\nMaybe<Tensor> AsStridedGrad(const std::shared_ptr<one::Tensor>& dy,\n                            const std::shared_ptr<one::Tensor>& input,\n                            const std::vector<int64_t>& sizes, const std::vector<int64_t>& strides,\n                            const int64_t storage_offset) {\n  CHECK_OR_RETURN(input->is_local()) << \"input must be local tensor.\";\n  // reference: torch/csrc/autograd/FunctionsManual.cpp\n  const size_t odim = dy->ndim();\n  std::vector<int64_t> out_sizes_, out_strides_;\n  out_sizes_.reserve(odim);\n  out_strides_.reserve(odim);\n  auto grad = dy;\n  for (int64_t i = odim - 1; i >= 0; i--) {\n    auto size_i = sizes[i];\n    auto stride_i = strides[i];\n    if (size_i == 0) {\n      return functional::Constant(*dy->shape(), 0, grad->dtype(), JUST(grad->device()));\n    } else if (size_i == 1) {\n      grad = JUST(functional::Squeeze(grad, std::vector<int32_t>{int(i)}));\n    } else if (stride_i == 0) {\n      grad = JUST(functional::ReduceSum(grad, std::vector<int32_t>{int(i)}, false, NullOpt));\n    } else {\n      out_sizes_.insert(out_sizes_.begin(), size_i);\n      out_strides_.insert(out_strides_.begin(), stride_i);\n    }\n  }\n\n  // Step (2)~(4) for the algorithm in NOTE [ Detecting Memory Overlap Within A\n  // Strided Tensor ]\n  //              on output geometry\n  const bool out_maybe_overlap = IsOverlappingMemorys(out_sizes_, out_strides_);\n\n  // For input geometry,\n  //   check for size 0 dimensions,\n  //   skip size 1 dimensions,\n  // Step (0)~(1) for the algorithm in NOTE [ Detecting Memory Overlap Within A\n  // Strided Tensor ]\n  //              on input geometry\n  auto idim = input->ndim();\n  std::vector<int64_t> inp_sizes(input->shape()->begin(), input->shape()->end());\n  std::vector<int64_t> inp_strides(JUST(input->stride())->begin(), JUST(input->stride())->end());\n  std::vector<int64_t> inp_sizes_, inp_strides_;\n  inp_sizes_.reserve(idim);\n  inp_strides_.reserve(idim);\n  for (int64_t i = idim - 1; i >= 0; i--) {\n    auto size_i = inp_sizes[i];\n    auto stride_i = inp_strides[i];\n    if (size_i == 0) {\n      return functional::Constant(*input->shape(), 0, grad->dtype(), JUST(grad->device()));\n    } else if (size_i != 1) {\n      inp_sizes_.insert(inp_sizes_.begin(), size_i);\n      inp_strides_.insert(inp_strides_.begin(), stride_i);\n    }\n  }\n  // Step (1)~(4) for the algorithm in NOTE [ Detecting Memory Overlap Within A\n  // Strided Tensor ]\n  //              on input geometry\n  const bool inp_maybe_overlap = IsOverlappingMemorys(inp_sizes_, inp_strides_);\n\n  // Rest of this function implements\n  // Step (1)~(4) for the algorithm in NOTE [ as_strided Backward and\n  // layout-aware/agnostic autograd ]\n  // TODO: Raise if not all output values are visible in input geometry.\n  //       Technically speaking, if you treat those values as constants, not\n  //       raising is fine, and mathematically correct. However, these values\n  //       really are contained in some base tensor, and by treating them as\n  //       constants we are ignoring this tight dependency. Therefore, it is\n  //       more sensible to raise here.\n\n  // Step (1): create underlying tensor as \"storage\"\n  auto input_storage_offset = JUST(input->storage_offset());\n  auto shared_offset = std::min(input_storage_offset, storage_offset);\n  auto inp_effective_offset = input_storage_offset - shared_offset;\n  auto out_effective_offset = storage_offset - shared_offset;\n  auto base_size = std::max(MinStorageSize(inp_sizes_, inp_strides_, inp_effective_offset),\n                            MinStorageSize(out_sizes_, out_strides_, out_effective_offset));\n  auto storage =\n      JUST(functional::Constant(Shape({base_size}), 0, grad->dtype(), JUST(grad->device())));\n\n  std::shared_ptr<Tensor> flatten_full_indices;\n  if (inp_maybe_overlap || out_maybe_overlap) {\n    flatten_full_indices = JUST(functional::Arange(Scalar(0), Scalar(base_size), Scalar(1),\n                                                   DType::Int64(), JUST(grad->device())));\n  }\n\n  // Step (2): use output geometry to scatter gradients into storage\n  if (out_maybe_overlap) {\n    auto out_indices = JUST(functional::AsStrided(flatten_full_indices, out_sizes_, out_strides_,\n                                                  out_effective_offset));\n    storage = JUST(functional::IndexAddInplace(\n        storage, 0,\n        JUST(functional::Reshape(out_indices, Shape({out_indices->shape()->elem_cnt()}))),\n        JUST(functional::Reshape(grad, Shape({grad->shape()->elem_cnt()}))), Scalar(1.0)));\n  } else {\n    // assume that new tensors have 0 storage offset\n    // torch impl: storage.as_strided(out_sizes_, out_strides_, out_effective_offset)\n    //     .copy_(grad);\n    // TODO(wangyinggang): use functional::copy_ replace this TensorSetItem\n    storage = JUST(functional::AsStrided(storage, out_sizes_, out_strides_, out_effective_offset));\n    functional::TensorIndex ellipsis_index;\n    ellipsis_index.emplace_back(functional::detail::EllipsisIndex());\n    JUST(functional::TensorSetItem(storage, ellipsis_index, grad));\n  }\n\n  // Step (3): if input tensor has overlapping memory, divide scattered gradient\n  //           at storage[i] by the number of times i shows up in input geometry\n  if (inp_maybe_overlap) {\n    auto count =\n        JUST(functional::Constant(*storage->shape(), 0, storage->dtype(), JUST(storage->device())));\n    flatten_full_indices = JUST(functional::AsStrided(flatten_full_indices, inp_sizes_,\n                                                      inp_strides_, inp_effective_offset));\n    auto inp_indices = JUST(functional::Reshape(\n        flatten_full_indices, Shape({flatten_full_indices->shape()->elem_cnt()})));\n\n    auto ones = JUST(functional::Constant(Shape({1}), 0, grad->dtype(), JUST(grad->device())));\n    count = JUST(functional::IndexAddInplace(count, 0, inp_indices, ones, Scalar(1.0)));\n    count = JUST(functional::Expand(count, *inp_indices->shape()));\n    storage = JUST(functional::Div(storage, count));  // this will give nan outside visible range\n  }\n\n  // Step (4): return as_strided view of the storage tensor with input geometry\n  return functional::AsStrided(storage, inp_sizes, inp_strides, inp_effective_offset);\n}\n\nMaybe<Tensor> AsStrided(const std::shared_ptr<one::Tensor>& input,\n                        const std::vector<int64_t>& sizes, const std::vector<int64_t>& strides,\n                        const int64_t storage_offset) {\n  DimVector dim_vec;\n  dim_vec.insert(dim_vec.end(), sizes.begin(), sizes.end());\n  Shape target_shape(dim_vec);\n  Stride stride(strides.begin(), strides.end());\n  auto output = JUST(view::BasicView(input, target_shape, stride, storage_offset));\n  if (autograd::GradMode::is_enabled() && input->requires_grad()) {\n    auto backward_fn = std::make_shared<BackwardFunction>();\n    backward_fn->body = [=](const TensorTuple& out_grads, TensorTuple* in_grads,\n                            bool create_graph) -> Maybe<void> {\n      autograd::AutoGradMode mode(create_graph);\n      CHECK_EQ_OR_RETURN(out_grads.size(), 1)\n          << \"out grad size should be 1, but got \" << out_grads.size();\n      in_grads->resize(1);\n      (*in_grads)[0] = JUST(AsStridedGrad(out_grads[0], input, sizes, strides, storage_offset));\n      return Maybe<void>::Ok();\n    };\n    backward_fn->status = []() { return true; };\n    TensorTuple outputs{output};\n    JUST(GetThreadLocalAutogradEngine()->AddNode(\"view::as_strided_backward\", backward_fn, {input},\n                                                 &outputs));\n  }\n  return output;\n}\n\nMaybe<void> InplaceAsStrided(const std::shared_ptr<one::Tensor>& input,\n                             const std::vector<int64_t>& sizes, const std::vector<int64_t>& strides,\n                             const int64_t storage_offset) {\n  DimVector dim_vec;\n  dim_vec.insert(dim_vec.end(), sizes.begin(), sizes.end());\n  Shape target_shape(dim_vec);\n  Stride stride(strides.begin(), strides.end());\n  JUST(view::InplaceView(input, target_shape, stride, storage_offset));\n  if (autograd::GradMode::is_enabled() && input->requires_grad()) {\n    auto backward_fn = std::make_shared<BackwardFunction>();\n    backward_fn->body = [=](const TensorTuple& out_grads, TensorTuple* in_grads,\n                            bool create_graph) -> Maybe<void> {\n      autograd::AutoGradMode mode(create_graph);\n      CHECK_EQ_OR_RETURN(out_grads.size(), 1)\n          << \"out grad size should be 1, but got \" << out_grads.size();\n      in_grads->resize(1);\n      (*in_grads)[0] = JUST(AsStridedGrad(out_grads[0], input, sizes, strides, storage_offset));\n      return Maybe<void>::Ok();\n    };\n    backward_fn->status = []() { return true; };\n    TensorTuple outputs{input};\n    JUST(GetThreadLocalAutogradEngine()->AddNode(\"view::inplace_as_strided_backward\", backward_fn,\n                                                 {input}, &outputs));\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<Tensor> Transpose(const std::shared_ptr<Tensor>& input, const std::vector<int32_t>& permute) {\n  const auto& shape = input->shape();\n  const auto& strides = JUST(input->stride());\n  const int64_t ndim = shape->NumAxes();\n  int64_t storage_offset = JUST(JUST(input->AsLocalTensor())->storage_offset());\n\n  CHECK_EQ_OR_RETURN(permute.size(), ndim)\n      << \"permute size should be equal to input tensor's ndim, but got \" << permute.size();\n  auto positive_perm = permute;\n  for (auto i = 0; i < positive_perm.size(); i++) {\n    positive_perm[i] = JUST(maybe_wrap_dim(positive_perm[i], ndim));\n  }\n\n  DimVector target_dims(ndim);\n  Stride stride(ndim);\n  for (int i = 0; i < ndim; ++i) {\n    target_dims[i] = shape->At(permute[i]);\n    stride[i] = strides->at(permute[i]);\n  }\n\n  auto output = JUST(BasicView(input, Shape(target_dims), stride, storage_offset));\n  if (autograd::GradMode::is_enabled() && input->requires_grad()) {\n    auto backward_fn = std::make_shared<BackwardFunction>();\n    backward_fn->body = [=](const TensorTuple& out_grads, TensorTuple* in_grads,\n                            bool create_graph) -> Maybe<void> {\n      std::vector<int32_t> grad_perm;\n      grad_perm.resize(ndim);\n      for (int i = 0; i < ndim; ++i) { grad_perm[permute[i]] = i; }\n      autograd::AutoGradMode mode(create_graph);\n      CHECK_EQ_OR_RETURN(out_grads.size(), 1)\n          << \"out grad size should be 1, but got \" << out_grads.size();\n      in_grads->resize(1);\n      (*in_grads)[0] = JUST(functional::Transpose(out_grads[0], grad_perm));\n      return Maybe<void>::Ok();\n    };\n    backward_fn->status = []() { return true; };\n    TensorTuple outputs{output};\n    JUST(GetThreadLocalAutogradEngine()->AddNode(\"view::transpose_backward\", backward_fn, {input},\n                                                 &outputs));\n  }\n  return output;\n}\n\nMaybe<Tensor> UnfoldTensor(const std::shared_ptr<Tensor>& input, const int32_t dimension,\n                           const int32_t size, const int32_t step) {\n  const auto& shape = input->shape();\n  const auto& stride = JUST(input->stride());\n  const int64_t ndim = shape->NumAxes();\n  int64_t storage_offset = JUST(JUST(input->AsLocalTensor())->storage_offset());\n\n  CHECK_GE_OR_RETURN(dimension, 0) << \"attibute dimension should be >= 0, but got \" << dimension;\n  CHECK_LE_OR_RETURN(dimension, ndim)\n      << \"attibute dimension should be <= input tensor's ndim, but got \" << dimension;\n\n  const int32_t max_size = ndim == 0 ? 1 : shape->At(dimension);\n  CHECK_GT_OR_RETURN(size, 0) << \"attibute size should be > 0, but got \" << size;\n  CHECK_LE_OR_RETURN(size, max_size)\n      << \"attibute size should be <= max_size(\" << max_size << \") but got \" << size;\n  CHECK_GT_OR_RETURN(step, 0) << \"attibute step should be > 0, but got \" << size;\n\n  DimVector out_shape(ndim + 1);\n  Stride out_stride(ndim + 1);\n  out_shape[ndim] = size;\n  out_stride[ndim] = ndim == 0 ? 1 : stride->at(dimension);\n  for (int64_t d = 0; d < ndim; ++d) {\n    const int64_t in_size_at_d = shape->At(d);\n    if (d == dimension) {\n      out_shape.at(d) = (in_size_at_d - size) / step + 1;\n      out_stride.at(d) = step * stride->at(d);\n    } else {\n      out_shape.at(d) = in_size_at_d;\n      out_stride.at(d) = stride->at(d);\n    }\n  }\n  auto output = JUST(BasicView(input, Shape(out_shape), out_stride, storage_offset));\n\n  if (autograd::GradMode::is_enabled() && input->requires_grad()) {\n    auto backward_fn = std::make_shared<BackwardFunction>();\n    backward_fn->body = [=](const TensorTuple& out_grads, TensorTuple* in_grads,\n                            bool create_graph) -> Maybe<void> {\n      autograd::AutoGradMode mode(create_graph);\n      CHECK_EQ_OR_RETURN(out_grads.size(), 1)\n          << \"out grad size should be 1, but got \" << out_grads.size();\n      in_grads->resize(1);\n      (*in_grads)[0] =\n          JUST(functional::UnfoldTensorGrad(out_grads[0], input, dimension, size, step));\n      return Maybe<void>::Ok();\n    };\n    backward_fn->status = []() { return true; };\n    TensorTuple outputs{output};\n    JUST(GetThreadLocalAutogradEngine()->AddNode(\"view::unfold_tensor_backward\", backward_fn,\n                                                 {input}, &outputs));\n  }\n\n  return output;\n}\n\nMaybe<Tensor> Diagonal(const std::shared_ptr<Tensor>& input, const int32_t offset,\n                       const int32_t dim1, const int32_t dim2) {\n  const auto& shape = input->shape();\n  const auto& stride = JUST(input->stride());\n  const int64_t ndim = shape->NumAxes();\n  int64_t storage_offset = JUST(JUST(input->AsLocalTensor())->storage_offset());\n\n  // infer output storage_offset\n  int64_t diag_size = 0;\n  if (offset >= 0) {\n    diag_size = std::max<int64_t>(std::min(shape->At(dim1), shape->At(dim2) - offset), 0);\n  } else {\n    diag_size = std::max<int64_t>(std::min(shape->At(dim1) + offset, shape->At(dim2)), 0);\n  }\n  if (diag_size == 0) {\n    // skip\n  } else if (offset >= 0) {\n    storage_offset += offset * stride->at(dim2);\n  } else {\n    storage_offset -= offset * stride->at(dim1);\n  }\n\n  CHECK_GE_OR_RETURN(ndim, 2) << \"input tensor's ndim should be >= 2, but got \" << ndim;\n  // infer output shape and stride\n  DimVector out_shape(shape->dim_vec());\n  Stride out_stride(*stride);\n  out_shape.erase(out_shape.begin() + std::max(dim1, dim2));\n  out_stride.erase(out_stride.begin() + std::max(dim1, dim2));\n  out_shape.erase(out_shape.begin() + std::min(dim1, dim2));\n  out_stride.erase(out_stride.begin() + std::min(dim1, dim2));\n  out_shape.emplace_back(diag_size);\n  out_stride.emplace_back(stride->at(dim1) + stride->at(dim2));\n\n  // generate view tensor\n  auto output = JUST(BasicView(input, Shape(out_shape), out_stride, storage_offset));\n  // autograd\n  if (autograd::GradMode::is_enabled() && input->requires_grad()) {\n    std::vector<int32_t> input_index{dim1, dim2};\n    for (int32_t i = 0; i < ndim; i++) {\n      if (i != dim1 && i != dim2) { input_index.push_back(i); }\n    }\n\n    auto backward_fn = std::make_shared<BackwardFunction>();\n    backward_fn->body = [=](const TensorTuple& out_grads, TensorTuple* in_grads,\n                            bool create_graph) -> Maybe<void> {\n      autograd::AutoGradMode mode(create_graph);\n      CHECK_EQ_OR_RETURN(out_grads.size(), 1)\n          << \"out grad size should be 1, but got \" << out_grads.size();\n      in_grads->resize(1);\n      std::shared_ptr<one::Tensor> d_x = JUST(functional::Transpose(input, input_index));\n      (*in_grads)[0] = JUST(functional::DiagonalGrad(out_grads[0], d_x, offset));\n      return Maybe<void>::Ok();\n    };\n    backward_fn->status = []() { return true; };\n    TensorTuple outputs{output};\n    JUST(GetThreadLocalAutogradEngine()->AddNode(\"view::diagonal_backward\", backward_fn, {input},\n                                                 &outputs));\n  }\n\n  return output;\n}\n\n}  // namespace view\n\nMaybe<void> Touch(std::shared_ptr<Tensor> input, Symbol<Stream> stream) {\n  auto eager_blob_objects = std::make_shared<vm::EagerBlobObjectList>();\n  if (input->is_global()) { input = JUST(input->cur_rank_phy_tensor()); }\n  if (input) { eager_blob_objects->push_back(JUST(input->eager_blob_object())); }\n  JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe<void> {\n    return builder->TouchTensors(eager_blob_objects, stream);\n  }));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/framework/tensor_methods.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#ifndef ONEFLOW_CORE_FRAMEWORK_TENSOR_METHODS_H_\n#define ONEFLOW_CORE_FRAMEWORK_TENSOR_METHODS_H_\n\n#include \"oneflow/core/framework/tensor.h\"\n\nnamespace oneflow {\n\nclass Stream;\n\nnamespace one {\n\nclass Tensor;\n\nnamespace view {\n\nbool IsEnvViewDisabled();\n\nbool IsViewApplicable(const std::shared_ptr<Tensor>& input);\n\nstatic bool IsOverlappingMemorys(const std::vector<int64_t>& sizes,\n                                 const std::vector<int64_t>& strides);\n\nstatic int64_t MinStorageSize(const std::vector<int64_t>& sizes,\n                              const std::vector<int64_t>& strides, int64_t storage_offset);\n\nMaybe<Tensor> BasicView(const std::shared_ptr<Tensor>& input, const Shape& target_shape,\n                        const int64_t storage_offset);\n\nMaybe<Tensor> BasicView(const std::shared_ptr<Tensor>& input, const Shape& target_shape,\n                        const Stride& target_stride, const int64_t storage_offset);\n\nMaybe<void> InplaceView(const std::shared_ptr<Tensor>& input, const Shape& target_shape,\n                        const Stride& target_stride, int64_t const storage_offset);\n\nMaybe<Tensor> Reshape(const std::shared_ptr<Tensor>& input, const Shape& target_shape);\n\nMaybe<Tensor> Reshape(const std::shared_ptr<Tensor>& input, const Shape& target_shape,\n                      const Stride& target_stride);\n\nMaybe<Tensor> Slice(const std::shared_ptr<Tensor>& input, const std::vector<int64_t>& starts,\n                    const std::vector<int64_t>& ends, const std::vector<int64_t>& steps);\n\nMaybe<Tensor> Unsqueeze(const std::shared_ptr<Tensor>& input, const int32_t expand_dim);\n\nMaybe<void> InplaceUnsqueeze(const std::shared_ptr<Tensor>& input, const int32_t expand_dim);\n\nMaybe<Tensor> Squeeze(const std::shared_ptr<Tensor>& input,\n                      const std::vector<int32_t>& squeeze_dims);\n\nMaybe<void> InplaceSqueeze(const std::shared_ptr<Tensor>& input,\n                           const std::vector<int32_t>& squeeze_dims);\n\nMaybe<Tensor> Expand(const std::shared_ptr<Tensor>& input, const Shape& expand_shape);\n\nMaybe<void> InplaceExpand(const std::shared_ptr<Tensor>& input, const Shape& expand_shape);\n\nMaybe<Tensor> Narrow(const std::shared_ptr<Tensor>& input, const int64_t dim, const int64_t start,\n                     const int64_t length);\n\nMaybe<Tensor> AsStridedGrad(const std::shared_ptr<one::Tensor>& dy,\n                            const std::shared_ptr<one::Tensor>& input,\n                            const std::vector<int64_t>& sizes, const std::vector<int64_t>& strides,\n                            const int64_t storage_offset);\n\nMaybe<Tensor> AsStrided(const std::shared_ptr<one::Tensor>& input,\n                        const std::vector<int64_t>& sizes, const std::vector<int64_t>& strides,\n                        const int64_t storage_offset);\n\nMaybe<void> InplaceAsStrided(const std::shared_ptr<one::Tensor>& input,\n                             const std::vector<int64_t>& sizes, const std::vector<int64_t>& strides,\n                             const int64_t storage_offset);\n\nMaybe<Tensor> Transpose(const std::shared_ptr<Tensor>& input, const std::vector<int32_t>& permute);\n\nMaybe<Tensor> UnfoldTensor(const std::shared_ptr<Tensor>& input, const int32_t dimension,\n                           const int32_t size, const int32_t step);\n\nMaybe<Tensor> Diagonal(const std::shared_ptr<Tensor>& input, const int32_t offset,\n                       const int32_t dim1, const int32_t dim2);\n\n}  // namespace view\n\nMaybe<void> Touch(std::shared_ptr<Tensor> input, Symbol<Stream> stream);\n\n}  // namespace one\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_FRAMEWORK_TENSOR_METHOD_H_\n"
  },
  {
    "path": "oneflow/core/framework/tensor_name_scope.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/tensor_name_scope.h\"\n#include <cstdint>\n\nnamespace oneflow {\nnamespace one {\n\n/* static */ TensorNameScope* TensorNameScope::Global() {\n  static TensorNameScope scope;\n  return &scope;\n}\n\nconst std::string& TensorNameScope::Lookup(const Tensor* tensor) const {\n  uint64_t key = reinterpret_cast<uint64_t>(tensor);\n  const auto* tensor_names = [&]() {\n    if (tensor->is_lazy()) { return &lazy_tensor_names_; }\n    return &eager_tensor_names_;\n  }();\n  std::lock_guard<std::mutex> lock(mutex_);\n  const auto& it = tensor_names->find(key);\n  if (it != tensor_names->end()) {\n    return it->second;\n  } else {\n    return default_tensor_name_;\n  }\n}\n\nconst std::string& TensorNameScope::Lookup(const std::shared_ptr<Tensor>& tensor) const {\n  return Lookup(tensor.get());\n}\n\nvoid TensorNameScope::Record(const Tensor* tensor, const std::string& name) {\n  uint64_t key = reinterpret_cast<uint64_t>(tensor);\n  auto* tensor_names = [&]() {\n    if (tensor->is_lazy()) { return &lazy_tensor_names_; }\n    return &eager_tensor_names_;\n  }();\n  std::lock_guard<std::mutex> lock(mutex_);\n  // We assume that the name of the tensor will be update more than once.\n  (*tensor_names)[key] = name;\n}\n\nvoid TensorNameScope::Record(const std::shared_ptr<Tensor>& tensor, const std::string& name) {\n  Record(tensor.get(), name);\n}\n\nvoid TensorNameScope::Clear() {\n  std::lock_guard<std::mutex> lock(mutex_);\n  lazy_tensor_names_.clear();\n  eager_tensor_names_.clear();\n}\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/framework/tensor_name_scope.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_FRAMEWORK_TENSOR_NAME_SCOPE_H_\n#define ONEFLOW_CORE_FRAMEWORK_TENSOR_NAME_SCOPE_H_\n\n#include <string>\n\n#include \"oneflow/core/framework/tensor.h\"\n\nnamespace oneflow {\nnamespace one {\n\nclass TensorNameScope {\n public:\n  static TensorNameScope* Global();\n\n  const std::string& Lookup(const Tensor* tensor) const;\n  const std::string& Lookup(const std::shared_ptr<Tensor>& tensor) const;\n\n  void Record(const Tensor* tensor, const std::string& name);\n  void Record(const std::shared_ptr<Tensor>& tensor, const std::string& name);\n\n  void Clear();\n\n private:\n  TensorNameScope() : default_tensor_name_(\"\") {}\n  virtual ~TensorNameScope() = default;\n\n private:\n  mutable std::mutex mutex_;\n\n  std::string default_tensor_name_;\n  // uint64_t(Tensor*) -> the name of the tensor.\n  std::unordered_map<uint64_t, std::string> lazy_tensor_names_;\n  std::unordered_map<uint64_t, std::string> eager_tensor_names_;\n};\n\n}  // namespace one\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_FRAMEWORK_TENSOR_NAME_SCOPE_H_\n"
  },
  {
    "path": "oneflow/core/framework/tensor_rpc_util.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <memory>\n#include \"oneflow/core/framework/tensor_rpc_util.h\"\n#include \"oneflow/core/framework/sync_symbol_global_tensor_meta.h\"\n#include \"oneflow/core/framework/sync_symbol_nd_sbp.h\"\n#include \"oneflow/core/framework/synced_symbol_map.h\"\n#include \"oneflow/core/framework/rank_group_rpc_util.h\"\n#include \"oneflow/core/framework/tensor.h\"\n#include \"oneflow/core/common/flat_shape.h\"\n#include \"oneflow/core/common/shape_vec.h\"\n#include \"oneflow/core/intrusive/flat_msg.h\"\n#include \"oneflow/core/job/rank_group.h\"\n#include \"oneflow/core/job/rank_group_scope.h\"\n#include \"oneflow/core/common/constant.h\"\n\nnamespace oneflow {\nnamespace private_details {\n\nstruct FlatTensorConsistency;\n\nclass CheckConsistencyAsyncTransportCtx : public AsyncTransportCtx {\n public:\n  CheckConsistencyAsyncTransportCtx(const TransportToken& transport_token,\n                                    Symbol<one::GlobalTensorMeta> tensor_meta,\n                                    const Optional<Symbol<NdSbp>>& consumer_nd_sbp_constraint,\n                                    const TransportToken& tensor_transport_token)\n      : AsyncTransportCtx(transport_token),\n        tensor_meta_(tensor_meta),\n        consumer_nd_sbp_constraint_(consumer_nd_sbp_constraint),\n        tensor_transport_token_(tensor_transport_token) {}\n\n  ~CheckConsistencyAsyncTransportCtx() override;\n\n  Maybe<void> PrepareSendBufferAndCallback(int64_t rank, void** buffer, std::size_t* size,\n                                           std::function<void()>* Callback) override;\n\n  Maybe<void> PrepareRecvBufferAndCallback(int64_t rank, void** buffer, std::size_t* size,\n                                           std::function<void()>* Callback) override;\n\n  Maybe<void> Check() const;\n\n private:\n  Symbol<one::GlobalTensorMeta> tensor_meta_;\n  Optional<Symbol<NdSbp>> consumer_nd_sbp_constraint_;\n  TransportToken tensor_transport_token_;\n  std::shared_ptr<FlatTensorConsistency> flat_tensor_consistency_;\n};\n\n// clang-format off\nFLAT_MSG_BEGIN(FlatTensorConsistency);\n public:\n  static Maybe<FlatTensorConsistency> New() {\n    const auto& consistency = std::make_shared<FlatTensorConsistency>();\n    consistency->clear();\n    return consistency;\n  }\n  static Maybe<FlatTensorConsistency> New(\n      Symbol<one::GlobalTensorMeta> tensor_meta,\n      const Optional<Symbol<NdSbp>>& consumer_nd_sbp_constraint,\n      const TransportToken& tensor_transport_token) {\n    const auto& consistency = std::make_shared<FlatTensorConsistency>();\n    consistency->clear();\n    JUST(consistency->Init(tensor_meta, consumer_nd_sbp_constraint, tensor_transport_token));\n    return consistency;\n  }\n\n  Maybe<void> Check(Symbol<one::GlobalTensorMeta> tensor_meta,\n    const Optional<Symbol<NdSbp>>& consumer_nd_sbp_constraint,\n                    const TransportToken& tensor_transport_token) {\n    const auto& this_synced_tensor_meta =\n        JUST(SyncedSymbolMap<one::GlobalTensorMeta>::Symbol4SyncedSymbolId(\n            this->synced_tensor_meta_symbol_id()));\n    CHECK_OR_RETURN(this_synced_tensor_meta == tensor_meta);\n    CHECK_EQ_OR_RETURN(consumer_nd_sbp_constraint.has_value(),\n                       this->has_consumer_nd_sbp_constraint_symbol_id());\n    if (this->has_consumer_nd_sbp_constraint_symbol_id()) {\n      const auto& that_rank_constaint =\n          JUST(SyncedSymbolMap<one::GlobalTensorMeta>::Symbol4SyncedSymbolId(\n            this->consumer_nd_sbp_constraint_symbol_id()))->nd_sbp();\n      const auto& this_rank_constaint = JUST(consumer_nd_sbp_constraint);\n      CHECK_OR_RETURN(this_rank_constaint == that_rank_constaint);\n    }\n    CHECK_EQ_OR_RETURN(this->tensor_transport_token(), tensor_transport_token);\n    return Maybe<void>::Ok();\n  }\n\n private:\n  Maybe<void> Init(Symbol<one::GlobalTensorMeta> tensor_meta,\n    const Optional<Symbol<NdSbp>>& consumer_nd_sbp_constraint,\n                   const TransportToken& tensor_transport_token) {\n    this->set_synced_tensor_meta_symbol_id(JUST(SyncedSymbolMap<one::GlobalTensorMeta>::FindOrSync(\n        tensor_meta, &SyncSymbolGlobalTensorMeta)));\n    if (consumer_nd_sbp_constraint.has_value()) {\n      const auto& this_rank_constaint = JUST(consumer_nd_sbp_constraint);\n      this->set_consumer_nd_sbp_constraint_symbol_id(\n        JUST(SyncedSymbolMap<NdSbp>::FindOrSync(\n              this_rank_constaint, &SyncSymbolNdSbp)));\n    } else {\n      this->clear_consumer_nd_sbp_constraint_symbol_id();\n    }\n    this->set_tensor_transport_token(static_cast<uint64_t>(tensor_transport_token));\n    return Maybe<void>::Ok();\n  }\n  \n  FLAT_MSG_DEFINE_OPTIONAL(uint64_t, synced_tensor_meta_symbol_id);\n  FLAT_MSG_DEFINE_OPTIONAL(uint64_t, consumer_nd_sbp_constraint_symbol_id);\n  FLAT_MSG_DEFINE_OPTIONAL(uint64_t, tensor_transport_token);\nFLAT_MSG_END(FlatTensorConsistency);\n// clang-format on\n\nCheckConsistencyAsyncTransportCtx::~CheckConsistencyAsyncTransportCtx() {}\n\nMaybe<void> CheckConsistencyAsyncTransportCtx::PrepareSendBufferAndCallback(\n    int64_t rank, void** buffer, std::size_t* size, std::function<void()>* Callback) {\n  const auto& tensor_consistency = JUST(FlatTensorConsistency::New(\n      tensor_meta_, consumer_nd_sbp_constraint_, tensor_transport_token_));\n  *buffer = tensor_consistency.get();\n  *size = sizeof(FlatTensorConsistency);\n  *Callback = [tensor_consistency] {};\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> CheckConsistencyAsyncTransportCtx::PrepareRecvBufferAndCallback(\n    int64_t rank, void** buffer, std::size_t* size, std::function<void()>* Callback) {\n  const auto& flat_tensor_consistency = JUST(FlatTensorConsistency::New());\n  *buffer = flat_tensor_consistency.get();\n  *size = sizeof(FlatTensorConsistency);\n  *Callback = [flat_tensor_consistency]() {};\n  flat_tensor_consistency_ = flat_tensor_consistency;\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> CheckConsistencyAsyncTransportCtx::Check() const {\n  if (!flat_tensor_consistency_) { return Maybe<void>::Ok(); }\n  JUST(flat_tensor_consistency_->Check(tensor_meta_, consumer_nd_sbp_constraint_,\n                                       tensor_transport_token_));\n  return Maybe<void>::Ok();\n}\n\nint64_t* MutThreadLocalTensorMetaCheckDepth() {\n  static thread_local int64_t depth = 0;\n  return &depth;\n}\n\nMaybe<CheckConsistencyAsyncTransportCtx> LaunchTensorMetaConsistencyCheck(\n    const one::Tensor& tensor) {\n  const auto& rank_group = JUST(RankGroupScope::CurrentRankGroup());\n  const auto& transport_token =\n      JUST(TransportToken::NewTransportToken(kTransportTokenTypeCheckTensorConsistency));\n  const auto& tensor_meta = JUST(tensor.global_tensor_meta());\n  const auto& constaint = JUST(tensor.consumer_nd_sbp_constraint());\n  const TransportToken& tensor_transport_token = JUST(tensor.transport_token());\n  const auto& ctx = std::make_shared<CheckConsistencyAsyncTransportCtx>(\n      transport_token, tensor_meta, constaint, tensor_transport_token);\n  JUST(TransportUtil::SendToNextRankInRing(rank_group, transport_token, ctx.get()));\n  JUST(TransportUtil::ReceiveFromPrevRankInRing(rank_group, transport_token, ctx.get()));\n  return ctx;\n}\n\nMaybe<void> BusyWaitAndCheck(std::shared_ptr<CheckConsistencyAsyncTransportCtx>& ctx) {\n  JUST_MSG(ctx->WaitDone(), kAsymmetricCodeErrorMsg);\n  JUST(ctx->Check());\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> RunCallback(const std::shared_ptr<one::Tensor>& tensor,\n                        const std::function<Maybe<void>()>& Callback) {\n  return Callback();\n}\n\n}  // namespace private_details\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/framework/tensor_rpc_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_FRAMEWORK_TENSOR_RPC_UTIL_H_\n#define ONEFLOW_CORE_FRAMEWORK_TENSOR_RPC_UTIL_H_\n\n#include \"oneflow/core/framework/transport_util.h\"\n#include \"oneflow/core/framework/tensor.h\"\n#include \"oneflow/core/common/optional.h\"\n#include \"oneflow/core/common/decorator.h\"\n#include \"oneflow/core/rpc/include/global_process_ctx.h\"\n#include \"oneflow/core/common/check_level.h\"\n\nnamespace oneflow {\n\nnamespace private_details {\n\nclass CheckConsistencyAsyncTransportCtx;\n\nint64_t* MutThreadLocalTensorMetaCheckDepth();\n\nMaybe<CheckConsistencyAsyncTransportCtx> LaunchTensorMetaConsistencyCheck(\n    const one::Tensor& tensor);\n\nMaybe<void> BusyWaitAndCheck(std::shared_ptr<CheckConsistencyAsyncTransportCtx>& ctx);\n\nMaybe<void> RunCallback(const std::shared_ptr<one::Tensor>& tensor,\n                        const std::function<Maybe<void>()>& Callback);\n\n}  // namespace private_details\n\ninline bool IsGlobalTensorMetaCheckDisabled() {\n  return *private_details::MutThreadLocalTensorMetaCheckDepth() > 1;\n}\n\ntemplate<typename... Args>\nstruct CheckGlobalTensorMeta;\n\ntemplate<typename RetT, typename... Args>\nstruct CheckGlobalTensorMeta<RetT, const std::shared_ptr<one::Tensor>&, Args...> {\n  static_assert(is_maybe<RetT>::value, \"returned value type must be Maybe<T>.\");\n  template<RetT (*func)(const std::shared_ptr<one::Tensor>&, Args...)>\n  static RetT Call(const std::shared_ptr<one::Tensor>& tensor, Args... args) {\n    std::shared_ptr<private_details::CheckConsistencyAsyncTransportCtx> ctx;\n    static bool is_env_enabled_check = IsEnvEnabled(/* check_level */ 1);\n    int64_t* depth = private_details::MutThreadLocalTensorMetaCheckDepth();\n    if (*depth == 0 && is_env_enabled_check) {\n      ctx = JUST(private_details::LaunchTensorMetaConsistencyCheck(*tensor));\n    }\n    ++*depth;\n    RetT ret = func(tensor, args...);\n    --*depth;\n    // Always synchronize global tensor meta even if `func` failed.\n    if (*depth == 0 && is_env_enabled_check) { JUST(private_details::BusyWaitAndCheck(ctx)); }\n    return ret;\n  }\n};\n\nstruct DisableCheckGlobalTensorMetaScope final {\n  DisableCheckGlobalTensorMetaScope() { ++*private_details::MutThreadLocalTensorMetaCheckDepth(); }\n  ~DisableCheckGlobalTensorMetaScope() { --*private_details::MutThreadLocalTensorMetaCheckDepth(); }\n};\n\nstatic constexpr auto* WithConsistencyChecked =\n    DECORATE(&private_details::RunCallback, CheckGlobalTensorMeta);\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_FRAMEWORK_TENSOR_RPC_UTIL_H_\n"
  },
  {
    "path": "oneflow/core/framework/tensor_storage.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/tensor_storage.h\"\n#include \"oneflow/core/eager/tensor_storage.h\"\n#include \"oneflow/core/framework/shut_down_util.h\"\n\nnamespace oneflow {\nnamespace one {\n\nTensorStorage::TensorStorage(const std::shared_ptr<vm::TensorStorage>& tensor_storage)\n    : storage_(tensor_storage) {}\n\nTensorStorage::~TensorStorage() {\n  if (!IsShuttingDown() && releaser_hook_) { (*releaser_hook_)(storage_); }\n}\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/framework/tensor_storage.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_FRAMEWORK_TENSOR_STORAGE_H_\n#define ONEFLOW_CORE_FRAMEWORK_TENSOR_STORAGE_H_\n\n#include <memory>\n#include <functional>\n\nnamespace oneflow {\n\nclass ParallelDesc;\n\nnamespace vm {\n\nclass TensorStorage;\n\n}  // namespace vm\n\nnamespace one {\n\nclass TensorStorage final {\n public:\n  explicit TensorStorage(const std::shared_ptr<vm::TensorStorage>& tensor_storage);\n  ~TensorStorage();\n\n  using ReleaserHookT = std::function<void(const std::shared_ptr<vm::TensorStorage>&)>;\n\n  const std::shared_ptr<vm::TensorStorage> storage() const { return storage_; }\n\n  void set_releaser_hook(const ReleaserHookT& releaser_hook) {\n    releaser_hook_ = std::make_shared<ReleaserHookT>(releaser_hook);\n  }\n\n private:\n  std::shared_ptr<vm::TensorStorage> storage_;\n  std::shared_ptr<ReleaserHookT> releaser_hook_;\n};\n\n}  // namespace one\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_FRAMEWORK_TENSOR_STORAGE_H_\n"
  },
  {
    "path": "oneflow/core/framework/tensor_tuple.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/framework/tensor_tuple.h\"\n\nnamespace oneflow {\nnamespace one {\n\nTensorTuple::TensorTuple(std::vector<std::shared_ptr<Tensor>>::size_type size) { resize(size); }\n\nTensorTuple::TensorTuple(std::initializer_list<std::shared_ptr<Tensor>> init_list) {\n  for (const auto& tensor : init_list) { emplace_back(tensor); }\n}\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/framework/tensor_tuple.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#ifndef ONEFLOW_CORE_FRAMEWORK_TENSOR_TUPLE_H_\n#define ONEFLOW_CORE_FRAMEWORK_TENSOR_TUPLE_H_\n\n#include <memory>\n#include <vector>\n#include \"oneflow/core/common/small_vector.h\"\n#include \"oneflow/core/common/op_args_reserved_size.h\"\n\nnamespace oneflow {\nnamespace one {\n\nclass Tensor;\n\nclass TensorTuple final : public small_vector<std::shared_ptr<Tensor>>,\n                          public std::enable_shared_from_this<TensorTuple> {\n public:\n  // TensorTuple(const TensorTuple&) = delete;\n  // TensorTuple(TensorTuple&) = delete;\n  TensorTuple() = default;\n  TensorTuple(std::vector<std::shared_ptr<Tensor>>::size_type size);\n  TensorTuple(std::initializer_list<std::shared_ptr<Tensor>> init_list);\n  ~TensorTuple() = default;\n};\n\n}  // namespace one\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_FRAMEWORK_TENSOR_TUPLE_H_\n"
  },
  {
    "path": "oneflow/core/framework/tensor_util.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/tensor_util.h\"\n#include \"oneflow/core/common/blocking_then_busy.h\"\n#include \"oneflow/core/framework/instructions_builder.h\"\n#include \"oneflow/core/framework/tensor_name_scope.h\"\n#include \"oneflow/core/job/job_build_and_infer_ctx_mgr.h\"\n#include \"oneflow/core/kernel/kernel_util.h\"\n#include \"oneflow/core/vm/virtual_machine.h\"\n#include \"oneflow/core/vm/symbol_storage.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\nnamespace one {\n\nMaybe<void> SyncAccessTensorWithTimeOut(\n    const std::shared_ptr<Tensor>& tensor,\n    const std::function<void(ep::Stream*, const std::shared_ptr<vm::EagerBlobObject>&)>& Callback,\n    const std::string& modifier) {\n  auto btb = std::make_shared<BlockingThenBusy>();\n  auto local_tensor = JUST(tensor->AsLocalTensor());\n  JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe<void> {\n    return builder->SyncAccessBlobByCallback(local_tensor, btb, Callback, modifier);\n  }));\n  JUST(btb->WaitUntilCntEqualZero(VirtualMachine::GetPredicatorNoMoreInstructionsFinished()));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> CopyLocalTensorDataTo(const std::shared_ptr<Tensor>& input, void* mem_ptr,\n                                  size_t size) {\n  CHECK_OR_RETURN(input->is_local());  // NOLINT\n  CHECK_OR_RETURN(input->is_contiguous()) << Error::RuntimeError() << kOfBugIssueUploadPrompt;\n  CHECK_EQ_OR_RETURN(input->shape()->elem_cnt() * JUST(input->dtype()->bytes()), size)\n      << Error::RuntimeError() << kOfBugIssueUploadPrompt;\n  if (input->nelement() == 1) { return GetItemInScalarTensor(input, mem_ptr, size); }\n  std::shared_ptr<one::LocalTensor> local_tensor = JUST(input->AsLocalTensor());\n  const auto& Callback = [&](ep::Stream* stream,\n                             const std::shared_ptr<vm::EagerBlobObject>& eager_blob_object) {\n    SyncAutoMemcpy(stream, mem_ptr, eager_blob_object->dptr(), size, memory::MakeHostMemCase(),\n                   eager_blob_object->mem_case());\n  };\n  auto btb = std::make_shared<BlockingThenBusy>();\n  JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe<void> {\n    return builder->SyncAccessBlobByCallback(local_tensor, btb, Callback, \"const\");\n  }));\n  JUST(btb->WaitUntilCntEqualZero(VirtualMachine::GetPredicatorNoMoreInstructionsFinished()));\n  return Maybe<void>::Ok();\n}\n\nMaybe<Scope> GetTensorScope(const std::shared_ptr<Tensor>& tensor) {\n  CHECK_OR_RETURN(LazyMode::is_enabled())\n      << \"it's not allowed to access tensor scope in eager mode\";\n  const auto& lbn = TensorNameScope::Global()->Lookup(tensor);\n  CHECK_OR_RETURN(!lbn.empty()) << \"can not access tensor scope since it is not a lazy tensor or a \"\n                                   \"captured eager tensor in graph\";\n  const auto& infer_ctx = JUST(GetCurInferCtx());\n  auto lbi = GenLogicalBlobId(lbn);\n  const auto* op = JUST(infer_ctx->Op4OpName(lbi.op_name()));\n  return Singleton<symbol::Storage<Scope>>::Get()->MaybeGetPtr(op->op_conf().scope_symbol_id());\n}\n\nMaybe<void> GetItemInScalarTensor(const std::shared_ptr<Tensor>& scalar_tensor, void* scalar_ptr,\n                                  size_t size) {\n  CHECK_EQ_OR_RETURN(GetSizeOfDataType(scalar_tensor->dtype()->data_type()), size)\n      << \"invalid size\";\n  CHECK_OR_RETURN(scalar_tensor->is_eager()) << \"Only eager scalar tensor support GetItem.\";\n  CHECK_EQ_OR_RETURN(scalar_tensor->nelement(), 1)\n      << \"can only convert a tensor of size 1 to a Python scalar\";\n  std::shared_ptr<LocalTensor> local_tensor;\n  {\n    auto tensor = scalar_tensor;\n    if (tensor->is_global()) {\n      Symbol<ParallelDesc> parallel_desc;\n      {\n        const ParallelConf parallel_conf = GenParallelConfOfCpuOnAllRanks();\n        JUST(PhysicalRun(\n            [&parallel_desc, &parallel_conf](InstructionsBuilder* builder) -> Maybe<void> {\n              parallel_desc = SymbolOf(*JUST(builder->GetParallelDescSymbol(parallel_conf)));\n              return Maybe<void>::Ok();\n            }));\n      }\n      const auto& broadcast_sbp = JUST(MakeBroadcastSbpParallel());\n      tensor = JUST(functional::ToGlobal(tensor, parallel_desc, {broadcast_sbp}, /*grad_sbp=*/{},\n                                         /*check_meta=*/false, /*copy=*/false));\n      tensor = JUST(functional::GlobalToLocal(tensor, /*copy=*/false));\n    }\n    local_tensor = JUST(tensor->AsLocalTensor());\n  }\n  JUST(SyncReadSmallMem(reinterpret_cast<char*>(scalar_ptr), size, local_tensor));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/framework/tensor_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_FRAMEWORK_TENSOR_UTIL_H_\n#define ONEFLOW_CORE_FRAMEWORK_TENSOR_UTIL_H_\n\n#include <functional>\n#include <string>\n\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/framework/tensor.h\"\n#include \"oneflow/core/job/scope.h\"\n#include \"oneflow/core/common/data_type.h\"\n\nnamespace oneflow {\n\nnamespace ep {\nclass Stream;\n}\n\nnamespace vm {\nclass EagerBlobObject;\n}\n\nnamespace one {\n\nclass Tensor;\n\nMaybe<void> SyncAccessTensorWithTimeOut(\n    const std::shared_ptr<Tensor>& tensor,\n    const std::function<void(ep::Stream*, const std::shared_ptr<vm::EagerBlobObject>&)>& callback,\n    const std::string& modifier);\n\nMaybe<void> CopyLocalTensorDataTo(const std::shared_ptr<Tensor>& input, void* mem_ptr, size_t size);\n\nMaybe<Scope> GetTensorScope(const std::shared_ptr<Tensor>& tensor);\n\nMaybe<void> GetItemInScalarTensor(const std::shared_ptr<Tensor>& scalar_tensor, void* scalar_ptr,\n                                  size_t size);\ntemplate<typename T>\nMaybe<T> GetItemInScalarTensor(const std::shared_ptr<Tensor>& scalar_tensor) {\n  T scalar{0};\n  if constexpr (GetDataType<T>() == kInt64) {\n    if (scalar_tensor->dtype()->data_type() == DataType::kInt8\n        || scalar_tensor->dtype()->data_type() == kUInt8) {\n      int8_t int8_integer = 0;\n      JUST(GetItemInScalarTensor(scalar_tensor, &int8_integer, sizeof(int8_t)));\n      scalar = static_cast<T>(int8_integer);\n    } else if (scalar_tensor->dtype()->data_type() == DataType::kInt16\n               || scalar_tensor->dtype()->data_type() == kUInt16) {\n      int16_t int16_integer = 0;\n      JUST(GetItemInScalarTensor(scalar_tensor, &int16_integer, sizeof(int16_t)));\n      scalar = static_cast<T>(int16_integer);\n    } else if (scalar_tensor->dtype()->data_type() == DataType::kInt32\n               || scalar_tensor->dtype()->data_type() == kUInt32) {\n      int32_t int32_integer = 0;\n      JUST(GetItemInScalarTensor(scalar_tensor, &int32_integer, sizeof(int32_t)));\n      scalar = static_cast<T>(int32_integer);\n    } else if (scalar_tensor->dtype()->data_type() == DataType::kInt64\n               || scalar_tensor->dtype()->data_type() == kUInt64) {\n      int64_t int64_integer = 0;\n      JUST(GetItemInScalarTensor(scalar_tensor, &int64_integer, sizeof(int64_t)));\n      scalar = static_cast<T>(int64_integer);\n    }\n  } else {\n    JUST(GetItemInScalarTensor(scalar_tensor, &scalar, sizeof(T)));\n  }\n  return scalar;\n}\n\n}  // namespace one\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_FRAMEWORK_TENSOR_UTIL_H_\n"
  },
  {
    "path": "oneflow/core/framework/to_string.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <map>\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/framework/to_string.h\"\n#include \"oneflow/core/ep/include/device_manager_registry.h\"\n\nnamespace oneflow {\n\nMaybe<std::string> DeviceTag4DeviceType(DeviceType device_type) {\n  auto device_tag = ep::DeviceManagerRegistry::GetDeviceTypeNameByDeviceType(device_type);\n  if (device_tag.empty()) {\n    return Error::DeviceTagNotFoundError() << \"invalid_device\";\n  } else {\n    return device_tag;\n  }\n}\n\nMaybe<DeviceType> DeviceType4DeviceTag(const std::string& device_tag) {\n  auto device_type = ep::DeviceManagerRegistry::GetDeviceTypeByDeviceTypeName(device_tag);\n  if (device_type == DeviceType::kInvalidDevice) {\n    return Error::DeviceTagNotFoundError() << \"device tag `\" << device_tag << \"' not found\";\n  } else {\n    return device_type;\n  }\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/framework/to_string.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_FRAMEWORK_TO_STRING_H_\n#define ONEFLOW_CORE_FRAMEWORK_TO_STRING_H_\n\n#include \"oneflow/core/common/to_string.h\"\n#include \"oneflow/core/common/data_type.pb.h\"\n#include \"oneflow/core/common/device_type.h\"\n#include \"oneflow/core/common/maybe.h\"\n\nnamespace oneflow {\n\nMaybe<std::string> DeviceTag4DeviceType(DeviceType device_type);\nMaybe<DeviceType> DeviceType4DeviceTag(const std::string& device_tag);\n\ntemplate<>\ninline std::string ToString(const DataType& data_type) {\n  return DataType_Name(data_type);\n}\n\ntemplate<>\ninline std::string ToString(const DeviceType& device_type) {\n  return DeviceType_Name(device_type);\n}\n\ntemplate<>\ninline std::string ToString(const std::string& value) {\n  return value;\n}\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_FRAMEWORK_TO_STRING_H_\n"
  },
  {
    "path": "oneflow/core/framework/transport_token.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <array>\n#include \"oneflow/core/framework/transport_token.h\"\n#include \"oneflow/core/common/data_type.h\"\n#include \"oneflow/core/common/data_type.h\"\n#include \"oneflow/core/thread/thread_global_id.h\"\n#include \"oneflow/core/framework/rank_group_rpc_util.h\"\n\nnamespace oneflow {\n\n/*static*/ Maybe<TransportToken> TransportToken::NewTransportToken(TransportTokenType type) {\n  int32_t thread_global_id = GetThisThreadGlobalId();\n  CHECK_GE_OR_RETURN(thread_global_id, 0);                             // NOLINT\n  CHECK_LT_OR_RETURN(thread_global_id, MaxNumberOfThreadGlobalUId());  // NOLINT\n  return TransportToken(type, thread_global_id);\n}\n\nMaybe<void> TransportToken::CheckThreadGlobalId() const {\n  int32_t thread_global_id = GetThisThreadGlobalId();\n  CHECK_EQ_OR_RETURN(thread_global_id, this->thread_global_id());  // NOLINT\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> TransportToken::set_src_rank(int64_t val) {\n  CHECK_GE_OR_RETURN(val, 0);\n  CHECK_LT_OR_RETURN(val, GetMaxVal<uint16_t>());\n  src_rank_ = val;\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> TransportToken::set_dst_rank(int64_t val) {\n  CHECK_GE_OR_RETURN(val, 0);\n  CHECK_LT_OR_RETURN(val, GetMaxVal<uint16_t>());\n  dst_rank_ = val;\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/framework/transport_token.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_FRAMEWORK_RPC_TOKEN_H_\n#define ONEFLOW_CORE_FRAMEWORK_RPC_TOKEN_H_\n\n#include <functional>\n#include \"oneflow/core/common/type_traits.h\"\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/common/symbol.h\"\n\nnamespace oneflow {\n\nconst static int kTransportTokenTypeBit = 5;\nconst static int kTransportTokenThreadGlobalIdBit = 3;\n\nenum TransportTokenType {\n  // Begin\n  kTransportTokenTypeInvalid = 0,\n  kTransportTokenTypeData,  // e.g. for tensor data transportation\n  kTransportTokenTypeMeta,  // e.g. for consistent id generating\n  kTransportTokenTypeSyncSymbolParallelDesc,\n  kTransportTokenTypeSyncSymbolNdSbp,\n  kTransportTokenTypeSyncSymbolGlobalTensorMeta,\n  kTransportTokenTypeCheckRankGroupConsistency,\n  kTransportTokenTypeCheckTensorConsistency,\n  kTransportTokenTypeSyncLocalShapeDtype,\n  // End\n  kTransportTokenTypeSize,\n};\n\nstatic_assert(kTransportTokenTypeSize <= (1 << kTransportTokenTypeBit), \"\");\n\nclass TransportToken;\n\ntemplate<>\nstruct IsScalarType<TransportToken> final {\n  static const bool value = true;\n};\n\nclass TransportToken final {\n public:\n  TransportToken() : TransportToken(kTransportTokenTypeInvalid, 0) {}\n  TransportToken(const TransportToken&) = default;\n  TransportToken(TransportToken&) = default;\n  ~TransportToken() = default;\n\n  static Maybe<TransportToken> NewTransportToken(TransportTokenType type);\n\n  static constexpr size_t MaxNumberOfThreadGlobalUId() {\n    return (1 << kTransportTokenThreadGlobalIdBit);\n  }\n\n  Maybe<void> CheckThreadGlobalId() const;\n  bool operator==(const TransportToken& other) const {\n    return static_cast<uint64_t>(*this) == static_cast<uint64_t>(other);\n  }\n\n  // Getters\n  TransportTokenType type() const { return static_cast<TransportTokenType>(type_); }\n  int thread_global_id() const { return thread_global_id_; }\n  int32_t seq_id() const { return seq_id_; }\n\n  // Setters\n  Maybe<void> set_src_rank(int64_t val);\n  Maybe<void> set_dst_rank(int64_t val);\n\n  operator uint64_t() const { return *reinterpret_cast<const uint64_t*>(this); }\n\n  TransportToken& operator++() {\n    ++seq_id_;\n    return *this;\n  }\n\n private:\n  TransportToken(TransportTokenType type, uint8_t thread_global_id)\n      : src_rank_(0),\n        dst_rank_(0),\n        type_(static_cast<uint8_t>(type)),\n        thread_global_id_(thread_global_id),\n        seq_id_(0) {}\n\n  uint16_t src_rank_;\n  uint16_t dst_rank_;\n  uint8_t type_ : kTransportTokenTypeBit;  // TransportTokenType\n  uint8_t thread_global_id_ : kTransportTokenThreadGlobalIdBit;\n  uint32_t seq_id_ : (32 - kTransportTokenTypeBit - kTransportTokenThreadGlobalIdBit);\n};\nstatic_assert(sizeof(TransportToken) == sizeof(uint64_t), \"\");\n\n}  // namespace oneflow\n\nnamespace std {\n\ntemplate<>\nstruct hash<oneflow::TransportToken> {\n  size_t operator()(const oneflow::TransportToken& token) const {\n    return std::hash<uint64_t>()(static_cast<uint64_t>(token));\n  }\n};\n\n}  // namespace std\n\n#endif  // ONEFLOW_CORE_FRAMEWORK_RPC_TOKEN_H_\n"
  },
  {
    "path": "oneflow/core/framework/transport_util.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <memory>\n#include <chrono>\n#include \"oneflow/core/framework/transport_token.h\"\n#include \"oneflow/core/framework/transport_util.h\"\n#include \"oneflow/core/job/parallel_desc.h\"\n#include \"oneflow/core/transport/transport.h\"\n#include \"oneflow/core/thread/thread_global_id.h\"\n#include \"oneflow/core/job/rank_group.h\"\n#include \"oneflow/core/common/data_type.h\"\n#include \"oneflow/core/common/spin_counter.h\"\n#include \"oneflow/core/rpc/include/global_process_ctx.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<Maybe<void> (*SendOrRecv)(const TransportToken&, int64_t, void*, std::size_t,\n                                   const std::function<void()>&),\n         Maybe<void> (AsyncTransportCtx::*Prepare)(int64_t, void**, std::size_t*,\n                                                   std::function<void()>*),\n         typename ForEachRankT>\nMaybe<void> AccessToOtherRanks(const ForEachRankT& ForEachRank, const TransportToken& token,\n                               AsyncTransportCtx* ctx) {\n  auto* blocking_counter = ctx->mut_blocking_counter();\n  JUST(ForEachRank([&, blocking_counter](int64_t rank) -> Maybe<void> {\n    if (rank == GlobalProcessCtx::Rank()) { return Maybe<void>::Ok(); }\n    blocking_counter->Increase();\n    void* buffer = nullptr;\n    std::size_t size = 0;\n    std::function<void()> Callback;\n    JUST((ctx->*Prepare)(rank, &buffer, &size, &Callback));\n    JUST(SendOrRecv(token, rank, buffer, size, [blocking_counter, Callback]() {\n      Callback();\n      blocking_counter->Decrease();\n    }));\n    return Maybe<void>::Ok();\n  }));\n  return Maybe<void>::Ok();\n}\n\ntemplate<Maybe<void> (*SendOrRecv)(const TransportToken&, int64_t, void*, std::size_t,\n                                   const std::function<void()>&),\n         Maybe<void> (AsyncTransportCtx::*Prepare)(int64_t, void**, std::size_t*,\n                                                   std::function<void()>*)>\nMaybe<void> AccessToAllOtherRanks(Symbol<RankGroup> rank_group, const TransportToken& token,\n                                  AsyncTransportCtx* ctx) {\n  const auto& ForEachRank = [&](const std::function<Maybe<void>(int64_t)>& DoEach) -> Maybe<void> {\n    return rank_group->ForEachRank(DoEach);\n  };\n  return AccessToOtherRanks<SendOrRecv, Prepare>(ForEachRank, token, ctx);\n}\n\ntemplate<Maybe<int64_t> (RankGroup::*GetPrevOrNext)() const,\n         Maybe<void> (*SendOrRecv)(const TransportToken&, int64_t, void*, std::size_t,\n                                   const std::function<void()>&),\n         Maybe<void> (AsyncTransportCtx::*Prepare)(int64_t, void**, std::size_t*,\n                                                   std::function<void()>*)>\nMaybe<void> AccessToNearbyRank(Symbol<RankGroup> rank_group, const TransportToken& token,\n                               AsyncTransportCtx* ctx) {\n  CHECK_OR_RETURN(rank_group->ContainingCurrentRank());\n  const auto& ForEachRank = [&](const std::function<Maybe<void>(int64_t)>& DoEach) -> Maybe<void> {\n    return DoEach(JUST(((*rank_group).*GetPrevOrNext)()));\n  };\n  return AccessToOtherRanks<SendOrRecv, Prepare>(ForEachRank, token, ctx);\n}\n\nnamespace {\n\nMaybe<std::shared_ptr<TransportToken>> RawGetTransportToken(const TransportToken& token) {\n  CHECK_EQ_OR_RETURN(token.seq_id(), 0);\n  JUST(token.CheckThreadGlobalId());\n  auto auto_token = std::make_shared<TransportToken>(token);\n  return auto_token;\n}\n\nstatic constexpr auto* GetTransportToken = DECORATE(&RawGetTransportToken, ThreadLocal);\n\nMaybe<TransportToken> GetAutoIncrementalTransportToken(int64_t src_rank, int64_t dst_rank,\n                                                       TransportToken token) {\n  CHECK_EQ_OR_RETURN(token.seq_id(), 0);\n  JUST(token.set_src_rank(src_rank));\n  JUST(token.set_dst_rank(dst_rank));\n  return ++**JUST(GetTransportToken(token));\n}\n\n}  // namespace\n\nMaybe<void> Send(const TransportToken& token, int64_t rank, void* buffer, std::size_t size,\n                 const std::function<void()>& Callback) {\n#ifdef __linux__\n  int64_t src_rank = GlobalProcessCtx::Rank();\n  int64_t dst_rank = rank;\n  TransportToken send_token = JUST(GetAutoIncrementalTransportToken(src_rank, dst_rank, token));\n  auto* transport = JUST(SingletonMaybe<Transport>());\n  transport->Send(static_cast<uint64_t>(send_token), rank, buffer, size, Callback);\n  return Maybe<void>::Ok();\n#else\n  UNIMPLEMENTED();\n  return Maybe<void>::Ok();\n#endif  // __linux__\n}\n\nMaybe<void> Recv(const TransportToken& token, int64_t rank, void* buffer, std::size_t size,\n                 const std::function<void()>& Callback) {\n#ifdef __linux__\n  int64_t src_rank = rank;\n  int64_t dst_rank = GlobalProcessCtx::Rank();\n  TransportToken recv_token = JUST(GetAutoIncrementalTransportToken(src_rank, dst_rank, token));\n  auto* transport = JUST(SingletonMaybe<Transport>());\n  transport->Receive(static_cast<uint64_t>(recv_token), rank, buffer, size, Callback);\n  return Maybe<void>::Ok();\n#else\n  UNIMPLEMENTED();\n  return Maybe<void>::Ok();\n#endif  // __linux__\n}\n\n}  // namespace\n\n/*static*/ Maybe<void> TransportUtil::BroadcastToAllOtherRanks(Symbol<RankGroup> rank_group,\n                                                               const TransportToken& token,\n                                                               AsyncTransportCtx* ctx) {\n  CHECK_OR_RETURN(rank_group->ContainingCurrentRank());\n  JUST(AccessToAllOtherRanks<&Send, &AsyncTransportCtx::PrepareSendBufferAndCallback>(rank_group,\n                                                                                      token, ctx));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> TransportUtil::CollectFromAllOtherRanks(Symbol<RankGroup> rank_group,\n                                                               const TransportToken& token,\n                                                               AsyncTransportCtx* ctx) {\n  CHECK_OR_RETURN(rank_group->ContainingCurrentRank());\n  JUST(AccessToAllOtherRanks<&Recv, &AsyncTransportCtx::PrepareRecvBufferAndCallback>(rank_group,\n                                                                                      token, ctx));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> TransportUtil::BroadcastToOtherRanks(Symbol<RankGroup> src_rank_group,\n                                                            Symbol<RankGroup> dst_rank_group,\n                                                            const TransportToken& token,\n                                                            AsyncTransportCtx* ctx) {\n  if (src_rank_group->ContainingCurrentRank()) {\n    JUST(AccessToAllOtherRanks<&Send, &AsyncTransportCtx::PrepareSendBufferAndCallback>(\n        dst_rank_group, token, ctx));\n  }\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> TransportUtil::CollectFromOtherRanks(Symbol<RankGroup> src_rank_group,\n                                                            Symbol<RankGroup> dst_rank_group,\n                                                            const TransportToken& token,\n                                                            AsyncTransportCtx* ctx) {\n  if (dst_rank_group->ContainingCurrentRank()) {\n    JUST(AccessToAllOtherRanks<&Recv, &AsyncTransportCtx::PrepareRecvBufferAndCallback>(\n        src_rank_group, token, ctx));\n  }\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> TransportUtil::SendToNextRankInRing(Symbol<RankGroup> rank_group,\n                                                           const TransportToken& token,\n                                                           AsyncTransportCtx* ctx) {\n  JUST(\n      AccessToNearbyRank<&RankGroup::GetNextRankInRing, &Send,\n                         &AsyncTransportCtx::PrepareSendBufferAndCallback>(rank_group, token, ctx));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> TransportUtil::ReceiveFromPrevRankInRing(Symbol<RankGroup> rank_group,\n                                                                const TransportToken& token,\n                                                                AsyncTransportCtx* ctx) {\n  JUST(\n      AccessToNearbyRank<&RankGroup::GetPrevRankInRing, &Recv,\n                         &AsyncTransportCtx::PrepareRecvBufferAndCallback>(rank_group, token, ctx));\n  return Maybe<void>::Ok();\n}\n\nnamespace {\n\nMaybe<int64_t> GetCurrentRankIndex(const std::vector<int64_t>& rank_heap) {\n  for (int i = 0; i < rank_heap.size(); ++i) {\n    if (rank_heap.at(i) == GlobalProcessCtx::Rank()) { return i; }\n  }\n  UNIMPLEMENTED_THEN_RETURN();\n}\n\n}  // namespace\n\n/*static*/ Maybe<void> TransportUtil::SendDataToChildrenInHeap(\n    const std::vector<int64_t>& rank_heap, const TransportToken& token, AsyncTransportCtx* ctx) {\n  int64_t current_rank_index = JUST(GetCurrentRankIndex(rank_heap));\n  const auto& ForEachRank = [&](const std::function<Maybe<void>(int64_t)>& DoEach) -> Maybe<void> {\n    int64_t left_index = current_rank_index * 2 + 1;\n    if (left_index < rank_heap.size()) { JUST(DoEach(rank_heap.at(left_index))); }\n    int64_t right_index = current_rank_index * 2 + 2;\n    if (right_index < rank_heap.size()) { JUST(DoEach(rank_heap.at(right_index))); }\n    return Maybe<void>::Ok();\n  };\n  return AccessToOtherRanks<&Send, &AsyncTransportCtx::PrepareSendBufferAndCallback>(ForEachRank,\n                                                                                     token, ctx);\n}\n\n/*static*/ Maybe<void> TransportUtil::ReceiveDataFromParentInHeap(\n    const std::vector<int64_t>& rank_heap, const TransportToken& token, AsyncTransportCtx* ctx) {\n  int64_t current_rank_index = JUST(GetCurrentRankIndex(rank_heap));\n  const auto& ForEachRank = [&](const std::function<Maybe<void>(int64_t)>& DoEach) -> Maybe<void> {\n    if (current_rank_index == 0) { return Maybe<void>::Ok(); }\n    return DoEach(rank_heap.at((current_rank_index - 1) / 2));\n  };\n  return AccessToOtherRanks<&Recv, &AsyncTransportCtx::PrepareRecvBufferAndCallback>(ForEachRank,\n                                                                                     token, ctx);\n}\n\n/*static*/ Maybe<void> TransportUtil::ReceiveDataFromRank(int64_t rank, const TransportToken& token,\n                                                          AsyncTransportCtx* ctx) {\n  const auto& ForEachRank = [&](const std::function<Maybe<void>(int64_t)>& DoEach) -> Maybe<void> {\n    return DoEach(rank);\n  };\n  return AccessToOtherRanks<&Recv, &AsyncTransportCtx::PrepareRecvBufferAndCallback>(ForEachRank,\n                                                                                     token, ctx);\n}\n\n/*static*/ Maybe<void> TransportUtil::SendDataToRank(int64_t rank, const TransportToken& token,\n                                                     AsyncTransportCtx* ctx) {\n  const auto& ForEachRank = [&](const std::function<Maybe<void>(int64_t)>& DoEach) -> Maybe<void> {\n    return DoEach(rank);\n  };\n  return AccessToOtherRanks<&Send, &AsyncTransportCtx::PrepareSendBufferAndCallback>(ForEachRank,\n                                                                                     token, ctx);\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/framework/transport_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_FRAMEWORK_RPC_UTIL_H_\n#define ONEFLOW_CORE_FRAMEWORK_RPC_UTIL_H_\n\n#include <atomic>\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/common/symbol.h\"\n#include \"oneflow/core/common/blocking_counter.h\"\n#include \"oneflow/core/framework/transport_token.h\"\n\nnamespace oneflow {\n\nclass AsyncTransportCtx {\n public:\n  explicit AsyncTransportCtx(const TransportToken& transport_token)\n      : transport_token_(transport_token), blocking_counter_(1) {}\n  virtual ~AsyncTransportCtx() = default;\n\n  const TransportToken& transport_token() const { return transport_token_; }\n  BlockingCounter* mut_blocking_counter() { return &blocking_counter_; }\n\n  Maybe<void> WaitDone() {\n    mut_blocking_counter()->Decrease();\n    return mut_blocking_counter()->WaitUntilCntEqualZero([]() -> Maybe<bool> { return true; });\n  }\n\n  virtual Maybe<void> PrepareSendBufferAndCallback(int64_t rank, void** buffer, std::size_t* size,\n                                                   std::function<void()>* Callback) = 0;\n\n  virtual Maybe<void> PrepareRecvBufferAndCallback(int64_t rank, void** buffer, std::size_t* size,\n                                                   std::function<void()>* Callback) = 0;\n\n private:\n  TransportToken transport_token_;\n  BlockingCounter blocking_counter_;\n};\n\nclass NaiveAsyncTransportCtx final : public AsyncTransportCtx {\n public:\n  NaiveAsyncTransportCtx(\n      const TransportToken& transport_token,\n      const std::function<Maybe<void>(void**, std::size_t*, std::function<void()>*)>& PrepareSend,\n      const std::function<Maybe<void>(void**, std::size_t*, std::function<void()>*)>& PrepareRecv)\n      : AsyncTransportCtx(transport_token),\n        prepare_send_(PrepareSend),\n        prepare_recv_(PrepareRecv) {}\n\n  NaiveAsyncTransportCtx(\n      const TransportToken& transport_token,\n      const std::function<Maybe<void>(void**, std::size_t*, std::function<void()>*)>& PrepareSend,\n      const std::function<Maybe<void>(int64_t, void**, std::size_t*, std::function<void()>*)>&\n          PrepareRecvWithRank)\n      : AsyncTransportCtx(transport_token),\n        prepare_send_(PrepareSend),\n        prepare_recv_with_rank_(PrepareRecvWithRank) {}\n\n  NaiveAsyncTransportCtx(\n      const TransportToken& transport_token,\n      const std::function<Maybe<void>(int64_t, void**, std::size_t*, std::function<void()>*)>&\n          PrepareSendWithRank,\n      const std::function<Maybe<void>(void**, std::size_t*, std::function<void()>*)>& PrepareRecv)\n      : AsyncTransportCtx(transport_token),\n        prepare_send_with_rank_(PrepareSendWithRank),\n        prepare_recv_(PrepareRecv) {}\n\n  NaiveAsyncTransportCtx(\n      const TransportToken& transport_token,\n      const std::function<Maybe<void>(int64_t, void**, std::size_t*, std::function<void()>*)>&\n          PrepareSendWithRank,\n      const std::function<Maybe<void>(int64_t, void**, std::size_t*, std::function<void()>*)>&\n          PrepareRecvWithRank)\n      : AsyncTransportCtx(transport_token),\n        prepare_send_with_rank_(PrepareSendWithRank),\n        prepare_recv_with_rank_(PrepareRecvWithRank) {}\n\n  ~NaiveAsyncTransportCtx() override = default;\n\n  Maybe<void> PrepareSendBufferAndCallback(int64_t rank, void** buffer, std::size_t* size,\n                                           std::function<void()>* Callback) override {\n    if (prepare_send_with_rank_) { return prepare_send_with_rank_(rank, buffer, size, Callback); }\n    return prepare_send_(buffer, size, Callback);\n  }\n\n  Maybe<void> PrepareRecvBufferAndCallback(int64_t rank, void** buffer, std::size_t* size,\n                                           std::function<void()>* Callback) override {\n    if (prepare_recv_with_rank_) { return prepare_recv_with_rank_(rank, buffer, size, Callback); }\n    return prepare_recv_(buffer, size, Callback);\n  }\n\n private:\n  std::function<Maybe<void>(void**, std::size_t*, std::function<void()>*)> prepare_send_;\n  std::function<Maybe<void>(int64_t, void**, std::size_t*, std::function<void()>*)>\n      prepare_send_with_rank_;\n  std::function<Maybe<void>(void**, std::size_t*, std::function<void()>*)> prepare_recv_;\n  std::function<Maybe<void>(int64_t, void**, std::size_t*, std::function<void()>*)>\n      prepare_recv_with_rank_;\n};\n\nclass RankGroup;\n\nstruct TransportUtil final {\n  static Maybe<void> SendToNextRankInRing(Symbol<RankGroup> rank_group, const TransportToken& token,\n                                          AsyncTransportCtx* ctx);\n  static Maybe<void> ReceiveFromPrevRankInRing(Symbol<RankGroup> rank_group,\n                                               const TransportToken& token, AsyncTransportCtx* ctx);\n\n  static Maybe<void> BroadcastToAllOtherRanks(Symbol<RankGroup> rank_group,\n                                              const TransportToken& token, AsyncTransportCtx* ctx);\n\n  static Maybe<void> CollectFromAllOtherRanks(Symbol<RankGroup> rank_group,\n                                              const TransportToken& token, AsyncTransportCtx* ctx);\n\n  static Maybe<void> BroadcastToOtherRanks(Symbol<RankGroup> src_rank_group,\n                                           Symbol<RankGroup> dst_rank_group,\n                                           const TransportToken& token, AsyncTransportCtx* ctx);\n\n  static Maybe<void> CollectFromOtherRanks(Symbol<RankGroup> src_rank_group,\n                                           Symbol<RankGroup> dst_rank_group,\n                                           const TransportToken& token, AsyncTransportCtx* ctx);\n\n  static Maybe<void> SendDataToChildrenInHeap(const std::vector<int64_t>& rank_heap,\n                                              const TransportToken& token, AsyncTransportCtx* ctx);\n  static Maybe<void> ReceiveDataFromParentInHeap(const std::vector<int64_t>& rank_heap,\n                                                 const TransportToken& token,\n                                                 AsyncTransportCtx* ctx);\n  static Maybe<void> ReceiveDataFromRank(int64_t rank, const TransportToken& token,\n                                         AsyncTransportCtx* ctx);\n  static Maybe<void> SendDataToRank(int64_t rank, const TransportToken& token,\n                                    AsyncTransportCtx* ctx);\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_FRAMEWORK_RPC_UTIL_H_\n"
  },
  {
    "path": "oneflow/core/framework/user_op_attr.proto",
    "content": "syntax = \"proto2\";\npackage oneflow;\n\nimport \"oneflow/core/common/shape.proto\";\nimport \"oneflow/core/common/sequential.proto\";\nimport \"oneflow/core/common/data_type.proto\";\nimport \"oneflow/core/common/device.proto\";\nimport \"oneflow/core/common/memory_format.proto\";\n\nenum AttrType {\n  kAtInt32 = 1;\n  kAtInt64 = 2;\n  kAtBool = 3;\n  kAtFloat = 4;\n  kAtDouble = 5;\n  kAtString = 6;\n  kAtShape = 7;\n  kAtDataType = 8;\n  kAtListInt32 = 9;\n  kAtListInt64 = 10;\n  kAtListFloat = 11;\n  kAtListDataType = 12;\n  kAtListShape = 13;\n  kAtListString = 14;\n  kAtStride = 15;\n  kAtListStride = 16;\n  kAtDevice = 17;\n  kAtComplexDouble = 18;\n  kAtMemoryFormat = 19;\n  kAtBytes = 20;\n}\n\nmessage AttrValue {\n  message ListInt32 {\n    repeated int32 val = 1;\n  }\n  message ListInt64 {\n    repeated int64 val = 1;\n  }\n  message ListFloat {\n    repeated float val = 1;\n  }\n  message ListDataType {\n    repeated DataType val = 1;\n  }\n  message ListShape {\n    repeated ShapeProto val = 1;\n  }\n  message ListStride {\n    repeated Int64ListProto val = 1;\n  }\n  // order and naming convention of the oneof field must be consistent with the enum AttrType\n  message ListString {\n    repeated string val = 1;\n  }\n  message ComplexDouble {\n    required double real = 1;\n    required double imag = 2;\n  }\n  oneof value {\n    int32 at_int32 = 1;\n    int64 at_int64 = 2;\n    bool at_bool = 3;\n    float at_float = 4;\n    double at_double = 5;\n    string at_string = 6;\n    ShapeProto at_shape = 7;\n    DataType at_data_type = 8;\n    ListInt32 at_list_int32 = 9;\n    ListInt64 at_list_int64 = 10;\n    ListFloat at_list_float = 11;\n    ListDataType at_list_data_type = 12;\n    ListShape at_list_shape = 13;\n    ListString at_list_string = 14;\n    Int64ListProto at_stride = 15;\n    ListStride at_list_stride = 16;\n    DeviceProto at_device = 17;\n    ComplexDouble at_complex_double = 18;\n    MemoryFormat at_memory_format = 19;\n    bytes at_bytes = 20;\n  }\n}\n\nmessage AttrDef {\n  required string name = 1;\n  required string description = 2;\n  required AttrValue default_val = 3;\n}\n"
  },
  {
    "path": "oneflow/core/framework/user_op_conf.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/user_op_conf.h\"\n#include \"oneflow/core/framework/user_op_registry_manager.h\"\n#include \"oneflow/core/operator/operator.h\"\n#include \"oneflow/core/register/blob_desc.h\"\n#include \"oneflow/core/framework/user_op_def.h\"\n#include \"oneflow/core/framework/attr_value.h\"\n#include \"oneflow/core/framework/attr_value_accessor.h\"\n\nnamespace oneflow {\n\nnamespace user_op {\n\nUserOpConfWrapper::UserOpConfWrapper(std::shared_ptr<const OperatorConf> op_conf)\n    : op_conf_(op_conf) {\n  CHECK(op_conf_);\n  CHECK(op_conf_->has_user_conf());\n  attrs_ = MakeAttrMapFromUserOpConf(op_conf_->user_conf());\n}\n\nUserOpConfWrapper::UserOpConfWrapper(const OperatorConf& op_conf)\n    : UserOpConfWrapper(std::make_shared<OperatorConf>(op_conf)) {}\n\nconst OperatorConf& UserOpConfWrapper::op_conf() const { return *op_conf_; }\n\nconst UserOpConf& UserOpConfWrapper::user_op_conf() const { return op_conf_->user_conf(); }\n\nconst std::string& UserOpConfWrapper::op_name() const { return op_conf_->name(); }\n\nconst std::string& UserOpConfWrapper::op_type_name() const {\n  return op_conf_->user_conf().op_type_name();\n}\n\nconst std::string& UserOpConfWrapper::input(const std::string& arg_name, int32_t index) const {\n  auto it = op_conf_->user_conf().input().find(arg_name);\n  CHECK(it != op_conf_->user_conf().input().end())\n      << \"arg_name: \" << arg_name << \", index: \" << index;\n  CHECK(index >= 0 && index < it->second.s_size());\n  return it->second.s(index);\n}\n\nconst std::string& UserOpConfWrapper::output(const std::string& arg_name, int32_t index) const {\n  auto it = op_conf_->user_conf().output().find(arg_name);\n  CHECK(it != op_conf_->user_conf().output().end())\n      << \"arg_name: \" << arg_name << \", index: \" << index;\n  CHECK(index >= 0 && index < it->second.s_size());\n  return it->second.s(index);\n}\n\nbool UserOpConfWrapper::has_input(const std::string& arg_name, int32_t index) const {\n  return input_size(arg_name) > index;\n}\n\nbool UserOpConfWrapper::has_output(const std::string& arg_name, int32_t index) const {\n  return output_size(arg_name) > index;\n}\n\nint32_t UserOpConfWrapper::input_size(const std::string& arg_name) const {\n  auto it = op_conf_->user_conf().input().find(arg_name);\n  if (it == op_conf_->user_conf().input().end()) { return 0; }\n  return it->second.s_size();\n}\n\nint32_t UserOpConfWrapper::output_size(const std::string& arg_name) const {\n  auto it = op_conf_->user_conf().output().find(arg_name);\n  if (it == op_conf_->user_conf().output().end()) { return 0; }\n  return it->second.s_size();\n}\n\nconst std::shared_ptr<const AttrVal>& UserOpConfWrapper::Attr4Name(\n    const std::string& attr_name) const {\n  const auto& attr = attrs_.Attr4Name(attr_name);\n  CHECK(attr.get() != nullptr) << \"attr_name: \" << attr_name;\n  return attr;\n}\n\n#define OP_WRAPPER_ATTR_MEMBER_FUNC(field, cpp_type, attr_type)                                    \\\n  template<>                                                                                       \\\n  UserOpConfWrapperBuilder& UserOpConfWrapperBuilder::Attr<cpp_type>(const std::string& attr_name, \\\n                                                                     const cpp_type& val) {        \\\n    AttrValue attr_val;                                                                            \\\n    AttrValueAccessor<cpp_type>::Attr(val, &attr_val);                                             \\\n    attr_.emplace(attr_name, attr_val);                                                            \\\n    return *this;                                                                                  \\\n  }\n\nOF_PP_FOR_EACH_TUPLE(OP_WRAPPER_ATTR_MEMBER_FUNC, ATTR_SEQ)\n\n#undef OP_WRAPPER_ATTR_MEMBER_FUNC\n\nUserOpWrapper::UserOpWrapper(\n    const OperatorConf& op,\n    const std::function<const BlobDesc&(const std::string&)>& LogicalBlobDesc4BnInOp,\n    const std::function<LogicalBlobId*(const std::string&)>& DiffLbi4BnInOp)\n    : conf_(op), diff_fn_(DiffLbi4BnInOp) {\n  auto InitTensorDescFromOpArgs = [&](const PbMap<std::string, UserOpConf_ListString>& args) {\n    for (const auto& pair : args) {\n      for (int32_t i = 0; i < pair.second.s_size(); ++i) {\n        std::string bn = GenRepeatedBn(pair.first, i);\n        const BlobDesc& blob_desc = LogicalBlobDesc4BnInOp(bn);\n        CHECK((&blob_desc) != nullptr);\n        BlobDescProto proto;\n        blob_desc.ToProto(&proto);\n        NaiveTensorDesc tensor_desc(proto);\n        CHECK(bn2tensor_desc_.emplace(bn, tensor_desc).second);\n      }\n    }\n  };\n  InitTensorDescFromOpArgs(op.user_conf().input());\n  InitTensorDescFromOpArgs(op.user_conf().output());\n}\n\nconst TensorDesc& UserOpWrapper::arg_tensor_desc(const std::string& arg_name, int32_t index) const {\n  std::string bn = GenRepeatedBn(arg_name, index);\n  CHECK(bn2tensor_desc_.find(bn) != bn2tensor_desc_.end());\n  return bn2tensor_desc_.at(bn);\n}\n\nconst TensorDesc& UserOpWrapper::TensorDesc4ArgNameAndIndex(const std::string& arg_name,\n                                                            int32_t index) const {\n  return arg_tensor_desc(arg_name, index);\n}\n\nUserOpConfWrapperBuilder& UserOpConfWrapperBuilder::InputBind(\n    const std::string& arg_name, const std::string& logical_blob_name) {\n  if (input_.find(arg_name) == input_.end()) { input_order_.emplace_back(arg_name); }\n  input_[arg_name].emplace_back(logical_blob_name);\n  CHECK_EQ(input_.size(), input_order_.size());\n  return *this;\n}\n\nUserOpConfWrapperBuilder& UserOpConfWrapperBuilder::Input(const std::string& arg_name,\n                                                          const std::string& logical_blob_name) {\n  return InputBind(arg_name, logical_blob_name);\n}\n\nUserOpConfWrapperBuilder& UserOpConfWrapperBuilder::Output(const std::string& arg_name) {\n  return Output(arg_name, 1);\n}\n\nUserOpConfWrapperBuilder& UserOpConfWrapperBuilder::Output(const std::string& arg_name,\n                                                           int32_t num) {\n  CHECK(num >= 0);\n  if (output_.find(arg_name) == output_.end()) { output_order_.emplace_back(arg_name); }\n  output_[arg_name].resize(num);\n  for (int32_t i = 0; i < num; ++i) {\n    std::string bn = GenRepeatedBn(arg_name, i);\n    output_[arg_name].at(i) = GenLogicalBlobName(op_name_, bn);\n  }\n  CHECK_EQ(output_.size(), output_order_.size());\n  return *this;\n}\n\nUserOpConfWrapperBuilder& UserOpConfWrapperBuilder::ScopeSymbolId(int64_t scope_symbol_id) {\n  scope_symbol_id_.set_value(scope_symbol_id);\n  return *this;\n}\n\nUserOpConfWrapperBuilder& UserOpConfWrapperBuilder::DeviceTag(const std::string& device_tag) {\n  device_tag_ = device_tag;\n  return *this;\n}\n\nUserOpConfWrapper UserOpConfWrapperBuilder::Build() {\n  OperatorConf op_conf;\n  op_conf.set_name(op_name_);\n  if (!device_tag_.empty()) { op_conf.set_device_tag(device_tag_); }\n  if (scope_symbol_id_.has_value()) { op_conf.set_scope_symbol_id(scope_symbol_id_.value()); }\n  UserOpConf* user_conf = op_conf.mutable_user_conf();\n  user_conf->set_op_type_name(op_type_name_);\n  auto GenArgs = [&](const HashMap<std::string, std::vector<std::string>>& src,\n                     PbMap<std::string, UserOpConf_ListString>* arg_name2lbns) {\n    for (const auto& pair : src) {\n      *(*arg_name2lbns)[pair.first].mutable_s() = StdVec2PbRpf<std::string>(pair.second);\n    }\n  };\n  GenArgs(input_, user_conf->mutable_input());\n  GenArgs(output_, user_conf->mutable_output());\n  for (const auto& arg_name : input_order_) { user_conf->add_input_order(arg_name); }\n  for (const auto& arg_name : output_order_) { user_conf->add_output_order(arg_name); }\n  for (const auto& pair : attr_) { (*user_conf->mutable_attr())[pair.first] = pair.second; }\n  wrapper_ = UserOpConfWrapper(*CHECK_JUST(CheckAndCompleteUserOpConfImpl(op_conf)));\n  return wrapper_;\n}\n\n}  // namespace user_op\n\nMaybe<void> CheckArgDefIsValidInUserOpConf(\n    const OperatorConf& op_conf, const PbMap<std::string, UserOpConf_ListString>& arg_name2lbns,\n    const PbRpf<UserOpDef_ArgDef>& args) {\n  const std::string& op_name = op_conf.name();\n  const std::string& op_type_name = op_conf.user_conf().op_type_name();\n  HashSet<std::string> op_def_arg_names;\n  for (const auto& arg : args) {\n    int32_t arg_blob_num = 0;\n    if (arg_name2lbns.find(arg.name()) != arg_name2lbns.end()) {\n      arg_blob_num = arg_name2lbns.at(arg.name()).s_size();\n    }\n    if (arg_blob_num == 0) {\n      CHECK_OR_RETURN(arg.is_optional())\n          << \" op_name: \" << op_name << \" op_type_name: \" << op_type_name\n          << \" arg name: \" << arg.name() << \" in OpDef must have blob in op_conf: \\n\"\n          << op_conf.DebugString();\n    }\n    op_def_arg_names.insert(arg.name());\n  }\n  for (const auto& pair : arg_name2lbns) {\n    CHECK_OR_RETURN(op_def_arg_names.find(pair.first) != op_def_arg_names.end())\n        << \" op_name: \" << op_name << \" op_type_name: \" << op_type_name\n        << \" has not arg name: \" << pair.first << \" in OpDef\";\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> CheckUserOpConfArgOrderValid(\n    const OperatorConf& op_conf, const PbMap<std::string, UserOpConf_ListString>& arg_name2lbns,\n    const PbRpf<std::string>& arg_order) {\n  CHECK_EQ_OR_RETURN(arg_name2lbns.size(), arg_order.size())\n      << \" op_conf: \" << op_conf.DebugString() << \" io order is not valid.\";\n  HashSet<std::string> arg_names;\n  for (const std::string& arg_name : arg_order) {\n    CHECK_OR_RETURN(arg_names.insert(arg_name).second)\n        << \" op_conf: \" << op_conf.DebugString() << \" io order is not valid.\";\n    CHECK_OR_RETURN(arg_name2lbns.find(arg_name) != arg_name2lbns.end())\n        << \" op_conf: \" << op_conf.DebugString() << \" io order is not valid.\";\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> AddAttrDefaultValueAndCheckValid(const UserOpDef& op_def, UserOpConf* user_conf,\n                                             const std::string& error_msg_prefix) {\n  auto* attr_name2attr = user_conf->mutable_attr();\n  HashSet<std::string> op_def_attr_names;\n  for (const auto& attr : op_def.attr()) {\n    if (attr_name2attr->find(attr.name()) == attr_name2attr->end()) {\n      CHECK_OR_RETURN(attr.has_default_val())\n          << error_msg_prefix << \" op_type_name: \" << user_conf->op_type_name()\n          << \" must set attr val for attr_name: \" << attr.name();\n      (*attr_name2attr)[attr.name()] = attr.default_val();\n    }\n    op_def_attr_names.insert(attr.name());\n  }\n  for (const auto& pair : user_conf->attr()) {\n    CHECK_OR_RETURN(op_def_attr_names.find(pair.first) != op_def_attr_names.end())\n        << error_msg_prefix << \" op_type_name: \" << user_conf->op_type_name()\n        << \" has not attr_name: \" << pair.first << \" in OpDef\";\n  }\n  for (const auto& attr : op_def.attr()) {\n    CHECK_OR_RETURN(static_cast<int32_t>(attr.type())\n                    == static_cast<int32_t>(attr_name2attr->at(attr.name()).value_case()))\n        << error_msg_prefix << \" op_type_name: \" << user_conf->op_type_name()\n        << \" attr_name: \" << attr.name()\n        << \" has different attr type in OpDef and OpConf, it should be with type: \"\n        << AttrType_Name(attr.type());\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> AddAttrDefaultValueAndCheckValid(UserOpConf* user_conf) {\n  const user_op::OpRegistryResult* val =\n      user_op::UserOpRegistryMgr::Get().GetOpRegistryResult(user_conf->op_type_name());\n  CHECK_OR_RETURN(val) << \" Cannot find op_type_name: \" << user_conf->op_type_name();\n  const UserOpDef& op_def = val->op_def;\n  return AddAttrDefaultValueAndCheckValid(op_def, user_conf, \"\");\n}\n\nMaybe<void> AddAttrDefaultValueAndCheckValid(const UserOpDef& op_def, OperatorConf* op_conf) {\n  UserOpConf* user_conf = op_conf->mutable_user_conf();\n  std::string error_msg_prefix = \" op_name: \" + op_conf->name();\n  return AddAttrDefaultValueAndCheckValid(op_def, user_conf, error_msg_prefix);\n}\n\nMaybe<long long> GetAttrTypeImpl(const std::string& op_type_name, const std::string& attr_name) {\n  const user_op::OpRegistryResult* val =\n      user_op::UserOpRegistryMgr::Get().GetOpRegistryResult(op_type_name);\n  CHECK_OR_RETURN(val) << \" Cannot find op \" << op_type_name;\n  const UserOpDef& op_def = val->op_def;\n  for (int32_t i = 0; i < op_def.attr_size(); ++i) {\n    if (op_def.attr(i).name() == attr_name) { return op_def.attr(i).type(); }\n  }\n  CHECK_OR_RETURN(false) << \" Cannot find attr \" << attr_name << \" in op \" << op_type_name;\n}\n\nMaybe<OperatorConf> CheckAndCompleteUserOpConfImpl(const OperatorConf& op_conf) {\n  CHECK_OR_RETURN(op_conf.has_user_conf()) << \" Add default value only for user op\";\n  OperatorConf ret = op_conf;\n  UserOpConf* user_conf = ret.mutable_user_conf();\n  const user_op::OpRegistryResult* val =\n      user_op::UserOpRegistryMgr::Get().GetOpRegistryResult(user_conf->op_type_name());\n  CHECK_OR_RETURN(val) << \" Cannot find op_type_name: \" << user_conf->op_type_name();\n  const UserOpDef& op_def = val->op_def;\n\n  JUST(AddAttrDefaultValueAndCheckValid(op_def, &ret));\n  // check input and output valid\n  JUST(CheckArgDefIsValidInUserOpConf(op_conf, user_conf->input(), op_def.input()));\n  JUST(CheckArgDefIsValidInUserOpConf(op_conf, user_conf->output(), op_def.output()));\n  JUST(CheckUserOpConfArgOrderValid(op_conf, user_conf->input(), user_conf->input_order()));\n  JUST(CheckUserOpConfArgOrderValid(op_conf, user_conf->output(), user_conf->output_order()));\n  // check attr valid by user\n  JUST(val->check_fn(user_op::UserOpDefWrapper(op_def), user_op::UserOpConfWrapper(ret)));\n  return ret;\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/framework/user_op_conf.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_FRAMEWORK_USER_OP_CONF_H_\n#define ONEFLOW_CORE_FRAMEWORK_USER_OP_CONF_H_\n\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/common/tensor_desc.h\"\n#include \"oneflow/core/framework/user_op_def.pb.h\"\n#include \"oneflow/core/framework/user_op_attr.pb.h\"\n#include \"oneflow/core/framework/user_op_conf.pb.h\"\n#include \"oneflow/core/framework/attr_map.h\"\n#include \"oneflow/core/operator/op_conf.pb.h\"\n\nnamespace oneflow {\n\nclass BlobDesc;\n\nnamespace user_op {\n\nclass OpArg final {\n public:\n  OpArg(std::string&& name, int32_t index) : name_(std::move(name)), index_(index) {}\n\n  const std::string& name() const { return name_; }\n  int32_t index() const { return index_; }\n\n private:\n  std::string name_;\n  int32_t index_;\n};\n\nclass AttrVal;\n\nclass UserOpConfWrapper final {\n public:\n  UserOpConfWrapper(const OperatorConf&);\n  UserOpConfWrapper(std::shared_ptr<const OperatorConf> op_conf);\n  const OperatorConf& op_conf() const;\n  const UserOpConf& user_op_conf() const;\n  const std::string& op_name() const;\n  const std::string& op_type_name() const;\n  const std::string& input(const std::string& arg_name, int32_t index) const;\n  const std::string& output(const std::string& arg_name, int32_t index) const;\n  bool has_input(const std::string& arg_name, int32_t index) const;\n  bool has_output(const std::string& arg_name, int32_t index) const;\n  int32_t input_size(const std::string& arg_name) const;\n  int32_t output_size(const std::string& arg_name) const;\n\n  template<typename T>\n  const T& attr(const std::string& attr_name) const {\n    return CHECK_JUST(attrs_.GetAttr<T>(attr_name));\n  }\n\n  template<typename T>\n  const T& attr_or_default(const std::string& attr_name, const T& default_val) const {\n    if (attrs_.Has(attr_name)) {\n      return CHECK_JUST(attrs_.GetAttr<T>(attr_name));\n    } else {\n      return default_val;\n    }\n  }\n\n  const std::shared_ptr<const AttrVal>& Attr4Name(const std::string& attr_name) const;\n\n private:\n  UserOpConfWrapper() = default;\n  friend class UserOpConfWrapperBuilder;\n\n  std::shared_ptr<const OperatorConf> op_conf_;\n  AttrMap attrs_;\n};\n\nclass UserOpWrapper final {\n public:\n  UserOpWrapper(const OperatorConf& op, const std::function<const BlobDesc&(const std::string&)>&,\n                const std::function<LogicalBlobId*(const std::string&)>&);\n\n public:\n  const UserOpConfWrapper& user_op_conf() const { return conf_; }\n  const OperatorConf& op_conf() const { return conf_.op_conf(); }\n  const std::string& op_name() const { return conf_.op_name(); }\n  const std::string& op_type_name() const { return conf_.op_type_name(); }\n\n  int32_t input_size(const std::string& arg_name) const { return conf_.input_size(arg_name); }\n  const std::string& input(const std::string& arg_name, int32_t index) const {\n    return conf_.input(arg_name, index);\n  }\n\n  int32_t output_size(const std::string& arg_name) const { return conf_.output_size(arg_name); }\n  const std::string& output(const std::string& arg_name, int32_t index) const {\n    return conf_.output(arg_name, index);\n  }\n\n  template<typename T>\n  T attr(const std::string& attr_name) const {\n    return conf_.attr<T>(attr_name);\n  }\n\n  template<typename T>\n  T attr_or_default(const std::string& attr_name, const T& default_val) const {\n    return conf_.attr_or_default<T>(attr_name, default_val);\n  }\n\n  const TensorDesc& arg_tensor_desc(const std::string& arg_name, int32_t index) const;\n  const TensorDesc& TensorDesc4ArgNameAndIndex(const std::string& arg_name, int32_t index) const;\n\n private:\n  UserOpConfWrapper conf_;\n  std::function<LogicalBlobId*(const std::string&)> diff_fn_;\n  HashMap<std::string, NaiveTensorDesc> bn2tensor_desc_;\n};\n\nclass UserOpConfWrapperBuilder final {\n public:\n  UserOpConfWrapperBuilder(const std::string& op_name) : op_name_(op_name) {}\n\n  UserOpConfWrapperBuilder& OpTypeName(const std::string& op_type_name) {\n    op_type_name_ = op_type_name;\n    return *this;\n  }\n  UserOpConfWrapperBuilder& Op(const std::string& op_type_name) { return OpTypeName(op_type_name); }\n\n  UserOpConfWrapperBuilder& InputBind(const std::string& arg_name,\n                                      const std::string& logical_blob_name);\n  UserOpConfWrapperBuilder& Input(const std::string& arg_name,\n                                  const std::string& logical_blob_name);\n\n  UserOpConfWrapperBuilder& Output(const std::string& arg_name, int32_t num);\n  UserOpConfWrapperBuilder& Output(const std::string& arg_name);\n\n  template<typename T>\n  UserOpConfWrapperBuilder& Attr(const std::string& attr_name, const T& val);\n\n  UserOpConfWrapperBuilder& ScopeSymbolId(int64_t scope_symbol_id);\n  UserOpConfWrapperBuilder& DeviceTag(const std::string& device_tag);\n\n  UserOpConfWrapper Build();\n\n private:\n  UserOpConfWrapper wrapper_;\n  std::string op_name_;\n  std::string op_type_name_;\n  HashMap<std::string, std::vector<std::string>> input_;\n  HashMap<std::string, std::vector<std::string>> output_;\n  HashMap<std::string, AttrValue> attr_;\n  std::vector<std::string> input_order_;\n  std::vector<std::string> output_order_;\n  OptInt64 scope_symbol_id_;\n  std::string device_tag_;\n};\n\n}  // namespace user_op\n\nMaybe<long long> GetAttrTypeImpl(const std::string& op_type_name, const std::string& attr_name);\nMaybe<OperatorConf> CheckAndCompleteUserOpConfImpl(const OperatorConf& op_conf);\nMaybe<void> AddAttrDefaultValueAndCheckValid(UserOpConf* user_conf);\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_FRAMEWORK_USER_OP_CONF_H_\n"
  },
  {
    "path": "oneflow/core/framework/user_op_conf.proto",
    "content": "syntax = \"proto2\";\npackage oneflow;\n\nimport \"oneflow/core/framework/user_op_attr.proto\";\n\nmessage UserOpConf {\n  message ListString {\n    repeated string s = 1;\n  }\n  required string op_type_name = 1;\n  map<string, ListString> input = 2;\n  map<string, ListString> output = 3;\n  map<string, AttrValue> attr = 4;\n  // NOTE(chengcheng): specify the input/output order according to the order called by\n  //   UserOpBuilder.\n  repeated string input_order = 5;\n  repeated string output_order = 6;\n}\n"
  },
  {
    "path": "oneflow/core/framework/user_op_def.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/user_op_def.h\"\n#include \"oneflow/core/framework/attr_value.h\"\n#include \"oneflow/core/framework/attr_value_accessor.h\"\n\nnamespace oneflow {\n\nnamespace user_op {\n\nUserOpDefWrapper::UserOpDefWrapper(const UserOpDef& def)\n    : def_(def), inputs_(), outputs_(), attrs_() {\n  for (int32_t i = 0; i < def_.input_size(); ++i) {\n    inputs_.emplace(def_.input(i).name(), def_.mutable_input(i));\n  }\n  for (int32_t i = 0; i < def_.output_size(); ++i) {\n    outputs_.emplace(def_.output(i).name(), def_.mutable_output(i));\n  }\n  for (int32_t i = 0; i < def_.attr_size(); ++i) {\n    attrs_.emplace(def_.attr(i).name(), def_.mutable_attr(i));\n  }\n}\n\nbool UserOpDefWrapper::IsInputArgName(const std::string& name) const {\n  return inputs_.find(name) != inputs_.end();\n}\n\nbool UserOpDefWrapper::IsOutputArgName(const std::string& name) const {\n  return outputs_.find(name) != outputs_.end();\n}\n\nbool UserOpDefWrapper::IsAttrName(const std::string& name) const {\n  return attrs_.find(name) != attrs_.end();\n}\n\nbool UserOpDefWrapper::IsArgOptional(const std::string& name) const {\n  const UserOpDef::ArgDef* arg_def = GetArgPointer(name);\n  CHECK_NOTNULL(arg_def);\n  return arg_def->is_optional();\n}\n\nconst UserOpDef::ArgDef* UserOpDefWrapper::GetArgPointer(const std::string& name) const {\n  auto it = inputs_.find(name);\n  if (it != inputs_.end()) { return it->second; }\n  it = outputs_.find(name);\n  if (it != outputs_.end()) { return it->second; }\n  return nullptr;\n}\n\nAttrType UserOpDefWrapper::GetAttrType(const std::string& name) const {\n  return attrs_.at(name)->type();\n}\n\nbool UserOpDefWrapper::AttrHasDefaultVal(const std::string& name) const {\n  return attrs_.at(name)->has_default_val();\n}\n\n#define ATTR_TYPE_SPECIALIZATION(field, cpp_type, attr_type)                              \\\n  template<>                                                                              \\\n  cpp_type UserOpDefWrapper::GetAttrDefaultVal<cpp_type>(const std::string& name) const { \\\n    CHECK(AttrHasDefaultVal(name));                                                       \\\n    const AttrValue& default_val = attrs_.at(name)->default_val();                        \\\n    CHECK_EQ(static_cast<int>(attr_type), default_val.value_case());                      \\\n    return AttrValueAccessor<cpp_type>::Attr(default_val);                                \\\n  }\n\nOF_PP_FOR_EACH_TUPLE(ATTR_TYPE_SPECIALIZATION, ATTR_SEQ)\n\n#undef ATTR_TYPE_SPECIALIZATION\n}  // namespace user_op\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/framework/user_op_def.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_FRAMEWORK_USER_OP_DEF_WRAPPER_H_\n#define ONEFLOW_CORE_FRAMEWORK_USER_OP_DEF_WRAPPER_H_\n\n#include \"oneflow/core/framework/user_op_def.pb.h\"\n#include \"oneflow/core/common/util.h\"\n\nnamespace oneflow {\n\nnamespace user_op {\n\nclass UserOpDefWrapper final {\n public:\n  UserOpDefWrapper(const UserOpDef&);\n  ~UserOpDefWrapper() = default;\n  UserOpDefWrapper(const UserOpDefWrapper&) = delete;\n  UserOpDefWrapper(UserOpDefWrapper&&) = delete;\n\n  const std::string& name() const { return def_.name(); }\n\n  bool IsInputArgName(const std::string&) const;\n  bool IsOutputArgName(const std::string&) const;\n  bool IsAttrName(const std::string&) const;\n\n  bool IsArgOptional(const std::string&) const;\n\n  AttrType GetAttrType(const std::string&) const;\n  bool AttrHasDefaultVal(const std::string&) const;\n  template<typename T>\n  T GetAttrDefaultVal(const std::string&) const;\n\n private:\n  const UserOpDef::ArgDef* GetArgPointer(const std::string&) const;\n\n  UserOpDef def_;\n  HashMap<std::string, UserOpDef::ArgDef*> inputs_;\n  HashMap<std::string, UserOpDef::ArgDef*> outputs_;\n  HashMap<std::string, UserOpDef::AttrDef*> attrs_;\n};\n\n}  // namespace user_op\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_FRAMEWORK_USER_OP_DEF_WRAPPER_H_\n"
  },
  {
    "path": "oneflow/core/framework/user_op_def.proto",
    "content": "syntax = \"proto2\";\npackage oneflow;\n\nimport \"oneflow/core/framework/user_op_attr.proto\";\n\nmessage UserOpDef {\n  required string name = 1;\n\n  message ArgDef {\n    required string name = 1;\n    optional bool is_optional = 2 [default = false];\n  }\n  repeated ArgDef input = 2;\n  repeated ArgDef output = 3;\n\n  message AttrDef {\n    required string name = 1;\n    required AttrType type = 2;\n    optional AttrValue default_val = 3;\n  }\n  repeated AttrDef attr = 4;\n}\n"
  },
  {
    "path": "oneflow/core/framework/user_op_hob.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_FRAMEWORK_USER_OP_HOB_H_\n#define ONEFLOW_CORE_FRAMEWORK_USER_OP_HOB_H_\n\n#include <sstream>\n\n#include \"oneflow/core/common/data_type.h\"\n#include \"oneflow/core/common/device_type.h\"\n#include \"oneflow/core/common/high_order_bool.h\"\n#include \"oneflow/core/framework/to_string.h\"\n#include \"oneflow/core/framework/user_op_registry_manager.h\"\n\nnamespace oneflow {\n\nnamespace user_op {\n\nALWAYS_INLINE inline auto HobTrue() {\n  std::ostringstream string_stream;\n  string_stream << \"\\\" always true \\\"\";\n  return hob::LiteralBool<KernelRegContext>(string_stream.str(), true);\n}\n\nALWAYS_INLINE inline auto HobFalse() {\n  std::ostringstream string_stream;\n  string_stream << \"\\\" always false \\\"\";\n  return hob::LiteralBool<KernelRegContext>(string_stream.str(), false);\n}\n\nALWAYS_INLINE inline auto HobDataType(const std::string& tensor_name, int tensor_idx) {\n  std::ostringstream string_stream;\n  string_stream << \"data_type of tensor \\'\" << tensor_name << \"\\'\";\n  return hob::make_custom(\n      string_stream.str(), [tensor_name, tensor_idx](const KernelRegContext& ctx) -> DataType {\n        const user_op::TensorDesc* desc = ctx.TensorDesc4ArgNameAndIndex(tensor_name, tensor_idx);\n        CHECK(desc != nullptr) << \"key `\" << tensor_name << \"_\" << tensor_idx << \"` not found.\";\n        return desc->data_type();\n      });\n}\n\nALWAYS_INLINE inline auto HobInputSize(const std::string& tensor_name) {\n  std::ostringstream string_stream;\n  string_stream << \"size of input \\'\" << tensor_name << \"\\'\";\n  return hob::make_custom(string_stream.str(),\n                          [tensor_name](const KernelRegContext& ctx) -> int32_t {\n                            return ctx.user_op_conf().input_size(tensor_name);\n                          });\n}\n\ntemplate<typename T>\nALWAYS_INLINE inline auto HobAttr(const std::string& attr_name) {\n  return hob::make_custom(attr_name, [attr_name](const user_op::KernelRegContext& ctx) -> const T& {\n    return ctx.Attr<T>(attr_name);\n  });\n}\n\nALWAYS_INLINE inline auto HobDeviceType() {\n  return hob::make_custom(\n      \"device_type\", [](const KernelRegContext& ctx) -> DeviceType { return ctx.device_type(); });\n}\n\nALWAYS_INLINE inline auto HobDeviceSubTag() {\n  return hob::make_custom(\"device_sub_tag\", [](const KernelRegContext& ctx) -> const std::string& {\n    return ctx.Attr<std::string>(\"device_sub_tag\");\n  });\n}\n\nALWAYS_INLINE inline auto HobEnvBool(const std::string& env_var, bool default_value) {\n  std::ostringstream string_stream;\n  string_stream << \"environment variable \\'\" << env_var << \"\\'\";\n  return hob::make_custom(string_stream.str(),\n                          [env_var, default_value](const KernelRegContext& ctx) -> bool {\n                            return ParseBooleanFromEnv(env_var, default_value);\n                          });\n}\n\n}  // namespace user_op\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_FRAMEWORK_USER_OP_HOB_H_\n"
  },
  {
    "path": "oneflow/core/framework/user_op_kernel_registry.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/user_op_kernel_registry.h\"\n#include \"oneflow/core/framework/user_op_hob.h\"\n\nnamespace oneflow {\n\nnamespace user_op {\n\nOpKernelRegistry& OpKernelRegistry::Name(const std::string& op_type_name) {\n  result_.op_type_name = op_type_name;\n  return *this;\n}\n\nOpKernelRegistry& OpKernelRegistry::SetCreateFn(OpKernelCreateFn fn) {\n  result_.create_fn = std::move(fn);\n  return *this;\n}\n\nOpKernelRegistry& OpKernelRegistry::SetInferTmpSizeFn(InferTmpSizeFn fn) {\n  result_.infer_tmp_size_fn = std::move(fn);\n  return *this;\n}\n\nOpKernelRegistry& OpKernelRegistry::SetInplaceProposalFn(InplaceProposalFn fn) {\n  result_.inplace_proposal_fn = std::move(fn);\n  return *this;\n}\n\nOpKernelRegistry& OpKernelRegistry::SetPriority(int32_t priority) {\n  result_.priority = priority;\n  return *this;\n}\n\nMaybe<OpKernelRegistry&> OpKernelRegistry::Finish() {\n  CHECK_OR_RETURN(result_.create_fn != nullptr)\n      << \"No Create function for \" << result_.op_type_name;\n  result_.need_temp_storage = (result_.infer_tmp_size_fn != nullptr);\n  if (!result_.need_temp_storage) { result_.infer_tmp_size_fn = TmpSizeInferFnUtil::ZeroTmpSize; }\n  if (result_.inplace_proposal_fn == nullptr) {\n    result_.inplace_proposal_fn = [](const InferContext&, AddInplaceArgPair) {\n      return Maybe<void>::Ok();\n    };\n  }\n  if (result_.is_matched_hob == nullptr) {\n    static auto hob_true = std::make_shared<decltype(user_op::HobTrue())>(user_op::HobTrue());\n    result_.is_matched_hob = hob_true;\n  }\n  return *this;\n}\n\n}  // namespace user_op\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/framework/user_op_kernel_registry.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_FRAMEWORK_USER_OP_KERNEL_REGISTRY_H_\n#define ONEFLOW_CORE_FRAMEWORK_USER_OP_KERNEL_REGISTRY_H_\n\n#include \"oneflow/core/common/device_type.h\"\n#include \"oneflow/core/common/data_type.pb.h\"\n#include \"oneflow/core/framework/op_kernel.h\"\n#include \"oneflow/core/job/placement.pb.h\"\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/framework/user_op_conf.h\"\n#include \"oneflow/core/common/high_order_bool.h\"\n\nnamespace oneflow {\n\nnamespace user_op {\n\nclass OpKernel;\nclass TensorDesc;\nclass InferContext;\n\nclass KernelRegContext {\n public:\n  virtual ~KernelRegContext() = default;\n\n  virtual DeviceType device_type() const = 0;\n  virtual const ParallelContext& parallel_ctx() const = 0;\n  virtual const TensorDesc* TensorDesc4ArgNameAndIndex(const std::string&, int32_t) const = 0;\n\n  virtual const std::vector<std::pair<std::string, int32_t>>& inputs() const = 0;\n  virtual const std::vector<std::pair<std::string, int32_t>>& outputs() const = 0;\n\n  virtual const UserOpConfWrapper& user_op_conf() const = 0;\n\n  template<typename T>\n  const T& Attr(const std::string& attr_name) const {\n    return AttrValueCast<T>(*Attr4Name(attr_name));\n  }\n\n protected:\n  KernelRegContext() = default;\n  KernelRegContext(const KernelRegContext&) = delete;\n  virtual const std::shared_ptr<const AttrVal>& Attr4Name(const std::string& attr_name) const = 0;\n};\n\nusing OpKernelCreateFn = std::function<const OpKernel*()>;\nusing InferTmpSizeFn = std::function<size_t(InferContext*)>;\nusing AddInplaceArgPair = std::function<Maybe<void>(\n    const std::string& out_arg_name, int32_t out_arg_index, const std::string& in_arg_name,\n    int32_t in_arg_index, bool is_mutable)>;\nusing InplaceProposalFn = std::function<Maybe<void>(const InferContext&, AddInplaceArgPair)>;\nusing IsMatchedHob = std::shared_ptr<hob::BaseExpr<user_op::KernelRegContext, bool>>;\n\nconstexpr int kKernelPriorityFallback = -10;\nconstexpr int kKernelPriorityDefault = 0;\nconstexpr int kKernelPriorityOptimized = 10;\nconstexpr int kKernelPriorityExperimental = 100;\n\nstruct OpKernelRegistryResult {\n  std::string op_type_name;\n\n  OpKernelCreateFn create_fn;\n  bool need_temp_storage;\n  InferTmpSizeFn infer_tmp_size_fn;\n  InplaceProposalFn inplace_proposal_fn;\n  IsMatchedHob is_matched_hob;\n  int32_t priority = kKernelPriorityDefault;\n};\n\nclass OpKernelRegistry final {\n public:\n  OpKernelRegistry& Name(const std::string& op_type_name);\n\n  template<typename T>\n  OpKernelRegistry& SetCreateFn() {\n    return SetCreateFn([]() -> const OpKernel* { return NewOpKernel<T>(); });\n  }\n  template<typename T>\n  OpKernelRegistry& SetIsMatchedHob(const T& hob) {\n    result_.is_matched_hob = std::make_shared<T>(hob);\n    return *this;\n  }\n  OpKernelRegistry& SetInferTmpSizeFn(InferTmpSizeFn fn);\n  OpKernelRegistry& SetInplaceProposalFn(InplaceProposalFn fn);\n\n  Maybe<OpKernelRegistry&> Finish();\n  OpKernelRegistryResult GetResult() { return result_; }\n\n  OpKernelRegistry& SetCreateFn(OpKernelCreateFn fn);\n  OpKernelRegistry& SetPriority(int32_t priority);\n\n private:\n  OpKernelRegistryResult result_;\n};\n\n}  // namespace user_op\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_FRAMEWORK_USER_OP_KERNEL_REGISTRY_H_\n"
  },
  {
    "path": "oneflow/core/framework/user_op_registry.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/user_op_registry.h\"\n#include \"oneflow/core/common/balanced_splitter.h\"\n#include \"oneflow/core/framework/device.h\"\n#include \"oneflow/core/framework/stream.h\"\n#include \"oneflow/core/framework/infer_util.h\"\n#include \"oneflow/core/framework/attr_value.h\"\n#include \"oneflow/core/framework/attr_value_accessor.h\"\n#include \"oneflow/core/framework/sbp_context.h\"\n#include \"oneflow/core/operator/operator.h\"\n\nnamespace oneflow {\n\nnamespace user_op {\n\nnamespace {\n\nbool InsertIfNotExists(const std::string& name, HashSet<std::string>* unique_names) {\n  if (unique_names->find(name) != unique_names->end()) { return false; }\n  unique_names->emplace(name);\n  return true;\n}\n\n}  // namespace\n\nOpRegistry& OpRegistry::Name(const std::string& op_type_name) {\n  CHECK(InsertIfNotExists(op_type_name, &unique_names_));\n  result_.op_type_name = op_type_name;\n  return *this;\n}\n\nOpRegistry& OpRegistry::ArgImpl(bool is_input, const std::string& name, bool is_optional) {\n  CHECK(InsertIfNotExists(name, &unique_names_))\n      << \"op arg registered, name: \" << name << \", op: \" << result_.op_type_name;\n  UserOpDef::ArgDef arg_def;\n  {\n    arg_def.set_name(name);\n    arg_def.set_is_optional(is_optional);\n  }\n  if (is_input) {\n    *(result_.op_def.mutable_input()->Add()) = arg_def;\n  } else {\n    *(result_.op_def.mutable_output()->Add()) = arg_def;\n  }\n  return *this;\n}\n\n#define OP_REG_ARG_MEMBER_FUNC(name_prefix, is_input, is_optional) \\\n  OpRegistry& OpRegistry::name_prefix(const std::string& name) {   \\\n    return ArgImpl(is_input, name, is_optional);                   \\\n  }\n\nOP_REG_ARG_MEMBER_FUNC(Input, true, false)\nOP_REG_ARG_MEMBER_FUNC(OptionalInput, true, true)\nOP_REG_ARG_MEMBER_FUNC(Output, false, false)\nOP_REG_ARG_MEMBER_FUNC(OptionalOutput, false, true)\n\n#undef OP_REG_ARG_MEMBER_FUNC\n\nOpRegistry& OpRegistry::SupportCpuOnly() {\n  result_.cpu_only_supported = true;\n  return *this;\n}\n\nOpRegistry& OpRegistry::SupportNonContiguous() {\n  result_.non_contiguous_supported = true;\n  return *this;\n}\n\nOpRegistry& OpRegistry::NoGrad() {\n  result_.no_grad = true;\n  return *this;\n}\n\nOpRegistry& OpRegistry::SetOutputBufferNum(int32_t num) {\n  result_.same_output_regst_num = num;\n  return *this;\n}\n\nOpRegistry& OpRegistry::Attr(const std::string& name, AttrType type) {\n  CHECK(InsertIfNotExists(name, &unique_names_));\n  UserOpDef::AttrDef attr_def;\n  attr_def.set_name(name);\n  attr_def.set_type(type);\n  *(result_.op_def.mutable_attr()->Add()) = attr_def;\n  return *this;\n}\n\nnamespace {\n\nvoid AddAttrWithDefault(OpRegistryResult* result, const std::string& name, AttrType type,\n                        std::function<void(UserOpDef::AttrDef*)> handler) {\n  UserOpDef::AttrDef attr_def;\n  attr_def.set_name(name);\n  attr_def.set_type(type);\n  handler(&attr_def);\n  *(result->op_def.mutable_attr()->Add()) = std::move(attr_def);\n}\n\n}  // namespace\n\n#define ATTR_MEMBER_FUNC(field, cpp_type, attr_type)                                             \\\n  template<>                                                                                     \\\n  OpRegistry& OpRegistry::Attr<cpp_type>(const std::string& name, AttrType type,                 \\\n                                         const cpp_type& default_val) {                          \\\n    CHECK_EQ(type, attr_type);                                                                   \\\n    return DefaultedAttr(name, type, [default_val](UserOpDef::AttrDef* attr_def) {               \\\n      AttrValueAccessor<cpp_type>::Attr(default_val, attr_def->mutable_default_val());           \\\n    });                                                                                          \\\n  }                                                                                              \\\n  template<>                                                                                     \\\n  OpRegistry& OpRegistry::Attr<cpp_type>(const std::string& name, const cpp_type& default_val) { \\\n    return DefaultedAttr(                                                                        \\\n        name, GetAttrType<cpp_type>::value, [default_val](UserOpDef::AttrDef* attr_def) {        \\\n          AttrValueAccessor<cpp_type>::Attr(default_val, attr_def->mutable_default_val());       \\\n        });                                                                                      \\\n  }                                                                                              \\\n  template<>                                                                                     \\\n  OpRegistry& OpRegistry::Attr<cpp_type>(const std::string& name) {                              \\\n    return Attr<cpp_type>(name, cpp_type());                                                     \\\n  }\n\nOF_PP_FOR_EACH_TUPLE(ATTR_MEMBER_FUNC, ATTR_SEQ)\n\n#undef ATTR_MEMBER_FUNC\n\nOpRegistry& OpRegistry::DefaultedAttr(const std::string& name, AttrType type,\n                                      const std::function<void(UserOpDef::AttrDef*)>& SetDefault) {\n  CHECK(InsertIfNotExists(name, &unique_names_));\n  AddAttrWithDefault(&result_, name, type, SetDefault);\n  return *this;\n}\n\nOpRegistry& OpRegistry::SetTensorDescInferFn(TensorDescInferFn tensor_desc_infer_fn) {\n  SetLogicalTensorDescInferFn(tensor_desc_infer_fn);\n  SetPhysicalTensorDescInferFn(tensor_desc_infer_fn);\n  return *this;\n}\n\nOpRegistry& OpRegistry::SetLogicalTensorDescInferFn(TensorDescInferFn tensor_desc_infer_fn) {\n  result_.logical_tensor_desc_infer_fn = std::move(tensor_desc_infer_fn);\n  return *this;\n}\n\nOpRegistry& OpRegistry::SetPhysicalTensorDescInferFn(TensorDescInferFn tensor_desc_infer_fn) {\n  result_.physical_tensor_desc_infer_fn = std::move(tensor_desc_infer_fn);\n  return *this;\n}\n\nOpRegistry& OpRegistry::SetCheckAttrFn(CheckAttrFn fn) {\n  result_.check_fn = std::move(fn);\n  return *this;\n}\n\nOpRegistry& OpRegistry::SetGetSbpFn(GetSbpFn get_sbp_fn) {\n  result_.get_sbp_fn = std::move(get_sbp_fn);\n  return *this;\n}\n\nOpRegistry& OpRegistry::SetSbpSignatureInferFn(SbpSignatureInferFn sbp_signature_infer_fn) {\n  result_.sbp_signature_infer_fn = std::move(sbp_signature_infer_fn);\n  return *this;\n}\n\nOpRegistry& OpRegistry::SetInputArgModifyFn(InputArgModifyFn input_arg_modify_fn) {\n  result_.input_arg_modify_fn = std::move(input_arg_modify_fn);\n  return *this;\n}\n\nOpRegistry& OpRegistry::SetOutputArgModifyFn(OutputArgModifyFn output_arg_modify_fn) {\n  result_.output_arg_modify_fn = std::move(output_arg_modify_fn);\n  return *this;\n}\n\nOpRegistry& OpRegistry::SetOutputBlobTimeShapeInferFn(\n    OutputBlobTimeShapeInferFn output_blob_time_shape_infer_fn) {\n  result_.output_blob_time_shape_infer_fn = std::move(output_blob_time_shape_infer_fn);\n  return *this;\n}\n\nOpRegistry& OpRegistry::SetNdSbpInferFn(NdSbpInferFn nd_sbp_infer_fn) {\n  result_.nd_sbp_infer_fn = std::move(nd_sbp_infer_fn);\n  return *this;\n}\n\nOpRegistry& OpRegistry::SetDataTypeInferFn(DataTypeInferFn data_type_infer_fn) {\n  result_.data_type_infer_fn = std::move(data_type_infer_fn);\n  return *this;\n}\n\nOpRegistry& OpRegistry::SetDeviceAndStreamInferFn(\n    DeviceAndStreamInferFn device_and_stream_infer_fn) {\n  result_.device_and_stream_infer_fn = std::move(device_and_stream_infer_fn);\n  return *this;\n}\n\nOpRegistry& OpRegistry::SetComputeComplexityFn(ComputeComplexityFn compute_complexity_fn) {\n  result_.compute_complexity_fn = std::move(compute_complexity_fn);\n  return *this;\n}\n\nOpRegistry& OpRegistry::SetGetNdSbpSignatureListFn(GetNdSbpSignatureListFn get_nd_sbp_list_fn) {\n  result_.get_nd_sbp_list_fn = std::move(get_nd_sbp_list_fn);\n  return *this;\n}\n\nOpRegistry& OpRegistry::SetEnumerateNdSbpSignaturesFn(EnumerateNdSbpSignaturesFn fn) {\n  result_.enumerate_nd_sbp_signatures_fn = std::move(fn);\n  return *this;\n}\n\nOpRegistry& OpRegistry::SetDumpNdSbpSignatureForOpConfFn(\n    Operator::DumpNdSbpSignatureForOpConfFn fn) {\n  result_.dump_nd_sbp_signature_for_op_conf_fn = std::move(fn);\n  return *this;\n}\n\nMaybe<OpRegistry&> OpRegistry::Finish() {\n  CHECK_OR_RETURN(result_.logical_tensor_desc_infer_fn != nullptr)\n      << \"No TensorDescInfer function for \" << result_.op_type_name;\n  if (!result_.physical_tensor_desc_infer_fn) {\n    const auto& logical_fn = result_.logical_tensor_desc_infer_fn;\n    result_.physical_tensor_desc_infer_fn =\n        [logical_fn](user_op::InferContext* ctx) -> Maybe<void> {\n      if (ctx->parallel_num() == 1) {\n        logical_fn(ctx);\n      } else {\n        for (const auto& pair : ctx->inputs()) {\n          const auto& nd_sbp = ctx->NdSbp4ArgNameAndIndex(pair.first, pair.second);\n          const TensorDesc* in_logical =\n              ctx->LogicalTensorDesc4ArgNameAndIndex(pair.first, pair.second);\n          const TensorDesc& in_physical = ctx->InputTensorDesc(pair.first, pair.second);\n          CHECK_OR_RETURN(*JUST(GetPhysicalShape(in_logical->shape(), nd_sbp, ctx->parallel_desc(),\n                                                 ctx->parallel_ctx()))\n                          == in_physical.shape());\n        }\n        for (const auto& pair : ctx->outputs()) {\n          TensorDesc* desc = ctx->MutOutputTensorDesc(pair.first, pair.second);\n          *desc = *ctx->LogicalTensorDesc4ArgNameAndIndex(pair.first, pair.second);\n          const auto& nd_sbp = ctx->NdSbp4ArgNameAndIndex(pair.first, pair.second);\n          desc->set_shape(*JUST(\n              GetPhysicalShape(desc->shape(), nd_sbp, ctx->parallel_desc(), ctx->parallel_ctx())));\n          desc->set_stride(Stride(desc->shape()));\n        }\n      }\n      return Maybe<void>::Ok();\n    };\n  }\n  if (result_.check_fn == nullptr) { result_.check_fn = CheckAttrFnUtil::NoCheck; }\n  CHECK_OR_RETURN(result_.get_sbp_fn != nullptr) << \"No Sbp function for \" << result_.op_type_name;\n  if (result_.cpu_only_supported && result_.device_and_stream_infer_fn == nullptr) {\n    result_.device_and_stream_infer_fn =\n        [](DeviceAndStreamInferContext* ctx) -> Maybe<Symbol<Stream>> {\n      for (const auto& pair : ctx->inputs()) {\n        const Symbol<Device>& input_device =\n            ctx->InputTensorDevice4ArgNameAndIndex(pair.first, pair.second);\n        CHECK_EQ(input_device->type(), \"cpu\");\n      }\n      Symbol<Device> default_device;\n      {\n        if (ctx->inputs().size() != 0) {\n          const auto& first_input_name = ctx->inputs().begin()->first;\n          default_device = ctx->InputTensorDevice4ArgNameAndIndex(first_input_name, 0);\n        } else {\n          default_device = JUST(Device::New(\"cpu\"));\n        }\n      }\n      for (const auto& pair : ctx->outputs()) {\n        *ctx->OutputTensorDevice4ArgNameAndIndex(pair.first, pair.second) = default_device;\n      }\n      return Stream::New(default_device, StreamType::kCompute);\n    };\n  }\n  return *this;\n}\n\n}  // namespace user_op\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/framework/user_op_registry.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_FRAMEWORK_OP_REGISTRY_H_\n#define ONEFLOW_CORE_FRAMEWORK_OP_REGISTRY_H_\n\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/common/symbol.h\"\n#include \"oneflow/core/framework/user_op_def.pb.h\"\n#include \"oneflow/core/framework/user_op_attr.pb.h\"\n#include \"oneflow/core/framework/user_op_conf.pb.h\"\n#include \"oneflow/core/operator/op_attribute.pb.h\"\n#include \"oneflow/core/operator/operator.h\"\n\nnamespace oneflow {\n\nclass Device;\nclass Stream;\n\nnamespace user_op {\n\nclass UserOpDefWrapper;\nclass UserOpConfWrapper;\nclass InferContext;\nclass SbpContext;\nclass InferSbpSignatureFnContext;\nclass InferOutputBlobTimeShapeFnContext;\nclass InferNdSbpFnContext;\nclass DeviceAndStreamInferContext;\nclass ComputeComplexityFnContext;\nclass GetNdSbpSignatureListContext;\n\nusing CheckAttrFn = std::function<Maybe<void>(const UserOpDefWrapper&, const UserOpConfWrapper&)>;\nusing TensorDescInferFn = std::function<Maybe<void>(InferContext*)>;\nusing DataTypeInferFn = std::function<Maybe<void>(InferContext*)>;\nusing DeviceAndStreamInferFn = std::function<Maybe<Symbol<Stream>>(DeviceAndStreamInferContext*)>;\nusing GetSbpFn = std::function<Maybe<void>(SbpContext*)>;\nusing SbpSignatureInferFn = std::function<Maybe<void>(InferSbpSignatureFnContext*)>;\nusing InputArgModifier = InputBlobModifier;\nusing GetInputArgModifier =\n    std::function<InputArgModifier*(const std::string& in_arg_name, int32_t in_arg_index)>;\nusing InputArgModifyFn =\n    std::function<Maybe<void>(const GetInputArgModifier&, const UserOpConfWrapper&)>;\nusing OutputArgModifier = OutputBlobModifier;\nusing GetOutputArgModifier =\n    std::function<OutputArgModifier*(const std::string& out_arg_name, int32_t out_arg_index)>;\nusing OutputArgModifyFn =\n    std::function<Maybe<void>(const GetOutputArgModifier&, const UserOpConfWrapper&)>;\nusing OutputBlobTimeShapeInferFn = std::function<Maybe<void>(InferOutputBlobTimeShapeFnContext*)>;\nusing NdSbpInferFn = std::function<Maybe<void>(InferNdSbpFnContext*)>;\nusing ComputeComplexityFn = std::function<Maybe<double>(ComputeComplexityFnContext*)>;\n// TODO: set up another context\nusing GetNdSbpSignatureListFn = std::function<Maybe<void>(GetNdSbpSignatureListContext*)>;\nusing EnumerateNdSbpSignaturesFn = std::function<Maybe<void>(GetNdSbpSignatureListContext*)>;\n\nstruct OpRegistryResult {\n  OpRegistryResult()\n      : cpu_only_supported(false),\n        no_grad(false),\n        non_contiguous_supported(false),\n        same_output_regst_num(-1) {}\n  ~OpRegistryResult() = default;\n\n  std::string op_type_name;\n  bool cpu_only_supported;\n  bool no_grad;\n  bool non_contiguous_supported;\n  int32_t same_output_regst_num;\n  UserOpDef op_def;\n  CheckAttrFn check_fn;\n  TensorDescInferFn logical_tensor_desc_infer_fn;\n  TensorDescInferFn physical_tensor_desc_infer_fn;\n  GetSbpFn get_sbp_fn;\n  SbpSignatureInferFn sbp_signature_infer_fn;\n  DataTypeInferFn data_type_infer_fn;\n  DeviceAndStreamInferFn device_and_stream_infer_fn;\n  // TODO(niuchong): move input_arg_modify_fn out of OpRegistryResult since it is more about\n  // performance other than op definition\n  InputArgModifyFn input_arg_modify_fn;\n  OutputArgModifyFn output_arg_modify_fn;\n  OutputBlobTimeShapeInferFn output_blob_time_shape_infer_fn;\n  NdSbpInferFn nd_sbp_infer_fn;\n  ComputeComplexityFn compute_complexity_fn;\n  GetNdSbpSignatureListFn get_nd_sbp_list_fn;\n  EnumerateNdSbpSignaturesFn enumerate_nd_sbp_signatures_fn;\n  Operator::DumpNdSbpSignatureForOpConfFn dump_nd_sbp_signature_for_op_conf_fn;\n};\n\nclass OpRegistry final {\n public:\n  OpRegistry& Name(const std::string& op_type_name);\n\n  OpRegistry& Input(const std::string& name);\n  OpRegistry& Input(const std::string& name, int32_t num);\n  OpRegistry& InputWithMinimum(const std::string& name, int32_t min_num);\n  OpRegistry& OptionalInput(const std::string& name);\n  OpRegistry& OptionalInput(const std::string& name, int32_t num);\n  OpRegistry& OptionalInputWithMinimum(const std::string& name, int32_t min_num);\n\n  OpRegistry& Output(const std::string& name);\n  OpRegistry& Output(const std::string& name, int32_t num);\n  OpRegistry& OutputWithMinimum(const std::string& name, int32_t min_num);\n  OpRegistry& OptionalOutput(const std::string& name);\n  OpRegistry& OptionalOutput(const std::string& name, int32_t num);\n  OpRegistry& OptionalOutputWithMinimum(const std::string& name, int32_t min_num);\n\n  OpRegistry& SupportCpuOnly();\n  OpRegistry& SupportNonContiguous();\n  OpRegistry& NoGrad();\n  OpRegistry& SetOutputBufferNum(int32_t num);\n\n  __attribute__((deprecated)) OpRegistry& Attr(const std::string& name, AttrType type);\n  template<typename T>\n  __attribute__((deprecated)) OpRegistry& Attr(const std::string& name, AttrType type,\n                                               const T& default_val);\n  template<typename T>\n  OpRegistry& Attr(const std::string& name, const T& default_val);\n  template<typename T>\n  OpRegistry& Attr(const std::string& name);\n\n  OpRegistry& SetTensorDescInferFn(TensorDescInferFn fn);\n  OpRegistry& SetLogicalTensorDescInferFn(TensorDescInferFn fn);\n  OpRegistry& SetPhysicalTensorDescInferFn(TensorDescInferFn fn);\n  OpRegistry& SetGetSbpFn(GetSbpFn fn);\n  OpRegistry& SetSbpSignatureInferFn(SbpSignatureInferFn fn);\n  OpRegistry& SetInputArgModifyFn(InputArgModifyFn fn);\n  OpRegistry& SetOutputArgModifyFn(OutputArgModifyFn fn);\n  OpRegistry& SetOutputBlobTimeShapeInferFn(OutputBlobTimeShapeInferFn fn);\n  OpRegistry& SetNdSbpInferFn(NdSbpInferFn fn);\n  OpRegistry& SetCheckAttrFn(CheckAttrFn fn);\n  OpRegistry& SetDataTypeInferFn(DataTypeInferFn fn);\n  OpRegistry& SetDeviceAndStreamInferFn(DeviceAndStreamInferFn fn);\n  OpRegistry& SetComputeComplexityFn(ComputeComplexityFn fn);\n  OpRegistry& SetGetNdSbpSignatureListFn(GetNdSbpSignatureListFn fn);\n  OpRegistry& SetEnumerateNdSbpSignaturesFn(EnumerateNdSbpSignaturesFn fn);\n  OpRegistry& SetDumpNdSbpSignatureForOpConfFn(Operator::DumpNdSbpSignatureForOpConfFn fn);\n\n  Maybe<OpRegistry&> Finish();\n  OpRegistryResult GetResult() { return result_; }\n\n private:\n  OpRegistry& ArgImpl(bool is_input, const std::string& name, bool is_optional);\n  OpRegistry& DefaultedAttr(const std::string& name, AttrType type,\n                            const std::function<void(UserOpDef::AttrDef*)>& SetDefault);\n\n private:\n  HashSet<std::string> unique_names_;\n  OpRegistryResult result_;\n};\n\nstatic const std::string kUserSourceOpTickInputArgName = \"UserSourceOpTickInput\";\n\n}  // namespace user_op\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_FRAMEWORK_OP_REGISTRY_H_\n"
  },
  {
    "path": "oneflow/core/framework/user_op_registry_manager.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/user_op_registry_manager.h\"\n\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/framework/infer_util.h\"\n#include \"oneflow/core/common/tensor_desc.h\"\n#include \"oneflow/core/kernel/kernel.pb.h\"\n#include \"oneflow/core/operator/operator.h\"\n#include \"oneflow/core/common/env_var/env_var.h\"\n\nnamespace oneflow {\n\nDEFINE_ENV_BOOL(ONEFLOW_KERNEL_ENABLE_PRIORITY_EXPERIMENTAL, false);\n\nnamespace user_op {\n\nUserOpRegistryMgr& UserOpRegistryMgr::Get() {\n  static UserOpRegistryMgr mgr;\n  return mgr;\n}\n\nOpRegistry UserOpRegistryMgr::CheckAndGetOpRegistry(const std::string& op_type_name) {\n  CHECK(!op_type_name.empty());\n  auto it = op_reg_result_.find(op_type_name);\n  CHECK(it == op_reg_result_.end());\n  return OpRegistry().Name(op_type_name);\n}\n\nMaybe<void> UserOpRegistryMgr::Register(OpRegistryResult result) {\n  CHECK_OR_RETURN(result.data_type_infer_fn);\n  CHECK_OR_RETURN(op_reg_result_.emplace(result.op_type_name, result).second);\n  return Maybe<void>::Ok();\n}\n\nconst OpRegistryResult* UserOpRegistryMgr::GetOpRegistryResult(const std::string& op_type_name) {\n  auto it = op_reg_result_.find(op_type_name);\n  if (it != op_reg_result_.end()) { return &(it->second); }\n  return nullptr;\n}\n\nOpKernelRegistry UserOpRegistryMgr::CheckAndGetOpKernelRegistry(const std::string& op_type_name) {\n  CHECK(!op_type_name.empty());\n  return OpKernelRegistry().Name(op_type_name);\n}\n\nMaybe<void> UserOpRegistryMgr::Register(OpKernelRegistryResult result) {\n  op_kernel_reg_result_[result.op_type_name].emplace_back(result);\n  return Maybe<void>::Ok();\n}\n\nnamespace {\n\nstd::string GetErrorMsgOfSearchedOp(const KernelRegContext& ctx) {\n  const auto& op_conf = ctx.user_op_conf();\n  std::stringstream ss;\n  ss << \" The Info of OperatorConf are \"\n     << \"\\n op_name: \" << op_conf.op_name() << \"\\n op_type_name: \" << op_conf.op_type_name()\n     << \"\\n DeviceType_Name: \" << DeviceType_Name(ctx.device_type());\n  for (const auto& pair : ctx.inputs()) {\n    ss << \"\\n DataType_Name of \" << pair.first << \"_\" << pair.second << \": \"\n       << DataType_Name(ctx.TensorDesc4ArgNameAndIndex(pair.first, pair.second)->data_type());\n  }\n  for (const auto& pair : ctx.outputs()) {\n    ss << \"\\n DataType_Name of \" << pair.first << \"_\" << pair.second << \": \"\n       << DataType_Name(ctx.TensorDesc4ArgNameAndIndex(pair.first, pair.second)->data_type());\n  }\n  return ss.str();\n}\n\n}  // namespace\n\nMaybe<const OpKernelRegistryResult*> UserOpRegistryMgr::GetOpKernelRegistryResult(\n    const std::string& op_type_name, const KernelRegContext& ctx) {\n  auto it = op_kernel_reg_result_.find(op_type_name);\n  if (it == op_kernel_reg_result_.end()) {\n    return Error::OpKernelNotFoundError({})\n           << \"There is no kernel registered for Current OperatorConf. \"\n           << GetErrorMsgOfSearchedOp(ctx);\n  }\n\n  const OpKernelRegistryResult* ret = nullptr;\n  int32_t cur_priority = kKernelPriorityFallback;\n  const bool enable_priority_experimental = EnvBool<ONEFLOW_KERNEL_ENABLE_PRIORITY_EXPERIMENTAL>();\n  for (const auto& reg_val : it->second) {\n    if (reg_val.priority >= kKernelPriorityExperimental && (!enable_priority_experimental)) {\n      continue;\n    }\n    if (reg_val.is_matched_hob->get(ctx)) {\n      if (ret == nullptr || reg_val.priority > cur_priority) {\n        ret = &reg_val;\n        cur_priority = reg_val.priority;\n      } else if (ret != nullptr && reg_val.priority == cur_priority) {\n        LOG(WARNING)\n            << \"There are more than one kernels with same priority matching Current OperatorConf. \"\n            << GetErrorMsgOfSearchedOp(ctx);\n      } else {\n        // do nothing\n      }\n    }\n  }\n\n  if (ret == nullptr) {\n    std::vector<std::string> debug_msgs;\n    for (const auto& reg_val : it->second) {\n      debug_msgs.emplace_back(reg_val.is_matched_hob->DebugStr(ctx));\n    }\n    return Error::OpKernelNotFoundError(debug_msgs)\n           << \"Cannot find the kernel matching Current OperatorConf. \"\n           << GetErrorMsgOfSearchedOp(ctx);\n  }\n\n  return ret;\n}\n\nMaybe<bool> UserOpRegistryMgr::IsOpKernelRegistered(const std::string& op_type_name,\n                                                    const KernelRegContext& ctx) {\n  auto it = op_kernel_reg_result_.find(op_type_name);\n  if (it == op_kernel_reg_result_.end()) { return false; }\n  const bool enable_priority_experimental = EnvBool<ONEFLOW_KERNEL_ENABLE_PRIORITY_EXPERIMENTAL>();\n  for (const auto& reg_val : it->second) {\n    if (reg_val.priority >= kKernelPriorityExperimental && (!enable_priority_experimental)) {\n      continue;\n    }\n    if (reg_val.is_matched_hob->get(ctx)) { return true; }\n  }\n  return false;\n}\n\nUserOpHostMemoryInputRegistry& UserOpHostMemoryInputRegistry::Get() {\n  static UserOpHostMemoryInputRegistry mgr;\n  return mgr;\n}\n\nMaybe<void> UserOpHostMemoryInputRegistry::SetHostMemoryInput4Op(const std::string& op_type_name,\n                                                                 const std::string& arg_name,\n                                                                 int32_t index) {\n  auto it = op_type_name2host_memory_input_args_.find(op_type_name);\n  if (it == op_type_name2host_memory_input_args_.end()) {\n    auto pair = op_type_name2host_memory_input_args_.emplace(\n        op_type_name, small_vector<std::pair<std::string, int32_t>>());\n    CHECK_OR_RETURN(pair.second);\n    it = pair.first;\n  }\n  it->second.emplace_back(std::make_pair(arg_name, index));\n  return Maybe<void>::Ok();\n}\n\nbool UserOpHostMemoryInputRegistry::IsHostMemoryInput4Op(const std::string& op_type_name,\n                                                         const std::string& arg_name,\n                                                         int32_t index) const {\n  auto it = op_type_name2host_memory_input_args_.find(op_type_name);\n  if (it == op_type_name2host_memory_input_args_.end()) { return false; }\n  return std::find(it->second.begin(), it->second.end(), std::make_pair(arg_name, index))\n         != it->second.end();\n}\n\nbool UserOpHostMemoryInputRegistry::HasHostMemoryInput(const std::string& op_type_name) const {\n  return op_type_name2host_memory_input_args_.find(op_type_name)\n         != op_type_name2host_memory_input_args_.end();\n}\n\n}  // namespace user_op\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/framework/user_op_registry_manager.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_FRAMEWORK_USER_OP_REGISTRY_MANAGER_H_\n#define ONEFLOW_CORE_FRAMEWORK_USER_OP_REGISTRY_MANAGER_H_\n\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/framework/user_op_registry.h\"\n#include \"oneflow/core/framework/user_op_kernel_registry.h\"\n#include \"oneflow/core/common/registry_error.h\"\n#include \"oneflow/core/common/op_args_reserved_size.h\"\n\nnamespace oneflow {\n\nnamespace user_op {\n\nclass UserOpRegistryMgr final {\n private:\n  UserOpRegistryMgr() {}\n\n public:\n  UserOpRegistryMgr(UserOpRegistryMgr const&) = delete;\n  UserOpRegistryMgr& operator=(UserOpRegistryMgr const&) = delete;\n  static UserOpRegistryMgr& Get();\n\n public:\n  OpRegistry CheckAndGetOpRegistry(const std::string& op_type_name);\n  Maybe<void> Register(OpRegistryResult result);\n  const OpRegistryResult* GetOpRegistryResult(const std::string& op_type_name);\n\n  OpKernelRegistry CheckAndGetOpKernelRegistry(const std::string& op_type_name);\n  Maybe<void> Register(OpKernelRegistryResult result);\n  Maybe<const OpKernelRegistryResult*> GetOpKernelRegistryResult(const std::string& op_type_name,\n                                                                 const KernelRegContext& ctx);\n  Maybe<bool> IsOpKernelRegistered(const std::string& op_type_name, const KernelRegContext& ctx);\n\n  const HashMap<std::string, OpRegistryResult>& GetAllOpRegistryResults() {\n    return op_reg_result_;\n  };\n\n private:\n  HashMap<std::string, OpRegistryResult> op_reg_result_;\n  HashMap<std::string, std::vector<OpKernelRegistryResult>> op_kernel_reg_result_;\n};\n\ntemplate<typename RegistryT>\nstruct UserOpRegisterTrigger final {\n  UserOpRegisterTrigger(RegistryT& registry) {\n    CatchRegistryError([&]() -> Maybe<void> {\n      return UserOpRegistryMgr::Get().Register(JUST(registry.Finish()).GetResult());\n    });\n  }\n};\n\nclass UserOpHostMemoryInputRegistry final {\n public:\n  UserOpHostMemoryInputRegistry(UserOpHostMemoryInputRegistry const&) = delete;\n  UserOpHostMemoryInputRegistry& operator=(UserOpHostMemoryInputRegistry const&) = delete;\n  ~UserOpHostMemoryInputRegistry() = default;\n\n  static UserOpHostMemoryInputRegistry& Get();\n\n  Maybe<void> SetHostMemoryInput4Op(const std::string& op_type_name, const std::string& arg_name,\n                                    int32_t index);\n  bool IsHostMemoryInput4Op(const std::string& op_type_name, const std::string& arg_name,\n                            int32_t index) const;\n\n  bool HasHostMemoryInput(const std::string& op_type_name) const;\n\n private:\n  UserOpHostMemoryInputRegistry() {}\n  HashMap<std::string, small_vector<std::pair<std::string, int32_t>>>\n      op_type_name2host_memory_input_args_;\n};\n\n}  // namespace user_op\n\n}  // namespace oneflow\n\n#define REGISTER_OP_HOST_MEMORY_INPUT(op_type_name, arg_name, index)                      \\\n  COMMAND(CHECK_JUST(user_op::UserOpHostMemoryInputRegistry::Get().SetHostMemoryInput4Op( \\\n      op_type_name, arg_name, index)));\n\n#define REGISTER_USER_OP(name)                                                                \\\n  static ::oneflow::user_op::UserOpRegisterTrigger<::oneflow::user_op::OpRegistry> OF_PP_CAT( \\\n      g_register_trigger, __COUNTER__) =                                                      \\\n      ::oneflow::user_op::UserOpRegistryMgr::Get().CheckAndGetOpRegistry(name)\n\n#define REGISTER_CPU_ONLY_USER_OP(name) REGISTER_USER_OP(name).SupportCpuOnly()\n\n#define REGISTER_NO_GRAD_USER_OP(name) REGISTER_USER_OP(name).NoGrad()\n\n#define REGISTER_NO_GRAD_CPU_ONLY_USER_OP(name) REGISTER_NO_GRAD_USER_OP(name).SupportCpuOnly()\n\n#define REGISTER_USER_KERNEL(name)                                                       \\\n  static ::oneflow::user_op::UserOpRegisterTrigger<::oneflow::user_op::OpKernelRegistry> \\\n      OF_PP_CAT(g_register_trigger, __COUNTER__) =                                       \\\n          ::oneflow::user_op::UserOpRegistryMgr::Get().CheckAndGetOpKernelRegistry(name)\n\n#endif  // ONEFLOW_CORE_FRAMEWORK_USER_OP_REGISTRY_MANAGER_H_\n"
  },
  {
    "path": "oneflow/core/framework/user_op_tensor.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_FRAMEWORK_USER_OP_TENSOR_H_\n#define ONEFLOW_CORE_FRAMEWORK_USER_OP_TENSOR_H_\n\n#include <memory>\n#include \"oneflow/core/common/data_type.h\"\n#include \"oneflow/core/common/memory_format.pb.h\"\n#include \"oneflow/core/common/shape_view.h\"\n#include \"oneflow/core/common/shape.h\"\n#include \"oneflow/core/common/stride.h\"\n#include \"oneflow/core/memory/memory_case.pb.h\"\n#include \"oneflow/core/common/error.h\"\n\nnamespace oneflow {\n\nnamespace user_op {\n\nclass Tensor {\n public:\n#pragma GCC diagnostic push\n#pragma GCC diagnostic ignored \"-Wnon-virtual-dtor\"\n  // NOTE: Performance will be degraded if the destructor is virtual.\n  //       So please do NOT implement custom destructor in any child classes of user_op::Tensor,\n  //       and every fields of child classes should be of POD type.\n  ~Tensor() = default;\n#pragma GCC diagnostic pop\n\n  virtual ShapeView shape_view() const = 0;\n  virtual MutShapeView mut_shape_view() = 0;\n  virtual const Stride& stride() const = 0;\n  virtual DataType data_type() const = 0;\n  virtual MemoryFormat memory_format() const = 0;\n  virtual const MemoryCase& mem_case() const = 0;\n  virtual const void* raw_dptr() const = 0;\n  virtual void* mut_raw_dptr() = 0;\n\n  template<typename T = void>\n  const T* dptr() const {\n    CheckDataType<T>();\n    return reinterpret_cast<const T*>(raw_dptr());\n  }\n\n  template<typename T = void>\n  T* mut_dptr() {\n    CheckDataType<T>();\n    return reinterpret_cast<T*>(mut_raw_dptr());\n  }\n\n protected:\n  template<typename T>\n  void CheckDataType() const {\n    LOG_IF(FATAL, (std::is_same<T, void>::value == false && std::is_same<T, char>::value == false\n                   && data_type() != DataType::kChar && data_type() != GetDataType<T>::value))\n        << \"tensor data_type mismatched. value: \" << DataType_Name(data_type())\n        << \", template T:\" << DataType_Name(GetDataType<T>::value);\n  }\n};\n\n}  // namespace user_op\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_FRAMEWORK_USER_OP_TENSOR_H_\n"
  },
  {
    "path": "oneflow/core/framework/util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_FRAMEWORK_UTIL_H_\n#define ONEFLOW_CORE_FRAMEWORK_UTIL_H_\n\n#include \"oneflow/core/common/util.h\"\n\nnamespace std {\n\ntemplate<>\nstruct hash<std::pair<std::string, int32_t>> {\n  std::size_t operator()(const std::pair<std::string, int32_t>& p) const {\n    return oneflow::Hash(p.first, p.second);\n  }\n};\n\n}  // namespace std\n\nnamespace oneflow {\n\nnamespace user_op {}\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_FRAMEWORK_UTIL_H_\n"
  },
  {
    "path": "oneflow/core/framework/variable_meta_info.proto",
    "content": "syntax = \"proto2\";\npackage oneflow;\n\nimport \"oneflow/core/common/shape.proto\";\nimport \"oneflow/core/common/data_type.proto\";\n\nmessage VariableMetaInfo {\n  required ShapeProto shape = 2;\n  required DataType data_type = 3;\n}\n"
  },
  {
    "path": "oneflow/core/framework/variable_tensor_mgr.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/symbol.h\"\n#include \"oneflow/core/framework/dtype.h\"\n#include \"oneflow/core/framework/variable_tensor_mgr.h\"\n#include \"oneflow/core/common/container_util.h\"\n#include \"oneflow/core/common/just.h\"\n#include \"oneflow/core/common/throw.h\"\n#include \"oneflow/core/framework/tensor.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\n\nMaybe<void> VariableTensorMgr::Set(const std::string& variable_op_name,\n                                   const std::shared_ptr<one::Tensor>& variable_tensor,\n                                   const Symbol<DType>& dtype) {\n  if (dtype && variable_tensor->dtype() != dtype) {\n    LazyMode::Guard guard{false};\n    variables_[variable_op_name] = JUST(one::functional::Cast(variable_tensor, dtype, false));\n  } else {\n    variables_[variable_op_name] = variable_tensor;\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<one::Tensor> VariableTensorMgr::Get(const std::string& variable_op_name,\n                                          const Symbol<DType>& dtype) {\n  if (variables_.find(variable_op_name) != variables_.end()) {\n    const auto variable_tensor = variables_[variable_op_name];\n    if (dtype && variable_tensor->dtype() != dtype) {\n      LazyMode::Guard guard{false};\n      return JUST(one::functional::Cast(variable_tensor, dtype, false));\n    }\n    return variable_tensor;\n  }\n  return std::shared_ptr<one::Tensor>(nullptr);\n}\n\nvoid VariableTensorMgr::Delete(const std::string& variable_op_name) {\n  if (variables_.find(variable_op_name) != variables_.end()) { variables_.erase(variable_op_name); }\n}\n\nMaybe<void> VariableTensorMgr::Fill(\n    const std::vector<std::string>& variable_op_names,\n    const std::vector<std::shared_ptr<one::Tensor>>& variable_tensors) {\n  CHECK_EQ_OR_THROW(variable_op_names.size(), variable_tensors.size())\n      << \"The number of variable op names is not equal with the number of variable tensors.\";\n\n  for (size_t i = 0; i < variable_op_names.size(); ++i) {\n    JUST(Set(JUST(oneflow::VectorAt(variable_op_names, i)),\n             JUST(oneflow::VectorAt(variable_tensors, i))));\n  }\n  return Maybe<void>::Ok();\n}\n\nstd::tuple<std::vector<std::string>, std::vector<std::shared_ptr<one::Tensor>>>\nVariableTensorMgr::Dump() {\n  std::vector<std::string> variable_op_names;\n  std::vector<std::shared_ptr<one::Tensor>> variable_tensors;\n  for (const auto& x : variables_) {\n    variable_op_names.push_back(x.first);\n    variable_tensors.push_back(x.second);\n  }\n  return std::make_tuple(variable_op_names, variable_tensors);\n}\n\nvoid VariableTensorMgr::Reset() {\n  std::map<std::string, std::shared_ptr<one::Tensor>>().swap(variables_);\n}\n\nstd::vector<std::string> VariableTensorMgr::DumpNames() {\n  std::vector<std::string> variable_op_names;\n  for (const auto& x : variables_) { variable_op_names.push_back(x.first); }\n  return variable_op_names;\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/framework/variable_tensor_mgr.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_FRAMEWORK_VARIABLE_TENSOR_MGR_H_\n#define ONEFLOW_CORE_FRAMEWORK_VARIABLE_TENSOR_MGR_H_\n\n#include <map>\n#include <memory>\n#include <tuple>\n#include \"oneflow/core/common/just.h\"\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/framework/dtype.h\"\nnamespace oneflow {\n\ntemplate<typename T, typename Kind>\nclass Singleton;\nnamespace one {\n\nclass Tensor;\n\n}\n\nclass VariableTensorMgr final {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(VariableTensorMgr);\n  ~VariableTensorMgr() = default;\n\n  Maybe<void> Set(const std::string& variable_op_name,\n                  const std::shared_ptr<one::Tensor>& variable_tensor,\n                  const Symbol<DType>& dtype = Symbol<DType>());\n  Maybe<one::Tensor> Get(const std::string& variable_op_name,\n                         const Symbol<DType>& dtype = Symbol<DType>());\n\n  void Delete(const std::string& variable_op_name);\n  Maybe<void> Fill(const std::vector<std::string>& variable_op_names,\n                   const std::vector<std::shared_ptr<one::Tensor>>& variable_tensors);\n  std::tuple<std::vector<std::string>, std::vector<std::shared_ptr<one::Tensor>>> Dump();\n  std::vector<std::string> DumpNames();\n  void Reset();\n\n private:\n  friend class Singleton<VariableTensorMgr>;\n  VariableTensorMgr() = default;\n\n  std::map<std::string, std::shared_ptr<one::Tensor>> variables_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_FRAMEWORK_VARIABLE_TENSOR_MGR_H_\n"
  },
  {
    "path": "oneflow/core/functional/function_library.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_FUNCTIONAL_FUNCTION_LIBRARY_H_\n#define ONEFLOW_CORE_FUNCTIONAL_FUNCTION_LIBRARY_H_\n\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/common/wrap_dim_utils.h\"\n#include \"oneflow/core/functional/packed_functor.h\"\n#include \"oneflow/core/common/stride.h\"\n#include \"oneflow/core/framework/tensor_methods.h\"\n#include \"oneflow/core/common/throw.h\"\n\nnamespace oneflow {\nnamespace one {\nnamespace functional {\n\nclass FunctionLibrary {\n public:\n  virtual ~FunctionLibrary() = default;\n\n  template<typename T>\n  struct PackedFuncCreatorMap;\n\n  template<typename R, typename... Args>\n  struct PackedFuncCreatorMap<R(Args...)> {\n    using FunctorCreator = typename std::function<PackedFunctor<R(Args...)>()>;\n\n    static HashMap<std::string, FunctorCreator>* Get() {\n      static HashMap<std::string, FunctorCreator> functors;\n      return &functors;\n    }\n  };\n\n  template<typename Func>\n  void add_functor(const std::string& func_name, const Func& func) {\n    using func_type = typename function_traits<Func>::func_type;\n    add_functor_creator<Func>(\n        func_name, [=]() { return PackedFunctorMaker<func_type>::make(func_name, func); });\n  }\n\n  template<typename Func>\n  void add_one_functor(const std::string& func_name) {\n    using func_type = typename function_traits<Func>::func_type;\n    add_functor_creator<Func>(func_name, [=]() {\n      // Lazily construct functor since ops maybe have not been registered.\n      Func func;\n      return PackedFunctorMaker<func_type>::make(func_name, func);\n    });\n  }\n\n  template<typename... Fs>\n  void add_functor(const std::string& func_name) {\n    static_assert(sizeof...(Fs) > 0, \"at least one functor is expected\");\n    (add_one_functor<Fs>(func_name), ...);\n  }\n\n  template<typename R, typename... Args>\n  auto find(const std::string& func_name)\n      -> Maybe<PackedFunctor<typename PackedFunctorMaker<R(Args...)>::FType>> {\n    auto* functors = PackedFuncCreatorMap<typename PackedFunctorMaker<R(Args...)>::FType>::Get();\n    const auto& it = functors->find(func_name);\n    CHECK_OR_RETURN(it != functors->end())\n        << Error::RuntimeError() << \"Functor was not found for \\\"\" << func_name\n        << \"\\\", please check whether the functor has been registered correctly or not.\";\n    return it->second();\n  }\n\n  static FunctionLibrary* Global() {\n    static FunctionLibrary global_function_library;\n    return &global_function_library;\n  }\n\n private:\n  FunctionLibrary() = default;\n\n  template<typename Func, typename Creator>\n  void add_functor_creator(const std::string& func_name, Creator creator) {\n    using func_type = typename function_traits<Func>::func_type;\n    auto* functors = PackedFuncCreatorMap<typename PackedFunctorMaker<func_type>::FType>::Get();\n    CHECK_OR_THROW(functors->count(func_name) == 0)\n        << Error::RuntimeError() << \"The functor with name \" << func_name\n        << \" has been registered more than once.\";\n    functors->emplace(func_name, creator);\n  }\n};\n\n#define ONEFLOW_FUNCTION_LIBRARY(m) ONEFLOW_FUNCTION_LIBRARY_IMPL(m, __COUNTER__)\n#define ONEFLOW_FUNCTION_LIBRARY_IMPL(m, uuid)                                  \\\n  static void OF_PP_CAT(_oneflow_function_library_, uuid)(FunctionLibrary & m); \\\n  static int OF_PP_CAT(_oneflow_function_library_dummy_, uuid) = []() {         \\\n    FunctionLibrary* library = FunctionLibrary::Global();                       \\\n    OF_PP_CAT(_oneflow_function_library_, uuid)(*library);                      \\\n    return 0;                                                                   \\\n  }();                                                                          \\\n  void OF_PP_CAT(_oneflow_function_library_, uuid)(FunctionLibrary & m)\n\n}  // namespace functional\n}  // namespace one\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_FUNCTIONAL_FUNCTION_LIBRARY_H_\n"
  },
  {
    "path": "oneflow/core/functional/functional.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_FUNCTIONAL_FUNCTIONAL_H_\n#define ONEFLOW_CORE_FUNCTIONAL_FUNCTIONAL_H_\n\n#include \"oneflow/core/functional/functional_api.yaml.h\"\n\n#endif  // ONEFLOW_CORE_FUNCTIONAL_FUNCTIONAL_H_\n"
  },
  {
    "path": "oneflow/core/functional/functional_api.yaml",
    "content": "# Copyright 2020 The OneFlow Authors. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# The following data types are allowed,\n# {\n#   \"Tensor\", \"TensorTuple\", \"Scalar\", \"Int\", \"Int32\", \"Int64\", \"Float\", \"Double\", \"String\", \"Bool\",\n#   \"ScalarList\", \"IntList\", \"Int32List\", \"Int64List\", \"FloatList\", \"DoubleList\", \"StringList\",\n#   \"BoolList\", \"DataType\", \"Shape\", \"Generator\", \"TensorIndex\", \"Device\", \"Placement\",\n#   \"Sbp\", \"SbpList\", \"Layout\", \"MemoryFormat\",\n# }\n\n- name: \"add\"\n  signature:\n    [\n      \"Tensor (Tensor input, Tensor other, *, Scalar alpha=1, Bool inplace=False) => Add\",\n      \"Tensor (Tensor input, Scalar other, *, Scalar alpha=1, Bool inplace=False) => ScalarAdd\",\n      \"Tensor (Scalar input, Tensor other, *, Scalar alpha=1) => ScalarAdd\",\n      \"Tensor (TensorTuple inputs, *, Bool inplace=False) => Add\",\n    ]\n  bind_python: true\n\n# this api just for test host memory input\n- name: \"host_scalar_add_by_tensor\"\n  signature: \"Tensor (Tensor x, Tensor scalar) => HostScalarAddByTensor\"\n  bind_python: true\n\n- name: \"amin\"\n  signature: \"Tensor (Tensor input, Int32List[1] dim=None, Bool keepdim=False) => Amin\"\n  bind_python: True\n\n- name: \"sub\"\n  signature:\n    [\n      \"Tensor (Tensor input, Tensor other, *, Scalar alpha=1, Bool inplace=False) => Sub\",\n      \"Tensor (Tensor input, Scalar other, *, Scalar alpha=1, Bool inplace=False) => ScalarSub\",\n      \"Tensor (Scalar input, Tensor other, *, Scalar alpha=1) =>  ScalarSub\",\n    ]\n  bind_python: true\n\n- name: \"mul\"\n  signature:\n    [\n      \"Tensor (Tensor input, Tensor other) => Mul\",\n      \"Tensor (Tensor input, Scalar other, *, Bool inplace=False) => ScalarMul\",\n      \"Tensor (Scalar input, Tensor other) => ScalarMul\",\n    ]\n  bind_python: true\n\n- name: \"mul_\"\n  signature:\n    [\n      \"Tensor (Tensor input, Tensor other) => InplaceMul\",\n      \"Tensor (Tensor input, Scalar other) => InplaceScalarMul\",\n    ]\n  bind_python: true\n\n- name: \"addcmul\"\n  signature: \"Tensor (Tensor input, Tensor tensor1, Tensor tensor2, *, Scalar value=1) => Addcmul\"\n  bind_python: true\n\n- name: \"addcmul_\"\n  signature: \"Tensor (Tensor input, Tensor tensor1, Tensor tensor2, *, Scalar value=1) => InplaceAddcmul\"\n  bind_python: true\n\n- name: \"addcdiv\"\n  signature: \"Tensor (Tensor input, Tensor tensor1, Tensor tensor2, *, Scalar value=1) => AddCDiv\"\n  bind_python: true\n\n- name: \"addcdiv_\"\n  signature: \"Tensor (Tensor input, Tensor tensor1, Tensor tensor2, *, Scalar value=1) => InplaceAddCDiv\"\n  bind_python: true\n\n- name: \"div\"\n  signature:\n    [\n      \"Tensor (Tensor input, Tensor other) => Div\",\n      \"Tensor (Tensor input, Scalar other) => ScalarDiv\",\n      \"Tensor (Scalar input, Tensor other) => ScalarDiv\",\n      \"Tensor (Tensor input, Tensor other, *, String rounding_mode=None) => DivMode\",\n      \"Tensor (Tensor input, Scalar other, *, String rounding_mode=None) => ScalarDivMode\",\n      \"Tensor (Scalar input, Tensor other, *, String rounding_mode=None) => ScalarDivMode\",\n    ]\n  bind_python: true\n\n- name: \"div_\"\n  signature:\n    [\n      \"Tensor (Tensor input, Tensor other) => InplaceDiv\",\n      \"Tensor (Tensor input, Scalar other) => InplaceScalarDiv\",\n    ]\n  bind_python: true\n\n- name: \"div_grad\"\n  signature: \"Tensor (Tensor dz, Tensor z, Tensor y) => DivGrad\"\n  bind_python: False\n\n- name: \"equal\"\n  signature: \"Bool (Tensor input, Tensor other) => Equal\"\n  bind_python: true\n\n- name: \"broadcast_equal\"\n  signature: \n    [\n      \"Tensor (Tensor input, Tensor other) => BroadcastEqual\",\n      \"Tensor (Tensor input, Scalar other) => ScalarLogicalEqual\",\n      \"Tensor (Scalar input, Tensor other) => ScalarLogicalEqual\",\n    ]\n  bind_python: true\n\n- name: \"not_equal\"\n  signature:\n    [\n      \"Tensor (Tensor input, Tensor other) => BroadcastNotEqual\",\n      \"Tensor (Tensor input, Scalar other) => ScalarLogicalNotEqual\",\n      \"Tensor (Scalar input, Tensor other) => ScalarLogicalNotEqual\",\n    ]\n  bind_python: true\n\n- name: \"greater\"\n  signature:\n    [\n      \"Tensor (Tensor input, Tensor other) => BroadcastGreater\",\n      \"Tensor (Tensor input, Scalar other) => ScalarLogicalGreater\",\n      \"Tensor (Scalar input, Tensor other) => ScalarLogicalGreater\",\n    ]\n  bind_python: true\n\n- name: \"greater_\"\n  signature:\n    [\n      \"Tensor (Tensor input, Tensor other) => InplaceBroadcastGreater\",\n      \"Tensor (Tensor input, Scalar other) => InplaceScalarLogicalGreater\",\n    ]\n  bind_python: true\n\n- name: \"greater_equal\"\n  signature:\n    [\n      \"Tensor (Tensor input, Tensor other) => BroadcastGreaterEqual\",\n      \"Tensor (Tensor input, Scalar other) => ScalarLogicalGreaterEqual\",\n      \"Tensor (Scalar input, Tensor other) => ScalarLogicalGreaterEqual\",\n    ]\n  bind_python: true\n\n- name: \"logical_and\"\n  signature:\n    [\n      \"Tensor (Tensor input, Tensor other) => BroadcastLogicalAnd\",\n      \"Tensor (Tensor input, Scalar other) => ScalarLogicalAnd\",\n      \"Tensor (Scalar input, Tensor other) => ScalarLogicalAnd\",\n    ]\n  bind_python: true\n\n- name: \"logical_or\"\n  signature:\n    [\n      \"Tensor (Tensor input, Tensor other) => BroadcastLogicalOr\",\n      \"Tensor (Tensor input, Scalar other) => ScalarLogicalOr\",\n      \"Tensor (Scalar input, Tensor other) => ScalarLogicalOr\",\n    ]\n  bind_python: true\n\n- name: \"logical_not\"\n  signature: \"Tensor (Tensor input) => LogicalNot\"\n  bind_python: true\n\n- name: \"logical_xor\"\n  signature:\n    [\n      \"Tensor (Tensor input, Tensor other) => BroadcastLogicalXor\",\n      \"Tensor (Tensor input, Scalar other) => ScalarLogicalXor\",\n      \"Tensor (Scalar input, Tensor other) => ScalarLogicalXor\",\n    ]\n  bind_python: true\n\n- name: \"bitwise_and\"\n  signature:\n    [\n      \"Tensor (Tensor input, Tensor other) => BroadcastBitwiseAnd\",\n      \"Tensor (Tensor input, Scalar other) => ScalarBitwiseAnd\",\n      \"Tensor (Scalar input, Tensor other) => ScalarBitwiseAnd\",\n    ]\n  bind_python: true\n\n- name: \"bitwise_or\"\n  signature:\n    [\n      \"Tensor (Tensor input, Tensor other) => BroadcastBitwiseOr\",\n      \"Tensor (Tensor input, Scalar other) => ScalarBitwiseOr\",\n      \"Tensor (Scalar input, Tensor other) => ScalarBitwiseOr\",\n    ]\n  bind_python: true\n\n- name: \"bitwise_xor\"\n  signature:\n    [\n      \"Tensor (Tensor input, Tensor other) => BroadcastBitwiseXor\",\n      \"Tensor (Tensor input, Scalar other) => ScalarBitwiseXor\",\n      \"Tensor (Scalar input, Tensor other) => ScalarBitwiseXor\",\n    ]\n  bind_python: true\n\n- name: \"bitwise_not\"\n  signature: \"Tensor (Tensor input) => BitwiseNot\"\n  bind_python: true\n\n- name: \"less\"\n  signature:\n    [\n      \"Tensor (Tensor input, Tensor other) => BroadcastLess\",\n      \"Tensor (Tensor input, Scalar other) => ScalarLogicalLess\",\n      \"Tensor (Scalar input, Tensor other) => ScalarLogicalLess\",\n    ]\n  bind_python: True\n\n- name: \"less_equal\"\n  signature:\n    [\n      \"Tensor (Tensor input, Tensor other) => BroadcastLessEqual\",\n      \"Tensor (Tensor input, Scalar other) => ScalarLogicalLessEqual\",\n      \"Tensor (Scalar input, Tensor other) => ScalarLogicalLessEqual\",\n    ]\n  bind_python: True\n\n- name: \"pow\"\n  signature:\n    [\n      \"Tensor (Tensor input, Tensor exponent) => Pow\",\n      \"Tensor (Tensor input, Scalar exponent, *, Bool inplace=False) => ScalarPow\",\n      \"Tensor (Tensor input, Scalar exponent) => ScalarPow\",\n      \"Tensor (Scalar exponent, Tensor input) => ScalarReversePow\",\n    ]\n  bind_python: True\n\n- name: \"pow_x_grad\"\n  signature: \"Tensor (Tensor x, Tensor y, Tensor dz) => PowXGrad\"\n  bind_python: False\n\n- name: \"pow_y_grad\"\n  signature: \"Tensor (Tensor x, Tensor y, Tensor dz) => PowYGrad\"\n  bind_python: False\n\n- name: \"searchsorted\"\n  signature:\n    [\n      \"Tensor (Tensor sorted_sequence, Tensor values, Bool out_int32=False, Bool right=False) => SearchSorted\",\n      \"Tensor (Tensor sorted_sequence, Scalar values, Bool out_int32=False, Bool right=False) => SearchSortedScalar\",\n    ]\n  bind_python: True\n\n- name: \"scalar_pow_grad\"\n  signature: \"Tensor (Tensor input, Tensor dy, Scalar exponent) => ScalarPowGrad\"\n  bind_python: False\n\n- name: \"scalar_reverse_pow_grad\"\n  signature: \"Tensor (Tensor input, Tensor dy, Scalar exponent) => ScalarReversePowGrad\"\n  bind_python: False\n\n- name: \"broadcast_pow\"\n  signature: \"Tensor (Tensor x, Tensor y) => BroadcastPow\"\n  bind_python: False\n\n- name: \"broadcast_pow_x_grad\"\n  signature: \"Tensor (Tensor x, Tensor y, Tensor dz) => BroadcastPowXGrad\"\n  bind_python: False\n\n- name: \"broadcast_pow_y_grad\"\n  signature: \"Tensor (Tensor x, Tensor y, Tensor dz) => BroadcastPowYGrad\"\n  bind_python: False\n\n- name: \"floor_divide\"\n  signature:\n    [\n      \"Tensor (Tensor input, Tensor other) => FloorDiv\",\n      \"Tensor (Tensor input, Scalar other, *, Bool inplace=False) => ScalarFloorDiv\",\n      \"Tensor (Tensor input, Scalar other) => ScalarFloorDiv\",\n    ]\n  bind_python: True\n\n- name: \"floordiv_x_grad\"\n  signature: \"Tensor (Tensor dz, Tensor x, Tensor y) => FloorDivXGrad\"\n  bind_python: False\n\n- name: \"floordiv_y_grad\"\n  signature: \"Tensor (Tensor dz, Tensor x, Tensor y) => FloorDivYGrad\"\n  bind_python: False\n\n- name: \"lerp\"\n  signature: \n    [\n      \"Tensor (Tensor start, Tensor end, Tensor weight) => Lerp\",\n      \"Tensor (Tensor start, Tensor end, Scalar weight) => ScalarLerp\"\n    ]\n  bind_python: True\n\n- name: \"lerp_\"\n  signature: \n    [\n      \"Tensor (Tensor start, Tensor end, Tensor weight) => InplaceLerp\",\n      \"Tensor (Tensor start, Tensor end, Scalar weight) => ScalarInplaceLerp\",\n    ]\n  bind_python: True\n\n- name: \"lerp_grad\"\n  signature: \"TensorTuple (Tensor start, Tensor end, Tensor weight, Tensor out_diff) => LerpGrad\"\n  bind_python: False\n\n- name: \"scalar_lerp_grad\"\n  signature: \"TensorTuple (Tensor start, Tensor end, Tensor out_diff, Scalar weight) => ScalarLerpGrad\"\n  bind_python: False\n\n- name: \"trunc_divide\"\n  signature:\n    [\n      \"Tensor (Tensor input, Tensor other) => TruncDiv\",\n      \"Tensor (Tensor input, Scalar other, *, Bool inplace=False) => ScalarTruncDiv\",\n    ]\n  bind_python: True\n\n- name: \"truncdiv_x_grad\"\n  signature: \"Tensor (Tensor dz, Tensor x, Tensor y) => TruncDivXGrad\"\n  bind_python: False\n\n- name: \"truncdiv_y_grad\"\n  signature: \"Tensor (Tensor dz, Tensor x, Tensor y) => TruncDivYGrad\"\n  bind_python: False\n\n- name: \"xdivy_x_grad\"\n  signature: \"Tensor (Tensor dz, Tensor x, Tensor y) => XdivyXGrad\"\n  bind_python: False\n\n- name: \"xdivy_y_grad\"\n  signature: \"Tensor (Tensor dz, Tensor x, Tensor y) => XdivyYGrad\"\n  bind_python: False\n\n- name: \"xlogy_x_grad\"\n  signature: \"Tensor (Tensor dz, Tensor x, Tensor y) => XlogyXGrad\"\n  bind_python: False\n\n- name: \"xlogy_y_grad\"\n  signature: \"Tensor (Tensor dz, Tensor x, Tensor y) => XlogyYGrad\"\n  bind_python: False\n\n- name: \"max\"\n  signature:\n    [\n      \"Tensor (Tensor input) => Max\",\n      \"Tensor (Tensor input, Tensor other) => Max\",\n      \"TensorTuple[values, indices] (Tensor input, Int32 dim, Bool keepdim=False) => Max\",\n    ]\n  bind_python: True\n\n- name: \"min\"\n  signature:\n    [\n      \"Tensor (Tensor input) => Min\",\n      \"TensorTuple[values, indices] (Tensor input, Int32 dim, Bool keepdim=False) => Min\",\n      \"Tensor (Tensor input, Tensor other) => Min\",\n    ]\n  bind_python: True\n\n- name: \"median\"\n  signature:\n    [\n      \"Tensor (Tensor input) => Median\",\n      \"TensorTuple[values, indices] (Tensor input, Int32 dim=-1, Bool keepdim=False) => MedianWithIndices\",\n    ]\n  bind_python: True\n\n- name: \"reduce_max\"\n  signature: \"Tensor (Tensor x, Int32List axis, Bool keepdim=False) => ReduceMax\"\n  bind_python: True\n\n- name: \"reduce_min\"\n  signature: \"Tensor (Tensor x, Int32List axis, Bool keepdim=False) => ReduceMin\"\n  bind_python: True\n\n- name: \"reduce_sum\"\n  signature:\n    [\n      \"Tensor (Tensor x, Int32List[1] dim, Bool keepdim=False, *, DataType dtype=None) => ReduceSum\",\n      \"Tensor (Tensor x, *, DataType dtype=None) => ReduceSumWhole\",\n    ]\n  bind_python: True\n\n- name: \"reduce_nansum\"\n  signature: [\n    \"Tensor (Tensor input, Int32List[1] dim, Bool keepdim=False, *, DataType dtype=None) => ReduceNanSum\",\n    \"Tensor (Tensor input, *, DataType dtype=None) => ReduceNanSumWhole\"\n  ]\n  bind_python: True\n\n- name: \"reduce_mean\"\n  signature:\n    [\n      \"Tensor (Tensor x, Int32List[1] dim, Bool keepdim=False) => ReduceMean\",\n      \"Tensor (Tensor x) => ReduceMeanWhole\",\n    ]\n  bind_python: True\n\n- name: \"reduce_all\"\n  signature:\n    [\n      \"Tensor (Tensor x, Int32List[1] dim, Bool keepdim=False) => ReduceAll\",\n      \"Tensor (Tensor x) => ReduceAllWhole\",\n    ]\n  bind_python: True\n\n- name: \"reduce_any\"\n  signature:\n    [\n      \"Tensor (Tensor x, Int32List[1] dim, Bool keepdim=False) => ReduceAny\",\n      \"Tensor (Tensor x) => ReduceAnyWhole\",\n    ]\n  bind_python: True\n\n- name: \"reduce_prod\"\n  signature:\n    [\n      \"Tensor (Tensor x, Int32List[1] dim, Bool keepdim=False, *, DataType dtype=None) => ReduceProd\",\n      \"Tensor (Tensor x, *, DataType dtype=None) => ReduceProdWhole\",\n    ]\n  bind_python: True\n\n- name: \"reduce_min_device_stage\"\n  signature: \"TensorTuple (Tensor in, Int32List axis) => ReduceMinDeviceStage\"\n  bind_python: True\n\n- name: \"reduce_min_device_stage_grad\"\n  signature: \"Tensor (Tensor out_diff, Tensor mask, Tensor count, Int32List axis) => ReduceMinDeviceStageGrad\"\n  bind_python: False\n\n- name: \"reduce_max_device_stage\"\n  signature: \"TensorTuple (Tensor in, Int32List axis) => ReduceMaxDeviceStage\"\n  bind_python: True\n\n- name: \"reduce_max_device_stage_grad\"\n  signature: \"Tensor (Tensor out_diff, Tensor mask, Tensor count, Int32List axis) => ReduceMaxDeviceStageGrad\"\n  bind_python: False\n\n- name: \"reduce_min_global_stage\"\n  signature: \"TensorTuple (Tensor in, Tensor device_count, Int32List axis, Bool keepdims=False) => ReduceMinGlobalStage\"\n  bind_python: True\n\n- name: \"reduce_min_global_stage_grad\"\n  signature: \"Tensor (Tensor out_diff, Tensor mask, Tensor device_count, Int32List axis, Bool keepdims=False) => ReduceMinGlobalStageGrad\"\n  bind_python: False\n\n- name: \"reduce_max_global_stage\"\n  signature: \"TensorTuple (Tensor in, Tensor device_count, Int32List axis, Bool keepdims=False) => ReduceMaxGlobalStage\"\n  bind_python: True\n\n- name: \"reduce_max_global_stage_grad\"\n  signature: \"Tensor (Tensor out_diff, Tensor mask, Tensor device_count, Int32List axis, Bool keepdims=False) => ReduceMaxGlobalStageGrad\"\n  bind_python: False\n\n- name: \"logsumexp\"\n  signature: \"Tensor (Tensor x, Int32List[1] dim, Bool keepdim=False) => LogSumExp\"\n  bind_python: True\n\n- name: \"logaddexp\"\n  signature: \"Tensor (Tensor x, Tensor y) => LogAddExp\"\n  bind_python: True\n\n- name: \"quantile\"\n  signature: \n    [\n      'Tensor (Tensor input, Tensor q, Int64 dim=None, Bool keepdim=False, String interpolation=\"linear\", Bool ignore_nan=False) => Quantile',\n      'Tensor (Tensor input, Scalar q, Int64 dim=None, Bool keepdim=False, String interpolation=\"linear\", Bool ignore_nan=False) ==> ScalarQuantile'\n    ]\n  bind_python: True\n\n- name: \"transpose\"\n  signature:\n    [\n      \"Tensor (Tensor input, Int32List perm) => Transpose\",\n      \"Tensor (Tensor input, Int32 dim0, Int32 dim1) => Transpose2dim\",\n    ]\n  bind_python: True\n\n- name: \"as_strided\"\n  signature: \"Tensor (Tensor input, Int64List size, Int64List stride, Int64 storage_offset=0) => AsStrided\"\n  bind_python: True\n\n- name: \"as_strided_grad\"\n  signature: \"Tensor (Tensor dy, Tensor input, Int64List size, Int64List stride, Int64 storage_offset=0) => AsStridedGrad\"\n  bind_python: False\n\n- name: \"as_strided_\"\n  signature: \"Tensor (Tensor input, Int64List size, Int64List stride, Int64 storage_offset=0) => InplaceAsStrided\"\n  bind_python: True\n\n- name: \"select\"\n  signature: \"Tensor (Tensor input, Int32 dim, Int32 index) => Select\"\n  bind_python: True\n\n- name: \"swapaxes\"\n  signature: \"Tensor (Tensor input, Int32 dim0, Int32 dim1) => Swapaxes\"\n  bind_python: True\n\n- name: \"swapdims\"\n  signature: \"Tensor (Tensor input, Int32 dim0, Int32 dim1) => Swapdims\"\n  bind_python: True\n\n- name: \"amax\"\n  signature: \"Tensor (Tensor input, Int32List[1] dim=None, Bool keepdim=False) => Amax\"\n  bind_python: True\n\n- name: \"permute\"\n  signature: \"Tensor (Tensor input, Int32List dims) => Permute\"\n  bind_python: True\n\n- name: \"T\"\n  signature: \"Tensor (Tensor input) => TransposeAllDimProperty\"\n  bind_python: True\n\n- name: \"t\"\n  signature: \"Tensor (Tensor input) => TransposeAllDimFunction\"\n  bind_python: True\n\n- name: \"not_equal_zero\"\n  signature: \"Tensor (Tensor x) => NotEqualZero\"\n  bind_python: False\n\n- name: \"not_equal_zero_grad\"\n  signature: \"Tensor (Tensor x, Tensor dy) => NotEqualZeroGrad\"\n  bind_python: False\n\n- name: \"reciprocal\"\n  signature: \"Tensor (Tensor x) => Reciprocal\"\n  bind_python: True\n\n- name: \"reciprocal_grad\"\n  signature: \"Tensor (Tensor x, Tensor dy) => ReciprocalGrad\"\n  bind_python: False\n\n- name: \"reciprocal_no_nan\"\n  signature: \"Tensor (Tensor x) => ReciprocalNoNan\"\n  bind_python: True\n\n- name: \"reciprocal_no_nan_grad\"\n  signature: \"Tensor (Tensor x, Tensor dy) => ReciprocalNoNanGrad\"\n  bind_python: False\n\n- name: \"image_flip\"\n  signature: \"Tensor (Tensor x, Tensor flip_code) => ImageFlip\"\n  bind_python: True\n\n- name: \"sin\"\n  signature: \"Tensor (Tensor x) => Sin\"\n  bind_python: True\n\n- name: \"sin_grad\"\n  signature: \"Tensor (Tensor x, Tensor dy) => SinGrad\"\n  bind_python: False\n\n- name: \"sin_grad_grad\"\n  signature: \"Tensor (Tensor x, Tensor dydx) => SinGradGrad\"\n  bind_python: False\n\n- name: \"sin_\"\n  signature: \"Tensor (Tensor x) => Sin_\"\n  bind_python: True\n\n- name: \"cos\"\n  signature: \"Tensor (Tensor x) => Cos\"\n  bind_python: True\n\n- name: \"cos_grad\"\n  signature: \"Tensor (Tensor x, Tensor dy) => CosGrad\"\n  bind_python: False\n\n- name: \"cos_grad_grad\"\n  signature: \"Tensor (Tensor x, Tensor dydx) => CosGradGrad\"\n  bind_python: False\n\n- name: \"cosh\"\n  signature: \"Tensor (Tensor x) => Cosh\"\n  bind_python: True\n\n- name: \"cosh_grad\"\n  signature: \"Tensor (Tensor x, Tensor dy) => CoshGrad\"\n  bind_python: True\n\n- name: \"fmod\"\n  signature:\n    [\n      \"Tensor (Tensor input, Tensor other) => BroadcastFMod\",\n      \"Tensor (Tensor input, Scalar other, *, Bool inplace=False) => ScalarFMod\",\n      \"Tensor (Tensor input, Scalar other) => ScalarFMod\",\n    ]\n  bind_python: true\n\n- name: \"log\"\n  signature: \"Tensor (Tensor x) => Log\"\n  bind_python: True\n\n- name: \"log_grad\"\n  signature: \"Tensor (Tensor x, Tensor dy) => LogGrad\"\n  bind_python: False\n\n- name: \"log2\"\n  signature: \"Tensor (Tensor x) => Log2\"\n  bind_python: True\n\n- name: \"log2_grad\"\n  signature: \"Tensor (Tensor x, Tensor dy) => Log2Grad\"\n  bind_python: False\n\n- name: \"log10\"\n  signature: \"Tensor (Tensor x) => Log10\"\n  bind_python: True\n\n- name: \"log10_grad\"\n  signature: \"Tensor (Tensor x, Tensor dy) => Log10Grad\"\n  bind_python: False\n\n- name: \"sqrt\"\n  signature: \"Tensor (Tensor x) => Sqrt\"\n  bind_python: True\n\n- name: \"sqrt_grad\"\n  signature: \"Tensor (Tensor x, Tensor dy) => SqrtGrad\"\n  bind_python: False\n\n- name: \"rsqrt\"\n  signature: \"Tensor (Tensor x) => Rsqrt\"\n  bind_python: True\n\n- name: \"rsqrt_grad\"\n  signature: \"Tensor (Tensor x, Tensor dy) => RsqrtGrad\"\n  bind_python: False\n\n- name: \"square\"\n  signature: \"Tensor (Tensor x) => Square\"\n  bind_python: True\n\n- name: \"square_grad\"\n  signature: \"Tensor (Tensor x, Tensor dy) => SquareGrad\"\n  bind_python: False\n\n- name: \"sqrt_square_sum\"\n  signature: \"Tensor (Tensor x) => SqrtSquareSum\"\n  bind_python: True\n\n- name: \"std\"\n  signature: \"Tensor (Tensor x, Int32List[1] dim=None, Bool unbiased=None, Bool keepdim=None) => StandardDeviation\"\n  bind_python: True\n\n- name: \"var\"\n  signature: \"Tensor (Tensor x, Int32List[1] dim=None, Bool unbiased=None, Bool keepdim=None) => Variance\"\n  bind_python: True\n\n- name: \"rms_layer_norm\"\n  signature: \"Tensor (Tensor hidden_states, Tensor weight, Float variance_epsilon) => RMSLayerNormalization\"\n  bind_python: True\n\n- name: \"relu\"\n  signature: \"Tensor (Tensor x, Bool inplace=False) => Relu\"\n  bind_python: True\n\n- name: \"relu_grad\"\n  signature: \"Tensor (Tensor dy, Tensor y) => ReluGrad\"\n  bind_python: False\n\n- name: \"hann_window\"\n  signature: [\n      \"Tensor (Int64 window_length, Bool periodic=True, *, Device device=None, DataType dtype=None,\n      Bool requires_grad=False) => HannWindow\",\n      \"Tensor (Int64 window_length, Bool periodic=True, *, Placement placement, SbpList sbp, DataType dtype=None,\n      Bool requires_grad=False) => GlobalHannWindow\",\n    ]\n  bind_python: True\n\n- name: \"hardtanh\"\n  signature: \"Tensor (Tensor x, Double min_val, Double max_val) => HardTanh\"\n  bind_python: True\n\n- name: \"hardtanh_grad\"\n  signature: \"Tensor (Tensor y, Tensor dy, Double min_val, Double max_val) => HardTanhGrad\"\n  bind_python: False\n\n- name: \"tan\"\n  signature: \"Tensor (Tensor x) => Tan\"\n  bind_python: True\n\n- name: \"tan_grad\"\n  signature: \"Tensor (Tensor x, Tensor dy) => TanGrad\"\n  bind_python: True\n\n- name: \"tanh\"\n  signature: \"Tensor (Tensor x) => Tanh\"\n  bind_python: True\n\n- name: \"tanh_grad\"\n  signature: \"Tensor (Tensor y, Tensor dy) => TanhGrad\"\n  bind_python: True\n\n- name: \"threshold\"\n  signature: \"Tensor (Tensor x, *, Double threshold, Double value) => Threshold\"\n  bind_python: True\n\n- name: \"threshold_grad\"\n  signature: \"Tensor (Tensor x, Tensor dy, Double threshold) => ThresholdGrad\"\n  bind_python: False\n\n- name: \"elu\"\n  signature: \"Tensor (Tensor x, Double alpha) => Elu\"\n  bind_python: True\n\n- name: \"elu_grad\"\n  signature: \"Tensor (Tensor x, Tensor dy, Double alpha) => EluGrad\"\n  bind_python: False\n\n- name: \"celu\"\n  signature: \"Tensor (Tensor x, *, Double alpha=1.0, Bool inplace=False) => Celu\"\n  bind_python: True\n\n- name: \"celu_grad\"\n  signature: \"Tensor (Tensor y, Tensor dy, Double alpha=1.0) => CeluGrad\"\n  bind_python: False\n\n- name: \"gelu\"\n  signature: \"Tensor (Tensor x) => Gelu\"\n  bind_python: True\n\n- name: \"gelu_grad\"\n  signature: \"Tensor (Tensor dy, Tensor x) => GeluGrad\"\n  bind_python: False\n\n- name: \"fast_gelu\"\n  signature: \"Tensor (Tensor x) => FastGelu\"\n  bind_python: True\n\n- name: \"fast_gelu_grad\"\n  signature: \"Tensor (Tensor dy, Tensor x) => FastGeluGrad\"\n  bind_python: False\n\n- name: \"quick_gelu\"\n  signature: \"Tensor (Tensor x) => QuickGelu\"\n  bind_python: True\n\n- name: \"quick_gelu_grad\"\n  signature: \"Tensor (Tensor dy, Tensor x) => QuickGeluGrad\"\n  bind_python: False\n\n- name: \"square_relu\"\n  signature: \"Tensor (Tensor x) => SquareReLU\"\n  bind_python: True\n\n- name: \"square_relu_grad\"\n  signature: \"Tensor (Tensor dy, Tensor x) => SquareReLUGrad\"\n  bind_python: False\n\n- name: \"gelu_with_approximate\"\n  signature: 'Tensor (Tensor x, String approximate=\"none\") => GeluWithApproximate'\n  bind_python: True\n\n- name: \"glu\"\n  signature: \"Tensor (Tensor input, Int64 dim=-1) => Glu\"\n  bind_python: True\n\n- name: \"fused_glu\"\n  signature: \"Tensor (Tensor x, Tensor w, Tensor b=None, Tensor v=None, Tensor c=None, String activation=\\\"none\\\") => FusedGlu\"\n  bind_python: True\n\n- name: \"fused_glu_without_linear_grad\"\n  signature: \"TensorTuple (Tensor dy, Tensor matmul_wx, Tensor matmul_vx=None, String activation=\\\"none\\\") => FusedGluWithoutLinearGrad\"\n  bind_python: False\n\n- name: \"sigmoid\"\n  signature: \"Tensor (Tensor x) => Sigmoid\"\n  bind_python: True\n\n- name: \"sigmoid_grad\"\n  signature: \"Tensor (Tensor y, Tensor dy) => SigmoidGrad\"\n  bind_python: True\n\n- name: \"hardsigmoid\"\n  signature: \"Tensor (Tensor input, Bool inplace=False, *) => HardSigmoid\"\n  bind_python: True\n\n- name: \"hardsigmoid_grad\"\n  signature: \"Tensor (Tensor dy, Tensor x) => HardSigmoidGrad\"\n  bind_python: False\n\n- name: \"hardshrink\"\n  signature: \"Tensor (Tensor x, *, Double lambd=0.5, Bool inplace=False) => HardShrink\"\n  bind_python: True\n\n- name: \"hardshrink_grad\"\n  signature: \"Tensor (Tensor y, Tensor dy, Double lambd=0.5) => HardShrinkGrad\"\n  bind_python: False\n\n- name: \"softmax\"\n  signature: \"Tensor (Tensor x, Int64 dim=None) => Softmax\"\n  bind_python: True\n\n- name: \"softmax_grad\"\n  signature: \"Tensor (Tensor dy, Tensor y) => SoftmaxGrad\"\n  bind_python: False\n\n- name: \"gumbel_softmax\"\n  signature: \"Tensor (Tensor x, Double tau=1., Int64 dim=None, Bool hard=False, Generator generator=None) => GumbelSoftmax\"\n  bind_python: True\n\n- name: \"log_softmax\"\n  signature: \"Tensor (Tensor x, Int64 dim=None) => LogSoftmax\"\n  bind_python: True\n\n- name: \"log_softmax_grad\"\n  signature: \"Tensor (Tensor dy, Tensor y) => LogSoftmaxGrad\"\n  bind_python: False\n\n- name: \"hardswish\"\n  signature: \"Tensor (Tensor x) => HardSwish\"\n  bind_python: True\n\n- name: \"hardswish_grad\"\n  signature: \"Tensor (Tensor dy, Tensor x) => HardSwishGrad\"\n  bind_python: False\n\n- name: \"leaky_relu\"\n  signature: \"Tensor (Tensor x, Float alpha, Bool inplace=False) => LeakyRelu\"\n  bind_python: True\n\n- name: \"leaky_relu_grad\"\n  signature: \"Tensor (Tensor x, Tensor dy, Float alpha) => LeakyReluGrad\"\n  bind_python: False\n\n- name: \"rrelu\"\n  signature: \"Tensor (Tensor x, Float lower=0.125, Float upper=0.3333333333333333, Bool training=False, Bool inplace=False) => RRelu\"\n  bind_python: True\n\n- name: \"rrelu_\"\n  signature: \"Tensor (Tensor x, Float lower=0.125, Float upper=0.3333333333333333, Bool training=False) => RReluInplace\"\n  bind_python: True\n\n- name: \"normal_\"\n  signature: \"Tensor (Tensor x, Float mean=0.0, Float std=1.0, Generator generator=None) => Normal_\"\n  bind_python: True\n  \n- name: \"normal\"\n  signature: [\n      \"Tensor (Tensor mean, Tensor std, *, Tensor out=None, \n      Generator generator=None, Bool requires_grad=False) => TensorTensorNormal\",\n      \"Tensor (Tensor mean, Float std=1.0, *, Tensor out=None, \n      Generator generator=None, Bool requires_grad=False) => TensorScalarNormal\",\n      \"Tensor (Float mean, Tensor std, *, Tensor out=None, \n      Generator generator=None, Bool requires_grad=False) => ScalarTensorNormal\",\n      \"Tensor (Float mean, Float std, Shape size, *, Tensor out=None, DataType dtype=None, Device device=None,\n      Generator generator=None, Bool requires_grad=False) => Normal\",\n      \"Tensor (Float mean, Float std, Int32 size, *, Tensor out=None, DataType dtype=None, Device device=None,\n      Generator generator=None, Bool requires_grad=False) => Normal2\",\n      \"Tensor (Float mean, Float std, Shape size, *, Tensor out=None, Placement placement, SbpList sbp, DataType dtype=None,\n      Generator generator=None, Bool requires_grad=False) => GlobalNormal\",\n      \"Tensor (Float mean, Float std, Int32 size, *, Tensor out=None, Placement placement, SbpList sbp, DataType dtype=None,\n      Generator generator=None, Bool requires_grad=False) => GlobalNormal2\",\n    ]\n  bind_python: True\n\n- name: \"normalization\"\n  signature:\n    \"Tensor (Tensor x, Tensor moving_mean=None, Tensor moving_variance=None,\n    Tensor gamma=None, Tensor beta=None, Int32 axis=1, Float epsilon=1e-5,\n    Float momentum=0.9, Bool is_training=False) => Normalization\"\n\n  bind_python: True\n\n- name: \"normalization_grad\"\n  signature:\n    \"TensorTuple (Tensor grad, Tensor x, Tensor mean, Tensor inv_variance,\n    Tensor gamma, Float epsilon, Int32 axis) => NormalizationGrad\"\n  bind_python: False\n\n- name: \"normalization_add_relu\"\n  signature:\n    \"Tensor (Tensor x, Tensor addend=None, Tensor moving_mean=None, Tensor moving_variance=None,\n    Tensor gamma, Tensor beta, Int32 axis=1, Float epsilon=1e-5,\n    Float momentum=0.9, Bool is_training=False) => NormalizationAddRelu\"\n  bind_python: True\n\n- name: \"normalization_add_relu_grad\"\n  signature:\n    \"TensorTuple (Tensor x, Tensor dy, Tensor moving_mean, Tensor moving_variance,\n    Tensor gamma, Tensor beta, Tensor reserve_space, Tensor y, Int32 axis=1,\n    Float epsilon=1e-5, Bool has_addend) => NormalizationAddReluGrad\"\n  bind_python: False\n\n- name: \"eye\"\n  signature:\n    [\n      \"Tensor (Scalar n, Scalar m=None, *, DataType dtype=kFloat, Device device=None, Bool requires_grad=False) => Eye\",\n      \"Tensor (Scalar n, Scalar m=None, *, DataType dtype=kFloat, String device, Bool requires_grad=False) => Eye\",\n      \"Tensor (Scalar n, Scalar m=None, *, DataType dtype=kFloat, Bool requires_grad=False, Placement placement, SbpList sbp) => Eye\",\n      \"Tensor (Scalar n, Scalar m=None, *, DataType dtype=kFloat, Bool requires_grad=False, Placement placement, Sbp sbp) => Eye\",\n    ]\n  bind_python: True\n\n- name: \"eye_\"\n  signature: \"Tensor (Tensor x) => EyeInplace\"\n  bind_python: True\n\n- name: \"erfinv\"\n  signature: \"Tensor (Tensor x) => Erfinv\"\n  bind_python: True\n\n- name: \"erfinv_\"\n  signature: \"Tensor (Tensor x) => ErfinvInplace\"\n  bind_python: True\n\n- name: \"arange\"\n  signature: [\n      \"Tensor (Scalar start, Scalar end, Scalar step=1, *, DataType dtype=None,\n      Device device=None) => Arange\",\n      \"Tensor (Scalar end, *, DataType dtype=None, Device device=None) => Arange\",\n    ]\n  bind_python: True\n\n- name: \"global_arange\"\n  signature: [\n      \"Tensor (Scalar start, Scalar end, Scalar step=1, *, DataType dtype=None,\n      Placement placement, SbpList sbp) => GlobalArange\",\n      \"Tensor (Scalar end, *, DataType dtype=None, Placement placement, SbpList sbp) => GlobalArange\",\n    ]\n  bind_python: True\n\n- name: \"flatten\"\n  signature: \"Tensor (Tensor x, Int32 start_dim=0, Int32 end_dim=-1) => Flatten\"\n  bind_python: True\n\n- name: \"argmax\"\n  signature: \"Tensor (Tensor x, Int32 dim=None, Bool keepdim=None, DataType dtype=None) => ArgMax\"\n  bind_python: True\n\n- name: \"argmin\"\n  signature: \"Tensor (Tensor x, Int32 dim=None, Bool keepdim=None, DataType dtype=None) => ArgMin\"\n  bind_python: True\n\n- name: \"argwhere\"\n  signature: \"TensorTuple (Tensor x, DataType dtype=kInt32) => ArgWhere\"\n  bind_python: True\n\n- name: \"nonzero\"\n  signature: \"TensorTuple (Tensor x, Bool as_tuple=False) => NonZero\"\n  bind_python: True\n\n- name: \"broadcast_like\"\n  signature: \"Tensor (Tensor x, Tensor like, Int32List broadcast_axes=[]) => BroadcastLike\"\n  bind_python: True\n\n- name: \"cast\"\n  signature: \"Tensor (Tensor x, DataType dtype, Bool pin_memory=False) => Cast\"\n  bind_python: True\n\n- name: \"global_tensor_constant\"\n  signature: \"Tensor (Shape shape, Tensor value, *, DataType dtype, Placement placement, SbpList sbp) => GlobalTensorConstant\"\n  bind_python: True\n\n- name: \"tensor_constant\"\n  signature: \"Tensor (Shape shape, Tensor value, *, DataType dtype, Device device=None) => TensorConstant\"\n  bind_python: True\n\n- name: \"constant\"\n  signature:\n    [\n      \"Tensor (Shape shape, Scalar value, *, DataType dtype, Device device=None) => Constant\",\n    ]\n  bind_python: True\n\n- name: \"global_constant\"\n  signature:\n    [\n      \"Tensor (Shape shape, Scalar value, *, DataType dtype, Placement placement, SbpList sbp) => GlobalConstant\",\n    ]\n  bind_python: True\n\n- name: \"empty\"\n  signature: \"Tensor (Shape shape, *, DataType dtype, Device device=None, Bool requires_grad=False, Bool pin_memory=False) => Empty\"\n  bind_python: True\n\n- name: \"empty_strided\"\n  signature: \"Tensor (Int64List shape, Int64List stride, DataType dtype=None, Device device=None, Bool requires_grad=False, Bool pin_memory=False) => EmptyStrided\"\n  bind_python: True\n\n- name: \"global_empty\"\n  signature:\n    [\n      \"Tensor (Shape shape, *, DataType dtype, Placement placement, SbpList sbp) => GlobalEmpty\",\n    ]\n  bind_python: True\n\n- name: \"zeros_like\"\n  signature: \"Tensor (Tensor x) => ZerosLike\"\n  bind_python: False\n\n- name: \"ones_like\"\n  signature: \"Tensor (Tensor x) => OnesLike\"\n  bind_python: False\n\n- name: \"full_like\"\n  signature: \"Tensor (Tensor x, Scalar fill_value) => FullLike\"\n  bind_python: False\n\n- name: \"bernoulli\"\n  signature:\n    [\n      \"Tensor (Tensor input, *, DataType dtype=kFloat, Generator generator=None, Bool inplace=False) => Bernoulli\",\n      \"Tensor (Tensor input, Double p, *, DataType dtype=kFloat, Generator generator=None, Bool inplace=False) => BernoulliProb\",\n    ]\n  bind_python: True\n\n- name: \"bernoulli_\"\n  signature:\n    [\n      \"Tensor (Tensor input, *, DataType dtype=kFloat, Generator generator=None) => BernoulliInplace\",\n      \"Tensor (Tensor input, Double p, *, DataType dtype=kFloat, Generator generator=None) => BernoulliProbInplace\",\n    ]\n  bind_python: True\n\n- name: \"concat\"\n  signature: \"Tensor (TensorTuple inputs, Int64 dim=0) => Concat\"\n  bind_python: True\n\n- name: \"bias_add\"\n  signature: \"Tensor (Tensor x, Tensor bias, Int32 axis=1) => BiasAdd\"\n  bind_python: True\n\n- name: \"conv1d\"\n  signature:\n    'Tensor (Tensor input, Tensor weight, Tensor bias=None, Int32List[1] stride=1,\n    Int32List[1] padding=0, Int32List[1] dilation=1, Int32 groups=1,\n    String channel_pos=\"channels_first\") => Conv1d'\n  bind_python: True\n\n- name: \"conv2d\"\n  signature:\n    'Tensor (Tensor input, Tensor weight, Tensor bias=None, Int32List[2] stride=1,\n    Int32List[2] padding=0, Int32List[2] dilation=1, Int32 groups=1,\n    String channel_pos=\"channels_first\") => Conv2d'\n  bind_python: True\n\n- name: \"conv3d\"\n  signature:\n    'Tensor (Tensor input, Tensor weight, Tensor bias=None, Int32List[3] stride=1,\n    Int32List[3] padding=0, Int32List[3] dilation=1, Int32 groups=1,\n    String channel_pos=\"channels_first\") => Conv3d'\n  bind_python: True\n\n- name: \"fake_quantization\"\n  signature:\n    \"Tensor (Tensor in, Tensor scale, Tensor zero_point, String quantization_formula,\n    Int32 quantization_bit, String quantization_scheme) => FakeQuantization\"\n  bind_python: True\n\n- name: \"quantization\"\n  signature:\n    \"Tensor (Tensor in, Tensor scale, Tensor zero_point, String quantization_formula,\n    Int32 quantization_bit, String quantization_scheme) => Quantization\"\n  bind_python: True\n\n- name: \"min_max_observer\"\n  signature:\n    \"TensorTuple (Tensor in, String quantization_formula, Int32 quantization_bit,\n    String quantization_scheme, Bool per_layer_quantization) => MinMaxObserver\"\n  bind_python: True\n\n- name: \"moving_average_min_max_observer\"\n  signature:\n    \"TensorTuple (Tensor in, Tensor current_train_step, Tensor moving_max, Tensor moving_min,\n    Bool training, Int64 stop_update_after_iters, String quantization_formula,\n    Int32 quantization_bit, String quantization_scheme, Float momentum) => MovingAverageMinMaxObserver\"\n  bind_python: True\n\n- name: \"groupwise_dequantize\"\n  signature:\n    'Tensor (Tensor in, Tensor scale, *, Tensor zero=None, Int32 num_bits=8, Bool symmetric=True, Int64 group_dim=-1, Int64 group_size=-1) => GroupwiseDequantize'\n  bind_python: True\n\n- name: \"fused_linear_with_groupwise_quantized_weight\"\n  signature: 'Tensor (Tensor x, Tensor w, Tensor w_scale, *, Tensor w_zero=None, Tensor b=None, Int32 num_bits=8, Bool symmetric=True, Int64 group_dim=-1, Int64 group_size=-1) => FusedLinearWithGroupwiseQuantizedWeight'\n  bind_python: True\n\n- name: \"conv_data_grad\"\n  signature:\n    'Tensor (Tensor dy, Tensor weight, Tensor x, Int32 num_spatial_dims,\n    Int32List kernel_size, Int32List strides, Int32List padding_before,\n    Int32List dilation_rate, Int32 groups=1,\n    String data_format=\"channels_first\") => ConvDataGrad'\n  bind_python: False\n\n- name: \"conv_filter_grad\"\n  signature:\n    'Tensor (Tensor dy, Tensor x, Int32 num_spatial_dims, Int32List kernel_size,\n    Int32List strides, Int32List padding_before, Int32List dilation_rate,\n    Int32 groups=1, String data_format=\"channels_first\") => ConvFilterGrad'\n  bind_python: False\n\n- name: \"conv_bias_grad\"\n  signature: 'Tensor (Tensor dy, Int32 num_spatial_dims,\n    String data_format=\"channels_first\") => ConvBiasGrad'\n  bind_python: False\n\n- name: \"deconv1d\"\n  signature:\n    'Tensor (Tensor input, Tensor weight, Tensor bias=None, Int32List[1] stride=1,\n    Int32List[1] padding=0, Int32List[1] output_padding=0, Int32 groups=1,\n    Int32List[1] dilation=1, String data_format=\"channels_first\") => Deconv1d'\n  bind_python: True\n\n- name: \"deconv2d\"\n  signature:\n    'Tensor (Tensor input, Tensor weight, Tensor bias=None, Int32List[2] stride=1,\n    Int32List[2] padding=0, Int32List[2] output_padding=0, Int32 groups=1,\n    Int32List[2] dilation=1, String data_format=\"channels_first\") => Deconv2d'\n  bind_python: True\n\n- name: \"deconv3d\"\n  signature:\n    'Tensor (Tensor input, Tensor weight, Tensor bias=None, Int32List[3] stride=1,\n    Int32List[3] padding=0, Int32List[3] output_padding=0, Int32 groups=1,\n    Int32List[3] dilation=1, String data_format=\"channels_first\") => Deconv3d'\n  bind_python: True\n\n- name: \"expand\"\n  signature: \"Tensor (Tensor x, Shape shape) => Expand\"\n  bind_python: True\n\n- name: \"repeat\"\n  signature: \"Tensor (Tensor input, Shape repeat_shape) => Repeat\"\n  bind_python: True\n\n- name: \"repeat_interleave_index\"\n  signature: \"Tensor (Tensor input, Tensor cumsum, Int32 dim) => RepeatInterLeaveIndex\"\n  bind_python: False\n\n- name: \"repeat_interleave\"\n  signature:\n    [\n      \"Tensor (Tensor input, Int32 repeats, Int32 dim=None) => RepeatInterLeaveInt\",\n      \"Tensor (Tensor input, Tensor repeats, Int32 dim, Int32 output_size=None) => RepeatInterLeaveTensor\",\n    ]\n  bind_python: True\n\n- name: \"tile\"\n  signature: \"Tensor (Tensor input, Shape dims) => Tile\"\n  bind_python: True\n\n- name: \"roll\"\n  signature: \"Tensor (Tensor x, Int32List[1] shifts, Int32List[1] dims=None) => Roll\"\n  bind_python: True\n\n- name: \"expand_dims\"\n  signature: \"Tensor (Tensor input, Int32 dim) => ExpandDims\"\n  bind_python: True\n\n- name: \"unsqueeze\"\n  signature: \"Tensor (Tensor input, Int32 dim) => Unsqueeze\"\n  bind_python: True\n\n- name: \"unsqueeze_multiple\"\n  signature: \"Tensor (Tensor input, Int32List dim, Int32 dims) => UnsqueezeMultiple\"\n  bind_python: False\n\n- name: \"unsqueeze_\"\n  signature: \"Tensor (Tensor input, Int32 dim) => InplaceUnsqueeze\"\n  bind_python: True\n\n- name: \"squeeze\"\n  signature: \"Tensor (Tensor x, Int32List[1] dim=None) => Squeeze\"\n  bind_python: True\n\n- name: \"squeeze_\"\n  signature: \"Tensor (Tensor x, Int32List[1] dim=None) => InplaceSqueeze\"\n  bind_python: True\n\n- name: \"exp\"\n  signature: \"Tensor (Tensor x) => Exp\"\n  bind_python: True\n\n- name: \"exp2\"\n  signature: \"Tensor (Tensor x) => Exp2\"\n  bind_python: True \n\n- name: \"exp_grad\"\n  signature: \"Tensor (Tensor x, Tensor dy) => ExpGrad\"\n  bind_python: False\n\n- name: \"exp2_grad\"\n  signature: \"Tensor (Tensor x, Tensor dy) => Exp2Grad\"\n  bind_python: False\n\n- name: \"gather\"\n  signature: \"Tensor (Tensor x, Tensor indices, Int64 axis) => Gather\"\n  bind_python: True\n\n- name: \"dim_gather\"\n  signature: \" Tensor (Tensor input, Int64 dim, Tensor index, Bool sparse_grad=False) => DimGather\"\n  bind_python: True\n\n- name: \"embedding_renorm_\"\n  signature: \" Tensor (Tensor in, Tensor indices, Double max_norm, Double norm_type) => EmbeddingReNorm\"\n  bind_python: True\n\n- name: \"embedding\"\n  signature: \" Tensor (Tensor weight, Tensor indices, Int64 padding_idx=None, Bool scale_grad_by_freq=False) => Embedding\"\n  bind_python: True\n\n- name: \"embedding_grad\"\n  signature: \" Tensor (Tensor dy, Tensor weight, Tensor indices, Int64 padding_idx, Bool scale_grad_by_freq=False) => EmbeddingGrad\"\n  bind_python: False\n\n- name: \"arg_sort\"\n  signature: \"Tensor (Tensor in, String direction) => ArgSort\"\n  bind_python: True\n\n- name: \"gather_nd\"\n  signature: \"Tensor (Tensor params, Tensor indices) => GatherNd\"\n  bind_python: True\n\n- name: \"scatternd\"\n  signature: \"Tensor (Tensor indices, Tensor updates, Shape shape) => ScatterNd\"\n  bind_python: True\n\n- name: \"tensor_scatter_nd_update\"\n  signature: \"Tensor (Tensor tensor, Tensor indices, Tensor updates, Bool inplace=False) => TensorScatterNdUpdate\"\n  bind_python: True\n\n- name: \"scatterndlike\"\n  signature: \"Tensor (Tensor like, Tensor updates, Tensor indices) => ScatterNdLike\"\n  bind_python: True\n\n- name: \"matmul\"\n  signature:\n    \"Tensor (Tensor input, Tensor other, Bool transpose_a=False, Bool transpose_b=False,\n    Double alpha=1.0) => MatMul\"\n  bind_python: True\n\n- name: \"mm\"\n  signature: \"Tensor (Tensor input, Tensor mat2) => MatMulNoBroadCast\"\n  bind_python: True\n\n- name: \"fused_mlp\"\n  signature: \"Tensor (Tensor x, TensorTuple weights, TensorTuple biases, Bool skip_final_activation) => FusedMLP\"\n  bind_python: True\n\n- name: \"fused_matmul_bias\"\n  signature: \"Tensor (Tensor x, Tensor weight, Tensor bias, Tensor _add_to_output=None, Double alpha=1.0, Double beta=1.0) => FusedMatmulBias\"\n  bind_python: True\n\n- name: \"fused_mlp_grad\"\n  signature: \"TensorTuple (Tensor dy, Tensor x, TensorTuple weights, TensorTuple cublas_aux, TensorTuple hidden, FloatList alpha_list) => FusedMLPGrad\"\n  bind_python: False\n\n- name: \"cublas_bias_add_relu_matmul_grad\"\n  signature: \"TensorTuple (Tensor dy, Tensor weight, Tensor aux, Double alpha=1.0) => CublasBiasAddReluMatmulGrad\"\n  bind_python: False\n\n- name: \"cublas_matmul_bias_add_grad\"\n  signature: \"TensorTuple (Tensor dy, Tensor x) => CublasMatmulBiasAddGrad\"\n  bind_python: False\n\n- name: \"fused_matmul_bias_add_relu_dropout\"\n  signature: \"Tensor (Tensor x, TensorTuple weights, TensorTuple biases, Bool skip_final_activation, FloatList dropout_rate_list, Generator generator=None) => FusedMatmulBiasAddReluDropout\"\n  bind_python: True\n\n- name: \"fused_apply_rotary_emb\"\n  signature: 'Tensor (Tensor x, *, Tensor cos=None, Tensor sin=None, Tensor position_ids=None, String x_layout=\"BHMK\", String output_layout=None, String mode=\"plane\", Int64 tensor_index=None, Int64 k_size=None, Float base=1e4, Int64 rotary_size=None) => FusedApplyRotaryEmb'\n  bind_python: True\n\n- name: \"fused_relu_dropout_grad\"\n  signature: \"Tensor (Tensor dy, Tensor mask, Float scale) => FusedReluDropoutGrad\"\n  bind_python: False\n\n- name: \"broadcast_matmul_grad_b\"\n  signature: \"Tensor (Tensor a, Tensor b, Double alpha=1.0) => BroadcastMatmulGradB\"\n  bind_python: False\n\n- name: \"batch_matmul\"\n  signature:\n    \"Tensor (Tensor a, Tensor b, Bool transpose_a=False, Bool transpose_b=False,\n    Double alpha=1.0) => BatchMatMul\"\n  bind_python: True\n\n- name: \"baddbmm\"\n  signature:\n    \"Tensor (Tensor input, Tensor batch1, Tensor batch2, *, Double beta=1.0, Double alpha=1.0) => BaddBmm\"\n  bind_python: True\n\n- name: \"matrix_vector_product\"\n  signature: \"Tensor (Tensor input, Tensor vec) => MatrixVectorProduct\"\n  bind_python: True\n\n- name: \"matrix_vector_product_grad_a\"\n  signature: \"Tensor (Tensor dy, Tensor b) => MatrixVectorProductGradA\"\n  bind_python: False\n\n- name: \"matrix_vector_product_grad_b\"\n  signature: \"Tensor (Tensor dy, Tensor a) => MatrixVectorProductGradB\"\n  bind_python: False\n\n- name: \"vector_matrix_product\"\n  signature: \"Tensor (Tensor vec, Tensor input) => VectorMatrixProduct\"\n  bind_python: False\n\n- name: \"vector_matrix_product_grad_a\"\n  signature: \"Tensor (Tensor dy, Tensor b) => VectorMatrixProductGradA\"\n  bind_python: False\n\n- name: \"vector_matrix_product_grad_b\"\n  signature: \"Tensor (Tensor dy, Tensor a) => VectorMatrixProductGradB\"\n  bind_python: False\n\n- name: \"tensordot\"\n  signature:\n    [\n      \"Tensor (Tensor a, Tensor b, Int32List dims_a, Int32List dims_b) => TensorDot\",\n      \"Tensor (Tensor a, Tensor b, Int32 dims) => TensorDotIntDims\",\n    ]\n  bind_python: True\n\n- name: \"l1_loss\"\n  signature: 'Tensor(Tensor input, Tensor target, String reduction=\"mean\") => L1Loss'\n  bind_python: True\n\n- name: \"mse_loss\"\n  signature: 'Tensor(Tensor input, Tensor target, String reduction=\"mean\") => MseLoss'\n  bind_python: True\n\n- name: \"kl_div_loss\"\n  signature: 'Tensor(Tensor input, Tensor target, Bool log_target=False, String reduction=\"mean\") => KLDivLoss'\n  bind_python: True\n\n- name: \"kl_div_loss_grad\"\n  signature: \"Tensor(Tensor dy, Tensor input, Tensor target, Bool log_target) => KLDivLossGrad\"\n  bind_python: False\n\n- name: \"kl_div_loss_target_grad\"\n  signature: \"Tensor(Tensor dy, Tensor input, Tensor target, Bool log_target) => KLDivLossTargetGrad\"\n  bind_python: False\n\n- name: \"nll_loss\"\n  signature: \"Tensor(Tensor input, Tensor target, Tensor weight=None, Int64 ignore_index, String reduction) => NLLLoss\"\n  bind_python: True\n\n- name: \"nll_grad\"\n  signature: \"Tensor(Tensor out_grad, Tensor input, Tensor target, Tensor weight=None, Int64 ignore_index) => NLLGrad\"\n  bind_python: False\n\n- name: \"binary_cross_entropy_loss\"\n  signature: 'Tensor(Tensor input, Tensor target, Tensor weight=None, String reduction=\"mean\") => BinaryCrossEntropyLoss'\n  bind_python: True\n\n- name: \"binary_cross_entropy_loss_grad\"\n  signature: \"Tensor(Tensor dy, Tensor input, Tensor target, Tensor weight=None) => BinaryCrossEntropyLossGrad\"\n  bind_python: False\n\n- name: \"binary_cross_entropy_loss_target_grad\"\n  signature: \"Tensor(Tensor dy, Tensor input, Tensor target, Tensor weight=None) => BinaryCrossEntropyLossTargetGrad\"\n  bind_python: False\n\n- name: \"binary_cross_entropy_with_logits_loss\"\n  signature: 'Tensor(Tensor input, Tensor target, Tensor weight=None, Tensor pos_weight=None, String reduction=\"mean\") => BinaryCrossEntropyWithLogitsLoss'\n  bind_python: True\n\n- name: \"binary_cross_entropy_with_logits_loss_grad\"\n  signature: \"Tensor(Tensor dy, Tensor input, Tensor target, Tensor weight=None, Tensor pos_weight=None) => BinaryCrossEntropyWithLogitsLossGrad\"\n  bind_python: True\n\n- name: \"binary_cross_entropy_with_logits_loss_target_grad\"\n  signature: \"Tensor(Tensor dy, Tensor input, Tensor target, Tensor weight=None, Tensor pos_weight=None) => BinaryCrossEntropyWithLogitsLossTargetGrad\"\n  bind_python: False\n\n- name: \"binary_cross_entropy_with_logits_reduce_mean_loss_grad\"\n  signature: \"Tensor(Tensor dy, Tensor input, Tensor target) => BinaryCrossEntropyWithLogitsReduceMeanLossGrad\"\n  bind_python: False\n\n- name: \"binary_cross_entropy_with_logits_reduce_mean_loss_target_grad\"\n  signature: \"Tensor(Tensor dy, Tensor input, Tensor target) => BinaryCrossEntropyWithLogitsReduceMeanLossTargetGrad\"\n  bind_python: False\n\n- name: \"sparse_cross_entropy\"\n  signature: \"Tensor (Tensor prediction, Tensor label, Int64 depth) => SparseCrossEntropy\"\n  bind_python: True\n\n- name: \"sparse_cross_entropy_grad\"\n  signature: \"Tensor (Tensor prediction, Tensor label, Tensor dy, Int64 depth) => SparseCrossEntropyGrad\"\n  bind_python: False\n\n- name: \"distributed_sparse_cross_entropy\"\n  signature: \"Tensor (Tensor prediction, Tensor label, Int64 depth) => SparseCrossEntropyMs\"\n  bind_python: True\n\n- name: \"cross_entropy\"\n  signature: 'Tensor(Tensor input, Tensor target, Tensor weight=None, Int64 ignore_index=-100, String reduction=\"mean\", Double label_smoothing=0.0) => CrossEntropy'\n  bind_python: True\n\n- name: \"cross_entropy_label_smoothing\"\n  signature: 'Tensor(Tensor input, Tensor target, Tensor weight=None, Int64 ignore_index=-100, String reduction=\"mean\", Double label_smoothing=0.0) => CrossEntropyLabelSmoothing'\n  bind_python: False\n\n- name: \"cross_entropy_prob\"\n  signature: 'Tensor(Tensor input, Tensor target, Tensor weight=None, String reduction=\"mean\", Double label_smoothing=0.0) => CrossEntropyProb'\n  bind_python: False\n\n- name: \"distributed_sparse_cross_entropy_grad\"\n  signature: \"Tensor (Tensor prediction, Tensor label, Tensor dy, Int64 depth) => SparseCrossEntropyMsGrad\"\n  bind_python: False\n\n- name: \"sparse_softmax_cross_entropy\"\n  signature: \"Tensor (Tensor logits, Tensor label) => SparseSoftmaxCrossEntropy\"\n  bind_python: True\n\n- name: \"sparse_softmax_cross_entropy_grad\"\n  signature: \"Tensor (Tensor dy, Tensor prob, Tensor label, Int64 depth) => SparseSoftmaxCrossEntropyGrad\"\n  bind_python: False\n\n- name: \"sparse_softmax_cross_entropy_ms_grad\"\n  signature: \"Tensor (Tensor dy, Tensor prob, Tensor label, Int64 depth) => SparseSoftmaxCrossEntropyMsGrad\"\n  bind_python: False\n\n- name: \"softmax_cross_entropy\"\n  signature: \"Tensor (Tensor logits, Tensor label) => SoftmaxCrossEntropy\"\n  bind_python: True\n\n- name: \"softmax_cross_entropy_grad\"\n  signature: \"Tensor (Tensor dy, Tensor label, Tensor prob) => SoftmaxCrossEntropyGrad\"\n  bind_python: True\n\n- name: \"smooth_l1_loss\"\n  signature: \"Tensor (Tensor logits, Tensor label, Float beta, String reduction) => SmoothL1Loss\"\n  bind_python: True\n\n- name: \"smooth_l1_loss_grad\"\n  signature: \"Tensor (Tensor loss_grad, Tensor prediction, Tensor label, Float beta) => SmoothL1LossGrad\"\n  bind_python: False\n\n- name: \"combined_margin_loss\"\n  signature: \"Tensor (Tensor x, Tensor label, Float m1, Float m2, Float m3) => CombinedMarginLoss\"\n  bind_python: True\n\n- name: \"combined_margin_loss_grad\"\n  signature: \"Tensor (Tensor dy, Tensor label, Tensor theta, Float m1, Float m2, Float m3, Int64 depth) => CombinedMarginLossGrad\"\n  bind_python: False\n\n- name: \"triplet_margin_loss\"\n  signature: \"Tensor (Tensor anchor, Tensor positive, Tensor negative, *, Float margin, Float p, Float eps, Bool swap, String reduction) => TripletMarginLoss\"\n  bind_python: True\n\n- name: \"margin_ranking_loss\"\n  signature: \"Tensor (Tensor input_1, Tensor input_2, Tensor target, Float margin, String reduction) => MarginRankingLoss\"\n  bind_python: True\n\n- name: \"ctc_loss\"\n  signature: \"Tensor (Tensor log_probs, Tensor targets, Tensor input_lengths, Tensor target_lengths, Int64 max_target_length, Int64 blank, Bool zero_infinity, String reduction) => CtcLoss\"\n  bind_python: True\n\n- name: \"affine_grid\"\n  signature: \"Tensor (Tensor theta, *, Shape size, Bool align_corners) => AffineGrid\"\n  bind_python: True\n\n- name: \"affine_grid_grad\"\n  signature: \"Tensor (Tensor dgrid, *, Shape size, Bool align_corners) => AffineGridGrad\"\n  bind_python: False\n\n- name: \"grid_sample\"\n  signature: \"Tensor (Tensor input, Tensor grid, *, String interpolation_mode, String padding_mode, Bool align_corners) => GridSample\"\n  bind_python: True\n\n- name: \"grid_sample_grad\"\n  signature: \"TensorTuple (Tensor doutput, Tensor input, Tensor grid, *, String interpolation_mode, String padding_mode, Bool align_corners) => GridSampleGrad\"\n  bind_python: False\n\n- name: \"where\"\n  signature:\n    [\n      \"Tensor (Tensor condition, Tensor x, Tensor y) => Where\",\n      \"Tensor (Tensor condition, Scalar x, Tensor y) => WhereScalarX\",\n      \"Tensor (Tensor condition, Tensor x, Scalar y) => WhereScalarY\",\n      \"Tensor (Tensor condition, Scalar x, Scalar y) => WhereScalarXY\",\n    ]\n  bind_python: true\n\n- name: \"masked_fill\"\n  signature: \"Tensor (Tensor input, Tensor mask, Scalar value) => MaskedFill\"\n  bind_python: true\n\n- name: \"masked_fill_\"\n  signature: \"Tensor (Tensor input, Tensor mask, Scalar value) => MaskedFillInplace\"\n  bind_python: true\n\n- name: \"movedim\"\n  signature:\n    [\n      \"Tensor (Tensor input, Int32 source, Int32 destination) => MovedimInt\",\n      \"Tensor (Tensor input, Int32List source, Int32List destination) => MovedimVec\",\n    ]\n  bind_python: True\n\n- name: \"tensor_split\"\n  signature:\n    [\n      \"TensorTuple (Tensor input, Int32 indices_or_sections, Int32 dim=0) => TensorSplitInt\",\n      \"TensorTuple (Tensor input, Int32List indices_or_sections, Int32 dim=0) => TensorSplitVec\",\n    ]\n  bind_python: True\n\n- name: \"hsplit\"\n  signature:\n    [\n      \"TensorTuple (Tensor input, Int32 indices_or_sections) => HsplitInt\",\n      \"TensorTuple (Tensor input, Int32List indices_or_sections) => HsplitVec\",\n    ]\n  bind_python: True\n\n- name: \"vsplit\"\n  signature:\n    [\n      \"TensorTuple (Tensor input, Int32 indices_or_sections) => VsplitInt\",\n      \"TensorTuple (Tensor input, Int32List indices_or_sections) => VsplitVec\",\n    ]\n  bind_python: True\n\n- name: \"negative\"\n  signature: \"Tensor (Tensor x) => Negative\"\n  bind_python: True\n\n- name: \"layer_norm_affine\"\n  signature:\n    \"Tensor (Tensor x, Tensor gamma, Tensor beta, Int64 begin_norm_axis,\n    Int64 begin_params_axis, Double epsilon) => LayerNormAffine\"\n  bind_python: True\n\n- name: \"skip_layer_norm\"\n  signature: \"Tensor (Tensor x, *, Tensor gamma=None, Tensor beta=None, Tensor bias=None, Tensor skip=None, Double epsilon=1e-5, Double alpha=1e1) => SkipLayerNorm\"\n  bind_python: True\n\n- name: \"layer_norm\"\n  signature: \"Tensor (Tensor x, Int64 begin_norm_axis, Int64 begin_params_axis, Double epsilon) => LayerNorm\"\n  bind_python: True\n\n- name: \"layer_norm_grad\"\n  signature: \"Tensor (Tensor dy, Tensor x, Tensor mean, Tensor inv_variance, Int64 begin_norm_axis, Double epsilon) => LayerNormGrad\"\n  bind_python: False\n\n- name: \"layer_norm_affine_grad\"\n  signature: \"Tensor (Tensor dy, Tensor x, Tensor mean, Tensor inv_variance, Tensor gamma, Int64 begin_norm_axis, Double epsilon) => LayerNormAffineGrad\"\n  bind_python: False\n\n- name: \"fuse_layer_norm_grad\"\n  signature: \"TensorTuple (Tensor dy, Tensor x, Tensor mean, Tensor inv_variance, Tensor gamma, Int64 begin_norm_axis, Int64 begin_params_axis, Double epsilon) => FuseLayerNormGrad\"\n  bind_python: False\n\n- name: \"layer_norm_param_grad\"\n  signature: \"TensorTuple (Tensor dy, Tensor x, Tensor mean, Tensor inv_variance, Int64 begin_params_axis) => LayerNormParamGrad\"\n  bind_python: False\n\n- name: \"rms_norm\"\n  signature: \"Tensor (Tensor x, Tensor weight=None, Shape normalized_shape, Float epsilon=1e-6) => RMSNorm\"\n  bind_python: True\n\n- name: \"rms_norm_grad\"\n  signature: \"Tensor (Tensor dy, Tensor x, Tensor inv_rms, Tensor weight=None, Bool param_grad) => RMSNormGrad\"\n  bind_python: False\n\n- name: \"skip_rms_norm\"\n  signature: \"Tensor (Tensor x, *, Tensor weight=None, Tensor bias=None, Tensor skip=None, Double epsilon=1e-5, Double alpha=1e1) => SkipRMSNorm\"\n  bind_python: True\n\n- name: \"group_norm\"\n  signature:\n    'Tensor (Tensor x, Tensor gamma=None, Tensor beta=None, Bool affine, Int32 num_groups, Double epsilon, String data_format=\"channels_first\", String activation=\"none\") => GroupNorm'\n  bind_python: True\n\n- name: \"group_norm_grad\"\n  signature: \"Tensor (Tensor dy, Tensor x, Tensor mean, Tensor inv_variance, Tensor gamma=None, Int32 num_groups, Double epsilon) => GroupNormGrad\"\n  bind_python: False\n\n- name: \"group_norm_param_grad\"\n  signature: \"TensorTuple (Tensor dy, Tensor x, Tensor mean, Tensor inv_variance) => GroupNormParamGrad\"\n  bind_python: False\n\n- name: \"avg_pool2d_nhwc\"\n  signature:\n    'Tensor (Tensor x, Int32List kernel_size, Int32List stride, String padding,\n    Int32List padding_before, Int32List padding_after,\n    String data_format=\"channels_first\", Bool ceil_mode=False) => TFAvgPool2D'\n  bind_python: True\n\n- name: \"ctc_loss_grad\"\n  signature: \"Tensor (Tensor loss_grad, Tensor log_probs, Tensor targets,\n    Tensor input_lengths, Tensor target_lengths, Tensor loss, Tensor alpha, Int64 blank, Bool zero_infinity, Int64 max_target_length) => CtcLossGrad\"\n  bind_python: False\n\n- name: \"adaptive_avg_pool1d\"\n  signature: 'Tensor (Tensor x, Int64List[1] output_size, String data_format=\"channels_first\") => AdaptiveAvgPool1D'\n  bind_python: True\n\n- name: \"adaptive_avg_pool2d\"\n  signature: 'Tensor (Tensor x, Int64List[2] output_size, String data_format=\"channels_first\") => AdaptiveAvgPool2D'\n  bind_python: True\n\n- name: \"adaptive_avg_pool3d\"\n  signature: 'Tensor (Tensor x, Int64List[3] output_size, String data_format=\"channels_first\") => AdaptiveAvgPool3D'\n  bind_python: True\n\n- name: \"adaptive_pool_grad\"\n  signature: 'Tensor (Tensor x, Tensor dy, String mode, Int32 ndims, String data_format=\"channels_first\") => AdaptivePoolNdGrad'\n\n- name: \"tf_pool_grad\"\n  signature:\n    \"Tensor (Tensor x, Tensor y, Tensor dy, String mode, Int32 ndims, String data_format,\n    String padding, Int32List padding_before, Int32List padding_after, Int32List pool_size,\n    Int32List strides, Bool ceil_mode) => TFPoolNdGrad\"\n  bind_python: False\n\n- name: \"max_pool1d\"\n  signature:\n    'TensorTuple (Tensor input, Int32List[1] kernel_size, Int32List[1] stride=None,\n    Int32List[1] padding=0, Int32List[1] dilation=1,\n    Bool return_indices=True, Bool ceil_mode=False,\n    String data_format=\"channels_first\") => MaxPool1D'\n  bind_python: True\n\n- name: \"max_pool2d\"\n  signature:\n    'TensorTuple (Tensor input, Int32List[2] kernel_size, Int32List[2] stride=None,\n    Int32List[2] padding=0, Int32List[2] dilation=1,\n    Bool return_indices=True, Bool ceil_mode=False,\n    String data_format=\"channels_first\") => MaxPool2D'\n  bind_python: True\n\n- name: \"max_pool3d\"\n  signature:\n    'TensorTuple (Tensor input, Int32List[3] kernel_size, Int32List[3] stride=None,\n    Int32List[3] padding=0, Int32List[3] dilation=1,\n    Bool return_indices=True, Bool ceil_mode=False,\n    String data_format=\"channels_first\") => MaxPool3D'\n  bind_python: True\n\n- name: \"max_pool_grad\"\n  signature: \"Tensor (Tensor x, Tensor indice, Tensor dy, Int32 ndims,\n    String data_format, Int32List padding, Int32List kernel_size,\n    Int32List stride, Int32List dilation, Bool return_indices, Bool ceil_mode) => MaxPoolNdGrad\"\n  bind_python: False\n\n- name: \"max_unpool1d\"\n  signature:\n    'Tensor (Tensor input, Tensor indices, Int32List[1] kernel_size, Int32List[1] stride=None,\n    Int32List[1] padding=0, Shape output_size=None) => MaxUnpool1D'\n  bind_python: True\n  \n- name: \"max_unpool2d\"\n  signature:\n    'Tensor (Tensor input, Tensor indices, Int32List[2] kernel_size, Int32List[2] stride=None,\n    Int32List[2] padding=0, Shape output_size=None) => MaxUnpool2D'\n  bind_python: True\n\n- name: \"max_unpool3d\"\n  signature:\n    'Tensor (Tensor input, Tensor indices, Int32List[3] kernel_size, Int32List[3] stride=None,\n    Int32List[3] padding=0, Shape output_size=None) => MaxUnpool3D'\n  bind_python: True\n\n- name: \"max_unpool1d_grad\"\n  signature: \"Tensor (Tensor x, Tensor indice, Tensor dy) => MaxUnpool1dGrad\"\n  bind_python: False\n\n- name: \"max_unpool2d_grad\"\n  signature: \"Tensor (Tensor x, Tensor indice, Tensor dy) => MaxUnpool2dGrad\"\n  bind_python: False\n\n- name: \"max_unpool3d_grad\"\n  signature: \"Tensor (Tensor x, Tensor indice, Tensor dy) => MaxUnpool3dGrad\"\n  bind_python: False\n\n- name: \"prelu\"\n  signature: \"Tensor (Tensor x, Tensor alpha) => PRelu\"\n  bind_python: True\n\n- name: \"prelu_grad\"\n  signature: \"TensorTuple (Tensor dy, Tensor x, Tensor alpha) => PReluGrad\"\n  bind_python: False\n\n- name: \"reshape\"\n  signature: \"Tensor (Tensor x, Shape shape) => Reshape\"\n  bind_python: True\n\n- name: \"view\"\n  signature: \"Tensor (Tensor x, Shape shape) => View\"\n  bind_python: True\n\n- name: \"contiguous\"\n  signature: \"Tensor (Tensor input) => ToContiguous\"\n  bind_python: True\n\n- name: \"contiguous_\"\n  signature: \"Tensor (Tensor input) => InplaceToContiguous\"\n  bind_python: True\n\n- name: \"slice_view_1d_contiguous\"\n  signature: \"Tensor (Tensor x, Int64 start, Int64 end) => SliceView1dContiguous\"\n  bind_python: True\n\n- name: \"narrow\"\n  signature: \"Tensor (Tensor input, Int64 dim, Int64 start, Int64 length) => Narrow\"\n  bind_python: True\n\n- name: \"narrow_grad\"\n  signature: \"Tensor (Tensor dy, Tensor like, Int64 dim, Int64 start, Int64 length) => NarrowGrad\"\n  bind_python: False\n\n- name: \"slice\"\n  signature: \"Tensor (Tensor x, Int64List start, Int64List stop, Int64List step, Bool enable_view_slice=None) => Slice\"\n  bind_python: True\n\n- name: \"slice_update\"\n  signature: \"Tensor (Tensor ref, Tensor value, Int64List start, Int64List stop, Int64List step, Bool inplace=False) => SliceUpdate\"\n  bind_python: True\n\n- name: \"slice_grad\"\n  signature: \"Tensor (Tensor dy, Shape like_shape, Int64List start, Int64List stop, Int64List step) => SliceGrad\"\n  bind_python: False\n\n- name: \"copy\"\n  signature: [\n      \"Tensor (Tensor x, String device_type, Int64 device_id, Bool pin_memory=False) => Copy\",\n      \"Tensor (Tensor x, Device device, Bool pin_memory=False) => Copy\"\n  ]\n  bind_python: True\n\n- name: \"to\"\n  signature: [\n      # type of device must be string for global tensor to perform argument validation\n      \"Tensor (Tensor x, String device=None, DataType dtype=None, Bool copy=False) => To\",\n      \"Tensor (Tensor x, Device device=None, DataType dtype=None, Bool copy=False) => To\",\n      \"Tensor (Tensor x, DataType dtype=None, Bool copy=False) => To\",\n      \"Tensor (Tensor x, Tensor other, Bool copy=False) => To\",\n      \"Tensor (Tensor x, String device=None) => To\",\n      \"Tensor (Tensor x, *, MemoryFormat memory_format) => To\",\n    ]\n  bind_python: True\n\n- name: \"flip\"\n  signature: \"Tensor (Tensor x, Int32List[1] dims) => Flip\"\n  bind_python: True\n\n- name: \"upsample\"\n  signature:\n    'Tensor (Tensor x, Double height_scale, Double width_scale, Bool align_corners,\n    String interpolation, String data_format=\"channels_first\") => Upsample'\n  bind_python: True\n\n- name: \"upsample_grad\"\n  signature:\n    \"Tensor (Tensor dy, Tensor x, Double height_scale, Double width_scale, Bool align_corners,\n    String data_format, String interpolation) => UpsampleGrad\"\n  bind_python: False\n\n- name: \"upsample_linear_1d\"\n  signature:\n    'Tensor (Tensor x, Double scale_factor=0.0, Bool align_corners=False, Int64List[1] output_size=None,\n    String data_format=\"channels_first\") => UpsampleLinear1D'\n  bind_python: True\n\n- name: \"upsample_linear_1d_grad\"\n  signature:\n    'Tensor (Tensor dy, Tensor x, Double scale_factor=0.0, Bool align_corners=False, Int64List[1] output_size=None,\n    String data_format=\"channels_first\") => UpsampleLinear1DGrad'\n  bind_python: False\n\n- name: \"upsample_nearest_1d\"\n  signature:\n    'Tensor (Tensor x, Double scale_factor=0.0, Int64List[1] output_size=None,\n    String data_format=\"channels_first\") => UpsampleNearest1D'\n  bind_python: True\n\n- name: \"upsample_nearest_1d_grad\"\n  signature:\n    'Tensor (Tensor dy, Tensor x, Double scale_factor=0.0, Int64List[1] output_size=None,\n    String data_format=\"channels_first\") => UpsampleNearest1DGrad'\n  bind_python: False\n\n- name: \"upsample_nearest_2d\"\n  signature:\n    'Tensor (Tensor x, Double height_scale=0.0, Double width_scale=0.0, Int64List[2] output_size=None,\n    String data_format=\"channels_first\") => UpsampleNearest2D'\n  bind_python: True\n\n- name: \"upsample_nearest_2d_grad\"\n  signature:\n    'Tensor (Tensor dy, Tensor x, Double height_scale=0.0, Double width_scale=0.0, Int64List[2] output_size=None,\n    String data_format=\"channels_first\") => UpsampleNearest2DGrad'\n  bind_python: False\n\n- name: \"upsample_bilinear_2d\"\n  signature:\n    'Tensor (Tensor x, Double height_scale=0.0, Double width_scale=0.0, Bool align_corners=False, Int64List[2] output_size=None,\n    String data_format=\"channels_first\") => UpsampleBilinear2D'\n  bind_python: True\n\n- name: \"upsample_bilinear_2d_grad\"\n  signature:\n    'Tensor (Tensor dy, Tensor x, Double height_scale=0.0, Double width_scale=0.0, Bool align_corners=False, Int64List[2] output_size=None,\n    String data_format=\"channels_first\") => UpsampleBilinear2DGrad'\n  bind_python: False\n\n- name: \"upsample_bicubic_2d\"\n  signature:\n    'Tensor (Tensor x, Double height_scale=0.0, Double width_scale=0.0, Bool align_corners=False, Int64List[2] output_size=None,\n    String data_format=\"channels_first\") => UpsampleBicubic2D'\n  bind_python: True\n\n- name: \"upsample_bicubic_2d_grad\"\n  signature:\n    'Tensor (Tensor dy, Tensor x, Double height_scale=0.0, Double width_scale=0.0, Bool align_corners=False, Int64List[2] output_size=None,\n    String data_format=\"channels_first\") => UpsampleBicubic2DGrad'\n  bind_python: False\n\n- name: \"upsample_nearest_3d\"\n  signature:\n    'Tensor (Tensor x, Double depth_scale=0.0, Double height_scale=0.0, Double width_scale=0.0, Int64List[3] output_size=None,\n    String data_format=\"channels_first\") => UpsampleNearest3D'\n  bind_python: True\n\n- name: \"upsample_nearest_3d_grad\"\n  signature:\n    'Tensor (Tensor dy, Tensor x, Double depth_scale=0.0, Double height_scale=0.0, Double width_scale=0.0, Int64List[3] output_size=None,\n    String data_format=\"channels_first\") => UpsampleNearest3DGrad'\n  bind_python: False\n\n- name: \"upsample_trilinear_3d\"\n  signature:\n    'Tensor (Tensor x, Double depth_scale=0.0, Double height_scale=0.0, Double width_scale=0.0, Bool align_corners=False,\n    Int64List[3] output_size=None, String data_format=\"channels_first\") => UpsampleTrilinear3D'\n  bind_python: True\n\n- name: \"upsample_trilinear_3d_grad\"\n  signature:\n    'Tensor (Tensor dy, Tensor x, Double depth_scale=0.0, Double height_scale=0.0, Double width_scale=0.0,\n    Bool align_corners=False, Int64List[3] output_size=None, String data_format=\"channels_first\") => UpsampleTrilinear3DGrad'\n  bind_python: False\n\n- name: \"fused_get_boundding_boxes_coord\"\n  signature: \"TensorTuple (Tensor x1, Tensor y1, Tensor w1, Tensor h1, Tensor x2, Tensor y2, Tensor w2, Tensor h2) => FusedGetBounddingBoxesCoord\"\n  bind_python: True\n\n- name: \"fused_get_boundding_boxes_coord_grad\"\n  signature: \"TensorTuple (Tensor b1_x1_diff, Tensor b1_x2_diff, Tensor b1_y1_diff, Tensor b1_y2_diff, Tensor b2_x1_diff, Tensor b2_x2_diff, Tensor b2_y1_diff, Tensor b2_y2_diff) => FusedGetBounddingBoxesCoordGrad\"\n  bind_python: False\n\n- name: \"fused_get_ciou_result\"\n  signature: \"TensorTuple (Tensor v, Tensor iou, Tensor rho2, Tensor c2, Float eps) => FusedGetCiouResult\"\n  bind_python: True\n\n- name: \"fused_get_ciou_result_grad\"\n  signature: \"TensorTuple (Tensor dy ,Tensor alpha, Tensor rho2, Tensor c2) => FusedGetCiouResultGrad\"\n  bind_python: False\n\n- name: \"fused_codegeex_qkv_reshape\"\n  signature: \"TensorTuple (Tensor query, Tensor key, Tensor value, Int32 num_attention_heads) => FusedCodegeexQkvReshape\"\n  bind_python: True\n\n- name: \"fused_get_iou\"\n  signature: \"Tensor (Tensor w1, Tensor h1, Tensor w2, Tensor h2, Tensor inter, Float eps) => FusedGetIou\"\n  bind_python: True\n\n- name: \"fused_get_iou_grad\"\n  signature: \"TensorTuple (Tensor diou, Tensor w1, Tensor h1, Tensor w2, Tensor h2, Tensor inter, Float eps) => FusedGetIouGrad\"\n  bind_python: False\n\n- name: \"abs\"\n  signature: \"Tensor (Tensor x) => Abs\"\n  bind_python: True\n\n- name: \"abs_grad\"\n  signature: \"Tensor (Tensor x, Tensor dy) => AbsGrad\"\n  bind_python: False\n\n- name: \"acos\"\n  signature: \"Tensor (Tensor x) => Acos\"\n  bind_python: True\n\n- name: \"acos_grad\"\n  signature: \"Tensor (Tensor x, Tensor dy) => AcosGrad\"\n  bind_python: False\n\n- name: \"acosh\"\n  signature: \"Tensor (Tensor x) => Acosh\"\n  bind_python: True\n\n- name: \"acosh_grad\"\n  signature: \"Tensor (Tensor x, Tensor dy) => AcoshGrad\"\n  bind_python: False\n\n- name: \"asin\"\n  signature: \"Tensor (Tensor x) => Asin\"\n  bind_python: True\n\n- name: \"asin_grad\"\n  signature: \"Tensor (Tensor x, Tensor dy) => AsinGrad\"\n  bind_python: False\n\n- name: \"asinh\"\n  signature: \"Tensor (Tensor x) => Asinh\"\n  bind_python: True\n\n- name: \"asinh_grad\"\n  signature: \"Tensor (Tensor x, Tensor dy) => AsinhGrad\"\n  bind_python: False\n\n- name: \"atan\"\n  signature: \"Tensor (Tensor x) => Atan\"\n  bind_python: True\n\n- name: \"atan_grad\"\n  signature: \"Tensor (Tensor x, Tensor dy) => AtanGrad\"\n  bind_python: False\n\n- name: \"atan2\"\n  signature: \"Tensor (Tensor input, Tensor other) => Atan2\"\n  bind_python: True\n\n- name: \"atan2_x_grad\"\n  signature: \"Tensor (Tensor dz, Tensor x, Tensor y) => Atan2XGrad\"\n  bind_python: False\n\n- name: \"atan2_y_grad\"\n  signature: \"Tensor (Tensor dz, Tensor x, Tensor y) => Atan2YGrad\"\n  bind_python: False\n\n- name: \"atanh\"\n  signature: \"Tensor (Tensor x) => Atanh\"\n  bind_python: True\n\n- name: \"atanh_grad\"\n  signature: \"Tensor (Tensor x, Tensor dy) => AtanhGrad\"\n  bind_python: False\n\n- name: \"ceil\"\n  signature: \"Tensor (Tensor x) => Ceil\"\n  bind_python: True\n\n- name: \"ceil_\"\n  signature: \"Tensor (Tensor x) => Ceil_\"\n  bind_python: True\n\n- name: \"ceil_grad\"\n  signature: \"Tensor (Tensor x, Tensor dy) => CeilGrad\"\n  bind_python: False\n\n- name: \"erf\"\n  signature: \"Tensor (Tensor x) => Erf\"\n  bind_python: True\n\n- name: \"erf_grad\"\n  signature: \"Tensor (Tensor x, Tensor dy) => ErfGrad\"\n  bind_python: False\n\n- name: \"erfc\"\n  signature: \"Tensor (Tensor x) => Erfc\"\n  bind_python: True\n\n- name: \"erfc_grad\"\n  signature: \"Tensor (Tensor x, Tensor dy) => ErfcGrad\"\n  bind_python: False\n\n- name: \"expm1\"\n  signature: \"Tensor (Tensor x) => Expm1\"\n  bind_python: True\n\n- name: \"expm1_grad\"\n  signature: \"Tensor (Tensor x, Tensor dy) => Expm1Grad\"\n  bind_python: False\n\n- name: \"floor\"\n  signature: \"Tensor (Tensor x) => Floor\"\n  bind_python: True\n\n- name: \"floor_\"\n  signature: \"Tensor (Tensor x) => Floor_\"\n  bind_python: True\n\n- name: \"floor_grad\"\n  signature: \"Tensor (Tensor x, Tensor dy) => FloorGrad\"\n  bind_python: False\n\n- name: \"lgamma\"\n  signature: \"Tensor (Tensor x) => Lgamma\"\n  bind_python: True\n\n- name: \"lgamma_grad\"\n  signature: \"Tensor (Tensor x, Tensor dy) => LgammaGrad\"\n  bind_python: False\n\n- name: \"log1p\"\n  signature: \"Tensor (Tensor x) => Log1p\"\n  bind_python: True\n\n- name: \"log1p_grad\"\n  signature: \"Tensor (Tensor x, Tensor dy) => Log1pGrad\"\n  bind_python: False\n\n- name: \"logsigmoid\"\n  signature: \"Tensor (Tensor x) => LogSigmoid\"\n  bind_python: True\n\n- name: \"logsigmoid_grad\"\n  signature: \"Tensor (Tensor x, Tensor dy) => LogSigmoidGrad\"\n  bind_python: False\n\n- name: \"rint\"\n  signature: \"Tensor (Tensor x) => Rint\"\n  bind_python: True\n\n- name: \"rint_grad\"\n  signature: \"Tensor (Tensor x, Tensor dy) => RintGrad\"\n  bind_python: False\n\n- name: \"round\"\n  signature: \"Tensor (Tensor x) => Round\"\n  bind_python: True\n\n- name: \"round_\"\n  signature: \"Tensor (Tensor x) => Round_\"\n  bind_python: True\n\n- name: \"round_grad\"\n  signature: \"Tensor (Tensor x, Tensor dy) => RoundGrad\"\n  bind_python: False\n\n- name: \"sign\"\n  signature: \"Tensor (Tensor x) => Sign\"\n  bind_python: True\n\n- name: \"sign_grad\"\n  signature: \"Tensor (Tensor x, Tensor dy) => SignGrad\"\n  bind_python: False\n\n- name: \"sinh\"\n  signature: \"Tensor (Tensor x) => Sinh\"\n  bind_python: True\n\n- name: \"sinh_grad\"\n  signature: \"Tensor (Tensor x, Tensor dy) => SinhGrad\"\n  bind_python: False\n\n- name: \"softplus\"\n  signature: \"Tensor (Tensor x, Double beta=1.0, Double threshold=20.0) => Softplus\"\n  bind_python: True\n\n- name: \"softplus_grad\"\n  signature: \"Tensor (Tensor x, Tensor dy, Double beta=1.0, Double threshold=20.0) => SoftplusGrad\"\n  bind_python: False\n\n- name: \"softshrink\"\n  signature: \"Tensor (Tensor x, *, Double alpha=0.5, Bool inplace=False) => SoftShrink\"\n  bind_python: True\n\n- name: \"softshrink_grad\"\n  signature: \"Tensor (Tensor y, Tensor dy, Double alpha=0.5) => SoftShrinkGrad\"\n  bind_python: False\n\n- name: \"one_hot\"\n  signature: \"Tensor (Tensor input, Int64 num_classes=-1, Scalar on_value=1, Scalar off_value=0) => OneHot\"\n  bind_python: True\n\n- name: \"unsorted_segment_sum_like\"\n  signature: \"Tensor (Tensor x, Tensor segment_ids, Tensor like, Int64 axis) => UnsortedSegmentSumLike\"\n  bind_python: True\n\n- name: \"unsorted_segment_sum\"\n  signature: \"Tensor (Tensor x, Tensor segment_ids, Int64 axis, Int64 num_segments) => UnsortedSegmentSum\"\n  bind_python: True\n\n- name: \"tril\"\n  signature: \"Tensor (Tensor x, Int64 diagonal=0) => Tril\"\n  bind_python: True\n\n- name: \"tril_\"\n  signature: \"Tensor (Tensor x, Int64 diagonal=0) => InplaceTril\"\n  bind_python: True\n\n- name: \"triu\"\n  signature: \"Tensor (Tensor x, Int64 diagonal=0) => Triu\"\n  bind_python: True\n\n- name: \"triu_\"\n  signature: \"Tensor (Tensor x, Int64 diagonal=0) => InplaceTriu\"\n  bind_python: True\n\n- name: \"clamp\"\n  signature: \"Tensor (Tensor input, Scalar min=None, Scalar max=None) => Clamp\"\n  bind_python: true\n\n- name: \"clamp_\"\n  signature: \"Tensor (Tensor input, Scalar min=None, Scalar max=None) => ClampInplace\"\n  bind_python: true\n\n- name: \"clamp_min\"\n  signature: \"Tensor (Tensor input, Scalar min) => ClampMin\"\n  bind_python: true\n\n- name: \"clamp_min_\"\n  signature: \"Tensor (Tensor input, Scalar min) => ClampMinInplace\"\n  bind_python: true\n\n- name: \"clamp_max\"\n  signature: \"Tensor (Tensor input, Scalar max) => ClampMax\"\n  bind_python: true\n\n- name: \"clamp_max_\"\n  signature: \"Tensor (Tensor input, Scalar min) => ClampMaxInplace\"\n  bind_python: true\n\n- name: \"clip\"\n  signature: [\"Tensor (Tensor input, Scalar min=None, Scalar max=None) => Clip\"]\n  bind_python: true\n\n- name: \"clip_\"\n  signature:\n    [\"Tensor (Tensor input, Scalar min=None, Scalar max=None) => ClipInplace\"]\n  bind_python: true\n\n- name: \"clamp_grad\"\n  signature: \"Tensor (Tensor dy, Tensor x, Scalar min=None, Scalar max=None) => ClampGrad\"\n  bind_python: False\n\n- name: \"vector_norm\"\n  signature:\n    [\n      \"Tensor (Tensor input, Scalar ord=2, Int32List dim=None, Bool keepdim=False, *, DataType dtype=None) => VectorNorm\",\n      \"Tensor (Tensor input, Scalar ord=2, Scalar dim, Bool keepdim=False, *, DataType dtype=None) => VectorNorm\",\n    ]\n  bind_python: True\n\n- name: \"matrix_norm\"\n  signature:\n    [\n      \"Tensor (Tensor input, Scalar ord, Int32List dim, Bool keepdim=False, *, DataType dtype=None) => MatrixNorm\",\n      \"Tensor (Tensor input, String ord, Int32List dim, Bool keepdim=False, *, DataType dtype=None) => MatrixNorm\",\n    ]\n  bind_python: True\n\n- name: \"norm\"\n  signature:\n    [\n      \"Tensor (Tensor input, Scalar ord=None, Int32List dim=None, Bool keepdim=False, *, DataType dtype=None, Bool for_norm=False) => Norm\",\n      \"Tensor (Tensor input, String ord, Int32List dim=None, Bool keepdim=False, *, DataType dtype=None) => Norm\",\n      \"Tensor (Tensor input, Scalar ord=None, Scalar dim, Bool keepdim=False, *, DataType dtype=None) => ScalarNorm\",\n      \"Tensor (Tensor input, String ord, Scalar dim, Bool keepdim=False, *, DataType dtype=None) => ScalarNorm\",\n    ]\n  bind_python: True\n\n- name: \"inv\"\n  signature: \"Tensor (Tensor x) => Inv\"\n  bind_python: True\n\n- name: \"linalg_cross\"\n  signature: \"Tensor (Tensor input, Tensor other, Int64 dim=None) => LinalgCross\"\n  bind_python: True\n\n- name: \"det\"\n  signature: \"Tensor (Tensor x) => Det\"\n  bind_python: True\n\n- name: \"dropout\"\n  signature: \"Tensor (Tensor input, Float p=0.5, Bool training=True, Bool inplace=False, Generator generator=None, *, Tensor addend=None) => Dropout\"\n  bind_python: True\n\n- name: \"dropout_grad\"\n  signature: \"Tensor (Tensor dy, Tensor mask, Float scale) => DropoutGrad\"\n  bind_python: False\n\n- name: \"dropout1d\"\n  signature: \"Tensor (Tensor input, Float p=0.5, Bool training=True) => Dropout1d\"\n  bind_python: True\n\n- name: \"dropout2d\"\n  signature: \"Tensor (Tensor input, Float p=0.5, Bool training=True) => Dropout2d\"\n  bind_python: True\n\n- name: \"dropout3d\"\n  signature: \"Tensor (Tensor input, Float p=0.5, Bool training=True) => Dropout3d\"\n  bind_python: True\n\n- name: \"constant_pad\"\n  signature: \"Tensor (Tensor x, Int64List pad, Scalar value=0) => ConstantPad\"\n  bind_python: False\n\n- name: \"reflection_pad\"\n  signature: \"Tensor (Tensor x, Int64List pad) => ReflectionPad\"\n  bind_python: False\n\n- name: \"replication_pad\"\n  signature: \"Tensor (Tensor x, Int64List pad) => ReplicationPad\"\n  bind_python: False\n\n- name: \"pad\"\n  signature: 'Tensor (Tensor x, Int64List pad, String mode=\"constant\", Scalar value=0) => Pad'\n  bind_python: True\n\n- name: \"pad_grad\"\n  signature: 'Tensor (Tensor dy, Int64List pad, String mode=\"constant\", Scalar value=0) => PadGrad'\n  bind_python: False\n\n- name: \"silu\"\n  signature: \"Tensor (Tensor x) => Silu\"\n  bind_python: True\n\n- name: \"silu_grad\"\n  signature: \"Tensor (Tensor dy, Tensor x) => SiluGrad\"\n  bind_python: False\n\n- name: \"mish\"\n  signature: \"Tensor (Tensor x) => Mish\"\n  bind_python: True\n\n- name: \"mish_grad\"\n  signature: \"Tensor (Tensor dy, Tensor x) => MishGrad\"\n  bind_python: False\n\n- name: \"selu\"\n  signature: \"Tensor (Tensor x) => Selu\"\n  bind_python: True\n\n- name: \"selu_grad\"\n  signature: \"Tensor (Tensor dy, Tensor x) => SeluGrad\"\n  bind_python: False\n\n- name: \"softsign\"\n  signature: \"Tensor (Tensor x) => SoftSign\"\n  bind_python: True\n\n- name: \"softsign_grad\"\n  signature: \"Tensor (Tensor dy, Tensor x) => SoftSignGrad\"\n  bind_python: False\n\n- name: \"diag\"\n  signature: \"Tensor (Tensor x, Int32 diagonal=0) => Diag\"\n  bind_python: True\n\n- name: \"diag_grad\"\n  signature: \"Tensor (Tensor dy, Tensor in, Int32 diagonal=0) => DiagGrad\"\n  bind_python: False\n\n- name: \"diagonal\"\n  signature: \"Tensor (Tensor x, Int32 offset=0, Int32 dim1=0, Int32 dim2=1) => Diagonal\"\n  bind_python: True\n\n- name: \"diagonal_grad\"\n  signature: \"Tensor (Tensor dy, Tensor in, Int32 offset=0) => DiagonalGrad\"\n  bind_python: False\n\n- name: \"tensor_getitem\"\n  signature: \"Tensor (Tensor x, TensorIndex index) => TensorGetItem\"\n  bind_python: False\n\n- name: \"scatter\"\n  signature:\n    [\n      \"Tensor (Tensor input, Int32 dim, Tensor index, Tensor src, *, String reduce=None, Bool inplace=False) => DimScatter\",\n      \"Tensor (Tensor input, Int32 dim, Tensor index, Scalar src, *, String reduce=None, Bool inplace=False) => DimScatterScalar\",\n    ]\n  bind_python: True\n\n- name: \"scatter_update\"\n  signature:\n    [\n      \"Tensor (Tensor input, Int32 dim, Tensor index, Tensor src, *, Bool inplace=False) => DimScatterUpdate\",\n      \"Tensor (Tensor input, Int32 dim, Tensor index, Scalar src, *, Bool inplace=False) => DimScatterUpdateScalar\",\n    ]\n  bind_python: False\n\n- name: \"scatter_add\"\n  signature:\n    [\n      \"Tensor (Tensor input, Int32 dim, Tensor index, Tensor src, *, Bool inplace=False) => DimScatterAdd\",\n      \"Tensor (Tensor input, Int32 dim, Tensor index, Scalar src, *, Bool inplace=False) => DimScatterAddScalar\",\n    ]\n  bind_python: True\n\n- name: \"scatter_mul\"\n  signature:\n    [\n      \"Tensor (Tensor input, Int32 dim, Tensor index, Tensor src, *, Bool inplace=False) => DimScatterMul\",\n      \"Tensor (Tensor input, Int32 dim, Tensor index, Scalar src, *, Bool inplace=False) => DimScatterMulScalar\",\n    ]\n  bind_python: False\n\n- name: \"scatter_add_like\"\n  signature: \"Tensor (Tensor like, Int32 dim, Tensor index, Tensor src) => DimScatterAddLike\"\n  bind_python: False\n\n- name: \"tensor_setitem\"\n  signature: \"Void (Tensor x, TensorIndex index, Tensor value) => TensorSetItem\"\n  bind_python: True\n\n- name: \"avg_pool1d\"\n  signature:\n    'Tensor (Tensor input, Int32List[1] kernel_size, Int32List[1] stride=None,\n    Int32List[1] padding=0, Bool ceil_mode=False, Bool count_include_pad=True,\n    Int32 divisor_override=0, String data_format=\"channels_first\") => AvgPool1D'\n  bind_python: True\n\n- name: \"avg_pool2d\"\n  signature:\n    'Tensor (Tensor input, Int32List[2] kernel_size, Int32List[2] stride=None,\n    Int32List[2] padding=0, Bool ceil_mode=False, Bool count_include_pad=True,\n    Int32 divisor_override=0, String data_format=\"channels_first\") => AvgPool2D'\n  bind_python: True\n\n- name: \"avg_pool3d\"\n  signature:\n    'Tensor (Tensor input, Int32List[3] kernel_size, Int32List[3] stride=None,\n    Int32List[3] padding=0, Bool ceil_mode=False, Bool count_include_pad=True,\n    Int32 divisor_override=0, String data_format=\"channels_first\") => AvgPool3D'\n  bind_python: True\n\n- name: \"avg_pool_grad\"\n  signature:\n    \"Tensor (Tensor x, Tensor dy, Int32 ndims, String data_format, Int32List padding,\n    Int32List kernel_size, Int32List stride, Bool ceil_mode, Bool count_include_pad,\n    Int32 divisor_override=0) => AvgPoolNdGrad\"\n  bind_python: False\n\n- name: \"minimum\"\n  signature: \"Tensor (Tensor input, Tensor other) => Minimum\"\n  bind_python: True\n\n- name: \"maximum\"\n  signature: \"Tensor (Tensor input, Tensor other) => Maximum\"\n  bind_python: True\n\n- name: \"elementwise_min_grad\"\n  signature: \"TensorTuple (Tensor dz, Tensor x, Tensor y) => ElementwiseMinGrad\"\n  bind_python: False\n\n- name: \"elementwise_max_grad\"\n  signature: \"TensorTuple (Tensor dz, Tensor x, Tensor y) => ElementwiseMaxGrad\"\n  bind_python: False\n\n- name: \"stack\"\n  signature: \"Tensor (TensorTuple inputs, Int64 dim=0) => Stack\"\n  bind_python: True\n\n- name: \"stack_grad\"\n  signature: \"TensorTuple (Tensor x, TensorTuple like, Int64 axis) => StackGrad\"\n  bind_python: False\n\n- name: \"atleast_1d\"\n  signature:\n    [\n      \"Tensor (Tensor input) => AtLeast1D\",\n      \"TensorTuple (TensorTuple tensors) => AtLeast1D\",\n    ]\n  bind_python: True\n\n- name: \"atleast_2d\"\n  signature:\n    [\n      \"Tensor (Tensor input) => AtLeast2D\",\n      \"TensorTuple (TensorTuple tensors) => AtLeast2D\",\n    ]\n  bind_python: True\n\n- name: \"atleast_3d\"\n  signature:\n    [\n      \"Tensor (Tensor input) => AtLeast3D\",\n      \"TensorTuple (TensorTuple tensors) => AtLeast3D\",\n    ]\n  bind_python: True\n\n- name: \"hstack\"\n  signature: \"Tensor (TensorTuple tensors) => HStack\"\n  bind_python: True\n\n- name: \"vstack\"\n  signature: \"Tensor (TensorTuple tensors) => VStack\"\n  bind_python: True\n\n- name: \"dstack\"\n  signature: \"Tensor (TensorTuple tensors) => DStack\"\n  bind_python: True\n\n- name: \"column_stack\"\n  signature: \"Tensor (TensorTuple tensors) => ColumnStack\"\n  bind_python: True\n\n- name: \"row_stack\"\n  signature: \"Tensor (TensorTuple tensors) => RowStack\"\n  bind_python: True\n\n- name: \"local_to_global\"\n  signature: \"Tensor (Tensor x, Placement placement, SbpList sbp, Shape shape, DataType dtype, Bool sync_data, Bool copy=False) => LocalToGlobal\"\n  bind_python: False\n\n- name: \"to_global\"\n  signature: \"Tensor (Tensor x, Placement placement, SbpList sbp, SbpList grad_sbp, Bool check_meta, Bool copy=False) => ToGlobal\"\n  bind_python: True\n\n- name: \"to_local\"\n  signature: \"Tensor (Tensor x, Bool copy=False) => GlobalToLocal\"\n  bind_python: True\n\n- name: \"stream_touch\"\n  signature: \"Void (TensorTuple x) => StreamTouch\"\n  bind_python: True\n\n- name: \"comm_broadcast\"\n  signature:\n    [\n      \"Tensor (Tensor x, *, Int64 src_rank=0, Bool inplace=True) => CommBroadcast\",\n      \"TensorTuple (TensorTuple inputs, *, Int64 src_rank=0, Bool inplace=True) => CommBroadcastTensors\",\n    ]\n  bind_python: True\n\n- name: \"local_all_reduce\"\n  signature: \"Tensor (Tensor x, Bool inplace=False) => LocalAllReduce\"\n  bind_python: True\n\n- name: \"local_all_gather\"\n  signature: \"Tensor (Tensor output, Tensor input) => LocalAllGather\"\n  bind_python: True\n\n- name: \"local_reduce_scatter\"\n  signature: \"Tensor (Tensor output, Tensor input) => LocalReduceScatter\"\n  bind_python: True\n\n- name: \"local_reduce\"\n  signature: \"Tensor (Tensor x, *, Int64 dst=0, Bool inplace=True) => LocalReduce\"\n  bind_python: True\n\n- name: \"eager_p_to_b\"\n  signature: \"Tensor (Tensor x, Placement in_placement, Placement out_placement, Shape shape) => EagerPToB\"\n  bind_python: False\n\n- name: \"eager_b_to_s\"\n  signature: \"Tensor (Tensor x, Placement in_placement, Placement out_placement, SbpList out_sbp, Shape shape) => EagerBToS\"\n  bind_python: False\n\n- name: \"eager_s_to_b\"\n  signature: \"Tensor (Tensor x, Placement in_placement, Placement out_placement, SbpList in_sbp, Shape shape) => EagerSToB\"\n  bind_python: False\n\n- name: \"eager_naive_s_to_s\"\n  signature: \"Tensor (Tensor x, Placement in_placement, Placement out_placement, SbpList in_sbp, SbpList out_sbp, Shape shape) => EagerNaiveSToS\"\n  bind_python: False\n\n- name: \"eager_p_to_s\"\n  signature: \"Tensor (Tensor x, Placement in_placement, Placement out_placement, SbpList out_sbp, Shape shape) => EagerPToS\"\n  bind_python: False\n\n- name: \"eager_s_to_p\"\n  signature: \"Tensor (Tensor x, Placement in_placement, Placement out_placement, SbpList out_sbp, Shape shape) => EagerSToP\"\n  bind_python: False\n\n- name: \"global_all_reduce\"\n  signature: \"Tensor (Tensor x) => GlobalAllReduce\"\n  bind_python: False\n\n- name: \"global_reduce_scatter\"\n  signature: \"Tensor (Tensor x, String op_type) => GlobalReduceScatter\"\n  bind_python: False\n\n- name: \"global_all_gather\"\n  signature: \"Tensor (Tensor x) => GlobalAllGather\"\n  bind_python: False\n\n- name: \"global_s2s\"\n  signature: \"Tensor (Tensor x, SbpList out_sbp) => GlobalS2S\"\n  bind_python: False\n\n- name: \"select_top_n\"\n  signature: \"TensorTuple (TensorTuple inputs, Int32 n) => SelectTopN\"\n  bind_python: True\n\n- name: \"cast_like\"\n  signature: \"Tensor (Tensor x, Tensor like) => CastLike\"\n  bind_python: False\n\n- name: \"identity\"\n  signature: \"Tensor (Tensor in) => Identity\"\n  bind_python: True\n\n- name: \"amp_white_identity\"\n  signature: \"Tensor (Tensor in) => AmpWhiteIdentity\"\n  bind_python: True\n\n- name: \"amp_black_identity\"\n  signature: \"Tensor (Tensor in) => AmpBlackIdentity\"\n  bind_python: True\n\n- name: \"reshape_like\"\n  signature: \"Tensor (Tensor in, Tensor like) => ReshapeLike\"\n  bind_python: True\n\n- name: \"reduce_sum_like\"\n  signature: \"Tensor (Tensor in, Tensor like, Int32List axis) => ReduceSumLike\"\n  bind_python: True\n\n- name: \"broadcast_reduce_sum_like\"\n  signature: \"Tensor (Tensor in, Tensor like) => BroadcastReduceSumLike\"\n  bind_python: False\n\n- name: \"rand\"\n  signature: [\n      \"Tensor (Shape size, *, DataType dtype=None, Device device=None,\n      Generator generator=None, Bool requires_grad=False) => Rand\",\n      \"Tensor (Shape size, *, Placement placement, SbpList sbp, DataType dtype=None,\n      Generator generator=None, Bool requires_grad=False) => GlobalRand\",\n    ]\n  bind_python: True\n\n- name: \"randn\"\n  signature: [\n      \"Tensor (Shape size, *, DataType dtype=None, Device device=None,\n      Generator generator=None, Bool requires_grad=False, Layout layout=kStrided) => RandN\",\n      \"Tensor (Shape size, *, Placement placement, SbpList sbp, DataType dtype=None,\n      Generator generator=None, Bool requires_grad=False) => GlobalRandN\",\n    ]\n  bind_python: True\n\n- name: \"randn_like\"\n  signature: [\n      \"Tensor (Tensor input, *, DataType dtype=None, Device device=None,\n      Generator generator=None, Bool requires_grad=False) => RandnLike\",\n      \"Tensor (Tensor input, *, Placement placement, SbpList sbp, DataType dtype=None,\n      Generator generator=None, Bool requires_grad=False) => GlobalRandnLike\",\n    ]\n  bind_python: True\n\n- name: \"randint\"\n  signature: [\n      \"Tensor (Int64 low, Int64 high, Shape size, *, DataType dtype=None,\n      Device device=None, Generator generator=None, Bool requires_grad=False)=> RandInt\",\n      \"Tensor (Int64 high, Shape size, *, DataType dtype=None,\n      Device device=None, Generator generator=None, Bool requires_grad=False)=> RandInt\",\n      \"Tensor (Int64 low, Int64 high, Shape size, *, Placement placement, SbpList sbp,\n      DataType dtype=None, Generator generator=None, Bool requires_grad=False)=> GlobalRandInt\",\n      \"Tensor (Int64 high, Shape size, *, Placement placement, SbpList sbp,\n      DataType dtype=None, Generator generator=None, Bool requires_grad=False)=> GlobalRandInt\",\n    ]\n  bind_python: True\n\n- name: \"randint_like\"\n  signature: [\n      \"Tensor (Tensor x, Int64 low, Int64 high, *, DataType dtype=None,\n      Device device=None, Generator generator=None, Bool requires_grad=False)=> RandIntLike\",\n      \"Tensor (Tensor x, Int64 high, *, DataType dtype=None,\n      Device device=None, Generator generator=None, Bool requires_grad=False)=> RandIntLike\",\n      \"Tensor (Tensor x, Int64 low, Int64 high, *, Placement placement, SbpList sbp,\n      DataType dtype=None, Generator generator=None, Bool requires_grad=False)=> GlobalRandIntLike\",\n      \"Tensor (Tensor x, Int64 high, *, Placement placement, SbpList sbp,\n      DataType dtype=None, Generator generator=None, Bool requires_grad=False)=> GlobalRandIntLike\",\n    ]\n  bind_python: True\n\n- name: \"randperm\"\n  signature:\n    [\n      \"Tensor (Int32 n, *, Generator generator=None, DataType dtype=kInt64, Device device=None, Bool requires_grad=False) => RandPerm\",\n      \"Tensor (Int32 n, *, Placement placement, SbpList sbp, Generator generator=None, DataType dtype=kInt64, Bool requires_grad=False) => GlobalRandPerm\",\n    ]\n  bind_python: True\n\n- name: \"unfold_tensor\"\n  signature: \"Tensor (Tensor x, Int32 dimension, Int32 size, Int32 step) => UnfoldTensor\"\n  bind_python: True\n\n- name: \"unfold_tensor_grad\"\n  signature: \"Tensor (Tensor dy, Tensor x, Int32 dimension, Int32 size, Int32 step) => UnfoldTensorGrad\"\n  bind_python: False\n\n- name: \"unfold\"\n  signature:\n    'Tensor (Tensor x, Int32List[2] kernel_size, Int32List[2] dilation=1, Int32List[2] padding=0,\n    Int32List[2] stride=1, String data_format=\"channels_first\") => Unfold'\n  bind_python: True\n\n- name: \"fold\"\n  signature:\n    'Tensor (Tensor x, Int32List[1] output_size, Int32List[2] kernel_size, Int32List[2] dilation=1,\n    Int32List[2] padding=0, Int32List[2] stride=1, String data_format=\"channels_first\") => Fold'\n  bind_python: True\n\n- name: \"split\"\n  signature:\n    [\n      \"TensorTuple (Tensor x, Int64 split_size_or_sections, Int64 dim=0) => Split\",\n      \"TensorTuple (Tensor x, Int64List split_size_or_sections, Int64 dim=0) => SplitWithSize\",\n    ]\n  bind_python: True\n\n- name: \"unbind\"\n  signature: [\"TensorTuple (Tensor x, Int64 dim=0) => Unbind\"]\n  bind_python: True\n\n- name: \"chunk\"\n  signature: [\"TensorTuple (Tensor x, Int64 chunks, Int64 dim=0) => Chunk\"]\n  bind_python: True\n\n- name: \"split_like\"\n  signature: \"TensorTuple (Tensor x, TensorTuple like, Int64 axis) => SplitLike\"\n  bind_python: True\n\n- name: \"pairwise_distance\"\n  signature: \"Tensor (Tensor x1, Tensor x2, Float p=2.0, Double eps=1e-6, Bool keepdim=False) => PairwiseDistance\"\n  bind_python: True\n\n- name: \"cosine_similarity\"\n  signature: \"Tensor (Tensor x, Tensor y, Int32 dim=1, Double eps=1e-8) => CosineSimilarity\"\n  bind_python: True\n\n- name: \"normalize\"\n  signature: \"Tensor (Tensor input, Float p=2.0, Int32 dim=1, Float eps=1e-12, Bool use_l2_norm_kernel=True) => Normalize\"\n  bind_python: True\n\n- name: \"l2_normalize\"\n  signature: \"Tensor (Tensor input, Int32 axis=0, Float epsilon=1e-12) => L2Normalize\"\n  bind_python: False\n\n- name: \"l2_normalize_grad\"\n  signature: \"Tensor (Tensor dy, Tensor y, Tensor square_x_sum, Int32 axis, Float epsilon) => L2NormalizeGrad\"\n  bind_python: False\n\n- name: \"fused_self_attention\"\n  signature: \"TensorTuple (Tensor hidden_states, Int64 head_size=8, Float alpha=1.0) => FusedSelfAttention\"\n  bind_python: True\n\n- name: \"fused_self_attention_grad\"\n  signature: \"Tensor (Tensor query_mul_key_grad, Tensor value_grad, Tensor hidden_states, Float alpha=1.0) => FusedSelfAttentionGrad\"\n  bind_python: False\n\n- name: \"fused_scale_tril\"\n  signature: \"Tensor (Tensor x, Int64 diagonal=0, Scalar fill_value=0, Scalar scale=1) => FusedScaleTril\"\n  bind_python: True\n\n- name: \"fused_bias_add_gelu\"\n  signature: \"Tensor (Tensor a, Tensor b, *, Int32 axis) => FusedBiasAddGelu\"\n  bind_python: True\n\n- name: \"fused_bias_add_gelu_grad\"\n  signature: \"Tensor (Tensor a, Tensor b, Tensor dy, Int32 axis) => FusedBiasAddGeluGrad\"\n  bind_python: false\n\n- name: \"fused_bias_add_dropout\"\n  signature: \"Tensor (Tensor a, Tensor b, *, Float p=0.5, Int32 axis, Generator generator=None) => FusedBiasAddDropout\"\n  bind_python: True\n\n- name: \"fused_scale_mask_softmax\"\n  signature: \"Tensor (Tensor x, Tensor mask, *, Float fill_value=0.0, Float scale=1.0) => FusedScaleMaskSoftmax\"\n  bind_python: True\n\n- name: \"fused_scale_mask_softmax_grad\"\n  signature: \"Tensor (Tensor y, Tensor dy, Tensor mask, Float scale=1.0) => FusedScaleMaskSoftmaxGrad\"\n  bind_python: False\n\n- name: \"fused_scale_mask_softmax_dropout\"\n  signature: \"TensorTuple (Tensor x, Tensor mask, *, Float fill_value=0.0, Float scale=1.0, Float p=0.5, Bool training=True, Generator generator=None) => FusedScaleMaskSoftmaxDropout\"\n  bind_python: True\n\n- name: \"fused_scale_mask_softmax_dropout_grad\"\n  signature: \"Tensor (Tensor softmax_y, Tensor dy, Tensor mask, Tensor dropout_mask, Float scale=1.0, Float dropout_scale=1.0) => FusedScaleMaskSoftmaxDropoutGrad\"\n  bind_python: False\n\n- name: \"fused_scale_tril_softmax_mask_scale\"\n  signature: \"TensorTuple (Tensor a, *, Float p=0.5, Int64 diagonal, Float tril_scale_value, Float tril_fill_value=0.0, Generator generator=None) => FusedScaleTrilSoftmaxMaskScale\"\n  bind_python: True\n\n- name: \"fused_scale_tril_softmax_mask_scale_grad\"\n  signature: \"Tensor (Tensor softmax_y, Tensor dy, Tensor mask, Int64 diagonal, Float tril_scale_value, Float mask_scale_value) => FusedScaleTrilSoftmaxMaskScaleGrad\"\n  bind_python: False\n\n- name: \"fused_bias_add_scale_mask_softmax_dropout\"\n  signature: \"TensorTuple (Tensor x, Tensor bias, Tensor mask, *, Float fill_value=0.0, Float scale=1.0, Float p=0.5, Bool training=True, Generator generator=None) => FusedBiasAddScaleMaskSoftmaxDropout\"\n  bind_python: True\n\n- name: \"scaled_dot_product_attention\"\n  signature: \"Tensor (Tensor query, Tensor key, Tensor value, Tensor attn_mask=None, Float dropout_p=0.0, Bool is_causal=False, Float scale=None, Int64 seed=0) => ScaledDotProductFlashAttention\"\n  bind_python: True\n\n- name: \"scaled_dot_product_attention_grad\"\n  signature: \"TensorTuple (Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor softmax_lse, Tensor rng_state, Float dropout_p=0.0, Bool is_causal=False, Float scale=0.0) => ScaledDotProductFlashAttentionGrad\"\n  bind_python: False\n\n- name: \"fused_multi_head_attention_inference\"\n  signature: \"Tensor (Tensor query, Tensor key, Tensor value, Int64 num_heads, Bool causal=False, Int64 query_hidden_slice_start=0, Int64 query_hidden_slice_end=-1, Int64 key_hidden_slice_start=0, Int64 key_hidden_slice_end=-1, Int64 value_hidden_slice_start=0, Int64 value_hidden_slice_end=-1, Tensor attn_bias=None, Int64 causal_diagonal_offset=0) => FusedMultiHeadAttentionInference\"\n  bind_python: True\n\n- name: \"fused_multi_head_attention_inference_v2\"\n  signature: 'Tensor (*, Tensor query, String query_layout, Int64 query_head_size=None, Tensor query_seq_start=None, Int64 query_max_seq_len=None, Tensor key=None, String key_layout=None, Tensor key_seq_start=None, Tensor key_seq_len=None, Int64 key_max_seq_len=None, Tensor value=None, String value_layout=None, Tensor attn_bias=None, String output_layout=\"BM(HK)\", Float scale=None, Bool causal=None, String attn_mask_type=None, Int64 causal_diagonal_offset=0) => FusedMultiHeadAttentionInferenceV2'\n  bind_python: True\n\n- name: \"fused_attention_concat_past_key_value\"\n  signature: 'TensorTuple (*, Tensor past_key=None, String past_key_layout, Tensor past_value=None, String past_value_layout, Tensor key, String key_layout, Tensor value, String value_layout, Int64 key_head_size=None) => FusedAttentionConcatPastKeyValue'\n  bind_python: True\n\n- name: \"fused_scale_mask_bias_softmax\"\n  signature: 'Tensor (Tensor x, Tensor mask, Tensor bias=None, Float scale=0.35355, Bool inplace=False) => FusedScaleMaskBiasSoftmax'\n  bind_python: True\n\n- name: \"fused_scale_mask_bias_softmax_grad\"\n  signature: 'Tensor (Tensor y, Tensor dy, Float scale=0.35355) => FusedScaleMaskBiasSoftmaxGrad'\n  bind_python: False\n\n- name: \"noncontiguous_binary_op\"\n  signature: 'Tensor (Tensor lhs, Tensor rhs, String op=\"add\", Bool inplace=False) => NonContiguousBinaryOp'\n  bind_python: True\n\n- name: \"noncontiguous_binary_op_grad\"\n  signature: 'TensorTuple (Tensor dy, Tensor lhs, Tensor rhs, String op=\"add\", Bool inplace=False) => NonContiguousBinaryOpGrad'\n  bind_python: False\n\n- name: \"fused_get_center_dist\"\n  signature: \"Tensor (Tensor b1_x1, Tensor b1_x2, Tensor b2_x1, Tensor b2_x2, Tensor b1_y1, Tensor b1_y2, Tensor b2_y1, Tensor b2_y2) => FusedCenter\"\n  bind_python: True\n\n- name: \"fused_get_center_dist_grad\"\n  signature: \"TensorTuple (Tensor b1_x1, Tensor b1_x2, Tensor b2_x1, Tensor b2_x2, Tensor b1_y1, Tensor b1_y2, Tensor b2_y1, Tensor b2_y2, Tensor rho2_diff) => FusedCenterGrad\"\n  bind_python: False\n\n- name: \"fused_get_intersection_area\"\n  signature: \"Tensor (Tensor b1_x1, Tensor b1_x2, Tensor b2_x1, Tensor b2_x2, Tensor b1_y1, Tensor b1_y2, Tensor b2_y1, Tensor b2_y2) => FusedGetIntersectionArea\"\n  bind_python: True\n\n- name: \"fused_get_intersection_area_grad\"\n  signature: \"TensorTuple (Tensor b1_x1, Tensor b1_x2, Tensor b2_x1, Tensor b2_x2, Tensor b1_y1, Tensor b1_y2, Tensor b2_y1, Tensor b2_y2, Tensor inter_diff) => FusedGetIntersectionAreaGrad\"\n  bind_python: False\n\n- name: \"fused_get_ciou_diagonal_angle\"\n  signature: \"Tensor (Tensor w1, Tensor h1, Tensor w2, Tensor h2, Float eps) => FusedGetCiouDiagonalAngle\"\n  bind_python: True\n\n- name: \"fused_get_ciou_diagonal_angle_grad\"\n  signature: \"TensorTuple (Tensor w1, Tensor h1, Tensor w2, Tensor h2, Tensor v_diff, Float eps) => FusedGetCiouDiagonalAngleGrad\"\n  bind_python: False\n\n- name: \"fused_get_convex_diagonal_squared\"\n  signature: \"Tensor (Tensor b1_x1, Tensor b1_x2, Tensor b2_x1, Tensor b2_x2, Tensor b1_y1, Tensor b1_y2, Tensor b2_y1, Tensor b2_y2, Float eps) => FusedGetConvexDiagonalSquared\"\n  bind_python: True\n\n- name: \"fused_get_convex_diagonal_squared_grad\"\n  signature: \"TensorTuple (Tensor c2_diff, Tensor b1_x1, Tensor b1_x2, Tensor b2_x1, Tensor b2_x2, Tensor b1_y1, Tensor b1_y2, Tensor b2_y1, Tensor b2_y2, Float eps) => FusedGetConvexDiagonalSquaredGrad\"\n  bind_python: False\n\n- name: \"grouped_matmul_bias\"\n  signature: \"TensorTuple (TensorTuple xs, TensorTuple weights, TensorTuple biases) => GroupedMatmulBias\"\n  bind_python: True\n\n- name: \"grouped_matmul\"\n  signature: \"TensorTuple (TensorTuple xs, TensorTuple weights) => GroupedMatmul\"\n  bind_python: True\n\n- name: \"send\"\n  signature: \"Void (Tensor input, Int64 dst, Bool send_meta=True) => Send\"\n  bind_python: True\n\n- name: \"recv\"\n  signature: \"Tensor (Int64 src, Shape shape=None, DataType dtype=None, Device device=None, *, Tensor out=None) => Recv\"\n  bind_python: True\n\n- name: \"batch_gather\"\n  signature: \"Tensor (Tensor in, Tensor indices) => BatchGather\"\n  bind_python: True\n\n- name: \"unsorted_batch_segment_sum\"\n  signature: \"Tensor (Tensor data, Tensor segment_ids, Int64 num_segments) => UnsortedBatchSegmentSum\"\n  bind_python: False\n\n- name: \"ctc_greedy_decoder\"\n  signature: \"TensorTuple (Tensor log_probs, Tensor input_lengths, Bool merge_repeated=True) => CtcGreedyDecoder\"\n  bind_python: True\n\n- name: \"distributed_partial_fc_sample_disable_boxing\"\n  signature: \"TensorTuple (Tensor sampled_weight_diff, Tensor sampled_label) => DistributedPariticalFCSampleDisableBoxing\"\n  bind_python: False\n\n- name: \"nms\"\n  signature: \"Tensor (Tensor x, Float iou_threshold, Int32 keep_n=-1) => Nms\"\n  bind_python: True\n\n- name: \"roi_align\"\n  signature: \"Tensor (Tensor x, Tensor rois, Float spatial_scale, Int32 pooled_h, Int32 pooled_w, Int32 sampling_ratio, Bool aligned) => RoiAlign\"\n  bind_python: True\n\n- name: \"roi_align_grad\"\n  signature: \"Tensor (Tensor dy, Tensor x_like, Tensor rois, Float spatial_scale, Int32 pooled_h, Int32 pooled_w, Int32 sampling_ratio, Bool aligned) => RoiAlignGrad\"\n  bind_python: False\n\n- name: \"meshgrid\"\n  signature: 'TensorTuple (TensorTuple tensors, String indexing=\"ij\") => Meshgrid'\n  bind_python: True\n\n- name: \"index_select\"\n  signature: \"Tensor (Tensor input, Int64 dim, Tensor index) => IndexSelect\"\n  bind_python: True\n\n- name: \"dot\"\n  signature: \"Tensor (Tensor input, Tensor other) => Dot\"\n  bind_python: True\n\n- name: \"fused_dot_feature_interaction\"\n  signature: 'Tensor (TensorTuple features, Tensor output_concat=None, Bool self_interaction=False, Int32 output_padding=0, String pooling=\"none\") => FusedDotFeatureInteraction'\n  bind_python: True\n\n- name: \"fused_dot_feature_interaction_grad\"\n  signature: 'TensorTuple (Tensor dy, TensorTuple features, Bool has_output_concat_grad=False, Bool self_interaction=False, Int32 output_concat_grad_dim=0, String pooling=\"none\") => FusedDotFeatureInteractionGrad'\n  bind_python: False\n\n- name: \"fused_cross_feature_interaction\"\n  signature: \"Tensor (Tensor x, Tensor weight, Tensor x_0, Tensor bias, String interaction_mode) => FusedCrossFeatureInteraction\"\n  bind_python: True\n\n- name: \"fused_cross_feature_interaction_v1_grad\"\n  signature: \"TensorTuple (Tensor dy, Tensor weight, Tensor x, Tensor x_0, Tensor matmul_result) => FusedCrossFeatureInteractionV1Grad\"\n  bind_python: False\n\n- name: \"fused_cross_feature_interaction_v2_grad\"\n  signature: \"TensorTuple (Tensor dy, Tensor weight, Tensor bias, Tensor x, Tensor x_0, Tensor matmul_result) => FusedCrossFeatureInteractionV2Grad\"\n  bind_python: False\n\n- name: \"tensor_buffer_to_tensor\"\n  signature: \"Tensor (Tensor input, Shape instance_shape, DataType dtype) => TensorBufferToTensor\"\n  bind_python: True\n\n- name: \"tensor_to_tensor_buffer\"\n  signature: \"Tensor (Tensor input, Int32 instance_dims) => TensorToTensorBuffer\"\n  bind_python: True\n\n- name: \"gen_tensor_buffer\"\n  signature: \"Tensor (Shape shape, ShapeList shape_list, FloatList value_list, DataType data_type, Bool dynamic_out) => GenTensorBuffer\"\n  bind_python: True\n\n- name: \"topk\"\n  signature: \"TensorTuple[values, indices] (Tensor input, Int32 k, Int32 dim=None, Bool largest=True, Bool sorted=True) => TopK\"\n  bind_python: True\n\n- name: \"in_top_k\"\n  signature: \"Tensor (Tensor targets, Tensor predictions, Int32 k) => InTopK\"\n  bind_python: True\n\n- name: \"cumsum\"\n  signature: \"Tensor (Tensor input, Int64 dim, *, DataType dtype=None) => Cumsum\"\n  bind_python: True\n\n- name: \"cumprod\"\n  signature: \"Tensor (Tensor input, Int64 dim, *, DataType dtype=None) => Cumprod\"\n  bind_python: True\n\n- name: \"cumprod_grad\"\n  signature: \"Tensor (Tensor input, Tensor y, Tensor x, Int64 dim) => CumprodGrad\"\n  bind_python: False\n\n- name: \"one_embedding_id_shuffle\"\n  signature: \"TensorTuple (Tensor ids, Tensor table_ids=None, Int32 num_tables=1, String embedding_name) => OneEmbeddingIdShuffle\"\n  bind_python: True\n\n- name: \"one_embedding_embedding_shuffle\"\n  signature: \"Tensor (Tensor cur_rank_embeddings, Tensor num_unique_matrix, Tensor cur_rank_inverse_indices, Tensor inverse_unique_partition_indices, String embedding_name) => OneEmbeddingEmbeddingShuffle\"\n  bind_python: True\n\n- name: \"one_embedding_embedding_gradient_shuffle\"\n  signature: \"Tensor (Tensor embedding_grad, Tensor num_unique_matrix, Tensor cur_rank_inverse_indices, Tensor inverse_unique_partition_indices, String embedding_name) => OneEmbeddingEmbeddingGradientShuffle\"\n  bind_python: True\n\n- name: \"one_embedding_lookup\"\n  signature: \"Tensor (Tensor num_unique_ids, Tensor unique_ids, Tensor table_ids, DataType dtype, DataType embedding_dtype, Int64 line_size, Int64 embedding_size, String embedding_name, String embedding_tables, String state_initializer, Int64 seed=0) => OneEmbeddingLookup\"\n  bind_python: True\n\n- name: \"one_embedding_fused_lookup\"\n  signature: \"Tensor (Tensor shadow, Tensor ids, Tensor table_ids=None, DataType dtype, String embedding_name, Int64 line_size, Int64 embedding_size, Bool is_full_cache, Int32 num_tables, String embedding_tables, Int64 padding_idx=None, Int64 seed=0) => OneEmbeddingFusedLookup\"\n  bind_python: True\n\n- name: \"one_embedding_fused_lookup_grad\"\n  signature: \"Void (Tensor ids, Tensor embedding_grad, String embedding_name, Int64 line_size, Int64 embedding_size) => OneEmbeddingFusedLookupGrad\"\n  bind_python: True\n\n- name: \"one_embedding_unique_key_value_pair\"\n  signature: \"TensorTuple (Tensor keys, Tensor values=None, Int32 num_tables, String embedding_name) => OneEmbeddingUniqueKeyValuePair\"\n  bind_python: True\n\n- name: \"one_embedding_embedding_put\"\n  signature: \"Void (Tensor num_unique_ids, Tensor unique_ids, Tensor unique_embeddings, String embedding_name, Int64 line_size) => OneEmbeddingEmbeddingPut\"\n  bind_python: True\n\n- name: \"one_embedding_sgd_update\"\n  signature: \"Tensor (Tensor num_unique_ids, Tensor unique_embeddings, Tensor embedding_grad, Tensor learning_rate=None, Tensor down_scale_by_tensor=None, Tensor skip_if=None, Float learning_rate_val, Double scale, Float weight_decay, Float momentum, Int64 line_size, Int64 embedding_size, String embedding_name) => OneEmbeddingSgdUpdate\"\n  bind_python: True\n\n- name: \"one_embedding_adam_update\"\n  signature: \"Tensor (Tensor num_unique_ids, Tensor unique_embeddings, Tensor embedding_grad, Tensor learning_rate=None, Tensor down_scale_by_tensor=None, Tensor skip_if=None, Tensor bias_correction1=None, Tensor bias_correction2=None, Float learning_rate_val, Double scale, Float weight_decay, Float beta1, Float beta2, Float bias_correction1_val, Float bias_correction2_val, Float epsilon, Bool do_bias_correction, Int64 line_size, Int64 embedding_size, String embedding_name) => OneEmbeddingAdamUpdate\"\n  bind_python: True\n\n- name: \"one_embedding_adagrad_update\"\n  signature: \"Tensor (Tensor num_unique_ids, Tensor unique_embeddings, Tensor embedding_grad, Tensor learning_rate=None, Tensor down_scale_by_tensor=None, Tensor skip_if=None, Tensor train_step=None, Int64 train_step_val, Float learning_rate_val, Double scale, Float weight_decay, Float lr_decay, Float epsilon, Int64 line_size, Int64 embedding_size, String embedding_name) => OneEmbeddingAdagradUpdate\"\n  bind_python: True\n\n- name: \"one_embedding_ftrl_update\"\n  signature: \"Tensor (Tensor num_unique_ids, Tensor unique_embeddings, Tensor embedding_grad, Tensor learning_rate=None, Tensor down_scale_by_tensor=None, Tensor skip_if=None, Float learning_rate_val, Double scale, Float weight_decay, Float lr_power, Float lambda1, Float lambda2, Float beta, Int64 line_size, Int64 embedding_size, String embedding_name) => OneEmbeddingFtrlUpdate\"\n  bind_python: True\n\n- name: \"einsum\"\n  signature: \"Tensor (String equation, TensorTuple operands) => EinSum\"\n  bind_python: True\n\n- name: \"pixel_shuffle\"\n  signature: \"Tensor (Tensor input, Int64 h_upscale_factor, Int64 w_upscale_factor) => PixelShuffle\"\n  bind_python: True\n\n- name: \"isnan\"\n  signature: \"Tensor (Tensor input) => IsNan\"\n  bind_python: True\n\n- name: \"isinf\"\n  signature: \"Tensor (Tensor input) => IsInf\"\n  bind_python: True\n\n- name: \"isfinite\"\n  signature: \"Tensor (Tensor input) => IsFinite\"\n  bind_python: True\n\n- name: \"depend\"\n  signature:\n    [\n      \"Tensor (Tensor input, Tensor depend) => Depend\",\n      \"Tensor (Tensor input, TensorTuple depends) => DependTuple\",\n    ]\n  bind_python: True\n\n- name: \"roc_auc_score\"\n  signature: \"Tensor (Tensor label, Tensor pred) => RocAucScore\"\n  bind_python: True\n\n- name: \"pin_memory\"\n  signature: \"Tensor (Tensor input) => PinMemory\"\n  bind_python: True\n\n- name: \"fill_\"\n  signature:\n    [\n      \"Tensor (Tensor in, Tensor value) => FillTensor\",\n      \"Tensor (Tensor in, Scalar value) => Fill\",\n    ]\n  bind_python: True\n\n- name: \"index_add\"\n  signature: \"Tensor (Tensor input, Int64 dim, Tensor index, Tensor source, Scalar alpha=1.0) => IndexAdd\"\n  bind_python: True\n\n- name: \"index_add_\"\n  signature: \"Tensor (Tensor input, Int64 dim, Tensor index, Tensor source, Scalar alpha=1.0) => IndexAddInplace\"\n  bind_python: True\n\n- name: \"rnn_tanh_cell\"\n  signature: \"Tensor (Tensor input, Tensor hx, Tensor w_ih, Tensor w_hh, Tensor b_ih=None, Tensor b_hh=None) => RnnTanhCell\"\n  bind_python: True\n\n- name: \"rnn_relu_cell\"\n  signature: \"Tensor (Tensor input, Tensor hx, Tensor w_ih, Tensor w_hh, Tensor b_ih=None, Tensor b_hh=None) => RnnReluCell\"\n  bind_python: True\n\n- name: \"lstm_cell\"\n  signature: \"TensorTuple (Tensor input, TensorTuple hx, Tensor w_ih, Tensor w_hh, Tensor b_ih=None, Tensor b_hh=None) => LstmCell\"\n  bind_python: True\n\n- name: \"gru_cell\"\n  signature: \"Tensor (Tensor input, Tensor hx, Tensor w_ih, Tensor w_hh, Tensor b_ih=None, Tensor b_hh=None) => GruCell\"\n  bind_python: True\n\n- name: \"_fused_gru_cell\"\n  signature: \"TensorTuple (Tensor igates, Tensor hgates, Tensor hx, Tensor b_ih=None, Tensor b_hh=None) => FusedGruCell\"\n  bind_python: False\n\n- name: \"_fused_gru_cell_grad\"\n  signature: \"TensorTuple (Tensor grad_hy, Tensor workspace, Bool has_bias, Bool hx_needs_grad) => FusedGruCellGrad\"\n  bind_python: False\n\n- name: \"_fused_lstm_cell\"\n  signature: \"TensorTuple (Tensor igates, Tensor hgates, Tensor cx, Tensor b_ih=None, Tensor b_hh=None) => FusedLstmCell\"\n  bind_python: False\n\n- name: \"_fused_lstm_cell_grad\"\n  signature: \"TensorTuple (Tensor grad_hy, Tensor grad_cy, Tensor cx, Tensor cy, Tensor workspace, Bool need_cx_grad, Bool has_bias) => FusedLstmCellGrad\"\n  bind_python: False\n\n- name: \"rnn_tanh\"\n  signature:\n    [\n      \"TensorTuple (Tensor input, Tensor hx, TensorTuple params, Bool has_biases, Int32 num_layers, Float dropout, Bool train, Bool bidirectional, Bool batch_first) => RnnTanhInput\",\n      \"TensorTuple (Tensor data, Tensor batch_sizes, Tensor hx, TensorTuple params, Bool has_biases, Int32 num_layers, Float dropout, Bool train, Bool bidirectional) => RnnTanhData\",\n    ]\n  bind_python: True\n\n- name: \"rnn_relu\"\n  signature:\n    [\n      \"TensorTuple (Tensor input, Tensor hx, TensorTuple params, Bool has_biases, Int32 num_layers, Float dropout, Bool train, Bool bidirectional, Bool batch_first) => RnnReluInput\",\n      \"TensorTuple (Tensor data, Tensor batch_sizes, Tensor hx, TensorTuple params, Bool has_biases, Int32 num_layers, Float dropout, Bool train, Bool bidirectional) => RnnReluData\",\n    ]\n  bind_python: True\n\n- name: \"lstm\"\n  signature:\n    [\n      \"TensorTuple (Tensor input, TensorTuple hx, TensorTuple params, Bool has_biases, Int32 num_layers, Float dropout, Bool train, Bool bidirectional, Bool batch_first) => LstmInput\",\n      \"TensorTuple (Tensor data, Tensor batch_sizes, TensorTuple hx, TensorTuple params, Bool has_biases, Int32 num_layers, Float dropout, Bool train, Bool bidirectional) => LstmData\",\n    ]\n  bind_python: True\n\n- name: \"gru\"\n  signature:\n    [\n      \"TensorTuple (Tensor input, Tensor hx, TensorTuple params, Bool has_biases, Int32 num_layers, Float dropout, Bool train, Bool bidirectional, Bool batch_first) => GruInput\",\n      \"TensorTuple (Tensor data, Tensor batch_sizes, Tensor hx, TensorTuple params, Bool has_biases, Int32 num_layers, Float dropout, Bool train, Bool bidirectional) => GruData\",\n    ]\n  bind_python: True\n\n- name: \"pack_padded_sequence\"\n  signature: \"TensorTuple (Tensor input, Tensor lengths, Bool batch_first) => PackPaddedSequence\"\n  bind_python: True\n\n- name: \"multi_tensor_sgd_update\"\n  signature: \"Void (TensorTuple model, TensorTuple model_diff, Double scale, Float weight_decay, Float learning_rate_val) => MultiTensorSgdUpdate\"\n  bind_python: True\n\n- name: \"multi_tensor_yolov5_weight_update\"\n  signature: \"Void (TensorTuple model, TensorTuple model_update, Float d) => MultiTensorYoloV5WeightUpdate\"\n  bind_python: True\n\n- name: \"multi_tensor_momentum_update\"\n  signature: \"Void (TensorTuple model, TensorTuple model_diff, TensorTuple momentum_buf, Double scale, Float weight_decay, Float learning_rate_val, Float momentum, Float dampening, Bool nesterov, Bool maximize) => MultiTensorMomentumUpdate\"\n  bind_python: True\n\n- name: \"multi_tensor_adam_update\"\n  signature: \"Void (TensorTuple model, TensorTuple model_diff, TensorTuple m, TensorTuple v, Float learning_rate_val, Float l2, Float beta1, Float beta2, Float bias_correction1_val, Float bias_correction2_val, Bool do_bias_correction, Double scale, Float weight_decay, Float epsilon) => MultiTensorAdamUpdate\"\n  bind_python: True\n\n- name: \"grad_acc_repeat\"\n  signature: \"Tensor (Tensor input, Int32 repeat_num) => GradAccRepeat\"\n  bind_python: False\n\n- name: \"grad_acc_collect\"\n  signature: \"Tensor (Tensor input, Int32 collect_num) => GradAccCollect\"\n  bind_python: False\n\n- name: \"grad_acc_pack\"\n  signature: \"Tensor (Tensor input, Int32 pack_num) => GradAccPack\"\n  bind_python: False\n\n- name: \"grad_acc_unpack\"\n  signature: \"Tensor (Tensor input, Int32 unpack_num) => GradAccUnpack\"\n  bind_python: False\n\n- name: \"trunc\"\n  signature: \"Tensor (Tensor input) => Trunc\"\n  bind_python: True\n\n- name: \"silu_grad_grad\"\n  signature: \"Tensor (Tensor x, Tensor dydx) => SiluGradGrad\"\n  bind_python: False\n\n- name: \"mish_grad_grad\"\n  signature: \"Tensor (Tensor x, Tensor dydx) => MishGradGrad\"\n  bind_python: False\n\n- name: \"selu_grad_grad\"\n  signature: \"Tensor (Tensor x, Tensor dydx) => SeluGradGrad\"\n  bind_python: False\n\n- name: \"softsign_grad_grad\"\n  signature: \"Tensor (Tensor x, Tensor dydx) => SoftSignGradGrad\"\n  bind_python: False\n\n- name: \"gelu_grad_grad\"\n  signature: \"Tensor (Tensor x, Tensor dydx) => GeluGradGrad\"\n  bind_python: False\n\n- name: \"hardsigmoid_grad_grad\"\n  signature: \"Tensor (Tensor x, Tensor dydx) => HardSigmoidGradGrad\"\n  bind_python: False\n\n- name: \"hardswish_grad_grad\"\n  signature: \"Tensor (Tensor x, Tensor dydx) => HardSwishGradGrad\"\n  bind_python: False\n\n- name: \"softplus_grad_grad\"\n  signature: \"Tensor (Tensor x, Tensor dydx, Double beta=1.0, Double threshold=20.0) => SoftplusGradGrad\"\n  bind_python: False\n\n- name: \"elu_grad_grad\"\n  signature: \"Tensor (Tensor x, Tensor dydx, Double alpha) => EluGradGrad\"\n  bind_python: False\n\n- name: \"celu_grad_grad\"\n  signature: \"Tensor (Tensor y, Tensor dydx, Double alpha) => CeluGradGrad\"\n  bind_python: False\n\n- name: \"batch_norm_stats\"\n  signature: \"TensorTuple (Tensor input, Int32 axis, Float eps) => BatchNormStats\"\n  bind_python: True\n\n- name: \"batch_norm_gather_stats_with_counts\"\n  signature: \"TensorTuple (Tensor input, Tensor mean, Tensor invstd, Tensor running_mean=None, Tensor running_var=None, Float momentum, Float eps, Tensor counts) => BatchNormGatherStatsWithCounts\"\n  bind_python: True\n\n- name: \"batch_norm_elemt\"\n  signature: \"Tensor (Tensor input, Tensor weight, Tensor bias, Tensor mean, Tensor invstd, Int32 axis, Float eps) => BatchNormElemt\"\n  bind_python: True\n\n- name: \"batch_norm_backward_reduce\"\n  signature: \"TensorTuple (Tensor grad_out, Tensor input, Tensor mean, Tensor invstd, Int32 axis) => BatchNormBackwardReduce\"\n  bind_python: True\n\n- name: \"batch_norm_backward_elemt\"\n  signature: \"Tensor (Tensor grad_out, Tensor input, Tensor mean, Tensor invstd, Tensor weight, Tensor sum_dy, Tensor sum_dy_xmu, Tensor count, Int32 axis) => BatchNormBackwardElemt\"\n  bind_python: True\n\n- name: \"adaptive_max_pool1d\"\n  signature: 'TensorTuple (Tensor input, Int64List[1] output_size, String data_format=\"channels_first\") => AdaptiveMaxPool1D'\n  bind_python: True\n\n- name: \"adaptive_max_pool2d\"\n  signature: 'TensorTuple (Tensor input, Int64List[2] output_size, String data_format=\"channels_first\") => AdaptiveMaxPool2D'\n  bind_python: True\n\n- name: \"adaptive_max_pool3d\"\n  signature: 'TensorTuple (Tensor input, Int64List[3] output_size, String data_format=\"channels_first\") => AdaptiveMaxPool3D'\n  bind_python: True\n\n- name: \"adaptive_max_pool_grad\"\n  signature: 'Tensor (Tensor x, Tensor index, Tensor dy, Int32 ndims, String data_format=\"channels_first\") => AdaptiveMaxPoolNdGrad'\n  bind_python: False\n\n- name: \"tan_grad_grad\"\n  signature: \"Tensor (Tensor x, Tensor dydx) =>  TanGradGrad\"\n  bind_python: False\n\n- name: \"sinh_grad_grad\"\n  signature: \"Tensor (Tensor x, Tensor dydx) =>  SinhGradGrad\"\n  bind_python: False\n\n- name: \"cosh_grad_grad\"\n  signature: \"Tensor (Tensor x, Tensor dydx) =>  CoshGradGrad\"\n  bind_python: False\n\n- name: \"tanh_grad_grad\"\n  signature: \"Tensor (Tensor x, Tensor dydx) =>  TanhGradGrad\"\n  bind_python: False\n\n- name: \"acos_grad_grad\"\n  signature: \"Tensor (Tensor x, Tensor dydx) =>  AcosGradGrad\"\n  bind_python: False\n\n- name: \"asin_grad_grad\"\n  signature: \"Tensor (Tensor x, Tensor dydx) =>  AsinGradGrad\"\n  bind_python: False\n\n- name: \"atan_grad_grad\"\n  signature: \"Tensor (Tensor x, Tensor dydx) =>  AtanGradGrad\"\n  bind_python: False\n\n- name: \"asinh_grad_grad\"\n  signature: \"Tensor (Tensor x, Tensor dydx) =>  AsinhGradGrad\"\n  bind_python: False\n\n- name: \"acosh_grad_grad\"\n  signature: \"Tensor (Tensor x, Tensor dydx) =>  AcoshGradGrad\"\n  bind_python: False\n\n- name: \"atanh_grad_grad\"\n  signature: \"Tensor (Tensor x, Tensor dydx) =>  AtanhGradGrad\"\n  bind_python: False\n\n- name: \"erf_grad_grad\"\n  signature: \"Tensor (Tensor x, Tensor dydx) =>  ErfGradGrad\"\n  bind_python: False\n\n- name: \"erfc_grad_grad\"\n  signature: \"Tensor (Tensor x, Tensor dydx) =>  ErfcGradGrad\"\n  bind_python: False\n\n- name: \"exp_grad_grad\"\n  signature: \"Tensor (Tensor x, Tensor dydx) =>  ExpGradGrad\"\n  bind_python: False\n\n- name: \"exp2_grad_grad\"\n  signature: \"Tensor (Tensor x, Tensor dydx) =>  Exp2GradGrad\"\n  bind_python: False\n\n- name: \"expm1_grad_grad\"\n  signature: \"Tensor (Tensor x, Tensor dydx) =>  Expm1GradGrad\"\n  bind_python: False\n\n- name: \"log_grad_grad\"\n  signature: \"Tensor (Tensor x, Tensor dydx) =>  LogGradGrad\"\n  bind_python: False\n\n- name: \"logsigmoid_grad_grad\"\n  signature: \"Tensor (Tensor x, Tensor dydx) =>  LogSigmoidGradGrad\"\n  bind_python: False\n\n- name: \"log2_grad_grad\"\n  signature: \"Tensor (Tensor x, Tensor dydx) =>  Log2GradGrad\"\n  bind_python: False\n\n- name: \"log10_grad_grad\"\n  signature: \"Tensor (Tensor x, Tensor dydx) =>  Log10GradGrad\"\n  bind_python: False\n\n- name: \"log1p_grad_grad\"\n  signature: \"Tensor (Tensor x, Tensor dydx) =>  Log1pGradGrad\"\n  bind_python: False\n\n- name: \"reciprocal_grad_grad\"\n  signature: \"Tensor (Tensor x, Tensor dydx) =>  ReciprocalGradGrad\"\n  bind_python: False\n\n- name: \"reciprocal_no_nan_grad_grad\"\n  signature: \"Tensor (Tensor x, Tensor dydx) =>  ReciprocalNoNanGradGrad\"\n  bind_python: False\n\n- name: \"rsqrt_grad_grad\"\n  signature: \"Tensor (Tensor x, Tensor dydx) =>  RsqrtGradGrad\"\n  bind_python: False\n\n- name: \"sqrt_grad_grad\"\n  signature: \"Tensor (Tensor x, Tensor dydx) =>  SqrtGradGrad\"\n  bind_python: False\n\n- name: \"square_grad_grad\"\n  signature: \"Tensor (Tensor x, Tensor dydx) =>  SquareGradGrad\"\n  bind_python: False\n\n- name: \"sigmoid_grad_grad\"\n  signature: \"Tensor (Tensor y, Tensor dydx) =>  SigmoidGradGrad\"\n  bind_python: False\n\n- name: \"max_pool_grad_grad\"\n  signature: \"Tensor (Tensor dydx, Tensor indices, Int32 ndims) =>  MaxPoolNdGradGrad\"\n  bind_python: False\n\n- name: \"exponential_\"\n  signature: \"Tensor (Tensor x, Float lambd=1.0, Generator generator=None) => Exponential\"\n  bind_python: True\n\n- name: \"multinomial\"\n  signature: \"Tensor (Tensor x, Int32 num_samples, Bool replacement=False, Generator generator=None) => Multinomial\"\n  bind_python: True\n\n- name: \"max_pool_grad_grad\"  \n  signature: \"Tensor (Tensor dydx, Tensor indices, Int32 ndims) =>  MaxPoolNdGradGrad\" \n  bind_python: False\n\n- name: \"deform_conv2d\"\n  signature:\n    \"Tensor (Tensor input,Tensor weight,Tensor offset,Tensor mask,Tensor bias=None, Int32 stride_h,Int32 stride_w,Int32 pad_h,\n    Int32 pad_w,Int32 dilation_h,Int32 dilation_w,Int32 groups,Int32 offset_groups,Bool use_mask) => DeformConv2d\"\n  bind_python: True\n\n- name: \"deform_conv2d_input_grad\"\n  signature:\n    \"TensorTuple (Tensor output_grad,Tensor input,Tensor weight,Tensor offset,Tensor mask=None, Int32 stride_h,Int32 stride_w,Int32 pad_h,\n    Int32 pad_w,Int32 dilation_h,Int32 dilation_w,Int32 groups,Int32 offset_groups,Bool use_mask) => DeformConv2dInputGrad\"\n  bind_python: False\n\n- name: \"deform_conv2d_param_grad\"\n  signature:\n    \"Tensor (Tensor output_grad,Tensor input,Tensor weight,Tensor offset,Tensor mask, Int32 stride_h,Int32 stride_w,Int32 pad_h,\n    Int32 pad_w,Int32 dilation_h,Int32 dilation_w,Int32 groups,Int32 offset_groups,Bool use_mask) => DeformConv2dParamGrad\"\n  bind_python: False\n\n- name: \"broadcast_shapes\"\n  signature: \"Shape (ShapeList shapes) => BroadcastShapes\"\n  bind_python: True\n\n- name: \"broadcast_tensors\"\n  signature: \"TensorTuple (TensorTuple tensors) => BroadcastTensors\"\n  bind_python: True\n\n- name: \"broadcast_to\"\n  signature: \"Tensor (Tensor x, Shape shape) => BroadcastTo\"\n  bind_python: True\n\n- name: \"bincount\"\n  signature: \"Tensor (Tensor input, Tensor weights=None, Int64 minlength=None) => BinCount\"\n  bind_python: True\n\n- name: \"stft\"\n  signature: \n    'Tensor (Tensor input, Int64 n_fft,Int64 hop_length=None, Int64 win_length=None, Tensor window=None,Bool center=True,String pad_mode=\"reflect\",Bool normalized=False,Bool onesided=True,Bool return_complex=False) =>Stft'\n  bind_python: True\n\n- name: \"fft_c2c\"\n  signature:\n    'Tensor (Tensor input, Int64List n=None, Int64List dims=None, Int32 norm_mode=0, Bool forward=True, Bool normalized=False) => FftC2C'\n  bind_python: False\n\n- name: \"fft_r2c\"\n  signature:\n    'Tensor (Tensor input, Int64List n=None, Int64List dims=None, Int32 norm_mode=0, Bool onesided=False, Bool forward=True, Bool normalized=False) => FftR2C'\n  bind_python: False\n\n- name: \"fft_c2r\"\n  signature:\n    'Tensor (Tensor input, Int64List n=None, Int64List dims=None, Int32 norm_mode=0, Bool forward=True, Bool normalized=False) =>FftC2R'\n  bind_python: False\n\n- name: \"fft\"\n  signature: \n    'Tensor (Tensor input, Int64 n=-1, Int64 dim=-1, String norm=None) => Fft'\n  bind_python: True\n\n- name: \"ifft\"\n  signature: \n    'Tensor (Tensor input, Int64 n=-1, Int64 dim=-1, String norm=None) => IFft'\n  bind_python: True\n\n- name: \"fft2\"\n  signature: \n    'Tensor (Tensor input, Int64List s=None, Int64List dim, String norm=None) => Fft2'\n  bind_python: True\n\n- name: \"ifft2\"\n  signature: \n    'Tensor (Tensor input, Int64List s=None, Int64List dim, String norm=None) => IFft2'\n  bind_python: True\n\n- name: \"fftn\"\n  signature: \n    'Tensor (Tensor input, Int64List s=None, Int64List dim=None, String norm=None) => FftN'\n  bind_python: True\n\n- name: \"ifftn\"\n  signature: \n    'Tensor (Tensor input, Int64List s=None, Int64List dim=None, String norm=None) => IFftN'\n  bind_python: True\n\n- name: \"rfft\"\n  signature: \n    'Tensor (Tensor input, Int64 n=-1, Int64 dim=-1, String norm=None) => RFft'\n  bind_python: True\n\n- name: \"irfft\"\n  signature: \n    'Tensor (Tensor input, Int64 n=-1, Int64 dim=-1, String norm=None) => IRFft'\n  bind_python: True\n\n- name: \"rfft2\"\n  signature: \n    'Tensor (Tensor input, Int64List s=None, Int64List dim, String norm=None) => RFft2'\n  bind_python: True\n\n- name: \"irfft2\"\n  signature: \n    'Tensor (Tensor input, Int64List s=None, Int64List dim, String norm=None) => IRFft2'\n  bind_python: True\n\n- name: \"rfftn\"\n  signature: \n    'Tensor (Tensor input, Int64List s=None, Int64List dim=None, String norm=None) => RFftN'\n  bind_python: True\n\n- name: \"irfftn\"\n  signature: \n    'Tensor (Tensor input, Int64List s=None, Int64List dim=None, String norm=None) => IRFftN'\n  bind_python: True\n\n- name: \"hfft\"\n  signature: \n    'Tensor (Tensor input, Int64 n=-1, Int64 dim=-1, String norm=None) => HFft'\n  bind_python: True\n\n- name: \"ihfft\"\n  signature: \n    'Tensor (Tensor input, Int64 n=-1, Int64 dim=-1, String norm=None) => IHFft'\n  bind_python: True\n\n- name: \"hfft2\"\n  signature: \n    'Tensor (Tensor input, Int64List s=None, Int64List dim, String norm=None) => HFft2'\n  bind_python: True\n\n- name: \"ihfft2\"\n  signature: \n    'Tensor (Tensor input, Int64List s=None, Int64List dim, String norm=None) => IHFft2'\n  bind_python: True\n\n- name: \"hfftn\"\n  signature: \n    'Tensor (Tensor input, Int64List s=None, Int64List dim=None, String norm=None) => HFftN'\n  bind_python: True\n\n- name: \"ihfftn\"\n  signature: \n    'Tensor (Tensor input, Int64List s=None, Int64List dim=None, String norm=None) => IHFftN'\n  bind_python: True\n\n- name: \"isclose\"\n  signature: \"Tensor (Tensor input, Tensor other, Float atol=1e-08, Float rtol=1e-05, Bool equal_nan=False) => IsClose\"\n  bind_python: True\n\n- name: \"uniform_\"\n  signature: \"Tensor (Tensor x,Scalar from, Scalar to) => InplaceUniform\"\n  bind_python: True\n\n- name: \"fused_fast_gelu_mul\"\n  signature: \"Tensor (Tensor x, Tensor multiplier) => FusedFastGeluMul\"\n  bind_python: True\n\n- name: \"fused_fast_gelu_mul_grad\"\n  signature: \"TensorTuple (Tensor dy, Tensor x, Tensor multiplier) => FusedFastGeluMulGrad\"\n  bind_python: False\n\n- name: \"unique\"\n  signature: [\n    \"Tensor (Tensor x, Bool sorted=True, DataType dtype=kInt32) => Unique\",\n    \"TensorTuple (Tensor x, Bool sorted=True, Bool return_inverse=False, Bool return_counts=False, DataType dtype=kInt32) => UniqueWithCounts\"\n  ]\n  bind_python: True\n\n- name: \"fused_weighted_sum\"\n  signature: \"Tensor (TensorTuple in, FloatList weights, Float alpha=1.0) => FusedWeightedSum\"\n  bind_python: True\n\n- name: \"sort\"\n  signature: \"TensorTuple[values, indices] (Tensor input, Int32 dim=-1, Bool descending=False) => Sort\"\n  bind_python: True\n  \n- name: \"throw_error\"\n  signature: \"Tensor (Tensor input) => ThrowError\"\n  bind_python: True\n\n- name: \"mode\"\n  signature: \"TensorTuple[values, indices] (Tensor input, Int32 dim=-1, Bool keepdim=False) => Mode\"\n  bind_python: True\n\n- name: \"clone\"\n  signature: \"Tensor (Tensor input) => Clone\"\n  bind_python: True\n\n- name: \"real\"\n  signature: \"Tensor (Tensor x) => Real\"\n  bind_python: True\n\n- name: \"real_grad\"\n  signature: \"Tensor (Tensor dout) => RealGrad\"\n  bind_python: False\n\n- name: \"imag\"\n  signature: \"Tensor (Tensor x) => Imag\"\n  bind_python: True\n\n- name: \"imag_grad\"\n  signature: \"Tensor (Tensor dout) => ImagGrad\"\n  bind_python: False\n\n- name: \"conj\"\n  signature: \"Tensor (Tensor x) => Conj\"\n  bind_python: True\n\n- name: \"conj_physical\"\n  signature: \"Tensor (Tensor x) => ConjPhysical\"\n  bind_python: True\n\n- name: \"frac\"\n  signature: \"Tensor (Tensor x) => Frac\"\n  bind_python: True\n\n- name: \"frac_\"\n  signature: \"Tensor (Tensor x) => FracInplace\"\n  bind_python: True\n\n- name: \"digamma\"\n  signature: \"Tensor (Tensor x) => Digamma\"\n  bind_python: True\n\n- name: \"digamma_grad\"\n  signature: \"Tensor (Tensor x, Tensor dy) => DigammaGrad\"\n  bind_python: False\n\n- name: \"trigamma\"\n  signature: \"Tensor (Tensor x) => Trigamma\"\n  bind_python: False\n  \n- name: \"zeta\"\n  signature: [\n    \"Tensor (Tensor x, Tensor other) => BroadcastZeta\",\n    \"Tensor (Scalar x, Tensor other) => ZetaScalarTensor\",\n    \"Tensor (Tensor x, Scalar other) => ZetaTensorScalar\",\n  ]\n  bind_python: True\n\n- name: \"fused_clip_grad\"\n  signature: \"Tensor (TensorTuple model_diff, Float max_norm, Float norm_type) => FusedClipGrad\"\n  bind_python: True\n"
  },
  {
    "path": "oneflow/core/functional/impl/activation_functor.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/container_util.h\"\n#include \"oneflow/core/common/error.h\"\n#include \"oneflow/core/common/scalar.h\"\n#include \"oneflow/core/functional/functional.h\"\n#include \"oneflow/core/functional/functional_api.yaml.h\"\n#include \"oneflow/core/functional/function_library.h\"\n#include \"oneflow/core/functional/impl/unary_functor.h\"\n#include \"oneflow/core/functional/impl/binary_functor.h\"\n#include \"oneflow/core/functional/sequence_function.h\"\n#include \"oneflow/core/framework/attr_map.h\"\n#include \"oneflow/core/framework/mutable_attr_map.h\"\n#include \"oneflow/core/framework/op_builder.h\"\n#include \"oneflow/core/framework/op_expr.h\"\n#include \"oneflow/core/framework/op_interpreter/op_interpreter_util.h\"\n#include \"oneflow/core/framework/tensor.h\"\n#include \"oneflow/core/framework/tensor_util.h\"\n#include \"oneflow/core/framework/tensor_tuple.h\"\n#include \"oneflow/core/autograd/autograd_mode.h\"\n#include \"oneflow/core/kernel/kernel_util.h\"\n#include \"oneflow/user/kernels/distributions/common.h\"\n#include \"oneflow/user/kernels/random_seed_util.h\"\n\nnamespace oneflow {\nnamespace one {\nnamespace functional {\n\nnamespace impl {\n\nclass ReluFunctor {\n public:\n  ReluFunctor() { op_ = CHECK_JUST(one::OpBuilder(\"relu\").Input(\"x\", 1).Output(\"y\", 1).Build()); }\n  Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& x, bool inplace) const {\n    if (inplace) {\n      JUST(CheckInplaceValid(x));\n      std::shared_ptr<TensorTuple> outputs = std::make_shared<TensorTuple>(1);\n      outputs->at(0) = x;\n      JUST(OpInterpUtil::Dispatch(*op_, {x}, outputs.get(), AttrMap{}));\n      return outputs->at(0);\n    } else {\n      return OpInterpUtil::Dispatch<Tensor>(*op_, {x});\n    }\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass ReluGradFunctor : public BinaryFunctor {\n public:\n  ReluGradFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"relu_grad\").Input(\"dy\").Input(\"y\").Output(\"dx\").Build());\n  }\n};\n\nclass PReluFunctor {\n public:\n  PReluFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"prelu\").Input(\"x\").Input(\"alpha\").Output(\"y\").Build());\n  }\n\n  Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& x,\n                           const std::shared_ptr<Tensor>& alpha) const {\n    int num_params = alpha->dim(0);\n    CHECK_OR_RETURN(((num_params == 1) || (num_params == x->shape()->At(1))))\n        << Error::RuntimeError() << \"num_parameters in prelu must be 1 or \" << x->shape()->At(1);\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {x, alpha});\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass PReluGradFunctor {\n public:\n  PReluGradFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"prelu_grad\")\n                         .Input(\"dy\")\n                         .Input(\"x\")\n                         .Input(\"alpha\")\n                         .Output(\"dx\")\n                         .Output(\"alpha_diff\")\n                         .Build());\n  }\n  Maybe<TensorTuple> operator()(const std::shared_ptr<Tensor>& dy, const std::shared_ptr<Tensor>& x,\n                                const std::shared_ptr<Tensor>& alpha) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"alpha_requires_grad\");\n    attrs.SetAllAttrs(alpha->requires_grad());\n    return OpInterpUtil::Dispatch<one::TensorTuple>(*op_, {dy, x, alpha}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass HardTanhFunctor {\n public:\n  HardTanhFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"hardtanh\").Input(\"in\").Output(\"out\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const double& min_val,\n                           const double& max_val) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"min_val\", \"max_val\");\n    attrs.SetAllAttrs(min_val, max_val);\n    return OpInterpUtil::Dispatch<one::Tensor>(*op_, {x}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass HardTanhGradFunctor {\n public:\n  HardTanhGradFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"hardtanh_grad\").Input(\"y\").Input(\"dy\").Output(\"dx\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& y,\n                           const std::shared_ptr<one::Tensor>& dy, const double& min_val,\n                           const double& max_val) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"min_val\", \"max_val\");\n    attrs.SetAllAttrs(min_val, max_val);\n    return OpInterpUtil::Dispatch<one::Tensor>(*op_, {y, dy}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass EluFunctor {\n public:\n  EluFunctor() { op_ = CHECK_JUST(one::OpBuilder(\"elu\").Input(\"in\").Output(\"out\").Build()); }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const double& alpha) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"alpha\");\n    attrs.SetAllAttrs(alpha);\n    return OpInterpUtil::Dispatch<one::Tensor>(*op_, {x}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass EluGradFunctor {\n public:\n  EluGradFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"elu_grad\").Input(\"x\").Input(\"dy\").Output(\"dx\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,\n                           const std::shared_ptr<one::Tensor>& dy, const double& alpha) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"alpha\");\n    attrs.SetAllAttrs(alpha);\n    return OpInterpUtil::Dispatch<one::Tensor>(*op_, {x, dy}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass CeluFunctor {\n public:\n  CeluFunctor() { op_ = CHECK_JUST(one::OpBuilder(\"celu\").Input(\"in\").Output(\"out\").Build()); }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const double& alpha,\n                           bool inplace) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"alpha\");\n    attrs.SetAllAttrs(alpha);\n    if (inplace) {\n      JUST(CheckInplaceValid(x));\n      std::shared_ptr<TensorTuple> outputs = std::make_shared<TensorTuple>(1);\n      (*outputs)[0] = x;\n      JUST(OpInterpUtil::Dispatch(*op_, {x}, outputs.get(), attrs));\n      return outputs->at(0);\n    } else {\n      return OpInterpUtil::Dispatch<one::Tensor>(*op_, {x}, attrs);\n    }\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass CeluGradFunctor {\n public:\n  CeluGradFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"celu_grad\").Input(\"y\").Input(\"dy\").Output(\"dx\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& y,\n                           const std::shared_ptr<one::Tensor>& dy, const double& alpha) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"alpha\");\n    attrs.SetAllAttrs(alpha);\n    return OpInterpUtil::Dispatch<one::Tensor>(*op_, {y, dy}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass GeluFunctor : public UnaryFunctor {\n public:\n  GeluFunctor() { op_ = CHECK_JUST(one::OpBuilder(\"gelu\").Input(\"in\").Output(\"out\").Build()); }\n};\n\nclass GeluGradFunctor : public BinaryFunctor {\n public:\n  GeluGradFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"gelu_grad\").Input(\"dy\").Input(\"x\").Output(\"dx\").Build());\n  }\n};\n\nclass FastGeluFunctor : public UnaryFunctor {\n public:\n  FastGeluFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"fast_gelu\").Input(\"in\").Output(\"out\").Build());\n  }\n};\n\nclass FastGeluGradFunctor : public BinaryFunctor {\n public:\n  FastGeluGradFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"fast_gelu_grad\").Input(\"dy\").Input(\"x\").Output(\"dx\").Build());\n  }\n};\n\nclass QuickGeluFunctor : public UnaryFunctor {\n public:\n  QuickGeluFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"quick_gelu\").Input(\"x\").Output(\"y\").Build());\n  }\n};\n\nclass QuickGeluGradFunctor : public BinaryFunctor {\n public:\n  QuickGeluGradFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"quick_gelu_grad\").Input(\"dy\").Input(\"x\").Output(\"dx\").Build());\n  }\n};\n\nclass SquareReLUFunctor : public UnaryFunctor {\n public:\n  SquareReLUFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"square_relu\").Input(\"x\").Output(\"y\").Build());\n  }\n};\n\nclass SquareReLUGradFunctor : public BinaryFunctor {\n public:\n  SquareReLUGradFunctor() {\n    op_ =\n        CHECK_JUST(one::OpBuilder(\"square_relu_grad\").Input(\"dy\").Input(\"x\").Output(\"dx\").Build());\n  }\n};\n\nclass GluFunctor {\n public:\n  GluFunctor() {}\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input, int64_t dim) const {\n    const auto ndim = input->ndim();\n    CHECK_GT_OR_RETURN(ndim, 0) << Error::RuntimeError()\n                                << \"glu does not support scalars because halving size must be even\";\n    dim = JUST(maybe_wrap_dim(dim, ndim));\n    if (dim < 0) { dim += ndim; }\n    int64_t nc = input->dim(dim);\n    CHECK_EQ_OR_RETURN(nc % 2, 0) << Error::RuntimeError()\n                                  << \"Halving dimension must be even, but dimension \" << dim\n                                  << \" is size \" << nc;\n    nc = nc / 2;\n    std::vector<int64_t> split_sizes(2, nc);\n    const auto split_x = JUST(SplitWithSize(input, split_sizes, dim));\n    return sequence_function(functional::Sigmoid)\n        .then(std::bind(functional::Mul, (*split_x)[0], std::placeholders::_1))\n        .call((*split_x)[1]);\n  }\n};\n\nclass HardSigmoidFunctor {\n public:\n  HardSigmoidFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"hardsigmoid\").Input(\"in\").Output(\"out\").Build());\n  }\n\n  Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& input, bool inplace) const {\n    if (inplace) {\n      JUST(CheckInplaceValid(input));\n      std::shared_ptr<TensorTuple> outputs = std::make_shared<TensorTuple>(1);\n      outputs->at(0) = input;\n      JUST(OpInterpUtil::Dispatch(*op_, {input}, outputs.get(), AttrMap{}));\n      return outputs->at(0);\n    } else {\n      return OpInterpUtil::Dispatch<Tensor>(*op_, {input});\n    }\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\nclass HardSigmoidGradFunctor : public BinaryFunctor {\n public:\n  HardSigmoidGradFunctor() {\n    op_ =\n        CHECK_JUST(one::OpBuilder(\"hardsigmoid_grad\").Input(\"dy\").Input(\"x\").Output(\"dx\").Build());\n  }\n};\n\nclass HardShrinkFunctor {\n public:\n  HardShrinkFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"hardshrink\").Input(\"in\").Output(\"out\").Build());\n  }\n\n  Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& x, const double& lambd,\n                           bool inplace) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"lambd\");\n    attrs.SetAllAttrs(lambd);\n    if (inplace) {\n      JUST(CheckInplaceValid(x));\n      std::shared_ptr<TensorTuple> outputs = std::make_shared<TensorTuple>(1);\n      JUST(oneflow::VectorAt(*outputs, 0)) = x;\n      JUST(OpInterpUtil::Dispatch(*op_, {x}, outputs.get(), attrs));\n      return JUST(oneflow::VectorAt(*outputs, 0));\n    } else {\n      return OpInterpUtil::Dispatch<one::Tensor>(*op_, {x}, attrs);\n    }\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass HardShrinkGradFunctor {\n public:\n  HardShrinkGradFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"hardshrink_grad\").Input(\"y\").Input(\"dy\").Output(\"dx\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& y, const std::shared_ptr<Tensor>& dy,\n                           const double& lambd) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"lambd\");\n    attrs.SetAllAttrs(lambd);\n    return OpInterpUtil::Dispatch<one::Tensor>(*op_, {y, dy}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass SoftmaxFunctorBase {\n public:\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input,\n                           const Optional<int64_t>& dim) const {\n    const auto input_shape = input->shape();\n    const int64_t num_axes = input_shape->NumAxes();\n\n    const auto get_dim = [num_axes]() -> int64_t {\n      const int64_t ndim = num_axes;\n      if (ndim == 0 || ndim == 1 || ndim == 3) {\n        return 0;\n      } else {\n        return 1;\n      }\n    };\n\n    int64_t dim_ = dim ? JUST(dim) : get_dim();\n    dim_ = JUST(maybe_wrap_dim(dim_, num_axes));\n    if (dim_ != num_axes - 1) {\n      std::vector<int> input_perm(input_shape->dim_vec().size(), 0);\n      for (size_t i = 1; i < input_perm.size(); ++i) { input_perm[i] = i; }\n      input_perm[dim_] = input_perm[input_perm.size() - 1];\n      input_perm[input_perm.size() - 1] = dim_;\n\n      return sequence_function(functional::Transpose)\n          .then([&](const std::shared_ptr<one::Tensor>& x) {\n            return OpInterpUtil::Dispatch<Tensor>(*op_, {x});\n          })\n          .then(std::bind(functional::Transpose, std::placeholders::_1, input_perm))\n          .call(input, input_perm);\n    }\n\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {input});\n  }\n\n protected:\n  SoftmaxFunctorBase() = default;\n  virtual ~SoftmaxFunctorBase() = default;\n\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass SoftmaxFunctor : public SoftmaxFunctorBase {\n public:\n  SoftmaxFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"softmax\").Input(\"in\").Output(\"out\").Build());\n  }\n};\n\nclass SoftmaxGradFunctor {\n public:\n  SoftmaxGradFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"softmax_grad\").Input(\"y\").Input(\"dy\").Output(\"dx\").Build());\n  }\n\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& dy,\n                           const std::shared_ptr<one::Tensor>& y) const {\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {y, dy});\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass LogSoftmaxFunctor : public SoftmaxFunctorBase {\n public:\n  LogSoftmaxFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"log_softmax\").Input(\"in\").Output(\"prob\").Build());\n  }\n};\n\nclass LogSoftmaxGradFunctor {\n public:\n  LogSoftmaxGradFunctor() {\n    op_ = CHECK_JUST(\n        one::OpBuilder(\"log_softmax_grad\").Input(\"prob\").Input(\"dy\").Output(\"dx\").Build());\n  }\n\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& dy,\n                           const std::shared_ptr<one::Tensor>& y) const {\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {y, dy});\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass GumbelSoftmaxFunctor {\n public:\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& in, const double& tau,\n                           const Optional<int64_t>& dim, bool hard,\n                           const Optional<one::Generator>& generator) const {\n    auto in_shape = in->shape();\n    auto device = JUST(in->device());\n    auto dtype = in->dtype();\n    const int64_t num_axes = in_shape->NumAxes();\n\n    const auto gen = generator.value_or(JUST(one::DefaultAutoGenerator()));\n    auto random_tensor =\n        JUST(functional::Rand(*in_shape.get(), dtype, device, gen, /*requires_grad=*/false));\n    auto gumbel_noise_tensor = JUST(functional::ScalarSub(\n        Scalar(0.0),\n        JUST(functional::Log(JUST(functional::ScalarSub(\n            Scalar(0.0), JUST(functional::Log(random_tensor)), /*alpha=*/1.0)))),\n        /*alpha=*/1.0));\n    auto gumbel_in_tensor = JUST(functional::ScalarDiv(\n        JUST(functional::Add(in, gumbel_noise_tensor, /*alpha=*/1.0, /*inplace=*/false)),\n        Scalar(tau)));\n\n    auto out_soft = JUST(functional::Softmax(gumbel_in_tensor, dim));\n    if (hard) {\n      const auto get_dim = [num_axes]() -> int64_t {\n        const int64_t ndim = num_axes;\n        if (ndim == 0 || ndim == 1 || ndim == 3) {\n          return 0;\n        } else {\n          return 1;\n        }\n      };\n\n      int64_t dim_ = dim ? JUST(dim) : get_dim();\n      dim_ = JUST(maybe_wrap_dim(dim_, num_axes));\n      auto out_max = JUST(functional::ArgMax(out_soft, dim_, /*keepdim=*/true, dtype));\n      auto index =\n          JUST(functional::To(out_max, JUST(DType::Get(DataType::kInt64)), /*copy=*/false));\n      auto zero = JUST(functional::ZerosLike(out_soft));\n      auto out_hard =\n          JUST(functional::DimScatterUpdateScalar(zero, dim_, index, 1.0, /*inplace=*/false));\n\n      auto out_hard_has_grad =\n          functional::Add(JUST(functional::Sub(out_hard, JUST(out_soft->detach()), /*alpha=*/1.0,\n                                               /*inplace=*/false)),\n                          out_soft, /*alpha=*/1.0, /*inplace=*/false);\n      return out_hard_has_grad;\n    } else {\n      return out_soft;\n    }\n  }\n};\n\nclass HardSwishFunctor : public UnaryFunctor {\n public:\n  HardSwishFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"hardswish\").Input(\"in\").Output(\"out\").Build());\n  }\n};\n\nclass HardSwishGradFunctor : public BinaryFunctor {\n public:\n  HardSwishGradFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"hardswish_grad\").Input(\"dy\").Input(\"x\").Output(\"dx\").Build());\n  }\n};\n\nclass LeakyReluFunctor {\n public:\n  LeakyReluFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"leaky_relu\").Input(\"x\").Output(\"y\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const float& alpha,\n                           bool inplace) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"alpha\");\n    attrs.SetAllAttrs(alpha);\n    if (inplace) {\n      JUST(CheckInplaceValid(x));\n      std::shared_ptr<TensorTuple> outputs = std::make_shared<TensorTuple>(1);\n      JUST(oneflow::VectorAt(*outputs, 0)) = x;\n      JUST(OpInterpUtil::Dispatch(*op_, {x}, outputs.get(), attrs));\n      return JUST(oneflow::VectorAt(*outputs, 0));\n    } else {\n      return OpInterpUtil::Dispatch<one::Tensor>(*op_, {x}, attrs);\n    }\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass LeakyReluGradFunctor {\n public:\n  LeakyReluGradFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"leaky_relu_grad\").Input(\"x\").Input(\"dy\").Output(\"dx\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,\n                           const std::shared_ptr<one::Tensor>& dy, const float& alpha) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"alpha\");\n    attrs.SetAllAttrs(alpha);\n    return OpInterpUtil::Dispatch<one::Tensor>(*op_, {x, dy}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass RReluFunctor {\n public:\n  RReluFunctor() {\n    op_ = CHECK_JUST(\n        one::OpBuilder(\"rrelu\").Input(\"in\").Output(\"output\").Output(\"noise_data\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const float& lower,\n                           const float& upper, bool training, bool inplace) const {\n    if (!training) { return JUST(functional::LeakyRelu(x, ((lower + upper) / 2), inplace)); }\n\n    auto gen = JUST(\n        GetGeneratorForLazyOrGlobal(JUST(one::DefaultAutoGenerator()), LazyMode::is_enabled(), x));\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"seed\", \"lower\", \"upper\", \"training\");\n    attrs.SetAllAttrs(static_cast<int64_t>(gen->current_seed()), lower, upper, training);\n    const auto& state = std::make_shared<DistributionKernelState>(gen);\n\n    OpExprInterpContext ctx(attrs, state);\n    std::shared_ptr<TensorTuple> outputs = std::make_shared<TensorTuple>(2);\n    if (inplace) {\n      JUST(CheckInplaceValid(x));\n      outputs->at(0) = x;\n    }\n    JUST(OpInterpUtil::Dispatch(*op_, {x}, outputs.get(), ctx));\n    return outputs->at(0);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass RReluInplaceFunctor {\n public:\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const float& lower,\n                           const float& upper, bool training) const {\n    return JUST(functional::RRelu(x, lower, upper, training, true /*inplace*/));\n  }\n};\n\nclass SoftplusFunctor {\n public:\n  SoftplusFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"softplus\").Input(\"in\").Output(\"out\").Build());\n  }\n\n  Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& x, const double& beta,\n                           const double& threshold) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"beta\", \"threshold\");\n    attrs.SetAllAttrs(beta, threshold);\n    return OpInterpUtil::Dispatch<one::Tensor>(*op_, {x}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass SoftplusGradFunctor {\n public:\n  SoftplusGradFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"softplus_grad\").Input(\"x\").Input(\"dy\").Output(\"dx\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& x, const std::shared_ptr<Tensor>& dy,\n                           const double& beta, const double& threshold) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"beta\", \"threshold\");\n    attrs.SetAllAttrs(beta, threshold);\n    return OpInterpUtil::Dispatch<one::Tensor>(*op_, {x, dy}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass SiluFunctor : public UnaryFunctor {\n public:\n  SiluFunctor() { op_ = CHECK_JUST(one::OpBuilder(\"silu\").Input(\"in\").Output(\"out\").Build()); }\n};\n\nclass SiluGradFunctor : public BinaryFunctor {\n public:\n  SiluGradFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"silu_grad\").Input(\"dy\").Input(\"x\").Output(\"dx\").Build());\n  }\n};\n\nclass MishFunctor : public UnaryFunctor {\n public:\n  MishFunctor() { op_ = CHECK_JUST(one::OpBuilder(\"mish\").Input(\"in\").Output(\"out\").Build()); }\n};\n\nclass MishGradFunctor : public BinaryFunctor {\n public:\n  MishGradFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"mish_grad\").Input(\"dy\").Input(\"x\").Output(\"dx\").Build());\n  }\n};\n\nclass SeluFunctor : public UnaryFunctor {\n public:\n  SeluFunctor() { op_ = CHECK_JUST(one::OpBuilder(\"selu\").Input(\"in\").Output(\"out\").Build()); }\n};\n\nclass SeluGradFunctor : public BinaryFunctor {\n public:\n  SeluGradFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"selu_grad\").Input(\"dy\").Input(\"x\").Output(\"dx\").Build());\n  }\n};\n\nclass SoftSignFunctor : public UnaryFunctor {\n public:\n  SoftSignFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"softsign\").Input(\"in\").Output(\"out\").Build());\n  }\n};\n\nclass SoftSignGradFunctor : public BinaryFunctor {\n public:\n  SoftSignGradFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"softsign_grad\").Input(\"dy\").Input(\"x\").Output(\"dx\").Build());\n  }\n};\n\nclass SoftShrinkFunctor {\n public:\n  SoftShrinkFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"softshrink\").Input(\"in\").Output(\"out\").Build());\n  }\n\n  Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& x, const double& alpha,\n                           bool inplace) const {\n    CHECK_GE_OR_RETURN(alpha, 0) << Error::RuntimeError()\n                                 << \"alpha must be greater or equal to 0, but found to be \" << alpha\n                                 << \".\";\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"alpha\");\n    attrs.SetAllAttrs(alpha);\n    if (inplace) {\n      JUST(CheckInplaceValid(x));\n      std::shared_ptr<TensorTuple> outputs = std::make_shared<TensorTuple>(1);\n      JUST(oneflow::VectorAt(*outputs, 0)) = x;\n      JUST(OpInterpUtil::Dispatch(*op_, {x}, outputs.get(), attrs));\n      return JUST(oneflow::VectorAt(*outputs, 0));\n    } else {\n      return OpInterpUtil::Dispatch<one::Tensor>(*op_, {x}, attrs);\n    }\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass ThresholdFunctor {\n public:\n  ThresholdFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"threshold\").Input(\"in\").Output(\"out\").Build());\n  }\n\n  Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& x, const double& threshold,\n                           const double& value) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"threshold_val\", \"value\");\n    attrs.SetAllAttrs(threshold, value);\n    return OpInterpUtil::Dispatch<one::Tensor>(*op_, {x}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass ThresholdGradFunctor {\n public:\n  ThresholdGradFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"threshold_grad\").Input(\"x\").Input(\"dy\").Output(\"dx\").Build());\n  }\n\n  Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& x, const std::shared_ptr<Tensor>& dy,\n                           const double& threshold) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"threshold_val\");\n    attrs.SetAllAttrs(threshold);\n    return OpInterpUtil::Dispatch<one::Tensor>(*op_, {x, dy}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass SoftShrinkGradFunctor {\n public:\n  SoftShrinkGradFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"softshrink_grad\").Input(\"y\").Input(\"dy\").Output(\"dx\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& y, const std::shared_ptr<Tensor>& dy,\n                           const double& alpha) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"alpha\");\n    attrs.SetAllAttrs(alpha);\n    return OpInterpUtil::Dispatch<one::Tensor>(*op_, {y, dy}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass FracFunctor {\n public:\n  FracFunctor() { op_ = CHECK_JUST(one::OpBuilder(\"frac\").Input(\"x\").Output(\"y\").Build()); }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x) const {\n    return OpInterpUtil::Dispatch<one::Tensor>(*op_, {x});\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass FracInplaceFunctor {\n public:\n  FracInplaceFunctor() { op_ = CHECK_JUST(one::OpBuilder(\"frac\").Input(\"x\").Output(\"y\").Build()); }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x) const {\n    JUST(CheckInplaceValid(x));\n    std::shared_ptr<TensorTuple> outputs = std::make_shared<TensorTuple>(1);\n    outputs->at(0) = x;\n    JUST(OpInterpUtil::Dispatch(*op_, {x}, outputs.get(), AttrMap{}));\n    return outputs->at(0);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\n}  // namespace impl\n\nONEFLOW_FUNCTION_LIBRARY(m) {\n  m.add_functor<impl::FracFunctor>(\"Frac\");\n  m.add_functor<impl::FracInplaceFunctor>(\"FracInplace\");\n  m.add_functor<impl::ReluFunctor>(\"Relu\");\n  m.add_functor<impl::ReluGradFunctor>(\"ReluGrad\");\n  m.add_functor<impl::PReluFunctor>(\"PRelu\");\n  m.add_functor<impl::PReluGradFunctor>(\"PReluGrad\");\n  m.add_functor<impl::HardTanhFunctor>(\"HardTanh\");\n  m.add_functor<impl::HardTanhGradFunctor>(\"HardTanhGrad\");\n  m.add_functor<impl::EluFunctor>(\"Elu\");\n  m.add_functor<impl::EluGradFunctor>(\"EluGrad\");\n  m.add_functor<impl::CeluFunctor>(\"Celu\");\n  m.add_functor<impl::CeluGradFunctor>(\"CeluGrad\");\n  m.add_functor<impl::GeluFunctor>(\"Gelu\");\n  m.add_functor<impl::GeluGradFunctor>(\"GeluGrad\");\n  m.add_functor<impl::FastGeluFunctor>(\"FastGelu\");\n  m.add_functor<impl::FastGeluGradFunctor>(\"FastGeluGrad\");\n  m.add_functor<impl::QuickGeluFunctor>(\"QuickGelu\");\n  m.add_functor<impl::QuickGeluGradFunctor>(\"QuickGeluGrad\");\n  m.add_functor<impl::SquareReLUFunctor>(\"SquareReLU\");\n  m.add_functor<impl::SquareReLUGradFunctor>(\"SquareReLUGrad\");\n  m.add_functor<impl::GluFunctor>(\"Glu\");\n  m.add_functor<impl::HardSigmoidFunctor>(\"HardSigmoid\");\n  m.add_functor<impl::HardSigmoidGradFunctor>(\"HardSigmoidGrad\");\n  m.add_functor<impl::HardShrinkFunctor>(\"HardShrink\");\n  m.add_functor<impl::HardShrinkGradFunctor>(\"HardShrinkGrad\");\n  m.add_functor<impl::SoftmaxFunctor>(\"Softmax\");\n  m.add_functor<impl::SoftmaxGradFunctor>(\"SoftmaxGrad\");\n  m.add_functor<impl::LogSoftmaxFunctor>(\"LogSoftmax\");\n  m.add_functor<impl::LogSoftmaxGradFunctor>(\"LogSoftmaxGrad\");\n  m.add_functor<impl::GumbelSoftmaxFunctor>(\"GumbelSoftmax\");\n  m.add_functor<impl::HardSwishFunctor>(\"HardSwish\");\n  m.add_functor<impl::HardSwishGradFunctor>(\"HardSwishGrad\");\n  m.add_functor<impl::LeakyReluFunctor>(\"LeakyRelu\");\n  m.add_functor<impl::LeakyReluGradFunctor>(\"LeakyReluGrad\");\n  m.add_functor<impl::RReluFunctor>(\"RRelu\");\n  m.add_functor<impl::RReluInplaceFunctor>(\"RReluInplace\");\n  m.add_functor<impl::SoftplusFunctor>(\"Softplus\");\n  m.add_functor<impl::SoftplusGradFunctor>(\"SoftplusGrad\");\n  m.add_functor<impl::SiluFunctor>(\"Silu\");\n  m.add_functor<impl::SiluGradFunctor>(\"SiluGrad\");\n  m.add_functor<impl::MishFunctor>(\"Mish\");\n  m.add_functor<impl::MishGradFunctor>(\"MishGrad\");\n  m.add_functor<impl::SeluFunctor>(\"Selu\");\n  m.add_functor<impl::SeluGradFunctor>(\"SeluGrad\");\n  m.add_functor<impl::SoftSignFunctor>(\"SoftSign\");\n  m.add_functor<impl::SoftSignGradFunctor>(\"SoftSignGrad\");\n  m.add_functor<impl::ThresholdFunctor>(\"Threshold\");\n  m.add_functor<impl::ThresholdGradFunctor>(\"ThresholdGrad\");\n  m.add_functor<impl::SoftShrinkFunctor>(\"SoftShrink\");\n  m.add_functor<impl::SoftShrinkGradFunctor>(\"SoftShrinkGrad\");\n};\n\n}  // namespace functional\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/functional/impl/array_functor.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/autograd/autograd_mode.h\"\n#include \"oneflow/core/common/container_util.h\"\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/framework/mutable_attr_map.h\"\n#include \"oneflow/core/framework/op_builder.h\"\n#include \"oneflow/core/framework/op_expr.h\"\n#include \"oneflow/core/framework/placement_utils.h\"\n#include \"oneflow/core/functional/function_library.h\"\n#include \"oneflow/core/functional/functional_api.yaml.h\"\n#include \"oneflow/core/functional/sequence_function.h\"\n#include \"oneflow/core/functional/impl/unary_functor.h\"\n#include \"oneflow/core/ep/include/device_manager_registry.h\"\n#include \"oneflow/core/job/global_mode.h\"\n#include \"oneflow/core/kernel/kernel_util.h\"\n#include \"oneflow/core/framework/tensor_util.h\"\n#include \"oneflow/core/job/nd_sbp_util.h\"\n#include \"oneflow/core/eager/tensor_storage.h\"\n#include <complex>\n\nnamespace oneflow {\nnamespace one {\nnamespace functional {\nnamespace impl {\n\nclass ArgMaxFunctor {\n public:\n  ArgMaxFunctor() { op_ = CHECK_JUST(one::OpBuilder(\"argmax\").Input(\"in\").Output(\"out\").Build()); }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input, const Optional<int32_t>& dim,\n                           const Optional<bool>& keepdim,\n                           const Optional<Symbol<DType>>& dtype) const {\n    if (dim.has_value() == false) {\n      return SequenceFunction<Maybe<Tensor>()>([&]() { return Flatten(input, 0, -1); })\n          .then([&](const std::shared_ptr<one::Tensor>& x) {\n            return OpInterpUtil::Dispatch<Tensor>(*op_, {x});\n          })\n          .call();\n    }\n\n    int new_dim = JUST(dim);\n    const int32_t ndims = input->shape()->NumAxes();\n    new_dim = JUST(maybe_wrap_dim(new_dim, ndims));\n    if (new_dim < 0) { new_dim += ndims; }\n    const auto do_cast = [&](const std::shared_ptr<one::Tensor>& x) -> Maybe<Tensor> {\n      return Cast(x, JUST(dtype), /*pin_memory=*/false);\n    };\n\n    if (new_dim == ndims - 1) {\n      return SequenceFunction<Maybe<Tensor>()>(\n                 [&]() { return OpInterpUtil::Dispatch<Tensor>(*op_, {input}); })\n          .then_if(keepdim.has_value() && JUST(keepdim) == true,\n                   std::bind(ExpandDims, std::placeholders::_1, -1))\n          .then_if(dtype.has_value(), do_cast)\n          .call();\n    }\n\n    std::vector<int32_t> permute;\n    permute.reserve(ndims);\n    for (int32_t i = 0; i < ndims - 1; i++) { permute.emplace_back(i < new_dim ? i : i + 1); }\n    permute.emplace_back(new_dim);\n\n    std::vector<int32_t> permute_inv(ndims, 0);\n    for (int32_t i = 0; i < ndims; i++) { permute_inv[i] = -1; }\n    for (int32_t i = 0; i < ndims; i++) { permute_inv[permute[i]] = i; }\n\n    std::vector<int32_t> squeeze_dim = {new_dim};\n\n    return SequenceFunction<Maybe<Tensor>()>([&]() { return Transpose(input, permute); })\n        .then([&](const std::shared_ptr<one::Tensor>& x) {\n          return OpInterpUtil::Dispatch<Tensor>(*op_, {x});\n        })\n        .then(std::bind(ExpandDims, std::placeholders::_1, -1))\n        .then(std::bind(Transpose, std::placeholders::_1, permute_inv))\n        .then_if((!keepdim.has_value()) || (keepdim.has_value() && JUST(keepdim) == false),\n                 std::bind(Squeeze, std::placeholders::_1, squeeze_dim))\n        .then_if(dtype.has_value(), do_cast)\n        .call();\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass ArgMinFunctor {\n public:\n  ArgMinFunctor() {}\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input, const Optional<int32_t>& dim,\n                           const Optional<bool>& keepdim,\n                           const Optional<Symbol<DType>>& dtype) const {\n    TensorProcessor tensor_processor;\n    JUST(tensor_processor.AddInputs({input}, DType::Float()).Apply());\n    const auto x = JUST(tensor_processor.GetInputs()).at(0);\n    return sequence_function(Negative)\n        .then(std::bind(ArgMax, std::placeholders::_1, dim, keepdim, dtype))\n        .call(x);\n  }\n};\n\nclass GlobalTensorConstantFunctor {\n public:\n  GlobalTensorConstantFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"tensor_constant\").Input(\"in\").Output(\"out\").Build());\n  }\n  Maybe<Tensor> operator()(const Shape& shape, const std::shared_ptr<one::Tensor>& value,\n                           const Symbol<DType>& dtype, const Symbol<ParallelDesc>& placement,\n                           const std::vector<Symbol<SbpParallel>>& sbp_tuple) const {\n    CHECK_OR_RETURN(value->ndim() <= 1 && value->nelement() == 1)\n        << \"Only tensor with single element or scalar tensor are supported as value!\";\n    CHECK_OR_RETURN(value->is_global()) << \"The value tensor should be global tensor\";\n    // NOTE: this op is an source op, so the value(scalar tensor) should not have autograd status.\n    autograd::AutoGradMode mode(false);\n    JUST(CheckDeviceIdsIsValid(placement));\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"shape\", \"dtype\", \"nd_sbp\");\n    attrs.SetAllAttrs(shape, dtype->data_type(), NullOpt);\n\n    auto dispatch_constant =\n        [&](const std::vector<Symbol<SbpParallel>>& sbp_tuple) -> Maybe<Tensor> {\n      std::vector<std::string> nd_sbp(sbp_tuple.size());\n      {\n        for (int i = 0; i < sbp_tuple.size(); ++i) {\n          nd_sbp[i] = SbpParallelToString(*sbp_tuple[i]);\n        }\n      }\n      attrs.SetAttr<2>(nd_sbp);\n      return OpInterpUtil::Dispatch<Tensor>(*op_, {value}, attrs);\n    };\n    bool has_partial_parallel =\n        std::any_of(sbp_tuple.begin(), sbp_tuple.end(),\n                    [](const Symbol<SbpParallel>& sbp) { return sbp->has_partial_sum_parallel(); });\n    // The source op does not support Partial\n    if (has_partial_parallel) {\n      const auto& fixed_sbp_tuple = JUST(NdSbpReplacePartialByBroadcast(sbp_tuple));\n      const auto& tensor = JUST(dispatch_constant(*fixed_sbp_tuple));\n      return functional::ToGlobal(tensor, placement, sbp_tuple, {}, /* check_meta */ false,\n                                  /*copy*/ false);\n    } else {\n      return dispatch_constant(sbp_tuple);\n    }\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass TensorConstantFunctor {\n public:\n  TensorConstantFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"tensor_constant\").Input(\"in\").Output(\"out\").Build());\n  }\n  Maybe<Tensor> operator()(const Shape& shape, const std::shared_ptr<one::Tensor>& value,\n                           const Symbol<DType>& dtype,\n                           const Optional<Symbol<Device>>& device) const {\n    CHECK_OR_RETURN(value->ndim() <= 1 && value->nelement() == 1)\n        << \"Only tensor with single element or scalar tensor are supported as value!\";\n    // NOTE: this op is an source op, so the value(scalar tensor) should not have autograd status.\n    autograd::AutoGradMode mode(false);\n    if (GlobalMode::is_enabled()) {\n      auto global_mode_gurad = GlobalMode::Guard(false);\n      return JUST(functional::GlobalTensorConstant(shape, value, dtype,\n                                                   GetGlobalParallelDescFromDevice(device),\n                                                   *JUST(GetSbpList(GlobalMode::nd_sbp()))));\n    }\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"shape\", \"dtype\");\n    attrs.SetAllAttrs(shape, dtype->data_type());\n    if (device.has_value()) {\n      Symbol<Device> device_symbol = JUST(device);\n      return OpInterpUtil::Dispatch<Tensor>(*op_, {value},\n                                            OpExprInterpContext(attrs, device_symbol));\n    } else {\n      return OpInterpUtil::Dispatch<Tensor>(*op_, {value}, attrs);\n    }\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass GlobalConstantFunctor {\n public:\n  GlobalConstantFunctor() { op_ = CHECK_JUST(one::OpBuilder(\"constant\").Output(\"out\").Build()); }\n  Maybe<Tensor> operator()(const Shape& shape, const Scalar& value, const Symbol<DType>& dtype,\n                           const Symbol<ParallelDesc>& placement,\n                           const std::vector<Symbol<SbpParallel>>& sbp_tuple) const {\n    JUST(CheckDeviceIdsIsValid(placement));\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"shape\", \"dtype\", \"complex_value\",\n                                                 \"is_complex_value\", \"floating_value\",\n                                                 \"is_floating_value\", \"integer_value\", \"nd_sbp\");\n    if (IsComplexDataType(dtype->data_type())) {\n      attrs.SetAllAttrs(shape, dtype->data_type(), value.Value<std::complex<double>>(), true,\n                        NullOpt, false, NullOpt, NullOpt);\n    } else if (IsIntegralDataType(dtype->data_type())) {\n      attrs.SetAllAttrs(shape, dtype->data_type(), NullOpt, false, NullOpt, false,\n                        value.As<int64_t>(), NullOpt);\n    } else {\n      attrs.SetAllAttrs(shape, dtype->data_type(), NullOpt, false, value.As<double>(), true,\n                        NullOpt, NullOpt);\n    }\n\n    auto dispatch_constant =\n        [&](const std::vector<Symbol<SbpParallel>>& sbp_tuple) -> Maybe<Tensor> {\n      if (LazyMode::is_enabled()) {\n        std::vector<std::string> nd_sbp(sbp_tuple.size());\n        {\n          for (int i = 0; i < sbp_tuple.size(); ++i) {\n            nd_sbp[i] = SbpParallelToString(*sbp_tuple[i]);\n          }\n        }\n        attrs.SetAttr<7>(nd_sbp);\n      }\n      const auto& nd_sbp = JUST(GetNdSbp(sbp_tuple));\n      return OpInterpUtil::Dispatch<Tensor>(*op_, {},\n                                            OpExprInterpContext(attrs, placement, nd_sbp));\n    };\n    bool has_partial_parallel = [&]() {\n      for (const auto& sbp : sbp_tuple) {\n        if (sbp->has_partial_sum_parallel()) { return true; }\n      }\n      return false;\n    }();\n    // Since the source op does not support Partial, it is necessary to replace Partial\n    // with Broadcast, and then convert it to Partial\n    if (has_partial_parallel) {\n      const auto& fixed_sbp_tuple = JUST(NdSbpReplacePartialByBroadcast(sbp_tuple));\n      const auto& tensor = JUST(dispatch_constant(*fixed_sbp_tuple));\n      return functional::ToGlobal(tensor, placement, sbp_tuple, {}, /* check_meta */ false,\n                                  /*copy*/ false);\n    } else {\n      return dispatch_constant(sbp_tuple);\n    }\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass ConstantFunctor {\n public:\n  ConstantFunctor() { op_ = CHECK_JUST(one::OpBuilder(\"constant\").Output(\"out\").Build()); }\n  Maybe<Tensor> operator()(const Shape& shape, const Scalar& value, const Symbol<DType>& dtype,\n                           const Optional<Symbol<Device>>& device) const {\n    if (GlobalMode::is_enabled()) {\n      auto global_mode_gurad = GlobalMode::Guard(false);\n      return JUST(functional::GlobalConstant(shape, value, dtype,\n                                             GetGlobalParallelDescFromDevice(device),\n                                             *JUST(GetSbpList(GlobalMode::nd_sbp()))));\n    }\n    auto& attrs =\n        THREAD_CACHED_MUTABLE_ATTR_MAP(\"shape\", \"dtype\", \"complex_value\", \"is_complex_value\",\n                                       \"floating_value\", \"is_floating_value\", \"integer_value\");\n    if (IsComplexDataType(dtype->data_type())) {\n      attrs.SetAllAttrs(shape, dtype->data_type(), value.Value<std::complex<double>>(), true,\n                        NullOpt, false, NullOpt);\n    } else if (IsIntegralDataType(dtype->data_type())) {\n      attrs.SetAllAttrs(shape, dtype->data_type(), NullOpt, false, NullOpt, false,\n                        value.As<int64_t>());\n    } else {\n      attrs.SetAllAttrs(shape, dtype->data_type(), NullOpt, false, value.As<double>(), true,\n                        NullOpt);\n    }\n    if (device.has_value()) {\n      Symbol<Device> device_symbol = JUST(device);\n      return OpInterpUtil::Dispatch<Tensor>(*op_, {}, OpExprInterpContext(attrs, device_symbol));\n    } else {\n      return OpInterpUtil::Dispatch<Tensor>(*op_, {}, attrs);\n    }\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass EmptyFunctor {\n public:\n  EmptyFunctor() { op_ = CHECK_JUST(one::OpBuilder(\"empty\").Output(\"out\").Build()); }\n  Maybe<Tensor> operator()(const Shape& shape, const Symbol<DType>& dtype,\n                           const Optional<Symbol<Device>>& device, const bool requires_grad,\n                           const bool pin_memory) const {\n    std::shared_ptr<Tensor> empty;\n    if (GlobalMode::is_enabled()) {\n      auto global_mode_gurad = GlobalMode::Guard(false);\n      empty = JUST(functional::GlobalEmpty(shape, dtype, GetGlobalParallelDescFromDevice(device),\n                                           *JUST(GetSbpList(GlobalMode::nd_sbp()))));\n      if (dtype->is_floating_point()) { JUST(empty->set_requires_grad(requires_grad)); }\n      return empty;\n    }\n    Symbol<Device> device_symbol = device.value_or(JUST(Device::New(\"cpu\")));\n    auto& attrs =\n        THREAD_CACHED_MUTABLE_ATTR_MAP(\"shape\", \"dtype\", \"pin_memory\", \"device_type\", \"device_id\");\n    attrs.SetAllAttrs(shape, dtype->data_type(), pin_memory, device_symbol->type(),\n                      device_symbol->device_id());\n    if (device.has_value()) {\n      Symbol<Device> device_symbol = JUST(device);\n      empty =\n          JUST(OpInterpUtil::Dispatch<Tensor>(*op_, {}, OpExprInterpContext(attrs, device_symbol)));\n    } else {\n      empty = JUST(OpInterpUtil::Dispatch<Tensor>(*op_, {}, attrs));\n    }\n\n    if (dtype->is_floating_point()) { JUST(empty->set_requires_grad(requires_grad)); }\n    return empty;\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass EmptyStridedFunctor {\n public:\n  Maybe<Tensor> operator()(const std::vector<int64_t>& shape, const std::vector<int64_t>& stride,\n                           const Optional<Symbol<DType>>& dtype,\n                           const Optional<Symbol<Device>>& device, const bool requires_grad,\n                           const bool pin_memory) const {\n    Symbol<DType> data_type = GetDefaultDType();\n    if (dtype.has_value()) { data_type = JUST(dtype); }\n    auto empty = JUST(functional::Empty(Shape(shape), dtype.value_or(GetDefaultDType()), device,\n                                        requires_grad, pin_memory));\n    CHECK_OR_RETURN(view::IsViewApplicable(empty))\n        << \"oneflow.empty_strided() only support in eager local mode!\";\n    return view::AsStrided(empty, shape, stride, 1);\n  }\n};\n\nclass GlobalEmptyFunctor {\n public:\n  GlobalEmptyFunctor() { op_ = CHECK_JUST(one::OpBuilder(\"empty\").Output(\"out\").Build()); }\n  Maybe<Tensor> operator()(const Shape& shape, const Symbol<DType>& dtype,\n                           const Symbol<ParallelDesc>& placement,\n                           const std::vector<Symbol<SbpParallel>>& sbp_tuple) const {\n    JUST(CheckDeviceIdsIsValid(placement));\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"shape\", \"dtype\", \"nd_sbp\");\n    if (LazyMode::is_enabled()) {\n      std::vector<std::string> nd_sbp(sbp_tuple.size());\n      {\n        for (int i = 0; i < sbp_tuple.size(); ++i) {\n          nd_sbp.at(i) = SbpParallelToString(*sbp_tuple.at(i));\n        }\n      }\n      attrs.SetAllAttrs(shape, dtype->data_type(), nd_sbp);\n    } else {\n      attrs.SetAllAttrs(shape, dtype->data_type(), NullOpt);\n    }\n    const auto& nd_sbp = JUST(GetNdSbp(sbp_tuple));\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {}, OpExprInterpContext(attrs, placement, nd_sbp));\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass ZerosLikeFunctor : public UnaryFunctor {\n public:\n  ZerosLikeFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"zero_like\").Input(\"like\").Output(\"out\").Build());\n  }\n};\n\nclass OnesLikeFunctor : public UnaryFunctor {\n public:\n  OnesLikeFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"ones_like\").Input(\"like\").Output(\"out\").Build());\n  }\n};\n\nclass FullLikeFunctor {\n public:\n  FullLikeFunctor() {}\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const Scalar& fill_value) const {\n    std::shared_ptr<Tensor> out;\n    if (x->is_local()) {\n      out = JUST(functional::Empty(*(x->shape()), x->dtype(), JUST(x->device()),\n                                   /*requires_grad=*/false, /*pin_memory=*/false));\n    } else {\n      out = JUST(functional::GlobalEmpty(*(x->shape()), x->dtype(), JUST(x->parallel_desc()),\n                                         *JUST(private_details::RawGetSbpList(JUST(x->nd_sbp())))));\n    }\n    out = JUST(functional::Fill(out, fill_value));\n    return out;\n  }\n};\n\nclass FlattenFunctor {\n public:\n  FlattenFunctor() = default;\n\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const int32_t& start_dim,\n                           const int32_t& end_dim) const {\n    const Shape& in_shape = *x->shape();\n    int32_t ndim = in_shape.size();\n\n    auto CheckAndWrapDim = [&](int32_t dim) -> Maybe<int32_t> {\n      // handle scalar\n      if (ndim == 0 && (dim == 0 || dim == -1)) { return 0; }\n      if (dim < -ndim || dim >= ndim) {\n        return Error::IndexError() << \"Dimension out of range (expected to be in range of [\"\n                                   << -ndim << \", \" << ndim - 1 << \"], but got \" << dim << \")\";\n      }\n      return dim >= 0 ? dim : dim + ndim;\n    };\n\n    // -n dim (negative dim) indicate ndim-n\n    // for example, when ndim == 3, (-3) == (0), (-2) == (1), (-1) == (2)\n    int32_t true_start_dim = JUST(CheckAndWrapDim(start_dim));\n    int32_t true_end_dim = JUST(CheckAndWrapDim(end_dim));\n\n    if (true_start_dim > true_end_dim) {\n      return Error::RuntimeError() << \"flatten() has invalid args: start_dim (\" << start_dim\n                                   << \") cannot come after end_dim (\" << end_dim << \")\";\n    }\n\n    // identity when start_dim == end_dim\n    if (true_start_dim == true_end_dim) { return x; }\n\n    DimVector dim_vec{in_shape.begin(), in_shape.begin() + true_start_dim + 1};\n    for (int i = true_start_dim + 1; i <= true_end_dim; ++i) { dim_vec.back() *= in_shape[i]; }\n    dim_vec.insert(dim_vec.end(), in_shape.begin() + true_end_dim + 1, in_shape.end());\n    Shape reshape_shape{dim_vec};\n    CHECK_EQ_OR_RETURN(in_shape.elem_cnt(), reshape_shape.elem_cnt())\n        << Error::RuntimeError() << \"invalid reshape from \" << in_shape.ToString() << \" to \"\n        << reshape_shape.ToString();\n    return JUST(Reshape(x, reshape_shape));\n  }\n};\n\nclass WhereFunctor {\n public:\n  WhereFunctor() {\n    op_ = CHECK_JUST(\n        one::OpBuilder(\"where\").Input(\"condition\").Input(\"x\").Input(\"y\").Output(\"out\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& condition,\n                           const std::shared_ptr<one::Tensor>& x,\n                           const std::shared_ptr<one::Tensor>& y) const {\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {condition, x, y});\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass WhereScalarXFunctor {\n public:\n  WhereScalarXFunctor() = default;\n\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& condition, const Scalar& scalar,\n                           const std::shared_ptr<one::Tensor>& y) const {\n    std::shared_ptr<one::Tensor> x;\n    if (y->is_local()) {\n      x = JUST(functional::Constant(Shape({}), scalar, y->dtype(), JUST(y->device())));\n    } else {\n      const size_t sbp_ndim = JUST(y->nd_sbp())->sbp_parallel_size();\n      std::vector<Symbol<SbpParallel>> nd_sbp_vec;\n      nd_sbp_vec.reserve(sbp_ndim);\n      for (int i = 0; i < sbp_ndim; ++i) {\n        SbpParallel sbp;\n        sbp.mutable_broadcast_parallel();\n        nd_sbp_vec.push_back(SymbolOf(sbp));\n      }\n      const auto& parallel_desc = JUST(y->parallel_desc());\n      x = JUST(\n          functional::GlobalConstant(Shape({}), scalar, y->dtype(), parallel_desc, nd_sbp_vec));\n    }\n    return functional::Where(condition, x, y);\n  }\n};\n\nclass WhereScalarYFunctor {\n public:\n  WhereScalarYFunctor() = default;\n\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& condition,\n                           const std::shared_ptr<one::Tensor>& x, const Scalar& scalar) const {\n    std::shared_ptr<one::Tensor> y;\n    if (x->is_local()) {\n      y = JUST(functional::Constant(Shape({}), scalar, x->dtype(), JUST(x->device())));\n    } else {\n      const size_t sbp_ndim = JUST(x->nd_sbp())->sbp_parallel_size();\n      std::vector<Symbol<SbpParallel>> nd_sbp_vec;\n      nd_sbp_vec.reserve(sbp_ndim);\n      for (int i = 0; i < sbp_ndim; ++i) {\n        SbpParallel sbp;\n        sbp.mutable_broadcast_parallel();\n        nd_sbp_vec.push_back(SymbolOf(sbp));\n      }\n      const auto& parallel_desc = JUST(x->parallel_desc());\n      y = JUST(\n          functional::GlobalConstant(Shape({}), scalar, x->dtype(), parallel_desc, nd_sbp_vec));\n    }\n    return functional::Where(condition, x, y);\n  }\n};\n\nclass WhereScalarXYFunctor {\n public:\n  WhereScalarXYFunctor() = default;\n\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& condition, const Scalar& x_scalar,\n                           const Scalar& y_scalar) const {\n    std::shared_ptr<one::Tensor> x;\n    std::shared_ptr<one::Tensor> y;\n    DataType dtype = DataType::kInvalidDataType;\n\n    if (x_scalar.IsBool() && y_scalar.IsBool()) {\n      dtype = DataType::kBool;\n    } else if (x_scalar.IsFloatingPoint() && y_scalar.IsFloatingPoint()) {\n      double x_val = x_scalar.As<double>();\n      double y_val = y_scalar.As<double>();\n      if (x_val >= GetMinVal<DataTypeToType<DataType::kFloat>>()\n          && x_val <= GetMaxVal<DataTypeToType<DataType::kFloat>>()\n          && y_val >= GetMinVal<DataTypeToType<DataType::kFloat>>()\n          && y_val <= GetMaxVal<DataTypeToType<DataType::kFloat>>()) {\n        dtype = DataType::kFloat;\n      } else {\n        dtype = DataType::kDouble;\n      }\n    } else if (x_scalar.IsIntegral() && y_scalar.IsIntegral()) {\n      if (x_scalar.IsUnsigned() && y_scalar.IsUnsigned()) {\n        uint64_t x_val = x_scalar.As<uint64_t>();\n        uint64_t y_val = y_scalar.As<uint64_t>();\n        if (x_val <= GetMaxVal<DataTypeToType<DataType::kUInt32>>()\n            && y_val <= GetMaxVal<DataTypeToType<DataType::kUInt32>>()) {\n          dtype = DataType::kUInt32;\n        } else {\n          dtype = DataType::kUInt64;\n        }\n      } else if (x_scalar.IsSigned() && y_scalar.IsSigned()) {\n        int64_t x_val = x_scalar.As<int64_t>();\n        int64_t y_val = y_scalar.As<int64_t>();\n        if (x_val >= GetMinVal<DataTypeToType<DataType::kInt32>>()\n            && x_val <= GetMaxVal<DataTypeToType<DataType::kInt32>>()\n            && y_val >= GetMinVal<DataTypeToType<DataType::kInt32>>()\n            && y_val <= GetMaxVal<DataTypeToType<DataType::kInt32>>()) {\n          dtype = DataType::kInt32;\n        } else {\n          dtype = DataType::kInt64;\n        }\n      } else {\n        UNIMPLEMENTED_THEN_RETURN()\n            << \"The x scalar and y scalar in Where shoule be signed or unsigned at the same time.\";\n      }\n    } else {\n      UNIMPLEMENTED_THEN_RETURN()\n          << \"The x scalar and y in Where shoule be bool, float or int at the same time.\";\n    }\n\n    if (condition->is_local()) {\n      x = JUST(functional::Constant(Shape({}), x_scalar, DType(dtype), JUST(condition->device())));\n      y = JUST(functional::Constant(Shape({}), y_scalar, DType(dtype), JUST(condition->device())));\n    } else {\n      const size_t sbp_ndim = JUST(condition->nd_sbp())->sbp_parallel_size();\n      std::vector<Symbol<SbpParallel>> nd_sbp_vec;\n      nd_sbp_vec.reserve(sbp_ndim);\n      for (int i = 0; i < sbp_ndim; ++i) {\n        SbpParallel sbp;\n        sbp.mutable_broadcast_parallel();\n        nd_sbp_vec.push_back(SymbolOf(sbp));\n      }\n      const auto& parallel_desc = JUST(condition->parallel_desc());\n      x = JUST(\n          functional::GlobalConstant(Shape({}), x_scalar, DType(dtype), parallel_desc, nd_sbp_vec));\n      y = JUST(\n          functional::GlobalConstant(Shape({}), y_scalar, DType(dtype), parallel_desc, nd_sbp_vec));\n    }\n    return functional::Where(condition, x, y);\n  }\n};\n\nclass ArgWhereFunctor {\n public:\n  ArgWhereFunctor() {\n    op_ = CHECK_JUST(\n        one::OpBuilder(\"argwhere\").Input(\"input\").Output(\"output\").Output(\"output_size\").Build());\n  }\n  Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& x,\n                                const Symbol<DType>& dtype) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"dtype\");\n    attrs.SetAllAttrs(dtype->data_type());\n    return OpInterpUtil::Dispatch<TensorTuple>(*op_, {x}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass NonZeroFunctor {\n public:\n  NonZeroFunctor() {}\n  Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& x, bool as_tuple) const {\n    std::shared_ptr<one::Tensor> input = x;\n    if (as_tuple && input->ndim() == 0) { input = JUST(functional::Unsqueeze(input, 0)); }\n    int64_t ndim = input->ndim();\n    const auto& output_tuple =\n        JUST(functional::ArgWhere(input, JUST(DType::Get(DataType::kInt64))));\n    const std::shared_ptr<one::Tensor>& size = JUST(VectorAt(*output_tuple, 1));\n    CHECK_EQ_OR_RETURN(size->shape()->elem_cnt(), 1)\n        << Error::RuntimeError() << kOfBugIssueUploadPrompt;\n    CHECK_OR_RETURN(size->dtype() == JUST(DType::Get(DataType::kInt64)))\n        << Error::RuntimeError() << kOfBugIssueUploadPrompt;\n    int64_t size_val = -1;\n    {\n      if (size->is_global()) {\n        CHECK_OR_RETURN(JUST(size->parallel_desc())->parallel_num() == 1  // NOLINT\n                        || NdSbpIsAllBroadcast(*JUST(size->nd_sbp())));   // NOLINT\n      }\n      JUST(GetItemInScalarTensor(size->is_local() ? size : JUST(size->cur_rank_phy_tensor()),\n                                 &size_val, sizeof(size_val)));\n    }\n    std::vector<int64_t> start{0, 0};\n    std::vector<int64_t> stop{size_val, ndim};\n    std::vector<int64_t> step{1, 1};\n    const auto& output = JUST(\n        functional::Slice(output_tuple->at(0), start, stop, step, /*enable_view_slice=*/false));\n    std::shared_ptr<TensorTuple> outputs = std::make_shared<TensorTuple>();\n    if (as_tuple) {\n      const auto& transposed_output = JUST(functional::Transpose2dim(output, 1, 0));\n      for (int64_t i = 0; i < ndim; ++i) {\n        outputs->emplace_back(\n            JUST(functional::TensorGetItem(transposed_output, {functional::detail::IndexItem(i)})));\n      }\n    } else {\n      outputs->emplace_back(output);\n    }\n    return outputs;\n  }\n};\n\nclass BroadcastLikeFunctor {\n public:\n  BroadcastLikeFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"broadcast_like\").Input(\"x\").Input(\"like\").Output(\"y\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,\n                           const std::shared_ptr<one::Tensor>& like,\n                           const std::vector<int32_t>& broadcast_axes) const {\n    const Shape& x_shape = *x->shape();\n    const Shape& like_shape = *like->shape();\n    if (x_shape == like_shape) { return x; }\n    CHECK_GE_OR_RETURN(like_shape.NumAxes(), x_shape.NumAxes())\n        << Error::RuntimeError() << \"The number of sizes provided (\" << like_shape.NumAxes()\n        << \") must be greater or equal to the number of dimensions in the tensor (\"\n        << x_shape.NumAxes() << \")\"\n        << \". Target sizes: \" << like_shape.ToString() << \". Tensor sizes: \" << x_shape.ToString();\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"broadcast_axes\");\n    if (broadcast_axes.empty()) {\n      int64_t like_ndim = like_shape.NumAxes();\n      int64_t x_ndim = x_shape.NumAxes();\n      int64_t num_prepend = like_ndim - x_ndim;\n      std::vector<int64_t> prepend_shape(num_prepend, 1);\n      std::vector<int32_t> broadcast_axes;\n      for (int i = 0; i < x_ndim; ++i) { prepend_shape.emplace_back(x_shape.At(i)); }\n      for (int i = 0; i < num_prepend; ++i) { broadcast_axes.emplace_back(i); }\n      for (int i = num_prepend; i < prepend_shape.size(); ++i) {\n        if (prepend_shape[i] != like_shape.At(i)) {\n          if (prepend_shape[i] == 1) {\n            broadcast_axes.emplace_back(i);\n          } else {\n            return Error::RuntimeError() << \"The expanded size of the tensor \"\n                                         << \"(\" << like_shape.At(i) << \")\"\n                                         << \" must match the existing size (\" << prepend_shape[i]\n                                         << \") at non-singleton dimension \" << i\n                                         << \". Target sizes: \" << like_shape.ToString()\n                                         << \". Tensor sizes: \" << x_shape.ToString();\n          }\n        }\n      }\n      attrs.SetAllAttrs(broadcast_axes);\n    } else {\n      attrs.SetAllAttrs(broadcast_axes);\n    }\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {x, JUST(like->detach())}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass ConcatFunctor {\n public:\n  ConcatFunctor() {\n    ops_.resize(kMaxInputCount);\n    for (int n = 0; n < ops_.size(); ++n) {\n      ops_[n] = CHECK_JUST(one::OpBuilder(\"cat\").Input(\"in\", n + 1).Output(\"out\").Build());\n    }\n  }\n  Maybe<Tensor> operator()(const TensorTuple& inputs, const int64_t& dim) const {\n    const int64_t ninput = inputs.size();\n    int64_t axis = dim;\n    int64_t ndim = inputs[0]->ndim();\n    int64_t nelement = inputs[0]->nelement();\n    int64_t max_dim_size = 0;\n    CHECK_GE_OR_RETURN(ninput, 1) << Error::RuntimeError() << \"inputs size must greater than 0\";\n    axis = JUST(maybe_wrap_dim(axis, ndim));\n\n    const std::shared_ptr<const Shape>& shape = inputs[0]->shape();\n    for (const auto& input : inputs) {\n      if (nelement == 0 and ndim == 1) {\n        if (input->nelement() != 0 or input->ndim() != 1) {\n          ndim = input->ndim();\n          nelement = input->nelement();\n        } else {\n          continue;\n        }\n      } else if (input->nelement() != 0 or input->ndim() != 1) {\n        CHECK_OR_RETURN(input->ndim() == ndim)\n            << Error::RuntimeError() << \"Tensors must have same number of dimensions: got \" << ndim\n            << \" and \" << input->ndim() << \" is expected.\";\n      }\n      for (int i = 0; i < ndim; ++i) {\n        if (input->nelement() == 0 and input->ndim() == 1) { continue; }\n        if (axis == i) {\n          max_dim_size += input->shape()->At(i);\n        } else if (inputs[0]->nelement() != 0) {\n          CHECK_OR_RETURN(input->shape()->At(i) == shape->At(i))\n              << Error::RuntimeError() << \"Sizes of tensors must match except in dimension \" << axis\n              << \". Got \" << input->shape()->At(i) << \" and \" << shape->At(i)\n              << \" is expected in dimension \" << i << \".\";\n        }\n      }\n    }\n\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"axis\", \"max_dim_size\");\n    attrs.SetAllAttrs(axis, max_dim_size);\n    TensorTuple outputs;\n    for (int i = 0; i < ninput; i += kMaxInputCount) {\n      size_t size = (i + kMaxInputCount) < ninput ? kMaxInputCount : ninput - i;\n      TensorTuple partial_inputs(size);\n      TensorProcessor tensor_processor;\n      for (int j = 0; j < size; ++j) { partial_inputs[j] = inputs[i + j]; }\n      JUST(tensor_processor.PromoteInputsToCommonDtype(true)\n               .AddInputs(partial_inputs, inputs.at(i)->dtype())\n               .Apply());\n      TensorTuple input_tuple = JUST(tensor_processor.GetInputs());\n      outputs.emplace_back(\n          JUST(OpInterpUtil::Dispatch<Tensor>(*ops_[size - 1], input_tuple, attrs)));\n    }\n\n    if (outputs.size() == 1) { return outputs.at(0); }\n    return this->operator()(outputs, axis);\n  }\n\n private:\n  std::vector<std::shared_ptr<OpExpr>> ops_;\n};\n\nclass StackFunctor {\n public:\n  StackFunctor() {\n    ops_.resize(kMaxInputCount);\n    for (int n = 0; n < ops_.size(); ++n) {\n      ops_[n] = CHECK_JUST(one::OpBuilder(\"stack\").Input(\"in\", n + 1).Output(\"out\").Build());\n    }\n  }\n  Maybe<Tensor> operator()(const TensorTuple& inputs, const int64_t& dim) const {\n    const int64_t ninput = inputs.size();\n    int64_t ndims = inputs[0]->ndim();\n    int64_t stack_dim = dim;\n    stack_dim = JUST(maybe_wrap_dim(stack_dim, ndims + 1));\n    const std::shared_ptr<const Shape>& first_in_shape = inputs[0]->shape();\n    for (const auto& input : inputs) {\n      for (int i = 0; i < ndims; ++i) {\n        CHECK_OR_RETURN(input->shape()->At(i) == first_in_shape->At(i))\n            << Error::RuntimeError() << \"stack expects each tensor to be equal size, but got \"\n            << first_in_shape->ToString() << \" at first input and \" << input->shape()->ToString()\n            << \" which index is \" << i;\n      }\n    }\n    int64_t max_dim_size = ninput;\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"axis\", \"max_dim_size\");\n    attrs.SetAllAttrs(stack_dim, max_dim_size);\n    TensorTuple outputs;\n    for (int i = 0; i < ninput; i += kMaxInputCount) {\n      size_t size = (i + kMaxInputCount) < ninput ? kMaxInputCount : ninput - i;\n      TensorTuple partial_inputs(size);\n      for (int j = 0; j < size; ++j) { partial_inputs[j] = inputs[i + j]; }\n      if (partial_inputs.size() == 1) {\n        // Use ExpandDims functor for only one input\n        outputs.emplace_back(JUST(functional::ExpandDims(partial_inputs[0], dim)));\n      } else {\n        outputs.emplace_back(\n            JUST(OpInterpUtil::Dispatch<Tensor>(*ops_[size - 1], partial_inputs, attrs)));\n      }\n    }\n    if (outputs.size() == 1) { return outputs.at(0); }\n    return Concat(outputs, stack_dim);\n  }\n\n private:\n  std::vector<std::shared_ptr<OpExpr>> ops_;\n};\n\nclass StackGradFunctor {\n public:\n  StackGradFunctor() {\n    ops_.resize(kMaxInputCount);\n    for (int n = 1; n < ops_.size(); ++n) {\n      ops_[n] = CHECK_JUST(one::OpBuilder(\"stack_grad\")\n                               .Input(\"in\")\n                               .Input(\"like\", n + 1)\n                               .Output(\"out\", n + 1)\n                               .Build());\n    }\n  }\n  Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& x, const TensorTuple& like,\n                                const int64_t& axis) const {\n    CHECK_GE_OR_RETURN(like.size(), 2)\n        << Error::RuntimeError() << \"like.size() must not less than 2, but got \" << like.size();\n    CHECK_LE_OR_RETURN(like.size(), kMaxInputCount)\n        << Error::RuntimeError() << \"like.size() must not greater than \" << kMaxInputCount\n        << \", but got \" << like.size();\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"axis\");\n    attrs.SetAllAttrs(axis);\n    TensorTuple inputs(like.size() + 1);\n    inputs[0] = x;\n    for (int i = 0; i < like.size(); ++i) { inputs[i + 1] = like[i]; }\n    return OpInterpUtil::Dispatch<TensorTuple>(*ops_.at(like.size() - 1), inputs, attrs);\n  }\n\n private:\n  std::vector<std::shared_ptr<OpExpr>> ops_;\n};\n\nclass AtLeast1DFunctor {\n public:\n  Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& x) const {\n    if (x->ndim() == 0) {\n      return JUST(Reshape(x, {1}));\n    } else\n      return x;\n  }\n};\n\nclass AtLeast1DListFunctor {\n public:\n  Maybe<TensorTuple> operator()(const TensorTuple& inputs) const {\n    TensorTuple result = TensorTuple(inputs.size());\n    for (int32_t i = 0; i < inputs.size(); i++) {\n      result.at(i) = JUST(AtLeast1D(JUST(VectorAt(inputs, i))));\n    }\n    return result;\n  }\n};\n\nclass AtLeast2DFunctor {\n public:\n  Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& x) const {\n    if (x->ndim() == 0) {\n      return JUST(Reshape(x, {1, 1}));\n    } else if (x->ndim() == 1) {\n      return JUST(Unsqueeze(x, 0));\n    } else\n      return x;\n  }\n};\n\nclass AtLeast2DListFunctor {\n public:\n  Maybe<TensorTuple> operator()(const TensorTuple& inputs) const {\n    TensorTuple result = TensorTuple(inputs.size());\n    for (int32_t i = 0; i < inputs.size(); i++) {\n      result.at(i) = JUST(AtLeast2D(JUST(VectorAt(inputs, i))));\n    }\n    return result;\n  }\n};\n\nclass AtLeast3DFunctor {\n public:\n  Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& x) const {\n    if (x->ndim() == 0) {\n      return JUST(Reshape(x, {1, 1, 1}));\n    } else if (x->ndim() == 1) {\n      return JUST(Reshape(x, {1, x->shape()->At(0), 1}));\n    } else if (x->ndim() == 2) {\n      return JUST(Unsqueeze(x, -1));\n    } else\n      return x;\n  }\n};\n\nclass AtLeast3DListFunctor {\n public:\n  Maybe<TensorTuple> operator()(const TensorTuple& inputs) const {\n    TensorTuple result = TensorTuple(inputs.size());\n    for (int32_t i = 0; i < inputs.size(); i++) {\n      result.at(i) = JUST(AtLeast3D(JUST(VectorAt(inputs, i))));\n    }\n    return result;\n  }\n};\n\nclass ColumnStackFunctor {\n public:\n  Maybe<Tensor> operator()(const TensorTuple& inputs) const {\n    std::shared_ptr<TensorTuple> new_inputs = std::make_shared<TensorTuple>(inputs.size());\n    for (int32_t i = 0; i < inputs.size(); i++) {\n      const auto& t = JUST(VectorAt(inputs, i));\n      if (t->ndim() <= 1)\n        new_inputs->at(i) = JUST(Reshape(t, {t->nelement(), 1}));\n      else\n        new_inputs->at(i) = t;\n    }\n    return HStack(*new_inputs);\n  }\n};\n\nclass HStackFunctor {\n public:\n  Maybe<Tensor> operator()(const TensorTuple& inputs) const {\n    std::shared_ptr<TensorTuple> new_inputs = JUST(AtLeast1D(inputs));\n    if (new_inputs->at(0)->ndim() == 1)\n      return Concat(*new_inputs, 0);\n    else\n      return Concat(*new_inputs, 1);\n  }\n};\n\nclass VStackFunctor {\n public:\n  Maybe<Tensor> operator()(const TensorTuple& inputs) const {\n    std::shared_ptr<TensorTuple> new_inputs = JUST(AtLeast2D(inputs));\n    return Concat(*new_inputs, 0);\n  }\n};\n\nclass RowStackFunctor {\n public:\n  Maybe<Tensor> operator()(const TensorTuple& inputs) const { return VStack(inputs); }\n};\n\nclass DStackFunctor {\n public:\n  Maybe<Tensor> operator()(const TensorTuple& inputs) const {\n    std::shared_ptr<TensorTuple> new_inputs = JUST(AtLeast3D(inputs));\n    return Concat(*new_inputs, 2);\n  }\n};\n\nclass ExpandFunctor {\n public:\n  ExpandFunctor() { op_ = CHECK_JUST(one::OpBuilder(\"expand\").Input(\"in\").Output(\"out\").Build()); }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const Shape& shape) const {\n    const Shape& in_shape = *x->shape();\n    int lpad = shape.size() - in_shape.size();\n    if (lpad < 0) {\n      return Error::RuntimeError()\n             << \"expand(tensor{\" << in_shape.ToString() << \"}, size=\" << in_shape.size()\n             << \"): the number of sizes provided (\" << shape.size() << \") \"\n             << \"must be greater or equal to the number of dimensions in the tensor (\"\n             << in_shape.size() << \")\";\n    }\n\n    DimVector expand_shape_vec = shape.dim_vec();\n    for (size_t i = 0; i < shape.size(); ++i) {\n      const auto& t_dim = shape[i];\n      if (t_dim < -1) {\n        return Error::RuntimeError() << \"Trying to create tensor with negative dimension \" << t_dim;\n      }\n      if (i >= lpad) {\n        const auto& dim = in_shape[i - lpad];\n        if (dim != 1 && t_dim != -1 && t_dim != dim) {\n          return Error::RuntimeError()\n                 << \"The expanded size of the tensor (\" << t_dim\n                 << \") must match the existing size (\" << dim << \") at non-singleton dimension \"\n                 << i << \". Target sizes: \" << shape.ToString()\n                 << \". Tensor sizes: \" << in_shape.ToString();\n        }\n        if (t_dim == -1) { expand_shape_vec[i] = dim; }\n      } else {\n        if (t_dim == -1) {\n          return Error::RuntimeError() << \"The expanded size of the tensor (-1) isn't allowed in a \"\n                                          \"leading, non-existing dimension \"\n                                       << i;\n        }\n      }\n    }\n\n    // if input tensor is eager local, then try return tensor's view\n    Shape expand_shape(expand_shape_vec);\n    if (view::IsViewApplicable(x)) { return view::Expand(x, expand_shape); }\n\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"expand_shape\");\n    attrs.SetAllAttrs(expand_shape);\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {x}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass ExpandDimsFunctor {\n public:\n  ExpandDimsFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"expand_dims\").Input(\"in\").Output(\"out\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input, const int32_t& dim) const {\n    int32_t expand_dim = dim;\n    const int32_t ndim = input->shape()->NumAxes();\n    expand_dim = JUST(maybe_wrap_dim(dim, ndim + 1));\n\n    if (view::IsViewApplicable(input)) { return view::Unsqueeze(input, expand_dim); }\n\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"axis\");\n    attrs.SetAllAttrs(expand_dim);\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {input}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass UnsqueezeMultipleFunctor {\n public:\n  UnsqueezeMultipleFunctor() {}\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const std::vector<int32_t>& dim,\n                           const int32_t& n_dims) const {\n    if (dim.size() == 0 || x->ndim() == n_dims) {\n      return x;\n    } else if (dim.size() == 1) {\n      return JUST(functional::Unsqueeze(x, JUST(VectorAt(dim, 0))));\n    } else {\n      std::shared_ptr<Tensor> tensor = x;\n      const auto& dims_to_unsqueeze = JUST(dim_list_to_bitset(dim, n_dims));\n\n      // Unsqueeze is called several times to extend the dimension when the View mechanism is\n      // enabled. Otherwise, calculate the target shape and call reshape.\n      if (view::IsViewApplicable(tensor)) {\n        for (int32_t i = 0; i < n_dims; i++) {\n          if ((*dims_to_unsqueeze)[i]) { tensor = JUST(view::Unsqueeze(tensor, i)); }\n        }\n      } else {\n        std::vector<int64_t> target_dims(n_dims, 0);\n        int32_t tensor_index = 0;\n        for (int32_t i = 0; i < n_dims; i++) {\n          if ((*dims_to_unsqueeze)[i]) {\n            target_dims[i] = 1;\n          } else {\n            CHECK_LT_OR_RETURN(tensor_index, tensor->ndim());  // NOLINT(maybe-need-error-msg)\n            target_dims[i] = tensor->shape()->at(tensor_index);\n            tensor_index++;\n          }\n        }\n        Shape infered_shape(DimVector(target_dims.begin(), target_dims.end()));\n        tensor = JUST(functional::Reshape(tensor, infered_shape));\n      }\n      return tensor;\n    }\n  }\n};\n\nclass InplaceUnsqueezeFunctor {\n public:\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input, const int32_t& dim) const {\n    JUST(CheckInplaceValid(input));\n    const int64_t expand_dim = JUST(maybe_wrap_dim(dim, input->shape()->NumAxes() + 1));\n    CHECK_OR_RETURN(view::IsViewApplicable(input))\n        << \"inplace unsqueeze(tensor.unsqueeze_) only support in eager local mode!\";\n\n    JUST(view::InplaceUnsqueeze(input, expand_dim));\n    return input;\n  }\n};\n\nclass SqueezeFunctor {\n public:\n  SqueezeFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"squeeze\").Input(\"in\").Output(\"out\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,\n                           const Optional<std::vector<int32_t>>& dim) const {\n    int32_t ndim = x->shape()->NumAxes();\n    std::vector<int32_t> squeeze_dims;\n    squeeze_dims.reserve(ndim);\n    if (dim.has_value()) {\n      std::vector<int32_t> dims = *JUST(dim);\n      for (int32_t dim_i : dims) {\n        dim_i = JUST(maybe_wrap_dim(dim_i, ndim));\n        if (x->shape()->At(dim_i) == 1) { squeeze_dims.emplace_back(dim_i); }\n      }\n    } else {\n      for (int i = 0; i < ndim; ++i) {\n        if (x->shape()->At(i) == 1) { squeeze_dims.emplace_back(i); }\n      }\n    }\n\n    if (view::IsViewApplicable(x)) { return view::Squeeze(x, squeeze_dims); }\n\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"axes\");\n    attrs.SetAllAttrs(squeeze_dims);\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {x}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass InplaceSqueezeFunctor {\n public:\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input,\n                           const Optional<std::vector<int32_t>>& dim) const {\n    JUST(CheckInplaceValid(input));\n    const int32_t ndim = input->shape()->NumAxes();\n    std::vector<int32_t> squeeze_dims;\n    squeeze_dims.reserve(ndim);\n    if (dim.has_value()) {\n      std::vector<int32_t> dims = *JUST(dim);\n      for (int32_t dim_i : dims) {\n        dim_i = JUST(maybe_wrap_dim(dim_i, ndim));\n        if (input->shape()->At(dim_i) == 1) { squeeze_dims.emplace_back(dim_i); }\n      }\n    } else {\n      for (int i = 0; i < ndim; ++i) {\n        if (input->shape()->At(i) == 1) { squeeze_dims.emplace_back(i); }\n      }\n    }\n\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"axes\");\n    attrs.SetAllAttrs(squeeze_dims);\n\n    CHECK_OR_RETURN(view::IsViewApplicable(input))\n        << \"inplace squeeze(tensor.squeeze_) only support in eager local mode!\";\n\n    JUST(view::InplaceSqueeze(input, squeeze_dims));\n    return input;\n  }\n};\n\nclass RollFunctor {\n public:\n  RollFunctor() { op_ = CHECK_JUST(one::OpBuilder(\"roll\").Input(\"in\").Output(\"out\").Build()); }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,\n                           const std::vector<int32_t>& shifts,\n                           const Optional<std::vector<int32_t>>& dims) const {\n    std::vector<int32_t> actual_dims;\n    if (dims.has_value()) {\n      actual_dims = *JUST(dims);\n    } else {\n      actual_dims.emplace_back(-1);\n    }\n    CHECK_EQ_OR_RETURN(shifts.size(), actual_dims.size())\n        << Error::RuntimeError() << \"shifts and dimensions must align. shifts: \" << shifts.size()\n        << \", dims: \" << actual_dims.size();\n\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"shifts\", \"dims\");\n    attrs.SetAllAttrs(shifts, actual_dims);\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {x}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass GatherFunctor {\n public:\n  GatherFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"gather\").Input(\"in\").Input(\"indices\").Output(\"out\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,\n                           const std::shared_ptr<one::Tensor>& indices, const int64_t& axis) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"axis\");\n    attrs.SetAllAttrs(axis);\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {x, indices}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass DimGatherFunctor {\n public:\n  DimGatherFunctor() {\n    op_ = CHECK_JUST(\n        one::OpBuilder(\"dim_gather\").Input(\"input\").Input(\"index\").Output(\"output\").Build());\n  }\n\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input, const int64_t& dim,\n                           const std::shared_ptr<one::Tensor>& index,\n                           const bool sparse_grad) const {\n    CHECK_OR_RETURN(index->dtype()->data_type() == kInt64 || index->dtype()->data_type() == kInt32)\n        << Error::RuntimeError() << \"gather(): Expected dtype int32 or int64 for index\";\n    CHECK_EQ_OR_RETURN(sparse_grad, false)\n        << Error::RuntimeError() << \"Only support bool = False for now!\";\n\n    int64_t new_dim = JUST(maybe_wrap_dim(dim, index->ndim()));\n    if (input->ndim() > 0 && index->ndim() > 0) {\n      CHECK_EQ_OR_RETURN(input->ndim(), index->ndim())\n          << Error::RuntimeError()\n          << \"Index tensor must have the same number of dimensions as input tensor\";\n    } else if (input->ndim() == 0) {\n      CHECK_LE_OR_RETURN(index->ndim(), 1)\n          << Error::RuntimeError()\n          << \"Index tensor must have the same number of dimensions as input tensor\";\n    } else {\n      CHECK_LE_OR_RETURN(input->ndim(), 1)\n          << Error::RuntimeError()\n          << \"Index tensor must have the same number of dimensions as input tensor\";\n    }\n    if (input->ndim() > 0 && index->ndim() > 0) {\n      FOR_RANGE(int32_t, i, 0, input->ndim()) {\n        if (i != new_dim) {\n          CHECK_LE_OR_RETURN(index->shape()->At(i), input->shape()->At(i))\n              << Error::RuntimeError() << \"Size does not match at dimension \" << i\n              << \" expected index \" << *(index->shape()) << \" to be smaller than self \"\n              << *(input->shape()) << \" apart from dimension \" << new_dim;\n        }\n      }\n    }\n\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"dim\");\n    attrs.SetAllAttrs(static_cast<int32_t>(new_dim));\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {input, index}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nenum class DimScatterType { kUpdate, kAdd, kMultiply };\n\ntemplate<DimScatterType T>\nstd::string DimScatterTypeToString() {\n  switch (T) {\n    case DimScatterType::kUpdate: return \"_update\";\n    case DimScatterType::kAdd: return \"_add\";\n    case DimScatterType::kMultiply: return \"_mul\";\n  }\n  return \"\";\n}\n\ntemplate<DimScatterType T>\nclass DimScatterFunctorImpl {\n public:\n  DimScatterFunctorImpl()\n      : op_(CHECK_JUST(one::OpBuilder(\"dim_scatter\" + DimScatterTypeToString<T>())\n                           .Input(\"input\")\n                           .Input(\"index\")\n                           .Input(\"src\")\n                           .Output(\"output\")\n                           .Build())) {}\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input, const int32_t& dim,\n                           const std::shared_ptr<one::Tensor>& index,\n                           const std::shared_ptr<one::Tensor>& src, bool inplace) const {\n    const int32_t ndim = input->shape()->NumAxes();\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"dim\");\n    attrs.SetAllAttrs(static_cast<int32_t>(JUST(maybe_wrap_dim(dim, ndim))));\n    if (inplace) {\n      JUST(CheckInplaceValid(input));\n      auto outputs = std::make_shared<TensorTuple>(1);\n      outputs->at(0) = input;\n      JUST(OpInterpUtil::Dispatch(*op_, {input, index, src}, outputs.get(), attrs));\n      return outputs->at(0);\n    }\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {input, index, src}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass DimScatterFunctor {\n public:\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input, const int32_t& dim,\n                           const std::shared_ptr<one::Tensor>& index,\n                           const std::shared_ptr<one::Tensor>& src,\n                           const Optional<std::string>& reduce, bool inplace) const {\n    if (reduce.has_value()) {\n      const std::string& reduce_str = *JUST(reduce);\n      if (reduce_str == \"add\") {\n        return DimScatterAdd(input, dim, index, src, inplace);\n      } else if (reduce_str == \"multiply\") {\n        return DimScatterMul(input, dim, index, src, inplace);\n      } else {\n        CHECK_OR_RETURN(false) << Error::RuntimeError() << \"Invalid reduce type: \" << reduce_str;\n      }\n    }\n    return functional::DimScatterUpdate(input, dim, index, src, inplace);\n  }\n};\n\ntemplate<DimScatterType T>\nclass DimScatterScalarFunctorImpl {\n public:\n  DimScatterScalarFunctorImpl()\n      : op_(CHECK_JUST(one::OpBuilder(\"dim_scatter\" + DimScatterTypeToString<T>() + \"_scalar\")\n                           .Input(\"input\")\n                           .Input(\"index\")\n                           .Output(\"output\")\n                           .Build())) {}\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input, const int32_t& dim,\n                           const std::shared_ptr<one::Tensor>& index, const Scalar& src,\n                           bool inplace) const {\n    const int32_t ndim = input->shape()->NumAxes();\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"dim\", \"src_scalar\");\n    attrs.SetAllAttrs(static_cast<int32_t>(JUST(maybe_wrap_dim(dim, ndim))), src.As<float>());\n    if (inplace) {\n      JUST(CheckInplaceValid(input));\n      auto outputs = std::make_shared<TensorTuple>(1);\n      outputs->at(0) = input;\n      JUST(OpInterpUtil::Dispatch(*op_, {input, index}, outputs.get(), attrs));\n      return outputs->at(0);\n    }\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {input, index}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass DimScatterScalarFunctor {\n public:\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input, const int32_t& dim,\n                           const std::shared_ptr<one::Tensor>& index, const Scalar& src,\n                           const Optional<std::string>& reduce, bool inplace) const {\n    if (reduce.has_value()) {\n      const std::string& reduce_str = *JUST(reduce);\n      if (reduce_str == \"add\") {\n        return DimScatterAddScalar(input, dim, index, src, inplace);\n      } else if (reduce_str == \"multiply\") {\n        return DimScatterMulScalar(input, dim, index, src, inplace);\n      } else {\n        CHECK_OR_RETURN(false) << Error::RuntimeError() << \"Invalid reduce type: \" << reduce_str;\n      }\n    }\n    return functional::DimScatterUpdateScalar(input, dim, index, src, inplace);\n  }\n};\n\nclass DimScatterAddLikeFunctor {\n public:\n  DimScatterAddLikeFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"dim_scatter_add_like\")\n                         .Input(\"like\")\n                         .Input(\"index\")\n                         .Input(\"src\")\n                         .Output(\"output\")\n                         .Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& like, const int32_t& dim,\n                           const std::shared_ptr<one::Tensor>& index,\n                           const std::shared_ptr<one::Tensor>& src) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"dim\");\n    attrs.SetAllAttrs(dim);\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {like, index, src}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass ArgSortFunctor {\n public:\n  ArgSortFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"arg_sort\").Input(\"in\").Output(\"out\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& in,\n                           const std::string& direction) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"direction\");\n    attrs.SetAllAttrs(direction);\n    CHECK_OR_RETURN(direction == \"ASCENDING\" || direction == \"DESCENDING\")\n        << Error::RuntimeError()\n        << \"expected the input direction parameter value is \\\"ASCENDING\\\" or \\\"DESCENDING\\\", \"\n        << \"but found the value is \"\n        << \"\\\"\" << direction << \"\\\"\";\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {in}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass SearchSortedFunctor {\n public:\n  SearchSortedFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"searchsorted\")\n                         .Input(\"sorted_sequence\")\n                         .Input(\"values\")\n                         .Output(\"out\")\n                         .Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& sorted_sequence,\n                           const std::shared_ptr<one::Tensor>& values, bool out_int32,\n                           bool right) const {\n    // checks\n    CHECK_OR_RETURN(values->shape()->NumAxes() > 0)\n        << \"for searchsorted op, input values tensor should have positive dimension\";\n    CHECK_OR_RETURN(sorted_sequence->shape()->NumAxes() > 0)\n        << \"for searchsorted op, input sorted_sequence should have positive dimension, \"\n        << \"but got 0 dimension\";\n    CHECK_OR_RETURN(sorted_sequence->shape()->NumAxes() == 1\n                    || sorted_sequence->shape()->MatchBeforeLastDim(*(values->shape())))\n        << \"for searchsorted op, sorted_sequence should be 1 dimension or the first N-1 dimensions \"\n        << \"of boundaries tensor and input value tensor must match\";\n    if (out_int32) {\n      CHECK_OR_RETURN(sorted_sequence->shape()->At(sorted_sequence->shape()->NumAxes() - 1)\n                      < INT32_MAX)\n          << \"for searchsorted op, the size of input sorted_sequence' last dimension should \"\n          << \"be less than \" << INT32_MAX;\n    }\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"out_int32\", \"right\");\n    attrs.SetAllAttrs(out_int32, right);\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {sorted_sequence, values}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass SearchSortedScalarFunctor {\n public:\n  SearchSortedScalarFunctor() {\n    op_ = CHECK_JUST(\n        one::OpBuilder(\"searchsorted_scalar\").Input(\"sorted_sequence\").Output(\"out\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& sorted_sequence,\n                           const Scalar& values, bool out_int32, bool right) const {\n    // checks\n    CHECK_OR_RETURN(sorted_sequence->shape()->NumAxes() == 1)\n        << \"for searchsorted op, input value can be a scalar only when sorted_sequence tensor \"\n        << \"dimension is 1, but we got sorted_sequence dim(\" << sorted_sequence->shape()->NumAxes()\n        << \")\";\n    if (out_int32) {\n      CHECK_OR_RETURN(sorted_sequence->shape()->At(sorted_sequence->shape()->NumAxes() - 1)\n                      < INT32_MAX)\n          << \"for searchsorted op, the size of input sorted_sequence' last dimension should \"\n          << \"be less than \" << INT32_MAX;\n    }\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"out_int32\", \"right\", \"values\");\n    attrs.SetAllAttrs(out_int32, right, values.As<double>());\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {sorted_sequence}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass GatherNdFunctor {\n public:\n  GatherNdFunctor() {\n    op_ = CHECK_JUST(\n        one::OpBuilder(\"gather_nd\").Input(\"params\").Input(\"indices\").Output(\"out\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& params,\n                           const std::shared_ptr<one::Tensor>& indices) const {\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {params, indices});\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass ScatterNdFunctor {\n public:\n  ScatterNdFunctor() {\n    op_ = CHECK_JUST(\n        one::OpBuilder(\"scatter_nd\").Input(\"indices\").Input(\"updates\").Output(\"out\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& indices,\n                           const std::shared_ptr<one::Tensor>& updates, const Shape& shape) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"shape\");\n    attrs.SetAllAttrs(shape);\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {indices, updates}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass TensorScatterNdUpdateFunctor {\n public:\n  TensorScatterNdUpdateFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"tensor_scatter_nd_update\")\n                         .Input(\"params\")\n                         .Input(\"indices\")\n                         .Input(\"updates\")\n                         .Output(\"out\")\n                         .Build());\n  }\n\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& tensor,\n                           const std::shared_ptr<one::Tensor>& indices,\n                           const std::shared_ptr<one::Tensor>& updates, bool inplace) const {\n    CHECK_OR_RETURN(*tensor->dtype() == *updates->dtype())\n        << Error::RuntimeError() << \"The dtype of tensor and updates must be same.\";\n    std::shared_ptr<Tensor> contiguous_index = JUST(functional::ToContiguous(indices));\n    if (inplace) {\n      if (tensor->is_global()) {\n        // NOTE: global tensor_scatter_nd_update inplace must calculate on another tensor and assign\n        // back because of input's sbp limited\n        auto output =\n            JUST(OpInterpUtil::Dispatch<Tensor>(*op_, {tensor, contiguous_index, updates}));\n        int64_t ndim = tensor->shape()->NumAxes();\n        // TODO: use inplace copy op to write back to origin tensor\n        std::vector<int64_t> start(ndim, 0);\n        std::vector<int64_t> stop(tensor->shape()->begin(), tensor->shape()->end());\n        std::vector<int64_t> step(ndim, 1);\n        return functional::SliceUpdate(tensor, output, start, stop, step, /*inplace=*/true);\n      } else {\n        JUST(CheckInplaceValid(tensor));\n        auto outputs = std::make_shared<TensorTuple>(1);\n        (*outputs)[0] = tensor;\n        JUST(OpInterpUtil::Dispatch(*op_, {tensor, contiguous_index, updates}, outputs.get()));\n        return (*outputs)[0];\n      }\n    } else {\n      return OpInterpUtil::Dispatch<Tensor>(*op_, {tensor, contiguous_index, updates});\n    }\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass ScatterNdLikeFunctor {\n public:\n  ScatterNdLikeFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"scatter_nd_like\")\n                         .Input(\"like\")\n                         .Input(\"updates\")\n                         .Input(\"indices\")\n                         .Output(\"out\")\n                         .Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& like,\n                           const std::shared_ptr<one::Tensor>& updates,\n                           const std::shared_ptr<one::Tensor>& indices) const {\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {like, updates, indices});\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass ReshapeFunctor {\n public:\n  ReshapeFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"reshape\").Input(\"in\").Output(\"out\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const Shape& shape) const {\n    Shape infered_shape = *JUST(InferShapeUnspecifiedDim(x->shape()->Count(0), shape));\n\n    if (view::IsViewApplicable(x)) {\n      Optional<Stride> infered_stride =\n          ComputeStride(*(x->shape()), *JUST(x->stride()), infered_shape);\n      if (infered_stride.has_value()) {\n        return view::Reshape(x, infered_shape, *JUST(infered_stride));\n      }\n    }\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"shape\");\n    attrs.SetAllAttrs(infered_shape);\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {x}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass ViewFunctor {\n public:\n  ViewFunctor() { op_ = CHECK_JUST(one::OpBuilder(\"reshape\").Input(\"in\").Output(\"out\").Build()); }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const Shape& shape) const {\n    Shape infered_shape = *JUST(InferShapeUnspecifiedDim(x->shape()->Count(0), shape));\n    if (view::IsViewApplicable(x)) {\n      Optional<Stride> infered_stride =\n          ComputeStride(*(x->shape()), *JUST(x->stride()), infered_shape);\n      CHECK_OR_RETURN_ERROR(infered_stride.has_value())\n          << Error::RuntimeError()\n          << \"view size is not compatible with input tensor's size and stride (at least one \"\n             \"dimension spans across two contiguous subspaces). Use .reshape(...) instead.\";\n      return view::Reshape(x, infered_shape, *JUST(infered_stride));\n    }\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"shape\");\n    attrs.SetAllAttrs(infered_shape);\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {x}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass ToContiguousFunctor {\n public:\n  ToContiguousFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"to_contiguous\").Input(\"in\").Output(\"out\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input) const {\n    if (input->is_global() || input->is_lazy()) { return input; }\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {input});\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass InplaceToContiguousFunctor {\n public:\n  InplaceToContiguousFunctor() {\n    assign_op_ = CHECK_JUST(one::OpBuilder(\"assign\").Input(\"ref\").Input(\"value\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input) const {\n    // TODO: use original \"inplace_to_contiguous\" op replace assign\n    if (input->is_contiguous()) { return input; }\n\n    auto contiguous_tensor = JUST(functional::ToContiguous(input));\n    CHECK_OR_RETURN(input->is_local() && contiguous_tensor->is_local())\n        << \"Both ref and value must be local tensor.\";\n    const Stride stride(*input->shape());\n    // update stride\n    const auto& blob_object = JUST(input->eager_blob_object());\n    Symbol<LocalTensorMeta> old_tensor_meta = JUST(input->local_tensor_meta());\n\n    Symbol<LocalTensorMeta> new_tensor_meta =\n        SymbolOf(LocalTensorMeta(old_tensor_meta->shape(), stride, old_tensor_meta->dtype(),\n                                 old_tensor_meta->memory_format(), old_tensor_meta->device()));\n\n    std::shared_ptr<EagerLocalTensorImpl> final_tensor_impl =\n        std::make_shared<EagerLocalTensorImpl>(JUST(input->tensor_storage()),\n                                               JUST(input->storage_offset()),\n                                               input->requires_grad(), input->is_leaf());\n    JUST(final_tensor_impl->set_retain_grad(input->retain_grad()));\n    JUST(final_tensor_impl->InitEagerBlobObject(new_tensor_meta,\n                                                JUST(blob_object->compute_local_dep_object())));\n    JUST(JUST(input->AsLocalTensor())->set_impl(final_tensor_impl));\n\n    // assign contiguous tensor data\n    JUST(OpInterpUtil::Dispatch<TensorTuple>(*assign_op_, {input, contiguous_tensor}));\n    return input;\n  }\n\n private:\n  std::shared_ptr<OpExpr> assign_op_;\n};\n\nclass NarrowFunctor {\n public:\n  NarrowFunctor() { op_ = CHECK_JUST(one::OpBuilder(\"narrow\").Input(\"in\").Output(\"out\").Build()); }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input, const int64_t& dim,\n                           const int64_t& start, const int64_t& length) const {\n    int64_t narrow_dim = dim;\n    int64_t narrow_start = start;\n    const int64_t ndim = input->shape()->NumAxes();\n    CHECK_GT_OR_RETURN(ndim, 0) << Error::RuntimeError()\n                                << \"narrow() cannot be applied to a 0-dim tensor.\";\n    narrow_dim = JUST(maybe_wrap_dim(narrow_dim, ndim));\n    int64_t dim_length = input->shape()->At(narrow_dim);\n    CHECK_OR_RETURN((-dim_length <= start) && (start <= dim_length))\n        << Error::IndexError() << \"Dimension out of range (expected to be in range of [\" << -ndim\n        << \", \" << ndim << \"], but got \" << start << \")\";\n    if (narrow_start < 0) { narrow_start += ndim; }\n    CHECK_GE_OR_RETURN(dim_length, narrow_start + length)\n        << Error::RuntimeError() << \"start (\" << narrow_start << \") + length (\" << length\n        << \") exceeds dimension size (\" << dim_length << \")\";\n\n    if (view::IsViewApplicable(input)) {\n      return JUST(view::Narrow(input, narrow_dim, narrow_start, length));\n    }\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"dim\", \"start\", \"length\");\n    attrs.SetAllAttrs(narrow_dim, start, length);\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {input}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass NarrowGradFunctor {\n public:\n  NarrowGradFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"narrow_grad\").Input(\"dy\").Input(\"like\").Output(\"dx\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& dy,\n                           const std::shared_ptr<one::Tensor>& like, const int64_t& dim,\n                           const int64_t& start, const int64_t& length) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"dim\", \"start\", \"length\");\n    attrs.SetAllAttrs(dim, start, length);\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {dy, like}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass SliceFunctor {\n public:\n  SliceFunctor() { op_ = CHECK_JUST(one::OpBuilder(\"slice\").Input(\"x\").Output(\"y\").Build()); }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const std::vector<int64_t>& start,\n                           const std::vector<int64_t>& stop, const std::vector<int64_t>& step,\n                           const Optional<bool>& enable_view_slice) const {\n    if (view::IsViewApplicable(x) && enable_view_slice.value_or(false)) {\n      return view::Slice(x, start, stop, step);\n    }\n\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"start\", \"stop\", \"step\");\n    attrs.SetAllAttrs(start, stop, step);\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {x}, attrs);\n  }\n\n protected:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass SliceUpdateFunctor {\n public:\n  SliceUpdateFunctor() {\n    op_ =\n        CHECK_JUST(one::OpBuilder(\"slice_update\").Input(\"ref\").Input(\"value\").Output(\"y\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& ref,\n                           const std::shared_ptr<one::Tensor>& value,\n                           const std::vector<int64_t>& start, const std::vector<int64_t>& stop,\n                           const std::vector<int64_t>& step, bool inplace) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"start\", \"stop\", \"step\");\n    attrs.SetAllAttrs(start, stop, step);\n\n    TensorProcessor tensor_processor;\n    JUST(tensor_processor.AddInputs({ref, value})\n             .PromoteInputsToCommonDtype(true, ref->dtype())\n             .Apply());\n\n    if (inplace) {\n      auto outputs = std::make_shared<TensorTuple>(1);\n      JUST(CheckInplaceValid(ref));\n      JUST(VectorAt(*outputs, 0)) = ref;\n      JUST(OpInterpUtil::Dispatch(*op_, JUST(tensor_processor.GetInputs()), outputs.get(), attrs));\n      return JUST(VectorAt(*outputs, 0));\n    } else {\n      return OpInterpUtil::Dispatch<Tensor>(*op_, JUST(tensor_processor.GetInputs()), attrs);\n    }\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass SliceGradFunctor {\n public:\n  SliceGradFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"slice_grad\").Input(\"dy\").Output(\"dx\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& dy, const Shape& like_shape,\n                           const std::vector<int64_t>& start, const std::vector<int64_t>& stop,\n                           const std::vector<int64_t>& step) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"like_shape\", \"start\", \"stop\", \"step\");\n    attrs.SetAllAttrs(like_shape, start, stop, step);\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {dy}, attrs);\n  }\n\n protected:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass UpsampleGradFunctor {\n public:\n  UpsampleGradFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"upsample_grad\").Input(\"dy\").Input(\"x\").Output(\"dx\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& dy,\n                           const std::shared_ptr<one::Tensor>& x, const double& height_scale,\n                           const double& width_scale, const bool& align_corners,\n                           const std::string& data_format, const std::string& interpolation) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"height_scale\", \"width_scale\", \"align_corners\",\n                                                 \"interpolation\", \"data_format\");\n    attrs.SetAllAttrs(height_scale, width_scale, align_corners, interpolation, data_format);\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {dy, x}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass CopyToDeviceFunctor {\n public:\n  CopyToDeviceFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"copy\").Input(\"in\").Output(\"out\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, Symbol<Device> device,\n                           const bool pin_memory) const {\n    if (x->is_local()) {\n      if (auto x_device = JUST(x->device()); x_device != device && x_device->rematable()) {\n        std::dynamic_pointer_cast<vm::RematableTensorStorage>(\n            JUST(x->eager_blob_object())->tensor_storage())\n            ->Remat();\n      }\n    }\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"device\", \"pin_memory\");\n    attrs.SetAllAttrs(device, pin_memory);\n\n    // Trigger the construction of device context in advance\n    if (device->enum_type() != DeviceType::kCPU) { TouchEpDevice(device); }\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {x}, attrs);\n  }\n\n private:\n  void TouchEpDevice(Symbol<Device> device) const {\n    ep::DeviceManager* device_mgr =\n        Singleton<ep::DeviceManagerRegistry>::Get()->GetDeviceManagerOrNull(device->enum_type());\n    if (!device_mgr) { return; }\n    device_mgr->GetDevice(device->device_id());\n  }\n\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass CopyFunctor {\n public:\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const std::string& device_type,\n                           const int64_t& device_id, const bool pin_memory) const {\n    return functional::Copy(x, JUST(Device::New(device_type, device_id)), pin_memory);\n  }\n};\n\nclass FlipFunctor {\n public:\n  FlipFunctor() { op_ = CHECK_JUST(one::OpBuilder(\"flip\").Input(\"x\").Output(\"y\").Build()); }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,\n                           const std::vector<int32_t>& dims) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"dims\");\n    if (dims.empty()) {\n      attrs.SetAllAttrs(dims);\n    } else {\n      std::vector<int32_t> flip_dims = *JUST(CheckAxis(dims, x->ndim()));\n      attrs.SetAllAttrs(flip_dims);\n    }\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {x}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass UnfoldTensorFunctor {\n public:\n  UnfoldTensorFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"unfold_tensor\").Input(\"x\").Output(\"y\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const int32_t& dimension,\n                           const int32_t& size, const int32_t& step) const {\n    // if input tensor is eager local, than try return tensor's view\n    if (view::IsViewApplicable(x)) { return view::UnfoldTensor(x, dimension, size, step); }\n\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"dimension\", \"size\", \"step\");\n    attrs.SetAllAttrs(dimension, size, step);\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {x}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass UnfoldTensorGradFunctor {\n public:\n  UnfoldTensorGradFunctor() {\n    op_ = CHECK_JUST(\n        one::OpBuilder(\"unfold_tensor_grad\").Input(\"dy\").Input(\"x\").Output(\"dx\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& dy,\n                           const std::shared_ptr<one::Tensor>& x, const int32_t& dimension,\n                           const int32_t& size, const int32_t& step) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"dimension\", \"size\", \"step\");\n    attrs.SetAllAttrs(dimension, size, step);\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {dy, x}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass UpsampleLinear1DFunctor {\n public:\n  UpsampleLinear1DFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"upsample_linear_1d\").Input(\"x\").Output(\"y\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const double& scale_factor,\n                           const bool& align_corners,\n                           const Optional<std::vector<int64_t>>& output_size,\n                           const std::string& data_format) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"scale_factor\", \"align_corners\", \"data_format\",\n                                                 \"output_size\");\n    if (output_size.has_value()) {\n      attrs.SetAllAttrs(scale_factor, align_corners, data_format, *JUST(output_size));\n    } else {\n      attrs.SetAllAttrs(scale_factor, align_corners, data_format, NullOpt);\n    }\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {x}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass UpsampleLinear1DGradFunctor {\n public:\n  UpsampleLinear1DGradFunctor() {\n    op_ = CHECK_JUST(\n        one::OpBuilder(\"upsample_linear_1d_grad\").Input(\"dy\").Input(\"x\").Output(\"dx\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& dy,\n                           const std::shared_ptr<one::Tensor>& x, const double& scale_factor,\n                           const bool& align_corners,\n                           const Optional<std::vector<int64_t>>& output_size,\n                           const std::string& data_format) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"scale_factor\", \"align_corners\", \"data_format\",\n                                                 \"output_size\");\n    if (output_size.has_value()) {\n      attrs.SetAllAttrs(scale_factor, align_corners, data_format, *JUST(output_size));\n    } else {\n      attrs.SetAllAttrs(scale_factor, align_corners, data_format, NullOpt);\n    }\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {dy, x}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass UpsampleNearest1DFunctor {\n public:\n  UpsampleNearest1DFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"upsample_nearest_1d\").Input(\"x\").Output(\"y\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const double& scale_factor,\n                           const Optional<std::vector<int64_t>>& output_size,\n                           const std::string& data_format) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"scale_factor\", \"data_format\", \"output_size\");\n    if (output_size) {\n      attrs.SetAllAttrs(scale_factor, data_format, *JUST(output_size));\n    } else {\n      attrs.SetAllAttrs(scale_factor, data_format, NullOpt);\n    }\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {x}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass UpsampleNearest1DGradFunctor {\n public:\n  UpsampleNearest1DGradFunctor() {\n    op_ = CHECK_JUST(\n        one::OpBuilder(\"upsample_nearest_1d_grad\").Input(\"dy\").Input(\"x\").Output(\"dx\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& dy,\n                           const std::shared_ptr<one::Tensor>& x, const double& scale_factor,\n                           const Optional<std::vector<int64_t>>& output_size,\n                           const std::string& data_format) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"scale_factor\", \"data_format\", \"output_size\");\n    if (output_size) {\n      attrs.SetAllAttrs(scale_factor, data_format, *JUST(output_size));\n    } else {\n      attrs.SetAllAttrs(scale_factor, data_format, NullOpt);\n    }\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {dy, x}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass UpsampleNearest2DFunctor {\n public:\n  UpsampleNearest2DFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"upsample_nearest_2d\").Input(\"x\").Output(\"y\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const double& height_scale,\n                           const double& width_scale,\n                           const Optional<std::vector<int64_t>>& output_size,\n                           const std::string& data_format) const {\n    auto& attrs =\n        THREAD_CACHED_MUTABLE_ATTR_MAP(\"height_scale\", \"width_scale\", \"data_format\", \"output_size\");\n    if (output_size) {\n      attrs.SetAllAttrs(height_scale, width_scale, data_format, *JUST(output_size));\n    } else {\n      attrs.SetAllAttrs(height_scale, width_scale, data_format, NullOpt);\n    }\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {x}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass UpsampleNearest2DGradFunctor {\n public:\n  UpsampleNearest2DGradFunctor() {\n    op_ = CHECK_JUST(\n        one::OpBuilder(\"upsample_nearest_2d_grad\").Input(\"dy\").Input(\"x\").Output(\"dx\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& dy,\n                           const std::shared_ptr<one::Tensor>& x, const double& height_scale,\n                           const double& width_scale,\n                           const Optional<std::vector<int64_t>>& output_size,\n                           const std::string& data_format) const {\n    auto& attrs =\n        THREAD_CACHED_MUTABLE_ATTR_MAP(\"height_scale\", \"width_scale\", \"data_format\", \"output_size\");\n    if (output_size) {\n      attrs.SetAllAttrs(height_scale, width_scale, data_format, *JUST(output_size));\n    } else {\n      attrs.SetAllAttrs(height_scale, width_scale, data_format, NullOpt);\n    }\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {dy, x}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass UpsampleBilinear2DFunctor {\n public:\n  UpsampleBilinear2DFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"upsample_bilinear_2d\").Input(\"x\").Output(\"y\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const double& height_scale,\n                           const double& width_scale, const bool& align_corners,\n                           const Optional<std::vector<int64_t>>& output_size,\n                           const std::string& data_format) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"height_scale\", \"width_scale\", \"align_corners\",\n                                                 \"data_format\", \"output_size\");\n    if (output_size) {\n      attrs.SetAllAttrs(height_scale, width_scale, align_corners, data_format, *JUST(output_size));\n    } else {\n      attrs.SetAllAttrs(height_scale, width_scale, align_corners, data_format, NullOpt);\n    }\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {x}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass UpsampleBilinear2DGradFunctor {\n public:\n  UpsampleBilinear2DGradFunctor() {\n    op_ = CHECK_JUST(\n        one::OpBuilder(\"upsample_bilinear_2d_grad\").Input(\"dy\").Input(\"x\").Output(\"dx\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& dy,\n                           const std::shared_ptr<one::Tensor>& x, const double& height_scale,\n                           const double& width_scale, const bool& align_corners,\n                           const Optional<std::vector<int64_t>>& output_size,\n                           const std::string& data_format) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"height_scale\", \"width_scale\", \"align_corners\",\n                                                 \"data_format\", \"output_size\");\n    if (output_size) {\n      attrs.SetAllAttrs(height_scale, width_scale, align_corners, data_format, *JUST(output_size));\n    } else {\n      attrs.SetAllAttrs(height_scale, width_scale, align_corners, data_format, NullOpt);\n    }\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {dy, x}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass UpsampleBicubic2DFunctor {\n public:\n  UpsampleBicubic2DFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"upsample_bicubic_2d\").Input(\"x\").Output(\"y\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const double& height_scale,\n                           const double& width_scale, const bool& align_corners,\n                           const Optional<std::vector<int64_t>>& output_size,\n                           const std::string& data_format) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"height_scale\", \"width_scale\", \"align_corners\",\n                                                 \"data_format\", \"output_size\");\n    if (output_size) {\n      attrs.SetAllAttrs(height_scale, width_scale, align_corners, data_format, *JUST(output_size));\n    } else {\n      attrs.SetAllAttrs(height_scale, width_scale, align_corners, data_format, NullOpt);\n    }\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {x}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass UpsampleBicubic2DGradFunctor {\n public:\n  UpsampleBicubic2DGradFunctor() {\n    op_ = CHECK_JUST(\n        one::OpBuilder(\"upsample_bicubic_2d_grad\").Input(\"dy\").Input(\"x\").Output(\"dx\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& dy,\n                           const std::shared_ptr<one::Tensor>& x, const double& height_scale,\n                           const double& width_scale, const bool& align_corners,\n                           const Optional<std::vector<int64_t>>& output_size,\n                           const std::string& data_format) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"height_scale\", \"width_scale\", \"align_corners\",\n                                                 \"data_format\", \"output_size\");\n    if (output_size) {\n      attrs.SetAllAttrs(height_scale, width_scale, align_corners, data_format, *JUST(output_size));\n    } else {\n      attrs.SetAllAttrs(height_scale, width_scale, align_corners, data_format, NullOpt);\n    }\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {dy, x}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass UpsampleNearest3DFunctor {\n public:\n  UpsampleNearest3DFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"upsample_nearest_3d\").Input(\"x\").Output(\"y\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const double& depth_scale,\n                           const double& height_scale, const double& width_scale,\n                           const Optional<std::vector<int64_t>>& output_size,\n                           const std::string& data_format) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"depth_scale\", \"height_scale\", \"width_scale\",\n                                                 \"data_format\", \"output_size\");\n    if (output_size) {\n      attrs.SetAllAttrs(depth_scale, height_scale, width_scale, data_format, *JUST(output_size));\n    } else {\n      attrs.SetAllAttrs(depth_scale, height_scale, width_scale, data_format, NullOpt);\n    }\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {x}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass UpsampleNearest3DGradFunctor {\n public:\n  UpsampleNearest3DGradFunctor() {\n    op_ = CHECK_JUST(\n        one::OpBuilder(\"upsample_nearest_3d_grad\").Input(\"dy\").Input(\"x\").Output(\"dx\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& dy,\n                           const std::shared_ptr<one::Tensor>& x, const double& depth_scale,\n                           const double& height_scale, const double& width_scale,\n                           const Optional<std::vector<int64_t>>& output_size,\n                           const std::string& data_format) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"depth_scale\", \"height_scale\", \"width_scale\",\n                                                 \"data_format\", \"output_size\");\n    if (output_size) {\n      attrs.SetAllAttrs(depth_scale, height_scale, width_scale, data_format, *JUST(output_size));\n    } else {\n      attrs.SetAllAttrs(depth_scale, height_scale, width_scale, data_format, NullOpt);\n    }\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {dy, x}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass UpsampleTrilinear3DFunctor {\n public:\n  UpsampleTrilinear3DFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"upsample_trilinear_3d\").Input(\"x\").Output(\"y\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const double& depth_scale,\n                           const double& height_scale, const double& width_scale,\n                           const bool& align_corners,\n                           const Optional<std::vector<int64_t>>& output_size,\n                           const std::string& data_format) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"depth_scale\", \"height_scale\", \"width_scale\",\n                                                 \"align_corners\", \"data_format\", \"output_size\");\n    if (output_size) {\n      attrs.SetAllAttrs(depth_scale, height_scale, width_scale, align_corners, data_format,\n                        *JUST(output_size));\n    } else {\n      attrs.SetAllAttrs(depth_scale, height_scale, width_scale, align_corners, data_format,\n                        NullOpt);\n    }\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {x}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass UpsampleTrilinear3DGradFunctor {\n public:\n  UpsampleTrilinear3DGradFunctor() {\n    op_ = CHECK_JUST(\n        one::OpBuilder(\"upsample_trilinear_3d_grad\").Input(\"dy\").Input(\"x\").Output(\"dx\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& dy,\n                           const std::shared_ptr<one::Tensor>& x, const double& depth_scale,\n                           const double& height_scale, const double& width_scale,\n                           const bool& align_corners,\n                           const Optional<std::vector<int64_t>>& output_size,\n                           const std::string& data_format) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"depth_scale\", \"height_scale\", \"width_scale\",\n                                                 \"align_corners\", \"data_format\", \"output_size\");\n    if (output_size) {\n      attrs.SetAllAttrs(depth_scale, height_scale, width_scale, align_corners, data_format,\n                        *JUST(output_size));\n    } else {\n      attrs.SetAllAttrs(depth_scale, height_scale, width_scale, align_corners, data_format,\n                        NullOpt);\n    }\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {dy, x}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass UnsortedSegmentSumLikeFunctor {\n public:\n  UnsortedSegmentSumLikeFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"unsorted_segment_sum_like\")\n                         .Input(\"data\")\n                         .Input(\"segment_ids\")\n                         .Input(\"like\")\n                         .Output(\"out\")\n                         .Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,\n                           const std::shared_ptr<one::Tensor>& segment_ids,\n                           const std::shared_ptr<one::Tensor>& like, const int64_t& axis) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"axis\");\n    attrs.SetAllAttrs(axis);\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {x, segment_ids, like}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass UnsortedSegmentSumFunctor {\n public:\n  UnsortedSegmentSumFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"unsorted_segment_sum\")\n                         .Input(\"data\")\n                         .Input(\"segment_ids\")\n                         .Output(\"out\")\n                         .Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,\n                           const std::shared_ptr<one::Tensor>& segment_ids, const int64_t& axis,\n                           const int64_t& num_segments) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"axis\", \"num_segments\");\n    attrs.SetAllAttrs(axis, num_segments);\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {x, segment_ids}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass TrilFunctor {\n public:\n  TrilFunctor() { op_ = CHECK_JUST(one::OpBuilder(\"tril\").Input(\"in\").Output(\"out\").Build()); }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const int64_t& diagonal) const {\n    auto& attrs =\n        THREAD_CACHED_MUTABLE_ATTR_MAP(\"diagonal\", \"is_floating_fill_value\", \"integer_fill_value\");\n    attrs.SetAllAttrs(diagonal, false, static_cast<int64_t>(0));\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {x}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass InplaceTrilFunctor {\n public:\n  InplaceTrilFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"tril\").Input(\"in\").Output(\"out\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const int64_t& diagonal) const {\n    JUST(CheckInplaceValid(x));\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"diagonal\");\n    attrs.SetAllAttrs(diagonal);\n    std::shared_ptr<TensorTuple> outputs = std::make_shared<TensorTuple>(1);\n    outputs->at(0) = x;\n    JUST(OpInterpUtil::Dispatch(*op_, {x}, outputs.get(), attrs));\n    return outputs->at(0);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass TriuFunctor {\n public:\n  TriuFunctor() { op_ = CHECK_JUST(one::OpBuilder(\"triu\").Input(\"in\").Output(\"out\").Build()); }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const int64_t& diagonal) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"diagonal\");\n    attrs.SetAllAttrs(diagonal);\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {x}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass InplaceTriuFunctor {\n public:\n  InplaceTriuFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"triu\").Input(\"in\").Output(\"out\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const int64_t& diagonal) const {\n    JUST(CheckInplaceValid(x));\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"diagonal\");\n    attrs.SetAllAttrs(diagonal);\n    std::shared_ptr<TensorTuple> outputs = std::make_shared<TensorTuple>(1);\n    outputs->at(0) = x;\n    JUST(OpInterpUtil::Dispatch(*op_, {x}, outputs.get(), attrs));\n    return outputs->at(0);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass DiagFunctor {\n public:\n  DiagFunctor() { op_ = CHECK_JUST(one::OpBuilder(\"diag\").Input(\"in\").Output(\"out\").Build()); }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const int32_t& diagonal) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"diagonal\");\n    attrs.SetAllAttrs(diagonal);\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {x}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass DiagGradFunctor {\n public:\n  DiagGradFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"diag_grad\").Input(\"dy\").Input(\"in\").Output(\"dx\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& dy,\n                           const std::shared_ptr<one::Tensor>& x, const int32_t& diagonal) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"diagonal\");\n    attrs.SetAllAttrs(diagonal);\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {dy, x}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass DiagonalFunctor {\n public:\n  DiagonalFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"diagonal\").Input(\"in\").Output(\"out\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const int32_t& offset,\n                           const int32_t& dim1, const int32_t& dim2) const {\n    int64_t ndims = x->shape()->NumAxes();\n    int32_t p_dim1 = dim1;\n    int32_t p_dim2 = dim2;\n    p_dim1 = JUST(maybe_wrap_dim(p_dim1, ndims));\n    p_dim2 = JUST(maybe_wrap_dim(p_dim2, ndims));\n    CHECK_NE_OR_RETURN(p_dim1, p_dim2)\n        << Error::RuntimeError() << \"diagonal dimensions cannot be identical \" << dim1 << \", \"\n        << dim2;\n    if (view::IsViewApplicable(x)) {\n      return view::Diagonal(x, offset, p_dim1, p_dim2);\n    } else {\n      auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"offset\");\n      attrs.SetAllAttrs(offset);\n      std::vector<int32_t> input_index{p_dim1, p_dim2};\n      for (int32_t i = 0; i < ndims; i++) {\n        if (i != p_dim1 && i != p_dim2) { input_index.push_back(i); }\n      }\n      std::shared_ptr<one::Tensor> d_x = JUST(Transpose(x, input_index));\n      return OpInterpUtil::Dispatch<Tensor>(*op_, {d_x}, attrs);\n    }\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass DiagonalGradFunctor {\n public:\n  DiagonalGradFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"diagonal_grad\").Input(\"dy\").Input(\"in\").Output(\"dx\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& dy,\n                           const std::shared_ptr<one::Tensor>& x, const int32_t& offset) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"offset\");\n    attrs.SetAllAttrs(offset);\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {dy, x}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\n// Only for ddp gradient grouping\nclass SliceView1dContiguousFunctor {\n public:\n  SliceView1dContiguousFunctor() = default;\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, int64_t start,\n                           int64_t end) const {\n    if (view::IsViewApplicable(x)) { return JUST(view::Slice(x, {start}, {end}, {1})); }\n    return JUST(functional::Slice(x, {start}, {end}, {1}, /*enable_view_slice=*/true));\n  }\n};\n\nclass TensorGetItemFunctor {\n public:\n  TensorGetItemFunctor() {}\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const TensorIndex& index) const {\n    if (x->is_local() && !(LazyMode::is_enabled()) && x->requires_grad() == false\n        && index.size() == 1 && index[0].IsInteger()) {\n      // NOTE: speed up in special case, e.g. dataloader(refer to torch)\n      // function call chain of pytorch : tensor getitem -> select -> as_strided\n      // function call chain of oneflow : tensor getitem -> as_strided\n      return ApplySelectIndexing(x, index);\n    }\n\n    std::vector<detail::Slice> slice_indices;\n    TensorTuple tensor_indices;\n    std::vector<int64_t> target_dims;\n    std::vector<int64_t> expand_dims;\n    JUST(PrepareSliceIndices(index, *(x->shape()), &slice_indices, &tensor_indices, &expand_dims,\n                             &target_dims));\n\n    auto expand_input = x;\n    for (int i = 0; i < expand_dims.size(); ++i) {\n      int64_t dim = expand_dims.at(i);\n      expand_input = JUST(functional::ExpandDims(expand_input, dim + i));\n    }\n    int64_t ndims = expand_input->shape()->NumAxes();\n    CHECK_EQ_OR_RETURN(slice_indices.size(), ndims)\n        << Error::RuntimeError() << \"Failed to prepare slice indices.\";\n    Shape target_shape(DimVector(target_dims.begin(), target_dims.end()));\n\n    std::vector<int64_t> start(ndims), end(ndims), step(ndims);\n    for (int i = 0; i < ndims; ++i) {\n      const auto& slice = slice_indices.at(i);\n      start[i] = slice.start();\n      end[i] = slice.end();\n      step[i] = slice.step();\n    }\n    bool is_identity = [&]() {\n      if (target_shape.NumAxes() == 0) { return false; }\n      for (int i = 0; i < ndims; ++i) {\n        if (start[i] != 0 || end[i] != expand_input->shape()->At(i) || step[i] != 1) {\n          return false;\n        }\n      }\n      return true;\n    }();\n    std::shared_ptr<one::Tensor> result;\n    if (is_identity) {\n      result = expand_input;\n    } else {\n      result = JUST(Slice(expand_input, start, end, step, /*enable_view_slice=*/true));\n    }\n\n    Shape shape(DimVector(target_dims.begin(), target_dims.end()));\n    if (shape != *(result->shape())) { result = JUST(Reshape(result, shape)); }\n    if (!tensor_indices.empty()) {\n      JUST(UnifyInputAndIndicesOnDevice(x, tensor_indices));\n      result = JUST(ApplyAdvancedIndexing(result, tensor_indices));\n    }\n    return result;\n  }\n};\n\nclass TensorSetItemFunctor {\n public:\n  TensorSetItemFunctor() {}\n  Maybe<void> operator()(const std::shared_ptr<one::Tensor>& x, const TensorIndex& index,\n                         const std::shared_ptr<one::Tensor>& value) const {\n    std::vector<detail::Slice> slice_indices;\n    TensorTuple tensor_indices;\n    std::vector<int64_t> expand_dims;\n    std::vector<int64_t> target_dims;\n    JUST(PrepareSliceIndices(index, *(x->shape()), &slice_indices, &tensor_indices, &expand_dims,\n                             &target_dims));\n    auto expand_input = x;\n    if (!expand_dims.empty()) {\n      CHECK_OR_RETURN(view::IsViewApplicable(x)) << \"expand dims must enable view, \"\n                                                    \"please try to set ONEFLOW_DISABLE_VIEW=0\";\n      for (int i = 0; i < expand_dims.size(); ++i) {\n        int64_t dim = expand_dims[i];\n        expand_input = JUST(functional::ExpandDims(expand_input, dim + i));\n      }\n    }\n    int64_t ndims = expand_input->shape()->NumAxes();\n    CHECK_EQ_OR_RETURN(slice_indices.size(), ndims)\n        << Error::RuntimeError() << \"Failed to prepare slice indices.\";\n\n    Shape target_shape(DimVector(target_dims.begin(), target_dims.end()));\n    if (target_shape.Count(0) == 0) { return Maybe<void>::Ok(); }\n    const auto& value_shape = value->shape();\n    bool matched = [&]() {\n      for (int i = 0; i < value_shape->NumAxes() - target_shape.NumAxes(); ++i) {\n        if (value_shape->At(i) != 1) { return false; }\n      }\n      return true;\n    }();\n    CHECK_OR_RETURN(matched) << Error::RuntimeError() << \"The tensor size mismatch. Target sizes: \"\n                             << target_shape.ToString()\n                             << \", value sizes: \" << value_shape->ToString();\n    std::shared_ptr<one::Tensor> value_tensor(value);\n    // TODO: replace reshape by unsqueeze with view mechanism.\n    // after here, each scalar tensor will be one with one dimension.\n    for (auto& tensor : tensor_indices) {\n      if (tensor && tensor->ndim() == 0) { tensor = JUST(functional::Reshape(tensor, Shape({1}))); }\n    }\n\n    DimVector slice_dims(ndims);\n    std::vector<int64_t> start(ndims), end(ndims), step(ndims);\n    for (int i = 0; i < ndims; ++i) {\n      const auto& slice = slice_indices[i];\n      start[i] = slice.start();\n      end[i] = slice.end();\n      step[i] = slice.step();\n      slice_dims[i] = (end[i] - start[i] + step[i] - 1) / step[i];\n    }\n    if (tensor_indices.empty()) {\n      Shape slice_shape(slice_dims);\n      if (slice_shape != *(value_tensor->shape())) {\n        // NOTE:\n        // 1. The value shape must can be broadcasted to the target shape.\n        // 2. The slice shape must have equal element count with the target shape.\n        //\n        // So, we should be expand to target_shape and then reshape to slice_shape.\n        //\n        // For example:\n        // x = flow.rand(2, 3, 4)\n        // y = flow.rand(3)\n        // x[:, :, 1] = y\n        //\n        // value_shape = (3,), target_shape = (2, 3), slice_shape = (2, 3, 1)\n        // We must change value shape to slice_shape if it uses SliceUpdate op.\n        if (target_shape != *(value_tensor->shape()) && target_shape.NumAxes() > 0) {\n          value_tensor = JUST(Expand(value_tensor, target_shape));\n        }\n        if (slice_shape != *(value_tensor->shape())) {\n          value_tensor = JUST(Reshape(value_tensor, slice_shape));\n        }\n      }\n      JUST(SliceUpdate(expand_input, value_tensor, start, end, step, /*inplace=*/true));\n    } else {\n      bool is_identity = [&]() {\n        if (target_shape.NumAxes() == 0) { return false; }\n        for (int i = 0; i < ndims; ++i) {\n          if (start[i] != 0 || end[i] != expand_input->shape()->At(i) || step[i] != 1) {\n            return false;\n          }\n        }\n        return true;\n      }();\n      std::shared_ptr<one::Tensor> result;\n      if (is_identity) {\n        result = expand_input;\n      } else {\n        if (expand_input->is_local()) {\n          CHECK_OR_RETURN(view::IsViewApplicable(expand_input))\n              << \"combined slice setitem must enable view, please try to set \"\n                 \"ONEFLOW_DISABLE_VIEW=0\";\n          result = JUST(Slice(expand_input, start, end, step, /*enable_view_slice=*/true));\n        } else {\n          // global tensor\n          result = JUST(Slice(expand_input, start, end, step, /*enable_view_slice=*/false));\n        }\n      }\n      const Shape& slice_result_shape = *(result->shape());\n      if (target_shape != slice_result_shape) {\n        result = JUST(functional::View(result, target_shape));\n      }\n\n      JUST(UnifyInputAndIndicesOnDevice(result, tensor_indices));\n      result = JUST(ApplyAdvancedIndexingUpdate(result, tensor_indices, value));\n\n      // Write the sliced tensor back to the original tensor.\n      if (result->is_global()) {\n        if (*result->shape() != slice_result_shape) {\n          CHECK_EQ_OR_RETURN(result->shape()->elem_cnt(), slice_result_shape.elem_cnt())\n              << Error::RuntimeError()\n              << \"The global tensor size mismatch. Target sizes: \" << slice_result_shape.ToString()\n              << \", value sizes: \" << result->shape()->ToString();\n          result = JUST(functional::View(result, slice_result_shape));\n        }\n        JUST(SliceUpdate(expand_input, result, start, end, step, /*inplace=*/true));\n      }\n    }\n    return Maybe<void>::Ok();\n  }\n};\n\nclass CastLikeFunctor {\n public:\n  CastLikeFunctor() {\n    op_ = CHECK_JUST(\n        one::OpBuilder(\"cast_like\").Input(\"in\").Input(\"dtype_like\").Output(\"out\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,\n                           const std::shared_ptr<one::Tensor>& like) const {\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {x, like});\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass ElementwiseMinimumGradFunctor {\n public:\n  ElementwiseMinimumGradFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"elementwise_minimum_backward\")\n                         .Input(\"dz\")\n                         .Input(\"x\")\n                         .Input(\"y\")\n                         .Output(\"dx\")\n                         .Output(\"dy\")\n                         .Build());\n  }\n  Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& dz,\n                                const std::shared_ptr<one::Tensor>& x,\n                                const std::shared_ptr<one::Tensor>& y) const {\n    return OpInterpUtil::Dispatch<TensorTuple>(*op_, {dz, x, y});\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass ElementwiseMaximumGradFunctor {\n public:\n  ElementwiseMaximumGradFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"elementwise_maximum_backward\")\n                         .Input(\"dz\")\n                         .Input(\"x\")\n                         .Input(\"y\")\n                         .Output(\"dx\")\n                         .Output(\"dy\")\n                         .Build());\n  }\n  Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& dz,\n                                const std::shared_ptr<one::Tensor>& x,\n                                const std::shared_ptr<one::Tensor>& y) const {\n    return OpInterpUtil::Dispatch<TensorTuple>(*op_, {dz, x, y});\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass DivGradFunctor {\n public:\n  DivGradFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"broadcast_div_grad\")\n                         .Input(\"dz\")\n                         .Input(\"z\")\n                         .Input(\"y\")\n                         .Output(\"dy\")\n                         .Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& dz,\n                           const std::shared_ptr<one::Tensor>& z,\n                           const std::shared_ptr<one::Tensor>& y) const {\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {dz, z, y});\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass BroadcastPowXGradFunctor {\n public:\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,\n                           const std::shared_ptr<one::Tensor>& y,\n                           const std::shared_ptr<one::Tensor>& dz) const {\n    auto y_sub_one = JUST(functional::ScalarSub(y, 1, /*alpha=*/1, /*inplace=*/false));\n    auto result = functional::sequence_function(functional::BroadcastPow)\n                      .then(std::bind(functional::Mul, std::placeholders::_1, y))\n                      .then(std::bind(functional::Mul, std::placeholders::_1, dz))\n                      .then(std::bind(functional::BroadcastReduceSumLike, std::placeholders::_1, x))\n                      .call(x, y_sub_one);\n    return result;\n  }\n};\n\nclass BroadcastPowYGradFunctor {\n public:\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,\n                           const std::shared_ptr<one::Tensor>& y,\n                           const std::shared_ptr<one::Tensor>& dz) const {\n    auto result =\n        functional::sequence_function(functional::BroadcastPow)\n            .then(std::bind(functional::Mul, std::placeholders::_1, JUST(functional::Log(x))))\n            .then(std::bind(functional::Mul, std::placeholders::_1, dz))\n            .then(std::bind(functional::BroadcastReduceSumLike, std::placeholders::_1, y))\n            .call(x, y);\n    return result;\n  }\n};\n\nclass IdentityFunctor {\n public:\n  IdentityFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"identity\").Input(\"in\").Output(\"out\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& in) const {\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {in});\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass AmpWhiteIdentityFunctor {\n public:\n  AmpWhiteIdentityFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"amp_white_identity\").Input(\"in\").Output(\"out\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& in) const {\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {in});\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass AmpBlackIdentityFunctor {\n public:\n  AmpBlackIdentityFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"amp_black_identity\").Input(\"in\").Output(\"out\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& in) const {\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {in});\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass ReduceSumLikeFunctor {\n public:\n  ReduceSumLikeFunctor() {\n    op_ =\n        CHECK_JUST(one::OpBuilder(\"reduce_sum_like\").Input(\"x\").Input(\"like\").Output(\"y\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,\n                           const std::shared_ptr<one::Tensor>& like,\n                           const std::vector<int32_t>& axis) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"axis\");\n    attrs.SetAllAttrs(axis);\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {x, JUST(like->detach())}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass BroadcastReduceSumLikeFunctor {\n public:\n  BroadcastReduceSumLikeFunctor() {}\n  Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& input,\n                           const std::shared_ptr<Tensor>& like) const {\n    const auto& in_shape = *(input->shape());\n    const auto& like_shape = *(like->shape());\n    if (in_shape != like_shape) {\n      const Shape& left_extended_shape =\n          CreateLeftExtendedShape(ShapeView(like_shape), in_shape.NumAxes());\n      if (in_shape == left_extended_shape) {\n        return JUST(ReshapeLike(input, like));\n      } else {\n        const AxisVector& broadcast_axis_vec = left_extended_shape.Axes4BroadcastTo(in_shape);\n        return JUST(ReduceSumLike(\n            input, like,\n            std::vector<int32_t>{broadcast_axis_vec.begin(), broadcast_axis_vec.end()}));\n      }\n    }\n    return JUST(Identity(input));\n  }\n};\n\nclass SplitFunctor {\n public:\n  SplitFunctor() {}\n  Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& x,\n                                const int64_t& split_size_or_sections, const int64_t& dim) const {\n    int64_t axis = dim;\n    axis = JUST(maybe_wrap_dim(axis, x->ndim()));\n    CHECK_GE_OR_RETURN(split_size_or_sections, 0)\n        << Error::RuntimeError() << \"split expects split_size be non-negative, but got split_size=\"\n        << split_size_or_sections;\n    int64_t dim_size = x->shape()->At(axis);\n    int64_t num_splits =\n        std::max<int64_t>((dim_size + split_size_or_sections - 1) / split_size_or_sections, 1);\n    TensorTuple splits(num_splits);\n    int64_t last_split_size =\n        split_size_or_sections - (split_size_or_sections * num_splits - dim_size);\n    for (int i = 0; i < num_splits; ++i) {\n      int64_t length = i < num_splits - 1 ? split_size_or_sections : last_split_size;\n      splits[i] = JUST(Narrow(x, axis, i * split_size_or_sections, length));\n    }\n    return splits;\n  }\n};\n\nclass UnbindFunctor {\n public:\n  UnbindFunctor() {}\n  Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& x, const int64_t& dim) const {\n    int32_t axis = dim;\n    const int32_t ndim = x->ndim();\n    axis = JUST(maybe_wrap_dim(axis, ndim));\n    int32_t dim_size = x->shape()->At(axis);\n    std::shared_ptr<TensorTuple> chunk_res = JUST(functional::Chunk(x, dim_size, axis));\n    TensorTuple unbinds(dim_size);\n    std::vector<int32_t> dims = {axis};\n    for (int i = 0; i < dim_size; ++i) {\n      unbinds[i] = JUST(functional::Squeeze((*chunk_res)[i], dims));\n    }\n    return unbinds;\n  }\n};\n\nclass ChunkFunctor {\n public:\n  ChunkFunctor() {}\n  Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& x, const int64_t& chunks,\n                                const int64_t& dim) const {\n    const int64_t ndim = x->ndim();\n    int64_t infferd_dim = dim;\n    CHECK_OR_RETURN(ndim > 0) << Error::RuntimeError()\n                              << \"chunk expects at least a 1-dimensional tensor.\";\n    CHECK_OR_RETURN(chunks > 0) << Error::RuntimeError()\n                                << \"chunk expects `chunks` to be greater than 0, got: \" << chunks;\n    infferd_dim = JUST(maybe_wrap_dim(infferd_dim, ndim));\n\n    const auto dim_size = x->shape()->At(infferd_dim);\n    int64_t split_size = (dim_size + chunks - 1) / chunks;\n    if (split_size == 0 && dim_size == 0) {\n      std::vector<int64_t> split_sizes(chunks, split_size);\n      split_sizes[chunks - 1] = split_size - (split_size * chunks - dim_size);\n      return functional::SplitWithSize(x, split_sizes, infferd_dim);\n    } else {\n      return functional::Split(x, split_size, infferd_dim);\n    }\n  }\n};\n\nclass SplitLikeFunctor {\n public:\n  SplitLikeFunctor() {\n    ops_.resize(kMaxInputCount);\n    for (int n = 1; n < ops_.size(); ++n) {\n      ops_[n] = CHECK_JUST(one::OpBuilder(\"split_like\")\n                               .Input(\"in\")\n                               .Input(\"like\", n + 1)\n                               .Output(\"out\", n + 1)\n                               .Build());\n    }\n  }\n  Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& x, const TensorTuple& like,\n                                const int64_t& axis) const {\n    CHECK_GE_OR_RETURN(like.size(), 2)\n        << Error::RuntimeError() << \"like.size() must not less than 2, but got \" << like.size();\n    CHECK_LE_OR_RETURN(like.size(), kMaxInputCount)\n        << Error::RuntimeError() << \"like.size() must not greater than \" << kMaxInputCount\n        << \", but got \" << like.size();\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"axis\");\n    attrs.SetAllAttrs(axis);\n    TensorTuple inputs(like.size() + 1);\n    inputs[0] = x;\n    for (int i = 0; i < like.size(); ++i) { inputs[i + 1] = JUST(like[i]->detach()); }\n    return OpInterpUtil::Dispatch<TensorTuple>(*ops_.at(like.size() - 1), inputs, attrs);\n  }\n\n private:\n  std::vector<std::shared_ptr<OpExpr>> ops_;\n};\n\nclass SplitWithSizeFunctor {\n public:\n  SplitWithSizeFunctor() {}\n  Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& x,\n                                const std::vector<int64_t>& split_size_or_sections,\n                                const int64_t& dim) const {\n    int64_t axis = dim;\n    axis = JUST(maybe_wrap_dim(axis, x->ndim()));\n    int64_t dim_size = x->shape()->At(axis);\n    int64_t num_splits = split_size_or_sections.size();\n    TensorTuple splits(num_splits);\n    int64_t start_idx = 0;\n    for (int i = 0; i < num_splits; ++i) {\n      int64_t length = split_size_or_sections[i];\n      CHECK_GE_OR_RETURN(length, 0) << Error::RuntimeError()\n                                    << \"split_with_sizes expects split_sizes have only \"\n                                       \"non-negative entries, but split_sizes[\"\n                                    << i << \"] = \" << length;\n      splits[i] = JUST(Narrow(x, axis, start_idx, length));\n      start_idx += length;\n    }\n    CHECK_EQ_OR_RETURN(start_idx, dim_size)\n        << Error::RuntimeError() << \"split_with_sizes expects split_sizes to sum exactly to \"\n        << dim_size << \" (input tensor's size at dimension \" << axis << \"), \"\n        << \"but got sum(split_sizes)=\" << start_idx;\n    return splits;\n  }\n};\n\nclass BatchGatherFunctor {\n public:\n  BatchGatherFunctor() {\n    op_ = CHECK_JUST(\n        one::OpBuilder(\"batch_gather\").Input(\"in\").Input(\"indices\").Output(\"out\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& in,\n                           const std::shared_ptr<one::Tensor>& indices) const {\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {in, indices});\n  }\n\n protected:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass UnsortedBatchSegmentSumFunctor {\n public:\n  UnsortedBatchSegmentSumFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"unsorted_batch_segment_sum\")\n                         .Input(\"data\")\n                         .Input(\"segment_ids\")\n                         .Output(\"out\")\n                         .Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& data,\n                           const std::shared_ptr<one::Tensor>& segment_ids,\n                           const int64_t& num_segments) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"num_segments\");\n    attrs.SetAllAttrs(num_segments);\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {data, segment_ids}, attrs);\n  }\n\n protected:\n  std::shared_ptr<OpExpr> op_;\n};\n\ntemplate<bool inplace>\nclass MaskedFillFunctor {\n public:\n  MaskedFillFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"masked_fill\").Input(\"x\").Input(\"mask\").Output(\"out\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,\n                           const std::shared_ptr<one::Tensor>& mask, const Scalar& value) const {\n    auto& attrs =\n        THREAD_CACHED_MUTABLE_ATTR_MAP(\"float_operand\", \"has_float_operand\", \"int_operand\",\n                                       \"has_int_operand\", \"bool_operand\", \"has_bool_operand\");\n    if (IsFloatingDataType(x->dtype()->data_type())) {\n      attrs.SetAllAttrs(value.As<double>(), true, NullOpt, false, NullOpt, false);\n    } else if (IsIntegralDataType(x->dtype()->data_type())) {\n      attrs.SetAllAttrs(NullOpt, false, value.As<int64_t>(), true, NullOpt, false);\n    } else if (IsBoolDataType(x->dtype()->data_type())) {\n      attrs.SetAllAttrs(NullOpt, false, NullOpt, false, value.As<bool>(), true);\n    } else {\n      UNIMPLEMENTED_THEN_RETURN() << \"Only support floating or integral data type.\";\n    }\n    const auto& x_shape = *(x->shape());\n    const auto& mask_shape = *(mask->shape());\n\n    std::shared_ptr<TensorTuple> outputs = std::make_shared<TensorTuple>(1);\n    if (inplace) {\n      JUST(CheckInplaceValid(x));\n      (*outputs)[0] = x;\n    }\n\n    if (x_shape != mask_shape) {\n      Shape max_shape = Shape::Ones(std::max(x_shape.NumAxes(), mask_shape.NumAxes()));\n      const Shape& x_extend_shape =\n          CreateLeftExtendedShape(ShapeView(x_shape), max_shape.NumAxes());\n      const Shape& mask_extend_shape =\n          CreateLeftExtendedShape(ShapeView(mask_shape), max_shape.NumAxes());\n      FOR_RANGE(int64_t, i, 0, max_shape.NumAxes()) {\n        max_shape.Set(i, std::max(x_extend_shape.At(i), mask_extend_shape.At(i)));\n      }\n      JUST(OpInterpUtil::Dispatch(*op_, {JUST(Expand(x, max_shape)), JUST(Expand(mask, max_shape))},\n                                  outputs.get(), attrs));\n      return outputs->at(0);\n    }\n\n    JUST(OpInterpUtil::Dispatch(*op_, {x, mask}, outputs.get(), attrs));\n    return outputs->at(0);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass MeshgridFunctor {\n public:\n  Maybe<TensorTuple> operator()(const TensorTuple& tensors, const std::string& indexing) const {\n    int size = tensors.size();\n    CHECK_GT_OR_RETURN(size, 0) << Error::RuntimeError()\n                                << \"meshgrid expects a non-empty TensorList\";\n    for (int i = 0; i < size - 1; ++i) {\n      const auto& cur_tensor = JUST(VectorAt(tensors, i));\n      const auto& next_tensor = JUST(VectorAt(tensors, i + 1));\n      CHECK_OR_RETURN(cur_tensor->dtype() == next_tensor->dtype())\n          << Error::RuntimeError() << \"meshgrid expects all tensors to have the same dtype\";\n      if (cur_tensor->is_local()) {\n        CHECK_OR_RETURN(next_tensor->is_local())\n            << Error::RuntimeError() << \"meshgrid expects all tensors are local tensor\";\n        CHECK_OR_RETURN(JUST(cur_tensor->device())->type() == JUST(next_tensor->device())->type())\n            << Error::RuntimeError() << \"meshgrid expects all tensors to have the same device\";\n      } else {\n        CHECK_OR_RETURN(!next_tensor->is_local())\n            << Error::RuntimeError() << \"meshgrid expects all tensors are global tensor\";\n        CHECK_OR_RETURN(JUST(cur_tensor->parallel_desc()) == JUST(next_tensor->parallel_desc()))\n            << Error::RuntimeError() << \"meshgrid expects all tensors to have the same placement\";\n      }\n    }\n\n    std::vector<std::shared_ptr<Tensor>> tensor_consts(tensors.begin(), tensors.end());\n\n    bool swap_first_and_second_tensors = false;\n    if (indexing == \"xy\") {\n      swap_first_and_second_tensors = (size >= 2);\n      if (swap_first_and_second_tensors) { std::swap(tensor_consts[0], tensor_consts[1]); }\n    } else {\n      CHECK_EQ_OR_RETURN(indexing, \"ij\") << Error::RuntimeError()\n                                         << \"meshgrid: indexing must be one of \\\"xy\\\" or \\\"ij\\\", \"\n                                            \"but received: \"\n                                         << indexing;\n    }\n\n    TensorTuple grids(size);\n    DimVector grids_vec(size);\n    for (int i = 0; i < size; ++i) {\n      CHECK_LE_OR_RETURN(tensor_consts[i]->shape()->NumAxes(), 1)\n          << Error::RuntimeError() << \"Expected scalar or 1D tensor in the tensor list but got \"\n          << tensor_consts[i]->shape()->NumAxes();\n      if (tensor_consts[i]->shape()->NumAxes() == 0) {\n        grids_vec[i] = 1;\n      } else {\n        grids_vec[i] = tensor_consts[i]->shape()->At(0);\n      }\n    }\n    Shape grids_shape(grids_vec);\n\n    DimVector view_shape_vec(size, 1);\n    Shape view_shape(view_shape_vec);\n    for (int i = 0; i < size; ++i) {\n      view_shape.Set(i, -1);\n      std::shared_ptr<one::Tensor> reshaped = JUST(Reshape(tensor_consts.at(i), view_shape));\n      grids[i] = JUST(Expand(reshaped, grids_shape));\n      view_shape.Set(i, 1);\n    }\n\n    if (swap_first_and_second_tensors) { std::swap(grids[0], grids[1]); }\n\n    return grids;\n  }\n};\n\nclass IndexSelectFunctor {\n public:\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input, const int64_t& dim,\n                           const std::shared_ptr<one::Tensor>& index) const {\n    const int64_t input_num_axes = input->shape()->NumAxes();\n    const int64_t index_num_axes = index->shape()->NumAxes();\n    CHECK_LE_OR_RETURN(index_num_axes, 1)\n        << Error::IndexError() << \"index_select(): Index is supposed to be a vector\";\n    bool index_dtype_flag =\n        (index->dtype()->data_type() == kInt32) || (index->dtype()->data_type() == kInt64);\n    CHECK_EQ_OR_RETURN(index_dtype_flag, true)\n        << Error::RuntimeError() << \"index_select(): Expected dtype int32 or int64 for index\";\n    int64_t new_dim = dim;\n    new_dim = JUST(maybe_wrap_dim(new_dim, input_num_axes));\n    return JUST(functional::Gather(input, index, new_dim));\n  }\n};\n\nnamespace {\n\nMaybe<Tensor> LocalTensorTo(const std::shared_ptr<Tensor>& x, Symbol<Device> device,\n                            const Symbol<DType>& dtype, const bool& copy) {\n  std::shared_ptr<Tensor> tensor = x;\n  if (device != JUST(x->device())) { tensor = JUST(Copy(tensor, device, /*pin_memory=*/false)); }\n  if (dtype != x->dtype()) { tensor = JUST(Cast(tensor, dtype, /*pin_memory=*/false)); }\n  if (copy && tensor == x) { tensor = JUST(Copy(tensor, device, /*pin_memory=*/false)); }\n  return tensor;\n}\n\nMaybe<Tensor> GlobalTensorTo(const std::shared_ptr<Tensor>& x, const std::string& device_type,\n                             const Symbol<DType>& dtype, const bool& copy) {\n  std::shared_ptr<Tensor> tensor;\n  auto input_placement = JUST(x->parallel_desc());\n  std::string input_device_tag = input_placement->device_tag();\n  if (input_device_tag == \"gpu\") { input_device_tag = \"cuda\"; }\n  if (device_type == input_device_tag) {\n    if (dtype == x->dtype()) {\n      return (copy ? JUST(x->clone()) : x);\n    } else {\n      return JUST(Cast(x, dtype, /*pin_memory=*/false));\n    }\n  }\n  if (LazyMode::is_enabled()) {\n    if (dtype != x->dtype()) { tensor = JUST(Cast(x, dtype, /*pin_memory=*/false)); }\n    if (device_type != JUST(x->parallel_desc())->device_tag()) {\n      tensor = JUST(Copy(tensor ? tensor : x, device_type, 0, /*pin_memory=*/false));\n    }\n    return tensor;\n  } else {\n    CheckMetaConsistency(x).GetOrThrow();\n    auto placement = JUST(ReplacePlacementDeviceTag(input_placement, device_type));\n    auto nd_sbp = JUST(x->nd_sbp());\n    std::vector<Symbol<SbpParallel>> sbp_tuple(nd_sbp->sbp_parallel().size());\n    for (int i = 0; i < sbp_tuple.size(); ++i) { sbp_tuple[i] = nd_sbp->sbp_parallel().Get(i); }\n    tensor = JUST(GlobalToLocal(x, /*copy=*/false));\n    Symbol<Device> device = JUST(Device::New(device_type));\n    tensor = JUST(LocalTensorTo(tensor, device, dtype, copy));\n    JUST(tensor->set_requires_grad(x->requires_grad()));\n    return JUST(LocalToGlobal(tensor, placement, sbp_tuple, *(x->shape()), dtype,\n                              /* sync_data */ true, /*copy=*/false));\n  }\n}\n\n}  // namespace\n\nclass ToFunctor {\n public:\n  Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& input,\n                           const Optional<std::string>& device_,\n                           const Optional<Symbol<DType>>& dtype_, bool copy) const {\n    Symbol<DType> dtype = dtype_.value_or(input->dtype());\n    if (input->is_global()) {\n      std::string device_type = device_.value_or(JUST(input->parallel_desc())->device_tag());\n      CHECK_OR_RETURN(ep::DeviceManagerRegistry::GetDeviceTypeByDeviceTypeName(device_type)\n                      != DeviceType::kInvalidDevice)\n          << Error::RuntimeError()\n          << \"Only string device without device id (eg. \\\"cpu\\\" or \\\"cuda\\\") is expected \"\n          << \"for global tensor, but got \" << device_.value_or(\"\");\n      return JUST(GlobalTensorTo(input, device_type, dtype, copy));\n    } else {\n      Symbol<Device> device =\n          device_\n              .map([](const std::shared_ptr<std::string>& str) -> Symbol<Device> {\n                return CHECK_JUST(Device::ParseAndNew(*str));\n              })\n              .value_or(JUST(input->device()));\n      return JUST(LocalTensorTo(input, device, dtype, copy));\n    }\n  }\n};\n\nclass To2Functor {\n public:\n  Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& input,\n                           const Optional<Symbol<Device>>& device_,\n                           const Optional<Symbol<DType>>& dtype_, bool copy) const {\n    if (input->is_global()) {\n      if (!device_.has_value()) {\n        std::string device_type = JUST(input->parallel_desc())->device_tag();\n        return JUST(GlobalTensorTo(input, device_type, dtype_.value_or(input->dtype()), copy));\n      } else {\n        if (!GlobalMode::is_enabled()) {\n          CHECK_OR_RETURN(!device_.has_value())\n              << Error::RuntimeError()\n              << \"Only string device without device id (eg. \\\"cpu\\\" or \\\"cuda\\\") is expected \"\n              << \"for global tensor, but got \" << device_.value_or(Symbol<Device>())->ToRepr();\n        }\n        std::string device_type = device_.value_or(Symbol<Device>())->type();\n        return JUST(GlobalTensorTo(input, device_type, dtype_.value_or(input->dtype()), copy));\n      }\n    } else {\n      auto dtype = dtype_.value_or(input->dtype());\n      auto device = device_.value_or(JUST(input->device()));\n      return JUST(LocalTensorTo(input, device, dtype, copy));\n    }\n  }\n};\n\nclass To3Functor {\n public:\n  Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& input,\n                           const Optional<Symbol<DType>>& dtype_, bool copy) const {\n    Symbol<DType> dtype = dtype_.value_or(input->dtype());\n    if (input->is_global()) {\n      return GlobalTensorTo(input, JUST(input->parallel_desc())->device_tag(), dtype, copy);\n    } else {\n      auto device = JUST(input->device());\n      return LocalTensorTo(input, device, dtype, copy);\n    }\n  }\n};\n\nclass To4Functor {\n public:\n  Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& input,\n                           const std::shared_ptr<Tensor>& other, bool copy) const {\n    CHECK_OR_RETURN(!input->is_global() && !other->is_global())\n        << Error::RuntimeError()\n        << \"tensor.to(other) can only be called when tensor and other are local tensors\";\n    Symbol<DType> dtype = other->dtype();\n    Symbol<Device> device = JUST(other->device());\n    return LocalTensorTo(input, device, dtype, copy);\n  }\n};\n\nclass ToDeviceFunctor {\n public:\n  Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& input,\n                           const Optional<std::string>& device_) const {\n    Symbol<DType> dtype = input->dtype();\n    const bool copy = false;\n    if (input->is_global()) {\n      std::string device_type = device_.value_or(JUST(input->parallel_desc())->device_tag());\n      CHECK_OR_RETURN(ep::DeviceManagerRegistry::GetDeviceTypeByDeviceTypeName(device_type)\n                      != DeviceType::kInvalidDevice)\n          << Error::RuntimeError()\n          << \"Only string device without device id (eg. \\\"cpu\\\" or \\\"cuda\\\") is expected \"\n          << \"for global tensor, but got \" << device_.value_or(\"\");\n      return JUST(GlobalTensorTo(input, device_type, dtype, copy));\n    } else {\n      Symbol<Device> device =\n          device_\n              .map([](const std::shared_ptr<std::string>& str) -> Symbol<Device> {\n                return CHECK_JUST(Device::ParseAndNew(*str));\n              })\n              .value_or(JUST(input->device()));\n      return JUST(LocalTensorTo(input, device, dtype, copy));\n    }\n  }\n};\n\nclass ToMemoryFormatFunctor {\n public:\n  ToMemoryFormatFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"convert_memory_format\").Input(\"in\").Output(\"out\").Build());\n  }\n\n  Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& input, MemoryFormat memory_format) const {\n    if (input->memory_format() == memory_format) { return input; }\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"memory_format\");\n    attrs.SetAllAttrs(memory_format);\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {input}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass TopKFunctor {\n public:\n  TopKFunctor() { op_ = CHECK_JUST(one::OpBuilder(\"top_k\").Input(\"in\").Output(\"out\").Build()); }\n  Maybe<TensorTuple> operator()(const std::shared_ptr<Tensor>& input, const int32_t k,\n                                const Optional<int32_t>& dim, const bool largest,\n                                const bool sorted) const {\n    auto outputs = std::make_shared<TensorTuple>(2);\n    std::shared_ptr<Tensor> values;\n    std::shared_ptr<Tensor> indices;\n\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"k\", \"sorted\");\n    attrs.SetAllAttrs(k, sorted);\n\n    int32_t dim_value = dim.value_or(-1);\n    int32_t axis = dim_value;\n    axis = JUST(maybe_wrap_dim(axis, input->ndim()));\n    if (axis == input->ndim() - 1) {\n      if (largest) {\n        indices = JUST(OpInterpUtil::Dispatch<Tensor>(*op_, {input}, attrs));\n      } else {\n        auto neg_input = JUST(ScalarMul(input, -1, false));\n        indices = JUST(OpInterpUtil::Dispatch<Tensor>(*op_, {neg_input}, attrs));\n      }\n      values = JUST(DimGather(input, axis, indices, false));\n\n    } else {\n      auto perm = JUST(GetPermWhenTransposeAxisToLastDim(input->ndim(), dim_value));\n      auto x = JUST(Transpose(input, *perm));\n      if (largest) {\n        indices = JUST(OpInterpUtil::Dispatch<Tensor>(*op_, {x}, attrs));\n      } else {\n        auto neg_input = JUST(ScalarMul(x, -1, false));\n        indices = JUST(OpInterpUtil::Dispatch<Tensor>(*op_, {neg_input}, attrs));\n      }\n      auto inversed_perm = JUST(GetInversedPerm(*perm));\n      indices = JUST(Transpose(indices, *inversed_perm));\n      values = JUST(DimGather(input, axis, indices, false));\n    }\n    (*outputs)[0] = values;\n    (*outputs)[1] = indices;\n    return outputs;\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass InTopKFunctor {\n public:\n  InTopKFunctor() {\n    op_ = CHECK_JUST(\n        one::OpBuilder(\"in_top_k\").Input(\"targets\").Input(\"predictions\").Output(\"out\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& targets,\n                           const std::shared_ptr<Tensor>& predictions, int32_t k) const {\n    CHECK_EQ_OR_RETURN(targets->shape()->At(0), predictions->shape()->At(0))\n        << Error::RuntimeError() << \"The num of targets must equal the num of predictions\";\n    CHECK_EQ_OR_RETURN(targets->ndim(), 1)\n        << Error::RuntimeError() << \"The dimension of targets must be 1\";\n    CHECK_EQ_OR_RETURN(predictions->ndim(), 2)\n        << Error::RuntimeError() << \"The dimension of predictions must be 2\";\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"k\");\n    attrs.SetAllAttrs(k);\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {targets, predictions}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass TensorBufferToTensorFunctor {\n public:\n  TensorBufferToTensorFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"tensor_buffer_to_tensor\").Input(\"in\").Output(\"out\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& input, const Shape& instance_shape,\n                           const Symbol<DType>& dtype) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"instance_shape\", \"dtype\");\n    attrs.SetAllAttrs(instance_shape, dtype->data_type());\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {input}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass TensorToTensorBufferFunctor {\n public:\n  TensorToTensorBufferFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"tensor_to_tensor_buffer\").Input(\"in\").Output(\"out\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& input, int32_t instance_dims) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"instance_dims\");\n    attrs.SetAllAttrs(instance_dims);\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {input}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass GenTensorBufferFunctor {\n public:\n  GenTensorBufferFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"gen_tensor_buffer\").Output(\"out\").Build());\n  }\n  Maybe<Tensor> operator()(const Shape& shape, const std::vector<Shape>& shape_list,\n                           const std::vector<float>& value_list, const Symbol<DType>& dtype,\n                           bool dynamic_out) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"shape\", \"shape_list\", \"value_list\", \"data_type\",\n                                                 \"dynamic_out\");\n    attrs.SetAllAttrs(shape, shape_list, value_list, dtype->data_type(), dynamic_out);\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass RepeatFunctor {\n public:\n  RepeatFunctor() {}\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input,\n                           const Shape& repeat_shape) const {\n    Shape input_shape = *(input->shape());\n    std::vector<int32_t> input_reshape_vec;\n    std::vector<int32_t> expand_shape_vec;\n    std::vector<int32_t> output_reshape_vec;\n\n    int32_t numaxes_diff = repeat_shape.NumAxes() - input_shape.NumAxes();\n    CHECK_GE_OR_RETURN(numaxes_diff, 0) << Error::RuntimeError()\n                                        << \"Number of dimensions of repeat dims can not be \"\n                                           \"smaller than number of dimensions of tensor\";\n\n    for (int32_t i = repeat_shape.NumAxes() - 1; i >= 0; i--) {\n      if (i >= numaxes_diff) {\n        int32_t input_shape_val = input_shape.At(i - numaxes_diff);\n        int32_t repeat_shape_val = repeat_shape.At(i);\n        if (repeat_shape_val > 1) {\n          if (input_shape_val > 1) {\n            input_reshape_vec.insert(input_reshape_vec.begin(), input_shape_val);\n            input_reshape_vec.insert(input_reshape_vec.begin(), 1);\n            expand_shape_vec.insert(expand_shape_vec.begin(), input_shape_val);\n            expand_shape_vec.insert(expand_shape_vec.begin(), repeat_shape_val);\n            output_reshape_vec.insert(output_reshape_vec.begin(),\n                                      repeat_shape_val * input_shape_val);\n          } else {\n            input_reshape_vec.insert(input_reshape_vec.begin(), input_shape_val);\n            expand_shape_vec.insert(expand_shape_vec.begin(), repeat_shape_val);\n            output_reshape_vec.insert(output_reshape_vec.begin(), repeat_shape_val);\n          }\n        } else {\n          input_reshape_vec.insert(input_reshape_vec.begin(), input_shape_val);\n          // For 0-size tensor, align with PyTorch.\n          if (repeat_shape_val == 0) {\n            expand_shape_vec.insert(expand_shape_vec.begin(), 0);\n            output_reshape_vec.insert(output_reshape_vec.begin(), 0);\n          } else {\n            expand_shape_vec.insert(expand_shape_vec.begin(), input_shape_val);\n            output_reshape_vec.insert(output_reshape_vec.begin(), input_shape_val);\n          }\n        }\n      } else {\n        expand_shape_vec.insert(expand_shape_vec.begin(), repeat_shape.At(i));\n        output_reshape_vec.insert(output_reshape_vec.begin(), repeat_shape.At(i));\n      }\n    }\n    Shape input_reshape(DimVector(input_reshape_vec.begin(), input_reshape_vec.end()));\n    Shape expand_shape(DimVector(expand_shape_vec.begin(), expand_shape_vec.end()));\n    Shape output_reshape(DimVector(output_reshape_vec.begin(), output_reshape_vec.end()));\n    std::shared_ptr<one::Tensor> reshaped_tensor = JUST(Reshape(input, input_reshape));\n    std::shared_ptr<one::Tensor> expanded_tensor = JUST(Expand(reshaped_tensor, expand_shape));\n    std::shared_ptr<one::Tensor> result = JUST(Reshape(expanded_tensor, output_reshape));\n    return result->contiguous();\n  }\n};\n\nclass RepeatInterLeaveIndexFunctor {\n public:\n  RepeatInterLeaveIndexFunctor() {\n    op_ = CHECK_JUST(\n        one::OpBuilder(\"repeat_interleave\").Input(\"in\").Input(\"cumsum\").Output(\"out\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input,\n                           const std::shared_ptr<one::Tensor>& cumsum,\n                           const int32_t& repeat_num) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"repeat_num\");\n    attrs.SetAllAttrs(static_cast<int64_t>(repeat_num));\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {input, cumsum}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass RepeatInterLeaveIntFunctor {\n public:\n  RepeatInterLeaveIntFunctor() {}\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input, const int32_t& repeats,\n                           const Optional<int32_t>& dim) const {\n    CHECK_OR_RETURN(input->is_local() == true)\n        << Error::RuntimeError() << \"repeat_interleave only support local tensor now\";\n    std::shared_ptr<one::Tensor> res;\n    if (!dim.has_value()) {\n      std::shared_ptr<one::Tensor> flatten_input = JUST(Flatten(input, 0, -1));\n      std::shared_ptr<one::Tensor> repeats_expand = JUST(\n          Expand(JUST(Constant(Shape{1}, Scalar(repeats), DType::Int32(), JUST(input->device()))),\n                 Shape{flatten_input->shape()->At(0)}));\n      std::shared_ptr<one::Tensor> cumsum = JUST(Cumsum(repeats_expand, 0, DType::Int32()));\n      int64_t output_size = flatten_input->shape()->At(0);\n      if (repeats > 0) { output_size *= repeats; }\n      res = JUST(IndexSelect(flatten_input, 0,\n                             JUST(RepeatInterLeaveIndex(repeats_expand, cumsum, output_size))));\n    } else {\n      int32_t dim_ = JUST(dim);\n      const auto& input_shape = input->shape();\n      const int64_t& num_axes = input_shape->NumAxes();\n      dim_ = JUST(maybe_wrap_dim(dim_, num_axes));\n      std::shared_ptr<one::Tensor> repeats_expand = JUST(\n          Expand(JUST(Constant(Shape{1}, Scalar(repeats), DType::Int32(), JUST(input->device()))),\n                 Shape{input->shape()->At(dim_)}));\n      std::shared_ptr<one::Tensor> cumsum = JUST(Cumsum(repeats_expand, 0, DType::Int32()));\n      int64_t output_size = input->shape()->At(dim_);\n      if (repeats > 0) { output_size *= repeats; }\n      res = JUST(IndexSelect(input, dim_,\n                             JUST(RepeatInterLeaveIndex(repeats_expand, cumsum, output_size))));\n    }\n    return res;\n  }\n};\n\nclass RepeatInterLeaveTensorFunctor {\n public:\n  RepeatInterLeaveTensorFunctor() {}\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input,\n                           const std::shared_ptr<one::Tensor>& repeats, const int32_t& dim,\n                           const Optional<int32_t>& output_size) const {\n    CHECK_OR_RETURN(input->is_local() == true)\n        << Error::RuntimeError() << \"repeat_interleave only support local tensor now\";\n    const auto repeats_shape = repeats->shape();\n    const int64_t& repeat_num_axes = repeats_shape->NumAxes();\n    CHECK_OR_RETURN(repeat_num_axes == 1)\n        << Error::RuntimeError() << \"repeat_interleave only accept 1D vector as repeat\";\n    CHECK_OR_RETURN(repeats->dtype() == DType::Int64())\n        << Error::RuntimeError() << \"repeats has to be Long tensor\";\n\n    std::vector<int64_t> repeats_value(repeats_shape->elem_cnt());\n    if (!output_size.has_value()) {\n      const auto& callback = [&](ep::Stream* stream,\n                                 const std::shared_ptr<vm::EagerBlobObject>& eager_blob_object) {\n        SyncAutoMemcpy(stream, repeats_value.data(), eager_blob_object->dptr(),\n                       repeats_value.size() * sizeof(int64_t), memory::MakeHostMemCase(),\n                       eager_blob_object->mem_case());\n      };\n      SyncAccessTensorWithTimeOut(repeats, callback, \"const\").GetOrThrow();\n      for (const auto x : repeats_value) {\n        CHECK_OR_RETURN(x >= 0) << Error::RuntimeError() << \"repeats can not be negative\";\n      }\n    } else {\n      repeats_value.push_back(JUST(output_size));\n    }\n    int32_t dim_ = dim;\n    const auto& input_shape = input->shape();\n    const int64_t& num_axes = input_shape->NumAxes();\n    dim_ = JUST(maybe_wrap_dim(dim_, num_axes));\n    CHECK_OR_RETURN(repeats_shape->At(0) == input->shape()->At(dim_))\n        << Error::RuntimeError() << \"repeats must have the same size as input along dim\";\n    std::shared_ptr<one::Tensor> cumsum = JUST(Cumsum(repeats, 0, DType::Int32()));\n    const int64_t& output_size_value =\n        std::accumulate(repeats_value.begin(), repeats_value.end(), 0);\n    return JUST(\n        IndexSelect(input, dim_, JUST(RepeatInterLeaveIndex(repeats, cumsum, output_size_value))));\n  }\n};\n\nclass TileFunctor {\n public:\n  TileFunctor() {}\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input, const Shape& dims) const {\n    std::vector<int32_t> new_dims_vec;\n    int32_t numaxes_diff = input->shape()->NumAxes() - dims.NumAxes();\n    for (int32_t i = dims.NumAxes() - 1; i >= 0; i--) {\n      CHECK_GE_OR_RETURN(dims.At(i), 0)\n          << Error::RuntimeError() << \"Trying to create tensor with negative dimension \"\n          << dims.At(i);\n      new_dims_vec.insert(new_dims_vec.begin(), dims.At(i));\n    }\n    for (int32_t i = 0; i < numaxes_diff; i++) { new_dims_vec.insert(new_dims_vec.begin(), 1); }\n    Shape new_dims(DimVector(new_dims_vec.begin(), new_dims_vec.end()));\n    return JUST(Repeat(input, new_dims));\n  }\n};\n\nclass TransposeAllDimPropertyFunctor {\n public:\n  TransposeAllDimPropertyFunctor() {}\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x) const {\n    const int64_t ndim = x->ndim();\n    std::vector<int32_t> permute;\n    permute.resize(ndim);\n    std::iota(permute.begin(), permute.end(), 0);\n    std::reverse(permute.begin(), permute.end());\n    return Transpose(x, permute);\n  }\n};\n\nclass TransposeAllDimFunctionFunctor {\n public:\n  TransposeAllDimFunctionFunctor() {}\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x) const {\n    const int64_t ndim = x->ndim();\n    CHECK_OR_RETURN(ndim <= 2) << Error::RuntimeError()\n                               << \"t() expects a tensor with <= 2 dimensions, but input tensor is \"\n                               << ndim << \"D\";\n    if (ndim == 0 || ndim == 1) { return x; }\n    return Transpose2dim(x, 0, 1);\n  }\n};\n\nclass ReshapeLikeFunctor {\n public:\n  ReshapeLikeFunctor() {\n    op_ =\n        CHECK_JUST(one::OpBuilder(\"reshape_like\").Input(\"in\").Input(\"like\").Output(\"out\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,\n                           const std::shared_ptr<one::Tensor>& like) const {\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {x, JUST(like->detach())});\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass PinMemoryFunctor {\n public:\n  PinMemoryFunctor() {\n    op_ =\n        CHECK_JUST(one::OpBuilder(\"slice_update\").Input(\"ref\").Input(\"value\").Output(\"y\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input) const {\n    // TODO:(zhaoluyang) support global tensor.pin_memory()\n    CHECK_OR_RETURN(input->is_local() && !(LazyMode::is_enabled()))\n        << Error::RuntimeError() << \"Tensor.pin_memory() only support local tensor for now!\";\n    // if tensor already pinned, then just return\n    if (JUST(JUST(input->AsLocalTensor())->is_pinned())) { return input; }\n    auto shape = input->shape();\n    auto device = JUST(input->device());\n    const bool requires_grad = input->requires_grad();\n    CHECK_EQ_OR_RETURN(device->enum_type(), DeviceType::kCPU)\n        << Error::RuntimeError() << \"cannot pin tensor with device: \" << device->ToString()\n        << \", only dense CPU tensors can be pinned.\";\n\n    auto empty = JUST(functional::Empty(*shape.get(), input->dtype(), device, requires_grad,\n                                        /*pin_memory=*/true));\n    const int32_t ndim = input->ndim();\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"start\", \"stop\", \"step\");\n    if (ndim == 0) {\n      // TODO(wyg): use TensorSetItem after supporting non-requires_grad tensor inplace\n      // for 0-dim tensor\n      empty = JUST(functional::ExpandDims(empty, 0));              // expand to [1, ]\n      auto expand_input = JUST(functional::ExpandDims(input, 0));  // expand to [1, ]\n      attrs.SetAllAttrs(std::vector<int64_t>{0}, std::vector<int64_t>{1}, std::vector<int64_t>{1});\n      auto outputs = TensorTuple{empty};\n      JUST(OpInterpUtil::Dispatch(*op_, TensorTuple{empty, expand_input}, &outputs, attrs));\n      return outputs[0];\n    } else {\n      std::vector<int64_t> starts(ndim, 0);\n      std::vector<int64_t> stops(ndim);\n      std::vector<int64_t> steps(ndim, 1);\n      for (int i = 0; i < ndim; ++i) { stops[i] = input->shape()->At(i); }\n      attrs.SetAllAttrs(starts, stops, steps);\n      JUST(empty->set_requires_grad(requires_grad));\n      auto outputs = TensorTuple{empty};\n      JUST(OpInterpUtil::Dispatch(*op_, TensorTuple{empty, input}, &outputs, attrs));\n      return outputs[0];\n    }\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass FillFunctor {\n public:\n  FillFunctor() { op_ = CHECK_JUST(one::OpBuilder(\"fill_\").Input(\"in\").Output(\"out\").Build()); }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& in, const Scalar& value) const {\n    JUST(CheckInplaceValid(in));\n    auto& attrs =\n        THREAD_CACHED_MUTABLE_ATTR_MAP(\"floating_value\", \"is_floating_value\", \"integral_value\");\n    if (IsFloatingDataType(in->dtype()->data_type())) {\n      attrs.SetAllAttrs(value.As<double>(), true, NullOpt);\n    } else if (IsIntegralDataType(in->dtype()->data_type())) {\n      attrs.SetAllAttrs(NullOpt, false, value.As<int64_t>());\n    } else {\n      UNIMPLEMENTED_THEN_RETURN() << \"Only support floating or integral data type.\";\n    }\n    auto outputs = std::make_shared<TensorTuple>(1);\n    (*outputs)[0] = in;\n    JUST(OpInterpUtil::Dispatch(*op_, {in}, outputs.get(), attrs));\n    return (*outputs)[0];\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass FillTensorFunctor {\n public:\n  FillTensorFunctor() {\n    op_ =\n        CHECK_JUST(one::OpBuilder(\"fill_tensor_\").Input(\"in\").Input(\"value\").Output(\"out\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& in,\n                           const std::shared_ptr<one::Tensor>& value) const {\n    JUST(CheckInplaceValid(in));\n    const int64_t ndim = value->ndim();\n    CHECK_EQ_OR_RETURN(ndim, 0)\n        << Error::RuntimeError()\n        << \"fill_ only supports 0-dimension value tensor but got tensor with \" << ndim\n        << \" dimensions.\";\n    TensorProcessor tensor_processor;\n    JUST(tensor_processor.PromoteInputsToCommonDtype(true, in->dtype())\n             .AddInputs({in, value})\n             .Apply());\n    TensorTuple input_tuple = JUST(tensor_processor.GetInputs());\n    auto outputs = std::make_shared<TensorTuple>(1);\n    (*outputs)[0] = in;\n    JUST(OpInterpUtil::Dispatch(*op_, {input_tuple[0], input_tuple[1]}, outputs.get()));\n    return (*outputs)[0];\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass IndexAddFunctor {\n public:\n  IndexAddFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"index_add\")\n                         .Input(\"input\")\n                         .Input(\"index\")\n                         .Input(\"source\")\n                         .Output(\"output\")\n                         .Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input, const int64_t& dim,\n                           const std::shared_ptr<one::Tensor>& index,\n                           const std::shared_ptr<one::Tensor>& source, const Scalar& alpha) const {\n    CHECK_OR_RETURN(source->ndim() == 0 || index->shape()->Count(0) == source->shape()->At(dim))\n        << \"index_copy_(): Number of indices (,\" << index->shape()->Count(0)\n        << \", \\\") should be equal to source.size(dim) (,\" << source->shape()->At(dim) << \", \\\")\";\n    CHECK_OR_RETURN(index->dtype()->data_type() != DataType::kInt32\n                    || index->dtype()->data_type() != DataType::kInt64)\n        << \"Input(Index) holds the wrong type, it holds \"\n        << DataType_Name(index->dtype()->data_type())\n        << \" , but \"\n           \"desires to be int32_t or int64_t\";\n    const float alpha_value = alpha.As<float>();\n    int64_t dim_ = dim;\n    dim_ = JUST(maybe_wrap_dim(dim_, input->ndim()));\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"dim\", \"alpha\");\n    attrs.SetAllAttrs(dim_, alpha_value);\n    TensorProcessor tensor_processor;\n    JUST(tensor_processor.PromoteInputsToCommonDtype(true, input->dtype())\n             .AddInputs({input, source})\n             .Apply());\n    TensorTuple input_tuple = JUST(tensor_processor.GetInputs());\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {input, index, input_tuple.at(1)}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass IndexAddInplaceFunctor {\n public:\n  IndexAddInplaceFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"index_add\")\n                         .Input(\"input\")\n                         .Input(\"index\")\n                         .Input(\"source\")\n                         .Output(\"output\")\n                         .Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input, const int64_t& dim,\n                           const std::shared_ptr<one::Tensor>& index,\n                           const std::shared_ptr<one::Tensor>& source, const Scalar& alpha) const {\n    CHECK_OR_RETURN(source->ndim() == 0 || index->shape()->Count(0) == source->shape()->At(dim))\n        << \"index_copy_(): Number of indices (,\" << index->shape()->Count(0)\n        << \", \\\") should be equal to source.size(dim) (,\" << source->shape()->At(dim) << \", \\\")\";\n    CHECK_OR_RETURN(index->dtype()->data_type() != DataType::kInt32\n                    || index->dtype()->data_type() != DataType::kInt64)\n        << \"Input(Index) holds the wrong type, it holds \"\n        << DataType_Name(index->dtype()->data_type())\n        << \" , but \"\n           \"desires to be int32_t or int64_t\";\n    const float alpha_value = alpha.As<float>();\n    int64_t dim_ = dim;\n    dim_ = JUST(maybe_wrap_dim(dim_, input->ndim()));\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"dim\", \"alpha\");\n    attrs.SetAllAttrs(dim_, alpha_value);\n    JUST(CheckInplaceValid(input));\n    std::shared_ptr<TensorTuple> outputs = std::make_shared<TensorTuple>(1);\n    outputs->at(0) = input;\n    TensorProcessor tensor_processor;\n    JUST(tensor_processor.PromoteInputsToCommonDtype(true, input->dtype())\n             .AddInputs({input, source})\n             .Apply());\n    TensorTuple input_tuple = JUST(tensor_processor.GetInputs());\n    JUST(OpInterpUtil::Dispatch(*op_, {input, index, input_tuple.at(1)}, outputs.get(), attrs));\n    return outputs->at(0);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass BroadcastShapesFunctor {\n public:\n  Maybe<Shape> operator()(const std::vector<Shape>& shapes) const {\n    return InferUnifiedShapeForBroadcasting(shapes);\n  }\n};\n\nclass BroadcastTensorsFunctor {\n public:\n  Maybe<TensorTuple> operator()(const TensorTuple& tensors) const {\n    if (tensors.empty()) { return Error::RuntimeError() << \"tensors should not be empty.\"; }\n\n    Shape shape_to_broadcast;\n    std::deque<bool> need_to_broadcast;\n\n    std::tie(shape_to_broadcast, need_to_broadcast) =\n        *JUST(InferUnifiedShapeForBroadcastingWithInfo([&tensors]() {\n          std::vector<Shape> shapes;\n          for (auto& x : tensors) { shapes.push_back(*x->shape()); }\n          return shapes;\n        }()));\n\n    std::shared_ptr<TensorTuple> outputs = std::make_shared<TensorTuple>();\n    for (size_t i = 0; i < tensors.size(); ++i) {\n      outputs->emplace_back(need_to_broadcast.at(i)  // NOLINT\n                                ? JUST(functional::Expand(tensors.at(i), shape_to_broadcast))\n                                : tensors.at(i));\n    }\n    return outputs;\n  }\n};\nclass BinCountFunctor {\n public:\n  BinCountFunctor() {\n    op_ = CHECK_JUST(OpBuilder(\"bincount\").Input(\"in\").Output(\"out\").Build());\n    weight_op_ =\n        CHECK_JUST(OpBuilder(\"bincount\").Input(\"in\").Input(\"weight\").Output(\"out\").Build());\n  }\n\n  Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& input, const Optional<Tensor>& weight,\n                           const Optional<int64_t>& minlength) const {\n    CHECK_OR_RETURN(!input->dtype()->is_floating_point()) << \"bincount can only support int tensor\";\n    TensorProcessor tensor_processor;\n    JUST(tensor_processor.AddInputs({input}, DType::Int64()).Apply());\n    const auto x = JUST(tensor_processor.GetInputs()).at(0);\n    std::shared_ptr<Tensor> local_tensor = x;\n    int64_t max = 0;\n\n    // check min value\n    {\n      if (x->is_global()) { local_tensor = JUST(GlobalToLocal(x, false)); }\n      auto tensor_min = JUST(functional::Min(local_tensor));\n      int64_t min = 0;\n      const auto& callback_min =\n          [&](ep::Stream* stream, const std::shared_ptr<vm::EagerBlobObject>& eager_blob_object) {\n            SyncAutoMemcpy(stream, &min, eager_blob_object->dptr(), sizeof(min),\n                           memory::MakeHostMemCase(), eager_blob_object->mem_case());\n          };\n      JUST(SyncAccessTensorWithTimeOut(tensor_min, callback_min, \"const\"));\n      CHECK_GE_OR_RETURN(min, 0) << \"bincount only supports 1-d non-negative integral inputs.\";\n\n      auto tensor_max = JUST(functional::Max(local_tensor));\n      const auto& callback_max =\n          [&](ep::Stream* stream, const std::shared_ptr<vm::EagerBlobObject>& eager_blob_object) {\n            SyncAutoMemcpy(stream, &max, eager_blob_object->dptr(), sizeof(max),\n                           memory::MakeHostMemCase(), eager_blob_object->mem_case());\n          };\n      JUST(SyncAccessTensorWithTimeOut(tensor_max, callback_max, \"const\"));\n      max += 1;\n    }\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"size\");\n    if (minlength) {\n      CHECK_GE_OR_RETURN(JUST(minlength), 0) << \"minlength should be >= 0\";\n      max = std::max(JUST(minlength), max);\n    }\n    attrs.SetAllAttrs(max);\n    if (weight) {\n      CHECK_EQ_OR_RETURN(JUST(weight)->nelement(), x->nelement())\n          << \"input and weights should have the same length\";\n      return OpInterpUtil::Dispatch<Tensor>(*weight_op_, {x, JUST(weight)}, attrs);\n    } else {\n      return OpInterpUtil::Dispatch<Tensor>(*op_, {x}, attrs);\n    }\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n  std::shared_ptr<OpExpr> weight_op_;\n};\n\nclass UniqueFunctor {\n public:\n  UniqueFunctor() {\n    op_ = CHECK_JUST(\n        OpBuilder(\"unique\").Input(\"x\").Output(\"y\").Output(\"idx\").Output(\"num_unique\").Build());\n  };\n  Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& x, const bool sorted,\n                           const Symbol<DType>& dtype) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"out_idx\", \"sorted\");\n    DataType out_idx = dtype->data_type();\n    attrs.SetAllAttrs(out_idx, sorted);\n    std::shared_ptr<TensorTuple> output = JUST(\n        OpInterpUtil::Dispatch<TensorTuple>(*op_, {JUST(functional::Flatten(x, 0, -1))}, attrs));\n    int64_t num_unique = 0;\n    std::shared_ptr<Tensor> num_unique_tensor = output->at(2);\n    {\n      if (num_unique_tensor->is_global()) {\n        num_unique_tensor = JUST(GlobalToLocal(num_unique_tensor, false));\n      }\n      const auto& callback = [&](ep::Stream* stream,\n                                 const std::shared_ptr<vm::EagerBlobObject>& eager_blob_object) {\n        SyncAutoMemcpy(stream, &num_unique, eager_blob_object->dptr(),\n                       GetSizeOfDataType(dtype->data_type()), memory::MakeHostMemCase(),\n                       eager_blob_object->mem_case());\n      };\n      JUST(SyncAccessTensorWithTimeOut(num_unique_tensor, callback, \"const\"));\n    }\n    return functional::Slice(output->at(0), /*start=*/{0}, /*end=*/{num_unique}, /*step=*/{1},\n                             false);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass UniqueWithCountsFunctor {\n public:\n  UniqueWithCountsFunctor() {\n    unique_op_ = CHECK_JUST(\n        OpBuilder(\"unique\").Input(\"x\").Output(\"y\").Output(\"idx\").Output(\"num_unique\").Build());\n    unique_with_counts_op_ = CHECK_JUST(OpBuilder(\"unique_with_counts\")\n                                            .Input(\"x\")\n                                            .Output(\"y\")\n                                            .Output(\"idx\")\n                                            .Output(\"num_unique\")\n                                            .Output(\"count\")\n                                            .Build());\n  };\n  Maybe<TensorTuple> operator()(const std::shared_ptr<Tensor>& x, const bool sorted,\n                                const bool return_inverse, const bool return_counts,\n                                const Symbol<DType>& dtype) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"out_idx\", \"sorted\");\n    attrs.SetAllAttrs(dtype->data_type(), sorted);\n    std::shared_ptr<TensorTuple> output;\n    if (return_counts) {\n      output = JUST(OpInterpUtil::Dispatch<TensorTuple>(\n          *unique_with_counts_op_, {JUST(functional::Flatten(x, 0, -1))}, attrs));\n    } else {\n      output = JUST(OpInterpUtil::Dispatch<TensorTuple>(\n          *unique_op_, {JUST(functional::Flatten(x, 0, -1))}, attrs));\n    }\n\n    int64_t num_unique = 0;\n    std::shared_ptr<Tensor> num_unique_tensor = output->at(2);\n    {\n      if (num_unique_tensor->is_global()) {\n        num_unique_tensor = JUST(GlobalToLocal(num_unique_tensor, false));\n      }\n      const auto& callback = [&](ep::Stream* stream,\n                                 const std::shared_ptr<vm::EagerBlobObject>& eager_blob_object) {\n        SyncAutoMemcpy(stream, &num_unique, eager_blob_object->dptr(),\n                       GetSizeOfDataType(dtype->data_type()), memory::MakeHostMemCase(),\n                       eager_blob_object->mem_case());\n      };\n      JUST(SyncAccessTensorWithTimeOut(num_unique_tensor, callback, \"const\"));\n    }\n    auto result = std::make_shared<TensorTuple>();\n    const auto& y = JUST(\n        functional::Slice(output->at(0), /*start=*/{0}, /*end=*/{num_unique}, /*step=*/{1}, false));\n    result->emplace_back(y);\n    if (return_inverse) {\n      result->emplace_back(JUST(functional::Reshape(output->at(1), *x->shape())));\n    }\n    if (return_counts) {\n      const auto count = JUST(functional::Slice(output->at(3), /*start=*/{0}, /*end=*/{num_unique},\n                                                /*step=*/{1}, false));\n      result->emplace_back(count);\n    }\n    return result;\n  }\n\n private:\n  std::shared_ptr<OpExpr> unique_op_;\n  std::shared_ptr<OpExpr> unique_with_counts_op_;\n};\n\nclass BaddBmmFunctor {\n public:\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input,\n                           const std::shared_ptr<one::Tensor>& batch1,\n                           const std::shared_ptr<one::Tensor>& batch2, const double& beta,\n                           const double& alpha) const {\n    const int32_t batch1_ndim = batch1->ndim();\n    const int32_t batch2_ndim = batch2->ndim();\n    CHECK_EQ_OR_RETURN(batch1_ndim, 3) << Error::RuntimeError() << \"batch1 must be a 3D tensor\";\n    CHECK_EQ_OR_RETURN(batch2_ndim, 3) << Error::RuntimeError() << \"batch2 must be a 3D tensor\";\n    CHECK_EQ_OR_RETURN(batch1->dim(0), batch2->dim(0))\n        << Error::RuntimeError() << \"batch1 and batch2 must have same number of batches, got ,\"\n        << batch1->dim(0) << \" and \" << batch2->dim(0);\n    CHECK_EQ_OR_RETURN(batch1->dim(2), batch2->dim(1))\n        << \"Incompatible matrix sizes for bmm (\" << batch1->dim(1) << \"x\" << batch1->dim(2)\n        << \" and \" << batch2->dim(1) << \"x\" << batch2->dim(2) << \")\";\n\n    if (beta == 0.0) {\n      // In stable diffsion, the beta param is always 0.0, so we can avoid use add and mul op to\n      // optimize speed and bandwidth in cuda.\n      return JUST(functional::BatchMatMul(batch1, batch2, false, false, alpha));\n    } else {\n      // TODO(add a fuse kernel to optimize speed and bancwidth in cuda)\n      return JUST(\n          functional::Add(JUST(functional::ScalarMul(beta, input)),\n                          JUST(functional::BatchMatMul(batch1, batch2, false, false, alpha)),\n                          /*alpha=*/1.0, /*inplace=*/false));\n    }\n  }\n};\n\nclass SortFunctor {\n public:\n  Maybe<TensorTuple> operator()(const std::shared_ptr<Tensor>& input, const int32_t& dim,\n                                const bool descending) const {\n    auto outputs = std::make_shared<TensorTuple>(2);\n    std::shared_ptr<Tensor> values;\n    std::shared_ptr<Tensor> indices;\n    int32_t axis = dim;\n    axis = JUST(maybe_wrap_dim(axis, input->ndim()));\n    std::string direction(\"ASCENDING\");\n    if (descending) { direction.assign(\"DESCENDING\"); }\n    if (axis == input->ndim() - 1) {\n      indices = JUST(ArgSort(input, direction));\n      values = JUST(DimGather(input, axis, indices, false));\n    } else {\n      std::shared_ptr<std::vector<int32_t>> perm =\n          JUST(GetPermWhenTransposeAxisToLastDim(input->ndim(), dim));\n      auto x = JUST(Transpose(input, *perm));\n      auto indices_temp = JUST(ArgSort(x, direction));\n      auto inversed_perm = JUST(GetInversedPerm(*perm));\n      indices = JUST(Transpose(indices_temp, *inversed_perm));\n      values = JUST(DimGather(input, axis, indices, false));\n    }\n    (*outputs)[0] = values;\n    (*outputs)[1] = indices;\n    return outputs;\n  }\n};\n\nclass CloneFunctor {\n public:\n  Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& input) const { return input->clone(); }\n};\n\nclass FusedCodegeexQkvReshapeFunctor {\n public:\n  FusedCodegeexQkvReshapeFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"fused_codegeex_qkv_reshape\")\n                         .Input(\"query\")\n                         .Input(\"key\")\n                         .Input(\"value\")\n                         .Output(\"new_query\")\n                         .Output(\"new_key\")\n                         .Output(\"new_value\")\n                         .Build());\n  }\n\n  Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& query,\n                                const std::shared_ptr<one::Tensor>& key,\n                                const std::shared_ptr<one::Tensor>& value,\n                                const int32_t num_attention_heads) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"num_attention_heads\");\n    attrs.SetAllAttrs(num_attention_heads);\n    return OpInterpUtil::Dispatch<TensorTuple>(*op_, {query, key, value}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\n}  // namespace impl\n\nONEFLOW_FUNCTION_LIBRARY(m) {\n  m.add_functor<impl::ArgMaxFunctor>(\"ArgMax\");\n  m.add_functor<impl::ArgMinFunctor>(\"ArgMin\");\n  m.add_functor<impl::GlobalTensorConstantFunctor>(\"GlobalTensorConstant\");\n  m.add_functor<impl::TensorConstantFunctor>(\"TensorConstant\");\n  m.add_functor<impl::GlobalConstantFunctor>(\"GlobalConstant\");\n  m.add_functor<impl::ConstantFunctor>(\"Constant\");\n  m.add_functor<impl::GlobalEmptyFunctor>(\"GlobalEmpty\");\n  m.add_functor<impl::EmptyFunctor>(\"Empty\");\n  m.add_functor<impl::EmptyStridedFunctor>(\"EmptyStrided\");\n  m.add_functor<impl::ZerosLikeFunctor>(\"ZerosLike\");\n  m.add_functor<impl::OnesLikeFunctor>(\"OnesLike\");\n  m.add_functor<impl::FullLikeFunctor>(\"FullLike\");\n  m.add_functor<impl::FlattenFunctor>(\"Flatten\");\n  m.add_functor<impl::FillFunctor>(\"Fill\");\n  m.add_functor<impl::FillTensorFunctor>(\"FillTensor\");\n  m.add_functor<impl::WhereFunctor>(\"Where\");\n  m.add_functor<impl::WhereScalarXFunctor>(\"WhereScalarX\");\n  m.add_functor<impl::WhereScalarYFunctor>(\"WhereScalarY\");\n  m.add_functor<impl::WhereScalarXYFunctor>(\"WhereScalarXY\");\n  m.add_functor<impl::ArgWhereFunctor>(\"ArgWhere\");\n  m.add_functor<impl::NonZeroFunctor>(\"NonZero\");\n  m.add_functor<impl::BroadcastLikeFunctor>(\"BroadcastLike\");\n  m.add_functor<impl::ConcatFunctor>(\"Concat\");\n  m.add_functor<impl::StackFunctor>(\"Stack\");\n  m.add_functor<impl::StackGradFunctor>(\"StackGrad\");\n  m.add_functor<impl::AtLeast1DFunctor>(\"AtLeast1D\");\n  m.add_functor<impl::AtLeast1DListFunctor>(\"AtLeast1D\");\n  m.add_functor<impl::AtLeast2DFunctor>(\"AtLeast2D\");\n  m.add_functor<impl::AtLeast2DListFunctor>(\"AtLeast2D\");\n  m.add_functor<impl::AtLeast3DFunctor>(\"AtLeast3D\");\n  m.add_functor<impl::AtLeast3DListFunctor>(\"AtLeast3D\");\n  m.add_functor<impl::HStackFunctor>(\"HStack\");\n  m.add_functor<impl::ColumnStackFunctor>(\"ColumnStack\");\n  m.add_functor<impl::VStackFunctor>(\"VStack\");\n  m.add_functor<impl::RowStackFunctor>(\"RowStack\");\n  m.add_functor<impl::DStackFunctor>(\"DStack\");\n  m.add_functor<impl::ExpandFunctor>(\"Expand\");\n  m.add_functor<impl::ExpandDimsFunctor>(\"ExpandDims\");\n  m.add_functor<impl::ExpandDimsFunctor>(\"Unsqueeze\");\n  m.add_functor<impl::UnsqueezeMultipleFunctor>(\"UnsqueezeMultiple\");\n  m.add_functor<impl::InplaceUnsqueezeFunctor>(\"InplaceUnsqueeze\");\n  m.add_functor<impl::SqueezeFunctor>(\"Squeeze\");\n  m.add_functor<impl::InplaceSqueezeFunctor>(\"InplaceSqueeze\");\n  m.add_functor<impl::RollFunctor>(\"Roll\");\n  m.add_functor<impl::GatherFunctor>(\"Gather\");\n  m.add_functor<impl::DimGatherFunctor>(\"DimGather\");\n  m.add_functor<impl::ArgSortFunctor>(\"ArgSort\");\n  m.add_functor<impl::SearchSortedFunctor>(\"SearchSorted\");\n  m.add_functor<impl::SearchSortedScalarFunctor>(\"SearchSortedScalar\");\n  m.add_functor<impl::GatherNdFunctor>(\"GatherNd\");\n  m.add_functor<impl::ScatterNdFunctor>(\"ScatterNd\");\n  m.add_functor<impl::TensorScatterNdUpdateFunctor>(\"TensorScatterNdUpdate\");\n  m.add_functor<impl::ScatterNdLikeFunctor>(\"ScatterNdLike\");\n  m.add_functor<impl::ReshapeFunctor>(\"Reshape\");\n  m.add_functor<impl::ViewFunctor>(\"View\");\n  m.add_functor<impl::ToContiguousFunctor>(\"ToContiguous\");\n  m.add_functor<impl::InplaceToContiguousFunctor>(\"InplaceToContiguous\");\n  m.add_functor<impl::NarrowFunctor>(\"Narrow\");\n  m.add_functor<impl::NarrowGradFunctor>(\"NarrowGrad\");\n  m.add_functor<impl::SliceUpdateFunctor>(\"SliceUpdate\");\n  m.add_functor<impl::SliceFunctor>(\"Slice\");\n  m.add_functor<impl::SliceGradFunctor>(\"SliceGrad\");\n  m.add_functor<impl::SliceView1dContiguousFunctor>(\"SliceView1dContiguous\");\n  m.add_functor<impl::CopyFunctor, impl::CopyToDeviceFunctor>(\"Copy\");\n  m.add_functor<impl::FlipFunctor>(\"Flip\");\n  m.add_functor<impl::UnfoldTensorFunctor>(\"UnfoldTensor\");\n  m.add_functor<impl::UnfoldTensorGradFunctor>(\"UnfoldTensorGrad\");\n  m.add_functor<impl::UpsampleGradFunctor>(\"UpsampleGrad\");\n  m.add_functor<impl::UpsampleNearest2DFunctor>(\"UpsampleNearest2D\");\n  m.add_functor<impl::UpsampleNearest2DGradFunctor>(\"UpsampleNearest2DGrad\");\n  m.add_functor<impl::UpsampleBilinear2DFunctor>(\"UpsampleBilinear2D\");\n  m.add_functor<impl::UpsampleBilinear2DGradFunctor>(\"UpsampleBilinear2DGrad\");\n  m.add_functor<impl::UpsampleLinear1DFunctor>(\"UpsampleLinear1D\");\n  m.add_functor<impl::UpsampleLinear1DGradFunctor>(\"UpsampleLinear1DGrad\");\n  m.add_functor<impl::UpsampleNearest1DFunctor>(\"UpsampleNearest1D\");\n  m.add_functor<impl::UpsampleNearest1DGradFunctor>(\"UpsampleNearest1DGrad\");\n  m.add_functor<impl::UpsampleBicubic2DFunctor>(\"UpsampleBicubic2D\");\n  m.add_functor<impl::UpsampleBicubic2DGradFunctor>(\"UpsampleBicubic2DGrad\");\n  m.add_functor<impl::UpsampleNearest3DFunctor>(\"UpsampleNearest3D\");\n  m.add_functor<impl::UpsampleNearest3DGradFunctor>(\"UpsampleNearest3DGrad\");\n  m.add_functor<impl::UpsampleTrilinear3DFunctor>(\"UpsampleTrilinear3D\");\n  m.add_functor<impl::UpsampleTrilinear3DGradFunctor>(\"UpsampleTrilinear3DGrad\");\n  m.add_functor<impl::UnsortedSegmentSumLikeFunctor>(\"UnsortedSegmentSumLike\");\n  m.add_functor<impl::UnsortedSegmentSumFunctor>(\"UnsortedSegmentSum\");\n  m.add_functor<impl::TrilFunctor>(\"Tril\");\n  m.add_functor<impl::InplaceTrilFunctor>(\"InplaceTril\");\n  m.add_functor<impl::TriuFunctor>(\"Triu\");\n  m.add_functor<impl::InplaceTriuFunctor>(\"InplaceTriu\");\n  m.add_functor<impl::DiagFunctor>(\"Diag\");\n  m.add_functor<impl::DiagGradFunctor>(\"DiagGrad\");\n  m.add_functor<impl::DiagonalFunctor>(\"Diagonal\");\n  m.add_functor<impl::DiagonalGradFunctor>(\"DiagonalGrad\");\n  m.add_functor<impl::TensorGetItemFunctor>(\"TensorGetItem\");\n  m.add_functor<impl::DimScatterFunctorImpl<impl::DimScatterType::kUpdate>>(\"DimScatterUpdate\");\n  m.add_functor<impl::DimScatterFunctorImpl<impl::DimScatterType::kAdd>>(\"DimScatterAdd\");\n  m.add_functor<impl::DimScatterFunctorImpl<impl::DimScatterType::kMultiply>>(\"DimScatterMul\");\n  m.add_functor<impl::DimScatterFunctor>(\"DimScatter\");\n  m.add_functor<impl::DimScatterScalarFunctorImpl<impl::DimScatterType::kUpdate>>(\n      \"DimScatterUpdateScalar\");\n  m.add_functor<impl::DimScatterScalarFunctorImpl<impl::DimScatterType::kAdd>>(\n      \"DimScatterAddScalar\");\n  m.add_functor<impl::DimScatterScalarFunctorImpl<impl::DimScatterType::kMultiply>>(\n      \"DimScatterMulScalar\");\n  m.add_functor<impl::DimScatterScalarFunctor>(\"DimScatterScalar\");\n  m.add_functor<impl::DimScatterAddLikeFunctor>(\"DimScatterAddLike\");\n\n  m.add_functor<impl::TensorSetItemFunctor>(\"TensorSetItem\");\n  m.add_functor<impl::CastLikeFunctor>(\"CastLike\");\n  m.add_functor<impl::ElementwiseMinimumGradFunctor>(\"ElementwiseMinGrad\");\n  m.add_functor<impl::ElementwiseMaximumGradFunctor>(\"ElementwiseMaxGrad\");\n  m.add_functor<impl::BroadcastPowXGradFunctor>(\"BroadcastPowXGrad\");\n  m.add_functor<impl::BroadcastPowYGradFunctor>(\"BroadcastPowYGrad\");\n  m.add_functor<impl::DivGradFunctor>(\"DivGrad\");\n  m.add_functor<impl::IdentityFunctor>(\"Identity\");\n  m.add_functor<impl::AmpWhiteIdentityFunctor>(\"AmpWhiteIdentity\");\n  m.add_functor<impl::AmpBlackIdentityFunctor>(\"AmpBlackIdentity\");\n  m.add_functor<impl::ReduceSumLikeFunctor>(\"ReduceSumLike\");\n  m.add_functor<impl::BroadcastReduceSumLikeFunctor>(\"BroadcastReduceSumLike\");\n  m.add_functor<impl::SplitFunctor>(\"Split\");\n  m.add_functor<impl::UnbindFunctor>(\"Unbind\");\n  m.add_functor<impl::ChunkFunctor>(\"Chunk\");\n  m.add_functor<impl::SplitLikeFunctor>(\"SplitLike\");\n  m.add_functor<impl::SplitWithSizeFunctor>(\"SplitWithSize\");\n  m.add_functor<impl::BatchGatherFunctor>(\"BatchGather\");\n  m.add_functor<impl::UnsortedBatchSegmentSumFunctor>(\"UnsortedBatchSegmentSum\");\n  m.add_functor<impl::MaskedFillFunctor<false>>(\"MaskedFill\");\n  m.add_functor<impl::MaskedFillFunctor<true>>(\"MaskedFillInplace\");\n  m.add_functor<impl::MeshgridFunctor>(\"Meshgrid\");\n  m.add_functor<impl::IndexSelectFunctor>(\"IndexSelect\");\n  m.add_functor<impl::ToFunctor, impl::To2Functor, impl::To3Functor, impl::To4Functor,\n                impl::ToDeviceFunctor, impl::ToMemoryFormatFunctor>(\"To\");\n  m.add_functor<impl::TopKFunctor>(\"TopK\");\n  m.add_functor<impl::InTopKFunctor>(\"InTopK\");\n  m.add_functor<impl::TensorToTensorBufferFunctor>(\"TensorToTensorBuffer\");\n  m.add_functor<impl::TensorBufferToTensorFunctor>(\"TensorBufferToTensor\");\n  m.add_functor<impl::GenTensorBufferFunctor>(\"GenTensorBuffer\");\n  m.add_functor<impl::RepeatFunctor>(\"Repeat\");\n  m.add_functor<impl::RepeatInterLeaveIndexFunctor>(\"RepeatInterLeaveIndex\");\n  m.add_functor<impl::RepeatInterLeaveIntFunctor>(\"RepeatInterLeaveInt\");\n  m.add_functor<impl::RepeatInterLeaveTensorFunctor>(\"RepeatInterLeaveTensor\");\n  m.add_functor<impl::TileFunctor>(\"Tile\");\n  m.add_functor<impl::TransposeAllDimPropertyFunctor>(\"TransposeAllDimProperty\");\n  m.add_functor<impl::TransposeAllDimFunctionFunctor>(\"TransposeAllDimFunction\");\n  m.add_functor<impl::ReshapeLikeFunctor>(\"ReshapeLike\");\n  m.add_functor<impl::PinMemoryFunctor>(\"PinMemory\");\n  m.add_functor<impl::BroadcastShapesFunctor>(\"BroadcastShapes\");\n  m.add_functor<impl::BroadcastTensorsFunctor>(\"BroadcastTensors\");\n  m.add_functor<impl::ExpandFunctor>(\"BroadcastTo\");  // BroadcastTo is an alias of Expand\n  m.add_functor<impl::BinCountFunctor>(\"BinCount\");\n  m.add_functor<impl::IndexAddFunctor>(\"IndexAdd\");\n  m.add_functor<impl::IndexAddInplaceFunctor>(\"IndexAddInplace\");\n  m.add_functor<impl::UniqueFunctor>(\"Unique\");\n  m.add_functor<impl::UniqueWithCountsFunctor>(\"UniqueWithCounts\");\n  m.add_functor<impl::BaddBmmFunctor>(\"BaddBmm\");\n  m.add_functor<impl::SortFunctor>(\"Sort\");\n  m.add_functor<impl::CloneFunctor>(\"Clone\");\n  m.add_functor<impl::FusedCodegeexQkvReshapeFunctor>(\"FusedCodegeexQkvReshape\");\n};\n\n}  // namespace functional\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/functional/impl/binary_functor.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/functional/impl/binary_functor.h\"\n\n#include \"oneflow/core/common/error.h\"\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/common/scalar.h\"\n#include \"oneflow/core/framework/attr_map.h\"\n#include \"oneflow/core/framework/mutable_attr_map.h\"\n#include \"oneflow/core/framework/op_builder.h\"\n#include \"oneflow/core/framework/op_expr.h\"\n#include \"oneflow/core/framework/op_interpreter/op_interpreter_util.h\"\n#include \"oneflow/core/framework/tensor.h\"\n#include \"oneflow/core/framework/tensor_util.h\"\n#include \"oneflow/core/framework/tensor_tuple.h\"\n#include \"oneflow/core/functional/functional.h\"\n#include \"oneflow/core/functional/function_library.h\"\n#include \"oneflow/core/functional/functional_api.yaml.h\"\n#include \"oneflow/core/functional/sequence_function.h\"\n\nnamespace oneflow {\nnamespace one {\nnamespace functional {\n\nnamespace impl {\n\nnamespace {\n\nbool IsCPUScalarTensor(const std::shared_ptr<Tensor>& tensor) {\n  return tensor->shape()->NumAxes() == 0\n         && TensorDeviceToString(tensor).find(\"cpu\") != std::string::npos;\n}\n\n}  // namespace\n\nstd::string TensorDeviceToString(const std::shared_ptr<Tensor>& tensor) {\n  if (tensor->is_global()) { return CHECK_JUST(tensor->parallel_desc())->device_tag(); }\n  return CHECK_JUST(tensor->device())->ToString();\n}\n\nMaybe<void> CastDeviceForCPUScalarTensor(std::shared_ptr<Tensor>& tensor,\n                                         std::shared_ptr<Tensor>& other, bool inplace) {\n  if (TensorDeviceToString(tensor) != TensorDeviceToString(other)) {\n    if (IsCPUScalarTensor(other)) {\n      other = JUST(functional::To(other, TensorDeviceToString(tensor)));\n    } else if (!inplace && IsCPUScalarTensor(tensor)) {\n      tensor = JUST(functional::To(tensor, TensorDeviceToString(other)));\n    }\n  }\n  return Maybe<void>::Ok();\n}\n\nclass AddFunctor {\n public:\n  AddFunctor() {\n    add_op_ = CHECK_JUST(one::OpBuilder(\"add_n\").Input(\"in\", 2).Output(\"out\").Build());\n    broadcast_add_op_ =\n        CHECK_JUST(one::OpBuilder(\"broadcast_add\").Input(\"x\").Input(\"y\").Output(\"z\").Build());\n  }\n\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input,\n                           const std::shared_ptr<one::Tensor>& other, const Scalar& alpha,\n                           bool inplace) const {\n    auto input_tensor = input;\n    if (IsIntegralDataType(input_tensor->dtype()->data_type())\n        && IsIntegralDataType(other->dtype()->data_type()) && alpha.IsFloatingPoint()) {\n      return Error::RuntimeError()\n             << \"For integral input tensors, argument alpha must not be a floating point number.\";\n    }\n\n    bool input_static_zeros = IsStaticZerosTensor(input_tensor);\n    if (input_static_zeros || IsStaticZerosTensor(other)) {\n      CHECK_OR_RETURN(JUST(input_tensor->device()) == JUST(other->device()))\n          << Error::RuntimeError()\n          << \"Expected all tensors to be on the same device, but found at least two devices, \"\n          << JUST(input_tensor->device())->ToString() << \" and \"\n          << JUST(other->device())->ToString() << \"!\";\n      CHECK_OR_RETURN(*input_tensor->shape() == *other->shape())\n          << Error::RuntimeError() << \"The size of tensor a \" << input_tensor->shape()->ToString()\n          << \" must match the size of tensor b \" << other->shape();\n      if (input_static_zeros) {\n        if ((alpha.IsIntegral() && alpha.Value<int64_t>() == 1)\n            || (alpha.IsFloatingPoint()\n                && std::fabs(alpha.Value<double>() - 1.0)\n                       < std::numeric_limits<double>::epsilon())) {\n          return other;\n        } else {\n          return JUST(functional::ScalarMul(alpha, other));\n        }\n      }\n      return input_tensor;\n    }\n\n    const OpExpr* op = nullptr;\n    Optional<Symbol<DType>> promote_dtype;\n    if (inplace) { promote_dtype = input_tensor->dtype(); }\n\n    TensorProcessor tensor_processor;\n    if ((alpha.IsIntegral() && alpha.Value<int64_t>() == 1)\n        || (alpha.IsFloatingPoint()\n            && std::fabs(alpha.Value<double>() - 1.0) < std::numeric_limits<double>::epsilon())) {\n      JUST(tensor_processor.PromoteInputsToCommonDtype(true, promote_dtype)\n               .AddInputs({input_tensor, other})\n               .Apply());\n    } else {\n      JUST(tensor_processor.PromoteInputsToCommonDtype(true, promote_dtype)\n               .AddInputs({input_tensor, JUST(functional::ScalarMul(alpha, other))})\n               .Apply());\n    }\n    TensorTuple input_vec = JUST(tensor_processor.GetInputs());\n    const std::shared_ptr<one::Tensor>& input_cast = input_vec[0];\n    const std::shared_ptr<one::Tensor>& other_cast = input_vec[1];\n    JUST(CastDeviceForCPUScalarTensor(input_vec[0], input_vec[1], inplace));\n\n    if (*input_cast->shape() == *other_cast->shape()) {\n      op = add_op_.get();\n    } else {\n      op = broadcast_add_op_.get();\n    }\n    if (inplace) {\n      JUST(CheckInplaceCastValid(input_tensor, input_cast));\n      JUST(CheckInplaceValid(input_tensor));\n      JUST(CheckInplaceShapeCanExpandTo(*other_cast->shape(), *input_cast->shape()));\n      std::shared_ptr<TensorTuple> outputs = std::make_shared<TensorTuple>(1);\n      outputs->at(0) = input_cast;\n      JUST(OpInterpUtil::Dispatch(*op, input_vec, outputs.get()));\n      return outputs->at(0);\n    }\n    return OpInterpUtil::Dispatch<Tensor>(*op, input_vec);\n  }\n\n private:\n  std::shared_ptr<OpExpr> add_op_;\n  std::shared_ptr<OpExpr> broadcast_add_op_;\n};\n\nclass BroadcastPowFunctor : public BinaryFloatFunctor {\n public:\n  BroadcastPowFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"broadcast_pow\").Input(\"x\").Input(\"y\").Output(\"z\").Build());\n  }\n};\n\nclass SubFunctor : public InplaceableBinaryFunctor {\n public:\n  SubFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"broadcast_sub\").Input(\"x\").Input(\"y\").Output(\"z\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input,\n                           const std::shared_ptr<one::Tensor>& other, const Scalar& alpha,\n                           bool inplace) const {\n    if (IsIntegralDataType(input->dtype()->data_type())\n        && IsIntegralDataType(other->dtype()->data_type()) && alpha.IsFloatingPoint()) {\n      return Error::RuntimeError()\n             << \"For integral input tensors, argument alpha must not be a floating point number.\";\n    }\n    if ((alpha.IsIntegral() && alpha.Value<int64_t>() == 1)\n        || (alpha.IsFloatingPoint()\n            && std::fabs(alpha.Value<double>() - 1.0) < std::numeric_limits<double>::epsilon())) {\n      return InplaceableBinaryFunctor::operator()(input, other, inplace);\n    } else {\n      return InplaceableBinaryFunctor::operator()(input, JUST(functional::ScalarMul(alpha, other)),\n                                                  inplace);\n    }\n  }\n};\n\nclass MulFunctor {\n public:\n  MulFunctor() {\n    broadcast_mul_op_ =\n        CHECK_JUST(one::OpBuilder(\"broadcast_mul\").Input(\"x\").Input(\"y\").Output(\"z\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,\n                           const std::shared_ptr<one::Tensor>& y) const {\n    auto tensor_x = x;\n    auto tensor_y = y;\n    JUST(CastDeviceForCPUScalarTensor(tensor_x, tensor_y, /*inplace=*/false));\n    TensorProcessor tensor_processor;\n    JUST(tensor_processor.PromoteInputsToCommonDtype(true).AddInputs({tensor_x, tensor_y}).Apply());\n    TensorTuple input_vec = JUST(tensor_processor.GetInputs());\n\n    return OpInterpUtil::Dispatch<Tensor>(*broadcast_mul_op_, input_vec);\n  }\n\n private:\n  std::shared_ptr<OpExpr> broadcast_mul_op_;\n};\n\nclass InplaceMulFunctor {\n public:\n  InplaceMulFunctor() {\n    broadcast_mul_op_ =\n        CHECK_JUST(one::OpBuilder(\"broadcast_mul\").Input(\"x\").Input(\"y\").Output(\"z\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,\n                           const std::shared_ptr<one::Tensor>& y) const {\n    TensorProcessor tensor_processor;\n    if (y->requires_grad()) {\n      JUST(tensor_processor.PromoteInputsToCommonDtype(true)\n               .AddInputs({JUST(Identity(x)), y})\n               .Apply());\n    } else {\n      JUST(tensor_processor.PromoteInputsToCommonDtype(true).AddInputs({x, y}).Apply());\n    }\n    const TensorTuple& input_vec = JUST(tensor_processor.GetInputs());\n    const std::shared_ptr<one::Tensor>& x_cast = input_vec.at(0);\n    const std::shared_ptr<one::Tensor>& y_cast = input_vec.at(1);\n    JUST(CheckInplaceValid(x));\n    JUST(CheckInplaceCastValid(x, x_cast));\n    JUST(CheckInplaceShapeCanExpandTo(*y_cast->shape(), *x_cast->shape()));\n    std::shared_ptr<TensorTuple> outputs = std::make_shared<TensorTuple>(1);\n    outputs->at(0) = x;\n    JUST(OpInterpUtil::Dispatch(*broadcast_mul_op_, input_vec, outputs.get()));\n    return outputs->at(0);\n  }\n\n private:\n  std::shared_ptr<OpExpr> broadcast_mul_op_;\n};\n\nclass AddcmulBaseFunctor {\n public:\n  AddcmulBaseFunctor() = default;\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input,\n                           const std::shared_ptr<one::Tensor>& tensor1,\n                           const std::shared_ptr<one::Tensor>& tensor2, const Scalar& value,\n                           bool inplace) const {\n    return SequenceFunction<Maybe<Tensor>()>([&]() { return functional::Mul(tensor1, tensor2); })\n        .then([&](const auto& x) { return functional::ScalarMul(value, x); })\n        .then([&](const auto& x) { return functional::Add(input, x, /*alpha=*/1, inplace); })\n        .call();\n  }\n};\n\nclass AddcmulFunctor : public AddcmulBaseFunctor {\n public:\n  AddcmulFunctor() = default;\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input,\n                           const std::shared_ptr<one::Tensor>& tensor1,\n                           const std::shared_ptr<one::Tensor>& tensor2, const Scalar& value) const {\n    return AddcmulBaseFunctor::operator()(input, tensor1, tensor2, value, /*inplace=*/false);\n  }\n};\n\nclass InplaceAddcmulFunctor : public AddcmulBaseFunctor {\n public:\n  InplaceAddcmulFunctor() = default;\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input,\n                           const std::shared_ptr<one::Tensor>& tensor1,\n                           const std::shared_ptr<one::Tensor>& tensor2, const Scalar& value) const {\n    return AddcmulBaseFunctor::operator()(input, tensor1, tensor2, value, /*inplace=*/true);\n  }\n};\n\nclass DivFunctor : public BinaryFloatFunctor {\n public:\n  DivFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"broadcast_div\").Input(\"x\").Input(\"y\").Output(\"z\").Build());\n  }\n};\n\nclass DivFunctorMode {\n public:\n  DivFunctorMode() {}\n\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,\n                           const std::shared_ptr<one::Tensor>& y,\n                           const Optional<std::string>& rounding_mode) const {\n    std::string rmode = rounding_mode.value_or(\"\");\n    if (rmode == \"floor\") {\n      return JUST(functional::FloorDiv(x, y));\n\n    } else if (rmode == \"trunc\") {\n      return JUST(functional::TruncDiv(x, y));\n    }\n    CHECK_OR_RETURN(rmode == \"\") << \"div expected rounding_mode to be one of None,\"\n                                    \" 'trunc', or 'floor' but found \"\n                                 << rmode;\n    return JUST(functional::Div(x, y));\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass InplaceDivFunctor {\n public:\n  InplaceDivFunctor() {\n    broadcast_div_op_ =\n        CHECK_JUST(one::OpBuilder(\"broadcast_div\").Input(\"x\").Input(\"y\").Output(\"z\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,\n                           const std::shared_ptr<one::Tensor>& y) const {\n    auto tensor_x = x;\n    auto tensor_y = y;\n    JUST(CastDeviceForCPUScalarTensor(tensor_x, tensor_y, /*inplace=*/true));\n\n    // NOTE: div operator will cast inputs to float when dtype is integral\n    TensorProcessor tensor_processor;\n    TensorTuple tensor_processor_inputs;\n    {\n      if (tensor_y->requires_grad()) {\n        tensor_processor_inputs.assign({JUST(Identity(tensor_x)), tensor_y});\n      } else {\n        tensor_processor_inputs.assign({tensor_x, tensor_y});\n      }\n    }\n    if (promoteTypes(tensor_x->dtype(), tensor_y->dtype())->is_integer()) {\n      tensor_processor.AddInputs(tensor_processor_inputs, DType::Float());\n    } else {\n      tensor_processor.AddInputs(tensor_processor_inputs)\n          .PromoteInputsToCommonDtype(true)\n          .PromoteIntegerInputsToFloatDtype(true);\n    }\n    JUST(tensor_processor.Apply());\n\n    const TensorTuple& input_vec = JUST(tensor_processor.GetInputs());\n    const std::shared_ptr<one::Tensor>& x_cast = input_vec.at(0);\n    const std::shared_ptr<one::Tensor>& y_cast = input_vec.at(1);\n    JUST(CheckInplaceValid(x));\n    JUST(CheckInplaceCastValid(x, x_cast));\n    JUST(CheckInplaceShapeCanExpandTo(*y_cast->shape(), *x_cast->shape()));\n    std::shared_ptr<TensorTuple> outputs = std::make_shared<TensorTuple>(1);\n    outputs->at(0) = x;\n    JUST(OpInterpUtil::Dispatch(*broadcast_div_op_, input_vec, outputs.get()));\n    return outputs->at(0);\n  }\n\n private:\n  std::shared_ptr<OpExpr> broadcast_div_op_;\n};\n\nclass Atan2Functor : public BinaryFloatFunctor {\n public:\n  Atan2Functor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"atan2\").Input(\"x\").Input(\"y\").Output(\"z\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,\n                           const std::shared_ptr<one::Tensor>& y) const {\n    const int64_t x_element = x->nelement();\n    const int64_t y_element = y->nelement();\n    CHECK_GT_OR_RETURN(x_element, 0)\n        << Error::RuntimeError() << \"the size of input should be > 0, but got \" << x_element;\n    CHECK_GT_OR_RETURN(y_element, 0)\n        << Error::RuntimeError() << \"the size of input should be > 0, but got \" << y_element;\n\n    if ((x_element != 1 && y_element != 1) && (x->shape()->NumAxes() == y->shape()->NumAxes())) {\n      return BinaryFloatFunctor::operator()(x, y);\n    }\n\n    auto broad_x_ = x;\n    auto broad_y_ = y;\n    if (x_element == 1) {\n      broad_x_ = JUST(functional::Expand(x, *y->shape()));\n    } else if (y_element == 1) {\n      broad_y_ = JUST(functional::Expand(y, *x->shape()));\n    } else if (x->shape()->NumAxes() != y->shape()->NumAxes()) {\n      return Error::RuntimeError() << \"The size of tensor a (\" << x->shape()->NumAxes()\n                                   << \") must match the size of tensor b \"\n                                      \"(\"\n                                   << y->shape()->NumAxes() << \") at non-singleton dimension 1\";\n    } else {\n      return Error::RuntimeError() << \"\";\n    }\n\n    return BinaryFloatFunctor::operator()(broad_x_, broad_y_);\n  }\n};\n\nclass PowFunctor : public BinaryFloatFunctor {\n public:\n  PowFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"pow\").Input(\"x\").Input(\"y\").Output(\"z\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,\n                           const std::shared_ptr<one::Tensor>& y) const {\n    if (*x->shape() != *y->shape()) { return BroadcastPow(x, y); }\n    return BinaryFloatFunctor::operator()(x, y);\n  }\n};\n\nclass FloorDivFunctor : public BinaryFunctor {\n public:\n  FloorDivFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"floordiv\").Input(\"x\").Input(\"y\").Output(\"z\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,\n                           const std::shared_ptr<one::Tensor>& y) const {\n    return BinaryFunctor::operator()(x, y);\n  }\n};\n\nclass TruncDivFunctor : public BinaryFunctor {\n public:\n  TruncDivFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"truncdiv\").Input(\"x\").Input(\"y\").Output(\"z\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,\n                           const std::shared_ptr<one::Tensor>& y) const {\n    return BinaryFunctor::operator()(x, y);\n  }\n};\n\nclass LerpFunctor {\n public:\n  LerpFunctor() {\n    op_ = CHECK_JUST(\n        one::OpBuilder(\"lerp\").Input(\"start\").Input(\"end\").Input(\"weight\").Output(\"out\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& start,\n                           const std::shared_ptr<one::Tensor>& end,\n                           const std::shared_ptr<one::Tensor>& weight) const {\n    const int64_t weight_elem_cnt = weight->nelement();\n    CHECK_EQ_OR_RETURN(start->shape()->NumAxes(), end->shape()->NumAxes())\n        << Error::RuntimeError() << \"expected dim\" << start->shape()->NumAxes()\n        << \"for `end` but got dim\" << end->shape()->NumAxes();\n    CHECK_EQ_OR_RETURN(start->dtype()->data_type(), weight->dtype()->data_type())\n        << Error::RuntimeError() << \"expected dtype \" << start->dtype()->name()\n        << \" for `weights` but got dtype \" << weight->dtype()->name();\n\n    auto broadcast_shape = *start->shape();\n    if (*start->shape() != *end->shape() || *start->shape() != *weight->shape()) {\n      broadcast_shape = *JUST(\n          InferUnifiedShapeForBroadcasting({*start->shape(), *end->shape(), *weight->shape()}));\n    }\n\n    if (weight_elem_cnt == 1 && weight->is_eager() && !weight->requires_grad()) {\n      std::shared_ptr<Tensor> cast_double_weight =\n          JUST(functional::Cast(weight, DType::Double(), /*pin_memory=*/false));\n      double weight_scalar = JUST(GetItemInScalarTensor<double>(cast_double_weight));\n      return functional::ScalarLerp(start, end, weight_scalar);\n    }\n\n    std::shared_ptr<Tensor> broadcast_start = start;\n    std::shared_ptr<Tensor> broadcast_end = end;\n    std::shared_ptr<Tensor> broadcast_weight = weight;\n    if (*start->shape() != broadcast_shape) {\n      broadcast_start = JUST(functional::Expand(start, broadcast_shape));\n    }\n    if (*end->shape() != broadcast_shape) {\n      broadcast_end = JUST(functional::Expand(end, broadcast_shape));\n    }\n    if (*weight->shape() != broadcast_shape) {\n      broadcast_weight = JUST(functional::Expand(weight, broadcast_shape));\n    }\n\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {broadcast_start, broadcast_end, broadcast_weight});\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass InplaceLerpFunctor {\n public:\n  InplaceLerpFunctor() {\n    lerp_op_ = CHECK_JUST(\n        one::OpBuilder(\"lerp\").Input(\"start\").Input(\"end\").Input(\"weight\").Output(\"out\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& start,\n                           const std::shared_ptr<one::Tensor>& end,\n                           const std::shared_ptr<one::Tensor>& weight) const {\n    const int64_t weight_elem_cnt = weight->nelement();\n    CHECK_EQ_OR_RETURN(start->shape()->NumAxes(), end->shape()->NumAxes())\n        << Error::RuntimeError() << \"expected dim\" << start->shape()->NumAxes()\n        << \"for `end` but got dim\" << end->shape()->NumAxes();\n    CHECK_EQ_OR_RETURN(start->dtype()->data_type(), weight->dtype()->data_type())\n        << Error::RuntimeError() << \"expected dtype \" << start->dtype()->name()\n        << \" for `weights` but got dtype \" << weight->dtype()->name();\n\n    if (weight_elem_cnt == 1 && weight->is_eager() && !weight->requires_grad()) {\n      std::shared_ptr<Tensor> cast_double_weight =\n          JUST(functional::Cast(weight, DType::Double(), /*pin_memory=*/false));\n      double weight_scalar = JUST(GetItemInScalarTensor<double>(cast_double_weight));\n      JUST(functional::ScalarInplaceLerp(start, end, weight_scalar));\n      return start;\n    }\n\n    auto broadcast_shape = *start->shape();\n    if (*start->shape() != *end->shape() || *start->shape() != *weight->shape()) {\n      broadcast_shape = *JUST(\n          InferUnifiedShapeForBroadcasting({*start->shape(), *end->shape(), *weight->shape()}));\n    }\n\n    std::shared_ptr<one::Tensor> broadcast_start = JUST(Identity(start));\n    std::shared_ptr<one::Tensor> broadcast_end = JUST(Identity(end));\n    std::shared_ptr<one::Tensor> broadcast_weight = JUST(Identity(weight));\n    if (*start->shape() != broadcast_shape) {\n      broadcast_start = JUST(view::Expand(start, broadcast_shape));\n    }\n    if (*end->shape() != broadcast_shape) {\n      broadcast_end = JUST(view::Expand(end, broadcast_shape));\n    }\n    if (*weight->shape() != broadcast_shape) {\n      broadcast_weight = JUST(view::Expand(weight, broadcast_shape));\n    }\n\n    TensorProcessor tensor_processor;\n    if (broadcast_end->requires_grad() || broadcast_weight->requires_grad()) {\n      JUST(tensor_processor.PromoteInputsToCommonDtype(true)\n               .AddInputs({JUST(Identity(broadcast_start)), broadcast_end, broadcast_weight})\n               .Apply());\n    } else {\n      JUST(tensor_processor.PromoteInputsToCommonDtype(true)\n               .AddInputs({broadcast_start, broadcast_end, broadcast_weight})\n               .Apply());\n    }\n\n    const TensorTuple& input_vec = JUST(tensor_processor.GetInputs());\n    const std::shared_ptr<one::Tensor>& start_cast = input_vec.at(0);\n    const std::shared_ptr<one::Tensor>& end_cast = input_vec.at(1);\n    JUST(CheckInplaceValid(broadcast_start));\n    JUST(CheckInplaceCastValid(broadcast_start, start_cast));\n    JUST(CheckInplaceShapeCanExpandTo(*start_cast->shape(), *end_cast->shape()));\n    std::shared_ptr<TensorTuple> outputs = std::make_shared<TensorTuple>(1);\n    outputs->at(0) = start;\n    JUST(OpInterpUtil::Dispatch(*lerp_op_, input_vec, outputs.get()));\n    return outputs->at(0);\n  }\n\n private:\n  std::shared_ptr<OpExpr> lerp_op_;\n};\n\nclass LerpGradFunctor {\n public:\n  LerpGradFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"lerp_grad\")\n                         .Input(\"start\")\n                         .Input(\"end\")\n                         .Input(\"weight\")\n                         .Input(\"out_diff\")\n                         .Output(\"start_diff\")\n                         .Output(\"end_diff\")\n                         .Output(\"weight_diff\")\n                         .Build());\n  }\n\n  Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& start,\n                                const std::shared_ptr<one::Tensor>& end,\n                                const std::shared_ptr<one::Tensor>& weight,\n                                const std::shared_ptr<one::Tensor>& out_diff) const {\n    return OpInterpUtil::Dispatch<TensorTuple>(*op_, {start, end, weight, out_diff}, {});\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass BroadcastFModFunctor : public BinaryFunctor {\n public:\n  BroadcastFModFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"broadcast_fmod\").Input(\"x\").Input(\"y\").Output(\"z\").Build());\n  }\n};\n\nclass BroadcastEqualFunctor : public BinaryFunctor {\n public:\n  BroadcastEqualFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"broadcast_equal\").Input(\"x\").Input(\"y\").Output(\"z\").Build());\n  }\n};\n\nclass EqualFunctor {\n public:\n  EqualFunctor() {\n    broadcast_equal_op_ =\n        CHECK_JUST(one::OpBuilder(\"broadcast_equal\").Input(\"x\").Input(\"y\").Output(\"z\").Build());\n  }\n  Maybe<bool> operator()(const std::shared_ptr<one::Tensor>& x,\n                         const std::shared_ptr<one::Tensor>& y) const {\n    if (*x->shape() != *y->shape()) { return false; }\n    if (x->nelement() == 0) { return true; }\n\n    std::shared_ptr<Tensor> output = JUST(\n        ReduceAllWhole(JUST(OpInterpUtil::Dispatch<Tensor>(*broadcast_equal_op_, {x, y}, {}))));\n    bool status = JUST(GetItemInScalarTensor<bool>(output));\n    return status;\n  }\n\n private:\n  std::shared_ptr<OpExpr> broadcast_equal_op_;\n};\n\nclass BroadcastNotEqualFunctor : public BinaryFunctor {\n public:\n  BroadcastNotEqualFunctor() {\n    op_ =\n        CHECK_JUST(one::OpBuilder(\"broadcast_not_equal\").Input(\"x\").Input(\"y\").Output(\"z\").Build());\n  }\n};\n\nclass BroadcastGreaterFunctor : public BinaryFunctor {\n public:\n  BroadcastGreaterFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"broadcast_greater\").Input(\"x\").Input(\"y\").Output(\"z\").Build());\n  }\n};\n\nclass InplaceBroadcastGreaterFunctor {\n public:\n  InplaceBroadcastGreaterFunctor() {\n    op_ = CHECK_JUST(\n        one::OpBuilder(\"broadcast_inplace_greater\").Input(\"x\").Input(\"y\").Output(\"out\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,\n                           const std::shared_ptr<one::Tensor>& y) const {\n    TensorProcessor tensor_processor;\n    JUST(tensor_processor.PromoteInputsToCommonDtype(true).AddInputs({x, y}).Apply());\n    const TensorTuple& input_vec = JUST(tensor_processor.GetInputs());\n    const std::shared_ptr<one::Tensor>& x_cast = input_vec.at(0);\n    const std::shared_ptr<one::Tensor>& y_cast = input_vec.at(1);\n    JUST(CheckInplaceValid(x));\n    JUST(CheckInplaceCastValid(x, x_cast));\n    JUST(CheckInplaceShapeCanExpandTo(*y_cast->shape(), *x_cast->shape()));\n    std::shared_ptr<TensorTuple> outputs = std::make_shared<TensorTuple>(1);\n    outputs->at(0) = x;\n    JUST(OpInterpUtil::Dispatch(*op_, input_vec, outputs.get()));\n    return outputs->at(0);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass BroadcastGreaterEqualFunctor : public BinaryFunctor {\n public:\n  BroadcastGreaterEqualFunctor() {\n    op_ = CHECK_JUST(\n        one::OpBuilder(\"broadcast_greater_equal\").Input(\"x\").Input(\"y\").Output(\"z\").Build());\n  }\n};\n\nclass BroadcastLogicalAndFunctor : public BinaryFunctor {\n public:\n  BroadcastLogicalAndFunctor() {\n    op_ = CHECK_JUST(\n        one::OpBuilder(\"broadcast_logical_and\").Input(\"x\").Input(\"y\").Output(\"z\").Build());\n  }\n};\n\nclass BroadcastLogicalOrFunctor : public BinaryFunctor {\n public:\n  BroadcastLogicalOrFunctor() {\n    op_ = CHECK_JUST(\n        one::OpBuilder(\"broadcast_logical_or\").Input(\"x\").Input(\"y\").Output(\"z\").Build());\n  }\n};\n\nclass BroadcastLogicalXorFunctor : public BinaryFunctor {\n public:\n  BroadcastLogicalXorFunctor() {\n    op_ = CHECK_JUST(\n        one::OpBuilder(\"broadcast_logical_xor\").Input(\"x\").Input(\"y\").Output(\"z\").Build());\n  }\n};\n\nclass BroadcastBitwiseAndFunctor : public BinaryFunctor {\n public:\n  BroadcastBitwiseAndFunctor() {\n    op_ = CHECK_JUST(\n        one::OpBuilder(\"broadcast_bitwise_and\").Input(\"x\").Input(\"y\").Output(\"z\").Build());\n  }\n};\n\nclass BroadcastBitwiseOrFunctor : public BinaryFunctor {\n public:\n  BroadcastBitwiseOrFunctor() {\n    op_ = CHECK_JUST(\n        one::OpBuilder(\"broadcast_bitwise_or\").Input(\"x\").Input(\"y\").Output(\"z\").Build());\n  }\n};\n\nclass BroadcastBitwiseXorFunctor : public BinaryFunctor {\n public:\n  BroadcastBitwiseXorFunctor() {\n    op_ = CHECK_JUST(\n        one::OpBuilder(\"broadcast_bitwise_xor\").Input(\"x\").Input(\"y\").Output(\"z\").Build());\n  }\n};\n\nclass BroadcastLessFunctor : public BinaryFunctor {\n public:\n  BroadcastLessFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"broadcast_less\").Input(\"x\").Input(\"y\").Output(\"z\").Build());\n  }\n};\n\nclass BroadcastLessEqualFunctor : public BinaryFunctor {\n public:\n  BroadcastLessEqualFunctor() {\n    op_ = CHECK_JUST(\n        one::OpBuilder(\"broadcast_less_equal\").Input(\"x\").Input(\"y\").Output(\"z\").Build());\n  }\n};\n\nclass BroadcastIsCloseFunctor {\n public:\n  BroadcastIsCloseFunctor() {\n    eq_nan_op_ = CHECK_JUST(\n        one::OpBuilder(\"broadcast_isclose_eq_nan\").Input(\"x\").Input(\"y\").Output(\"z\").Build());\n    neq_nan_op_ = CHECK_JUST(\n        one::OpBuilder(\"broadcast_isclose_neq_nan\").Input(\"x\").Input(\"y\").Output(\"z\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,\n                           const std::shared_ptr<one::Tensor>& y, const float atol,\n                           const float rtol, const bool equal_nan) const {\n    auto& attr = THREAD_CACHED_MUTABLE_ATTR_MAP(\"atol\", \"rtol\", \"equal_nan\");\n    attr.SetAllAttrs(atol, rtol, equal_nan);\n    if (equal_nan) {\n      return OpInterpUtil::Dispatch<Tensor>(*eq_nan_op_, {x, y}, attr);\n    } else {\n      return OpInterpUtil::Dispatch<Tensor>(*neq_nan_op_, {x, y}, attr);\n    }\n  }\n\n private:\n  std::shared_ptr<OpExpr> eq_nan_op_;\n  std::shared_ptr<OpExpr> neq_nan_op_;\n};\n\nclass ScalarAddByTensorFunctor : public InplaceableBinaryFunctor {\n public:\n  ScalarAddByTensorFunctor() {\n    op_ = CHECK_JUST(\n        one::OpBuilder(\"scalar_add_by_tensor\").Input(\"x\").Input(\"scalar\").Output(\"y\").Build());\n  }\n};\n\n// this functor just for test host memory input\nclass HostScalarAddByTensorFunctor {\n public:\n  HostScalarAddByTensorFunctor() {\n    op_ = CHECK_JUST(\n        one::OpBuilder(\"host_scalar_add_by_tensor\").Input(\"x\").Input(\"scalar\").Output(\"y\").Build());\n  }\n\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,\n                           const std::shared_ptr<one::Tensor>& scalar) const {\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {x, scalar});\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass ScalarSubByTensorFunctor : public BinaryFunctor {\n public:\n  ScalarSubByTensorFunctor() {\n    op_ = CHECK_JUST(\n        one::OpBuilder(\"scalar_sub_by_tensor\").Input(\"x\").Input(\"scalar\").Output(\"y\").Build());\n  }\n};\n\nclass ScalarMulByTensorFunctor : public BinaryFunctor {\n public:\n  ScalarMulByTensorFunctor() {\n    op_ = CHECK_JUST(\n        one::OpBuilder(\"scalar_mul_by_tensor\").Input(\"x\").Input(\"scalar\").Output(\"y\").Build());\n  }\n};\n\nclass ScalarDivByTensorFunctor : public BinaryFunctor {\n public:\n  ScalarDivByTensorFunctor() {\n    op_ = CHECK_JUST(\n        one::OpBuilder(\"scalar_div_by_tensor\").Input(\"x\").Input(\"scalar\").Output(\"y\").Build());\n  }\n};\n\nclass BroadcastZetaFunctor : public BinaryFloatFunctor {\n public:\n  BroadcastZetaFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"broadcast_zeta\").Input(\"x\").Input(\"y\").Output(\"z\").Build());\n  }\n};\n\nclass ZetaScalarTensorFunctor {\n public:\n  Maybe<Tensor> operator()(const Scalar x, const std::shared_ptr<one::Tensor>& y) const {\n    auto scalar_tensor = JUST(functional::FullLike(y, x));  // wrap scalar to tensor\n    return functional::BroadcastZeta(scalar_tensor, y);\n  }\n};\n\nclass ZetaTensorScalarFunctor {\n public:\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const Scalar y) const {\n    auto scalar_tensor = JUST(functional::FullLike(x, y));  // wrap scalar to tensor\n    return functional::BroadcastZeta(x, scalar_tensor);\n  }\n};\n\n}  // namespace impl\n\nONEFLOW_FUNCTION_LIBRARY(m) {\n  m.add_functor<impl::AddFunctor>(\"Add\");\n  m.add_functor<impl::AddcmulFunctor>(\"Addcmul\");\n  m.add_functor<impl::InplaceAddcmulFunctor>(\"InplaceAddcmul\");\n  m.add_functor<impl::Atan2Functor>(\"Atan2\");\n  m.add_functor<impl::SubFunctor>(\"Sub\");\n  m.add_functor<impl::MulFunctor>(\"Mul\");\n  m.add_functor<impl::InplaceMulFunctor>(\"InplaceMul\");\n  m.add_functor<impl::InplaceDivFunctor>(\"InplaceDiv\");\n  m.add_functor<impl::DivFunctor>(\"Div\");\n  m.add_functor<impl::DivFunctorMode>(\"DivMode\");\n  m.add_functor<impl::PowFunctor>(\"Pow\");\n  m.add_functor<impl::BroadcastPowFunctor>(\"BroadcastPow\");\n  m.add_functor<impl::BroadcastEqualFunctor>(\"BroadcastEqual\");\n  m.add_functor<impl::EqualFunctor>(\"Equal\");\n  m.add_functor<impl::BroadcastNotEqualFunctor>(\"BroadcastNotEqual\");\n  m.add_functor<impl::BroadcastGreaterFunctor>(\"BroadcastGreater\");\n  m.add_functor<impl::InplaceBroadcastGreaterFunctor>(\"InplaceBroadcastGreater\");\n  m.add_functor<impl::BroadcastGreaterEqualFunctor>(\"BroadcastGreaterEqual\");\n  m.add_functor<impl::BroadcastLogicalAndFunctor>(\"BroadcastLogicalAnd\");\n  m.add_functor<impl::BroadcastLogicalOrFunctor>(\"BroadcastLogicalOr\");\n  m.add_functor<impl::BroadcastLogicalXorFunctor>(\"BroadcastLogicalXor\");\n  m.add_functor<impl::BroadcastBitwiseAndFunctor>(\"BroadcastBitwiseAnd\");\n  m.add_functor<impl::BroadcastBitwiseOrFunctor>(\"BroadcastBitwiseOr\");\n  m.add_functor<impl::BroadcastBitwiseXorFunctor>(\"BroadcastBitwiseXor\");\n  m.add_functor<impl::BroadcastLessFunctor>(\"BroadcastLess\");\n  m.add_functor<impl::BroadcastLessEqualFunctor>(\"BroadcastLessEqual\");\n  m.add_functor<impl::ScalarAddByTensorFunctor>(\"ScalarAddByTensor\");\n  m.add_functor<impl::HostScalarAddByTensorFunctor>(\"HostScalarAddByTensor\");\n  m.add_functor<impl::ScalarSubByTensorFunctor>(\"ScalarSubByTensor\");\n  m.add_functor<impl::ScalarMulByTensorFunctor>(\"ScalarMulByTensor\");\n  m.add_functor<impl::ScalarDivByTensorFunctor>(\"ScalarDivByTensor\");\n  m.add_functor<impl::BroadcastFModFunctor>(\"BroadcastFMod\");\n  m.add_functor<impl::FloorDivFunctor>(\"FloorDiv\");\n  m.add_functor<impl::TruncDivFunctor>(\"TruncDiv\");\n  m.add_functor<impl::BroadcastIsCloseFunctor>(\"IsClose\");\n  m.add_functor<impl::LerpFunctor>(\"Lerp\");\n  m.add_functor<impl::InplaceLerpFunctor>(\"InplaceLerp\");\n  m.add_functor<impl::LerpGradFunctor>(\"LerpGrad\");\n  m.add_functor<impl::BroadcastZetaFunctor>(\"BroadcastZeta\");\n  m.add_functor<impl::ZetaScalarTensorFunctor>(\"ZetaScalarTensor\");\n  m.add_functor<impl::ZetaTensorScalarFunctor>(\"ZetaTensorScalar\");\n};\n\n}  // namespace functional\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/functional/impl/binary_functor.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#ifndef ONEFLOW_CORE_FUNCTIONAL_IMPL_BINARY_FUNCTOR_H_\n#define ONEFLOW_CORE_FUNCTIONAL_IMPL_BINARY_FUNCTOR_H_\n\n#include \"oneflow/core/framework/op_expr.h\"\n#include \"oneflow/core/framework/op_interpreter/op_interpreter_util.h\"\n#include \"oneflow/core/framework/tensor.h\"\n#include \"oneflow/core/functional/impl/common.h\"\n#include \"oneflow/core/functional/tensor_processor.h\"\n#include \"oneflow/core/functional/functional.h\"\n\nnamespace oneflow {\nnamespace one {\nnamespace functional {\n\nnamespace impl {\n\nstd::string TensorDeviceToString(const std::shared_ptr<Tensor>& tensor);\n\nMaybe<void> CastDeviceForCPUScalarTensor(std::shared_ptr<Tensor>& tensor,\n                                         std::shared_ptr<Tensor>& other, bool inplace);\n\nclass BinaryFunctor {\n public:\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,\n                           const std::shared_ptr<one::Tensor>& y) const {\n    auto tensor_x = x;\n    auto tensor_y = y;\n    JUST(CastDeviceForCPUScalarTensor(tensor_x, tensor_y, /*inplace=*/false));\n    TensorProcessor tensor_processor;\n    JUST(tensor_processor.PromoteInputsToCommonDtype(true).AddInputs({tensor_x, tensor_y}).Apply());\n    TensorTuple input_tuple = JUST(tensor_processor.GetInputs());\n    return OpInterpUtil::Dispatch<Tensor>(*op_, input_tuple);\n  }\n\n protected:\n  BinaryFunctor() = default;\n  virtual ~BinaryFunctor() = default;\n\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass BinaryFloatFunctor {\n public:\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,\n                           const std::shared_ptr<one::Tensor>& y) const {\n    auto tensor_x = x;\n    auto tensor_y = y;\n    JUST(CastDeviceForCPUScalarTensor(tensor_x, tensor_y, /*inplace=*/false));\n    TensorProcessor tensor_processor;\n    if (promoteTypes(tensor_x->dtype(), tensor_y->dtype())->is_integer()) {\n      tensor_processor.AddInputs({tensor_x, tensor_y}, DType::Float());\n    } else {\n      tensor_processor.AddInputs({tensor_x, tensor_y})\n          .PromoteInputsToCommonDtype(true)\n          .PromoteIntegerInputsToFloatDtype(true);\n    }\n    JUST(tensor_processor.Apply());\n    TensorTuple input_tuple = JUST(tensor_processor.GetInputs());\n    return OpInterpUtil::Dispatch<Tensor>(*op_, input_tuple);\n  }\n\n protected:\n  BinaryFloatFunctor() = default;\n  virtual ~BinaryFloatFunctor() = default;\n\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass BinaryGradFunctor {\n public:\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,\n                           const std::shared_ptr<one::Tensor>& y,\n                           const std::shared_ptr<one::Tensor>& dz) const {\n    TensorProcessor tensor_processor;\n    JUST(tensor_processor.PromoteInputsToCommonDtype(true).AddInputs({x, y, dz}).Apply());\n    TensorTuple input_tuple = JUST(tensor_processor.GetInputs());\n    return OpInterpUtil::Dispatch<Tensor>(*op_, input_tuple);\n  }\n\n protected:\n  BinaryGradFunctor() = default;\n  virtual ~BinaryGradFunctor() = default;\n\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass InplaceableBinaryFunctor {\n public:\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,\n                           const std::shared_ptr<one::Tensor>& y, bool inplace) const {\n    auto tensor_x = x;\n    auto tensor_y = y;\n    JUST(CastDeviceForCPUScalarTensor(tensor_x, tensor_y, inplace));\n    TensorProcessor tensor_processor;\n    JUST(tensor_processor.PromoteInputsToCommonDtype(true).AddInputs({tensor_x, tensor_y}).Apply());\n    TensorTuple input_tuple = JUST(tensor_processor.GetInputs());\n    if (inplace) {\n      std::shared_ptr<one::Tensor>& x_cast = input_tuple.at(0);\n      std::shared_ptr<one::Tensor>& y_cast = input_tuple.at(1);\n      JUST(CheckInplaceCastValid(x, x_cast));\n      JUST(CheckInplaceShapeCanExpandTo(*y_cast->shape(), *x_cast->shape()));\n      std::shared_ptr<TensorTuple> outputs = std::make_shared<TensorTuple>(1);\n      outputs->at(0) = x_cast;\n      JUST(OpInterpUtil::Dispatch(*op_, input_tuple, outputs.get()));\n      return outputs->at(0);\n    } else {\n      return OpInterpUtil::Dispatch<Tensor>(*op_, input_tuple);\n    }\n  }\n\n protected:\n  InplaceableBinaryFunctor() = default;\n  virtual ~InplaceableBinaryFunctor() = default;\n\n  std::shared_ptr<OpExpr> op_;\n};\n\n}  // namespace impl\n\n}  // namespace functional\n}  // namespace one\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_FUNCTIONAL_IMPL_BINARY_FUNCTOR_H_\n"
  },
  {
    "path": "oneflow/core/functional/impl/binary_grad_functor.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/functional/impl/binary_functor.h\"\n\n#include \"oneflow/core/framework/op_builder.h\"\n#include \"oneflow/core/functional/function_library.h\"\n#include \"oneflow/user/ops/math_binary_elementwise_seq.h\"\n\nnamespace oneflow {\nnamespace one {\nnamespace functional {\n\nnamespace impl {\n\n#define BINARY_ELEMENTWISE_GRAD_FUNCTOR(op_type_name, class_name, base)           \\\n  class class_name##XGradFunctor : public base {                                  \\\n   public:                                                                        \\\n    class_name##XGradFunctor() {                                                  \\\n      op_ = CHECK_JUST(one::OpBuilder(std::string(\"\") + op_type_name + \"_x_grad\") \\\n                           .Input(\"x\")                                            \\\n                           .Input(\"y\")                                            \\\n                           .Input(\"dz\")                                           \\\n                           .Output(\"dx\")                                          \\\n                           .Build());                                             \\\n    }                                                                             \\\n  };                                                                              \\\n  class class_name##YGradFunctor : public base {                                  \\\n   public:                                                                        \\\n    class_name##YGradFunctor() {                                                  \\\n      op_ = CHECK_JUST(one::OpBuilder(std::string(\"\") + op_type_name + \"_y_grad\") \\\n                           .Input(\"x\")                                            \\\n                           .Input(\"y\")                                            \\\n                           .Input(\"dz\")                                           \\\n                           .Output(\"dy\")                                          \\\n                           .Build());                                             \\\n    }                                                                             \\\n  };\n\n#define INSTANTIAT_BINARY_ELEMENTWISE_GRAD_FUNCTOR(op_type_name, class_name) \\\n  BINARY_ELEMENTWISE_GRAD_FUNCTOR(op_type_name, class_name, BinaryGradFunctor);\n\nOF_PP_FOR_EACH_TUPLE(INSTANTIAT_BINARY_ELEMENTWISE_GRAD_FUNCTOR, MATH_BINARY_ELEMENTWISE_FUNC_SEQ);\n}  // namespace impl\n\nusing namespace impl;\n\n#define ADD_BINARY_GRAD_FUNCTOR(class_name, functor_name)                            \\\n  m.add_functor<class_name##XGradFunctor>(std::string(\"\") + functor_name + \"XGrad\"); \\\n  m.add_functor<class_name##YGradFunctor>(std::string(\"\") + functor_name + \"YGrad\");\n\nONEFLOW_FUNCTION_LIBRARY(m) {\n  ADD_BINARY_GRAD_FUNCTOR(Pow, \"Pow\");\n  ADD_BINARY_GRAD_FUNCTOR(Atan2, \"Atan2\");\n  ADD_BINARY_GRAD_FUNCTOR(FloorDiv, \"FloorDiv\");\n  ADD_BINARY_GRAD_FUNCTOR(TruncDiv, \"TruncDiv\");\n  ADD_BINARY_GRAD_FUNCTOR(Xdivy, \"Xdivy\");\n  ADD_BINARY_GRAD_FUNCTOR(Xlogy, \"Xlogy\");\n};\n\n#undef ADD_BINARY_GRAD_FUNCTOR\n\n}  // namespace functional\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/functional/impl/comm_functor.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/container_util.h\"\n#include \"oneflow/core/framework/id_util.h\"\n#include \"oneflow/core/framework/attr_map.h\"\n#include \"oneflow/core/framework/attr_value.h\"\n#include \"oneflow/core/framework/mutable_attr_map.h\"\n#include \"oneflow/core/framework/nd_sbp.h\"\n#include \"oneflow/core/job/nd_sbp_util.h\"\n#include \"oneflow/core/framework/op_builder.h\"\n#include \"oneflow/core/framework/op_expr.h\"\n#include \"oneflow/core/framework/op_interpreter/op_interpreter_util.h\"\n#include \"oneflow/core/framework/op_interpreter/eager_local_op_interpreter.h\"\n#include \"oneflow/core/framework/tensor.h\"\n#include \"oneflow/core/framework/tensor_tuple.h\"\n#include \"oneflow/core/functional/functional.h\"\n#include \"oneflow/core/functional/function_library.h\"\n#include \"oneflow/core/functional/impl/common.h\"\n#include \"oneflow/core/functional/impl/unary_functor.h\"\n#include \"oneflow/core/ccl/ccl.h\"\n#include \"oneflow/core/job/rank_group_scope.h\"\n#include \"oneflow/core/rpc/include/global_process_ctx.h\"\n#include \"oneflow/core/common/flat_shape.h\"\n#include \"oneflow/core/framework/user_op_registry_manager.h\"\n\nnamespace oneflow {\nnamespace one {\nnamespace functional {\n\nnamespace impl {\n\nnamespace {\n\n#define OF_KERNEL_NOT_SUPPORT_ERROR(op_type, device_type)                                          \\\n  Error::RuntimeError() << op_type << \" not suport for the device (\"                               \\\n                        << DeviceType_Name(device_type) << \") because eager kernel of \" << op_type \\\n                        << \" is not registered\"\n\nclass EagerCclKernelRegContext final : public user_op::KernelRegContext {\n public:\n  explicit EagerCclKernelRegContext(DeviceType device_type) : device_type_(device_type) {}\n  ~EagerCclKernelRegContext() = default;\n\n  DeviceType device_type() const override { return device_type_; }\n  const ParallelContext& parallel_ctx() const override { PRINT_BUG_PROMPT_AND_ABORT(); }\n  const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(const std::string& arg_name,\n                                                        int32_t index) const override {\n    PRINT_BUG_PROMPT_AND_ABORT();\n  }\n  const std::vector<std::pair<std::string, int32_t>>& inputs() const override {\n    PRINT_BUG_PROMPT_AND_ABORT();\n  }\n  const std::vector<std::pair<std::string, int32_t>>& outputs() const override {\n    PRINT_BUG_PROMPT_AND_ABORT();\n  }\n\n  const user_op::UserOpConfWrapper& user_op_conf() const override { PRINT_BUG_PROMPT_AND_ABORT(); }\n\n  const std::shared_ptr<const user_op::AttrVal>& Attr4Name(\n      const std::string& attr_name) const override {\n    PRINT_BUG_PROMPT_AND_ABORT();\n  }\n\n private:\n  DeviceType device_type_;\n};\n\nMaybe<bool> RawCheckCclKernelRegistered(const std::string& op_type_name, DeviceType device_type) {\n  EagerCclKernelRegContext reg_ctx(device_type);\n  return user_op::UserOpRegistryMgr::Get().IsOpKernelRegistered(op_type_name, reg_ctx);\n}\n\nstatic constexpr auto* CheckCclKernelRegistered =\n    DECORATE(&RawCheckCclKernelRegistered, ThreadLocalCachedCopiable);\n\nbool IsSplitSbp(Symbol<SbpParallel> sbp_parallel) { return sbp_parallel->has_split_parallel(); }\n\nMaybe<one::UserOpExpr> EagerCclAllReduce(Symbol<ParallelDesc> parallel_desc) {\n  CHECK_OR_RETURN(\n      JUST(CheckCclKernelRegistered(\"eager_ccl_all_reduce\", parallel_desc->device_type())))\n      << OF_KERNEL_NOT_SUPPORT_ERROR(\"AllReduce\", parallel_desc->device_type());\n  return one::OpBuilder(\"eager_ccl_all_reduce\", *JUST(UniqueStr(\"eager_ccl_all_reduce\")))\n      .Input(\"in\")\n      .Output(\"out\")\n      .Attr<std::string>(\"parallel_conf\", PbMessage2TxtString(parallel_desc->parallel_conf()))\n      .Build();\n}\n\nstatic constexpr auto* CachedEagerCclAllReduceOpExpr = DECORATE(&EagerCclAllReduce, ThreadLocal);\n\nMaybe<one::UserOpExpr> EagerCclReduceScatter(Symbol<ParallelDesc> parallel_desc,\n                                             const std::string& op_type) {\n  CHECK_OR_RETURN(\n      JUST(CheckCclKernelRegistered(\"eager_ccl_reduce_scatter\", parallel_desc->device_type())))\n      << OF_KERNEL_NOT_SUPPORT_ERROR(\"ReduceScatter\", parallel_desc->device_type());\n  return one::OpBuilder(\"eager_ccl_reduce_scatter\", *JUST(UniqueStr(\"eager_ccl_reduce_scatter\")))\n      .Input(\"in\")\n      .Output(\"out\")\n      .Attr<std::string>(\"parallel_conf\", PbMessage2TxtString(parallel_desc->parallel_conf()))\n      .Attr<std::string>(\"op_type\", op_type)\n      .Build();\n}\nstatic constexpr auto* CachedCclReduceScatterOpExpr =\n    DECORATE(&EagerCclReduceScatter, ThreadLocalCopiable);\n\nMaybe<one::UserOpExpr> EagerCclAllGather(Symbol<ParallelDesc> parallel_desc) {\n  CHECK_OR_RETURN(\n      JUST(CheckCclKernelRegistered(\"eager_ccl_all_gather\", parallel_desc->device_type())))\n      << OF_KERNEL_NOT_SUPPORT_ERROR(\"AllGather\", parallel_desc->device_type());\n  return one::OpBuilder(\"eager_ccl_all_gather\", *JUST(UniqueStr(\"eager_ccl_all_gather\")))\n      .Input(\"in\")\n      .Output(\"out\")\n      .Attr<std::string>(\"parallel_conf\", PbMessage2TxtString(parallel_desc->parallel_conf()))\n      .Build();\n}\n\nstatic constexpr auto* CachedEagerCclAllGatherOpExpr = DECORATE(&EagerCclAllGather, ThreadLocal);\n\nMaybe<one::UserOpExpr> EagerCclS2S(Symbol<ParallelDesc> parallel_desc, Symbol<SbpParallel> src_sbp,\n                                   Symbol<SbpParallel> dst_sbp) {\n  return one::OpBuilder(\"eager_ccl_s2s\", *JUST(UniqueStr(\"eager_ccl_s2s\")))\n      .Input(\"in\")\n      .Output(\"out\")\n      .Attr<int64_t>(\"in_split_axis\", src_sbp->split_parallel().axis())\n      .Attr<int64_t>(\"out_split_axis\", dst_sbp->split_parallel().axis())\n      .Attr<std::string>(\"parallel_conf\", PbMessage2TxtString(parallel_desc->parallel_conf()))\n      .Build();\n}\n\nauto* CachedEagerCclS2SOpExpr = DECORATE(&EagerCclS2S, ThreadLocal);\n\nMaybe<one::UserOpExpr> EagerCclReduce(Symbol<ParallelDesc> parallel_desc, int64_t root) {\n  CHECK_OR_RETURN(JUST(CheckCclKernelRegistered(\"eager_ccl_reduce\", parallel_desc->device_type())))\n      << OF_KERNEL_NOT_SUPPORT_ERROR(\"Reduce\", parallel_desc->device_type());\n  return one::OpBuilder(\"eager_ccl_reduce\", *JUST(UniqueStr(\"eager_ccl_reduce\")))\n      .Input(\"in\")\n      .Output(\"out\")\n      .Attr<std::string>(\"parallel_conf\", PbMessage2TxtString(parallel_desc->parallel_conf()))\n      .Attr<int64_t>(\"root\", root)\n      .Build();\n}\n\nauto* CachedEagerCclReduceOpExpr = DECORATE(&EagerCclReduce, ThreadLocal);\n\nMaybe<one::UserOpExpr> RankGroupAndDeviceType2AllReduceOpExpr(Symbol<RankGroup> rank_group,\n                                                              DeviceType device_type) {\n  CHECK_OR_RETURN(JUST(CheckCclKernelRegistered(\"eager_ccl_all_reduce\", device_type)))\n      << OF_KERNEL_NOT_SUPPORT_ERROR(\"AllReduce\", device_type);\n  const auto& parallel_desc = JUST(RankGroup::GetDefaultParallelDesc(device_type, rank_group));\n  return one::OpBuilder(\"eager_ccl_all_reduce\")\n      .Input(\"in\")\n      .Output(\"out\")\n      .Attr<std::string>(\"parallel_conf\", PbMessage2TxtString(parallel_desc->parallel_conf()))\n      .Build();\n}\n\nauto* CachedRankGroupAndDeviceType2AllReduceOpExpr =\n    DECORATE(&RankGroupAndDeviceType2AllReduceOpExpr, ThreadLocal);\n\nMaybe<one::UserOpExpr> RankGroupAndDeviceType2AllGatherOpExpr(Symbol<RankGroup> rank_group,\n                                                              DeviceType device_type) {\n  CHECK_OR_RETURN(JUST(CheckCclKernelRegistered(\"eager_ccl_all_gather\", device_type)))\n      << OF_KERNEL_NOT_SUPPORT_ERROR(\"AllGather\", device_type);\n  const auto& parallel_desc = JUST(RankGroup::GetDefaultParallelDesc(device_type, rank_group));\n  return one::OpBuilder(\"eager_ccl_all_gather\")\n      .Input(\"in\")\n      .Output(\"out\")\n      .Attr<std::string>(\"parallel_conf\", PbMessage2TxtString(parallel_desc->parallel_conf()))\n      .Build();\n}\n\nauto* CachedRankGroupAndDeviceType2AllGatherOpExpr =\n    DECORATE(&RankGroupAndDeviceType2AllGatherOpExpr, ThreadLocal);\n\nMaybe<one::UserOpExpr> RankGroupAndDeviceType2ReduceScatterOpExpr(Symbol<RankGroup> rank_group,\n                                                                  DeviceType device_type) {\n  CHECK_OR_RETURN(JUST(CheckCclKernelRegistered(\"eager_ccl_reduce_scatter\", device_type)))\n      << OF_KERNEL_NOT_SUPPORT_ERROR(\"ReduceScatter\", device_type);\n  const auto& parallel_desc = JUST(RankGroup::GetDefaultParallelDesc(device_type, rank_group));\n  return one::OpBuilder(\"eager_ccl_reduce_scatter\", *JUST(UniqueStr(\"eager_ccl_reduce_scatter\")))\n      .Input(\"in\")\n      .Output(\"out\")\n      .Attr<std::string>(\"parallel_conf\", PbMessage2TxtString(parallel_desc->parallel_conf()))\n      .Build();\n}\n\nauto* CachedRankGroupAndDeviceType2ReduceScatterOpExpr =\n    DECORATE(&RankGroupAndDeviceType2ReduceScatterOpExpr, ThreadLocal);\n\n#undef OF_KERNEL_NOT_SUPPORT_ERROR\n\n}  // namespace\n\nclass CommBroadcastFunctor {\n public:\n  CommBroadcastFunctor() = default;\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, int64_t src_rank,\n                           bool inplace) const {\n    const auto& rank_group = JUST(RankGroupScope::CurrentRankGroup());\n    DeviceType device_type = JUST(x->device())->enum_type();\n    const auto& parallel_desc = JUST(RankGroup::GetDefaultParallelDesc(device_type, rank_group));\n    return one::Broadcast(x, src_rank, parallel_desc, inplace);\n  }\n};\n\nclass CommBroadcastTensorsFunctor {\n public:\n  CommBroadcastTensorsFunctor() = default;\n  Maybe<one::TensorTuple> operator()(const one::TensorTuple& inputs, int64_t src_rank,\n                                     bool inplace) const {\n    if (inputs.empty()) { return inputs; }\n    const auto& rank_group = JUST(RankGroupScope::CurrentRankGroup());\n    const auto& x = JUST(VectorAt(inputs, 0));\n    DeviceType device_type = JUST(x->device())->enum_type();\n    const auto& parallel_desc = JUST(RankGroup::GetDefaultParallelDesc(device_type, rank_group));\n    return one::Broadcast(inputs, src_rank, parallel_desc, inplace);\n  }\n};\n\nnamespace {\n\nMaybe<one::UserOpExpr> RawStreamTouchFunctorOpExpr(size_t input_size) {\n  return one::OpBuilder(\"eager_ccl_touch\", *JUST(UniqueStr(\"eager_ccl_touch\")))\n      .Input(\"in\", input_size)\n      .Build();\n}\n\nstatic constexpr auto* StreamTouchFunctorOpExpr =\n    DECORATE(&RawStreamTouchFunctorOpExpr, ThreadLocal);\n\n}  // namespace\n\nclass StreamTouchFunctor {\n public:\n  StreamTouchFunctor() = default;\n  Maybe<void> operator()(const one::TensorTuple& inputs) const {\n    if (inputs.empty()) { return Maybe<void>::Ok(); }\n    std::shared_ptr<UserOpExpr> op_expr = JUST(StreamTouchFunctorOpExpr(inputs.size()));\n    TensorTuple outputs{};\n    JUST(OpInterpUtil::Dispatch(*op_expr, inputs, &outputs));\n    return Maybe<void>::Ok();\n  }\n};\n\nclass LocalAllReduceFunctor {\n public:\n  LocalAllReduceFunctor() = default;\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, bool inplace) const {\n    const auto& device = JUST(x->device());\n    CHECK_EQ_OR_RETURN(device->device_id(), GlobalProcessCtx::LocalRank());\n    const auto& rank_group = JUST(RankGroupScope::CurrentRankGroup());\n    DeviceType device_type = device->enum_type();\n    std::shared_ptr<OpExpr> op_expr =\n        JUST(CachedRankGroupAndDeviceType2AllReduceOpExpr(rank_group, device_type));\n    auto op_input = x;\n    if (const auto& static_zeros_tensor = std::dynamic_pointer_cast<StaticZerosTensor>(x)) {\n      op_input = std::dynamic_pointer_cast<Tensor>(JUST(static_zeros_tensor->AsLocalTensor()));\n    }\n    if (inplace) {\n      JUST(CheckInplaceValid(op_input));\n      TensorTuple outputs{op_input};\n      JUST(OpInterpUtil::Dispatch(*op_expr, {op_input}, &outputs));\n      return outputs[0];\n    } else {\n      return OpInterpUtil::Dispatch<Tensor>(*op_expr, {op_input}, {});\n    }\n  }\n};\n\nclass GlobalAllReduceFunctor {\n public:\n  GlobalAllReduceFunctor() = default;\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x) const {\n    {\n      CHECK_OR_RETURN(x->is_global()) << \"Tensor is not global\";\n      CHECK_OR_RETURN(NdSbpIsAllPartialSum(*JUST(x->nd_sbp())))\n          << \"Tensor's sbp must be partial_sum\";\n    }\n    std::shared_ptr<OpExpr> op_expr = JUST(CachedEagerCclAllReduceOpExpr(JUST(x->parallel_desc())));\n    return JUST(OpInterpUtil::Dispatch<Tensor>(*op_expr, {x}));\n  }\n};\n\nclass GlobalReduceScatterFunctor {\n public:\n  GlobalReduceScatterFunctor() = default;\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,\n                           const std::string& op_type) const {\n    {\n      CHECK_OR_RETURN(x->is_global());  // NOLINT\n      if (op_type == \"max\") {\n        CHECK_OR_RETURN(NdSbpIsAllBroadcast(*JUST(x->nd_sbp())))\n            << \"Tensor's sbp must be broadcast to get reduce_max\";\n        CHECK_EQ_OR_RETURN(JUST(x->parallel_desc())->device_type(), DeviceType::kCUDA)\n            << \"reduce_max only support CUDA\";\n      } else if (op_type == \"sum\") {\n        CHECK_OR_RETURN(NdSbpIsAllPartialSum(*JUST(x->nd_sbp())))\n            << \"Tensor's sbp must be partial_sum to get reduce_sum\";\n      } else {\n        UNIMPLEMENTED_THEN_RETURN();\n      }\n    }\n    std::shared_ptr<OpExpr> op_expr =\n        JUST(CachedCclReduceScatterOpExpr(JUST(x->parallel_desc()), op_type));\n    return JUST(OpInterpUtil::Dispatch<Tensor>(*op_expr, {x}));\n  }\n};\n\nclass LocalReduceScatterFunctor {\n public:\n  LocalReduceScatterFunctor() = default;\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& output,\n                           const std::shared_ptr<one::Tensor>& input) const {\n    DataType dtype_val = input->dtype()->data_type();\n    CHECK_EQ_OR_RETURN(input->shape()->elem_cnt(),\n                       output->nelement() * GlobalProcessCtx::WorldSize())\n        << Error::RuntimeError()\n        << \"output tensor size must be equal to world_size times input tensor size\";\n    CHECK_EQ_OR_RETURN(dtype_val, output->dtype()->data_type())\n        << Error::RuntimeError() << Error::RuntimeError()\n        << \"output tensor must have the same type as input tensor\";\n    const Shape& shape = *output->shape();\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"output_shape\", \"output_dtype\");\n    attrs.SetAllAttrs(shape, dtype_val);\n    const auto& device = JUST(input->device());\n    CHECK_EQ_OR_RETURN(device->device_id(), GlobalProcessCtx::LocalRank());\n    const auto& rank_group = JUST(RankGroupScope::CurrentRankGroup());\n    DeviceType device_type = device->enum_type();\n    std::shared_ptr<OpExpr> op_expr =\n        JUST(CachedRankGroupAndDeviceType2ReduceScatterOpExpr(rank_group, device_type));\n    auto op_input = input;\n    if (const auto& static_zeros_tensor = std::dynamic_pointer_cast<StaticZerosTensor>(input)) {\n      op_input = std::dynamic_pointer_cast<Tensor>(JUST(static_zeros_tensor->AsLocalTensor()));\n    }\n    TensorTuple outputs{output};\n    JUST(OpInterpUtil::Dispatch(*op_expr, {op_input}, &outputs, attrs));\n    return outputs[0];\n  }\n};\n\nclass GlobalAllGatherFunctor {\n public:\n  GlobalAllGatherFunctor() = default;\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x) const {\n    {\n      CHECK_OR_RETURN(x->is_global()) << \"Tensor is not global\";\n      CHECK_OR_RETURN(NdSbpIsAllSplit(*JUST(x->nd_sbp()), 0))\n          << \"Tensor's sbp must be split to get all_gather\";\n    }\n    std::shared_ptr<OpExpr> op_expr = JUST(CachedEagerCclAllGatherOpExpr(JUST(x->parallel_desc())));\n    return JUST(OpInterpUtil::Dispatch<Tensor>(*op_expr, {x}));\n  }\n};\n\nclass LocalAllGatherFunctor {\n public:\n  LocalAllGatherFunctor() = default;\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& output,\n                           const std::shared_ptr<one::Tensor>& input) const {\n    DataType dtype_val = input->dtype()->data_type();\n    CHECK_EQ_OR_RETURN(input->shape()->elem_cnt() * GlobalProcessCtx::WorldSize(),\n                       output->nelement())\n        << Error::RuntimeError()\n        << \"output tensor size must be equal to world_size times input tensor size\";\n    CHECK_EQ_OR_RETURN(dtype_val, output->dtype()->data_type())\n        << Error::RuntimeError() << Error::RuntimeError()\n        << \"output tensor must have the same type as input tensor\";\n    const Shape& shape = *output->shape();\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"output_shape\", \"output_dtype\");\n    attrs.SetAllAttrs(shape, dtype_val);\n    const auto& device = JUST(input->device());\n    CHECK_EQ_OR_RETURN(device->device_id(), GlobalProcessCtx::LocalRank());\n    const auto& rank_group = JUST(RankGroupScope::CurrentRankGroup());\n    DeviceType device_type = device->enum_type();\n    std::shared_ptr<OpExpr> op_expr =\n        JUST(CachedRankGroupAndDeviceType2AllGatherOpExpr(rank_group, device_type));\n    auto op_input = input;\n    if (const auto& static_zeros_tensor = std::dynamic_pointer_cast<StaticZerosTensor>(input)) {\n      op_input = std::dynamic_pointer_cast<Tensor>(JUST(static_zeros_tensor->AsLocalTensor()));\n    }\n    TensorTuple outputs{output};\n    JUST(OpInterpUtil::Dispatch(*op_expr, {op_input}, &outputs, attrs));\n    return outputs[0];\n  }\n};\n\nclass GlobalS2SFunctor {\n public:\n  GlobalS2SFunctor() = default;\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,\n                           const std::vector<Symbol<SbpParallel>>& sbp_parallels) const {\n    Symbol<NdSbp> in_nd_sbp = JUST(x->nd_sbp());\n    Symbol<NdSbp> out_nd_sbp = JUST(GetNdSbp(sbp_parallels));\n    {\n      CHECK_OR_RETURN(x->is_global());  // NOLINT\n      CHECK_EQ_OR_RETURN(in_nd_sbp->sbp_parallel_size(), 1);\n      CHECK_OR_RETURN(IsSplitSbp(in_nd_sbp->sbp_parallel(0)));\n      CHECK_EQ_OR_RETURN(out_nd_sbp->sbp_parallel_size(), 1);\n      CHECK_OR_RETURN(IsSplitSbp(out_nd_sbp->sbp_parallel(0)));\n      CHECK_NE_OR_RETURN(in_nd_sbp->sbp_parallel(0).split_parallel().axis(),\n                         out_nd_sbp->sbp_parallel(0).split_parallel().axis());\n    }\n    std::shared_ptr<OpExpr> op_expr =\n        JUST(CachedEagerCclS2SOpExpr(JUST(x->parallel_desc()), SymbolOf(in_nd_sbp->sbp_parallel(0)),\n                                     SymbolOf(out_nd_sbp->sbp_parallel(0))));\n    return JUST(OpInterpUtil::Dispatch<Tensor>(*op_expr, {x}));\n  }\n};\n\nclass SendFunctor {\n public:\n  SendFunctor() { op_expr_ = CHECK_JUST(one::OpBuilder(\"send\").Input(\"in\").Build()); }\n  Maybe<void> operator()(const std::shared_ptr<one::Tensor>& x, int64_t dst, bool send_meta) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"dst_process_id\");\n    attrs.SetAllAttrs(dst);\n    if (send_meta) {\n      std::shared_ptr<FlatShape> flat_shape = JUST(FlatShape::New(*x->shape()));\n      JUST(ccl::CpuSend(flat_shape.get(), sizeof(*flat_shape), dst));\n\n      DataType dtype = x->dtype()->data_type();\n      JUST(ccl::CpuSend(&dtype, sizeof(dtype), dst));\n\n      DeviceType device_type = JUST(Device::GetPlacement(*JUST(x->device())))->device_type();\n      JUST(ccl::CpuSend(&device_type, sizeof(device_type), dst));\n    }\n    JUST(OpInterpUtil::Dispatch<TensorTuple>(*op_expr_, {x}, attrs));\n    return Maybe<void>::Ok();\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_expr_;\n};\n\nclass RecvFunctor {\n public:\n  RecvFunctor() { op_expr_ = CHECK_JUST(one::OpBuilder(\"recv\").Output(\"out\").Build()); }\n  Maybe<Tensor> operator()(int64_t src, const Optional<Shape>& optional_shape,\n                           const Optional<Symbol<DType>>& optional_dtype,\n                           const Optional<Symbol<Device>>& optional_device,\n                           const Optional<one::Tensor>& out) const {\n    Shape shape;\n    DataType data_type = DataType::kInvalidDataType;\n    Symbol<Device> device;\n    if (optional_shape.has_value() && optional_dtype.has_value() && optional_device.has_value()) {\n      shape = *JUST(optional_shape);\n      data_type = JUST(optional_dtype)->data_type();\n      device = JUST(optional_device);\n    } else if (!optional_shape.has_value() && !optional_dtype.has_value()\n               && !optional_device.has_value()) {\n      FlatShape flat_shape{};\n      JUST(ccl::CpuRecv(&flat_shape, sizeof(flat_shape), src));\n      shape = *JUST(flat_shape.ToShape());\n\n      JUST(ccl::CpuRecv(&data_type, sizeof(data_type), src));\n\n      DeviceType device_type = DeviceType::kInvalidDevice;\n      JUST(ccl::CpuRecv(&device_type, sizeof(device_type), src));\n      device = JUST(Device::New(*JUST(DeviceTag4DeviceType(device_type))));\n    } else {\n      UNIMPLEMENTED_THEN_RETURN() << \"All or none of shape, dtype and device should have value.\";\n    }\n\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"src_process_id\", \"shape\", \"dtype\", \"device_type\",\n                                                 \"device_id\");\n    attrs.SetAllAttrs(src, shape, data_type, device->type(), device->device_id());\n    OpExprInterpContext op_expr_interp_context(attrs, device);\n\n    if (out.has_value()) {\n      std::shared_ptr<one::Tensor> out_tensor = JUST(out);\n      Symbol<Device> out_tensor_device = JUST(out_tensor->device());\n      CHECK_OR_RETURN(out_tensor_device == device);\n      std::shared_ptr<TensorTuple> outputs = std::make_shared<TensorTuple>(1);\n      outputs->at(0) = out_tensor;\n      JUST(OpInterpUtil::Dispatch(*op_expr_, {}, outputs.get(), op_expr_interp_context));\n      return outputs->at(0);\n    }\n    return OpInterpUtil::Dispatch<Tensor>(*op_expr_, {}, op_expr_interp_context);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_expr_;\n};\n\nclass LocalReduceFunctor {\n public:\n  LocalReduceFunctor() = default;\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, int64_t dst, bool inplace) const {\n    const auto& device = JUST(x->device());\n    { CHECK_EQ_OR_RETURN(device->device_id(), GlobalProcessCtx::LocalRank()); }\n    static thread_local std::unordered_map<std::pair<Symbol<RankGroup>, Symbol<Device>>,\n                                           Symbol<ParallelDesc>>\n        rank_group_with_device2parallel_desc;\n    const auto& rank_group = JUST(RankGroupScope::CurrentRankGroup());\n    auto iter = rank_group_with_device2parallel_desc.find({rank_group, device});\n    Symbol<ParallelDesc> parallel_desc;\n    if (iter == rank_group_with_device2parallel_desc.end()) {\n      ParallelConf parallel_conf;\n      parallel_conf.set_device_tag(device->type());\n      JUST(rank_group->ForEachRank([&parallel_conf](int64_t rank) -> Maybe<void> {\n        parallel_conf.add_device_name(\"@\" + std::to_string(rank) + \":\"\n                                      + std::to_string(GlobalProcessCtx::LocalRank(rank)));\n        return Maybe<void>::Ok();\n      }));\n      parallel_desc = SymbolOf(ParallelDesc(parallel_conf));\n      rank_group_with_device2parallel_desc[{rank_group, device}] = parallel_desc;\n    } else {\n      parallel_desc = iter->second;\n    }\n    std::shared_ptr<OpExpr> op_expr = JUST(CachedEagerCclReduceOpExpr(parallel_desc, dst));\n    if (inplace) {\n      TensorTuple outputs{x};\n      JUST(OpInterpUtil::Dispatch(*op_expr, {x}, &outputs));\n      return x;\n    } else {\n      return JUST(OpInterpUtil::Dispatch<Tensor>(*op_expr, {x}));\n    }\n  }\n};\n\n}  // namespace impl\n\nONEFLOW_FUNCTION_LIBRARY(m) {\n  m.add_functor<impl::StreamTouchFunctor>(\"StreamTouch\");\n  m.add_functor<impl::CommBroadcastFunctor>(\"CommBroadcast\");\n  m.add_functor<impl::CommBroadcastTensorsFunctor>(\"CommBroadcastTensors\");\n  m.add_functor<impl::LocalAllReduceFunctor>(\"LocalAllReduce\");\n  m.add_functor<impl::LocalAllGatherFunctor>(\"LocalAllGather\");\n  m.add_functor<impl::LocalReduceScatterFunctor>(\"LocalReduceScatter\");\n  m.add_functor<impl::GlobalAllReduceFunctor>(\"GlobalAllReduce\");\n  m.add_functor<impl::GlobalReduceScatterFunctor>(\"GlobalReduceScatter\");\n  m.add_functor<impl::GlobalAllGatherFunctor>(\"GlobalAllGather\");\n  m.add_functor<impl::GlobalS2SFunctor>(\"GlobalS2S\");\n  m.add_functor<impl::SendFunctor>(\"Send\");\n  m.add_functor<impl::RecvFunctor>(\"Recv\");\n  m.add_functor<impl::LocalReduceFunctor>(\"LocalReduce\");\n};\n\n}  // namespace functional\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/functional/impl/common.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"fmt/core.h\"\n#include \"oneflow/core/functional/functional_api.yaml.h\"\n#include \"oneflow/core/functional/impl/common.h\"\n#include \"oneflow/core/autograd/autograd_mode.h\"\n#include \"oneflow/core/common/wrap_dim_utils.h\"\n#include \"oneflow/core/common/container_util.h\"\n#include \"oneflow/core/ccl/ccl.h\"\n#include \"oneflow/core/job/rank_group.h\"\n#include \"oneflow/core/common/small_vector.h\"\n#include \"oneflow/core/common/throw.h\"\n#include \"oneflow/core/eager/eager_blob_object.h\"\n#include \"oneflow/core/framework/tensor_util.h\"\n#include \"oneflow/core/kernel/kernel_util.h\"\n#include \"oneflow/core/memory/memory_case_util.h\"\n\nnamespace oneflow {\nnamespace one {\nnamespace functional {\nnamespace {\n\nMaybe<Shape> InferUnifiedShapeForBroadcasting(const Shape& input_shape, const Shape& other_shape) {\n  // same shapes need no broadcasting\n  if (input_shape == other_shape) { return input_shape; }\n\n  const auto unify_shapes_with_same_num_axes = [](const Shape& input_shape,\n                                                  const Shape& other_shape) -> Maybe<Shape> {\n    // num_axes.first == num_axes.second\n    Shape target;\n    for (size_t i = 0; i < input_shape.NumAxes() /* both input_shape and other_shape are ok */;\n         ++i) {\n      const auto num_in_curr_dim = std::make_pair(input_shape.At(i), other_shape.At(i));\n\n      // A = (2, ), B = (2, ), A[0] == B[0], so C = (2, )\n      if (num_in_curr_dim.first == num_in_curr_dim.second) {\n        target.push_back(num_in_curr_dim.first);\n        continue;\n      }\n\n      // A = (2, ), B = (3, ), A[0] != B[0] and A[0] != 1 and B[0] != 1, so raise RuntimeError\n      if (num_in_curr_dim.first != 1 && num_in_curr_dim.second != 1) {\n        return Error::RuntimeError()\n               << fmt::format(\"input and other can't be broadcasted to a single shape. [input's \"\n                              \"shape: {}, other's shape: {}].\",\n                              input_shape.ToString(), other_shape.ToString());\n      }\n\n      // A = (2, ), B = (1, ), A[0] != B[0] but B[0] == 1, so C = (2, )\n      target.push_back(\n          num_in_curr_dim.first == 1\n              ? num_in_curr_dim.second\n              : num_in_curr_dim.first);  // num_in_curr_dim.first and num_in_curr_dim.second can't\n                                         // be 1 at the same time\n    }\n    return target;\n  };\n\n  const int64_t input_num_axes = input_shape.NumAxes();\n  const int64_t other_num_axes = other_shape.NumAxes();\n\n  if (input_num_axes == other_num_axes) {\n    return unify_shapes_with_same_num_axes(input_shape, other_shape);\n  }\n\n  const int64_t unified_num_axes = std::max(input_num_axes, other_num_axes);\n\n  // shape = (3, 4) and unified_num_axes = 3 ==> shape will be (1, 3, 4)\n  const auto expand_shape_if_necessary = [unified_num_axes](const Shape& shape_to_expand) {\n    const int64_t shape_to_expand_num_axes = shape_to_expand.NumAxes();\n    if (shape_to_expand_num_axes < unified_num_axes) {\n      auto new_shape = Shape::Ones(unified_num_axes);\n      std::copy(shape_to_expand.begin(), shape_to_expand.end(),\n                new_shape.begin() + (unified_num_axes - shape_to_expand_num_axes));\n      return new_shape;\n    }\n    return shape_to_expand;\n  };\n\n  return unify_shapes_with_same_num_axes(expand_shape_if_necessary(input_shape),\n                                         expand_shape_if_necessary(other_shape));\n}\n\n}  // namespace\n\nbool IsStaticZerosTensor(const std::shared_ptr<Tensor>& x) {\n  return nullptr != std::dynamic_pointer_cast<StaticZerosTensor>(x);\n}\n\nbool IsInplaceValid(const std::shared_ptr<Tensor>& x) {\n  return !autograd::GradMode::is_enabled() || !(x->is_leaf() && x->requires_grad());\n}\n\nbool IsScalarTensor(const std::shared_ptr<Tensor>& x) {\n  return x->shape()->NumAxes() == 0 && x->shape()->elem_cnt() == 1;\n}\n\nMaybe<bool> ComputeNonOverlappingAndDense(const std::shared_ptr<Tensor>& x) {\n  // A function used to check whether the tensor is non-overlapping and dense, reference: (pytorch)\n  // c10/core/TensorImpl.cpp\n  const int64_t ndim = x->ndim();\n  const auto& shape = x->shape();\n  const auto& stride = JUST(x->stride());\n\n  // If 1D tensor and shape(0) < 2 or stride(0) == 1 then true\n  if (ndim == 1) { return shape->at(0) < 2 || stride->at(0) == 1; }\n  small_vector<int64_t, 5> perm;\n  perm.resize(ndim);\n  for (int64_t i = 0; i < ndim; ++i) { perm[i] = i; }\n  // Sort by strides, leaving 0 and 1 sized dims at the end of the array\n  std::sort(perm.begin(), perm.end(), [&](int64_t a, int64_t b) {\n    if (shape->at(a) < 2) {\n      return false;\n    } else if (shape->at(b) < 2) {\n      return true;\n    }\n    return stride->at(a) < stride->at(b);\n  });\n  // CHeck if tareget stride == required stride\n  auto require_stride = 1;\n  for (int64_t i = 0; i < ndim; ++i) {\n    const auto size_perm_i = shape->at(perm[i]);\n    if (size_perm_i < 2) { return true; }\n    if (stride->at(perm[i]) != require_stride) { return false; }\n    require_stride *= size_perm_i;\n  }\n  return true;\n}\n\nMaybe<bool> IsNonOverlappingAndDense(const std::shared_ptr<Tensor>& x) {\n  // if tensor is_contiguous or ComputeNonOverlappingAndDense = True, then indicates it's memory\n  // layout is non-overlapping and dense.\n  return x->is_contiguous() || JUST(ComputeNonOverlappingAndDense(x));\n}\n\nMaybe<std::vector<int32_t>> CheckAxis(const std::vector<int32_t>& axis, const int32_t& ndim) {\n  const int32_t naxis = axis.size();\n  int32_t reduce_ndim = naxis;\n  if (naxis == 0 || ndim == 0) { reduce_ndim = ndim; };\n  std::vector<int32_t> reduce_axis(reduce_ndim);\n  if (naxis == 0) {\n    std::iota(reduce_axis.begin(), reduce_axis.end(), 0);\n  } else {\n    JUST(dim_list_to_bitset(axis, ndim));  // checking axis[dim]'s validation\n    for (int32_t i = 0; i < naxis; i++) {\n      if (i < reduce_ndim) { reduce_axis[i] = JUST(maybe_wrap_dim(axis[i], ndim)); };\n    }\n  }\n  return reduce_axis;\n}\n\nMaybe<void> CheckInplaceValid(const std::shared_ptr<Tensor>& x) {\n  CHECK_OR_RETURN(IsInplaceValid(x))\n      << Error::RuntimeError()\n      << \"a leaf Tensor that requires grad is being used in an in-place operation\";\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> CheckInplaceCastValid(const std::shared_ptr<Tensor>& x,\n                                  const std::shared_ptr<Tensor>& x_cast) {\n  CHECK_OR_RETURN(*x->dtype() == *x_cast->dtype())\n      << Error::RuntimeError() << \"result type \" << x_cast->dtype()->name()\n      << \" can't be cast to the desired output type \" << x->dtype()->name();\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> CheckInplaceShapeCanExpandTo(const Shape& shape, const Shape& expand_shape) {\n  if (shape == expand_shape) { return Maybe<void>::Ok(); }\n\n  CHECK_OR_RETURN(expand_shape.NumAxes() >= shape.NumAxes())\n      << Error::RuntimeError() << \"Can not expand origin shape \" << shape.ToString() << \" to \"\n      << expand_shape.ToString() << \" in an inplace operation\";\n\n  int shift = expand_shape.NumAxes() - shape.NumAxes();\n  for (int i = expand_shape.NumAxes() - 1; i >= 0; --i) {\n    int index = i - shift;\n    if (index >= 0) {\n      int dim_a = expand_shape.At(i);\n      int dim_b = shape.At(index);\n      // NOTE(lixiang): When a dimension of tensor a and tensor b are not equal in size, dim_a needs\n      // to be greater than or equal 0, and dim_b should be equal to 1.\n      CHECK_OR_RETURN(!(dim_a != dim_b && (dim_a < 0 || dim_b != 1)))\n          << Error::RuntimeError() << \"Tensor with shape \" << expand_shape.ToString()\n          << \" doesn't match the broadcast shape in an inplace operation\";\n    } else {\n      // For 0-size tensor, expand_shape.At(i) can equal to 0.\n      CHECK_OR_RETURN(expand_shape.At(i) >= 0);  // NOLINT(maybe-need-error-msg)\n    }\n  }\n\n  return Maybe<void>::Ok();\n}\n\nOptional<Stride> ComputeStride(const Shape& shape, const Stride& stride,\n                               const Shape& target_shape) {\n  /*************************************************\n   * Description: in some case, view operate is not allowed, so need to check it's validation,\n   * the check refers to torch(aten/src/ATen/native/TensorShape.cpp)\n   *************************************************/\n  if (stride.size() == 0) {\n    // for scalar input tensor\n    return Stride(target_shape.NumAxes(), 1);\n  }\n  int64_t elem_count = shape.elem_cnt();\n  int64_t ndim = shape.NumAxes();\n  int64_t tgt_ndim = target_shape.NumAxes();\n  DimVector shape_vec = shape.dim_vec();\n  DimVector tgt_shape_vec = target_shape.dim_vec();\n  if (elem_count == 0) { return NullOpt; }\n\n  int64_t view_d = tgt_ndim - 1;\n  int64_t chunk_base_stride = stride.back();\n  Stride target_stride(tgt_ndim);\n  // stride for each subspace in the chunk\n  // numel in current chunk\n  int64_t tensor_numel = 1;\n  int64_t view_numel = 1;\n  for (int64_t tensor_d = ndim - 1; tensor_d >= 0; tensor_d--) {\n    tensor_numel *= shape_vec[tensor_d];\n    // if end of tensor size chunk, check view\n    if ((tensor_d == 0)\n        || (shape_vec[tensor_d - 1] != 1\n            && stride[tensor_d - 1] != tensor_numel * chunk_base_stride)) {\n      while (view_d >= 0 && (view_numel < tensor_numel || tgt_shape_vec[view_d] == 1)) {\n        target_stride[view_d] = view_numel * chunk_base_stride;\n        view_numel *= tgt_shape_vec[view_d];\n        view_d--;\n      }\n      if (view_numel != tensor_numel) { return NullOpt; }\n      if (tensor_d > 0) {\n        chunk_base_stride = stride[tensor_d - 1];\n        tensor_numel = 1;\n        view_numel = 1;\n      }\n    }\n  }\n  if (view_d != -1) { return NullOpt; }\n  return target_stride;\n}\n\nMaybe<Shape> InferShapeUnspecifiedDim(const int64_t& elem_count, const Shape& shape) {\n  int need_infer_axis = -1;\n  int64_t target_elem_count = 1;\n  for (int i = 0; i < shape.NumAxes(); ++i) {\n    if (shape.At(i) < -1) {\n      return Error::RuntimeError() << \"Invalid shape dimension \" << shape.At(i);\n    } else if (shape.At(i) == -1) {\n      CHECK_OR_RETURN_ERROR(need_infer_axis == -1)\n          << Error::RuntimeError() << \"only one dimension can be inferred\";\n      need_infer_axis = i;\n    } else {\n      target_elem_count *= shape.At(i);\n    }\n  }\n  Shape infered_shape = shape;\n  if (need_infer_axis == -1) {\n    if (elem_count > 0) {\n      // For 0-size tensor, we don't need to check the element size.\n      CHECK_OR_RETURN_ERROR(target_elem_count == elem_count)\n          << Error::RuntimeError() << \"shape '\" << shape.ToString()\n          << \"' is invalid for input of size \" << elem_count;\n    }\n  } else {\n    infered_shape.Set(need_infer_axis, elem_count / target_elem_count);\n    CHECK_OR_RETURN_ERROR(target_elem_count * infered_shape.At(need_infer_axis) == elem_count)\n        << Error::RuntimeError() << \"shape '\" << shape.ToString()\n        << \"' is invalid for input of size \" << elem_count;\n  }\n  return infered_shape;\n}\n\nMaybe<Shape> InferUnifiedShapeForBroadcasting(const std::vector<Shape>& shapes) {\n  if (shapes.empty()) { return Error::RuntimeError() << \"shapes should not be empty.\"; }\n  if (shapes.size() == 1) { return JUST(VectorAt(shapes, 0)); }\n\n  auto result =\n      *JUST(InferUnifiedShapeForBroadcasting(JUST(VectorAt(shapes, 0)), JUST(VectorAt(shapes, 1))));\n\n  // (1, 2) vs (3, 2) => (3, 2)\n  if (shapes.size() == 2) { return result; }\n\n  /*\n    (1, 3) vs (3, 1) vs (3, 1, 1)\n\n    1. (1, 3) vs (3, 1) => (3, 3)\n    2. (3, 3) vs (3, 1, 1) => (3, 3, 3)\n    3. (3, 3, 3) is the final result\n  */\n  for (auto iter = shapes.begin() + 2; iter != shapes.end(); ++iter) {\n    result = *JUST(InferUnifiedShapeForBroadcasting(result, *iter));\n  }\n  return result;\n}\n\n/*\n  if input shapes are [(1, 3), (3, 1), (3, 1, 1)]\n  will return ((3, 3, 3), [true, true, true])\n  means the shape to broadcast to is (3, 3, 3) and all three shapes need broadcasting\n*/\nMaybe<std::tuple<Shape, std::deque<bool>>> InferUnifiedShapeForBroadcastingWithInfo(\n    const std::vector<Shape>& shapes) {\n  const auto unified_shape = *JUST(InferUnifiedShapeForBroadcasting(shapes));\n  std::deque<bool> need_to_broadcast;\n  for (const auto& x : shapes) { need_to_broadcast.emplace_back(x != unified_shape); }\n  return std::make_tuple(unified_shape, need_to_broadcast);\n}\n\nMaybe<void> BroadcastSeedToAllRanks(uint64_t* seed, int64_t root) {\n  CHECK_NOTNULL_OR_RETURN(seed) << \"seed is not allowed to be nullptr\";\n  const auto& rank_group = JUST(RankGroup::DefaultRankGroup());\n  const auto& parallel_desc = JUST(RankGroup::GetDefaultParallelDesc(DeviceType::kCPU, rank_group));\n  const auto& meta_transport_token =\n      JUST(TransportToken::NewTransportToken(kTransportTokenTypeMeta));\n  JUST(ccl::CpuBroadcast(seed, seed, sizeof(*seed), root, parallel_desc, meta_transport_token));\n  return Maybe<void>::Ok();\n}\n\nMaybe<std::vector<int32_t>> GetPermWhenTransposeAxisToLastDim(const int32_t& ndim,\n                                                              const int32_t& axis) {\n  auto wrap_dim = JUST(maybe_wrap_dim(axis, ndim));\n  std::vector<int32_t> perm(ndim);\n  for (int i = 0; i < ndim - 1; i++) {\n    if (i < wrap_dim) {\n      perm[i] = i;\n    } else {\n      perm[i] = i + 1;\n    }\n  }\n  perm[ndim - 1] = wrap_dim;\n  return perm;\n}\n\nMaybe<std::vector<int32_t>> GetInversedPerm(const std::vector<int32_t>& perm) {\n  std::vector<int32_t> inversed_perm(perm.size());\n  for (int i = 0; i < perm.size(); i++) { inversed_perm[perm[i]] = i; }\n  return inversed_perm;\n}\n\nMaybe<std::tuple<std::shared_ptr<Tensor>, bool>> batchify(const std::shared_ptr<Tensor>& input,\n                                                          const int64_t num_spatial_dims,\n                                                          const std::string& func_name) {\n  const int64_t dim_count_no_batch = num_spatial_dims + 1;\n  const int64_t dim_count_batch = dim_count_no_batch + 1;\n  const bool is_batched = (input->ndim() == dim_count_batch);\n  CHECK_EQ_OR_RETURN(input->ndim() == dim_count_no_batch || is_batched, true) << fmt::format(\n      \"Expected `{}`D (unbatched) or `{}`D (batched) input to `{}`, but got input of size: `{}`\",\n      dim_count_no_batch, dim_count_batch, func_name, input->shape()->DebugStr());\n  return std::make_tuple(is_batched ? input : JUST(functional::Unsqueeze(input, 0)), is_batched);\n}\n\ntemplate<typename T>\nT GetTensorItemValue(const std::shared_ptr<one::Tensor>& input) {\n  CHECK_EQ_OR_THROW(input->nelement(), 1) << \"Input tensor must have exactly one element\";\n  T value;\n  const auto& callback = [&](ep::Stream* stream,\n                             const std::shared_ptr<vm::EagerBlobObject>& eager_blob_object) {\n    SyncAutoMemcpy(stream, &value, eager_blob_object->dptr(), sizeof(T), memory::MakeHostMemCase(),\n                   eager_blob_object->mem_case());\n  };\n  SyncAccessTensorWithTimeOut(input, callback, \"const\").GetOrThrow();\n  return value;\n}\n\nMaybe<void> CheckNormalTensorStd(const std::shared_ptr<one::Tensor>& std) {\n  CHECK_OR_RETURN(!std->dtype()->is_complex())\n      << \"normal expects standard deviation to be non-complex\";\n  if (std->nelement() > 0) {\n    auto std_check = CHECK_JUST(ScalarLogicalGreaterEqual(CHECK_JUST(Min(std)), Scalar(0.0)));\n    CHECK_OR_THROW(GetTensorItemValue<bool>(std_check))\n        << \"normal expects all elements of std >= 0.0\";\n  }\n  return Maybe<void>::Ok();\n}\nMaybe<void> CheckNormalTensorStd(const float std) {\n  CHECK_GE_OR_RETURN(std, 0.0) << \"normal expects std >= 0.0, but found std \" << (std)\n                               << \". This may cause an error.\";\n  return Maybe<void>::Ok();\n}\n\n}  // namespace functional\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/functional/impl/common.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_FUNCTIONAL_IMPL_COMMON_H_\n#define ONEFLOW_CORE_FUNCTIONAL_IMPL_COMMON_H_\n\n#include \"oneflow/core/framework/tensor.h\"\n#include \"oneflow/core/common/stride.h\"\n#include \"oneflow/core/common/maybe.h\"\n#include \"fmt/core.h\"\n\nnamespace oneflow {\nnamespace one {\nnamespace functional {\n\nstatic constexpr size_t kMaxInputCount = 128;\nstatic constexpr size_t kMaxOutputCount = 128;\n\nbool IsStaticZerosTensor(const std::shared_ptr<Tensor>& x);\nbool IsInplaceValid(const std::shared_ptr<Tensor>& x);\nbool IsScalarTensor(const std::shared_ptr<Tensor>& x);\nMaybe<bool> ComputeNonOverlappingAndDense(const std::shared_ptr<Tensor>& x);\nMaybe<bool> IsNonOverlappingAndDense(const std::shared_ptr<Tensor>& x);\n\nMaybe<std::vector<int32_t>> CheckAxis(const std::vector<int32_t>& axis, const int32_t& ndim);\nMaybe<void> CheckInplaceValid(const std::shared_ptr<Tensor>& x);\nMaybe<void> CheckInplaceCastValid(const std::shared_ptr<Tensor>& x,\n                                  const std::shared_ptr<Tensor>& x_cast);\nMaybe<void> CheckInplaceShapeCanExpandTo(const Shape& shape, const Shape& expand_shape);\nOptional<Stride> ComputeStride(const Shape& shape, const Stride& stride, const Shape& target_shape);\nMaybe<Shape> InferShapeUnspecifiedDim(const int64_t& elem_count, const Shape& shape);\n\n// returns unified_shape\nMaybe<Shape> InferUnifiedShapeForBroadcasting(const std::vector<Shape>& shapes);\n// returns tuple<unified_shape, need_to_broadcasts>\nMaybe<std::tuple<Shape, std::deque<bool>>> InferUnifiedShapeForBroadcastingWithInfo(\n    const std::vector<Shape>& shapes);\n\nMaybe<void> BroadcastSeedToAllRanks(uint64_t* seed, int64_t root = 0);\n\nMaybe<std::vector<int32_t>> GetPermWhenTransposeAxisToLastDim(const int32_t& ndim,\n                                                              const int32_t& axis);\nMaybe<std::vector<int32_t>> GetInversedPerm(const std::vector<int32_t>& perm);\n\n// batchify function is referenced from\n// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Convolution.cpp#L729\nMaybe<std::tuple<std::shared_ptr<Tensor>, bool>> batchify(const std::shared_ptr<Tensor>& input,\n                                                          const int64_t num_spatial_dims,\n                                                          const std::string& func_name);\n// CheckNormalTensorStd function is referenced from\n// https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/DistributionTemplates.h#L171-L182\nMaybe<void> CheckNormalTensorStd(const std::shared_ptr<one::Tensor>& std);\nMaybe<void> CheckNormalTensorStd(const float std);\n}  // namespace functional\n}  // namespace one\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_FUNCTIONAL_IMPL_COMMON_H_\n"
  },
  {
    "path": "oneflow/core/functional/impl/dataset_functor.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/mutable_attr_map.h\"\n#include \"oneflow/core/framework/op_builder.h\"\n#include \"oneflow/core/framework/op_expr.h\"\n#include \"oneflow/core/framework/op_interpreter.h\"\n#include \"oneflow/core/framework/op_interpreter/op_interpreter_util.h\"\n#include \"oneflow/core/framework/tensor.h\"\n#include \"oneflow/core/framework/tensor_tuple.h\"\n#include \"oneflow/core/functional/function_library.h\"\n#include \"oneflow/core/framework/nd_sbp.h\"\n\nnamespace oneflow {\nnamespace one {\nnamespace functional {\n\nnamespace impl {\n\nclass ImageFlipFuntor {\n public:\n  ImageFlipFuntor() {\n    op_ = CHECK_JUST(\n        one::OpBuilder(\"image_flip\").Input(\"in\").Input(\"flip_code\").Output(\"out\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,\n                           const std::shared_ptr<one::Tensor>& flip_code) const {\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {x, flip_code});\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\n}  // namespace impl\n\nONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor<impl::ImageFlipFuntor>(\"ImageFlip\"); };\n\n}  // namespace functional\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/functional/impl/eye_functor.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/common/just.h\"\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/common/scalar.h\"\n#include \"oneflow/core/common/throw.h\"\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/framework/device.h\"\n#include \"oneflow/core/framework/attr_map.h\"\n#include \"oneflow/core/framework/mutable_attr_map.h\"\n#include \"oneflow/core/framework/nd_sbp.h\"\n#include \"oneflow/core/framework/op_builder.h\"\n#include \"oneflow/core/framework/op_expr.h\"\n#include \"oneflow/core/framework/op_interpreter/op_interpreter_util.h\"\n#include \"oneflow/core/framework/tensor.h\"\n#include \"oneflow/core/framework/tensor_tuple.h\"\n#include \"oneflow/core/functional/functional.h\"\n#include \"oneflow/core/functional/function_library.h\"\n#include \"oneflow/core/functional/functional_api.yaml.h\"\n#include \"oneflow/core/functional/impl/common.h\"\n#include \"oneflow/core/job/lazy_mode.h\"\n#include \"oneflow/core/job/sbp_parallel.h\"\n\nnamespace oneflow {\nnamespace one {\nnamespace functional {\n\nnamespace impl {\n\nclass EyeDevcieFunctor {\n public:\n  EyeDevcieFunctor() { op_ = CHECK_JUST(one::OpBuilder(\"eye\").Output(\"out\").Build()); }\n  Maybe<Tensor> operator()(const Scalar& rows, const Optional<Scalar>& cols,\n                           const Symbol<DType>& dtype, const Optional<Symbol<Device>>& device,\n                           const bool& requires_grad) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"rows\", \"cols\", \"dtype\");\n    attrs.SetAllAttrs(rows.As<int64_t>(), cols.value_or(rows).As<int64_t>(), dtype->data_type());\n    OpExprInterpContext ctx(attrs);\n    ctx.device = device;\n    auto res = JUST(OpInterpUtil::Dispatch<Tensor>(*op_, {}, ctx));\n    JUST(res->set_requires_grad(requires_grad));\n    return res;\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass EyeDeviceStrFunctor {\n public:\n  Maybe<Tensor> operator()(const Scalar& rows, const Optional<Scalar>& cols,\n                           const Symbol<DType>& dtype, const std::string& device,\n                           const bool& requires_grad) const {\n    const Symbol<Device>& dev = JUST(Device::ParseAndNew(device));\n    return JUST(functional::Eye(rows, cols, dtype, dev, requires_grad));\n  }\n};\n\nclass GlobalEyeSbpListFunctor {\n public:\n  GlobalEyeSbpListFunctor() { op_ = CHECK_JUST(one::OpBuilder(\"eye\").Output(\"out\").Build()); }\n  Maybe<Tensor> operator()(const Scalar& rows, const Optional<Scalar>& cols,\n                           const Symbol<DType>& dtype, const bool& requires_grad,\n                           const Symbol<ParallelDesc>& placement,\n                           const std::vector<Symbol<SbpParallel>>& sbp_tuple) const {\n    CHECK_EQ_OR_RETURN(sbp_tuple.size(), placement->hierarchy()->NumAxes())\n        << \"len(sbp) == len(placement.hierarchy) required, but \"\n        << \"len(sbp)==\" << sbp_tuple.size() << \", \"\n        << \"len(placement.hierarchy)==\" << placement->hierarchy()->NumAxes();\n\n    FOR_RANGE(int32_t, i, 0, sbp_tuple.size()) {\n      CHECK_OR_RETURN(sbp_tuple.at(i)->has_broadcast_parallel())\n          << \"sbp of eye should be broadcast only\";\n    }\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"rows\", \"cols\", \"dtype\", \"nd_sbp\");\n    if (LazyMode::is_enabled()) {\n      std::vector<std::string> nd_sbp(sbp_tuple.size());\n      {\n        for (int i = 0; i < sbp_tuple.size(); ++i) {\n          nd_sbp.at(i) = SbpParallelToString(*sbp_tuple.at(i));\n        }\n      }\n      attrs.SetAllAttrs(rows.As<int64_t>(), cols.value_or(rows).As<int64_t>(), dtype->data_type(),\n                        nd_sbp);\n    } else {\n      attrs.SetAllAttrs(rows.As<int64_t>(), cols.value_or(rows).As<int64_t>(), dtype->data_type(),\n                        NullOpt);\n    }\n\n    const auto& nd_sbp = JUST(GetNdSbp(sbp_tuple));\n    auto res = JUST(\n        OpInterpUtil::Dispatch<Tensor>(*op_, {}, OpExprInterpContext(attrs, placement, nd_sbp)));\n    JUST(res->set_requires_grad(requires_grad));\n    return res;\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass GlobalEyeSbpFunctor {\n public:\n  Maybe<Tensor> operator()(const Scalar& rows, const Optional<Scalar>& cols,\n                           const Symbol<DType>& dtype, const bool& requires_grad,\n                           const Symbol<ParallelDesc>& placement,\n                           const Symbol<SbpParallel>& sbp) const {\n    std::vector<Symbol<SbpParallel>> sbp_tuple{sbp};\n    return JUST(functional::Eye(rows, cols, dtype, requires_grad, placement, sbp_tuple));\n  }\n};\n\n}  // namespace impl\n\nclass EyeInplaceFunctor {\n public:\n  EyeInplaceFunctor() { op_ = CHECK_JUST(one::OpBuilder(\"eye\").Output(\"out\").Build()); }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x) const {\n    JUST(CheckInplaceValid(x));\n    std::shared_ptr<TensorTuple> outputs = std::make_shared<TensorTuple>(1);\n    outputs->at(0) = x;\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"rows\", \"cols\", \"dtype\");\n    attrs.SetAllAttrs(x->shape()->At(0), x->shape()->At(1), x->dtype()->data_type());\n    OpExprInterpContext ctx(attrs);\n    ctx.device = JUST(x->device());\n    JUST(OpInterpUtil::Dispatch(*op_, {}, outputs.get(), ctx));\n    return outputs->at(0);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nusing namespace impl;\n\nONEFLOW_FUNCTION_LIBRARY(m) {\n  m.add_functor<EyeDevcieFunctor, EyeDeviceStrFunctor, GlobalEyeSbpListFunctor,\n                GlobalEyeSbpFunctor>(\"Eye\");\n  m.add_functor<EyeInplaceFunctor>(\"EyeInplace\");\n};\n\n}  // namespace functional\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/functional/impl/fused_attention_functor.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"fmt/core.h\"\n#include \"oneflow/core/framework/mutable_attr_map.h\"\n#include \"oneflow/core/framework/op_builder.h\"\n#include \"oneflow/core/framework/tensor_util.h\"\n#include \"oneflow/core/functional/function_library.h\"\n#include \"oneflow/core/functional/sequence_function.h\"\n#include \"oneflow/core/functional/impl/common.h\"\n#include \"oneflow/core/functional/impl/unary_functor.h\"\n#include \"oneflow/core/kernel/kernel_util.h\"\n#include \"oneflow/user/kernels/random_mask_like_kernel.h\"\n#include \"oneflow/user/kernels/dropout_kernel.h\"\n#include \"oneflow/core/common/container_util.h\"\n#include \"oneflow/user/kernels/distributions/common.h\"\n\nnamespace oneflow {\nnamespace one {\nnamespace functional {\n\nnamespace impl {\n\nnamespace {\n\nMaybe<void> ParseDims(const std::string& name, const Shape& shape, const std::string& layout,\n                      const Optional<int64_t>& batch_size, const Optional<int64_t>& seq_len,\n                      const Optional<int64_t>& num_heads, const Optional<int64_t>& head_size,\n                      int64_t* b, int64_t* m, int64_t* h, int64_t* k, bool* bm_packed) {\n  if (shape.NumAxes() == 2) {\n    if (layout == \"(BM)(HK)\" || layout == \"(BM)(H2K)\" || layout == \"(BM)(H3K)\") {\n      *bm_packed = true;\n      CHECK_OR_RETURN(batch_size);\n      CHECK_OR_RETURN(seq_len);\n      *b = JUST(batch_size);\n      *m = JUST(seq_len);\n      int64_t packed_n = 0;\n      if (layout == \"(BM)(HK)\") {\n        packed_n = 1;\n      } else if (layout == \"(BM)(H2K)\") {\n        CHECK_NE_OR_RETURN(name, \"query\") << \"query_layout should not be '(BM)(H2K)'\";\n        packed_n = 2;\n      } else if (layout == \"(BM)(H3K)\") {\n        packed_n = 3;\n      } else {\n        UNIMPLEMENTED_THEN_RETURN();\n      }\n      const int64_t hidden_size = shape.At(1);\n      if (num_heads) {\n        const int64_t expected_h = JUST(num_heads);\n        const int64_t packed_h = packed_n * expected_h;\n        CHECK_EQ_OR_RETURN(hidden_size % packed_h, 0)\n            << \"The size of the last dimension of the \" << name\n            << \" tensor should be a multiple of \" << packed_h << \".\";\n        *h = expected_h;\n        *k = hidden_size / packed_h;\n      } else if (head_size) {\n        const int64_t expected_k = JUST(head_size);\n        const int64_t packed_k = expected_k * packed_n;\n        CHECK_EQ_OR_RETURN(hidden_size % packed_k, 0)\n            << \"The size of the last dimension of the \" << name\n            << \" tensor should be a multiple of \" << packed_k << \".\";\n        *h = hidden_size / packed_k;\n        *k = expected_k;\n      } else {\n        UNIMPLEMENTED_THEN_RETURN();\n      }\n    } else {\n      UNIMPLEMENTED_THEN_RETURN() << name\n                                  << \"_layout should be '(BM)(HK)', '(BM)(H2K)', or '(BM)(H3K)' \"\n                                     \"when the number of dimensions of \"\n                                  << name << \" tensor is 2.\";\n    }\n  } else if (shape.NumAxes() == 3) {\n    if (layout == \"BM(HK)\" || layout == \"MB(HK)\" || layout == \"BM(H2K)\" || layout == \"MB(H2K)\"\n        || layout == \"BM(H3K)\" || layout == \"MB(H3K)\") {\n      *bm_packed = false;\n      int64_t packed_n = 0;\n      if (layout == \"BM(HK)\") {\n        *b = shape.At(0);\n        *m = shape.At(1);\n        packed_n = 1;\n      } else if (layout == \"MB(HK)\") {\n        *b = shape.At(1);\n        *m = shape.At(0);\n        packed_n = 1;\n      } else if (layout == \"BM(H2K)\") {\n        CHECK_NE_OR_RETURN(name, \"query\") << \"query_layout should not be 'BM(H2K)'\";\n        *b = shape.At(0);\n        *m = shape.At(1);\n        packed_n = 2;\n      } else if (layout == \"MB(H2K)\") {\n        CHECK_NE_OR_RETURN(name, \"query\") << \"query_layout should not be 'MB(H2K)'\";\n        *b = shape.At(1);\n        *m = shape.At(0);\n        packed_n = 2;\n      } else if (layout == \"BM(H3K)\") {\n        *b = shape.At(0);\n        *m = shape.At(1);\n        packed_n = 3;\n      } else if (layout == \"MB(H3K)\") {\n        *b = shape.At(1);\n        *m = shape.At(0);\n        packed_n = 3;\n      } else {\n        UNIMPLEMENTED_THEN_RETURN();\n      }\n      const int64_t hidden_size = shape.At(2);\n      if (num_heads) {\n        const int64_t expected_h = JUST(num_heads);\n        const int64_t packed_h = packed_n * expected_h;\n        CHECK_EQ_OR_RETURN(hidden_size % packed_h, 0)\n            << \"The size of the last dimension of the \" << name\n            << \" tensor should be a multiple of \" << packed_h << \".\";\n        *h = expected_h;\n        *k = hidden_size / packed_h;\n      } else if (head_size) {\n        const int64_t expected_k = JUST(head_size);\n        const int64_t packed_k = expected_k * packed_n;\n        CHECK_EQ_OR_RETURN(hidden_size % packed_k, 0)\n            << \"The size of the last dimension of the \" << name\n            << \" tensor should be a multiple of \" << packed_k << \".\";\n        *h = hidden_size / packed_k;\n        *k = expected_k;\n      } else {\n        UNIMPLEMENTED_THEN_RETURN();\n      }\n    } else if (layout == \"(BM)HK\") {\n      *bm_packed = true;\n      CHECK_OR_RETURN(batch_size);\n      CHECK_OR_RETURN(seq_len);\n      *b = JUST(batch_size);\n      *m = JUST(seq_len);\n      *h = shape.At(1);\n      *k = shape.At(2);\n    } else {\n      UNIMPLEMENTED_THEN_RETURN()\n          << name\n          << \"_layout should be 'BM(HK)', 'MB(HK)', 'BM(H2K)', 'MB(H2K)', 'BM(H3K)', \"\n             \"'MB(H3K)' or '(BM)HK' when the number of dimensions of \"\n          << name << \" tensor is 3.\";\n    }\n  } else if (shape.NumAxes() == 4) {\n    *bm_packed = false;\n    if (layout == \"BMHK\") {\n      *b = shape.At(0);\n      *m = shape.At(1);\n      *h = shape.At(2);\n      *k = shape.At(3);\n    } else if (layout == \"BHMK\") {\n      *b = shape.At(0);\n      *m = shape.At(2);\n      *h = shape.At(1);\n      *k = shape.At(3);\n    } else if (layout == \"MBHK\") {\n      *b = shape.At(1);\n      *m = shape.At(0);\n      *h = shape.At(2);\n      *k = shape.At(3);\n    } else {\n      UNIMPLEMENTED_THEN_RETURN()\n          << name << \"_layout should be 'BMHK', 'BHMK' or 'MBHK' when the number of dimensions of \"\n          << name << \" tensor is 4.\";\n    }\n  } else {\n    UNIMPLEMENTED_THEN_RETURN() << \"The number of dimensions of the \" << name\n                                << \" tensor should be 3 or 4\";\n  };\n  if (batch_size) {\n    const int64_t expected_b = JUST(batch_size);\n    CHECK_EQ_OR_RETURN(*b, expected_b)\n        << \"The size of dimension 'B' of \" << name << \" tensor should be \" << expected_b << \".\";\n  }\n  if (seq_len) {\n    const int64_t expected_m = JUST(seq_len);\n    CHECK_EQ_OR_RETURN(*m, expected_m)\n        << \"The size of dimension 'M' of \" << name << \" tensor should be \" << expected_m << \".\";\n  }\n  if (num_heads) {\n    const int64_t expected_h = JUST(num_heads);\n    CHECK_EQ_OR_RETURN(*h, expected_h)\n        << \"The size of dimension 'H' of \" << name << \" tensor should be \" << expected_h << \".\";\n  }\n  if (head_size) {\n    const int64_t expected_k = JUST(head_size);\n    CHECK_EQ_OR_RETURN(*k, expected_k)\n        << \"The size of dimension 'K' of \" << name << \" tensor should be \" << expected_k << \".\";\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> ParseDims(const std::string& name, const Shape& shape, const std::string& layout,\n                      const Optional<int64_t>& num_heads, const Optional<int64_t>& head_size,\n                      int64_t* b, int64_t* m, int64_t* h, int64_t* k) {\n  bool bm_packed{};\n  return ParseDims(name, shape, layout, Optional<int64_t>(), Optional<int64_t>(), num_heads,\n                   head_size, b, m, h, k, &bm_packed);\n}\n\n}  // namespace\n\nclass FusedMultiHeadAttentionInferenceFunctor {\n public:\n  FusedMultiHeadAttentionInferenceFunctor() = default;\n  Maybe<Tensor> operator()(\n      const std::shared_ptr<one::Tensor>& query, const std::shared_ptr<one::Tensor>& key,\n      const std::shared_ptr<one::Tensor>& value, const int64_t& num_heads, const bool& causal,\n      const int64_t& query_hidden_slice_start, const int64_t& query_hidden_slice_end,\n      const int64_t& key_hidden_slice_start, const int64_t& key_hidden_slice_end,\n      const int64_t& value_hidden_slice_start, const int64_t& value_hidden_slice_end,\n      const Optional<one::Tensor>& attn_bias, const int64_t& causal_diagonal_offset) const {\n    CHECK_OR_RETURN(query_hidden_slice_start == 0 && key_hidden_slice_start == 0\n                    && value_hidden_slice_start == 0 && query_hidden_slice_end == -1\n                    && key_hidden_slice_end == -1 && value_hidden_slice_end == -1)\n        << \"The parameters 'query_hidden_slice_start', 'query_hidden_slice_end', \"\n           \"'key_hidden_slice_start', 'key_hidden_slice_end', 'value_hidden_slice_start', \"\n           \"'value_hidden_slice_end' have been deprecated.\";\n\n    const int64_t query_hidden_size = query->shape()->At(2);\n    CHECK_EQ_OR_RETURN(query_hidden_size % num_heads, 0)\n        << \"The hidden size of the query tensor should be a multiple of num_heads.\";\n    const int64_t query_head_size = query_hidden_size / num_heads;\n    return functional::FusedMultiHeadAttentionInferenceV2(\n        query, \"BM(HK)\", query_head_size, Optional<one::Tensor>(), Optional<int64_t>(), key,\n        \"BM(HK)\", Optional<one::Tensor>(), Optional<one::Tensor>(), Optional<int64_t>(), value,\n        \"BM(HK)\", attn_bias, \"BM(HK)\", Optional<float>(), causal, Optional<std::string>(),\n        causal_diagonal_offset);\n  }\n};\n\nclass FusedMultiHeadAttentionInferenceV2Functor {\n public:\n  struct OpExprCacheKey {\n    bool has_attn_bias = false;\n    bool has_seq_start = false;\n    bool has_key_seq_len = false;\n    bool operator==(const OpExprCacheKey& rhs) const {\n      return this->has_attn_bias == rhs.has_attn_bias && this->has_seq_start == rhs.has_seq_start\n             && this->has_key_seq_len == rhs.has_key_seq_len;\n    }\n  };\n  struct OpExprCacheKeyHash {\n    size_t operator()(const OpExprCacheKey& key) const {\n      return Hash(key.has_attn_bias, key.has_seq_start, key.has_key_seq_len);\n    }\n  };\n  using OpExprCache =\n      std::unordered_map<OpExprCacheKey, std::shared_ptr<OpExpr>, OpExprCacheKeyHash>;\n  FusedMultiHeadAttentionInferenceV2Functor() {\n    for (bool has_attn_bias : {false, true}) {\n      for (bool has_seq_start : {false, true}) {\n        for (bool has_key_seq_len : {false, true}) {\n          auto builder = one::OpBuilder(\"fused_multi_head_attention_inference\")\n                             .Input(\"query\")\n                             .Input(\"key\")\n                             .Input(\"value\");\n          if (has_attn_bias) { builder.Input(\"attn_bias\"); }\n          if (has_seq_start) { builder.Input(\"query_seq_start\").Input(\"key_seq_start\"); }\n          if (has_key_seq_len) { builder.Input(\"key_seq_len\"); }\n          auto op = CHECK_JUST(builder.Output(\"out\").Build());\n          OpExprCacheKey key;\n          key.has_attn_bias = has_attn_bias;\n          key.has_seq_start = has_seq_start;\n          key.has_key_seq_len = has_key_seq_len;\n          op_cache_.emplace(key, op);\n        }\n      }\n    }\n  }\n  Maybe<Tensor> operator()(\n      const std::shared_ptr<one::Tensor>& query, const std::string& query_layout,\n      const Optional<int64_t>& query_head_size, const Optional<one::Tensor>& query_seq_start,\n      const Optional<int64_t>& query_max_seq_len, const Optional<one::Tensor>& key,\n      const Optional<std::string>& key_layout, const Optional<one::Tensor>& key_seq_start,\n      const Optional<one::Tensor>& key_seq_len, const Optional<int64_t>& key_max_seq_len,\n      const Optional<one::Tensor>& value, const Optional<std::string>& value_layout,\n      const Optional<one::Tensor>& attn_bias, const std::string& output_layout,\n      const Optional<float>& scale, const Optional<bool>& causal,\n      const Optional<std::string>& attn_mask_type, const int64_t& causal_diagonal_offset) const {\n    std::string attn_mask_type_val = \"none\";\n    if (attn_mask_type) {\n      CHECK(!causal) << \"Only one of attn_mask_type and causal can be specified at the same time.\";\n      attn_mask_type_val = *JUST(attn_mask_type);\n      CHECK_OR_RETURN(attn_mask_type_val == \"none\" || attn_mask_type_val == \"causal_from_top_left\"\n                      || attn_mask_type_val == \"causal_from_bottom_right\")\n          << \"The value of attn_mask_type should be one of 'none', 'causal_from_top_left' or \"\n             \"'causal_from_bottom_right'\";\n    } else if (causal && JUST(causal)) {\n      attn_mask_type_val = \"causal_from_top_left\";\n    } else {\n      // do nothing\n    }\n    CHECK_GE_OR_RETURN(causal_diagonal_offset, 0)\n        << \"The value of causal_diagonal_offset should be greater or equal to 0.\";\n\n    Optional<int64_t> batch_size;\n    std::shared_ptr<one::Tensor> query_seq_start_tensor;\n    std::shared_ptr<one::Tensor> key_seq_start_tensor;\n    if (query_seq_start) {\n      CHECK_OR_RETURN(key_seq_start) << \"The tensors query_seq_start and key_seq_start should both \"\n                                        \"be None or both not be None at the same time.\";\n      CHECK_OR_RETURN(query_max_seq_len)\n          << \"query_max_seq_len should not be None when query_seq_start is not None.\";\n      CHECK_OR_RETURN(key_max_seq_len)\n          << \"key_max_seq_len should not be None when key_seq_start is not None.\";\n      query_seq_start_tensor = JUST(query_seq_start);\n      key_seq_start_tensor = JUST(key_seq_start);\n      CHECK_EQ_OR_RETURN(query_seq_start_tensor->shape()->NumAxes(), 1)\n          << \"The number of dimensions of query_seq_start tensor should be 1.\";\n      CHECK_OR_RETURN(*query_seq_start_tensor->shape() == *key_seq_start_tensor->shape())\n          << \"The shapes of the query_seq_start and key_seq_start tensors should match.\";\n      CHECK_GT_OR_RETURN(query_seq_start_tensor->shape()->At(0), 1)\n          << \"The size of query_seq_start should be greater than 1.\";\n      batch_size = query_seq_start_tensor->shape()->At(0) - 1;\n      if (key_seq_len) {\n        CHECK_EQ_OR_RETURN(JUST(key_seq_len)->shape()->NumAxes(), 1)\n            << \"The number of dimensions of key_seq_len tensor should be 1.\";\n        CHECK_EQ_OR_RETURN(JUST(key_seq_len)->shape()->At(0), JUST(batch_size))\n            << \"The size of the key_seq_len tensor should be \" << JUST(batch_size) << \".\";\n      }\n    } else {\n      CHECK_OR_RETURN(!key_seq_start)\n          << \"The tensors query_seq_start and key_seq_start should both \"\n             \"be None or both not be None at the same time.\";\n      CHECK_OR_RETURN(!key_seq_len)\n          << \"The key_seq_len tensor should be None when query_seq_start is None.\";\n    }\n    std::shared_ptr<one::Tensor> key_tensor;\n    std::string key_tensor_layout;\n    std::shared_ptr<one::Tensor> value_tensor;\n    std::string value_tensor_layout;\n\n    int64_t q_b = 0;\n    int64_t q_m = 0;\n    int64_t q_h = 0;\n    int64_t q_k = 0;\n    bool q_bm_packed = false;\n    JUST(ParseDims(\"query\", *query->shape(), query_layout, batch_size, query_max_seq_len,\n                   Optional<int64_t>(), query_head_size, &q_b, &q_m, &q_h, &q_k, &q_bm_packed));\n    CHECK_EQ_OR_RETURN(q_k % 8, 0)\n        << \"The size of dimension 'K' of the query tensor should be a multiple of 8.\";\n    if (q_bm_packed) {\n      CHECK_OR_RETURN(query_seq_start)\n          << \"The query_seq_start tensor should not be None when the query tensor is BM-Packed.\";\n    }\n\n    int64_t k_b = 0;\n    int64_t k_m = 0;\n    int64_t k_h = 0;\n    int64_t k_k = 0;\n    bool k_bm_packed = false;\n    if (key) {\n      key_tensor = JUST(key);\n      key_tensor_layout = *JUST(key_layout);\n      JUST(ParseDims(\"key\", *key_tensor->shape(), key_tensor_layout, q_b, key_max_seq_len,\n                     Optional<int64_t>(), q_k, &k_b, &k_m, &k_h, &k_k, &k_bm_packed));\n      CHECK_EQ_OR_RETURN(k_b, q_b) << \"The size of dimension 'B' of the key tensor should be the \"\n                                      \"same as that of the query tensor.\";\n      CHECK_EQ_OR_RETURN(k_h, q_h) << \"The size of dimension 'H' of the key tensor should be the \"\n                                      \"same as that of the query tensor.\";\n      CHECK_EQ_OR_RETURN(k_bm_packed, q_bm_packed)\n          << \"The query tensor and the key tensor should either both be BM-Packed or both not be \"\n             \"BM-Packed at the same time.\";\n\n    } else {\n      CHECK_OR_RETURN(query_layout == \"BM(H3K)\" || query_layout == \"MB(H3K)\")\n          << \"The value of query_layout should be 'BM(H3K)' or 'MB(H3K)' when the key tensor is \"\n             \"None.\";\n      key_tensor = query;\n      key_tensor_layout = query_layout;\n      k_b = q_b;\n      k_m = q_m;\n      k_h = q_h;\n      k_k = q_k;\n      k_bm_packed = q_bm_packed;\n    }\n\n    int64_t v_b = 0;\n    int64_t v_m = 0;\n    int64_t v_h = 0;\n    int64_t v_k = 0;\n    bool v_bm_packed = false;\n    if (value) {\n      value_tensor = JUST(value);\n      value_tensor_layout = *JUST(value_layout);\n      JUST(ParseDims(\"value\", *value_tensor->shape(), value_tensor_layout, q_b, k_m, q_h,\n                     Optional<int64_t>(), &v_b, &v_m, &v_h, &v_k, &v_bm_packed));\n      CHECK_EQ_OR_RETURN(v_b, q_b) << \"The size of dimension 'B' of the value tensor should be the \"\n                                      \"same as that of the query tensor.\";\n      CHECK_EQ_OR_RETURN(v_m, k_m) << \"The size of dimension 'M' of the value tensor should be the \"\n                                      \"same as that of the key tensor.\";\n      CHECK_EQ_OR_RETURN(v_k % 8, 0)\n          << \"The size of dimension 'K' of the value tensor should be a multiple of 8.\";\n      CHECK_EQ_OR_RETURN(v_bm_packed, k_bm_packed)\n          << \"The key tensor and the value tensor should either both be BM-Packed or both not be \"\n             \"BM-Packed at the same time.\";\n\n    } else {\n      CHECK_OR_RETURN(key_tensor_layout == \"BM(H2K)\" || key_tensor_layout == \"MB(H2K)\"\n                      || key_tensor_layout == \"BM(H3K)\" || key_tensor_layout == \"MB(H3K)\")\n          << \"The value of key_layout should be 'BM(H3K)', 'MB(H3K)', 'BM(H2K)' or 'MB(H2K)' when \"\n             \"the value tensor is None.\";\n      value_tensor = key_tensor;\n      value_tensor_layout = key_tensor_layout;\n      v_b = k_b;\n      v_m = k_m;\n      v_h = k_h;\n      v_k = k_k;\n      v_bm_packed = k_bm_packed;\n    }\n\n    if (attn_bias) {\n      const auto attn_bias_shape = JUST(attn_bias)->shape();\n      const int64_t num_attn_bias_axes = attn_bias_shape->NumAxes();\n      CHECK_OR_RETURN(num_attn_bias_axes > 0 && num_attn_bias_axes <= 4)\n          << \"The number of dimensions of attn_bias should be greater than 0 and less than or \"\n             \"equal to 4.\";\n      CHECK_GE_OR_RETURN(attn_bias_shape->At(num_attn_bias_axes - 1), k_m)\n          << \"The size of the -1 dimension of attn_bias should be greater than or equal to the \"\n             \"dimension 'M' of the key tensor\";\n      CHECK_EQ_OR_RETURN(attn_bias_shape->At(num_attn_bias_axes - 1) % 8, 0)\n          << \"The size of the -1 dimension of attn_bias should be a multiple of 8.\";\n      if (num_attn_bias_axes >= 2) {\n        CHECK_OR_RETURN(attn_bias_shape->At(num_attn_bias_axes - 2) == 1\n                        || attn_bias_shape->At(num_attn_bias_axes - 2) >= q_m)\n            << \"The size of the -2 dimension of attn_bias should be greater than or equal to the \"\n               \"dimension 'M' of the query tensor or equal to 1.\";\n      }\n      if (num_attn_bias_axes >= 3) {\n        CHECK_OR_RETURN(attn_bias_shape->At(num_attn_bias_axes - 3) == 1\n                        || attn_bias_shape->At(num_attn_bias_axes - 3) == q_h)\n            << \"The size of the -3 dimension of attn_bias should be equal to the dimension 'H' of \"\n               \"the query tensor or equal to 1.\";\n      }\n      if (num_attn_bias_axes == 4) {\n        CHECK_OR_RETURN(attn_bias_shape->At(0) == 1 || attn_bias_shape->At(0) == q_b)\n            << \"The size of the -4 dimension of attn_bias should be equal to the dimension 'B' of \"\n               \"the query tensor or equal to 1.\";\n      }\n    }\n    const bool o_bm_packed = output_layout == \"(BM)(HK)\";\n    CHECK_EQ_OR_RETURN(o_bm_packed, q_bm_packed)\n        << \"The query tensor and the output tensor should either both be BM-Packed or both not be \"\n           \"BM-Packed at the same time.\";\n    std::string op_output_layout;\n    if (output_layout == \"BM(HK)\" || output_layout == \"(BM)(HK)\") {\n      op_output_layout = output_layout;\n    } else if (output_layout == \"MB(HK)\") {\n      if (q_b == 1) {\n        op_output_layout = output_layout;\n      } else {\n        op_output_layout = \"BM(HK)\";\n      }\n    } else {\n      UNIMPLEMENTED_THEN_RETURN() << \"output_layout should be 'BM(HK)', 'MB(HK)' or (BM)(HK)\";\n    }\n\n    double scale_value = 0.0;\n    if (scale) {\n      scale_value = JUST(scale);\n    } else {\n      scale_value = 1.0 / std::sqrt(static_cast<float>(q_k));\n    }\n\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"query_layout\", \"key_layout\", \"value_layout\",\n                                                 \"output_layout\", \"query_head_size\",\n                                                 \"attn_mask_type\", \"causal_diagonal_offset\",\n                                                 \"query_max_seq_len\", \"key_max_seq_len\", \"scale\");\n    attrs.SetAllAttrs(query_layout, key_tensor_layout, value_tensor_layout, op_output_layout, q_k,\n                      attn_mask_type_val, causal_diagonal_offset, query_max_seq_len.value_or(0),\n                      key_max_seq_len.value_or(0), scale_value);\n    OpExprCacheKey cache_key{};\n    std::vector<std::shared_ptr<one::Tensor>> inputs;\n    inputs.emplace_back(query);\n    inputs.emplace_back(key_tensor);\n    inputs.emplace_back(value_tensor);\n    if (attn_bias) {\n      inputs.emplace_back(JUST(attn_bias));\n      cache_key.has_attn_bias = true;\n    } else {\n      cache_key.has_attn_bias = false;\n    }\n    if (query_seq_start && key_seq_start) {\n      inputs.emplace_back(JUST(query_seq_start));\n      inputs.emplace_back(JUST(key_seq_start));\n      cache_key.has_seq_start = true;\n    } else {\n      cache_key.has_seq_start = false;\n    }\n    if (key_seq_len) {\n      inputs.emplace_back(JUST(key_seq_len));\n      cache_key.has_key_seq_len = true;\n    } else {\n      cache_key.has_key_seq_len = false;\n    }\n    auto it = op_cache_.find(cache_key);\n    CHECK_OR_RETURN(it != op_cache_.end());\n    TensorTuple input_tuple(inputs.size());\n    for (int i = 0; i < inputs.size(); ++i) { input_tuple[i] = std::move(inputs[i]); }\n    std::shared_ptr<one::Tensor> op_output =\n        JUST(OpInterpUtil::Dispatch<Tensor>(*it->second, input_tuple, attrs));\n    if (op_output_layout == output_layout) {\n      return op_output;\n    } else {\n      if (op_output_layout == \"BM(HK)\" && output_layout == \"MB(HK)\") {\n        return functional::Transpose(op_output, {1, 0, 2});\n      } else {\n        UNIMPLEMENTED_THEN_RETURN();\n      }\n    }\n  }\n\n private:\n  OpExprCache op_cache_;\n};\n\nclass FusedAttentionConcatPastKeyValueFunctor {\n public:\n  FusedAttentionConcatPastKeyValueFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"fused_attention_concat_past_key_value\")\n                         .Input(\"key\")\n                         .Input(\"value\")\n                         .Input(\"past_key\")\n                         .Input(\"past_value\")\n                         .Output(\"output_key\")\n                         .Output(\"output_value\")\n                         .Build());\n    op_without_past_ = CHECK_JUST(one::OpBuilder(\"fused_attention_concat_past_key_value\")\n                                      .Input(\"key\")\n                                      .Input(\"value\")\n                                      .Output(\"output_key\")\n                                      .Output(\"output_value\")\n                                      .Build());\n  }\n  Maybe<TensorTuple> operator()(\n      const Optional<one::Tensor>& past_key, const std::string& past_key_layout,\n      const Optional<one::Tensor>& past_value, const std::string& past_value_layout,\n      const std::shared_ptr<one::Tensor>& key, const std::string& key_layout,\n      const std::shared_ptr<one::Tensor>& value, const std::string& value_layout,\n      const Optional<int64_t>& key_head_size) const {\n    int64_t k_b = 0;\n    int64_t k_m = 0;\n    int64_t k_h = 0;\n    int64_t k_k = 0;\n    JUST(ParseDims(\"key\", *key->shape(), key_layout, Optional<int64_t>(), key_head_size, &k_b, &k_m,\n                   &k_h, &k_k));\n\n    int64_t v_b = 0;\n    int64_t v_m = 0;\n    int64_t v_h = 0;\n    int64_t v_k = 0;\n    JUST(ParseDims(\"value\", *value->shape(), value_layout, k_h, k_k, &v_b, &v_m, &v_h, &v_k));\n    CHECK_EQ_OR_RETURN(v_b, k_b) << \"The size of dimension 'B' of the value tensor should be \"\n                                    \"the same as that of the key tensor.\";\n    CHECK_EQ_OR_RETURN(v_m, k_m) << \"The size of dimension 'M' of the value tensor should be the \"\n                                    \"same as that of the key tensor.\";\n\n    if (past_key) {\n      CHECK_OR_RETURN(past_value) << \"Tensor past_key and tensor past_value should both be None or \"\n                                     \"both not be None at the same time.\";\n      int64_t past_k_b = 0;\n      int64_t past_k_m = 0;\n      int64_t past_k_h = 0;\n      int64_t past_k_k = 0;\n      JUST(ParseDims(\"past_key\", *JUST(past_key)->shape(), past_key_layout, k_h, k_k, &past_k_b,\n                     &past_k_m, &past_k_h, &past_k_k));\n      CHECK_EQ_OR_RETURN(past_k_b, k_b)\n          << \"The size of dimension 'B' of the past_key tensor should be \"\n             \"the same as that of the key tensor.\";\n      int64_t past_v_b = 0;\n      int64_t past_v_m = 0;\n      int64_t past_v_h = 0;\n      int64_t past_v_k = 0;\n      JUST(ParseDims(\"past_value\", *JUST(past_value)->shape(), past_value_layout, k_h, k_k,\n                     &past_v_b, &past_v_m, &past_v_h, &past_v_k));\n      CHECK_EQ_OR_RETURN(past_v_b, k_b) << \"The size of dimension 'B' of the past_value tensor \"\n                                           \"should be the same as that of the key tensor.\";\n      CHECK_EQ_OR_RETURN(past_v_m, past_k_m)\n          << \"The size of dimension 'M' of the past_value tensor \"\n             \"should be the same as that of the past_key tensor.\";\n    } else {\n      CHECK_OR_RETURN(!past_value)\n          << \"Tensor past_key and tensor past_value should both be None or \"\n             \"both not be None at the same time.\";\n    }\n\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"past_key_layout\", \"past_value_layout\",\n                                                 \"key_layout\", \"value_layout\", \"key_head_size\");\n    attrs.SetAllAttrs(past_key_layout, past_value_layout, key_layout, value_layout, k_k);\n    if (past_key) {\n      return JUST(OpInterpUtil::Dispatch<TensorTuple>(\n          *op_, {key, value, JUST(past_key), JUST(past_value)}, attrs));\n    } else {\n      return JUST(OpInterpUtil::Dispatch<TensorTuple>(*op_without_past_, {key, value}, attrs));\n    }\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n  std::shared_ptr<OpExpr> op_without_past_;\n};\n\nclass FusedApplyRotaryEmbFunctor {\n public:\n  FusedApplyRotaryEmbFunctor() {\n    op_with_position_sinuous_ = CHECK_JUST(one::OpBuilder(\"fused_apply_rotary_emb\")\n                                               .Input(\"x\")\n                                               .Input(\"cos\")\n                                               .Input(\"sin\")\n                                               .Input(\"position_ids\")\n                                               .Output(\"out\")\n                                               .Build());\n    op_with_position_ = CHECK_JUST(one::OpBuilder(\"fused_apply_rotary_emb\")\n                                       .Input(\"x\")\n                                       .Input(\"position_ids\")\n                                       .Output(\"out\")\n                                       .Build());\n    op_without_position_ = CHECK_JUST(one::OpBuilder(\"fused_apply_rotary_emb\")\n                                          .Input(\"x\")\n                                          .Input(\"cos\")\n                                          .Input(\"sin\")\n                                          .Output(\"out\")\n                                          .Build());\n    op_without_position_sinuous_ =\n        CHECK_JUST(one::OpBuilder(\"fused_apply_rotary_emb\").Input(\"x\").Output(\"out\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const Optional<one::Tensor>& cos,\n                           const Optional<one::Tensor>& sin,\n                           const Optional<one::Tensor>& position_ids, const std::string& x_layout,\n                           const Optional<std::string>& output_layout, const std::string& mode,\n                           const Optional<int64_t>& tensor_index, const Optional<int64_t>& k_size,\n                           const float base, const Optional<int64_t>& rotary_size) const {\n    int64_t b = 0, m = 0, h = 0, k = 0;\n\n    if (tensor_index) {\n      CHECK_OR_RETURN((JUST(tensor_index) >= 0) && (JUST(tensor_index) <= 2))\n          << \"tensor_index should be set between [0, 2]\";\n    }\n    CHECK_OR_RETURN((mode == \"interval\") || (mode == \"plane\"))\n        << \"mode should be \\\"intervel\\\" or \\\"plane\\\"\";\n\n    ParseDims(\"x\", *x->shape(), x_layout, Optional<int64_t>(), k_size, &b, &m, &h, &k);\n\n    if (k_size) {\n      CHECK_EQ_OR_RETURN(JUST(k_size), k)\n          << \"k_size if given should be equal to K of cos, sin and x.\";\n    }\n    if (rotary_size) {\n      CHECK_LE_OR_RETURN(JUST(rotary_size), k) << \"rotary_size should be no more than k.\";\n    }\n\n    int64_t rotary_emd_dim = 1;\n\n    if (position_ids) {\n      CHECK_EQ_OR_RETURN(JUST(position_ids)->shape()->NumAxes(), 3)\n          << \"ndims of position_ids should be equal to 3, either in form of B1M or B2M.\";\n      CHECK_EQ_OR_RETURN(JUST(position_ids)->shape()->At(0), b)\n          << \"1st dim of position_ids should be equal to B.\";\n      CHECK_EQ_OR_RETURN(JUST(position_ids)->shape()->At(2), m)\n          << \"3rd dim of position_ids should be equal to M.\";\n      rotary_emd_dim = JUST(position_ids)->shape()->At(1);\n      CHECK_OR_RETURN(rotary_emd_dim == 1 || rotary_emd_dim == 2)\n          << \"2nd dim of position_ids should be 1 or 2.\";\n    }\n\n    const int64_t actual_rotary_size = rotary_size.value_or(k) / rotary_emd_dim;\n    CHECK_EQ_OR_RETURN(actual_rotary_size % 2, 0)\n        << \"k ,or rotary_size if given, should be a multiple of 2 * rotary_encoding_dim.\";\n\n    if (cos && sin) {\n      CHECK_EQ_OR_RETURN(JUST(cos)->shape()->NumAxes(), 2)\n          << \"The number of dimensions of cos should be equal to 2.\";\n      CHECK_OR_RETURN(JUST(cos)->shape() == JUST(sin)->shape())\n          << \"Each dimension of cos & sin should be the same.\";\n      CHECK_EQ_OR_RETURN(JUST(cos)->shape()->At(1), actual_rotary_size)\n          << \"The 1st dimension of cos & sin should equal to rotary_size // \"\n             \"rotary_embedding_dimension.\";\n    } else if (!cos && !sin) {\n      // do nothing\n    } else {\n      UNIMPLEMENTED_THEN_RETURN() << \"cos & sin should both be given or not given.\";\n    }\n\n    if (!position_ids) {\n      if (cos && sin) {\n        CHECK_GE_OR_RETURN(JUST(cos)->shape()->At(0), m)\n            << \"M of cos & sin should be to no less than \"\n               \"M of x when position_ids is not \"\n               \"given.\";  // K of cos & sin is checked\n                          // inside ParseDims\n      }\n    }\n\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"x_layout\", \"output_layout\", \"mode\",\n                                                 \"tensor_index\", \"k_size\", \"base\", \"rotary_size\");\n    attrs.SetAllAttrs(x_layout, output_layout.value_or(x_layout), mode, tensor_index.value_or(0),\n                      k_size.value_or(k), base, rotary_size.value_or(k));\n\n    if (position_ids) {\n      if (cos && sin) {\n        return OpInterpUtil::Dispatch<Tensor>(*op_with_position_sinuous_,\n                                              {x, JUST(cos), JUST(sin), JUST(position_ids)}, attrs);\n      } else {\n        return OpInterpUtil::Dispatch<Tensor>(*op_with_position_, {x, JUST(position_ids)}, attrs);\n      }\n    } else {\n      if (cos && sin) {\n        return OpInterpUtil::Dispatch<Tensor>(*op_without_position_, {x, JUST(cos), JUST(sin)},\n                                              attrs);\n      } else {\n        return OpInterpUtil::Dispatch<Tensor>(*op_without_position_sinuous_, {x}, attrs);\n      }\n    }\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_with_position_;\n  std::shared_ptr<OpExpr> op_with_position_sinuous_;\n  std::shared_ptr<OpExpr> op_without_position_;\n  std::shared_ptr<OpExpr> op_without_position_sinuous_;\n};\n\n}  // namespace impl\n\nONEFLOW_FUNCTION_LIBRARY(m) {\n  m.add_functor<impl::FusedMultiHeadAttentionInferenceFunctor>(\"FusedMultiHeadAttentionInference\");\n  m.add_functor<impl::FusedMultiHeadAttentionInferenceV2Functor>(\n      \"FusedMultiHeadAttentionInferenceV2\");\n  m.add_functor<impl::FusedAttentionConcatPastKeyValueFunctor>(\"FusedAttentionConcatPastKeyValue\");\n  m.add_functor<impl::FusedApplyRotaryEmbFunctor>(\"FusedApplyRotaryEmb\");\n}\n\n}  // namespace functional\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/functional/impl/global_cast.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/framework/consistency_check.h\"\n#include \"oneflow/core/functional/function_library.h\"\n#include \"oneflow/core/framework/id_util.h\"\n#include \"oneflow/core/framework/mutable_attr_map.h\"\n#include \"oneflow/core/framework/tensor.h\"\n#include \"oneflow/core/framework/tensor_tuple.h\"\n#include \"oneflow/core/framework/op_builder.h\"\n#include \"oneflow/core/framework/op_interpreter/op_interpreter_util.h\"\n#include \"oneflow/core/framework/nd_sbp.h\"\n#include \"oneflow/core/functional/functional.h\"\n#include \"oneflow/core/operator/operator.h\"\n#include \"oneflow/core/autograd/autograd_mode.h\"\n#include \"oneflow/core/autograd/autograd_engine.h\"\n#include \"oneflow/core/framework/tensor_rpc_util.h\"\n#include \"oneflow/core/control/global_process_ctx.h\"\n#include \"oneflow/core/job/global_for.h\"\n#include \"oneflow/core/job/resource_desc.h\"\n#include \"oneflow/core/job/rank_group_scope.h\"\n#include \"oneflow/core/job/lazy_mode.h\"\n#include \"oneflow/core/framework/transport_token.h\"\n#include \"oneflow/core/framework/transport_util.h\"\n#include \"oneflow/core/framework/placement_sbp_util.h\"\n#include \"oneflow/core/intrusive/flat_msg.h\"\n#include \"oneflow/core/common/flat_shape.h\"\n#include \"oneflow/core/common/container_util.h\"\n#include \"oneflow/core/common/balanced_splitter.h\"\n#include \"oneflow/core/common/decorator.h\"\n#include \"oneflow/core/common/optional.h\"\n#include \"oneflow/core/common/cpp_attribute.h\"\n#include \"oneflow/core/ccl/ccl.h\"\n#include \"oneflow/core/common/constant.h\"\n#include \"oneflow/core/common/env_var/debug_mode.h\"\n#include \"oneflow/user/kernels/collective_communication/include/broadcast.h\"\n\nnamespace oneflow {\nnamespace one {\nnamespace functional {\n\nnamespace impl {\n\nnamespace {\n\n// NOTE: use env variable 'ONEFLOW_EAGER_LOCAL_TO_GLOBAL_BALANCED_OVERRIDE' indicate whether the\n// shape and dtype of input tensor on each rank is the same when cast local tensor to global tensor.\n// If set true, there will be no meta-information synchronization on each rank.\nOptional<bool> ParseEagerLocalToGlobalBalancedOverride() {\n  const char* env_p = std::getenv(\"ONEFLOW_EAGER_LOCAL_TO_GLOBAL_BALANCED_OVERRIDE\");\n  if (env_p == nullptr) {\n    return Optional<bool>();\n  } else {\n    return ParseBooleanFromEnv(\"ONEFLOW_EAGER_LOCAL_TO_GLOBAL_BALANCED_OVERRIDE\", false);\n  }\n}\n\nbool NeedSyncAndCheckShapeAndDtype(bool check_meta_hint) {\n  thread_local Optional<bool> eager_local_to_global_balanced_override =\n      ParseEagerLocalToGlobalBalancedOverride();\n  if (eager_local_to_global_balanced_override.has_value()) {\n    return IsInDebugMode() || !CHECK_JUST(eager_local_to_global_balanced_override);\n  } else {\n    return IsInDebugMode() || check_meta_hint;\n  }\n}\n\n// clang-format off\nFLAT_MSG_BEGIN(FlatShapeAndDataType);\n  // Methods\n  static Maybe<FlatShapeAndDataType> New() {\n    const auto& flat_shape_dtype = std::make_shared<FlatShapeAndDataType>();\n    flat_shape_dtype->clear();\n    return flat_shape_dtype;\n  }\n  static Maybe<FlatShapeAndDataType> New(const Shape& shape, DataType dtype) {\n    const auto& flat_shape_dtype = JUST(New());\n    JUST(flat_shape_dtype->mutable_shape()->Init(shape));\n    flat_shape_dtype->set_dtype(dtype);\n    return flat_shape_dtype;\n  }\n  Maybe<void> Check(const Shape& shape, DataType dtype) const {\n    JUST(this->shape().Check(shape));\n    CHECK_EQ_OR_RETURN(this->dtype(), dtype) << Error::RuntimeError()\n        << \"Expected all tensors on each rank to be the same dtype, but found \"\n            \"at least two dtypes, \" << DType(this->dtype()).name() << \" and \"\n        << DType(dtype).name() << \"!\";\n    return Maybe<void>::Ok();\n  }\n  Maybe<void> Check(const FlatShapeAndDataType& flat_shape_dtype) const {\n    JUST(this->shape().Check(flat_shape_dtype.shape()));\n    CHECK_EQ_OR_RETURN(this->dtype(), flat_shape_dtype.dtype())\n        << Error::RuntimeError()\n        << \"Expected input of each rank must have the same dtype, but got at least two dtypes, \"\n        << DType(this->dtype()).name() << \" and \" << DType(flat_shape_dtype.dtype()).name();\n    return Maybe<void>::Ok();\n  }\n  Maybe<void> ToShape(Shape* shape) const { return this->shape().ToShape(shape); }\n  Maybe<Shape> ToShape() const { return shape().ToShape(); }\n  int64_t At(int i) const { return shape().At(i); }\n  int64_t NumAxes() const { return shape().NumAxes(); }\n\n private:\n  // Fields\n  FLAT_MSG_DEFINE_OPTIONAL(FlatShape, shape);\n  FLAT_MSG_DEFINE_OPTIONAL(DataType, dtype);\nFLAT_MSG_END(FlatShapeAndDataType);\n// clang-format on\n\nMaybe<void> ShapeAndDataTypeConsistencyCheck(const Symbol<ParallelDesc>& placement,\n                                             const Shape& shape, DataType dtype) {\n  if (!placement->containing_current_rank() || placement->parallel_num() == 1) {\n    return Maybe<void>::Ok();\n  }\n\n  const auto& transport_token =\n      JUST(TransportToken::NewTransportToken(kTransportTokenTypeSyncLocalShapeDtype));\n  const auto& send_buffer = JUST(FlatShapeAndDataType::New(shape, dtype));\n  const auto& recv_buffer = JUST(FlatShapeAndDataType::New());\n  recv_buffer->clear();\n\n  NaiveAsyncTransportCtx ctx(\n      transport_token,\n      [send_buffer](void** buffer, std::size_t* size, std::function<void()>* Cb) -> Maybe<void> {\n        *buffer = send_buffer.get();\n        *size = sizeof(FlatShapeAndDataType);\n        *Cb = [send_buffer] {};\n        return Maybe<void>::Ok();\n      },\n      [recv_buffer](int64_t rank, void** buffer, std::size_t* size,\n                    std::function<void()>* Cb) -> Maybe<void> {\n        *buffer = recv_buffer.get();\n        *size = sizeof(FlatShapeAndDataType);\n        *Cb = [recv_buffer] {};\n        return Maybe<void>::Ok();\n      });\n  const auto& rank_group = JUST(RankGroup::New(placement));\n  JUST(TransportUtil::SendToNextRankInRing(rank_group, transport_token, &ctx));\n  JUST(TransportUtil::ReceiveFromPrevRankInRing(rank_group, transport_token, &ctx));\n  JUST_MSG(ctx.WaitDone(), kAsymmetricCodeErrorMsg);\n  JUST(send_buffer->Check(*recv_buffer));\n  return Maybe<void>::Ok();\n}\n\nMaybe<HashMap<int64_t, std::shared_ptr<FlatShapeAndDataType>>> BroadcastGatherShapeAndDataType(\n    const Shape& shape, DataType dtype, Symbol<ParallelDesc> parallel_desc) {\n  const auto& transport_token =\n      JUST(TransportToken::NewTransportToken(kTransportTokenTypeSyncLocalShapeDtype));\n  const auto& send_buffer = JUST(FlatShapeAndDataType::New(shape, dtype));\n  const auto& map = std::make_shared<HashMap<int64_t, std::shared_ptr<FlatShapeAndDataType>>>();\n  map->emplace(GlobalProcessCtx::Rank(), send_buffer);\n  NaiveAsyncTransportCtx ctx(\n      transport_token,\n      [send_buffer](void** buffer, std::size_t* size, std::function<void()>* Cb) -> Maybe<void> {\n        *buffer = send_buffer.get();\n        *size = sizeof(FlatShapeAndDataType);\n        *Cb = [send_buffer] {};\n        return Maybe<void>::Ok();\n      },\n      [map](int64_t rank, void** buffer, std::size_t* size,\n            std::function<void()>* Cb) -> Maybe<void> {\n        const auto& recv_buffer = JUST(FlatShapeAndDataType::New());\n        recv_buffer->clear();\n        *buffer = recv_buffer.get();\n        *size = sizeof(FlatShapeAndDataType);\n        *Cb = [recv_buffer] {};\n        CHECK_OR_RETURN(map->emplace(rank, recv_buffer).second);  // NOLINT(maybe-need-error-msg)\n        return Maybe<void>::Ok();\n      });\n  const auto& rank_group = JUST(RankGroup::New(parallel_desc));\n  JUST(TransportUtil::BroadcastToOtherRanks(rank_group, rank_group, transport_token, &ctx));\n  JUST(TransportUtil::CollectFromOtherRanks(rank_group, rank_group, transport_token, &ctx));\n  JUST_MSG(ctx.WaitDone(), kAsymmetricCodeErrorMsg);\n  return map;\n}\n\nMaybe<int64_t> FindRoot(Symbol<ParallelDesc> broadcast_parallel_desc,\n                        Symbol<ParallelDesc> src_parallel_desc) {\n  for (int64_t process_id : broadcast_parallel_desc->sorted_machine_ids()) {\n    if (src_parallel_desc->ContainingMachineId(process_id)) { return process_id; }\n  }\n  UNIMPLEMENTED_THEN_RETURN();\n}\n\nauto* CachedFindRoot = DECORATE(&FindRoot, ThreadLocal);\n\nMaybe<FlatShapeAndDataType> BroadcastShapeAndDtype(const Shape& shape, DataType dtype,\n                                                   Symbol<ParallelDesc> parallel_desc) {\n  const auto& rank_group = JUST(RankGroupScope::CurrentRankGroup());\n  const auto& rank_group_parallel_desc =\n      JUST(RankGroup::GetDefaultParallelDesc(parallel_desc->device_type(), rank_group));\n  const auto& process_id2broadcast_group =\n      JUST(GetBroadcastGroup(parallel_desc, rank_group_parallel_desc));\n  const auto& broadcast_parallel_desc =\n      JUST(MapAt(*process_id2broadcast_group, GlobalProcessCtx::Rank()));\n\n  const auto& in_flat_shape_dtype = JUST(FlatShapeAndDataType::New(shape, dtype));\n  const auto& out_flat_shape_dtype = JUST(FlatShapeAndDataType::New());\n  int64_t root = JUST(CachedFindRoot(broadcast_parallel_desc, parallel_desc));\n  const auto& transport_token =\n      JUST(TransportToken::NewTransportToken(kTransportTokenTypeSyncLocalShapeDtype));\n  JUST(ccl::CpuBroadcast(in_flat_shape_dtype.get(), out_flat_shape_dtype.get(),\n                         sizeof(FlatShapeAndDataType), root, broadcast_parallel_desc,\n                         transport_token));\n  return out_flat_shape_dtype;\n}\n\nMaybe<void> GetConcatenatedShapeAndCheckDtype(\n    Shape* logical_shape, DataType* dtype,\n    const HashMap<int64_t, std::shared_ptr<FlatShapeAndDataType>>& rank2flat_shape_dtype,\n    Symbol<ParallelDesc> parallel_desc, Symbol<NdSbp> nd_sbp) {\n  *dtype = rank2flat_shape_dtype.begin()->second->dtype();\n  HashMap<int64_t, std::shared_ptr<Shape>> rank2logical_shape;\n  for (const auto& pair : rank2flat_shape_dtype) {\n    rank2logical_shape.emplace(pair.first, JUST(pair.second->ToShape()));\n    CHECK_EQ_OR_RETURN(*dtype, pair.second->dtype())\n        << Error::RuntimeError()\n        << \"Expected all tensors on each rank to be the same dtype, but found \"\n           \"at least two dtypes, \"\n        << DType(*dtype).name() << \"(rank \" << rank2flat_shape_dtype.begin()->first << \") and \"\n        << DType(pair.second->dtype()).name() << \"(rank \" << pair.first << \")!\";\n  }\n  const auto& GetRankPhyShapeByParallelId = [&](Symbol<ParallelDesc> parallel_desc,\n                                                int64_t parallel_id) -> Maybe<Shape> {\n    int64_t machine_id = JUST(parallel_desc->MachineId4ParallelId(parallel_id));\n    return JUST(MapAt(rank2logical_shape, machine_id));\n  };\n  const auto& parallel_hierarchy = parallel_desc->hierarchy();\n  Stride parallel_stride(*parallel_hierarchy);\n  for (int32_t i = nd_sbp->sbp_parallel_size() - 1; i >= 0; --i) {\n    if (nd_sbp->sbp_parallel(i).has_split_parallel()) {\n      int64_t concat_axis = nd_sbp->sbp_parallel(i).split_parallel().axis();\n      int64_t group_size = parallel_hierarchy->Count(0, i);\n      int64_t stride = parallel_stride.at(i);\n      for (int group_id = 0; group_id < group_size; ++group_id) {\n        int64_t parallel_num_in_group = parallel_hierarchy->At(i);\n        for (int64_t stride_id = 0; stride_id < stride; ++stride_id) {\n          ParallelConf parallel_conf;\n          parallel_conf.set_device_tag(parallel_desc->device_tag());\n          int64_t start_parallel_id = group_id * parallel_num_in_group + stride_id;\n          for (int64_t parallel_id_in_group = 0; parallel_id_in_group < parallel_num_in_group;\n               ++parallel_id_in_group) {\n            int64_t id = start_parallel_id + parallel_id_in_group * stride;\n            int64_t machine_id = JUST(parallel_desc->MachineId4ParallelId(id));\n            int64_t device_id = JUST(parallel_desc->DeviceId4ParallelId(id));\n            parallel_conf.add_device_name(std::string(\"@\") + std::to_string(machine_id) + \":\"\n                                          + std::to_string(device_id));\n          }\n          Symbol<ParallelDesc> sub_parallel_desc = SymbolOf(ParallelDesc(parallel_conf));\n          std::shared_ptr<Shape> first_shape =\n              JUST(GetRankPhyShapeByParallelId(sub_parallel_desc, 0));\n          CHECK_GE_OR_RETURN(concat_axis, 0)\n              << Error::RuntimeError() << \"Split axis must not be negative, but got \" << concat_axis\n              << \"!\";\n          CHECK_LT_OR_RETURN(concat_axis, first_shape->NumAxes())\n              << Error::RuntimeError() << \"Split axis out of range (expected to be in range of [\"\n              << 0 << \", \" << first_shape->NumAxes() << \"), but got \" << concat_axis << \"!)\";\n\n          int64_t logical_concat_dim = first_shape->At(concat_axis);\n          for (int parallel_id = 1; parallel_id < sub_parallel_desc->parallel_num();\n               ++parallel_id) {\n            const auto& rank_shape =\n                JUST(GetRankPhyShapeByParallelId(sub_parallel_desc, parallel_id));\n            CHECK_EQ_OR_RETURN(rank_shape->NumAxes(), first_shape->NumAxes())\n                << Error::RuntimeError() << \"Sizes of tensors must match except in dimension \"\n                << concat_axis << \", but found \" << first_shape->ToString() << \"(rank \"\n                << JUST(sub_parallel_desc->MachineId4ParallelId(0)) << \") and \"\n                << rank_shape->ToString() << \"(rank \"\n                << JUST(sub_parallel_desc->MachineId4ParallelId(parallel_id)) << \")!\";\n            logical_concat_dim += rank_shape->At(concat_axis);\n          }\n\n          BalancedSplitter bs(logical_concat_dim, sub_parallel_desc->parallel_num());\n          CHECK_EQ_OR_RETURN(first_shape->At(concat_axis), bs.At(0).size())\n              << Error::RuntimeError() << \"Sizes of tensors in dimension \" << concat_axis\n              << \" must be same or match balanced split distribution. See \"\n                 \"https://github.com/Oneflow-Inc/oneflow/blob/master/oneflow/core/common/\"\n                 \"balanced_splitter.h \"\n                 \"for details of balanced split\";\n          first_shape->Set(concat_axis, logical_concat_dim);\n\n          for (int parallel_id = 1; parallel_id < sub_parallel_desc->parallel_num();\n               ++parallel_id) {\n            std::shared_ptr<Shape> rank_shape =\n                JUST(GetRankPhyShapeByParallelId(sub_parallel_desc, parallel_id));\n            for (int i = 0; i < first_shape->NumAxes(); ++i) {\n              if (i == concat_axis) {\n                CHECK_EQ_OR_RETURN(rank_shape->At(i), bs.At(parallel_id).size())\n                    << Error::RuntimeError() << \"Sizes of tensors in dimension \" << concat_axis\n                    << \" must be same or match balanced split distribution. See \"\n                       \"https://github.com/Oneflow-Inc/oneflow/blob/master/oneflow/core/common/\"\n                       \"balanced_splitter.h \"\n                       \"for details of balanced split\";\n              } else {\n                CHECK_EQ_OR_RETURN(rank_shape->At(i), first_shape->At(i))\n                    << Error::RuntimeError() << \"Sizes of tensors must match except in dimension \"\n                    << concat_axis << \". Expected size \" << first_shape->At(i) << \" but got size \"\n                    << rank_shape->At(i) << \" for tensor on rank \"\n                    << JUST(sub_parallel_desc->MachineId4ParallelId(parallel_id)) << \"!\";\n              }\n            }\n            rank_shape->Set(concat_axis, logical_concat_dim);\n          }\n        }\n      }\n    }\n  }\n  *logical_shape = *JUST(GetRankPhyShapeByParallelId(parallel_desc, 0));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> GetLogicalShapeAndDataType(Shape* logical_shape, DataType* /* in and out */ dtype,\n                                       std::shared_ptr<const Shape> physical_shape,\n                                       Symbol<ParallelDesc> parallel_desc, Symbol<NdSbp> nd_sbp,\n                                       bool sync_and_check_meta) {\n  if (!sync_and_check_meta) {\n    *logical_shape = *JUST(GetLogicalShape(*physical_shape, *nd_sbp, *parallel_desc));\n  } else {\n    if (ContainSplitSbp(nd_sbp)) {\n      *logical_shape = *physical_shape;\n      if (parallel_desc->containing_current_rank()) {\n        const auto& rank2flat_shape_dtype =\n            JUST(BroadcastGatherShapeAndDataType(*logical_shape, *dtype, parallel_desc));\n        JUST(GetConcatenatedShapeAndCheckDtype(logical_shape, dtype, *rank2flat_shape_dtype,\n                                               parallel_desc, nd_sbp));\n      }\n    } else {\n      *logical_shape = *physical_shape;\n      JUST(ShapeAndDataTypeConsistencyCheck(parallel_desc, *logical_shape, *dtype));\n    }\n  }\n  if (JUST(RankGroup::New(parallel_desc)) != JUST(RankGroupScope::CurrentRankGroup())) {\n    const auto& flat_shape_dtype =\n        JUST(BroadcastShapeAndDtype(*logical_shape, *dtype, parallel_desc));\n    *logical_shape = *JUST(flat_shape_dtype->ToShape());\n    *dtype = flat_shape_dtype->dtype();\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> CheckNdSbpValid(Symbol<NdSbp> nd_sbp, const Shape& logical_shape) {\n  for (int i = 0; i < nd_sbp->sbp_parallel_size(); ++i) {\n    const auto& sbp_parallel = nd_sbp->sbp_parallel(i);\n    if (sbp_parallel.has_split_parallel()) {\n      CHECK_LT_OR_RETURN(sbp_parallel.split_parallel().axis(), logical_shape.NumAxes())\n          << Error::RuntimeError() << \"Split axis out of range (expected to be in range of [\" << 0\n          << \", \" << logical_shape.NumAxes() << \"), but got \"\n          << sbp_parallel.split_parallel().axis() << \"!)\";\n    }\n  }\n  return Maybe<void>::Ok();\n}\n\nnamespace {\n\nMaybe<one::OpExpr> RawGetGlobalToGlobalOpExpr(\n    const std::vector<Symbol<SbpParallel>>& grad_sbp_parallels) {\n  Optional<Symbol<NdSbp>> grad_nd_sbp;\n  if (!grad_sbp_parallels.empty()) { grad_nd_sbp = JUST(GetNdSbp(grad_sbp_parallels)); }\n  std::shared_ptr<one::OpExpr> op_expr = JUST(one::GlobalToGlobalOpExpr::New(grad_nd_sbp));\n  return op_expr;\n}\n\n}  // namespace\n\nstatic constexpr auto* GetGlobalToGlobalOpExpr =\n    DECORATE(&RawGetGlobalToGlobalOpExpr, ThreadLocalCopiable);\n\nMaybe<Tensor> GlobalToGlobal(const std::shared_ptr<Tensor>& x, Symbol<ParallelDesc> parallel_desc,\n                             const std::vector<Symbol<SbpParallel>>& sbp_parallels,\n                             const std::vector<Symbol<SbpParallel>>& grad_sbp_parallels,\n                             bool copy) {\n  const auto& global_tensor = JUST(x->AsGlobalTensor());\n  CHECK_NOTNULL_OR_RETURN(global_tensor) << \"global tensors supported only\";\n  const auto& nd_sbp = JUST(GetNdSbp(sbp_parallels));\n  JUST(CheckNdSbpValid(nd_sbp, *x->shape()));\n  std::shared_ptr<one::OpExpr> op;\n  if (unlikely(!LazyMode::is_enabled()\n               && JUST(x->parallel_desc())->hierarchy()->NumAxes()\n                      != parallel_desc->hierarchy()->NumAxes()\n               && grad_sbp_parallels.size() == 0)) {\n    op = JUST(GetGlobalToGlobalOpExpr(*JUST(GetSbpList(JUST(x->nd_sbp())))));\n  } else {\n    op = JUST(GetGlobalToGlobalOpExpr(grad_sbp_parallels));\n  }\n  if (!LazyMode::is_enabled() && JUST(x->nd_sbp()) == nd_sbp\n      && JUST(x->parallel_desc()) == parallel_desc\n      && (grad_sbp_parallels.size() == 0 || !autograd::GradMode::is_enabled())) {\n    if (copy) { return functional::Identity(x); }\n    return x;\n  }\n  const auto& tensor = JUST(OpInterpUtil::Dispatch<one::Tensor>(\n      *op, {global_tensor}, OpExprInterpContext(AttrMap{}, parallel_desc, nd_sbp)));\n  if (!LazyMode::is_enabled() && tensor != x && !IsGlobalTensorMetaCheckDisabled()) {\n    const auto& input_global_id = JUST(x->transport_token());\n    const auto& output_consistend_id = JUST(tensor->transport_token());\n    CHECK_NE_OR_RETURN(input_global_id, output_consistend_id);  // NOLINT(maybe-need-error-msg)\n  }\n  return tensor;\n}\n\nMaybe<Tensor> LocalToGlobal(const std::shared_ptr<Tensor>& x, Symbol<ParallelDesc> parallel_desc,\n                            const std::vector<Symbol<SbpParallel>>& sbp_parallels,\n                            const Optional<Shape>& opt_shape, const Optional<DataType>& opt_dtype,\n                            const std::shared_ptr<OpExpr>& op, bool check_meta_hint, bool sync_data,\n                            bool copy) {\n  CHECK_OR_RETURN(!x->is_lazy())\n      << Error::RuntimeError()\n      << \"local_tensor.to_global() is not supported within nn.Graph for now\";\n  CHECK_OR_RETURN(x->is_local()) << Error::RuntimeError() << \"local tensors supported only\";\n  std::shared_ptr<one::Tensor> input = x->contiguous();\n  // copy to right device first if input's device type is wrong\n  if (JUST(input->device())->type() != parallel_desc->device_tag()) {\n    VLOG(2) << \"The device_type of the input tensor is different from placement, now copy it to \"\n            << parallel_desc->device_tag();\n    input = JUST(functional::Copy(x, parallel_desc->device_tag(), GlobalProcessCtx::LocalRank(),\n                                  /*pin_memory=*/false));\n  }\n  // copy to default device of the current rank if input's device type is right but not on default\n  // device\n  bool device_mismatch = JUST(input->device())->device_id() != GlobalProcessCtx::LocalRank();\n  if (copy || device_mismatch) {\n    if (device_mismatch) {\n      VLOG(2) << \"The tensor isn't on default device of the current rank, now copy it to \"\n              << parallel_desc->device_tag() << \": \" << GlobalProcessCtx::LocalRank();\n    }\n    input = JUST(functional::Copy(x, parallel_desc->device_tag(), GlobalProcessCtx::LocalRank(),\n                                  /*pin_memory=*/false));\n  }\n  const auto& device = JUST(input->device());\n  CHECK_EQ_OR_RETURN(device->type(), parallel_desc->device_tag())\n      << Error::UnimplementedError() << \"tensor' device type must be same with placement.\";\n  CHECK_EQ_OR_RETURN(device->device_id(), GlobalProcessCtx::LocalRank())\n      << Error::UnimplementedError() << \"tensor must be on default device of the current rank.\";\n  Symbol<NdSbp> nd_sbp = JUST(GetNdSbp(sbp_parallels));\n  DataType dtype = x->dtype()->data_type();\n\n  std::shared_ptr<Shape> shape = std::make_shared<Shape>();\n  if (opt_shape.has_value() && opt_dtype.has_value()) {\n    shape = JUST(opt_shape);\n    dtype = JUST(opt_dtype);\n  } else {\n    bool sync_and_check_meta = NeedSyncAndCheckShapeAndDtype(check_meta_hint);\n    JUST(GetLogicalShapeAndDataType(shape.get(), &dtype, x->shape(), parallel_desc, nd_sbp,\n                                    sync_and_check_meta));\n  }\n\n  auto& attrs =\n      THREAD_CACHED_MUTABLE_ATTR_MAP(\"shape\", \"dtype\", \"sync_data\", \"inplace_when_sync_data\");\n  attrs.SetAllAttrs(*shape, dtype, sync_data, !copy);\n  const auto& output = JUST(OpInterpUtil::Dispatch<one::Tensor>(\n      *op, {input}, OpExprInterpContext(attrs, parallel_desc, nd_sbp)));\n  return output;\n}\n\n}  //  namespace\n\nclass LocalToGlobalFunctor {\n public:\n  LocalToGlobalFunctor() {\n    op_ = CHECK_JUST(one::LocalToGlobalOpExpr::New(*CHECK_JUST(UniqueStr(\"local_to_global\"))));\n  }\n\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,\n                           Symbol<ParallelDesc> parallel_desc,\n                           const std::vector<Symbol<SbpParallel>>& sbp_parallels,\n                           const Shape& shape, const Symbol<DType>& dtype, bool sync_data,\n                           bool copy) const {\n    JUST(CheckDeviceIdsIsValid(parallel_desc));\n    NonRecursiveMetaInfoConsistencyCheckScope no_recursive_meta_info_conisitency_check_scope;\n    JUST(MetaInfoConsistencyCheck(parallel_desc, sbp_parallels, 1, /* force_check */ false));\n    DisableCheckGlobalTensorMetaScope scope{};\n    std::shared_ptr<Tensor> tensor;\n    DeviceType device_type = parallel_desc->device_type();\n    if (ccl::IsBroadcastRegistered(device_type) || !sync_data || device_type == DeviceType::kMeta) {\n      tensor = JUST(LocalToGlobal(x, parallel_desc, sbp_parallels, shape, dtype->data_type(), op_,\n                                  /* check_meta */ false, sync_data, copy));\n    } else {\n      // Assuming that the newly adapted hardware device does not support collective\n      // communication, since local to global may need to synchronize data (through the\n      // broadcast API), if device_type is neither cpu nor cuda, generate global tensor\n      // with the corresponding cpu placement first, then convert the cpu global tensor\n      // to the desired placement.\n      Symbol<ParallelDesc> cpu_parallel_desc =\n          JUST(ReplaceDeviceType(parallel_desc, DeviceType::kCPU));\n      std::shared_ptr<Tensor> cpu_tensor =\n          JUST(LocalToGlobal(x, cpu_parallel_desc, sbp_parallels, shape, dtype->data_type(), op_,\n                             /* check_meta */ false, sync_data, copy));\n      tensor =\n          JUST(GlobalToGlobal(cpu_tensor, parallel_desc, sbp_parallels, GetNoneSbpList(), copy));\n    }\n    return tensor;\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass ToGlobalFunctor {\n public:\n  ToGlobalFunctor() {\n    local_to_global_op_ =\n        CHECK_JUST(one::LocalToGlobalOpExpr::New(*CHECK_JUST(UniqueStr(\"local_to_global\"))));\n  }\n\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,\n                           Symbol<ParallelDesc> parallel_desc,\n                           const std::vector<Symbol<SbpParallel>>& sbp_parallels,\n                           const std::vector<Symbol<SbpParallel>>& grad_sbp_parallels,\n                           bool check_meta, bool copy) const {\n    JUST(CheckDeviceIdsIsValid(parallel_desc));\n    NonRecursiveMetaInfoConsistencyCheckScope scope;\n    JUST(MetaInfoConsistencyCheck(parallel_desc, sbp_parallels, grad_sbp_parallels, 1,\n                                  /* force_check */ check_meta));\n    std::shared_ptr<Tensor> tensor;\n    if (x->is_global()) {\n      tensor = JUST(GlobalToGlobal(x, parallel_desc, sbp_parallels, grad_sbp_parallels, copy));\n    } else {\n      DeviceType device_type = parallel_desc->device_type();\n      if (ccl::IsBroadcastRegistered(device_type)) {\n        tensor = JUST(LocalToGlobal(x, parallel_desc, sbp_parallels, NullOpt, NullOpt,\n                                    local_to_global_op_, check_meta, /* sync_data */ true, copy));\n      } else {\n        // Assuming that the newly adapted hardware device does not support collective\n        // communication, since local to global may need to synchronize data (through the\n        // broadcast API), if device_type is neither cpu nor cuda, generate global tensor\n        // with the corresponding cpu placement first, then convert the cpu global tensor\n        // to the desired placement.\n        Symbol<ParallelDesc> cpu_parallel_desc =\n            JUST(ReplaceDeviceType(parallel_desc, DeviceType::kCPU));\n        std::shared_ptr<Tensor> cpu_tensor =\n            JUST(LocalToGlobal(x, cpu_parallel_desc, sbp_parallels, NullOpt, NullOpt,\n                               local_to_global_op_, check_meta, /* sync_data */ true, copy));\n        tensor =\n            JUST(GlobalToGlobal(cpu_tensor, parallel_desc, sbp_parallels, GetNoneSbpList(), copy));\n      }\n    }\n    return tensor;\n  }\n\n private:\n  std::shared_ptr<OpExpr> local_to_global_op_;\n};\n\nclass GlobalToLocalFunctor {\n public:\n  GlobalToLocalFunctor() {\n    op_ = CHECK_JUST(one::GlobalToLocalOpExpr::New(*CHECK_JUST(UniqueStr(\"global_to_local\"))));\n  }\n\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, bool copy) const {\n    CHECK_OR_RETURN(!x->is_lazy())\n        << Error::RuntimeError()\n        << \"global_tensor.to_local() is not supported within nn.Graph for now\";\n    CHECK_OR_RETURN(x->is_global())\n        << Error::RuntimeError() << \"Expected global tensor for to_local but got local tensor!\";\n    const auto& local_tensor = JUST(OpInterpUtil::Dispatch<one::Tensor>(*op_, {x}));\n    if (copy) { return local_tensor->clone(); }\n    return local_tensor;\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\n}  // namespace impl\n\nONEFLOW_FUNCTION_LIBRARY(m) {\n  m.add_functor<impl::LocalToGlobalFunctor>(\"LocalToGlobal\");\n  m.add_functor<impl::ToGlobalFunctor>(\"ToGlobal\");\n  m.add_functor<impl::GlobalToLocalFunctor>(\"GlobalToLocal\");\n};\n\n}  // namespace functional\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/functional/impl/gradient_accumulation_functor.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/mutable_attr_map.h\"\n#include \"oneflow/core/framework/op_builder.h\"\n#include \"oneflow/core/framework/op_expr.h\"\n#include \"oneflow/core/framework/op_interpreter/op_interpreter_util.h\"\n#include \"oneflow/core/framework/tensor.h\"\n#include \"oneflow/core/functional/function_library.h\"\n#include \"oneflow/core/functional/functional.h\"\n#include \"oneflow/core/functional/impl/common.h\"\n\nnamespace oneflow {\nnamespace one {\nnamespace functional {\n\nnamespace impl {\n\nclass GradAccRepeatFunctor {\n public:\n  GradAccRepeatFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"repeat\").Input(\"in\").Output(\"out\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& in, int32_t repeat_num) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"repeat_num\");\n    attrs.SetAllAttrs(repeat_num);\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {in}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass GradAccCollectFunctor {\n public:\n  GradAccCollectFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"acc\").Input(\"in\").Output(\"out\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& in, int32_t collect_num) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"max_acc_num\");\n    attrs.SetAllAttrs(collect_num);\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {in}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass GradAccPackFunctor {\n public:\n  GradAccPackFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"pack\").Input(\"in\").Output(\"out\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& in, int32_t pack_num) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"pack_num\");\n    attrs.SetAllAttrs(pack_num);\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {in}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass GradAccUnpackFunctor {\n public:\n  GradAccUnpackFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"unpack\").Input(\"in\").Output(\"out\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& in, int32_t unpack_num) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"unpack_num\");\n    attrs.SetAllAttrs(unpack_num);\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {in}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\n}  // namespace impl\n\nONEFLOW_FUNCTION_LIBRARY(m) {\n  m.add_functor<impl::GradAccRepeatFunctor>(\"GradAccRepeat\");\n  m.add_functor<impl::GradAccCollectFunctor>(\"GradAccCollect\");\n  m.add_functor<impl::GradAccPackFunctor>(\"GradAccPack\");\n  m.add_functor<impl::GradAccUnpackFunctor>(\"GradAccUnpack\");\n}\n\n}  // namespace functional\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/functional/impl/higher_derivative_functor.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include <functional>\n#include \"oneflow/core/common/scalar.h\"\n#include \"oneflow/core/common/container_util.h\"\n#include \"oneflow/core/framework/attr_map.h\"\n#include \"oneflow/core/framework/op_builder.h\"\n#include \"oneflow/core/framework/op_expr.h\"\n#include \"oneflow/core/framework/op_interpreter/op_interpreter_util.h\"\n#include \"oneflow/core/framework/tensor.h\"\n#include \"oneflow/core/framework/tensor_tuple.h\"\n#include \"oneflow/core/functional/functional_api.yaml.h\"\n#include \"oneflow/core/functional/sequence_function.h\"\n#include \"oneflow/core/functional/function_library.h\"\n#include \"oneflow/core/functional/impl/common.h\"\n#include \"oneflow/core/functional/impl/unary_functor.h\"\n\nnamespace oneflow {\nnamespace one {\nnamespace functional {\n\nnamespace impl {\n\nclass SinGradGradFunctor {\n public:\n  Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& x,\n                           const std::shared_ptr<Tensor>& dydx) const {\n    auto res = sequence_function(functional::Sin)\n                   .then(functional::Negative)\n                   .then(std::bind(functional::Mul, dydx, std::placeholders::_1))\n                   .call(x);\n    return res;\n  }\n};\n\nclass CosGradGradFunctor {\n public:\n  Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& x,\n                           const std::shared_ptr<Tensor>& dydx) const {\n    auto res = sequence_function(functional::Cos)\n                   .then(functional::Negative)\n                   .then(std::bind(functional::Mul, dydx, std::placeholders::_1))\n                   .call(x);\n    return res;\n  }\n};\n\nclass TanGradGradFunctor {\n public:\n  // dx = 1/cos^2(x), ddx = 2*sinx/cos^3(x) = tan_grad(x)*tan(x)*2\n  Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& x,\n                           const std::shared_ptr<Tensor>& dydx) const {\n    auto r = sequence_function(functional::Mul)\n                 .then([](const std::shared_ptr<Tensor>& input) {\n                   return functional::ScalarMul(Scalar(2), input);\n                 })\n                 .call(JUST(functional::Tan(x)), JUST(functional::TanGrad(x, dydx)));\n    return r;\n  }\n};\n\nclass SinhGradGradFunctor {\n public:\n  // dx = cosh(x), ddx = sinh(x) = cosh_grad(x)\n  Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& x,\n                           const std::shared_ptr<Tensor>& dydx) const {\n    return functional::CoshGrad(x, dydx);\n  }\n};\n\nclass CoshGradGradFunctor {\n public:\n  // dx = sinh(x), ddx = cosh(x) = sinh_grad(x)\n  Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& x,\n                           const std::shared_ptr<Tensor>& dydx) const {\n    return functional::SinhGrad(x, dydx);\n  }\n};\n\nclass TanhGradGradFunctor {\n public:\n  // dx = sech^2(x), ddx = -2*sech^2(x)*tanh(x) = dydx*tanh(x)*(-2)\n  Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& x,\n                           const std::shared_ptr<Tensor>& dydx) const {\n    auto r = sequence_function(functional::Mul)\n                 .then([](const std::shared_ptr<Tensor>& input) {\n                   return functional::ScalarMul(Scalar(-2), input);\n                 })\n                 .call(dydx, x);\n    return r;\n  }\n};\n\nclass AsinGradGradFunctor {\n public:\n  // dx = 1/sqrt(1-x*x)=rsqrt(1-x*x), ddx = rsqrt_grad(1-x*x)*(-2x)\n  Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& x,\n                           const std::shared_ptr<Tensor>& dydx) const {\n    auto r = sequence_function(functional::Square)\n                 .then([](const std::shared_ptr<Tensor>& input) {\n                   return functional::ScalarSub(Scalar(1), input, /*alpha=*/1.0);\n                 })\n                 .then(std::bind(functional::RsqrtGrad, std::placeholders::_1, dydx))\n                 .then(std::bind(functional::Mul, std::placeholders::_1, x))\n                 .then([](const std::shared_ptr<Tensor>& input) {\n                   return functional::ScalarMul(Scalar(-2), input);\n                 })\n                 .call(x);\n    return r;\n  }\n};\nclass AcosGradGradFunctor {\n public:\n  // dx = -1/sqrt(1-x*x)=-rsqrt(1-x*x), ddx = rsqrt_grad(1-x*x)*(2x)\n  Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& x,\n                           const std::shared_ptr<Tensor>& dydx) const {\n    auto r = sequence_function(functional::Square)\n                 .then([](const std::shared_ptr<Tensor>& input) {\n                   return functional::ScalarSub(Scalar(1), input, /*alpha=*/1.0);\n                 })\n                 .then(std::bind(functional::RsqrtGrad, std::placeholders::_1, dydx))\n                 .then(std::bind(functional::Mul, std::placeholders::_1, x))\n                 .then([](const std::shared_ptr<Tensor>& input) {\n                   return functional::ScalarMul(Scalar(2), input);\n                 })\n                 .call(x);\n    return r;\n  }\n};\n\nclass AtanGradGradFunctor {\n public:\n  // dx = 1/(1+x*x), ddx = reci_grad(1+x*x)*(2x)\n  Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& x,\n                           const std::shared_ptr<Tensor>& dydx) const {\n    auto r = sequence_function(functional::Square)\n                 .then([](const std::shared_ptr<Tensor>& input) {\n                   return functional::ScalarAdd(Scalar(1), input, /*alpha=*/1.0);\n                 })\n                 .then(std::bind(functional::ReciprocalGrad, std::placeholders::_1, dydx))\n                 .then(std::bind(functional::Mul, std::placeholders::_1, x))\n                 .then([](const std::shared_ptr<Tensor>& input) {\n                   return functional::ScalarMul(Scalar(2), input);\n                 })\n                 .call(x);\n    return r;\n  }\n};\n\nclass AsinhGradGradFunctor {\n public:\n  // dx = 1/sqrt(1+x*x)=rsqrt(1+x*x), ddx = rsqrt_grad(1+x*x)*(2x)\n  Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& x,\n                           const std::shared_ptr<Tensor>& dydx) const {\n    auto r = sequence_function(functional::Square)\n                 .then([](const std::shared_ptr<Tensor>& input) {\n                   return functional::ScalarAdd(Scalar(1), input, /*alpha=*/1.0);\n                 })\n                 .then(std::bind(functional::RsqrtGrad, std::placeholders::_1, dydx))\n                 .then(std::bind(functional::Mul, std::placeholders::_1, x))\n                 .then([](const std::shared_ptr<Tensor>& input) {\n                   return functional::ScalarMul(Scalar(2), input);\n                 })\n                 .call(x);\n    return r;\n  }\n};\n\nclass AcoshGradGradFunctor {\n public:\n  // dx = 1/sqrt(x*x-1)=rsqrt(x*x-1), ddx = rsqrt_grad(x*x-1)*(2x)\n  Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& x,\n                           const std::shared_ptr<Tensor>& dydx) const {\n    auto r = sequence_function(functional::Square)\n                 .then([](const std::shared_ptr<Tensor>& input) {\n                   return functional::ScalarSub(input, Scalar(1), /*alpha=*/1.0,\n                                                /*inplace=*/false);\n                 })\n                 .then(std::bind(functional::RsqrtGrad, std::placeholders::_1, dydx))\n                 .then(std::bind(functional::Mul, std::placeholders::_1, x))\n                 .then([](const std::shared_ptr<Tensor>& input) {\n                   return functional::ScalarMul(Scalar(2), input);\n                 })\n                 .call(x);\n\n    return r;\n  }\n};\n\nclass AtanhGradGradFunctor {\n public:\n  // dx = 1/(1-x*x), ddx = reci_grad(1-x*x)*(-2x)\n  Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& x,\n                           const std::shared_ptr<Tensor>& dydx) const {\n    auto r = sequence_function(functional::Square)\n                 .then([](const std::shared_ptr<Tensor>& input) {\n                   return functional::ScalarSub(Scalar(1), input, /*alpha=*/1.0);\n                 })\n                 .then(std::bind(functional::ReciprocalGrad, std::placeholders::_1, dydx))\n                 .then(std::bind(functional::Mul, std::placeholders::_1, x))\n                 .then([](const std::shared_ptr<Tensor>& input) {\n                   return functional::ScalarMul(Scalar(-2), input);\n                 })\n                 .call(x);\n    return r;\n  }\n};\n\nclass ErfGradGradFunctor {\n public:\n  Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& x,\n                           const std::shared_ptr<Tensor>& dydx) const {\n    return functional::ScalarMul(Scalar(-2),\n                                 JUST(functional::Mul(x, JUST(functional::ErfGrad(x, dydx)))));\n  }\n};\n\nclass ErfcGradGradFunctor {\n public:\n  Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& x,\n                           const std::shared_ptr<Tensor>& dydx) const {\n    return functional::ScalarMul(Scalar(-2),\n                                 JUST(functional::Mul(x, JUST(functional::ErfcGrad(x, dydx)))));\n  }\n};\n\nclass ExpGradGradFunctor {\n public:\n  Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& x,\n                           const std::shared_ptr<Tensor>& dydx) const {\n    return functional::ExpGrad(x, dydx);\n  }\n};\n\nclass Exp2GradGradFunctor {\n public:\n  Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& x,\n                           const std::shared_ptr<Tensor>& dydx) const {\n    return functional::ScalarMul(Scalar(std::log(2)), JUST(functional::Exp2Grad(x, dydx)));\n  }\n};\n\nclass Expm1GradGradFunctor {\n public:\n  Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& x,\n                           const std::shared_ptr<Tensor>& dydx) const {\n    return functional::ExpGrad(x, dydx);\n  }\n};\n\nclass LogGradGradFunctor {\n public:\n  Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& x,\n                           const std::shared_ptr<Tensor>& dydx) const {\n    return functional::ReciprocalGrad(x, dydx);\n  }\n};\n\nclass Log2GradGradFunctor {\n public:\n  // dx = 1/(x*ln2), ddx = 1/ln2 * -1/(x*x)\n  Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& x,\n                           const std::shared_ptr<Tensor>& dydx) const {\n    return functional::ScalarMul(Scalar(1.0 / std::log(2.0f)),\n                                 JUST(functional::ReciprocalGrad(x, dydx)));\n  }\n};\n\nclass Log10GradGradFunctor {\n public:\n  // dx = 1/(x*ln10), ddx = 1/ln10 * -1/(x*x)\n  Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& x,\n                           const std::shared_ptr<Tensor>& dydx) const {\n    return functional::ScalarMul(Scalar(1.0 / std::log(10.0f)),\n                                 JUST(functional::ReciprocalGrad(x, dydx)));\n  }\n};\n\nclass Log1pGradGradFunctor {\n public:\n  Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& x,\n                           const std::shared_ptr<Tensor>& dydx) const {\n    return functional::ReciprocalGrad(\n        JUST(functional::ScalarAdd(Scalar(1), x, /*alpha=*/Scalar(1))), dydx);\n  }\n};\n\nclass LogSigmoidGradGradFunctor {\n public:\n  // dx = exp(-x)/(1+exp(-x)), ddx = -exp(-x)/(1+exp(-x))^2 = -sigmoid_grad(x)\n  Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& x,\n                           const std::shared_ptr<Tensor>& dydx) const {\n    return functional::Negative(JUST(functional::SigmoidGrad(JUST(functional::Sigmoid(x)), dydx)));\n  }\n};\n\nclass ReciprocalGradGradFunctor {\n public:\n  Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& x,\n                           const std::shared_ptr<Tensor>& dydx) const {\n    return functional::Negative(JUST(functional::ScalarPowGrad(x, dydx, Scalar(-2))));\n  }\n};\n\nclass ReciprocalNoNanGradGradFunctor {\n public:\n  // dx = -pow(x,-2), ddx = -pow_grad(x,-2)\n  Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& x,\n                           const std::shared_ptr<Tensor>& dydx) const {\n    return functional::Negative(JUST(functional::ScalarPowGrad(x, dydx, Scalar(-2))));\n  }\n};\n\nclass RsqrtGradGradFunctor {\n public:\n  // dx = -0.5*pow(x,-1.5), ddx = -0.5*pow_grad(x,-1.5)\n  Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& x,\n                           const std::shared_ptr<Tensor>& dydx) const {\n    return functional::ScalarMul(Scalar(-0.5),\n                                 JUST(functional::ScalarPowGrad(x, dydx, Scalar(-1.5))));\n  }\n};\n\nclass SqrtGradGradFunctor {\n public:\n  // dx = 0.5*pow(x,-0.5), ddx = -0.25*pow(x,-1.5) = 0.5*rsqrt_grad(x)\n  Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& x,\n                           const std::shared_ptr<Tensor>& dydx) const {\n    return functional::ScalarMul(Scalar(0.5), JUST(functional::RsqrtGrad(x, dydx)));\n  }\n};\n\nclass SquareGradGradFunctor {\n public:\n  Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& x,\n                           const std::shared_ptr<Tensor>& dydx) const {\n    return functional::ScalarMul(2, dydx);\n  }\n};\n\nclass SigmoidGradGradFunctor {\n public:\n  // dy = y * (1 - y), ddy = 1 - 2*y\n  Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& y,\n                           const std::shared_ptr<Tensor>& dydx) const {\n    return functional::Mul(JUST(functional::ScalarSub(1, y, /*alpha=*/2)), dydx);\n  }\n};\n\nclass SiluGradGradFunctor {\n public:\n  // y     = x ∗ sigmoid(x)\n  // y'    = (sig(x) + x * sig_grad(x))\n  // y''   = (sig(x) + x*sig_grad(x))' = sig_grad(x)*(x+2-2*silu(x))\n  Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& x,\n                           const std::shared_ptr<Tensor>& dydx) const {\n    auto res = functional::sequence_function(functional::Silu)\n                   .then([](const std::shared_ptr<Tensor>& input) {\n                     return functional::ScalarSub(Scalar(2.0), input, /*alpha=*/Scalar(2.0));\n                   })\n                   .then([&x](const std::shared_ptr<Tensor>& input) {\n                     return functional::Add(x, input, /*alpha=*/Scalar(1.0), /*inplace=*/false);\n                   })\n                   // Since we use y to compute SigmoidGrad, here we need to use sigmoid with x to\n                   // compute x first.\n                   // TODO(zzk):  Implement SigmoidGradXDy func.\n                   .then(std::bind(functional::SigmoidGrad, JUST(functional::Sigmoid(x)),\n                                   std::placeholders::_1))\n                   .then(std::bind(functional::Mul, dydx, std::placeholders::_1))\n                   .call(x);\n    return res;\n  }\n};\n\nclass SeluGradGradFunctor {\n public:\n  // y'' = scale * alpha * exp(x) (x < 0)\n  Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& x,\n                           const std::shared_ptr<Tensor>& dydx) const {\n    auto condition = JUST(functional::ScalarLogicalLess(x, Scalar(0.0)));\n    auto res = functional::Where(condition, JUST(functional::SeluGrad(dydx, x)),\n                                 JUST(functional::ZerosLike(x)));\n    return res;\n  }\n};\n\nclass SoftSignGradGradFunctor {\n public:\n  // y = x/(1+abs(x)), y' = 1/(1+abs(x))^2, y'' = -2/(1+abs(x))^3*abs_grad(x)\n  Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& x,\n                           const std::shared_ptr<Tensor>& dydx) const {\n    auto res = functional::sequence_function(functional::Abs)\n                   .then([](const std::shared_ptr<Tensor>& input) {\n                     return functional::ScalarAdd(Scalar(1.0), input, /*alpha=*/Scalar(1));\n                   })\n                   .then([](const std::shared_ptr<Tensor>& input) {\n                     return functional::ScalarPow(input, Scalar(-3), /*inplace=*/false);\n                   })\n                   .then([](const std::shared_ptr<Tensor>& input) {\n                     return functional::ScalarMul(Scalar(-2), input);\n                   })\n                   .then(std::bind(functional::AbsGrad, x, std::placeholders::_1))\n                   .then(std::bind(functional::Mul, dydx, std::placeholders::_1))\n                   .call(x);\n    return res;\n  }\n};\n\nclass HardSigmoidGradGradFunctor {\n public:\n  Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& x,\n                           const std::shared_ptr<Tensor>& dydx) const {\n    return functional::ZerosLike(x);\n  }\n};\n\nclass HardSwishGradGradFunctor {\n public:\n  Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& x,\n                           const std::shared_ptr<Tensor>& dydx) const {\n    auto condition = JUST(functional::ScalarLogicalGreater(\n        (JUST(functional::ScalarLogicalLess(x, Scalar(3.0)))), Scalar(-3.0)));\n    return functional::Where(condition, JUST(functional::ScalarDiv(dydx, Scalar(3.0))),\n                             JUST(functional::ZerosLike(x)));\n  }\n};\n\nclass SoftplusGradGradFunctor {\n public:\n  // beta*x <= threshold:\n  // y = 1/beta*ln(1+exp(beta*x)), y' = 1/(1+exp(beta*x))*exp(beta*x)\n  // y'' = beta*exp(beta*x)/(1+exp(beta*x))^2 = beta*sig(beta*x)(1-sig(beta*x))\n  //     = beta*sig_grad(beta*x)\n  Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& x, const std::shared_ptr<Tensor>& dydx,\n                           const double& beta, const double& threshold) const {\n    auto beta_x = JUST(functional::ScalarMul(x, beta, /*inplace=*/false));\n    auto condition = JUST(functional::ScalarLogicalLess(beta_x, Scalar(threshold)));\n    auto zero_out = JUST(functional::ZerosLike(x));\n    auto res = functional::sequence_function(functional::Sigmoid)\n                   .then(std::bind(functional::SigmoidGrad, std::placeholders::_1, dydx))\n                   .then([&beta](const std::shared_ptr<Tensor>& input) {\n                     return functional::ScalarMul(Scalar(beta), input);\n                   })\n                   .then(std::bind(functional::Where, condition, std::placeholders::_1, zero_out))\n                   .call(beta_x);\n\n    return res;\n  }\n};\n\nclass EluGradGradFunctor {\n public:\n  // y = max(0,x) + min(0,alpha∗(exp(x)−1))\n  Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& x, const std::shared_ptr<Tensor>& dydx,\n                           const double& alpha) const {\n    auto condition = JUST(functional::ScalarLogicalLess(x, Scalar(0.0)));\n    return functional::Where(condition, JUST(functional::EluGrad(x, dydx, alpha)),\n                             JUST(functional::ZerosLike(x)));\n  }\n};\n\nclass CeluGradGradFunctor {\n public:\n  // y = max(0,x) + min(0,alpha∗(exp(x/alpha)−1))\n  Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& y, const std::shared_ptr<Tensor>& dydx,\n                           const double& alpha) const {\n    auto condition = JUST(functional::ScalarLogicalLess(y, Scalar(0)));\n    auto r = functional::Where(condition, JUST(functional::ScalarDiv(dydx, alpha)),\n                               JUST(functional::ZerosLike(y)));\n    return r;\n  }\n};\n\nclass MaxPoolNdGradGradFunctor {\n public:\n  Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& dydx,\n                           const std::shared_ptr<Tensor>& indices, const int ndims) const {\n    if (indices->nelement()) {\n      Shape view_shape(indices->shape()->begin(), indices->shape()->end() - ndims);\n      view_shape.push_back(-1);\n      auto indices_view = JUST(functional::Reshape(indices, view_shape));\n      auto outgrad_view = JUST(functional::Reshape(dydx, view_shape));\n      return functional::sequence_function(functional::DimGather)\n          .then(std::bind(functional::Reshape, std::placeholders::_1, *indices->shape()))\n          .call(outgrad_view, -1, indices_view, /*sparse_grad=*/false);\n    } else {\n      // empty inputs, return 0size tensor\n      return functional::ZerosLike(indices);\n    }\n  }\n};\n\nclass MishGradGradFunctor {\n public:\n  // y = x ∗ tanh(softplus(x))\n  // ddx = grad_tsp * sig * (2 + x * (1 + (-1 - 2 * tsp) * sig)), sig equal grad_sp here\n  Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& x,\n                           const std::shared_ptr<Tensor>& dydx) const {\n    const auto sig = JUST(functional::Sigmoid(x));\n    const auto sp = JUST(functional::Log1p(JUST(functional::Exp(x))));\n    const auto tanh_sp = JUST(functional::Tanh(sp));\n    const auto grad_tsp = JUST(functional::TanhGrad(tanh_sp, dydx));\n\n    auto r = functional::sequence_function(functional::Tanh)\n                 .then([](const std::shared_ptr<Tensor>& input) {\n                   return functional::ScalarAdd(-1, input, /*alpha=*/-2);\n                 })\n                 .then(std::bind(functional::Mul, std::placeholders::_1, sig))\n                 .then([](const std::shared_ptr<Tensor>& input) {\n                   return functional::ScalarAdd(1, input, /*alpha=*/1);\n                 })\n                 .then(std::bind(functional::Mul, std::placeholders::_1, x))\n                 .then([](const std::shared_ptr<Tensor>& input) {\n                   return functional::ScalarAdd(2, input, /*alpha=*/1);\n                 })\n                 .then(std::bind(functional::Mul, std::placeholders::_1, sig))\n                 .then(std::bind(functional::Mul, std::placeholders::_1, grad_tsp))\n                 .call(sp);\n    return r;\n  }\n};\n\nclass GeluGradGradFunctor {\n public:\n  // y = gussian(x) = 0.5 * x * (1.0 + erf(sqrt(0.5) * x));\n  // dx = 0.5 * (1.0 + erf(sqrt(0.5)*x) + x * coef * exp(-0.5*x*x)) * dy), coef = sqrt(-2.0/pi)\n  // ddx = coef * grad1 * grad2 * flow.exp(t) * (1+t), t = -0.5*x*x\n  Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& x,\n                           const std::shared_ptr<Tensor>& dydx) const {\n    const auto& tmp = JUST(functional::ScalarMul(-0.5, JUST(functional::Square(x))));\n    const auto& tmp_add_one = JUST(functional::ScalarAdd(1, tmp, 1));\n    const Scalar coef = std::sqrt(2.0 / std::acos(-1.0));\n\n    auto r = functional::sequence_function(functional::Exp)\n                 .then(std::bind(functional::Mul, std::placeholders::_1, tmp_add_one))\n                 .then(std::bind(functional::Mul, std::placeholders::_1, dydx))\n                 .then([&coef](const std::shared_ptr<Tensor>& input) {\n                   return functional::ScalarMul(coef, input);\n                 })\n                 .call(tmp);\n    return r;\n  }\n};\n}  // namespace impl\n\nONEFLOW_FUNCTION_LIBRARY(m) {\n  m.add_functor<impl::SinGradGradFunctor>(\"SinGradGrad\");\n  m.add_functor<impl::CosGradGradFunctor>(\"CosGradGrad\");\n  m.add_functor<impl::TanGradGradFunctor>(\"TanGradGrad\");\n  m.add_functor<impl::SinhGradGradFunctor>(\"SinhGradGrad\");\n  m.add_functor<impl::CoshGradGradFunctor>(\"CoshGradGrad\");\n  m.add_functor<impl::TanhGradGradFunctor>(\"TanhGradGrad\");\n  m.add_functor<impl::AsinGradGradFunctor>(\"AsinGradGrad\");\n  m.add_functor<impl::AcosGradGradFunctor>(\"AcosGradGrad\");\n  m.add_functor<impl::AtanGradGradFunctor>(\"AtanGradGrad\");\n  m.add_functor<impl::AsinhGradGradFunctor>(\"AsinhGradGrad\");\n  m.add_functor<impl::AcoshGradGradFunctor>(\"AcoshGradGrad\");\n  m.add_functor<impl::AtanhGradGradFunctor>(\"AtanhGradGrad\");\n  m.add_functor<impl::ErfGradGradFunctor>(\"ErfGradGrad\");\n  m.add_functor<impl::ErfcGradGradFunctor>(\"ErfcGradGrad\");\n  m.add_functor<impl::ExpGradGradFunctor>(\"ExpGradGrad\");\n  m.add_functor<impl::Exp2GradGradFunctor>(\"Exp2GradGrad\");\n  m.add_functor<impl::Expm1GradGradFunctor>(\"Expm1GradGrad\");\n  m.add_functor<impl::LogGradGradFunctor>(\"LogGradGrad\");\n  m.add_functor<impl::Log2GradGradFunctor>(\"Log2GradGrad\");\n  m.add_functor<impl::Log10GradGradFunctor>(\"Log10GradGrad\");\n  m.add_functor<impl::Log1pGradGradFunctor>(\"Log1pGradGrad\");\n  m.add_functor<impl::LogSigmoidGradGradFunctor>(\"LogSigmoidGradGrad\");\n  m.add_functor<impl::ReciprocalGradGradFunctor>(\"ReciprocalGradGrad\");\n  m.add_functor<impl::ReciprocalNoNanGradGradFunctor>(\"ReciprocalNoNanGradGrad\");\n  m.add_functor<impl::RsqrtGradGradFunctor>(\"RsqrtGradGrad\");\n  m.add_functor<impl::SqrtGradGradFunctor>(\"SqrtGradGrad\");\n  m.add_functor<impl::SquareGradGradFunctor>(\"SquareGradGrad\");\n  m.add_functor<impl::SigmoidGradGradFunctor>(\"SigmoidGradGrad\");\n  m.add_functor<impl::SiluGradGradFunctor>(\"SiluGradGrad\");\n  m.add_functor<impl::SeluGradGradFunctor>(\"SeluGradGrad\");\n  m.add_functor<impl::SoftSignGradGradFunctor>(\"SoftSignGradGrad\");\n  m.add_functor<impl::HardSigmoidGradGradFunctor>(\"HardSigmoidGradGrad\");\n  m.add_functor<impl::HardSwishGradGradFunctor>(\"HardSwishGradGrad\");\n  m.add_functor<impl::SoftplusGradGradFunctor>(\"SoftplusGradGrad\");\n  m.add_functor<impl::EluGradGradFunctor>(\"EluGradGrad\");\n  m.add_functor<impl::CeluGradGradFunctor>(\"CeluGradGrad\");\n  m.add_functor<impl::MaxPoolNdGradGradFunctor>(\"MaxPoolNdGradGrad\");\n  m.add_functor<impl::MishGradGradFunctor>(\"MishGradGrad\");\n  m.add_functor<impl::GeluGradGradFunctor>(\"GeluGradGrad\");\n}\n\n}  // namespace functional\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/functional/impl/linalg_functor.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"fmt/core.h\"\n#include \"oneflow/core/common/device_type.pb.h\"\n#include \"oneflow/core/common/error.h\"\n#include \"oneflow/core/common/error.pb.h\"\n#include \"oneflow/core/common/just.h\"\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/common/tensor_desc.h\"\n#include \"oneflow/core/framework/attr_map.h\"\n#include \"oneflow/core/framework/mutable_attr_map.h\"\n#include \"oneflow/core/framework/op_builder.h\"\n#include \"oneflow/core/framework/op_expr.h\"\n#include \"oneflow/core/framework/op_interpreter/op_interpreter_util.h\"\n#include \"oneflow/core/functional/functional.h\"\n#include \"oneflow/core/functional/function_library.h\"\n#include \"oneflow/core/functional/functional_api.yaml.h\"\n#include \"oneflow/core/functional/impl/common.h\"\n\nnamespace oneflow {\nnamespace one {\nnamespace functional {\n\nnamespace impl {\nnamespace linalg {\n\nclass CrossFunctor {\n public:\n  CrossFunctor() {\n    op_ = CHECK_JUST(OpBuilder(\"linalg_cross\").Input(\"input\").Input(\"other\").Output(\"out\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input,\n                           const std::shared_ptr<one::Tensor>& other,\n                           const Optional<int64_t>& dim) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"dim\");\n\n    const auto do_dispatch_base_on_device = [&attrs, this](\n                                                const std::shared_ptr<one::Tensor>& input,\n                                                const std::shared_ptr<one::Tensor>& other,\n                                                const int64_t dim) -> Maybe<Tensor> {\n      DeviceType device{};\n\n      if (input->is_global()) {\n        device = JUST(input->parallel_desc())->device_type();\n      } else {\n        device = JUST(input->device())->enum_type();\n      }\n\n      const int64_t final_dim = input->ndim() - 1;\n\n      if (device == DeviceType::kCUDA && dim != final_dim) {\n        attrs.SetAllAttrs(final_dim);\n\n        std::vector<int> perm(input->ndim(), 0);\n        for (size_t i = 0; i < perm.size(); ++i) { perm[i] = static_cast<int>(i); }\n        std::swap(perm[dim], perm[final_dim]);\n        return functional::Transpose(\n            JUST(OpInterpUtil::Dispatch<Tensor>(*op_,\n                                                {JUST(functional::Transpose(input, perm)),\n                                                 JUST(functional::Transpose(other, perm))},\n                                                attrs)),\n            perm);\n      }\n\n      attrs.SetAllAttrs(dim);\n      return OpInterpUtil::Dispatch<Tensor>(*op_, {input, other}, attrs);\n    };\n\n    Shape shape_to_broadcast;\n    std::deque<bool> need_to_broadcast;\n\n    std::tie(shape_to_broadcast, need_to_broadcast) =\n        *JUST(InferUnifiedShapeForBroadcastingWithInfo({*input->shape(), *other->shape()}));\n    CHECK_EQ_OR_RETURN(need_to_broadcast.size(), 2)\n        << fmt::format(\"The number of boolean values to determine if the tensor is to be broadcast \"\n                       \"should be 2 (which is {})\",\n                       need_to_broadcast.size());\n    const auto new_input =\n        need_to_broadcast[0] ? JUST(functional::Expand(input, shape_to_broadcast)) : input;\n    const auto new_other =\n        need_to_broadcast[1] ? JUST(functional::Expand(other, shape_to_broadcast)) : other;\n\n    if (!dim.has_value()) {\n      return do_dispatch_base_on_device(new_input, new_other,\n                                        JUST(FindValidDim(shape_to_broadcast)));\n    }\n\n    int64_t new_dim = JUST(dim);\n    if (new_dim < 0) { new_dim += shape_to_broadcast.NumAxes(); }\n    CHECK_EQ_OR_RETURN(shape_to_broadcast.At(new_dim), 3)\n        << Error::RuntimeError()\n        << fmt::format(\"the size of the specified dimension(which is {}) is not 3.\", JUST(dim));\n\n    return do_dispatch_base_on_device(new_input, new_other, new_dim);\n  }\n\n private:\n  Maybe<int64_t> FindValidDim(const Shape& shape) const {\n    int64_t valid_dim = -1;\n    const auto& dim_vec = shape.dim_vec();\n    for (size_t i = 0; i < dim_vec.size(); ++i) {\n      if (dim_vec[i] == 3) {\n        valid_dim = i;\n        break;\n      }\n    }\n    if (valid_dim == -1) { return Error::RuntimeError() << \"no dimension of size 3 in input.\"; }\n    return valid_dim;\n  }\n\n  std::shared_ptr<OpExpr> op_;\n};\n\n}  // namespace linalg\n}  // namespace impl\n\nusing namespace impl::linalg;\n\nONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor<CrossFunctor>(\"LinalgCross\"); }\n\n}  // namespace functional\n}  // namespace one\n}  // namespace oneflow"
  },
  {
    "path": "oneflow/core/functional/impl/math_functor.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/autograd/autograd_mode.h\"\n#include \"oneflow/core/common/container_util.h\"\n#include \"oneflow/core/framework/mutable_attr_map.h\"\n#include \"oneflow/core/framework/op_builder.h\"\n#include \"oneflow/core/framework/op_expr.h\"\n#include \"oneflow/core/framework/op_interpreter/op_interpreter_util.h\"\n#include \"oneflow/core/framework/tensor_tuple.h\"\n#include \"oneflow/core/functional/functional.h\"\n#include \"oneflow/core/functional/function_library.h\"\n#include \"oneflow/core/functional/functional_api.yaml.h\"\n#include \"oneflow/core/functional/impl/binary_functor.h\"\n#include \"oneflow/core/functional/impl/common.h\"\n#include \"oneflow/core/functional/sequence_function.h\"\n#include \"oneflow/core/job/lazy_mode.h\"\n#include \"oneflow/core/functional/tensor_processor.h\"\n#include \"oneflow/core/profiler/profiler.h\"\n\nnamespace oneflow {\nnamespace one {\nnamespace functional {\nnamespace impl {\n\nclass AddNFunctor {\n public:\n  AddNFunctor() {\n    op_.resize(kMaxInputCount /*the maximum number of inputs*/);\n    for (int n = 1; n < op_.size(); ++n) {\n      op_[n] = CHECK_JUST(one::OpBuilder(\"add_n\").Input(\"in\", n + 1).Output(\"out\").Build());\n    }\n  }\n  Maybe<Tensor> operator()(const TensorTuple& inputs, bool inplace) const {\n    CHECK_GE_OR_RETURN(inputs.size(), 2);\n    TensorTuple outputs;\n    for (int i = 0; i < inputs.size(); i += kMaxInputCount) {\n      size_t size = (i + kMaxInputCount) < inputs.size() ? kMaxInputCount : inputs.size() - i;\n      TensorTuple partial_inputs(size);\n      std::copy(inputs.begin() + i, inputs.begin() + i + size, partial_inputs.begin());\n      if (i == 0 && inplace) {\n        JUST(CheckInplaceValid(partial_inputs.at(0)));\n        std::shared_ptr<TensorTuple> outs = std::make_shared<TensorTuple>(1);\n        (*outs)[0] = partial_inputs[0];\n        JUST(OpInterpUtil::Dispatch(*op_.at(size - 1), partial_inputs, outs.get()));\n        outputs.emplace_back((*outs)[0]);\n      } else {\n        outputs.emplace_back(\n            JUST(OpInterpUtil::Dispatch<Tensor>(*op_.at(size - 1), partial_inputs)));\n      }\n    }\n    if (outputs.size() == 1) { return outputs.at(0); }\n    return this->operator()(outputs, inplace);\n  }\n\n private:\n  std::vector<std::shared_ptr<OpExpr>> op_;\n};\n\nclass ScalarMathBaseFunctor {\n public:\n  explicit ScalarMathBaseFunctor(std::string op_name) {\n    op_ = CHECK_JUST(one::OpBuilder(op_name).Input(\"in\").Output(\"out\").Build());\n  }\n  virtual ~ScalarMathBaseFunctor() = default;\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const Scalar& scalar,\n                           bool inplace) const {\n    if (std::dynamic_pointer_cast<StaticZerosTensor>(x) && op_->op_type_name() == \"scalar_mul\") {\n      return x;\n    }\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"float_operand\", \"has_float_operand\",\n                                                 \"int_operand\", \"has_int_operand\");\n    TensorProcessor tensor_processor;\n    Symbol<DType> lowest_dtype;\n    if (scalar.IsFloatingPoint() || scalar.IsComplex()) {\n      attrs.SetAllAttrs(scalar.As<double>(), true, NullOpt, false);\n      // Only promote type to Float32 when tensor is Int type but scalar is float type.\n      if (DType::priority_order[x->dtype()->data_type()]\n          < DType::priority_order[DType::Float16()->data_type()]) {\n        lowest_dtype = DType::Float();\n      } else {\n        lowest_dtype = x->dtype();\n      }\n    } else if (scalar.IsIntegral()) {\n      attrs.SetAllAttrs(NullOpt, false, scalar.As<int64_t>(), true);\n      // Promote type to Int64 when tensor is Bool type but scalar is int type.\n      // Promote type to Float32 when op is scalar_div.\n      if (DType::priority_order[x->dtype()->data_type()]\n          == DType::priority_order[DType::Bool()->data_type()]) {\n        lowest_dtype = DType::Int64();\n      } else if (op_->op_type_name() == \"scalar_div\") {\n        lowest_dtype = x->dtype() == DType::Float16() ? DType::Float16() : DType::Float();\n      } else {\n        lowest_dtype = x->dtype();\n      }\n    } else {\n      UNIMPLEMENTED_THEN_RETURN() << \"The scalar in \" << op_->op_type_name()\n                                  << \" should be float or int.\";\n    }\n    JUST(tensor_processor.AddInputs({x}, lowest_dtype).Apply());\n    TensorTuple casted_vec = JUST(tensor_processor.GetInputs());\n    if (inplace) {\n      JUST(CheckInplaceCastValid(x, casted_vec[0]));\n      JUST(CheckInplaceValid(x));\n\n      std::shared_ptr<TensorTuple> outputs = std::make_shared<TensorTuple>(1);\n      (*outputs)[0] = x;\n      JUST(OpInterpUtil::Dispatch(*op_, {x}, outputs.get(), OpExprInterpContext(attrs)));\n      return outputs->at(0);\n    } else {\n      return OpInterpUtil::Dispatch<Tensor>(*op_, casted_vec, attrs);\n    }\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass ScalarAddFunctor : public ScalarMathBaseFunctor {\n public:\n  ScalarAddFunctor() : ScalarMathBaseFunctor(/*op_name=*/\"scalar_add\") {}\n\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input, const Scalar& other,\n                           const Scalar& alpha, const bool& inplace) const {\n    if (IsIntegralDataType(input->dtype()->data_type()) && other.IsIntegral()\n        && alpha.IsFloatingPoint()) {\n      return Error::RuntimeError()\n             << \"For integral input tensors, argument alpha must not be a floating point number.\";\n    }\n    Scalar scalar;\n    if (other.IsFloatingPoint() || alpha.IsFloatingPoint()) {\n      scalar = Scalar(other.Value<double>() * alpha.Value<double>());\n    } else {\n      scalar = Scalar(other.Value<int64_t>() * alpha.Value<int64_t>());\n    }\n    return ScalarMathBaseFunctor::operator()(input, scalar, inplace);\n  }\n};\n\nclass ScalarAdd2Functor {\n public:\n  Maybe<Tensor> operator()(const Scalar& input, const std::shared_ptr<one::Tensor>& other,\n                           const Scalar& alpha) const {\n    if (IsIntegralDataType(other->dtype()->data_type()) && input.IsIntegral()\n        && alpha.IsFloatingPoint()) {\n      return Error::RuntimeError()\n             << \"For integral input tensors, argument alpha must not be a floating point number.\";\n    }\n    std::shared_ptr<one::Tensor> other_;\n    if ((alpha.IsIntegral() && alpha.Value<int64_t>() == 1)\n        || (alpha.IsFloatingPoint()\n            && std::fabs(alpha.Value<double>() - 1.0) < std::numeric_limits<double>::epsilon())) {\n      other_ = other;\n    } else {\n      other_ = JUST(ScalarMul(alpha, other));\n    }\n    return ScalarAdd(other_, input, /*alpha=*/1, /*inplace=*/false);\n  }\n};\n\nclass ScalarSubFunctor {\n public:\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input, const Scalar& scalar,\n                           const Scalar& alpha, bool inplace) const {\n    return ScalarAdd(input, Scalar(-1) * scalar, alpha, inplace);\n  }\n};\n\nclass ScalarSub2Functor {\n public:\n  Maybe<Tensor> operator()(const Scalar& scalar, const std::shared_ptr<one::Tensor>& input,\n                           const Scalar& alpha) const {\n    return ScalarAdd(scalar, input, Scalar(-1) * alpha);\n  }\n};\n\nclass ScalarMulFunctor : public ScalarMathBaseFunctor {\n public:\n  ScalarMulFunctor() : ScalarMathBaseFunctor(/*op_name=*/\"scalar_mul\") {}\n};\n\nclass ScalarMul2Functor {\n public:\n  Maybe<Tensor> operator()(const Scalar& scalar, const std::shared_ptr<one::Tensor>& x) const {\n    return ScalarMul(x, scalar, false);\n  }\n};\n\nclass InplaceScalarMulFunctor : public ScalarMathBaseFunctor {\n public:\n  InplaceScalarMulFunctor() : ScalarMathBaseFunctor(/*op_name=*/\"scalar_mul\") {}\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const Scalar& scalar) const {\n    return ScalarMathBaseFunctor::operator()(x, scalar, true);\n  }\n};\n\nclass ScalarDivFunctor : public ScalarMathBaseFunctor {\n public:\n  ScalarDivFunctor() : ScalarMathBaseFunctor(/*op_name=*/\"scalar_div\") {}\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const Scalar& scalar) const {\n    return ScalarMathBaseFunctor::operator()(x, scalar, false);\n  }\n};\n\nclass ScalarDivModeFunctor {\n public:\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const Scalar& scalar,\n                           const Optional<std::string>& rounding_mode) const {\n    std::string rmode = rounding_mode.value_or(\"\");\n    CHECK_OR_RETURN(rmode == \"\" || rmode == \"floor\" || rmode == \"trunc\")\n        << \"div expected rounding_mode to be one of None,\"\n           \" 'trunc', or 'floor' but found \"\n        << rmode;\n    std::shared_ptr<one::Tensor> ret = JUST(functional::ScalarDiv(x, scalar));\n    if (rmode == \"floor\") {\n      return JUST(functional::Floor(ret));\n\n    } else if (rmode == \"trunc\") {\n      return JUST(functional::Trunc(ret));\n    }\n\n    return ret;\n  }\n};\n\nclass ScalarDiv2Functor {\n public:\n  Maybe<Tensor> operator()(const Scalar& scalar, const std::shared_ptr<one::Tensor>& x) const {\n    return functional::ScalarMul(JUST(functional::Reciprocal(x)), scalar, /*inplace=*/false);\n  }\n};\n\nclass ScalarDivMode2Functor {\n public:\n  Maybe<Tensor> operator()(const Scalar& scalar, const std::shared_ptr<one::Tensor>& x,\n                           const Optional<std::string>& rounding_mode) const {\n    std::string rmode = rounding_mode.value_or(\"\");\n    CHECK_OR_RETURN(rmode == \"\" || rmode == \"floor\" || rmode == \"trunc\")\n        << \"div expected rounding_mode to be one of None,\"\n           \" 'trunc', or 'floor' but found \"\n        << rmode;\n    std::shared_ptr<one::Tensor> ret = JUST(functional::ScalarDiv(scalar, x));\n    if (rmode == \"floor\") {\n      return JUST(functional::Floor(ret));\n\n    } else if (rmode == \"trunc\") {\n      return JUST(functional::Trunc(ret));\n    }\n\n    return ret;\n  }\n};\n\nclass InplaceScalarDivFunctor : public ScalarMathBaseFunctor {\n public:\n  InplaceScalarDivFunctor() : ScalarMathBaseFunctor(/*op_name=*/\"scalar_mul\") {}\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const Scalar& scalar) const {\n    return ScalarMathBaseFunctor::operator()(x, Scalar(1.0) / scalar, true);\n  }\n};\n\nclass ScalarPowFunctor : public ScalarMathBaseFunctor {\n public:\n  ScalarPowFunctor() : ScalarMathBaseFunctor(/*op_name=*/\"scalar_pow\") {}\n};\n\nclass ScalarPowGradFunctor {\n public:\n  ScalarPowGradFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"scalar_pow_grad\").Input(\"x\").Input(\"dy\").Output(\"dx\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,\n                           const std::shared_ptr<one::Tensor>& dy, const Scalar& scalar) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"float_operand\", \"has_float_operand\",\n                                                 \"int_operand\", \"has_int_operand\");\n    if (scalar.IsFloatingPoint()) {\n      attrs.SetAllAttrs(scalar.As<double>(), true, NullOpt, false);\n    } else if (scalar.IsIntegral()) {\n      attrs.SetAllAttrs(NullOpt, false, scalar.As<int64_t>(), true);\n    } else {\n      UNIMPLEMENTED_THEN_RETURN() << \"The scalar in ScalarPowGrad should be float or int.\";\n    }\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {x, dy}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass ScalarReversePowFunctor : public ScalarMathBaseFunctor {\n public:\n  ScalarReversePowFunctor() : ScalarMathBaseFunctor(/*op_name=*/\"scalar_reverse_pow\") {}\n  Maybe<Tensor> operator()(const Scalar& scalar, const std::shared_ptr<one::Tensor>& input) const {\n    return ScalarMathBaseFunctor::operator()(input, scalar, false);\n  }\n};\n\nclass ScalarReversePowGradFunctor {\n public:\n  ScalarReversePowGradFunctor() {\n    op_ = CHECK_JUST(\n        one::OpBuilder(\"scalar_reverse_pow_grad\").Input(\"x\").Input(\"dy\").Output(\"dx\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,\n                           const std::shared_ptr<one::Tensor>& dy, const Scalar& scalar) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"float_operand\", \"has_float_operand\",\n                                                 \"int_operand\", \"has_int_operand\");\n    if (scalar.IsFloatingPoint()) {\n      attrs.SetAllAttrs(scalar.As<double>(), true, NullOpt, false);\n    } else if (scalar.IsIntegral()) {\n      attrs.SetAllAttrs(NullOpt, false, scalar.As<int64_t>(), true);\n    } else {\n      UNIMPLEMENTED_THEN_RETURN() << \"The scalar in ScalarTensorPowGrad should be float or int.\";\n    }\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {x, dy}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass ScalarFloorDivFunctor : public ScalarMathBaseFunctor {\n public:\n  ScalarFloorDivFunctor() : ScalarMathBaseFunctor(/*op_name=*/\"scalar_floordiv\") {}\n};\n\nclass ScalarTruncDivFunctor : public ScalarMathBaseFunctor {\n public:\n  ScalarTruncDivFunctor() : ScalarMathBaseFunctor(/*op_name=*/\"scalar_truncdiv\") {}\n};\n\nclass ScalarFModFunctor : public ScalarMathBaseFunctor {\n public:\n  ScalarFModFunctor() : ScalarMathBaseFunctor(/*op_name=*/\"scalar_fmod\") {}\n};\n\nclass ReduceMaxFunctor {\n public:\n  ReduceMaxFunctor() {\n    op_ = CHECK_JUST(\n        one::OpBuilder(\"reduce_max\").Input(\"input_tensor\").Output(\"output_tensor\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const std::vector<int32_t>& axis,\n                           const bool& keepdims) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"axis\", \"keepdims\");\n    if (axis.empty()) {\n      std::vector<int32_t> reduce_axis(x->ndim());\n      std::iota(reduce_axis.begin(), reduce_axis.end(), 0);\n      attrs.SetAllAttrs(reduce_axis, keepdims);\n    } else {\n      attrs.SetAllAttrs(axis, keepdims);\n    }\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {x}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass ReduceMinFunctor {\n public:\n  ReduceMinFunctor() {\n    op_ = CHECK_JUST(\n        one::OpBuilder(\"reduce_min\").Input(\"input_tensor\").Output(\"output_tensor\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const std::vector<int32_t>& axis,\n                           const bool& keepdims) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"axis\", \"keepdims\");\n    if (axis.empty()) {\n      std::vector<int32_t> reduce_axis(x->ndim());\n      std::iota(reduce_axis.begin(), reduce_axis.end(), 0);\n      attrs.SetAllAttrs(reduce_axis, keepdims);\n    } else {\n      attrs.SetAllAttrs(axis, keepdims);\n    }\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {x}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass MaxFunctor {\n public:\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x) const {\n    std::vector<int32_t> axis(x->ndim());\n    std::iota(axis.begin(), axis.end(), 0);\n    return ReduceMax(x, axis, /*keepdims=*/false);\n  }\n};\n\nclass Max2Functor {\n public:\n  Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& x, const int32_t& dim,\n                                const bool& keepdims) const {\n    auto outputs = std::make_shared<TensorTuple>(2);\n    int32_t axis = dim;\n    axis = JUST(maybe_wrap_dim(axis, x->ndim()));\n    (*outputs)[0] = JUST(ReduceMax(x, {axis}, keepdims));\n    (*outputs)[1] = JUST(ArgMax(x, dim, keepdims, NullOpt));\n    return outputs;\n  }\n};\n\nclass MinFunctor {\n public:\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x) const {\n    std::vector<int32_t> axis(x->ndim());\n    std::iota(axis.begin(), axis.end(), 0);\n    return ReduceMin(x, axis, /*keepdims=*/false);\n  }\n};\n\nclass Min2Functor {\n public:\n  Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& x, const int32_t& dim,\n                                const bool& keepdims) const {\n    auto outputs = std::make_shared<TensorTuple>(2);\n    int32_t axis = dim;\n    axis = JUST(maybe_wrap_dim(axis, x->ndim()));\n    (*outputs)[0] = JUST(ReduceMin(x, {axis}, keepdims));\n    (*outputs)[1] = JUST(ArgMin(x, dim, keepdims, NullOpt));\n    return outputs;\n  }\n};\n\nclass AminFunctor {\n public:\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,\n                           const Optional<std::vector<int32_t>>& dim, const bool& keepdim) const {\n    if (!dim.has_value()) { return ReduceMin(x, {}, keepdim); }\n\n    const int32_t ndim = x->ndim();\n    std::vector<int32_t>& dims = *JUST(dim);\n    for (int i = 0; i < dims.size(); i++) { dims[i] = JUST(maybe_wrap_dim(dims[i], ndim)); }\n    return ReduceMin(x, dims, keepdim);\n  }\n};\n\nclass AmaxFunctor {\n public:\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,\n                           const Optional<std::vector<int32_t>>& dim, const bool& keepdim) const {\n    if (!dim.has_value()) { return ReduceMax(x, {}, keepdim); }\n\n    const int32_t ndim = x->ndim();\n    std::vector<int32_t>& dims = *JUST(dim);\n    for (int i = 0; i < dims.size(); i++) { dims[i] = JUST(maybe_wrap_dim(dims[i], ndim)); }\n    return ReduceMax(x, dims, keepdim);\n  }\n};\n\nclass ReduceSumWholeFunctor {\n public:\n  ReduceSumWholeFunctor() {\n    op_ = CHECK_JUST(\n        one::OpBuilder(\"reduce_sum\").Input(\"input_tensor\").Output(\"output_tensor\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,\n                           const Optional<Symbol<DType>>& dtype) const {\n    std::shared_ptr<one::Tensor> tensor = x;\n    if (dtype.has_value() && (dtype != x->dtype())) {\n      tensor = JUST(Cast(x, JUST(dtype), /*pin_memory=*/false));\n    }\n    const int32_t naxis = tensor->ndim();\n    if (naxis == 0) { return x; }  // for 0-dim Tensor\n    std::vector<int32_t> axis(naxis);\n    std::iota(axis.begin(), axis.end(), 0);\n\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"axis\", \"keepdims\");\n    attrs.SetAllAttrs(axis, false);\n    TensorProcessor tensor_processor;\n    JUST(tensor_processor.AddInputs({tensor}, /*lowest_dtype=*/DType::Int64()).Apply());\n    TensorTuple input_tuple = JUST(tensor_processor.GetInputs());\n    return OpInterpUtil::Dispatch<Tensor>(*op_, input_tuple, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass ReduceSumFunctor {\n public:\n  ReduceSumFunctor() {\n    op_ = CHECK_JUST(\n        one::OpBuilder(\"reduce_sum\").Input(\"input_tensor\").Output(\"output_tensor\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const std::vector<int32_t>& axis,\n                           const bool keepdims, const Optional<Symbol<DType>>& dtype) const {\n    std::shared_ptr<one::Tensor> tensor = x;\n    if (dtype.has_value() && (dtype != x->dtype())) {\n      tensor = JUST(Cast(x, JUST(dtype), /*pin_memory=*/false));\n    }\n    std::vector<int32_t> reduce_axis = *JUST(CheckAxis(axis, x->ndim()));\n    if (reduce_axis.size() == 0) { return tensor; }\n\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"axis\", \"keepdims\");\n    attrs.SetAllAttrs(reduce_axis, keepdims);\n    TensorProcessor tensor_processor;\n    JUST(tensor_processor.AddInputs({tensor}, /*lowest_dtype=*/DType::Int64()).Apply());\n    TensorTuple input_tuple = JUST(tensor_processor.GetInputs());\n    return OpInterpUtil::Dispatch<Tensor>(*op_, input_tuple, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass ReduceNanSumFunctor {\n public:\n  ReduceNanSumFunctor() {\n    op_ = CHECK_JUST(\n        one::OpBuilder(\"reduce_nansum\").Input(\"input_tensor\").Output(\"output_tensor\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const std::vector<int32_t>& axis,\n                           const bool& keepdims, const Optional<Symbol<DType>>& dtype) const {\n    std::shared_ptr<one::Tensor> tensor = x;\n    if (dtype.has_value() && (dtype != x->dtype())) {\n      tensor = JUST(Cast(x, JUST(dtype), /*pin_memory=*/false));\n    }\n\n    std::vector<int32_t> reduce_axis = *JUST(CheckAxis(axis, tensor->ndim()));\n    if (reduce_axis.size() == 0) { return tensor; }\n\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"axis\", \"keepdims\");\n    attrs.SetAllAttrs(reduce_axis, keepdims);\n    TensorProcessor tensor_processor;\n    JUST(tensor_processor.AddInputs({tensor}, /*lowest_dtype=*/DType::Int64()).Apply());\n    TensorTuple input_tuple = JUST(tensor_processor.GetInputs());\n    return OpInterpUtil::Dispatch<Tensor>(*op_, input_tuple, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass ReduceNanSumWholeFunctor {\n public:\n  ReduceNanSumWholeFunctor() {\n    op_ = CHECK_JUST(\n        one::OpBuilder(\"reduce_nansum\").Input(\"input_tensor\").Output(\"output_tensor\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,\n                           const Optional<Symbol<DType>>& dtype) const {\n    std::shared_ptr<one::Tensor> tensor = x;\n    if (dtype.has_value() && (dtype != x->dtype())) {\n      tensor = JUST(Cast(x, JUST(dtype), /*pin_memory=*/false));\n    }\n\n    const int32_t ndim = tensor->ndim();\n    if (ndim == 0) { return tensor; }  // for 0-dim Tensor\n    std::vector<int32_t> axis(ndim);\n    std::iota(axis.begin(), axis.end(), 0);\n\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"axis\", \"keepdims\");\n    attrs.SetAllAttrs(axis, false);\n    TensorProcessor tensor_processor;\n    JUST(tensor_processor.AddInputs({tensor}, /*lowest_dtype=*/DType::Int64()).Apply());\n    TensorTuple input_tuple = JUST(tensor_processor.GetInputs());\n    return OpInterpUtil::Dispatch<Tensor>(*op_, input_tuple, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass ReduceAllWholeFunctor {\n public:\n  ReduceAllWholeFunctor() {\n    op_ = CHECK_JUST(\n        one::OpBuilder(\"reduce_all\").Input(\"input_tensor\").Output(\"output_tensor\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x) const {\n    std::vector<int32_t> reduce_axis(x->ndim());\n    std::iota(reduce_axis.begin(), reduce_axis.end(), 0);\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"axis\", \"keepdims\");\n    attrs.SetAllAttrs(reduce_axis, false);\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {x}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass ReduceAllFunctor {\n public:\n  ReduceAllFunctor() {\n    op_ = CHECK_JUST(\n        one::OpBuilder(\"reduce_all\").Input(\"input_tensor\").Output(\"output_tensor\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const std::vector<int32_t>& axis,\n                           const bool& keepdims) const {\n    std::vector<int32_t> reduce_axis = *JUST(CheckAxis(axis, x->ndim()));\n    if (reduce_axis.size() == 0) { return x; }\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"axis\", \"keepdims\");\n    attrs.SetAllAttrs(reduce_axis, keepdims);\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {x}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass ReduceAnyWholeFunctor {\n public:\n  ReduceAnyWholeFunctor() {\n    op_ = CHECK_JUST(\n        one::OpBuilder(\"reduce_any\").Input(\"input_tensor\").Output(\"output_tensor\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x) const {\n    std::vector<int32_t> reduce_axis(x->ndim());\n    std::iota(reduce_axis.begin(), reduce_axis.end(), 0);\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"axis\", \"keepdims\");\n    attrs.SetAllAttrs(reduce_axis, false);\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {x}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass ReduceAnyFunctor {\n public:\n  ReduceAnyFunctor() {\n    op_ = CHECK_JUST(\n        one::OpBuilder(\"reduce_any\").Input(\"input_tensor\").Output(\"output_tensor\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const std::vector<int32_t>& axis,\n                           const bool& keepdims) const {\n    std::vector<int32_t> reduce_axis = *JUST(CheckAxis(axis, x->ndim()));\n    if (reduce_axis.size() == 0) { return x; }\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"axis\", \"keepdims\");\n    attrs.SetAllAttrs(reduce_axis, keepdims);\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {x}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\ntemplate<class T>\nclass ReduceDeviceStageBaseFunctor {\n public:\n  ReduceDeviceStageBaseFunctor()\n      : op_(CHECK_JUST(one::OpBuilder(T::GetOpName())\n                           .Input(\"in\")\n                           .Output(\"out\")\n                           .Output(\"mask\")\n                           .Output(\"count\")\n                           .Build())) {}\n  Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& in,\n                                const std::vector<int32_t>& axis) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"axis\");\n    attrs.SetAllAttrs(axis);\n    return OpInterpUtil::Dispatch<TensorTuple>(*op_, {in}, attrs);\n  }\n  virtual ~ReduceDeviceStageBaseFunctor() = default;\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\ntemplate<class T>\nclass ReduceDeviceStageGradBaseFunctor {\n public:\n  ReduceDeviceStageGradBaseFunctor()\n      : op_(CHECK_JUST(one::OpBuilder(T::GetOpName())\n                           .Input(\"out_diff\")\n                           .Input(\"mask\")\n                           .Input(\"count\")\n                           .Output(\"in_diff\")\n                           .Build())) {}\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& out_diff,\n                           const std::shared_ptr<one::Tensor>& mask,\n                           const std::shared_ptr<one::Tensor>& count,\n                           const std::vector<int32_t>& axis) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"axis\");\n    attrs.SetAllAttrs(axis);\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {out_diff, mask, count}, attrs);\n  }\n  virtual ~ReduceDeviceStageGradBaseFunctor() = default;\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass ReduceMinDeviceStageFunctor\n    : public ReduceDeviceStageBaseFunctor<ReduceMinDeviceStageFunctor> {\n public:\n  static std::string GetOpName() { return \"reduce_min_device_stage\"; }\n};\n\nclass ReduceMaxDeviceStageFunctor\n    : public ReduceDeviceStageBaseFunctor<ReduceMaxDeviceStageFunctor> {\n public:\n  static std::string GetOpName() { return \"reduce_max_device_stage\"; }\n};\n\nclass ReduceMinDeviceStageGradFunctor\n    : public ReduceDeviceStageGradBaseFunctor<ReduceMinDeviceStageGradFunctor> {\n public:\n  static std::string GetOpName() { return \"reduce_min_device_stage_grad\"; }\n};\n\nclass ReduceMaxDeviceStageGradFunctor\n    : public ReduceDeviceStageGradBaseFunctor<ReduceMaxDeviceStageGradFunctor> {\n public:\n  static std::string GetOpName() { return \"reduce_max_device_stage_grad\"; }\n};\n\ntemplate<class T>\nclass ReduceGlobalStageBaseFunctor {\n public:\n  ReduceGlobalStageBaseFunctor()\n      : op_(CHECK_JUST(one::OpBuilder(T::GetOpName())\n                           .Input(\"in\")\n                           .Input(\"device_count\")\n                           .Output(\"out\")\n                           .Output(\"mask\")\n                           .Build())) {}\n  Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& in,\n                                const std::shared_ptr<one::Tensor>& device_count,\n                                const std::vector<int32_t>& axis, const bool& keepdims) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"axis\", \"keepdims\");\n    attrs.SetAllAttrs(axis, keepdims);\n    return OpInterpUtil::Dispatch<TensorTuple>(*op_, {in, device_count}, attrs);\n  }\n  virtual ~ReduceGlobalStageBaseFunctor() = default;\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\ntemplate<class T>\nclass ReduceGlobalStageGradBaseFunctor {\n public:\n  ReduceGlobalStageGradBaseFunctor()\n      : op_(CHECK_JUST(one::OpBuilder(T::GetOpName())\n                           .Input(\"out_diff\")\n                           .Input(\"mask\")\n                           .Input(\"device_count\")\n                           .Output(\"in_diff\")\n                           .Build())) {}\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& out_diff,\n                           const std::shared_ptr<one::Tensor>& mask,\n                           const std::shared_ptr<one::Tensor>& device_count,\n                           const std::vector<int32_t>& axis, const bool& keepdims) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"axis\", \"keepdims\");\n    attrs.SetAllAttrs(axis, keepdims);\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {out_diff, mask, device_count}, attrs);\n  }\n  virtual ~ReduceGlobalStageGradBaseFunctor() = default;\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass ReduceMinGlobalStageFunctor\n    : public ReduceGlobalStageBaseFunctor<ReduceMinGlobalStageFunctor> {\n public:\n  static std::string GetOpName() { return \"reduce_min_global_stage\"; }\n};\n\nclass ReduceMinGlobalStageGradFunctor\n    : public ReduceGlobalStageGradBaseFunctor<ReduceMinGlobalStageGradFunctor> {\n public:\n  static std::string GetOpName() { return \"reduce_min_global_stage_grad\"; }\n};\n\nclass ReduceMaxGlobalStageFunctor\n    : public ReduceGlobalStageBaseFunctor<ReduceMaxGlobalStageFunctor> {\n public:\n  static std::string GetOpName() { return \"reduce_max_global_stage\"; }\n};\n\nclass ReduceMaxGlobalStageGradFunctor\n    : public ReduceGlobalStageGradBaseFunctor<ReduceMaxGlobalStageGradFunctor> {\n public:\n  static std::string GetOpName() { return \"reduce_max_global_stage_grad\"; }\n};\n\nclass ReduceMeanWholeFunctor {\n public:\n  ReduceMeanWholeFunctor() {}\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x) const {\n    // ReduceMean only calculate floating values.\n    CHECK_OR_RETURN(IsFloatingDataType(x->dtype()->data_type())\n                    || IsComplexDataType(x->dtype()->data_type()))\n        << \"RuntimeError: Can only calculate the mean of floating types or complex types.\";\n    size_t reduce_count = 1;\n    reduce_count = x->shape()->Count(0);\n    const auto& sum = JUST(functional::ReduceSumWhole(x, NullOpt));\n    if (reduce_count == 1 || reduce_count == 0) { return sum; }\n    return functional::ScalarMul(sum, 1.0 / reduce_count, false);\n  }\n};\n\nclass ReduceMeanFunctor {\n public:\n  ReduceMeanFunctor() {}\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const std::vector<int32_t>& axis,\n                           const bool& keepdims) const {\n    // ReduceMean only calculate floating values.\n    // NOTE: Should use original reduce_mean op/kernel rather than current way(ReduceSum /\n    // reduce_count) because it could encounter precision problem(like overflow) in float16 case.\n    CHECK_OR_RETURN(IsFloatingDataType(x->dtype()->data_type()))\n        << \"RuntimeError: Can only calculate the mean of floating types.\";\n\n    const auto& sum = JUST(functional::ReduceSum(x, axis, keepdims, NullOpt));\n    size_t reduce_count = 1;\n    if (axis.empty()) {\n      reduce_count = x->shape()->Count(0);\n    } else {\n      std::vector<int32_t> reduce_axis = *JUST(CheckAxis(axis, x->ndim()));\n      for (int32_t& i : reduce_axis) { reduce_count *= x->shape()->At(i); }\n    }\n    if (reduce_count == 1 || reduce_count == 0) { return sum; }\n    return functional::ScalarMul(sum, 1.0 / reduce_count, false);\n  }\n};\n\nclass ReduceProdWholeFunctor {\n public:\n  ReduceProdWholeFunctor() {\n    op_ = CHECK_JUST(\n        one::OpBuilder(\"reduce_prod\").Input(\"input_tensor\").Output(\"output_tensor\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,\n                           const Optional<Symbol<DType>>& dtype) const {\n    std::shared_ptr<one::Tensor> tensor = x;\n    if (dtype.has_value() && (dtype != x->dtype())) {\n      tensor = JUST(Cast(tensor, JUST(dtype), /*pin_memory=*/false));\n    }\n    TensorProcessor tensor_processor;\n    Symbol<DType> lowest_dtype;\n    if (DType::priority_order[tensor->dtype()->data_type()]\n        == DType::priority_order[DType::Bool()->data_type()]) {\n      lowest_dtype = DType::Int64();\n    } else {\n      lowest_dtype = tensor->dtype();\n    }\n    JUST(tensor_processor.AddInputs({tensor}, lowest_dtype).Apply());\n    TensorTuple input_tuple = JUST(tensor_processor.GetInputs());\n    std::vector<int32_t> reduce_axis(tensor->ndim());\n    std::iota(reduce_axis.begin(), reduce_axis.end(), 0);\n\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"axis\", \"keepdims\");\n    attrs.SetAllAttrs(reduce_axis, false);\n    return JUST(OpInterpUtil::Dispatch<Tensor>(*op_, input_tuple, attrs));\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass MedianFunctor {\n public:\n  MedianFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"median\").Input(\"input\").Output(\"output\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x) const {\n    if (x->shape()->elem_cnt() == 0) {\n      return functional::To(\n          JUST(functional::Constant(Shape({1}).RemoveOnes({0}),\n                                    Scalar(std::numeric_limits<float>::quiet_NaN()),\n                                    JUST(DType::Get(DataType::kFloat)), NullOpt)),\n          x, false);\n    }\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {x});\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass MedianWithIndicesFunctor {\n public:\n  MedianWithIndicesFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"median_with_indices\")\n                         .Input(\"input\")\n                         .Output(\"values\")\n                         .Output(\"indices\")\n                         .Build());\n  }\n  Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& x, const int32_t& dim,\n                                const bool& keepdim) const {\n    int32_t axis = dim;\n    const int64_t ndim = x->ndim();\n    axis = JUST(maybe_wrap_dim(axis, ndim));\n    std::shared_ptr<one::Tensor> tensor = x;\n    if (x->dim(axis) == 0) {\n      return Error::IndexError() << \"IndexError: Expected reduction dim \" << axis\n                                 << \" to have non-zero size.\";\n    }\n    if (axis != ndim - 1) {\n      tensor = JUST(functional::Squeeze(\n          JUST(functional::Transpose2dim(JUST(functional::Unsqueeze(x, -1)), axis, -1)),\n          std::vector<int32_t>({axis})));\n    }\n    std::shared_ptr<TensorTuple> result;\n    result = JUST(OpInterpUtil::Dispatch<TensorTuple>(*op_, {tensor}));\n    if (keepdim) {\n      JUST(VectorAt(*result, 0)) = JUST(functional::Unsqueeze(JUST(VectorAt(*result, 0)), axis));\n      JUST(VectorAt(*result, 1)) = JUST(functional::Unsqueeze(JUST(VectorAt(*result, 1)), axis));\n    }\n    return result;\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass ModeFunctor {\n public:\n  ModeFunctor() {\n    op_ = CHECK_JUST(\n        one::OpBuilder(\"mode\").Input(\"input\").Output(\"values\").Output(\"indices\").Build());\n  }\n  Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& x, const int32_t& dim,\n                                const bool keepdim) const {\n    int32_t axis = dim;\n    const int64_t ndim = x->ndim();\n    axis = JUST(maybe_wrap_dim(axis, ndim));\n    std::shared_ptr<one::Tensor> tensor = x;\n    if (x->dim(axis) == 0) {\n      return Error::IndexError() << \"IndexError: Expected reduction dim \" << axis\n                                 << \" to have non-zero size.\";\n    }\n    if (axis != ndim - 1) {\n      tensor = JUST(functional::Squeeze(\n          JUST(functional::Transpose2dim(JUST(functional::Unsqueeze(x, -1)), axis, -1)),\n          std::vector<int32_t>({axis})));\n    }\n    std::shared_ptr<TensorTuple> result;\n    result = JUST(OpInterpUtil::Dispatch<TensorTuple>(*op_, {tensor}));\n    if (keepdim) {\n      JUST(VectorAt(*result, 0)) = JUST(functional::Unsqueeze(JUST(VectorAt(*result, 0)), axis));\n      JUST(VectorAt(*result, 1)) = JUST(functional::Unsqueeze(JUST(VectorAt(*result, 1)), axis));\n    }\n    return result;\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass ReduceProdFunctor {\n public:\n  ReduceProdFunctor() {\n    op_ = CHECK_JUST(\n        one::OpBuilder(\"reduce_prod\").Input(\"input_tensor\").Output(\"output_tensor\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const std::vector<int32_t>& axis,\n                           const bool& keepdims, const Optional<Symbol<DType>>& dtype) const {\n    std::shared_ptr<one::Tensor> tensor = x;\n    if (dtype.has_value() && (dtype != x->dtype())) {\n      tensor = JUST(Cast(tensor, JUST(dtype), /*pin_memory=*/false));\n    }\n    TensorProcessor tensor_processor;\n    Symbol<DType> lowest_dtype;\n    if (DType::priority_order[tensor->dtype()->data_type()]\n        == DType::priority_order[DType::Bool()->data_type()]) {\n      lowest_dtype = DType::Int64();\n    } else {\n      lowest_dtype = tensor->dtype();\n    }\n    JUST(tensor_processor.AddInputs({tensor}, lowest_dtype).Apply());\n    TensorTuple input_tuple = JUST(tensor_processor.GetInputs());\n    std::vector<int32_t> reduce_axis = *JUST(CheckAxis(axis, x->ndim()));\n    if (reduce_axis.size() == 0) { return x; }\n\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"axis\", \"keepdims\");\n    attrs.SetAllAttrs(reduce_axis, keepdims);\n    return JUST(OpInterpUtil::Dispatch<Tensor>(*op_, input_tuple, attrs));\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass LogSumExpFunctor {\n public:\n  LogSumExpFunctor() {}\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const std::vector<int32_t>& axis,\n                           const bool& keepdims) const {\n    if (x->ndim() == 0) {\n      // can't take amax of 0-dim tensor\n      return To(x, JUST(DType::Get(DataType::kFloat)), false);\n    } else if (x->nelement() == 0) {\n      // can't take amax of empty tensor\n      std::shared_ptr<one::Tensor> exp_out = JUST(Exp(x));\n      return Log(JUST(ReduceSum(exp_out, axis, keepdims, NullOpt)));\n    } else {\n      const std::shared_ptr<one::Tensor>& maxes = JUST(Amax(x, axis, true));\n      const std::shared_ptr<one::Tensor>& maxes_squeezed =\n          (keepdims ? maxes : JUST(SqueezeMultiple(maxes, axis)));\n      JUST(MaskedFillInplace(maxes_squeezed,\n                             JUST(ScalarLogicalEqual(JUST(Abs(maxes_squeezed)), INFINITY)), 0));\n      std::shared_ptr<one::Tensor> exp_out = JUST(Exp(JUST(Sub(x, maxes, 1, false))));\n      return Add(JUST(Log(JUST(ReduceSum(exp_out, axis, keepdims, NullOpt)))), maxes_squeezed, 1,\n                 false);\n    }\n  }\n\n private:\n  Maybe<Tensor> SqueezeMultiple(const std::shared_ptr<one::Tensor>& x,\n                                const std::vector<int32_t>& axis) const {\n    int ndims = x->ndim();\n    const auto& dims_to_squeeze = JUST(dim_list_to_bitset(axis, ndims));\n    std::shared_ptr<one::Tensor> result = x;\n    for (int i = ndims - 1; i >= 0; --i) {\n      if ((*dims_to_squeeze)[i]) {\n        std::vector<int32_t> dims = {i};\n        result = JUST(Squeeze(result, dims));\n      }\n    }\n    return result;\n  }\n};\n\nclass LogAddExpFunctor {\n public:\n  LogAddExpFunctor() {}\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,\n                           const std::shared_ptr<one::Tensor>& y) const {\n    CHECK_OR_RETURN(x->nelement() > 0 && y->nelement() > 0)\n        << \"logaddexp do not support 0-size tensor.\";\n    const std::shared_ptr<one::Tensor>& maxes = JUST(Maximum(x, y));\n    std::shared_ptr<one::Tensor> exp_out =\n        JUST(Exp(JUST(Negative(JUST(Abs(JUST(Sub(x, y, 1, false))))))));\n    std::shared_ptr<one::Tensor> add_out = JUST(ScalarAdd(1.0, exp_out, 1));\n    return Add(maxes, JUST(Log(add_out)), 1, false);\n  }\n};\n\nclass QuantileFunctor {\n public:\n  QuantileFunctor() {}\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input,\n                           const std::shared_ptr<one::Tensor>& q, const Optional<int64_t>& dim,\n                           const bool keepdim, const std::string& interpolation,\n                           const bool ignore_nan) const {\n    CHECK_GT_OR_RETURN(input->nelement(), 0) << \"oneflow.quantile input tensor must be non-empty\";\n    CHECK_LE_OR_RETURN(q->ndim(), 1)\n        << \"oneflow.quantile only support `q` tensor is a scalar or 1D tensor.\";\n    int64_t wrapped_dim = JUST(maybe_wrap_dim(dim.value_or(0), input->ndim()));\n\n    // NOTE(hujiakui): this check is only performed when running on the CPU to avoid\n    // synchronizing an accelerator with the CPU\n    // For q is a Tensor.\n    DeviceType input_device{};\n    if (input->is_global()) {\n      input_device = JUST(input->parallel_desc())->device_type();\n    } else {\n      input_device = JUST(input->device())->enum_type();\n    }\n    if (input_device == DeviceType::kCPU) {\n      std::shared_ptr<Tensor> condition =\n          JUST(functional::ReduceAllWhole(JUST(functional::BroadcastLogicalAnd(\n              JUST(functional::ScalarLogicalGreaterEqual(q, Scalar(0.0))),\n              JUST(functional::ScalarLogicalLessEqual(q, Scalar(1.0)))))));\n      CHECK_OR_RETURN(JUST(functional::Equal(\n          condition,\n          JUST(functional::Cast(JUST(functional::OnesLike(condition)), DType::Bool(), false)))))\n          << \"oneflow.quantile q values must be in the range [0, 1]\";\n    }\n\n    // calculate the shape of output\n    auto out_shape = quantile_output_shape(dim, input, q, keepdim, wrapped_dim);\n\n    std::shared_ptr<Tensor> sorted;\n    if (!dim.has_value()) {\n      sorted = JUST(functional::Flatten(input, 0, -1));\n      sorted = JUST(functional::Sort(sorted, -1, false))->at(0);\n    } else if (wrapped_dim == input->ndim() - 1) {\n      sorted = JUST(functional::Sort(input, -1, false))->at(0);\n    } else {\n      sorted = JUST(functional::Unsqueeze(input, input->ndim() - 1));\n      std::vector<int32_t> perm(sorted->ndim());\n      std::iota(perm.begin(), perm.end(), 0);\n      std::swap(perm[wrapped_dim], perm[perm.size() - 1]);\n      sorted = JUST(view::Transpose(sorted, perm));\n      sorted = JUST(functional::Sort(sorted, -1, false))->at(0);\n    }\n\n    std::vector<int64_t> in_shape(out_shape.size());\n    std::copy(out_shape.begin() + 1, out_shape.end(), in_shape.begin());\n    in_shape[in_shape.size() - 1] = sorted->dim(sorted->ndim() - 1);\n    DimVector inv(in_shape.size());\n    for (int i = 0; i < in_shape.size(); ++i) { inv[i] = in_shape[i]; }\n    const Shape step_shape(inv);\n    sorted = JUST(functional::View(sorted->contiguous(), step_shape));\n\n    CHECK_LE_OR_RETURN(sorted->dim(sorted->ndim() - 1), std::pow(2, 24))\n        << \"oneflow.quantile input tensor is too large\";\n\n    std::shared_ptr<Tensor> ranks;\n\n    if (ignore_nan) {\n      ranks = JUST(\n          functional::Mul(JUST(functional::ScalarSub(\n                              JUST(functional::ReduceSum(\n                                  JUST(functional::LogicalNot(JUST(functional::IsNan(sorted)))),\n                                  std::vector<int32_t>({static_cast<int32_t>(sorted->ndim() - 1)}),\n                                  /*keepdim=*/true, NullOpt)),\n                              Scalar(1), Scalar(1), /*inplace=*/false)),\n                          q));\n      ranks = JUST(functional::MaskedFill(\n          ranks, JUST(functional::ScalarLogicalLess(ranks, Scalar(0))), Scalar(0)));\n    } else {\n      int64_t last_index = sorted->dim(sorted->ndim() - 1) - 1;\n      std::shared_ptr<TensorTuple> tl = JUST(functional::BroadcastTensors(\n          {JUST(functional::ScalarMul(q, last_index, /*inplace=*/false)),\n           JUST(functional::ReduceAny(\n               JUST(functional::IsNan(sorted)),\n               std::vector<int32_t>({static_cast<int32_t>(sorted->ndim() - 1)}),\n               /*keepdim=*/true))}));\n      ranks = JUST(functional::MaskedFill(tl->at(0), tl->at(1), Scalar(last_index)));\n    }\n\n    if (interpolation == \"lower\") {\n      JUST(functional::Floor_(ranks));\n    } else if (interpolation == \"higher\") {\n      JUST(functional::Ceil_(ranks));\n    } else if (interpolation == \"nearest\") {\n      JUST(functional::Round_(ranks));\n    }\n\n    std::shared_ptr<Tensor> ranks_below = JUST(functional::Cast(ranks, DType::Int64(),\n                                                                /*pin_memory=*/false));\n    std::shared_ptr<Tensor> values_below =\n        JUST(functional::DimGather(sorted, sorted->ndim() - 1, ranks_below, false));\n\n    if (interpolation == \"linear\" || interpolation == \"midpoint\") {\n      std::shared_ptr<Tensor> weights = interpolation == \"midpoint\"\n                                            ? JUST(functional::FullLike(ranks, Scalar(0.5)))\n                                            : JUST(functional::Sub(ranks, ranks_below, Scalar(1.0),\n                                                                   /*inplace=*/false));\n      JUST(functional::Ceil_(ranks));\n      std::shared_ptr<Tensor> ranks_above =\n          JUST(functional::Cast(ranks, DType::Int64(), /*pin_memory=*/false));\n      std::shared_ptr<Tensor> values_above =\n          JUST(functional::DimGather(sorted, sorted->ndim() - 1, ranks_above, false));\n\n      values_below = JUST(functional::Lerp(values_below, values_above, weights));\n    }\n\n    values_below = JUST(view::Unsqueeze(values_below, 0));\n\n    int32_t ndim = values_below->ndim();\n    std::vector<int32_t> perm(ndim);\n    std::iota(perm.begin(), perm.end(), 0);\n    std::swap(perm[0], perm[perm.size() - 1]);\n    values_below = JUST(view::Transpose(values_below, perm));\n\n    return view::Squeeze(values_below,\n                         std::vector<int32_t>({static_cast<int32_t>(values_below->ndim() - 1)}));\n  }\n\n private:\n  static inline std::vector<int64_t> quantile_output_shape(const Optional<int64_t>& dim,\n                                                           const std::shared_ptr<Tensor>& input,\n                                                           const std::shared_ptr<Tensor>& q,\n                                                           const bool keepdim,\n                                                           int64_t wrapped_dim) {\n    // Compute output shape: q_size + reduced_size\n    std::vector<int64_t> out_shape;\n    if (dim.has_value() && input->ndim() > 0) {\n      out_shape =\n          std::vector<int64_t>(input->shape()->dim_vec().begin(), input->shape()->dim_vec().end());\n      if (keepdim) {\n        out_shape[wrapped_dim] = 1;\n      } else {\n        out_shape.erase(out_shape.begin() + wrapped_dim);\n      }\n    } else if (keepdim) {\n      out_shape = std::vector<int64_t>(input->ndim(), 1);\n    }\n    out_shape.insert(out_shape.begin(), q->nelement());\n\n    return out_shape;\n  }\n};\n\nclass ScalarQuantileFunctor {\n public:\n  ScalarQuantileFunctor() {}\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input, const Scalar& q,\n                           const Optional<int64_t>& dim, const bool& keepdim,\n                           const std::string& interpolation, const bool& ignore_nan) const {\n    CHECK_GT_OR_RETURN(input->nelement(), 0) << \"oneflow.quantile input tensor must be non-empty\";\n    int64_t wrapped_dim = JUST(maybe_wrap_dim(dim.value_or(0), input->ndim()));\n    double qf = 0;\n    if (q.IsIntegral()) {\n      qf = static_cast<double>(q.As<int64_t>());\n    } else {\n      qf = q.As<double>();\n    }\n    CHECK_OR_RETURN(qf <= 1.0 && qf >= 0.0)\n        << \"oneflow.quantile q values must be in the range [0, 1]\";\n\n    // calculate the shape of output\n    auto out_shape = quantile_output_shape(dim, input, q, keepdim, wrapped_dim);\n\n    std::shared_ptr<Tensor> sorted;\n    if (!dim.has_value()) {\n      sorted = JUST(functional::Flatten(input, 0, -1));\n      sorted = JUST(functional::Sort(sorted, -1, false))->at(0);\n    } else if (wrapped_dim == input->ndim() - 1) {\n      sorted = JUST(functional::Sort(input, -1, false))->at(0);\n    } else {\n      sorted = JUST(functional::Unsqueeze(input, input->ndim() - 1));\n      std::vector<int32_t> perm(sorted->ndim());\n      std::iota(perm.begin(), perm.end(), 0);\n      std::swap(perm[wrapped_dim], perm[perm.size() - 1]);\n      sorted = JUST(view::Transpose(sorted, perm));\n      sorted = JUST(functional::Sort(sorted, -1, false))->at(0);\n    }\n\n    // q ==> 1-D Tensor\n    out_shape.insert(out_shape.begin(), 1);\n\n    std::vector<int64_t> in_shape(out_shape.size());\n    std::copy(out_shape.begin() + 1, out_shape.end(), in_shape.begin());\n    in_shape[in_shape.size() - 1] = sorted->dim(sorted->ndim() - 1);\n    DimVector inv(in_shape.size());\n    for (int i = 0; i < in_shape.size(); ++i) { inv[i] = in_shape[i]; }\n    const Shape step_shape(inv);\n    sorted = JUST(functional::View(sorted->contiguous(), step_shape));\n\n    CHECK_LE_OR_RETURN(sorted->dim(sorted->ndim() - 1), std::pow(2, 24))\n        << \"oneflow.quantile input tensor is too large\";\n\n    std::shared_ptr<Tensor> ranks;\n\n    if (ignore_nan) {\n      ranks = JUST(functional::ScalarMul(\n          JUST(functional::ScalarSub(\n              JUST(functional::ReduceSum(\n                  JUST(functional::LogicalNot(JUST(functional::IsNan(sorted)))),\n                  std::vector<int32_t>({static_cast<int32_t>(sorted->ndim() - 1)}),\n                  /*keepdim=*/true, NullOpt)),\n              Scalar(1), Scalar(1), /*inplace=*/false)),\n          q, /*inplace=*/false));\n      ranks = JUST(functional::MaskedFill(\n          ranks, JUST(functional::ScalarLogicalLess(ranks, Scalar(0))), Scalar(0)));\n    } else {\n      int64_t last_index = sorted->dim(sorted->ndim() - 1) - 1;\n      std::shared_ptr<Tensor> tl_index = JUST(\n          functional::ReduceAny(JUST(functional::IsNan(sorted)),\n                                std::vector<int32_t>({static_cast<int32_t>(sorted->ndim() - 1)}),\n                                /*keepdim=*/true));\n      std::shared_ptr<Tensor> tl_value;\n      if (input->is_local()) {\n        tl_value =\n            JUST(functional::Empty(*(tl_index->shape()), DType::Float(), JUST(tl_index->device()),\n                                   /*requires_grad=*/false, /*pin_memory=*/false));\n      } else {\n        tl_value = JUST(functional::GlobalEmpty(\n            *(tl_index->shape()), DType::Float(), JUST(tl_index->parallel_desc()),\n            *JUST(private_details::RawGetSbpList(JUST(tl_index->nd_sbp())))));\n      }\n      tl_value = JUST(functional::Fill(tl_value, Scalar(qf * last_index)));\n      ranks = JUST(functional::MaskedFill(tl_value, tl_index, Scalar(last_index)));\n    }\n\n    // adjust ranks based on the interpolation mode\n    if (interpolation == \"lower\") {\n      JUST(functional::Floor_(ranks));\n    } else if (interpolation == \"higher\") {\n      JUST(functional::Ceil_(ranks));\n    } else if (interpolation == \"nearest\") {\n      JUST(functional::Round_(ranks));\n    }\n\n    std::shared_ptr<Tensor> ranks_below = JUST(functional::Cast(ranks, DType::Int64(),\n                                                                /*pin_memory=*/false));\n    std::shared_ptr<Tensor> values_below =\n        JUST(functional::DimGather(sorted, sorted->ndim() - 1, ranks_below, false));\n\n    if (interpolation == \"linear\" || interpolation == \"midpoint\") {\n      std::shared_ptr<Tensor> weights = interpolation == \"midpoint\"\n                                            ? JUST(functional::FullLike(ranks, Scalar(0.5)))\n                                            : JUST(functional::Sub(ranks, ranks_below, Scalar(1.0),\n                                                                   /*inplace=*/false));\n      JUST(functional::Ceil_(ranks));\n      std::shared_ptr<Tensor> ranks_above =\n          JUST(functional::Cast(ranks, DType::Int64(), /*pin_memory=*/false));\n      std::shared_ptr<Tensor> values_above =\n          JUST(functional::DimGather(sorted, sorted->ndim() - 1, ranks_above, false));\n\n      values_below = JUST(functional::Lerp(values_below, values_above, weights));\n    }\n\n    return view::Squeeze(values_below,\n                         std::vector<int32_t>({static_cast<int32_t>(values_below->ndim() - 1)}));\n  }\n\n private:\n  static inline std::vector<int64_t> quantile_output_shape(const Optional<int64_t>& dim,\n                                                           const std::shared_ptr<Tensor>& input,\n                                                           const Scalar& q, const bool keepdim,\n                                                           int64_t wrapped_dim) {\n    // Compute output shape: q_size + reduced_size\n    std::vector<int64_t> out_shape;\n    if (dim.has_value() && input->ndim() > 0) {\n      out_shape =\n          std::vector<int64_t>(input->shape()->dim_vec().begin(), input->shape()->dim_vec().end());\n      if (keepdim) {\n        out_shape[wrapped_dim] = 1;\n      } else {\n        out_shape.erase(out_shape.begin() + wrapped_dim);\n      }\n    } else if (keepdim) {\n      out_shape = std::vector<int64_t>(input->ndim(), 1);\n    }\n\n    return out_shape;\n  }\n};\n\nclass TransposeFunctor {\n public:\n  TransposeFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"transpose\").Input(\"input\").Output(\"output\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input,\n                           const std::vector<int32_t>& permute) const {\n    auto ndim = input->ndim();\n    CHECK_EQ_OR_RETURN(ndim, permute.size()) << \"number of dims don't match in permute\";\n\n    // handle negative permute value here, because of permute is const,\n    // so copy it to local var and do modification.\n    auto positive_perm = permute;\n    for (auto i = 0; i < positive_perm.size(); i++) {\n      positive_perm[i] = JUST(maybe_wrap_dim(positive_perm[i], ndim));\n    }\n    // currently, view only support eager and local mode\n    if (view::IsViewApplicable(input)) { return JUST(view::Transpose(input, positive_perm)); }\n\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"perm\");\n    attrs.SetAllAttrs(positive_perm);\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {input}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass Transpose2dimFunctor {\n public:\n  Transpose2dimFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"transpose\").Input(\"input\").Output(\"output\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input, const int32_t dim0,\n                           const int32_t dim1) const {\n    const int64_t ndim = input->ndim();\n    std::vector<int32_t> permute;\n    permute.reserve(ndim);\n    int32_t dim_0 = dim0;\n    int32_t dim_1 = dim1;\n\n    dim_0 = JUST(maybe_wrap_dim(dim_0, ndim));\n    dim_1 = JUST(maybe_wrap_dim(dim_1, ndim));\n    for (int32_t i = 0; i < ndim; ++i) { permute.emplace_back(i); }\n    std::swap(permute[dim_0], permute[dim_1]);\n    Shape shape(DimVector(permute.begin(), permute.end()));\n    if (view::IsViewApplicable(input)) { return JUST(view::Transpose(input, permute)); }\n\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"perm\");\n    attrs.SetAllAttrs(permute);\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {input}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass AsStridedFunctor {\n public:\n  AsStridedFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"as_strided\").Input(\"input\").Output(\"output\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input,\n                           const std::vector<int64_t>& size, const std::vector<int64_t>& stride,\n                           const int64_t& storage_offset) const {\n    CHECK_OR_RETURN(size.size() == stride.size()) << \"mismatch in length of strides and shape\";\n    for (size_t i = 0; i < size.size(); i++) {\n      CHECK_OR_RETURN(size[i] >= 0) << \"Trying to create tensor with negative dimension\" << size[i];\n      CHECK_OR_RETURN(stride[i] >= 0)\n          << \"as_strided: Negative strides are not supported at the moment, got strides:\"\n          << stride[i];\n    }\n    if (view::IsViewApplicable(input)) {\n      return JUST(view::AsStrided(input, size, stride, storage_offset));\n    }\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"size\", \"stride\", \"storage_offset\");\n    attrs.SetAllAttrs(size, stride, storage_offset);\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {input}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass AsStridedGradFunctor {\n public:\n  AsStridedGradFunctor() {\n    op_ = CHECK_JUST(\n        one::OpBuilder(\"as_strided_grad\").Input(\"dy\").Input(\"input\").Output(\"dx\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& dy,\n                           const std::shared_ptr<one::Tensor>& input,\n                           const std::vector<int64_t>& size, const std::vector<int64_t>& stride,\n                           const int64_t& storage_offset) const {\n    if (view::IsViewApplicable(input)) {\n      return JUST(view::AsStridedGrad(dy, input, size, stride, storage_offset));\n    }\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"size\", \"stride\", \"storage_offset\");\n    attrs.SetAllAttrs(size, stride, storage_offset);\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {dy, input}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass InplaceAsStridedFunctor {\n public:\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input,\n                           const std::vector<int64_t>& size, const std::vector<int64_t>& stride,\n                           const int64_t& storage_offset) const {\n    JUST(CheckInplaceValid(input));\n    CHECK_OR_RETURN(size.size() == stride.size()) << \"mismatch in length of strides and shape\";\n    for (size_t i = 0; i < size.size(); i++) {\n      CHECK_OR_RETURN(size[i] >= 0) << \"Trying to create tensor with negative dimension\" << size[i];\n      CHECK_OR_RETURN(stride[i] >= 0)\n          << \"as_strided: Negative strides are not supported at the moment, got strides:\"\n          << stride[i];\n    }\n    CHECK_OR_RETURN(view::IsViewApplicable(input))\n        << \"Only support as_strided_ in eager local mode\";\n    JUST(view::InplaceAsStrided(input, size, stride, storage_offset));\n    return input;\n  }\n};\nclass ArangeFunctor {\n public:\n  ArangeFunctor() { op_ = CHECK_JUST(one::OpBuilder(\"arange\").Output(\"out\").Build()); }\n  Maybe<Tensor> operator()(const Scalar& start, const Scalar& limit, const Scalar& delta,\n                           const Optional<Symbol<DType>>& dtype,\n                           const Optional<Symbol<Device>>& device) const {\n    if (GlobalMode::is_enabled()) {\n      auto global_mode_gurad = GlobalMode::Guard(false);\n      return JUST(functional::GlobalArange(start, limit, delta, dtype,\n                                           GetGlobalParallelDescFromDevice(device),\n                                           *JUST(GetSbpList(GlobalMode::nd_sbp()))));\n    }\n    auto& attrs =\n        THREAD_CACHED_MUTABLE_ATTR_MAP(\"integer_start\", \"integer_limit\", \"integer_delta\",\n                                       \"float_start\", \"float_limit\", \"float_delta\", \"dtype\");\n    if (dtype.has_value()) {\n      const DataType range_dtype = JUST(dtype)->data_type();\n      if (IsIntegralDataType(range_dtype)) {\n        attrs.SetAllAttrs(start.As<int64_t>(), limit.As<int64_t>(), delta.As<int64_t>(), NullOpt,\n                          NullOpt, NullOpt, range_dtype);\n      } else {\n        attrs.SetAllAttrs(NullOpt, NullOpt, NullOpt, start.As<double>(), limit.As<double>(),\n                          delta.As<double>(), range_dtype);\n      }\n    } else {\n      if (start.IsIntegral() && limit.IsIntegral() && delta.IsIntegral()) {\n        attrs.SetAllAttrs(start.As<int64_t>(), limit.As<int64_t>(), delta.As<int64_t>(), NullOpt,\n                          NullOpt, NullOpt, DType::Int64()->data_type());\n      } else {\n        attrs.SetAllAttrs(NullOpt, NullOpt, NullOpt, start.As<double>(), limit.As<double>(),\n                          delta.As<double>(), DType::Float()->data_type());\n      }\n    }\n    OpExprInterpContext ctx(attrs);\n    ctx.device = device;\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {}, ctx);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass Arange2Functor {\n public:\n  Maybe<Tensor> operator()(const Scalar& limit, const Optional<Symbol<DType>>& dtype,\n                           const Optional<Symbol<Device>>& device) const {\n    return Arange(Scalar(0), limit, Scalar(1), dtype, device);\n  }\n};\n\nclass GlobalArangeFunctor {\n public:\n  GlobalArangeFunctor() { op_ = CHECK_JUST(one::OpBuilder(\"arange\").Output(\"out\").Build()); }\n  Maybe<Tensor> operator()(const Scalar& start, const Scalar& limit, const Scalar& delta,\n                           const Optional<Symbol<DType>>& dtype,\n                           const Symbol<ParallelDesc>& placement,\n                           const std::vector<Symbol<SbpParallel>>& sbp_tuple) const {\n    JUST(CheckDeviceIdsIsValid(placement));\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"integer_start\", \"integer_limit\", \"integer_delta\",\n                                                 \"float_start\", \"float_limit\", \"float_delta\",\n                                                 \"dtype\", \"nd_sbp\");\n    if (dtype.has_value()) {\n      const DataType range_dtype = JUST(dtype)->data_type();\n      if (IsIntegralDataType(range_dtype)) {\n        attrs.SetAllAttrs(start.As<int64_t>(), limit.As<int64_t>(), delta.As<int64_t>(), NullOpt,\n                          NullOpt, NullOpt, range_dtype, NullOpt);\n      } else {\n        attrs.SetAllAttrs(NullOpt, NullOpt, NullOpt, start.As<double>(), limit.As<double>(),\n                          delta.As<double>(), range_dtype, NullOpt);\n      }\n    } else {\n      if (start.IsIntegral() && limit.IsIntegral() && delta.IsIntegral()) {\n        attrs.SetAllAttrs(start.As<int64_t>(), limit.As<int64_t>(), delta.As<int64_t>(), NullOpt,\n                          NullOpt, NullOpt, DType::Int64()->data_type(), NullOpt);\n      } else {\n        attrs.SetAllAttrs(NullOpt, NullOpt, NullOpt, start.As<double>(), limit.As<double>(),\n                          delta.As<double>(), DType::Float()->data_type(), NullOpt);\n      }\n    }\n    if (LazyMode::is_enabled()) {\n      std::vector<std::string> nd_sbp(sbp_tuple.size());\n      {\n        for (int i = 0; i < sbp_tuple.size(); ++i) {\n          nd_sbp.at(i) = SbpParallelToString(*sbp_tuple.at(i));\n        }\n      }\n      attrs.SetAttr<7>(nd_sbp);\n    }\n    const auto& nd_sbp = JUST(GetNdSbp(sbp_tuple));\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {}, OpExprInterpContext(attrs, placement, nd_sbp));\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass GlobalArange2Functor {\n public:\n  Maybe<Tensor> operator()(const Scalar& limit, const Optional<Symbol<DType>>& dtype,\n                           const Symbol<ParallelDesc>& placement,\n                           const std::vector<Symbol<SbpParallel>>& sbp_tuple) const {\n    JUST(CheckDeviceIdsIsValid(placement));\n    return GlobalArange(Scalar(0), limit, Scalar(1), dtype, placement, sbp_tuple);\n  }\n};\n\nclass HannWindowFunctor {\n public:\n  Maybe<Tensor> operator()(const int64_t window_length, const bool& periodic,\n                           const Optional<Symbol<Device>>& device,\n                           const Optional<Symbol<DType>>& dtype, const bool& requires_grad) const {\n    if (GlobalMode::is_enabled()) {\n      auto global_mode_gurad = GlobalMode::Guard(false);\n      return JUST(functional::GlobalHannWindow(\n          window_length, periodic, GetGlobalParallelDescFromDevice(device),\n          *JUST(GetSbpList(GlobalMode::nd_sbp())), dtype, requires_grad));\n    }\n    autograd::AutoGradMode mode(false);\n    if (dtype.has_value() && !IsFloatingDataType(JUST(dtype)->data_type())) {\n      return Error::RuntimeError()\n             << \"hann_window expects floating point dtypes, got: \" << JUST(dtype)->name();\n    }\n    // TODO: speedup\n    auto result = JUST(Arange(1, 2, 1, dtype, device));\n    if (window_length != 1) {\n      if (periodic) {\n        const auto indice = JUST(Arange(window_length + 1, dtype, device));\n        const auto div_result = JUST(ScalarDiv(JUST(ScalarMul(2 * M_PI, indice)), window_length));\n        result = JUST(Slice(JUST(ScalarDiv(JUST(ScalarSub(1, JUST(Cos(div_result)), 1)), 2)), {0},\n                            {window_length}, {1}, /*enable_view_slice=*/false));\n      } else {\n        const auto indice = JUST(Arange(window_length, dtype, device));\n        const auto div_result =\n            JUST(ScalarDiv(JUST(ScalarMul(2 * M_PI, indice)), window_length - 1));\n        result = JUST(ScalarDiv(JUST(ScalarSub(1, JUST(Cos(div_result)), 1)), 2));\n      }\n    }\n    JUST(result->set_requires_grad(requires_grad));\n    return result;\n  }\n};\n\nclass GlobalHannWindowFunctor {\n public:\n  Maybe<Tensor> operator()(const int64_t window_length, const bool& periodic,\n                           const Symbol<ParallelDesc>& placement,\n                           const std::vector<Symbol<SbpParallel>>& sbp,\n                           const Optional<Symbol<DType>>& dtype, const bool& requires_grad) const {\n    autograd::AutoGradMode mode(false);\n    JUST(CheckDeviceIdsIsValid(placement));\n    if (dtype.has_value() && !IsFloatingDataType(JUST(dtype)->data_type())) {\n      return Error::RuntimeError()\n             << \"hann_window expects floating point dtypes, got: \" << JUST(dtype)->name();\n    }\n    auto result = JUST(GlobalArange(1, 1 + window_length, 1, dtype, placement, sbp));\n    if (window_length != 1) {\n      if (periodic) {\n        const auto indice = JUST(GlobalArange(window_length + 8, dtype, placement, sbp));\n        const auto div_result = JUST(ScalarDiv(JUST(ScalarMul(2 * M_PI, indice)), window_length));\n        result = JUST(Slice(JUST(ScalarDiv(JUST(ScalarSub(1, JUST(Cos(div_result)), 1)), 2)), {0},\n                            {window_length}, {1}, /*enable_view_slice=*/false));\n      } else {\n        const auto indice = JUST(GlobalArange(window_length, dtype, placement, sbp));\n        const auto div_result =\n            JUST(ScalarDiv(JUST(ScalarMul(2 * M_PI, indice)), window_length - 1));\n        result = JUST(ScalarDiv(JUST(ScalarSub(1, JUST(Cos(div_result)), 1)), 2));\n      }\n    }\n    result = JUST(ToGlobal(result, placement, sbp, {}, true, /*copy=*/false));\n    JUST(result->set_requires_grad(requires_grad));\n    return result;\n  }\n};\n\nclass CastFunctor {\n public:\n  CastFunctor() { op_ = CHECK_JUST(one::OpBuilder(\"cast\").Input(\"in\").Output(\"out\").Build()); }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const Symbol<DType>& dtype,\n                           const bool pin_memory) const {\n    if (x->dtype() == dtype) { return x; }\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"dtype\", \"pin_memory\");\n    attrs.SetAllAttrs(dtype->data_type(), pin_memory);\n    // refers to pytorch's tensor.to (to_impl function at\n    // aten/src/ATen/native/TensorConversions.cpp)\n    if (JUST(IsNonOverlappingAndDense(x))) {\n      return OpInterpUtil::Dispatch<Tensor>(*op_, {x}, attrs);\n    } else {\n      return OpInterpUtil::Dispatch<Tensor>(*op_, {x->contiguous()}, attrs);\n    }\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass ClampBaseFunctor {\n public:\n  ClampBaseFunctor() {\n    clip_op_ = CHECK_JUST(one::OpBuilder(\"clip_by_scalar\").Input(\"x\").Output(\"y\").Build());\n    clip_min_op_ = CHECK_JUST(one::OpBuilder(\"clip_by_scalar_min\").Input(\"x\").Output(\"y\").Build());\n    clip_max_op_ = CHECK_JUST(one::OpBuilder(\"clip_by_scalar_max\").Input(\"x\").Output(\"y\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const Optional<Scalar>& min,\n                           const Optional<Scalar>& max, bool inplace) const {\n    CHECK_OR_RETURN(min.has_value() || max.has_value())\n        << \"Requires one of argument `min` and `max` at least in clip.\";\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"floating_min\", \"integral_min\", \"floating_max\",\n                                                 \"integral_max\");\n    if (IsFloatingDataType(x->dtype()->data_type())) {\n      if (min.has_value()) {\n        const auto& min_val = JUST(min);\n        attrs.SetAttr<0>(min_val->As<double>());\n        attrs.SetAttr<1>(static_cast<int64_t>(0));\n      }\n      if (max.has_value()) {\n        const auto& max_val = JUST(max);\n        attrs.SetAttr<2>(max_val->As<double>());\n        attrs.SetAttr<3>(static_cast<int64_t>(0));\n      }\n    } else if (IsIntegralDataType(x->dtype()->data_type())) {\n      if (min.has_value()) {\n        const auto& min_val = JUST(min);\n        attrs.SetAttr<0>(static_cast<double>(0));\n        attrs.SetAttr<1>(min_val->As<int64_t>());\n      }\n      if (max.has_value()) {\n        const auto& max_val = JUST(max);\n        attrs.SetAttr<2>(static_cast<double>(0));\n        attrs.SetAttr<3>(max_val->As<int64_t>());\n      }\n    } else {\n      UNIMPLEMENTED_THEN_RETURN() << \"Only support floating or integral data type.\";\n    }\n    const OpExpr* op = nullptr;\n    if (!min.has_value()) {\n      op = clip_max_op_.get();\n    } else if (!max.has_value()) {\n      op = clip_min_op_.get();\n    } else {\n      op = clip_op_.get();\n    }\n    if (inplace) {\n      JUST(CheckInplaceValid(x));\n      std::shared_ptr<TensorTuple> outputs = std::make_shared<TensorTuple>(1);\n      outputs->at(0) = x;\n      if (x->requires_grad()) {\n        JUST(OpInterpUtil::Dispatch(*op, {JUST(functional::Identity(x))}, outputs.get(), attrs));\n      } else {\n        JUST(OpInterpUtil::Dispatch(*op, {x}, outputs.get(), attrs));\n      }\n      return outputs->at(0);\n    } else {\n      return OpInterpUtil::Dispatch<Tensor>(*op, {x}, attrs);\n    }\n  }\n\n private:\n  std::shared_ptr<OpExpr> clip_op_;\n  std::shared_ptr<OpExpr> clip_min_op_;\n  std::shared_ptr<OpExpr> clip_max_op_;\n};\n\nclass ClampFunctor : public ClampBaseFunctor {\n public:\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const Optional<Scalar>& min,\n                           const Optional<Scalar>& max) const {\n    return ClampBaseFunctor::operator()(x, min, max, /* inplace=*/false);\n  }\n};\n\nclass ClampMinFunctor : public ClampBaseFunctor {\n public:\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const Scalar& min) const {\n    return ClampBaseFunctor::operator()(x, min, NullOpt, /* inplace=*/false);\n  }\n};\n\nclass ClampMaxFunctor : public ClampBaseFunctor {\n public:\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const Scalar& max) const {\n    return ClampBaseFunctor::operator()(x, NullOpt, max, /* inplace=*/false);\n  }\n};\n\nclass ClampInplaceFunctor : public ClampBaseFunctor {\n public:\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const Optional<Scalar>& min,\n                           const Optional<Scalar>& max) const {\n    return ClampBaseFunctor::operator()(x, min, max, /* inplace=*/true);\n  }\n};\n\nclass ClampMinInplaceFunctor : public ClampBaseFunctor {\n public:\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const Scalar& min) const {\n    return ClampBaseFunctor::operator()(x, min, NullOpt, /* inplace=*/true);\n  }\n};\n\nclass ClampMaxInplaceFunctor : public ClampBaseFunctor {\n public:\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const Scalar& max) const {\n    return ClampBaseFunctor::operator()(x, NullOpt, max, /* inplace=*/true);\n  }\n};\n\nclass ClipFunctor {\n public:\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const Optional<Scalar>& min,\n                           const Optional<Scalar>& max) const {\n    return Clamp(x, min, max);\n  }\n};\n\nclass ClipInplaceFunctor {\n public:\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const Optional<Scalar>& min,\n                           const Optional<Scalar>& max) const {\n    return ClampInplace(x, min, max);\n  }\n};\nclass SqrtSquareSumFunctor {\n public:\n  SqrtSquareSumFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"sqrt_square_sum\").Input(\"x\").Output(\"y\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x) const {\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {x}, {});\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass VectorNormFunctor {\n public:\n  VectorNormFunctor() {}\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const Scalar& ord,\n                           const Optional<std::vector<int32_t>>& input_dim, const bool& keepdim,\n                           const Optional<Symbol<DType>>& dtype) const {\n    std::shared_ptr<one::Tensor> res;\n    Symbol<DType> dtype_val;\n    if (dtype) {\n      dtype_val = JUST(dtype);\n      if (!(dtype_val->data_type() == DataType::kFloat\n            || dtype_val->data_type() == DataType::kDouble\n            || dtype_val->data_type() == DataType::kFloat16\n            || dtype_val->data_type() == DataType::kBFloat16)) {\n        UNIMPLEMENTED_THEN_RETURN() << \"linalg.vector_norm(): only supports floating point and \"\n                                       \"complex dtypes, but got: Int.\";\n      }\n    } else {\n      if (!IsFloatingDataType(x->dtype()->data_type())) {\n        UNIMPLEMENTED_THEN_RETURN() << \"linalg.vector_norm(): only supports floating point and \"\n                                       \"complex dtypes, but got: Int.\";\n      }\n      dtype_val = x->dtype();\n    }\n    bool full_dim_flag = true;\n    std::vector<int32_t> dim;\n    if (!input_dim.has_value()) {\n      std::vector<int32_t> reduce_axis(x->ndim());\n      std::iota(reduce_axis.begin(), reduce_axis.end(), 0);\n      dim = reduce_axis;\n    } else {\n      std::vector<int32_t> dim_check;\n      dim_check = *JUST(input_dim);\n      for (int i = 0; i < dim_check.size(); ++i) {\n        if (dim_check[i] >= 0) {\n          dim.emplace_back(dim_check[i]);\n        } else {\n          dim.emplace_back(dim_check[i] + x->ndim());\n        }\n        if (dim[i] != i) { full_dim_flag = false; }\n      }\n      if ((int)dim.size() < x->ndim()) { full_dim_flag = false; }\n    }\n    if (ord.IsIntegral() || ord.IsFloatingPoint()) {\n      double ord_val = ord.As<double>();\n      if (ord_val == 0) {\n        res = JUST(ReduceSum(JUST(functional::NotEqualZero(x)), dim, keepdim, NullOpt));\n      } else if (ord_val == INFINITY) {\n        res = JUST(ReduceMax(JUST(Abs(x)), dim, keepdim));\n      } else if (ord_val == -INFINITY) {\n        res = JUST(ReduceMin(JUST(Abs(x)), dim, keepdim));\n      } else if (ord_val == 2.0 && keepdim == false && full_dim_flag\n                 && x->requires_grad() == false) {\n        res = JUST(SqrtSquareSum(x));\n      } else {\n        res = JUST(ScalarPow(\n            JUST(ReduceSum(JUST(ScalarPow(JUST(Abs(x)), ord, false)), dim, keepdim, NullOpt)),\n            Scalar(1.0) / ord, false));\n      }\n      res = JUST(Cast(res, dtype_val, /*pin_memory=*/false));\n      return res;\n    } else {\n      UNIMPLEMENTED_THEN_RETURN()\n          << \"linalg_vector_norm(): argument 'ord' must be Number, not str.\";\n    }\n  }\n};\n\nclass ScalarVectorNormFunctor {\n public:\n  ScalarVectorNormFunctor() {}\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const Scalar& ord,\n                           const Scalar& input_dim, const bool& keepdim,\n                           const Optional<Symbol<DType>>& dtype) const {\n    if (dtype) {\n      Symbol<DType> dtype_val = JUST(dtype);\n      if (!(dtype_val->data_type() == DataType::kFloat\n            || dtype_val->data_type() == DataType::kDouble\n            || dtype_val->data_type() == DataType::kFloat16\n            || dtype_val->data_type() == DataType::kBFloat16)) {\n        UNIMPLEMENTED_THEN_RETURN() << \"linalg.vector_norm(): only supports the float, double, \"\n                                       \"cfloat and cdouble dtypes, but got: Int.\";\n      }\n    } else {\n      if (!IsFloatingDataType(x->dtype()->data_type())) {\n        UNIMPLEMENTED_THEN_RETURN() << \"linalg.vector_norm(): only supports the float, double, \"\n                                       \"cfloat and cdouble dtypes, but got: Int.\";\n      }\n    }\n    if (input_dim.IsIntegral()) {\n      std::vector<int32_t> dim(1, input_dim.As<int>());\n      return functional::VectorNorm(x, ord, dim, keepdim, dtype);\n    } else {\n      UNIMPLEMENTED_THEN_RETURN() << \"linalg.vector_norm(): only support int dim.\";\n    }\n  }\n};\n\nclass ScalarMatrixNormFunctor {\n public:\n  ScalarMatrixNormFunctor() {}\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const Scalar& ord,\n                           const std::vector<int32_t>& input_dim, const bool& keepdim,\n                           const Optional<Symbol<DType>>& dtype) const {\n    std::shared_ptr<one::Tensor> res;\n\n    auto num_dims = x->ndim();\n    auto axis = input_dim.size();\n    CHECK_OR_RETURN(num_dims >= 2)\n        << \"linalg.matrix_norm(): input tensor must be a matrix or batch of matrices\";\n    CHECK_OR_RETURN(axis == 2 && input_dim[0] != input_dim[1])\n        << \"linalg.matrix_norm(): input_dim must be a 2-tuple of ints with different elements\";\n\n    Symbol<DType> dtype_val;\n    if (dtype) {\n      dtype_val = JUST(dtype);\n      if (!(dtype_val->data_type() == DataType::kFloat\n            || dtype_val->data_type() == DataType::kDouble\n            || dtype_val->data_type() == DataType::kFloat16\n            || dtype_val->data_type() == DataType::kBFloat16)) {\n        UNIMPLEMENTED_THEN_RETURN() << \"linalg.matrix_norm(): only supports the float, double, \"\n                                       \"cfloat and cdouble dtypes, but got: Int.\";\n      }\n    } else {\n      if (!IsFloatingDataType(x->dtype()->data_type())) {\n        UNIMPLEMENTED_THEN_RETURN() << \"linalg.matrix_norm(): only supports the float, double, \"\n                                       \"cfloat and cdouble dtypes, but got: Int.\";\n      }\n      dtype_val = x->dtype();\n    }\n    std::vector<int32_t> dim_tmp;\n    dim_tmp.reserve(axis);\n    for (int i = 0; i < axis; ++i) {\n      if (input_dim[i] >= 0) {\n        dim_tmp.emplace_back(input_dim[i]);\n      } else {\n        dim_tmp.emplace_back(input_dim[i] + num_dims);\n      }\n    }\n    std::vector<int32_t> dim(2);\n    double ord_tmp = ord.As<double>();\n    if (ord_tmp == INFINITY || ord_tmp == -INFINITY) {\n      dim = dim_tmp;\n      dim[0] = dim_tmp[1];\n      dim[1] = dim_tmp[0];\n    } else if (ord_tmp == 1 || ord_tmp == -1) {\n      dim = dim_tmp;\n    } else {\n      UNIMPLEMENTED_THEN_RETURN()\n          << \"linalg.matrix_norm(): Only support INFINITY,-INFINITY,1 or -1 data type.\";\n    }\n\n    if (dim[1] > dim[0] && keepdim == false) { dim[1] -= 1; }\n    std::vector<int32_t> dim_tmp0_vec(1, dim[0]);\n    std::vector<int32_t> dim_tmp1_vec(1, dim[1]);\n    res = JUST(ReduceSum(JUST(Abs(x)), dim_tmp0_vec, keepdim, NullOpt));\n\n    if (ord_tmp == INFINITY || ord_tmp == 1) {\n      res = JUST(ReduceMax(res, dim_tmp1_vec, keepdim));\n    } else if (ord_tmp == -INFINITY || ord_tmp == -1) {\n      res = JUST(ReduceMin(res, dim_tmp1_vec, keepdim));\n    }\n    res = JUST(Cast(res, dtype_val, /*pin_memory=*/false));\n    return res;\n  }\n};\n\nclass MatrixNormFunctor {\n public:\n  MatrixNormFunctor() {}\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const std::string& ord,\n                           const std::vector<int32_t>& input_dim, const bool& keepdim,\n                           const Optional<Symbol<DType>>& dtype) const {\n    std::shared_ptr<one::Tensor> res;\n    Symbol<DType> dtype_val;\n    if (dtype) {\n      dtype_val = JUST(dtype);\n      if (!(dtype_val->data_type() == DataType::kFloat\n            || dtype_val->data_type() == DataType::kDouble\n            || dtype_val->data_type() == DataType::kFloat16\n            || dtype_val->data_type() == DataType::kBFloat16)) {\n        UNIMPLEMENTED_THEN_RETURN() << \"linalg.matrix_norm(): only supports the float, double, \"\n                                       \"cfloat and cdouble dtypes, but got: Int.\";\n      }\n    } else {\n      if (!IsFloatingDataType(x->dtype()->data_type())) {\n        UNIMPLEMENTED_THEN_RETURN() << \"linalg.matrix_norm(): only supports the float, double, \"\n                                       \"cfloat and cdouble dtypes, but got: Int.\";\n      }\n      dtype_val = x->dtype();\n    }\n    auto num_dims = x->ndim();\n    auto axis = input_dim.size();\n    std::vector<int32_t> dim_tmp(axis);\n    for (int i = 0; i < axis; ++i) {\n      if (input_dim[i] >= 0) {\n        dim_tmp[i] = input_dim[i];\n      } else {\n        dim_tmp[i] = input_dim[i] + num_dims;\n      }\n    }\n    if (ord == \"nuc\") {\n      UNIMPLEMENTED_THEN_RETURN() << \"linalg.matrix_norm(): Not support ord is nuc.\";\n    } else if (ord == \"fro\") {\n      res = JUST(Sqrt(JUST(ReduceSum(JUST(Square(x)), dim_tmp, keepdim, NullOpt))));\n    } else {\n      UNIMPLEMENTED_THEN_RETURN() << \"linalg.matrix_norm(): could not convert string to float:\"\n                                  << ord;\n    }\n    res = JUST(Cast(res, dtype_val, /*pin_memory=*/false));\n    return res;\n  }\n};\n\nclass NormFunctor {\n public:\n  NormFunctor() {}\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const Optional<Scalar>& ord,\n                           const Optional<std::vector<int32_t>>& input_dim, const bool& keepdim,\n                           const Optional<Symbol<DType>>& dtype, const bool& for_norm) const {\n    // If for_norm, the functor will be used to oneflow.norm.\n    std::shared_ptr<one::Tensor> res;\n    if (dtype) {\n      Symbol<DType> dtype_val = JUST(dtype);\n      if (!(dtype_val->data_type() == DataType::kFloat\n            || dtype_val->data_type() == DataType::kDouble\n            || dtype_val->data_type() == DataType::kFloat16\n            || dtype_val->data_type() == DataType::kBFloat16)) {\n        UNIMPLEMENTED_THEN_RETURN() << \"linalg.norm(): only supports the float, double, cfloat and \"\n                                       \"cdouble dtypes, but got: Int.\";\n      }\n    } else {\n      if (!IsFloatingDataType(x->dtype()->data_type())) {\n        UNIMPLEMENTED_THEN_RETURN() << \"linalg.norm(): only supports the float, double, cfloat and \"\n                                       \"cdouble dtypes, but got: Int.\";\n      }\n    }\n    Scalar ord_sca;\n    bool ord_type = false;\n    if (ord.has_value()) {\n      ord_type = (*JUST(ord)).IsIntegral();\n      if (ord_type) {\n        ord_sca = Scalar((*JUST(ord)).As<double>());\n      } else {\n        ord_sca = *JUST(ord);\n      }\n    }\n    if (input_dim.has_value()) {\n      auto axis = (*JUST(input_dim)).size();\n      if (axis == 1) {\n        Scalar ord_val;\n        if (!ord.has_value()) {\n          ord_val = Scalar(2.0);\n        } else {\n          ord_val = ord_sca;\n        }\n        res = JUST(VectorNorm(x, ord_val, input_dim, keepdim, dtype));\n      } else if (axis > 2) {\n        res = JUST(MatrixNorm(x, ord_sca, *JUST(input_dim), keepdim, dtype));\n      } else if (axis == 2) {\n        if (!ord.has_value()) {\n          res = JUST(MatrixNorm(x, \"fro\", *JUST(input_dim), keepdim, dtype));\n        } else {\n          res = JUST(MatrixNorm(x, ord_sca, *JUST(input_dim), keepdim, dtype));\n        }\n      }\n    } else {\n      if (ord.has_value()) {\n        CHECK_OR_RETURN(x->ndim() <= 2)\n            << \"linalg.norm(): input must be 1-D or 2-D when dim is None and ord is not None\";\n        if (ord_type) {\n          const double ord_double = (*JUST(ord)).As<double>();\n          if (for_norm && (ord_double >= 2 || ord_double <= -2)) {\n            const int32_t num_axes = x->shape()->NumAxes();\n            std::vector<int32_t> axes_vec(num_axes);\n            std::iota(axes_vec.begin(), axes_vec.end(), 0);\n            return ScalarPow(JUST(ReduceSum(JUST(ScalarPow(JUST(Abs(x)), ord_sca, false)), axes_vec,\n                                            /*keepdims=*/false, NullOpt)),\n                             1 / ord_double, false);\n          }\n        }\n        if (x->ndim() == 1) {\n          res = JUST(VectorNorm(x, ord_sca, input_dim, keepdim, dtype));\n        } else {\n          std::vector<int32_t> dim{0, 1};\n          res = JUST(MatrixNorm(x, ord_sca, dim, keepdim, dtype));\n        }\n      } else {\n        res = JUST(VectorNorm(x, Scalar(2.0), input_dim, keepdim, dtype));\n      }\n    }\n    return res;\n  }\n};\n\nclass Norm2Functor {\n public:\n  Norm2Functor() {}\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const std::string& ord,\n                           const Optional<std::vector<int32_t>>& input_dim, const bool& keepdim,\n                           const Optional<Symbol<DType>>& dtype) const {\n    std::shared_ptr<one::Tensor> res;\n    std::vector<int32_t> dim(x->ndim());\n    std::iota(dim.begin(), dim.end(), 0);\n    if (dtype) {\n      Symbol<DType> dtype_val = JUST(dtype);\n      if (!(dtype_val->data_type() == DataType::kFloat\n            || dtype_val->data_type() == DataType::kDouble\n            || dtype_val->data_type() == DataType::kFloat16\n            || dtype_val->data_type() == DataType::kBFloat16)) {\n        UNIMPLEMENTED_THEN_RETURN() << \"linalg.norm(): only supports the float, double, cfloat and \"\n                                       \"cdouble dtypes, but got: Int.\";\n      }\n    } else {\n      if (!IsFloatingDataType(x->dtype()->data_type())) {\n        UNIMPLEMENTED_THEN_RETURN() << \"linalg.norm(): only supports the float, double, cfloat and \"\n                                       \"cdouble dtypes, but got: Int.\";\n      }\n    }\n    if (input_dim.has_value()) {\n      res = JUST(MatrixNorm(x, ord, *JUST(input_dim), keepdim, dtype));\n    } else {\n      res = JUST(MatrixNorm(x, ord, dim, keepdim, dtype));\n    }\n    return res;\n  }\n};\n\nclass ScalarNormFunctor {\n public:\n  ScalarNormFunctor() {}\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const Optional<Scalar>& ord,\n                           const Scalar& input_dim, const bool& keepdim,\n                           const Optional<Symbol<DType>>& dtype) const {\n    if (dtype) {\n      Symbol<DType> dtype_val = JUST(dtype);\n      if (!(dtype_val->data_type() == DataType::kFloat\n            || dtype_val->data_type() == DataType::kDouble\n            || dtype_val->data_type() == DataType::kFloat16\n            || dtype_val->data_type() == DataType::kBFloat16)) {\n        UNIMPLEMENTED_THEN_RETURN() << \"linalg.norm(): only supports the float, double, cfloat and \"\n                                       \"cdouble dtypes, but got: Int.\";\n      }\n    } else {\n      if (!IsFloatingDataType(x->dtype()->data_type())) {\n        UNIMPLEMENTED_THEN_RETURN() << \"linalg.norm(): only supports the float, double, cfloat and \"\n                                       \"cdouble dtypes, but got: Int.\";\n      }\n    }\n    if (input_dim.IsIntegral()) {\n      std::vector<int32_t> dim(1, input_dim.As<int>());\n      return functional::Norm(x, ord, dim, keepdim, dtype, /*for_norm=*/false);\n    } else {\n      UNIMPLEMENTED_THEN_RETURN() << \"linalg_norm(): only supports int dim.\";\n    }\n  }\n};\n\nclass ScalarNorm2Functor {\n public:\n  ScalarNorm2Functor() {}\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const std::string& ord,\n                           const Scalar& input_dim, const bool& keepdim,\n                           const Optional<Symbol<DType>>& dtype) const {\n    if (dtype) {\n      Symbol<DType> dtype_val = JUST(dtype);\n      if (!(dtype_val->data_type() == DataType::kFloat\n            || dtype_val->data_type() == DataType::kDouble\n            || dtype_val->data_type() == DataType::kFloat16\n            || dtype_val->data_type() == DataType::kBFloat16)) {\n        UNIMPLEMENTED_THEN_RETURN() << \"linalg.norm(): only supports the float, double, cfloat and \"\n                                       \"cdouble dtypes, but got: Int.\";\n      }\n    } else {\n      if (!IsFloatingDataType(x->dtype()->data_type())) {\n        UNIMPLEMENTED_THEN_RETURN() << \"linalg.norm(): only supports the float, double, cfloat and \"\n                                       \"cdouble dtypes, but got: Int.\";\n      }\n    }\n    if (input_dim.IsIntegral()) {\n      std::vector<int32_t> dim(1, input_dim.As<int>());\n      return functional::Norm(x, ord, dim, keepdim, dtype);\n    } else {\n      UNIMPLEMENTED_THEN_RETURN() << \"linalg_norm(): only supports int dim.\";\n    }\n  }\n};\n\nclass InvFunctor {\n public:\n  InvFunctor() { op_ = CHECK_JUST(one::OpBuilder(\"inv\").Input(\"x\").Output(\"y\").Build()); }\n  Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& x) const {\n    if (x->ndim() < 2) {\n      return Error::RuntimeError() << \"linalg.inv: The input tensor must be at least 2 dimensions.\";\n    }\n    if (x->dim(x->ndim() - 1) != x->dim(x->ndim() - 2)) {\n      return Error::RuntimeError() << \"linalg.inv: A must be batches of square matrices, \"\n                                   << \"but they are \" << x->dim(x->ndim() - 2) << \" by \"\n                                   << x->dim(x->ndim() - 1) << \" matrices\";\n    }\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {x}, {});\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass DetFunctor {\n public:\n  DetFunctor() {\n    det_op_ = CHECK_JUST(one::OpBuilder(\"det\").Input(\"x\").Output(\"y\").Build());\n    lu_decomposition_op_ = CHECK_JUST(\n        one::OpBuilder(\"lu_decomposition\").Input(\"x\").Output(\"LU\").Output(\"pivot\").Build());\n  }\n  Maybe<Tensor> GetPivotDet(const std::shared_ptr<Tensor>& pivot) const {\n    std::shared_ptr<Tensor> arange = nullptr;\n    int64_t end = pivot->shape()->At(pivot->ndim() - 1) + 1;\n    if (pivot->is_local()) {\n      arange = JUST(functional::Arange(1, end, 1, pivot->dtype(), JUST(pivot->device())));\n    } else {\n      auto pivot_nd_sbp = JUST(pivot->nd_sbp());\n      std::vector<Symbol<SbpParallel>> nd_sbp(pivot_nd_sbp->sbp_parallel_size());\n      {\n        for (int i = 0; i < nd_sbp.size(); ++i) { nd_sbp[i] = pivot_nd_sbp->sbp_parallel(i); }\n      }\n      arange = JUST(functional::GlobalArange(1, end, 1, pivot->dtype(),\n                                             JUST(pivot->parallel_desc()), nd_sbp));\n    }\n    return sequence_function(functional::BroadcastNotEqual)\n        .then([](const auto& x) { return functional::ReduceSum(x, {-1}, false, NullOpt); })\n        .then([](const auto& x) { return functional::ScalarFMod(x, Scalar(2), true); })\n        .then([](const auto& x) { return functional::ScalarMul(x, Scalar(-2), true); })\n        .then([](const auto& x) { return functional::ScalarAdd(x, Scalar(1), Scalar(1), true); })\n        .call(arange, pivot);\n  }\n\n  Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& x) const {\n    const int64_t xdims = x->ndim();\n    if (xdims < 2) {\n      return Error::RuntimeError() << \"linalg.det: The input tensor must be at least 2 dimensions.\";\n    }\n    if (x->dim(xdims - 1) != x->dim(xdims - 2)) {\n      return Error::RuntimeError()\n             << \"linalg.det: A must be batches of square matrices, \"\n             << \"but they are \" << x->dim(xdims - 2) << \" by \" << x->dim(xdims - 1) << \" matrices\";\n    }\n\n    DeviceType x_device_type = DeviceType::kInvalidDevice;\n    if (x->is_local()) {\n      x_device_type = JUST(x->device())->enum_type();\n    } else if (x->is_global()) {\n      x_device_type = JUST(x->parallel_desc())->device_type();\n    }\n\n    if (x_device_type == DeviceType::kCPU) {\n      return JUST(OpInterpUtil::Dispatch<Tensor>(*det_op_, {x}, {}));\n    } else if (x_device_type == DeviceType::kCUDA) {\n      auto result = JUST(OpInterpUtil::Dispatch<TensorTuple>(*lu_decomposition_op_, {x}, {}));\n      auto LU = result->at(0);\n      auto pivot = result->at(1);\n      auto LU_det = JUST(\n          functional::ReduceProd(JUST(functional::Diagonal(LU, 0, -2, -1)), {-1}, false, NullOpt));\n      return functional::Mul(JUST(GetPivotDet(pivot)), LU_det);\n    } else {\n      UNIMPLEMENTED_THEN_RETURN() << \"Det: Only support cpu and cuda device.\";\n    }\n  }\n\n private:\n  std::shared_ptr<OpExpr> det_op_;\n  std::shared_ptr<OpExpr> lu_decomposition_op_;\n};\n\nclass ClampGradFunctor {\n public:\n  ClampGradFunctor() {\n    clip_op_ = CHECK_JUST(\n        one::OpBuilder(\"clip_by_scalar_grad\").Input(\"dy\").Input(\"x\").Output(\"dx\").Build());\n    clip_min_op_ = CHECK_JUST(\n        one::OpBuilder(\"clip_by_scalar_min_grad\").Input(\"dy\").Input(\"x\").Output(\"dx\").Build());\n    clip_max_op_ = CHECK_JUST(\n        one::OpBuilder(\"clip_by_scalar_max_grad\").Input(\"dy\").Input(\"x\").Output(\"dx\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& dy,\n                           const std::shared_ptr<one::Tensor>& x, const Optional<Scalar>& min,\n                           const Optional<Scalar>& max) const {\n    CHECK_OR_RETURN(min.has_value() || max.has_value())\n        << \"Requires one of argument `min` and `max` at least in clip_grad.\";\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"floating_min\", \"integral_min\", \"floating_max\",\n                                                 \"integral_max\");\n    if (IsFloatingDataType(x->dtype()->data_type())) {\n      if (min.has_value()) {\n        const auto& min_val = JUST(min);\n        attrs.SetAttr<0>(min_val->As<double>());\n        attrs.SetAttr<1>(static_cast<int64_t>(0));\n      }\n      if (max.has_value()) {\n        const auto& max_val = JUST(max);\n        attrs.SetAttr<2>(max_val->As<double>());\n        attrs.SetAttr<3>(static_cast<int64_t>(0));\n      }\n    } else if (IsIntegralDataType(x->dtype()->data_type())) {\n      if (min.has_value()) {\n        const auto& min_val = JUST(min);\n        attrs.SetAttr<0>(static_cast<double>(0));\n        attrs.SetAttr<1>(min_val->As<int64_t>());\n      }\n      if (max.has_value()) {\n        const auto& max_val = JUST(max);\n        attrs.SetAttr<2>(static_cast<double>(0));\n        attrs.SetAttr<3>(max_val->As<int64_t>());\n      }\n    } else {\n      UNIMPLEMENTED_THEN_RETURN() << \"Only support floating or integral data type.\";\n    }\n    const OpExpr* op = nullptr;\n    if (!min.has_value()) {\n      op = clip_max_op_.get();\n    } else if (!max.has_value()) {\n      op = clip_min_op_.get();\n    } else {\n      op = clip_op_.get();\n    }\n    return OpInterpUtil::Dispatch<Tensor>(*op, {dy, x}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> clip_op_;\n  std::shared_ptr<OpExpr> clip_min_op_;\n  std::shared_ptr<OpExpr> clip_max_op_;\n};\n\nclass SelectFunctor {\n public:\n  SelectFunctor() = default;\n\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input, const int32_t& dim,\n                           const int32_t& index) const {\n    int32_t ndim = input->ndim();\n    CHECK_OR_RETURN(ndim > 0) << \"select() cannot be applied to a 0-dim tensor.\";\n    int32_t pos_dim = JUST(maybe_wrap_dim(dim, ndim));\n    auto size = input->dim(pos_dim);\n    CHECK_OR_RETURN((index >= -size) && (index < size))\n        << \"Index out of range (expected to be in range of [\" << -size << \",\" << size - 1\n        << \"], but got \" << index << \")\";\n    int32_t pos_index = index >= 0 ? index : index + size;\n\n    std::vector<int64_t> sizes(input->shape()->dim_vec().begin(), input->shape()->dim_vec().end());\n    const auto& stride = *JUST(input->stride());\n    std::vector<int64_t> strides(stride.begin(), stride.end());\n    auto storage_offset = JUST(input->storage_offset()) + pos_index * strides[pos_dim];\n\n    sizes.erase(sizes.begin() + pos_dim);\n    strides.erase(strides.begin() + pos_dim);\n\n    return AsStrided(input, sizes, strides, storage_offset);\n  }\n};\n\nclass SelectTopNFunctor {\n public:\n  SelectTopNFunctor() { op_ = CHECK_JUST(one::SelectTopNOpExpr::New()); }\n\n  Maybe<TensorTuple> operator()(const TensorTuple& inputs, int32_t n) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"top_n\");\n    attrs.SetAllAttrs(n);\n    std::vector<bool> require_grad(n);\n    for (int i = 0; i < n; ++i) { require_grad[i] = JUST(VectorAt(inputs, i))->requires_grad(); }\n    const auto& output = JUST(OpInterpUtil::Dispatch<one::TensorTuple>(*op_, inputs, attrs));\n    for (int i = 0; i < output->size(); ++i) {\n      (*output)[i]->set_is_leaf(false);\n      JUST((*output)[i]->set_requires_grad(require_grad[i]));\n    }\n    return output;\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass MinimumFunctor {\n public:\n  MinimumFunctor() {\n    elementwise_minimum_op_ =\n        CHECK_JUST(one::OpBuilder(\"elementwise_minimum\").Input(\"x\").Input(\"y\").Output(\"z\").Build());\n    broadcast_minimum_op_ =\n        CHECK_JUST(one::OpBuilder(\"broadcast_minimum\").Input(\"x\").Input(\"y\").Output(\"z\").Build());\n  }\n\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,\n                           const std::shared_ptr<one::Tensor>& y) const {\n    auto tensor_x = x;\n    auto tensor_y = y;\n    JUST(CastDeviceForCPUScalarTensor(tensor_x, tensor_y, /*inplace=*/false));\n    TensorProcessor tensor_processor;\n    JUST(tensor_processor.PromoteInputsToCommonDtype(true).AddInputs({tensor_x, tensor_y}).Apply());\n    TensorTuple input_tuple = JUST(tensor_processor.GetInputs());\n    if (*x->shape() == *y->shape()) {\n      return OpInterpUtil::Dispatch<Tensor>(*elementwise_minimum_op_,\n                                            {input_tuple[0], input_tuple[1]});\n    } else {\n      return OpInterpUtil::Dispatch<Tensor>(*broadcast_minimum_op_,\n                                            {input_tuple[0], input_tuple[1]});\n    }\n  }\n\n private:\n  std::shared_ptr<OpExpr> elementwise_minimum_op_;\n  std::shared_ptr<OpExpr> broadcast_minimum_op_;\n};\n\nclass MaximumFunctor {\n public:\n  MaximumFunctor() {\n    elementwise_maximum_op_ =\n        CHECK_JUST(one::OpBuilder(\"elementwise_maximum\").Input(\"x\").Input(\"y\").Output(\"z\").Build());\n    broadcast_maximum_op_ =\n        CHECK_JUST(one::OpBuilder(\"broadcast_maximum\").Input(\"x\").Input(\"y\").Output(\"z\").Build());\n  }\n\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,\n                           const std::shared_ptr<one::Tensor>& y) const {\n    auto tensor_x = x;\n    auto tensor_y = y;\n    JUST(CastDeviceForCPUScalarTensor(tensor_x, tensor_y, /*inplace=*/false));\n    TensorProcessor tensor_processor;\n    JUST(tensor_processor.PromoteInputsToCommonDtype(true).AddInputs({tensor_x, tensor_y}).Apply());\n    TensorTuple input_tuple = JUST(tensor_processor.GetInputs());\n    if (*x->shape() == *y->shape()) {\n      return OpInterpUtil::Dispatch<Tensor>(*elementwise_maximum_op_,\n                                            {input_tuple[0], input_tuple[1]});\n    } else {\n      return OpInterpUtil::Dispatch<Tensor>(*broadcast_maximum_op_,\n                                            {input_tuple[0], input_tuple[1]});\n    }\n  }\n\n private:\n  std::shared_ptr<OpExpr> elementwise_maximum_op_;\n  std::shared_ptr<OpExpr> broadcast_maximum_op_;\n};\n\nclass ScalarLogicalBaseFunctor {\n public:\n  explicit ScalarLogicalBaseFunctor(std::string op_name) {\n    op_ = CHECK_JUST(one::OpBuilder(op_name).Input(\"in\").Output(\"out\").Build());\n  }\n  virtual ~ScalarLogicalBaseFunctor() = default;\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const Scalar& scalar) const {\n    TensorProcessor tensor_processor;\n    Symbol<DType> lowest_dtype;\n\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"float_operand\", \"has_float_operand\",\n                                                 \"int_operand\", \"has_int_operand\");\n    if (scalar.IsFloatingPoint()) {\n      attrs.SetAllAttrs(scalar.As<double>(), true, NullOpt, false);\n      // Only promote type to Float32 when tensor is Int type but scalar is float type.\n      if (DType::priority_order[x->dtype()->data_type()]\n          < DType::priority_order[DType::Float16()->data_type()]) {\n        lowest_dtype = DType::Float();\n      } else {\n        lowest_dtype = x->dtype();\n      }\n    } else if (scalar.IsIntegral() || scalar.IsBool()) {\n      attrs.SetAllAttrs(NullOpt, false, scalar.As<int64_t>(), true);\n      // Only promote type to Int64 when tensor is Bool type but scalar is int type.\n      if (DType::priority_order[x->dtype()->data_type()]\n          == DType::priority_order[DType::Bool()->data_type()]) {\n        lowest_dtype = DType::Int64();\n      } else {\n        lowest_dtype = x->dtype();\n      }\n    } else {\n      UNIMPLEMENTED_THEN_RETURN() << \"The scalar in \" << op_->op_type_name()\n                                  << \" should be float or int.\";\n    }\n    JUST(tensor_processor.AddInputs({x}, lowest_dtype).Apply());\n    TensorTuple casted_vec = JUST(tensor_processor.GetInputs());\n\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {casted_vec}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass ScalarLogicalEqualFunctor : public ScalarLogicalBaseFunctor {\n public:\n  ScalarLogicalEqualFunctor() : ScalarLogicalBaseFunctor(/*op_name=*/\"scalar_logical_equal\") {}\n};\n\n// (scalar == x) = (x == scalar)\nclass ScalarLogicalEqual2Functor {\n public:\n  Maybe<Tensor> operator()(const Scalar& scalar, const std::shared_ptr<one::Tensor>& x) const {\n    return ScalarLogicalEqual(x, scalar);\n  }\n};\n\nclass ScalarLogicalNotEqualFunctor : public ScalarLogicalBaseFunctor {\n public:\n  ScalarLogicalNotEqualFunctor()\n      : ScalarLogicalBaseFunctor(/*op_name=*/\"scalar_logical_not_equal\") {}\n};\n\n// (scalar != x) = (x != scalar)\nclass ScalarLogicalNotEqual2Functor {\n public:\n  Maybe<Tensor> operator()(const Scalar& scalar, const std::shared_ptr<one::Tensor>& x) const {\n    return ScalarLogicalNotEqual(x, scalar);\n  }\n};\n\nclass ScalarLogicalGreaterFunctor : public ScalarLogicalBaseFunctor {\n public:\n  ScalarLogicalGreaterFunctor() : ScalarLogicalBaseFunctor(/*op_name=*/\"scalar_logical_greater\") {}\n};\n\n// (scalar > x) = (x < scalar)\nclass ScalarLogicalGreater2Functor {\n public:\n  Maybe<Tensor> operator()(const Scalar& scalar, const std::shared_ptr<one::Tensor>& x) const {\n    return ScalarLogicalLess(x, scalar);\n  }\n};\n\nclass InplaceScalarLogicalGreaterFunctor {\n public:\n  InplaceScalarLogicalGreaterFunctor() {\n    op_ = CHECK_JUST(\n        one::OpBuilder(\"scalar_logical_inplace_greater\").Input(\"in\").Output(\"out\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const Scalar& scalar) const {\n    TensorProcessor tensor_processor;\n    Symbol<DType> lowest_dtype;\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"float_operand\", \"has_float_operand\",\n                                                 \"int_operand\", \"has_int_operand\");\n    if (scalar.IsFloatingPoint()) {\n      attrs.SetAllAttrs(scalar.As<double>(), true, NullOpt, false);\n      // Only promote type to Float32 when tensor is Int type but scalar is float type.\n      if (DType::priority_order[x->dtype()->data_type()]\n          < DType::priority_order[DType::Float16()->data_type()]) {\n        lowest_dtype = DType::Float();\n      } else {\n        lowest_dtype = x->dtype();\n      }\n    } else if (scalar.IsIntegral() || scalar.IsBool()) {\n      attrs.SetAllAttrs(NullOpt, false, scalar.As<int64_t>(), true);\n      // Only promote type to Int64 when tensor is Bool type but scalar is int type.\n      if (DType::priority_order[x->dtype()->data_type()]\n          == DType::priority_order[DType::Bool()->data_type()]) {\n        lowest_dtype = DType::Int64();\n      } else {\n        lowest_dtype = x->dtype();\n      }\n    } else {\n      UNIMPLEMENTED_THEN_RETURN() << \"The scalar in \" << op_->op_type_name()\n                                  << \" should be float or int.\";\n    }\n    JUST(tensor_processor.PromoteInputsToCommonDtype(true).AddInputs({x}, lowest_dtype).Apply());\n    TensorTuple input_vec = JUST(tensor_processor.GetInputs());\n    const std::shared_ptr<one::Tensor>& x_cast = input_vec.at(0);\n    JUST(CheckInplaceValid(x));\n    JUST(CheckInplaceCastValid(x, x_cast));\n    std::shared_ptr<TensorTuple> outputs = std::make_shared<TensorTuple>(1);\n    outputs->at(0) = x;\n    JUST(OpInterpUtil::Dispatch(*op_, input_vec, outputs.get(), attrs));\n    return outputs->at(0);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass ScalarLogicalGreaterEqualFunctor : public ScalarLogicalBaseFunctor {\n public:\n  ScalarLogicalGreaterEqualFunctor()\n      : ScalarLogicalBaseFunctor(/*op_name=*/\"scalar_logical_greater_equal\") {}\n};\n\n// (scalar >= x) = (x <= scalar)\nclass ScalarLogicalGreaterEqual2Functor {\n public:\n  Maybe<Tensor> operator()(const Scalar& scalar, const std::shared_ptr<one::Tensor>& x) const {\n    return ScalarLogicalLessEqual(x, scalar);\n  }\n};\n\nclass ScalarLogicalLessFunctor : public ScalarLogicalBaseFunctor {\n public:\n  ScalarLogicalLessFunctor() : ScalarLogicalBaseFunctor(/*op_name=*/\"scalar_logical_less\") {}\n};\n\n// (scalar < x) = (x > scalar)\nclass ScalarLogicalLess2Functor {\n public:\n  Maybe<Tensor> operator()(const Scalar& scalar, const std::shared_ptr<one::Tensor>& x) const {\n    return ScalarLogicalGreater(x, scalar);\n  }\n};\n\nclass ScalarLogicalLessEqualFunctor : public ScalarLogicalBaseFunctor {\n public:\n  ScalarLogicalLessEqualFunctor()\n      : ScalarLogicalBaseFunctor(/*op_name=*/\"scalar_logical_less_equal\") {}\n};\n\n// (scalar <= x) = (x >= scalar)\nclass ScalarLogicalLessEqual2Functor {\n public:\n  Maybe<Tensor> operator()(const Scalar& scalar, const std::shared_ptr<one::Tensor>& x) const {\n    return ScalarLogicalGreaterEqual(x, scalar);\n  }\n};\n\nclass ScalarLogicalAndFunctor : public ScalarLogicalBaseFunctor {\n public:\n  ScalarLogicalAndFunctor() : ScalarLogicalBaseFunctor(/*op_name=*/\"scalar_logical_and\") {}\n};\n\n// (scalar && x) = (x && scalar)\nclass ScalarLogicalAnd2Functor {\n public:\n  Maybe<Tensor> operator()(const Scalar& scalar, const std::shared_ptr<one::Tensor>& x) const {\n    return ScalarLogicalAnd(x, scalar);\n  }\n};\n\nclass ScalarLogicalOrFunctor : public ScalarLogicalBaseFunctor {\n public:\n  ScalarLogicalOrFunctor() : ScalarLogicalBaseFunctor(/*op_name=*/\"scalar_logical_or\") {}\n};\n\n// (scalar || x) = (x || scalar)\nclass ScalarLogicalOr2Functor {\n public:\n  Maybe<Tensor> operator()(const Scalar& scalar, const std::shared_ptr<one::Tensor>& x) const {\n    return ScalarLogicalOr(x, scalar);\n  }\n};\n\nclass ScalarLogicalXorFunctor : public ScalarLogicalBaseFunctor {\n public:\n  ScalarLogicalXorFunctor() : ScalarLogicalBaseFunctor(/*op_name=*/\"scalar_logical_xor\") {}\n};\n\n// (scalar ^ x) = (x ^ scalar)\nclass ScalarLogicalXor2Functor {\n public:\n  Maybe<Tensor> operator()(const Scalar& scalar, const std::shared_ptr<one::Tensor>& x) const {\n    return ScalarLogicalXor(x, scalar);\n  }\n};\n\nclass ScalarBitwiseBaseFunctor {\n public:\n  explicit ScalarBitwiseBaseFunctor(std::string op_name) {\n    op_ = CHECK_JUST(one::OpBuilder(op_name).Input(\"in\").Output(\"out\").Build());\n  }\n  virtual ~ScalarBitwiseBaseFunctor() = default;\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const Scalar& scalar) const {\n    TensorProcessor tensor_processor;\n    Symbol<DType> lowest_dtype;\n\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"operand\");\n    CHECK_OR_RETURN(scalar.IsIntegral() || scalar.IsBool())\n        << \"Bitwise ops only support int and bool dtype\";\n    attrs.SetAllAttrs(scalar.As<int64_t>());\n    // Only promote type to Int64 when tensor is Bool type but scalar is int type.\n    if (DType::priority_order[x->dtype()->data_type()]\n        == DType::priority_order[DType::Bool()->data_type()]) {\n      lowest_dtype = DType::Int64();\n    } else {\n      lowest_dtype = x->dtype();\n    }\n    JUST(tensor_processor.AddInputs({x}, lowest_dtype).Apply());\n    TensorTuple casted_vec = JUST(tensor_processor.GetInputs());\n\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {casted_vec}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass ScalarLerpFunctor {\n public:\n  ScalarLerpFunctor() {\n    op_ =\n        CHECK_JUST(one::OpBuilder(\"scalar_lerp\").Input(\"start\").Input(\"end\").Output(\"out\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& start,\n                           const std::shared_ptr<one::Tensor>& end, const Scalar& weight) const {\n    CHECK_EQ_OR_RETURN(start->shape()->NumAxes(), end->shape()->NumAxes())\n        << Error::RuntimeError() << \"expected dim\" << start->shape()->NumAxes()\n        << \"for `end` but got dim\" << end->shape()->NumAxes();\n\n    auto broadcast_shape = *start->shape();\n    if (*start->shape() != *end->shape()) {\n      broadcast_shape = *JUST(InferUnifiedShapeForBroadcasting({*start->shape(), *end->shape()}));\n    }\n\n    std::shared_ptr<Tensor> broadcast_start = start;\n    std::shared_ptr<Tensor> broadcast_end = end;\n    if (*start->shape() != broadcast_shape) {\n      broadcast_start = JUST(functional::Expand(start, broadcast_shape));\n    }\n    if (*end->shape() != broadcast_shape) {\n      broadcast_end = JUST(functional::Expand(end, broadcast_shape));\n    }\n\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"float_operand\", \"has_float_operand\",\n                                                 \"int_operand\", \"has_int_operand\");\n    if (weight.IsFloatingPoint()) {\n      attrs.SetAllAttrs(weight.As<double>(), true, NullOpt, false);\n    } else if (weight.IsIntegral() || weight.IsBool()) {\n      attrs.SetAllAttrs(NullOpt, false, weight.As<int64_t>(), true);\n    } else {\n      UNIMPLEMENTED_THEN_RETURN() << \"The scalar in \" << op_->op_type_name()\n                                  << \" should be float or int.\";\n    }\n\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {broadcast_start, broadcast_end}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass ScalarInplaceLerpFunctor {\n public:\n  ScalarInplaceLerpFunctor() {\n    op_ =\n        CHECK_JUST(one::OpBuilder(\"scalar_lerp\").Input(\"start\").Input(\"end\").Output(\"out\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& start,\n                           const std::shared_ptr<one::Tensor>& end, const Scalar& weight) const {\n    CHECK_EQ_OR_RETURN(start->shape()->NumAxes(), end->shape()->NumAxes())\n        << Error::RuntimeError() << \"expected dim\" << start->shape()->NumAxes()\n        << \"for `end` but got dim\" << end->shape()->NumAxes();\n\n    auto broadcast_shape = *start->shape();\n    if (*start->shape() != *end->shape()) {\n      broadcast_shape = *JUST(InferUnifiedShapeForBroadcasting({*start->shape(), *end->shape()}));\n    }\n\n    std::shared_ptr<one::Tensor> broadcast_start = JUST(Identity(start));\n    std::shared_ptr<one::Tensor> broadcast_end = JUST(Identity(end));\n    if (*start->shape() != broadcast_shape) {\n      broadcast_start = JUST(view::Expand(start, broadcast_shape));\n    }\n    if (*end->shape() != broadcast_shape) {\n      broadcast_end = JUST(view::Expand(end, broadcast_shape));\n    }\n\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"float_operand\", \"has_float_operand\",\n                                                 \"int_operand\", \"has_int_operand\");\n    if (weight.IsFloatingPoint()) {\n      attrs.SetAllAttrs(weight.As<double>(), true, NullOpt, false);\n    } else if (weight.IsIntegral() || weight.IsBool()) {\n      attrs.SetAllAttrs(NullOpt, false, weight.As<int64_t>(), true);\n    } else {\n      UNIMPLEMENTED_THEN_RETURN() << \"The scalar in \" << op_->op_type_name()\n                                  << \" should be float or int.\";\n    }\n\n    TensorProcessor tensor_processor;\n    if (broadcast_end->requires_grad()) {\n      JUST(tensor_processor.PromoteInputsToCommonDtype(true)\n               .AddInputs({JUST(Identity(broadcast_start)), broadcast_end})\n               .Apply());\n    } else {\n      JUST(tensor_processor.PromoteInputsToCommonDtype(true)\n               .AddInputs({broadcast_start, broadcast_end})\n               .Apply());\n    }\n    const TensorTuple& input_vec = JUST(tensor_processor.GetInputs());\n    const std::shared_ptr<one::Tensor>& start_cast = input_vec.at(0);\n    const std::shared_ptr<one::Tensor>& end_cast = input_vec.at(1);\n    JUST(CheckInplaceValid(broadcast_start));\n    JUST(CheckInplaceCastValid(broadcast_start, start_cast));\n    JUST(CheckInplaceShapeCanExpandTo(*start_cast->shape(), *end_cast->shape()));\n    std::shared_ptr<TensorTuple> outputs = std::make_shared<TensorTuple>(1);\n    outputs->at(0) = start;\n    JUST(OpInterpUtil::Dispatch(*op_, input_vec, outputs.get(), attrs));\n    return outputs->at(0);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass ScalarLerpGradFunctor {\n public:\n  ScalarLerpGradFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"scalar_lerp_grad\")\n                         .Input(\"start\")\n                         .Input(\"end\")\n                         .Input(\"out_diff\")\n                         .Output(\"start_diff\")\n                         .Output(\"end_diff\")\n                         .Build());\n  }\n  Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& start,\n                                const std::shared_ptr<one::Tensor>& end,\n                                const std::shared_ptr<one::Tensor>& out_diff,\n                                const Scalar& weight) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"float_operand\", \"has_float_operand\",\n                                                 \"int_operand\", \"has_int_operand\");\n    if (weight.IsFloatingPoint()) {\n      attrs.SetAllAttrs(weight.As<double>(), true, NullOpt, false);\n    } else if (weight.IsIntegral()) {\n      attrs.SetAllAttrs(NullOpt, false, weight.As<int64_t>(), true);\n    } else {\n      UNIMPLEMENTED_THEN_RETURN() << \"The scalar in ScalarLerpGrad should be float or int.\";\n    }\n    return OpInterpUtil::Dispatch<TensorTuple>(*op_, {start, end, out_diff}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass ScalarBitwiseAndFunctor : public ScalarBitwiseBaseFunctor {\n public:\n  ScalarBitwiseAndFunctor() : ScalarBitwiseBaseFunctor(/*op_name=*/\"scalar_bitwise_and\") {}\n};\n\nclass ScalarBitwiseAnd2Functor {\n public:\n  Maybe<Tensor> operator()(const Scalar& scalar, const std::shared_ptr<one::Tensor>& x) const {\n    return ScalarBitwiseAnd(x, scalar);\n  }\n};\n\nclass ScalarBitwiseOrFunctor : public ScalarBitwiseBaseFunctor {\n public:\n  ScalarBitwiseOrFunctor() : ScalarBitwiseBaseFunctor(/*op_name=*/\"scalar_bitwise_or\") {}\n};\n\nclass ScalarBitwiseOr2Functor {\n public:\n  Maybe<Tensor> operator()(const Scalar& scalar, const std::shared_ptr<one::Tensor>& x) const {\n    return ScalarBitwiseOr(x, scalar);\n  }\n};\n\nclass ScalarBitwiseXorFunctor : public ScalarBitwiseBaseFunctor {\n public:\n  ScalarBitwiseXorFunctor() : ScalarBitwiseBaseFunctor(/*op_name=*/\"scalar_bitwise_xor\") {}\n};\n\nclass ScalarBitwiseXor2Functor {\n public:\n  Maybe<Tensor> operator()(const Scalar& scalar, const std::shared_ptr<one::Tensor>& x) const {\n    return ScalarBitwiseXor(x, scalar);\n  }\n};\n\nclass StandardDeviationFunctor {\n public:\n  Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& input,\n                           const Optional<std::vector<int32_t>>& dim,\n                           const Optional<bool>& unbiased, const Optional<bool>& keepdim) const {\n    std::vector<int32_t> axis;\n    if (!dim) {\n      for (int i = 0; i < input->ndim(); i++) { axis.emplace_back(i); }\n    } else {\n      axis = *JUST(CheckAxis(*JUST(dim), input->ndim()));\n    }\n    bool unbias = true;\n    bool keepdims = false;\n    if (unbiased.has_value()) { unbias = JUST(unbiased); }\n    if (keepdim.has_value()) { keepdims = JUST(keepdim); }\n\n    if (axis.size() == 0) {\n      return functional::Constant(*input->shape(), Scalar(0), *input->dtype(), NullOpt);\n    }\n\n    int32_t reduce_count = 1;\n    if (axis.size() == 1) {\n      reduce_count *= input->shape()->At(axis[0]);\n    } else {\n      for (int i = 0; i < axis.size(); ++i) { reduce_count *= input->shape()->At(axis[i]); }\n    }\n\n    bool is_double = input->dtype()->data_type() == DataType::kDouble;\n    if (is_double) {\n      const auto& sum = JUST(functional::ScalarDiv(\n          JUST(functional::ReduceSum(JUST(functional::Square(input)), axis, keepdims, NullOpt)),\n          Scalar((double)reduce_count)));\n      const auto& square = JUST(functional::Square(\n          JUST(functional::ScalarDiv(JUST(functional::ReduceSum(input, axis, keepdims, NullOpt)),\n                                     Scalar((double)reduce_count)))));\n      const auto& sub = JUST(functional::Sub(sum, square, /*alpha=*/1.0, /*inplace=*/false));\n      if (unbias) {\n        return functional::Sqrt(JUST(functional::ScalarMul(\n            sub, Scalar((double)reduce_count / (double)(reduce_count - 1)), false)));\n      }\n      /*\n      According to the std calculation formula,\n      StandardDeviation = \\sqrt {\\frac {\\sum _ {i=1}^ {N}X_ {i}^ {2}}{N}  -  \\mu ^ {2}}\n        = \\sqrt{\\frac {1}{N}\\sum _ {i=1}^ {n} (x_ {i}-\\mu )^ {2}  -\\frac {1}{N}  N \\mu ^ {2}}\n        = \\sqrt{\\frac {\\sum _ {i=1}^ {N}X_ {i}^ {2}}{N}  -  \\mu ^ {2}}\n\n      when we are in the last sqrt,\n      if the value in the radical is <= 0, it may cause the result gradient to appear\n      undefined(nan), which is normal. In this case, the gradient of ours and pytorch are\n      different. Use abs(absolute value) can keep it consistent with pytorch:\n\n      const auto& abs = JUST(functional::Abs(sub));\n      return functional::Sqrt(abs);\n      */\n      // const auto& abs = JUST(functional::Abs(sub));\n      // return functional::Sqrt(abs);\n      return functional::Sqrt(sub);\n    } else {\n      //  If input tensor's dtype is float32, than cast it to double dtype,\n      //  because float dtype has accuracy problem in float dtype, see:\n      //  https://github.com/Oneflow-Inc/oneflow/issues/6526\n      const auto& double_input =\n          JUST(functional::Cast(input, DType::Double(), /*pin_memory=*/false));\n      const auto& sum = JUST(\n          functional::ScalarDiv(JUST(functional::ReduceSum(JUST(functional::Square(double_input)),\n                                                           axis, keepdims, NullOpt)),\n                                Scalar((double)reduce_count)));\n      const auto& square = JUST(functional::Square(JUST(\n          functional::ScalarDiv(JUST(functional::ReduceSum(double_input, axis, keepdims, NullOpt)),\n                                Scalar((double)reduce_count)))));\n      const auto& sub = JUST(functional::Sub(sum, square, /*alpha=*/1.0, /*inplace=*/false));\n      if (unbias) {\n        return functional::Cast(\n            JUST(functional::Sqrt(JUST(functional::ScalarMul(\n                sub, Scalar((double)reduce_count / (double)(reduce_count - 1)), false)))),\n            input->dtype(), /*pin_memory=*/false);\n      }\n      return functional::Cast(JUST(functional::Sqrt(sub)), input->dtype(), /*pin_memory=*/false);\n    }\n  }\n};\n\nclass VarianceFunctor {\n public:\n  VarianceFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"var\").Input(\"input\").Output(\"output\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& input,\n                           const Optional<std::vector<int32_t>>& dim,\n                           const Optional<bool>& unbiased, const Optional<bool>& keepdim) const {\n    if (!(IsFloatingDataType(input->dtype()->data_type())\n          || IsHalfDataType(input->dtype()->data_type()))) {\n      return Error::RuntimeError() << \"var only support floating point dtypes\";\n    }\n    std::vector<int32_t> axis;\n    const int ndim = input->ndim();\n    axis.reserve(ndim);\n    if (!dim) {\n      for (int i = 0; i < ndim; i++) { axis.emplace_back(i); }\n    } else {\n      std::vector<int32_t>& dims = *JUST(dim);\n      JUST(maybe_wrap_dim(dims.size(), ndim));  // only check validation\n      std::sort(dims.begin(), dims.end());\n      axis.assign(dims.begin(), dims.end());\n    }\n    for (size_t i = 0; i < axis.size(); i++) {\n      if (axis[i] < 0) { axis[i] += ndim; }\n    }\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"unbiased\", \"keepdim\", \"dim\", \"dtype\");\n    attrs.SetAllAttrs(unbiased, keepdim, axis, input->dtype()->data_type());\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {input}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass RMSLayerNormalizationFunctor {\n public:\n  Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& hidden_states,\n                           const std::shared_ptr<Tensor>& weight,\n                           const float& variance_epsilon) const {\n    std::shared_ptr<Tensor> cast_hidden_states = hidden_states;\n    if (hidden_states->dtype() != DType::Float()) {\n      cast_hidden_states =\n          JUST(functional::Cast(hidden_states, DType::Float(), /*pin_memory=*/false));\n    }\n    std::shared_ptr<Tensor> normalized_hidden_states = JUST(functional::Mul(\n        cast_hidden_states, JUST(functional::Rsqrt(JUST(functional::ScalarAdd(\n                                JUST(functional::ReduceMean(JUST(Square(hidden_states)),\n                                                            std::vector<int32_t>{-1}, true)),\n                                Scalar(variance_epsilon), 1.0, false))))));\n    if (weight->dtype() == DType::Float16()) {\n      normalized_hidden_states =\n          JUST(functional::Cast(normalized_hidden_states, weight->dtype(), /*pin_memory=*/false));\n    }\n    return JUST(functional::Mul(normalized_hidden_states, weight));\n  }\n};\n\nclass DotFunctor {\n public:\n  DotFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"dot\").Input(\"x\").Input(\"y\").Output(\"out\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input,\n                           const std::shared_ptr<one::Tensor>& other) const {\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {input, other});\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\nclass MovedimVecFunctor {\n public:\n  MovedimVecFunctor() = default;\n  static Maybe<void> CheckNoRepeat(const std::vector<int32_t>& perm, std::vector<int32_t>& perm_out,\n                                   int32_t ndim, const std::string& desc) {\n    std::vector<bool> is_used(ndim, false);\n    FOR_RANGE(size_t, i, 0, perm.size()) {\n      int32_t item = perm[i];\n      item = JUST(maybe_wrap_dim(item, ndim));\n      CHECK_EQ_OR_RETURN(is_used[item], false) << \"repeated dim in \" << desc;\n\n      is_used[item] = true;\n      perm_out[i] = item;\n    }\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input,\n                           const std::vector<int32_t>& source,\n                           const std::vector<int32_t>& destination) const {\n    int32_t ndim = input->ndim();\n    int32_t dim = source.size();\n\n    CHECK_EQ_OR_RETURN(source.size(), destination.size())\n        << \"movedim: Invalid source or destination dims: source (\" << source.size()\n        << \" dims ) should contain the same number of dims as destination (\" << destination.size()\n        << \" dims)\";\n\n    std::vector<int32_t> source_nopeat(dim);\n    std::vector<int32_t> destination_nopeat(dim);\n\n    JUST(CheckNoRepeat(source, source_nopeat, ndim, \"source\"));\n    JUST(CheckNoRepeat(destination, destination_nopeat, ndim, \"destination\"));\n\n    std::vector<int32_t> order(ndim);\n    std::vector<int32_t> source_dims(ndim);\n    std::vector<int32_t> destination_dims(ndim);\n\n    std::iota(source_dims.begin(), source_dims.end(), 0);\n    std::iota(destination_dims.begin(), destination_dims.end(), 0);\n\n    FOR_RANGE(size_t, i, 0, dim) {\n      order[destination_nopeat[i]] = source_nopeat[i];\n      source_dims[source_nopeat[i]] = -1;\n      destination_dims[destination_nopeat[i]] = -1;\n    }\n\n    std::remove(source_dims.begin(), source_dims.end(), -1);\n    std::remove(destination_dims.begin(), destination_dims.end(), -1);\n\n    int64_t rest_dim = ndim - dim;\n    FOR_RANGE(size_t, i, 0, rest_dim) { order[destination_dims[i]] = source_dims[i]; }\n\n    return Transpose(input, order);\n  }\n};\n\nclass MovedimIntFunctor {\n public:\n  MovedimIntFunctor() = default;\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input, const int32_t& source,\n                           const int32_t& destination) const {\n    std::vector<int32_t> src{source};\n    std::vector<int32_t> dest{destination};\n    return MovedimVec(input, src, dest);\n  }\n};\n\nclass TensorSplitVecFunctor {\n public:\n  TensorSplitVecFunctor() = default;\n  Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& input,\n                                const std::vector<int32_t>& indices_or_sections,\n                                const int32_t& dim) const {\n    int32_t ndim = input->ndim();\n    int32_t pos_dim = JUST(maybe_wrap_dim(dim, ndim));\n\n    std::vector<int64_t> start(ndim, 0);\n    std::vector<int64_t> stop(ndim);\n    std::vector<int64_t> step(ndim, 1);\n    for (int32_t i = 0; i < ndim; i++) { stop[i] = input->dim(i); }\n\n    int32_t num_indices = indices_or_sections.size();\n    TensorTuple output(num_indices + 1);\n    for (int32_t i = 0; i < num_indices; i++) {\n      int32_t end_idx = indices_or_sections[i];\n      stop[pos_dim] = end_idx;\n      output[i] = JUST(Slice(input, start, stop, step, /*enable_view_slice=*/false));\n      start[pos_dim] = end_idx;\n    }\n    stop[pos_dim] = input->shape()->At(pos_dim);\n    output[num_indices] = JUST(Slice(input, start, stop, step, /*enable_view_slice=*/false));\n\n    return output;\n  }\n};\n\nclass TensorSplitIntFunctor {\n public:\n  TensorSplitIntFunctor() = default;\n  Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& input,\n                                const int32_t& indices_or_sections, const int32_t& dim) const {\n    int32_t ndim = input->ndim();\n    int32_t pos_dim = JUST(maybe_wrap_dim(dim, ndim));\n    CHECK_OR_RETURN(indices_or_sections > 0)\n        << \"number of sections must be larger than 0, got ,\" << indices_or_sections << \");\";\n\n    const auto dim_size = input->dim(pos_dim);\n    int64_t min_split_size = dim_size / indices_or_sections;\n    int64_t num_splits_one_extra = dim_size % indices_or_sections;\n\n    std::vector<int64_t> start(ndim, 0);\n    std::vector<int64_t> stop(ndim);\n    std::vector<int64_t> step(ndim, 1);\n    for (int32_t i = 0; i < ndim; i++) { stop[i] = input->dim(i); }\n    stop[pos_dim] = 0;\n\n    TensorTuple output(indices_or_sections);\n    for (int32_t i = 0; i < indices_or_sections; i++) {\n      int64_t split_size = (i < num_splits_one_extra) ? (min_split_size + 1) : min_split_size;\n      stop[pos_dim] += split_size;\n      output[i] = JUST(Slice(input, start, stop, step, /*enable_view_slice=*/false));\n      start[pos_dim] += split_size;\n    }\n\n    return output;\n  }\n};\n\nclass HsplitIntFunctor {\n public:\n  HsplitIntFunctor() = default;\n  Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& input,\n                                const int32_t& indices_or_sections) const {\n    int32_t ndim = input->ndim();\n    CHECK_OR_RETURN(ndim >= 1)\n        << \"flow.hsplit requires a tensor with at least 1 dimension, but got a tensor with \" << ndim\n        << \" dimensions!\";\n    CHECK_OR_RETURN(indices_or_sections > 0) << \"indices_or_sections must greater than 0\";\n    int32_t dim = (ndim == 1) ? 0 : 1;\n    CHECK_OR_RETURN(input->dim(dim) % indices_or_sections == 0)\n        << \"flow.hsplit attempted to split along dimension \" << dim\n        << \", but the size of the dimension \" << input->shape()->At(dim)\n        << \" is not divisible by the split_size \" << indices_or_sections << \"!\";\n    return TensorSplitInt(input, indices_or_sections, dim);\n  }\n};\n\nclass HsplitVecFunctor {\n public:\n  HsplitVecFunctor() = default;\n  Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& input,\n                                const std::vector<int32_t>& indices_or_sections) const {\n    int32_t ndim = input->ndim();\n    CHECK_OR_RETURN(ndim >= 1)\n        << \"flow.hsplit requires a tensor with at least 1 dimension, but got a tensor with \" << ndim\n        << \" dimensions!\";\n    int32_t dim = (ndim == 1) ? 0 : 1;\n    return TensorSplitVec(input, indices_or_sections, dim);\n  }\n};\n\nclass VsplitIntFunctor {\n public:\n  VsplitIntFunctor() = default;\n  Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& input,\n                                const int32_t& indices_or_sections) const {\n    int32_t ndim = input->ndim();\n    CHECK_OR_RETURN(ndim >= 2)\n        << \"flow.vsplit requires a tensor with at least 2 dimension, but got a tensor with \" << ndim\n        << \" dimensions!\";\n    CHECK_OR_RETURN(indices_or_sections > 0) << \"indices_or_sections must greater than 0\";\n    CHECK_OR_RETURN(input->dim(0) % indices_or_sections == 0)\n        << \"flow.vsplit attempted to split along dimension \" << 0\n        << \", but the size of the dimension \" << input->dim(0)\n        << \" is not divisible by the split_size \" << indices_or_sections << \"!\";\n    return TensorSplitInt(input, indices_or_sections, 0);\n  }\n};\n\nclass VsplitVecFunctor {\n public:\n  VsplitVecFunctor() = default;\n  Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& input,\n                                const std::vector<int32_t>& indices_or_sections) const {\n    int32_t ndim = input->ndim();\n    CHECK_OR_RETURN(ndim >= 2)\n        << \"flow.vsplit requires a tensor with at least 1 dimension, but got a tensor with \" << ndim\n        << \" dimensions!\";\n    return TensorSplitVec(input, indices_or_sections, 0);\n  }\n};\n\nclass ErfinvFunctor {\n public:\n  ErfinvFunctor() { op_ = CHECK_JUST(one::OpBuilder(\"erfinv\").Input(\"x\").Output(\"y\").Build()); }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x) const {\n    return OpInterpUtil::Dispatch<one::Tensor>(*op_, {x}, {});\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass ErfinvInplaceFunctor {\n public:\n  ErfinvInplaceFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"erfinv\").Input(\"x\").Output(\"y\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x) const {\n    JUST(CheckInplaceValid(x));\n    std::shared_ptr<TensorTuple> outputs = std::make_shared<TensorTuple>(1);\n    outputs->at(0) = x;\n    JUST(OpInterpUtil::Dispatch(*op_, {x}, outputs.get(), {}));\n    return outputs->at(0);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass GeluWithApproximateFunctor {\n public:\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,\n                           const std::string& approximate) const {\n    if (approximate != \"none\" && approximate != \"tanh\") {\n      return Error::RuntimeError() << \"the approximate argument should be 'none' or 'tanh'\";\n    }\n    if (approximate == \"tanh\") { return FastGelu(x); }\n    return Gelu(x);\n  }\n};\n\nclass CumBaseFunctor {\n public:\n  explicit CumBaseFunctor(std::string op_name) {\n    op_ = CHECK_JUST(one::OpBuilder(op_name).Input(\"x\").Output(\"y\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input, int64_t dim,\n                           const Optional<Symbol<DType>>& dtype) const {\n    auto ndim = input->ndim();\n    dim = JUST(maybe_wrap_dim(dim, ndim));\n\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"dim\");\n    attrs.SetAllAttrs(dim);\n    TensorProcessor tensor_processor;\n    if (dtype) {\n      JUST(tensor_processor.AddInputs({input}, JUST(dtype)).Apply());\n    } else {\n      JUST(tensor_processor.AddInputs({input}, DType::Int64()).Apply());\n    }\n    TensorTuple input_tuple = JUST(tensor_processor.GetInputs());\n    return OpInterpUtil::Dispatch<Tensor>(*op_, input_tuple, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass CumsumFunctor : public CumBaseFunctor {\n public:\n  CumsumFunctor() : CumBaseFunctor(\"cumsum\") {}\n};\n\nclass CumProdFunctor : public CumBaseFunctor {\n public:\n  CumProdFunctor() : CumBaseFunctor(\"cumprod\") {}\n};\n\nclass CumGradBaseFunctor {\n protected:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass CumProdGradFunctor : public CumGradBaseFunctor {\n public:\n  CumProdGradFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"cumprod_grad\")\n                         .Input(\"dy\")\n                         .Input(\"output\")\n                         .Input(\"input\")\n                         .Output(\"dx\")\n                         .Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& dy,\n                           const std::shared_ptr<one::Tensor>& y,\n                           const std::shared_ptr<one::Tensor>& x, int64_t dim) const {\n    // No need to check dim validation here, while CumProdFunctor handled already\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"dim\");\n    attrs.SetAllAttrs(dim);\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {dy, y, x}, attrs);\n  }\n};\n\n// NOTE(Liang Depeng): The implementation of sumproduct_pair are mostly taken from pytorch.\n//                     For more details pls refer to:\n//                     https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Linear.cpp#L65\n\n// sumproduct_pair computes `(left*right).sum(sumdims)` by means of permutation and\n// batch matrix multiplication\n// its main purpose is to provide a pairwise reduction for einsum\nstatic Maybe<one::Tensor> sumproduct_pair(const std::shared_ptr<one::Tensor>& left_,\n                                          const std::shared_ptr<one::Tensor>& right_,\n                                          const std::vector<int32_t>& sum_dims_, bool keepdim) {\n  // assumes that tensors have been pre-unsqueezed (so that all dimensions match - after\n  // broadcasting) but makes no other assumptions on the order of dimensions\n  CHECK_OR_RETURN(left_->ndim() == right_->ndim()) << \"number of dimensions must match\";\n  if (sum_dims_.size() == 0) return functional::Mul(left_, right_);\n  int64_t dim = left_->ndim();\n\n  constexpr size_t dim_bitset_size = 64;\n  CHECK_OR_RETURN(dim <= (int64_t)dim_bitset_size)\n      << \"only tensors with up to \" << dim_bitset_size << \" dims are supported\";\n  std::bitset<dim_bitset_size> sum_dims;\n  for (int i = 0; i < sum_dims_.size(); ++i) {\n    size_t d = sum_dims_[i];\n    CHECK_OR_RETURN(!sum_dims[d]) << \"dim \" << d << \" appears multiple times in the list of dims\";\n    sum_dims[d] = true;\n  }\n\n  // dimensions that will be part of the output (i.e. not summed over) in three vectors\n  // dims in lro appear in left, right and output, similarly lo: left and output, ro: right and\n  // output also the sizes are kept track of for reshaping\n  std::vector<int32_t> lro, lo, ro;\n  int32_t lro_size = 1, lo_size = 1, ro_size = 1, sum_size = 1;\n  std::shared_ptr<one::Tensor> left = left_;\n  std::shared_ptr<one::Tensor> right = right_;\n  for (int i = 0; i < dim; ++i) {\n    auto sl = left->shape()->At(i) > 1;\n    auto sr = right->shape()->At(i) > 1;\n    if (sum_dims[i]) {  // first dimensions that will be summed over after multiplication\n      if (sl && sr) {   // dimensions nontrivially in both left and right must be of the same size\n        CHECK_OR_RETURN(left->shape()->At(i) == right->shape()->At(i))\n            << \"non-broadcast dimensions must match\";\n        sum_size *= left->shape()->At(i);\n      } else if (sl) {  // if it is only in one of left and right, we can sum right away\n        left = JUST(functional::ReduceSum(left, {i}, true, NullOpt));\n      } else if (sr) {\n        right = JUST(functional::ReduceSum(right, {i}, true, NullOpt));\n      }\n    } else if (sl && sr) {  // now deal with dimensions  dimensions that will be in the output\n      // dimensions nontrivially in both left and right must be of the same size\n      CHECK_OR_RETURN(left->shape()->At(i) == right->shape()->At(i))\n          << \"non-broadcast dimensions must match\";\n      lro.push_back(i);\n      lro_size *= left->shape()->At(i);\n    } else if (sl) {  // keep track of dimensions appearing only once\n      lo.push_back(i);\n      lo_size *= left->shape()->At(i);\n    } else {\n      ro.push_back(i);\n      ro_size *= right->shape()->At(i);\n    }\n  }\n\n  // we now work with the following permutations / shapes.\n  // the pipeline is permute inputs -> reshape inputs -> batch matrix mul -> reshape(view) output\n  // -> permute output output: \"lro, lo, 1-for-summed-dims, ro\" with orgiginal shape dimensions\n  // left: \"lro, lo, summed\" permuted with lpermutation and the three flattened right:  \"lro,\n  // summed, ro\" permuted with rpermutation and the three flattened then the permuted output is a\n  // view of bmm(left, right) finally, opermutation reverts the permutation to the original order\n  // of dimensions\n  std::vector<int32_t> out_size;\n  for (auto& d : lro) out_size.push_back(left->shape()->At(d));\n  for (auto& d : lo) out_size.push_back(left->shape()->At(d));\n  for (auto& d : sum_dims_) {\n    out_size.push_back(1);\n    (void)(d);\n  };  // avoid warining about not using d\n  for (auto& d : ro) out_size.push_back(right->shape()->At(d));\n\n  std::vector<int32_t> lpermutation(lro);\n  lpermutation.insert(lpermutation.end(), lo.begin(), lo.end());\n  lpermutation.insert(lpermutation.end(), sum_dims_.begin(), sum_dims_.end());\n  lpermutation.insert(lpermutation.end(), ro.begin(), ro.end());\n\n  std::vector<int32_t> rpermutation(lro);\n  rpermutation.insert(rpermutation.end(), sum_dims_.begin(), sum_dims_.end());\n  rpermutation.insert(rpermutation.end(), ro.begin(), ro.end());\n  rpermutation.insert(rpermutation.end(), lo.begin(), lo.end());\n\n  std::vector<int32_t> opermutation(lro.size() + lo.size() + sum_dims_.size() + ro.size(), -1);\n  {\n    int32_t i = 0;\n\n    for (auto it = lro.cbegin(); it != lro.cend(); i++, it++) { opermutation[*it] = i; }\n    for (auto it = lo.cbegin(); it != lo.cend(); i++, it++) { opermutation[*it] = i; }\n    for (auto it = sum_dims_.cbegin(); it != sum_dims_.cend(); i++, it++) { opermutation[*it] = i; }\n    for (auto it = ro.cbegin(); it != ro.cend(); i++, it++) { opermutation[*it] = i; }\n  }\n\n  // now we can execute the operations above\n  left = JUST(functional::Permute(left, lpermutation));\n  DimVector lsv(3);\n  lsv[0] = lro_size;\n  lsv[1] = lo_size;\n  lsv[2] = sum_size;\n  const Shape ls(lsv);\n\n  left = JUST(functional::Reshape(left, ls));\n\n  right = JUST(functional::Permute(right, rpermutation));\n  DimVector rsv(3);\n  rsv[0] = lro_size;\n  rsv[1] = sum_size;\n  rsv[2] = ro_size;\n  const Shape rs(rsv);\n  right = JUST(functional::Reshape(right, rs));\n\n  std::shared_ptr<one::Tensor> result =\n      JUST(functional::BatchMatMul(left, right, false, false, 1.0));\n  DimVector osv(out_size.size());\n  for (int i = 0; i < out_size.size(); ++i) { osv[i] = out_size[i]; }\n  const Shape os(osv);\n  // TODO(Liang Depeng): change reshape to veiw\n  result = JUST(functional::Reshape(result, os));\n  result = JUST(functional::Permute(result, opermutation));\n\n  // finally squeeze summed dimensions if desired\n  if (!keepdim) {\n    auto sizes = result->shape()->dim_vec();\n    for (int i = dim - 1; i >= 0; i--) {\n      if (sum_dims[i]) { sizes.erase(sizes.begin() + i); }\n    }\n    // TODO(Liang Depeng): change reshape to veiw\n    const Shape s(sizes);\n    result = JUST(functional::Reshape(result, s));\n  }\n  return result;\n}\n\nnamespace {\n\nbool einsum_check_label(unsigned char label) { return std::isalpha(label); }\n\nuint8_t einsum_label_to_index(unsigned char label) {\n  constexpr uint8_t NUM_OF_LETTERS = 'z' - 'a' + 1;\n  return std::isupper(label) ? label - 'A' : NUM_OF_LETTERS + (label - 'a');\n}\n\nunsigned char einsum_index_to_label(uint8_t index) {\n  constexpr uint8_t NUM_OF_LETTERS = 'z' - 'a' + 1;\n  return index < NUM_OF_LETTERS ? index + 'A' : index - NUM_OF_LETTERS + 'a';\n}\n\n}  // namespace\n\n// NOTE(Liang Depeng): The implementation of EinSumFunctor are mostly taken from pytorch.\n//                     For more details pls refer to:\n//                     https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Linear.cpp#L190\n\n// There are roughly three parts to compute einsum:\n// 1. Parse equation to extract the labels for each input operand and output\n// 2. Unsqueeze missing dimensions from input operands and permute to align them\n// 3. Compute result by multiplying input operands and summing contraction\n//    dimensions We do the last part by reducing to batch matmul.\nclass EinSumFunctor {\n public:\n  EinSumFunctor() {}\n  Maybe<Tensor> operator()(const std::string& equation, const one::TensorTuple& operands) const {\n    CHECK_OR_RETURN(operands.size() > 0) << \"einsum(): must provide at least one input tensor.\";\n    // NOTE(Liang Depeng): In order to better understand what einsum is doing,\n    //                     the following comments will give a detailed explaination of\n    //                     how the operands of equation \"ik,jkl,il->ij\" (bilinear)\n    //                     are transformed during the computation.\n    //                     Assume that the size of each operands \"ik\", \"jkl\" and \"il\" are\n    //                     [2, 3], [4, 3, 5], [2, 5] respectively.\n\n    // Code used to identify ELLIPSIS (\"...\")\n    constexpr uint8_t ELLIPSIS = 52;\n\n    // Find arrow (->) to split equation into lhs (input equations) and rhs (output equation)\n    const auto arrow_pos = equation.find(\"->\");\n    const auto lhs = equation.substr(0, arrow_pos);\n\n    const auto num_ops = operands.size();\n\n    // Convert each input equations into indexes in range [0, 52] and store\n    // them in op_labels for each operand along with ELLIPSIS if present.\n    std::vector<std::vector<uint8_t>> op_labels(num_ops);\n    // NOTE(Liang Depeng): Continue explaining the equation \"ik,jkl,il->ij\".\n    //                     After running the following for loop, `op_labels` contains 3 vectors.\n    //                     The contents of each vectors are:\n    //                     op_labels[0]: [34('i'-'a'+26), 36('k'-'a'+26)]\n    //                     op_labels[1]: [35('j'-'a'+26), 36('k'-'a'+26), 37('l'-'a'+26)]\n    //                     op_labels[2]: [34('i'-'a'+26), 37('l'-'a'+26)]\n    bool found_ell = false;\n    std::size_t curr_op = 0;\n    for (auto i = decltype(lhs.length()){0}; i < lhs.length(); ++i) {\n      const unsigned char label = lhs[i];\n      switch (label) {\n        case ' ':\n          // Ignore spaces\n          break;\n\n        case '.':\n          // process ellipsis\n          CHECK_OR_RETURN(\n              // Only one ellipsis per operand can be given\n              !found_ell)\n              << \"einsum(): found \\'.\\' for operand \" << curr_op\n              << \" for which an ellipsis was already found\";\n          CHECK_OR_RETURN(\n              // Ensure it's a valid ellipsis\n              i + 2 < lhs.length() && lhs[++i] == '.' && lhs[++i] == '.')\n              << \"einsum(): found \\'.\\' for operand \" << curr_op\n              << \" that is not part of any ellipsis\";\n          op_labels[curr_op].push_back(ELLIPSIS);\n          found_ell = true;\n          break;\n\n        case ',':\n          // Move onto next operand\n          ++curr_op;\n          CHECK_OR_RETURN(curr_op < num_ops)\n              << \"einsum(): fewer operands were provided than specified in the equation\";\n          found_ell = false;\n          break;\n\n        default:\n          // Parse label\n          CHECK_OR_RETURN(einsum_check_label(label))\n              << \"einsum(): invalid subscript given at index  \" << i\n              << \" in the equation string, subscripts must be in [a-zA-Z]\";\n          op_labels[curr_op].push_back(einsum_label_to_index(label));\n      }\n    }\n\n    CHECK_OR_RETURN(curr_op == num_ops - 1)\n        << \"einsum(): more operands were provided than specified in the equation\";\n\n    // Labels must be within [a-zA-Z].\n    constexpr uint8_t TOTAL_LABELS = 52;\n    std::vector<int32_t> label_count(TOTAL_LABELS, 0);\n\n    // The maximum number of dimensions covered by any ellipsis, needed when\n    // unsqueezing missing dimensions from operands to permute and broadcast\n    int32_t ell_num_dim = 0;\n    // NOTE(Liang Depeng): Continue explaining the equation \"ik,jkl,il->ij\".\n    //                     After running the following for loop,\n    //                     the none zero indexes of `label_count` are:\n    //                     op_labels[34] = 2\n    //                     op_labels[35] = 1\n    //                     op_labels[36] = 2\n    //                     op_labels[37] = 2\n    //                     `ell_num_dim` equals to 0 because no ellipsis in equation\n\n    // Compute label frequency and number of dimensions covered by ellipsis\n    // We do this after parsing labels to make it more readable and simpler\n    // to compute the number of dimensions covered by ellipsis.\n    for (auto i = 0; i < num_ops; i++) {\n      const auto operand = operands[i];\n      const auto labels = op_labels[i];\n      const int ndims = operand->ndim();\n      int32_t nlabels = static_cast<int32_t>(labels.size());\n      bool has_ellipsis = false;\n\n      for (const auto& label : labels) {\n        if (label == ELLIPSIS) {\n          --nlabels;\n          has_ellipsis = true;\n          ell_num_dim = std::max(ell_num_dim, ndims - nlabels);\n        } else {\n          ++label_count[label];\n        }\n      }\n      if (has_ellipsis) {\n        CHECK_OR_RETURN(nlabels <= ndims)\n            << \"einsum() the number of subscripts in the equation (\" << nlabels\n            << \") is more than the number of dimensions (\" << ndims << \") for operand \" << i;\n      } else {\n        CHECK_OR_RETURN(nlabels == ndims)\n            << \"einsum(): the number of subscripts in the equation (\" << nlabels\n            << \") does not match the number of dimensions (\" << ndims << \") for operand \" << i\n            << \" and no ellipsis was given\";\n      }\n    }\n\n    // We want to align the dimensions of every input tensor to have\n    // shape out_dims + sum_dims. For this, we create a mapping of label\n    // to index into the permuted shape.\n    std::vector<int32_t> label_perm_index(TOTAL_LABELS, -1);\n\n    // Current index in the permuted shape\n    int32_t perm_index = 0;\n\n    // Start index of ellipsis dimensions in the permuted shape\n    int32_t ell_index = 0;\n    found_ell = false;\n\n    // NOTE(Liang Depeng): Continue explaining the equation \"ik,jkl,il->ij\".\n    //                     After running the following if-else code block,\n    //                     the none -1 indexes of `label_perm_index` are:\n    //                     label_perm_index[34] = 0\n    //                     label_perm_index[35] = 1\n    //                     `perm_index` equals to 2\n    //                     `ell_index` equals to 0 because no ellipsis in equation\n    //                     `found_ell` equals to false because no ellipsis in equation\n    if (arrow_pos == std::string::npos) {\n      // Implicit output is ellipsis (...) + labels seen only once\n      perm_index = ell_num_dim;\n      found_ell = true;\n      for (auto label = 0; label < TOTAL_LABELS; label++) {\n        if (label_count[label] == 1) { label_perm_index[label] = perm_index++; }\n      }\n    } else {\n      // Parse explicit output\n      const auto rhs = equation.substr(arrow_pos + 2);\n      for (auto i = decltype(rhs.length()){0}; i < rhs.length(); ++i) {\n        const unsigned char label = rhs[i];\n        switch (label) {\n          case ' ':\n            // Ignore spaces\n            break;\n\n          case '.':\n            // process ellipsis\n            CHECK_OR_RETURN(\n                // There can only be one ellipsis in the output\n                !found_ell)\n                << \"einsum(): found \\'.\\' for output but an ellipsis (...) was already found\";\n            CHECK_OR_RETURN(\n                // Ensure ellipsis is correct\n                i + 2 < rhs.length() && rhs[++i] == '.' && rhs[++i] == '.')\n                << \"einsum(): found \\'.\\' for output that is not part of any ellipsis (...)\";\n            ell_index = perm_index;\n            perm_index += ell_num_dim;\n            found_ell = true;\n            break;\n\n          default:\n            CHECK_OR_RETURN(einsum_check_label(label))\n                << \"einsum(): invalid subscript given at index \" << lhs.size() + 2 + i\n                << \" in the equation string, subscripts must be in [a-zA-Z]\";\n            const auto index = einsum_label_to_index(label);\n            CHECK_OR_RETURN(\n                // Ensure label appeared at least once for some input operand\n                // and at most once for the output\n                label_count[index] > 0 && label_perm_index[index] == -1)\n                << \"einsum(): output subscript \" << label\n                << (label_perm_index[index] > -1\n                        ? \" appears more than once in the output\"\n                        : \" does not appear in the equation for any input operand\");\n            label_perm_index[index] = perm_index++;\n        }\n      }\n    }\n\n    // Save output size before adding contraction dims (dims to sum out)\n    const int32_t out_size = perm_index;\n\n    // If ellipsis is not part of the output, add to contraction dimensions\n    if (!found_ell) {\n      ell_index = perm_index;\n      perm_index += ell_num_dim;\n    }\n\n    // NOTE(Liang Depeng): Continue explaining the equation \"ik,jkl,il->ij\".\n    //                     After running the following foor loop,\n    //                     the none -1 indexes of `label_perm_index` are:\n    //                     label_perm_index[34] = 0 ('i')\n    //                     label_perm_index[35] = 1 ('j')\n    //                     label_perm_index[36] = 2 ('k')\n    //                     label_perm_index[37] = 3 ('l')\n    //                     `out_size` equals to 2\n    //                     `perm_index` equals to 4\n\n    // Add contraction labels (labels not present in output)\n    for (auto label = 0; label < TOTAL_LABELS; label++) {\n      if (label_count[label] > 0 && label_perm_index[label] == -1) {\n        label_perm_index[label] = perm_index++;\n      }\n    }\n\n    // Here we unsqueeze missing dimensions to make all operands have the same\n    // number of dimensions. We take diagonals for repeated labels within the\n    // same operand. Finally we permute the operands to align dimensions as\n    // per the perm_out_index we computed above.\n    TensorTuple permuted_operands;\n    for (auto i = 0; i < num_ops; i++) {\n      std::vector<int32_t> perm_shape(perm_index, -1);\n      std::vector<int32_t> label_dim(TOTAL_LABELS, -1);\n      std::shared_ptr<Tensor> operand = operands[i];\n      const auto labels = op_labels[i];\n      const auto original_sizes = operand->shape()->dim_vec();\n\n      int32_t j = 0;\n      for (const auto& label : labels) {\n        if (label == ELLIPSIS) {\n          // Add missing dimensions covered by the ellipsis\n          const auto num_missing_dim = ell_num_dim - (original_sizes.size() - labels.size() + 1);\n          for (auto k = 0; k < num_missing_dim; k++) {\n            operand = JUST(functional::Unsqueeze(operand, j));\n          }\n          for (auto k = 0; k < ell_num_dim; k++) { perm_shape[ell_index + k] = j++; }\n        } else if (label_dim[label] != -1) {\n          // Repeated label, take diagonal\n          const auto dim = label_dim[label];\n          CHECK_OR_RETURN(operand->dim(j) == operand->dim(dim))\n              << \"einsum() subscript \" << einsum_index_to_label(label)\n              << \" is repeated for operand \" << i << \" but the sizes don't match, \"\n              << operand->dim(j) << \" != \" << operand->dim(dim);\n\n          operand = JUST(functional::Diagonal(operand, 0, dim, j));\n          operand = JUST(functional::MovedimInt(operand, -1, dim));\n        } else {\n          // Lookup output index for label\n          label_dim[label] = j;\n          perm_shape[label_perm_index[label]] = j++;\n        }\n      }\n\n      // Add dimensions for missing labels\n      for (int32_t& index : perm_shape) {\n        if (index == -1) {\n          operand = JUST(functional::Unsqueeze(operand, -1));\n          index = j++;\n        }\n      }\n      permuted_operands.emplace_back(JUST(functional::Permute(operand, perm_shape)));\n\n      // NOTE(Liang Depeng): Continue explaining the equation \"ik,jkl,il->ij\".\n      //                     What is going on within this foor loop?\n      //                     For operand \"ik\" size = [2, 3]:\n      //                        `perm_shape` equals to [0, 2, 1, 3]\n      //                        first unsqueeze \"ik\" to 4 dim, from [2, 3] to [2, 3, 1, 1]\n      //                        then permute with `perm_shape`, from [2, 3, 1, 1] to [2, 1, 3, 1]\n      //\n      //                     For operand \"jkl\" size = [4, 3, 5]:\n      //                        `perm_shape` equals to [3, 0, 1, 2]\n      //                        first unsqueeze \"jkl\" to 4 dim, from [4, 3, 5] to [4, 3, 5, 1]\n      //                        then permute with `perm_shape`, from [4, 3, 5, 1] to [1, 4, 3, 5]\n      //\n      //                     For operand \"il\" size = [2, 5]:\n      //                        `perm_shape` equals to [0, 2, 3, 1]\n      //                        first unsqueeze \"ik\" to 4 dim, from [2, 5] to [2, 5, 1, 1]\n      //                        then permute with `perm_shape`, from [2, 5, 1, 1] to [2, 1, 1, 5]\n    }\n\n    // Check if operands broadcast and keep track of last operand with\n    // dimension size != 1 for optimizing reductions\n    std::vector<std::size_t> dim_last_op(perm_index, 0);\n    bool has_zero_size_dim = false;\n    // NOTE(Liang Depeng): Continue explaining the equation \"ik,jkl,il->ij\".\n    //                     After running the following foor loop,\n    //                     The contents of `dim_last_op` are:\n    //                     dim_last_op[0] = 2\n    //                     dim_last_op[1] = 1\n    //                     dim_last_op[2] = 1\n    //                     dim_last_op[3] = 2\n    //                     `has_zero_size_dim` equals to false\n    for (auto dim = 0; dim < perm_index; dim++) {\n      auto broadcast_size = permuted_operands[0]->dim(dim);\n      for (auto i = 1; i < num_ops; i++) {\n        const auto dim_size = permuted_operands[i]->dim(dim);\n        if (broadcast_size != dim_size && broadcast_size != 1 && dim_size != 1) {\n          std::ostringstream msg;\n          msg << \"einsum(): operands do not broadcast with remapped shapes [original->remapped]:\";\n          for (auto j = 0; j < num_ops; j++) {\n            msg << \" \" << operands[j]->shape()->DebugStr() << \"->\"\n                << permuted_operands[j]->shape()->DebugStr();\n          }\n          CHECK_OR_RETURN(false) << msg.str();\n        }\n        if (dim_size != 1) {\n          broadcast_size = dim_size;\n          dim_last_op[dim] = i;\n        }\n      }\n      has_zero_size_dim |= broadcast_size == 0;\n    }\n\n    // Compute result\n    std::shared_ptr<Tensor> result = permuted_operands[0];\n\n    // Fast path for when an operand has zero sized dim\n    if (has_zero_size_dim) {\n      DimVector out_shape(out_size);\n      for (auto i = 0; i < out_size; i++) {\n        out_shape[i] = permuted_operands[dim_last_op[i]]->dim(i);\n      }\n\n      const Shape shape(out_shape);\n      return functional::Constant(shape, Scalar(0), *permuted_operands[0]->dtype(), NullOpt);\n    }\n\n    // Sum out or squeeze dimensions that are size 1 for all later operands\n    int dim = out_size;\n    for (int i = dim; i < perm_index; ++i, ++dim) {\n      if (dim_last_op[i] == 0) {\n        if (result->dim(dim) == 1) {\n          std::vector<int32_t> dims = {dim--};\n          result = JUST(functional::Squeeze(result, dims));\n        } else {\n          result = JUST(functional::ReduceSum(result, {dim--}, false, NullOpt));\n        }\n      }\n    }\n\n    for (auto i = 1; i < num_ops; i++) {\n      auto operand = permuted_operands[i];\n      std::vector<int32_t> sum_dims;\n\n      // Sum out or squeeze dimensions that are size 1 for all later operands\n      dim = out_size;\n      for (int j = dim; j < perm_index; ++j, ++dim) {\n        if (dim_last_op[j] < i) {\n          std::vector<int32_t> dims = {dim--};\n          operand = JUST(functional::Squeeze(operand, dims));\n        } else if (dim_last_op[j] == i) {\n          if (result->dim(dim) == 1) {\n            operand = JUST(functional::ReduceSum(operand, {dim}, false, NullOpt));\n            std::vector<int32_t> dims = {dim--};\n            result = JUST(functional::Squeeze(result, dims));\n          } else {\n            sum_dims.push_back(dim);\n          }\n        }\n      }\n\n      // Multiply tensors and sum out dimensions in sum_dims\n      if (sum_dims.empty()) {\n        result = JUST(functional::Mul(result, operand));\n      } else if (sum_dims.size() == result->ndim()) {\n        auto flatten_result = JUST(functional::Flatten(result, 0, -1));\n        auto flatten_operand = JUST(functional::Flatten(operand, 0, -1));\n        result = JUST(functional::Dot(flatten_result, flatten_operand));\n      } else {\n        result = JUST(sumproduct_pair(result, operand, sum_dims, false));\n      }\n\n      // NOTE(Liang Depeng): Continue explaining the equation \"ik,jkl,il->ij\".\n      //                     What is going on within this foor loop?\n      //                     For iter i = 1:\n      //                        result = permuted_operands[0], size = [2, 1, 3, 1]\n      //                        operand = permuted_operands[1], size = [1, 4, 3, 5]\n      //                        sum_dims = [2, ]\n      //                        what happened in `sumproduct_pair` ?\n      //                            result [2, 1, 3, 1] will be permuted to [2, 3, 1, 1] then\n      //                                reshaped to [1, 2, 3]\n      //                            operand [1, 4, 3, 5] will be permuted to [3, 4, 5, 1] then\n      //                                reshape to [1, 3, 4 * 5]\n      //                            perform batch_matmul(result, operand) => [1, 2, 4 * 5]\n      //                            then reshape to [2, 1, 4, 5] then permute to\n      //                            [2, 4, 1, 5], at last reshape to [2, 4, 5]\n      //\n      //                     For iter i = 2:\n      //                        result, size = [2, 4, 5]\n      //                        operand = permuted_operands[2], size = [2, 1, 1, 5]\n      //                        squeeze operand from [2, 1, 1, 5] to [2, 1, 5]\n      //                        sum_dims = [2,]\n      //                        what happened in `sumproduct_pair` ?\n      //                            result [2, 4, 5] will be permuted to [2, 4, 5] then\n      //                                reshaped to [2, 4, 5]\n      //                            operand [2, 1, 5] will be permuted to [2, 5, 1] then\n      //                                reshape to [2, 5, 1]\n      //                            perform batch_matmul(result, operand)=>[2, 4, 1]\n      //                            then reshape to [2, 4, 1] then permute to [2, 4, 1]\n      //                            at last reshape to [2, 4]\n    }\n    return result;\n  }\n};\n\nclass TruncFunctor {\n public:\n  TruncFunctor() { op_ = CHECK_JUST(one::OpBuilder(\"trunc\").Input(\"in\").Output(\"out\").Build()); }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x) const {\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {x});\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\nclass AddCDivFunctor {\n public:\n  AddCDivFunctor() {}\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input,\n                           const std::shared_ptr<one::Tensor>& tensor1,\n                           const std::shared_ptr<one::Tensor>& tensor2, const Scalar& value) const {\n    return JUST(Add(input, JUST(ScalarMul(JUST(Div(tensor1, tensor2)), value, false)), 1, false));\n  }\n};\n\nclass InplaceAddCDivFunctor {\n public:\n  InplaceAddCDivFunctor() {}\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input,\n                           const std::shared_ptr<one::Tensor>& tensor1,\n                           const std::shared_ptr<one::Tensor>& tensor2, const Scalar& value) const {\n    JUST(CheckInplaceValid(input));\n    std::shared_ptr<TensorTuple> outputs = std::make_shared<TensorTuple>(1);\n    JUST(VectorAt(*outputs, 0)) = input;\n    JUST(Add(input, JUST(ScalarMul(JUST(Div(tensor1, tensor2)), value, false)), 1, true));\n    return JUST(VectorAt(*outputs, 0));\n  }\n};\n\nnamespace {\nconstexpr int64_t cufft_max_ndim =\n    3;  // must keep Equal to `oneflow/user/kernels/cufft_plan_cache.h:max_rank`\nenum class fft_norm_mode {\n  none = 0,   // No normalization\n  by_root_n,  // Divide by sqrt(signal_size)\n  by_n,       // Divide by signal_size\n};\n\nbool use_optimized_cufft_path(const std::vector<int64_t>& fft_dims) {\n  // For performance reason, when dim starts with (0, 1), do not use the optimized path.\n  if (fft_dims.size() > cufft_max_ndim\n      || (fft_dims.size() >= 2 && fft_dims[0] == 0 && fft_dims[1] == 1)) {\n    return false;\n  } else {\n    return true;\n  }\n}\n\n// Convert NumPy compatible normalization mode string to enum values\n// In Numpy, \"forward\" translates to `by_n` for a forward transform and `none` for backward.\nstatic fft_norm_mode fft_norm_from_string(const Optional<std::string>& norm_op, bool forward) {\n  std::string norm_str = norm_op.value_or(\"backward\");\n  if (norm_str == \"backward\") {\n    return forward ? fft_norm_mode::none : fft_norm_mode::by_n;\n  } else if (norm_str == \"forward\") {\n    return forward ? fft_norm_mode::by_n : fft_norm_mode::none;\n  } else if (norm_str == \"ortho\") {\n    return fft_norm_mode::by_root_n;\n  }\n\n  return fft_norm_mode::none;\n}\n\ntemplate<typename T>\nstatic T fft_compute_fct(int64_t size, fft_norm_mode normalization) {\n  constexpr auto one = static_cast<T>(1);\n  switch (normalization) {\n    case fft_norm_mode::none: return one;\n    case fft_norm_mode::by_n: return one / static_cast<T>(size);\n    case fft_norm_mode::by_root_n: return one / std::sqrt(static_cast<T>(size));\n  }\n  return static_cast<T>(0);\n}\n\ntemplate<typename T>\nstatic T fft_compute_fct(const Shape& in_shape, const std::vector<int64_t>& dims,\n                         fft_norm_mode normalization) {\n  if (normalization == fft_norm_mode::none) { return static_cast<T>(1); }\n  int64_t n = 1;\n  for (int64_t idx : dims) { n *= in_shape.At(idx); }\n  return fft_compute_fct<T>(n, normalization);\n}\n}  // namespace\n\nclass FftBaseFunctor {\n public:\n  explicit FftBaseFunctor() {}\n  explicit FftBaseFunctor(std::string op_name) {\n    op_ = CHECK_JUST(one::OpBuilder(op_name).Input(\"input\").Output(\"out\").Build());\n  }\n  virtual ~FftBaseFunctor() = default;\n\n  Maybe<Tensor> resize_fft_input(const std::shared_ptr<one::Tensor>& x,\n                                 const std::vector<int64_t>& dims,\n                                 const std::vector<int64_t>& sizes) const {\n    CHECK_EQ_OR_THROW(dims.size(), sizes.size()) << \"dims.size() != sizes.size().\";\n    bool must_copy = false;\n    auto x_sizes = x->shape()->dim_vec();\n    std::vector<int64_t> pad_amount(x_sizes.size() * 2);\n    std::vector<int64_t> slice_st(x_sizes.size());\n    std::vector<int64_t> slice_end(x_sizes.size());\n    std::vector<int64_t> slice_step(x_sizes.size(), 1);\n\n    FOR_RANGE(int64_t, i, 0, x_sizes.size()) {\n      slice_st[i] = 0;\n      slice_end[i] = x_sizes[i];\n    }\n\n    FOR_RANGE(int64_t, i, 0, sizes.size()) {\n      if (sizes[i] == -1) { continue; }\n\n      if (x_sizes[dims[i]] < sizes[i]) {\n        must_copy = true;\n        auto pad_idx = pad_amount.size() - 2 * dims[i] - 1;\n        pad_amount[pad_idx] = sizes[i] - x_sizes[dims[i]];\n      }\n\n      if (x_sizes[dims[i]] > sizes[i]) {\n        // slice in dims[i]\n        slice_end[dims[i]] = sizes[i];\n      }\n    }\n\n    auto sliced_tenosr = JUST(functional::Slice(x, slice_st, slice_end, slice_step, false));\n    return must_copy ? functional::ConstantPad(sliced_tenosr, pad_amount, 0) : sliced_tenosr;\n  }\n\n  Maybe<Symbol<DType>> promote_type_fft(Symbol<DType> type, bool require_complex = false) const {\n    if (type->is_complex()) { return type; }\n\n    if (!type->is_floating_point()) { type = GetDefaultDType(); }\n    CHECK_OR_RETURN(type->data_type() == kFloat || type->data_type() == kDouble)\n        << \"Unsupported dtype \" << type->name() << \", \"\n        << \"support kFloat and kDouble\";\n\n    if (!require_complex) { return type; }\n\n    switch (type->data_type()) {\n      //  TO-DO: add kFloat16\n      case (kFloat): return CHECK_JUST(DType::Get(DataType::kComplex64));\n      case (kDouble): return CHECK_JUST(DType::Get(DataType::kComplex128));\n      default: CHECK_OR_RETURN(false) << \"RuntimeError: dtype can't be handled\";\n    }\n    CHECK_OR_RETURN(false) << \"RuntimeError: dtype can't be handled\";\n  }\n\n  Maybe<Tensor> promote_tensor_fft(const std::shared_ptr<Tensor>& x,\n                                   bool require_complex = false) const {\n    auto cur_type = x->dtype();\n    auto new_type = JUST(promote_type_fft(cur_type, require_complex));\n    if (cur_type->data_type() == new_type->data_type()) {\n      return x;\n    } else {\n      TensorProcessor tensor_processor;\n      JUST(tensor_processor.AddInputs({x}, {new_type}).Apply());\n      return JUST(oneflow::VectorAt(JUST(tensor_processor.GetInputs()), 0));\n    }\n  }\n\n  Maybe<void> maybe_wrap_dims(std::vector<int64_t>& dims, int64_t dim_post_expr,\n                              bool wrap_scalar = true) const {\n    if (dim_post_expr <= 0) {\n      if (!wrap_scalar) {\n        CHECK_OR_RETURN(false) << \"RuntimeError: dimension specified as \" << dims[0]\n                               << \" but tensor has no dimensions\";\n      }\n      dim_post_expr = 1;  // this will make range [-1, 0]\n    }\n\n    int64_t min = -dim_post_expr;\n    int64_t max = dim_post_expr - 1;\n    for (auto& dim : dims) {\n      if (dim < min || dim > max) {\n        CHECK_OR_RETURN(false)\n            << \"RuntimeError: Dimension out of range (expected to be in range of [\" << min << \", \"\n            << max << \"], but got \" << dim << \")\";\n      }\n      if (dim < 0) dim += dim_post_expr;\n    }\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> calculate_fftn_shape_and_dims(const std::shared_ptr<Tensor>& x,\n                                            const Optional<std::vector<int64_t>>& n,\n                                            const Optional<std::vector<int64_t>>& dims,\n                                            std::vector<int64_t>& fft_shape,\n                                            std::vector<int64_t>& fft_dims) const {\n    if (dims.has_value()) {\n      fft_dims = *JUST(dims);\n      JUST(maybe_wrap_dims(fft_dims, x->ndim()));\n      std::vector<int64_t> copy = fft_dims;\n      std::sort(copy.begin(), copy.end());\n      auto duplicate = std::adjacent_find(copy.begin(), copy.end());\n      CHECK_OR_RETURN(duplicate == copy.end()) << \"RuntimeError: FFT dims must be unique\";\n    } else {\n      fft_dims.resize(x->ndim());\n      for (int i = 0; i < x->ndim(); i++) { fft_dims[i] = i; }\n    }\n\n    if (!n.has_value()) {\n      fft_shape.resize(fft_dims.size());\n      for (int i = 0; i < fft_dims.size(); i++) { fft_shape[i] = x->dim(fft_dims[i]); }\n    } else {\n      fft_shape = *JUST(n);\n      if (dims.has_value()) {\n        // got n, also got dim\n        for (int i = 0; i < fft_dims.size(); i++) {\n          if (fft_shape[i] == -1) { fft_shape[i] = x->dim(fft_dims[i]); }\n        }\n      } else {\n        // got n, but not got dim\n        fft_dims.resize(fft_shape.size());\n        FOR_RANGE(size_t, i, 0, fft_dims.size()) { fft_dims[i] = x->ndim() - fft_dims.size() + i; }\n      }\n    }\n\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> parse_input_n_and_dims(const std::shared_ptr<Tensor>& x,\n                                     const Optional<std::vector<int64_t>>& n,\n                                     const Optional<std::vector<int64_t>>& dims,\n                                     std::vector<int64_t>& fft_len,\n                                     std::vector<int64_t>& wrapped_dims) const {\n    if (n.has_value() && dims.has_value()) {\n      CHECK_OR_RETURN((*JUST(n)).size() == (*JUST(dims)).size())\n          << \"RuntimeError: When dim and shape were both given, they must have the same length\";\n    }\n    wrapped_dims.resize(x->ndim());\n    fft_len.resize(x->ndim());\n    if (dims.has_value() && (*JUST(dims)).size() == 1) {\n      // 1D-discrete fourier transform\n      wrapped_dims = *JUST(dims);\n      JUST(maybe_wrap_dims(wrapped_dims, x->ndim()));\n      fft_len.resize(wrapped_dims.size());\n      fft_len[0] = n.has_value() == true ? (*JUST(n))[0] : x->dim(wrapped_dims[0]);\n      if (fft_len[0] == -1) { fft_len[0] = x->dim(wrapped_dims[0]); }\n      CHECK_OR_RETURN(fft_len[0] >= 1) << \"RuntimeError: Expected n >= 1, but got \" << fft_len[0];\n    } else if (n.has_value() && JUST(n)->size() == 1) {\n      // 1D-discrete fourier transform\n      fft_len = *(JUST(n));\n      if (fft_len[0] == -1) { fft_len[0] = x->shape()->back(); }\n      CHECK_OR_RETURN(fft_len[0] >= 1) << \"RuntimeError: Expected n >= 1, but got \" << fft_len[0];\n      wrapped_dims.resize(1);\n      wrapped_dims[0] = x->ndim() - 1;\n    } else {\n      // ND-discrete fourier transform\n      JUST(calculate_fftn_shape_and_dims(x, n, dims, fft_len, wrapped_dims));\n    }\n\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<Tensor> permute_and_reshape(const std::shared_ptr<Tensor>& self,\n                                    const std::vector<int64_t>& out_sizes,\n                                    const std::vector<int64_t>& fft_dims,\n                                    std::vector<int64_t>& out_strides) const {\n    // Permute and reshape `self` Tensor.\n    // This can maximizes data locality\n    const int64_t ndim = self->ndim();\n    const int64_t fft_ndim = fft_dims.size();\n    const int64_t batch_dims = ndim - fft_ndim;\n    const auto& in_stride = JUST(self->stride());\n    // Permute dimensions to make batch dims come first, and this maximizes data locality\n    std::vector<int32_t> dim_permute(ndim);\n    std::iota(dim_permute.begin(), dim_permute.end(), int32_t(0));\n    std::vector<bool> is_transformed_dim(ndim, false);\n    for (const auto& dim : fft_dims) { is_transformed_dim[dim] = true; }\n\n    auto batch_end = std::partition(dim_permute.begin(), dim_permute.end(),\n                                    [&](int64_t d) { return !is_transformed_dim[d]; });\n    std::sort(dim_permute.begin(), batch_end,\n              [&](int64_t a, int64_t b) { return in_stride->at(a) > in_stride->at(b); });\n    std::copy(fft_dims.begin(), fft_dims.end(), batch_end);\n\n    // permute\n    auto input = JUST(functional::Permute(self, dim_permute));\n\n    std::vector<int64_t> batched_sizes(fft_ndim + 1);\n    batched_sizes[0] = -1;\n    std::copy(input->shape()->begin() + batch_dims, input->shape()->end(),\n              batched_sizes.begin() + 1);\n    // reshape\n    Shape batched_shape(batched_sizes);\n    input = JUST(functional::Reshape(input, batched_shape));\n\n    const auto batch_size = input->shape()->At(0);\n\n    batched_sizes[0] = batch_size;\n    std::vector<int64_t> batched_out_sizes(batched_sizes.begin(), batched_sizes.end());\n    FOR_RANGE(int64_t, i, 0, fft_dims.size()) { batched_out_sizes[i + 1] = out_sizes[fft_dims[i]]; }\n\n    // Inplace reshaping to original batch shape and inverting the dimension permutation\n    out_strides.resize(ndim, 0);\n\n    int64_t batch_numel = 1;\n    Stride contiguous_out_strides = Stride(batched_out_sizes);\n    for (int64_t i = batch_dims - 1; i >= 0; --i) {\n      out_strides[dim_permute[i]] = batch_numel * contiguous_out_strides[0];\n      batch_numel *= out_sizes[dim_permute[i]];\n    }\n    FOR_RANGE(int64_t, i, batch_dims, ndim) {\n      out_strides[dim_permute[i]] = contiguous_out_strides[1 + (i - batch_dims)];\n    }\n\n    // Judge if the input needs to be cloned\n    int64_t signal_ndim = input->shape()->size() - 1;\n    const Stride& batched_input_strides = *(JUST(input->stride()));\n    auto last_stride = JUST(oneflow::VectorAt(batched_input_strides, signal_ndim));\n    bool must_clone_input = false;\n    if (JUST(oneflow::VectorAt(batched_input_strides, 0)) == 0) { must_clone_input = true; }\n    for (auto i = signal_ndim - 1; !must_clone_input && i > 0; i--) {\n      auto stride = JUST(oneflow::VectorAt(batched_input_strides, i));\n      if (JUST(oneflow::VectorAt(*(input->shape()), i)) == 1) {\n        continue;\n      } else if (stride > 0 && stride % last_stride == 0) {\n        last_stride = stride;\n      } else {\n        must_clone_input = true;\n      }\n    }\n\n    if (must_clone_input) { input = JUST(functional::ToContiguous(input)); }\n    return input;\n  }\n\n  Maybe<void> parse_c2r_input_n_and_dims(const std::shared_ptr<Tensor>& x,\n                                         const Optional<std::vector<int64_t>>& n,\n                                         const Optional<std::vector<int64_t>>& dims,\n                                         int64_t& last_dim_size, std::vector<int64_t>& fft_len,\n                                         std::vector<int64_t>& wrapped_dims) const {\n    JUST(parse_input_n_and_dims(x, n, dims, fft_len, wrapped_dims));\n    // infer last_dim_size\n    last_dim_size = 0;\n    if (!n.has_value() || JUST(n)->back() == -1) {\n      int64_t last_dim = wrapped_dims.back();\n      last_dim_size = 2 * (x->dim(last_dim) - 1);\n    } else {\n      last_dim_size = JUST(n)->back();\n    }\n    CHECK_OR_RETURN(last_dim_size >= 1)\n        << \"RuntimeError: Invalid number of last_dim_size (\" << last_dim_size << \") specified\";\n    fft_len.back() = last_dim_size / 2 + 1;\n\n    return Maybe<void>::Ok();\n  }\n\n protected:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass FftC2CFunctor : public FftBaseFunctor {\n public:\n  FftC2CFunctor() : FftBaseFunctor(\"fft_c2c\") {}\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,\n                           const Optional<std::vector<int64_t>>& n,\n                           const Optional<std::vector<int64_t>>& dims, int32_t norm_mode,\n                           bool forward, bool normalized) const {\n    // NOTE: The parameter `normalized` indicates whether the FFT results need to be normalized\n    // using `ScalarMul`. This parameter is only valid when using CUDA devices. This parameter is\n    // not valid when using a CPU device, because the cpu's fft operator will be normalized inside\n    // the cpu oprator according to the parameter `forward` and the type of FFT transform\n\n    CHECK_OR_RETURN(x->dtype()->is_complex())\n        << \"RuntimeError: expects the dtype of input Tensor  is Complex, but gets \"\n        << x->dtype()->name();\n    std::vector<int64_t> fft_len(x->ndim(), 0);\n    std::vector<int64_t> wrapped_dims(x->ndim(), 0);\n\n    JUST(parse_input_n_and_dims(x, n, dims, fft_len, wrapped_dims));\n    auto resized_tensor =\n        n.has_value() == true ? JUST(resize_fft_input(x, wrapped_dims, fft_len)) : x;\n\n    DeviceType input_device{};\n    if (x->is_global()) {\n      input_device = JUST(x->parallel_desc())->device_type();\n    } else {\n      input_device = JUST(x->device())->enum_type();\n    }\n\n    double norm_fct = fft_compute_fct<double>(*(resized_tensor->shape()), wrapped_dims,\n                                              static_cast<fft_norm_mode>(norm_mode));\n\n    if (input_device == DeviceType::kCPU) {\n      auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"dims\", \"forward\", \"norm_mode\", \"norm_fct\");\n      attrs.SetAllAttrs(wrapped_dims, forward, norm_mode, norm_fct);\n      return OpInterpUtil::Dispatch<Tensor>(*op_, {resized_tensor}, attrs);\n    } else if (input_device == DeviceType::kCUDA) {\n      if (wrapped_dims.empty()) { return resized_tensor; }\n      std::vector<int64_t> out_sizes(resized_tensor->shape()->dim_vec().begin(),\n                                     resized_tensor->shape()->dim_vec().end());\n      std::vector<int64_t> sorted_dims(wrapped_dims.begin(), wrapped_dims.end());\n      auto working_tensor = resized_tensor;\n      std::vector<int64_t> out_strides;\n      std::shared_ptr<Tensor> output;\n      while (true) {\n        // Sort Dimemsions every iteration\n        auto strides = *JUST(working_tensor->stride());\n        std::sort(sorted_dims.begin(), sorted_dims.end(),\n                  [&](int64_t a, int64_t b) { return strides[a] > strides[b]; });\n\n        const auto max_dims = std::min(static_cast<size_t>(cufft_max_ndim), sorted_dims.size());\n        auto first_dims_end = sorted_dims.end();\n        auto first_dims_begin = first_dims_end - max_dims;\n        std::vector<int64_t> first_dims(first_dims_begin, first_dims_end);\n\n        auto input = JUST(permute_and_reshape(working_tensor, out_sizes, first_dims, out_strides));\n\n        std::vector<int64_t> fft_dims(input->ndim() - 1);  // must >= 1\n        std::iota(fft_dims.begin(), fft_dims.end(), int64_t(1));\n        auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"dims\", \"forward\", \"norm_mode\", \"norm_fct\");\n        attrs.SetAllAttrs(fft_dims, forward, norm_mode, norm_fct);\n        output = JUST(OpInterpUtil::Dispatch<Tensor>(*op_, {input}, attrs));\n        output = JUST(\n            functional::AsStrided(output, out_sizes, out_strides, JUST(output->storage_offset())));\n\n        sorted_dims.resize(sorted_dims.size() - max_dims);\n\n        if (sorted_dims.empty()) { break; }\n        working_tensor = std::move(output);\n      }\n\n      if (normalized) { JUST(functional::ScalarMul(output, Scalar(norm_fct), true)); }\n\n      return output;\n    } else {\n      CHECK_OR_RETURN(false) << \"RuntimeError: FFTC2C Only support cpu and cuda device.\";\n    }\n  }\n};\n\nclass FftR2CFunctor : public FftBaseFunctor {\n public:\n  FftR2CFunctor() : FftBaseFunctor(\"fft_r2c\") {}\n\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,\n                           const Optional<std::vector<int64_t>>& n,\n                           const Optional<std::vector<int64_t>>& dims, int32_t norm_mode,\n                           bool onesided, bool forward, bool normalized) const {\n    // NOTE: The parameter `normalized` indicates whether the FFT results need to be normalized\n    // using `ScalarMul`. This parameter is only valid when using CUDA devices. This parameter is\n    // not valid when using a CPU device, because the cpu's fft operator will be normalized inside\n    // the cpu oprator according to the parameter `forward` and the type of FFT transform\n\n    CHECK_OR_RETURN(!(x->dtype()->is_complex()))\n        << \"RuntimeError: expects the dtype of input Tensor  is Real, but gets \"\n        << x->dtype()->name();\n\n    auto input_tensor = JUST(promote_tensor_fft(x));\n\n    if (n.has_value() && dims.has_value()) {\n      CHECK_OR_RETURN((*JUST(n)).size() == (*JUST(dims)).size())\n          << \"RuntimeError: When dim and shape were both given, they must have the same length\";\n    }\n\n    std::vector<int64_t> fft_len(input_tensor->ndim(), 0);\n    std::vector<int64_t> wrapped_dims(input_tensor->ndim(), 0);\n    JUST(parse_input_n_and_dims(input_tensor, n, dims, fft_len, wrapped_dims));\n    auto resized_tensor = n.has_value() == true\n                              ? JUST(resize_fft_input(input_tensor, wrapped_dims, fft_len))\n                              : input_tensor;\n    DeviceType input_device{};\n    if (x->is_global()) {\n      input_device = JUST(x->parallel_desc())->device_type();\n    } else {\n      input_device = JUST(x->device())->enum_type();\n    }\n\n    double norm_fct = fft_compute_fct<double>(*(resized_tensor->shape()), wrapped_dims,\n                                              static_cast<fft_norm_mode>(norm_mode));\n\n    std::shared_ptr<Tensor> output;\n    if (input_device == DeviceType::kCPU) {\n      auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"dims\", \"norm_mode\", \"norm_fct\", \"onesided\");\n      attrs.SetAllAttrs(wrapped_dims, norm_mode, norm_fct, onesided);\n      output = JUST(OpInterpUtil::Dispatch<Tensor>(*op_, {resized_tensor}, attrs));\n    } else if (input_device == DeviceType::kCUDA) {\n      std::vector<int64_t> input_sizes(resized_tensor->shape()->begin(),\n                                       resized_tensor->shape()->end());\n      std::vector<int64_t> onesided_sizes = input_sizes;\n      int64_t last_dim = wrapped_dims.back();\n      int64_t last_dim_halfsize = (input_sizes[last_dim]) / 2 + 1;\n      onesided_sizes[last_dim] = last_dim_halfsize;\n      std::vector<int64_t> out_sizes = onesided ? onesided_sizes : input_sizes;\n\n      if (use_optimized_cufft_path(wrapped_dims)) {\n        std::vector<int64_t> out_strides;\n        auto input =\n            JUST(permute_and_reshape(resized_tensor, out_sizes, wrapped_dims, out_strides));\n\n        std::vector<int64_t> fft_dims(input->ndim() - 1);  // must >= 1\n        std::iota(fft_dims.begin(), fft_dims.end(), int64_t(1));\n\n        auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"dims\", \"norm_mode\", \"norm_fct\", \"onesided\");\n        attrs.SetAllAttrs(fft_dims, norm_mode, norm_fct, onesided);\n        output = JUST(OpInterpUtil::Dispatch<Tensor>(*op_, {input}, attrs));\n        output = JUST(\n            functional::AsStrided(output, out_sizes, out_strides, JUST(output->storage_offset())));\n      } else {\n        // First do the **one-sided** R2C transform on the last dimension\n        const std::shared_ptr<Tensor>& working_tensor = resized_tensor;\n        {\n          std::vector<int64_t> out_strides;\n          auto input = JUST(\n              permute_and_reshape(/*self=*/working_tensor, /*out_sizes=*/onesided_sizes,\n                                  /*fft_dims=*/{wrapped_dims.back()}, /*out_strides=*/out_strides));\n          auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"dims\", \"norm_mode\", \"norm_fct\", \"onesided\");\n          int64_t last_dim = input->shape()->size() - 1;\n          std::vector<int64_t> fft_last_dim_vec = {last_dim};\n          attrs.SetAllAttrs(fft_last_dim_vec, norm_mode, norm_fct, /*onesided=*/true);\n          output = JUST(OpInterpUtil::Dispatch<Tensor>(*op_, {input}, attrs));\n          output = JUST(functional::AsStrided(output, out_sizes, out_strides,\n                                              JUST(output->storage_offset())));\n        }\n\n        // Then any remaining C2C transforms\n        std::vector<int64_t> sorted_dims(wrapped_dims.begin(), wrapped_dims.end() - 1);\n        if (!sorted_dims.empty()) {\n          output = JUST(functional::FftC2C(output, NullOpt, sorted_dims, norm_mode,\n                                           /*forward=*/true, /*normalize=*/false));\n        }\n      }\n\n      if (normalized) { JUST(functional::ScalarMul(output, Scalar(norm_fct), true)); }\n\n    } else {\n      CHECK_OR_RETURN(false) << \"RuntimeError: FFTR2C Only support cpu and cuda device.\";\n    }\n\n    if (!forward) {\n      return functional::ConjPhysical(output);\n    } else {\n      return output;\n    }\n  }\n};\n\nclass FftC2RFunctor : public FftBaseFunctor {\n public:\n  FftC2RFunctor() : FftBaseFunctor(\"fft_c2r\") {}\n\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,\n                           const Optional<std::vector<int64_t>>& n,\n                           const Optional<std::vector<int64_t>>& dims, int32_t norm_mode,\n                           bool forward, bool normalized) const {\n    // NOTE: The parameter `normalized` indicates whether the FFT results need to be normalized\n    // using `ScalarMul`. This parameter is only valid when using CUDA devices. This parameter is\n    // not valid when using a CPU device, because the cpu's fft operator will be normalized inside\n    // the cpu oprator according to the parameter `forward` and the type of FFT transform\n\n    CHECK_OR_RETURN(x->dtype()->is_complex())\n        << \"RuntimeError: expects the dtype of input Tensor is Complex, but gets \"\n        << x->dtype()->name();\n\n    if (n.has_value() && dims.has_value()) {\n      CHECK_OR_RETURN((*JUST(n)).size() == (*JUST(dims)).size())\n          << \"RuntimeError: When dim and shape were both given, they must have the same length\";\n    }\n\n    std::vector<int64_t> wrapped_dims(x->ndim(), 0);\n    std::vector<int64_t> fft_len(x->ndim(), 0);\n    int64_t last_dim_size = 0;\n    JUST(parse_c2r_input_n_and_dims(x, n, dims, last_dim_size, fft_len, wrapped_dims));\n\n    auto resized_tensor =\n        n.has_value() == true ? JUST(resize_fft_input(x, wrapped_dims, fft_len)) : x;\n\n    Shape out_shape = *(resized_tensor->shape());\n    out_shape[wrapped_dims.back()] = last_dim_size;\n    double norm_fct =\n        fft_compute_fct<double>(out_shape, wrapped_dims, static_cast<fft_norm_mode>(norm_mode));\n\n    if (forward) { resized_tensor = JUST(functional::ConjPhysical(resized_tensor)); }\n\n    DeviceType input_device{};\n    if (x->is_global()) {\n      input_device = JUST(x->parallel_desc())->device_type();\n    } else {\n      input_device = JUST(x->device())->enum_type();\n    }\n\n    if (input_device == DeviceType::kCPU) {\n      auto& attrs =\n          THREAD_CACHED_MUTABLE_ATTR_MAP(\"dims\", \"norm_mode\", \"norm_fct\", \"last_dim_size\");\n      attrs.SetAllAttrs(wrapped_dims, norm_mode, norm_fct, last_dim_size);\n      return OpInterpUtil::Dispatch<Tensor>(*op_, {resized_tensor}, attrs);\n    } else if (input_device == DeviceType::kCUDA) {\n      std::shared_ptr<Tensor> output;\n      if (use_optimized_cufft_path(wrapped_dims)) {\n        auto input = JUST(functional::ToContiguous(resized_tensor));\n        std::vector<int64_t> out_sizes(out_shape.dim_vec().begin(), out_shape.dim_vec().end());\n        std::vector<int64_t> out_strides;\n        input = JUST(permute_and_reshape(input, out_sizes, wrapped_dims, out_strides));\n\n        std::vector<int64_t> fft_dims(input->ndim() - 1);  // must >= 1\n        std::iota(fft_dims.begin(), fft_dims.end(), int64_t(1));\n\n        auto& attrs =\n            THREAD_CACHED_MUTABLE_ATTR_MAP(\"dims\", \"norm_mode\", \"norm_fct\", \"last_dim_size\");\n        attrs.SetAllAttrs(fft_dims, norm_mode, norm_fct, last_dim_size);\n        output = JUST(OpInterpUtil::Dispatch<Tensor>(*op_, {input}, attrs));\n        output = JUST(\n            functional::AsStrided(output, out_sizes, out_strides, JUST(output->storage_offset())));\n      } else {\n        // First complete any C2C transforms\n        std::shared_ptr<Tensor> temp;\n        if (wrapped_dims.size() > 1) {\n          std::vector<int64_t> any_c2c_dims(wrapped_dims.begin(), wrapped_dims.end() - 1);\n          temp = JUST(functional::FftC2C(resized_tensor, NullOpt, any_c2c_dims,\n                                         static_cast<int32_t>(fft_norm_mode::none),\n                                         /*forward=*/false, /*normalized=*/false));\n        } else {\n          temp = JUST(functional::ToContiguous(resized_tensor));\n        }\n\n        // Finally, do the 1D C2R transforms on the last dim\n        std::vector<int64_t> out_strides;\n        std::vector<int64_t> out_sizes(out_shape.dim_vec().begin(), out_shape.dim_vec().end());\n        auto input = JUST(permute_and_reshape(/*self=*/temp, /*out_sizes=*/out_sizes,\n                                              /*fft_dims=*/{wrapped_dims.back()},\n                                              /*out_strides=*/out_strides));\n\n        auto& attrs =\n            THREAD_CACHED_MUTABLE_ATTR_MAP(\"dims\", \"norm_mode\", \"norm_fct\", \"last_dim_size\");\n        int64_t last_dim = input->shape()->size() - 1;\n        std::vector<int64_t> fft_last_dim_vec = {last_dim};\n        attrs.SetAllAttrs(fft_last_dim_vec, norm_mode, norm_fct, /*last_dim_size=*/last_dim_size);\n\n        output = JUST(OpInterpUtil::Dispatch<Tensor>(*op_, {input}, attrs));\n        output = JUST(\n            functional::AsStrided(output, out_sizes, out_strides, JUST(output->storage_offset())));\n      }\n\n      if (normalized) { JUST(functional::ScalarMul(output, Scalar(norm_fct), /*inplace=*/true)); }\n      return output;\n    } else {\n      CHECK_OR_RETURN(false) << \"RuntimeError: FFTC2R Only support cpu and cuda device.\";\n    }\n  }\n};\n\nclass FftFunctor {\n public:\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input, int64_t n, int64_t dim,\n                           const Optional<std::string>& norm) const {\n    std::string norm_str = norm.value_or(\"backward\");\n    std::vector<int64_t> fft_dim{dim};\n\n    bool forward = true;\n    fft_norm_mode norm_mode = fft_norm_mode::none;\n    norm_mode = fft_norm_from_string(norm_str, forward);\n\n    std::vector<int64_t> len{n};\n    return input->dtype()->is_complex()\n               ? functional::FftC2C(input, len, fft_dim, static_cast<int32_t>(norm_mode),\n                                    /*forward=*/forward, /*normalized=*/true)\n               : functional::FftR2C(input, len, fft_dim, static_cast<int32_t>(norm_mode),\n                                    /*onesided=*/false, /*forward=*/forward, /*normalized=*/true);\n  }\n};\n\nclass IFftFunctor {\n public:\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input, int64_t n, int64_t dim,\n                           const Optional<std::string>& norm) const {\n    auto norm_str = norm.value_or(\"backward\");\n    std::vector<int64_t> fft_dim{dim};\n\n    bool forward = false;\n    fft_norm_mode norm_mode = fft_norm_mode::none;\n    norm_mode = fft_norm_from_string(norm_str, forward);\n    std::vector<int64_t> len{n};\n    return input->dtype()->is_complex()\n               ? functional::FftC2C(input, len, fft_dim, static_cast<int32_t>(norm_mode),\n                                    /*forward=*/forward, /*normalized=*/true)\n               : functional::FftR2C(input, len, fft_dim, static_cast<int32_t>(norm_mode),\n                                    /*onesided=*/false, /*forward=*/forward, /*normalized=*/true);\n  }\n};\n\nclass Fft2Functor {\n public:\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input,\n                           const Optional<std::vector<int64_t>>& s, const std::vector<int64_t>& dim,\n                           const Optional<std::string>& norm) const {\n    return functional::FftN(input, s, dim, norm);\n  }\n};\n\nclass IFft2Functor {\n public:\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input,\n                           const Optional<std::vector<int64_t>>& s, const std::vector<int64_t>& dim,\n                           const Optional<std::string>& norm) const {\n    return functional::IFftN(input, s, dim, norm);\n  }\n};\n\nclass FftNFunctor {\n public:\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input,\n                           const Optional<std::vector<int64_t>>& s,\n                           const Optional<std::vector<int64_t>>& dim,\n                           const Optional<std::string>& norm) const {\n    std::string norm_str = norm.value_or(\"backward\");\n    bool forward = true;\n    fft_norm_mode norm_mode = fft_norm_mode::none;\n    norm_mode = fft_norm_from_string(norm_str, forward);\n\n    if (!(input->dtype()->is_complex())) {\n      // cast to complex\n      TensorProcessor tensor_processor;\n      Symbol<DType> complex_dtype;\n      if (input->dtype() == DType::Double()) {\n        complex_dtype = DType::Complex128();\n      } else {\n        complex_dtype = DType::Complex64();\n      }\n      JUST(tensor_processor.AddInputs({input}, {complex_dtype}).Apply());\n      TensorTuple input_tuple = JUST(tensor_processor.GetInputs());\n      return functional::FftC2C(JUST(oneflow::VectorAt(input_tuple, 0)), s, dim,\n                                static_cast<int32_t>(norm_mode), /*forward=*/forward,\n                                /*normalized=*/true);\n    } else {\n      return functional::FftC2C(input, s, dim, static_cast<int32_t>(norm_mode), /*forward=*/forward,\n                                /*normalized=*/true);\n    }\n  }\n};\n\nclass IFftNFunctor {\n public:\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input,\n                           const Optional<std::vector<int64_t>>& s,\n                           const Optional<std::vector<int64_t>>& dim,\n                           const Optional<std::string>& norm) const {\n    std::string norm_str = norm.value_or(\"backward\");\n    bool forward = false;\n    fft_norm_mode norm_mode = fft_norm_mode::none;\n    norm_mode = fft_norm_from_string(norm_str, forward);\n\n    if (!(input->dtype()->is_complex())) {\n      // cast to complex\n      TensorProcessor tensor_processor;\n      Symbol<DType> complex_dtype;\n      if (input->dtype() == DType::Double()) {\n        complex_dtype = DType::Complex128();\n      } else {\n        complex_dtype = DType::Complex64();\n      }\n      JUST(tensor_processor.AddInputs({input}, {complex_dtype}).Apply());\n      TensorTuple input_tuple = JUST(tensor_processor.GetInputs());\n      return functional::FftC2C(JUST(oneflow::VectorAt(input_tuple, 0)), s, dim,\n                                static_cast<int32_t>(norm_mode), /*forward=*/forward,\n                                /*normalized=*/true);\n    } else {\n      return functional::FftC2C(input, s, dim, static_cast<int32_t>(norm_mode), /*forward=*/forward,\n                                /*normalized=*/true);\n    }\n  }\n};\n\nclass RFftFunctor {\n public:\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input, int64_t n, int64_t dim,\n                           const Optional<std::string>& norm) const {\n    CHECK_OR_RETURN(!(input->dtype()->is_complex()))\n        << \"RuntimeError: expects the dtype of input Tensor  is Real, but gets \"\n        << input->dtype()->name();\n\n    std::string norm_str = norm.value_or(\"backward\");\n    std::vector<int64_t> fft_dim{dim};\n    bool forward = true;\n    fft_norm_mode norm_mode = fft_norm_mode::none;\n    norm_mode = fft_norm_from_string(norm_str, forward);\n\n    std::vector<int64_t> len{n};\n    return functional::FftR2C(input, len, fft_dim, static_cast<int32_t>(norm_mode),\n                              /*onesided=*/true, /*forward=*/forward, /*normalized=*/true);\n  }\n};\n\nclass IRFftFunctor {\n public:\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input, int64_t n, int64_t dim,\n                           const Optional<std::string>& norm) const {\n    std::string norm_str = norm.value_or(\"backward\");\n    std::vector<int64_t> fft_dim{dim};\n\n    bool forward = false;\n    fft_norm_mode norm_mode = fft_norm_mode::none;\n    norm_mode = fft_norm_from_string(norm_str, forward);\n\n    std::vector<int64_t> len{n};\n    return functional::FftC2R(input, len, fft_dim, static_cast<int32_t>(norm_mode),\n                              /*forward=*/forward, /*normalized=*/true);\n  }\n};\n\nclass RFft2Functor {\n public:\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input,\n                           const Optional<std::vector<int64_t>>& s, const std::vector<int64_t>& dim,\n                           const Optional<std::string>& norm) const {\n    return functional::RFftN(input, s, dim, norm);\n  }\n};\n\nclass IRFft2Functor {\n public:\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input,\n                           const Optional<std::vector<int64_t>>& s, const std::vector<int64_t>& dim,\n                           const Optional<std::string>& norm) const {\n    return functional::IRFftN(input, s, dim, norm);\n  }\n};\n\nclass RFftNFunctor {\n public:\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input,\n                           const Optional<std::vector<int64_t>>& s,\n                           const Optional<std::vector<int64_t>>& dim,\n                           const Optional<std::string>& norm) const {\n    CHECK_OR_RETURN(!(input->dtype()->is_complex()))\n        << \"RuntimeError: expects the dtype of input Tensor  is Real, but gets \"\n        << input->dtype()->name();\n\n    std::string norm_str = norm.value_or(\"backward\");\n    bool forward = true;\n    fft_norm_mode norm_mode = fft_norm_mode::none;\n    norm_mode = fft_norm_from_string(norm_str, forward);\n\n    return functional::FftR2C(input, s, dim, static_cast<int32_t>(norm_mode), /*onesided=*/true,\n                              /*forward=*/forward, /*normalized=*/true);\n  }\n};\n\nclass IRFftNFunctor {\n public:\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input,\n                           const Optional<std::vector<int64_t>>& s,\n                           const Optional<std::vector<int64_t>>& dim,\n                           const Optional<std::string>& norm) const {\n    CHECK_OR_RETURN(input->dtype()->is_complex())\n        << \"RuntimeError: expects the dtype of input Tensor is Complex, but gets \"\n        << input->dtype()->name();\n\n    std::string norm_str = norm.value_or(\"backward\");\n    bool forward = false;\n    fft_norm_mode norm_mode = fft_norm_mode::none;\n    norm_mode = fft_norm_from_string(norm_str, forward);\n\n    return functional::FftC2R(input, s, dim, static_cast<int32_t>(norm_mode), /*forward=*/false,\n                              /*normalized=*/true);\n  }\n};\n\nclass HFftFunctor {\n public:\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input, int64_t n, int64_t dim,\n                           const Optional<std::string>& norm) const {\n    CHECK_OR_RETURN(input->dtype()->is_complex())\n        << \"RuntimeError: expects the dtype of input Tensor is Complex, but gets \"\n        << input->dtype()->name();\n\n    std::string norm_str = norm.value_or(\"backward\");\n    std::vector<int64_t> fft_dim{dim};\n\n    bool forward = true;\n    fft_norm_mode norm_mode = fft_norm_mode::none;\n    norm_mode = fft_norm_from_string(norm_str, forward);\n\n    std::vector<int64_t> len{n};\n    return functional::FftC2R(input, len, fft_dim, static_cast<int32_t>(norm_mode),\n                              /*forward=*/forward, /*normalized=*/true);\n  }\n};\n\nclass IHFftFunctor {\n public:\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input, int64_t n, int64_t dim,\n                           const Optional<std::string>& norm) const {\n    CHECK_OR_RETURN(!(input->dtype()->is_complex()))\n        << \"RuntimeError: expects the dtype of input Tensor is Real, but gets \"\n        << input->dtype()->name();\n\n    std::string norm_str = norm.value_or(\"backward\");\n    std::vector<int64_t> fft_dim{dim};\n\n    bool forward = false;\n    fft_norm_mode norm_mode = fft_norm_mode::none;\n    norm_mode = fft_norm_from_string(norm_str, forward);\n\n    std::vector<int64_t> len{n};\n    return functional::FftR2C(input, len, fft_dim, static_cast<int32_t>(norm_mode),\n                              /*onesided=*/true,\n                              /*forward=*/forward, /*normalized=*/true);\n  }\n};\n\nclass HFft2Functor {\n public:\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input,\n                           const Optional<std::vector<int64_t>>& s, const std::vector<int64_t>& dim,\n                           const Optional<std::string>& norm) const {\n    return functional::HFftN(input, s, dim, norm);\n  }\n};\n\nclass IHFft2Functor {\n public:\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input,\n                           const Optional<std::vector<int64_t>>& s, const std::vector<int64_t>& dim,\n                           const Optional<std::string>& norm) const {\n    return functional::IHFftN(input, s, dim, norm);\n  }\n};\n\nclass HFftNFunctor : FftBaseFunctor {\n public:\n  HFftNFunctor() : FftBaseFunctor() {}\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input,\n                           const Optional<std::vector<int64_t>>& s,\n                           const Optional<std::vector<int64_t>>& dim,\n                           const Optional<std::string>& norm) const {\n    CHECK_OR_RETURN(input->dtype()->is_complex())\n        << \"RuntimeError: expects the dtype of input Tensor is Complex, but gets \"\n        << input->dtype()->name();\n\n    std::string norm_str = norm.value_or(\"backward\");\n\n    bool forward = true;\n    fft_norm_mode norm_mode = fft_norm_mode::none;\n    norm_mode = fft_norm_from_string(norm_str, forward);\n\n    if (s.has_value() && dim.has_value()) {\n      CHECK_OR_RETURN((*JUST(s)).size() == (*JUST(dim)).size())\n          << \"RuntimeError: When dim and shape were both given, they must have the same length\";\n    }\n\n    std::vector<int64_t> wrapped_dims(input->ndim(), 0);\n    std::vector<int64_t> fft_len(input->ndim(), 0);\n    int64_t last_dim_size = 0;\n    JUST(parse_c2r_input_n_and_dims(input, s, dim, last_dim_size, fft_len, wrapped_dims));\n\n    auto resized_tensor =\n        s.has_value() == true ? JUST(resize_fft_input(input, wrapped_dims, fft_len)) : input;\n\n    std::shared_ptr<Tensor> temp;\n    if (wrapped_dims.size() > 1) {\n      // ND Fast Fourier Transform\n      std::vector<int64_t> c2c_dims(wrapped_dims.begin(), wrapped_dims.end() - 1);\n      temp = JUST(functional::FftC2C(resized_tensor, NullOpt, c2c_dims,\n                                     static_cast<int32_t>(norm_mode), /*forward=*/forward,\n                                     /*normalized=*/true));\n    } else {\n      temp = resized_tensor;\n    }\n\n    // Finally, do 1D fft_c2r\n    int64_t last_dim = wrapped_dims.back();\n    std::vector<int64_t> last_dim_vec = {last_dim};\n    std::vector<int64_t> last_dim_size_vec = {last_dim_size};\n    return functional::FftC2R(temp, last_dim_size_vec, last_dim_vec,\n                              static_cast<int32_t>(norm_mode), /*forward=*/forward,\n                              /*normalized=*/true);\n  }\n};\n\nclass IHFftNFunctor : FftBaseFunctor {\n public:\n  IHFftNFunctor() : FftBaseFunctor() {}\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input,\n                           const Optional<std::vector<int64_t>>& s,\n                           const Optional<std::vector<int64_t>>& dim,\n                           const Optional<std::string>& norm) const {\n    CHECK_OR_RETURN(!(input->dtype()->is_complex()))\n        << \"RuntimeError: expects the dtype of input Tensor is Real, but gets \"\n        << input->dtype()->name();\n\n    std::string norm_str = norm.value_or(\"backward\");\n    bool forward = false;\n    fft_norm_mode norm_mode = fft_norm_mode::none;\n    norm_mode = fft_norm_from_string(norm_str, forward);\n\n    auto input_tensor = JUST(promote_tensor_fft(input, false));\n\n    if (s.has_value() && dim.has_value()) {\n      CHECK_OR_RETURN((*JUST(s)).size() == (*JUST(dim)).size())\n          << \"RuntimeError: When dim and shape were both given, they must have the same length\";\n    }\n\n    std::vector<int64_t> fft_len(input_tensor->ndim(), 0);\n    std::vector<int64_t> wrapped_dims(input_tensor->ndim(), 0);\n    JUST(parse_input_n_and_dims(input_tensor, s, dim, fft_len, wrapped_dims));\n    auto resized_tensor = s.has_value() == true\n                              ? JUST(resize_fft_input(input_tensor, wrapped_dims, fft_len))\n                              : input_tensor;\n\n    // First do 1D R2C Transform on the last dim\n    const auto last_dim_len = fft_len.back();\n    const auto last_dim = wrapped_dims.back();\n    std::vector<int64_t> r2c_fft_len = {last_dim_len};\n    std::vector<int64_t> r2c_fft_dim = {last_dim};\n    auto temp = JUST(functional::FftR2C(resized_tensor, r2c_fft_len, r2c_fft_dim,\n                                        static_cast<int32_t>(norm_mode), /*onesided=*/true,\n                                        /*forward=*/forward, /*normalized=*/true));\n    // NOTE: `temp` is already conjugated in `functional::FftR2C`\n    if (wrapped_dims.size() == 1) { return temp; }\n\n    // Finally do C2C Transform on the remaining dims\n    std::vector<int64_t> c2c_dims(wrapped_dims.begin(), wrapped_dims.end() - 1);\n    return functional::FftC2C(temp, NullOpt, c2c_dims, static_cast<int32_t>(norm_mode),\n                              /*forward=*/forward, /*normalized=*/true);\n  }\n};\n\nclass StftFunctor {\n public:\n  StftFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"stft\").Input(\"input\").Output(\"output\").Build());\n  }\n\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input, const int64_t n_fft,\n                           const Optional<int64_t>& hop_length, const Optional<int64_t>& win_length,\n                           const Optional<one::Tensor>& window, const bool center,\n                           const std::string& mode, const bool normalized, const bool onesided,\n                           const bool return_complex) const {\n    CHECK_OR_RETURN(n_fft > 0) << Error::RuntimeError() << \"Expected 0 < n_fft , but got \" << n_fft;\n    int64_t new_hop_length = hop_length.has_value() == true ? JUST(hop_length) : n_fft / 4;\n    int64_t new_win_length = win_length.has_value() == true ? JUST(win_length) : n_fft;\n    auto input_tensor = input;\n\n    // TODO(yzm):Remove this line when complex numbers are supported\n    CHECK_OR_RETURN(return_complex == false)\n        << Error::RuntimeError() << \"return_complex parameter is not supported at this time\";\n\n    const auto& NumAxes = input_tensor->shape()->NumAxes();\n    CHECK_OR_RETURN(NumAxes == 2 || NumAxes == 1)\n        << Error::RuntimeError() << \"Expected a 1D or 2D tensor,but got \" << NumAxes << \"D\";\n\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"normalized\", \"onesided\", \"return_complex\");\n    attrs.SetAllAttrs(normalized, onesided, return_complex);\n\n    if (NumAxes == 1) { input_tensor = JUST(functional::Unsqueeze(input_tensor, 0)); }\n    if (center) {\n      const auto& input_shape = input_tensor->shape();\n      const auto input_dim = input_tensor->shape()->NumAxes();\n\n      const auto extra_dims = std::max(size_t{3}, (size_t)input_dim) - input_dim;\n      const auto pad_amount = n_fft / 2;\n\n      DimVector extended_shape(extra_dims, 1);\n      extended_shape.append(input_shape->begin(), input_shape->end());\n      input_tensor =\n          JUST(functional::Pad(JUST(functional::View(input_tensor, Shape(extended_shape))),\n                               {pad_amount, pad_amount}, mode, Scalar(0)));\n\n      DimVector view_shape;\n      if (input_dim == 1) {\n        view_shape = {input_tensor->shape()->back()};\n      } else {\n        view_shape = {input_shape->at(0), input_tensor->shape()->back()};\n      }\n      input_tensor = JUST(functional::View(input_tensor, Shape(view_shape)));\n    }\n\n    int32_t batch = input_tensor->shape()->At(0);\n    int32_t len = input_tensor->shape()->At(1);\n    int32_t n_frames = 1 + (len - n_fft) / new_hop_length;\n    int32_t fft_size = static_cast<int32_t>(n_fft);\n    CHECK_OR_RETURN(n_fft > 0 && n_fft <= len)\n        << Error::RuntimeError() << \"Expected 0 < n_fft < \" << len << \" ,but got \" << n_fft;\n    CHECK_GT_OR_RETURN(new_hop_length, 0)\n        << Error::RuntimeError() << \"Expected hop_length > 0, but got \" << new_hop_length;\n    CHECK_OR_RETURN(new_win_length > 0 && new_win_length <= n_fft)\n        << Error::RuntimeError() << \"Expected 0 < win_length <=n_fft ,but got \" << new_win_length;\n    const auto& stride = *JUST(input_tensor->stride());\n    std::vector<int32_t> strides(stride.begin(), stride.end());\n    input_tensor =\n        JUST(view::AsStrided(input_tensor, {batch, n_frames, fft_size},\n                             {JUST(VectorAt(strides, 0)),\n                              static_cast<int32_t>(new_hop_length) * JUST(VectorAt(strides, 1)),\n                              JUST(VectorAt(strides, 1))},\n                             0));\n\n    std::shared_ptr<Tensor> temp_tensor;\n    if (window.has_value()) {\n      temp_tensor = JUST(window);\n      CHECK_OR_RETURN(temp_tensor->shape()->NumAxes() == 1\n                      && temp_tensor->shape()->at(0) == new_win_length)\n          << Error::RuntimeError()\n          << \"Expected a 1D window tensor of size equal to win_length=\" << new_win_length\n          << \", but got window with size \" << temp_tensor->shape()->ToString();\n    }\n    if (new_win_length < n_fft) {\n      temp_tensor = JUST(functional::Fill(temp_tensor, 0));\n      const int64_t left = (n_fft - new_win_length) / 2;\n\n      if (window.has_value()) {\n        // TODO(yzm):Copy the window matrix to the defined range,such as\n        //'''\n        //      functional::AssignLocalTensor(JUST(functional::Narrow(temp_tensor, 0,\n        //      left,new_win_length)), window);\n        //'''\n        // Remove the following check after support\n        CHECK_OR_RETURN(false) << Error::RuntimeError()\n                               << \"The following conditions are not currently supported: \"\n                                  \"win_length<n_fft and the window function is customized\";\n      } else {\n        temp_tensor = JUST(\n            functional::Fill(JUST(functional::Narrow(temp_tensor, 0, left, new_win_length)), 1.0));\n      }\n    }\n\n    if (new_win_length < n_fft || window.has_value()) {\n      input_tensor = JUST(functional::Mul(input_tensor, temp_tensor));\n    }\n\n    auto output = JUST(OpInterpUtil::Dispatch<Tensor>(\n        *op_, {JUST(functional::ToContiguous(input_tensor))}, attrs));\n    if (NumAxes == 2 && input->shape()->At(0) == 1) {\n      output = JUST(functional::Unsqueeze(output, 0));\n    }\n    return output;\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass FusedWeightedSumFunctor {\n public:\n  FusedWeightedSumFunctor() {\n    op_.resize(kMaxInputCount /*the maximum number of inputs*/);\n    for (int n = 1; n < op_.size(); ++n) {\n      op_[n] =\n          CHECK_JUST(one::OpBuilder(\"fused_weighted_sum\").Input(\"in\", n).Output(\"out\").Build());\n    }\n  }\n  Maybe<Tensor> operator()(const TensorTuple& in, const std::vector<float>& weights,\n                           const float& alpha) const {\n    CHECK_GE_OR_RETURN(in.size(), 1);\n    CHECK_LT_OR_RETURN(in.size(), kMaxInputCount);\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"weights\", \"alpha\");\n    attrs.SetAllAttrs(weights, alpha);\n    return JUST(OpInterpUtil::Dispatch<Tensor>(*op_[in.size()], in, attrs));\n  }\n\n private:\n  std::vector<std::shared_ptr<OpExpr>> op_;\n};\n\nclass FusedCenterFunctor {\n public:\n  FusedCenterFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"fused_get_center_dist\")\n                         .Input(\"b1_x1\")\n                         .Input(\"b1_x2\")\n                         .Input(\"b2_x1\")\n                         .Input(\"b2_x2\")\n                         .Input(\"b1_y1\")\n                         .Input(\"b1_y2\")\n                         .Input(\"b2_y1\")\n                         .Input(\"b2_y2\")\n                         .Output(\"rho2\")\n                         .Build());\n  }\n\n  Maybe<Tensor> operator()(\n      const std::shared_ptr<one::Tensor>& b1_x1, const std::shared_ptr<one::Tensor>& b1_x2,\n      const std::shared_ptr<one::Tensor>& b2_x1, const std::shared_ptr<one::Tensor>& b2_x2,\n      const std::shared_ptr<one::Tensor>& b1_y1, const std::shared_ptr<one::Tensor>& b1_y2,\n      const std::shared_ptr<one::Tensor>& b2_y1, const std::shared_ptr<one::Tensor>& b2_y2) const {\n    return OpInterpUtil::Dispatch<Tensor>(\n        *op_, {b1_x1, b1_x2, b2_x1, b2_x2, b1_y1, b1_y2, b2_y1, b2_y2}, {});\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass FusedCenterGradFunctor {\n public:\n  FusedCenterGradFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"fused_get_center_dist_grad\")\n                         .Input(\"b1_x1\")\n                         .Input(\"b1_x2\")\n                         .Input(\"b2_x1\")\n                         .Input(\"b2_x2\")\n                         .Input(\"b1_y1\")\n                         .Input(\"b1_y2\")\n                         .Input(\"b2_y1\")\n                         .Input(\"b2_y2\")\n                         .Input(\"rho2_diff\")\n                         .Output(\"b1_x1_diff\")\n                         .Output(\"b1_x2_diff\")\n                         .Output(\"b2_x1_diff\")\n                         .Output(\"b2_x2_diff\")\n                         .Output(\"b1_y1_diff\")\n                         .Output(\"b1_y2_diff\")\n                         .Output(\"b2_y1_diff\")\n                         .Output(\"b2_y2_diff\")\n                         .Build());\n  }\n\n  Maybe<TensorTuple> operator()(\n      const std::shared_ptr<one::Tensor>& b1_x1, const std::shared_ptr<one::Tensor>& b1_x2,\n      const std::shared_ptr<one::Tensor>& b2_x1, const std::shared_ptr<one::Tensor>& b2_x2,\n      const std::shared_ptr<one::Tensor>& b1_y1, const std::shared_ptr<one::Tensor>& b1_y2,\n      const std::shared_ptr<one::Tensor>& b2_y1, const std::shared_ptr<one::Tensor>& b2_y2,\n      const std::shared_ptr<one::Tensor>& rho2_diff) const {\n    return OpInterpUtil::Dispatch<TensorTuple>(\n        *op_, {b1_x1, b1_x2, b2_x1, b2_x2, b1_y1, b1_y2, b2_y1, b2_y2, rho2_diff}, {});\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass FusedGetIntersectionAreaFunctor {\n public:\n  FusedGetIntersectionAreaFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"fused_get_intersection_area\")\n                         .Input(\"b1_x1\")\n                         .Input(\"b1_x2\")\n                         .Input(\"b2_x1\")\n                         .Input(\"b2_x2\")\n                         .Input(\"b1_y1\")\n                         .Input(\"b1_y2\")\n                         .Input(\"b2_y1\")\n                         .Input(\"b2_y2\")\n                         .Output(\"inter\")\n                         .Build());\n  }\n\n  Maybe<Tensor> operator()(\n      const std::shared_ptr<one::Tensor>& b1_x1, const std::shared_ptr<one::Tensor>& b1_x2,\n      const std::shared_ptr<one::Tensor>& b2_x1, const std::shared_ptr<one::Tensor>& b2_x2,\n      const std::shared_ptr<one::Tensor>& b1_y1, const std::shared_ptr<one::Tensor>& b1_y2,\n      const std::shared_ptr<one::Tensor>& b2_y1, const std::shared_ptr<one::Tensor>& b2_y2) const {\n    return OpInterpUtil::Dispatch<Tensor>(\n        *op_, {b1_x1, b1_x2, b2_x1, b2_x2, b1_y1, b1_y2, b2_y1, b2_y2}, {});\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass FusedGetIntersectionAreaGradFunctor {\n public:\n  FusedGetIntersectionAreaGradFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"fused_get_intersection_area_grad\")\n                         .Input(\"b1_x1\")\n                         .Input(\"b1_x2\")\n                         .Input(\"b2_x1\")\n                         .Input(\"b2_x2\")\n                         .Input(\"b1_y1\")\n                         .Input(\"b1_y2\")\n                         .Input(\"b2_y1\")\n                         .Input(\"b2_y2\")\n                         .Input(\"inter_diff\")\n                         .Output(\"b1_x1_diff\")\n                         .Output(\"b1_x2_diff\")\n                         .Output(\"b2_x1_diff\")\n                         .Output(\"b2_x2_diff\")\n                         .Output(\"b1_y1_diff\")\n                         .Output(\"b1_y2_diff\")\n                         .Output(\"b2_y1_diff\")\n                         .Output(\"b2_y2_diff\")\n                         .Build());\n  }\n\n  Maybe<TensorTuple> operator()(\n      const std::shared_ptr<one::Tensor>& b1_x1, const std::shared_ptr<one::Tensor>& b1_x2,\n      const std::shared_ptr<one::Tensor>& b2_x1, const std::shared_ptr<one::Tensor>& b2_x2,\n      const std::shared_ptr<one::Tensor>& b1_y1, const std::shared_ptr<one::Tensor>& b1_y2,\n      const std::shared_ptr<one::Tensor>& b2_y1, const std::shared_ptr<one::Tensor>& b2_y2,\n      const std::shared_ptr<one::Tensor>& inter_diff) const {\n    return OpInterpUtil::Dispatch<TensorTuple>(\n        *op_, {b1_x1, b1_x2, b2_x1, b2_x2, b1_y1, b1_y2, b2_y1, b2_y2, inter_diff}, {});\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass FusedGetBounddingBoxesCoordFunctor {\n public:\n  FusedGetBounddingBoxesCoordFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"fused_get_boundding_boxes_coord\")\n                         .Input(\"x1\")\n                         .Input(\"y1\")\n                         .Input(\"w1\")\n                         .Input(\"h1\")\n                         .Input(\"x2\")\n                         .Input(\"y2\")\n                         .Input(\"w2\")\n                         .Input(\"h2\")\n                         .Output(\"b1_x1\")\n                         .Output(\"b1_x2\")\n                         .Output(\"b1_y1\")\n                         .Output(\"b1_y2\")\n                         .Output(\"b2_x1\")\n                         .Output(\"b2_x2\")\n                         .Output(\"b2_y1\")\n                         .Output(\"b2_y2\")\n                         .Build());\n  }\n\n  Maybe<TensorTuple> operator()(\n      const std::shared_ptr<one::Tensor>& x1, const std::shared_ptr<one::Tensor>& y1,\n      const std::shared_ptr<one::Tensor>& w1, const std::shared_ptr<one::Tensor>& h1,\n      const std::shared_ptr<one::Tensor>& x2, const std::shared_ptr<one::Tensor>& y2,\n      const std::shared_ptr<one::Tensor>& w2, const std::shared_ptr<one::Tensor>& h2) const {\n    return OpInterpUtil::Dispatch<TensorTuple>(*op_, {x1, y1, w1, h1, x2, y2, w2, h2}, {});\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass FusedGetBounddingBoxesCoordGradFunctor {\n public:\n  FusedGetBounddingBoxesCoordGradFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"fused_get_boundding_boxes_coord_grad\")\n                         .Input(\"b1_x1_diff\")\n                         .Input(\"b1_x2_diff\")\n                         .Input(\"b1_y1_diff\")\n                         .Input(\"b1_y2_diff\")\n                         .Input(\"b2_x1_diff\")\n                         .Input(\"b2_x2_diff\")\n                         .Input(\"b2_y1_diff\")\n                         .Input(\"b2_y2_diff\")\n                         .Output(\"x1_diff\")\n                         .Output(\"y1_diff\")\n                         .Output(\"w1_diff\")\n                         .Output(\"h1_diff\")\n                         .Output(\"x2_diff\")\n                         .Output(\"y2_diff\")\n                         .Output(\"w2_diff\")\n                         .Output(\"h2_diff\")\n                         .Build());\n  }\n\n  Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& b1_x1_diff,\n                                const std::shared_ptr<one::Tensor>& b1_x2_diff,\n                                const std::shared_ptr<one::Tensor>& b1_y1_diff,\n                                const std::shared_ptr<one::Tensor>& b1_y2_diff,\n                                const std::shared_ptr<one::Tensor>& b2_x1_diff,\n                                const std::shared_ptr<one::Tensor>& b2_x2_diff,\n                                const std::shared_ptr<one::Tensor>& b2_y1_diff,\n                                const std::shared_ptr<one::Tensor>& b2_y2_diff) const {\n    return OpInterpUtil::Dispatch<TensorTuple>(*op_,\n                                               {b1_x1_diff, b1_x2_diff, b1_y1_diff, b1_y2_diff,\n                                                b2_x1_diff, b2_x2_diff, b2_y1_diff, b2_y2_diff},\n                                               {});\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass FusedGetCiouDiagonalAngleFunctor {\n public:\n  FusedGetCiouDiagonalAngleFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"fused_get_ciou_diagonal_angle\")\n                         .Input(\"w1\")\n                         .Input(\"h1\")\n                         .Input(\"w2\")\n                         .Input(\"h2\")\n                         .Output(\"v\")\n                         .Build());\n  }\n\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& w1,\n                           const std::shared_ptr<one::Tensor>& h1,\n                           const std::shared_ptr<one::Tensor>& w2,\n                           const std::shared_ptr<one::Tensor>& h2, const float eps) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"eps\");\n    attrs.SetAllAttrs(eps);\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {w1, h1, w2, h2}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass FusedGetCiouResultFunctor {\n public:\n  FusedGetCiouResultFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"fused_get_ciou_result\")\n                         .Input(\"v\")\n                         .Input(\"iou\")\n                         .Input(\"rho2\")\n                         .Input(\"c2\")\n                         .Output(\"y\")\n                         .Output(\"alpha\")\n                         .Build());\n  }\n\n  Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& v,\n                                const std::shared_ptr<one::Tensor>& iou,\n                                const std::shared_ptr<one::Tensor>& rho2,\n                                const std::shared_ptr<one::Tensor>& c2, const float& eps) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"eps\");\n    attrs.SetAllAttrs(eps);\n    return OpInterpUtil::Dispatch<TensorTuple>(*op_, {v, iou, rho2, c2}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass FusedGetCiouDiagonalAngleGradFunctor {\n public:\n  FusedGetCiouDiagonalAngleGradFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"fused_get_ciou_diagonal_angle_grad\")\n                         .Input(\"w1\")\n                         .Input(\"h1\")\n                         .Input(\"w2\")\n                         .Input(\"h2\")\n                         .Input(\"v_diff\")\n                         .Output(\"w1_diff\")\n                         .Output(\"h1_diff\")\n                         .Output(\"w2_diff\")\n                         .Output(\"h2_diff\")\n                         .Build());\n  }\n\n  Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& w1,\n                                const std::shared_ptr<one::Tensor>& h1,\n                                const std::shared_ptr<one::Tensor>& w2,\n                                const std::shared_ptr<one::Tensor>& h2,\n                                const std::shared_ptr<one::Tensor>& v_diff, const float eps) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"eps\");\n    attrs.SetAllAttrs(eps);\n    return OpInterpUtil::Dispatch<TensorTuple>(*op_, {w1, h1, w2, h2, v_diff}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass FusedGetCiouResultGradFunctor {\n public:\n  FusedGetCiouResultGradFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"fused_get_ciou_result_grad\")\n                         .Input(\"dy\")\n                         .Input(\"alpha\")\n                         .Input(\"rho2\")\n                         .Input(\"c2\")\n                         .Output(\"dv\")\n                         .Output(\"diou\")\n                         .Output(\"drho2\")\n                         .Output(\"dc2\")\n                         .Build());\n  }\n\n  Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& dy,\n                                const std::shared_ptr<one::Tensor>& alpha,\n                                const std::shared_ptr<one::Tensor>& rho2,\n                                const std::shared_ptr<one::Tensor>& c2) const {\n    return OpInterpUtil::Dispatch<TensorTuple>(*op_, {dy, alpha, rho2, c2}, {});\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass FusedGetIouFunctor {\n public:\n  FusedGetIouFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"fused_get_iou\")\n                         .Input(\"w1\")\n                         .Input(\"h1\")\n                         .Input(\"w2\")\n                         .Input(\"h2\")\n                         .Input(\"inter\")\n                         .Output(\"iou\")\n                         .Build());\n  }\n\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& w1,\n                           const std::shared_ptr<one::Tensor>& h1,\n                           const std::shared_ptr<one::Tensor>& w2,\n                           const std::shared_ptr<one::Tensor>& h2,\n                           const std::shared_ptr<one::Tensor>& inter, const float& eps) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"eps\");\n    attrs.SetAllAttrs(eps);\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {w1, h1, w2, h2, inter}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass FusedGetIouGradFunctor {\n public:\n  FusedGetIouGradFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"fused_get_iou_grad\")\n                         .Input(\"diou\")\n                         .Input(\"w1\")\n                         .Input(\"h1\")\n                         .Input(\"w2\")\n                         .Input(\"h2\")\n                         .Input(\"inter\")\n                         .Output(\"dw1\")\n                         .Output(\"dh1\")\n                         .Output(\"dinter\")\n                         .Build());\n  }\n\n  Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& diou,\n                                const std::shared_ptr<one::Tensor>& w1,\n                                const std::shared_ptr<one::Tensor>& h1,\n                                const std::shared_ptr<one::Tensor>& w2,\n                                const std::shared_ptr<one::Tensor>& h2,\n                                const std::shared_ptr<one::Tensor>& inter, const float& eps) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"eps\");\n    attrs.SetAllAttrs(eps);\n    return OpInterpUtil::Dispatch<TensorTuple>(*op_, {diou, w1, h1, w2, h2, inter}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass FusedGetConvexDiagonalSquaredFunctor {\n public:\n  FusedGetConvexDiagonalSquaredFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"fused_get_convex_diagonal_squared\")\n                         .Input(\"b1_x1\")\n                         .Input(\"b1_x2\")\n                         .Input(\"b2_x1\")\n                         .Input(\"b2_x2\")\n                         .Input(\"b1_y1\")\n                         .Input(\"b1_y2\")\n                         .Input(\"b2_y1\")\n                         .Input(\"b2_y2\")\n                         .Output(\"c2\")\n                         .Build());\n  }\n\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& b1_x1,\n                           const std::shared_ptr<one::Tensor>& b1_x2,\n                           const std::shared_ptr<one::Tensor>& b2_x1,\n                           const std::shared_ptr<one::Tensor>& b2_x2,\n                           const std::shared_ptr<one::Tensor>& b1_y1,\n                           const std::shared_ptr<one::Tensor>& b1_y2,\n                           const std::shared_ptr<one::Tensor>& b2_y1,\n                           const std::shared_ptr<one::Tensor>& b2_y2, const float& eps) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"eps\");\n    attrs.SetAllAttrs(eps);\n    return OpInterpUtil::Dispatch<Tensor>(\n        *op_, {b1_x1, b1_x2, b2_x1, b2_x2, b1_y1, b1_y2, b2_y1, b2_y2}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass FusedGetConvexDiagonalSquaredGradFunctor {\n public:\n  FusedGetConvexDiagonalSquaredGradFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"fused_get_convex_diagonal_squared_grad\")\n                         .Input(\"c2_diff\")\n                         .Input(\"b1_x1\")\n                         .Input(\"b1_x2\")\n                         .Input(\"b2_x1\")\n                         .Input(\"b2_x2\")\n                         .Input(\"b1_y1\")\n                         .Input(\"b1_y2\")\n                         .Input(\"b2_y1\")\n                         .Input(\"b2_y2\")\n                         .Output(\"b1_x1_diff\")\n                         .Output(\"b1_x2_diff\")\n                         .Output(\"b2_x1_diff\")\n                         .Output(\"b2_x2_diff\")\n                         .Output(\"b1_y1_diff\")\n                         .Output(\"b1_y2_diff\")\n                         .Output(\"b2_y1_diff\")\n                         .Output(\"b2_y2_diff\")\n                         .Build());\n  }\n\n  Maybe<TensorTuple> operator()(\n      const std::shared_ptr<one::Tensor>& c2_diff, const std::shared_ptr<one::Tensor>& b1_x1,\n      const std::shared_ptr<one::Tensor>& b1_x2, const std::shared_ptr<one::Tensor>& b2_x1,\n      const std::shared_ptr<one::Tensor>& b2_x2, const std::shared_ptr<one::Tensor>& b1_y1,\n      const std::shared_ptr<one::Tensor>& b1_y2, const std::shared_ptr<one::Tensor>& b2_y1,\n      const std::shared_ptr<one::Tensor>& b2_y2, const float& eps) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"eps\");\n    attrs.SetAllAttrs(eps);\n    return OpInterpUtil::Dispatch<TensorTuple>(\n        *op_, {c2_diff, b1_x1, b1_x2, b2_x1, b2_x2, b1_y1, b1_y2, b2_y1, b2_y2}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass RealFunctor {\n public:\n  RealFunctor() { op_ = CHECK_JUST(one::OpBuilder(\"real\").Input(\"x\").Output(\"out\").Build()); }\n\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x) const {\n    if (!x->dtype()->is_complex()) { return x; }\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {x});\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass RealGradFunctor {\n public:\n  RealGradFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"real_grad\").Input(\"dout\").Output(\"dx\").Build());\n  }\n\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& dout) const {\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {dout});\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass ImagFunctor {\n public:\n  ImagFunctor() { op_ = CHECK_JUST(one::OpBuilder(\"imag\").Input(\"x\").Output(\"out\").Build()); }\n\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x) const {\n    CHECK_OR_RETURN(x->dtype()->is_complex())\n        << \"RuntimeError: imag is implemented for tensors with complex dtypes, but gets\"\n        << x->dtype()->name();\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {x});\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass ImagGradFunctor {\n public:\n  ImagGradFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"imag_grad\").Input(\"dout\").Output(\"dx\").Build());\n  }\n\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& dout) const {\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {dout});\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass ConjFunctor {\n public:\n  ConjFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"conj_physical\").Input(\"x\").Output(\"out\").Build());\n  }\n\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x) const {\n    if (!x->dtype()->is_complex()) { return x; }\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {x});\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass ConjPhysicalFunctor {\n public:\n  ConjPhysicalFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"conj_physical\").Input(\"x\").Output(\"out\").Build());\n  }\n\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x) const {\n    if (!IsComplexDataType(x->dtype()->data_type())) { return x; }\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {x});\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\n}  // namespace impl\n\nusing namespace impl;\n\nONEFLOW_FUNCTION_LIBRARY(m) {\n  m.add_functor<AddNFunctor>(\"Add\");\n  m.add_functor<ScalarAddFunctor, ScalarAdd2Functor>(\"ScalarAdd\");\n  m.add_functor<ScalarSubFunctor, ScalarSub2Functor>(\"ScalarSub\");\n  m.add_functor<ScalarMulFunctor, ScalarMul2Functor>(\"ScalarMul\");\n  m.add_functor<InplaceScalarMulFunctor>(\"InplaceScalarMul\");\n  m.add_functor<AddCDivFunctor>(\"AddCDiv\");\n  m.add_functor<InplaceAddCDivFunctor>(\"InplaceAddCDiv\");\n  m.add_functor<ScalarDivFunctor, ScalarDiv2Functor>(\"ScalarDiv\");\n  m.add_functor<ScalarDivModeFunctor, ScalarDivMode2Functor>(\"ScalarDivMode\");\n  m.add_functor<InplaceScalarDivFunctor>(\"InplaceScalarDiv\");\n  m.add_functor<ScalarPowFunctor>(\"ScalarPow\");\n  m.add_functor<ScalarReversePowFunctor>(\"ScalarReversePow\");\n  m.add_functor<ScalarPowGradFunctor>(\"ScalarPowGrad\");\n  m.add_functor<ScalarReversePowGradFunctor>(\"ScalarReversePowGrad\");\n  m.add_functor<ReduceMaxFunctor>(\"ReduceMax\");\n  m.add_functor<MaxFunctor, Max2Functor>(\"Max\");\n  m.add_functor<ReduceMeanFunctor>(\"ReduceMean\");\n  m.add_functor<ReduceMeanWholeFunctor>(\"ReduceMeanWhole\");\n  m.add_functor<ReduceMinFunctor>(\"ReduceMin\");\n  m.add_functor<MinFunctor, Min2Functor>(\"Min\");\n  m.add_functor<AminFunctor>(\"Amin\");\n  m.add_functor<MedianFunctor>(\"Median\");\n  m.add_functor<MedianWithIndicesFunctor>(\"MedianWithIndices\");\n  m.add_functor<ModeFunctor>(\"Mode\");\n  m.add_functor<AmaxFunctor>(\"Amax\");\n  m.add_functor<ReduceSumFunctor>(\"ReduceSum\");\n  m.add_functor<ReduceSumWholeFunctor>(\"ReduceSumWhole\");\n  m.add_functor<ReduceNanSumFunctor>(\"ReduceNanSum\");\n  m.add_functor<ReduceNanSumWholeFunctor>(\"ReduceNanSumWhole\");\n  m.add_functor<ReduceAllFunctor>(\"ReduceAll\");\n  m.add_functor<ReduceAllWholeFunctor>(\"ReduceAllWhole\");\n  m.add_functor<ReduceAnyFunctor>(\"ReduceAny\");\n  m.add_functor<ReduceAnyWholeFunctor>(\"ReduceAnyWhole\");\n  m.add_functor<ReduceProdFunctor>(\"ReduceProd\");\n  m.add_functor<ReduceProdWholeFunctor>(\"ReduceProdWhole\");\n  m.add_functor<ReduceMinDeviceStageFunctor>(\"ReduceMinDeviceStage\");\n  m.add_functor<ReduceMaxDeviceStageFunctor>(\"ReduceMaxDeviceStage\");\n  m.add_functor<ReduceMinGlobalStageFunctor>(\"ReduceMinGlobalStage\");\n  m.add_functor<ReduceMaxGlobalStageFunctor>(\"ReduceMaxGlobalStage\");\n  m.add_functor<ReduceMinDeviceStageGradFunctor>(\"ReduceMinDeviceStageGrad\");\n  m.add_functor<ReduceMaxDeviceStageGradFunctor>(\"ReduceMaxDeviceStageGrad\");\n  m.add_functor<ReduceMinGlobalStageGradFunctor>(\"ReduceMinGlobalStageGrad\");\n  m.add_functor<ReduceMaxGlobalStageGradFunctor>(\"ReduceMaxGlobalStageGrad\");\n  m.add_functor<LogSumExpFunctor>(\"LogSumExp\");\n  m.add_functor<LogAddExpFunctor>(\"LogAddExp\");\n  m.add_functor<QuantileFunctor>(\"Quantile\");\n  m.add_functor<ScalarQuantileFunctor>(\"ScalarQuantile\");\n  m.add_functor<TransposeFunctor>(\"Transpose\");\n  m.add_functor<Transpose2dimFunctor>(\"Transpose2dim\");\n  m.add_functor<TransposeFunctor>(\"Permute\");\n  m.add_functor<AsStridedFunctor>(\"AsStrided\");\n  m.add_functor<AsStridedGradFunctor>(\"AsStridedGrad\");\n  m.add_functor<InplaceAsStridedFunctor>(\"InplaceAsStrided\");\n  m.add_functor<Transpose2dimFunctor>(\"Swapaxes\");\n  m.add_functor<Transpose2dimFunctor>(\"Swapdims\");\n  m.add_functor<ArangeFunctor, Arange2Functor>(\"Arange\");\n  m.add_functor<GlobalArangeFunctor, GlobalArange2Functor>(\"GlobalArange\");\n  m.add_functor<HannWindowFunctor>(\"HannWindow\");\n  m.add_functor<GlobalHannWindowFunctor>(\"GlobalHannWindow\");\n  m.add_functor<CastFunctor>(\"Cast\");\n  m.add_functor<ClampFunctor>(\"Clamp\");\n  m.add_functor<ClampMinFunctor>(\"ClampMin\");\n  m.add_functor<ClampMaxFunctor>(\"ClampMax\");\n  m.add_functor<ClampInplaceFunctor>(\"ClampInplace\");\n  m.add_functor<ClampMinInplaceFunctor>(\"ClampMinInplace\");\n  m.add_functor<ClampMaxInplaceFunctor>(\"ClampMaxInplace\");\n  m.add_functor<ClipFunctor>(\"Clip\");\n  m.add_functor<ClipInplaceFunctor>(\"ClipInplace\");\n  m.add_functor<SqrtSquareSumFunctor>(\"SqrtSquareSum\");\n  m.add_functor<VectorNormFunctor, ScalarVectorNormFunctor>(\"VectorNorm\");\n  m.add_functor<ScalarMatrixNormFunctor, MatrixNormFunctor>(\"MatrixNorm\");\n  m.add_functor<NormFunctor, Norm2Functor>(\"Norm\");\n  m.add_functor<ScalarNormFunctor, ScalarNorm2Functor>(\"ScalarNorm\");\n  m.add_functor<ClampGradFunctor>(\"ClampGrad\");\n  m.add_functor<SelectFunctor>(\"Select\");\n  m.add_functor<SelectTopNFunctor>(\"SelectTopN\");\n  m.add_functor<MinimumFunctor>(\"Minimum\");\n  m.add_functor<MinimumFunctor>(\"Min\");\n  m.add_functor<MaximumFunctor>(\"Maximum\");\n  m.add_functor<MaximumFunctor>(\"Max\");\n  m.add_functor<ScalarFModFunctor>(\"ScalarFMod\");\n  m.add_functor<ScalarFloorDivFunctor>(\"ScalarFloorDiv\");\n  m.add_functor<ScalarTruncDivFunctor>(\"ScalarTruncDiv\");\n  m.add_functor<ScalarLogicalEqualFunctor, ScalarLogicalEqual2Functor>(\"ScalarLogicalEqual\");\n  m.add_functor<ScalarLogicalNotEqualFunctor, ScalarLogicalNotEqual2Functor>(\n      \"ScalarLogicalNotEqual\");\n  m.add_functor<ScalarLogicalGreaterFunctor, ScalarLogicalGreater2Functor>(\"ScalarLogicalGreater\");\n  m.add_functor<InplaceScalarLogicalGreaterFunctor>(\"InplaceScalarLogicalGreater\");\n  m.add_functor<ScalarLogicalGreaterEqualFunctor, ScalarLogicalGreaterEqual2Functor>(\n      \"ScalarLogicalGreaterEqual\");\n  m.add_functor<ScalarLogicalLessFunctor, ScalarLogicalLess2Functor>(\"ScalarLogicalLess\");\n  m.add_functor<ScalarLogicalLessEqualFunctor, ScalarLogicalLessEqual2Functor>(\n      \"ScalarLogicalLessEqual\");\n  m.add_functor<ScalarLogicalAndFunctor, ScalarLogicalAnd2Functor>(\"ScalarLogicalAnd\");\n  m.add_functor<ScalarLogicalOrFunctor, ScalarLogicalOr2Functor>(\"ScalarLogicalOr\");\n  m.add_functor<ScalarLogicalXorFunctor, ScalarLogicalXor2Functor>(\"ScalarLogicalXor\");\n  m.add_functor<ScalarLerpFunctor>(\"ScalarLerp\");\n  m.add_functor<ScalarInplaceLerpFunctor>(\"ScalarInplaceLerp\");\n  m.add_functor<ScalarLerpGradFunctor>(\"ScalarLerpGrad\");\n  m.add_functor<StandardDeviationFunctor>(\"StandardDeviation\");\n  m.add_functor<VarianceFunctor>(\"Variance\");\n  m.add_functor<RMSLayerNormalizationFunctor>(\"RMSLayerNormalization\");\n  m.add_functor<DotFunctor>(\"Dot\");\n  m.add_functor<MovedimVecFunctor>(\"MovedimVec\");\n  m.add_functor<MovedimIntFunctor>(\"MovedimInt\");\n  m.add_functor<TensorSplitVecFunctor>(\"TensorSplitVec\");\n  m.add_functor<TensorSplitIntFunctor>(\"TensorSplitInt\");\n  m.add_functor<HsplitIntFunctor>(\"HsplitInt\");\n  m.add_functor<HsplitVecFunctor>(\"HsplitVec\");\n  m.add_functor<VsplitIntFunctor>(\"VsplitInt\");\n  m.add_functor<VsplitVecFunctor>(\"VsplitVec\");\n  m.add_functor<ErfinvFunctor>(\"Erfinv\");\n  m.add_functor<ErfinvInplaceFunctor>(\"ErfinvInplace\");\n  m.add_functor<CumsumFunctor>(\"Cumsum\");\n  m.add_functor<CumProdFunctor>(\"Cumprod\");\n  m.add_functor<CumProdGradFunctor>(\"CumprodGrad\");\n  m.add_functor<EinSumFunctor>(\"EinSum\");\n  m.add_functor<InvFunctor>(\"Inv\");\n  m.add_functor<DetFunctor>(\"Det\");\n  m.add_functor<GeluWithApproximateFunctor>(\"GeluWithApproximate\");\n  m.add_functor<impl::TruncFunctor>(\"Trunc\");\n\n  m.add_functor<StftFunctor>(\"Stft\");\n  m.add_functor<impl::FftC2CFunctor>(\"FftC2C\");\n  m.add_functor<impl::FftR2CFunctor>(\"FftR2C\");\n  m.add_functor<impl::FftC2RFunctor>(\"FftC2R\");\n  m.add_functor<impl::FftFunctor>(\"Fft\");\n  m.add_functor<impl::IFftFunctor>(\"IFft\");\n  m.add_functor<impl::Fft2Functor>(\"Fft2\");\n  m.add_functor<impl::IFft2Functor>(\"IFft2\");\n  m.add_functor<impl::FftNFunctor>(\"FftN\");\n  m.add_functor<impl::IFftNFunctor>(\"IFftN\");\n  m.add_functor<impl::RFftFunctor>(\"RFft\");\n  m.add_functor<impl::IRFftFunctor>(\"IRFft\");\n  m.add_functor<impl::RFft2Functor>(\"RFft2\");\n  m.add_functor<impl::IRFft2Functor>(\"IRFft2\");\n  m.add_functor<impl::RFftNFunctor>(\"RFftN\");\n  m.add_functor<impl::IRFftNFunctor>(\"IRFftN\");\n  m.add_functor<impl::HFftFunctor>(\"HFft\");\n  m.add_functor<impl::IHFftFunctor>(\"IHFft\");\n  m.add_functor<impl::HFft2Functor>(\"HFft2\");\n  m.add_functor<impl::IHFft2Functor>(\"IHFft2\");\n  m.add_functor<impl::HFftNFunctor>(\"HFftN\");\n  m.add_functor<impl::IHFftNFunctor>(\"IHFftN\");\n\n  m.add_functor<impl::FusedWeightedSumFunctor>(\"FusedWeightedSum\");\n  m.add_functor<impl::FusedCenterFunctor>(\"FusedCenter\");\n  m.add_functor<impl::FusedCenterGradFunctor>(\"FusedCenterGrad\");\n  m.add_functor<impl::FusedGetBounddingBoxesCoordFunctor>(\"FusedGetBounddingBoxesCoord\");\n  m.add_functor<impl::FusedGetBounddingBoxesCoordGradFunctor>(\"FusedGetBounddingBoxesCoordGrad\");\n  m.add_functor<impl::FusedGetCiouDiagonalAngleFunctor>(\"FusedGetCiouDiagonalAngle\");\n  m.add_functor<impl::FusedGetCiouDiagonalAngleGradFunctor>(\"FusedGetCiouDiagonalAngleGrad\");\n  m.add_functor<impl::FusedGetCiouResultFunctor>(\"FusedGetCiouResult\");\n  m.add_functor<impl::FusedGetCiouResultGradFunctor>(\"FusedGetCiouResultGrad\");\n  m.add_functor<impl::FusedGetIntersectionAreaFunctor>(\"FusedGetIntersectionArea\");\n  m.add_functor<impl::FusedGetIntersectionAreaGradFunctor>(\"FusedGetIntersectionAreaGrad\");\n  m.add_functor<impl::FusedGetIouFunctor>(\"FusedGetIou\");\n  m.add_functor<impl::FusedGetIouGradFunctor>(\"FusedGetIouGrad\");\n  m.add_functor<impl::FusedGetConvexDiagonalSquaredFunctor>(\"FusedGetConvexDiagonalSquared\");\n  m.add_functor<impl::FusedGetConvexDiagonalSquaredGradFunctor>(\n      \"FusedGetConvexDiagonalSquaredGrad\");\n  m.add_functor<impl::ScalarBitwiseAndFunctor, impl::ScalarBitwiseAnd2Functor>(\"ScalarBitwiseAnd\");\n  m.add_functor<impl::ScalarBitwiseOrFunctor, impl::ScalarBitwiseOr2Functor>(\"ScalarBitwiseOr\");\n  m.add_functor<impl::ScalarBitwiseXorFunctor, impl::ScalarBitwiseXor2Functor>(\"ScalarBitwiseXor\");\n  m.add_functor<impl::RealFunctor>(\"Real\");\n  m.add_functor<impl::RealGradFunctor>(\"RealGrad\");\n  m.add_functor<impl::ImagFunctor>(\"Imag\");\n  m.add_functor<impl::ImagGradFunctor>(\"ImagGrad\");\n  m.add_functor<impl::ConjFunctor>(\"Conj\");\n  m.add_functor<impl::ConjPhysicalFunctor>(\"ConjPhysical\");\n};\n\n}  // namespace functional\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/functional/impl/nn_functor.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/common/optional.h\"\n#include \"oneflow/core/framework/mutable_attr_map.h\"\n#include \"oneflow/core/framework/op_builder.h\"\n#include \"oneflow/core/framework/tensor_util.h\"\n#include \"oneflow/core/functional/function_library.h\"\n#include \"oneflow/core/functional/sequence_function.h\"\n#include \"oneflow/core/functional/impl/common.h\"\n#include \"oneflow/core/functional/impl/unary_functor.h\"\n#include \"oneflow/core/kernel/kernel_util.h\"\n#include \"oneflow/user/kernels/random_mask_like_kernel.h\"\n#include \"oneflow/user/kernels/dropout_kernel.h\"\n#include \"oneflow/user/kernels/distributions/common.h\"\n#include \"oneflow/user/kernels/random_seed_util.h\"\n#include \"oneflow/user/kernels/scaled_dot_product_attention_kernel.h\"\n\n#include \"oneflow/core/common/container_util.h\"\n#include \"fmt/core.h\"\n\nnamespace oneflow {\nnamespace one {\nnamespace functional {\n\nnamespace impl {\n\nclass BiasAddFunctor {\n public:\n  BiasAddFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"bias_add\").Input(\"a\").Input(\"b\").Output(\"out\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,\n                           const std::shared_ptr<one::Tensor>& bias, const int32_t& axis) const {\n    int32_t axis_val = axis;\n    if (axis_val < 0) {\n      const int64_t num_axes = x->shape()->NumAxes();\n      axis_val += num_axes;\n    }\n    CHECK_LT_OR_RETURN(axis_val, x->shape()->NumAxes())\n        << Error::IndexError() << \"Dimension out of range (expected to be in range of [-\"\n        << x->shape()->NumAxes() << \",\" << x->shape()->NumAxes() - 1 << \"], but got \" << axis_val\n        << \")\";\n    CHECK_EQ_OR_RETURN(x->shape()->At(axis_val), bias->shape()->At(0))\n        << Error::RuntimeError() << \"The size of tensor x \" << x->shape()->ToString()\n        << \" must match the size of tensor b \" << bias->shape()->ToString() << \" at dimension \"\n        << axis_val;\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"axis\");\n    attrs.SetAllAttrs(axis_val);\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {x, bias}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass ConvBaseFunctor {\n public:\n  explicit ConvBaseFunctor(const int& num_spatial_dims) : num_spatial_dims_(num_spatial_dims) {\n    bias_op_ = CHECK_JUST(one::OpBuilder(\"bias_add\").Input(\"a\").Input(\"b\").Output(\"out\").Build());\n    enable_fused_conv_bias_ = ParseBooleanFromEnv(\"ONEFLOW_KERNEL_ENABLE_FUSED_CONV_BIAS\", false);\n  }\n  virtual ~ConvBaseFunctor() = default;\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input,\n                           const std::shared_ptr<one::Tensor>& weight,\n                           const Optional<one::Tensor>& bias, const std::vector<int32_t>& stride,\n                           const std::vector<int32_t>& padding,\n                           const std::vector<int32_t>& dilation, const int32_t& groups,\n                           const std::string& channel_pos) const {\n    std::shared_ptr<one::Tensor> unsqueezed_input;\n    bool is_batched = true;\n    std::string func_name;\n    if (num_spatial_dims_ == 1) {\n      func_name = \"conv1d\";\n    } else if (num_spatial_dims_ == 2) {\n      func_name = \"conv2d\";\n    } else {\n      func_name = \"conv3d\";\n    }\n    std::tie(unsqueezed_input, is_batched) = *JUST(batchify(input, num_spatial_dims_, func_name));\n    std::vector<int32_t> kernel_size_vec(num_spatial_dims_);\n    int32_t channel_idx = 1;\n    int32_t kernel_idx_offset = 2;\n    if (channel_pos == \"channels_last\") {\n      kernel_idx_offset = 1;\n      channel_idx = kernel_idx_offset + num_spatial_dims_;\n    }\n\n    for (int i = 0; i < num_spatial_dims_; i++) {\n      kernel_size_vec.at(i) = ((weight->shape())->At(i + kernel_idx_offset));\n    }\n    auto& conv_attrs =\n        THREAD_CACHED_MUTABLE_ATTR_MAP(\"filters\", \"kernel_size\", \"padding_before\", \"strides\",\n                                       \"dilation_rate\", \"groups\", \"data_format\");\n    conv_attrs.SetAllAttrs(static_cast<int32_t>(weight->shape()->At(0)), kernel_size_vec, padding,\n                           stride, dilation, groups, channel_pos);\n    if (bias && enable_fused_conv_bias_) {\n      return OpInterpUtil::Dispatch<Tensor>(*conv_bias_op_, {input, weight, JUST(bias)},\n                                            conv_attrs);\n    }\n    const std::shared_ptr<one::Tensor>& conv_out =\n        JUST(OpInterpUtil::Dispatch<Tensor>(*conv_op_, {unsqueezed_input, weight}, conv_attrs));\n    std::shared_ptr<one::Tensor> squeezed_conv_output = conv_out;\n    if (!is_batched) {\n      squeezed_conv_output = JUST(functional::Squeeze(conv_out, std::vector<int32_t>{0}));\n      channel_idx -= 1;\n    }\n    if (bias) {\n      return functional::BiasAdd(squeezed_conv_output, JUST(bias), channel_idx);\n    } else {\n      return squeezed_conv_output;\n    }\n  }\n\n protected:\n  std::shared_ptr<OpExpr> conv_op_;\n  std::shared_ptr<OpExpr> bias_op_;\n  std::shared_ptr<OpExpr> conv_bias_op_;\n  int32_t num_spatial_dims_;\n  bool enable_fused_conv_bias_;\n};\n\nclass Conv1dFunctor : public ConvBaseFunctor {\n public:\n  Conv1dFunctor() : ConvBaseFunctor(/*num_spatial_dims_=*/1) {\n    conv_op_ =\n        CHECK_JUST(one::OpBuilder(\"conv1d\").Input(\"in\").Input(\"weight\").Output(\"out\").Build());\n    conv_bias_op_ = CHECK_JUST(\n        one::OpBuilder(\"conv1d\").Input(\"in\").Input(\"weight\").Input(\"bias\").Output(\"out\").Build());\n  }\n};\n\nclass Conv2dFunctor : public ConvBaseFunctor {\n public:\n  Conv2dFunctor() : ConvBaseFunctor(/*num_spatial_dims_=*/2) {\n    conv_op_ =\n        CHECK_JUST(one::OpBuilder(\"conv2d\").Input(\"in\").Input(\"weight\").Output(\"out\").Build());\n    conv_bias_op_ = CHECK_JUST(\n        one::OpBuilder(\"conv2d\").Input(\"in\").Input(\"weight\").Input(\"bias\").Output(\"out\").Build());\n  }\n};\n\nclass Conv3dFunctor : public ConvBaseFunctor {\n public:\n  Conv3dFunctor() : ConvBaseFunctor(/*num_spatial_dims_=*/3) {\n    conv_op_ =\n        CHECK_JUST(one::OpBuilder(\"conv3d\").Input(\"in\").Input(\"weight\").Output(\"out\").Build());\n    conv_bias_op_ = CHECK_JUST(\n        one::OpBuilder(\"conv3d\").Input(\"in\").Input(\"weight\").Input(\"bias\").Output(\"out\").Build());\n  }\n};\n\nclass DeConvBaseFunctor {\n public:\n  explicit DeConvBaseFunctor(const int& num_spatial_dims) : num_spatial_dims_(num_spatial_dims) {\n    bias_op_ = CHECK_JUST(one::OpBuilder(\"bias_add\").Input(\"a\").Input(\"b\").Output(\"out\").Build());\n  }\n  virtual ~DeConvBaseFunctor() = default;\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input,\n                           const std::shared_ptr<one::Tensor>& weight,\n                           const Optional<one::Tensor>& bias, const std::vector<int32_t>& stride,\n                           const std::vector<int32_t>& padding,\n                           const std::vector<int32_t>& output_padding, const int32_t& groups,\n                           const std::vector<int32_t>& dilation,\n                           const std::string& data_format) const {\n    std::shared_ptr<one::Tensor> unsqueezed_input;\n    bool is_batched = true;\n    std::string func_name;\n    if (num_spatial_dims_ == 1) {\n      func_name = \"deconv1d\";\n    } else if (num_spatial_dims_ == 2) {\n      func_name = \"deconv2d\";\n    } else {\n      func_name = \"deconv3d\";\n    }\n    std::tie(unsqueezed_input, is_batched) = *JUST(batchify(input, num_spatial_dims_, func_name));\n    int32_t channel_idx = 1;\n    std::vector<int32_t> kernel_size_vec(num_spatial_dims_);\n    int32_t kernel_idx_offset = 2;\n    if (data_format == \"channels_last\") { kernel_idx_offset = 1; }\n    for (int i = 0; i < num_spatial_dims_; i++) {\n      kernel_size_vec[i] = ((weight->shape())->At(i + kernel_idx_offset));\n    }\n\n    auto& deconv_attrs =\n        THREAD_CACHED_MUTABLE_ATTR_MAP(\"filters\", \"kernel_size\", \"padding_before\", \"output_padding\",\n                                       \"strides\", \"dilation_rate\", \"groups\", \"data_format\");\n    deconv_attrs.SetAllAttrs(static_cast<int32_t>(weight->shape()->At(1) * groups), kernel_size_vec,\n                             padding, output_padding, stride, dilation, groups, data_format);\n    std::shared_ptr<one::Tensor> deconv_out =\n        JUST(OpInterpUtil::Dispatch<Tensor>(*deconv_op_, {unsqueezed_input, weight}, deconv_attrs));\n    std::shared_ptr<one::Tensor> squeezed_deconv_output = deconv_out;\n    if (!is_batched) {\n      squeezed_deconv_output = JUST(functional::Squeeze(deconv_out, std::vector<int32_t>{0}));\n      channel_idx -= 1;\n    }\n    if (bias) {\n      auto& bias_attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"axis\");\n      bias_attrs.SetAllAttrs(static_cast<int32_t>(channel_idx));\n      return OpInterpUtil::Dispatch<Tensor>(*bias_op_, {squeezed_deconv_output, JUST(bias)},\n                                            bias_attrs);\n    } else {\n      return squeezed_deconv_output;\n    }\n  }\n\n protected:\n  std::shared_ptr<OpExpr> deconv_op_;\n  std::shared_ptr<OpExpr> bias_op_;\n  int32_t num_spatial_dims_;\n};\n\nclass DeConv1dFunctor : public DeConvBaseFunctor {\n public:\n  DeConv1dFunctor() : DeConvBaseFunctor(/*num_spatial_dims_=*/1) {\n    deconv_op_ =\n        CHECK_JUST(one::OpBuilder(\"deconv1d\").Input(\"in\").Input(\"weight\").Output(\"out\").Build());\n  }\n};\n\nclass DeConv2dFunctor : public DeConvBaseFunctor {\n public:\n  DeConv2dFunctor() : DeConvBaseFunctor(/*num_spatial_dims_=*/2) {\n    deconv_op_ =\n        CHECK_JUST(one::OpBuilder(\"deconv2d\").Input(\"in\").Input(\"weight\").Output(\"out\").Build());\n  }\n};\n\nclass DeConv3dFunctor : public DeConvBaseFunctor {\n public:\n  DeConv3dFunctor() : DeConvBaseFunctor(/*num_spatial_dims_=*/3) {\n    deconv_op_ =\n        CHECK_JUST(one::OpBuilder(\"deconv3d\").Input(\"in\").Input(\"weight\").Output(\"out\").Build());\n  }\n};\n\nclass EmbeddingReNormFunctor {\n public:\n  EmbeddingReNormFunctor() {\n    op_ = CHECK_JUST(\n        one::OpBuilder(\"embedding_renorm\").Input(\"in\").Input(\"indices\").Output(\"out\").Build());\n  }\n\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& in,\n                           const std::shared_ptr<one::Tensor>& indices, const double& max_norm,\n                           const double& norm_type) const {\n    CHECK_EQ_OR_RETURN(in->ndim(), 2)\n        << Error::RuntimeError() << \"The dimension of input should be 2.\";\n    std::shared_ptr<TensorTuple> outputs = std::make_shared<TensorTuple>(1);\n    JUST(oneflow::VectorAt(*outputs, 0)) = in;\n\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"max_norm\", \"norm_type\");\n    attrs.SetAllAttrs(max_norm, norm_type);\n\n    JUST(OpInterpUtil::Dispatch(*op_, {in, indices}, outputs.get(), attrs));\n    return JUST(oneflow::VectorAt(*outputs, 0));\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass EmbeddingFunctor {\n public:\n  EmbeddingFunctor() {\n    op_ = CHECK_JUST(\n        one::OpBuilder(\"embedding\").Input(\"weight\").Input(\"indices\").Output(\"out\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& weight,\n                           const std::shared_ptr<one::Tensor>& indices,\n                           const Optional<int64_t>& padding_idx,\n                           const bool& scale_grad_by_freq) const {\n    CHECK_EQ_OR_RETURN(weight->ndim(), 2) << \"The dimension of weight should be 2\";\n    int64_t new_padding_idx = -1;\n    if (padding_idx.has_value()) { new_padding_idx = JUST(padding_idx); }\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"padding_idx\", \"scale_grad_by_freq\");\n    attrs.SetAllAttrs(new_padding_idx, scale_grad_by_freq);\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {weight, indices}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass MatMulNoBroadCastFunctor {\n public:\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input,\n                           const std::shared_ptr<one::Tensor>& mat2) const {\n    const auto& input_shape = input->shape();\n    const auto& mat2_shape = mat2->shape();\n    CHECK_EQ_OR_RETURN(input_shape->NumAxes(), 2)\n        << Error::RuntimeError() << \"self must be a matrix\";\n    CHECK_EQ_OR_RETURN(mat2_shape->NumAxes(), 2)\n        << Error::RuntimeError() << \"mat2 must be a matrix\";\n    CHECK_EQ_OR_RETURN(input_shape->at(1), mat2_shape->at(0))\n        << Error::RuntimeError() << \"mat1 and mat2 shapes cannot be multiplied (\"\n        << std::to_string(input_shape->at(0)) << \"x\" << std::to_string(input_shape->at(1))\n        << \" and \" << std::to_string(mat2_shape->at(0)) << \"x\" << std::to_string(mat2_shape->at(1))\n        << \")\";\n    return JUST(functional::MatMul(input, mat2, false, false, 1.0));\n  }\n};\n\nclass MatMulFunctor {\n public:\n  MatMulFunctor() {\n    matmul_op_ = CHECK_JUST(one::OpBuilder(\"matmul\").Input(\"a\").Input(\"b\").Output(\"out\").Build());\n    batch_matmul_op_ =\n        CHECK_JUST(one::OpBuilder(\"batch_matmul\").Input(\"a\").Input(\"b\").Output(\"out\").Build());\n    bcast_matmul_op_ =\n        CHECK_JUST(one::OpBuilder(\"broadcast_matmul\").Input(\"a\").Input(\"b\").Output(\"out\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& a,\n                           const std::shared_ptr<one::Tensor>& b, const bool& transpose_a,\n                           const bool& transpose_b, const double& alpha) const {\n    const auto& a_shape = a->shape();\n    const auto& b_shape = b->shape();\n    CHECK_GE_OR_RETURN(a_shape->NumAxes(), 1)\n        << Error::RuntimeError() << \"Tensor a's dim should >= 1\";\n    CHECK_GE_OR_RETURN(b_shape->NumAxes(), 1)\n        << Error::RuntimeError() << \"Tensor b's dim should >= 1\";\n\n    DeviceType device_type{};\n    if (a->is_global()) {\n      device_type = JUST(a->parallel_desc())->device_type();\n    } else {\n      device_type = JUST(a->device())->enum_type();\n    }\n    std::shared_ptr<one::Tensor> cast_a = a;\n    std::shared_ptr<one::Tensor> cast_b = b;\n    std::shared_ptr<one::Tensor> result;\n    if ((cast_a->dtype()->is_integer()) && (device_type == DeviceType::kCPU)) {\n      cast_a = JUST(functional::Cast(a, JUST(DType::Get(DataType::kFloat)), /*pin_memory=*/false));\n      cast_b = JUST(functional::Cast(b, JUST(DType::Get(DataType::kFloat)), /*pin_memory=*/false));\n    }\n\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"transpose_a\", \"transpose_b\", \"alpha\");\n    attrs.SetAllAttrs(transpose_a, transpose_b, alpha);\n    const int64_t a_num_axes = a_shape->NumAxes();\n    const int64_t b_num_axes = b_shape->NumAxes();\n    if (a_num_axes == 1 && b_num_axes == 2) {\n      result = JUST(VectorMatrixProduct(cast_a, cast_b));\n    } else if (a_num_axes == 2 && b_num_axes == 1) {\n      result = JUST(MatrixVectorProduct(cast_a, cast_b));\n    } else if (a_num_axes == 2 && b_num_axes == 2) {\n      result = JUST(OpInterpUtil::Dispatch<Tensor>(*matmul_op_, {cast_a, cast_b}, attrs));\n    } else if (a_num_axes == b_num_axes) {\n      bool if_batch_matmul = true;\n      for (int i = 0; i < a_num_axes - 2; ++i) {\n        if (a_shape->At(i) != b_shape->At(i)) {\n          if_batch_matmul = false;\n          break;\n        }\n      }\n      if (if_batch_matmul) {\n        result = JUST(OpInterpUtil::Dispatch<Tensor>(*batch_matmul_op_, {cast_a, cast_b}, attrs));\n      } else {\n        result = JUST(OpInterpUtil::Dispatch<Tensor>(*bcast_matmul_op_, {cast_a, cast_b}, attrs));\n      }\n    } else {\n      result = JUST(OpInterpUtil::Dispatch<Tensor>(*bcast_matmul_op_, {cast_a, cast_b}, attrs));\n    }\n\n    if ((a->dtype()->is_integer()) && (device_type == DeviceType::kCPU)) {\n      return JUST(functional::Cast(result, a->dtype(), /*pin_memory=*/false));\n    } else {\n      return result;\n    }\n  }\n\n private:\n  std::shared_ptr<OpExpr> matmul_op_;\n  std::shared_ptr<OpExpr> batch_matmul_op_;\n  std::shared_ptr<OpExpr> bcast_matmul_op_;\n};\n\nclass BatchMatMulFunctor {\n public:\n  BatchMatMulFunctor() {\n    batch_matmul_op_ =\n        CHECK_JUST(one::OpBuilder(\"batch_matmul\").Input(\"a\").Input(\"b\").Output(\"out\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& a,\n                           const std::shared_ptr<one::Tensor>& b, const bool& transpose_a,\n                           const bool& transpose_b, const double& alpha) const {\n    const auto& a_shape = a->shape();\n    const auto& b_shape = b->shape();\n    CHECK_EQ_OR_RETURN(a_shape->NumAxes(), 3)\n        << Error::RuntimeError() << \"Expected 3-dimensional tensor, but got \" << a_shape->NumAxes()\n        << \"-dimensional tensor for argument #1\";\n    CHECK_EQ_OR_RETURN(b_shape->NumAxes(), 3)\n        << Error::RuntimeError() << \"Expected 3-dimensional tensor, but got \" << b_shape->NumAxes()\n        << \"-dimensional tensor for argument #2\";\n    CHECK_EQ_OR_RETURN(a_shape->At(0), b_shape->At(0))\n        << Error::RuntimeError() << \"Batch dim not match, please check input!\";\n    const int64_t matmul_dim_a = transpose_a ? a_shape->At(1) : a_shape->At(2);\n    const int64_t matmul_dim_b = transpose_b ? b_shape->At(2) : b_shape->At(1);\n    CHECK_EQ_OR_RETURN(matmul_dim_a, matmul_dim_b)\n        << Error::RuntimeError() << \"Matmul dim not match, got \" << matmul_dim_a << \" of mat1 and \"\n        << matmul_dim_b << \" of mat2, please check input!\";\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"transpose_a\", \"transpose_b\", \"alpha\");\n    attrs.SetAllAttrs(transpose_a, transpose_b, alpha);\n\n    DeviceType device_type{};\n    if (a->is_global()) {\n      device_type = JUST(a->parallel_desc())->device_type();\n    } else {\n      device_type = JUST(a->device())->enum_type();\n    }\n    std::shared_ptr<one::Tensor> cast_a = a;\n    std::shared_ptr<one::Tensor> cast_b = b;\n    if ((a->dtype()->is_integer()) && (device_type == DeviceType::kCPU)) {\n      cast_a = JUST(functional::Cast(a, JUST(DType::Get(DataType::kFloat)), /*pin_memory=*/false));\n      cast_b = JUST(functional::Cast(b, JUST(DType::Get(DataType::kFloat)), /*pin_memory=*/false));\n    }\n\n    auto result = JUST(OpInterpUtil::Dispatch<Tensor>(*batch_matmul_op_, {cast_a, cast_b}, attrs));\n    if ((a->dtype()->is_integer()) && (device_type == DeviceType::kCPU)) {\n      return JUST(functional::Cast(result, a->dtype(), /*pin_memory=*/false));\n    } else {\n      return result;\n    }\n  }\n\n private:\n  std::shared_ptr<OpExpr> batch_matmul_op_;\n};\n\nclass VectorMatrixProductFunctor {\n public:\n  VectorMatrixProductFunctor() {\n    vector_matrix_product_op_ = CHECK_JUST(\n        one::OpBuilder(\"vector_matrix_product\").Input(\"a\").Input(\"b\").Output(\"out\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& vec,\n                           const std::shared_ptr<one::Tensor>& input) const {\n    const auto& vec_shape = vec->shape();\n    const auto& input_shape = input->shape();\n    CHECK_OR_RETURN(input_shape->NumAxes() == 2 && vec_shape->NumAxes() == 1)\n        << Error::RuntimeError() << \"vector @ matrix expected, got \"\n        << \"1, \" << input_shape->NumAxes() << \", \" << vec_shape->NumAxes();\n    CHECK_EQ_OR_RETURN(vec_shape->at(0), input_shape->at(0))\n        << Error::RuntimeError() << \"size mismatch, got \" << 1 << \", \"\n        << std::to_string(vec_shape->at(0)) << \" x \" << std::to_string(input_shape->at(0)) << \", \"\n        << std::to_string(input_shape->at(1));\n    return OpInterpUtil::Dispatch<Tensor>(*vector_matrix_product_op_, {vec, input});\n  }\n\n private:\n  std::shared_ptr<OpExpr> vector_matrix_product_op_;\n};\n\nclass TensorDotIntDimsFunctor {\n public:\n  Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& a, const std::shared_ptr<Tensor>& b,\n                           const int32_t dims) const {\n    CHECK_GE_OR_RETURN(dims, 0) << Error::RuntimeError()\n                                << \"tensordot expects dims >= 0, but got dims=\" << dims;\n    CHECK_LE_OR_RETURN(dims, a->ndim())\n        << Error::RuntimeError() << \"tensordot expects dims <= a.ndim which is \" << a->ndim()\n        << \", but got \" << dims;\n    CHECK_LE_OR_RETURN(dims, b->ndim())\n        << Error::RuntimeError() << \"tensordot expects dims <= b.ndim which is \" << b->ndim()\n        << \", but got \" << dims;\n    std::vector<int32_t> dot_dims_a(dims), dot_dims_b(dims);\n    for (int32_t i = 0; i < dims; i++) {\n      dot_dims_a[i] = a->ndim() - dims + i;\n      dot_dims_b[i] = i;\n    }\n    return JUST(functional::TensorDot(a, b, dot_dims_a, dot_dims_b));\n  }\n};\n\nclass TensorDotFunctor {\n public:\n  Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& a, const std::shared_ptr<Tensor>& b,\n                           const std::vector<int32_t>& dims_a,\n                           const std::vector<int32_t>& dims_b) const {\n    // dims_a and dims_b represent dim indices to calculate dot, and are copied to variables\n    // dot_dims_a and dot_dims_b when they need to be modified\n    CHECK_EQ_OR_RETURN(dims_a.size(), dims_b.size())\n        << Error::RuntimeError() << \"both dimension lists should have same length, got \"\n        << dims_a.size() << \" and \" << dims_b.size();\n\n    // dims_a.size() == dims_b.size(), and specially treat if both are empty\n    if (dims_a.empty()) {\n      DimVector shape_sum(a->ndim() + b->ndim());\n      for (int64_t i = 0; i < a->ndim(); i++) { shape_sum[i] = a->shape()->At(i); }\n      for (int64_t i = 0; i < b->ndim(); i++) { shape_sum[i + a->ndim()] = b->shape()->At(i); }\n      std::shared_ptr<Tensor> reshape_a = JUST(Reshape(a, Shape(DimVector{-1, 1})));\n      std::shared_ptr<Tensor> reshape_b = JUST(Reshape(b, Shape(DimVector{1, -1})));\n      return JUST(Reshape(JUST(functional::MatMul(reshape_a, reshape_b, false, false, 1.0)),\n                          Shape(DimVector(shape_sum.begin(), shape_sum.end()))));\n    }\n    std::vector<int32_t> dot_dims_a(dims_a.begin(), dims_a.end());\n    std::vector<int32_t> dot_dims_b(dims_b.begin(), dims_b.end());\n    for (int64_t i = 0; i < dot_dims_a.size(); i++) {\n      dot_dims_a[i] = JUST(maybe_wrap_dim(dot_dims_a[i], a->ndim()));\n      dot_dims_b[i] = JUST(maybe_wrap_dim(dot_dims_b[i], b->ndim()));\n    }\n    std::vector<bool> if_dot_dims_a(a->ndim(), false);\n    std::vector<bool> if_dot_dims_b(b->ndim(), false);\n    for (const int32_t dim_idx : dot_dims_a) {\n      CHECK_EQ_OR_RETURN(if_dot_dims_a[dim_idx], false)\n          << Error::RuntimeError() << \"dim \" << dim_idx\n          << \" appears multiple times in the list of dims\";\n      if_dot_dims_a[dim_idx] = true;\n    }\n    for (const int32_t dim_idx : dot_dims_b) {\n      CHECK_EQ_OR_RETURN(if_dot_dims_b[dim_idx], false)\n          << Error::RuntimeError() << \"dim \" << dim_idx\n          << \" appears multiple times in the list of dims\";\n      if_dot_dims_b[dim_idx] = true;\n    }\n\n    std::vector<int32_t> broadcast_dims_a, broadcast_dims_b;\n    for (int64_t i = 0; i < dot_dims_a.size(); i++) {\n      int64_t size_a = a->shape()->At(dot_dims_a[i]);\n      int64_t size_b = b->shape()->At(dot_dims_b[i]);\n      if (size_a == 1 && size_b > 1) {\n        broadcast_dims_b.emplace_back(dot_dims_b[i]);\n      } else if (size_b == 1 && size_a > 1) {\n        broadcast_dims_a.emplace_back(dot_dims_a[i]);\n      } else {\n        CHECK_EQ_OR_RETURN(size_a, size_b)\n            << Error::RuntimeError() << \"contracted dimensions need to match, but first has size \"\n            << size_a << \" in dim \" << dot_dims_a[i] << \" and second has size \" << size_b\n            << \" in dim \" << dot_dims_b[i];\n      }\n    }\n\n    // calculate ReduceSum for broadcasting of some axis\n    std::shared_ptr<Tensor> reduced_sum_a = a;\n    std::shared_ptr<Tensor> reduced_sum_b = b;\n    if (!broadcast_dims_a.empty())\n      reduced_sum_a = JUST(functional::ReduceSum(a, broadcast_dims_a, true, NullOpt));\n    if (!broadcast_dims_b.empty())\n      reduced_sum_b = JUST(functional::ReduceSum(b, broadcast_dims_b, true, NullOpt));\n\n    // int64_t non_dot_size_a = 1, non_dot_size_b = 1;\n    std::vector<int32_t> non_dot_shape_a, non_dot_shape_b;\n    non_dot_shape_a.reserve(a->ndim() - dot_dims_a.size() + b->ndim() - dot_dims_b.size());\n    non_dot_shape_b.reserve(b->ndim() - dot_dims_b.size());\n\n    std::vector<int32_t> permuted_dims_a, permuted_dims_b;\n    permuted_dims_a.reserve(a->ndim());\n    permuted_dims_b.reserve(b->ndim());\n\n    for (int32_t i = 0; i < a->ndim(); i++) {\n      if (!if_dot_dims_a[i]) {\n        permuted_dims_a.emplace_back(i);\n        // non_dot_size_a *= reduced_sum_a->shape()->At(i);\n        non_dot_shape_a.emplace_back(reduced_sum_a->shape()->At(i));\n      }\n    }\n\n    for (const int32_t dim_idx : dot_dims_a) permuted_dims_a.emplace_back(dim_idx);\n    for (const int32_t dim_idx : dot_dims_b) permuted_dims_b.emplace_back(dim_idx);\n\n    for (int32_t i = 0; i < b->ndim(); i++) {\n      if (!if_dot_dims_b[i]) {\n        permuted_dims_b.emplace_back(i);\n        // non_dot_size_b *= reduced_sum_b->shape()->At(i);\n        non_dot_shape_b.emplace_back(reduced_sum_b->shape()->At(i));\n      }\n    }\n    non_dot_shape_a.insert(non_dot_shape_a.end(), non_dot_shape_b.begin(), non_dot_shape_b.end());\n\n    int64_t dot_size = 1;\n    for (const int32_t dim_idx : dot_dims_a) dot_size *= reduced_sum_a->shape()->At(dim_idx);\n    std::shared_ptr<Tensor> permuted_a = JUST(\n        Reshape(JUST(Permute(reduced_sum_a, permuted_dims_a)), Shape(DimVector({-1, dot_size}))));\n    std::shared_ptr<Tensor> permuted_b = JUST(\n        Reshape(JUST(Permute(reduced_sum_b, permuted_dims_b)), Shape(DimVector({dot_size, -1}))));\n\n    return Reshape(JUST(functional::MatMul(permuted_a, permuted_b, false, false, 1.0)),\n                   Shape(DimVector({non_dot_shape_a.begin(), non_dot_shape_a.end()})));\n  }\n};\n\nclass FusedMLPFunctor {\n public:\n  FusedMLPFunctor() {\n#if CUDA_VERSION >= 11060\n    fused_op_.resize(kMaxInputCount /*the maximum number of inputs*/);\n    for (int n = 1; n < fused_op_.size(); ++n) {\n      fused_op_[n] = CHECK_JUST(one::OpBuilder(\"cublas_fused_mlp\")\n                                    .Input(\"x\")\n                                    .Input(\"weights\", n)\n                                    .Input(\"biases\", n)\n                                    .Output(\"out\")\n                                    .Output(\"cublas_aux\", n)\n                                    .Output(\"hidden\", n)\n                                    .Build());\n    }\n#endif\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const TensorTuple& weights,\n                           const TensorTuple& biases, bool skip_final_activation) const {\n    const int64_t weight_size = weights.size();\n    const int64_t bias_size = biases.size();\n    CHECK_GE_OR_RETURN(weight_size, 1)\n        << Error::RuntimeError() << \"The number of weights should be greater equal than 1. \";\n    CHECK_EQ_OR_RETURN(weight_size, bias_size)\n        << Error::RuntimeError() << \"The number of weights should be equal to biases. \";\n    int64_t n = 0, k = 0;\n    /*\n    x: (m, k)\n    weight: (n, k) need transpose\n    bias: (n)\n    */\n    const auto& x_shape = x->shape();\n    k = x_shape->At(1);\n    for (int64_t i = 0; i < weight_size; i++) {\n      const auto& weight_shape = weights[i]->shape();\n      const auto& bias_shape = biases[i]->shape();\n\n      // TODO(): Support Fused batch/broadcast matmul.\n      CHECK_EQ_OR_RETURN(weight_shape->NumAxes(), 2)\n          << Error::RuntimeError() << \"Weight's dim size should == 2\";\n      CHECK_EQ_OR_RETURN(bias_shape->NumAxes(), 1)\n          << Error::RuntimeError() << \"Bias's dim size should == 1\";\n\n      n = weight_shape->At(0);\n      CHECK_EQ_OR_RETURN(bias_shape->At(0), n)\n          << Error::RuntimeError() << \"Bias's dim is not equal to weight's first dim. \";\n      CHECK_EQ_OR_RETURN(weight_shape->At(1), k)\n          << Error::RuntimeError() << \"weight's second dim should be equal to input's second dim. \";\n\n      // Set for next layer.\n      k = n;\n    }\n\n#if CUDA_VERSION >= 11060\n    DeviceType device_type{};\n    if (x->is_global()) {\n      device_type = JUST(x->parallel_desc())->device_type();\n    } else {\n      device_type = JUST(x->device())->enum_type();\n    }\n\n    if ((device_type == DeviceType::kCUDA) && (weight_size <= kMaxInputCount)\n        && (!ParseBooleanFromEnv(\"ONEFLOW_FUNCTOR_DISABLE_FUSED_MLP\", false))) {\n      TensorTuple input(2 * weight_size + 1);\n      input[0] = x;\n      std::copy(weights.begin(), weights.end(), input.begin() + 1);\n      std::copy(biases.begin(), biases.end(), input.begin() + 1 + weight_size);\n\n      auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"skip_final_activation\");\n      attrs.SetAllAttrs(skip_final_activation);\n      return OpInterpUtil::Dispatch<Tensor>(*fused_op_[weight_size], input, attrs);\n    }\n#endif  // CUDA_VERSION >= 11060\n\n    // Fall back to Naive matmul + bias_add + relu\n    std::shared_ptr<one::Tensor> out = x;\n    for (int32_t layer_idx = 0; layer_idx < weight_size; layer_idx++) {\n      out = JUST(\n          functional::BiasAdd(JUST(functional::MatMul(out, weights[layer_idx], false, true, 1.0)),\n                              biases[layer_idx], 1));\n      if ((layer_idx != weight_size - 1) || (!skip_final_activation)) {\n        /*\n        When it is not last dense layer, or it is last dense layer and skip_final_activate=False,\n        we add relu Layer.\n        */\n        out = JUST(functional::Relu(out, false));\n      }\n    }\n    return out;\n  }\n\n private:\n#if CUDA_VERSION >= 11060\n  std::vector<std::shared_ptr<OpExpr>> fused_op_;\n#endif\n};\n\nclass FusedMatmulBiasFunctor {\n public:\n  FusedMatmulBiasFunctor() {\n    _with_add_to_output_op = CHECK_JUST(one::OpBuilder(\"fused_matmul_bias\")\n                                            .Input(\"x\")\n                                            .Input(\"weight\")\n                                            .Input(\"bias\")\n                                            .Input(\"_add_to_output\")\n                                            .Output(\"out\")\n                                            .Build());\n    _without_add_to_output_op = CHECK_JUST(one::OpBuilder(\"fused_matmul_bias\")\n                                               .Input(\"x\")\n                                               .Input(\"weight\")\n                                               .Input(\"bias\")\n                                               .Output(\"out\")\n                                               .Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,\n                           const std::shared_ptr<one::Tensor>& weight,\n                           const std::shared_ptr<one::Tensor>& bias,\n                           const Optional<one::Tensor>& _add_to_output, const double& alpha,\n                           const double& beta) const {\n    /*\n    x: (m_i, ... m_0, k)\n    weight: (n, k) need transpose\n    bias: (n)\n    */\n    const auto& x_shape = x->shape();\n    const int64_t k = x_shape->At(x->shape()->NumAxes() - 1);\n\n    const auto& weight_shape = weight->shape();\n    const auto& bias_shape = bias->shape();\n\n    CHECK_EQ_OR_RETURN(weight_shape->NumAxes(), 2)\n        << Error::RuntimeError() << \"Weight's dim size should == 2\";\n    CHECK_EQ_OR_RETURN(bias_shape->NumAxes(), 1)\n        << Error::RuntimeError() << \"Bias's dim size should == 1\";\n\n    const int64_t n = weight_shape->At(0);\n    CHECK_EQ_OR_RETURN(bias_shape->At(0), n)\n        << Error::RuntimeError() << \"Bias's dim is not equal to weight's first dim. \";\n    CHECK_EQ_OR_RETURN(weight_shape->At(1), k)\n        << Error::RuntimeError() << \"weight's second dim should be equal to input's second dim. \";\n\n#if CUDA_VERSION >= 11020\n    DeviceType device_type{};\n    if (x->is_global()) {\n      device_type = JUST(x->parallel_desc())->device_type();\n    } else {\n      device_type = JUST(x->device())->enum_type();\n    }\n\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"alpha\", \"beta\");\n    attrs.SetAllAttrs(alpha, beta);\n    if (device_type == DeviceType::kCUDA) {\n      if (_add_to_output) {\n        return OpInterpUtil::Dispatch<Tensor>(*_with_add_to_output_op,\n                                              {x, weight, bias, JUST(_add_to_output)}, attrs);\n      } else {\n        return OpInterpUtil::Dispatch<Tensor>(*_without_add_to_output_op, {x, weight, bias}, attrs);\n      }\n    }\n#endif  // CUDA_VERSION >= 11020\n\n    auto matmul_bias = JUST(functional::BiasAdd(\n        JUST(functional::MatMul(x, weight, false, true, alpha)), bias, x->shape()->NumAxes() - 1));\n    if (_add_to_output && beta != 0.0) {\n      if (beta == 1.0) {\n        return JUST(functional::Add({matmul_bias, JUST(_add_to_output)}, false));\n      } else {\n        return JUST(functional::Add(\n            {matmul_bias, JUST(functional::ScalarMul(JUST(_add_to_output), beta, false))}, false));\n      }\n    } else {\n      return matmul_bias;\n    }\n  }\n\n private:\n  std::shared_ptr<OpExpr> _with_add_to_output_op;\n  std::shared_ptr<OpExpr> _without_add_to_output_op;\n};\n\nclass FusedMatmulBiasAddReluDropoutFunctor {\n public:\n  FusedMatmulBiasAddReluDropoutFunctor() {\n#if CUDA_VERSION >= 11060\n    fused_op_.resize(kMaxInputCount /*the maximum number of inputs*/);\n    for (int n = 1; n < fused_op_.size(); ++n) {\n      fused_op_[n] = CHECK_JUST(one::OpBuilder(\"fused_matmul_bias_add_relu_dropout\")\n                                    .Input(\"x\")\n                                    .Input(\"weights\", n)\n                                    .Input(\"biases\", n)\n                                    .Output(\"out\")\n                                    .Output(\"cublas_aux\", n)\n                                    .Output(\"hidden\", n)\n                                    .Build());\n    }\n#endif\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const TensorTuple& weights,\n                           const TensorTuple& biases, bool skip_final_activation,\n                           const std::vector<float>& dropout_rate_list,\n                           const Optional<one::Generator>& generator) const {\n    const int64_t weight_size = weights.size();\n    const int64_t bias_size = biases.size();\n    CHECK_GE_OR_RETURN(weight_size, 1)\n        << Error::RuntimeError() << \"The number of weights should be greater equal than 1. \";\n    CHECK_EQ_OR_RETURN(weight_size, bias_size)\n        << Error::RuntimeError() << \"The number of weights should be equal to biases. \";\n    CHECK_EQ_OR_RETURN(weight_size, dropout_rate_list.size())\n        << Error::RuntimeError()\n        << \"The dropout rate list length should be equal to the number of weights. \";\n    int64_t n = 0, k = 0;\n    /*\n    x: (m, k)\n    weight: (n, k) need transpose\n    bias: (n)\n    */\n    const auto& x_shape = x->shape();\n    k = x_shape->At(1);\n    for (int64_t i = 0; i < weight_size; i++) {\n      CHECK_GE_OR_RETURN(dropout_rate_list[i], 0.0f)\n          << Error::RuntimeError() << \"Dropout rate should be >= 0.0\";\n\n      const auto& weight_shape = weights[i]->shape();\n      const auto& bias_shape = biases[i]->shape();\n      // TODO(): Support Fused batch/broadcast matmul.\n      CHECK_EQ_OR_RETURN(weight_shape->NumAxes(), 2) << \"Weight's dim should == 2\";\n      CHECK_EQ_OR_RETURN(bias_shape->NumAxes(), 1) << \"Bias's dim should == 1\";\n\n      n = weight_shape->At(0);\n      CHECK_EQ_OR_RETURN(bias_shape->At(0), n) << \"Bias's dim is not equal to weight's last dim. \";\n      CHECK_EQ_OR_RETURN(weight_shape->At(1), k)\n          << \"weight's first dim should be equal to input's last dim. \";\n      // Set for next layer.\n      k = n;\n    }\n\n    auto gen = generator.value_or(JUST(one::DefaultAutoGenerator()));\n\n#if CUDA_VERSION >= 11060\n    DeviceType device_type{};\n    if (x->is_global()) {\n      device_type = JUST(x->parallel_desc())->device_type();\n    } else {\n      device_type = JUST(x->device())->enum_type();\n    }\n    if ((device_type == DeviceType::kCUDA) && (weight_size <= kMaxInputCount)\n        && (!ParseBooleanFromEnv(\"ONEFLOW_FUNCTOR_DISABLE_FUSED_MLP\", false))) {\n      TensorTuple input(2 * weight_size + 1);\n      input[0] = x;\n      std::copy(weights.begin(), weights.end(), input.begin() + 1);\n      std::copy(biases.begin(), biases.end(), input.begin() + 1 + weight_size);\n\n      gen = JUST(GetGeneratorForLazyOrGlobal(gen, LazyMode::is_enabled(), x));\n      auto& attrs =\n          THREAD_CACHED_MUTABLE_ATTR_MAP(\"skip_final_activation\", \"seed\", \"dropout_rate_list\");\n      attrs.SetAllAttrs(skip_final_activation, static_cast<int64_t>(gen->current_seed()),\n                        dropout_rate_list);\n      const auto& dropout_state = std::make_shared<FusedDropoutKernelState>(gen);\n      return OpInterpUtil::Dispatch<Tensor>(*fused_op_[weight_size], input,\n                                            OpExprInterpContext(attrs, dropout_state));\n    }\n#endif  // CUDA_VERSION >= 11060\n\n    // Fall back to Naive matmul + bias_add + relu + dropout\n    std::shared_ptr<one::Tensor> out = x;\n    for (int32_t layer_idx = 0; layer_idx < weight_size; layer_idx++) {\n      out = JUST(\n          functional::BiasAdd(JUST(functional::MatMul(out, weights[layer_idx], false, true, 1.0)),\n                              biases[layer_idx], 1));\n      if ((layer_idx != weight_size - 1) || !skip_final_activation) {\n        out = JUST(functional::Relu(out, false));\n        out = JUST(functional::Dropout(out, JUST(VectorAt(dropout_rate_list, layer_idx)),\n                                       /*training=*/true,\n                                       /*inplace=*/false,\n                                       /*generator=*/gen, /*addend=*/NullOpt));\n      } else {\n        out = JUST(functional::Dropout(out, JUST(VectorAt(dropout_rate_list, layer_idx)),\n                                       /*training=*/true,\n                                       /*inplace=*/false,\n                                       /*generator=*/gen, /*addend=*/NullOpt));\n      }\n    }\n    return out;\n  }\n\n private:\n#if CUDA_VERSION >= 11060\n  std::vector<std::shared_ptr<OpExpr>> fused_op_;\n#endif\n};\n\nclass LayerNormFunctor {\n public:\n  LayerNormFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"layer_norm\")\n                         .Input(\"x\")\n                         .Output(\"y\")\n                         .Output(\"mean\")\n                         .Output(\"inv_variance\")\n                         .Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const int64_t& begin_norm_axis,\n                           const int64_t& begin_params_axis, const double& epsilon) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"begin_norm_axis\", \"begin_params_axis\", \"epsilon\",\n                                                 \"center\", \"scale\");\n    attrs.SetAllAttrs(begin_norm_axis, begin_params_axis, epsilon, false, false);\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {x}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass SkipLayerNormFunctor {\n public:\n  SkipLayerNormFunctor() {\n    std::vector<bool> bool_list = {true, false};\n\n    /* number of skip */\n    for (bool has_skip : bool_list) {\n      /* has_gamma */\n      for (bool has_gamma : bool_list) {\n        /* has_beta */\n        for (bool has_beta : bool_list) {\n          /* has_bias */\n          for (bool has_bias : bool_list) {\n            one::OpBuilder op_builder = one::OpBuilder(\"skip_layer_norm\").Input(\"x\");\n            if (has_gamma) { op_builder = op_builder.Input(\"gamma\"); }\n            if (has_beta) { op_builder = op_builder.Input(\"beta\"); }\n            if (has_bias) { op_builder = op_builder.Input(\"bias\"); }\n            if (has_skip) { op_builder = op_builder.Input(\"skip\"); }\n            op_builder = op_builder.Output(\"y\").Output(\"mean\").Output(\"inv_variance\");\n\n            std::shared_ptr<OpExpr> op_expr = CHECK_JUST(op_builder.Build());\n            ops_.insert(std::pair<std::tuple<bool, bool, bool, bool>, std::shared_ptr<OpExpr>>(\n                std::tuple<bool, bool, bool, bool>(has_skip, has_gamma, has_beta, has_bias),\n                op_expr));\n          }  // has_bias\n        }    // has_beta\n      }      // has_gamma\n    }        // has_skip\n  }\n\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,\n                           const Optional<one::Tensor>& gamma, const Optional<one::Tensor>& beta,\n                           const Optional<one::Tensor>& bias, const Optional<one::Tensor>& skip,\n                           const double& epsilon, const double& alpha) const {\n    // check shape of x\n    const auto& x_shape = *(x->shape());\n    CHECK_GE_OR_RETURN(x_shape.NumAxes(), 2)\n        << \"number of axes of \\'x\\' should be greater than or equal to 2, yet get \"\n        << x_shape.NumAxes();\n\n    if (gamma) {\n      const auto& gamma_shape = *(JUST(gamma)->shape());\n      CHECK_EQ_OR_RETURN(gamma_shape.NumAxes(), 1)\n          << \"number of axes of \\'gamma\\' should have be equal to 1, yet get \"\n          << gamma_shape.NumAxes();\n      CHECK_EQ_OR_RETURN(gamma_shape.At(0), x_shape.At(x_shape.NumAxes() - 1))\n          << \"the size of \\'gamma\\'(\" << gamma_shape.At(0)\n          << \") is not consistant with the last dimension of \\'x\\'(\"\n          << x_shape.At(x_shape.NumAxes() - 1) << \")\";\n    }\n    if (beta) {\n      const auto& beta_shape = *(JUST(beta)->shape());\n      CHECK_EQ_OR_RETURN(beta_shape.NumAxes(), 1)\n          << \"number of axes of \\'beta\\' should have be equal to 1, yet get \"\n          << beta_shape.NumAxes();\n      CHECK_EQ_OR_RETURN(beta_shape.At(0), x_shape.At(x_shape.NumAxes() - 1))\n          << \"dimension 1 of \\'beta\\'(\" << beta_shape.At(0)\n          << \") is not consistant with the last dimension of \\'x\\'(\"\n          << x_shape.At(x_shape.NumAxes() - 1) << \")\";\n    }\n    if (bias) {\n      const auto& bias_shape = *(JUST(bias)->shape());\n      CHECK_EQ_OR_RETURN(bias_shape.NumAxes(), 1)\n          << \"number of axes of \\'bias\\' should have be equal to 1, yet get \"\n          << bias_shape.NumAxes();\n      CHECK_EQ_OR_RETURN(bias_shape.At(0), x_shape.At(x_shape.NumAxes() - 1))\n          << \"dimension 1 of \\'bias\\'(\" << bias_shape.At(0)\n          << \") is not consistant with the last dimension of \\'x\\'(\"\n          << x_shape.At(x_shape.NumAxes() - 1) << \")\";\n    }\n    if (skip) {\n      const auto& skip_shape = *(JUST(skip)->shape());\n      CHECK_EQ_OR_RETURN(skip_shape, x_shape) << \"shape of \\'skip\\' is not the same as \\'x\\'\";\n    }\n\n    // set attributes\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"epsilon\", \"alpha\");\n    attrs.SetAllAttrs(epsilon, alpha);\n\n    // count number of all input tensors\n    size_t nb_inputs = 1;       // count x\n    if (skip) nb_inputs += 1;   // count skip\n    if (gamma) nb_inputs += 1;  // count gamma\n    if (beta) nb_inputs += 1;   // count beta\n    if (bias) nb_inputs += 1;   // count bias\n\n    // construct input tensor tuple\n    size_t tensor_index = 1;\n    TensorTuple input(nb_inputs);\n    bool has_gamma = false, has_beta = false, has_bias = false, has_skip = false;\n    input[0] = x;\n    if (gamma) {\n      input[tensor_index] = JUST(gamma);\n      tensor_index += 1;\n      has_gamma = true;\n    }\n    if (beta) {\n      input[tensor_index] = JUST(beta);\n      tensor_index += 1;\n      has_beta = true;\n    }\n    if (bias) {\n      input[tensor_index] = JUST(bias);\n      tensor_index += 1;\n      has_bias = true;\n    }\n    if (skip) {\n      input[tensor_index] = JUST(skip);\n      tensor_index += 1;\n      has_skip = true;\n    }\n\n    return OpInterpUtil::Dispatch<Tensor>(\n        *(ops_.find(std::tuple<bool, bool, bool, bool>(has_skip, has_gamma, has_beta, has_bias))\n              ->second),\n        input, attrs);\n  }\n\n private:\n  /* (nb_skip, has_gamma, has_beta, has_bias) -> op */\n  std::map<std::tuple<bool, bool, bool, bool>, std::shared_ptr<OpExpr>> ops_;\n};\n\nclass LayerNormAffineFunctor {\n public:\n  LayerNormAffineFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"layer_norm\")\n                         .Input(\"x\")\n                         .Input(\"gamma\")\n                         .Input(\"beta\")\n                         .Output(\"y\")\n                         .Output(\"mean\")\n                         .Output(\"inv_variance\")\n                         .Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,\n                           const std::shared_ptr<one::Tensor>& gamma,\n                           const std::shared_ptr<one::Tensor>& beta, const int64_t& begin_norm_axis,\n                           const int64_t& begin_params_axis, const double& epsilon) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"begin_norm_axis\", \"begin_params_axis\", \"epsilon\",\n                                                 \"center\", \"scale\");\n    attrs.SetAllAttrs(begin_norm_axis, begin_params_axis, epsilon, true, true);\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {x, gamma, beta}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass GroupNormFunctor {\n public:\n  GroupNormFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"group_norm\")\n                         .Input(\"x\")\n                         .Output(\"y\")\n                         .Output(\"mean\")\n                         .Output(\"inv_variance\")\n                         .Attr(\"affine\", false)\n                         .Build());\n    affine_op_ = CHECK_JUST(one::OpBuilder(\"group_norm\")\n                                .Input(\"x\")\n                                .Input(\"gamma\")\n                                .Input(\"beta\")\n                                .Output(\"y\")\n                                .Output(\"mean\")\n                                .Output(\"inv_variance\")\n                                .Attr(\"affine\", true)\n                                .Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,\n                           const Optional<one::Tensor>& gamma, const Optional<one::Tensor>& beta,\n                           const bool affine, const int32_t num_groups, const double& epsilon,\n                           const std::string& data_format, const std::string& activation) const {\n    auto& attrs =\n        THREAD_CACHED_MUTABLE_ATTR_MAP(\"num_groups\", \"epsilon\", \"data_format\", \"activation\");\n    attrs.SetAllAttrs(num_groups, epsilon, data_format, activation);\n    if (affine) {\n      return OpInterpUtil::Dispatch<Tensor>(*affine_op_, {x, JUST(gamma), JUST(beta)}, attrs);\n    } else {\n      return OpInterpUtil::Dispatch<Tensor>(*op_, {x}, attrs);\n    }\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n  std::shared_ptr<OpExpr> affine_op_;\n};\n\nbool CheckNormShape(const Shape& x_shape, const Shape& normalized_shape) {\n  if (x_shape.size() < normalized_shape.size()) { return false; }\n  size_t b_ndim = x_shape.size() - normalized_shape.size();\n  for (int i = 0; i < x_shape.size(); ++i) {\n    if (i >= b_ndim) {\n      if (x_shape[i] != normalized_shape[i - b_ndim]) { return false; }\n    }\n  }\n  return true;\n}\n\nclass RMSNormFunctor {\n public:\n  RMSNormFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"rms_norm\").Input(\"x\").Output(\"y\").Output(\"inv_rms\").Build());\n    op_affine_ = CHECK_JUST(one::OpBuilder(\"rms_norm\")\n                                .Input(\"x\")\n                                .Input(\"weight\")\n                                .Output(\"y\")\n                                .Output(\"inv_rms\")\n                                .Build());\n  }\n\n  Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& x, const Optional<one::Tensor>& weight,\n                           const Shape& normalized_shape, const float epsilon) const {\n    const Shape& x_shape = *x->shape();\n    if (weight) {\n      const Shape& w_shape = *JUST(weight)->shape();\n      CHECK_EQ_OR_RETURN(w_shape, normalized_shape)\n          << \"Expected weight be the same shape with normalized_shape \"\n          << normalized_shape.ToString() << \", but got \" << w_shape.ToString();\n    }\n    if (!CheckNormShape(x_shape, normalized_shape)) {\n      auto shape_str_without_parentheses =\n          x_shape.ToString().substr(1, x_shape.ToString().size() - 2);\n      return Error::RuntimeError()\n             << \"Given normalized_shape=\" << normalized_shape.ToString()\n             << \", expected input with shape (*, \" << shape_str_without_parentheses\n             << \"), but got input of \" << x_shape.ToString();\n    }\n\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"normalized_shape\", \"epsilon\");\n    attrs.SetAllAttrs(normalized_shape, epsilon);\n    if (weight) {\n      const DataType dtype = x->dtype()->data_type();\n      if (JUST(weight)->dtype()->data_type() != dtype) {\n        auto weight_cast = JUST(functional::Cast(JUST(weight), DType{dtype}, /*pin_memory=*/false));\n        return OpInterpUtil::Dispatch<Tensor>(*op_affine_, {x, weight_cast}, attrs);\n      }\n      return OpInterpUtil::Dispatch<Tensor>(*op_affine_, {x, JUST(weight)}, attrs);\n    }\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {x}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n  std::shared_ptr<OpExpr> op_affine_;\n};\n\nclass SkipRMSNormFunctor {\n public:\n  SkipRMSNormFunctor() {\n    std::vector<bool> bool_list = {true, false};\n\n    for (bool has_weight : bool_list) {\n      for (bool has_skip : bool_list) {\n        for (bool has_bias : bool_list) {\n          one::OpBuilder op_builder = one::OpBuilder(\"skip_rms_norm\").Input(\"x\");\n          if (has_weight) { op_builder = op_builder.Input(\"weight\"); }\n          if (has_bias) { op_builder = op_builder.Input(\"bias\"); }\n          if (has_skip) { op_builder = op_builder.Input(\"skip\"); }\n          op_builder = op_builder.Output(\"y\").Output(\"inv_rms\");\n\n          std::shared_ptr<OpExpr> op_expr = CHECK_JUST(op_builder.Build());\n          ops_.insert(std::pair<std::tuple<bool, bool, bool>, std::shared_ptr<OpExpr>>(\n              std::tuple<bool, bool, bool>(has_weight, has_skip, has_bias), op_expr));\n        }  // has_bias\n      }    // has_skip\n    }      // has_weight\n  }\n\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,\n                           const Optional<one::Tensor>& weight, const Optional<one::Tensor>& bias,\n                           const Optional<one::Tensor>& skip, const double& epsilon,\n                           const double& alpha) const {\n    // check shape of x\n    const auto& x_shape = *(x->shape());\n    CHECK_GE_OR_RETURN(x_shape.NumAxes(), 2)\n        << \"number of axes of \\'x\\' should be greater than or equal to 2, yet get \"\n        << x_shape.NumAxes();\n\n    if (weight) {\n      const auto& weight_shape = *(JUST(weight)->shape());\n      CHECK_EQ_OR_RETURN(weight_shape.NumAxes(), 1)\n          << \"number of axes of \\'weight\\' should have be equal to 1, yet get \"\n          << weight_shape.NumAxes();\n      CHECK_EQ_OR_RETURN(weight_shape.At(0), x_shape.At(x_shape.NumAxes() - 1))\n          << \"dimension 1 of \\'weight\\'(\" << weight_shape.At(0)\n          << \") is not consistant with the last dimension of \\'x\\'(\"\n          << x_shape.At(x_shape.NumAxes() - 1) << \")\";\n    }\n\n    if (bias) {\n      const auto& bias_shape = *(JUST(bias)->shape());\n      CHECK_EQ_OR_RETURN(bias_shape.NumAxes(), 1)\n          << \"number of axes of \\'bias\\' should have be equal to 1, yet get \"\n          << bias_shape.NumAxes();\n      CHECK_EQ_OR_RETURN(bias_shape.At(0), x_shape.At(x_shape.NumAxes() - 1))\n          << \"dimension 1 of \\'bias\\'(\" << bias_shape.At(0)\n          << \") is not consistant with the last dimension of \\'x\\'(\"\n          << x_shape.At(x_shape.NumAxes() - 1) << \")\";\n    }\n\n    if (skip) {\n      const auto& skip_shape = *(JUST(skip)->shape());\n      CHECK_EQ_OR_RETURN(skip_shape, x_shape) << \"shape of \\'skip\\' is not the same as \\'x\\'\";\n    }\n\n    // set attributes\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"epsilon\", \"alpha\");\n    attrs.SetAllAttrs(epsilon, alpha);\n\n    // count number of all input tensors\n    size_t nb_inputs = 1;        // count x\n    if (skip) nb_inputs += 1;    // count skip\n    if (weight) nb_inputs += 1;  // count weight\n    if (bias) nb_inputs += 1;    // count bias\n\n    // construct input tensor tuple\n    size_t tensor_index = 1;\n    TensorTuple input(nb_inputs);\n    bool has_weight = false, has_bias = false, has_skip = false;\n    input[0] = x;\n    if (weight) {\n      input[tensor_index] = JUST(weight);\n      tensor_index += 1;\n      has_weight = true;\n    }\n    if (bias) {\n      input[tensor_index] = JUST(bias);\n      tensor_index += 1;\n      has_bias = true;\n    }\n    if (skip) {\n      input[tensor_index] = JUST(skip);\n      tensor_index += 1;\n      has_skip = true;\n    }\n\n    return OpInterpUtil::Dispatch<Tensor>(\n        *(ops_.find(std::tuple<bool, bool, bool>(has_weight, has_skip, has_bias))->second), input,\n        attrs);\n  }\n\n private:\n  /* (has_weight, has_skip, has_bias) -> op */\n  std::map<std::tuple<bool, bool, bool>, std::shared_ptr<OpExpr>> ops_;\n};\n\nclass PixelShuffleFunctor {\n public:\n  PixelShuffleFunctor() {}\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const int64_t& h_upscale_factor,\n                           const int64_t& w_upscale_factor) const {\n    CHECK_OR_RETURN(x->ndim() == 4) << Error::RuntimeError() << \"Only Accept 4D Tensor\";\n    const int64_t batch = x->shape()->At(0);\n    const int64_t channel = x->shape()->At(1);\n    const int64_t height = x->shape()->At(2);\n    const int64_t width = x->shape()->At(3);\n    std::shared_ptr<one::Tensor> out;\n    CHECK_OR_RETURN(channel % (h_upscale_factor * w_upscale_factor) == 0)\n        << Error::RuntimeError()\n        << \"The channels of input tensor must be divisible by (upscale_factor * upscale_factor) or \"\n           \"(h_upscale_factor * w_upscale_factor)\";\n    const int64_t new_c = static_cast<int>(channel / (h_upscale_factor * w_upscale_factor));\n    std::vector<int32_t> permute_vec = {0, 1, 4, 2, 5, 3};\n    std::vector<int64_t> reshape_vec_1 = {batch, new_c, h_upscale_factor * w_upscale_factor, height,\n                                          width};\n    Shape reshape_1(DimVector(reshape_vec_1.begin(), reshape_vec_1.end()));\n    std::vector<int64_t> reshape_vec_2 = {batch,  new_c, h_upscale_factor, w_upscale_factor,\n                                          height, width};\n    Shape reshape_2(DimVector(reshape_vec_2.begin(), reshape_vec_2.end()));\n    std::vector<int64_t> reshape_vec_3 = {batch, new_c, height * h_upscale_factor,\n                                          width * w_upscale_factor};\n    Shape reshape_3(DimVector(reshape_vec_3.begin(), reshape_vec_3.end()));\n    out = JUST(Reshape(x, reshape_1));\n    out = JUST(Reshape(out, reshape_2));\n    out = JUST(Permute(out, permute_vec));\n    out = JUST(Reshape(out, reshape_3));\n    return out;\n  }\n};\n\nclass TFPoolNDFunctor {\n public:\n  TFPoolNDFunctor() = default;\n  virtual ~TFPoolNDFunctor() = default;\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,\n                           const std::vector<int32_t>& kernel_size,\n                           const std::vector<int32_t>& strides, const std::string& padding,\n                           const std::vector<int32_t>& padding_before,\n                           const std::vector<int32_t>& padding_after,\n                           const std::string& data_format, const bool& ceil_mode) const {\n    auto& attrs =\n        THREAD_CACHED_MUTABLE_ATTR_MAP(\"pool_size\", \"strides\", \"padding\", \"padding_before\",\n                                       \"padding_after\", \"data_format\", \"ceil_mode\");\n    attrs.SetAllAttrs(kernel_size, strides, padding, padding_before, padding_after, data_format,\n                      ceil_mode);\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {x}, attrs);\n  }\n\n protected:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass MaxPoolNDFunctor {\n public:\n  explicit MaxPoolNDFunctor(const int& num_spatial_dims) : num_spatial_dims_(num_spatial_dims) {}\n  virtual ~MaxPoolNDFunctor() = default;\n  Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& input,\n                                const std::vector<int32_t>& kernel_size,\n                                const Optional<std::vector<int32_t>>& stride,\n                                const std::vector<int32_t>& padding,\n                                const std::vector<int32_t>& dilation, const bool& return_indices,\n                                const bool& ceil_mode, const std::string& data_format) const {\n    // channels_last case\n    if (input->is_cuda() && num_spatial_dims_ == 2 && data_format == \"channels_last\") {\n      if (!return_indices && dilation.at(0) == 1 && dilation.at(1) == 1) {\n        // legacy tf style maxpool2d , use cudnn implementation\n        // with high performance but do not support dilation/return_indices\n        std::vector<int32_t> padding_before{padding.at(0), padding.at(1)};\n        std::vector<int32_t> padding_after{padding.at(0), padding.at(1)};\n\n        auto& attrs =\n            THREAD_CACHED_MUTABLE_ATTR_MAP(\"pool_size\", \"strides\", \"padding\", \"padding_before\",\n                                           \"padding_after\", \"data_format\", \"ceil_mode\");\n        attrs.SetAllAttrs(kernel_size, stride ? *JUST(stride) : kernel_size,\n                          std::string(\"customized\"), padding_before, padding_after, data_format,\n                          ceil_mode);\n        TensorTuple output;\n        output.emplace_back(JUST(OpInterpUtil::Dispatch<Tensor>(*tf_maxpool_op_, {input}, attrs)));\n        return output;\n      }\n    }\n\n    std::shared_ptr<one::Tensor> unsqueezed_input;\n    bool is_batched = true;\n    std::string func_name;\n    if (num_spatial_dims_ == 1) {\n      func_name = \"max_pool1d\";\n    } else if (num_spatial_dims_ == 2) {\n      func_name = \"max_pool2d\";\n    } else {\n      func_name = \"max_pool3d\";\n    }\n    std::tie(unsqueezed_input, is_batched) = *JUST(batchify(input, num_spatial_dims_, func_name));\n\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"kernel_size\", \"padding\", \"stride\", \"dilation\",\n                                                 \"data_format\", \"return_indices\", \"ceil_mode\");\n    // If stride is None, we set it as kernel_size to align Pytorch.\n    attrs.SetAllAttrs(kernel_size, padding, stride ? *JUST(stride) : kernel_size, dilation,\n                      data_format, return_indices, ceil_mode);\n    const auto& pooling_out =\n        JUST(OpInterpUtil::Dispatch<TensorTuple>(*op_, {unsqueezed_input}, attrs));\n    if (!is_batched) {\n      TensorTuple squeezed_pooling_out;  // (y,indices)\n      squeezed_pooling_out.emplace_back(\n          JUST(functional::Squeeze(pooling_out->at(0), std::vector<int32_t>{0})));\n      squeezed_pooling_out.emplace_back(\n          JUST(functional::Squeeze(pooling_out->at(1), std::vector<int32_t>{0})));\n      return squeezed_pooling_out;\n    }\n    return pooling_out;\n  }\n\n protected:\n  int32_t num_spatial_dims_;\n  std::shared_ptr<OpExpr> op_;\n  std::shared_ptr<OpExpr> tf_maxpool_op_;\n};\n\nclass AvgPoolNDFunctor {\n public:\n  AvgPoolNDFunctor() = default;\n  virtual ~AvgPoolNDFunctor() = default;\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,\n                           const std::vector<int32_t>& kernel_size,\n                           const Optional<std::vector<int32_t>>& stride,\n                           const std::vector<int32_t>& padding, const bool& ceil_mode,\n                           const bool& count_include_pad, const int32_t& divisor_override,\n                           const std::string& data_format) const {\n    // legacy tf style avgpool2d , use cudnn implementation with high performance but not support\n    // count_include_pad and divisor_override.\n    if (x->is_cuda() && x->ndim() == 4 && data_format == \"channels_last\") {\n      CHECK_OR_THROW(count_include_pad)\n          << \"AvgPool2d with channels_last data format don't support count_include_pad for now.\";\n      CHECK_EQ_OR_THROW(divisor_override, 0)\n          << \"AvgPool2d with channels_last data format don't support divisor_override for now.\";\n\n      std::vector<int32_t> padding_before{JUST(VectorAt(padding, 0)), JUST(VectorAt(padding, 1))};\n      std::vector<int32_t> padding_after{JUST(VectorAt(padding, 0)), JUST(VectorAt(padding, 1))};\n\n      auto& attrs =\n          THREAD_CACHED_MUTABLE_ATTR_MAP(\"pool_size\", \"strides\", \"padding\", \"padding_before\",\n                                         \"padding_after\", \"data_format\", \"ceil_mode\");\n      attrs.SetAllAttrs(kernel_size, stride ? *JUST(stride) : kernel_size,\n                        std::string(\"customized\"), padding_before, padding_after, data_format,\n                        ceil_mode);\n      return JUST(OpInterpUtil::Dispatch<Tensor>(*tf_avgpool_op_, {x}, attrs));\n    }\n\n    auto& attrs =\n        THREAD_CACHED_MUTABLE_ATTR_MAP(\"kernel_size\", \"padding\", \"stride\", \"data_format\",\n                                       \"ceil_mode\", \"count_include_pad\", \"divisor_override\");\n    // If stride is None, we set it as kernel_size to align Pytorch.\n    attrs.SetAllAttrs(kernel_size, padding, stride ? *JUST(stride) : kernel_size, data_format,\n                      ceil_mode, count_include_pad, divisor_override);\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {x}, attrs);\n  }\n\n protected:\n  std::shared_ptr<OpExpr> op_;\n  std::shared_ptr<OpExpr> tf_avgpool_op_;\n};\n\nclass TFAvgPool2DFunctor : public TFPoolNDFunctor {\n public:\n  TFAvgPool2DFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"tf_avg_pool_2d\").Input(\"x\").Output(\"y\").Build());\n  }\n};\n\nclass MaxPool1DFunctor : public MaxPoolNDFunctor {\n public:\n  MaxPool1DFunctor() : MaxPoolNDFunctor(/*num_spatial_dims_=*/1) {\n    op_ = CHECK_JUST(one::OpBuilder(\"max_pool_1d\").Input(\"x\").Output(\"y\").Output(\"indice\").Build());\n  }\n};\n\nclass MaxPool2DFunctor : public MaxPoolNDFunctor {\n public:\n  MaxPool2DFunctor() : MaxPoolNDFunctor(/*num_spatial_dims_=*/2) {\n    op_ = CHECK_JUST(one::OpBuilder(\"max_pool_2d\").Input(\"x\").Output(\"y\").Output(\"indice\").Build());\n    tf_maxpool_op_ = CHECK_JUST(one::OpBuilder(\"tf_max_pool_2d\").Input(\"x\").Output(\"y\").Build());\n  }\n};\n\nclass MaxPool3DFunctor : public MaxPoolNDFunctor {\n public:\n  MaxPool3DFunctor() : MaxPoolNDFunctor(/*num_spatial_dims_=*/3) {\n    op_ = CHECK_JUST(one::OpBuilder(\"max_pool_3d\").Input(\"x\").Output(\"y\").Output(\"indice\").Build());\n  }\n};\n\nclass AvgPool1DFunctor : public AvgPoolNDFunctor {\n public:\n  AvgPool1DFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"avg_pool_1d\").Input(\"x\").Output(\"y\").Build());\n  }\n};\n\nclass AvgPool2DFunctor : public AvgPoolNDFunctor {\n public:\n  AvgPool2DFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"avg_pool_2d\").Input(\"x\").Output(\"y\").Build());\n    tf_avgpool_op_ = CHECK_JUST(one::OpBuilder(\"tf_avg_pool_2d\").Input(\"x\").Output(\"y\").Build());\n  }\n};\n\nclass AvgPool3DFunctor : public AvgPoolNDFunctor {\n public:\n  AvgPool3DFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"avg_pool_3d\").Input(\"x\").Output(\"y\").Build());\n  }\n};\n\ntemplate<int N>\nclass MaxUnpoolNDFunctor {\n public:\n  MaxUnpoolNDFunctor()\n      : op_(CHECK_JUST(one::OpBuilder(fmt::format(\"max_unpool_{}d\", N))\n                           .Input(\"x\")\n                           .Input(\"indices\")\n                           .Output(\"y\")\n                           .Build())){};\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,\n                           const std::shared_ptr<one::Tensor>& indices,\n                           const std::vector<int32_t>& kernel_size,\n                           const Optional<std::vector<int32_t>>& stride,\n                           const std::vector<int32_t>& padding,\n                           const Optional<Shape>& output_size) const {\n    const auto fmt_error_msg = [](const std::string& name, int32_t num, bool check_element) {\n      if (check_element) {\n        return fmt::format(\"each element in `{}` must be greater than 0, got {}\", name, num);\n      }\n      return fmt::format(\"`{}` must be an integer or a list of {} integers\", name, N);\n    };\n\n    CHECK_EQ_OR_RETURN(kernel_size.size(), N) << fmt_error_msg(\"kernel_size\", N, false);\n    for (int32_t pool_dim : kernel_size) {\n      CHECK_GT_OR_RETURN(pool_dim, 0) << fmt_error_msg(\"kernel_size\", pool_dim, true);\n    }\n\n    if (stride) {\n      CHECK_EQ_OR_RETURN(JUST(stride)->size(), N) << fmt_error_msg(\"stride\", N, false);\n      for (int32_t stride_dim : *JUST(stride)) {\n        CHECK_GT_OR_RETURN(stride_dim, 0) << fmt_error_msg(\"stride\", stride_dim, true);\n      }\n    }\n    for (int32_t i = 0; i < padding.size(); i++) {\n      CHECK_GE_OR_RETURN(kernel_size[i], 2 * padding[i])\n          << \"pad should be smaller than half of kernel size\";\n    }\n\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"kernel_size\", \"padding\", \"stride\",\n                                                 \"has_output_size\", \"output_size\");\n    attrs.SetAllAttrs(kernel_size, padding, stride ? *JUST(stride) : kernel_size,\n                      output_size.has_value(),\n                      output_size.has_value() ? *JUST(output_size) : Shape());\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {x, indices}, attrs);\n  }\n\n protected:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass AdaptivePoolNDFunctor {\n public:\n  AdaptivePoolNDFunctor() = default;\n  virtual ~AdaptivePoolNDFunctor() = default;\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,\n                           const std::vector<int64_t>& output_size,\n                           const std::string& data_format) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"output_size\", \"data_format\");\n    attrs.SetAllAttrs(output_size, data_format);\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {x}, attrs);\n  }\n\n protected:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass AdaptiveAvgPool1DFunctor : public AdaptivePoolNDFunctor {\n public:\n  AdaptiveAvgPool1DFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"adaptive_avg_pool1d\").Input(\"x\").Output(\"y\").Build());\n  }\n};\n\nclass AdaptiveAvgPool2DFunctor : public AdaptivePoolNDFunctor {\n public:\n  AdaptiveAvgPool2DFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"adaptive_avg_pool2d\").Input(\"x\").Output(\"y\").Build());\n  }\n};\n\nclass AdaptiveAvgPool3DFunctor : public AdaptivePoolNDFunctor {\n public:\n  AdaptiveAvgPool3DFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"adaptive_avg_pool3d\").Input(\"x\").Output(\"y\").Build());\n  }\n};\n\nclass AdaptiveMaxPoolBaseFunctor {\n public:\n  AdaptiveMaxPoolBaseFunctor() = default;\n  virtual ~AdaptiveMaxPoolBaseFunctor() = default;\n  Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& x,\n                                const std::vector<int64_t>& output_size,\n                                const std::string& data_format) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"output_size\", \"data_format\");\n    attrs.SetAllAttrs(output_size, data_format);\n    return OpInterpUtil::Dispatch<TensorTuple>(*op_, {x}, attrs);\n  }\n\n protected:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass AdaptiveMaxPool1DFunctor : public AdaptiveMaxPoolBaseFunctor {\n public:\n  AdaptiveMaxPool1DFunctor() {\n    op_ = CHECK_JUST(\n        one::OpBuilder(\"adaptive_max_pool1d\").Input(\"x\").Output(\"y\").Output(\"index\").Build());\n  }\n};\n\nclass AdaptiveMaxPool2DFunctor : public AdaptiveMaxPoolBaseFunctor {\n public:\n  AdaptiveMaxPool2DFunctor() {\n    op_ = CHECK_JUST(\n        one::OpBuilder(\"adaptive_max_pool2d\").Input(\"x\").Output(\"y\").Output(\"index\").Build());\n  }\n};\n\nclass AdaptiveMaxPool3DFunctor : public AdaptiveMaxPoolBaseFunctor {\n public:\n  AdaptiveMaxPool3DFunctor() {\n    op_ = CHECK_JUST(\n        one::OpBuilder(\"adaptive_max_pool3d\").Input(\"x\").Output(\"y\").Output(\"index\").Build());\n  }\n};\nclass LossFunctorBase {\n public:\n  Maybe<Tensor> apply_reduction(const Maybe<Tensor>& x, const std::string& reduction) const {\n    CHECK_OR_RETURN(reduction == \"none\" || reduction == \"sum\" || reduction == \"mean\")\n        << Error::RuntimeError() << \"Reduction should be none, sum or mean.\";\n    if (reduction == \"sum\") { return functional::ReduceSum(JUST(x), {}, false, NullOpt); }\n    if (reduction == \"mean\") { return functional::ReduceMean(JUST(x), {}, false); }\n    return x;\n  }\n\n protected:\n  LossFunctorBase() = default;\n  virtual ~LossFunctorBase() = default;\n};\n\nclass MseLossFunctor : public LossFunctorBase {\n public:\n  MseLossFunctor() {}\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input,\n                           const std::shared_ptr<one::Tensor>& target,\n                           const std::string& reduction) const {\n    const auto out = sequence_function(functional::Sub)\n                         .then(functional::Square)\n                         .call(input, target, /*alpha=*/1.0, /*inplace=*/false);\n    return apply_reduction(out, reduction);\n  }\n};\n\nclass L1LossFunctor : public LossFunctorBase {\n public:\n  L1LossFunctor() {}\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input,\n                           const std::shared_ptr<one::Tensor>& target,\n                           const std::string& reduction) const {\n    const auto out = sequence_function(functional::Sub)\n                         .then(functional::Abs)\n                         .call(input, target, /*alpha=*/1.0, /*inplace=*/false);\n    return apply_reduction(out, reduction);\n  }\n};\n\nclass SmoothL1LossFunctor : LossFunctorBase {\n public:\n  SmoothL1LossFunctor() {\n    op_ = CHECK_JUST(\n        one::OpBuilder(\"smooth_l1_loss\").Input(\"input\").Input(\"target\").Output(\"out\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input,\n                           const std::shared_ptr<one::Tensor>& target, const float& beta,\n                           const std::string& reduction) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"beta\");\n    attrs.SetAllAttrs(beta);\n    return apply_reduction(OpInterpUtil::Dispatch<Tensor>(*op_, {input, target}, attrs), reduction);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass KLDivLossFunctor : public LossFunctorBase {\n public:\n  KLDivLossFunctor() {\n    op_ = CHECK_JUST(\n        one::OpBuilder(\"kl_div_loss\").Input(\"input\").Input(\"target\").Output(\"out\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input,\n                           const std::shared_ptr<one::Tensor>& target, const bool log_target,\n                           const std::string& reduction) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"log_target\");\n    attrs.SetAllAttrs(log_target);\n    if (reduction == \"batchmean\" && input->ndim() != 0) {\n      const auto& result = JUST(\n          apply_reduction(OpInterpUtil::Dispatch<Tensor>(*op_, {input, target}, attrs), \"sum\"));\n      return ScalarDiv(result, input->shape()->At(0));\n    } else {\n      return apply_reduction(OpInterpUtil::Dispatch<Tensor>(*op_, {input, target}, attrs),\n                             reduction);\n    }\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass MarginRankingLossFunctor : public LossFunctorBase {\n public:\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input_1,\n                           const std::shared_ptr<one::Tensor>& input_2,\n                           const std::shared_ptr<one::Tensor>& target, const float margin,\n                           const std::string& reduction) const {\n    const auto out =\n        sequence_function(functional::Sub)\n            .then(functional::Negative)\n            .then(std::bind(functional::Mul, target, std::placeholders::_1))\n            .then([&margin](const std::shared_ptr<one::Tensor>& x) {\n              return functional::ScalarAdd(x, Scalar(margin), /*alpha=*/1, /*inplace=*/true);\n            })\n            .then(std::bind(functional::Clamp, std::placeholders::_1, Scalar(0), NullOpt))\n            .call(input_1, input_2, /*alpha=*/1.0, /*inplace=*/false);\n    return apply_reduction(out, reduction);\n  }\n};\n\nclass BinaryCrossEntropyLossFunctor : public LossFunctorBase {\n public:\n  BinaryCrossEntropyLossFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"binary_cross_entropy\")\n                         .Input(\"input\")\n                         .Input(\"target\")\n                         .Output(\"out\")\n                         .Build());\n    op_weight_ = CHECK_JUST(one::OpBuilder(\"binary_cross_entropy\")\n                                .Input(\"input\")\n                                .Input(\"target\")\n                                .Input(\"weight\")\n                                .Output(\"out\")\n                                .Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input,\n                           const std::shared_ptr<one::Tensor>& target,\n                           const Optional<one::Tensor>& weight,\n                           const std::string& reduction) const {\n    auto out = weight ? OpInterpUtil::Dispatch<Tensor>(*op_weight_, {input, target, JUST(weight)})\n                      : OpInterpUtil::Dispatch<Tensor>(*op_, {input, target});\n    return apply_reduction(out, reduction);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n  std::shared_ptr<OpExpr> op_weight_;\n};\n\nclass BinaryCrossEntropyWithLogitsLossFunctor : public LossFunctorBase {\n public:\n  BinaryCrossEntropyWithLogitsLossFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"binary_cross_entropy_with_logits\")\n                         .Input(\"input\")\n                         .Input(\"target\")\n                         .Output(\"out\")\n                         .Build());\n    op_weight_ = CHECK_JUST(one::OpBuilder(\"binary_cross_entropy_with_logits\")\n                                .Input(\"input\")\n                                .Input(\"target\")\n                                .Input(\"weight\")\n                                .Output(\"out\")\n                                .Build());\n    op_pos_ = CHECK_JUST(one::OpBuilder(\"binary_cross_entropy_with_logits\")\n                             .Input(\"input\")\n                             .Input(\"target\")\n                             .Input(\"pos_weight\")\n                             .Output(\"out\")\n                             .Build());\n    op_weight_pos_ = CHECK_JUST(one::OpBuilder(\"binary_cross_entropy_with_logits\")\n                                    .Input(\"input\")\n                                    .Input(\"target\")\n                                    .Input(\"weight\")\n                                    .Input(\"pos_weight\")\n                                    .Output(\"out\")\n                                    .Build());\n    op_reduce_mean_ = CHECK_JUST(one::OpBuilder(\"binary_cross_entropy_with_logits_reduce_mean\")\n                                     .Input(\"input\")\n                                     .Input(\"target\")\n                                     .Output(\"out\")\n                                     .Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input,\n                           const std::shared_ptr<one::Tensor>& target,\n                           const Optional<one::Tensor>& weight,\n                           const Optional<one::Tensor>& pos_weight,\n                           const std::string& reduction) const {\n    if (pos_weight) {\n      const auto pos_weight_shape = JUST(pos_weight)->shape();\n      // pos weight shape = (), (1,), (1,1)... or (input/target.shape[-1],)\n      const bool is_pos_weight_shape_valid =\n          (pos_weight_shape->elem_cnt() == 1)\n          || (pos_weight_shape->NumAxes() == 1\n              && pos_weight_shape->At(0) == target->shape()->back());\n\n      CHECK_OR_RETURN(is_pos_weight_shape_valid)\n          << Error::RuntimeError()\n          << \"pos_weight must be a vector with length equal to the number of classes.\";\n    }\n\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"has_pos_weight\");\n    attrs.SetAllAttrs(pos_weight.has_value());\n    std::shared_ptr<Tensor> out;\n    if (weight) {\n      if (pos_weight) {\n        out = JUST(OpInterpUtil::Dispatch<Tensor>(\n            *op_weight_pos_, {input, target, JUST(weight), JUST(pos_weight)}, attrs));\n      } else {\n        out =\n            JUST(OpInterpUtil::Dispatch<Tensor>(*op_weight_, {input, target, JUST(weight)}, attrs));\n      }\n    } else {\n      if (pos_weight) {\n        out = JUST(\n            OpInterpUtil::Dispatch<Tensor>(*op_pos_, {input, target, JUST(pos_weight)}, attrs));\n      } else {\n        if (reduction == \"mean\") {\n          return OpInterpUtil::Dispatch<Tensor>(*op_reduce_mean_, {input, target});\n        }\n        out = JUST(OpInterpUtil::Dispatch<Tensor>(*op_, {input, target}, attrs));\n      }\n    }\n    return apply_reduction(out, reduction);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n  std::shared_ptr<OpExpr> op_weight_;\n  std::shared_ptr<OpExpr> op_pos_;\n  std::shared_ptr<OpExpr> op_weight_pos_;\n  std::shared_ptr<OpExpr> op_reduce_mean_;\n};\n\nclass NLLLossFunctor {\n public:\n  NLLLossFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"nll\")\n                         .Input(\"input\")\n                         .Input(\"target\")\n                         .Output(\"output\")\n                         .Output(\"out_weight\")\n                         .Build());\n\n    op_weight_ = CHECK_JUST(one::OpBuilder(\"nll\")\n                                .Input(\"input\")\n                                .Input(\"target\")\n                                .Input(\"weight\")\n                                .Output(\"output\")\n                                .Output(\"out_weight\")\n                                .Build());\n  }\n\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input,\n                           const std::shared_ptr<one::Tensor>& target,\n                           const Optional<one::Tensor>& weight, const int64_t& ignore_index,\n                           const std::string& reduction) const {\n    CHECK_OR_RETURN(reduction == \"none\" || reduction == \"sum\" || reduction == \"mean\")\n        << Error::RuntimeError() << \"Reduction should be none, sum or mean.\";\n\n    const auto& input_shape = input->shape();\n    const int64_t K = input_shape->NumAxes();\n    CHECK_GE_OR_RETURN(K, 2) << Error::RuntimeError() << \"Expected 2 or more dimensions\";\n    const int64_t N = input_shape->At(0);\n    const int64_t C = input_shape->At(1);\n\n    const auto& target_shape = target->shape();\n    CHECK_EQ_OR_RETURN(target_shape->NumAxes(), K - 1)\n        << Error::RuntimeError() << \"Expected target dimensions (\" << K - 1\n        << \") to match input dimensions (\" << K << \"), got \" << target_shape->NumAxes();\n    CHECK_EQ_OR_RETURN(target_shape->At(0), N)\n        << Error::RuntimeError() << \"Expected input batch_size (\" << N\n        << \") to match target batch_size (\" << target_shape->At(0) << \")\";\n\n    std::shared_ptr<one::Tensor> input_;\n    std::shared_ptr<one::Tensor> target_;\n    if (K > 2) {\n      DimVector idea_target_dim_vec;\n      idea_target_dim_vec.push_back(N);\n      for (int64_t i = 2; i < K; ++i) { idea_target_dim_vec.push_back(input_shape->At(i)); }\n      Shape idea_target_shape(idea_target_dim_vec);\n      CHECK_EQ_OR_RETURN(*target_shape, idea_target_shape)\n          << Error::RuntimeError() << \"Expected target shape \" << idea_target_shape.ToString()\n          << \", got \" << target_shape->ToString();\n\n      std::vector<int> perm(input_shape->dim_vec().size(), 0);\n      perm[perm.size() - 1] = 1;\n      for (size_t i = 1; i < perm.size() - 1; ++i) { perm[i] = i + 1; }\n\n      input_ = JUST(sequence_function(functional::Transpose)\n                        .then(std::bind(functional::Reshape, std::placeholders::_1, Shape({-1, C})))\n                        .call(input, perm));\n      target_ = JUST(functional::Flatten(target, 0, K - 2));\n    } else {\n      input_ = input;\n      target_ = target;\n    }\n\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"ignore_index\");\n    attrs.SetAllAttrs(ignore_index);\n\n    std::shared_ptr<TensorTuple> nll_result;\n    if (weight) {\n      nll_result = JUST(\n          OpInterpUtil::Dispatch<TensorTuple>(*op_weight_, {input_, target_, JUST(weight)}, attrs));\n    } else {\n      nll_result = JUST(OpInterpUtil::Dispatch<TensorTuple>(*op_, {input_, target_}, attrs));\n    }\n    auto output = JUST(VectorAt(*nll_result, 0));\n\n    if (K > 2) { output = JUST(functional::Reshape(output, *target_shape)); }\n\n    if (reduction == \"none\") { return output; }\n\n    auto sum = JUST(functional::ReduceSum(output, {}, false, NullOpt));\n\n    if (reduction == \"sum\") { return sum; }\n\n    auto total_weight =\n        JUST(functional::ReduceSum(JUST(VectorAt(*nll_result, 1)), {}, false, NullOpt));\n    return functional::Div(sum, total_weight);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n  std::shared_ptr<OpExpr> op_weight_;\n};\n\nclass CrossEntropyFunctor {\n public:\n  CrossEntropyFunctor() {\n    op_log_softmax_ = CHECK_JUST(one::OpBuilder(\"log_softmax\").Input(\"in\").Output(\"prob\").Build());\n\n    op_nll_ = CHECK_JUST(one::OpBuilder(\"nll\")\n                             .Input(\"input\")\n                             .Input(\"target\")\n                             .Output(\"output\")\n                             .Output(\"out_weight\")\n                             .Build());\n\n    op_nll_weight_ = CHECK_JUST(one::OpBuilder(\"nll\")\n                                    .Input(\"input\")\n                                    .Input(\"target\")\n                                    .Input(\"weight\")\n                                    .Output(\"output\")\n                                    .Output(\"out_weight\")\n                                    .Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input,\n                           const std::shared_ptr<one::Tensor>& target,\n                           const Optional<one::Tensor>& weight, const int64_t& ignore_index,\n                           const std::string& reduction, const double& label_smoothing) const {\n    if (input->shape() == target->shape()) {\n      CHECK_OR_RETURN(target->dtype()->is_floating_point())\n          << \"Expected floating point type for target with class probabilities, got \"\n          << target->dtype()->name();\n      CHECK_LT_OR_RETURN(ignore_index, 0)\n          << \"ignore_index is not supported for floating point targe\";\n      return CrossEntropyProb(input, target, weight, reduction, label_smoothing);\n    }\n    if (label_smoothing > 0.0)\n      return CrossEntropyLabelSmoothing(input, target, weight, ignore_index, reduction,\n                                        label_smoothing);\n    CHECK_OR_RETURN(reduction == \"none\" || reduction == \"sum\" || reduction == \"mean\")\n        << Error::RuntimeError() << \"Reduction should be none, sum or mean.\";\n    const auto& input_shape = input->shape();\n    const auto& target_shape = target->shape();\n\n    std::vector<int> input_perm(input_shape->dim_vec().size(), 0);\n    input_perm[input_perm.size() - 1] = 1;\n    for (size_t i = 1; i < input_perm.size() - 1; ++i) { input_perm[i] = i + 1; }\n\n    const auto input_ = JUST(sequence_function(functional::Transpose)\n                                 .then(std::bind(functional::Reshape, std::placeholders::_1,\n                                                 Shape({-1, input_shape->At(1)})))\n                                 .then([this](const std::shared_ptr<one::Tensor>& x) {\n                                   return OpInterpUtil::Dispatch<Tensor>(*op_log_softmax_, {x});\n                                 })\n                                 .call(input, input_perm));\n\n    const auto target_ = JUST(functional::Flatten(target, 0, target->shape()->NumAxes() - 1));\n\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"ignore_index\");\n    attrs.SetAllAttrs(ignore_index);\n\n    std::shared_ptr<TensorTuple> nll_result;\n    if (weight) {\n      nll_result = JUST(OpInterpUtil::Dispatch<TensorTuple>(\n          *op_nll_weight_, {input_, target_, JUST(weight)}, attrs));\n    } else {\n      nll_result = JUST(OpInterpUtil::Dispatch<TensorTuple>(*op_nll_, {input_, target_}, attrs));\n    }\n\n    auto output = JUST(VectorAt(*nll_result, 0));\n    output = JUST(functional::Reshape(output, *target_shape));\n    if (reduction == \"none\") { return output; }\n\n    auto sum = JUST(functional::ReduceSum(output, {}, false, NullOpt));\n    if (reduction == \"sum\") { return sum; }\n\n    auto total_weight =\n        JUST(functional::ReduceSum(JUST(VectorAt(*nll_result, 1)), {}, false, NullOpt));\n    return functional::Div(sum, total_weight);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_log_softmax_;\n  std::shared_ptr<OpExpr> op_nll_;\n  std::shared_ptr<OpExpr> op_nll_weight_;\n};\n\nclass CrossEntropyLabelSmoothingFunctor {\n public:\n  CrossEntropyLabelSmoothingFunctor() {\n    op_log_softmax_ = CHECK_JUST(one::OpBuilder(\"log_softmax\").Input(\"in\").Output(\"prob\").Build());\n\n    op_nll_ = CHECK_JUST(one::OpBuilder(\"nll\")\n                             .Input(\"input\")\n                             .Input(\"target\")\n                             .Output(\"output\")\n                             .Output(\"out_weight\")\n                             .Build());\n\n    op_nll_weight_ = CHECK_JUST(one::OpBuilder(\"nll\")\n                                    .Input(\"input\")\n                                    .Input(\"target\")\n                                    .Input(\"weight\")\n                                    .Output(\"output\")\n                                    .Output(\"out_weight\")\n                                    .Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input,\n                           const std::shared_ptr<one::Tensor>& target,\n                           const Optional<one::Tensor>& weight, const int64_t& ignore_index,\n                           const std::string& reduction, const double& label_smoothing) const {\n    CHECK_OR_RETURN(reduction == \"none\" || reduction == \"sum\" || reduction == \"mean\")\n        << Error::RuntimeError() << \"Reduction should be none, sum or mean.\";\n    const auto& input_shape = input->shape();\n    const auto& target_shape = target->shape();\n\n    std::vector<int> input_perm(input_shape->dim_vec().size(), 0);\n    input_perm[input_perm.size() - 1] = 1;\n    for (size_t i = 1; i < input_perm.size() - 1; ++i) { input_perm[i] = i + 1; }\n    CHECK_OR_RETURN(label_smoothing > 0.0 && label_smoothing <= 1.0)\n        << \"label_smoothing must be between 0.0 and 1.0. Got: \" << label_smoothing;\n\n    const auto& input_ = JUST(sequence_function(functional::Transpose)\n                                  .then(std::bind(functional::Reshape, std::placeholders::_1,\n                                                  Shape({-1, input_shape->At(1)})))\n                                  .then([this](const std::shared_ptr<one::Tensor>& x) {\n                                    return OpInterpUtil::Dispatch<Tensor>(*op_log_softmax_, {x});\n                                  })\n                                  .call(input, input_perm));\n    const auto& target_ = JUST(functional::Flatten(target, 0, target->shape()->NumAxes() - 1));\n\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"ignore_index\");\n    attrs.SetAllAttrs(ignore_index);\n\n    std::shared_ptr<TensorTuple> nll_result;\n    if (weight) {\n      nll_result = JUST(OpInterpUtil::Dispatch<TensorTuple>(\n          *op_nll_weight_, {input_, target_, JUST(weight)}, attrs));\n    } else {\n      nll_result = JUST(OpInterpUtil::Dispatch<TensorTuple>(*op_nll_, {input_, target_}, attrs));\n    }\n\n    const auto& ignore_mask = JUST(Reshape(JUST(ScalarLogicalEqual(target_, ignore_index)), {-1}));\n\n    // smooth_loss = (-(input_ * weight.reshape(1, -1)).sum(1) * ~ignore_mask).reshape_as(target)\n    std::shared_ptr<Tensor> smooth_loss = input_;\n    if (weight) {\n      const auto& weight_2d = JUST(Reshape(JUST(weight), {1, -1}));\n      smooth_loss = JUST(Mul(smooth_loss, weight_2d));\n    }\n    smooth_loss = JUST(Negative(JUST(ReduceSum(smooth_loss, {1}, false, NullOpt))));\n    smooth_loss = JUST(MaskedFill(smooth_loss, ignore_mask, 0.0));\n    smooth_loss = JUST(Reshape(smooth_loss, *target_shape));\n\n    int64_t n_classes = input->shape()->At(1);\n    auto nll_loss = JUST(VectorAt(*nll_result, 0));\n    nll_loss = JUST(functional::Reshape(nll_loss, *target_shape));\n\n    // loss = nll_loss * (1 - label_smoothing) + smooth_loss * label_smoothing / num_classes\n    if (reduction == \"none\") {\n      return JUST(Add(JUST(ScalarMul(nll_loss, 1 - label_smoothing, false)),\n                      JUST(ScalarMul(smooth_loss, label_smoothing / n_classes, false)), 1, false));\n    }\n\n    const auto& nll_loss_sum = JUST(ReduceSum(nll_loss, {}, false, NullOpt));\n    const auto& smooth_loss_sum = JUST(ReduceSum(smooth_loss, {}, false, NullOpt));\n    const auto& cross_entropy_loss_sum =\n        JUST(Add(JUST(ScalarMul(nll_loss_sum, 1 - label_smoothing, false)),\n                 JUST(ScalarMul(smooth_loss_sum, label_smoothing / n_classes, false)), 1, false));\n    if (reduction == \"sum\") { return cross_entropy_loss_sum; }\n\n    const auto& total_weight = JUST(ReduceSum(JUST(VectorAt(*nll_result, 1)), {}, false, NullOpt));\n    return Div(cross_entropy_loss_sum, total_weight);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_log_softmax_;\n  std::shared_ptr<OpExpr> op_nll_;\n  std::shared_ptr<OpExpr> op_nll_weight_;\n};\n\nclass CrossEntropyProbFunctor : public LossFunctorBase {\n public:\n  CrossEntropyProbFunctor() {\n    op_log_softmax_ = CHECK_JUST(one::OpBuilder(\"log_softmax\").Input(\"in\").Output(\"prob\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input,\n                           const std::shared_ptr<one::Tensor>& target,\n                           const Optional<one::Tensor>& weight, const std::string& reduction,\n                           const double& label_smoothing) const {\n    const auto& input_shape = input->shape();\n    const auto& target_shape = target->shape();\n\n    std::vector<int> input_perm(input_shape->NumAxes(), 0);\n    input_perm[input_perm.size() - 1] = 1;\n    for (size_t i = 1; i < input_perm.size() - 1; ++i) { input_perm[i] = i + 1; }\n\n    const auto input_ = JUST(sequence_function(functional::Transpose)\n                                 .then(std::bind(functional::Reshape, std::placeholders::_1,\n                                                 Shape({-1, input_shape->At(1)})))\n                                 .then([this](const std::shared_ptr<one::Tensor>& x) {\n                                   return OpInterpUtil::Dispatch<Tensor>(*op_log_softmax_, {x});\n                                 })\n                                 .call(input, input_perm));\n    std::shared_ptr<Tensor> target_ =\n        JUST(sequence_function(functional::Transpose)\n                 .then(std::bind(functional::Reshape, std::placeholders::_1,\n                                 Shape({-1, target_shape->At(1)})))\n                 .call(target, input_perm));\n    if (label_smoothing > 0) {\n      int32_t num_classes = input_->shape()->At(1);\n      target_ =\n          JUST(ScalarAdd(JUST(ScalarMul(target_, static_cast<double>(1) - label_smoothing, false)),\n                         label_smoothing / static_cast<double>(num_classes), 1, false));\n    }\n\n    auto nll_result = JUST(Negative(JUST(Mul(input_, target_))));\n    if (weight) {\n      const auto& weight_expand = JUST(Unsqueeze(JUST(weight), 0));\n      nll_result = JUST(Mul(nll_result, weight_expand));\n    }\n    DimVector target_reshape_(input->ndim() - 1);\n    for (size_t i = 0; i < target_reshape_.size(); ++i) {\n      target_reshape_[i] = input_shape->At(input_perm[i]);\n    }\n    nll_result = JUST(ReduceSum(nll_result, {-1}, false, NullOpt));\n    nll_result = JUST(Reshape(nll_result, Shape(target_reshape_)));\n    return apply_reduction(nll_result, reduction);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_log_softmax_;\n};\n\nclass SparseCrossEntropyFunctor {\n public:\n  SparseCrossEntropyFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"sparse_cross_entropy\")\n                         .Input(\"prediction\")\n                         .Input(\"label\")\n                         .Output(\"out\")\n                         .Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& prediction,\n                           const std::shared_ptr<one::Tensor>& label, const int64_t& depth) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"depth\");\n    attrs.SetAllAttrs(depth);\n\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {prediction, label}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass SparseCrossEntropyMsFunctor {\n public:\n  SparseCrossEntropyMsFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"sparse_cross_entropy_ms\")\n                         .Input(\"prediction\")\n                         .Input(\"label\")\n                         .Output(\"out\")\n                         .Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& prediction,\n                           const std::shared_ptr<one::Tensor>& label, const int64_t& depth) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"depth\");\n    attrs.SetAllAttrs(depth);\n\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {prediction, label}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass SparseSoftmaxCrossEntropyFunctor {\n public:\n  SparseSoftmaxCrossEntropyFunctor() {\n    // SparseSoftmaxCrossEntropy\n    op_sparse_softmax_cross_entropy_ = CHECK_JUST(one::OpBuilder(\"sparse_softmax_cross_entropy\")\n                                                      .Input(\"prediction\")\n                                                      .Input(\"label\")\n                                                      .Output(\"prob\")\n                                                      .Output(\"out\")\n                                                      .Build());\n    // lazy model SparseSoftmaxCrossEntropyMs\n    op_sparse_softmax_cross_entropy_ms_ =\n        CHECK_JUST(one::OpBuilder(\"sparse_softmax_cross_entropy_ms\")\n                       .Input(\"prediction\")\n                       .Input(\"label\")\n                       .Output(\"prob\")\n                       .Output(\"out\")\n                       .Build());\n    // eager model SparseSoftmaxCrossEntropyMs\n    op_reduce_max_device_stage_ = CHECK_JUST(one::OpBuilder(\"reduce_max_device_stage\")\n                                                 .Input(\"in\")\n                                                 .Output(\"out\")\n                                                 .Output(\"mask\")\n                                                 .Output(\"count\")\n                                                 .Build());\n    op_reduce_max_global_stage_ = CHECK_JUST(one::OpBuilder(\"reduce_max_global_stage\")\n                                                 .Input(\"in\")\n                                                 .Input(\"device_count\")\n                                                 .Output(\"out\")\n                                                 .Output(\"mask\")\n                                                 .Build());\n    op_sparse_cross_entropy_ms_ = CHECK_JUST(one::OpBuilder(\"sparse_cross_entropy_ms\")\n                                                 .Input(\"prediction\")\n                                                 .Input(\"label\")\n                                                 .Output(\"out\")\n                                                 .Build());\n    op_broadcast_sub_ =\n        CHECK_JUST(one::OpBuilder(\"broadcast_sub\").Input(\"x\").Input(\"y\").Output(\"z\").Build());\n    op_broadcast_div_ =\n        CHECK_JUST(one::OpBuilder(\"broadcast_div\").Input(\"x\").Input(\"y\").Output(\"z\").Build());\n    op_reduce_sum_ = CHECK_JUST(\n        one::OpBuilder(\"reduce_sum\").Input(\"input_tensor\").Output(\"output_tensor\").Build());\n    op_exp_ = CHECK_JUST(one::OpBuilder(\"exp\").Input(\"x\").Output(\"y\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& logits,\n                           const std::shared_ptr<one::Tensor>& label) const {\n    if (JUST(RunWithMsVersion(logits, label))) {\n      if (LazyMode::is_enabled()) {\n        return LazySparseSoftmaxCrossEntropyMsOperator(logits, label);\n      } else {\n        return EagerSparseSoftmaxCrossEntropyMsOperator(logits, label);\n      }\n    } else {\n      return SparseSoftmaxCrossEntropyOperator(logits, label);\n    }\n  }\n\n  Maybe<bool> RunWithMsVersion(const std::shared_ptr<one::Tensor>& logits,\n                               const std::shared_ptr<one::Tensor>& label) const {\n    if (!(logits->is_global() && label->is_global())) { return false; }\n    // npu-implementation not support ms version yet\n#if defined(WITH_NPU) || defined(WITH_MLU)\n    return false;\n#endif\n\n    if (JUST(logits->parallel_desc())->parallel_num() == 1) { return false; }\n\n    if (logits->shape()->NumAxes() != 2) { return false; }\n\n    const NdSbp& logits_nd_sbp = *(JUST(logits->nd_sbp()));\n    const int32_t split_axis = logits->shape()->NumAxes() - 1;\n    bool has_split_axis_parallel = false;\n    for (int64_t i = 0; i < logits_nd_sbp.sbp_parallel_size(); ++i) {\n      const auto& sbp = logits_nd_sbp.sbp_parallel(i);\n      if (sbp.has_split_parallel() && sbp.split_parallel().axis() == split_axis) {\n        has_split_axis_parallel = true;\n      } else {\n        if (sbp.has_partial_sum_parallel()) { return false; }\n      }\n    }\n    if (!has_split_axis_parallel) { return false; }\n\n    return true;\n  }\n\n  Maybe<Tensor> SparseSoftmaxCrossEntropyOperator(const std::shared_ptr<one::Tensor>& logits,\n                                                  const std::shared_ptr<one::Tensor>& label) const {\n    int64_t depth = logits->shape()->At(logits->shape()->NumAxes() - 1);\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"depth\");\n    attrs.SetAllAttrs(depth);\n    const auto& result = JUST(OpInterpUtil::Dispatch<TensorTuple>(*op_sparse_softmax_cross_entropy_,\n                                                                  {logits, label}, attrs));\n    return result->at(1);\n  }\n\n  Maybe<Tensor> LazySparseSoftmaxCrossEntropyMsOperator(\n      const std::shared_ptr<one::Tensor>& logits, const std::shared_ptr<one::Tensor>& label) const {\n    int64_t depth = logits->shape()->At(logits->shape()->NumAxes() - 1);\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"depth\");\n    attrs.SetAllAttrs(depth);\n    const auto& result = JUST(OpInterpUtil::Dispatch<TensorTuple>(\n        *op_sparse_softmax_cross_entropy_ms_, {logits, label}, attrs));\n    return result->at(1);\n  }\n\n  Maybe<Tensor> EagerSparseSoftmaxCrossEntropyMsOperator(\n      const std::shared_ptr<one::Tensor>& logits, const std::shared_ptr<one::Tensor>& label) const {\n    // op_reduce_max_device_stage_\n    int64_t depth = logits->shape()->At(logits->shape()->NumAxes() - 1);\n    int32_t axis = logits->shape()->NumAxes() - 1;\n\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"axis\");\n    attrs.SetAllAttrs(std::vector<int32_t>{axis});\n    const auto& max_device_stage =\n        JUST(OpInterpUtil::Dispatch<TensorTuple>(*op_reduce_max_device_stage_, {logits}, attrs));\n    std::shared_ptr<Tensor> max_global_stage_input0 = max_device_stage->at(0);\n    std::shared_ptr<Tensor> max_global_stage_input1 = max_device_stage->at(2);\n\n    const NdSbp& logits_nd_sbp = *(JUST(logits->nd_sbp()));\n    std::vector<Symbol<SbpParallel>> new_sbp_parallels;\n    std::vector<Symbol<SbpParallel>> s0s1_sbp_parallels;\n    if (logits_nd_sbp.sbp_parallel_size() == 2) {\n      for (int i = 0; i < logits_nd_sbp.sbp_parallel_size(); ++i) {\n        const auto& sbp_parallel = logits_nd_sbp.sbp_parallel(i);\n        if (sbp_parallel.has_split_parallel()) {\n          const int64_t& split_axis = sbp_parallel.split_parallel().axis();\n          if (split_axis == axis) {\n            SbpParallel sbp;\n            sbp.mutable_broadcast_parallel();\n            new_sbp_parallels.emplace_back(sbp);\n          } else {\n            CHECK_EQ_OR_RETURN(split_axis, 0)\n                << Error::RuntimeError() << \"Split axis must equal to 0. \";\n            new_sbp_parallels.emplace_back(sbp_parallel);\n          }\n        } else {\n          new_sbp_parallels.emplace_back(sbp_parallel);\n        }\n      }\n\n      s0s1_sbp_parallels.emplace_back(logits_nd_sbp.sbp_parallel(0));\n      s0s1_sbp_parallels.emplace_back(logits_nd_sbp.sbp_parallel(1));\n      max_global_stage_input0 = JUST(functional::ToGlobal(\n          (*max_device_stage)[0], JUST((*max_device_stage)[0]->parallel_desc()), new_sbp_parallels,\n          s0s1_sbp_parallels, /* check_meta */ false, /*copy=*/false));\n      max_global_stage_input1 = JUST(functional::ToGlobal(\n          (*max_device_stage)[2], JUST((*max_device_stage)[0]->parallel_desc()), new_sbp_parallels,\n          s0s1_sbp_parallels, /* check_meta */ false, /*copy=*/false));\n    }\n    // op_reduce_max_global_stage_\n    auto& reduce_max_global_attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"axis\", \"keepdims\");\n    reduce_max_global_attrs.SetAllAttrs(std::vector<int32_t>{axis}, true);\n    const auto& max_global_stage = JUST(OpInterpUtil::Dispatch<TensorTuple>(\n        *op_reduce_max_global_stage_, {max_global_stage_input0, max_global_stage_input1},\n        reduce_max_global_attrs));\n    auto& broadcast_sub_input = max_global_stage->at(0);\n    if (logits_nd_sbp.sbp_parallel_size() == 2) {\n      broadcast_sub_input = JUST(functional::ToGlobal(\n          broadcast_sub_input, JUST((*max_device_stage)[0]->parallel_desc()), new_sbp_parallels,\n          new_sbp_parallels, /* check_meta */ false, /*copy=*/false));\n    }\n    // op_broadcast_sub_\n    const auto& output_broadcast_sub = JUST(\n        OpInterpUtil::Dispatch<TensorTuple>(*op_broadcast_sub_, {logits, broadcast_sub_input}));\n    // op_exp_\n    const auto& output_exp =\n        JUST(OpInterpUtil::Dispatch<TensorTuple>(*op_exp_, {(*output_broadcast_sub)[0]}));\n    // op_reduce_sum_\n    auto& reduce_sum_attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"axis\", \"keepdims\");\n    reduce_sum_attrs.SetAllAttrs(std::vector<int32_t>{axis}, true);\n    const auto& output_reduce_sum = JUST(\n        OpInterpUtil::Dispatch<TensorTuple>(*op_reduce_sum_, {(*output_exp)[0]}, reduce_sum_attrs));\n    std::shared_ptr<Tensor> broadcast_div_input1 = output_reduce_sum->at(0);\n    if (logits_nd_sbp.sbp_parallel_size() == 2) {\n      std::vector<Symbol<SbpParallel>> empty_grad_sbp_parallels;\n      broadcast_div_input1 = JUST(functional::ToGlobal(\n          (*output_reduce_sum)[0], JUST((*output_reduce_sum)[0]->parallel_desc()),\n          new_sbp_parallels, new_sbp_parallels, /* check_meta */ false, /*copy=*/false));\n    }\n    // op_broadcast_div_\n    const auto& predictions = JUST(OpInterpUtil::Dispatch<TensorTuple>(\n        *op_broadcast_div_, {(*output_exp)[0], broadcast_div_input1}));\n    // op_sparse_cross_entropy_ms_\n    auto& sparse_cross_entropy_ms_attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"depth\");\n    sparse_cross_entropy_ms_attrs.SetAllAttrs(depth);\n    const auto& output = JUST(OpInterpUtil::Dispatch<Tensor>(\n        *op_sparse_cross_entropy_ms_, {(*predictions)[0], label}, sparse_cross_entropy_ms_attrs));\n    return output;\n  }\n\n private:\n  // SparseSoftmaxCrossEntropy\n  std::shared_ptr<OpExpr> op_sparse_softmax_cross_entropy_;\n  // lazy model SparseSoftmaxCrossEntropyMs\n  std::shared_ptr<OpExpr> op_sparse_softmax_cross_entropy_ms_;\n  // SparseSoftmaxCrossEntropyMs\n  std::shared_ptr<OpExpr> op_reduce_max_device_stage_;\n  std::shared_ptr<OpExpr> op_reduce_max_global_stage_;\n  std::shared_ptr<OpExpr> op_broadcast_sub_;\n  std::shared_ptr<OpExpr> op_exp_;\n  std::shared_ptr<OpExpr> op_reduce_sum_;\n  std::shared_ptr<OpExpr> op_broadcast_div_;\n  std::shared_ptr<OpExpr> op_sparse_cross_entropy_ms_;\n};\n\nclass SoftmaxCrossEntropyFunctor {\n public:\n  SoftmaxCrossEntropyFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"softmax_cross_entropy\")\n                         .Input(\"prediction\")\n                         .Input(\"label\")\n                         .Output(\"out\")\n                         .Output(\"prob\")\n                         .Build());\n  }\n\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& logits,\n                           const std::shared_ptr<one::Tensor>& label) const {\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {logits, label});\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass SoftmaxCrossEntropyGradFunctor {\n public:\n  SoftmaxCrossEntropyGradFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"softmax_cross_entropy_grad\")\n                         .Input(\"dy\")\n                         .Input(\"label\")\n                         .Input(\"prob\")\n                         .Output(\"prediction_diff\")\n                         .Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& dy,\n                           const std::shared_ptr<one::Tensor>& label,\n                           const std::shared_ptr<one::Tensor>& prob) const {\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {dy, label, prob});\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass CombinedMarginLossFunctor {\n public:\n  CombinedMarginLossFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"combined_margin_loss\")\n                         .Input(\"x\")\n                         .Input(\"label\")\n                         .Output(\"y\")\n                         .Output(\"theta\")\n                         .Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,\n                           const std::shared_ptr<one::Tensor>& label, const float& m1,\n                           const float& m2, const float& m3) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"m1\", \"m2\", \"m3\", \"depth\");\n    attrs.SetAllAttrs(m1, m2, m3, x->shape()->At(1));\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {x, label}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass CtcLossFunctor {\n public:\n  CtcLossFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"ctc_loss\")\n                         .Input(\"log_probs\")\n                         .Input(\"targets\")\n                         .Input(\"input_lengths\")\n                         .Input(\"target_lengths\")\n                         .Output(\"loss\")\n                         .Output(\"alpha\")\n                         .Build());\n    op_xdivy_ = CHECK_JUST(one::OpBuilder(\"xdivy\").Input(\"x\").Input(\"y\").Output(\"z\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& log_probs,\n                           const std::shared_ptr<one::Tensor>& targets,\n                           const std::shared_ptr<one::Tensor>& input_lengths,\n                           const std::shared_ptr<one::Tensor>& target_lengths,\n                           const int64_t& max_target_length, const int64_t& blank,\n                           const bool& zero_infinity, const std::string& reduction) const {\n    // FIXME: global ctc loss sometimes segfaults\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"max_target_length\", \"blank\", \"zero_infinity\");\n    attrs.SetAllAttrs(max_target_length, blank, zero_infinity);\n    std::shared_ptr<one::Tensor> out;\n    DeviceType log_probs_device_type;  // NOLINT\n    if (log_probs->is_local()) {\n      log_probs_device_type = JUST(log_probs->device())->enum_type();\n    } else {\n      log_probs_device_type = JUST(log_probs->parallel_desc())->device_type();\n    }\n    const std::string& log_probs_device_str = *JUST(DeviceTag4DeviceType(log_probs_device_type));\n    std::shared_ptr<one::Tensor> target_lengths_on_log_probs_device =\n        JUST(functional::To(target_lengths, log_probs_device_str));\n    if (targets->dtype()->data_type() == DataType::kInt32) {\n      out = JUST(OpInterpUtil::Dispatch<Tensor>(\n          *op_,\n          {\n              log_probs,\n              JUST(functional::To(targets, log_probs_device_str)),\n              JUST(functional::To(input_lengths, log_probs_device_str)),\n              target_lengths_on_log_probs_device,\n          },\n          attrs));\n    } else {\n      out = JUST(OpInterpUtil::Dispatch<Tensor>(\n          *op_,\n          {\n              log_probs,\n              JUST(functional::To(targets, Optional<std::string>(log_probs_device_str),\n                                  DType::Int64(), false)),\n              JUST(functional::To(input_lengths, log_probs_device_str)),\n              target_lengths_on_log_probs_device,\n          },\n          attrs));\n    }\n    if (zero_infinity) {\n      if (out->is_local()) {\n        const auto create_constant = [&](const Scalar& scalar) -> Maybe<Tensor> {\n          return functional::Constant(*out->shape(), scalar, out->dtype(), JUST(out->device()));\n        };\n\n        out = JUST(sequence_function(functional::Constant)\n                       .then(std::bind(functional::BroadcastEqual, out, std::placeholders::_1))\n                       .then(std::bind(functional::Where, std::placeholders::_1,\n                                       JUST(create_constant(Scalar(0))), out))\n                       .call(*out->shape(), Scalar(std::numeric_limits<double>::infinity()),\n                             out->dtype(), JUST(out->device())));\n      } else {\n        const auto& placement = JUST(out->parallel_desc());\n        const auto& nd_sbp = *JUST(GetSbpList(JUST(out->nd_sbp())));\n        const auto create_constant = [&](const Scalar& scalar) -> Maybe<Tensor> {\n          return functional::GlobalConstant(*out->shape(), scalar, out->dtype(), placement, nd_sbp);\n        };\n\n        out = JUST(sequence_function(functional::GlobalConstant)\n                       .then(std::bind(functional::BroadcastEqual, out, std::placeholders::_1))\n                       .then(std::bind(functional::Where, std::placeholders::_1,\n                                       JUST(create_constant(Scalar(0))), out))\n                       .call(*out->shape(), Scalar(std::numeric_limits<double>::infinity()),\n                             out->dtype(), placement, nd_sbp));\n      }\n    }\n    CHECK_OR_RETURN([&]() -> bool {\n      if ((reduction != \"none\") && (reduction != \"sum\") && (reduction != \"mean\")) return false;\n      return true;\n    }()) << Error::RuntimeError()\n         << \"Reduction should be none, sum or mean.\";\n    if (reduction == \"sum\") { return functional::ReduceSum(out, {}, false, NullOpt); }\n    if (reduction == \"mean\") {\n      return sequence_function(functional::Clamp)\n          .then(std::bind(functional::Cast, std::placeholders::_1, log_probs->dtype(),\n                          /*pin_memory=*/false))\n          .then([&](const std::shared_ptr<one::Tensor>& x) {\n            return OpInterpUtil::Dispatch<Tensor>(*op_xdivy_, {out, x});\n          })\n          .then(std::bind(functional::ReduceMean, std::placeholders::_1, std::vector<int32_t>({}),\n                          false))\n          .call(target_lengths_on_log_probs_device, Scalar(1), NullOpt);\n    }\n    return out;\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n  std::shared_ptr<OpExpr> op_xdivy_;\n};\n\nclass TripletMarginLossFunctor {\n public:\n  TripletMarginLossFunctor() {}\n\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& anchor,\n                           const std::shared_ptr<one::Tensor>& positive,\n                           const std::shared_ptr<one::Tensor>& negative, const float& margin,\n                           const float& p, const float& eps, const bool& swap,\n                           const std::string& reduction) const {\n    int32_t dim_norm = anchor->ndim() - 1;\n    std::vector<int32_t> dim(1, dim_norm);\n    CHECK_OR_RETURN([&]() -> bool {\n      if ((reduction != \"none\") && (reduction != \"sum\") && (reduction != \"mean\")) return false;\n      return true;\n    }()) << Error::RuntimeError()\n         << \"Reduction should be none, sum or mean.\";\n    auto da_p = JUST(VectorNorm(\n        JUST(ScalarAdd(eps, JUST(Sub(anchor, positive, /*alpha=*/1.0, /*inplace=*/false)),\n                       /*alpha=*/1)),\n        p, dim,\n        /*keepdim=*/false, anchor->dtype()));\n    auto da_n = JUST(VectorNorm(\n        JUST(ScalarAdd(eps, JUST(Sub(anchor, negative, /*alpha=*/1.0, /*inplace=*/false)),\n                       /*alpha=*/1)),\n        p, dim,\n        /*keepdim=*/false, anchor->dtype()));\n    if (swap) {\n      auto distance_swap = JUST(VectorNorm(\n          JUST(ScalarAdd(eps, JUST(Sub(positive, negative, /*alpha=*/1.0, /*inplace=*/false)),\n                         /*alpha=*/1)),\n          p, dim,\n          /*keepdim=*/false, positive->dtype()));\n      da_n = JUST(Minimum(distance_swap, da_n));\n    }\n    auto triplet_loss =\n        JUST(Clamp(JUST(ScalarAdd(JUST(Sub(da_p, da_n, /*alpha=*/1.0, /*inplace=*/false)), margin,\n                                  /*alpha=*/1, /*inplace=*/false)),\n                   /*min=*/0.0, NullOpt));\n    int32_t ndim = triplet_loss->ndim() - 1;\n    std::vector<int32_t> axis(1, ndim);\n\n    if (reduction == \"mean\") {\n      triplet_loss = JUST(ReduceMean(triplet_loss, axis, /*keepdim=*/false));\n    } else if (reduction == \"sum\") {\n      triplet_loss = JUST(ReduceSum(triplet_loss, axis, /*keepdim=*/false, NullOpt));\n    }\n    return triplet_loss;\n  }\n};\n\nclass AffineGridFunctor {\n public:\n  AffineGridFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"affine_grid\").Input(\"theta\").Output(\"grid\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& theta, const Shape& size,\n                           const bool& align_corners) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"size\", \"align_corners\");\n    attrs.SetAllAttrs(size, align_corners);\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {theta}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass GridSampleFunctor {\n public:\n  GridSampleFunctor() {\n    op_ = CHECK_JUST(\n        one::OpBuilder(\"grid_sample\").Input(\"input\").Input(\"grid\").Output(\"output\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input,\n                           const std::shared_ptr<one::Tensor>& grid,\n                           const std::string& interpolation_mode, const std::string& padding_mode,\n                           const bool& align_corners) const {\n    auto& attrs =\n        THREAD_CACHED_MUTABLE_ATTR_MAP(\"interpolation_mode\", \"padding_mode\", \"align_corners\");\n    attrs.SetAllAttrs(interpolation_mode, padding_mode, align_corners);\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {input, grid}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass NormalizationFunctor {\n public:\n  NormalizationFunctor() {\n    norm_eval_op_ = CHECK_JUST(one::OpBuilder(\"normalization\")\n                                   .Input(\"x\")\n                                   .Input(\"moving_mean\")\n                                   .Input(\"moving_variance\")\n                                   .Input(\"gamma\")\n                                   .Input(\"beta\")\n                                   .Output(\"y\")\n                                   .Attr(\"training\", false)\n                                   .Build());\n    norm_training_stats_op_ = CHECK_JUST(one::OpBuilder(\"normalization\")\n                                             .Input(\"x\")\n                                             .Input(\"moving_mean\")\n                                             .Input(\"moving_variance\")\n                                             .Input(\"gamma\")\n                                             .Input(\"beta\")\n                                             .Output(\"y\")\n                                             .Output(\"mean\")\n                                             .Output(\"inv_variance\")\n                                             .Attr(\"training\", true)\n                                             .Build());\n    norm_training_no_stats_op_ = CHECK_JUST(one::OpBuilder(\"normalization\")\n                                                .Input(\"x\")\n                                                .Input(\"gamma\")\n                                                .Input(\"beta\")\n                                                .Output(\"y\")\n                                                .Output(\"mean\")\n                                                .Output(\"inv_variance\")\n                                                .Attr(\"training\", true)\n                                                .Build());\n    cast_op_ = CHECK_JUST(one::OpBuilder(\"cast\").Input(\"in\").Output(\"out\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,\n                           const Optional<one::Tensor>& moving_mean,\n                           const Optional<one::Tensor>& moving_variance,\n                           const Optional<one::Tensor>& gamma, const Optional<one::Tensor>& beta,\n                           const int32_t& axis, const float& epsilon, const float& momentum,\n                           const bool& training) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"axis\", \"epsilon\", \"momentum\");\n    // convert torch momentum to tensorflow momentum\n    attrs.SetAllAttrs(axis, epsilon, static_cast<float>(1.0 - momentum));\n\n    CHECK_OR_RETURN((moving_mean && moving_variance) || (!moving_mean && !moving_variance))\n        << Error::RuntimeError()\n        << \"Both running_mean and running_variance should be None or Tensor.\";\n\n    const DataType dtype = x->dtype()->data_type();\n\n    std::shared_ptr<one::Tensor> gamma_val;\n    std::shared_ptr<one::Tensor> beta_val;\n\n    CHECK_GE_OR_RETURN(x->shape()->NumAxes(), 2)\n        << Error::RuntimeError() << \"NumAxes of x should be greater or equal than 2. \";\n    if (gamma.has_value() && beta.has_value()) {\n      gamma_val = JUST(gamma);\n      beta_val = JUST(beta);\n    } else {\n      const Shape gamma_beta_shape = Shape({x->shape()->At(1)});\n      gamma_val = JUST(functional::Constant(gamma_beta_shape, 1.0, x->dtype(), JUST(x->device())));\n      beta_val = JUST(functional::Constant(gamma_beta_shape, 0.0, x->dtype(), JUST(x->device())));\n    }\n\n    const DataType gamma_dtype = gamma_val->dtype()->data_type();\n    const DataType beta_dtype = beta_val->dtype()->data_type();\n    CHECK_EQ_OR_RETURN(gamma_dtype, beta_dtype)\n        << Error::RuntimeError() << \"gamma and beta have different data types.\";\n    if (gamma_dtype != dtype) {\n      gamma_val = JUST(functional::Cast(gamma_val, DType{dtype}, /*pin_memory=*/false));\n      beta_val = JUST(functional::Cast(beta_val, DType{dtype}, /*pin_memory=*/false));\n    }\n\n    std::shared_ptr<one::Tensor> moving_mean_val;\n    std::shared_ptr<one::Tensor> moving_variance_val;\n    bool need_cast_moving_stats = false;\n    if (moving_mean) {\n      const DataType moving_mean_dtype = JUST(moving_mean)->dtype()->data_type();\n      CHECK_EQ_OR_RETURN(JUST(moving_variance)->dtype()->data_type(), moving_mean_dtype)\n          << Error::RuntimeError() << \"moving_mean and moving_variance have different data types.\";\n      need_cast_moving_stats = (moving_mean_dtype != dtype);\n      if (need_cast_moving_stats) {\n        moving_mean_val =\n            JUST(functional::Cast(JUST(moving_mean), DType{dtype}, /*pin_memory=*/false));\n        moving_variance_val =\n            JUST(functional::Cast(JUST(moving_variance), DType{dtype}, /*pin_memory=*/false));\n      } else {\n        moving_mean_val = JUST(moving_mean);\n        moving_variance_val = JUST(moving_variance);\n      }\n    }\n\n    std::shared_ptr<one::Tensor> res;\n\n    if (!training) {\n      CHECK_OR_RETURN(moving_mean && moving_variance)\n          << Error::RuntimeError() << \"Must have moving_mean and moving_variance in eval mode.\";\n      res = JUST(OpInterpUtil::Dispatch<one::Tensor>(\n          *norm_eval_op_, {x, moving_mean_val, moving_variance_val, gamma_val, beta_val}, attrs));\n    } else if (moving_mean) {\n      res = JUST(OpInterpUtil::Dispatch<one::Tensor>(\n          *norm_training_stats_op_, {x, moving_mean_val, moving_variance_val, gamma_val, beta_val},\n          attrs));\n    } else {\n      res = JUST(OpInterpUtil::Dispatch<one::Tensor>(*norm_training_no_stats_op_,\n                                                     {x, gamma_val, beta_val}, attrs));\n    }\n\n    if (need_cast_moving_stats) {\n      // For inplace update moving_mean and moving_variance\n      JUST(CheckInplaceValid(JUST(moving_mean)));\n      std::shared_ptr<TensorTuple> outputs = std::make_shared<TensorTuple>(1);\n      outputs->at(0) = JUST(moving_mean);\n      auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"dtype\", \"pin_memory\");\n      attrs.SetAllAttrs(JUST(moving_mean)->dtype()->data_type(), false);\n      JUST(OpInterpUtil::Dispatch(*cast_op_, {moving_mean_val}, outputs.get(), attrs));\n      JUST(CheckInplaceValid(JUST(moving_variance)));\n      outputs->at(0) = JUST(moving_variance);\n      JUST(OpInterpUtil::Dispatch(*cast_op_, {moving_variance_val}, outputs.get(), attrs));\n    }\n\n    return res;\n  }\n\n private:\n  std::shared_ptr<OpExpr> norm_eval_op_;\n  std::shared_ptr<OpExpr> norm_training_stats_op_;\n  std::shared_ptr<OpExpr> norm_training_no_stats_op_;\n  std::shared_ptr<OpExpr> cast_op_;\n};\n\nclass NormalizationAddReluFunctor {\n public:\n  NormalizationAddReluFunctor() {\n    fused_norm_training_stats_op_ =\n        CHECK_JUST(BuildFusedNormalizationOp(/*stats=*/true, /*addend=*/false, /*training=*/true));\n    fused_addend_norm_training_stats_op_ =\n        CHECK_JUST(BuildFusedNormalizationOp(/*stats=*/true, /*addend=*/true, /*training=*/true));\n    fused_norm_training_no_stats_op_ =\n        CHECK_JUST(BuildFusedNormalizationOp(/*stats=*/false, /*addend=*/false, /*training=*/true));\n    fused_addend_norm_training_no_stats_op_ =\n        CHECK_JUST(BuildFusedNormalizationOp(/*stats=*/false, /*addend=*/true, /*training=*/true));\n    fused_norm_eval_stats_op_ =\n        CHECK_JUST(BuildFusedNormalizationOp(/*stats=*/true, /*addend=*/false, /*training=*/false));\n    fused_addend_norm_eval_stats_op_ =\n        CHECK_JUST(BuildFusedNormalizationOp(/*stats=*/true, /*addend=*/true, /*training=*/false));\n  }\n\n  Maybe<one::UserOpExpr> BuildFusedNormalizationOp(bool stats, bool addend, bool training) {\n    auto op_builder = one::OpBuilder(\"normalization_add_relu\")\n                          .Input(\"x\")\n                          .Output(\"y\")\n                          .Output(\"reserve_space\")\n                          .Attr(\"training\", training);\n    if (addend) { op_builder.Input(\"addend\"); }\n    if (stats) { op_builder.Input(\"moving_mean\").Input(\"moving_variance\"); }\n    op_builder.Input(\"gamma\").Input(\"beta\");\n    if (training) { op_builder.Output(\"mean\").Output(\"inv_variance\"); }\n    return op_builder.Build();\n  }\n\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,\n                           const Optional<one::Tensor>& addend,\n                           const Optional<one::Tensor>& moving_mean,\n                           const Optional<one::Tensor>& moving_variance,\n                           const std::shared_ptr<one::Tensor>& gamma,\n                           const std::shared_ptr<one::Tensor>& beta, const int32_t& axis,\n                           const float& epsilon, const float& momentum,\n                           const bool& is_training) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"axis\", \"epsilon\", \"momentum\");\n    // convert torch momentum to tensorflow momentum\n    attrs.SetAllAttrs(axis, epsilon, static_cast<float>(1.0 - momentum));\n\n    CHECK_OR_RETURN((moving_mean && moving_variance) || (!moving_mean && !moving_variance))\n        << Error::RuntimeError()\n        << \"Both moving_mean and moving_variance should be None or Tensor.\";\n    if (!is_training) {\n      CHECK_OR_RETURN(moving_mean && moving_variance)\n          << Error::RuntimeError() << \"Must have moving_mean and moving_variance in eval mode.\";\n      if (addend) {\n        return OpInterpUtil::Dispatch<one::Tensor>(\n            *fused_addend_norm_eval_stats_op_,\n            {x, JUST(addend), JUST(moving_mean), JUST(moving_variance), gamma, beta}, attrs);\n      } else {\n        return OpInterpUtil::Dispatch<one::Tensor>(\n            *fused_norm_eval_stats_op_, {x, JUST(moving_mean), JUST(moving_variance), gamma, beta},\n            attrs);\n      }\n    } else if (moving_mean) {\n      if (addend) {\n        return OpInterpUtil::Dispatch<one::Tensor>(\n            *fused_addend_norm_training_stats_op_,\n            {x, JUST(addend), JUST(moving_mean), JUST(moving_variance), gamma, beta}, attrs);\n      } else {\n        return OpInterpUtil::Dispatch<one::Tensor>(\n            *fused_norm_training_stats_op_,\n            {x, JUST(moving_mean), JUST(moving_variance), gamma, beta}, attrs);\n      }\n    } else {\n      if (addend) {\n        return OpInterpUtil::Dispatch<one::Tensor>(*fused_addend_norm_training_no_stats_op_,\n                                                   {x, JUST(addend), gamma, beta}, attrs);\n      } else {\n        return OpInterpUtil::Dispatch<one::Tensor>(*fused_norm_training_no_stats_op_,\n                                                   {x, gamma, beta}, attrs);\n      }\n    }\n  }\n\n private:\n  std::shared_ptr<OpExpr> fused_norm_training_stats_op_;\n  std::shared_ptr<OpExpr> fused_addend_norm_training_stats_op_;\n  std::shared_ptr<OpExpr> fused_norm_training_no_stats_op_;\n  std::shared_ptr<OpExpr> fused_addend_norm_training_no_stats_op_;\n  std::shared_ptr<OpExpr> fused_norm_eval_stats_op_;\n  std::shared_ptr<OpExpr> fused_addend_norm_eval_stats_op_;\n};\n\nclass ConstantPadFunctor {\n public:\n  ConstantPadFunctor() {\n    constant_pad_ = CHECK_JUST(one::OpBuilder(\"pad\").Input(\"x\").Output(\"y\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input,\n                           const std::vector<int64_t>& pad, const Scalar& value) const {\n    const int64_t ndim = input->shape()->NumAxes();\n    const int64_t pad_size = pad.size();\n    CHECK_LE_OR_RETURN(pad_size, 2 * ndim)\n        << Error::RuntimeError() << \"Pad size should less than or equal to input axes * 2.\";\n    CHECK_EQ_OR_RETURN(pad_size % 2, 0)\n        << Error::RuntimeError() << \"Length of pad must be even but instead it equals \" << pad_size;\n\n    std::vector<int64_t> pad_before(ndim, 0);\n    std::vector<int64_t> pad_after(ndim, 0);\n    const int64_t pad_pair = pad_size / 2;\n    for (int64_t i = 0; i < pad_pair; ++i) {\n      pad_before[ndim - i - 1] = pad[2 * i];\n      pad_after[ndim - i - 1] = pad[2 * i + 1];\n    }\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"padding\", \"floating_constant_value\",\n                                                 \"integral_constant_value\", \"padding_before\",\n                                                 \"padding_after\");\n    if (IsFloatingDataType(input->dtype()->data_type())\n        || IsComplexDataType(input->dtype()->data_type())) {\n      attrs.SetAllAttrs(pad, value.As<double>(), static_cast<int64_t>(0), pad_before, pad_after);\n    } else if (IsIntegralDataType(input->dtype()->data_type())) {\n      attrs.SetAllAttrs(pad, static_cast<double>(0), value.As<int64_t>(), pad_before, pad_after);\n    } else if (input->dtype() == DType::Bool()) {\n      int64_t bool_value = value.As<int64_t>();\n      CHECK_OR_RETURN(bool_value == 1 || bool_value == 0)\n          << \"value must be 1/0 or True/False for bool Tensor\";\n      attrs.SetAllAttrs(pad, static_cast<double>(0), value.As<int64_t>(), pad_before, pad_after);\n    } else {\n      UNIMPLEMENTED_THEN_RETURN() << \"Data type should be floating, bool or integral type.\";\n    }\n    return OpInterpUtil::Dispatch<Tensor>(*constant_pad_, {input}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> constant_pad_;\n};\n\nclass ReflectionPadFunctor {\n public:\n  ReflectionPadFunctor() {\n    reflect_pad1d_ = CHECK_JUST(one::OpBuilder(\"reflection_pad1d\").Input(\"x\").Output(\"y\").Build());\n    reflect_pad2d_ = CHECK_JUST(one::OpBuilder(\"reflection_pad2d\").Input(\"x\").Output(\"y\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input,\n                           const std::vector<int64_t>& pad) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"padding\");\n    attrs.SetAllAttrs(pad);\n    const int64_t pad_size = pad.size();\n    const size_t ndim = input->ndim();\n    CHECK_LE_OR_RETURN(pad_size, 2 * ndim)\n        << Error::RuntimeError() << \"Pad size should less than or equal to input axes * 2.\";\n\n    if (pad_size == 2) {\n      // 2D/3D reflect padding\n      CHECK_OR_RETURN((ndim == 2 && input->shape()->At(1) != 0)\n                      || (ndim == 3 && input->shape()->At(1) != 0 && input->shape()->At(2) != 0))\n          << \"2D or 3D (batch mode) tensor expected for input, but got: \" << ndim;\n      const int64_t pad_left = pad[0];\n      const int64_t pad_right = pad[1];\n      const int64_t dim_w = (ndim == 3) ? 2 : 1;\n      const int64_t input_width = input->shape()->At(dim_w);\n      const int64_t output_w = input_width + pad_left + pad_right;\n      CHECK_OR_RETURN(pad_left < input_width && pad_right < input_width)\n          << \"Padding size should be less than the corresponding input dimension, but got: \"\n             \"padding (\"\n          << pad_left << \", \" << pad_right << \") at dimension \" << dim_w << \" of input \"\n          << input->shape()->ToString();\n      CHECK_OR_RETURN(output_w >= 1)\n          << \"input (W: \" << input_width << \")is too small. Calculated output W: \" << output_w;\n\n      if (ndim == 2) {\n        // for 2D input\n        auto unsqueezed_input = JUST(functional::Unsqueeze(input, 0));\n        auto unsqueezed_output =\n            JUST(OpInterpUtil::Dispatch<Tensor>(*reflect_pad1d_, {unsqueezed_input}, attrs));\n        return JUST(functional::Squeeze(unsqueezed_output, std::vector<int32_t>{0}));\n      }\n      return OpInterpUtil::Dispatch<Tensor>(*reflect_pad1d_, {input}, attrs);\n    } else if (pad_size == 4) {\n      // 3D/4D reflect padding\n      bool valid_dims = input->shape()->At(1) != 0 && input->shape()->At(2) != 0;\n      CHECK_OR_RETURN((ndim == 3 && valid_dims)\n                      || (ndim == 4 && valid_dims && input->shape()->At(3) != 0))\n          << \"3D or 4D (batch mode) tensor expected for input, but got: \" << ndim;\n\n      int dim_h = 1;\n      int dim_w = 2;\n      if (ndim == 4) {\n        dim_w++;\n        dim_h++;\n      }\n\n      const int64_t pad_left = pad[0];\n      const int64_t pad_right = pad[1];\n      const int64_t pad_top = pad[2];\n      const int64_t pad_bottom = pad[3];\n\n      const int64_t input_h = input->shape()->At(dim_h);\n      const int64_t input_w = input->shape()->At(dim_w);\n      const int64_t output_h = input_h + pad_top + pad_bottom;\n      const int64_t output_w = input_w + pad_left + pad_right;\n      CHECK_OR_RETURN(pad_left < input_w && pad_right < input_w)\n          << Error::RuntimeError()\n          << \"Padding size should be less than the corresponding input \"\n             \"dimension, but got: padding (\"\n          << pad_left << \", \" << pad_right << \") at dimension \" << dim_w << \" of input \" << ndim;\n\n      CHECK_OR_RETURN(pad_top < input_h && pad_bottom < input_h)\n          << Error::RuntimeError()\n          << \"Padding size should be less than the corresponding input \"\n             \"dimension, but got: padding (\"\n          << pad_top << \", \" << pad_bottom << \") at dimension \" << dim_h << \" of input \" << ndim;\n\n      CHECK_OR_RETURN(output_w >= 1 || output_h >= 1)\n          << Error::RuntimeError() << \"input (H: \" << input_h << \", W: \" << input_w\n          << \")is too small. Calculated output H: \" << output_h << \" W: \" << output_w;\n\n      if (ndim == 3) {\n        // for 3D input\n        auto unsqueezed_input = JUST(functional::Unsqueeze(input, 0));\n        auto unsqueezed_output =\n            JUST(OpInterpUtil::Dispatch<Tensor>(*reflect_pad2d_, {unsqueezed_input}, attrs));\n        return JUST(functional::Squeeze(unsqueezed_output, std::vector<int32_t>{0}));\n      }\n      return OpInterpUtil::Dispatch<Tensor>(*reflect_pad2d_, {input}, attrs);\n    } else if (pad_size == 6) {\n      UNIMPLEMENTED_THEN_RETURN() << \"5D reflect padding are not supported for now\";\n    } else {\n      UNIMPLEMENTED_THEN_RETURN()\n          << \"Only 2D, 3D, 4D, 5D padding with non-constant padding are supported for now\";\n    }\n  }\n\n private:\n  std::shared_ptr<OpExpr> reflect_pad1d_;\n  std::shared_ptr<OpExpr> reflect_pad2d_;\n};\n\nclass ReplicationPadFunctor {\n public:\n  ReplicationPadFunctor() {\n    replicate_pad1d_ =\n        CHECK_JUST(one::OpBuilder(\"replication_pad1d\").Input(\"x\").Output(\"y\").Build());\n    replicate_pad2d_ =\n        CHECK_JUST(one::OpBuilder(\"replication_pad2d\").Input(\"x\").Output(\"y\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input,\n                           const std::vector<int64_t>& pad) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"padding\");\n    attrs.SetAllAttrs(pad);\n    const int64_t pad_size = pad.size();\n    const size_t ndim = input->ndim();\n    CHECK_LE_OR_RETURN(pad_size, 2 * ndim)\n        << Error::RuntimeError() << \"Pad size should less than or equal to input axes * 2.\";\n    if (pad_size == 2) {\n      // 2D/3D replicate padding\n      CHECK_OR_RETURN((ndim == 2 && input->shape()->At(0) != 0 && input->shape()->At(1) != 0)\n                      || (ndim == 3 && input->shape()->At(1) != 0 && input->shape()->At(2) != 0))\n          << \"Expected 2D or 3D (batch mode) tensor with possibly 0 batch size and other \"\n             \"non-zero dimensions for input, but got: \"\n          << ndim;\n      const int64_t pad_left = pad[0];\n      const int64_t pad_right = pad[1];\n      const int64_t dim_w = (ndim == 3) ? 2 : 1;\n      const int64_t input_width = input->shape()->At(dim_w);\n      const int64_t output_w = input_width + pad_left + pad_right;\n      CHECK_OR_RETURN(output_w >= 1)\n          << \"input (W: \" << input_width << \")is too small. Calculated output W: \" << output_w;\n\n      if (ndim == 2) {\n        // for 2D input\n        auto unsqueezed_input = JUST(functional::Unsqueeze(input, 0));\n        auto unsqueezed_output =\n            JUST(OpInterpUtil::Dispatch<Tensor>(*replicate_pad1d_, {unsqueezed_input}, attrs));\n        return JUST(functional::Squeeze(unsqueezed_output, std::vector<int32_t>{0}));\n      }\n      return OpInterpUtil::Dispatch<Tensor>(*replicate_pad1d_, {input}, attrs);\n    } else if (pad_size == 4) {\n      // 3D/4D replicate padding\n      bool valid_dims = input->shape()->At(1) != 0 && input->shape()->At(2) != 0;\n      CHECK_OR_RETURN((ndim == 3 && valid_dims)\n                      || (ndim == 4 && valid_dims && input->shape()->At(3) != 0))\n          << \"3D or 4D (batch mode) tensor expected for input, but got: \" << ndim;\n\n      int dim_h = 1;\n      int dim_w = 2;\n      if (ndim == 4) {\n        dim_w++;\n        dim_h++;\n      }\n\n      const int64_t pad_left = pad[0];\n      const int64_t pad_right = pad[1];\n      const int64_t pad_top = pad[2];\n      const int64_t pad_bottom = pad[3];\n\n      const int64_t input_h = input->shape()->At(dim_h);\n      const int64_t input_w = input->shape()->At(dim_w);\n      const int64_t output_h = input_h + pad_top + pad_bottom;\n      const int64_t output_w = input_w + pad_left + pad_right;\n      CHECK_OR_RETURN(output_w >= 1 || output_h >= 1)\n          << Error::RuntimeError() << \"input (H: \" << input_h << \", W: \" << input_w\n          << \")is too small. Calculated output H: \" << output_h << \" W: \" << output_w;\n\n      if (ndim == 3) {\n        // for 3D input\n        auto unsqueezed_input = JUST(functional::Unsqueeze(input, 0));\n        auto unsqueezed_output =\n            JUST(OpInterpUtil::Dispatch<Tensor>(*replicate_pad2d_, {unsqueezed_input}, attrs));\n        return JUST(functional::Squeeze(unsqueezed_output, std::vector<int32_t>{0}));\n      }\n      return OpInterpUtil::Dispatch<Tensor>(*replicate_pad2d_, {input}, attrs);\n    } else if (pad_size == 6) {\n      UNIMPLEMENTED_THEN_RETURN() << \"5D replicate padding are not supported for now\";\n    } else {\n      UNIMPLEMENTED_THEN_RETURN()\n          << \"Only 2D, 3D, 4D, 5D padding with non-constant padding are supported for now\";\n    }\n  }\n\n private:\n  std::shared_ptr<OpExpr> replicate_pad1d_;\n  std::shared_ptr<OpExpr> replicate_pad2d_;\n};\n\nclass PadFunctor {\n public:\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input,\n                           const std::vector<int64_t>& pad, const std::string& mode,\n                           const Scalar& value) const {\n    if (mode == \"constant\") {\n      return functional::ConstantPad(input, pad, value);\n    } else if (mode == \"reflect\") {\n      return functional::ReflectionPad(input, pad);\n    } else if (mode == \"replicate\") {\n      return functional::ReplicationPad(input, pad);\n    } else {\n      UNIMPLEMENTED_THEN_RETURN() << \"Pad mode is \" << mode\n                                  << \", but only constant, reflect and replicate are valid.\";\n    }\n  }\n};\n\nclass DropoutFunctor {\n public:\n  DropoutFunctor() {\n    dropout_op_ =\n        CHECK_JUST(one::OpBuilder(\"dropout\").Input(\"in\").Output(\"out\").Output(\"mask\").Build());\n    dropout_addend_op_ = CHECK_JUST(one::OpBuilder(\"dropout\")\n                                        .Input(\"in\")\n                                        .Input(\"_add_to_output\")\n                                        .Output(\"out\")\n                                        .Output(\"mask\")\n                                        .Build());\n    add_op_ = CHECK_JUST(one::OpBuilder(\"add_n\").Input(\"in\", 2).Output(\"out\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const float& p,\n                           const bool& training, const bool& inplace,\n                           const Optional<one::Generator>& generator,\n                           const Optional<one::Tensor>& addend) const {\n    auto outputs = std::make_shared<TensorTuple>(1);\n    if (inplace) {\n      JUST(CheckInplaceValid(x));\n      (*outputs)[0] = x;\n    }\n\n    auto gen = generator.value_or(JUST(one::DefaultAutoGenerator()));\n    gen = JUST(GetGeneratorForLazyOrGlobal(gen, LazyMode::is_enabled(), x));\n    auto& dropout_attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"rate\", \"seed\");\n    dropout_attrs.SetAllAttrs(p, static_cast<int64_t>(gen->current_seed()));\n\n    const auto& dropout_state = std::make_shared<FusedDropoutKernelState>(gen);\n    OpExprInterpContext ctx(dropout_attrs, dropout_state);\n    if (addend) {\n      if ((!training) || p == 0.0) {\n        JUST(OpInterpUtil::Dispatch(*add_op_, {x, JUST(addend)}, outputs.get()));\n      } else {\n        outputs->resize(2);\n        JUST(OpInterpUtil::Dispatch(*dropout_addend_op_, {x, JUST(addend)}, outputs.get(), ctx));\n      }\n    } else {\n      if (!training || p == 0.0) {\n        return x;\n      } else {\n        outputs->resize(2);\n        JUST(OpInterpUtil::Dispatch(*dropout_op_, {x}, outputs.get(), ctx));\n      }\n    }\n    return (*outputs)[0];\n  }\n\n private:\n  std::shared_ptr<OpExpr> dropout_op_;\n  std::shared_ptr<OpExpr> dropout_addend_op_;\n  std::shared_ptr<OpExpr> add_op_;\n};\n\nnamespace {\nMaybe<Tensor> MakeFeatureNoise(const std::shared_ptr<one::Tensor>& x) {\n  const int64_t ndim = x->ndim();\n  CHECK_GE_OR_RETURN(ndim, 2) << Error::RuntimeError()\n                              << \"Feature dropout requires at least 2 dimensions in the input\";\n  std::vector<int64_t> sizes;\n  sizes.reserve(ndim);\n  sizes.push_back(x->shape()->At(0));\n  sizes.push_back(x->shape()->At(1));\n  for (int i = 2; i < ndim; i++) { sizes.push_back(1); }\n  return JUST(Empty(Shape(sizes), x->dtype(), JUST(x->device()),\n                    /*requires_grad=*/x->requires_grad(),\n                    /*pin_memory=*/false));\n}\n\nMaybe<Tensor> DropoutImpl(const std::shared_ptr<one::Tensor>& input, const float& p,\n                          const bool& train) {\n  CHECK_EQ_OR_RETURN(p >= 0 && p <= 1, true)\n      << \"dropout probability has to be between 0 and 1, but got \" << p;\n  if (p == 0 || !train || input->shape()->elem_cnt() == 0) { return input; }\n  if (p == 1) {\n    std::shared_ptr<Tensor> other =\n        JUST(Constant(*input->shape(), Scalar(0.0), input->dtype(), JUST(input->device())));\n    return Mul(input, other);\n  }\n  std::shared_ptr<Tensor> noise = JUST(MakeFeatureNoise(input));\n  noise =\n      JUST(BernoulliProb(noise, 1.0 - p, noise->dtype(), JUST(one::DefaultAutoGenerator()), false));\n  noise = JUST(InplaceScalarDiv(noise, Scalar(1.0 - p)));\n  return JUST(Mul(input, noise));\n}\n}  // namespace\n\nclass Dropout1dFunctor {\n public:\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input, const float& p,\n                           const bool& training) const {\n    CHECK_EQ_OR_RETURN(p < 0 || p > 1.0, false)\n        << \"dropout probability has to be between 0 and 1, but got \" << p;\n    const int input_dim = input->ndim();\n    CHECK_EQ_OR_RETURN(input_dim != 2 && input_dim != 3, false)\n        << \"dropout1d: Expected 2D or 3D input, but received a \" << input_dim\n        << \"D input. \"\n           \"Note that dropout1d exists to provide channel-wise dropout on inputs with 1 \"\n           \"spatial dimension, a channel dimension, and an optional batch dimension \"\n           \"(i.e. 2D or 3D inputs).\";\n    bool is_batched = (input_dim == 3);\n    std::shared_ptr<one::Tensor> result = input;\n    if (!is_batched) { result = JUST(Unsqueeze(input, 0)); }\n    result = JUST(DropoutImpl(result, p, training));\n    if (!is_batched) { result = JUST(Squeeze(result, std::vector<int32_t>{0})); }\n    return result;\n  }\n};\n\nclass Dropout2dFunctor {\n public:\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input, const float& p,\n                           const bool& training) const {\n    CHECK_EQ_OR_RETURN(p < 0 || p > 1.0, false)\n        << \"dropout probability has to be between 0 and 1, but got \" << p;\n    const int input_dim = input->ndim();\n    if (input_dim != 3 && input_dim != 4) {\n      LOG(WARNING)\n          << \"dropout2d: Received a \" << input_dim\n          << \"-D input to dropout2d, which is deprecated \"\n             \"and will result in an error in a future release. To retain the behavior \"\n             \"and silence this warning, please use dropout instead. Note that dropout2d \"\n             \"exists to provide channel-wise dropout on inputs with 2 spatial dimensions, \"\n             \"a channel dimension, and an optional batch dimension (i.e. 3D or 4D inputs).\";\n    }\n    if (input_dim == 3) {\n      LOG(WARNING)\n          << \"dropout2d: Received a 3D input to dropout2d and assuming that channel-wise \"\n             \"1D dropout behavior is desired - input is interpreted as shape (N, C, L), where C \"\n             \"is the channel dim. This behavior will change in a future release to interpret the \"\n             \"input as one without a batch dimension, i.e. shape (C, H, W). To maintain the 1D \"\n             \"channel-wise dropout behavior, please switch to using dropout1d instead.\";\n    }\n    return JUST(DropoutImpl(input, p, training));\n  }\n};\n\nclass Dropout3dFunctor {\n public:\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input, const float& p,\n                           const bool& training) const {\n    CHECK_EQ_OR_RETURN(p < 0 || p > 1.0, false)\n        << \"dropout probability has to be between 0 and 1, but got \" << p;\n    const int input_dim = input->ndim();\n    if (input_dim != 4 && input_dim != 5) {\n      LOG(WARNING)\n          << \"dropout3d: Received a \" << input_dim\n          << \"-D input to dropout3d, which is deprecated \"\n             \"and will result in an error in a future release. To retain the behavior \"\n             \"and silence this warning, please use dropout instead. Note that dropout3d \"\n             \"exists to provide channel-wise dropout on inputs with 3 spatial dimensions, \"\n             \"a channel dimension, and an optional batch dimension (i.e. 4D or 5D inputs).\";\n    }\n    bool is_batched = (input_dim == 5);\n    std::shared_ptr<one::Tensor> result = input;\n    if (!is_batched) { result = JUST(Unsqueeze(input, 0)); }\n    result = JUST(DropoutImpl(result, p, training));\n    if (!is_batched) { result = JUST(Squeeze(result, std::vector<int32_t>{0})); }\n    return result;\n  }\n};\n\nclass DropoutGradFunctor {\n public:\n  DropoutGradFunctor() {\n    dropout_grad_op_ =\n        CHECK_JUST(one::OpBuilder(\"dropout_grad\").Input(\"dy\").Input(\"mask\").Output(\"dx\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& dy,\n                           const std::shared_ptr<one::Tensor>& mask, const float& scale) const {\n    auto& dropout_grad_attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"scale\");\n    dropout_grad_attrs.SetAllAttrs(scale);\n    return OpInterpUtil::Dispatch<Tensor>(*dropout_grad_op_, {dy, mask}, dropout_grad_attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> dropout_grad_op_;\n};\n\nclass UnfoldFunctor {\n public:\n  UnfoldFunctor() {\n    unfold_op_ = CHECK_JUST(one::OpBuilder(\"unfold\").Input(\"x\").Output(\"y\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,\n                           const std::vector<int32_t>& kernel_size,\n                           const std::vector<int32_t>& dilation_rate,\n                           const std::vector<int32_t>& padding, const std::vector<int32_t>& strides,\n                           const std::string& data_format) const {\n    const auto& x_shape = x->shape();\n    // Only Support 4d tensor now.\n    CHECK_EQ_OR_RETURN(x_shape->NumAxes(), 4)\n        << Error::RuntimeError() << \"Input Tensor dim should == 4\";\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"kernel_size\", \"dilation_rate\", \"padding\",\n                                                 \"strides\", \"data_format\");\n    attrs.SetAllAttrs(kernel_size, dilation_rate, padding, strides, data_format);\n    return OpInterpUtil::Dispatch<Tensor>(*unfold_op_, {x}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> unfold_op_;\n};\n\nclass FoldFunctor {\n public:\n  FoldFunctor() { fold_op_ = CHECK_JUST(one::OpBuilder(\"fold\").Input(\"x\").Output(\"y\").Build()); }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,\n                           const std::vector<int32_t>& output_size,\n                           const std::vector<int32_t>& kernel_size,\n                           const std::vector<int32_t>& dilation_rate,\n                           const std::vector<int32_t>& padding, const std::vector<int32_t>& strides,\n                           const std::string& data_format) const {\n    const auto& x_shape = x->shape();\n    // Only Support 3d tensor fold now. format is (N, C*K*K, L)\n    CHECK_EQ_OR_RETURN(x_shape->NumAxes(), 3)\n        << Error::RuntimeError() << \"Input Tensor dim should == 3\";\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"output_size\", \"kernel_size\", \"dilation_rate\",\n                                                 \"padding\", \"strides\", \"data_format\");\n    attrs.SetAllAttrs(output_size, kernel_size, dilation_rate, padding, strides, data_format);\n    return OpInterpUtil::Dispatch<Tensor>(*fold_op_, {x}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> fold_op_;\n};\n\nclass OneHotFunctor {\n public:\n  OneHotFunctor() {\n    one_hot_op_ = CHECK_JUST(one::OpBuilder(\"one_hot\").Input(\"indices\").Output(\"out\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input, const int64_t& num_classes,\n                           const Scalar& on_value, const Scalar& off_value) const {\n    CHECK_OR_RETURN(!IsFloatingDataType(input->dtype()->data_type()))\n        << Error::RuntimeError() << \"one_hot is only applicable to index tensor.\";\n    auto& attrs =\n        THREAD_CACHED_MUTABLE_ATTR_MAP(\"depth\", \"dtype\", \"floating_on_value\", \"floating_off_value\",\n                                       \"integer_on_value\", \"integer_off_value\");\n    int64_t depth = num_classes;\n    if (num_classes == -1) {\n      std::vector<int32_t> axis(input->ndim());\n      std::iota(axis.begin(), axis.end(), 0);\n      auto tensor_max = JUST(functional::ReduceMax(input, axis, false));\n\n      int64_t max = 0;\n      const auto& callback = [&](ep::Stream* stream,\n                                 const std::shared_ptr<vm::EagerBlobObject>& eager_blob_object) {\n        SyncAutoMemcpy(stream, &max, eager_blob_object->dptr(), sizeof(max),\n                       memory::MakeHostMemCase(), eager_blob_object->mem_case());\n      };\n      JUST(SyncAccessTensorWithTimeOut(tensor_max, callback, \"const\"));\n      depth = max + 1;\n    }\n    // Refer to: https://github.com/Oneflow-Inc/oneflow/pull/5315/files#r755823506\n    bool is_on_value_double = on_value.IsFloatingPoint();\n    bool is_off_value_double = off_value.IsFloatingPoint();\n    if (is_on_value_double || is_off_value_double) {\n      attrs.SetAllAttrs(depth, kFloat, on_value.As<double>(), off_value.As<double>(),\n                        static_cast<int64_t>(0), static_cast<int64_t>(0));\n    } else {\n      attrs.SetAllAttrs(depth, kInt64, static_cast<double>(0), static_cast<double>(0),\n                        on_value.As<int64_t>(), off_value.As<int64_t>());\n    }\n    return OpInterpUtil::Dispatch<Tensor>(*one_hot_op_, {input}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> one_hot_op_;\n};\n\nclass PairwiseDistanceFunctor {\n public:\n  Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& x, const std::shared_ptr<Tensor>& y,\n                           const float& p, const double& eps, bool keepdim) const {\n    const int64_t xdim = x->ndim();\n    const int64_t ydim = y->ndim();\n    const int64_t output_dim = xdim > ydim ? xdim : ydim;\n    const auto& sub = JUST(ScalarAdd(JUST(Sub(x, y, 1, false)), eps, 1, false));\n    return ScalarNorm(sub, p, output_dim - 1, keepdim, NullOpt);\n  }\n};\n\nclass CosineSimilarityFunctor {\n public:\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,\n                           const std::shared_ptr<one::Tensor>& y, const int32_t& dim,\n                           const double& eps) const {\n    const auto& x_shape = *(x->shape());\n    const auto& y_shape = *(y->shape());\n    std::shared_ptr<one::Tensor> x_extend = x;\n    std::shared_ptr<one::Tensor> y_extend = y;\n    if (x_shape != y_shape) {\n      Shape max_shape = Shape::Ones(std::max(x_shape.NumAxes(), y_shape.NumAxes()));\n      for (int64_t i = max_shape.NumAxes() - 1; i >= 0; i--) {\n        int64_t offset = max_shape.NumAxes() - 1 - i;\n        int64_t dim_x = x_shape.NumAxes() - 1 - offset;\n        int64_t dim_y = y_shape.NumAxes() - 1 - offset;\n        int64_t size_x = (dim_x >= 0) ? x_shape.At(dim_x) : 1;\n        int64_t size_y = (dim_y >= 0) ? y_shape.At(dim_y) : 1;\n        if (!(size_x == size_y || size_x == 1 || size_y == 1)) {\n          return Error::RuntimeError()\n                 << \"The size of tensor a (\" << size_x << \") must match the size of tensor b (\"\n                 << size_y << \") at non-singleton dimension \" << i;\n        }\n        max_shape.Set(i, std::max(size_x, size_y));\n      }\n      x_extend = JUST(Expand(x, max_shape));\n      y_extend = JUST(Expand(y, max_shape));\n    }\n    TensorProcessor tensor_processor;\n    JUST(tensor_processor.PromoteInputsToCommonDtype(true).AddInputs({x_extend, y_extend}).Apply());\n    TensorTuple input_vec = JUST(tensor_processor.GetInputs());\n    const auto common_dtype = JUST(oneflow::VectorAt(input_vec, 0))->dtype();\n    if (!IsFloatingDataType(common_dtype->data_type())) {\n      return Error::RuntimeError()\n             << \"expected common dtype to be floating point, yet common dtype is \"\n             << common_dtype->name();\n    }\n    auto& x_ = JUST(oneflow::VectorAt(input_vec, 0));\n    auto& y_ = JUST(oneflow::VectorAt(input_vec, 1));\n    std::shared_ptr<Tensor> w12 =\n        JUST(functional::ReduceSum(JUST(functional::Mul(x_, y_)), {dim}, false, NullOpt));\n    std::shared_ptr<Tensor> w1 =\n        JUST(functional::ReduceSum(JUST(functional::Mul(x_, x_)), {dim}, false, NullOpt));\n    std::shared_ptr<Tensor> w2 =\n        JUST(functional::ReduceSum(JUST(functional::Mul(y_, y_)), {dim}, false, NullOpt));\n    std::shared_ptr<Tensor> n12 = JUST(functional::Sqrt(\n        JUST(functional::Clamp(JUST(functional::Mul(w1, w2)), Scalar(eps * eps), NullOpt))));\n    return functional::Div(w12, n12);\n  }\n};\n\nclass L2NormalizeFunctor {\n public:\n  L2NormalizeFunctor() {\n    op_ = CHECK_JUST(\n        one::OpBuilder(\"l2_normalize\").Input(\"x\").Output(\"y\").Output(\"square_x_sum\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input, const int32_t& axis,\n                           const float& epsilon) const {\n    const int32_t ndims = input->shape()->NumAxes();\n    const int32_t final_dim = ndims - 1;\n\n    auto axis_ = axis >= 0 ? axis : axis + ndims;\n    CHECK_GE_OR_RETURN(axis_, 0) << Error::RuntimeError() << \"Axis should >=0 but axis is \" << axis_\n                                 << \" now.\";\n    CHECK_LE_OR_RETURN(axis_, final_dim) << Error::RuntimeError() << \"Axis should < \" << ndims\n                                         << \" but axis is \" << axis_ << \" now.\";\n\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"epsilon\", \"axis\");\n    attrs.SetAllAttrs(epsilon, final_dim);\n\n    if (axis_ == final_dim) { return OpInterpUtil::Dispatch<Tensor>(*op_, {input}, attrs); }\n\n    std::vector<int> input_perm(input->shape()->dim_vec().size(), 0);\n    for (size_t i = 0; i < input_perm.size(); ++i) { input_perm[i] = static_cast<int>(i); }\n    std::swap(input_perm[final_dim], input_perm[static_cast<size_t>(axis_)]);\n\n    const auto result = JUST(OpInterpUtil::Dispatch<TensorTuple>(\n        *op_, {JUST(functional::Transpose(input, input_perm))}, attrs));\n    return functional::Transpose((*result)[0], input_perm);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass NormalizeFunctor {\n public:\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input, const float& p,\n                           const int32_t& dim, const float& eps,\n                           const bool& use_l2_norm_kernel) const {\n    if (use_l2_norm_kernel && (std::fabs(p - 2.0f) < std::numeric_limits<float>::min())) {\n      return functional::L2Normalize(input, dim, eps);\n    }\n    return SequenceFunction<Maybe<Tensor>(const std::shared_ptr<Tensor>&, const float&,\n                                          const int32_t&)>(\n               [](const auto& x, const float& p, const int32_t& dim) -> Maybe<Tensor> {\n                 return functional::ScalarNorm(x, p, dim, true, NullOpt);\n               })\n        .then([&](const auto& x) { return functional::Clamp(x, eps, NullOpt); })\n        .then([&](const auto& x) { return functional::Div(input, x); })\n        .call(input, p, dim);\n  }\n};\n\nclass FusedSelfAttentionFunctor {\n public:\n  FusedSelfAttentionFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"fused_self_attention_query_mul_key_and_value\")\n                         .Input(\"hidden_states\")\n                         .Output(\"query_mul_key\")\n                         .Output(\"value\")\n                         .Build());\n  }\n  Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& hidden_states,\n                                const int64_t& head_size, const float& alpha) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"head_size\", \"alpha\");\n    attrs.SetAllAttrs(head_size, alpha);\n    return OpInterpUtil::Dispatch<TensorTuple>(*op_, {hidden_states}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass FusedSelfAttentionGradFunctor {\n public:\n  FusedSelfAttentionGradFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"fused_self_attention_query_mul_key_and_value_grad\")\n                         .Input(\"query_mul_key_grad\")\n                         .Input(\"value_grad\")\n                         .Input(\"hidden_states\")\n                         .Output(\"hidden_states_grad\")\n                         .Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& query_mul_key_grad,\n                           const std::shared_ptr<one::Tensor>& value_grad,\n                           const std::shared_ptr<one::Tensor>& hidden_states,\n                           const float& alpha) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"alpha\");\n    attrs.SetAllAttrs(alpha);\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {query_mul_key_grad, value_grad, hidden_states},\n                                          attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass FusedScaleTrilSoftmaxMaskScaleFunctor {\n public:\n  FusedScaleTrilSoftmaxMaskScaleFunctor() {\n    random_mask_like_op_ =\n        CHECK_JUST(one::OpBuilder(\"random_mask_like\").Input(\"like\").Output(\"out\").Build());\n    fused_op_ = CHECK_JUST(one::OpBuilder(\"fused_tril_scale_softmax_mask_scale\")\n                               .Input(\"x\")\n                               .Input(\"mask\")\n                               .Output(\"y\")\n                               .Output(\"softmax_y\")\n                               .Build());\n  }\n  Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& x, const float p,\n                                const int64_t diagonal, const float tril_scale_value,\n                                const float tril_fill_value,\n                                const Optional<one::Generator>& generator) const {\n    auto gen = generator.value_or(JUST(one::DefaultAutoGenerator()));\n    gen = JUST(GetGeneratorForLazyOrGlobal(gen, LazyMode::is_enabled(), x));\n    auto& random_mask_like_attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"rate\", \"seed\");\n    random_mask_like_attrs.SetAllAttrs(p, static_cast<int64_t>(gen->current_seed()));\n    const auto& random_mask_like_state = std::make_shared<RandomMaskLikeKernelState>(gen);\n    const auto& mask = JUST(OpInterpUtil::Dispatch<Tensor>(\n        *random_mask_like_op_, {x},\n        OpExprInterpContext(random_mask_like_attrs, random_mask_like_state)));\n\n    float mask_scale_value = 1.0;\n    if (p != 1.0) { mask_scale_value = 1.0 / (1.0 - p); }\n    auto& fused_attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"diagonal\", \"tril_scale_value\",\n                                                       \"mask_scale_value\", \"tril_fill_value\");\n    fused_attrs.SetAllAttrs(diagonal, tril_scale_value, mask_scale_value, tril_fill_value);\n    return OpInterpUtil::Dispatch<TensorTuple>(*fused_op_, {x, mask}, fused_attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> fused_op_;\n  std::shared_ptr<OpExpr> random_mask_like_op_;\n};\n\nclass L2NormalizeGradFunctor {\n public:\n  L2NormalizeGradFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"l2_normalize_grad\")\n                         .Input(\"dy\")\n                         .Input(\"y\")\n                         .Input(\"square_x_sum\")\n                         .Output(\"dx\")\n                         .Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& dy,\n                           const std::shared_ptr<one::Tensor>& y,\n                           const std::shared_ptr<one::Tensor>& square_x_sum, const int32_t& axis,\n                           const float& epsilon) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"axis\", \"epsilon\");\n    attrs.SetAllAttrs(axis, epsilon);\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {dy, y, square_x_sum}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass FusedBiasAddGeluFunctor {\n public:\n  FusedBiasAddGeluFunctor() {\n    op_ = CHECK_JUST(\n        one::OpBuilder(\"fused_bias_add_gelu\").Input(\"a\").Input(\"b\").Output(\"out\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& a,\n                           const std::shared_ptr<one::Tensor>& b, const int32_t& axis) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"axis\");\n    attrs.SetAllAttrs(axis);\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {a, b}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass FusedBiasAddGeluGradFunctor {\n public:\n  FusedBiasAddGeluGradFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"fused_bias_add_gelu_grad\")\n                         .Input(\"a\")\n                         .Input(\"b\")\n                         .Input(\"dy\")\n                         .Output(\"dx\")\n                         .Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& a,\n                           const std::shared_ptr<one::Tensor>& b,\n                           const std::shared_ptr<one::Tensor>& dy, const int32_t& axis) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"axis\");\n    attrs.SetAllAttrs(axis);\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {a, b, dy}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass FusedBiasAddDropoutFunctor {\n public:\n  FusedBiasAddDropoutFunctor() {\n    random_mask_like_op_ =\n        CHECK_JUST(one::OpBuilder(\"random_mask_like\").Input(\"like\").Output(\"out\").Build());\n    fused_bias_add_mask_scale_op_ = CHECK_JUST(one::OpBuilder(\"fused_bias_add_mask_scale\")\n                                                   .Input(\"a\")\n                                                   .Input(\"b\")\n                                                   .Input(\"mask\")\n                                                   .Output(\"out\")\n                                                   .Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& a,\n                           const std::shared_ptr<one::Tensor>& b, const float& p,\n                           const int32_t& axis, const Optional<one::Generator>& generator) const {\n    int32_t axis_val = axis;\n    if (axis_val < 0) {\n      const int64_t num_axes = a->shape()->NumAxes();\n      axis_val += num_axes;\n    }\n    if (p > 0.0) {\n      auto gen = generator.value_or(JUST(one::DefaultAutoGenerator()));\n      gen = JUST(GetGeneratorForLazyOrGlobal(gen, LazyMode::is_enabled(), a));\n      auto& random_mask_like_attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"rate\", \"seed\");\n      random_mask_like_attrs.SetAllAttrs(p, static_cast<int64_t>(gen->current_seed()));\n      const auto& random_mask_like_state = std::make_shared<RandomMaskLikeKernelState>(gen);\n\n      float scale = 0.0;\n      if (p != 1.0) { scale = 1.0 / (1.0 - p); }\n      auto& fused_bias_add_mask_attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"scale\", \"axis\");\n      fused_bias_add_mask_attrs.SetAllAttrs(scale, axis_val);\n\n      return SequenceFunction<Maybe<Tensor>()>([&]() -> Maybe<Tensor> {\n               return OpInterpUtil::Dispatch<Tensor>(\n                   *random_mask_like_op_, {a},\n                   OpExprInterpContext(random_mask_like_attrs, random_mask_like_state));\n             })\n          .then([&](const std::shared_ptr<one::Tensor>& x) {\n            return OpInterpUtil::Dispatch<Tensor>(*fused_bias_add_mask_scale_op_, {a, b, x},\n                                                  fused_bias_add_mask_attrs);\n          })\n          .call();\n    } else {\n      return functional::BiasAdd(a, b, axis_val);\n    }\n  }\n\n private:\n  std::shared_ptr<OpExpr> random_mask_like_op_;\n  std::shared_ptr<OpExpr> fused_bias_add_mask_scale_op_;\n};\n\nclass FusedGluFunctor {\n public:\n  FusedGluFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"fused_glu\")\n                         .Input(\"x\")\n                         .Input(\"w\")\n                         .Input(\"b\")\n                         .Output(\"y\")\n                         .Output(\"matmul_wx\")\n                         .Build());\n\n    op_without_bias_ = CHECK_JUST(\n        one::OpBuilder(\"fused_glu\").Input(\"x\").Input(\"w\").Output(\"y\").Output(\"matmul_wx\").Build());\n\n    split_op_ = CHECK_JUST(one::OpBuilder(\"fused_glu\")\n                               .Input(\"x\")\n                               .Input(\"w\")\n                               .Input(\"b\")\n                               .Input(\"v\")\n                               .Input(\"c\")\n                               .Output(\"y\")\n                               .Output(\"matmul_wx\")\n                               .Output(\"matmul_vx\")\n                               .Build());\n\n    split_op_without_bias_ = CHECK_JUST(one::OpBuilder(\"fused_glu\")\n                                            .Input(\"x\")\n                                            .Input(\"w\")\n                                            .Input(\"v\")\n                                            .Output(\"y\")\n                                            .Output(\"matmul_wx\")\n                                            .Output(\"matmul_vx\")\n                                            .Build());\n  }\n\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,\n                           const std::shared_ptr<one::Tensor>& w, const Optional<one::Tensor>& b,\n                           const Optional<one::Tensor>& v, const Optional<one::Tensor>& c,\n                           const std::string& activation) const {\n    // check whether the user provide weight tensor v\n    bool is_split_mode = false;\n    if (v) {\n      is_split_mode = true;\n    } else {\n      is_split_mode = false;\n    }\n\n    // check whether the user provide bias tensors\n    bool has_bias = false;\n    if (b) {\n      has_bias = true;\n      if (is_split_mode) {\n        CHECK_OR_RETURN(c) << \"expected existance of c, when provide tensors w, v and b\";\n      }\n    } else {\n      CHECK_OR_RETURN(!c) << \"expected existance of b while providing c\";\n      has_bias = false;\n    }\n\n    // obtain input shape\n    const auto& x_shape = *(x->shape());\n    const auto& w_shape = *(w->shape());\n    std::shared_ptr<const oneflow::Shape> b_shape = nullptr;\n    if (has_bias) { b_shape = (JUST(b)->shape()); }\n\n    // check number of axes of x, w and b\n    CHECK_GT_OR_RETURN(x_shape.NumAxes(), 1)\n        << \"number of axes of \\'x\\' should have be greater than 1, yet get \" << x_shape.NumAxes();\n    CHECK_EQ_OR_RETURN(w_shape.NumAxes(), 2)\n        << \"number of axes of \\'w\\' should have be equal to 2, yet get \" << w_shape.NumAxes();\n    if (has_bias) {\n      CHECK_EQ_OR_RETURN(b_shape->NumAxes(), 1)\n          << \"number of axes of \\'b\\' should have be equal to 1, yet get \" << b_shape->NumAxes();\n    }\n\n    // check input shapes of w and b\n    size_t x_num_axes = x_shape.NumAxes();\n    CHECK_EQ_OR_RETURN(w_shape.At(1), x_shape.At(x_num_axes - 1))\n        << \"dimension 1 of \\'w\\'(\" << w_shape.At(1)\n        << \") is not consistant with the last dimension of \\'x\\'(\" << x_shape.At(x_num_axes - 1)\n        << \")\";\n    if (has_bias) {\n      CHECK_EQ_OR_RETURN(b_shape->At(0), w_shape.At(0))\n          << \"dimension 0 of \\'b\\'(\" << b_shape->At(0)\n          << \") is not consistant with dimension 0 of \\'w\\'(\" << w_shape.At(0) << \")\";\n    }\n    if (!is_split_mode) {\n      CHECK_EQ_OR_RETURN(w_shape.At(1) % 2, 0) << \"dimension 1 of \\'w\\' is not divisible by 2\";\n    }\n\n    // check both dimensions and input shapes of v and c (optional)\n    if (is_split_mode) {\n      const auto& v_shape = *(JUST(v)->shape());\n      std::shared_ptr<const oneflow::Shape> c_shape = NULL;\n      if (has_bias) { c_shape = (JUST(c)->shape()); }\n\n      CHECK_EQ_OR_RETURN(v_shape.NumAxes(), 2)\n          << \"number of axes of \\'v\\' should have be equal to 2, yet get \" << v_shape.NumAxes();\n      if (has_bias) {\n        CHECK_EQ_OR_RETURN(c_shape->NumAxes(), 1)\n            << \"number of axes of \\'c\\' should have be equal to 1, yet get \" << c_shape->NumAxes();\n      }\n\n      CHECK_OR_RETURN(v_shape == w_shape) << \"the shape of \\'v\\' is not consistant with \\'w\\'\";\n      if (has_bias) {\n        CHECK_OR_RETURN((*c_shape) == (*b_shape))\n            << \"the shape of \\'c\\' is not consistant with \\'b\\'\";\n      }\n    }\n\n    // set activation attribute\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"activation\", \"has_bias\", \"is_split\");\n    attrs.SetAllAttrs(activation, has_bias, is_split_mode);\n\n    // dispatch corresponding operator\n    if (is_split_mode && has_bias) {\n      return OpInterpUtil::Dispatch<one::Tensor>(*split_op_, {x, w, JUST(b), JUST(v), JUST(c)},\n                                                 attrs);\n    } else if (!is_split_mode && has_bias) {\n      return OpInterpUtil::Dispatch<one::Tensor>(*op_, {x, w, JUST(b)}, attrs);\n    } else if (is_split_mode && !has_bias) {\n      return OpInterpUtil::Dispatch<one::Tensor>(*split_op_without_bias_, {x, w, JUST(v)}, attrs);\n    } else if (!is_split_mode && !has_bias) {\n      return OpInterpUtil::Dispatch<one::Tensor>(*op_without_bias_, {x, w}, attrs);\n    } else {\n      UNIMPLEMENTED_THEN_RETURN();\n    }\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n  std::shared_ptr<OpExpr> op_without_bias_;\n  std::shared_ptr<OpExpr> split_op_;\n  std::shared_ptr<OpExpr> split_op_without_bias_;\n};\n\nclass FusedScaleTrilFunctor {\n public:\n  FusedScaleTrilFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"fused_scale_tril\").Input(\"in\").Output(\"out\").Build());\n  }\n\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const int64_t& diagonal,\n                           const Scalar& fill_value, const Scalar& scale) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\n        \"diagonal\", \"floating_fill_value\", \"is_floating_fill_value\", \"integer_fill_value\",\n        \"floating_scale_value\", \"is_floating_scale_value\", \"integer_scale_value\");\n    bool is_fill_value_double = fill_value.IsFloatingPoint();\n    bool is_scale_double = scale.IsFloatingPoint();\n\n    double floating_fill_value = 0;\n    int64_t integer_fill_value = 0;\n    if (is_fill_value_double) {\n      floating_fill_value = fill_value.As<double>();\n    } else {\n      integer_fill_value = fill_value.As<int64_t>();\n    }\n    double floating_scale_value = 0;\n    int64_t integer_scale_value = 0;\n    if (is_scale_double) {\n      floating_scale_value = scale.As<double>();\n    } else {\n      integer_scale_value = scale.As<int64_t>();\n    }\n    attrs.SetAllAttrs(diagonal, floating_fill_value, is_fill_value_double, integer_fill_value,\n                      floating_scale_value, is_scale_double, integer_scale_value);\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {x}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass FusedScaleMaskSoftmaxFunctor {\n public:\n  FusedScaleMaskSoftmaxFunctor() {\n    op_ = CHECK_JUST(\n        one::OpBuilder(\"fused_scale_mask_softmax\").Input(\"x\").Input(\"mask\").Output(\"y\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,\n                           const std::shared_ptr<one::Tensor>& mask, const float& fill_value,\n                           const float& scale) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"scale_value\", \"mask_fill_value\");\n    attrs.SetAllAttrs(scale, fill_value);\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {x, mask}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass FusedScaleMaskSoftmaxDropoutFunctor {\n public:\n  FusedScaleMaskSoftmaxDropoutFunctor() {\n    random_mask_like_op_ =\n        CHECK_JUST(one::OpBuilder(\"random_mask_like\").Input(\"like\").Output(\"out\").Build());\n    fused_scale_mask_softmax_dropout_op_ =\n        CHECK_JUST(one::OpBuilder(\"fused_scale_mask_softmax_dropout\")\n                       .Input(\"x\")\n                       .Input(\"mask\")\n                       .Input(\"dropout_mask\")\n                       .Output(\"y\")\n                       .Output(\"softmax_y\")\n                       .Build());\n  }\n  Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& x,\n                                const std::shared_ptr<one::Tensor>& mask, const float& fill_value,\n                                const float& scale, const float& p, const bool& training,\n                                const Optional<one::Generator>& generator) const {\n    float rate = p;\n    if (!training) rate = 0.0;\n    auto gen = generator.value_or(JUST(one::DefaultAutoGenerator()));\n    gen = JUST(GetGeneratorForLazyOrGlobal(gen, LazyMode::is_enabled(), x));\n    auto& random_mask_like_attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"rate\", \"seed\");\n    random_mask_like_attrs.SetAllAttrs(rate, static_cast<int64_t>(gen->current_seed()));\n    const auto& random_mask_like_state = std::make_shared<RandomMaskLikeKernelState>(gen);\n    const auto& dropout_mask = JUST(OpInterpUtil::Dispatch<Tensor>(\n        *random_mask_like_op_, {x},\n        OpExprInterpContext(random_mask_like_attrs, random_mask_like_state)));\n\n    float dropout_scale = 0.0;\n    if (rate != 1.0) { dropout_scale = 1.0 / (1.0 - rate); }\n    auto& fused_scale_mask_softmax_dropout_attrs =\n        THREAD_CACHED_MUTABLE_ATTR_MAP(\"scale_value\", \"mask_fill_value\", \"dropout_scale_value\");\n    fused_scale_mask_softmax_dropout_attrs.SetAllAttrs(scale, fill_value, dropout_scale);\n    return OpInterpUtil::Dispatch<TensorTuple>(*fused_scale_mask_softmax_dropout_op_,\n                                               {x, mask, dropout_mask},\n                                               fused_scale_mask_softmax_dropout_attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> random_mask_like_op_;\n  std::shared_ptr<OpExpr> fused_scale_mask_softmax_dropout_op_;\n};\n\n// Equivalent to\n// masked = (x + bias) * mask * scale_value\n// unmask = (1 - mask).bool()\n// masked.masked_fill_(unmask, mask_fill_value)\n// softmax_y = softmax(masked, dim=-1)\n// y = dropout(softmax_y, p)\nclass FusedBiasAddScaleMaskSoftmaxDropoutFunctor {\n public:\n  FusedBiasAddScaleMaskSoftmaxDropoutFunctor() {\n    random_mask_op_ =\n        CHECK_JUST(one::OpBuilder(\"random_mask_like\").Input(\"like\").Output(\"out\").Build());\n    fused_op_ = CHECK_JUST(one::OpBuilder(\"fused_bias_add_scale_mask_softmax_dropout\")\n                               .Input(\"x\")\n                               .Input(\"bias\")\n                               .Input(\"mask\")\n                               .Input(\"dropout_mask\")\n                               .Output(\"y\")\n                               .Output(\"softmax_y\")\n                               .Build());\n  }\n  Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& x,\n                                const std::shared_ptr<one::Tensor>& bias,\n                                const std::shared_ptr<one::Tensor>& mask, const float& fill_value,\n                                const float& scale, const float& p, const bool& training,\n                                const Optional<one::Generator>& generator) const {\n    float rate = p;\n    if (!training) rate = 0.0;\n    auto gen = generator.value_or(JUST(one::DefaultAutoGenerator()));\n    gen = JUST(GetGeneratorForLazyOrGlobal(gen, LazyMode::is_enabled(), x));\n    auto& random_mask_like_attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"rate\", \"seed\");\n    random_mask_like_attrs.SetAllAttrs(rate, static_cast<int64_t>(gen->current_seed()));\n    const auto& random_mask_like_state = std::make_shared<RandomMaskLikeKernelState>(gen);\n    const auto& dropout_mask = JUST(OpInterpUtil::Dispatch<Tensor>(\n        *random_mask_op_, {x},\n        OpExprInterpContext(random_mask_like_attrs, random_mask_like_state)));\n\n    float dropout_scale = 0.0;\n    if (rate != 1.0) { dropout_scale = 1.0 / (1.0 - rate); }\n    auto& fused_scale_mask_softmax_dropout_attrs =\n        THREAD_CACHED_MUTABLE_ATTR_MAP(\"scale_value\", \"mask_fill_value\", \"dropout_scale_value\");\n    fused_scale_mask_softmax_dropout_attrs.SetAllAttrs(scale, fill_value, dropout_scale);\n    return OpInterpUtil::Dispatch<TensorTuple>(*fused_op_, {x, bias, mask, dropout_mask},\n                                               fused_scale_mask_softmax_dropout_attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> random_mask_op_;\n  std::shared_ptr<OpExpr> fused_op_;\n};\n\nclass CtcGreedyDecoderFunctor {\n public:\n  CtcGreedyDecoderFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"ctc_greedy_decoder\")\n                         .Input(\"log_probs\")\n                         .Input(\"input_lengths\")\n                         .Output(\"decoded\")\n                         .Output(\"neg_sum_logits\")\n                         .Build());\n  }\n  Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& log_probs,\n                                const std::shared_ptr<one::Tensor>& input_lengths,\n                                const bool& merge_repeated) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"merge_repeated\");\n    attrs.SetAllAttrs(merge_repeated);\n    return OpInterpUtil::Dispatch<TensorTuple>(*op_, {log_probs, input_lengths}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass PariticalFCSampleDisableBoxing {\n public:\n  PariticalFCSampleDisableBoxing() {\n    op_ = CHECK_JUST(one::OpBuilder(\"distributed_partial_fc_sample_disable_boxing\")\n                         .Input(\"sampled_weight_diff\")\n                         .Input(\"sampled_label\")\n                         .Output(\"boxing_disabled_sampled_weight_diff\")\n                         .Output(\"boxing_disabled_sampled_label\")\n                         .Build());\n  }\n  Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& sampled_weight_diff,\n                                const std::shared_ptr<one::Tensor>& sampled_label) const {\n    return OpInterpUtil::Dispatch<TensorTuple>(*op_, {sampled_weight_diff, sampled_label});\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass NmsFunctor {\n public:\n  NmsFunctor() { op_ = CHECK_JUST(one::OpBuilder(\"nms\").Input(\"in\").Output(\"out\").Build()); }\n\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const float& iou_threshold,\n                           const int32_t& keep_n) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"iou_threshold\", \"keep_n\");\n    attrs.SetAllAttrs(iou_threshold, keep_n);\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {x}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass RoiAlignFunctor {\n public:\n  RoiAlignFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"roi_align\").Input(\"x\").Input(\"rois\").Output(\"y\").Build());\n  }\n\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,\n                           const std::shared_ptr<one::Tensor>& rois, const float& spatial_scale,\n                           const int32_t& pooled_h, const int32_t& pooled_w,\n                           const int32_t& sampling_ratio, const bool& aligned) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"spatial_scale\", \"pooled_h\", \"pooled_w\",\n                                                 \"sampling_ratio\", \"aligned\");\n    attrs.SetAllAttrs(spatial_scale, pooled_h, pooled_w, sampling_ratio, aligned);\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {x, rois}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass RoiAlignGradFunctor {\n public:\n  RoiAlignGradFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"roi_align_grad\")\n                         .Input(\"dy\")\n                         .Input(\"x_like\")\n                         .Input(\"rois\")\n                         .Output(\"dx\")\n                         .Build());\n  }\n\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& dy,\n                           const std::shared_ptr<one::Tensor>& x_like,\n                           const std::shared_ptr<one::Tensor>& rois, const float& spatial_scale,\n                           const int32_t& pooled_h, const int32_t& pooled_w,\n                           const int32_t& sampling_ratio, const bool& aligned) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"spatial_scale\", \"pooled_h\", \"pooled_w\",\n                                                 \"sampling_ratio\", \"aligned\");\n    attrs.SetAllAttrs(spatial_scale, pooled_h, pooled_w, sampling_ratio, aligned);\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {dy, x_like, rois}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass FusedDotFeatureInteractionFunctor {\n public:\n  FusedDotFeatureInteractionFunctor() {\n    ops_has_output_concat_.resize(kMaxInputCount);\n    ops_no_output_concat_.resize(kMaxInputCount);\n    for (int n = 0; n < ops_has_output_concat_.size(); ++n) {\n      ops_has_output_concat_[n] = CHECK_JUST(one::OpBuilder(\"fused_dot_feature_interaction\")\n                                                 .Input(\"features\", n + 1)\n                                                 .Input(\"output_concat\")\n                                                 .Output(\"out\")\n                                                 .Build());\n    }\n    for (int n = 0; n < ops_no_output_concat_.size(); ++n) {\n      ops_no_output_concat_[n] = CHECK_JUST(one::OpBuilder(\"fused_dot_feature_interaction\")\n                                                .Input(\"features\", n + 1)\n                                                .Output(\"out\")\n                                                .Build());\n    }\n  }\n\n  Maybe<Tensor> operator()(const TensorTuple& features, const Optional<one::Tensor>& output_concat,\n                           const bool& self_interaction, const int32_t& output_padding,\n                           const std::string& pooling) const {\n    const int64_t n_features = features.size();\n    TensorTuple inputs;\n    if (n_features > kMaxInputCount) {\n      inputs.push_back(JUST(functional::Concat(features, 1)));\n    } else {\n      inputs = features;\n    }\n    CHECK_OR_RETURN(pooling == \"sum\" || pooling == \"none\")\n        << Error::RuntimeError() << \"pooling should be sum or none, but get \" << pooling;\n\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"self_interaction\", \"output_padding\", \"pooling\",\n                                                 \"has_output_concat\");\n    if (pooling == \"sum\") {\n      CHECK_EQ_OR_RETURN(output_padding, 0)\n          << Error::RuntimeError() << \"output_padding should be equal to 0. \";\n      CHECK_OR_RETURN(!output_concat) << Error::RuntimeError() << \"output_concat should not exist\";\n      attrs.SetAllAttrs(self_interaction, output_padding, pooling, false);\n      const std::shared_ptr<one::Tensor>& bi_interaction = JUST(OpInterpUtil::Dispatch<Tensor>(\n          *JUST(oneflow::VectorAt(ops_no_output_concat_, n_features - 1)), inputs, attrs));\n      std::vector<int32_t> reduce_axes_vec = {1};\n      return functional::ReduceSum(bi_interaction, reduce_axes_vec, true, NullOpt);\n    }\n    if (output_concat) {\n      attrs.SetAllAttrs(self_interaction, output_padding, pooling, true);\n      inputs.push_back(JUST(output_concat));\n      return OpInterpUtil::Dispatch<Tensor>(\n          *JUST(oneflow::VectorAt(ops_has_output_concat_, n_features - 1)), inputs, attrs);\n    } else {\n      attrs.SetAllAttrs(self_interaction, output_padding, pooling, false);\n      return OpInterpUtil::Dispatch<Tensor>(\n          *JUST(oneflow::VectorAt(ops_no_output_concat_, n_features - 1)), inputs, attrs);\n    }\n  }\n\n private:\n  std::vector<std::shared_ptr<OpExpr>> ops_has_output_concat_;\n  std::vector<std::shared_ptr<OpExpr>> ops_no_output_concat_;\n};\n\nclass FusedCrossFeatureInteractionFunctor {\n public:\n  FusedCrossFeatureInteractionFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"fused_cross_feature_interaction\")\n                         .Input(\"x\")\n                         .Input(\"weight\")\n                         .Input(\"x0\")\n                         .Input(\"bias\")\n                         .Output(\"out\")\n                         .Output(\"matmul_result\")\n                         .Build());\n  }\n\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,\n                           const std::shared_ptr<one::Tensor>& weight,\n                           const std::shared_ptr<one::Tensor>& x0,\n                           const std::shared_ptr<one::Tensor>& bias,\n                           const std::string& interaction_mode) const {\n    if (interaction_mode != \"vector\" && interaction_mode != \"matrix\") {\n      UNIMPLEMENTED_THEN_RETURN()\n          << \"Fused Cross Interaction mode only support `vector` and `matrix`. \";\n    }\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"interaction_mode\");\n    attrs.SetAllAttrs(interaction_mode);\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {x, weight, x0, bias}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass OneEmbeddingIdShuffleFunctor {\n public:\n  OneEmbeddingIdShuffleFunctor() {\n    op_table_ids_has_in_out_ = CHECK_JUST(one::OpBuilder(\"id_shuffle\")\n                                              .Input(\"ids\")\n                                              .Input(\"table_ids\")\n                                              .Output(\"num_unique_matrix\")\n                                              .Output(\"inverse_unique_partition_indices\")\n                                              .Output(\"cur_rank_num_unique\")\n                                              .Output(\"cur_rank_unique_ids\")\n                                              .Output(\"cur_rank_unique_table_ids\")\n                                              .Output(\"cur_rank_inverse_indices\")\n                                              .Build());\n    op_table_ids_no_in_has_out_ = CHECK_JUST(one::OpBuilder(\"id_shuffle\")\n                                                 .Input(\"ids\")\n                                                 .Output(\"num_unique_matrix\")\n                                                 .Output(\"inverse_unique_partition_indices\")\n                                                 .Output(\"cur_rank_num_unique\")\n                                                 .Output(\"cur_rank_unique_ids\")\n                                                 .Output(\"cur_rank_unique_table_ids\")\n                                                 .Output(\"cur_rank_inverse_indices\")\n                                                 .Build());\n  }\n\n  Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& ids,\n                                const Optional<one::Tensor>& table_ids, const int32_t& num_tables,\n                                const std::string& embedding_name) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"num_tables\", \"embedding_name\");\n    attrs.SetAllAttrs(num_tables, embedding_name);\n    if (table_ids) {\n      return OpInterpUtil::Dispatch<TensorTuple>(*op_table_ids_has_in_out_, {ids, JUST(table_ids)},\n                                                 attrs);\n    } else {\n      return OpInterpUtil::Dispatch<TensorTuple>(*op_table_ids_no_in_has_out_, {ids}, attrs);\n    }\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_table_ids_has_in_out_;\n  std::shared_ptr<OpExpr> op_table_ids_no_in_has_out_;\n};\n\nclass OneEmbeddingEmbeddingShuffleFunctor {\n public:\n  OneEmbeddingEmbeddingShuffleFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"embedding_shuffle\")\n                         .Input(\"cur_rank_embeddings\")\n                         .Input(\"num_unique_matrix\")\n                         .Input(\"cur_rank_inverse_indices\")\n                         .Input(\"inverse_unique_partition_indices\")\n                         .Output(\"embeddings\")\n                         .Build());\n  }\n\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& cur_rank_embeddings,\n                           const std::shared_ptr<one::Tensor>& num_unique_matrix,\n                           const std::shared_ptr<one::Tensor>& cur_rank_inverse_indices,\n                           const std::shared_ptr<one::Tensor>& inverse_unique_partition_indices,\n                           const std::string& embedding_name) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"embedding_size\", \"embedding_name\");\n    const int64_t num_axes = cur_rank_embeddings->shape()->NumAxes();\n    attrs.SetAllAttrs(cur_rank_embeddings->shape()->At(num_axes - 1), embedding_name);\n    return OpInterpUtil::Dispatch<Tensor>(\n        *op_,\n        {cur_rank_embeddings, num_unique_matrix, cur_rank_inverse_indices,\n         inverse_unique_partition_indices},\n        attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass OneEmbeddingEmbeddingGradientShuffleFunctor {\n public:\n  OneEmbeddingEmbeddingGradientShuffleFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"embedding_gradient_shuffle\")\n                         .Input(\"embedding_grad\")\n                         .Input(\"num_unique_matrix\")\n                         .Input(\"cur_rank_inverse_indices\")\n                         .Input(\"inverse_unique_partition_indices\")\n                         .Output(\"cur_rank_unique_embedding_grad\")\n                         .Build());\n  }\n\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& embedding_grad,\n                           const std::shared_ptr<one::Tensor>& num_unique_matrix,\n                           const std::shared_ptr<one::Tensor>& cur_rank_inverse_indices,\n                           const std::shared_ptr<one::Tensor>& inverse_unique_partition_indices,\n                           const std::string& embedding_name) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"embedding_size\", \"embedding_name\");\n    const int64_t num_axes = embedding_grad->shape()->NumAxes();\n    attrs.SetAllAttrs(embedding_grad->shape()->At(num_axes - 1), embedding_name);\n    return OpInterpUtil::Dispatch<Tensor>(\n        *op_,\n        {embedding_grad, num_unique_matrix, cur_rank_inverse_indices,\n         inverse_unique_partition_indices},\n        attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass OneEmbeddingLookupFunctor {\n public:\n  OneEmbeddingLookupFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"embedding_lookup\")\n                         .Input(\"num_unique_ids\")\n                         .Input(\"unique_ids\")\n                         .Input(\"table_ids\")\n                         .Output(\"unique_values\")\n                         .Build());\n  }\n\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& num_unique_ids,\n                           const std::shared_ptr<one::Tensor>& unique_ids,\n                           const std::shared_ptr<one::Tensor>& table_ids,\n                           const Symbol<DType>& dtype, const Symbol<DType>& embedding_dtype,\n                           const int64_t line_size, const int64_t embedding_size,\n                           const std::string& embedding_name, const std::string& embedding_tables,\n                           const std::string& state_initializer, const int64_t seed) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"dtype\", \"embedding_dtype\", \"line_size\",\n                                                 \"embedding_size\", \"embedding_name\",\n                                                 \"embedding_tables\", \"state_initializer\", \"seed\");\n    attrs.SetAllAttrs(dtype->data_type(), embedding_dtype->data_type(), line_size, embedding_size,\n                      embedding_name, embedding_tables, state_initializer, seed);\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {num_unique_ids, unique_ids, table_ids}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass OneEmbeddingFusedLookupFunctor {\n public:\n  OneEmbeddingFusedLookupFunctor() {\n    op_has_table_ids_ = CHECK_JUST(one::OpBuilder(\"one_embedding_fused_lookup\")\n                                       .Input(\"shadow\")\n                                       .Input(\"ids\")\n                                       .Input(\"table_ids\")\n                                       .Output(\"embeddings\")\n                                       .Build());\n    op_no_table_ids_ = CHECK_JUST(one::OpBuilder(\"one_embedding_fused_lookup\")\n                                      .Input(\"shadow\")\n                                      .Input(\"ids\")\n                                      .Output(\"embeddings\")\n                                      .Build());\n  }\n\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& shadow,\n                           const std::shared_ptr<one::Tensor>& ids,\n                           const Optional<one::Tensor>& table_ids, const Symbol<DType>& dtype,\n                           const std::string& embedding_name, const int64_t line_size,\n                           const int64_t embedding_size, const bool is_full_cache,\n                           const int32_t num_tables, const std::string& embedding_tables,\n                           const Optional<int64_t>& padding_idx, const int64_t seed) const {\n    int64_t padding_idx_val = -1;\n    bool has_padding_idx = false;\n    if (padding_idx.has_value()) {\n      padding_idx_val = JUST(padding_idx);\n      has_padding_idx = true;\n    }\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\n        \"dtype\", \"embedding_name\", \"line_size\", \"embedding_size\", \"is_full_cache\", \"num_tables\",\n        \"embedding_tables\", \"seed\", \"padding_idx\", \"has_padding_idx\");\n    attrs.SetAllAttrs(dtype->data_type(), embedding_name, line_size, embedding_size, is_full_cache,\n                      num_tables, embedding_tables, seed, padding_idx_val, has_padding_idx);\n    if (table_ids) {\n      const auto& table_ids_shape = *(JUST(table_ids)->shape());\n      const auto& ids_shape = *(ids->shape());\n      auto broadcast_table_ids = JUST(table_ids);\n      if (table_ids_shape != ids_shape) {\n        CHECK_LE_OR_RETURN(table_ids_shape.NumAxes(), ids_shape.NumAxes())\n            << \"table_ids num_axes should be less equal to ids num_axes, but got table_ids \"\n               \"num_axes \"\n            << table_ids_shape.NumAxes() << \" and ids num_axes \" << ids_shape.NumAxes();\n        const int64_t left_extend_dims = ids_shape.NumAxes() - table_ids_shape.NumAxes();\n        for (int64_t i = 0; i < table_ids_shape.NumAxes(); i++) {\n          CHECK_EQ_OR_RETURN(table_ids_shape.at(i), ids_shape.at(left_extend_dims + i))\n              << \"when table_ids's shape not equals ids shape, table_ids must be able to be \"\n                 \"broadcast to ids_shape \"\n                 \"but got table_ids_shape: \"\n              << table_ids_shape.DebugStr() << \", ids_shape: \" << ids_shape.DebugStr();\n        }\n        broadcast_table_ids =\n            JUST(functional::BroadcastLike(JUST(table_ids), ids, std::vector<int32_t>{}));\n      }\n      return OpInterpUtil::Dispatch<Tensor>(*op_has_table_ids_, {shadow, ids, broadcast_table_ids},\n                                            attrs);\n    } else {\n      return OpInterpUtil::Dispatch<Tensor>(*op_no_table_ids_, {shadow, ids}, attrs);\n    }\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_has_table_ids_;\n  std::shared_ptr<OpExpr> op_no_table_ids_;\n};\n\nclass OneEmbeddingFusedLookupGradFunctor {\n public:\n  OneEmbeddingFusedLookupGradFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"one_embedding_fused_lookup_grad\")\n                         .Input(\"ids\")\n                         .Input(\"embedding_grad\")\n                         .Build());\n  }\n\n  Maybe<void> operator()(const std::shared_ptr<one::Tensor>& ids,\n                         const std::shared_ptr<one::Tensor>& embedding_grad,\n                         const std::string& embedding_name, const int64_t line_size,\n                         const int64_t embedding_size) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"embedding_name\", \"line_size\", \"embedding_size\");\n    attrs.SetAllAttrs(embedding_name, line_size, embedding_size);\n    JUST(OpInterpUtil::Dispatch<TensorTuple>(*op_, {ids, embedding_grad}, attrs));\n    return Maybe<void>::Ok();\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass OneEmbeddingEmbeddingPutFunctor {\n public:\n  OneEmbeddingEmbeddingPutFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"embedding_put\")\n                         .Input(\"num_unique_ids\")\n                         .Input(\"unique_ids\")\n                         .Input(\"unique_embeddings\")\n                         .Build());\n  }\n\n  Maybe<void> operator()(const std::shared_ptr<one::Tensor>& num_unique_ids,\n                         const std::shared_ptr<one::Tensor>& unique_ids,\n                         const std::shared_ptr<one::Tensor>& unique_embeddings,\n                         const std::string& embedding_name, const int64_t line_size) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"embedding_name\", \"line_size\");\n    attrs.SetAllAttrs(embedding_name, line_size);\n    JUST(OpInterpUtil::Dispatch<TensorTuple>(*op_, {num_unique_ids, unique_ids, unique_embeddings},\n                                             attrs));\n    return Maybe<void>::Ok();\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass OneEmbeddingUniqueKeyValuePairFunctor {\n public:\n  OneEmbeddingUniqueKeyValuePairFunctor() {\n    op_has_input_value_ = CHECK_JUST(one::OpBuilder(\"unique_key_value_pair\")\n                                         .Input(\"keys\")\n                                         .Input(\"values\")\n                                         .Output(\"num_unique\")\n                                         .Output(\"unique_keys\")\n                                         .Output(\"unique_values\")\n                                         .Output(\"inverse_indices\")\n                                         .Build());\n    op_no_input_value_ = CHECK_JUST(one::OpBuilder(\"unique_key_value_pair\")\n                                        .Input(\"keys\")\n                                        .Output(\"num_unique\")\n                                        .Output(\"unique_keys\")\n                                        .Output(\"unique_values\")\n                                        .Output(\"inverse_indices\")\n                                        .Build());\n  }\n\n  Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& keys,\n                                const Optional<one::Tensor>& values, const int32_t num_tables,\n                                const std::string& embedding_name) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"num_tables\", \"embedding_name\");\n    attrs.SetAllAttrs(num_tables, embedding_name);\n    if (values) {\n      return OpInterpUtil::Dispatch<TensorTuple>(*op_has_input_value_, {keys, JUST(values)}, attrs);\n    } else {\n      return OpInterpUtil::Dispatch<TensorTuple>(*op_no_input_value_, {keys}, attrs);\n    }\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_has_input_value_;\n  std::shared_ptr<OpExpr> op_no_input_value_;\n};\n\nclass OneEmbeddingSgdUpdateFunctor {\n public:\n  OneEmbeddingSgdUpdateFunctor() {\n    // This functor is only used in one_embedding eager mode with lr passed by attr and no optional\n    // input, we also define functor with all optional input just for unittest. when the optional\n    // input learning_rate tensor has passed in, we think all optional input are not None and check\n    // them.\n    sgd_no_optional_input_op_ = CHECK_JUST(one::OpBuilder(\"one_embedding_sgd_update\")\n                                               .Input(\"num_unique_ids\")\n                                               .Input(\"unique_embeddings\")\n                                               .Input(\"embedding_grad\")\n                                               .Output(\"updated_unique_embeddings\")\n                                               .Build());\n    momentum_no_optional_input_op_ = CHECK_JUST(one::OpBuilder(\"one_embedding_momentum_update\")\n                                                    .Input(\"num_unique_ids\")\n                                                    .Input(\"unique_embeddings\")\n                                                    .Input(\"embedding_grad\")\n                                                    .Output(\"updated_unique_embeddings\")\n                                                    .Build());\n    // This functor is just for unittest\n    sgd_op_ = CHECK_JUST(one::OpBuilder(\"one_embedding_sgd_update\")\n                             .Input(\"num_unique_ids\")\n                             .Input(\"unique_embeddings\")\n                             .Input(\"embedding_grad\")\n                             .Input(\"learning_rate\")\n                             .Input(\"down_scale_by_tensor\")\n                             .Input(\"skip_if\")\n                             .Output(\"updated_unique_embeddings\")\n                             .Build());\n    momentum_op_ = CHECK_JUST(one::OpBuilder(\"one_embedding_momentum_update\")\n                                  .Input(\"num_unique_ids\")\n                                  .Input(\"unique_embeddings\")\n                                  .Input(\"embedding_grad\")\n                                  .Input(\"learning_rate\")\n                                  .Input(\"down_scale_by_tensor\")\n                                  .Input(\"skip_if\")\n                                  .Output(\"updated_unique_embeddings\")\n                                  .Build());\n  }\n\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& num_unique_ids,\n                           const std::shared_ptr<one::Tensor>& unique_embeddings,\n                           const std::shared_ptr<one::Tensor>& embedding_grad,\n                           const Optional<one::Tensor>& learning_rate,\n                           const Optional<one::Tensor>& down_scale_by_tensor,\n                           const Optional<one::Tensor>& skip_if, const float learning_rate_val,\n                           const double scale, const float weight_decay, const float momentum,\n                           const int64_t line_size, const int64_t embedding_size,\n                           const std::string& embedding_name) const {\n    auto& attrs =\n        THREAD_CACHED_MUTABLE_ATTR_MAP(\"learning_rate_val\", \"scale\", \"weight_decay\", \"line_size\",\n                                       \"embedding_size\", \"embedding_name\", \"beta\");\n    if (momentum == 0) {\n      attrs.SetAllAttrs(learning_rate_val, scale, weight_decay, line_size, embedding_size,\n                        embedding_name, NullOpt);\n\n      if (learning_rate) {\n        CHECK(down_scale_by_tensor);\n        CHECK(skip_if);\n        return OpInterpUtil::Dispatch<Tensor>(\n            *sgd_op_,\n            {num_unique_ids, unique_embeddings, embedding_grad, JUST(learning_rate),\n             JUST(down_scale_by_tensor), JUST(skip_if)},\n            attrs);\n      } else {\n        CHECK(!down_scale_by_tensor);\n        CHECK(!skip_if);\n        return OpInterpUtil::Dispatch<Tensor>(\n            *sgd_no_optional_input_op_, {num_unique_ids, unique_embeddings, embedding_grad}, attrs);\n      }\n    } else {\n      attrs.SetAllAttrs(learning_rate_val, scale, weight_decay, line_size, embedding_size,\n                        embedding_name, momentum);\n      if (learning_rate) {\n        CHECK(down_scale_by_tensor);\n        CHECK(skip_if);\n        return OpInterpUtil::Dispatch<Tensor>(\n            *momentum_op_,\n            {num_unique_ids, unique_embeddings, embedding_grad, JUST(learning_rate),\n             JUST(down_scale_by_tensor), JUST(skip_if)},\n            attrs);\n      } else {\n        CHECK(!down_scale_by_tensor);\n        CHECK(!skip_if);\n        return OpInterpUtil::Dispatch<Tensor>(*momentum_no_optional_input_op_,\n                                              {num_unique_ids, unique_embeddings, embedding_grad},\n                                              attrs);\n      }\n    }\n  }\n\n private:\n  std::shared_ptr<OpExpr> sgd_no_optional_input_op_;\n  std::shared_ptr<OpExpr> sgd_op_;\n  std::shared_ptr<OpExpr> momentum_no_optional_input_op_;\n  std::shared_ptr<OpExpr> momentum_op_;\n};\n\nclass OneEmbeddingAdamUpdateFunctor {\n public:\n  OneEmbeddingAdamUpdateFunctor() {\n    // This functor is only used in one_embedding eager mode with lr passed by attr and no optional\n    // input, we also define functor with all optional input just for unittest. when the optional\n    // input learning_rate tensor has passed in, we think all optional input are not None and check\n    // them.\n    no_optional_input_op_ = CHECK_JUST(one::OpBuilder(\"one_embedding_adam_update\")\n                                           .Input(\"num_unique_ids\")\n                                           .Input(\"unique_embeddings\")\n                                           .Input(\"embedding_grad\")\n                                           .Output(\"updated_unique_embeddings\")\n                                           .Build());\n    // This functor is just for unittest\n    no_bias_correction_op_ = CHECK_JUST(one::OpBuilder(\"one_embedding_adam_update\")\n                                            .Input(\"num_unique_ids\")\n                                            .Input(\"unique_embeddings\")\n                                            .Input(\"embedding_grad\")\n                                            .Input(\"learning_rate\")\n                                            .Input(\"down_scale_by_tensor\")\n                                            .Input(\"skip_if\")\n                                            .Output(\"updated_unique_embeddings\")\n                                            .Build());\n    do_bias_correction_op_ = CHECK_JUST(one::OpBuilder(\"one_embedding_adam_update\")\n                                            .Input(\"num_unique_ids\")\n                                            .Input(\"unique_embeddings\")\n                                            .Input(\"embedding_grad\")\n                                            .Input(\"learning_rate\")\n                                            .Input(\"down_scale_by_tensor\")\n                                            .Input(\"skip_if\")\n                                            .Input(\"bias_correction1\")\n                                            .Input(\"bias_correction2\")\n                                            .Output(\"updated_unique_embeddings\")\n                                            .Build());\n  }\n\n  Maybe<Tensor> operator()(\n      const std::shared_ptr<one::Tensor>& num_unique_ids,\n      const std::shared_ptr<one::Tensor>& unique_embeddings,\n      const std::shared_ptr<one::Tensor>& embedding_grad,\n      const Optional<one::Tensor>& learning_rate, const Optional<one::Tensor>& down_scale_by_tensor,\n      const Optional<one::Tensor>& skip_if, const Optional<one::Tensor>& bias_correction1,\n      const Optional<one::Tensor>& bias_correction2, const float learning_rate_val,\n      const double scale, const float weight_decay, const float beta1, const float beta2,\n      const float& bias_correction1_val, const float& bias_correction2_val, const float epsilon,\n      const bool do_bias_correction, const int64_t line_size, const int64_t embedding_size,\n      const std::string& embedding_name) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\n        \"learning_rate_val\", \"scale\", \"weight_decay\", \"beta1\", \"beta2\", \"epsilon\",\n        \"bias_correction1_val\", \"bias_correction2_val\", \"do_bias_correction\", \"line_size\",\n        \"embedding_size\", \"embedding_name\");\n    attrs.SetAllAttrs(learning_rate_val, scale, weight_decay, beta1, beta2, epsilon,\n                      bias_correction1_val, bias_correction2_val, do_bias_correction, line_size,\n                      embedding_size, embedding_name);\n    if (learning_rate) {\n      CHECK(down_scale_by_tensor);\n      CHECK(skip_if);\n      if (do_bias_correction) {\n        CHECK(bias_correction1);\n        CHECK(bias_correction2);\n        return OpInterpUtil::Dispatch<Tensor>(\n            *do_bias_correction_op_,\n            {num_unique_ids, unique_embeddings, embedding_grad, JUST(learning_rate),\n             JUST(down_scale_by_tensor), JUST(skip_if), JUST(bias_correction1),\n             JUST(bias_correction2)},\n            attrs);\n      } else {\n        return OpInterpUtil::Dispatch<Tensor>(\n            *no_bias_correction_op_,\n            {num_unique_ids, unique_embeddings, embedding_grad, JUST(learning_rate),\n             JUST(down_scale_by_tensor), JUST(skip_if)},\n            attrs);\n      }\n    } else {\n      CHECK(!down_scale_by_tensor);\n      CHECK(!skip_if);\n      CHECK(!bias_correction1);\n      CHECK(!bias_correction2);\n      return OpInterpUtil::Dispatch<Tensor>(\n          *no_optional_input_op_, {num_unique_ids, unique_embeddings, embedding_grad}, attrs);\n    }\n  }\n\n private:\n  std::shared_ptr<OpExpr> no_bias_correction_op_;\n  std::shared_ptr<OpExpr> do_bias_correction_op_;\n  std::shared_ptr<OpExpr> no_optional_input_op_;\n};\n\nclass OneEmbeddingAdagradUpdateFunctor {\n public:\n  OneEmbeddingAdagradUpdateFunctor() {\n    // This functor is only used in one_embedding eager mode with lr passed by attr and no optional\n    // input, we also define functor with all optional input just for unittest. when the optional\n    // input learning_rate tensor has passed in, we think all optional input are not None and check\n    // them.\n    op_no_optional_input_ = CHECK_JUST(one::OpBuilder(\"one_embedding_adagrad_update\")\n                                           .Input(\"num_unique_ids\")\n                                           .Input(\"unique_embeddings\")\n                                           .Input(\"embedding_grad\")\n                                           .Output(\"updated_unique_embeddings\")\n                                           .Build());\n    // This functor is just for unittest\n    op_ = CHECK_JUST(one::OpBuilder(\"one_embedding_adagrad_update\")\n                         .Input(\"num_unique_ids\")\n                         .Input(\"unique_embeddings\")\n                         .Input(\"embedding_grad\")\n                         .Input(\"learning_rate\")\n                         .Input(\"down_scale_by_tensor\")\n                         .Input(\"skip_if\")\n                         .Input(\"train_step\")\n                         .Output(\"updated_unique_embeddings\")\n                         .Build());\n  }\n\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& num_unique_ids,\n                           const std::shared_ptr<one::Tensor>& unique_embeddings,\n                           const std::shared_ptr<one::Tensor>& embedding_grad,\n                           const Optional<one::Tensor>& learning_rate,\n                           const Optional<one::Tensor>& down_scale_by_tensor,\n                           const Optional<one::Tensor>& skip_if,\n                           const Optional<one::Tensor>& train_step, const int64_t train_step_val,\n                           const float learning_rate_val, const double scale,\n                           const float weight_decay, const float lr_decay, const float epsilon,\n                           const int64_t line_size, const int64_t embedding_size,\n                           const std::string& embedding_name) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"train_step_val\", \"learning_rate_val\", \"scale\",\n                                                 \"weight_decay\", \"lr_decay\", \"epsilon\", \"line_size\",\n                                                 \"embedding_size\", \"embedding_name\");\n    attrs.SetAllAttrs(train_step_val, learning_rate_val, scale, weight_decay, lr_decay, epsilon,\n                      line_size, embedding_size, embedding_name);\n    if (learning_rate) {\n      CHECK(down_scale_by_tensor);\n      CHECK(skip_if);\n      CHECK(train_step);\n      return OpInterpUtil::Dispatch<Tensor>(\n          *op_,\n          {num_unique_ids, unique_embeddings, embedding_grad, JUST(learning_rate),\n           JUST(down_scale_by_tensor), JUST(skip_if), JUST(train_step)},\n          attrs);\n    } else {\n      CHECK(!down_scale_by_tensor);\n      CHECK(!skip_if);\n      CHECK(!train_step);\n      return OpInterpUtil::Dispatch<Tensor>(\n          *op_no_optional_input_, {num_unique_ids, unique_embeddings, embedding_grad}, attrs);\n    }\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n  std::shared_ptr<OpExpr> op_no_optional_input_;\n};\n\nclass OneEmbeddingFtrlUpdateFunctor {\n public:\n  OneEmbeddingFtrlUpdateFunctor() {\n    // This functor is only used in one_embedding eager mode with lr passed by attr and no optional\n    // input, we also define functor with all optional input just for unittest. when the optional\n    // input learning_rate tensor has passed in, we think all optional input are not None and check\n    // them.\n    op_no_optional_input_ = CHECK_JUST(one::OpBuilder(\"one_embedding_ftrl_update\")\n                                           .Input(\"num_unique_ids\")\n                                           .Input(\"unique_embeddings\")\n                                           .Input(\"embedding_grad\")\n                                           .Output(\"updated_unique_embeddings\")\n                                           .Build());\n    // This functor is just for unittest\n    op_ = CHECK_JUST(one::OpBuilder(\"one_embedding_ftrl_update\")\n                         .Input(\"num_unique_ids\")\n                         .Input(\"unique_embeddings\")\n                         .Input(\"embedding_grad\")\n                         .Input(\"learning_rate\")\n                         .Input(\"down_scale_by_tensor\")\n                         .Input(\"skip_if\")\n                         .Output(\"updated_unique_embeddings\")\n                         .Build());\n  }\n\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& num_unique_ids,\n                           const std::shared_ptr<one::Tensor>& unique_embeddings,\n                           const std::shared_ptr<one::Tensor>& embedding_grad,\n                           const Optional<one::Tensor>& learning_rate,\n                           const Optional<one::Tensor>& down_scale_by_tensor,\n                           const Optional<one::Tensor>& skip_if, const float learning_rate_val,\n                           const double scale, const float weight_decay, const float lr_power,\n                           const float lambda1, const float lambda2, const float beta,\n                           const int64_t line_size, const int64_t embedding_size,\n                           const std::string& embedding_name) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"learning_rate_val\", \"scale\", \"weight_decay\",\n                                                 \"lr_power\", \"lambda1\", \"lambda2\", \"beta\",\n                                                 \"line_size\", \"embedding_size\", \"embedding_name\");\n    attrs.SetAllAttrs(learning_rate_val, scale, weight_decay, lr_power, lambda1, lambda2, beta,\n                      line_size, embedding_size, embedding_name);\n    if (learning_rate) {\n      CHECK(down_scale_by_tensor);\n      CHECK(skip_if);\n      return OpInterpUtil::Dispatch<Tensor>(\n          *op_,\n          {num_unique_ids, unique_embeddings, embedding_grad, JUST(learning_rate),\n           JUST(down_scale_by_tensor), JUST(skip_if)},\n          attrs);\n    } else {\n      CHECK(!down_scale_by_tensor);\n      CHECK(!skip_if);\n      return OpInterpUtil::Dispatch<Tensor>(\n          *op_no_optional_input_, {num_unique_ids, unique_embeddings, embedding_grad}, attrs);\n    }\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n  std::shared_ptr<OpExpr> op_no_optional_input_;\n};\n\nclass DeformConv2dFunctor {\n public:\n  DeformConv2dFunctor() {\n    bias_op_ = CHECK_JUST(one::OpBuilder(\"bias_add\").Input(\"a\").Input(\"b\").Output(\"out\").Build());\n    deformconv2d_op_ = CHECK_JUST(one::OpBuilder(\"deform_conv2d\")\n                                      .Input(\"input\")\n                                      .Input(\"weight\")\n                                      .Input(\"offset\")\n                                      .Input(\"mask\")\n                                      .Output(\"output\")\n                                      .Build());\n  }\n\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input,\n                           const std::shared_ptr<one::Tensor>& weight,\n                           const std::shared_ptr<one::Tensor>& offset,\n                           const std::shared_ptr<one::Tensor>& mask,\n                           const Optional<one::Tensor>& bias, const int32_t& stride_h,\n                           const int32_t& stride_w, const int32_t& pad_h, const int32_t& pad_w,\n                           const int32_t& dilation_h, const int32_t& dilation_w,\n                           const int32_t& groups, const int32_t& offset_groups,\n                           const bool& use_mask) const {\n    auto& attrs =\n        THREAD_CACHED_MUTABLE_ATTR_MAP(\"stride_h\", \"stride_w\", \"pad_h\", \"pad_w\", \"dilation_h\",\n                                       \"dilation_w\", \"groups\", \"offset_groups\", \"use_mask\");\n    attrs.SetAllAttrs(stride_h, stride_w, pad_h, pad_w, dilation_h, dilation_w, groups,\n                      offset_groups, use_mask);\n    const std::shared_ptr<one::Tensor>& deformconv2d_out = JUST(\n        OpInterpUtil::Dispatch<Tensor>(*deformconv2d_op_, {input, weight, offset, mask}, attrs));\n    if (bias) {\n      auto bias_shape = JUST(bias)->shape();\n      auto& bias_attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"axis\");\n      bias_attrs.SetAllAttrs(static_cast<int32_t>(1));\n      return OpInterpUtil::Dispatch<Tensor>(*bias_op_, {deformconv2d_out, JUST(bias)}, bias_attrs);\n    }\n    return deformconv2d_out;\n  }\n\n private:\n  std::shared_ptr<OpExpr> deformconv2d_op_;\n  std::shared_ptr<OpExpr> bias_op_;\n};\n\nclass RocAucScoreFunctor {\n public:\n  RocAucScoreFunctor() {\n    op_ = CHECK_JUST(\n        one::OpBuilder(\"roc_auc_score\").Input(\"label\").Input(\"pred\").Output(\"out\").Build());\n  }\n\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& label,\n                           const std::shared_ptr<one::Tensor>& pred) const {\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {label, pred});\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass MultiTensorSgdUpdateFunctor {\n public:\n  MultiTensorSgdUpdateFunctor() {\n    op_.resize(kMaxInputCount /*the maximum number of inputs*/);\n    for (int n = 0; n < op_.size(); ++n) {\n      op_[n] = CHECK_JUST(one::OpBuilder(\"multi_tensor_sgd_update\")\n                              .Input(\"model\", n + 1)\n                              .Input(\"model_diff\", n + 1)\n                              .Build());\n    }\n  }\n\n  Maybe<void> operator()(const TensorTuple& model, const TensorTuple& model_diff,\n                         const double& scale, const float& weight_decay,\n                         const float& learning_rate_val) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"scale\", \"weight_decay\", \"learning_rate_val\");\n    attrs.SetAllAttrs(scale, weight_decay, learning_rate_val);\n    const int64_t weight_size = model.size();\n    for (int i = 0; i < weight_size; i += kMaxInputCount) {\n      size_t size = (i + kMaxInputCount) < weight_size ? kMaxInputCount : weight_size - i;\n      TensorTuple input(2 * size);\n      std::copy(model.begin() + i, model.begin() + i + size, input.begin());\n      std::copy(model_diff.begin() + i, model_diff.begin() + i + size, input.begin() + size);\n      JUST(OpInterpUtil::Dispatch<TensorTuple>(*op_[size - 1], input, attrs));\n    }\n    return Maybe<void>::Ok();\n  }\n\n private:\n  std::vector<std::shared_ptr<OpExpr>> op_;\n};\n\nclass MultiTensorMomentumUpdateFunctor {\n public:\n  MultiTensorMomentumUpdateFunctor() {\n    op_.resize(kMaxInputCount /*the maximum number of inputs*/);\n    for (int n = 0; n < op_.size(); ++n) {\n      op_[n] = CHECK_JUST(one::OpBuilder(\"multi_tensor_momentum_update\")\n                              .Input(\"model\", n + 1)\n                              .Input(\"model_diff\", n + 1)\n                              .Input(\"momentum_buf\", n + 1)\n                              .Build());\n    }\n  }\n\n  Maybe<void> operator()(const TensorTuple& model, const TensorTuple& model_diff,\n                         const TensorTuple& momentum_buf, const double& scale,\n                         const float& weight_decay, const float& learning_rate_val,\n                         const float& momentum, const float& dampening, const bool& nesterov,\n                         const bool& maximize) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"scale\", \"weight_decay\", \"learning_rate_val\",\n                                                 \"momentum\", \"dampening\", \"nesterov\", \"maximize\");\n    attrs.SetAllAttrs(scale, weight_decay, learning_rate_val, momentum, dampening, nesterov,\n                      maximize);\n    const int64_t weight_size = model.size();\n    for (int i = 0; i < weight_size; i += kMaxInputCount) {\n      size_t size = (i + kMaxInputCount) < weight_size ? kMaxInputCount : weight_size - i;\n      TensorTuple input(3 * size);\n      std::copy(model.begin() + i, model.begin() + i + size, input.begin());\n      std::copy(model_diff.begin() + i, model_diff.begin() + i + size, input.begin() + size);\n      std::copy(momentum_buf.begin() + i, momentum_buf.begin() + i + size,\n                input.begin() + 2 * size);\n      JUST(OpInterpUtil::Dispatch<TensorTuple>(*op_[size - 1], input, attrs));\n    }\n    return Maybe<void>::Ok();\n  }\n\n private:\n  std::vector<std::shared_ptr<OpExpr>> op_;\n};\n\nclass MultiTensorAdamUpdateFunctor {\n public:\n  MultiTensorAdamUpdateFunctor() {\n    op_.resize(kMaxInputCount /*the maximum number of inputs*/);\n    for (int n = 0; n < op_.size(); ++n) {\n      op_[n] = CHECK_JUST(one::OpBuilder(\"multi_tensor_adam_update\")\n                              .Input(\"model\", n + 1)\n                              .Input(\"model_diff\", n + 1)\n                              .Input(\"m\", n + 1)\n                              .Input(\"v\", n + 1)\n                              .Build());\n    }\n  }\n\n  Maybe<void> operator()(const TensorTuple& model, const TensorTuple& model_diff,\n                         const TensorTuple& m, const TensorTuple& v, const float& learning_rate_val,\n                         const float& l2, const float& beta1, const float& beta2,\n                         const float& bias_correction1_val, const float& bias_correction2_val,\n                         const bool& do_bias_correction, const double& scale,\n                         const float& weight_decay, const float& epsilon) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\n        \"scale\", \"weight_decay\", \"beta1\", \"beta2\", \"bias_correction1_val\", \"bias_correction2_val\",\n        \"do_bias_correction\", \"learning_rate_val\", \"l2\", \"epsilon\");\n    attrs.SetAllAttrs(scale, weight_decay, beta1, beta2, bias_correction1_val, bias_correction2_val,\n                      do_bias_correction, learning_rate_val, l2, epsilon);\n\n    const int64_t weight_size = model.size();\n    for (int i = 0; i < weight_size; i += kMaxInputCount) {\n      size_t size = (i + kMaxInputCount) < weight_size ? kMaxInputCount : weight_size - i;\n      TensorTuple input(4 * size);\n      std::copy(model.begin() + i, model.begin() + i + size, input.begin());\n      std::copy(model_diff.begin() + i, model_diff.begin() + i + size, input.begin() + size);\n      std::copy(m.begin() + i, m.begin() + i + size, input.begin() + 2 * size);\n      std::copy(v.begin() + i, v.begin() + i + size, input.begin() + 3 * size);\n      JUST(OpInterpUtil::Dispatch<TensorTuple>(*op_[size - 1], input, attrs));\n    }\n    return Maybe<void>::Ok();\n  }\n\n private:\n  std::vector<std::shared_ptr<OpExpr>> op_;\n};\n\nclass MatrixVectorProductFunctor {\n public:\n  MatrixVectorProductFunctor() {\n    matrix_vector_product_op_ = CHECK_JUST(\n        one::OpBuilder(\"matrix_vector_product\").Input(\"a\").Input(\"b\").Output(\"out\").Build());\n  }\n\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input,\n                           const std::shared_ptr<one::Tensor>& vec) const {\n    const auto& input_shape = input->shape();\n    const auto& vec_shape = vec->shape();\n    CHECK_OR_RETURN(input_shape->NumAxes() == 2 && vec_shape->NumAxes() == 1)\n        << Error::RuntimeError() << \"vector + matrix @ vector expected, got \"\n        << \"1, \" << input_shape->NumAxes() << \", \" << vec_shape->NumAxes();\n    CHECK_EQ_OR_RETURN(input_shape->at(1), vec_shape->at(0))\n        << Error::RuntimeError() << \"size mismatch, got \" << std::to_string(input_shape->at(0))\n        << \", \" << std::to_string(input_shape->at(0)) << \"x\" << std::to_string(input_shape->at(1))\n        << \", \" << std::to_string(vec_shape->at(0));\n    return OpInterpUtil::Dispatch<Tensor>(*matrix_vector_product_op_, {input, vec});\n  }\n\n private:\n  std::shared_ptr<OpExpr> matrix_vector_product_op_;\n};\n\nclass BatchNormStatsFunctor {\n public:\n  BatchNormStatsFunctor() {\n    op_ = CHECK_JUST(\n        one::OpBuilder(\"batch_norm_stats\").Input(\"input\").Output(\"mean\").Output(\"invstd\").Build());\n  }\n\n  Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& input, const int& axis,\n                                const float& eps) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"axis\", \"eps\");\n    attrs.SetAllAttrs(axis, eps);\n    return OpInterpUtil::Dispatch<one::TensorTuple>(*op_, {input}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass BatchNormGatherStatsWithCountsFunctor {\n public:\n  BatchNormGatherStatsWithCountsFunctor() {\n    op_with_running_mean_and_var_ = CHECK_JUST(one::OpBuilder(\"batch_norm_gather_stats_with_counts\")\n                                                   .Input(\"input\")\n                                                   .Input(\"mean\")\n                                                   .Input(\"invstd\")\n                                                   .Input(\"counts\")\n                                                   .Input(\"running_mean\")\n                                                   .Input(\"running_var\")\n                                                   .Output(\"global_mean\")\n                                                   .Output(\"global_invstd\")\n                                                   .Build());\n    op_without_running_mean_and_var_ =\n        CHECK_JUST(one::OpBuilder(\"batch_norm_gather_stats_with_counts\")\n                       .Input(\"input\")\n                       .Input(\"mean\")\n                       .Input(\"invstd\")\n                       .Input(\"counts\")\n                       .Output(\"global_mean\")\n                       .Output(\"global_invstd\")\n                       .Build());\n  }\n\n  Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& input,\n                                const std::shared_ptr<one::Tensor>& mean,\n                                const std::shared_ptr<one::Tensor>& invstd,\n                                const Optional<one::Tensor>& running_mean,\n                                const Optional<one::Tensor>& running_var, const float& momentum,\n                                const float& eps,\n                                const std::shared_ptr<one::Tensor>& counts) const {\n    CHECK_OR_RETURN((running_mean && running_var) || (!running_mean && !running_var))\n        << Error::RuntimeError()\n        << \"Both running_mean and running_var should be None or Tensor at the same time.\";\n\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"eps\", \"momentum\");\n    attrs.SetAllAttrs(eps, momentum);\n\n    if (running_mean) {\n      return OpInterpUtil::Dispatch<one::TensorTuple>(\n          *op_with_running_mean_and_var_,\n          {input, mean, invstd, counts, JUST(running_mean), JUST(running_var)}, attrs);\n    }\n    return OpInterpUtil::Dispatch<one::TensorTuple>(*op_without_running_mean_and_var_,\n                                                    {input, mean, invstd, counts}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_with_running_mean_and_var_;\n  std::shared_ptr<OpExpr> op_without_running_mean_and_var_;\n};\n\nclass BatchNormElemtFunctor {\n public:\n  BatchNormElemtFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"batch_norm_elemt\")\n                         .Input(\"input\")\n                         .Input(\"weight\")\n                         .Input(\"bias\")\n                         .Input(\"mean\")\n                         .Input(\"invstd\")\n                         .Output(\"output\")\n                         .Build());\n  }\n\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input,\n                           const std::shared_ptr<one::Tensor>& weight,\n                           const std::shared_ptr<one::Tensor>& bias,\n                           const std::shared_ptr<one::Tensor>& mean,\n                           const std::shared_ptr<one::Tensor>& invstd, const int& axis,\n                           const float& eps) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"axis\", \"eps\");\n    attrs.SetAllAttrs(axis, eps);\n    return OpInterpUtil::Dispatch<one::Tensor>(*op_, {input, weight, bias, mean, invstd}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass BatchNormBackwardReduceFunctor {\n public:\n  BatchNormBackwardReduceFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"batch_norm_backward_reduce\")\n                         .Input(\"grad_out\")\n                         .Input(\"input\")\n                         .Input(\"mean\")\n                         .Input(\"invstd\")\n                         .Output(\"sum_dy\")\n                         .Output(\"sum_dy_xmu\")\n                         .Output(\"grad_weight\")\n                         .Output(\"grad_bias\")\n                         .Build());\n  }\n\n  Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& grad_out,\n                                const std::shared_ptr<one::Tensor>& input,\n                                const std::shared_ptr<one::Tensor>& mean,\n                                const std::shared_ptr<one::Tensor>& invstd, const int& axis) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"axis\");\n    attrs.SetAllAttrs(axis);\n    return OpInterpUtil::Dispatch<one::TensorTuple>(*op_, {grad_out, input, mean, invstd}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass BatchNormBackwardElemtFunctor {\n public:\n  BatchNormBackwardElemtFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"batch_norm_backward_elemt\")\n                         .Input(\"grad_out\")\n                         .Input(\"input\")\n                         .Input(\"mean\")\n                         .Input(\"invstd\")\n                         .Input(\"weight\")\n                         .Input(\"sum_dy\")\n                         .Input(\"sum_dy_xmu\")\n                         .Input(\"count\")\n                         .Output(\"grad_in\")\n                         .Build());\n  }\n\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& grad_out,\n                           const std::shared_ptr<one::Tensor>& input,\n                           const std::shared_ptr<one::Tensor>& mean,\n                           const std::shared_ptr<one::Tensor>& invstd,\n                           const std::shared_ptr<one::Tensor>& weight,\n                           const std::shared_ptr<one::Tensor>& sum_dy,\n                           const std::shared_ptr<one::Tensor>& sum_dy_xmu,\n                           const std::shared_ptr<one::Tensor>& count, const int& axis) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"axis\");\n    attrs.SetAllAttrs(axis);\n    return OpInterpUtil::Dispatch<one::Tensor>(\n        *op_, {grad_out, input, mean, invstd, weight, sum_dy, sum_dy_xmu, count}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass FusedFastGeluMulFunctor {\n public:\n  FusedFastGeluMulFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"fused_fast_gelu_mul\")\n                         .Input(\"in\")\n                         .Input(\"multiplier\")\n                         .Output(\"out\")\n                         .Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,\n                           const std::shared_ptr<one::Tensor>& multiplier) const {\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {x, multiplier});\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass FusedFastGeluMulGradFunctor {\n public:\n  FusedFastGeluMulGradFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"fused_fast_gelu_mul_grad\")\n                         .Input(\"out_diff\")\n                         .Input(\"in\")\n                         .Input(\"multiplier\")\n                         .Output(\"in_diff\")\n                         .Output(\"multiplier_diff\")\n                         .Build());\n  }\n  Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& dy,\n                                const std::shared_ptr<one::Tensor>& x,\n                                const std::shared_ptr<one::Tensor>& multiplier) const {\n    return OpInterpUtil::Dispatch<TensorTuple>(*op_, {dy, x, multiplier});\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass GroupedMatmulBiasFunctor {\n public:\n  GroupedMatmulBiasFunctor() {\n    fused_op_.resize(kMaxInputCount /*the maximum number of inputs*/);\n    for (int n = 1; n < fused_op_.size(); ++n) {\n      fused_op_[n] = CHECK_JUST(one::OpBuilder(\"grouped_matmul_bias\")\n                                    .Input(\"xs\", n)\n                                    .Input(\"weights\", n)\n                                    .Input(\"biases\", n)\n                                    .Output(\"ys\", n)\n                                    .Build());\n    }\n  }\n  Maybe<TensorTuple> operator()(const TensorTuple& xs, const TensorTuple& weights,\n                                const TensorTuple& biases) const {\n    const int64_t input_size = xs.size();\n    const int64_t weight_size = weights.size();\n    const int64_t bias_size = biases.size();\n    CHECK_GE_OR_RETURN(input_size, 1)\n        << Error::RuntimeError() << \"The number of xs should be greater equal than 1.\";\n    CHECK_EQ_OR_RETURN(weight_size, input_size)\n        << Error::RuntimeError() << \"The number of weights should be equal to xs.\";\n    CHECK_EQ_OR_RETURN(bias_size, input_size)\n        << Error::RuntimeError() << \"The number of bias should be equal to xs.\";\n    for (int64_t i = 0; i < input_size; ++i) {\n      const auto& input_shape = xs[i]->shape();\n      const auto& weight_shape = weights[i]->shape();\n      const auto& bias_shape = biases[i]->shape();\n      CHECK_GE_OR_RETURN(input_shape->NumAxes(), 2)\n          << Error::RuntimeError() << \"x's dim size should greater equal than 2.\";\n      CHECK_EQ_OR_RETURN(weight_shape->NumAxes(), 2)\n          << Error::RuntimeError() << \"Weight's dim size should == 2\";\n      CHECK_EQ_OR_RETURN(bias_shape->NumAxes(), 1)\n          << Error::RuntimeError() << \"Bias's dim size should == 1\";\n      const int64_t k = input_shape->At(input_shape->NumAxes() - 1);\n      CHECK_EQ_OR_RETURN(weight_shape->At(1), k)\n          << Error::RuntimeError() << \"weight's second dim should be equal to input's last dim. \";\n      const int64_t n = weight_shape->At(0);\n      CHECK_EQ_OR_RETURN(bias_shape->At(0), n)\n          << Error::RuntimeError() << \"Bias's dim is not equal to weight's first dim. \";\n    }\n    TensorTuple input(3 * input_size);\n    std::copy(xs.begin(), xs.end(), input.begin() + 0 * input_size);\n    std::copy(weights.begin(), weights.end(), input.begin() + 1 * input_size);\n    std::copy(biases.begin(), biases.end(), input.begin() + 2 * input_size);\n    return OpInterpUtil::Dispatch<TensorTuple>(*fused_op_[input_size], input);\n  }\n\n private:\n  std::vector<std::shared_ptr<OpExpr>> fused_op_;\n};\n\nclass GroupedMatmulFunctor {\n public:\n  GroupedMatmulFunctor() {\n    fused_op_.resize(kMaxInputCount /*the maximum number of inputs*/);\n    for (int n = 1; n < fused_op_.size(); ++n) {\n      fused_op_[n] = CHECK_JUST(one::OpBuilder(\"grouped_matmul_bias\")\n                                    .Input(\"xs\", n)\n                                    .Input(\"weights\", n)\n                                    .Output(\"ys\", n)\n                                    .Build());\n    }\n  }\n  Maybe<TensorTuple> operator()(const TensorTuple& xs, const TensorTuple& weights) const {\n    const int64_t input_size = xs.size();\n    const int64_t weight_size = weights.size();\n    CHECK_LT_OR_RETURN(input_size, kMaxInputCount)\n        << Error::RuntimeError() << \"input_size size should not be greater than 128\";\n    CHECK_GE_OR_RETURN(input_size, 1)\n        << Error::RuntimeError() << \"The number of xs should be greater equal than 1.\";\n    CHECK_EQ_OR_RETURN(weight_size, input_size)\n        << Error::RuntimeError() << \"The number of weights should be equal to xs.\";\n    for (int64_t i = 0; i < input_size; ++i) {\n      const auto& input_shape = xs[i]->shape();\n      const auto& weight_shape = weights[i]->shape();\n      CHECK_GE_OR_RETURN(input_shape->NumAxes(), 2)\n          << Error::RuntimeError() << \"x's dim size should greater equal than 2.\";\n      CHECK_EQ_OR_RETURN(weight_shape->NumAxes(), 2)\n          << Error::RuntimeError() << \"Weight's dim size should == 2\";\n      const int64_t k = input_shape->At(input_shape->NumAxes() - 1);\n      CHECK_EQ_OR_RETURN(weight_shape->At(1), k)\n          << Error::RuntimeError() << \"weight's second dim should be equal to input's last dim. \";\n    }\n    TensorTuple input(2 * input_size);\n    std::copy(xs.begin(), xs.end(), input.begin() + 0 * input_size);\n    std::copy(weights.begin(), weights.end(), input.begin() + 1 * input_size);\n    return OpInterpUtil::Dispatch<TensorTuple>(*fused_op_[input_size], input);\n  }\n\n private:\n  std::vector<std::shared_ptr<OpExpr>> fused_op_;\n};\n\nclass MultiTensorYoloV5WeightUpdateFunctor {\n public:\n  MultiTensorYoloV5WeightUpdateFunctor() {\n    op_.resize(kMaxInputCount /*the maximum number of inputs*/);\n    for (int n = 0; n < op_.size(); ++n) {\n      op_[n] = CHECK_JUST(one::OpBuilder(\"multi_tensor_yolov5_weight_update\")\n                              .Input(\"model\", n + 1)\n                              .Input(\"model_update\", n + 1)\n                              .Build());\n    }\n  }\n\n  Maybe<void> operator()(const TensorTuple& model, const TensorTuple& model_update,\n                         const float& d) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"d\");\n    attrs.SetAllAttrs(d);\n    const int64_t weight_size = model.size();\n    for (int i = 0; i < weight_size; i += kMaxInputCount) {\n      size_t size = (i + kMaxInputCount) < weight_size ? kMaxInputCount : weight_size - i;\n      TensorTuple input(size * 2);\n      std::copy(model.begin() + i, model.begin() + i + size, input.begin());\n      std::copy(model_update.begin() + i, model_update.begin() + i + size,\n                input.begin() + 1 * size);\n      JUST(OpInterpUtil::Dispatch<TensorTuple>(*op_[size - 1], input, attrs));\n    }\n    return Maybe<void>::Ok();\n  }\n\n private:\n  std::vector<std::shared_ptr<OpExpr>> op_;\n};\n\nclass FusedScaleMaskBiasSoftmaxFunctor {\n public:\n  FusedScaleMaskBiasSoftmaxFunctor() {\n    op_with_bias_ = CHECK_JUST(one::OpBuilder(\"fused_scale_mask_bias_softmax\")\n                                   .Input(\"x\")\n                                   .Input(\"mask\")\n                                   .Input(\"bias\")\n                                   .Output(\"out\")\n                                   .Build());\n    op_without_bias_ = CHECK_JUST(one::OpBuilder(\"fused_scale_mask_bias_softmax\")\n                                      .Input(\"x\")\n                                      .Input(\"mask\")\n                                      .Output(\"out\")\n                                      .Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,\n                           const std::shared_ptr<one::Tensor>& mask,\n                           const Optional<one::Tensor>& bias, const float& scale,\n                           const bool& inplace = false) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"scale\", \"inplace\");\n    attrs.SetAllAttrs(scale, inplace);\n    if (bias) {\n      if (inplace) {\n        std::shared_ptr<TensorTuple> outputs = std::make_shared<TensorTuple>(1);\n        outputs->at(0) = x;\n        JUST(OpInterpUtil::Dispatch(*op_with_bias_, {x, mask, JUST(bias)}, outputs.get(), attrs));\n        return outputs->at(0);\n      }\n      return OpInterpUtil::Dispatch<Tensor>(*op_with_bias_, {x, mask, JUST(bias)}, attrs);\n      ;\n    }\n    if (inplace) {\n      std::shared_ptr<TensorTuple> outputs = std::make_shared<TensorTuple>(1);\n      outputs->at(0) = x;\n      JUST(OpInterpUtil::Dispatch(*op_without_bias_, {x, mask}, outputs.get(), attrs));\n      return outputs->at(0);\n    }\n    return OpInterpUtil::Dispatch<Tensor>(*op_without_bias_, {x, mask}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_without_bias_;\n  std::shared_ptr<OpExpr> op_with_bias_;\n};\n\nclass FusedScaleMaskBiasSoftmaxGradFunctor {\n public:\n  FusedScaleMaskBiasSoftmaxGradFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"fused_scale_mask_bias_softmax_grad\")\n                         .Input(\"y\")\n                         .Input(\"dy\")\n                         .Output(\"dx\")\n                         .Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& y,\n                           const std::shared_ptr<one::Tensor>& dy, const float& scale) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"scale\");\n    attrs.SetAllAttrs(scale);\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {y, dy}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass FusedClipGradFunctor {\n public:\n  FusedClipGradFunctor() {\n    op_.resize(kMaxInputCount /*the maximum number of inputs*/);\n    for (int n = 0; n < op_.size(); ++n) {\n      op_[n] = CHECK_JUST(\n          one::OpBuilder(\"fused_clip_grad\").Input(\"model_diff\", n + 1).Output(\"out\").Build());\n    }\n  }\n\n  Maybe<Tensor> operator()(const TensorTuple& model_diff, const float& max_norm,\n                           const float& norm_type) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"max_norm\", \"norm_type\");\n    attrs.SetAllAttrs(max_norm, norm_type);\n    const int64_t input_size = model_diff.size();\n    CHECK_LE_OR_RETURN(input_size, kMaxInputCount)\n        << Error::RuntimeError() << \"model_diff size should not be greater than 128\";\n    return JUST(OpInterpUtil::Dispatch<Tensor>(*op_[input_size - 1], model_diff, attrs));\n  }\n\n private:\n  std::vector<std::shared_ptr<OpExpr>> op_;\n};\n\nclass NonContiguousBinaryOpFunctor {\n public:\n  NonContiguousBinaryOpFunctor() {\n    op_ = CHECK_JUST(\n        one::OpBuilder(\"noncontiguous_binary_op\").Input(\"lhs\").Input(\"rhs\").Output(\"y\").Build());\n  }\n\n  Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& lhs, const std::shared_ptr<Tensor>& rhs,\n                           const std::string& op, const bool& inplace = false) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"op\", \"inplace\");\n    attrs.SetAllAttrs(op, inplace);\n    if (inplace) {\n      std::shared_ptr<TensorTuple> outputs = std::make_shared<TensorTuple>(1);\n      outputs->at(0) = lhs;\n      JUST(OpInterpUtil::Dispatch(*op_, {lhs, rhs}, outputs.get(), attrs));\n      return outputs->at(0);\n    }\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {lhs, rhs}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass NonContiguousBinaryOpGradFunctor {\n public:\n  NonContiguousBinaryOpGradFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"noncontiguous_binary_op_grad\")\n                         .Input(\"dy\")\n                         .Input(\"lhs\")\n                         .Input(\"rhs\")\n                         .Output(\"dlhs\")\n                         .Output(\"drhs\")\n                         .Build());\n  }\n\n  Maybe<TensorTuple> operator()(const std::shared_ptr<Tensor>& dy,\n                                const std::shared_ptr<Tensor>& lhs,\n                                const std::shared_ptr<Tensor>& rhs, const std::string& op,\n                                const bool& inplace = false) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"op\", \"inplace\");\n    attrs.SetAllAttrs(op, inplace);\n    return OpInterpUtil::Dispatch<TensorTuple>(*op_, {dy, lhs, rhs}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nnamespace {\n\ntemplate<int alignment_size>\nMaybe<one::Tensor> pad_last_dim(const std::shared_ptr<one::Tensor>& input) {\n  auto num_dims = input->shape()->NumAxes();\n  auto last_dim_size = input->shape()->At(num_dims - 1);\n  if (last_dim_size % alignment_size == 0) { return input; }\n  auto pad_count = alignment_size - (last_dim_size % alignment_size);\n\n  return JUST(functional::Pad(input, {0, pad_count}, \"constant\", Scalar(0)));\n  ;\n}\n\n}  // namespace\n\nclass ScaledDotProductFlashAttentionFunctor {\n public:\n  ScaledDotProductFlashAttentionFunctor() {\n#if CUDA_VERSION >= 11070\n    op_ = CHECK_JUST(one::OpBuilder(\"scaled_dot_product_flash_attention\")\n                         .Input(\"query\")\n                         .Input(\"key\")\n                         .Input(\"value\")\n                         .Output(\"out\")\n                         .Output(\"softmax_lse\")\n                         .Output(\"rng_state\")\n                         .Build());\n#endif  // CUDA_VERSION >= 11070\n  }\n\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& query,\n                           const std::shared_ptr<one::Tensor>& key,\n                           const std::shared_ptr<one::Tensor>& value,\n                           const Optional<one::Tensor>& attn_mask, const float& dropout_p,\n                           const bool& is_causal, const Optional<float>& scale,\n                           const int64_t& seed = 0) const {\n#if CUDA_VERSION >= 11070\n    const auto og_size = query->shape()->At(3);\n    const auto batch_size = query->shape()->At(0);\n    const auto seqlen_q = query->shape()->At(2);\n    const auto num_heads = query->shape()->At(1);\n    const auto num_heads_k = key->shape()->At(1);\n    const auto max_seqlen_batch_k = key->shape()->At(2);\n    const auto max_seqlen_batch_v = value->shape()->At(2);\n\n    CHECK_EQ_OR_RETURN(batch_size, key->shape()->At(0))\n        << \" key has different batch size from query.\";\n    CHECK_EQ_OR_RETURN(batch_size, value->shape()->At(0))\n        << \" value has different batch size from query.\";\n    CHECK_EQ_OR_RETURN(num_heads_k, value->shape()->At(1))\n        << \" value has different num_heads from key.\";\n    CHECK_EQ_OR_RETURN(max_seqlen_batch_k, max_seqlen_batch_v)\n        << \"value has different seqlen from key.\";\n    CHECK_EQ_OR_RETURN(og_size, key->shape()->At(3)) << \" key has different head dims from query.\";\n    CHECK_EQ_OR_RETURN(og_size, value->shape()->At(3))\n        << \" value has different head dims from query.\";\n\n    // Query (Batch x Num_heads x Q_seq_len  x Dim_per_head)\n    // Key   (Batch x Num_heads x KV_seq_len x Dim_per_head)\n    // Value (Batch x Num_heads x KV_seq_len x Dim_per_head)\n    std::shared_ptr<Tensor> q_padded, k_padded, v_padded;\n    bool padded = og_size % 8;\n    if (padded) {\n      q_padded = JUST(pad_last_dim<8>(query));\n      k_padded = JUST(pad_last_dim<8>(key));\n      v_padded = JUST(pad_last_dim<8>(value));\n    } else {\n      q_padded = query;\n      k_padded = key;\n      v_padded = value;\n    }\n\n    auto q_ = JUST(functional::Transpose(q_padded, {0, 2, 1, 3}));\n    auto k_ = JUST(functional::Transpose(k_padded, {0, 2, 1, 3}));\n    auto v_ = JUST(functional::Transpose(v_padded, {0, 2, 1, 3}));\n    // Query -> Query(Batch x Q_seq_len  x Num_heads x Dim_per_head)\n    // Key   -> Key  (Batch x KV_seq_len x Num_heads x Dim_per_head)\n    // Value -> Value(Batch x KV_seq_len x Num_heads x Dim_per_head)\n\n    const auto& scale_ =\n        scale.has_value() ? scale : (1.0f / std::sqrt(static_cast<float>(query->shape()->At(3))));\n\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"p_dropout\", \"softmax_scale\", \"is_causal\",\n                                                 \"window_size_left\", \"window_size_right\", \"seed\");\n    attrs.SetAllAttrs(dropout_p, scale_, is_causal, -1, -1, seed);\n\n    auto gen = JUST(one::DefaultAutoGenerator());\n    gen = JUST(GetGeneratorForLazyOrGlobal(gen, LazyMode::is_enabled(), query));\n    const auto& state = std::make_shared<ScaledDotProductFlashAttentionKernelState>(gen);\n    OpExprInterpContext ctx(attrs, state);\n\n    std::shared_ptr<one::Tensor> output_ =\n        JUST(OpInterpUtil::Dispatch<one::Tensor>(*op_, {q_, k_, v_}, ctx));\n\n    auto output_padded = JUST(functional::Transpose(output_, {0, 2, 1, 3}));\n\n    std::shared_ptr<Tensor> output;\n    if (padded) {\n      output =\n          JUST(functional::Slice(output_padded, {0, 0, 0, 0},\n                                 {batch_size, num_heads, seqlen_q, og_size}, {1, 1, 1, 1}, false));\n    } else {\n      output = output_padded;\n    }\n\n    return output;\n#endif  // CUDA_VERSION >= 11070\n\n    UNIMPLEMENTED_THEN_RETURN() << \"only support CUDA_VERSION >= 11070.\";\n  }\n\n private:\n#if CUDA_VERSION >= 11070\n  std::shared_ptr<OpExpr> op_;\n#endif  // CUDA_VERSION >= 11070\n};\n\nclass ScaledDotProductFlashAttentionGradFunctor {\n public:\n  ScaledDotProductFlashAttentionGradFunctor() {\n#if CUDA_VERSION >= 11070\n    op_ = CHECK_JUST(one::OpBuilder(\"scaled_dot_product_flash_attention_grad\")\n                         .Input(\"grad_out\")\n                         .Input(\"query\")\n                         .Input(\"key\")\n                         .Input(\"value\")\n                         .Input(\"out\")\n                         .Input(\"softmax_lse\")\n                         .Input(\"rng_state\")\n                         .Output(\"grad_q\")\n                         .Output(\"grad_k\")\n                         .Output(\"grad_v\")\n                         .Build());\n#endif\n  }\n\n  Maybe<TensorTuple> operator()(\n      const std::shared_ptr<one::Tensor>& grad_out, const std::shared_ptr<one::Tensor>& query,\n      const std::shared_ptr<one::Tensor>& key, const std::shared_ptr<one::Tensor>& value,\n      const std::shared_ptr<one::Tensor>& out, const std::shared_ptr<one::Tensor>& softmax_lse,\n      const std::shared_ptr<one::Tensor>& rng_state, const float& dropout_p, const bool& is_causal,\n      const float& scale) const {\n#if CUDA_VERSION >= 11070\n    // grad_out(batch x q_sqe_len  x num_heads x head_size)\n    // query   (batch x q_seq_len  x num_heads x head_size_padded)\n    // key     (batch x kv_seq_len x num_heads_k x head_size_padded)\n    // value   (batch x kv_seq_len x num_heads_k x head_size_padded)\n    // out     (batch x kv_seq_len x num_heads x head_size_padded)\n    // softmax_lse (batch x num_heads x q_seq_len)\n    const auto head_size = grad_out->shape()->At(3);\n    const auto head_size_padded = query->shape()->At(3);\n    const auto batch_size = query->shape()->At(0);\n    const auto seqlen_q = query->shape()->At(1);\n    const auto seqlen_k = key->shape()->At(1);\n    const auto num_heads = query->shape()->At(2);\n    const auto num_heads_k = key->shape()->At(2);\n    CHECK_EQ_OR_RETURN(batch_size, key->shape()->At(0))\n        << \" key has different batch size from query.\";\n    CHECK_EQ_OR_RETURN(batch_size, value->shape()->At(0))\n        << \" value has different batch size from query.\";\n    CHECK_EQ_OR_RETURN(batch_size, grad_out->shape()->At(0))\n        << \" grad_out has different batch size from query.\";\n    CHECK_EQ_OR_RETURN(batch_size, out->shape()->At(0))\n        << \" out has different batch size from query.\";\n    CHECK_EQ_OR_RETURN(batch_size, softmax_lse->shape()->At(0))\n        << \" softmax_lse has different batch size from query.\";\n    CHECK_EQ_OR_RETURN(num_heads, grad_out->shape()->At(2))\n        << \" grad_out has different num_heads from query.\";\n    CHECK_EQ_OR_RETURN(num_heads, softmax_lse->shape()->At(1))\n        << \" softmax_lse has different num_heads from query.\";\n    CHECK_EQ_OR_RETURN(num_heads_k, value->shape()->At(2))\n        << \" value has different num_heads from key.\";\n    CHECK_EQ_OR_RETURN(seqlen_q, grad_out->shape()->At(1))\n        << \" grad_out has different seq_len from query.\";\n    CHECK_EQ_OR_RETURN(seqlen_q, softmax_lse->shape()->At(2))\n        << \" softmax_lse has different seq_len from query.\";\n    CHECK_EQ_OR_RETURN(head_size_padded, key->shape()->At(3))\n        << \" key has different head dims from query.\";\n    CHECK_EQ_OR_RETURN(head_size_padded, value->shape()->At(3))\n        << \" key has different head dims from query.\";\n    CHECK_EQ_OR_RETURN(head_size_padded, out->shape()->At(3))\n        << \" out has different head dims from query.\";\n\n    bool padded = head_size % 8;\n\n    auto grad_out_ = padded ? JUST(pad_last_dim<8>(grad_out)) : grad_out;\n\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"p_dropout\", \"softmax_scale\", \"is_causal\",\n                                                 \"window_size_left\", \"window_size_right\");\n    attrs.SetAllAttrs(dropout_p, scale, is_causal, -1, -1);\n\n    auto output = std::make_shared<TensorTuple>(3);\n    auto output_ = JUST(OpInterpUtil::Dispatch<TensorTuple>(\n        *op_, {grad_out_, query, key, value, out, softmax_lse, rng_state}, attrs));\n    CHECK_EQ(output_->size(), 3);\n    auto grad_q_ = (*output_)[0];\n    auto grad_k_ = (*output_)[1];\n    auto grad_v_ = (*output_)[2];\n\n    std::shared_ptr<Tensor> grad_q_padded, grad_k_padded, grad_v_padded;\n\n    bool expanded = num_heads != num_heads_k;\n\n    grad_q_padded = grad_q_;\n    if (expanded) {\n      grad_k_padded = JUST(functional::ReduceSum(\n          JUST(functional::Reshape(grad_k_, {batch_size, seqlen_k, num_heads_k,\n                                             num_heads / num_heads_k, head_size_padded})),\n          {3}, false, grad_k_->dtype()));\n      grad_v_padded = JUST(functional::ReduceSum(\n          JUST(functional::Reshape(grad_v_, {batch_size, seqlen_k, num_heads_k,\n                                             num_heads / num_heads_k, head_size_padded})),\n          {3}, false, grad_v_->dtype()));\n    } else {\n      grad_k_padded = grad_k_;\n      grad_v_padded = grad_v_;\n    }\n\n    auto grad_q = padded ? JUST(functional::Slice(grad_q_padded, {0, 0, 0, 0},\n                                                  {batch_size, seqlen_q, num_heads, head_size},\n                                                  {1, 1, 1, 1}, false))\n                         : grad_q_padded;\n    auto grad_k = padded ? JUST(functional::Slice(grad_k_padded, {0, 0, 0, 0},\n                                                  {batch_size, seqlen_k, num_heads_k, head_size},\n                                                  {1, 1, 1, 1}, false))\n                         : grad_k_padded;\n    auto grad_v = padded ? JUST(functional::Slice(grad_v_padded, {0, 0, 0, 0},\n                                                  {batch_size, seqlen_k, num_heads_k, head_size},\n                                                  {1, 1, 1, 1}, false))\n                         : grad_v_padded;\n\n    (*output)[0] = grad_q;\n    (*output)[1] = grad_k;\n    (*output)[2] = grad_v;\n    return output;\n\n#endif  // CUDA_VERSION >= 11070\n\n    UNIMPLEMENTED_THEN_RETURN() << \"only support CUDA_VERSION >= 11070.\";\n  }\n\n private:\n#if CUDA_VERSION >= 11070\n  std::shared_ptr<OpExpr> op_;\n#endif  // CUDA_VERSION >= 11070\n};\n}  // namespace impl\n\nONEFLOW_FUNCTION_LIBRARY(m) {\n  m.add_functor<impl::BiasAddFunctor>(\"BiasAdd\");\n  m.add_functor<impl::Conv1dFunctor>(\"Conv1d\");\n  m.add_functor<impl::Conv2dFunctor>(\"Conv2d\");\n  m.add_functor<impl::Conv3dFunctor>(\"Conv3d\");\n  m.add_functor<impl::DeConv1dFunctor>(\"Deconv1d\");\n  m.add_functor<impl::DeConv2dFunctor>(\"Deconv2d\");\n  m.add_functor<impl::DeConv3dFunctor>(\"Deconv3d\");\n  m.add_functor<impl::EmbeddingReNormFunctor>(\"EmbeddingReNorm\");\n  m.add_functor<impl::EmbeddingFunctor>(\"Embedding\");\n  m.add_functor<impl::MatMulFunctor>(\"MatMul\");\n  m.add_functor<impl::MatMulNoBroadCastFunctor>(\"MatMulNoBroadCast\");\n  m.add_functor<impl::BatchMatMulFunctor>(\"BatchMatMul\");\n  m.add_functor<impl::MatrixVectorProductFunctor>(\"MatrixVectorProduct\");\n  m.add_functor<impl::VectorMatrixProductFunctor>(\"VectorMatrixProduct\");\n  m.add_functor<impl::TensorDotFunctor>(\"TensorDot\");\n  m.add_functor<impl::TensorDotIntDimsFunctor>(\"TensorDotIntDims\");\n  m.add_functor<impl::FusedMLPFunctor>(\"FusedMLP\");\n  m.add_functor<impl::FusedMatmulBiasFunctor>(\"FusedMatmulBias\");\n  m.add_functor<impl::FusedMatmulBiasAddReluDropoutFunctor>(\"FusedMatmulBiasAddReluDropout\");\n  m.add_functor<impl::LayerNormFunctor>(\"LayerNorm\");\n  m.add_functor<impl::SkipLayerNormFunctor>(\"SkipLayerNorm\");\n  m.add_functor<impl::LayerNormAffineFunctor>(\"LayerNormAffine\");\n  m.add_functor<impl::GroupNormFunctor>(\"GroupNorm\");\n  m.add_functor<impl::TFAvgPool2DFunctor>(\"TFAvgPool2D\");\n  m.add_functor<impl::MaxPool1DFunctor>(\"MaxPool1D\");\n  m.add_functor<impl::MaxPool2DFunctor>(\"MaxPool2D\");\n  m.add_functor<impl::MaxPool3DFunctor>(\"MaxPool3D\");\n  m.add_functor<impl::MaxUnpoolNDFunctor<1>>(\"MaxUnpool1D\");\n  m.add_functor<impl::MaxUnpoolNDFunctor<2>>(\"MaxUnpool2D\");\n  m.add_functor<impl::MaxUnpoolNDFunctor<3>>(\"MaxUnpool3D\");\n  m.add_functor<impl::AdaptiveAvgPool1DFunctor>(\"AdaptiveAvgPool1D\");\n  m.add_functor<impl::AdaptiveAvgPool2DFunctor>(\"AdaptiveAvgPool2D\");\n  m.add_functor<impl::AdaptiveAvgPool3DFunctor>(\"AdaptiveAvgPool3D\");\n  m.add_functor<impl::AdaptiveMaxPool1DFunctor>(\"AdaptiveMaxPool1D\");\n  m.add_functor<impl::AdaptiveMaxPool2DFunctor>(\"AdaptiveMaxPool2D\");\n  m.add_functor<impl::AdaptiveMaxPool3DFunctor>(\"AdaptiveMaxPool3D\");\n  m.add_functor<impl::L1LossFunctor>(\"L1Loss\");\n  m.add_functor<impl::MseLossFunctor>(\"MseLoss\");\n  m.add_functor<impl::KLDivLossFunctor>(\"KLDivLoss\");\n  m.add_functor<impl::NLLLossFunctor>(\"NLLLoss\");\n  m.add_functor<impl::BinaryCrossEntropyLossFunctor>(\"BinaryCrossEntropyLoss\");\n  m.add_functor<impl::BinaryCrossEntropyWithLogitsLossFunctor>(\"BinaryCrossEntropyWithLogitsLoss\");\n  m.add_functor<impl::SparseCrossEntropyFunctor>(\"SparseCrossEntropy\");\n  m.add_functor<impl::SparseCrossEntropyMsFunctor>(\"SparseCrossEntropyMs\");\n  m.add_functor<impl::CrossEntropyFunctor>(\"CrossEntropy\");\n  m.add_functor<impl::CrossEntropyLabelSmoothingFunctor>(\"CrossEntropyLabelSmoothing\");\n  m.add_functor<impl::CrossEntropyProbFunctor>(\"CrossEntropyProb\");\n  m.add_functor<impl::SparseSoftmaxCrossEntropyFunctor>(\"SparseSoftmaxCrossEntropy\");\n  m.add_functor<impl::SoftmaxCrossEntropyFunctor>(\"SoftmaxCrossEntropy\");\n  m.add_functor<impl::SoftmaxCrossEntropyGradFunctor>(\"SoftmaxCrossEntropyGrad\");\n  m.add_functor<impl::SmoothL1LossFunctor>(\"SmoothL1Loss\");\n  m.add_functor<impl::CombinedMarginLossFunctor>(\"CombinedMarginLoss\");\n  m.add_functor<impl::TripletMarginLossFunctor>(\"TripletMarginLoss\");\n  m.add_functor<impl::MarginRankingLossFunctor>(\"MarginRankingLoss\");\n  m.add_functor<impl::CtcLossFunctor>(\"CtcLoss\");\n  m.add_functor<impl::AffineGridFunctor>(\"AffineGrid\");\n  m.add_functor<impl::GridSampleFunctor>(\"GridSample\");\n  m.add_functor<impl::NormalizationFunctor>(\"Normalization\");\n  m.add_functor<impl::NormalizationAddReluFunctor>(\"NormalizationAddRelu\");\n  m.add_functor<impl::ConstantPadFunctor>(\"ConstantPad\");\n  m.add_functor<impl::ReflectionPadFunctor>(\"ReflectionPad\");\n  m.add_functor<impl::ReplicationPadFunctor>(\"ReplicationPad\");\n  m.add_functor<impl::PadFunctor>(\"Pad\");\n  m.add_functor<impl::DropoutFunctor>(\"Dropout\");\n  m.add_functor<impl::DropoutGradFunctor>(\"DropoutGrad\");\n  m.add_functor<impl::Dropout1dFunctor>(\"Dropout1d\");\n  m.add_functor<impl::Dropout2dFunctor>(\"Dropout2d\");\n  m.add_functor<impl::Dropout3dFunctor>(\"Dropout3d\");\n  m.add_functor<impl::PixelShuffleFunctor>(\"PixelShuffle\");\n  m.add_functor<impl::AvgPool1DFunctor>(\"AvgPool1D\");\n  m.add_functor<impl::AvgPool2DFunctor>(\"AvgPool2D\");\n  m.add_functor<impl::AvgPool3DFunctor>(\"AvgPool3D\");\n  m.add_functor<impl::UnfoldFunctor>(\"Unfold\");\n  m.add_functor<impl::FoldFunctor>(\"Fold\");\n  m.add_functor<impl::OneHotFunctor>(\"OneHot\");\n  m.add_functor<impl::FusedSelfAttentionFunctor>(\"FusedSelfAttention\");\n  m.add_functor<impl::FusedSelfAttentionGradFunctor>(\"FusedSelfAttentionGrad\");\n  m.add_functor<impl::PairwiseDistanceFunctor>(\"PairwiseDistance\");\n  m.add_functor<impl::CosineSimilarityFunctor>(\"CosineSimilarity\");\n  m.add_functor<impl::NormalizeFunctor>(\"Normalize\");\n  m.add_functor<impl::L2NormalizeFunctor>(\"L2Normalize\");\n  m.add_functor<impl::L2NormalizeGradFunctor>(\"L2NormalizeGrad\");\n  m.add_functor<impl::FusedBiasAddGeluFunctor>(\"FusedBiasAddGelu\");\n  m.add_functor<impl::FusedBiasAddGeluGradFunctor>(\"FusedBiasAddGeluGrad\");\n  m.add_functor<impl::FusedGluFunctor>(\"FusedGlu\");\n  m.add_functor<impl::FusedBiasAddDropoutFunctor>(\"FusedBiasAddDropout\");\n  m.add_functor<impl::FusedScaleMaskSoftmaxFunctor>(\"FusedScaleMaskSoftmax\");\n  m.add_functor<impl::FusedScaleMaskSoftmaxDropoutFunctor>(\"FusedScaleMaskSoftmaxDropout\");\n  m.add_functor<impl::FusedBiasAddScaleMaskSoftmaxDropoutFunctor>(\n      \"FusedBiasAddScaleMaskSoftmaxDropout\");\n  m.add_functor<impl::FusedScaleTrilSoftmaxMaskScaleFunctor>(\"FusedScaleTrilSoftmaxMaskScale\");\n  m.add_functor<impl::FusedScaleTrilFunctor>(\"FusedScaleTril\");\n  m.add_functor<impl::CtcGreedyDecoderFunctor>(\"CtcGreedyDecoder\");\n  m.add_functor<impl::PariticalFCSampleDisableBoxing>(\"DistributedPariticalFCSampleDisableBoxing\");\n  m.add_functor<impl::NmsFunctor>(\"Nms\");\n  m.add_functor<impl::RoiAlignFunctor>(\"RoiAlign\");\n  m.add_functor<impl::RoiAlignGradFunctor>(\"RoiAlignGrad\");\n  m.add_functor<impl::FusedDotFeatureInteractionFunctor>(\"FusedDotFeatureInteraction\");\n  m.add_functor<impl::FusedCrossFeatureInteractionFunctor>(\"FusedCrossFeatureInteraction\");\n  m.add_functor<impl::OneEmbeddingIdShuffleFunctor>(\"OneEmbeddingIdShuffle\");\n  m.add_functor<impl::OneEmbeddingEmbeddingShuffleFunctor>(\"OneEmbeddingEmbeddingShuffle\");\n  m.add_functor<impl::OneEmbeddingEmbeddingGradientShuffleFunctor>(\n      \"OneEmbeddingEmbeddingGradientShuffle\");\n  m.add_functor<impl::OneEmbeddingLookupFunctor>(\"OneEmbeddingLookup\");\n  m.add_functor<impl::OneEmbeddingFusedLookupFunctor>(\"OneEmbeddingFusedLookup\");\n  m.add_functor<impl::OneEmbeddingFusedLookupGradFunctor>(\"OneEmbeddingFusedLookupGrad\");\n  m.add_functor<impl::OneEmbeddingEmbeddingPutFunctor>(\"OneEmbeddingEmbeddingPut\");\n  m.add_functor<impl::OneEmbeddingUniqueKeyValuePairFunctor>(\"OneEmbeddingUniqueKeyValuePair\");\n  m.add_functor<impl::OneEmbeddingSgdUpdateFunctor>(\"OneEmbeddingSgdUpdate\");\n  m.add_functor<impl::OneEmbeddingAdamUpdateFunctor>(\"OneEmbeddingAdamUpdate\");\n  m.add_functor<impl::OneEmbeddingAdagradUpdateFunctor>(\"OneEmbeddingAdagradUpdate\");\n  m.add_functor<impl::OneEmbeddingFtrlUpdateFunctor>(\"OneEmbeddingFtrlUpdate\");\n  m.add_functor<impl::RocAucScoreFunctor>(\"RocAucScore\");\n  m.add_functor<impl::MultiTensorSgdUpdateFunctor>(\"MultiTensorSgdUpdate\");\n  m.add_functor<impl::MultiTensorMomentumUpdateFunctor>(\"MultiTensorMomentumUpdate\");\n  m.add_functor<impl::MultiTensorAdamUpdateFunctor>(\"MultiTensorAdamUpdate\");\n  m.add_functor<impl::DeformConv2dFunctor>(\"DeformConv2d\");\n  m.add_functor<impl::BatchNormStatsFunctor>(\"BatchNormStats\");\n  m.add_functor<impl::BatchNormGatherStatsWithCountsFunctor>(\"BatchNormGatherStatsWithCounts\");\n  m.add_functor<impl::BatchNormElemtFunctor>(\"BatchNormElemt\");\n  m.add_functor<impl::BatchNormBackwardReduceFunctor>(\"BatchNormBackwardReduce\");\n  m.add_functor<impl::BatchNormBackwardElemtFunctor>(\"BatchNormBackwardElemt\");\n  m.add_functor<impl::FusedFastGeluMulFunctor>(\"FusedFastGeluMul\");\n  m.add_functor<impl::FusedFastGeluMulGradFunctor>(\"FusedFastGeluMulGrad\");\n  m.add_functor<impl::GroupedMatmulBiasFunctor>(\"GroupedMatmulBias\");\n  m.add_functor<impl::GroupedMatmulFunctor>(\"GroupedMatmul\");\n  m.add_functor<impl::RMSNormFunctor>(\"RMSNorm\");\n  m.add_functor<impl::SkipRMSNormFunctor>(\"SkipRMSNorm\");\n  m.add_functor<impl::FusedScaleMaskBiasSoftmaxFunctor>(\"FusedScaleMaskBiasSoftmax\");\n  m.add_functor<impl::FusedScaleMaskBiasSoftmaxGradFunctor>(\"FusedScaleMaskBiasSoftmaxGrad\");\n  m.add_functor<impl::NonContiguousBinaryOpFunctor>(\"NonContiguousBinaryOp\");\n  m.add_functor<impl::NonContiguousBinaryOpGradFunctor>(\"NonContiguousBinaryOpGrad\");\n  m.add_functor<impl::MultiTensorYoloV5WeightUpdateFunctor>(\"MultiTensorYoloV5WeightUpdate\");\n  m.add_functor<impl::FusedClipGradFunctor>(\"FusedClipGrad\");\n  m.add_functor<impl::ScaledDotProductFlashAttentionFunctor>(\"ScaledDotProductFlashAttention\");\n  m.add_functor<impl::ScaledDotProductFlashAttentionGradFunctor>(\n      \"ScaledDotProductFlashAttentionGrad\");\n}\n\n}  // namespace functional\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/functional/impl/nn_grad_functor.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"fmt/core.h\"\n#include \"oneflow/core/framework/mutable_attr_map.h\"\n#include \"oneflow/core/framework/op_builder.h\"\n#include \"oneflow/core/functional/function_library.h\"\n#include \"oneflow/core/functional/sequence_function.h\"\n#include \"oneflow/core/functional/impl/unary_functor.h\"\n#include \"oneflow/core/common/container_util.h\"\n\nnamespace oneflow {\nnamespace one {\nnamespace functional {\n\nnamespace impl {\n\nclass ConvBiasGradFunctor {\n public:\n  ConvBiasGradFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"conv_bias_grad\").Input(\"dy\").Output(\"bias_diff\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& dy, const int32_t& num_spatial_dims,\n                           const std::string& data_format) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"num_spatial_dims\", \"data_format\");\n    attrs.SetAllAttrs(num_spatial_dims, data_format);\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {dy}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass ConvFilterGradFunctor {\n public:\n  ConvFilterGradFunctor() {\n    op_ = CHECK_JUST(\n        one::OpBuilder(\"conv_filter_grad\").Input(\"dy\").Input(\"x\").Output(\"filter_diff\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& dy,\n                           const std::shared_ptr<one::Tensor>& x, const int32_t& num_spatial_dims,\n                           const std::vector<int32_t>& kernel_size,\n                           const std::vector<int32_t>& strides,\n                           const std::vector<int32_t>& padding_before,\n                           const std::vector<int32_t>& dilation_rate, const int32_t& groups,\n                           const std::string& data_format) const {\n    auto& attrs =\n        THREAD_CACHED_MUTABLE_ATTR_MAP(\"num_spatial_dims\", \"kernel_size\", \"strides\",\n                                       \"padding_before\", \"dilation_rate\", \"groups\", \"data_format\");\n    attrs.SetAllAttrs(num_spatial_dims, kernel_size, strides, padding_before, dilation_rate, groups,\n                      data_format);\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {dy, x}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass ConvDataGradFunctor {\n public:\n  ConvDataGradFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"conv_data_grad\")\n                         .Input(\"dy\")\n                         .Input(\"filter\")\n                         .Input(\"x_like\")\n                         .Output(\"dx\")\n                         .Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& dy,\n                           const std::shared_ptr<one::Tensor>& weight,\n                           const std::shared_ptr<one::Tensor>& x, const int32_t& num_spatial_dims,\n                           const std::vector<int32_t>& kernel_size,\n                           const std::vector<int32_t>& strides,\n                           const std::vector<int32_t>& padding_before,\n                           const std::vector<int32_t>& dilation_rate, const int32_t& groups,\n                           const std::string& data_format) const {\n    auto& attrs =\n        THREAD_CACHED_MUTABLE_ATTR_MAP(\"num_spatial_dims\", \"kernel_size\", \"strides\",\n                                       \"padding_before\", \"dilation_rate\", \"groups\", \"data_format\");\n    attrs.SetAllAttrs(num_spatial_dims, kernel_size, strides, padding_before, dilation_rate, groups,\n                      data_format);\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {dy, weight, JUST(x->detach())}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass EmbeddingGradFunctor {\n public:\n  EmbeddingGradFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"embedding_grad\")\n                         .Input(\"dy\")\n                         .Input(\"weight\")\n                         .Input(\"indices\")\n                         .Output(\"dx\")\n                         .Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& dy,\n                           const std::shared_ptr<one::Tensor>& weight,\n                           const std::shared_ptr<one::Tensor>& indices, const int64_t& padding_idx,\n                           const bool& scale_grad_by_freq) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"padding_idx\", \"scale_grad_by_freq\");\n    attrs.SetAllAttrs(padding_idx, scale_grad_by_freq);\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {dy, weight, indices}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass MaxPoolNdGradFunctor {\n public:\n  MaxPoolNdGradFunctor() {\n    for (int ndims = 1; ndims <= 3; ++ndims) {\n      const auto& op_type_name = GetOpTypeName(ndims);\n      op_expr_map_[op_type_name] = CHECK_JUST(\n          one::OpBuilder(op_type_name).Input(\"dy\").Input(\"x\").Input(\"indice\").Output(\"dx\").Build());\n    }\n  }\n  static std::string GetOpTypeName(const int32_t& ndims) {\n    return \"max_pool_\" + std::to_string(ndims) + \"d_grad\";\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,\n                           const std::shared_ptr<one::Tensor>& indice,\n                           const std::shared_ptr<one::Tensor>& dy, const int32_t& ndims,\n                           const std::string& data_format, const std::vector<int32_t>& padding,\n                           const std::vector<int32_t>& kernel_size,\n                           const std::vector<int32_t>& stride, const std::vector<int32_t>& dilation,\n                           const bool& return_indices, const bool& ceil_mode) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"data_format\", \"padding\", \"kernel_size\", \"stride\",\n                                                 \"dilation\", \"return_indices\", \"ceil_mode\");\n    attrs.SetAllAttrs(data_format, padding, kernel_size, stride, dilation, return_indices,\n                      ceil_mode);\n    const auto& op_type_name = GetOpTypeName(ndims);\n    const auto& it = op_expr_map_.find(op_type_name);\n    CHECK_OR_RETURN(it != op_expr_map_.end())\n        << Error::RuntimeError() << \"Encounter unsupported op \" << op_type_name\n        << \" in MaxPoolNdGradFunctor.\";\n    CHECK_NOTNULL_OR_RETURN(it->second);  // NOLINT(maybe-need-error-msg)\n    return OpInterpUtil::Dispatch<Tensor>(*it->second, {dy, x, indice}, attrs);\n  }\n\n protected:\n  std::unordered_map<std::string, std::shared_ptr<OpExpr>> op_expr_map_;\n};\n\ntemplate<int N>\nclass MaxUnpoolNdGradFunctor {\n public:\n  MaxUnpoolNdGradFunctor()\n      : op_(CHECK_JUST(one::OpBuilder(fmt::format(\"max_unpool_{}d_grad\", N))\n                           .Input(\"dy\")\n                           .Input(\"x\")\n                           .Input(\"indices\")\n                           .Output(\"dx\")\n                           .Build())) {}\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,\n                           const std::shared_ptr<one::Tensor>& indice,\n                           const std::shared_ptr<one::Tensor>& dy) const {\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {dy, x, indice});\n  }\n\n protected:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass AdaptiveMaxPoolNdGradFunctor {\n public:\n  AdaptiveMaxPoolNdGradFunctor() {\n    for (int ndims = 1; ndims <= 3; ++ndims) {\n      const auto& op_type_name = GetOpTypeName(ndims);\n      op_expr_map_[op_type_name] = CHECK_JUST(\n          one::OpBuilder(op_type_name).Input(\"dy\").Input(\"x\").Input(\"index\").Output(\"dx\").Build());\n    }\n  }\n  static std::string GetOpTypeName(const int32_t& ndims) {\n    return \"adaptive_max_pool\" + std::to_string(ndims) + \"d_grad\";\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,\n                           const std::shared_ptr<one::Tensor>& dy,\n                           const std::shared_ptr<one::Tensor>& index, const int32_t& ndims,\n                           const std::string& data_format) const {\n    const auto& op_type_name = GetOpTypeName(ndims);\n    const auto& it = op_expr_map_.find(op_type_name);\n    CHECK_OR_RETURN(it != op_expr_map_.end())\n        << Error::RuntimeError() << \"Encounter unsupported op \" << op_type_name\n        << \" in AdaptiveMaxPoolNdGradFunctor.\";\n    CHECK_NOTNULL_OR_RETURN(it->second);  // NOLINT(maybe-need-error-msg)\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"data_format\");\n    attrs.SetAllAttrs(data_format);\n    return OpInterpUtil::Dispatch<Tensor>(*it->second, {dy, x, index}, attrs);\n  }\n\n protected:\n  std::unordered_map<std::string, std::shared_ptr<OpExpr>> op_expr_map_;\n};\n\nclass TFPoolNdGradFunctor {\n public:\n  TFPoolNdGradFunctor() {\n    for (const auto& mode : {\"tf_max\", \"tf_avg\"}) {\n      for (int ndims = 1; ndims <= 3; ++ndims) {\n        const auto& op_type_name = GetOpTypeName(mode, ndims);\n        op_expr_map_[op_type_name] = CHECK_JUST(\n            one::OpBuilder(op_type_name).Input(\"x\").Input(\"y\").Input(\"dy\").Output(\"dx\").Build());\n      }\n    }\n  }\n  static std::string GetOpTypeName(const std::string& mode, const int32_t& ndims) {\n    return mode + \"_pool_\" + std::to_string(ndims) + \"d_grad\";\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,\n                           const std::shared_ptr<one::Tensor>& y,\n                           const std::shared_ptr<one::Tensor>& dy, const std::string& mode,\n                           const int32_t& ndims, const std::string& data_format,\n                           const std::string& padding, const std::vector<int32_t>& padding_before,\n                           const std::vector<int32_t>& padding_after,\n                           const std::vector<int32_t>& pool_size,\n                           const std::vector<int32_t>& strides, const bool& ceil_mode) const {\n    auto& attrs =\n        THREAD_CACHED_MUTABLE_ATTR_MAP(\"data_format\", \"padding\", \"padding_before\", \"padding_after\",\n                                       \"pool_size\", \"strides\", \"ceil_mode\");\n    attrs.SetAllAttrs(data_format, padding, padding_before, padding_after, pool_size, strides,\n                      ceil_mode);\n    const auto& op_type_name = GetOpTypeName(mode, ndims);\n    const auto& it = op_expr_map_.find(op_type_name);\n    CHECK_OR_RETURN(it != op_expr_map_.end())\n        << Error::RuntimeError() << \"Encounter unsupported op \" << op_type_name\n        << \" in TFPoolNdGradFunctor.\";\n    CHECK_NOTNULL_OR_RETURN(it->second);  // NOLINT(maybe-need-error-msg)\n    return OpInterpUtil::Dispatch<Tensor>(*it->second, {x, y, dy}, attrs);\n  }\n\n protected:\n  std::unordered_map<std::string, std::shared_ptr<OpExpr>> op_expr_map_;\n};\n\nclass AdaptivePoolNdGradFunctor {\n public:\n  AdaptivePoolNdGradFunctor() {\n    for (const auto& mode : {\"avg\"}) {\n      for (int ndims = 1; ndims <= 3; ++ndims) {\n        const auto& op_type_name = GetOpTypeName(mode, ndims);\n        op_expr_map_[op_type_name] =\n            CHECK_JUST(one::OpBuilder(op_type_name).Input(\"dy\").Input(\"x\").Output(\"dx\").Build());\n      }\n    }\n  }\n  static std::string GetOpTypeName(const std::string& mode, const int32_t& ndims) {\n    return \"adaptive_\" + mode + \"_pool\" + std::to_string(ndims) + \"d_grad\";\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,\n                           const std::shared_ptr<one::Tensor>& dy, const std::string& mode,\n                           const int32_t& ndims, const std::string& data_format) const {\n    const auto& op_type_name = GetOpTypeName(mode, ndims);\n    const auto& it = op_expr_map_.find(op_type_name);\n    CHECK_OR_RETURN(it != op_expr_map_.end())\n        << Error::RuntimeError() << \"Encounter unsupported op \" << op_type_name\n        << \" in AdaptivePoolNdGradFunctor.\";\n    CHECK_NOTNULL_OR_RETURN(it->second);  // NOLINT(maybe-need-error-msg)\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"data_format\");\n    attrs.SetAllAttrs(data_format);\n    return OpInterpUtil::Dispatch<Tensor>(*it->second, {dy, x}, attrs);\n  }\n\n protected:\n  std::unordered_map<std::string, std::shared_ptr<OpExpr>> op_expr_map_;\n};\n\nclass SparseCrossEntropyGradFunctor {\n public:\n  SparseCrossEntropyGradFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"sparse_cross_entropy_grad\")\n                         .Input(\"prediction\")\n                         .Input(\"label\")\n                         .Input(\"dy\")\n                         .Output(\"prediction_diff\")\n                         .Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& prediction,\n                           const std::shared_ptr<one::Tensor>& label,\n                           const std::shared_ptr<one::Tensor>& dy, const int64_t& depth) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"depth\");\n    attrs.SetAllAttrs(depth);\n\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {prediction, label, dy}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass SparseCrossEntropyMsGradFunctor {\n public:\n  SparseCrossEntropyMsGradFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"sparse_cross_entropy_ms_grad\")\n                         .Input(\"prediction\")\n                         .Input(\"label\")\n                         .Input(\"dy\")\n                         .Output(\"prediction_diff\")\n                         .Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& prediction,\n                           const std::shared_ptr<one::Tensor>& label,\n                           const std::shared_ptr<one::Tensor>& dy, const int64_t& depth) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"depth\");\n    attrs.SetAllAttrs(depth);\n\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {prediction, label, dy}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass SparseSoftmaxCrossEntropyGrad {\n public:\n  SparseSoftmaxCrossEntropyGrad() {\n    op_ = CHECK_JUST(one::OpBuilder(\"sparse_softmax_cross_entropy_grad\")\n                         .Input(\"prob\")\n                         .Input(\"label\")\n                         .Input(\"dy\")\n                         .Output(\"prediction_diff\")\n                         .Build());\n  }\n\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& dy,\n                           const std::shared_ptr<one::Tensor>& prob,\n                           const std::shared_ptr<one::Tensor>& label, const int64_t& depth) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"depth\");\n    attrs.SetAllAttrs(depth);\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {prob, label, dy}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass SparseSoftmaxCrossEntropyMsGrad {\n public:\n  SparseSoftmaxCrossEntropyMsGrad() {\n    op_ = CHECK_JUST(one::OpBuilder(\"sparse_softmax_cross_entropy_ms_grad\")\n                         .Input(\"prob\")\n                         .Input(\"label\")\n                         .Input(\"dy\")\n                         .Output(\"prediction_diff\")\n                         .Build());\n  }\n\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& dy,\n                           const std::shared_ptr<one::Tensor>& prob,\n                           const std::shared_ptr<one::Tensor>& label, const int64_t& depth) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"depth\");\n    attrs.SetAllAttrs(depth);\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {prob, label, dy}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass SmoothL1LossGradFunctor {\n public:\n  SmoothL1LossGradFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"smooth_l1_loss_grad\")\n                         .Input(\"dy\")\n                         .Input(\"input\")\n                         .Input(\"target\")\n                         .Output(\"dx\")\n                         .Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& dy,\n                           const std::shared_ptr<one::Tensor>& input,\n                           const std::shared_ptr<one::Tensor>& target, const float& beta) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"beta\");\n    attrs.SetAllAttrs(beta);\n\n    return OpInterpUtil::Dispatch<one::Tensor>(*op_, {dy, input, target}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass KLDivLossGradFunctor {\n public:\n  KLDivLossGradFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"kl_div_loss_grad\")\n                         .Input(\"dy\")\n                         .Input(\"input\")\n                         .Input(\"target\")\n                         .Output(\"dx\")\n                         .Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& dy,\n                           const std::shared_ptr<one::Tensor>& input,\n                           const std::shared_ptr<one::Tensor>& target,\n                           const bool log_target) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"log_target\");\n    attrs.SetAllAttrs(log_target);\n\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {dy, input, target}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass KLDivLossTargetGradFunctor {\n public:\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& dy,\n                           const std::shared_ptr<one::Tensor>& input,\n                           const std::shared_ptr<one::Tensor>& target,\n                           const bool log_target) const {\n    if (log_target) {\n      return functional::sequence_function(functional::Sub)\n          .then([](const std::shared_ptr<Tensor>& input) {\n            return functional::ScalarAdd(1, input, /*alpha=*/Scalar(1));\n          })\n          .then(std::bind(functional::Mul, std::placeholders::_1, JUST(functional::Exp(target))))\n          .then(std::bind(functional::Mul, std::placeholders::_1, dy))\n          .call(target, input, /*alpha=*/1, /*inplace=*/false);\n    } else {\n      return functional::sequence_function(functional::Log)\n          .then([](const std::shared_ptr<Tensor>& input) {\n            return functional::ScalarAdd(1, input, /*alpha=*/Scalar(1));\n          })\n          .then(std::bind(functional::Sub, std::placeholders::_1, input, /*alpha=*/1,\n                          /*inplace=*/false))\n          .then(std::bind(functional::Mul, std::placeholders::_1, dy))\n          .call(target);\n    }\n  }\n};\n\nclass NLLGradFunctor {\n public:\n  NLLGradFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"nll_grad\")\n                         .Input(\"out_grad\")\n                         .Input(\"input\")\n                         .Input(\"target\")\n                         .Output(\"in_grad\")\n                         .Build());\n\n    op_weight_ = CHECK_JUST(one::OpBuilder(\"nll_grad\")\n                                .Input(\"out_grad\")\n                                .Input(\"input\")\n                                .Input(\"target\")\n                                .Input(\"weight\")\n                                .Output(\"in_grad\")\n                                .Build());\n  }\n\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& out_grad,\n                           const std::shared_ptr<one::Tensor>& input,\n                           const std::shared_ptr<one::Tensor>& target,\n                           const Optional<one::Tensor>& weight, const int64_t ignore_index) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"ignore_index\");\n    attrs.SetAllAttrs(ignore_index);\n\n    if (weight) {\n      return OpInterpUtil::Dispatch<one::Tensor>(\n          *op_weight_, {out_grad, input, target, JUST(JUST(weight)->detach())}, attrs);\n    } else {\n      return OpInterpUtil::Dispatch<one::Tensor>(*op_, {out_grad, input, target}, attrs);\n    }\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n  std::shared_ptr<OpExpr> op_weight_;\n};\n\nclass BinaryCrossEntropyLossGradFunctor {\n public:\n  BinaryCrossEntropyLossGradFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"binary_cross_entropy_grad\")\n                         .Input(\"dy\")\n                         .Input(\"input\")\n                         .Input(\"target\")\n                         .Output(\"dx\")\n                         .Build());\n    op_weight_ = CHECK_JUST(one::OpBuilder(\"binary_cross_entropy_grad\")\n                                .Input(\"dy\")\n                                .Input(\"input\")\n                                .Input(\"target\")\n                                .Input(\"weight\")\n                                .Output(\"dx\")\n                                .Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& dy,\n                           const std::shared_ptr<one::Tensor>& input,\n                           const std::shared_ptr<one::Tensor>& target,\n                           const Optional<one::Tensor>& weight) const {\n    if (weight) {\n      return OpInterpUtil::Dispatch<one::Tensor>(*op_weight_, {dy, input, target, JUST(weight)});\n    } else {\n      return OpInterpUtil::Dispatch<one::Tensor>(*op_, {dy, input, target});\n    }\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n  std::shared_ptr<OpExpr> op_weight_;\n};\n\nclass BinaryCrossEntropyLossTargetGradFunctor {\n public:\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& dy,\n                           const std::shared_ptr<one::Tensor>& input,\n                           const std::shared_ptr<one::Tensor>& target,\n                           const Optional<one::Tensor>& weight) const {\n    auto log_one_sub_input = JUST(functional::Log(JUST(ScalarSub(1, input, /*alpha=*/1))));\n    auto grad = functional::sequence_function(functional::Log)\n                    .then(std::bind(functional::Sub, log_one_sub_input, std::placeholders::_1,\n                                    /*alpha=*/1, /*inplace=*/false))\n                    .then(std::bind(functional::Mul, dy, std::placeholders::_1))\n                    .call(input);\n    return weight ? Mul(JUST(grad), JUST(weight)) : grad;\n  }\n};\n\nclass BinaryCrossEntropyWithLogitsLossGradFunctor {\n public:\n  BinaryCrossEntropyWithLogitsLossGradFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"binary_cross_entropy_with_logits_grad\")\n                         .Input(\"dy\")\n                         .Input(\"input\")\n                         .Input(\"target\")\n                         .Output(\"dx\")\n                         .Build());\n    op_weight_ = CHECK_JUST(one::OpBuilder(\"binary_cross_entropy_with_logits_grad\")\n                                .Input(\"dy\")\n                                .Input(\"input\")\n                                .Input(\"target\")\n                                .Input(\"weight\")\n                                .Output(\"dx\")\n                                .Build());\n    op_pos_ = CHECK_JUST(one::OpBuilder(\"binary_cross_entropy_with_logits_grad\")\n                             .Input(\"dy\")\n                             .Input(\"input\")\n                             .Input(\"target\")\n                             .Input(\"pos_weight\")\n                             .Output(\"dx\")\n                             .Build());\n    op_weight_pos_ = CHECK_JUST(one::OpBuilder(\"binary_cross_entropy_with_logits_grad\")\n                                    .Input(\"dy\")\n                                    .Input(\"input\")\n                                    .Input(\"target\")\n                                    .Input(\"weight\")\n                                    .Input(\"pos_weight\")\n                                    .Output(\"dx\")\n                                    .Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& dy,\n                           const std::shared_ptr<one::Tensor>& input,\n                           const std::shared_ptr<one::Tensor>& target,\n                           const Optional<one::Tensor>& weight,\n                           const Optional<one::Tensor>& pos_weight) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"has_pos_weight\");\n    attrs.SetAllAttrs(pos_weight.has_value());\n\n    if (weight) {\n      if (pos_weight) {\n        return OpInterpUtil::Dispatch<one::Tensor>(\n            *op_weight_pos_, {dy, input, target, JUST(weight), JUST(pos_weight)}, attrs);\n      } else {\n        return OpInterpUtil::Dispatch<one::Tensor>(*op_weight_, {dy, input, target, JUST(weight)},\n                                                   attrs);\n      }\n    } else {\n      if (pos_weight) {\n        return OpInterpUtil::Dispatch<one::Tensor>(*op_pos_, {dy, input, target, JUST(pos_weight)},\n                                                   attrs);\n      } else {\n        return OpInterpUtil::Dispatch<one::Tensor>(*op_, {dy, input, target}, attrs);\n      }\n    }\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n  std::shared_ptr<OpExpr> op_weight_;\n  std::shared_ptr<OpExpr> op_pos_;\n  std::shared_ptr<OpExpr> op_weight_pos_;\n};\n\nclass BinaryCrossEntropyWithLogitsLossTargetGradFunctor {\n public:\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& dy,\n                           const std::shared_ptr<one::Tensor>& input,\n                           const std::shared_ptr<one::Tensor>& target,\n                           const Optional<one::Tensor>& weight,\n                           const Optional<one::Tensor>& pos_weight) const {\n    if (pos_weight) {\n      auto sig = JUST(functional::Sigmoid(input));\n      auto log_one_sub_sig =\n          JUST(functional::Log(JUST(functional::ScalarSub(1, sig, /*alpha=*/1))));\n      auto grad = functional::sequence_function(functional::Log)\n                      .then(std::bind(functional::Mul, std::placeholders::_1, JUST(pos_weight)))\n                      .then(std::bind(functional::Sub, log_one_sub_sig, std::placeholders::_1,\n                                      /*alpha=*/1, false))\n                      .call(sig);\n\n      return weight ? functional::Mul(JUST(grad), JUST(weight)) : grad;\n    } else {\n      auto grad = functional::sequence_function(functional::Negative)\n                      .then(std::bind(functional::Mul, std::placeholders::_1, dy))\n                      .call(input);\n      return weight ? functional::Mul(JUST(grad), JUST(weight)) : grad;\n    }\n  }\n};\n\nclass BinaryCrossEntropyWithLogitsReduceMeanLossGradFunctor {\n public:\n  BinaryCrossEntropyWithLogitsReduceMeanLossGradFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"binary_cross_entropy_with_logits_reduce_mean_grad\")\n                         .Input(\"dy\")\n                         .Input(\"input\")\n                         .Input(\"target\")\n                         .Output(\"dx\")\n                         .Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& dy,\n                           const std::shared_ptr<one::Tensor>& input,\n                           const std::shared_ptr<one::Tensor>& target) const {\n    return OpInterpUtil::Dispatch<one::Tensor>(*op_, {dy, input, target});\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n  std::shared_ptr<OpExpr> op_weight_;\n  std::shared_ptr<OpExpr> op_pos_;\n  std::shared_ptr<OpExpr> op_weight_pos_;\n};\n\nclass BinaryCrossEntropyWithLogitsReduceMeanLossTargetGradFunctor {\n public:\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& dy,\n                           const std::shared_ptr<one::Tensor>& input,\n                           const std::shared_ptr<one::Tensor>& target) const {\n    auto neg_mean_dy = JUST(functional::ScalarMul(-1.0 / input->nelement(), dy));\n    return functional::Mul(input, neg_mean_dy);\n  }\n};\n\nclass CombinedMarginLossGradFunctor {\n public:\n  CombinedMarginLossGradFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"combined_margin_loss_grad\")\n                         .Input(\"dy\")\n                         .Input(\"label\")\n                         .Input(\"theta\")\n                         .Output(\"dx\")\n                         .Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& dy,\n                           const std::shared_ptr<one::Tensor>& label,\n                           const std::shared_ptr<one::Tensor>& theta, const float& m1,\n                           const float& m2, const float& m3, const int64_t& depth) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"m1\", \"m2\", \"m3\", \"depth\");\n    attrs.SetAllAttrs(m1, m2, m3, depth);\n    return OpInterpUtil::Dispatch<one::Tensor>(*op_, {dy, label, theta}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass AffineGridGradFunctor {\n public:\n  AffineGridGradFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"affine_grid_grad\").Input(\"dgrid\").Output(\"dtheta\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& dgrid, const Shape& size,\n                           const bool& align_corners) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"size\", \"align_corners\");\n    attrs.SetAllAttrs(size, align_corners);\n    return OpInterpUtil::Dispatch<one::Tensor>(*op_, {dgrid}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass GridSampleGradFunctor {\n public:\n  GridSampleGradFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"grid_sample_grad\")\n                         .Input(\"doutput\")\n                         .Input(\"input\")\n                         .Input(\"grid\")\n                         .Output(\"dinput\")\n                         .Output(\"dgrid\")\n                         .Build());\n  }\n  Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& doutput,\n                                const std::shared_ptr<one::Tensor>& input,\n                                const std::shared_ptr<one::Tensor>& grid,\n                                const std::string& interpolation_mode,\n                                const std::string& padding_mode, const bool& align_corners) const {\n    auto& attrs =\n        THREAD_CACHED_MUTABLE_ATTR_MAP(\"interpolation_mode\", \"padding_mode\", \"align_corners\");\n    attrs.SetAllAttrs(interpolation_mode, padding_mode, align_corners);\n    return OpInterpUtil::Dispatch<one::TensorTuple>(*op_, {doutput, input, grid}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass CtcLossGradFunctor {\n public:\n  CtcLossGradFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"ctc_loss_grad\")\n                         .Input(\"grad_out\")\n                         .Input(\"log_probs\")\n                         .Input(\"targets\")\n                         .Input(\"input_lengths\")\n                         .Input(\"target_lengths\")\n                         .Input(\"loss\")\n                         .Input(\"alpha\")\n                         .Output(\"grad\")\n                         .Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& grad_out,\n                           const std::shared_ptr<one::Tensor>& log_probs,\n                           const std::shared_ptr<one::Tensor>& targets,\n                           const std::shared_ptr<one::Tensor>& input_lengths,\n                           const std::shared_ptr<one::Tensor>& target_lengths,\n                           const std::shared_ptr<one::Tensor>& loss,\n                           const std::shared_ptr<one::Tensor>& alpha, const int64_t& blank,\n                           const bool& zero_infinity, const int64_t& max_target_length) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"blank\", \"zero_infinity\", \"max_target_length\");\n    attrs.SetAllAttrs(blank, zero_infinity, max_target_length);\n    if (targets->dtype()->data_type() == DataType::kInt32) {\n      return OpInterpUtil::Dispatch<one::Tensor>(\n          *op_, {grad_out, log_probs, targets, input_lengths, target_lengths, loss, alpha}, attrs);\n    } else {\n      return OpInterpUtil::Dispatch<one::Tensor>(\n          *op_,\n          {grad_out, log_probs, JUST(functional::Cast(targets, DType::Int64(), false)),\n           input_lengths, target_lengths, loss, alpha},\n          attrs);\n    }\n    return OpInterpUtil::Dispatch<one::Tensor>(\n        *op_, {grad_out, log_probs, targets, input_lengths, target_lengths, loss, alpha}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass PadGradFunctor {\n public:\n  PadGradFunctor() {\n    reflect_pad1d_grad_ =\n        CHECK_JUST(one::OpBuilder(\"reflection_pad1d_grad\").Input(\"dy\").Output(\"dx\").Build());\n    reflect_pad2d_grad_ =\n        CHECK_JUST(one::OpBuilder(\"reflection_pad2d_grad\").Input(\"dy\").Output(\"dx\").Build());\n    replicate_pad1d_grad_ =\n        CHECK_JUST(one::OpBuilder(\"replication_pad1d_grad\").Input(\"dy\").Output(\"dx\").Build());\n    replicate_pad2d_grad_ =\n        CHECK_JUST(one::OpBuilder(\"replication_pad2d_grad\").Input(\"dy\").Output(\"dx\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& dy, const std::vector<int64_t>& pad,\n                           const std::string& mode, const Scalar& value) const {\n    const int64_t ndim = dy->shape()->NumAxes();\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"padding\");\n    attrs.SetAllAttrs(pad);\n    if (mode == \"reflect\") {\n      if (ndim == 3) {\n        return OpInterpUtil::Dispatch<Tensor>(*reflect_pad1d_grad_, {dy}, attrs);\n      } else if (ndim == 4) {\n        return OpInterpUtil::Dispatch<Tensor>(*reflect_pad2d_grad_, {dy}, attrs);\n      } else {\n        UNIMPLEMENTED_THEN_RETURN() << \"only 3D/4D reflect padding are supported for now\";\n      }\n\n    } else if (mode == \"replicate\") {\n      if (ndim == 3) {\n        return OpInterpUtil::Dispatch<Tensor>(*replicate_pad1d_grad_, {dy}, attrs);\n      } else if (ndim == 4) {\n        return OpInterpUtil::Dispatch<Tensor>(*replicate_pad2d_grad_, {dy}, attrs);\n      } else {\n        UNIMPLEMENTED_THEN_RETURN() << \"only 3D/4D replicate padding are supported for now\";\n      }\n    } else {\n      UNIMPLEMENTED_THEN_RETURN() << \"Pad mode is \" << mode\n                                  << \", but only constant, reflect and replicate are valid.\";\n    }\n  }\n\n private:\n  std::shared_ptr<OpExpr> reflect_pad1d_grad_;\n  std::shared_ptr<OpExpr> reflect_pad2d_grad_;\n  std::shared_ptr<OpExpr> replicate_pad1d_grad_;\n  std::shared_ptr<OpExpr> replicate_pad2d_grad_;\n};\n\nclass AvgPoolNdGradFunctor {\n public:\n  AvgPoolNdGradFunctor() {\n    for (int ndims = 1; ndims <= 3; ++ndims) {\n      const auto& op_type_name = GetOpTypeName(ndims);\n      op_expr_map_[op_type_name] =\n          CHECK_JUST(one::OpBuilder(op_type_name).Input(\"dy\").Input(\"x\").Output(\"dx\").Build());\n    }\n  }\n  static std::string GetOpTypeName(const int32_t& ndims) {\n    return \"avg_pool_\" + std::to_string(ndims) + \"d_grad\";\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,\n                           const std::shared_ptr<one::Tensor>& dy, const int32_t& ndims,\n                           const std::string& data_format, const std::vector<int32_t>& padding,\n                           const std::vector<int32_t>& kernel_size,\n                           const std::vector<int32_t>& stride, const bool& ceil_mode,\n                           const bool& count_include_pad, const int32_t& divisor_override) const {\n    auto& attrs =\n        THREAD_CACHED_MUTABLE_ATTR_MAP(\"data_format\", \"padding\", \"kernel_size\", \"stride\",\n                                       \"ceil_mode\", \"count_include_pad\", \"divisor_override\");\n    attrs.SetAllAttrs(data_format, padding, kernel_size, stride, ceil_mode, count_include_pad,\n                      divisor_override);\n    const auto& op_type_name = GetOpTypeName(ndims);\n    const auto& it = op_expr_map_.find(op_type_name);\n    CHECK_OR_RETURN(it != op_expr_map_.end())\n        << Error::RuntimeError() << \"Encounter unsupported op \" << op_type_name\n        << \" in AvgPoolNdGradFunctor.\";\n    CHECK_NOTNULL_OR_RETURN(it->second);  // NOLINT(maybe-need-error-msg)\n    return OpInterpUtil::Dispatch<Tensor>(*it->second, {dy, x}, attrs);\n  }\n\n protected:\n  std::unordered_map<std::string, std::shared_ptr<OpExpr>> op_expr_map_;\n};\n\nclass NormalizationGradFunctor {\n public:\n  NormalizationGradFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"normalization_grad\")\n                         .Input(\"dy\")\n                         .Input(\"x\")\n                         .Input(\"mean\")\n                         .Input(\"inv_variance\")\n                         .Input(\"gamma\")\n                         .Output(\"dx\")\n                         .Output(\"gamma_diff\")\n                         .Output(\"beta_diff\")\n                         .Build());\n  }\n  Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& grad,\n                                const std::shared_ptr<one::Tensor>& x,\n                                const std::shared_ptr<one::Tensor>& mean,\n                                const std::shared_ptr<one::Tensor>& inv_variance,\n                                const std::shared_ptr<one::Tensor>& gamma, const float& epsilon,\n                                const int32_t& axis) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"epsilon\", \"axis\");\n    attrs.SetAllAttrs(epsilon, axis);\n    return OpInterpUtil::Dispatch<TensorTuple>(*op_, {grad, x, mean, inv_variance, gamma}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass NormalizationAddReluGradFunctor {\n public:\n  NormalizationAddReluGradFunctor() {\n    addend_op_ = CHECK_JUST(one::OpBuilder(\"normalization_add_relu_grad\")\n                                .Input(\"x\")\n                                .Input(\"dy\")\n                                .Input(\"mean\")\n                                .Input(\"inv_variance\")\n                                .Input(\"gamma\")\n                                .Input(\"beta\")\n                                .Input(\"reserve_space\")\n                                .Input(\"y\")\n                                .Output(\"dx\")\n                                .Output(\"gamma_diff\")\n                                .Output(\"beta_diff\")\n                                .Output(\"addend_diff\")\n                                .Build());\n    no_addend_op_ = CHECK_JUST(one::OpBuilder(\"normalization_add_relu_grad\")\n                                   .Input(\"x\")\n                                   .Input(\"dy\")\n                                   .Input(\"mean\")\n                                   .Input(\"inv_variance\")\n                                   .Input(\"gamma\")\n                                   .Input(\"beta\")\n                                   .Input(\"reserve_space\")\n                                   .Input(\"y\")\n                                   .Output(\"dx\")\n                                   .Output(\"gamma_diff\")\n                                   .Output(\"beta_diff\")\n                                   .Build());\n  }\n  Maybe<TensorTuple> operator()(\n      const std::shared_ptr<one::Tensor>& x, const std::shared_ptr<one::Tensor>& grad,\n      const std::shared_ptr<one::Tensor>& mean, const std::shared_ptr<one::Tensor>& inv_variance,\n      const std::shared_ptr<one::Tensor>& gamma, const std::shared_ptr<one::Tensor>& beta,\n      const std::shared_ptr<one::Tensor>& reserve_space, const std::shared_ptr<one::Tensor>& y,\n      const int32_t& axis, const float& epsilon, bool has_addend) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"axis\", \"epsilon\");\n    attrs.SetAllAttrs(axis, epsilon);\n    if (has_addend) {\n      return OpInterpUtil::Dispatch<TensorTuple>(\n          *addend_op_, {x, grad, mean, inv_variance, gamma, beta, reserve_space, y}, attrs);\n    } else {\n      return OpInterpUtil::Dispatch<TensorTuple>(\n          *no_addend_op_, {x, grad, mean, inv_variance, gamma, beta, reserve_space, y}, attrs);\n    }\n  }\n\n private:\n  std::shared_ptr<OpExpr> addend_op_;\n  std::shared_ptr<OpExpr> no_addend_op_;\n};\n\nclass LayerNormGradFunctor {\n public:\n  LayerNormGradFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"layer_norm_grad\")\n                         .Input(\"dy\")\n                         .Input(\"x\")\n                         .Input(\"mean\")\n                         .Input(\"inv_variance\")\n                         .Output(\"dx\")\n                         .Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& dy,\n                           const std::shared_ptr<one::Tensor>& x,\n                           const std::shared_ptr<one::Tensor>& mean,\n                           const std::shared_ptr<one::Tensor>& inv_variance,\n                           const int64_t& begin_norm_axis, const double& epsilon) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"begin_norm_axis\", \"epsilon\");\n    attrs.SetAllAttrs(begin_norm_axis, epsilon);\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {dy, x, mean, inv_variance}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass LayerNormAffineGradFunctor {\n public:\n  LayerNormAffineGradFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"layer_norm_grad\")\n                         .Input(\"dy\")\n                         .Input(\"x\")\n                         .Input(\"mean\")\n                         .Input(\"inv_variance\")\n                         .Input(\"gamma\")\n                         .Output(\"dx\")\n                         .Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& dy,\n                           const std::shared_ptr<one::Tensor>& x,\n                           const std::shared_ptr<one::Tensor>& mean,\n                           const std::shared_ptr<one::Tensor>& inv_variance,\n                           const std::shared_ptr<one::Tensor>& gamma,\n                           const int64_t& begin_norm_axis, const double& epsilon) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"begin_norm_axis\", \"epsilon\");\n    attrs.SetAllAttrs(begin_norm_axis, epsilon);\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {dy, x, mean, inv_variance, gamma}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass FuseLayerNormGradFunctor {\n public:\n  FuseLayerNormGradFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"fuse_layer_norm_grad\")\n                         .Input(\"dy\")\n                         .Input(\"x\")\n                         .Input(\"mean\")\n                         .Input(\"inv_variance\")\n                         .Input(\"gamma\")\n                         .Output(\"dx\")\n                         .Output(\"gamma_diff\")\n                         .Output(\"beta_diff\")\n                         .Build());\n  }\n  Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& dy,\n                                const std::shared_ptr<one::Tensor>& x,\n                                const std::shared_ptr<one::Tensor>& mean,\n                                const std::shared_ptr<one::Tensor>& inv_variance,\n                                const std::shared_ptr<one::Tensor>& gamma,\n                                const int64_t& begin_norm_axis, const int64_t& begin_params_axis,\n                                const double& epsilon) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"begin_norm_axis\", \"begin_params_axis\", \"epsilon\");\n    attrs.SetAllAttrs(begin_norm_axis, begin_params_axis, epsilon);\n    return OpInterpUtil::Dispatch<TensorTuple>(*op_, {dy, x, mean, inv_variance, gamma}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass LayerNormParamGradFunctor {\n public:\n  LayerNormParamGradFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"layer_norm_param_grad\")\n                         .Input(\"dy\")\n                         .Input(\"x\")\n                         .Input(\"mean\")\n                         .Input(\"inv_variance\")\n                         .Output(\"gamma_diff\")\n                         .Output(\"beta_diff\")\n                         .Build());\n  }\n  Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& dy,\n                                const std::shared_ptr<one::Tensor>& x,\n                                const std::shared_ptr<one::Tensor>& mean,\n                                const std::shared_ptr<one::Tensor>& inv_variance,\n                                const int64_t& begin_params_axis) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"begin_params_axis\");\n    attrs.SetAllAttrs(begin_params_axis);\n    return OpInterpUtil::Dispatch<TensorTuple>(*op_, {dy, x, mean, inv_variance}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass GroupNormGradFunctor {\n public:\n  GroupNormGradFunctor() {\n    affine_grad_op_ = CHECK_JUST(one::OpBuilder(\"group_norm_grad\")\n                                     .Input(\"dy\")\n                                     .Input(\"x\")\n                                     .Input(\"mean\")\n                                     .Input(\"inv_variance\")\n                                     .Input(\"gamma\")\n                                     .Output(\"dx\")\n                                     .Build());\n    grad_op_ = CHECK_JUST(one::OpBuilder(\"group_norm_grad\")\n                              .Input(\"dy\")\n                              .Input(\"x\")\n                              .Input(\"mean\")\n                              .Input(\"inv_variance\")\n                              .Output(\"dx\")\n                              .Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& dy,\n                           const std::shared_ptr<one::Tensor>& x,\n                           const std::shared_ptr<one::Tensor>& mean,\n                           const std::shared_ptr<one::Tensor>& inv_variance,\n                           const Optional<one::Tensor>& gamma, const int32_t& num_groups,\n                           const double& epsilon) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"num_groups\", \"epsilon\");\n    attrs.SetAttr<int32_t>(\"num_groups\", num_groups);\n    attrs.SetAttr<double>(\"epsilon\", epsilon);\n    if (gamma) {\n      return OpInterpUtil::Dispatch<Tensor>(*affine_grad_op_,\n                                            {dy, x, mean, inv_variance, JUST(gamma)}, attrs);\n    } else {\n      return OpInterpUtil::Dispatch<Tensor>(*grad_op_, {dy, x, mean, inv_variance}, attrs);\n    }\n  }\n\n private:\n  std::shared_ptr<OpExpr> affine_grad_op_;\n  std::shared_ptr<OpExpr> grad_op_;\n};\n\nclass GroupNormParamGradFunctor {\n public:\n  GroupNormParamGradFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"group_norm_param_grad\")\n                         .Input(\"dy\")\n                         .Input(\"x\")\n                         .Input(\"mean\")\n                         .Input(\"inv_variance\")\n                         .Output(\"dgamma\")\n                         .Output(\"dbeta\")\n                         .Build());\n  }\n  Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& dy,\n                                const std::shared_ptr<one::Tensor>& x,\n                                const std::shared_ptr<one::Tensor>& mean,\n                                const std::shared_ptr<one::Tensor>& inv_variance) const {\n    return OpInterpUtil::Dispatch<TensorTuple>(*op_, {dy, x, mean, inv_variance});\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass RMSNormGradFunctor {\n public:\n  RMSNormGradFunctor() {\n    grad_op_ = CHECK_JUST(one::OpBuilder(\"rms_norm_grad\")\n                              .Input(\"dy\")\n                              .Input(\"x\")\n                              .Input(\"inv_rms\")\n                              .Output(\"dx\")\n                              .Build());\n    affine_grad_op_ = CHECK_JUST(one::OpBuilder(\"rms_norm_grad\")\n                                     .Input(\"dy\")\n                                     .Input(\"x\")\n                                     .Input(\"inv_rms\")\n                                     .Input(\"weight\")\n                                     .Output(\"dx\")\n                                     .Build());\n    param_grad_op_ = CHECK_JUST(one::OpBuilder(\"rms_norm_param_grad\")\n                                    .Input(\"dy\")\n                                    .Input(\"x\")\n                                    .Input(\"inv_rms\")\n                                    .Output(\"weight_grad\")\n                                    .Build());\n  }\n\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& dy,\n                           const std::shared_ptr<one::Tensor>& x,\n                           const std::shared_ptr<one::Tensor>& inv_rms,\n                           const Optional<one::Tensor>& weight, const bool param_grad) const {\n    if (param_grad) {\n      return OpInterpUtil::Dispatch<Tensor>(*param_grad_op_, {dy, x, inv_rms});\n    } else if (weight) {\n      return OpInterpUtil::Dispatch<Tensor>(*affine_grad_op_, {dy, x, inv_rms, JUST(weight)});\n    } else {\n      return OpInterpUtil::Dispatch<Tensor>(*grad_op_, {dy, x, inv_rms});\n    }\n  }\n\n private:\n  std::shared_ptr<OpExpr> grad_op_;\n  std::shared_ptr<OpExpr> affine_grad_op_;\n  std::shared_ptr<OpExpr> param_grad_op_;\n};\n\nclass BroadcastMatmulGradBFunctor {\n public:\n  BroadcastMatmulGradBFunctor() {\n    op_ = CHECK_JUST(\n        one::OpBuilder(\"broadcast_matmul_grad_b\").Input(\"a\").Input(\"b\").Output(\"out\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& a,\n                           const std::shared_ptr<one::Tensor>& b, double alpha) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"alpha\");\n    attrs.SetAllAttrs(alpha);\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {a, b}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass FusedScaleTrilSoftmaxMaskScaleGradFunctor {\n public:\n  FusedScaleTrilSoftmaxMaskScaleGradFunctor() {\n    fused_op_ = CHECK_JUST(one::OpBuilder(\"fused_tril_scale_softmax_mask_scale_grad\")\n                               .Input(\"softmax_y\")\n                               .Input(\"dy\")\n                               .Input(\"mask\")\n                               .Output(\"dx\")\n                               .Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& softmax_y,\n                           const std::shared_ptr<one::Tensor>& dy,\n                           const std::shared_ptr<one::Tensor>& mask, const int64_t diagonal,\n                           const float tril_scale_value, const float mask_scale_value) const {\n    auto& fused_attrs =\n        THREAD_CACHED_MUTABLE_ATTR_MAP(\"diagonal\", \"tril_scale_value\", \"mask_scale_value\");\n    fused_attrs.SetAllAttrs(diagonal, tril_scale_value, mask_scale_value);\n    return OpInterpUtil::Dispatch<Tensor>(*fused_op_, {softmax_y, dy, mask}, fused_attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> fused_op_;\n};\n\nclass FusedScaleMaskSoftmaxGradFunctor {\n public:\n  FusedScaleMaskSoftmaxGradFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"fused_scale_mask_softmax_grad\")\n                         .Input(\"y\")\n                         .Input(\"dy\")\n                         .Input(\"mask\")\n                         .Output(\"dx\")\n                         .Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& y,\n                           const std::shared_ptr<one::Tensor>& dy,\n                           const std::shared_ptr<one::Tensor>& mask, const float& scale) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"scale_value\");\n    attrs.SetAllAttrs(scale);\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {y, dy, mask}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass FusedScaleMaskSoftmaxDropoutGradFunctor {\n public:\n  FusedScaleMaskSoftmaxDropoutGradFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"fused_scale_mask_softmax_dropout_grad\")\n                         .Input(\"softmax_y\")\n                         .Input(\"dy\")\n                         .Input(\"mask\")\n                         .Input(\"dropout_mask\")\n                         .Output(\"dx\")\n                         .Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& softmax_y,\n                           const std::shared_ptr<one::Tensor>& dy,\n                           const std::shared_ptr<one::Tensor>& mask,\n                           const std::shared_ptr<one::Tensor>& dropout_mask, const float& scale,\n                           const float& dropout_scale) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"scale_value\", \"dropout_scale_value\");\n    attrs.SetAllAttrs(scale, dropout_scale);\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {softmax_y, dy, mask, dropout_mask}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass CublasBiasAddReluMatmulGradFunctor {\n public:\n  CublasBiasAddReluMatmulGradFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"cublas_bias_add_relu_matmul_grad\")\n                         .Input(\"dy\")\n                         .Input(\"weight\")\n                         .Input(\"aux\")\n                         .Output(\"d_grad\")\n                         .Output(\"d_bias\")\n                         .Build());\n  }\n  Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& dy,\n                                const std::shared_ptr<one::Tensor>& weight,\n                                const std::shared_ptr<one::Tensor>& aux,\n                                const double& alpha) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"alpha\");\n    attrs.SetAllAttrs(alpha);\n    return OpInterpUtil::Dispatch<TensorTuple>(*op_, {dy, weight, aux}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass CublasMatmulBiasAddGradFunctor {\n public:\n  CublasMatmulBiasAddGradFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"cublas_matmul_bias_add_grad\")\n                         .Input(\"dy\")\n                         .Input(\"x\")\n                         .Output(\"w_grad\")\n                         .Output(\"b_grad\")\n                         .Build());\n  }\n  Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& dy,\n                                const std::shared_ptr<one::Tensor>& x) const {\n    return OpInterpUtil::Dispatch<TensorTuple>(*op_, {dy, x});\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass FusedReluDropoutGradFunctor {\n public:\n  FusedReluDropoutGradFunctor() {\n    op_ = CHECK_JUST(\n        one::OpBuilder(\"fused_relu_dropout_grad\").Input(\"dy\").Input(\"mask\").Output(\"dx\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& dy,\n                           const std::shared_ptr<one::Tensor>& mask, const float& scale) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"scale\");\n    attrs.SetAllAttrs(scale);\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {dy, mask}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass FusedDotFeatureInteractionGradFunctor {\n public:\n  FusedDotFeatureInteractionGradFunctor() {\n    ops_has_output_concat_grad_.resize(kMaxInputCount);\n    ops_no_output_concat_grad_.resize(kMaxInputCount);\n    for (int n = 0; n < ops_has_output_concat_grad_.size(); ++n) {\n      ops_has_output_concat_grad_[n] =\n          CHECK_JUST(one::OpBuilder(\"fused_dot_feature_interaction_grad\")\n                         .Input(\"dy\")\n                         .Input(\"features\", n + 1)\n                         .Output(\"features_grad\", n + 1)\n                         .Output(\"output_concat_grad\")\n                         .Build());\n    }\n    for (int n = 0; n < ops_no_output_concat_grad_.size(); ++n) {\n      ops_no_output_concat_grad_[n] =\n          CHECK_JUST(one::OpBuilder(\"fused_dot_feature_interaction_grad\")\n                         .Input(\"dy\")\n                         .Input(\"features\", n + 1)\n                         .Output(\"features_grad\", n + 1)\n                         .Build());\n    }\n  }\n\n  Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& dy, const TensorTuple& features,\n                                const bool& has_output_concat, const bool& self_interaction,\n                                const int32_t& output_concat_grad_dim,\n                                const std::string& pooling) const {\n    auto& attrs =\n        THREAD_CACHED_MUTABLE_ATTR_MAP(\"self_interaction\", \"output_concat_grad_dim\", \"pooling\");\n    attrs.SetAllAttrs(self_interaction, output_concat_grad_dim, pooling);\n    CHECK_OR_RETURN(pooling == \"sum\" || pooling == \"none\")\n        << Error::RuntimeError() << \"pooling should be sum or none, but get \" << pooling << \". \";\n    const int64_t n_features_grad = features.size();\n    CHECK_LE_OR_RETURN(n_features_grad, kMaxInputCount)\n        << Error::RuntimeError() << \"The number of tensors in features should be less than 128.\";\n    TensorTuple inputs(n_features_grad + 1);\n    inputs[0] = dy;\n    for (int32_t i = 0; i < n_features_grad; ++i) { inputs[i + 1] = features[i]; }\n    if (has_output_concat) {\n      return OpInterpUtil::Dispatch<TensorTuple>(\n          *JUST(oneflow::VectorAt(ops_has_output_concat_grad_, n_features_grad - 1)), inputs,\n          attrs);\n    } else {\n      return OpInterpUtil::Dispatch<TensorTuple>(\n          *JUST(oneflow::VectorAt(ops_no_output_concat_grad_, n_features_grad - 1)), inputs, attrs);\n    }\n  }\n\n private:\n  std::vector<std::shared_ptr<OpExpr>> ops_has_output_concat_grad_;\n  std::vector<std::shared_ptr<OpExpr>> ops_no_output_concat_grad_;\n};\n\nclass FusedCrossFeatureInteractionV1GradFunctor {\n public:\n  FusedCrossFeatureInteractionV1GradFunctor() {\n    v1_grad_op_ = CHECK_JUST(one::OpBuilder(\"fused_cross_feature_interaction_v1_grad\")\n                                 .Input(\"dy\")\n                                 .Input(\"weight\")\n                                 .Input(\"x\")\n                                 .Input(\"x0\")\n                                 .Input(\"matmul_result\")\n                                 .Output(\"dx\")\n                                 .Output(\"dw\")\n                                 .Output(\"dx0\")\n                                 .Output(\"dbias\")\n                                 .Build());\n  }\n\n  Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& dy,\n                                const std::shared_ptr<one::Tensor>& weight,\n                                const std::shared_ptr<one::Tensor>& x,\n                                const std::shared_ptr<one::Tensor>& x0,\n                                const std::shared_ptr<one::Tensor>& matmul_result) const {\n    return OpInterpUtil::Dispatch<TensorTuple>(*v1_grad_op_, {dy, weight, x, x0, matmul_result});\n  }\n\n private:\n  std::shared_ptr<OpExpr> v1_grad_op_;\n};\n\nclass FusedCrossFeatureInteractionV2GradFunctor {\n public:\n  FusedCrossFeatureInteractionV2GradFunctor() {\n    v2_grad_op_ = CHECK_JUST(one::OpBuilder(\"fused_cross_feature_interaction_v2_grad\")\n                                 .Input(\"dy\")\n                                 .Input(\"weight\")\n                                 .Input(\"bias\")\n                                 .Input(\"x\")\n                                 .Input(\"x0\")\n                                 .Input(\"matmul_result\")\n                                 .Output(\"dx\")\n                                 .Output(\"dw\")\n                                 .Output(\"dx0\")\n                                 .Output(\"dbias\")\n                                 .Build());\n  }\n\n  Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& dy,\n                                const std::shared_ptr<one::Tensor>& weight,\n                                const std::shared_ptr<one::Tensor>& bias,\n                                const std::shared_ptr<one::Tensor>& x,\n                                const std::shared_ptr<one::Tensor>& x0,\n                                const std::shared_ptr<one::Tensor>& matmul_result) const {\n    return OpInterpUtil::Dispatch<TensorTuple>(*v2_grad_op_,\n                                               {dy, weight, bias, x, x0, matmul_result});\n  }\n\n private:\n  std::shared_ptr<OpExpr> v2_grad_op_;\n};\n\nclass MatrixVectorProductGradAFunctor {\n public:\n  MatrixVectorProductGradAFunctor() {\n    matrix_vector_product_grad_a_op_ = CHECK_JUST(\n        one::OpBuilder(\"matrix_vector_product_grad_a\").Input(\"dy\").Input(\"b\").Output(\"dx\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& dy,\n                           const std::shared_ptr<one::Tensor>& b) const {\n    return OpInterpUtil::Dispatch<Tensor>(*matrix_vector_product_grad_a_op_, {dy, b});\n  }\n\n private:\n  std::shared_ptr<OpExpr> matrix_vector_product_grad_a_op_;\n};\n\nclass MatrixVectorProductGradBFunctor {\n public:\n  MatrixVectorProductGradBFunctor() {\n    matrix_vector_product_grad_b_op_ = CHECK_JUST(\n        one::OpBuilder(\"matrix_vector_product_grad_b\").Input(\"dy\").Input(\"a\").Output(\"dx\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& dy,\n                           const std::shared_ptr<one::Tensor>& a) const {\n    return OpInterpUtil::Dispatch<Tensor>(*matrix_vector_product_grad_b_op_, {dy, a});\n  }\n\n private:\n  std::shared_ptr<OpExpr> matrix_vector_product_grad_b_op_;\n};\n\nclass VectorMatrixProductGradAFunctor {\n public:\n  VectorMatrixProductGradAFunctor() {\n    vector_matrix_product_grad_a_op_ = CHECK_JUST(\n        one::OpBuilder(\"vector_matrix_product_grad_a\").Input(\"dy\").Input(\"b\").Output(\"dx\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& dy,\n                           const std::shared_ptr<one::Tensor>& b) const {\n    return OpInterpUtil::Dispatch<Tensor>(*vector_matrix_product_grad_a_op_, {dy, b});\n  }\n\n private:\n  std::shared_ptr<OpExpr> vector_matrix_product_grad_a_op_;\n};\n\nclass VectorMatrixProductGradBFunctor {\n public:\n  VectorMatrixProductGradBFunctor() {\n    vector_matrix_product_grad_b_op_ = CHECK_JUST(\n        one::OpBuilder(\"vector_matrix_product_grad_b\").Input(\"dy\").Input(\"a\").Output(\"dx\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& dy,\n                           const std::shared_ptr<one::Tensor>& a) const {\n    return OpInterpUtil::Dispatch<Tensor>(*vector_matrix_product_grad_b_op_, {dy, a});\n  }\n\n private:\n  std::shared_ptr<OpExpr> vector_matrix_product_grad_b_op_;\n};\nclass DeformConv2dInputGradFunctor {\n public:\n  DeformConv2dInputGradFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"deform_conv2d_input_grad\")\n                         .Input(\"output_grad\")\n                         .Input(\"input\")\n                         .Input(\"weight\")\n                         .Input(\"offset\")\n                         .Output(\"input_grad\")\n                         .Output(\"offset_grad\")\n                         .Build());\n\n    mask_op_ = CHECK_JUST(one::OpBuilder(\"deform_conv2d_input_grad\")\n                              .Input(\"output_grad\")\n                              .Input(\"input\")\n                              .Input(\"weight\")\n                              .Input(\"offset\")\n                              .Input(\"mask\")\n                              .Output(\"input_grad\")\n                              .Output(\"offset_grad\")\n                              .Output(\"mask_grad\")\n                              .Build());\n  }\n\n  Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& output_grad,\n                                const std::shared_ptr<one::Tensor>& input,\n                                const std::shared_ptr<one::Tensor>& weight,\n                                const std::shared_ptr<one::Tensor>& offset,\n                                const Optional<one::Tensor>& mask, const int32_t& stride_h,\n                                const int32_t& stride_w, const int32_t& pad_h, const int32_t& pad_w,\n                                const int32_t& dilation_h, const int32_t& dilation_w,\n                                const int32_t& groups, const int32_t& offset_groups,\n                                const bool& use_mask) const {\n    auto& attrs =\n        THREAD_CACHED_MUTABLE_ATTR_MAP(\"stride_h\", \"stride_w\", \"pad_h\", \"pad_w\", \"dilation_h\",\n                                       \"dilation_w\", \"groups\", \"offset_groups\", \"use_mask\");\n    attrs.SetAllAttrs(stride_h, stride_w, pad_h, pad_w, dilation_h, dilation_w, groups,\n                      offset_groups, use_mask);\n    if (mask) {\n      return OpInterpUtil::Dispatch<TensorTuple>(\n          *mask_op_, {output_grad, input, weight, offset, JUST(mask)}, attrs);\n    } else {\n      return OpInterpUtil::Dispatch<TensorTuple>(*op_, {output_grad, input, weight, offset}, attrs);\n    }\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n  std::shared_ptr<OpExpr> mask_op_;\n};\n\nclass DeformConv2dParamGradFunctor {\n public:\n  DeformConv2dParamGradFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"deform_conv2d_param_grad\")\n                         .Input(\"output_grad\")\n                         .Input(\"input\")\n                         .Input(\"weight\")\n                         .Input(\"offset\")\n                         .Input(\"mask\")\n                         .Output(\"weight_grad\")\n                         .Build());\n  }\n\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& output_grad,\n                           const std::shared_ptr<one::Tensor>& input,\n                           const std::shared_ptr<one::Tensor>& weight,\n                           const std::shared_ptr<one::Tensor>& offset,\n                           const std::shared_ptr<one::Tensor>& mask, const int32_t& stride_h,\n                           const int32_t& stride_w, const int32_t& pad_h, const int32_t& pad_w,\n                           const int32_t& dilation_h, const int32_t& dilation_w,\n                           const int32_t& groups, const int32_t& offset_groups,\n                           const bool& use_mask) const {\n    auto& attrs =\n        THREAD_CACHED_MUTABLE_ATTR_MAP(\"stride_h\", \"stride_w\", \"pad_h\", \"pad_w\", \"dilation_h\",\n                                       \"dilation_w\", \"groups\", \"offset_groups\", \"use_mask\");\n    attrs.SetAllAttrs(stride_h, stride_w, pad_h, pad_w, dilation_h, dilation_w, groups,\n                      offset_groups, use_mask);\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {output_grad, input, weight, offset, mask}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass FusedGluWithoutLinearGradFunctor {\n public:\n  FusedGluWithoutLinearGradFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"fused_glu_without_linear_grad\")\n                         .Input(\"dy\")\n                         .Input(\"matmul_wx\")\n                         .Output(\"d_matmul_wx\")\n                         .Build());\n    split_op_ = CHECK_JUST(one::OpBuilder(\"fused_glu_without_linear_grad\")\n                               .Input(\"dy\")\n                               .Input(\"matmul_wx\")\n                               .Input(\"matmul_vx\")\n                               .Output(\"d_matmul_wx\")\n                               .Output(\"d_matmul_vx\")\n                               .Build());\n  }\n\n  Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& dy,\n                                const std::shared_ptr<one::Tensor>& matmul_wx,\n                                const Optional<one::Tensor>& matmul_vx,\n                                const std::string& activation) const {\n    // check whether the user provide splited tensors\n    bool is_split_mode = false;\n    if (matmul_vx) { is_split_mode = true; }\n\n    // obtain input shape\n    const auto& dy_shape = *(dy->shape());\n    const auto& matmul_wx_shape = *(matmul_wx->shape());\n\n    // check number of axes of dy and matmul_wx\n    size_t dy_num_axes = dy_shape.NumAxes();\n    size_t matmul_wx_num_axes = matmul_wx_shape.NumAxes();\n    CHECK_GT_OR_RETURN(dy_num_axes, 1)\n        << \"number of axes of \\'dy\\' should have be greater than 1, yet get \" << dy_num_axes;\n    CHECK_GE_OR_RETURN(matmul_wx_num_axes, 2)\n        << \"number of axes of \\'matmul_wx\\' should have be greater than 1, yet get \"\n        << matmul_wx_num_axes;\n    CHECK_EQ_OR_RETURN(dy_num_axes, matmul_wx_num_axes)\n        << \"number of axes of \\'matmul_wx\\' (\" << matmul_wx_num_axes\n        << \") should equal to the one of \\'dy\\' (\" << dy_num_axes << \")\";\n\n    // check input shapes of dy and matmul_wx\n    for (uint64_t i = 0; i < dy_num_axes - 1; i++) {\n      size_t dy_size = dy_shape.At(i);\n      size_t matmul_wx_size = matmul_wx_shape.At(i);\n      CHECK_EQ_OR_RETURN(dy_size, matmul_wx_size)\n          << \"dimension \" << i << \"of \\'dy\\'(\" << dy_size << \") and \\'matmul_wx\\'(\"\n          << matmul_wx_size << \") is not consistent\";\n    }\n    if (is_split_mode) {\n      CHECK_EQ_OR_RETURN(dy_shape.At(dy_num_axes - 1), matmul_wx_shape.At(matmul_wx_num_axes - 1))\n          << \"last dimension of \\'dy\\'(\" << dy_shape.At(dy_num_axes - 1) << \") and \\'matmul_wx\\'(\"\n          << matmul_wx_shape.At(matmul_wx_num_axes - 1) << \") is not consistent\";\n    } else {\n      CHECK_EQ_OR_RETURN(2 * dy_shape.At(dy_num_axes - 1),\n                         matmul_wx_shape.At(matmul_wx_num_axes - 1))\n          << \"two times of the last dimension of \\'dy\\'(\" << 2 * (dy_shape.At(dy_num_axes - 1))\n          << \") and \\'matmul_wx\\'(\" << matmul_wx_shape.At(matmul_wx_num_axes - 1)\n          << \") is not consistent\";\n    }\n\n    if (is_split_mode) {\n      // obtain input shape\n      const auto& matmul_vx_shape = *(JUST(matmul_vx)->shape());\n\n      // check number of axes of dy and matmul_vx\n      size_t matmul_vx_num_axes = matmul_vx_shape.NumAxes();\n      CHECK_EQ_OR_RETURN(dy_num_axes, matmul_vx_num_axes)\n          << \"number of axes of \\'matmul_vx\\' (\" << matmul_vx_num_axes\n          << \") should equal to the one of \\'dy\\' (\" << dy_num_axes << \")\";\n\n      // check input shapes of dy and matmul_vx\n      for (uint64_t i = 0; i < dy_num_axes - 1; i++) {\n        size_t dy_size = dy_shape.At(i);\n        size_t matmul_vx_size = matmul_vx_shape.At(i);\n        CHECK_EQ_OR_RETURN(dy_size, matmul_vx_size)\n            << \"dimension \" << i << \"of \\'dy\\'(\" << dy_size << \") and \\'matmul_vx\\'(\"\n            << matmul_vx_size << \") is not consistent\";\n      }\n      CHECK_EQ_OR_RETURN(dy_shape.At(dy_num_axes - 1), matmul_vx_shape.At(matmul_vx_num_axes - 1))\n          << \"last dimension of \\'dy\\'(\" << dy_shape.At(dy_num_axes - 1) << \") and \\'matmul_vx\\'(\"\n          << matmul_vx_shape.At(matmul_vx_num_axes - 1) << \") is not consistent\";\n    }\n\n    // set activation attribute\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"activation\");\n    attrs.SetAllAttrs(activation);\n\n    // dispatch corresponding operator\n    if (is_split_mode) {\n      return OpInterpUtil::Dispatch<TensorTuple>(*split_op_, {dy, matmul_wx, JUST(matmul_vx)},\n                                                 attrs);\n    } else {\n      return OpInterpUtil::Dispatch<TensorTuple>(*op_, {dy, matmul_wx}, attrs);\n    }\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n  std::shared_ptr<OpExpr> split_op_;\n};\n\nclass FusedMLPGradFunctor {\n public:\n  FusedMLPGradFunctor() {\n#if CUDA_VERSION >= 11060\n    fused_op_.resize(kMaxInputCount /*the maximum number of layers*/);\n    for (int n = 1; n < fused_op_.size(); ++n) {\n      fused_op_[n] = CHECK_JUST(one::OpBuilder(\"cublas_fused_mlp_grad\")\n                                    .Input(\"dy\")\n                                    .Input(\"x\")\n                                    .Input(\"weights\", n)\n                                    .Input(\"cublas_aux\", n)\n                                    .Input(\"hidden\", n)\n                                    .Output(\"d_x\")\n                                    .Output(\"d_biases\", n)\n                                    .Output(\"d_weights\", n)\n                                    .Build());\n    }\n#endif\n  }\n  Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& dy,\n                                const std::shared_ptr<one::Tensor>& x, const TensorTuple& weights,\n                                const TensorTuple& cublas_aux, const TensorTuple& hidden,\n                                const std::vector<float>& alpha_list) const {\n    const int64_t weight_size = weights.size();\n    CHECK_EQ_OR_RETURN(alpha_list.size(), weight_size - 1)\n        << \"Alpha list size should be equal to weight_size - 1. \";\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"alpha_list\");\n    attrs.SetAllAttrs(alpha_list);\n    TensorTuple input(2 + 3 * weight_size);\n    input[0] = dy;\n    input[1] = x;\n    std::copy(weights.begin(), weights.end(), input.begin() + 2);\n    std::copy(cublas_aux.begin(), cublas_aux.end(), input.begin() + 2 + weight_size);\n    std::copy(hidden.begin(), hidden.end(), input.begin() + 2 + 2 * weight_size);\n#if CUDA_VERSION >= 11060\n    return OpInterpUtil::Dispatch<TensorTuple>(*fused_op_[weight_size], input, attrs);\n#endif\n    UNIMPLEMENTED_THEN_RETURN() << \"Only Support in CUDA_VERSION >= 11060\";\n  }\n\n private:\n#if CUDA_VERSION >= 11060\n  std::vector<std::shared_ptr<OpExpr>> fused_op_;\n#endif\n};\n\n}  // namespace impl\n\nONEFLOW_FUNCTION_LIBRARY(m) {\n  m.add_functor<impl::ConvBiasGradFunctor>(\"ConvBiasGrad\");\n  m.add_functor<impl::ConvFilterGradFunctor>(\"ConvFilterGrad\");\n  m.add_functor<impl::ConvDataGradFunctor>(\"ConvDataGrad\");\n  m.add_functor<impl::EmbeddingGradFunctor>(\"EmbeddingGrad\");\n  m.add_functor<impl::TFPoolNdGradFunctor>(\"TFPoolNdGrad\");\n  m.add_functor<impl::AdaptivePoolNdGradFunctor>(\"AdaptivePoolNdGrad\");\n  m.add_functor<impl::KLDivLossGradFunctor>(\"KLDivLossGrad\");\n  m.add_functor<impl::KLDivLossTargetGradFunctor>(\"KLDivLossTargetGrad\");\n  m.add_functor<impl::NLLGradFunctor>(\"NLLGrad\");\n  m.add_functor<impl::BinaryCrossEntropyLossGradFunctor>(\"BinaryCrossEntropyLossGrad\");\n  m.add_functor<impl::BinaryCrossEntropyLossTargetGradFunctor>(\"BinaryCrossEntropyLossTargetGrad\");\n  m.add_functor<impl::BinaryCrossEntropyWithLogitsLossGradFunctor>(\n      \"BinaryCrossEntropyWithLogitsLossGrad\");\n  m.add_functor<impl::BinaryCrossEntropyWithLogitsLossTargetGradFunctor>(\n      \"BinaryCrossEntropyWithLogitsLossTargetGrad\");\n  m.add_functor<impl::SparseCrossEntropyGradFunctor>(\"SparseCrossEntropyGrad\");\n  m.add_functor<impl::SparseCrossEntropyMsGradFunctor>(\"SparseCrossEntropyMsGrad\");\n  m.add_functor<impl::SparseSoftmaxCrossEntropyGrad>(\"SparseSoftmaxCrossEntropyGrad\");\n  m.add_functor<impl::SparseSoftmaxCrossEntropyMsGrad>(\"SparseSoftmaxCrossEntropyMsGrad\");\n  m.add_functor<impl::SmoothL1LossGradFunctor>(\"SmoothL1LossGrad\");\n  m.add_functor<impl::CombinedMarginLossGradFunctor>(\"CombinedMarginLossGrad\");\n  m.add_functor<impl::AffineGridGradFunctor>(\"AffineGridGrad\");\n  m.add_functor<impl::GridSampleGradFunctor>(\"GridSampleGrad\");\n  m.add_functor<impl::MaxPoolNdGradFunctor>(\"MaxPoolNdGrad\");\n  m.add_functor<impl::MaxUnpoolNdGradFunctor<1>>(\"MaxUnpool1dGrad\");\n  m.add_functor<impl::MaxUnpoolNdGradFunctor<2>>(\"MaxUnpool2dGrad\");\n  m.add_functor<impl::MaxUnpoolNdGradFunctor<3>>(\"MaxUnpool3dGrad\");\n  m.add_functor<impl::AdaptiveMaxPoolNdGradFunctor>(\"AdaptiveMaxPoolNdGrad\");\n  m.add_functor<impl::PadGradFunctor>(\"PadGrad\");\n  m.add_functor<impl::AvgPoolNdGradFunctor>(\"AvgPoolNdGrad\");\n  m.add_functor<impl::NormalizationGradFunctor>(\"NormalizationGrad\");\n  m.add_functor<impl::NormalizationAddReluGradFunctor>(\"NormalizationAddReluGrad\");\n  m.add_functor<impl::LayerNormGradFunctor>(\"LayerNormGrad\");\n  m.add_functor<impl::LayerNormAffineGradFunctor>(\"LayerNormAffineGrad\");\n  m.add_functor<impl::LayerNormParamGradFunctor>(\"LayerNormParamGrad\");\n  m.add_functor<impl::FuseLayerNormGradFunctor>(\"FuseLayerNormGrad\");\n  m.add_functor<impl::GroupNormGradFunctor>(\"GroupNormGrad\");\n  m.add_functor<impl::GroupNormParamGradFunctor>(\"GroupNormParamGrad\");\n  m.add_functor<impl::BroadcastMatmulGradBFunctor>(\"BroadcastMatmulGradB\");\n  m.add_functor<impl::CtcLossGradFunctor>(\"CtcLossGrad\");\n  m.add_functor<impl::FusedScaleTrilSoftmaxMaskScaleGradFunctor>(\n      \"FusedScaleTrilSoftmaxMaskScaleGrad\");\n  m.add_functor<impl::FusedScaleMaskSoftmaxGradFunctor>(\"FusedScaleMaskSoftmaxGrad\");\n  m.add_functor<impl::FusedScaleMaskSoftmaxDropoutGradFunctor>(\"FusedScaleMaskSoftmaxDropoutGrad\");\n  m.add_functor<impl::CublasBiasAddReluMatmulGradFunctor>(\"CublasBiasAddReluMatmulGrad\");\n  m.add_functor<impl::CublasMatmulBiasAddGradFunctor>(\"CublasMatmulBiasAddGrad\");\n  m.add_functor<impl::FusedReluDropoutGradFunctor>(\"FusedReluDropoutGrad\");\n  m.add_functor<impl::FusedDotFeatureInteractionGradFunctor>(\"FusedDotFeatureInteractionGrad\");\n  m.add_functor<impl::FusedCrossFeatureInteractionV1GradFunctor>(\n      \"FusedCrossFeatureInteractionV1Grad\");\n  m.add_functor<impl::FusedCrossFeatureInteractionV2GradFunctor>(\n      \"FusedCrossFeatureInteractionV2Grad\");\n  m.add_functor<impl::FusedGluWithoutLinearGradFunctor>(\"FusedGluWithoutLinearGrad\");\n  m.add_functor<impl::FusedMLPGradFunctor>(\"FusedMLPGrad\");\n  m.add_functor<impl::BinaryCrossEntropyWithLogitsReduceMeanLossGradFunctor>(\n      \"BinaryCrossEntropyWithLogitsReduceMeanLossGrad\");\n  m.add_functor<impl::BinaryCrossEntropyWithLogitsReduceMeanLossTargetGradFunctor>(\n      \"BinaryCrossEntropyWithLogitsReduceMeanLossTargetGrad\");\n  m.add_functor<impl::MatrixVectorProductGradAFunctor>(\"MatrixVectorProductGradA\");\n  m.add_functor<impl::MatrixVectorProductGradBFunctor>(\"MatrixVectorProductGradB\");\n  m.add_functor<impl::VectorMatrixProductGradAFunctor>(\"VectorMatrixProductGradA\");\n  m.add_functor<impl::VectorMatrixProductGradBFunctor>(\"VectorMatrixProductGradB\");\n  m.add_functor<impl::DeformConv2dInputGradFunctor>(\"DeformConv2dInputGrad\");\n  m.add_functor<impl::DeformConv2dParamGradFunctor>(\"DeformConv2dParamGrad\");\n  m.add_functor<impl::RMSNormGradFunctor>(\"RMSNormGrad\");\n};\n\n}  // namespace functional\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/functional/impl/quantization.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/functional/impl/binary_functor.h\"\n\n#include \"oneflow/core/framework/attr_map.h\"\n#include \"oneflow/core/framework/mutable_attr_map.h\"\n#include \"oneflow/core/framework/op_builder.h\"\n#include \"oneflow/core/framework/op_expr.h\"\n#include \"oneflow/core/framework/op_interpreter/op_interpreter_util.h\"\n#include \"oneflow/core/framework/tensor.h\"\n#include \"oneflow/core/framework/tensor_tuple.h\"\n#include \"oneflow/core/functional/function_library.h\"\n\nnamespace oneflow {\nnamespace one {\nnamespace functional {\n\nnamespace impl {\n\nclass MinMaxObserverFunctor {\n public:\n  MinMaxObserverFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"min_max_observer\")\n                         .Input(\"in\")\n                         .Output(\"scale\")\n                         .Output(\"zero_point\")\n                         .Build());\n  }\n  Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& in,\n                                const std::string& quantization_formula,\n                                const int32_t& quantization_bit,\n                                const std::string& quantization_scheme,\n                                const bool& per_layer_quantization) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"quantization_formula\", \"quantization_bit\",\n                                                 \"quantization_scheme\", \"per_layer_quantization\");\n    attrs.SetAllAttrs(quantization_formula, quantization_bit, quantization_scheme,\n                      per_layer_quantization);\n    return OpInterpUtil::Dispatch<TensorTuple>(*op_, {in}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass MovingAverageMinMaxObserverFunctor {\n public:\n  MovingAverageMinMaxObserverFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"moving_average_min_max_observer\")\n                         .Input(\"in\")\n                         .Input(\"current_train_step\")\n                         .Input(\"moving_max\")\n                         .Input(\"moving_min\")\n                         .Output(\"scale\")\n                         .Output(\"zero_point\")\n                         .Build());\n  }\n  Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& in,\n                                const std::shared_ptr<one::Tensor>& current_train_step,\n                                const std::shared_ptr<one::Tensor>& moving_max,\n                                const std::shared_ptr<one::Tensor>& moving_min,\n                                const bool& training, const int64_t& stop_update_after_iters,\n                                const std::string& quantization_formula,\n                                const int32_t& quantization_bit,\n                                const std::string& quantization_scheme,\n                                const float& momentum) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"training\", \"quantization_formula\",\n                                                 \"stop_update_after_iters\", \"quantization_bit\",\n                                                 \"quantization_scheme\", \"momentum\");\n    attrs.SetAllAttrs(training, quantization_formula, stop_update_after_iters, quantization_bit,\n                      quantization_scheme, momentum);\n    return OpInterpUtil::Dispatch<TensorTuple>(\n        *op_, {in, current_train_step, moving_max, moving_min}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass FakeQuantizationFunctor {\n public:\n  FakeQuantizationFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"fake_quantization\")\n                         .Input(\"in\")\n                         .Input(\"scale\")\n                         .Input(\"zero_point\")\n                         .Output(\"out\")\n                         .Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& in,\n                           const std::shared_ptr<one::Tensor>& scale,\n                           const std::shared_ptr<one::Tensor>& zero_point,\n                           const std::string& quantization_formula, const int32_t& quantization_bit,\n                           const std::string& quantization_scheme) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"quantization_formula\", \"quantization_bit\",\n                                                 \"quantization_scheme\");\n    attrs.SetAllAttrs(quantization_formula, quantization_bit, quantization_scheme);\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {in, scale, zero_point}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass QuantizationFunctor {\n public:\n  QuantizationFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"quantization\")\n                         .Input(\"in\")\n                         .Input(\"scale\")\n                         .Input(\"zero_point\")\n                         .Output(\"out\")\n                         .Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& in,\n                           const std::shared_ptr<one::Tensor>& scale,\n                           const std::shared_ptr<one::Tensor>& zero_point,\n                           const std::string quantization_formula, const int32_t& quantization_bit,\n                           const std::string quantization_scheme) const {\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"quantization_formula\", \"quantization_bit\",\n                                                 \"quantization_scheme\");\n    attrs.SetAllAttrs(quantization_formula, quantization_bit, quantization_scheme);\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {in, scale, zero_point}, attrs);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass GroupwiseDequantizeFunctor {\n public:\n  GroupwiseDequantizeFunctor() {\n    symmetric_op_ = CHECK_JUST(\n        one::OpBuilder(\"groupwise_dequantize\").Input(\"in\").Input(\"scale\").Output(\"out\").Build());\n    asymmetric_op_ = CHECK_JUST(one::OpBuilder(\"groupwise_dequantize\")\n                                    .Input(\"in\")\n                                    .Input(\"scale\")\n                                    .Input(\"zero\")\n                                    .Output(\"out\")\n                                    .Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& in,\n                           const std::shared_ptr<one::Tensor>& scale,\n                           const Optional<one::Tensor>& zero, const int32_t& num_bits,\n                           const bool& symmetric, const int64_t& group_dim,\n                           const int64_t& group_size) const {\n    auto& attrs =\n        THREAD_CACHED_MUTABLE_ATTR_MAP(\"num_bits\", \"symmetric\", \"group_dim\", \"group_size\");\n    CHECK_OR_RETURN(num_bits == 4 || num_bits == 8) << \"num_bits should be 4 or 8.\";\n    CHECK_GE_OR_RETURN(in->shape()->NumAxes(), 1)\n        << \"The number of dimensions for tensor in should be greater than or equal to 1.\";\n    const int64_t regularized_group_dim =\n        group_dim < 0 ? in->shape()->NumAxes() + group_dim : group_dim;\n    CHECK_OR_RETURN(regularized_group_dim >= 0 && regularized_group_dim < in->shape()->NumAxes())\n        << \"group_dim should be in range [-\" << in->shape()->NumAxes() << \",\"\n        << in->shape()->NumAxes() << \").\";\n    const int64_t group_dim_size =\n        in->shape()->At(regularized_group_dim)\n        * (regularized_group_dim == in->shape()->NumAxes() - 1 ? 8 / num_bits : 1);\n    const int64_t regularized_group_size = group_size < 0 ? group_dim_size : group_size;\n    CHECK_OR_RETURN(regularized_group_size > 0 && regularized_group_size <= group_dim_size)\n        << \"group_size should be in range (0,\" << group_dim_size << \"].\";\n    CHECK_EQ_OR_RETURN(group_dim_size % regularized_group_size, 0)\n        << \"group_size should be a divisor of \" << group_dim_size << \".\";\n    const int64_t num_groups = group_dim_size / regularized_group_size;\n    if (symmetric) {\n      CHECK_OR_RETURN(in->dtype()->data_type() == DataType::kUInt8\n                      || in->dtype()->data_type() == DataType::kInt8)\n          << \"The dtype of tensor in should be int8 or uint8.\";\n    } else {\n      CHECK_OR_RETURN(in->dtype()->data_type() == DataType::kUInt8)\n          << \"The dtype of tensor in should be uint8.\";\n    }\n    CHECK_EQ_OR_RETURN(scale->shape()->NumAxes(), in->shape()->NumAxes())\n        << \"The number of dimensions of tensor scale should be equal to tensor in.\";\n    for (int64_t i = 0; i < in->shape()->NumAxes(); ++i) {\n      if (i == regularized_group_dim) {\n        CHECK_EQ_OR_RETURN(scale->shape()->At(i), num_groups)\n            << \"The size of the \" << i << \"-th dimension of tensor scale should be equal to \"\n            << num_groups;\n      } else if (i == in->shape()->NumAxes() - 1) {\n        CHECK_EQ_OR_RETURN(scale->shape()->At(i), in->shape()->At(i) * (8 / num_bits))\n            << \"The size of the \" << i << \"-th dimension of tensor scale should be equal to \"\n            << in->shape()->At(i) * (8 / num_bits) << \".\";\n      } else {\n        CHECK_EQ_OR_RETURN(scale->shape()->At(i), in->shape()->At(i))\n            << \"The size of the \" << i\n            << \"-th dimension of tensor scale should be equal to tensor in.\";\n      }\n    }\n    if (!symmetric) {\n      CHECK_OR_RETURN(zero) << \"When symmetric is False, tensor zero should be specified.\";\n      CHECK_OR_RETURN(JUST(zero)->dtype() == scale->dtype())\n          << \"The dtype of the zero tensor should be the same as the scale \"\n             \"tensor.\";\n      CHECK_OR_RETURN(*JUST(zero)->shape() == *scale->shape())\n          << \"The shape of zero tensor should be equal to tensor scale.\";\n    } else {\n      CHECK_OR_RETURN(!zero) << \"When symmetric is True, tensor zero should be None.\";\n    }\n    attrs.SetAllAttrs(num_bits, symmetric, regularized_group_dim, regularized_group_size);\n    if (symmetric) {\n      return OpInterpUtil::Dispatch<Tensor>(*symmetric_op_, {in, scale}, attrs);\n    } else {\n      return OpInterpUtil::Dispatch<Tensor>(*asymmetric_op_, {in, scale, JUST(zero)}, attrs);\n    }\n  }\n\n private:\n  std::shared_ptr<OpExpr> symmetric_op_;\n  std::shared_ptr<OpExpr> asymmetric_op_;\n};\n\nclass FusedLinearWithGroupwiseQuantizedWeightFunctor {\n public:\n  FusedLinearWithGroupwiseQuantizedWeightFunctor() {\n    symmetric_with_bias_op_ =\n        CHECK_JUST(one::OpBuilder(\"fused_linear_with_groupwise_quantized_weight\")\n                       .Input(\"x\")\n                       .Input(\"w\")\n                       .Input(\"w_scale\")\n                       .Input(\"b\")\n                       .Output(\"out\")\n                       .Build());\n    symmetric_without_bias_op_ =\n        CHECK_JUST(one::OpBuilder(\"fused_linear_with_groupwise_quantized_weight\")\n                       .Input(\"x\")\n                       .Input(\"w\")\n                       .Input(\"w_scale\")\n                       .Output(\"out\")\n                       .Build());\n    asymmetric_with_bias_op_ =\n        CHECK_JUST(one::OpBuilder(\"fused_linear_with_groupwise_quantized_weight\")\n                       .Input(\"x\")\n                       .Input(\"w\")\n                       .Input(\"w_scale\")\n                       .Input(\"w_zero\")\n                       .Input(\"b\")\n                       .Output(\"out\")\n                       .Build());\n    asymmetric_without_bias_op_ =\n        CHECK_JUST(one::OpBuilder(\"fused_linear_with_groupwise_quantized_weight\")\n                       .Input(\"x\")\n                       .Input(\"w\")\n                       .Input(\"w_scale\")\n                       .Input(\"w_zero\")\n                       .Output(\"out\")\n                       .Build());\n  }\n\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,\n                           const std::shared_ptr<one::Tensor>& w,\n                           const std::shared_ptr<one::Tensor>& w_scale,\n                           const Optional<one::Tensor>& w_zero, const Optional<one::Tensor>& b,\n                           const int32_t& num_bits, const bool& symmetric, const int64_t& group_dim,\n                           const int64_t& group_size) const {\n    CHECK_GE_OR_RETURN(x->shape()->NumAxes(), 2)\n        << \"The number of dimensions for tensor x should be greater than or equal to 2.\";\n    const int64_t m = x->shape()->Count(0, x->shape()->NumAxes() - 1);\n    const int64_t k = x->shape()->At(x->shape()->NumAxes() - 1);\n    CHECK_OR_RETURN(num_bits == 4 || num_bits == 8) << \"num_bits should be 4 or 8.\";\n    CHECK_EQ_OR_RETURN(w->shape()->NumAxes(), 2)\n        << \"The number of dimensions for tensor w should be equal to 2.\";\n    CHECK_EQ_OR_RETURN(k % (8 / num_bits), 0)\n        << \"The size of the last dimension of x should be a multiple of (8/num_bits).\";\n    CHECK_EQ_OR_RETURN(w->shape()->At(1), k / (8 / num_bits))\n        << \"The size of second dimension of tensor w should be equal to \" << k / (8 / num_bits);\n    const int64_t n = w->shape()->At(0);\n    const int64_t regularized_group_dim =\n        group_dim < 0 ? w->shape()->NumAxes() + group_dim : group_dim;\n    CHECK_OR_RETURN(regularized_group_dim == 0 || regularized_group_dim == 1)\n        << \"group_dim should be in range [-2,2).\";\n    const int64_t group_dim_size = regularized_group_dim == 0 ? n : k;\n    const int64_t regularized_group_size = group_size < 0 ? group_dim_size : group_size;\n    CHECK_OR_RETURN(regularized_group_size > 0 && regularized_group_size <= group_dim_size)\n        << \"group_size should be in range (0,\" << group_dim_size << \"].\";\n    CHECK_EQ_OR_RETURN(group_dim_size % regularized_group_size, 0)\n        << \"group_size should be a divisor of \" << group_dim_size << \".\";\n    const int64_t num_groups = group_dim_size / regularized_group_size;\n    if (symmetric) {\n      CHECK_OR_RETURN(w->dtype()->data_type() == DataType::kUInt8\n                      || w->dtype()->data_type() == DataType::kInt8)\n          << \"The dtype of tensor w should be int8 or uint8.\";\n    } else {\n      CHECK_OR_RETURN(w->dtype()->data_type() == DataType::kUInt8)\n          << \"The dtype of tensor w should be uint8.\";\n    }\n    CHECK_EQ_OR_RETURN(w_scale->shape()->NumAxes(), 2)\n        << \"The number of dimensions of tensor w_scale should be equal to 2.\";\n    for (int64_t i = 0; i < 2; ++i) {\n      if (i == regularized_group_dim) {\n        CHECK_EQ_OR_RETURN(w_scale->shape()->At(i), num_groups)\n            << \"The size of the \" << i << \"-th dimension of tensor w_scale should be equal to \"\n            << num_groups;\n      } else if (i == 1) {\n        CHECK_EQ_OR_RETURN(w_scale->shape()->At(i), k)\n            << \"The size of the \" << i << \"-th dimension of tensor w_scale should be equal to \" << k\n            << \".\";\n      } else {\n        CHECK_EQ_OR_RETURN(w_scale->shape()->At(i), w->shape()->At(i))\n            << \"The size of the \" << i\n            << \"-th dimension of tensor w_scale should be equal to tensor w.\";\n      }\n    }\n    CHECK_OR_RETURN(w_scale->dtype() == x->dtype())\n        << \"The dtype of the w_scale tensor should be the same as the x tensor.\";\n    if (!symmetric) {\n      CHECK_OR_RETURN(w_zero) << \"When symmetric is False, tensor w_zero should be specified.\";\n      CHECK_OR_RETURN(JUST(w_zero)->dtype() == w_scale->dtype())\n          << \"The dtype of the w_zero tensor should be the same as the w_scale \"\n             \"tensor.\";\n      CHECK_OR_RETURN(*JUST(w_zero)->shape() == *w_scale->shape())\n          << \"The shape of w_zero tensor should be equal to tensor w_scale.\";\n    } else {\n      CHECK_OR_RETURN(!w_zero) << \"When symmetric is True, tensor w_zero should be None.\";\n    }\n\n    if (b) {\n      CHECK_OR_RETURN(JUST(b)->dtype() == x->dtype())\n          << \"The dtype of the b tensor should be the same as the x tensor.\";\n      CHECK_EQ_OR_RETURN(JUST(b)->shape()->NumAxes(), 1)\n          << \"The number of dimensions for tensor b should be equal to 1.\";\n      CHECK_EQ_OR_RETURN(JUST(b)->shape()->At(0), n)\n          << \"The size of first dimension of tensor b should be equal to the size of first \"\n             \"dimension of tensor w\";\n    }\n\n    if (m > 8) {\n      const auto w_dequantized = JUST(functional::GroupwiseDequantize(\n          w, w_scale, w_zero, num_bits, symmetric, group_dim, group_size));\n      if (b) {\n        return JUST(functional::FusedMatmulBias(x, w_dequantized, JUST(b), Optional<one::Tensor>(),\n                                                1.0, 1.0));\n      } else {\n        return JUST(functional::MatMul(x, w_dequantized, false, true, 1.0));\n      }\n    }\n    auto& attrs =\n        THREAD_CACHED_MUTABLE_ATTR_MAP(\"num_bits\", \"symmetric\", \"group_dim\", \"group_size\");\n\n    attrs.SetAllAttrs(num_bits, symmetric, regularized_group_dim, regularized_group_size);\n\n    if (symmetric) {\n      if (b) {\n        return OpInterpUtil::Dispatch<Tensor>(*symmetric_with_bias_op_, {x, w, w_scale, JUST(b)},\n                                              attrs);\n      } else {\n        return OpInterpUtil::Dispatch<Tensor>(*symmetric_without_bias_op_, {x, w, w_scale}, attrs);\n      }\n    } else {\n      if (b) {\n        return OpInterpUtil::Dispatch<Tensor>(*asymmetric_with_bias_op_,\n                                              {x, w, w_scale, JUST(w_zero), JUST(b)}, attrs);\n      } else {\n        return OpInterpUtil::Dispatch<Tensor>(*asymmetric_without_bias_op_,\n                                              {x, w, w_scale, JUST(w_zero)}, attrs);\n      }\n    }\n  }\n\n private:\n  std::shared_ptr<OpExpr> symmetric_with_bias_op_;\n  std::shared_ptr<OpExpr> symmetric_without_bias_op_;\n  std::shared_ptr<OpExpr> asymmetric_with_bias_op_;\n  std::shared_ptr<OpExpr> asymmetric_without_bias_op_;\n};\n\n}  // namespace impl\n\nONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor<impl::FakeQuantizationFunctor>(\"FakeQuantization\"); };\nONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor<impl::QuantizationFunctor>(\"Quantization\"); };\nONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor<impl::MinMaxObserverFunctor>(\"MinMaxObserver\"); };\nONEFLOW_FUNCTION_LIBRARY(m) {\n  m.add_functor<impl::MovingAverageMinMaxObserverFunctor>(\"MovingAverageMinMaxObserver\");\n};\nONEFLOW_FUNCTION_LIBRARY(m) {\n  m.add_functor<impl::GroupwiseDequantizeFunctor>(\"GroupwiseDequantize\");\n  m.add_functor<impl::FusedLinearWithGroupwiseQuantizedWeightFunctor>(\n      \"FusedLinearWithGroupwiseQuantizedWeight\");\n};\n\n}  // namespace functional\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/functional/impl/random_functor.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/kernel/kernel_util.h\"\n#include \"oneflow/core/framework/layout.h\"\n#include \"oneflow/core/framework/mutable_attr_map.h\"\n#include \"oneflow/core/framework/nd_sbp.h\"\n#include \"oneflow/core/framework/op_builder.h\"\n#include \"oneflow/core/functional/function_library.h\"\n#include \"oneflow/core/functional/impl/unary_functor.h\"\n#include \"oneflow/core/job/global_mode.h\"\n#include \"oneflow/core/job/parallel_desc.h\"\n#include \"oneflow/core/functional/functional.h\"\n#include \"oneflow/user/kernels/distributions/common.h\"\n#include \"oneflow/user/kernels/random_seed_util.h\"\n#include \"oneflow/core/rpc/include/global_process_ctx.h\"\n\nnamespace oneflow {\nnamespace one {\nnamespace functional {\n\nnamespace impl {\n\nclass BernoulliFunctor {\n public:\n  BernoulliFunctor() {\n    bernoulli_op_ = CHECK_JUST(one::OpBuilder(\"bernoulli\").Input(\"in\").Output(\"out\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const Symbol<DType>& dtype,\n                           const Optional<one::Generator>& generator, const bool& inplace) const {\n    if (x->is_global()) { JUST(CheckDeviceIdsIsValid(JUST(x->parallel_desc()))); }\n    auto gen = generator.value_or(JUST(one::DefaultAutoGenerator()));\n    gen = JUST(GetGeneratorForLazyOrGlobal(gen, LazyMode::is_enabled(), x));\n    auto& bernoulli_attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"dtype\", \"seed\", \"p\");\n    // p == -1 means bernoulli op doesn't use p to generate random number\n    bernoulli_attrs.SetAllAttrs(dtype->data_type(), static_cast<int64_t>(gen->current_seed()),\n                                static_cast<double>(-1));\n\n    const auto& distribution_state = std::make_shared<DistributionKernelState>(gen);\n    OpExprInterpContext ctx(bernoulli_attrs, distribution_state);\n    if (inplace) {\n      auto outputs = std::make_shared<TensorTuple>(1);\n      JUST(CheckInplaceValid(x));\n      (*outputs)[0] = x;\n      JUST(OpInterpUtil::Dispatch(*bernoulli_op_, {x}, outputs.get(), ctx));\n      return outputs->at(0);\n    } else {\n      return OpInterpUtil::Dispatch<Tensor>(*bernoulli_op_, {x}, ctx);\n    }\n  }\n\n private:\n  std::shared_ptr<OpExpr> bernoulli_op_;\n};\n\nclass BernoulliInplaceFunctor {\n public:\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const Symbol<DType>& dtype,\n                           const Optional<one::Generator>& generator) const {\n    return Bernoulli(x, dtype, generator, true);\n  }\n};\n\nclass BernoulliProbFunctor {\n public:\n  BernoulliProbFunctor() {\n    bernoulli_op_ = CHECK_JUST(one::OpBuilder(\"bernoulli\").Input(\"in\").Output(\"out\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const double& p,\n                           const Symbol<DType>& dtype, const Optional<one::Generator>& generator,\n                           const bool& inplace) const {\n    CHECK_OR_THROW(p >= 0.0 && p <= 1.0) << \"bernoulli expects p to be in [0, 1], but got p=\" << p;\n    if (x->is_global()) { JUST(CheckDeviceIdsIsValid(JUST(x->parallel_desc()))); }\n\n    auto gen = generator.value_or(JUST(one::DefaultAutoGenerator()));\n    gen = JUST(GetGeneratorForLazyOrGlobal(gen, LazyMode::is_enabled(), x));\n    auto& bernoulli_attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"dtype\", \"seed\", \"p\");\n    bernoulli_attrs.SetAllAttrs(dtype->data_type(), static_cast<int64_t>(gen->current_seed()), p);\n\n    const auto& distribution_state = std::make_shared<DistributionKernelState>(gen);\n    OpExprInterpContext ctx(bernoulli_attrs, distribution_state);\n    if (inplace) {\n      auto outputs = std::make_shared<TensorTuple>(1);\n      JUST(CheckInplaceValid(x));\n      (*outputs)[0] = x;\n      JUST(OpInterpUtil::Dispatch(*bernoulli_op_, {x}, outputs.get(), ctx));\n      return outputs->at(0);\n    } else {\n      return OpInterpUtil::Dispatch<Tensor>(*bernoulli_op_, {x}, ctx);\n    }\n  }\n\n private:\n  std::shared_ptr<OpExpr> bernoulli_op_;\n};\n\nclass BernoulliProbInplaceFunctor {\n public:\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const double& p,\n                           const Symbol<DType>& dtype,\n                           const Optional<one::Generator>& generator) const {\n    return BernoulliProb(x, p, dtype, generator, true);\n  }\n};\n\nclass InplaceUniformFunctor {\n public:\n  InplaceUniformFunctor() {\n    uniform_op_ = CHECK_JUST(one::OpBuilder(\"uniform\").Output(\"out\").Build());\n    uniform_int_op_ = CHECK_JUST(one::OpBuilder(\"uniform_int\").Output(\"out\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const Scalar& from,\n                           const Scalar& to) const {\n    JUST(CheckInplaceValid(x));\n    const Shape& shape = *(x->shape());\n    std::shared_ptr<OpExpr> exec_op;\n    const auto& dtype = x->dtype();\n    bool IsInteger = false;\n\n    if (dtype->is_floating_point()) {\n      exec_op = uniform_op_;\n    } else if (dtype->is_integer()) {\n      exec_op = uniform_int_op_;\n      IsInteger = true;\n    } else {\n      OF_UNIMPLEMENTED() << \"Only support floating and int dtype.\";\n    }\n    DataType dtype_val = dtype->data_type();\n\n    Optional<Symbol<Device>> device;\n    Optional<Symbol<ParallelDesc>> placement;\n    Optional<Symbol<NdSbp>> nd_sbp;\n\n    auto gen = JUST(one::DefaultAutoGenerator());\n    if (x->is_global()) {\n      JUST(CheckDeviceIdsIsValid(JUST(x->parallel_desc())));\n      placement = JUST(x->parallel_desc());\n      nd_sbp = JUST(x->nd_sbp());\n      gen = JUST(GetGeneratorForLazyOrGlobal(gen, LazyMode::is_enabled(), placement, nd_sbp));\n    } else {\n      device = JUST(x->device());\n      gen = JUST(GetGeneratorForLazyOrGlobal(gen, LazyMode::is_enabled(), NullOpt, NullOpt));\n    }\n\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"from\", \"to\", \"shape\", \"dtype\", \"seed\", \"nd_sbp\");\n    Optional<std::vector<std::string>> attr_nd_sbp{NullOpt};\n    if (nd_sbp) { attr_nd_sbp = *JUST(GetNdSbpStrList(JUST(nd_sbp))); }\n    if (IsInteger) {\n      attrs.SetAllAttrs(from.Value<int64_t>(), to.Value<int64_t>(), shape, dtype_val,\n                        static_cast<int64_t>(gen->current_seed()), attr_nd_sbp);\n    } else {\n      attrs.SetAllAttrs(from.Value<double>(), to.Value<double>(), shape, dtype_val,\n                        static_cast<int64_t>(gen->current_seed()), attr_nd_sbp);\n    }\n\n    const auto& distribution_state = std::make_shared<DistributionKernelState>(gen);\n    OpExprInterpContext ctx(attrs, distribution_state);\n    ctx.parallel_desc = placement;\n    ctx.nd_sbp = nd_sbp;\n    ctx.device = device;\n\n    auto outputs = std::make_shared<TensorTuple>(1);\n    (*outputs)[0] = x;\n    JUST(OpInterpUtil::Dispatch(*exec_op, {}, outputs.get(), ctx));\n    return outputs->at(0);\n  }\n\n private:\n  std::shared_ptr<OpExpr> uniform_op_;\n  std::shared_ptr<OpExpr> uniform_int_op_;\n};\n\nclass RandFunctor {\n public:\n  RandFunctor() { op_ = CHECK_JUST(one::OpBuilder(\"uniform\").Output(\"out\").Build()); }\n  Maybe<Tensor> operator()(const Shape& shape, const Optional<Symbol<DType>>& dtype,\n                           const Optional<Symbol<Device>>& device,\n                           const Optional<one::Generator>& generator,\n                           const bool& requires_grad) const {\n    if (GlobalMode::is_enabled()) {\n      auto global_mode_gurad = GlobalMode::Guard(false);\n      return JUST(functional::GlobalRand(shape, GetGlobalParallelDescFromDevice(device),\n                                         *JUST(GetSbpList(GlobalMode::nd_sbp())), dtype, generator,\n                                         requires_grad));\n    }\n    DataType dtype_val = GetDefaultDType()->data_type();\n    if (dtype.has_value()) {\n      dtype_val = JUST(dtype)->data_type();\n      if (!JUST(dtype)->is_floating_point()) {\n        OF_UNIMPLEMENTED() << \"Only support floating dtype in rand().\";\n      }\n    }\n\n    auto gen = generator.value_or(JUST(one::DefaultAutoGenerator()));\n    gen = JUST(GetGeneratorForLazyOrGlobal(gen, LazyMode::is_enabled(), NullOpt, NullOpt));\n\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"from\", \"to\", \"shape\", \"dtype\", \"seed\");\n    attrs.SetAllAttrs(static_cast<double>(0), static_cast<double>(1), shape, dtype_val,\n                      static_cast<int64_t>(gen->current_seed()));\n\n    const auto& distribution_state = std::make_shared<DistributionKernelState>(gen);\n    OpExprInterpContext ctx(attrs, distribution_state);\n    ctx.device = device;\n    auto result = JUST(OpInterpUtil::Dispatch<Tensor>(*op_, {}, ctx));\n    JUST(result->set_requires_grad(requires_grad));\n    return result;\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass GlobalRandFunctor {\n public:\n  GlobalRandFunctor() { op_ = CHECK_JUST(one::OpBuilder(\"uniform\").Output(\"out\").Build()); }\n  Maybe<Tensor> operator()(const Shape& shape, const Symbol<ParallelDesc>& placement,\n                           const std::vector<Symbol<SbpParallel>>& sbp_tuple,\n                           const Optional<Symbol<DType>>& dtype,\n                           const Optional<one::Generator>& generator,\n                           const bool& requires_grad) const {\n    DataType dtype_val = GetDefaultDType()->data_type();\n    if (dtype.has_value()) {\n      dtype_val = JUST(dtype)->data_type();\n      if (dtype_val != DataType::kFloat && dtype_val != DataType::kDouble) {\n        OF_UNIMPLEMENTED() << \"Only support floating dtype in rand().\";\n      }\n    }\n\n    JUST(CheckDeviceIdsIsValid(placement));\n    const auto& nd_sbp = JUST(GetNdSbp(sbp_tuple));\n    auto attr_nd_sbp = *JUST(GetNdSbpStrList(nd_sbp));\n\n    auto gen = generator.value_or(JUST(one::DefaultAutoGenerator()));\n    gen = JUST(GetGeneratorForLazyOrGlobal(gen, LazyMode::is_enabled(), placement, nd_sbp));\n\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"from\", \"to\", \"shape\", \"dtype\", \"seed\", \"nd_sbp\");\n    attrs.SetAllAttrs(static_cast<double>(0), static_cast<double>(1), shape, dtype_val,\n                      static_cast<int64_t>(gen->current_seed()), attr_nd_sbp);\n\n    const auto& distribution_state = std::make_shared<DistributionKernelState>(gen);\n    auto result = JUST(OpInterpUtil::Dispatch<Tensor>(\n        *op_, {}, OpExprInterpContext(attrs, placement, nd_sbp, distribution_state)));\n    JUST(result->set_requires_grad(requires_grad));\n    return result;\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass RandNFunctor {\n public:\n  Maybe<Tensor> operator()(const Shape& shape, const Optional<Symbol<DType>>& dtype,\n                           const Optional<Symbol<Device>>& device,\n                           const Optional<one::Generator>& generator, const bool& requires_grad,\n                           const Symbol<Layout>& layout) const {\n    if (GlobalMode::is_enabled()) {\n      auto global_mode_gurad = GlobalMode::Guard(false);\n      return JUST(functional::GlobalRandN(shape, GetGlobalParallelDescFromDevice(device),\n                                          *JUST(GetSbpList(GlobalMode::nd_sbp())), dtype, generator,\n                                          requires_grad));\n    }\n    if (dtype.has_value() && !JUST(dtype)->is_floating_point()) {\n      OF_UNIMPLEMENTED() << \"Only support floating dtype in randn().\";\n    }\n    const auto& out = Optional<one::Tensor>();\n    return Normal(static_cast<double>(0), static_cast<double>(1), shape, out, dtype, device,\n                  generator, requires_grad);\n  }\n};\n\nclass GlobalRandNFunctor {\n public:\n  Maybe<Tensor> operator()(const Shape& shape, const Symbol<ParallelDesc>& placement,\n                           const std::vector<Symbol<SbpParallel>>& sbp_tuple,\n                           const Optional<Symbol<DType>>& dtype,\n                           const Optional<one::Generator>& generator,\n                           const bool& requires_grad) const {\n    if (dtype.has_value() && !JUST(dtype)->is_floating_point()) {\n      OF_UNIMPLEMENTED() << \"Only support floating dtype in randn().\";\n    }\n    const auto& out = Optional<one::Tensor>();\n    return GlobalNormal(static_cast<double>(0), static_cast<double>(1), shape, out, placement,\n                        sbp_tuple, dtype, generator, requires_grad);\n  }\n};\n\nclass NormalFunctor {\n public:\n  NormalFunctor() { op_ = CHECK_JUST(one::OpBuilder(\"normal\").Output(\"out\").Build()); }\n  Maybe<Tensor> operator()(const float mean, const float std, const Shape& shape,\n                           const Optional<one::Tensor>& out,\n                           const Optional<Symbol<DType>>& optional_dtype,\n                           const Optional<Symbol<Device>>& optional_device,\n                           const Optional<one::Generator>& optional_generator,\n                           const bool requires_grad) const {\n    Symbol<DType> dtype = GetDefaultDType();\n    if (optional_dtype.has_value()) {\n      if (!JUST(optional_dtype)->is_floating_point()) {\n        OF_UNIMPLEMENTED() << \"Only support float and double in normal().\";\n      }\n      dtype = JUST(optional_dtype);\n    }\n    Symbol<Device> device = JUST(Device::New(\"cpu\"));\n    if (optional_device.has_value()) { device = JUST(optional_device); }\n\n    if (out.has_value()) {\n      auto out_tensor = JUST(out);\n\n      CHECK_OR_RETURN(shape == (*out_tensor->shape()))\n          << \"Shape of out_tensor does not match shape. \"\n          << \"Expected shape: \" << shape << \", actual shape: \" << *out_tensor->shape();\n\n      Symbol<DType> output_tensor_dtype = out_tensor->dtype();\n      if (optional_dtype.has_value()) {\n        CHECK_OR_RETURN(output_tensor_dtype == dtype)\n            << Error::RuntimeError() << \"data type \" << dtype->name()\n            << \" does not match data type of out parameter \" << output_tensor_dtype->name();\n      }\n      dtype = output_tensor_dtype;\n      Symbol<Device> out_tensor_device = JUST(out_tensor->device());\n      if (optional_device.has_value()) {\n        CHECK_OR_RETURN(out_tensor_device == JUST(optional_device))\n            << Error::RuntimeError() << \"device type \" << device->ToString()\n            << \" does not match device type of out parameter \" << out_tensor_device->ToString();\n      }\n      device = out_tensor_device;\n    }\n\n    auto gen = optional_generator.value_or(JUST(one::DefaultAutoGenerator()));\n    gen = JUST(GetGeneratorForLazyOrGlobal(gen, LazyMode::is_enabled(), NullOpt, NullOpt));\n\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"mean\", \"std\", \"shape\", \"dtype\", \"seed\");\n    attrs.SetAllAttrs(static_cast<double>(mean), static_cast<double>(std), shape,\n                      dtype->data_type(), static_cast<int64_t>(gen->current_seed()));\n\n    const auto& distribution_state = std::make_shared<DistributionKernelState>(gen);\n    OpExprInterpContext ctx(attrs, device, distribution_state);\n    if (out.has_value()) {\n      std::shared_ptr<TensorTuple> outputs = std::make_shared<TensorTuple>(1);\n      (*outputs)[0] = JUST(out);\n      JUST(OpInterpUtil::Dispatch(*op_, {}, outputs.get(), ctx));\n      return (*outputs)[0];\n    }\n\n    auto result = JUST(OpInterpUtil::Dispatch<Tensor>(*op_, {}, ctx));\n    JUST(result->set_requires_grad(requires_grad));\n    return result;\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass Normal2Functor {\n public:\n  Maybe<Tensor> operator()(const float mean, const float std, const int32_t shape,\n                           const Optional<one::Tensor>& out,\n                           const Optional<Symbol<DType>>& optional_dtype,\n                           const Optional<Symbol<Device>>& optional_device,\n                           const Optional<one::Generator>& optional_generator,\n                           const bool requires_grad) const {\n    const Shape size = Shape({shape});\n    return Normal(mean, std, size, out, optional_dtype, optional_device, optional_generator,\n                  requires_grad);\n  }\n};\n\nclass InplaceNormalFuctor {\n public:\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const float mean, const float std,\n                           const Optional<one::Generator>& optional_generator) const {\n    return Normal(mean, std, *x->shape(), x, x->dtype(), JUST(x->device()), optional_generator,\n                  x->requires_grad());\n  }\n};\n\nclass TensorTensorNormalFunctor {\n public:\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& mean,\n                           const std::shared_ptr<one::Tensor>& std,\n                           const Optional<one::Tensor>& out,\n                           const Optional<one::Generator>& optional_generator,\n                           const bool requires_grad) const {\n    JUST(CheckNormalTensorStd(std));\n    auto out_shape = *JUST(InferUnifiedShapeForBroadcasting({*mean->shape(), *std->shape()}));\n    auto output = JUST(Normal(0, 1, out_shape, out, Symbol<DType>(mean->dtype()),\n                              JUST(mean->device()), optional_generator, requires_grad));\n    // mean + output * std\n    JUST(InplaceMul(output, std));\n    JUST(Add(output, mean, 1, true));\n    JUST(output->set_requires_grad(requires_grad));\n    return output;\n  }\n};\n\nclass TensorScalarNormalFunctor {\n public:\n  // TODO : performance optimizing Write as a kenerl\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& mean, const float std,\n                           const Optional<one::Tensor>& out,\n                           const Optional<one::Generator>& optional_generator,\n                           const bool requires_grad) const {\n    JUST(CheckNormalTensorStd(std));\n    auto output = JUST(Normal(0, std, *(mean->shape()), out, mean->dtype(), JUST(mean->device()),\n                              optional_generator, requires_grad));\n    JUST(Add(output, mean, 1, true));\n    JUST(output->set_requires_grad(requires_grad));\n    return output;\n  }\n};\n\nclass ScalarTensorNormalFunctor {\n public:\n  // TODO : performance optimizing one multiplication and one addition Write as a kenerl\n  Maybe<Tensor> operator()(const float mean, const std::shared_ptr<one::Tensor>& std,\n                           const Optional<one::Tensor>& out,\n                           const Optional<one::Generator>& optional_generator,\n                           const bool requires_grad) const {\n    JUST(CheckNormalTensorStd(std));\n    auto output = JUST(Normal(0.0, 1.0, *(std->shape()), out, std->dtype(), JUST(std->device()),\n                              optional_generator, requires_grad));\n    JUST(InplaceMul(output, std));\n    JUST(ScalarAdd(output, mean, 1, true));\n    JUST(output->set_requires_grad(requires_grad));\n    return output;\n  }\n};\n\nclass GlobalNormalFunctor {\n public:\n  GlobalNormalFunctor() { op_ = CHECK_JUST(one::OpBuilder(\"normal\").Output(\"out\").Build()); }\n  Maybe<Tensor> operator()(const float& mean, const float& std, const Shape& shape,\n                           const Optional<one::Tensor>& out, const Symbol<ParallelDesc>& placement,\n                           const std::vector<Symbol<SbpParallel>>& sbp_tuple,\n                           const Optional<Symbol<DType>>& optional_dtype,\n                           const Optional<one::Generator>& optional_generator,\n                           const bool& requires_grad) const {\n    Symbol<DType> dtype = DType::Float();\n    if (optional_dtype.has_value()) {\n      if (!JUST(optional_dtype)->is_floating_point()) {\n        OF_UNIMPLEMENTED() << \"Only support float and double in normal().\";\n      }\n      dtype = JUST(optional_dtype);\n    }\n\n    if (out.has_value()) {\n      auto out_tensor = JUST(out);\n      Symbol<DType> output_tensor_dtype = out_tensor->dtype();\n      if (optional_dtype.has_value()) {\n        CHECK_OR_RETURN(output_tensor_dtype == dtype)\n            << Error::RuntimeError() << \"data type \" << dtype->name()\n            << \" does not match data type of out parameter (\" << output_tensor_dtype->name();\n      }\n      dtype = output_tensor_dtype;\n    }\n\n    JUST(CheckDeviceIdsIsValid(placement));\n    const auto& nd_sbp = JUST(GetNdSbp(sbp_tuple));\n    auto attr_nd_sbp = *JUST(GetNdSbpStrList(nd_sbp));\n\n    std::shared_ptr<Generator> gen = optional_generator.value_or(JUST(one::DefaultAutoGenerator()));\n    gen = JUST(GetGeneratorForLazyOrGlobal(gen, LazyMode::is_enabled(), placement, nd_sbp));\n\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"mean\", \"std\", \"shape\", \"dtype\", \"seed\", \"nd_sbp\");\n    attrs.SetAllAttrs(static_cast<double>(mean), static_cast<double>(std), shape,\n                      dtype->data_type(), static_cast<int64_t>(gen->current_seed()), attr_nd_sbp);\n\n    const auto& distribution_state = std::make_shared<DistributionKernelState>(gen);\n    OpExprInterpContext ctx(attrs, placement, nd_sbp, distribution_state);\n    if (out.has_value()) {\n      std::shared_ptr<TensorTuple> outputs = std::make_shared<TensorTuple>(1);\n      (*outputs)[0] = JUST(out);\n      JUST(OpInterpUtil::Dispatch(*op_, {}, outputs.get(), ctx));\n      return (*outputs)[0];\n    }\n\n    auto result = JUST(OpInterpUtil::Dispatch<Tensor>(*op_, {}, ctx));\n    JUST(result->set_requires_grad(requires_grad));\n    return result;\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass GlobalNormal2Functor {\n public:\n  Maybe<Tensor> operator()(const float& mean, const float& std, const int32_t& shape,\n                           const Optional<one::Tensor>& out, const Symbol<ParallelDesc>& placement,\n                           const std::vector<Symbol<SbpParallel>>& sbp_tuple,\n                           const Optional<Symbol<DType>>& optional_dtype,\n                           const Optional<one::Generator>& optional_generator,\n                           const bool& requires_grad) const {\n    const Shape size = Shape({shape});\n    return GlobalNormal(mean, std, size, out, placement, sbp_tuple, optional_dtype,\n                        optional_generator, requires_grad);\n  }\n};\n\nclass RandnLikeFunctor {\n public:\n  Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& input,\n                           const Optional<Symbol<DType>>& dtype,\n                           const Optional<Symbol<Device>>& device,\n                           const Optional<one::Generator>& generator,\n                           const bool& requires_grad) const {\n    return RandN(*input->shape(), dtype.value_or(input->dtype()),\n                 device.value_or(JUST(input->device())), generator, requires_grad,\n                 Layout::Strided());\n  }\n};\n\nclass GlobalRandnLikeFunctor {\n public:\n  Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& input,\n                           const Symbol<ParallelDesc>& placement,\n                           const std::vector<Symbol<SbpParallel>>& sbp,\n                           const Optional<Symbol<DType>>& dtype,\n                           const Optional<one::Generator>& generator,\n                           const bool& requires_grad) const {\n    return GlobalRandN(*input->shape(), placement, sbp, dtype.value_or(input->dtype()), generator,\n                       requires_grad);\n  }\n};\n\nclass RandIntFunctor {\n public:\n  RandIntFunctor() { op_ = CHECK_JUST(one::OpBuilder(\"uniform_int\").Output(\"out\").Build()); }\n\n  Maybe<Tensor> operator()(const int64_t low, const int64_t high, const Shape& shape,\n                           const Optional<Symbol<DType>>& dtype,\n                           const Optional<Symbol<Device>>& device,\n                           const Optional<one::Generator>& generator,\n                           const bool& requires_grad) const {\n    if (GlobalMode::is_enabled()) {\n      auto global_mode_gurad = GlobalMode::Guard(false);\n      return JUST(functional::GlobalRandInt(\n          low, high, shape, GetGlobalParallelDescFromDevice(device),\n          *JUST(GetSbpList(GlobalMode::nd_sbp())), dtype, generator, requires_grad));\n    }\n\n    DataType dtype_val = DataType::kInt64;\n    if (dtype) { dtype_val = JUST(dtype)->data_type(); }\n\n    auto gen = generator.value_or(JUST(one::DefaultAutoGenerator()));\n    gen = JUST(GetGeneratorForLazyOrGlobal(gen, LazyMode::is_enabled(), NullOpt, NullOpt));\n\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"shape\", \"from\", \"to\", \"dtype\", \"seed\");\n    attrs.SetAllAttrs(shape, low, high, dtype_val, static_cast<int64_t>(gen->current_seed()));\n\n    const auto& distribution_state = std::make_shared<DistributionKernelState>(gen);\n    OpExprInterpContext ctx(attrs, distribution_state);\n    ctx.device = device;\n    auto result = JUST(OpInterpUtil::Dispatch<Tensor>(*op_, {}, ctx));\n    JUST(result->set_requires_grad(requires_grad));\n    return result;\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass RandInt2Functor {\n public:\n  Maybe<Tensor> operator()(const int64_t high, const Shape& shape,\n                           const Optional<Symbol<DType>>& dtype,\n                           const Optional<Symbol<Device>>& device,\n                           const Optional<one::Generator>& generator,\n                           const bool& requires_grad) const {\n    return RandInt(/*low*/ 0, high, shape, dtype, device, generator, requires_grad);\n  }\n};\n\nclass RandIntLikeFunctor {\n public:\n  Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& input, const int64_t low,\n                           const int64_t high, const Optional<Symbol<DType>>& dtype,\n                           const Optional<Symbol<Device>>& device,\n                           const Optional<one::Generator>& generator,\n                           const bool& requires_grad) const {\n    const Shape shape = *input->shape();\n    return RandInt(low, high, shape, dtype, device, generator, requires_grad);\n  }\n};\n\nclass RandIntLike2Functor {\n public:\n  Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& input, const int64_t high,\n                           const Optional<Symbol<DType>>& dtype,\n                           const Optional<Symbol<Device>>& device,\n                           const Optional<one::Generator>& generator,\n                           const bool& requires_grad) const {\n    const Shape shape = *input->shape();\n    return RandInt(/*low*/ 0, high, shape, dtype, device, generator, requires_grad);\n  }\n};\n\nclass GlobalRandIntFunctor {\n public:\n  GlobalRandIntFunctor() { op_ = CHECK_JUST(one::OpBuilder(\"uniform_int\").Output(\"out\").Build()); }\n\n  Maybe<Tensor> operator()(const int64_t low, const int64_t high, const Shape& shape,\n                           const Symbol<ParallelDesc>& placement,\n                           const std::vector<Symbol<SbpParallel>>& sbp,\n                           const Optional<Symbol<DType>>& dtype,\n                           const Optional<one::Generator>& generator,\n                           const bool& requires_grad) const {\n    JUST(CheckDeviceIdsIsValid(placement));\n    DataType dtype_val = DataType::kInt64;\n    if (dtype) { dtype_val = JUST(dtype)->data_type(); }\n\n    const auto& nd_sbp = JUST(GetNdSbp(sbp));\n    auto attr_nd_sbp = *JUST(GetNdSbpStrList(nd_sbp));\n\n    auto gen = generator.value_or(JUST(one::DefaultAutoGenerator()));\n    gen = JUST(GetGeneratorForLazyOrGlobal(gen, LazyMode::is_enabled(), placement, nd_sbp));\n\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"shape\", \"from\", \"to\", \"dtype\", \"seed\", \"nd_sbp\");\n    attrs.SetAllAttrs(shape, low, high, dtype_val, static_cast<int64_t>(gen->current_seed()),\n                      attr_nd_sbp);\n\n    const auto& distribution_state = std::make_shared<DistributionKernelState>(gen);\n    auto result = JUST(OpInterpUtil::Dispatch<Tensor>(\n        *op_, {}, OpExprInterpContext(attrs, placement, nd_sbp, distribution_state)));\n    JUST(result->set_requires_grad(requires_grad));\n    return result;\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass GlobalRandInt2Functor {\n public:\n  Maybe<Tensor> operator()(const int64_t high, const Shape& shape,\n                           const Symbol<ParallelDesc>& placement,\n                           const std::vector<Symbol<SbpParallel>>& sbp,\n                           const Optional<Symbol<DType>>& dtype,\n                           const Optional<one::Generator>& generator,\n                           const bool& requires_grad) const {\n    JUST(CheckDeviceIdsIsValid(placement));\n    return GlobalRandInt(/*low*/ 0, high, shape, placement, sbp, dtype, generator, requires_grad);\n  }\n};\n\nclass GlobalRandIntLikeFunctor {\n public:\n  Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& input, const int64_t low,\n                           const int64_t high, const Symbol<ParallelDesc>& placement,\n                           const std::vector<Symbol<SbpParallel>>& sbp,\n                           const Optional<Symbol<DType>>& dtype,\n                           const Optional<one::Generator>& generator,\n                           const bool& requires_grad) const {\n    const Shape shape = *input->shape();\n    return GlobalRandInt(low, high, shape, placement, sbp, dtype, generator, requires_grad);\n  }\n};\n\nclass GlobalRandIntLike2Functor {\n public:\n  Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& input, const int64_t high,\n                           const Symbol<ParallelDesc>& placement,\n                           const std::vector<Symbol<SbpParallel>>& sbp,\n                           const Optional<Symbol<DType>>& dtype,\n                           const Optional<one::Generator>& generator,\n                           const bool& requires_grad) const {\n    const Shape shape = *input->shape();\n    return GlobalRandInt(/*low*/ 0, high, shape, placement, sbp, dtype, generator, requires_grad);\n  }\n};\n\nclass RandPermFunctor {\n public:\n  RandPermFunctor() { randperm_op_ = CHECK_JUST(one::OpBuilder(\"randperm\").Output(\"out\").Build()); }\n  Maybe<Tensor> operator()(const int32_t n, const Optional<one::Generator>& generator,\n                           const Symbol<DType>& dtype, const Optional<Symbol<Device>>& device,\n                           const bool& requires_grad) const {\n    if (GlobalMode::is_enabled()) {\n      auto global_mode_gurad = GlobalMode::Guard(false);\n      return JUST(functional::GlobalRandPerm(n, GetGlobalParallelDescFromDevice(device),\n                                             *JUST(GetSbpList(GlobalMode::nd_sbp())), generator,\n                                             dtype, requires_grad));\n    }\n\n    auto gen = generator.value_or(JUST(one::DefaultAutoGenerator()));\n    gen = JUST(GetGeneratorForLazyOrGlobal(gen, LazyMode::is_enabled(), NullOpt, NullOpt));\n\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"n\", \"seed\");\n    attrs.SetAllAttrs(n, static_cast<int64_t>(gen->current_seed()));\n\n    const auto& distribution_state = std::make_shared<DistributionKernelState>(gen);\n    OpExprInterpContext ctx(attrs, distribution_state);\n    ctx.device = device;\n    auto result = JUST(OpInterpUtil::Dispatch<Tensor>(*randperm_op_, {}, ctx));\n    JUST(result->set_requires_grad(requires_grad));\n    return functional::Cast(result, dtype, /*pin_memory=*/false);\n  }\n\n private:\n  std::shared_ptr<OpExpr> randperm_op_;\n};\n\nclass GlobalRandPermFunctor {\n public:\n  GlobalRandPermFunctor() {\n    randperm_op_ = CHECK_JUST(one::OpBuilder(\"randperm\").Output(\"out\").Build());\n  }\n  Maybe<Tensor> operator()(const int32_t n, const Symbol<ParallelDesc>& placement,\n                           const std::vector<Symbol<SbpParallel>>& sbp_tuple,\n                           const Optional<one::Generator>& generator, const Symbol<DType>& dtype,\n                           const bool& requires_grad) const {\n    JUST(CheckDeviceIdsIsValid(placement));\n    const auto& nd_sbp = JUST(GetNdSbp(sbp_tuple));\n    auto attr_nd_sbp = *JUST(GetNdSbpStrList(nd_sbp));\n\n    auto gen = generator.value_or(JUST(one::DefaultAutoGenerator()));\n    gen = JUST(GetGeneratorForLazyOrGlobal(gen, LazyMode::is_enabled(), placement, nd_sbp));\n\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"n\", \"seed\", \"nd_sbp\");\n    attrs.SetAllAttrs(n, static_cast<int64_t>(gen->current_seed()), attr_nd_sbp);\n\n    const auto& distribution_state = std::make_shared<DistributionKernelState>(gen);\n    auto result = JUST(OpInterpUtil::Dispatch<Tensor>(\n        *randperm_op_, {}, OpExprInterpContext(attrs, placement, nd_sbp, distribution_state)));\n    JUST(result->set_requires_grad(requires_grad));\n    return functional::Cast(result, dtype, /*pin_memory=*/false);\n  }\n\n private:\n  std::shared_ptr<OpExpr> randperm_op_;\n};\n\nclass ExponentialFunctor {\n public:\n  ExponentialFunctor() { op_ = CHECK_JUST(one::OpBuilder(\"exponential\").Output(\"out\").Build()); }\n\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const float& lambd,\n                           const Optional<one::Generator>& generator) const {\n    DataType dtype_val = x->dtype()->data_type();\n\n    Optional<Symbol<Device>> device;\n    Optional<Symbol<ParallelDesc>> placement;\n    Optional<Symbol<NdSbp>> nd_sbp;\n\n    auto gen = generator.value_or(JUST(one::DefaultAutoGenerator()));\n    if (x->is_global()) {\n      JUST(CheckDeviceIdsIsValid(JUST(x->parallel_desc())));\n      placement = JUST(x->parallel_desc());\n      nd_sbp = JUST(x->nd_sbp());\n      gen = JUST(GetGeneratorForLazyOrGlobal(gen, LazyMode::is_enabled(), placement, nd_sbp));\n    } else {\n      device = JUST(x->device());\n      gen = JUST(GetGeneratorForLazyOrGlobal(gen, LazyMode::is_enabled(), NullOpt, NullOpt));\n    }\n\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"seed\", \"lambd\", \"dtype\", \"out_shape\", \"nd_sbp\");\n    const Shape& out_shape = *(x->shape());\n    Optional<std::vector<std::string>> attr_nd_sbp{NullOpt};\n    if (nd_sbp) { attr_nd_sbp = *JUST(GetNdSbpStrList(JUST(nd_sbp))); }\n    attrs.SetAllAttrs(static_cast<int64_t>(gen->current_seed()), lambd, dtype_val, out_shape,\n                      attr_nd_sbp);\n\n    const auto& distribution_state = std::make_shared<DistributionKernelState>(gen);\n    OpExprInterpContext ctx(attrs, distribution_state);\n    ctx.device = device;\n    ctx.parallel_desc = placement;\n    ctx.nd_sbp = nd_sbp;\n\n    std::shared_ptr<TensorTuple> outputs = std::make_shared<TensorTuple>(1);\n    outputs->at(0) = x;\n    JUST(OpInterpUtil::Dispatch(*op_, {}, outputs.get(), ctx));\n    return outputs->at(0);\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\n// NOTE(Liang Depeng): The implementation of MultinomialFunctor is modified from\n//                    https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Distributions.cpp#L548\nclass MultinomialFunctor {\n public:\n  MultinomialFunctor() {\n    op_cpu_ =\n        CHECK_JUST(one::OpBuilder(\"multinomial_with_replacement\").Input(\"x\").Output(\"out\").Build());\n    op_gpu_ = CHECK_JUST(one::OpBuilder(\"multinomial_with_replacement\")\n                             .Input(\"x\")\n                             .Input(\"prefix_sum\")\n                             .Output(\"out\")\n                             .Build());\n    op_npu_ =\n        CHECK_JUST(one::OpBuilder(\"multinomial_with_replacement\").Input(\"x\").Output(\"out\").Build());\n  }\n\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const int& num_samples,\n                           const bool& replacement,\n                           const Optional<one::Generator>& generator) const {\n    CHECK_OR_RETURN(x->ndim() > 0 && x->ndim() <= 2)\n        << \"The input probability tensor must be 1 or 2 dim, \"\n        << \"but got: \" << x->ndim();\n    CHECK_OR_RETURN(x->dtype()->is_floating_point())\n        << \"multinomial only supports floating-point dtypes for input, but got: \"\n        << x->dtype()->name();\n    CHECK_OR_RETURN(num_samples > 0) << \"cannot sample num_samples <= 0 samples\";\n    int64_t num_categories = x->dim(x->ndim() - 1);\n    CHECK_OR_RETURN(replacement || num_samples <= num_categories)\n        << \"cannot sample num_samples > prob_dist.size(-1) samples without replacement\";\n\n    /* The largest consecutive integer representable in float32 (2^24) */\n    constexpr int64_t FLOAT32_MAX_CONSECUTIVE_INT = 1 << (FLT_MANT_DIG);\n    // Since the index tensor is float, numCategories cannot exceed max float integer precision\n    CHECK_OR_RETURN(num_categories <= FLOAT32_MAX_CONSECUTIVE_INT)\n        << \"number of categories cannot exceed 2^24\";\n\n    DeviceType input_device = DeviceType::kCPU;\n    if (x->is_global()) {\n      JUST(CheckDeviceIdsIsValid(JUST(x->parallel_desc())));\n      input_device = JUST(x->parallel_desc())->device_type();\n    } else {\n      input_device = JUST(x->device())->enum_type();\n    }\n    // Fast-path for no replacement.\n    // Reference:\n    // https://github.com/pytorch/pytorch/issues/11931#issuecomment-625882503\n    if (!replacement && input_device != DeviceType::kNPU) {\n      // The algorithm is from gumbel softmax.\n      // s = argmax( logp - log(-log(eps)) ) where eps ~ U(0, 1)\n      // Here we can apply exp to the formula which will not affect result of\n      // argmax or topk. Then we have\n      // s = argmax( p / (-log(eps)) ) where eps ~ U(0, 1).\n      // We can also simplify the formula above by\n      // s = argmax( p / q ) where q ~ Exp(1)\n      std::shared_ptr<Tensor> q =\n          JUST(functional::Empty(*(x->shape()), x->dtype(), JUST(x->device()),\n                                 /*requires_grad=*/x->requires_grad(), /*pin_memory=*/false));\n      q = JUST(functional::Exponential(q, 1, generator));\n      // In theory the probability to generate 0 from exponential distribution is\n      // 0. However, on CUDA side there is a protection to avoid 0s, but on CPU\n      // side, there is a very low probability to generate 0 from\n      // exponential<double>. The probability is about 2^(-DBL_MANT_DIG). We just\n      // ignore it here, but there may be some risk to get invalid output on CPU.\n      q = JUST(functional::Div(x, q));\n      std::shared_ptr<Tensor> result;\n      if (num_samples == 1) {\n        result = JUST(functional::ArgMax(q, -1, true, JUST(DType::Get(DataType::kInt64))));\n      } else if (input_device == DeviceType::kNPU) {\n      } else {\n        std::shared_ptr<TensorTuple> temp =\n            JUST(functional::TopK(q, num_samples, -1,\n                                  /*largest=*/true, /*sorted=*/true));\n        result = (*temp)[1];\n      }\n      return result;\n    }\n\n    auto gen = generator.value_or(JUST(one::DefaultAutoGenerator()));\n    gen = JUST(GetGeneratorForLazyOrGlobal(gen, LazyMode::is_enabled(), x));\n    auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(\"seed\", \"num_samples\", \"replacement\");\n    attrs.SetAllAttrs(static_cast<int64_t>(gen->current_seed()), num_samples, replacement);\n\n    const auto& distribution_state = std::make_shared<DistributionKernelState>(gen);\n    OpExprInterpContext ctx(attrs, distribution_state);\n\n    if (input_device == DeviceType::kCPU) {\n      return OpInterpUtil::Dispatch<Tensor>(*op_cpu_, {x}, ctx);\n    } else if (input_device == DeviceType::kNPU) {\n      return OpInterpUtil::Dispatch<Tensor>(*op_npu_, {x}, ctx);\n    } else {\n      std::shared_ptr<Tensor> sum_last_dim = JUST(functional::ReduceSum(x, {-1}, true, NullOpt));\n      std::shared_ptr<Tensor> norm_dist = JUST(functional::Div(x, sum_last_dim));\n      std::shared_ptr<Tensor> prefix_sum = JUST(functional::Cumsum(norm_dist, -1, x->dtype()));\n      return OpInterpUtil::Dispatch<Tensor>(*op_gpu_, {norm_dist, prefix_sum}, ctx);\n    }\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_cpu_;\n  std::shared_ptr<OpExpr> op_gpu_;\n  std::shared_ptr<OpExpr> op_npu_;\n};\n\n}  // namespace impl\n\nusing namespace impl;\n\nONEFLOW_FUNCTION_LIBRARY(m) {\n  m.add_functor<BernoulliFunctor>(\"Bernoulli\");\n  m.add_functor<BernoulliInplaceFunctor>(\"BernoulliInplace\");\n  m.add_functor<BernoulliProbFunctor>(\"BernoulliProb\");\n  m.add_functor<BernoulliProbInplaceFunctor>(\"BernoulliProbInplace\");\n  m.add_functor<RandPermFunctor>(\"RandPerm\");\n  m.add_functor<GlobalRandPermFunctor>(\"GlobalRandPerm\");\n  m.add_functor<RandFunctor>(\"Rand\");\n  m.add_functor<GlobalRandFunctor>(\"GlobalRand\");\n  m.add_functor<RandNFunctor>(\"RandN\");\n  m.add_functor<GlobalRandNFunctor>(\"GlobalRandN\");\n  m.add_functor<impl::NormalFunctor>(\"Normal\");\n  m.add_functor<impl::Normal2Functor>(\"Normal2\");\n  m.add_functor<impl::TensorTensorNormalFunctor>(\"TensorTensorNormal\");\n  m.add_functor<impl::TensorScalarNormalFunctor>(\"TensorScalarNormal\");\n  m.add_functor<impl::ScalarTensorNormalFunctor>(\"ScalarTensorNormal\");\n  m.add_functor<impl::InplaceNormalFuctor>(\"Normal_\");\n  m.add_functor<impl::GlobalNormalFunctor>(\"GlobalNormal\");\n  m.add_functor<impl::GlobalNormal2Functor>(\"GlobalNormal2\");\n  m.add_functor<RandnLikeFunctor>(\"RandnLike\");\n  m.add_functor<GlobalRandnLikeFunctor>(\"GlobalRandnLike\");\n  m.add_functor<RandIntFunctor, RandInt2Functor>(\"RandInt\");\n  m.add_functor<GlobalRandIntFunctor, GlobalRandInt2Functor>(\"GlobalRandInt\");\n  m.add_functor<RandIntLikeFunctor, RandIntLike2Functor>(\"RandIntLike\");\n  m.add_functor<GlobalRandIntLikeFunctor, GlobalRandIntLike2Functor>(\"GlobalRandIntLike\");\n  m.add_functor<ExponentialFunctor>(\"Exponential\");\n  m.add_functor<MultinomialFunctor>(\"Multinomial\");\n  m.add_functor<InplaceUniformFunctor>(\"InplaceUniform\");\n};\n\n}  // namespace functional\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/functional/impl/rnn_functor.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/common/data_type.pb.h\"\n#include \"oneflow/core/common/error.h\"\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/common/optional.h\"\n#include \"oneflow/core/common/scalar.h\"\n#include \"oneflow/core/framework/attr_map.h\"\n#include \"oneflow/core/framework/op_builder.h\"\n#include \"oneflow/core/framework/op_expr.h\"\n#include \"oneflow/core/framework/op_interpreter/op_interpreter_util.h\"\n#include \"oneflow/core/framework/tensor.h\"\n#include \"oneflow/core/framework/tensor_tuple.h\"\n#include \"oneflow/core/framework/tensor_util.h\"\n#include \"oneflow/core/framework/op_interpreter.h\"\n#include \"oneflow/core/framework/random_generator.h\"\n#include \"oneflow/core/functional/functional.h\"\n#include \"oneflow/core/functional/function_library.h\"\n#include \"oneflow/core/functional/sequence_function.h\"\n#include \"oneflow/core/functional/impl/common.h\"\n#include \"oneflow/core/functional/impl/unary_functor.h\"\n#include \"oneflow/core/job/lazy_mode.h\"\n#include \"oneflow/core/common/container_util.h\"\n#include \"oneflow/core/kernel/kernel_util.h\"\n#include \"oneflow/user/kernels/distributions/common.h\"\n#include \"oneflow/core/framework/nd_sbp.h\"\n\nnamespace oneflow {\nnamespace one {\nnamespace functional {\n\nnamespace impl {\n// NOTE(Liang Depeng): The implementation of rnn related functors are modified from\n//                     https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/RNN.cpp\nstruct tanh_f {\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& t) const {\n    return JUST(functional::Tanh(t));\n  }\n};\nstruct relu_f {\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& t) const {\n    return JUST(functional::Relu(t, false));\n  }\n};\n\nMaybe<void> check_rnn_cell_forward_input(const std::shared_ptr<one::Tensor>& input,\n                                         int64_t input_size) {\n  CHECK_OR_RETURN(input->shape()->At(1) == input_size)\n      << \"input has inconsistent input_size: got \" << input->shape()->At(1) << \" expected \"\n      << input_size;\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> check_rnn_cell_forward_hidden(const std::shared_ptr<one::Tensor>& input,\n                                          const std::shared_ptr<one::Tensor>& hx,\n                                          int64_t hidden_size, int64_t hidden_label) {\n  CHECK_OR_RETURN(input->shape()->At(0) == hx->shape()->At(0))\n      << \"Input batch size \" << input->shape()->At(0) << \" doesn't match hidden\" << hidden_label\n      << \" batch size \" << hx->shape()->At(0);\n\n  CHECK_OR_RETURN(hx->shape()->At(1) == hidden_size)\n      << \"hidden\" << hidden_label << \" has inconsistent hidden_size: got \" << hx->shape()->At(1)\n      << \", expected \" << hidden_size;\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> check_attributes(const std::shared_ptr<one::Tensor>& input, const TensorTuple& params,\n                             const TensorTuple& hiddens, bool check_dtype = false) {\n  DeviceType input_device{};\n  if (input->is_global()) {\n    input_device = JUST(input->parallel_desc())->device_type();\n  } else {\n    input_device = JUST(input->device())->enum_type();\n  }\n\n  DataType input_dtype = input->dtype()->data_type();\n\n  auto check_tensors = [&](const std::string& name,\n                           const std::shared_ptr<one::Tensor>& t) -> Maybe<void> {\n    DeviceType t_device{};\n    if (t->is_global()) {\n      t_device = JUST(t->parallel_desc())->device_type();\n    } else {\n      t_device = JUST(t->device())->enum_type();\n    }\n\n    CHECK_OR_RETURN(input_device == t_device)\n        << \"Input and \" << name << \" tensors are not at the same device, found input tensor at \"\n        << input_device << \" and \" << name << \" tensor at \" << t_device;\n\n    if (check_dtype) {\n      DataType t_dtype = t->dtype()->data_type();\n      CHECK_OR_RETURN(input_dtype == t_dtype)\n          << \"Input and \" << name << \" tensors are not the same dtype, found input tensor with \"\n          << input_dtype << \" and \" << name << \" tensor with \" << t_dtype;\n    }\n    return Maybe<void>::Ok();\n  };\n\n  for (const auto& h : hiddens) JUST(check_tensors(\"hidden\", h));\n  for (const auto& p : params) JUST(check_tensors(\"parameter\", p));\n\n  return Maybe<void>::Ok();\n}\n\nMaybe<Tensor> linear(const std::shared_ptr<one::Tensor>& input,\n                     const std::shared_ptr<one::Tensor>& weight,\n                     const std::shared_ptr<one::Tensor>& bias) {\n  if (bias != nullptr) {\n    TensorTuple weights;\n    weights.emplace_back(weight);\n    TensorTuple biases;\n    biases.emplace_back(bias);\n    return functional::FusedMLP(input, weights, biases, true);\n  } else {\n    return functional::MatMul(input, weight, false, true, 1.0);\n  }\n}\n\nstruct CellParams {\n  CellParams(const std::shared_ptr<one::Tensor> _w_ih,  // NOLINT\n             const std::shared_ptr<one::Tensor> _w_hh,  // NOLINT\n             const std::shared_ptr<one::Tensor> _b_ih,  // NOLINT\n             const std::shared_ptr<one::Tensor> _b_hh,  // NOLINT\n             const std::shared_ptr<one::Tensor> _w_hr)  // NOLINT\n      : w_ih(_w_ih), w_hh(_w_hh), b_ih_(_b_ih), b_hh_(_b_hh), w_hr(_w_hr){};\n\n  const std::shared_ptr<one::Tensor> w_ih;\n  const std::shared_ptr<one::Tensor> w_hh;\n  const std::shared_ptr<one::Tensor> b_ih_;\n  const std::shared_ptr<one::Tensor> b_hh_;\n  const std::shared_ptr<one::Tensor> w_hr;  // only defined for LSTMs with projections\n\n  Maybe<Tensor> matmul_ih(const std::shared_ptr<one::Tensor>& input) const {\n    return functional::MatMul(input, w_ih, false, true, 1.0);\n  }\n\n  Maybe<Tensor> matmul_hh(const std::shared_ptr<one::Tensor>& h) const {\n    return functional::MatMul(h, w_hh, false, true, 1.0);\n  }\n\n  Maybe<Tensor> matmul_hr(const std::shared_ptr<one::Tensor>& h) const {\n    if (w_hr != nullptr) { return functional::MatMul(h, w_hr, false, true, 1.0); }\n    return h;\n  }\n\n  Maybe<Tensor> linear_ih(const std::shared_ptr<one::Tensor>& input) const {\n    return linear(input, w_ih, b_ih_);\n  }\n\n  Maybe<Tensor> linear_hh(const std::shared_ptr<one::Tensor>& h) const {\n    return linear(h, w_hh, b_hh_);\n  }\n\n  const std::shared_ptr<one::Tensor>& b_ih() const { return b_ih_; }\n  const std::shared_ptr<one::Tensor>& b_hh() const { return b_hh_; }\n};\n\n// Parses a flat list of parameter tensors into a list of CellParams\nstatic Maybe<std::vector<CellParams>> gather_params(const TensorTuple& params, bool has_biases,\n                                                    bool has_projections = false) {\n  std::vector<CellParams> result;\n  if (has_biases) {\n    if (has_projections) {\n      CHECK_OR_RETURN(params.size() % 5 == 0) << \"got an incorrect number of RNN parameters\";\n      for (size_t i = 0; i < params.size(); i += 5) {\n        result.emplace_back(params[i], params[i + 1], params[i + 2], params[i + 3], params[i + 4]);\n      }\n    } else {\n      CHECK_OR_RETURN(params.size() % 4 == 0) << \"got an incorrect number of RNN parameters\";\n      for (size_t i = 0; i < params.size(); i += 4) {\n        result.emplace_back(params[i], params[i + 1], params[i + 2], params[i + 3], nullptr);\n      }\n    }\n  } else {\n    if (has_projections) {\n      CHECK_OR_RETURN(params.size() % 3 == 0) << \"got an incorrect number of RNN parameters\";\n      for (size_t i = 0; i < params.size(); i += 3) {\n        result.emplace_back(params[i], params[i + 1], nullptr, nullptr, params[i + 2]);\n      }\n    } else {\n      CHECK_OR_RETURN(params.size() % 2 == 0) << \"got an incorrect number of RNN parameters\";\n      for (size_t i = 0; i < params.size(); i += 2) {\n        result.emplace_back(params[i], params[i + 1], nullptr, nullptr, nullptr);\n      }\n    }\n  }\n  return result;\n}\n\ntemplate<typename nonlinearity, typename cell_params>\nstruct SimpleCell {\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input,\n                           const std::shared_ptr<one::Tensor>& hidden, const cell_params& params,\n                           bool pre_compute_input = false) const {\n    std::shared_ptr<one::Tensor> hh = JUST(params.linear_hh(hidden));\n    std::shared_ptr<one::Tensor> output;\n    if (pre_compute_input) {\n      output = JUST(functional::Add(hh, input, 1.0, true));\n    } else {\n      std::shared_ptr<one::Tensor> ih = JUST(params.linear_ih(input));\n      output = JUST(functional::Add(hh, ih, 1.0, true));\n    }\n    return nonlinearity{}(output);\n  }\n};\n\ntemplate<typename cell_params>\nstruct GRUCell {\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input,\n                           const std::shared_ptr<one::Tensor>& hidden, const cell_params& params,\n                           bool pre_compute_input = false) const {\n    DeviceType input_device{};\n    if (input->is_global()) {\n      input_device = JUST(input->parallel_desc())->device_type();\n    } else {\n      input_device = JUST(input->device())->enum_type();\n    }\n\n    if (input_device == DeviceType::kCUDA) {\n      CHECK_OR_RETURN(!pre_compute_input);\n\n      std::shared_ptr<one::Tensor> igates = JUST(params.matmul_ih(input));\n      std::shared_ptr<one::Tensor> hgates = JUST(params.matmul_hh(hidden));\n\n      std::shared_ptr<TensorTuple> result =\n          JUST(functional::FusedGruCell(igates, hgates, hidden, params.b_ih(), params.b_hh()));\n\n      return (*result)[0];\n    }\n\n    std::shared_ptr<one::TensorTuple> chunked_igates;\n    if (pre_compute_input) {\n      chunked_igates = JUST(functional::Chunk(input, 3, 1));\n    } else {\n      std::shared_ptr<one::Tensor> gates_ih = JUST(params.linear_ih(input));\n      chunked_igates = JUST(functional::Chunk(gates_ih, 3, 1));\n    }\n\n    std::shared_ptr<one::Tensor> tmp = JUST(params.linear_hh(hidden));\n    std::shared_ptr<one::TensorTuple> chunked_hgates = JUST(functional::Chunk(tmp, 3, 1));\n    std::shared_ptr<one::Tensor> reset_gate =\n        JUST(functional::Add((*chunked_hgates)[0], (*chunked_igates)[0], 1.0, false));\n    reset_gate = JUST(functional::Sigmoid(reset_gate));\n    std::shared_ptr<one::Tensor> input_gate =\n        JUST(functional::Add((*chunked_hgates)[1], (*chunked_igates)[1], 1.0, false));\n    input_gate = JUST(functional::Sigmoid(input_gate));\n    std::shared_ptr<one::Tensor> new_gate = JUST(functional::Mul((*chunked_hgates)[2], reset_gate));\n    new_gate = JUST(functional::Add((*chunked_igates)[2], new_gate, 1.0, false));\n    new_gate = JUST(functional::Tanh(new_gate));\n    std::shared_ptr<one::Tensor> output = JUST(functional::Sub(hidden, new_gate, 1.0, false));\n    output = JUST(functional::Mul(output, input_gate));\n    output = JUST(functional::Add(output, new_gate, 1.0, false));\n    return output;\n  }\n};\n\ntemplate<typename cell_params>\nstruct LSTMCell {\n  Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& input,\n                                const one::TensorTuple& hidden, const cell_params& params,\n                                bool pre_compute_input = false) const {\n    const std::shared_ptr<Tensor>& hx = hidden[0];\n    const std::shared_ptr<Tensor>& cx = hidden[1];\n\n    DeviceType input_device{};\n    if (input->is_global()) {\n      input_device = JUST(input->parallel_desc())->device_type();\n    } else {\n      input_device = JUST(input->device())->enum_type();\n    }\n\n    if (input_device == DeviceType::kCUDA) {\n      CHECK_OR_RETURN(!pre_compute_input);\n\n      std::shared_ptr<one::Tensor> igates = JUST(params.matmul_ih(input));\n      std::shared_ptr<one::Tensor> hgates = JUST(params.matmul_hh(hx));\n\n      std::shared_ptr<TensorTuple> result =\n          JUST(functional::FusedLstmCell(igates, hgates, cx, params.b_ih(), params.b_hh()));\n\n      auto outputs = std::make_shared<TensorTuple>(2);\n      (*outputs)[0] = JUST(params.matmul_hr((*result)[0]));\n      (*outputs)[1] = (*result)[1];\n      return outputs;\n    }\n\n    std::shared_ptr<one::Tensor> gates = JUST(params.linear_hh(hx));\n    if (pre_compute_input) {\n      gates = JUST(functional::Add(gates, input, 1.0, true));\n    } else {\n      std::shared_ptr<one::Tensor> gates_ih = JUST(params.linear_ih(input));\n      gates = JUST(functional::Add(gates, gates_ih, 1.0, true));\n    }\n    std::shared_ptr<one::TensorTuple> chunked_gates = JUST(functional::Chunk(gates, 4, 1));\n    std::shared_ptr<one::Tensor> ingate = JUST(functional::Sigmoid((*chunked_gates)[0]));\n    std::shared_ptr<one::Tensor> forgetgate = JUST(functional::Sigmoid((*chunked_gates)[1]));\n    std::shared_ptr<one::Tensor> cellgate = JUST(functional::Tanh((*chunked_gates)[2]));\n    std::shared_ptr<one::Tensor> outgate = JUST(functional::Sigmoid((*chunked_gates)[3]));\n    std::shared_ptr<one::Tensor> cy = JUST(functional::Mul(forgetgate, cx));\n    cellgate = JUST(functional::Mul(ingate, cellgate));\n    cy = JUST(functional::Add(cy, cellgate, 1.0, true));\n    std::shared_ptr<one::Tensor> tanh_cy = JUST(functional::Tanh(cy));\n    std::shared_ptr<one::Tensor> hy = JUST(functional::Mul(outgate, tanh_cy));\n    auto outputs = std::make_shared<TensorTuple>(2);\n    (*outputs)[0] = JUST(params.matmul_hr(hy));\n    (*outputs)[1] = cy;\n    return outputs;\n  }\n};\n\nclass RnnTanhCellFunctor {\n public:\n  RnnTanhCellFunctor() {}\n\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input,\n                           const std::shared_ptr<one::Tensor>& hx,\n                           const std::shared_ptr<one::Tensor>& w_ih,\n                           const std::shared_ptr<one::Tensor>& w_hh,\n                           const Optional<one::Tensor>& b_ih,\n                           const Optional<one::Tensor>& b_hh) const {\n    JUST(check_rnn_cell_forward_input(input, w_ih->shape()->At(1)));\n    JUST(check_rnn_cell_forward_hidden(input, hx, w_hh->shape()->At(1), 0));\n    std::shared_ptr<one::Tensor> bias_ih = nullptr;\n    std::shared_ptr<one::Tensor> bias_hh = nullptr;\n    if (b_ih.has_value() && b_hh.has_value()) {\n      bias_ih = JUST(b_ih);\n      bias_hh = JUST(b_hh);\n    }\n    return SimpleCell<tanh_f, CellParams>{}(input, hx,\n                                            CellParams{w_ih, w_hh, bias_ih, bias_hh, nullptr});\n  }\n};\n\nclass RnnReluCellFunctor {\n public:\n  RnnReluCellFunctor() {}\n\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input,\n                           const std::shared_ptr<one::Tensor>& hx,\n                           const std::shared_ptr<one::Tensor>& w_ih,\n                           const std::shared_ptr<one::Tensor>& w_hh,\n                           const Optional<one::Tensor>& b_ih,\n                           const Optional<one::Tensor>& b_hh) const {\n    JUST(check_rnn_cell_forward_input(input, w_ih->shape()->At(1)));\n    JUST(check_rnn_cell_forward_hidden(input, hx, w_hh->shape()->At(1), 0));\n    std::shared_ptr<one::Tensor> bias_ih = nullptr;\n    std::shared_ptr<one::Tensor> bias_hh = nullptr;\n    if (b_ih.has_value() && b_hh.has_value()) {\n      bias_ih = JUST(b_ih);\n      bias_hh = JUST(b_hh);\n    }\n    return SimpleCell<relu_f, CellParams>{}(input, hx,\n                                            CellParams{w_ih, w_hh, bias_ih, bias_hh, nullptr});\n  }\n};\n\nclass GruCellFunctor {\n public:\n  GruCellFunctor() {}\n\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input,\n                           const std::shared_ptr<one::Tensor>& hx,\n                           const std::shared_ptr<one::Tensor>& w_ih,\n                           const std::shared_ptr<one::Tensor>& w_hh,\n                           const Optional<one::Tensor>& b_ih,\n                           const Optional<one::Tensor>& b_hh) const {\n    JUST(check_rnn_cell_forward_input(input, w_ih->shape()->At(1)));\n    JUST(check_rnn_cell_forward_hidden(input, hx, w_hh->shape()->At(1), 0));\n    std::shared_ptr<one::Tensor> bias_ih = nullptr;\n    std::shared_ptr<one::Tensor> bias_hh = nullptr;\n    if (b_ih.has_value() && b_hh.has_value()) {\n      bias_ih = JUST(b_ih);\n      bias_hh = JUST(b_hh);\n    }\n    return GRUCell<CellParams>{}(input, hx, CellParams{w_ih, w_hh, bias_ih, bias_hh, nullptr});\n  }\n};\n\nclass LstmCellFunctor {\n public:\n  LstmCellFunctor() {}\n\n  Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& input,\n                                const one::TensorTuple& hx,\n                                const std::shared_ptr<one::Tensor>& w_ih,\n                                const std::shared_ptr<one::Tensor>& w_hh,\n                                const Optional<one::Tensor>& b_ih,\n                                const Optional<one::Tensor>& b_hh) const {\n    CHECK_OR_RETURN(hx.size() == 2) << \"lstm_cell expects two hidden states\";\n    JUST(check_rnn_cell_forward_input(input, w_ih->shape()->At(1)));\n    auto hidden_size = w_hh->shape()->At(1);\n    JUST(check_rnn_cell_forward_hidden(input, hx[0], hidden_size, 0));\n    JUST(check_rnn_cell_forward_hidden(input, hx[1], hidden_size, 0));\n    std::shared_ptr<one::Tensor> bias_ih = nullptr;\n    std::shared_ptr<one::Tensor> bias_hh = nullptr;\n    if (b_ih.has_value() && b_hh.has_value()) {\n      bias_ih = JUST(b_ih);\n      bias_hh = JUST(b_hh);\n    }\n    return LSTMCell<CellParams>{}(input, hx, CellParams{w_ih, w_hh, bias_ih, bias_hh, nullptr});\n  }\n};\n\nclass FusedGruCellFunctor {\n public:\n  FusedGruCellFunctor() {\n    op_with_bias_ = CHECK_JUST(one::OpBuilder(\"fused_gru_cell\")\n                                   .Input(\"input_gates\")\n                                   .Input(\"hidden_gates\")\n                                   .Input(\"hx\")\n                                   .Input(\"input_bias\")\n                                   .Input(\"hidden_bias\")\n                                   .Output(\"hy\")\n                                   .Output(\"workspace\")\n                                   .Build());\n    op_without_bias_ = CHECK_JUST(one::OpBuilder(\"fused_gru_cell\")\n                                      .Input(\"input_gates\")\n                                      .Input(\"hidden_gates\")\n                                      .Input(\"hx\")\n                                      .Output(\"hy\")\n                                      .Output(\"workspace\")\n                                      .Build());\n  }\n\n  Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& igates,\n                                const std::shared_ptr<one::Tensor>& hgates,\n                                const std::shared_ptr<one::Tensor>& hx,\n                                const Optional<one::Tensor>& b_ih,\n                                const Optional<one::Tensor>& b_hh) const {\n    std::shared_ptr<TensorTuple> kernel_result;\n    if (b_ih.has_value() && b_hh.has_value()) {\n      kernel_result = JUST(OpInterpUtil::Dispatch<TensorTuple>(\n          *op_with_bias_, {igates, hgates, hx, JUST(b_ih), JUST(b_hh)}));\n    } else {\n      kernel_result =\n          JUST(OpInterpUtil::Dispatch<TensorTuple>(*op_without_bias_, {igates, hgates, hx}));\n    }\n    return kernel_result;\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_with_bias_;\n  std::shared_ptr<OpExpr> op_without_bias_;\n};\n\nclass FusedGruCellGradFunctor {\n public:\n  FusedGruCellGradFunctor() {\n    op_with_bias_ = CHECK_JUST(one::OpBuilder(\"fused_gru_cell_grad\")\n                                   .Input(\"grad_hy\")\n                                   .Input(\"workspace\")\n                                   .Output(\"grad_input_gates\")\n                                   .Output(\"grad_hidden_gates\")\n                                   .Output(\"grad_hx\")\n                                   .Output(\"grad_input_bias\")\n                                   .Output(\"grad_hidden_bias\")\n                                   .Build());\n    op_with_bias_without_hx_ = CHECK_JUST(one::OpBuilder(\"fused_gru_cell_grad\")\n                                              .Input(\"grad_hy\")\n                                              .Input(\"workspace\")\n                                              .Output(\"grad_input_gates\")\n                                              .Output(\"grad_hidden_gates\")\n                                              .Output(\"grad_input_bias\")\n                                              .Output(\"grad_hidden_bias\")\n                                              .Build());\n    op_without_bias_ = CHECK_JUST(one::OpBuilder(\"fused_gru_cell_grad\")\n                                      .Input(\"grad_hy\")\n                                      .Input(\"workspace\")\n                                      .Output(\"grad_input_gates\")\n                                      .Output(\"grad_hidden_gates\")\n                                      .Output(\"grad_hx\")\n                                      .Build());\n    op_without_bias_without_hx_ = CHECK_JUST(one::OpBuilder(\"fused_gru_cell_grad\")\n                                                 .Input(\"grad_hy\")\n                                                 .Input(\"workspace\")\n                                                 .Output(\"grad_input_gates\")\n                                                 .Output(\"grad_hidden_gates\")\n                                                 .Build());\n  }\n\n  Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& grad_hy,\n                                const std::shared_ptr<one::Tensor>& workspace, bool has_bias,\n                                bool hx_needs_grad) const {\n    std::shared_ptr<TensorTuple> kernel_result;\n    if (has_bias) {\n      if (hx_needs_grad) {\n        kernel_result =\n            JUST(OpInterpUtil::Dispatch<TensorTuple>(*op_with_bias_, {grad_hy, workspace}));\n      } else {\n        kernel_result = JUST(\n            OpInterpUtil::Dispatch<TensorTuple>(*op_with_bias_without_hx_, {grad_hy, workspace}));\n      }\n    } else {\n      if (hx_needs_grad) {\n        kernel_result =\n            JUST(OpInterpUtil::Dispatch<TensorTuple>(*op_without_bias_, {grad_hy, workspace}));\n      } else {\n        kernel_result = JUST(OpInterpUtil::Dispatch<TensorTuple>(*op_without_bias_without_hx_,\n                                                                 {grad_hy, workspace}));\n      }\n    }\n    return kernel_result;\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_with_bias_;\n  std::shared_ptr<OpExpr> op_with_bias_without_hx_;\n  std::shared_ptr<OpExpr> op_without_bias_;\n  std::shared_ptr<OpExpr> op_without_bias_without_hx_;\n};\n\nclass FusedLstmCellFunctor {\n public:\n  FusedLstmCellFunctor() {\n    op_with_bias_ = CHECK_JUST(one::OpBuilder(\"fused_lstm_cell\")\n                                   .Input(\"input_gates\")\n                                   .Input(\"hidden_gates\")\n                                   .Input(\"cx\")\n                                   .Input(\"input_bias\")\n                                   .Input(\"hidden_bias\")\n                                   .Output(\"hy\")\n                                   .Output(\"cy\")\n                                   .Output(\"workspace\")\n                                   .Build());\n    op_without_bias_ = CHECK_JUST(one::OpBuilder(\"fused_lstm_cell\")\n                                      .Input(\"input_gates\")\n                                      .Input(\"hidden_gates\")\n                                      .Input(\"cx\")\n                                      .Output(\"hy\")\n                                      .Output(\"cy\")\n                                      .Output(\"workspace\")\n                                      .Build());\n  }\n\n  Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& igates,\n                                const std::shared_ptr<one::Tensor>& hgates,\n                                const std::shared_ptr<one::Tensor>& cx,\n                                const Optional<one::Tensor>& b_ih,\n                                const Optional<one::Tensor>& b_hh) const {\n    std::shared_ptr<TensorTuple> kernel_result;\n    if (b_ih.has_value() && b_hh.has_value()) {\n      kernel_result = JUST(OpInterpUtil::Dispatch<TensorTuple>(\n          *op_with_bias_, {igates, hgates, cx, JUST(b_ih), JUST(b_hh)}));\n    } else {\n      kernel_result =\n          JUST(OpInterpUtil::Dispatch<TensorTuple>(*op_without_bias_, {igates, hgates, cx}));\n    }\n    return kernel_result;\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_with_bias_;\n  std::shared_ptr<OpExpr> op_without_bias_;\n};\n\nclass FusedLstmCellGradFunctor {\n public:\n  FusedLstmCellGradFunctor() {\n    op_with_bias_ = CHECK_JUST(one::OpBuilder(\"fused_lstm_cell_grad\")\n                                   .Input(\"grad_hy\")\n                                   .Input(\"grad_cy\")\n                                   .Input(\"cx\")\n                                   .Input(\"cy\")\n                                   .Input(\"workspace\")\n                                   .Output(\"grad_gates\")\n                                   .Output(\"grad_cx\")\n                                   .Output(\"grad_bias\")\n                                   .Build());\n    op_without_bias_ = CHECK_JUST(one::OpBuilder(\"fused_lstm_cell_grad\")\n                                      .Input(\"grad_hy\")\n                                      .Input(\"grad_cy\")\n                                      .Input(\"cx\")\n                                      .Input(\"cy\")\n                                      .Input(\"workspace\")\n                                      .Output(\"grad_gates\")\n                                      .Output(\"grad_cx\")\n                                      .Build());\n    op_with_bias_no_grad_cx_ = CHECK_JUST(one::OpBuilder(\"fused_lstm_cell_grad\")\n                                              .Input(\"grad_hy\")\n                                              .Input(\"grad_cy\")\n                                              .Input(\"cx\")\n                                              .Input(\"cy\")\n                                              .Input(\"workspace\")\n                                              .Output(\"grad_gates\")\n                                              .Output(\"grad_bias\")\n                                              .Build());\n    op_without_bias_no_grad_cx_ = CHECK_JUST(one::OpBuilder(\"fused_lstm_cell_grad\")\n                                                 .Input(\"grad_hy\")\n                                                 .Input(\"grad_cy\")\n                                                 .Input(\"cx\")\n                                                 .Input(\"cy\")\n                                                 .Input(\"workspace\")\n                                                 .Output(\"grad_gates\")\n                                                 .Build());\n  }\n\n  Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& grad_hy,\n                                const std::shared_ptr<one::Tensor>& grad_cy,\n                                const std::shared_ptr<one::Tensor>& cx,\n                                const std::shared_ptr<one::Tensor>& cy,\n                                const std::shared_ptr<one::Tensor>& workspace, bool need_cx_grad,\n                                bool has_bias) const {\n    std::shared_ptr<TensorTuple> kernel_result;\n    if (has_bias) {\n      if (need_cx_grad) {\n        kernel_result = JUST(OpInterpUtil::Dispatch<TensorTuple>(\n            *op_with_bias_, {grad_hy, grad_cy, cx, cy, workspace}));\n      } else {\n        kernel_result = JUST(OpInterpUtil::Dispatch<TensorTuple>(\n            *op_with_bias_no_grad_cx_, {grad_hy, grad_cy, cx, cy, workspace}));\n      }\n    } else {\n      if (need_cx_grad) {\n        kernel_result = JUST(OpInterpUtil::Dispatch<TensorTuple>(\n            *op_without_bias_, {grad_hy, grad_cy, cx, cy, workspace}));\n      } else {\n        kernel_result = JUST(OpInterpUtil::Dispatch<TensorTuple>(\n            *op_without_bias_no_grad_cx_, {grad_hy, grad_cy, cx, cy, workspace}));\n      }\n    }\n    return kernel_result;\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_with_bias_;\n  std::shared_ptr<OpExpr> op_with_bias_no_grad_cx_;\n  std::shared_ptr<OpExpr> op_without_bias_;\n  std::shared_ptr<OpExpr> op_without_bias_no_grad_cx_;\n};\n\ntemplate<typename cell_type>\nMaybe<TensorTuple> _rnn_impl(const std::shared_ptr<one::Tensor>& input,\n                             const std::shared_ptr<one::Tensor>& hx, const one::TensorTuple& params,\n                             const bool& has_biases, const int32_t& num_layers,\n                             const float& dropout, const bool& train, const bool& bidirectional,\n                             const bool& batch_first) {\n  TensorTuple hiddens;\n  hiddens.emplace_back(hx);\n  JUST(check_attributes(input, params, hiddens));\n\n  std::shared_ptr<one::Tensor> rnn_input = input;\n  if (batch_first) {\n    std::vector<int32_t> dims = {1, 0, 2};\n    rnn_input = JUST(functional::Permute(input, dims));\n  }\n  auto rnn_params = JUST(gather_params(params, has_biases));\n  std::shared_ptr<TensorTuple> rnn_hiddens = JUST(functional::Unbind(hx, 0));\n  std::shared_ptr<TensorTuple> rnn_inputs = JUST(functional::Unbind(rnn_input, 0));\n\n  auto generator = JUST(one::DefaultAutoGenerator());\n\n  TensorTuple final_hiddens;\n  if (bidirectional) {\n    std::shared_ptr<TensorTuple> fw_outputs = std::make_shared<TensorTuple>(rnn_inputs->size());\n    std::shared_ptr<TensorTuple> bw_outputs = std::make_shared<TensorTuple>(rnn_inputs->size());\n    for (int32_t l = 0; l < num_layers; ++l) {\n      // forward direction\n      std::shared_ptr<one::Tensor> fw_hidden = (*rnn_hiddens)[l * 2];\n      auto& fw_cell_param = (*rnn_params)[l * 2];\n      for (int32_t i = 0; i < rnn_inputs->size(); ++i) {\n        fw_hidden = JUST(cell_type{}((*rnn_inputs)[i], fw_hidden, fw_cell_param));\n        (*fw_outputs)[i] = fw_hidden;\n      }\n      final_hiddens.emplace_back(fw_hidden);\n\n      // reverse direction\n      std::shared_ptr<one::Tensor> bw_hidden = (*rnn_hiddens)[l * 2 + 1];\n      auto& bw_cell_param = (*rnn_params)[l * 2 + 1];\n      for (int32_t i = rnn_inputs->size() - 1; i >= 0; i--) {\n        bw_hidden = JUST(cell_type{}((*rnn_inputs)[i], bw_hidden, bw_cell_param));\n        (*bw_outputs)[i] = bw_hidden;\n      }\n      final_hiddens.emplace_back(bw_hidden);\n\n      // concat fw_outputs and bw_outputs\n      for (int32_t i = 0; i < rnn_inputs->size(); ++i) {\n        (*rnn_inputs)[i] = JUST(functional::Concat({(*fw_outputs)[i], (*bw_outputs)[i]},\n                                                   bw_hidden->shape()->NumAxes() - 1));\n      }\n\n      if (dropout != 0 && train && l < num_layers - 1) {\n        std::shared_ptr<one::Tensor> stack_res = JUST(functional::Stack(*rnn_inputs, 0));\n        std::shared_ptr<one::Tensor> dropout_res =\n            JUST(functional::Dropout(stack_res, dropout, train, false, generator, nullptr));\n        rnn_inputs = JUST(functional::Unbind(dropout_res, 0));\n      }\n    }\n  } else {\n    for (int32_t l = 0; l < num_layers; ++l) {\n      std::shared_ptr<one::Tensor> hidden = (*rnn_hiddens)[l];\n      auto& cell_param = (*rnn_params)[l];\n      for (int32_t i = 0; i < rnn_inputs->size(); ++i) {\n        hidden = JUST(cell_type{}((*rnn_inputs)[i], hidden, cell_param));\n        (*rnn_inputs)[i] = hidden;\n      }\n      final_hiddens.emplace_back(hidden);\n      if (dropout != 0 && train && l < num_layers - 1) {\n        std::shared_ptr<one::Tensor> stack_res = JUST(functional::Stack(*rnn_inputs, 0));\n        std::shared_ptr<one::Tensor> dropout_res =\n            JUST(functional::Dropout(stack_res, dropout, train, false, generator, nullptr));\n        rnn_inputs = JUST(functional::Unbind(dropout_res, 0));\n      }\n    }\n  }\n\n  TensorTuple output;\n  std::shared_ptr<one::Tensor> output_0 = JUST(functional::Stack(*rnn_inputs, 0));\n  if (batch_first) {\n    std::vector<int32_t> dims = {1, 0, 2};\n    output.emplace_back(JUST(functional::Permute(output_0, dims)));\n  } else {\n    output.emplace_back(output_0);\n  }\n  output.emplace_back(JUST(functional::Stack(final_hiddens, 0)));\n  return output;\n}\n\nclass RnnTanhInputFunctor {\n public:\n  RnnTanhInputFunctor() {}\n\n  Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& input,\n                                const std::shared_ptr<one::Tensor>& hx,\n                                const one::TensorTuple& params, const bool& has_biases,\n                                const int32_t& num_layers, const float& dropout, const bool& train,\n                                const bool& bidirectional, const bool& batch_first) const {\n    return _rnn_impl<SimpleCell<tanh_f, CellParams>>(input, hx, params, has_biases, num_layers,\n                                                     dropout, train, bidirectional, batch_first);\n  }\n};\n\nclass RnnReluInputFunctor {\n public:\n  RnnReluInputFunctor() {}\n\n  Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& input,\n                                const std::shared_ptr<one::Tensor>& hx,\n                                const one::TensorTuple& params, const bool& has_biases,\n                                const int32_t& num_layers, const float& dropout, const bool& train,\n                                const bool& bidirectional, const bool& batch_first) const {\n    return _rnn_impl<SimpleCell<relu_f, CellParams>>(input, hx, params, has_biases, num_layers,\n                                                     dropout, train, bidirectional, batch_first);\n  }\n};\n\ntemplate<typename cell_type>\nMaybe<TensorTuple> _rnn_pack_sequence_impl(const std::shared_ptr<one::Tensor>& input,\n                                           const std::shared_ptr<one::Tensor>& batch_sizes,\n                                           const std::shared_ptr<one::Tensor>& hx,\n                                           const one::TensorTuple& params, const bool& has_biases,\n                                           const int32_t& num_layers, const float& dropout,\n                                           const bool& train, const bool& bidirectional) {\n  auto rnn_params = JUST(gather_params(params, has_biases));\n  std::shared_ptr<TensorTuple> rnn_hiddens = JUST(functional::Unbind(hx, 0));\n  auto generator = JUST(one::DefaultAutoGenerator());\n\n  TensorTuple final_hiddens;\n\n  std::vector<int64_t> batch_sizes_vec;\n  batch_sizes_vec.resize(batch_sizes->nelement());\n  const auto& callback = [&](ep::Stream* stream,\n                             const std::shared_ptr<vm::EagerBlobObject>& eager_blob_object) {\n    SyncAutoMemcpy(stream, batch_sizes_vec.data(), eager_blob_object->dptr(),\n                   batch_sizes_vec.size() * sizeof(int64_t), memory::MakeHostMemCase(),\n                   eager_blob_object->mem_case());\n  };\n  JUST(SyncAccessTensorWithTimeOut(batch_sizes, callback, \"const\"));\n  int64_t num_steps = batch_sizes->shape()->At(0);\n  std::shared_ptr<TensorTuple> rnn_inputs = std::make_shared<TensorTuple>(num_steps);\n  int64_t input_offset = 0;\n  for (int32_t i = 0; i < num_steps; ++i) {\n    const int64_t batch_size = batch_sizes_vec[i];\n    (*rnn_inputs)[i] = JUST(functional::Narrow(input, 0, input_offset, batch_size));\n    input_offset += batch_size;\n  }\n\n  if (bidirectional) {\n    std::shared_ptr<TensorTuple> fw_outputs = std::make_shared<TensorTuple>(rnn_inputs->size());\n    std::shared_ptr<TensorTuple> bw_outputs = std::make_shared<TensorTuple>(rnn_inputs->size());\n    for (int32_t l = 0; l < num_layers; ++l) {\n      // forward direction\n      int64_t last_batch_size = batch_sizes_vec[0];\n      std::shared_ptr<one::Tensor> fw_hidden = (*rnn_hiddens)[l * 2];\n      auto& fw_cell_param = (*rnn_params)[l * 2];\n\n      TensorTuple fw_final_hiddens_for_single_layer;\n      for (int32_t i = 0; i < num_steps; ++i) {\n        const int64_t batch_size = batch_sizes_vec[i];\n        const int64_t dec = last_batch_size - batch_size;\n        if (dec > 0) {\n          fw_final_hiddens_for_single_layer.emplace_back(\n              JUST(functional::Narrow(fw_hidden, 0, last_batch_size - dec, dec)));\n          fw_hidden = JUST(functional::Narrow(fw_hidden, 0, 0, last_batch_size - dec));\n        }\n        last_batch_size = batch_size;\n        fw_hidden = JUST(cell_type{}((*rnn_inputs)[i], fw_hidden, fw_cell_param));\n        (*fw_outputs)[i] = fw_hidden;\n      }\n      fw_final_hiddens_for_single_layer.emplace_back(fw_hidden);\n      std::reverse(fw_final_hiddens_for_single_layer.begin(),\n                   fw_final_hiddens_for_single_layer.end());\n      final_hiddens.emplace_back(JUST(functional::Concat(fw_final_hiddens_for_single_layer, 0)));\n\n      // reverse direction\n      last_batch_size = batch_sizes_vec[num_steps - 1];\n      std::shared_ptr<one::Tensor> bw_hidden =\n          JUST(functional::Narrow((*rnn_hiddens)[l * 2 + 1], 0, 0, last_batch_size));\n      auto& bw_cell_param = (*rnn_params)[l * 2 + 1];\n      // Here the situation is similar to that above, except we start out with\n      // the smallest batch size (and a small set of hidden states we actually use),\n      // and progressively expand the hidden states, as we move backwards over the\n      // 1D list of inputs.\n      for (int64_t i = num_steps - 1; i >= 0; --i) {\n        const int64_t batch_size = batch_sizes_vec[i];\n        const int64_t inc = batch_size - last_batch_size;\n        if (inc > 0) {\n          std::shared_ptr<one::Tensor> hidden_slice = JUST(functional::Narrow(\n              (*rnn_hiddens)[l * 2 + 1], 0, last_batch_size, batch_size - last_batch_size));\n          std::shared_ptr<TensorTuple> tmp = std::make_shared<TensorTuple>(2);\n          (*tmp)[0] = bw_hidden;\n          (*tmp)[1] = hidden_slice;\n          bw_hidden = JUST(functional::Concat(*tmp, 0));\n        }\n        last_batch_size = batch_size;\n        bw_hidden = JUST(cell_type{}((*rnn_inputs)[i], bw_hidden, bw_cell_param));\n        (*bw_outputs)[i] = bw_hidden;\n      }\n\n      final_hiddens.emplace_back(bw_hidden);\n\n      // concat fw_outputs and bw_outputs\n      for (int32_t i = 0; i < num_steps; ++i) {\n        (*rnn_inputs)[i] = JUST(functional::Concat({(*fw_outputs)[i], (*bw_outputs)[i]},\n                                                   bw_hidden->shape()->NumAxes() - 1));\n      }\n\n      if (dropout != 0 && train && l < num_layers - 1) {\n        std::shared_ptr<one::Tensor> stack_res = JUST(functional::Concat(*rnn_inputs, 0));\n        std::shared_ptr<one::Tensor> dropout_res =\n            JUST(functional::Dropout(stack_res, dropout, train, false, generator, nullptr));\n        int64_t input_offset = 0;\n        for (int32_t i = 0; i < num_steps; ++i) {\n          const int64_t batch_size = batch_sizes_vec[i];\n          (*rnn_inputs)[i] = JUST(functional::Narrow(dropout_res, 0, input_offset, batch_size));\n          input_offset += batch_size;\n        }\n      }\n    }\n  } else {\n    // Batch sizes is a sequence of decreasing lengths, which are offsets\n    // into a 1D list of inputs. At every step we slice out batch_size elements,\n    // and possibly account for the decrease in the batch size since the last step,\n    // which requires us to slice the hidden state (since some sequences\n    // are completed now). The sliced parts are also saved, because we will need\n    // to return a tensor of final hidden state.\n    for (int32_t l = 0; l < num_layers; ++l) {\n      int64_t last_batch_size = batch_sizes_vec[0];\n      std::shared_ptr<one::Tensor> hidden = (*rnn_hiddens)[l];\n      auto& cell_param = (*rnn_params)[l];\n      TensorTuple final_hiddens_for_single_layer;\n      for (int32_t i = 0; i < num_steps; ++i) {\n        const int64_t batch_size = batch_sizes_vec[i];\n        const int64_t dec = last_batch_size - batch_size;\n        if (dec > 0) {\n          final_hiddens_for_single_layer.emplace_back(\n              JUST(functional::Narrow(hidden, 0, last_batch_size - dec, dec)));\n          hidden = JUST(functional::Narrow(hidden, 0, 0, last_batch_size - dec));\n        }\n        last_batch_size = batch_size;\n        hidden = JUST(cell_type{}((*rnn_inputs)[i], hidden, cell_param));\n        (*rnn_inputs)[i] = hidden;\n      }\n      final_hiddens_for_single_layer.emplace_back(hidden);\n      std::reverse(final_hiddens_for_single_layer.begin(), final_hiddens_for_single_layer.end());\n      final_hiddens.emplace_back(JUST(functional::Concat(final_hiddens_for_single_layer, 0)));\n\n      if (dropout != 0 && train && l < num_layers - 1) {\n        std::shared_ptr<one::Tensor> stack_res = JUST(functional::Concat(*rnn_inputs, 0));\n        std::shared_ptr<one::Tensor> dropout_res =\n            JUST(functional::Dropout(stack_res, dropout, train, false, generator, nullptr));\n        int64_t input_offset = 0;\n        for (int32_t i = 0; i < num_steps; ++i) {\n          const int64_t batch_size = batch_sizes_vec[i];\n          (*rnn_inputs)[i] = JUST(functional::Narrow(dropout_res, 0, input_offset, batch_size));\n          input_offset += batch_size;\n        }\n      }\n    }\n  }\n\n  TensorTuple output;\n  output.emplace_back(JUST(functional::Concat(*rnn_inputs, 0)));\n  output.emplace_back(JUST(functional::Stack(final_hiddens, 0)));\n  return output;\n}\n\nclass RnnTanhDataFunctor {\n public:\n  RnnTanhDataFunctor() {}\n\n  Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& data,\n                                const std::shared_ptr<one::Tensor>& batch_sizes,\n                                const std::shared_ptr<one::Tensor>& hx,\n                                const one::TensorTuple& params, const bool& has_biases,\n                                const int32_t& num_layers, const float& dropout, const bool& train,\n                                const bool& bidirectional) const {\n    return _rnn_pack_sequence_impl<SimpleCell<tanh_f, CellParams>>(\n        data, batch_sizes, hx, params, has_biases, num_layers, dropout, train, bidirectional);\n  }\n};\n\nclass RnnReluDataFunctor {\n public:\n  RnnReluDataFunctor() {}\n\n  Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& data,\n                                const std::shared_ptr<one::Tensor>& batch_sizes,\n                                const std::shared_ptr<one::Tensor>& hx,\n                                const one::TensorTuple& params, const bool& has_biases,\n                                const int32_t& num_layers, const float& dropout, const bool& train,\n                                const bool& bidirectional) const {\n    return _rnn_pack_sequence_impl<SimpleCell<relu_f, CellParams>>(\n        data, batch_sizes, hx, params, has_biases, num_layers, dropout, train, bidirectional);\n  }\n};\n\nMaybe<TensorTuple> _lstm_impl(const std::shared_ptr<one::Tensor>& input, const one::TensorTuple& hx,\n                              const one::TensorTuple& params, const bool& has_biases,\n                              const int32_t& num_layers, const float& dropout, const bool& train,\n                              const bool& bidirectional, const bool& batch_first) {\n  CHECK_OR_RETURN(hx.size() == 2) << \"lstm expects two hidden states\";\n  // if cells are of different size, that means projections are used\n  bool has_projections = (hx[0]->shape()->At(2) != hx[1]->shape()->At(2));\n  JUST(check_attributes(input, params, hx));\n  std::shared_ptr<one::Tensor> rnn_input = input;\n  if (batch_first) {\n    std::vector<int32_t> dims = {1, 0, 2};\n    rnn_input = JUST(functional::Permute(input, dims));\n  }\n  auto rnn_params = JUST(gather_params(params, has_biases, has_projections));\n\n  std::shared_ptr<TensorTuple> layer_hxs = JUST(functional::Unbind(hx[0], 0));\n  std::shared_ptr<TensorTuple> layer_cxs = JUST(functional::Unbind(hx[1], 0));\n  std::shared_ptr<TensorTuple> rnn_inputs = JUST(functional::Unbind(rnn_input, 0));\n\n  auto generator = JUST(one::DefaultAutoGenerator());\n\n  TensorTuple final_hy;\n  TensorTuple final_cy;\n\n  if (bidirectional) {\n    std::shared_ptr<TensorTuple> fw_outputs = std::make_shared<TensorTuple>(rnn_inputs->size());\n    std::shared_ptr<TensorTuple> lstm_cell_out = std::make_shared<TensorTuple>(2);\n    std::shared_ptr<TensorTuple> bw_outputs = std::make_shared<TensorTuple>(rnn_inputs->size());\n\n    for (int32_t l = 0; l < num_layers; ++l) {\n      // forward direction\n      (*lstm_cell_out)[0] = (*layer_hxs)[l * 2];\n      (*lstm_cell_out)[1] = (*layer_cxs)[l * 2];\n      auto& fw_cell_param = (*rnn_params)[l * 2];\n      for (int32_t i = 0; i < rnn_inputs->size(); ++i) {\n        lstm_cell_out =\n            JUST(LSTMCell<CellParams>{}((*rnn_inputs)[i], *lstm_cell_out, fw_cell_param));\n        (*fw_outputs)[i] = (*lstm_cell_out)[0];\n      }\n      final_hy.emplace_back((*lstm_cell_out)[0]);\n      final_cy.emplace_back((*lstm_cell_out)[1]);\n\n      // reverse direction\n      (*lstm_cell_out)[0] = (*layer_hxs)[l * 2 + 1];\n      (*lstm_cell_out)[1] = (*layer_cxs)[l * 2 + 1];\n      auto& bw_cell_param = (*rnn_params)[l * 2 + 1];\n      for (int32_t i = rnn_inputs->size() - 1; i >= 0; i--) {\n        lstm_cell_out =\n            JUST(LSTMCell<CellParams>{}((*rnn_inputs)[i], *lstm_cell_out, bw_cell_param));\n        (*bw_outputs)[i] = (*lstm_cell_out)[0];\n      }\n      final_hy.emplace_back((*lstm_cell_out)[0]);\n      final_cy.emplace_back((*lstm_cell_out)[1]);\n\n      // concat fw_outputs and bw_outputs\n      for (int32_t i = 0; i < rnn_inputs->size(); ++i) {\n        (*rnn_inputs)[i] = JUST(functional::Concat({(*fw_outputs)[i], (*bw_outputs)[i]},\n                                                   (*bw_outputs)[0]->shape()->NumAxes() - 1));\n      }\n\n      if (dropout != 0 && train && l < num_layers - 1) {\n        std::shared_ptr<one::Tensor> stack_res = JUST(functional::Stack(*rnn_inputs, 0));\n        std::shared_ptr<one::Tensor> dropout_res =\n            JUST(functional::Dropout(stack_res, dropout, train, false, generator, nullptr));\n        rnn_inputs = JUST(functional::Unbind(dropout_res, 0));\n      }\n    }\n\n  } else {\n    std::shared_ptr<TensorTuple> lstm_cell_out = std::make_shared<TensorTuple>(2);\n\n    for (int32_t l = 0; l < num_layers; ++l) {\n      auto& cell_param = (*rnn_params)[l];\n      (*lstm_cell_out)[0] = (*layer_hxs)[l];\n      (*lstm_cell_out)[1] = (*layer_cxs)[l];\n      for (int32_t i = 0; i < rnn_inputs->size(); ++i) {\n        lstm_cell_out = JUST(LSTMCell<CellParams>{}((*rnn_inputs)[i], *lstm_cell_out, cell_param));\n        (*rnn_inputs)[i] = (*lstm_cell_out)[0];\n      }\n      final_hy.emplace_back((*lstm_cell_out)[0]);\n      final_cy.emplace_back((*lstm_cell_out)[1]);\n\n      if (dropout != 0 && train && l < num_layers - 1) {\n        std::shared_ptr<one::Tensor> stack_res = JUST(functional::Stack(*rnn_inputs, 0));\n        std::shared_ptr<one::Tensor> dropout_res =\n            JUST(functional::Dropout(stack_res, dropout, train, false, generator, nullptr));\n        rnn_inputs = JUST(functional::Unbind(dropout_res, 0));\n      }\n    }\n  }\n\n  TensorTuple output;\n  std::shared_ptr<one::Tensor> output_0 = JUST(functional::Stack(*rnn_inputs, 0));\n  if (batch_first) {\n    std::vector<int32_t> dims = {1, 0, 2};\n    output.emplace_back(JUST(functional::Permute(output_0, dims)));\n  } else {\n    output.emplace_back(output_0);\n  }\n  output.emplace_back(JUST(functional::Stack(final_hy, 0)));\n  output.emplace_back(JUST(functional::Stack(final_cy, 0)));\n  return output;\n}\n\nclass LstmInputFunctor {\n public:\n  LstmInputFunctor() {}\n\n  Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& input,\n                                const one::TensorTuple& hx, const one::TensorTuple& params,\n                                const bool& has_biases, const int32_t& num_layers,\n                                const float& dropout, const bool& train, const bool& bidirectional,\n                                const bool& batch_first) const {\n    return _lstm_impl(input, hx, params, has_biases, num_layers, dropout, train, bidirectional,\n                      batch_first);\n  }\n};\n\nMaybe<TensorTuple> _lstm_pack_sequence_impl(const std::shared_ptr<one::Tensor>& input,\n                                            const std::shared_ptr<one::Tensor>& batch_sizes,\n                                            const one::TensorTuple& hx,\n                                            const one::TensorTuple& params, const bool& has_biases,\n                                            const int32_t& num_layers, const float& dropout,\n                                            const bool& train, const bool& bidirectional) {\n  CHECK_OR_RETURN(hx.size() == 2) << \"lstm expects two hidden states\";\n  // if cells are of different size, that means projections are used\n  bool has_projections = (hx[0]->shape()->At(2) != hx[1]->shape()->At(2));\n  auto rnn_params = JUST(gather_params(params, has_biases, has_projections));\n\n  std::shared_ptr<TensorTuple> layer_hxs = JUST(functional::Unbind(hx[0], 0));\n  std::shared_ptr<TensorTuple> layer_cxs = JUST(functional::Unbind(hx[1], 0));\n\n  std::vector<int64_t> batch_sizes_vec;\n  batch_sizes_vec.resize(batch_sizes->nelement());\n  const auto& callback = [&](ep::Stream* stream,\n                             const std::shared_ptr<vm::EagerBlobObject>& eager_blob_object) {\n    SyncAutoMemcpy(stream, batch_sizes_vec.data(), eager_blob_object->dptr(),\n                   batch_sizes_vec.size() * sizeof(int64_t), memory::MakeHostMemCase(),\n                   eager_blob_object->mem_case());\n  };\n  JUST(SyncAccessTensorWithTimeOut(batch_sizes, callback, \"const\"));\n  int64_t num_steps = batch_sizes->shape()->At(0);\n  std::shared_ptr<TensorTuple> rnn_inputs = std::make_shared<TensorTuple>(num_steps);\n  int64_t input_offset = 0;\n  for (int32_t i = 0; i < num_steps; ++i) {\n    const int64_t batch_size = batch_sizes_vec[i];\n    (*rnn_inputs)[i] = JUST(functional::Narrow(input, 0, input_offset, batch_size));\n    input_offset += batch_size;\n  }\n\n  auto generator = JUST(one::DefaultAutoGenerator());\n\n  TensorTuple final_hy;\n  TensorTuple final_cy;\n\n  if (bidirectional) {\n    std::shared_ptr<TensorTuple> fw_outputs = std::make_shared<TensorTuple>(rnn_inputs->size());\n    std::shared_ptr<TensorTuple> lstm_cell_out = std::make_shared<TensorTuple>(2);\n    std::shared_ptr<TensorTuple> bw_outputs = std::make_shared<TensorTuple>(rnn_inputs->size());\n\n    for (int32_t l = 0; l < num_layers; ++l) {\n      int64_t last_batch_size = batch_sizes_vec[0];\n      // forward direction\n      (*lstm_cell_out)[0] = (*layer_hxs)[l * 2];\n      (*lstm_cell_out)[1] = (*layer_cxs)[l * 2];\n      auto& fw_cell_param = (*rnn_params)[l * 2];\n\n      TensorTuple final_hy_for_single_layer;\n      TensorTuple final_cy_for_single_layer;\n      for (int32_t i = 0; i < num_steps; ++i) {\n        const int64_t batch_size = batch_sizes_vec[i];\n        const int64_t dec = last_batch_size - batch_size;\n        if (dec > 0) {\n          final_hy_for_single_layer.emplace_back(\n              JUST(functional::Narrow((*lstm_cell_out)[0], 0, last_batch_size - dec, dec)));\n          (*lstm_cell_out)[0] =\n              JUST(functional::Narrow((*lstm_cell_out)[0], 0, 0, last_batch_size - dec));\n\n          final_cy_for_single_layer.emplace_back(\n              JUST(functional::Narrow((*lstm_cell_out)[1], 0, last_batch_size - dec, dec)));\n          (*lstm_cell_out)[1] =\n              JUST(functional::Narrow((*lstm_cell_out)[1], 0, 0, last_batch_size - dec));\n        }\n        last_batch_size = batch_size;\n        lstm_cell_out =\n            JUST(LSTMCell<CellParams>{}((*rnn_inputs)[i], *lstm_cell_out, fw_cell_param));\n        (*fw_outputs)[i] = (*lstm_cell_out)[0];\n      }\n      final_hy_for_single_layer.emplace_back((*lstm_cell_out)[0]);\n      final_cy_for_single_layer.emplace_back((*lstm_cell_out)[1]);\n      std::reverse(final_hy_for_single_layer.begin(), final_hy_for_single_layer.end());\n      std::reverse(final_cy_for_single_layer.begin(), final_cy_for_single_layer.end());\n      final_hy.emplace_back(JUST(functional::Concat(final_hy_for_single_layer, 0)));\n      final_cy.emplace_back(JUST(functional::Concat(final_cy_for_single_layer, 0)));\n\n      // reverse direction\n      last_batch_size = batch_sizes_vec[num_steps - 1];\n      (*lstm_cell_out)[0] =\n          JUST(functional::Narrow((*layer_hxs)[l * 2 + 1], 0, 0, last_batch_size));\n      (*lstm_cell_out)[1] =\n          JUST(functional::Narrow((*layer_cxs)[l * 2 + 1], 0, 0, last_batch_size));\n\n      auto& bw_cell_param = (*rnn_params)[l * 2 + 1];\n\n      for (int64_t i = num_steps - 1; i >= 0; --i) {\n        const int64_t batch_size = batch_sizes_vec[i];\n        const int64_t inc = batch_size - last_batch_size;\n        if (inc > 0) {\n          std::shared_ptr<one::Tensor> hxs_slice = JUST(functional::Narrow(\n              (*layer_hxs)[l * 2 + 1], 0, last_batch_size, batch_size - last_batch_size));\n          std::shared_ptr<TensorTuple> tmp = std::make_shared<TensorTuple>(2);\n          (*tmp)[0] = (*lstm_cell_out)[0];\n          (*tmp)[1] = hxs_slice;\n          (*lstm_cell_out)[0] = JUST(functional::Concat(*tmp, 0));\n\n          std::shared_ptr<one::Tensor> cxs_slice = JUST(functional::Narrow(\n              (*layer_cxs)[l * 2 + 1], 0, last_batch_size, batch_size - last_batch_size));\n          (*tmp)[0] = (*lstm_cell_out)[1];\n          (*tmp)[1] = cxs_slice;\n          (*lstm_cell_out)[1] = JUST(functional::Concat(*tmp, 0));\n        }\n        last_batch_size = batch_size;\n        lstm_cell_out =\n            JUST(LSTMCell<CellParams>{}((*rnn_inputs)[i], *lstm_cell_out, bw_cell_param));\n        (*bw_outputs)[i] = (*lstm_cell_out)[0];\n      }\n      final_hy.emplace_back((*lstm_cell_out)[0]);\n      final_cy.emplace_back((*lstm_cell_out)[1]);\n\n      // concat fw_outputs and bw_outputs\n      for (int32_t i = 0; i < rnn_inputs->size(); ++i) {\n        (*rnn_inputs)[i] = JUST(functional::Concat({(*fw_outputs)[i], (*bw_outputs)[i]},\n                                                   (*bw_outputs)[0]->shape()->NumAxes() - 1));\n      }\n\n      if (dropout != 0 && train && l < num_layers - 1) {\n        std::shared_ptr<one::Tensor> stack_res = JUST(functional::Concat(*rnn_inputs, 0));\n        std::shared_ptr<one::Tensor> dropout_res =\n            JUST(functional::Dropout(stack_res, dropout, train, false, generator, nullptr));\n        int64_t input_offset = 0;\n        for (int32_t i = 0; i < num_steps; ++i) {\n          const int64_t batch_size = batch_sizes_vec[i];\n          (*rnn_inputs)[i] = JUST(functional::Narrow(dropout_res, 0, input_offset, batch_size));\n          input_offset += batch_size;\n        }\n      }\n    }\n  } else {\n    std::shared_ptr<TensorTuple> lstm_cell_out = std::make_shared<TensorTuple>(2);\n    for (int32_t l = 0; l < num_layers; ++l) {\n      int64_t last_batch_size = batch_sizes_vec[0];\n      (*lstm_cell_out)[0] = (*layer_hxs)[l];\n      (*lstm_cell_out)[1] = (*layer_cxs)[l];\n      auto& cell_param = (*rnn_params)[l];\n      TensorTuple final_hy_for_single_layer;\n      TensorTuple final_cy_for_single_layer;\n      for (int32_t i = 0; i < num_steps; ++i) {\n        const int64_t batch_size = batch_sizes_vec[i];\n        const int64_t dec = last_batch_size - batch_size;\n        if (dec > 0) {\n          final_hy_for_single_layer.emplace_back(\n              JUST(functional::Narrow((*lstm_cell_out)[0], 0, last_batch_size - dec, dec)));\n          (*lstm_cell_out)[0] =\n              JUST(functional::Narrow((*lstm_cell_out)[0], 0, 0, last_batch_size - dec));\n\n          final_cy_for_single_layer.emplace_back(\n              JUST(functional::Narrow((*lstm_cell_out)[1], 0, last_batch_size - dec, dec)));\n          (*lstm_cell_out)[1] =\n              JUST(functional::Narrow((*lstm_cell_out)[1], 0, 0, last_batch_size - dec));\n        }\n        last_batch_size = batch_size;\n        lstm_cell_out = JUST(LSTMCell<CellParams>{}((*rnn_inputs)[i], *lstm_cell_out, cell_param));\n        (*rnn_inputs)[i] = (*lstm_cell_out)[0];\n      }\n      final_hy_for_single_layer.emplace_back((*lstm_cell_out)[0]);\n      final_cy_for_single_layer.emplace_back((*lstm_cell_out)[1]);\n      std::reverse(final_hy_for_single_layer.begin(), final_hy_for_single_layer.end());\n      std::reverse(final_cy_for_single_layer.begin(), final_cy_for_single_layer.end());\n      final_hy.emplace_back(JUST(functional::Concat(final_hy_for_single_layer, 0)));\n      final_cy.emplace_back(JUST(functional::Concat(final_cy_for_single_layer, 0)));\n\n      if (dropout != 0 && train && l < num_layers - 1) {\n        std::shared_ptr<one::Tensor> stack_res = JUST(functional::Concat(*rnn_inputs, 0));\n        std::shared_ptr<one::Tensor> dropout_res =\n            JUST(functional::Dropout(stack_res, dropout, train, false, generator, nullptr));\n        int64_t input_offset = 0;\n        for (int32_t i = 0; i < num_steps; ++i) {\n          const int64_t batch_size = batch_sizes_vec[i];\n          (*rnn_inputs)[i] = JUST(functional::Narrow(dropout_res, 0, input_offset, batch_size));\n          input_offset += batch_size;\n        }\n      }\n    }\n  }\n\n  TensorTuple output;\n  std::shared_ptr<one::Tensor> output_0 = JUST(functional::Concat(*rnn_inputs, 0));\n  output.emplace_back(output_0);\n  output.emplace_back(JUST(functional::Stack(final_hy, 0)));\n  output.emplace_back(JUST(functional::Stack(final_cy, 0)));\n  return output;\n}\n\nclass LstmDataFunctor {\n public:\n  LstmDataFunctor() {}\n\n  Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& data,\n                                const std::shared_ptr<one::Tensor>& batch_sizes,\n                                const one::TensorTuple& hx, const one::TensorTuple& params,\n                                const bool& has_biases, const int32_t& num_layers,\n                                const float& dropout, const bool& train,\n                                const bool& bidirectional) const {\n    return _lstm_pack_sequence_impl(data, batch_sizes, hx, params, has_biases, num_layers, dropout,\n                                    train, bidirectional);\n  }\n};\n\nclass GruInputFunctor {\n public:\n  GruInputFunctor() {}\n\n  Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& input,\n                                const std::shared_ptr<one::Tensor>& hx,\n                                const one::TensorTuple& params, const bool& has_biases,\n                                const int32_t& num_layers, const float& dropout, const bool& train,\n                                const bool& bidirectional, const bool& batch_first) const {\n    return _rnn_impl<GRUCell<CellParams>>(input, hx, params, has_biases, num_layers, dropout, train,\n                                          bidirectional, batch_first);\n  }\n};\n\nclass GruDataFunctor {\n public:\n  GruDataFunctor() {}\n\n  Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& data,\n                                const std::shared_ptr<one::Tensor>& batch_sizes,\n                                const std::shared_ptr<one::Tensor>& hx,\n                                const one::TensorTuple& params, const bool& has_biases,\n                                const int32_t& num_layers, const float& dropout, const bool& train,\n                                const bool& bidirectional) const {\n    return _rnn_pack_sequence_impl<GRUCell<CellParams>>(data, batch_sizes, hx, params, has_biases,\n                                                        num_layers, dropout, train, bidirectional);\n  }\n};\n\nMaybe<void> checkLongTensor(const std::shared_ptr<one::Tensor>& tensor) {\n  auto& device = JUST(tensor->device())->type();\n  CHECK_OR_RETURN(tensor->ndim() == 1 && device == \"cpu\" && tensor->dtype() == DType::Int64())\n      << \"'lengths' argument should be a 1D CPU int64 tensor, but got \" << tensor->ndim() << \"D \"\n      << device << \" \" << tensor->dtype()->name() << \" tensor\";\n  return Maybe<void>::Ok();\n}\n\nclass PackPaddedSequenceFunctor {\n public:\n  PackPaddedSequenceFunctor() {}\n\n  Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& input,\n                                const std::shared_ptr<one::Tensor>& lengths,\n                                const bool& batch_first) const {\n    CHECK_OR_RETURN(input->is_local() && lengths->is_local())\n        << \"pack_padded_sequence only accept local tensors as input.\";\n    std::shared_ptr<one::Tensor> new_input = input;\n    if (batch_first) {\n      std::vector<int32_t> dims;\n      dims.resize(input->shape()->NumAxes());\n      dims[0] = 1;\n      dims[1] = 0;\n      for (int i = 2; i < input->shape()->NumAxes(); ++i) { dims[i] = i; }\n      new_input = JUST(functional::Permute(input, dims));\n    }\n    JUST(checkLongTensor(lengths));\n\n    int64_t batch_size = new_input->shape()->At(1);\n    std::vector<int64_t> lengths_vec;\n    lengths_vec.resize(lengths->nelement());\n    const auto& callback = [&](ep::Stream* stream,\n                               const std::shared_ptr<vm::EagerBlobObject>& eager_blob_object) {\n      SyncAutoMemcpy(stream, lengths_vec.data(), eager_blob_object->dptr(),\n                     lengths_vec.size() * sizeof(int64_t), memory::MakeHostMemCase(),\n                     eager_blob_object->mem_case());\n    };\n    JUST(SyncAccessTensorWithTimeOut(lengths, callback, \"const\"));\n\n    CHECK_OR_RETURN(new_input->nelement() > 0) << \"Cannot pack empty tensors.\";\n    CHECK_OR_RETURN(lengths->shape()->At(0) == batch_size)\n        << \"Expected `len(lengths)` to be equal to batch_size, but got \" << lengths->shape()->At(0)\n        << \" (batch_size=\" << batch_size << \")\";\n    CHECK_OR_RETURN(lengths_vec[batch_size - 1] > 0)\n        << \"Length of all samples has to be greater than 0, but found an element in 'lengths' that \"\n           \"is <= 0\";\n    for (int i = 0; i < batch_size - 1; ++i) {\n      if (lengths_vec[batch_size - 1 - i] > lengths_vec[batch_size - 2 - i]) {\n        CHECK_OR_RETURN(false) << \"`lengths` array must be sorted in decreasing order when \"\n                                  \"`enforce_sorted` is True. You can pass `enforce_sorted=False` \"\n                                  \"to pack_padded_sequence and/or pack_sequence to sidestep this \"\n                                  \"requirement if you do not need ONNX exportability.\";\n      }\n    }\n\n    std::vector<int64_t> step_shape_vec;  // == [-1, *input.shape[2:]]\n    {\n      const auto& input_sizes = new_input->shape();\n      step_shape_vec.push_back(-1);\n      for (int i = 2; i < input_sizes->NumAxes(); ++i) {\n        step_shape_vec.push_back(input_sizes->At(i));\n      }\n    }\n    DimVector rsv(step_shape_vec.size());\n    for (int i = 0; i < step_shape_vec.size(); ++i) { rsv[i] = step_shape_vec[i]; }\n    const Shape step_shape(rsv);\n\n    // To understand what's going on in this loop imagine that the input is a padded 2D\n    // array that looks like this (x = valid entry, . = padding)\n    //\n    //  1 1 1 1 1\n    //  2 2 2 . .\n    //  2 2 2 . .\n    //  4 . . . .\n    //  4 . . . .\n    //\n    // Where the vertical dimension corresponds to time, and horizontal dim to batch.\n    // In this example, the lengths array will be equal to [5, 3, 3, 1, 1], and we will\n    // iterate over them in reverse order (from the rightmost column to the left).\n    // We want to avoid eager slicing of the input at every time step, and wait for\n    // the moments where the length increases. In this example, that will happen at the\n    // first, second and fourth steps. Then, we slice out the whole block of the input\n    // that corresponds to this length, and hasn't been sliced yet (the steps at which each\n    // element is sliced are annotated in the array above).  You can think of this as if we\n    // were scanning the sequences from the shortest one, and every time we realize there's\n    // more elements below in our column, we lower the counter (prev_l), and append the new\n    // block to the output.\n    std::vector<int64_t> batch_sizes;\n    batch_sizes.resize(lengths_vec[0]);\n    int64_t* batch_sizes_ptr = batch_sizes.data();\n    TensorTuple steps;\n    int64_t prev_l = 0;\n    for (int i = 0; i < batch_size; ++i) {\n      int64_t l = lengths_vec[batch_size - 1 - i];\n      if (l > prev_l) {\n        auto current_batch_size = batch_size - i;\n        std::shared_ptr<Tensor> slice_res =\n            JUST(functional::Narrow(new_input, 0, prev_l, l - prev_l));\n        slice_res = JUST(functional::Narrow(slice_res, 1, 0, current_batch_size));\n        slice_res = JUST(functional::View(slice_res->contiguous(), step_shape));\n        steps.emplace_back(slice_res);\n        for (int64_t j = 0; j < (l - prev_l); ++j) { (*batch_sizes_ptr++) = current_batch_size; }\n        prev_l = l;\n      }\n      CHECK_OR_RETURN(l >= prev_l)\n          << \"PackPaddedSequenceFunctor: `lengths` array must be sorted in decreasing order.\";\n    }\n\n    DimVector lsv(1);\n    lsv[0] = lengths_vec[0];\n    const Shape ls(lsv);\n    std::shared_ptr<Tensor> batch_sizes_t =\n        JUST(functional::Empty(ls, lengths->dtype(), JUST(lengths->device()),\n                               /*requires_grad=*/lengths->requires_grad(), /*pin_memory=*/false));\n    const auto& callback2 = [&](ep::Stream* stream,\n                                const std::shared_ptr<vm::EagerBlobObject>& eager_blob_object) {\n      SyncAutoMemcpy(stream, eager_blob_object->mut_dptr(), batch_sizes.data(),\n                     batch_sizes.size() * sizeof(int64_t), eager_blob_object->mem_case(),\n                     memory::MakeHostMemCase());  // copy 1 scalar(int64_t) tensor's value to max\n    };\n    JUST(SyncAccessTensorWithTimeOut(batch_sizes_t, callback2, \"const\"));\n\n    std::shared_ptr<TensorTuple> output = std::make_shared<TensorTuple>(2);\n    (*output)[0] = JUST(functional::Concat(steps, 0));\n    (*output)[1] = batch_sizes_t;\n    return output;\n  }\n};\n\n}  // namespace impl\n\nONEFLOW_FUNCTION_LIBRARY(m) {\n  m.add_functor<impl::RnnTanhCellFunctor>(\"RnnTanhCell\");\n  m.add_functor<impl::RnnReluCellFunctor>(\"RnnReluCell\");\n  m.add_functor<impl::LstmCellFunctor>(\"LstmCell\");\n  m.add_functor<impl::GruCellFunctor>(\"GruCell\");\n  m.add_functor<impl::FusedLstmCellFunctor>(\"FusedLstmCell\");\n  m.add_functor<impl::FusedLstmCellGradFunctor>(\"FusedLstmCellGrad\");\n  m.add_functor<impl::FusedGruCellFunctor>(\"FusedGruCell\");\n  m.add_functor<impl::FusedGruCellGradFunctor>(\"FusedGruCellGrad\");\n  m.add_functor<impl::RnnTanhInputFunctor>(\"RnnTanhInput\");\n  m.add_functor<impl::RnnTanhDataFunctor>(\"RnnTanhData\");\n  m.add_functor<impl::RnnReluInputFunctor>(\"RnnReluInput\");\n  m.add_functor<impl::RnnReluDataFunctor>(\"RnnReluData\");\n  m.add_functor<impl::LstmInputFunctor>(\"LstmInput\");\n  m.add_functor<impl::LstmDataFunctor>(\"LstmData\");\n  m.add_functor<impl::GruInputFunctor>(\"GruInput\");\n  m.add_functor<impl::GruDataFunctor>(\"GruData\");\n  m.add_functor<impl::PackPaddedSequenceFunctor>(\"PackPaddedSequence\");\n}\n\n}  // namespace functional\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/functional/impl/slice_boxing_functor.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/id_util.h\"\n#include \"oneflow/core/framework/nd_sbp.h\"\n#include \"oneflow/core/framework/op_builder.h\"\n#include \"oneflow/core/framework/op_expr.h\"\n#include \"oneflow/core/framework/op_interpreter/op_interpreter_util.h\"\n#include \"oneflow/core/framework/tensor.h\"\n#include \"oneflow/core/framework/tensor_tuple.h\"\n#include \"oneflow/core/functional/functional.h\"\n#include \"oneflow/core/functional/function_library.h\"\n#include \"oneflow/core/functional/impl/common.h\"\n\nnamespace oneflow {\nnamespace one {\nnamespace functional {\n\nnamespace impl {\n\nnamespace {\n\nbool IsSplitSbp(Symbol<SbpParallel> sbp_parallel) { return sbp_parallel->has_split_parallel(); }\n\nMaybe<one::UserOpExpr> EagerSToB(Symbol<ParallelDesc> in_parallel_desc,\n                                 Symbol<ParallelDesc> out_parallel_desc,\n                                 Symbol<SbpParallel> src_sbp, const Shape& shape) {\n  return one::OpBuilder(\"eager_s_to_b\", *JUST(UniqueStr(\"eager_s_to_b\")))\n      .Input(\"in\")\n      .Output(\"out\")\n      .Attr<int64_t>(\"in_split_axis\", src_sbp->split_parallel().axis())\n      .Attr<std::string>(\"in_parallel_conf\", PbMessage2TxtString(in_parallel_desc->parallel_conf()))\n      .Attr<std::string>(\"out_parallel_conf\",\n                         PbMessage2TxtString(out_parallel_desc->parallel_conf()))\n      .Attr<Shape>(\"shape\", shape)\n      .Build();\n}\n\nstatic constexpr auto* CachedEagerSToBOpExpr = DECORATE(&EagerSToB, ThreadLocalCopiable);\n\nMaybe<one::UserOpExpr> EagerPToB(Symbol<ParallelDesc> in_parallel_desc,\n                                 Symbol<ParallelDesc> out_parallel_desc, const Shape& shape) {\n  return one::OpBuilder(\"eager_p_to_b\", *JUST(UniqueStr(\"eager_p_to_b\")))\n      .Input(\"in\")\n      .Output(\"out\")\n      .Attr<std::string>(\"in_parallel_conf\", PbMessage2TxtString(in_parallel_desc->parallel_conf()))\n      .Attr<std::string>(\"out_parallel_conf\",\n                         PbMessage2TxtString(out_parallel_desc->parallel_conf()))\n      .Attr<Shape>(\"shape\", shape)\n      .Build();\n}\n\nstatic constexpr auto* CachedEagerPToBOpExpr = DECORATE(&EagerPToB, ThreadLocalCopiable);\n\nMaybe<one::UserOpExpr> EagerNaiveSToS(Symbol<ParallelDesc> in_parallel_desc,\n                                      Symbol<ParallelDesc> out_parallel_desc,\n                                      Symbol<SbpParallel> src_sbp, Symbol<SbpParallel> dst_sbp,\n                                      const Shape& shape) {\n  return one::OpBuilder(\"eager_naive_s_to_s\", *JUST(UniqueStr(\"eager_naive_s_to_s\")))\n      .Input(\"in\")\n      .Output(\"out\")\n      .Attr<int64_t>(\"in_split_axis\", src_sbp->split_parallel().axis())\n      .Attr<int64_t>(\"out_split_axis\", dst_sbp->split_parallel().axis())\n      .Attr<std::string>(\"in_parallel_conf\", PbMessage2TxtString(in_parallel_desc->parallel_conf()))\n      .Attr<std::string>(\"out_parallel_conf\",\n                         PbMessage2TxtString(out_parallel_desc->parallel_conf()))\n      .Attr<Shape>(\"shape\", shape)\n      .Build();\n}\n\nstatic constexpr auto* CachedEagerNaiveSToSOpExpr = DECORATE(&EagerNaiveSToS, ThreadLocalCopiable);\n\nMaybe<one::UserOpExpr> EagerBToS(Symbol<ParallelDesc> in_parallel_desc,\n                                 Symbol<ParallelDesc> out_parallel_desc,\n                                 Symbol<SbpParallel> dst_sbp, const Shape& shape) {\n  return one::OpBuilder(\"eager_b_to_s\", *JUST(UniqueStr(\"eager_b_to_s\")))\n      .Input(\"in\")\n      .Output(\"out\")\n      .Attr<int64_t>(\"out_split_axis\", dst_sbp->split_parallel().axis())\n      .Attr<std::string>(\"in_parallel_conf\", PbMessage2TxtString(in_parallel_desc->parallel_conf()))\n      .Attr<std::string>(\"out_parallel_conf\",\n                         PbMessage2TxtString(out_parallel_desc->parallel_conf()))\n      .Attr<Shape>(\"shape\", shape)\n      .Build();\n}\n\nstatic constexpr auto* CachedEagerBToSOpExpr = DECORATE(&EagerBToS, ThreadLocalCopiable);\n\nMaybe<one::UserOpExpr> EagerPToS(Symbol<ParallelDesc> in_parallel_desc,\n                                 Symbol<ParallelDesc> out_parallel_desc,\n                                 Symbol<SbpParallel> dst_sbp, const Shape& shape) {\n  return one::OpBuilder(\"eager_p_to_s\", *JUST(UniqueStr(\"eager_p_to_s\")))\n      .Input(\"in\")\n      .Output(\"out\")\n      .Attr<int64_t>(\"out_split_axis\", dst_sbp->split_parallel().axis())\n      .Attr<std::string>(\"in_parallel_conf\", PbMessage2TxtString(in_parallel_desc->parallel_conf()))\n      .Attr<std::string>(\"out_parallel_conf\",\n                         PbMessage2TxtString(out_parallel_desc->parallel_conf()))\n      .Attr<Shape>(\"shape\", shape)\n      .Build();\n}\n\nstatic constexpr auto* CachedEagerPToSOpExpr = DECORATE(&EagerPToS, ThreadLocalCopiable);\n\nMaybe<one::UserOpExpr> EagerSToP(Symbol<ParallelDesc> in_parallel_desc,\n                                 Symbol<ParallelDesc> out_parallel_desc,\n                                 Symbol<SbpParallel> src_sbp, const Shape& shape) {\n  return one::OpBuilder(\"eager_s_to_p\", *JUST(UniqueStr(\"eager_s_to_p\")))\n      .Input(\"in\")\n      .Output(\"out\")\n      .Attr<int64_t>(\"in_split_axis\", src_sbp->split_parallel().axis())\n      .Attr<std::string>(\"in_parallel_conf\", PbMessage2TxtString(in_parallel_desc->parallel_conf()))\n      .Attr<std::string>(\"out_parallel_conf\",\n                         PbMessage2TxtString(out_parallel_desc->parallel_conf()))\n      .Attr<Shape>(\"shape\", shape)\n      .Build();\n}\n\nstatic constexpr auto* CachedEagerSToPOpExpr = DECORATE(&EagerSToP, ThreadLocalCopiable);\n\n}  // namespace\n\nclass EagerSToBFunctor {\n public:\n  EagerSToBFunctor() = default;\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,\n                           Symbol<ParallelDesc> in_parallel_desc,\n                           Symbol<ParallelDesc> out_parallel_desc,\n                           const std::vector<Symbol<SbpParallel>>& in_sbp_parallels,\n                           const Shape& shape) const {\n    Symbol<NdSbp> in_nd_sbp = JUST(GetNdSbp(in_sbp_parallels));\n    {\n      CHECK_OR_RETURN(x->is_local())\n          << Error::RuntimeError() << \"input tensors `.is_local` should be true\";\n      CHECK_OR_RETURN(x->is_eager())\n          << Error::RuntimeError() << \"input tensors `.is_eager` should be true\";\n      CHECK_OR_RETURN((in_nd_sbp->sbp_parallel_size() == 1)\n                      && IsSplitSbp(in_nd_sbp->sbp_parallel(0)))\n          << Error::RuntimeError() << \"The input tensor's sbp should be (split, )\";\n    }\n    std::shared_ptr<OpExpr> op_expr = JUST(CachedEagerSToBOpExpr(\n        in_parallel_desc, out_parallel_desc, SymbolOf(in_nd_sbp->sbp_parallel(0)), shape));\n    return JUST(OpInterpUtil::Dispatch<Tensor>(*op_expr, {x}));\n  }\n};\n\nclass EagerPToBFunctor {\n public:\n  EagerPToBFunctor() = default;\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,\n                           Symbol<ParallelDesc> in_parallel_desc,\n                           Symbol<ParallelDesc> out_parallel_desc, const Shape& shape) const {\n    {\n      CHECK_OR_RETURN(x->is_local())\n          << Error::RuntimeError() << \"input tensors `.is_local` should be true\";\n      CHECK_OR_RETURN(x->is_eager())\n          << Error::RuntimeError() << \"input tensors `.is_eager` should be true\";\n    }\n    std::shared_ptr<OpExpr> op_expr =\n        JUST(CachedEagerPToBOpExpr(in_parallel_desc, out_parallel_desc, shape));\n    return JUST(OpInterpUtil::Dispatch<Tensor>(*op_expr, {x}));\n  }\n};\n\nclass EagerNaiveSToSFunctor {\n public:\n  EagerNaiveSToSFunctor() = default;\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,\n                           Symbol<ParallelDesc> in_parallel_desc,\n                           Symbol<ParallelDesc> out_parallel_desc,\n                           const std::vector<Symbol<SbpParallel>>& in_sbp_parallels,\n                           const std::vector<Symbol<SbpParallel>>& out_sbp_parallels,\n                           const Shape& shape) const {\n    Symbol<NdSbp> in_nd_sbp = JUST(GetNdSbp(in_sbp_parallels));\n    Symbol<NdSbp> out_nd_sbp = JUST(GetNdSbp(out_sbp_parallels));\n    {\n      CHECK_OR_RETURN(x->is_local())\n          << Error::RuntimeError() << \"input tensors `.is_local` should be true\";\n      CHECK_OR_RETURN(x->is_eager())\n          << Error::RuntimeError() << \"input tensors `.is_eager` should be true\";\n      CHECK_OR_RETURN((in_nd_sbp->sbp_parallel_size() == 1)\n                      && IsSplitSbp(in_nd_sbp->sbp_parallel(0)))\n          << Error::RuntimeError() << \"The input tensor's sbp should be (split, )\";\n      CHECK_OR_RETURN((out_nd_sbp->sbp_parallel_size() == 1)\n                      && IsSplitSbp(out_nd_sbp->sbp_parallel(0)))\n          << Error::RuntimeError() << \"The output tensor's sbp should be (split, )\";\n    }\n    std::shared_ptr<OpExpr> op_expr = JUST(CachedEagerNaiveSToSOpExpr(\n        in_parallel_desc, out_parallel_desc, SymbolOf(in_nd_sbp->sbp_parallel(0)),\n        SymbolOf(out_nd_sbp->sbp_parallel(0)), shape));\n    return JUST(OpInterpUtil::Dispatch<Tensor>(*op_expr, {x}));\n  }\n};\n\nclass EagerBToSFunctor {\n public:\n  EagerBToSFunctor() = default;\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,\n                           Symbol<ParallelDesc> in_parallel_desc,\n                           Symbol<ParallelDesc> out_parallel_desc,\n                           const std::vector<Symbol<SbpParallel>>& out_sbp_parallels,\n                           const Shape& shape) const {\n    Symbol<NdSbp> out_nd_sbp = JUST(GetNdSbp(out_sbp_parallels));\n    {\n      CHECK_OR_RETURN(x->is_local())\n          << Error::RuntimeError() << \"input tensors `.is_local` should be true\";\n      CHECK_OR_RETURN(x->is_eager())\n          << Error::RuntimeError() << \"input tensors `.is_eager` should be true\";\n      CHECK_OR_RETURN((out_nd_sbp->sbp_parallel_size() == 1)\n                      && IsSplitSbp(out_nd_sbp->sbp_parallel(0)))\n          << Error::RuntimeError() << \"The output tensor's sbp should be (split, )\";\n    }\n    std::shared_ptr<OpExpr> op_expr = JUST(CachedEagerBToSOpExpr(\n        in_parallel_desc, out_parallel_desc, SymbolOf(out_nd_sbp->sbp_parallel(0)), shape));\n    return JUST(OpInterpUtil::Dispatch<Tensor>(*op_expr, {x}));\n  }\n};\n\nclass EagerPToSFunctor {\n public:\n  EagerPToSFunctor() = default;\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,\n                           Symbol<ParallelDesc> in_parallel_desc,\n                           Symbol<ParallelDesc> out_parallel_desc,\n                           const std::vector<Symbol<SbpParallel>>& out_sbp_parallels,\n                           const Shape& shape) const {\n    Symbol<NdSbp> out_nd_sbp = JUST(GetNdSbp(out_sbp_parallels));\n    {\n      CHECK_OR_RETURN(x->is_local())\n          << Error::RuntimeError() << \"input tensors `.is_local` should be true\";\n      CHECK_OR_RETURN(x->is_eager())\n          << Error::RuntimeError() << \"input tensors `.is_eager` should be true\";\n      CHECK_OR_RETURN((out_nd_sbp->sbp_parallel_size() == 1)\n                      && IsSplitSbp(out_nd_sbp->sbp_parallel(0)))\n          << Error::RuntimeError() << \"The output tensor's sbp should be (split, )\";\n    }\n    std::shared_ptr<OpExpr> op_expr = JUST(CachedEagerPToSOpExpr(\n        in_parallel_desc, out_parallel_desc, SymbolOf(out_nd_sbp->sbp_parallel(0)), shape));\n    return JUST(OpInterpUtil::Dispatch<Tensor>(*op_expr, {x}));\n  }\n};\n\nclass EagerSToPFunctor {\n public:\n  EagerSToPFunctor() = default;\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,\n                           Symbol<ParallelDesc> in_parallel_desc,\n                           Symbol<ParallelDesc> out_parallel_desc,\n                           const std::vector<Symbol<SbpParallel>>& in_sbp_parallels,\n                           const Shape& shape) const {\n    Symbol<NdSbp> in_nd_sbp = JUST(GetNdSbp(in_sbp_parallels));\n    {\n      CHECK_OR_RETURN(x->is_local())\n          << Error::RuntimeError() << \"input tensors `.is_local` should be true\";\n      CHECK_OR_RETURN(x->is_eager())\n          << Error::RuntimeError() << \"input tensors `.is_eager` should be true\";\n      CHECK_OR_RETURN((in_nd_sbp->sbp_parallel_size() == 1)\n                      && IsSplitSbp(in_nd_sbp->sbp_parallel(0)))\n          << Error::RuntimeError() << \"The input tensor's sbp should be (split, )\";\n    }\n    std::shared_ptr<OpExpr> op_expr = JUST(CachedEagerSToPOpExpr(\n        in_parallel_desc, out_parallel_desc, SymbolOf(in_nd_sbp->sbp_parallel(0)), shape));\n    return JUST(OpInterpUtil::Dispatch<Tensor>(*op_expr, {x}));\n  }\n};\n\n}  // namespace impl\n\nONEFLOW_FUNCTION_LIBRARY(m) {\n  m.add_functor<impl::EagerSToBFunctor>(\"EagerSToB\");\n  m.add_functor<impl::EagerPToBFunctor>(\"EagerPToB\");\n  m.add_functor<impl::EagerNaiveSToSFunctor>(\"EagerNaiveSToS\");\n  m.add_functor<impl::EagerBToSFunctor>(\"EagerBToS\");\n  m.add_functor<impl::EagerPToSFunctor>(\"EagerPToS\");\n  m.add_functor<impl::EagerSToPFunctor>(\"EagerSToP\");\n};\n\n}  // namespace functional\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/functional/impl/test_functor.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/common/just.h\"\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/framework/op_builder.h\"\n#include \"oneflow/core/framework/op_expr.h\"\n#include \"oneflow/core/framework/op_interpreter/op_interpreter_util.h\"\n#include \"oneflow/core/framework/tensor.h\"\n#include \"oneflow/core/framework/tensor_tuple.h\"\n#include \"oneflow/core/functional/functional.h\"\n#include \"oneflow/core/functional/function_library.h\"\n#include \"oneflow/core/functional/functional_api.yaml.h\"\n#include \"oneflow/core/functional/impl/common.h\"\n\nnamespace oneflow {\nnamespace one {\nnamespace functional {\n\nnamespace impl {\n\nclass ThrowErrorFunctor final {\n public:\n  ThrowErrorFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"throw_error\").Input(\"x\").Output(\"y\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& input) const {\n    return JUST(OpInterpUtil::Dispatch<Tensor>(*op_, {input}));\n  }\n\n protected:\n  std::shared_ptr<OpExpr> op_;\n};\n\n}  // namespace impl\n\nusing namespace impl;\n\nONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor<ThrowErrorFunctor>(\"ThrowError\"); };\n\n}  // namespace functional\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/functional/impl/unary_functor.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/functional/impl/unary_functor.h\"\n#include \"oneflow/core/functional/impl/binary_functor.h\"\n\n#include \"oneflow/core/framework/op_builder.h\"\n#include \"oneflow/core/functional/function_library.h\"\n#include \"oneflow/user/ops/math_unary_elementwise_seq.h\"\n\nnamespace oneflow {\nnamespace one {\nnamespace functional {\n\nnamespace impl {\n\n#define INPLACE_UNARY_FLOAT_FUNC_SEQ          \\\n  OF_PP_MAKE_TUPLE_SEQ(\"sin\", InplaceSin)     \\\n  OF_PP_MAKE_TUPLE_SEQ(\"floor\", InplaceFloor) \\\n  OF_PP_MAKE_TUPLE_SEQ(\"ceil\", InplaceCeil)   \\\n  OF_PP_MAKE_TUPLE_SEQ(\"round\", InplaceRound)\n\n#define UNARY_PRIMITIVE_FUNC_BWD_WITH_DY_X_SEQ    \\\n  OF_PP_MAKE_TUPLE_SEQ(\"abs\", Abs)                \\\n  OF_PP_MAKE_TUPLE_SEQ(\"acos\", Acos)              \\\n  OF_PP_MAKE_TUPLE_SEQ(\"cosh\", Cosh)              \\\n  OF_PP_MAKE_TUPLE_SEQ(\"lgamma\", Lgamma)          \\\n  OF_PP_MAKE_TUPLE_SEQ(\"log_sigmoid\", LogSigmoid) \\\n  OF_PP_MAKE_TUPLE_SEQ(\"reciprocal_no_nan\", ReciprocalNoNan)\n\n#define FLOAT_UNARY_PRIMITIVE_FUNC_BWD_WITH_DY_X_SEQ \\\n  OF_PP_MAKE_TUPLE_SEQ(\"acosh\", Acosh)               \\\n  OF_PP_MAKE_TUPLE_SEQ(\"asin\", Asin)                 \\\n  OF_PP_MAKE_TUPLE_SEQ(\"asinh\", Asinh)               \\\n  OF_PP_MAKE_TUPLE_SEQ(\"atan\", Atan)                 \\\n  OF_PP_MAKE_TUPLE_SEQ(\"atanh\", Atanh)               \\\n  OF_PP_MAKE_TUPLE_SEQ(\"sin\", Sin)                   \\\n  OF_PP_MAKE_TUPLE_SEQ(\"cos\", Cos)                   \\\n  OF_PP_MAKE_TUPLE_SEQ(\"erf\", Erf)                   \\\n  OF_PP_MAKE_TUPLE_SEQ(\"erfc\", Erfc)                 \\\n  OF_PP_MAKE_TUPLE_SEQ(\"exp\", Exp)                   \\\n  OF_PP_MAKE_TUPLE_SEQ(\"exp2\", Exp2)                 \\\n  OF_PP_MAKE_TUPLE_SEQ(\"expm1\", Expm1)               \\\n  OF_PP_MAKE_TUPLE_SEQ(\"log\", Log)                   \\\n  OF_PP_MAKE_TUPLE_SEQ(\"log2\", Log2)                 \\\n  OF_PP_MAKE_TUPLE_SEQ(\"log10\", Log10)               \\\n  OF_PP_MAKE_TUPLE_SEQ(\"log1p\", Log1p)               \\\n  OF_PP_MAKE_TUPLE_SEQ(\"reciprocal\", Reciprocal)     \\\n  OF_PP_MAKE_TUPLE_SEQ(\"rsqrt\", Rsqrt)               \\\n  OF_PP_MAKE_TUPLE_SEQ(\"sinh\", Sinh)                 \\\n  OF_PP_MAKE_TUPLE_SEQ(\"sqrt\", Sqrt)                 \\\n  OF_PP_MAKE_TUPLE_SEQ(\"square\", Square)             \\\n  OF_PP_MAKE_TUPLE_SEQ(\"tan\", Tan)                   \\\n  OF_PP_MAKE_TUPLE_SEQ(\"digamma\", Digamma)\n\n#define FLOAT_UNARY_PRIMITIVE_FUNC_BWD_WITH_DY_Y_SEQ \\\n  OF_PP_MAKE_TUPLE_SEQ(\"sigmoid\", Sigmoid)           \\\n  OF_PP_MAKE_TUPLE_SEQ(\"tanh\", Tanh)\n\n#define UNARY_FUNC_BWD_WITH_FILL_SEQ   \\\n  OF_PP_MAKE_TUPLE_SEQ(\"rint\", Rint)   \\\n  OF_PP_MAKE_TUPLE_SEQ(\"round\", Round) \\\n  OF_PP_MAKE_TUPLE_SEQ(\"floor\", Floor) \\\n  OF_PP_MAKE_TUPLE_SEQ(\"ceil\", Ceil)\n\n#define FLOAT_UNARY_FUNC_BWD_WITH_FILL_SEQ \\\n  OF_PP_MAKE_TUPLE_SEQ(\"sign\", Sign)       \\\n  OF_PP_MAKE_TUPLE_SEQ(\"not_equal_zero\", NotEqualZero)\n\n#define LOGICAL_FLOAT_UNARY_FUNC_SEQ OF_PP_MAKE_TUPLE_SEQ(\"logical_not\", LogicalNot)\n\n#define UNARY_ELEMENTWISE_FUNCTOR(op_type_name, class_name, base)                    \\\n  class class_name##Functor : public base {                                          \\\n   public:                                                                           \\\n    class_name##Functor() {                                                          \\\n      op_ = CHECK_JUST(one::OpBuilder(op_type_name).Input(\"x\").Output(\"y\").Build()); \\\n    }                                                                                \\\n  };\n\n#define UNARY_ELEMENTWISE_BWD_WITH_DY_X_FUNCTOR(op_type_name, class_name, base) \\\n  class class_name##WithDyXGradFunctor : public base {                          \\\n   public:                                                                      \\\n    class_name##WithDyXGradFunctor() {                                          \\\n      op_ = CHECK_JUST(one::OpBuilder(std::string(\"\") + op_type_name + \"_grad\") \\\n                           .Input(\"x\")                                          \\\n                           .Input(\"dy\")                                         \\\n                           .Output(\"dx\")                                        \\\n                           .Build());                                           \\\n    }                                                                           \\\n  };\n\n#define UNARY_ELEMENTWISE_BWD_WITH_DY_Y_FUNCTOR(op_type_name, class_name, base) \\\n  class class_name##WithDyYGradFunctor : public base {                          \\\n   public:                                                                      \\\n    class_name##WithDyYGradFunctor() {                                          \\\n      op_ = CHECK_JUST(one::OpBuilder(std::string(\"\") + op_type_name + \"_grad\") \\\n                           .Input(\"y\")                                          \\\n                           .Input(\"dy\")                                         \\\n                           .Output(\"dx\")                                        \\\n                           .Build());                                           \\\n    }                                                                           \\\n  };\n\n#define INPLACE_UNARY_FUNCTORS(op_type_name, class_name) \\\n  UNARY_ELEMENTWISE_FUNCTOR(op_type_name, class_name, InplaceUnaryFunctor)\n#define INPLACE_FLOAT_UNARY_FUNCTORS(op_type_name, class_name) \\\n  UNARY_ELEMENTWISE_FUNCTOR(op_type_name, class_name, InplaceFloatUnaryFunctor)\n#define LOGICAL_FLOAT_UNARY_FUNCTORS(op_type_name, class_name) \\\n  UNARY_ELEMENTWISE_FUNCTOR(op_type_name, class_name, FloatUnaryFunctor)\n#define UNARY_FUNCTORS(op_type_name, class_name)                    \\\n  UNARY_ELEMENTWISE_FUNCTOR(op_type_name, class_name, UnaryFunctor) \\\n  UNARY_ELEMENTWISE_BWD_WITH_DY_X_FUNCTOR(op_type_name, class_name, BinaryFunctor)\n#define FLOAT_UNARY_FUNCTORS(op_type_name, class_name)                   \\\n  UNARY_ELEMENTWISE_FUNCTOR(op_type_name, class_name, FloatUnaryFunctor) \\\n  UNARY_ELEMENTWISE_BWD_WITH_DY_X_FUNCTOR(op_type_name, class_name, BinaryFunctor)\n\n#define UNARY_BWD_WITH_DY_X_FUNCTORS(op_type_name, class_name)      \\\n  UNARY_ELEMENTWISE_FUNCTOR(op_type_name, class_name, UnaryFunctor) \\\n  UNARY_ELEMENTWISE_BWD_WITH_DY_X_FUNCTOR(op_type_name, class_name, BinaryFunctor)\n\n#define FLOAT_UNARY_BWD_WITH_DY_X_FUNCTORS(op_type_name, class_name)     \\\n  UNARY_ELEMENTWISE_FUNCTOR(op_type_name, class_name, FloatUnaryFunctor) \\\n  UNARY_ELEMENTWISE_BWD_WITH_DY_X_FUNCTOR(op_type_name, class_name, BinaryFunctor)\n\n#define FLOAT_UNARY_WITH_DY_Y_FUNCTORS(op_type_name, class_name)         \\\n  UNARY_ELEMENTWISE_FUNCTOR(op_type_name, class_name, FloatUnaryFunctor) \\\n  UNARY_ELEMENTWISE_BWD_WITH_DY_Y_FUNCTOR(op_type_name, class_name, BinaryFunctor)\n\n#define FLOAT_UNARY_BWD_WITH_FILL_FUNCTORS(op_type_name, class_name) \\\n  UNARY_ELEMENTWISE_FUNCTOR(op_type_name, class_name, FloatUnaryFunctor)\n\n#define UNARY_BWD_WITH_FILL_FUNCTORS(op_type_name, class_name) \\\n  UNARY_ELEMENTWISE_FUNCTOR(op_type_name, class_name, UnaryFunctor)\n\nOF_PP_FOR_EACH_TUPLE(INPLACE_FLOAT_UNARY_FUNCTORS, INPLACE_UNARY_FLOAT_FUNC_SEQ);\nOF_PP_FOR_EACH_TUPLE(LOGICAL_FLOAT_UNARY_FUNCTORS, LOGICAL_FLOAT_UNARY_FUNC_SEQ);\n\nOF_PP_FOR_EACH_TUPLE(UNARY_BWD_WITH_DY_X_FUNCTORS, UNARY_PRIMITIVE_FUNC_BWD_WITH_DY_X_SEQ);\nOF_PP_FOR_EACH_TUPLE(FLOAT_UNARY_BWD_WITH_DY_X_FUNCTORS,\n                     FLOAT_UNARY_PRIMITIVE_FUNC_BWD_WITH_DY_X_SEQ);\n\nOF_PP_FOR_EACH_TUPLE(FLOAT_UNARY_WITH_DY_Y_FUNCTORS, FLOAT_UNARY_PRIMITIVE_FUNC_BWD_WITH_DY_Y_SEQ);\nOF_PP_FOR_EACH_TUPLE(UNARY_BWD_WITH_FILL_FUNCTORS, UNARY_FUNC_BWD_WITH_FILL_SEQ);\nOF_PP_FOR_EACH_TUPLE(FLOAT_UNARY_BWD_WITH_FILL_FUNCTORS, FLOAT_UNARY_FUNC_BWD_WITH_FILL_SEQ);\n\nUNARY_ELEMENTWISE_FUNCTOR(\"negative\", Negative, FloatUnaryFunctor)\nUNARY_ELEMENTWISE_FUNCTOR(\"bitwise_not\", BitwiseNot, UnaryFunctor)\nUNARY_ELEMENTWISE_FUNCTOR(\"trigamma\", Trigamma, FloatUnaryFunctor)\n\n}  // namespace impl\n\nusing namespace impl;\n#define ADD_UNARY_FUNCTOR_WITH_DY_X(class_name, functor_name) \\\n  m.add_functor<class_name##Functor>(functor_name);           \\\n  m.add_functor<class_name##WithDyXGradFunctor>(std::string(\"\") + functor_name + \"Grad\");\n\n#define ADD_UNARY_FUNCTOR_WITH_DY_Y(class_name, functor_name) \\\n  m.add_functor<class_name##Functor>(functor_name);           \\\n  m.add_functor<class_name##WithDyYGradFunctor>(std::string(\"\") + functor_name + \"Grad\");\n\nONEFLOW_FUNCTION_LIBRARY(m) {\n  ADD_UNARY_FUNCTOR_WITH_DY_X(Abs, \"Abs\");\n  ADD_UNARY_FUNCTOR_WITH_DY_X(Acos, \"Acos\");\n  ADD_UNARY_FUNCTOR_WITH_DY_X(Acosh, \"Acosh\");\n  ADD_UNARY_FUNCTOR_WITH_DY_X(Asin, \"Asin\");\n  ADD_UNARY_FUNCTOR_WITH_DY_X(Asinh, \"Asinh\");\n  ADD_UNARY_FUNCTOR_WITH_DY_X(Atan, \"Atan\");\n  ADD_UNARY_FUNCTOR_WITH_DY_X(Atanh, \"Atanh\");\n  m.add_functor<CeilFunctor>(\"Ceil\");\n  ADD_UNARY_FUNCTOR_WITH_DY_X(Cos, \"Cos\");\n  ADD_UNARY_FUNCTOR_WITH_DY_X(Cosh, \"Cosh\");\n  ADD_UNARY_FUNCTOR_WITH_DY_X(Digamma, \"Digamma\");\n  ADD_UNARY_FUNCTOR_WITH_DY_X(Erf, \"Erf\");\n  ADD_UNARY_FUNCTOR_WITH_DY_X(Erfc, \"Erfc\");\n  ADD_UNARY_FUNCTOR_WITH_DY_X(Exp, \"Exp\");\n  ADD_UNARY_FUNCTOR_WITH_DY_X(Exp2, \"Exp2\");\n  ADD_UNARY_FUNCTOR_WITH_DY_X(Expm1, \"Expm1\");\n  m.add_functor<FloorFunctor>(\"Floor\");\n  ADD_UNARY_FUNCTOR_WITH_DY_X(Lgamma, \"Lgamma\");\n  ADD_UNARY_FUNCTOR_WITH_DY_X(Log, \"Log\");\n  ADD_UNARY_FUNCTOR_WITH_DY_X(Log2, \"Log2\");\n  ADD_UNARY_FUNCTOR_WITH_DY_X(Log10, \"Log10\");\n  ADD_UNARY_FUNCTOR_WITH_DY_X(Log1p, \"Log1p\");\n  ADD_UNARY_FUNCTOR_WITH_DY_X(LogSigmoid, \"LogSigmoid\");\n  m.add_functor<NegativeFunctor>(\"Negative\");\n  m.add_functor<BitwiseNotFunctor>(\"BitwiseNot\");\n  ADD_UNARY_FUNCTOR_WITH_DY_X(Reciprocal, \"Reciprocal\");\n  ADD_UNARY_FUNCTOR_WITH_DY_X(ReciprocalNoNan, \"ReciprocalNoNan\");\n  m.add_functor<RintFunctor>(\"Rint\");\n  m.add_functor<RoundFunctor>(\"Round\");\n  ADD_UNARY_FUNCTOR_WITH_DY_X(Rsqrt, \"Rsqrt\");\n  ADD_UNARY_FUNCTOR_WITH_DY_Y(Sigmoid, \"Sigmoid\");\n  m.add_functor<SignFunctor>(\"Sign\");\n  ADD_UNARY_FUNCTOR_WITH_DY_X(Sin, \"Sin\");\n  ADD_UNARY_FUNCTOR_WITH_DY_X(Sinh, \"Sinh\");\n  ADD_UNARY_FUNCTOR_WITH_DY_X(Sqrt, \"Sqrt\");\n  ADD_UNARY_FUNCTOR_WITH_DY_X(Square, \"Square\");\n  ADD_UNARY_FUNCTOR_WITH_DY_X(Tan, \"Tan\");\n  ADD_UNARY_FUNCTOR_WITH_DY_Y(Tanh, \"Tanh\");\n  m.add_functor<NotEqualZeroFunctor>(\"NotEqualZero\");\n  m.add_functor<LogicalNotFunctor>(\"LogicalNot\");\n  m.add_functor<InplaceSinFunctor>(\"Sin_\");\n  m.add_functor<InplaceFloorFunctor>(\"Floor_\");\n  m.add_functor<InplaceCeilFunctor>(\"Ceil_\");\n  m.add_functor<InplaceRoundFunctor>(\"Round_\");\n  m.add_functor<TrigammaFunctor>(\"Trigamma\");\n};\n\n#undef ADD_UNARY_FUNCTOR_WITH_DY_X\n#undef ADD_UNARY_FUNCTOR_WITH_DY_Y\n\n}  // namespace functional\n}  // namespace one\n}  // namespace oneflow"
  },
  {
    "path": "oneflow/core/functional/impl/unary_functor.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#ifndef ONEFLOW_CORE_FUNCTIONAL_IMPL_UNARY_FUNCTOR_H_\n#define ONEFLOW_CORE_FUNCTIONAL_IMPL_UNARY_FUNCTOR_H_\n\n#include \"oneflow/core/framework/op_expr.h\"\n#include \"oneflow/core/framework/op_interpreter/op_interpreter_util.h\"\n#include \"oneflow/core/framework/tensor.h\"\n#include \"oneflow/core/functional/impl/common.h\"\n#include \"oneflow/core/functional/functional.h\"\n#include \"oneflow/core/functional/tensor_processor.h\"\n\nnamespace oneflow {\nnamespace one {\nnamespace functional {\n\nnamespace impl {\n\nclass UnaryFunctor {\n public:\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x) const {\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {x});\n  }\n\n protected:\n  UnaryFunctor() = default;\n  virtual ~UnaryFunctor() = default;\n\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass InplaceUnaryFunctor {\n public:\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x) const {\n    JUST(CheckInplaceValid(x));\n    std::shared_ptr<TensorTuple> outputs = std::make_shared<TensorTuple>(1);\n    outputs->at(0) = x;\n    if (x->requires_grad()) {\n      JUST(OpInterpUtil::Dispatch(*op_, {JUST(functional::Identity(x))}, outputs.get()));\n    } else {\n      JUST(OpInterpUtil::Dispatch(*op_, {x}, outputs.get()));\n    }\n    return outputs->at(0);\n  }\n\n protected:\n  InplaceUnaryFunctor() = default;\n  virtual ~InplaceUnaryFunctor() = default;\n\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass FloatUnaryFunctor {\n public:\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x) const {\n    // The functor lowest Dtype is Float32. (For sigmoid, tanh and etc. )\n    TensorProcessor tensor_processor;\n    JUST(tensor_processor.AddInputs({x}, DType::Float())\n             .PromoteInputsToCommonDtype(true)\n             .PromoteIntegerInputsToFloatDtype(true)\n             .Apply());\n    TensorTuple input_tuple = JUST(tensor_processor.GetInputs());\n    return OpInterpUtil::Dispatch<one::Tensor>(*op_, input_tuple);\n  }\n\n protected:\n  FloatUnaryFunctor() = default;\n  virtual ~FloatUnaryFunctor() = default;\n\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass InplaceFloatUnaryFunctor {\n public:\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x) const {\n    TensorProcessor tensor_processor;\n    JUST(tensor_processor.AddInputs({x}, DType::Float()).Apply());\n    TensorTuple input_tuple = JUST(tensor_processor.GetInputs());\n    JUST(CheckInplaceCastValid(x, input_tuple.at(0)));\n    JUST(CheckInplaceValid(x));\n    std::shared_ptr<TensorTuple> outputs = std::make_shared<TensorTuple>(1);\n    outputs->at(0) = x;\n    if (x->requires_grad()) {\n      // It should copy input tensor in autograd_mode because these operators can't calculate\n      // in_grad with output.\n      JUST(OpInterpUtil::Dispatch(*op_, {JUST(functional::Identity(x))}, outputs.get()));\n    } else {\n      JUST(OpInterpUtil::Dispatch(*op_, {x}, outputs.get()));\n    }\n    return outputs->at(0);\n  }\n\n protected:\n  InplaceFloatUnaryFunctor() = default;\n  virtual ~InplaceFloatUnaryFunctor() = default;\n\n  std::shared_ptr<OpExpr> op_;\n};\n\n}  // namespace impl\n\n}  // namespace functional\n}  // namespace one\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_FUNCTIONAL_IMPL_UNARY_FUNCTOR_H_\n"
  },
  {
    "path": "oneflow/core/functional/impl/util_ops_functor.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/common/just.h\"\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/framework/op_builder.h\"\n#include \"oneflow/core/framework/op_expr.h\"\n#include \"oneflow/core/framework/op_interpreter/op_interpreter_util.h\"\n#include \"oneflow/core/framework/tensor.h\"\n#include \"oneflow/core/framework/tensor_tuple.h\"\n#include \"oneflow/core/functional/functional.h\"\n#include \"oneflow/core/functional/function_library.h\"\n#include \"oneflow/core/functional/functional_api.yaml.h\"\n#include \"oneflow/core/functional/impl/common.h\"\n\nnamespace oneflow {\nnamespace one {\nnamespace functional {\n\nnamespace impl {\n\nclass UtilOpsFunctor {\n public:\n  Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& input) const {\n    return JUST(OpInterpUtil::Dispatch<Tensor>(*op_, {input}));\n  }\n\n protected:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass IsNanFunctor final : public UtilOpsFunctor {\n public:\n  IsNanFunctor() { op_ = CHECK_JUST(one::OpBuilder(\"isnan\").Input(\"in\").Output(\"out\").Build()); }\n};\n\nclass IsInfFunctor final : public UtilOpsFunctor {\n public:\n  IsInfFunctor() { op_ = CHECK_JUST(one::OpBuilder(\"isinf\").Input(\"in\").Output(\"out\").Build()); }\n};\n\nclass IsFiniteFunctor final : public UtilOpsFunctor {\n public:\n  IsFiniteFunctor() {\n    op_ = CHECK_JUST(one::OpBuilder(\"isfinite\").Input(\"in\").Output(\"out\").Build());\n  }\n};\n\nclass DependFunctor {\n public:\n  DependFunctor() {\n    op_ = CHECK_JUST(\n        one::OpBuilder(\"depend\").Input(\"in\").Input(\"depend_tensor\").Output(\"out\").Build());\n  }\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& in,\n                           const std::shared_ptr<one::Tensor>& depend_tensor) const {\n    return OpInterpUtil::Dispatch<Tensor>(*op_, {in, depend_tensor});\n  }\n\n private:\n  std::shared_ptr<OpExpr> op_;\n};\n\nclass DependTupleFunctor {\n public:\n  DependTupleFunctor() {\n    ops_.resize(kMaxInputCount);\n    for (int n = 0; n < ops_.size(); ++n) {\n      ops_[n] = CHECK_JUST(\n          one::OpBuilder(\"depend\").Input(\"in\").Input(\"depend_tensor\").Output(\"out\").Build());\n    }\n  }\n\n  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& in,\n                           const one::TensorTuple& depends) const {\n    return _dispatch(in, depends, 0);\n  }\n\n private:\n  Maybe<Tensor> _dispatch(const std::shared_ptr<one::Tensor>& in, const one::TensorTuple& depends,\n                          const int pos) const {\n    const size_t ndepend = depends.size();\n    Maybe<Tensor> output = OpInterpUtil::Dispatch<Tensor>(*ops_[pos], {in, depends[pos]});\n    if (pos == ndepend - 1) { return output; }\n    return _dispatch(JUST(output), depends, pos + 1);\n  }\n\n  std::vector<std::shared_ptr<OpExpr>> ops_;\n};\n\n}  // namespace impl\n\nusing namespace impl;\n\nONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor<IsNanFunctor>(\"IsNan\"); };\nONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor<IsInfFunctor>(\"IsInf\"); };\nONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor<IsFiniteFunctor>(\"IsFinite\"); };\nONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor<DependFunctor>(\"Depend\"); };\nONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor<DependTupleFunctor>(\"DependTuple\"); };\n\n}  // namespace functional\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/functional/packed_functor.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#ifndef ONEFLOW_CORE_FUNCTIONAL_FUNCTOR_H_\n#define ONEFLOW_CORE_FUNCTIONAL_FUNCTOR_H_\n\n#include <memory>\n\n#include \"oneflow/core/common/function_traits.h\"\n#include \"oneflow/core/common/type_traits.h\"\n\nnamespace oneflow {\nnamespace one {\nnamespace functional {\n\ntemplate<typename T>\nusing remove_cvref_t = oneflow::detail::remove_cvref_t<T>;\n\ntemplate<typename T>\nclass PackedFunctor;\n\ntemplate<typename R, typename... Args>\nclass PackedFunctor<R(Args...)> {\n public:\n  PackedFunctor(const std::string& func_name, const std::function<R(Args...)>& impl)\n      : func_name_(func_name), impl_(impl) {}\n\n  virtual ~PackedFunctor() = default;\n\n  template<typename... TArgs>\n  R call(TArgs&&... args) const {\n    return impl_(std::forward<TArgs>(args)...);\n  }\n\n private:\n  std::string func_name_;\n  std::function<R(Args...)> impl_;\n};\n\ntemplate<typename T>\nclass PackedFunctorMaker;\n\ntemplate<typename R, typename... Args>\nclass PackedFunctorMaker<R(Args...)> {\n public:\n  using FType = R(const remove_cvref_t<Args>&...);\n\n  template<typename Func,\n           typename std::enable_if<\n               std::is_same<typename function_traits<Func>::func_type, R(Args...)>::value,\n               int>::type = 0>\n  static PackedFunctor<FType> make(const std::string& func_name, const Func& func) {\n    return PackedFunctor<FType>(func_name, [func](const remove_cvref_t<Args>&... args) -> R {\n      return func(std::forward<const remove_cvref_t<Args>&>(args)...);\n    });\n  }\n};\n\n}  // namespace functional\n}  // namespace one\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_FUNCTIONAL_FUNCTOR_H_\n"
  },
  {
    "path": "oneflow/core/functional/sequence_function.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_FUNCTIONAL_SEQUENCE_FUNCTION_H_\n#define ONEFLOW_CORE_FUNCTIONAL_SEQUENCE_FUNCTION_H_\n\n#include <functional>\n#include <utility>\n#include \"oneflow/core/common/maybe.h\"\n\nnamespace oneflow {\nnamespace one {\nnamespace functional {\n\ntemplate<typename T>\nclass SequenceFunction;\n\ntemplate<typename R, typename... Args>\nclass SequenceFunction<R(Args...)> {\n public:\n  using first_f_type = std::function<R(Args...)>;\n  using f_type = std::function<R(\n      const decltype(std::declval<R>().Data_YouAreNotAllowedToCallThisFuncOutsideThisFile())&)>;\n\n  explicit SequenceFunction(first_f_type&& f) : fn_(std::forward<first_f_type>(f)) {}\n\n  explicit SequenceFunction(const first_f_type& f) : fn_(f) {}\n\n  SequenceFunction<R(Args...)>& then(f_type&& f) {\n    auto fn_ = std::move(this->fn_);\n    this->fn_ = [fn_, f](Args&&... args) -> R { return f(JUST(fn_(std::forward<Args>(args)...))); };\n    return *this;\n  }\n\n  SequenceFunction<R(Args...)>& then_if(bool condition, f_type&& f) {\n    return condition ? then(std::forward<f_type>(f)) : *this;\n  }\n\n  SequenceFunction<R(Args...)>& operator<<(f_type&& f) { return then(std::forward<f_type>(f)); }\n\n  R call(Args&&... args) const { return fn_(std::forward<Args>(args)...); }\n\n private:\n  std::function<R(Args...)> fn_;\n};\n\n#define sequence_function(f) SequenceFunction<decltype(f)>(f)\n\n}  // namespace functional\n}  // namespace one\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_FUNCTIONAL_SEQUENCE_FUNCTION_H_\n"
  },
  {
    "path": "oneflow/core/functional/tensor_index.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/functional/tensor_index.h\"\n\n#include \"oneflow/core/framework/tensor.h\"\n#include \"oneflow/core/framework/device.h\"\n#include \"oneflow/core/framework/instructions_builder.h\"\n#include \"oneflow/core/framework/tensor_tuple.h\"\n#include \"oneflow/core/framework/tensor_util.h\"\n#include \"oneflow/core/framework/nd_sbp.h\"\n#include \"oneflow/core/functional/functional.h\"\n#include \"oneflow/core/job/sbp_parallel.h\"\n#include \"oneflow/core/common/stride.h\"\n#include \"oneflow/core/framework/op_builder.h\"\n#include \"oneflow/core/framework/op_interpreter/op_interpreter_util.h\"\n#include \"oneflow/core/kernel/kernel_util.h\"\n\nnamespace oneflow {\nnamespace one {\nnamespace functional {\n\nnamespace {\n\nint64_t CountSpecifiedDims(const TensorIndex& index) {\n  int64_t specified_ndims = 0;\n  for (int i = 0; i < index.size(); ++i) {\n    const auto& index_item = index.at(i);\n    if (index_item.IsSlice() || index_item.IsInteger()) {\n      specified_ndims++;\n    } else if (index_item.IsTensor()) {\n      const auto& tensor = index_item.tensor();\n      if (IsMaskTensor(tensor)) {\n        specified_ndims += tensor->ndim();\n      } else {\n        specified_ndims++;\n      }\n    }\n  }\n  return specified_ndims;\n}\n\nMaybe<TensorTuple> ExpandMaskIndex(const std::shared_ptr<Tensor>& index) {\n  auto indices = std::make_shared<TensorTuple>();\n  const auto& res = JUST(functional::ArgWhere(index, DType::Int64()));\n  if (res->size() != 2) {\n    return Error::RuntimeError() << \"Argwhere should returns 2 tensors, but got \" << res->size();\n  }\n  auto size_tensor = res->at(1);\n  if (!size_tensor->is_eager()) {\n    return Error::RuntimeError()\n           << \"Advanced indexing by boolean(mask) tensor only valid in eager mode.\";\n  }\n  if (size_tensor->is_global()) {\n    // TODO(): check size_tensor sbp is broadcast.\n    size_tensor = JUST(functional::GlobalToLocal(size_tensor, /*copy=*/false));\n  }\n  int64_t size = 0;\n  const auto& callback = [&](ep::Stream* stream,\n                             const std::shared_ptr<vm::EagerBlobObject>& eager_blob_object) {\n    AutoMemcpy(stream, &size, eager_blob_object->dptr(), sizeof(size), memory::MakeHostMemCase(),\n               eager_blob_object->mem_case());\n  };\n  JUST(SyncAccessTensorWithTimeOut(size_tensor, callback, \"const\"));\n\n  for (int i = 0; i < index->ndim(); ++i) {\n    auto item = JUST(functional::Slice((*res)[0], {0, i}, {size, i + 1}, {1, 1},\n                                       /*enable_view_slice=*/false));\n    item = JUST(functional::Reshape(item, {size}));\n    indices->emplace_back(item);\n  }\n  return indices;\n}\n\n// NOTE: expand each non-empty indice to same shape.\nMaybe<TensorTuple> ExpandIndices(const TensorTuple& indices) {\n  std::shared_ptr<const Shape> expanded_shape;\n  {\n    bool first = true;\n    for (int i = 0; i < indices.size(); ++i) {\n      if (!indices.at(i)) { continue; }\n      if (first) {\n        expanded_shape = indices.at(i)->shape();\n        first = false;\n      } else {\n        const auto& shape = indices.at(i)->shape();\n        int ndims = std::max(shape->NumAxes(), expanded_shape->NumAxes());\n        DimVector sizes(ndims);\n        for (int j = ndims - 1; j >= 0; --j) {\n          int dim = j - (ndims - shape->NumAxes());\n          int expanded_dim = j - (ndims - expanded_shape->NumAxes());\n          if (dim < 0) {\n            sizes[j] = expanded_shape->At(expanded_dim);\n          } else if (expanded_dim < 0) {\n            sizes[j] = shape->At(dim);\n          } else {\n            int size = shape->At(dim);\n            int expanded_size = expanded_shape->At(expanded_dim);\n            CHECK_OR_RETURN(size == expanded_size || size == 1 || expanded_size == 1)\n                << Error::RuntimeError() << \"The size of tensor a (\" << size\n                << \") must match the size of tensor b (\" << expanded_size\n                << \") at non-singleton dimension \" << i;\n            sizes[j] = size == 1 ? expanded_size : size;\n          }\n        }\n        expanded_shape.reset(new Shape(sizes));\n      }\n    }\n  }\n  auto expanded_indices = std::make_shared<TensorTuple>(indices.size());\n  for (int i = 0; i < indices.size(); ++i) {\n    if (!indices.at(i)) { continue; }\n    if (*(indices.at(i)->shape()) != *expanded_shape) {\n      expanded_indices->at(i) = JUST(Expand(indices.at(i), *expanded_shape));\n    } else {\n      expanded_indices->at(i) = indices.at(i);\n    }\n  }\n  return expanded_indices;\n}\n\n// NOTE(wyg):\n// Judge whether all index dims are contiguous.\n// e.g. [:, index0, index1, :] -> True\n// [index0, :, index1] -> False\n// [index0, index1, :] -> True\nMaybe<bool> IsContinuousSubspace(const TensorTuple& indices) {\n  int token = 0;\n  for (int i = 0; i < indices.size(); ++i) {\n    if (indices.at(i) && !token) {\n      token = 1;\n    } else if (indices.at(i) && token) {\n      if (token != 1) { return false; }\n    } else if (token) {\n      token += 1;\n    }\n  }\n  return true;\n}\n\n// NOTE(wyg):\n// Move indices subspace to be contiguous and ahead.\n// e.g. [:, index0, index1] -> [index0, index1, :]\nMaybe<std::vector<int>> TransposeFront(const std::shared_ptr<Tensor>& input,\n                                       const TensorTuple& indices, std::shared_ptr<Tensor>* output,\n                                       TensorTuple* valid_indices) {\n  std::vector<int> permute;\n  permute.reserve(input->ndim());\n  for (int i = 0; i < input->ndim(); ++i) {\n    if (i < indices.size() && indices.at(i)) {\n      permute.emplace_back(i);\n      valid_indices->emplace_back(indices.at(i));\n    }\n  }\n  for (int i = 0; i < input->ndim(); ++i) {\n    if (i >= indices.size() || !indices.at(i)) { permute.emplace_back(i); }\n  }\n  bool need_transpose = [&]() {\n    for (int i = 0; i < permute.size(); ++i) {\n      if (permute.at(i) != i) { return true; }\n    }\n    return false;\n  }();\n  if (need_transpose) {\n    *output = JUST(Transpose(input, permute));\n  } else {\n    *output = input;\n  }\n  return permute;\n}\n\nMaybe<Tensor> AdjustSubspace(const std::shared_ptr<Tensor>& input, const TensorTuple& indices,\n                             const int& index_ndim, bool reverse = false) {\n  int index_subspace_pos = -1;\n  for (int i = 0; i < indices.size(); ++i) {\n    if (indices.at(i)) {\n      index_subspace_pos = i;\n      break;\n    }\n  }\n  if (index_subspace_pos <= 0) { return input; }\n  int ndim = input->ndim();\n  CHECK_LE_OR_RETURN(index_subspace_pos + index_ndim, ndim)\n      << Error::IndexError()\n      << \"Failed to adjust subspace since the index is out of bounds for tensor dimension \" << ndim;\n  std::vector<int> permute;\n  {\n    permute.reserve(ndim);\n    if (reverse) {\n      for (int i = 0; i < index_ndim; ++i) { permute.emplace_back(index_subspace_pos + i); }\n      for (int i = 0; i < index_subspace_pos; ++i) { permute.emplace_back(i); }\n    } else {\n      for (int i = 0; i < index_subspace_pos; ++i) { permute.emplace_back(i + index_ndim); }\n      for (int i = 0; i < index_ndim; ++i) { permute.emplace_back(i); }\n    }\n    for (int i = permute.size(); i < ndim; ++i) { permute.emplace_back(i); }\n  }\n  return Transpose(input, permute);\n}\n\nMaybe<bool> HasFalseIndex(const TensorIndex& index) {\n  return std::any_of(index.begin(), index.end(), [](const detail::IndexItem& item) {\n    return item.IsBoolean() && !item.boolean();\n  });\n}\n\nbool IsValidScalarTensorIndex(const std::shared_ptr<one::Tensor>& tensor) {\n  if (!(tensor->dtype()->is_integer() || tensor->dtype() == DType::Bool())) { return false; }\n  return tensor->shape()->NumAxes() == 0 && tensor->shape()->elem_cnt() == 1;\n}\n\n// Permute back for global tensor which transpose dims to front\nMaybe<Tensor> PermuteBackForGlobalTensor(const std::shared_ptr<Tensor>& result,\n                                         const std::vector<int>& permute) {\n  CHECK_OR_RETURN(result->is_global());                // NOLINT(maybe-need-error-msg)\n  CHECK_EQ_OR_RETURN(result->ndim(), permute.size());  // NOLINT(maybe-need-error-msg)\n  std::vector<int> inv_permute(permute.size());\n  for (int32_t i = 0; i < permute.size(); ++i) { inv_permute[permute[i]] = i; }\n\n  bool not_permute = true;\n  {\n    for (int32_t i = 0; i < permute.size(); ++i) {\n      if (inv_permute[i] != i) {\n        not_permute = false;\n        break;\n      }\n    }\n  }\n  if (!not_permute) {\n    return Transpose(result, inv_permute);\n  } else {\n    return result;\n  }\n}\n\n}  // namespace\n\nbool IsMaskTensor(const std::shared_ptr<Tensor>& tensor) {\n  return tensor->dtype() == DType::Int8() || tensor->dtype() == DType::UInt8()\n         || tensor->dtype() == DType::Bool();\n}\n\nMaybe<void> PrepareSliceIndices(const TensorIndex& index, const Shape& shape,\n                                std::vector<detail::Slice>* slice_indices,\n                                TensorTuple* tensor_indices, std::vector<int64_t>* expand_dims,\n                                std::vector<int64_t>* target_dims) {\n  int64_t ndims = shape.NumAxes();\n  int64_t specified_ndims = CountSpecifiedDims(index);\n  CHECK_LE_OR_RETURN(specified_ndims, ndims)\n      << Error::IndexError() << \"Too many indices for tensor of dimension \" << ndims;\n  bool has_false_index = JUST(HasFalseIndex(index));\n  bool has_expand_boolean_dim = false;\n  int dim = 0;\n  for (int i = 0; i < index.size(); ++i) {\n    const auto& index_item = index.at(i);\n    if (index_item.IsNone()) {\n      expand_dims->emplace_back(dim);\n      slice_indices->emplace_back(0, 1, 1);\n      target_dims->emplace_back(1);\n      continue;\n    }\n    if (index_item.IsBoolean()) {\n      if (!has_expand_boolean_dim) {\n        int boolean_index = !has_false_index;\n        expand_dims->emplace_back(dim);\n        slice_indices->emplace_back(0, boolean_index, 1);\n        target_dims->emplace_back(boolean_index);\n        has_expand_boolean_dim = true;\n      }\n      continue;\n    }\n    if (index_item.IsEllipsis()) {\n      int64_t unspecified_ndims = ndims - specified_ndims;\n      unspecified_ndims = std::min(ndims - dim, unspecified_ndims);\n      for (int j = 0; j < unspecified_ndims; ++j) {\n        slice_indices->emplace_back(0, shape.At(dim + j), 1);\n        target_dims->emplace_back(shape.At(dim + j));\n      }\n      dim += unspecified_ndims;\n      continue;\n    }\n    CHECK_LT_OR_RETURN(dim, ndims)\n        << Error::IndexError() << \"Invalid index for tensor of dimension \" << ndims;\n    if (index_item.IsSlice()) {\n      const auto& slice = index_item.slice();\n      CHECK_GT_OR_RETURN(slice.step(), 0)\n          << Error::RuntimeError() << \"Step must be greater than zero.\";\n      int64_t step = std::min(slice.step(), shape.At(dim));\n      int64_t end = std::min(slice.end(), shape.At(dim));\n      int64_t start = std::min(slice.start(), shape.At(dim));\n      if (start < 0) { start += shape.At(dim); }\n      if (start < 0) { start = 0; }\n      if (end < 0) { end += shape.At(dim); }\n      if (end < start) { end = start; }\n      if (start == end) { step = 1; }\n      slice_indices->emplace_back(start, end, step);\n      int64_t length = start == end ? 0 : (end - start + step - 1) / step;\n      target_dims->emplace_back(length);\n      dim++;\n    } else if (index_item.IsInteger()) {\n      int64_t integer = index_item.integer();\n      if (integer < 0) { integer += shape.At(dim); }\n      if (integer < 0 || integer >= shape.At(dim)) {\n        return Error::IndexError()\n               << \"Index \" << index_item.integer() << \" is out of bounds for dimension \" << dim\n               << \" with size \" << shape.At(dim);\n      }\n      slice_indices->emplace_back(integer, integer + 1, 1);\n      dim++;\n    } else if (index_item.IsTensor()) {\n      const auto& tensor = index_item.tensor();\n      if (IsValidScalarTensorIndex(tensor) && !LazyMode::is_enabled()) {\n        if (tensor->dtype()->is_integer() && tensor->dtype()->data_type() != DataType::kBool) {\n          int64_t integer = JUST(GetItemInScalarTensor<int64_t>(tensor));\n          if (integer < 0) { integer += shape.At(dim); }\n          if (integer < 0 || integer >= shape.At(dim)) {\n            return Error::IndexError()\n                   << \"Index \" << index_item.integer() << \" is out of bounds for dimension \" << dim\n                   << \" with size \" << shape.At(dim);\n          }\n          slice_indices->emplace_back(integer, integer + 1, 1);\n          dim++;\n        } else {\n          bool boolean_index = JUST(GetItemInScalarTensor<bool>(tensor));\n          if (!has_expand_boolean_dim) {\n            expand_dims->emplace_back(dim);\n            slice_indices->emplace_back(0, boolean_index, 1);\n            target_dims->emplace_back(boolean_index);\n            has_expand_boolean_dim = true;\n          }\n        }\n      } else {\n        auto indices = std::make_shared<TensorTuple>();\n        if (tensor->dtype() == DType::Int8() || tensor->dtype() == DType::UInt8()\n            || tensor->dtype() == DType::Bool()) {\n          for (int j = 0; j < tensor->ndim(); ++j) {\n            if (tensor->shape()->At(j) != shape.At(dim + j)) {\n              return Error::IndexError()\n                     << \"The shape of the mask \" << tensor->shape()->ToString() << \" at index \" << j\n                     << \" does not match the shape of the indexed tensor \" << shape.ToString()\n                     << \" at index \" << dim + j;\n            }\n          }\n          indices = JUST(ExpandMaskIndex(tensor));\n        } else {\n          indices->emplace_back(tensor);\n        }\n        for (int j = 0; j < indices->size(); ++j) {\n          slice_indices->emplace_back(0, shape.At(dim), 1);\n          tensor_indices->resize(target_dims->size());\n          tensor_indices->emplace_back(indices->at(j));\n          target_dims->emplace_back(shape.At(dim));\n          dim++;\n        }\n      }\n    }\n  }\n  for (int i = dim; i < ndims; ++i) {\n    slice_indices->emplace_back(0, shape.At(i), 1);\n    target_dims->emplace_back(shape.At(i));\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<std::vector<detail::Slice>> RemoveExpandDimSlice(\n    const std::vector<detail::Slice>& expand_slices, const std::vector<int64_t>& expand_dims) {\n  auto slices = std::make_shared<std::vector<detail::Slice>>();\n  std::vector<int> mask(expand_slices.size(), 0);\n  for (const auto& dim : expand_dims) {\n    if (dim >= expand_slices.size()) {\n      return Error::IndexError() << \"Dimension \" << dim << \" is out of bounds for size \"\n                                 << expand_slices.size();\n    }\n    mask[dim] = 1;\n  }\n  for (int i = 0; i < expand_slices.size(); ++i) {\n    if (!mask[i]) { slices->emplace_back(expand_slices.at(i)); }\n  }\n  return slices;\n}\n\nMaybe<Tensor> ApplyAdvancedIndexing(const std::shared_ptr<Tensor>& input,\n                                    const TensorTuple& indices) {\n  CHECK_GE_OR_RETURN(input->ndim(), indices.size())\n      << Error::IndexError() << \"Too many indices for tensor of dimension \" << input->ndim();\n  const auto& expanded_indices = JUST(ExpandIndices(indices));\n  bool is_continuous_subspace = JUST(IsContinuousSubspace(indices));\n\n  // Since the start dimension cannot be specified for `gather_nd`, so we should\n  // transpose the input as long as the first index is null.\n  std::shared_ptr<Tensor> transposed_input;\n  TensorTuple valid_indices;\n  JUST(TransposeFront(input, *expanded_indices, &transposed_input, &valid_indices));\n  if (valid_indices.empty()) { return input; }\n  int index_ndim = valid_indices.at(0)->ndim();\n  std::shared_ptr<Tensor> packed_indices;\n  if (valid_indices.size() == 1) {\n    packed_indices = JUST(functional::Unsqueeze(valid_indices.at(0), -1));\n  } else {\n    packed_indices = JUST(Stack(valid_indices, 0));\n    int packed_ndim = packed_indices->ndim();\n    CHECK_GT_OR_RETURN(packed_ndim, 0)\n        << Error::RuntimeError() << \"Index array dimension should be greater than 0.\";\n    std::vector<int> permute(packed_ndim);\n    permute[packed_ndim - 1] = 0;\n    std::iota(permute.begin(), permute.end() - 1, 1);\n    packed_indices = JUST(Transpose(packed_indices, permute))->contiguous();\n  }\n\n  CHECK_EQ_OR_RETURN(transposed_input->is_local(), packed_indices->is_local())\n      << Error::RuntimeError() << \"The input and indices must be both local or global.\";\n\n  auto result = JUST(GatherNd(transposed_input, packed_indices));\n\n  int required_ndim = input->ndim() - valid_indices.size() + index_ndim;\n  CHECK_EQ_OR_RETURN(result->ndim(), required_ndim)\n      << Error::RuntimeError() << \"The indexing result dimension is \" << result->ndim()\n      << \", but shoule be \" << required_ndim;\n  if (is_continuous_subspace) {\n    result = JUST(AdjustSubspace(result, indices, index_ndim, /*reverse*/ false));\n  }\n  return result;\n}\n\nMaybe<Tensor> ApplyAdvancedIndexingUpdate(const std::shared_ptr<Tensor>& input,\n                                          const TensorTuple& indices,\n                                          const std::shared_ptr<Tensor>& value) {\n  CHECK_GE_OR_RETURN(input->ndim(), indices.size())\n      << Error::IndexError() << \"Too many indices for tensor of dimension \" << input->ndim();\n  const auto& expanded_indices = JUST(ExpandIndices(indices));\n  bool is_continuous_subspace = JUST(IsContinuousSubspace(indices));\n\n  // Since the start dimension cannot be specified for `scatter_nd`, so we should\n  // transpose the input as long as the first index is null.\n  std::shared_ptr<Tensor> transposed_input;\n  TensorTuple valid_indices;\n  const auto& transposed_input_permute =\n      JUST(TransposeFront(input, *expanded_indices, &transposed_input, &valid_indices));\n  // NOTE: For local tensor, we make sure that transposed_input is a view of input.\n  //       Therefore we need not transpose it back because we update the value in a same memory\n  //       by tensor_scatter_nd_update operator.\n  if (input->is_local()) {\n    CHECK_EQ_OR_RETURN(JUST(transposed_input->tensor_storage()), JUST(input->tensor_storage()))\n        << Error::RuntimeError()\n        << \"This setitem operator must enable view mechanism, please try to set \"\n           \"ONEFLOW_DISABLE_VIEW=0\";\n  }\n\n  if (valid_indices.empty()) {\n    CHECK_EQ_OR_RETURN(value->nelement(), 0) << Error::IndexError() << \"invalid indices\";\n    return input;\n  }\n  int index_ndim = valid_indices[0]->ndim();\n  auto packed_indices = JUST(Stack(valid_indices, 0));\n  {\n    int packed_ndim = packed_indices->ndim();\n    CHECK_GT_OR_RETURN(packed_ndim, 0)\n        << Error::RuntimeError() << \"Index array dimension should be greater than 0.\";\n    std::vector<int> permute(packed_ndim);\n    permute[packed_ndim - 1] = 0;\n    std::iota(permute.begin(), permute.end() - 1, 1);\n    packed_indices = JUST(Transpose(packed_indices, permute))->contiguous();\n  }\n\n  CHECK_EQ_OR_RETURN(transposed_input->is_local(), packed_indices->is_local())\n      << Error::RuntimeError() << \"The input and indices must be both local or global.\";\n\n  Shape expand_shape;\n  {\n    if (is_continuous_subspace) {\n      bool index_subspace_begin = true;\n      for (int i = 0; i < indices.size(); ++i) {\n        // if the index is the first not-null index\n        if (indices[i]) {\n          if (!index_subspace_begin) { continue; }\n          for (int j = 0; j < index_ndim; ++j) {\n            expand_shape.emplace_back(valid_indices[0]->shape()->At(j));\n          }\n          index_subspace_begin = false;\n        } else {\n          expand_shape.emplace_back(input->shape()->At(i));\n        }\n      }\n    } else {\n      expand_shape = *(valid_indices[0]->shape());\n      for (int i = 0; i < indices.size(); ++i) {\n        if (!indices[i]) { expand_shape.emplace_back(input->shape()->At(i)); }\n      }\n    }\n    for (int i = indices.size(); i < input->ndim(); ++i) {\n      expand_shape.emplace_back(input->shape()->At(i));\n    }\n  }\n  std::shared_ptr<Tensor> expand_value = JUST(Expand(value, expand_shape));\n\n  // reverse adjust value if index subspace is continuous but transposed since the start\n  // dimension cannot be specified for `scatter_nd`\n  if (is_continuous_subspace) {\n    expand_value = JUST(AdjustSubspace(expand_value, indices, index_ndim, /*reverse*/ true));\n  }\n  JUST(TensorScatterNdUpdate(transposed_input, packed_indices, expand_value, /*inplace=*/true));\n  // Global tensor is not support view, so we should permute back and copy to origin input if need\n  if (transposed_input->is_global()) {\n    return PermuteBackForGlobalTensor(transposed_input, *transposed_input_permute);\n  }\n  return transposed_input;\n}\n\nMaybe<Tensor> ApplySelectIndexing(const std::shared_ptr<one::Tensor>& input,\n                                  const TensorIndex& tensor_index) {\n  const int32_t index = tensor_index[0].integer();\n  const int32_t ndim = input->ndim();\n  CHECK_OR_RETURN(ndim > 0) << Error::RuntimeError()\n                            << \"select() cannot be applied to a 0-dim tensor.\";\n  const int32_t pos_dim = 0;\n  auto size = input->dim(pos_dim);\n  CHECK_OR_RETURN(index >= -size && index < size)\n      << Error::IndexError() << \"Index out of range (expected to be in range of [\" << -size << \",\"\n      << size - 1 << \"], but got \" << index << \")\";\n  int32_t pos_index = index >= 0 ? index : index + size;\n  std::vector<int64_t> sizes(input->shape()->dim_vec().begin() + 1,\n                             input->shape()->dim_vec().end());\n  const auto& stride = *JUST(input->stride());\n  const int64_t storage_offset = JUST(input->storage_offset()) + pos_index * stride[pos_dim];\n  std::vector<int64_t> strides(stride.begin() + 1, stride.end());\n  return functional::AsStrided(input, sizes, strides, storage_offset);\n}\n\nMaybe<void> UnifyInputAndIndicesOnDevice(const std::shared_ptr<Tensor>& x,\n                                         TensorTuple& tensor_indices) {\n  if (x->is_local()) {\n    const auto x_device = JUST(x->device());\n    for (int64_t i = 0; i < tensor_indices.size(); ++i) {\n      const auto tensor_index = tensor_indices[i];\n      if (tensor_index == nullptr) { continue; }\n      if (tensor_index->is_global()) { return Maybe<void>::Ok(); }\n      const auto tensor_index_device = JUST(tensor_index->device());\n      if ((tensor_index_device->type() != x_device->type())\n          || (tensor_index_device->device_id() != x_device->device_id())) {\n        tensor_indices[i] =\n            JUST(Copy(tensor_index, x_device->type(), x_device->device_id(), /*pin_memory=*/false));\n      }\n    }\n  } else {\n    // global tensor\n    const auto& placement = JUST(x->parallel_desc());\n    const auto& broadcast_sbp = JUST(MakeBroadcastSbpParallel());\n    int n = JUST(x->nd_sbp())->sbp_parallel_size();\n    std::vector<Symbol<SbpParallel>> grad_sbp_tuple;\n    for (int64_t i = 0; i < tensor_indices.size(); ++i) {\n      const auto tensor_index = tensor_indices[i];\n      if (tensor_index == nullptr) { continue; }\n      if (tensor_index->is_local()) {\n        // NOTE: LocalToGlobal should be called in eager mode\n        LazyMode::Guard lazy_mode_disabled_guard(/*is_enabled*/ false);\n        tensor_indices[i] = JUST(ToGlobal(tensor_index, placement,\n                                          std::vector<Symbol<SbpParallel>>(n, broadcast_sbp),\n                                          grad_sbp_tuple, /*check_meta=*/false, /*copy=*/false));\n      }\n    }\n  }\n  return Maybe<void>::Ok();\n}\n\n}  // namespace functional\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/functional/tensor_index.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#ifndef ONEFLOW_CORE_FUNCTIONAL_TENSOR_INDEX_H_\n#define ONEFLOW_CORE_FUNCTIONAL_TENSOR_INDEX_H_\n\n#include <cstdint>\n#include <limits>\n#include <vector>\n\n#include \"oneflow/core/common/shape.h\"\n\nnamespace oneflow {\nnamespace one {\n\nclass Tensor;\nclass TensorTuple;\n\nnamespace functional {\n\nnamespace detail {\n\nstruct NoneIndex {};\nstruct EllipsisIndex {};\n\nclass Slice {\n public:\n  Slice() : Slice(0, std::numeric_limits<int64_t>::max(), 1) {}\n  explicit Slice(int64_t start) : Slice(start, std::numeric_limits<int64_t>::max(), 1) {}\n  explicit Slice(int64_t start, int64_t end) : Slice(start, end, 1) {}\n  explicit Slice(int64_t start, int64_t end, int64_t step)\n      : start_(start), end_(end), step_(step) {}\n\n  int64_t start() const { return start_; }\n  int64_t end() const { return end_; }\n  int64_t step() const { return step_; }\n  std::string ToString() const {\n    std::stringstream ss;\n    ss << \"[\" << start_ << \":\" << end_ << \":\" << step_ << \"]\\n\";\n    return ss.str();\n  }\n\n private:\n  int64_t start_;\n  int64_t end_;\n  int64_t step_;\n};\n\nclass IndexItem {\n public:\n  IndexItem() : IndexItem(NoneIndex()) {}\n  explicit IndexItem(NoneIndex none) : item_{.dummy = 0}, tag_(HAS_NONE) {}\n\n  explicit IndexItem(int64_t start, int64_t end, int64_t step)\n      : item_{.slice = Slice{start, end, step}}, tag_(HAS_SLICE) {}\n  explicit IndexItem(const Slice& slice) : item_{.slice = slice}, tag_(HAS_SLICE) {}\n\n  explicit IndexItem(int64_t index) : item_{.i = index}, tag_(HAS_INT) {}\n  explicit IndexItem(bool boolean) : item_{.b = boolean}, tag_(HAS_BOOLEAN) {}\n  explicit IndexItem(EllipsisIndex ellipsis) : item_{.dummy = 0}, tag_(HAS_ELLIPSIS) {}\n\n  explicit IndexItem(const std::shared_ptr<Tensor>& tensor)\n      : item_{.dummy = 0}, tensor_(tensor), tag_(HAS_TENSOR) {}\n\n  bool IsSlice() const { return tag_ == HAS_SLICE; }\n  const Slice& slice() const { return item_.slice; }\n\n  bool IsInteger() const { return tag_ == HAS_INT; }\n  int64_t integer() const { return item_.i; }\n\n  bool IsBoolean() const { return tag_ == HAS_BOOLEAN; }\n  bool boolean() const { return item_.b; }\n\n  bool IsEllipsis() const { return tag_ == HAS_ELLIPSIS; }\n\n  bool IsNone() const { return tag_ == HAS_NONE; }\n\n  bool IsTensor() const { return tag_ == HAS_TENSOR; }\n  const std::shared_ptr<Tensor>& tensor() const { return tensor_; }\n\n private:\n  union {\n    Slice slice;\n    bool b;\n    int64_t i;\n    char dummy;\n  } item_;\n  std::shared_ptr<Tensor> tensor_;\n  enum { HAS_SLICE, HAS_BOOLEAN, HAS_INT, HAS_ELLIPSIS, HAS_NONE, HAS_TENSOR } tag_;\n};\n\n}  // namespace detail\n\nclass TensorIndex : public std::vector<detail::IndexItem> {\n public:\n  using std::vector<detail::IndexItem>::vector;\n};\n\nbool IsMaskTensor(const std::shared_ptr<Tensor>& tensor);\n\nMaybe<void> PrepareSliceIndices(const TensorIndex& index, const Shape& shape,\n                                std::vector<detail::Slice>* slice_indices,\n                                TensorTuple* tensor_indices, std::vector<int64_t>* expand_dims,\n                                std::vector<int64_t>* target_dims);\n\nMaybe<std::vector<detail::Slice>> RemoveExpandDimSlice(\n    const std::vector<detail::Slice>& expand_slices, const std::vector<int64_t>& expand_dims);\n\nMaybe<Tensor> ApplyAdvancedIndexing(const std::shared_ptr<Tensor>& input,\n                                    const TensorTuple& indices);\n\nMaybe<Tensor> ApplySelectIndexing(const std::shared_ptr<one::Tensor>& input,\n                                  const TensorIndex& index);\n\nMaybe<void> UnifyInputAndIndicesOnDevice(const std::shared_ptr<Tensor>& x,\n                                         TensorTuple& tensor_indices);\n\nMaybe<Tensor> ApplyAdvancedIndexingUpdate(const std::shared_ptr<Tensor>& input,\n                                          const TensorTuple& indices,\n                                          const std::shared_ptr<Tensor>& value);\n\n}  // namespace functional\n}  // namespace one\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_FUNCTIONAL_TENSOR_INDEX_H_\n"
  },
  {
    "path": "oneflow/core/functional/tensor_processor.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <glog/logging.h>\n#include <cstdio>\n#include \"oneflow/core/functional/tensor_processor.h\"\n#include \"oneflow/core/common/symbol.h\"\n#include \"oneflow/core/common/throw.h\"\n#include \"oneflow/core/framework/dtype.h\"\n#include \"oneflow/core/functional/functional.h\"\n#include \"oneflow/core/job/lazy_mode.h\"\n\nnamespace oneflow {\nnamespace one {\nnamespace functional {\n\nnamespace {\n\nSymbol<DType> ComputeCommonDType(const TensorTuple& tensor_tuple) {\n  Symbol<DType> common_dtype = DType::InvalidDataType();\n  bool all_scalar_tensors = std::all_of(\n      tensor_tuple.begin(), tensor_tuple.end(),\n      [](const std::shared_ptr<Tensor>& tensor) { return tensor->shape()->NumAxes() == 0; });\n  for (auto& tensor_ptr : tensor_tuple) {\n    // skip scalar tensor\n    if (!all_scalar_tensors && tensor_ptr->shape()->NumAxes() == 0\n        && !(tensor_ptr->dtype()->is_complex())) {\n      continue;\n    }\n    common_dtype = promoteTypes(tensor_ptr->dtype(), common_dtype);\n  }\n  return common_dtype;\n}\n\nbool CheckHasDifferentInputDType(const TensorTuple& tensor_tuple) {\n  if (tensor_tuple.size() <= 1) { return false; }\n  Symbol<DType> common_dtype = tensor_tuple[0]->dtype();\n  for (auto& tensor_ptr : tensor_tuple) {\n    if (common_dtype != tensor_ptr->dtype()) { return true; }\n  }\n  return false;\n}\n\nMaybe<void> CastToSameType(TensorTuple& tensor_tuple, const Symbol<DType>& common_dtype) {\n  for (auto& tensor_ptr : tensor_tuple) {\n    if (tensor_ptr->dtype() != common_dtype) {\n      tensor_ptr = JUST(functional::Cast(tensor_ptr, common_dtype, /*pin_memory=*/false));\n    }\n  }\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\nTensorProcessor& TensorProcessor::AddInputs(const TensorTuple& init_tensor_or_tuple) {\n  for (const auto& tensor : init_tensor_or_tuple) {\n    tensor_tuple_.emplace_back(tensor);\n    inputs_lowest_dtype_vec_.emplace_back(DType::InvalidDataType());\n  }\n  return *this;\n}\n\nTensorProcessor& TensorProcessor::AddInputs(const TensorTuple& init_tensor_or_tuple,\n                                            Symbol<DType> tensor_lowest_dtype) {\n  for (const auto& tensor : init_tensor_or_tuple) {\n    tensor_tuple_.emplace_back(tensor);\n    inputs_lowest_dtype_vec_.emplace_back(tensor_lowest_dtype);\n  }\n  return *this;\n}\n\nTensorProcessor& TensorProcessor::PromoteInputsToCommonDtype(bool is_promote) {\n  promote_inputs_to_common_dtype_ = is_promote;\n  return *this;\n}\n\nTensorProcessor& TensorProcessor::PromoteInputsToCommonDtype(\n    bool is_promote, const Optional<Symbol<DType>>& promote_dtype) {\n  promote_inputs_to_common_dtype_ = is_promote;\n  promote_dtype_ = promote_dtype;\n  return *this;\n}\n\nTensorProcessor& TensorProcessor::PromoteIntegerInputsToFloatDtype(bool is_promote) {\n  promote_integer_inputs_to_float_ = is_promote;\n  CHECK_OR_THROW(!promote_integer_inputs_to_float_ || promote_inputs_to_common_dtype_)\n      << \"when set promote_integer_inputs_to_float to 'True', then promote_inputs_to_common_dtype \"\n         \"should be set to 'True' first!\";\n  return *this;\n}\n\nMaybe<void> TensorProcessor::Apply() {\n  if (promote_inputs_to_common_dtype_) {\n    bool has_different_input_dtype = CheckHasDifferentInputDType(tensor_tuple_);\n    if (has_different_input_dtype) {\n      if (promote_dtype_.has_value()) {\n        common_dtype_ = CHECK_JUST(promote_dtype_);\n      } else {\n        common_dtype_ = ComputeCommonDType(tensor_tuple_);\n      }\n      if (promote_integer_inputs_to_float_ && common_dtype_->is_integer()) {\n        // Promotes common dtype to the default float scalar type, if needed.\n        // same to pytorch's computeTypes() in torch/csrc/jit/codegen/cuda/type_promotion.cpp\n        common_dtype_ = DType::Float();\n      }\n      JUST(CastToSameType(tensor_tuple_, common_dtype_));\n    } else {\n      if (tensor_tuple_.size() == 1\n          && !((tensor_tuple_[0]->dtype()->is_floating_point())\n               || tensor_tuple_[0]->dtype()->is_complex())) {\n        Symbol<DType> cast_dtype = (inputs_lowest_dtype_vec_[0] == DType::InvalidDataType())\n                                       ? DType::Float()\n                                       : inputs_lowest_dtype_vec_[0];\n        JUST(CastToSameType(tensor_tuple_, cast_dtype));\n      }\n    }\n  } else {\n    for (int i = 0; i < tensor_tuple_.size(); ++i) {\n      // Cast all the inputs to it's attribute `lowest_dtype` if the input tensor dtype is lower\n      // than attribute `lowest_dtype`.\n      Symbol<DType> base_dtype = inputs_lowest_dtype_vec_.at(i);\n      if (base_dtype->data_type()\n          && DType::priority_order[base_dtype->data_type()]\n                 > DType::priority_order[tensor_tuple_.at(i)->dtype()->data_type()]) {\n        tensor_tuple_[i] =\n            JUST(one::functional::Cast(tensor_tuple_[i], base_dtype, /*pin_memory=*/false));\n      }\n    }\n  }\n  return Maybe<void>::Ok();\n}\n\nstatic bool IsAllContiguous(const TensorTuple& tensors) {\n  for (const auto& t : tensors) {\n    if (t && !t->is_contiguous()) { return false; }\n  }\n  return true;\n}\n\nMaybe<void> TensorLayoutProcessor::Apply() {\n  if (LazyMode::is_enabled()) { return Maybe<void>::Ok(); }\n  if (!non_contiguous_enabled_ && !IsAllContiguous(inputs_)) {\n    contiguous_inputs_.resize(inputs_.size());\n    for (int i = 0; i < inputs_.size(); ++i) { contiguous_inputs_[i] = inputs_[i]->contiguous(); }\n  }\n  // inplace operation is not allowed if input is non-contiguous and non-contiguous is\n  // not supported for this operation\n  if (!non_contiguous_enabled_ && outputs_ && !IsAllContiguous(*outputs_)) {\n    post_process_outputs_.reserve(outputs_->size());\n    post_process_output_indices_.reserve(outputs_->size());\n    for (int i = 0; i < outputs_->size(); ++i) {\n      if ((*outputs_)[i] && !(*outputs_)[i]->is_contiguous()) {\n        post_process_outputs_.emplace_back((*outputs_)[i]);\n        post_process_output_indices_.emplace_back(i);\n        (*outputs_)[i] = nullptr;\n      }\n    }\n  }\n  return Maybe<void>::Ok();\n}\n\nTensorLayoutProcessor::~TensorLayoutProcessor() {\n  for (int i = 0; i < post_process_output_indices_.size(); ++i) {\n    int output_index = post_process_output_indices_[i];\n    CHECK_OR_THROW((*outputs_)[output_index])\n        << \"the output which index is \" << i << \" should not be nullptr\";\n    functional::TensorIndex ellipsis_index;\n    ellipsis_index.emplace_back(functional::detail::EllipsisIndex());\n    CHECK_JUST(functional::TensorSetItem(post_process_outputs_[i], ellipsis_index,\n                                         (*outputs_)[output_index]));\n    (*outputs_)[output_index] = post_process_outputs_[i];\n  }\n}\n\nMaybe<void> TensorAutoCastProcessor::Apply() {\n  if (!autocast::is_enabled()) { return Maybe<void>::Ok(); }\n  if (autocast_meta_.autocast_color() == autocast::kNoColor) { return Maybe<void>::Ok(); }\n  auto autocast_device_type = autocast::get_autocast_device_type();\n  auto autocast_dtype = autocast::get_autocast_dtype();\n  auto IsDeviceType = [](const std::shared_ptr<Tensor>& tensor,\n                         DeviceType device_type) -> Maybe<bool> {\n    return tensor->is_local() ? JUST(tensor->device())->enum_type() == device_type\n                              : JUST(tensor->parallel_desc())->device_type() == device_type;\n  };\n  bool is_autocast_eligible = [&]() {\n    if (!autocast_meta_.is_autocast_eligible(autocast_device_type, autocast_dtype)) {\n      return false;\n    }\n    // Skip autocast if output data type is float32\n    if (outputs_) {\n      for (const auto& output : *outputs_) {\n        if (output && output->dtype() != autocast_dtype) { return false; }\n      }\n    }\n    // Skip autocast if any input is float32 for gray or clear list\n    if (autocast_meta_.autocast_color() != autocast::kWhite) {\n      for (int i = 0; i < inputs_.size(); ++i) {\n        if (autocast_meta_.is_args_autocast_eligible(i) && inputs_[i]->dtype()->is_floating_point()\n            && inputs_[i]->dtype() != autocast_dtype) {\n          return false;\n        }\n      }\n    }\n    return true;\n  }();\n  // Disable autocast temporarily to avoid going into a dead loop\n  autocast::set_enabled(false);\n  if (is_autocast_eligible) {\n    const auto& args_eligible = autocast_meta_.is_args_autocast_eligible();\n    CHECK_EQ_OR_RETURN(args_eligible.size(), inputs_.size())\n        << Error::RuntimeError() << \"argument autocast eligible size should equal to input size\";\n    autocast_inputs_.resize(inputs_.size());\n    for (int i = 0; i < inputs_.size(); ++i) {\n      if (args_eligible[i] && JUST(IsDeviceType(inputs_[i], autocast_device_type))\n          && inputs_[i]->dtype()->is_floating_point() && inputs_[i]->dtype() != autocast_dtype) {\n        autocast_inputs_[i] = JUST(autocast::cached_cast(inputs_[i], autocast_dtype,\n                                                         JUST(inputs_[i]->device())->enum_type()));\n      } else {\n        autocast_inputs_[i] = inputs_[i];\n      }\n    }\n  } else {\n    // Fallback to float32\n    auto common_dtype = ComputeCommonDType(inputs_);\n    auto promote_dtype = promoteTypes(common_dtype, DType::Float());\n    autocast_inputs_.resize(inputs_.size());\n    for (int i = 0; i < inputs_.size(); ++i) {\n      if (JUST(IsDeviceType(inputs_[i], autocast_device_type))\n          && inputs_[i]->dtype()->is_floating_point() && inputs_[i]->dtype() != promote_dtype) {\n        autocast_inputs_[i] = JUST(functional::To(inputs_[i], promote_dtype, /*copy*/ false));\n      } else {\n        autocast_inputs_[i] = inputs_[i];\n      }\n    }\n  }\n  // Enable autocast to restore autocast state\n  autocast::set_enabled(true);\n  return Maybe<void>::Ok();\n}\n\n}  // namespace functional\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/functional/tensor_processor.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_FUNCTIONAL_TENSOR_PROCESSOR_H_\n#define ONEFLOW_CORE_FUNCTIONAL_TENSOR_PROCESSOR_H_\n\n#include <algorithm>\n#include <functional>\n#include <memory>\n#include <tuple>\n\n#include \"oneflow/core/common/symbol.h\"\n#include \"oneflow/core/functional/impl/common.h\"\n#include \"oneflow/core/framework/autocast.h\"\n#include \"oneflow/core/framework/tensor_tuple.h\"\n#include \"oneflow/core/common/optional.h\"\n\nnamespace oneflow {\nnamespace one {\nnamespace functional {\n\nclass TensorProcessor final {\n public:\n  TensorProcessor()\n      : common_dtype_(DType::InvalidDataType()),\n        promote_dtype_(NullOpt),\n        promote_inputs_to_common_dtype_(false),\n        promote_integer_inputs_to_float_(false){};\n  TensorProcessor& AddInputs(const TensorTuple& init_list);\n  TensorProcessor& AddInputs(const TensorTuple& init_list, Symbol<DType> tensor_lowest_dtype);\n\n  Maybe<void> Apply();\n  TensorProcessor& PromoteInputsToCommonDtype(bool is_promote);\n  TensorProcessor& PromoteInputsToCommonDtype(bool is_promote,\n                                              const Optional<Symbol<DType>>& promote_dtype);\n  TensorProcessor& PromoteIntegerInputsToFloatDtype(bool is_promote);\n  Maybe<TensorTuple&> GetInputs() { return tensor_tuple_; };\n\n private:\n  TensorTuple tensor_tuple_;\n  Symbol<DType> common_dtype_;\n  Optional<Symbol<DType>> promote_dtype_;\n  std::vector<Symbol<DType>> inputs_lowest_dtype_vec_;\n\n  bool promote_inputs_to_common_dtype_;\n  bool promote_integer_inputs_to_float_;\n};\n\nclass TensorLayoutProcessor final {\n public:\n  TensorLayoutProcessor(const TensorTuple& inputs, bool non_contiguous_enabled)\n      : TensorLayoutProcessor(inputs, nullptr, non_contiguous_enabled) {}\n  TensorLayoutProcessor(const TensorTuple& inputs, TensorTuple* outputs,\n                        bool non_contiguous_enabled)\n      : inputs_(inputs), outputs_(outputs), non_contiguous_enabled_(non_contiguous_enabled) {}\n\n  ~TensorLayoutProcessor();\n\n  Maybe<void> Apply();\n\n  const TensorTuple& inputs() const {\n    if (!contiguous_inputs_.empty()) { return contiguous_inputs_; }\n    return inputs_;\n  }\n  TensorTuple* outputs() const { return outputs_; }\n\n private:\n  const TensorTuple& inputs_;\n  TensorTuple* outputs_;\n  bool non_contiguous_enabled_;\n  TensorTuple contiguous_inputs_;\n  std::vector<int> post_process_output_indices_;\n  TensorTuple post_process_outputs_;\n};\n\nclass TensorAutoCastProcessor final {\n public:\n  TensorAutoCastProcessor(const TensorTuple& inputs, const autocast::AutoCastMeta& autocast_meta)\n      : TensorAutoCastProcessor(inputs, nullptr, autocast_meta) {}\n  TensorAutoCastProcessor(const TensorTuple& inputs, TensorTuple* outputs,\n                          const autocast::AutoCastMeta& autocast_meta)\n      : inputs_(inputs), outputs_(outputs), autocast_meta_(autocast_meta) {}\n\n  ~TensorAutoCastProcessor() = default;\n\n  Maybe<void> Apply();\n\n  const TensorTuple& inputs() const {\n    if (!autocast_inputs_.empty()) { return autocast_inputs_; }\n    return inputs_;\n  }\n\n  TensorTuple* outputs() const { return outputs_; }\n\n private:\n  const TensorTuple& inputs_;\n  TensorTuple* outputs_;\n  const autocast::AutoCastMeta& autocast_meta_;\n  TensorTuple autocast_inputs_;\n};\n\ntemplate<typename... TPArgs>\nstruct TupleTrait {\n  constexpr static size_t size = sizeof...(TPArgs);\n  constexpr static size_t max_storage_size = std::max({sizeof(TPArgs)...});\n  constexpr static size_t alignment = std::max({alignof(TPArgs)...});\n  using type = std::tuple<TPArgs...>;\n};\n\nstruct TensorProcessorTuple {\n  using trait = TupleTrait<TensorLayoutProcessor, TensorAutoCastProcessor>;\n  constexpr static size_t size = trait::size;\n  constexpr static size_t max_storage_size = trait::max_storage_size;\n  constexpr static size_t alignment = trait::alignment;\n  using type = typename trait::type;\n};\n\nclass TensorProcessorStorage {\n public:\n  constexpr static size_t TPMaxStorageSize = TensorProcessorTuple::max_storage_size;\n\n  TensorProcessorStorage() = default;\n  TensorProcessorStorage(TensorProcessorStorage&& other) = default;\n\n  ~TensorProcessorStorage() {\n    if (deleter_) { deleter_(buffer_); }\n  }\n\n  template<typename TP, typename... Args>\n  void New(Args&&... args) {\n    static_assert(sizeof(TP) <= TPMaxStorageSize, \"Insufficient buffer size\");\n    new (buffer_) TP(std::forward<Args>(args)...);\n    deleter_ = [](char* buffer) { reinterpret_cast<TP*>(buffer)->~TP(); };\n  }\n\n  template<typename TP>\n  TP* As() {\n    return reinterpret_cast<TP*>(buffer_);\n  }\n\n private:\n  alignas(TensorProcessorTuple::alignment) char buffer_[TPMaxStorageSize];\n  std::function<void(char*)> deleter_;\n};\n\nclass TensorProcessorPipe final {\n public:\n  constexpr static size_t TPSize = TensorProcessorTuple::size;\n\n  TensorProcessorPipe(const TensorTuple& inputs) : TensorProcessorPipe(inputs, nullptr) {}\n  TensorProcessorPipe(const TensorTuple& inputs, TensorTuple* outputs)\n      : inputs_(&inputs), outputs_(outputs), index_(0) {}\n\n  template<typename TP, typename... Args>\n  Maybe<void> Apply(Args&&... args) {\n    CHECK_LT_OR_RETURN(index_, static_cast<int>(TPSize))\n        << Error::RuntimeError() << \"The tensor processor pipe can only be applied up to \"\n        << static_cast<int>(TPSize) << \" times\";\n    processors_[index_].New<TP>(*inputs_, outputs_, std::forward<Args>(args)...);\n    auto* processor = processors_[index_].As<TP>();\n    JUST(processor->Apply());\n    inputs_ = &(processor->inputs());\n    outputs_ = processor->outputs();\n    ++index_;\n    return Maybe<void>::Ok();\n  }\n\n  const TensorTuple& inputs() const { return *inputs_; }\n\n  TensorTuple* outputs() const { return outputs_; }\n\n private:\n  const TensorTuple* inputs_;\n  TensorTuple* outputs_;\n  int index_;\n  TensorProcessorStorage processors_[TPSize];\n};\n\n}  // namespace functional\n}  // namespace one\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_FUNCTIONAL_TENSOR_PROCESSOR_H_\n"
  },
  {
    "path": "oneflow/core/graph/boxing/b21_sub_task_graph_builder.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/graph/boxing/b21_sub_task_graph_builder.h\"\n#include \"oneflow/core/graph/boxing/sub_task_graph_builder_util.h\"\n\nnamespace oneflow {\n\nMaybe<SubTskGphBuilderStatus> B21SubTskGphBuilder::Build(\n    SubTskGphBuilderCtx* ctx, const std::vector<TaskNode*>& sorted_in_tasks,\n    std::vector<TaskNode*>* sorted_out_tasks,\n    std::vector<std::vector<TaskNode*>>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc,\n    const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi,\n    const BlobDesc& logical_blob_desc, const SbpParallel& in_sbp_parallel,\n    const SbpParallel& out_sbp_parallel, const Shape& time_shape) const {\n  if ((in_parallel_desc.parallel_num() == 1 || in_sbp_parallel.has_broadcast_parallel())\n      && out_parallel_desc.parallel_num() == 1) {\n    const int64_t out_parallel_id = 0;\n    const int64_t nearest_in_parallel_id = SubTskGphBuilderUtil::FindNearestSrcParallelId(\n        in_parallel_desc, out_parallel_desc, out_parallel_id);\n    sorted_ctrl_tasks->resize(1);\n    FOR_RANGE(int64_t, i, 0, in_parallel_desc.parallel_num()) {\n      TaskNode* in_node = sorted_in_tasks.at(i);\n      if (i == nearest_in_parallel_id) {\n        TaskNode* proxy =\n            ctx->task_graph()->GetProxyNode(in_node, lbi, out_parallel_desc, out_parallel_id);\n        sorted_out_tasks->emplace_back(proxy);\n      } else {\n        sorted_ctrl_tasks->at(0).emplace_back(in_node);\n      }\n    }\n    return TRY(BuildSubTskGphBuilderStatus(\"B21SubTskGphBuilder\", \"\"));\n  } else {\n    return Error::BoxingNotSupportedError();\n  }\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/graph/boxing/b21_sub_task_graph_builder.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_GRAPH_BOXING_B21_SUB_TASK_GRAPH_BUILDER_H_\n#define ONEFLOW_CORE_GRAPH_BOXING_B21_SUB_TASK_GRAPH_BUILDER_H_\n\n#include \"oneflow/core/graph/boxing/sub_task_graph_builder.h\"\n\nnamespace oneflow {\n\nclass B21SubTskGphBuilder final : public SubTskGphBuilder {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(B21SubTskGphBuilder);\n  B21SubTskGphBuilder() = default;\n  ~B21SubTskGphBuilder() override = default;\n\n  Maybe<SubTskGphBuilderStatus> Build(\n      SubTskGphBuilderCtx* ctx, const std::vector<TaskNode*>& sorted_in_tasks,\n      std::vector<TaskNode*>* sorted_out_tasks,\n      std::vector<std::vector<TaskNode*>>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc,\n      const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi,\n      const BlobDesc& logical_blob_desc, const SbpParallel& in_sbp_parallel,\n      const SbpParallel& out_sbp_parallel, const Shape& time_shape) const override;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_GRAPH_BOXING_B21_SUB_TASK_GRAPH_BUILDER_H_\n"
  },
  {
    "path": "oneflow/core/graph/boxing/boxing_logger.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/graph/boxing/boxing_logger.h\"\n#include \"oneflow/core/job/sbp_parallel.h\"\n#include \"oneflow/core/framework/nd_sbp.h\"\n\nnamespace oneflow {\n\nnamespace {\n\n#define OF_BOXING_LOGGER_CSV_COLNUM_NAME_FIELD                   \\\n  \"src_op_name,dst_op_name,src_parallel_desc,dst_parallel_desc,\" \\\n  \"src_nd_sbp,\"                                                  \\\n  \"dst_nd_sbp,lbi,dtype,shape,builder,comment\\n\"\n\nstd::string ShapeToString(const Shape& shape) {\n  std::stringstream shape_ss;\n  auto dim_vec = shape.dim_vec();\n  shape_ss << \"(\";\n  for (int32_t i = 0; i < dim_vec.size(); ++i) {\n    shape_ss << dim_vec.at(i);\n    if (i != dim_vec.size() - 1) { shape_ss << \" \"; }\n  }\n  shape_ss << \")\";\n  return shape_ss.str();\n}\n\nstd::string ParallelDescToString(const ParallelDesc& parallel_desc) {\n  std::string serialized_parallel_desc;\n  std::string device_type;\n  device_type = *CHECK_JUST(DeviceTag4DeviceType(parallel_desc.device_type()));\n  auto sorted_machine_ids = parallel_desc.sorted_machine_ids();\n  serialized_parallel_desc += \"{\";\n  for (int64_t i = 0; i < sorted_machine_ids.size(); ++i) {\n    const int64_t machine_id = sorted_machine_ids.at(i);\n    serialized_parallel_desc += std::to_string(machine_id) + \":\" + device_type + \":\";\n    int64_t min_id = parallel_desc.sorted_dev_phy_ids(machine_id).front();\n    int64_t max_id = parallel_desc.sorted_dev_phy_ids(machine_id).back();\n    serialized_parallel_desc += std::to_string(min_id) + \"-\" + std::to_string(max_id);\n    serialized_parallel_desc += \" \";\n  }\n  serialized_parallel_desc += ShapeToString(*parallel_desc.hierarchy());\n  serialized_parallel_desc += \"}\";\n  return serialized_parallel_desc;\n}\n\nstd::string NdSbpToCsvString(const NdSbp& nd_sbp) {\n  std::ostringstream ss;\n  ss << \"(\";\n  for (size_t i = 0; i < nd_sbp.sbp_parallel_size(); ++i) {\n    if (i > 0) { ss << \" \"; }\n    ss << SbpToString(nd_sbp.sbp_parallel(i));\n  }\n  ss << \")\";\n  return ss.str();\n}\n\nstd::string MakeBoxingLoggerCsvRow(const SubTskGphBuilderStatus& status,\n                                   const std::string& src_op_name, const std::string& dst_op_name,\n                                   const ParallelDesc& src_parallel_desc,\n                                   const ParallelDesc& dst_parallel_desc, const NdSbp& src_nd_sbp,\n                                   const NdSbp& dst_nd_sbp, const LogicalBlobId& lbi,\n                                   const BlobDesc& logical_blob_desc) {\n  std::string serialized_status;\n  serialized_status += src_op_name + \",\";\n  serialized_status += dst_op_name + \",\";\n  serialized_status += ParallelDescToString(src_parallel_desc) + \",\";\n  serialized_status += ParallelDescToString(dst_parallel_desc) + \",\";\n  serialized_status += NdSbpToCsvString(src_nd_sbp) + \",\";\n  serialized_status += NdSbpToCsvString(dst_nd_sbp) + \",\";\n  serialized_status += GenLogicalBlobName(lbi) + \",\";\n  serialized_status += DataType_Name(logical_blob_desc.data_type()) + \",\";\n  serialized_status += ShapeToString(logical_blob_desc.shape()) + \",\";\n  serialized_status += status.builder_name() + \",\";\n  if (status.comment().empty()) {\n    serialized_status += \"-\";\n  } else {\n    serialized_status += status.comment();\n  }\n  serialized_status += \"\\n\";\n  return serialized_status;\n}\n\n}  // namespace\n\nCsvBoxingLogger::CsvBoxingLogger(std::string path) {\n  log_stream_ = TeePersistentLogStream::Create(path);\n  log_stream_ << OF_BOXING_LOGGER_CSV_COLNUM_NAME_FIELD;\n}\n\nCsvBoxingLogger::~CsvBoxingLogger() { log_stream_->Flush(); }\n\nvoid CsvBoxingLogger::Log(const SubTskGphBuilderStatus& status, const std::string& src_op_name,\n                          const std::string& dst_op_name, const ParallelDesc& src_parallel_desc,\n                          const ParallelDesc& dst_parallel_desc, const NdSbp& src_nd_sbp,\n                          const NdSbp& dst_nd_sbp, const LogicalBlobId& lbi,\n                          const BlobDesc& logical_blob_desc) {\n  log_stream_ << MakeBoxingLoggerCsvRow(status, src_op_name, dst_op_name, src_parallel_desc,\n                                        dst_parallel_desc, src_nd_sbp, dst_nd_sbp, lbi,\n                                        logical_blob_desc);\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/graph/boxing/boxing_logger.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_GRAPH_BOXING_LOGGER_H_\n#define ONEFLOW_CORE_GRAPH_BOXING_LOGGER_H_\n\n#include \"oneflow/core/persistence/tee_persistent_log_stream.h\"\n#include \"oneflow/core/graph/boxing/sub_task_graph_builder_status_util.h\"\n\nnamespace oneflow {\n\nclass BoxingLogger {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(BoxingLogger);\n  BoxingLogger() = default;\n  virtual ~BoxingLogger() = default;\n\n  virtual void Log(const SubTskGphBuilderStatus& status, const std::string& src_op_name,\n                   const std::string& dst_op_name, const ParallelDesc& src_parallel_desc,\n                   const ParallelDesc& dst_parallel_desc, const NdSbp& src_nd_sbp,\n                   const NdSbp& dst_nd_sbp, const LogicalBlobId& lbi,\n                   const BlobDesc& logical_blob_desc) = 0;\n};\n\nclass NullBoxingLogger final : public BoxingLogger {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(NullBoxingLogger);\n  NullBoxingLogger() = default;\n  ~NullBoxingLogger() override = default;\n\n  void Log(const SubTskGphBuilderStatus& status, const std::string& src_op_name,\n           const std::string& dst_op_name, const ParallelDesc& src_parallel_desc,\n           const ParallelDesc& dst_parallel_desc, const NdSbp& src_nd_sbp, const NdSbp& dst_nd_sbp,\n           const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc) override{};\n};\n\nclass CsvBoxingLogger final : public BoxingLogger {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CsvBoxingLogger);\n  CsvBoxingLogger() = delete;\n  CsvBoxingLogger(std::string path);\n  ~CsvBoxingLogger() override;\n\n  void Log(const SubTskGphBuilderStatus& status, const std::string& src_op_name,\n           const std::string& dst_op_name, const ParallelDesc& src_parallel_desc,\n           const ParallelDesc& dst_parallel_desc, const NdSbp& src_nd_sbp, const NdSbp& dst_nd_sbp,\n           const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc) override;\n\n private:\n  std::unique_ptr<TeePersistentLogStream> log_stream_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_GRAPH_BOXING_LOGGER_H_\n"
  },
  {
    "path": "oneflow/core/graph/boxing/ccl_sub_task_graph_builder.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/graph/boxing/ccl_sub_task_graph_builder.h\"\n#include \"oneflow/core/graph/boxing/sub_task_graph_builder_util.h\"\n#include \"oneflow/core/graph/collective_boxing_task_node.h\"\n#include \"oneflow/core/graph/collective_boxing_pack_task_node.h\"\n#include \"oneflow/core/graph/collective_boxing_unpack_task_node.h\"\n#include \"oneflow/core/graph/slice_boxing_task_node.h\"\n#include \"oneflow/core/graph/task_stream_id.h\"\n#include \"oneflow/core/job/nd_sbp_util.h\"\n\nnamespace oneflow {\n\nusing namespace boxing::collective;\n\nnamespace {\n\nvoid CclInitCollectiveNode(CollectiveBoxingGenericTaskNode* node, const ParallelDesc& parallel_desc,\n                           int64_t parallel_id, const std::string& name, const LogicalBlobId& lbi,\n                           const BlobDesc& logical_blob_desc, OpType op_type,\n                           DeviceType device_type, int64_t root) {\n  OperatorConf op_conf;\n  op_conf.set_name(name);\n  op_conf.set_device_tag(*CHECK_JUST(DeviceTag4DeviceType(device_type)));\n  CollectiveBoxingGenericOpConf* conf = op_conf.mutable_collective_boxing_generic_conf();\n  *conf->mutable_lbi() = lbi;\n  RankDesc* rank_desc = conf->mutable_rank_desc();\n  OpDesc* op_desc = rank_desc->mutable_op_desc();\n  op_desc->set_name(name);\n  op_desc->set_op_type(op_type);\n  if (op_type == OpType::kOpTypeAllReduce || op_type == OpType::kOpTypeReduceScatter\n      || op_type == OpType::kOpTypeReduce) {\n    op_desc->set_reduce_method(ReduceMethod::kReduceMethodSum);\n  }\n  op_desc->set_data_type(logical_blob_desc.data_type());\n  logical_blob_desc.shape().ToProto(op_desc->mutable_shape());\n  op_desc->set_num_ranks(parallel_desc.parallel_num());\n  if (op_type == OpType::kOpTypeBroadcast || op_type == OpType::kOpTypeReduce) {\n    CHECK_GE(root, 0);\n    CHECK_LT(root, parallel_desc.parallel_num());\n    op_desc->set_root(root);\n  } else {\n    CHECK_EQ(root, -1);\n  }\n  op_desc->set_device_type(device_type);\n  rank_desc->set_rank(parallel_id);\n\n  const int64_t machine_id = CHECK_JUST(parallel_desc.MachineId4ParallelId(parallel_id));\n  const int64_t device_index = CHECK_JUST(parallel_desc.DeviceId4ParallelId(parallel_id));\n  const int64_t thrd_id = EncodeStreamIdToInt64(GenerateNamedTaskStreamId(\n      machine_id, device_type, device_index, *CHECK_JUST(DeviceTag4DeviceType(device_type))));\n  node->Init(machine_id, thrd_id, lbi, op_conf);\n}\n\nint64_t FindRootParallelId(const ParallelDesc& multi_device, const ParallelDesc& sole_device) {\n  CHECK_EQ(sole_device.parallel_num(), 1);\n  const int64_t root_machine_id = CHECK_JUST(sole_device.MachineId4ParallelId(0));\n  const int64_t root_device_id = CHECK_JUST(sole_device.DeviceId4ParallelId(0));\n  int64_t root_parallel_id = -1;\n  FOR_RANGE(int64_t, i, 0, multi_device.parallel_num()) {\n    if (CHECK_JUST(multi_device.MachineId4ParallelId(i)) == root_machine_id\n        && CHECK_JUST(multi_device.DeviceId4ParallelId(i)) == root_device_id) {\n      root_parallel_id = i;\n      break;\n    }\n  }\n  return root_parallel_id;\n}\n\n}  // namespace\n\nbool IsSourceTimeShape(const Shape& shape) { return shape.elem_cnt() == 1; }\n\nMaybe<SubTskGphBuilderStatus> CclAllReduceSubTskGphBuilder::Build(\n    SubTskGphBuilderCtx* ctx, const std::vector<TaskNode*>& sorted_in_tasks,\n    std::vector<TaskNode*>* sorted_out_tasks,\n    std::vector<std::vector<TaskNode*>>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc,\n    const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi,\n    const BlobDesc& logical_blob_desc, const SbpParallel& in_sbp_parallel,\n    const SbpParallel& out_sbp_parallel, const Shape& time_shape) const {\n  if (out_parallel_desc.Equals(in_parallel_desc)\n      && !SubTskGphBuilderUtil::BlobHasDynamicShape(logical_blob_desc)\n      && out_parallel_desc.device_type() == device_type_ && out_parallel_desc.parallel_num() > 1\n      && SubTskGphBuilderUtil::IsBoxingP2B(in_sbp_parallel, out_sbp_parallel)) {\n    const std::string op_name = \"System-Boxing-CclBoxingAllReduce-\" + NewUniqueId();\n    FOR_RANGE(int64_t, i, 0, in_parallel_desc.parallel_num()) {\n      TaskNode* in_node = sorted_in_tasks.at(i);  // NOLINT\n      auto* collective_node = ctx->task_graph()->NewNode<CollectiveBoxingGenericTaskNode>();\n      CclInitCollectiveNode(collective_node, in_parallel_desc, i, op_name, lbi, logical_blob_desc,\n                            OpType::kOpTypeAllReduce, device_type_, -1);\n      ctx->task_graph()->ConnectWithLbi(in_node, collective_node, lbi);\n      sorted_out_tasks->emplace_back(collective_node);\n    }\n    return TRY(BuildSubTskGphBuilderStatus(\"CclBoxingAllReduceSubTskGphBuilder\", \"\"));\n  } else {\n    return Error::BoxingNotSupportedError();\n  }\n}\n\nMaybe<SubTskGphBuilderStatus> CclReduceScatterSubTskGphBuilder::Build(\n    SubTskGphBuilderCtx* ctx, const std::vector<TaskNode*>& sorted_in_tasks,\n    std::vector<TaskNode*>* sorted_out_tasks,\n    std::vector<std::vector<TaskNode*>>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc,\n    const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi,\n    const BlobDesc& logical_blob_desc, const SbpParallel& in_sbp_parallel,\n    const SbpParallel& out_sbp_parallel, const Shape& time_shape) const {\n  if (out_parallel_desc.Equals(in_parallel_desc)\n      && !SubTskGphBuilderUtil::BlobHasDynamicShape(logical_blob_desc)\n      && out_parallel_desc.device_type() == device_type_ && out_parallel_desc.parallel_num() > 1\n      && logical_blob_desc.shape().NumAxes() > 0\n      && logical_blob_desc.shape().At(0) % out_parallel_desc.parallel_num() == 0\n      && SubTskGphBuilderUtil::IsBoxingP2S(in_sbp_parallel, out_sbp_parallel)\n      && out_sbp_parallel.split_parallel().axis() == 0) {\n    const std::string op_name = \"System-Boxing-CclBoxingReduceScatter-\" + NewUniqueId();\n    FOR_RANGE(int64_t, i, 0, in_parallel_desc.parallel_num()) {\n      TaskNode* in_node = sorted_in_tasks.at(i);  // NOLINT\n      auto* collective_node = ctx->task_graph()->NewNode<CollectiveBoxingGenericTaskNode>();\n      CclInitCollectiveNode(collective_node, in_parallel_desc, i, op_name, lbi, logical_blob_desc,\n                            OpType::kOpTypeReduceScatter, device_type_, -1);\n      ctx->task_graph()->ConnectWithLbi(in_node, collective_node, lbi);\n      sorted_out_tasks->emplace_back(collective_node);\n    }\n    return TRY(BuildSubTskGphBuilderStatus(\"CclBoxingReduceScatterSubTskGphBuilder\", \"\"));\n  } else {\n    return Error::BoxingNotSupportedError();\n  }\n}\n\nMaybe<SubTskGphBuilderStatus> CclP2SNoncontinuousSubTskGphBuilder::Build(\n    SubTskGphBuilderCtx* ctx, const std::vector<TaskNode*>& sorted_in_tasks,\n    std::vector<TaskNode*>* sorted_out_tasks,\n    std::vector<std::vector<TaskNode*>>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc,\n    const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi,\n    const BlobDesc& logical_blob_desc, const SbpParallel& in_sbp_parallel,\n    const SbpParallel& out_sbp_parallel, const Shape& time_shape) const {\n  const Shape& shape = logical_blob_desc.shape();\n  const int64_t out_split_axis = out_sbp_parallel.split_parallel().axis();\n  if (out_parallel_desc.Equals(in_parallel_desc)\n      && !SubTskGphBuilderUtil::BlobHasDynamicShape(logical_blob_desc)\n      && out_parallel_desc.device_type() == device_type_ && out_parallel_desc.parallel_num() > 1\n      && SubTskGphBuilderUtil::IsBoxingP2S(in_sbp_parallel, out_sbp_parallel)\n      && shape.NumAxes() > out_split_axis\n      && shape.At(out_split_axis) % out_parallel_desc.parallel_num() == 0\n      && out_sbp_parallel.split_parallel().axis() != 0) {\n    const std::string op_name = \"System-Boxing-CclBoxingP2SNoncontinuous-\" + NewUniqueId();\n    FOR_RANGE(int64_t, i, 0, in_parallel_desc.parallel_num()) {\n      const int64_t machine_id = CHECK_JUST(in_parallel_desc.MachineId4ParallelId(i));\n      const int64_t device_index = CHECK_JUST(in_parallel_desc.DeviceId4ParallelId(i));\n      const int64_t thrd_id = EncodeStreamIdToInt64(\n          GenerateComputeTaskStreamId(machine_id, device_type_, device_index));\n      TaskNode* in_node = sorted_in_tasks.at(i);  // NOLINT\n      CollectiveBoxingPackTaskNode* pack_node =\n          ctx->task_graph()->NewNode<CollectiveBoxingPackTaskNode>();\n      pack_node->Init(machine_id, thrd_id, lbi, logical_blob_desc.shape(), in_sbp_parallel,\n                      out_sbp_parallel, in_parallel_desc.parallel_num());\n      ctx->task_graph()->ConnectWithLbi(in_node, pack_node, lbi);\n\n      auto* collective_node = ctx->task_graph()->NewNode<CollectiveBoxingGenericTaskNode>();\n      CclInitCollectiveNode(\n          collective_node, in_parallel_desc, i, op_name, lbi,\n          BlobDesc({logical_blob_desc.shape().elem_cnt()}, logical_blob_desc.data_type(),\n                   logical_blob_desc.memory_format()),\n          OpType::kOpTypeReduceScatter, device_type_, -1);\n      ctx->task_graph()->ConnectWithLbi(pack_node, collective_node, lbi);\n\n      CollectiveBoxingUnpackTaskNode* unpack_node =\n          ctx->task_graph()->NewNode<CollectiveBoxingUnpackTaskNode>();\n      unpack_node->Init(machine_id, thrd_id, lbi, logical_blob_desc.shape(), in_sbp_parallel,\n                        out_sbp_parallel, in_parallel_desc.parallel_num());\n      ctx->task_graph()->ConnectWithLbi(collective_node, unpack_node, lbi);\n      sorted_out_tasks->emplace_back(unpack_node);\n    }\n    return TRY(BuildSubTskGphBuilderStatus(\"CclBoxingP2SNoncontinuousSubTskGphBuilder\", \"\"));\n  } else {\n    return Error::BoxingNotSupportedError();\n  }\n}\n\nMaybe<SubTskGphBuilderStatus> CclAllGatherSubTskGphBuilder::Build(\n    SubTskGphBuilderCtx* ctx, const std::vector<TaskNode*>& sorted_in_tasks,\n    std::vector<TaskNode*>* sorted_out_tasks,\n    std::vector<std::vector<TaskNode*>>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc,\n    const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi,\n    const BlobDesc& logical_blob_desc, const SbpParallel& in_sbp_parallel,\n    const SbpParallel& out_sbp_parallel, const Shape& time_shape) const {\n  if (out_parallel_desc.EqualsIgnoringDeviceType(in_parallel_desc)\n      && !SubTskGphBuilderUtil::BlobHasDynamicShape(logical_blob_desc)\n      && SubTskGphBuilderUtil::IsDeviceTypeCPUOr(in_parallel_desc, device_type_)\n      && out_parallel_desc.device_type() == device_type_ && out_parallel_desc.parallel_num() > 1\n      && logical_blob_desc.shape().NumAxes() > 0\n      && logical_blob_desc.shape().At(0) % out_parallel_desc.parallel_num() == 0\n      && SubTskGphBuilderUtil::IsBoxingS2B(in_sbp_parallel, out_sbp_parallel)\n      && in_sbp_parallel.split_parallel().axis() == 0) {\n    const std::string op_name = \"System-Boxing-CclBoxingAllGather-\" + NewUniqueId();\n    FOR_RANGE(int64_t, i, 0, in_parallel_desc.parallel_num()) {\n      TaskNode* in_node = sorted_in_tasks.at(i);  // NOLINT\n      TaskNode* in_node_proxy = ctx->task_graph()->GetProxyNode(in_node, lbi, out_parallel_desc, i);\n      auto* collective_node = ctx->task_graph()->NewNode<CollectiveBoxingGenericTaskNode>();\n      CclInitCollectiveNode(collective_node, out_parallel_desc, i, op_name, lbi, logical_blob_desc,\n                            OpType::kOpTypeAllGather, device_type_, -1);\n      ctx->task_graph()->ConnectWithLbi(in_node_proxy, collective_node, lbi);\n      sorted_out_tasks->emplace_back(collective_node);\n    }\n    return TRY(BuildSubTskGphBuilderStatus(\"CclBoxingAllGatherSubTskGphBuilder\", \"\"));\n  } else {\n    return Error::BoxingNotSupportedError();\n  }\n}\nMaybe<SubTskGphBuilderStatus> CclS2BNoncontinuousSubTskGphBuilder::Build(\n    SubTskGphBuilderCtx* ctx, const std::vector<TaskNode*>& sorted_in_tasks,\n    std::vector<TaskNode*>* sorted_out_tasks,\n    std::vector<std::vector<TaskNode*>>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc,\n    const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi,\n    const BlobDesc& logical_blob_desc, const SbpParallel& in_sbp_parallel,\n    const SbpParallel& out_sbp_parallel, const Shape& time_shape) const {\n  const Shape& shape = logical_blob_desc.shape();\n  const int64_t in_split_axis = in_sbp_parallel.split_parallel().axis();\n  if (out_parallel_desc.EqualsIgnoringDeviceType(in_parallel_desc)\n      && !SubTskGphBuilderUtil::BlobHasDynamicShape(logical_blob_desc)\n      && SubTskGphBuilderUtil::IsDeviceTypeCPUOr(in_parallel_desc, device_type_)\n      && out_parallel_desc.device_type() == device_type_ && out_parallel_desc.parallel_num() > 1\n      && SubTskGphBuilderUtil::IsBoxingS2B(in_sbp_parallel, out_sbp_parallel)\n      && shape.NumAxes() > in_split_axis && in_split_axis > 0\n      && shape.At(in_split_axis) % out_parallel_desc.parallel_num() == 0) {\n    const std::string op_name = \"System-Boxing-CclBoxingS2BNoncontinuous-\" + NewUniqueId();\n    FOR_RANGE(int64_t, i, 0, in_parallel_desc.parallel_num()) {\n      const int64_t machine_id = CHECK_JUST(out_parallel_desc.MachineId4ParallelId(i));\n      const int64_t device_index = CHECK_JUST(out_parallel_desc.DeviceId4ParallelId(i));\n      const int64_t thrd_id = EncodeStreamIdToInt64(\n          GenerateComputeTaskStreamId(machine_id, device_type_, device_index));\n      TaskNode* in_node = sorted_in_tasks.at(i);  // NOLINT\n      TaskNode* in_node_proxy = ctx->task_graph()->GetProxyNode(in_node, lbi, out_parallel_desc, i);\n      CollectiveBoxingPackTaskNode* pack_node =\n          ctx->task_graph()->NewNode<CollectiveBoxingPackTaskNode>();\n      pack_node->Init(machine_id, thrd_id, lbi, logical_blob_desc.shape(), in_sbp_parallel,\n                      out_sbp_parallel, in_parallel_desc.parallel_num());\n      ctx->task_graph()->ConnectWithLbi(in_node_proxy, pack_node, lbi);\n      auto* collective_node = ctx->task_graph()->NewNode<CollectiveBoxingGenericTaskNode>();\n      CclInitCollectiveNode(\n          collective_node, out_parallel_desc, i, op_name, lbi,\n          BlobDesc({logical_blob_desc.shape().elem_cnt()}, logical_blob_desc.data_type(),\n                   logical_blob_desc.memory_format()),\n          OpType::kOpTypeAllGather, device_type_, -1);\n      ctx->task_graph()->ConnectWithLbi(pack_node, collective_node, lbi);\n      CollectiveBoxingUnpackTaskNode* unpack_node =\n          ctx->task_graph()->NewNode<CollectiveBoxingUnpackTaskNode>();\n      unpack_node->Init(machine_id, thrd_id, lbi, logical_blob_desc.shape(), in_sbp_parallel,\n                        out_sbp_parallel, in_parallel_desc.parallel_num());\n      ctx->task_graph()->ConnectWithLbi(collective_node, unpack_node, lbi);\n      sorted_out_tasks->emplace_back(unpack_node);\n    }\n    return TRY(BuildSubTskGphBuilderStatus(\"CclBoxingS2BNoncontinuousSubTskGphBuilder\", \"\"));\n  } else {\n    return Error::BoxingNotSupportedError();\n  }\n}\n\nMaybe<SubTskGphBuilderStatus> CclReduceSubTskGphBuilder::Build(\n    SubTskGphBuilderCtx* ctx, const std::vector<TaskNode*>& sorted_in_tasks,\n    std::vector<TaskNode*>* sorted_out_tasks,\n    std::vector<std::vector<TaskNode*>>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc,\n    const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi,\n    const BlobDesc& logical_blob_desc, const SbpParallel& in_sbp_parallel,\n    const SbpParallel& out_sbp_parallel, const Shape& time_shape) const {\n  if (in_parallel_desc.parallel_num() > 1 && out_parallel_desc.parallel_num() == 1\n      && in_parallel_desc.device_type() == device_type_\n      && out_parallel_desc.device_type() == device_type_\n      && !SubTskGphBuilderUtil::BlobHasDynamicShape(logical_blob_desc)\n      && in_sbp_parallel.has_partial_sum_parallel()) {\n    const int64_t root_parallel_id = FindRootParallelId(in_parallel_desc, out_parallel_desc);\n    if (root_parallel_id == -1) { return Error::BoxingNotSupportedError(); }\n\n    const std::string op_name = \"System-Boxing-CclBoxingReduce-\" + NewUniqueId();\n    sorted_ctrl_tasks->resize(out_parallel_desc.parallel_num());\n    FOR_RANGE(int64_t, i, 0, in_parallel_desc.parallel_num()) {\n      TaskNode* in_node = sorted_in_tasks.at(i);  // NOLINT\n      auto* collective_node = ctx->task_graph()->NewNode<CollectiveBoxingGenericTaskNode>();\n      CclInitCollectiveNode(collective_node, in_parallel_desc, i, op_name, lbi, logical_blob_desc,\n                            OpType::kOpTypeReduce, device_type_, root_parallel_id);\n      ctx->task_graph()->ConnectWithLbi(in_node, collective_node, lbi);\n      if (i == root_parallel_id) {\n        sorted_out_tasks->emplace_back(collective_node);\n      } else {\n        sorted_ctrl_tasks->at(0).emplace_back(collective_node);  // NOLINT\n      }\n    }\n    return TRY(BuildSubTskGphBuilderStatus(\"CclBoxingReduceSubTskGphBuilder\", \"\"));\n  } else {\n    return Error::BoxingNotSupportedError();\n  }\n}\n\nMaybe<SubTskGphBuilderStatus> CclScatterThenAllGatherSubTskGphBuilder::Build(\n    SubTskGphBuilderCtx* ctx, const std::vector<TaskNode*>& sorted_in_tasks,\n    std::vector<TaskNode*>* sorted_out_tasks,\n    std::vector<std::vector<TaskNode*>>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc,\n    const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi,\n    const BlobDesc& logical_blob_desc, const SbpParallel& in_sbp_parallel,\n    const SbpParallel& out_sbp_parallel, const Shape& time_shape) const {\n  if (in_parallel_desc.parallel_num() == 1 && out_parallel_desc.parallel_num() > 1\n      && in_parallel_desc.device_type() == DeviceType::kCPU\n      && out_parallel_desc.device_type() == device_type_\n      && !SubTskGphBuilderUtil::BlobHasDynamicShape(logical_blob_desc)\n      && logical_blob_desc.shape().elem_cnt() >= 1024\n      && out_sbp_parallel.has_broadcast_parallel()\n      // a potential optimization: flat the blob and then relax this requirement\n      && logical_blob_desc.shape().NumAxes() > 0\n      && logical_blob_desc.shape().At(0) % out_parallel_desc.parallel_num() == 0) {\n    const TensorSliceView in_slice = GetBroadcastTensorSliceView(logical_blob_desc);\n    SbpParallel split_sbp_parallel;\n    split_sbp_parallel.mutable_split_parallel()->set_axis(0);\n    std::vector<TensorSliceView> out_slices =\n        GetTensorSliceView(out_parallel_desc.parallel_num(), split_sbp_parallel, logical_blob_desc);\n    const std::string op_name = \"System-Boxing-CclBoxingAllGather-\" + NewUniqueId();\n    FOR_RANGE(int64_t, out_id, 0, out_parallel_desc.parallel_num()) {\n      const TensorSliceView& out_slice = out_slices.at(out_id);  // NOLINT\n      const int64_t nearest_in_parallel_id = SubTskGphBuilderUtil::FindNearestSrcParallelId(\n          in_parallel_desc, out_parallel_desc, out_id);\n\n      TaskNode* in_node = sorted_in_tasks.at(nearest_in_parallel_id);  // NOLINT\n      SliceBoxingTaskNode* slice_node = ctx->task_graph()->NewNode<SliceBoxingTaskNode>();\n      // slice on cpu\n      const auto in_machine_id = CHECK_JUST(in_parallel_desc.MachineId4ParallelId(0));\n      int64_t thrd_id =\n          EncodeStreamIdToInt64(GenerateComputeTaskStreamId(in_machine_id, DeviceType::kCPU, 0));\n      slice_node->Init(lbi, out_slice, kSliceBoxingTaskModeCopy, in_machine_id, thrd_id);\n      slice_node->ConnectToSrcNodeWithSlice(in_node, ctx->task_graph()->NewEdge(), in_slice);\n      // copy to dst gpu\n      TaskNode* slice_node_proxy =\n          ctx->task_graph()->GetProxyNode(slice_node, lbi, out_parallel_desc, out_id);\n      // allgather\n      auto* collective_node = ctx->task_graph()->NewNode<CollectiveBoxingGenericTaskNode>();\n      CclInitCollectiveNode(collective_node, out_parallel_desc, out_id, op_name, lbi,\n                            logical_blob_desc, OpType::kOpTypeAllGather, device_type_, -1);\n      ctx->task_graph()->ConnectWithLbi(slice_node_proxy, collective_node, lbi);\n      sorted_out_tasks->emplace_back(collective_node);\n    }\n    return TRY(BuildSubTskGphBuilderStatus(\"BoxingCclScatterThenAllGatherSubTskGphBuilder\", \"\"));\n  } else {\n    return Error::BoxingNotSupportedError();\n  }\n}\n\nMaybe<SubTskGphBuilderStatus> CclBroadcastSubTskGphBuilder::Build(\n    SubTskGphBuilderCtx* ctx, const std::vector<TaskNode*>& sorted_in_tasks,\n    std::vector<TaskNode*>* sorted_out_tasks,\n    std::vector<std::vector<TaskNode*>>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc,\n    const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi,\n    const BlobDesc& logical_blob_desc, const SbpParallel& in_sbp_parallel,\n    const SbpParallel& out_sbp_parallel, const Shape& time_shape) const {\n  if (in_parallel_desc.parallel_num() == 1 && out_parallel_desc.parallel_num() > 1\n      && (in_parallel_desc.device_type() == device_type_\n          || (in_parallel_desc.device_type() == DeviceType::kCPU\n              && logical_blob_desc.shape().elem_cnt() >= 1024))\n      && out_parallel_desc.device_type() == device_type_\n      && !SubTskGphBuilderUtil::BlobHasDynamicShape(logical_blob_desc)\n      && out_sbp_parallel.has_broadcast_parallel()) {\n    TaskNode* gpu_in_node = nullptr;\n    int64_t root_parallel_id = -1;\n    if (in_parallel_desc.device_type() == DeviceType::kCPU) {\n      auto* cpu_in_node = sorted_in_tasks.front();\n      root_parallel_id =\n          SubTskGphBuilderUtil::FindNearestSrcParallelId(out_parallel_desc, in_parallel_desc, 0);\n      gpu_in_node =\n          ctx->task_graph()->GetProxyNode(cpu_in_node, lbi, out_parallel_desc, root_parallel_id);\n\n    } else if (in_parallel_desc.device_type() == device_type_) {\n      root_parallel_id = FindRootParallelId(out_parallel_desc, in_parallel_desc);\n      gpu_in_node = sorted_in_tasks.front();\n    } else {\n      return Error::BoxingNotSupportedError();\n    }\n    if (root_parallel_id == -1) { return Error::BoxingNotSupportedError(); }\n\n    const std::string op_name = \"System-Boxing-CclBoxingBroadcast-\" + NewUniqueId();\n    FOR_RANGE(int64_t, i, 0, out_parallel_desc.parallel_num()) {\n      auto* collective_node = ctx->task_graph()->NewNode<CollectiveBoxingGenericTaskNode>();\n      CclInitCollectiveNode(collective_node, out_parallel_desc, i, op_name, lbi, logical_blob_desc,\n                            OpType::kOpTypeBroadcast, device_type_, root_parallel_id);\n      if (i == root_parallel_id) {\n        ctx->task_graph()->ConnectWithLbi(gpu_in_node, collective_node, lbi);\n      } else {\n        std::string regst_desc_name;\n        gpu_in_node->BuildCtrlRegstDesc(collective_node, &regst_desc_name);\n        TaskEdge* edge = ctx->task_graph()->NewEdge();\n        Connect<TaskNode>(gpu_in_node, edge, collective_node);\n        gpu_in_node->BindEdgeWithProducedRegst(edge, regst_desc_name);\n      }\n      sorted_out_tasks->emplace_back(collective_node);\n    }\n    return TRY(BuildSubTskGphBuilderStatus(\"CclBoxingBroadcastSubTskGphBuilder\", \"\"));\n  } else {\n    return Error::BoxingNotSupportedError();\n  }\n}\n\nMaybe<SubTskGphBuilderStatus> CclAll2AllSubTskGphBuilder::Build(\n    SubTskGphBuilderCtx* ctx, const std::vector<TaskNode*>& sorted_in_tasks,\n    std::vector<TaskNode*>* sorted_out_tasks,\n    std::vector<std::vector<TaskNode*>>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc,\n    const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi,\n    const BlobDesc& logical_blob_desc, const SbpParallel& in_sbp_parallel,\n    const SbpParallel& out_sbp_parallel, const Shape& time_shape) const {\n  const Shape& shape = logical_blob_desc.shape();\n  const int64_t in_split_axis = in_sbp_parallel.split_parallel().axis();\n  const int64_t out_split_axis = out_sbp_parallel.split_parallel().axis();\n  if (out_parallel_desc.EqualsIgnoringDeviceType(in_parallel_desc)\n      && !SubTskGphBuilderUtil::BlobHasDynamicShape(logical_blob_desc)\n      && in_parallel_desc.device_type() == device_type_\n      && out_parallel_desc.device_type() == device_type_ && out_parallel_desc.parallel_num() > 1\n      && shape.NumAxes() > std::max(in_split_axis, out_split_axis)\n      && shape.At(in_split_axis) % in_parallel_desc.parallel_num() == 0\n      && shape.At(out_split_axis) % out_parallel_desc.parallel_num() == 0\n      && in_sbp_parallel.split_parallel().axis() != out_sbp_parallel.split_parallel().axis()\n      && SubTskGphBuilderUtil::IsBoxingS2S(in_sbp_parallel, out_sbp_parallel)) {\n    const std::string op_name = \"System-Boxing-CclBoxingAll2All-\" + NewUniqueId();\n    FOR_RANGE(int64_t, i, 0, in_parallel_desc.parallel_num()) {\n      const int64_t machine_id = CHECK_JUST(in_parallel_desc.MachineId4ParallelId(i));\n      const int64_t device_index = CHECK_JUST(in_parallel_desc.DeviceId4ParallelId(i));\n      const int64_t thrd_id = EncodeStreamIdToInt64(\n          GenerateComputeTaskStreamId(machine_id, device_type_, device_index));\n      TaskNode* in_node = sorted_in_tasks.at(i);  // NOLINT\n      CollectiveBoxingPackTaskNode* pack_node =\n          ctx->task_graph()->NewNode<CollectiveBoxingPackTaskNode>();\n      pack_node->Init(machine_id, thrd_id, lbi, logical_blob_desc.shape(), in_sbp_parallel,\n                      out_sbp_parallel, in_parallel_desc.parallel_num());\n      ctx->task_graph()->ConnectWithLbi(in_node, pack_node, lbi);\n\n      auto* collective_node = ctx->task_graph()->NewNode<CollectiveBoxingGenericTaskNode>();\n      CclInitCollectiveNode(collective_node, out_parallel_desc, i, op_name, lbi, logical_blob_desc,\n                            OpType::kOpTypeAll2All, device_type_, -1);\n      ctx->task_graph()->ConnectWithLbi(pack_node, collective_node, lbi);\n\n      CollectiveBoxingUnpackTaskNode* unpack_node =\n          ctx->task_graph()->NewNode<CollectiveBoxingUnpackTaskNode>();\n      unpack_node->Init(machine_id, thrd_id, lbi, logical_blob_desc.shape(), in_sbp_parallel,\n                        out_sbp_parallel, in_parallel_desc.parallel_num());\n      ctx->task_graph()->ConnectWithLbi(collective_node, unpack_node, lbi);\n      sorted_out_tasks->emplace_back(unpack_node);\n    }\n    return TRY(BuildSubTskGphBuilderStatus(\"CclBoxingAll2AllSubTskGphBuilder\", \"\"));\n  } else {\n    return Error::BoxingNotSupportedError();\n  }\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/graph/boxing/ccl_sub_task_graph_builder.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_GRAPH_BOXING_CCL_SUB_TASK_GRAPH_BUILDER_H_\n#define ONEFLOW_CORE_GRAPH_BOXING_CCL_SUB_TASK_GRAPH_BUILDER_H_\n\n#include \"oneflow/core/graph/boxing/sub_task_graph_builder.h\"\n\nnamespace oneflow {\n\nbool IsSourceTimeShape(const Shape& shape);\n\nclass CclAllReduceSubTskGphBuilder final : public SubTskGphBuilder {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CclAllReduceSubTskGphBuilder);\n  CclAllReduceSubTskGphBuilder(DeviceType device_type) : device_type_(device_type) {}\n  ~CclAllReduceSubTskGphBuilder() override = default;\n\n  Maybe<SubTskGphBuilderStatus> Build(\n      SubTskGphBuilderCtx* ctx, const std::vector<TaskNode*>& sorted_in_tasks,\n      std::vector<TaskNode*>* sorted_out_tasks,\n      std::vector<std::vector<TaskNode*>>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc,\n      const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi,\n      const BlobDesc& logical_blob_desc, const SbpParallel& in_sbp_parallel,\n      const SbpParallel& out_sbp_parallel, const Shape& time_shape) const override;\n\n private:\n  DeviceType device_type_;\n};\n\nclass CclReduceScatterSubTskGphBuilder final : public SubTskGphBuilder {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CclReduceScatterSubTskGphBuilder);\n  CclReduceScatterSubTskGphBuilder(DeviceType device_type) : device_type_(device_type) {}\n  ~CclReduceScatterSubTskGphBuilder() override = default;\n\n  Maybe<SubTskGphBuilderStatus> Build(\n      SubTskGphBuilderCtx* ctx, const std::vector<TaskNode*>& sorted_in_tasks,\n      std::vector<TaskNode*>* sorted_out_tasks,\n      std::vector<std::vector<TaskNode*>>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc,\n      const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi,\n      const BlobDesc& logical_blob_desc, const SbpParallel& in_sbp_parallel,\n      const SbpParallel& out_sbp_parallel, const Shape& time_shape) const override;\n\n private:\n  DeviceType device_type_;\n};\n\nclass CclP2SNoncontinuousSubTskGphBuilder final : public SubTskGphBuilder {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CclP2SNoncontinuousSubTskGphBuilder);\n  CclP2SNoncontinuousSubTskGphBuilder(DeviceType device_type) : device_type_(device_type) {}\n  ~CclP2SNoncontinuousSubTskGphBuilder() override = default;\n\n  Maybe<SubTskGphBuilderStatus> Build(\n      SubTskGphBuilderCtx* ctx, const std::vector<TaskNode*>& sorted_in_tasks,\n      std::vector<TaskNode*>* sorted_out_tasks,\n      std::vector<std::vector<TaskNode*>>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc,\n      const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi,\n      const BlobDesc& logical_blob_desc, const SbpParallel& in_sbp_parallel,\n      const SbpParallel& out_sbp_parallel, const Shape& time_shape) const override;\n\n private:\n  DeviceType device_type_;\n};\n\nclass CclAllGatherSubTskGphBuilder final : public SubTskGphBuilder {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CclAllGatherSubTskGphBuilder);\n  CclAllGatherSubTskGphBuilder(DeviceType device_type) : device_type_(device_type) {}\n  ~CclAllGatherSubTskGphBuilder() override = default;\n\n  Maybe<SubTskGphBuilderStatus> Build(\n      SubTskGphBuilderCtx* ctx, const std::vector<TaskNode*>& sorted_in_tasks,\n      std::vector<TaskNode*>* sorted_out_tasks,\n      std::vector<std::vector<TaskNode*>>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc,\n      const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi,\n      const BlobDesc& logical_blob_desc, const SbpParallel& in_sbp_parallel,\n      const SbpParallel& out_sbp_parallel, const Shape& time_shape) const override;\n\n private:\n  DeviceType device_type_;\n};\n\nclass CclS2BNoncontinuousSubTskGphBuilder final : public SubTskGphBuilder {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CclS2BNoncontinuousSubTskGphBuilder);\n  CclS2BNoncontinuousSubTskGphBuilder(DeviceType device_type) : device_type_(device_type) {}\n  ~CclS2BNoncontinuousSubTskGphBuilder() override = default;\n\n  Maybe<SubTskGphBuilderStatus> Build(\n      SubTskGphBuilderCtx* ctx, const std::vector<TaskNode*>& sorted_in_tasks,\n      std::vector<TaskNode*>* sorted_out_tasks,\n      std::vector<std::vector<TaskNode*>>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc,\n      const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi,\n      const BlobDesc& logical_blob_desc, const SbpParallel& in_sbp_parallel,\n      const SbpParallel& out_sbp_parallel, const Shape& time_shape) const override;\n\n private:\n  DeviceType device_type_;\n};\n\nclass CclReduceSubTskGphBuilder final : public SubTskGphBuilder {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CclReduceSubTskGphBuilder);\n  CclReduceSubTskGphBuilder(DeviceType device_type) : device_type_(device_type) {}\n  ~CclReduceSubTskGphBuilder() override = default;\n\n  Maybe<SubTskGphBuilderStatus> Build(\n      SubTskGphBuilderCtx* ctx, const std::vector<TaskNode*>& sorted_in_tasks,\n      std::vector<TaskNode*>* sorted_out_tasks,\n      std::vector<std::vector<TaskNode*>>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc,\n      const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi,\n      const BlobDesc& logical_blob_desc, const SbpParallel& in_sbp_parallel,\n      const SbpParallel& out_sbp_parallel, const Shape& time_shape) const override;\n\n private:\n  DeviceType device_type_;\n};\n\nclass CclScatterThenAllGatherSubTskGphBuilder final : public SubTskGphBuilder {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CclScatterThenAllGatherSubTskGphBuilder);\n  CclScatterThenAllGatherSubTskGphBuilder(DeviceType device_type) : device_type_(device_type) {}\n  ~CclScatterThenAllGatherSubTskGphBuilder() override = default;\n\n  Maybe<SubTskGphBuilderStatus> Build(\n      SubTskGphBuilderCtx* ctx, const std::vector<TaskNode*>& sorted_in_tasks,\n      std::vector<TaskNode*>* sorted_out_tasks,\n      std::vector<std::vector<TaskNode*>>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc,\n      const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi,\n      const BlobDesc& logical_blob_desc, const SbpParallel& in_sbp_parallel,\n      const SbpParallel& out_sbp_parallel, const Shape& time_shape) const override;\n\n private:\n  DeviceType device_type_;\n};\n\nclass CclBroadcastSubTskGphBuilder final : public SubTskGphBuilder {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CclBroadcastSubTskGphBuilder);\n  CclBroadcastSubTskGphBuilder(DeviceType device_type) : device_type_(device_type) {}\n  ~CclBroadcastSubTskGphBuilder() override = default;\n\n  Maybe<SubTskGphBuilderStatus> Build(\n      SubTskGphBuilderCtx* ctx, const std::vector<TaskNode*>& sorted_in_tasks,\n      std::vector<TaskNode*>* sorted_out_tasks,\n      std::vector<std::vector<TaskNode*>>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc,\n      const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi,\n      const BlobDesc& logical_blob_desc, const SbpParallel& in_sbp_parallel,\n      const SbpParallel& out_sbp_parallel, const Shape& time_shape) const override;\n\n private:\n  DeviceType device_type_;\n};\n\nclass CclAll2AllSubTskGphBuilder final : public SubTskGphBuilder {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CclAll2AllSubTskGphBuilder);\n  CclAll2AllSubTskGphBuilder(DeviceType device_type) : device_type_(device_type) {}\n  ~CclAll2AllSubTskGphBuilder() override = default;\n\n  Maybe<SubTskGphBuilderStatus> Build(\n      SubTskGphBuilderCtx* ctx, const std::vector<TaskNode*>& sorted_in_tasks,\n      std::vector<TaskNode*>* sorted_out_tasks,\n      std::vector<std::vector<TaskNode*>>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc,\n      const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi,\n      const BlobDesc& logical_blob_desc, const SbpParallel& in_sbp_parallel,\n      const SbpParallel& out_sbp_parallel, const Shape& time_shape) const override;\n\n private:\n  DeviceType device_type_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_GRAPH_BOXING_CCL_SUB_TASK_GRAPH_BUILDER_H_\n"
  },
  {
    "path": "oneflow/core/graph/boxing/chain_sub_task_graph_builder.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/graph/boxing/chain_sub_task_graph_builder.h\"\n#include \"oneflow/core/graph/boxing/sub_task_graph_builder_util.h\"\n\nnamespace oneflow {\n\nMaybe<SubTskGphBuilderStatus> ChainSubTskGphBuilder::Build(\n    SubTskGphBuilderCtx* ctx, const std::vector<TaskNode*>& sorted_in_tasks,\n    std::vector<TaskNode*>* sorted_out_tasks,\n    std::vector<std::vector<TaskNode*>>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc,\n    const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi,\n    const BlobDesc& logical_blob_desc, const SbpParallel& in_sbp_parallel,\n    const SbpParallel& out_sbp_parallel, const Shape& time_shape) const {\n  for (const auto& builder : builders_) {\n    Maybe<SubTskGphBuilderStatus> boxing_builder_status = TRY(builder->Build(\n        ctx, sorted_in_tasks, sorted_out_tasks, sorted_ctrl_tasks, in_parallel_desc,\n        out_parallel_desc, lbi, logical_blob_desc, in_sbp_parallel, out_sbp_parallel, time_shape));\n    if (!boxing_builder_status.IsOk()\n        && SubTskGphBuilderUtil::IsErrorBoxingNotSupported(*boxing_builder_status.error())) {\n      continue;\n    } else {\n      return boxing_builder_status;\n    }\n  }\n  return Error::BoxingNotSupportedError();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/graph/boxing/chain_sub_task_graph_builder.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_GRAPH_BOXING_CHAIN_SUB_TASK_GRAPH_BUILDER_H_\n#define ONEFLOW_CORE_GRAPH_BOXING_CHAIN_SUB_TASK_GRAPH_BUILDER_H_\n\n#include \"oneflow/core/graph/boxing/sub_task_graph_builder.h\"\n\nnamespace oneflow {\n\nclass ChainSubTskGphBuilder final : public SubTskGphBuilder {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(ChainSubTskGphBuilder);\n  explicit ChainSubTskGphBuilder(std::vector<std::shared_ptr<SubTskGphBuilder>> builders)\n      : builders_(std::move(builders)) {}\n  ~ChainSubTskGphBuilder() override = default;\n\n  Maybe<SubTskGphBuilderStatus> Build(\n      SubTskGphBuilderCtx* ctx, const std::vector<TaskNode*>& sorted_in_tasks,\n      std::vector<TaskNode*>* sorted_out_tasks,\n      std::vector<std::vector<TaskNode*>>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc,\n      const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi,\n      const BlobDesc& logical_blob_desc, const SbpParallel& in_sbp_parallel,\n      const SbpParallel& out_sbp_parallel, const Shape& time_shape) const override;\n\n private:\n  std::vector<std::shared_ptr<SubTskGphBuilder>> builders_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_GRAPH_BOXING_CHAIN_SUB_TASK_GRAPH_BUILDER_H_\n"
  },
  {
    "path": "oneflow/core/graph/boxing/collective_boxing.proto",
    "content": "syntax = \"proto2\";\npackage oneflow.boxing.collective;\n\nimport \"oneflow/core/common/shape.proto\";\nimport \"oneflow/core/common/data_type.proto\";\nimport \"oneflow/core/common/device_type.proto\";\n\nenum OpType {\n    kOpTypeInvalid = 0;\n    kOpTypeAllReduce = 1;\n    kOpTypeReduceScatter = 2;\n    kOpTypeAllGather = 3;\n    kOpTypeReduce = 4;\n    kOpTypeBroadcast = 5;\n    kOpTypeAll2All = 6;\n}\n\nenum ReduceMethod {\n    kReduceMethodInvalid = 0;\n    kReduceMethodSum = 1;\n}\n\nmessage DeviceDesc {\n    required int64 machine_id = 1;\n    required DeviceType device_type = 2;\n    required int64 device_id = 3;\n}\n\nmessage DeviceSet {\n    repeated DeviceDesc device = 1;\n}\n\nmessage OpDesc {\n    required string name = 1;\n    required OpType op_type = 2;\n    optional ReduceMethod reduce_method = 3;\n    optional int64 root = 4;\n    required DataType data_type = 5;\n    required ShapeProto shape = 6;\n    required int64 num_ranks = 7;\n    required DeviceType device_type = 8;\n}\n\nmessage RequestDesc {\n    required OpDesc op_desc = 1;\n    required DeviceSet device_set = 2;\n    required int64 order = 3;\n    required int64 dependency_depth = 4;\n}\n\nmessage RequestSet {\n    repeated RequestDesc request = 1;\n}\n\nmessage RankDesc {\n    required OpDesc op_desc = 1;\n    required int64 rank = 2;\n}\n"
  },
  {
    "path": "oneflow/core/graph/boxing/collective_boxing_sub_task_graph_builder.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/graph/boxing/chain_sub_task_graph_builder.h\"\n#include \"oneflow/core/graph/boxing/collective_boxing_sub_task_graph_builder.h\"\n#include \"oneflow/core/graph/boxing/ccl_sub_task_graph_builder.h\"\n\nnamespace oneflow {\n\nCollectiveBoxingSubTskGphBuilder::CollectiveBoxingSubTskGphBuilder() {\n  const CollectiveBoxingConf collective_boxing_conf =\n      Singleton<ResourceDesc, ForSession>::Get()->collective_boxing_conf();\n  std::vector<std::shared_ptr<SubTskGphBuilder>> builders;\n  builders.emplace_back(new CclAllReduceSubTskGphBuilder(DeviceType::kCUDA));\n  builders.emplace_back(new CclReduceScatterSubTskGphBuilder(DeviceType::kCUDA));\n  builders.emplace_back(new CclP2SNoncontinuousSubTskGphBuilder(DeviceType::kCUDA));\n  builders.emplace_back(new CclAllGatherSubTskGphBuilder(DeviceType::kCUDA));\n  builders.emplace_back(new CclS2BNoncontinuousSubTskGphBuilder(DeviceType::kCUDA));\n  builders.emplace_back(new CclReduceSubTskGphBuilder(DeviceType::kCUDA));\n  builders.emplace_back(new CclScatterThenAllGatherSubTskGphBuilder(DeviceType::kCUDA));\n  builders.emplace_back(new CclBroadcastSubTskGphBuilder(DeviceType::kCUDA));\n\n  if (collective_boxing_conf.nccl_enable_all_to_all()) {\n#if defined(WITH_CUDA) && NCCL_VERSION_CODE > 2700\n    builders.emplace_back(new CclAll2AllSubTskGphBuilder(DeviceType::kCUDA));\n#elif defined(WITH_NPU)\n    builders.emplace_back(new CclAll2AllSubTskGphBuilder(DeviceType::kNPU));\n#elif defined(WITH_MLU)\n    builders.emplace_back(new CclAll2AllSubTskGphBuilder(DeviceType::kMLU));\n#else\n    LOG(WARNING) << \"nccl_enable_all_to_all is unavailable unless NCCL_VERSION > 2.7.0\";\n#endif\n  }\n  chain_builder_.reset(new ChainSubTskGphBuilder(builders));\n}\n\nMaybe<SubTskGphBuilderStatus> CollectiveBoxingSubTskGphBuilder::Build(\n    SubTskGphBuilderCtx* ctx, const std::vector<TaskNode*>& sorted_in_tasks,\n    std::vector<TaskNode*>* sorted_out_tasks,\n    std::vector<std::vector<TaskNode*>>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc,\n    const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi,\n    const BlobDesc& logical_blob_desc, const SbpParallel& in_sbp_parallel,\n    const SbpParallel& out_sbp_parallel, const Shape& time_shape) const {\n  if (!GlobalJobDesc().Bool(\"__is_user_function__\")) { return Error::BoxingNotSupportedError(); }\n  if (!IsSourceTimeShape(time_shape)) { return Error::BoxingNotSupportedError(); }\n  return chain_builder_->Build(ctx, sorted_in_tasks, sorted_out_tasks, sorted_ctrl_tasks,\n                               in_parallel_desc, out_parallel_desc, lbi, logical_blob_desc,\n                               in_sbp_parallel, out_sbp_parallel, time_shape);\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/graph/boxing/collective_boxing_sub_task_graph_builder.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_GRAPH_BOXING_COLLECTIVE_BOXING_SUB_TASK_GRAPH_BUILDER_H_\n#define ONEFLOW_CORE_GRAPH_BOXING_COLLECTIVE_BOXING_SUB_TASK_GRAPH_BUILDER_H_\n\n#include \"oneflow/core/graph/boxing/sub_task_graph_builder.h\"\n\nnamespace oneflow {\n\nclass CollectiveBoxingSubTskGphBuilder final : public SubTskGphBuilder {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CollectiveBoxingSubTskGphBuilder);\n  CollectiveBoxingSubTskGphBuilder();\n  ~CollectiveBoxingSubTskGphBuilder() override = default;\n\n  Maybe<SubTskGphBuilderStatus> Build(\n      SubTskGphBuilderCtx* ctx, const std::vector<TaskNode*>& sorted_in_tasks,\n      std::vector<TaskNode*>* sorted_out_tasks,\n      std::vector<std::vector<TaskNode*>>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc,\n      const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi,\n      const BlobDesc& logical_blob_desc, const SbpParallel& in_sbp_parallel,\n      const SbpParallel& out_sbp_parallel, const Shape& time_shape) const override;\n\n private:\n  std::unique_ptr<SubTskGphBuilder> chain_builder_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_GRAPH_BOXING_COLLECTIVE_BOXING_SUB_TASK_GRAPH_BUILDER_H_\n"
  },
  {
    "path": "oneflow/core/graph/boxing/collective_boxing_util.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/graph/boxing/collective_boxing_util.h\"\n\nnamespace oneflow {\n\nnamespace boxing {\n\nnamespace collective {\n\nnamespace {\n\nShape GetSplitShape(const RankDesc& rank_desc) {\n  Shape shape(rank_desc.op_desc().shape());\n  CHECK_GT(shape.NumAxes(), 0);\n  CHECK(shape.At(0) % rank_desc.op_desc().num_ranks() == 0);\n  shape.Set(0, shape.At(0) / rank_desc.op_desc().num_ranks());\n  return shape;\n}\n\nShape GetFlattenSplitShape(const RankDesc& rank_desc) {\n  Shape shape(rank_desc.op_desc().shape());\n  CHECK_GT(shape.NumAxes(), 0);\n  CHECK(shape.elem_cnt() % rank_desc.op_desc().num_ranks() == 0);\n  Shape return_shape({shape.elem_cnt() / rank_desc.op_desc().num_ranks()});\n  return return_shape;\n}\n\n}  // namespace\n\nbool GenericOpHasInput(const RankDesc& rank_desc) {\n  const OpType op_type = rank_desc.op_desc().op_type();\n  if (op_type == OpType::kOpTypeAllReduce || op_type == OpType::kOpTypeAllGather\n      || op_type == OpType::kOpTypeReduceScatter || op_type == OpType::kOpTypeReduce\n      || op_type == OpType::kOpTypeAll2All) {\n    return true;\n  } else if (op_type == OpType::kOpTypeBroadcast) {\n    CHECK(rank_desc.op_desc().has_root());\n    return rank_desc.rank() == rank_desc.op_desc().root();\n  } else {\n    UNIMPLEMENTED();\n    return false;\n  }\n}\n\nbool GenericOpHasOutput(const RankDesc& rank_desc) {\n  const OpType op_type = rank_desc.op_desc().op_type();\n  if (op_type == OpType::kOpTypeAllReduce || op_type == OpType::kOpTypeAllGather\n      || op_type == OpType::kOpTypeReduceScatter || op_type == OpType::kOpTypeBroadcast\n      || op_type == OpType::kOpTypeAll2All) {\n    return true;\n  } else if (op_type == OpType::kOpTypeReduce) {\n    CHECK(rank_desc.op_desc().has_root());\n    return rank_desc.rank() == rank_desc.op_desc().root();\n  } else {\n    UNIMPLEMENTED();\n    return false;\n  }\n}\n\nShape GenericOpGetInputShape(const RankDesc& rank_desc) {\n  CHECK(GenericOpHasInput(rank_desc));\n  const OpType op_type = rank_desc.op_desc().op_type();\n  if (op_type == OpType::kOpTypeAllReduce || op_type == OpType::kOpTypeReduceScatter\n      || op_type == OpType::kOpTypeReduce || op_type == OpType::kOpTypeBroadcast) {\n    return Shape(rank_desc.op_desc().shape());\n  } else if (op_type == OpType::kOpTypeAllGather) {\n    return GetSplitShape(rank_desc);\n  } else if (op_type == OpType::kOpTypeAll2All) {\n    return GetFlattenSplitShape(rank_desc);\n  } else {\n    UNIMPLEMENTED();\n    return Shape();\n  }\n}\n\nShape GenericOpGetOutputShape(const RankDesc& rank_desc) {\n  CHECK(GenericOpHasOutput(rank_desc));\n  const OpType op_type = rank_desc.op_desc().op_type();\n  if (op_type == OpType::kOpTypeAllReduce || op_type == OpType::kOpTypeAllGather\n      || op_type == OpType::kOpTypeReduce || op_type == OpType::kOpTypeBroadcast) {\n    return Shape(rank_desc.op_desc().shape());\n  } else if (op_type == OpType::kOpTypeReduceScatter) {\n    return GetSplitShape(rank_desc);\n  } else if (op_type == OpType::kOpTypeAll2All) {\n    return GetFlattenSplitShape(rank_desc);\n  } else {\n    UNIMPLEMENTED();\n    return Shape();\n  }\n}\n\n}  // namespace collective\n\n}  // namespace boxing\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/graph/boxing/collective_boxing_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_JOB_COLLECTIVE_BOXING_UTIL_H_\n#define ONEFLOW_CORE_JOB_COLLECTIVE_BOXING_UTIL_H_\n\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/job/plan.pb.h\"\n#include \"oneflow/core/operator/op_conf.pb.h\"\n#include \"oneflow/core/common/protobuf.h\"\n#include \"oneflow/core/graph/boxing/collective_boxing.pb.h\"\n#include \"oneflow/core/common/shape.h\"\n\nnamespace oneflow {\n\nnamespace boxing {\n\nnamespace collective {\n\ninline bool operator==(const OpDesc& lhs, const OpDesc& rhs) { return PbMd::Equals(lhs, rhs); }\n\ninline bool operator==(const DeviceDesc& lhs, const DeviceDesc& rhs) {\n  return PbMd::Equals(lhs, rhs);\n}\n\ninline bool operator==(const DeviceSet& lhs, const DeviceSet& rhs) {\n  return PbMd::Equals(lhs, rhs);\n}\n\ninline bool operator!=(const DeviceSet& lhs, const DeviceSet& rhs) { return !(lhs == rhs); }\n\nbool GenericOpHasInput(const RankDesc& rank_desc);\n\nbool GenericOpHasOutput(const RankDesc& rank_desc);\n\nShape GenericOpGetInputShape(const RankDesc& rank_desc);\n\nShape GenericOpGetOutputShape(const RankDesc& rank_desc);\n\n}  // namespace collective\n\n}  // namespace boxing\n\n}  // namespace oneflow\n\nnamespace std {\n\ntemplate<>\nstruct hash<oneflow::boxing::collective::DeviceDesc> {\n  size_t operator()(const oneflow::boxing::collective::DeviceDesc& device_desc) const {\n    size_t hash = std::hash<int64_t>()(device_desc.machine_id());\n    oneflow::HashCombine(&hash, std::hash<int64_t>()(device_desc.device_type()));\n    oneflow::HashCombine(&hash, std::hash<int64_t>()(device_desc.device_id()));\n    return hash;\n  }\n};\n\ntemplate<>\nstruct hash<oneflow::boxing::collective::DeviceSet> {\n  size_t operator()(const oneflow::boxing::collective::DeviceSet& device_set) const {\n    size_t hash = 0;\n    for (const auto& device : device_set.device()) { oneflow::AddHash(&hash, device); }\n    return hash;\n  }\n};\n\n}  // namespace std\n\n#endif  // ONEFLOW_CORE_JOB_COLLECTIVE_BOXING_UTIL_H_\n"
  },
  {
    "path": "oneflow/core/graph/boxing/fallback_to_cpu_slice_boxing_sub_task_graph_builder.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/graph/boxing/fallback_to_cpu_slice_boxing_sub_task_graph_builder.h\"\n#include \"oneflow/core/graph/boxing/sub_task_graph_builder_util.h\"\n\nnamespace oneflow {\n\nMaybe<SubTskGphBuilderStatus> FallbackToCpuSliceBoxingSubTskGphBuilder::Build(\n    SubTskGphBuilderCtx* ctx, const std::vector<TaskNode*>& sorted_in_tasks,\n    std::vector<TaskNode*>* sorted_out_tasks,\n    std::vector<std::vector<TaskNode*>>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc,\n    const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi,\n    const BlobDesc& logical_blob_desc, const SbpParallel& in_sbp_parallel,\n    const SbpParallel& out_sbp_parallel, const Shape& time_shape) const {\n  std::vector<SubTskGphBuilderStatus> status;\n\n  std::vector<TaskNode*> cpu_in_tasks;\n  std::vector<TaskNode*> cpu_out_tasks;\n  std::vector<std::vector<TaskNode*>> cpu_ctrl_tasks;\n  cpu_out_tasks.reserve(out_parallel_desc.parallel_num());\n\n  FOR_RANGE(int64_t, in_id, 0, in_parallel_desc.parallel_num()) {\n    TaskNode* in_node = sorted_in_tasks.at(in_id);\n    TaskNode* proxy_on_src_host = ctx->task_graph()->GetProxyNode(\n        in_node, lbi, GetNodeCPUMemZoneId(in_node->MemZoneId121().rank()));\n    cpu_in_tasks.push_back(proxy_on_src_host);\n  }\n  status.emplace_back(\"MoveToCpu\", \"-\");\n\n  ParallelConf cpu_in_parallel_conf = in_parallel_desc.parallel_conf();\n  cpu_in_parallel_conf.set_device_tag(\"cpu\");\n  ParallelConf cpu_out_parallel_conf = out_parallel_desc.parallel_conf();\n  cpu_out_parallel_conf.set_device_tag(\"cpu\");\n  Maybe<SubTskGphBuilderStatus> boxing_builder_status =\n      TRY(builder_->Build(ctx, cpu_in_tasks, &cpu_out_tasks, &cpu_ctrl_tasks,\n                          ParallelDesc(cpu_in_parallel_conf), ParallelDesc(cpu_out_parallel_conf),\n                          lbi, logical_blob_desc, in_sbp_parallel, out_sbp_parallel, time_shape));\n  if (!boxing_builder_status.IsOk()\n      && SubTskGphBuilderUtil::IsErrorBoxingNotSupported(*boxing_builder_status.error())) {\n    return Error::BoxingNotSupportedError();\n  }\n  status.push_back(*JUST(boxing_builder_status));\n\n  FOR_RANGE(int64_t, out_id, 0, out_parallel_desc.parallel_num()) {\n    TaskNode* out_node =\n        ctx->task_graph()->GetProxyNode(cpu_out_tasks.at(out_id), lbi, out_parallel_desc, out_id);\n    sorted_out_tasks->push_back(out_node);\n  }\n  status.emplace_back(\"MoveBackToDevice\", \"-\");\n  if (!cpu_ctrl_tasks.empty()) {\n    CHECK_EQ(cpu_ctrl_tasks.size(), sorted_out_tasks->size());\n    FOR_RANGE(size_t, i, 0, sorted_out_tasks->size()) {\n      for (TaskNode* ctrl_node : cpu_ctrl_tasks.at(i)) {\n        std::string regst_desc_name;\n        ctrl_node->BuildCtrlRegstDesc(sorted_out_tasks->at(i), &regst_desc_name);\n        TaskEdge* edge = ctx->task_graph()->NewEdge();\n        Connect<TaskNode>(ctrl_node, edge, sorted_out_tasks->at(i));\n        ctrl_node->BindEdgeWithProducedRegst(edge, regst_desc_name);\n      }\n    }\n  }\n\n  return TRY(MakeComposedSubTskGphBuilderStatus(status));\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/graph/boxing/fallback_to_cpu_slice_boxing_sub_task_graph_builder.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_GRAPH_BOXING_FALLBACK_TO_CPU_SLICE_BOXING_SUB_TASK_GRAPH_BUILDER_H_\n#define ONEFLOW_CORE_GRAPH_BOXING_FALLBACK_TO_CPU_SLICE_BOXING_SUB_TASK_GRAPH_BUILDER_H_\n\n#include \"oneflow/core/graph/boxing/sub_task_graph_builder.h\"\n#include \"oneflow/core/graph/boxing/slice_boxing_sub_task_graph_builder.h\"\n\nnamespace oneflow {\n\nclass FallbackToCpuSliceBoxingSubTskGphBuilder final : public SubTskGphBuilder {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(FallbackToCpuSliceBoxingSubTskGphBuilder);\n  FallbackToCpuSliceBoxingSubTskGphBuilder() { builder_.reset(new SliceBoxingSubTskGphBuilder()); }\n  ~FallbackToCpuSliceBoxingSubTskGphBuilder() override = default;\n\n  Maybe<SubTskGphBuilderStatus> Build(\n      SubTskGphBuilderCtx* ctx, const std::vector<TaskNode*>& sorted_in_tasks,\n      std::vector<TaskNode*>* sorted_out_tasks,\n      std::vector<std::vector<TaskNode*>>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc,\n      const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi,\n      const BlobDesc& logical_blob_desc, const SbpParallel& in_sbp_parallel,\n      const SbpParallel& out_sbp_parallel, const Shape& time_shape) const override;\n\n private:\n  std::unique_ptr<SliceBoxingSubTskGphBuilder> builder_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_GRAPH_BOXING_FALLBACK_TO_CPU_SLICE_BOXING_SUB_TASK_GRAPH_BUILDER_H_\n"
  },
  {
    "path": "oneflow/core/graph/boxing/hierarchical_sub_task_graph_builder.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_GRAPH_BOXING_HIERARCHICAL_SUB_TASK_GRAPH_BUILDER_H_\n#define ONEFLOW_CORE_GRAPH_BOXING_HIERARCHICAL_SUB_TASK_GRAPH_BUILDER_H_\n\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/graph/boxing/sub_task_graph_builder_context.h\"\n#include \"oneflow/core/graph/boxing/sub_task_graph_builder_status_util.h\"\n\nnamespace oneflow {\n\nclass HierarchicalSubTskGphBuilder {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(HierarchicalSubTskGphBuilder);\n  HierarchicalSubTskGphBuilder() = default;\n  virtual ~HierarchicalSubTskGphBuilder() = default;\n\n  virtual Maybe<SubTskGphBuilderStatus> Build(\n      SubTskGphBuilderCtx* ctx, const std::vector<TaskNode*>& sorted_in_tasks,\n      std::vector<TaskNode*>* sorted_out_tasks,\n      std::vector<std::vector<TaskNode*>>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc,\n      const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi,\n      const BlobDesc& logical_blob_desc, const NdSbp& in_nd_sbp, const NdSbp& out_nd_sbp,\n      const Shape& time_shape) const = 0;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_GRAPH_BOXING_HIERARCHICAL_SUB_TASK_GRAPH_BUILDER_H_\n"
  },
  {
    "path": "oneflow/core/graph/boxing/hierarchical_sub_task_graph_builder_impl.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/graph/boxing/hierarchical_sub_task_graph_builder_impl.h\"\n#include \"oneflow/core/graph/boxing/hierarchical_sub_task_graph_builder_util.h\"\n#include \"oneflow/core/graph/boxing/sub_task_graph_builder.h\"\n#include \"oneflow/core/graph/boxing/chain_sub_task_graph_builder.h\"\n#include \"oneflow/core/graph/boxing/collective_boxing_sub_task_graph_builder.h\"\n#include \"oneflow/core/graph/boxing/slice_boxing_sub_task_graph_builder.h\"\n#include \"oneflow/core/graph/boxing/fallback_to_cpu_slice_boxing_sub_task_graph_builder.h\"\n#include \"oneflow/core/graph/boxing/naive_b2b_sub_task_graph_builder.h\"\n#include \"oneflow/core/graph/boxing/naive_b2p_sub_task_graph_builder.h\"\n#include \"oneflow/core/graph/boxing/b21_sub_task_graph_builder.h\"\n#include \"oneflow/core/graph/boxing/one_to_one_sub_task_graph_builder.h\"\n#include \"oneflow/core/graph/boxing/sub_task_graph_builder_util.h\"\n#include \"oneflow/core/framework/sbp_infer_util.h\"\n#include \"oneflow/core/job/sbp_parallel.h\"\n#include \"oneflow/core/graph/nccl_send_recv_boxing_task_node.h\"\n#include \"oneflow/core/job/nd_sbp_util.h\"\n#include \"oneflow/core/graph/task_stream_id.h\"\n#include \"oneflow/core/job/job_desc.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nstd::shared_ptr<ChainSubTskGphBuilder> Make1DSubTskGphBuilder() {\n  std::vector<std::shared_ptr<SubTskGphBuilder>> builders;\n  builders.emplace_back(new OneToOneSubTskGphBuilder());\n  builders.emplace_back(new B21SubTskGphBuilder());\n  if (!Singleton<ResourceDesc, ForSession>::Get()->nccl_use_compute_stream()) {\n    builders.emplace_back(new CollectiveBoxingSubTskGphBuilder());\n  }\n  builders.emplace_back(new SliceBoxingSubTskGphBuilder());\n  builders.emplace_back(new FallbackToCpuSliceBoxingSubTskGphBuilder());\n  builders.emplace_back(new NaiveB2BSubTskGphBuilder());\n  builders.emplace_back(new NaiveB2PSubTskGphBuilder());\n  return std::make_shared<ChainSubTskGphBuilder>(builders);\n}\n\nvoid MergeParallelConf(const ParallelDesc& parallel_desc_0, const ParallelDesc& parallel_desc_1,\n                       ParallelConf* parallel_conf) {\n  CHECK_EQ(parallel_desc_0.device_tag(), parallel_desc_1.device_tag());\n  std::set<std::pair<int64_t, int64_t>> machine_device_ids;\n  for (int64_t machine_id : parallel_desc_0.sorted_machine_ids()) {\n    for (int64_t device_id : parallel_desc_0.sorted_dev_phy_ids(machine_id)) {\n      machine_device_ids.insert(std::make_pair(machine_id, device_id));\n    }\n  }\n  for (int64_t machine_id : parallel_desc_1.sorted_machine_ids()) {\n    for (int64_t device_id : parallel_desc_1.sorted_dev_phy_ids(machine_id)) {\n      machine_device_ids.insert(std::make_pair(machine_id, device_id));\n    }\n  }\n  parallel_conf->set_device_tag(parallel_desc_0.device_tag());\n  for (const auto& pair : machine_device_ids) {\n    parallel_conf->add_device_name(\"@\" + std::to_string(pair.first) + \":\"\n                                   + std::to_string(pair.second));\n  }\n}\n\ninline std::string NewUniqueIdGbc() {\n  // The boxing task graph is built on rank 0 and broadcasted to all the ranks,\n  // so the ids here are unique among all the ranks.\n  static std::atomic<int64_t> counter(0);\n  static std::atomic<int64_t> curr_job_id(0);\n  if (curr_job_id != GlobalJobDesc().job_id()) {\n    curr_job_id = GlobalJobDesc().job_id();\n    counter = 0;\n  }\n  return std::to_string(counter.fetch_add(1, std::memory_order_relaxed));\n}\n\nclass NDNcclSendRecvBoxingSubTskGphBuilder final : public HierarchicalSubTskGphBuilder {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(NDNcclSendRecvBoxingSubTskGphBuilder);\n  NDNcclSendRecvBoxingSubTskGphBuilder() {}\n  ~NDNcclSendRecvBoxingSubTskGphBuilder() override = default;\n\n  Maybe<SubTskGphBuilderStatus> Build(SubTskGphBuilderCtx* ctx,\n                                      const std::vector<TaskNode*>& sorted_in_tasks,\n                                      std::vector<TaskNode*>* sorted_out_tasks,\n                                      std::vector<std::vector<TaskNode*>>* sorted_ctrl_tasks,\n                                      const ParallelDesc& in_parallel_desc,\n                                      const ParallelDesc& out_parallel_desc,\n                                      const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc,\n                                      const NdSbp& in_nd_sbp, const NdSbp& out_nd_sbp,\n                                      const Shape& time_shape) const override {\n    if (in_parallel_desc.device_type() == out_parallel_desc.device_type()\n        && in_parallel_desc.device_type() != DeviceType::kCPU\n        && !NdSbpHasPartialParallel(out_nd_sbp)) {\n#if (defined(WITH_CUDA) && (NCCL_VERSION_CODE > 2700)) || defined(WITH_NPU) || defined(WITH_MLU)\n      ParallelConf merged_parallel_conf;\n      MergeParallelConf(in_parallel_desc.parallel_conf(), out_parallel_desc.parallel_conf(),\n                        &merged_parallel_conf);\n      ParallelDesc merged_parallel_desc(merged_parallel_conf);\n      TaskNode* first_in_node = sorted_in_tasks.front();\n      sorted_ctrl_tasks->resize(out_parallel_desc.parallel_num());\n      std::string stream_name = \"NCCL_SEND_RECV_BOXING\" + NewUniqueIdGbc();\n      FOR_RANGE(int64_t, id, 0, merged_parallel_desc.parallel_num()) {\n        NcclSendRecvBoxingTaskNode* node = ctx->task_graph()->NewNode<NcclSendRecvBoxingTaskNode>();\n        const int64_t machine_id = JUST(merged_parallel_desc.MachineId4ParallelId(id));\n        int64_t device_index = JUST(merged_parallel_desc.DeviceId4ParallelId(id));\n        int64_t thrd_id = EncodeStreamIdToInt64(GenerateNamedTaskStreamId(\n            machine_id, merged_parallel_desc.device_type(), device_index, stream_name));\n        bool has_input = in_parallel_desc.Containing(machine_id, device_index);\n        bool has_output = out_parallel_desc.Containing(machine_id, device_index);\n        node->Init(machine_id, thrd_id, lbi, logical_blob_desc.shape(),\n                   logical_blob_desc.data_type(), in_nd_sbp, out_nd_sbp, in_parallel_desc,\n                   out_parallel_desc, id, merged_parallel_desc, has_input, has_output, stream_name);\n        if (has_input) {\n          int64_t in_id =\n              JUST(in_parallel_desc.ParallelId4MachineDeviceId(machine_id, device_index));\n          ctx->task_graph()->ConnectWithLbi(sorted_in_tasks.at(in_id), node, lbi);\n        } else {\n          // TODO: find nearest\n          std::string regst_desc_name;\n          first_in_node->BuildCtrlRegstDesc(node, &regst_desc_name);\n          TaskEdge* edge = ctx->task_graph()->NewEdge();\n          Connect<TaskNode>(first_in_node, edge, node);\n          first_in_node->BindEdgeWithProducedRegst(edge, regst_desc_name);\n        }\n        if (has_output) { sorted_out_tasks->push_back(node); }\n      }\n      return BuildSubTskGphBuilderStatus(\"NDNcclSendRecvBoxingSubTskGphBuilder\", \"\");\n#else\n      return Error::BoxingNotSupportedError() << \"No Device or low NCCL version\";\n#endif\n    } else {\n      return Error::BoxingNotSupportedError()\n             << \"Partial SBP in the consumer or not running on CUDA\";\n    }\n  }\n};\n\nclass Dim0NdSbpMismatchedSubTskGphBuilder final : public HierarchicalSubTskGphBuilder {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(Dim0NdSbpMismatchedSubTskGphBuilder);\n  Dim0NdSbpMismatchedSubTskGphBuilder() {\n    inter_group_sub_tsk_gph_builder_.reset(\n        new InterGroupSubTskGphBuilder(Make1DSubTskGphBuilder()));\n  }\n  ~Dim0NdSbpMismatchedSubTskGphBuilder() override = default;\n\n  Maybe<SubTskGphBuilderStatus> Build(SubTskGphBuilderCtx* ctx,\n                                      const std::vector<TaskNode*>& sorted_in_tasks,\n                                      std::vector<TaskNode*>* sorted_out_tasks,\n                                      std::vector<std::vector<TaskNode*>>* sorted_ctrl_tasks,\n                                      const ParallelDesc& in_parallel_desc,\n                                      const ParallelDesc& out_parallel_desc,\n                                      const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc,\n                                      const NdSbp& in_nd_sbp, const NdSbp& out_nd_sbp,\n                                      const Shape& time_shape) const override {\n    if (in_parallel_desc.hierarchy()->NumAxes() == 2\n        && (*in_parallel_desc.hierarchy() == *out_parallel_desc.hierarchy())\n        && in_nd_sbp.sbp_parallel(0) != out_nd_sbp.sbp_parallel(0)\n        && in_nd_sbp.sbp_parallel(1) == out_nd_sbp.sbp_parallel(1)\n        && !(NdSbpAllSameSplitParallel(in_nd_sbp) || NdSbpAllSameSplitParallel(out_nd_sbp))) {\n      return inter_group_sub_tsk_gph_builder_->Build(\n          ctx, sorted_in_tasks, sorted_out_tasks, sorted_ctrl_tasks, in_parallel_desc,\n          out_parallel_desc, lbi, logical_blob_desc, in_nd_sbp, out_nd_sbp, time_shape);\n    } else {\n      return nd_nccl_send_recv_boxing_sub_tsk_gph_builder_->Build(\n          ctx, sorted_in_tasks, sorted_out_tasks, sorted_ctrl_tasks, in_parallel_desc,\n          out_parallel_desc, lbi, logical_blob_desc, in_nd_sbp, out_nd_sbp, time_shape);\n    }\n  }\n\n private:\n  std::unique_ptr<InterGroupSubTskGphBuilder> inter_group_sub_tsk_gph_builder_;\n  std::unique_ptr<NDNcclSendRecvBoxingSubTskGphBuilder>\n      nd_nccl_send_recv_boxing_sub_tsk_gph_builder_;\n};\n\nclass Same2DHierarchySubTskGphBuilder final : public HierarchicalSubTskGphBuilder {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(Same2DHierarchySubTskGphBuilder);\n  Same2DHierarchySubTskGphBuilder() {\n    intra_group_sub_tsk_gph_builder_.reset(\n        new IntraGroupSubTskGphBuilder(Make1DSubTskGphBuilder()));\n    dim0_nd_sbp_mismatched_sub_tsk_gph_builder_.reset(new Dim0NdSbpMismatchedSubTskGphBuilder());\n  }\n  ~Same2DHierarchySubTskGphBuilder() override = default;\n\n  Maybe<SubTskGphBuilderStatus> Build(SubTskGphBuilderCtx* ctx,\n                                      const std::vector<TaskNode*>& sorted_in_tasks,\n                                      std::vector<TaskNode*>* sorted_out_tasks,\n                                      std::vector<std::vector<TaskNode*>>* sorted_ctrl_tasks,\n                                      const ParallelDesc& in_parallel_desc,\n                                      const ParallelDesc& out_parallel_desc,\n                                      const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc,\n                                      const NdSbp& in_nd_sbp, const NdSbp& out_nd_sbp,\n                                      const Shape& time_shape) const override {\n    if (in_parallel_desc.hierarchy()->NumAxes() == 2\n        && (*in_parallel_desc.hierarchy() == *out_parallel_desc.hierarchy())) {\n      if (in_nd_sbp.sbp_parallel(0) == out_nd_sbp.sbp_parallel(0)) {\n        return intra_group_sub_tsk_gph_builder_->Build(\n            ctx, sorted_in_tasks, sorted_out_tasks, sorted_ctrl_tasks, in_parallel_desc,\n            out_parallel_desc, lbi, logical_blob_desc, in_nd_sbp, out_nd_sbp, time_shape);\n      } else {\n        return dim0_nd_sbp_mismatched_sub_tsk_gph_builder_->Build(\n            ctx, sorted_in_tasks, sorted_out_tasks, sorted_ctrl_tasks, in_parallel_desc,\n            out_parallel_desc, lbi, logical_blob_desc, in_nd_sbp, out_nd_sbp, time_shape);\n      }\n    } else {\n      return Error::BoxingNotSupportedError();\n    }\n  }\n\n private:\n  std::unique_ptr<IntraGroupSubTskGphBuilder> intra_group_sub_tsk_gph_builder_;\n  std::unique_ptr<Dim0NdSbpMismatchedSubTskGphBuilder> dim0_nd_sbp_mismatched_sub_tsk_gph_builder_;\n};\n\n}  // namespace\n\nstruct DispatchHierarchicalSubTskGphBuilder::Impl {\n  Impl();\n  std::unique_ptr<FlatSubTskGphBuilder> flat_sub_tsk_gph_builder_;\n  std::unique_ptr<Same2DHierarchySubTskGphBuilder> same_2d_hierarchy_sub_tsk_gph_builder_;\n  std::unique_ptr<NDNcclSendRecvBoxingSubTskGphBuilder>\n      nd_nccl_send_recv_boxing_sub_tsk_gph_builder_;\n};\n\nDispatchHierarchicalSubTskGphBuilder::Impl::Impl() {\n  flat_sub_tsk_gph_builder_.reset(new FlatSubTskGphBuilder(Make1DSubTskGphBuilder()));\n  same_2d_hierarchy_sub_tsk_gph_builder_.reset(new Same2DHierarchySubTskGphBuilder());\n  nd_nccl_send_recv_boxing_sub_tsk_gph_builder_.reset(new NDNcclSendRecvBoxingSubTskGphBuilder());\n}\n\nDispatchHierarchicalSubTskGphBuilder::DispatchHierarchicalSubTskGphBuilder() {\n  impl_.reset(new Impl());\n}\n\nDispatchHierarchicalSubTskGphBuilder::~DispatchHierarchicalSubTskGphBuilder() = default;\n\nMaybe<SubTskGphBuilderStatus> DispatchHierarchicalSubTskGphBuilder::Build(\n    SubTskGphBuilderCtx* ctx, const std::vector<TaskNode*>& sorted_in_tasks,\n    std::vector<TaskNode*>* sorted_out_tasks,\n    std::vector<std::vector<TaskNode*>>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc,\n    const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi,\n    const BlobDesc& logical_blob_desc, const NdSbp& in_nd_sbp, const NdSbp& out_nd_sbp,\n    const Shape& time_shape) const {\n  ParallelDesc reduced_in_parallel_desc = in_parallel_desc;\n  ParallelDesc reduced_out_parallel_desc = out_parallel_desc;\n  NdSbp reduced_in_nd_sbp;\n  NdSbp reduced_out_nd_sbp;\n  // The 1d to 2d and 2d to 1d cases are consider in this function\n  // If it gives out 1d sbp and 2d sbp simultaneously, then that the 2d sbp can not be converted\n  // to 1d sbp and 1d sbp can not be expanded to 2d sbp.\n  InOutParallelDimReduce(in_parallel_desc, out_parallel_desc, in_nd_sbp, out_nd_sbp,\n                         &reduced_in_parallel_desc, &reduced_out_parallel_desc, &reduced_in_nd_sbp,\n                         &reduced_out_nd_sbp, logical_blob_desc.shape());\n  const auto& in_hierarchy = reduced_in_parallel_desc.hierarchy();\n  const auto& out_hierarchy = reduced_out_parallel_desc.hierarchy();\n  if ((in_hierarchy->NumAxes() > 2 || out_hierarchy->NumAxes() > 2)\n      && reduced_in_parallel_desc.device_type() == reduced_out_parallel_desc.device_type()\n      && reduced_in_parallel_desc.device_type() != DeviceType::kCPU) {\n    return impl_->nd_nccl_send_recv_boxing_sub_tsk_gph_builder_->Build(\n        ctx, sorted_in_tasks, sorted_out_tasks, sorted_ctrl_tasks, reduced_in_parallel_desc,\n        reduced_out_parallel_desc, lbi, logical_blob_desc, reduced_in_nd_sbp, reduced_out_nd_sbp,\n        time_shape);\n  }\n  if (in_hierarchy->NumAxes() <= 2 && out_hierarchy->NumAxes() <= 2) {\n    if (in_hierarchy->NumAxes() == 1 && out_hierarchy->NumAxes() == 1) {\n      return impl_->flat_sub_tsk_gph_builder_->Build(\n          ctx, sorted_in_tasks, sorted_out_tasks, sorted_ctrl_tasks, reduced_in_parallel_desc,\n          reduced_out_parallel_desc, lbi, logical_blob_desc, reduced_in_nd_sbp, reduced_out_nd_sbp,\n          time_shape);\n    } else if ((in_hierarchy->NumAxes() == 2) && (*in_hierarchy == *out_hierarchy)) {\n      return impl_->same_2d_hierarchy_sub_tsk_gph_builder_->Build(\n          ctx, sorted_in_tasks, sorted_out_tasks, sorted_ctrl_tasks, reduced_in_parallel_desc,\n          reduced_out_parallel_desc, lbi, logical_blob_desc, reduced_in_nd_sbp, reduced_out_nd_sbp,\n          time_shape);\n    } else if (reduced_in_parallel_desc.device_type() != DeviceType::kCPU\n               && reduced_out_parallel_desc.device_type() != DeviceType::kCPU) {\n      return impl_->nd_nccl_send_recv_boxing_sub_tsk_gph_builder_->Build(\n          ctx, sorted_in_tasks, sorted_out_tasks, sorted_ctrl_tasks, reduced_in_parallel_desc,\n          reduced_out_parallel_desc, lbi, logical_blob_desc, reduced_in_nd_sbp, reduced_out_nd_sbp,\n          time_shape);\n    } else {\n      return Error::BoxingNotSupportedError();\n    }\n  }\n  return Error::BoxingNotSupportedError();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/graph/boxing/hierarchical_sub_task_graph_builder_impl.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_GRAPH_BOXING_HIERARCHICAL_SUB_TASK_GRAPH_BUILDER_IMPL_H_\n#define ONEFLOW_CORE_GRAPH_BOXING_HIERARCHICAL_SUB_TASK_GRAPH_BUILDER_IMPL_H_\n\n#include \"oneflow/core/graph/boxing/hierarchical_sub_task_graph_builder.h\"\n\nnamespace oneflow {\n\nclass DispatchHierarchicalSubTskGphBuilder final : public HierarchicalSubTskGphBuilder {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(DispatchHierarchicalSubTskGphBuilder);\n  DispatchHierarchicalSubTskGphBuilder();\n  ~DispatchHierarchicalSubTskGphBuilder() override;\n\n  Maybe<SubTskGphBuilderStatus> Build(SubTskGphBuilderCtx* ctx,\n                                      const std::vector<TaskNode*>& sorted_in_tasks,\n                                      std::vector<TaskNode*>* sorted_out_tasks,\n                                      std::vector<std::vector<TaskNode*>>* sorted_ctrl_tasks,\n                                      const ParallelDesc& in_parallel_desc,\n                                      const ParallelDesc& out_parallel_desc,\n                                      const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc,\n                                      const NdSbp& in_nd_sbp, const NdSbp& out_nd_sbp,\n                                      const Shape& time_shape) const override;\n\n private:\n  struct Impl;\n  std::unique_ptr<Impl> impl_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_GRAPH_BOXING_HIERARCHICAL_SUB_TASK_GRAPH_BUILDER_IMPL_H_\n"
  },
  {
    "path": "oneflow/core/graph/boxing/hierarchical_sub_task_graph_builder_util.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/graph/boxing/hierarchical_sub_task_graph_builder_util.h\"\n\nnamespace oneflow {\n\nMaybe<SubTskGphBuilderStatus> FlatSubTskGphBuilder::Build(\n    SubTskGphBuilderCtx* ctx, const std::vector<TaskNode*>& sorted_in_tasks,\n    std::vector<TaskNode*>* sorted_out_tasks,\n    std::vector<std::vector<TaskNode*>>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc,\n    const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi,\n    const BlobDesc& logical_blob_desc, const NdSbp& in_nd_sbp, const NdSbp& out_nd_sbp,\n    const Shape& time_shape) const {\n  if (in_parallel_desc.hierarchy()->NumAxes() == 1\n      && out_parallel_desc.hierarchy()->NumAxes() == 1) {\n    return sub_tsk_gph_builder_->Build(ctx, sorted_in_tasks, sorted_out_tasks, sorted_ctrl_tasks,\n                                       in_parallel_desc, out_parallel_desc, lbi, logical_blob_desc,\n                                       in_nd_sbp.sbp_parallel(0), out_nd_sbp.sbp_parallel(0),\n                                       time_shape);\n  } else {\n    return Error::BoxingNotSupportedError();\n  }\n}\n\nMaybe<SubTskGphBuilderStatus> IntraGroupSubTskGphBuilder::Build(\n    SubTskGphBuilderCtx* ctx, const std::vector<TaskNode*>& sorted_in_tasks,\n    std::vector<TaskNode*>* sorted_out_tasks,\n    std::vector<std::vector<TaskNode*>>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc,\n    const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi,\n    const BlobDesc& logical_blob_desc, const NdSbp& in_nd_sbp, const NdSbp& out_nd_sbp,\n    const Shape& time_shape) const {\n  if (*in_parallel_desc.hierarchy() == *out_parallel_desc.hierarchy()\n      && in_parallel_desc.hierarchy()->NumAxes() == 2\n      && in_nd_sbp.sbp_parallel(0) == out_nd_sbp.sbp_parallel(0)\n      && in_nd_sbp.sbp_parallel(1) != out_nd_sbp.sbp_parallel(1)) {\n    const auto& hierarchy = in_parallel_desc.hierarchy();\n    std::vector<SubTskGphBuilderStatus> status;\n    const int64_t num_groups = hierarchy->At(0);\n    const int64_t group_size = hierarchy->At(1);\n    status.reserve(num_groups);\n    sorted_ctrl_tasks->resize(out_parallel_desc.parallel_num());\n    sorted_out_tasks->resize(out_parallel_desc.parallel_num());\n    FOR_RANGE(int64_t, i, 0, num_groups) {\n      std::vector<TaskNode*> in_tasks;\n      std::vector<TaskNode*> out_tasks;\n      std::vector<std::vector<TaskNode*>> ctrl_tasks;\n      ParallelConf in_parallel_conf;\n      in_parallel_conf.set_device_tag(in_parallel_desc.device_tag());\n      in_parallel_conf.mutable_hierarchy()->add_dim(group_size);\n      ParallelConf out_parallel_conf;\n      out_parallel_conf.set_device_tag(out_parallel_desc.device_tag());\n      out_parallel_conf.mutable_hierarchy()->add_dim(group_size);\n      FOR_RANGE(int64_t, j, 0, group_size) {\n        const int64_t parallel_id = i * group_size + j;\n        in_tasks.emplace_back(sorted_in_tasks.at(parallel_id));  // NOLINT\n        in_parallel_conf.add_device_name(\n            \"@\" + std::to_string(JUST(in_parallel_desc.MachineId4ParallelId(parallel_id))) + \":\"\n            + std::to_string(JUST(in_parallel_desc.DeviceId4ParallelId(parallel_id))));\n        out_parallel_conf.add_device_name(\n            \"@\" + std::to_string(JUST(out_parallel_desc.MachineId4ParallelId(parallel_id))) + \":\"\n            + std::to_string(JUST(out_parallel_desc.DeviceId4ParallelId(parallel_id))));\n      }\n      DimVector dim_vec = logical_blob_desc.shape().dim_vec();\n      if (in_nd_sbp.sbp_parallel(0).has_split_parallel()) {\n        const int64_t axis = in_nd_sbp.sbp_parallel(0).split_parallel().axis();\n        dim_vec.at(axis) /= hierarchy->At(0);\n      }\n      BlobDesc new_blob_desc(Shape(dim_vec), logical_blob_desc.data_type(),\n                             logical_blob_desc.memory_format());\n      std::shared_ptr<SubTskGphBuilderStatus> boxing_builder_status =\n          JUST(sub_tsk_gph_builder_->Build(\n              ctx, in_tasks, &out_tasks, &ctrl_tasks, ParallelDesc(in_parallel_conf),\n              ParallelDesc(out_parallel_conf), lbi, new_blob_desc, in_nd_sbp.sbp_parallel(1),\n              out_nd_sbp.sbp_parallel(1), time_shape));\n      status.emplace_back(*boxing_builder_status);\n      CHECK_EQ_OR_RETURN(out_tasks.size(), group_size);  // NOLINT\n      FOR_RANGE(int64_t, j, 0, group_size) {\n        const int64_t parallel_id = i * group_size + j;\n        sorted_out_tasks->at(parallel_id) = out_tasks.at(j);  // NOLINT\n        if (!ctrl_tasks.empty()) {\n          for (TaskNode* ctrl_node : ctrl_tasks.at(j)) {                 // NOLINT\n            sorted_ctrl_tasks->at(parallel_id).emplace_back(ctrl_node);  // NOLINT\n          }\n        }\n      }\n    }\n    return MakeComposedSubTskGphBuilderStatus(status);\n  } else {\n    return Error::BoxingNotSupportedError();\n  }\n}\n\nMaybe<SubTskGphBuilderStatus> InterGroupSubTskGphBuilder::Build(\n    SubTskGphBuilderCtx* ctx, const std::vector<TaskNode*>& sorted_in_tasks,\n    std::vector<TaskNode*>* sorted_out_tasks,\n    std::vector<std::vector<TaskNode*>>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc,\n    const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi,\n    const BlobDesc& logical_blob_desc, const NdSbp& in_nd_sbp, const NdSbp& out_nd_sbp,\n    const Shape& time_shape) const {\n  if (*in_parallel_desc.hierarchy() == *out_parallel_desc.hierarchy()\n      && in_parallel_desc.hierarchy()->NumAxes() == 2\n      && in_nd_sbp.sbp_parallel(1) == out_nd_sbp.sbp_parallel(1)\n      && in_nd_sbp.sbp_parallel(0) != out_nd_sbp.sbp_parallel(0)\n      && !NdSbpAllSameSplitParallel(in_nd_sbp) && !NdSbpAllSameSplitParallel(out_nd_sbp)) {\n    const auto& hierarchy = in_parallel_desc.hierarchy();\n    std::vector<SubTskGphBuilderStatus> status;\n    const int64_t num_groups = hierarchy->At(0);\n    const int64_t group_size = hierarchy->At(1);\n    status.reserve(group_size);\n    sorted_ctrl_tasks->resize(out_parallel_desc.parallel_num());\n    sorted_out_tasks->resize(out_parallel_desc.parallel_num());\n    FOR_RANGE(int64_t, i, 0, group_size) {\n      std::vector<TaskNode*> in_tasks;\n      std::vector<TaskNode*> out_tasks;\n      std::vector<std::vector<TaskNode*>> ctrl_tasks;\n      ParallelConf in_parallel_conf;\n      in_parallel_conf.set_device_tag(in_parallel_desc.device_tag());\n      in_parallel_conf.mutable_hierarchy()->add_dim(num_groups);\n      ParallelConf out_parallel_conf;\n      out_parallel_conf.set_device_tag(out_parallel_desc.device_tag());\n      out_parallel_conf.mutable_hierarchy()->add_dim(num_groups);\n      FOR_RANGE(int64_t, j, 0, num_groups) {\n        const int64_t parallel_id = j * group_size + i;\n        in_tasks.emplace_back(sorted_in_tasks.at(parallel_id));  // NOLINT\n        in_parallel_conf.add_device_name(\n            \"@\" + std::to_string(JUST(in_parallel_desc.MachineId4ParallelId(parallel_id))) + \":\"\n            + std::to_string(JUST(in_parallel_desc.DeviceId4ParallelId(parallel_id))));\n        out_parallel_conf.add_device_name(\n            \"@\" + std::to_string(JUST(out_parallel_desc.MachineId4ParallelId(parallel_id))) + \":\"\n            + std::to_string(JUST(out_parallel_desc.DeviceId4ParallelId(parallel_id))));\n      }\n      DimVector dim_vec = logical_blob_desc.shape().dim_vec();\n      if (in_nd_sbp.sbp_parallel(1).has_split_parallel()) {\n        const int64_t axis = in_nd_sbp.sbp_parallel(1).split_parallel().axis();\n        dim_vec.at(axis) /= hierarchy->At(1);\n      }\n      BlobDesc new_blob_desc(Shape(dim_vec), logical_blob_desc.data_type(),\n                             logical_blob_desc.memory_format());\n      std::shared_ptr<SubTskGphBuilderStatus> boxing_builder_status =\n          JUST(sub_tsk_gph_builder_->Build(\n              ctx, in_tasks, &out_tasks, &ctrl_tasks, ParallelDesc(in_parallel_conf),\n              ParallelDesc(out_parallel_conf), lbi, new_blob_desc, in_nd_sbp.sbp_parallel(0),\n              out_nd_sbp.sbp_parallel(0), time_shape));\n      status.emplace_back(*boxing_builder_status);\n      CHECK_EQ_OR_RETURN(out_tasks.size(), num_groups);  // NOLINT\n      FOR_RANGE(int64_t, j, 0, num_groups) {\n        const int64_t parallel_id = j * group_size + i;\n        sorted_out_tasks->at(parallel_id) = out_tasks.at(j);  // NOLINT\n        if (!ctrl_tasks.empty()) {\n          for (TaskNode* ctrl_node : ctrl_tasks.at(j)) {                 // NOLINT\n            sorted_ctrl_tasks->at(parallel_id).emplace_back(ctrl_node);  // NOLINT\n          }\n        }\n      }\n    }\n    return MakeComposedSubTskGphBuilderStatus(status);\n  } else {\n    return Error::BoxingNotSupportedError();\n  }\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/graph/boxing/hierarchical_sub_task_graph_builder_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_GRAPH_BOXING_HIERARCHICAL_SUB_TASK_GRAPH_BUILDER_UTIL_H_\n#define ONEFLOW_CORE_GRAPH_BOXING_HIERARCHICAL_SUB_TASK_GRAPH_BUILDER_UTIL_H_\n\n#include \"oneflow/core/graph/boxing/hierarchical_sub_task_graph_builder.h\"\n#include \"oneflow/core/graph/boxing/sub_task_graph_builder.h\"\n\nnamespace oneflow {\n\nclass FlatSubTskGphBuilder final : public HierarchicalSubTskGphBuilder {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(FlatSubTskGphBuilder);\n  FlatSubTskGphBuilder(const std::shared_ptr<SubTskGphBuilder>& sub_tsk_gph_builder)\n      : sub_tsk_gph_builder_(sub_tsk_gph_builder) {}\n  ~FlatSubTskGphBuilder() override = default;\n\n  Maybe<SubTskGphBuilderStatus> Build(SubTskGphBuilderCtx* ctx,\n                                      const std::vector<TaskNode*>& sorted_in_tasks,\n                                      std::vector<TaskNode*>* sorted_out_tasks,\n                                      std::vector<std::vector<TaskNode*>>* sorted_ctrl_tasks,\n                                      const ParallelDesc& in_parallel_desc,\n                                      const ParallelDesc& out_parallel_desc,\n                                      const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc,\n                                      const NdSbp& in_nd_sbp, const NdSbp& out_nd_sbp,\n                                      const Shape& time_shape) const override;\n\n private:\n  std::shared_ptr<SubTskGphBuilder> sub_tsk_gph_builder_;\n};\n\nclass IntraGroupSubTskGphBuilder final : public HierarchicalSubTskGphBuilder {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(IntraGroupSubTskGphBuilder);\n  IntraGroupSubTskGphBuilder(const std::shared_ptr<SubTskGphBuilder>& sub_tsk_gph_builder)\n      : sub_tsk_gph_builder_(sub_tsk_gph_builder) {}\n  ~IntraGroupSubTskGphBuilder() override = default;\n\n  Maybe<SubTskGphBuilderStatus> Build(SubTskGphBuilderCtx* ctx,\n                                      const std::vector<TaskNode*>& sorted_in_tasks,\n                                      std::vector<TaskNode*>* sorted_out_tasks,\n                                      std::vector<std::vector<TaskNode*>>* sorted_ctrl_tasks,\n                                      const ParallelDesc& in_parallel_desc,\n                                      const ParallelDesc& out_parallel_desc,\n                                      const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc,\n                                      const NdSbp& in_nd_sbp, const NdSbp& out_nd_sbp,\n                                      const Shape& time_shape) const override;\n\n private:\n  std::shared_ptr<SubTskGphBuilder> sub_tsk_gph_builder_;\n};\n\nclass InterGroupSubTskGphBuilder final : public HierarchicalSubTskGphBuilder {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(InterGroupSubTskGphBuilder);\n  InterGroupSubTskGphBuilder(const std::shared_ptr<SubTskGphBuilder>& sub_tsk_gph_builder)\n      : sub_tsk_gph_builder_(sub_tsk_gph_builder) {}\n  ~InterGroupSubTskGphBuilder() override = default;\n\n  Maybe<SubTskGphBuilderStatus> Build(SubTskGphBuilderCtx* ctx,\n                                      const std::vector<TaskNode*>& sorted_in_tasks,\n                                      std::vector<TaskNode*>* sorted_out_tasks,\n                                      std::vector<std::vector<TaskNode*>>* sorted_ctrl_tasks,\n                                      const ParallelDesc& in_parallel_desc,\n                                      const ParallelDesc& out_parallel_desc,\n                                      const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc,\n                                      const NdSbp& in_nd_sbp, const NdSbp& out_nd_sbp,\n                                      const Shape& time_shape) const override;\n\n private:\n  std::shared_ptr<SubTskGphBuilder> sub_tsk_gph_builder_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_GRAPH_BOXING_HIERARCHICAL_SUB_TASK_GRAPH_BUILDER_UTIL_H_\n"
  },
  {
    "path": "oneflow/core/graph/boxing/naive_b2b_sub_task_graph_builder.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/graph/boxing/naive_b2b_sub_task_graph_builder.h\"\n#include \"oneflow/core/graph/boxing/sub_task_graph_builder_util.h\"\n\nnamespace oneflow {\n\nMaybe<SubTskGphBuilderStatus> NaiveB2BSubTskGphBuilder::Build(\n    SubTskGphBuilderCtx* ctx, const std::vector<TaskNode*>& sorted_in_tasks,\n    std::vector<TaskNode*>* sorted_out_tasks,\n    std::vector<std::vector<TaskNode*>>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc,\n    const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi,\n    const BlobDesc& logical_blob_desc, const SbpParallel& in_sbp_parallel,\n    const SbpParallel& out_sbp_parallel, const Shape& time_shape) const {\n  if ((in_parallel_desc.parallel_num() == 1 || in_sbp_parallel.has_broadcast_parallel())\n      && (out_parallel_desc.parallel_num() == 1 || out_sbp_parallel.has_broadcast_parallel())) {\n    FOR_RANGE(int64_t, out_id, 0, out_parallel_desc.parallel_num()) {\n      const int64_t nearest_in_parallel_id = SubTskGphBuilderUtil::FindNearestSrcParallelId(\n          in_parallel_desc, out_parallel_desc, out_id);\n      TaskNode* nearest_in_node = sorted_in_tasks.at(nearest_in_parallel_id);\n      TaskNode* proxy =\n          ctx->task_graph()->GetProxyNode(nearest_in_node, lbi, out_parallel_desc, out_id);\n      sorted_out_tasks->emplace_back(proxy);\n    }\n    return TRY(BuildSubTskGphBuilderStatus(\"NaiveB2BSubTskGphBuilder\", \"\"));\n  } else {\n    return Error::BoxingNotSupportedError();\n  }\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/graph/boxing/naive_b2b_sub_task_graph_builder.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_GRAPH_BOXING_NAIVE_B2B_SUB_TASK_GRAPH_BUILDER_H_\n#define ONEFLOW_CORE_GRAPH_BOXING_NAIVE_B2B_SUB_TASK_GRAPH_BUILDER_H_\n\n#include \"oneflow/core/graph/boxing/sub_task_graph_builder.h\"\n\nnamespace oneflow {\n\nclass NaiveB2BSubTskGphBuilder final : public SubTskGphBuilder {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(NaiveB2BSubTskGphBuilder);\n  NaiveB2BSubTskGphBuilder() = default;\n  ~NaiveB2BSubTskGphBuilder() override = default;\n\n  Maybe<SubTskGphBuilderStatus> Build(\n      SubTskGphBuilderCtx* ctx, const std::vector<TaskNode*>& sorted_in_tasks,\n      std::vector<TaskNode*>* sorted_out_tasks,\n      std::vector<std::vector<TaskNode*>>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc,\n      const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi,\n      const BlobDesc& logical_blob_desc, const SbpParallel& in_sbp_parallel,\n      const SbpParallel& out_sbp_parallel, const Shape& time_shape) const override;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_GRAPH_BOXING_NAIVE_B2B_SUB_TASK_GRAPH_BUILDER_H_\n"
  },
  {
    "path": "oneflow/core/graph/boxing/naive_b2p_sub_task_graph_builder.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/graph/boxing/naive_b2p_sub_task_graph_builder.h\"\n#include \"oneflow/core/graph/boxing/sub_task_graph_builder_util.h\"\n#include \"oneflow/core/graph/boxing_zeros_task_node.h\"\n#include \"oneflow/core/graph/task_stream_id.h\"\n\nnamespace oneflow {\n\nMaybe<SubTskGphBuilderStatus> NaiveB2PSubTskGphBuilder::Build(\n    SubTskGphBuilderCtx* ctx, const std::vector<TaskNode*>& sorted_in_tasks,\n    std::vector<TaskNode*>* sorted_out_tasks,\n    std::vector<std::vector<TaskNode*>>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc,\n    const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi,\n    const BlobDesc& logical_blob_desc, const SbpParallel& in_sbp_parallel,\n    const SbpParallel& out_sbp_parallel, const Shape& time_shape) const {\n  if ((in_parallel_desc.parallel_num() == 1 || in_sbp_parallel.has_broadcast_parallel())\n      && out_parallel_desc.parallel_num() != 1 && out_sbp_parallel.has_partial_sum_parallel()) {\n    HashMap<int64_t, int64_t> out_id2nearest_in_id;\n    int64_t nearest_out_node_idx = -1;\n    int64_t nearest_out_node_distance = -1;\n\n    FOR_RANGE(int64_t, out_id, 0, out_parallel_desc.parallel_num()) {\n      const int64_t nearest_in_parallel_id = SubTskGphBuilderUtil::FindNearestSrcParallelId(\n          in_parallel_desc, out_parallel_desc, out_id);\n      out_id2nearest_in_id.emplace(out_id, nearest_in_parallel_id);\n      const int64_t distance = SubTskGphBuilderUtil::GetDistance(\n          in_parallel_desc, nearest_in_parallel_id, out_parallel_desc, out_id);\n      if (nearest_out_node_idx == -1 || distance < nearest_out_node_distance) {\n        nearest_out_node_idx = out_id;\n        nearest_out_node_distance = distance;\n      }\n    }\n    FOR_RANGE(int64_t, out_id, 0, out_parallel_desc.parallel_num()) {\n      const int64_t nearest_in_id = out_id2nearest_in_id.at(out_id);\n      TaskNode* nearest_in_node = sorted_in_tasks.at(nearest_in_id);\n      if (out_id == nearest_out_node_idx) {\n        TaskNode* proxy =\n            ctx->task_graph()->GetProxyNode(nearest_in_node, lbi, out_parallel_desc, out_id);\n\n        sorted_out_tasks->emplace_back(proxy);\n      } else {\n        int64_t out_machine_id = CHECK_JUST(out_parallel_desc.MachineId4ParallelId(out_id));\n        int64_t out_dev_phy_id = CHECK_JUST(out_parallel_desc.DeviceId4ParallelId(out_id));\n        if (out_parallel_desc.device_type() == DeviceType::kCPU) { out_dev_phy_id = 0; }\n        int64_t thrd_id = EncodeStreamIdToInt64(GenerateComputeTaskStreamId(\n            out_machine_id, out_parallel_desc.device_type(), out_dev_phy_id));\n        auto* zeros_node = ctx->task_graph()->NewNode<BoxingZerosTaskNode>();\n        zeros_node->Init(out_machine_id, thrd_id, lbi, logical_blob_desc.shape(),\n                         logical_blob_desc.data_type(), time_shape);\n        nearest_in_node->BuildCtrlRegstDesc(zeros_node);\n        ctx->task_graph()->ConnectWithLbi(nearest_in_node, zeros_node, lbi);\n        sorted_out_tasks->emplace_back(zeros_node);\n      }\n    }\n    return TRY(BuildSubTskGphBuilderStatus(\"NaiveB2PSubTskGphBuilder\", \"\"));\n  } else {\n    return Error::BoxingNotSupportedError();\n  }\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/graph/boxing/naive_b2p_sub_task_graph_builder.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_GRAPH_BOXING_NAIVE_B2P_SUB_TASK_GRAPH_BUILDER_H_\n#define ONEFLOW_CORE_GRAPH_BOXING_NAIVE_B2P_SUB_TASK_GRAPH_BUILDER_H_\n\n#include \"oneflow/core/graph/boxing/sub_task_graph_builder.h\"\n\nnamespace oneflow {\n\nclass NaiveB2PSubTskGphBuilder final : public SubTskGphBuilder {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(NaiveB2PSubTskGphBuilder);\n  NaiveB2PSubTskGphBuilder() = default;\n  ~NaiveB2PSubTskGphBuilder() override = default;\n\n  Maybe<SubTskGphBuilderStatus> Build(\n      SubTskGphBuilderCtx* ctx, const std::vector<TaskNode*>& sorted_in_tasks,\n      std::vector<TaskNode*>* sorted_out_tasks,\n      std::vector<std::vector<TaskNode*>>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc,\n      const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi,\n      const BlobDesc& logical_blob_desc, const SbpParallel& in_sbp_parallel,\n      const SbpParallel& out_sbp_parallel, const Shape& time_shape) const override;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_GRAPH_BOXING_NAIVE_B2P_SUB_TASK_GRAPH_BUILDER_H_\n"
  },
  {
    "path": "oneflow/core/graph/boxing/one_to_one_sub_task_graph_builder.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/graph/boxing/one_to_one_sub_task_graph_builder.h\"\n#include \"oneflow/core/graph/boxing/sub_task_graph_builder_util.h\"\n\nnamespace oneflow {\n\nMaybe<SubTskGphBuilderStatus> OneToOneSubTskGphBuilder::Build(\n    SubTskGphBuilderCtx* ctx, const std::vector<TaskNode*>& sorted_in_tasks,\n    std::vector<TaskNode*>* sorted_out_tasks,\n    std::vector<std::vector<TaskNode*>>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc,\n    const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi,\n    const BlobDesc& logical_blob_desc, const SbpParallel& in_sbp_parallel,\n    const SbpParallel& out_sbp_parallel, const Shape& time_shape) const {\n  if ((in_parallel_desc.parallel_num() == 1 && out_parallel_desc.parallel_num() == 1)\n      || (in_parallel_desc.parallel_num() == out_parallel_desc.parallel_num()\n          && in_sbp_parallel == out_sbp_parallel)) {\n    for (int64_t i = 0; i < in_parallel_desc.parallel_num(); ++i) {\n      TaskNode* in_node = sorted_in_tasks.at(i);\n      TaskNode* proxy = ctx->task_graph()->GetProxyNode(in_node, lbi, out_parallel_desc, i);\n      sorted_out_tasks->emplace_back(proxy);\n    }\n    return TRY(BuildSubTskGphBuilderStatus(\"OneToOneSubTskGphBuilder\", \"\"));\n  } else {\n    return Error::BoxingNotSupportedError();\n  }\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/graph/boxing/one_to_one_sub_task_graph_builder.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_GRAPH_BOXING_ONE_TO_ONE_SUB_TASK_GRAPH_BUILDER_H_\n#define ONEFLOW_CORE_GRAPH_BOXING_ONE_TO_ONE_SUB_TASK_GRAPH_BUILDER_H_\n\n#include \"oneflow/core/graph/boxing/sub_task_graph_builder.h\"\n\nnamespace oneflow {\n\nclass OneToOneSubTskGphBuilder final : public SubTskGphBuilder {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(OneToOneSubTskGphBuilder);\n  OneToOneSubTskGphBuilder() = default;\n  ~OneToOneSubTskGphBuilder() override = default;\n\n  Maybe<SubTskGphBuilderStatus> Build(\n      SubTskGphBuilderCtx* ctx, const std::vector<TaskNode*>& sorted_in_tasks,\n      std::vector<TaskNode*>* sorted_out_tasks,\n      std::vector<std::vector<TaskNode*>>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc,\n      const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi,\n      const BlobDesc& logical_blob_desc, const SbpParallel& in_sbp_parallel,\n      const SbpParallel& out_sbp_parallel, const Shape& time_shape) const override;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_GRAPH_BOXING_ONE_TO_ONE_SUB_TASK_GRAPH_BUILDER_H_\n"
  },
  {
    "path": "oneflow/core/graph/boxing/slice_boxing_sub_task_graph_builder.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/graph/boxing/slice_boxing_sub_task_graph_builder.h\"\n#include \"oneflow/core/register/tensor_slice_view.h\"\n#include \"oneflow/core/common/balanced_splitter.h\"\n#include \"oneflow/core/graph/slice_boxing_task_node.h\"\n#include \"oneflow/core/graph/boxing/sub_task_graph_builder_util.h\"\n#include \"oneflow/core/job/nd_sbp_util.h\"\n#include \"oneflow/core/graph/task_stream_id.h\"\n#include \"oneflow/core/ep/include/primitive/copy_nd.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nbool IsCopyNdPrimitiveSupported(DeviceType device_type, int64_t ndims) {\n  auto primitive = ep::primitive::NewPrimitive<ep::primitive::CopyNdFactory>(device_type, ndims);\n  return primitive.operator bool();\n}\n\n}  // namespace\n\nMaybe<SubTskGphBuilderStatus> SliceBoxingSubTskGphBuilder::Build(\n    SubTskGphBuilderCtx* ctx, const std::vector<TaskNode*>& sorted_in_tasks,\n    std::vector<TaskNode*>* sorted_out_tasks,\n    std::vector<std::vector<TaskNode*>>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc,\n    const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi,\n    const BlobDesc& logical_blob_desc, const SbpParallel& in_sbp_parallel,\n    const SbpParallel& out_sbp_parallel, const Shape& time_shape) const {\n  if (!IsCopyNdPrimitiveSupported(in_parallel_desc.device_type(),\n                                  logical_blob_desc.shape().NumAxes())) {\n    return Error::BoxingNotSupportedError();\n  }\n  if (!IsCopyNdPrimitiveSupported(out_parallel_desc.device_type(),\n                                  logical_blob_desc.shape().NumAxes())) {\n    return Error::BoxingNotSupportedError();\n  }\n  if (SubTskGphBuilderUtil::BlobHasDynamicShape(logical_blob_desc)) {\n    return Error::BoxingNotSupportedError();\n  }\n  if (SubTskGphBuilderUtil::HasEmptySliceIfSplit(in_parallel_desc.parallel_num(), in_sbp_parallel,\n                                                 logical_blob_desc)) {\n    return Error::BoxingNotSupportedError();\n  }\n  if (SubTskGphBuilderUtil::HasEmptySliceIfSplit(out_parallel_desc.parallel_num(), out_sbp_parallel,\n                                                 logical_blob_desc)) {\n    return Error::BoxingNotSupportedError();\n  }\n  if (!(SubTskGphBuilderUtil::IsBoxingS2B(in_sbp_parallel, out_sbp_parallel)\n        || SubTskGphBuilderUtil::IsBoxingS2S(in_sbp_parallel, out_sbp_parallel)\n        || SubTskGphBuilderUtil::IsBoxingP2S(in_sbp_parallel, out_sbp_parallel)\n        || SubTskGphBuilderUtil::IsBoxingP2B(in_sbp_parallel, out_sbp_parallel)\n        || SubTskGphBuilderUtil::IsBoxingB2S(in_sbp_parallel, out_sbp_parallel))) {\n    return Error::BoxingNotSupportedError();\n  }\n\n  const auto NewEdge = [&ctx]() -> TaskEdge* { return ctx->task_graph()->NewEdge(); };\n  const auto CreateSliceBoxingNode =\n      [&ctx, &lbi](const ParallelDesc& pd, const int64_t parallel_id, const TensorSliceView& slice,\n                   SliceBoxingTaskMode mode) -> SliceBoxingTaskNode* {\n    SliceBoxingTaskNode* node = ctx->task_graph()->NewNode<SliceBoxingTaskNode>();\n    const int64_t machine_id = CHECK_JUST(pd.MachineId4ParallelId(parallel_id));\n    int64_t device_index = (pd.device_type() == DeviceType::kCPU)\n                               ? 0\n                               : CHECK_JUST(pd.DeviceId4ParallelId(parallel_id));\n    int64_t thrd_id = EncodeStreamIdToInt64(\n        GenerateComputeTaskStreamId(machine_id, pd.device_type(), device_index));\n    node->Init(lbi, slice, mode, machine_id, thrd_id);\n    return node;\n  };\n  const auto GetSliceCopyNode = [&CreateSliceBoxingNode, &NewEdge](\n                                    TaskNode* in_node, const TensorSliceView& in_slice,\n                                    const ParallelDesc& in_pd, const int64_t in_id,\n                                    const TensorSliceView& intersection) -> TaskNode* {\n    if (in_slice == intersection) {\n      return in_node;\n    } else {\n      SliceBoxingTaskNode* slice_copy_node =\n          CreateSliceBoxingNode(in_pd, in_id, intersection, kSliceBoxingTaskModeCopy);\n      slice_copy_node->ConnectToSrcNodeWithSlice(in_node, NewEdge(), in_slice);\n      return slice_copy_node;\n    }\n  };\n  const auto BuildSubTaskGphS2B =\n      [&ctx, &CreateSliceBoxingNode, &NewEdge, &lbi](\n          const ParallelDesc& in_pd, const ParallelDesc& out_pd, const SbpParallel& in_sbp,\n          const SbpParallel& out_sbp, const BlobDesc& blob_desc,\n          const std::vector<TaskNode*>& in_nodes, std::vector<TaskNode*>* out_nodes) {\n        CHECK(SubTskGphBuilderUtil::IsBoxingS2B(in_sbp, out_sbp));\n        const std::vector<TensorSliceView> in_slices =\n            GetTensorSliceView(in_pd.parallel_num(), in_sbp, blob_desc);\n        const TensorSliceView& out_slice = GetBroadcastTensorSliceView(blob_desc);\n        FOR_RANGE(int64_t, out_id, 0, out_pd.parallel_num()) {\n          SliceBoxingTaskNode* out_node =\n              CreateSliceBoxingNode(out_pd, out_id, out_slice, kSliceBoxingTaskModeCopy);\n          FOR_RANGE(int64_t, in_id, 0, in_pd.parallel_num()) {\n            const TensorSliceView& in_slice = in_slices.at(in_id);\n            TaskNode* in_node = in_nodes.at(in_id);\n            TaskNode* proxy_node = ctx->task_graph()->GetProxyNode(\n                in_node, lbi, dynamic_cast<TaskNode*>(out_node)->MemZoneId121());\n            out_node->ConnectToSrcNodeWithSlice(proxy_node, NewEdge(), in_slice);\n          }\n          out_nodes->emplace_back(out_node);\n        }\n      };\n  const auto BuildSubTaskGphS2S = [&ctx, &lbi, &CreateSliceBoxingNode, &GetSliceCopyNode, &NewEdge](\n                                      const ParallelDesc& in_pd, const ParallelDesc& out_pd,\n                                      const SbpParallel& in_sbp, const SbpParallel& out_sbp,\n                                      const BlobDesc& blob_desc,\n                                      const std::vector<TaskNode*>& in_nodes,\n                                      std::vector<TaskNode*>* out_nodes) {\n    CHECK(SubTskGphBuilderUtil::IsBoxingS2S(in_sbp, out_sbp));\n    const std::vector<TensorSliceView> in_slices =\n        GetTensorSliceView(in_pd.parallel_num(), in_sbp, blob_desc);\n    const std::vector<TensorSliceView> out_slices =\n        GetTensorSliceView(out_pd.parallel_num(), out_sbp, blob_desc);\n    for (int64_t out_id = 0; out_id < out_pd.parallel_num(); ++out_id) {\n      const TensorSliceView& out_slice = out_slices.at(out_id);\n      SliceBoxingTaskNode* out_node =\n          CreateSliceBoxingNode(out_pd, out_id, out_slice, kSliceBoxingTaskModeCopy);\n      for (int64_t in_id = 0; in_id < in_pd.parallel_num(); ++in_id) {\n        const TensorSliceView& in_slice = in_slices.at(in_id);\n        const TensorSliceView& intersection = out_slice.Intersect(in_slice);\n        if (intersection.IsEmpty()) { continue; }\n        TaskNode* in_node = in_nodes.at(in_id);\n        TaskNode* slice_copy_node = GetSliceCopyNode(in_node, in_slice, in_pd, in_id, intersection);\n        TaskNode* proxy_node = ctx->task_graph()->GetProxyNode(\n            slice_copy_node, lbi, dynamic_cast<TaskNode*>(out_node)->MemZoneId121());\n        out_node->ConnectToSrcNodeWithSlice(proxy_node, NewEdge(), intersection);\n      }\n      out_nodes->emplace_back(out_node);\n    }\n  };\n  const auto BuildSubTaskGphP2S = [&ctx, &lbi, &CreateSliceBoxingNode, &GetSliceCopyNode, &NewEdge](\n                                      const ParallelDesc& in_pd, const ParallelDesc& out_pd,\n                                      const SbpParallel& in_sbp, const SbpParallel& out_sbp,\n                                      const BlobDesc& blob_desc,\n                                      const std::vector<TaskNode*>& in_nodes,\n                                      std::vector<TaskNode*>* out_nodes) {\n    CHECK(SubTskGphBuilderUtil::IsBoxingP2S(in_sbp, out_sbp));\n    const TensorSliceView& in_slice = GetBroadcastTensorSliceView(blob_desc);\n    const std::vector<TensorSliceView> out_slices =\n        GetTensorSliceView(out_pd.parallel_num(), out_sbp, blob_desc);\n    for (int64_t out_id = 0; out_id < out_pd.parallel_num(); ++out_id) {\n      const TensorSliceView& out_slice = out_slices.at(out_id);\n      SliceBoxingTaskNode* out_node =\n          CreateSliceBoxingNode(out_pd, out_id, out_slice, kSliceBoxingTaskModeAdd);\n      for (int64_t in_id = 0; in_id < in_pd.parallel_num(); ++in_id) {\n        const TensorSliceView& intersection = out_slice.Intersect(in_slice);\n        if (intersection.IsEmpty()) { continue; }\n        TaskNode* in_node = in_nodes.at(in_id);\n        TaskNode* slice_copy_node = GetSliceCopyNode(in_node, in_slice, in_pd, in_id, intersection);\n        TaskNode* proxy_node = ctx->task_graph()->GetProxyNode(\n            slice_copy_node, lbi, dynamic_cast<TaskNode*>(out_node)->MemZoneId121());\n        out_node->ConnectToSrcNodeWithSlice(proxy_node, NewEdge(), intersection);\n      }\n      out_nodes->emplace_back(out_node);\n    }\n  };\n\n  const auto BuildSubTaskGphP2B =\n      [&ctx, &lbi, &CreateSliceBoxingNode, &NewEdge](\n          const ParallelDesc& in_pd, const ParallelDesc& out_pd, const SbpParallel& in_sbp,\n          const SbpParallel& out_sbp, const BlobDesc& blob_desc,\n          const std::vector<TaskNode*>& in_nodes, std::vector<TaskNode*>* out_nodes) {\n        CHECK(SubTskGphBuilderUtil::IsBoxingP2B(in_sbp, out_sbp));\n        const TensorSliceView& slice = GetBroadcastTensorSliceView(blob_desc);\n        for (int64_t out_id = 0; out_id < out_pd.parallel_num(); ++out_id) {\n          SliceBoxingTaskNode* out_node =\n              CreateSliceBoxingNode(out_pd, out_id, slice, kSliceBoxingTaskModeAdd);\n          for (int64_t in_id = 0; in_id < in_pd.parallel_num(); ++in_id) {\n            TaskNode* in_node = in_nodes.at(in_id);\n            TaskNode* proxy_node = ctx->task_graph()->GetProxyNode(\n                in_node, lbi, dynamic_cast<TaskNode*>(out_node)->MemZoneId121());\n            out_node->ConnectToSrcNodeWithSlice(proxy_node, NewEdge(), slice);\n          }\n          out_nodes->emplace_back(out_node);\n        }\n      };\n\n  const auto BuildSubTaskGphB2S =\n      [&ctx, &lbi, &CreateSliceBoxingNode, &NewEdge](\n          const ParallelDesc& in_pd, const ParallelDesc& out_pd, const SbpParallel& in_sbp,\n          const SbpParallel& out_sbp, const BlobDesc& blob_desc,\n          const std::vector<TaskNode*>& in_nodes, std::vector<TaskNode*>* out_nodes) {\n        CHECK(SubTskGphBuilderUtil::IsBoxingB2S(in_sbp, out_sbp));\n        const TensorSliceView& in_slice = GetBroadcastTensorSliceView(blob_desc);\n        const std::vector<TensorSliceView> out_slices =\n            GetTensorSliceView(out_pd.parallel_num(), out_sbp, blob_desc);\n        FOR_RANGE(int64_t, out_id, 0, out_pd.parallel_num()) {\n          const TensorSliceView& out_slice = out_slices.at(out_id);\n          const int64_t nearest_idx =\n              SubTskGphBuilderUtil::FindNearestSrcParallelId(in_pd, out_pd, out_id);\n          TaskNode* in_node = in_nodes.at(nearest_idx);\n          SliceBoxingTaskNode* slice_node =\n              CreateSliceBoxingNode(in_pd, nearest_idx, out_slice, kSliceBoxingTaskModeCopy);\n          slice_node->ConnectToSrcNodeWithSlice(in_node, NewEdge(), in_slice);\n          TaskNode* out_node = ctx->task_graph()->GetProxyNode(slice_node, lbi, out_pd, out_id);\n\n          out_nodes->emplace_back(out_node);\n        }\n      };\n\n  std::string comment;\n  if (SubTskGphBuilderUtil::IsBoxingS2B(in_sbp_parallel, out_sbp_parallel)) {\n    BuildSubTaskGphS2B(in_parallel_desc, out_parallel_desc, in_sbp_parallel, out_sbp_parallel,\n                       logical_blob_desc, sorted_in_tasks, sorted_out_tasks);\n    comment = \"BuildSubTaskGphS2B\";\n  } else if (SubTskGphBuilderUtil::IsBoxingS2S(in_sbp_parallel, out_sbp_parallel)) {\n    BuildSubTaskGphS2S(in_parallel_desc, out_parallel_desc, in_sbp_parallel, out_sbp_parallel,\n                       logical_blob_desc, sorted_in_tasks, sorted_out_tasks);\n    comment = \"BuildSubTaskGphS2S\";\n  } else if (SubTskGphBuilderUtil::IsBoxingP2S(in_sbp_parallel, out_sbp_parallel)) {\n    BuildSubTaskGphP2S(in_parallel_desc, out_parallel_desc, in_sbp_parallel, out_sbp_parallel,\n                       logical_blob_desc, sorted_in_tasks, sorted_out_tasks);\n    comment = \"BuildSubTaskGphP2S\";\n  } else if (SubTskGphBuilderUtil::IsBoxingP2B(in_sbp_parallel, out_sbp_parallel)) {\n    if (logical_blob_desc.shape().elem_cnt() < out_parallel_desc.parallel_num()) {\n      BuildSubTaskGphP2B(in_parallel_desc, out_parallel_desc, in_sbp_parallel, out_sbp_parallel,\n                         logical_blob_desc, sorted_in_tasks, sorted_out_tasks);\n      comment = \"BuildSubTaskGphP2B\";\n    } else {\n      BlobDesc flat_blob_desc(logical_blob_desc.data_type(), logical_blob_desc.memory_format());\n      flat_blob_desc.set_shape(Shape({logical_blob_desc.shape().elem_cnt()}));\n      std::vector<TaskNode*> middle_nodes;\n      SbpParallel middle_sbp;\n      middle_sbp.mutable_split_parallel()->set_axis(0);\n      BuildSubTaskGphP2S(in_parallel_desc, out_parallel_desc, in_sbp_parallel, middle_sbp,\n                         flat_blob_desc, sorted_in_tasks, &middle_nodes);\n      BuildSubTaskGphS2B(out_parallel_desc, out_parallel_desc, middle_sbp, out_sbp_parallel,\n                         flat_blob_desc, middle_nodes, sorted_out_tasks);\n      comment = \"BuildSubTaskGphP2S->BuildSubTaskGphS2B\";\n      for (TaskNode* out_node : *sorted_out_tasks) {\n        auto* slice_boxing_node = dynamic_cast<SliceBoxingTaskNode*>(out_node);\n        CHECK_NOTNULL(slice_boxing_node);\n        slice_boxing_node->SetOutShape(logical_blob_desc.shape());\n      }\n    }\n\n  } else if (SubTskGphBuilderUtil::IsBoxingB2S(in_sbp_parallel, out_sbp_parallel)) {\n    BuildSubTaskGphB2S(in_parallel_desc, out_parallel_desc, in_sbp_parallel, out_sbp_parallel,\n                       logical_blob_desc, sorted_in_tasks, sorted_out_tasks);\n    comment = \"BuildSubTaskGphB2S\";\n  } else {\n    UNIMPLEMENTED();\n  }\n  return TRY(BuildSubTskGphBuilderStatus(\"SliceBoxingSubTskGphBuilder\", comment));\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/graph/boxing/slice_boxing_sub_task_graph_builder.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_GRAPH_BOXING_SLICE_BOXING_SUB_TASK_GRAPH_BUILDER_H_\n#define ONEFLOW_CORE_GRAPH_BOXING_SLICE_BOXING_SUB_TASK_GRAPH_BUILDER_H_\n\n#include \"oneflow/core/graph/boxing/sub_task_graph_builder.h\"\n\nnamespace oneflow {\n\nclass SliceBoxingSubTskGphBuilder final : public SubTskGphBuilder {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(SliceBoxingSubTskGphBuilder);\n  SliceBoxingSubTskGphBuilder() = default;\n  ~SliceBoxingSubTskGphBuilder() override = default;\n\n  Maybe<SubTskGphBuilderStatus> Build(\n      SubTskGphBuilderCtx* ctx, const std::vector<TaskNode*>& sorted_in_tasks,\n      std::vector<TaskNode*>* sorted_out_tasks,\n      std::vector<std::vector<TaskNode*>>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc,\n      const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi,\n      const BlobDesc& logical_blob_desc, const SbpParallel& in_sbp_parallel,\n      const SbpParallel& out_sbp_parallel, const Shape& time_shape) const override;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_GRAPH_BOXING_SLICE_BOXING_SUB_TASK_GRAPH_BUILDER_H_\n"
  },
  {
    "path": "oneflow/core/graph/boxing/sub_task_graph_builder.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_GRAPH_BOXING_SUB_TASK_GRAPH_BUILDER_H_\n#define ONEFLOW_CORE_GRAPH_BOXING_SUB_TASK_GRAPH_BUILDER_H_\n\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/graph/boxing/sub_task_graph_builder_context.h\"\n#include \"oneflow/core/graph/boxing/sub_task_graph_builder_status_util.h\"\n\nnamespace oneflow {\n\nclass SubTskGphBuilder {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(SubTskGphBuilder);\n  SubTskGphBuilder() = default;\n  virtual ~SubTskGphBuilder() = default;\n\n  virtual Maybe<SubTskGphBuilderStatus> Build(\n      SubTskGphBuilderCtx* ctx, const std::vector<TaskNode*>& sorted_in_tasks,\n      std::vector<TaskNode*>* sorted_out_tasks,\n      std::vector<std::vector<TaskNode*>>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc,\n      const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi,\n      const BlobDesc& logical_blob_desc, const SbpParallel& in_sbp_parallel,\n      const SbpParallel& out_sbp_parallel, const Shape& time_shape) const = 0;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_GRAPH_BOXING_SUB_TASK_GRAPH_BUILDER_H_\n"
  },
  {
    "path": "oneflow/core/graph/boxing/sub_task_graph_builder_context.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/graph/boxing/sub_task_graph_builder_context.h\"\n\nnamespace oneflow {\n\nSubTskGphBuilderCtx::SubTskGphBuilderCtx(TaskGraph* task_graph) : task_graph_(task_graph) {}\n\nTaskGraph* SubTskGphBuilderCtx::task_graph() { return task_graph_; }\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/graph/boxing/sub_task_graph_builder_context.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_GRAPH_BOXING_SUB_TASK_GRAPH_BUILDER_CONTEXT_H_\n#define ONEFLOW_CORE_GRAPH_BOXING_SUB_TASK_GRAPH_BUILDER_CONTEXT_H_\n\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/graph/task_graph.h\"\n\nnamespace oneflow {\n\nclass SubTskGphBuilderCtx final {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(SubTskGphBuilderCtx);\n  explicit SubTskGphBuilderCtx(TaskGraph* task_graph);\n  virtual ~SubTskGphBuilderCtx() = default;\n\n  virtual TaskGraph* task_graph();\n\n private:\n  TaskGraph* task_graph_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_GRAPH_BOXING_SUB_TASK_GRAPH_BUILDER_CONTEXT_H_\n"
  },
  {
    "path": "oneflow/core/graph/boxing/sub_task_graph_builder_status_util.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/graph/boxing/sub_task_graph_builder_status_util.h\"\n\nnamespace oneflow {\n\nMaybe<SubTskGphBuilderStatus> BuildSubTskGphBuilderStatus(const std::string& builder_name,\n                                                          const std::string& comment) {\n  SubTskGphBuilderStatus status(builder_name, comment);\n  return status;\n}\n\nMaybe<SubTskGphBuilderStatus> MakeComposedSubTskGphBuilderStatus(\n    const std::vector<SubTskGphBuilderStatus>& status_vec) {\n  std::string builder_name = \"ComposedBuilder:\";\n  std::string comment = \"ComposedComment:\";\n  for (auto status : status_vec) {\n    builder_name += \" \";\n    builder_name += status.builder_name();\n    comment += \" \";\n    if (status.comment().empty()) {\n      comment += \"None\";\n    } else {\n      comment += status.comment();\n    }\n  }\n  SubTskGphBuilderStatus composed_status(builder_name, comment);\n  return composed_status;\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/graph/boxing/sub_task_graph_builder_status_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_GRAPH_SUB_TASK_GRAPH_BUILDER_STATUS_UTIL_H_\n#define ONEFLOW_CORE_GRAPH_SUB_TASK_GRAPH_BUILDER_STATUS_UTIL_H_\n\n#include \"oneflow/core/graph/compute_task_node.h\"\n\nnamespace oneflow {\n\nclass SubTskGphBuilderStatus;\n\nMaybe<SubTskGphBuilderStatus> BuildSubTskGphBuilderStatus(const std::string& builder_name,\n                                                          const std::string& comment);\n\nMaybe<SubTskGphBuilderStatus> MakeComposedSubTskGphBuilderStatus(\n    const std::vector<SubTskGphBuilderStatus>& status);\n\nclass SubTskGphBuilderStatus final {\n public:\n  SubTskGphBuilderStatus(const std::string& builder_name, const std::string& comment)\n      : builder_name_(builder_name), comment_(comment){};\n  ~SubTskGphBuilderStatus() = default;\n\n  // Getters\n  const std::string& builder_name() const { return builder_name_; }\n  const std::string& comment() const { return comment_; }\n\n private:\n  std::string builder_name_;\n  std::string comment_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_GRAPH_SUB_TASK_GRAPH_BUILDER_STATUS_UTIL_H_\n"
  },
  {
    "path": "oneflow/core/graph/boxing/sub_task_graph_builder_util.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/graph/boxing/sub_task_graph_builder_util.h\"\n#include \"oneflow/core/common/balanced_splitter.h\"\n#include \"oneflow/core/common/nd_index_offset_helper.h\"\n\nnamespace oneflow {\n\nbool SubTskGphBuilderUtil::IsDeviceTypeCPUOr(const ParallelDesc& parallel_desc,\n                                             DeviceType device_type) {\n  return parallel_desc.device_type() == DeviceType::kCPU\n         || parallel_desc.device_type() == device_type;\n}\n\nbool SubTskGphBuilderUtil::HasEmptySliceIfSplit(int64_t parallel_num,\n                                                const SbpParallel& sbp_parallel,\n                                                const BlobDesc& blob_desc) {\n  if (sbp_parallel.has_split_parallel()) {\n    return blob_desc.shape().At(sbp_parallel.split_parallel().axis()) < parallel_num;\n  } else {\n    return false;\n  }\n}\n\nbool SubTskGphBuilderUtil::IsOnSameDevice(const TaskNode* lhs, const TaskNode* rhs) {\n  return lhs->stream_id().device_id() == rhs->stream_id().device_id()\n         && lhs->stream_id().device_id().device_type() != DeviceType::kCPU;\n}\n\nbool SubTskGphBuilderUtil::IsBoxingS2S(const SbpParallel& src, const SbpParallel& dst) {\n  return src.has_split_parallel() && dst.has_split_parallel();\n}\n\nbool SubTskGphBuilderUtil::IsBoxingS2B(const SbpParallel& src, const SbpParallel& dst) {\n  return src.has_split_parallel() && dst.has_broadcast_parallel();\n}\n\nbool SubTskGphBuilderUtil::IsBoxingP2S(const SbpParallel& src, const SbpParallel& dst) {\n  return src.has_partial_sum_parallel() && dst.has_split_parallel();\n}\n\nbool SubTskGphBuilderUtil::IsBoxingP2B(const SbpParallel& src, const SbpParallel& dst) {\n  return src.has_partial_sum_parallel() && dst.has_broadcast_parallel();\n}\n\nbool SubTskGphBuilderUtil::IsBoxingB2B(const SbpParallel& src, const SbpParallel& dst) {\n  return src.has_broadcast_parallel() && dst.has_broadcast_parallel();\n}\n\nbool SubTskGphBuilderUtil::IsBoxingB2S(const SbpParallel& src, const SbpParallel& dst) {\n  return src.has_broadcast_parallel() && dst.has_split_parallel();\n}\n\nbool SubTskGphBuilderUtil::BlobHasDynamicShape(const BlobDesc& blob_desc) {\n  return blob_desc.is_dynamic();\n}\n\nbool SubTskGphBuilderUtil::IsErrorBoxingNotSupported(const ErrorProto& error) {\n  return error.has_boxing_not_supported_error();\n}\n\nint64_t SubTskGphBuilderUtil::GetDistance(\n    const int64_t src_machine_id, const int64_t src_dev_phy_id, const DeviceType src_device_type,\n    const int64_t dst_machine_id, const int64_t dst_dev_phy_id, const DeviceType dst_device_type) {\n  if (src_machine_id != dst_machine_id) {\n    return kDistanceDiffMachine;\n  } else if (src_device_type != dst_device_type) {\n    return kDistanceSameMachine;\n  } else if (src_device_type == DeviceType::kCPU) {\n    return kDistanceSameDevice;\n  } else {\n    if (src_dev_phy_id == dst_dev_phy_id) {\n      return kDistanceSameDevice;\n    } else {\n      return kDistanceSameMachine;\n    }\n  }\n}\n\nint64_t SubTskGphBuilderUtil::GetDistance(const ParallelDesc& src_parallel_desc,\n                                          const int64_t src_parallel_id,\n                                          const ParallelDesc& dst_parallel_desc,\n                                          const int64_t dst_parallel_id) {\n  const int64_t src_machine_id =\n      CHECK_JUST(src_parallel_desc.MachineId4ParallelId(src_parallel_id));\n  const int64_t src_dev_phy_id = CHECK_JUST(src_parallel_desc.DeviceId4ParallelId(src_parallel_id));\n  const int64_t dst_machine_id =\n      CHECK_JUST(dst_parallel_desc.MachineId4ParallelId(dst_parallel_id));\n  const int64_t dst_dev_phy_id = CHECK_JUST(dst_parallel_desc.DeviceId4ParallelId(dst_parallel_id));\n  return GetDistance(src_machine_id, src_dev_phy_id, src_parallel_desc.device_type(),\n                     dst_machine_id, dst_dev_phy_id, dst_parallel_desc.device_type());\n}\n\nint64_t SubTskGphBuilderUtil::GetDistance(const TaskNode* src, const TaskNode* dst) {\n  const auto GetDevPhyId = [](const TaskNode* node) -> int64_t {\n    const DeviceId& device_id = node->stream_id().device_id();\n    if (device_id.device_type() == DeviceType::kCPU) {\n      return 0;\n    } else {\n      return device_id.device_index();\n    }\n  };\n  const DeviceType src_device_type = src->device_type();\n  const int64_t src_dev_phy_id = GetDevPhyId(src);\n  const DeviceType dst_device_type = dst->device_type();\n  const int64_t dst_dev_phy_id = GetDevPhyId(dst);\n  return GetDistance(src->machine_id(), src_dev_phy_id, src_device_type, dst->machine_id(),\n                     dst_dev_phy_id, dst_device_type);\n}\n\nint64_t SubTskGphBuilderUtil::FindNearestSrcParallelId(const ParallelDesc& from_parallel_desc,\n                                                       const ParallelDesc& to_parallel_desc,\n                                                       const int64_t to_parallel_id) {\n  int64_t nearest_from_parallel_idx = -1;\n  int64_t nearest_distance = SubTskGphBuilderUtil::kDistanceMax;\n  for (int64_t i = 0; i < from_parallel_desc.parallel_num(); ++i) {\n    const int64_t distance =\n        SubTskGphBuilderUtil::GetDistance(from_parallel_desc, i, to_parallel_desc, to_parallel_id);\n    if (distance < nearest_distance) {\n      nearest_from_parallel_idx = i;\n      nearest_distance = distance;\n    }\n  }\n  CHECK_NE(nearest_from_parallel_idx, -1);\n  return nearest_from_parallel_idx;\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/graph/boxing/sub_task_graph_builder_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_GRAPH_BOXING_SUB_TASK_GRAPH_BUILDER_UTIL_H_\n#define ONEFLOW_CORE_GRAPH_BOXING_SUB_TASK_GRAPH_BUILDER_UTIL_H_\n\n#include \"oneflow/core/job/parallel_desc.h\"\n#include \"oneflow/core/register/tensor_slice_view.h\"\n#include \"oneflow/core/register/blob_desc.h\"\n#include \"oneflow/core/graph/task_node.h\"\n\nnamespace oneflow {\n\nstruct SubTskGphBuilderUtil {\n  static constexpr int64_t kDistanceSameDevice = 0;\n  static constexpr int64_t kDistanceSameMachine = 1;\n  static constexpr int64_t kDistanceDiffMachine = 2;\n  static constexpr int64_t kDistanceMax = 3;\n\n  static bool IsDeviceTypeCPUOr(const ParallelDesc& parallel_desc, DeviceType device_type);\n  static bool HasEmptySliceIfSplit(int64_t parallel_num, const SbpParallel& sbp_parallel,\n                                   const BlobDesc& blob_desc);\n  static bool IsOnSameDevice(const TaskNode* lhs, const TaskNode* rhs);\n  static bool IsBoxingS2S(const SbpParallel& src, const SbpParallel& dst);\n  static bool IsBoxingS2B(const SbpParallel& src, const SbpParallel& dst);\n  static bool IsBoxingP2S(const SbpParallel& src, const SbpParallel& dst);\n  static bool IsBoxingP2B(const SbpParallel& src, const SbpParallel& dst);\n  static bool IsBoxingB2B(const SbpParallel& src, const SbpParallel& dst);\n  static bool IsBoxingB2S(const SbpParallel& src, const SbpParallel& dst);\n  static bool BlobHasDynamicShape(const BlobDesc& blob_desc);\n  static bool IsErrorBoxingNotSupported(const ErrorProto& error);\n  static int64_t GetDistance(int64_t src_machine_id, int64_t src_dev_phy_id,\n                             DeviceType src_device_type, int64_t dst_machine_id,\n                             int64_t dst_dev_phy_id, DeviceType dst_device_type);\n  static int64_t GetDistance(const ParallelDesc& src_parallel_desc, int64_t src_parallel_id,\n                             const ParallelDesc& dst_parallel_desc, int64_t dst_parallel_id);\n  static int64_t GetDistance(const TaskNode* src, const TaskNode* dst);\n\n  template<typename NodeType>\n  static int64_t FindNearestNodeIndex(const std::vector<NodeType*> from_nodes,\n                                      const NodeType* to_node) {\n    CHECK(!from_nodes.empty());\n    int64_t nearest_from_node_idx = -1;\n    int64_t nearest_distance = SubTskGphBuilderUtil::kDistanceMax;\n    for (int64_t i = 0; i < from_nodes.size(); ++i) {\n      NodeType* from_node = from_nodes.at(i);\n      int64_t distance = SubTskGphBuilderUtil::GetDistance(from_node, to_node);\n      if (distance < nearest_distance) {\n        nearest_from_node_idx = i;\n        nearest_distance = distance;\n      }\n    }\n    return nearest_from_node_idx;\n  }\n\n  template<typename NodeType>\n  static NodeType* FindNearestNode(const std::vector<NodeType*> from_nodes,\n                                   const NodeType* to_node) {\n    const int64_t idx = FindNearestNodeIndex<NodeType>(from_nodes, to_node);\n    return from_nodes.at(idx);\n  }\n\n  static int64_t FindNearestSrcParallelId(const ParallelDesc& from_parallel_desc,\n                                          const ParallelDesc& to_parallel_desc,\n                                          int64_t to_parallel_id);\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_GRAPH_BOXING_SUB_TASK_GRAPH_BUILDER_UTIL_H_\n"
  },
  {
    "path": "oneflow/core/graph/boxing_identity_task_node.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/to_string.h\"\n#include \"oneflow/core/graph/boxing_identity_task_node.h\"\n#include \"oneflow/core/graph/boxing_task_graph.pb.h\"\n\nnamespace oneflow {\n\nvoid BoxingIdentityTaskNode::Init(int64_t machine_id, int64_t thrd_id, const LogicalBlobId& lbi) {\n  set_machine_id(machine_id);\n  set_thrd_id(thrd_id);\n  set_lbi(lbi);\n}\n\nvoid BoxingIdentityTaskNode::ProduceAllRegstsAndBindEdges() {\n  std::shared_ptr<RegstDesc> out_regst = ProduceRegst(\"out\", true, 1, 1);\n  this->ForEachOutDataEdge([&](TaskEdge* out_dege) { out_dege->AddRegst(\"out\", out_regst); });\n}\n\nvoid BoxingIdentityTaskNode::ConsumeAllRegsts() {\n  this->ForEachInDataEdge(\n      [&](TaskEdge* in_edge) { ConsumeRegst(\"in\", SoleInDataEdge()->GetSoleRegst()); });\n}\n\nvoid BoxingIdentityTaskNode::BuildExecGphAndRegst() {\n  ExecNode* node = mut_exec_gph().NewNode();\n  OperatorConf op_conf;\n  op_conf.set_name(\"System-Boxing-Identity-\" + NewUniqueId());\n  op_conf.set_device_tag(*CHECK_JUST(DeviceTag4DeviceType(this->device_type())));\n  *op_conf.mutable_boxing_identity_conf()->mutable_lbi() = lbi();\n  std::shared_ptr<Operator> sole_op = CHECK_JUST(ConstructOp(op_conf));\n  node->mut_op() = sole_op;\n  node->BindBnWithRegst(sole_op->SoleIbn(), GetSoleConsumedRegst(\"in\"));\n  std::shared_ptr<RegstDesc> out_regst = GetProducedRegst(\"out\");\n  out_regst->AddLbi(sole_op->BnInOp2Lbi(sole_op->SoleObn()));\n  node->BindBnWithRegst(sole_op->SoleObn(), out_regst);\n  (node->*GetInferBlobDescsMethod())(nullptr);\n}\n\nvoid BoxingIdentityTaskNode::InferProducedDataRegstTimeShape() {\n  NaiveInferProducedDataRegstTimeShape();\n}\n\nMaybe<void> BoxingIdentityTaskNode::InitTransportTaskFromProto(\n    const TransportTaskProto& transport_task_proto, const TaskGraphRebuildCtx& ctx) {\n  CHECK_OR_RETURN(transport_task_proto.has_boxing_identity_task())\n      << \"not a serialized BoxingIdentityTaskNode. debug string: \"\n      << transport_task_proto.DebugString();\n  return Maybe<void>::Ok();\n}\n\nvoid BoxingIdentityTaskNode::ToTransportTaskProto(TransportTaskProto* transport_task_proto) const {\n  ToProto(transport_task_proto->mutable_task_proto(), /*check=*/false);\n  transport_task_proto->mutable_boxing_identity_task();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/graph/boxing_identity_task_node.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_GRAPH_BOXING_IDENTITY_TASK_NODE_H_\n#define ONEFLOW_CORE_GRAPH_BOXING_IDENTITY_TASK_NODE_H_\n\n#include \"oneflow/core/graph/transport_task_node.h\"\n\nnamespace oneflow {\n\nclass BoxingIdentityTaskNode : public TransportTaskNode {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(BoxingIdentityTaskNode);\n  BoxingIdentityTaskNode() = default;\n  ~BoxingIdentityTaskNode() override = default;\n\n  void Init(int64_t machine_id, int64_t thrd_id, const LogicalBlobId& lbi);\n  TaskType GetTaskType() const override { return TaskType::kBoxingIdentity; }\n\n  Maybe<void> InitTransportTaskFromProto(const TransportTaskProto& transport_task_proto,\n                                         const TaskGraphRebuildCtx& ctx) override;\n  void ToTransportTaskProto(TransportTaskProto*) const override;\n\n private:\n  void BuildExecGphAndRegst() override;\n  void ProduceAllRegstsAndBindEdges() override;\n  void ConsumeAllRegsts() final;\n  void InferProducedDataRegstTimeShape() final;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_GRAPH_BOXING_IDENTITY_TASK_NODE_H_\n"
  },
  {
    "path": "oneflow/core/graph/boxing_task_graph.proto",
    "content": "syntax = \"proto2\";\npackage oneflow;\n\nimport \"oneflow/core/register/logical_blob_id.proto\";\nimport \"oneflow/core/common/shape.proto\";\nimport \"oneflow/core/common/data_type.proto\";\nimport \"oneflow/core/job/sbp_parallel.proto\";\nimport \"oneflow/core/job/task.proto\";\nimport \"oneflow/core/job/placement.proto\";\nimport \"oneflow/core/graph/task_edge.proto\";\nimport \"oneflow/core/operator/op_conf.proto\";\nimport \"oneflow/core/register/tensor_slice_view.proto\";\n\nmessage ComputeTasksProto {\n  map<int64, TaskProto> parallel_id2task = 2;\n}\n\nmessage CollectiveBoxingGenericTaskProto {\n  required OperatorConf op_conf = 1;\n}\n\nmessage NcclSendRecvBoxingTaskProto {\n  required ShapeProto logical_shape = 1;\n  required DataType data_type = 2;\n  required NdSbp src_nd_sbp = 3;\n  required NdSbp dst_nd_sbp = 4;\n  required ParallelConf src_parallel_conf = 5;\n  required ParallelConf dst_parallel_conf = 6;\n  required ParallelConf parallel_conf = 7;\n  required ParallelContext parallel_ctx = 8;\n  required bool has_input = 9;\n  required bool has_output = 10;\n  required string stream_name = 11;\n}\n\nenum CopyHdType {\n  H2D = 0;\n  D2H = 1;\n}\n\nmessage CopyHdTaskProto {\n  required CopyHdType copy_type = 1;\n}\n\nmessage CopyCommNetTaskProto {\n}\n\nmessage BoxingZerosTaskProto {\n  required ShapeProto shape = 1;\n  required DataType data_type = 2;\n  required ShapeProto time_shape = 3;\n}\n\nenum SliceBoxingTaskMode {\n  kSliceBoxingTaskModeInvalid = 0;\n  kSliceBoxingTaskModeCopy = 1;\n  kSliceBoxingTaskModeAdd = 2;\n}\n\nmessage SliceBoxingTaskProto {\n  map<int64, TensorSliceViewProto> in_data_edge_uid2slice = 1;\n  repeated int64 ordered_in_data_edge_uid = 2;\n  required TensorSliceViewProto out_slice = 3;\n  required ShapeProto out_shape = 4;\n  required SliceBoxingTaskMode mode = 5;\n}\n\nmessage CollectiveBoxingPackTaskProto {\n  required ShapeProto logical_shape = 1;\n  required SbpParallel src_sbp_parallel = 2;\n  required SbpParallel dst_sbp_parallel = 3;\n  required int64 parallel_num = 4;\n}\n\nmessage CollectiveBoxingUnpackTaskProto {\n  required ShapeProto logical_shape = 1;\n  required SbpParallel src_sbp_parallel = 2;\n  required SbpParallel dst_sbp_parallel = 3;\n  required int64 parallel_num = 4;\n}\n\nmessage BoxingIdentityTaskProto {\n}\n\nmessage TransportTaskProto {\n  required TaskProto task_proto = 1;\n  required LogicalBlobId lbi = 11;\n  oneof transport_task_type {\n    CollectiveBoxingGenericTaskProto collective_boxing_generic_task = 2;\n    NcclSendRecvBoxingTaskProto nccl_send_recv_boxing_task = 3;\n    CopyHdTaskProto copy_hd_task = 4;\n    CopyCommNetTaskProto copy_comm_net_task = 5;\n    BoxingZerosTaskProto boxing_zeros_task = 6;\n    SliceBoxingTaskProto slice_boxing_task = 7;\n    CollectiveBoxingPackTaskProto collective_boxing_pack_task = 8;\n    CollectiveBoxingUnpackTaskProto collective_boxing_unpack_task = 9;\n    BoxingIdentityTaskProto boxing_identity_task = 10;\n  }\n}\n\nmessage TaskIdsProto {\n  repeated int64 task_id = 1;\n}\n\nmessage BoxingTaskGraphProto {\n  map<string, ComputeTasksProto> boxing_related_op_name2compute_tasks = 1;\n  repeated TransportTaskProto transport_task = 2;\n  repeated TaskEdgeProto task_edge = 3;\n  map<string, TaskIdsProto> boxing_unrelated_op_name2task_ids = 4;\n}\n"
  },
  {
    "path": "oneflow/core/graph/boxing_zeros_task_node.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/to_string.h\"\n#include \"oneflow/core/graph/boxing_zeros_task_node.h\"\n#include \"oneflow/core/graph/boxing_task_graph.pb.h\"\n\nnamespace oneflow {\n\nvoid BoxingZerosTaskNode::Init(int64_t machine_id, int64_t thrd_id, const LogicalBlobId& lbi,\n                               const Shape& shape, DataType data_type, const Shape& time_shape) {\n  set_machine_id(machine_id);\n  set_thrd_id(thrd_id);\n  set_lbi(lbi);\n  shape_ = shape;\n  data_type_ = data_type;\n  time_shape_ = time_shape;\n}\n\nvoid BoxingZerosTaskNode::ProduceAllRegstsAndBindEdges() {\n  std::shared_ptr<RegstDesc> out_regst = ProduceRegst(\"out\", false, 1, 1);\n  this->ForEachOutDataEdge([&](TaskEdge* out_dege) { out_dege->AddRegst(\"out\", out_regst); });\n}\n\nvoid BoxingZerosTaskNode::ConsumeAllRegsts() {\n  // do nothing\n}\n\nvoid BoxingZerosTaskNode::BuildExecGphAndRegst() {\n  ExecNode* node = mut_exec_gph().NewNode();\n  OperatorConf op_conf;\n  op_conf.set_name(\"System-Boxing-Zeros-\" + NewUniqueId());\n  op_conf.set_device_tag(*CHECK_JUST(DeviceTag4DeviceType(this->device_type())));\n  *op_conf.mutable_boxing_zeros_conf()->mutable_lbi() = lbi();\n  shape_.ToProto(op_conf.mutable_boxing_zeros_conf()->mutable_shape());\n  op_conf.mutable_boxing_zeros_conf()->set_data_type(data_type_);\n  std::shared_ptr<Operator> sole_op = CHECK_JUST(ConstructOp(op_conf));\n  node->mut_op() = sole_op;\n  std::shared_ptr<RegstDesc> out_regst = GetProducedRegst(\"out\");\n  out_regst->AddLbi(sole_op->BnInOp2Lbi(sole_op->SoleObn()));\n  node->BindBnWithRegst(sole_op->SoleObn(), out_regst);\n  (node->*GetInferBlobDescsMethod())(nullptr);\n}\n\nvoid BoxingZerosTaskNode::InferProducedDataRegstTimeShape() {\n  GetProducedRegst(\"out\")->mut_data_regst_time_shape()->reset(new Shape(time_shape_));\n}\nMaybe<void> BoxingZerosTaskNode::InitTransportTaskFromProto(\n    const TransportTaskProto& transport_task_proto, const TaskGraphRebuildCtx& ctx) {\n  CHECK_OR_RETURN(transport_task_proto.has_boxing_zeros_task())\n      << \"not a serialized BoxingZerosTaskNode. debug string: \"\n      << transport_task_proto.DebugString();\n  const auto& proto = transport_task_proto.boxing_zeros_task();\n  shape_ = Shape(proto.shape());\n  data_type_ = proto.data_type();\n  time_shape_ = Shape(proto.time_shape());\n  return Maybe<void>::Ok();\n}\n\nvoid BoxingZerosTaskNode::ToTransportTaskProto(TransportTaskProto* transport_task_proto) const {\n  ToProto(transport_task_proto->mutable_task_proto(), /*check=*/false);\n  auto* proto = transport_task_proto->mutable_boxing_zeros_task();\n  shape_.ToProto(proto->mutable_shape());\n  proto->set_data_type(data_type_);\n  time_shape_.ToProto(proto->mutable_time_shape());\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/graph/boxing_zeros_task_node.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_GRAPH_BOXING_ZEROS_TASK_NODE_H_\n#define ONEFLOW_CORE_GRAPH_BOXING_ZEROS_TASK_NODE_H_\n\n#include \"oneflow/core/graph/transport_task_node.h\"\n\nnamespace oneflow {\n\nclass BoxingZerosTaskNode : public TransportTaskNode {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(BoxingZerosTaskNode);\n  BoxingZerosTaskNode() = default;\n  ~BoxingZerosTaskNode() override = default;\n\n  void Init(int64_t machine_id, int64_t thrd_id, const LogicalBlobId& lbi, const Shape& shape,\n            DataType data_type, const Shape& time_shape);\n  TaskType GetTaskType() const override { return TaskType::kBoxingZeros; }\n\n  Maybe<void> InitTransportTaskFromProto(const TransportTaskProto& transport_task_proto,\n                                         const TaskGraphRebuildCtx& ctx) override;\n  void ToTransportTaskProto(TransportTaskProto*) const override;\n\n private:\n  void BuildExecGphAndRegst() override;\n  void ProduceAllRegstsAndBindEdges() override;\n  void ConsumeAllRegsts() final;\n  void InferProducedDataRegstTimeShape() final;\n\n  Shape shape_;\n  DataType data_type_;\n  Shape time_shape_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_GRAPH_BOXING_ZEROS_TASK_NODE_H_\n"
  },
  {
    "path": "oneflow/core/graph/collective_boxing_pack_task_node.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/to_string.h\"\n#include \"oneflow/core/graph/collective_boxing_pack_task_node.h\"\n#include \"oneflow/core/graph/boxing_task_graph.pb.h\"\n\nnamespace oneflow {\n\nvoid CollectiveBoxingPackTaskNode::Init(int64_t machine_id, int64_t thrd_id,\n                                        const LogicalBlobId& lbi, const Shape& logical_shape,\n                                        const SbpParallel& src_sbp_parallel,\n                                        const SbpParallel& dst_sbp_parallel,\n                                        const int64_t parallel_num) {\n  set_machine_id(machine_id);\n  set_thrd_id(thrd_id);\n  set_lbi(lbi);\n  logical_shape_ = logical_shape;\n  parallel_num_ = parallel_num;\n  src_sbp_parallel_ = src_sbp_parallel;\n  dst_sbp_parallel_ = dst_sbp_parallel;\n}\n\nvoid CollectiveBoxingPackTaskNode::ProduceAllRegstsAndBindEdges() {\n  std::shared_ptr<RegstDesc> out_regst = ProduceRegst(\"out\", true, 1, 1);\n  this->ForEachOutDataEdge([&](TaskEdge* out_dege) { out_dege->AddRegst(\"out\", out_regst); });\n}\n\nvoid CollectiveBoxingPackTaskNode::ConsumeAllRegsts() {\n  this->ForEachInDataEdge(\n      [&](TaskEdge* in_edge) { ConsumeRegst(\"in\", SoleInDataEdge()->GetSoleRegst()); });\n}\n\nvoid CollectiveBoxingPackTaskNode::BuildExecGphAndRegst() {\n  ExecNode* node = mut_exec_gph().NewNode();\n  OperatorConf op_conf;\n  op_conf.set_name(\"System-Collective-Boxing-Pack-\" + NewUniqueId());\n  op_conf.set_device_tag(*CHECK_JUST(DeviceTag4DeviceType(this->device_type())));\n  auto* collective_boxing_pack_conf = op_conf.mutable_collective_boxing_pack_conf();\n  *collective_boxing_pack_conf->mutable_lbi() = lbi();\n  logical_shape_.ToProto(collective_boxing_pack_conf->mutable_logical_shape());\n  *collective_boxing_pack_conf->mutable_src_sbp_parallel() = src_sbp_parallel_;\n  *collective_boxing_pack_conf->mutable_dst_sbp_parallel() = dst_sbp_parallel_;\n  collective_boxing_pack_conf->set_num_ranks(parallel_num_);\n  std::shared_ptr<Operator> sole_op = CHECK_JUST(ConstructOp(op_conf));\n  node->mut_op() = sole_op;\n  node->BindBnWithRegst(sole_op->SoleIbn(), GetSoleConsumedRegst(\"in\"));\n  std::shared_ptr<RegstDesc> out_regst = GetProducedRegst(\"out\");\n  out_regst->AddLbi(sole_op->BnInOp2Lbi(sole_op->SoleObn()));\n  node->BindBnWithRegst(sole_op->SoleObn(), out_regst);\n  (node->*GetInferBlobDescsMethod())(nullptr);\n}\n\nvoid CollectiveBoxingPackTaskNode::InferProducedDataRegstTimeShape() {\n  NaiveInferProducedDataRegstTimeShape();\n}\n\nMaybe<void> CollectiveBoxingPackTaskNode::InitTransportTaskFromProto(\n    const TransportTaskProto& transport_task_proto, const TaskGraphRebuildCtx& ctx) {\n  CHECK_OR_RETURN(transport_task_proto.has_collective_boxing_pack_task())\n      << \"not a serialized CollectiveBoxingPackTaskNode. debug string: \"\n      << transport_task_proto.DebugString();\n  const auto& proto = transport_task_proto.collective_boxing_pack_task();\n  logical_shape_ = Shape(proto.logical_shape());\n  src_sbp_parallel_ = proto.src_sbp_parallel();\n  dst_sbp_parallel_ = proto.dst_sbp_parallel();\n  parallel_num_ = proto.parallel_num();\n  return Maybe<void>::Ok();\n}\n\nvoid CollectiveBoxingPackTaskNode::ToTransportTaskProto(\n    TransportTaskProto* transport_task_proto) const {\n  ToProto(transport_task_proto->mutable_task_proto(), /*check=*/false);\n  auto* proto = transport_task_proto->mutable_collective_boxing_pack_task();\n  logical_shape_.ToProto(proto->mutable_logical_shape());\n  *proto->mutable_src_sbp_parallel() = src_sbp_parallel_;\n  *proto->mutable_dst_sbp_parallel() = dst_sbp_parallel_;\n  proto->set_parallel_num(parallel_num_);\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/graph/collective_boxing_pack_task_node.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_GRAPH_COLLECTIVE_BOXING_PACK_TASK_NODE_H_\n#define ONEFLOW_CORE_GRAPH_COLLECTIVE_BOXING_PACK_TASK_NODE_H_\n\n#include \"oneflow/core/graph/transport_task_node.h\"\n\nnamespace oneflow {\n\nclass CollectiveBoxingPackTaskNode : public TransportTaskNode {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CollectiveBoxingPackTaskNode);\n  CollectiveBoxingPackTaskNode() = default;\n  ~CollectiveBoxingPackTaskNode() override = default;\n\n  void Init(int64_t machine_id, int64_t thrd_id, const LogicalBlobId& lbi,\n            const Shape& logical_shape, const SbpParallel& src_sbp_parallel,\n            const SbpParallel& dst_sbp_parallel, const int64_t parallel_num);\n  TaskType GetTaskType() const override { return TaskType::kCollectiveBoxingPack; }\n\n  Maybe<void> InitTransportTaskFromProto(const TransportTaskProto& transport_task_proto,\n                                         const TaskGraphRebuildCtx& ctx) override;\n  void ToTransportTaskProto(TransportTaskProto*) const override;\n\n private:\n  void BuildExecGphAndRegst() override;\n  void ProduceAllRegstsAndBindEdges() override;\n  void ConsumeAllRegsts() final;\n  void InferProducedDataRegstTimeShape() final;\n\n  Shape logical_shape_;\n  SbpParallel src_sbp_parallel_;\n  SbpParallel dst_sbp_parallel_;\n  int64_t parallel_num_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_GRAPH_COLLECTIVE_BOXING_PACK_TASK_NODE_H_\n"
  },
  {
    "path": "oneflow/core/graph/collective_boxing_task_node.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/graph/boxing_task_graph.pb.h\"\n#include \"oneflow/core/graph/collective_boxing_task_node.h\"\n#include \"oneflow/core/graph/boxing/collective_boxing_util.h\"\n\nnamespace oneflow {\n\nvoid CollectiveBoxingGenericTaskNode::Init(int64_t machine_id, int64_t thrd_id,\n                                           const LogicalBlobId& lbi, const OperatorConf& op_conf) {\n  set_machine_id(machine_id);\n  set_thrd_id(thrd_id);\n  set_lbi(lbi);\n  op_conf_ = op_conf;\n}\n\nvoid CollectiveBoxingGenericTaskNode::ProduceAllRegstsAndBindEdges() {\n  if (boxing::collective::GenericOpHasOutput(\n          op_conf_.collective_boxing_generic_conf().rank_desc())) {\n    const bool enable_mem_reuse =\n        ParseBooleanFromEnv(\"ONEFLOW_GRAPH_BOXING_ENABLE_MEM_REUSE\", false);\n    std::shared_ptr<RegstDesc> out_regst = ProduceRegst(\"out\", enable_mem_reuse, 1, 1);\n    this->ForEachOutDataEdge([&](TaskEdge* out_dege) { out_dege->AddRegst(\"out\", out_regst); });\n  }\n}\n\nvoid CollectiveBoxingGenericTaskNode::ConsumeAllRegsts() {\n  this->ForEachInDataEdge(\n      [&](TaskEdge* in_edge) { ConsumeRegst(\"in\", SoleInDataEdge()->GetSoleRegst()); });\n}\n\nvoid CollectiveBoxingGenericTaskNode::BuildExecGphAndRegst() {\n  ExecNode* node = mut_exec_gph().NewNode();\n  std::shared_ptr<Operator> boxing_op = CHECK_JUST(ConstructOp(op_conf_));\n  node->mut_op() = boxing_op;\n  for (const std::string& ibn : boxing_op->input_bns()) {\n    node->BindBnWithRegst(ibn, GetSoleConsumedRegst(\"in\"));\n  }\n  std::shared_ptr<RegstDesc> out_regst = GetProducedRegst(\"out\");\n  for (const std::string& obn : boxing_op->output_bns()) {\n    CHECK(out_regst != nullptr);\n    node->BindBnWithRegst(obn, out_regst);\n    out_regst->AddLbi(boxing_op->BnInOp2Lbi(obn));\n  }\n  (node->*GetInferBlobDescsMethod())(nullptr);\n}\n\nvoid CollectiveBoxingGenericTaskNode::InferProducedDataRegstTimeShape() {\n  auto out_regst = GetProducedRegst(\"out\");\n  if (out_regst != nullptr) { out_regst->mut_data_regst_time_shape()->reset(new Shape({1, 1})); }\n}\n\nMaybe<void> CollectiveBoxingGenericTaskNode::InitTransportTaskFromProto(\n    const TransportTaskProto& transport_task_proto, const TaskGraphRebuildCtx& ctx) {\n  CHECK_OR_RETURN(transport_task_proto.has_collective_boxing_generic_task())\n      << \"not a serialized CollectiveBoxingGenericTaskNode. debug string: \"\n      << transport_task_proto.DebugString();\n  op_conf_ = transport_task_proto.collective_boxing_generic_task().op_conf();\n  return Maybe<void>::Ok();\n}\n\nvoid CollectiveBoxingGenericTaskNode::ToTransportTaskProto(\n    TransportTaskProto* transport_task_proto) const {\n  ToProto(transport_task_proto->mutable_task_proto(), /*check=*/false);\n  *transport_task_proto->mutable_collective_boxing_generic_task()->mutable_op_conf() = op_conf_;\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/graph/collective_boxing_task_node.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_GRAPH_COLLECTIVE_BOXING_TASK_NODE_H_\n#define ONEFLOW_CORE_GRAPH_COLLECTIVE_BOXING_TASK_NODE_H_\n\n#include \"oneflow/core/graph/transport_task_node.h\"\n\nnamespace oneflow {\n\nclass CollectiveBoxingGenericTaskNode : public TransportTaskNode {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CollectiveBoxingGenericTaskNode);\n  CollectiveBoxingGenericTaskNode() = default;\n  ~CollectiveBoxingGenericTaskNode() override = default;\n\n  void Init(int64_t machine_id, int64_t thrd_id, const LogicalBlobId& lbi,\n            const OperatorConf& op_conf);\n\n  Maybe<void> InitTransportTaskFromProto(const TransportTaskProto& transport_task_proto,\n                                         const TaskGraphRebuildCtx& ctx) override;\n  void ToTransportTaskProto(TransportTaskProto*) const override;\n\n private:\n  void BuildExecGphAndRegst() override;\n  void ProduceAllRegstsAndBindEdges() override;\n  void ConsumeAllRegsts() final;\n  void InferProducedDataRegstTimeShape() final;\n  TaskType GetTaskType() const override { return TaskType::kCollectiveBoxingGeneric; }\n\n  OperatorConf op_conf_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_GRAPH_COLLECTIVE_BOXING_TASK_NODE_H_\n"
  },
  {
    "path": "oneflow/core/graph/collective_boxing_unpack_task_node.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/to_string.h\"\n#include \"oneflow/core/graph/boxing_task_graph.pb.h\"\n#include \"oneflow/core/graph/collective_boxing_unpack_task_node.h\"\n\nnamespace oneflow {\n\nvoid CollectiveBoxingUnpackTaskNode::Init(int64_t machine_id, int64_t thrd_id,\n                                          const LogicalBlobId& lbi, const Shape& logical_shape,\n                                          const SbpParallel& src_sbp_parallel,\n                                          const SbpParallel& dst_sbp_parallel,\n                                          const int64_t parallel_num) {\n  set_machine_id(machine_id);\n  set_thrd_id(thrd_id);\n  set_lbi(lbi);\n  logical_shape_ = logical_shape;\n  parallel_num_ = parallel_num;\n  src_sbp_parallel_ = src_sbp_parallel;\n  dst_sbp_parallel_ = dst_sbp_parallel;\n}\n\nvoid CollectiveBoxingUnpackTaskNode::ProduceAllRegstsAndBindEdges() {\n  std::shared_ptr<RegstDesc> out_regst = ProduceRegst(\"out\", true, 1, 1);\n  this->ForEachOutDataEdge([&](TaskEdge* out_dege) { out_dege->AddRegst(\"out\", out_regst); });\n}\n\nvoid CollectiveBoxingUnpackTaskNode::ConsumeAllRegsts() {\n  this->ForEachInDataEdge(\n      [&](TaskEdge* in_edge) { ConsumeRegst(\"in\", SoleInDataEdge()->GetSoleRegst()); });\n}\n\nvoid CollectiveBoxingUnpackTaskNode::BuildExecGphAndRegst() {\n  ExecNode* node = mut_exec_gph().NewNode();\n  OperatorConf op_conf;\n  op_conf.set_name(\"System-Collective-Boxing-Unpack-\" + NewUniqueId());\n  op_conf.set_device_tag(*CHECK_JUST(DeviceTag4DeviceType(this->device_type())));\n  auto* collective_boxing_unpack_conf = op_conf.mutable_collective_boxing_unpack_conf();\n  *collective_boxing_unpack_conf->mutable_lbi() = lbi();\n  logical_shape_.ToProto(collective_boxing_unpack_conf->mutable_logical_shape());\n  *collective_boxing_unpack_conf->mutable_src_sbp_parallel() = src_sbp_parallel_;\n  *collective_boxing_unpack_conf->mutable_dst_sbp_parallel() = dst_sbp_parallel_;\n  collective_boxing_unpack_conf->set_num_ranks(parallel_num_);\n  std::shared_ptr<Operator> sole_op = CHECK_JUST(ConstructOp(op_conf));\n  node->mut_op() = sole_op;\n  node->BindBnWithRegst(sole_op->SoleIbn(), GetSoleConsumedRegst(\"in\"));\n  std::shared_ptr<RegstDesc> out_regst = GetProducedRegst(\"out\");\n  out_regst->AddLbi(sole_op->BnInOp2Lbi(sole_op->SoleObn()));\n  node->BindBnWithRegst(sole_op->SoleObn(), out_regst);\n  (node->*GetInferBlobDescsMethod())(nullptr);\n}\n\nvoid CollectiveBoxingUnpackTaskNode::InferProducedDataRegstTimeShape() {\n  NaiveInferProducedDataRegstTimeShape();\n}\n\nMaybe<void> CollectiveBoxingUnpackTaskNode::InitTransportTaskFromProto(\n    const TransportTaskProto& transport_task_proto, const TaskGraphRebuildCtx& ctx) {\n  CHECK_OR_RETURN(transport_task_proto.has_collective_boxing_unpack_task())\n      << \"not a serialized CollectiveBoxingUnpackTaskNode. debug string: \"\n      << transport_task_proto.DebugString();\n  const auto& proto = transport_task_proto.collective_boxing_unpack_task();\n  logical_shape_ = Shape(proto.logical_shape());\n  src_sbp_parallel_ = proto.src_sbp_parallel();\n  dst_sbp_parallel_ = proto.dst_sbp_parallel();\n  parallel_num_ = proto.parallel_num();\n  return Maybe<void>::Ok();\n}\n\nvoid CollectiveBoxingUnpackTaskNode::ToTransportTaskProto(\n    TransportTaskProto* transport_task_proto) const {\n  ToProto(transport_task_proto->mutable_task_proto(), /*check=*/false);\n  auto* proto = transport_task_proto->mutable_collective_boxing_unpack_task();\n  logical_shape_.ToProto(proto->mutable_logical_shape());\n  *proto->mutable_src_sbp_parallel() = src_sbp_parallel_;\n  *proto->mutable_dst_sbp_parallel() = dst_sbp_parallel_;\n  proto->set_parallel_num(parallel_num_);\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/graph/collective_boxing_unpack_task_node.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_GRAPH_COLLECTIVE_BOXING_UNPACK_TASK_NODE_H_\n#define ONEFLOW_CORE_GRAPH_COLLECTIVE_BOXING_UNPACK_TASK_NODE_H_\n\n#include \"oneflow/core/graph/transport_task_node.h\"\n\nnamespace oneflow {\n\nclass CollectiveBoxingUnpackTaskNode : public TransportTaskNode {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CollectiveBoxingUnpackTaskNode);\n  CollectiveBoxingUnpackTaskNode() = default;\n  ~CollectiveBoxingUnpackTaskNode() override = default;\n\n  void Init(int64_t machine_id, int64_t thrd_id, const LogicalBlobId& lbi,\n            const Shape& logical_shape, const SbpParallel& src_sbp_parallel,\n            const SbpParallel& dst_sbp_parallel, const int64_t parallel_num);\n\n  TaskType GetTaskType() const override { return TaskType::kCollectiveBoxingUnpack; }\n\n  Maybe<void> InitTransportTaskFromProto(const TransportTaskProto& transport_task_proto,\n                                         const TaskGraphRebuildCtx& ctx) override;\n  void ToTransportTaskProto(TransportTaskProto*) const override;\n\n private:\n  void BuildExecGphAndRegst() override;\n  void ProduceAllRegstsAndBindEdges() override;\n  void ConsumeAllRegsts() final;\n  void InferProducedDataRegstTimeShape() final;\n\n  Shape logical_shape_;\n  SbpParallel src_sbp_parallel_;\n  SbpParallel dst_sbp_parallel_;\n  int64_t parallel_num_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_GRAPH_COLLECTIVE_BOXING_UNPACK_TASK_NODE_H_\n"
  },
  {
    "path": "oneflow/core/graph/compute_task_node.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/graph/compute_task_node.h\"\n#include \"oneflow/core/graph/task_graph.h\"\n#include \"oneflow/core/graph/normal_forward_compute_task_node.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nconst OpNode* OpNodeOnEdge(TaskEdge* edge, TaskNode* (TaskEdge::*GetNode)() const,\n                           void (TaskNode::*ForEachDataEdge)(const std::function<void(TaskEdge*)>&)\n                               const) {\n  CompTaskNode* target_node = nullptr;\n  do {\n    TaskNode* tmp_node = (edge->*GetNode)();\n    target_node = dynamic_cast<CompTaskNode*>(tmp_node);\n    edge = nullptr;\n    (tmp_node->*ForEachDataEdge)([&](TaskEdge* e) {\n      if (edge == nullptr) { edge = e; }\n    });\n  } while (!target_node && edge);\n  if (target_node) { return target_node->op_node(); }\n  return nullptr;\n}\n\nstd::vector<CompTaskNode*> GetCompTaskNodesOnEdge(\n    TaskEdge* edge, TaskNode* (TaskEdge::*GetNode)() const,\n    void (TaskNode::*ForEachDataEdge)(const std::function<void(TaskEdge*)>&) const) {\n  std::queue<TaskNode*> nodes;\n  HashSet<TaskNode*> visited_nodes;\n  nodes.push((edge->*GetNode)());\n  CHECK(visited_nodes.emplace((edge->*GetNode)()).second);\n  std::vector<CompTaskNode*> comp_task_nodes;\n  while (!nodes.empty()) {\n    TaskNode* node = nodes.front();\n    nodes.pop();\n    CompTaskNode* comp_task_node = dynamic_cast<CompTaskNode*>(node);\n    if (comp_task_node) {\n      comp_task_nodes.emplace_back(comp_task_node);\n    } else {\n      (node->*ForEachDataEdge)([&](TaskEdge* task_edge) {\n        if (visited_nodes.find((task_edge->*GetNode)()) == visited_nodes.end()) {\n          nodes.push((task_edge->*GetNode)());\n          CHECK(visited_nodes.emplace((task_edge->*GetNode)()).second);\n        }\n      });\n    }\n  }\n  return comp_task_nodes;\n}\n\nstd::shared_ptr<RegstDesc> NewFakeDataRegstDesc() {\n  auto regst_desc = std::make_shared<RegstDesc>();\n  regst_desc->mut_regst_desc_type()->mutable_data_regst_desc();\n  return regst_desc;\n}\n\n}  // namespace\n\nvoid CompTaskNode::ConsumeFakeRegst(const std::string& regst_name) {\n  ConsumeRegst(regst_name, NewFakeDataRegstDesc());\n  fake_consumed_regst_names_.insert(regst_name);\n}\n\nvoid CompTaskNode::ConsumeFakeRegstsIf() {\n  ConsumeFakeRegsts();\n  RegstDesc* data_regst_desc = nullptr;\n  for (const auto& pair : consumed_regsts()) {\n    for (const auto& regst_desc : pair.second) {\n      if (regst_desc->regst_desc_type().has_data_regst_desc()) {\n        // Only one fake data regst is creatd for each CompTaskNode with ConsumeFakeRegsts().\n        CHECK(data_regst_desc == nullptr);\n        data_regst_desc = CHECK_NOTNULL(regst_desc.get());\n      } else if (regst_desc->regst_desc_type().has_ctrl_regst_desc()) {\n        // do nothing.\n      } else {\n        UNIMPLEMENTED();\n      }\n    }\n  }\n  if (data_regst_desc != nullptr) {\n    for (const auto& ibn : op_node()->op().input_bns()) {\n      // Only one fake data regst is creatd and just use it for all input_bns as a placeholder.\n      data_regst_desc->AddLbi(op_node()->op().BnInOp2Lbi(ibn));\n    }\n  }\n}\n\nvoid CompTaskNode::EraseFakeRegstsIf() {\n  for (const auto& fake_consumed_regst_name : fake_consumed_regst_names_) {\n    EraseConsumedRegstsByName(fake_consumed_regst_name);\n  }\n  fake_consumed_regst_names_.clear();\n}\n\nstd::string CompTaskNode::VisualStr() const { return op_node_->op().op_name(); }\n\nvoid CompTaskNode::InitFromProtoExceptConsumedRegsts(const TaskProto& proto) {\n  TaskNode::InitFromProtoExceptConsumedRegsts(proto);\n  parallel_ctx_ = proto.parallel_ctx();\n}\n\nvoid CompTaskNode::ToProto(TaskProto* task_proto, bool check) const {\n  TaskNode::ToProto(task_proto, check);\n  *(task_proto->mutable_parallel_ctx()) = parallel_ctx_;\n}\n\nconst OpNode* CompTaskNode::GetOneSuccOpNodeOnEdge(TaskEdge* edge) {\n  return OpNodeOnEdge(edge, &TaskEdge::dst_node, &TaskNode::ForEachOutDataEdge);\n}\n\nconst OpNode* CompTaskNode::GetOnePredOpNodeOnEdge(TaskEdge* edge) {\n  return OpNodeOnEdge(edge, &TaskEdge::src_node, &TaskNode::ForEachInDataEdge);\n}\n\nstd::vector<CompTaskNode*> CompTaskNode::GetSuccCompTaskNodesOnEdge(TaskEdge* edge) const {\n  return GetCompTaskNodesOnEdge(edge, &TaskEdge::dst_node, &TaskNode::ForEachOutDataEdge);\n}\n\nstd::vector<CompTaskNode*> CompTaskNode::GetPredCompTaskNodesOnEdge(TaskEdge* edge) const {\n  return GetCompTaskNodesOnEdge(edge, &TaskEdge::src_node, &TaskNode::ForEachInDataEdge);\n}\n\nvoid CompTaskNode::InferProducedDataRegstTimeShape() {\n  std::shared_ptr<Shape> op_time_shape(new Shape(*CHECK_JUST(op()->GetOpTimeShape())));\n  ForEachProducedDataRegst([op_time_shape](const std::string& name, RegstDesc* regst) {\n    *regst->mut_data_regst_time_shape() = op_time_shape;\n  });\n}\n\nCompTaskNode* NewCompTaskNode4OpNode(const OpNode* op_node) {\n  const OperatorConf& op_conf = op_node->op().op_conf();\n  if (op_conf.has_user_conf()) {\n    const std::string& op_type_name = op_conf.user_conf().op_type_name();\n    if (IsClassRegistered<std::string, OpCompTaskNodeCreator>(op_type_name)) {\n      return std::unique_ptr<OpCompTaskNodeCreator>(\n                 NewObj<std::string, OpCompTaskNodeCreator>(op_type_name))\n          ->NewCompTaskNode(op_conf);\n    } else {\n      return new NormalForwardCompTaskNode;\n    }\n  } else {\n    OperatorConf::OpTypeCase op_type_case = op_conf.op_type_case();\n    if (IsClassRegistered<int32_t, OpCompTaskNodeCreator>(op_type_case)) {\n      return std::unique_ptr<OpCompTaskNodeCreator>(\n                 NewObj<int32_t, OpCompTaskNodeCreator>(op_type_case))\n          ->NewCompTaskNode(op_conf);\n    } else {\n      return new NormalForwardCompTaskNode;\n    }\n  }\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/graph/compute_task_node.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_GRAPH_COMPUTE_TASK_NODE_H_\n#define ONEFLOW_CORE_GRAPH_COMPUTE_TASK_NODE_H_\n\n#include \"oneflow/core/graph/task_node.h\"\n#include \"oneflow/core/graph/op_graph.h\"\n#include \"oneflow/core/graph/fake_consumed_regst_provider.h\"\n#include \"oneflow/core/device/cuda_util.h\"\n#include \"oneflow/core/job/compile_mode.h\"\n\nnamespace oneflow {\n\nclass CompTaskNode : public TaskNode, public FakeConsumedRegstProvider {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CompTaskNode);\n  CompTaskNode() = default;\n  virtual ~CompTaskNode() = default;\n\n  virtual void ToProto(TaskProto*, bool check) const override;\n  virtual void InitFromProtoExceptConsumedRegsts(const TaskProto&) override;\n  void ConsumeFakeRegstsIf() override;\n  void EraseFakeRegstsIf() override;\n\n  // ConsumeFakeRegsts is used for initializing CompTaskNode.consumed_regsts_ on the other ranks.\n  virtual void ConsumeFakeRegsts() = 0;\n  void ConsumeFakeRegst(const std::string& regst_name);\n\n  // parallel_ctx_\n  int64_t parallel_id() const { return parallel_ctx_.parallel_id(); }\n  const ParallelContext* parallel_ctx() const override { return &parallel_ctx_; }\n  ParallelContext* mut_parallel_ctx() { return &parallel_ctx_; }\n\n  // op_node_\n  const OpNode* op_node() const { return op_node_; }\n  void set_op_node(const OpNode* val) { op_node_ = val; }\n  std::string VisualStr() const override;\n\n  // op\n  std::shared_ptr<const Operator> op() const { return op_node_->shared_op(); }\n\n  ExecNode::InferBlobDescsMethod GetInferBlobDescsMethod() const override {\n    // For default compilation mode, compute task node use input blob desc to infer output blob\n    // desc; For separate compilation mode, compute task node use NdSBP to infer output blob desc.\n    return InferBlobDescsMethodGetter::Visit(CHECK_JUST(CurrentCompileMode()));\n  }\n\n protected:\n  const OpNode* GetOneSuccOpNodeOnEdge(TaskEdge* edge);\n  const OpNode* GetOnePredOpNodeOnEdge(TaskEdge* edge);\n  std::vector<CompTaskNode*> GetSuccCompTaskNodesOnEdge(TaskEdge* edge) const;\n  std::vector<CompTaskNode*> GetPredCompTaskNodesOnEdge(TaskEdge* edge) const;\n\n  void InferProducedDataRegstTimeShape() override;\n\n private:\n  struct InferBlobDescsMethodGetter final : public CompileModeVisitor<InferBlobDescsMethodGetter> {\n    static ExecNode::InferBlobDescsMethod VisitNaive() { return &ExecNode::InferBlobDescsByInputs; }\n    static ExecNode::InferBlobDescsMethod VisitRankPerProcess() {\n      return &ExecNode::InferBlobDescsByNdSbp;\n    }\n    static ExecNode::InferBlobDescsMethod VisitInValid() { return nullptr; }\n  };\n\n  ParallelContext parallel_ctx_;\n  const OpNode* op_node_;\n  HashSet<std::string> fake_consumed_regst_names_;\n};\n\nclass OpCompTaskNodeCreator {\n public:\n  virtual ~OpCompTaskNodeCreator() = default;\n  virtual CompTaskNode* NewCompTaskNode(const OperatorConf& op_conf) = 0;\n};\n\ntemplate<typename CompTaskNodeType>\nclass StaticOpCompTaskNodeCreator : public OpCompTaskNodeCreator {\n public:\n  StaticOpCompTaskNodeCreator() = default;\n  ~StaticOpCompTaskNodeCreator() override = default;\n\n private:\n  CompTaskNode* NewCompTaskNode(const OperatorConf& op_conf) override {\n    return new CompTaskNodeType();\n  }\n};\n\nclass FnOpCompTaskNodeCreator : public OpCompTaskNodeCreator {\n public:\n  using CreateFn = std::function<CompTaskNode*(const OperatorConf& op_conf)>;\n  explicit FnOpCompTaskNodeCreator(CreateFn fn) : fn_(std::move(fn)) {}\n  ~FnOpCompTaskNodeCreator() override = default;\n\n private:\n  CompTaskNode* NewCompTaskNode(const OperatorConf& op_conf) override { return fn_(op_conf); }\n  CreateFn fn_;\n};\n\n#define REGISTER_USER_OP_COMP_TASK_NODE_TYPE(op_type_name, comp_task_node_type) \\\n  REGISTER_CLASS_CREATOR(std::string, op_type_name, OpCompTaskNodeCreator,      \\\n                         ([] { return new StaticOpCompTaskNodeCreator<comp_task_node_type>(); }));\n\n#define REGISTER_USER_OP_COMP_TASK_NODE_TYPE_WITH_FUNC(op_type_name, func) \\\n  REGISTER_CLASS_CREATOR(std::string, op_type_name, OpCompTaskNodeCreator, \\\n                         ([] { return new FnOpCompTaskNodeCreator(func); }));\n\n#define REGISTER_SYSTEM_OP_COMP_TASK_NODE_TYPE(op_type_case, comp_task_node_type) \\\n  REGISTER_CLASS_CREATOR(int32_t, op_type_case, OpCompTaskNodeCreator,            \\\n                         ([] { return new StaticOpCompTaskNodeCreator<comp_task_node_type>(); }));\n\n#define REGISTER_SYSTEM_OP_COMP_TASK_NODE_TYPE_WITH_FUNC(op_type_case, func) \\\n  REGISTER_CLASS_CREATOR(int32_t, op_type_case, OpCompTaskNodeCreator,       \\\n                         ([] { return new FnOpCompTaskNodeCreator(func); }));\n\nCompTaskNode* NewCompTaskNode4OpNode(const OpNode* op_node);\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_GRAPH_COMPUTE_TASK_NODE_H_\n"
  },
  {
    "path": "oneflow/core/graph/copy_task_node.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/graph/copy_task_node.h\"\n#include \"oneflow/core/graph/task_stream_id.h\"\n#include \"oneflow/core/graph/boxing_task_graph.pb.h\"\n#include \"oneflow/core/framework/user_op_registry_manager.h\"\n\nnamespace oneflow {\n\nvoid CopyTaskNode::ProduceAllRegstsAndBindEdges() {\n  std::shared_ptr<RegstDesc> out_regst = ProduceRegst(\"copy_out\", false);\n  ForEachOutDataEdge([&](TaskEdge* edge) { edge->AddRegst(\"copy_out\", out_regst); });\n}\n\nvoid CopyTaskNode::ConsumeAllRegsts() { ConsumeRegst(\"copy_in\", SoleInDataEdge()->GetSoleRegst()); }\n\nvoid CopyTaskNode::BuildExecGphAndRegst() {\n  auto out_regst = GetProducedRegst(\"copy_out\");\n  auto in_regst = GetSoleConsumedRegst(\"copy_in\");\n  out_regst->CopyBlobDescFrom(in_regst.get());\n  ExecNode* node = mut_exec_gph().NewNode();\n  auto constructed = CHECK_JUST(ConstructOp(NewCopyOpConf()));\n\n  // prevent filling parallel desc for copy commnet\n  if (constructed->op_conf().has_user_conf()) {\n    std::shared_ptr<Shape> hierarchy = std::make_shared<Shape>(Shape({1}));\n    auto parallel_desc =\n        ParallelDesc::New(constructed->op_conf().device_tag(), {\"0:0-0\"}, hierarchy).GetOrThrow();\n    CHECK_JUST(constructed->FillOpParallelDesc(parallel_desc));\n  }\n\n  node->mut_op() = constructed;\n  node->BindBnWithRegst(node->op()->SoleIbn(), in_regst);\n  node->BindBnWithRegst(node->op()->SoleObn(), out_regst);\n}\n\nvoid CopyTaskNode::InferProducedDataRegstTimeShape() { NaiveInferProducedDataRegstTimeShape(); }\n\nvoid CopyHdTaskNode::Init(CopyHdType copy_type, const DeviceId& device_id,\n                          const LogicalBlobId& lbi) {\n  copy_type_ = copy_type;\n  set_machine_id(device_id.rank());\n  int64_t thrd_id = -1;\n  if (copy_type == CopyHdType::H2D) {\n    thrd_id = EncodeStreamIdToInt64(GenerateNamedTaskStreamId(device_id, \"H2D\"));\n  } else if (copy_type == CopyHdType::D2H) {\n    thrd_id = EncodeStreamIdToInt64(GenerateNamedTaskStreamId(device_id, \"D2H\"));\n  } else {\n    UNIMPLEMENTED();\n  }\n  set_thrd_id(thrd_id);\n  set_lbi(lbi);\n}\n\nvoid CopyHdTaskNode::InitProducedRegstMemCase(MemoryCase* mem_case) {\n  if (copy_type_ == CopyHdType::H2D) {\n    TaskNode::InitProducedRegstMemCase(mem_case);\n  } else if (copy_type_ == CopyHdType::D2H) {\n    mem_case->set_device_type(DeviceType::kCPU);\n    mem_case->set_device_id(0);\n    mem_case->set_pinned_device_type(device_type());\n    mem_case->set_pinned_device_id(stream_id().device_id().device_index());\n  } else {\n    UNIMPLEMENTED();\n  }\n}\n\nvoid CopyHdTaskNode::ProduceAllRegstsAndBindEdges() {\n  const bool enable_mem_reuse = ParseBooleanFromEnv(\"ONEFLOW_GRAPH_BOXING_ENABLE_MEM_REUSE\", false)\n                                && (copy_type_ == CopyHdType::H2D);\n  std::shared_ptr<RegstDesc> out_regst = ProduceRegst(\"copy_out\", enable_mem_reuse);\n  ForEachOutDataEdge([&](TaskEdge* edge) { edge->AddRegst(\"copy_out\", out_regst); });\n}\n\nOperatorConf CopyHdTaskNode::NewCopyOpConf() {\n  OperatorConf conf;\n  conf.set_device_tag(*CHECK_JUST(DeviceTag4DeviceType(device_type())));\n  auto copy_type_name = \"undefined\";\n  if (copy_type_ == CopyHdType::D2H) {\n    copy_type_name = \"copy_d2h\";\n  } else if (copy_type_ == CopyHdType::H2D) {\n    copy_type_name = \"copy_h2d\";\n  } else {\n    LOG(FATAL) << \"unknow copy type: \" << copy_type_;\n  }\n  conf.set_name(std::string(copy_type_name) + \"_\" + lbi().op_name() + \"-\" + lbi().blob_name() + \"_\"\n                + std::to_string(task_id()));\n  *conf.mutable_user_conf()->mutable_op_type_name() = copy_type_name;\n  auto in_regst = GetSoleConsumedRegst(\"copy_in\");\n  CHECK_EQ(in_regst->NumOfLbi(), 1);\n  in_regst->ForEachLbi([&](const LogicalBlobId& lbi) {\n    (*conf.mutable_user_conf()->mutable_input())[\"in\"].add_s(GenLogicalBlobName(lbi));\n    (*conf.mutable_user_conf()->mutable_output())[\"out\"].add_s(\n        GenLogicalBlobName(conf.name(), GenRepeatedBn(\"out\", 0)));\n  });\n  return conf;\n}\n\nvoid CopyCommNetTaskNode::Init(int64_t machine_id, const LogicalBlobId& lbi) {\n  set_machine_id(machine_id);\n  set_thrd_id(EncodeStreamIdToInt64(\n      GenerateNamedTaskStreamId(machine_id, DeviceType::kCPU, 0, \"COMM_NET\")));\n  set_lbi(lbi);\n}\n\nOperatorConf CopyCommNetTaskNode::NewCopyOpConf() {\n  OperatorConf conf;\n  conf.set_name(\"copy_comm_net_\" + NewUniqueId());\n  conf.set_device_tag(*CHECK_JUST(DeviceTag4DeviceType(this->device_type())));\n  *(conf.mutable_copy_comm_net_conf()->mutable_lbi()) = lbi();\n  return conf;\n}\n\nMaybe<void> CopyHdTaskNode::InitTransportTaskFromProto(\n    const TransportTaskProto& transport_task_proto, const TaskGraphRebuildCtx& ctx) {\n  CHECK_OR_RETURN(transport_task_proto.has_copy_hd_task())\n      << \"not a serialized CopyHdTaskNode. debug string: \" << transport_task_proto.DebugString();\n  copy_type_ = transport_task_proto.copy_hd_task().copy_type();\n  return Maybe<void>::Ok();\n}\n\nvoid CopyHdTaskNode::ToTransportTaskProto(TransportTaskProto* transport_task_proto) const {\n  ToProto(transport_task_proto->mutable_task_proto(), /*check=*/false);\n  transport_task_proto->mutable_copy_hd_task()->set_copy_type(copy_type_);\n}\n\nMaybe<void> CopyCommNetTaskNode::InitTransportTaskFromProto(\n    const TransportTaskProto& transport_task_proto, const TaskGraphRebuildCtx& ctx) {\n  CHECK_OR_RETURN(transport_task_proto.has_copy_comm_net_task())\n      << \"not a serialized CopyCommNetTaskNode. debug string: \"\n      << transport_task_proto.DebugString();\n  return Maybe<void>::Ok();\n}\n\nvoid CopyCommNetTaskNode::ToTransportTaskProto(TransportTaskProto* transport_task_proto) const {\n  ToProto(transport_task_proto->mutable_task_proto(), /*check=*/false);\n  transport_task_proto->mutable_copy_comm_net_task();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/graph/copy_task_node.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_GRAPH_COPY_TASK_NODE_H_\n#define ONEFLOW_CORE_GRAPH_COPY_TASK_NODE_H_\n\n#include \"oneflow/core/graph/transport_task_node.h\"\n#include \"oneflow/core/graph/boxing_task_graph.pb.h\"\n\nnamespace oneflow {\nclass CopyTaskNode : public TransportTaskNode {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CopyTaskNode);\n  CopyTaskNode() = default;\n  virtual ~CopyTaskNode() = default;\n\n  void ProduceAllRegstsAndBindEdges() override;\n  void ConsumeAllRegsts() override;\n  void BuildExecGphAndRegst() override;\n\n protected:\n  virtual OperatorConf NewCopyOpConf() = 0;\n\n private:\n  void InferProducedDataRegstTimeShape() final;\n};\n\nclass CopyHdTaskNode final : public CopyTaskNode {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CopyHdTaskNode);\n  CopyHdTaskNode() = default;\n  ~CopyHdTaskNode() = default;\n\n  TaskType GetTaskType() const override { return TaskType::kCopyHd; }\n\n  void Init(CopyHdType, const DeviceId& device_id, const LogicalBlobId& lbi);\n\n  void ProduceAllRegstsAndBindEdges() override;\n\n  CopyHdType copy_type() const { return copy_type_; }\n  MemZoneId MemZoneId121() const override {\n    if (copy_type_ == CopyHdType::H2D) {\n      return TaskNode::MemZoneId121();\n    } else if (copy_type_ == CopyHdType::D2H) {\n      return GetNodeCPUMemZoneId(this->machine_id());\n    } else {\n      UNIMPLEMENTED();\n    }\n    return kInvalidMemZoneId;\n  }\n\n  Maybe<void> InitTransportTaskFromProto(const TransportTaskProto& transport_task_proto,\n                                         const TaskGraphRebuildCtx& ctx) override;\n  void ToTransportTaskProto(TransportTaskProto*) const override;\n\n private:\n  void InitProducedRegstMemCase(MemoryCase*) override;\n  OperatorConf NewCopyOpConf() override;\n\n  CopyHdType copy_type_;\n};\n\nclass CopyCommNetTaskNode final : public CopyTaskNode {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CopyCommNetTaskNode);\n  CopyCommNetTaskNode() = default;\n  ~CopyCommNetTaskNode() = default;\n\n  TaskType GetTaskType() const override { return TaskType::kCopyCommNet; }\n\n  void Init(int64_t machine_id, const LogicalBlobId& lbi);\n\n  Maybe<void> InitTransportTaskFromProto(const TransportTaskProto& transport_task_proto,\n                                         const TaskGraphRebuildCtx& ctx) override;\n  void ToTransportTaskProto(TransportTaskProto*) const override;\n\n private:\n  OperatorConf NewCopyOpConf() override;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_GRAPH_COPY_TASK_NODE_H_\n"
  },
  {
    "path": "oneflow/core/graph/exec_graph.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/graph/exec_graph.h\"\n#include <sstream>\n#include \"oneflow/core/common/just.h\"\n#include \"oneflow/core/graph/op_graph.h\"\n\nnamespace oneflow {\n\nvoid ExecNode::BindBnWithRegst(const std::string& bn, std::shared_ptr<RegstDesc> regst) {\n  CHECK(bn_in_op2regst_.emplace(bn, regst).second);\n}\n\nvoid ExecNode::BindBnsWithRegst(const PbRpf<std::string>& (Operator::*bns_getter)() const,\n                                std::shared_ptr<RegstDesc> regst) {\n  for (const std::string& bn : (op_.get()->*bns_getter)()) { BindBnWithRegst(bn, regst); }\n}\n\nvoid ExecNode::AddBnToRegstAndBindIt(const PbRpf<std::string>& (Operator::*bns_getter)() const,\n                                     std::shared_ptr<RegstDesc> regst) {\n  for (const std::string& bn : (op_.get()->*bns_getter)()) { regst->AddLbi(op_->BnInOp2Lbi(bn)); }\n  BindBnsWithRegst(bns_getter, regst);\n}\n\nbool ExecNode::TryBindBnWithOneOfTheRegsts(const std::string& bn,\n                                           const std::list<std::shared_ptr<RegstDesc>>& regsts) {\n  const LogicalBlobId& lbi = op()->BnInOp2Lbi(bn);\n  bool has_binded = false;\n  for (std::shared_ptr<RegstDesc> regst : regsts) {\n    if (regst->GetBlobDesc(lbi) == nullptr) { continue; }\n    BindBnWithRegst(bn, regst);\n    has_binded = true;\n    break;\n  }\n  return has_binded;\n}\n\nvoid ExecNode::BindBnWithOneOfTheRegsts(const std::string& bn,\n                                        const std::list<std::shared_ptr<RegstDesc>>& regsts) {\n  CHECK(TryBindBnWithOneOfTheRegsts(bn, regsts));\n}\n\nvoid ExecNode::UnbindBnWithEmptyRegst() {\n  EraseIf<std::string, std::shared_ptr<RegstDesc>>(\n      &bn_in_op2regst_, [](HashMap<std::string, std::shared_ptr<RegstDesc>>::iterator it) {\n        return it->second->regst_desc_type().has_data_regst_desc() && it->second->NumOfLbi() == 0;\n      });\n}\n\nvoid ExecNode::ToProto(const ParallelContext* parallel_ctx, ExecNodeProto* ret) const {\n  op_->GenKernelConf(GetBlobDesc4BnInOpFunc(), parallel_ctx, ret->mutable_kernel_conf());\n  for (const auto& bn_regst : bn_in_op2regst_) {\n    const std::string& bn_in_op = bn_regst.first;\n    auto regst = bn_regst.second;\n    CHECK(regst);\n    PbMapPair<std::string, int64_t> pair{bn_in_op, regst->regst_desc_id()};\n    CHECK(ret->mutable_bn_in_op2regst_desc_id()->insert(pair).second);\n  }\n}\n\nnamespace {\n\nMaybe<void> CheckPhysicalBlobDesc(const BlobDesc& logical, const NdSbp& nd_sbp,\n                                  const ParallelDesc& parallel_desc,\n                                  const ParallelContext* parallel_ctx, const BlobDesc& physical) {\n  CHECK_EQ_OR_RETURN(physical.shape(), *JUST(GetPhysicalShape(logical.shape(), nd_sbp,\n                                                              parallel_desc, *parallel_ctx)));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> CheckPhysicalBlobDesc(\n    const Operator& op, const PbRpf<std::string>& bns,\n    const std::function<Maybe<const BlobDesc>(const std::string&)>& GetLogicalBlobDesc,\n    const NdSbpSignature* nd_sbp_signature, const ParallelContext* parallel_ctx,\n    const std::function<BlobDesc*(const std::string&)>& GetPhysicalBlobDesc) {\n  const std::shared_ptr<const ParallelDesc> op_parallel_desc = JUST(op.GetOpParallelDesc());\n  for (const auto& bn : bns) {\n    const BlobDesc* physical_blob_desc = GetPhysicalBlobDesc(bn);\n    if (physical_blob_desc == nullptr) {\n      // TODO(liujuncheng): remove this hotfix\n      continue;\n    }\n    if (*JUST(op.GetParallelDesc4BnInOp(bn)) == *op_parallel_desc) {\n      JUST_MSG(CheckPhysicalBlobDesc(*JUST(GetLogicalBlobDesc(bn)),\n                                     nd_sbp_signature->bn_in_op2nd_sbp().at(bn), *op_parallel_desc,\n                                     parallel_ctx, *physical_blob_desc),\n               std::stringstream() << \" check physical shape failed, op name \" << op.op_loc());\n    }\n  }\n  return Maybe<void>::Ok();\n}\n\n// A helper function to infer blob's physical shape with ND SBP.\nMaybe<void> InferPhysicalBlobDesc(\n    const Operator& op, const PbRpf<std::string>& bns,\n    const std::function<Maybe<const BlobDesc>(const std::string&)>& GetLogicalBlobDesc,\n    const NdSbpSignature* nd_sbp_signature, const ParallelContext* parallel_ctx,\n    const std::function<BlobDesc*(const std::string&)>& GetPhysicalBlobDesc) {\n  const std::shared_ptr<const ParallelDesc> op_parallel_desc = JUST(op.GetOpParallelDesc());\n  for (const auto& bn : bns) {\n    BlobDesc* physical_blob_desc = GetPhysicalBlobDesc(bn);\n    const auto& logical_blob_desc = *JUST(GetLogicalBlobDesc(bn));\n    CHECK_NOTNULL_OR_RETURN(physical_blob_desc)\n        << \"physical_blob_desc should not be nullptr. op location: \" << op.op_loc();\n    *physical_blob_desc = logical_blob_desc;\n    const auto& physical_shape = JUST_MSG(\n        GetPhysicalShape(logical_blob_desc.shape(), nd_sbp_signature->bn_in_op2nd_sbp().at(bn),\n                         *op_parallel_desc, *parallel_ctx),\n        std::stringstream() << \" check physical shape failed, op name \" << op.op_loc());\n    physical_blob_desc->set_shape(*physical_shape);\n  }\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\nvoid ExecNode::InferBlobDescsByInputs(const ParallelContext* parallel_ctx) {\n  auto GetBlobDesc4BnInOp = GetBlobDesc4BnInOpFunc();\n  const OpNode* op_node = Singleton<OpGraph>::Get()->OpNode4OpName(op()->op_name());\n  const NdSbpSignature* nd_sbp_signature = nullptr;\n  if (op_node != nullptr) { nd_sbp_signature = &op_node->nd_sbp_signature(); }\n\n  if (op_node != nullptr && parallel_ctx->parallel_num() > 1 && nd_sbp_signature != nullptr) {\n    CHECK_JUST(CheckPhysicalBlobDesc(\n        *op(), op()->input_bns(),\n        std::bind(&Operator::GetLogicalBlobDesc4Ibn, op().get(), std::placeholders::_1),\n        nd_sbp_signature, parallel_ctx, GetBlobDesc4BnInOp));\n  }\n  CHECK_JUST_MSG(op_->InferBlobDescsIf(GetBlobDesc4BnInOp, parallel_ctx, &GlobalJobDesc()),\n                 std::stringstream() << \" infer blob descs is failed, op name \" << op_->op_loc());\n  if (op_node != nullptr && parallel_ctx->parallel_num() > 1 && nd_sbp_signature != nullptr) {\n    CHECK_JUST(CheckPhysicalBlobDesc(\n        *op(), op()->output_bns(),\n        std::bind(&Operator::GetLogicalBlobDesc4Obn, op().get(), std::placeholders::_1),\n        nd_sbp_signature, parallel_ctx, GetBlobDesc4BnInOp));\n  }\n  CHECK_JUST_MSG(op_->InferInplaceObn2IbnIf(&mut_inplace_obn2ibn_, &con_inplace_obn2ibn_,\n                                            GetBlobDesc4BnInOp, parallel_ctx),\n                 std::stringstream()\n                     << \" infer inplace obn to ibn is failed, op name \" << op_->op_loc());\n}\n\nvoid ExecNode::InferBlobDescsByNdSbp(const ParallelContext* parallel_ctx) {\n  const HashSet<std::string> ibns{op()->input_bns().begin(), op()->input_bns().end()};\n  HashMap<std::string, BlobDesc> ibn2blob_desc{};\n  const auto& GetBlobDesc4BnInOp = [&](const std::string& bn_in_op) -> BlobDesc* {\n    // Generate temp regst to store input blob desc, and will be released after infer output blob\n    // desc.\n    if (ibns.count(bn_in_op) > 0) {\n      auto iter = ibn2blob_desc.find(bn_in_op);\n      if (iter == ibn2blob_desc.end()) {\n        iter = ibn2blob_desc.emplace(bn_in_op, BlobDesc(kInvalidDataType, kContiguous)).first;\n      }\n      return &iter->second;\n    }\n    auto it = bn_in_op2regst_.find(bn_in_op);\n    if (it == bn_in_op2regst_.end()) { return nullptr; }\n    std::shared_ptr<RegstDesc> regst = it->second;\n    CHECK(regst);\n    return regst->MutBlobDesc(op()->BnInOp2Lbi(bn_in_op));\n  };\n  const OpNode* op_node = Singleton<OpGraph>::Get()->OpNode4OpName(op()->op_name());\n  const NdSbpSignature* nd_sbp_signature = &CHECK_NOTNULL(op_node)->nd_sbp_signature();\n\n  // TODO(strint): user op can infer output with SBP, so there is no need to infer the input.\n  // Reference: https://github.com/Oneflow-Inc/oneflow/pull/8971\n  // Infer input blob desc with SBP, the infer results are set into the temp input blob desc.\n  CHECK_JUST(InferPhysicalBlobDesc(\n      *op(), op()->input_bns(),\n      std::bind(&Operator::GetLogicalBlobDesc4Ibn, op().get(), std::placeholders::_1),\n      nd_sbp_signature, parallel_ctx, GetBlobDesc4BnInOp));\n\n  // Infer output blob desc with input.\n  CHECK_JUST_MSG(op_->InferBlobDescsIf(GetBlobDesc4BnInOp, parallel_ctx, &GlobalJobDesc()),\n                 std::stringstream() << \" infer blob descs is failed, op name \" << op_->op_loc());\n  CHECK_JUST(CheckPhysicalBlobDesc(\n      *op(), op()->output_bns(),\n      std::bind(&Operator::GetLogicalBlobDesc4Obn, op().get(), std::placeholders::_1),\n      nd_sbp_signature, parallel_ctx, GetBlobDesc4BnInOp));\n  CHECK_JUST_MSG(op_->InferInplaceObn2IbnIf(&mut_inplace_obn2ibn_, &con_inplace_obn2ibn_,\n                                            GetBlobDesc4BnInOp, parallel_ctx),\n                 std::stringstream()\n                     << \" infer inplace obn to ibn is failed, op name \" << op_->op_loc());\n}\n\nstd::function<BlobDesc*(const std::string&)> ExecNode::GetBlobDesc4BnInOpFunc() const {\n  return [this](const std::string& bn_in_op) -> BlobDesc* {\n    auto it = bn_in_op2regst_.find(bn_in_op);\n    if (it == bn_in_op2regst_.end()) { return nullptr; }\n    std::shared_ptr<RegstDesc> regst = it->second;\n    CHECK(regst);\n    return regst->MutBlobDesc(op()->BnInOp2Lbi(bn_in_op));\n  };\n}\n\nvoid ExecGraph::ToExecSequence(const ParallelContext* parallel_ctx, ExecSequence* ret) const {\n  TopoForEachNode([&](ExecNode* node) { node->ToProto(parallel_ctx, ret->add_exec_node()); });\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/graph/exec_graph.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_GRAPH_EXEC_GRAPH_H_\n#define ONEFLOW_CORE_GRAPH_EXEC_GRAPH_H_\n\n#include \"oneflow/core/common/protobuf.h\"\n#include \"oneflow/core/graph/exec_sequence.pb.h\"\n#include \"oneflow/core/graph/graph.h\"\n#include \"oneflow/core/operator/operator.h\"\n#include \"oneflow/core/register/register_desc.h\"\n\nnamespace oneflow {\n\nclass ExecNode;\n\nclass ExecEdge final : public Edge<ExecNode, ExecEdge> {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(ExecEdge);\n  ExecEdge() = default;\n  ~ExecEdge() = default;\n\n  // Getters\n  const LogicalBlobId& lbi() const { return lbi_; }\n  const std::string& src_bn() const { return src_bn_; }\n  const std::string& dst_bn() const { return dst_bn_; }\n\n  // Setters\n  void set_lbi(const LogicalBlobId& lbi) { lbi_ = lbi; }\n  std::string& mut_src_bn() { return src_bn_; }\n  std::string& mut_dst_bn() { return dst_bn_; }\n\n private:\n  // various names for one blob\n  LogicalBlobId lbi_;\n  std::string src_bn_;\n  std::string dst_bn_;\n};\n\nclass ExecNode final : public Node<ExecNode, ExecEdge> {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(ExecNode);\n  ExecNode() {}\n  ~ExecNode() = default;\n\n  std::shared_ptr<const Operator> op() const { return op_; }\n  std::shared_ptr<const Operator>& mut_op() { return op_; }\n  RegstDesc* RegstDesc4BnInOp(const std::string& bn) const { return bn_in_op2regst_.at(bn).get(); }\n\n  void BindBnWithRegst(const std::string& bn, std::shared_ptr<RegstDesc>);\n  void BindBnsWithRegst(const PbRpf<std::string>& (Operator::*bns_getter)() const,\n                        std::shared_ptr<RegstDesc>);\n  void AddBnToRegstAndBindIt(const PbRpf<std::string>& (Operator::*bns_getter)() const,\n                             std::shared_ptr<RegstDesc>);\n  bool TryBindBnWithOneOfTheRegsts(const std::string&,\n                                   const std::list<std::shared_ptr<RegstDesc>>&);\n  void BindBnWithOneOfTheRegsts(const std::string&, const std::list<std::shared_ptr<RegstDesc>>&);\n  void UnbindBnWithEmptyRegst();\n\n  std::string VisualStr() const override { return op_->op_name(); }\n  void ToProto(const ParallelContext*, ExecNodeProto*) const;\n\n  typedef void (ExecNode::*InferBlobDescsMethod)(const ParallelContext*);\n  void InferBlobDescsByInputs(const ParallelContext* parallel_ctx);\n  void InferBlobDescsByNdSbp(const ParallelContext* parallel_ctx);\n\n  const HashMap<std::string, std::string>& mut_inplace_obn2ibn() const {\n    return mut_inplace_obn2ibn_;\n  }\n  const HashMap<std::string, std::string>& con_inplace_obn2ibn() const {\n    return con_inplace_obn2ibn_;\n  }\n\n private:\n  std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOpFunc() const;\n\n  std::shared_ptr<const Operator> op_;\n  HashMap<std::string, std::shared_ptr<RegstDesc>> bn_in_op2regst_;\n\n  HashMap<std::string, std::string> mut_inplace_obn2ibn_;\n  HashMap<std::string, std::string> con_inplace_obn2ibn_;\n};\n\nclass ExecGraph final : public Graph<ExecNode, ExecEdge> {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(ExecGraph);\n  ExecGraph() = default;\n  ~ExecGraph() = default;\n\n  void ToExecSequence(const ParallelContext*, ExecSequence*) const;\n  const char* TypeName() const override { return \"ExecGraph\"; }\n\n private:\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_GRAPH_EXEC_GRAPH_H_\n"
  },
  {
    "path": "oneflow/core/graph/exec_sequence.proto",
    "content": "syntax = \"proto2\";\npackage oneflow;\n\nimport \"oneflow/core/kernel/kernel.proto\";\n\nmessage ExecNodeProto {\n  required KernelConf kernel_conf = 1;\n  map<string, int64> bn_in_op2regst_desc_id = 2;\n}\n\nmessage ExecSequence {\n  repeated ExecNodeProto exec_node = 1;\n}\n"
  },
  {
    "path": "oneflow/core/graph/fake_consumed_regst_provider.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_GRAPH_FAKE_CONSUMED_REGST_PROVIDER_H_\n#define ONEFLOW_CORE_GRAPH_FAKE_CONSUMED_REGST_PROVIDER_H_\n\nnamespace oneflow {\n\n// Provide a compute task node with a fake input regst, and its output regst can be inferred using\n// SBP + Placement. The fake compute task node can help the task graph of one rank to infer blob\n// desc, mainly to ensure that the transport task node has the correct input blob desc.\nclass FakeConsumedRegstProvider {\n public:\n  FakeConsumedRegstProvider() = default;\n  virtual ~FakeConsumedRegstProvider() = default;\n\n  virtual void ConsumeFakeRegstsIf() = 0;\n  virtual void EraseFakeRegstsIf() = 0;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_GRAPH_FAKE_CONSUMED_REGST_PROVIDER_H_\n"
  },
  {
    "path": "oneflow/core/graph/graph.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_GRAPH_GRAPH_H_\n#define ONEFLOW_CORE_GRAPH_GRAPH_H_\n\n#include <stack>\n#include <bitset>\n#include \"oneflow/core/common/str_util.h\"\n#include \"oneflow/core/graph/node.h\"\n#include \"oneflow/core/persistence/tee_persistent_log_stream.h\"\n\nnamespace oneflow {\n\ntemplate<typename NodeType, typename EdgeType>\nclass Graph {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(Graph);\n  Graph() = default;\n  virtual ~Graph() = default;\n\n  // For Each\n  void ForEachNode(std::function<void(NodeType*)> NodeHandler) const;\n  Maybe<void> MaybeForEachNode(std::function<Maybe<void>(NodeType*)> NodeHandler) const;\n  // In case you want to change the topological structure during the node handler.\n  // For example, adding/deleting a node or an edge.\n  // Still, it might have bugs even if you use TopoForEachNodeDynamic.\n  void TopoForEachNodeDynamic(std::function<void(NodeType*)> NodeHandler) const;\n  void TopoForEachNode(std::function<void(NodeType*)> NodeHandler) const;\n  Maybe<void> TopoForEachNodeDynamicWithErrorCaptured(\n      std::function<Maybe<void>(NodeType*)> NodeHandler) const;\n  Maybe<void> TopoForEachNodeWithErrorCaptured(\n      std::function<Maybe<void>(NodeType*)> NodeHandler) const;\n  void ReverseTopoForEachNode(std::function<void(NodeType*)> NodeHandler) const;\n  void ForEachEdge(std::function<void(EdgeType*)> EdgeHandler) const;\n  Maybe<void> MaybeForEachEdge(std::function<Maybe<void>(EdgeType*)> EdgeHandler) const;\n\n  void SortedTopoForEachNode(std::function<bool(const EdgeType* lhs, const EdgeType* rhs)> LessThan,\n                             std::function<void(NodeType*)> NodeHandler) const;\n\n  void BfsForEachNode(\n      const std::list<NodeType*>& starts,\n      const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachNext,\n      const std::function<void(NodeType*)>& Handler) const;\n\n  void DfsForEachNode(\n      const std::list<NodeType*>& starts,\n      const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachNext,\n      const std::function<void(NodeType*)>& Handler) const;\n\n  void TopoForEachNodeDynamic(\n      const std::list<NodeType*>& starts,\n      const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachInNode,\n      const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachOutNode,\n      const std::function<void(NodeType*)>& Handler) const;\n\n  void TopoForEachNode(\n      const std::list<NodeType*>& starts,\n      const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachInNode,\n      const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachOutNode,\n      const std::function<void(NodeType*)>& Handler) const;\n\n  void TopoForEachNode(\n      const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachInNode,\n      const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachOutNode,\n      const std::function<void(NodeType*)>& Handler) const;\n\n  Maybe<void> TopoForEachNodeDynamicWithErrorCaptured(\n      const std::list<NodeType*>& starts,\n      const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachInNode,\n      const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachOutNode,\n      const std::function<Maybe<void>(NodeType*)>& Handler) const;\n\n  Maybe<void> TopoForEachNodeWithErrorCaptured(\n      const std::list<NodeType*>& starts,\n      const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachInNode,\n      const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachOutNode,\n      const std::function<Maybe<void>(NodeType*)>& Handler) const;\n\n  Maybe<void> TopoForEachNodeWithErrorCaptured(\n      const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachInNode,\n      const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachOutNode,\n      const std::function<Maybe<void>(NodeType*)>& Handler) const;\n\n  void DfsTopoForEachNode(\n      const std::list<NodeType*>& starts,\n      const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachInNode,\n      const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachOutNode,\n      const std::function<void(NodeType*)>& Handler) const;\n\n  void DfsTopoForEachNodeSortByDistanceToSink(\n      const std::list<NodeType*>& starts,\n      const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachInNode,\n      const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachOutNode,\n      const std::function<void(NodeType*)>& Handler) const;\n\n  std::function<bool(const NodeType* src, const NodeType* dst)> MakePredicatorIsReachable() const;\n\n  std::function<bool(const NodeType* src, const NodeType* dst)> MakePredicatorIsReachable(\n      const std::list<NodeType*>& starts,\n      const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachInNode,\n      const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachOutNode)\n      const;\n\n  void ForEachConnectedComponent(\n      const std::function<void(const HashSet<NodeType*>&)>& Handler) const;\n\n  void ForEachConnectedComponent(\n      const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachConnected,\n      const std::function<void(const HashSet<NodeType*>&)>& Handler) const;\n\n  void ForEachConnectedComponent(\n      const std::function<void(const std::function<void(NodeType*)>&)>& ForEachNodeAsStart,\n      const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachConnected,\n      const std::function<void(const HashSet<NodeType*>&)>& Handler) const;\n\n  // find first nontrivial strongly connected component\n  std::unique_ptr<HashSet<NodeType*>> FindFirstNontrivialSCC(\n      const std::list<NodeType*>& starts,\n      const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachInNode,\n      const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachOutNode)\n      const;\n\n  std::unique_ptr<HashSet<NodeType*>> FindFirstNontrivialSCC(\n      const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachInNode,\n      const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachOutNode)\n      const;\n\n  std::unique_ptr<HashSet<NodeType*>> FindFirstNontrivialSCC() const;\n\n  // Getters\n  std::list<NodeType*> source_nodes() const;\n  std::list<NodeType*> sink_nodes() const;\n  NodeType* SoleSourceNode() const;\n  NodeType* SoleSinkNode() const;\n  NodeType* SoleNode() const;\n  size_t node_num() const { return nodes_.size(); }\n  size_t edge_num() const { return edges_.size(); }\n  virtual const char* TypeName() const { return \"\"; }\n\n  // Setters\n  template<typename DerivedNodeType = NodeType>\n  DerivedNodeType* NewNode();\n  template<class... Args>\n  EdgeType* NewEdge(Args&&... args);\n  void AddAllocatedNode(NodeType*);\n  void AddAllocatedEdge(EdgeType*);\n  void DeleteNode(NodeType*);\n\n  // ToDot\n  template<typename StreamT>\n  void ToDotWithStream(StreamT& out_stream) const;\n  template<typename StreamT>\n  void ToDotWithStream(const std::function<bool(NodeType*)>& IsNodeAllowed,\n                       const std::function<bool(EdgeType*)>& IsEdgeAllowed,\n                       const std::function<std::string(NodeType*)>& AddNodeAttribute,\n                       const std::function<std::string(EdgeType*)>& AddEdgeAttribute,\n                       StreamT& out_stream) const;\n  void ToDotWithFilePath(const std::string& file_path) const;\n  void ToDotWithFilePath(const std::function<std::string(NodeType*)>& AddNodeAttribute,\n                         const std::function<std::string(EdgeType*)>& AddEdgeAttribute,\n                         const std::string& file_path) const;\n  void ToDotWithFilePath(const std::function<bool(NodeType*)>& IsNodeAllowed,\n                         const std::function<bool(EdgeType*)>& IsEdgeAllowed,\n                         const std::string& file_path) const;\n  void ToDotWithAutoFilePath() const;\n\n private:\n  std::unique_ptr<HashSet<NodeType*>> FindFirstNontrivialSCC(\n      const std::function<void(const std::function<void(NodeType*)>&)>& ForEachStart,\n      const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachInNode,\n      const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachOutNode)\n      const;\n\n  // finish time first search\n  void FfsForEachNode(\n      const std::function<void(const std::function<void(NodeType*)>&)>& ForEachStart,\n      const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachNext,\n      const std::function<void(NodeType*)>& Handler) const;\n\n  void FfsForEachNode(const std::function<void(NodeType*)>& Handler) const;\n\n  std::vector<std::unique_ptr<NodeType>> nodes_;\n  std::vector<std::unique_ptr<EdgeType>> edges_;\n};\n\ntemplate<typename NodeType, typename EdgeType>\nvoid Graph<NodeType, EdgeType>::ForEachNode(std::function<void(NodeType*)> NodeHandler) const {\n  for (auto& x : nodes_) { NodeHandler(x.get()); }\n}\n\ntemplate<typename NodeType, typename EdgeType>\nMaybe<void> Graph<NodeType, EdgeType>::MaybeForEachNode(\n    std::function<Maybe<void>(NodeType*)> NodeHandler) const {\n  for (auto& x : nodes_) { JUST(NodeHandler(x.get())); }\n  return Maybe<void>::Ok();\n}\n\ntemplate<typename NodeType, typename EdgeType>\nstd::list<NodeType*> Graph<NodeType, EdgeType>::source_nodes() const {\n  std::list<NodeType*> ret;\n  ForEachNode([&](NodeType* node) {\n    if (node->in_edges().empty()) { ret.emplace_back(node); }\n  });\n  return ret;\n}\n\ntemplate<typename NodeType, typename EdgeType>\nstd::list<NodeType*> Graph<NodeType, EdgeType>::sink_nodes() const {\n  std::list<NodeType*> ret;\n  ForEachNode([&](NodeType* node) {\n    if (node->out_edges().empty()) { ret.emplace_back(node); }\n  });\n  return ret;\n}\n\ntemplate<typename NodeType, typename EdgeType>\nNodeType* Graph<NodeType, EdgeType>::SoleSourceNode() const {\n  std::list<NodeType*> source_nodes_list = source_nodes();\n  CHECK_EQ(source_nodes_list.size(), 1);\n  return source_nodes_list.front();\n}\n\ntemplate<typename NodeType, typename EdgeType>\nNodeType* Graph<NodeType, EdgeType>::SoleSinkNode() const {\n  std::list<NodeType*> sink_nodes_list = sink_nodes();\n  CHECK_EQ(sink_nodes_list.size(), 1);\n  return sink_nodes_list.front();\n}\n\ntemplate<typename NodeType, typename EdgeType>\nvoid Graph<NodeType, EdgeType>::TopoForEachNodeDynamic(\n    std::function<void(NodeType*)> NodeHandler) const {\n  TopoForEachNodeDynamic(source_nodes(), &NodeType::ForEachNodeOnInEdge,\n                         &NodeType::ForEachNodeOnOutEdge, NodeHandler);\n}\n\ntemplate<typename NodeType, typename EdgeType>\nvoid Graph<NodeType, EdgeType>::TopoForEachNode(std::function<void(NodeType*)> NodeHandler) const {\n  CHECK_JUST(TopoForEachNodeWithErrorCaptured(&NodeType::ForEachNodeOnInEdge,\n                                              &NodeType::ForEachNodeOnOutEdge, [&](NodeType* node) {\n                                                NodeHandler(node);\n                                                return Maybe<void>::Ok();\n                                              }));\n}\n\ntemplate<typename NodeType, typename EdgeType>\nMaybe<void> Graph<NodeType, EdgeType>::TopoForEachNodeDynamicWithErrorCaptured(\n    std::function<Maybe<void>(NodeType*)> NodeHandler) const {\n  return TopoForEachNodeDynamicWithErrorCaptured(source_nodes(), &NodeType::ForEachNodeOnInEdge,\n                                                 &NodeType::ForEachNodeOnOutEdge, NodeHandler);\n}\n\ntemplate<typename NodeType, typename EdgeType>\nMaybe<void> Graph<NodeType, EdgeType>::TopoForEachNodeWithErrorCaptured(\n    std::function<Maybe<void>(NodeType*)> NodeHandler) const {\n  return TopoForEachNodeWithErrorCaptured(&NodeType::ForEachNodeOnInEdge,\n                                          &NodeType::ForEachNodeOnOutEdge, NodeHandler);\n}\n\ntemplate<typename NodeType, typename EdgeType>\nvoid Graph<NodeType, EdgeType>::SortedTopoForEachNode(\n    std::function<bool(const EdgeType* lhs, const EdgeType* rhs)> LessThan,\n    std::function<void(NodeType*)> NodeHandler) const {\n  ForEachNode([&](NodeType* node) { node->SortInOutEdges(LessThan); });\n  TopoForEachNode(&NodeType::ForEachNodeOnSortedInEdge, &NodeType::ForEachNodeOnSortedOutEdge,\n                  NodeHandler);\n}\n\ntemplate<typename NodeType, typename EdgeType>\nvoid Graph<NodeType, EdgeType>::ReverseTopoForEachNode(\n    std::function<void(NodeType*)> NodeHandler) const {\n  TopoForEachNode(&NodeType::ForEachNodeOnOutEdge, &NodeType::ForEachNodeOnInEdge, NodeHandler);\n}\n\ntemplate<typename NodeType, typename EdgeType>\nvoid Graph<NodeType, EdgeType>::ForEachEdge(std::function<void(EdgeType*)> EdgeHandler) const {\n  for (auto& x : edges_) {\n    if (x->src_node() == nullptr && x->dst_node() == nullptr) { continue; }\n    EdgeHandler(x.get());\n  }\n}\n\ntemplate<typename NodeType, typename EdgeType>\nMaybe<void> Graph<NodeType, EdgeType>::MaybeForEachEdge(\n    std::function<Maybe<void>(EdgeType*)> EdgeHandler) const {\n  for (auto& x : edges_) {\n    if (x->src_node() == nullptr && x->dst_node() == nullptr) { continue; }\n    JUST(EdgeHandler(x.get()));\n  }\n  return Maybe<void>::Ok();\n}\n\ntemplate<typename NodeType, typename EdgeType>\nNodeType* Graph<NodeType, EdgeType>::SoleNode() const {\n  CHECK_EQ(nodes_.size(), 1);\n  return nodes_.front().get();\n}\n\ntemplate<typename NodeType, typename EdgeType>\ntemplate<typename DerivedNodeType>\nDerivedNodeType* Graph<NodeType, EdgeType>::NewNode() {\n  DerivedNodeType* ret = new DerivedNodeType;\n  AddAllocatedNode(ret);\n  return ret;\n}\n\ntemplate<typename NodeType, typename EdgeType>\ntemplate<class... Args>\nEdgeType* Graph<NodeType, EdgeType>::NewEdge(Args&&... args) {\n  EdgeType* ret = new EdgeType(std::forward<Args>(args)...);\n  AddAllocatedEdge(ret);\n  return ret;\n}\n\ntemplate<typename NodeType, typename EdgeType>\nvoid Graph<NodeType, EdgeType>::AddAllocatedNode(NodeType* node) {\n  nodes_.emplace_back(node);\n}\n\ntemplate<typename NodeType, typename EdgeType>\nvoid Graph<NodeType, EdgeType>::AddAllocatedEdge(EdgeType* edge) {\n  edges_.emplace_back(edge);\n}\n\ntemplate<typename NodeType, typename EdgeType>\nvoid Graph<NodeType, EdgeType>::DeleteNode(NodeType* node) {\n  Erase<std::vector<std::unique_ptr<NodeType>>>(\n      nodes_, [node](const std::unique_ptr<NodeType>& node_ptr) { return node_ptr.get() == node; });\n}\n\ntemplate<typename NodeType, typename EdgeType>\ntemplate<typename StreamT>\nvoid Graph<NodeType, EdgeType>::ToDotWithStream(StreamT& out_stream) const {\n  ToDotWithStream([](NodeType*) { return true; }, [](EdgeType*) { return true; },\n                  [](NodeType*) { return \"\"; }, [](EdgeType*) { return \"\"; }, out_stream);\n}\n\ntemplate<typename NodeType, typename EdgeType>\ntemplate<typename StreamT>\nvoid Graph<NodeType, EdgeType>::ToDotWithStream(\n    const std::function<bool(NodeType*)>& IsNodeAllowed,\n    const std::function<bool(EdgeType*)>& IsEdgeAllowed,\n    const std::function<std::string(NodeType*)>& AddNodeAttribute,\n    const std::function<std::string(EdgeType*)>& AddEdgeAttribute, StreamT& out_stream) const {\n  out_stream << \"digraph {\\n\";\n  this->ForEachNode([&](NodeType* node) {\n    if (IsNodeAllowed(node) == false) { return; }\n    out_stream << \"\\\"\" << node->node_id_str() << \"\\\" [label=\\\"\" << node->VisualStr() << \"\\\"\"\n               << AddNodeAttribute(node) << \"]\\n\";\n  });\n  this->ForEachEdge([&](EdgeType* edge) {\n    if (IsEdgeAllowed(edge) == false) { return; }\n    if (IsNodeAllowed(edge->src_node()) == false) { return; }\n    if (IsNodeAllowed(edge->dst_node()) == false) { return; }\n    out_stream << \"\\\"\" << edge->src_node()->node_id_str() << \"\\\" -> \"\n               << \"\\\"\" << edge->dst_node()->node_id_str() << \"\\\"\"\n               << \"[label=\\\"\" << edge->VisualStr() << \"\\\"\" << AddEdgeAttribute(edge) << \"];\\n\";\n  });\n  out_stream << \"}\\n\";\n}\n\ntemplate<typename NodeType, typename EdgeType>\nvoid Graph<NodeType, EdgeType>::ToDotWithFilePath(const std::string& file_path) const {\n  auto log_stream = TeePersistentLogStream::Create(file_path);\n  ToDotWithStream(log_stream);\n  log_stream->Flush();\n}\n\ntemplate<typename NodeType, typename EdgeType>\nvoid Graph<NodeType, EdgeType>::ToDotWithFilePath(\n    const std::function<std::string(NodeType*)>& AddNodeAttribute,\n    const std::function<std::string(EdgeType*)>& AddEdgeAttribute,\n    const std::string& file_path) const {\n  auto log_stream = TeePersistentLogStream::Create(file_path);\n  ToDotWithStream([](NodeType*) { return true; }, [](EdgeType*) { return true; }, AddNodeAttribute,\n                  AddEdgeAttribute, log_stream);\n  log_stream->Flush();\n}\n\ntemplate<typename NodeType, typename EdgeType>\nvoid Graph<NodeType, EdgeType>::ToDotWithFilePath(\n    const std::function<bool(NodeType*)>& IsNodeAllowed,\n    const std::function<bool(EdgeType*)>& IsEdgeAllowed, const std::string& file_path) const {\n  auto log_stream = TeePersistentLogStream::Create(file_path);\n  ToDotWithStream(\n      IsNodeAllowed, IsEdgeAllowed, [](NodeType*) { return \"\"; }, [](EdgeType*) { return \"\"; },\n      log_stream);\n  log_stream->Flush();\n}\n\ntemplate<typename NodeType, typename EdgeType>\nvoid Graph<NodeType, EdgeType>::ToDotWithAutoFilePath() const {\n  std::string file_path = JoinPath(\"dot\", TypeName(), NewUniqueId() + \".dot\");\n  ToDotWithFilePath(file_path);\n}\n\ntemplate<typename NodeType, typename EdgeType>\nvoid Graph<NodeType, EdgeType>::BfsForEachNode(\n    const std::list<NodeType*>& starts,\n    const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachNext,\n    const std::function<void(NodeType*)>& Handler) const {\n  HashSet<NodeType*> queued_nodes;\n  std::queue<NodeType*> queue;\n  for (NodeType* start : starts) {\n    if (queued_nodes.find(start) == queued_nodes.end()) {\n      queue.push(start);\n      queued_nodes.insert(start);\n    }\n  }\n  while (!queue.empty()) {\n    NodeType* cur_node = queue.front();\n    queue.pop();\n    Handler(cur_node);\n    ForEachNext(cur_node, [&](NodeType* next) {\n      if (queued_nodes.find(next) == queued_nodes.end()) {\n        queue.push(next);\n        queued_nodes.insert(next);\n      }\n    });\n  }\n}\n\ntemplate<typename NodeType, typename EdgeType>\nvoid Graph<NodeType, EdgeType>::DfsForEachNode(\n    const std::list<NodeType*>& starts,\n    const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachNext,\n    const std::function<void(NodeType*)>& Handler) const {\n  HashSet<NodeType*> visited_nodes;\n  std::stack<NodeType*> stack;\n  for (NodeType* start : starts) { stack.push(start); }\n  while (!stack.empty()) {\n    NodeType* cur_node = stack.top();\n    stack.pop();\n    if (visited_nodes.find(cur_node) == visited_nodes.end()) {\n      Handler(cur_node);\n      visited_nodes.insert(cur_node);\n      ForEachNext(cur_node, [&](NodeType* next) {\n        if (visited_nodes.find(next) == visited_nodes.end()) { stack.push(next); }\n      });\n    }\n  }\n}\n\ntemplate<typename NodeType, typename EdgeType>\nvoid Graph<NodeType, EdgeType>::FfsForEachNode(\n    const std::function<void(const std::function<void(NodeType*)>&)>& ForEachStart,\n    const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachNext,\n    const std::function<void(NodeType*)>& Handler) const {\n  HashSet<NodeType*> visited_nodes;\n  HashSet<NodeType*> handled_nodes;\n  ForEachStart([&](NodeType* start) {\n    if (visited_nodes.find(start) != visited_nodes.end()) { return; }\n    std::stack<std::queue<NodeType*>> stack;\n    stack.emplace(std::queue<NodeType*>{});\n    stack.top().push(start);\n    while (!stack.empty()) {\n      if (stack.top().empty()) {\n        stack.pop();\n        continue;\n      }\n      if (handled_nodes.find(stack.top().front()) != handled_nodes.end()) {\n        stack.top().pop();\n        continue;\n      }\n      NodeType* cur_node = stack.top().front();\n      if (visited_nodes.find(cur_node) == visited_nodes.end()) { visited_nodes.insert(cur_node); }\n      int64_t next_unvisited_cnt = 0;\n      ForEachNext(cur_node, [&](NodeType* next) {\n        if (visited_nodes.find(next) == visited_nodes.end()) {\n          if (next_unvisited_cnt == 0) { stack.emplace(std::queue<NodeType*>()); }\n          stack.top().push(next);\n          ++next_unvisited_cnt;\n        }\n      });\n      if (next_unvisited_cnt == 0) {\n        Handler(cur_node);\n        handled_nodes.insert(cur_node);\n      }\n    }\n  });\n}\n\ntemplate<typename NodeType, typename EdgeType>\nstd::unique_ptr<HashSet<NodeType*>> Graph<NodeType, EdgeType>::FindFirstNontrivialSCC(\n    const std::list<NodeType*>& starts,\n    const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachInNode,\n    const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachOutNode)\n    const {\n  auto ForEachStart = [&](const std::function<void(NodeType*)>& Handler) {\n    for (NodeType* start : starts) { Handler(start); }\n  };\n  return FindFirstNontrivialSCC(ForEachStart, ForEachInNode, ForEachOutNode);\n}\n\ntemplate<typename NodeType, typename EdgeType>\nstd::unique_ptr<HashSet<NodeType*>> Graph<NodeType, EdgeType>::FindFirstNontrivialSCC(\n    const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachInNode,\n    const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachOutNode)\n    const {\n  return FindFirstNontrivialSCC(\n      [&](const std::function<void(NodeType*)>& Handler) { ForEachNode(Handler); }, ForEachInNode,\n      ForEachOutNode);\n}\n\ntemplate<typename NodeType, typename EdgeType>\nstd::unique_ptr<HashSet<NodeType*>> Graph<NodeType, EdgeType>::FindFirstNontrivialSCC() const {\n  return FindFirstNontrivialSCC(\n      [&](const std::function<void(NodeType*)>& Handler) { ForEachNode(Handler); },\n      &NodeType::ForEachNodeOnInEdge, &NodeType::ForEachNodeOnOutEdge);\n}\n\ntemplate<typename NodeType, typename EdgeType>\nstd::unique_ptr<HashSet<NodeType*>> Graph<NodeType, EdgeType>::FindFirstNontrivialSCC(\n    const std::function<void(const std::function<void(NodeType*)>&)>& ForEachStart,\n    const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachInNode,\n    const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachOutNode)\n    const {\n  std::stack<NodeType*> stack;\n  FfsForEachNode(ForEachStart, ForEachOutNode, [&](NodeType* node) { stack.push(node); });\n  HashSet<NodeType*> visited;\n  auto ForEachUnvisitedInNode = [&](NodeType* node, const std::function<void(NodeType*)>& Handler) {\n    ForEachInNode(node, [&](NodeType* in_node) {\n      if (visited.find(in_node) == visited.end()) { Handler(in_node); }\n    });\n  };\n  while (stack.empty() == false) {\n    NodeType* cur_node = stack.top();\n    stack.pop();\n    auto ret = std::make_unique<HashSet<NodeType*>>();\n    DfsForEachNode({cur_node}, ForEachUnvisitedInNode,\n                   [&](NodeType* node) { CHECK(ret->insert(node).second); });\n    for (const auto& node : *ret) { visited.insert(node); }\n    if (ret->size() > 1) { return ret; }\n  }\n  return std::unique_ptr<HashSet<NodeType*>>();\n}\n\ntemplate<typename NodeType, typename EdgeType>\nvoid Graph<NodeType, EdgeType>::TopoForEachNodeDynamic(\n    const std::list<NodeType*>& starts,\n    const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachInNode,\n    const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachOutNode,\n    const std::function<void(NodeType*)>& Handler) const {\n  CHECK_JUST(TopoForEachNodeDynamicWithErrorCaptured(starts, ForEachInNode, ForEachOutNode,\n                                                     [&](NodeType* node) {\n                                                       Handler(node);\n                                                       return Maybe<void>::Ok();\n                                                     }));\n}\n\ntemplate<typename NodeType, typename EdgeType>\nvoid Graph<NodeType, EdgeType>::TopoForEachNode(\n    const std::list<NodeType*>& starts,\n    const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachInNode,\n    const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachOutNode,\n    const std::function<void(NodeType*)>& Handler) const {\n  CHECK_JUST(\n      TopoForEachNodeWithErrorCaptured(starts, ForEachInNode, ForEachOutNode, [&](NodeType* node) {\n        Handler(node);\n        return Maybe<void>::Ok();\n      }));\n}\n\ntemplate<typename NodeType, typename EdgeType>\nvoid Graph<NodeType, EdgeType>::TopoForEachNode(\n    const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachInNode,\n    const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachOutNode,\n    const std::function<void(NodeType*)>& Handler) const {\n  CHECK_JUST(TopoForEachNodeWithErrorCaptured(ForEachInNode, ForEachOutNode, [&](NodeType* node) {\n    Handler(node);\n    return Maybe<void>::Ok();\n  }));\n}\n\ntemplate<typename NodeType, typename EdgeType>\nMaybe<void> Graph<NodeType, EdgeType>::TopoForEachNodeDynamicWithErrorCaptured(\n    const std::list<NodeType*>& starts,\n    const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachInNode,\n    const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachOutNode,\n    const std::function<Maybe<void>(NodeType*)>& Handler) const {\n  HashMap<NodeType*, bool> has_queued;\n  std::queue<NodeType*> queue;\n  for (NodeType* start : starts) {\n    queue.push(start);\n    has_queued[start] = true;\n    ForEachInNode(start, [&](NodeType*) { LOG(FATAL) << \"not a source\"; });\n  }\n  while (!queue.empty()) {\n    NodeType* cur_node = queue.front();\n    queue.pop();\n    JUST(Handler(cur_node));\n    ForEachOutNode(cur_node, [&](NodeType* out) {\n      bool is_ready = true;\n      ForEachInNode(out, [&](NodeType* in) {\n        if (is_ready && !has_queued[in]) { is_ready = false; }\n      });\n      if (is_ready && !has_queued[out]) {\n        queue.push(out);\n        has_queued[out] = true;\n      }\n    });\n  }\n  return Maybe<void>::Ok();\n}\n\ntemplate<typename NodeType, typename EdgeType>\nMaybe<void> Graph<NodeType, EdgeType>::TopoForEachNodeWithErrorCaptured(\n    const std::list<NodeType*>& starts,\n    const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachInNode,\n    const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachOutNode,\n    const std::function<Maybe<void>(NodeType*)>& Handler) const {\n  HashMap<NodeType*, int32_t> counter_in;\n  std::queue<NodeType*> queue;\n  for (NodeType* start : starts) {\n    queue.push(start);\n    counter_in[start] = 0;\n    ForEachInNode(start, [&](NodeType*) { LOG(FATAL) << \"not a source\"; });\n  }\n  while (!queue.empty()) {\n    NodeType* cur_node = queue.front();\n    queue.pop();\n    JUST(Handler(cur_node));\n    ForEachOutNode(cur_node, [&](NodeType* out) {\n      auto it = counter_in.find(out);\n      // Move the initialization here\n      if (it == counter_in.end()) {\n        int32_t count = 0;\n        ForEachInNode(out, [&](NodeType* out_in) { count++; });\n        counter_in[out] = count;\n        it = counter_in.find(out);\n      }\n      it->second--;\n      if (it->second == 0) { queue.push(out); }\n    });\n  }\n  return Maybe<void>::Ok();\n}\n\ntemplate<typename NodeType, typename EdgeType>\nMaybe<void> Graph<NodeType, EdgeType>::TopoForEachNodeWithErrorCaptured(\n    const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachInNode,\n    const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachOutNode,\n    const std::function<Maybe<void>(NodeType*)>& Handler) const {\n  HashMap<NodeType*, int32_t> counter_in;\n  std::queue<NodeType*> queue;\n  ForEachNode([&](NodeType* node) {\n    int32_t count = 0;\n    ForEachInNode(node, [&](NodeType*) { count++; });\n    counter_in[node] = count;\n    if (count == 0) { queue.push(node); }\n  });\n  while (!queue.empty()) {\n    NodeType* cur_node = queue.front();\n    queue.pop();\n    JUST(Handler(cur_node));\n    ForEachOutNode(cur_node, [&](NodeType* out) {\n      --counter_in[out];\n      if (counter_in[out] == 0) { queue.push(out); }\n    });\n  }\n  return Maybe<void>::Ok();\n}\n\ntemplate<typename NodeType, typename EdgeType>\nvoid Graph<NodeType, EdgeType>::DfsTopoForEachNodeSortByDistanceToSink(\n    const std::list<NodeType*>& starts,\n    const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachInNode,\n    const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachOutNode,\n    const std::function<void(NodeType*)>& Handler) const {\n  HashMap<NodeType*, int64_t> node2distance_to_sink;\n  {\n    std::list<NodeType*> nodes;\n    TopoForEachNode(ForEachInNode, ForEachOutNode,\n                    [&](NodeType* node) { nodes.emplace_back(node); });\n    std::list<NodeType*> sinks;\n    for (NodeType* node : nodes) {\n      bool is_sink = true;\n      ForEachOutNode(node, [&](NodeType* out_node) { is_sink = false; });\n      if (is_sink) { sinks.emplace_back(node); }\n    }\n    TopoForEachNode(ForEachOutNode, ForEachInNode, [&](NodeType* node) {\n      int64_t distance_to_sink = -1;\n      ForEachOutNode(node, [&](NodeType* out_node) {\n        distance_to_sink = std::max(distance_to_sink, node2distance_to_sink[out_node]);\n      });\n      node2distance_to_sink[node] = distance_to_sink + 1;\n    });\n  }\n  auto ForEachOutNodeSortedByDistanceToSink = [&](NodeType* node,\n                                                  const std::function<void(NodeType*)>& Handler) {\n    std::vector<NodeType*> out_nodes;\n    ForEachOutNode(node, [&](NodeType* out_node) { out_nodes.emplace_back(out_node); });\n    std::sort(out_nodes.begin(), out_nodes.end(), [&](NodeType* lhs, NodeType* rhs) {\n      // DfsTopoForEachNode use stack, so sort desc\n      return node2distance_to_sink.at(lhs) > node2distance_to_sink.at(rhs);\n    });\n    for (NodeType* out_node : out_nodes) { Handler(out_node); }\n  };\n  DfsTopoForEachNode(starts, ForEachInNode, ForEachOutNodeSortedByDistanceToSink, Handler);\n}\n\ntemplate<typename NodeType, typename EdgeType>\nvoid Graph<NodeType, EdgeType>::DfsTopoForEachNode(\n    const std::list<NodeType*>& starts,\n    const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachInNode,\n    const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachOutNode,\n    const std::function<void(NodeType*)>& Handler) const {\n  HashMap<NodeType*, bool> be_visited;\n  std::stack<NodeType*> stack;\n  for (NodeType* start : starts) {\n    stack.push(start);\n    ForEachInNode(start, [&](NodeType*) { LOG(FATAL) << \"not a source\"; });\n  }\n  while (!stack.empty()) {\n    NodeType* cur_node = stack.top();\n    stack.pop();\n    Handler(cur_node);\n    be_visited[cur_node] = true;\n    ForEachOutNode(cur_node, [&](NodeType* out) {\n      bool is_ready = true;\n      ForEachInNode(out, [&](NodeType* in) {\n        if (is_ready && !be_visited[in]) { is_ready = false; }\n      });\n      if (is_ready && !be_visited[out]) { stack.push(out); }\n    });\n  }\n}\n\ntemplate<typename NodeType, typename EdgeType>\nstd::function<bool(const NodeType* src, const NodeType* dst)>\nGraph<NodeType, EdgeType>::MakePredicatorIsReachable() const {\n  return MakePredicatorIsReachable(source_nodes(), &NodeType::ForEachNodeOnInEdge,\n                                   &NodeType::ForEachNodeOnOutEdge);\n}\n\ntemplate<typename NodeType, typename EdgeType>\nstd::function<bool(const NodeType* src, const NodeType* dst)>\nGraph<NodeType, EdgeType>::MakePredicatorIsReachable(\n    const std::list<NodeType*>& starts,\n    const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachInNode,\n    const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachOutNode)\n    const {\n  static constexpr int64_t BITSET_SIZE = 512;  // size of cache line\n  class BitSet {\n   public:\n    BitSet() = default;\n    ~BitSet() = default;\n\n    void Insert(int64_t k) { bitset_vec_.at(k / BITSET_SIZE).set(k % BITSET_SIZE, true); }\n\n    bool Contains(int64_t k) { return bitset_vec_.at(k / BITSET_SIZE).test(k % BITSET_SIZE); }\n\n    void Merge(const BitSet& other) {\n      CHECK_EQ(bitset_vec_.size(), other.bitset_vec_.size());\n      for (int64_t i = 0; i < bitset_vec_.size(); ++i) {\n        bitset_vec_.at(i) |= other.bitset_vec_.at(i);\n      }\n    }\n\n    void Resize(size_t size) {\n      const int64_t bitset_vec_size = RoundUp(size, BITSET_SIZE) / BITSET_SIZE;\n      bitset_vec_.resize(bitset_vec_size);\n    }\n\n   private:\n    using bitset_vec = std::vector<std::bitset<BITSET_SIZE>>;\n    bitset_vec bitset_vec_;\n  };\n\n  using NodePtr2Id = HashMap<const NodeType*, int64_t>;\n  using Id2Ancestor = std::vector<BitSet>;\n  std::shared_ptr<NodePtr2Id> node2id(new NodePtr2Id);\n  std::shared_ptr<Id2Ancestor> id2ancestor(new Id2Ancestor(node_num()));\n  int64_t id = 0;\n  node2id->reserve(node_num());\n  TopoForEachNode(ForEachInNode, ForEachOutNode, [&](NodeType* node) {\n    node2id->emplace(node, id);\n    id2ancestor->at(id).Resize(node_num());\n    id += 1;\n  });\n  TopoForEachNode(ForEachInNode, ForEachOutNode, [&](NodeType* node) {\n    const int64_t node_id = node2id->at(node);\n    auto& ancestor_bitset_vec = id2ancestor->at(node_id);\n    ForEachInNode(node, [&](NodeType* in_node) {\n      const int64_t in_node_id = node2id->at(in_node);\n      ancestor_bitset_vec.Insert(in_node_id);\n      ancestor_bitset_vec.Merge(id2ancestor->at(in_node_id));\n    });\n  });\n  return [id2ancestor, node2id](const NodeType* src, const NodeType* dst) -> bool {\n    const int64_t dst_id = node2id->at(dst);\n    return id2ancestor->at(dst_id).Contains(node2id->at(src));\n  };\n}\n\ntemplate<typename NodeType, typename EdgeType>\nvoid Graph<NodeType, EdgeType>::ForEachConnectedComponent(\n    const std::function<void(const HashSet<NodeType*>&)>& Handler) const {\n  ForEachConnectedComponent(\n      [&](const std::function<void(NodeType*)>& Handler) { ForEachNode(Handler); },\n      &NodeType::ForEachNodeOnInOutEdge, Handler);\n}\n\ntemplate<typename NodeType, typename EdgeType>\nvoid Graph<NodeType, EdgeType>::ForEachConnectedComponent(\n    const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachConnected,\n    const std::function<void(const HashSet<NodeType*>&)>& Handler) const {\n  ForEachConnectedComponent(\n      [&](const std::function<void(NodeType*)>& Handler) { ForEachNode(Handler); },\n      ForEachConnected, Handler);\n}\n\ntemplate<typename NodeType, typename EdgeType>\nvoid Graph<NodeType, EdgeType>::ForEachConnectedComponent(\n    const std::function<void(const std::function<void(NodeType*)>&)>& ForEachNodeAsStart,\n    const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachConnected,\n    const std::function<void(const HashSet<NodeType*>&)>& Handler) const {\n  HashMap<NodeType*, int32_t> node2component_id;\n  int32_t cur_component_id = 0;\n  ForEachNodeAsStart([&](NodeType* start) {\n    if (node2component_id.find(start) != node2component_id.end()) { return; }\n    ++cur_component_id;\n    BfsForEachNode({start}, ForEachConnected, [&](NodeType* node) {\n      CHECK(node2component_id.emplace(node, cur_component_id).second);\n    });\n  });\n  HashMap<int32_t, HashSet<NodeType*>> component_id2nodes;\n  for (const auto& pair : node2component_id) { component_id2nodes[pair.second].insert(pair.first); }\n  for (const auto& pair : component_id2nodes) { Handler(pair.second); }\n}\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_GRAPH_GRAPH_H_\n"
  },
  {
    "path": "oneflow/core/graph/inplace_lbi_graph.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/graph/inplace_lbi_graph.h\"\n#include \"oneflow/core/common/protobuf.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nbool IsSourceNode(const Operator& op) {\n  const auto& op_conf = op.op_conf();\n  if (op_conf.has_user_conf() && op_conf.user_conf().input().size() == 0\n      && op_conf.user_conf().output().size() == 1) {\n    return true;\n  }\n  if (op_conf.has_user_conf() && op_conf.user_conf().op_type_name() == \"mutable_cast_once\") {\n    return true;\n  }\n  if (op_conf.has_variable_conf()) { return true; }\n  if (op_conf.has_distribute_clone_conf() && op_conf.distribute_clone_conf().is_variable_ref()) {\n    return true;\n  }\n  if (op_conf.has_distribute_split_conf() && op_conf.distribute_split_conf().is_variable_ref()) {\n    return true;\n  }\n  return false;\n}\n\nvoid CheckSubGraph(const HashSet<const InplaceLbiNode*>& nodes) {\n  size_t source_op_node_cnt = 0;\n  size_t updt_node_cnt = 0;\n  size_t source_cnt = 0;\n  for (const auto* node : nodes) {\n    if (node->in_edges().empty()) { CHECK_EQ(++source_cnt, 1); }\n    if (dynamic_cast<const SourceOpInplaceLbiNode*>(node) != nullptr) {\n      CHECK_EQ(++source_op_node_cnt, 1);\n      CHECK(node->in_edges().empty());\n    }\n    if (dynamic_cast<const UpdateInplaceLbiNode*>(node) != nullptr) {\n      CHECK_EQ(++updt_node_cnt, 1);\n      CHECK(dynamic_cast<const SourceOpInplaceLbiNode*>(node->SoleInEdge()->src_node()) != nullptr)\n          << \"UpdateInplaceLbiNode-lbi: \" << PbMessage2TxtString(node->lbi())\n          << \", src_node.in_edges_size: \" << node->SoleInEdge()->src_node()->in_edges().size()\n          << \", SoleInNode: \" << typeid(node->SoleInEdge()->src_node()).name() << \", \"\n          << PbMessage2TxtString(node->SoleInEdge()->src_node()->lbi());\n    }\n  }\n}\n\nconst InplaceLbiNode* GetRoot(const HashSet<const InplaceLbiNode*>& nodes,\n                              const std::function<bool(const InplaceLbiEdge*)>& IsValidEdge) {\n  const InplaceLbiNode* root = nullptr;\n  for (const InplaceLbiNode* node : nodes) {\n    if (node->GetValidInEdge(IsValidEdge) == nullptr) {\n      CHECK_ISNULL(root);\n      root = node;\n    }\n  }\n  return root;\n}\n\nconst InplaceLbiNode* FindSoleIsMutableIbnConsumer(const SourceOpInplaceLbiNode* node) {\n  const InplaceLbiNode* ret = nullptr;\n  for (const InplaceLbiEdge* edge : node->out_edges()) {\n    if (dynamic_cast<const UpdateInplaceLbiNode*>(edge->dst_node()) != nullptr) {\n      CHECK_ISNULL(ret);\n      ret = edge->dst_node();\n    }\n  }\n  return ret;\n}\n\nInplaceLbiNode* CreateNode(const LogicalBlobId& lbi,\n                           const std::function<const Operator*(const std::string&)>& Op4OpName) {\n  const Operator& op = *Op4OpName(lbi.op_name());\n  if (IsSourceNode(op)) {\n    return new SourceOpInplaceLbiNode(lbi);\n  } else if (std::find_if(op.output_bns().begin(), op.output_bns().end(),\n                          [&](const std::string& obn) { return op.BnInOp2Lbi(obn) == lbi; })\n             != op.output_bns().end()) {\n    return new NormalInplaceLbiNode(lbi);\n  } else {\n    return new UpdateInplaceLbiNode(lbi);\n  }\n}\n\nvoid GetUnconnectedNodes(const HashSet<const InplaceLbiNode*>& nodes,\n                         const std::function<bool(const InplaceLbiEdge*)>& IsValidEdge,\n                         HashSet<const InplaceLbiNode*>* cur_disabled_nodes) {\n  for (const InplaceLbiNode* node : nodes) {\n    size_t cnt = 0;\n    for (const InplaceLbiEdge* edge : node->in_edges()) { cnt += IsValidEdge(edge); }\n    for (const InplaceLbiEdge* edge : node->out_edges()) { cnt += IsValidEdge(edge); }\n    if (cnt == 0) { CHECK(cur_disabled_nodes->emplace(node).second); }\n  }\n}\n\nconst InplaceLbiNode* GetFirstDiffNode(const std::vector<const InplaceLbiNode*>& lhs,\n                                       const std::vector<const InplaceLbiNode*>& rhs) {\n  FOR_RANGE(int32_t, i, 0, std::min(lhs.size(), rhs.size())) {\n    if (lhs.at(i) != rhs.at(i)) { return lhs.at(i); }\n  }\n  return nullptr;\n};\n\nstd::function<void(const InplaceLbiNode*, const std::function<void(const InplaceLbiNode*)>&)>\nGetForEachValidInNode(const HashSet<const InplaceLbiNode*>* nodes,\n                      std::function<bool(const InplaceLbiEdge*)> IsValidEdge) {\n  return [nodes, IsValidEdge](const InplaceLbiNode* node,\n                              const std::function<void(const InplaceLbiNode*)>& Handler) {\n    const InplaceLbiEdge* in_edge = node->GetValidInEdge(IsValidEdge);\n    if (in_edge == nullptr) { return; }\n    if (nodes->find(in_edge->src_node()) != nodes->end()) { Handler(in_edge->src_node()); }\n  };\n}\n\nstd::function<void(const InplaceLbiNode*, const std::function<void(const InplaceLbiNode*)>&)>\nGetForEachValidOutNode(const HashSet<const InplaceLbiNode*>* nodes,\n                       std::function<bool(const InplaceLbiEdge*)> IsValidEdge) {\n  return [nodes, IsValidEdge](const InplaceLbiNode* node,\n                              const std::function<void(const InplaceLbiNode*)>& Handler) {\n    node->ForEachNodeOnValidOutEdge(IsValidEdge, [&](const InplaceLbiNode* out_node) {\n      if (nodes->find(out_node) != nodes->end()) { Handler(out_node); }\n    });\n  };\n}\n\nbool IsOtherIbnBoundToOneOfLbis(const HashSet<LogicalBlobId>& lbis, const InplaceLbiEdge* edge) {\n  const Operator& op = edge->op();\n  for (const std::string& ibn : op.input_bns()) {\n    if (ibn != edge->ibn() && lbis.find(op.BnInOp2Lbi(ibn)) != lbis.end()) { return true; }\n  }\n  return false;\n}\n\nvoid RemoveUnconnectedNodes(HashSet<const InplaceLbiNode*>* nodes,\n                            const std::function<bool(const InplaceLbiEdge*)>& IsValidEdge) {\n  HashSet<const InplaceLbiNode*> cur_disabled_nodes;\n  GetUnconnectedNodes(*nodes, IsValidEdge, &cur_disabled_nodes);\n  for (const auto* node : cur_disabled_nodes) { nodes->erase(node); }\n}\n\n}  // namespace\n\nconst InplaceLbiEdge* InplaceLbiNode::GetValidInEdge(\n    const std::function<bool(const InplaceLbiEdge*)>& IsValidEdge) const {\n  if (!in_edges().empty() && IsValidEdge(SoleInEdge())) { return SoleInEdge(); }\n  return nullptr;\n}\n\nconst InplaceLbiEdge* InplaceLbiNode::GetSoleValidInEdge(\n    const std::function<bool(const InplaceLbiEdge*)>& IsValidEdge) const {\n  const auto* edge = GetValidInEdge(IsValidEdge);\n  CHECK_NOTNULL(edge);\n  return edge;\n}\n\nvoid InplaceLbiNode::ForEachNodeOnValidOutEdge(\n    const std::function<bool(const InplaceLbiEdge*)>& IsValidEdge,\n    const std::function<void(const InplaceLbiNode*)>& Handler) const {\n  for (const auto* edge : out_edges()) {\n    if (IsValidEdge(edge)) { Handler(edge->dst_node()); }\n  }\n}\n\nbool InplaceLbiNode::IsMutRef(const std::function<bool(const InplaceLbiEdge*)>& IsValidEdge) const {\n  UNIMPLEMENTED();\n}\n\nbool InplaceLbiNode::IsConstRef(\n    const std::function<bool(const InplaceLbiEdge*)>& IsValidEdge) const {\n  return !IsMutRef(IsValidEdge);\n}\n\nbool NormalInplaceLbiNode::IsMutRef(\n    const std::function<bool(const InplaceLbiEdge*)>& IsValidEdge) const {\n  const InplaceLbiEdge* in_edge = GetValidInEdge(IsValidEdge);\n  return in_edge != nullptr && in_edge->IsMutRef();\n}\n\nbool InplaceLbiEdge::IsMutRef() const {\n  CHECK_NOTNULL(dynamic_cast<const NormalInplaceLbiNode*>(dst_node()));\n  return is_mut_ref_;\n}\n\nstd::function<InplaceLbiNode*(const LogicalBlobId&)> InplaceLbiGraph::MakeMutFindOrCreateNode(\n    std::function<const Operator*(const std::string&)> Op4OpName) {\n  auto lbi2node = std::make_shared<HashMap<LogicalBlobId, InplaceLbiNode*>>();\n  return [this, lbi2node, Op4OpName](const LogicalBlobId& lbi) -> InplaceLbiNode* {\n    auto node_it = lbi2node->find(lbi);\n    if (node_it == lbi2node->end()) {\n      auto* node = CreateNode(lbi, Op4OpName);\n      AddAllocatedNode(node);\n      node_it = lbi2node->emplace(lbi, node).first;\n    }\n    return node_it->second;\n  };\n}\n\nvoid InplaceLbiGraph::Init(const InplaceObasInfo& obas_info,\n                           const std::function<const Operator*(const std::string&)>& Op4OpName) {\n  auto FindOrCreateNode = MakeMutFindOrCreateNode(Op4OpName);\n  auto AddEdge = [&](const Operator& op, const LogicalBlobId& lbi, const std::string& ibn,\n                     const std::string& obn, bool is_mut) {\n    auto* edge = new InplaceLbiEdge(&op, ibn, obn, is_mut);\n    AddAllocatedEdge(edge);\n    Connect<InplaceLbiNode, InplaceLbiEdge>(FindOrCreateNode(op.BnInOp2Lbi(ibn)), edge,\n                                            FindOrCreateNode(lbi));\n  };\n\n  auto BuildNodeAndEdge4InplacePairs = [&](const OpBlobArgPairs& pairs, bool is_mut) {\n    for (const auto& pair : pairs.pair()) {\n      CHECK_EQ(pair.first().op_name(), pair.second().op_name());\n      const Operator& op = *Op4OpName(pair.first().op_name());\n      std::string ibn = pair.first().bn_in_op();\n      std::string obn = pair.second().bn_in_op();\n      LogicalBlobId lbi = op.BnInOp2Lbi(obn);\n      CHECK(std::find(op.input_bns().begin(), op.input_bns().end(), ibn) != op.input_bns().end());\n      CHECK(std::find(op.output_bns().begin(), op.output_bns().end(), obn)\n            != op.output_bns().end());\n      AddEdge(op, lbi, ibn, obn, is_mut);\n    }\n  };\n\n  for (const auto& oba : obas_info.mut_in_obas.oba()) {\n    const Operator& op = *Op4OpName(oba.op_name());\n    std::string ibn = oba.bn_in_op();\n    std::string obn = ibn + \"_updated\";\n    LogicalBlobId lbi;\n    lbi.set_op_name(op.op_name());\n    lbi.set_blob_name(obn);\n    CHECK(std::find(op.input_bns().begin(), op.input_bns().end(), ibn) != op.input_bns().end());\n    CHECK(std::find_if(op.output_bns().begin(), op.output_bns().end(),\n                       [&](const std::string& obn) { return op.BnInOp2Lbi(obn) == lbi; })\n          == op.output_bns().end());\n    AddEdge(op, lbi, ibn, obn, true);\n  }\n\n  BuildNodeAndEdge4InplacePairs(obas_info.mut_inplace_oba_pairs, true);\n  BuildNodeAndEdge4InplacePairs(obas_info.con_inplace_oba_pairs, false);\n\n  ForEachNode([](const InplaceLbiNode* node) { CHECK_LE(node->in_edges().size(), 1); });\n  CHECK(!FindFirstNontrivialSCC());\n}\n\nvoid InplaceLbiGraph::ComputeSafeInplaceObns(\n    InplaceObasInfo* obas_info,\n    const std::function<bool(const LogicalBlobId&, const std::string&)>& IsReachableFromLbiToOpName)\n    const {\n  ComputeSafeInplaceEdges(IsReachableFromLbiToOpName, [&](const InplaceLbiEdge* edge) {\n    CHECK_NOTNULL(dynamic_cast<const NormalInplaceLbiNode*>(edge->dst_node()));\n    if (edge->IsMutRef()) {\n      auto* pair = obas_info->mut_inplace_oba_pairs.mutable_pair()->Add();\n      *pair->mutable_first() = GenOpBlobArg(edge->op().op_name(), edge->ibn());\n      *pair->mutable_second() = GenOpBlobArg(edge->op().op_name(), edge->obn());\n    } else {\n      auto* pair = obas_info->con_inplace_oba_pairs.mutable_pair()->Add();\n      *pair->mutable_first() = GenOpBlobArg(edge->op().op_name(), edge->ibn());\n      *pair->mutable_second() = GenOpBlobArg(edge->op().op_name(), edge->obn());\n    }\n  });\n}\n\nvoid InplaceLbiGraph::ComputeSafeInplaceEdges(\n    const std::function<bool(const LogicalBlobId&, const std::string&)>& IsReachableFromLbiToOpName,\n    const std::function<void(const InplaceLbiEdge*)>& Handler) const {\n  ForEachConnectedComponent([&](const HashSet<const InplaceLbiNode*>& nodes) {\n    ComputeSafeInplaceEdges(nodes, IsReachableFromLbiToOpName, Handler);\n  });\n}\n\nvoid InplaceLbiGraph::ForEachSafeInplaceEdgeInSourceOpSubTree(\n    const HashSet<const InplaceLbiNode*>& nodes,\n    const std::function<bool(const LogicalBlobId&, const std::string&)>& IsReachableFromLbiToOpName,\n    const std::function<void(const InplaceLbiEdge*)>& Handler,\n    HashSet<const InplaceLbiEdge*>* disabled_edges) const {\n  disabled_edges->clear();\n  auto IsValidEdge = [&](const InplaceLbiEdge* edge) {\n    return disabled_edges->find(edge) == disabled_edges->end();\n  };\n  const InplaceLbiNode* root = GetRoot(nodes, [](const InplaceLbiEdge*) { return true; });\n  const auto* source_op_root = dynamic_cast<const SourceOpInplaceLbiNode*>(root);\n  if (source_op_root != nullptr) {\n    const InplaceLbiNode* updt_node = FindSoleIsMutableIbnConsumer(source_op_root);\n    if (updt_node != nullptr) {\n      HashSet<const InplaceLbiEdge*> cur_disabled_edges;\n      FixConstRefOrMutRefConflictsToUpdtNode(nodes, IsReachableFromLbiToOpName,\n                                             &cur_disabled_edges);\n      disabled_edges->insert(cur_disabled_edges.begin(), cur_disabled_edges.end());\n    }\n    {\n      HashSet<const InplaceLbiEdge*> cur_disabled_edges;\n      FixMutRefConflictsFromSourceOpNode(source_op_root, IsValidEdge, &cur_disabled_edges);\n      disabled_edges->insert(cur_disabled_edges.begin(), cur_disabled_edges.end());\n    }\n    {\n      // disconnect edges in the subtree containning `root`\n      HashSet<const InplaceLbiEdge*> cur_disabled_edges;\n      auto ForEachNext = GetForEachValidOutNode(&nodes, IsValidEdge);\n      BfsForEachNode({root}, ForEachNext, [&](const InplaceLbiNode* node) {\n        const InplaceLbiEdge* in_edge = node->GetValidInEdge(IsValidEdge);\n        if (in_edge != nullptr) { CHECK(cur_disabled_edges.emplace(in_edge).second); }\n        if (dynamic_cast<const NormalInplaceLbiNode*>(node) != nullptr) {\n          CHECK_NOTNULL(in_edge);\n          if (node->IsConstRef(IsValidEdge)) { Handler(in_edge); }\n        }\n      });\n      disabled_edges->insert(cur_disabled_edges.begin(), cur_disabled_edges.end());\n    }\n  }\n}\n\nvoid InplaceLbiGraph::ComputeSafeInplaceEdges(\n    const HashSet<const InplaceLbiNode*>& nodes,\n    const std::function<bool(const LogicalBlobId&, const std::string&)>& IsReachableFromLbiToOpName,\n    const std::function<void(const InplaceLbiEdge*)>& Handler) const {\n  CheckSubGraph(nodes);\n  HashSet<const InplaceLbiNode*> remainder_nodes(nodes);\n  HashSet<const InplaceLbiEdge*> disabled_edges;\n  {\n    // compute safe inplace edges in the subtree containning SourceOpInplaceLbiNode as root\n    HashSet<const InplaceLbiEdge*> cur_disabled_edges;\n    ForEachSafeInplaceEdgeInSourceOpSubTree(remainder_nodes, IsReachableFromLbiToOpName, Handler,\n                                            &cur_disabled_edges);\n    disabled_edges.insert(cur_disabled_edges.begin(), cur_disabled_edges.end());\n  }\n  auto IsValidEdge = [&](const InplaceLbiEdge* edge) {\n    return remainder_nodes.find(edge->src_node()) != remainder_nodes.end()\n           && remainder_nodes.find(edge->dst_node()) != remainder_nodes.end()\n           && disabled_edges.find(edge) == disabled_edges.end();\n  };\n  RemoveUnconnectedNodes(&remainder_nodes, IsValidEdge);\n  size_t dead_loop_check = remainder_nodes.size();\n  while (!remainder_nodes.empty()) {\n    ForEachTree(remainder_nodes, IsValidEdge, [&](const HashSet<const InplaceLbiNode*>& nodes) {\n      const InplaceLbiEdge* cur_disabled_edge =\n          FindFirstInterOpRefConflictMutRefEdge(nodes, IsValidEdge, IsReachableFromLbiToOpName);\n      if (cur_disabled_edge != nullptr) { disabled_edges.insert(cur_disabled_edge); }\n    });\n    ForEachTree(remainder_nodes, IsValidEdge, [&](const HashSet<const InplaceLbiNode*>& nodes) {\n      const InplaceLbiEdge* cur_disabled_edge =\n          FindFirstConstRefConflictMutRefEdge(nodes, IsValidEdge, IsReachableFromLbiToOpName);\n      if (cur_disabled_edge != nullptr) { disabled_edges.insert(cur_disabled_edge); }\n    });\n    ForEachTree(remainder_nodes, IsValidEdge, [&](const HashSet<const InplaceLbiNode*>& nodes) {\n      const InplaceLbiEdge* cur_disabled_edge =\n          FindFirstIntraOpRefConflictMutRefEdge(nodes, IsValidEdge);\n      if (cur_disabled_edge != nullptr) { disabled_edges.insert(cur_disabled_edge); }\n    });\n    {\n      HashSet<const InplaceLbiEdge*> cur_safe_inplace_obn_edges;\n      GetSafeInplaceObnEdges(remainder_nodes, IsValidEdge, IsReachableFromLbiToOpName,\n                             &cur_safe_inplace_obn_edges);\n      for (const auto* edge : cur_safe_inplace_obn_edges) { Handler(edge); }\n      disabled_edges.insert(cur_safe_inplace_obn_edges.begin(), cur_safe_inplace_obn_edges.end());\n    }\n    RemoveUnconnectedNodes(&remainder_nodes, IsValidEdge);\n    CHECK_GE(--dead_loop_check, 0);\n  }\n}\n\nvoid InplaceLbiGraph::FindAllEdges(const HashSet<const InplaceLbiNode*>& nodes,\n                                   const std::function<bool(const InplaceLbiEdge*)>& IsValidEdge,\n                                   HashSet<const InplaceLbiEdge*>* cur_disabled_edges) const {\n  for (const auto* node : nodes) {\n    node->ForEachNodeOnValidOutEdge(IsValidEdge, [&](const InplaceLbiNode* out_node) {\n      CHECK(cur_disabled_edges->emplace(out_node->GetSoleValidInEdge(IsValidEdge)).second);\n    });\n  }\n}\n\nconst InplaceLbiEdge* InplaceLbiGraph::FindFirstIntraOpRefConflictMutRefEdge(\n    const HashSet<const InplaceLbiNode*>& nodes,\n    const std::function<bool(const InplaceLbiEdge*)>& IsValidEdge) const {\n  const InplaceLbiEdge* ret = nullptr;\n  HashSet<LogicalBlobId> lbis;\n  for (const auto* node : nodes) { CHECK(lbis.insert(node->lbi()).second); }\n\n  const auto* root = GetRoot(nodes, IsValidEdge);\n  auto ForEachInNode = GetForEachValidInNode(&nodes, IsValidEdge);\n  auto ForEachOutNode = GetForEachValidOutNode(&nodes, IsValidEdge);\n  TopoForEachNode({root}, ForEachInNode, ForEachOutNode, [&](const InplaceLbiNode* node) {\n    if (ret != nullptr) { return; }\n    if (node->IsMutRef(IsValidEdge) && IsOtherIbnBoundToOneOfLbis(lbis, node->SoleInEdge())) {\n      ret = node->SoleInEdge();\n    }\n  });\n  return ret;\n}\n\nbool InplaceLbiGraph::IsConstRefConflictMutRefNode(\n    const InplaceLbiNode* mut_ref_node, const HashSet<const InplaceLbiNode*>& nodes,\n    const std::function<bool(const InplaceLbiEdge*)>& IsValidEdge,\n    const std::function<bool(const LogicalBlobId&, const std::string&)>&\n        IsLbiAllConsumerReachableToOpName) const {\n  CHECK(mut_ref_node->IsMutRef(IsValidEdge));\n  auto ForEachNext = [&](const InplaceLbiNode* node,\n                         const std::function<void(const InplaceLbiNode*)>& Handler) {\n    node->ForEachNodeOnValidOutEdge(IsValidEdge, [&](const InplaceLbiNode* out_node) {\n      if (out_node != mut_ref_node) { Handler(out_node); }\n    });\n  };\n  bool conflict = false;\n  const auto& op_name = mut_ref_node->lbi().op_name();\n  BfsForEachNode({GetRoot(nodes, IsValidEdge)}, ForEachNext, [&](const InplaceLbiNode* node) {\n    conflict = conflict || !IsLbiAllConsumerReachableToOpName(node->lbi(), op_name);\n  });\n  return conflict;\n}\n\nconst InplaceLbiEdge* InplaceLbiGraph::FindFirstConstRefConflictMutRefEdge(\n    const HashSet<const InplaceLbiNode*>& nodes,\n    const std::function<bool(const InplaceLbiEdge*)>& IsValidEdge,\n    const std::function<bool(const LogicalBlobId&, const std::string&)>&\n        IsLbiAllConsumerReachableToOpName) const {\n  const InplaceLbiNode* root = GetRoot(nodes, IsValidEdge);\n  auto ForEachInNode = GetForEachValidInNode(&nodes, IsValidEdge);\n  auto ForEachOutNode = GetForEachValidOutNode(&nodes, IsValidEdge);\n  const InplaceLbiEdge* ret = nullptr;\n  TopoForEachNode({root}, ForEachInNode, ForEachOutNode, [&](const InplaceLbiNode* node) {\n    if (ret != nullptr) { return; }\n    if (node->IsMutRef(IsValidEdge)\n        && IsConstRefConflictMutRefNode(node, nodes, IsValidEdge,\n                                        IsLbiAllConsumerReachableToOpName)) {\n      ret = node->GetValidInEdge(IsValidEdge);\n    }\n  });\n  return ret;\n}\n\nconst InplaceLbiEdge* InplaceLbiGraph::FindFirstInterOpRefConflictMutRefEdge(\n    const HashSet<const InplaceLbiNode*>& nodes,\n    const std::function<bool(const InplaceLbiEdge*)>& IsValidEdge,\n    const std::function<bool(const LogicalBlobId&, const std::string&)>&\n        IsLbiAllConsumerReachableToOpName) const {\n  HashSet<const InplaceLbiNode*> mut_ref_nodes;\n  HashMap<const InplaceLbiNode*, std::vector<const InplaceLbiNode*>> node2mut_ref_ancestors;\n  {\n    const InplaceLbiNode* root = GetRoot(nodes, IsValidEdge);\n    auto ForEachInNode = GetForEachValidInNode(&nodes, IsValidEdge);\n    auto ForEachOutNode = GetForEachValidOutNode(&nodes, IsValidEdge);\n    TopoForEachNode({root}, ForEachInNode, ForEachOutNode, [&](const InplaceLbiNode* node) {\n      if (node->IsMutRef(IsValidEdge)) { mut_ref_nodes.insert(node); }\n      size_t in_edges_size_check = 0;\n      ForEachInNode(node, [&](const InplaceLbiNode* in_node) {\n        node2mut_ref_ancestors[node] = node2mut_ref_ancestors[in_node];\n        if (in_node->IsMutRef(IsValidEdge)) { node2mut_ref_ancestors[node].emplace_back(in_node); }\n        CHECK_EQ(++in_edges_size_check, 1);\n      });\n    });\n  }\n  std::vector<const InplaceLbiNode*> last_mut_ref_nodes;\n  {\n    HashMap<const InplaceLbiNode*, size_t> mut_ref_node2descendents_size;\n    for (const InplaceLbiNode* descendent : mut_ref_nodes) {\n      for (const InplaceLbiNode* ancestor : node2mut_ref_ancestors.at(descendent)) {\n        ++mut_ref_node2descendents_size[ancestor];\n      }\n    }\n    for (const InplaceLbiNode* node : mut_ref_nodes) {\n      if (mut_ref_node2descendents_size[node] == 0) { last_mut_ref_nodes.emplace_back(node); }\n    }\n  }\n  if (last_mut_ref_nodes.size() <= 1) { return nullptr; }\n  const InplaceLbiNode* first_diff_node = nullptr;\n  {\n    const auto& first = node2mut_ref_ancestors.at(last_mut_ref_nodes.at(0));\n    const auto& second = node2mut_ref_ancestors.at(last_mut_ref_nodes.at(1));\n    first_diff_node = GetFirstDiffNode(first, second);\n    if (first_diff_node == nullptr) {\n      first_diff_node = last_mut_ref_nodes.at(first.size() < second.size() ? 0 : 1);\n    }\n  }\n  return first_diff_node->GetSoleValidInEdge(IsValidEdge);\n}\n\nvoid InplaceLbiGraph::GetSafeInplaceObnEdges(\n    const HashSet<const InplaceLbiNode*>& nodes,\n    const std::function<bool(const InplaceLbiEdge*)>& IsValidEdge,\n    const std::function<bool(const LogicalBlobId&, const std::string&)>&\n        IsLbiAllConsumerReachableToOpName,\n    HashSet<const InplaceLbiEdge*>* cur_disabled_edges) const {\n  ForEachTree(nodes, IsValidEdge, [&](const HashSet<const InplaceLbiNode*>& nodes) {\n    // no inter-op reference conflicts\n    const InplaceLbiEdge* inter_op_conflict_ref_edge = FindFirstInterOpRefConflictMutRefEdge(\n        nodes, IsValidEdge, IsLbiAllConsumerReachableToOpName);\n    // mutable reference always goes after const reference\n    const InplaceLbiEdge* const_ref_conflict_ref_edge =\n        FindFirstConstRefConflictMutRefEdge(nodes, IsValidEdge, IsLbiAllConsumerReachableToOpName);\n    // no intra-op reference conflicts\n    const InplaceLbiEdge* intra_op_conflict_ref_edge =\n        FindFirstIntraOpRefConflictMutRefEdge(nodes, IsValidEdge);\n    if (const_ref_conflict_ref_edge == nullptr && intra_op_conflict_ref_edge == nullptr\n        && inter_op_conflict_ref_edge == nullptr) {\n      FindAllEdges(nodes, IsValidEdge, cur_disabled_edges);\n    }\n  });\n}\n\nvoid InplaceLbiGraph::ForEachTree(\n    const HashSet<const InplaceLbiNode*>& nodes,\n    const std::function<bool(const InplaceLbiEdge*)>& IsValidEdge,\n    const std::function<void(const HashSet<const InplaceLbiNode*>&)>& Handler) const {\n  auto ForEachNode = [&](const std::function<void(const InplaceLbiNode*)>& Handler) {\n    for (const auto* node : nodes) { Handler(node); }\n  };\n  auto ForEachInNode = GetForEachValidInNode(&nodes, IsValidEdge);\n  auto ForEachOutNode = GetForEachValidOutNode(&nodes, IsValidEdge);\n  auto ForEachConnected = [&](const InplaceLbiNode* node,\n                              const std::function<void(const InplaceLbiNode*)>& Handler) {\n    ForEachInNode(node, Handler);\n    ForEachOutNode(node, Handler);\n  };\n  ForEachConnectedComponent(ForEachNode, ForEachConnected, Handler);\n}\n\nvoid InplaceLbiGraph::FixConstRefOrMutRefConflictsToUpdtNode(\n    const HashSet<const InplaceLbiNode*>& nodes,\n    const std::function<bool(const LogicalBlobId&, const std::string&)>&\n        IsLbiAllConsumerReachableToOpName,\n    HashSet<const InplaceLbiEdge*>* cur_disabled_edges) const {\n  auto IsValidEdge = [](const InplaceLbiEdge*) { return true; };\n  const InplaceLbiNode* updt_node = nullptr;\n  HashSet<const InplaceLbiNode*> safe_const_ref_nodes;\n  const InplaceLbiNode* root = GetRoot(nodes, IsValidEdge);\n  CHECK_NOTNULL(root);\n  {\n    const auto* source_op_root = dynamic_cast<const SourceOpInplaceLbiNode*>(root);\n    CHECK_NOTNULL(source_op_root);\n    updt_node = FindSoleIsMutableIbnConsumer(source_op_root);\n    CHECK_NOTNULL(updt_node);\n    auto ForEachNext = [&](const InplaceLbiNode* node,\n                           const std::function<void(const InplaceLbiNode*)>& Handler) {\n      node->ForEachNodeOnValidOutEdge(IsValidEdge, [&](const InplaceLbiNode* out_node) {\n        if (dynamic_cast<const NormalInplaceLbiNode*>(out_node) == nullptr) { return; }\n        if (out_node->IsMutRef(IsValidEdge)) { return; }\n        if (!IsLbiAllConsumerReachableToOpName(out_node->lbi(), updt_node->lbi().op_name())) {\n          return;\n        }\n        Handler(out_node);\n      });\n    };\n    BfsForEachNode({root}, ForEachNext, [&](const InplaceLbiNode* node) {\n      if (node == root) { return; }\n      CHECK(safe_const_ref_nodes.emplace(node).second);\n    });\n  }\n  for (const auto* node : safe_const_ref_nodes) {\n    node->ForEachNodeOnValidOutEdge(IsValidEdge, [&](const InplaceLbiNode* out_node) {\n      if (safe_const_ref_nodes.find(out_node) == safe_const_ref_nodes.end()\n          && out_node != updt_node) {\n        CHECK(nodes.find(out_node) != nodes.end());\n        CHECK(cur_disabled_edges->emplace(out_node->GetSoleValidInEdge(IsValidEdge)).second);\n      }\n    });\n  }\n  // remove mutable inplace edges from root which are not end with model update node\n  root->ForEachNodeOnValidOutEdge(IsValidEdge, [&](const InplaceLbiNode* out_node) {\n    const auto* node = dynamic_cast<const NormalInplaceLbiNode*>(out_node);\n    if (node != nullptr && node->IsMutRef(IsValidEdge)) {\n      CHECK(nodes.find(out_node) != nodes.end());\n      CHECK(cur_disabled_edges->emplace(node->GetSoleValidInEdge(IsValidEdge)).second);\n    }\n  });\n}\n\nvoid InplaceLbiGraph::FixMutRefConflictsFromSourceOpNode(\n    const SourceOpInplaceLbiNode* root,\n    const std::function<bool(const InplaceLbiEdge*)>& IsValidEdge,\n    HashSet<const InplaceLbiEdge*>* cur_disabled_edges) const {\n  HashSet<const InplaceLbiNode*> safe_const_ref_nodes;\n  {\n    auto ForEachNext = [&](const InplaceLbiNode* node,\n                           const std::function<void(const InplaceLbiNode*)>& Handler) {\n      node->ForEachNodeOnValidOutEdge(IsValidEdge, [&](const InplaceLbiNode* out_node) {\n        if (dynamic_cast<const NormalInplaceLbiNode*>(out_node) == nullptr) {\n          Handler(out_node);\n        } else if (out_node->IsConstRef(IsValidEdge)) {\n          Handler(out_node);\n        } else {\n          // do nothing\n        }\n      });\n    };\n    BfsForEachNode({root}, ForEachNext, [&](const InplaceLbiNode* node) {\n      if (dynamic_cast<const NormalInplaceLbiNode*>(node) != nullptr) {\n        CHECK(safe_const_ref_nodes.emplace(node).second);\n      }\n    });\n  }\n  for (const auto* node : safe_const_ref_nodes) {\n    node->ForEachNodeOnValidOutEdge(IsValidEdge, [&](const InplaceLbiNode* out_node) {\n      if (safe_const_ref_nodes.find(out_node) == safe_const_ref_nodes.end()\n          && dynamic_cast<const NormalInplaceLbiNode*>(out_node) != nullptr\n          && out_node->IsMutRef(IsValidEdge)) {\n        CHECK(cur_disabled_edges->emplace(out_node->GetSoleValidInEdge(IsValidEdge)).second);\n      }\n    });\n  }\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/graph/inplace_lbi_graph.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_GRAPH_INPLACE_LBI_GRAPH_H_\n#define ONEFLOW_CORE_GRAPH_INPLACE_LBI_GRAPH_H_\n\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/operator/operator.h\"\n#include \"oneflow/core/graph/graph.h\"\n#include \"oneflow/core/register/op_blob_arg_info.h\"\n\nnamespace oneflow {\n\nclass InplaceLbiEdge;\n\nclass InplaceLbiNode : public Node<InplaceLbiNode, InplaceLbiEdge> {\n public:\n  virtual ~InplaceLbiNode() = default;\n\n  const LogicalBlobId& lbi() const { return lbi_; }\n  const InplaceLbiEdge* GetValidInEdge(\n      const std::function<bool(const InplaceLbiEdge*)>& IsValidEdge) const;\n  const InplaceLbiEdge* GetSoleValidInEdge(\n      const std::function<bool(const InplaceLbiEdge*)>& IsValidEdge) const;\n  void ForEachNodeOnValidOutEdge(const std::function<bool(const InplaceLbiEdge*)>& IsValidEdge,\n                                 const std::function<void(const InplaceLbiNode*)>& Handler) const;\n  virtual bool IsMutRef(const std::function<bool(const InplaceLbiEdge*)>& IsValidEdge) const;\n  bool IsConstRef(const std::function<bool(const InplaceLbiEdge*)>& IsValidEdge) const;\n\n  std::string VisualStr() const override { return GenLogicalBlobName(lbi_); }\n\n protected:\n  OF_DISALLOW_COPY_AND_MOVE(InplaceLbiNode);\n  explicit InplaceLbiNode(const LogicalBlobId& lbi) : lbi_(lbi) {}\n\n private:\n  LogicalBlobId lbi_;\n};\n\nclass NormalInplaceLbiNode final : public InplaceLbiNode {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(NormalInplaceLbiNode);\n  explicit NormalInplaceLbiNode(const LogicalBlobId& lbi) : InplaceLbiNode(lbi) {}\n  ~NormalInplaceLbiNode() override = default;\n\n  bool IsMutRef(const std::function<bool(const InplaceLbiEdge*)>& IsValidEdge) const override;\n};\n\nclass SourceOpInplaceLbiNode final : public InplaceLbiNode {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(SourceOpInplaceLbiNode);\n  explicit SourceOpInplaceLbiNode(const LogicalBlobId& lbi) : InplaceLbiNode(lbi) {}\n  ~SourceOpInplaceLbiNode() = default;\n};\n\nclass UpdateInplaceLbiNode final : public InplaceLbiNode {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(UpdateInplaceLbiNode);\n  explicit UpdateInplaceLbiNode(const LogicalBlobId& lbi) : InplaceLbiNode(lbi) {}\n  ~UpdateInplaceLbiNode() = default;\n};\n\nclass InplaceLbiEdge final : public Edge<InplaceLbiNode, InplaceLbiEdge> {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(InplaceLbiEdge);\n  InplaceLbiEdge(const Operator* op, const std::string& ibn, const std::string& obn,\n                 bool is_mut_ref)\n      : op_(op), ibn_(ibn), obn_(obn), is_mut_ref_(is_mut_ref) {}\n  ~InplaceLbiEdge() = default;\n\n  const Operator& op() const { return *op_; }\n  const std::string& ibn() const { return ibn_; }\n  const std::string& obn() const { return obn_; }\n  bool IsMutRef() const;\n  bool IsConstRef() const { return !IsMutRef(); }\n\n  std::string VisualStr() const override {\n    return std::string(op_->op_name() + \"/\" + ibn_ + \":\" + obn_);\n  }\n\n private:\n  const Operator* op_;\n  const std::string ibn_;\n  const std::string obn_;\n  const bool is_mut_ref_;\n};\n\nclass InplaceLbiGraph final : public Graph<const InplaceLbiNode, const InplaceLbiEdge> {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(InplaceLbiGraph);\n  InplaceLbiGraph(const InplaceObasInfo& obas_info,\n                  const std::function<const Operator*(const std::string&)>& Op4OpName) {\n    Init(obas_info, Op4OpName);\n  }\n  ~InplaceLbiGraph() = default;\n  const char* TypeName() const override { return \"InplaceLbiGraph\"; }\n\n  void ComputeSafeInplaceObns(InplaceObasInfo* obas_info,\n                              const std::function<bool(const LogicalBlobId&, const std::string&)>&\n                                  IsLbiAllConsumerReachableToOpName) const;\n\n private:\n  void Init(const InplaceObasInfo& obas_info,\n            const std::function<const Operator*(const std::string&)>& Op4OpName);\n  std::function<InplaceLbiNode*(const LogicalBlobId&)> MakeMutFindOrCreateNode(\n      std::function<const Operator*(const std::string&)> Op4OpName);\n  void ComputeSafeInplaceEdges(const std::function<bool(const LogicalBlobId&, const std::string&)>&\n                                   IsLbiAllConsumerReachableToOpName,\n                               const std::function<void(const InplaceLbiEdge*)>& Handler) const;\n  void ComputeSafeInplaceEdges(const HashSet<const InplaceLbiNode*>& nodes,\n                               const std::function<bool(const LogicalBlobId&, const std::string&)>&\n                                   IsLbiAllConsumerReachableToOpName,\n                               const std::function<void(const InplaceLbiEdge*)>& Handler) const;\n  void ForEachSafeInplaceEdgeInSourceOpSubTree(\n      const HashSet<const InplaceLbiNode*>& nodes,\n      const std::function<bool(const LogicalBlobId&, const std::string&)>&\n          IsLbiAllConsumerReachableToOpName,\n      const std::function<void(const InplaceLbiEdge*)>& Handler,\n      HashSet<const InplaceLbiEdge*>* cur_disabled_edges) const;\n  void GetSafeInplaceObnEdges(const HashSet<const InplaceLbiNode*>& nodes,\n                              const std::function<bool(const InplaceLbiEdge*)>& IsValidEdge,\n                              const std::function<bool(const LogicalBlobId&, const std::string&)>&\n                                  IsLbiAllConsumerReachableToOpName,\n                              HashSet<const InplaceLbiEdge*>* cur_disabled_edges) const;\n  const InplaceLbiEdge* FindFirstConstRefConflictMutRefEdge(\n      const HashSet<const InplaceLbiNode*>& nodes,\n      const std::function<bool(const InplaceLbiEdge*)>& IsValidEdge,\n      const std::function<bool(const LogicalBlobId&, const std::string&)>&\n          IsLbiAllConsumerReachableToOpName) const;\n\n  const InplaceLbiEdge* FindFirstIntraOpRefConflictMutRefEdge(\n      const HashSet<const InplaceLbiNode*>& nodes,\n      const std::function<bool(const InplaceLbiEdge*)>& IsValidEdge) const;\n\n  const InplaceLbiEdge* FindFirstInterOpRefConflictMutRefEdge(\n      const HashSet<const InplaceLbiNode*>& nodes,\n      const std::function<bool(const InplaceLbiEdge*)>& IsValidEdge,\n      const std::function<bool(const LogicalBlobId&, const std::string&)>&\n          IsLbiAllConsumerReachableToOpName) const;\n\n  bool IsConstRefConflictMutRefNode(\n      const InplaceLbiNode* mut_ref_node, const HashSet<const InplaceLbiNode*>& nodes,\n      const std::function<bool(const InplaceLbiEdge*)>& IsValidEdge,\n      const std::function<bool(const LogicalBlobId&, const std::string&)>&\n          IsLbiAllConsumerReachableToOpName) const;\n\n  void FixConstRefOrMutRefConflictsToUpdtNode(\n      const HashSet<const InplaceLbiNode*>& nodes,\n      const std::function<bool(const LogicalBlobId&, const std::string&)>&\n          IsLbiAllConsumerReachableToOpName,\n      HashSet<const InplaceLbiEdge*>* cur_disabled_edges) const;\n\n  void FixMutRefConflictsFromSourceOpNode(\n      const SourceOpInplaceLbiNode* root,\n      const std::function<bool(const InplaceLbiEdge*)>& IsValidEdge,\n      HashSet<const InplaceLbiEdge*>* cur_disabled_edges) const;\n\n  void ForEachTree(const HashSet<const InplaceLbiNode*>& nodes,\n                   const std::function<bool(const InplaceLbiEdge*)>& IsValidEdge,\n                   const std::function<void(const HashSet<const InplaceLbiNode*>&)>& Handler) const;\n  void FindAllEdges(const HashSet<const InplaceLbiNode*>& nodes,\n                    const std::function<bool(const InplaceLbiEdge*)>& IsValidEdge,\n                    HashSet<const InplaceLbiEdge*>* cur_disabled_edges) const;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_GRAPH_INPLACE_LBI_GRAPH_H_\n"
  },
  {
    "path": "oneflow/core/graph/inplace_regst_graph.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/graph/inplace_regst_graph.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nstd::function<const RegstDescProto*(int64_t)> MakeGetterRegstDesc4RegstDescId(\n    const HashSet<const RegstDescProto*>& regst_descs) {\n  auto regst_desc_id2regst_desc = std::make_shared<HashMap<int64_t, const RegstDescProto*>>();\n  for (const auto* regst_desc : regst_descs) {\n    CHECK(regst_desc_id2regst_desc->emplace(regst_desc->regst_desc_id(), regst_desc).second);\n  }\n  return [regst_desc_id2regst_desc](int64_t regst_desc_id) -> const RegstDescProto* {\n    auto it = regst_desc_id2regst_desc->find(regst_desc_id);\n    return it == regst_desc_id2regst_desc->end() ? nullptr : it->second;\n  };\n}\n\n}  // namespace\n\nInplaceRegstGraph::InplaceRegstGraph(const HashSet<const RegstDescProto*>& regst_descs) {\n  auto RegstDesc4RegstDescId = MakeGetterRegstDesc4RegstDescId(regst_descs);\n  auto FindOrCreate = MakeMutFindOrCreateNode();\n  for (const RegstDescProto* regst_desc : regst_descs) {\n    if (regst_desc->has_hint_inplace_consumed_regst_desc_id()) {\n      const RegstDescProto* in_regst_desc =\n          RegstDesc4RegstDescId(regst_desc->hint_inplace_consumed_regst_desc_id());\n      if (in_regst_desc != nullptr) {\n        auto* edge = new InplaceRegstEdge();\n        AddAllocatedEdge(edge);\n        Connect<InplaceRegstNode, InplaceRegstEdge>(FindOrCreate(in_regst_desc), edge,\n                                                    FindOrCreate(regst_desc));\n      }\n    }\n  }\n}\n\nstd::function<InplaceRegstNode*(const RegstDescProto*)>\nInplaceRegstGraph::MakeMutFindOrCreateNode() {\n  auto regst_desc2node = std::make_shared<HashMap<const RegstDescProto*, InplaceRegstNode*>>();\n  return [regst_desc2node, this](const RegstDescProto* regst_desc) -> InplaceRegstNode* {\n    auto it = regst_desc2node->find(regst_desc);\n    if (it == regst_desc2node->end()) {\n      InplaceRegstNode* node = new InplaceRegstNode(regst_desc);\n      AddAllocatedNode(node);\n      it = regst_desc2node->emplace(regst_desc, node).first;\n    }\n    return it->second;\n  };\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/graph/inplace_regst_graph.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_GRAPH_INPLACE_REGST_GRAPH_H_\n#define ONEFLOW_CORE_GRAPH_INPLACE_REGST_GRAPH_H_\n\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/register/register_desc.pb.h\"\n#include \"oneflow/core/graph/graph.h\"\n\nnamespace oneflow {\n\nclass InplaceRegstEdge;\nclass InplaceRegstNode final : public Node<InplaceRegstNode, InplaceRegstEdge> {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(InplaceRegstNode);\n  explicit InplaceRegstNode(const RegstDescProto* regst_desc) : regst_desc_(regst_desc) {}\n  ~InplaceRegstNode() = default;\n\n  const RegstDescProto* regst_desc() const { return regst_desc_; }\n\n private:\n  const RegstDescProto* regst_desc_;\n};\n\nclass InplaceRegstEdge final : public Edge<InplaceRegstNode, InplaceRegstEdge> {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(InplaceRegstEdge);\n  InplaceRegstEdge() = default;\n  ~InplaceRegstEdge() = default;\n};\n\nclass InplaceRegstGraph final : public Graph<const InplaceRegstNode, const InplaceRegstEdge> {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(InplaceRegstGraph);\n  explicit InplaceRegstGraph(const HashSet<const RegstDescProto*>& regst_descs);\n\n private:\n  std::function<InplaceRegstNode*(const RegstDescProto*)> MakeMutFindOrCreateNode();\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_GRAPH_INPLACE_REGST_GRAPH_H_\n"
  },
  {
    "path": "oneflow/core/graph/nccl_send_recv_boxing_task_node.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/to_string.h\"\n#include \"oneflow/core/graph/nccl_send_recv_boxing_task_node.h\"\n#include \"oneflow/core/graph/boxing_task_graph.pb.h\"\n#include \"oneflow/core/job/placement.pb.h\"\n\nnamespace oneflow {\n\nvoid NcclSendRecvBoxingTaskNode::Init(int64_t machine_id, int64_t thrd_id, const LogicalBlobId& lbi,\n                                      const Shape& logical_shape, const DataType& data_type,\n                                      const NdSbp& src_nd_sbp, const NdSbp& dst_nd_sbp,\n                                      const ParallelDesc& src_parallel_desc,\n                                      const ParallelDesc& dst_parallel_desc,\n                                      const int64_t parallel_id, const ParallelDesc& parallel_desc,\n                                      const bool has_input, const bool has_output,\n                                      const std::string& stream_name) {\n  set_machine_id(machine_id);\n  set_thrd_id(thrd_id);\n  set_lbi(lbi);\n  logical_shape_ = logical_shape;\n  src_nd_sbp_ = src_nd_sbp;\n  dst_nd_sbp_ = dst_nd_sbp;\n  src_parallel_conf_ = src_parallel_desc.parallel_conf();\n  dst_parallel_conf_ = dst_parallel_desc.parallel_conf();\n  parallel_conf_ = parallel_desc.parallel_conf();\n  parallel_ctx_.set_parallel_id(parallel_id);\n  parallel_ctx_.set_parallel_num(parallel_desc.parallel_num());\n  has_input_ = has_input;\n  has_output_ = has_output;\n  data_type_ = data_type;\n  stream_name_ = stream_name;\n}\n\nvoid NcclSendRecvBoxingTaskNode::ProduceAllRegstsAndBindEdges() {\n  if (has_output_) {\n    std::shared_ptr<RegstDesc> out_regst = ProduceRegst(\"out\", true, 1, 1);\n    this->ForEachOutDataEdge([&](TaskEdge* out_dege) { out_dege->AddRegst(\"out\", out_regst); });\n  }\n  ProduceRegst(\"tmp\", true);\n}\n\nvoid NcclSendRecvBoxingTaskNode::ConsumeAllRegsts() {\n  this->ForEachInDataEdge(\n      [&](TaskEdge* in_edge) { ConsumeRegst(\"in\", SoleInDataEdge()->GetSoleRegst()); });\n}\n\nvoid NcclSendRecvBoxingTaskNode::BuildExecGphAndRegst() {\n  ExecNode* node = mut_exec_gph().NewNode();\n  OperatorConf op_conf;\n  op_conf.set_name(\"System-Nccl-Send-Recv-Boxing-\" + NewUniqueId());\n  op_conf.set_device_tag(*CHECK_JUST(DeviceTag4DeviceType(this->device_type())));\n  op_conf.set_stream_name_hint(stream_name_);\n  auto* nccl_send_recv_boxing_conf = op_conf.mutable_nccl_send_recv_boxing_conf();\n  *nccl_send_recv_boxing_conf->mutable_lbi() = lbi();\n  logical_shape_.ToProto(nccl_send_recv_boxing_conf->mutable_logical_shape());\n  nccl_send_recv_boxing_conf->set_data_type(data_type_);\n  *nccl_send_recv_boxing_conf->mutable_src_nd_sbp() = src_nd_sbp_;\n  *nccl_send_recv_boxing_conf->mutable_dst_nd_sbp() = dst_nd_sbp_;\n  *nccl_send_recv_boxing_conf->mutable_parallel_conf() = parallel_conf_;\n  *nccl_send_recv_boxing_conf->mutable_src_parallel_conf() = src_parallel_conf_;\n  *nccl_send_recv_boxing_conf->mutable_dst_parallel_conf() = dst_parallel_conf_;\n  nccl_send_recv_boxing_conf->set_has_input(has_input_);\n  nccl_send_recv_boxing_conf->set_has_output(has_output_);\n  std::shared_ptr<Operator> sole_op = CHECK_JUST(ConstructOp(op_conf));\n  node->mut_op() = sole_op;\n  CHECK_JUST(sole_op->FillOpParallelDesc(parallel_conf_));\n  if (has_input_) { node->BindBnWithRegst(sole_op->SoleIbn(), GetSoleConsumedRegst(\"in\")); }\n  if (has_output_) {\n    std::shared_ptr<RegstDesc> out_regst = GetProducedRegst(\"out\");\n    out_regst->AddLbi(sole_op->BnInOp2Lbi(sole_op->SoleObn()));\n    node->BindBnWithRegst(sole_op->SoleObn(), out_regst);\n  }\n  node->AddBnToRegstAndBindIt(&Operator::tmp_bns, GetProducedRegst(\"tmp\"));\n  (node->*GetInferBlobDescsMethod())(parallel_ctx());\n}\n\nvoid NcclSendRecvBoxingTaskNode::InferProducedDataRegstTimeShape() {\n  auto out_regst = GetProducedRegst(\"out\");\n  if (out_regst != nullptr) { out_regst->mut_data_regst_time_shape()->reset(new Shape({1, 1})); }\n  auto tmp_regst = GetProducedRegst(\"tmp\");\n  tmp_regst->mut_data_regst_time_shape()->reset(new Shape({1, 1}));\n}\n\nMaybe<void> NcclSendRecvBoxingTaskNode::InitTransportTaskFromProto(\n    const TransportTaskProto& transport_task_proto, const TaskGraphRebuildCtx& ctx) {\n  CHECK_OR_RETURN(transport_task_proto.has_nccl_send_recv_boxing_task())\n      << \"not a serialized NcclSendRecvBoxingTaskNode. debug string: \"\n      << transport_task_proto.DebugString();\n  const auto& proto = transport_task_proto.nccl_send_recv_boxing_task();\n  logical_shape_ = Shape(proto.logical_shape());\n  data_type_ = proto.data_type();\n  src_nd_sbp_ = proto.src_nd_sbp();\n  dst_nd_sbp_ = proto.dst_nd_sbp();\n  src_parallel_conf_ = proto.src_parallel_conf();\n  dst_parallel_conf_ = proto.dst_parallel_conf();\n  parallel_conf_ = proto.parallel_conf();\n  parallel_ctx_ = proto.parallel_ctx();\n  has_input_ = proto.has_input();\n  has_output_ = proto.has_output();\n  stream_name_ = proto.stream_name();\n  return Maybe<void>::Ok();\n}\n\nvoid NcclSendRecvBoxingTaskNode::ToTransportTaskProto(\n    TransportTaskProto* transport_task_proto) const {\n  ToProto(transport_task_proto->mutable_task_proto(), /*check=*/false);\n  auto* proto = transport_task_proto->mutable_nccl_send_recv_boxing_task();\n  logical_shape_.ToProto(proto->mutable_logical_shape());\n  proto->set_data_type(data_type_);\n  *proto->mutable_src_nd_sbp() = src_nd_sbp_;\n  *proto->mutable_dst_nd_sbp() = dst_nd_sbp_;\n  *proto->mutable_src_parallel_conf() = src_parallel_conf_;\n  *proto->mutable_dst_parallel_conf() = dst_parallel_conf_;\n  *proto->mutable_parallel_conf() = parallel_conf_;\n  *proto->mutable_parallel_ctx() = parallel_ctx_;\n  proto->set_has_input(has_input_);\n  proto->set_has_output(has_output_);\n  proto->set_stream_name(stream_name_);\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/graph/nccl_send_recv_boxing_task_node.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_GRAPH_NCCL_SEND_RECV_BOXING_TASK_NODE_H_\n#define ONEFLOW_CORE_GRAPH_NCCL_SEND_RECV_BOXING_TASK_NODE_H_\n\n#include \"oneflow/core/graph/transport_task_node.h\"\n#include \"oneflow/core/graph/boxing_task_graph.pb.h\"\n#include \"oneflow/core/job/placement.pb.h\"\n\nnamespace oneflow {\n\nclass NcclSendRecvBoxingTaskNode : public TransportTaskNode {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(NcclSendRecvBoxingTaskNode);\n  NcclSendRecvBoxingTaskNode() = default;\n  ~NcclSendRecvBoxingTaskNode() override = default;\n\n  void Init(int64_t machine_id, int64_t thrd_id, const LogicalBlobId& lbi,\n            const Shape& logical_shape, const DataType& data_type, const NdSbp& src_nd_sbp,\n            const NdSbp& dst_nd_sbp, const ParallelDesc& src_parallel_desc,\n            const ParallelDesc& dst_parallel_desc, const int64_t parallel_id,\n            const ParallelDesc& parallel_desc, const bool has_input, const bool has_output,\n            const std::string& stream_name);\n  TaskType GetTaskType() const override { return TaskType::kNcclSendRecvBoxing; }\n  const ParallelContext* parallel_ctx() const override { return &parallel_ctx_; }\n\n  Maybe<void> InitTransportTaskFromProto(const TransportTaskProto& transport_task_proto,\n                                         const TaskGraphRebuildCtx& ctx) override;\n  void ToTransportTaskProto(TransportTaskProto*) const override;\n\n private:\n  void BuildExecGphAndRegst() override;\n  void ProduceAllRegstsAndBindEdges() override;\n  void ConsumeAllRegsts() final;\n  void InferProducedDataRegstTimeShape() final;\n\n  Shape logical_shape_;\n  DataType data_type_;\n  NdSbp src_nd_sbp_;\n  NdSbp dst_nd_sbp_;\n  ParallelConf src_parallel_conf_;\n  ParallelConf dst_parallel_conf_;\n  ParallelConf parallel_conf_;\n  ParallelContext parallel_ctx_;\n  bool has_input_;\n  bool has_output_;\n  std::string stream_name_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_GRAPH_NCCL_SEND_RECV_BOXING_TASK_NODE_H_\n"
  },
  {
    "path": "oneflow/core/graph/node.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/graph/node.h\"\n\nnamespace oneflow {\n\nint64_t NewNodeId() {\n  static int64_t node_id = 0;\n  return node_id++;\n}\n\nint64_t NewEdgeId() {\n  static int64_t edge_id = 0;\n  return edge_id++;\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/graph/node.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_GRAPH_NODE_H_\n#define ONEFLOW_CORE_GRAPH_NODE_H_\n\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/job/id_manager.h\"\n\nnamespace oneflow {\n\ntemplate<typename NodeType, typename EdgeType>\nvoid Connect(NodeType* src_node, EdgeType* edge, NodeType* dst_node) {\n  CHECK(src_node->out_edges_.insert(edge).second);\n  CHECK(dst_node->in_edges_.insert(edge).second);\n  CHECK(edge->src_node_ == nullptr);\n  CHECK(edge->dst_node_ == nullptr);\n  edge->src_node_ = src_node;\n  edge->dst_node_ = dst_node;\n}\n\ntemplate<typename EdgeType>\nvoid DisConnect(EdgeType* edge) {\n  CHECK_EQ(edge->src_node_->out_edges_.erase(edge), 1);\n  CHECK_EQ(edge->dst_node_->in_edges_.erase(edge), 1);\n  edge->src_node_ = nullptr;\n  edge->dst_node_ = nullptr;\n}\n\nint64_t NewNodeId();\nint64_t NewEdgeId();\n\ntemplate<typename NodeType, typename EdgeType>\nclass Edge {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(Edge);\n  Edge() {\n    edge_id_ = NewEdgeId();\n    src_node_ = nullptr;\n    dst_node_ = nullptr;\n  }\n  virtual ~Edge() = default;\n\n  int64_t edge_id() const { return edge_id_; }\n\n  NodeType* src_node() const { return src_node_; }\n  NodeType* dst_node() const { return dst_node_; }\n\n  virtual std::string VisualStr() const { return \"\"; }\n\n private:\n  friend void Connect<NodeType, EdgeType>(NodeType* src_node, EdgeType* edge, NodeType* dst_node);\n  friend void DisConnect<EdgeType>(EdgeType* edge);\n\n  int64_t edge_id_;\n\n  NodeType* src_node_;\n  NodeType* dst_node_;\n};\n\ntemplate<typename NodeType, typename EdgeType>\nclass Node {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(Node);\n  Node() { node_id_ = NewNodeId(); }\n  virtual ~Node() = default;\n\n  int64_t node_id() const { return node_id_; }\n  std::string node_id_str() const { return std::to_string(node_id_); }\n\n  EdgeType* SoleInEdge() const {\n    CHECK_EQ(in_edges_.size(), 1);\n    return *(in_edges_.begin());\n  }\n  EdgeType* SoleOutEdge() const {\n    CHECK_EQ(out_edges_.size(), 1);\n    return *(out_edges_.begin());\n  }\n\n  const std::unordered_set<EdgeType*>& in_edges() const { return in_edges_; }\n  const std::unordered_set<EdgeType*>& out_edges() const { return out_edges_; }\n\n  void ForEachNodeOnInEdge(std::function<void(NodeType*)> Handler) const {\n    for (EdgeType* edge : in_edges_) { Handler(edge->src_node()); }\n  }\n  void ForEachNodeOnOutEdge(std::function<void(NodeType*)> Handler) const {\n    for (EdgeType* edge : out_edges_) { Handler(edge->dst_node()); }\n  }\n  void ForEachNodeOnInOutEdge(std::function<void(NodeType*)> Handler) const {\n    ForEachNodeOnInEdge(Handler);\n    ForEachNodeOnOutEdge(Handler);\n  }\n  Maybe<void> ForEachInNode(std::function<Maybe<void>(NodeType*)> Handler) const {\n    for (EdgeType* edge : in_edges_) { JUST(Handler(edge->src_node())); }\n    return Maybe<void>::Ok();\n  }\n  Maybe<void> ForEachOutNode(std::function<Maybe<void>(NodeType*)> Handler) const {\n    for (EdgeType* edge : out_edges_) { JUST(Handler(edge->dst_node())); }\n    return Maybe<void>::Ok();\n  }\n  Maybe<void> ForEachInOutNode(std::function<Maybe<void>(NodeType*)> Handler) const {\n    JUST(ForEachNodeOnInEdge(Handler));\n    JUST(ForEachNodeOnOutEdge(Handler));\n    return Maybe<void>::Ok();\n  }\n\n  void ForEachNodeOnSortedInEdge(std::function<void(NodeType*)> Handler) const {\n    for (EdgeType* edge : sorted_in_edges_) { Handler(edge->src_node()); }\n  }\n  void ForEachNodeOnSortedOutEdge(std::function<void(NodeType*)> Handler) const {\n    for (EdgeType* edge : sorted_out_edges_) { Handler(edge->dst_node()); }\n  }\n  void ForEachNodeOnSortedInOutEdge(std::function<void(NodeType*)> Handler) const {\n    ForEachNodeOnSortedInEdge(Handler);\n    ForEachNodeOnSortedOutEdge(Handler);\n  }\n\n  void DisconnectAllEdges() {\n    for (EdgeType* edge : in_edges_) { DisConnect(edge); }\n    for (EdgeType* edge : out_edges_) { DisConnect(edge); }\n  }\n\n  virtual std::string VisualStr() const { return \"\"; }\n\n  void SortInOutEdges(std::function<bool(const EdgeType* lhs, const EdgeType* rhs)> LessThan) {\n    sorted_in_edges_.assign(in_edges_.begin(), in_edges_.end());\n    sorted_out_edges_.assign(out_edges_.begin(), out_edges_.end());\n    std::sort(sorted_in_edges_.begin(), sorted_in_edges_.end(), LessThan);\n    std::sort(sorted_out_edges_.begin(), sorted_out_edges_.end(), LessThan);\n  }\n\n private:\n  friend void Connect<NodeType, EdgeType>(NodeType* src_node, EdgeType* edge, NodeType* dst_node);\n  friend void DisConnect<EdgeType>(EdgeType* edge);\n\n  int64_t node_id_;\n  HashSet<EdgeType*> in_edges_;\n  HashSet<EdgeType*> out_edges_;\n  std::vector<EdgeType*> sorted_in_edges_;\n  std::vector<EdgeType*> sorted_out_edges_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_GRAPH_NODE_H_\n"
  },
  {
    "path": "oneflow/core/graph/normal_forward_compute_task_node.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_GRAPH_NORMAL_FORWARD_COMPUTE_TASK_NODE_H_\n#define ONEFLOW_CORE_GRAPH_NORMAL_FORWARD_COMPUTE_TASK_NODE_H_\n\n#include \"oneflow/core/graph/compute_task_node.h\"\n\nnamespace oneflow {\n\nsize_t RegstNum4Op(const Operator& sole_op);\n\nclass NormalForwardCompTaskNode final : public CompTaskNode {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(NormalForwardCompTaskNode);\n  NormalForwardCompTaskNode() = default;\n  ~NormalForwardCompTaskNode() = default;\n\n  void ProduceAllRegstsAndBindEdges() override;\n  void ConsumeAllRegsts() override;\n  void ConsumeFakeRegsts() override;\n\n  TaskType GetTaskType() const override { return TaskType::kNormalForward; }\n\n private:\n  void ProduceOutRegstByNameAndBlockNum(const std::string& name, size_t mem_block_num);\n  void BuildExecGphAndRegst() override;\n  void BuildExecGphStructAndBindInRegst();\n  void BuildOutRegst();\n  void BuildTmp7BufRegsts();\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_GRAPH_NORMAL_FORWARD_COMPUTE_TASK_NODE_H_\n"
  },
  {
    "path": "oneflow/core/graph/op_graph.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/graph/op_graph.h\"\n#include \"oneflow/core/job/job_builder.h\"\n#include \"oneflow/core/job/local_sig_infer_hint.h\"\n#include \"oneflow/core/job/lazy_mode.h\"\n#include \"oneflow/core/common/container_util.h\"\n#include \"oneflow/core/persistence/tee_persistent_log_stream.h\"\n#include \"oneflow/core/auto_parallel/algorithm_util.h\"\n#include \"oneflow/core/framework/nd_sbp.h\"\n#include \"oneflow/core/framework/sbp_infer_util.h\"\n\nnamespace oneflow {\n\nbool OpEdge::NeedBoxing() const {\n  if (src_node()->parallel_desc_sym() != dst_node()->parallel_desc_sym()) { return true; }\n  if (src_node()->parallel_desc().parallel_num() == 1) { return false; }\n  for (const auto& lbi : *lbis_) {\n    Shape src_reduced_hierarchy;\n    Shape dst_reduced_hierarchy;\n    NdSbp src_reduced_nd_sbp;\n    NdSbp dst_reduced_nd_sbp;\n\n    InOutParallelDimReduce(*src_node()->parallel_desc().hierarchy(),\n                           *dst_node()->parallel_desc().hierarchy(), src_node()->NdSbp4Lbi(lbi),\n                           dst_node()->NdSbp4Lbi(lbi), &src_reduced_hierarchy,\n                           &dst_reduced_hierarchy, &src_reduced_nd_sbp, &dst_reduced_nd_sbp,\n                           src_node()->LogicalBlobDesc4Lbi(lbi).shape());\n    if (src_reduced_hierarchy != dst_reduced_hierarchy\n        || src_reduced_nd_sbp != dst_reduced_nd_sbp) {\n      // Not one to one\n      return true;\n    }\n  }\n  return false;\n}\n\nstd::string OpEdge::VisualStr() const {\n  std::string str;\n  int32_t idx = 0;\n  for (const LogicalBlobId& lbi : *lbis_) {\n    if (idx++ > 0) { str += \"\\\\n\"; }\n    str += lbi.blob_name() + \":\";\n    str += src_node()->LogicalBlobDesc4Lbi(lbi).shape().ToString();\n  }\n  return str;\n}\n\nconst SbpParallel& OpNode::SbpParallel4BnInOp(const std::string& bn_in_op) const {\n  return *CHECK_JUST(op().SbpParallel4BnInOp(bn_in_op));\n}\n\nconst SbpParallel& OpNode::SbpParallel4Lbi(const LogicalBlobId& lbi) const {\n  auto it = lbi2nd_sbp_.find(lbi);\n  CHECK(it != lbi2nd_sbp_.end());\n  CHECK_EQ(it->second.sbp_parallel_size(), 1);\n  return it->second.sbp_parallel(0);\n}\n\nconst NdSbp& OpNode::NdSbp4BnInOp(const std::string& bn_in_op) const {\n  return *CHECK_JUST(op().NdSbp4BnInOp(bn_in_op));\n}\n\nconst NdSbp& OpNode::NdSbp4Lbi(const LogicalBlobId& lbi) const {\n  auto it = lbi2nd_sbp_.find(lbi);\n  CHECK(it != lbi2nd_sbp_.end());\n  return it->second;\n}\n\nOpNode::OpNode(Symbol<ParallelDesc> parallel_desc, const OperatorConf& op_conf)\n    : parallel_desc_(parallel_desc),\n      op_(CHECK_JUST(ConstructOp(op_conf, parallel_desc->device_type()))),\n      ibns_(op_->input_bns().begin(), op_->input_bns().end()) {\n  CHECK_JUST(op_->FillOpParallelDesc(parallel_desc.shared_from_symbol()));\n}\n\nstd::string OpNode::VisualStr() const {\n  std::string str = op().op_name();\n  {\n    for (int64_t machine_id : parallel_desc().sorted_machine_ids()) {\n      const std::string dev_type = *CHECK_JUST(DeviceTag4DeviceType(parallel_desc().device_type()));\n\n      std::string parallel_desc_str = std::to_string(machine_id) + \":\" + dev_type + \":\";\n      const auto& dev_phy_ids = parallel_desc().sorted_dev_phy_ids(machine_id);\n      parallel_desc_str += std::to_string(dev_phy_ids.front());\n      if (dev_phy_ids.back() > dev_phy_ids.front()) {\n        parallel_desc_str += \"-\" + std::to_string(dev_phy_ids.back());\n      }\n      str += \"\\\\n\" + parallel_desc_str;\n    }\n  }\n  auto GetTimeShapeStr = [&](const Shape& shape, const std::string& prefix) {\n    std::string time_shape_str = prefix + \":\";\n    time_shape_str += shape.ToString();\n    return time_shape_str;\n  };\n  if (in_edges().empty() == false) {\n    str +=\n        \"\\\\n\"\n        + GetTimeShapeStr(*CHECK_JUST(op().GetInputBlobFastestTimeShape()), \"in_blob_time_shape\");\n  }\n  str += \"\\\\n\" + GetTimeShapeStr(*CHECK_JUST(op().GetOpTimeShape()), \"op_time_shape\");\n  return str;\n}\n\nconst BlobDesc& OpNode::LogicalBlobDesc4Lbi(const LogicalBlobId& lbi) const {\n  const OpNode& producer = ProducerOpNode4Lbi(lbi);\n  const int32_t index = CHECK_JUST(producer.op().GetOutputIndex(lbi));\n  const BlobDesc* blob_desc = CHECK_JUST(producer.op().GetLogicalBlobDescPtr4OutputIndex(index));\n  return *blob_desc;\n}\n\nconst OpNode& OpNode::SrcNode4Ibn(const std::string& bn_in_op) const {\n  return *MutSrcNode4Ibn(bn_in_op);\n}\n\nOpNode* OpNode::MutSrcNode4Ibn(const std::string& bn_in_op) const {\n  const LogicalBlobId& lbi = op().BnInOp2Lbi(bn_in_op);\n  CHECK(ibns_.find(bn_in_op) != ibns_.end());\n  return MutSrcNode4InputLbi(lbi);\n}\n\nconst OpNode& OpNode::ProducerOpNode4Lbi(const LogicalBlobId& lbi) const {\n  const OpNode* producer = MutSrcNode4InputLbi(lbi);\n  if (producer == nullptr) { producer = this; }\n  return *producer;\n}\n\nOpNode* OpNode::MutSrcNode4InputLbi(const LogicalBlobId& lbi) const {\n  auto it = lbi2source_node_.find(lbi);\n  if (it == lbi2source_node_.end()) {\n    return nullptr;\n  } else {\n    return it->second;\n  }\n}\n\nbool OpNode::IsTimeShapeIdentity() const {\n  std::shared_ptr<const Shape> in_shape = CHECK_JUST(op().GetInputBlobFastestTimeShape());\n  if (!in_shape) { return true; }\n  std::shared_ptr<const Shape> op_shape = CHECK_JUST(op().GetOpTimeShape());\n  return *in_shape == *op_shape;\n}\n\nvoid OpNode::InitLbi2SourceNode() {\n  for (OpEdge* edge : in_edges()) {\n    for (const LogicalBlobId& lbi : edge->lbis()) {\n      CHECK(lbi2source_node_.emplace(lbi, edge->src_node()).second);\n    }\n  }\n}\n\nvoid OpNode::InitLbi2NdSbp() {\n  const auto Update = [&](const PbRpf<std::string>& bns) {\n    for (const auto& bn : bns) {\n      const LogicalBlobId& lbi = op().BnInOp2Lbi(bn);\n      const NdSbp& nd_sbp = NdSbp4BnInOp(bn);\n      auto it = lbi2nd_sbp_.find(lbi);\n      if (it == lbi2nd_sbp_.end()) {\n        lbi2nd_sbp_[lbi] = nd_sbp;\n      } else {\n        CHECK(it->second == nd_sbp);\n      }\n    }\n  };\n  Update(op().input_bns());\n  Update(op().output_bns());\n}\n\nMaybe<OpGraph> OpGraph::New(const Job& job) {\n  const auto& op_graph = std::make_shared<OpGraph>();\n  JUST(op_graph->Init(job));\n  return op_graph;\n}\n\nMaybe<void> OpGraph::Init(const Job& job) {\n  InitNodes(job);\n  op_name2op_node_.reserve(job.net().op_size());\n  ForEachNode([&](OpNode* node) {\n    CHECK(op_name2op_node_.emplace(node->op().op_name(), node).second)\n        << \"op_name: \" << node->op().op_name();\n  });\n  InitEdges();\n  InitProducerOpName2CtrlConsumerOpNames(job);\n  CheckIsDAG();\n  ForEachNode([](OpNode* node) { node->InitLbi2SourceNode(); });\n  InferBlobLastUsed();\n  InferTimeShape();\n  {\n    LazyMode::Guard enable_lazy_mode_guard(true);\n    JUST(InferLogicalBlobDesc(job));\n  }\n  return Maybe<void>::Ok();\n}\n\nvoid OpGraph::CheckIsDAG() const {\n  CHECK(!FindFirstNontrivialSCC());\n  auto ForEachIn = [&](OpNode* node, const std::function<void(OpNode*)>& Handler) {\n    ForEachDataAndCtrlInNode(node, Handler);\n  };\n  auto ForEachOut = [&](OpNode* node, const std::function<void(OpNode*)>& Handler) {\n    ForEachDataAndCtrlOutNode(node, Handler);\n  };\n  CHECK(!FindFirstNontrivialSCC(ForEachIn, ForEachOut));\n}\n\nnamespace {\n\nstd::function<Symbol<ParallelDesc>(const std::string&)> MakeGetterParallelDesc4OpName(\n    const Job& job) {\n  const Placement& placement = job.placement();\n  auto op_name2parallel_desc = std::make_shared<HashMap<std::string, Symbol<ParallelDesc>>>();\n  op_name2parallel_desc->reserve(job.net().op_size());\n  for (const auto& placement_group : placement.placement_group()) {\n    const ParallelConf& parallel_conf = placement_group.parallel_conf();\n    Symbol<ParallelDesc> parallel_desc = SymbolOf(ParallelDesc(parallel_conf));\n    for (const std::string& op_name : placement_group.op_set().op_name()) {\n      CHECK(op_name2parallel_desc->emplace(op_name, parallel_desc).second)\n          << \"op_name: \" << op_name;\n    }\n  }\n  return [op_name2parallel_desc](const std::string& op_name) {\n    return op_name2parallel_desc->at(op_name);\n  };\n}\n\n}  // namespace\n\nvoid OpGraph::InitNodes(const Job& job) {\n  auto ParallelDesc4OpName = MakeGetterParallelDesc4OpName(job);\n  for (const auto& op_conf : job.net().op()) {\n    op_names_.emplace_back(op_conf.name());\n    OpNode* node = new OpNode(ParallelDesc4OpName(op_conf.name()), op_conf);\n    AddAllocatedNode(node);\n  }\n}\n\nvoid OpGraph::InitEdges() {\n  HashMap<LogicalBlobId, OpNode*> lbi2producer;\n  HashMap<std::string, std::shared_ptr<HashMap<LogicalBlobId, std::string>>>\n      producer_op_name2lbi2obn;\n  ForEachNode([&](OpNode* op_node) {\n    for (const auto& obn : op_node->op().output_bns()) {\n      const auto& lbi = op_node->op().BnInOp2Lbi(obn);\n      CHECK(lbi2producer.emplace(lbi, op_node).second);\n      auto& lbi2obn = producer_op_name2lbi2obn[op_node->op().op_name()];\n      if (!lbi2obn) { lbi2obn.reset(new HashMap<LogicalBlobId, std::string>()); }\n      CHECK(lbi2obn->emplace(lbi, obn).second);\n    }\n  });\n  ForEachNode([&](OpNode* op_node) {\n    HashMap<std::string, HashSet<LogicalBlobId>> producer_op_name2lbis;\n    std::shared_ptr<HashMap<LogicalBlobId, std::vector<std::string>>> consumer_lbi2ibns(\n        new HashMap<LogicalBlobId, std::vector<std::string>>);\n    op_node->input_index2producer_and_output_index_.reserve(op_node->op().input_bns().size());\n    for (const auto& ibn : op_node->op().input_bns()) {\n      const LogicalBlobId& lbi = op_node->op().BnInOp2Lbi(ibn);\n      producer_op_name2lbis[lbi.op_name()].insert(lbi);\n      (*consumer_lbi2ibns)[lbi].emplace_back(ibn);\n      auto producer_it = lbi2producer.find(lbi);\n      CHECK(producer_it != lbi2producer.end()) << \"producer not found: \" << GenLogicalBlobName(lbi);\n      const int32_t output_index = CHECK_JUST(producer_it->second->op().GetOutputIndex(lbi));\n      op_node->input_index2producer_and_output_index_.emplace_back(producer_it->second,\n                                                                   output_index);\n    }\n    for (const auto& pair : producer_op_name2lbis) {\n      std::shared_ptr<std::vector<LogicalBlobId>> lbis(\n          new std::vector<LogicalBlobId>({pair.second.begin(), pair.second.end()}));\n      const auto it = producer_op_name2lbi2obn.find(pair.first);\n      CHECK(it != producer_op_name2lbi2obn.end()) << \"producer_op_name: \" << pair.first;\n      const auto& lbi2obn = it->second;\n      auto producer_it = lbi2producer.find(lbis->front());\n      CHECK(producer_it != lbi2producer.end())\n          << \"producer not found: \" << GenLogicalBlobName(lbis->front());\n      Connect(producer_it->second, NewEdge(lbis, lbi2obn, consumer_lbi2ibns), op_node);\n    }\n  });\n}\n\nvoid OpGraph::InitProducerOpName2CtrlConsumerOpNames(const Job& job) {\n  for (const auto& op_conf : job.net().op()) {\n    for (const auto& ctrl_in_op_name : op_conf.ctrl_in_op_name()) {\n      auto* consumer_op_names = &producer_op_name2ctrl_consumer_op_names_[ctrl_in_op_name];\n      CHECK(consumer_op_names->emplace(op_conf.name()).second);\n    }\n  }\n}\n\nvoid OpGraph::InferBlobLastUsed() const {\n  HashSet<LogicalBlobId> visisted_lbi;\n  for (auto iter = op_names_.rbegin(); iter != op_names_.rend(); iter++) {\n    Operator* op = op_name2op_node_.at(*iter)->mut_op();\n    auto* map = op->mut_blob_last_used_signature()->mutable_bn_in_op2blob_last_used();\n    const auto InferLastUsed = [&](const std::string& bn_in_op) {\n      (*map)[bn_in_op] = visisted_lbi.insert(op->BnInOp2Lbi(bn_in_op)).second;\n    };\n    for (const auto& obn : op->output_bns()) { InferLastUsed(obn); }\n    for (const auto& ibn : op->input_bns()) { InferLastUsed(ibn); }\n  }\n}\n\nvoid OpGraph::InferTimeShape() const {\n  TopoForEachNode([&](OpNode* op_node) {\n    auto GetInputBlobTimeShape = [&](int32_t index) -> Maybe<const Shape> {\n      CHECK_LT_OR_RETURN(index, op_node->input_index2producer_and_output_index_.size());\n      return op_node->input_index2producer_and_output_index_.at(index).first->op().GetOpTimeShape();\n    };\n    CHECK_JUST(op_node->mut_op()->FillInputBlobTimeShape(GetInputBlobTimeShape));\n    CHECK_JUST(op_node->mut_op()->InferOpTimeShapeIf());\n  });\n}\n\nvoid OpGraph::InferOpNodeNdSbpSignature(OpNode* op_node,\n                                        const NdSbpSignature& nd_sbp_sig_conf) const {\n  HashMap<std::string, NdSbpInferHint> ibn2nd_sbp_infer_hint;\n  for (const std::string& ibn : op_node->op().input_bns()) {\n    const LogicalBlobId& lbi = op_node->op().BnInOp2Lbi(ibn);\n    OpNode* producer = op_node->MutSrcNode4Ibn(ibn);\n    const std::string& producer_lbn = *CHECK_JUST(producer->op().obn4lbi(lbi));\n    const ParallelDesc* parallel_desc =\n        CHECK_JUST(producer->op().GetParallelDesc4BnInOp(producer_lbn)).get();\n    const BlobDesc* logical_blob_desc = &producer->LogicalBlobDesc4Lbi(lbi);\n    const NdSbp* nd_sbp = &producer->NdSbp4Lbi(lbi);\n    ibn2nd_sbp_infer_hint.emplace(ibn, NdSbpInferHint(parallel_desc, logical_blob_desc, nd_sbp));\n  }\n  const auto NdSbpInferHint4Ibn = [&](const std::string& bn) -> Maybe<const NdSbpInferHint*> {\n    auto it = ibn2nd_sbp_infer_hint.find(bn);\n    CHECK_OR_RETURN(it != ibn2nd_sbp_infer_hint.end());\n    return Maybe<const NdSbpInferHint*>(&it->second);\n  };\n  CHECK_JUST(op_node->mut_op()->InferNdSbpSignatureIf(nd_sbp_sig_conf, op_node->parallel_desc(),\n                                                      NdSbpInferHint4Ibn));\n  op_node->InitLbi2NdSbp();\n}\n\nMaybe<void> OpGraph::InferOpNodeLocalSignature(OpNode* op_node, bool is_local_conf) const {\n  HashMap<std::string, LocalSigInferHint> ibn2local_sig_infer_hint;\n  for (const std::string& ibn : op_node->op().input_bns()) {\n    const LogicalBlobId& lbi = op_node->op().BnInOp2Lbi(ibn);\n    const auto* producer = op_node->MutSrcNode4Ibn(ibn);\n    const ParallelDesc* parallel_desc = &producer->parallel_desc();\n    const auto& producer_obn = *JUST(producer->op().obn4lbi(lbi));\n    const auto& opt_local_parallel = *JUST(producer->op().OptLocalParallel4BnInOp(producer_obn));\n    LocalSigInferHint infer_ctx(parallel_desc, opt_local_parallel.has_local_parallel());\n    ibn2local_sig_infer_hint.emplace(ibn, infer_ctx);\n  }\n  const auto& LocalSigInferHint4Ibn =\n      [&](const std::string& ibn) -> Maybe<const LocalSigInferHint*> {\n    const auto& iter = ibn2local_sig_infer_hint.find(ibn);\n    CHECK_OR_RETURN(iter != ibn2local_sig_infer_hint.end()) << \"input blob not found. ibn: \" << ibn;\n    return &iter->second;\n  };\n  JUST(op_node->mut_op()->InferLocalSignatureIf(LocalSigInferHint4Ibn, is_local_conf,\n                                                op_node->parallel_desc()));\n  return Maybe<void>::Ok();\n}\n\nconst OpNode* OpGraph::OpNode4OpName(const std::string& op_name) const {\n  const auto& op_node_it = op_name2op_node_.find(op_name);\n  if (op_node_it == op_name2op_node_.end()) { return nullptr; }\n  return op_node_it->second;\n}\n\nMaybe<void> OpGraph::InferLogicalBlobDesc(const Job& job) const {\n  JobParallelViewConf job_parallel_view_conf(job.job_parallel_view_conf());\n  JUST(TopoForEachNodeWithErrorCaptured([&](OpNode* op_node) -> Maybe<void> {\n    auto LogicalBlobDesc4InputIndex = [&](int32_t index) -> Maybe<const BlobDesc> {\n      CHECK_LT_OR_RETURN(index, op_node->input_index2producer_and_output_index_.size());\n      const auto& producer_info = op_node->input_index2producer_and_output_index_.at(index);\n      return producer_info.first->op().GetLogicalBlobDesc4OutputIndex(producer_info.second);\n    };\n    JUST(op_node->mut_op()->FillLogicalInBlobDesc(LogicalBlobDesc4InputIndex));\n    // Infer ParallelSignature\n    JUST(op_node->mut_op()->InferParallelSignatureIf());\n    // Infer local_signature\n    bool is_local_conf = false;\n    {\n      const auto& op_name2is_local = job_parallel_view_conf.op_name2is_local_parallel_view();\n      const auto& iter = op_name2is_local.find(op_node->op().op_name());\n      if (iter != op_name2is_local.end()) { is_local_conf = iter->second; }\n    }\n    JUST(InferOpNodeLocalSignature(op_node, is_local_conf));\n    NdSbpSignature nd_sbp_sig_conf;\n    {\n      const auto& op_name2nd_sbp_sig_conf = job_parallel_view_conf.op_name2nd_sbp_signature_conf();\n      const auto& iter = op_name2nd_sbp_sig_conf.find(op_node->op().op_name());\n      if (iter != op_name2nd_sbp_sig_conf.end()) {\n        nd_sbp_sig_conf = NdSbpSignature(iter->second);\n        if (op_node->parallel_desc().hierarchy()->NumAxes() == 1) {\n          const auto& op_name2sbp_sig_conf = job_parallel_view_conf.op_name2sbp_signature_conf();\n          const auto& op_name2sbp_sig_conf_it = op_name2sbp_sig_conf.find(op_node->op().op_name());\n          CHECK_OR_RETURN(op_name2sbp_sig_conf_it != op_name2sbp_sig_conf.end())\n              << op_node->op().op_name();\n          CheckSbpSignatureAndNdSbpEquals(SbpSignature(op_name2sbp_sig_conf_it->second),\n                                          NdSbpSignature(iter->second));\n        } else {\n          // do nothing\n        }\n      }\n    }\n    InferOpNodeNdSbpSignature(op_node, nd_sbp_sig_conf);\n    JUST(op_node->mut_op()->InferLogicalOutBlobDescsIf());\n    return Maybe<void>::Ok();\n  }));\n  return Maybe<void>::Ok();\n}\n\nint64_t OpGraph::GetParallelNum(const std::string& op_name) const {\n  return op_name2op_node_.at(op_name)->parallel_desc().parallel_num();\n}\n\nconst SbpParallel& OpGraph::GetSbpParallel(const std::string& op_name,\n                                           const LogicalBlobId& lbi) const {\n  return op_name2op_node_.at(GetOpNameKey(op_name, lbi))\n      ->SbpParallel4Lbi(GetLogicalBlobIdKey(op_name, lbi));\n}\n\nconst NdSbp& OpGraph::GetNdSbp(const std::string& op_name, const LogicalBlobId& lbi) const {\n  return op_name2op_node_.at(GetOpNameKey(op_name, lbi))\n      ->NdSbp4Lbi(GetLogicalBlobIdKey(op_name, lbi));\n}\n\nDataType OpGraph::GetBlobDataType(const LogicalBlobId& lbi) const {\n  return op_name2op_node_.at(lbi.op_name())\n      ->LogicalBlobDesc4Lbi(GetLogicalBlobIdKey(lbi.op_name(), lbi))\n      .data_type();\n}\n\nconst BlobDesc& OpGraph::GetLogicalBlobDesc(const LogicalBlobId& lbi) const {\n  return op_name2op_node_.at(lbi.op_name())\n      ->LogicalBlobDesc4Lbi(GetLogicalBlobIdKey(lbi.op_name(), lbi));\n}\n\nstd::string OpGraph::GetOpNameKey(const std::string& op_name, const LogicalBlobId& lbi) const {\n  if (op_name2op_node_.find(op_name) != op_name2op_node_.end()) {\n    return op_name;\n  } else {\n    UNIMPLEMENTED();\n  }\n}\n\nLogicalBlobId OpGraph::GetLogicalBlobIdKey(const std::string& op_name,\n                                           const LogicalBlobId& lbi) const {\n  if (op_name2op_node_.find(op_name) != op_name2op_node_.end()) {\n    return lbi;\n  } else {\n    UNIMPLEMENTED();\n  }\n}\n\nvoid OpGraph::ForEachDataAndCtrlInNode(OpNode* node,\n                                       const std::function<void(OpNode*)>& Handler) const {\n  node->ForEachNodeOnInEdge(Handler);\n  for (const auto& ctrl_in_op_name : node->op().op_conf().ctrl_in_op_name()) {\n    CHECK(op_name2op_node_.find(ctrl_in_op_name) != op_name2op_node_.end())\n        << \" cannot find ctrl_in_op_name: [\" << ctrl_in_op_name << \"] of op: [\"\n        << node->op().op_name() << \"] in OpGraph. \";\n    Handler(op_name2op_node_.at(ctrl_in_op_name));\n  }\n}\n\nvoid OpGraph::ForEachDataAndCtrlOutNode(OpNode* node,\n                                        const std::function<void(OpNode*)>& Handler) const {\n  node->ForEachNodeOnOutEdge(Handler);\n  const auto& op_name_it = producer_op_name2ctrl_consumer_op_names_.find(node->op().op_name());\n  if (op_name_it == producer_op_name2ctrl_consumer_op_names_.end()) { return; }\n  for (const std::string& ctrl_consumer_op_name : op_name_it->second) {\n    CHECK(op_name2op_node_.find(ctrl_consumer_op_name) != op_name2op_node_.end())\n        << \" cannot find ctrl_consumer_op_name: [\" << ctrl_consumer_op_name << \"] of op: [\"\n        << node->op().op_name() << \"] in OpGraph.\";\n    Handler(op_name2op_node_.at(ctrl_consumer_op_name));\n  }\n}\n\nvoid OpGraph::TopoForEachNodeWithCtrlEdge(const std::function<void(OpNode*)>& NodeHandler) const {\n  auto OpGraphForEachInDataAndCtrlNode = [&](OpNode* node,\n                                             const std::function<void(OpNode*)>& Handler) {\n    ForEachDataAndCtrlInNode(node, Handler);\n  };\n  auto OpGraphForEachOutDataAndCtrlNode = [&](OpNode* node,\n                                              const std::function<void(OpNode*)>& Handler) {\n    ForEachDataAndCtrlOutNode(node, Handler);\n  };\n  TopoForEachNode(OpGraphForEachInDataAndCtrlNode, OpGraphForEachOutDataAndCtrlNode, NodeHandler);\n}\n\nstd::function<bool(const std::string&, const std::string&)>\nOpGraph::MakePredicatorIsOpNameDataOrCtrlReachable() const {\n  auto IsDataOrCtrlReachable = MakePredicatorIsDataOrCtrlReachable();\n  return [IsDataOrCtrlReachable, this](const std::string& lhs, const std::string& rhs) {\n    const auto& src_node_it = op_name2op_node_.find(lhs);\n    if (src_node_it == op_name2op_node_.end()) { return false; }\n    const auto& dst_node_it = op_name2op_node_.find(rhs);\n    if (dst_node_it == op_name2op_node_.end()) { return false; }\n    return (src_node_it->second == dst_node_it->second)\n           || IsDataOrCtrlReachable(src_node_it->second, dst_node_it->second);\n  };\n}\n\nstd::function<bool(const OpNode*, const OpNode*)> OpGraph::MakePredicatorIsDataOrCtrlReachable()\n    const {\n  auto _1 = std::placeholders::_1;\n  auto _2 = std::placeholders::_2;\n  return MakePredicatorIsReachable(DataOrCtrlSourceNodes(),\n                                   std::bind(&OpGraph::ForEachDataAndCtrlInNode, this, _1, _2),\n                                   std::bind(&OpGraph::ForEachDataAndCtrlOutNode, this, _1, _2));\n}\n\nstd::list<OpNode*> OpGraph::DataOrCtrlSourceNodes() const {\n  std::list<OpNode*> ret;\n  ForEachNode([&](OpNode* op_node) {\n    size_t in_edges_cnt = 0;\n    ForEachDataAndCtrlInNode(op_node, [&](OpNode*) { ++in_edges_cnt; });\n    if (in_edges_cnt == 0) { ret.emplace_back(op_node); }\n  });\n  return ret;\n}\n\nvoid OpGraph::DumpLogicalBlobDesc(Job* job) const {\n  auto* helper = job->mutable_helper();\n  ForEachNode([&](const OpNode* node) {\n    for (const auto& obn : node->op().output_bns()) {\n      const auto& lbi = node->op().BnInOp2Lbi(obn);\n      node->LogicalBlobDesc4Lbi(lbi).ToProto(\n          &(*helper->mutable_lbn2logical_blob_desc())[GenLogicalBlobName(lbi)]);\n    }\n  });\n}\n\nvoid OpGraph::DumpNdSbpSignature(Job* job) const {\n  ForEachNode([&](const OpNode* node) -> void {\n    (*job->mutable_job_parallel_view_conf()\n          ->mutable_op_name2nd_sbp_signature_conf())[node->op().op_name()] =\n        *CHECK_JUST(node->op().nd_sbp_signature());\n    if (node->parallel_desc().hierarchy()->NumAxes() == 1) {\n      (*job->mutable_job_parallel_view_conf()\n            ->mutable_op_name2sbp_signature_conf())[node->op().op_name()] = node->sbp_signature();\n    }\n  });\n}\n\nvoid OpGraph::DumpArgSignature(Job* job) const {\n  ForEachNode([&](const OpNode* node) {\n    auto* op_arg_signature =\n        &(*job->mutable_helper()->mutable_op_name2arg_signature())[node->op().op_name()];\n    for (const auto& ibn : node->op().input_bns()) {\n      const auto& lbi = node->op().BnInOp2Lbi(ibn);\n      (*op_arg_signature->mutable_bn_in_op2lbi())[ibn] = lbi;\n    }\n    for (const auto& obn : node->op().output_bns()) {\n      const auto& lbi = node->op().BnInOp2Lbi(obn);\n      (*op_arg_signature->mutable_bn_in_op2lbi())[obn] = lbi;\n    }\n  });\n}\n\nMaybe<void> OpGraph::ForEachOpNode(const std::function<Maybe<void>(const OpNode&)>& DoEach) const {\n  HashMap<LogicalBlobId, bool> visited;\n  for (const auto& op_name : op_names_) {\n    const OpNode& op_node = *op_name2op_node_.at(op_name);\n    for (const auto& ibn : op_node.op().input_bns()) {\n      const auto& lbi = op_node.op().BnInOp2Lbi(ibn);\n      CHECK_OR_RETURN(visited[lbi]) << \"input blob '\" << ibn << \"' is not defined\\n\"\n                                    << lbi.DebugString() << \"\\n==== op_conf ====\\n\"\n                                    << op_node.op().op_conf().DebugString();\n    }\n    for (const auto& obn : op_node.op().output_bns()) {\n      const auto& lbi = op_node.op().BnInOp2Lbi(obn);\n      CHECK_OR_RETURN(!visited[lbi]) << \"output blob '\" << obn << \"' is defined\\n\"\n                                     << lbi.DebugString() << \"\\n==== op_conf ====\\n\"\n                                     << op_node.op().op_conf().DebugString();\n      visited[lbi] = true;\n    }\n    JUST(DoEach(op_node));\n  }\n  return Maybe<void>::Ok();\n}\n\nstd::function<bool(const OpNode* src, const OpNode* dst)> OpGraph::CreatePredicatorIsReachable()\n    const {\n  return MakePredicatorIsReachable();\n}\n\n// Print the graph with SBP in order\nvoid OpGraph::PrintSBPGraphDebugInfo() const {\n  // test debug\n  std::cout << \"Get Into Print Op Graph\" << std::endl;\n  // Collect op_node\n  std::vector<OpNode*> NodeList;\n  ForEachNode([&](OpNode* op_node) { NodeList.push_back(op_node); });\n\n  // test debug\n  std::cout << \"Deciding order\" << std::endl;\n  // Decide the order to vist the op\n  std::vector<int32_t> order;\n  auto_parallel::DecideOrder(NodeList, order, [&](OpNode* a, OpNode* b) {\n    return a->op().op_name().compare(b->op().op_name()) > 0;\n  });\n  std::vector<int32_t> str_order;\n\n  // test debug\n  std::cout << \"Finish deciding order\" << std::endl;\n\n  for (int32_t i = 0; i < NodeList.size(); i++) {\n    OpNode* op_node = NodeList[order[i]];\n    std::cout << op_node->op().op_name() << \" (^_^):\" << std::endl;\n    // Sort before printing\n    const auto& op_input_bns = op_node->op().input_bns();\n    auto comp = [](const std::string& a, const std::string& b) { return a.compare(b) > 0; };\n    auto_parallel::DecideOrder(op_input_bns, str_order, comp);\n    // Print out SBP information for input operator\n    for (int32_t j : str_order) {\n      const auto& ibn = op_input_bns[j];\n      auto producer_node = op_node->MutSrcNode4Ibn(ibn);\n      std::cout << \"Pre Op:\" << producer_node->op().op_name() << \": \" << ibn;\n      const auto& this_sbp_parallel = op_node->NdSbp4BnInOp(ibn);\n      std::cout << \", \" << NdSbpToString(this_sbp_parallel);\n      const auto input_blob_modifier_ = op_node->op().InputBlobModifier4Ibn(ibn);\n      bool is_same_sbp = input_blob_modifier_.has_is_mutable() && input_blob_modifier_.is_mutable();\n      if (is_same_sbp) std::cout << \", same SBP\";\n      std::cout << \", \" << op_node->LogicalBlobDesc4Lbi(op_node->op().BnInOp2Lbi(ibn)).shape();\n      std::cout << std::endl;\n    }\n    // Sort before printing\n    const auto& op_output_bns = op_node->op().output_bns();\n    auto_parallel::DecideOrder(op_output_bns, str_order, comp);\n    // Print out SBP information for output blobs\n    for (int32_t j : str_order) {\n      const auto& obn = op_output_bns[j];\n      std::cout << \"Out Op:\" << obn;\n      const auto& this_sbp_parallel = op_node->NdSbp4BnInOp(obn);\n      std::cout << \", \" << NdSbpToString(this_sbp_parallel);\n      std::cout << \", \" << op_node->LogicalBlobDesc4Lbi(op_node->op().BnInOp2Lbi(obn)).shape();\n      std::cout << std::endl;\n    }\n    std::cout << std::endl;\n  }\n}\n\nOpGraphSingletonGuard::OpGraphSingletonGuard(const Job& job) {\n  // new Singleton<OpGraph> and set log configs.\n  Singleton<OpGraph>::New(job);\n  const JobDesc& job_desc = GlobalJobDesc();\n  if (Singleton<ResourceDesc, ForSession>::Get()->enable_debug_mode()) {\n    TeePersistentLogStream::Create(StrCat(\"optimized_job\", job_desc.job_id()))->Write(job);\n    Singleton<OpGraph>::Get()->ToDotWithFilePath(\n        \"optimized_dlnet_\" + std::to_string(job_desc.job_id()) + \"_op_graph.dot\");\n  }\n}\n\nOpGraphSingletonGuard::~OpGraphSingletonGuard() { Singleton<OpGraph>::Delete(); }\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/graph/op_graph.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_GRAPH_OP_GRAPH_H_\n#define ONEFLOW_CORE_GRAPH_OP_GRAPH_H_\n\n#include \"oneflow/core/graph/graph.h\"\n#include \"oneflow/core/job/job_desc.h\"\n#include \"oneflow/core/job/parallel_desc.h\"\n#include \"oneflow/core/job/local_parallel.pb.h\"\n#include \"oneflow/core/operator/operator.h\"\n#include \"oneflow/core/common/balanced_splitter.h\"\n\nnamespace oneflow {\nnamespace auto_parallel {\nclass SbpConstructor;\n}\n\nclass OpEdge;\nclass OpGraph;\n\nclass OpNode final : public Node<OpNode, OpEdge> {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(OpNode);\n  explicit OpNode(Symbol<ParallelDesc> parallel_desc, const OperatorConf& op_conf);\n  ~OpNode() = default;\n\n  // Getters\n  bool IsTimeShapeIdentity() const;\n  const Operator& op() const { return *op_; }\n  std::shared_ptr<const Operator> shared_op() const { return op_; }\n  const ParallelDesc& parallel_desc() const { return *parallel_desc_; }\n  Symbol<ParallelDesc> parallel_desc_sym() const { return parallel_desc_; }\n  const SbpSignature& sbp_signature() const { return *CHECK_JUST(op().sbp_signature()); }\n  const NdSbpSignature& nd_sbp_signature() const { return *CHECK_JUST(op().nd_sbp_signature()); }\n  const SbpParallel& SbpParallel4Lbi(const LogicalBlobId& lbi) const;\n  const SbpParallel& SbpParallel4BnInOp(const std::string& bn_in_op) const;\n  const NdSbp& NdSbp4Lbi(const LogicalBlobId& lbi) const;\n  const NdSbp& NdSbp4BnInOp(const std::string& bn_in_op) const;\n  const BlobDesc& LogicalBlobDesc4Lbi(const LogicalBlobId& lbi) const;\n  const OpNode& ProducerOpNode4Lbi(const LogicalBlobId& lbi) const;\n  const OpNode& SrcNode4Ibn(const std::string& bn_in_op) const;\n\n  std::string VisualStr() const override;\n\n private:\n  friend class OpGraph;\n  friend class OpEdge;\n  friend class auto_parallel::SbpConstructor;\n\n  // Setters\n  Operator* mut_op() { return op_.get(); }\n  OpNode* MutSrcNode4Ibn(const std::string& bn_in_op) const;\n  OpNode* MutSrcNode4InputLbi(const LogicalBlobId& lbi) const;\n  void InitLbi2SourceNode();\n  void InitLbi2NdSbp();\n\n  Symbol<ParallelDesc> parallel_desc_;\n  std::shared_ptr<Operator> op_;\n  HashSet<std::string> ibns_;\n  HashMap<LogicalBlobId, OpNode*> lbi2source_node_;\n  HashMap<LogicalBlobId, NdSbp> lbi2nd_sbp_;\n  std::vector<std::pair<const OpNode*, int32_t>> input_index2producer_and_output_index_;\n};\n\nclass OpEdge final : public Edge<OpNode, OpEdge> {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(OpEdge);\n  explicit OpEdge(std::shared_ptr<std::vector<LogicalBlobId>> lbis,\n                  std::shared_ptr<HashMap<LogicalBlobId, std::string>> lbi2obn,\n                  std::shared_ptr<HashMap<LogicalBlobId, std::vector<std::string>>> lbi2ibns)\n      : lbis_(std::move(lbis)), lbi2obn_(std::move(lbi2obn)), lbi2ibns_(std::move(lbi2ibns)) {}\n  ~OpEdge() override = default;\n\n  // Getters\n  const std::vector<LogicalBlobId>& lbis() const { return *lbis_; }\n  const HashMap<LogicalBlobId, std::string>& lbi2obn() const { return *lbi2obn_; }\n  const HashMap<LogicalBlobId, std::vector<std::string>>& lbi2ibns() const { return *lbi2ibns_; }\n\n  bool NeedBoxing() const;\n  std::string VisualStr() const override;\n\n private:\n  std::shared_ptr<std::vector<LogicalBlobId>> lbis_;\n  std::shared_ptr<HashMap<LogicalBlobId, std::string>> lbi2obn_;\n  std::shared_ptr<HashMap<LogicalBlobId, std::vector<std::string>>> lbi2ibns_;\n};\n\nclass OpGraph final : public Graph<OpNode, OpEdge> {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(OpGraph);\n  explicit OpGraph(const Job& job) { CHECK_JUST(Init(job)); }\n  explicit OpGraph() = default;\n  ~OpGraph() override = default;\n\n  static Maybe<OpGraph> New(const Job& job);\n\n  Maybe<void> ForEachOpNode(const std::function<Maybe<void>(const OpNode&)>& DoEach) const;\n\n  const OpNode* OpNode4OpName(const std::string& name) const;\n\n  int64_t GetParallelNum(const std::string& op_name) const;\n  const SbpParallel& GetSbpParallel(const std::string& op_name, const LogicalBlobId& lbi) const;\n  const NdSbp& GetNdSbp(const std::string& op_name, const LogicalBlobId& lbi) const;\n  DataType GetBlobDataType(const LogicalBlobId& lbi) const;\n  const BlobDesc& GetLogicalBlobDesc(const LogicalBlobId& lbi) const;\n\n  std::function<bool(const std::string&, const std::string&)>\n  MakePredicatorIsOpNameDataOrCtrlReachable() const;\n\n  void ForEachDataAndCtrlInNode(OpNode* node, const std::function<void(OpNode*)>& Handler) const;\n  void ForEachDataAndCtrlOutNode(OpNode* node, const std::function<void(OpNode*)>& Handler) const;\n  void TopoForEachNodeWithCtrlEdge(const std::function<void(OpNode*)>& NodeHandler) const;\n  // NOTE(chengcheng): For topo for each with ctrl edges. OpEdge is ONLY data edge.\n  std::list<OpNode*> DataOrCtrlSourceNodes() const;\n\n  void DumpLogicalBlobDesc(Job* job) const;\n  void DumpArgSignature(Job* job) const;\n  void DumpNdSbpSignature(Job* job) const;\n\n  Maybe<void> Init(const Job& job);\n\n  std::function<bool(const OpNode* src, const OpNode* dst)> CreatePredicatorIsReachable() const;\n  // Print the graph with SBP in order\n  void PrintSBPGraphDebugInfo() const;\n\n private:\n  friend class auto_parallel::SbpConstructor;\n\n  void InitNodes(const Job& job);\n  void InitEdges();\n  void InitProducerOpName2CtrlConsumerOpNames(const Job& job);\n  void CheckIsDAG() const;\n  void InferBlobLastUsed() const;\n  void InferTimeShape() const;\n  void InferOpNodeNdSbpSignature(OpNode* op_node, const NdSbpSignature& nd_sbp_sig_conf) const;\n  Maybe<void> InferOpNodeLocalSignature(OpNode* op_node, bool is_local_conf) const;\n  Maybe<void> InferLogicalBlobDesc(const Job& job) const;\n  std::string GetOpNameKey(const std::string& op_name, const LogicalBlobId& lbi) const;\n  LogicalBlobId GetLogicalBlobIdKey(const std::string& op_name, const LogicalBlobId& lbi) const;\n\n  std::function<bool(const OpNode*, const OpNode*)> MakePredicatorIsDataOrCtrlReachable() const;\n\n  HashMap<std::string, OpNode*> op_name2op_node_;\n  std::list<std::string> op_names_;\n  HashMap<std::string, HashSet<std::string>> producer_op_name2ctrl_consumer_op_names_;\n};\n\nclass OpGraphSingletonGuard {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(OpGraphSingletonGuard);\n  explicit OpGraphSingletonGuard(const Job& job);\n  ~OpGraphSingletonGuard();\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_GRAPH_OP_GRAPH_H_\n"
  },
  {
    "path": "oneflow/core/graph/plan_task_graph.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/container_util.h\"\n#include \"oneflow/core/graph/plan_task_graph.h\"\n\nnamespace oneflow {\n\nPlanTaskGraph::PlanTaskGraph(const Plan& plan) : plan_(&plan) {\n  InitNodes();\n  InitEdges();\n}\n\nvoid PlanTaskGraph::InitNodes() {\n  for (const auto& task : plan_->task()) {\n    PlanTaskNode* plan_task_node = new PlanTaskNode(task);\n    task_id2plan_task_node_.insert({task.task_id(), plan_task_node});\n    AddAllocatedNode(plan_task_node);\n  }\n}\n\nvoid PlanTaskGraph::InitEdges() {\n  for (const auto& task_id_and_plan_task_node : task_id2plan_task_node_) {\n    PlanTaskNode* producer_node = task_id_and_plan_task_node.second;\n    for (const auto& pair : producer_node->task_proto()->produced_regst_desc()) {\n      for (int64_t consumer_task_id : pair.second.consumer_task_id()) {\n        PlanTaskNode* consumer_node = CHECK_JUST(MapAt(task_id2plan_task_node_, consumer_task_id));\n        TryConnect(producer_node, consumer_node);\n      }\n    }\n  }\n}\n\nvoid PlanTaskGraph::TryConnect(PlanTaskNode* src, PlanTaskNode* dst) {\n  if (edges_.insert({src, dst}).second) { Connect(src, NewEdge(), dst); }\n}\n\nconst TaskProto* PlanTaskGraph::TaskProto4TaskId(int64_t task_id) const {\n  return CHECK_JUST(MapAt(task_id2plan_task_node_, task_id))->task_proto();\n}\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/graph/plan_task_graph.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_GRAPH_PLAN_TASK_GRAPH_H_\n#define ONEFLOW_CORE_GRAPH_PLAN_TASK_GRAPH_H_\n\n#include \"oneflow/core/job/plan.pb.h\"\n#include \"oneflow/core/graph/graph.h\"\n\nnamespace oneflow {\n\nclass PlanTaskNode;\n\nclass PlanTaskEdge final : public Edge<PlanTaskNode, PlanTaskEdge> {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(PlanTaskEdge);\n  PlanTaskEdge() = default;\n  ~PlanTaskEdge() = default;\n};\n\nclass PlanTaskNode final : public Node<PlanTaskNode, PlanTaskEdge> {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(PlanTaskNode);\n  explicit PlanTaskNode(const TaskProto& task_proto) : task_proto_(&task_proto) {}\n  ~PlanTaskNode() = default;\n\n  const TaskProto* task_proto() const { return task_proto_; }\n  int64_t task_id() const { return task_proto_->task_id(); }\n  int64_t chain_id() const { return task_proto_->chain_id(); }\n  int64_t order_in_chain() const { return task_proto_->order_in_chain(); }\n\n private:\n  const TaskProto* task_proto_;\n};\n\nclass PlanTaskGraph : public Graph<const PlanTaskNode, PlanTaskEdge> {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(PlanTaskGraph);\n  explicit PlanTaskGraph(const Plan& plan);\n  virtual ~PlanTaskGraph() = default;\n\n  const TaskProto* TaskProto4TaskId(int64_t task_id) const;\n  const Plan& plan() const { return *plan_; }\n\n protected:\n  void InitNodes();\n  void InitEdges();\n  void TryConnect(PlanTaskNode* src, PlanTaskNode* dst);\n\n  const Plan* plan_;\n  HashMap<int64_t, PlanTaskNode*> task_id2plan_task_node_;\n  HashSet<std::pair<PlanTaskNode*, PlanTaskNode*>> edges_;\n};\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_GRAPH_PLAN_TASK_GRAPH_H_\n"
  },
  {
    "path": "oneflow/core/graph/slice_boxing_task_node.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/graph/slice_boxing_task_node.h\"\n\n#include \"oneflow/core/framework/to_string.h\"\n#include \"oneflow/core/graph/task_graph_rebuild_ctx.h\"\n\nnamespace oneflow {\n\nvoid SliceBoxingTaskNode::Init(const LogicalBlobId& lbi, const TensorSliceView& out_slice,\n                               const SliceBoxingTaskMode mode, int64_t machine_id,\n                               int64_t thrd_id) {\n  out_slice_ = out_slice;\n  out_shape_ = out_slice.shape();\n  mode_ = mode;\n  set_machine_id(machine_id);\n  set_thrd_id(thrd_id);\n  set_lbi(lbi);\n}\n\nvoid SliceBoxingTaskNode::ProduceAllRegstsAndBindEdges() {\n  std::shared_ptr<RegstDesc> out_regst_desc = ProduceRegst(\"out\", true);\n  this->ForEachOutDataEdge([&](TaskEdge* edge) { edge->AddRegst(\"out\", out_regst_desc); });\n  ProduceRegst(\"tmp\", true);\n}\n\nvoid SliceBoxingTaskNode::ConsumeAllRegsts() {\n  HashMap<const TaskEdge*, int64_t> edge2order_;\n  FOR_RANGE(int64_t, i, 0, ordered_in_data_edges_.size()) {\n    edge2order_.emplace(ordered_in_data_edges_.at(i), i);\n  }\n  int64_t in_data_edge_cnt = 0;\n  ForEachInDataEdge([&](TaskEdge* edge) {\n    const auto order_it = edge2order_.find(edge);\n    CHECK(order_it != edge2order_.end());\n    ConsumeRegst(\"in_\" + std::to_string(order_it->second), edge->GetSoleRegst());\n    in_data_edge_cnt += 1;\n  });\n  CHECK_EQ(in_data_edge_cnt, ordered_in_data_edges_.size());\n}\n\nvoid SliceBoxingTaskNode::BuildExecGphAndRegst() {\n  ExecNode* node = mut_exec_gph().NewNode();\n  std::shared_ptr<Operator> op = CHECK_JUST(ConstructOp(GetBoxingOpConf()));\n  node->mut_op() = op;\n  FOR_RANGE(size_t, i, 0, op->input_bns().size()) {\n    const std::string& ibn = op->input_bns().Get(i);\n    CHECK_EQ(GenUnRepeatedBn(ibn).second, i);\n    node->BindBnWithRegst(ibn, GetSoleConsumedRegst(\"in_\" + std::to_string(i)));\n  }\n  std::shared_ptr<RegstDesc> out_regst = GetProducedRegst(\"out\");\n  out_regst->AddLbi(lbi());\n  node->BindBnWithRegst(op->SoleObn(), out_regst);\n  node->AddBnToRegstAndBindIt(&Operator::tmp_bns, GetProducedRegst(\"tmp\"));\n  (node->*GetInferBlobDescsMethod())(parallel_ctx());\n}\n\nvoid SliceBoxingTaskNode::InferProducedDataRegstTimeShape() {\n  NaiveInferProducedDataRegstTimeShape();\n}\n\nvoid SliceBoxingTaskNode::SetInDataEdgeSlice(const TaskEdge* edge, const TensorSliceView& slice) {\n  CHECK(in_data_edge2slice_.emplace(edge, slice).second);\n  ordered_in_data_edges_.emplace_back(edge);\n}\n\nvoid SliceBoxingTaskNode::ConnectToSrcNodeWithSlice(TaskNode* src, TaskEdge* edge,\n                                                    const TensorSliceView& slice) {\n  edge->AddLbi(lbi());\n  Connect<TaskNode>(src, edge, this);\n  SetInDataEdgeSlice(edge, slice);\n}\n\nvoid SliceBoxingTaskNode::SetOutShape(const Shape& shape) { out_shape_ = shape; }\n\nOperatorConf SliceBoxingTaskNode::GetBoxingOpConf() {\n  OperatorConf op_conf{};\n  op_conf.set_device_tag(*CHECK_JUST(DeviceTag4DeviceType(device_type())));\n  SliceBoxingConf boxing_conf{};\n  *boxing_conf.mutable_lbi() = lbi();\n  out_slice_.ToProto(boxing_conf.mutable_out_slice());\n  out_shape_.ToProto(boxing_conf.mutable_out_shape());\n  for (const TaskEdge* edge : ordered_in_data_edges_) {\n    in_data_edge2slice_.at(edge).ToProto(boxing_conf.mutable_in_slice()->Add());\n  }\n  if (mode_ == kSliceBoxingTaskModeCopy) {\n    op_conf.set_name(\"System-Boxing-BoxingCopy-\" + NewUniqueId());\n    SliceBoxingCopyOpConf* conf = op_conf.mutable_slice_boxing_copy_conf();\n    *conf->mutable_slice_boxing_conf() = boxing_conf;\n  } else if (mode_ == kSliceBoxingTaskModeAdd) {\n    op_conf.set_name(\"System-Boxing-BoxingAdd-\" + NewUniqueId());\n    SliceBoxingAddOpConf* conf = op_conf.mutable_slice_boxing_add_conf();\n    *conf->mutable_slice_boxing_conf() = boxing_conf;\n  } else {\n    UNIMPLEMENTED();\n  }\n  return op_conf;\n}\n\nMaybe<void> SliceBoxingTaskNode::InitTransportTaskFromProto(\n    const TransportTaskProto& transport_task_proto, const TaskGraphRebuildCtx& ctx) {\n  CHECK_OR_RETURN(transport_task_proto.has_slice_boxing_task())\n      << \"not a serialized SliceBoxingTaskNode. debug string: \"\n      << transport_task_proto.DebugString();\n  const auto& proto = transport_task_proto.slice_boxing_task();\n  for (const auto& pair : proto.in_data_edge_uid2slice()) {\n    const auto* edge = JUST(ctx.TaskEdge4Uid(pair.first));\n    CHECK_OR_RETURN(in_data_edge2slice_.emplace(edge, pair.second).second)\n        << \"redundant edge found. edge_uid: \" << pair.first;\n  }\n  for (int64_t edge_uid : proto.ordered_in_data_edge_uid()) {\n    ordered_in_data_edges_.push_back(JUST(ctx.TaskEdge4Uid(edge_uid)));\n  }\n  out_slice_ = TensorSliceView(proto.out_slice());\n  out_shape_ = Shape(proto.out_shape());\n  mode_ = proto.mode();\n  return Maybe<void>::Ok();\n}\n\nvoid SliceBoxingTaskNode::ToTransportTaskProto(TransportTaskProto* transport_task_proto) const {\n  ToProto(transport_task_proto->mutable_task_proto(), /*check=*/false);\n  auto* proto = transport_task_proto->mutable_slice_boxing_task();\n  for (const auto& pair : in_data_edge2slice_) {\n    int64_t edge_uid = reinterpret_cast<int64_t>(pair.first);\n    pair.second.ToProto(&(*proto->mutable_in_data_edge_uid2slice())[edge_uid]);\n  }\n  for (const auto* edge : ordered_in_data_edges_) {\n    proto->add_ordered_in_data_edge_uid(reinterpret_cast<int64_t>(edge));\n  }\n  out_slice_.ToProto(proto->mutable_out_slice());\n  out_shape_.ToProto(proto->mutable_out_shape());\n  proto->set_mode(mode_);\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/graph/slice_boxing_task_node.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_GRAPH_SLICE_BOXING_TASK_NODE_H_\n#define ONEFLOW_CORE_GRAPH_SLICE_BOXING_TASK_NODE_H_\n\n#include \"oneflow/core/graph/boxing_task_graph.pb.h\"\n#include \"oneflow/core/graph/transport_task_node.h\"\n#include \"oneflow/core/register/tensor_slice_view.h\"\n#include \"oneflow/core/memory/memory_zone.h\"\n\nnamespace oneflow {\n\nclass SliceBoxingTaskNode final : public TransportTaskNode {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(SliceBoxingTaskNode);\n  SliceBoxingTaskNode() = default;\n  ~SliceBoxingTaskNode() override = default;\n\n  void Init(const LogicalBlobId& lbi, const TensorSliceView& out_slice, SliceBoxingTaskMode mode,\n            int64_t machine_id, int64_t thrd_id);\n  void ProduceAllRegstsAndBindEdges() override;\n  void ConsumeAllRegsts() override;\n  TaskType GetTaskType() const override { return TaskType::kSliceBoxing; }\n  void SetInDataEdgeSlice(const TaskEdge* edge, const TensorSliceView& slice);\n  void ConnectToSrcNodeWithSlice(TaskNode* src, TaskEdge* edge, const TensorSliceView& slice);\n  void SetOutShape(const Shape& shape);\n\n  Maybe<void> InitTransportTaskFromProto(const TransportTaskProto& transport_task_proto,\n                                         const TaskGraphRebuildCtx& ctx) override;\n  void ToTransportTaskProto(TransportTaskProto*) const override;\n\n private:\n  void BuildExecGphAndRegst() override;\n  void InferProducedDataRegstTimeShape() override;\n  OperatorConf GetBoxingOpConf();\n\n  HashMap<const TaskEdge*, TensorSliceView> in_data_edge2slice_;\n  std::vector<const TaskEdge*> ordered_in_data_edges_;\n  TensorSliceView out_slice_;\n  Shape out_shape_;\n  SliceBoxingTaskMode mode_ = kSliceBoxingTaskModeInvalid;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_GRAPH_SLICE_BOXING_TASK_NODE_H_\n"
  },
  {
    "path": "oneflow/core/graph/straighten_nodes.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <memory>\n#include <string>\n#include \"oneflow/core/common/singleton.h\"\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/graph/compute_task_node.h\"\n#include \"oneflow/core/graph/straighten_nodes.h\"\n#include \"oneflow/core/common/shape.h\"\n#include \"oneflow/core/graph/op_graph.h\"\n#include \"oneflow/core/graph/task_graph.h\"\n#include \"oneflow/core/graph/task_node.h\"\n#include \"oneflow/core/graph/transport_task_node.h\"\n#include \"oneflow/core/job/job_conf.pb.h\"\n#include \"oneflow/core/job/job_desc.h\"\n#include \"oneflow/core/common/protobuf.h\"\n#include \"oneflow/core/job/task.pb.h\"\n#include \"oneflow/core/operator/op_conf.pb.h\"\n#include \"oneflow/core/register/runtime_register_desc.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nenum TaskClassifier : int {\n  kWaitingOverlapNode = 0,\n  kWaitingMainComputation = 1,\n  kRunASAP = 2,\n  kRunALAP = 3\n};\n\n// The difference between a descending order and its corresponding ascending order\nstatic const int kDiff4AscendDescend = 100;\n\nclass TopoStruct {\n public:\n  TaskNode* node = nullptr;\n  int32_t min_layer = -1;\n  int32_t tributary_layer = -1;\n  bool on_trunk = false;\n  int32_t counter = 0;\n  int32_t min_distance2overlap = -1;\n  int64_t memory_increment = -1;\n  int32_t exceed_time = -1;\n  int32_t min_lifetime = -1;\n  int64_t memory_volume = -1;\n  int32_t max_layer = -1;\n  TaskClassifier task_classifier;\n  std::string key;\n  // We can have some other nodes in it for example\n  // SbpNode<NdSbpSignature>* node;\n  // SbpEdge<NdSbpSignature>* node;\n  // Or we can omit all the pointers and leave all the useful parameters.\n\n  int32_t ComputeMinLayer(HashMap<TaskNode*, TopoStruct>* task_node2topo_struct,\n                          std::map<std::string, std::vector<TopoStruct*>>* key2topo_structs);\n  // Drop down the tributary layer\n  void DropTributaryLayer(int32_t upper_bound);\n\n  void SpreadTributaryLayer(HashMap<TaskNode*, TopoStruct>* task_node2topo_struct);\n  void ComputeMaxLayer(HashMap<TaskNode*, TopoStruct>* task_node2topo_struct);\n\n  void SpreadTrunk(HashMap<TaskNode*, TopoStruct>* task_node2topo_struct);\n\n  // The minimum computation distance from the beginning of this op to the next overlap node\n  int32_t GetMinDistance2Overlap(HashMap<TaskNode*, TopoStruct>* task_node2topo_struct);\n\n  // Memory increment = (memory of out registers) - (memory of in registers)\n  void ComputeMemoryIncrement();\n\n  // Exceed time = time of cpu - time of gpu\n  // For most operators, the execution time on gpu exceed the execution time on cpu.\n  // However, overlap is needed if time of cpu > time of gpu.\n  void ComputeExceedTime();\n\n  // Memory volume is memory * lifetime, but we might change the formula\n  void ComputeMemoryVolume();\n\n  // TODO: We might design more deciding parameter and choose a right combination of them in the\n  // future.\n\n  // deciding parameter\n  // kTributaryLayerAscend = 0,     // small tributary layers go first\n  // kDistanceToOverlapAscend = 1,  // small minimum distance to overlap go first\n  // kLayerAscend = 2,              // first in first out\n  // kMemoryIncrementAscend = 3,    // small memory increment go first\n  // kExceedTimeAscend = 4,         // small exceed time go first\n  // kTributaryLayerDescend = 100,     // large tributary layers go first\n  // kDistanceToOverlapDescend = 101,  // long distance to overlap go first\n  // kLayerDescend = 102,              // last in first out\n  // kMemoryIncrementDescend = 103,    // large memory increment go first\n  // kExceedTimeDescend = 104,         // large exceed time go first\n  int64_t GetDecidingParameter(StraightenOrder so) const;\n};\n\nstatic StraightenAlgorithmTag sat;\n\n// NOTE: Leave these code for debugging in the future\n// static std::vector<StraightenOrder> decide_parameters({ParseIntegerFromEnv(\"Parameter0\", 3),\n//                                                        ParseIntegerFromEnv(\"Parameter1\", 0),\n//                                                        ParseIntegerFromEnv(\"Parameter2\", 3)});\n// The best parameter set for saving time is {102, 100}\n// The best parameter set for saving memory is {3, 0}\nstatic std::vector<StraightenOrder> decide_parameters;\n\n// move the head from source to target\nvoid MoveFrontBetweenMaps(std::map<int32_t, TopoStruct*>& source,\n                          std::map<int32_t, TopoStruct*>& target) {\n  if (!source.empty()) {\n    const auto& front = source.begin();\n    target[front->first] = front->second;\n    source.erase(front);\n  }\n};\n\nbool ShouldRunASAP(TaskType task_type) {\n  // They are sorted according to frequency of occurrences\n  switch (task_type) {\n    // We mark the number of occurrences in bert\n    case TaskType::kDeviceTick:                  // 38\n    case TaskType::kTick:                        // 8\n    case TaskType::kSrcSubsetTick:               // 6\n    case TaskType::kDstSubsetTick:               // 6\n    case TaskType::kCriticalSectionWaitTick:     // 4\n    case TaskType::kWaitAndSendIds:              // 2\n    case TaskType::kPack:                        // 0\n    case TaskType::kUnpack:                      // 0\n    case TaskType::kRepeat:                      // 0\n    case TaskType::kAcc:                         // 0\n    case TaskType::kSourceTick:                  // 0\n    case TaskType::kAccTick:                     // 0\n    case TaskType::kAccCtrlTick:                 // ?\n    case TaskType::kCase:                        // 0\n    case TaskType::kEsac:                        // 0\n    case TaskType::kReentrantLock: return true;  // 0\n    default: return false;\n  }\n}\n\nbool IsTransferNode(TaskType task_type) {\n  // return task_type == 12 || task_type == 13 || (48 <= task_type && task_type <= 64);\n  // They are sorted according to frequency of occurrences\n  switch (task_type) {\n    // We mark the number of occurrences in bert\n    case TaskType::kCollectiveBoxingGeneric:        // 76\n    case TaskType::kNcclSendRecvBoxing:             // ?\n    case TaskType::kCopyHd:                         // 27\n    case TaskType::kSliceBoxing:                    // 16\n    case TaskType::kCopyCommNet:                    // 12\n    case TaskType::kCollectiveBoxingPack:           // 8\n    case TaskType::kCollectiveBoxingUnpack:         // 8\n    case TaskType::kBoxingZeros:                    // 3\n    case TaskType::kDistributeConcat:               // 0\n    case TaskType::kDistributeSplit:                // 0\n    case TaskType::kBoxingIdentity:                 // 0\n    case TaskType::kDecodeH2D:                      // 0\n    case TaskType::kSspVariableProxy: return true;  // 0\n    default: return false;\n  }\n}\n\n// Classifier for the set according to the task type\nTaskClassifier GetTaskClassifier(const TaskNode* node, bool nccl_use_compute_stream) {\n  // Check task.pb.h for detail\n  // They are sorted according to frequency of judgement\n  // frequency of judgement = the number of occurrences / the times of judgement\n  TaskType task_type = node->GetTaskType();\n  if (task_type == TaskType::kNormalForward) {\n    const auto& op_conf = dynamic_cast<const CompTaskNode*>(node)->op()->op_conf();\n    if (sat == StraightenAlgorithmTag::kOverlap4CpuGpu && ShortGpuTime(op_conf)) {\n      return TaskClassifier::kWaitingOverlapNode;\n    } else {\n      return TaskClassifier::kWaitingMainComputation;\n    }\n  }\n  if (IsTransferNode(task_type)) {\n    if (sat == StraightenAlgorithmTag::kCompressMemory && nccl_use_compute_stream) {\n      // Overlap is not the first consideration, memory is\n      return TaskClassifier::kWaitingMainComputation;\n    } else {\n      return TaskClassifier::kWaitingOverlapNode;\n    }\n  }\n  if (task_type == TaskType::kCallbackNotify) { return TaskClassifier::kRunALAP; }\n  if (ShouldRunASAP(task_type)) { return TaskClassifier::kRunASAP; }\n  CHECK(false) << \"Unclassified or invalid task type (\" << task_type << \") showing up\";\n  // Throw a kRunASAP which means ignoring this node in the algorithm\n  return TaskClassifier::kRunASAP;\n}\n\nint32_t MaxProducerMinLayer(HashMap<TaskNode*, TopoStruct>* task_node2topo_struct,\n                            std::map<std::string, std::vector<TopoStruct*>>* key2topo_structs,\n                            TaskNode* node) {\n  int32_t max_min_layer = 0;\n  node->ForEachNodeOnInEdge([&](TaskNode* in) {\n    max_min_layer = std::max(max_min_layer, task_node2topo_struct->at(in).ComputeMinLayer(\n                                                task_node2topo_struct, key2topo_structs));\n  });\n  return max_min_layer + 1;\n}\n\nint32_t TopoStruct::ComputeMinLayer(\n    HashMap<TaskNode*, TopoStruct>* task_node2topo_struct,\n    std::map<std::string, std::vector<TopoStruct*>>* key2topo_structs) {\n  // Directly return the value if computed\n  if (min_layer > -1) { return min_layer; }\n  auto transport_task_node = dynamic_cast<TransportTaskNode*>(node);\n  if (transport_task_node) {\n    // Only compute the minimum layer for this transport node\n    min_layer = MaxProducerMinLayer(task_node2topo_struct, key2topo_structs, node);\n    // Generate the key to determine the same task nodes\n    // Since the key is connected with the min_layer for transport nodes\n    key = transport_task_node->lbi().ShortDebugString() + \"MinLayer:\" + std::to_string(min_layer);\n    // Gather all the task nodes with the same key\n    (*key2topo_structs)[key].push_back(this);\n  } else {\n    // Compute the minimum layer for all the nodes with the same key simultaneously\n    int32_t max_min_layer = -1;\n    for (auto& curr_topo_struct : key2topo_structs->at(key)) {\n      max_min_layer = std::max(\n          max_min_layer,\n          MaxProducerMinLayer(task_node2topo_struct, key2topo_structs, curr_topo_struct->node));\n    }\n    for (auto& curr_topo_struct : key2topo_structs->at(key)) {\n      curr_topo_struct->min_layer = max_min_layer;\n    }\n  }\n  return min_layer;\n}\n\n// Drop down the maximum layer with the minimum layer from consumer\nvoid TopoStruct::DropTributaryLayer(int32_t upper_bound) {\n  if (upper_bound < tributary_layer || tributary_layer < 0) { tributary_layer = upper_bound; }\n}\n\n// Should initialize the counter to be the number of out edges\n// Compute maximum layer for tributaries\nvoid TopoStruct::SpreadTributaryLayer(HashMap<TaskNode*, TopoStruct>* task_node2topo_struct) {\n  if (counter || min_layer <= 0) { return; }\n  int32_t producer_max_lay = 0;\n  if (on_trunk) {\n    producer_max_lay = min_layer - 1;\n  } else {\n    // On a tributary, the operator could be run later.\n    producer_max_lay = tributary_layer;\n  }\n  node->ForEachNodeOnInEdge([&](TaskNode* in) {\n    auto& topo_struct_in = task_node2topo_struct->at(in);\n    topo_struct_in.DropTributaryLayer(producer_max_lay);\n    --topo_struct_in.counter;\n    if (topo_struct_in.counter == 0) { topo_struct_in.SpreadTributaryLayer(task_node2topo_struct); }\n  });\n  // Reduce counter to -1 to avoid visiting again\n  counter--;\n}\n\nvoid TopoStruct::ComputeMaxLayer(HashMap<TaskNode*, TopoStruct>* task_node2topo_struct) {\n  node->ForEachNodeOnOutEdge([&](TaskNode* out) {\n    max_layer = std::max(max_layer, task_node2topo_struct->at(out).min_layer);\n  });\n}\n\n// Judge if this node is on the trunk\n// If so, judge it for its producer/upstream nodes\nvoid TopoStruct::SpreadTrunk(HashMap<TaskNode*, TopoStruct>* task_node2topo_struct) {\n  // Skip it if this node is already judged.\n  if (on_trunk) { return; }\n  CHECK_GE(min_layer, 0) << \"TopoStruct not initialized!\";\n  on_trunk = true;\n  // If I am in the trunk, then all the children with (min_layer >= my layer id - 1) would be\n  // considered as in the trunk\n  node->ForEachNodeOnInEdge([&](TaskNode* in) {\n    auto& topo_struct_in = task_node2topo_struct->at(in);\n    if (topo_struct_in.min_layer == min_layer - 1) {\n      topo_struct_in.SpreadTrunk(task_node2topo_struct);\n    }\n  });\n}\n\n// The minimum computation distance from the beginning of this op to the next overlap\nint32_t TopoStruct::GetMinDistance2Overlap(HashMap<TaskNode*, TopoStruct>* task_node2topo_struct) {\n  if (min_distance2overlap >= 0) { return min_distance2overlap; }\n  // if this node should be overlapped by main computation nodes\n  if (task_classifier == TaskClassifier::kWaitingOverlapNode) {\n    min_distance2overlap = 0;\n    return min_distance2overlap;\n  }\n  // Otherwise, initialize it with a large number\n  // Well, the total number in the task graph is large enough\n  min_distance2overlap = task_node2topo_struct->size();\n  node->ForEachNodeOnOutEdge([&](TaskNode* out) {\n    min_distance2overlap =\n        std::min(min_distance2overlap,\n                 task_node2topo_struct->at(out).GetMinDistance2Overlap(task_node2topo_struct));\n  });\n  ++min_distance2overlap;\n  return min_distance2overlap;\n}\n\n// Memory increment = (memory of out registers) - (memory of in registers)\nvoid TopoStruct::ComputeMemoryIncrement() {\n  if (memory_increment < 0) {\n    memory_increment = 0;\n    for (const auto& produced_register : node->produced_regsts()) {\n      if (produced_register.second->enable_reuse_mem()) {\n        RegstDescProto temp_proto;\n        produced_register.second->ToProto(&temp_proto);\n        memory_increment += RtRegstDesc(temp_proto).TotalMainByteSize4AllRegst();\n      }\n    }\n    for (const auto& consumed_register_list : node->consumed_regsts()) {\n      for (const auto& consumed_register : consumed_register_list.second) {\n        if (consumed_register->enable_reuse_mem()) {\n          RegstDescProto temp_proto;\n          consumed_register->ToProto(&temp_proto);\n          memory_increment -= RtRegstDesc(temp_proto).TotalMainByteSize4AllRegst()\n                              / consumed_register->consumers().size();\n        }\n      }\n    }\n  }\n}\n\n// Exceed time = time of cpu - time of gpu\nvoid TopoStruct::ComputeExceedTime() {\n  if (node->GetTaskType() == TaskType::kNormalForward\n      && ShortGpuTime(dynamic_cast<const CompTaskNode*>(node)->op()->op_conf())) {\n    exceed_time = 1;\n  } else {\n    exceed_time = 0;\n  }\n}\n\n// Memory volume is memory * lifetime, but we might change the formula\nvoid TopoStruct::ComputeMemoryVolume() {\n  static float lifetime_order = ParseFloatFromEnv(\"LifetimeOrder\", 1.0);\n  // We might get a large tensor multiply by a long life time, we need some rescaling\n  memory_volume = static_cast<int64_t>(\n      (memory_increment * pow(static_cast<double>(min_lifetime), lifetime_order)) / 1000.0);\n  // We need to distinguish zero or negative memory increment from slight positive memory increment.\n  // Make sure that we execute -0.1, 0, -0.003 before 0.1, 0.2\n  if (memory_increment > 0) { memory_volume += 1; }\n}\n\n// deciding parameter\n// kTributaryLayerAscend = 0,     // small tributary layers go first\n// kDistanceToOverlapAscend = 1,  // small minimum distance to overlap go first\n// kLayerAscend = 2,              // first in first out\n// kMemoryIncrementAscend = 3,    // small memory increment go first\n// kExceedTimeAscend = 4,         // small exceed time go first\n// kMemoryVolumeAscend = 5,       // small memory volume go first\n// kTributaryLayerDescend = 100,     // large tributary layers go first\n// kDistanceToOverlapDescend = 101,  // long distance to overlap go first\n// kLayerDescend = 102,              // last in first out\n// kMemoryIncrementDescend = 103,    // large memory increment go first\n// kExceedTimeDescend = 104,         // large exceed time go first\n// kMemoryVolumeAscend = 105,        // large memory volume go first\nint64_t TopoStruct::GetDecidingParameter(StraightenOrder so) const {\n  int64_t sign = 1;\n  if (so >= kDiff4AscendDescend) {\n    so = StraightenOrder(int(so) - kDiff4AscendDescend);\n    sign = -1;\n  }\n  switch (so) {\n    case StraightenOrder::kTributaryLayerAscend: return sign * tributary_layer;\n    case StraightenOrder::kDistanceToOverlapAscend: return sign * min_distance2overlap;\n    case StraightenOrder::kLayerAscend: return sign * min_layer;\n    case StraightenOrder::kMemoryIncrementAscend: return sign * memory_increment;\n    case StraightenOrder::kExceedTimeAscend: return sign * exceed_time;\n    case StraightenOrder::kMemoryVolumeAscend: return sign * memory_volume;\n    case StraightenOrder::kMaxLayerAscend: return sign * max_layer;\n    default: return 0;\n  }\n}\n\n// Find the trunk of the task graph, then reduce the wait time for tributaries\nvoid FindTrunk(HashMap<TaskNode*, TopoStruct>* task_node2topo_struct) {\n  // Find the maximum layer number\n  int32_t max_min_layer = -1;\n  for (const auto& pair : *task_node2topo_struct) {\n    if (max_min_layer < pair.second.min_layer) { max_min_layer = pair.second.min_layer; }\n  }\n  // All the nodes with min_layer>=trunk_end_id would be considered as trunk nodes\n  // The last 5 layers would be considered as in trunk anyway.\n  int32_t trunk_end_id = max_min_layer - 4;\n  for (auto& pair : *task_node2topo_struct) {\n    auto& topo_struct = pair.second;\n    // Initialize the counter and Tributary Layer\n    topo_struct.counter = pair.first->out_edges().size();\n    topo_struct.tributary_layer = max_min_layer;\n    // Find out all the nodes on the trunk.\n    if (topo_struct.min_layer >= trunk_end_id) { topo_struct.SpreadTrunk(task_node2topo_struct); }\n  }\n\n  for (auto& pair : *task_node2topo_struct) {\n    // Compute maximum layer for tributaries\n    pair.second.SpreadTributaryLayer(task_node2topo_struct);\n    // Set the min_distance2overlap for each topological structure\n    pair.second.GetMinDistance2Overlap(task_node2topo_struct);\n  }\n\n  // The computation of maximum layer must behind those of minimum layer for the whole graph.\n  for (auto& pair : *task_node2topo_struct) { pair.second.ComputeMaxLayer(task_node2topo_struct); }\n}\n\n// Find the minimum life time of the task graph,\n// which is the maximum of the minimum layer among all the consumers.\n// The function must be executed after generating min layer\nvoid FindMinLifetime(HashMap<TaskNode*, TopoStruct>* task_node2topo_struct) {\n  // Find the maximum consumer layer\n  for (auto& pair : *task_node2topo_struct) {\n    int32_t curr_min_layer = pair.second.min_layer;\n    pair.first->ForEachNodeOnInDataEdge([&](TaskNode* in) {\n      auto& max_consumer_layer = task_node2topo_struct->at(in).min_lifetime;\n      if (max_consumer_layer < curr_min_layer) { max_consumer_layer = curr_min_layer; }\n    });\n  }\n  // Compute the life time\n  for (auto& pair : *task_node2topo_struct) {\n    if (pair.second.min_layer >= pair.second.min_lifetime) {\n      // No consumer, the register will be killed after the execution of the current operator\n      // The life time is 1 (including the current operator)\n      pair.second.min_lifetime = 1;\n    } else {\n      // The life time is the distance between two operators + 1\n      // For example, a ---(x)---> b\n      // Register x is created while executing a, and x is killed after the execution of b.\n      // The life time is 2 (including a and b) == b.lifetime - a.lifetime\n      pair.second.min_lifetime -= pair.second.min_layer - 1;\n    }\n    pair.second.ComputeMemoryVolume();\n  }\n}\n\n}  // anonymous namespace\n\n// Some operators have longer time in cpu and less time in gpu.\n// Running those operators without overlap would cause large gap during each iteration.\n// For example, expand dims would not execute any kernel on gpu but still need 10us to execute some\n// functions on cpu.\nbool ShortGpuTime(const OperatorConf& op_conf) {\n  if (op_conf.has_variable_conf()) {\n    // Variable operators would not be run. They just create tensors.\n    // We do not visualize any execution in NVTX. (Even a tick operator has something in NVTX.)\n    return true;\n  }\n  if (op_conf.has_user_conf()) {\n    const auto& op_type_name = op_conf.user_conf().op_type_name();\n    // They are sorted according to frequency of occurrences in stable diffusion\n    if (op_type_name == \"expand_dims\"  // 90\n        || op_type_name == \"cast\"      // 16\n        || op_type_name == \"expand\"    // 2\n    ) {\n      return true;\n    }\n  }\n  return false;\n}\n\n// SAT, a.k.a. Scholastic Aptitude Test,\n// is the college admission test in the United States of America.\nvoid InitDecideParameters(StraightenAlgorithmTag sat,\n                          std::vector<StraightenOrder>* decide_parameters) {\n  decide_parameters->clear();\n  if (sat == StraightenAlgorithmTag::kCompressMemory) {\n    decide_parameters->push_back(StraightenOrder::kMemoryVolumeAscend);\n    decide_parameters->push_back(StraightenOrder::kMemoryIncrementAscend);\n    decide_parameters->push_back(StraightenOrder::kTributaryLayerAscend);\n  } else if (sat == StraightenAlgorithmTag::kOverlap4Transfer) {\n    decide_parameters->push_back(StraightenOrder::kLayerDescend);\n    decide_parameters->push_back(StraightenOrder::kTributaryLayerDescend);\n  } else if (sat == StraightenAlgorithmTag::kOverlap4CpuGpu) {\n    decide_parameters->push_back(StraightenOrder::kExceedTimeDescend);\n    decide_parameters->push_back(StraightenOrder::kLayerDescend);\n    decide_parameters->push_back(StraightenOrder::kMemoryIncrementAscend);\n  } else if (sat == StraightenAlgorithmTag::kDelayShortGpu) {\n    decide_parameters->push_back(StraightenOrder::kExceedTimeAscend);\n    decide_parameters->push_back(StraightenOrder::kMaxLayerAscend);\n    decide_parameters->push_back(StraightenOrder::kMemoryIncrementAscend);\n  } else {\n    // sat == StraightenAlgorithmTag::kDisable\n    decide_parameters->push_back(StraightenOrder::kLayerAscend);\n  }\n}\n\n// Maximum overlap number\n// While running an overlap operator, we would run some other operators simultaneously.\nint32_t MaximumOverlapNum(StraightenAlgorithmTag sat, bool nccl_use_compute_stream) {\n  if (sat == StraightenAlgorithmTag::kOverlap4CpuGpu) {\n    // 10 operators on GPU is enough to cover the time for a CPU operator\n    return 10;\n  }\n  // This condition should be following the sat == StraightenAlgorithmTag::kOverlap4CpuGpu\n  // Since the kOverlap4CpuGpu would not be affected by transfer.\n  if (nccl_use_compute_stream) {\n    // Using nccl compute stream would disable the overlap for transfer\n    // We need to reduce it to 1\n    return 1;\n  }\n  if (sat == StraightenAlgorithmTag::kCompressMemory) {\n    // Actually we do not need the overlap.\n    // Time is not the main consideration, memory is.\n    return 2;\n  }\n  // The default number is 10. Mainly for sat == StraightenAlgorithmTag::kOverlap4Transfer\n  // sat == StraightenAlgorithmTag::kDisable does not need a maximum overlap number.\n  return 10;\n}\n\nvoid StraightenNodes(TaskGraph* task_graph, std::vector<TaskNode*>* ordered_task_nodes,\n                     bool nccl_use_compute_stream) {\n  // Generate topological data structure for each task node\n  HashMap<TaskNode*, TopoStruct> task_node2topo_struct;\n  // Determine the same nodes which should run simultaneously by the keys\n  std::map<std::string, std::vector<TopoStruct*>> key2topo_structs;\n  task_graph->TopoForEachNode([&](TaskNode* node) {\n    auto& topo_struct = task_node2topo_struct[node];\n    topo_struct.node = node;\n    topo_struct.ComputeMemoryIncrement();\n    topo_struct.ComputeExceedTime();\n    // Generate the key to determine the same task nodes\n    if (dynamic_cast<TransportTaskNode*>(node)) {\n      // Deal with the key and the same task nodes later\n      return;\n      // topo_struct.key = dynamic_cast<TransportTaskNode*>(node)->lbi().ShortDebugString();\n    } else if (node->GetTaskType() == TaskType::kNormalForward) {\n      topo_struct.key = dynamic_cast<CompTaskNode*>(node)->op()->op_name();\n    } else {\n      topo_struct.key = node->VisualStr();\n    }\n    // Gather all the task nodes with the same key\n    key2topo_structs[topo_struct.key].push_back(&topo_struct);\n  });\n\n  // Compute all the min layer and generate the rest of the keys\n  for (auto& pair : task_node2topo_struct) {\n    pair.second.ComputeMinLayer(&task_node2topo_struct, &key2topo_structs);\n  }\n\n  // Generate other parameters in the topological data structure\n  FindTrunk(&task_node2topo_struct);\n  FindMinLifetime(&task_node2topo_struct);\n\n  // Update sat, since sat might be changed in previous jobs\n  UpdateSat(task_node2topo_struct, &sat);\n  // Decide the task classifier after updating sat\n  for (auto& pair : task_node2topo_struct) {\n    pair.second.task_classifier = GetTaskClassifier(pair.first, nccl_use_compute_stream);\n  }\n  // Check the task classifier for all the nodes with the same key\n  for (auto& pair : key2topo_structs) {\n    TaskClassifier first_task_classifier = pair.second.at(0)->task_classifier;\n    for (auto& topo_struct : pair.second) {\n      CHECK_EQ(first_task_classifier, topo_struct->task_classifier)\n          << \" We have different task classifier \" << first_task_classifier << \" and \"\n          << topo_struct->task_classifier << \" for the nodes with the same key: \" << pair.first;\n    }\n  }\n  // Decide which node should run first\n  InitDecideParameters(sat, &decide_parameters);\n  VLOG(3) << \"Straightening order: \";\n  for (int32_t decide_parameter : decide_parameters) { VLOG(3) << decide_parameter; }\n\n  // Order in the waiting sets\n  struct comp {\n    bool operator()(const TopoStruct* a, const TopoStruct* b) const {\n      for (auto decide_parameter : decide_parameters) {\n        auto decide_parameter_a = a->GetDecidingParameter(decide_parameter);\n        auto decide_parameter_b = b->GetDecidingParameter(decide_parameter);\n        if (decide_parameter_a != decide_parameter_b) {\n          return decide_parameter_a < decide_parameter_b;\n        }\n      }\n      return a->node < b->node;\n    }\n  };\n\n  // Classify sets for the task nodes\n  // 0, TaskClassifier::kWaitingOverlapNode\n  // It contains transfer nodes, and those with less time in gpu if request.\n  // std::set<TopoStruct*, comp> waiting_overlap_node;\n  // 1, TaskClassifier::kWaitingMainComputation\n  // std::set<TopoStruct*, comp> waiting_main_computation;\n  // 2, TaskClassifier::kRunASAP , run as soon as possible\n  // std::set<TopoStruct*, comp> run_asap;\n  // 3, TaskClassifier::kRunALAP , run as late as possible\n  // std::set<TopoStruct*, comp> run_alap;\n  const int32_t num_classifier = 4;\n  std::vector<std::set<TopoStruct*, comp>> waiting_lists(num_classifier);\n\n  std::vector<int32_t> remain_task_nums(num_classifier, 0);\n\n  auto AddOrderedNodes = [&](TaskNode* task_node) { ordered_task_nodes->emplace_back(task_node); };\n\n  // wait in the list\n  auto wait = [&](TaskNode* node) {\n    TopoStruct* first_topo_struct = &task_node2topo_struct[node];\n    // Check if all the same nodes are ready simultaneously\n    for (auto& curr_topo_struct : key2topo_structs.at(first_topo_struct->key)) {\n      if (curr_topo_struct->counter) { return; }\n    }\n    // Add all the same nodes at the same time\n    auto& waiting_list = waiting_lists[first_topo_struct->task_classifier];\n    for (auto& curr_topo_struct : key2topo_structs.at(first_topo_struct->key)) {\n      waiting_list.insert(curr_topo_struct);\n      // Reduce counter then this node will never be added again\n      // Though inserting into a map twice does not matter because of the same keys\n      curr_topo_struct->counter--;\n    }\n  };\n\n  // initialization\n  task_graph->ForEachNode([&](TaskNode* node) {\n    int32_t count = node->in_edges().size();\n    auto& topo_struct = task_node2topo_struct[node];\n    topo_struct.counter = count;\n    if (count == 0) { wait(node); }\n    remain_task_nums[topo_struct.task_classifier]++;\n  });\n\n  // Finish execution\n  auto finish_execution = [&](TaskNode* node) {\n    node->ForEachNodeOnOutEdge([&](TaskNode* out) {\n      --(task_node2topo_struct[out].counter);\n      if (task_node2topo_struct[out].counter == 0) { wait(out); }\n    });\n  };\n\n  // Find the iterator of an element in set\n  // Make sure that the element exist in the set before using this function\n  auto FindElementInSet = [&](TopoStruct* element, std::set<TopoStruct*, comp>& set) {\n    auto it = set.find(element);\n    // NOTE: In some cases, the set can not find this element\n    // Tested in machine-16:\n    // Deleting: 0x7f75041d64c0, size: 4:\n    // 0x7f75041d64c0, 0x7f75040d7390, 0x7f7504384540, 0x7f75042bc410,\n    // Find: 0x4\n    // Or it may have the chance to delete multiple elements while deleting one element.\n    CHECK(it != set.end() && *it == element)\n        << \" Something happens. If you make sure that the element exist in the set but you still \"\n           \"can not find that element, please report this issue to Oneflow Inc.\";\n    // TODO: One simple resolution is to traverse all the elements in the set and find the\n    // corresponding iterator. But it is not recommended. If std::set do have problem, we may need\n    // to implement our own set. Or we find out the problematic version of std and make it clear to\n    // the users that we do not support that version.\n    // We may be able to reproduce the bug in the commit 0c06021c7e48d2e84d20e555e4f4dfbaf04a5e7b\n    // by running\n    // ONEFLOW_LAZY_COMPILE_MODE=\"rank_per_thread\" ONEFLOW_TEST_DEVICE_NUM=4 python3 -m\n    // oneflow.distributed.launch --nproc_per_node 4 -m unittest discover . --failfast --verbose\n    // under the path oneflow/python/oneflow/test/graph\n    // We still need to delete the file test_alexnet_auto_parallel.py before running the command.\n    return it;\n  };\n\n  // Since the erase function call the find function\n  // we also need to reset the erase function\n  auto EraseElementInSet = [&](TopoStruct* element, std::set<TopoStruct*, comp>& set) {\n    set.erase(FindElementInSet(element, set));\n  };\n\n  // Move the first node of the waiting list to the execution list\n  auto move2execution_list = [&](std::set<TopoStruct*, comp>& waiting_list,\n                                 std::vector<TaskNode*>& execution_list) {\n    TaskNode* first_node = (*waiting_list.begin())->node;\n    int32_t execution_num = 0;\n    TopoStruct* first_topo_struct = &task_node2topo_struct[first_node];\n    // Find all the same nodes in different machines which should be run simultaneously\n    for (auto& curr_topo_struct : key2topo_structs.at(first_topo_struct->key)) {\n      execution_num++;\n      execution_list.push_back(curr_topo_struct->node);\n      EraseElementInSet(curr_topo_struct, waiting_list);\n    }\n    CHECK_GT(execution_num, 0) << \"Error, no task nodes are moved to the execution list\";\n  };\n\n  // Execute the first n nodes in the waiting list\n  auto execute = [&](int32_t list_classifier, int32_t n, bool if_reverse = false) {\n    // n > 0\n    if (n <= 0) { return; }\n    auto& waiting_list = waiting_lists[list_classifier];\n    std::vector<TaskNode*> execution_list;\n    int32_t count = 0;\n    // Move to the execution list\n    while (!waiting_list.empty()) {\n      move2execution_list(waiting_list, execution_list);\n      count++;\n      if (count >= n) { break; }\n    }\n    remain_task_nums[list_classifier] -= execution_list.size();\n    // Set the order and then remove from the execution list\n    for (auto* node : execution_list) {\n      AddOrderedNodes(node);\n      finish_execution(node);\n    }\n  };\n\n  // straightening\n  int32_t maximum_overlap_num = MaximumOverlapNum(sat, nccl_use_compute_stream);\n  while (true) {\n    if (waiting_lists[TaskClassifier::kRunASAP].empty()) {\n      if (waiting_lists[TaskClassifier::kWaitingOverlapNode].empty()) {\n        if (waiting_lists[TaskClassifier::kWaitingMainComputation].empty()) {\n          if (waiting_lists[TaskClassifier::kRunALAP].empty()) {\n            // All the waiting lists are empty\n            break;\n          } else {\n            // Execute all the nodes left\n            execute(TaskClassifier::kRunALAP, waiting_lists[TaskClassifier::kRunALAP].size());\n          }\n        } else {\n          // Execute one computation node\n          execute(TaskClassifier::kWaitingMainComputation, 1);\n        }\n      } else {\n        int32_t computation_num = std::min(\n            std::min(int32_t(waiting_lists[TaskClassifier::kWaitingMainComputation].size()\n                             / (waiting_lists[TaskClassifier::kWaitingOverlapNode].size())),\n                     remain_task_nums[TaskClassifier::kWaitingMainComputation]\n                         / remain_task_nums[TaskClassifier::kWaitingOverlapNode]),\n            maximum_overlap_num);\n        // Holding the node to be overlapped\n        std::vector<TaskNode*> overlap_execution_list;\n        move2execution_list(waiting_lists[TaskClassifier::kWaitingOverlapNode],\n                            overlap_execution_list);\n        remain_task_nums[TaskClassifier::kWaitingOverlapNode] -= overlap_execution_list.size();\n        for (auto* overlap_node : overlap_execution_list) { AddOrderedNodes(overlap_node); }\n        // Overlap the node with computation from the trunk\n        execute(TaskClassifier::kWaitingMainComputation, computation_num);\n\n        // Release the overlap node\n        for (auto* overlap_node : overlap_execution_list) { finish_execution(overlap_node); }\n      }\n    } else {\n      execute(TaskClassifier::kRunASAP, waiting_lists[TaskClassifier::kRunASAP].size());\n    }\n  }\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/graph/straighten_nodes.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_GRAPH_STRAIGHTEN_NODES_H_\n#define ONEFLOW_CORE_GRAPH_STRAIGHTEN_NODES_H_\n\n#include \"oneflow/core/graph/task_graph.h\"\n#include \"oneflow/core/job/job_conf.pb.h\"\n\nnamespace oneflow {\n\n// The difference between a descending order and its corresponding ascending order\nconst int kDiff4AscendDescend = 100;\n\n// deciding parameter\n// The sorting order of nodes for the straighten algorithm\nenum StraightenOrder : int {\n  kTributaryLayerAscend = 0,     // small tributary layers go first\n  kDistanceToOverlapAscend = 1,  // small minimum distance to overlap go first\n  kLayerAscend = 2,              // first in first out\n  kMemoryIncrementAscend = 3,    // small memory increment go first\n  kExceedTimeAscend = 4,         // small exceed time go first\n  kMemoryVolumeAscend = 5,       // small memory volume go first\n  kMaxLayerAscend = 6,           // the urgent one go first\n\n  kTributaryLayerDescend =\n      kDiff4AscendDescend + kTributaryLayerAscend,  // large tributary layers go first\n  kDistanceToOverlapDescend =\n      kDiff4AscendDescend + kDistanceToOverlapAscend,  // long distance to overlap go first\n  kLayerDescend = kDiff4AscendDescend + kLayerAscend,  // last in first out\n  kMemoryIncrementDescend =\n      kDiff4AscendDescend + kMemoryIncrementAscend,              // large memory increment go first\n  kExceedTimeDescend = kDiff4AscendDescend + kExceedTimeAscend,  // large exceed time go first\n  kMemoryVolumeDescend = kDiff4AscendDescend + kMemoryVolumeAscend,  // large memory volume go first\n  kMaxLayerDescent = kDiff4AscendDescend + kMaxLayerAscend,          // the non-urgent one go first\n};\n\n// Some operators have longer time in cpu and less time in gpu.\n// Running those operators without overlap would cause large gap during each iteration.\n// For example, expand dims would not execute any kernel on gpu but still need 10us to execute some\n// functions on cpu.\nbool ShortGpuTime(const OperatorConf& op_conf);\n\n// SAT, a.k.a. Scholastic Aptitude Test,\n// is the college admission test in the United States of America.\nvoid InitDecideParameters(StraightenAlgorithmTag sat,\n                          std::vector<StraightenOrder>* decide_parameters);\n\n// Maximum overlap number\n// While running an overlap operator, we would run some other operators simultaneously.\nint32_t MaximumOverlapNum(StraightenAlgorithmTag sat, bool nccl_use_compute_stream);\n\ntemplate<class HashMapType>\nvoid UpdateSat(const HashMapType& node2topo_struct, StraightenAlgorithmTag* sat) {\n  *sat = GlobalJobDesc().job_conf().straighten_algorithm_tag_in_task_graph();\n  if (*sat == StraightenAlgorithmTag::kOverlap4CpuGpu) {\n    // If not cpu nodes, then the overlap strategy between cpu and gpu might consume large memory\n    bool exist_cpu_nodes = false;\n    for (const auto& pair : node2topo_struct) {\n      // Found a cpu node\n      if (pair.second.exceed_time == 1) {\n        exist_cpu_nodes = true;\n        break;\n      }\n    }\n    if (!exist_cpu_nodes) {\n      // Switch to the compress memory strategy, the default one\n      // Since the overlap strategy for transfer might not be working on 1n1d.\n      *sat = StraightenAlgorithmTag::kCompressMemory;\n    }\n  }\n}\n\n// Make sure that we use the same boolean value nccl_use_compute_stream through the straighten\n// algorithm\nvoid StraightenNodes(TaskGraph* task_graph, std::vector<TaskNode*>* ordered_task_nodes,\n                     bool nccl_use_compute_stream);\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_GRAPH_STRAIGHTEN_NODES_H_\n"
  },
  {
    "path": "oneflow/core/graph/stream_id.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/graph/stream_id.h\"\n#include <climits>\n\nnamespace oneflow {\n\n// StreamId encoding (bits)\n// | reserved |   node_index   | device_type | device_index  | stream_index |\n// | -- 18 -- | ----- 19 ----- | ---- 5 ---- | ----- 7 ----- |              |\n// |          |                  DeviceId                    |              |\n// |          | ------------------- 31 --------------------- | ---- 15 ---- |\n// |                               StreamId                                 |\n// | -------------------------------- 64 ---------------------------------- |\n\nnamespace {\n\nconstexpr size_t kInt64Bits = sizeof(int64_t) * CHAR_BIT;\n\nconstexpr size_t kDeviceIndexShift = StreamId::kStreamIndexBits;\nconstexpr size_t kDeviceTypeShift = kDeviceIndexShift + DeviceId::kDeviceIndexBits;\nconstexpr size_t kRankShift = kDeviceTypeShift + DeviceId::kDeviceTypeBits;\n\nstatic_assert(kRankShift + DeviceId::kRankBits < kInt64Bits, \"\");\n\nconstexpr int64_t kStreamIndexInt64Mask = (int64_t{1} << StreamId::kStreamIndexBits) - 1;\nconstexpr int64_t kDeviceIndexInt64Mask = ((int64_t{1} << DeviceId::kDeviceIndexBits) - 1)\n                                          << kDeviceIndexShift;\nconstexpr int64_t kDeviceTypeInt64Mask = ((int64_t{1} << DeviceId::kDeviceTypeBits) - 1)\n                                         << kDeviceTypeShift;\nconstexpr int64_t kRankInt64Mask = ((int64_t{1} << DeviceId::kRankBits) - 1) << kRankShift;\n\n}  // namespace\n\nint64_t EncodeStreamIdToInt64(const StreamId& stream_id) {\n  int64_t id = static_cast<int64_t>(stream_id.stream_index());\n  id |= static_cast<int64_t>(stream_id.device_index()) << kDeviceIndexShift;\n  id |= static_cast<int64_t>(stream_id.device_type()) << kDeviceTypeShift;\n  id |= static_cast<int64_t>(stream_id.rank()) << kRankShift;\n  return id;\n}\n\nStreamId DecodeStreamIdFromInt64(int64_t stream_id_val) {\n  int64_t rank = (stream_id_val & kRankInt64Mask) >> kRankShift;\n  int64_t device_type = (stream_id_val & kDeviceTypeInt64Mask) >> kDeviceTypeShift;\n  int64_t device_index = (stream_id_val & kDeviceIndexInt64Mask) >> kDeviceIndexShift;\n  int64_t stream_index = (stream_id_val & kStreamIndexInt64Mask);\n  return StreamId{static_cast<DeviceId::rank_t>(rank), static_cast<DeviceType>(device_type),\n                  static_cast<DeviceId::device_index_t>(device_index),\n                  static_cast<StreamId::stream_index_t>(stream_index)};\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/graph/stream_id.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_GRAPH_STREAM_ID_H_\n#define ONEFLOW_CORE_GRAPH_STREAM_ID_H_\n\n#include \"oneflow/core/device/device_id.h\"\n\nnamespace oneflow {\n\nclass StreamId {\n public:\n  using stream_index_t = uint32_t;\n\n  constexpr static size_t kStreamIndexBits = 15;\n  constexpr static stream_index_t kMaxStreamIndex =\n      (stream_index_t{1} << kStreamIndexBits) - stream_index_t{1};\n\n  StreamId(const DeviceId& device_id, stream_index_t stream_index)\n      : device_id_(device_id), stream_index_(stream_index) {\n    CHECK_LE(stream_index, kMaxStreamIndex);\n  }\n  StreamId(DeviceId::rank_t node_index, DeviceType device_type,\n           DeviceId::device_index_t device_index, stream_index_t stream_index)\n      : device_id_(node_index, device_type, device_index), stream_index_(stream_index) {\n    CHECK_LE(stream_index, kMaxStreamIndex);\n  }\n\n  const DeviceId& device_id() const { return device_id_; }\n  DeviceId::rank_t rank() const { return device_id_.rank(); }\n  DeviceType device_type() const { return device_id_.device_type(); }\n  DeviceId::device_index_t device_index() const { return device_id_.device_index(); }\n  stream_index_t stream_index() const { return stream_index_; }\n\n  bool operator==(const StreamId& rhs) const {\n    return device_id_ == rhs.device_id_ && stream_index_ == rhs.stream_index_;\n  }\n\n  bool operator!=(const StreamId& rhs) const { return !(*this == rhs); }\n\n  size_t hash() const {\n    size_t hash = device_id_.hash();\n    HashCombine(&hash, std::hash<stream_index_t>{}(stream_index_));\n    return hash;\n  }\n\n private:\n  DeviceId device_id_;\n  stream_index_t stream_index_;\n};\n\nint64_t EncodeStreamIdToInt64(const StreamId&);\nStreamId DecodeStreamIdFromInt64(int64_t);\n\n}  // namespace oneflow\n\nnamespace std {\n\ntemplate<>\nstruct hash<oneflow::StreamId> {\n  size_t operator()(const oneflow::StreamId& stream_id) const { return stream_id.hash(); }\n};\n\n}  // namespace std\n\n#endif  // ONEFLOW_CORE_GRAPH_STREAM_ID_H_\n"
  },
  {
    "path": "oneflow/core/graph/stream_index_generator.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/graph/stream_index_generator.h\"\n#include <mutex>\n#include \"oneflow/core/job/id_state.h\"\n\nnamespace oneflow {\n\nStreamIndexGenerator::StreamIndexGenerator(stream_index_t stream_index)\n    : next_stream_index_(stream_index) {}\n\nStreamIndexGenerator::stream_index_t StreamIndexGenerator::GenerateAnonymous() {\n  std::unique_lock<std::mutex> lck(mtx_);\n  return next_stream_index_++;\n}\n\nStreamIndexGenerator::stream_index_t StreamIndexGenerator::GenerateNamed(const std::string& name) {\n  return GenerateNamedRoundRobin(name, 1);\n}\n\nStreamIndexGenerator::stream_index_t StreamIndexGenerator::GenerateNamedRoundRobin(\n    const std::string& name, size_t size) {\n  CHECK_GT(size, 0);\n  std::unique_lock<std::mutex> lck(mtx_);\n  auto it = name2rr_range_.find(name);\n  if (it == name2rr_range_.end()) {\n    it = name2rr_range_.emplace(name, RoundRobinRange{next_stream_index_, size}).first;\n    next_stream_index_ += size;\n  } else {\n    CHECK_EQ(it->second.size, size) << name;\n  }\n\n  stream_index_t cur_stream_index = it->second.begin;\n  if (size > 1) {\n    size_t& offset = it->second.offset;\n    cur_stream_index += offset++;\n    if (offset >= size) { offset = 0; }\n  }\n  return cur_stream_index;\n}\n\nStreamIndexGenerator::stream_index_t StreamIndexGenerator::GetCurrStreamIndex() {\n  std::unique_lock<std::mutex> lck(mtx_);\n  return next_stream_index_;\n}\n\nvoid StreamIndexGenerator::TryUpdateNextStreamIndex(stream_index_t next_stream_index) {\n  std::unique_lock<std::mutex> lck(mtx_);\n  next_stream_index_ = std::max(next_stream_index_, next_stream_index);\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/graph/stream_index_generator.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_GRAPH_STREAM_INDEX_GENERATOR_H_\n#define ONEFLOW_CORE_GRAPH_STREAM_INDEX_GENERATOR_H_\n\n#include \"oneflow/core/graph/stream_id.h\"\n#include \"oneflow/core/job/id_state.h\"\n\nnamespace oneflow {\n\nclass StreamIndexGenerator final {\n public:\n  using stream_index_t = StreamId::stream_index_t;\n\n  explicit StreamIndexGenerator(stream_index_t stream_index);\n  OF_DISALLOW_COPY_AND_MOVE(StreamIndexGenerator);\n  ~StreamIndexGenerator() = default;\n\n  stream_index_t GenerateAnonymous();\n  stream_index_t GenerateNamed(const std::string& name);\n  stream_index_t GenerateNamedRoundRobin(const std::string& name, size_t size);\n  stream_index_t GetCurrStreamIndex();\n  void TryUpdateNextStreamIndex(stream_index_t next_stream_index);\n\n private:\n  struct RoundRobinRange {\n    RoundRobinRange(stream_index_t begin, size_t size) : begin(begin), size(size), offset(0) {}\n    stream_index_t begin;\n    size_t size;\n    size_t offset;\n  };\n\n  stream_index_t next_stream_index_;\n  HashMap<std::string, RoundRobinRange> name2rr_range_;\n  std::mutex mtx_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_GRAPH_STREAM_INDEX_GENERATOR_H_\n"
  },
  {
    "path": "oneflow/core/graph/task_edge.proto",
    "content": "syntax = \"proto2\";\npackage oneflow;\n\nimport \"oneflow/core/register/logical_blob_id.proto\";\n\nmessage TaskEdgeProto {\n  required int64 task_edge_uid = 1;\n  required int64 src_task_id = 2;\n  required int64 dst_task_id = 3;\n  repeated LogicalBlobId lbi = 4;\n  map<string, int64> name_in_producer2regst_desc_id = 5;\n};\n"
  },
  {
    "path": "oneflow/core/graph/task_graph.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/graph/task_graph.h\"\n#include \"oneflow/core/common/just.h\"\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/common/container_util.h\"\n#include \"oneflow/core/common/env_var/debug_mode.h\"\n#include \"oneflow/core/graph/inplace_lbi_graph.h\"\n#include \"oneflow/core/job/job_conf.pb.h\"\n#include \"oneflow/core/job/job_desc.h\"\n#include \"oneflow/core/job/task.pb.h\"\n#include \"oneflow/core/register/blob_desc.h\"\n#include \"oneflow/core/job/global_for.h\"\n#include \"oneflow/core/operator/variable_op.h\"\n#include \"oneflow/core/graph/op_graph.h\"\n#include \"oneflow/core/graph/normal_forward_compute_task_node.h\"\n#include \"oneflow/core/graph/boxing_identity_task_node.h\"\n#include \"oneflow/core/job/scope.h\"\n#include \"oneflow/core/rpc/include/global_process_ctx.h\"\n#include \"oneflow/core/vm/symbol_storage.h\"\n#include \"oneflow/core/job_rewriter/calculation_pass.h\"\n#include \"oneflow/core/graph/boxing/sub_task_graph_builder_util.h\"\n#include \"oneflow/core/graph/boxing/hierarchical_sub_task_graph_builder_impl.h\"\n#include \"oneflow/core/graph/task_stream_index_manager.h\"\n#include \"oneflow/core/ep/include/primitive/memcpy.h\"\n#include \"oneflow/core/graph/straighten_nodes.h\"\n#include \"oneflow/core/register/runtime_register_desc.h\"\n#include \"oneflow/core/common/env_var/env_var.h\"\n#include \"oneflow/core/graph/boxing_task_graph.pb.h\"\n#include \"oneflow/core/graph/task_graph_rebuild_ctx.h\"\n#include \"oneflow/core/framework/user_op_registry_manager.h\"\n#include \"oneflow/core/graph/task_type_visitor.h\"\n\nnamespace oneflow {\n\n// TODO(Chengcheng): default false.\nDEFINE_ENV_BOOL(ONEFLOW_ENABLE_OUTDATED_OPT_FW_CHAIN_MERGE, true);\n\nnamespace {\n\nbool IsMemcpyPrimitiveSupported(DeviceType device_type, ep::primitive::MemcpyKind kind) {\n  auto primitive = ep::primitive::NewPrimitive<ep::primitive::MemcpyFactory>(device_type, kind);\n  return primitive.operator bool();\n}\n\nbool IsMemcpyHtoDSupported(DeviceType device_type) {\n  return IsMemcpyPrimitiveSupported(device_type, ep::primitive::MemcpyKind::kHtoD);\n}\n\nbool IsMemcpyDtoHSupported(DeviceType device_type) {\n  return IsMemcpyPrimitiveSupported(device_type, ep::primitive::MemcpyKind::kDtoH);\n}\n\nbool IsConnectToTickOp(const TaskNode* node) {\n  const auto* comp_task_node = dynamic_cast<const CompTaskNode*>(node);\n  if (comp_task_node == nullptr) { return false; }\n  const Operator* op = comp_task_node->op().get();\n  if (dynamic_cast<const VariableOp*>(op) != nullptr) { return true; }\n  return false;\n}\n\nbool IsSubsetTickOpConf(const OperatorConf& op_conf) {\n  return op_conf.has_src_subset_tick_conf() || op_conf.has_dst_subset_tick_conf();\n}\n\nbool IsTickOpConf(const OperatorConf& conf) {\n  return IsClassRegistered<int32_t, IsTickTockOpTypeCase>(conf.op_type_case());\n}\n\nconst std::string& GetOpConfCalculationPassName(const OperatorConf& op_conf) {\n  CHECK(op_conf.has_scope_symbol_id());\n  if (op_conf.has_calculation_pass_name()) { return op_conf.calculation_pass_name(); }\n  int64_t scope_symbol_id = op_conf.scope_symbol_id();\n  CHECK(Singleton<symbol::Storage<Scope>>::Get()->Has(scope_symbol_id))\n      << \" Error! op : \\n \" << op_conf.DebugString()\n      << \" has error scope_symbol_id = \" << scope_symbol_id\n      << \" which cannot find in Singleton<symbol::Storage<Scope>>::Get()\\n\";\n  const Scope& scope = Singleton<symbol::Storage<Scope>>::Get()->Get(scope_symbol_id);\n  return scope.scope_proto().calculation_pass_name();\n}\n\nbool IsOptimizerPassOp(const Operator* op) {\n  // NOTE(chengcheng): use scope::calculation_pass_name instead of area_id to not merge optimizer\n  // ops with fw/bw ops\n  if (!op->op_conf().has_scope_symbol_id()) {\n    // NOTE(chengcheng): Some system op insert to OpGraph may not set scope_symbol_id, it MUST NOT\n    // optimizer subgraph ops.\n    return false;\n  }\n  return GetOpConfCalculationPassName(op->op_conf()) == kOptimizerPass;\n}\n\nbool IsSpecialOpNotConsiderMergeInChain(const Operator* op) {\n  const OperatorConf& op_conf = op->op_conf();\n  if (op_conf.has_variable_conf() || op_conf.has_tick_conf() || op_conf.has_device_tick_conf()\n      || op_conf.has_src_subset_tick_conf() || op_conf.has_dst_subset_tick_conf()\n      || op_conf.has_source_tick_conf() || op_conf.has_sink_tick_conf()\n      || op_conf.has_acc_tick_conf()) {\n    return true;\n  }\n  if (op_conf.has_user_conf()) {\n    const std::string& user_type_name = op_conf.user_conf().op_type_name();\n    if (user_type_name == \"repeat\" || user_type_name == \"acc\" || user_type_name == \"pack\"\n        || user_type_name == \"unpack\" || user_type_name == \"identity_buffer\") {\n      return true;\n    }\n  }\n  // NOTE(chengcheng): ONLY nccl_use_compute_stream = false will exclude optimizer pass ops\n  if (!Singleton<ResourceDesc, ForSession>::Get()->nccl_use_compute_stream()\n      && IsOptimizerPassOp(op) && EnvBool<ONEFLOW_ENABLE_OUTDATED_OPT_FW_CHAIN_MERGE>()) {\n    return true;\n  }\n  return false;\n}\n\nbool IsTaskNodeProducedRegstHasMultiRegstNum(const TaskNode* node) {\n  for (const auto& pair : node->produced_regsts()) {\n    if (pair.second->min_register_num() > 1) { return true; }\n  }\n  return false;\n}\n\nbool CanBeMergedInChain(const TaskNode* node) {\n  // ONLY the node which is NormalForward and in GPU and NOT variable can be merged.\n  if (IsTaskNodeProducedRegstHasMultiRegstNum(node)) { return false; }\n  const auto* fw_comp_node = dynamic_cast<const NormalForwardCompTaskNode*>(node);\n  if (fw_comp_node == nullptr) { return false; }\n  if (fw_comp_node->device_type() == DeviceType::kCPU) { return false; }\n  const Operator* op = fw_comp_node->op().get();\n  if (IsSpecialOpNotConsiderMergeInChain(op)) { return false; }\n  return true;\n}\n\nstd::shared_ptr<const Shape> GetTaskNodeTimeShape(const TaskNode* node) {\n  const auto* fw_comp_node = dynamic_cast<const NormalForwardCompTaskNode*>(node);\n  CHECK(fw_comp_node != nullptr);\n  return CHECK_JUST(fw_comp_node->op()->GetOpTimeShape());\n}\n\nvoid TraverseConnectedSubGraphMergeInThisChain(TaskNode* this_node, const int64_t this_chain_id) {\n  CHECK(IsValidChainId(this_chain_id));\n  CHECK(!IsValidChainId(this_node->chain_id()));\n  // bfs search all node can be merged in this chain\n  std::shared_ptr<const Shape> seed_time_shape = GetTaskNodeTimeShape(this_node);\n  HashSet<TaskNode*> visited_nodes;\n  std::queue<TaskNode*> queued_nodes;\n  queued_nodes.push(this_node);\n  visited_nodes.insert(this_node);\n  while (!queued_nodes.empty()) {\n    TaskNode* cur_node = queued_nodes.front();\n    queued_nodes.pop();\n\n    CHECK(!IsValidChainId(cur_node->chain_id()));\n    cur_node->set_chain_id(this_chain_id);\n\n    cur_node->ForEachNodeOnInOutDataEdge([&](TaskNode* next_node) {\n      if (visited_nodes.find(next_node) == visited_nodes.end() && CanBeMergedInChain(next_node)\n          && this_node->thrd_id() == next_node->thrd_id()\n          && (*GetTaskNodeTimeShape(next_node)) == (*seed_time_shape)) {\n        if (!IsValidChainId(next_node->chain_id())) {\n          queued_nodes.push(next_node);\n          visited_nodes.insert(next_node);\n        } else {\n          CHECK_EQ(next_node->chain_id(), this_chain_id);\n        }\n      }\n    });\n  }\n}\n\nstd::function<TaskNode*(const std::string&)> MakeGetterTaskNode4SoleOpName(\n    const HashSet<TaskNode*>& task_nodes) {\n  auto op_name2task_nodes = std::make_shared<HashMap<std::string, HashSet<TaskNode*>>>();\n  for (TaskNode* task_node : task_nodes) {\n    if (task_node->exec_gph().node_num() == 1) {\n      ExecNode* exec_node = task_node->exec_gph().SoleNode();\n      CHECK((*op_name2task_nodes)[exec_node->op()->op_name()].emplace(task_node).second);\n    }\n  }\n  return [op_name2task_nodes](const std::string& op_name) -> TaskNode* {\n    const auto& iter = op_name2task_nodes->find(op_name);\n    if (iter == op_name2task_nodes->end()) { return nullptr; }\n    if (iter->second.size() > 1) { return nullptr; }\n    return *iter->second.begin();\n  };\n}\n\nbool IsLbiOnTaskEdge(const TaskEdge* edge, const LogicalBlobId& lbi) {\n  for (const auto& regst_desc : edge->GetRegsts()) {\n    if (regst_desc->HasLbi(lbi)) { return true; }\n  }\n  return false;\n}\n\nstd::function<bool(const LogicalBlobId&, const std::string&)>\nMakePredicatorIsLbiAllConsumersReachable(\n    const std::function<const TaskNode*(const std::string&)>& TaskNode4SoleOpName,\n    const std::function<bool(const std::string&, const std::string&)>&\n        IsOpNameDataOrCtrlReachable) {\n  auto IsDataOrCtrlReachable = [IsOpNameDataOrCtrlReachable](const TaskNode* src_node,\n                                                             const TaskNode* dst_node) -> bool {\n    if (IsValidChainId(src_node->chain_id()) && IsValidChainId(dst_node->chain_id())\n        && src_node->chain_id() == dst_node->chain_id()\n        && src_node->order_in_chain() <= dst_node->order_in_chain()) {\n      return true;\n    }\n    const CompTaskNode* comp_src_node = dynamic_cast<const CompTaskNode*>(src_node);\n    if (comp_src_node == nullptr) { return false; }\n    const CompTaskNode* comp_dst_node = dynamic_cast<const CompTaskNode*>(dst_node);\n    if (comp_dst_node == nullptr) { return false; }\n    return IsOpNameDataOrCtrlReachable(comp_src_node->op()->op_name(),\n                                       comp_dst_node->op()->op_name());\n  };\n  return [TaskNode4SoleOpName, IsDataOrCtrlReachable](const LogicalBlobId& lbi,\n                                                      const std::string& op_name) -> bool {\n    const TaskNode* src_task_node = TaskNode4SoleOpName(lbi.op_name());\n    const TaskNode* dst_task_node = TaskNode4SoleOpName(op_name);\n    size_t out_edges_size = 0;\n    size_t reachable_out_edges_size = 0;\n    for (TaskEdge* out_edge : src_task_node->out_edges()) {\n      if (IsLbiOnTaskEdge(out_edge, lbi)) {\n        out_edges_size += 1;\n        reachable_out_edges_size += IsDataOrCtrlReachable(out_edge->dst_node(), dst_task_node);\n      }\n    }\n    return out_edges_size > 0 && out_edges_size == reachable_out_edges_size;\n  };\n}\n\nbool IsInplaceAllowed(\n    TaskNode* task_node, const std::vector<std::string>& bns,\n    const std::function<const TaskNode*(const std::string&)>& TaskNode4SoleOpName) {\n  if (task_node->exec_gph().node_num() != 1) { return false; }\n  const auto& exec_node = *task_node->exec_gph().SoleNode();\n  for (const auto& bn : bns) {\n    // TaskNode for bn is not nullptr if it's on the same device with `task_node`\n    if (TaskNode4SoleOpName(exec_node.op()->BnInOp2Lbi(bn).op_name()) == nullptr) { return false; }\n    const RegstDesc& regst_desc = *exec_node.RegstDesc4BnInOp(bn);\n    if (regst_desc.NumOfLbi() != 1) { return false; }\n  }\n  const BlobDesc* first_blob = nullptr;\n  for (const auto& bn : bns) {\n    const BlobDesc* blob_desc = exec_node.RegstDesc4BnInOp(bn)->SoleBlobDesc();\n    if (first_blob == nullptr) {\n      first_blob = blob_desc;\n    } else {\n      if (!(first_blob->shape().elem_cnt() == blob_desc->shape().elem_cnt()\n            && first_blob->data_type() == blob_desc->data_type())) {\n        return false;\n      }\n    }\n  }\n  return true;\n}\n\nstd::unique_ptr<BoxingLogger> CreateBoxingLogger() {\n  if (Singleton<ResourceDesc, ForSession>::Get()->enable_debug_mode()) {\n    return std::unique_ptr<BoxingLogger>(\n        new CsvBoxingLogger(StrCat(\"boxing/log/\", GlobalJobDesc().job_id()) + \".csv\"));\n  } else {\n    return std::unique_ptr<BoxingLogger>(new NullBoxingLogger());\n  }\n}\n\nMaybe<void> MakeGetterTaskNode4MachineId7ThrdId(\n    const std::vector<CompTaskNode*>& task_nodes,\n    std::function<Maybe<CompTaskNode*>(int64_t mchn_id, int64_t thrd_id)>* Getter) {\n  // ticks are shared within a machine/process\n  auto machine_id2task_node = std::make_shared<HashMap<int64_t, CompTaskNode*>>();\n  for (auto* task_node : task_nodes) {\n    machine_id2task_node->emplace(task_node->machine_id(), task_node);\n  }\n  *Getter = [machine_id2task_node](int64_t mchn_id, int64_t thrd_id) -> Maybe<CompTaskNode*> {\n    const auto& iter = machine_id2task_node->find(mchn_id);\n    CHECK_OR_RETURN(iter != machine_id2task_node->end());\n    return iter->second;\n  };\n  return Maybe<void>::Ok();\n}\n\nnamespace {\n\nStreamId GetStreamId(const OpNode* op_node, int64_t parallel_id, TaskType task_type) {\n  const ParallelDesc& parallel_desc = op_node->parallel_desc();\n  int64_t machine_id = CHECK_JUST(parallel_desc.MachineId4ParallelId(parallel_id));\n  int64_t dev_phy_id = CHECK_JUST(parallel_desc.DeviceId4ParallelId(parallel_id));\n\n  DeviceId::device_index_t device_index = parallel_desc.device_type() == DeviceType::kCPU\n                                              ? 0\n                                              : static_cast<DeviceId::device_index_t>(dev_phy_id);\n  DeviceId device_id{static_cast<DeviceId::rank_t>(machine_id), parallel_desc.device_type(),\n                     device_index};\n  StreamId::stream_index_t stream_index = 0;\n  if (op_node->op().op_conf().has_stream_name_hint()) {\n    const std::string& stream_name_hint = op_node->op().op_conf().stream_name_hint();\n    VLOG(3) << \"set op: \" << op_node->op().op_name() << \" to stream: \" << stream_name_hint;\n    stream_index = Singleton<TaskStreamIndexManager>::Get()->GetNamedTaskStreamIndex(\n        device_id, stream_name_hint);\n  } else {\n    stream_index =\n        Singleton<TaskStreamIndexManager>::Get()->GetTaskStreamIndex(task_type, device_id);\n  }\n  return StreamId{device_id, stream_index};\n}\n\nTaskType TaskType4OpNode(const OpNode* op_node) {\n  std::unique_ptr<CompTaskNode> comp_task_node(NewCompTaskNode4OpNode(op_node));\n  return comp_task_node->GetTaskType();\n}\n\n}  // namespace\n\nCompTaskNode* GenCompTaskNode(\n    const OpNode* op_node, int64_t parallel_id,\n    const std::function<StreamId(const OpNode* op_node, int64_t parallel_id, TaskType task_type)>&\n        GetOrCreateStreamId) {\n  const ParallelDesc& parallel_desc = op_node->parallel_desc();\n  int64_t parallel_num = parallel_desc.parallel_num();\n  CompTaskNode* comp_task_node = NewCompTaskNode4OpNode(op_node);\n  int64_t machine_id = CHECK_JUST(parallel_desc.MachineId4ParallelId(parallel_id));\n  comp_task_node->set_machine_id(machine_id);\n  comp_task_node->mut_parallel_ctx()->set_parallel_id(parallel_id);\n  comp_task_node->mut_parallel_ctx()->set_parallel_num(parallel_num);\n  StreamId stream_id = GetOrCreateStreamId(op_node, parallel_id, comp_task_node->GetTaskType());\n  comp_task_node->set_thrd_id(EncodeStreamIdToInt64(stream_id));\n  comp_task_node->set_op_node(op_node);\n  return comp_task_node;\n}\n\nvoid GenSortedCompTaskNodes(const OpNode* op_node, std::vector<CompTaskNode*>* sorted_comp_tasks) {\n  int64_t parallel_idx = 0;\n  const ParallelDesc& parallel_desc = op_node->parallel_desc();\n  for (int64_t machine_id : parallel_desc.sorted_machine_ids()) {\n    for (int64_t dev_phy_id : parallel_desc.sorted_dev_phy_ids(machine_id)) {\n      sorted_comp_tasks->emplace_back(GenCompTaskNode(op_node, parallel_idx++, &GetStreamId));\n      (void)dev_phy_id;\n    }\n    (void)machine_id;\n  }\n}\n\nbool IsConnectedLbisAllSameNdSbp(const OpEdge* op_edge) {\n  const OpNode* src_node = op_edge->src_node();\n  const OpNode* dst_node = op_edge->dst_node();\n  CHECK_GT(op_edge->lbis().size(), 0);\n  HashSet<bool> predicators;\n  for (const LogicalBlobId& lbi : op_edge->lbis()) {\n    const NdSbp& src_nd_sbp = src_node->NdSbp4Lbi(lbi);\n    const NdSbp& dst_nd_sbp = dst_node->NdSbp4Lbi(lbi);\n    predicators.insert(src_nd_sbp == dst_nd_sbp);\n  }\n  CHECK_EQ(predicators.size(), 1);\n  return *predicators.begin();\n}\n\nBldSubTskGphMthd GetMthdForBldSubTskGph(const OpEdge* op_edge) {\n  const OpNode* src_node = op_edge->src_node();\n  const OpNode* dst_node = op_edge->dst_node();\n  const ParallelDesc& src_pd = src_node->parallel_desc();\n  const ParallelDesc& dst_pd = dst_node->parallel_desc();\n  const OperatorConf& src_op_conf = src_node->op().op_conf();\n  const OperatorConf& dst_op_conf = dst_node->op().op_conf();\n\n  // WaitAndSendIds -> Reentrantlock\n  if (src_op_conf.has_wait_and_send_ids_conf() && dst_op_conf.has_reentrant_lock_conf()) {\n    CHECK_EQ(src_pd.parallel_num(), 1);\n    CHECK_EQ(dst_pd.parallel_num(), 1);\n    return &TaskGraph::BldSubTskGphByBoxing;\n  }\n\n  // *Tick -> *Tick\n  if (IsTickOpConf(src_op_conf) || IsTickOpConf(dst_op_conf)) {\n    if (src_op_conf.has_source_tick_conf()) {\n      CHECK(dst_op_conf.has_tick_conf());\n      CHECK_EQ(src_pd.parallel_num(), 1);\n      CHECK_EQ(dst_pd.parallel_num(), 1);\n      return &TaskGraph::BldSubTskGphByBoxing;\n    } else if (dst_op_conf.has_sink_tick_conf()) {\n      CHECK(src_op_conf.has_tick_conf() || src_op_conf.has_sink_tick_conf());\n      CHECK_EQ(src_pd.parallel_num(), 1);\n      CHECK_EQ(dst_pd.parallel_num(), 1);\n      return &TaskGraph::BldSubTskGphByBoxing;\n    } else if (IsSubsetTickOpConf(src_op_conf)) {\n      return &TaskGraph::BldSubTskGphBySrcSubsetConnect;\n    } else if (IsSubsetTickOpConf(dst_op_conf)) {\n      return &TaskGraph::BldSubTskGphByDstSubsetConnect;\n    } else if (IsTickOpConf(src_op_conf) && IsTickOpConf(dst_op_conf)) {\n      if (src_pd.parallel_num() == dst_pd.parallel_num()) {\n        return &TaskGraph::BldSubTskGphByOneToOne;\n      } else {\n        CHECK_EQ(src_pd.parallel_num(), 1);\n        return &TaskGraph::BldSubTskGphByBroadcastToBroadcast;\n      }\n    }\n  }\n\n  std::shared_ptr<CompTaskNode> src_comp_task(NewCompTaskNode4OpNode(src_node));\n  std::shared_ptr<CompTaskNode> dst_comp_task(NewCompTaskNode4OpNode(dst_node));\n  // NOTE(chengcheng): MUST use TaskType instead of OpTypeCase because may\n  //   Multi-op corresponding to SAME TaskType such as:\n  //     DistributeConcatOpConf and DistributeAddOpConf -> TaskType::kDistributeConcat\n  //     DistributeSplitOpConf  and DistributeCloneOpConf -> TaskType::kDistributeSplit\n  // * -> DistributeConcat\n  if (dst_comp_task->GetTaskType() == TaskType::kDistributeConcat) {\n    return &TaskGraph::BldSubTskGphByPartialInLbiConnect;\n  }\n\n  // DistributeSplit -> *\n  if (src_comp_task->GetTaskType() == TaskType::kDistributeSplit) {\n    return &TaskGraph::BldSubTskGphByPartialOutLbiConnect;\n  }\n\n  // NormalForward -> DecodeH2D\n  if (src_comp_task->GetTaskType() == TaskType::kNormalForward\n      && dst_comp_task->GetTaskType() == TaskType::kDecodeH2D) {\n    return &TaskGraph::BldSubTskGphNormalForwardToDecodeH2D;\n  }\n\n  if (src_pd.parallel_num() == 1 && dst_pd.parallel_num() == 1) {\n    return &TaskGraph::BldSubTskGphByOneToOne;\n  }\n\n  // one to one\n  if (src_pd.parallel_num() == dst_pd.parallel_num() && *src_pd.hierarchy() == *dst_pd.hierarchy()\n      && IsConnectedLbisAllSameNdSbp(op_edge)) {\n    return &TaskGraph::BldSubTskGphByOneToOne;\n  }\n\n  return &TaskGraph::BldSubTskGphByBoxing;\n}\n\nvoid ForEachOpGraphNecessaryCtrlEdge(\n    const OpGraph* op_graph, const std::function<void(const OpNode*, const OpNode*)>& Handler) {\n  auto IsOpGraphDataReachable = op_graph->CreatePredicatorIsReachable();\n  op_graph->ForEachNode([&](OpNode* dst) {\n    for (const auto& ctrl_in_op_name : dst->op().op_conf().ctrl_in_op_name()) {\n      const OpNode* src = op_graph->OpNode4OpName(ctrl_in_op_name);\n      CHECK(!IsOpGraphDataReachable(dst, src));\n      // src has ctrl to dst, but src has no data path to dst.\n      if (!IsOpGraphDataReachable(src, dst)) {\n        CHECK_EQ(dst->parallel_desc().parallel_num(), src->parallel_desc().parallel_num());\n        const Shape* src_time_shape = CHECK_JUST(src->op().GetOpTimeShape()).get();\n        const Shape* dst_time_shape = CHECK_JUST(dst->op().GetInputBlobFastestTimeShape()).get();\n        if (dst_time_shape == nullptr) {\n          dst_time_shape = CHECK_JUST(dst->op().GetOpTimeShape()).get();\n        }\n        if (src_time_shape->elem_cnt() != dst_time_shape->elem_cnt()) {\n          // NOTE(chengcheng): acc / pack op node can be merged and add ctrl edge.\n          CHECK(src->op().op_conf().has_user_conf());\n          const std::string& op_type_name = src->op().op_conf().user_conf().op_type_name();\n          CHECK(op_type_name == \"acc\" || op_type_name == \"pack\");\n          const Shape* src_input_time_shape =\n              CHECK_JUST(src->op().GetInputBlobFastestTimeShape()).get();\n          CHECK_EQ(src_input_time_shape->elem_cnt(), dst_time_shape->elem_cnt());\n        } else {\n          CHECK_EQ(src_time_shape->elem_cnt(), dst_time_shape->elem_cnt());\n        }\n        if (!src->parallel_desc().EqualsIgnoringHierarchy(dst->parallel_desc())) {\n          LOG(WARNING) << \" Warning, there is a ctrl edge connected across placement from: \"\n                       << src->op().op_name() << \" [\"\n                       << src->parallel_desc().parallel_conf().DebugString()\n                       << \"] to: \" << dst->op().op_name() << \" [\"\n                       << dst->parallel_desc().parallel_conf().DebugString() << \"]\";\n        }\n        Handler(src, dst);\n      }\n    }\n  });\n}\n\nvoid GetHostInputLbis4OpNode(const OpNode* op_node,\n                             std::vector<LogicalBlobId>* host_mem_input_lbis) {\n  host_mem_input_lbis->clear();\n  if (op_node->op().op_conf().has_user_conf()) {\n    const auto& user_conf = op_node->op().op_conf().user_conf();\n    const auto& op_type_name = user_conf.op_type_name();\n    if (user_op::UserOpHostMemoryInputRegistry::Get().HasHostMemoryInput(op_type_name)) {\n      const auto& inputs = [&]() -> std::vector<std::pair<std::string, int32_t>> {\n        const auto& arg_map = op_node->op().op_conf().user_conf().input();\n        std::vector<std::pair<std::string, int32_t>> arg_vec;\n        for (auto it = arg_map.begin(); it != arg_map.end(); ++it) {\n          for (int32_t i = 0; i < it->second.s_size(); ++i) {\n            arg_vec.emplace_back(std::make_pair(it->first, i));\n          }\n        }\n        return arg_vec;\n      }();\n      for (const auto& pair : inputs) {\n        if (user_op::UserOpHostMemoryInputRegistry::Get().IsHostMemoryInput4Op(\n                op_type_name, pair.first, pair.second)) {\n          const LogicalBlobId& host_input_lbi =\n              GenLogicalBlobId(user_conf.input().at(pair.first).s(pair.second));\n          host_mem_input_lbis->emplace_back(host_input_lbi);\n        }\n      }\n    }\n  }\n}\n\nHashMap<DeviceType, CreateSubTskGphBuilderFn>* GlobalDeviceType2CreateSubTskGphBuilderFn() {\n  static HashMap<DeviceType, CreateSubTskGphBuilderFn>\n      global_device_type_create_sub_tsk_gph_builder_fn;\n  return &global_device_type_create_sub_tsk_gph_builder_fn;\n}\n\n}  // namespace\n\nTaskGraph::TaskGraph() = default;\nTaskGraph::~TaskGraph() = default;\n\nMaybe<void> RegisterCreateSubTskGphBuilderFn(DeviceType device_type,\n                                             const CreateSubTskGphBuilderFn& fn) {\n  auto* global_device_type_create_sub_tsk_gph_builder_fn =\n      GlobalDeviceType2CreateSubTskGphBuilderFn();\n  global_device_type_create_sub_tsk_gph_builder_fn->emplace(device_type, fn);\n  return Maybe<void>::Ok();\n}\n\nTaskEdge* TaskGraph::NewTaskEdgeWithLbi(const LogicalBlobId& lbi) {\n  TaskEdge* edge = NewEdge();\n  edge->AddLbi(lbi);\n  return edge;\n}\n\nTaskEdge* TaskGraph::NewTaskEdgeWithLbis(const std::vector<LogicalBlobId>& lbis) {\n  TaskEdge* edge = NewEdge();\n  edge->AddLbis(lbis);\n  return edge;\n}\n\nTaskNode* TaskGraph::GetProxyNode(TaskNode* src_node, const LogicalBlobId& lbi,\n                                  const MemZoneId& dst_mem_zone_id) {\n  const auto& src_mem_zone_id = src_node->MemZoneId121();\n  const ProxyKey key(src_node, lbi, dst_mem_zone_id);\n  auto it = proxy2node.find(key);\n  if (it != proxy2node.cend()) {\n    // hit cache\n    return it->second;\n  } else {\n    if (src_mem_zone_id == dst_mem_zone_id) {\n      // in the same memory zone\n      proxy2node[key] = src_node;\n      return src_node;\n    } else if (dst_mem_zone_id.device_type() == DeviceType::kCPU) {\n      if (src_mem_zone_id.rank() == dst_mem_zone_id.rank()) {\n        // on the same node, not on the same device\n        // src must be not on the cpu mem zone, copy d2h first\n        CHECK(IsMemcpyDtoHSupported(src_mem_zone_id.device_type()));\n        CopyHdTaskNode* copy_task = NewNode<CopyHdTaskNode>();\n        copy_task->Init(CopyHdType::D2H, src_mem_zone_id, lbi);\n        Connect<TaskNode>(src_node, NewTaskEdgeWithLbi(lbi), copy_task);\n        proxy2node[key] = copy_task;\n        return copy_task;\n      } else {\n        // not on the same node, need CopyCommNet from src to dst\n        // build src cpu proxy first\n        TaskNode* proxy_on_src_host =\n            GetProxyNode(src_node, lbi, GetNodeCPUMemZoneId(src_mem_zone_id.rank()));\n        CopyCommNetTaskNode* copy_comm_net_task = NewNode<CopyCommNetTaskNode>();\n        copy_comm_net_task->Init(dst_mem_zone_id.rank(), lbi);\n        Connect<TaskNode>(proxy_on_src_host, NewTaskEdgeWithLbi(lbi), copy_comm_net_task);\n        proxy2node[key] = copy_comm_net_task;\n        return copy_comm_net_task;\n      }\n    } else {\n      TaskNode* proxy_on_dst_host =\n          GetProxyNode(src_node, lbi, GetNodeCPUMemZoneId(dst_mem_zone_id.rank()));\n      CHECK(IsMemcpyHtoDSupported(dst_mem_zone_id.device_type()));\n      CopyHdTaskNode* copy_task = NewNode<CopyHdTaskNode>();\n      copy_task->Init(CopyHdType::H2D, dst_mem_zone_id, lbi);\n      Connect<TaskNode>(proxy_on_dst_host, NewTaskEdgeWithLbi(lbi), copy_task);\n      proxy2node[key] = copy_task;\n      return copy_task;\n    }\n  }\n  return nullptr;\n}\n\nTaskNode* TaskGraph::GetProxyNode(TaskNode* src_node, const LogicalBlobId& lbi,\n                                  const ParallelDesc& dst_parallel_desc, int64_t dst_parallel_id) {\n  const int64_t dst_machine_id =\n      CHECK_JUST(dst_parallel_desc.MachineId4ParallelId(dst_parallel_id));\n  const int64_t dev_id = CHECK_JUST(dst_parallel_desc.DeviceId4ParallelId(dst_parallel_id));\n  DeviceType device_type = dst_parallel_desc.device_type();\n  auto device_index =\n      (device_type == DeviceType::kCPU ? 0 : static_cast<DeviceId::device_index_t>(dev_id));\n  MemZoneId mem_zone_id{static_cast<MemZoneId::rank_t>(dst_machine_id), device_type, device_index};\n  return GetProxyNode(src_node, lbi, mem_zone_id);\n}\n\nvoid TaskGraph::ConnectCtrlEdge(CompTaskNode* src_task_node, CompTaskNode* dst_task_node) {\n  std::string regst_desc_name;\n  src_task_node->BuildCtrlRegstDesc(dst_task_node, &regst_desc_name);\n  TaskEdge* edge = NewEdge();\n  Connect<TaskNode>(src_task_node, edge, dst_task_node);\n  src_task_node->BindEdgeWithProducedRegst(edge, regst_desc_name);\n}\n\nvoid TaskGraph::ConnectCtrlEdges(const std::vector<CompTaskNode*>& src_task_nodes,\n                                 const std::vector<CompTaskNode*>& dst_task_nodes) {\n  CHECK_EQ(src_task_nodes.size(), dst_task_nodes.size());\n  FOR_RANGE(int32_t, i, 0, src_task_nodes.size()) {\n    ConnectCtrlEdge(src_task_nodes.at(i), dst_task_nodes.at(i));\n  }\n}\n\nvoid TaskGraph::RemoveEmptyRegsts() {\n  ForEachNode([&](TaskNode* node) { node->EraseUninitializedShapeProducedBlob(); });\n  ForEachNode([&](TaskNode* node) { node->EraseZeroSizeConsumedRegst(); });\n  ForEachNode([&](TaskNode* node) { node->EraseZeroSizeProducedRegst(); });\n  ForEachNode([&](TaskNode* node) { node->UnbindBnWithEmptyRegst(); });\n}\n\nvoid TaskGraph::MergeChainAndAddOrderingCtrlEdgeInSameChain() {\n  if (EnableLogicalChain()) {\n    // Ctrl edges in chain has already been added in logical chain pass, so\n    // there is no need to call BuildCtrlRegstDescInSameChain here.\n    MergeChainByLogicalChainId();\n  } else {\n    // TODO(chengcheng): erase old chain version in the future.\n    MergeChainByPhysicalTaskGraph();\n    BuildCtrlRegstDescInSameChain();\n  }\n}\n\nvoid TaskGraph::InitOrderedTaskNodes() {\n  // NOTE(chengcheng): Warning, ordered_task_nodes_ by topo is NOT valid in process\n  //  parallel compile, because the current rank task graph is Incomplete.\n  TopoForEachNode([&](TaskNode* task_node) { ordered_task_nodes_.emplace_back(task_node); });\n}\n\nvoid TaskGraph::MergeChainByPhysicalTaskGraph() {\n  int64_t chain_id = 0;\n  for (auto* this_node : ordered_task_nodes_) {\n    // skip if this node has been set in a chain.\n    if (IsValidChainId(this_node->chain_id())) { continue; }\n\n    if (CanBeMergedInChain(this_node)) {\n      TraverseConnectedSubGraphMergeInThisChain(this_node, chain_id);\n    } else {\n      this_node->set_chain_id(chain_id);\n    }\n\n    ++chain_id;\n  }\n\n  // set order_in_chain by ordered_task_nodes_\n  HashMap<int64_t, int64_t> chain_id2order;\n  for (auto* node : ordered_task_nodes_) {\n    CHECK(IsValidChainId(node->chain_id()));\n    int64_t this_chain_id = node->chain_id();\n    if (chain_id2order.find(this_chain_id) == chain_id2order.end()) {\n      chain_id2order.emplace(this_chain_id, 0);\n    }\n    node->set_order_in_chain(chain_id2order.at(this_chain_id)++);\n  }\n}\n\nvoid TaskGraph::MergeChainByLogicalChainId() {\n  for (TaskNode* this_node : ordered_task_nodes_) {\n    CompTaskNode* comp_node = dynamic_cast<CompTaskNode*>(this_node);\n    if (!comp_node) { continue; }\n    const OperatorConf& conf = comp_node->op()->op_conf();\n    if (conf.has_logical_chain_id()) {\n      const int64_t logical_chain_id = conf.logical_chain_id();\n      CHECK(IsValidChainId(logical_chain_id));\n      this_node->set_chain_id(logical_chain_id);\n      CHECK(conf.has_order_in_logical_chain());\n      this_node->set_order_in_chain(conf.order_in_logical_chain());\n    }\n  }\n}\n\nvoid TaskGraph::BuildCtrlRegstDescInSameChain() {\n  auto GenPhysicalChainId = [](TaskNode* node) {\n    // NOTE(chengcheng): different rank cannot use same chain id for bad ctrl link.\n    return (node->chain_id() << 31) | (node->machine_id());\n  };\n  HashMap<int64_t, TaskNode*> physical_chain_id2node;\n  // Note that ordered_task_nodes_'s topology order in seperation plan compile is not gerenteed,\n  // So add ctrl edge with ordered_task_nodes_ in seperation plan compile may case dead lock.\n  for (auto* node : ordered_task_nodes_) {\n    if (IsConnectToTickOp(node)) { continue; }\n    // NOTE(chengcheng): skip invalid chain id\n    if (!IsValidChainId(node->chain_id())) { continue; }\n    int64_t physical_chain_id = GenPhysicalChainId(node);\n    auto iter = physical_chain_id2node.find(physical_chain_id);\n    if (iter == physical_chain_id2node.end()) {\n      CHECK(physical_chain_id2node.emplace(physical_chain_id, node).second);\n    } else {\n      TaskNode* src_node = iter->second;\n      TaskNode* dst_node = node;\n      std::string ctrl_regst_name;\n      bool build_ctrl_edge = src_node->BuildCtrlRegstDescIfNeed(dst_node, &ctrl_regst_name);\n      if (build_ctrl_edge) {\n        CHECK(!ctrl_regst_name.empty());\n        TaskEdge* edge = NewEdge();\n        Connect<TaskNode>(src_node, edge, dst_node);\n        src_node->BindEdgeWithProducedRegst(edge, ctrl_regst_name);\n      }\n      iter->second = dst_node;\n    }\n  }\n}\n\nvoid TaskGraph::GetInplaceOpBlobArgList(\n    InplaceObasInfo* obas_info, const HashSet<TaskNode*>& dev_nodes,\n    const std::function<const TaskNode*(const std::string&)>& TaskNode4OpName) const {\n  auto AddMutableInplaceArgPair = [&](TaskNode* node, const std::string& ibn,\n                                      const std::string& obn, const std::string& op_name) {\n    if (IsInplaceAllowed(node, {ibn, obn}, TaskNode4OpName)) {\n      auto* pair = obas_info->mut_inplace_oba_pairs.mutable_pair()->Add();\n      *pair->mutable_first() = GenOpBlobArg(op_name, ibn);\n      *pair->mutable_second() = GenOpBlobArg(op_name, obn);\n    }\n  };\n  auto AddConstInplaceArgPair = [&](TaskNode* node, const std::string& ibn, const std::string& obn,\n                                    const std::string& op_name) {\n    if (IsInplaceAllowed(node, {ibn, obn}, TaskNode4OpName)) {\n      auto* pair = obas_info->con_inplace_oba_pairs.mutable_pair()->Add();\n      *pair->mutable_first() = GenOpBlobArg(op_name, ibn);\n      *pair->mutable_second() = GenOpBlobArg(op_name, obn);\n    }\n  };\n\n  for (TaskNode* task_node : dev_nodes) {\n    if (task_node->exec_gph().node_num() != 1) { continue; }\n    const auto& op = *task_node->exec_gph().SoleNode()->op();\n    for (const std::string& ibn : op.input_bns()) {\n      if (op.InputBlobModifier4Ibn(ibn).is_mutable()) {\n        CHECK(IsInplaceAllowed(task_node, {ibn}, TaskNode4OpName));\n        *obas_info->mut_in_obas.mutable_oba()->Add() = GenOpBlobArg(op.op_name(), ibn);\n      }\n    }\n    for (const auto& pair : task_node->exec_gph().SoleNode()->mut_inplace_obn2ibn()) {\n      AddMutableInplaceArgPair(task_node, pair.second, pair.first, op.op_name());\n    }\n    for (const auto& pair : task_node->exec_gph().SoleNode()->con_inplace_obn2ibn()) {\n      AddConstInplaceArgPair(task_node, pair.second, pair.first, op.op_name());\n    }\n  }\n}\n\nvoid TaskGraph::GetSafeInplaceOpBlobArgList(\n    InplaceObasInfo* safe_obas_info, const HashSet<TaskNode*>& dev_nodes,\n    const std::function<bool(const std::string&, const std::string&)>& IsOpNameDataOrCtrlReachable)\n    const {\n  auto TaskNode4SoleOpName = MakeGetterTaskNode4SoleOpName(dev_nodes);\n  InplaceObasInfo obas_info;\n  GetInplaceOpBlobArgList(&obas_info, dev_nodes, TaskNode4SoleOpName);\n  auto Op4OpName = [&](const std::string& op_name) -> const Operator* {\n    return TaskNode4SoleOpName(op_name)->exec_gph().SoleNode()->op().get();\n  };\n  auto IsLbiAllConsumersReachable =\n      MakePredicatorIsLbiAllConsumersReachable(TaskNode4SoleOpName, IsOpNameDataOrCtrlReachable);\n  InplaceLbiGraph origin_graph(obas_info, Op4OpName);\n  InplaceLbiGraph safe_graph(*safe_obas_info, Op4OpName);\n  origin_graph.ComputeSafeInplaceObns(safe_obas_info, IsLbiAllConsumersReachable);\n  if (Singleton<ResourceDesc, ForSession>::Get()->enable_debug_mode()) {\n    origin_graph.ToDotWithFilePath(\n        JoinPath(\"dot\", \"InplaceLbiGraph\", GlobalJobDesc().job_name() + \"_origin.dot\"));\n    safe_graph.ToDotWithFilePath(\n        JoinPath(\"dot\", \"InplaceLbiGraph\", GlobalJobDesc().job_name() + \"_safe.dot\"));\n  }\n}\n\nvoid TaskGraph::SetTaskRegstInplaceInfo(const InplaceObasInfo& obas_info,\n                                        const HashSet<TaskNode*>& dev_nodes) const {\n  auto TaskNode4SoleOpName = MakeGetterTaskNode4SoleOpName(dev_nodes);\n  auto Op4OpName = [&](const std::string& op_name) -> const Operator* {\n    return TaskNode4SoleOpName(op_name)->exec_gph().SoleNode()->op().get();\n  };\n  InplaceLbiGraph inplace_gph(obas_info, Op4OpName);\n  inplace_gph.ForEachConnectedComponent([&](const HashSet<const InplaceLbiNode*>& inplace_nodes) {\n    for (const auto* inplace_node : inplace_nodes) {\n      if (inplace_node->in_edges().empty()) { continue; }\n      const auto* inplace_edge = inplace_node->SoleInEdge();\n      auto* exec_node = TaskNode4SoleOpName(inplace_edge->op().op_name())->exec_gph().SoleNode();\n      RegstDesc* in_regst = exec_node->RegstDesc4BnInOp(inplace_edge->ibn());\n      RegstDesc* out_regst = exec_node->RegstDesc4BnInOp(inplace_edge->obn());\n      out_regst->set_hint_inplace_consumed_regst_desc_id(in_regst->regst_desc_id());\n    }\n  });\n}\n\nvoid TaskGraph::ForEachGpuDeviceNodes(\n    const std::function<void(const HashSet<TaskNode*>& dev_nodes)>& Handler) const {\n  HashMap<std::pair<int64_t, int64_t>, HashSet<TaskNode*>> global_dev_phy_id2nodes;\n  ForEachNode([&](TaskNode* task_node) {\n    if (task_node->device_type() == DeviceType::kCPU) { return; }\n    int64_t dev_phy_id = task_node->stream_id().device_id().device_index();\n    global_dev_phy_id2nodes[{task_node->machine_id(), dev_phy_id}].emplace(task_node);\n  });\n  for (const auto& pair : global_dev_phy_id2nodes) { Handler(pair.second); }\n}\n\nvoid TaskGraph::EnableInplaceMemSharing(\n    const std::function<bool(const std::string&, const std::string&)>&\n        IsOpNameDataOrCtrlReachable) {\n  ForEachGpuDeviceNodes([&](const HashSet<TaskNode*>& dev_nodes) {\n    EnableInplaceMemSharing(dev_nodes, IsOpNameDataOrCtrlReachable);\n  });\n}\n\nvoid TaskGraph::EnableInplaceMemSharing(\n    const HashSet<TaskNode*>& dev_nodes,\n    const std::function<bool(const std::string&, const std::string&)>&\n        IsOpNameDataOrCtrlReachable) {\n  InplaceObasInfo safe_inplace_obas_info;\n  GetSafeInplaceOpBlobArgList(&safe_inplace_obas_info, dev_nodes, IsOpNameDataOrCtrlReachable);\n  SetTaskRegstInplaceInfo(safe_inplace_obas_info, dev_nodes);\n}\n\nvoid TaskGraph::DecideExecutionOrder() {\n  // For one machine with no transfer available, the straighten algorithm for overlaps consume a lot\n  // of memory\n  StraightenAlgorithmTag straighten_algorithm_tag =\n      GlobalJobDesc().job_conf().straighten_algorithm_tag_in_task_graph();\n  if (straighten_algorithm_tag == StraightenAlgorithmTag::kDisableStraighten\n      || (straighten_algorithm_tag == StraightenAlgorithmTag::kOverlap4Transfer\n          && GlobalProcessCtx::WorldSize() == 1)) {\n    InitOrderedTaskNodes();\n  } else {\n    StraightenNodes(this, &ordered_task_nodes_,\n                    Singleton<ResourceDesc, ForSession>::Get()->nccl_use_compute_stream());\n  }\n}\n\n#define DEFINE_BLD_SUB_TASK_GRAPH_METHOD(method_name) \\\n  void TaskGraph::method_name BLD_SUB_TSK_GPH_MTHD_ARGS()\n\nDEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByBoxing) {\n  const OpNode* src_op_node = op_edge->src_node();\n  const OpNode* dst_op_node = op_edge->dst_node();\n  std::vector<LogicalBlobId> host_mem_input_lbis;\n  GetHostInputLbis4OpNode(dst_op_node, &host_mem_input_lbis);\n  for (const LogicalBlobId& lbi : op_edge->lbis()) {\n    std::vector<TaskNode*> in_nodes(sorted_src_comp_tasks.begin(), sorted_src_comp_tasks.end());\n    std::vector<TaskNode*> out_nodes;\n    out_nodes.reserve(sorted_dst_comp_tasks.size());\n    std::vector<std::vector<TaskNode*>> sorted_ctrl_tasks;\n    const NdSbp& src_nd_sbp = src_op_node->NdSbp4Lbi(lbi);\n    const NdSbp& dst_nd_sbp = dst_op_node->NdSbp4Lbi(lbi);\n    const ParallelDesc& src_parallel_desc = src_op_node->parallel_desc();\n    const ParallelDesc& dst_parallel_desc = [&]() {\n      if (std::find(host_mem_input_lbis.begin(), host_mem_input_lbis.end(), lbi)\n          != host_mem_input_lbis.end()) {\n        return *CHECK_JUST(\n            ReplaceDeviceType(SymbolOf(dst_op_node->parallel_desc()), DeviceType::kCPU));\n      } else {\n        return dst_op_node->parallel_desc();\n      }\n    }();\n    const BlobDesc& blob_desc = src_op_node->LogicalBlobDesc4Lbi(lbi);\n    VLOG(3) << \"src op: \" << src_op_node->op().op_name()\n            << \" dst op: \" << dst_op_node->op().op_name()\n            << \" src_parallel_conf: \" << src_parallel_desc.parallel_conf().DebugString()\n            << \" dst parallel conf: \" << dst_parallel_desc.parallel_conf().DebugString()\n            << \" src_nd_sbp \" << src_nd_sbp.DebugString() << \" dst nd_sbp \"\n            << dst_nd_sbp.DebugString();\n    std::shared_ptr<SubTskGphBuilderStatus> status;\n    const DeviceType device_type = [&src_parallel_desc, &dst_parallel_desc]() {\n      return src_parallel_desc.device_type() != DeviceType::kCPU ? src_parallel_desc.device_type()\n                                                                 : dst_parallel_desc.device_type();\n    }();\n    if (device_type != DeviceType::kCPU\n        && device_type2sub_tsk_gph_builder_.find(device_type)\n               != device_type2sub_tsk_gph_builder_.end()) {\n      auto maybe_status =                                                             // NOLINT\n          device_type2sub_tsk_gph_builder_                                            // NOLINT\n              .at(device_type)                                                        // NOLINT\n              ->Build(sub_tsk_gph_builder_ctx_.get(), in_nodes, &out_nodes,           // NOLINT\n                      &sorted_ctrl_tasks, src_parallel_desc, dst_parallel_desc, lbi,  // NOLINT\n                      blob_desc, src_nd_sbp, dst_nd_sbp,                              // NOLINT\n                      *(CHECK_JUST(src_op_node->op().GetOpTimeShape()).get()));       // NOLINT\n      if (maybe_status.IsOk()) { status = CHECK_JUST(maybe_status); }\n    }\n    if (!status) {\n      status = CHECK_JUST(hierarchical_sub_tsk_gph_builder_->Build(\n          sub_tsk_gph_builder_ctx_.get(), in_nodes, &out_nodes, &sorted_ctrl_tasks,\n          src_parallel_desc, dst_parallel_desc, lbi, blob_desc, src_nd_sbp, dst_nd_sbp,\n          *(CHECK_JUST(src_op_node->op().GetOpTimeShape()).get())));\n    }\n    boxing_logger_->Log(*status, src_op_node->op().op_name(), dst_op_node->op().op_name(),\n                        src_parallel_desc, dst_parallel_desc, src_nd_sbp, dst_nd_sbp, lbi,\n                        blob_desc);\n    CHECK_EQ(out_nodes.size(), sorted_dst_comp_tasks.size());\n    FOR_RANGE(size_t, i, 0, out_nodes.size()) {\n      ConnectWithLbi(out_nodes.at(i), sorted_dst_comp_tasks.at(i), lbi);\n    }\n    if (!sorted_ctrl_tasks.empty()) {\n      CHECK_EQ(sorted_ctrl_tasks.size(), sorted_dst_comp_tasks.size());\n      FOR_RANGE(size_t, i, 0, sorted_dst_comp_tasks.size()) {\n        for (TaskNode* ctrl_node : sorted_ctrl_tasks.at(i)) {\n          std::string regst_desc_name;\n          ctrl_node->BuildCtrlRegstDesc(sorted_dst_comp_tasks.at(i), &regst_desc_name);\n          TaskEdge* edge = NewEdge();\n          Connect<TaskNode>(ctrl_node, edge, sorted_dst_comp_tasks.at(i));\n          ctrl_node->BindEdgeWithProducedRegst(edge, regst_desc_name);\n        }\n      }\n    }\n  }\n}\n\nDEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByOneToOne) {\n  std::vector<LogicalBlobId> host_mem_input_lbis;\n  GetHostInputLbis4OpNode(op_edge->dst_node(), &host_mem_input_lbis);\n  CHECK_EQ(sorted_src_comp_tasks.size(), sorted_dst_comp_tasks.size());\n  FOR_RANGE(size_t, i, 0, sorted_src_comp_tasks.size()) {\n    for (const LogicalBlobId& lbi : op_edge->lbis()) {\n      bool is_host_mem_input =\n          std::find(host_mem_input_lbis.begin(), host_mem_input_lbis.end(), lbi)\n          != host_mem_input_lbis.end();\n      BuildTaskPath(sorted_src_comp_tasks.at(i), sorted_dst_comp_tasks.at(i), lbi,\n                    is_host_mem_input);\n    }\n  }\n}\n\nDEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByBroadcastToBroadcast) {\n  std::vector<LogicalBlobId> host_mem_input_lbis;\n  GetHostInputLbis4OpNode(op_edge->dst_node(), &host_mem_input_lbis);\n  for (CompTaskNode* dst_node : sorted_dst_comp_tasks) {\n    CompTaskNode* nearest_src_node =\n        SubTskGphBuilderUtil::FindNearestNode(sorted_src_comp_tasks, dst_node);\n    CHECK_NOTNULL(nearest_src_node);\n    for (const LogicalBlobId& lbi : op_edge->lbis()) {\n      bool is_host_mem_input =\n          std::find(host_mem_input_lbis.begin(), host_mem_input_lbis.end(), lbi)\n          != host_mem_input_lbis.end();\n      BuildTaskPath(nearest_src_node, dst_node, lbi, is_host_mem_input);\n    }\n  }\n}\n\nDEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByPartialInLbiConnect) {\n  const Operator& src_op = op_edge->src_node()->op();\n  const Operator& dst_op = op_edge->dst_node()->op();\n  HashSet<LogicalBlobId> lbis;\n  std::vector<LogicalBlobId> host_mem_input_lbis;\n  GetHostInputLbis4OpNode(op_edge->dst_node(), &host_mem_input_lbis);\n  for (const auto& obn : src_op.output_bns()) { lbis.insert(src_op.BnInOp2Lbi(obn)); }\n  CHECK_EQ(sorted_src_comp_tasks.size(), 1);\n  CHECK_EQ(dst_op.input_bns().size(), sorted_dst_comp_tasks.size());\n  FOR_RANGE(int, i, 0, sorted_dst_comp_tasks.size()) {\n    const auto& lbi = dst_op.BnInOp2Lbi(dst_op.input_bns().Get(i));\n    if (lbis.find(lbi) != lbis.end()) {\n      bool is_host_mem_input =\n          std::find(host_mem_input_lbis.begin(), host_mem_input_lbis.end(), lbi)\n          != host_mem_input_lbis.end();\n      BuildTaskPath(sorted_src_comp_tasks.at(0), sorted_dst_comp_tasks.at(i), lbi,\n                    is_host_mem_input);\n    }\n  }\n}\n\nDEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByPartialOutLbiConnect) {\n  const Operator& src_op = op_edge->src_node()->op();\n  const Operator& dst_op = op_edge->dst_node()->op();\n  HashSet<LogicalBlobId> lbis;\n  std::vector<LogicalBlobId> host_mem_input_lbis;\n  GetHostInputLbis4OpNode(op_edge->dst_node(), &host_mem_input_lbis);\n  for (const auto& ibn : dst_op.input_bns()) { lbis.insert(dst_op.BnInOp2Lbi(ibn)); }\n  CHECK_EQ(sorted_dst_comp_tasks.size(), 1);\n  CHECK_EQ(src_op.output_bns().size(), sorted_src_comp_tasks.size());\n  FOR_RANGE(int, i, 0, sorted_src_comp_tasks.size()) {\n    const auto& lbi = src_op.BnInOp2Lbi(src_op.output_bns().Get(i));\n    if (lbis.find(lbi) != lbis.end()) {\n      bool is_host_mem_input =\n          std::find(host_mem_input_lbis.begin(), host_mem_input_lbis.end(), lbi)\n          != host_mem_input_lbis.end();\n      BuildTaskPath(sorted_src_comp_tasks.at(i), sorted_dst_comp_tasks.at(0), lbi,\n                    is_host_mem_input);\n    }\n  }\n}\n\nDEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphBySrcSubsetConnect) {\n  std::function<Maybe<CompTaskNode*>(int64_t mchn_id, int64_t thrd_id)> TaskNode4MachineId7ThrdId;\n  CHECK_JUST(\n      MakeGetterTaskNode4MachineId7ThrdId(sorted_src_comp_tasks, &TaskNode4MachineId7ThrdId));\n  for (CompTaskNode* dst_task_node : sorted_dst_comp_tasks) {\n    CompTaskNode* src_task_node = CHECK_JUST(\n        TaskNode4MachineId7ThrdId(dst_task_node->machine_id(), dst_task_node->thrd_id()));\n    Connect<TaskNode>(src_task_node, NewTaskEdgeWithLbis(op_edge->lbis()), dst_task_node);\n  }\n}\n\nDEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByDstSubsetConnect) {\n  std::function<Maybe<CompTaskNode*>(int64_t mchn_id, int64_t thrd_id)> TaskNode4MachineId7ThrdId;\n  CHECK_JUST(\n      MakeGetterTaskNode4MachineId7ThrdId(sorted_dst_comp_tasks, &TaskNode4MachineId7ThrdId));\n  for (CompTaskNode* src_task_node : sorted_src_comp_tasks) {\n    CompTaskNode* dst_task_node = CHECK_JUST(\n        TaskNode4MachineId7ThrdId(src_task_node->machine_id(), src_task_node->thrd_id()));\n    Connect<TaskNode>(src_task_node, NewTaskEdgeWithLbis(op_edge->lbis()), dst_task_node);\n  }\n}\n\nDEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphNormalForwardToDecodeH2D) {\n  CHECK_EQ(sorted_src_comp_tasks.size(), sorted_dst_comp_tasks.size());\n  FOR_RANGE(size_t, i, 0, sorted_src_comp_tasks.size()) {\n    CompTaskNode* src = sorted_src_comp_tasks.at(i);\n    CompTaskNode* dst = sorted_dst_comp_tasks.at(i);\n    for (const LogicalBlobId& lbi : op_edge->lbis()) { ConnectWithLbi(src, dst, lbi); }\n  }\n}\n\nvoid TaskGraph::ConnectWithLbi(TaskNode* src_node, TaskNode* dst_node, const LogicalBlobId& lbi) {\n  if (src_node == dst_node) { return; }\n  for (TaskEdge* out_edge : src_node->out_edges()) {\n    TaskNode* out_node = out_edge->dst_node();\n    if (out_node == dst_node) {\n      out_edge->AddLbi(lbi);\n      return;\n    }\n  }\n\n  TaskEdge* connected_edge = NewEdge();\n  connected_edge->AddLbi(lbi);\n  Connect<TaskNode>(src_node, connected_edge, dst_node);\n}\n\nvoid TaskGraph::BuildTaskPath(TaskNode* src_node, TaskNode* dst_node, const LogicalBlobId& lbi,\n                              bool is_host_mem_input) {\n  const MemZoneId dst_mem_zone_id = [&]() {\n    if (is_host_mem_input) {\n      MemZoneId mem_zone_id = dst_node->MemZoneId121();\n      return MemZoneId(mem_zone_id.rank(), DeviceType::kCPU, 0);\n    } else {\n      return dst_node->MemZoneId121();\n    }\n  }();\n  TaskNode* proxy_node = GetProxyNode(src_node, lbi, dst_mem_zone_id);\n  ConnectWithLbi(proxy_node, dst_node, lbi);\n}\n\nMaybe<void> GlobalTaskGraph::Init() {\n  OpGraph* op_graph = Singleton<OpGraph>::Get();\n  sub_tsk_gph_builder_ctx_.reset(new SubTskGphBuilderCtx(this));\n  boxing_logger_ = CreateBoxingLogger();\n  // Register the corresponding task graph builder based on the device type and store them to map\n  const auto* global_device_type_create_sub_tsk_gph_builder_fn =\n      GlobalDeviceType2CreateSubTskGphBuilderFn();\n  for (const auto& pair : *global_device_type_create_sub_tsk_gph_builder_fn) {\n    device_type2sub_tsk_gph_builder_.emplace(pair.first, pair.second());\n  }\n  hierarchical_sub_tsk_gph_builder_.reset(new DispatchHierarchicalSubTskGphBuilder());\n  HashMap<const OpNode*, std::vector<CompTaskNode*>> op_node2sorted_comp_tasks;\n\n  op_graph->ForEachNode([&](const OpNode* op_node) {\n    std::vector<CompTaskNode*>* sorted_comp_tasks = &(op_node2sorted_comp_tasks[op_node]);\n    GenSortedCompTaskNodes(op_node, sorted_comp_tasks);\n    for (CompTaskNode* comp_task : *sorted_comp_tasks) { AddAllocatedNode(comp_task); }\n  });\n\n  op_graph->ForEachEdge([&](const OpEdge* op_edge) {\n    BldSubTskGphMthd method = GetMthdForBldSubTskGph(op_edge);\n    (this->*method)(op_edge, op_node2sorted_comp_tasks.at(op_edge->src_node()),\n                    op_node2sorted_comp_tasks.at(op_edge->dst_node()));\n  });\n\n  ForEachOpGraphNecessaryCtrlEdge(op_graph, [&](const OpNode* src, const OpNode* dst) {\n    const auto& src_task_nodes = op_node2sorted_comp_tasks.at(src);\n    const auto& dst_task_nodes = op_node2sorted_comp_tasks.at(dst);\n    if (src->op().op_conf().has_src_subset_tick_conf()) {\n      UNIMPLEMENTED();\n    } else if (dst->op().op_conf().has_dst_subset_tick_conf()) {\n      UNIMPLEMENTED();\n    } else {\n      ConnectCtrlEdges(src_task_nodes, dst_task_nodes);\n    }\n  });\n\n  if (Singleton<ResourceDesc, ForSession>::Get()->enable_debug_mode()) { ToDotWithAutoFilePath(); }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> BoxingTaskGraph::Init(\n    const std::function<void(size_t, const std::function<void(size_t i)>&)>& ParallelRunLoop) {\n  OpGraph* op_graph = Singleton<OpGraph>::Get();\n  sub_tsk_gph_builder_ctx_.reset(new SubTskGphBuilderCtx(this));\n  boxing_logger_ = CreateBoxingLogger();\n  // Register the corresponding task graph builder based on the device type and store them to map\n  const auto* global_device_type_create_sub_tsk_gph_builder_fn =\n      GlobalDeviceType2CreateSubTskGphBuilderFn();\n  for (const auto& pair : *global_device_type_create_sub_tsk_gph_builder_fn) {\n    device_type2sub_tsk_gph_builder_.emplace(pair.first, pair.second());\n  }\n\n  hierarchical_sub_tsk_gph_builder_.reset(new DispatchHierarchicalSubTskGphBuilder());\n\n  const auto& TryCreateSortedCompTaskNodes = [&](const OpNode* op_node) {\n    if (boxing_related_op_node2sorted_comp_tasks_.count(op_node) > 0) { return; }\n    std::vector<CompTaskNode*>* sorted_comp_tasks =\n        &(boxing_related_op_node2sorted_comp_tasks_[op_node]);\n    GenSortedCompTaskNodes(op_node, sorted_comp_tasks);\n    for (CompTaskNode* comp_task : *sorted_comp_tasks) { AddAllocatedNode(comp_task); }\n  };\n  op_graph->ForEachEdge([&](const OpEdge* op_edge) {\n    if (!op_edge->NeedBoxing()) { return; }\n    TryCreateSortedCompTaskNodes(op_edge->src_node());\n    TryCreateSortedCompTaskNodes(op_edge->dst_node());\n    BldSubTskGphMthd method = GetMthdForBldSubTskGph(op_edge);\n    (this->*method)(op_edge, boxing_related_op_node2sorted_comp_tasks_.at(op_edge->src_node()),\n                    boxing_related_op_node2sorted_comp_tasks_.at(op_edge->dst_node()));\n  });\n  ForEachNode(std::bind(&TaskNode::ProduceAllRegstsAndBindEdges, std::placeholders::_1));\n  CreateOpNode2TaskIds(ParallelRunLoop);\n  return Maybe<void>::Ok();\n}\n\nvoid BoxingTaskGraph::CreateOpNode2TaskIds(\n    const std::function<void(size_t, const std::function<void(size_t i)>&)>& ParallelRunLoop) {\n  const OpGraph* op_graph = Singleton<OpGraph>::Get();\n  std::vector<const OpNode*> op_nodes;\n  op_nodes.reserve(op_graph->node_num());\n  op_graph->ForEachNode([&](OpNode* op_node) {\n    if (boxing_related_op_node2sorted_comp_tasks_.count(op_node) == 0) {\n      op_nodes.push_back(op_node);\n      boxing_unrelated_op_node2sorted_task_ids_[op_node].reserve(\n          op_node->parallel_desc().parallel_num());\n    }\n  });\n  ParallelRunLoop(op_nodes.size(), [&](size_t i) {\n    const OpNode* op_node = op_nodes.at(i);\n    TaskType task_type = TaskType4OpNode(op_node);\n    const auto& parallel_desc = op_node->parallel_desc();\n    auto* task_ids = &boxing_unrelated_op_node2sorted_task_ids_[op_node];\n    for (int parallel_id = 0; parallel_id < parallel_desc.parallel_num(); ++parallel_id) {\n      const auto& stream_id = GetStreamId(op_node, parallel_id, task_type);\n      task_ids->push_back(Singleton<IDMgr>::Get()->GetTaskIdGenerator()->Generate(stream_id));\n    }\n  });\n}\n\nnamespace {\n\nbool IsComputTaskNodeDutyRank(int64_t current_rank, const ParallelDesc& parallel_desc,\n                              int64_t task_node_rank) {\n  if (current_rank == 0) {\n    // make sure master knows at least one op_node.\n    return CHECK_JUST(parallel_desc.MachineId4ParallelId(0)) == task_node_rank;\n  } else if (parallel_desc.HasMachineId(current_rank)) {\n    // workers only care their own rank.\n    return current_rank == task_node_rank;\n  } else {\n    return false;\n  }\n}\n\n// A template function to process task node for different task node type.\n// RetT, function return type\n// HandleTansportTaskNode, if the task node is a transport task node, call this processing function\n// HandleComputeTaskNode, if the task node is a compute task node, call this processing\n// task_node, the input task node\ntemplate<typename RetT, typename HandleTansportTaskNodeT, typename HandleComputeTaskNodeT>\nRetT TaskNodeVisitor(TaskNode* task_node, const HandleTansportTaskNodeT& HandleTansportTaskNode,\n                     const HandleComputeTaskNodeT& HandleComputeTaskNode) {\n  auto* transport_task_node = dynamic_cast<TransportTaskNode*>(task_node);\n  if (transport_task_node != nullptr) {\n    return HandleTansportTaskNode(transport_task_node);\n  } else {\n    auto* comp_task_node = dynamic_cast<CompTaskNode*>(task_node);\n    if (comp_task_node != nullptr) {\n      return HandleComputeTaskNode(comp_task_node);\n    } else {\n      UNIMPLEMENTED();\n    }\n  }\n}\n\n}  // namespace\n\n/*static*/ bool BoxingTaskGraph::SelectTaskNodeByRank(TaskNode* task_node, int64_t rank) {\n  return TaskNodeVisitor<bool>(\n      task_node, [&](TransportTaskNode* task_node) { return task_node->machine_id() == rank; },\n      [&](CompTaskNode* task_node) {\n        const auto& machine_id = task_node->machine_id();\n        return IsComputTaskNodeDutyRank(rank, task_node->op_node()->parallel_desc(), machine_id);\n      });\n}\n\nvoid BoxingTaskGraph::ToProto(const std::function<bool(TaskNode*)>& Pick,\n                              BoxingTaskGraphProto* proto) const {\n  const auto sources = [&]() -> std::list<TaskNode*> {\n    HashSet<TaskNode*> sources;\n    ForEachNode([&](TaskNode* task_node) {\n      if (Pick(task_node)) { sources.insert(task_node); }\n    });\n    HashSet<TaskNode*> sources_out;\n    for (auto* source : sources) {\n      // The consumed task_ids must be generated from out_nodes.\n      source->ForEachNodeOnOutEdge([&](TaskNode* out_node) {\n        if (!sources.count(out_node)) { sources_out.insert(out_node); }\n      });\n    }\n    sources.insert(sources_out.begin(), sources_out.end());\n    return std::list<TaskNode*>{sources.begin(), sources.end()};\n  }();\n  const auto& TransportTaskNodeToProto = [&](TransportTaskNode* task_node) {\n    task_node->ToTransportTaskProtoIf(proto->mutable_transport_task()->Add());\n  };\n  const auto& ComputeTaskNodeToProto = [&](CompTaskNode* task_node) {\n    auto* map = proto->mutable_boxing_related_op_name2compute_tasks();\n    const auto& op_name = task_node->op_node()->op().op_name();\n    auto* parallel_id2task_proto = (*map)[op_name].mutable_parallel_id2task();\n    int64_t parallel_id = task_node->parallel_id();\n    task_node->ToProto(&(*parallel_id2task_proto)[parallel_id], /*check=*/false);\n  };\n  HashSet<TaskNode*> rank_task_nodes;\n  BfsForEachNode(sources, &TaskNode::ForEachNodeOnInEdge, [&](TaskNode* task_node) {\n    rank_task_nodes.insert(task_node);\n    TaskNodeVisitor<void>(task_node, TransportTaskNodeToProto, ComputeTaskNodeToProto);\n  });\n  const auto rank_task_edges = [&] {\n    HashSet<TaskEdge*> rank_task_edges;\n    const auto& TryInsertEdge = [&](TaskEdge* edge) {\n      if (rank_task_nodes.count(edge->src_node()) > 0\n          && rank_task_nodes.count(edge->dst_node()) > 0) {\n        rank_task_edges.insert(edge);\n      }\n    };\n    for (const auto* task_node : rank_task_nodes) {\n      for (auto* in_edge : task_node->in_edges()) { TryInsertEdge(in_edge); }\n      for (auto* out_edge : task_node->out_edges()) { TryInsertEdge(out_edge); }\n    }\n    return rank_task_edges;\n  }();\n  for (auto* edge : rank_task_edges) { edge->ToProto(proto->mutable_task_edge()->Add()); }\n  for (const auto& pair : boxing_unrelated_op_node2sorted_task_ids_) {\n    const auto& op_name = pair.first->op().op_name();\n    auto* vec = &(*proto->mutable_boxing_unrelated_op_name2task_ids())[op_name];\n    for (const auto& task_id : pair.second) { vec->add_task_id(EncodeTaskIdToInt64(task_id)); }\n  }\n}\n\nRankTaskGraph::RankTaskGraph(const std::shared_ptr<BoxingTaskGraphProto>& boxing_task_graph_proto,\n                             int64_t current_rank)\n    : boxing_task_graph_proto_(boxing_task_graph_proto),\n      current_rank_(current_rank),\n      task_graph_rebuild_ctx_(std::make_unique<TaskGraphRebuildCtx>()) {}\n\nMaybe<CompTaskNode*> RankTaskGraph::TryGetBoxingRelatedComTaskNode(const OpNode* op_node,\n                                                                   int64_t parallel_id) {\n  const auto& op_name = op_node->op().op_name();\n  auto iter = boxing_task_graph_proto_->boxing_related_op_name2compute_tasks().find(op_name);\n  if (iter == boxing_task_graph_proto_->boxing_related_op_name2compute_tasks().end()) {\n    return nullptr;\n  }\n  if (iter == boxing_task_graph_proto_->boxing_related_op_name2compute_tasks().end()) {\n    return nullptr;\n  }\n  auto task_iter = iter->second.parallel_id2task().find(parallel_id);\n  if (task_iter == iter->second.parallel_id2task().end()) { return nullptr; }\n  int64_t task_id = task_iter->second.task_id();\n  auto* task_node = JUST(task_graph_rebuild_ctx_->TaskNode4Id(task_id));\n  auto* comp_task_node = dynamic_cast<CompTaskNode*>(task_node);\n  CHECK_NOTNULL_OR_RETURN(comp_task_node) << \"invalid task_type. task_id: \" << task_id;\n  return comp_task_node;\n}\n\nMaybe<CompTaskNode*> RankTaskGraph::CreateOrFindRankCompTaskNodeByParallelId(const OpNode* op_node,\n                                                                             int64_t parallel_id) {\n  auto* comp_task_node = JUST(TryGetBoxingRelatedComTaskNode(op_node, parallel_id));\n  if (comp_task_node != nullptr) { return comp_task_node; }\n  auto iter = op_node2comp_task_node_.find(op_node);\n  if (iter != op_node2comp_task_node_.end()) { return iter->second; }\n\n  const TaskId task_id = *JUST([&]() -> Maybe<TaskId> {\n    const auto& map = boxing_task_graph_proto_->boxing_unrelated_op_name2task_ids();\n    const auto& iter = map.find(op_node->op().op_name());\n    CHECK_OR_RETURN(iter != map.end());\n    CHECK_LT_OR_RETURN(parallel_id, iter->second.task_id_size());\n    return DecodeTaskIdFromInt64(iter->second.task_id().Get(parallel_id));\n  }());\n  const auto& GetStreamIdFromMaster = [&](const OpNode* op_node, int64_t parallel_id, TaskType) {\n    return task_id.stream_id();\n  };\n  auto comp_task_node_ptr = GenCompTaskNode(op_node, parallel_id, GetStreamIdFromMaster);\n  comp_task_node_ptr->update_new_task_id(task_id);\n  AddAllocatedNode(comp_task_node_ptr);\n  CHECK_OR_RETURN(op_node2comp_task_node_.emplace(op_node, comp_task_node_ptr).second)\n      << \"Got dupliacted op_node \" << op_node->op().op_name();\n  return comp_task_node_ptr;\n}\n\nMaybe<CompTaskNode*> RankTaskGraph::CreateOrFindRankCompTaskNodeByRank(const OpNode* op_node,\n                                                                       int64_t rank) {\n  CHECK_OR_RETURN(op_node->parallel_desc().HasMachineId(rank))\n      << \"rank is not contained in the placment\";\n  int64_t parallel_id = -1;\n  CHECK_OR_RETURN(JUST(op_node->parallel_desc().TryGetParallelId(rank, &parallel_id)))\n      << \"parallel_id not found.\";\n  return CreateOrFindRankCompTaskNodeByParallelId(op_node, parallel_id);\n}\n\nMaybe<CompTaskNode*> RankTaskGraph::TryGetRankCompTaskNode(const OpNode* op_node, int64_t rank) {\n  if (!op_node->parallel_desc().HasMachineId(rank)) { return nullptr; }\n  int64_t parallel_id = -1;\n  CHECK_OR_RETURN(JUST(op_node->parallel_desc().TryGetParallelId(rank, &parallel_id)))\n      << \"parallel_id not found.\";\n  auto* comp_task_node = JUST(TryGetBoxingRelatedComTaskNode(op_node, parallel_id));\n  if (comp_task_node != nullptr) { return comp_task_node; }\n  auto iter = op_node2comp_task_node_.find(op_node);\n  CHECK_OR_RETURN(iter != op_node2comp_task_node_.end())\n      << \"op_node \" << op_node->op().op_name() << \" not found.\";\n  return iter->second;\n}\n\nMaybe<void> RankTaskGraph::AddBoxingReletedCompTaskNodesFromProto() {\n  OpGraph* op_graph = Singleton<OpGraph>::Get();\n  for (const auto& pair : boxing_task_graph_proto_->boxing_related_op_name2compute_tasks()) {\n    const OpNode* op_node = op_graph->OpNode4OpName(pair.first);\n    for (const auto& pair : pair.second.parallel_id2task()) {\n      const auto& task_proto = pair.second;\n      CHECK_OR_RETURN(task_id2task_proto_.emplace(task_proto.task_id(), &task_proto).second)\n          << \"redundant task_id.\";\n      CompTaskNode* comp_task_node = NewCompTaskNode4OpNode(op_node);\n      comp_task_node->set_op_node(op_node);\n      AddAllocatedNode(comp_task_node);\n      // Note here has no consume regst\n      // Init task node and produce regst\n      comp_task_node->InitFromProtoExceptConsumedRegsts(task_proto);\n      JUST(task_graph_rebuild_ctx_->AddTaskNode(comp_task_node));\n    }\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> RankTaskGraph::CreateAndPartiallyInitTransportTaskNodesFromProto() {\n  for (const auto& transport_task_proto : boxing_task_graph_proto_->transport_task()) {\n    const auto& task_proto = transport_task_proto.task_proto();\n    CHECK_OR_RETURN(task_id2task_proto_.emplace(task_proto.task_id(), &task_proto).second)\n        << \"redundant task_id.\";\n    auto* task_node = JUST(CreateTransportTask::Visit(task_proto.task_type()));\n    AddAllocatedNode(task_node);\n    // Init task node and produce regst\n    task_node->InitFromProtoExceptConsumedRegsts(transport_task_proto.task_proto());\n    JUST(task_graph_rebuild_ctx_->AddTaskNode(task_node));\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> RankTaskGraph::AddTransportTaskEdgesFromProto() {\n  for (const auto& task_edge_proto : boxing_task_graph_proto_->task_edge()) {\n    TaskEdge* edge = NewEdge();\n    auto* src_task_node = JUST(task_graph_rebuild_ctx_->TaskNode4Id(task_edge_proto.src_task_id()));\n    auto* dst_task_node = JUST(task_graph_rebuild_ctx_->TaskNode4Id(task_edge_proto.dst_task_id()));\n    Connect<TaskNode>(src_task_node, edge, dst_task_node);\n    JUST(edge->InitFromProto(task_edge_proto, *task_graph_rebuild_ctx_));\n    JUST(task_graph_rebuild_ctx_->AddTaskEdge(edge, task_edge_proto.task_edge_uid()));\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> RankTaskGraph::InitTransportTaskNodesFromProto() {\n  for (const auto& transport_task_proto : boxing_task_graph_proto_->transport_task()) {\n    int64_t task_id = transport_task_proto.task_proto().task_id();\n    auto* task_node = JUST(task_graph_rebuild_ctx_->TaskNode4Id(task_id));\n    auto* transport_task_node = dynamic_cast<TransportTaskNode*>(task_node);\n    CHECK_NOTNULL_OR_RETURN(transport_task_node)\n        << \"task node is not a TransportTaskNode. task_id\" << task_id;\n    JUST(transport_task_node->InitTransportTaskFromProtoIf(transport_task_proto,\n                                                           *task_graph_rebuild_ctx_));\n  }\n  return Maybe<void>::Ok();\n}\n\nbool RankTaskGraph::ContainRank(const OpNode* op_node, int64_t rank) const {\n  return op_node->parallel_desc().HasMachineId(rank);\n}\n\nMaybe<void> RankTaskGraph::ConnectDataEdges(const OpEdge* op_edge, int64_t rank) {\n  if (!op_edge->NeedBoxing()) {\n    auto* src_task_node = JUST(TryGetRankCompTaskNode(op_edge->src_node(), rank));\n    auto* dst_task_node = JUST(TryGetRankCompTaskNode(op_edge->dst_node(), rank));\n    if (ContainRank(op_edge->src_node(), rank)) {\n      CHECK_NOTNULL_OR_RETURN(src_task_node) << \"src_task_node should not be nullptr. op_name: \"\n                                             << op_edge->src_node()->op().op_name();\n    }\n    if (ContainRank(op_edge->dst_node(), rank)) {\n      CHECK_NOTNULL_OR_RETURN(dst_task_node) << \"dst_task_node should not be nullptr. op_name: \"\n                                             << op_edge->dst_node()->op().op_name();\n    }\n    if (src_task_node != nullptr && dst_task_node != nullptr) {\n      for (const auto& lbi : op_edge->lbis()) { ConnectWithLbi(src_task_node, dst_task_node, lbi); }\n    }\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> RankTaskGraph::ConnectCtrlEdges(const OpNode* src, const OpNode* dst, int64_t rank) {\n  if ((ContainRank(src, rank) && ContainRank(dst, rank))) {\n    auto* src_task_node = CHECK_JUST(TryGetRankCompTaskNode(src, rank));\n    auto* dst_task_node = CHECK_JUST(TryGetRankCompTaskNode(dst, rank));\n    if (src->op().op_conf().has_src_subset_tick_conf()) {\n      UNIMPLEMENTED_THEN_RETURN() << \"ctrl edge from src_subset_tick is not supported.\";\n    } else if (dst->op().op_conf().has_dst_subset_tick_conf()) {\n      UNIMPLEMENTED_THEN_RETURN() << \"ctrl edge to dst_subset_tick is not supported.\";\n    } else {\n      ConnectCtrlEdge(CHECK_NOTNULL(src_task_node), CHECK_NOTNULL(dst_task_node));\n    }\n  }\n  return Maybe<void>::Ok();\n}\n\nbool RankTaskGraph::IsDutyRank(const ParallelDesc& parallel_desc, int64_t rank) const {\n  return IsComputTaskNodeDutyRank(current_rank_, parallel_desc, rank);\n}\n\ntemplate<typename DoEachRankT>\nMaybe<void> RankTaskGraph::DoRankDuty(const ParallelDesc& parallel_desc,\n                                      const DoEachRankT& DoWithRank) {\n  if (current_rank_ == 0) {\n    // make sure master knows at least one op_node.\n    JUST(DoWithRank(JUST(parallel_desc.MachineId4ParallelId(0))));\n  } else if (parallel_desc.HasMachineId(current_rank_)) {\n    // workers only care their own rank.\n    JUST(DoWithRank(current_rank_));\n  } else {\n    // Do nothing.\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> RankTaskGraph::InitRegstDescsConsumers() {\n  const auto& RegstDesc4Id = [&](int64_t regst_desc_id) -> Maybe<RegstDesc> {\n    return JUST(task_graph_rebuild_ctx_->RegstDesc4Id(regst_desc_id));\n  };\n  JUST(MaybeForEachNode([&](TaskNode* task_node) -> Maybe<void> {\n    const auto& task_proto = *JUST(MapAt(task_id2task_proto_, task_node->task_id()));\n    JUST(task_node->InitConsumedRegstsFromProto(task_proto, RegstDesc4Id));\n    return Maybe<void>::Ok();\n  }));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> RankTaskGraph::Init(const HashSet<std::string>& var_op_names) {\n  JUST(AddBoxingReletedCompTaskNodesFromProto());\n  JUST(CreateAndPartiallyInitTransportTaskNodesFromProto());\n  JUST(AddTransportTaskEdgesFromProto());\n  JUST(InitTransportTaskNodesFromProto());\n  JUST(InitRegstDescsConsumers());\n  // Note that tasks currently added in above code are from BoxingTaskGraph, so they are all\n  // boxing related.\n  OpGraph* op_graph = Singleton<OpGraph>::Get();\n  JUST(op_graph->MaybeForEachNode([&](OpNode* op_node) -> Maybe<void> {\n    JUST(DoRankDuty(op_node->parallel_desc(), [&](int64_t rank) -> Maybe<void> {\n      JUST(CreateOrFindRankCompTaskNodeByRank(op_node, rank));\n      return Maybe<void>::Ok();\n    }));\n    if (var_op_names.count(op_node->op().op_name()) > 0\n        && !IsDutyRank(op_node->parallel_desc(), current_rank_)) {\n      // To makes sure all ranks know all var_op_names, at least one task for variable op is\n      // needed in the plan.\n      JUST(CreateOrFindRankCompTaskNodeByParallelId(op_node, /*parallel_id=*/0));\n    }\n    return Maybe<void>::Ok();\n  }));\n\n  JUST(op_graph->MaybeForEachEdge([&](const OpEdge* op_edge) -> Maybe<void> {\n    return DoRankDuty(op_edge->src_node()->parallel_desc(),\n                      [&](int64_t rank) { return ConnectDataEdges(op_edge, rank); });\n  }));\n\n  ForEachOpGraphNecessaryCtrlEdge(op_graph, [&](const OpNode* src, const OpNode* dst) {\n    if (!src->parallel_desc_sym()->EqualsIgnoringHierarchy(*dst->parallel_desc_sym())) {\n      LOG(INFO) << \" src \" << src->parallel_desc_sym()->data().DebugString() << \" dst \"\n                << dst->parallel_desc_sym()->data().DebugString();\n      return;\n    }\n    CHECK_JUST(DoRankDuty(src->parallel_desc(),\n                          [&](int64_t rank) { return ConnectCtrlEdges(src, dst, rank); }));\n  });\n\n  if (Singleton<ResourceDesc, ForSession>::Get()->enable_debug_mode()) { ToDotWithAutoFilePath(); }\n\n  ForEachNode([&](TaskNode* task_node) { task_node->ProduceAllRegstsAndBindEdges(); });\n  ForEachEdge([&](TaskEdge* edge) {\n    CHECK(edge->HasRegst()) << \"Found edge which has not bound a regst, src task \"\n                            << edge->src_node()->VisualStr();\n  });\n  return Maybe<void>::Ok();\n}\n\nRankTaskGraph::~RankTaskGraph() {}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/graph/task_graph.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_GRAPH_TASK_GRAPH_H_\n#define ONEFLOW_CORE_GRAPH_TASK_GRAPH_H_\n\n#include \"oneflow/core/graph/task_node.h\"\n#include \"oneflow/core/job/id_manager.h\"\n#include \"oneflow/core/job/parallel_desc.h\"\n#include \"oneflow/core/operator/operator.h\"\n#include \"oneflow/core/graph/op_graph.h\"\n#include \"oneflow/core/graph/compute_task_node.h\"\n#include \"oneflow/core/graph/copy_task_node.h\"\n#include \"oneflow/core/register/op_blob_arg_info.h\"\n#include \"oneflow/core/graph/boxing/boxing_logger.h\"\n#include \"oneflow/core/memory/memory_zone.h\"\n\nnamespace oneflow {\n\nclass SubTskGphBuilderCtx;\nclass HierarchicalSubTskGphBuilder;\n\n#define BLD_SUB_TSK_GPH_MTHD_ARGS()                                                \\\n  (const OpEdge* op_edge, const std::vector<CompTaskNode*>& sorted_src_comp_tasks, \\\n   const std::vector<CompTaskNode*>& sorted_dst_comp_tasks)\n\nclass TaskGraph;\nusing BldSubTskGphMthd = void(TaskGraph::*) BLD_SUB_TSK_GPH_MTHD_ARGS();\n\nclass TaskGraph : public Graph<TaskNode, TaskEdge> {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(TaskGraph);\n  virtual ~TaskGraph() override;\n\n  const char* TypeName() const override { return \"TaskGraph\"; }\n  void RemoveEmptyRegsts();\n  void MergeChainAndAddOrderingCtrlEdgeInSameChain();\n  void DecideExecutionOrder();\n\n  void EnableInplaceMemSharing(const std::function<bool(const std::string&, const std::string&)>&\n                                   IsOpNameDataOrCtrlReachable);\n\n  void EnableInplaceMemSharing(const HashSet<TaskNode*>& dev_nodes,\n                               const std::function<bool(const std::string&, const std::string&)>&\n                                   IsOpNameDataOrCtrlReachable);\n\n  TaskNode* GetProxyNode(TaskNode* src_node, const LogicalBlobId& lbi,\n                         const MemZoneId& dst_mem_zone_id);\n\n  TaskNode* GetProxyNode(TaskNode* src_node, const LogicalBlobId& lbi,\n                         const ParallelDesc& dst_parallel_desc, int64_t dst_parallel_id);\n\n  TaskEdge* NewTaskEdgeWithLbi(const LogicalBlobId& lbi);\n  TaskEdge* NewTaskEdgeWithLbis(const std::vector<LogicalBlobId>& lbis);\n\n  void ConnectWithLbi(TaskNode* src_node, TaskNode* dst_node, const LogicalBlobId& lbi);\n\n#define DECLARE_BLD_SUB_TASK_GRAPH_METHOD(method_name) void method_name BLD_SUB_TSK_GPH_MTHD_ARGS();\n\n  DECLARE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByBoxing);\n  DECLARE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByOneToOne);\n  DECLARE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByBroadcastToBroadcast);\n  DECLARE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByPartialInLbiConnect);\n  DECLARE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByPartialOutLbiConnect);\n  DECLARE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphBySrcSubsetConnect);\n  DECLARE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByDstSubsetConnect);\n  DECLARE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphNormalForwardToDecodeH2D);\n\n  void ForEachGpuDeviceNodes(\n      const std::function<void(const HashSet<TaskNode*>& dev_nodes)>& Handler) const;\n\n protected:\n  explicit TaskGraph();\n\n  void BuildTaskPath(TaskNode* src_node, TaskNode* dst_node, const LogicalBlobId& lbi,\n                     bool is_host_mem_input);\n\n  void ConnectCtrlEdges(const std::vector<CompTaskNode*>& src_task_nodes,\n                        const std::vector<CompTaskNode*>& dst_task_nodes);\n\n  void ConnectCtrlEdge(CompTaskNode* src_task_node, CompTaskNode* dst_task_node);\n\n  void InitOrderedTaskNodes();\n  void MergeChainByPhysicalTaskGraph();\n  void MergeChainByLogicalChainId();\n  void BuildCtrlRegstDescInSameChain();\n\n  // inplace\n  void GetInplaceOpBlobArgList(\n      InplaceObasInfo* obas_info, const HashSet<TaskNode*>& dev_nodes,\n      const std::function<const TaskNode*(const std::string&)>& TaskNode4OpName) const;\n  void GetSafeInplaceOpBlobArgList(\n      InplaceObasInfo* safe_obas_info, const HashSet<TaskNode*>& dev_nodes,\n      const std::function<bool(const std::string&, const std::string&)>&\n          IsOpNameDataOrCtrlReachable) const;\n  void SetTaskRegstInplaceInfo(const InplaceObasInfo& obas_info,\n                               const HashSet<TaskNode*>& dev_nodes) const;\n  std::vector<TaskNode*> ordered_task_nodes_;\n  HashMap<DeviceType, std::unique_ptr<HierarchicalSubTskGphBuilder>>\n      device_type2sub_tsk_gph_builder_;\n  std::unique_ptr<HierarchicalSubTskGphBuilder> hierarchical_sub_tsk_gph_builder_;\n  std::unique_ptr<SubTskGphBuilderCtx> sub_tsk_gph_builder_ctx_;\n  std::unique_ptr<BoxingLogger> boxing_logger_;\n\n  struct ProxyKey {\n    TaskNode* src_node;\n    LogicalBlobId lbi;\n    MemZoneId dst_mem_zone_id;\n\n    ProxyKey(TaskNode* src, const LogicalBlobId& arg_lbi, const MemZoneId& arg_mem_zone_id)\n        : src_node(src), lbi(arg_lbi), dst_mem_zone_id(arg_mem_zone_id) {}\n\n    bool operator==(const ProxyKey& other) const {\n      return src_node == other.src_node && lbi == other.lbi\n             && dst_mem_zone_id == other.dst_mem_zone_id;\n    }\n\n    struct Hasher {\n      inline size_t operator()(const ProxyKey& key) const {\n        return Hash(key.src_node, key.lbi, key.dst_mem_zone_id.hash());\n      }\n    };\n  };\n\n  HashMap<ProxyKey, TaskNode*, ProxyKey::Hasher> proxy2node;\n};\n\nclass GlobalTaskGraph final : public TaskGraph {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(GlobalTaskGraph);\n  ~GlobalTaskGraph() = default;\n  static Maybe<GlobalTaskGraph> New() {\n    std::shared_ptr<GlobalTaskGraph> graph(new GlobalTaskGraph());\n    JUST(graph->Init());\n    return graph;\n  }\n\n private:\n  GlobalTaskGraph() = default;\n  Maybe<void> Init();\n};\n\nclass BoxingTaskGraphProto;\n\nclass BoxingTaskGraph final : public TaskGraph {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(BoxingTaskGraph);\n  ~BoxingTaskGraph() = default;\n\n  static Maybe<BoxingTaskGraph> New(\n      const std::function<void(size_t, const std::function<void(size_t i)>&)>& ParallelRunLoop) {\n    std::shared_ptr<BoxingTaskGraph> graph(new BoxingTaskGraph());\n    JUST(graph->Init(ParallelRunLoop));\n    return graph;\n  }\n\n  void ToProto(const std::function<bool(TaskNode*)>& Pick, BoxingTaskGraphProto* proto) const;\n  static bool SelectTaskNodeByRank(TaskNode*, int64_t rank);\n\n private:\n  BoxingTaskGraph() = default;\n  Maybe<void> Init(\n      const std::function<void(size_t, const std::function<void(size_t i)>&)>& ParallelRunLoop);\n\n  void CreateOpNode2TaskIds(\n      const std::function<void(size_t, const std::function<void(size_t i)>&)>& ParallelRunLoop);\n\n  HashMap<const OpNode*, std::vector<CompTaskNode*>> boxing_related_op_node2sorted_comp_tasks_;\n  HashMap<const OpNode*, std::vector<TaskId>> boxing_unrelated_op_node2sorted_task_ids_;\n};\n\nclass TaskGraphRebuildCtx;\n\nclass RankTaskGraph final : public TaskGraph {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(RankTaskGraph);\n  ~RankTaskGraph();\n\n  static Maybe<RankTaskGraph> New(\n      const std::shared_ptr<BoxingTaskGraphProto>& boxing_task_graph_proto,\n      const HashSet<std::string>& var_op_names, int64_t current_rank) {\n    std::shared_ptr<RankTaskGraph> graph(new RankTaskGraph(boxing_task_graph_proto, current_rank));\n    JUST(graph->Init(var_op_names));\n    return graph;\n  }\n\n  // Is `rank` my duty.\n  bool IsDutyRank(const ParallelDesc& parallel_desc, int64_t rank) const;\n\n private:\n  RankTaskGraph(const std::shared_ptr<BoxingTaskGraphProto>& boxing_task_graph_proto, int64_t rank);\n\n  Maybe<void> Init(const HashSet<std::string>& var_op_names);\n  bool ContainRank(const OpNode* op_node, int64_t rank) const;\n  Maybe<void> AddBoxingReletedCompTaskNodesFromProto();\n  Maybe<void> CreateAndPartiallyInitTransportTaskNodesFromProto();\n  Maybe<void> AddTransportTaskEdgesFromProto();\n  Maybe<void> InitTransportTaskNodesFromProto();\n  Maybe<void> InitRegstDescsConsumers();\n  template<typename DoEachRankT>\n  Maybe<void> DoRankDuty(const ParallelDesc& parallel_desc, const DoEachRankT& DoWithRank);\n\n  Maybe<CompTaskNode*> TryGetBoxingRelatedComTaskNode(const OpNode* op_node, int64_t parallel_id);\n  Maybe<CompTaskNode*> CreateOrFindRankCompTaskNodeByParallelId(const OpNode* op_node,\n                                                                int64_t parallel_id);\n  Maybe<CompTaskNode*> CreateOrFindRankCompTaskNodeByRank(const OpNode* op_node, int64_t rank);\n  Maybe<CompTaskNode*> TryGetRankCompTaskNode(const OpNode* op_node, int64_t rank);\n\n  Maybe<void> ConnectDataEdges(const OpEdge* op_edge, int64_t rank);\n  Maybe<void> ConnectCtrlEdges(const OpNode* src, const OpNode* dst, int64_t rank);\n\n  std::shared_ptr<BoxingTaskGraphProto> boxing_task_graph_proto_;\n  HashMap<int64_t, const TaskProto*> task_id2task_proto_;\n  const int64_t current_rank_;\n  std::unique_ptr<TaskGraphRebuildCtx> task_graph_rebuild_ctx_;\n  HashMap<const OpNode*, CompTaskNode*> op_node2comp_task_node_;\n};\n\nusing CreateSubTskGphBuilderFn = std::function<std::unique_ptr<HierarchicalSubTskGphBuilder>()>;\n\nMaybe<void> RegisterCreateSubTskGphBuilderFn(DeviceType device_type,\n                                             const CreateSubTskGphBuilderFn& fn);\n\n#define REGISTER_CREATE_SUB_TASK_GRAPH_BUILDER_FN(device_type, fn) \\\n  COMMAND(CHECK_JUST(RegisterCreateSubTskGphBuilderFn(device_type, fn)))\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_GRAPH_TASK_GRAPH_H_\n"
  },
  {
    "path": "oneflow/core/graph/task_graph_rebuild_ctx.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/container_util.h\"\n#include \"oneflow/core/graph/task_node.h\"\n#include \"oneflow/core/graph/task_graph_rebuild_ctx.h\"\n\nnamespace oneflow {\n\nMaybe<TaskNode*> TaskGraphRebuildCtx::TaskNode4Id(int64_t task_id) const {\n  auto* task_node = JUST(MapAt(id2task_node_, task_id));\n  CHECK_EQ_OR_RETURN(task_node->task_id(), task_id);  // NOLINT\n  return task_node;\n}\n\nMaybe<TaskEdge*> TaskGraphRebuildCtx::TaskEdge4Uid(int64_t task_edge_uid) const {\n  return JUST(MapAt(uid2task_edge_, task_edge_uid));\n}\n\nMaybe<RegstDesc> TaskGraphRebuildCtx::RegstDesc4Id(int64_t regst_desc_id) const {\n  return JUST(MapAt(id2regst_desc_, regst_desc_id));\n}\n\nMaybe<void> TaskGraphRebuildCtx::AddTaskNode(TaskNode* task_node) {\n  CHECK_OR_RETURN(id2task_node_.emplace(task_node->task_id(), task_node).second)\n      << \"redundant task id found. value: \" << task_node->task_id();\n  for (const auto& pair : task_node->produced_regsts()) { JUST(AddRegstDesc(pair.second)); }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> TaskGraphRebuildCtx::AddTaskEdge(TaskEdge* task_edge, int64_t task_edge_uid) {\n  CHECK_OR_RETURN(uid2task_edge_.emplace(task_edge_uid, task_edge).second)\n      << \"redundant task edge uid found. value: \" << task_edge_uid;\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> TaskGraphRebuildCtx::AddRegstDesc(const std::shared_ptr<RegstDesc>& regst_desc) {\n  CHECK_OR_RETURN(id2regst_desc_.emplace(regst_desc->regst_desc_id(), regst_desc).second)\n      << \"redundant register descriptor id found. value: \" << regst_desc->regst_desc_id();\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/graph/task_graph_rebuild_ctx.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_GRAPH_TASK_GRAPH_REBUILD_CTX_H_\n#define ONEFLOW_CORE_GRAPH_TASK_GRAPH_REBUILD_CTX_H_\n\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/register/register_desc.h\"\n\nnamespace oneflow {\n\nclass TaskNode;\nclass TaskEdge;\n\nclass TaskGraphRebuildCtx {\n public:\n  TaskGraphRebuildCtx() = default;\n  ~TaskGraphRebuildCtx() = default;\n\n  Maybe<TaskNode*> TaskNode4Id(int64_t task_id) const;\n  Maybe<TaskEdge*> TaskEdge4Uid(int64_t task_edge_uid) const;\n  Maybe<RegstDesc> RegstDesc4Id(int64_t regst_desc_id) const;\n\n  Maybe<void> AddTaskNode(TaskNode* task_node);\n  Maybe<void> AddTaskEdge(TaskEdge* task_edge, int64_t task_edge_uid);\n  Maybe<void> AddRegstDesc(const std::shared_ptr<RegstDesc>& regst_desc);\n\n private:\n  HashMap<int64_t, TaskNode*> id2task_node_;\n  HashMap<int64_t, TaskEdge*> uid2task_edge_;\n  HashMap<int64_t, std::shared_ptr<RegstDesc>> id2regst_desc_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_GRAPH_TASK_GRAPH_REBUILD_CTX_H_\n"
  },
  {
    "path": "oneflow/core/graph/task_id.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/graph/task_id.h\"\n#include <climits>\n\nnamespace oneflow {\n\n// TaskId encoding (maybe extended to 128 bits in future)\n// |            rank            | device_type | device_index  |                           |\n// | ----------- 16 ----------- | ---- 5 ---- | ----- 7 ----- |                           |\n// |                        DeviceId                          | stream_index |            |\n// | ------------------------- 31 --------------------------- | ---- 15 ---- |            |\n// |                               StreamId                                  | task_index |\n// | -------------------------------- 43 ----------------------------------- | --- 21 --- |\n// |                                      TaskId                                          |\n// | ----------------------------------- 64 bit ----------------------------------------- |\n\nnamespace {\n\nconstexpr size_t kInt64Bits = sizeof(int64_t) * CHAR_BIT;\n\nconstexpr size_t kStreamIndexShift = TaskId::kTaskIndexBits;\nconstexpr size_t kDeviceIndexShift = kStreamIndexShift + StreamId::kStreamIndexBits;\nconstexpr size_t kDeviceTypeShift = kDeviceIndexShift + DeviceId::kDeviceIndexBits;\nconstexpr size_t kRankShift = kDeviceTypeShift + DeviceId::kDeviceTypeBits;\nstatic_assert(kInt64Bits == kRankShift + DeviceId::kRankBits, \"\");\n\nconstexpr int64_t kTaskIndexInt64Mask = (int64_t{1} << TaskId::kTaskIndexBits) - 1;\nconstexpr int64_t kStreamIndexInt64Mask = ((int64_t{1} << StreamId::kStreamIndexBits) - 1)\n                                          << kStreamIndexShift;\nconstexpr int64_t kDeviceIndexInt64Mask = ((int64_t{1} << DeviceId::kDeviceIndexBits) - 1)\n                                          << kDeviceIndexShift;\nconstexpr int64_t kDeviceTypeInt64Mask = ((int64_t{1} << DeviceId::kDeviceTypeBits) - 1)\n                                         << kDeviceTypeShift;\nconstexpr int64_t kRankInt64Mask = ((int64_t{1} << DeviceId::kRankBits) - 1) << kRankShift;\n\n}  // namespace\n\nint64_t EncodeTaskIdToInt64(const TaskId& task_id) {\n  int64_t id = static_cast<int64_t>(task_id.task_index());\n  id |= static_cast<int64_t>(task_id.stream_id().stream_index()) << kStreamIndexShift;\n  id |= static_cast<int64_t>(task_id.stream_id().device_index()) << kDeviceIndexShift;\n  id |= static_cast<int64_t>(task_id.stream_id().device_type()) << kDeviceTypeShift;\n  id |= static_cast<int64_t>(task_id.stream_id().rank()) << kRankShift;\n  return id;\n}\n\nTaskId DecodeTaskIdFromInt64(int64_t task_id_val) {\n  int64_t rank = (task_id_val & kRankInt64Mask) >> kRankShift;\n  int64_t device_type = (task_id_val & kDeviceTypeInt64Mask) >> kDeviceTypeShift;\n  int64_t device_index = (task_id_val & kDeviceIndexInt64Mask) >> kDeviceIndexShift;\n  int64_t stream_index = (task_id_val & kStreamIndexInt64Mask) >> kStreamIndexShift;\n  int64_t task_index = task_id_val & kTaskIndexInt64Mask;\n  StreamId stream_id{static_cast<DeviceId::rank_t>(rank), static_cast<DeviceType>(device_type),\n                     static_cast<DeviceId::device_index_t>(device_index),\n                     static_cast<StreamId::stream_index_t>(stream_index)};\n  return TaskId{stream_id, static_cast<TaskId::task_index_t>(task_index)};\n}\n\nint64_t MachineId4ActorId(int64_t actor_id) {\n  return DecodeTaskIdFromInt64(actor_id).stream_id().rank();\n}\n\nint64_t ThrdId4ActorId(int64_t actor_id) {\n  return EncodeStreamIdToInt64(DecodeTaskIdFromInt64(actor_id).stream_id());\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/graph/task_id.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_GRAPH_TASK_ID_H_\n#define ONEFLOW_CORE_GRAPH_TASK_ID_H_\n\n#include \"oneflow/core/graph/stream_id.h\"\n\nnamespace oneflow {\n\nclass TaskId {\n public:\n  using task_index_t = uint32_t;\n\n  const static size_t kTaskIndexBits = 21;\n  constexpr static task_index_t kMaxTaskIndex =\n      (task_index_t{1} << kTaskIndexBits) - task_index_t{1};\n\n  TaskId(const StreamId& stream_id, task_index_t task_index)\n      : stream_id_(stream_id), task_index_(task_index) {\n    CHECK_LE(task_index_, kMaxTaskIndex);\n  }\n\n  const StreamId& stream_id() const { return stream_id_; }\n  task_index_t task_index() const { return task_index_; }\n\n  bool operator==(const TaskId& rhs) const {\n    return stream_id_ == rhs.stream_id_ && task_index_ == rhs.task_index_;\n  }\n  bool operator!=(const TaskId& rhs) const { return !(*this == rhs); }\n\n  size_t hash() const {\n    size_t hash = stream_id_.hash();\n    HashCombine(&hash, std::hash<task_index_t>{}(task_index_));\n    return hash;\n  }\n\n private:\n  StreamId stream_id_;\n  task_index_t task_index_;\n};\n\nint64_t EncodeTaskIdToInt64(const TaskId&);\nTaskId DecodeTaskIdFromInt64(int64_t);\n\nint64_t MachineId4ActorId(int64_t actor_id);\nint64_t ThrdId4ActorId(int64_t actor_id);\n\n}  // namespace oneflow\n\nnamespace std {\n\ntemplate<>\nstruct hash<oneflow::TaskId> {\n  size_t operator()(const oneflow::TaskId& task_id) const { return task_id.hash(); }\n};\n\n}  // namespace std\n\n#endif  // ONEFLOW_CORE_GRAPH_TASK_ID_H_\n"
  },
  {
    "path": "oneflow/core/graph/task_id_generator.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/multi_client_session_context.h\"\n#include \"oneflow/core/graph/stream_id.h\"\n#include \"oneflow/core/graph/task_id.h\"\n#include \"oneflow/core/graph/task_id_generator.h\"\n\nnamespace oneflow {\n\nvoid TaskIdGenerator::GetTaskIndex(HashMap<int64_t, uint32_t>* task_index_state) {\n  for (const auto& pair : stream_id2task_index_counter_) {\n    const int64_t i64_stream_id = EncodeStreamIdToInt64(pair.first);\n    (*task_index_state)[i64_stream_id] = pair.second;\n  }\n}\n\nvoid TaskIdGenerator::TryUpdateTaskIndex(const HashMap<int64_t, uint32_t>& task_index_state) {\n  for (auto& pair : stream_id2task_index_counter_) {\n    const int64_t i64_stream_id = EncodeStreamIdToInt64(pair.first);\n    uint32_t initial_task_index = 0;\n    if (task_index_state.count(i64_stream_id) != 0) {\n      initial_task_index = task_index_state.at(i64_stream_id);\n    }\n    pair.second = std::max(pair.second, initial_task_index);\n  }\n\n  // try update the task_index_init_state\n  for (const auto& pair : task_index_state) {\n    const auto& key = pair.first;\n    const auto& val = pair.second;\n    if (task_index_init_state_.count(key) != 0) {\n      task_index_init_state_[key] = std::max(task_index_init_state_.at(key), val);\n    } else {\n      task_index_init_state_[key] = val;\n    }\n  }\n}\n\nTaskId TaskIdGenerator::Generate(const StreamId& stream_id) {\n  std::unique_lock<std::mutex> lock(mutex_);\n  if (stream_id2task_index_counter_.count(stream_id) == 0) {\n    uint32_t init_task_index = 0;\n    const int64_t i64_stream_id = EncodeStreamIdToInt64(stream_id);\n    if (task_index_init_state_.count(i64_stream_id) != 0) {\n      init_task_index = task_index_init_state_.at(i64_stream_id);\n    }\n    stream_id2task_index_counter_[stream_id] = init_task_index;\n  }\n  task_index_t task_index = stream_id2task_index_counter_[stream_id]++;\n  return TaskId{stream_id, task_index};\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/graph/task_id_generator.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_GRAPH_TASK_ID_GENERATOR_H_\n#define ONEFLOW_CORE_GRAPH_TASK_ID_GENERATOR_H_\n\n#include \"oneflow/core/graph/task_id.h\"\n#include \"oneflow/core/job/id_state.h\"\n\nnamespace oneflow {\n\nclass TaskIdGenerator final {\n public:\n  using task_index_t = TaskId::task_index_t;\n\n  TaskIdGenerator() = default;\n  OF_DISALLOW_COPY_AND_MOVE(TaskIdGenerator);\n  ~TaskIdGenerator() = default;\n\n  TaskId Generate(const StreamId& stream_id);\n\n  void GetTaskIndex(HashMap<int64_t, uint32_t>* task_index_state);\n  void TryUpdateTaskIndex(const HashMap<int64_t, uint32_t>& task_index_state);\n\n private:\n  std::mutex mutex_;\n  HashMap<StreamId, task_index_t> stream_id2task_index_counter_;\n  // The task_index_init_state is used to initialize the `stream_id2task_index_counter_` hashmap.\n  HashMap<int64_t, uint32_t> task_index_init_state_{};\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_GRAPH_TASK_ID_GENERATOR_H_\n"
  },
  {
    "path": "oneflow/core/graph/task_node.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/graph/task_node.h\"\n#include \"oneflow/core/common/container_util.h\"\n#include \"oneflow/core/job/id_manager.h\"\n#include \"oneflow/core/memory/memory_case_util.h\"\n#include \"oneflow/core/graph/task_graph_rebuild_ctx.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nvoid ForEachDataEdge(const std::unordered_set<TaskEdge*>& edges,\n                     const std::function<void(TaskEdge*)>& Handler) {\n  for (TaskEdge* edge : edges) {\n    const auto& regsts = edge->GetRegsts();\n    int32_t data_regst_size =\n        std::count_if(regsts.begin(), regsts.end(), [](const std::shared_ptr<RegstDesc>& regst) {\n          return regst->regst_desc_type().has_data_regst_desc();\n        });\n    if (data_regst_size == regsts.size()) {\n      Handler(edge);\n    } else {\n      CHECK_EQ(data_regst_size, 0);\n    }\n  }\n}\n\n}  // namespace\n\nTaskNode::TaskNode()\n    : machine_id_(-1), thrd_id_(-1), task_id_(-1), chain_id_(-1), order_in_chain_(-1) {}\n\nstd::shared_ptr<RegstDesc> TaskNode::GetProducedRegst(const std::string& name) {\n  auto produced_regsts_it = produced_regsts_.find(name);\n  if (produced_regsts_it == produced_regsts_.end()) {\n    return nullptr;\n  } else {\n    return produced_regsts_it->second;\n  }\n}\n\nconst std::list<std::shared_ptr<RegstDesc>>& TaskNode::GetConsumedRegst(const std::string& name) {\n  return consumed_regsts_.at(name);\n}\n\nstd::shared_ptr<RegstDesc> TaskNode::GetSoleConsumedRegst(const std::string& name) {\n  auto it = consumed_regsts_.find(name);\n  if (it == consumed_regsts_.end()) { return nullptr; }\n  const std::list<std::shared_ptr<RegstDesc>>& vec = it->second;\n  CHECK_EQ(vec.size(), 1);\n  return vec.front();\n}\n\nconst StreamId& TaskNode::stream_id() const {\n  CHECK(new_task_id_);\n  return new_task_id_->stream_id();\n}\n\nDeviceType TaskNode::device_type() const { return stream_id().device_id().device_type(); }\n\nvoid TaskNode::set_machine_id(int64_t val) {\n  CHECK_EQ(machine_id_, -1);\n  machine_id_ = val;\n  if (thrd_id_ != -1) { UpdateTaskId(); }\n}\n\nvoid TaskNode::set_thrd_id(int64_t val) {\n  CHECK_EQ(thrd_id_, -1);\n  thrd_id_ = val;\n  CHECK_GE(thrd_id_, 0);\n  if (machine_id_ != -1) { UpdateTaskId(); }\n}\n\nvoid TaskNode::set_chain_id(int64_t val) {\n  CHECK(!IsValidChainId(chain_id_));\n  chain_id_ = val;\n}\n\nvoid TaskNode::set_order_in_chain(int64_t val) {\n  CHECK_EQ(order_in_chain_, -1);\n  order_in_chain_ = val;\n}\n\nvoid TaskNode::PinConsumedRegst() {\n  for (auto& pair : consumed_regsts_) {\n    for (const std::shared_ptr<RegstDesc>& regst : pair.second) {\n      PinConsumedRegstMemCase(regst->mut_mem_case());\n    }\n  }\n}\n\nvoid TaskNode::NaiveInferProducedDataRegstTimeShape() {\n  if (IsMeaningLess()) { return; }\n  std::shared_ptr<Shape> time_shape;\n  ForEachConsumedDataRegst([&time_shape](const std::string& name, const RegstDesc* regst) {\n    if (time_shape) {\n      CHECK_EQ(*time_shape.get(), *regst->data_regst_time_shape().get());\n    } else {\n      time_shape = regst->data_regst_time_shape();\n    }\n  });\n\n  CHECK(time_shape);\n\n  ForEachProducedDataRegst([time_shape](const std::string& name, RegstDesc* regst) {\n    *regst->mut_data_regst_time_shape() = time_shape;\n  });\n}\n\nvoid TaskNode::InferTimeShapeIfMeaningful() {\n  if (!IsMeaningLess()) { InferProducedDataRegstTimeShape(); }\n}\n\nstd::shared_ptr<Shape> TaskNode::GetFastestInputOutputTimeShape() const {\n  std::shared_ptr<Shape> shape;\n  auto UpdateRetShape = [&](TaskEdge* edge) {\n    for (const auto& regst : edge->GetRegsts()) {\n      if (!shape || shape->elem_cnt() < regst->data_regst_time_shape()->elem_cnt()) {\n        shape = regst->data_regst_time_shape();\n      }\n    }\n  };\n  ForEachOutDataEdge(UpdateRetShape);\n  if (shape) { return shape; }\n  ForEachInDataEdge(UpdateRetShape);\n  return shape;\n}\n\nvoid TaskNode::ForEachConsumedDataRegst(\n    const std::function<void(const std::string&, const RegstDesc*)>& Handler) const {\n  for (const auto& pair : consumed_regsts_) {\n    for (const auto& regst : pair.second) {\n      if (!regst->regst_desc_type().has_data_regst_desc()) { continue; }\n      Handler(pair.first, regst.get());\n    }\n  }\n}\n\nvoid TaskNode::ForEachProducedDataRegst(\n    const std::function<void(const std::string&, RegstDesc*)>& Handler) {\n  for (auto& pair : produced_regsts_) {\n    if (!pair.second->regst_desc_type().has_data_regst_desc()) { continue; }\n    Handler(pair.first, pair.second.get());\n  }\n}\n\nvoid TaskNode::Build() { BuildExecGphAndRegst(); }\n\nvoid TaskNode::EraseUninitializedShapeProducedBlob() {\n  for (auto& pair : produced_regsts_) { pair.second->EraseUninitializedShapeBlob(); }\n}\n\nvoid TaskNode::EraseZeroSizeConsumedRegst() {\n  for (auto& pair : consumed_regsts_) {\n    for (auto it = pair.second.begin(); it != pair.second.end();) {\n      auto regst_ptr = *it;\n      CHECK(regst_ptr);\n      if (regst_ptr->regst_desc_type().has_data_regst_desc() && regst_ptr->NumOfLbi() == 0) {\n        it = pair.second.erase(it);\n      } else {\n        ++it;\n      }\n    }\n  }\n  EraseIf<std::string, std::list<std::shared_ptr<RegstDesc>>>(\n      &consumed_regsts_,\n      [](HashMap<std::string, std::list<std::shared_ptr<RegstDesc>>>::iterator it) {\n        return it->second.empty();\n      });\n}\n\nvoid TaskNode::EraseZeroSizeProducedRegst() {\n  EraseIf<std::string, std::shared_ptr<RegstDesc>>(\n      &produced_regsts_, [](HashMap<std::string, std::shared_ptr<RegstDesc>>::iterator it) {\n        return it->second->regst_desc_type().has_data_regst_desc() && it->second->NumOfLbi() == 0;\n      });\n}\n\nvoid TaskNode::UnbindBnWithEmptyRegst() {\n  exec_gph_.ForEachNode([&](ExecNode* exec_node) { exec_node->UnbindBnWithEmptyRegst(); });\n}\n\nstd::string TaskNode::VisualStr() const {\n  std::stringstream ss;\n  ss << TaskType_Name(GetTaskType()) << \"\\\\n\"\n     << machine_id_ << \":\" << thrd_id_ << \"\\\\n\"\n     << task_id_;\n  return ss.str();\n}\n\nbool TaskNode::IsMeaningLess() { return produced_regsts_.empty() && consumed_regsts_.empty(); }\n\nvoid TaskNode::InitFromProtoExceptConsumedRegsts(const TaskProto& task_proto) {\n  // Step1: init some scalar items.\n  CHECK(task_proto.task_type() == GetTaskType());\n  machine_id_ = task_proto.machine_id();\n  thrd_id_ = task_proto.thrd_id();\n  task_id_ = task_proto.task_id();\n  new_task_id_.reset(new TaskId(DecodeTaskIdFromInt64(task_id_)));\n  CHECK(task_proto.job_id() == GlobalJobDesc().job_id());\n  chain_id_ = task_proto.chain_id();\n  order_in_chain_ = task_proto.order_in_chain();\n  // Step2: check exec_gph empty.\n  CHECK(task_proto.exec_sequence().exec_node().empty());\n  // Step3: init produced_regst.\n  for (const auto& pair : task_proto.produced_regst_desc()) {\n    const auto& regst_desc = ProduceRegst(pair.first, pair.second.enable_reuse_mem());\n    // regst_desc->consumers_ will be initialized by RegstDesc::InitConsumersFromProto.\n    regst_desc->InitFromProtoExceptConsumers(pair.second);\n  }\n}\n\nMaybe<void> TaskNode::InitConsumedRegstsFromProto(\n    const TaskProto& task_proto,\n    const std::function<Maybe<RegstDesc>(int64_t regst_desc_id)>& RegstDesc4Id) {\n  // init consumed_regst.\n  for (const auto& pair : task_proto.consumed_regst_desc_id()) {\n    for (int64_t regst_desc_id : pair.second.regst_desc_id()) {\n      ConsumeRegst(pair.first, JUST(RegstDesc4Id(regst_desc_id)));\n    }\n  }\n  return Maybe<void>::Ok();\n}\n\nvoid TaskNode::ToProto(TaskProto* task_proto, bool check) const {\n  // Step1: process some scalar items.\n  task_proto->set_task_type(GetTaskType());\n  task_proto->set_machine_id(machine_id_);\n  task_proto->set_thrd_id(thrd_id_);\n  task_proto->set_task_id(task_id_);\n  task_proto->set_job_id(GlobalJobDesc().job_id());\n  task_proto->set_chain_id(chain_id_);\n  task_proto->set_order_in_chain(order_in_chain_);\n\n  // Step2: process exec_gph.\n  exec_gph_.ToExecSequence(parallel_ctx(), task_proto->mutable_exec_sequence());\n\n  // Step3: process produced_regst.\n  auto* produced_regst_proto = task_proto->mutable_produced_regst_desc();\n  for (auto& pair : produced_regsts_) {\n    RegstDescProto regst_desc_proto;\n    pair.second->ToProto(&regst_desc_proto, check);\n    CHECK(produced_regst_proto->insert({pair.first, regst_desc_proto}).second);\n  }\n\n  // Step4: process consumed_regst.\n  auto* consumed_regst_proto = task_proto->mutable_consumed_regst_desc_id();\n  for (const auto& pair : consumed_regsts_) {\n    RegstDescIdSet regst_desc_ids;\n    for (const std::shared_ptr<RegstDesc>& regst : pair.second) {\n      regst_desc_ids.add_regst_desc_id(regst->regst_desc_id());\n    }\n    CHECK(consumed_regst_proto->insert({pair.first, regst_desc_ids}).second);\n  }\n}\n\nMemZoneId TaskNode::MemZoneId121() const {\n  StreamId stream_id = DecodeStreamIdFromInt64(thrd_id_);\n  return stream_id.device_id();\n}\n\nbool TaskNode::BuildCtrlRegstDescIfNeed(TaskNode* dst_node, std::string* name) {\n  if (IsMeaningLess() || dst_node->IsMeaningLess()) { return false; }\n  for (const TaskEdge* in_edge : dst_node->in_edges()) {\n    if (in_edge->src_node() == this) { return false; }\n  }\n  BuildCtrlRegstDesc(dst_node, name);\n  return true;\n}\n\nRegstDesc* TaskNode::BuildCtrlRegstDesc(TaskNode* dst_node) {\n  std::string name;\n  return BuildCtrlRegstDesc(dst_node, &name);\n}\n\nRegstDesc* TaskNode::BuildCtrlRegstDesc(TaskNode* dst_node, std::string* name) {\n  RegstDescTypeProto regst_desc_type;\n  regst_desc_type.mutable_ctrl_regst_desc();\n  auto regst = NewProducedRegst(false, 1, kMaxRegisterNum, regst_desc_type);\n  *name = \"out_ctrl_\" + std::to_string(regst->regst_desc_id());\n  CHECK(produced_regsts_.emplace(*name, regst).second);\n  dst_node->ConsumeRegst(\"in_ctrl\", regst);\n  return regst.get();\n}\n\nvoid TaskNode::BindEdgeWithProducedRegst(TaskEdge* edge, const std::string& name) {\n  if (edge->HasRegst(name)) { return; }\n  edge->AddRegst(name, GetProducedRegst(name));\n}\n\nstd::shared_ptr<RegstDesc> TaskNode::GetAndCheckRegst(const std::string& name,\n                                                      bool enable_reuse_mem,\n                                                      int32_t min_register_num,\n                                                      int32_t max_register_num) const {\n  auto iter = produced_regsts_.find(name);\n  if (iter == produced_regsts_.end()) { return nullptr; }\n  const auto& regst = (iter->second);\n  CHECK_EQ(regst->min_register_num(), min_register_num);\n  CHECK_EQ(regst->max_register_num(), max_register_num);\n  CHECK_EQ(regst->enable_reuse_mem(), enable_reuse_mem);\n  return regst;\n}\n\nstd::shared_ptr<RegstDesc> TaskNode::ProduceRegst(const std::string& name, bool enable_reuse_mem) {\n  return ProduceRegst(name, enable_reuse_mem, 1, kMaxRegisterNum);\n}\n\nstd::shared_ptr<RegstDesc> TaskNode::ProduceRegst(const std::string& name, bool enable_reuse_mem,\n                                                  int32_t min_register_num,\n                                                  int32_t max_register_num) {\n  // Because the Regst of separate compilation is not created in order, some Regst may have been\n  // built. This implementation can avoid ProduceRegst being called multiple times.\n  const auto& regst = GetAndCheckRegst(name, enable_reuse_mem, min_register_num, max_register_num);\n  if (regst) { return regst; }\n  RegstDescTypeProto regst_desc_type;\n  regst_desc_type.mutable_data_regst_desc();\n  return ProduceRegst(name, enable_reuse_mem, min_register_num, max_register_num, regst_desc_type);\n}\n\nstd::shared_ptr<RegstDesc> TaskNode::ProduceRegst(const std::string& name, bool enable_reuse_mem,\n                                                  int32_t min_register_num,\n                                                  int32_t max_register_num,\n                                                  const RegstDescTypeProto& regst_desc_type) {\n  auto regst =\n      NewProducedRegst(enable_reuse_mem, min_register_num, max_register_num, regst_desc_type);\n  CHECK(produced_regsts_.emplace(name, regst).second);\n  return regst;\n}\n\nstd::shared_ptr<RegstDesc> TaskNode::NewProducedRegst(bool enable_reuse_mem,\n                                                      int32_t min_register_num,\n                                                      int32_t max_register_num,\n                                                      const RegstDescTypeProto& regst_desc_type) {\n  auto regst = std::make_shared<RegstDesc>();\n  regst->set_producer(this);\n  *(regst->mut_regst_desc_type()) = regst_desc_type;\n  regst->UpdtMinRegstNumIfNeed(min_register_num);\n  regst->UpdtMaxRegstNumIfNeed(max_register_num);\n  regst->set_enable_reuse_mem(GlobalJobDesc().enable_reuse_mem() && enable_reuse_mem);\n  InitProducedRegstMemCase(regst.get());\n  return regst;\n}\n\nvoid TaskNode::InitProducedRegstMemCase(RegstDesc* regst) {\n  InitProducedRegstMemCase(regst->mut_mem_case());\n}\n\nvoid TaskNode::InitProducedRegstMemCase(MemoryCase* mem_case) {\n  mem_case->set_device_type(device_type());\n  mem_case->set_device_id(stream_id().device_id().device_index());\n}\n\nvoid TaskNode::PinConsumedRegstMemCase(MemoryCase* mem_case) {\n  // When a node located on non-cpu device consumes a cpu regst,\n  // the regst memory should be pinned on host memory (locked page memory).\n  // When the regst is not on host, skip pinning\n  if (!memory::IsHostMem(*mem_case)) { return; }\n  // When the node is located on host, skip pinning\n  if (device_type() == DeviceType::kCPU) { return; }\n  mem_case->set_pinned_device_type(device_type());\n  mem_case->set_pinned_device_id(stream_id().device_id().device_index());\n}\n\nvoid TaskNode::ConsumeRegst(const std::string& name) {\n  consumed_regsts_.emplace(name, std::list<std::shared_ptr<RegstDesc>>{});\n}\n\nvoid TaskNode::ConsumeRegst(const std::string& name, const std::shared_ptr<RegstDesc>& regst) {\n  regst->AddConsumer(this);\n  consumed_regsts_[name].emplace_back(regst);\n}\n\nvoid TaskNode::UpdateTaskId() {\n  CHECK_NE(machine_id_, -1);\n  CHECK_NE(thrd_id_, -1);\n  StreamId stream_id = DecodeStreamIdFromInt64(thrd_id_);\n  new_task_id_.reset(\n      new TaskId(Singleton<IDMgr>::Get()->GetTaskIdGenerator()->Generate(stream_id)));\n  task_id_ = EncodeTaskIdToInt64(*new_task_id_);\n}\n\nvoid TaskNode::update_new_task_id(const TaskId& task_id) {\n  CHECK(static_cast<bool>(new_task_id_));\n  CHECK(new_task_id_->stream_id() == task_id.stream_id());\n  *new_task_id_ = task_id;\n  task_id_ = EncodeTaskIdToInt64(*new_task_id_);\n}\n\nvoid TaskNode::EraseConsumedRegstsByName(const std::string& name) {\n  if (consumed_regsts_.find(name) != consumed_regsts_.end()) {\n    for (auto& regst : consumed_regsts_[name]) { regst->DeleteConsumer(this); }\n    CHECK_EQ(consumed_regsts_.erase(name), 1);\n  }\n}\n\nstd::shared_ptr<RegstDesc> TaskEdge::GetRegst(const std::string& name_in_producer) const {\n  return name_in_producer2regst_.at(name_in_producer);\n}\n\nbool TaskEdge::HasRegst(const std::string& name_in_producer) const {\n  return (name_in_producer2regst_.find(name_in_producer) != name_in_producer2regst_.end());\n}\n\nstd::shared_ptr<RegstDesc> TaskEdge::GetSoleRegst() const {\n  CHECK_EQ(name_in_producer2regst_.size(), 1)\n      << \"edge: \" << this << \", src: \" << src_node()->task_id()\n      << \", dst: \" << dst_node()->task_id();\n  return name_in_producer2regst_.begin()->second;\n}\n\nstd::vector<std::shared_ptr<RegstDesc>> TaskEdge::GetRegsts() const {\n  std::vector<std::shared_ptr<RegstDesc>> regst_descs;\n  regst_descs.reserve(name_in_producer2regst_.size());\n  for (auto& pair : name_in_producer2regst_) { regst_descs.emplace_back(pair.second); }\n  return regst_descs;\n}\n\nvoid TaskEdge::AddRegst(const std::string& name_in_producer,\n                        const std::shared_ptr<RegstDesc>& regst) {\n  if (HasRegst(name_in_producer)) {\n    CHECK(CHECK_JUST(MapAt(name_in_producer2regst_, name_in_producer))->regst_desc_id()\n          == regst->regst_desc_id());\n    return;\n  }\n  CHECK(name_in_producer2regst_.emplace(name_in_producer, regst).second);\n}\n\nvoid TaskEdge::CheckRegstLbiValid() const {\n  HashMap<LogicalBlobId, std::shared_ptr<RegstDesc>> lbi2data_regst;\n  for (auto& pair : name_in_producer2regst_) {\n    std::shared_ptr<RegstDesc> regst = pair.second;\n    if (regst->regst_desc_type().has_data_regst_desc()) {\n      // NOTE(chengcheng): regst_desc_type is Set, BUT regst_desc_type.data_regst_desc is UNSET!\n      //  So you can ONLY use NumOfLbi and ForEachLbi interface.\n      CHECK_EQ(regst->NumOfLbi(), 1);\n      regst->ForEachLbi(\n          [&](const LogicalBlobId& lbi) { CHECK(lbi2data_regst.emplace(lbi, regst).second); });\n    }\n  }\n\n  CHECK_EQ(lbi2data_regst.size(), lbis_.size())\n      << \" \\n\\n TaskEdge lbi and regst NOT match.\"\n      << \" TaskEdge: edge_id = \" << edge_id() << \" From: [\" << src_node()->VisualStr() << \"] To: [\"\n      << dst_node()->VisualStr() << \"]\\n\";\n  for (auto& lbi : lbis_) {\n    CHECK(lbi2data_regst.find(lbi) != lbi2data_regst.end())\n        << \" \\n\\n Cannot find lbi: \" << lbi.DebugString() << \" in TaskEdge From: [\"\n        << src_node()->VisualStr() << \"] To: [\" << dst_node()->VisualStr() << \"]\\n\\n\";\n  }\n}\n\nRegstDescProto* FindOrCreateProducedCtrlRegstDesc(TaskProto* task_proto,\n                                                  const std::string& regst_desc_name) {\n  auto* produced_regst_desc = task_proto->mutable_produced_regst_desc();\n  if (produced_regst_desc->find(regst_desc_name) == produced_regst_desc->end()) {\n    RegstDescProto ctrl_regst_desc;\n    InitCtrlRegstDesc(task_proto->task_id(), &ctrl_regst_desc);\n    CHECK(produced_regst_desc->insert({regst_desc_name, ctrl_regst_desc}).second);\n  }\n  return &produced_regst_desc->at(regst_desc_name);\n}\n\nRegstDescIdSet* FindOrCreateConsumedCtrlRegstDescIdSet(TaskProto* task_proto,\n                                                       const std::string& regst_desc_name) {\n  auto* consumed_regst_desc_id_sets = task_proto->mutable_consumed_regst_desc_id();\n  if (consumed_regst_desc_id_sets->find(regst_desc_name) == consumed_regst_desc_id_sets->end()) {\n    CHECK(consumed_regst_desc_id_sets->insert({regst_desc_name, RegstDescIdSet()}).second);\n  }\n  return &consumed_regst_desc_id_sets->at(regst_desc_name);\n}\n\nvoid TaskNode::ForEachInDataEdge(const std::function<void(TaskEdge*)>& Handler) const {\n  ForEachDataEdge(in_edges(), Handler);\n}\n\nvoid TaskNode::ForEachOutDataEdge(const std::function<void(TaskEdge*)>& Handler) const {\n  ForEachDataEdge(out_edges(), Handler);\n}\n\nvoid TaskNode::ForEachNodeOnInDataEdge(const std::function<void(TaskNode*)>& Handler) const {\n  ForEachInDataEdge([&](TaskEdge* in_edge) { Handler(in_edge->src_node()); });\n}\n\nvoid TaskNode::ForEachNodeOnOutDataEdge(const std::function<void(TaskNode*)>& Handler) const {\n  ForEachOutDataEdge([&](TaskEdge* out_edge) { Handler(out_edge->dst_node()); });\n}\n\nvoid TaskNode::ForEachNodeOnInOutDataEdge(const std::function<void(TaskNode*)>& Handler) const {\n  ForEachNodeOnInDataEdge(Handler);\n  ForEachNodeOnOutDataEdge(Handler);\n}\n\nTaskEdge* TaskNode::GetSoleEdge(void (TaskNode::*ForEachEdge)(const std::function<void(TaskEdge*)>&)\n                                    const) const {\n  TaskEdge* ret = nullptr;\n  (this->*ForEachEdge)([&](TaskEdge* edge) {\n    CHECK(ret == nullptr);\n    ret = edge;\n  });\n  CHECK_NOTNULL(ret);\n  return ret;\n}\n\nsize_t TaskNode::GetEdgesSize(void (TaskNode::*ForEachEdge)(const std::function<void(TaskEdge*)>&)\n                                  const) const {\n  size_t size = 0;\n  (this->*ForEachEdge)([&](TaskEdge* edge) { ++size; });\n  return size;\n}\n\nTaskEdge* TaskNode::SoleInDataEdge() const { return GetSoleEdge(&TaskNode::ForEachInDataEdge); }\n\nTaskEdge* TaskNode::SoleOutDataEdge() const { return GetSoleEdge(&TaskNode::ForEachOutDataEdge); }\n\nsize_t TaskNode::in_data_edges_size() const { return GetEdgesSize(&TaskNode::ForEachInDataEdge); }\n\nsize_t TaskNode::out_data_edges_size() const { return GetEdgesSize(&TaskNode::ForEachOutDataEdge); }\n\nMaybe<void> TaskEdge::InitFromProto(const TaskEdgeProto& proto,\n                                    const TaskGraphRebuildCtx& task_graph_rebuild_ctx) {\n  CHECK_NE_OR_RETURN(proto.src_task_id(), proto.dst_task_id()) << \"self-loop are not supported\";\n  JUST(task_graph_rebuild_ctx.TaskNode4Id(proto.src_task_id()));\n  JUST(task_graph_rebuild_ctx.TaskNode4Id(proto.dst_task_id()));\n  // Note that edge id from proto is ignored.\n  lbis_.insert(proto.lbi().begin(), proto.lbi().end());\n  for (const auto& pair : proto.name_in_producer2regst_desc_id()) {\n    AddRegst(pair.first, JUST(task_graph_rebuild_ctx.RegstDesc4Id(pair.second)));\n  }\n  return Maybe<void>::Ok();\n}\n\nvoid TaskEdge::ToProto(TaskEdgeProto* proto) const {\n  // proto->set_task_edge_uid(edge_id());\n  proto->set_task_edge_uid(reinterpret_cast<int64_t>(this));\n  proto->set_src_task_id(src_node()->task_id());\n  proto->set_dst_task_id(dst_node()->task_id());\n  *proto->mutable_lbi() = {lbis_.begin(), lbis_.end()};\n  auto* map = proto->mutable_name_in_producer2regst_desc_id();\n  for (const auto& pair : name_in_producer2regst_) {\n    CHECK(map->insert({pair.first, pair.second->regst_desc_id()}).second);\n  }\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/graph/task_node.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_GRAPH_TASK_NODE_H_\n#define ONEFLOW_CORE_GRAPH_TASK_NODE_H_\n\n#include \"oneflow/core/graph/exec_graph.h\"\n#include \"oneflow/core/job/task.pb.h\"\n#include \"oneflow/core/graph/task_edge.pb.h\"\n#include \"oneflow/core/operator/operator.h\"\n#include \"oneflow/core/common/auto_registration_factory.h\"\n#include \"oneflow/core/memory/memory_zone.h\"\n\nnamespace std {\n\ntemplate<>\nstruct hash<oneflow::TaskType> {\n  std::size_t operator()(const oneflow::TaskType& task_type) const {\n    return std::hash<uint32_t>{}(static_cast<uint32_t>(task_type));\n  }\n};\n\n}  // namespace std\n\nnamespace oneflow {\n\nRegstDescProto* FindOrCreateProducedCtrlRegstDesc(TaskProto* task_proto,\n                                                  const std::string& regst_desc_name);\nRegstDescIdSet* FindOrCreateConsumedCtrlRegstDescIdSet(TaskProto* task_proto,\n                                                       const std::string& regst_desc_name);\n\nbool inline IsValidChainId(int64_t val) { return val >= 0; }\n\nclass TaskEdge;\n\nclass TaskNode : public Node<TaskNode, TaskEdge> {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(TaskNode);\n  TaskNode();\n  ~TaskNode() override = default;\n\n  // Getters\n  int64_t machine_id() const { return machine_id_; }\n  int64_t thrd_id() const { return thrd_id_; }\n  int64_t task_id() const { return task_id_; }\n  const StreamId& stream_id() const;\n  int64_t chain_id() const { return chain_id_; }\n  int64_t order_in_chain() const { return order_in_chain_; }\n  const ExecGraph& exec_gph() const { return exec_gph_; }\n  std::shared_ptr<RegstDesc> GetProducedRegst(const std::string& name);\n  const std::list<std::shared_ptr<RegstDesc>>& GetConsumedRegst(const std::string& name);\n  std::shared_ptr<RegstDesc> GetSoleConsumedRegst(const std::string& name);\n  const HashMap<std::string, std::shared_ptr<RegstDesc>>& produced_regsts() const {\n    return produced_regsts_;\n  }\n  const HashMap<std::string, std::list<std::shared_ptr<RegstDesc>>>& consumed_regsts() const {\n    return consumed_regsts_;\n  }\n  DeviceType device_type() const;\n  virtual const ParallelContext* parallel_ctx() const { return nullptr; }\n\n  // Different types of TaskNode/Compile Mode choose different output BlobDesc inference methods\n  virtual ExecNode::InferBlobDescsMethod GetInferBlobDescsMethod() const = 0;\n\n  // Setters\n  void set_machine_id(int64_t val);\n  void set_thrd_id(int64_t val);\n  void set_chain_id(int64_t val);\n  void set_order_in_chain(int64_t val);\n\n  // Build\n  virtual void ProduceAllRegstsAndBindEdges() = 0;\n  virtual void ConsumeAllRegsts() = 0;\n  void PinConsumedRegst();\n  void InferTimeShapeIfMeaningful();\n  void ForEachProducedDataRegst(const std::function<void(const std::string&, RegstDesc*)>& Handler);\n  void ForEachConsumedDataRegst(\n      const std::function<void(const std::string&, const RegstDesc*)>& Handler) const;\n  void Build();\n\n  void EraseUninitializedShapeProducedBlob();\n  void EraseZeroSizeConsumedRegst();\n  void EraseZeroSizeProducedRegst();\n  void UnbindBnWithEmptyRegst();\n\n  // Others\n  virtual TaskType GetTaskType() const { return TaskType::kInvalid; }\n  std::string VisualStr() const override;\n  virtual bool IsMeaningLess();\n  void ToProto(TaskProto* task_proto) const { ToProto(task_proto, /*check*/ true); }\n  // Used to create task node from proto in plan separation compilation.\n  virtual void InitFromProtoExceptConsumedRegsts(const TaskProto& task_proto);\n  Maybe<void> InitConsumedRegstsFromProto(\n      const TaskProto& task_proto,\n      const std::function<Maybe<RegstDesc>(int64_t regst_desc_id)>& RegstDesc4Id);\n  virtual void ToProto(TaskProto* task_proto, bool check) const;\n  void BindEdgeWithProducedRegst(TaskEdge*, const std::string& name);\n  virtual MemZoneId MemZoneId121() const;\n  bool BuildCtrlRegstDescIfNeed(TaskNode* dst_node, std::string* name);\n  RegstDesc* BuildCtrlRegstDesc(TaskNode* dst_node);\n  RegstDesc* BuildCtrlRegstDesc(TaskNode* dst_node, std::string* name);\n  std::shared_ptr<Shape> GetFastestInputOutputTimeShape() const;\n\n  void ForEachInDataEdge(const std::function<void(TaskEdge*)>& Handler) const;\n  void ForEachOutDataEdge(const std::function<void(TaskEdge*)>& Handler) const;\n\n  void ForEachNodeOnInDataEdge(const std::function<void(TaskNode*)>& Handler) const;\n  void ForEachNodeOnOutDataEdge(const std::function<void(TaskNode*)>& Handler) const;\n  void ForEachNodeOnInOutDataEdge(const std::function<void(TaskNode*)>& Handler) const;\n\n  TaskEdge* SoleInDataEdge() const;\n  TaskEdge* SoleOutDataEdge() const;\n  size_t in_data_edges_size() const;\n  size_t out_data_edges_size() const;\n  const TaskId& new_task_id() const {\n    CHECK(has_new_task_id());\n    return *new_task_id_;\n  }\n  void update_new_task_id(const TaskId& task_id);\n  bool has_new_task_id() const { return static_cast<bool>(new_task_id_); }\n\n protected:\n  std::shared_ptr<RegstDesc> ProduceRegst(const std::string& name, bool enable_reuse_mem);\n  std::shared_ptr<RegstDesc> ProduceRegst(const std::string& name, bool enable_reuse_mem,\n                                          int32_t min_register_num, int32_t max_register_num);\n  std::shared_ptr<RegstDesc> ProduceRegst(const std::string& name, bool enable_reuse_mem,\n                                          int32_t min_register_num, int32_t max_register_num,\n                                          const RegstDescTypeProto&);\n  std::shared_ptr<RegstDesc> NewProducedRegst(bool enable_reuse_mem, int32_t min_register_num,\n                                              int32_t max_register_num, const RegstDescTypeProto&);\n  virtual void InitProducedRegstMemCase(RegstDesc* regst);\n  virtual void InitProducedRegstMemCase(MemoryCase*);\n  virtual void PinConsumedRegstMemCase(MemoryCase*);\n  void ConsumeRegst(const std::string& name);\n  void ConsumeRegst(const std::string& name, const std::shared_ptr<RegstDesc>&);\n  ExecGraph& mut_exec_gph() { return exec_gph_; }\n  void EraseConsumedRegstsByName(const std::string& name);\n\n  virtual void BuildExecGphAndRegst() = 0;\n\n  virtual void InferProducedDataRegstTimeShape() = 0;\n  void NaiveInferProducedDataRegstTimeShape();\n\n  TaskEdge* GetSoleEdge(void (TaskNode::*ForEachEdge)(const std::function<void(TaskEdge*)>&)\n                            const) const;\n  size_t GetEdgesSize(void (TaskNode::*ForEachEdge)(const std::function<void(TaskEdge*)>&)\n                          const) const;\n\n private:\n  void UpdateTaskId();\n  std::shared_ptr<RegstDesc> GetAndCheckRegst(const std::string& name, bool enable_reuse_mem,\n                                              int32_t min_register_num,\n                                              int32_t max_register_num) const;\n\n  int64_t machine_id_;\n  int64_t thrd_id_;\n  int64_t task_id_;\n  int64_t chain_id_;\n  int64_t order_in_chain_;\n  std::unique_ptr<TaskId> new_task_id_;\n\n  ExecGraph exec_gph_;\n  HashMap<std::string, std::shared_ptr<RegstDesc>> produced_regsts_;\n  HashMap<std::string, std::list<std::shared_ptr<RegstDesc>>> consumed_regsts_;\n};\n\nclass TaskGraphRebuildCtx;\n\nclass TaskEdge final : public Edge<TaskNode, TaskEdge> {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(TaskEdge);\n  TaskEdge() = default;\n  ~TaskEdge() override = default;\n\n  std::shared_ptr<RegstDesc> GetRegst(const std::string& name_in_producer) const;\n  bool HasRegst(const std::string& name_in_producer) const;\n  std::shared_ptr<RegstDesc> GetSoleRegst() const;\n  std::vector<std::shared_ptr<RegstDesc>> GetRegsts() const;\n  const HashSet<LogicalBlobId>& GetLbis() const { return lbis_; }\n\n  void AddRegst(const std::string& name_in_producer, const std::shared_ptr<RegstDesc>& regst);\n  void AddLbi(const LogicalBlobId& lbi) { lbis_.insert(lbi); }\n  void AddLbis(const std::vector<LogicalBlobId>& lbis) { lbis_.insert(lbis.begin(), lbis.end()); }\n\n  void CheckRegstLbiValid() const;\n  bool HasRegst() const { return !name_in_producer2regst_.empty(); }\n\n  Maybe<void> InitFromProto(const TaskEdgeProto& proto,\n                            const TaskGraphRebuildCtx& task_graph_rebuild_ctx);\n  void ToProto(TaskEdgeProto* proto) const;\n\n private:\n  HashSet<LogicalBlobId> lbis_;\n  HashMap<std::string, std::shared_ptr<RegstDesc>> name_in_producer2regst_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_GRAPH_TASK_NODE_H_\n"
  },
  {
    "path": "oneflow/core/graph/task_stream_id.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_GRAPH_TASK_STREAM_ID_H_\n#define ONEFLOW_CORE_GRAPH_TASK_STREAM_ID_H_\n\n#include \"oneflow/core/graph/stream_id.h\"\n#include \"oneflow/core/graph/task_stream_index_manager.h\"\n\nnamespace oneflow {\n\ninline StreamId GenerateComputeTaskStreamId(const DeviceId& device_id) {\n  auto stream_index =\n      Singleton<TaskStreamIndexManager>::Get()->GetComputeTaskStreamIndex(device_id);\n  return StreamId{device_id, stream_index};\n}\n\ninline StreamId GenerateComputeTaskStreamId(int64_t rank, DeviceType device_type,\n                                            int64_t device_index) {\n  DeviceId device_id{static_cast<DeviceId::rank_t>(rank), device_type,\n                     static_cast<DeviceId::device_index_t>(device_index)};\n  return GenerateComputeTaskStreamId(device_id);\n}\n\ninline StreamId GenerateNamedTaskStreamId(const DeviceId& device_id, const std::string& name) {\n  auto stream_index =\n      Singleton<TaskStreamIndexManager>::Get()->GetNamedTaskStreamIndex(device_id, name);\n  return StreamId{device_id, stream_index};\n}\n\ninline StreamId GenerateNamedTaskStreamId(int64_t rank, DeviceType device_type,\n                                          int64_t device_index, const std::string& name) {\n  DeviceId device_id{static_cast<DeviceId::rank_t>(rank), device_type,\n                     static_cast<DeviceId::device_index_t>(device_index)};\n  return GenerateNamedTaskStreamId(device_id, name);\n}\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_GRAPH_TASK_STREAM_ID_H_\n"
  },
  {
    "path": "oneflow/core/graph/task_stream_index_manager.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/graph/task_stream_index_manager.h\"\n#include \"oneflow/core/framework/multi_client_session_context.h\"\n#include \"oneflow/core/job/global_for.h\"\n#include \"oneflow/core/job/id_state.h\"\n#include \"oneflow/core/job/resource_desc.h\"\n\nnamespace oneflow {\n\nStreamIndexGenerator* TaskStreamIndexManager::GetGenerator(const DeviceId& device_id) {\n  std::unique_lock<std::mutex> lck(mtx_);\n  auto iter = generators_.find(device_id);\n  if (iter == generators_.end()) {\n    uint32_t init_stream_index = 0;\n    const int64_t i64_device_id = EncodeDeviceIdToInt64(device_id);\n    if (stream_index_init_state_.count(i64_device_id) != 0) {\n      init_stream_index = stream_index_init_state_.at(i64_device_id);\n    }\n    iter = generators_.emplace(device_id, std::make_unique<StreamIndexGenerator>(init_stream_index))\n               .first;\n  }\n  return iter->second.get();\n}\n\nTaskStreamIndexManager::stream_index_t TaskStreamIndexManager::GetTaskStreamIndex(\n    TaskType task_type, const DeviceId& device_id) {\n  auto* generator = GetGenerator(device_id);\n  auto stream_index = CHECK_JUST(TaskStreamIndexGetterRegistry::Instance().Dispatch(\n      device_id.device_type(), task_type, generator));\n  return stream_index;\n}\n\nTaskStreamIndexManager::stream_index_t TaskStreamIndexManager::GetComputeTaskStreamIndex(\n    const DeviceId& device_id) {\n  auto* generator = GetGenerator(device_id);\n  return GenerateComputeTaskStreamIndex(device_id.device_type(), generator);\n}\n\nTaskStreamIndexManager::stream_index_t TaskStreamIndexManager::GetNamedTaskStreamIndex(\n    const DeviceId& device_id, const std::string& name) {\n  auto* generator = GetGenerator(device_id);\n  return generator->GenerateNamed(name);\n}\n\nvoid TaskStreamIndexManager::GetTaskStreamIndex(HashMap<int64_t, uint32_t>* stream_index_state) {\n  for (auto& pair : generators_) {\n    const int64_t i64_device_id = EncodeDeviceIdToInt64(pair.first);\n    (*stream_index_state)[i64_device_id] = pair.second->GetCurrStreamIndex();\n  }\n}\n\nvoid TaskStreamIndexManager::TryUpdateTaskStreamIndex(\n    const HashMap<int64_t, uint32_t>& stream_index_state) {\n  // Try Update generator's new_stream_index\n  for (auto& pair : generators_) {\n    const int64_t i64_device_id = EncodeDeviceIdToInt64(pair.first);\n    uint32_t initial_stream_index = 0;\n    if (stream_index_state.count(i64_device_id) != 0) {\n      initial_stream_index = stream_index_state.at(i64_device_id);\n    }\n    pair.second->TryUpdateNextStreamIndex(initial_stream_index);\n  }\n\n  // try update stream_index_init_state\n  for (const auto& pair : stream_index_state) {\n    const auto& key = pair.first;\n    const auto& val = pair.second;\n    if (stream_index_init_state_.count(key) != 0) {\n      stream_index_init_state_[key] = std::max(stream_index_init_state_.at(key), val);\n    } else {\n      stream_index_init_state_[key] = val;\n    }\n  }\n}\n\nvoid TaskStreamIndexGetterRegistry::Register(const key_t& key, const stream_index_getter& getter) {\n  bool insert_success = stream_index_getter_map_.emplace(key, getter).second;\n  if (!insert_success) {\n    std::cerr << \"DeviceType \" << key.first << \", TaskType \" << key.second\n              << \" was already registered\";\n    abort();\n  }\n}\n\nMaybe<StreamId::stream_index_t> TaskStreamIndexGetterRegistry::Dispatch(\n    DeviceType device_type, TaskType task_type, StreamIndexGenerator* generator) {\n  auto key = std::make_pair(device_type, task_type);\n  auto it = stream_index_getter_map_.find(key);\n  CHECK_OR_RETURN(it != stream_index_getter_map_.end())\n      << \"TaskType: \" << key.second << \", DeviceType: \" << key.first << \" has not been registered\";\n  return it->second(generator);\n}\n\nStreamId::stream_index_t GenerateComputeTaskStreamIndex(DeviceType device_type,\n                                                        StreamIndexGenerator* generator) {\n  if (device_type == DeviceType::kCPU) {\n    size_t cpu_device_num = Singleton<ResourceDesc, ForSession>::Get()->CpuDeviceNum();\n    return generator->GenerateNamedRoundRobin(\"CPU_COMPUTE\", cpu_device_num);\n  } else {\n    return generator->GenerateNamed(\"COMPUTE\");\n  }\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/graph/task_stream_index_manager.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_GRAPH_TASK_STREAM_INDEX_MANAGER_H_\n#define ONEFLOW_CORE_GRAPH_TASK_STREAM_INDEX_MANAGER_H_\n\n#include \"oneflow/core/job/task.pb.h\"\n#include \"oneflow/core/graph/stream_index_generator.h\"\n\nnamespace oneflow {\n\nclass TaskStreamIndexManager {\n public:\n  using stream_index_t = StreamId::stream_index_t;\n\n  OF_DISALLOW_COPY_AND_MOVE(TaskStreamIndexManager);\n  TaskStreamIndexManager() = default;\n  virtual ~TaskStreamIndexManager() = default;\n\n  StreamIndexGenerator* GetGenerator(const DeviceId& device_id);\n  stream_index_t GetTaskStreamIndex(TaskType task_type, const DeviceId& device_id);\n  stream_index_t GetComputeTaskStreamIndex(const DeviceId& device_id);\n  stream_index_t GetNamedTaskStreamIndex(const DeviceId& device_id, const std::string& name);\n  void GetTaskStreamIndex(HashMap<int64_t, uint32_t>* stream_index_state);\n  void TryUpdateTaskStreamIndex(const HashMap<int64_t, uint32_t>& stream_index_state);\n\n private:\n  HashMap<DeviceId, std::unique_ptr<StreamIndexGenerator>> generators_;\n  // The stream_index_init_state is used to initialize the generator.\n  HashMap<int64_t, uint32_t> stream_index_init_state_{};\n  std::mutex mtx_;\n};\n\nclass TaskStreamIndexGetterRegistry final {\n public:\n  using key_t = std::pair<DeviceType, TaskType>;\n  using stream_index_getter = std::function<StreamId::stream_index_t(StreamIndexGenerator*)>;\n  using map_t = HashMap<key_t, stream_index_getter>;\n\n  struct GetterRegister {\n    GetterRegister(DeviceType device_type, TaskType task_type, const stream_index_getter& getter) {\n      TaskStreamIndexGetterRegistry::Instance().Register(std::make_pair(device_type, task_type),\n                                                         getter);\n    }\n  };\n\n  static TaskStreamIndexGetterRegistry& Instance() {\n    static TaskStreamIndexGetterRegistry registry;\n    return registry;\n  }\n\n  OF_DISALLOW_COPY_AND_MOVE(TaskStreamIndexGetterRegistry);\n  ~TaskStreamIndexGetterRegistry() = default;\n\n  void Register(const key_t& key, const stream_index_getter& getter);\n  Maybe<StreamId::stream_index_t> Dispatch(DeviceType device_type, TaskType task_type,\n                                           StreamIndexGenerator* generator);\n\n private:\n  TaskStreamIndexGetterRegistry() = default;\n  map_t stream_index_getter_map_;\n};\n\nStreamId::stream_index_t GenerateComputeTaskStreamIndex(DeviceType device_type,\n                                                        StreamIndexGenerator* generator);\n\n}  // namespace oneflow\n\n#define REGISTER_TASK_STREAM_INDEX_GETTER(device_type, task_type, getter) \\\n  static auto OF_PP_CAT(g_stream_index_getter_register_, __COUNTER__) =   \\\n      ::oneflow::TaskStreamIndexGetterRegistry::GetterRegister(device_type, task_type, getter)\n\n#define REGISTER_NAMED_TASK_STREAM_INDEX_GETTER(device_type, task_type, name)                    \\\n  REGISTER_TASK_STREAM_INDEX_GETTER(                                                             \\\n      device_type, task_type, ([](StreamIndexGenerator* generator) -> StreamId::stream_index_t { \\\n        return generator->GenerateNamed(name);                                                   \\\n      }));\n\n#define REGISTER_INDEPENDENT_TASK_STREAM_INDEX_GETTER(task_type)         \\\n  REGISTER_TASK_STREAM_INDEX_GETTER(                                     \\\n      DeviceType::kCPU, task_type,                                       \\\n      ([](StreamIndexGenerator* generator) -> StreamId::stream_index_t { \\\n        return generator->GenerateAnonymous();                           \\\n      }));\n\n#define REGISTER_TICK_TASK_STREAM_INDEX_GETTER(task_type)                \\\n  REGISTER_TASK_STREAM_INDEX_GETTER(                                     \\\n      DeviceType::kCPU, task_type,                                       \\\n      ([](StreamIndexGenerator* generator) -> StreamId::stream_index_t { \\\n        return generator->GenerateNamed(\"TICK\");                         \\\n      }));\n\n#define REGISTER_DEVICE_COMP_TASK_STREAM_INDEX_GETTER(device_type, task_type)                    \\\n  REGISTER_TASK_STREAM_INDEX_GETTER(                                                             \\\n      device_type, task_type, ([](StreamIndexGenerator* generator) -> StreamId::stream_index_t { \\\n        return GenerateComputeTaskStreamIndex(device_type, generator);                           \\\n      }));\n\n#define REGISTER_COMP_TASK_STREAM_INDEX_GETTER(task_type)                                          \\\n  OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_DEVICE_COMP_TASK_STREAM_INDEX_GETTER, DEVICE_TYPE_SEQ, \\\n                                   (task_type))\n\n#endif  // ONEFLOW_CORE_GRAPH_TASK_STREAM_INDEX_MANAGER_H_\n"
  },
  {
    "path": "oneflow/core/graph/task_type_visitor.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/throw.h\"\n#include \"oneflow/core/job/task.pb.h\"\n#include \"oneflow/core/graph/collective_boxing_task_node.h\"\n#include \"oneflow/core/graph/nccl_send_recv_boxing_task_node.h\"\n#include \"oneflow/core/graph/copy_task_node.h\"\n#include \"oneflow/core/graph/boxing_zeros_task_node.h\"\n#include \"oneflow/core/graph/slice_boxing_task_node.h\"\n#include \"oneflow/core/graph/collective_boxing_pack_task_node.h\"\n#include \"oneflow/core/graph/collective_boxing_unpack_task_node.h\"\n#include \"oneflow/core/graph/boxing_identity_task_node.h\"\n\nnamespace oneflow {\n\ntemplate<typename DerivedT>\nstruct TransportTaskTypeVisitor {\n  template<typename... Args>\n  static auto Visit(TaskType task_type, Args&&... args) {\n    switch (task_type) {\n      case TaskType::kInvalid: LOG(FATAL) << \"invalid task type\";\n      case TaskType::kCopyHd: return DerivedT::VisitCopyHd(std::forward<Args>(args)...);\n      case TaskType::kCopyCommNet: return DerivedT::VisitCopyCommNet(std::forward<Args>(args)...);\n      case TaskType::kSliceBoxing: return DerivedT::VisitSliceBoxing(std::forward<Args>(args)...);\n      case TaskType::kCollectiveBoxingGeneric:\n        return DerivedT::VisitCollectiveBoxingGeneric(std::forward<Args>(args)...);\n      case TaskType::kBoxingIdentity:\n        return DerivedT::VisitBoxingIdentity(std::forward<Args>(args)...);\n      case TaskType::kNcclSendRecvBoxing:\n        return DerivedT::VisitNcclSendRecvBoxing(std::forward<Args>(args)...);\n      case TaskType::kBoxingZeros: return DerivedT::VisitBoxingZeros(std::forward<Args>(args)...);\n      case TaskType::kCollectiveBoxingPack:\n        return DerivedT::VisitCollectiveBoxingPack(std::forward<Args>(args)...);\n      case TaskType::kCollectiveBoxingUnpack:\n        return DerivedT::VisitCollectiveBoxingUnpack(std::forward<Args>(args)...);\n      default: LOG(FATAL) << \"invalid task type\";\n    }\n  }\n};\n\nstruct CreateTransportTask final : public TransportTaskTypeVisitor<CreateTransportTask> {\n  static Maybe<TransportTaskNode*> VisitCopyHd() { return new CopyHdTaskNode(); }\n  static Maybe<TransportTaskNode*> VisitCopyCommNet() { return new CopyCommNetTaskNode(); }\n  static Maybe<TransportTaskNode*> VisitSliceBoxing() { return new SliceBoxingTaskNode(); }\n  static Maybe<TransportTaskNode*> VisitCollectiveBoxingGeneric() {\n    return new CollectiveBoxingGenericTaskNode();\n  }\n  static Maybe<TransportTaskNode*> VisitBoxingIdentity() { return new BoxingIdentityTaskNode(); }\n  static Maybe<TransportTaskNode*> VisitCollectiveBoxingPack() {\n    return new CollectiveBoxingPackTaskNode();\n  }\n  static Maybe<TransportTaskNode*> VisitCollectiveBoxingUnpack() {\n    return new CollectiveBoxingUnpackTaskNode();\n  }\n  static Maybe<TransportTaskNode*> VisitBoxingZeros() { return new BoxingZerosTaskNode(); }\n  static Maybe<TransportTaskNode*> VisitNcclSendRecvBoxing() {\n    return new NcclSendRecvBoxingTaskNode();\n  }\n};\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/graph/transport_task_node.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/graph/transport_task_node.h\"\n#include \"oneflow/core/graph/boxing_task_graph.pb.h\"\n\nnamespace oneflow {\n\nMaybe<void> TransportTaskNode::InitTransportTaskFromProtoIf(\n    const TransportTaskProto& transport_task_proto, const TaskGraphRebuildCtx& ctx) {\n  CHECK(has_new_task_id());\n  JUST(InitTransportTaskFromProto(transport_task_proto, ctx));\n  lbi_ = transport_task_proto.lbi();\n  return Maybe<void>::Ok();\n}\n\nvoid TransportTaskNode::ToTransportTaskProtoIf(TransportTaskProto* transport_task_proto) const {\n  ToTransportTaskProto(transport_task_proto);\n  *transport_task_proto->mutable_lbi() = lbi_;\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/graph/transport_task_node.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_GRAPH_TRANSPORT_TASK_NODE_H_\n#define ONEFLOW_CORE_GRAPH_TRANSPORT_TASK_NODE_H_\n\n#include \"oneflow/core/graph/task_node.h\"\n#include \"oneflow/core/register/logical_blob_id.pb.h\"\n\nnamespace oneflow {\n\nclass TransportTaskProto;\nclass TaskGraphRebuildCtx;\n\nclass TransportTaskNode : public TaskNode {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(TransportTaskNode);\n  TransportTaskNode() = default;\n  virtual ~TransportTaskNode() = default;\n\n  void set_lbi(const LogicalBlobId& lbi) { lbi_ = lbi; }\n  LogicalBlobId lbi() const { return lbi_; }\n\n  Maybe<void> InitTransportTaskFromProtoIf(const TransportTaskProto& transport_task_proto,\n                                           const TaskGraphRebuildCtx& ctx);\n  void ToTransportTaskProtoIf(TransportTaskProto*) const;\n\n  ExecNode::InferBlobDescsMethod GetInferBlobDescsMethod() const override {\n    // TransportTaskNode infers output BlobDesc based on input BlobDesc, because it can't infers\n    // output BlobDesc with SBP.\n    return &ExecNode::InferBlobDescsByInputs;\n  }\n\n private:\n  virtual Maybe<void> InitTransportTaskFromProto(const TransportTaskProto&,\n                                                 const TaskGraphRebuildCtx& ctx) = 0;\n\n  virtual void ToTransportTaskProto(TransportTaskProto*) const = 0;\n  LogicalBlobId lbi_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_GRAPH_TRANSPORT_TASK_NODE_H_\n"
  },
  {
    "path": "oneflow/core/graph_impl/acc_compute_task_node.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/graph/compute_task_node.h\"\n#include \"oneflow/core/graph/task_stream_index_manager.h\"\n\nnamespace oneflow {\n\nclass AccCompTaskNode final : public CompTaskNode {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(AccCompTaskNode);\n  AccCompTaskNode() = default;\n  ~AccCompTaskNode() = default;\n  TaskType GetTaskType() const override { return TaskType::kAcc; }\n  void BuildExecGphAndRegst() override;\n  void ProduceAllRegstsAndBindEdges() override;\n  void ConsumeAllRegsts() override;\n  void ConsumeFakeRegsts() override;\n};\n\nvoid AccCompTaskNode::ProduceAllRegstsAndBindEdges() {\n  std::shared_ptr<RegstDesc> regst = ProduceRegst(\"out\", false);\n  ForEachOutDataEdge([&](TaskEdge* edge) { edge->AddRegst(\"out\", regst); });\n}\n\nvoid AccCompTaskNode::ConsumeAllRegsts() { ConsumeRegst(\"in\", SoleInDataEdge()->GetSoleRegst()); }\nvoid AccCompTaskNode::ConsumeFakeRegsts() { ConsumeFakeRegst(\"in\"); }\n\nvoid AccCompTaskNode::BuildExecGphAndRegst() {\n  std::shared_ptr<RegstDesc> in_regst = GetSoleConsumedRegst(\"in\");\n  std::shared_ptr<RegstDesc> out_regst = GetProducedRegst(\"out\");\n  ExecNode* exec_node = mut_exec_gph().NewNode();\n  exec_node->mut_op() = op();\n  exec_node->BindBnWithRegst(op()->SoleIbn(), in_regst);\n  out_regst->AddLbi(op()->BnInOp2Lbi(op()->SoleObn()));\n  exec_node->BindBnWithRegst(op()->SoleObn(), out_regst);\n  (exec_node->*GetInferBlobDescsMethod())(parallel_ctx());\n  out_regst->ForEachLbi([out_regst](const LogicalBlobId& lbi) {\n    const BlobDesc* blob_desc = out_regst->GetBlobDesc(lbi);\n    CHECK_EQ(blob_desc->is_dynamic(), false);\n  });\n}\n\nREGISTER_COMP_TASK_STREAM_INDEX_GETTER(TaskType::kAcc);\n\nREGISTER_USER_OP_COMP_TASK_NODE_TYPE(\"acc\", AccCompTaskNode);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/graph_impl/acc_ctrl_tick_compute_task_node.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/graph/compute_task_node.h\"\n#include \"oneflow/core/graph/task_stream_index_manager.h\"\n\nnamespace oneflow {\n\nclass AccCtrlTickCompTaskNode final : public CompTaskNode {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(AccCtrlTickCompTaskNode);\n  AccCtrlTickCompTaskNode() = default;\n  ~AccCtrlTickCompTaskNode() = default;\n  TaskType GetTaskType() const override { return TaskType::kAccCtrlTick; }\n  void ProduceAllRegstsAndBindEdges() override;\n  void ConsumeAllRegsts() override;\n  void BuildExecGphAndRegst() override;\n  void ConsumeFakeRegsts() override;\n};\n\nvoid AccCtrlTickCompTaskNode::ProduceAllRegstsAndBindEdges() {\n  std::shared_ptr<RegstDesc> regst = ProduceRegst(\"out\", false);\n  ForEachOutDataEdge([&](TaskEdge* edge) { edge->AddRegst(\"out\", regst); });\n}\n\nvoid AccCtrlTickCompTaskNode::ConsumeAllRegsts() {\n  ConsumeRegst(\"in\", SoleInDataEdge()->GetSoleRegst());\n}\n\nvoid AccCtrlTickCompTaskNode::ConsumeFakeRegsts() { ConsumeFakeRegst(\"in\"); }\n\nvoid AccCtrlTickCompTaskNode::BuildExecGphAndRegst() {\n  std::shared_ptr<RegstDesc> in_regst = GetSoleConsumedRegst(\"in\");\n  std::shared_ptr<RegstDesc> out_regst = GetProducedRegst(\"out\");\n  std::shared_ptr<const Operator> op = this->op();\n  ExecNode* exec_node = mut_exec_gph().NewNode();\n  exec_node->mut_op() = op;\n  exec_node->BindBnWithRegst(op->SoleIbn(), in_regst);\n  out_regst->AddLbi(op->BnInOp2Lbi(op->SoleObn()));\n  exec_node->BindBnWithRegst(op->SoleObn(), out_regst);\n  (exec_node->*GetInferBlobDescsMethod())(parallel_ctx());\n}\n\nREGISTER_COMP_TASK_STREAM_INDEX_GETTER(TaskType::kAccCtrlTick);\n\nREGISTER_USER_OP_COMP_TASK_NODE_TYPE(\"acc_ctrl_tick\", AccCtrlTickCompTaskNode);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/graph_impl/acc_tick_compute_task_node.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/graph/compute_task_node.h\"\n#include \"oneflow/core/graph/task_stream_index_manager.h\"\n\nnamespace oneflow {\n\nclass AccTickCompTaskNode final : public CompTaskNode {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(AccTickCompTaskNode);\n  AccTickCompTaskNode() = default;\n  ~AccTickCompTaskNode() = default;\n  TaskType GetTaskType() const override { return TaskType::kAccTick; }\n  void ProduceAllRegstsAndBindEdges() override;\n  void ConsumeAllRegsts() override;\n  void ConsumeFakeRegsts() override;\n  void BuildExecGphAndRegst() override;\n};\n\nvoid AccTickCompTaskNode::ProduceAllRegstsAndBindEdges() {\n  std::shared_ptr<RegstDesc> regst = ProduceRegst(\"out\", false);\n  ForEachOutDataEdge([&](TaskEdge* edge) { edge->AddRegst(\"out\", regst); });\n}\n\nvoid AccTickCompTaskNode::ConsumeAllRegsts() {\n  ConsumeRegst(\"in\", SoleInDataEdge()->GetSoleRegst());\n}\nvoid AccTickCompTaskNode::ConsumeFakeRegsts() { ConsumeFakeRegst(\"in\"); }\n\nvoid AccTickCompTaskNode::BuildExecGphAndRegst() {\n  std::shared_ptr<RegstDesc> in_regst = GetSoleConsumedRegst(\"in\");\n  std::shared_ptr<RegstDesc> out_regst = GetProducedRegst(\"out\");\n  std::shared_ptr<const Operator> op = this->op();\n  ExecNode* exec_node = mut_exec_gph().NewNode();\n  exec_node->mut_op() = op;\n  exec_node->BindBnWithRegst(op->SoleIbn(), in_regst);\n  out_regst->AddLbi(op->BnInOp2Lbi(op->SoleObn()));\n  exec_node->BindBnWithRegst(op->SoleObn(), out_regst);\n  (exec_node->*GetInferBlobDescsMethod())(parallel_ctx());\n}\n\nREGISTER_COMP_TASK_STREAM_INDEX_GETTER(TaskType::kAccTick);\n\nREGISTER_SYSTEM_OP_COMP_TASK_NODE_TYPE(OperatorConf::kAccTickConf, AccTickCompTaskNode);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/graph_impl/callback_notify_compute_task_node.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/graph/compute_task_node.h\"\n#include \"oneflow/core/graph/task_stream_index_manager.h\"\n\nnamespace oneflow {\n\nclass CallbackNotifyCompTaskNode final : public CompTaskNode {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CallbackNotifyCompTaskNode);\n  CallbackNotifyCompTaskNode() = default;\n  ~CallbackNotifyCompTaskNode() = default;\n\n  TaskType GetTaskType() const override { return TaskType::kCallbackNotify; }\n\n private:\n  void ProduceAllRegstsAndBindEdges() override;\n  void ConsumeAllRegsts() override;\n  void ConsumeFakeRegsts() override;\n  void BuildExecGphAndRegst() override;\n};\n\nvoid CallbackNotifyCompTaskNode::ProduceAllRegstsAndBindEdges() {}\n\nvoid CallbackNotifyCompTaskNode::ConsumeAllRegsts() {\n  ForEachInDataEdge([&](TaskEdge* edge) { ConsumeRegst(\"in\", edge->GetSoleRegst()); });\n}\n\nvoid CallbackNotifyCompTaskNode::ConsumeFakeRegsts() { ConsumeFakeRegst(\"in\"); }\n\nvoid CallbackNotifyCompTaskNode::BuildExecGphAndRegst() {\n  ExecNode* node = mut_exec_gph().NewNode();\n  node->mut_op() = this->op();\n  for (const std::string& ibn : node->op()->input_bns()) {\n    node->BindBnWithOneOfTheRegsts(ibn, GetConsumedRegst(\"in\"));\n  }\n  CHECK(node->op()->tmp_bns().empty());\n  CHECK(node->op()->output_bns().empty());\n}\n\nREGISTER_NAMED_TASK_STREAM_INDEX_GETTER(DeviceType::kCPU, TaskType::kCallbackNotify,\n                                        \"CALLBACK_NOTIFY\");\n\nREGISTER_SYSTEM_OP_COMP_TASK_NODE_TYPE(OperatorConf::kCallbackNotifyConf,\n                                       CallbackNotifyCompTaskNode);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/graph_impl/case_compute_task_node.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/graph/compute_task_node.h\"\n#include \"oneflow/core/graph/task_stream_index_manager.h\"\n\nnamespace oneflow {\n\nclass CaseCompTaskNode final : public CompTaskNode {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CaseCompTaskNode);\n  CaseCompTaskNode() = default;\n  ~CaseCompTaskNode() override = default;\n\n  void ProduceAllRegstsAndBindEdges() override;\n  void ConsumeAllRegsts() override;\n  void ConsumeFakeRegsts() override;\n\n  TaskType GetTaskType() const override { return TaskType::kCase; }\n\n private:\n  void BuildExecGphAndRegst() override;\n  void InferProducedDataRegstTimeShape() override;\n};\n\nvoid CaseCompTaskNode::ConsumeAllRegsts() { ConsumeRegst(\"in\", SoleInDataEdge()->GetSoleRegst()); }\n\nvoid CaseCompTaskNode::ConsumeFakeRegsts() { ConsumeFakeRegst(\"in\"); }\n\nvoid CaseCompTaskNode::ProduceAllRegstsAndBindEdges() {\n  HashMap<LogicalBlobId, int64_t> lbi2obn_id;\n  FOR_RANGE(int64_t, obn_id, 0, op()->output_bns().size()) {\n    CHECK(lbi2obn_id.emplace(op()->BnInOp2Lbi(GenRepeatedBn(\"out\", obn_id)), obn_id).second);\n  }\n  ForEachOutDataEdge([&](TaskEdge* edge) {\n    const OpNode* succ = GetOneSuccOpNodeOnEdge(edge);\n    int64_t obn_id = -1;\n    for (const std::string& ibn : succ->shared_op()->input_bns()) {\n      const LogicalBlobId& lbi = succ->shared_op()->BnInOp2Lbi(ibn);\n      if (lbi2obn_id.find(lbi) != lbi2obn_id.cend()) {\n        CHECK_EQ(obn_id, -1);\n        obn_id = lbi2obn_id.at(lbi);\n      }\n    }\n    CHECK_NE(obn_id, -1);\n    std::string name = \"out_\" + std::to_string(obn_id);\n    CHECK(GetProducedRegst(name) == nullptr);\n    edge->AddRegst(\"out\", ProduceRegst(name, false));\n  });\n}\n\nvoid CaseCompTaskNode::BuildExecGphAndRegst() {\n  ExecNode* node = mut_exec_gph().NewNode();\n  std::shared_ptr<const Operator> sole_op = op();\n  node->mut_op() = sole_op;\n  node->BindBnWithRegst(\"in\", GetSoleConsumedRegst(\"in\"));\n  FOR_RANGE(int64_t, obn_id, 0, sole_op->output_bns().size()) {\n    std::string name = \"out_\" + std::to_string(obn_id);\n    std::shared_ptr<RegstDesc> out_regst = GetProducedRegst(name);\n    out_regst->AddLbi(sole_op->BnInOp2Lbi(name));\n    node->BindBnWithRegst(name, out_regst);\n  }\n  (node->*GetInferBlobDescsMethod())(parallel_ctx());\n}\n\nvoid CaseCompTaskNode::InferProducedDataRegstTimeShape() { NaiveInferProducedDataRegstTimeShape(); }\n\nREGISTER_TICK_TASK_STREAM_INDEX_GETTER(TaskType::kCase);\n\nREGISTER_SYSTEM_OP_COMP_TASK_NODE_TYPE(OperatorConf::kCaseConf, CaseCompTaskNode);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/graph_impl/critical_section_wait_compute_task_node.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/graph/compute_task_node.h\"\n#include \"oneflow/core/graph/task_stream_index_manager.h\"\n\nnamespace oneflow {\n\nclass CriticalSectionWaitTickCompTaskNode final : public CompTaskNode {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CriticalSectionWaitTickCompTaskNode);\n  CriticalSectionWaitTickCompTaskNode() = default;\n  ~CriticalSectionWaitTickCompTaskNode() = default;\n\n  bool IsMeaningLess() override { return false; }\n  TaskType GetTaskType() const override { return TaskType::kCriticalSectionWaitTick; }\n\n private:\n  void ProduceAllRegstsAndBindEdges() override;\n  void ConsumeAllRegsts() override;\n  void ConsumeFakeRegsts() override;\n  void BuildExecGphAndRegst() override;\n};\n\nvoid CriticalSectionWaitTickCompTaskNode::ProduceAllRegstsAndBindEdges() {\n  ProduceRegst(\"out\", false, 128, 128);\n  ForEachOutDataEdge([&](TaskEdge* edge) { BindEdgeWithProducedRegst(edge, \"out\"); });\n}\n\nvoid CriticalSectionWaitTickCompTaskNode::ConsumeAllRegsts() {\n  ConsumeRegst(\"in\");\n  ForEachInDataEdge([&](TaskEdge* edge) { ConsumeRegst(\"in\", edge->GetSoleRegst()); });\n}\n\nvoid CriticalSectionWaitTickCompTaskNode::ConsumeFakeRegsts() { ConsumeFakeRegst(\"in\"); }\n\nvoid CriticalSectionWaitTickCompTaskNode::BuildExecGphAndRegst() {\n  ExecNode* node = mut_exec_gph().NewNode();\n  node->mut_op() = op();\n  const std::list<std::shared_ptr<RegstDesc>>& in_regsts = GetConsumedRegst(\"in\");\n  for (const std::string& ibn : node->op()->input_bns()) {\n    node->BindBnWithOneOfTheRegsts(ibn, in_regsts);\n  }\n  std::shared_ptr<RegstDesc> out_regst = GetProducedRegst(\"out\");\n  for (const std::string& obn : node->op()->output_bns()) {\n    const LogicalBlobId& lbi = node->op()->BnInOp2Lbi(obn);\n    out_regst->AddLbi(lbi);\n    node->BindBnWithRegst(obn, out_regst);\n  }\n  (node->*GetInferBlobDescsMethod())(parallel_ctx());\n}\n\nREGISTER_INDEPENDENT_TASK_STREAM_INDEX_GETTER(TaskType::kCriticalSectionWaitTick);\n\nREGISTER_SYSTEM_OP_COMP_TASK_NODE_TYPE(OperatorConf::kCriticalSectionWaitTickConf,\n                                       CriticalSectionWaitTickCompTaskNode);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/graph_impl/decode_h2d_compute_task_node.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/graph/compute_task_node.h\"\n#include \"oneflow/core/graph/normal_forward_compute_task_node.h\"\n#include \"oneflow/core/graph/task_stream_index_manager.h\"\n\nnamespace oneflow {\n\nclass DecodeH2DCompTaskNode final : public CompTaskNode {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(DecodeH2DCompTaskNode);\n  DecodeH2DCompTaskNode() = default;\n  ~DecodeH2DCompTaskNode() override = default;\n\n  void ProduceAllRegstsAndBindEdges() override;\n  void ConsumeAllRegsts() override;\n  void ConsumeFakeRegsts() override;\n\n  TaskType GetTaskType() const override { return TaskType::kDecodeH2D; }\n\n private:\n  void BuildExecGphAndRegst() override;\n};\n\nvoid DecodeH2DCompTaskNode::ConsumeAllRegsts() {\n  ConsumeRegst(\"in\", SoleInDataEdge()->GetSoleRegst());\n}\n\nvoid DecodeH2DCompTaskNode::ConsumeFakeRegsts() { ConsumeFakeRegst(\"in\"); }\n\nvoid DecodeH2DCompTaskNode::ProduceAllRegstsAndBindEdges() {\n  auto regst_num = ParseIntegerFromEnv(\"ONEFLOW_DECODE_H2D_REGST_NUM\", 2);\n  std::shared_ptr<RegstDesc> out_regst = ProduceRegst(\"out\", false, regst_num, regst_num);\n  ForEachOutDataEdge([&](TaskEdge* edge) { edge->AddRegst(\"out\", out_regst); });\n  ProduceRegst(\"tmp\", false);\n}\n\nvoid DecodeH2DCompTaskNode::BuildExecGphAndRegst() {\n  ExecNode* node = mut_exec_gph().NewNode();\n  std::shared_ptr<const Operator> sole_op = op();\n  node->mut_op() = sole_op;\n  node->BindBnWithRegst(sole_op->SoleIbn(), GetSoleConsumedRegst(\"in\"));\n  std::shared_ptr<RegstDesc> out_regst = GetProducedRegst(\"out\");\n  out_regst->AddLbi(sole_op->BnInOp2Lbi(sole_op->SoleObn()));\n  node->BindBnWithRegst(sole_op->SoleObn(), out_regst);\n  node->AddBnToRegstAndBindIt(&Operator::tmp_bns, GetProducedRegst(\"tmp\"));\n  (node->*GetInferBlobDescsMethod())(parallel_ctx());\n}\n\nREGISTER_NAMED_TASK_STREAM_INDEX_GETTER(DeviceType::kCUDA, TaskType::kDecodeH2D, \"DECODE_H2D\")\n\nnamespace {\n\nCompTaskNode* CreateCompTaskNodeByOpDeviceType(const OperatorConf& op_conf) {\n  if (CHECK_JUST(DeviceType4DeviceTag(op_conf.device_tag())) == DeviceType::kCUDA) {\n    return new DecodeH2DCompTaskNode;\n  } else {\n    return new NormalForwardCompTaskNode;\n  }\n}\n\n}  // namespace\n\nREGISTER_SYSTEM_OP_COMP_TASK_NODE_TYPE_WITH_FUNC(OperatorConf::kImageDecoderRandomCropResizeConf,\n                                                 CreateCompTaskNodeByOpDeviceType);\nREGISTER_USER_OP_COMP_TASK_NODE_TYPE_WITH_FUNC(\"raw_reader\", CreateCompTaskNodeByOpDeviceType);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/graph_impl/device_tick_compute_task_node.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/graph/compute_task_node.h\"\n#include \"oneflow/core/graph/task_stream_index_manager.h\"\n\nnamespace oneflow {\n\nclass DeviceTickCompTaskNode final : public CompTaskNode {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(DeviceTickCompTaskNode);\n  DeviceTickCompTaskNode() = default;\n  ~DeviceTickCompTaskNode() = default;\n\n  bool IsMeaningLess() override { return false; }\n  TaskType GetTaskType() const override { return TaskType::kDeviceTick; }\n\n private:\n  void ProduceAllRegstsAndBindEdges() override;\n  void ConsumeAllRegsts() override;\n  void ConsumeFakeRegsts() override;\n  void BuildExecGphAndRegst() override;\n};\n\nvoid DeviceTickCompTaskNode::ProduceAllRegstsAndBindEdges() {\n  ProduceRegst(\"out\", false, 1, 1);\n  ForEachOutDataEdge([&](TaskEdge* edge) { BindEdgeWithProducedRegst(edge, \"out\"); });\n}\n\nvoid DeviceTickCompTaskNode::ConsumeAllRegsts() {\n  ConsumeRegst(\"in\");\n  ForEachInDataEdge([&](TaskEdge* edge) { ConsumeRegst(\"in\", edge->GetSoleRegst()); });\n}\n\nvoid DeviceTickCompTaskNode::ConsumeFakeRegsts() { ConsumeFakeRegst(\"in\"); }\n\nvoid DeviceTickCompTaskNode::BuildExecGphAndRegst() {\n  ExecNode* node = mut_exec_gph().NewNode();\n  node->mut_op() = op();\n  const std::list<std::shared_ptr<RegstDesc>>& in_regsts = GetConsumedRegst(\"in\");\n  for (const std::string& ibn : node->op()->input_bns()) {\n    node->BindBnWithOneOfTheRegsts(ibn, in_regsts);\n  }\n  std::shared_ptr<RegstDesc> out_regst = GetProducedRegst(\"out\");\n  for (const std::string& obn : node->op()->output_bns()) {\n    const LogicalBlobId& lbi = node->op()->BnInOp2Lbi(obn);\n    out_regst->AddLbi(lbi);\n    node->BindBnWithRegst(obn, out_regst);\n  }\n  (node->*GetInferBlobDescsMethod())(parallel_ctx());\n}\n\nREGISTER_COMP_TASK_STREAM_INDEX_GETTER(TaskType::kDeviceTick);\n\nREGISTER_SYSTEM_OP_COMP_TASK_NODE_TYPE(OperatorConf::kDeviceTickConf, DeviceTickCompTaskNode);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/graph_impl/distribute_concat_compute_task_node.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/graph/compute_task_node.h\"\n#include \"oneflow/core/graph/task_stream_index_manager.h\"\n\nnamespace oneflow {\n\nclass DistributeConcatCompTaskNode final : public CompTaskNode {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(DistributeConcatCompTaskNode);\n  DistributeConcatCompTaskNode() = default;\n  ~DistributeConcatCompTaskNode() = default;\n\n  void ProduceAllRegstsAndBindEdges() override;\n  void ConsumeAllRegsts() override;\n  void ConsumeFakeRegsts() override;\n\n  TaskType GetTaskType() const override { return TaskType::kDistributeConcat; }\n\n private:\n  void BuildExecGphAndRegst() override;\n  void BuildExecGphStructAndBindInRegst();\n  void BuildOutRegst();\n};\n\nvoid DistributeConcatCompTaskNode::ProduceAllRegstsAndBindEdges() {\n  ProduceRegst(\"out\", true);\n  ForEachOutDataEdge([&](TaskEdge* edge) { BindEdgeWithProducedRegst(edge, \"out\"); });\n}\n\nvoid DistributeConcatCompTaskNode::ConsumeAllRegsts() {\n  size_t cnt = 0;\n  ForEachInDataEdge([&](TaskEdge* edge) {\n    cnt += 1;\n    ConsumeRegst(\"in\", edge->GetSoleRegst());\n  });\n  CHECK_EQ(cnt, 1);\n}\n\nvoid DistributeConcatCompTaskNode::ConsumeFakeRegsts() { ConsumeFakeRegst(\"in\"); }\n\nvoid DistributeConcatCompTaskNode::BuildExecGphAndRegst() {\n  BuildExecGphStructAndBindInRegst();\n  BuildOutRegst();\n  mut_exec_gph().TopoForEachNode(\n      [this](ExecNode* node) { (node->*GetInferBlobDescsMethod())(parallel_ctx()); });\n}\n\nvoid DistributeConcatCompTaskNode::BuildExecGphStructAndBindInRegst() {\n  ExecNode* cur_node = mut_exec_gph().NewNode();\n  cur_node->mut_op() = this->op();\n  auto in_regst = GetSoleConsumedRegst(\"in\");\n  mut_exec_gph().ForEachNode([&](ExecNode* cur_node) {\n    const auto& ibn = cur_node->op()->input_bns().Get(parallel_ctx()->parallel_id());\n    cur_node->BindBnWithRegst(ibn, in_regst);\n    CHECK(in_regst->HasLbi(cur_node->op()->BnInOp2Lbi(ibn)));\n  });\n}  // namespace oneflow\n\nvoid DistributeConcatCompTaskNode::BuildOutRegst() {\n  std::shared_ptr<RegstDesc> out_regst = GetProducedRegst(\"out\");\n  mut_exec_gph().ForEachNode([&](ExecNode* cur_node) {\n    HashSet<LogicalBlobId> found_lbis;\n    for (ExecEdge* out_edge : cur_node->out_edges()) { found_lbis.insert(out_edge->lbi()); }\n    for (const std::string& obn : cur_node->op()->output_bns()) {\n      out_regst->AddLbi(cur_node->op()->BnInOp2Lbi(obn));\n      cur_node->BindBnWithRegst(obn, out_regst);\n    }\n  });\n  // NOTE: we can ONLY set inplace when regst has ONLY ONE blob\n  auto in_regst = GetSoleConsumedRegst(\"in\");\n  if (in_regst->NumOfLbi() == 1) {\n    out_regst->set_hint_inplace_consumed_regst_desc_id(in_regst->regst_desc_id());\n  }\n}\n\nREGISTER_COMP_TASK_STREAM_INDEX_GETTER(TaskType::kDistributeConcat);\n\nREGISTER_SYSTEM_OP_COMP_TASK_NODE_TYPE(OperatorConf::kDistributeConcatConf,\n                                       DistributeConcatCompTaskNode);\n\nREGISTER_SYSTEM_OP_COMP_TASK_NODE_TYPE(OperatorConf::kDistributeAddConf,\n                                       DistributeConcatCompTaskNode);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/graph_impl/distribute_split_compute_task_node.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/graph/compute_task_node.h\"\n#include \"oneflow/core/graph/task_stream_index_manager.h\"\n\nnamespace oneflow {\n\nclass DistributeSplitCompTaskNode final : public CompTaskNode {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(DistributeSplitCompTaskNode);\n  DistributeSplitCompTaskNode() = default;\n  ~DistributeSplitCompTaskNode() = default;\n\n  void ProduceAllRegstsAndBindEdges() override;\n  void ConsumeAllRegsts() override;\n  void ConsumeFakeRegsts() override;\n\n  TaskType GetTaskType() const override { return TaskType::kDistributeSplit; }\n\n private:\n  void BuildExecGphAndRegst() override;\n  void BuildExecGphStructAndBindInRegst();\n  void BuildOutRegst();\n};\n\nvoid DistributeSplitCompTaskNode::ProduceAllRegstsAndBindEdges() {\n  ProduceRegst(\"out\", true);\n  ForEachOutDataEdge([&](TaskEdge* edge) { BindEdgeWithProducedRegst(edge, \"out\"); });\n}\n\nvoid DistributeSplitCompTaskNode::ConsumeAllRegsts() {\n  ForEachInDataEdge([&](TaskEdge* edge) { ConsumeRegst(\"in\", edge->GetSoleRegst()); });\n}\n\nvoid DistributeSplitCompTaskNode::ConsumeFakeRegsts() { ConsumeFakeRegst(\"in\"); }\n\nvoid DistributeSplitCompTaskNode::BuildExecGphAndRegst() {\n  BuildExecGphStructAndBindInRegst();\n  BuildOutRegst();\n  mut_exec_gph().TopoForEachNode(\n      [this](ExecNode* node) { (node->*GetInferBlobDescsMethod())(parallel_ctx()); });\n}\n\nvoid DistributeSplitCompTaskNode::BuildExecGphStructAndBindInRegst() {\n  ExecNode* cur_node = mut_exec_gph().NewNode();\n  cur_node->mut_op() = this->op();\n  for (const std::string& ibn : cur_node->op()->input_bns()) {\n    cur_node->BindBnWithRegst(ibn, GetSoleConsumedRegst(\"in\"));\n  }\n}\n\nvoid DistributeSplitCompTaskNode::BuildOutRegst() {\n  std::shared_ptr<RegstDesc> out_regst = GetProducedRegst(\"out\");\n  mut_exec_gph().ForEachNode([&](ExecNode* cur_node) {\n    const auto& obn = cur_node->op()->output_bns().Get(parallel_ctx()->parallel_id());\n    out_regst->AddLbi(cur_node->op()->BnInOp2Lbi(obn));\n    cur_node->BindBnWithRegst(obn, out_regst);\n  });\n  // NOTE: we can ONLY set inplace when regst has ONLY ONE blob\n  auto in_regst = GetSoleConsumedRegst(\"in\");\n  if (in_regst->NumOfLbi() == 1) {\n    out_regst->set_hint_inplace_consumed_regst_desc_id(in_regst->regst_desc_id());\n  }\n}\n\nREGISTER_COMP_TASK_STREAM_INDEX_GETTER(TaskType::kDistributeSplit);\n\nREGISTER_SYSTEM_OP_COMP_TASK_NODE_TYPE(OperatorConf::kDistributeSplitConf,\n                                       DistributeSplitCompTaskNode);\n\nREGISTER_SYSTEM_OP_COMP_TASK_NODE_TYPE(OperatorConf::kDistributeCloneConf,\n                                       DistributeSplitCompTaskNode);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/graph_impl/dst_subset_tick_compute_task_node.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/graph/compute_task_node.h\"\n#include \"oneflow/core/graph/task_stream_index_manager.h\"\n\nnamespace oneflow {\n\nclass DstSubsetTickCompTaskNode final : public CompTaskNode {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(DstSubsetTickCompTaskNode);\n  DstSubsetTickCompTaskNode() = default;\n  ~DstSubsetTickCompTaskNode() = default;\n\n  bool IsMeaningLess() override { return false; }\n  TaskType GetTaskType() const override { return TaskType::kDstSubsetTick; }\n\n private:\n  void ProduceAllRegstsAndBindEdges() override;\n  void ConsumeAllRegsts() override;\n  void ConsumeFakeRegsts() override;\n  void BuildExecGphAndRegst() override;\n};\n\nvoid DstSubsetTickCompTaskNode::ProduceAllRegstsAndBindEdges() {\n  ProduceRegst(\"out\", false, 2, 2);\n  ForEachOutDataEdge([&](TaskEdge* edge) { BindEdgeWithProducedRegst(edge, \"out\"); });\n}\n\nvoid DstSubsetTickCompTaskNode::ConsumeAllRegsts() {\n  ConsumeRegst(\"in\");\n  ForEachInDataEdge([&](TaskEdge* edge) { ConsumeRegst(\"in\", edge->GetSoleRegst()); });\n}\n\nvoid DstSubsetTickCompTaskNode::ConsumeFakeRegsts() { ConsumeFakeRegst(\"in\"); }\n\nvoid DstSubsetTickCompTaskNode::BuildExecGphAndRegst() {\n  ExecNode* node = mut_exec_gph().NewNode();\n  node->mut_op() = op();\n  const std::list<std::shared_ptr<RegstDesc>>& in_regsts = GetConsumedRegst(\"in\");\n  for (const std::string& ibn : node->op()->input_bns()) {\n    node->TryBindBnWithOneOfTheRegsts(ibn, in_regsts);\n  }\n  std::shared_ptr<RegstDesc> out_regst = GetProducedRegst(\"out\");\n  for (const std::string& obn : node->op()->output_bns()) {\n    const LogicalBlobId& lbi = node->op()->BnInOp2Lbi(obn);\n    out_regst->AddLbi(lbi);\n    node->BindBnWithRegst(obn, out_regst);\n  }\n  (node->*GetInferBlobDescsMethod())(parallel_ctx());\n}\n\nREGISTER_TICK_TASK_STREAM_INDEX_GETTER(TaskType::kDstSubsetTick);\n\nREGISTER_SYSTEM_OP_COMP_TASK_NODE_TYPE(OperatorConf::kDstSubsetTickConf, DstSubsetTickCompTaskNode);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/graph_impl/esac_compute_task_node.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/graph/compute_task_node.h\"\n#include \"oneflow/core/graph/task_stream_index_manager.h\"\n\nnamespace oneflow {\n\nclass EsacCompTaskNode final : public CompTaskNode {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(EsacCompTaskNode);\n  EsacCompTaskNode() = default;\n  ~EsacCompTaskNode() override = default;\n\n  void ProduceAllRegstsAndBindEdges() override;\n  void ConsumeAllRegsts() override;\n  void ConsumeFakeRegsts() override { UNIMPLEMENTED() << \"EsacCompTaskNode is deprecated\"; }\n\n  TaskType GetTaskType() const override { return TaskType::kEsac; }\n\n private:\n  void BuildExecGphAndRegst() override;\n  void InferProducedDataRegstTimeShape() override;\n};\n\nvoid EsacCompTaskNode::ConsumeAllRegsts() {\n  HashMap<LogicalBlobId, int64_t> lbi2ibn_id;\n  FOR_RANGE(int64_t, ibn_id, 0, op()->input_bns().size()) {\n    CHECK(lbi2ibn_id.emplace(op()->BnInOp2Lbi(GenRepeatedBn(\"in\", ibn_id)), ibn_id).second);\n  }\n  ForEachInDataEdge([&](TaskEdge* edge) {\n    const OpNode* pred = GetOnePredOpNodeOnEdge(edge);\n    int64_t ibn_id = -1;\n    for (const std::string& obn : pred->shared_op()->output_bns()) {\n      const LogicalBlobId& lbi = pred->shared_op()->BnInOp2Lbi(obn);\n      if (lbi2ibn_id.find(lbi) != lbi2ibn_id.cend()) {\n        CHECK_EQ(ibn_id, -1);\n        ibn_id = lbi2ibn_id.at(lbi);\n      }\n    }\n    CHECK_NE(ibn_id, -1);\n    ConsumeRegst(\"in_\" + std::to_string(ibn_id), edge->GetSoleRegst());\n  });\n}\n\nvoid EsacCompTaskNode::ProduceAllRegstsAndBindEdges() {\n  std::shared_ptr<RegstDesc> out_regst = ProduceRegst(\"out\", false, 1, 1);\n  ForEachOutDataEdge([&](TaskEdge* edge) { edge->AddRegst(\"out\", out_regst); });\n}\n\nvoid EsacCompTaskNode::BuildExecGphAndRegst() {\n  ExecNode* node = mut_exec_gph().NewNode();\n  std::shared_ptr<const Operator> sole_op = this->op();\n  node->mut_op() = sole_op;\n  FOR_RANGE(int64_t, ibn_id, 0, sole_op->input_bns().size()) {\n    node->BindBnWithRegst(GenRepeatedBn(\"in\", ibn_id),\n                          GetSoleConsumedRegst(\"in_\" + std::to_string(ibn_id)));\n  }\n  std::shared_ptr<RegstDesc> out_regst = GetProducedRegst(\"out\");\n  out_regst->AddLbi(sole_op->BnInOp2Lbi(\"out\"));\n  node->BindBnWithRegst(\"out\", out_regst);\n  (node->*GetInferBlobDescsMethod())(parallel_ctx());\n}\n\nvoid EsacCompTaskNode::InferProducedDataRegstTimeShape() { NaiveInferProducedDataRegstTimeShape(); }\n\nREGISTER_TICK_TASK_STREAM_INDEX_GETTER(TaskType::kEsac);\n\nREGISTER_SYSTEM_OP_COMP_TASK_NODE_TYPE(OperatorConf::kEsacConf, EsacCompTaskNode);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/graph_impl/normal_forward_compute_task_node.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/graph/normal_forward_compute_task_node.h\"\n#include \"oneflow/core/graph/task_stream_index_manager.h\"\n#include \"oneflow/core/framework/framework.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nsize_t RegstNum4OpSameOutputBlob(OperatorConf::OpTypeCase op_type_case) {\n  if (IsClassRegistered<int32_t, RuntimeRegstNum4OpSameOutputBlob>(op_type_case)) {\n    std::unique_ptr<RuntimeRegstNum4OpSameOutputBlob> ptr;\n    ptr.reset(NewObj<int32_t, RuntimeRegstNum4OpSameOutputBlob>(op_type_case));\n    return *ptr;\n  } else {\n    return -1;\n  }\n}\n\nstd::string GetOutRegstNameByObn(const std::string& obn) { return \"__\" + obn; }\n\n}  // namespace\n\nvoid NormalForwardCompTaskNode::ProduceOutRegstByNameAndBlockNum(const std::string& name,\n                                                                 size_t mem_block_num) {\n  if (mem_block_num != -1) {\n    CHECK_GT(mem_block_num, 0);\n    ProduceRegst(name, false, mem_block_num, mem_block_num);\n  } else {\n    ProduceRegst(name, true);\n  }\n}\n\nsize_t RegstNum4Op(const Operator& sole_op) {\n  size_t mem_block_num = RegstNum4OpSameOutputBlob(sole_op.op_conf().op_type_case());\n  if (sole_op.op_conf().has_user_conf()) {\n    const std::string& op_type_name = sole_op.op_conf().user_conf().op_type_name();\n    const auto* op_reg_result = user_op::UserOpRegistryMgr::Get().GetOpRegistryResult(op_type_name);\n    CHECK(op_reg_result != nullptr) << \"op_type_name \" << op_type_name << \" not register\";\n    if (op_reg_result->same_output_regst_num > 0) {\n      mem_block_num = op_reg_result->same_output_regst_num;\n    }\n    if (IsClassRegistered<std::string, RuntimeRegstNum4OpSameOutputBlob>(op_type_name)) {\n      std::unique_ptr<RuntimeRegstNum4OpSameOutputBlob> ptr;\n      ptr.reset(NewObj<std::string, RuntimeRegstNum4OpSameOutputBlob>(op_type_name));\n      mem_block_num = *ptr;\n    }\n    if (op_type_name == \"identity_buffer\") {\n      mem_block_num = user_op::UserOpConfWrapper(sole_op.op_conf()).attr<int64_t>(\"buffer_size\");\n    }\n  }\n  return mem_block_num;\n}\n\nvoid NormalForwardCompTaskNode::ProduceAllRegstsAndBindEdges() {\n  std::shared_ptr<const Operator> sole_op = op();\n  size_t mem_block_num = RegstNum4Op(*sole_op);\n  // when output blob num > 1 and task node on out edge is all NormalForwardCompTaskNode ,\n  // create multi out regst by output blob name in op\n\n  HashMap<LogicalBlobId, std::string> lbi2out_regst_name;\n  for (const std::string& obn : sole_op->output_bns()) {\n    const LogicalBlobId& lbi = sole_op->BnInOp2Lbi(obn);\n    std::string out_regst_name = GetOutRegstNameByObn(obn);\n    lbi2out_regst_name.insert({lbi, out_regst_name});\n    ProduceOutRegstByNameAndBlockNum(out_regst_name, mem_block_num);\n  }\n  ForEachOutDataEdge([&](TaskEdge* edge) {\n    for (const LogicalBlobId& lbi : edge->GetLbis()) {\n      auto it = lbi2out_regst_name.find(lbi);\n      CHECK(it != lbi2out_regst_name.end());\n      BindEdgeWithProducedRegst(edge, it->second);\n    }\n  });\n  ProduceRegst(\"tmp\", true);\n}\n\nvoid NormalForwardCompTaskNode::ConsumeAllRegsts() {\n  ForEachInDataEdge([&](TaskEdge* edge) {\n    for (const auto& regst : edge->GetRegsts()) { ConsumeRegst(\"in\", regst); }\n  });\n}\n\nvoid NormalForwardCompTaskNode::ConsumeFakeRegsts() { ConsumeFakeRegst(\"in\"); }\n\nvoid NormalForwardCompTaskNode::BuildExecGphAndRegst() {\n  BuildExecGphStructAndBindInRegst();\n  BuildOutRegst();\n  BuildTmp7BufRegsts();\n  mut_exec_gph().TopoForEachNode(\n      [this](ExecNode* node) { (node->*GetInferBlobDescsMethod())(parallel_ctx()); });\n}\n\nvoid NormalForwardCompTaskNode::BuildExecGphStructAndBindInRegst() {\n  ExecNode* cur_node = mut_exec_gph().NewNode();\n  cur_node->mut_op() = op();\n  const std::list<std::shared_ptr<RegstDesc>>& in_regsts = GetConsumedRegst(\"in\");\n  for (const std::string& ibn : cur_node->op()->input_bns()) {\n    cur_node->BindBnWithOneOfTheRegsts(ibn, in_regsts);\n  }\n}\n\nvoid NormalForwardCompTaskNode::BuildOutRegst() {\n  ExecNode* exec_node = mut_exec_gph().SoleNode();\n  for (const std::string& obn : exec_node->op()->output_bns()) {\n    std::string out_regst_name = GetOutRegstNameByObn(obn);\n    std::shared_ptr<RegstDesc> out_regst = GetProducedRegst(out_regst_name);\n    out_regst->AddLbi(exec_node->op()->BnInOp2Lbi(obn));\n    exec_node->BindBnWithRegst(obn, out_regst);\n  }\n}\n\nvoid NormalForwardCompTaskNode::BuildTmp7BufRegsts() {\n  mut_exec_gph().ForEachNode([&](ExecNode* node) {\n    node->AddBnToRegstAndBindIt(&Operator::tmp_bns, GetProducedRegst(\"tmp\"));\n  });\n}\n\nREGISTER_COMP_TASK_STREAM_INDEX_GETTER(TaskType::kNormalForward);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/graph_impl/pack_compute_task_node.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/graph/compute_task_node.h\"\n#include \"oneflow/core/graph/task_stream_index_manager.h\"\n\nnamespace oneflow {\n\nclass PackCompTaskNode final : public CompTaskNode {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(PackCompTaskNode);\n  PackCompTaskNode() = default;\n  ~PackCompTaskNode() override = default;\n\n  TaskType GetTaskType() const override { return TaskType::kPack; }\n\n  void ProduceAllRegstsAndBindEdges() override;\n  void ConsumeAllRegsts() override;\n  void ConsumeFakeRegsts() override;\n\n private:\n  void BuildExecGphAndRegst() override;\n};\n\nvoid PackCompTaskNode::ProduceAllRegstsAndBindEdges() {\n  ProduceRegst(\"out\", false);\n  ForEachOutDataEdge([&](TaskEdge* edge) { BindEdgeWithProducedRegst(edge, \"out\"); });\n}\n\nvoid PackCompTaskNode::ConsumeAllRegsts() { ConsumeRegst(\"in\", SoleInDataEdge()->GetSoleRegst()); }\n\nvoid PackCompTaskNode::ConsumeFakeRegsts() { ConsumeFakeRegst(\"in\"); }\n\nvoid PackCompTaskNode::BuildExecGphAndRegst() {\n  ExecNode* exec_node = mut_exec_gph().NewNode();\n  exec_node->mut_op() = op();\n  std::shared_ptr<RegstDesc> in_regst = GetSoleConsumedRegst(\"in\");\n  exec_node->BindBnWithRegst(op()->SoleIbn(), in_regst);\n\n  std::shared_ptr<RegstDesc> out_regst = GetProducedRegst(\"out\");\n  out_regst->AddLbi(op()->BnInOp2Lbi(op()->SoleObn()));\n  exec_node->BindBnWithRegst(op()->SoleObn(), out_regst);\n\n  (exec_node->*GetInferBlobDescsMethod())(parallel_ctx());\n}\n\nREGISTER_COMP_TASK_STREAM_INDEX_GETTER(TaskType::kPack);\n\nREGISTER_USER_OP_COMP_TASK_NODE_TYPE(\"pack\", PackCompTaskNode);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/graph_impl/reentrant_lock_compute_task_node.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/graph/compute_task_node.h\"\n#include \"oneflow/core/graph/task_stream_index_manager.h\"\n\nnamespace oneflow {\n\nclass ReentrantLockCompTaskNode final : public CompTaskNode {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(ReentrantLockCompTaskNode);\n  ReentrantLockCompTaskNode() = default;\n  ~ReentrantLockCompTaskNode() = default;\n\n  bool IsMeaningLess() override { return false; }\n  TaskType GetTaskType() const override { return TaskType::kReentrantLock; }\n\n private:\n  void ProduceAllRegstsAndBindEdges() override;\n  void ConsumeAllRegsts() override;\n  void ConsumeFakeRegsts() override;\n  void BuildExecGphAndRegst() override;\n  void InferProducedDataRegstTimeShape() override;\n};\n\nvoid ReentrantLockCompTaskNode::ProduceAllRegstsAndBindEdges() {\n  ProduceRegst(\"out\", false, 1, 1);\n  ForEachOutDataEdge([&](TaskEdge* edge) { BindEdgeWithProducedRegst(edge, \"out\"); });\n}\n\nvoid ReentrantLockCompTaskNode::ConsumeAllRegsts() {\n  ConsumeRegst(\"in\");\n  ForEachInDataEdge([&](TaskEdge* edge) { ConsumeRegst(\"in\", edge->GetSoleRegst()); });\n}\n\nvoid ReentrantLockCompTaskNode::ConsumeFakeRegsts() { ConsumeFakeRegst(\"in\"); }\n\nvoid ReentrantLockCompTaskNode::BuildExecGphAndRegst() {\n  ExecNode* node = mut_exec_gph().NewNode();\n  node->mut_op() = op();\n  const std::list<std::shared_ptr<RegstDesc>>& in_regsts = GetConsumedRegst(\"in\");\n  // no regst_desc for ibn \"end\" provided because TaskGraph hates cycle\n  node->BindBnWithOneOfTheRegsts(\"start\", in_regsts);\n  std::shared_ptr<RegstDesc> out_regst = GetProducedRegst(\"out\");\n  for (const std::string& obn : node->op()->output_bns()) {\n    const LogicalBlobId& lbi = node->op()->BnInOp2Lbi(obn);\n    out_regst->AddLbi(lbi);\n    node->BindBnWithRegst(obn, out_regst);\n  }\n  (node->*GetInferBlobDescsMethod())(parallel_ctx());\n}\n\nvoid ReentrantLockCompTaskNode::InferProducedDataRegstTimeShape() {\n  std::shared_ptr<Shape> time_shape(new Shape());\n  for (TaskEdge* edge : in_edges()) {\n    if (edge->src_node()->GetFastestInputOutputTimeShape()) {\n      *time_shape = *edge->src_node()->GetFastestInputOutputTimeShape();\n    }\n  }\n  CHECK_GT(time_shape->elem_cnt(), 0);\n  ForEachProducedDataRegst([time_shape](const std::string& name, RegstDesc* regst) {\n    *regst->mut_data_regst_time_shape() = time_shape;\n  });\n}\n\nREGISTER_TICK_TASK_STREAM_INDEX_GETTER(TaskType::kReentrantLock);\n\nREGISTER_SYSTEM_OP_COMP_TASK_NODE_TYPE(OperatorConf::kReentrantLockConf, ReentrantLockCompTaskNode);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/graph_impl/repeat_compute_task_node.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/graph/compute_task_node.h\"\n#include \"oneflow/core/graph/task_stream_index_manager.h\"\n\nnamespace oneflow {\n\nclass RepeatCompTaskNode final : public CompTaskNode {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(RepeatCompTaskNode);\n  RepeatCompTaskNode() = default;\n  ~RepeatCompTaskNode() override = default;\n\n  void ProduceAllRegstsAndBindEdges() override;\n  void ConsumeAllRegsts() override;\n  void ConsumeFakeRegsts() override;\n\n  TaskType GetTaskType() const override { return TaskType::kRepeat; }\n\n private:\n  void BuildExecGphAndRegst() override;\n};\n\nvoid RepeatCompTaskNode::ConsumeAllRegsts() {\n  ConsumeRegst(\"in\", SoleInDataEdge()->GetSoleRegst());\n}\n\nvoid RepeatCompTaskNode::ConsumeFakeRegsts() { ConsumeFakeRegst(\"in\"); }\n\nvoid RepeatCompTaskNode::ProduceAllRegstsAndBindEdges() {\n  std::shared_ptr<RegstDesc> out_regst = ProduceRegst(\"out\", false, 1, 1);\n  ForEachOutDataEdge([&](TaskEdge* edge) { edge->AddRegst(\"out\", out_regst); });\n}\n\nvoid RepeatCompTaskNode::BuildExecGphAndRegst() {\n  std::shared_ptr<RegstDesc> in_regst = GetSoleConsumedRegst(\"in\");\n  ExecNode* node = mut_exec_gph().NewNode();\n  std::shared_ptr<const Operator> sole_op = op();\n  node->mut_op() = sole_op;\n  node->BindBnWithRegst(sole_op->SoleIbn(), in_regst);\n  std::shared_ptr<RegstDesc> out_regst = GetProducedRegst(\"out\");\n  out_regst->AddLbi(sole_op->BnInOp2Lbi(sole_op->SoleObn()));\n  node->BindBnWithRegst(sole_op->SoleObn(), out_regst);\n  (node->*GetInferBlobDescsMethod())(parallel_ctx());\n\n  // NOTE(chengcheng): force inplace\n  CHECK_EQ(in_regst->NumOfLbi(), 1);\n  CHECK_EQ(out_regst->NumOfLbi(), 1);\n  CHECK_EQ(in_regst->min_register_num(), 1);\n  // NOTE(chengcheng): input need unreused mem\n  in_regst->set_enable_reuse_mem(false);\n  out_regst->set_force_inplace_consumed_regst_desc_id(in_regst->regst_desc_id());\n}\n\nREGISTER_COMP_TASK_STREAM_INDEX_GETTER(TaskType::kRepeat);\n\nREGISTER_USER_OP_COMP_TASK_NODE_TYPE(\"repeat\", RepeatCompTaskNode);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/graph_impl/source_tick_compute_task_node.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/graph/compute_task_node.h\"\n#include \"oneflow/core/graph/task_stream_index_manager.h\"\n\nnamespace oneflow {\n\nclass SourceTickCompTaskNode final : public CompTaskNode {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(SourceTickCompTaskNode);\n  SourceTickCompTaskNode() = default;\n  ~SourceTickCompTaskNode() = default;\n\n  void ProduceAllRegstsAndBindEdges() override;\n  void ConsumeAllRegsts() override {}\n  void ConsumeFakeRegsts() override {}\n  void BuildExecGphAndRegst() override;\n  bool IsMeaningLess() override { return false; }\n\n  TaskType GetTaskType() const override { return TaskType::kSourceTick; }\n};\n\nvoid SourceTickCompTaskNode::ProduceAllRegstsAndBindEdges() {\n  std::shared_ptr<RegstDesc> out_regst = ProduceRegst(\"out\", false, 2, 2);\n  ForEachOutDataEdge([&](TaskEdge* edge) { edge->AddRegst(\"out\", out_regst); });\n}\n\nvoid SourceTickCompTaskNode::BuildExecGphAndRegst() {\n  std::shared_ptr<RegstDesc> out_regst = GetProducedRegst(\"out\");\n  ExecNode* node = mut_exec_gph().NewNode();\n  node->mut_op() = op();\n  for (const std::string& obn : node->op()->output_bns()) {\n    const LogicalBlobId& lbi = node->op()->BnInOp2Lbi(obn);\n    out_regst->AddLbi(lbi);\n    node->BindBnWithRegst(obn, out_regst);\n  }\n  (node->*GetInferBlobDescsMethod())(parallel_ctx());\n}\n\nREGISTER_TICK_TASK_STREAM_INDEX_GETTER(TaskType::kSourceTick);\n\nREGISTER_SYSTEM_OP_COMP_TASK_NODE_TYPE(OperatorConf::kSourceTickConf, SourceTickCompTaskNode);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/graph_impl/src_subset_tick_compute_task_node.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/graph/compute_task_node.h\"\n#include \"oneflow/core/graph/task_stream_index_manager.h\"\n\nnamespace oneflow {\n\nclass SrcSubsetTickCompTaskNode final : public CompTaskNode {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(SrcSubsetTickCompTaskNode);\n  SrcSubsetTickCompTaskNode() = default;\n  ~SrcSubsetTickCompTaskNode() = default;\n\n  bool IsMeaningLess() override { return false; }\n  TaskType GetTaskType() const override { return TaskType::kSrcSubsetTick; }\n\n private:\n  void ProduceAllRegstsAndBindEdges() override;\n  void ConsumeAllRegsts() override;\n  void ConsumeFakeRegsts() override;\n  void BuildExecGphAndRegst() override;\n};\n\nvoid SrcSubsetTickCompTaskNode::ProduceAllRegstsAndBindEdges() {\n  ProduceRegst(\"out\", false, 2, 2);\n  ForEachOutDataEdge([&](TaskEdge* edge) { BindEdgeWithProducedRegst(edge, \"out\"); });\n}\n\nvoid SrcSubsetTickCompTaskNode::ConsumeAllRegsts() {\n  ConsumeRegst(\"in\");\n  ForEachInDataEdge([&](TaskEdge* edge) { ConsumeRegst(\"in\", edge->GetSoleRegst()); });\n}\n\nvoid SrcSubsetTickCompTaskNode::ConsumeFakeRegsts() { ConsumeFakeRegst(\"in\"); }\n\nvoid SrcSubsetTickCompTaskNode::BuildExecGphAndRegst() {\n  ExecNode* node = mut_exec_gph().NewNode();\n  node->mut_op() = op();\n  const std::list<std::shared_ptr<RegstDesc>>& in_regsts = GetConsumedRegst(\"in\");\n  for (const std::string& ibn : node->op()->input_bns()) {\n    node->TryBindBnWithOneOfTheRegsts(ibn, in_regsts);\n  }\n  std::shared_ptr<RegstDesc> out_regst = GetProducedRegst(\"out\");\n  for (const std::string& obn : node->op()->output_bns()) {\n    const LogicalBlobId& lbi = node->op()->BnInOp2Lbi(obn);\n    out_regst->AddLbi(lbi);\n    node->BindBnWithRegst(obn, out_regst);\n  }\n  (node->*GetInferBlobDescsMethod())(parallel_ctx());\n}\n\nREGISTER_TICK_TASK_STREAM_INDEX_GETTER(TaskType::kSrcSubsetTick);\n\nREGISTER_SYSTEM_OP_COMP_TASK_NODE_TYPE(OperatorConf::kSrcSubsetTickConf, SrcSubsetTickCompTaskNode);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/graph_impl/ssp_variable_proxy_task_node.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/graph/compute_task_node.h\"\n#include \"oneflow/core/graph/copy_task_node.h\"\n#include \"oneflow/core/graph/task_stream_index_manager.h\"\n#include \"oneflow/core/framework/framework.h\"\n\nnamespace oneflow {\n\nclass SspVariableProxyCompTaskNode final : public CompTaskNode {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(SspVariableProxyCompTaskNode);\n  SspVariableProxyCompTaskNode() = default;\n  ~SspVariableProxyCompTaskNode() = default;\n\n  void ProduceAllRegstsAndBindEdges() override {\n    int64_t buffer_size = user_op::UserOpConfWrapper(op()->op_conf()).attr<int64_t>(\"buffer_size\");\n    CHECK_GT(buffer_size, 0);\n    ProduceRegst(\"value\", false, buffer_size, buffer_size);\n    ProduceRegst(\"ref\", false, 1, 1);\n    HashMap<std::string, std::set<TaskEdge*>> out_regst_name2edges;\n    ForEachOutDataEdge(\n        [&](TaskEdge* edge) {\n          {\n            auto* copy_hd_node = dynamic_cast<CopyHdTaskNode*>(edge->dst_node());\n            if (copy_hd_node != nullptr) {\n              // The only possible regst_name is \"value\" because \"ref\" is always strictly one-to-one\n              // connected.\n              CHECK_EQ(*out_regst_name2edges[\"value\"].insert(edge).first, edge);\n              return;\n            }\n          }\n          auto* dst_node = dynamic_cast<CompTaskNode*>(edge->dst_node());\n          CHECK(dst_node != nullptr)\n              << \"SspVariableProxyTaskNode must be consumed by CompTaskNode. got \"\n              << TaskType_Name(edge->dst_node()->GetTaskType());\n          for (const std::string& ibn : dst_node->op()->input_bns()) {\n            const LogicalBlobId& dst_in_lbi = dst_node->op()->BnInOp2Lbi(ibn);\n            if (dst_in_lbi == op()->BnInOp2Lbi(\"ref_0\")) {\n              CHECK_EQ(*out_regst_name2edges[\"ref\"].insert(edge).first, edge);\n            } else if (dst_in_lbi == op()->BnInOp2Lbi(\"value_0\")) {\n              CHECK_EQ(*out_regst_name2edges[\"value\"].insert(edge).first, edge);\n            } else {\n              // do nothing\n            }\n          }\n        });\n    for (const auto& pair : out_regst_name2edges) {\n      for (TaskEdge* edge : pair.second) { BindEdgeWithProducedRegst(edge, pair.first); }\n    }\n  }\n  void ConsumeAllRegsts() override {\n    ConsumeRegst(\"var\");\n    ForEachInDataEdge([&](TaskEdge* edge) { ConsumeRegst(\"var\", edge->GetSoleRegst()); });\n  }\n\n  void ConsumeFakeRegsts() override { ConsumeFakeRegst(\"var\"); }\n\n  TaskType GetTaskType() const override { return TaskType::kSspVariableProxy; }\n\n private:\n  void BuildExecGphAndRegst() override {\n    BuildExecGphStructAndBindInRegst();\n    BuildOutRegst();\n    mut_exec_gph().TopoForEachNode(\n        [this](ExecNode* node) { (node->*GetInferBlobDescsMethod())(parallel_ctx()); });\n  }\n\n  void BuildExecGphStructAndBindInRegst() {\n    ExecNode* exec_node = mut_exec_gph().NewNode();\n    exec_node->mut_op() = op();\n    exec_node->BindBnWithOneOfTheRegsts(\"var_0\", GetConsumedRegst(\"var\"));\n    BindInplacebetweenVarAndRef();\n  }\n\n  void BindInplacebetweenVarAndRef() {\n    const auto& var_regst = GetSoleConsumedRegst(\"var\");\n    CHECK_EQ(var_regst->NumOfLbi(), 1);\n    CHECK_EQ(var_regst->min_register_num(), 1);\n    CHECK_EQ(var_regst->max_register_num(), 1);\n    const auto& ref_regst = GetProducedRegst(\"ref\");\n    ref_regst->set_force_inplace_consumed_regst_desc_id(var_regst->regst_desc_id());\n  }\n\n  void BuildOutRegst() {\n    ExecNode* exec_node = mut_exec_gph().SoleNode();\n    const auto& AddLbiAndBindBn = [&](const std::string& regst_name) {\n      // \"ref_0\" obn <-> \"ref\" regst_name\n      // \"value_0\" obn <-> \"value\" regst_name\n      const std::string& obn = regst_name + \"_0\";\n      const std::shared_ptr<RegstDesc>& regst = GetProducedRegst(regst_name);\n      regst->AddLbi(exec_node->op()->BnInOp2Lbi(obn));\n      exec_node->BindBnWithRegst(obn, regst);\n    };\n    AddLbiAndBindBn(\"ref\");\n    AddLbiAndBindBn(\"value\");\n  }\n\n  void InferProducedDataRegstTimeShape() override { NaiveInferProducedDataRegstTimeShape(); }\n};\n\nREGISTER_COMP_TASK_STREAM_INDEX_GETTER(TaskType::kSspVariableProxy);\n\nREGISTER_USER_OP_COMP_TASK_NODE_TYPE(\"ssp_variable_proxy\", SspVariableProxyCompTaskNode);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/graph_impl/tick_compute_task_node.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/graph/compute_task_node.h\"\n#include \"oneflow/core/graph/task_stream_index_manager.h\"\n\nnamespace oneflow {\n\nclass TickCompTaskNode final : public CompTaskNode {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(TickCompTaskNode);\n  TickCompTaskNode() = default;\n  ~TickCompTaskNode() = default;\n\n  bool IsMeaningLess() override { return false; }\n  TaskType GetTaskType() const override { return TaskType::kTick; }\n\n private:\n  void ProduceAllRegstsAndBindEdges() override;\n  void ConsumeAllRegsts() override;\n  void ConsumeFakeRegsts() override;\n  void BuildExecGphAndRegst() override;\n};\n\nvoid TickCompTaskNode::ProduceAllRegstsAndBindEdges() {\n  ProduceRegst(\"out\", false, 1, 1);\n  ForEachOutDataEdge([&](TaskEdge* edge) { BindEdgeWithProducedRegst(edge, \"out\"); });\n}\n\nvoid TickCompTaskNode::ConsumeAllRegsts() {\n  ConsumeRegst(\"in\");\n  ForEachInDataEdge([&](TaskEdge* edge) { ConsumeRegst(\"in\", edge->GetSoleRegst()); });\n}\n\nvoid TickCompTaskNode::ConsumeFakeRegsts() { ConsumeFakeRegst(\"in\"); }\n\nvoid TickCompTaskNode::BuildExecGphAndRegst() {\n  ExecNode* node = mut_exec_gph().NewNode();\n  node->mut_op() = op();\n  const std::list<std::shared_ptr<RegstDesc>>& in_regsts = GetConsumedRegst(\"in\");\n  for (const std::string& ibn : node->op()->input_bns()) {\n    node->BindBnWithOneOfTheRegsts(ibn, in_regsts);\n  }\n  std::shared_ptr<RegstDesc> out_regst = GetProducedRegst(\"out\");\n  for (const std::string& obn : node->op()->output_bns()) {\n    const LogicalBlobId& lbi = node->op()->BnInOp2Lbi(obn);\n    out_regst->AddLbi(lbi);\n    node->BindBnWithRegst(obn, out_regst);\n  }\n  (node->*GetInferBlobDescsMethod())(parallel_ctx());\n}\n\nREGISTER_TICK_TASK_STREAM_INDEX_GETTER(TaskType::kTick);\n\nREGISTER_SYSTEM_OP_COMP_TASK_NODE_TYPE(OperatorConf::kTickConf, TickCompTaskNode);\n\nREGISTER_SYSTEM_OP_COMP_TASK_NODE_TYPE(OperatorConf::kSinkTickConf, TickCompTaskNode);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/graph_impl/unpack_compute_task_node.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/graph/compute_task_node.h\"\n#include \"oneflow/core/graph/task_stream_index_manager.h\"\n\nnamespace oneflow {\n\nclass UnpackCompTaskNode final : public CompTaskNode {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(UnpackCompTaskNode);\n  UnpackCompTaskNode() = default;\n  ~UnpackCompTaskNode() override = default;\n\n  TaskType GetTaskType() const override { return TaskType::kUnpack; }\n\n  void ProduceAllRegstsAndBindEdges() override;\n  void ConsumeAllRegsts() override;\n  void ConsumeFakeRegsts() override;\n\n private:\n  void BuildExecGphAndRegst() override;\n};\n\nvoid UnpackCompTaskNode::ProduceAllRegstsAndBindEdges() {\n  ProduceRegst(\"out\", false);\n  ForEachOutDataEdge([&](TaskEdge* edge) { BindEdgeWithProducedRegst(edge, \"out\"); });\n}\n\nvoid UnpackCompTaskNode::ConsumeAllRegsts() {\n  ConsumeRegst(\"in\", SoleInDataEdge()->GetSoleRegst());\n}\n\nvoid UnpackCompTaskNode::ConsumeFakeRegsts() { ConsumeFakeRegst(\"in\"); }\n\nvoid UnpackCompTaskNode::BuildExecGphAndRegst() {\n  ExecNode* exec_node = mut_exec_gph().NewNode();\n  exec_node->mut_op() = op();\n  exec_node->BindBnWithRegst(op()->SoleIbn(), GetSoleConsumedRegst(\"in\"));\n\n  std::shared_ptr<RegstDesc> out_regst = GetProducedRegst(\"out\");\n  out_regst->AddLbi(op()->BnInOp2Lbi(op()->SoleObn()));\n  exec_node->BindBnWithRegst(op()->SoleObn(), out_regst);\n  (exec_node->*GetInferBlobDescsMethod())(parallel_ctx());\n}\n\nREGISTER_COMP_TASK_STREAM_INDEX_GETTER(TaskType::kUnpack);\n\nREGISTER_USER_OP_COMP_TASK_NODE_TYPE(\"unpack\", UnpackCompTaskNode);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/graph_impl/wait_and_send_ids_compute_task_node.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/graph/compute_task_node.h\"\n#include \"oneflow/core/graph/task_stream_index_manager.h\"\n\nnamespace oneflow {\n\nclass WaitAndSendIdsCompTaskNode final : public CompTaskNode {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(WaitAndSendIdsCompTaskNode);\n  WaitAndSendIdsCompTaskNode() = default;\n  ~WaitAndSendIdsCompTaskNode() override = default;\n\n  void ProduceAllRegstsAndBindEdges() override;\n  void ConsumeAllRegsts() override {}\n  void ConsumeFakeRegsts() override {}\n  void BuildExecGphAndRegst() override;\n  bool IsMeaningLess() override { return false; }\n\n  TaskType GetTaskType() const override { return TaskType::kWaitAndSendIds; }\n\n private:\n  void InferProducedDataRegstTimeShape() override;\n};\n\nvoid WaitAndSendIdsCompTaskNode::ProduceAllRegstsAndBindEdges() {\n  std::shared_ptr<RegstDesc> out_regst = ProduceRegst(\"out\", false, 100, 100);\n  ForEachOutDataEdge([&](TaskEdge* edge) { edge->AddRegst(\"out\", out_regst); });\n}\n\nvoid WaitAndSendIdsCompTaskNode::BuildExecGphAndRegst() {\n  std::shared_ptr<RegstDesc> out_regst = GetProducedRegst(\"out\");\n  ExecNode* node = mut_exec_gph().NewNode();\n  node->mut_op() = op();\n  for (const std::string& obn : node->op()->output_bns()) {\n    const LogicalBlobId& lbi = node->op()->BnInOp2Lbi(obn);\n    out_regst->AddLbi(lbi);\n    node->BindBnWithRegst(obn, out_regst);\n  }\n  (node->*GetInferBlobDescsMethod())(parallel_ctx());\n}\n\nvoid WaitAndSendIdsCompTaskNode::InferProducedDataRegstTimeShape() {\n  std::shared_ptr<Shape> time_shape(new Shape({1, 1}));\n  ForEachProducedDataRegst([time_shape](const std::string& name, RegstDesc* regst) {\n    *regst->mut_data_regst_time_shape() = time_shape;\n  });\n}\n\nREGISTER_INDEPENDENT_TASK_STREAM_INDEX_GETTER(TaskType::kWaitAndSendIds);\n\nREGISTER_SYSTEM_OP_COMP_TASK_NODE_TYPE(OperatorConf::kWaitAndSendIdsConf,\n                                       WaitAndSendIdsCompTaskNode);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/hardware/basic_device_descriptor_list.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/hardware/basic_device_descriptor_list.h\"\n\nnamespace oneflow {\n\nnamespace hardware {\n\nBasicDeviceDescriptorList::BasicDeviceDescriptorList(\n    std::vector<std::shared_ptr<const DeviceDescriptor>> device_descriptor_list)\n    : device_descriptor_list_(std::move(device_descriptor_list)) {}\n\nBasicDeviceDescriptorList::BasicDeviceDescriptorList()\n    : BasicDeviceDescriptorList(std::vector<std::shared_ptr<const DeviceDescriptor>>()) {}\n\nBasicDeviceDescriptorList::~BasicDeviceDescriptorList() = default;\n\nsize_t BasicDeviceDescriptorList::DeviceCount() const { return device_descriptor_list_.size(); }\n\nstd::shared_ptr<const DeviceDescriptor> BasicDeviceDescriptorList::GetDevice(size_t ordinal) const {\n  if (ordinal < device_descriptor_list_.size()) {\n    return device_descriptor_list_.at(ordinal);\n  } else {\n    return nullptr;\n  }\n}\n\n}  // namespace hardware\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/hardware/basic_device_descriptor_list.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_HARDWARE_BASIC_DEVICE_DESCRIPTOR_LIST_H_\n#define ONEFLOW_CORE_HARDWARE_BASIC_DEVICE_DESCRIPTOR_LIST_H_\n\n#include \"oneflow/core/hardware/device_descriptor_list.h\"\n#include \"oneflow/core/common/util.h\"\n#include <cstdint>\n#include <memory>\n#include <vector>\n\nnamespace oneflow {\n\nnamespace hardware {\n\nclass BasicDeviceDescriptorList : public DeviceDescriptorList {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(BasicDeviceDescriptorList);\n  explicit BasicDeviceDescriptorList(\n      std::vector<std::shared_ptr<const DeviceDescriptor>> device_descriptor_list);\n  BasicDeviceDescriptorList();\n  ~BasicDeviceDescriptorList() override;\n\n  size_t DeviceCount() const override;\n  std::shared_ptr<const DeviceDescriptor> GetDevice(size_t ordinal) const override;\n\n private:\n  std::vector<std::shared_ptr<const DeviceDescriptor>> device_descriptor_list_;\n};\n\n}  // namespace hardware\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_HARDWARE_BASIC_DEVICE_DESCRIPTOR_LIST_H_\n"
  },
  {
    "path": "oneflow/core/hardware/cuda_device_descriptor.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/hardware/cuda_device_descriptor.h\"\n\n#ifdef WITH_CUDA\n\n#include <cuda_runtime.h>\n#include <cuda.h>\n#include \"nlohmann/json.hpp\"\n\nnamespace oneflow {\n\nnamespace hardware {\n\nnamespace {\n\nconstexpr char kJsonKeyOrdinal[] = \"ordinal\";\nconstexpr char kJsonKeyName[] = \"name\";\nconstexpr char kJsonKeyTotalGlobalMemory[] = \"total_global_memory_bytes\";\nconstexpr char kJsonKeyClockRate[] = \"clock_rate_khz\";\nconstexpr char kJsonKeyComputeCapabilityMajor[] = \"compute_capability_major\";\nconstexpr char kJsonKeyComputeCapabilityMinor[] = \"compute_capability_minor\";\nconstexpr char kJsonKeyMemoryClockRate[] = \"memory_clock_rate_khz\";\nconstexpr char kJsonKeyMemoryBusWidth[] = \"memory_bus_width_bit\";\nconstexpr char kJsonKeyPCIBusID[] = \"pci_bus_id\";\n\n}  // namespace\n\nstruct CudaDeviceDescriptor::Impl {\n  int32_t ordinal{};\n  std::string name;\n  size_t total_global_memory_bytes{};\n  int32_t clock_rate_khz{};\n  int32_t compute_capability_major{};\n  int32_t compute_capability_minor{};\n  int32_t memory_clock_rate_khz{};\n  int32_t memory_bus_width_bit{};\n  std::string pci_bus_id;\n};\n\nCudaDeviceDescriptor::CudaDeviceDescriptor() { impl_.reset(new Impl()); }\n\nCudaDeviceDescriptor::~CudaDeviceDescriptor() = default;\n\nint32_t CudaDeviceDescriptor::Ordinal() const { return impl_->ordinal; }\n\nconst std::string& CudaDeviceDescriptor::Name() const { return impl_->name; }\n\nsize_t CudaDeviceDescriptor::GlobalMemorySizeBytes() const {\n  return impl_->total_global_memory_bytes;\n}\n\nint32_t CudaDeviceDescriptor::ClockRateKHz() const { return impl_->clock_rate_khz; }\n\nint32_t CudaDeviceDescriptor::ComputeCapabilityMajor() const {\n  return impl_->compute_capability_major;\n}\n\nint32_t CudaDeviceDescriptor::ComputeCapabilityMinor() const {\n  return impl_->compute_capability_minor;\n}\n\nint32_t CudaDeviceDescriptor::MemoryClockRateKHz() const { return impl_->memory_clock_rate_khz; }\n\nint32_t CudaDeviceDescriptor::MemoryBusWidthBit() const { return impl_->memory_bus_width_bit; }\n\nconst std::string& CudaDeviceDescriptor::PCIBusID() const { return impl_->pci_bus_id; }\n\nstd::shared_ptr<const CudaDeviceDescriptor> CudaDeviceDescriptor::Query(int32_t ordinal) {\n  cudaDeviceProp prop{};\n  const cudaError_t err = cudaGetDeviceProperties(&prop, ordinal);\n  CHECK(err == cudaSuccess);\n  static const std::set<int> compiled_archs{CUDA_REAL_ARCHS};\n  if (compiled_archs.find(prop.major * 10 + prop.minor) == compiled_archs.cend()\n      && compiled_archs.find(prop.major * 10) == compiled_archs.cend()) {\n    static std::atomic<bool> once_flag(false);\n    if (!once_flag.exchange(true)) {\n      LOG(WARNING)\n          << \"The CUDA device '\" << prop.name << \"' with capability \"\n          << prop.major * 10 + prop.minor\n          << \" is not compatible with the current OneFlow installation. The current program \"\n             \"may throw a 'no kernel image is available for execution \"\n             \"on the device' error or hang for a long time. Please reinstall OneFlow \"\n             \"compiled with a newer version of CUDA.\";\n    }\n  }\n  auto* desc = new CudaDeviceDescriptor();\n  desc->impl_->ordinal = ordinal;\n  desc->impl_->name = prop.name;\n  desc->impl_->total_global_memory_bytes = prop.totalGlobalMem;\n  desc->impl_->clock_rate_khz = prop.clockRate;\n  desc->impl_->compute_capability_major = prop.major;\n  desc->impl_->compute_capability_minor = prop.minor;\n  desc->impl_->memory_clock_rate_khz = prop.memoryClockRate;\n  desc->impl_->memory_bus_width_bit = prop.memoryBusWidth;\n  char pci_bus_id_buf[sizeof(\"00000000:00:00.0\")];\n  if (CUDA_VERSION >= 11000\n      && cudaDeviceGetPCIBusId(pci_bus_id_buf, sizeof(pci_bus_id_buf), ordinal) == cudaSuccess) {\n    for (int i = 0; i < sizeof(pci_bus_id_buf) - 1; ++i) {\n      pci_bus_id_buf[i] = static_cast<char>(std::tolower(pci_bus_id_buf[i]));\n    }\n    desc->impl_->pci_bus_id = pci_bus_id_buf;\n  } else {\n    desc->impl_->pci_bus_id = \"\";\n  }\n  return std::shared_ptr<const CudaDeviceDescriptor>(desc);\n}\n\nvoid CudaDeviceDescriptor::Serialize(std::string* serialized) const {\n  nlohmann::json json_object;\n  json_object[kJsonKeyOrdinal] = impl_->ordinal;\n  json_object[kJsonKeyName] = impl_->name;\n  json_object[kJsonKeyTotalGlobalMemory] = impl_->total_global_memory_bytes;\n  json_object[kJsonKeyClockRate] = impl_->clock_rate_khz;\n  json_object[kJsonKeyComputeCapabilityMajor] = impl_->compute_capability_major;\n  json_object[kJsonKeyComputeCapabilityMinor] = impl_->compute_capability_minor;\n  json_object[kJsonKeyMemoryClockRate] = impl_->memory_clock_rate_khz;\n  json_object[kJsonKeyMemoryBusWidth] = impl_->memory_bus_width_bit;\n  json_object[kJsonKeyPCIBusID] = impl_->pci_bus_id;\n  *serialized = json_object.dump(2);\n}\n\nstd::shared_ptr<const CudaDeviceDescriptor> CudaDeviceDescriptor::Deserialize(\n    const std::string& serialized) {\n  auto json_object = nlohmann::json::parse(serialized);\n  auto* desc = new CudaDeviceDescriptor();\n  desc->impl_->ordinal = json_object[kJsonKeyOrdinal];\n  desc->impl_->name = json_object[kJsonKeyName];\n  desc->impl_->total_global_memory_bytes = json_object[kJsonKeyTotalGlobalMemory];\n  desc->impl_->clock_rate_khz = json_object[kJsonKeyClockRate];\n  desc->impl_->compute_capability_major = json_object[kJsonKeyComputeCapabilityMajor];\n  desc->impl_->compute_capability_minor = json_object[kJsonKeyComputeCapabilityMinor];\n  desc->impl_->memory_clock_rate_khz = json_object[kJsonKeyMemoryClockRate];\n  desc->impl_->memory_bus_width_bit = json_object[kJsonKeyMemoryBusWidth];\n  desc->impl_->pci_bus_id = json_object[kJsonKeyPCIBusID];\n  return std::shared_ptr<const CudaDeviceDescriptor>(desc);\n}\n\n}  // namespace hardware\n\n}  // namespace oneflow\n\n#endif  // WITH_CUDA\n"
  },
  {
    "path": "oneflow/core/hardware/cuda_device_descriptor.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_HARDWARE_CUDA_DEVICE_DESCRIPTOR_H_\n#define ONEFLOW_CORE_HARDWARE_CUDA_DEVICE_DESCRIPTOR_H_\n\n#include \"oneflow/core/hardware/device_descriptor.h\"\n#include \"oneflow/core/common/util.h\"\n#include <string>\n#include <memory>\n\n#ifdef WITH_CUDA\n\nnamespace oneflow {\n\nnamespace hardware {\n\nconstexpr char kCudaDeviceDescriptorClassName[] = \"cuda\";\n\nclass CudaDeviceDescriptor : public DeviceDescriptor {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CudaDeviceDescriptor);\n  ~CudaDeviceDescriptor() override;\n\n  int32_t Ordinal() const;\n  const std::string& Name() const;\n  size_t GlobalMemorySizeBytes() const;\n  int32_t ClockRateKHz() const;\n  int32_t ComputeCapabilityMajor() const;\n  int32_t ComputeCapabilityMinor() const;\n  int32_t MemoryClockRateKHz() const;\n  int32_t MemoryBusWidthBit() const;\n  const std::string& PCIBusID() const;\n  void Serialize(std::string* serialized) const;\n  static std::shared_ptr<const CudaDeviceDescriptor> Query(int32_t ordinal);\n  static std::shared_ptr<const CudaDeviceDescriptor> Deserialize(const std::string& serialized);\n\n private:\n  CudaDeviceDescriptor();\n\n  struct Impl;\n  std::unique_ptr<Impl> impl_;\n};\n\n}  // namespace hardware\n\n}  // namespace oneflow\n\n#endif  // WITH_CUDA\n\n#endif  // ONEFLOW_CORE_HARDWARE_CUDA_DEVICE_DESCRIPTOR_H_\n"
  },
  {
    "path": "oneflow/core/hardware/cuda_device_descriptor_class.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/hardware/device_descriptor_class.h\"\n#include \"oneflow/core/hardware/cuda_device_descriptor.h\"\n#include \"oneflow/core/hardware/basic_device_descriptor_list.h\"\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/persistence/tee_persistent_log_stream.h\"\n#include \"oneflow/core/common/str_util.h\"\n#include \"nlohmann/json.hpp\"\n\n#ifdef WITH_CUDA\n\n#include <cuda_runtime.h>\n\nnamespace oneflow {\n\nnamespace hardware {\n\nnamespace {\n\nconstexpr char kJsonKeyDevices[] = \"devices\";\n\n}  // namespace\n\nclass CudaDeviceDescriptorClass : public DeviceDescriptorClass {\n public:\n  CudaDeviceDescriptorClass() = default;\n  ~CudaDeviceDescriptorClass() override = default;\n\n  std::shared_ptr<const DeviceDescriptorList> QueryDeviceDescriptorList() const override {\n    int n_dev = 0;\n    cudaError_t err = cudaGetDeviceCount(&n_dev);\n    if (err != cudaSuccess) {\n      LOG(WARNING) << cudaGetErrorString(err);\n      return std::make_shared<const BasicDeviceDescriptorList>(\n          std::vector<std::shared_ptr<const DeviceDescriptor>>());\n    }\n    std::vector<std::shared_ptr<const DeviceDescriptor>> devices(n_dev);\n    for (int dev = 0; dev < n_dev; ++dev) { devices.at(dev) = CudaDeviceDescriptor::Query(dev); }\n    return std::make_shared<const BasicDeviceDescriptorList>(devices);\n  }\n\n  std::string Name() const override { return kCudaDeviceDescriptorClassName; }\n\n  void SerializeDeviceDescriptorList(const std::shared_ptr<const DeviceDescriptorList>& list,\n                                     std::string* serialized) const override {\n    std::vector<std::string> serialized_devices;\n    serialized_devices.reserve(list->DeviceCount());\n    for (size_t i = 0; i < list->DeviceCount(); ++i) {\n      auto cuda_device = std::dynamic_pointer_cast<const CudaDeviceDescriptor>(list->GetDevice(i));\n      CHECK(cuda_device);\n      std::string serialized_device;\n      cuda_device->Serialize(&serialized_device);\n      serialized_devices.emplace_back(std::move(serialized_device));\n    }\n    nlohmann::json json_object;\n    json_object[kJsonKeyDevices] = serialized_devices;\n    *serialized = json_object.dump();\n  }\n\n  std::shared_ptr<const DeviceDescriptorList> DeserializeDeviceDescriptorList(\n      const std::string& serialized) const override {\n    auto json_object = nlohmann::json::parse(serialized);\n    std::vector<std::string> serialized_devices = json_object[kJsonKeyDevices];\n    std::vector<std::shared_ptr<const DeviceDescriptor>> devices(serialized_devices.size());\n    for (int i = 0; i < serialized_devices.size(); ++i) {\n      devices.at(i) = CudaDeviceDescriptor::Deserialize(serialized_devices.at(i));\n    }\n    return std::make_shared<const BasicDeviceDescriptorList>(devices);\n  }\n\n  void DumpDeviceDescriptorListSummary(const std::shared_ptr<const DeviceDescriptorList>& list,\n                                       const std::string& path) const override {\n    for (size_t i = 0; i < list->DeviceCount(); ++i) {\n      auto cuda_device = std::dynamic_pointer_cast<const CudaDeviceDescriptor>(list->GetDevice(i));\n      CHECK(cuda_device);\n      auto stream = TeePersistentLogStream::Create(JoinPath(path, std::to_string(i) + \".json\"));\n      std::string serialized;\n      cuda_device->Serialize(&serialized);\n      stream << serialized;\n    }\n  }\n};\n\nCOMMAND(DeviceDescriptorClass::RegisterClass(std::make_shared<CudaDeviceDescriptorClass>()));\n\n}  // namespace hardware\n\n}  // namespace oneflow\n\n#endif  // WITH_CUDA\n"
  },
  {
    "path": "oneflow/core/hardware/device_descriptor.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_HARDWARE_DEVICE_DESCRIPTOR_H_\n#define ONEFLOW_CORE_HARDWARE_DEVICE_DESCRIPTOR_H_\n\n#include \"oneflow/core/common/util.h\"\n\nnamespace oneflow {\n\nnamespace hardware {\n\nclass DeviceDescriptor {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(DeviceDescriptor);\n  DeviceDescriptor() = default;\n  virtual ~DeviceDescriptor() = default;\n};\n\n}  // namespace hardware\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_HARDWARE_DEVICE_DESCRIPTOR_H_\n"
  },
  {
    "path": "oneflow/core/hardware/device_descriptor_class.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/hardware/device_descriptor_class.h\"\n#include <mutex>\n#include <utility>\n#include <vector>\n#include <unordered_map>\n\nnamespace oneflow {\n\nnamespace hardware {\n\nnamespace {\n\nclass DeviceClassRegistryStorage {\n public:\n  DeviceClassRegistryStorage() = default;\n  ~DeviceClassRegistryStorage() = default;\n  void Register(std::shared_ptr<const DeviceDescriptorClass> descriptor_class) {\n    std::lock_guard<std::mutex> lock(mutex_);\n    const std::string name = descriptor_class->Name();\n    if (!name2index_.emplace(name, classes_.size()).second) { abort(); }\n    classes_.emplace_back(std::make_shared<std::string>(name), std::move(descriptor_class));\n  }\n\n  size_t RegisteredCount() {\n    std::lock_guard<std::mutex> lock(mutex_);\n    return classes_.size();\n  }\n\n  const std::string& GetRegisteredClass(size_t index) {\n    std::lock_guard<std::mutex> lock(mutex_);\n    return *classes_.at(index).first;\n  }\n\n  std::shared_ptr<const DeviceDescriptorClass> GetRegistered(size_t index) {\n    std::lock_guard<std::mutex> lock(mutex_);\n    return classes_.at(index).second;\n  }\n\n  std::shared_ptr<const DeviceDescriptorClass> GetRegistered(const std::string& name) {\n    std::lock_guard<std::mutex> lock(mutex_);\n    auto it = name2index_.find(name);\n    if (it == name2index_.end()) { return std::shared_ptr<const DeviceDescriptorClass>(); }\n    return classes_.at(it->second).second;\n  }\n\n  static DeviceClassRegistryStorage& Instance() {\n    static DeviceClassRegistryStorage instance;\n    return instance;\n  }\n\n private:\n  std::unordered_map<std::string, size_t> name2index_;\n  std::vector<std::pair<std::shared_ptr<std::string>, std::shared_ptr<const DeviceDescriptorClass>>>\n      classes_;\n  std::mutex mutex_;\n};\n\n}  // namespace\n\nvoid DeviceDescriptorClass::RegisterClass(\n    std::shared_ptr<const DeviceDescriptorClass> descriptor_class) {\n  DeviceClassRegistryStorage::Instance().Register(std::move(descriptor_class));\n}\n\nsize_t DeviceDescriptorClass::GetRegisteredClassesCount() {\n  return DeviceClassRegistryStorage::Instance().RegisteredCount();\n}\n\nstd::shared_ptr<const DeviceDescriptorClass> DeviceDescriptorClass::GetRegisteredClass(\n    size_t index) {\n  return DeviceClassRegistryStorage::Instance().GetRegistered(index);\n}\n\nstd::shared_ptr<const DeviceDescriptorClass> DeviceDescriptorClass::GetRegisteredClass(\n    const std::string& class_name) {\n  return DeviceClassRegistryStorage::Instance().GetRegistered(class_name);\n}\n\n}  // namespace hardware\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/hardware/device_descriptor_class.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_HARDWARE_DEVICE_DESCRIPTOR_CLASS_H_\n#define ONEFLOW_CORE_HARDWARE_DEVICE_DESCRIPTOR_CLASS_H_\n\n#include \"oneflow/core/hardware/device_descriptor_list.h\"\n#include \"oneflow/core/common/util.h\"\n\nnamespace oneflow {\n\nnamespace hardware {\n\nclass DeviceDescriptorClass {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(DeviceDescriptorClass);\n  DeviceDescriptorClass() = default;\n  virtual ~DeviceDescriptorClass() = default;\n\n  virtual std::shared_ptr<const DeviceDescriptorList> QueryDeviceDescriptorList() const = 0;\n  virtual std::string Name() const = 0;\n  virtual void SerializeDeviceDescriptorList(\n      const std::shared_ptr<const DeviceDescriptorList>& list, std::string* serialized) const = 0;\n  virtual std::shared_ptr<const DeviceDescriptorList> DeserializeDeviceDescriptorList(\n      const std::string& serialized) const = 0;\n  virtual void DumpDeviceDescriptorListSummary(\n      const std::shared_ptr<const DeviceDescriptorList>& list, const std::string& path) const = 0;\n\n  static void RegisterClass(std::shared_ptr<const DeviceDescriptorClass> descriptor_class);\n  static size_t GetRegisteredClassesCount();\n  static std::shared_ptr<const DeviceDescriptorClass> GetRegisteredClass(size_t index);\n  static std::shared_ptr<const DeviceDescriptorClass> GetRegisteredClass(\n      const std::string& class_name);\n};\n\n}  // namespace hardware\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_HARDWARE_DEVICE_DESCRIPTOR_CLASS_H_\n"
  },
  {
    "path": "oneflow/core/hardware/device_descriptor_list.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_HARDWARE_DEVICE_DESCRIPTOR_LIST_H_\n#define ONEFLOW_CORE_HARDWARE_DEVICE_DESCRIPTOR_LIST_H_\n\n#include \"oneflow/core/hardware/device_descriptor.h\"\n#include \"oneflow/core/common/util.h\"\n#include <cstdint>\n#include <memory>\n\nnamespace oneflow {\n\nnamespace hardware {\n\nclass DeviceDescriptorList {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(DeviceDescriptorList);\n  DeviceDescriptorList() = default;\n  virtual ~DeviceDescriptorList() = default;\n\n  virtual size_t DeviceCount() const = 0;\n  virtual std::shared_ptr<const DeviceDescriptor> GetDevice(size_t ordinal) const = 0;\n};\n\n}  // namespace hardware\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_HARDWARE_DEVICE_DESCRIPTOR_LIST_H_\n"
  },
  {
    "path": "oneflow/core/hardware/net_ib_device_descriptor.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/hardware/net_ib_device_descriptor.h\"\n\n#ifdef WITH_RDMA\n\n#include \"nlohmann/json.hpp\"\n\nnamespace oneflow {\n\nnamespace hardware {\n\nnamespace {\n\nconstexpr char kJsonKeyOrdinal[] = \"ordinal\";\nconstexpr char kJsonKeyName[] = \"name\";\nconstexpr char kJsonKeyGUID[] = \"guid\";\nconstexpr char kJsonKeyPort[] = \"port\";\nconstexpr char kJsonKeyLankLayer[] = \"link_layer\";\nconstexpr char kJsonValueLinkLayerInfiniBand[] = \"InfiniBand\";\nconstexpr char kJsonValueLinkLayerEthernet[] = \"Ethernet\";\nconstexpr char kJsonKeyPCIBusID[] = \"pci_bus_id\";\n\nvoid GetPCIBusID(const std::string& name, std::string* pci_bus_id) {\n#ifdef __linux__\n  const std::string device_path = \"/sys/class/infiniband/\" + name + \"/device\";\n  const char* device_real_path = realpath(device_path.data(), nullptr);\n  if (device_real_path == nullptr) { return; }\n  const std::string device_real_path_str = device_real_path;\n  const size_t pos = device_real_path_str.rfind('/');\n  if (pos == std::string::npos) { return; }\n  *pci_bus_id = device_real_path_str.substr(pos + 1);\n#endif\n}\n\n}  // namespace\n\nstruct NetIBDeviceDescriptor::Impl {\n  int32_t ordinal{};\n  std::string name;\n  uint64_t guid{};\n  uint8_t port{};\n  NetIBDeviceDescriptorLinkLayer link_layer{};\n  std::string pci_bus_id;\n};\n\nNetIBDeviceDescriptor::NetIBDeviceDescriptor() { impl_.reset(new Impl()); }\n\nNetIBDeviceDescriptor::~NetIBDeviceDescriptor() = default;\n\nint32_t NetIBDeviceDescriptor::Ordinal() const { return impl_->ordinal; }\n\nconst std::string& NetIBDeviceDescriptor::Name() const { return impl_->name; }\n\nuint64_t NetIBDeviceDescriptor::GUID() const { return impl_->guid; }\n\nuint8_t NetIBDeviceDescriptor::Port() const { return impl_->port; }\n\nNetIBDeviceDescriptorLinkLayer NetIBDeviceDescriptor::LinkLayer() const {\n  return impl_->link_layer;\n}\n\nconst std::string& NetIBDeviceDescriptor::PCIBusID() const { return impl_->pci_bus_id; }\n\nvoid NetIBDeviceDescriptor::Serialize(std::string* serialized) const {\n  nlohmann::json json_object;\n  json_object[kJsonKeyOrdinal] = impl_->ordinal;\n  json_object[kJsonKeyName] = impl_->name;\n  json_object[kJsonKeyGUID] = impl_->guid;\n  json_object[kJsonKeyPort] = impl_->port;\n  if (impl_->link_layer == kNetIBDeviceDescriptorLinkLayerInfiniBand) {\n    json_object[kJsonKeyLankLayer] = kJsonValueLinkLayerInfiniBand;\n  } else if (impl_->link_layer == kNetIBDeviceDescriptorLinkLayerEthernet) {\n    json_object[kJsonKeyLankLayer] = kJsonValueLinkLayerEthernet;\n  } else {\n    UNIMPLEMENTED();\n  }\n  json_object[kJsonKeyPCIBusID] = impl_->pci_bus_id;\n  *serialized = json_object.dump(2);\n}\n\nstd::shared_ptr<const NetIBDeviceDescriptor> NetIBDeviceDescriptor::Query(int32_t ordinal,\n                                                                          ibv_context* context,\n                                                                          uint8_t port) {\n  CHECK(ibv::IsAvailable());\n  ibv_device_attr device_attr{};\n  if (ibv::wrapper.ibv_query_device(context, &device_attr) != 0) {\n    VLOG(3) << \"Unable to query device: \" << context->device->name;\n    return std::shared_ptr<const NetIBDeviceDescriptor>();\n  }\n  ibv_port_attr port_attr{};\n  if (ibv::wrapper.ibv_query_port_wrap(context, port, &port_attr) != 0) {\n    VLOG(3) << \"Unable to query port: device \" << context->device->name << \" port \" << port;\n    return std::shared_ptr<const NetIBDeviceDescriptor>();\n  }\n  if (port_attr.state != IBV_PORT_ACTIVE) {\n    VLOG(3) << \"Inactivate port: device \" << context->device->name << \" port \" << port;\n    return std::shared_ptr<const NetIBDeviceDescriptor>();\n  }\n  if (port_attr.link_layer != IBV_LINK_LAYER_INFINIBAND\n      && port_attr.link_layer != IBV_LINK_LAYER_ETHERNET) {\n    VLOG(3) << \"Link layer is not supported: device \" << context->device->name << \" port \" << port;\n    return std::shared_ptr<const NetIBDeviceDescriptor>();\n  }\n  auto* desc = new NetIBDeviceDescriptor();\n  desc->impl_->ordinal = ordinal;\n  desc->impl_->name = context->device->name;\n  desc->impl_->guid = device_attr.sys_image_guid;\n  desc->impl_->port = port;\n  if (port_attr.link_layer == IBV_LINK_LAYER_INFINIBAND) {\n    desc->impl_->link_layer = kNetIBDeviceDescriptorLinkLayerInfiniBand;\n  } else if (port_attr.link_layer == IBV_LINK_LAYER_ETHERNET) {\n    desc->impl_->link_layer = kNetIBDeviceDescriptorLinkLayerEthernet;\n  } else {\n    UNIMPLEMENTED();\n  }\n  GetPCIBusID(desc->impl_->name, &desc->impl_->pci_bus_id);\n  return std::shared_ptr<const NetIBDeviceDescriptor>(desc);\n}\n\nstd::shared_ptr<const NetIBDeviceDescriptor> NetIBDeviceDescriptor::Deserialize(\n    const std::string& serialized) {\n  auto json_object = nlohmann::json::parse(serialized);\n  auto* desc = new NetIBDeviceDescriptor();\n  desc->impl_->ordinal = json_object[kJsonKeyOrdinal];\n  desc->impl_->name = json_object[kJsonKeyName];\n  desc->impl_->guid = json_object[kJsonKeyGUID];\n  desc->impl_->port = json_object[kJsonKeyPort];\n  const std::string link_layer_value = json_object[kJsonKeyLankLayer];\n  if (link_layer_value == kJsonValueLinkLayerInfiniBand) {\n    desc->impl_->link_layer = kNetIBDeviceDescriptorLinkLayerInfiniBand;\n  } else if (link_layer_value == kJsonValueLinkLayerEthernet) {\n    desc->impl_->link_layer = kNetIBDeviceDescriptorLinkLayerEthernet;\n  } else {\n    UNIMPLEMENTED();\n  }\n  desc->impl_->pci_bus_id = json_object[kJsonKeyPCIBusID];\n  return std::shared_ptr<const NetIBDeviceDescriptor>(desc);\n}\n\n}  // namespace hardware\n\n}  // namespace oneflow\n\n#endif  // WITH_RDMA\n"
  },
  {
    "path": "oneflow/core/hardware/net_ib_device_descriptor.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_HARDWARE_NET_IB_DEVICE_DESCRIPTOR_H_\n#define ONEFLOW_CORE_HARDWARE_NET_IB_DEVICE_DESCRIPTOR_H_\n\n#include \"oneflow/core/hardware/device_descriptor.h\"\n#include \"oneflow/core/common/util.h\"\n#include <string>\n#include <memory>\n\n#ifdef WITH_RDMA\n\n#include \"oneflow/core/platform/include/ibv.h\"\n\nnamespace oneflow {\n\nnamespace hardware {\n\nconstexpr char kNetIBDeviceDescriptorClassName[] = \"net_ib\";\n\nenum NetIBDeviceDescriptorLinkLayer {\n  kNetIBDeviceDescriptorLinkLayerInvalid = 0,\n  kNetIBDeviceDescriptorLinkLayerInfiniBand = 1,\n  kNetIBDeviceDescriptorLinkLayerEthernet = 2,\n};\n\nclass NetIBDeviceDescriptor : public DeviceDescriptor {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(NetIBDeviceDescriptor);\n  ~NetIBDeviceDescriptor() override;\n\n  int32_t Ordinal() const;\n  const std::string& Name() const;\n  uint64_t GUID() const;\n  uint8_t Port() const;\n  NetIBDeviceDescriptorLinkLayer LinkLayer() const;\n  const std::string& PCIBusID() const;\n  void Serialize(std::string* serialized) const;\n  static std::shared_ptr<const NetIBDeviceDescriptor> Query(int32_t ordinal, ibv_context* context,\n                                                            uint8_t port);\n  static std::shared_ptr<const NetIBDeviceDescriptor> Deserialize(const std::string& serialized);\n\n private:\n  NetIBDeviceDescriptor();\n\n  struct Impl;\n  std::unique_ptr<Impl> impl_;\n};\n\n}  // namespace hardware\n\n}  // namespace oneflow\n\n#endif  // WITH_RDMA\n\n#endif  // ONEFLOW_CORE_HARDWARE_NET_IB_DEVICE_DESCRIPTOR_H_\n"
  },
  {
    "path": "oneflow/core/hardware/net_ib_device_descriptor_class.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/hardware/device_descriptor_class.h\"\n#include \"oneflow/core/hardware/net_ib_device_descriptor.h\"\n#include \"oneflow/core/hardware/basic_device_descriptor_list.h\"\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/persistence/tee_persistent_log_stream.h\"\n#include \"oneflow/core/common/str_util.h\"\n#include \"nlohmann/json.hpp\"\n\n#ifdef WITH_RDMA\n\nnamespace oneflow {\n\nnamespace hardware {\n\nnamespace {\n\nconstexpr char kJsonKeyDevices[] = \"devices\";\n\n}  // namespace\n\nclass NetIBDeviceDescriptorClass : public DeviceDescriptorClass {\n public:\n  NetIBDeviceDescriptorClass() = default;\n  ~NetIBDeviceDescriptorClass() override = default;\n\n  std::shared_ptr<const DeviceDescriptorList> QueryDeviceDescriptorList() const override {\n    std::vector<std::shared_ptr<const DeviceDescriptor>> devices;\n    int num_devices;\n    if (!ibv::IsAvailable()) { return std::make_shared<const BasicDeviceDescriptorList>(devices); }\n    ibv_device** device_list = ibv::wrapper.ibv_get_device_list(&num_devices);\n    if (device_list == nullptr) {\n      return std::make_shared<const BasicDeviceDescriptorList>(devices);\n    }\n    for (int i = 0; i < num_devices; ++i) {\n      ibv_device* device = device_list[i];\n      ibv_context* context = ibv::wrapper.ibv_open_device(device);\n      if (context == nullptr) { continue; }\n      ibv_device_attr device_attr{};\n      if (ibv::wrapper.ibv_query_device(context, &device_attr) != 0) {\n        CHECK_EQ(ibv::wrapper.ibv_close_device(context), 0);\n      }\n      for (int port = 1; port <= device_attr.phys_port_cnt; ++port) {\n        auto device_desc =\n            NetIBDeviceDescriptor::Query(static_cast<int32_t>(devices.size()), context, port);\n        if (device_desc) { devices.emplace_back(device_desc); }\n      }\n    }\n    ibv::wrapper.ibv_free_device_list(device_list);\n    return std::make_shared<const BasicDeviceDescriptorList>(devices);\n  }\n\n  std::string Name() const override { return kNetIBDeviceDescriptorClassName; }\n\n  void SerializeDeviceDescriptorList(const std::shared_ptr<const DeviceDescriptorList>& list,\n                                     std::string* serialized) const override {\n    std::vector<std::string> serialized_devices;\n    serialized_devices.reserve(list->DeviceCount());\n    for (size_t i = 0; i < list->DeviceCount(); ++i) {\n      auto ib_device = std::dynamic_pointer_cast<const NetIBDeviceDescriptor>(list->GetDevice(i));\n      CHECK(ib_device);\n      std::string serialized_device;\n      ib_device->Serialize(&serialized_device);\n      serialized_devices.emplace_back(std::move(serialized_device));\n    }\n    nlohmann::json json_object;\n    json_object[kJsonKeyDevices] = serialized_devices;\n    *serialized = json_object.dump();\n  }\n\n  std::shared_ptr<const DeviceDescriptorList> DeserializeDeviceDescriptorList(\n      const std::string& serialized) const override {\n    auto json_object = nlohmann::json::parse(serialized);\n    std::vector<std::string> serialized_devices = json_object[kJsonKeyDevices];\n    std::vector<std::shared_ptr<const DeviceDescriptor>> devices(serialized_devices.size());\n    for (int i = 0; i < serialized_devices.size(); ++i) {\n      devices.at(i) = NetIBDeviceDescriptor::Deserialize(serialized_devices.at(i));\n    }\n    return std::make_shared<const BasicDeviceDescriptorList>(devices);\n  }\n\n  void DumpDeviceDescriptorListSummary(const std::shared_ptr<const DeviceDescriptorList>& list,\n                                       const std::string& path) const override {\n    for (size_t i = 0; i < list->DeviceCount(); ++i) {\n      auto ib_device = std::dynamic_pointer_cast<const NetIBDeviceDescriptor>(list->GetDevice(i));\n      CHECK(ib_device);\n      auto stream = TeePersistentLogStream::Create(JoinPath(path, std::to_string(i) + \".json\"));\n      std::string serialized;\n      ib_device->Serialize(&serialized);\n      stream << serialized;\n    }\n  }\n};\n\nCOMMAND(DeviceDescriptorClass::RegisterClass(std::make_shared<NetIBDeviceDescriptorClass>()));\n\n}  // namespace hardware\n\n}  // namespace oneflow\n\n#endif  // WITH_RDMA\n"
  },
  {
    "path": "oneflow/core/hardware/net_socket_device_descriptor.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#ifdef __linux__\n\n#include \"oneflow/core/hardware/net_socket_device_descriptor.h\"\n#include \"nlohmann/json.hpp\"\n\nnamespace oneflow {\n\nnamespace hardware {\n\nnamespace {\n\nconstexpr char kJsonKeyOrdinal[] = \"ordinal\";\nconstexpr char kJsonKeyName[] = \"name\";\nconstexpr char kJsonKeyAddress[] = \"address\";\nconstexpr char kJsonKeyPCIBusID[] = \"pci_bus_id\";\n\nvoid GetPCIBusID(const std::string& name, std::string* pci_bus_id) {\n#ifdef __linux__\n  const std::string device_path = \"/sys/class/net/\" + name + \"/device\";\n  char* device_real_path = realpath(device_path.data(), nullptr);\n  if (device_real_path == nullptr) { return; }\n  const std::string device_real_path_str = device_real_path;\n  free(device_real_path);  // NOLINT\n  const size_t pos = device_real_path_str.rfind('/');\n  if (pos == std::string::npos) { return; }\n  *pci_bus_id = device_real_path_str.substr(pos + 1);\n#endif\n}\n\n}  // namespace\n\nstruct NetSocketDeviceDescriptor::Impl {\n  int32_t ordinal{};\n  std::string name;\n  std::string address;\n  std::string pci_bus_id;\n};\n\nNetSocketDeviceDescriptor::NetSocketDeviceDescriptor() { impl_.reset(new Impl()); }\n\nNetSocketDeviceDescriptor::~NetSocketDeviceDescriptor() = default;\n\nint32_t NetSocketDeviceDescriptor::Ordinal() const { return impl_->ordinal; }\n\nconst std::string& NetSocketDeviceDescriptor::Name() const { return impl_->name; }\n\nconst std::string& NetSocketDeviceDescriptor::Address() const { return impl_->address; }\n\nconst std::string& NetSocketDeviceDescriptor::PCIBusID() const { return impl_->pci_bus_id; }\n\nvoid NetSocketDeviceDescriptor::Serialize(std::string* serialized) const {\n  nlohmann::json json_object;\n  json_object[kJsonKeyOrdinal] = impl_->ordinal;\n  json_object[kJsonKeyName] = impl_->name;\n  json_object[kJsonKeyAddress] = impl_->address;\n  json_object[kJsonKeyPCIBusID] = impl_->pci_bus_id;\n  *serialized = json_object.dump(2);\n}\n\nstd::shared_ptr<const NetSocketDeviceDescriptor> NetSocketDeviceDescriptor::Query(\n    int32_t ordinal, const std::string& name, const std::string& address) {\n  auto* desc = new NetSocketDeviceDescriptor();\n  desc->impl_->ordinal = ordinal;\n  desc->impl_->name = name;\n  desc->impl_->address = address;\n  GetPCIBusID(name, &desc->impl_->pci_bus_id);\n  return std::shared_ptr<const NetSocketDeviceDescriptor>(desc);\n}\n\nstd::shared_ptr<const NetSocketDeviceDescriptor> NetSocketDeviceDescriptor::Deserialize(\n    const std::string& serialized) {\n  auto json_object = nlohmann::json::parse(serialized);\n  auto* desc = new NetSocketDeviceDescriptor();\n  desc->impl_->ordinal = json_object[kJsonKeyOrdinal];\n  desc->impl_->name = json_object[kJsonKeyName];\n  desc->impl_->address = json_object[kJsonKeyAddress];\n  desc->impl_->pci_bus_id = json_object[kJsonKeyPCIBusID];\n  return std::shared_ptr<const NetSocketDeviceDescriptor>(desc);\n}\n\n}  // namespace hardware\n\n}  // namespace oneflow\n\n#endif  // __linux__\n"
  },
  {
    "path": "oneflow/core/hardware/net_socket_device_descriptor.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_HARDWARE_NET_SOCKET_DEVICE_DESCRIPTOR_H_\n#define ONEFLOW_CORE_HARDWARE_NET_SOCKET_DEVICE_DESCRIPTOR_H_\n\n#ifdef __linux__\n\n#include \"oneflow/core/hardware/device_descriptor.h\"\n#include \"oneflow/core/common/util.h\"\n#include <string>\n#include <memory>\n#include <ifaddrs.h>\n\nnamespace oneflow {\n\nnamespace hardware {\n\nconstexpr char kNetSocketDeviceDescriptorClassName[] = \"net_socket\";\n\nclass NetSocketDeviceDescriptor : public DeviceDescriptor {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(NetSocketDeviceDescriptor);\n  ~NetSocketDeviceDescriptor() override;\n\n  int32_t Ordinal() const;\n  const std::string& Name() const;\n  const std::string& Address() const;\n  const std::string& PCIBusID() const;\n  void Serialize(std::string* serialized) const;\n  static std::shared_ptr<const NetSocketDeviceDescriptor> Query(int32_t ordinal,\n                                                                const std::string& name,\n                                                                const std::string& address);\n  static std::shared_ptr<const NetSocketDeviceDescriptor> Deserialize(\n      const std::string& serialized);\n\n private:\n  NetSocketDeviceDescriptor();\n\n  struct Impl;\n  std::unique_ptr<Impl> impl_;\n};\n\n}  // namespace hardware\n\n}  // namespace oneflow\n\n#endif  // __linux__\n\n#endif  // ONEFLOW_CORE_HARDWARE_NET_SOCKET_DEVICE_DESCRIPTOR_H_\n"
  },
  {
    "path": "oneflow/core/hardware/net_socket_device_descriptor_class.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#ifdef __linux__\n\n#include \"oneflow/core/hardware/device_descriptor_class.h\"\n#include \"oneflow/core/hardware/net_socket_device_descriptor.h\"\n#include \"oneflow/core/hardware/basic_device_descriptor_list.h\"\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/persistence/tee_persistent_log_stream.h\"\n#include \"oneflow/core/common/str_util.h\"\n#include \"nlohmann/json.hpp\"\n#include <ifaddrs.h>\n#include <sys/socket.h>\n#include <netinet/in.h>\n#include <netdb.h>\n\nnamespace oneflow {\n\nnamespace hardware {\n\nnamespace {\n\nconstexpr char kJsonKeyDevices[] = \"devices\";\n\n}  // namespace\n\nclass NetSocketDeviceDescriptorClass : public DeviceDescriptorClass {\n public:\n  NetSocketDeviceDescriptorClass() = default;\n  ~NetSocketDeviceDescriptorClass() override = default;\n\n  std::shared_ptr<const DeviceDescriptorList> QueryDeviceDescriptorList() const override {\n    std::vector<std::shared_ptr<const NetSocketDeviceDescriptor>> devices;\n    ifaddrs* interfaces = nullptr;\n    if (getifaddrs(&interfaces) != 0) {\n      return std::make_shared<const BasicDeviceDescriptorList>();\n    }\n    ifaddrs* ifa = nullptr;\n    for (ifa = interfaces; ifa != nullptr; ifa = ifa->ifa_next) {\n      if (ifa->ifa_addr == nullptr) { continue; }\n      const std::string name(ifa->ifa_name);\n      if (name == \"lo\") { continue; }\n      // TODO(liujuncheng): support ipv6\n      if (ifa->ifa_addr->sa_family != AF_INET) { continue; }\n      if (std::count_if(devices.cbegin(), devices.cend(),\n                        [&](const std::shared_ptr<const NetSocketDeviceDescriptor>& device) {\n                          return device->Name() == name;\n                        })\n          != 0) {\n        continue;\n      }\n      char host[NI_MAXHOST];\n      const socklen_t sa_len = (ifa->ifa_addr->sa_family == AF_INET) ? sizeof(struct sockaddr_in)\n                                                                     : sizeof(struct sockaddr_in6);\n      if (getnameinfo(ifa->ifa_addr, sa_len, host, NI_MAXHOST, nullptr, 0, NI_NUMERICHOST) != 0) {\n        continue;\n      }\n      auto socket_device =\n          NetSocketDeviceDescriptor::Query(static_cast<int32_t>(devices.size()), name, host);\n      if (socket_device) { devices.emplace_back(socket_device); }\n    }\n    freeifaddrs(interfaces);\n    return std::make_shared<const BasicDeviceDescriptorList>(\n        std::vector<std::shared_ptr<const DeviceDescriptor>>{devices.begin(), devices.end()});\n  }\n\n  std::string Name() const override { return kNetSocketDeviceDescriptorClassName; }\n\n  void SerializeDeviceDescriptorList(const std::shared_ptr<const DeviceDescriptorList>& list,\n                                     std::string* serialized) const override {\n    std::vector<std::string> serialized_devices;\n    serialized_devices.reserve(list->DeviceCount());\n    for (size_t i = 0; i < list->DeviceCount(); ++i) {\n      auto socket_device =\n          std::dynamic_pointer_cast<const NetSocketDeviceDescriptor>(list->GetDevice(i));\n      CHECK(socket_device);\n      std::string serialized_device;\n      socket_device->Serialize(&serialized_device);\n      serialized_devices.emplace_back(std::move(serialized_device));\n    }\n    nlohmann::json json_object;\n    json_object[kJsonKeyDevices] = serialized_devices;\n    *serialized = json_object.dump();\n  }\n\n  std::shared_ptr<const DeviceDescriptorList> DeserializeDeviceDescriptorList(\n      const std::string& serialized) const override {\n    auto json_object = nlohmann::json::parse(serialized);\n    std::vector<std::string> serialized_devices = json_object[kJsonKeyDevices];\n    std::vector<std::shared_ptr<const DeviceDescriptor>> devices(serialized_devices.size());\n    for (int i = 0; i < serialized_devices.size(); ++i) {\n      devices.at(i) = NetSocketDeviceDescriptor::Deserialize(serialized_devices.at(i));\n    }\n    return std::make_shared<const BasicDeviceDescriptorList>(devices);\n  }\n\n  void DumpDeviceDescriptorListSummary(const std::shared_ptr<const DeviceDescriptorList>& list,\n                                       const std::string& path) const override {\n    for (size_t i = 0; i < list->DeviceCount(); ++i) {\n      auto socket_device =\n          std::dynamic_pointer_cast<const NetSocketDeviceDescriptor>(list->GetDevice(i));\n      CHECK(socket_device);\n      auto stream = TeePersistentLogStream::Create(JoinPath(path, std::to_string(i) + \".json\"));\n      std::string serialized;\n      socket_device->Serialize(&serialized);\n      stream << serialized;\n    }\n  }\n};\n\nCOMMAND(DeviceDescriptorClass::RegisterClass(std::make_shared<NetSocketDeviceDescriptorClass>()));\n\n}  // namespace hardware\n\n}  // namespace oneflow\n\n#endif  // __linux__\n"
  },
  {
    "path": "oneflow/core/hardware/node_device_descriptor.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/hardware/node_device_descriptor.h\"\n#include \"oneflow/core/hardware/device_descriptor_class.h\"\n#include \"oneflow/core/common/str_util.h\"\n#include \"oneflow/core/persistence/tee_persistent_log_stream.h\"\n#include \"nlohmann/json.hpp\"\n#ifdef WITH_HWLOC\n#include <hwloc.h>\n#endif  // WITH_HWLOC\n\nnamespace oneflow {\n\nnamespace hardware {\n\nnamespace {\n\nconstexpr char kJsonKeyClasses[] = \"classes\";\nconstexpr char kJsonKeyClassName[] = \"class_name\";\nconstexpr char kJsonKeySerializedDescriptorList[] = \"serialized_descriptor_list\";\nconstexpr char kJsonKeyHostMemorySize[] = \"host_memory_size_bytes\";\nconstexpr char kJsonKeyTopology[] = \"topology\";\n\nclass DummyCPUAffinityDescriptor : public TopologyCPUAffinityDescriptor {\n public:\n  DummyCPUAffinityDescriptor() = default;\n  ~DummyCPUAffinityDescriptor() override = default;\n};\n\nclass DummyMemoryAffinityDescriptor : public TopologyMemoryAffinityDescriptor {\n public:\n  DummyMemoryAffinityDescriptor() = default;\n  ~DummyMemoryAffinityDescriptor() override = default;\n};\n\nclass DummyTopologyDescriptor : public TopologyDescriptor {\n public:\n  DummyTopologyDescriptor() = default;\n  ~DummyTopologyDescriptor() override = default;\n\n  std::shared_ptr<const TopologyCPUAffinityDescriptor> GetCPUAffinity() const override {\n    return std::make_shared<const DummyCPUAffinityDescriptor>();\n  }\n\n  std::shared_ptr<const TopologyMemoryAffinityDescriptor> GetMemoryAffinity() const override {\n    return std::make_shared<const DummyMemoryAffinityDescriptor>();\n  }\n\n  std::shared_ptr<const TopologyCPUAffinityDescriptor> GetCPUAffinityByPCIBusID(\n      const std::string& bus_id) const override {\n    return std::make_shared<const DummyCPUAffinityDescriptor>();\n  }\n\n  std::shared_ptr<const TopologyMemoryAffinityDescriptor> GetMemoryAffinityByPCIBusID(\n      const std::string& bus_id) const override {\n    return std::make_shared<const DummyMemoryAffinityDescriptor>();\n  }\n\n  void SetCPUAffinity(\n      const std::shared_ptr<const TopologyCPUAffinityDescriptor>& affinity) const override {}\n\n  void SetMemoryAffinity(\n      const std::shared_ptr<const TopologyMemoryAffinityDescriptor>& affinity) const override {}\n};\n\n#ifdef WITH_HWLOC\n\nclass HWLocCPUAffinityDescriptor : public TopologyCPUAffinityDescriptor {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(HWLocCPUAffinityDescriptor);\n  explicit HWLocCPUAffinityDescriptor(hwloc_cpuset_t hwloc_cpu_set)\n      : hwloc_cpu_set_(hwloc_cpu_set) {}\n  ~HWLocCPUAffinityDescriptor() override { hwloc_bitmap_free(hwloc_cpu_set_); }\n\n  hwloc_cpuset_t HWLocCPUSet() const { return hwloc_cpu_set_; }\n\n private:\n  hwloc_cpuset_t hwloc_cpu_set_;\n};\n\nclass HWLocMemoryAffinityDescriptor : public TopologyMemoryAffinityDescriptor {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(HWLocMemoryAffinityDescriptor);\n  explicit HWLocMemoryAffinityDescriptor(hwloc_bitmap_t hwloc_bitmap, hwloc_membind_policy_t policy)\n      : hwloc_bitmap_(hwloc_bitmap), policy_(policy) {}\n  ~HWLocMemoryAffinityDescriptor() override { hwloc_bitmap_free(hwloc_bitmap_); }\n\n  hwloc_bitmap_t HWLocBitmap() const { return hwloc_bitmap_; }\n  hwloc_membind_policy_t HWLocPolicy() const { return policy_; }\n\n private:\n  hwloc_bitmap_t hwloc_bitmap_;\n  hwloc_membind_policy_t policy_;\n};\n\nclass HWLocTopologyDescriptor : public TopologyDescriptor {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(HWLocTopologyDescriptor);\n  ~HWLocTopologyDescriptor() override { hwloc_topology_destroy(topology_); }\n\n  std::shared_ptr<const TopologyCPUAffinityDescriptor> GetCPUAffinity() const override {\n    hwloc_bitmap_t set = hwloc_bitmap_alloc();\n    if (hwloc_get_cpubind(topology_, set, HWLOC_CPUBIND_THREAD) != 0) { return nullptr; }\n    return std::make_shared<const HWLocCPUAffinityDescriptor>(set);\n  }\n\n  std::shared_ptr<const TopologyMemoryAffinityDescriptor> GetMemoryAffinity() const override {\n    hwloc_bitmap_t set = hwloc_bitmap_alloc();\n    hwloc_membind_policy_t policy{};\n    if (hwloc_get_membind(topology_, set, &policy, HWLOC_MEMBIND_THREAD) != 0) { return nullptr; }\n    return std::make_shared<const HWLocMemoryAffinityDescriptor>(set, policy);\n  }\n\n  std::shared_ptr<const TopologyCPUAffinityDescriptor> GetCPUAffinityByPCIBusID(\n      const std::string& bus_id) const override {\n    if (bus_id.empty()) { return nullptr; }\n    hwloc_obj_t non_io_ancestor = GetNonIOAncestorByPCIBusID(bus_id);\n    if (non_io_ancestor == nullptr) { return nullptr; }\n    if (non_io_ancestor->cpuset == nullptr) { return nullptr; }\n    return std::make_shared<const HWLocCPUAffinityDescriptor>(\n        hwloc_bitmap_dup(non_io_ancestor->cpuset));\n  }\n\n  std::shared_ptr<const TopologyMemoryAffinityDescriptor> GetMemoryAffinityByPCIBusID(\n      const std::string& bus_id) const override {\n    if (bus_id.empty()) { return nullptr; }\n    hwloc_obj_t non_io_ancestor = GetNonIOAncestorByPCIBusID(bus_id);\n    if (non_io_ancestor == nullptr) { return nullptr; }\n    if (non_io_ancestor->cpuset == nullptr) { return nullptr; }\n    return std::make_shared<const HWLocMemoryAffinityDescriptor>(\n        hwloc_bitmap_dup(non_io_ancestor->cpuset), HWLOC_MEMBIND_BIND);\n  }\n\n  void SetCPUAffinity(\n      const std::shared_ptr<const TopologyCPUAffinityDescriptor>& affinity) const override {\n    auto hwloc_affinity = std::dynamic_pointer_cast<const HWLocCPUAffinityDescriptor>(affinity);\n    if (!hwloc_affinity) { return; }\n    hwloc_set_cpubind(topology_, hwloc_affinity->HWLocCPUSet(), HWLOC_CPUBIND_THREAD);\n  }\n\n  void SetMemoryAffinity(\n      const std::shared_ptr<const TopologyMemoryAffinityDescriptor>& affinity) const override {\n    auto hwloc_affinity = std::dynamic_pointer_cast<const HWLocMemoryAffinityDescriptor>(affinity);\n    if (!hwloc_affinity) { return; }\n    hwloc_set_membind(topology_, hwloc_affinity->HWLocBitmap(), hwloc_affinity->HWLocPolicy(),\n                      HWLOC_MEMBIND_THREAD);\n  }\n\n  static std::shared_ptr<const HWLocTopologyDescriptor> Query() {\n    hwloc_topology_t topology = nullptr;\n    do {\n      if (hwloc_topology_init(&topology) != 0) { break; }\n      if (hwloc_topology_set_io_types_filter(topology, HWLOC_TYPE_FILTER_KEEP_ALL) != 0) { break; }\n      if (hwloc_topology_load(topology) != 0) { break; }\n      auto* desc = new HWLocTopologyDescriptor(topology);\n      return std::shared_ptr<const HWLocTopologyDescriptor>(desc);\n    } while (false);\n    if (topology != nullptr) { hwloc_topology_destroy(topology); }\n    return nullptr;\n  }\n\n  static std::shared_ptr<const HWLocTopologyDescriptor> Deserialize(const std::string& serialized) {\n    hwloc_topology_t topology = nullptr;\n    do {\n      if (hwloc_topology_init(&topology) != 0) { break; }\n      if (hwloc_topology_set_xmlbuffer(topology, serialized.data(),\n                                       static_cast<int>(serialized.size()))\n          != 0) {\n        break;\n      }\n      if (hwloc_topology_load(topology) != 0) { break; }\n      auto* desc = new HWLocTopologyDescriptor(topology);\n      return std::shared_ptr<const HWLocTopologyDescriptor>(desc);\n    } while (false);\n    if (topology != nullptr) { hwloc_topology_destroy(topology); }\n    return nullptr;\n  }\n\n  void Serialize(std::string* serialized) const {\n    char* buffer = nullptr;\n    int len = 0;\n    if (hwloc_topology_export_xmlbuffer(topology_, &buffer, &len, 0) == 0) {\n      *serialized = buffer;\n      hwloc_free_xmlbuffer(topology_, buffer);\n    }\n  }\n\n private:\n  hwloc_obj_t GetNonIOAncestorByPCIBusID(const std::string& pci_bus_id) const {\n    hwloc_obj_t device = hwloc_get_pcidev_by_busidstring(topology_, pci_bus_id.data());\n    if (device == nullptr) { return nullptr; }\n    hwloc_obj_t non_io_ancestor = hwloc_get_non_io_ancestor_obj(topology_, device);\n    return non_io_ancestor;\n  }\n\n  explicit HWLocTopologyDescriptor(hwloc_topology_t topology) : topology_(topology) {}\n  hwloc_topology_t topology_{};\n};\n\n#endif  // WITH_HWLOC\n\nstd::shared_ptr<const TopologyDescriptor> QueryTopologyDescriptor() {\n  std::shared_ptr<const TopologyDescriptor> topology;\n#ifdef WITH_HWLOC\n  topology = HWLocTopologyDescriptor::Query();\n#endif  // WITH_HWLOC\n  if (!topology) { topology.reset(new DummyTopologyDescriptor()); }\n  return topology;\n}\n\nstd::shared_ptr<const TopologyDescriptor> DeserializeTopologyDescriptor(\n    const std::string& serialized) {\n  std::shared_ptr<const TopologyDescriptor> topology;\n  if (serialized.empty()) {\n    topology.reset(new DummyTopologyDescriptor());\n  } else {\n#ifdef WITH_HWLOC\n    topology = HWLocTopologyDescriptor::Deserialize(serialized);\n#else\n    UNIMPLEMENTED();\n#endif  // WITH_HWLOC\n  }\n  if (!topology) { topology.reset(new DummyTopologyDescriptor()); }\n  return topology;\n}\n\nvoid SerializeTopologyDescriptor(const std::shared_ptr<const TopologyDescriptor>& topology,\n                                 std::string* serialized) {\n#ifdef WITH_HWLOC\n  auto hwloc_topology = std::dynamic_pointer_cast<const HWLocTopologyDescriptor>(topology);\n  if (hwloc_topology) { hwloc_topology->Serialize(serialized); }\n#endif  // WITH_HWLOC\n}\n\n}  // namespace\n\nstruct NodeDeviceDescriptor::Impl {\n  std::unordered_map<std::string, std::shared_ptr<const DeviceDescriptorList>>\n      class_name2descriptor_list;\n  size_t host_memory_size_bytes{};\n  std::shared_ptr<const TopologyDescriptor> topology;\n};\n\nNodeDeviceDescriptor::NodeDeviceDescriptor() { impl_.reset(new Impl()); }\n\nNodeDeviceDescriptor::~NodeDeviceDescriptor() = default;\n\nbool NodeDeviceDescriptor::HasDeviceClass(const std::string& class_name) const {\n  return impl_->class_name2descriptor_list.find(class_name)\n         != impl_->class_name2descriptor_list.end();\n}\n\nstd::shared_ptr<const DeviceDescriptorList> NodeDeviceDescriptor::GetDeviceDescriptorList(\n    const std::string& class_name) const {\n  auto it = impl_->class_name2descriptor_list.find(class_name);\n  if (it != impl_->class_name2descriptor_list.end()) {\n    return it->second;\n  } else {\n    return nullptr;\n  }\n}\n\nstd::shared_ptr<const DeviceDescriptor> NodeDeviceDescriptor::GetDevice(\n    const std::string& class_name, size_t ordinal) const {\n  const auto device_list = GetDeviceDescriptorList(class_name);\n  if (device_list) {\n    return device_list->GetDevice(ordinal);\n  } else {\n    return nullptr;\n  }\n}\n\nsize_t NodeDeviceDescriptor::HostMemorySizeBytes() const { return impl_->host_memory_size_bytes; }\n\nstd::shared_ptr<const TopologyDescriptor> NodeDeviceDescriptor::Topology() const {\n  return impl_->topology;\n}\n\nvoid NodeDeviceDescriptor::Serialize(std::string* serialized) const {\n  nlohmann::json json_object;\n  json_object[kJsonKeyHostMemorySize] = impl_->host_memory_size_bytes;\n  for (const auto& pair : impl_->class_name2descriptor_list) {\n    std::string serialized_descriptor_list;\n    auto clz = DeviceDescriptorClass::GetRegisteredClass(pair.first);\n    CHECK(clz);\n    clz->SerializeDeviceDescriptorList(pair.second, &serialized_descriptor_list);\n    json_object[kJsonKeyClasses].push_back(\n        {{kJsonKeyClassName, clz->Name()},\n         {kJsonKeySerializedDescriptorList, serialized_descriptor_list}});\n  }\n  std::string serialized_topology;\n  SerializeTopologyDescriptor(impl_->topology, &serialized_topology);\n  json_object[kJsonKeyTopology] = serialized_topology;\n  *serialized = json_object.dump();\n}\n\nvoid NodeDeviceDescriptor::DumpSummary(const std::string& path) const {\n  std::string classes_base = JoinPath(path, \"classes\");\n  for (const auto& pair : impl_->class_name2descriptor_list) {\n    auto clz = DeviceDescriptorClass::GetRegisteredClass(pair.first);\n    CHECK(clz);\n    clz->DumpDeviceDescriptorListSummary(pair.second, JoinPath(classes_base, pair.first));\n  }\n  std::string serialized_topology;\n  SerializeTopologyDescriptor(impl_->topology, &serialized_topology);\n  if (!serialized_topology.empty()) {\n    TeePersistentLogStream::Create(JoinPath(path, \"topology\"))->Write(serialized_topology);\n  }\n}\n\nstd::shared_ptr<const NodeDeviceDescriptor> NodeDeviceDescriptor::Query() {\n  auto* desc = new NodeDeviceDescriptor();\n  desc->impl_->host_memory_size_bytes = GetAvailableCpuMemSize();\n  const size_t num_classes = DeviceDescriptorClass::GetRegisteredClassesCount();\n  for (size_t i = 0; i < num_classes; ++i) {\n    std::shared_ptr<const DeviceDescriptorClass> descriptor_class =\n        DeviceDescriptorClass::GetRegisteredClass(i);\n    desc->impl_->class_name2descriptor_list.emplace(descriptor_class->Name(),\n                                                    descriptor_class->QueryDeviceDescriptorList());\n  }\n  desc->impl_->topology = QueryTopologyDescriptor();\n  return std::shared_ptr<const NodeDeviceDescriptor>(desc);\n}\n\nstd::shared_ptr<const NodeDeviceDescriptor> NodeDeviceDescriptor::Deserialize(\n    const std::string& serialized) {\n  auto json_object = nlohmann::json::parse(serialized);\n  auto* desc = new NodeDeviceDescriptor();\n  desc->impl_->host_memory_size_bytes = json_object[kJsonKeyHostMemorySize];\n  auto num_classes = json_object[kJsonKeyClasses].size();\n  for (int i = 0; i < num_classes; ++i) {\n    const std::string class_name = json_object[kJsonKeyClasses].at(i)[kJsonKeyClassName];\n    const std::string serialized_descriptor_list =\n        json_object[kJsonKeyClasses].at(i)[kJsonKeySerializedDescriptorList];\n    auto clz = DeviceDescriptorClass::GetRegisteredClass(class_name);\n    CHECK(clz);\n    const auto descriptor_list = clz->DeserializeDeviceDescriptorList(serialized_descriptor_list);\n    desc->impl_->class_name2descriptor_list.emplace(class_name, descriptor_list);\n  }\n  desc->impl_->topology = DeserializeTopologyDescriptor(json_object[kJsonKeyTopology]);\n  return std::shared_ptr<const NodeDeviceDescriptor>(desc);\n}\n\n}  // namespace hardware\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/hardware/node_device_descriptor.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_HARDWARE_NODE_DEVICE_DESCRIPTOR_H_\n#define ONEFLOW_CORE_HARDWARE_NODE_DEVICE_DESCRIPTOR_H_\n\n#include \"oneflow/core/hardware/device_descriptor_list.h\"\n#include \"oneflow/core/hardware/topology_descriptor.h\"\n#include \"oneflow/core/common/util.h\"\n\nnamespace oneflow {\n\nnamespace hardware {\n\nclass NodeDeviceDescriptor {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(NodeDeviceDescriptor);\n  ~NodeDeviceDescriptor();\n\n  bool HasDeviceClass(const std::string& class_name) const;\n  std::shared_ptr<const DeviceDescriptorList> GetDeviceDescriptorList(\n      const std::string& class_name) const;\n  std::shared_ptr<const DeviceDescriptor> GetDevice(const std::string& class_name,\n                                                    size_t ordinal) const;\n  size_t HostMemorySizeBytes() const;\n  std::shared_ptr<const TopologyDescriptor> Topology() const;\n  void Serialize(std::string* serialized) const;\n  void DumpSummary(const std::string& path) const;\n\n  static std::shared_ptr<const NodeDeviceDescriptor> Query();\n  static std::shared_ptr<const NodeDeviceDescriptor> Deserialize(const std::string& serialized);\n\n private:\n  NodeDeviceDescriptor();\n\n  struct Impl;\n  std::unique_ptr<Impl> impl_;\n};\n\n}  // namespace hardware\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_HARDWARE_NODE_DEVICE_DESCRIPTOR_H_\n"
  },
  {
    "path": "oneflow/core/hardware/node_device_descriptor_manager.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/hardware/node_device_descriptor_manager.h\"\n#include \"oneflow/core/control/ctrl_client.h\"\n#include \"oneflow/core/common/str_util.h\"\n#include \"oneflow/core/rpc/include/global_process_ctx.h\"\n\nnamespace oneflow {\n\nnamespace hardware {\n\nnamespace {\n\nstd::string MakeNodeDeviceDescriptorRpcKey(const int64_t rank) {\n  return \"NodeDeviceDescriptorRpcKey/\" + std::to_string(rank);\n}\n\n}  // namespace\n\nstruct NodeDeviceDescriptorManager::Impl {\n  Impl(int64_t rank, int64_t num_ranks) : rank(rank) { nodes.resize(num_ranks); }\n  std::vector<std::shared_ptr<const NodeDeviceDescriptor>> nodes;\n  int64_t rank;\n};\n\nNodeDeviceDescriptorManager::NodeDeviceDescriptorManager() {\n  impl_.reset(new Impl(GlobalProcessCtx::Rank(), GlobalProcessCtx::WorldSize()));\n  std::shared_ptr<const NodeDeviceDescriptor> local = NodeDeviceDescriptor::Query();\n  impl_->nodes.at(impl_->rank) = local;\n  if (impl_->nodes.size() > 1) {\n    std::string serialized_local_node;\n    local->Serialize(&serialized_local_node);\n    Singleton<CtrlClient>::Get()->PushKV(MakeNodeDeviceDescriptorRpcKey(impl_->rank),\n                                         serialized_local_node);\n    for (int64_t i = 0; i < impl_->nodes.size(); ++i) {\n      if (i == impl_->rank) { continue; }\n      Singleton<CtrlClient>::Get()->PullKV(\n          MakeNodeDeviceDescriptorRpcKey(i), [&](const std::string& serialized) {\n            impl_->nodes.at(i) = NodeDeviceDescriptor::Deserialize(serialized);\n          });\n    }\n  }\n}\n\nNodeDeviceDescriptorManager::~NodeDeviceDescriptorManager() = default;\n\nstd::shared_ptr<const NodeDeviceDescriptor> NodeDeviceDescriptorManager::GetNodeDeviceDescriptor(\n    int64_t rank) const {\n  CHECK_LT(rank, impl_->nodes.size());\n  return impl_->nodes.at(rank);\n}\n\nstd::shared_ptr<const NodeDeviceDescriptor>\nNodeDeviceDescriptorManager::GetLocalNodeDeviceDescriptor() const {\n  return impl_->nodes.at(impl_->rank);\n}\n\nvoid NodeDeviceDescriptorManager::DumpSummary(const std::string& base) const {\n  for (int64_t i = 0; i < impl_->nodes.size(); ++i) {\n    impl_->nodes.at(i)->DumpSummary(JoinPath(base, \"nodes\", std::to_string(i)));\n  }\n}\n\n}  // namespace hardware\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/hardware/node_device_descriptor_manager.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_HARDWARE_NODE_DEVICE_DESCRIPTOR_MANAGER_H_\n#define ONEFLOW_CORE_HARDWARE_NODE_DEVICE_DESCRIPTOR_MANAGER_H_\n\n#include \"oneflow/core/hardware/node_device_descriptor.h\"\n#include \"oneflow/core/common/util.h\"\n\nnamespace oneflow {\n\nnamespace hardware {\n\nclass NodeDeviceDescriptorManager {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(NodeDeviceDescriptorManager);\n  NodeDeviceDescriptorManager();\n  ~NodeDeviceDescriptorManager();\n\n  std::shared_ptr<const NodeDeviceDescriptor> GetNodeDeviceDescriptor(int64_t rank) const;\n  std::shared_ptr<const NodeDeviceDescriptor> GetLocalNodeDeviceDescriptor() const;\n\n  void DumpSummary(const std::string& path) const;\n\n private:\n  struct Impl;\n  std::unique_ptr<Impl> impl_;\n};\n\n}  // namespace hardware\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_HARDWARE_NODE_DEVICE_DESCRIPTOR_MANAGER_H_\n"
  },
  {
    "path": "oneflow/core/hardware/topology_descriptor.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/hardware/topology_descriptor.h\"\n\nnamespace oneflow {\n\nnamespace hardware {\n\nvoid TopologyDescriptor::SetCPUAffinityByPCIBusID(const std::string& bus_id) const {\n  SetCPUAffinity(GetCPUAffinityByPCIBusID(bus_id));\n}\n\nvoid TopologyDescriptor::SetMemoryAffinityByPCIBusID(const std::string& bus_id) const {\n  SetMemoryAffinity(GetMemoryAffinityByPCIBusID(bus_id));\n}\n\n}  // namespace hardware\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/hardware/topology_descriptor.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_HARDWARE_TOPOLOGY_DESCRIPTOR_H_\n#define ONEFLOW_CORE_HARDWARE_TOPOLOGY_DESCRIPTOR_H_\n\n#include <string>\n#include <memory>\n#include \"oneflow/core/common/util.h\"\n\nnamespace oneflow {\n\nnamespace hardware {\n\nclass TopologyCPUAffinityDescriptor {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(TopologyCPUAffinityDescriptor);\n  TopologyCPUAffinityDescriptor() = default;\n  virtual ~TopologyCPUAffinityDescriptor() = default;\n};\n\nclass TopologyMemoryAffinityDescriptor {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(TopologyMemoryAffinityDescriptor);\n  TopologyMemoryAffinityDescriptor() = default;\n  virtual ~TopologyMemoryAffinityDescriptor() = default;\n};\n\nclass TopologyDescriptor {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(TopologyDescriptor);\n  TopologyDescriptor() = default;\n  virtual ~TopologyDescriptor() = default;\n\n  virtual std::shared_ptr<const TopologyCPUAffinityDescriptor> GetCPUAffinity() const = 0;\n  virtual std::shared_ptr<const TopologyMemoryAffinityDescriptor> GetMemoryAffinity() const = 0;\n  virtual std::shared_ptr<const TopologyCPUAffinityDescriptor> GetCPUAffinityByPCIBusID(\n      const std::string& bus_id) const = 0;\n  virtual std::shared_ptr<const TopologyMemoryAffinityDescriptor> GetMemoryAffinityByPCIBusID(\n      const std::string& bus_id) const = 0;\n  virtual void SetCPUAffinity(\n      const std::shared_ptr<const TopologyCPUAffinityDescriptor>& affinity) const = 0;\n  virtual void SetMemoryAffinity(\n      const std::shared_ptr<const TopologyMemoryAffinityDescriptor>& affinity) const = 0;\n  virtual void SetCPUAffinityByPCIBusID(const std::string& bus_id) const;\n  virtual void SetMemoryAffinityByPCIBusID(const std::string& bus_id) const;\n};\n\n}  // namespace hardware\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_HARDWARE_TOPOLOGY_DESCRIPTOR_H_\n"
  },
  {
    "path": "oneflow/core/intrusive/README.md",
    "content": "### 概念与数据结构\n本子系统可以方便用户定义可侵入式类型。内建支持侵入式智能指针`intrusive::shared_ptr`和侵入式容器。\n目前有主要有两类侵入式容器：\n1. `intrusive::List`，双链表。基于此，还提供了`intrusive::MutexedList`和`intrusive::Channel`。\n2. `intrusive::SkipList`，跳表，等同于map。\n\n为了管理元素CURD所带来的生命周期，侵入式容器需要`intrusive::shared_ptr`来实现内存生命周期的管理，它与`std::shared_ptr`的不同在于其引用计数嵌入在目标结构体里。\n### 接口\n需要使用`intrusive::shared_ptr`来管理生命周期的类必须拥有`intrusive::Ref* mut_intrusive_ref();`方法\n\n由于侵入式容器支持比标准容器更为强大的迭代方式，同时为了性能起见，我们提供三类迭代宏：\n1. `INTRUSIVE_FOR_EACH`，支持迭代过程中删除当前元素，同时使用`intrusive::shared_ptr`管理好当前元素生命周期\n2. `INTRUSIVE_FOR_EACH_PTR`，支持迭代过程中删除当前元素，类型直接为裸指针，即不负责当前元素生命周期的管理\n3. `INTRUSIVE_UNSAFE_FOR_EACH_PTR`，不支持迭代中删除元素，不负责当前元素生命周期的管理。\n\n### 特点\n本组件与boost::intrusive最大不同在于实现了完整的生命周期管理，另外提供了其他更能减少内存分配的容器定义方式（详见intrusive::HeadFreeList）。\n"
  },
  {
    "path": "oneflow/core/intrusive/base.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_INTRUSIVE_BASE_H_\n#define ONEFLOW_CORE_INTRUSIVE_BASE_H_\n\nnamespace oneflow {\nnamespace intrusive {\n\nclass Base {\n public:\n  void __Init__() {}\n  void __Delete__() {}\n};\n\n}  // namespace intrusive\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_INTRUSIVE_BASE_H_\n"
  },
  {
    "path": "oneflow/core/intrusive/cpp_attribute.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_COMMON_INTRUSIVE_ATTRIBUTE_H_\n#define ONEFLOW_CORE_COMMON_INTRUSIVE_ATTRIBUTE_H_\n\n#define INTRUSIVE_PREDICT_TRUE GOOGLE_PREDICT_TRUE\n#define INTRUSIVE_PREDICT_FALSE GOOGLE_PREDICT_FALSE\n\n#endif  // ONEFLOW_CORE_COMMON_INTRUSIVE_ATTRIBUTE_H_\n"
  },
  {
    "path": "oneflow/core/intrusive/dss.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_INTRUSIVE_DSS_H_\n#define ONEFLOW_CORE_INTRUSIVE_DSS_H_\n\n#include <cstddef>\n#include <typeinfo>\n#include \"oneflow/core/common/preprocessor.h\"\n#include \"oneflow/core/intrusive/struct_traits.h\"\n\nnamespace oneflow {\n\n// DSS is short for domain specific struct\n#define DSS_BEGIN(field_counter, type) _DSS_BEGIN(field_counter, type)\n#define DSS_DEFINE_FIELD(field_counter, dss_type, field_type, field_name) \\\n  _DSS_DEFINE_FIELD(field_counter, dss_type, field_type, field_name)\n#define DSS_END(field_counter, dss_type, type) _DSS_END(field_counter, dss_type, type)\n#define DSS_DEFINE_UNION_FIELD_VISITOR(field_counter, field_case, type7field7case_tuple_seq) \\\n  _DSS_DEFINE_UNION_FIELD_VISITOR(field_counter, field_case, type7field7case_tuple_seq)\n#define DSS_GET_FIELD_COUNTER() __COUNTER__\n\n// details\n\n#define _DSS_DEFINE_UNION_FIELD_VISITOR(field_counter, field_case, type7field7case_tuple_seq)   \\\n private:                                                                                       \\\n  template<template<int, class, class> class F, typename WalkCtxType, typename DssFieldType,    \\\n           typename Enabled>                                                                    \\\n  struct __DssVisitField__<field_counter, F, WalkCtxType, DssFieldType, Enabled> {              \\\n    template<typename __DssFieldType>                                                           \\\n    using PartialF = F<field_counter, WalkCtxType, __DssFieldType>;                             \\\n    static void Call(WalkCtxType* ctx, DssFieldType* field_ptr) {                               \\\n      switch (field_ptr->field_case) {                                                          \\\n        OF_PP_FOR_EACH_TUPLE(_DSS_MAKE_UNION_FIELD_VISITOR_HOOK, type7field7case_tuple_seq)     \\\n        default:;                                                                               \\\n      }                                                                                         \\\n    }                                                                                           \\\n  };                                                                                            \\\n  template<template<int, class, class> class F, typename WalkCtxType, typename DssFieldType,    \\\n           typename Enabled>                                                                    \\\n  struct __DssVisitVerboseField__<field_counter, F, WalkCtxType, DssFieldType, Enabled> {       \\\n    template<typename __DssFieldType>                                                           \\\n    using PartialF = F<field_counter, WalkCtxType, __DssFieldType>;                             \\\n    static void Call(WalkCtxType* ctx, DssFieldType* field_ptr, const char* __field_name__) {   \\\n      switch (field_ptr->field_case) {                                                          \\\n        OF_PP_FOR_EACH_TUPLE(_DSS_MAKE_UNION_FIELD_VISITOR_HOOK_VERBOSE,                        \\\n                             type7field7case_tuple_seq)                                         \\\n        default:;                                                                               \\\n      }                                                                                         \\\n    }                                                                                           \\\n  };                                                                                            \\\n  template<template<class, int, class, class, bool> class F, typename WalkCtxType,              \\\n           typename DssFieldType, typename Enabled>                                             \\\n  struct __DssVisitStaticVerboseField__<field_counter, F, WalkCtxType, DssFieldType, Enabled> { \\\n    template<typename __DssFieldType>                                                           \\\n    using PartialF = F<__DssSelfType__, field_counter, WalkCtxType, __DssFieldType, true>;      \\\n    static void Call(WalkCtxType* ctx, const char* __oneof_name__) {                            \\\n      OF_PP_FOR_EACH_TUPLE(_DSS_MAKE_UNION_FIELD_VISITOR_HOOK_STATIC_VERBOSE,                   \\\n                           type7field7case_tuple_seq)                                           \\\n    }                                                                                           \\\n  };                                                                                            \\\n  template<template<int, class, class> class F, typename WalkCtxType, typename DssFieldType,    \\\n           typename Enabled>                                                                    \\\n  struct __DssVisitFieldUntil__<field_counter, F, WalkCtxType, DssFieldType, Enabled> {         \\\n    template<typename __DssFieldType>                                                           \\\n    using PartialF = F<field_counter, WalkCtxType, __DssFieldType>;                             \\\n    static bool Call(WalkCtxType* ctx, DssFieldType* field_ptr) {                               \\\n      switch (field_ptr->field_case) {                                                          \\\n        OF_PP_FOR_EACH_TUPLE(_DSS_MAKE_UNION_FIELD_VISITOR_HOOK, type7field7case_tuple_seq)     \\\n        default:;                                                                               \\\n      }                                                                                         \\\n    }                                                                                           \\\n  };\n\n#define _DSS_MAKE_UNION_FIELD_VISITOR_HOOK(field_type, field_name, field_case_value) \\\n  case field_case_value: {                                                           \\\n    return PartialF<field_type>::Call(ctx, &field_ptr->field_name);                  \\\n  }\n\n#define _DSS_MAKE_UNION_FIELD_VISITOR_HOOK_VERBOSE(field_type, field_name, field_case_value) \\\n  case field_case_value: {                                                                   \\\n    const char* case_field_name = OF_PP_STRINGIZE(field_name);                               \\\n    return PartialF<field_type>::Call(ctx, &field_ptr->field_name, case_field_name);         \\\n  }\n\n#define _DSS_MAKE_UNION_FIELD_VISITOR_HOOK_STATIC_VERBOSE(field_type, field_name, \\\n                                                          field_case_value)       \\\n  {                                                                               \\\n    const char* case_field_name = OF_PP_STRINGIZE(field_name);                    \\\n    PartialF<field_type>::Call(ctx, case_field_name, __oneof_name__);             \\\n  }\n\n#define _DSS_BEGIN(field_counter, type)                                                       \\\n private:                                                                                     \\\n  using __DssSelfType__ = type;                                                               \\\n                                                                                              \\\n public:                                                                                      \\\n  template<int tpl_fld_counter, typename Enabled = void>                                      \\\n  struct __DssFieldType__;                                                                    \\\n  template<template<int, class, class> class F, typename WalkCtxType>                         \\\n  void __WalkField__(WalkCtxType* ctx) {                                                      \\\n    __DssFieldIter__<field_counter, F, WalkCtxType>::Call(ctx, this);                         \\\n  }                                                                                           \\\n  template<template<int, class, class> class F, typename WalkCtxType>                         \\\n  void __WalkVerboseField__(WalkCtxType* ctx) {                                               \\\n    __DssVerboseFieldIter__<field_counter, F, WalkCtxType>::Call(ctx, this);                  \\\n  }                                                                                           \\\n  template<template<class, int, class, class, bool> class F, typename WalkCtxType>            \\\n  static void __WalkStaticVerboseField__(WalkCtxType* ctx) {                                  \\\n    __DssStaticVerboseFieldIter__<field_counter, F, WalkCtxType>::Call(ctx);                  \\\n  }                                                                                           \\\n  template<template<int, class, class> class F, typename WalkCtxType>                         \\\n  bool __WalkFieldUntil__(WalkCtxType* ctx) {                                                 \\\n    return __DssFieldIterUntil__<field_counter, F, WalkCtxType>::Call(ctx, this);             \\\n  }                                                                                           \\\n                                                                                              \\\n private:                                                                                     \\\n  template<int tpl_fld_counter, template<int, class, class> class F, typename WalkCtxType,    \\\n           typename DssFieldType, typename Enabled = void>                                    \\\n  struct __DssVisitField__ {                                                                  \\\n    static void Call(WalkCtxType* ctx, DssFieldType* field_ptr) {                             \\\n      F<tpl_fld_counter, WalkCtxType, DssFieldType>::Call(ctx, field_ptr);                    \\\n    }                                                                                         \\\n  };                                                                                          \\\n  template<int tpl_fld_counter, template<int, class, class> class F, typename WalkCtxType,    \\\n           typename DssFieldType, typename Enabled = void>                                    \\\n  struct __DssVisitVerboseField__ {                                                           \\\n    static void Call(WalkCtxType* ctx, DssFieldType* field_ptr, const char* __field_name__) { \\\n      F<tpl_fld_counter, WalkCtxType, DssFieldType>::Call(ctx, field_ptr, __field_name__);    \\\n    }                                                                                         \\\n  };                                                                                          \\\n  template<int tpl_fld_counter, template<class, int, class, class, bool> class F,             \\\n           typename WalkCtxType, typename DssFieldType, typename Enabled = void>              \\\n  struct __DssVisitStaticVerboseField__ {                                                     \\\n    static void Call(WalkCtxType* ctx, const char* __field_name__) {                          \\\n      const char* __oneof_name__ = nullptr;                                                   \\\n      F<__DssSelfType__, tpl_fld_counter, WalkCtxType, DssFieldType, false>::Call(            \\\n          ctx, __field_name__, __oneof_name__);                                               \\\n    }                                                                                         \\\n  };                                                                                          \\\n  template<int tpl_fld_counter, template<int, class, class> class F, typename WalkCtxType,    \\\n           typename DssFieldType, typename Enabled = void>                                    \\\n  struct __DssVisitFieldUntil__ {                                                             \\\n    static bool Call(WalkCtxType* ctx, DssFieldType* field_ptr) {                             \\\n      return F<tpl_fld_counter, WalkCtxType, DssFieldType>::Call(ctx, field_ptr);             \\\n    }                                                                                         \\\n  };                                                                                          \\\n  template<int tpl_fld_counter, template<int, class, class> class F, typename WalkCtxType,    \\\n           typename Enabled = void>                                                           \\\n  struct __DssFieldIter__ {                                                                   \\\n    static void Call(WalkCtxType* ctx, __DssSelfType__* self) {                               \\\n      __DssFieldIter__<tpl_fld_counter + 1, F, WalkCtxType>::Call(ctx, self);                 \\\n    }                                                                                         \\\n  };                                                                                          \\\n  template<int tpl_fld_counter, template<int, class, class> class F, typename WalkCtxType,    \\\n           typename Enabled = void>                                                           \\\n  struct __DssVerboseFieldIter__ {                                                            \\\n    static void Call(WalkCtxType* ctx, __DssSelfType__* self) {                               \\\n      __DssVerboseFieldIter__<tpl_fld_counter + 1, F, WalkCtxType>::Call(ctx, self);          \\\n    }                                                                                         \\\n  };                                                                                          \\\n  template<int tpl_fld_counter, template<class, int, class, class, bool> class F,             \\\n           typename WalkCtxType, typename Enabled = void>                                     \\\n  struct __DssStaticVerboseFieldIter__ {                                                      \\\n    static void Call(WalkCtxType* ctx) {                                                      \\\n      __DssStaticVerboseFieldIter__<tpl_fld_counter + 1, F, WalkCtxType>::Call(ctx);          \\\n    }                                                                                         \\\n  };                                                                                          \\\n  template<int tpl_fld_counter, template<int, class, class> class F, typename WalkCtxType,    \\\n           typename Enabled = void>                                                           \\\n  struct __DssFieldIterUntil__ {                                                              \\\n    static bool Call(WalkCtxType* ctx, __DssSelfType__* self) {                               \\\n      return __DssFieldIterUntil__<tpl_fld_counter + 1, F, WalkCtxType>::Call(ctx, self);     \\\n    }                                                                                         \\\n  };                                                                                          \\\n  template<int tpl_fld_counter, template<int, class, class> class F, typename WalkCtxType,    \\\n           typename Enabled = void>                                                           \\\n  struct __DssFieldReverseIter__ {                                                            \\\n    static void Call(WalkCtxType* ctx, __DssSelfType__* self) {                               \\\n      __DssFieldReverseIter__<tpl_fld_counter - 1, F, WalkCtxType>::Call(ctx, self);          \\\n    }                                                                                         \\\n  };                                                                                          \\\n  template<template<int, class, class> class F, typename WalkCtxType, typename Enabled>       \\\n  struct __DssFieldReverseIter__<field_counter, F, WalkCtxType, Enabled> {                    \\\n    static void Call(WalkCtxType* ctx, __DssSelfType__* self) {}                              \\\n  };                                                                                          \\\n  template<int tpl_fld_counter, typename Enabled = void>                                      \\\n  struct __DssFieldAlign4Counter__ {                                                          \\\n    static const int value = 1;                                                               \\\n  };                                                                                          \\\n  template<int tpl_fld_counter, typename Enabled = void>                                      \\\n  struct __DssFieldSize4Counter__ {                                                           \\\n    static const int value = 0;                                                               \\\n  };                                                                                          \\\n  template<int tpl_fld_counter, typename Enabled = void>                                      \\\n  struct __DssFieldOffsetOfFieldNumber__ {                                                    \\\n    constexpr static int Get() {                                                              \\\n      return __DssFieldOffsetOfFieldNumber__<tpl_fld_counter - 1, Enabled>::Get();            \\\n    }                                                                                         \\\n  };                                                                                          \\\n  template<typename Enabled>                                                                  \\\n  struct __DssFieldOffsetOfFieldNumber__<field_counter, Enabled> {                            \\\n    constexpr static int Get() { return 0; }                                                  \\\n  };                                                                                          \\\n  template<int tpl_fld_counter, typename Enabled = void>                                      \\\n  struct __DssStaticAssertFieldCounter__ {};                                                  \\\n                                                                                              \\\n  template<int tpl_fld_counter, typename Enabled = void>                                      \\\n  struct __DssAccumulatedAlignedSize4Counter__ {                                              \\\n    static const int value =                                                                  \\\n        ConstExprRoundUp<__DssAccumulatedAlignedSize4Counter__<tpl_fld_counter - 1>::value    \\\n                             + __DssFieldSize4Counter__<tpl_fld_counter - 1>::value,          \\\n                         __DssFieldAlign4Counter__<tpl_fld_counter>::value>();                \\\n  };                                                                                          \\\n  template<typename Enabled>                                                                  \\\n  struct __DssAccumulatedAlignedSize4Counter__<field_counter, Enabled> {                      \\\n    static const int value = 0;                                                               \\\n  };                                                                                          \\\n                                                                                              \\\n public:                                                                                      \\\n  template<int field_index>                                                                   \\\n  struct __DssFieldOffset4FieldIndex__ {                                                      \\\n    static const int value = __DssAccumulatedAlignedSize4Counter__<field_index>::value;       \\\n  };\n\n#define DSS_ASSERT_VERBOSE(dss_type)                                        \\\n  \"\\n\\n\\n    please check file \" __FILE__ \" (before line \" OF_PP_STRINGIZE( \\\n      __LINE__) \") carefully\\n\"                                             \\\n                \"    non \" dss_type \" member found before line \" OF_PP_STRINGIZE(__LINE__) \"\\n\\n\"\n\n#define _DSS_DEFINE_FIELD(field_counter, dss_type, field_type, field)                              \\\n private:                                                                                          \\\n  template<template<int, class, class> class F, typename WalkCtxType, typename Enabled>            \\\n  struct __DssFieldIter__<field_counter, F, WalkCtxType, Enabled> {                                \\\n    static void Call(WalkCtxType* ctx, __DssSelfType__* self) {                                    \\\n      __DssVisitField__<field_counter, F, WalkCtxType, decltype(self->field)>::Call(ctx,           \\\n                                                                                    &self->field); \\\n      __DssFieldIter__<field_counter + 1, F, WalkCtxType>::Call(ctx, self);                        \\\n    }                                                                                              \\\n  };                                                                                               \\\n  template<template<int, class, class> class F, typename WalkCtxType, typename Enabled>            \\\n  struct __DssVerboseFieldIter__<field_counter, F, WalkCtxType, Enabled> {                         \\\n    static void Call(WalkCtxType* ctx, __DssSelfType__* self) {                                    \\\n      const char* __field_name__ = OF_PP_STRINGIZE(field);                                         \\\n      __DssVisitVerboseField__<field_counter, F, WalkCtxType, decltype(self->field)>::Call(        \\\n          ctx, &self->field, __field_name__);                                                      \\\n      __DssVerboseFieldIter__<field_counter + 1, F, WalkCtxType>::Call(ctx, self);                 \\\n    }                                                                                              \\\n  };                                                                                               \\\n  template<template<class, int, class, class, bool> class F, typename WalkCtxType,                 \\\n           typename Enabled>                                                                       \\\n  struct __DssStaticVerboseFieldIter__<field_counter, F, WalkCtxType, Enabled> {                   \\\n    static void Call(WalkCtxType* ctx) {                                                           \\\n      const char* __field_name__ = OF_PP_STRINGIZE(field);                                         \\\n      __DssVisitStaticVerboseField__<                                                              \\\n          field_counter, F, WalkCtxType,                                                           \\\n          decltype(((__DssSelfType__*)nullptr)->field)>::Call(ctx, __field_name__);                \\\n      __DssStaticVerboseFieldIter__<field_counter + 1, F, WalkCtxType>::Call(ctx);                 \\\n    }                                                                                              \\\n  };                                                                                               \\\n  template<template<int, class, class> class F, typename WalkCtxType, typename Enabled>            \\\n  struct __DssFieldIterUntil__<field_counter, F, WalkCtxType, Enabled> {                           \\\n    static bool Call(WalkCtxType* ctx, __DssSelfType__* self) {                                    \\\n      bool end =                                                                                   \\\n          __DssVisitFieldUntil__<field_counter, F, WalkCtxType, decltype(self->field)>::Call(      \\\n              ctx, &self->field);                                                                  \\\n      if (end) { return true; }                                                                    \\\n      return __DssFieldIterUntil__<field_counter + 1, F, WalkCtxType>::Call(ctx, self);            \\\n    }                                                                                              \\\n  };                                                                                               \\\n  template<template<int, class, class> class F, typename WalkCtxType, typename Enabled>            \\\n  struct __DssFieldReverseIter__<field_counter, F, WalkCtxType, Enabled> {                         \\\n    static void Call(WalkCtxType* ctx, __DssSelfType__* self) {                                    \\\n      __DssVisitField__<field_counter, F, WalkCtxType, decltype(self->field)>::Call(ctx,           \\\n                                                                                    &self->field); \\\n      __DssFieldReverseIter__<field_counter - 1, F, WalkCtxType>::Call(ctx, self);                 \\\n    }                                                                                              \\\n  };                                                                                               \\\n  template<typename Enabled>                                                                       \\\n  struct __DssFieldAlign4Counter__<field_counter, Enabled> {                                       \\\n    static const int value = alignof(field_type);                                                  \\\n  };                                                                                               \\\n  template<typename Enabled>                                                                       \\\n  struct __DssFieldSize4Counter__<field_counter, Enabled> {                                        \\\n    static const int value = sizeof(field_type);                                                   \\\n  };                                                                                               \\\n  template<typename Enabled>                                                                       \\\n  struct __DssFieldOffsetOfFieldNumber__<field_counter, Enabled> {                                 \\\n    constexpr static int Get() {                                                                   \\\n      static_assert(std::is_standard_layout<__DssSelfType__>::value, \"\");                          \\\n      return offsetof(__DssSelfType__, field);                                                     \\\n    }                                                                                              \\\n  };                                                                                               \\\n  template<typename Enabled>                                                                       \\\n  struct __DssStaticAssertFieldCounter__<field_counter, Enabled> {                                 \\\n    static void StaticAssert() {                                                                   \\\n      static const int kAccSize = __DssAccumulatedAlignedSize4Counter__<field_counter>::value;     \\\n      static_assert(kAccSize == __DssFieldOffsetOfFieldNumber__<field_counter>::Get(),             \\\n                    DSS_ASSERT_VERBOSE(dss_type));                                                 \\\n    }                                                                                              \\\n  };                                                                                               \\\n                                                                                                   \\\n public:                                                                                           \\\n  template<typename Enabled>                                                                       \\\n  struct __DssFieldType__<field_counter, Enabled> {                                                \\\n    using type = field_type;                                                                       \\\n  };                                                                                               \\\n  [[maybe_unused]] static const int OF_PP_CAT(field, kDssFieldNumber) = field_counter;             \\\n  using OF_PP_CAT(field, DssFieldType) = field_type;                                               \\\n  [[maybe_unused]] static const int OF_PP_CAT(field, kDssFieldOffset) =                            \\\n      __DssAccumulatedAlignedSize4Counter__<field_counter>::value;\n\n#define _DSS_END(field_counter, dss_type, type)                                         \\\n public:                                                                                \\\n  template<template<int, class, class> class F, typename WalkCtxType>                   \\\n  void __ReverseWalkField__(WalkCtxType* ctx) {                                         \\\n    __DssFieldReverseIter__<field_counter, F, WalkCtxType>::Call(ctx, this);            \\\n  }                                                                                     \\\n                                                                                        \\\n private:                                                                               \\\n  template<template<int, class, class> class F, typename WalkCtxType, typename Enabled> \\\n  struct __DssFieldIter__<field_counter, F, WalkCtxType, Enabled> {                     \\\n    static void Call(WalkCtxType* ctx, type* self) {}                                   \\\n  };                                                                                    \\\n  template<template<int, class, class> class F, typename WalkCtxType, typename Enabled> \\\n  struct __DssVerboseFieldIter__<field_counter, F, WalkCtxType, Enabled> {              \\\n    static void Call(WalkCtxType* ctx, type* self) {}                                   \\\n  };                                                                                    \\\n  template<template<class, int, class, class, bool> class F, typename WalkCtxType,      \\\n           typename Enabled>                                                            \\\n  struct __DssStaticVerboseFieldIter__<field_counter, F, WalkCtxType, Enabled> {        \\\n    static void Call(WalkCtxType* ctx) {}                                               \\\n  };                                                                                    \\\n  template<template<int, class, class> class F, typename WalkCtxType, typename Enabled> \\\n  struct __DssFieldIterUntil__<field_counter, F, WalkCtxType, Enabled> {                \\\n    static bool Call(WalkCtxType* ctx, type* self) { return false; }                    \\\n  };                                                                                    \\\n  static void __DssStaticAssertStructSize__() {                                         \\\n    static const int kSize =                                                            \\\n        ConstExprRoundUp<__DssAccumulatedAlignedSize4Counter__<field_counter>::value,   \\\n                         alignof(type)>();                                              \\\n    static_assert((kSize == 0 && sizeof(type) == 1) || (kSize == sizeof(type)),         \\\n                  DSS_ASSERT_VERBOSE(dss_type));                                        \\\n  }\n\ntemplate<int x, int y>\nconstexpr int ConstExprRoundUp() {\n  return (x + y - 1) / y * y;\n}\n\ntemplate<bool is_pointer, typename Enabled = void>\nstruct GetterTrait {};\n\ntemplate<typename Enabled>\nstruct GetterTrait<false, Enabled> {\n  template<typename T>\n  static const T& Call(const T& data) {\n    return data;\n  }\n};\ntemplate<typename Enabled>\nstruct GetterTrait<true, Enabled> {\n  template<typename T>\n  static const T& Call(const T* data) {\n    return *data;\n  }\n};\n\ntemplate<bool is_pointer, typename Enabled = void>\nstruct MutableTrait {};\n\ntemplate<typename Enabled>\nstruct MutableTrait<false, Enabled> {\n  template<typename T>\n  static T* Call(T* data) {\n    return data;\n  }\n};\ntemplate<typename Enabled>\nstruct MutableTrait<true, Enabled> {\n  template<typename T>\n  static T* Call(T** data) {\n    return *data;\n  }\n};\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_INTRUSIVE_DSS_H_\n"
  },
  {
    "path": "oneflow/core/intrusive/dss_test.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"gtest/gtest.h\"\n#include \"oneflow/core/intrusive/dss.h\"\n#include \"oneflow/core/common/util.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nstruct Foo {\n  DSS_BEGIN(DSS_GET_FIELD_COUNTER(), Foo);\n  int x;\n  int y;\n  int* z;\n\n  DSS_DEFINE_FIELD(DSS_GET_FIELD_COUNTER(), \"demo dss\", int, x);\n  DSS_DEFINE_FIELD(DSS_GET_FIELD_COUNTER(), \"demo dss\", int, y);\n  DSS_DEFINE_FIELD(DSS_GET_FIELD_COUNTER(), \"demo dss\", int*, z);\n\n  DSS_END(DSS_GET_FIELD_COUNTER(), \"demo dss\", Foo);\n};\n\nstruct Bar {\n  DSS_BEGIN(DSS_GET_FIELD_COUNTER(), Foo);\n\n  DSS_END(DSS_GET_FIELD_COUNTER(), \"demo dss\", Bar);\n};\n\ntemplate<typename T>\nstruct IsPointer {\n  static const bool value = std::is_pointer<T>::value;\n};\n\ntemplate<typename T>\nstruct RemovePointer {\n  using type = typename std::remove_pointer<T>::type;\n};\n\ntemplate<typename T>\nstruct IsScalar {\n  static const bool value =\n      std::is_arithmetic<T>::value || std::is_enum<T>::value || std::is_same<T, std::string>::value;\n};\n\ntemplate<int field_counter, typename WalkCtxType, typename FieldType>\nstruct DumpFieldName {\n  static void Call(WalkCtxType* ctx, FieldType* field, const char* field_name) {\n    ctx->emplace_back(field_name);\n  }\n};\n\nTEST(DSS, walk_field) {\n  Foo foo;\n  std::vector<std::string> field_names;\n  foo.__WalkVerboseField__<DumpFieldName>(&field_names);\n  ASSERT_EQ(field_names.size(), 3);\n  ASSERT_TRUE(field_names[0] == \"x\");\n  ASSERT_TRUE(field_names[1] == \"y\");\n  ASSERT_TRUE(field_names[2] == \"z\");\n}\n\ntemplate<bool is_pointer>\nstruct PushBackPtrFieldName {\n  template<typename WalkCtxType>\n  static void Call(WalkCtxType* ctx, const char* field_name) {}\n};\n\ntemplate<>\nstruct PushBackPtrFieldName<true> {\n  template<typename WalkCtxType>\n  static void Call(WalkCtxType* ctx, const char* field_name) {\n    ctx->emplace_back(field_name);\n  }\n};\n\ntemplate<int field_counter, typename WalkCtxType, typename FieldType>\nstruct FilterPointerFieldName {\n  static void Call(WalkCtxType* ctx, FieldType* field, const char* field_name) {\n    PushBackPtrFieldName<std::is_pointer<FieldType>::value>::Call(ctx, field_name);\n  }\n};\n\ntemplate<int field_counter, typename WalkCtxType, typename FieldType>\nstruct FilterPointerFieldNameUntil {\n  static bool Call(WalkCtxType* ctx, FieldType* field) {\n    return true;\n    PushBackPtrFieldName<std::is_pointer<FieldType>::value>::Call(ctx, \"\");\n  }\n};\n\nTEST(DSS, filter_field) {\n  Foo foo;\n  std::vector<std::string> field_names;\n  foo.__WalkVerboseField__<FilterPointerFieldName>(&field_names);\n  ASSERT_EQ(field_names.size(), 1);\n  ASSERT_TRUE(field_names[0] == \"z\");\n}\n\nTEST(DSS, filter_field_until) {\n  Foo foo;\n  std::vector<std::string> field_names;\n  ASSERT_TRUE(foo.__WalkFieldUntil__<FilterPointerFieldNameUntil>(&field_names));\n  ASSERT_TRUE(field_names.empty());\n}\n\n#define DSS_DEFINE_TEST_UNION_FIELD(field_counter)                      \\\n  DSS_DEFINE_FIELD(field_counter, \"demo dss\", UnionField, union_field); \\\n  DSS_DEFINE_UNION_FIELD_VISITOR(field_counter, union_case,             \\\n                                 OF_PP_MAKE_TUPLE_SEQ(int32_t, x, 1)    \\\n                                     OF_PP_MAKE_TUPLE_SEQ(int64_t, y, 2));\n\nstruct TestDssUnion {\n  DSS_BEGIN(DSS_GET_FIELD_COUNTER(), TestDssUnion);\n\n public:\n  struct UnionField {\n    int32_t union_case;\n    union {\n      int32_t x;\n      int64_t y;\n    };\n  } union_field;\n\n  DSS_DEFINE_TEST_UNION_FIELD(DSS_GET_FIELD_COUNTER());\n  DSS_END(DSS_GET_FIELD_COUNTER(), \"demo dss\", TestDssUnion);\n};\n\ntemplate<typename StructT, int field_counter, typename WalkCtxType, typename FieldType,\n         bool is_oneof_field>\nstruct StaticDumpFieldName {\n  static void Call(WalkCtxType* ctx, const char* field_name, const char* oneof_name) {\n    ctx->emplace_back(field_name);\n    ctx->emplace_back(oneof_name);\n  }\n};\n\nTEST(DSS, union_field) {\n  TestDssUnion foo;\n  foo.union_field.union_case = 0;\n  {\n    std::vector<std::string> field_names;\n    foo.__WalkVerboseField__<DumpFieldName>(&field_names);\n    ASSERT_EQ(field_names.size(), 0);\n  }\n  foo.union_field.union_case = 1;\n  {\n    std::vector<std::string> field_names;\n    foo.__WalkVerboseField__<DumpFieldName>(&field_names);\n    ASSERT_EQ(field_names.size(), 1);\n    ASSERT_EQ(field_names.at(0), \"x\");\n  }\n  foo.union_field.union_case = 2;\n  {\n    std::vector<std::string> field_names;\n    foo.__WalkVerboseField__<DumpFieldName>(&field_names);\n    ASSERT_EQ(field_names.size(), 1);\n    ASSERT_EQ(field_names.at(0), \"y\");\n  }\n}\n\nTEST(DSS, static_verbose_field) {\n  std::vector<std::string> field_names;\n  TestDssUnion::__WalkStaticVerboseField__<StaticDumpFieldName>(&field_names);\n  ASSERT_EQ(field_names.size(), 4);\n  ASSERT_EQ(field_names.at(0), \"x\");\n  ASSERT_EQ(field_names.at(1), \"union_field\");\n  ASSERT_EQ(field_names.at(2), \"y\");\n  ASSERT_EQ(field_names.at(3), \"union_field\");\n}\n\n}  // namespace\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/intrusive/flat_msg.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_INTRUSIVE_FLAT__H_\n#define ONEFLOW_CORE_INTRUSIVE_FLAT__H_\n\n#include <array>\n#include <cstring>\n#include \"oneflow/core/common/throw.h\"\n#include \"oneflow/core/common/preprocessor.h\"\n#include \"oneflow/core/intrusive/dss.h\"\n#include \"oneflow/core/intrusive/static_counter.h\"\n\nnamespace oneflow {\n\n#define FLAT_MSG_BEGIN(struct_name)                        \\\n  struct struct_name final {                               \\\n    using self_type = struct_name;                         \\\n    using self_value_type = struct_name;                   \\\n    static const bool __is_flat_message_type__ = true;     \\\n                                                           \\\n   public:                                                 \\\n    DEFINE_STATIC_COUNTER(field_counter);                  \\\n    DSS_BEGIN(STATIC_COUNTER(field_counter), struct_name); \\\n    FLAT_MSG_DEFINE_BASIC_METHODS(struct_name);            \\\n    FLAT_MSG_DEFINE_DEFAULT(struct_name);\n\n#define FLAT_MSG_END(struct_name)                                                       \\\n  static_assert(__is_flat_message_type__, \"this struct is not a flat message\");         \\\n                                                                                        \\\n public:                                                                                \\\n  [[maybe_unused]] static const int __NumberOfFields__ = STATIC_COUNTER(field_counter); \\\n                                                                                        \\\n public:                                                                                \\\n  INCREASE_STATIC_COUNTER(field_counter);                                               \\\n  DSS_END(STATIC_COUNTER(field_counter), \"flat message\", struct_name);                  \\\n  }                                                                                     \\\n  ;\n\n#define FLAT_MSG_DEFINE_OPTIONAL(field_type, field_name)                        \\\n  static_assert(__is_flat_message_type__, \"this struct is not a flat message\"); \\\n  FLAT_MSG_DEFINE_ONEOF(OF_PP_CAT(__flat_msg_optional__, field_name),           \\\n                        FLAT_MSG_ONEOF_FIELD(field_type, field_name))\n\n#define FLAT_MSG_DEFINE_ONEOF(oneof_name, type_and_field_name_seq) \\\n  _FLAT_MSG_DEFINE_ONEOF(_FLAT_MSG_DEFINE_NOTHING, oneof_name, type_and_field_name_seq);\n\n#define FLAT_MSG_DEFINE_STRICT_ONEOF(oneof_name, type_and_field_name_seq) \\\n  _FLAT_MSG_DEFINE_ONEOF(_FLAT_MSG_DEFINE_ONEOF_VALUE4TYPE, oneof_name, type_and_field_name_seq);\n\n#define _FLAT_MSG_DEFINE_ONEOF(define_field_value4field_type, oneof_name, type_and_field_name_seq) \\\n  static_assert(__is_flat_message_type__, \"this struct is not a flat message\");                    \\\n  FLAT_MSG_DEFINE_ONEOF_ENUM_TYPE(oneof_name, type_and_field_name_seq);                            \\\n  FLAT_MSG_DEFINE_ONEOF_UNION(define_field_value4field_type, oneof_name, type_and_field_name_seq); \\\n  FLAT_MSG_DEFINE_ONEOF_ACCESSOR(oneof_name, type_and_field_name_seq)                              \\\n public:                                                                                           \\\n  INCREASE_STATIC_COUNTER(field_counter);                                                          \\\n  FLAT_MSG_DSS_DEFINE_UION_FIELD(STATIC_COUNTER(field_counter), oneof_name,                        \\\n                                 type_and_field_name_seq);\n\n#define FLAT_MSG_DEFINE_REPEATED(field_type, field_name, max_size)                        \\\n  static_assert(__is_flat_message_type__, \"this struct is not a flat message\");           \\\n  _FLAT_MSG_DEFINE_REPEATED_FIELD(FLAT_MSG_TYPE_CHECK(field_type), field_name, max_size); \\\n                                                                                          \\\n public:                                                                                  \\\n  INCREASE_STATIC_COUNTER(field_counter);                                                 \\\n  DSS_DEFINE_FIELD(STATIC_COUNTER(field_counter), \"flat message\",                         \\\n                   OF_PP_CAT(field_name, _RepeatedField), OF_PP_CAT(field_name, _));\n\n#define FLAT_MSG_DEFINE_COMPARE_OPERATORS_BY_MEMCMP() _FLAT_MSG_DEFINE_COMPARE_OPERATORS_BY_MEMCMP()\n\n#define FLAT_MSG_ONEOF_FIELD(field_type, field_name) \\\n  OF_PP_MAKE_TUPLE_SEQ(FLAT_MSG_TYPE_CHECK(field_type), field_name)\n\n#define FLAT_MSG_ONEOF_CASE(oneof_name) _FLAT_MSG_ONEOF_ENUM_TYPE(oneof_name)\n\n#define FLAT_MSG_ONEOF_CASE_VALUE(field) _FLAT_MSG_ONEOF_ENUM_VALUE(field)\n\n#define FLAT_MSG_ONEOF_NOT_SET_VALUE(field_type, oneof_name) \\\n  field_type::_FLAT_MSG_ONEOF_NOT_SET_VALUE(oneof_name)\n\n#define FLAT_MSG_TYPE_CHECK(type_name) FlatMsgSelfType<type_name>::type\n\n// details\n\n#define FLAT_MSG_DSS_DEFINE_UION_FIELD(field_counter, oneof_name, type_and_field_name_seq) \\\n  DSS_DEFINE_FIELD(field_counter, \"flat message\", OF_PP_CAT(oneof_name, _OneofType),       \\\n                   OF_PP_CAT(oneof_name, _));                                              \\\n  DSS_DEFINE_UNION_FIELD_VISITOR(                                                          \\\n      field_counter, case_,                                                                \\\n      OF_PP_FOR_EACH_TUPLE(FLAT_MSG_MAKE_UNION_TYPE7FIELD4CASE, type_and_field_name_seq));\n\n#define FLAT_MSG_MAKE_UNION_TYPE7FIELD4CASE(field_type, field_name) \\\n  OF_PP_MAKE_TUPLE_SEQ(field_type, OF_PP_CAT(field_name, _), _FLAT_MSG_ONEOF_ENUM_VALUE(field_name))\n\ntemplate<typename T, typename Enabled = void>\nstruct FlatMsgSelfType {\n  static_assert(T::__is_flat_message_type__, \"T is not a flat message type\");\n  using type = T;\n};\n\ntemplate<typename T>\nstruct FlatMsgSelfType<\n    T, typename std::enable_if<std::is_arithmetic<T>::value || std::is_enum<T>::value>::type> {\n  using type = T;\n};\n\ntemplate<typename T>\nstruct FlatMsg final {\n  using value_type = T;\n  using self_value_type = value_type;\n  FlatMsg() { msg_.clear(); }\n  FlatMsg(const FlatMsg& rhs) { msg_.CopyFrom(rhs.msg_); }\n  FlatMsg(const T& msg) { msg_.CopyFrom(msg); }\n\n  const value_type& operator*() const { return msg_; }\n  value_type& operator*() { return msg_; }\n  const value_type* operator->() const { return &msg_; }\n  value_type* operator->() { return &msg_; }\n\n  const value_type& Get() const { return msg_; }\n  value_type* Mutable() { return &msg_; }\n\n  template<typename RhsT>\n  bool operator==(const RhsT& rhs) const {\n    static_assert(std::is_same<FlatMsg, RhsT>::value, \"\");\n    return msg_ == rhs.msg_;\n  }\n\n  template<typename RhsT>\n  bool operator!=(const RhsT& rhs) const {\n    static_assert(std::is_same<FlatMsg, RhsT>::value, \"\");\n    return msg_ != rhs.msg_;\n  }\n\n  template<typename RhsT>\n  bool operator>=(const RhsT& rhs) const {\n    static_assert(std::is_same<FlatMsg, RhsT>::value, \"\");\n    return msg_ >= rhs.msg_;\n  }\n\n  template<typename RhsT>\n  bool operator<=(const RhsT& rhs) const {\n    static_assert(std::is_same<FlatMsg, RhsT>::value, \"\");\n    return msg_ <= rhs.msg_;\n  }\n\n  template<typename RhsT>\n  bool operator>(const RhsT& rhs) const {\n    static_assert(std::is_same<FlatMsg, RhsT>::value, \"\");\n    return msg_ > rhs.msg_;\n  }\n\n  template<typename RhsT>\n  bool operator<(const RhsT& rhs) const {\n    static_assert(std::is_same<FlatMsg, RhsT>::value, \"\");\n    return msg_ < rhs.msg_;\n  }\n\n private:\n  union {\n    value_type msg_;\n  };\n};\n\n#define FLAT_MSG_DEFINE_DEFAULT(flat_msg_type_name)            \\\n  const flat_msg_type_name& __Default__() const {              \\\n    static const FlatMsg<flat_msg_type_name> default_flat_msg; \\\n    return default_flat_msg.Get();                             \\\n  }\n\ntemplate<typename T>\nstruct FlatMsgIsScalar final {\n  static const bool value = std::is_arithmetic<T>::value || std::is_enum<T>::value;\n};\n\ntemplate<bool is_scalar>\nstruct FlatMsgGetDefault final {\n  template<typename T>\n  static const T& Call(const T* val) {\n    return val->__Default__();\n  }\n};\ntemplate<>\nstruct FlatMsgGetDefault<true> final {\n  template<typename T>\n  static const T& Call(const T* val) {\n    return *val;\n  }\n};\n\n#define _FLAT_MSG_ONEOF_CASE_NAME(oneof_name) OF_PP_CAT(oneof_name, _case)\n\n#define _FLAT_MSG_ONEOF_ENUM_VALUE(field) SNAKE_TO_CAMEL(field)\n\n#define _FLAT_MSG_ONEOF_ENUM_TYPE(oneof_name) SNAKE_TO_CAMEL(oneof_name)\n\n#define _FLAT_MSG_ONEOF_NOT_SET_VALUE(oneof_name) OF_PP_CAT(k_, OF_PP_CAT(oneof_name, _not_set))\n\n#define FLAT_MSG_DEFINE_BASIC_METHODS(T) _FLAT_MSG_DEFINE_BASIC_METHODS(T)\n\n#define _FLAT_MSG_DEFINE_BASIC_METHODS(T)                                           \\\n public:                                                                            \\\n  void clear() { std::memset(reinterpret_cast<void*>(this), 0, sizeof(T)); }        \\\n  void CopyFrom(const self_type& rhs) {                                             \\\n    std::memcpy(reinterpret_cast<void*>(this), reinterpret_cast<const void*>(&rhs), \\\n                sizeof(self_type));                                                 \\\n  }\n\n#define FLAT_MSG_DEFINE_ONEOF_ENUM_TYPE(oneof_name, type_and_field_name_seq)     \\\n public:                                                                         \\\n  enum _FLAT_MSG_ONEOF_ENUM_TYPE(oneof_name) {                                   \\\n    _FLAT_MSG_ONEOF_NOT_SET_VALUE(oneof_name) = 0,                               \\\n    OF_PP_FOR_EACH_TUPLE(MAKE_FLAT_MSG_ONEOF_ENUM_CASE, type_and_field_name_seq) \\\n  }\n\n#define MAKE_FLAT_MSG_ONEOF_ENUM_CASE(field_type, field_name) \\\n  _FLAT_MSG_ONEOF_ENUM_VALUE(field_name),\n\n#define FLAT_MSG_DEFINE_ONEOF_ACCESSOR(oneof_name, type_and_field_name_seq)                    \\\n  _FLAT_MSG_DEFINE_ONEOF_CASE_ACCESSOR(oneof_name, _FLAT_MSG_ONEOF_ENUM_TYPE(oneof_name));     \\\n  OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_FLAT_MSG_ONEOF_ACCESSOR, (_FLAT_MSG_ONEOF_ENUM_VALUE), \\\n                                   (oneof_name), type_and_field_name_seq)\n\n#define MAKE_FLAT_MSG_ONEOF_ACCESSOR(get_enum_value, oneof_name, pair)                         \\\n public:                                                                                       \\\n  const OF_PP_PAIR_FIRST(pair) & OF_PP_PAIR_SECOND(pair)() const {                             \\\n    if (OF_PP_CAT(has_, OF_PP_PAIR_SECOND(pair))()) {                                          \\\n      return OF_PP_CAT(oneof_name, _).OF_PP_CAT(OF_PP_PAIR_SECOND(pair), _);                   \\\n    }                                                                                          \\\n    return FlatMsgGetDefault<FlatMsgIsScalar<OF_PP_PAIR_FIRST(pair)>::value>::Call(            \\\n        &OF_PP_CAT(oneof_name, _).OF_PP_CAT(OF_PP_PAIR_SECOND(pair), _));                      \\\n  }                                                                                            \\\n  bool OF_PP_CAT(has_, OF_PP_PAIR_SECOND(pair))() const {                                      \\\n    return _FLAT_MSG_ONEOF_CASE_NAME(oneof_name)() == get_enum_value(OF_PP_PAIR_SECOND(pair)); \\\n  }                                                                                            \\\n  void OF_PP_CAT(clear_, OF_PP_PAIR_SECOND(pair))() {                                          \\\n    if (!OF_PP_CAT(has_, OF_PP_PAIR_SECOND(pair))()) { return; }                               \\\n    OF_PP_CAT(set_, _FLAT_MSG_ONEOF_CASE_NAME(oneof_name))                                     \\\n    (_FLAT_MSG_ONEOF_NOT_SET_VALUE(oneof_name));                                               \\\n  }                                                                                            \\\n  OF_PP_PAIR_FIRST(pair) * OF_PP_CAT(mut_, OF_PP_PAIR_SECOND(pair))() {                        \\\n    OF_PP_CAT(set_, _FLAT_MSG_ONEOF_CASE_NAME(oneof_name))                                     \\\n    (get_enum_value(OF_PP_PAIR_SECOND(pair)));                                                 \\\n    return &OF_PP_CAT(oneof_name, _).OF_PP_CAT(OF_PP_PAIR_SECOND(pair), _);                    \\\n  }                                                                                            \\\n  OF_PP_PAIR_FIRST(pair) * OF_PP_CAT(mutable_, OF_PP_PAIR_SECOND(pair))() {                    \\\n    OF_PP_CAT(set_, _FLAT_MSG_ONEOF_CASE_NAME(oneof_name))                                     \\\n    (get_enum_value(OF_PP_PAIR_SECOND(pair)));                                                 \\\n    return &OF_PP_CAT(oneof_name, _).OF_PP_CAT(OF_PP_PAIR_SECOND(pair), _);                    \\\n  }                                                                                            \\\n  void OF_PP_CAT(set_, OF_PP_PAIR_SECOND(pair))(const OF_PP_PAIR_FIRST(pair) & val) {          \\\n    *OF_PP_CAT(mutable_, OF_PP_PAIR_SECOND(pair))() = val;                                     \\\n  }\n\n#define FLAT_MSG_DEFINE_ONEOF_UNION(define_field_value4field_type, oneof_name,             \\\n                                    type_and_field_name_seq)                               \\\n public:                                                                                   \\\n  struct OF_PP_CAT(oneof_name, _OneofType) {                                               \\\n   public:                                                                                 \\\n    using self_oneof_type = OF_PP_CAT(oneof_name, _OneofType);                             \\\n    using self_oneof_case_type = _FLAT_MSG_ONEOF_ENUM_TYPE(oneof_name);                    \\\n    template<self_oneof_case_type oneof_case, typename Enabled = void>                     \\\n    struct FieldType4FieldValueStruct {};                                                  \\\n    template<self_oneof_case_type oneof_case, typename Enabled = void>                     \\\n    struct HasStruct {};                                                                   \\\n    template<self_oneof_case_type oneof_case, typename Enabled = void>                     \\\n    struct GetStruct {};                                                                   \\\n    template<self_oneof_case_type oneof_case, typename Enabled = void>                     \\\n    struct MutableStruct {};                                                               \\\n    OF_PP_FOR_EACH_TUPLE(_MAKE_FLAT_MSG_ONEOF_TEMPLATE_ACCESSOR, type_and_field_name_seq); \\\n    define_field_value4field_type(type_and_field_name_seq);                                \\\n    template<self_oneof_case_type oneof_case>                                              \\\n    bool Has() const {                                                                     \\\n      return HasStruct<oneof_case>::Call(*this);                                           \\\n    }                                                                                      \\\n    template<self_oneof_case_type oneof_case>                                              \\\n    const typename FieldType4FieldValueStruct<oneof_case>::type& Get() const {             \\\n      return GetStruct<oneof_case>::Call(*this);                                           \\\n    }                                                                                      \\\n    template<self_oneof_case_type oneof_case>                                              \\\n    typename FieldType4FieldValueStruct<oneof_case>::type* Mutable() {                     \\\n      return MutableStruct<oneof_case>::Call(this);                                        \\\n    }                                                                                      \\\n                                                                                           \\\n    union {                                                                                \\\n      OF_PP_FOR_EACH_TUPLE(MAKE_FLAT_MSG_ONEOF_UNION_FIELD, type_and_field_name_seq)       \\\n    };                                                                                     \\\n    self_oneof_case_type case_;                                                            \\\n  };                                                                                       \\\n                                                                                           \\\n private:                                                                                  \\\n  OF_PP_CAT(oneof_name, _OneofType) OF_PP_CAT(oneof_name, _);                              \\\n                                                                                           \\\n public:                                                                                   \\\n  const OF_PP_CAT(oneof_name, _OneofType) & oneof_name() const {                           \\\n    return OF_PP_CAT(oneof_name, _);                                                       \\\n  }                                                                                        \\\n  OF_PP_CAT(oneof_name, _OneofType) * OF_PP_CAT(mutable_, oneof_name)() {                  \\\n    return &OF_PP_CAT(oneof_name, _);                                                      \\\n  }\n\n#define _MAKE_FLAT_MSG_ONEOF_TEMPLATE_ACCESSOR(field_type, field_name)                 \\\n public:                                                                               \\\n  template<typename Enabled>                                                           \\\n  struct FieldType4FieldValueStruct<_FLAT_MSG_ONEOF_ENUM_VALUE(field_name), Enabled> { \\\n    using type = field_type;                                                           \\\n  };                                                                                   \\\n  template<typename Enabled>                                                           \\\n  struct HasStruct<_FLAT_MSG_ONEOF_ENUM_VALUE(field_name), Enabled> {                  \\\n    static bool Call(const self_oneof_type& self) {                                    \\\n      return self.case_ == _FLAT_MSG_ONEOF_ENUM_VALUE(field_name);                     \\\n    }                                                                                  \\\n  };                                                                                   \\\n  template<typename Enabled>                                                           \\\n  struct GetStruct<_FLAT_MSG_ONEOF_ENUM_VALUE(field_name), Enabled> {                  \\\n    static const field_type& Call(const self_oneof_type& self) {                       \\\n      return self.OF_PP_CAT(field_name, _);                                            \\\n    }                                                                                  \\\n  };                                                                                   \\\n  template<typename Enabled>                                                           \\\n  struct MutableStruct<_FLAT_MSG_ONEOF_ENUM_VALUE(field_name), Enabled> {              \\\n    static field_type* Call(self_oneof_type* self) {                                   \\\n      self->case_ = _FLAT_MSG_ONEOF_ENUM_VALUE(field_name);                            \\\n      return &self->OF_PP_CAT(field_name, _);                                          \\\n    }                                                                                  \\\n  };\n\n#define _FLAT_MSG_DEFINE_NOTHING(type_and_field_name_seq)\n\n#define _FLAT_MSG_DEFINE_ONEOF_VALUE4TYPE(type_and_field_name_seq)                \\\n public:                                                                          \\\n  template<typename T, typename Enabled = void>                                   \\\n  struct FieldValue4FieldType {};                                                 \\\n  OF_PP_FOR_EACH_TUPLE(_MAKE_FLAT_MSG_ONEOF_VALUE4TYPE, type_and_field_name_seq); \\\n  template<typename T>                                                            \\\n  bool HasField() const {                                                         \\\n    return Has<FieldValue4FieldType<T>::value>();                                 \\\n  }                                                                               \\\n  template<typename T>                                                            \\\n  const T& GetField() const {                                                     \\\n    return Get<FieldValue4FieldType<T>::value>();                                 \\\n  }                                                                               \\\n  template<typename T>                                                            \\\n  T* MutableField() {                                                             \\\n    return Mutable<FieldValue4FieldType<T>::value>();                             \\\n  }\n\n#define _MAKE_FLAT_MSG_ONEOF_VALUE4TYPE(field_type, field_name)                       \\\n  template<typename Enabled>                                                          \\\n  struct FieldValue4FieldType<field_type, Enabled> {                                  \\\n    static const self_oneof_case_type value = _FLAT_MSG_ONEOF_ENUM_VALUE(field_name); \\\n  };\n\n#define MAKE_FLAT_MSG_ONEOF_UNION_FIELD(field_type, field_name) field_type OF_PP_CAT(field_name, _);\n\n#define SNAKE_TO_CAMEL(name) OF_PP_CAT(__FlatMsgSnakeToCamel__, name)\n\n#define _FLAT_MSG_DEFINE_ONEOF_CASE_ACCESSOR(oneof_name, T)                             \\\n public:                                                                                \\\n  T OF_PP_CAT(oneof_name, _case)() const { return OF_PP_CAT(oneof_name, _).case_; }     \\\n  bool OF_PP_CAT(has_, oneof_name)() const {                                            \\\n    return OF_PP_CAT(oneof_name, _).case_ != _FLAT_MSG_ONEOF_NOT_SET_VALUE(oneof_name); \\\n  }                                                                                     \\\n                                                                                        \\\n private:                                                                               \\\n  void OF_PP_CAT(set_, OF_PP_CAT(oneof_name, _case))(T val) {                           \\\n    OF_PP_CAT(oneof_name, _).case_ = val;                                               \\\n  }\n\n#define _FLAT_MSG_DEFINE_REPEATED_FIELD(T, field_name, N)                                       \\\n public:                                                                                        \\\n  using OF_PP_CAT(field_name, _RepeatedField) = FlatMsgRepeatedField<T, N>;                     \\\n  std::size_t OF_PP_CAT(field_name, _size)() const { return OF_PP_CAT(field_name, _).size(); }  \\\n  const OF_PP_CAT(field_name, _RepeatedField) & field_name() const {                            \\\n    return OF_PP_CAT(field_name, _);                                                            \\\n  }                                                                                             \\\n  const T& field_name(int32_t i) const { return OF_PP_CAT(field_name, _).Get(i); }              \\\n  OF_PP_CAT(field_name, _RepeatedField) * OF_PP_CAT(mut_, field_name)() {                       \\\n    return &OF_PP_CAT(field_name, _);                                                           \\\n  }                                                                                             \\\n  OF_PP_CAT(field_name, _RepeatedField) * OF_PP_CAT(mutable_, field_name)() {                   \\\n    return &OF_PP_CAT(field_name, _);                                                           \\\n  }                                                                                             \\\n  T* OF_PP_CAT(mut_, field_name)(int32_t i) { return OF_PP_CAT(field_name, _).Mutable(i); }     \\\n  T* OF_PP_CAT(mutable_, field_name)(int32_t i) { return OF_PP_CAT(field_name, _).Mutable(i); } \\\n  T* OF_PP_CAT(add_, field_name)() { return OF_PP_CAT(field_name, _).Add(); }                   \\\n  void OF_PP_CAT(clear_, field_name)() { OF_PP_CAT(field_name, _).clear(); }                    \\\n                                                                                                \\\n private:                                                                                       \\\n  OF_PP_CAT(field_name, _RepeatedField)                                                         \\\n  OF_PP_CAT(field_name, _);\n\n#define _FLAT_MSG_DEFINE_COMPARE_OPERATORS_BY_MEMCMP()                                           \\\n public:                                                                                         \\\n  bool operator<(const self_type& rhs) const {                                                   \\\n    return std::memcmp(reinterpret_cast<const void*>(this), reinterpret_cast<const void*>(&rhs), \\\n                       sizeof(self_type))                                                        \\\n           < 0;                                                                                  \\\n  }                                                                                              \\\n  bool operator<=(const self_type& rhs) const {                                                  \\\n    return std::memcmp(reinterpret_cast<const void*>(this), reinterpret_cast<const void*>(&rhs), \\\n                       sizeof(self_type))                                                        \\\n           <= 0;                                                                                 \\\n  }                                                                                              \\\n  bool operator==(const self_type& rhs) const {                                                  \\\n    return std::memcmp(reinterpret_cast<const void*>(this), reinterpret_cast<const void*>(&rhs), \\\n                       sizeof(self_type))                                                        \\\n           == 0;                                                                                 \\\n  }                                                                                              \\\n  bool operator!=(const self_type& rhs) const {                                                  \\\n    return std::memcmp(reinterpret_cast<const void*>(this), reinterpret_cast<const void*>(&rhs), \\\n                       sizeof(self_type))                                                        \\\n           != 0;                                                                                 \\\n  }                                                                                              \\\n  bool operator>(const self_type& rhs) const {                                                   \\\n    return std::memcmp(reinterpret_cast<const void*>(this), reinterpret_cast<const void*>(&rhs), \\\n                       sizeof(self_type))                                                        \\\n           > 0;                                                                                  \\\n  }                                                                                              \\\n  bool operator>=(const self_type& rhs) const {                                                  \\\n    return std::memcmp(reinterpret_cast<const void*>(this), reinterpret_cast<const void*>(&rhs), \\\n                       sizeof(self_type))                                                        \\\n           >= 0;                                                                                 \\\n  }\n\ntemplate<typename T, std::size_t N>\nclass FlatMsgRepeatedField final {\n public:\n  using value_type = T;\n  static const int capacity = N;\n\n  bool empty() const { return size_ == 0; }\n\n  std::size_t size() const { return size_; }\n\n  void clear() { size_ = 0; }\n\n  T* begin() { return &data_[0]; }\n  T* end() {\n    CHECK_GE(size_, 0);\n    CHECK_LE(size_, N);\n    return &data_[size_];\n  }\n\n  const T* begin() const { return &data_[0]; }\n  const T* end() const {\n    CHECK_GE(size_, 0);\n    CHECK_LE(size_, N);\n    return &data_[size_];\n  }\n\n  const T& Get(int32_t index) const {\n    CHECK_GE(index, 0);\n    CHECK_LT(index, N);\n    return data_[index];\n  }\n\n  T* Mutable(int32_t index) {\n    CHECK_GE(index, 0);\n    CHECK_LT(index, N);\n    return &data_[index];\n  }\n\n  const T* data() const { return &Get(0); }\n  T* data() { return Mutable(0); }\n  T* mut_data() { return Mutable(0); }\n\n  T* Add() {\n    CHECK_GE(size_, 0);\n    CHECK_LT(size_, N);\n    return &data_[size_++];\n  }\n\n private:\n  std::size_t size_;\n  std::array<T, N> data_;\n};\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_INTRUSIVE_FLAT__H_\n"
  },
  {
    "path": "oneflow/core/intrusive/flat_msg_test.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"gtest/gtest.h\"\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/intrusive/flat_msg.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<int field_counter, typename WalkCtxType, typename FieldType>\nstruct DumpFieldName {\n  static void Call(WalkCtxType* ctx, FieldType* field, const char* field_name) {\n    ctx->emplace_back(field_name);\n  }\n};\n\ntemplate<typename T>\nstd::vector<std::string> GetFieldNames(T* flat_msg) {\n  std::vector<std::string> field_names;\n  flat_msg->template __WalkVerboseField__<DumpFieldName>(&field_names);\n  return field_names;\n}\n\ntemplate<typename T>\nvoid CheckSoleFieldName(T* flat_msg, const std::string& expected) {\n  const auto& field_names = GetFieldNames(flat_msg);\n  ASSERT_EQ(field_names.size(), 1);\n  ASSERT_EQ(field_names.at(0), expected);\n}\n// clang-format off\nFLAT_MSG_BEGIN(TestOptional)\n  FLAT_MSG_DEFINE_OPTIONAL(int32_t, bar);\nFLAT_MSG_END(TestOptional)\n// clang-format on\n\nTEST(FlatMsg, optional) {\n  static_assert(std::is_trivial<TestOptional>::value, \"TestOptional is not trivial\");\n  FlatMsg<TestOptional> foo_box;\n  auto& foo = *foo_box.Mutable();\n  ASSERT_TRUE(!foo.has_bar());\n  ASSERT_EQ(foo.bar(), 0);\n  ASSERT_TRUE(GetFieldNames(&foo).empty());\n  *foo.mutable_bar() = 9527;\n  ASSERT_TRUE(foo.has_bar());\n  ASSERT_EQ(foo.bar(), 9527);\n  auto field_names = GetFieldNames(&foo);\n  ASSERT_EQ(field_names.size(), 1);\n  ASSERT_EQ(field_names.at(0), \"bar_\");\n}\n\n// clang-format off\nFLAT_MSG_BEGIN(FooOneof)\n  FLAT_MSG_DEFINE_ONEOF(type,\n      FLAT_MSG_ONEOF_FIELD(int32_t, case_0)\n      FLAT_MSG_ONEOF_FIELD(int64_t, case_1)\n      FLAT_MSG_ONEOF_FIELD(TestOptional, bar));\nFLAT_MSG_END(FooOneof)\n// clang-format on\n\nTEST(FlatMsg, oneof) {\n  FlatMsg<FooOneof> foo_box;\n  auto& foo = *foo_box.Mutable();\n  ASSERT_TRUE(GetFieldNames(&foo).empty());\n  ASSERT_TRUE(!foo.has_bar());\n  ASSERT_EQ(foo.bar().bar(), 0);\n  foo.mutable_case_0();\n  CheckSoleFieldName(&foo, \"case_0_\");\n  ASSERT_TRUE(foo.has_case_0());\n  FooOneof::FLAT_MSG_ONEOF_CASE(type) x = foo.type_case();\n  ASSERT_TRUE(x == FooOneof::FLAT_MSG_ONEOF_CASE_VALUE(case_0));\n  *foo.mutable_case_1() = 9527;\n  CheckSoleFieldName(&foo, \"case_1_\");\n  ASSERT_TRUE(foo.has_case_1());\n  ASSERT_EQ(foo.case_1(), 9527);\n}\n\n// clang-format off\nFLAT_MSG_BEGIN(FooRepeated)\n  FLAT_MSG_DEFINE_REPEATED(char, char_field, 1);\n  FLAT_MSG_DEFINE_REPEATED(TestOptional, bar, 10);\nFLAT_MSG_END(FooRepeated)\n// clang-format on\n\nTEST(FlatMsg, repeated) {\n  FlatMsg<FooRepeated> foo_box;\n  auto& foo = *foo_box.Mutable();\n  ASSERT_EQ(foo.bar_size(), 0);\n  ASSERT_EQ(foo.bar().size(), 0);\n  auto* bar = foo.mutable_bar()->Add();\n  ASSERT_TRUE(!bar->has_bar());\n  ASSERT_EQ(foo.bar_size(), 1);\n  ASSERT_EQ(foo.bar().size(), 1);\n  bar->set_bar(9527);\n  ASSERT_TRUE(bar->has_bar());\n  ASSERT_EQ(bar->bar(), 9527);\n  bar = foo.mutable_bar()->Add();\n  ASSERT_TRUE(!bar->has_bar());\n  ASSERT_EQ(foo.bar_size(), 2);\n  ASSERT_EQ(foo.bar().size(), 2);\n  bar->set_bar(9528);\n  for (const auto& x : foo.bar()) { ASSERT_TRUE(x.has_bar()); }\n  foo.clear_bar();\n  ASSERT_EQ(foo.bar_size(), 0);\n}\n\n// clang-format off\ntemplate<int N>\nFLAT_MSG_BEGIN(TestTemplateFlatMsg);\n  FLAT_MSG_DEFINE_REPEATED(char, char_field, N);\nFLAT_MSG_END(TestTemplateFlatMsg);\n// clang-format on\n\nTEST(FlatMsg, flat_msg_template) {\n  FlatMsg<TestTemplateFlatMsg<1024>> foo;\n  ASSERT_TRUE(foo.Get().char_field().empty());\n}\n\n}  // namespace\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/intrusive/flat_msg_view.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_INTRUSIVE_FLAT_MSG_VIEW_H_\n#define ONEFLOW_CORE_INTRUSIVE_FLAT_MSG_VIEW_H_\n\n#include <vector>\n#include \"oneflow/core/common/throw.h\"\n#include \"oneflow/core/intrusive/dss.h\"\n#include \"oneflow/core/intrusive/flat_msg.h\"\n#include \"oneflow/core/intrusive/struct_traits.h\"\n#include \"oneflow/core/intrusive/static_counter.h\"\n\nnamespace oneflow {\n\n#define FLAT_MSG_VIEW_BEGIN(struct_name)                    \\\n  struct struct_name final {                                \\\n    using self_type = struct_name;                          \\\n    static const bool __is_flat_message_view_type__ = true; \\\n    FLAT_MSG_VIEW_DEFINE_BASIC_METHODS(struct_name);        \\\n                                                            \\\n   public:                                                  \\\n    DEFINE_STATIC_COUNTER(field_counter);                   \\\n    DSS_BEGIN(STATIC_COUNTER(field_counter), struct_name);\n\n#define FLAT_MSG_VIEW_END(struct_name)                                                    \\\n  static_assert(__is_flat_message_view_type__, \"this struct is not a flat message view\"); \\\n                                                                                          \\\n public:                                                                                  \\\n  static const int __LastFieldIndex__ = STATIC_COUNTER(field_counter);                    \\\n                                                                                          \\\n public:                                                                                  \\\n  INCREASE_STATIC_COUNTER(field_counter);                                                 \\\n  DSS_END(STATIC_COUNTER(field_counter), \"flat message view\", struct_name);               \\\n  }                                                                                       \\\n  ;\n\n#define FLAT_MSG_VIEW_DEFINE_PATTERN(flat_msg_field_type, field_name)                        \\\n  static_assert(__is_flat_message_view_type__, \"this struct is not a flat message view\");    \\\n  _FLAT_MSG_VIEW_DEFINE_PATTERN(FLAT_MSG_TYPE_CHECK(flat_msg_field_type), field_name);       \\\n                                                                                             \\\n public:                                                                                     \\\n  INCREASE_STATIC_COUNTER(field_counter);                                                    \\\n  FLAT_MSG_VIEW_SPECIALIZE_FIELD_TYPE(STATIC_COUNTER(field_counter), flat_msg_field_type);   \\\n  FLAT_MSG_VIEW_CHECK_LAST_FIELD_TYPE(STATIC_COUNTER(field_counter), flat_msg_field_type);   \\\n  DSS_DEFINE_FIELD(STATIC_COUNTER(field_counter), \"flat message view\", flat_msg_field_type*, \\\n                   OF_PP_CAT(field_name, _));\n\n#define FLAT_MSG_VIEW_DEFINE_REPEATED_PATTERN(flat_msg_field_type, field_name)                  \\\n  static_assert(__is_flat_message_view_type__, \"this struct is not a flat message view\");       \\\n  _FLAT_MSG_VIEW_DEFINE_REPEATED_PATTERN(FLAT_MSG_TYPE_CHECK(flat_msg_field_type), field_name); \\\n                                                                                                \\\n public:                                                                                        \\\n  INCREASE_STATIC_COUNTER(field_counter);                                                       \\\n  _SPECIALIZE_IS_REPEATED_PATTERN(STATIC_COUNTER(field_counter));                               \\\n  FLAT_MSG_VIEW_SPECIALIZE_FIELD_TYPE(STATIC_COUNTER(field_counter), flat_msg_field_type);      \\\n  FLAT_MSG_VIEW_CHECK_LAST_FIELD_TYPE(STATIC_COUNTER(field_counter), flat_msg_field_type);      \\\n  DSS_DEFINE_FIELD(STATIC_COUNTER(field_counter), \"flat message view\",                          \\\n                   FlatMsgViewPatternVec<flat_msg_field_type>, OF_PP_CAT(field_name, _));\n\n// details\n\n#define _FLAT_MSG_VIEW_DEFINE_PATTERN(field_type, field_name)                \\\n public:                                                                     \\\n  const field_type& field_name() const { return *OF_PP_CAT(field_name, _); } \\\n                                                                             \\\n private:                                                                    \\\n  const field_type* OF_PP_CAT(field_name, _);\n\n#define _FLAT_MSG_VIEW_DEFINE_REPEATED_PATTERN(field_type, field_name)                         \\\n public:                                                                                       \\\n  const field_type& field_name(int i) const { return *OF_PP_CAT(field_name, _).at(i); }        \\\n  std::size_t OF_PP_CAT(field_name, _size)() const { return OF_PP_CAT(field_name, _).size(); } \\\n                                                                                               \\\n private:                                                                                      \\\n  FlatMsgViewPatternVec<field_type> OF_PP_CAT(field_name, _);\n\n#define FLAT_MSG_VIEW_DEFINE_BASIC_METHODS(T)        \\\n public:                                             \\\n  template<int field_index, typename Enabled = void> \\\n  struct IsRepeatedPattern {                         \\\n    static const bool value = false;                 \\\n  };                                                 \\\n                                                     \\\n private:                                            \\\n  template<int field_index, typename Enabled = void> \\\n  struct __FlatMsgViewFieldType__ {                  \\\n    struct type {};                                  \\\n  };\n\n#define FLAT_MSG_VIEW_SPECIALIZE_FIELD_TYPE(field_index, field_type) \\\n private:                                                            \\\n  template<typename Enabled>                                         \\\n  struct __FlatMsgViewFieldType__<field_index, Enabled> {            \\\n    using type = field_type;                                         \\\n  };\n\n#define FLAT_MSG_VIEW_CHECK_LAST_FIELD_TYPE(field_index, field_type)                            \\\n private:                                                                                       \\\n  static void OF_PP_CAT(__CheckLastFieldType__, __LINE__)() {                                   \\\n    static_assert(                                                                              \\\n        !(IsRepeatedPattern<field_index - 1>::value                                             \\\n          && std::is_same<__FlatMsgViewFieldType__<field_index - 1>::type, field_type>::value), \\\n        \"repeated pattern shouldn't be followed by the pattern with same type\");                \\\n  }\n\n#define _SPECIALIZE_IS_REPEATED_PATTERN(field_index) \\\n  template<typename Enabled>                         \\\n  struct IsRepeatedPattern<field_index, Enabled> {   \\\n    static const bool value = true;                  \\\n  }\n\ntemplate<typename T>\nstruct FlatMsgViewPatternVec {\n  using value_type = T;\n\n  void __Init__() { new (&vec_buffer_) Vec(); }\n  void __Delete__() { mut_vec()->~Vec(); }\n\n  const T* at(int index) const { return vec().at(index); }\n  size_t size() const { return vec().size(); }\n  void clear() { mut_vec()->clear(); }\n  void emplace_back(const T* ptr) { mut_vec()->emplace_back(ptr); }\n\n private:\n  using Vec = std::vector<const T*>;\n\n  Vec* mut_vec() {\n    Vec* __attribute__((__may_alias__)) ptr = reinterpret_cast<Vec*>(&vec_buffer_);\n    return ptr;\n  }\n\n  const Vec& vec() const {\n    const Vec* __attribute__((__may_alias__)) ptr = reinterpret_cast<const Vec*>(&vec_buffer_);\n    return *ptr;\n  }\n\n  union {\n    char vec_buffer_[sizeof(Vec)];\n    int64_t align64_;\n  };\n};\n\ntemplate<typename FlatMsgViewT, typename FlatMsgOneofField, typename OneofValueType>\nclass FlatMsgViewFieldCtx {\n public:\n  using flat_msg_view_type = FlatMsgViewT;\n  static_assert(std::is_same<OneofValueType, typename FlatMsgOneofField::struct_type>::value,\n                \"invalid view match\");\n  FlatMsgViewFieldCtx(const FlatMsgViewFieldCtx&) = delete;\n  FlatMsgViewFieldCtx(FlatMsgViewFieldCtx&&) = delete;\n  FlatMsgViewFieldCtx(const OneofValueType* repeated_flag_msg, std::size_t size)\n      : repeated_flag_msg_(repeated_flag_msg), token_index_(0), size_(size) {}\n  ~FlatMsgViewFieldCtx() = default;\n\n  const OneofValueType* GetFlatMsg() const { return repeated_flag_msg_ + token_index_; }\n  typename FlatMsgOneofField::field_type* GetOneof() const {\n    return FlatMsgOneofField::FieldPtr4StructPtr(GetFlatMsg());\n  }\n  bool is_token_index_valid() const { return token_index_ < size_; }\n  void increase_token_index() { ++token_index_; }\n  int32_t token_count() const { return token_index_; }\n\n private:\n  const OneofValueType* repeated_flag_msg_;\n  int32_t token_index_;\n  const std::size_t size_;\n};\n\ntemplate<bool is_repeated_pattern, typename WalkCtxType, typename FieldPtrT>\nstruct _FlatMsgViewFieldMatcher {};\n\ntemplate<int field_counter, typename WalkCtxType, typename FieldPtrT>\nstruct FlatMsgViewFieldMatcher {\n  static const bool is_repeated_pattern =\n      WalkCtxType::flat_msg_view_type::template IsRepeatedPattern<field_counter>::value;\n  // return true if error occured\n  static bool Call(WalkCtxType* ctx, FieldPtrT* field) {\n    return _FlatMsgViewFieldMatcher<is_repeated_pattern, WalkCtxType, FieldPtrT>::Call(ctx, field);\n  }\n};\n\ntemplate<typename WalkCtxType, typename FieldPtrT>\nstruct _FlatMsgViewFieldMatcher<false, WalkCtxType, FieldPtrT> {\n  // return true if error occured\n  static bool Call(WalkCtxType* ctx, FieldPtrT* field) {\n    if (!ctx->is_token_index_valid()) { return true; }\n    using ConstFieldType = typename std::remove_pointer<FieldPtrT>::type;\n    using FieldType = typename std::remove_const<ConstFieldType>::type;\n    const auto* oneof = ctx->GetOneof();\n    if (!oneof->template HasField<FieldType>()) { return true; }\n    *field = &oneof->template GetField<FieldType>();\n    ctx->increase_token_index();\n    return false;\n  }\n};\n\ntemplate<typename WalkCtxType, typename FieldPtrT>\nstruct _FlatMsgViewFieldMatcher<true, WalkCtxType, FieldPtrT> {\n  // return true if error occured\n  static bool Call(WalkCtxType* ctx, FieldPtrT* field) {\n    field->clear();\n    using FieldType = typename FieldPtrT::value_type;\n    while (ctx->is_token_index_valid()) {\n      const auto* oneof = ctx->GetOneof();\n      if (!oneof->template HasField<FieldType>()) { break; }\n      field->emplace_back(&oneof->template GetField<FieldType>());\n      ctx->increase_token_index();\n    }\n    return false;\n  }\n};\n\ntemplate<typename FlatMsgViewT, typename FlatMsgOneofField, typename ValueType>\nstruct FlatMsgViewUtil {\n  static_assert(std::is_same<ValueType, typename FlatMsgOneofField::struct_type>::value,\n                \"invalid view match\");\n  static bool Match(FlatMsgViewT* flat_msg_view, const ValueType* data_ptr, std::size_t size) {\n    FlatMsgViewFieldCtx<FlatMsgViewT, FlatMsgOneofField, ValueType> ctx(data_ptr, size);\n    bool ret = !flat_msg_view->template __WalkFieldUntil__<FlatMsgViewFieldMatcher>(&ctx);\n    if (ret) {\n      if (FlatMsgViewT::template IsRepeatedPattern<FlatMsgViewT::__LastFieldIndex__>::value) {\n        ret = (ctx.token_count() == size)\n              || /* last repeated field empty */ (ctx.token_count() - 1 == size);\n      } else {\n        ret = (ctx.token_count() == size);\n      }\n    }\n    return ret;\n  }\n};\n\ntemplate<typename FlatMsgViewT, typename ValueType, typename ContainerT, typename Enabled = void>\nstruct FlatMsgViewContainerUtil {\n  using FlatMsgOneofField = intrusive::OffsetStructField<ValueType, typename ValueType::__OneofType,\n                                                         ValueType::__kDssFieldOffset>;\n  static bool Match(FlatMsgViewT* self, const ContainerT& container) {\n    return FlatMsgViewUtil<FlatMsgViewT, FlatMsgOneofField, typename ContainerT::value_type>::Match(\n        self, container.data(), container.size());\n  }\n};\n\ntemplate<typename FlatMsgViewT, typename ValueType, typename Enabled>\nstruct FlatMsgViewContainerUtil<FlatMsgViewT, ValueType, std::vector<FlatMsg<ValueType>>, Enabled> {\n  using FlatMsgOneofField = intrusive::OffsetStructField<ValueType, typename ValueType::__OneofType,\n                                                         ValueType::__kDssFieldOffset>;\n  static_assert(sizeof(ValueType) == sizeof(FlatMsg<ValueType>), \"\");\n  static_assert(alignof(ValueType) == alignof(FlatMsg<ValueType>), \"\");\n  static bool Match(FlatMsgViewT* self, const std::vector<FlatMsg<ValueType>>& container) {\n    return FlatMsgViewUtil<FlatMsgViewT, FlatMsgOneofField, ValueType>::Match(\n        self, &container.data()->Get(), container.size());\n  }\n};\n\ntemplate<bool is_repeated_pattern, typename FieldPtrT>\nstruct _FlatMsgViewFieldInit {};\n\ntemplate<int field_counter, typename WalkCtxType, typename FieldPtrT>\nstruct FlatMsgViewFieldInit {\n  static const bool is_repeated_pattern =\n      WalkCtxType::template IsRepeatedPattern<field_counter>::value;\n  static void Call(WalkCtxType* ctx, FieldPtrT* field) {\n    _FlatMsgViewFieldInit<is_repeated_pattern, FieldPtrT>::Call(field);\n  }\n};\n\ntemplate<typename FieldPtrT>\nstruct _FlatMsgViewFieldInit<false, FieldPtrT> {\n  static void Call(FieldPtrT* field) {}\n};\n\ntemplate<typename FieldPtrT>\nstruct _FlatMsgViewFieldInit<true, FieldPtrT> {\n  static void Call(FieldPtrT* field) { field->__Init__(); }\n};\n\ntemplate<bool is_repeated_pattern, typename FieldPtrT>\nstruct _FlatMsgViewFieldDelete {};\n\ntemplate<int field_counter, typename WalkCtxType, typename FieldPtrT>\nstruct FlatMsgViewFieldDelete {\n  static const bool is_repeated_pattern =\n      WalkCtxType::template IsRepeatedPattern<field_counter>::value;\n  static void Call(WalkCtxType* ctx, FieldPtrT* field) {\n    _FlatMsgViewFieldDelete<is_repeated_pattern, FieldPtrT>::Call(field);\n  }\n};\n\ntemplate<typename FieldPtrT>\nstruct _FlatMsgViewFieldDelete<false, FieldPtrT> {\n  static void Call(FieldPtrT* field) {}\n};\n\ntemplate<typename FieldPtrT>\nstruct _FlatMsgViewFieldDelete<true, FieldPtrT> {\n  static void Call(FieldPtrT* field) { field->__Delete__(); }\n};\n\ntemplate<typename T>\nstruct FlatMsgView final {\n  FlatMsgView(const FlatMsgView&) = delete;\n  FlatMsgView(FlatMsgView&&) = delete;\n  static_assert(T::__is_flat_message_view_type__, \"T is not a flat message view type\");\n  FlatMsgView() { view_.template __WalkField__<FlatMsgViewFieldInit>(&view_); }\n  template<typename RepeatedFlatMsgT>\n  explicit FlatMsgView(const RepeatedFlatMsgT& repeated_flat_msg) {\n    view_.template __WalkField__<FlatMsgViewFieldInit>(&view_);\n    CHECK(this->template Match(repeated_flat_msg));\n  }\n  ~FlatMsgView() { view_.template __ReverseWalkField__<FlatMsgViewFieldDelete>(&view_); }\n\n  const T& operator*() const { return view_; }\n  T& operator*() { return view_; }\n  const T* operator->() const { return &view_; }\n  T* operator->() { return &view_; }\n\n  const T& Get() const { return view_; }\n  T* Mutable() { return &view_; }\n\n  template<typename RepeatedFlatMsgT>\n  bool Match(const RepeatedFlatMsgT& repeated_flat_msg) {\n    using OneofType = typename RepeatedFlatMsgT::value_type::self_value_type;\n    return FlatMsgViewContainerUtil<T, OneofType, RepeatedFlatMsgT>::Match(&view_,\n                                                                           repeated_flat_msg);\n  }\n\n private:\n  union {\n    T view_;\n  };\n};\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_INTRUSIVE_FLAT_MSG_VIEW_H_\n"
  },
  {
    "path": "oneflow/core/intrusive/flat_msg_view_test.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"gtest/gtest.h\"\n#include \"oneflow/core/intrusive/flat_msg_view.h\"\n#include \"oneflow/core/common/util.h\"\n\nnamespace oneflow {\n\nnamespace test {\n\nnamespace {\n\n// clang-format off\nFLAT_MSG_BEGIN(VariantFoo);\n  FLAT_MSG_DEFINE_STRICT_ONEOF(_,\n    FLAT_MSG_ONEOF_FIELD(int8_t, int8_value)\n    FLAT_MSG_ONEOF_FIELD(int16_t, int16_value)\n    FLAT_MSG_ONEOF_FIELD(int32_t, int32_value)\n    FLAT_MSG_ONEOF_FIELD(float, float_value));\nFLAT_MSG_END(VariantFoo);\n// clang-format on\n\n// clang-format off\nFLAT_MSG_BEGIN(VariantList);\n  FLAT_MSG_DEFINE_REPEATED(VariantFoo, foo, 16);\nFLAT_MSG_END(VariantList);\n// clang-format on\n\n// clang-format off\nFLAT_MSG_VIEW_BEGIN(ViewFoo);\n  FLAT_MSG_VIEW_DEFINE_PATTERN(int32_t, int32_value);\n  FLAT_MSG_VIEW_DEFINE_PATTERN(int16_t, int16_value);\n  FLAT_MSG_VIEW_DEFINE_PATTERN(float, float_value);\nFLAT_MSG_VIEW_END(ViewFoo);\n// clang-format on\n\nTEST(FlatMsgView, match_success) {\n  FlatMsg<VariantList> variant_list;\n  variant_list.Mutable()->mutable_foo()->Add()->set_int32_value(30);\n  variant_list.Mutable()->mutable_foo()->Add()->set_int16_value(40);\n  variant_list.Mutable()->mutable_foo()->Add()->set_float_value(50.0);\n  FlatMsgView<ViewFoo> view;\n  ASSERT_TRUE(view.template Match(variant_list.Get().foo()));\n  ASSERT_EQ(view->int32_value(), 30);\n  ASSERT_EQ(view->int16_value(), 40);\n  ASSERT_EQ(view->float_value(), 50.0);\n}\n\nTEST(FlatMsgView, match_failed) {\n  FlatMsg<VariantList> variant_list;\n  variant_list.Mutable()->mutable_foo()->Add()->set_int16_value(40);\n  variant_list.Mutable()->mutable_foo()->Add()->set_int32_value(30);\n  variant_list.Mutable()->mutable_foo()->Add()->set_float_value(50.0);\n  FlatMsgView<ViewFoo> view;\n  ASSERT_TRUE(!view.template Match(variant_list.Get().foo()));\n}\n\nTEST(FlatMsgView, match_success_vector) {\n  std::vector<FlatMsg<VariantFoo>> variant_list(3);\n  variant_list.at(0)->set_int32_value(30);\n  variant_list.at(1)->set_int16_value(40);\n  variant_list.at(2)->set_float_value(50.0);\n  FlatMsgView<ViewFoo> view;\n  ASSERT_TRUE(view.template Match(variant_list));\n  ASSERT_EQ(view->int32_value(), 30);\n  ASSERT_EQ(view->int16_value(), 40);\n  ASSERT_EQ(view->float_value(), 50.0);\n}\n\nTEST(FlatMsgView, match_failed_vector) {\n  std::vector<FlatMsg<VariantFoo>> variant_list(3);\n  variant_list.at(0)->set_int16_value(40);\n  variant_list.at(1)->set_int32_value(30);\n  variant_list.at(2)->set_float_value(50.0);\n  FlatMsgView<ViewFoo> view;\n  ASSERT_TRUE(!view.template Match(variant_list));\n}\n\n// clang-format off\nFLAT_MSG_VIEW_BEGIN(RepeatedFoo);\n  FLAT_MSG_VIEW_DEFINE_PATTERN(int32_t, int32_value);\n  FLAT_MSG_VIEW_DEFINE_REPEATED_PATTERN(int16_t, int16_value);\n  FLAT_MSG_VIEW_DEFINE_PATTERN(float, float_value);\nFLAT_MSG_VIEW_END(RepeatedFoo);\n// clang-format on\n\nTEST(FlatMsgView, repeated_empty) {\n  std::vector<FlatMsg<VariantFoo>> variant_list(2);\n  variant_list.at(0)->set_int32_value(40);\n  variant_list.at(1)->set_float_value(50.0);\n  FlatMsgView<RepeatedFoo> view;\n  ASSERT_TRUE(view.Match(variant_list));\n  ASSERT_EQ(view->int16_value_size(), 0);\n}\n\nTEST(FlatMsgView, repeated_empty_failed) {\n  std::vector<FlatMsg<VariantFoo>> variant_list(2);\n  variant_list.at(0)->set_float_value(50.0);\n  variant_list.at(1)->set_int32_value(40);\n  FlatMsgView<RepeatedFoo> view;\n  ASSERT_TRUE(!view.Match(variant_list));\n}\n\nTEST(FlatMsgView, repeated_one) {\n  std::vector<FlatMsg<VariantFoo>> variant_list(3);\n  variant_list.at(0)->set_int32_value(40);\n  variant_list.at(1)->set_int16_value(45);\n  variant_list.at(2)->set_float_value(50.0);\n  FlatMsgView<RepeatedFoo> view;\n  ASSERT_TRUE(view.Match(variant_list));\n  ASSERT_EQ(view->int16_value_size(), 1);\n  ASSERT_EQ(view->int16_value(0), 45);\n}\n\nTEST(FlatMsgView, repeated_one_failed) {\n  std::vector<FlatMsg<VariantFoo>> variant_list(3);\n  variant_list.at(0)->set_int32_value(40);\n  variant_list.at(1)->set_float_value(50.0);\n  variant_list.at(2)->set_int16_value(45);\n  FlatMsgView<RepeatedFoo> view;\n  ASSERT_TRUE(!view.Match(variant_list));\n}\n\nTEST(FlatMsgView, repeated_many) {\n  std::vector<FlatMsg<VariantFoo>> variant_list(4);\n  variant_list.at(0)->set_int32_value(40);\n  variant_list.at(1)->set_int16_value(45);\n  variant_list.at(2)->set_int16_value(45);\n  variant_list.at(3)->set_float_value(50.0);\n  FlatMsgView<RepeatedFoo> view;\n  ASSERT_TRUE(view.Match(variant_list));\n  ASSERT_EQ(view->int16_value_size(), 2);\n  ASSERT_EQ(view->int16_value(0), 45);\n  ASSERT_EQ(view->int16_value(1), 45);\n}\n\nTEST(FlatMsgView, repeated_many_failed) {\n  std::vector<FlatMsg<VariantFoo>> variant_list(4);\n  variant_list.at(0)->set_int32_value(40);\n  variant_list.at(1)->set_int16_value(45);\n  variant_list.at(2)->set_float_value(45.0);\n  variant_list.at(3)->set_float_value(50.0);\n  FlatMsgView<RepeatedFoo> view;\n  ASSERT_TRUE(!view.Match(variant_list));\n}\n\n// clang-format off\nFLAT_MSG_VIEW_BEGIN(LastFieldRepeatedFoo);\n  FLAT_MSG_VIEW_DEFINE_PATTERN(int32_t, int32_value);\n  FLAT_MSG_VIEW_DEFINE_PATTERN(float, float_value);\n  FLAT_MSG_VIEW_DEFINE_REPEATED_PATTERN(int16_t, int16_value);\nFLAT_MSG_VIEW_END(LastFieldRepeatedFoo);\n// clang-format on\n\nTEST(FlatMsgView, last_field_repeated_empty) {\n  std::vector<FlatMsg<VariantFoo>> variant_list(2);\n  variant_list.at(0)->set_int32_value(40);\n  variant_list.at(1)->set_float_value(50.0);\n  FlatMsgView<LastFieldRepeatedFoo> view;\n  ASSERT_TRUE(view.Match(variant_list));\n  ASSERT_EQ(view->int16_value_size(), 0);\n}\n\nTEST(FlatMsgView, last_field_repeated_empty_failed) {\n  std::vector<FlatMsg<VariantFoo>> variant_list(2);\n  variant_list.at(0)->set_float_value(50.0);\n  variant_list.at(1)->set_int32_value(40);\n  FlatMsgView<LastFieldRepeatedFoo> view;\n  ASSERT_TRUE(!view.Match(variant_list));\n}\n\nTEST(FlatMsgView, last_field_repeated_one) {\n  std::vector<FlatMsg<VariantFoo>> variant_list(3);\n  variant_list.at(0)->set_int32_value(40);\n  variant_list.at(1)->set_float_value(50.0);\n  variant_list.at(2)->set_int16_value(45);\n  FlatMsgView<LastFieldRepeatedFoo> view;\n  ASSERT_TRUE(view.Match(variant_list));\n  ASSERT_EQ(view->int16_value_size(), 1);\n  ASSERT_EQ(view->int16_value(0), 45);\n}\n\nTEST(FlatMsgView, last_field_repeated_one_failed) {\n  std::vector<FlatMsg<VariantFoo>> variant_list(3);\n  variant_list.at(0)->set_int32_value(40);\n  variant_list.at(1)->set_int16_value(45);\n  variant_list.at(2)->set_float_value(50.0);\n  FlatMsgView<LastFieldRepeatedFoo> view;\n  ASSERT_TRUE(!view.Match(variant_list));\n}\n\nTEST(FlatMsgView, last_field_repeated_many) {\n  std::vector<FlatMsg<VariantFoo>> variant_list(4);\n  variant_list.at(0)->set_int32_value(40);\n  variant_list.at(1)->set_float_value(50.0);\n  variant_list.at(2)->set_int16_value(45);\n  variant_list.at(3)->set_int16_value(45);\n  FlatMsgView<LastFieldRepeatedFoo> view;\n  ASSERT_TRUE(view.Match(variant_list));\n  ASSERT_EQ(view->int16_value_size(), 2);\n  ASSERT_EQ(view->int16_value(0), 45);\n  ASSERT_EQ(view->int16_value(1), 45);\n}\n\nTEST(FlatMsgView, last_field_repeated_many_failed) {\n  std::vector<FlatMsg<VariantFoo>> variant_list(4);\n  variant_list.at(0)->set_int32_value(40);\n  variant_list.at(1)->set_int16_value(45);\n  variant_list.at(2)->set_float_value(50.0);\n  variant_list.at(3)->set_int16_value(45);\n  FlatMsgView<LastFieldRepeatedFoo> view;\n  ASSERT_TRUE(!view.Match(variant_list));\n}\n\n}  // namespace\n\n}  // namespace test\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/intrusive/for_each.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_INTRUSIVE_FOR_EACH_H_\n#define ONEFLOW_CORE_INTRUSIVE_FOR_EACH_H_\n\n#include \"oneflow/core/intrusive/list_hook.h\"\n#include \"oneflow/core/intrusive/struct_traits.h\"\n\nnamespace oneflow {\nnamespace intrusive {\n\n#define INTRUSIVE_FOR_EACH(elem, container) \\\n  _INTRUSIVE_FOR_EACH(std::remove_pointer<decltype(container)>::type, elem, container)\n\n#define INTRUSIVE_FOR_EACH_PTR(elem, container) \\\n  _INTRUSIVE_FOR_EACH_PTR(std::remove_pointer<decltype(container)>::type, elem, container)\n\n#define INTRUSIVE_UNSAFE_FOR_EACH_PTR(elem, container) \\\n  _INTRUSIVE_UNSAFE_FOR_EACH_PTR(std::remove_pointer<decltype(container)>::type, elem, container)\n\n// details\n\n#define _INTRUSIVE_FOR_EACH(container_type, elem, container)                     \\\n  for (intrusive::shared_ptr<typename container_type::value_type> elem,          \\\n       *end_if_not_null = nullptr;                                               \\\n       end_if_not_null == nullptr; end_if_not_null = nullptr, ++end_if_not_null) \\\n  LIST_HOOK_FOR_EACH_WITH_EXPR(                                                  \\\n      (intrusive::OffsetStructField<                                             \\\n          typename container_type, intrusive::ListHook,                          \\\n          container_type::IteratorHookOffset()>::FieldPtr4StructPtr(container)), \\\n      container_type::iterator_struct_field, elem_ptr, (elem.Reset(elem_ptr), true))\n\n#define _INTRUSIVE_FOR_EACH_PTR(container_type, elem, container)                                \\\n  LIST_HOOK_FOR_EACH((intrusive::OffsetStructField<                                             \\\n                         typename container_type, intrusive::ListHook,                          \\\n                         container_type::IteratorHookOffset()>::FieldPtr4StructPtr(container)), \\\n                     container_type::iterator_struct_field, elem)\n\n#define _INTRUSIVE_UNSAFE_FOR_EACH_PTR(container_type, elem, container)          \\\n  LIST_HOOK_UNSAFE_FOR_EACH(                                                     \\\n      (intrusive::OffsetStructField<                                             \\\n          typename container_type, intrusive::ListHook,                          \\\n          container_type::IteratorHookOffset()>::FieldPtr4StructPtr(container)), \\\n      container_type::iterator_struct_field, elem)\n\n}  // namespace intrusive\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_INTRUSIVE_FOR_EACH_H_\n"
  },
  {
    "path": "oneflow/core/intrusive/force_standard_layout.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_INTRUSIVE_FORCE_STANDARD_LAYOUT_H_\n#define ONEFLOW_CORE_INTRUSIVE_FORCE_STANDARD_LAYOUT_H_\n\nnamespace oneflow {\nnamespace intrusive {\n\ntemplate<typename T>\nclass ForceStandardLayout final {\n public:\n  ForceStandardLayout() { new (&object_) T(); }\n  template<typename Arg, typename = typename std::enable_if<!std::is_same<\n                             ForceStandardLayout, typename std::decay<Arg>::type>::value>::type>\n  explicit ForceStandardLayout(Arg&& arg) {\n    new (&object_) T(std::forward<Arg>(arg));\n  }\n  template<typename Arg0, typename Arg1, typename... Args>\n  ForceStandardLayout(Arg0&& arg0, Arg1&& arg1, Args&&... args) {\n    new (&object_)\n        T(std::forward<Arg0>(arg0), std::forward<Arg1>(arg1), std::forward<Args>(args)...);\n  }\n\n  ~ForceStandardLayout() { Mutable()->~T(); }\n\n  ForceStandardLayout(const ForceStandardLayout& other) { new (&object_) T(other.Get()); }\n  ForceStandardLayout(ForceStandardLayout&& other) {\n    new (&object_) T(std::move(*other.Mutable()));\n  }\n\n  ForceStandardLayout& operator=(const ForceStandardLayout& other) {\n    *Mutable() = other.Get();\n    return *this;\n  }\n  ForceStandardLayout& operator=(ForceStandardLayout&& other) {\n    *Mutable() = std::move(*other.Mutable());\n    return *this;\n  }\n\n  const T& Get() const {\n    const auto* __attribute__((__may_alias__)) ptr = reinterpret_cast<const T*>(&object_[0]);\n    return *ptr;\n  }\n\n  T* Mutable() {\n    auto* __attribute__((__may_alias__)) ptr = reinterpret_cast<T*>(&object_[0]);\n    return ptr;\n  }\n\n private:\n  alignas(T) char object_[sizeof(T)];\n};\n\n}  // namespace intrusive\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_INTRUSIVE_FORCE_STANDARD_LAYOUT_H_\n"
  },
  {
    "path": "oneflow/core/intrusive/force_standard_layout_test.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n// include sstream first to avoid some compiling error\n// caused by the following trick\n// reference: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=65899\n#include <sstream>\n#include \"gtest/gtest.h\"\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/intrusive/force_standard_layout.h\"\n\nnamespace oneflow {\n\nnamespace intrusive {\n\nnamespace test {\n\nconstexpr const int unstandard_value = 999;\nconstexpr const int standard_value = 666;\n\nstruct Unstandard {\n public:\n  explicit Unstandard(int* ptr) : x(unstandard_value), ptr_(ptr) {}\n  ~Unstandard() { *ptr_ = unstandard_value; }\n\n  Unstandard(const Unstandard&) = default;\n  Unstandard(Unstandard&&) = default;\n  Unstandard& operator=(const Unstandard&) = default;\n  Unstandard& operator=(Unstandard&&) = default;\n\n  int* ptr() const { return ptr_; }\n  void set_ptr(int* val) { ptr_ = val; }\n\n  int x;\n\n private:\n  int* ptr_;\n};\n\nTEST(ForceStandardLayout, default_constructor) {\n  int value = standard_value;\n  ForceStandardLayout<Unstandard> sl(&value);\n  ASSERT_EQ(sl.Get().x, unstandard_value);\n  ASSERT_EQ(sl.Get().ptr(), &value);\n}\n\nTEST(ForceStandardLayout, copy_constructor) {\n  int value = standard_value;\n  const ForceStandardLayout<Unstandard> const_sl(&value);\n  ForceStandardLayout<Unstandard> sl(const_sl);  // NOLINT\n  ASSERT_EQ(sl.Get().x, unstandard_value);\n  ASSERT_EQ(sl.Get().ptr(), &value);\n}\n\nTEST(ForceStandardLayout, move_constructor) {\n  int value = standard_value;\n  ForceStandardLayout<Unstandard> old_sl(&value);\n  ForceStandardLayout<Unstandard> sl(std::move(old_sl));\n  ASSERT_EQ(sl.Get().x, unstandard_value);\n  ASSERT_EQ(sl.Get().ptr(), &value);\n}\n\nTEST(ForceStandardLayout, copy_assign) {\n  int value = standard_value;\n  const ForceStandardLayout<Unstandard> const_sl(&value);\n  ForceStandardLayout<Unstandard> sl = const_sl;  // NOLINT\n  ASSERT_EQ(sl.Get().x, unstandard_value);\n  ASSERT_EQ(sl.Get().ptr(), &value);\n}\n\nTEST(ForceStandardLayout, move_assign) {\n  int value = standard_value;\n  ForceStandardLayout<Unstandard> sl = ForceStandardLayout<Unstandard>(&value);\n  ASSERT_EQ(sl.Get().x, unstandard_value);\n  ASSERT_EQ(sl.Get().ptr(), &value);\n}\n\nTEST(ForceStandardLayout, destructor) {\n  int value = standard_value;\n  { ForceStandardLayout<Unstandard> sl(&value); }\n  ASSERT_EQ(value, unstandard_value);\n}\n\n}  // namespace test\n\n}  // namespace intrusive\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/intrusive/head_free_list.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_INTRUSIVE_HEAD_FREE_LIST_H_\n#define ONEFLOW_CORE_INTRUSIVE_HEAD_FREE_LIST_H_\n\n#include \"oneflow/core/intrusive/ref.h\"\n#include \"oneflow/core/intrusive/list_hook.h\"\n#include \"oneflow/core/intrusive/struct_traits.h\"\n#include \"oneflow/core/intrusive/reflective.h\"\n\nnamespace oneflow {\nnamespace intrusive {\n\ntemplate<typename ValueHookField, int field_counter>\nclass HeadFreeList {\n public:\n  static_assert(std::is_same<typename ValueHookField::field_type, intrusive::ListHook>::value, \"\");\n  HeadFreeList(const HeadFreeList&) = delete;\n  HeadFreeList(HeadFreeList&&) = delete;\n  HeadFreeList() { this->__Init__(); }\n  ~HeadFreeList() { this->Clear(); }\n\n  using value_type = typename ValueHookField::struct_type;\n  using iterator_struct_field = ValueHookField;\n\n  // field_counter is last field_number\n  static const int field_number_in_countainter = field_counter + 1;\n\n  template<typename Enabled = void>\n  static constexpr int IteratorHookOffset() {\n    return offsetof(HeadFreeList, list_head_)\n           + intrusive::ListHead<ValueHookField>::IteratorHookOffset();\n  }\n\n  std::size_t size() const { return list_head_.size(); }\n  bool empty() const { return list_head_.empty(); }\n\n  void __Init__() {\n    list_head_.__Init__();\n    static_assert(\n        std::is_same<HeadFreeList, REFLECTIVE_FIELD_TYPE(typename value_type,\n                                                         field_number_in_countainter)>::value,\n        \"It's invalid to define fields between definition of head-free list type and definition of \"\n        \"head-free list field.\");\n    using ThisInContainer =\n        OffsetStructField<value_type, HeadFreeList,\n                          REFLECTIVE_FIELD_OFFSET(value_type, field_number_in_countainter)>;\n    container_ = ThisInContainer::StructPtr4FieldPtr(this);\n  }\n\n  value_type* Begin() {\n    if (list_head_.empty()) { return nullptr; }\n    return list_head_.Begin();\n  }\n  value_type* Next(value_type* ptr) {\n    if (ptr == nullptr) { return nullptr; }\n    value_type* next = list_head_.Next(ptr);\n    if (next == list_head_.End()) { return nullptr; }\n    return next;\n  }\n  value_type* Last() {\n    if (list_head_.empty()) { return nullptr; }\n    return list_head_.Last();\n  }\n  constexpr value_type* End() const { return nullptr; }\n\n  void MoveToDstBack(value_type* ptr, HeadFreeList* dst) {\n    list_head_.MoveToDstBack(ptr, &dst->list_head_);\n    MoveReference(ptr, dst);\n  }\n  void MoveToDstFront(value_type* ptr, HeadFreeList* dst) {\n    list_head_.MoveToDstFront(ptr, &dst->list_head_);\n    MoveReference(ptr, dst);\n  }\n  value_type* MoveFrontToDstBack(HeadFreeList* dst) {\n    value_type* begin = list_head_.Begin();\n    MoveToDstBack(begin, dst);\n    return begin;\n  }\n  value_type* MoveBackToDstBack(HeadFreeList* dst) {\n    value_type* begin = list_head_.Last();\n    MoveToDstBack(begin, dst);\n    return begin;\n  }\n\n  void PushBack(value_type* ptr) {\n    list_head_.PushBack(ptr);\n    if (container_ != ptr) { Ref::IncreaseRef(ptr); }\n  }\n\n  void PushFront(value_type* ptr) {\n    list_head_.PushFront(ptr);\n    if (container_ != ptr) { Ref::IncreaseRef(ptr); }\n  }\n\n  void EmplaceBack(intrusive::shared_ptr<value_type>&& ptr) {\n    value_type* raw_ptr = nullptr;\n    if (container_ != ptr.Mutable()) {\n      ptr.__UnsafeMoveTo__(&raw_ptr);\n    } else {\n      raw_ptr = ptr.Mutable();\n    }\n    list_head_.PushBack(raw_ptr);\n  }\n\n  void EmplaceFront(intrusive::shared_ptr<value_type>&& ptr) {\n    value_type* raw_ptr = nullptr;\n    if (container_ != ptr.Mutable()) {\n      ptr.__UnsafeMoveTo__(&raw_ptr);\n    } else {\n      raw_ptr = ptr.Mutable();\n    }\n    list_head_.PushFront(raw_ptr);\n  }\n\n  intrusive::shared_ptr<value_type> Erase(value_type* ptr) {\n    list_head_.Erase(ptr);\n    if (container_ != ptr) {\n      return intrusive::shared_ptr<value_type>::__UnsafeMove__(ptr);\n    } else {\n      return intrusive::shared_ptr<value_type>(ptr);\n    }\n  }\n\n  intrusive::shared_ptr<value_type> PopBack() {\n    value_type* raw_ptr = nullptr;\n    if (!list_head_.empty()) { raw_ptr = list_head_.PopBack(); }\n    if (container_ != raw_ptr) {\n      return intrusive::shared_ptr<value_type>::__UnsafeMove__(raw_ptr);\n    } else {\n      return intrusive::shared_ptr<value_type>(raw_ptr);\n    }\n  }\n\n  intrusive::shared_ptr<value_type> PopFront() {\n    value_type* raw_ptr = nullptr;\n    if (!list_head_.empty()) { raw_ptr = list_head_.PopFront(); }\n    if (container_ != raw_ptr) {\n      return intrusive::shared_ptr<value_type>::__UnsafeMove__(raw_ptr);\n    } else {\n      return intrusive::shared_ptr<value_type>(raw_ptr);\n    }\n  }\n\n  void MoveTo(HeadFreeList* list) { MoveToDstBack(list); }\n  void MoveToDstBack(HeadFreeList* list) {\n    while (!empty()) { MoveToDstBack(list_head_.Begin(), list); }\n  }\n\n  void Clear() {\n    while (!empty()) {\n      auto* ptr = list_head_.PopFront();\n      if (container_ != ptr) { Ref::DecreaseRef(ptr); }\n    }\n  }\n\n private:\n  void MoveReference(value_type* ptr, HeadFreeList* dst) {\n    if (ptr == container_ && ptr != dst->container_) {\n      Ref::IncreaseRef(ptr);\n    } else if (ptr != container_ && ptr == dst->container_) {\n      Ref::DecreaseRef(ptr);\n    } else {\n      // do nothing\n    }\n  }\n\n  intrusive::ListHead<ValueHookField> list_head_;\n  const value_type* container_;\n};\n\n}  // namespace intrusive\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_INTRUSIVE_HEAD_FREE_LIST_H_\n"
  },
  {
    "path": "oneflow/core/intrusive/head_free_list_test.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n// include sstream first to avoid some compiling error\n// caused by the following trick\n// reference: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=65899\n#include <sstream>\n#include \"gtest/gtest.h\"\n#define private public\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/intrusive/intrusive.h\"\n\nnamespace oneflow {\n\nnamespace test {\n\nnamespace {\n\n// clang-format off\nREFLECTIVE_CLASS_BEGIN(SelfLoopContainer);\n public:\n  void __Init__() { clear_deleted(); }\n  // Getters\n  bool has_deleted() const { return deleted_ != nullptr; }\n  bool deleted() const { return *deleted_; } \n  bool is_hook_empty() const { return hook_.empty(); }\n  // Setters\n  bool* mut_deleted() { return deleted_; }\n  void set_deleted(bool* val) { deleted_ = val; }\n  void clear_deleted() { deleted_ = nullptr; }\n\n  // methods\n  void __Init__(bool* deleted) {\n    __Init__();\n    set_deleted(deleted);\n  }\n  void __Delete__() { *mut_deleted() = true; }\n\n  size_t ref_cnt() const { return intrusive_ref_.ref_cnt(); }\n\n private:\n  friend class intrusive::Ref;\n  intrusive::Ref* mut_intrusive_ref() { return &intrusive_ref_; }\n\n  SelfLoopContainer() : intrusive_ref_(), deleted_(), hook_(), head_() {}\n  REFLECTIVE_CLASS_DEFINE_FIELD(intrusive::Ref, intrusive_ref_);\n  // fields\n  REFLECTIVE_CLASS_DEFINE_FIELD(bool*, deleted_);\n  // list hooks\n  REFLECTIVE_CLASS_DEFINE_FIELD(intrusive::ListHook, hook_);\n\n public:\n  // Do not insert other REFLECTIVE_CLASS_DEFINE_FIELD between `using SelfLoopContainerList = ...;` and `REFLECTIVE_CLASS_DEFINE_FIELD(SelfLoopContainerList, ...);` \n  using SelfLoopContainerList =\n      intrusive::HeadFreeList<REFLECTIVE_FIELD(SelfLoopContainer, hook_), REFLECTIVE_FIELD_COUNTER>;\n  const SelfLoopContainerList& head() const { return head_; }\n  SelfLoopContainerList* mut_head() { return &head_; }\n\n private:\n  REFLECTIVE_CLASS_DEFINE_FIELD(SelfLoopContainerList, head_);\nREFLECTIVE_CLASS_END(SelfLoopContainer);\n// clang-format on\n\nTEST(HeadFreeList, __Init__) {\n  bool deleted = false;\n  auto self_loop_head = intrusive::make_shared<SelfLoopContainer>(&deleted);\n  ASSERT_EQ(self_loop_head->mut_head()->container_, self_loop_head.Mutable());\n}\n\nTEST(HeadFreeList, PushBack) {\n  bool deleted0 = false;\n  bool deleted1 = false;\n  {\n    auto self_loop_head0 = intrusive::make_shared<SelfLoopContainer>(&deleted0);\n    auto self_loop_head1 = intrusive::make_shared<SelfLoopContainer>(&deleted1);\n    ASSERT_EQ(self_loop_head0->ref_cnt(), 1);\n    ASSERT_EQ(self_loop_head1->ref_cnt(), 1);\n    self_loop_head0->mut_head()->PushBack(self_loop_head0.Mutable());\n    ASSERT_EQ(self_loop_head0->head().size(), 1);\n    ASSERT_EQ(self_loop_head0->ref_cnt(), 1);\n    self_loop_head0->mut_head()->PushBack(self_loop_head1.Mutable());\n    ASSERT_EQ(self_loop_head1->ref_cnt(), 2);\n    ASSERT_EQ(self_loop_head0->head().size(), 2);\n  }\n  ASSERT_TRUE(deleted0);\n  ASSERT_TRUE(deleted1);\n}\n\nTEST(HeadFreeList, PushFront) {\n  bool deleted0 = false;\n  bool deleted1 = false;\n  {\n    auto self_loop_head0 = intrusive::make_shared<SelfLoopContainer>(&deleted0);\n    auto self_loop_head1 = intrusive::make_shared<SelfLoopContainer>(&deleted1);\n    ASSERT_EQ(self_loop_head0->ref_cnt(), 1);\n    ASSERT_EQ(self_loop_head1->ref_cnt(), 1);\n    self_loop_head0->mut_head()->PushFront(self_loop_head0.Mutable());\n    ASSERT_EQ(self_loop_head0->head().size(), 1);\n    ASSERT_EQ(self_loop_head0->ref_cnt(), 1);\n    self_loop_head0->mut_head()->PushFront(self_loop_head1.Mutable());\n    ASSERT_EQ(self_loop_head1->ref_cnt(), 2);\n    ASSERT_EQ(self_loop_head0->head().size(), 2);\n  }\n  ASSERT_TRUE(deleted0);\n  ASSERT_TRUE(deleted1);\n}\n\nTEST(HeadFreeList, EmplaceBack) {\n  bool deleted0 = false;\n  bool deleted1 = false;\n  {\n    auto self_loop_head0 = intrusive::make_shared<SelfLoopContainer>(&deleted0);\n    auto self_loop_head1 = intrusive::make_shared<SelfLoopContainer>(&deleted1);\n    ASSERT_EQ(self_loop_head0->ref_cnt(), 1);\n    ASSERT_EQ(self_loop_head1->ref_cnt(), 1);\n    self_loop_head0->mut_head()->EmplaceBack(\n        intrusive::shared_ptr<SelfLoopContainer>(self_loop_head0));\n    ASSERT_EQ(self_loop_head0->head().size(), 1);\n    ASSERT_EQ(self_loop_head0->ref_cnt(), 1);\n    self_loop_head0->mut_head()->EmplaceBack(\n        intrusive::shared_ptr<SelfLoopContainer>(self_loop_head1));\n    ASSERT_EQ(self_loop_head1->ref_cnt(), 2);\n    ASSERT_EQ(self_loop_head0->head().size(), 2);\n  }\n  ASSERT_TRUE(deleted0);\n  ASSERT_TRUE(deleted1);\n}\n\nTEST(HeadFreeList, EmplaceFront) {\n  bool deleted0 = false;\n  bool deleted1 = false;\n  {\n    auto self_loop_head0 = intrusive::make_shared<SelfLoopContainer>(&deleted0);\n    auto self_loop_head1 = intrusive::make_shared<SelfLoopContainer>(&deleted1);\n    ASSERT_EQ(self_loop_head0->ref_cnt(), 1);\n    ASSERT_EQ(self_loop_head1->ref_cnt(), 1);\n    self_loop_head0->mut_head()->EmplaceFront(\n        intrusive::shared_ptr<SelfLoopContainer>(self_loop_head0));\n    ASSERT_EQ(self_loop_head0->head().size(), 1);\n    ASSERT_EQ(self_loop_head0->ref_cnt(), 1);\n    self_loop_head0->mut_head()->EmplaceFront(\n        intrusive::shared_ptr<SelfLoopContainer>(self_loop_head1));\n    ASSERT_EQ(self_loop_head1->ref_cnt(), 2);\n    ASSERT_EQ(self_loop_head0->head().size(), 2);\n  }\n  ASSERT_TRUE(deleted0);\n  ASSERT_TRUE(deleted1);\n}\n\nTEST(HeadFreeList, Erase) {\n  bool deleted0 = false;\n  bool deleted1 = false;\n  {\n    auto self_loop_head0 = intrusive::make_shared<SelfLoopContainer>(&deleted0);\n    auto self_loop_head1 = intrusive::make_shared<SelfLoopContainer>(&deleted1);\n    self_loop_head0->mut_head()->PushBack(self_loop_head0.Mutable());\n    self_loop_head0->mut_head()->PushBack(self_loop_head1.Mutable());\n    self_loop_head0->mut_head()->Erase(self_loop_head0.Mutable());\n    self_loop_head0->mut_head()->Erase(self_loop_head1.Mutable());\n    ASSERT_EQ(self_loop_head0->ref_cnt(), 1);\n    ASSERT_EQ(self_loop_head1->ref_cnt(), 1);\n  }\n  ASSERT_TRUE(deleted0);\n  ASSERT_TRUE(deleted1);\n}\n\nTEST(HeadFreeList, PopBack) {\n  bool deleted0 = false;\n  bool deleted1 = false;\n  {\n    auto self_loop_head0 = intrusive::make_shared<SelfLoopContainer>(&deleted0);\n    auto self_loop_head1 = intrusive::make_shared<SelfLoopContainer>(&deleted1);\n    self_loop_head0->mut_head()->PushBack(self_loop_head0.Mutable());\n    self_loop_head0->mut_head()->PushBack(self_loop_head1.Mutable());\n    self_loop_head0->mut_head()->PopBack();\n    self_loop_head0->mut_head()->PopBack();\n    ASSERT_EQ(self_loop_head0->ref_cnt(), 1);\n    ASSERT_EQ(self_loop_head1->ref_cnt(), 1);\n  }\n  ASSERT_TRUE(deleted0);\n  ASSERT_TRUE(deleted1);\n}\n\nTEST(HeadFreeList, PopFront) {\n  bool deleted0 = false;\n  bool deleted1 = false;\n  {\n    auto self_loop_head0 = intrusive::make_shared<SelfLoopContainer>(&deleted0);\n    auto self_loop_head1 = intrusive::make_shared<SelfLoopContainer>(&deleted1);\n    self_loop_head0->mut_head()->PushBack(self_loop_head0.Mutable());\n    self_loop_head0->mut_head()->PushBack(self_loop_head1.Mutable());\n    self_loop_head0->mut_head()->PopFront();\n    self_loop_head0->mut_head()->PopFront();\n    ASSERT_EQ(self_loop_head0->ref_cnt(), 1);\n    ASSERT_EQ(self_loop_head1->ref_cnt(), 1);\n  }\n  ASSERT_TRUE(deleted0);\n  ASSERT_TRUE(deleted1);\n}\n\nTEST(HeadFreeList, MoveTo) {\n  bool deleted0 = false;\n  bool deleted1 = false;\n  {\n    auto self_loop_head0 = intrusive::make_shared<SelfLoopContainer>(&deleted0);\n    auto self_loop_head1 = intrusive::make_shared<SelfLoopContainer>(&deleted1);\n    self_loop_head0->mut_head()->PushBack(self_loop_head0.Mutable());\n    self_loop_head0->mut_head()->PushBack(self_loop_head1.Mutable());\n    self_loop_head0->mut_head()->MoveTo(self_loop_head1->mut_head());\n    ASSERT_EQ(self_loop_head0->ref_cnt(), 2);\n    ASSERT_EQ(self_loop_head1->ref_cnt(), 1);\n  }\n  ASSERT_TRUE(deleted0);\n  ASSERT_TRUE(deleted1);\n}\n\nTEST(HeadFreeList, Clear) {\n  bool deleted0 = false;\n  bool deleted1 = false;\n  {\n    auto self_loop_head0 = intrusive::make_shared<SelfLoopContainer>(&deleted0);\n    auto self_loop_head1 = intrusive::make_shared<SelfLoopContainer>(&deleted1);\n    self_loop_head0->mut_head()->PushBack(self_loop_head0.Mutable());\n    self_loop_head0->mut_head()->PushBack(self_loop_head1.Mutable());\n    self_loop_head0->mut_head()->Clear();\n    ASSERT_EQ(self_loop_head0->ref_cnt(), 1);\n    ASSERT_EQ(self_loop_head1->ref_cnt(), 1);\n  }\n  ASSERT_TRUE(deleted0);\n  ASSERT_TRUE(deleted1);\n}\n\n}  // namespace\n\n}  // namespace test\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/intrusive/intrusive.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_INTRUSIVE_INTRUSIVE_H_\n#define ONEFLOW_CORE_INTRUSIVE_INTRUSIVE_H_\n\n#include \"oneflow/core/intrusive/struct_traits.h\"\n#include \"oneflow/core/intrusive/base.h\"\n#include \"oneflow/core/intrusive/ref.h\"\n#include \"oneflow/core/intrusive/shared_ptr.h\"\n#include \"oneflow/core/intrusive/list.h\"\n#include \"oneflow/core/intrusive/head_free_list.h\"\n#include \"oneflow/core/intrusive/skiplist.h\"\n#include \"oneflow/core/intrusive/for_each.h\"\n#include \"oneflow/core/intrusive/reflective.h\"\n#include \"oneflow/core/intrusive/force_standard_layout.h\"\n\n#endif  // ONEFLOW_CORE_INTRUSIVE_INTRUSIVE_H_\n"
  },
  {
    "path": "oneflow/core/intrusive/intrusive_core_test.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n// include sstream first to avoid some compiling error\n// caused by the following trick\n// reference: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=65899\n#include <sstream>\n#include \"gtest/gtest.h\"\n#define private public\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/intrusive/intrusive.h\"\n#include \"oneflow/core/intrusive/flat_msg.h\"\n#include \"oneflow/core/common/preprocessor.h\"\n\nnamespace oneflow {\n\nnamespace intrusive {\n\nnamespace test {\n\nnamespace {\n\nTEST(Ref, ref_cnt) {\n  class Foo final : public Ref {\n   public:\n    Foo() = default;\n  };\n  Foo foo;\n  foo.InitRefCount();\n  foo.IncreaseRefCount();\n  foo.IncreaseRefCount();\n  ASSERT_EQ(foo.DecreaseRefCount(), 1);\n  ASSERT_EQ(foo.DecreaseRefCount(), 0);\n}\n\nclass IntrusiveFoo final : public intrusive::Base {\n public:\n  void __Init__() { clear_is_deleted(); }\n  void __Delete__();\n\n  // Getters\n  int8_t x() const { return x_; }\n  int32_t foo() const { return foo_; }\n  int16_t bar() const { return bar_; }\n  int64_t foobar() const { return foobar_; }\n  bool has_is_deleted() const { return is_deleted_ != nullptr; }\n  const std::string& is_deleted() const { return *is_deleted_; }\n\n  // Setters\n  void set_x(int8_t val) { x_ = val; }\n  void set_foo(int32_t val) { foo_ = val; }\n  void set_bar(int16_t val) { bar_ = val; }\n  void set_foobar(int64_t val) { foobar_ = val; }\n  void set_is_deleted(std::string* val) { is_deleted_ = val; }\n  std::string* mut_is_deleted() { return is_deleted_; }\n  void clear_is_deleted() { is_deleted_ = nullptr; }\n\n  size_t ref_cnt() const { return intrusive_ref_.ref_cnt(); }\n\n private:\n  friend class intrusive::Ref;\n  intrusive::Ref* mut_intrusive_ref() { return &intrusive_ref_; }\n\n  IntrusiveFoo() : intrusive_ref_(), x_(), foo_(), bar_(), foobar_(), is_deleted_() {}\n  intrusive::Ref intrusive_ref_;\n  int8_t x_;\n  int32_t foo_;\n  int16_t bar_;\n  int64_t foobar_;\n  std::string* is_deleted_;\n};\n\nvoid IntrusiveFoo::__Delete__() {\n  if (mut_is_deleted()) { *mut_is_deleted() = \"deleted\"; }\n}\n\nTEST(intrusive, naive) {\n  auto foo = intrusive::make_shared<IntrusiveFoo>();\n  foo->set_bar(9527);\n  ASSERT_TRUE(foo->bar() == 9527);\n}\n\nTEST(intrusive, __delete__) {\n  std::string is_deleted;\n  {\n    auto foo = intrusive::make_shared<IntrusiveFoo>();\n    foo->set_bar(9527);\n    foo->set_is_deleted(&is_deleted);\n    ASSERT_EQ(foo->bar(), 9527);\n  }\n  ASSERT_TRUE(is_deleted == \"deleted\");\n}\n\nclass IntrusiveBar final : public intrusive::Base {\n public:\n  void __Init__() { clear_is_deleted(); }\n  void __Delete__() {\n    if (mut_is_deleted()) { *mut_is_deleted() = \"bar_deleted\"; }\n  }\n\n  // Getters\n  const IntrusiveFoo& foo() const {\n    if (foo_) { return foo_.Get(); }\n    static const auto default_val = intrusive::make_shared<IntrusiveFoo>();\n    return default_val.Get();\n  }\n  const std::string& is_deleted() const { return *is_deleted_; }\n  bool has_is_deleted() const { return is_deleted_ != nullptr; }\n\n  // Setters\n  IntrusiveFoo* mut_foo() {\n    if (!foo_) { foo_ = intrusive::make_shared<IntrusiveFoo>(); }\n    return foo_.Mutable();\n  }\n  std::string* mut_is_deleted() { return is_deleted_; }\n  void set_is_deleted(std::string* val) { is_deleted_ = val; }\n  void clear_is_deleted() { is_deleted_ = nullptr; }\n\n  size_t ref_cnt() const { return intrusive_ref_.ref_cnt(); }\n\n private:\n  friend class intrusive::Ref;\n  intrusive::Ref* mut_intrusive_ref() { return &intrusive_ref_; }\n\n  IntrusiveBar() : intrusive_ref_(), foo_(), is_deleted_() {}\n  intrusive::Ref intrusive_ref_;\n  intrusive::shared_ptr<IntrusiveFoo> foo_;\n  std::string* is_deleted_;\n};\n\nTEST(intrusive, nested_objects) {\n  auto bar = intrusive::make_shared<IntrusiveBar>();\n  bar->mut_foo()->set_bar(9527);\n  ASSERT_TRUE(bar->foo().bar() == 9527);\n}\n\nTEST(intrusive, nested_delete) {\n  std::string bar_is_deleted;\n  std::string is_deleted;\n  {\n    auto bar = intrusive::make_shared<IntrusiveBar>();\n    bar->set_is_deleted(&bar_is_deleted);\n    auto* foo = bar->mut_foo();\n    foo->set_bar(9527);\n    foo->set_is_deleted(&is_deleted);\n    ASSERT_EQ(foo->bar(), 9527);\n    ASSERT_EQ(bar->ref_cnt(), 1);\n    ASSERT_EQ(foo->ref_cnt(), 1);\n  }\n  ASSERT_EQ(is_deleted, std::string(\"deleted\"));\n  ASSERT_EQ(bar_is_deleted, std::string(\"bar_deleted\"));\n}\n\n// clang-format off\nFLAT_MSG_BEGIN(FlatMsgDemo)\n  FLAT_MSG_DEFINE_ONEOF(type,\n      FLAT_MSG_ONEOF_FIELD(int32_t, int32_field)\n      FLAT_MSG_ONEOF_FIELD(float, float_field));\nFLAT_MSG_END(FlatMsgDemo)\n// clang-format on\n\nclass IntrusiveContainerDemo final : public intrusive::Base {\n public:\n  // Getters\n  const FlatMsgDemo& flat_field() const { return flat_field_.Get(); }\n  // Setters\n  FlatMsgDemo* mut_flat_field() { return flat_field_.Mutable(); }\n\n private:\n  friend class intrusive::Ref;\n  intrusive::Ref* mut_intrusive_ref() { return &intrusive_ref_; }\n\n  IntrusiveContainerDemo() : intrusive_ref_(), flat_field_() {}\n  intrusive::Ref intrusive_ref_;\n  FlatMsg<FlatMsgDemo> flat_field_;\n};\n\nTEST(intrusive, flat_msg_field) {\n  auto obj = intrusive::make_shared<IntrusiveContainerDemo>();\n  ASSERT_TRUE(!obj->flat_field().has_int32_field());\n  obj->mut_flat_field()->set_int32_field(33);\n  ASSERT_TRUE(obj->flat_field().has_int32_field());\n  ASSERT_EQ(obj->flat_field().int32_field(), 33);\n}\n\n// clang-format off\nREFLECTIVE_CLASS_BEGIN(TestIntrusiveField);\n  TestIntrusiveField() = default;\n  static_assert(REFLECTIVE_FIELD_COUNTER == 0, \"\");\n  static_assert(REFLECTIVE_FIELD_COUNTER == 0, \"\");\n  REFLECTIVE_CLASS_DEFINE_FIELD(int32_t, a);\n  static_assert(REFLECTIVE_FIELD_COUNTER == 1, \"\");\n  static_assert(REFLECTIVE_FIELD_COUNTER == 1, \"\");\n  REFLECTIVE_CLASS_DEFINE_FIELD(int64_t, b);\n  static_assert(REFLECTIVE_FIELD_COUNTER == 2, \"\");\n  static_assert(REFLECTIVE_FIELD_COUNTER == 2, \"\");\n  REFLECTIVE_CLASS_DEFINE_FIELD(int8_t, c);\n  static_assert(REFLECTIVE_FIELD_COUNTER == 3, \"\");\n  static_assert(REFLECTIVE_FIELD_COUNTER == 3, \"\");\n  REFLECTIVE_CLASS_DEFINE_FIELD(int64_t, d);\n  static_assert(REFLECTIVE_FIELD_COUNTER == 4, \"\");\n  static_assert(REFLECTIVE_FIELD_COUNTER == 4, \"\");\nREFLECTIVE_CLASS_END(TestIntrusiveField);\n// clang-format on\n\nTEST(intrusive, intrusive_field_number) {\n  static_assert(REFLECTIVE_FIELD_NUMBER(TestIntrusiveField, a) == 1, \"\");\n  static_assert(REFLECTIVE_FIELD_NUMBER(TestIntrusiveField, b) == 2, \"\");\n  static_assert(REFLECTIVE_FIELD_NUMBER(TestIntrusiveField, c) == 3, \"\");\n  static_assert(REFLECTIVE_FIELD_NUMBER(TestIntrusiveField, d) == 4, \"\");\n}\n\nTEST(intrusive, intrusive_field_type) {\n  static_assert(std::is_same<REFLECTIVE_FIELD_TYPE(TestIntrusiveField, 1), int32_t>::value, \"\");\n  static_assert(std::is_same<REFLECTIVE_FIELD_TYPE(TestIntrusiveField, 2), int64_t>::value, \"\");\n  static_assert(std::is_same<REFLECTIVE_FIELD_TYPE(TestIntrusiveField, 3), int8_t>::value, \"\");\n  static_assert(std::is_same<REFLECTIVE_FIELD_TYPE(TestIntrusiveField, 4), int64_t>::value, \"\");\n}\n\nTEST(intrusive, intrusive_field_offset) {\n  static_assert(REFLECTIVE_FIELD_OFFSET(TestIntrusiveField, 1) == 0, \"\");\n  static_assert(REFLECTIVE_FIELD_OFFSET(TestIntrusiveField, 2) == 8, \"\");\n  static_assert(REFLECTIVE_FIELD_OFFSET(TestIntrusiveField, 3) == 16, \"\");\n  static_assert(REFLECTIVE_FIELD_OFFSET(TestIntrusiveField, 4) == 24, \"\");\n}\n\n}  // namespace\n\n}  // namespace test\n\n}  // namespace intrusive\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/intrusive/list.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_INTRUSIVE_LIST_H_\n#define ONEFLOW_CORE_INTRUSIVE_LIST_H_\n\n#include \"oneflow/core/intrusive/ref.h\"\n#include \"oneflow/core/intrusive/list_hook.h\"\n\nnamespace oneflow {\n\nnamespace intrusive {\n\ntemplate<typename HookField>\nclass List {\n public:\n  List(const List&) = delete;\n  List(List&&) = delete;\n  List() { this->__Init__(); }\n  ~List() { this->Clear(); }\n\n  using value_type = typename HookField::struct_type;\n  using iterator_struct_field = HookField;\n\n  template<typename Enabled = void>\n  static constexpr int IteratorHookOffset() {\n    return offsetof(List, list_head_) + intrusive::ListHead<HookField>::IteratorHookOffset();\n  }\n\n  std::size_t size() const { return list_head_.size(); }\n  bool empty() const { return list_head_.empty(); }\n\n  void CheckSize() const { list_head_.CheckSize(); }\n\n  void __Init__() { list_head_.__Init__(); }\n\n  value_type* Begin() {\n    if (list_head_.empty()) { return nullptr; }\n    return list_head_.Begin();\n  }\n  value_type* Prev(value_type* ptr) {\n    if (ptr == nullptr) { return nullptr; }\n    value_type* prev = list_head_.Prev(ptr);\n    if (prev == list_head_.End()) { return nullptr; }\n    return prev;\n  }\n  value_type* Next(value_type* ptr) {\n    if (ptr == nullptr) { return nullptr; }\n    value_type* next = list_head_.Next(ptr);\n    if (next == list_head_.End()) { return nullptr; }\n    return next;\n  }\n  value_type* Last() {\n    if (list_head_.empty()) { return nullptr; }\n    return list_head_.Last();\n  }\n  constexpr value_type* End() const { return nullptr; }\n\n  void MoveToDstBack(value_type* ptr, List* dst) {\n    list_head_.MoveToDstBack(ptr, &dst->list_head_);\n  }\n  void MoveToDstFront(value_type* ptr, List* dst) {\n    list_head_.MoveToDstFront(ptr, &dst->list_head_);\n  }\n  value_type* MoveFrontToDstBack(List* dst) {\n    value_type* begin = list_head_.Begin();\n    MoveToDstBack(begin, dst);\n    return begin;\n  }\n  value_type* MoveBackToDstBack(List* dst) {\n    value_type* begin = list_head_.Last();\n    MoveToDstBack(begin, dst);\n    return begin;\n  }\n\n  void PushBack(value_type* ptr) {\n    list_head_.PushBack(ptr);\n    Ref::IncreaseRef(ptr);\n  }\n\n  void PushFront(value_type* ptr) {\n    list_head_.PushFront(ptr);\n    Ref::IncreaseRef(ptr);\n  }\n\n  void EmplaceBack(intrusive::shared_ptr<value_type>&& ptr) {\n    value_type* raw_ptr = nullptr;\n    ptr.__UnsafeMoveTo__(&raw_ptr);\n    list_head_.PushBack(raw_ptr);\n  }\n\n  void EmplaceFront(intrusive::shared_ptr<value_type>&& ptr) {\n    value_type* raw_ptr = nullptr;\n    ptr.__UnsafeMoveTo__(&raw_ptr);\n    list_head_.PushFront(raw_ptr);\n  }\n\n  intrusive::shared_ptr<value_type> Erase(value_type* ptr) {\n    list_head_.Erase(ptr);\n    return intrusive::shared_ptr<value_type>::__UnsafeMove__(ptr);\n  }\n\n  intrusive::shared_ptr<value_type> PopBack() {\n    value_type* raw_ptr = nullptr;\n    if (!list_head_.empty()) { raw_ptr = list_head_.PopBack(); }\n    return intrusive::shared_ptr<value_type>::__UnsafeMove__(raw_ptr);\n  }\n\n  intrusive::shared_ptr<value_type> PopFront() {\n    value_type* raw_ptr = nullptr;\n    if (!list_head_.empty()) { raw_ptr = list_head_.PopFront(); }\n    return intrusive::shared_ptr<value_type>::__UnsafeMove__(raw_ptr);\n  }\n\n  void MoveTo(List* list) { MoveToDstBack(list); }\n  void MoveToDstBack(List* list) { list_head_.MoveToDstBack(&list->list_head_); }\n\n  void Clear() {\n    while (!empty()) { Ref::DecreaseRef(list_head_.PopFront()); }\n  }\n\n private:\n  intrusive::ListHead<HookField> list_head_;\n};\n\n}  // namespace intrusive\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_INTRUSIVE_LIST_H_\n"
  },
  {
    "path": "oneflow/core/intrusive/list_hook.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_INTRUSIVE_LIST_HOOK_H_\n#define ONEFLOW_CORE_INTRUSIVE_LIST_HOOK_H_\n\n#include \"oneflow/core/intrusive/struct_traits.h\"\n#include \"oneflow/core/common/throw.h\"\n\nnamespace oneflow {\n\nnamespace intrusive {\n\nstruct ListHook {\n public:\n  ListHook() { Clear(); }\n\n  ListHook* prev() const { return prev_; }\n  ListHook* next() const { return next_; }  // NOLINT\n\n  void __Init__() { Clear(); }\n  void Clear() {\n    prev_ = this;\n    next_ = this;\n  }\n\n  bool empty() const { return prev_ == this || next_ == this; }\n  void AppendTo(ListHook* prev) {\n    prev->set_next(this);\n    this->set_prev(prev);\n  }\n  void InsertAfter(ListHook* prev) {\n    auto* next = prev->next();\n    this->AppendTo(prev);\n    next->AppendTo(this);\n  }\n  void Erase() {\n    next_->AppendTo(prev_);\n    Clear();\n  }\n\n  bool nullptr_empty() const { return prev_ == nullptr && next_ == nullptr; }\n\n  void NullptrClear() {\n    prev_ = nullptr;\n    next_ = nullptr;\n  }\n\n private:\n  void set_prev(ListHook* prev) { prev_ = prev; }\n  void set_next(ListHook* next) { next_ = next; }\n\n  ListHook* prev_;\n  ListHook* next_;\n};\n\n#define LIST_HOOK_FOR_EACH(head_hook, elem_hook_struct_field, elem) \\\n  LIST_HOOK_FOR_EACH_WITH_EXPR(head_hook, elem_hook_struct_field, elem, 0)\n\n#define LIST_HOOK_FOR_EACH_WITH_EXPR(head_hook, elem_hook_struct_field, elem, expr)   \\\n  for (typename elem_hook_struct_field::struct_type* elem = nullptr; elem == nullptr; \\\n       elem = nullptr, elem++)                                                        \\\n  LIST_HOOK_FOR_EACH_I(head_hook, __elem_hook__,                                      \\\n                       ((elem = elem_hook_struct_field::StructPtr4FieldPtr(__elem_hook__)), expr))\n\n#define LIST_HOOK_FOR_EACH_I(head_hook, elem_hook, expr)                                     \\\n  for (intrusive::ListHook* __head_hook__ = (head_hook), *elem_hook = __head_hook__->next(), \\\n                            *__next_hook__ = elem_hook->next();                              \\\n       (elem_hook != __head_hook__) && ((expr) || true);                                     \\\n       elem_hook = __next_hook__, __next_hook__ = __next_hook__->next())\n\n#define LIST_HOOK_UNSAFE_FOR_EACH(head_hook, elem_hook_struct_field, elem)            \\\n  for (typename elem_hook_struct_field::struct_type* elem = nullptr; elem == nullptr; \\\n       elem = nullptr, elem++)                                                        \\\n  LIST_HOOK_UNSAFE_FOR_EACH_I(head_hook, __elem_hook__,                               \\\n                              (elem = elem_hook_struct_field::StructPtr4FieldPtr(__elem_hook__)))\n\n#define LIST_HOOK_UNSAFE_FOR_EACH_I(head_hook, elem_hook, expr)                              \\\n  for (intrusive::ListHook* __head_hook__ = (head_hook), *elem_hook = __head_hook__->next(); \\\n       (elem_hook != __head_hook__) && ((expr), true); elem_hook = elem_hook->next())\n\ntemplate<typename HookField>\nclass ListHead {\n public:\n  ListHead() { Clear(); }\n  using value_type = typename HookField::struct_type;\n  static_assert(std::is_same<typename HookField::field_type, ListHook>::value, \"no ListHook found\");\n\n  template<typename Enabled = void>\n  static constexpr int IteratorHookOffset() {\n    return offsetof(ListHead, container_);\n  }\n\n  std::size_t size() const { return size_; }\n  bool empty() const {\n    bool list_empty = (&Begin() == &End());\n    bool size_empty = (size_ == 0);\n    CHECK_EQ(list_empty, size_empty);\n    return size_empty;\n  }\n  void CheckSize() const {\n    size_t hook_size = 0;\n    for (ListHook* iter = container_.next(); iter != &container_; iter = iter->next()) {\n      ++hook_size;\n    }\n    CHECK_EQ(size_, hook_size);\n  }\n  const value_type& Begin() const { return Next(End()); }\n  const value_type& ReverseBegin() const { return Prev(End()); }\n  const value_type& End() const { return *HookField::StructPtr4FieldPtr(&container()); }\n  const value_type& Next(const value_type& current) const {\n    return *HookField::StructPtr4FieldPtr(HookField::FieldPtr4StructPtr(&current)->next());\n  }\n  const value_type& Prev(const value_type& current) const {\n    return *HookField::StructPtr4FieldPtr(HookField::FieldPtr4StructPtr(&current)->prev());\n  }\n\n  value_type* Begin() { return Next(End()); }\n  value_type* Last() { return Prev(End()); }\n  value_type* End() { return HookField::StructPtr4FieldPtr(mut_container()); }\n  value_type* Next(value_type* current) {\n    return HookField::StructPtr4FieldPtr(HookField::FieldPtr4StructPtr(current)->next());\n  }\n  value_type* Prev(value_type* current) {\n    return HookField::StructPtr4FieldPtr(HookField::FieldPtr4StructPtr(current)->prev());\n  }\n  void __Init__() { Clear(); }\n\n  void Clear() {\n    container_.__Init__();\n    size_ = 0;\n  }\n\n  void Erase(value_type* elem) {\n    CHECK_GT(size_, 0);\n    CHECK_NE(elem, End());\n    ListHook* list_hook = HookField::FieldPtr4StructPtr(elem);\n    CHECK(!list_hook->empty());\n    list_hook->Erase();\n    --size_;\n  }\n  void MoveToDstBack(value_type* elem, ListHead* dst) {\n    CHECK(!container_.empty());\n    auto* dst_rbegin = dst->container_.prev();\n    auto* dst_end = &dst->container_;\n    ListHook* elem_hook = HookField::FieldPtr4StructPtr(elem);\n    elem_hook->next()->AppendTo(elem_hook->prev());\n    elem_hook->AppendTo(dst_rbegin);\n    dst_end->AppendTo(elem_hook);\n    --size_;\n    ++dst->size_;\n  }\n  void MoveToDstFront(value_type* elem, ListHead* dst) {\n    CHECK(!container_.empty());\n    auto* dst_end = &dst->container_;\n    auto* dst_begin = dst->container_.next();\n    ListHook* elem_hook = HookField::FieldPtr4StructPtr(elem);\n    elem_hook->next()->AppendTo(elem_hook->prev());\n    elem_hook->AppendTo(dst_end);\n    dst_begin->AppendTo(elem_hook);\n    --size_;\n    ++dst->size_;\n  }\n  void PushBack(value_type* elem) { InsertAfter(Last(), elem); }\n  void PushFront(value_type* elem) { InsertAfter(End(), elem); }\n  value_type* PopBack() {\n    CHECK(!empty());\n    value_type* last = Last();\n    Erase(last);\n    return last;\n  }\n  value_type* PopFront() {\n    CHECK(!empty());\n    value_type* first = Begin();\n    Erase(first);\n    return first;\n  }\n  void MoveToDstBack(ListHead* dst) {\n    if (container_.empty()) { return; }\n    auto* dst_last = dst->container_.prev();\n    auto* dst_end = &dst->container_;\n    auto* this_first = container_.next();\n    auto* this_last = container_.prev();\n    this_first->AppendTo(dst_last);\n    dst_end->AppendTo(this_last);\n    dst->size_ += size();\n    this->Clear();\n  }\n\n private:\n  void InsertAfter(value_type* prev_elem, value_type* new_elem) {\n    ListHook* prev_list_hook = HookField::FieldPtr4StructPtr(prev_elem);\n    ListHook* next_list_hook = prev_list_hook->next();\n    ListHook* new_list_hook = HookField::FieldPtr4StructPtr(new_elem);\n    CHECK(new_list_hook->empty());\n    new_list_hook->AppendTo(prev_list_hook);\n    next_list_hook->AppendTo(new_list_hook);\n    ++size_;\n  }\n  const ListHook& container() const { return container_; }\n  ListHook* mut_container() { return &container_; }\n\n private:\n  ListHook container_;\n  volatile std::size_t size_;\n};\n\n}  // namespace intrusive\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_INTRUSIVE_LIST_HOOK_H_\n"
  },
  {
    "path": "oneflow/core/intrusive/list_hook_test.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n// include sstream first to avoid some compiling error\n// caused by the following trick\n// reference: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=65899\n#include <sstream>\n#include \"gtest/gtest.h\"\n#define private public\n#include \"oneflow/core/intrusive/list_hook.h\"\n#include \"oneflow/core/common/util.h\"\n\nnamespace oneflow {\n\nnamespace intrusive {\n\nnamespace test {\n\nstruct ListItemBar final {\n  ListItemBar() : value() { bar_list.__Init__(); }\n  int value;\n  ListHook bar_list;\n};\n\nclass TestListHook final : public ListHook {\n public:\n  TestListHook() { this->__Init__(); }\n};\n\ntemplate<typename ItemField>\nclass TestListHead : public intrusive::ListHead<ItemField> {\n public:\n  TestListHead() { this->__Init__(); }\n};\n\nusing BarListHead = TestListHead<INTRUSIVE_FIELD(ListItemBar, bar_list)>;\n\nTEST(TestListHook, init) {\n  TestListHook list_iterator;\n  ASSERT_EQ(&list_iterator, list_iterator.prev());\n  ASSERT_EQ(&list_iterator, list_iterator.next());\n}\n\nTEST(TestListHook, append_to) {\n  TestListHook list_iter0;\n  TestListHook list_iter1;\n  list_iter1.AppendTo(&list_iter0);\n  ASSERT_EQ(&list_iter0, list_iter1.prev());\n  ASSERT_EQ(&list_iter1, list_iter0.next());\n}\n\nTEST(TestListHook, clear) {\n  TestListHook list_head0;\n  TestListHook list_head1;\n  list_head1.AppendTo(&list_head0);\n  list_head1.__Init__();\n  ASSERT_EQ(&list_head1, list_head1.prev());\n  ASSERT_EQ(&list_head1, list_head1.next());\n}\n\nTEST(ListHead, empty) {\n  BarListHead list_head;\n  ASSERT_TRUE(list_head.empty());\n}\n\nTEST(ListHead, push_front) {\n  BarListHead list_head;\n  ListHook& head = list_head.container_;\n  ListItemBar item0;\n  list_head.PushFront(&item0);\n  ASSERT_EQ(head.next(), &item0.bar_list);\n  ASSERT_EQ(head.prev(), &item0.bar_list);\n  ASSERT_EQ(item0.bar_list.next(), &head);\n  ASSERT_EQ(item0.bar_list.prev(), &head);\n  ListItemBar item1;\n  list_head.PushFront(&item1);\n  ASSERT_EQ(head.next(), &item1.bar_list);\n  ASSERT_EQ(item1.bar_list.prev(), &head);\n  ASSERT_EQ(item1.bar_list.next(), &item0.bar_list);\n  ASSERT_EQ(item0.bar_list.prev(), &item1.bar_list);\n  ASSERT_EQ(item0.bar_list.next(), &head);\n  ASSERT_EQ(head.prev(), &item0.bar_list);\n}\n\nTEST(ListHead, end) {\n  BarListHead list_head;\n  ListItemBar* end_item = list_head.End();\n  ListItemBar item0;\n  list_head.PushFront(&item0);\n  ASSERT_EQ(end_item, list_head.End());\n}\n\nTEST(ListHead, begin) {\n  BarListHead list_head;\n  ASSERT_EQ(list_head.Begin(), list_head.End());\n  ListItemBar item0;\n  list_head.PushFront(&item0);\n  ASSERT_EQ(list_head.Begin(), &item0);\n  ListItemBar item1;\n  list_head.PushFront(&item1);\n  ASSERT_EQ(list_head.Begin(), &item1);\n}\n\nTEST(ListHead, last) {\n  BarListHead list_head;\n  ASSERT_EQ(list_head.Begin(), list_head.End());\n  ListItemBar item0;\n  list_head.PushFront(&item0);\n  ASSERT_EQ(list_head.Last(), &item0);\n  ListItemBar item1;\n  list_head.PushFront(&item1);\n  ASSERT_EQ(list_head.Last(), &item0);\n}\n\nTEST(ListHead, push_back) {\n  BarListHead list_head;\n  ASSERT_EQ(list_head.Begin(), list_head.End());\n  ListItemBar item0;\n  list_head.PushBack(&item0);\n  ASSERT_EQ(list_head.Last(), &item0);\n  ListItemBar item1;\n  list_head.PushBack(&item1);\n  ASSERT_EQ(list_head.Last(), &item1);\n}\n\nTEST(ListHead, erase) {\n  BarListHead list_head;\n  ASSERT_EQ(list_head.Begin(), list_head.End());\n  ListItemBar item0;\n  list_head.PushBack(&item0);\n  ASSERT_EQ(list_head.Last(), &item0);\n  ListItemBar item1;\n  list_head.PushBack(&item1);\n  ASSERT_EQ(list_head.Last(), &item1);\n  list_head.Erase(&item0);\n  ASSERT_EQ(list_head.Last(), &item1);\n  ASSERT_EQ(list_head.Begin(), &item1);\n  ASSERT_EQ(item0.bar_list.prev(), &item0.bar_list);\n  ASSERT_EQ(item0.bar_list.next(), &item0.bar_list);\n}\n\nTEST(ListHead, pop_front) {\n  BarListHead list_head;\n  ASSERT_EQ(list_head.Begin(), list_head.End());\n  ListItemBar item0;\n  list_head.PushBack(&item0);\n  ASSERT_EQ(list_head.Last(), &item0);\n  ListItemBar item1;\n  list_head.PushBack(&item1);\n  ASSERT_EQ(list_head.Last(), &item1);\n  list_head.PopFront();\n  ASSERT_EQ(list_head.Last(), &item1);\n  ASSERT_EQ(list_head.Begin(), &item1);\n  ASSERT_EQ(item0.bar_list.prev(), &item0.bar_list);\n  ASSERT_EQ(item0.bar_list.next(), &item0.bar_list);\n}\n\nTEST(ListHead, pop_back) {\n  BarListHead list_head;\n  ASSERT_EQ(list_head.Begin(), list_head.End());\n  ListItemBar item0;\n  list_head.PushBack(&item0);\n  ASSERT_EQ(list_head.Last(), &item0);\n  ListItemBar item1;\n  list_head.PushBack(&item1);\n  ASSERT_EQ(list_head.Last(), &item1);\n  list_head.PopBack();\n  ASSERT_EQ(list_head.Last(), &item0);\n  ASSERT_EQ(list_head.Begin(), &item0);\n  ASSERT_EQ(item1.bar_list.prev(), &item1.bar_list);\n  ASSERT_EQ(item1.bar_list.next(), &item1.bar_list);\n}\n\nTEST(ListHead, Next) {\n  BarListHead list_head;\n  ListItemBar item0;\n  list_head.PushBack(&item0);\n  ListItemBar item1;\n  list_head.PushBack(&item1);\n\n  ListItemBar* item = list_head.Begin();\n  ASSERT_EQ(item, &item0);\n  item = list_head.Next(item);\n  ASSERT_EQ(item, &item1);\n  item = list_head.Next(item);\n  ASSERT_EQ(item, list_head.End());\n  item = list_head.Next(item);\n  ASSERT_EQ(item, &item0);\n}\n\nTEST(ListHead, prev_item) {\n  BarListHead list_head;\n  ListItemBar item0;\n  list_head.PushBack(&item0);\n  ListItemBar item1;\n  list_head.PushBack(&item1);\n\n  ListItemBar* item = list_head.Begin();\n  ASSERT_EQ(item, &item0);\n  item = list_head.Prev(item);\n  ASSERT_EQ(item, list_head.End());\n  item = list_head.Prev(item);\n  ASSERT_EQ(item, &item1);\n  item = list_head.Prev(item);\n  ASSERT_EQ(item, &item0);\n}\n\n}  // namespace test\n\n}  // namespace intrusive\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/intrusive/list_test.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n// include sstream first to avoid some compiling error\n// caused by the following trick\n// reference: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=65899\n#include <sstream>\n#include \"gtest/gtest.h\"\n#define private public\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/intrusive/intrusive.h\"\n\nnamespace oneflow {\n\nnamespace test {\n\nnamespace {\n\nclass TestListItem : public intrusive::Base {\n public:\n  void __Init__() { clear_cnt(); }\n  void __Delete__() {\n    if (has_cnt()) { --*mut_cnt(); }\n  }\n\n  // Getters\n  bool has_cnt() const { return cnt_ != nullptr; }\n  int cnt() const { return *cnt_; }\n  bool is_foo_list_empty() const { return foo_list_.empty(); }\n\n  // Setters\n  void set_cnt(int* val) { cnt_ = val; }\n  void clear_cnt() { cnt_ = nullptr; }\n  int* mut_cnt() { return cnt_; }\n\n  size_t ref_cnt() const { return intrusive_ref_.ref_cnt(); }\n\n  intrusive::ListHook foo_list_;\n\n private:\n  friend class intrusive::Ref;\n  intrusive::Ref* mut_intrusive_ref() { return &intrusive_ref_; }\n\n  TestListItem() : foo_list_(), intrusive_ref_(), cnt_() {}\n  intrusive::Ref intrusive_ref_;\n  int* cnt_;\n};\n\nusing TestList = intrusive::List<INTRUSIVE_FIELD(TestListItem, foo_list_)>;\n\nTEST(List, empty) {\n  TestList foo_list;\n  ASSERT_TRUE(foo_list.empty());\n  ASSERT_EQ(foo_list.size(), 0);\n}\n\nTEST(List, empty_Begin) {\n  TestList foo_list;\n  intrusive::shared_ptr<TestListItem> obj_ptr;\n  obj_ptr = foo_list.Begin();\n  ASSERT_TRUE(!obj_ptr);\n  intrusive::shared_ptr<TestListItem> next;\n  obj_ptr = foo_list.Begin();\n  next = foo_list.Next(obj_ptr.Mutable());\n  ASSERT_TRUE(!obj_ptr);\n}\n\nTEST(List, empty_Next) {\n  TestList foo_list;\n  intrusive::shared_ptr<TestListItem> obj_ptr;\n  intrusive::shared_ptr<TestListItem> next;\n  obj_ptr = foo_list.Begin();\n  next = foo_list.Next(obj_ptr.Mutable());\n  ASSERT_TRUE(!obj_ptr);\n  ASSERT_TRUE(!next);\n  obj_ptr = foo_list.Next(obj_ptr.Mutable());\n  ASSERT_TRUE(!obj_ptr);\n  obj_ptr = next;\n  next = foo_list.Next(next.Mutable());\n  ASSERT_TRUE(!obj_ptr);\n  ASSERT_TRUE(!next);\n}\n\nTEST(List, PushFront) {\n  TestList foo_list;\n  auto item0 = intrusive::make_shared<TestListItem>();\n  auto item1 = intrusive::make_shared<TestListItem>();\n  foo_list.PushFront(item0.Mutable());\n  foo_list.PushFront(item1.Mutable());\n  intrusive::shared_ptr<TestListItem> obj_ptr;\n  intrusive::shared_ptr<TestListItem> next;\n  obj_ptr = foo_list.Begin();\n  next = foo_list.Next(obj_ptr.Mutable());\n  ASSERT_TRUE(obj_ptr == item1);\n  ASSERT_TRUE(next == item0);\n}\n\nTEST(List, destructor) {\n  int elem_cnt = 2;\n  {\n    TestList foo_list;\n    auto item0 = intrusive::make_shared<TestListItem>();\n    item0->set_cnt(&elem_cnt);\n    auto item1 = intrusive::make_shared<TestListItem>();\n    item1->set_cnt(&elem_cnt);\n    foo_list.PushFront(item0.Mutable());\n    foo_list.PushFront(item1.Mutable());\n  }\n  ASSERT_EQ(elem_cnt, 0);\n  elem_cnt = 2;\n  auto item0 = intrusive::make_shared<TestListItem>();\n  {\n    TestList foo_list;\n    item0->set_cnt(&elem_cnt);\n    auto item1 = intrusive::make_shared<TestListItem>();\n    item1->set_cnt(&elem_cnt);\n    foo_list.PushFront(item0.Mutable());\n    foo_list.PushFront(item1.Mutable());\n  }\n  ASSERT_EQ(elem_cnt, 1);\n}\n\nTEST(List, PushBack) {\n  TestList foo_list;\n  auto item0 = intrusive::make_shared<TestListItem>();\n  auto item1 = intrusive::make_shared<TestListItem>();\n  foo_list.PushBack(item0.Mutable());\n  foo_list.PushBack(item1.Mutable());\n  intrusive::shared_ptr<TestListItem> obj_ptr;\n  intrusive::shared_ptr<TestListItem> next;\n  obj_ptr = foo_list.Begin();\n  next = foo_list.Next(obj_ptr.Mutable());\n  ASSERT_TRUE(obj_ptr == item0);\n  ASSERT_TRUE(next == item1);\n}\n\nTEST(List, Erase) {\n  TestList foo_list;\n  auto item0 = intrusive::make_shared<TestListItem>();\n  auto item1 = intrusive::make_shared<TestListItem>();\n  foo_list.PushBack(item0.Mutable());\n  foo_list.PushBack(item1.Mutable());\n  ASSERT_EQ(item1->ref_cnt(), 2);\n  foo_list.Erase(item1.Mutable());\n  ASSERT_EQ(item1->ref_cnt(), 1);\n  intrusive::shared_ptr<TestListItem> obj_ptr;\n  intrusive::shared_ptr<TestListItem> next;\n  obj_ptr = foo_list.Begin();\n  next = foo_list.Next(obj_ptr.Mutable());\n  ASSERT_TRUE(obj_ptr == item0);\n  ASSERT_TRUE(!next);\n}\n\nTEST(List, PopBack) {\n  TestList foo_list;\n  auto item0 = intrusive::make_shared<TestListItem>();\n  auto item1 = intrusive::make_shared<TestListItem>();\n  foo_list.PushBack(item0.Mutable());\n  foo_list.PushBack(item1.Mutable());\n  ASSERT_EQ(item1->ref_cnt(), 2);\n  foo_list.PopBack();\n  ASSERT_EQ(item1->ref_cnt(), 1);\n  intrusive::shared_ptr<TestListItem> obj_ptr;\n  intrusive::shared_ptr<TestListItem> next;\n  obj_ptr = foo_list.Begin();\n  next = foo_list.Next(obj_ptr.Mutable());\n  ASSERT_TRUE(obj_ptr == item0);\n  ASSERT_TRUE(!next);\n}\n\nTEST(List, PopFront) {\n  TestList foo_list;\n  auto item0 = intrusive::make_shared<TestListItem>();\n  auto item1 = intrusive::make_shared<TestListItem>();\n  foo_list.PushBack(item0.Mutable());\n  foo_list.PushBack(item1.Mutable());\n  ASSERT_EQ(item0->ref_cnt(), 2);\n  foo_list.PopFront();\n  ASSERT_EQ(item0->ref_cnt(), 1);\n  intrusive::shared_ptr<TestListItem> obj_ptr;\n  intrusive::shared_ptr<TestListItem> next;\n  obj_ptr = foo_list.Begin();\n  next = foo_list.Next(obj_ptr.Mutable());\n  ASSERT_TRUE(!next);\n}\n\nTEST(List, Clear) {\n  TestList foo_list;\n  auto item0 = intrusive::make_shared<TestListItem>();\n  auto item1 = intrusive::make_shared<TestListItem>();\n  foo_list.PushBack(item0.Mutable());\n  foo_list.PushBack(item1.Mutable());\n  ASSERT_EQ(item0->ref_cnt(), 2);\n  ASSERT_EQ(item1->ref_cnt(), 2);\n  foo_list.Clear();\n  ASSERT_TRUE(foo_list.empty());\n  ASSERT_EQ(item0->ref_cnt(), 1);\n  ASSERT_EQ(item1->ref_cnt(), 1);\n}\n\nTEST(List, UNSAFE_FOR_EACH_PTR) {\n  TestList foo_list;\n  auto item0 = intrusive::make_shared<TestListItem>();\n  auto item1 = intrusive::make_shared<TestListItem>();\n  foo_list.PushBack(item0.Mutable());\n  foo_list.PushBack(item1.Mutable());\n  int i = 0;\n  INTRUSIVE_UNSAFE_FOR_EACH_PTR(item, &foo_list) {\n    if (i == 0) {\n      ASSERT_TRUE(item == item0.Mutable());\n    } else if (i == 1) {\n      ASSERT_TRUE(item == item1.Mutable());\n    }\n    ++i;\n  }\n  ASSERT_EQ(i, 2);\n}\n\nTEST(List, FOR_EACH) {\n  TestList foo_list;\n  auto item0 = intrusive::make_shared<TestListItem>();\n  auto item1 = intrusive::make_shared<TestListItem>();\n  foo_list.PushBack(item0.Mutable());\n  foo_list.PushBack(item1.Mutable());\n  int i = 0;\n  INTRUSIVE_FOR_EACH(item, &foo_list) {\n    if (i == 0) {\n      ASSERT_TRUE(item == item0);\n      foo_list.Erase(item.Mutable());\n    } else if (i == 1) {\n      ASSERT_TRUE(item == item1);\n      foo_list.Erase(item.Mutable());\n    }\n    ++i;\n  }\n  ASSERT_EQ(i, 2);\n  ASSERT_TRUE(foo_list.empty());\n  ASSERT_EQ(item0->ref_cnt(), 1);\n  ASSERT_EQ(item1->ref_cnt(), 1);\n}\n\nclass TestIntrusiveListHead final : public intrusive::Base {\n public:\n  // types\n  using FooList = intrusive::List<INTRUSIVE_FIELD(TestListItem, foo_list_)>;\n  // Getters\n  const FooList& foo_list() const { return foo_list_; }\n  // Setters\n  FooList* mut_foo_list() { return &foo_list_; }\n\n private:\n  friend class intrusive::Ref;\n  intrusive::Ref* mut_intrusive_ref() { return &intrusive_ref_; }\n\n  TestIntrusiveListHead() : intrusive_ref_(), foo_list_() {}\n  intrusive::Ref intrusive_ref_;\n  FooList foo_list_;\n};\n\nTEST(List, intrusive_list_for_each) {\n  auto foo_list_head = intrusive::make_shared<TestIntrusiveListHead>();\n  auto& foo_list = *foo_list_head->mut_foo_list();\n  auto item0 = intrusive::make_shared<TestListItem>();\n  auto item1 = intrusive::make_shared<TestListItem>();\n  foo_list.PushBack(item0.Mutable());\n  foo_list.PushBack(item1.Mutable());\n  ASSERT_EQ(item0->ref_cnt(), 2);\n  ASSERT_EQ(item1->ref_cnt(), 2);\n  int i = 0;\n  INTRUSIVE_FOR_EACH(item, &foo_list) {\n    if (i == 0) {\n      ASSERT_TRUE(item == item0);\n      foo_list.Erase(item.Mutable());\n    } else if (i == 1) {\n      ASSERT_TRUE(item == item1);\n      foo_list.Erase(item.Mutable());\n    }\n    ++i;\n  }\n  ASSERT_EQ(i, 2);\n  ASSERT_TRUE(foo_list.empty());\n  ASSERT_EQ(item0->ref_cnt(), 1);\n  ASSERT_EQ(item1->ref_cnt(), 1);\n}\n\nclass TestIntrusiveListHeadWrapper final : public intrusive::Base {\n public:\n  // Getters\n  const TestIntrusiveListHead& head() const {\n    if (head_) { return head_.Get(); }\n    static const auto default_val = intrusive::make_shared<TestIntrusiveListHead>();\n    return default_val.Get();\n  }\n  // Setters\n  TestIntrusiveListHead* mut_head() {\n    if (!head_) { head_ = intrusive::make_shared<TestIntrusiveListHead>(); }\n    return head_.Mutable();\n  }\n  void clear_head() {\n    if (head_) { head_.Reset(); }\n  }\n\n private:\n  friend class intrusive::Ref;\n  intrusive::Ref* mut_intrusive_ref() { return &intrusive_ref_; }\n\n  TestIntrusiveListHeadWrapper() : intrusive_ref_(), head_() {}\n  intrusive::Ref intrusive_ref_;\n  intrusive::shared_ptr<TestIntrusiveListHead> head_;\n};\n\nTEST(List, nested_list_delete) {\n  auto foo_list_head = intrusive::make_shared<TestIntrusiveListHeadWrapper>();\n  auto& foo_list = *foo_list_head->mut_head()->mut_foo_list();\n  auto item0 = intrusive::make_shared<TestListItem>();\n  auto item1 = intrusive::make_shared<TestListItem>();\n  foo_list.PushBack(item0.Mutable());\n  foo_list.PushBack(item1.Mutable());\n  ASSERT_EQ(item0->ref_cnt(), 2);\n  ASSERT_EQ(item1->ref_cnt(), 2);\n  int i = 0;\n  INTRUSIVE_UNSAFE_FOR_EACH_PTR(item, &foo_list) {\n    if (i == 0) {\n      ASSERT_TRUE(item == item0.Mutable());\n    } else if (i == 1) {\n      ASSERT_TRUE(item == item1.Mutable());\n    }\n    ++i;\n  }\n  ASSERT_EQ(i, 2);\n  foo_list_head->clear_head();\n  ASSERT_EQ(item0->ref_cnt(), 1);\n  ASSERT_EQ(item1->ref_cnt(), 1);\n}\n\nTEST(List, MoveTo) {\n  TestList foo_list;\n  TestList foo_list0;\n  auto item0 = intrusive::make_shared<TestListItem>();\n  auto item1 = intrusive::make_shared<TestListItem>();\n  ASSERT_EQ(item0->is_foo_list_empty(), true);\n  ASSERT_EQ(item1->is_foo_list_empty(), true);\n  foo_list.PushBack(item0.Mutable());\n  foo_list.PushBack(item1.Mutable());\n  ASSERT_EQ(item0->is_foo_list_empty(), false);\n  ASSERT_EQ(item1->is_foo_list_empty(), false);\n  ASSERT_EQ(foo_list.size(), 2);\n  ASSERT_EQ(foo_list0.empty(), true);\n  ASSERT_EQ(item0->ref_cnt(), 2);\n  ASSERT_EQ(item1->ref_cnt(), 2);\n  foo_list.MoveTo(&foo_list0);\n  ASSERT_EQ(foo_list0.size(), 2);\n  ASSERT_EQ(foo_list.empty(), true);\n  ASSERT_TRUE(foo_list0.Begin() == item0.Mutable());\n  ASSERT_TRUE(foo_list0.Last() == item1.Mutable());\n  ASSERT_EQ(item0->ref_cnt(), 2);\n  ASSERT_EQ(item1->ref_cnt(), 2);\n}\n\n}  // namespace\n\n}  // namespace test\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/intrusive/mutexed_list.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_INTRUSIVE_MUTEXED_LIST_H_\n#define ONEFLOW_CORE_INTRUSIVE_MUTEXED_LIST_H_\n\n#include <mutex>\n#include \"oneflow/core/intrusive/list.h\"\n\nnamespace oneflow {\n\nnamespace intrusive {\n\ntemplate<typename HookField>\nclass MutexedList {\n public:\n  using value_type = typename HookField::struct_type;\n  using list_type = List<HookField>;\n\n  MutexedList(const MutexedList&) = delete;\n  MutexedList(MutexedList&&) = delete;\n  explicit MutexedList(std::mutex* mutex) { this->__Init__(mutex); }\n  ~MutexedList() { this->Clear(); }\n\n  std::size_t thread_unsafe_size() const { return list_head_.size(); }\n  std::size_t size() const {\n    std::unique_lock<std::mutex> lock(*mutex_);\n    return list_head_.size();\n  }\n  bool empty() const {\n    std::unique_lock<std::mutex> lock(*mutex_);\n    return list_head_.empty();\n  }\n\n  void __Init__(std::mutex* mutex) {\n    list_head_.__Init__();\n    mutex_ = mutex;\n  }\n\n  void EmplaceBack(intrusive::shared_ptr<value_type>&& ptr) {\n    std::unique_lock<std::mutex> lock(*mutex_);\n    return list_head_.EmplaceBack(std::move(ptr));\n  }\n  void EmplaceFront(intrusive::shared_ptr<value_type>&& ptr) {\n    std::unique_lock<std::mutex> lock(*mutex_);\n    return list_head_.EmplaceFront(std::move(ptr));\n  }\n  void PushBack(value_type* ptr) { EmplaceBack(intrusive::shared_ptr<value_type>(ptr)); }\n  void PushFront(value_type* ptr) { EmplaceFront(intrusive::shared_ptr<value_type>(ptr)); }\n  intrusive::shared_ptr<value_type> PopBack() {\n    std::unique_lock<std::mutex> lock(*mutex_);\n    return list_head_.PopBack();\n  }\n  intrusive::shared_ptr<value_type> PopFront() {\n    std::unique_lock<std::mutex> lock(*mutex_);\n    return list_head_.PopFront();\n  }\n\n  // Returns true if old list is empty.\n  bool MoveFrom(list_type* src) {\n    std::unique_lock<std::mutex> lock(*mutex_);\n    return ThreadUnsafeMoveFrom(src);\n  }\n\n  // Returns true if old list is empty.\n  bool ThreadUnsafeMoveFrom(list_type* src) {\n    bool old_list_empty = list_head_.empty();\n    src->MoveToDstBack(&list_head_);\n    return old_list_empty;\n  }\n\n  void MoveTo(list_type* dst) {\n    std::unique_lock<std::mutex> lock(*mutex_);\n    list_head_.MoveToDstBack(dst);\n  }\n\n  void ThreadUnsafeMoveTo(list_type* dst) { list_head_.MoveToDstBack(dst); }\n\n  void Clear() {\n    std::unique_lock<std::mutex> lock(*mutex_);\n    list_head_.Clear();\n  }\n\n private:\n  list_type list_head_;\n  std::mutex* mutex_;\n};\n\n}  // namespace intrusive\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_INTRUSIVE_MUTEXED_LIST_H_\n"
  },
  {
    "path": "oneflow/core/intrusive/object_pool.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_INTRUSIVE_OBJECT_POOL_H_\n#define ONEFLOW_CORE_INTRUSIVE_OBJECT_POOL_H_\n\n#include <vector>\n#include \"oneflow/core/intrusive/cpp_attribute.h\"\n\nnamespace oneflow {\nnamespace intrusive {\n\nenum ObjectPoolStrategey {\n  kThreadUnsafeAndDisableDestruct,\n};\n\ntemplate<typename T, ObjectPoolStrategey object_pool_strategy>\nclass ObjectPool;\n\ntemplate<typename T, ObjectPoolStrategey object_pool_strategy>\nclass EnableObjectPool {\n public:\n  EnableObjectPool() = default;\n  EnableObjectPool(const EnableObjectPool&) = default;\n  EnableObjectPool(EnableObjectPool&&) = default;\n  ~EnableObjectPool() = default;\n\n  using object_pool_type = ObjectPool<T, object_pool_strategy>;\n  object_pool_type* mut_object_pool() { return object_pool_; }\n  void set_object_pool(object_pool_type* val) { object_pool_ = val; }\n\n private:\n  object_pool_type* object_pool_;\n};\n\ntemplate<typename T>\nclass ObjectPool<T, kThreadUnsafeAndDisableDestruct> {\n public:\n  ObjectPool() { container_.reserve(kObjectPoolInitCap); }\n  ObjectPool(const ObjectPool&) = delete;\n  ObjectPool(ObjectPool&&) = delete;\n  ~ObjectPool() {\n    for (auto* elem : container_) { delete elem; }\n  }\n\n  template<typename... Args>\n  intrusive::shared_ptr<T> make_shared(Args&&... args) {\n    if (INTRUSIVE_PREDICT_FALSE(container_.empty())) {\n      auto ptr = intrusive::make_shared<T>(std::forward<Args>(args)...);\n      InitObjectPoolFields4Element(ptr.get());\n      return ptr;\n    } else {\n      auto* ptr = container_.back();\n      container_.pop_back();\n      ptr->__Init__(std::forward<Args>(args)...);\n      InitObjectPoolFields4Element(ptr);\n      return intrusive::shared_ptr<T>(ptr);\n    }\n  }\n\n  static void Put(void* raw_ptr) {\n    T* ptr = reinterpret_cast<T*>(raw_ptr);\n    ptr->mut_object_pool()->container_.push_back(ptr);\n  }\n\n private:\n  inline void InitObjectPoolFields4Element(T* ptr) {\n    ptr->set_object_pool(this);\n    ptr->mut_intrusive_ref()->set_deleter(&ObjectPool::Put);\n  }\n\n  static constexpr int kObjectPoolInitCap = 65536;\n  std::vector<T*> container_;\n};\n\n}  // namespace intrusive\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_INTRUSIVE_OBJECT_POOL_H_\n"
  },
  {
    "path": "oneflow/core/intrusive/object_pool_test.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <sstream>\n#include \"gtest/gtest.h\"\n#define private public\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/intrusive/intrusive.h\"\n#include \"oneflow/core/intrusive/object_pool.h\"\n\nnamespace oneflow {\n\nnamespace intrusive {\n\nnamespace test {\n\nnamespace {\n\nclass IntrusiveFoo final  // NOLINT\n    : public intrusive::Base,\n      public intrusive::EnableObjectPool<IntrusiveFoo, kThreadUnsafeAndDisableDestruct> {  // NOLINT\n public:\n  IntrusiveFoo() = default;  // NOLINT\n\n  intrusive::Ref* mut_intrusive_ref() { return &intrusive_ref_; }\n\n private:\n  intrusive::Ref intrusive_ref_;\n};\n\nTEST(ObjectPool_kThreadUnsafeAndDisableDestruct, append_to_pool) {\n  ObjectPool<IntrusiveFoo, kThreadUnsafeAndDisableDestruct> object_pool;\n  IntrusiveFoo* ptr = nullptr;\n  { ptr = object_pool.make_shared().get(); }\n  ASSERT_EQ(ptr, object_pool.make_shared().get());\n}\n\nTEST(ObjectPool_kThreadUnsafeAndDisableDestruct, recycle) {\n  ObjectPool<IntrusiveFoo, kThreadUnsafeAndDisableDestruct> object_pool;\n  auto* ptr = object_pool.make_shared().get();\n  ASSERT_EQ(ptr, object_pool.make_shared().get());\n}\n\n}  // namespace\n}  // namespace test\n}  // namespace intrusive\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/intrusive/ref.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_INTRUSIVE_REF_H_\n#define ONEFLOW_CORE_INTRUSIVE_REF_H_\n\n#include <atomic>\n#include \"oneflow/core/common/throw.h\"\n#include \"oneflow/core/intrusive/cpp_attribute.h\"\n\nnamespace oneflow {\n\nnamespace intrusive {\n\nclass Ref {\n public:\n  Ref() : ref_cnt_(), deleter_(nullptr) {}\n\n  using RefCntType = int32_t;\n\n  RefCntType ref_cnt() const { return ref_cnt_; }\n\n  template<typename T>\n  static void NewAndInitRef(T** ptr) {\n    *ptr = new T();\n    (*ptr)->mut_intrusive_ref()->InitRefCount();\n    IncreaseRef(*ptr);\n  }\n  template<typename T>\n  static void IncreaseRef(T* ptr) {\n    ptr->mut_intrusive_ref()->IncreaseRefCount();\n  }\n  template<typename T>\n  static void DecreaseRef(T* ptr) {\n    CHECK_NOTNULL(ptr);\n    auto* ref = ptr->mut_intrusive_ref();\n    if (INTRUSIVE_PREDICT_TRUE(ref->DecreaseRefCount() > 0)) { return; }\n    if (INTRUSIVE_PREDICT_TRUE(ref->deleter_ == nullptr)) {\n      ptr->__Delete__();\n      delete ptr;\n    } else {\n      ref->deleter_(ptr);\n    }\n  }\n\n  void set_deleter(void (*deleter)(void*)) { deleter_ = deleter; }\n\n private:\n  void InitRefCount() { ref_cnt_ = 0; }\n  void IncreaseRefCount() { ref_cnt_++; }\n  RefCntType DecreaseRefCount() { return --ref_cnt_; }\n\n  std::atomic<RefCntType> ref_cnt_;\n  void (*deleter_)(void*);\n};\n\n}  // namespace intrusive\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_INTRUSIVE_REF_H_\n"
  },
  {
    "path": "oneflow/core/intrusive/reflective.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_INTRUSIVE_REFLECTIVE_CORE_H_\n#define ONEFLOW_CORE_INTRUSIVE_REFLECTIVE_CORE_H_\n\n#include \"oneflow/core/intrusive/dss.h\"\n#include \"oneflow/core/intrusive/static_counter.h\"\n#include \"oneflow/core/intrusive/struct_traits.h\"\n#include \"oneflow/core/intrusive/base.h\"\n\nnamespace oneflow {\n\n#define REFLECTIVE_CLASS_BEGIN(class_name)           \\\n  struct class_name final : public intrusive::Base { \\\n   public:                                           \\\n    using self_type = class_name;                    \\\n    static const bool __has_intrusive_ref__ = true;  \\\n                                                     \\\n   private:                                          \\\n    DEFINE_STATIC_COUNTER(field_counter);            \\\n    DSS_BEGIN(STATIC_COUNTER(field_counter), class_name);\n\n#define REFLECTIVE_CLASS_END(class_name)                                            \\\n  static_assert(__has_intrusive_ref__, \"this class is not intrusive-referenced\");   \\\n                                                                                    \\\n public:                                                                            \\\n  static const int __NumberOfFields__ = STATIC_COUNTER(field_counter);              \\\n                                                                                    \\\n private:                                                                           \\\n  INCREASE_STATIC_COUNTER(field_counter);                                           \\\n  DSS_END(STATIC_COUNTER(field_counter), \"intrusive-referenced class\", class_name); \\\n  }                                                                                 \\\n  ;\n\n#define REFLECTIVE_CLASS_DEFINE_FIELD(field_type, field_name)                               \\\n  static_assert(__has_intrusive_ref__, \"this class is not intrusive-referenced\");           \\\n  field_type field_name;                                                                    \\\n  INCREASE_STATIC_COUNTER(field_counter);                                                   \\\n  DSS_DEFINE_FIELD(STATIC_COUNTER(field_counter), \"intrusive-referenced class\", field_type, \\\n                   field_name);\n\n#define REFLECTIVE_FIELD(struct_type, field_name)                                             \\\n  intrusive::OffsetStructField<struct_type, struct_type::OF_PP_CAT(field_name, DssFieldType), \\\n                               struct_type::OF_PP_CAT(field_name, kDssFieldOffset)>\n\n// Get field number by field name\n// note: field numbers start from 1 instead of 0.\n#define REFLECTIVE_FIELD_NUMBER(cls, field_name) cls::OF_PP_CAT(field_name, kDssFieldNumber)\n\n// Get field type by field number\n#define REFLECTIVE_FIELD_TYPE(cls, field_number) cls::template __DssFieldType__<field_number>::type\n\n// Get field offset by field number\n#define REFLECTIVE_FIELD_OFFSET(cls, field_number) \\\n  cls::template __DssFieldOffset4FieldIndex__<field_number>::value\n\n// Get current defined field counter inside a intrusive-referenced class.\n// note: not used outside REFLECTIVE_CLASS_BEGIN ... REFLECTIVE_CLASS_END\n// e.g.:\n// REFLECTIVE_CLASS_BEGIN(Foo);\n//   static_assert(REFLECTIVE_FIELD_COUNTER == 0, \"\");\n//   REFLECTIVE_CLASS_DEFINE_FIELD(int64_t, a);\n//   static_assert(REFLECTIVE_FIELD_COUNTER == 1, \"\");\n//   REFLECTIVE_CLASS_DEFINE_FIELD(int64_t, b);\n//   static_assert(REFLECTIVE_FIELD_COUNTER == 2, \"\");\n//   REFLECTIVE_CLASS_DEFINE_FIELD(int8_t, c);\n//   static_assert(REFLECTIVE_FIELD_COUNTER == 3, \"\");\n//   REFLECTIVE_CLASS_DEFINE_FIELD(int64_t, d);\n// REFLECTIVE_CLASS_END(Foo);\n#define REFLECTIVE_FIELD_COUNTER STATIC_COUNTER(field_counter)\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_INTRUSIVE_REFLECTIVE_CORE_H_\n"
  },
  {
    "path": "oneflow/core/intrusive/shared_ptr.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_INTRUSIVE_SHARED_PTR_H_\n#define ONEFLOW_CORE_INTRUSIVE_SHARED_PTR_H_\n\n#include \"oneflow/core/intrusive/ref.h\"\n\nnamespace oneflow {\n\nnamespace intrusive {\n\ntemplate<typename T>\nclass shared_ptr final {\n public:\n  using value_type = T;\n  shared_ptr() : ptr_(nullptr) {}\n  shared_ptr(value_type* ptr) : ptr_(nullptr) { Reset(ptr); }\n  shared_ptr(const shared_ptr& obj_ptr) {\n    ptr_ = nullptr;\n    Reset(obj_ptr.ptr_);\n  }\n  shared_ptr(shared_ptr&& obj_ptr) noexcept {\n    ptr_ = obj_ptr.ptr_;\n    obj_ptr.ptr_ = nullptr;\n  }\n  // NOLINTNEXTLINE(google-explicit-constructor)\n  operator shared_ptr<const T>() const { return shared_ptr<const T>(ptr_); }\n  ~shared_ptr() { Clear(); }\n\n  template<typename... Args>\n  static shared_ptr make_shared(Args&&... args) {\n    shared_ptr ret;\n    Ref::NewAndInitRef(&ret.ptr_);\n    ret.Mutable()->__Init__(std::forward<Args>(args)...);\n    return ret;\n  }\n\n  explicit operator bool() const { return ptr_ != nullptr; }\n  value_type* get() const { return ptr_; }\n  const value_type& Get() const { return *ptr_; }\n  value_type* operator->() const { return ptr_; }\n  value_type& operator*() const { return *ptr_; }\n  bool operator==(const shared_ptr& rhs) const { return this->ptr_ == rhs.ptr_; }\n\n  value_type* Mutable() { return ptr_; }\n\n  void Reset() { Reset(nullptr); }\n\n  void Reset(value_type* ptr) {\n    Clear();\n    if (ptr == nullptr) { return; }\n    ptr_ = ptr;\n    Ref::IncreaseRef<value_type>(ptr_);\n  }\n\n  shared_ptr& operator=(const shared_ptr& rhs) {\n    Reset(rhs.ptr_);\n    return *this;\n  }\n\n  shared_ptr& operator=(shared_ptr&& rhs) noexcept {\n    ptr_ = rhs.ptr_;\n    rhs.ptr_ = nullptr;\n    return *this;\n  }\n\n  static shared_ptr __UnsafeMove__(value_type* ptr) {\n    shared_ptr ret;\n    ret.ptr_ = ptr;\n    return ret;\n  }\n  void __UnsafeMoveTo__(value_type** ptr) {\n    *ptr = ptr_;\n    ptr_ = nullptr;\n  }\n\n private:\n  void Clear() {\n    if (ptr_ == nullptr) { return; }\n    Ref::DecreaseRef<value_type>(ptr_);\n    ptr_ = nullptr;\n  }\n  mutable value_type* ptr_;\n};\n\ntemplate<typename T, typename... Args>\nshared_ptr<T> make_shared(Args&&... args) {\n  return shared_ptr<T>::make_shared(std::forward<Args>(args)...);\n}\n\n}  // namespace intrusive\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_INTRUSIVE_SHARED_PTR_H_\n"
  },
  {
    "path": "oneflow/core/intrusive/skiplist.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_INTRUSIVE_INTRUSIVE_SKIPLIST_H_\n#define ONEFLOW_CORE_INTRUSIVE_INTRUSIVE_SKIPLIST_H_\n\n#include \"oneflow/core/intrusive/ref.h\"\n#include \"oneflow/core/intrusive/skiplist_hook.h\"\n\nnamespace oneflow {\n\nnamespace intrusive {\n\ntemplate<typename ElemKeyField>\nclass SkipList {\n public:\n  SkipList(const SkipList&) = delete;\n  SkipList(SkipList&&) = delete;\n\n  SkipList() { this->__Init__(); }\n  ~SkipList() { this->Clear(); }\n\n  using value_type = typename ElemKeyField::struct_type;\n  using key_type = typename ElemKeyField::field_type::key_type;\n  using elem_key_level0_hook_struct_field =\n      OffsetStructField<typename ElemKeyField::field_type, intrusive::ListHook,\n                        ElemKeyField::field_type::LevelZeroHookOffset()>;\n  using iterator_struct_field = ComposeStructField<ElemKeyField, elem_key_level0_hook_struct_field>;\n  template<typename Enabled = void>\n  static constexpr int IteratorHookOffset() {\n    return offsetof(SkipList, skiplist_head_)\n           + intrusive::SkipListHead<ElemKeyField>::IteratorHookOffset();\n  }\n\n  void __Init__() { skiplist_head_.__Init__(); }\n\n  std::size_t size() const { return skiplist_head_.size(); }\n  bool empty() const { return skiplist_head_.empty(); }\n  value_type* Begin() { return skiplist_head_.Begin(); }\n  intrusive::shared_ptr<value_type> Find(const key_type& key) {\n    intrusive::shared_ptr<value_type> ret;\n    ret.Reset(skiplist_head_.Find(key));\n    return ret;\n  }\n  value_type* FindPtr(const key_type& key) { return skiplist_head_.Find(key); }\n  const value_type* FindPtr(const key_type& key) const { return skiplist_head_.Find(key); }\n  bool EqualsEnd(const intrusive::shared_ptr<value_type>& ptr) { return !ptr; }\n  void Erase(const key_type& key) { Ref::DecreaseRef(skiplist_head_.Erase(key)); }\n  void Erase(value_type* elem_ptr) {\n    skiplist_head_.Erase(elem_ptr);\n    Ref::DecreaseRef(elem_ptr);\n  }\n  std::pair<intrusive::shared_ptr<value_type>, bool> Insert(value_type* elem_ptr) {\n    value_type* ret_elem = nullptr;\n    bool success = false;\n    std::tie(ret_elem, success) = skiplist_head_.Insert(elem_ptr);\n    std::pair<intrusive::shared_ptr<value_type>, bool> ret;\n    ret.first.Reset(ret_elem);\n    ret.second = success;\n    if (success) { Ref::IncreaseRef(elem_ptr); }\n    return ret;\n  }\n\n  void Clear() {\n    skiplist_head_.Clear([](value_type* elem) { Ref::DecreaseRef(elem); });\n  }\n\n private:\n  intrusive::SkipListHead<ElemKeyField> skiplist_head_;\n};\n\n}  // namespace intrusive\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_INTRUSIVE_INTRUSIVE_SKIPLIST_H_\n"
  },
  {
    "path": "oneflow/core/intrusive/skiplist_hook.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_INTRUSIVE_EMBEDDED_SKIPLIST_H_\n#define ONEFLOW_CORE_INTRUSIVE_EMBEDDED_SKIPLIST_H_\n\n#include <array>\n#include <tuple>\n#include <random>\n#include \"oneflow/core/common/throw.h\"\n#include \"oneflow/core/intrusive/struct_traits.h\"\n#include \"oneflow/core/intrusive/list_hook.h\"\n\nnamespace oneflow {\n\nnamespace intrusive {\n\ntemplate<int max_level>\nstruct ListHookArray final {\n public:\n  ListHookArray() { Clear(); }\n  using self_type = ListHookArray<max_level>;\n  template<typename Enabled = void>\n  static constexpr int LevelZeroHookOffset() {\n    return 0;\n  }\n\n  bool empty() const { return hooks_[0].nullptr_empty(); }\n\n  void __Init__() { Clear(); }\n\n  void Clear() {\n    for (auto& hook : hooks_) { hook.Clear(); }\n  }\n  void NullptrClear() {\n    for (auto& hook : hooks_) { hook.NullptrClear(); }\n  }\n  void InsertAfter(ListHookArray* prev_skiplist_hook, int levels) {\n    CHECK(empty());\n    ListHook* prev_hook = &prev_skiplist_hook->hooks_[0];\n    int i = 0;\n    for (; i < levels; ++i, ++prev_hook) {\n      while (prev_hook->nullptr_empty()) { prev_hook = (prev_hook - 1)->prev() + 1; }\n      hooks_[i].InsertAfter(prev_hook);\n    }\n  }\n  void Erase() {\n    for (int i = 0; i < max_level; ++i) {\n      if (hooks_[i].nullptr_empty()) { return; }\n      hooks_[i].next()->AppendTo(hooks_[i].prev());\n      hooks_[i].NullptrClear();\n    }\n  }\n  static ListHookArray* ThisPtr4HookPtr(ListHook* slist_ptr, int level) {\n    auto* hooks_ptr = (std::array<intrusive::ListHook, max_level>*)(slist_ptr - level);\n    return OffsetStructField<self_type, decltype(hooks_), HooksOffset()>::StructPtr4FieldPtr(\n        hooks_ptr);\n  }\n  void CheckEmpty() const {\n    for (const auto& hook : hooks_) { CHECK(hook.empty()); }\n  }\n  void CheckNullptrEmpty() const {\n    for (const auto& hook : hooks_) { CHECK(hook.nullptr_empty()); }\n  }\n\n  ListHook* mutable_hook(int i) { return &hooks_[i]; }\n\n private:\n  template<typename Enabled = void>\n  static constexpr int HooksOffset() {\n    return offsetof(self_type, hooks_);\n  }\n\n  std::array<intrusive::ListHook, max_level> hooks_;\n};\n\ntemplate<typename T, int N = 20>\nstruct SkipListHook {\n public:\n  SkipListHook() : key_() { __Init__(); }\n  using self_type = SkipListHook<T, N>;\n  using hook_type = ListHookArray<N>;\n  using key_type = T;\n  static const int max_level = N;\n  static_assert(N > 0, \"invalid number of levels\");\n  template<typename Enabled = void>\n  static constexpr int LevelZeroHookOffset() {\n    return offsetof(SkipListHook, hook_) + hook_type::LevelZeroHookOffset();\n  }\n\n  bool empty() const { return hook_.empty(); }\n\n  void __Init__() { hook_.NullptrClear(); }\n\n  const T& key() const { return key_; }\n  T* mut_key() { return &key_; }\n\n  void CheckEmpty() const { return hook_.CheckNullptrEmpty(); }\n\n  void Clear() {\n    hook_.NullptrClear();\n    mut_key()->__Delete__();\n  }\n\n  static self_type* Find(const key_type& key, hook_type* head, int size_shift) {\n    ListHook* last_hook_less_than_key = SearchLastBottomHookLessThan(key, head, size_shift);\n    if (last_hook_less_than_key->next() == head->mutable_hook(0)) { return nullptr; }\n    self_type* searched = ThisPtr4HookPtr(last_hook_less_than_key->next(), 0);\n    if (searched->key() == key) { return searched; }\n    return nullptr;\n  }\n  static self_type* Erase(const key_type& key, hook_type* head, int size_shift) {\n    self_type* searched = Find(key, head, size_shift);\n    CHECK_NOTNULL(searched);\n    Erase(searched);\n    return searched;\n  }\n  static void Erase(self_type* elem) { elem->hook_.Erase(); }\n  // return true if success\n  static std::pair<self_type*, bool> Insert(self_type* elem, hook_type* head, int size_shift) {\n    ListHook* prev_list_hook = SearchLastBottomHookLessThan(elem->key(), head, size_shift);\n    self_type* maybe_searched = nullptr;\n    if (prev_list_hook->next() == head->mutable_hook(0)) {\n      maybe_searched = nullptr;\n    } else {\n      maybe_searched = ThisPtr4HookPtr(prev_list_hook->next(), 0);\n    }\n    self_type* ret_elem = nullptr;\n    bool success = false;\n    if (maybe_searched != nullptr && (maybe_searched->key() == elem->key())) {\n      ret_elem = maybe_searched;\n      success = false;\n    } else {\n      self_type* prev = ThisPtr4HookPtr(prev_list_hook, 0);\n      ret_elem = elem;\n      elem->hook_.InsertAfter(&prev->hook_, RandomNumLevels(size_shift));\n      success = true;\n    }\n    // CHECK_EQ(Find(ret_elem->key(), head), ret_elem, GetMaxVal<int32_t>() / 2);\n    return std::make_pair(ret_elem, success);\n  }\n  static SkipListHook* ThisPtr4HookPtr(ListHook* list_hook_ptr, int level) {\n    auto* skip_list_ptr = hook_type::ThisPtr4HookPtr(list_hook_ptr, level);\n    using FieldUtil = OffsetStructField<self_type, hook_type, SkipListIteratorOffset()>;\n    return FieldUtil::StructPtr4FieldPtr(skip_list_ptr);\n  }\n\n private:\n  template<typename Enabled = void>\n  static constexpr int SkipListIteratorOffset() {\n    return offsetof(self_type, hook_);\n  }\n  static int32_t RandomNumLevels(int size_shift) {\n    std::minstd_rand rand{std::random_device{}()};\n    int32_t max_num_levels = std::min(size_shift, N);\n    int32_t num_levels = 1;\n    for (int i = 1; (rand() % 2 == 0) && i < max_num_levels; ++i) { ++num_levels; }\n    return num_levels;\n  }\n\n  static ListHook* SearchLastBottomHookLessThan(const key_type& key, hook_type* head,\n                                                int size_shift) {\n    int max_num_level = std::min(size_shift, N);\n    ListHook* list_hook = head->mutable_hook(max_num_level);\n    for (int level = max_num_level - 1; level >= 0; --level) {\n      --list_hook;\n      while (list_hook->next() != head->mutable_hook(level)\n             && ThisPtr4HookPtr(list_hook->next(), level)->key() < key) {\n        list_hook = list_hook->next();\n      }\n    }\n    return list_hook;\n  }\n\n  hook_type hook_;\n  T key_;\n};\n\ntemplate<typename ValueHookField>\nclass SkipListHead {\n public:\n  SkipListHead() { __Init__(); }\n  using value_type = typename ValueHookField::struct_type;\n  using key_hook_type = typename ValueHookField::field_type;\n  using key_type = typename key_hook_type::key_type;\n  using value_key_level0_hook_struct_field =\n      OffsetStructField<typename ValueHookField::field_type, intrusive::ListHook,\n                        ValueHookField::field_type::LevelZeroHookOffset()>;\n  using value_level0_hook_struct_field =\n      ComposeStructField<ValueHookField, value_key_level0_hook_struct_field>;\n  static const int max_level = key_hook_type::max_level;\n  template<typename Enabled = void>\n  static constexpr int IteratorHookOffset() {\n    return offsetof(SkipListHead, skiplist_head_) + ListHookArray<max_level>::LevelZeroHookOffset();\n  }\n\n  void __Init__() {\n    skiplist_head_.__Init__();\n    size_ = 0;\n  }\n\n  std::size_t size() const { return size_; }\n  bool empty() const { return size_ == 0; }\n\n  value_type* Begin() {\n    ListHook* head_level0 = skiplist_head_.mutable_hook(0);\n    ListHook* begin_list_hook = head_level0->next();\n    if (begin_list_hook == head_level0) { return nullptr; }\n    return value_level0_hook_struct_field::StructPtr4FieldPtr(begin_list_hook);\n  }\n\n  value_type* Find(const key_type& key) {\n    auto* key_hook_ptr = key_hook_type::Find(key, &skiplist_head_, size_shift());\n    if (key_hook_ptr == nullptr) { return nullptr; }\n    return ValueHookField::StructPtr4FieldPtr(key_hook_ptr);\n  }\n  const value_type* Find(const key_type& key) const {\n    auto* key_hook_ptr = key_hook_type::Find(\n        key, const_cast<ListHookArray<max_level>*>(&skiplist_head_), size_shift());\n    if (key_hook_ptr == nullptr) { return nullptr; }\n    return ValueHookField::StructPtr4FieldPtr(key_hook_ptr);\n  }\n  value_type* Erase(const key_type& key) {\n    key_hook_type* erased = key_hook_type::Erase(key, &skiplist_head_, size_shift());\n    --size_;\n    return ValueHookField::StructPtr4FieldPtr(erased);\n  }\n  void Erase(value_type* elem) {\n    key_hook_type::Erase(ValueHookField::FieldPtr4StructPtr(elem));\n    --size_;\n  }\n  // return true if success\n  std::pair<value_type*, bool> Insert(value_type* elem) {\n    key_hook_type* elem_key_hook = ValueHookField::FieldPtr4StructPtr(elem);\n    key_hook_type* ret_key_hook = nullptr;\n    bool success = false;\n    std::tie(ret_key_hook, success) =\n        key_hook_type::Insert(elem_key_hook, &skiplist_head_, size_shift());\n    if (success) { ++size_; }\n    return std::make_pair(ValueHookField::StructPtr4FieldPtr(ret_key_hook), success);\n  }\n\n  template<typename Callback>\n  void Clear(const Callback& cb) {\n    using hook_type = ListHookArray<max_level>;\n    for (; size_ > 0; --size_) {\n      ListHook* begin_list_hook = skiplist_head_.mutable_hook(0)->next();\n      auto* begin = hook_type::ThisPtr4HookPtr(begin_list_hook, 0);\n      if (begin == &skiplist_head_) { break; }\n      begin->Erase();\n      cb(value_level0_hook_struct_field::StructPtr4FieldPtr(begin_list_hook));\n    }\n    CHECK(empty_debug());\n  }\n  void Clear() {\n    Clear([](value_type*) {});\n  }\n\n  bool empty_debug() const {\n    bool ret = (size_ == 0);\n    if (ret) { skiplist_head_.CheckEmpty(); }\n    return ret;\n  }\n\n private:\n  int size_shift() const { return std::log2(size_ + 1); }\n\n  ListHookArray<max_level> skiplist_head_;\n  volatile std::size_t size_;\n};\n\n}  // namespace intrusive\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_INTRUSIVE_EMBEDDED_SKIPLIST_H_\n"
  },
  {
    "path": "oneflow/core/intrusive/skiplist_hook_test.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"gtest/gtest.h\"\n#include \"oneflow/core/intrusive/skiplist_hook.h\"\n#include \"oneflow/core/common/util.h\"\n\nnamespace oneflow {\n\nnamespace intrusive {\n\nnamespace test {\n\ntemplate<typename ElemKeyField>\nclass TestSkipListHead final : public SkipListHead<ElemKeyField> {  // NOLINT\n public:\n  TestSkipListHead() { this->__Init__(); }\n  TestSkipListHead(const TestSkipListHead&) = delete;\n  TestSkipListHead(TestSkipListHead&&) = delete;\n  TestSkipListHead& operator==(const TestSkipListHead&) = delete;\n  TestSkipListHead& operator==(TestSkipListHead&&) = delete;\n  ~TestSkipListHead() { this->Clear(); }\n};\n\nstruct FooSkipListElem {\n  FooSkipListElem() : value() { key.__Init__(); }\n\n  int value;\n  SkipListHook<int> key;\n};\n\nusing FooSkipList = TestSkipListHead<INTRUSIVE_FIELD(FooSkipListElem, key)>;\n\nTEST(SkipListHook, empty) {\n  FooSkipList skiplist;\n  ASSERT_TRUE(skiplist.empty_debug());\n  ASSERT_EQ(skiplist.size(), 0);\n}\n\nTEST(SkipListHook, insert_naive) {\n  FooSkipList skiplist;\n  FooSkipListElem elem0;\n  *elem0.key.mut_key() = 0;\n  elem0.value = 1;\n  skiplist.Insert(&elem0);\n  ASSERT_EQ(skiplist.size(), 1);\n  {\n    auto* searched = skiplist.Find(int(0));\n    ASSERT_EQ(searched, &elem0);\n  }\n  {\n    auto* searched = skiplist.Find(int(-1));\n    ASSERT_TRUE(searched == nullptr);\n  }\n}\n\nTEST(SkipListHook, erase_by_key) {\n  FooSkipList skiplist;\n  FooSkipListElem elem0;\n  *elem0.key.mut_key() = 0;\n  elem0.value = 1;\n  skiplist.Insert(&elem0);\n  ASSERT_EQ(skiplist.size(), 1);\n  ASSERT_TRUE(skiplist.Find(int(0)) != nullptr);\n  skiplist.Erase(int(0));\n  ASSERT_EQ(skiplist.size(), 0);\n  ASSERT_TRUE(skiplist.Find(int(0)) == nullptr);\n}\n\nTEST(SkipListHook, erase_by_elem) {\n  FooSkipList skiplist;\n  FooSkipListElem elem0;\n  *elem0.key.mut_key() = 0;\n  elem0.value = 1;\n  skiplist.Insert(&elem0);\n  ASSERT_EQ(skiplist.size(), 1);\n  ASSERT_TRUE(skiplist.Find(int(0)) != nullptr);\n  skiplist.Erase(&elem0);\n  ASSERT_EQ(skiplist.size(), 0);\n  ASSERT_TRUE(skiplist.Find(int(0)) == nullptr);\n}\n\nTEST(SkipListHook, insert_many) {\n  FooSkipList skiplist;\n  FooSkipListElem exists[100];\n  for (int i = 0; i < 100; ++i) {\n    int key = i - 50;\n    if (key >= 0) { ++key; }\n    *exists[i].key.mut_key() = key;\n    skiplist.Insert(&exists[i]);\n    ASSERT_EQ(skiplist.Find(key), &exists[i]);\n  }\n  FooSkipListElem elem0;\n  *elem0.key.mut_key() = 0;\n  elem0.value = 1;\n  skiplist.Insert(&elem0);\n  ASSERT_EQ(skiplist.size(), 101);\n  {\n    auto* searched = skiplist.Find(int(0));\n    ASSERT_EQ(searched, &elem0);\n  }\n  {\n    auto* searched = skiplist.Find(int(-1001));\n    ASSERT_TRUE(searched == nullptr);\n  }\n  skiplist.Clear();\n  ASSERT_TRUE(skiplist.empty_debug());\n}\n\nTEST(SkipListHook, erase_many_by_key) {\n  FooSkipList skiplist;\n  FooSkipListElem exists[100];\n  for (int i = 0; i < 100; ++i) {\n    int key = i - 50;\n    if (key >= 0) { ++key; }\n    *exists[i].key.mut_key() = key;\n    skiplist.Insert(&exists[i]);\n    ASSERT_EQ(skiplist.Find(key), &exists[i]);\n  }\n  FooSkipListElem elem0;\n  *elem0.key.mut_key() = 0;\n  elem0.value = 1;\n  skiplist.Insert(&elem0);\n  ASSERT_EQ(skiplist.size(), 101);\n  ASSERT_TRUE(skiplist.Find(int(0)) != nullptr);\n  skiplist.Erase(int(0));\n  ASSERT_EQ(skiplist.size(), 100);\n  ASSERT_TRUE(skiplist.Find(int(0)) == nullptr);\n  skiplist.Clear();\n  ASSERT_TRUE(skiplist.empty_debug());\n}\n\nTEST(SkipListHook, erase_many_by_elem) {\n  FooSkipList skiplist;\n  FooSkipListElem exists[100];\n  for (int i = 0; i < 100; ++i) {\n    int key = i - 50;\n    if (key >= 0) { ++key; }\n    *exists[i].key.mut_key() = key;\n    skiplist.Insert(&exists[i]);\n    ASSERT_EQ(skiplist.Find(key), &exists[i]);\n  }\n  FooSkipListElem elem0;\n  *elem0.key.mut_key() = 0;\n  elem0.value = 1;\n  skiplist.Insert(&elem0);\n  ASSERT_EQ(skiplist.size(), 101);\n  ASSERT_TRUE(skiplist.Find(int(0)) != nullptr);\n  skiplist.Erase(&elem0);\n  ASSERT_EQ(skiplist.size(), 100);\n  ASSERT_TRUE(skiplist.Find(int(0)) == nullptr);\n  skiplist.Clear();\n  ASSERT_TRUE(skiplist.empty_debug());\n}\n\n}  // namespace test\n\n}  // namespace intrusive\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/intrusive/skiplist_test.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"gtest/gtest.h\"\n#include \"oneflow/core/intrusive/intrusive.h\"\n#include \"oneflow/core/common/util.h\"\n\nnamespace oneflow {\n\nnamespace intrusive {\n\nnamespace test {\n\nnamespace {\n\nclass SkipListFoo final : public intrusive::Base {\n public:\n  void __Init__() { clear_is_deleted(); }\n  void __Delete__() {\n    if (has_is_deleted()) { ++*mut_is_deleted(); }\n  }\n\n  // Getters\n  bool has_is_deleted() const { return is_deleted_ != nullptr; }\n  int is_deleted() const { return *is_deleted_; }\n  int32_t foo_map_key() const { return foo_map_key_.key(); }\n  // Setters\n  void set_is_deleted(int* val) { is_deleted_ = val; }\n  void clear_is_deleted() { is_deleted_ = nullptr; }\n  int* mut_is_deleted() { return is_deleted_; }\n  void set_foo_map_key(int32_t val) { *foo_map_key_.mut_key() = val; }\n\n  size_t ref_cnt() const { return intrusive_ref_.ref_cnt(); }\n\n private:\n  friend class intrusive::Ref;\n  intrusive::Ref* mut_intrusive_ref() { return &intrusive_ref_; }\n\n  SkipListFoo() : intrusive_ref_(), is_deleted_(), foo_map_key_() {}\n  intrusive::Ref intrusive_ref_;\n  int* is_deleted_;\n\n public:\n  intrusive::SkipListHook<int32_t> foo_map_key_;\n};\n\nclass SkipListFooContainer final : public intrusive::Base {\n public:\n  // types\n  using Key2SkipListFoo = intrusive::SkipList<INTRUSIVE_FIELD(SkipListFoo, foo_map_key_)>;\n  // Getters\n  const Key2SkipListFoo& foo_map() const { return foo_map_; }\n  // Setters\n  Key2SkipListFoo* mut_foo_map() { return &foo_map_; }\n\n private:\n  friend class intrusive::Ref;\n  intrusive::Ref* mut_intrusive_ref() { return &intrusive_ref_; }\n\n  SkipListFooContainer() : intrusive_ref_(), foo_map_() {}\n  intrusive::Ref intrusive_ref_;\n  // maps\n  Key2SkipListFoo foo_map_;\n};\n\nusing Key2SkipListFoo = intrusive::SkipList<INTRUSIVE_FIELD(SkipListFoo, foo_map_key_)>;\nTEST(SkipList, empty) {\n  Key2SkipListFoo foo_map;\n  ASSERT_TRUE(foo_map.empty());\n  ASSERT_EQ(foo_map.size(), 0);\n}\n\nTEST(SkipList, insert_naive) {\n  Key2SkipListFoo foo_map;\n  auto elem0 = intrusive::make_shared<SkipListFoo>();\n  elem0->set_foo_map_key(0);\n  foo_map.Insert(elem0.Mutable());\n  ASSERT_EQ(foo_map.size(), 1);\n  {\n    auto searched = foo_map.Find(int(0));\n    ASSERT_TRUE(searched == elem0);\n  }\n  {\n    auto searched = foo_map.Find(int(-1));\n    ASSERT_TRUE(foo_map.EqualsEnd(searched));\n  }\n}\n\nTEST(SkipList, insert_twice) {\n  Key2SkipListFoo foo_map;\n  auto elem0 = intrusive::make_shared<SkipListFoo>();\n  elem0->set_foo_map_key(0);\n  auto elem1 = intrusive::make_shared<SkipListFoo>();\n  elem1->set_foo_map_key(0);\n  ASSERT_TRUE(foo_map.Insert(elem0.Mutable()).second);\n  ASSERT_TRUE(!foo_map.Insert(elem1.Mutable()).second);\n}\n\nTEST(SkipList, erase_by_key) {\n  Key2SkipListFoo foo_map;\n  auto elem0 = intrusive::make_shared<SkipListFoo>();\n  elem0->set_foo_map_key(0);\n  foo_map.Insert(elem0.Mutable());\n  ASSERT_EQ(foo_map.size(), 1);\n  ASSERT_TRUE(!foo_map.EqualsEnd(foo_map.Find(int(0))));\n  foo_map.Erase(int(0));\n  ASSERT_EQ(foo_map.size(), 0);\n  ASSERT_TRUE(foo_map.EqualsEnd(foo_map.Find(int(0))));\n}\n\nTEST(SkipList, erase_by_elem) {\n  Key2SkipListFoo foo_map;\n  auto elem0 = intrusive::make_shared<SkipListFoo>();\n  elem0->set_foo_map_key(0);\n  foo_map.Insert(elem0.Mutable());\n  ASSERT_EQ(foo_map.size(), 1);\n  ASSERT_TRUE(!foo_map.EqualsEnd(foo_map.Find(int(0))));\n  foo_map.Erase(elem0.Mutable());\n  ASSERT_EQ(foo_map.size(), 0);\n  ASSERT_TRUE(foo_map.EqualsEnd(foo_map.Find(int(0))));\n}\n\nTEST(SkipList, insert_many) {\n  Key2SkipListFoo foo_map;\n  intrusive::shared_ptr<SkipListFoo> exists[100];\n  for (int i = 0; i < 100; ++i) {\n    exists[i] = intrusive::make_shared<SkipListFoo>();\n    int key = i - 50;\n    if (key >= 0) { ++key; }\n    exists[i]->set_foo_map_key(key);\n    foo_map.Insert(exists[i].Mutable());\n    ASSERT_TRUE(foo_map.Find(key) == exists[i]);\n  }\n  auto elem0 = intrusive::make_shared<SkipListFoo>();\n  elem0->set_foo_map_key(0);\n  foo_map.Insert(elem0.Mutable());\n  ASSERT_EQ(foo_map.size(), 101);\n  {\n    auto searched = foo_map.Find(int(0));\n    ASSERT_TRUE(searched == elem0);\n  }\n  {\n    auto searched = foo_map.Find(int(-1001));\n    ASSERT_TRUE(foo_map.EqualsEnd(searched));\n  }\n  foo_map.Clear();\n  ASSERT_TRUE(foo_map.empty());\n}\n\nTEST(SkipList, erase_many_by_key) {\n  Key2SkipListFoo foo_map;\n  intrusive::shared_ptr<SkipListFoo> exists[100];\n  for (int i = 0; i < 100; ++i) {\n    exists[i] = intrusive::make_shared<SkipListFoo>();\n    int key = i - 50;\n    if (key >= 0) { ++key; }\n    exists[i]->set_foo_map_key(key);\n    foo_map.Insert(exists[i].Mutable());\n    ASSERT_TRUE(foo_map.Find(key) == exists[i]);\n  }\n  auto elem0 = intrusive::make_shared<SkipListFoo>();\n  elem0->set_foo_map_key(0);\n  foo_map.Insert(elem0.Mutable());\n  ASSERT_EQ(foo_map.size(), 101);\n  ASSERT_TRUE(!foo_map.EqualsEnd(foo_map.Find(int(0))));\n  foo_map.Erase(int(0));\n  ASSERT_EQ(foo_map.size(), 100);\n  ASSERT_TRUE(foo_map.EqualsEnd(foo_map.Find(int(0))));\n  foo_map.Clear();\n  ASSERT_TRUE(foo_map.empty());\n}\n\nTEST(SkipList, erase_many_by_elem) {\n  Key2SkipListFoo foo_map;\n  intrusive::shared_ptr<SkipListFoo> exists[100];\n  for (int i = 0; i < 100; ++i) {\n    exists[i] = intrusive::make_shared<SkipListFoo>();\n    int key = i - 50;\n    if (key >= 0) { ++key; }\n    exists[i]->set_foo_map_key(key);\n    foo_map.Insert(exists[i].Mutable());\n    ASSERT_TRUE(foo_map.Find(key) == exists[i]);\n  }\n  auto elem0 = intrusive::make_shared<SkipListFoo>();\n  elem0->set_foo_map_key(0);\n  foo_map.Insert(elem0.Mutable());\n  ASSERT_EQ(foo_map.size(), 101);\n  ASSERT_TRUE(!foo_map.EqualsEnd(foo_map.Find(int(0))));\n  foo_map.Erase(elem0.Mutable());\n  ASSERT_EQ(foo_map.size(), 100);\n  ASSERT_TRUE(foo_map.EqualsEnd(foo_map.Find(int(0))));\n  foo_map.Clear();\n  ASSERT_TRUE(foo_map.empty());\n}\n\nTEST(SkipList, MAP_HEAD) {\n  int elem_cnt = 0;\n  {\n    auto foo_map_container = intrusive::make_shared<SkipListFooContainer>();\n    auto& foo_map = *foo_map_container->mut_foo_map();\n    intrusive::shared_ptr<SkipListFoo> exists[100];\n    for (int i = 0; i < 100; ++i) {\n      exists[i] = intrusive::make_shared<SkipListFoo>();\n      int key = i - 50;\n      if (key >= 0) { ++key; }\n      exists[i]->set_foo_map_key(key);\n      exists[i]->set_is_deleted(&elem_cnt);\n      foo_map.Insert(exists[i].Mutable());\n      ASSERT_TRUE(foo_map.Find(key) == exists[i]);\n      ASSERT_EQ(exists[i]->ref_cnt(), 2);\n    }\n    auto elem0 = intrusive::make_shared<SkipListFoo>();\n    elem0->set_foo_map_key(0);\n    elem0->set_is_deleted(&elem_cnt);\n    foo_map.Insert(elem0.Mutable());\n    ASSERT_EQ(foo_map.size(), 101);\n    ASSERT_TRUE(!foo_map.EqualsEnd(foo_map.Find(int(0))));\n    ASSERT_EQ(elem0->ref_cnt(), 2);\n    foo_map.Erase(elem0->foo_map_key());\n    ASSERT_EQ(elem0->ref_cnt(), 1);\n    ASSERT_EQ(foo_map.size(), 100);\n    ASSERT_TRUE(foo_map.EqualsEnd(foo_map.Find(int(0))));\n    foo_map.Clear();\n    ASSERT_TRUE(foo_map.empty());\n  }\n  ASSERT_EQ(elem_cnt, 101);\n}\n\nTEST(SkipList, FOR_EACH) {\n  int elem_cnt = 0;\n  {\n    auto foo_map_container = intrusive::make_shared<SkipListFooContainer>();\n    auto& foo_map = *foo_map_container->mut_foo_map();\n    intrusive::shared_ptr<SkipListFoo> exists[100];\n    for (int i = 0; i < 100; ++i) {\n      exists[i] = intrusive::make_shared<SkipListFoo>();\n      int key = i - 50;\n      exists[i]->set_foo_map_key(key);\n      exists[i]->set_is_deleted(&elem_cnt);\n      foo_map.Insert(exists[i].Mutable());\n      ASSERT_TRUE(foo_map.Find(key) == exists[i]);\n      ASSERT_EQ(exists[i]->ref_cnt(), 2);\n    }\n    int value = -50;\n    INTRUSIVE_UNSAFE_FOR_EACH_PTR(foo, &foo_map) {\n      ASSERT_EQ(foo->foo_map_key(), value);\n      ++value;\n    }\n  }\n  ASSERT_EQ(elem_cnt, 100);\n}\n\n}  // namespace\n\n}  // namespace test\n\n}  // namespace intrusive\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/intrusive/static_counter.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_INTRUSIVE_STATIC_COUNTER_H_\n#define ONEFLOW_CORE_INTRUSIVE_STATIC_COUNTER_H_\n\nnamespace oneflow {\n\n#define STATIC_COUNTER(counter_name) _STATIC_COUNTER_NAME(counter_name)<_AUTO_INCREMENT()>::value\n\n#define DEFINE_STATIC_COUNTER(counter_name) _DEFINE_STATIC_COUNTER(_AUTO_INCREMENT(), counter_name)\n\n#define INCREASE_STATIC_COUNTER(counter_name) \\\n  _INCREASE_STATIC_COUNTER(_AUTO_INCREMENT(), counter_name)\n\n// details\n\n#define _STATIC_COUNTER_NAME(counter_name) StaticCounter_##counter_name\n#define _AUTO_INCREMENT() __COUNTER__\n\n#define _DEFINE_STATIC_COUNTER(auto_counter, counter_name)                               \\\n  template<int tpl_counter, typename Enabled = void>                                     \\\n  struct _STATIC_COUNTER_NAME(counter_name) {                                            \\\n    static const int value = _STATIC_COUNTER_NAME(counter_name)<tpl_counter - 1>::value; \\\n  };                                                                                     \\\n  template<typename Enabled>                                                             \\\n  struct _STATIC_COUNTER_NAME(counter_name)<auto_counter, Enabled> {                     \\\n    static const int value = 0;                                                          \\\n  };\n\n#define _INCREASE_STATIC_COUNTER(auto_counter, counter_name)                                  \\\n  template<typename Enabled>                                                                  \\\n  struct _STATIC_COUNTER_NAME(counter_name)<auto_counter, Enabled> {                          \\\n    static const int value = _STATIC_COUNTER_NAME(counter_name)<auto_counter - 1>::value + 1; \\\n  };\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_INTRUSIVE_STATIC_COUNTER_H_\n"
  },
  {
    "path": "oneflow/core/intrusive/static_counter_test.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"gtest/gtest.h\"\n#include \"oneflow/core/intrusive/static_counter.h\"\n#include \"oneflow/core/intrusive/intrusive.h\"\n#include \"oneflow/core/common/util.h\"\n\nnamespace oneflow {\n\nnamespace test {\n\nnamespace {\n\nDEFINE_STATIC_COUNTER(static_counter);\n\nstatic_assert(STATIC_COUNTER(static_counter) == 0, \"\");\n\nTEST(StaticCounter, eq0) { static_assert(STATIC_COUNTER(static_counter) == 0, \"\"); }\n\nINCREASE_STATIC_COUNTER(static_counter);\n\nstatic_assert(STATIC_COUNTER(static_counter) == 1, \"\");\n\nTEST(StaticCounter, eq1) { static_assert(STATIC_COUNTER(static_counter) == 1, \"\"); }\n\nstatic_assert(STATIC_COUNTER(static_counter) == 1, \"\");\n\nTEST(StaticCounter, eq1_again) { static_assert(STATIC_COUNTER(static_counter) == 1, \"\"); }\n\nINCREASE_STATIC_COUNTER(static_counter);\n\nstatic_assert(STATIC_COUNTER(static_counter) == 2, \"\");\n\nTEST(StaticCounter, eq2) { static_assert(STATIC_COUNTER(static_counter) == 2, \"\"); }\n\n// clang-format off\nREFLECTIVE_CLASS_BEGIN(FooBar);\n  FooBar() = default;\n  static_assert(STATIC_COUNTER(field_counter) == 0, \"\");\nREFLECTIVE_CLASS_END(FooBar);\n// clang-format on\n\n}  // namespace\n\n}  // namespace test\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/intrusive/struct_traits.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_INTRUSIVE_STRUCT_MACRO_TRAITS_H_\n#define ONEFLOW_CORE_INTRUSIVE_STRUCT_MACRO_TRAITS_H_\n\n#include <cstddef>\n#include <type_traits>\n#include \"oneflow/core/common/preprocessor.h\"\n\nnamespace oneflow {\nnamespace intrusive {\n\ntemplate<typename T, typename F, F T::*ptr2member>\nstruct PtrStructField {\n  using struct_type = T;\n  using field_type = F;\n\n  static T* StructPtr4FieldPtr(const F* field_ptr) {\n    int offset_value = reinterpret_cast<long long>(&(((T*)nullptr)->*ptr2member));\n    return (T*)((const_cast<char*>(reinterpret_cast<const char*>(field_ptr))) - offset_value);\n  }\n  static F* FieldPtr4StructPtr(const T* struct_ptr) {\n    return &(const_cast<T*>(struct_ptr)->*ptr2member);\n  }\n};\n\ntemplate<typename T, typename F, int offset>\nstruct OffsetStructField {\n  using struct_type = T;\n  using field_type = F;\n  static const int offset_value = offset;\n\n  static T* StructPtr4FieldPtr(const F* field_ptr) {\n    return (T*)((const_cast<char*>(reinterpret_cast<const char*>(field_ptr))) - offset_value);\n  }\n  static F* FieldPtr4StructPtr(const T* struct_ptr) {\n    return (F*)((const_cast<char*>(reinterpret_cast<const char*>(struct_ptr))) + offset_value);\n  }\n};\n\n#define INTRUSIVE_FIELD(struct_type, field_name)                                        \\\n  intrusive::PtrStructField<struct_type, decltype(((struct_type*)nullptr)->field_name), \\\n                            &struct_type::field_name>\n\ntemplate<typename X, typename Y>\nstruct ComposeStructField {\n  static_assert(std::is_same<typename X::field_type, typename Y::struct_type>::value,\n                \"invalid type\");\n  using struct_type = typename X::struct_type;\n  using field_type = typename Y::field_type;\n  static struct_type* StructPtr4FieldPtr(const field_type* field_ptr) {\n    return X::StructPtr4FieldPtr(Y::StructPtr4FieldPtr(field_ptr));\n  }\n  static field_type* FieldPtr4StructPtr(const struct_type* struct_ptr) {\n    return Y::FieldPtr4StructPtr(X::FieldPtr4StructPtr(struct_ptr));\n  }\n};\n\ntemplate<typename T>\nstruct ConstStruct {\n  using type = const T;\n};\ntemplate<typename T>\nstruct ConstStruct<const T> {\n  using type = const T;\n};\n\ntemplate<typename T>\nusing ConstType = typename ConstStruct<T>::type;\n\ntemplate<typename T>\nstruct ConstRefOrPtrStruct {\n  using type = ConstType<T>&;\n};\n\ntemplate<typename T>\nstruct ConstRefOrPtrStruct<T*> {\n  using type = ConstType<T>*;\n};\n\ntemplate<typename T>\nusing ConstRefOrPtr = typename ConstRefOrPtrStruct<T>::type;\n\n}  // namespace intrusive\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_INTRUSIVE_STRUCT_MACRO_TRAITS_H_\n"
  },
  {
    "path": "oneflow/core/intrusive/struct_traits_test.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"gtest/gtest.h\"\n#include \"oneflow/core/intrusive/struct_traits.h\"\n#include \"oneflow/core/common/util.h\"\n\nnamespace oneflow {\n\nnamespace test {\n\nnamespace {\n\nstruct OneflowTestNamespaceFoo {\n  OneflowTestNamespaceFoo() : x(0), bar(0), const_bar(0) {}\n\n  int x;\n  int bar;\n  const int const_bar;\n};\n\nTEST(StructField, mutable_struct_mutable_field) {\n  OneflowTestNamespaceFoo foo;\n  auto* bar = &foo.bar;\n  auto* struct_ptr = INTRUSIVE_FIELD(OneflowTestNamespaceFoo, bar)::StructPtr4FieldPtr(bar);\n  auto* field_ptr = INTRUSIVE_FIELD(OneflowTestNamespaceFoo, bar)::FieldPtr4StructPtr(&foo);\n  ASSERT_EQ(struct_ptr, &foo);\n  ASSERT_EQ(field_ptr, bar);\n}\n\nTEST(StructField, mutable_struct_const_field) {\n  OneflowTestNamespaceFoo foo;\n  auto* bar = &foo.const_bar;\n  auto* struct_ptr = INTRUSIVE_FIELD(OneflowTestNamespaceFoo, const_bar)::StructPtr4FieldPtr(bar);\n  auto* field_ptr = INTRUSIVE_FIELD(OneflowTestNamespaceFoo, const_bar)::FieldPtr4StructPtr(&foo);\n  ASSERT_EQ(struct_ptr, &foo);\n  ASSERT_EQ(field_ptr, bar);\n}\n\nTEST(StructField, const_struct_mutable_field) {\n  const OneflowTestNamespaceFoo foo;\n  auto* bar = &foo.bar;\n  auto* struct_ptr = INTRUSIVE_FIELD(OneflowTestNamespaceFoo, bar)::StructPtr4FieldPtr(bar);\n  auto* field_ptr = INTRUSIVE_FIELD(OneflowTestNamespaceFoo, bar)::FieldPtr4StructPtr(&foo);\n  ASSERT_EQ(struct_ptr, &foo);\n  ASSERT_EQ(field_ptr, bar);\n}\n\nTEST(StructField, const_struct_const_field) {\n  const OneflowTestNamespaceFoo foo;\n  auto* bar = &foo.const_bar;\n  auto* struct_ptr = INTRUSIVE_FIELD(OneflowTestNamespaceFoo, const_bar)::StructPtr4FieldPtr(bar);\n  auto* field_ptr = INTRUSIVE_FIELD(OneflowTestNamespaceFoo, const_bar)::FieldPtr4StructPtr(&foo);\n  ASSERT_EQ(struct_ptr, &foo);\n  ASSERT_EQ(field_ptr, bar);\n}\n\nstruct X {\n  int a;\n  int b;\n};\n\nstruct Y {\n  int c;\n  X d;\n};\n\nTEST(StructField, compose) {\n  using BFieldInY = intrusive::ComposeStructField<INTRUSIVE_FIELD(Y, d), INTRUSIVE_FIELD(X, b)>;\n  Y y{};\n  int* field_b = &y.d.b;\n  ASSERT_EQ(BFieldInY::FieldPtr4StructPtr(&y), field_b);\n  ASSERT_EQ(BFieldInY::StructPtr4FieldPtr(field_b), &y);\n}\n\n}  // namespace\n\n}  // namespace test\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/ipc/shared_memory.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/ipc/shared_memory.h\"\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/common/pcheck.h\"\n#include \"oneflow/core/common/str_util.h\"\n#include \"oneflow/core/common/optional.h\"\n#include \"oneflow/core/common/env_var/env_var.h\"\n#ifdef __linux__\n#include <sys/types.h>\n#include <sys/mman.h>\n#include <sys/shm.h>\n#include <sys/stat.h>\n#include <fcntl.h>\n#include <error.h>\n#include <dirent.h>\n#endif\n\nnamespace oneflow {\nnamespace ipc {\n\nnamespace {\n\n#ifdef __linux__\n\n// return errno\nint ShmOpen(const std::string& shm_name, int* fd, bool create) {\n  SharedMemoryManager::get().AddShmName(shm_name);\n  *fd = shm_open((\"/\" + shm_name).c_str(), (create ? O_CREAT : 0) | O_RDWR | O_EXCL,\n                 S_IRUSR | S_IWUSR);\n  return *fd == -1 ? errno : 0;\n}\n\n// return errno\nint ShmOpen(std::string* shm_name, int* fd, bool create) {\n  int err = EEXIST;\n  while (true) {\n    static constexpr int kNameLength = 8;\n    *shm_name = std::string(\"ofshm_\") + GenAlphaNumericString(kNameLength);\n    err = ShmOpen(*shm_name, fd, create);\n    if (err != EEXIST) { return err; }\n  }\n  return err;\n}\n\nint ShmMap(int fd, const size_t shm_size, void** ptr) {\n  *ptr = mmap(NULL, shm_size, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0);\n  return (*ptr == MAP_FAILED) ? errno : 0;\n}\n\n#endif\n\nMaybe<void*> ShmSetUp(std::string* shm_name, size_t shm_size, bool create) {\n#ifdef __linux__\n  int fd = 0;\n  PCHECK_OR_RETURN(ShmOpen(shm_name, &fd, create));\n  PCHECK_OR_RETURN(posix_fallocate(fd, 0, shm_size)) << ReturnEmptyStr([&] { close(fd); });\n  void* ptr = nullptr;\n  PCHECK_OR_RETURN(ShmMap(fd, shm_size, &ptr)) << ReturnEmptyStr([&] { close(fd); });\n  close(fd);\n  std::memset(ptr, 0, shm_size);\n  return ptr;\n#else\n  TODO_THEN_RETURN();\n#endif\n}\n\nMaybe<void*> ShmSetUp(const std::string& shm_name, size_t* shm_size, bool create) {\n#ifdef __linux__\n  int fd = 0;\n  PCHECK_OR_RETURN(ShmOpen(shm_name, &fd, create));\n  struct stat st;  // NOLINT\n  PCHECK_OR_RETURN(fstat(fd, &st)) << ReturnEmptyStr([&] { close(fd); });\n  *shm_size = st.st_size;\n  void* ptr = nullptr;\n  PCHECK_OR_RETURN(ShmMap(fd, *shm_size, &ptr)) << ReturnEmptyStr([&] { close(fd); });\n  close(fd);\n  return ptr;\n#else\n  TODO_THEN_RETURN();\n#endif\n}\n\nMaybe<std::set<std::string>> GetContentsOfShmDirectory() {\n#ifdef __linux__\n  std::set<std::string> contents;\n  DIR* dir = opendir(\"/dev/shm/\");\n  CHECK_NOTNULL_OR_RETURN(dir)\n      << \"/dev/shm directory does not exist, there may be a problem with your machine!\";\n  while (dirent* f = readdir(dir)) {\n    if (f->d_name[0] == '.') continue;\n    contents.insert(f->d_name);\n  }\n  closedir(dir);\n  return contents;\n#else\n  TODO_THEN_RETURN();\n#endif\n}\n}  // namespace\n\nSharedMemoryManager& SharedMemoryManager::get() {\n  // Must be a static singleton variable instead of Singleton<SharedMemoryManager>.\n  // Subprocesses don't have chance to call `Singleton<SharedMemoryManager>::Delete()`\n  static SharedMemoryManager shared_memory_manager;\n  return shared_memory_manager;\n}\n\nvoid SharedMemoryManager::FindAndDeleteOutdatedShmNames() {\n  std::unique_lock<std::recursive_mutex> lock(mutex_);\n  static size_t counter = 0;\n  const int delete_invalid_names_interval =\n      EnvInteger<ONEFLOW_DELETE_OUTDATED_SHM_NAMES_INTERVAL>();\n  if (counter % delete_invalid_names_interval == 0) {\n    const auto& existing_shm_names = CHECK_JUST(GetContentsOfShmDirectory());\n    // std::remove_if doesn't support std::map\n    for (auto it = shm_names_.begin(); it != shm_names_.end(); /* do nothing */) {\n      if (existing_shm_names->find(*it) == existing_shm_names->end()) {\n        it = shm_names_.erase(it);\n      } else {\n        it++;\n      }\n    }\n  }\n  counter++;\n}\n\nvoid SharedMemoryManager::AddShmName(const std::string& shm_name) {\n  FindAndDeleteOutdatedShmNames();\n  std::unique_lock<std::recursive_mutex> lock(mutex_);\n  shm_names_.insert(shm_name);\n}\n\nMaybe<void> SharedMemoryManager::DeleteShmName(const std::string& shm_name) {\n  std::unique_lock<std::recursive_mutex> lock(mutex_);\n  auto it = std::find(shm_names_.begin(), shm_names_.end(), shm_name);\n  if (it != shm_names_.end()) {\n    shm_names_.erase(it);\n  } else {\n    return Error::RuntimeError() << \"shared memory was not created but attempted to be freed.\";\n  }\n  return Maybe<void>::Ok();\n}\n\nvoid SharedMemoryManager::UnlinkAllShms() {\n#ifdef __linux__\n  // Here we deliberately do not handle unlink errors.\n  std::unique_lock<std::recursive_mutex> lock(mutex_);\n  for (const auto& shm : shm_names_) { shm_unlink(shm.c_str()); }\n  shm_names_.clear();\n#else\n  UNIMPLEMENTED();\n#endif\n}\n\nSharedMemoryManager::~SharedMemoryManager() { UnlinkAllShms(); }\n\nSharedMemory::~SharedMemory() { CHECK_JUST(Close()); }\n\nMaybe<SharedMemory> SharedMemory::Open(size_t shm_size, bool create) {\n  std::string shm_name;\n  char* ptr = static_cast<char*>(JUST(ShmSetUp(&shm_name, shm_size, create)));\n  return std::shared_ptr<SharedMemory>(new SharedMemory(ptr, shm_name, shm_size));\n}\n\nMaybe<SharedMemory> SharedMemory::Open(const std::string& shm_name, bool create) {\n  size_t shm_size = 0;\n  char* ptr = static_cast<char*>(JUST(ShmSetUp(shm_name, &shm_size, create)));\n  return std::shared_ptr<SharedMemory>(new SharedMemory(ptr, shm_name, shm_size));\n}\n\nMaybe<void> SharedMemory::Close() {\n#ifdef __linux__\n  if (buf_ != nullptr) {\n    PCHECK_OR_RETURN(munmap(buf_, size_));\n    buf_ = nullptr;\n  }\n  return Maybe<void>::Ok();\n#else\n  TODO_THEN_RETURN();\n#endif\n}\n\nMaybe<void> SharedMemory::Unlink() {\n#ifdef __linux__\n  PCHECK_OR_RETURN(shm_unlink(name_.c_str()));\n  JUST(SharedMemoryManager::get().DeleteShmName(name_));\n  return Maybe<void>::Ok();\n#else\n  TODO_THEN_RETURN();\n#endif\n}\n\n}  // namespace ipc\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/ipc/shared_memory.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_IPC_SHARED_MEMORY_H_\n#define ONEFLOW_CORE_IPC_SHARED_MEMORY_H_\n\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/common/singleton.h\"\n\nnamespace oneflow {\nnamespace ipc {\n\nclass SharedMemoryManager final {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(SharedMemoryManager);\n  ~SharedMemoryManager();\n  void AddShmName(const std::string& shm_name);\n  Maybe<void> DeleteShmName(const std::string& shm_name);\n\n  void UnlinkAllShms();\n\n  static SharedMemoryManager& get();\n\n private:\n  SharedMemoryManager() = default;\n  void FindAndDeleteOutdatedShmNames();\n  std::set<std::string> shm_names_;\n  std::recursive_mutex mutex_;\n};\n\nclass SharedMemory final {\n public:\n  SharedMemory(const SharedMemory&) = delete;\n  SharedMemory(SharedMemory&&) = delete;\n  ~SharedMemory();\n\n  static Maybe<SharedMemory> Open(size_t size, bool create);\n  static Maybe<SharedMemory> Open(const std::string& name, bool create);\n\n  const char* buf() const { return buf_; }\n  char* mut_buf() { return buf_; }\n\n  const std::string& name() const { return name_; }\n  size_t size() const { return size_; }\n\n  Maybe<void> Close();\n  Maybe<void> Unlink();\n\n private:\n  SharedMemory(char* buf, const std::string& name, size_t size)\n      : buf_(buf), name_(name), size_(size) {}\n\n  char* buf_;\n  std::string name_;\n  size_t size_;\n};\n\n}  // namespace ipc\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_IPC_SHARED_MEMORY_H_\n"
  },
  {
    "path": "oneflow/core/job/blob_lifetime_signature.proto",
    "content": "syntax = \"proto2\";\npackage oneflow;\n\nmessage BlobLastUsedSignature {\n  map<string, bool> bn_in_op2blob_last_used = 1;\n}\n\nmessage BlobBackwardUsedSignature {\n  map<string, bool> bn_in_op2blob_backward_used = 1;\n}\n"
  },
  {
    "path": "oneflow/core/job/checkpointing_config_def.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/config_def.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nREGISTER_SCOPE_CONFIG_DEF().Bool(\n    \"checkpointing\", false,\n    \"enable checkpointing op/tensor for backward recomputation to sublinear memory cost\");\n\n}  // namespace\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job/cluster_instruction.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <mutex>\n#include \"oneflow/core/job/cluster_instruction.h\"\n#include \"oneflow/core/job/cluster_instruction.pb.h\"\n#include \"oneflow/core/control/ctrl_server.h\"\n#include \"oneflow/core/control/ctrl_client.h\"\n#include \"oneflow/core/control/global_process_ctx.h\"\n#include \"oneflow/core/job/env_desc.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nstd::string GetHaltAckCtrlKey(int64_t machine_id) {\n  return \"HaltAckCtrlKey/\" + std::to_string(machine_id);\n}\n\n// return unique sequential key\n// because ctrl key is not allowed to push/pull twice\nstd::string GetClusterInstructionKey() {\n  static int64_t seq = 0;\n  return \"ClusterInstructionKey/\" + std::to_string(seq++);\n}\n\nclass ObsoleteCtrlKeys {\n public:\n  ObsoleteCtrlKeys() = default;\n  ~ObsoleteCtrlKeys() = default;\n\n  template<typename CallbackT>\n  void ForEach(const CallbackT& Callback) const {\n    std::unique_lock<std::mutex> lck(mutex_);\n    for (const std::string& k : keys_) { Callback(k); }\n  }\n\n  void Clear() {\n    std::unique_lock<std::mutex> lck(mutex_);\n    keys_.clear();\n  }\n  void Add(const std::string& key) {\n    std::unique_lock<std::mutex> lck(mutex_);\n    keys_.emplace_back(key);\n  }\n\n private:\n  mutable std::mutex mutex_;\n  std::vector<std::string> keys_;\n};\n\nCOMMAND(Singleton<ObsoleteCtrlKeys>::SetAllocated(new ObsoleteCtrlKeys()));\n\nvoid OccasionallyClearCtrlKV(const std::string& key) {\n  static std::atomic<int64_t> seq(0LL);\n  const static int64_t interval = 65536;\n  Singleton<ObsoleteCtrlKeys>::Get()->Add(key);\n  // 1 instead of 0 is better for avoid clearing no ctrl kv\n  if ((seq++) % interval == 1) {\n    OF_ENV_BARRIER();\n    if (GlobalProcessCtx::IsThisProcessMaster()) {\n      Singleton<ObsoleteCtrlKeys>::Get()->ForEach(\n          [](const std::string& k) { Singleton<CtrlClient>::Get()->ClearMasterKV(k); });\n    }\n    Singleton<ObsoleteCtrlKeys>::Get()->Clear();\n    OF_ENV_BARRIER();\n  }\n}\n\nvoid PushClusterInstruction(const ClusterInstructionProto& cluster_instruction) {\n  const std::string& key = GetClusterInstructionKey();\n  Singleton<CtrlClient>::Get()->PushMasterKV(key, cluster_instruction);\n  OccasionallyClearCtrlKV(key);\n}\n\nvoid PullClusterInstruction(ClusterInstructionProto* cluster_instruction) {\n  const std::string& key = GetClusterInstructionKey();\n  Singleton<CtrlClient>::Get()->PullMasterKV(key, cluster_instruction);\n  OccasionallyClearCtrlKV(key);\n}\n\n}  // namespace\n\nvoid ClusterInstruction::NewSessionBarrier() {\n  OF_ENV_BARRIER();\n  Singleton<CtrlClient>::Get()->Clear();\n  Singleton<ObsoleteCtrlKeys>::Get()->Clear();\n  OF_ENV_BARRIER();\n}\n\nvoid ClusterInstruction::MasterSendSessionStart() {\n  ClusterInstructionProto cluster_instruction;\n  cluster_instruction.mutable_cluster_ctrl_session_start();\n  PushClusterInstruction(cluster_instruction);\n  NewSessionBarrier();\n}\n\nvoid ClusterInstruction::MasterSendHalt() {\n  ClusterInstructionProto cluster_instruction;\n  cluster_instruction.mutable_cluster_ctrl_halt();\n  PushClusterInstruction(cluster_instruction);\n  HaltBarrier();\n}\n\nvoid ClusterInstruction::MasterSendAbort() {\n  LOG(INFO) << \"Sending abort instruction.\";\n  ClusterInstructionProto cluster_instruction;\n  cluster_instruction.mutable_cluster_ctrl_abort();\n  PushClusterInstruction(cluster_instruction);\n}\n\nvoid ClusterInstruction::WorkerReceiveInstruction(ClusterInstructionProto* cluster_instruction) {\n  PullClusterInstruction(cluster_instruction);\n}\n\nvoid ClusterInstruction::HaltBarrier() { OF_ENV_BARRIER(); }\n\nvoid ClusterInstruction::EagerSyncBarrier() {\n  // TODO(jianhao): update here after eager instructions are run asynchronously\n  OF_ENV_BARRIER();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job/cluster_instruction.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_JOB_CLUSTER_CONTROL_H_\n#define ONEFLOW_CORE_JOB_CLUSTER_CONTROL_H_\n\n#include \"oneflow/core/job/cluster_instruction.pb.h\"\n\nnamespace oneflow {\n\nstruct ClusterInstruction final {\n  static void MasterSendSessionStart();\n  static void MasterSendHalt();\n  static void MasterSendAbort();\n  static void MasterSendEagerSync();\n  static void WorkerReceiveInstruction(ClusterInstructionProto* cluster_instruction);\n  static void NewSessionBarrier();\n  static void HaltBarrier();\n  static void EagerSyncBarrier();\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_JOB_CLUSTER_CONTROL_H_\n"
  },
  {
    "path": "oneflow/core/job/cluster_instruction.proto",
    "content": "syntax = \"proto2\";\npackage oneflow;\n\nmessage ClusterCtrlSessionStart {}\nmessage ClusterCtrlHalt {}\nmessage ClusterCtrlAbort {}\n\nmessage ClusterInstructionProto {\n  oneof instruction_type {\n    ClusterCtrlSessionStart cluster_ctrl_session_start = 1;\n    ClusterCtrlHalt cluster_ctrl_halt = 2; // normal exit\n    ClusterCtrlAbort cluster_ctrl_abort = 5; // error exit\n  }\n}\n"
  },
  {
    "path": "oneflow/core/job/collective_boxing/coordinator.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_JOB_COLLECTIVE_BOXING_COORDINATOR_H_\n#define ONEFLOW_CORE_JOB_COLLECTIVE_BOXING_COORDINATOR_H_\n\n#include \"oneflow/core/common/util.h\"\n\nnamespace oneflow {\n\nnamespace boxing {\n\nnamespace collective {\n\nclass RequestStore;\nclass Executor;\nstruct RequestId;\n\nclass Coordinator {\n public:\n  Coordinator() = default;\n  virtual ~Coordinator() = default;\n\n  virtual void Init(std::shared_ptr<RequestStore> request_store,\n                    std::shared_ptr<Executor> executor) = 0;\n  virtual void InitJob(int64_t job_id) = 0;\n  virtual void DeinitJob(int64_t job_id) = 0;\n  virtual void AddRequest(void* coordinator_token) = 0;\n  virtual void* CreateCoordinatorToken(const RequestId& request_id) = 0;\n  virtual void DestroyCoordinatorToken(void* coordinator_token) = 0;\n};\n\n}  // namespace collective\n\n}  // namespace boxing\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_JOB_COLLECTIVE_BOXING_COORDINATOR_H_\n"
  },
  {
    "path": "oneflow/core/job/collective_boxing/executor.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/job/collective_boxing/executor.h\"\n\nnamespace oneflow {\n\nnamespace boxing {\n\nnamespace collective {\n\nvoid Executor::ExecuteRequests(const std::vector<RequestId>& request_ids) {\n  GroupRequests(request_ids, [&](std::vector<RequestId>&& group, GroupToken* group_token) {\n    ExecuteGroup(group_token);\n  });\n}\n\n}  // namespace collective\n\n}  // namespace boxing\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job/collective_boxing/executor.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_JOB_COLLECTIVE_BOXING_EXECUTOR_H_\n#define ONEFLOW_CORE_JOB_COLLECTIVE_BOXING_EXECUTOR_H_\n\n#include \"oneflow/core/common/util.h\"\n\nnamespace oneflow {\n\nnamespace boxing {\n\nnamespace collective {\n\nclass RequestStore;\n\nstruct RequestId;\n\nclass GroupToken;\n\nclass Executor {\n public:\n  Executor() = default;\n  virtual ~Executor() = default;\n\n  virtual void Init(std::shared_ptr<RequestStore> request_store) = 0;\n  virtual void InitJob(int64_t job_id) = 0;\n  virtual void DeinitJob(int64_t job_id) = 0;\n  virtual void GroupRequests(\n      const std::vector<RequestId>& request_ids,\n      const std::function<void(std::vector<RequestId>&&, GroupToken*)>& Handler) = 0;\n  virtual void ExecuteGroup(GroupToken* group_token) = 0;\n  virtual void DestroyGroupToken(GroupToken* group_token) = 0;\n  virtual void ExecuteRequests(const std::vector<RequestId>& request_ids);\n};\n\n}  // namespace collective\n\n}  // namespace boxing\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_JOB_COLLECTIVE_BOXING_EXECUTOR_H_\n"
  },
  {
    "path": "oneflow/core/job/collective_boxing/executor_backend.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_JOB_COLLECTIVE_BOXING_EXECUTOR_BACKEND_H_\n#define ONEFLOW_CORE_JOB_COLLECTIVE_BOXING_EXECUTOR_BACKEND_H_\n\n#include \"oneflow/core/common/util.h\"\n\nnamespace oneflow {\n\nnamespace boxing {\n\nnamespace collective {\n\nclass RequestStore;\n\nstruct RequestId;\n\nclass ExecutorBackend {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(ExecutorBackend);\n  ExecutorBackend() = default;\n  virtual ~ExecutorBackend() = default;\n\n  virtual void Init(std::shared_ptr<RequestStore> request_store) = 0;\n  virtual void InitJob(int64_t job_id) = 0;\n  virtual void DeinitJob(int64_t job_id) = 0;\n  virtual void GroupRequests(\n      const std::vector<RequestId>& request_ids,\n      const std::function<void(std::vector<RequestId>&&, void*)>& Handler) = 0;\n  virtual void ExecuteGroup(void* group_token) = 0;\n  virtual void* CreateGroupToken(const std::vector<RequestId>& group) = 0;\n  virtual void DestroyGroupToken(void* group_token) = 0;\n};\n\n}  // namespace collective\n\n}  // namespace boxing\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_JOB_COLLECTIVE_BOXING_EXECUTOR_BACKEND_H_\n"
  },
  {
    "path": "oneflow/core/job/collective_boxing/executor_backend_manager.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/job/collective_boxing/executor_backend_manager.h\"\n\nnamespace oneflow {\n\nnamespace boxing {\n\nnamespace collective {\n\nExecutorBackendMgr& ExecutorBackendMgr::Get() {\n  static ExecutorBackendMgr mgr;\n  return mgr;\n}\n\n}  // namespace collective\n\n}  // namespace boxing\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job/collective_boxing/executor_backend_manager.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_JOB_COLLECTIVE_BOXING_EXECUTOR_BACKEND_MANAGER_H_\n#define ONEFLOW_CORE_JOB_COLLECTIVE_BOXING_EXECUTOR_BACKEND_MANAGER_H_\n\n#include \"oneflow/core/job/collective_boxing/executor_backend.h\"\n#include \"oneflow/core/common/device_type.h\"\n\nnamespace oneflow {\n\nnamespace boxing {\n\nnamespace collective {\n\nclass ExecutorBackendMgr {\n public:\n  using Creator = std::function<std::unique_ptr<ExecutorBackend>()>;\n\n  ExecutorBackendMgr(ExecutorBackendMgr const&) = delete;\n  ExecutorBackendMgr& operator=(ExecutorBackendMgr const&) = delete;\n  static ExecutorBackendMgr& Get();\n\n  template<typename Derived>\n  void RegisterExecutorBackendType(DeviceType device_type) {\n    executor_backend_reg_result_.emplace(device_type, []() -> std::unique_ptr<ExecutorBackend> {\n      return std::make_unique<Derived>();\n    });\n    vaild_executor_device_types_.emplace_back(device_type);\n  }\n\n  std::unique_ptr<ExecutorBackend> NewExecutorBackend(DeviceType device_type) const {\n    const auto& it = executor_backend_reg_result_.find(device_type);\n    CHECK(it != executor_backend_reg_result_.end());\n    return it->second();\n  }\n\n  const std::vector<DeviceType>& vaild_executor_device_types() const {\n    return vaild_executor_device_types_;\n  }\n\n private:\n  ExecutorBackendMgr() = default;\n\n  HashMap<DeviceType, Creator> executor_backend_reg_result_;\n  std::vector<DeviceType> vaild_executor_device_types_;\n};\n\n#define REGISTER_EXECUTOR_BACKEND(device, Derived) \\\n  COMMAND(ExecutorBackendMgr::Get().RegisterExecutorBackendType<Derived>(device))\n\n}  // namespace collective\n\n}  // namespace boxing\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_JOB_COLLECTIVE_BOXING_EXECUTOR_BACKEND_MANAGER_H_\n"
  },
  {
    "path": "oneflow/core/job/collective_boxing/nccl_executor_backend.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/job/collective_boxing/executor_backend_manager.h\"\n#include \"oneflow/core/job/collective_boxing/request_store.h\"\n#include \"oneflow/core/device/nccl_util.h\"\n#include \"oneflow/core/graph/boxing/collective_boxing_util.h\"\n#include \"oneflow/core/job/resource_desc.h\"\n#include \"oneflow/core/control/ctrl_client.h\"\n#include \"oneflow/core/control/global_process_ctx.h\"\n#include \"oneflow/core/job/global_for.h\"\n#include \"oneflow/core/thread/thread_pool.h\"\n#include \"oneflow/core/device/cuda_util.h\"\n\n#include <nccl.h>\n\n#include <memory>\n#include <utility>\n\nnamespace oneflow {\n\nnamespace boxing {\n\nnamespace collective {\n\nnamespace {\n\nncclRedOp_t GetNcclReduceOp(ReduceMethod reduce_method) {\n  if (reduce_method == kReduceMethodSum) {\n    return ncclRedOp_t::ncclSum;\n  } else {\n    UNIMPLEMENTED();\n    return ncclRedOp_t{};\n  }\n}\n\nstd::string GetNcclUniqueIdRpcKey(const std::string& name, int64_t stream_id) {\n  return \"CollectiveBoxingExecutorNcclUniqueIdRpcKey-\" + name + \"-\" + std::to_string(stream_id);\n}\n\nstruct CopyParams {\n  void* dst;\n  const void* src;\n  int64_t count;\n};\n\nconstexpr int64_t kMultiCopyParamsMaxSize = 128;\nconstexpr int64_t kMultiCopyAlignSize = 32;\n\nint64_t GetMultiCopyAlignedSize(int64_t size) {\n  return ((size + kMultiCopyAlignSize - 1) / kMultiCopyAlignSize) * kMultiCopyAlignSize;\n}\n\nstruct MultiCopyParams {\n  CopyParams params[kMultiCopyParamsMaxSize];\n  int64_t count;\n\n  MultiCopyParams() : count(0), params{} {}\n\n  void Add(void* dst, const void* src, int64_t count) {\n    CHECK_LT(this->count, kMultiCopyParamsMaxSize);\n    params[this->count].dst = dst;\n    params[this->count].src = src;\n    params[this->count].count = count;\n    this->count += 1;\n  }\n};\n\nusing BulkType = ulonglong2;\n\n__global__ void MultiCopyGpu(MultiCopyParams multi_params) {\n  for (int64_t p = 0; p < multi_params.count; ++p) {\n    const CopyParams params = multi_params.params[p];\n    auto* bulk_dst = reinterpret_cast<BulkType*>(params.dst);\n    const auto* bulk_src = reinterpret_cast<const BulkType*>(params.src);\n    const int64_t bulk_count = params.count / sizeof(BulkType);\n    CUDA_1D_KERNEL_LOOP_T(int64_t, i, bulk_count) { bulk_dst[i] = bulk_src[i]; }\n    const int64_t tail_offset = bulk_count * sizeof(BulkType);\n    auto* tail_dst = reinterpret_cast<char*>(params.dst) + tail_offset;\n    const auto* tail_src = reinterpret_cast<const char*>(params.src) + tail_offset;\n    const int64_t tail_count = params.count - tail_offset;\n    CUDA_1D_KERNEL_LOOP_T(int64_t, i, tail_count) { tail_dst[i] = tail_src[i]; }\n  }\n}\n\nvoid MultiCopy(cudaStream_t stream, const MultiCopyParams& multi_params) {\n  if (multi_params.count <= 0) { return; }\n  CHECK_LE(multi_params.count, kMultiCopyParamsMaxSize);\n  int64_t max_count = multi_params.params[0].count;\n  for (int64_t i = 0; i < multi_params.count; ++i) {\n    max_count = std::max(max_count, multi_params.params[i].count);\n  }\n  MultiCopyGpu<<<BlocksNum4ThreadsNum(max_count), kCudaThreadsNumPerBlock, 0, stream>>>(\n      multi_params);\n}\n\nclass CommRank final {\n public:\n  OF_DISALLOW_COPY(CommRank);\n  CommRank(int32_t device_id, int32_t global_rank, int32_t global_rank_count, int32_t local_rank,\n           int32_t local_rank_count)\n      : device_id_(device_id),\n        global_rank_(global_rank),\n        local_rank_(local_rank),\n        nccl_comm_(nullptr) {}\n\n  CommRank(CommRank&& rhs) noexcept {\n    this->device_id_ = rhs.device_id_;\n    this->global_rank_ = rhs.global_rank_;\n    this->local_rank_ = rhs.local_rank_;\n    this->nccl_comm_ = rhs.nccl_comm_;\n    rhs.nccl_comm_ = nullptr;\n  }\n\n  ~CommRank() {\n    if (nccl_comm_ != nullptr) {\n      CudaCurrentDeviceGuard guard(device_id_);\n      OF_NCCL_CHECK(ncclCommDestroy(nccl_comm_));\n    }\n  }\n\n  int32_t device_id() const { return device_id_; }\n\n  ncclComm_t nccl_comm() const { return nccl_comm_; }\n\n  void InitRank(ncclUniqueId unique_id, int32_t global_rank_count) {\n    CudaCurrentDeviceGuard guard(device_id_);\n    OF_NCCL_CHECK(ncclCommInitRank(&nccl_comm_, global_rank_count, unique_id, global_rank_));\n  }\n\n private:\n  int32_t device_id_;\n  int32_t global_rank_;\n  int32_t local_rank_;\n  ncclComm_t nccl_comm_;\n};\n\nclass CommGroup final {\n public:\n  OF_DISALLOW_COPY(CommGroup);\n  CommGroup() = default;\n  ~CommGroup() = default;\n  CommGroup(CommGroup&& rhs) noexcept {\n    rank_vec_.swap(rhs.rank_vec_);\n    global_rank_count_ = rhs.global_rank_count_;\n  }\n\n  void InitGroup(const DeviceSet& device_set, const std::string& unique_name) {\n    CudaCurrentDeviceGuard guard;\n    const int64_t this_machine_id = GlobalProcessCtx::Rank();\n    global_rank_count_ = device_set.device_size();\n    std::vector<int32_t> local_ranks;\n    for (int32_t i = 0; i < global_rank_count_; ++i) {\n      if (device_set.device(i).machine_id() == this_machine_id) { local_ranks.emplace_back(i); }\n    }\n    const int32_t local_rank_count = local_ranks.size();\n    CHECK_GT(local_rank_count, 0);\n    ncclUniqueId nccl_unique_id{};\n    if (local_ranks.front() == 0) {\n      OF_NCCL_CHECK(ncclGetUniqueId(&nccl_unique_id));\n      if (local_rank_count != global_rank_count_) {\n        Singleton<CtrlClient>::Get()->PushKV(unique_name, NcclUniqueIdToString(nccl_unique_id));\n      }\n    } else {\n      Singleton<CtrlClient>::Get()->PullKV(unique_name, [&nccl_unique_id](const std::string& val) {\n        NcclUniqueIdFromString(val, &nccl_unique_id);\n      });\n    }\n    rank_vec_.reserve(local_rank_count);\n    OF_NCCL_CHECK(ncclGroupStart());\n    for (int32_t local_rank = 0; local_rank < local_ranks.size(); ++local_rank) {\n      const int32_t global_rank = local_ranks.at(local_rank);\n      const int32_t device_id = device_set.device(global_rank).device_id();\n      OF_CUDA_CHECK(cudaSetDevice(device_id));\n      rank_vec_.emplace_back(device_id, global_rank, global_rank_count_, local_rank,\n                             local_rank_count);\n      rank_vec_.at(local_rank).InitRank(nccl_unique_id, global_rank_count_);\n    }\n    OF_NCCL_CHECK(ncclGroupEnd());\n  }\n\n  int32_t global_rank_count() const { return global_rank_count_; }\n\n  int32_t local_rank_count() const { return rank_vec_.size(); }\n\n  const CommRank& GetCommRank(int32_t local_rank) const { return rank_vec_.at(local_rank); }\n\n private:\n  std::vector<CommRank> rank_vec_;\n  int32_t global_rank_count_ = 0;\n};\n\nclass StreamCtx {\n public:\n  OF_DISALLOW_COPY(StreamCtx);\n  StreamCtx(int32_t device_id, size_t fusion_buffer_size)\n      : device_id_(device_id), fusion_buffer_size_(fusion_buffer_size) {\n    CudaCurrentDeviceGuard guard(device_id_);\n    int priority;\n    OF_CUDA_CHECK(cudaDeviceGetStreamPriorityRange(nullptr, &priority));\n    OF_CUDA_CHECK(cudaStreamCreateWithPriority(&stream_, cudaStreamNonBlocking, priority));\n    OF_CUDA_CHECK(cudaMalloc(&fusion_buffer_, fusion_buffer_size_));\n    cb_event_poller_ = std::thread(&StreamCtx::PollEvent, this);\n  }\n  ~StreamCtx() {\n    cb_event_chan_.Close();\n    cb_event_poller_.join();\n    CudaCurrentDeviceGuard guard(device_id_);\n    OF_CUDA_CHECK(cudaStreamSynchronize(stream_));\n    OF_CUDA_CHECK(cudaStreamDestroy(stream_));\n    OF_CUDA_CHECK(cudaFree(fusion_buffer_));\n  }\n\n  void PollEvent() {\n    CudaCurrentDeviceGuard guard(device_id_);\n    while (true) {\n      std::pair<cudaEvent_t, std::function<void()>> cb_event;\n      ChannelStatus status = cb_event_chan_.Receive(&cb_event);\n      if (status == kChannelStatusErrorClosed) { break; }\n      CHECK_EQ(status, kChannelStatusSuccess);\n      OF_CUDA_CHECK(cudaEventSynchronize(cb_event.first));\n      cb_event.second();\n      OF_CUDA_CHECK(cudaEventDestroy(cb_event.first));\n    }\n  }\n\n  void AddCallback(const std::function<void()>& callback) {\n    cudaEvent_t event;\n    OF_CUDA_CHECK(cudaEventCreateWithFlags(&event, cudaEventDisableTiming));\n    OF_CUDA_CHECK(cudaEventRecord(event, stream_));\n    CHECK_EQ(cb_event_chan_.Send(std::make_pair(event, callback)), kChannelStatusSuccess);\n  }\n\n  int32_t device_id() const { return device_id_; }\n\n  cudaStream_t stream() const { return stream_; }\n\n  size_t fusion_buffer_size() const { return fusion_buffer_size_; }\n\n  char* fusion_buffer() const { return fusion_buffer_; }\n\n private:\n  int32_t device_id_;\n  cudaStream_t stream_ = nullptr;\n  size_t fusion_buffer_size_;\n  char* fusion_buffer_ = nullptr;\n  Channel<std::pair<cudaEvent_t, std::function<void()>>> cb_event_chan_;\n  std::thread cb_event_poller_;\n};\n\nvoid LaunchFusedAllReduce(const CommGroup& comm_group,\n                          const std::vector<std::unique_ptr<StreamCtx>>& device_id2stream_ctx,\n                          const std::shared_ptr<RequestStore>& request_store,\n                          const std::vector<RequestId>& request_ids) {\n  CHECK_LE(request_ids.size(), kMultiCopyParamsMaxSize);\n  RequestEntry* first_request_entry = request_store->MutRequestEntry(request_ids.front());\n  const ncclDataType_t nccl_data_type =\n      GetNcclDataType(first_request_entry->desc().op_desc().data_type());\n  const ncclRedOp_t nccl_reduce_op =\n      GetNcclReduceOp(first_request_entry->desc().op_desc().reduce_method());\n  const int64_t size_of_data_type =\n      GetSizeOfDataType(first_request_entry->desc().op_desc().data_type());\n  std::vector<int64_t> offset_vec;\n  offset_vec.reserve(request_ids.size());\n  int64_t offset = 0;\n  request_store->ForEachMutRequestEntryForIdsInJob(\n      request_ids, [&](RequestEntry* request_entry, int32_t i, const RequestId& request_id) {\n        offset_vec.emplace_back(offset);\n        offset += GetMultiCopyAlignedSize(request_entry->size_in_bytes());\n      });\n  const int64_t elem_cnt = offset / size_of_data_type;\n  for (int32_t local_rank = 0; local_rank < comm_group.local_rank_count(); ++local_rank) {\n    MultiCopyParams copy_in_params;\n    const CommRank& comm_rank = comm_group.GetCommRank(local_rank);\n    const StreamCtx* stream_ctx = device_id2stream_ctx.at(comm_rank.device_id()).get();\n    CHECK_LE(offset, stream_ctx->fusion_buffer_size());\n    request_store->ForEachMutRequestEntryForIdsInJob(\n        request_ids, [&](RequestEntry* request_entry, int32_t i, const RequestId& request_id) {\n          copy_in_params.Add(stream_ctx->fusion_buffer() + offset_vec.at(i),\n                             request_entry->GetRuntimeRequest(local_rank)->send_buff,\n                             request_entry->size_in_bytes());\n        });\n    OF_CUDA_CHECK(cudaSetDevice(comm_rank.device_id()));\n    MultiCopy(stream_ctx->stream(), copy_in_params);\n  }\n\n  OF_NCCL_CHECK(ncclGroupStart());\n  for (int32_t local_rank = 0; local_rank < comm_group.local_rank_count(); ++local_rank) {\n    const CommRank& comm_rank = comm_group.GetCommRank(local_rank);\n    const StreamCtx* stream_ctx = device_id2stream_ctx.at(comm_rank.device_id()).get();\n    OF_CUDA_CHECK(cudaSetDevice(comm_rank.device_id()));\n    OF_NCCL_CHECK(ncclAllReduce(stream_ctx->fusion_buffer(), stream_ctx->fusion_buffer(), elem_cnt,\n                                nccl_data_type, nccl_reduce_op, comm_rank.nccl_comm(),\n                                stream_ctx->stream()));\n  }\n  OF_NCCL_CHECK(ncclGroupEnd());\n\n  for (int32_t local_rank = 0; local_rank < comm_group.local_rank_count(); ++local_rank) {\n    MultiCopyParams copy_out_params;\n    const CommRank& comm_rank = comm_group.GetCommRank(local_rank);\n    const StreamCtx* stream_ctx = device_id2stream_ctx.at(comm_rank.device_id()).get();\n    request_store->ForEachMutRequestEntryForIdsInJob(\n        request_ids, [&](RequestEntry* request_entry, int32_t i, const RequestId& request_id) {\n          copy_out_params.Add(request_entry->GetRuntimeRequest(local_rank)->recv_buff,\n                              stream_ctx->fusion_buffer() + offset_vec.at(i),\n                              request_entry->size_in_bytes());\n        });\n    OF_CUDA_CHECK(cudaSetDevice(comm_rank.device_id()));\n    MultiCopy(stream_ctx->stream(), copy_out_params);\n  }\n}\n\nvoid LaunchAggregatedOps(const CommGroup& comm_group,\n                         const std::vector<std::unique_ptr<StreamCtx>>& device_id2stream_ctx,\n                         const std::shared_ptr<RequestStore>& request_store,\n                         const std::vector<RequestId>& request_ids) {\n  OF_NCCL_CHECK(ncclGroupStart());\n  for (int32_t local_rank = 0; local_rank < comm_group.local_rank_count(); ++local_rank) {\n    const CommRank& comm_rank = comm_group.GetCommRank(local_rank);\n    const auto comm = comm_rank.nccl_comm();\n    const StreamCtx* stream_ctx = device_id2stream_ctx.at(comm_rank.device_id()).get();\n    OF_CUDA_CHECK(cudaSetDevice(comm_rank.device_id()));\n    request_store->ForEachMutRequestEntryForIdsInJob(\n        request_ids, [&](RequestEntry* request_entry, int32_t i, const RequestId& request_id) {\n          const auto& op_desc = request_entry->desc().op_desc();\n          const std::shared_ptr<const RuntimeRequestInfo>& runtime_request_info =\n              request_entry->GetRuntimeRequest(local_rank);\n          const OpType op_type = op_desc.op_type();\n          const void* send_buff = runtime_request_info->send_buff;\n          void* recv_buff = runtime_request_info->recv_buff;\n          const int64_t elem_cnt = request_entry->elem_cnt();\n          const ncclDataType_t nccl_data_type = GetNcclDataType(op_desc.data_type());\n          const int32_t num_ranks = comm_group.global_rank_count();\n          if (op_type == OpType::kOpTypeAllReduce) {\n            OF_NCCL_CHECK(ncclAllReduce(send_buff, recv_buff, elem_cnt, nccl_data_type,\n                                        GetNcclReduceOp(op_desc.reduce_method()), comm,\n                                        stream_ctx->stream()));\n          } else if (op_type == OpType::kOpTypeAllGather) {\n            CHECK_EQ(elem_cnt % num_ranks, 0);\n            OF_NCCL_CHECK(ncclAllGather(send_buff, recv_buff, elem_cnt / num_ranks, nccl_data_type,\n                                        comm, stream_ctx->stream()));\n          } else if (op_type == OpType::kOpTypeReduceScatter) {\n            CHECK_EQ(elem_cnt % num_ranks, 0);\n            OF_NCCL_CHECK(ncclReduceScatter(\n                send_buff, recv_buff, elem_cnt / num_ranks, nccl_data_type,\n                GetNcclReduceOp(op_desc.reduce_method()), comm, stream_ctx->stream()));\n          } else if (op_type == OpType::kOpTypeReduce) {\n            OF_NCCL_CHECK(ncclReduce(send_buff, recv_buff, elem_cnt, nccl_data_type,\n                                     GetNcclReduceOp(op_desc.reduce_method()), op_desc.root(), comm,\n                                     stream_ctx->stream()));\n          } else if (op_type == OpType::kOpTypeBroadcast) {\n            OF_NCCL_CHECK(ncclBroadcast(send_buff, recv_buff, elem_cnt, nccl_data_type,\n                                        op_desc.root(), comm, stream_ctx->stream()));\n          } else if (op_type == OpType::kOpTypeAll2All) {\n#if NCCL_VERSION_CODE > 2700\n            const int64_t elem_per_rank = elem_cnt / num_ranks;\n            const int64_t elem_per_chunk = elem_per_rank / num_ranks;\n            const int64_t dtype_size = GetSizeOfDataType(op_desc.data_type());\n            const int64_t chunk_size = elem_per_chunk * dtype_size;\n            for (int64_t j = 0; j < num_ranks; ++j) {\n              OF_NCCL_CHECK(ncclSend(reinterpret_cast<const void*>(\n                                         reinterpret_cast<const char*>(send_buff) + j * chunk_size),\n                                     elem_per_chunk, nccl_data_type, j, comm,\n                                     stream_ctx->stream()));\n              OF_NCCL_CHECK(ncclRecv(\n                  reinterpret_cast<void*>(reinterpret_cast<char*>(recv_buff) + j * chunk_size),\n                  elem_per_chunk, nccl_data_type, j, comm, stream_ctx->stream()));\n            }\n#else\n        UNIMPLEMENTED();\n#endif\n          } else {\n            UNIMPLEMENTED();\n          }\n        });\n  }\n  OF_NCCL_CHECK(ncclGroupEnd());\n}\n\nvoid AddCallbackAndResetRuntimeRequest(\n    const CommGroup& comm_group,\n    const std::vector<std::unique_ptr<StreamCtx>>& device_id2stream_ctx,\n    const std::shared_ptr<RequestStore>& request_store, const std::vector<RequestId>& request_ids) {\n  std::vector<std::vector<std::shared_ptr<const RuntimeRequestInfo>>> saved_runtime_request_info(\n      request_ids.size());\n  request_store->ForEachMutRequestEntryForIdsInJob(\n      request_ids, [&](RequestEntry* request_entry, int32_t i, const RequestId& request_id) {\n        saved_runtime_request_info.at(i) = std::move(request_entry->ResetRuntimeRequest());\n      });\n  for (int32_t local_rank = 0; local_rank < comm_group.local_rank_count(); ++local_rank) {\n    const CommRank& comm_rank = comm_group.GetCommRank(local_rank);\n    StreamCtx* stream_ctx = device_id2stream_ctx.at(comm_rank.device_id()).get();\n    auto runtime_request_info_vec =\n        std::make_shared<std::vector<std::shared_ptr<const RuntimeRequestInfo>>>();\n    runtime_request_info_vec->reserve(request_ids.size());\n    request_store->ForEachMutRequestEntryForIdsInJob(\n        request_ids, [&](RequestEntry* request_entry, int32_t i, const RequestId& request_id) {\n          runtime_request_info_vec->emplace_back(\n              std::move(saved_runtime_request_info.at(i).at(local_rank)));\n        });\n    OF_CUDA_CHECK(cudaSetDevice(comm_rank.device_id()));\n    stream_ctx->AddCallback([runtime_request_info_vec]() {\n      for (auto& runtime_request_info : *runtime_request_info_vec) {\n        runtime_request_info->callback(Maybe<void>::Ok());\n      }\n    });\n  }\n}\n\n}  // namespace\n\nclass NcclExecutorBackend : public ExecutorBackend {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(NcclExecutorBackend);\n  NcclExecutorBackend();\n  ~NcclExecutorBackend() override;\n\n private:\n  void Init(std::shared_ptr<RequestStore> request_store) override;\n  void InitJob(int64_t job_id) override;\n  void DeinitJob(int64_t job_id) override;\n  void GroupRequests(const std::vector<RequestId>& request_ids,\n                     const std::function<void(std::vector<RequestId>&&, void*)>& Handler) override;\n  void ExecuteGroup(void* group_token) override;\n  void* CreateGroupToken(const std::vector<RequestId>& group) override;\n  void DestroyGroupToken(void* group_token) override;\n\n  struct Impl;\n  std::unique_ptr<Impl> impl_;\n};\n\nstruct NcclExecutorBackend::Impl {\n  Impl(const CollectiveBoxingConf& conf, std::shared_ptr<RequestStore> request_store)\n      : conf(conf), request_store(std::move(request_store)) {\n    CHECK_GT(conf.nccl_num_streams(), 0);\n    CHECK_GE(conf.nccl_fusion_threshold_mb(), 0);\n    fusion_threshold = conf.nccl_fusion_threshold_mb() * 1024 * 1024;\n    num_streams = conf.nccl_num_streams();\n    current_stream_id = 0;\n    enable_mixed_fusion =\n        (!conf.nccl_fusion_all_reduce_use_buffer()) && conf.nccl_enable_mixed_fusion();\n    int nccl_version;\n    OF_NCCL_CHECK(ncclGetVersion(&nccl_version));\n    if (nccl_version == 21003) {\n      LOG(WARNING)\n          << \"Current nccl version is 2.10.3, in this version, ncclGroup() with mixed \"\n             \"datatype/element/collective could induce crash or corruption, so we will not \"\n             \"fuse any request.\";\n    }\n    InitStreamCtx();\n    InitIsOpTypeFusionEnabled();\n  }\n  ~Impl() {\n    stream_id2device_id2stream_ctx.clear();\n    device_set2stream_id2comm_group.clear();\n  }\n\n  void InitCommGroup(int64_t job_id) {\n    std::set<int64_t> local_device_ids;\n    request_store->ForEachMutRequestEntryInJob(\n        job_id, [&](RequestEntry* request_entry, int32_t i, const RequestId& request_id) {\n          const auto& request = request_entry->desc();\n          if (request.op_desc().device_type() != DeviceType::kCUDA) { return; }\n          if (!request_entry->HasRankOnThisNode()) { return; }\n          const DeviceSet& device_set = request.device_set();\n          if (device_set2stream_id2comm_group.count(device_set) > 0) { return; }\n          auto& stream_id2comm_group = device_set2stream_id2comm_group[device_set];\n          stream_id2comm_group.resize(num_streams);\n          for (int32_t stream_id = 0; stream_id < num_streams; ++stream_id) {\n            stream_id2comm_group.at(stream_id).InitGroup(\n                device_set, GetNcclUniqueIdRpcKey(request.op_desc().name(), stream_id));\n          }\n          for (int32_t j = 0; j < stream_id2comm_group.at(0).local_rank_count(); ++j) {\n            local_device_ids.emplace(stream_id2comm_group.at(0).GetCommRank(j).device_id());\n          }\n        });\n    for (int32_t stream_id = 0; stream_id < num_streams; ++stream_id) {\n      for (const int64_t device_id : local_device_ids) {\n        if (stream_id2device_id2stream_ctx.at(stream_id).at(device_id) == nullptr) {\n          stream_id2device_id2stream_ctx.at(stream_id).at(device_id) =\n              std::make_unique<StreamCtx>(device_id, fusion_threshold);\n        }\n      }\n    }\n  }\n\n  void InitStreamCtx() {\n    int32_t num_devices;\n    OF_CUDA_CHECK(cudaGetDeviceCount(&num_devices));\n    stream_id2device_id2stream_ctx.resize(num_streams);\n    for (int64_t stream_id = 0; stream_id < num_streams; ++stream_id) {\n      stream_id2device_id2stream_ctx.at(stream_id).resize(num_devices);\n    }\n  }\n\n  void InitIsOpTypeFusionEnabled() {\n    op_type2fusion_enabled.resize(OpType_ARRAYSIZE, false);\n    op_type2fusion_enabled.at(OpType::kOpTypeAllReduce) = conf.nccl_fusion_all_reduce();\n    op_type2fusion_enabled.at(OpType::kOpTypeAllGather) = conf.nccl_fusion_all_gather();\n    op_type2fusion_enabled.at(OpType::kOpTypeReduceScatter) = conf.nccl_fusion_reduce_scatter();\n    op_type2fusion_enabled.at(OpType::kOpTypeReduce) = conf.nccl_fusion_reduce();\n    op_type2fusion_enabled.at(OpType::kOpTypeBroadcast) = conf.nccl_fusion_broadcast();\n    op_type2fusion_enabled.at(OpType::kOpTypeAll2All) = false;\n  }\n\n  int32_t NextStreamId() {\n    const int32_t stream_id = current_stream_id;\n    current_stream_id = (current_stream_id + 1) % num_streams;\n    return stream_id;\n  }\n\n  bool IsOpTypeFusionEnabled(OpType op_type) const { return op_type2fusion_enabled.at(op_type); }\n\n  bool IsRequestEntryFusionEnabled(const RequestEntry* entry) const {\n    return IsOpTypeFusionEnabled(entry->desc().op_desc().op_type());\n  }\n\n  bool CanRequestEntryFuse(const RequestEntry* lhs, const RequestEntry* rhs) const {\n    {\n      int nccl_version;\n      OF_NCCL_CHECK(ncclGetVersion(&nccl_version));\n      // Workaround for https://github.com/NVIDIA/nccl/issues/560\n      if (nccl_version == 21003) { return false; }\n    }\n    if (lhs->device_set_symbol() != rhs->device_set_symbol()) { return false; }\n    if ((!IsRequestEntryFusionEnabled(lhs)) || (!IsRequestEntryFusionEnabled(rhs))) {\n      return false;\n    }\n    if ((!enable_mixed_fusion)\n        && lhs->desc().op_desc().op_type() != rhs->desc().op_desc().op_type()) {\n      return false;\n    }\n    if (conf.nccl_fusion_all_reduce_use_buffer()) {\n      if (lhs->desc().op_desc().op_type() == OpType::kOpTypeAllReduce\n          && rhs->desc().op_desc().op_type() == OpType::kOpTypeAllReduce) {\n        CHECK(lhs->desc().op_desc().has_reduce_method());\n        CHECK(rhs->desc().op_desc().has_reduce_method());\n        return lhs->desc().op_desc().reduce_method() == rhs->desc().op_desc().reduce_method()\n               && lhs->desc().op_desc().data_type() == rhs->desc().op_desc().data_type();\n      } else if (lhs->desc().op_desc().op_type() == OpType::kOpTypeAllReduce\n                 || rhs->desc().op_desc().op_type() == OpType::kOpTypeAllReduce) {\n        return false;\n      } else {\n        return true;\n      }\n    } else {\n      return true;\n    }\n  }\n\n  void GroupRequests(const std::vector<RequestId>& request_ids,\n                     const std::function<void(std::vector<RequestId>&&, void*)>& Handler) {\n    std::vector<RequestId> group;\n    int64_t group_size = 0;\n    const int64_t fusion_max_ops = std::min(conf.nccl_fusion_max_ops(), kMultiCopyParamsMaxSize);\n    request_store->ForEachMutRequestEntryForIdsInJob(\n        request_ids, [&](RequestEntry* request_entry, int32_t i, const RequestId& request_id) {\n          const auto& request = request_entry->desc();\n          const int64_t size = GetMultiCopyAlignedSize(request_entry->size_in_bytes());\n          if (group.empty()\n              || !CanRequestEntryFuse(request_store->MutRequestEntry(group.back()), request_entry)\n              || group_size + size > fusion_threshold || group.size() >= fusion_max_ops) {\n            if (!group.empty()) {\n              void* token = CreateGroupToken(group);\n              Handler(std::move(group), token);\n              group.clear();\n              group_size = 0;\n            }\n          }\n          group.emplace_back(request_id);\n          group_size += size;\n        });\n    if (!group.empty()) {\n      void* token = CreateGroupToken(group);\n      Handler(std::move(group), token);\n    }\n  }\n\n  struct GroupToken {\n    GroupToken(const std::vector<RequestId>& group, std::vector<CommGroup>* stream_id2comm_group)\n        : request_ids(group), stream_id2comm_group(stream_id2comm_group) {}\n    std::vector<RequestId> request_ids;\n    std::vector<CommGroup>* stream_id2comm_group;\n  };\n\n  void* CreateGroupToken(const std::vector<RequestId>& group) {\n    CHECK_GT(group.size(), 0);\n    void* group_token;\n    const DeviceSet& first_device_set =\n        request_store->MutRequestEntry(group.front())->desc().device_set();\n    auto it = device_set2stream_id2comm_group.find(first_device_set);\n    CHECK(it != device_set2stream_id2comm_group.end());\n    group_token = new GroupToken(group, &it->second);\n    request_store->ForEachMutRequestEntryForIdsInJob(\n        group, [&](RequestEntry* request_entry, int32_t i, const RequestId& request_id) {\n          const DeviceSet& device_set = request_entry->desc().device_set();\n          CHECK(first_device_set == device_set);\n        });\n    return group_token;\n  }\n\n  void DestroyGroupToken(void* group_token) {\n    GroupToken* token = static_cast<GroupToken*>(group_token);\n    delete token;\n  }\n\n  void ExecuteGroup(void* group_token) {\n    GroupToken* token = static_cast<GroupToken*>(group_token);\n    const std::vector<RequestId>& request_ids = token->request_ids;\n    if (request_ids.empty()) { return; }\n    const int32_t stream_id = NextStreamId();\n    CudaCurrentDeviceGuard device_guard;\n    const auto& comm_group = token->stream_id2comm_group->at(stream_id);\n    auto& device_id2stream_ctx = stream_id2device_id2stream_ctx.at(stream_id);\n    if (request_store->MutRequestEntry(request_ids.front())->desc().op_desc().op_type()\n            == OpType::kOpTypeAllReduce\n        && conf.nccl_fusion_all_reduce_use_buffer() && request_ids.size() > 1) {\n      LaunchFusedAllReduce(comm_group, device_id2stream_ctx, request_store, request_ids);\n    } else {\n      LaunchAggregatedOps(comm_group, device_id2stream_ctx, request_store, request_ids);\n    }\n    AddCallbackAndResetRuntimeRequest(comm_group, device_id2stream_ctx, request_store, request_ids);\n  }\n\n  CollectiveBoxingConf conf;\n  int64_t fusion_threshold;\n  int32_t num_streams;\n  int32_t current_stream_id;\n  bool enable_mixed_fusion;\n  std::vector<bool> op_type2fusion_enabled;\n  std::shared_ptr<RequestStore> request_store;\n  HashMap<DeviceSet, std::vector<CommGroup>> device_set2stream_id2comm_group;\n  std::vector<std::vector<std::unique_ptr<StreamCtx>>> stream_id2device_id2stream_ctx;\n};\n\nNcclExecutorBackend::NcclExecutorBackend() = default;\n\nNcclExecutorBackend::~NcclExecutorBackend() = default;\n\nvoid NcclExecutorBackend::Init(std::shared_ptr<RequestStore> request_store) {\n  impl_ = std::make_unique<Impl>(\n      Singleton<ResourceDesc, ForSession>::Get()->collective_boxing_conf(), request_store);\n}\n\nvoid NcclExecutorBackend::InitJob(int64_t job_id) {\n  CudaCurrentDeviceGuard guard;\n  impl_->InitCommGroup(job_id);\n}\n\nvoid NcclExecutorBackend::DeinitJob(int64_t job_id) {}\n\nvoid NcclExecutorBackend::GroupRequests(\n    const std::vector<RequestId>& request_ids,\n    const std::function<void(std::vector<RequestId>&&, void*)>& Handler) {\n  impl_->GroupRequests(request_ids, Handler);\n}\n\nvoid* NcclExecutorBackend::CreateGroupToken(const std::vector<RequestId>& group) {\n  return impl_->CreateGroupToken(group);\n}\n\nvoid NcclExecutorBackend::DestroyGroupToken(void* group_token) {\n  return impl_->DestroyGroupToken(group_token);\n}\n\nvoid NcclExecutorBackend::ExecuteGroup(void* group_token) { impl_->ExecuteGroup(group_token); }\n\nREGISTER_EXECUTOR_BACKEND(DeviceType::kCUDA, NcclExecutorBackend);\n\n}  // namespace collective\n\n}  // namespace boxing\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job/collective_boxing/request_store.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/job/collective_boxing/request_store.h\"\n#include \"oneflow/core/job/plan.pb.h\"\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/control/global_process_ctx.h\"\n#include \"oneflow/core/job/global_for.h\"\n#include \"oneflow/core/common/shape.h\"\n#include \"oneflow/core/common/data_type.h\"\n\nnamespace oneflow {\n\nnamespace boxing {\n\nnamespace collective {\n\nRequestEntry::RequestEntry(const RequestDesc& desc) : desc_(desc) {\n  std::set<int64_t> node_ids;\n  for (int64_t global_rank = 0; global_rank < desc.device_set().device().size(); ++global_rank) {\n    const DeviceDesc& device = desc.device_set().device(global_rank);\n    if (device.machine_id() == GlobalProcessCtx::Rank()) {\n      local_device_vec_.emplace_back(device);\n      global_rank2local_rank_.emplace(global_rank, local_rank2global_rank_.size());\n      local_rank2global_rank_.emplace_back(global_rank);\n    }\n    node_ids.emplace(device.machine_id());\n  }\n  const size_t local_rank_count = local_device_vec_.size();\n  node_count_ = node_ids.size();\n  state_.runtime_request_info_vec.resize(local_rank_count);\n  state_.runtime_request_count = 0;\n  elem_cnt_ = Shape(desc.op_desc().shape()).elem_cnt();\n  size_in_bytes_ = elem_cnt_ * GetSizeOfDataType(desc.op_desc().data_type());\n  device_set_symbol_.reset(desc.device_set());\n}\n\nbool RequestEntry::AddRuntimeRequest(\n    int32_t local_rank, std::shared_ptr<const RuntimeRequestInfo> runtime_request_info) {\n  CHECK_LT(local_rank, state_.runtime_request_info_vec.size());\n  std::lock_guard<std::mutex> lock(state_.mutex);\n  CHECK(!state_.runtime_request_info_vec.at(local_rank));\n  state_.runtime_request_info_vec.at(local_rank) = std::move(runtime_request_info);\n  state_.runtime_request_count += 1;\n  return state_.runtime_request_count == state_.runtime_request_info_vec.size();\n}\n\nconst std::shared_ptr<const RuntimeRequestInfo>& RequestEntry::GetRuntimeRequest(\n    int32_t local_rank) {\n  std::lock_guard<std::mutex> lock(state_.mutex);\n  return state_.runtime_request_info_vec.at(local_rank);\n}\n\nstd::vector<std::shared_ptr<const RuntimeRequestInfo>> RequestEntry::ResetRuntimeRequest() {\n  std::lock_guard<std::mutex> lock(state_.mutex);\n  std::vector<std::shared_ptr<const RuntimeRequestInfo>> ret(\n      state_.runtime_request_info_vec.size());\n  ret.swap(state_.runtime_request_info_vec);\n  state_.runtime_request_count = 0;\n  return ret;\n}\n\nvoid RequestStore::InitJob(int64_t job_id, const RequestSet& request_set) {\n  std::vector<std::unique_ptr<RequestEntry>>& request_entry_vec = job_id2request_entry_vec_[job_id];\n  CHECK_EQ(request_entry_vec.size(), 0);\n  for (const RequestDesc& desc : request_set.request()) {\n    request_entry_vec.emplace_back(std::make_unique<RequestEntry>(desc));\n  }\n  for (int32_t i = 0; i < request_entry_vec.size(); ++i) {\n    const std::unique_ptr<RequestEntry>& entry = request_entry_vec.at(i);\n    CHECK(name2request_id_.emplace(entry->desc().op_desc().name(), RequestId(job_id, i)).second);\n  }\n}\n\nvoid RequestStore::DeinitJob(int64_t job_id) {\n  const auto& it = job_id2request_entry_vec_.find(job_id);\n  CHECK(it != job_id2request_entry_vec_.end());\n  const auto& request_entry_vec = it->second;\n  for (const auto& request_entry : request_entry_vec) {\n    name2request_id_.erase(request_entry->desc().op_desc().name());\n  }\n  job_id2request_entry_vec_.erase(job_id);\n}\n\nstruct RequestEntryToken {\n  RequestEntry* request_entry;\n};\n\nvoid* RequestStore::CreateRequestEntryToken(const RequestId& request_id) {\n  auto it = job_id2request_entry_vec_.find(request_id.job_id);\n  CHECK(it != job_id2request_entry_vec_.end());\n  return new RequestEntryToken{it->second.at(request_id.request_index).get()};\n}\n\nvoid RequestStore::DestroyRequestEntryToken(void* request_entry_token) {\n  auto token = static_cast<RequestEntryToken*>(request_entry_token);\n  delete token;\n}\n\nRequestEntry* RequestStore::GetRequestEntry(void* request_entry_token) {\n  return static_cast<RequestEntryToken*>(request_entry_token)->request_entry;\n}\n\n}  // namespace collective\n\n}  // namespace boxing\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job/collective_boxing/request_store.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_JOB_COLLECTIVE_BOXING_REQUEST_STORE_H_\n#define ONEFLOW_CORE_JOB_COLLECTIVE_BOXING_REQUEST_STORE_H_\n\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/job/collective_boxing/runtime_request_info.h\"\n#include \"oneflow/core/job/plan.pb.h\"\n#include \"oneflow/core/common/symbol.h\"\n#include \"oneflow/core/graph/boxing/collective_boxing_util.h\"\n\nnamespace oneflow {\n\nnamespace boxing {\n\nnamespace collective {\n\nclass RequestEntry final {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(RequestEntry);\n  RequestEntry(const RequestDesc& desc);\n  ~RequestEntry() = default;\n\n  const RequestDesc& desc() const { return desc_; }\n  int32_t LocalRankCount() const { return local_rank2global_rank_.size(); }\n  int32_t LocalRankToGlobalRank(int32_t local_rank) const {\n    return local_rank2global_rank_.at(local_rank);\n  }\n  int32_t GlobalRankToLocalRank(int32_t global_rank) const {\n    return global_rank2local_rank_.at(global_rank);\n  }\n  bool HasRankOnThisNode() const { return !local_rank2global_rank_.empty(); }\n  int32_t NodeCount() const { return node_count_; }\n  const DeviceDesc& LocalDeviceDesc(int32_t local_rank) const {\n    return local_device_vec_.at(local_rank);\n  }\n  bool IsRootOnThisNode() const {\n    return (!local_rank2global_rank_.empty()) && local_rank2global_rank_.front() == 0;\n  }\n\n  bool AddRuntimeRequest(int32_t local_rank,\n                         std::shared_ptr<const RuntimeRequestInfo> runtime_request_info);\n  const std::shared_ptr<const RuntimeRequestInfo>& GetRuntimeRequest(int32_t local_rank);\n  std::vector<std::shared_ptr<const RuntimeRequestInfo>> ResetRuntimeRequest();\n  int64_t elem_cnt() const { return elem_cnt_; }\n  int64_t size_in_bytes() const { return size_in_bytes_; }\n  const Symbol<DeviceSet>& device_set_symbol() const { return device_set_symbol_; }\n\n private:\n  RequestDesc desc_;\n  int32_t node_count_;\n  std::vector<DeviceDesc> local_device_vec_;\n  std::vector<int64_t> local_rank2global_rank_;\n  std::map<int64_t, int64_t> global_rank2local_rank_;\n  int64_t elem_cnt_;\n  int64_t size_in_bytes_;\n  Symbol<DeviceSet> device_set_symbol_;\n\n  struct State {\n    std::vector<std::shared_ptr<const RuntimeRequestInfo>> runtime_request_info_vec;\n    int32_t runtime_request_count;\n    std::mutex mutex;\n  };\n\n  State state_;\n};\n\nstruct RequestId {\n  RequestId(int64_t job_id, int32_t request_index) : job_id(job_id), request_index(request_index) {}\n  int64_t job_id;\n  int32_t request_index;\n};\n\nclass RequestStore {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(RequestStore);\n  RequestStore() = default;\n  ~RequestStore() = default;\n\n  void InitJob(int64_t job_id, const RequestSet& request_set);\n  void DeinitJob(int64_t job_id);\n\n  RequestEntry* MutRequestEntry(const RequestId& request_id) {\n    auto it = job_id2request_entry_vec_.find(request_id.job_id);\n    CHECK(it != job_id2request_entry_vec_.end());\n    return it->second.at(request_id.request_index).get();\n  }\n\n  void ForEachMutRequestEntryForIdsInJob(\n      const std::vector<RequestId>& request_ids,\n      const std::function<void(RequestEntry*, int32_t i, const RequestId& request_id)>& Handler) {\n    if (request_ids.size() == 0) { return; }\n    int64_t job_id = request_ids.front().job_id;\n    auto it = job_id2request_entry_vec_.find(job_id);\n    CHECK(it != job_id2request_entry_vec_.end());\n    for (int32_t i = 0; i < request_ids.size(); ++i) {\n      CHECK_EQ(request_ids.at(i).job_id, job_id);\n      Handler(it->second.at(request_ids.at(i).request_index).get(), i, request_ids.at(i));\n    }\n  }\n\n  void ForEachMutRequestEntryInJob(\n      int64_t job_id,\n      const std::function<void(RequestEntry*, int32_t i, const RequestId& request_id)>& Handler) {\n    auto it = job_id2request_entry_vec_.find(job_id);\n    CHECK(it != job_id2request_entry_vec_.end());\n    for (int32_t i = 0; i < it->second.size(); ++i) {\n      RequestId request_id(job_id, i);\n      Handler(it->second.at(i).get(), i, request_id);\n    }\n  }\n\n  int32_t RequestCountForJob(int64_t job_id) const {\n    const auto& it = job_id2request_entry_vec_.find(job_id);\n    CHECK(it != job_id2request_entry_vec_.end());\n    return it->second.size();\n  }\n\n  RequestId GetRequestIdByName(const std::string& name) const { return name2request_id_.at(name); }\n\n  void* CreateRequestEntryToken(const RequestId& request_id);\n\n  void DestroyRequestEntryToken(void* token);\n\n  RequestEntry* GetRequestEntry(void* token);\n\n private:\n  HashMap<int64_t, std::vector<std::unique_ptr<RequestEntry>>> job_id2request_entry_vec_;\n  HashMap<std::string, RequestId> name2request_id_;\n};\n\n}  // namespace collective\n\n}  // namespace boxing\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_JOB_COLLECTIVE_BOXING_REQUEST_STORE_H_\n"
  },
  {
    "path": "oneflow/core/job/collective_boxing/runtime_request_info.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_JOB_COLLECTIVE_BOXING_RUNTIME_REQUEST_INFO_H_\n#define ONEFLOW_CORE_JOB_COLLECTIVE_BOXING_RUNTIME_REQUEST_INFO_H_\n\n#include \"oneflow/core/common/util.h\"\n\nnamespace oneflow {\n\nnamespace boxing {\n\nnamespace collective {\n\nstruct RuntimeRequestInfo {\n  const void* send_buff;\n  void* recv_buff;\n  std::function<void(const Maybe<void>&)> callback;\n};\n\n}  // namespace collective\n\n}  // namespace boxing\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_JOB_COLLECTIVE_BOXING_RUNTIME_REQUEST_INFO_H_\n"
  },
  {
    "path": "oneflow/core/job/collective_boxing/scheduler.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/job/collective_boxing/scheduler.h\"\n#include \"oneflow/core/job/collective_boxing/executor.h\"\n#include \"oneflow/core/job/collective_boxing/request_store.h\"\n#include \"oneflow/core/job/collective_boxing/coordinator.h\"\n#include \"oneflow/core/job/collective_boxing/static_group_coordinator.h\"\n#include \"oneflow/core/graph/boxing/collective_boxing_util.h\"\n#include \"oneflow/core/job/global_for.h\"\n#include \"oneflow/core/job/collective_boxing/executor_backend_manager.h\"\n#include \"oneflow/core/job/plan.pb.h\"\n#include \"oneflow/core/job/resource_desc.h\"\n#include \"oneflow/core/ep/include/device_manager_registry.h\"\n\nnamespace oneflow {\n\nnamespace boxing {\n\nnamespace collective {\n\nnamespace {\n\nbool CanMergeIntoCurGroup(RequestStore* request_store, const RequestEntry* request_entry,\n                          const RequestId& request_id, const std::vector<RequestId>& group_buffer) {\n  if (group_buffer.empty()) { return true; }\n  const RequestId& group_entry_id = group_buffer.front();\n  const auto* group_entry = request_store->MutRequestEntry(group_entry_id);\n  return (request_id.job_id == group_entry_id.job_id\n          && request_entry->desc().dependency_depth() == group_entry->desc().dependency_depth()\n          && request_entry->desc().op_desc().device_type()\n                 == group_entry->desc().op_desc().device_type()\n          && request_entry->device_set_symbol() == group_entry->device_set_symbol());\n}\n\nbool HasRankInteraction(const DeviceSet& a, const DeviceSet& b) {\n  for (int64_t i = 0; i < a.device_size(); ++i) {\n    const DeviceDesc& a_device_desc = a.device(i);\n    for (int64_t j = 0; j < b.device_size(); ++j) {\n      if (a_device_desc.machine_id() == b.device(j).machine_id()) { return true; }\n    }\n  }\n  return false;\n}\n\n}  // namespace\n\nclass RequestHandle final {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(RequestHandle);\n  RequestHandle(int32_t local_rank, void* request_entry_token, void* coordinator_token)\n      : local_rank_(local_rank),\n        request_entry_token_(request_entry_token),\n        coordinator_token_(coordinator_token) {}\n  ~RequestHandle() = default;\n\n  int32_t local_rank() const { return local_rank_; }\n\n  void* request_entry_token() { return request_entry_token_; }\n\n  void* coordinator_token() { return coordinator_token_; }\n\n private:\n  int32_t local_rank_;\n  void* request_entry_token_;\n  void* coordinator_token_;\n};\n\nclass GroupToken final {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(GroupToken);\n  GroupToken(DeviceType device_type, void* backend_group_token)\n      : device_type_(device_type), backend_group_token_(backend_group_token) {}\n  ~GroupToken() = default;\n\n  DeviceType device_type() { return device_type_; }\n\n  void* backend_group_token() { return backend_group_token_; }\n\n private:\n  DeviceType device_type_;\n  void* backend_group_token_;\n};\n\nclass ExecutorImpl : public Executor {\n public:\n  ExecutorImpl() = default;\n  ~ExecutorImpl() override = default;\n\n  void Init(std::shared_ptr<RequestStore> request_store) override;\n  void InitJob(int64_t job_id) override;\n  void DeinitJob(int64_t job_id) override;\n  void GroupRequests(\n      const std::vector<RequestId>& request_ids,\n      const std::function<void(std::vector<RequestId>&&, GroupToken*)>& Handler) override;\n  void ExecuteGroup(GroupToken* group_token) override;\n  void DestroyGroupToken(GroupToken* group_token) override;\n\n private:\n  DeviceType GetUniqueDeviceType(const std::vector<RequestId>& group);\n  GroupToken* CreateGroupToken(const std::vector<RequestId>& group, void* backend_group_token);\n\n  std::vector<std::unique_ptr<ExecutorBackend>> backends_;\n  std::shared_ptr<RequestStore> request_store_;\n  std::vector<RequestId> group_buffer_;\n};\n\nvoid ExecutorImpl::Init(std::shared_ptr<RequestStore> request_store) {\n  request_store_ = request_store;\n  backends_.resize(DeviceType_ARRAYSIZE);\n  const auto& vaild_executor_device_types = ExecutorBackendMgr::Get().vaild_executor_device_types();\n  CHECK_LE(vaild_executor_device_types.size(), 1)\n      << \"Currently only one backend is supported at the same time\";\n\n  for (DeviceType device_type : vaild_executor_device_types) {\n    size_t dev_count = Singleton<ep::DeviceManagerRegistry>::Get()->GetDeviceCount(device_type);\n    if (dev_count > 0) {\n      std::unique_ptr<ExecutorBackend> backend =\n          ExecutorBackendMgr::Get().NewExecutorBackend(device_type);\n      CHECK(backend);\n      backend->Init(request_store_);\n      backends_.at(device_type) = std::move(backend);\n    }\n  }\n}\n\nvoid ExecutorImpl::InitJob(int64_t job_id) {\n  const auto& vaild_executor_device_types = ExecutorBackendMgr::Get().vaild_executor_device_types();\n  for (DeviceType device_type : vaild_executor_device_types) {\n    CHECK(backends_.at(device_type));\n    backends_.at(device_type)->InitJob(job_id);\n  }\n}\n\nvoid ExecutorImpl::DeinitJob(int64_t job_id) {\n  const auto& vaild_executor_device_types = ExecutorBackendMgr::Get().vaild_executor_device_types();\n  for (DeviceType device_type : vaild_executor_device_types) {\n    CHECK(backends_.at(device_type));\n    backends_.at(device_type)->DeinitJob(job_id);\n  }\n}\n\nGroupToken* ExecutorImpl::CreateGroupToken(const std::vector<RequestId>& group,\n                                           void* backend_group_token) {\n  return new GroupToken(GetUniqueDeviceType(group), backend_group_token);\n}\n\nvoid ExecutorImpl::DestroyGroupToken(GroupToken* group_token) {\n  const auto& vaild_executor_device_types = ExecutorBackendMgr::Get().vaild_executor_device_types();\n  for (DeviceType device_type : vaild_executor_device_types) {\n    CHECK(backends_.at(device_type));\n    backends_.at(device_type)->DestroyGroupToken(group_token->backend_group_token());\n  }\n  delete group_token;\n}\n\nvoid ExecutorImpl::GroupRequests(\n    const std::vector<RequestId>& request_ids,\n    const std::function<void(std::vector<RequestId>&&, GroupToken*)>& Handler) {\n  if (request_ids.empty()) { return; }\n  const CollectiveBoxingConf& conf =\n      Singleton<ResourceDesc, ForSession>::Get()->collective_boxing_conf();\n  auto BackendHandler = [&](std::vector<RequestId>&& group, void* backend_group_token) {\n    GroupToken* group_token = CreateGroupToken(group, backend_group_token);\n    Handler(std::move(group), group_token);\n  };\n  auto HandleGroup = [&]() {\n    if (group_buffer_.empty()) { return; }\n    const auto device_type =\n        request_store_->MutRequestEntry(group_buffer_.front())->desc().op_desc().device_type();\n    backends_.at(device_type)->GroupRequests(group_buffer_, BackendHandler);\n    group_buffer_.clear();\n  };\n  request_store_->ForEachMutRequestEntryForIdsInJob(\n      request_ids, [&](RequestEntry* request_entry, int32_t i, const RequestId& request_id) {\n        if (request_entry->HasRankOnThisNode()) {\n          if (!(conf.enable_fusion()\n                && CanMergeIntoCurGroup(request_store_.get(), request_entry, request_id,\n                                        group_buffer_))) {\n            HandleGroup();\n          }\n          group_buffer_.emplace_back(request_id);\n        } else {\n          if (!group_buffer_.empty()\n              && HasRankInteraction(\n                  request_store_->MutRequestEntry(group_buffer_.back())->desc().device_set(),\n                  request_entry->desc().device_set())) {\n            HandleGroup();\n          }\n        }\n      });\n  HandleGroup();\n}\n\nvoid ExecutorImpl::ExecuteGroup(GroupToken* group_token) {\n  const DeviceType device_type = group_token->device_type();\n  backends_.at(device_type)->ExecuteGroup(group_token->backend_group_token());\n}\n\nDeviceType ExecutorImpl::GetUniqueDeviceType(const std::vector<RequestId>& group) {\n  const DeviceType device_type =\n      request_store_->MutRequestEntry(group.front())->desc().op_desc().device_type();\n  request_store_->ForEachMutRequestEntryForIdsInJob(\n      group, [&](RequestEntry* request_entry, int32_t i, const RequestId& request_id) {\n        CHECK_EQ(request_entry->desc().op_desc().device_type(), device_type);\n      });\n  return device_type;\n}\n\nstruct Scheduler::Impl {\n  Impl();\n  std::shared_ptr<RequestStore> request_store;\n  std::shared_ptr<Executor> executor;\n  std::shared_ptr<Coordinator> coordinator;\n};\n\nScheduler::Impl::Impl() {\n  request_store.reset(new RequestStore());\n  executor.reset(new ExecutorImpl());\n  executor->Init(request_store);\n  coordinator.reset(new StaticGroupCoordinator());\n  coordinator->Init(request_store, executor);\n}\n\nclass SchedulerPlanToken {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(SchedulerPlanToken);\n  explicit SchedulerPlanToken(const std::vector<int64_t>& job_ids) : job_ids_(job_ids) {}\n  ~SchedulerPlanToken() = default;\n  const std::vector<int64_t>& job_ids() const { return job_ids_; }\n\n private:\n  std::vector<int64_t> job_ids_;\n};\n\nSchedulerPlanToken* Scheduler::AddPlan(const Plan& plan) {\n  std::vector<int64_t> job_ids;\n  for (const auto& job_id7request_set : plan.collective_boxing_plan().job_id2request_set()) {\n    const int64_t job_id = job_id7request_set.first;\n    job_ids.emplace_back(job_id);\n    impl_->request_store->InitJob(job_id, job_id7request_set.second);\n    impl_->executor->InitJob(job_id);\n    impl_->coordinator->InitJob(job_id);\n  }\n  return new SchedulerPlanToken(job_ids);\n}\n\nvoid Scheduler::DeletePlan(SchedulerPlanToken* plan_token) {\n  const std::vector<int64_t>& job_ids = plan_token->job_ids();\n  for (const auto& job_id : job_ids) {\n    impl_->coordinator->DeinitJob(job_id);\n    impl_->executor->DeinitJob(job_id);\n    impl_->request_store->DeinitJob(job_id);\n  }\n  delete plan_token;\n}\n\nScheduler::Scheduler() { impl_.reset(new Impl()); }\n\nScheduler::~Scheduler() = default;\n\nRequestHandle* Scheduler::CreateRequestHandle(const RankDesc& rank_desc) {\n  const RequestId& request_id =\n      impl_->request_store->GetRequestIdByName(rank_desc.op_desc().name());\n  auto* request_entry = impl_->request_store->MutRequestEntry(request_id);\n  CHECK(rank_desc.op_desc() == request_entry->desc().op_desc());\n  const int32_t local_rank = request_entry->GlobalRankToLocalRank(rank_desc.rank());\n  void* request_entry_token = impl_->request_store->CreateRequestEntryToken(request_id);\n  void* coordinator_token = impl_->coordinator->CreateCoordinatorToken(request_id);\n  return new RequestHandle(local_rank, request_entry_token, coordinator_token);\n}\n\nvoid Scheduler::DestroyRequestHandle(RequestHandle* handle) {\n  impl_->coordinator->DestroyCoordinatorToken(handle->coordinator_token());\n  impl_->request_store->DestroyRequestEntryToken(handle->request_entry_token());\n}\n\nvoid Scheduler::Schedule(RequestHandle* handle,\n                         std::shared_ptr<const RuntimeRequestInfo> request_info) {\n  const int32_t local_rank = handle->local_rank();\n  const bool ready = impl_->request_store->GetRequestEntry(handle->request_entry_token())\n                         ->AddRuntimeRequest(local_rank, std::move(request_info));\n  if (ready) { impl_->coordinator->AddRequest(handle->coordinator_token()); }\n}\n\n}  // namespace collective\n\n}  // namespace boxing\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job/collective_boxing/scheduler.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_JOB_COLLECTIVE_BOXING_SCHEDULER_H_\n#define ONEFLOW_CORE_JOB_COLLECTIVE_BOXING_SCHEDULER_H_\n\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/job/collective_boxing/runtime_request_info.h\"\n#include \"oneflow/core/job/collective_boxing/request_store.h\"\n#include \"oneflow/core/job/plan.pb.h\"\n\nnamespace oneflow {\n\nnamespace boxing {\n\nnamespace collective {\n\nclass RequestHandle;\nclass SchedulerPlanToken;\n\nclass Scheduler final {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(Scheduler);\n  ~Scheduler();\n\n  RequestHandle* CreateRequestHandle(const RankDesc& rank_desc);\n  void DestroyRequestHandle(RequestHandle*);\n  void Schedule(RequestHandle* handle, std::shared_ptr<const RuntimeRequestInfo> request_info);\n  SchedulerPlanToken* AddPlan(const Plan& plan);\n  void DeletePlan(SchedulerPlanToken* plan_token);\n\n private:\n  friend class Singleton<Scheduler>;\n  Scheduler();\n\n  struct Impl;\n  std::unique_ptr<Impl> impl_;\n};\n\n}  // namespace collective\n\n}  // namespace boxing\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_JOB_COLLECTIVE_BOXING_SCHEDULER_H_\n"
  },
  {
    "path": "oneflow/core/job/collective_boxing/static_group_coordinator.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/job/collective_boxing/static_group_coordinator.h\"\n#include \"oneflow/core/job/collective_boxing/executor.h\"\n#include \"oneflow/core/job/collective_boxing/request_store.h\"\n#include \"oneflow/core/graph/boxing/collective_boxing_util.h\"\n#include \"oneflow/core/job/resource_desc.h\"\n#include \"oneflow/core/persistence/tee_persistent_log_stream.h\"\n#include \"oneflow/core/job/global_for.h\"\n#include \"oneflow/core/common/str_util.h\"\n\nnamespace oneflow {\n\nnamespace boxing {\n\nnamespace collective {\n\nnamespace {\n\nvoid SortRequestIdsByOrder(RequestStore* request_store, std::vector<RequestId>* requests) {\n  std::sort(requests->begin(), requests->end(),\n            [request_store](const RequestId& a, const RequestId& b) {\n              return request_store->MutRequestEntry(a)->desc().order()\n                     < request_store->MutRequestEntry(b)->desc().order();\n            });\n}\n\nbool HasRankInteractionOnDeviceSet(const DeviceSet& a, const DeviceSet& b) {\n  for (int64_t i = 0; i < a.device_size(); ++i) {\n    const DeviceDesc& a_device_desc = a.device(i);\n    for (int64_t j = 0; j < b.device_size(); ++j) {\n      if (a_device_desc.machine_id() == b.device(j).machine_id()) { return true; }\n    }\n  }\n  return false;\n}\n\n}  // namespace\n\nstruct GroupState {\n  explicit GroupState(int32_t group_size) : index2is_ready(group_size), ready_request_count(0) {}\n\n  void AddReadyRequest(int32_t index);\n  bool IsReady() const;\n  void Reset();\n\n  std::vector<bool> index2is_ready;\n  int32_t ready_request_count;\n};\nstd::mutex mutex_;\nint64_t current_job_id_ = -1;\nint64_t current_group_idx_in_job_ = -1;\n\nstruct RequestGroupIndex {\n  int32_t group_id;\n  int32_t index_in_group;\n};\n\nclass GroupToken;\n\nstruct StaticGroupRequestsInfo {\n  std::vector<RequestGroupIndex> request_index2request_group_index;\n  std::vector<GroupState> group_states;\n  std::vector<std::vector<RequestId>> group_id2request_ids;\n  std::vector<GroupToken*> group_id2group_token;\n};\n\nstruct StaticGroupRequestsInfoToken {\n  RequestId request_id;\n  StaticGroupRequestsInfo* info;\n};\n\nstruct StaticGroupCoordinator::Impl {\n  Impl(const std::shared_ptr<RequestStore>& request_store,\n       const std::shared_ptr<Executor>& executor);\n  std::shared_ptr<RequestStore> request_store_;\n  std::shared_ptr<Executor> executor_;\n  HashMap<int64_t, StaticGroupRequestsInfo> job_id2static_group_requests_info_;\n};\n\nStaticGroupCoordinator::Impl::Impl(const std::shared_ptr<RequestStore>& request_store,\n                                   const std::shared_ptr<Executor>& executor)\n    : request_store_(request_store), executor_(executor) {}\n\nStaticGroupCoordinator::StaticGroupCoordinator() = default;\n\nStaticGroupCoordinator::~StaticGroupCoordinator() = default;\n\nvoid StaticGroupCoordinator::Init(std::shared_ptr<RequestStore> request_store,\n                                  std::shared_ptr<Executor> executor) {\n  impl_ = std::make_unique<Impl>(request_store, executor);\n}\n\nvoid* StaticGroupCoordinator::CreateCoordinatorToken(const RequestId& request_id) {\n  std::unique_lock<std::mutex> lock(mutex_);\n  auto it = impl_->job_id2static_group_requests_info_.find(request_id.job_id);\n  CHECK(it != impl_->job_id2static_group_requests_info_.end());\n  return new StaticGroupRequestsInfoToken{request_id, &it->second};\n}\n\nvoid StaticGroupCoordinator::DestroyCoordinatorToken(void* coordinator_token) {\n  std::unique_lock<std::mutex> lock(mutex_);\n  auto token = static_cast<StaticGroupRequestsInfoToken*>(coordinator_token);\n  delete token;\n}\n\nvoid StaticGroupCoordinator::InitJob(int64_t job_id) {\n  std::unique_lock<std::mutex> lock(mutex_);\n  std::vector<RequestId> request_ids;\n  impl_->request_store_->ForEachMutRequestEntryInJob(\n      job_id, [&](RequestEntry* request_entry, int32_t i, const RequestId& request_id) {\n        request_ids.emplace_back(request_id);\n      });\n  SortRequestIdsByOrder(impl_->request_store_.get(), &request_ids);\n  StaticGroupRequestsInfo info;\n  std::vector<GroupState>& group_states = info.group_states;\n  std::vector<RequestGroupIndex>& request_index2request_group_index =\n      info.request_index2request_group_index;\n  std::vector<std::vector<RequestId>>& group_id2request_ids = info.group_id2request_ids;\n  std::vector<GroupToken*>& group_id2group_token = info.group_id2group_token;\n  const int32_t request_count = impl_->request_store_->RequestCountForJob(job_id);\n  request_index2request_group_index.resize(request_count);\n  impl_->executor_->GroupRequests(\n      request_ids, [&](std::vector<RequestId>&& group, GroupToken* group_token) {\n        const int32_t group_id = group_states.size();\n        group_states.emplace_back(group.size());\n        for (int32_t idx_in_group = 0; idx_in_group < group.size(); ++idx_in_group) {\n          const RequestId& request_id = group.at(idx_in_group);\n          RequestGroupIndex request_group_index{group_id, idx_in_group};\n          request_index2request_group_index.at(request_id.request_index) = request_group_index;\n        }\n        group_id2request_ids.emplace_back(group);\n        group_id2group_token.emplace_back(group_token);\n      });\n\n  CHECK(impl_->job_id2static_group_requests_info_.emplace(job_id, info).second);\n  if (group_states.size() != 0) { DumpSummary(job_id); }\n}\n\nvoid StaticGroupCoordinator::DeinitJob(int64_t job_id) {\n  std::unique_lock<std::mutex> lock(mutex_);\n  const auto& it = impl_->job_id2static_group_requests_info_.find(job_id);\n  CHECK(it != impl_->job_id2static_group_requests_info_.end());\n  const auto& group_id2group_token = it->second.group_id2group_token;\n  for (int32_t group_id = 0; group_id < group_id2group_token.size(); ++group_id) {\n    impl_->executor_->DestroyGroupToken(group_id2group_token.at(group_id));\n  }\n  impl_->job_id2static_group_requests_info_.erase(job_id);\n}\n\nvoid StaticGroupCoordinator::AddRequest(void* coordinator_token) {\n  std::unique_lock<std::mutex> lock(mutex_);\n  StaticGroupRequestsInfoToken* token =\n      static_cast<StaticGroupRequestsInfoToken*>(coordinator_token);\n  const RequestId& request_id = token->request_id;\n  if (current_job_id_ == -1) {\n    current_job_id_ = request_id.job_id;\n    current_group_idx_in_job_ = 0;\n  } else {\n    CHECK_EQ(current_job_id_, request_id.job_id);\n  }\n  StaticGroupRequestsInfo* info = token->info;\n  const RequestGroupIndex& request_group_index =\n      info->request_index2request_group_index.at(request_id.request_index);\n  info->group_states.at(request_group_index.group_id)\n      .AddReadyRequest(request_group_index.index_in_group);\n  int64_t num_launched_groups = 0;\n  while (true) {\n    auto& group_state = info->group_states.at(current_group_idx_in_job_);\n    if (group_state.IsReady()) {\n      impl_->executor_->ExecuteGroup(info->group_id2group_token.at(current_group_idx_in_job_));\n      group_state.Reset();\n      current_group_idx_in_job_ = (current_group_idx_in_job_ + 1) % info->group_states.size();\n      num_launched_groups += 1;\n    } else {\n      break;\n    }\n  }\n  if (current_group_idx_in_job_ == 0 && num_launched_groups > 0) {\n    current_job_id_ = -1;\n    current_group_idx_in_job_ = -1;\n  }\n}\n\nvoid StaticGroupCoordinator::DumpSummary(const int64_t job_id) const {\n  if (!Singleton<ResourceDesc, ForSession>::Get()->enable_debug_mode()) { return; }\n  auto group_ls = TeePersistentLogStream::Create(StrCat(\"boxing/collective/job_\", job_id));\n  const auto& it = impl_->job_id2static_group_requests_info_.find(job_id);\n\n  CHECK(it != impl_->job_id2static_group_requests_info_.end());\n  const auto& group_id2request_ids = it->second.group_id2request_ids;\n  for (int32_t group_id = 0; group_id < group_id2request_ids.size(); ++group_id) {\n    group_ls << \"group id: \" << std::to_string(group_id) << \"\\n\";\n    impl_->request_store_->ForEachMutRequestEntryForIdsInJob(\n        group_id2request_ids.at(group_id),\n        [&](RequestEntry* request_entry, int32_t i, const RequestId& request_id) {\n          group_ls->Write(request_entry->desc());\n        });\n  }\n}\n\nvoid GroupState::AddReadyRequest(int32_t index) {\n  CHECK(!index2is_ready.at(index));\n  CHECK(index2is_ready.at(index) = true);\n  ready_request_count += 1;\n}\n\nbool GroupState::IsReady() const { return ready_request_count == index2is_ready.size(); }\n\nvoid GroupState::Reset() {\n  ready_request_count = 0;\n  std::fill(index2is_ready.begin(), index2is_ready.end(), false);\n}\n\n}  // namespace collective\n\n}  // namespace boxing\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job/collective_boxing/static_group_coordinator.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_JOB_COLLECTIVE_BOXING_STATIC_GROUP_COORDINATOR_H_\n#define ONEFLOW_CORE_JOB_COLLECTIVE_BOXING_STATIC_GROUP_COORDINATOR_H_\n\n#include \"oneflow/core/job/collective_boxing/coordinator.h\"\n\nnamespace oneflow {\n\nnamespace boxing {\n\nnamespace collective {\n\nclass RequestStore;\nclass Executor;\n\nclass StaticGroupCoordinator : public Coordinator {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(StaticGroupCoordinator);\n  StaticGroupCoordinator();\n  ~StaticGroupCoordinator() override;\n\n  void Init(std::shared_ptr<RequestStore> request_store,\n            std::shared_ptr<Executor> executor) override;\n  void InitJob(int64_t job_id) override;\n  void DeinitJob(int64_t job_id) override;\n  void AddRequest(void* coordinator_token) override;\n  void* CreateCoordinatorToken(const RequestId& request_id) override;\n  void DestroyCoordinatorToken(void* token) override;\n\n private:\n  void DumpSummary(const int64_t job_id) const;\n\n  struct Impl;\n  std::unique_ptr<Impl> impl_;\n};\n\n}  // namespace collective\n\n}  // namespace boxing\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_JOB_COLLECTIVE_BOXING_STATIC_GROUP_COORDINATOR_H_\n"
  },
  {
    "path": "oneflow/core/job/compile_mode.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/job/compile_mode.h\"\n#include \"oneflow/core/common/env_var/env_var.h\"\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/common/container_util.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nstruct CompileModeName final : public CompileModeVisitor<CompileModeName> {\n  static std::string VisitNaive() { return \"naive\"; }\n  static std::string VisitRankPerProcess() { return \"rank_per_process\"; }\n  static std::string VisitInValid() { return \"invalid\"; }\n};\n\nstd::unordered_map<std::string, CompileMode> Name2CompileMode() {\n  std::unordered_map<std::string, CompileMode> name2compile_mode;\n  for (int i = static_cast<int>(CompileMode::kInvalid) + 1;\n       i != static_cast<int>(CompileMode::kEnd); ++i) {\n    CompileMode compile_mode = static_cast<CompileMode>(i);\n    CHECK(name2compile_mode.emplace(CompileModeName::Visit(compile_mode), compile_mode).second);\n  }\n  return name2compile_mode;\n}\n\nstd::string GetValidCompileModeNames() {\n  std::stringstream ss;\n  for (int i = static_cast<int>(CompileMode::kInvalid) + 1;\n       i != static_cast<int>(CompileMode::kEnd); ++i) {\n    if (i > static_cast<int>(CompileMode::kInvalid) + 1) { ss << \", \"; }\n    CompileMode compile_mode = static_cast<CompileMode>(i);\n    ss << CompileModeName::Visit(compile_mode);\n  }\n  return ss.str();\n}\n\n}  // namespace\n\nMaybe<CompileMode> CurrentCompileMode() {\n  static thread_local CompileMode mode =\n      JUST_MSG(MapAt(Name2CompileMode(), ThreadLocalEnvString<ONEFLOW_LAZY_COMPILE_MODE>()),\n               std::stringstream()\n                   << \"ONEFLOW_LAZY_COMPILER(value: \"\n                   << ThreadLocalEnvString<ONEFLOW_LAZY_COMPILE_MODE>()\n                   << \") is invalid. valid options: \\\"\" << GetValidCompileModeNames() << \"\\\"\");\n  return mode;\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job/compile_mode.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_JOB_COMPILE_MODE_H_\n#define ONEFLOW_CORE_JOB_COMPILE_MODE_H_\n\n#include \"oneflow/core/common/maybe.h\"\n\nnamespace oneflow {\n\nenum class CompileMode {\n  kInvalid = 0,  // make sure kInvalid is the first CompileMode\n  kNaive,\n  kRankPerProcess,\n  kEnd,  // make sure kEnd is the last CompileMode\n};\n\ntemplate<typename DerivedT>\nstruct CompileModeVisitor {\n  template<typename... Args>\n  static auto Visit(CompileMode compile_mode, Args&&... args) {\n    switch (compile_mode) {\n      case CompileMode::kNaive: return DerivedT::VisitNaive(std::forward<Args>(args)...);\n      case CompileMode::kRankPerProcess:\n        return DerivedT::VisitRankPerProcess(std::forward<Args>(args)...);\n      default: {\n        LOG(FATAL) << \"invalid compile mode\";\n        return DerivedT::VisitInValid(std::forward<Args>(args)...);\n      }\n    }\n  }\n};\n\nMaybe<CompileMode> CurrentCompileMode();\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_JOB_COMPILE_MODE_H_\n"
  },
  {
    "path": "oneflow/core/job/compiler.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/job/compiler.h\"\n#include \"oneflow/core/job/global_for.h\"\n#include \"oneflow/core/job/intra_job_mem_sharing_util.h\"\n#include \"oneflow/core/job/plan_util.h\"\n#include \"oneflow/core/persistence/tee_persistent_log_stream.h\"\n#include \"oneflow/core/graph/op_graph.h\"\n#include \"oneflow/core/job_rewriter/job_completer.h\"\n#include \"oneflow/core/thread/thread_pool.h\"\n#include \"oneflow/core/common/blocking_counter.h\"\n#include \"oneflow/core/common/cost_util.h\"\n#include \"oneflow/core/job/lazy_mode.h\"\n\nnamespace oneflow {\n\nvoid Compiler::Compile(Job* job, Plan* plan) const {\n  const auto& job_name = job->job_conf().job_name();\n  auto compile_tc = std::make_unique<CostCounter<std::chrono::seconds>>(true, true);\n  // Step1: new Singleton<OpGraph> and set log configs.\n  Singleton<OpGraph>::New(*job);\n  const JobDesc& job_desc = GlobalJobDesc();\n  compile_tc->Count(\"[GraphCompile]\" + job_name + \" NewOpGraph\", 1);\n\n  // Step2: build task_gph.\n  // TODO(levi): we can rewrite this part of code in visitor pattern.\n  auto task_gph = CHECK_JUST(GlobalTaskGraph::New());\n  using std::placeholders::_1;\n  LazyMode::Guard guard(true);\n  task_gph->ForEachNode(std::bind(&TaskNode::ProduceAllRegstsAndBindEdges, _1));\n  task_gph->ForEachNode(std::bind(&TaskNode::ConsumeAllRegsts, _1));\n  task_gph->ForEachNode(std::bind(&TaskNode::PinConsumedRegst, _1));\n  task_gph->TopoForEachNode(&TaskNode::Build);\n  task_gph->RemoveEmptyRegsts();\n  task_gph->TopoForEachNode(&TaskNode::InferTimeShapeIfMeaningful);\n  task_gph->DecideExecutionOrder();\n  task_gph->MergeChainAndAddOrderingCtrlEdgeInSameChain();\n  auto IsReachable = Singleton<OpGraph>::Get()->MakePredicatorIsOpNameDataOrCtrlReachable();\n  if (job_desc.enable_inplace()) { task_gph->EnableInplaceMemSharing(IsReachable); }\n  task_gph->ForEachEdge([&](TaskEdge* task_edge) { task_edge->CheckRegstLbiValid(); });\n  compile_tc->Count(\"[GraphCompile]\" + job_name + \" BuildTaskGraph\", 1, true);\n\n  // Step3: put infomation from task_gph into plan.\n  const int64_t node_num = task_gph->node_num();\n  const int64_t cpu_num = std::thread::hardware_concurrency();\n  const int64_t thread_pool_size = std::min(node_num, cpu_num);\n  BlockingCounter counter(node_num);\n  std::mutex mtx;\n  ThreadPool thread_pool(thread_pool_size);\n  task_gph->ForEachNode([&](TaskNode* task_node) {\n    thread_pool.AddWork([task_node, plan, &job_desc, &counter, &mtx]() {\n      if (!task_node->IsMeaningLess()) {\n        TaskProto task_proto;\n        task_node->ToProto(&task_proto);\n        {\n          std::unique_lock<std::mutex> guard(mtx);\n          if (task_node->GetTaskType() == kNormalForward || task_node->GetTaskType() == kRepeat\n              || task_node->GetTaskType() == kAcc) {\n            PlanUtil::CreateOpAttributeRef(plan, job_desc.job_id(), &task_proto);\n          }\n          plan->mutable_task()->Add(std::move(task_proto));\n        }  // guard(mtx)\n      }\n      counter.Decrease();\n    } /* thread_pool.AddWork */);\n  } /* task_gph->ForEachNode */);\n  counter.WaitForeverUntilCntEqualZero();\n  // NOTE(levi): release task_gph here to decrise memory peak.\n  task_gph.reset();\n  compile_tc->Count(\"[GraphCompile]\" + job_name + \" AddTaskToPlan\", 1, true);\n\n  // Step4: post-process for plan and delete Singleton<OpGraph>.\n  auto* job_id2job_conf = plan->mutable_job_confs()->mutable_job_id2job_conf();\n  (*job_id2job_conf)[GlobalJobDesc().job_id()] = GlobalJobDesc().job_conf();\n  // NOTE(chengcheng): infer mem blob id & set inplace & add ctrl\n  // TODO(chengcheng): set inplace hint for cpu regst\n  IntraJobMemSharingUtil::InferMemBlockId4MemReusedRegst(plan);\n  PlanUtil::MergeMemBlockIdByLogicalChainId(plan, *job);\n  PlanUtil::SetUniqueMemBlockId4UnreusedMemRegst(plan);\n  PlanUtil::SetForceInplaceMemBlock(plan);\n  compile_tc->Count(\"[GraphCompile]\" + job_name + \" InferMemShare\", 1, true);\n  Singleton<OpGraph>::Delete();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job/compiler.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_JOB_COMPILER_H_\n#define ONEFLOW_CORE_JOB_COMPILER_H_\n\n#include \"oneflow/core/common/protobuf.h\"\n#include \"oneflow/core/graph/task_graph.h\"\n#include \"oneflow/core/job/plan.pb.h\"\n#include \"oneflow/core/operator/operator.h\"\n\nnamespace oneflow {\n\nclass Compiler final {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(Compiler);\n  Compiler() = default;\n  ~Compiler() = default;\n\n  void Compile(Job*, Plan*) const;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_JOB_COMPILER_H_\n"
  },
  {
    "path": "oneflow/core/job/critical_section.proto",
    "content": "syntax = \"proto2\";\npackage oneflow;\n\nmessage TotalJobCriticalSection {}\nmessage InputOutputCriticalSection {\n  repeated string lbi_producer_op_name = 1;\n}\n\nmessage CriticalSection {\n  required int64 job_id = 1;\n  map<int64, string> machine_id2source_tick_op_name = 2;\n  map<int64, string> machine_id2sink_tick_op_name = 3;\n  repeated int64 mem_block_id = 4;\n  repeated int64 chunk_id = 5;\n  oneof type {\n    TotalJobCriticalSection total_job_critical_section = 6;\n    InputOutputCriticalSection input_output_critical_section = 7;\n  }\n}\n"
  },
  {
    "path": "oneflow/core/job/critical_section_desc.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/job/critical_section_desc.h\"\n#include \"oneflow/core/persistence/tee_persistent_log_stream.h\"\n#include <google/protobuf/text_format.h>\n#include <cstdint>\n#include <string>\n\nnamespace oneflow {\n\nCriticalSection* CriticalSectionDesc::AddCriticalSection(int64_t job_id) {\n  CHECK_EQ(inited_, false);\n  auto critical_section = std::make_unique<CriticalSection>();\n  CriticalSection* ret = critical_section.get();\n  critical_section->set_job_id(job_id);\n  critical_sections_.emplace_back(std::move(critical_section));\n  return ret;\n}\n\nvoid CriticalSectionDesc::Done() {\n  CHECK_EQ(inited_, false);\n  UpdateJobId2CriticalSectionIds();\n  UpdateJobId2TotalJobCriticalSectionId();\n  UpdateCriticalSectionIds2IntersectingIds();\n  CHECK_EQ(job_id2critical_section_ids_.size(), job_id2total_job_critical_section_id_.size());\n  CHECK_EQ(critical_sections_.size(), critical_section_id2intersecting_ids_.size());\n  inited_ = true;\n  std::string all_output;\n  int32_t i = 0;\n  for (const auto& cs : critical_sections_) {\n    all_output += \"CriticalSection \" + std::to_string(i) + \"\\n\";\n    std::string output;\n    google::protobuf::TextFormat::PrintToString(*cs, &output);\n    all_output += output;\n    all_output += \"\\n\";\n    i++;\n  }\n  TeePersistentLogStream::Create(\"critical_section_desc\")->Write(all_output);\n}\n\nconst CriticalSection& CriticalSectionDesc::GetCriticalSection(int64_t critical_section_id) const {\n  CHECK(inited_);\n  return *critical_sections_.at(critical_section_id);\n}\n\nCriticalSection* CriticalSectionDesc::MutCriticalSection(int64_t critical_section_id) const {\n  CHECK_EQ(inited_, false);\n  return critical_sections_.at(critical_section_id).get();\n}\n\nconst std::vector<int64_t>& CriticalSectionDesc::CriticalSectionIds4JobId(int64_t job_id) const {\n  CHECK(inited_);\n  return job_id2critical_section_ids_.at(job_id);\n}\n\nvoid CriticalSectionDesc::DumpCriticalSectionId2IntersectinIds(PbRpf<Int64List>* id2id_list) const {\n  CHECK(inited_);\n  FOR_RANGE(int64_t, i, 0, critical_sections_.size()) {\n    *id2id_list->Add()->mutable_value() = {critical_section_id2intersecting_ids_.at(i).begin(),\n                                           critical_section_id2intersecting_ids_.at(i).end()};\n  }\n}\n\nvoid CriticalSectionDesc::UpdateJobId2CriticalSectionIds() {\n  CHECK_EQ(inited_, false);\n  job_id2critical_section_ids_.resize(critical_sections_.size());\n  int64_t max_job_id = -1;\n  FOR_RANGE(int64_t, i, 0, critical_sections_.size()) {\n    const auto& critical_section = *critical_sections_.at(i);\n    int64_t job_id = critical_section.job_id();\n    job_id2critical_section_ids_[job_id].emplace_back(i);\n    max_job_id = std::max(max_job_id, job_id);\n  }\n  job_id2critical_section_ids_.resize(max_job_id + 1);\n}\n\nvoid CriticalSectionDesc::UpdateJobId2TotalJobCriticalSectionId() {\n  CHECK_EQ(inited_, false);\n  HashSet<int64_t> unique_check;\n  job_id2total_job_critical_section_id_.resize(critical_sections_.size());\n  FOR_RANGE(int64_t, i, 0, critical_sections_.size()) {\n    const auto& critical_section = *critical_sections_.at(i);\n    if (critical_section.has_total_job_critical_section()) {\n      CHECK(unique_check.emplace(critical_section.job_id()).second);\n      job_id2total_job_critical_section_id_.at(critical_section.job_id()) = i;\n    }\n  }\n  job_id2total_job_critical_section_id_.resize(unique_check.size());\n}\n\nvoid CriticalSectionDesc::UpdateCriticalSectionIds2IntersectingIds() {\n  CHECK_EQ(inited_, false);\n  critical_section_id2intersecting_ids_.resize(critical_sections_.size());\n  HashMap<int64_t, HashSet<int64_t>> mem_block_id2critical_section_ids;\n  HashMap<int64_t, HashSet<int64_t>> chunk_id2critical_section_ids;\n  FOR_RANGE(int64_t, i, 0, critical_sections_.size()) {\n    for (int64_t mem_block_id : critical_sections_.at(i)->mem_block_id()) {\n      mem_block_id2critical_section_ids[mem_block_id].insert(i);\n    }\n    for (int64_t chunk_id : critical_sections_.at(i)->chunk_id()) {\n      chunk_id2critical_section_ids[chunk_id].insert(i);\n    }\n  }\n  for (const auto& pair : mem_block_id2critical_section_ids) {\n    for (int64_t first_id : pair.second) {\n      for (int64_t second_id : pair.second) {\n        if (first_id != second_id) {\n          critical_section_id2intersecting_ids_[first_id].insert(second_id);\n        }\n      }\n    }\n  }\n  for (const auto& pair : chunk_id2critical_section_ids) {\n    for (int64_t first_id : pair.second) {\n      for (int64_t second_id : pair.second) {\n        if (first_id != second_id) {\n          critical_section_id2intersecting_ids_[first_id].insert(second_id);\n        }\n      }\n    }\n  }\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job/critical_section_desc.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_JOB_CRITICAL_SECTION_DESC_H_\n#define ONEFLOW_CORE_JOB_CRITICAL_SECTION_DESC_H_\n\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/common/protobuf.h\"\n#include \"oneflow/core/common/data_type.h\"\n#include \"oneflow/core/job/critical_section.pb.h\"\n\nnamespace oneflow {\n\nclass CriticalSectionDesc final {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CriticalSectionDesc);\n  ~CriticalSectionDesc() = default;\n\n  CriticalSection* AddCriticalSection(int64_t job_id);\n  void Done();\n\n  size_t CriticalSectionNum() const { return critical_sections_.size(); }\n  const CriticalSection& GetCriticalSection(int64_t) const;\n  CriticalSection* MutCriticalSection(int64_t) const;\n  const std::vector<int64_t>& CriticalSectionIds4JobId(int64_t) const;\n  void DumpCriticalSectionId2IntersectinIds(PbRpf<Int64List>* id2id_list) const;\n  const std::vector<std::vector<int64_t>>& job_id2critical_section_ids() const {\n    return job_id2critical_section_ids_;\n  }\n  const std::vector<int64_t>& job_id2total_job_critical_section_id() const {\n    return job_id2total_job_critical_section_id_;\n  }\n\n private:\n  friend class Singleton<CriticalSectionDesc>;\n  CriticalSectionDesc() : inited_(false) {}\n  void UpdateJobId2CriticalSectionIds();\n  void UpdateJobId2TotalJobCriticalSectionId();\n  void UpdateCriticalSectionIds2IntersectingIds();\n\n  bool inited_;\n  std::vector<std::unique_ptr<CriticalSection>> critical_sections_;\n  std::vector<std::vector<int64_t>> job_id2critical_section_ids_;\n  std::vector<int64_t> job_id2total_job_critical_section_id_;\n  std::vector<HashSet<int64_t>> critical_section_id2intersecting_ids_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_JOB_CRITICAL_SECTION_DESC_H_\n"
  },
  {
    "path": "oneflow/core/job/critical_section_instance.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_JOB_CRITICAL_SECTION_INSTANCE_H_\n#define ONEFLOW_CORE_JOB_CRITICAL_SECTION_INSTANCE_H_\n\n#include <string>\n#include \"oneflow/core/common/util.h\"\n\nnamespace oneflow {\n\nclass Blob;\n\nnamespace ep {\nclass Stream;\n}\n\nclass CriticalSectionInstance {\n public:\n  CriticalSectionInstance() = default;\n\n  virtual const std::string& job_name() const = 0;\n\n  virtual ~CriticalSectionInstance() = default;\n\n  virtual void AccessBlobByOpName(ep::Stream* stream, Blob* blob,\n                                  const std::string& op_name) const {\n    UNIMPLEMENTED();\n  }\n  virtual void Finish() const { UNIMPLEMENTED(); }\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_JOB_CRITICAL_SECTION_INSTANCE_H_\n"
  },
  {
    "path": "oneflow/core/job/distribute_hirarchy.proto",
    "content": "syntax = \"proto2\";\npackage oneflow;\n\nimport \"oneflow/core/job/sbp_parallel.proto\";\n\nenum DistributeType {\n  kInvalidDistributeType = 0;\n  kSpaceDistribute = 2;\n  kTimeDistribute = 3;\n}\n\nmessage DistributeDim {\n  required DistributeType distribute_type = 1;\n  required SbpParallel sbp_parallel = 2;\n  required int64 distribute_num = 3;\n}\n\nmessage DistributeHirarchy {\n  repeated DistributeDim dim = 1;\n}"
  },
  {
    "path": "oneflow/core/job/dlnet_conf.proto",
    "content": "syntax = \"proto2\";\npackage oneflow;\n\nimport \"oneflow/core/operator/op_conf.proto\";\n\nmessage DLNetConf {\n  repeated OperatorConf op = 1;\n}\n"
  },
  {
    "path": "oneflow/core/job/eager_ccl_comm_manager.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/job/eager_ccl_comm_manager.h\"\n\nnamespace oneflow {\n\nconst std::string EagerCclCommMgr::kDefaultCclStreamName = \"DEFAULT\";\n\nEagerCclCommMgrBuilder& EagerCclCommMgrBuilder::Get() {\n  static EagerCclCommMgrBuilder mgr;\n  return mgr;\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job/eager_ccl_comm_manager.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_JOB_EAGER_CCL_COMM_MANAGER_H_\n#define ONEFLOW_CORE_JOB_EAGER_CCL_COMM_MANAGER_H_\n\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/job/plan.pb.h\"\n#include \"oneflow/user/kernels/collective_communication/include/collective_communication.h\"\n\nnamespace oneflow {\n\nclass EagerCclCommMgr {\n public:\n  static const std::string kDefaultCclStreamName;\n  OF_DISALLOW_COPY_AND_MOVE(EagerCclCommMgr);\n  virtual ~EagerCclCommMgr() = default;\n\n  virtual void CreateCommFromPlan(const Plan& plan) = 0;\n  virtual bool IsAsyncLaunchCclLogicalKernel() const = 0;\n  virtual void SetAsyncLaunchCclLogicalKernel(bool val) = 0;\n  virtual ccl::CclComm GetCclCommForParallelDesc(const ParallelDesc& parallel_desc) = 0;\n  virtual ccl::CclComm GetCclCommForParallelDescAndStreamName(const ParallelDesc& parallel_desc,\n                                                              const std::string& stream_name) = 0;\n  virtual ccl::CclComm GetCclCommForParallelDescNdHierarchy(const ParallelDesc& parallel_desc,\n                                                            const std::string& stream_name,\n                                                            const int64_t this_parallel_id,\n                                                            const std::string& comm_key) = 0;\n\n  template<typename T>\n  T* As() {\n    return dynamic_cast<T*>(this);\n  }\n\n protected:\n  EagerCclCommMgr() = default;\n};\n\nclass EagerCclCommMgrBuilder {\n public:\n  using Creator = std::function<EagerCclCommMgr*()>;\n\n  EagerCclCommMgrBuilder(EagerCclCommMgrBuilder const&) = delete;\n  EagerCclCommMgrBuilder& operator=(EagerCclCommMgrBuilder const&) = delete;\n  static EagerCclCommMgrBuilder& Get();\n\n  template<typename Derived>\n  void RegisterEagerCclCommMgrType(DeviceType device_type) {\n    ccl_comm_mgr_reg_result_->emplace(device_type,\n                                      []() -> EagerCclCommMgr* { return new Derived; });\n    vaild_ccl_comm_mgr_device_types_.emplace_back(device_type);\n  }\n\n  EagerCclCommMgr* NewCclCommMgr(DeviceType device_type) const {\n    const auto& it = ccl_comm_mgr_reg_result_->find(device_type);\n    CHECK(it != ccl_comm_mgr_reg_result_->end());\n    return it->second();\n  }\n\n  const std::vector<DeviceType>& vaild_ccl_comm_mgr_device_types() const {\n    return vaild_ccl_comm_mgr_device_types_;\n  }\n\n private:\n  EagerCclCommMgrBuilder() { ccl_comm_mgr_reg_result_.reset(new std::map<DeviceType, Creator>); }\n\n  std::unique_ptr<std::map<DeviceType, Creator>> ccl_comm_mgr_reg_result_;\n  std::vector<DeviceType> vaild_ccl_comm_mgr_device_types_;\n};\n\n#define REGISTER_CCL_COMM_MGR(device, Derived) \\\n  COMMAND(EagerCclCommMgrBuilder::Get().RegisterEagerCclCommMgrType<Derived>(device))\n\nclass UserKernelUnifiedCclCommInitRegistry final {\n public:\n  struct Trigger {\n    explicit Trigger(const std::string& key) {\n      UserKernelUnifiedCclCommInitRegistry::Instance().Register(key);\n    }\n  };\n\n  static UserKernelUnifiedCclCommInitRegistry& Instance() {\n    static UserKernelUnifiedCclCommInitRegistry reg;\n    return reg;\n  }\n\n  OF_DISALLOW_COPY_AND_MOVE(UserKernelUnifiedCclCommInitRegistry);\n  ~UserKernelUnifiedCclCommInitRegistry() = default;\n\n  void Register(const std::string& key) {\n    bool insert_success = reg_set_.insert(key).second;\n    if (!insert_success) {\n      std::cerr << key << \" was already registered in CclCommRegistry\" << std::endl;\n      abort();\n    }\n  }\n\n  bool IsRegistered(const std::string& key) const { return reg_set_.find(key) != reg_set_.end(); }\n\n private:\n  UserKernelUnifiedCclCommInitRegistry() = default;\n  std::set<std::string> reg_set_;\n};\n\nstatic const std::string kSystemCclOpPrefix = \"sys_op_\";\n\n#define REGISTER_USER_KERNEL_UNIFIED_CCL_COMM_INIT(op_type_name) \\\n  static auto OF_PP_CAT(g_nccl_comm_reg_, __COUNTER__) =         \\\n      ::oneflow::UserKernelUnifiedCclCommInitRegistry::Trigger(op_type_name)\n\n#define REGISTER_SYSTEM_OP_KERNEL_UNIFIED_CCL_COMM_INIT(op_type_case)                        \\\n  static auto OF_PP_CAT(g_nccl_comm_reg_, __COUNTER__) =                                     \\\n      ::oneflow::UserKernelUnifiedCclCommInitRegistry::Trigger(::oneflow::kSystemCclOpPrefix \\\n                                                               + std::to_string(op_type_case))\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_JOB_EAGER_CCL_COMM_MANAGER_H_\n"
  },
  {
    "path": "oneflow/core/job/eager_nccl_comm_manager.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <iomanip>\n#include <string>\n#include \"oneflow/core/control/ctrl_client.h\"\n#include \"oneflow/core/control/global_process_ctx.h\"\n#include \"oneflow/core/job/eager_nccl_comm_manager.h\"\n#include \"oneflow/core/device/nccl_util.h\"\n#include \"oneflow/core/job/id_manager.h\"\n#include \"oneflow/core/job/parallel_desc.h\"\n#include \"oneflow/core/operator/op_conf.pb.h\"\n#include \"oneflow/core/vm/vm_util.h\"\n\n#ifdef WITH_CUDA\n\nnamespace oneflow {\n\nnamespace {\n\nstd::string GetNcclUniqueIdRpcKey(const std::vector<std::pair<int64_t, int64_t>>& sorted_devices) {\n  std::ostringstream oss;\n  oss << \"eager_nccl_unique_id_rpc_key\";\n  for (const auto& pair : sorted_devices) { oss << \",\" << pair.first << \":\" << pair.second; }\n  return oss.str();\n}\n\nstd::string NcclUniqueId2String(const ncclUniqueId& id) {\n  std::stringstream ss;\n  for (int i = 0; i < NCCL_UNIQUE_ID_BYTES; ++i) {\n    ss << std::hex << std::setfill('0') << std::setw(2) << static_cast<int>(id.internal[i]);\n  }\n  return ss.str();\n}\n\nbool CompareDeviceSetPair(const std::pair<int64_t, int64_t>& a,\n                          const std::pair<int64_t, int64_t>& b) {\n  if (a.first == b.first) {\n    return a.second < b.second;\n  } else {\n    return a.first < b.first;\n  }\n}\n\nvoid CreateNcclComm(ncclComm_t* comm, const int dev, const std::string& key,\n                    const std::vector<std::pair<int64_t, int64_t>>& device_vec) {\n  ncclUniqueId nccl_unique_id{};\n  int64_t machine = GlobalProcessCtx::Rank();\n  std::pair<int64_t, int64_t> this_device(machine, dev);\n  auto it = std::find(device_vec.cbegin(), device_vec.cend(), this_device);\n  CHECK(it != device_vec.end());\n  int rank = std::distance(device_vec.cbegin(), it);\n  if (rank == 0) {\n    OF_NCCL_CHECK(ncclGetUniqueId(&nccl_unique_id));\n    Singleton<CtrlClient>::Get()->PushKV(\n        key, std::string(nccl_unique_id.internal, NCCL_UNIQUE_ID_BYTES));\n  } else {\n    Singleton<CtrlClient>::Get()->PullKV(key, [&nccl_unique_id](const std::string& val) {\n      memcpy(nccl_unique_id.internal, val.data(), NCCL_UNIQUE_ID_BYTES);\n    });\n  }\n  VLOG(2) << \" EagerNcclCommMgr::ncclCommInitRank device_vec.size() = \" << device_vec.size()\n          << \", nccl_unique_id = \" << NcclUniqueId2String(nccl_unique_id) << \", rank = \" << rank\n          << \", key = {\" << key << \"}\\n\";\n  OF_NCCL_CHECK(ncclCommInitRank(comm, device_vec.size(), nccl_unique_id, rank));\n  VLOG(2) << \" EagerNcclCommMgr::ncclCommInitRank succeed device_vec.size() = \" << device_vec.size()\n          << \", nccl_unique_id = \" << NcclUniqueId2String(nccl_unique_id) << \", rank = \" << rank\n          << \", key = {\" << key << \"}\\n\";\n}\n\nbool NeedUnifiedNcclCommInit(const OperatorConf& op_conf) {\n  if (op_conf.has_user_conf()) {\n    return UserKernelUnifiedNcclCommInitRegistry::Instance().IsRegistered(\n        op_conf.user_conf().op_type_name());\n  } else {\n    // Please check the .h file for hard-coding of the name\n    return UserKernelUnifiedNcclCommInitRegistry::Instance().IsRegistered(\n        kSystemOpPrefix + std::to_string(op_conf.op_type_case()));\n  }\n}\n\n}  // namespace\n\nconst std::string EagerNcclCommMgr::kDefaultStreamName = \"DEFAULT\";\n\nEagerNcclCommMgr::~EagerNcclCommMgr() {\n  for (auto& device_set7device_id2comm : device_set2device_id2comm_) {\n    for (auto& device_id7comm : device_set7device_id2comm.second) {\n      OF_NCCL_CHECK(ncclCommDestroy(device_id7comm.second));\n    }\n  }\n  for (auto& pair : device7stream2device_id2comm_) {\n    for (auto& device_id7comm : pair.second) {\n      OF_NCCL_CHECK(ncclCommDestroy(device_id7comm.second));\n    }\n  }\n}\n\nncclComm_t EagerNcclCommMgr::GetCommForDevice(\n    const std::set<std::pair<int64_t, int64_t>>& device_set) {\n  int dev;\n  OF_CUDA_CHECK(cudaGetDevice(&dev));\n  {\n    std::lock_guard<std::mutex> lock(mutex_);\n    auto it = device_set2device_id2comm_.find(device_set);\n    if (it != device_set2device_id2comm_.end()) { return it->second.at(dev); }\n  }\n  std::vector<std::pair<int64_t, int64_t>> device_vec(device_set.cbegin(), device_set.cend());\n  std::sort(device_vec.begin(), device_vec.end(), CompareDeviceSetPair);\n\n  ncclComm_t comm;\n  std::string nccl_unique_id_rpc_key = GetNcclUniqueIdRpcKey(device_vec);\n  CreateNcclComm(&comm, dev, nccl_unique_id_rpc_key, device_vec);\n\n  {\n    std::lock_guard<std::mutex> lock(mutex_);\n    device_set2device_id2comm_[device_set][dev] = comm;\n  }\n  return comm;\n}\n\nncclComm_t EagerNcclCommMgr::GetCommForDeviceAndStreamName(\n    const std::set<std::pair<int64_t, int64_t>>& device_set, const std::string& stream_name) {\n  int dev;\n  OF_CUDA_CHECK(cudaGetDevice(&dev));\n\n  std::vector<std::pair<int64_t, int64_t>> device_vec(device_set.cbegin(), device_set.cend());\n  std::sort(device_vec.begin(), device_vec.end(), CompareDeviceSetPair);\n  std::string key = GetNcclUniqueIdRpcKey(device_vec) + \"-stream_name_hint:\" + stream_name;\n\n  {\n    std::lock_guard<std::mutex> lock(mutex_);\n    auto it = device7stream2device_id2comm_.find(key);\n    if (it != device7stream2device_id2comm_.end()) { return it->second.at(dev); }\n  }\n\n  ncclComm_t comm;\n  CreateNcclComm(&comm, dev, key, device_vec);\n\n  {\n    std::lock_guard<std::mutex> lock(mutex_);\n    device7stream2device_id2comm_[key][dev] = comm;\n  }\n  return comm;\n}\n\nccl::CclComm EagerNcclCommMgr::GetCclCommForParallelDesc(const ParallelDesc& parallel_desc) {\n  std::set<std::pair<int64_t, int64_t>> device_set;\n  FOR_RANGE(int64_t, parallel_id, 0, parallel_desc.parallel_num()) {\n    int64_t machine_id = CHECK_JUST(parallel_desc.MachineId4ParallelId(parallel_id));\n    int64_t device_id = CHECK_JUST(parallel_desc.DeviceId4ParallelId(parallel_id));\n    device_set.emplace(std::make_pair(machine_id, device_id));\n  }\n\n  ncclComm_t comm = GetCommForDevice(device_set);\n  std::shared_ptr<ccl::CommBase> ncclCommAdapter = std::make_shared<ccl::NcclCommAdapter>(comm);\n  ccl::CclComm ccl_comm(ncclCommAdapter);\n  return ccl_comm;\n}\n\nccl::CclComm EagerNcclCommMgr::GetCclCommForParallelDescAndStreamName(\n    const ParallelDesc& parallel_desc, const std::string& stream_name) {\n  std::set<std::pair<int64_t, int64_t>> device_set;\n  FOR_RANGE(int64_t, parallel_id, 0, parallel_desc.parallel_num()) {\n    int64_t machine_id = CHECK_JUST(parallel_desc.MachineId4ParallelId(parallel_id));\n    int64_t device_id = CHECK_JUST(parallel_desc.DeviceId4ParallelId(parallel_id));\n    device_set.emplace(std::make_pair(machine_id, device_id));\n  }\n\n  ncclComm_t comm = GetCommForDeviceAndStreamName(device_set, stream_name);\n  std::shared_ptr<ccl::CommBase> ncclCommAdapter = std::make_shared<ccl::NcclCommAdapter>(comm);\n  ccl::CclComm ccl_comm(ncclCommAdapter);\n  return ccl_comm;\n}\n\nccl::CclComm EagerNcclCommMgr::GetCclCommForParallelDescNdHierarchy(\n    const ParallelDesc& parallel_desc, const std::string& stream_name,\n    const int64_t this_parallel_id, const std::string& comm_key) {\n  std::set<std::pair<int64_t, int64_t>> device_set;\n  const Shape& hierarchy = *parallel_desc.hierarchy();\n  CHECK_LE(hierarchy.NumAxes(), 2);\n\n  // 1D\n  if (hierarchy.NumAxes() == 1) {\n    // 1D hierarchy\n    for (int64_t parallel_id = 0; parallel_id < parallel_desc.parallel_num(); ++parallel_id) {\n      int64_t machine_id = CHECK_JUST(parallel_desc.MachineId4ParallelId(parallel_id));\n      int64_t device_id = CHECK_JUST(parallel_desc.DeviceId4ParallelId(parallel_id));\n      device_set.emplace(std::make_pair(machine_id, device_id));\n    }\n  } else if (hierarchy.NumAxes() == 2) {\n    // 2D hierarchy\n    CHECK(comm_key == \"SameDim0\" || comm_key == \"SameDim1\");\n    if (comm_key == \"SameDim0\") {\n      const int64_t num_groups = hierarchy.At(0);\n      const int64_t group_size = hierarchy.At(1);\n      CHECK_EQ(num_groups * group_size, parallel_desc.parallel_num());\n      const int64_t this_group_begin_parallel_id = this_parallel_id / group_size * group_size;\n      CHECK_EQ(this_group_begin_parallel_id % group_size, 0);\n      CHECK_LE(this_group_begin_parallel_id + group_size, parallel_desc.parallel_num());\n      for (int64_t id_in_group = 0; id_in_group < group_size; ++id_in_group) {\n        const int64_t parallel_id = this_group_begin_parallel_id + id_in_group;\n        const int64_t machine_id = CHECK_JUST(parallel_desc.MachineId4ParallelId(parallel_id));\n        const int64_t device_id = CHECK_JUST(parallel_desc.DeviceId4ParallelId(parallel_id));\n        device_set.emplace(std::make_pair(machine_id, device_id));\n      }\n    } else if (comm_key == \"SameDim1\") {\n      const int64_t group_size = hierarchy.At(0);\n      const int64_t num_groups = hierarchy.At(1);\n      CHECK_EQ(num_groups * group_size, parallel_desc.parallel_num());\n      const int64_t this_group_begin_parallel_id = this_parallel_id % num_groups;\n      CHECK_LT(this_group_begin_parallel_id + (group_size - 1) * num_groups,\n               parallel_desc.parallel_num());\n      for (int64_t id_in_group = 0; id_in_group < group_size; ++id_in_group) {\n        const int64_t parallel_id = this_group_begin_parallel_id + (id_in_group * num_groups);\n        const int64_t machine_id = CHECK_JUST(parallel_desc.MachineId4ParallelId(parallel_id));\n        const int64_t device_id = CHECK_JUST(parallel_desc.DeviceId4ParallelId(parallel_id));\n        device_set.emplace(std::make_pair(machine_id, device_id));\n      }\n    } else {\n      UNIMPLEMENTED();\n    }\n  }\n\n  ncclComm_t comm = GetCommForDeviceAndStreamName(device_set, stream_name);\n  std::shared_ptr<ccl::CommBase> ncclCommAdapter = std::make_shared<ccl::NcclCommAdapter>(comm);\n  ccl::CclComm ccl_comm(ncclCommAdapter);\n  return ccl_comm;\n}\n\nvoid EagerNcclCommMgr::CreateCommFromPlan(const Plan& plan) {\n  const int64_t rank = GlobalProcessCtx::Rank();\n  const int64_t dev = GlobalProcessCtx::LocalRank();\n  std::map<std::string, std::vector<std::pair<int64_t, int64_t>>> nccl_comm_key2devices;\n\n  for (const auto& task_proto : plan.task()) {\n    if (task_proto.machine_id() != rank) { continue; }\n    if (task_proto.exec_sequence().exec_node_size() != 1) { continue; }\n    const auto& kernel_conf = task_proto.exec_sequence().exec_node(0).kernel_conf();\n    const OpAttribute* op_attr = nullptr;\n    if (kernel_conf.has_op_attribute()) {\n      op_attr = &kernel_conf.op_attribute();\n    } else if (kernel_conf.has_op_attribute_ref()) {\n      const auto& ref_name = kernel_conf.op_attribute_ref();\n      op_attr = &plan.job_id2op_attribute_ref_table()\n                     .at(task_proto.job_id())\n                     .op_name2op_attribute()\n                     .at(ref_name);\n    } else {\n      continue;\n    }\n    const auto& op_conf = op_attr->op_conf();\n    if (!NeedUnifiedNcclCommInit(op_conf)) { continue; }\n    if (!op_attr->has_parallel_conf_signature()) { continue; }\n    if (!op_attr->parallel_conf_signature().has_op_parallel_conf()) { continue; }\n\n    std::vector<std::pair<int64_t, int64_t>> device_vec;\n    ParallelDesc parallel_desc(op_attr->parallel_conf_signature().op_parallel_conf());\n    for (int64_t parallel_id = 0; parallel_id < parallel_desc.parallel_num(); ++parallel_id) {\n      int64_t machine_id = CHECK_JUST(parallel_desc.MachineId4ParallelId(parallel_id));\n      int64_t device_id = CHECK_JUST(parallel_desc.DeviceId4ParallelId(parallel_id));\n      device_vec.emplace_back(machine_id, device_id);\n    }\n\n    std::string stream_name = kDefaultStreamName;\n    if (op_conf.has_stream_name_hint()) { stream_name = op_conf.stream_name_hint(); }\n    std::string key = GetNcclUniqueIdRpcKey(device_vec) + \"-stream_name_hint:\" + stream_name;\n\n    VLOG(3) << \" EagerNcclCommMgr create nccl comm for \" << op_conf.name() << \", rank = \" << rank\n            << \", dev = \" << dev << \", key = {\" << key << \"}\\n\";\n    nccl_comm_key2devices.emplace(std::move(key), std::move(device_vec));\n  }\n\n  if (nccl_comm_key2devices.size() == 0) { return; }\n\n  CHECK_JUST(vm::CurrentRankSync());\n  CudaCurrentDeviceGuard guard(dev);\n\n  for (const auto& pair : nccl_comm_key2devices) {\n    const auto& key = pair.first;\n    auto device_id2comm_it = device7stream2device_id2comm_.find(key);\n    if (device_id2comm_it != device7stream2device_id2comm_.end()) {\n      auto comm_it = device_id2comm_it->second.find(dev);\n      if (comm_it != device_id2comm_it->second.end()) { continue; }\n    }\n    ncclComm_t comm;\n    CreateNcclComm(&comm, dev, key, pair.second);\n    device7stream2device_id2comm_[key][dev] = comm;\n  }\n}\n\nREGISTER_CCL_COMM_MGR(DeviceType::kCUDA, EagerNcclCommMgr);\n\n}  // namespace oneflow\n\n#endif  // WITH_CUDA\n"
  },
  {
    "path": "oneflow/core/job/eager_nccl_comm_manager.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_JOB_EAGER_NCCL_COMM_MANAGER_H_\n#define ONEFLOW_CORE_JOB_EAGER_NCCL_COMM_MANAGER_H_\n\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/job/plan.pb.h\"\n#include \"oneflow/core/job/eager_ccl_comm_manager.h\"\n\n#ifdef WITH_CUDA\n\n#include \"oneflow/core/device/cuda_util.h\"\n\nnamespace oneflow {\nnamespace ccl {\n\nclass NcclCommAdapter : public CommBase {\n public:\n  explicit NcclCommAdapter(ncclComm_t comm) : comm_(comm) {}\n\n  void* getComm() const override { return const_cast<void*>(static_cast<const void*>(&comm_)); }\n\n private:\n  ncclComm_t comm_;\n};\n\n}  // namespace ccl\n\nclass EagerNcclCommMgr final : public EagerCclCommMgr {\n public:\n  static const std::string kDefaultStreamName;\n\n  OF_DISALLOW_COPY_AND_MOVE(EagerNcclCommMgr);\n  ~EagerNcclCommMgr() override;\n\n  ncclComm_t GetCommForDevice(const std::set<std::pair<int64_t, int64_t>>& device_set);\n  ncclComm_t GetCommForDeviceAndStreamName(const std::set<std::pair<int64_t, int64_t>>& device_set,\n                                           const std::string& stream_name);\n  ccl::CclComm GetCclCommForParallelDesc(const ParallelDesc& parallel_desc) override;\n  ccl::CclComm GetCclCommForParallelDescAndStreamName(const ParallelDesc& parallel_desc,\n                                                      const std::string& stream_name) override;\n  ccl::CclComm GetCclCommForParallelDescNdHierarchy(const ParallelDesc& parallel_desc,\n                                                    const std::string& stream_name,\n                                                    const int64_t this_parallel_id,\n                                                    const std::string& comm_key) override;\n\n  void CreateCommFromPlan(const Plan& plan) override;\n  bool IsAsyncLaunchCclLogicalKernel() const override { return async_launch_nccl_logical_kernel_; }\n  void SetAsyncLaunchCclLogicalKernel(bool val) override {\n    async_launch_nccl_logical_kernel_ = val;\n  }\n\n private:\n  friend class EagerCclCommMgrBuilder;\n  // NOTE(chengcheng): default async launch nccl logical kernel is true for better performence.\n  EagerNcclCommMgr() : EagerCclCommMgr(), async_launch_nccl_logical_kernel_(true) {}\n\n  std::map<std::set<std::pair<int64_t, int64_t>>, HashMap<int64_t, ncclComm_t>>\n      device_set2device_id2comm_;\n  std::map<std::string, HashMap<int64_t, ncclComm_t>> device7stream2device_id2comm_;\n  std::mutex mutex_;\n  bool async_launch_nccl_logical_kernel_;\n};\n\nclass UserKernelUnifiedNcclCommInitRegistry final {\n public:\n  struct Trigger {\n    explicit Trigger(const std::string& key) {\n      UserKernelUnifiedNcclCommInitRegistry::Instance().Register(key);\n    }\n  };\n\n  static UserKernelUnifiedNcclCommInitRegistry& Instance() {\n    static UserKernelUnifiedNcclCommInitRegistry reg;\n    return reg;\n  }\n\n  OF_DISALLOW_COPY_AND_MOVE(UserKernelUnifiedNcclCommInitRegistry);\n  ~UserKernelUnifiedNcclCommInitRegistry() = default;\n\n  void Register(const std::string& key) {\n    bool insert_success = reg_set_.insert(key).second;\n    if (!insert_success) {\n      std::cerr << key << \" was already registered in NcclCommRegistry\" << std::endl;\n      abort();\n    }\n  }\n\n  bool IsRegistered(const std::string& key) const { return reg_set_.find(key) != reg_set_.end(); }\n\n private:\n  UserKernelUnifiedNcclCommInitRegistry() = default;\n  std::set<std::string> reg_set_;\n};\n\nstatic const std::string kSystemOpPrefix = \"sys_op_\";\n\n}  // namespace oneflow\n\n#define REGISTER_USER_KERNEL_UNIFIED_NCCL_COMM_INIT(op_type_name) \\\n  static auto OF_PP_CAT(g_nccl_comm_reg_, __COUNTER__) =          \\\n      ::oneflow::UserKernelUnifiedNcclCommInitRegistry::Trigger(op_type_name)\n\n#define REGISTER_SYSTEM_OP_KERNEL_UNIFIED_NCCL_COMM_INIT(op_type_case)                     \\\n  static auto OF_PP_CAT(g_nccl_comm_reg_, __COUNTER__) =                                   \\\n      ::oneflow::UserKernelUnifiedNcclCommInitRegistry::Trigger(::oneflow::kSystemOpPrefix \\\n                                                                + std::to_string(op_type_case))\n\n#endif  // WITH_CUDA\n\n#endif  // ONEFLOW_CORE_JOB_EAGER_NCCL_COMM_MANAGER_H_\n"
  },
  {
    "path": "oneflow/core/job/env.proto",
    "content": "syntax = \"proto2\";\npackage oneflow;\n\nimport \"oneflow/core/control/ctrl_bootstrap.proto\";\n\nmessage Machine {\n  required int64 id = 1;\n  required string addr = 2; // domain name or ip\n  optional int32 ctrl_port_agent = 3 [default = -1];\n  optional int32 data_port_agent = 4 [default = -1];\n}\n\nmessage CppLoggingConf {\n  optional string log_dir = 1 [default = \"./log\"];\n  optional int32 logtostderr = 2 [default = 1];\n  optional int32 logbuflevel = 3 [default = -1];\n  optional int32 minloglevel = 4 [default = 1];\n}\n\nmessage EnvProto {\n  repeated Machine machine = 1;\n  required int32 ctrl_port = 2;\n  optional int32 data_port = 3 [default = -1];\n  optional CppLoggingConf cpp_logging_conf = 4;\n  optional BootstrapConf ctrl_bootstrap_conf = 5;\n}\n"
  },
  {
    "path": "oneflow/core/job/env_desc.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/job/env_desc.h\"\n#include \"oneflow/core/job/global_for.h\"\n\nnamespace oneflow {\n\nconst BootstrapConf& EnvDesc::bootstrap_conf() const {\n  CHECK(has_ctrl_bootstrap_conf());\n  return env_proto_.ctrl_bootstrap_conf();\n}\n\nint32_t EnvDesc::bootstrap_conf_ctrl_port() const {\n  CHECK(has_bootstrap_conf_ctrl_port());\n  return env_proto_.ctrl_bootstrap_conf().ctrl_port();\n}\n\nsize_t EnvDesc::TotalMachineNum() const {\n  if (env_proto_.has_ctrl_bootstrap_conf()) {\n    return env_proto_.ctrl_bootstrap_conf().world_size();\n  } else {\n    return env_proto_.machine().size();\n  }\n}\n\nint64_t EnvDesc::GetMachineId(const std::string& addr) const {\n  int64_t machine_id = -1;\n  int64_t machine_num = env_proto_.machine_size();\n  FOR_RANGE(int64_t, i, 0, machine_num) {\n    if (addr == env_proto_.machine(i).addr()) {\n      machine_id = i;\n      break;\n    }\n  }\n  CHECK_GE(machine_id, 0);\n  CHECK_LT(machine_id, machine_num);\n  return machine_id;\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job/env_desc.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_JOB_CLUSTER_DESC_H_\n#define ONEFLOW_CORE_JOB_CLUSTER_DESC_H_\n\n#include \"oneflow/core/job/env.pb.h\"\n#include \"oneflow/core/common/util.h\"\n\nnamespace oneflow {\n\nclass EnvDesc final {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(EnvDesc);\n  explicit EnvDesc(const EnvProto& env_proto) : env_proto_(env_proto) {}\n  ~EnvDesc() = default;\n\n  const EnvProto& env_proto() const { return env_proto_; }\n  const Machine& machine(int32_t idx) const { return env_proto_.machine(idx); }\n  int32_t ctrl_port() const { return env_proto_.ctrl_port(); }\n  int32_t data_port() const { return env_proto_.data_port(); }\n  bool has_ctrl_bootstrap_conf() const { return env_proto_.has_ctrl_bootstrap_conf(); }\n  bool has_bootstrap_conf_ctrl_port() const {\n    return has_ctrl_bootstrap_conf() && env_proto_.ctrl_bootstrap_conf().has_ctrl_port();\n  }\n  const BootstrapConf& bootstrap_conf() const;\n  int32_t bootstrap_conf_ctrl_port() const;\n  size_t TotalMachineNum() const;\n  int64_t GetMachineId(const std::string& addr) const;\n\n private:\n  EnvProto env_proto_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_JOB_CLUSTER_DESC_H_\n"
  },
  {
    "path": "oneflow/core/job/env_global_objects_scope.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/vm/remat/allocator.h\"\n#ifdef WITH_CUDA\n#include <cuda.h>\n#endif  // WITH_CUDA\n#include <thread>\n#include \"oneflow/core/thread/thread_pool.h\"\n#include \"oneflow/core/job/env_global_objects_scope.h\"\n#include \"oneflow/core/control/ctrl_server.h\"\n#include \"oneflow/core/control/ctrl_bootstrap.h\"\n#include \"oneflow/core/control/ctrl_client.h\"\n#include \"oneflow/core/control/global_process_ctx.h\"\n#include \"oneflow/core/job/resource_desc.h\"\n#include \"oneflow/core/job/global_for.h\"\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/common/tensor_buffer.h\"\n#include \"oneflow/core/persistence/file_system.h\"\n#include \"oneflow/core/device/cuda_util.h\"\n#include \"oneflow/core/vm/virtual_machine_scope.h\"\n#include \"oneflow/core/vm/remat/util.h\"\n#include \"oneflow/core/job/job_build_and_infer_ctx_mgr.h\"\n#include \"oneflow/core/job/eager_ccl_comm_manager.h\"\n#include \"oneflow/core/device/cudnn_conv_util.h\"\n#include \"oneflow/core/rpc/include/manager.h\"\n#include \"oneflow/core/transport/transport.h\"\n#include \"oneflow/core/hardware/node_device_descriptor_manager.h\"\n#include \"oneflow/core/vm/symbol_storage.h\"\n#include \"oneflow/core/framework/multi_client_session_context.h\"\n#include \"oneflow/core/framework/scope_util.h\"\n#include \"oneflow/core/operator/op_node_signature.pb.h\"\n#include \"oneflow/core/comm_network/comm_network.h\"\n#include \"oneflow/core/comm_network/epoll/epoll_comm_network.h\"\n#include \"oneflow/core/comm_network/ibverbs/ibverbs_comm_network.h\"\n#include \"oneflow/core/kernel/chain_kernel_observer.h\"\n#include \"oneflow/core/kernel/sync_check_kernel_observer.h\"\n#include \"oneflow/core/kernel/blob_access_checker_kernel_observer.h\"\n#include \"oneflow/core/kernel/profiler_kernel_observer.h\"\n#include \"oneflow/core/embedding/embedding_manager.h\"\n#include \"oneflow/core/vm/remat/env.h\"\n#ifdef WITH_RDMA\n#include \"oneflow/core/platform/include/ibv.h\"\n#include \"oneflow/core/comm_network/ibverbs/ibverbs_comm_network.h\"\n#endif  // WITH_RDMA\n#include \"oneflow/core/ep/include/device_manager_registry.h\"\n#include \"oneflow/core/ep/cpu/cpu_device_manager.h\"\n#include \"oneflow/core/common/env_var/debug_mode.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nstd::string LogDir(const std::string& log_dir) {\n  char hostname[255];\n  CHECK_EQ(gethostname(hostname, sizeof(hostname)), 0);\n  std::string v = JoinPath(log_dir, std::string(hostname));\n  return v;\n}\n\nvoid InitLogging(const CppLoggingConf& logging_conf) {\n  FLAGS_log_dir = LogDir(logging_conf.log_dir());\n  FLAGS_logtostderr = logging_conf.logtostderr();\n  FLAGS_logbuflevel = logging_conf.logbuflevel();\n  FLAGS_minloglevel = logging_conf.minloglevel();\n  FLAGS_stderrthreshold = 1;  // 1=WARNING\n  google::InitGoogleLogging(\"oneflow\");\n  if (IsInDebugMode()) {\n    // record all level logs to file in debug mode\n    FLAGS_logtostderr = 0;\n    FLAGS_minloglevel = 0;  // 0=INFO\n  }\n  if (!FLAGS_logtostderr) { LocalFS()->RecursivelyCreateDirIfNotExist(FLAGS_log_dir); }\n}\n\nint32_t GetDefaultCpuDeviceNum() { return std::thread::hardware_concurrency(); }\n\nResource GetDefaultResource(const EnvProto& env_proto) {\n  Resource resource;\n  if (env_proto.has_ctrl_bootstrap_conf()) {\n    resource.set_machine_num(GlobalProcessCtx::NodeSize());\n  } else {\n    resource.set_machine_num(env_proto.machine_size());\n  }\n  resource.set_cpu_device_num(GetDefaultCpuDeviceNum());\n  return resource;\n}\n\nvoid SetCpuDeviceManagerNumThreads() {\n  ep::CpuDeviceManager* cpu_device_manager = dynamic_cast<ep::CpuDeviceManager*>(\n      Singleton<ep::DeviceManagerRegistry>::Get()->GetDeviceManager(DeviceType::kCPU));\n  constexpr size_t kDefaultUsedNumThreads = 2;\n  int64_t cpu_logic_core = std::thread::hardware_concurrency();\n  int64_t default_num_threads =\n      (cpu_logic_core / GlobalProcessCtx::NumOfProcessPerNode()) - kDefaultUsedNumThreads;\n  int64_t num_threads = ParseIntegerFromEnv(\"OMP_NUM_THREADS\", default_num_threads);\n  cpu_device_manager->SetDeviceNumThreads(num_threads);\n}\n\nvoid ClearAllSymbol() {\n  Singleton<symbol::Storage<Scope>>::Get()->ClearAll();\n  Singleton<symbol::Storage<JobDesc>>::Get()->ClearAll();\n  Singleton<symbol::Storage<ParallelDesc>>::Get()->ClearAll();\n  Singleton<symbol::Storage<OperatorConfSymbol>>::Get()->ClearAll();\n}\n\n#if defined(WITH_RDMA) && defined(OF_PLATFORM_POSIX)\n\nbool CommNetIBEnabled() {\n  if (!ibv::IsAvailable()) { return false; }\n  const auto* node_manager = Singleton<hardware::NodeDeviceDescriptorManager>::Get();\n  if (node_manager == nullptr) { return false; }\n  for (int64_t rank = 0; rank < GlobalProcessCtx::WorldSize(); ++rank) {\n    const auto& node = node_manager->GetNodeDeviceDescriptor(rank);\n    if (!node) { return false; }\n    const auto& list = node->GetDeviceDescriptorList(\"net_ib\");\n    if (!list) { return false; }\n    if (list->DeviceCount() == 0) { return false; }\n  }\n  return true;\n}\n\n#endif  // WITH_RDMA && OF_PLATFORM_POSIX\n\n}  // namespace\n\nEnvGlobalObjectsScope::EnvGlobalObjectsScope(const std::string& env_proto_str) {\n  EnvProto env_proto;\n  CHECK(TxtString2PbMessage(env_proto_str, &env_proto))\n      << \"failed to parse env_proto\" << env_proto_str;\n  CHECK_JUST(Init(env_proto));\n}\n\nEnvGlobalObjectsScope::EnvGlobalObjectsScope(const EnvProto& env_proto) {\n  CHECK_JUST(Init(env_proto));\n}\n\nMaybe<void> EnvGlobalObjectsScope::Init(const EnvProto& env_proto) {\n  CHECK(Singleton<EnvGlobalObjectsScope>::Get() == nullptr);\n  Singleton<EnvGlobalObjectsScope>::SetAllocated(this);\n\n  InitLogging(env_proto.cpp_logging_conf());\n  Singleton<remat::Env>::New();\n  Singleton<EnvDesc>::New(env_proto);\n  Singleton<ProcessCtx>::New();\n  // Avoid dead lock by using CHECK_JUST instead of JUST. because it maybe be blocked in\n  // ~CtrlBootstrap.\n\n  if ((env_proto.machine_size() == 1 && env_proto.has_ctrl_bootstrap_conf() == false)\n      || (env_proto.has_ctrl_bootstrap_conf()\n          && env_proto.ctrl_bootstrap_conf().world_size() == 1)) /*single process*/ {\n#ifdef RPC_BACKEND_LOCAL\n    LOG(INFO) << \"Using rpc backend: local\";\n    Singleton<RpcManager>::SetAllocated(new LocalRpcManager());\n#else\n    static_assert(false, \"Requires rpc backend local to run oneflow in single processs\");\n#endif  // RPC_BACKEND_LOCAL\n  } else /*multi process, multi machine*/ {\n#ifdef RPC_BACKEND_GRPC\n    LOG(INFO) << \"Using rpc backend: gRPC\";\n    Singleton<RpcManager>::SetAllocated(new GrpcRpcManager());\n#else\n    UNIMPLEMENTED() << \"To run distributed oneflow, you must enable at least one multi-node rpc \"\n                       \"backend by adding cmake argument, for instance: -DRPC_BACKEND=GRPC\";\n#endif  // RPC_BACKEND_GRPC\n  }\n  CHECK_JUST(Singleton<RpcManager>::Get()->CreateServer());\n  CHECK_JUST(Singleton<RpcManager>::Get()->Bootstrap());\n  CHECK_JUST(Singleton<RpcManager>::Get()->CreateClient());\n  Singleton<ResourceDesc, ForEnv>::New(GetDefaultResource(env_proto),\n                                       GlobalProcessCtx::NumOfProcessPerNode());\n  Singleton<ResourceDesc, ForSession>::New(GetDefaultResource(env_proto),\n                                           GlobalProcessCtx::NumOfProcessPerNode());\n  Singleton<hardware::NodeDeviceDescriptorManager>::SetAllocated(\n      new hardware::NodeDeviceDescriptorManager());\n  if (Singleton<ResourceDesc, ForEnv>::Get()->enable_debug_mode()) {\n    Singleton<hardware::NodeDeviceDescriptorManager>::Get()->DumpSummary(\"devices\");\n  }\n  Singleton<ep::DeviceManagerRegistry>::New();\n  Singleton<remat::AllocatorManager>::New();\n  Singleton<ThreadPool>::New(Singleton<ResourceDesc, ForSession>::Get()->ComputeThreadPoolSize());\n  SetCpuDeviceManagerNumThreads();\n#ifdef WITH_CUDA\n  Singleton<CudnnConvAlgoCache>::New();\n  Singleton<CudnnHandlePool>::New();\n  Singleton<embedding::EmbeddingManager>::New();\n#endif\n  const auto& vaild_ccl_comm_mgr_device_types =\n      EagerCclCommMgrBuilder::Get().vaild_ccl_comm_mgr_device_types();\n  CHECK_LE_OR_RETURN(vaild_ccl_comm_mgr_device_types.size(), 1)\n      << \"Only one kind collective communication manager is supported at most at the same time for \"\n         \"now!\";\n  if (!vaild_ccl_comm_mgr_device_types.empty() && !Singleton<EagerCclCommMgr>::Get()) {\n    Singleton<EagerCclCommMgr>::SetAllocated(\n        EagerCclCommMgrBuilder::Get().NewCclCommMgr(vaild_ccl_comm_mgr_device_types.front()));\n  }\n\n  Singleton<vm::VirtualMachineScope>::New(Singleton<ResourceDesc, ForSession>::Get()->resource());\n#ifdef __linux__\n  Singleton<EpollCommNet>::New();\n  Singleton<Transport>::New();\n  if (Singleton<ResourceDesc, ForSession>::Get()->process_ranks().size() > 1) {\n    Singleton<CommNet>::SetAllocated(Singleton<EpollCommNet>::Get());\n  }\n#endif  // __linux__\n  {\n    std::vector<std::shared_ptr<KernelObserver>> kernel_observers;\n    if (ParseBooleanFromEnv(\"ONEFLOW_DEBUG_KERNEL_SYNC_CHECK\", false)) {\n      LOG(WARNING)\n          << \"Environment variable ONEFLOW_DEBUG_KERNEL_SYNC_CHECK has been set to a truthy \"\n             \"value, it will impact performance\";\n      kernel_observers.emplace_back(new SyncCheckKernelObserver());\n    }\n    if (!ParseBooleanFromEnv(\"ONEFLOW_KERNEL_DISABLE_BLOB_ACCESS_CHECKER\", true)) {\n      kernel_observers.emplace_back(new BlobAccessCheckerKernelObserver());\n    }\n    kernel_observers.emplace_back(new ProfilerKernelObserver());\n    Singleton<KernelObserver>::SetAllocated(new ChainKernelObserver(kernel_observers));\n  }\n  TensorBufferPool::New();\n  return Maybe<void>::Ok();\n}\n\nEnvGlobalObjectsScope::~EnvGlobalObjectsScope() {\n  VLOG(2) << \"Try to close env global objects scope.\" << std::endl;\n  OF_ENV_BARRIER();\n  if (is_normal_exit_.has_value() && !CHECK_JUST(is_normal_exit_)) { return; }\n  TensorBufferPool::Delete();\n  Singleton<KernelObserver>::Delete();\n#ifdef __linux__\n  if (Singleton<ResourceDesc, ForSession>::Get()->process_ranks().size() > 1) {\n    if (Singleton<EpollCommNet>::Get() != dynamic_cast<EpollCommNet*>(Singleton<CommNet>::Get())) {\n      Singleton<CommNet>::Delete();\n    }\n  }\n  Singleton<Transport>::Delete();\n  Singleton<EpollCommNet>::Delete();\n#endif  // __linux__\n  Singleton<vm::VirtualMachineScope>::Delete();\n#ifdef WITH_CUDA\n  Singleton<embedding::EmbeddingManager>::Delete();\n  Singleton<CudnnConvAlgoCache>::Delete();\n  Singleton<CudnnHandlePool>::Delete();\n#endif\n  if (Singleton<EagerCclCommMgr>::Get() != nullptr) { Singleton<EagerCclCommMgr>::Delete(); }\n  Singleton<ThreadPool>::Delete();\n  Singleton<remat::AllocatorManager>::Delete();\n  Singleton<ep::DeviceManagerRegistry>::Delete();\n  if (Singleton<ResourceDesc, ForSession>::Get() != nullptr) {\n    Singleton<ResourceDesc, ForSession>::Delete();\n  }\n  Singleton<ResourceDesc, ForEnv>::Delete();\n  Singleton<hardware::NodeDeviceDescriptorManager>::Delete();\n  CHECK_NOTNULL(Singleton<CtrlClient>::Get());\n  CHECK_NOTNULL(Singleton<EnvDesc>::Get());\n  Singleton<RpcManager>::Delete();\n  Singleton<ProcessCtx>::Delete();\n  Singleton<EnvDesc>::Delete();\n  Singleton<remat::Env>::Delete();\n  ClearAllSymbol();\n  ClearAllBackwardPassScope();\n  if (Singleton<EnvGlobalObjectsScope>::Get() != nullptr) {\n    Singleton<EnvGlobalObjectsScope>::SetAllocated(nullptr);\n  }\n  VLOG(2) << \"Finish closing env global objects scope.\" << std::endl;\n  google::ShutdownGoogleLogging();\n}\n\nMaybe<void> InitRDMA() {\n#ifdef __linux__\n  if (Singleton<ResourceDesc, ForSession>::Get()->process_ranks().size() > 1) {\n#if defined(WITH_RDMA) && defined(OF_PLATFORM_POSIX)\n    if (CommNetIBEnabled()) {\n      if (Singleton<IBVerbsCommNet>::Get() == nullptr) {\n        Singleton<IBVerbsCommNet>::New();\n        Singleton<CommNet>::SetAllocated(Singleton<IBVerbsCommNet>::Get());\n      } else {\n        LOG(INFO) << \"Skip init RDMA because RDMA is already initialized!\";\n      }\n    } else {\n      LOG(WARNING) << \"Skip init RDMA because RDMA is unavailable!\";\n    }\n#else\n    LOG(WARNING) << \"Skip init RDMA because RDMA is not compiled!\";\n#endif  // WITH_RDMA && OF_PLATFORM_POSIX\n  } else {\n    LOG(INFO) << \"Skip init RDMA because only one process in this group!\";\n  }\n#endif  // __linux__\n  return Maybe<void>::Ok();\n}\n\nMaybe<bool> RDMAIsInitialized() {\n#if defined(WITH_RDMA) && defined(OF_PLATFORM_POSIX)\n  return Singleton<IBVerbsCommNet>::Get() != nullptr;\n#else\n  return false;\n#endif  // WITH_RDMA && OF_PLATFORM_POSIX\n}\n\nMaybe<void> DestoryRDMA() {\n#if defined(WITH_RDMA) && defined(OF_PLATFORM_POSIX)\n  if (JUST(RDMAIsInitialized())) {\n    CHECK_NOTNULL(Singleton<IBVerbsCommNet>::Get());\n    CHECK_NOTNULL(Singleton<CommNet>::Get());\n    Singleton<IBVerbsCommNet>::Delete();\n    if (Singleton<EpollCommNet>::Get()) {\n      Singleton<CommNet>::SetAllocated(Singleton<EpollCommNet>::Get());\n    }\n  }\n#endif  // WITH_RDMA && OF_PLATFORM_POSIX\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job/env_global_objects_scope.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_JOB_CLUSTER_OBJECTS_SCOPE_H_\n#define ONEFLOW_CORE_JOB_CLUSTER_OBJECTS_SCOPE_H_\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/common/optional.h\"\n#include \"oneflow/core/job/env_desc.h\"\n#include \"oneflow/core/framework/device.h\"\n\nnamespace oneflow {\n\nclass ParallelDesc;\n\nclass EnvGlobalObjectsScope final {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(EnvGlobalObjectsScope);\n  explicit EnvGlobalObjectsScope(const std::string& env_proto_str);\n  explicit EnvGlobalObjectsScope(const EnvProto& env_proto);\n  ~EnvGlobalObjectsScope();\n\n  Maybe<void> init_is_normal_exit(bool is_normal_exit) {\n    CHECK_OR_RETURN(!is_normal_exit_.has_value());\n    is_normal_exit_ = is_normal_exit;\n    return Maybe<void>::Ok();\n  }\n\n private:\n  Maybe<void> Init(const EnvProto& env_proto);\n\n private:\n  Optional<bool> is_normal_exit_;\n};\n\nMaybe<void> InitRDMA();\n\nMaybe<bool> RDMAIsInitialized();\n\nMaybe<void> DestoryRDMA();\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_JOB_CLUSTER_OBJECTS_SCOPE_H_\n"
  },
  {
    "path": "oneflow/core/job/function_config_def.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/config_def.h\"\n\nnamespace oneflow {}\n"
  },
  {
    "path": "oneflow/core/job/global_for.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/job/global_for.h\"\n#include \"oneflow/core/common/error.h\"\n#include \"oneflow/core/common/singleton.h\"\n#include \"oneflow/core/common/optional.h\"\n#include \"oneflow/core/common/util.h\"\n\nnamespace oneflow {\n\nCOMMAND(Singleton<Optional<bool>, MultiClient>::SetAllocated(new Optional<bool>()));\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job/global_for.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_JOB_GLOBAL_FOR_H_\n#define ONEFLOW_CORE_JOB_GLOBAL_FOR_H_\n\n#include \"oneflow/core/common/singleton.h\"\n\nnamespace oneflow {\n\nclass ForSession {};\nclass ForEnv {};\n\nclass MultiClient {};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_JOB_GLOBAL_FOR_H_\n"
  },
  {
    "path": "oneflow/core/job/global_mode.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/job/global_mode.h\"\n#include \"oneflow/core/common/symbol.h\"\n#include \"oneflow/core/framework/device.h\"\n\nnamespace oneflow {\n\nSymbol<ParallelDesc> GetGlobalParallelDescFromDevice(const Optional<Symbol<Device>>& device) {\n  auto parallel_desc = GlobalMode::parallel_desc();\n  if (device.has_value()) {\n    const auto& device_type = device.value_or(Symbol<Device>())->type();\n    if (parallel_desc->parallel_conf().device_tag() != device_type) {\n      ParallelConf parallel_conf = parallel_desc->parallel_conf();\n      parallel_conf.set_device_tag(device_type);\n      parallel_desc = SymbolOf(ParallelDesc(parallel_conf));\n    }\n  }\n  return parallel_desc;\n}\n\n/* static */ bool* GlobalMode::get_mode_ptr() {\n  thread_local bool mode = false;\n  return &mode;\n}\n/* static */ bool GlobalMode::is_enabled() { return *get_mode_ptr(); }\n/* static */ void GlobalMode::set_enabled(bool enabled) { *get_mode_ptr() = enabled; }\n\n/* static */ Symbol<NdSbp>* GlobalMode::get_nd_sbp_ptr() {\n  thread_local Symbol<NdSbp> nd_sbp;\n  return &nd_sbp;\n}\n/* static */ Symbol<NdSbp> GlobalMode::nd_sbp() { return *get_nd_sbp_ptr(); }\n/* static */ void GlobalMode::set_nd_sbp(Symbol<NdSbp> nd_sbp) { *get_nd_sbp_ptr() = nd_sbp; }\n\n/* static */ Symbol<ParallelDesc>* GlobalMode::get_parallel_desc_ptr() {\n  thread_local Symbol<ParallelDesc> parallel_desc;\n  return &parallel_desc;\n}\n/* static */ Symbol<ParallelDesc> GlobalMode::parallel_desc() { return *get_parallel_desc_ptr(); }\n/* static */ void GlobalMode::set_parallel_desc(Symbol<ParallelDesc> parallel_desc) {\n  *get_parallel_desc_ptr() = parallel_desc;\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job/global_mode.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#ifndef ONEFLOW_CORE_JOB_GLOBAL_MODE_H_\n#define ONEFLOW_CORE_JOB_GLOBAL_MODE_H_\n\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/common/symbol.h\"\n#include \"oneflow/core/job/parallel_desc.h\"\n#include \"oneflow/core/job/sbp_parallel.pb.h\"\n\nnamespace oneflow {\n\nSymbol<ParallelDesc> GetGlobalParallelDescFromDevice(const Optional<Symbol<Device>>& device);\nclass GlobalMode {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(GlobalMode);\n  GlobalMode() = default;\n  ~GlobalMode() = default;\n\n  static bool is_enabled();\n  static Symbol<NdSbp> nd_sbp();\n  static Symbol<ParallelDesc> parallel_desc();\n\n  class Guard {\n   public:\n    explicit Guard(bool enabled)\n        : prev_mode_(GlobalMode::is_enabled()),\n          prev_nd_sbp_(GlobalMode::nd_sbp()),\n          prev_parallel_desc_(GlobalMode::parallel_desc()) {\n      CHECK(!enabled);\n      GlobalMode::set_enabled(enabled);\n    }\n    explicit Guard(bool enabled, Symbol<NdSbp> nd_sbp, Symbol<ParallelDesc> parallel_desc)\n        : prev_mode_(GlobalMode::is_enabled()),\n          prev_nd_sbp_(GlobalMode::nd_sbp()),\n          prev_parallel_desc_(GlobalMode::parallel_desc()) {\n      GlobalMode::set_enabled(enabled);\n      if (enabled) {\n        GlobalMode::set_nd_sbp(nd_sbp);\n        GlobalMode::set_parallel_desc(parallel_desc);\n      }\n    }\n    ~Guard() {\n      GlobalMode::set_enabled(prev_mode_);\n      GlobalMode::set_nd_sbp(prev_nd_sbp_);\n      GlobalMode::set_parallel_desc(prev_parallel_desc_);\n    }\n\n   private:\n    bool prev_mode_;\n    Symbol<NdSbp> prev_nd_sbp_;\n    Symbol<ParallelDesc> prev_parallel_desc_;\n  };\n\n private:\n  static bool* get_mode_ptr();\n  static Symbol<NdSbp>* get_nd_sbp_ptr();\n  static Symbol<ParallelDesc>* get_parallel_desc_ptr();\n\n  static void set_enabled(bool enabled);\n  static void set_nd_sbp(Symbol<NdSbp> nd_sbp);\n  static void set_parallel_desc(Symbol<ParallelDesc> parallel_desc);\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_JOB_GLOBAL_MODE_H_\n"
  },
  {
    "path": "oneflow/core/job/graph_scope_vars.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/job/graph_scope_vars.h\"\n#include <vector>\n\nnamespace oneflow {\n\nnamespace {\n\nstd::vector<std::string>* GetPythonPathsToBeFilteredForDebuggingVar() {\n  static thread_local std::vector<std::string> filtered_paths;\n  return &filtered_paths;\n}\n\nstd::vector<std::string>* GetPythonPathsToBeKeptForDebuggingVar() {\n  static thread_local std::vector<std::string> kept_paths;\n  return &kept_paths;\n}\n\nbool* GetGraphVerboseStepLr() {\n  static thread_local bool graph_verbose_step_lr = false;\n  return &graph_verbose_step_lr;\n}\n\nint32_t* GetGraphDebugMaxPyStackDepthVar() {\n  static thread_local int32_t graph_debug_max_py_stack_depth = 2;\n  return &graph_debug_max_py_stack_depth;\n}\n\nbool* GetGraphDebugModeFlag() {\n  static thread_local bool graph_debug_mode_flag = false;\n  return &graph_debug_mode_flag;\n}\n\nbool* GetGraphDebugOnlyUserPyStackFlag() {\n  static thread_local bool graph_debug_only_user_py_stack = true;\n  return &graph_debug_only_user_py_stack;\n}\n}  // namespace\n\nbool IsOpenGraphVerboseStepLr() {\n  auto* graph_verbose_step_lr = GetGraphVerboseStepLr();\n  bool is_graph_verbose_step_lr = *graph_verbose_step_lr;\n  return is_graph_verbose_step_lr;\n}\n\nvoid SetGraphVerboseStepLr(bool verbose) {\n  auto* graph_verbose_step_lr = GetGraphVerboseStepLr();\n  *graph_verbose_step_lr = verbose;\n}\n\nvoid InitPythonPathsToBeKeptAndFilteredForDebugging(const std::string& python_base_dir) {\n  std::vector<std::string>* kept_paths = GetPythonPathsToBeKeptForDebuggingVar();\n  kept_paths->clear();\n  kept_paths->push_back(python_base_dir + \"/test\");\n  kept_paths->push_back(python_base_dir + \"/nn/modules\");\n\n  std::vector<std::string>* filtered_paths = GetPythonPathsToBeFilteredForDebuggingVar();\n  filtered_paths->clear();\n  filtered_paths->push_back(python_base_dir);\n}\n\nconst std::vector<std::string>& GetPythonPathsToBeFilteredForDebugging() {\n  return *GetPythonPathsToBeFilteredForDebuggingVar();\n}\nconst std::vector<std::string>& GetPythonPathsToBeKeptForDebugging() {\n  return *GetPythonPathsToBeKeptForDebuggingVar();\n}\n\nvoid SetGraphDebugMaxPyStackDepth(int32_t depth) { *GetGraphDebugMaxPyStackDepthVar() = depth; }\nint32_t GetGraphDebugMaxPyStackDepth() { return *GetGraphDebugMaxPyStackDepthVar(); }\n\nvoid SetGraphDebugMode(bool mode) { *GetGraphDebugModeFlag() = mode; }\nbool GetGraphDebugMode() { return *GetGraphDebugModeFlag(); }\n\nvoid SetGraphDebugOnlyUserPyStack(bool flag) { *GetGraphDebugOnlyUserPyStackFlag() = flag; }\nbool GetGraphDebugOnlyUserPyStack() { return *GetGraphDebugOnlyUserPyStackFlag(); }\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job/graph_scope_vars.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_JOB_GRAPH_SCOPE_VARS_H_\n#define ONEFLOW_CORE_JOB_GRAPH_SCOPE_VARS_H_\n\n#include <cstdint>\n#include <string>\n#include <vector>\n\nnamespace oneflow {\n\nbool IsOpenGraphVerboseStepLr();\nvoid SetGraphVerboseStepLr(bool verbose);\n\nvoid SetGraphDebugMaxPyStackDepth(int32_t depth);\nint32_t GetGraphDebugMaxPyStackDepth();\nvoid SetGraphDebugMode(bool mode);\nbool GetGraphDebugMode();\nvoid SetGraphDebugOnlyUserPyStack(bool flag);\nbool GetGraphDebugOnlyUserPyStack();\nvoid InitPythonPathsToBeKeptAndFilteredForDebugging(const std::string& python_base_dir);\nconst std::vector<std::string>& GetPythonPathsToBeFilteredForDebugging();\nconst std::vector<std::string>& GetPythonPathsToBeKeptForDebugging();\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_JOB_GRAPH_SCOPE_VARS_H_\n"
  },
  {
    "path": "oneflow/core/job/id_manager.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/job/id_manager.h\"\n#include \"oneflow/core/rpc/include/global_process_ctx.h\"\n#include \"oneflow/core/framework/multi_client_session_context.h\"\n#include \"oneflow/core/job/id_state.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nconstexpr static int64_t kRankLimitShift = 16;\nconstexpr static int64_t kIdLimitShift = (sizeof(int64_t) * 8 - kRankLimitShift);\nstatic_assert(kIdLimitShift > 0, \"\");\n\nint64_t AddCurrentRankOffset(int64_t x) {\n  CHECK_GE(x, 0);\n  CHECK_LT(x, (static_cast<int64_t>(1) << kIdLimitShift));\n  return (static_cast<int64_t>(GlobalProcessCtx::Rank()) << kIdLimitShift) + x;\n}\n\n}  // namespace\n\nIDMgr::IDMgr() {\n  regst_desc_id_count_ = 0;\n  mem_block_id_count_ = 0;\n  chunk_id_count_ = 0;\n  CHECK_LE(GlobalProcessCtx::WorldSize(), (static_cast<int64_t>(1) << kRankLimitShift));\n}\n\nint64_t IDMgr::NewRegstDescId() { return AddCurrentRankOffset(regst_desc_id_count_++); }\n\nint64_t IDMgr::NewMemBlockId() { return AddCurrentRankOffset(mem_block_id_count_++); }\n\nint64_t IDMgr::NewChunkId() { return AddCurrentRankOffset(chunk_id_count_++); }\nvoid IDMgr::SaveIdAndTaskIndex(IdState* id_state) {\n  id_state->regst_desc_id_state_ = regst_desc_id_count_;\n  id_state->mem_block_id_state_ = mem_block_id_count_;\n  id_state->chunk_id_state_ = chunk_id_count_;\n  task_id_gen_.GetTaskIndex(&id_state->task_index_state_);\n}\n\nvoid IDMgr::TryUpdateIdAndTaskIndex(const IdState* id_state) {\n  regst_desc_id_count_ = std::max(regst_desc_id_count_.load(std::memory_order_relaxed),\n                                  id_state->regst_desc_id_state_);\n  mem_block_id_count_ =\n      std::max(mem_block_id_count_.load(std::memory_order_relaxed), id_state->mem_block_id_state_);\n  chunk_id_count_ =\n      std::max(chunk_id_count_.load(std::memory_order_relaxed), id_state->chunk_id_state_);\n  task_id_gen_.TryUpdateTaskIndex(id_state->task_index_state_);\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job/id_manager.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_JOB_ID_MANAGER_H_\n#define ONEFLOW_CORE_JOB_ID_MANAGER_H_\n\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/job/id_state.h\"\n#include \"oneflow/core/job/job_desc.h\"\n#include \"oneflow/core/job/resource_desc.h\"\n#include \"oneflow/core/job/global_for.h\"\n#include \"oneflow/core/graph/task_id_generator.h\"\n\nnamespace oneflow {\n\nclass IDMgr final {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(IDMgr);\n  ~IDMgr() = default;\n\n  int64_t NewRegstDescId();\n  int64_t NewMemBlockId();\n  int64_t NewChunkId();\n\n  TaskIdGenerator* GetTaskIdGenerator() { return &task_id_gen_; }\n\n  void SaveIdAndTaskIndex(IdState* id_state);\n  void TryUpdateIdAndTaskIndex(const IdState* id_state);\n\n private:\n  friend class Singleton<IDMgr>;\n  IDMgr();\n\n  std::atomic<int64_t> regst_desc_id_count_;\n  std::atomic<int64_t> mem_block_id_count_;\n  std::atomic<int64_t> chunk_id_count_;\n  TaskIdGenerator task_id_gen_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_JOB_ID_MANAGER_H_\n"
  },
  {
    "path": "oneflow/core/job/id_manager_test.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"gtest/gtest.h\"\n#include \"oneflow/core/control/global_process_ctx.h\"\n#include \"oneflow/core/job/id_manager.h\"\n#include \"oneflow/core/job/global_for.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nstatic const int64_t machine_id_shl = 11 + 21 + 21;\nstatic const int64_t thread_id_shl = 21 + 21;\nstatic const int64_t local_work_stream_shl = 21;\n\nEnvProto GetEnvProto() {\n  EnvProto ret;\n  for (size_t i = 0; i < 10; ++i) {\n    auto* machine = ret.add_machine();\n    machine->set_id(i);\n    machine->set_addr(\"192.168.1.\" + std::to_string(i));\n  }\n  ret.set_ctrl_port(9527);\n  return ret;\n}\n\nResource GetResource() {\n  Resource ret;\n  ret.set_machine_num(10);\n  ret.set_cpu_device_num(5);\n  ret.set_comm_net_worker_num(4);\n  return ret;\n}\n\nvoid New() {\n  Singleton<EnvDesc>::New(GetEnvProto());\n  Singleton<ProcessCtx>::New();\n  Singleton<ProcessCtx>::Get()->mutable_ctrl_addr()->Add();\n  Singleton<ProcessCtx>::Get()->set_rank(0);\n  Singleton<ProcessCtx>::Get()->set_node_size(1);\n  Singleton<ResourceDesc, ForSession>::New(GetResource(), GlobalProcessCtx::NumOfProcessPerNode());\n  Singleton<IDMgr>::New();\n}\n\nvoid Delete() {\n  Singleton<IDMgr>::Delete();\n  Singleton<ProcessCtx>::Delete();\n  Singleton<ResourceDesc, ForSession>::Delete();\n  Singleton<EnvDesc>::Delete();\n}\n\n}  // namespace\n\nTEST(IDMgr, compile_regst_desc_id) {\n  New();\n  ASSERT_EQ(Singleton<IDMgr>::Get()->NewRegstDescId(), 0);\n  ASSERT_EQ(Singleton<IDMgr>::Get()->NewRegstDescId(), 1);\n  ASSERT_EQ(Singleton<IDMgr>::Get()->NewRegstDescId(), 2);\n  Delete();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job/id_state.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_JOB_ID_STATE_H_\n#define ONEFLOW_CORE_JOB_ID_STATE_H_\n\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/device/device_id.h\"\n#include \"oneflow/core/graph/stream_id.h\"\n#include \"oneflow/core/graph/task_id.h\"\n\nnamespace oneflow {\n\nclass IdState {\n public:\n  int64_t regst_desc_id_state_{};\n  int64_t mem_block_id_state_{};\n  int64_t chunk_id_state_{};\n  int64_t job_id_state_{};\n  HashMap<int64_t, uint32_t> task_index_state_{};\n  HashMap<int64_t, uint32_t> stream_index_state_{};\n};\n\n}  // namespace oneflow\n\n#endif\n"
  },
  {
    "path": "oneflow/core/job/initializer_conf.proto",
    "content": "syntax = \"proto2\";\npackage oneflow;\n\nmessage ConstantInitializerConf {\n  optional float value = 1 [default = 0];\n}\n\nmessage ConstantIntInitializerConf {\n  optional int64 value = 1 [default = 0];\n}\n\nmessage RandomNormalInitializerConf {\n  optional float mean = 1 [default = 0];\n  optional float std = 2 [default = 1];\n}\n\n//output[D_0 ... D_(axis - 1) i D_(axis + 1) ... D_n] = start + i * stride\nmessage RangeInitializerConf {\n  optional double start = 1 [default = 0];\n  optional double stride = 2 [default = 1];\n  optional int64 axis = 3 [default = -1];\n}\n\nmessage IntRangeInitializerConf {\n  optional int64 start = 1 [default = 0];\n  optional int64 stride = 2 [default = 1];\n  optional int64 axis = 3 [default = -1];\n}\n\nmessage EmptyInitializerConf {\n}\n\nmessage InitializerConf {\n  oneof type {\n    ConstantInitializerConf constant_conf = 1;\n    ConstantIntInitializerConf constant_int_conf = 2;\n    RandomNormalInitializerConf random_normal_conf = 3;\n    RangeInitializerConf range_conf = 4;\n    IntRangeInitializerConf int_range_conf = 5;\n    EmptyInitializerConf empty_conf = 6;\n  }\n}\n\nmessage InitializeWithSnapshotConf {\n  required string path = 1;\n  optional string key = 2;\n}\n"
  },
  {
    "path": "oneflow/core/job/inter_job_mem_sharing_util.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/job/inter_job_mem_sharing_util.h\"\n#include \"oneflow/core/common/str_util.h\"\n#include \"oneflow/core/memory/memory_case_util.h\"\n#include \"oneflow/core/register/runtime_register_desc.h\"\n#include \"oneflow/core/job/id_manager.h\"\n#include \"oneflow/core/job/plan_util.h\"\n#include \"oneflow/core/persistence/tee_persistent_log_stream.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nvoid GetOpName2JobId2TaskProtos(\n    Plan* plan, const HashSet<std::string>& op_names,\n    HashMap<std::string, HashMap<int64_t, std::vector<TaskProto*>>>* op_name2job_id2task_protos) {\n  for (int64_t i = 0; i < plan->task_size(); ++i) {\n    TaskProto* task = plan->mutable_task(i);\n    if (task->exec_sequence().exec_node_size() == 1) {\n      const KernelConf& kernel_conf = task->exec_sequence().exec_node(0).kernel_conf();\n      std::string op_name =\n          PlanUtil::GetOpAttribute(plan, task->job_id(), kernel_conf).op_conf().name();\n      if (op_names.find(op_name) != op_names.end()) {\n        CHECK(task->has_parallel_ctx());\n        (*op_name2job_id2task_protos)[op_name][task->job_id()].emplace_back(task);\n      }\n    }\n  }\n  for (auto& op2job_task_pair : *op_name2job_id2task_protos) {\n    for (auto& job2task_pair : op2job_task_pair.second) {\n      std::vector<TaskProto*>& task_protos = job2task_pair.second;\n      std::sort(task_protos.begin(), task_protos.end(),\n                [](const TaskProto* lhs, const TaskProto* rhs) {\n                  return lhs->parallel_ctx().parallel_id() < rhs->parallel_ctx().parallel_id();\n                });\n    }\n  }\n}\n\nHashMap<std::string, HashSet<int64_t>> GetInterfaceOpName2JobIds(\n    const std::vector<std::shared_ptr<Job>>& jobs) {\n  HashMap<std::string, HashSet<int64_t>> interface_op_name2job_ids;\n  HashSet<std::string> unique_op_name_check;\n  FOR_RANGE(int64_t, i, 0, jobs.size()) {\n    const auto& job = *jobs.at(i);\n    for (const auto& op : job.net().op()) {\n      if (IsInterfaceOpConf(op)) {\n        CHECK(interface_op_name2job_ids[op.name()].emplace(i).second);\n        unique_op_name_check.emplace(op.name());\n      } else {\n        // interface ops shouldn't share op_name with other ops\n        CHECK(unique_op_name_check.find(op.name()) == unique_op_name_check.end());\n      }\n    }\n  }\n  return interface_op_name2job_ids;\n}\n\nstd::vector<HashSet<int64_t>> InitJobId2MutualExclusionJobIds(\n    const std::vector<std::shared_ptr<Job>>& jobs) {\n  int64_t job_size = jobs.size();\n  std::vector<HashSet<int64_t>> job_id2mutual_exclusion_ids(job_size);\n  for (const auto& pair : GetInterfaceOpName2JobIds(jobs)) {\n    for (int64_t first_id : pair.second) {\n      for (int64_t second_id : pair.second) {\n        if (first_id != second_id) { job_id2mutual_exclusion_ids[first_id].emplace(second_id); }\n      }\n    }\n  }\n  const InterJobReuseMemStrategy* strategy = Singleton<const InterJobReuseMemStrategy>::Get();\n  if (strategy->has_custom_parallelism()) {\n    auto* job_name2job_id = Singleton<JobName2JobId>::Get();\n    for (const auto& group : strategy->custom_parallelism().nonparallel_group()) {\n      for (const std::string& first_name : group.job_name()) {\n        for (const std::string& second_name : group.job_name()) {\n          if (first_name != second_name) {\n            CHECK(job_name2job_id->find(first_name) != job_name2job_id->end());\n            CHECK(job_name2job_id->find(second_name) != job_name2job_id->end());\n            int64_t first_id = (*job_name2job_id)[first_name];\n            int64_t second_id = (*job_name2job_id)[second_name];\n            job_id2mutual_exclusion_ids[first_id].emplace(second_id);\n          }\n        }\n      }\n    }\n  }\n  return job_id2mutual_exclusion_ids;\n}\n\nstd::vector<HashSet<int64_t>> GetMutualExclusionJobGroups(\n    const std::vector<std::shared_ptr<Job>>& jobs) {\n  int64_t job_size = jobs.size();\n  std::vector<HashSet<int64_t>> job_groups;\n  job_groups.reserve(job_size);\n  if (Singleton<const InterJobReuseMemStrategy>::Get()->has_reuse_mem_priority()) {\n    job_groups.emplace_back(HashSet<int64_t>());\n    FOR_RANGE(int64_t, i, 0, job_size) { job_groups.front().emplace(i); }\n    return job_groups;\n  }\n\n  // default using parallelism_priority strategy\n  std::vector<HashSet<int64_t>> job_id2mutual_exclusion_ids = InitJobId2MutualExclusionJobIds(jobs);\n  std::vector<HashSet<int64_t>> job_id2enable_parallel_ids(job_size);\n  FOR_RANGE(int64_t, i, 0, job_size) {\n    FOR_RANGE(int64_t, j, 0, job_size) {\n      if (job_id2mutual_exclusion_ids[i].find(j) == job_id2mutual_exclusion_ids[i].end()) {\n        job_id2enable_parallel_ids[i].emplace(j);\n      }\n    }\n  }\n  int64_t mem_share_group_num = 0;\n  std::vector<int64_t> job_id2mem_share_group_id(job_size, -1);\n  FOR_RANGE(int64_t, this_job_id, 0, job_size) {\n    HashSet<int64_t> mem_share_group_id_used;\n    for (int64_t enable_parallel_job_id : job_id2enable_parallel_ids[this_job_id]) {\n      int64_t group_id = job_id2mem_share_group_id[enable_parallel_job_id];\n      if (group_id != -1) { mem_share_group_id_used.emplace(group_id); }\n    }\n    FOR_RANGE(int64_t, this_group_id, 0, mem_share_group_num) {\n      if (mem_share_group_id_used.find(this_group_id) == mem_share_group_id_used.end()) {\n        job_id2mem_share_group_id[this_job_id] = this_group_id;\n        break;\n      }\n    }\n    if (job_id2mem_share_group_id[this_job_id] == -1) {\n      job_id2mem_share_group_id[this_job_id] = mem_share_group_num;\n      ++mem_share_group_num;\n      CHECK_LE(mem_share_group_num, job_size);\n    }\n  }\n\n  job_groups.resize(mem_share_group_num);\n  FOR_RANGE(int64_t, this_job_id, 0, job_size) {\n    job_groups[job_id2mem_share_group_id[this_job_id]].emplace(this_job_id);\n  }\n  {\n    HashSet<int64_t> job_id_unique_check;\n    for (auto& job_group : job_groups) {\n      for (int64_t job_id : job_group) { CHECK(job_id_unique_check.emplace(job_id).second); }\n    }\n  }\n  return job_groups;\n}\n\nvoid MergeReusedChunk(HashMap<int64_t, ChunkProto>* chunk_id2chunk,\n                      HashMap<int64_t, MemBlockProto*>* mem_block_id2mem_block,\n                      const std::vector<HashSet<int64_t>>& reuse_mem_job_groups) {\n  // mzuid = memory zone unique id\n  HashMap<int64_t, HashMap<int64_t, int64_t>> job_id2mzuid2chunk_id;\n  HashMap<int64_t, HashSet<MemBlockProto*>> chunk_id2mem_blocks;\n\n  for (auto& pair : *mem_block_id2mem_block) {\n    MemBlockProto* mem_block = pair.second;\n    if (mem_block->enable_reuse_mem() == false) {\n      CHECK(mem_block->has_chunk_id() == false);\n      CHECK(mem_block->has_chunk_offset() == false);\n      continue;\n    }\n    CHECK(mem_block->has_chunk_id() && mem_block->chunk_id() >= 0);\n    CHECK(mem_block->has_chunk_offset() && mem_block->chunk_offset() >= 0);\n    CHECK(chunk_id2mem_blocks[mem_block->chunk_id()].insert(mem_block).second);\n  }\n\n  // merge chunk and delete useless chunk\n  for (const auto& pair : *chunk_id2chunk) {\n    const ChunkProto& chunk = pair.second;\n    const MemoryCase& mem_case = chunk.mem_case();\n    // NOTE(zwx): do not reuse mem on cpu\n    if (memory::IsHostMem(mem_case)) { continue; }\n    int64_t mzuid = memory::GetUniqueMemCaseId(chunk.machine_id(), mem_case);\n    CHECK_EQ(chunk.job_id_size(), 1);\n    CHECK(job_id2mzuid2chunk_id[chunk.job_id(0)].emplace(mzuid, chunk.chunk_id()).second);\n  }\n\n  auto MergeMemChunkIdR2L = [&](int64_t left_chunk_id, int64_t right_chunk_id) {\n    CHECK_NE(left_chunk_id, right_chunk_id);\n    ChunkProto* chunk_l = &(chunk_id2chunk->at(left_chunk_id));\n    ChunkProto* chunk_r = &(chunk_id2chunk->at(right_chunk_id));\n    CHECK_GE(chunk_l->job_id_size(), 1);\n    CHECK_EQ(chunk_r->job_id_size(), 1);\n    CHECK_EQ(chunk_l->machine_id(), chunk_r->machine_id());\n    CHECK(chunk_l->mem_case() == chunk_r->mem_case());\n    CHECK_GT(chunk_l->mem_size(), 0);\n    CHECK_GT(chunk_r->mem_size(), 0);\n    for (MemBlockProto* mem_block : chunk_id2mem_blocks[right_chunk_id]) {\n      CHECK_EQ(mem_block->machine_id(), chunk_l->machine_id());\n      CHECK(mem_block->mem_case() == chunk_l->mem_case());\n      mem_block->set_chunk_id(left_chunk_id);\n    }\n    chunk_l->add_job_id(chunk_r->job_id(0));\n    chunk_l->set_mem_size(std::max(chunk_l->mem_size(), chunk_r->mem_size()));\n    chunk_id2chunk->erase(chunk_id2chunk->find(right_chunk_id));\n  };\n  auto InitMzuid2JobIdsInJobGroup =\n      [&](const HashSet<int64_t>& job_group) -> HashMap<int64_t, HashSet<int64_t>> {\n    HashMap<int64_t, HashSet<int64_t>> mzuid2job_ids;\n    for (int64_t job_id : job_group) {\n      for (const auto& pair : job_id2mzuid2chunk_id[job_id]) {\n        CHECK(mzuid2job_ids[pair.first].emplace(job_id).second);\n      }\n    }\n    return mzuid2job_ids;\n  };\n  for (const HashSet<int64_t>& job_group : reuse_mem_job_groups) {\n    if (job_group.size() <= 1) { continue; }\n    HashMap<int64_t, HashSet<int64_t>> mzuid2job_ids = InitMzuid2JobIdsInJobGroup(job_group);\n    for (const auto& pair : mzuid2job_ids) {\n      const HashSet<int64_t>& job_ids = pair.second;\n      if (job_ids.size() <= 1) { continue; }\n      int64_t mzuid = pair.first;\n      int64_t merged_job_id = *(job_ids.begin());\n      for (int64_t job_id : job_ids) {\n        if (job_id == merged_job_id) { continue; }\n        MergeMemChunkIdR2L(job_id2mzuid2chunk_id[merged_job_id].at(mzuid),\n                           job_id2mzuid2chunk_id[job_id].at(mzuid));\n      }\n    }\n  }\n}\n\nvoid MergeSharedMemBlockR2L(RegstDescProto* lhs, RegstDescProto* rhs,\n                            HashMap<int64_t, MemBlockProto>* mem_block_id2mem_block) {\n  if (lhs == rhs) { return; }\n  auto CheckValidAndGetMemBlock = [&](int64_t mem_block_id, int64_t mem_size,\n                                      const MemoryCase& mem_case) {\n    CHECK_NE(mem_block_id, -1);\n    CHECK(mem_block_id2mem_block->find(mem_block_id) != mem_block_id2mem_block->end());\n    MemBlockProto* mem_block = &(mem_block_id2mem_block->at(mem_block_id));\n    CHECK(mem_block->enable_reuse_mem() == false);\n    CHECK(mem_block->has_chunk_id() == false);\n    CHECK(mem_block->has_chunk_offset() == false);\n    CHECK_EQ(mem_block->mem_size(), mem_size);\n    CHECK(mem_block->mem_case() == mem_case);\n    return mem_block;\n  };\n\n  auto MergeAndEraseMemBlock = [&](MemBlockProto* merged_block, MemBlockProto* erased_block) {\n    CHECK_NE(merged_block->mem_block_id(), erased_block->mem_block_id());\n    CHECK_EQ(erased_block->job_id_size(), 1);\n    CHECK_EQ(merged_block->mem_size(), erased_block->mem_size());\n    merged_block->add_job_id(erased_block->job_id(0));\n    CHECK_EQ(mem_block_id2mem_block->erase(erased_block->mem_block_id()), 1);\n  };\n\n  int64_t merged_mem_block_id = lhs->mem_block_id();\n  int64_t erased_mem_block_id = rhs->mem_block_id();\n  CHECK(lhs->enable_reuse_mem() == false && rhs->enable_reuse_mem() == false);\n  CHECK_EQ(lhs->mem_block_offset(), 0);\n  CHECK_EQ(rhs->mem_block_offset(), 0);\n  RtRegstDesc left_rt_regst(*lhs);\n  RtRegstDesc right_rt_regst(*rhs);\n  MemBlockProto* merged_mem_block = CheckValidAndGetMemBlock(\n      merged_mem_block_id, left_rt_regst.TotalMainByteSize4AllRegst(), lhs->mem_case());\n  MemBlockProto* erased_mem_block = CheckValidAndGetMemBlock(\n      erased_mem_block_id, right_rt_regst.TotalMainByteSize4AllRegst(), rhs->mem_case());\n  MergeAndEraseMemBlock(merged_mem_block, erased_mem_block);\n  rhs->set_mem_block_id(merged_mem_block_id);\n\n  int64_t separated_header_mem_size = left_rt_regst.TotalSeparatedHeaderByteSize4AllRegst();\n  if (separated_header_mem_size > 0) {\n    CHECK_EQ(separated_header_mem_size, right_rt_regst.TotalSeparatedHeaderByteSize4AllRegst());\n    int64_t merged_header_id = lhs->separated_header_mem_block_id();\n    int64_t erased_header_id = rhs->separated_header_mem_block_id();\n    MemoryCase header_mem_case = memory::GetPinnedHostMemoryCase(lhs->mem_case());\n    MemBlockProto* merged_header_block =\n        CheckValidAndGetMemBlock(merged_header_id, separated_header_mem_size, header_mem_case);\n    MemBlockProto* erased_header_block =\n        CheckValidAndGetMemBlock(erased_header_id, separated_header_mem_size, header_mem_case);\n    MergeAndEraseMemBlock(merged_header_block, erased_header_block);\n    rhs->set_separated_header_mem_block_id(merged_header_id);\n  }\n}\n\nvoid MergeSharedInterfaceMemBlock(const std::vector<std::shared_ptr<Job>>& jobs, Plan* plan,\n                                  HashMap<int64_t, MemBlockProto>* mem_block_id2mem_block) {\n  HashMap<std::string, HashSet<int64_t>> interface_op_name2job_ids =\n      GetInterfaceOpName2JobIds(jobs);\n  HashSet<std::string> interfaces_op_names;\n  for (const auto& pair : interface_op_name2job_ids) { interfaces_op_names.insert(pair.first); }\n  HashMap<std::string, HashMap<int64_t, std::vector<TaskProto*>>> op_name2job_id2task_protos;\n  GetOpName2JobId2TaskProtos(plan, interfaces_op_names, &op_name2job_id2task_protos);\n\n  for (const auto& op_job_pair : interface_op_name2job_ids) {\n    if (op_job_pair.second.size() <= 1) { continue; }\n    const HashMap<int64_t, std::vector<TaskProto*>>& job_id2same_op_name_sorted_task_protos =\n        op_name2job_id2task_protos.at(op_job_pair.first);\n    const auto& first_vec = job_id2same_op_name_sorted_task_protos.begin()->second;\n    std::vector<MemoryCase> common_mem_case_vec(first_vec.size());\n    std::transform(\n        first_vec.cbegin(), first_vec.cend(), common_mem_case_vec.begin(),\n        [](TaskProto* tp) { return PlanUtil::GetSoleProducedDataRegst(tp)->mem_case(); });\n    for (const auto& pair : job_id2same_op_name_sorted_task_protos) {\n      const auto& task_protos = pair.second;\n      CHECK_EQ(task_protos.size(), first_vec.size());\n      FOR_RANGE(int64_t, i, 0, first_vec.size()) {\n        CHECK_EQ(task_protos.at(i)->machine_id(), first_vec.at(i)->machine_id());\n        RegstDescProto* first_regst_desc = PlanUtil::GetSoleProducedDataRegst(first_vec.at(i));\n        RegstDescProto* regst_desc = PlanUtil::GetSoleProducedDataRegst(task_protos.at(i));\n\n        MergeSharedMemBlockR2L(first_regst_desc, regst_desc, mem_block_id2mem_block);\n\n        CHECK(memory::EqualsIgnorePinnedDevice(common_mem_case_vec.at(i), regst_desc->mem_case()));\n        common_mem_case_vec[i] = regst_desc->mem_case();\n      }\n    }\n    for (const auto& pair : job_id2same_op_name_sorted_task_protos) {\n      const auto& task_protos = pair.second;\n      FOR_RANGE(int64_t, i, 0, task_protos.size()) {\n        RegstDescProto* regst_desc = PlanUtil::GetSoleProducedDataRegst(task_protos.at(i));\n        *(regst_desc->mutable_mem_case()) = common_mem_case_vec.at(i);\n        CHECK(mem_block_id2mem_block->find(regst_desc->mem_block_id())\n              != mem_block_id2mem_block->end());\n        *(mem_block_id2mem_block->at(regst_desc->mem_block_id()).mutable_mem_case()) =\n            common_mem_case_vec.at(i);\n      }\n    }\n  }\n}\n\n}  // namespace\n\nvoid InterJobMemSharingUtil::MergeMemSharedInterfaceMemBlockBetweenJobs(\n    const std::vector<std::shared_ptr<Job>>& jobs, Plan* plan) {\n  if (jobs.size() == 1) { return; }\n\n  HashMap<int64_t, MemBlockProto> mem_block_id2mem_block;\n  for (const auto& mem_block : plan->block_chunk_list().mem_block()) {\n    CHECK(mem_block_id2mem_block.emplace(mem_block.mem_block_id(), mem_block).second);\n  }\n  plan->mutable_block_chunk_list()->clear_mem_block();\n\n  MergeSharedInterfaceMemBlock(jobs, plan, &mem_block_id2mem_block);\n\n  for (const auto& pair : mem_block_id2mem_block) {\n    *(plan->mutable_block_chunk_list()->add_mem_block()) = pair.second;\n  }\n}\n\nvoid InterJobMemSharingUtil::MergeMemReusedChunkBetweenUserJobs(\n    const std::vector<std::shared_ptr<Job>>& user_jobs, Plan* plan) {\n  if (user_jobs.size() == 1) { return; }\n  std::vector<HashSet<int64_t>> reuse_mem_job_groups = GetMutualExclusionJobGroups(user_jobs);\n\n  HashMap<int64_t, ChunkProto> chunk_id2chunk;\n  HashMap<int64_t, MemBlockProto*> mem_block_id2mem_block;\n  for (const auto& chunk : plan->block_chunk_list().chunk()) {\n    CHECK(chunk_id2chunk.emplace(chunk.chunk_id(), chunk).second);\n  }\n  plan->mutable_block_chunk_list()->clear_chunk();\n  for (MemBlockProto& mem_block : *plan->mutable_block_chunk_list()->mutable_mem_block()) {\n    CHECK(mem_block_id2mem_block.emplace(mem_block.mem_block_id(), &mem_block).second);\n  }\n\n  MergeReusedChunk(&chunk_id2chunk, &mem_block_id2mem_block, reuse_mem_job_groups);\n\n  for (const auto& pair : chunk_id2chunk) {\n    *(plan->mutable_block_chunk_list()->add_chunk()) = pair.second;\n  }\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job/inter_job_mem_sharing_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_JOB_INTER_JOB_MEM_SHARING_UTIL_H_\n#define ONEFLOW_CORE_JOB_INTER_JOB_MEM_SHARING_UTIL_H_\n\n#include \"oneflow/core/job/job_set.pb.h\"\n#include \"oneflow/core/job/plan.pb.h\"\n\nnamespace oneflow {\n\nstruct InterJobMemSharingUtil {\n  static void MergeMemSharedInterfaceMemBlockBetweenJobs(\n      const std::vector<std::shared_ptr<Job>>& jobs, Plan* plan);\n\n  static void MergeMemReusedChunkBetweenUserJobs(const std::vector<std::shared_ptr<Job>>& user_jobs,\n                                                 Plan* plan);\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_JOB_INTER_JOB_MEM_SHARING_UTIL_H_\n"
  },
  {
    "path": "oneflow/core/job/inter_user_job_info.proto",
    "content": "syntax = \"proto2\";\npackage oneflow;\n\nmessage InterUserJobInfo {\n  map<string, string> input_or_var_op_name2push_job_name = 1;\n  map<string, string> output_or_var_op_name2pull_job_name = 2;\n  optional string global_model_init_job_name = 4;\n  optional string global_model_load_job_name = 5;\n  optional string global_model_save_job_name = 6;\n}\n"
  },
  {
    "path": "oneflow/core/job/intra_job_mem_sharing_util.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/job/intra_job_mem_sharing_util.h\"\n#include <vector>\n#include \"oneflow/core/common/blocking_counter.h\"\n#include \"oneflow/core/common/data_type.h\"\n#include \"oneflow/core/common/hash_container.h\"\n#include \"oneflow/core/common/str_util.h\"\n#include \"oneflow/core/common/shape.h\"\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/job/id_manager.h\"\n#include \"oneflow/core/job/job_conf.pb.h\"\n#include \"oneflow/core/job/job_desc.h\"\n#include \"oneflow/core/job/memory_share_strategy.h\"\n#include \"oneflow/core/register/runtime_register_desc.h\"\n#include \"oneflow/core/thread/thread_pool.h\"\n#include \"oneflow/core/graph/task_node.h\"\n#include \"oneflow/core/job/plan_util.h\"\n\nnamespace oneflow {\n\nenum MemAllocAlgoType {\n  kMemSizeFirstAlgo = 0,\n  kLifetimeFirstAlgo = 1,\n  kTimeLineAlgo = 2,\n  kMemVolumeFirstAlgo = 3,\n};\n\n}  // namespace oneflow\n\nnamespace std {\n\ntemplate<>\nstruct hash<::oneflow::MemAllocAlgoType> {\n  std::size_t operator()(const ::oneflow::MemAllocAlgoType& type) const {\n    return std::hash<int>()(static_cast<size_t>(type));\n  }\n};\n\n}  // namespace std\n\nnamespace oneflow {\n\nnamespace {\n\nint64_t GenDeviceUniqueId(int64_t machine_id, int64_t device_id) {\n  return (machine_id << 32) | device_id;\n}\n\nvoid TryConnectWithMemSafeGuardCtrlRegstDesc(TaskProto* src_task_proto, TaskProto* dst_task_proto) {\n  RegstDescProto* ctrl_regst_desc =\n      FindOrCreateProducedCtrlRegstDesc(src_task_proto, \"out_ctrl_shared_mem_safe_guard\");\n  int64_t dst_task_id = dst_task_proto->task_id();\n  if (!IsInRepeatedField(ctrl_regst_desc->consumer_task_id(), dst_task_id)) {\n    ctrl_regst_desc->add_consumer_task_id(dst_task_id);\n    int64_t ctrl_regst_desc_id = ctrl_regst_desc->regst_desc_id();\n    RegstDescIdSet* consumed_ctrl_regst_desc_ids =\n        FindOrCreateConsumedCtrlRegstDescIdSet(dst_task_proto, \"in_ctrl\");\n    CHECK(!IsInRepeatedField(consumed_ctrl_regst_desc_ids->regst_desc_id(), ctrl_regst_desc_id));\n    consumed_ctrl_regst_desc_ids->add_regst_desc_id(ctrl_regst_desc_id);\n  }\n}\n\nstruct MemoryChain {\n  std::vector<TaskProto*> sorted_tasks;\n  HashSet<RegstDescProto*> mem_reused_regsts;\n  int64_t total_mem_reused_size = 0;\n  Shape time_shape;\n};\n\nvoid InitMemoryChains(Plan* plan,\n                      HashMap<int64_t, HashMap<int64_t, MemoryChain>>* device2chain2mem_chain,\n                      HashMap<RegstDescProto*, size_t>* mem_reused_regst2size) {\n  for (int64_t i = 0; i < plan->task_size(); ++i) {\n    TaskProto* task = plan->mutable_task(i);\n    const StreamId stream_id = PlanUtil::GetStreamId(*task);\n    int64_t machine_id = task->machine_id();\n    DeviceType device_type = stream_id.device_id().device_type();\n    // TODO(zwx): eliminate this special 'is cpu' determine\n    if (device_type == DeviceType::kCPU) { continue; }\n    if (!IsValidChainId(task->chain_id())) { continue; }\n    int64_t device_id = stream_id.device_id().device_index();\n    int64_t device_unique_id = GenDeviceUniqueId(machine_id, device_id);\n    MemoryChain* mem_chain = &((*device2chain2mem_chain)[device_unique_id][task->chain_id()]);\n    mem_chain->sorted_tasks.emplace_back(task);\n    for (auto& pair : *(task->mutable_produced_regst_desc())) {\n      RegstDescProto* regst_desc = &pair.second;\n      int64_t regst_total_main_size = RtRegstDesc(*regst_desc).TotalMainByteSize4AllRegst();\n      if (regst_desc->mem_case().device_type() == device_type\n          && regst_desc->mem_case().device_id() == device_id && regst_desc->enable_reuse_mem()\n          && regst_desc->register_num() == 1 && regst_desc->mem_block_id() == -1\n          && regst_desc->mem_block_offset() == -1\n          && regst_desc->regst_desc_type().has_data_regst_desc() && regst_total_main_size > 0) {\n        CHECK(mem_chain->mem_reused_regsts.insert(regst_desc).second);\n        (*mem_reused_regst2size)[regst_desc] = regst_total_main_size;\n        mem_chain->total_mem_reused_size += regst_total_main_size;\n\n        // for time shape in mem chain\n        Shape regst_time_shape =\n            Shape(regst_desc->regst_desc_type().data_regst_desc().time_shape());\n        if (!mem_chain->time_shape.is_initialized()) {\n          mem_chain->time_shape = regst_time_shape;\n        } else {\n          CHECK(mem_chain->time_shape == regst_time_shape);\n        }\n      }\n    }\n  }\n  for (auto& device_pair : *device2chain2mem_chain) {\n    HashMap<int64_t, MemoryChain>* chain2mem_chain = &device_pair.second;\n    HashSet<int64_t> useless_chain_ids;\n    for (auto& pair : *chain2mem_chain) {\n      if (pair.second.mem_reused_regsts.empty()) { useless_chain_ids.insert(pair.first); }\n    }\n    for (int64_t chain_id : useless_chain_ids) { chain2mem_chain->erase(chain_id); }\n    for (auto& pair : *chain2mem_chain) {\n      MemoryChain* mem_chain = &pair.second;\n      std::sort(mem_chain->sorted_tasks.begin(), mem_chain->sorted_tasks.end(),\n                [&](const TaskProto* lhs, const TaskProto* rhs) {\n                  int64_t lhs_order_in_chain = lhs->order_in_chain();\n                  int64_t rhs_order_in_chain = rhs->order_in_chain();\n                  CHECK_NE(lhs_order_in_chain, rhs_order_in_chain);\n                  return lhs_order_in_chain < rhs_order_in_chain;\n                });\n    }\n  }\n}\n\nbool IsReachableToAnyOtherTask(const TaskProto* src_task, const HashSet<int64_t>& task_ids) {\n  for (const auto& pair : src_task->produced_regst_desc()) {\n    for (int64_t consumer : pair.second.consumer_task_id()) {\n      if (task_ids.find(consumer) != task_ids.end()) { return true; }\n    }\n  }\n  return false;\n}\n\nbool IsTaskConnectedL2R(const TaskProto* src, const TaskProto* dst) {\n  for (const auto& pair : src->produced_regst_desc()) {\n    for (int64_t consumer : pair.second.consumer_task_id()) {\n      if (consumer == dst->task_id()) { return true; }\n    }\n  }\n  return false;\n}\n\nvoid GenMemChainTasksAndRegsts(\n    Plan* plan, HashMap<int64_t, std::vector<TaskProto*>>* mem_chain2sorted_tasks,\n    HashMap<int64_t, std::vector<RegstDescProto*>>* mem_chain2mem_reused_regsts,\n    HashMap<int64_t, HashMap<int64_t, RegstDescProto*>>* mem_chain2regst_desc_id2reuse_regst_desc,\n    HashMap<RegstDescProto*, size_t>* mem_reused_regst2size) {\n  mem_chain2sorted_tasks->clear();\n  mem_chain2mem_reused_regsts->clear();\n  HashMap<int64_t, HashMap<int64_t, MemoryChain>> device2chain2mem_chain;\n  InitMemoryChains(plan, &device2chain2mem_chain, mem_reused_regst2size);\n\n  int64_t mem_chain_id = 0;\n\n  for (auto& device_chain_pair : device2chain2mem_chain) {\n    if (device_chain_pair.second.empty()) { continue; }\n    std::vector<MemoryChain*> mem_chains;\n    mem_chains.reserve(device_chain_pair.second.size());\n    for (auto& pair : device_chain_pair.second) { mem_chains.emplace_back(&pair.second); }\n    for (MemoryChain* mem_chain : mem_chains) {\n      std::vector<TaskProto*>* sorted_tasks = &((*mem_chain2sorted_tasks)[mem_chain_id]);\n      CHECK(sorted_tasks->empty());\n      sorted_tasks->insert(sorted_tasks->end(), mem_chain->sorted_tasks.begin(),\n                           mem_chain->sorted_tasks.end());\n      std::vector<RegstDescProto*>* mem_reused_regsts =\n          &((*mem_chain2mem_reused_regsts)[mem_chain_id]);\n      CHECK(mem_reused_regsts->empty());\n      mem_reused_regsts->insert(mem_reused_regsts->end(), mem_chain->mem_reused_regsts.begin(),\n                                mem_chain->mem_reused_regsts.end());\n      // Merge HashSet mem_chain2mem_reused_regsts and HashMap regst_desc_id2reuse_regst_desc\n      auto& regst_desc_id2reuse_regst_desc =\n          (*mem_chain2regst_desc_id2reuse_regst_desc)[mem_chain_id];\n      CHECK(regst_desc_id2reuse_regst_desc.empty());\n      for (auto& mem_reused_regst : mem_chain->mem_reused_regsts) {\n        regst_desc_id2reuse_regst_desc[mem_reused_regst->regst_desc_id()] = mem_reused_regst;\n      }\n      ++mem_chain_id;\n    }\n  }\n\n  CHECK_EQ(mem_chain2sorted_tasks->size(), mem_chain2mem_reused_regsts->size());\n\n  // NOTE(chengcheng): add ctrl safe guard for each mem chain\n  HashMap<int64_t, TaskProto*> task_id2proto;\n  for (int64_t i = 0; i < plan->task_size(); ++i) {\n    TaskProto* task = plan->mutable_task(i);\n    CHECK(task_id2proto.emplace(task->task_id(), task).second);\n  }\n  for (auto& pair : *mem_chain2sorted_tasks) {\n    std::vector<TaskProto*>* sorted_tasks = &(pair.second);\n    // NOTE(chengcheng): We CANNOT only add ctrl safe guard between first and last task,\n    //  because of the sorted_tasks may connected as a graph, has multi-tail tasks(sink task).\n    const std::vector<RegstDescProto*>& mem_reused_regsts =\n        mem_chain2mem_reused_regsts->at(pair.first);\n    if (mem_reused_regsts.size() <= 1) { continue; }\n\n    HashSet<int64_t> consumer_task_ids;\n    for (const RegstDescProto* regst : mem_reused_regsts) {\n      for (int64_t consumer : regst->consumer_task_id()) { consumer_task_ids.insert(consumer); }\n    }\n    std::vector<TaskProto*> sink_tasks;\n    sink_tasks.reserve(consumer_task_ids.size());\n    for (int64_t src_task_id : consumer_task_ids) {\n      auto it = task_id2proto.find(src_task_id);\n      CHECK(it != task_id2proto.end());\n      if (!IsReachableToAnyOtherTask(it->second, consumer_task_ids)) {\n        sink_tasks.emplace_back(it->second);\n      }\n    }\n\n    TaskProto* first_task = sorted_tasks->front();\n    for (TaskProto* sink_task : sink_tasks) {\n      CHECK(first_task != sink_task);\n      if (!IsTaskConnectedL2R(first_task, sink_task)) {\n        TryConnectWithMemSafeGuardCtrlRegstDesc(first_task, sink_task);\n      }\n    }\n  }\n}\n\nvoid GenRegstAllocFreeTimeLineAndRegstLifetimes(\n    const std::vector<TaskProto*>& sorted_tasks,\n    const std::vector<RegstDescProto*>& mem_reused_regsts,\n    const HashMap<int64_t, RegstDescProto*>& regst_desc_id2reuse_regst_desc,\n    const HashMap<RegstDescProto*, size_t>& mem_reused_regst2size,\n    HashMap<RegstDescProto*, std::pair<int32_t, int32_t>>* regst2lifetime,\n    HashMap<RegstDescProto*, RegstDescProto*>* consumer2inplaced_regst, size_t* peak_memory) {\n  CHECK(consumer2inplaced_regst->empty());\n  std::vector<std::vector<RegstDescProto*>> alloc_regsts_timeline(sorted_tasks.size());\n  std::vector<std::vector<RegstDescProto*>> free_regsts_timeline(sorted_tasks.size());\n  HashMap<int64_t, int64_t> task_id2sorted_id;\n  for (int64_t i = 0; i < sorted_tasks.size(); ++i) {\n    TaskProto* task = sorted_tasks.at(i);\n    CHECK(task_id2sorted_id.emplace(task->task_id(), i).second);\n  }\n\n  auto FindLastFreeIndexInSortedTasks = [&](RegstDescProto* regst_desc) -> int64_t {\n    // temp regst will set free index as same as alloc index\n    int64_t free_index = task_id2sorted_id.at(regst_desc->producer_task_id());\n    for (int64_t consumer_task_id : regst_desc->consumer_task_id()) {\n      // if consumer is not in this mem chain, set free index = last index\n      int64_t this_sorted_index = sorted_tasks.size() - 1;\n      if (task_id2sorted_id.find(consumer_task_id) != task_id2sorted_id.end()) {\n        this_sorted_index = task_id2sorted_id.at(consumer_task_id);\n      }\n      free_index = std::max(free_index, this_sorted_index);\n    }\n    return free_index;\n  };\n\n  auto TryFindFirstInplacedRegstDesc = [&](RegstDescProto* consumer_regst) -> RegstDescProto* {\n    RegstDescProto* inplaced_regst = nullptr;\n    while (consumer_regst->has_hint_inplace_consumed_regst_desc_id()\n           && consumer_regst->hint_inplace_consumed_regst_desc_id() != -1) {\n      const auto& iterator_hint_inplaced_regst = regst_desc_id2reuse_regst_desc.find(\n          consumer_regst->hint_inplace_consumed_regst_desc_id());\n      if (iterator_hint_inplaced_regst != regst_desc_id2reuse_regst_desc.end()) {\n        inplaced_regst = iterator_hint_inplaced_regst->second;\n        consumer_regst = iterator_hint_inplaced_regst->second;\n      } else {\n        break;\n      }\n    }\n    return inplaced_regst;\n  };\n\n  HashMap<int64_t, int64_t> regst_desc_id2free_index;\n  for (RegstDescProto* regst_desc : mem_reused_regsts) {\n    RegstDescProto* inplaced_regst_desc = TryFindFirstInplacedRegstDesc(regst_desc);\n    if (inplaced_regst_desc != nullptr) {\n      CHECK(consumer2inplaced_regst->emplace(regst_desc, inplaced_regst_desc).second);\n      continue;\n    }\n\n    alloc_regsts_timeline[task_id2sorted_id.at(regst_desc->producer_task_id())].push_back(\n        regst_desc);\n    CHECK(regst_desc_id2free_index\n              .emplace(regst_desc->regst_desc_id(), FindLastFreeIndexInSortedTasks(regst_desc))\n              .second);\n  }\n  // inplace extend regst free index\n  for (auto pair : *consumer2inplaced_regst) {\n    RegstDescProto* consumer_regst_desc = pair.first;\n    int64_t inplaced_regst_desc_id = pair.second->regst_desc_id();\n    CHECK(regst_desc_id2free_index.find(inplaced_regst_desc_id) != regst_desc_id2free_index.end());\n    regst_desc_id2free_index.at(inplaced_regst_desc_id) =\n        std::max(regst_desc_id2free_index.at(inplaced_regst_desc_id),\n                 FindLastFreeIndexInSortedTasks(consumer_regst_desc));\n  }\n  for (const auto& pair : regst_desc_id2free_index) {\n    free_regsts_timeline[pair.second].push_back(regst_desc_id2reuse_regst_desc.at(pair.first));\n  }\n\n  HashSet<RegstDescProto*> remain_regsts;\n  size_t remain_memory = 0;\n  *peak_memory = 0;\n  for (int64_t i = 0; i < sorted_tasks.size(); ++i) {\n    for (RegstDescProto* alloc_regst : alloc_regsts_timeline.at(i)) {\n      // Record the born time\n      (*regst2lifetime)[alloc_regst].first = i;\n      CHECK(remain_regsts.insert(alloc_regst).second);\n      remain_memory += mem_reused_regst2size.at(alloc_regst);\n      // NOTE(chengcheng): insert time line to regst proto\n      alloc_regst->set_mem_block_total_actor_count(sorted_tasks.size());\n      alloc_regst->set_alloc_before_actor(i);\n    }\n    // Update the peak of memory during execution\n    if (*peak_memory < remain_memory) { *peak_memory = remain_memory; }\n    for (RegstDescProto* free_regst : free_regsts_timeline.at(i)) {\n      CHECK_EQ(remain_regsts.erase(free_regst), 1);\n      free_regst->set_free_after_actor(i);\n      remain_memory -= mem_reused_regst2size.at(free_regst);\n      // Record the die time\n      (*regst2lifetime)[free_regst].second = i + 1;\n    }\n  }\n  // Make sure that every register has a die time\n  CHECK(remain_regsts.empty());\n}\n\nvoid MemReusedLifetimeFirstAlgo(\n    const bool compact_insert,\n    const HashMap<RegstDescProto*, std::pair<int32_t, int32_t>>& regst2lifetime,\n    const HashMap<RegstDescProto*, size_t>& mem_reused_regst2size,\n    MemBlockResultInfo<RegstDescProto*>* result) {\n  std::vector<RegstDescProto*> order;\n  order.reserve(regst2lifetime.size());\n  for (const auto& pair : regst2lifetime) { order.emplace_back(pair.first); }\n  std::sort(order.begin(), order.end(), [&](RegstDescProto* lhs, RegstDescProto* rhs) {\n    int64_t l_value = regst2lifetime.at(lhs).second - regst2lifetime.at(lhs).first;\n    int64_t r_value = regst2lifetime.at(rhs).second - regst2lifetime.at(rhs).first;\n    if (l_value == r_value) { return regst2lifetime.at(lhs).first < regst2lifetime.at(rhs).first; }\n    return l_value > r_value;\n  });\n  MemReusedAlgorithmAllocateByOrder(compact_insert, order, mem_reused_regst2size, regst2lifetime,\n                                    result);\n}\n\nvoid MemReusedTimeLineAlgo(\n    const bool compact_insert,\n    const HashMap<RegstDescProto*, std::pair<int32_t, int32_t>>& regst2lifetime,\n    const HashMap<RegstDescProto*, size_t>& mem_reused_regst2size,\n    MemBlockResultInfo<RegstDescProto*>* result) {\n  std::vector<RegstDescProto*> order;\n  order.reserve(regst2lifetime.size());\n  for (const auto& pair : regst2lifetime) { order.emplace_back(pair.first); }\n  std::sort(order.begin(), order.end(), [&](RegstDescProto* lhs, RegstDescProto* rhs) {\n    int64_t l_value = regst2lifetime.at(lhs).first;\n    int64_t r_value = regst2lifetime.at(rhs).first;\n    if (l_value == r_value) {\n      return regst2lifetime.at(lhs).second > regst2lifetime.at(rhs).second;\n    }\n    return l_value > r_value;\n  });\n  MemReusedAlgorithmAllocateByOrder(compact_insert, order, mem_reused_regst2size, regst2lifetime,\n                                    result);\n}\n\nvoid MemReusedMemVolumeFirstAlgo(\n    const bool compact_insert,\n    const HashMap<RegstDescProto*, std::pair<int32_t, int32_t>>& regst2lifetime,\n    const HashMap<RegstDescProto*, size_t>& mem_reused_regst2size,\n    MemBlockResultInfo<RegstDescProto*>* result) {\n  std::vector<RegstDescProto*> order;\n  order.reserve(regst2lifetime.size());\n  auto ComputeMemoryVolume = [&](RegstDescProto* key) {\n    return mem_reused_regst2size.at(key)\n           * (regst2lifetime.at(key).second - regst2lifetime.at(key).first) / 1000;\n  };\n  for (const auto& pair : regst2lifetime) { order.emplace_back(pair.first); }\n  std::sort(order.begin(), order.end(), [&](RegstDescProto* lhs, RegstDescProto* rhs) {\n    size_t l_value = ComputeMemoryVolume(lhs);\n    size_t r_value = ComputeMemoryVolume(rhs);\n    if (l_value == r_value) {\n      return mem_reused_regst2size.at(lhs) > mem_reused_regst2size.at(rhs);\n    }\n    return l_value > r_value;\n  });\n  MemReusedAlgorithmAllocateByOrder(compact_insert, order, mem_reused_regst2size, regst2lifetime,\n                                    result);\n}\n\nvoid SelectAlgorithmGenMemBlockOffset4Regsts(\n    MemAllocAlgoType algo_id, const bool compact_insert,\n    const HashMap<RegstDescProto*, std::pair<int32_t, int32_t>>& regst2lifetime,\n    const HashMap<RegstDescProto*, size_t>& mem_reused_regst2size,\n    MemBlockResultInfo<RegstDescProto*>* result) {\n  CHECK_EQ(result->mem_block_size, 0);\n  CHECK(result->regst_desc2offset.empty());\n\n  switch (algo_id) {\n    case kMemSizeFirstAlgo:\n      MemReusedMemSizeFirstAlgo(compact_insert, regst2lifetime, mem_reused_regst2size, result);\n      break;\n    case kLifetimeFirstAlgo:\n      MemReusedLifetimeFirstAlgo(compact_insert, regst2lifetime, mem_reused_regst2size, result);\n      break;\n    case kTimeLineAlgo:\n      MemReusedTimeLineAlgo(compact_insert, regst2lifetime, mem_reused_regst2size, result);\n      break;\n    case kMemVolumeFirstAlgo:\n      MemReusedMemVolumeFirstAlgo(compact_insert, regst2lifetime, mem_reused_regst2size, result);\n      break;\n    default: UNIMPLEMENTED();\n  }\n  CHECK_GT(result->mem_block_size, 0);\n  CHECK(!result->regst_desc2offset.empty());\n}\n\nint64_t CountMemAllocAlgoNum() {\n  const MemoryAllocationAlgorithmConf& mem_alloc_algo_conf =\n      GlobalJobDesc().job_conf().memory_allocation_algorithm_conf();\n  int64_t alloc_algo_num = 0;\n  if (mem_alloc_algo_conf.use_mem_size_first_algo()) { ++alloc_algo_num; }\n  if (mem_alloc_algo_conf.use_lifetime_first_algo()) { ++alloc_algo_num; }\n  if (mem_alloc_algo_conf.use_time_line_algo()) { ++alloc_algo_num; }\n  if (mem_alloc_algo_conf.use_mem_volume_first_algo()) { ++alloc_algo_num; }\n  CHECK_GE(alloc_algo_num, 0) << \"At least choose one type of memory allocation algorithm. We \"\n                                 \"recommend use_mem_size_first_algo()\";\n  const MemoryCompactInsertConf& mem_compact_insert_conf =\n      GlobalJobDesc().job_conf().memory_compact_insert_conf();\n  int64_t compact_insert_num = 0;\n  if (mem_compact_insert_conf.use_compact_insert()) { ++compact_insert_num; }\n  if (mem_compact_insert_conf.use_non_compact_insert()) { ++compact_insert_num; }\n  CHECK_GE(compact_insert_num, 0) << \"At least choose one type of memory arrangement algorithm \"\n                                     \"during memory allocation. We recommend use_compact_insert()\";\n\n  return alloc_algo_num * compact_insert_num;\n}\n\nvoid InitAlgo2Result(\n    HashMap<std::pair<MemAllocAlgoType, bool>, MemBlockResultInfo<RegstDescProto*>>* algo2result) {\n  CHECK(algo2result->empty());\n  std::vector<bool> compact_insert_algorithms;\n  const MemoryCompactInsertConf& mem_compact_insert_conf =\n      GlobalJobDesc().job_conf().memory_compact_insert_conf();\n  if (mem_compact_insert_conf.use_compact_insert()) { compact_insert_algorithms.push_back(true); }\n  if (mem_compact_insert_conf.use_non_compact_insert()) {\n    compact_insert_algorithms.push_back(false);\n  }\n\n  const MemoryAllocationAlgorithmConf& mem_alloc_algo_conf =\n      GlobalJobDesc().job_conf().memory_allocation_algorithm_conf();\n  // NOTE: Experiments show that memory first might be good enough for some cases.\n  for (auto compact_insert : compact_insert_algorithms) {\n    if (mem_alloc_algo_conf.use_mem_size_first_algo()) {\n      (*algo2result)[{kMemSizeFirstAlgo, compact_insert}] = MemBlockResultInfo<RegstDescProto*>();\n    }\n    if (mem_alloc_algo_conf.use_lifetime_first_algo()) {\n      (*algo2result)[{kLifetimeFirstAlgo, compact_insert}] = MemBlockResultInfo<RegstDescProto*>();\n    }\n    if (mem_alloc_algo_conf.use_time_line_algo()) {\n      (*algo2result)[{kTimeLineAlgo, compact_insert}] = MemBlockResultInfo<RegstDescProto*>();\n    }\n    if (mem_alloc_algo_conf.use_mem_volume_first_algo()) {\n      (*algo2result)[{kMemVolumeFirstAlgo, compact_insert}] = MemBlockResultInfo<RegstDescProto*>();\n    }\n  }\n}\n\n}  // namespace\n\nvoid IntraJobMemSharingUtil::InferMemBlockId4MemReusedRegst(Plan* plan) {\n  // 1 device 1 mem chain\n  HashMap<int64_t, std::vector<TaskProto*>> mem_chain2sorted_tasks;\n  HashMap<int64_t, std::vector<RegstDescProto*>> mem_chain2mem_reused_regsts;\n  // NOTE: We only store those reusable registers in mem_chain2regst_desc_id2reuse_regst_desc.\n  //      There are no duplicated registers in different memory chains.\n  HashMap<int64_t, HashMap<int64_t, RegstDescProto*>> mem_chain2regst_desc_id2reuse_regst_desc;\n  HashMap<RegstDescProto*, size_t> mem_reused_regst2size;\n  GenMemChainTasksAndRegsts(plan, &mem_chain2sorted_tasks, &mem_chain2mem_reused_regsts,\n                            &mem_chain2regst_desc_id2reuse_regst_desc, &mem_reused_regst2size);\n  if (mem_chain2mem_reused_regsts.empty()) { return; }\n  HashSet<int64_t> mem_chains;\n  for (const auto& pair : mem_chain2mem_reused_regsts) { mem_chains.insert(pair.first); }\n  // register lifetime\n  HashMap<int64_t, HashMap<RegstDescProto*, std::pair<int32_t, int32_t>>> mem_chain2regst2lifetime;\n  // info for inplace\n  HashMap<int64_t, HashMap<RegstDescProto*, RegstDescProto*>> mem_chain2consumer2inplaced_regst;\n  // info for straighten\n  HashMap<int64_t, size_t> mem_chain2peak_memory;\n\n  // step 1: generate regst alloc/free queue AND regst lifetimes\n  for (const auto& pair : mem_chain2mem_reused_regsts) {\n    GenRegstAllocFreeTimeLineAndRegstLifetimes(\n        mem_chain2sorted_tasks.at(pair.first), pair.second,\n        mem_chain2regst_desc_id2reuse_regst_desc.at(pair.first), mem_reused_regst2size,\n        &mem_chain2regst2lifetime[pair.first], &mem_chain2consumer2inplaced_regst[pair.first],\n        &mem_chain2peak_memory[pair.first]);\n  }\n\n  // step 2: multi-thread run several algorithm for each mem chain\n  HashMap<int64_t, HashMap<std::pair<MemAllocAlgoType, bool>, MemBlockResultInfo<RegstDescProto*>>>\n      mem_chain2algo2result;\n  {\n    int64_t work_size = mem_chain2mem_reused_regsts.size() * CountMemAllocAlgoNum();\n    int64_t thread_pool_size = std::min<int64_t>(work_size, std::thread::hardware_concurrency());\n    BlockingCounter counter(work_size);\n    ThreadPool thread_pool(thread_pool_size);\n    for (int64_t mem_chain_id : mem_chains) {\n      InitAlgo2Result(&mem_chain2algo2result[mem_chain_id]);\n      for (auto& pair : mem_chain2algo2result.at(mem_chain_id)) {\n        MemAllocAlgoType algo_id = pair.first.first;\n        bool compact_insert = pair.first.second;\n        MemBlockResultInfo<RegstDescProto*>* result = &pair.second;\n        thread_pool.AddWork([algo_id, compact_insert, mem_chain_id, &mem_chain2regst2lifetime,\n                             &mem_reused_regst2size, result, &counter]() {\n          SelectAlgorithmGenMemBlockOffset4Regsts(algo_id, compact_insert,\n                                                  mem_chain2regst2lifetime.at(mem_chain_id),\n                                                  mem_reused_regst2size, result);\n          counter.Decrease();\n        });\n      }\n    }\n    counter.WaitForeverUntilCntEqualZero();\n  }\n\n  // step 3: choose best one for each mem chain and set offset for inplace consumer regst\n  for (auto& pair : mem_chain2algo2result) {\n    MemBlockResultInfo<RegstDescProto*>* best_result = nullptr;\n    for (auto& algo_result_pair : pair.second) {\n      if (!best_result || algo_result_pair.second.mem_block_size < best_result->mem_block_size) {\n        best_result = &algo_result_pair.second;\n      }\n    }\n    CHECK(best_result != nullptr);\n\n    // Update the offset with a smaller total memory size if the current size is greater than the\n    // lower bound\n    if (GlobalJobDesc().job_conf().enable_compress_memory()) {\n      MemoryShareStrategy mss;\n      mss.AdaptivelyUpdateOffset(mem_reused_regst2size, mem_chain2regst2lifetime.at(pair.first),\n                                 mem_chain2peak_memory[pair.first], &best_result->mem_block_size,\n                                 &best_result->regst_desc2offset);\n    }\n\n    int64_t mem_block_id = Singleton<IDMgr>::Get()->NewMemBlockId();\n    CHECK_EQ(mem_chain2mem_reused_regsts.at(pair.first).size(),\n             (best_result->regst_desc2offset.size()\n              + mem_chain2consumer2inplaced_regst.at(pair.first).size()));\n    for (const auto& regst_offset_pair : best_result->regst_desc2offset) {\n      RegstDescProto* regst_desc = regst_offset_pair.first;\n      CHECK_EQ(regst_desc->mem_block_id(), -1);\n      regst_desc->set_mem_block_id(mem_block_id);\n      regst_desc->set_mem_block_offset(regst_offset_pair.second);\n    }\n    // set inplace\n    for (auto& consumer_inplace_pair : mem_chain2consumer2inplaced_regst.at(pair.first)) {\n      RegstDescProto* consumer_regst_desc = consumer_inplace_pair.first;\n      CHECK_EQ(consumer_regst_desc->mem_block_id(), -1);\n      RegstDescProto* inplaced_regst_desc = consumer_inplace_pair.second;\n      CHECK_EQ(inplaced_regst_desc->mem_block_id(), mem_block_id);\n      CHECK_NE(inplaced_regst_desc->mem_block_offset(), -1);\n      consumer_regst_desc->set_mem_block_id(inplaced_regst_desc->mem_block_id());\n      consumer_regst_desc->set_mem_block_offset(inplaced_regst_desc->mem_block_offset());\n    }\n\n    // set inplace hint and check\n    const auto& regst_desc_id2reuse_regst_desc =\n        mem_chain2regst_desc_id2reuse_regst_desc.at(pair.first);\n    for (auto& consumer_inplace_pair : mem_chain2consumer2inplaced_regst.at(pair.first)) {\n      RegstDescProto* consumer_regst_desc = consumer_inplace_pair.first;\n      RegstDescProto* inplaced_regst_desc = consumer_inplace_pair.second;\n      CHECK(consumer_regst_desc->has_inplace_consumed_regst_desc_id() == false);\n      CHECK(consumer_regst_desc->has_hint_inplace_consumed_regst_desc_id());\n      int64_t hint = consumer_regst_desc->hint_inplace_consumed_regst_desc_id();\n      // NOTE(chengcheng): hint regst desc id may NOT be the inplaced_regst_desc_id\n      //   because of nest inplace.\n      // NOTE: All the registers in mem_chain2consumer2inplaced_regst are reusable\n      auto hint_it = regst_desc_id2reuse_regst_desc.find(hint);\n      CHECK(hint_it != regst_desc_id2reuse_regst_desc.end());\n      RegstDescProto* in_regst_desc = hint_it->second;\n      CHECK_EQ(consumer_regst_desc->mem_block_id(), in_regst_desc->mem_block_id());\n      CHECK_EQ(consumer_regst_desc->mem_block_offset(), in_regst_desc->mem_block_offset());\n      CHECK_EQ(in_regst_desc->mem_block_offset(), inplaced_regst_desc->mem_block_offset());\n      CHECK_EQ(consumer_regst_desc->register_num(), in_regst_desc->register_num());\n      consumer_regst_desc->set_inplace_consumed_regst_desc_id(hint);\n    }\n  }\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job/intra_job_mem_sharing_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_JOB_IN_JOB_MEM_SHARING_UTIL_H_\n#define ONEFLOW_CORE_JOB_IN_JOB_MEM_SHARING_UTIL_H_\n\n#include \"oneflow/core/common/data_type.h\"\n#include \"oneflow/core/common/hash_container.h\"\n#include \"oneflow/core/job/memory_share_strategy.h\"\n#include \"oneflow/core/job/plan.pb.h\"\n#include <functional>\n#include <string>\n\nnamespace oneflow {\n\nstruct IntraJobMemSharingUtil {\n  static void InferMemBlockId4MemReusedRegst(Plan* plan);\n};\n\ntemplate<class T>\nstruct MemBlockResultInfo {\n  size_t mem_block_size;\n  HashMap<T, int64_t> regst_desc2offset;\n};\n\n// Judge whether a is suitable than b for a gap\ninline bool SuitableThan(int64_t a, int64_t b) {\n  // The number have orders\n  // A non-negative number is always more suitable than a negative number\n  // If a number is non-negative, then the smaller the better\n  // If a number is negative, then the larger the better\n  // 0 > 1 > 2 > ... > 999999999 > -1 > -2 > ... > -99999999\n  // Now we flip the positive part to make it \"the larger the better\".\n  if (a >= 0) { a = GetMaxVal<int64_t>() - a; }\n  if (b >= 0) { b = GetMaxVal<int64_t>() - b; }\n  return a > b;\n}\n\ntemplate<class T>\nvoid MemReusedAlgorithmAllocateByOrder(\n    const bool compact_insert, const std::vector<T>& order,\n    const HashMap<T, size_t>& regst_desc2size,\n    const HashMap<T, std::pair<int32_t, int32_t>>& regst2lifetime, MemBlockResultInfo<T>* result) {\n  HashMap<T, int64_t>* regst_desc2offset = &(result->regst_desc2offset);\n  // NOTE: It is important to make the variables local.\n  // It took me several days to find out that using passed-in vector for size, order, and lifetime\n  // would double the running time. Switch HashMap to vector\n  int32_t total_register_num = order.size();\n  std::vector<int64_t> order2size(total_register_num);\n  std::vector<std::pair<int32_t, int32_t>> order2lifetime(total_register_num);\n  std::vector<int64_t> order2offset(total_register_num);\n  for (int32_t i = 0; i < total_register_num; i++) {\n    order2size[i] = regst_desc2size.at(order[i]);\n    order2lifetime[i] = regst2lifetime.at(order[i]);\n  }\n  size_t buffer_size = 1;\n  // Sort by offset\n  auto comp = [&order2offset](const auto& a, const auto& b) {\n    if (order2offset[a] != order2offset[b]) { return order2offset[a] < order2offset[b]; }\n    // Make sure we have a stable order even if we have the same offset for different registers\n    return a < b;\n  };\n  std::set<int32_t, decltype(comp)> sorted_registers(comp);\n  // Decide offset following the given order\n  for (int32_t inserting_id = 0; inserting_id < total_register_num; inserting_id++) {\n    const auto& inserting_lifetime = order2lifetime[inserting_id];\n    // At the beginning, try to insert the offset in the front of the whole memory pool.\n    int64_t inserting_offset = 0;\n    int64_t inserting_end = inserting_offset + order2size[inserting_id];\n    if (compact_insert) {\n      // Find the most suitable gap for the register\n      int64_t gap_head = 0;\n      int64_t inserting_size = order2size[inserting_id];\n      // difference = length of gap - length of the inserting register\n      int64_t diff_gap = 0, suitable_diff_gap = -1 - inserting_size;\n      for (const auto& curr_register : sorted_registers) {\n        // Ignore those non-excluded registers\n        if (IsLifetimeExcluded(inserting_lifetime, order2lifetime[curr_register])) {\n          if (gap_head < order2offset[curr_register]) {\n            // Find one gap\n            diff_gap = (order2offset[curr_register] - gap_head) - inserting_size;\n            // Compared with the previous suitable gap\n            if (SuitableThan(diff_gap, suitable_diff_gap)) {\n              suitable_diff_gap = diff_gap;\n              // We may insert the register into the gap\n              inserting_offset = gap_head;\n            }\n            // Update gap head\n            gap_head = order2offset[curr_register] + order2size[curr_register];\n          } else {\n            // No gap, update gap head\n            gap_head = std::max(gap_head, order2offset[curr_register] + order2size[curr_register]);\n          }\n        }\n      }\n      // Deal with the buffer_size, which may be the final gap\n      diff_gap = (buffer_size - gap_head) - inserting_size;\n      // Compared with the previous suitable gap\n      if (SuitableThan(diff_gap, suitable_diff_gap)) {\n        suitable_diff_gap = diff_gap;\n        // We may insert the register into the gap\n        inserting_offset = gap_head;\n      }\n      // If no gap large enough to contain the current register\n      if (suitable_diff_gap < 0) {\n        // Prolong the maximum memory pool size by (-suitable_diff_gap)\n        buffer_size -= suitable_diff_gap;\n        int64_t gap_end = suitable_diff_gap + inserting_size + inserting_offset;\n        for (auto reverse_it = sorted_registers.rbegin(); reverse_it != sorted_registers.rend();\n             reverse_it++) {\n          // All the registers with offset < gap_end maintain their position\n          if (order2offset[*reverse_it] < gap_end) { break; }\n          // All the registers with offset >= gap_end move backward\n          order2offset[*reverse_it] -= suitable_diff_gap;\n        }\n      }\n\n    } else {\n      for (const auto& curr_register : sorted_registers) {\n        // i: inserting register, j: current register\n        // x: register offset, l: register size\n        // If x_i + l_i <= x_j, then the inserting register would be placed at x_i\n        if (order2offset[curr_register] >= inserting_end) { break; }\n        // If i and j are excluded, and x_i + l_i > x_j,\n        // then we try to place i at x_j + l_j and check the following registers\n        if (IsLifetimeExcluded(inserting_lifetime, order2lifetime[curr_register])) {\n          int64_t curr_end = order2offset[curr_register] + order2size[curr_register];\n          // Can not set inserting offset = current end directly.\n          // We might have two excluded registers like this:\n          // register a: [100, 10000]\n          // register b: [500, 600]\n          if (inserting_offset < curr_end) {\n            inserting_offset = curr_end;\n            inserting_end = inserting_offset + order2size[inserting_id];\n          }\n        }\n      }\n      // Update total size\n      if (inserting_end > buffer_size) { buffer_size = inserting_end; }\n    }\n    // Either we break the loop or the loop terminated naturally, we can place i at inserting_offset\n    order2offset[inserting_id] = inserting_offset;\n    sorted_registers.insert(inserting_id);\n  }\n\n  result->mem_block_size = buffer_size;\n  // Switch vector to HashMap\n  for (int32_t i = 0; i < total_register_num; i++) {\n    (*regst_desc2offset)[order[i]] = order2offset[i];\n  }\n}\n\ntemplate<class T>\nvoid MemReusedMemSizeFirstAlgo(const bool compact_insert,\n                               const HashMap<T, std::pair<int32_t, int32_t>>& regst2lifetime,\n                               const HashMap<T, size_t>& mem_reused_regst2size,\n                               MemBlockResultInfo<T>* result) {\n  std::vector<T> order;\n  order.reserve(regst2lifetime.size());\n  for (const auto& pair : regst2lifetime) { order.emplace_back(pair.first); }\n  std::sort(order.begin(), order.end(), [&](T lhs, T rhs) {\n    size_t l_value = mem_reused_regst2size.at(lhs);\n    size_t r_value = mem_reused_regst2size.at(rhs);\n    if (l_value == r_value) { return regst2lifetime.at(lhs).first < regst2lifetime.at(rhs).first; }\n    return l_value > r_value;\n  });\n  MemReusedAlgorithmAllocateByOrder(compact_insert, order, mem_reused_regst2size, regst2lifetime,\n                                    result);\n}\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_JOB_IN_JOB_MEM_SHARING_UTIL_H_\n"
  },
  {
    "path": "oneflow/core/job/job.proto",
    "content": "syntax = \"proto2\";\npackage oneflow;\n\nimport \"oneflow/core/job/dlnet_conf.proto\";\nimport \"oneflow/core/job/placement.proto\";\nimport \"oneflow/core/job/job_conf.proto\";\nimport \"oneflow/core/register/logical_blob_id.proto\";\nimport \"oneflow/core/register/op_blob_arg.proto\";\nimport \"oneflow/core/register/blob_desc.proto\";\nimport \"oneflow/core/operator/op_conf.proto\";\nimport \"oneflow/core/job/sbp_parallel.proto\";\nimport \"oneflow/core/job/module_conf.proto\";\n\nmessage JobParallelViewConf {\n  map<string, SbpSignature> op_name2sbp_signature_conf = 1;\n  map<string, bool> op_name2is_local_parallel_view = 2;\n  map<string, NdSbpSignature> op_name2nd_sbp_signature_conf = 3;\n}\n\nmessage JobHelperConf {\n  map<string, LogicalBlobIdPairs> tag2lbi_relations = 1;\n  map<string, OpNameRelations> tag2op_name_relations = 2;\n  map<string, BlobDescProto> lbn2logical_blob_desc = 4;\n  map<string, int64> lbn2logical_object_id = 5;\n  map<string, ArgSignature> op_name2arg_signature = 9;\n}\n\nmessage MergedLogicalChainIdGroup {\n  repeated int64 logical_chain_id_list = 1;\n}\n\nmessage Job {\n  optional DLNetConf net = 1;\n  optional Placement placement = 2;\n  required JobConfigProto job_conf = 3;\n  optional JobParallelViewConf job_parallel_view_conf = 4;\n  optional JobHelperConf helper = 5;\n  map<string, ModuleConf> module_name2module_conf = 6;\n  repeated MergedLogicalChainIdGroup logical_chain_groups = 7;\n}\n"
  },
  {
    "path": "oneflow/core/job/job_build_and_infer_ctx.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/common/protobuf.h\"\n#include \"oneflow/core/common/cost_util.h\"\n#include \"oneflow/core/framework/nd_sbp.h\"\n#include \"oneflow/core/vm/symbol_storage.h\"\n#include \"oneflow/core/framework/config_def.h\"\n#include \"oneflow/core/framework/to_string.h\"\n#include \"oneflow/core/framework/scope_util.h\"\n#include \"oneflow/core/job/job_build_and_infer_ctx.h\"\n#include \"oneflow/core/job/local_sig_infer_hint.h\"\n#include \"oneflow/core/job/scope.h\"\n#include \"oneflow/core/job_rewriter/autograd.h\"\n#include \"oneflow/core/job_rewriter/job_pass.h\"\n#include \"oneflow/user/summary/summary_converter.h\"\n\n#include <google/protobuf/text_format.h>\n#include \"nlohmann/json.hpp\"\n\nnamespace oneflow {\n\nstatic const std::string kAutoLocalBlobNamePrefix =\n    \"System-Local-Blob-Auto-Converted-From-Global-Blob\";\n\nnamespace {\n\nvoid ResetOpConfName(OperatorConf* op_conf, const std::string& new_op_name) {\n  op_conf->set_name(new_op_name);\n  PbMessage* op_type_conf = MutableMessageInPbMessage(op_conf, op_conf->op_type_case());\n  UserOpConf* user_conf = dynamic_cast<UserOpConf*>(op_type_conf);\n  if (user_conf) {\n    for (const auto& pair : user_conf->output()) {\n      for (const std::string& old_lbn : pair.second.s()) {\n        LogicalBlobId old_lbi = GenLogicalBlobId(old_lbn);\n        auto blob_name_id_pair = GenUnRepeatedBn(old_lbi.blob_name());\n        std::string new_lbn = GenLogicalBlobName(new_op_name, old_lbi.blob_name());\n        (*(user_conf->mutable_output()))[pair.first].set_s(blob_name_id_pair.second, new_lbn);\n      }\n    }\n  }\n}\n\nMaybe<void> GetOpNames(const Job& job, HashSet<std::string>* op_names) {\n  for (const auto& op_conf : job.net().op()) {\n    CHECK_OR_RETURN(op_names->insert(op_conf.name()).second);\n  }\n  return Maybe<void>::Ok();\n}\n\nvoid UpdateOpName2AncestorsNeedNoGrad(\n    const Operator& op, const std::function<const Operator*(const std::string&)>& Op4OpName,\n    const bool is_train, HashMap<std::string, bool>* op_name2ancestors_need_no_grad) {\n  bool no_grad = !is_train;\n  auto IsTrainableVariableLbi = [&](const LogicalBlobId& lbi) {\n    const auto& op_conf = Op4OpName(lbi.op_name())->op_conf();\n    return op_conf.has_variable_conf() && op_conf.variable_conf().trainable();\n  };\n  for (const auto& ibn : op.input_bns()) {\n    const auto& lbi = op.BnInOp2Lbi(ibn);\n    no_grad = no_grad && !IsTrainableVariableLbi(lbi);\n    no_grad = no_grad && !op.InputBlobModifier4Ibn(ibn).requires_grad();\n    no_grad = no_grad && (*op_name2ancestors_need_no_grad)[lbi.op_name()];\n  }\n  (*op_name2ancestors_need_no_grad)[op.op_name()] = no_grad;\n}\n\n}  // namespace\n\nJobBuildAndInferCtx::JobBuildAndInferCtx(Job* job, int64_t job_id)\n    : job_(job), job_id_(job_id), unique_op_name_index_(0) {\n  is_job_conf_frozen_ = false;\n  has_job_conf_ = false;\n}\n\nMaybe<void> JobBuildAndInferCtx::SetJobConf(const JobConfigProto& job_conf) {\n  CHECK_OR_RETURN(!is_job_conf_frozen_) << Error::JobConfFrozenError();\n  CHECK_OR_RETURN(!has_job_conf_) << Error::JobConfRepeatedSetError();\n  has_job_conf_ = true;\n  CHECK_EQ_OR_RETURN(job_->job_conf().job_name(), job_conf.job_name())\n      << Error::JobNameNotEqualError() << \"job name you set: \" << job_conf.job_name()\n      << \" not equal to origin job name: \" << job_->job_conf().job_name();\n  job_->mutable_job_conf()->CopyFrom(job_conf);\n  CHECK_ISNULL_OR_RETURN(Singleton<JobDesc>::Get());\n  Singleton<JobDesc>::New(job_conf, job_id_);\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> JobBuildAndInferCtx::AddOpNameParallelConf2Placement(\n    const std::string& op_name, const ParallelConf& parallel_conf) {\n  ParallelDesc parallel_desc(parallel_conf);\n  PlacementGroup* pg = nullptr;\n  if (parallel_desc2placement_group_.find(parallel_desc) == parallel_desc2placement_group_.end()) {\n    pg = job_->mutable_placement()->add_placement_group();\n    parallel_desc2placement_group_.emplace(parallel_desc, pg);\n    *(pg->mutable_parallel_conf()) = parallel_conf;\n  } else {\n    pg = parallel_desc2placement_group_.at(parallel_desc);\n  }\n  pg->mutable_op_set()->add_op_name(op_name);\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> JobBuildAndInferCtx::AddLbiParallelConf2BlobPlacement(\n    const Operator* op, std::function<ParallelDesc*(const std::string&)> ParallelDesc4Obn) {\n  for (const auto& obn : op->output_bns()) {\n    const auto& parallel_desc = *ParallelDesc4Obn(obn);\n    auto iter = parallel_desc2blob_placement_group_.find(parallel_desc);\n    if (iter == parallel_desc2blob_placement_group_.end()) {\n      auto* blob_pg = job_->mutable_placement()->add_blob_placement_group();\n      *blob_pg->mutable_parallel_conf() = parallel_desc.parallel_conf();\n      iter = parallel_desc2blob_placement_group_.emplace(parallel_desc, blob_pg).first;\n    }\n    const auto& lbi = op->BnInOp2Lbi(obn);\n    *iter->second->add_lbi() = lbi;\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<OperatorConf> JobBuildAndInferCtx::DecodeLbiHintAndReturnNewOpConf(\n    const Operator& op, SbpSignature* sbp_sig_conf) const {\n  auto op_conf_without_split_hint = std::make_shared<OperatorConf>(op.op_conf());\n  for (const std::string& ibn : op.input_bns()) {\n    std::string lbn_may_with_hint = GetInputLbnInOpCustomizedConf(op.op_conf(), ibn);\n    SbpParallel sbp_parallel;\n    bool has_sbp_hint = JUST(GetSbpParallelInLbnOrNothing(lbn_may_with_hint, &sbp_parallel));\n    if (has_sbp_hint) {\n      (*(sbp_sig_conf->mutable_bn_in_op2sbp_parallel()))[ibn] = sbp_parallel;\n      const LogicalBlobId& lbi = op.BnInOp2Lbi(ibn);\n      std::string lbn = GenLogicalBlobName(lbi);\n      CHECK_EQ_OR_RETURN(lbn_may_with_hint, ReplaceInputLbnInOpCustomizedConf(\n                                                op_conf_without_split_hint.get(), ibn, lbn));\n    }\n  }\n  return op_conf_without_split_hint;\n}\n\nvoid JobBuildAndInferCtx::AddOpAndUpdateJobParallelViewConf(const OperatorConf& operator_conf,\n                                                            const ParallelDesc& parallel_desc,\n                                                            const NdSbpSignature& nd_sbp_signature,\n                                                            bool is_local_parallel_view) const {\n  auto* op_name2sbp_sig =\n      job_->mutable_job_parallel_view_conf()->mutable_op_name2sbp_signature_conf();\n  auto* op_name2nd_sbp_sig =\n      job_->mutable_job_parallel_view_conf()->mutable_op_name2nd_sbp_signature_conf();\n  if (nd_sbp_signature.bn_in_op2nd_sbp().size() > 0) {\n    (*op_name2nd_sbp_sig)[operator_conf.name()] = nd_sbp_signature;\n    if (parallel_desc.hierarchy()->NumAxes() == 1) {\n      SbpSignature sbp_signature;\n      NdSbpSignatureToSbpSignature(nd_sbp_signature, &sbp_signature);\n      (*op_name2sbp_sig)[operator_conf.name()] = sbp_signature;\n    }\n  }\n  auto* op_name2is_local_parallel_view =\n      job_->mutable_job_parallel_view_conf()->mutable_op_name2is_local_parallel_view();\n  if (is_local_parallel_view) { (*op_name2is_local_parallel_view)[operator_conf.name()] = true; }\n  job_->mutable_net()->add_op()->CopyFrom(operator_conf);\n\n  // set up the module config\n  const auto& scope =\n      Singleton<symbol::Storage<Scope>>::Get()->Get(operator_conf.scope_symbol_id());\n  if (scope.scope_proto().has_module_name()) {\n    const auto& module_name = scope.scope_proto().module_name();\n    auto* module_name2module_conf = job_->mutable_module_name2module_conf();\n    if (!(*module_name2module_conf)[module_name].has_name()) {\n      (*module_name2module_conf)[module_name].set_name(scope.scope_proto().module_name());\n    }\n\n    *((*module_name2module_conf)[module_name].add_ops()) = operator_conf.name();\n  }\n}\n\nMaybe<void> JobBuildAndInferCtx::InferLocalSignature(Operator* op, bool is_local_parallel_view_conf,\n                                                     const ParallelDesc& parallel_desc) {\n  HashMap<std::string, LocalSigInferHint> ibn2local_sig_infer_hint;\n  for (const std::string& ibn : op->input_bns()) {\n    const LogicalBlobId& lbi = op->BnInOp2Lbi(ibn);\n    CHECK_OR_RETURN(lbi2logical_blob_desc_.find(lbi) != lbi2logical_blob_desc_.end())\n        << Error::LogicalBlobNameNotExistError()\n        << \"infer blob desc not found, when infer op_name: \\\"\" << op->op_name()\n        << \"\\\", consumed op_name: \\\"\" << lbi.op_name() << \"\\\", blob_name: \\\"\" << lbi.blob_name();\n    const ParallelDesc* pd = &lbi2parallel_desc_from_producer_view_.at(lbi);\n    const auto* producer_op = op_name2op_.at(lbi.op_name()).get();\n    const auto& producer_obn = *JUST(producer_op->obn4lbi(lbi));\n    const auto& opt_local_parallel =\n        *CHECK_JUST(producer_op->OptLocalParallel4BnInOp(producer_obn));\n    ibn2local_sig_infer_hint.emplace(\n        ibn, LocalSigInferHint(pd, opt_local_parallel.has_local_parallel()));\n  }\n  const auto& LocalSigInferHint4Ibn =\n      [&](const std::string& ibn) -> Maybe<const LocalSigInferHint*> {\n    const auto& iter = ibn2local_sig_infer_hint.find(ibn);\n    CHECK_OR_RETURN(iter != ibn2local_sig_infer_hint.end()) << \"input blob not found. ibn: \" << ibn;\n    return &iter->second;\n  };\n  JUST(\n      op->InferLocalSignatureIf(LocalSigInferHint4Ibn, is_local_parallel_view_conf, parallel_desc));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> JobBuildAndInferCtx::InferOpOutNdSbp(Operator* op,\n                                                 const NdSbpSignature& nd_sbp_sig_conf,\n                                                 const ParallelDesc& parallel_desc) {\n  HashMap<std::string, NdSbpInferHint> ibn2nd_sbp_infer_hint;\n  for (const std::string& ibn : op->input_bns()) {\n    const LogicalBlobId& lbi = op->BnInOp2Lbi(ibn);\n    auto logical_blob_desc_it = lbi2logical_blob_desc_.find(lbi);\n    CHECK_OR_RETURN(logical_blob_desc_it != lbi2logical_blob_desc_.end())\n        << Error::LogicalBlobNameNotExistError()\n        << \"infer blob desc not found, when infer op_name: \\\"\" << op->op_name()\n        << \"\\\", consumed op_name: \\\"\" << lbi.op_name() << \"\\\", blob_name: \\\"\" << lbi.blob_name();\n    const BlobDesc* logical_blob_desc = logical_blob_desc_it->second.get();\n    const ParallelDesc* pd = &lbi2parallel_desc_from_producer_view_.at(lbi);\n    auto nd_sbp_it = lbi2nd_sbp_from_producer_view_.find(lbi);\n    CHECK_OR_RETURN(nd_sbp_it != lbi2nd_sbp_from_producer_view_.end())\n        << Error::LogicalBlobNameNotExistError() << \"when infer op_name: \" << op->op_name()\n        << \" consumed op_name: \" << lbi.op_name() << \" blob_name: \" << lbi.blob_name()\n        << \" not infer parallel distribution\";\n    const NdSbp* nd_sbp = &nd_sbp_it->second;\n    ibn2nd_sbp_infer_hint.emplace(ibn, NdSbpInferHint(pd, logical_blob_desc, nd_sbp));\n  }\n\n  const auto NdSbpInferHint4Ibn = [&](const std::string& bn) -> Maybe<const NdSbpInferHint*> {\n    return &ibn2nd_sbp_infer_hint.at(bn);\n  };\n\n  JUST(op->InferNdSbpSignatureIf(nd_sbp_sig_conf, parallel_desc, NdSbpInferHint4Ibn));\n\n  const auto& bn2nd_sbp = JUST(op->nd_sbp_signature())->bn_in_op2nd_sbp();\n  for (const auto& obn : op->output_bns()) {\n    const LogicalBlobId& lbi = op->BnInOp2Lbi(obn);\n    CHECK_OR_RETURN(bn2nd_sbp.find(obn) != bn2nd_sbp.end())\n        << Error::BlobSplitAxisInferError() << \"op_name: \" << lbi.op_name()\n        << \" blob_name: \" << lbi.blob_name() << \" not infer split axis\";\n    CHECK_OR_RETURN(lbi2nd_sbp_from_producer_view_.emplace(lbi, bn2nd_sbp.at(obn)).second)\n        << Error::BlobSplitAxisInferError() << \"op_name: \" << lbi.op_name()\n        << \" blob_name: \" << lbi.blob_name() << \" infer split axis repeated\";\n    CHECK_OR_RETURN(lbi2parallel_desc_from_producer_view_.emplace(lbi, parallel_desc).second)\n        << Error::BlobSplitAxisInferError() << \"op_name: \" << lbi.op_name()\n        << \" blob_name: \" << lbi.blob_name() << \" add parallel desc repeated\";\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> JobBuildAndInferCtx::GenOpProducedEmptyLogicalBlobDesc(Operator* op) {\n  // check consumed blob\n  for (const std::string& consumed_bn : op->input_bns()) {\n    const LogicalBlobId& lbi = op->BnInOp2Lbi(consumed_bn);\n    CHECK_OR_RETURN(lbi2logical_blob_desc_.find(lbi) != lbi2logical_blob_desc_.end())\n        << Error::LogicalBlobNameNotExistError() << \"op_name: \" << op->op_name()\n        << \" consumed_op_name:\" << lbi.op_name() << \" blob_name: \" << lbi.blob_name()\n        << \" not exist\";\n  }\n\n  // create produced blob\n  std::vector<std::string> produced_bns;\n  produced_bns.reserve(op->output_bns().size() + op->tmp_bns().size());\n  produced_bns.insert(produced_bns.end(), op->output_bns().begin(), op->output_bns().end());\n  produced_bns.insert(produced_bns.end(), op->tmp_bns().begin(), op->tmp_bns().end());\n  for (const std::string& produced_bn : produced_bns) {\n    const LogicalBlobId& lbi = op->BnInOp2Lbi(produced_bn);\n    CHECK_OR_RETURN(lbi2logical_blob_desc_.find(lbi) == lbi2logical_blob_desc_.end())\n        << Error::LogicalBlobNameExistError()\n        << \"duplicate logical blob name found. op_name: \" << lbi.op_name()\n        << \" blob_name: \" << lbi.blob_name();\n    lbi2logical_blob_desc_.emplace(\n        lbi, std::make_unique<BlobDesc>(DataType::kInvalidDataType, MemoryFormat::kContiguous));\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> JobBuildAndInferCtx::CheckOpBlobSplitability(Operator* op, int64_t parallel_num) {\n  const auto& parallel_hierarchy = JUST(op->GetOpParallelDesc())->hierarchy();\n  if (parallel_hierarchy->NumAxes() == 1) {\n    HashSet<std::string> obns(op->output_bns().begin(), op->output_bns().end());\n    auto GetParallelNum = [&](const std::string& bn_in_op) {\n      if (obns.find(bn_in_op) == obns.end()) { return parallel_num; }\n      return lbi2parallel_desc_from_producer_view_.at(op->BnInOp2Lbi(bn_in_op)).parallel_num();\n    };\n    for (const auto& pair : JUST(op->sbp_signature())->bn_in_op2sbp_parallel()) {\n      if (!pair.second.has_split_parallel()) { continue; }\n      if (JUST(op->OptLocalParallel4BnInOp(pair.first))->has_local_parallel()) { continue; }\n      int64_t axis = pair.second.split_parallel().axis();\n      const LogicalBlobId& lbi = op->BnInOp2Lbi(pair.first);\n      int64_t blob_parallel_num = GetParallelNum(pair.first);\n      const BlobDesc& logical_blob_desc = *(lbi2logical_blob_desc_.at(lbi).get());\n      int64_t num_axes = logical_blob_desc.shape().NumAxes();\n      if (axis < 0) { axis += num_axes; }\n      CHECK_GE_OR_RETURN(axis, 0);\n      CHECK_LE_OR_RETURN(axis, num_axes)\n          << \"op: \" << op->op_name() << \", blob: \" << pair.first << \", axis: \" << axis\n          << \", shape: \" << logical_blob_desc.shape();\n      if (logical_blob_desc.shape().NumAxes() > 0) {\n        CHECK_GE_OR_RETURN(logical_blob_desc.shape().At(axis), blob_parallel_num)\n            << \"op_name: \" << lbi.op_name() << \" blob_name: \" << lbi.blob_name()\n            << \" shape: \" << logical_blob_desc.shape()\n            << \" cannot be splitted by parallel_num: \" << blob_parallel_num << \" at axis \" << axis;\n      }\n    }\n  } else {\n    for (const auto& pair : JUST(op->nd_sbp_signature())->bn_in_op2nd_sbp()) {\n      if (JUST(op->OptLocalParallel4BnInOp(pair.first))->has_local_parallel()) { continue; }\n      const LogicalBlobId& lbi = op->BnInOp2Lbi(pair.first);\n      const BlobDesc& logical_blob_desc = *(lbi2logical_blob_desc_.at(lbi).get());\n      Shape current_shape = logical_blob_desc.shape();\n      for (int64_t i = 0; i < pair.second.sbp_parallel_size(); ++i) {\n        const SbpParallel& sbp_parallel = pair.second.sbp_parallel(i);\n        if (sbp_parallel.has_split_parallel()) {\n          const int64_t axis = sbp_parallel.split_parallel().axis();\n          CHECK_GT_OR_RETURN(current_shape.At(axis), 0);\n          // Support unbalanced splitting\n          CHECK_GE_OR_RETURN(current_shape.At(axis), parallel_hierarchy->At(i))\n              << \"op_name: \" << lbi.op_name() << \" blob_name: \" << lbi.blob_name()\n              << \" shape: \" << logical_blob_desc.shape()\n              << \" cannot be splitted by nd sbp: \" << NdSbpToString(pair.second) << \" at axis \"\n              << axis << \" with parallel_hierarchy: \" << *parallel_hierarchy;\n          // Split and take the minimum one\n          current_shape.Set(axis, current_shape.At(axis) / parallel_hierarchy->At(i));\n        }\n      }\n    }\n  }\n\n  return Maybe<void>::Ok();\n}\n\nMaybe<ParallelConf> JobBuildAndInferCtx::InferOpParallelConf(\n    const Operator& op, const ParallelConf& origin_parallel_conf,\n    const HashMap<std::string, bool>& ibn2disable_boxing) const {\n  const ParallelDesc* parallel_desc = nullptr;\n  for (const auto& ibn : op.input_bns()) {\n    if (ibn2disable_boxing.at(ibn) == false) { continue; }\n    const auto& lbi = op.BnInOp2Lbi(ibn);\n    const auto& ibn_parallel_desc = lbi2parallel_desc_from_producer_view_.at(lbi);\n    if (parallel_desc == nullptr) {\n      parallel_desc = &ibn_parallel_desc;\n    } else {\n      CHECK_EQ_OR_RETURN(parallel_desc->parallel_num(), ibn_parallel_desc.parallel_num());\n    }\n  }\n  if (parallel_desc == nullptr) { return std::make_shared<ParallelConf>(origin_parallel_conf); }\n  return std::make_shared<ParallelConf>(parallel_desc->parallel_conf());\n}\n\nvoid JobBuildAndInferCtx::InitIbn2DisableBoxing(const Operator& op,\n                                                HashMap<std::string, bool>* ibn2disable_boxing) {\n  for (const auto& ibn : op.input_bns()) {\n    (*ibn2disable_boxing)[ibn] = lbi2disable_boxing_[op.BnInOp2Lbi(ibn)];\n  }\n}\n\nMaybe<NdSbpSignature> JobBuildAndInferCtx::InitConstraitNdSbpSignature(\n    const Operator& op, const HashMap<std::string, bool>& ibn2disable_boxing) const {\n  auto nd_sbp_sig = std::make_shared<NdSbpSignature>();\n  for (const auto& it : ibn2disable_boxing) {\n    if (it.second) {\n      const auto& ibn = it.first;\n      const LogicalBlobId& lbi = op.BnInOp2Lbi(ibn);\n      const auto& nd_sbp_iter = lbi2nd_sbp_from_producer_view_.find(lbi);\n      if (nd_sbp_iter == lbi2nd_sbp_from_producer_view_.end()) {\n        return Error::RuntimeError()\n               << \"The nd_sbp of input \" << ibn << \" (tensor name is \" << GenLogicalBlobName(lbi)\n               << \") is not found for operation \" << op.op_name()\n               << \". It maybe caused by an invalid inplace operation.\";\n      }\n      (*(nd_sbp_sig->mutable_bn_in_op2nd_sbp()))[ibn] = lbi2nd_sbp_from_producer_view_.at(lbi);\n    }\n  }\n  return nd_sbp_sig;\n}\n\nbool JobBuildAndInferCtx::HasAnyLocalBlobInput(const Operator& op) const {\n  for (const auto& ibn : op.input_bns()) {\n    const auto& lbi = op.BnInOp2Lbi(ibn);\n    if (local_lbi2sub_lbis_.find(lbi) != local_lbi2sub_lbis_.end()) { return true; }\n  }\n  return false;\n}\n\nMaybe<const SbpParallel*> JobBuildAndInferCtx::SbpParallel4Lbi(const LogicalBlobId& lbi) const {\n  const auto& iter = lbi2nd_sbp_from_producer_view_.find(lbi);\n  CHECK_OR_RETURN(iter != lbi2nd_sbp_from_producer_view_.end())\n      << \"lbn: \" << GenLogicalBlobName(lbi) << \" undefined\";\n  CHECK_EQ_OR_RETURN(iter->second.sbp_parallel_size(), 1);\n  return &(iter->second.sbp_parallel(0));\n}\n\nMaybe<const ParallelDesc*> JobBuildAndInferCtx::ParallelDesc4Lbi(const LogicalBlobId& lbi) const {\n  const auto& iter = lbi2parallel_desc_from_producer_view_.find(lbi);\n  CHECK_OR_RETURN(iter != lbi2parallel_desc_from_producer_view_.end())\n      << \"lbn: \" << GenLogicalBlobName(lbi) << \" undefined\";\n  return &iter->second;\n}\n\nMaybe<bool> JobBuildAndInferCtx::AllInputsBroadcastParallel(const Operator& op) const {\n  for (const auto& ibn : op.input_bns()) {\n    const LogicalBlobId& lbi = op.BnInOp2Lbi(ibn);\n    const auto& iter = local_lbi2sbp_parallel_.find(lbi);\n    if (iter != local_lbi2sbp_parallel_.end()) {\n      if (!iter->second.has_broadcast_parallel()) { return false; }\n    } else {\n      if (!JUST(SbpParallel4Lbi(lbi))->has_broadcast_parallel()) { return false; }\n    }\n  }\n  return true;\n}\n\nbool JobBuildAndInferCtx::IsVariableLbi(const LogicalBlobId& lbi) const {\n  return op_name2op_.at(lbi.op_name())->op_conf().has_variable_conf();\n}\n\nMaybe<void> JobBuildAndInferCtx::CheckAllInputsConvertableToLocalBlob(const Operator& op) const {\n  for (const auto& ibn : op.input_bns()) {\n    const auto& lbi = op.BnInOp2Lbi(ibn);\n    if (local_lbi2sub_lbis_.find(lbi) != local_lbi2sub_lbis_.end()) { continue; }\n    const auto& sbp = *JUST(SbpParallel4Lbi(lbi));\n    if (sbp.has_broadcast_parallel()) { continue; }\n    if (sbp.has_split_parallel() && sbp.split_parallel().axis() == 0) { continue; }\n    const std::string& lbn = GenLogicalBlobName(lbi);\n    return Error::CheckFailedError() << \"input lbn: \" << lbn << \" is not convertable to local blob\";\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> LazyJobBuildAndInferCtx::CheckAllInputsWithSameParallelNum(const Operator& op,\n                                                                       int32_t parallel_num) const {\n  for (const auto& ibn : op.input_bns()) {\n    const auto& lbi = op.BnInOp2Lbi(ibn);\n    const auto& iter = local_lbi2sub_lbis().find(lbi);\n    int32_t ibn_parallel_num = 0;\n    if (iter != local_lbi2sub_lbis().end()) {\n      ibn_parallel_num = iter->second.size();\n    } else {\n      ibn_parallel_num = JUST(ParallelDesc4Lbi(lbi))->parallel_num();\n    }\n    CHECK_EQ_OR_RETURN(ibn_parallel_num, parallel_num)\n        << \"the parallel_num of input lbn: \" << GenLogicalBlobName(lbi)\n        << \" is not equals to op' parallel_num\";\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<OpAttribute> JobBuildAndInferCtx::AddAndInferLocalOp(const OperatorConf& op_conf) {\n  CHECK_OR_RETURN(op_conf.has_scope_symbol_id());\n  const auto& scope = Singleton<symbol::Storage<Scope>>::Get()->Get(op_conf.scope_symbol_id());\n  const auto* job_desc = JUST(scope.job_desc());\n  const auto& parallel_desc = *JUST(scope.GetParallelDesc(op_conf));\n  auto op = JUST(ConstructOp(op_conf, parallel_desc.device_type()));\n  JUST(CheckAllInputsConvertableToLocalBlob(*op));\n  int32_t parallel_num = parallel_desc.parallel_num();\n  JUST(CheckAllInputsWithSameParallelNum(*op, parallel_num));\n  auto GetSubOpName = [&](int index) { return GetLocalOpName(op_conf.name(), index); };\n  OperatorConf sub_op_conf(op_conf);\n  int64_t sub_op_list_size = SizeOfSubGlobalOpList(parallel_num);\n  auto last_op_attribute = std::make_shared<OpAttribute>();\n  FOR_RANGE(int32_t, i, 0, sub_op_list_size) {\n    ResetOpConfName(&sub_op_conf, GetSubOpName(i));\n    for (const auto& ibn : op->input_bns()) {\n      const auto& lbi = *JUST(GetSubLbi(op_conf.scope_symbol_id(), op->BnInOp2Lbi(ibn), i));\n      ReplaceInputLbnInOpCustomizedConf(&sub_op_conf, ibn, GenLogicalBlobName(lbi));\n    }\n    const ParallelConf& parallel_conf = GetLocalOpParallelConf(parallel_desc, i);\n    bool is_local_parallel_view = GetIsLocalParallelView();\n    last_op_attribute =\n        JUST(AddAndInferOp(sub_op_conf, parallel_conf, job_desc, is_local_parallel_view));\n  }\n  bool is_broadcast = JUST(AllInputsBroadcastParallel(*op));\n  for (const auto& obn : op->output_bns()) {\n    const auto& lbi = op->BnInOp2Lbi(obn);\n    auto* sub_lbis = &local_lbi2sub_lbis_[lbi];\n    sub_lbis->resize(sub_op_list_size, op->BnInOp2Lbi(obn));\n    FOR_RANGE(int32_t, i, 0, sub_op_list_size) { sub_lbis->at(i).set_op_name(GetSubOpName(i)); }\n    CHECK(local_lbi2parallel_desc_.emplace(lbi, parallel_desc).second);\n    auto* sbp_parallel = &local_lbi2sbp_parallel_[lbi];\n    if (is_broadcast) {\n      sbp_parallel->mutable_broadcast_parallel();\n    } else {\n      sbp_parallel->mutable_split_parallel()->set_axis(0);\n    }\n  }\n  return last_op_attribute;\n}\n\nMaybe<const LogicalBlobId*> JobBuildAndInferCtx::GetSubLbi(int64_t scope_symbol_id,\n                                                           const LogicalBlobId& lbi,\n                                                           int32_t index) {\n  auto lbi_vec_iter = local_lbi2sub_lbis_.find(lbi);\n  if (lbi_vec_iter == local_lbi2sub_lbis_.end()) {\n    const auto& new_lbi = JUST(FindOrCreateLocalLbiFromCompatibleGlobalBlob(scope_symbol_id, lbi));\n    lbi_vec_iter = local_lbi2sub_lbis_.find(*new_lbi);\n    CHECK(lbi_vec_iter != local_lbi2sub_lbis_.end());\n  }\n  return &lbi_vec_iter->second.at(index);\n}\n\nMaybe<OpAttribute> JobBuildAndInferCtx::AddAndInferGlobalOp(const OperatorConf& op_conf) {\n  CHECK_OR_RETURN(op_conf.has_scope_symbol_id());\n  const auto& scope = Singleton<symbol::Storage<Scope>>::Get()->Get(op_conf.scope_symbol_id());\n  const auto& parallel_desc = *JUST(scope.GetParallelDesc(op_conf));\n  const auto* job_desc = JUST(scope.job_desc());\n  return AddAndInferOp(op_conf, parallel_desc.parallel_conf(), job_desc, false);\n}\n\n// TODO(): add handle error of same interface op blob between jobs\nMaybe<OpAttribute> JobBuildAndInferCtx::AddAndInferOp(const OperatorConf& op_conf,\n                                                      const ParallelConf& origin_parallel_conf,\n                                                      const JobDesc* job_desc,\n                                                      bool is_local_parallel_view) {\n  CHECK_OR_RETURN(has_job_conf_) << Error::JobConfNotSetError();\n  if (!is_job_conf_frozen_) { is_job_conf_frozen_ = true; }\n  const std::string& op_name = op_conf.name();\n  CHECK_OR_RETURN(op_name2op_.find(op_name) == op_name2op_.end())\n      << Error::OpNameExistError() << \"op_name: \" << op_name\n      << \" already exist in job: \" << job_->job_conf().job_name();\n  CHECK_NE_OR_RETURN(op_conf.device_tag(), \"invalid_device\")\n      << Error::OpConfDeviceTagNoSetError() << \"op_name: \" << op_name << \" not set device tag\";\n\n  op_name2op_.emplace(op_name, JUST(ConstructOp(op_conf)));\n  Operator* op = op_name2op_.at(op_name).get();\n\n  SbpSignature sbp_sig_conf;\n  HashMap<std::string, bool> ibn2disable_boxing;\n  InitIbn2DisableBoxing(*op, &ibn2disable_boxing);\n  auto new_op_conf = JUST(DecodeLbiHintAndReturnNewOpConf(*op, &sbp_sig_conf));\n  auto parallel_conf = JUST(InferOpParallelConf(*op, origin_parallel_conf, ibn2disable_boxing));\n  ParallelDesc parallel_desc(*parallel_conf);\n  JUST(op->FillOpParallelDesc(parallel_desc));\n  JUST(AddOpNameParallelConf2Placement(op_name, *parallel_conf));\n\n  auto GetBlobDesc4BnInOp = [&](const std::string& bn) -> BlobDesc* {\n    const LogicalBlobId& lbi = op->BnInOp2Lbi(bn);\n    if (lbi2logical_blob_desc_.find(lbi) != lbi2logical_blob_desc_.end()) {\n      return lbi2logical_blob_desc_.at(lbi).get();\n    }\n    return nullptr;\n  };\n  JUST(op->FillLogicalInBlobDesc(GetBlobDesc4BnInOp));\n  JUST(op->InferParallelSignatureIf());\n\n  // infer local signature\n  JUST(InferLocalSignature(op, is_local_parallel_view, parallel_desc));\n\n  // infer nd_sbp signature\n  NdSbpSignature nd_sbp_sig_conf;\n  // Only infer nd_sbp signature if auto parallel is not enable,\n  // since the semi-auto parallellism rule might have inconsistency with the auto-parallel strategy.\n  if (!job_desc->enable_auto_parallel()) {\n    nd_sbp_sig_conf = *JUST(InitConstraitNdSbpSignature(*op, ibn2disable_boxing));\n  }\n  // Override constrait nd_sbp if sbp hint is given\n  if (!sbp_sig_conf.bn_in_op2sbp_parallel().empty()) {\n    SbpSignatureToNdSbpSignature(sbp_sig_conf, &nd_sbp_sig_conf);\n  }\n  AddOpAndUpdateJobParallelViewConf(*new_op_conf, parallel_desc, nd_sbp_sig_conf,\n                                    is_local_parallel_view);\n  JUST(InferOpOutNdSbp(op, nd_sbp_sig_conf, parallel_desc));\n\n  // infer logical blob desc\n  JUST(GenOpProducedEmptyLogicalBlobDesc(op));\n  JUST(op->InferLogicalOutBlobDescsIf());\n  for (const auto& bn : op->output_bns()) {\n    *lbi2logical_blob_desc_.at(op->BnInOp2Lbi(bn)) = *JUST(op->GetLogicalBlobDesc4Obn(bn));\n  }\n  // Infer ParallelDesc for output blobs.\n  auto ParallelDesc4Obn = [&](const std::string& obn) -> ParallelDesc* {\n    const auto& lbi = op->BnInOp2Lbi(obn);\n    auto iter = lbi2parallel_desc_from_producer_view_.find(lbi);\n    if (iter == lbi2parallel_desc_from_producer_view_.end()) {\n      iter = lbi2parallel_desc_from_producer_view_.emplace(lbi, parallel_desc).first;\n    }\n    return &iter->second;\n  };\n  for (const auto& bn : op->output_bns()) {\n    lbi2parallel_desc_from_producer_view_.emplace(op->BnInOp2Lbi(bn),\n                                                  *JUST(op->GetParallelDesc4BnInOp(bn)));\n  }\n  JUST(AddLbiParallelConf2BlobPlacement(op, ParallelDesc4Obn));\n  // Check splitability\n  JUST(CheckOpBlobSplitability(op, parallel_desc.parallel_num()));\n\n  return op->GetOpAttributeWithoutOpNameAndLbn();\n}\n\nbool JobBuildAndInferCtx::HasJobConf() const { return has_job_conf_; }\n\nMaybe<void> JobBuildAndInferCtx::SetTrainConf(const TrainConf& train_conf) {\n  *job_->mutable_job_conf()->mutable_train_conf() = train_conf;\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> JobBuildAndInferCtx::AddLossLogicalBlobName(const std::string& lbn) {\n  if (IsLocalBlob(lbn)) { return AddLossLocalBlobName(lbn); }\n  return AddLossGlobalBlobName(lbn);\n}\n\nMaybe<void> JobBuildAndInferCtx::AddLossGlobalBlobName(const std::string& lbn) {\n  JUST(CheckLbnValidAndExist(lbn));\n  CHECK_OR_RETURN(job_->job_conf().has_train_conf())\n      << Error::UnknownJobBuildAndInferError()\n      << \"job has no TrainConf when adding loss logical blob name\";\n  job_->mutable_job_conf()->mutable_train_conf()->add_loss_lbn(lbn);\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> JobBuildAndInferCtx::MarkVariableGradientBlobNames(\n    const HashMap<std::string, std::string>& variable_grad_lbns) {\n  CHECK_OR_RETURN(job_->job_conf().has_train_conf())\n      << Error::UnknownJobBuildAndInferError()\n      << \"job has no TrainConf when add variable gradient logical blob name\";\n  auto* train_conf = job_->mutable_job_conf()->mutable_train_conf();\n  for (int i = 0; i < train_conf->optimizer_conf_size(); ++i) {\n    auto* optimizer_conf = train_conf->mutable_optimizer_conf(i);\n    for (const auto& variable_op_name : optimizer_conf->variable_op_names()) {\n      const auto& it = variable_grad_lbns.find(variable_op_name + \"/out\");\n      if (it != variable_grad_lbns.end()) {\n        optimizer_conf->add_variable_grad_lbns(it->second);\n      } else {\n        // add an empty gradient lbn for variable that has no gradient\n        optimizer_conf->add_variable_grad_lbns(\"\");\n      }\n    }\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> JobBuildAndInferCtx::MarkOutputGradientBlobNames(\n    const HashMap<std::string, std::string>& output_gradient_lbns) {\n  CHECK_OR_RETURN(job_->job_conf().has_train_conf())\n      << Error::UnknownJobBuildAndInferError()\n      << \"job has no TrainConf when add variable gradient logical blob name\";\n  auto* train_conf = job_->mutable_job_conf()->mutable_train_conf();\n  for (const auto& loss_lbn : train_conf->loss_lbn()) {\n    const auto& it = output_gradient_lbns.find(loss_lbn);\n    CHECK_OR_RETURN(it != output_gradient_lbns.end())\n        << Error::UnknownJobBuildAndInferError() << \"gradient is missing for loss \" << loss_lbn;\n    train_conf->add_loss_grad_lbn(it->second);\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<Shape> JobBuildAndInferCtx::GetStaticShape(const std::string& lbn) const {\n  JUST(CheckLbnValidAndExist(lbn));\n  return lbi2logical_blob_desc_.at(GenLogicalBlobId(lbn))->shape();\n}\n\nMaybe<DataType> JobBuildAndInferCtx::GetDataType(const std::string& lbn) const {\n  JUST(CheckLbnValidAndExist(lbn));\n  return lbi2logical_blob_desc_.at(GenLogicalBlobId(lbn))->data_type();\n}\n\nMaybe<bool> JobBuildAndInferCtx::IsDynamic(const std::string& lbn) const {\n  JUST(CheckLbnValidAndExist(lbn));\n  return lbi2logical_blob_desc_.at(GenLogicalBlobId(lbn))->is_dynamic();\n}\n\nMaybe<bool> JobBuildAndInferCtx::IsDisableBoxing(const std::string& lbn) const {\n  JUST(CheckLbnValidAndExist(lbn));\n  LogicalBlobId lbi(GenLogicalBlobId(lbn));\n  const auto& iter = lbi2disable_boxing_.find(lbi);\n  CHECK_OR_RETURN(iter != lbi2disable_boxing_.end());\n  return iter->second;\n}\n\nMaybe<void> JobBuildAndInferCtx::DisableBoxing(const std::string& lbn) {\n  JUST(CheckLbnValidAndExist(lbn));\n  LogicalBlobId lbi(GenLogicalBlobId(lbn));\n  lbi2disable_boxing_[lbi] = true;\n  return Maybe<void>::Ok();\n}\n\nMaybe<Operator*> JobBuildAndInferCtx::Op4OpName(const std::string& op_name) const {\n  const auto& op_iter = op_name2op_.find(op_name);\n  CHECK_OR_RETURN(op_iter != op_name2op_.end());\n  auto* op = op_iter->second.get();\n  CHECK_NOTNULL_OR_RETURN(op);\n  return op;\n}\n\nMaybe<OptInt64> JobBuildAndInferCtx::GetSplitAxisFromProducerView(const std::string& lbn) const {\n  JUST(CheckLbnValidAndExist(lbn));\n  OptInt64 ret;\n  const auto& nd_sbp = lbi2nd_sbp_from_producer_view_.at(GenLogicalBlobId(lbn));\n  CHECK_EQ_OR_RETURN(nd_sbp.sbp_parallel_size(), 1);\n  const auto& sbp = nd_sbp.sbp_parallel(0);\n  if (sbp.has_split_parallel()) { ret.set_value(sbp.split_parallel().axis()); }\n  return ret;\n}\n\nMaybe<const ParallelDesc*> JobBuildAndInferCtx::GetParallelDescFromProducerView(\n    const std::string& lbn) const {\n  JUST(CheckLbnValidAndExist(lbn));\n  return &(lbi2parallel_desc_from_producer_view_.at(GenLogicalBlobId(lbn)));\n}\n\nMaybe<void> JobBuildAndInferCtx::AddLossLocalBlobName(const std::string& lbn) {\n  const auto& local_lbi = JUST(GetLocalLbi(lbn));\n  CHECK_OR_RETURN(job_->job_conf().has_train_conf())\n      << Error::UnknownJobBuildAndInferError()\n      << \"job has no TrainConf when adding loss logical blob name\";\n  for (const auto& lbi : local_lbi2sub_lbis_[*local_lbi]) {\n    job_->mutable_job_conf()->mutable_train_conf()->add_loss_lbn(GenLogicalBlobName(lbi));\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<LogicalBlobId> JobBuildAndInferCtx::GetLocalLbi(const std::string& lbn_with_hint) const {\n  const LogicalBlobId& lbi = GenLogicalBlobId(lbn_with_hint);\n  if (local_lbi2sub_lbis_.find(lbi) != local_lbi2sub_lbis_.end()) { return lbi; }\n  return Error::CheckFailedError() << lbn_with_hint << \" is not a local blob name\";\n}\n\nMaybe<int> JobBuildAndInferCtx::LocalBlobGetNumSubLbi(const std::string& lbn_with_hint) const {\n  const auto& local_lbi = JUST(GetLocalLbi(lbn_with_hint));\n  return local_lbi2sub_lbis_.at(*local_lbi).size();  // NOLINT\n}\n\nMaybe<const LogicalBlobId*> JobBuildAndInferCtx::LocalBlobGetSubLbi(\n    const std::string& lbn_with_hint, int index) const {\n  const auto& local_lbi = JUST(GetLocalLbi(lbn_with_hint));\n  const auto& vec = local_lbi2sub_lbis_.at(*local_lbi);  // NOLINT\n  CHECK_GE_OR_RETURN(index, 0);\n  CHECK_LT_OR_RETURN(index, vec.size());\n  return &vec.at(index);\n}\n\nbool JobBuildAndInferCtx::IsLocalBlob(const std::string& lbn) const {\n  bool is_local_blob = TRY(GetLocalLbi(lbn)).IsOk();\n  if (is_local_blob) { return is_local_blob; }\n  const LogicalBlobId& lbi = GenLogicalBlobId(lbn);\n  CHECK(lbi2logical_blob_desc_.find(lbi) != lbi2logical_blob_desc_.end()) << \"lbn: \" << lbn;\n  return false;\n}\n\nMaybe<Shape> JobBuildAndInferCtx::LocalBlobGetStaticShape(const std::string& lbn_with_hint) const {\n  const auto& lbi = *JUST(LocalBlobGetSubLbi(lbn_with_hint, 0));\n  return lbi2logical_blob_desc_.at(lbi)->shape();\n}\n\nMaybe<DataType> JobBuildAndInferCtx::LocalBlobGetDataType(const std::string& lbn_with_hint) const {\n  const auto& lbi = *JUST(LocalBlobGetSubLbi(lbn_with_hint, 0));\n  return lbi2logical_blob_desc_.at(lbi)->data_type();\n}\n\nMaybe<bool> JobBuildAndInferCtx::LocalBlobIsDynamic(const std::string& lbn_with_hint) const {\n  const auto& lbi = *JUST(LocalBlobGetSubLbi(lbn_with_hint, 0));\n  return lbi2logical_blob_desc_.at(lbi)->is_dynamic();\n}\n\nMaybe<OptInt64> JobBuildAndInferCtx::LocalBlobGetSplitAxisFromProducerView(\n    const std::string& lbn_with_hint) const {\n  const auto& lbi = *JUST(LocalBlobGetSubLbi(lbn_with_hint, 0));\n  OptInt64 ret;\n  const auto& nd_sbp = lbi2nd_sbp_from_producer_view_.at(lbi);\n  CHECK_EQ_OR_RETURN(nd_sbp.sbp_parallel_size(), 1);\n  const auto& sbp = nd_sbp.sbp_parallel(0);\n  if (sbp.has_split_parallel()) { ret.set_value(sbp.split_parallel().axis()); }\n  return ret;\n}\n\nMaybe<const ParallelDesc*> JobBuildAndInferCtx::LocalBlobGetParallelDescFromProducerView(\n    const std::string& lbn_with_hint) const {\n  const auto& lbi = JUST(GetLocalLbi(lbn_with_hint));\n  return &(local_lbi2parallel_desc_.at(*lbi));  // NOLINT\n}\n\nMaybe<void> JobBuildAndInferCtx::CheckJob() const {\n  JUST(CheckPlacement());\n  JUST(CheckJobConf());\n  JUST(CheckOpScope());\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> JobBuildAndInferCtx::CheckPlacement() const {\n  HashSet<std::string> op_names_in_net;\n  HashSet<std::string> op_names_in_placement;\n  for (const OperatorConf& op_conf : job_->net().op()) {\n    CHECK_OR_RETURN(op_names_in_net.insert(op_conf.name()).second)\n        << Error::OpNameExistError() << \"op_name: \" << op_conf.name()\n        << \" already exist in job: \" << job_->job_conf().job_name() << \" net\";\n  }\n  for (const PlacementGroup& placement_group : job_->placement().placement_group()) {\n    for (const std::string& op_name : placement_group.op_set().op_name()) {\n      CHECK_OR_RETURN(op_names_in_placement.insert(op_name).second)\n          << Error::OpNameExistError() << \"op_name: \" << op_name\n          << \" already exist in job: \" << job_->job_conf().job_name() << \" placement\";\n    }\n  }\n  CHECK_EQ_OR_RETURN(op_names_in_net.size(), op_names_in_placement.size())\n      << Error::PlacementError() << \"job: \" << job_->job_conf().job_name()\n      << \" op number not equal between net and placement\";\n  for (const std::string& op_name : op_names_in_net) {\n    CHECK_OR_RETURN(op_names_in_placement.find(op_name) != op_names_in_placement.end())\n        << Error::PlacementError() << \"job: \" << job_->job_conf().job_name()\n        << \" op_name: \" << op_name << \" defined in net cannot find its placement\";\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> JobBuildAndInferCtx::CheckJobConf() const {\n  if (job_->job_conf().job_type_case() == JobConfigProto::JOB_TYPE_NOT_SET) {\n    return Error::JobTypeNotSetError() << \"job_type not set, please set predict_conf or train_conf\";\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> JobBuildAndInferCtx::CheckOpScope() const {\n  for (const OperatorConf& op_conf : job_->net().op()) {\n    if (!op_conf.has_scope_symbol_id()) {\n      // NOTE(chengcheng): LOG(WARNING) instead of CHECK_OR_RETURN() for transition\n      LOG(WARNING) << \" ERROR! op_name: \" << op_conf.name()\n                   << \" has NOT set scope(scope_symbol_id) in job: \" << job_->job_conf().job_name()\n                   << \" net. \\n op_conf = \" << op_conf.DebugString();\n    }\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> JobBuildAndInferCtx::CheckLbnValidAndExist(const std::string& lbn) const {\n  CHECK_OR_RETURN(lbn.find('/') != std::string::npos)\n      << Error::LogicalBlobNameInvalidError() << \"lbn:\" << lbn;\n  LogicalBlobId lbi = GenLogicalBlobId(lbn);\n\n#define CHECK_HAS_LBI_KEY(info_src)                     \\\n  CHECK_OR_RETURN(info_src.find(lbi) != info_src.end()) \\\n      << Error::LogicalBlobNameNotExistError() << \"lbn:\" << lbn;\n\n  CHECK_HAS_LBI_KEY(lbi2logical_blob_desc_);\n  CHECK_HAS_LBI_KEY(lbi2nd_sbp_from_producer_view_);\n  CHECK_HAS_LBI_KEY(lbi2parallel_desc_from_producer_view_);\n#undef CHECK_HAS_LBI_KEY\n\n  return Maybe<void>::Ok();\n}\n\nconst Job& JobBuildAndInferCtx::job() const { return *job_; }\n\nstd::string LazyJobBuildAndInferCtx::GetLocalOpName(const std::string& op_name,\n                                                    int64_t parallel_id) const {\n  return op_name + \"_\" + std::to_string(parallel_id);\n}\n\nParallelConf LazyJobBuildAndInferCtx::GetLocalOpParallelConf(const ParallelDesc& parallel_desc,\n                                                             int64_t parallel_id) const {\n  return parallel_desc.GetParallelIdOnlyParallelConf(parallel_id);\n}\n\nMaybe<LogicalBlobId> LazyJobBuildAndInferCtx::FindOrCreateLocalLbiFromCompatibleGlobalBlob(\n    int64_t scope_symbol_id, const LogicalBlobId& lbi) {\n  const std::string& lbn = GenLogicalBlobName(lbi);\n  const auto& sbn_it = mut_global_lbi2local_lbi()->find(lbi);\n  if (sbn_it != mut_global_lbi2local_lbi()->end()) { return sbn_it->second; }\n  const SbpParallel& sbp = *JUST(SbpParallel4Lbi(lbi));\n  const ParallelDesc& parallel_desc = *JUST(ParallelDesc4Lbi(lbi));\n  LogicalBlobId local_lbi;\n  local_lbi.set_op_name(kAutoLocalBlobNamePrefix + NewUniqueId());\n  local_lbi.set_blob_name(\"out\");\n  (*mut_global_lbi2local_lbi())[lbi] = local_lbi;\n  auto* lbi_vec = &(*mut_local_lbi2sub_lbis())[local_lbi];\n  lbi_vec->reserve(parallel_desc.parallel_num());\n  auto PushBackSubLbi = [&](const std::string& op_name, const std::string& blob_name) {\n    LogicalBlobId sub_lbi;\n    sub_lbi.set_op_name(op_name);\n    sub_lbi.set_blob_name(blob_name);\n    lbi_vec->emplace_back(sub_lbi);\n  };\n  OperatorConf op_conf;\n  op_conf.set_scope_symbol_id(scope_symbol_id);\n  op_conf.set_device_tag(*JUST(DeviceTag4DeviceType(parallel_desc.device_type())));\n  if (sbp.has_broadcast_parallel()) {\n    op_conf.set_name(kAutoLocalBlobNamePrefix + \"-DistributeClone-\" + NewUniqueId());\n    auto* distribute_clone = op_conf.mutable_distribute_clone_conf();\n    distribute_clone->set_in(lbn);\n    FOR_RANGE(int32_t, i, 0, parallel_desc.parallel_num()) {\n      const std::string& blob_name = \"out_\" + std::to_string(i);\n      distribute_clone->add_out(blob_name);\n      distribute_clone->set_is_variable_ref(IsVariableLbi(lbi));\n      PushBackSubLbi(op_conf.name(), blob_name);\n    }\n  } else if (sbp.has_split_parallel()) {\n    CHECK_EQ_OR_RETURN(sbp.split_parallel().axis(), 0)\n        << \"only `S(0)' global blob is compatible to local blob\";\n    op_conf.set_name(kAutoLocalBlobNamePrefix + \"-DistributeSplit-\" + NewUniqueId());\n    auto* distribute_split = op_conf.mutable_distribute_split_conf();\n    distribute_split->set_in(lbn);\n    distribute_split->set_axis(0);\n    distribute_split->set_is_variable_ref(IsVariableLbi(lbi));\n    FOR_RANGE(int32_t, i, 0, parallel_desc.parallel_num()) {\n      const std::string& blob_name = \"out_\" + std::to_string(i);\n      distribute_split->add_out(blob_name);\n      PushBackSubLbi(op_conf.name(), blob_name);\n    }\n  } else {\n    OF_UNIMPLEMENTED() << \"`P' global blob is not compatible to local blob\";\n  }\n  {\n    const auto& producer_op_conf = JUST(Op4OpName(lbi.op_name()))->op_conf();\n    CHECK_OR_RETURN(producer_op_conf.has_scope_symbol_id());\n    const auto& scope = Singleton<symbol::Storage<Scope>>::Get()->Get(scope_symbol_id);\n    const auto* job_desc = JUST(scope.job_desc());\n    JUST(AddAndInferOp(op_conf, parallel_desc.parallel_conf(), job_desc, false));\n  }\n  return local_lbi;\n}\n\nMaybe<void> LazyJobBuildAndInferCtx::Complete() {\n  CHECK_GT_OR_RETURN(job().net().op_size(), 0)\n      << \" Sorry, nn.Graph need at least 1 op in net, but get 0 now.\";\n  auto compile_tc = std::make_unique<CostCounter<std::chrono::seconds>>(true, true);\n  CHECK_NOTNULL(Singleton<JobDesc>::Get());\n  // A global variable to get graph configurations.\n  auto current_graph_config = std::make_unique<GlobalJobDescScope>(mut_job()->job_conf(), job_id());\n  JobPassCtx job_pass_ctx(GlobalJobDesc());\n  const auto job_name = job().job_conf().job_name();\n  auto LogJob = [&](const std::string& name_suffix) -> void {\n    std::string full_log_name =\n        job_name + \"-job_id_\" + std::to_string(job_id()) + \"-\" + name_suffix;\n    TeePersistentLogStream::Create(full_log_name)->Write(job());\n    Singleton<OpGraph>::New(job());\n    Singleton<OpGraph>::Get()->ToDotWithFilePath(full_log_name + \".dot\");\n    Singleton<OpGraph>::Delete();\n  };\n  std::string debug_pass_name = GetStringFromEnv(\"ONEFLOW_DEBUG_PASS\", \"\");\n  auto NeedLogJob = [&](const std::string& pass_name) -> bool {\n    if (\"ALL\" == debug_pass_name) {\n      return true;\n    } else if (pass_name == debug_pass_name) {\n      return true;\n    } else {\n      return false;\n    }\n  };\n  int32_t pass_cnt = 0;\n  const int64_t prev_v = FLAGS_v;\n  auto DoPass = [&](const std::string& pass_name, int32_t cnt = 0) -> Maybe<void> {\n    auto pass_tc = std::make_unique<CostCounter<std::chrono::milliseconds>>(true, true);\n    VLOG(1) << job_name << \" start compiling with pass\"\n            << \" pass_cnt_\" + std::to_string(pass_cnt) + \"-\" + pass_name\n            << (cnt > 0 ? std::to_string(cnt) : \"\");\n    if (unlikely(NeedLogJob(pass_name))) {\n      std::string cnt_str = cnt > 0 ? std::to_string(cnt) : \"\";\n      LogJob(\"pass_cnt_\" + std::to_string(pass_cnt) + \"-\" + pass_name + cnt_str + \"-before\");\n      FLAGS_v = 3;\n    }\n    JUST(JobPass4Name(pass_name)(mut_job(), &job_pass_ctx));\n    if (unlikely(NeedLogJob(pass_name))) {\n      FLAGS_v = prev_v;\n      std::string cnt_str = cnt > 0 ? std::to_string(cnt) : \"\";\n      LogJob(\"pass_cnt_\" + std::to_string(pass_cnt) + \"-\" + pass_name + cnt_str + \"-after\");\n    }\n    VLOG(1) << job_name << \" finish compiling with pass\"\n            << \" pass_cnt_\" + std::to_string(pass_cnt) + \"-\" + pass_name\n            << (cnt > 0 ? std::to_string(cnt) : \"\");\n    pass_tc->Count(\"[GraphCompile]\" + job_name + \" \" + pass_name, 1, true);\n    ++pass_cnt;\n    return Maybe<void>::Ok();\n  };\n\n  if (Singleton<ResourceDesc, ForSession>::Get()->enable_debug_mode()) {\n    TeePersistentLogStream::Create(StrCat(\"forward_graph\", job_id()))->Write(job());\n    Singleton<OpGraph>::New(job());\n    Singleton<OpGraph>::Get()->ToDotWithFilePath(\"forward_dlnet_\" + std::to_string(job_id())\n                                                 + \"_op_graph.dot\");\n    Singleton<OpGraph>::Delete();\n  }\n\n  if (GlobalJobDesc().Bool(\"__is_user_function__\")) {\n    // insert pinned identity to prevent the loss, loss initial gradient and\n    // variable gradient from being eliminated by IRRoundTripBeforeAD pass\n    JUST(DoPass(\"InsertPinnedIdentityOpPass\"));\n    // prune the dangling constant which are the 0 gradients initialized by\n    // the autograd engine for those tensors that have no gradients\n    JUST(DoPass(\"EliminateDeadNodesPass\"));\n    JUST(DoPass(\"NormalizationExponentialAverageAutoTickPass\"));\n    JUST(DoPass(\"AutoMixedPrecision\"));\n    // prune depend OP and and add ctrl_in_op to op_conf accordingly\n    // to express the same semantics and avoid performance loss\n    JUST(DoPass(\"PruneDependOpPass\"));\n    JUST(DoPass(\"PruneAmpWhiteIdentityOpPass\"));\n    JUST(DoPass(\"OptimizerPlacementOptimizationPass\"));\n    // run FuseAddToOutputPass before IRRoundTripBeforeAD since add_2 maybe\n    // fused as add_n in IRRoundTripBeforeAD pass\n    JUST(DoPass(\"FuseAddToOutputPass\"));\n#ifdef WITH_MLIR\n    JUST(DoPass(\"IRRoundTripBeforeAD\"));\n#endif  // WITH_MLIR\n    // run DynamicLossScaleSchedulePass, AutoTrainStep and AutoLearningRate\n    // after IRRoundTripBeforeAD since IRRoundTripBeforeAD will do DCE\n    // optimization which could eliminate the nodes inserted by them\n    JUST(DoPass(\"DynamicLossScaleSchedulePass\"));\n    JUST(DoPass(\"AutoTrainStep\"));\n    JUST(DoPass(\"AutoLearningRate\"));\n    JUST(DoPass(\"QuantAwareTraining\"));\n    JUST(DoPass(\"GenerateOptimizerOpConfs\"));\n    // pinned identity can be pruned since GenerateOptimizerOpConfs pass has\n    // already construct a complete computational graph\n    JUST(DoPass(\"PrunePinnedIdentityOpPass\"));\n    JUST(DoPass(\"ReplaceEmbeddingOps\"));\n    JUST(DoPass(\"SequentialOneEmbeddingOpsPass\"));\n    JUST(DoPass(\"FuseEmbeddingShuffleInteractionPass\"));\n    JUST(DoPass(\"FuseBCEReduceMeanFwBwPass\"));\n    JUST(DoPass(\"AddSspVariableProxy\"));\n    JUST(DoPass(\"CheckpointingPass\"));\n    JUST(DoPass(\"CudnnFusedNormalizationAddReluPass\"));\n    JUST(DoPass(\"PruneCastToStaticShapeOpsPass\"));\n#ifdef WITH_MLIR\n    JUST(DoPass(\"IRRoundTrip\"));\n#endif  // WITH_MLIR\n    // run this pass again to fuse ops created in the first run.\n    // TODO(guoran): loop multiple times inside the pass\n    JUST(DoPass(\"FuseAddToOutputPass\", 1));\n    JUST(DoPass(\"FuseConsecutiveAddPass\"));\n    JUST(DoPass(\"IndexedSlicesOptimizerRewritePass\"));\n    JUST(DoPass(\"SplitSparseSoftmaxCrossEntropyOpPass\"));\n    JUST(DoPass(\"DoParallelCastBeforeWideningTypeCast\"));\n    JUST(DoPass(\"FuseCastScalePass\"));\n    JUST(DoPass(\"PruneParallelCastOpsPass\"));\n    JUST(DoPass(\"FuseUpdateOpsPass\"));\n    JUST(DoPass(\"FuseModelUpdateCastOpsPass\"));\n    JUST(DoPass(\"MultiTensorModelUpdatePass\"));\n    JUST(DoPass(\"FixPipelineStageIdPass\"));\n    JUST(DoPass(\"PipelineBufferPass\"));\n    JUST(DoPass(\"AutoParallelPass\"));\n    JUST(DoPass(\"DelayVariableOpExecutionPass\"));\n#ifdef WITH_CUTLASS\n    JUST(DoPass(\"CutlassConvTuningWarmupPass\"));\n#endif  // WITH_CUTLASS\n    JUST(DoPass(\"DumpVariableInfoPass\"));\n  }\n  JUST(DoPass(\"DumpBlobParallelConfPass\"));\n  JUST(CheckJob());\n  compile_tc->Count(\"[GraphCompile]\" + job_name + \" OptimizationLogicalGraph\", 0);\n  return Maybe<void>::Ok();\n}\n\nnamespace {\n\nstd::string OpConf2ClassName(const OperatorConf& op_conf) {\n  if (op_conf.has_user_conf()) {\n    return op_conf.user_conf().op_type_name();\n  } else if (op_conf.has_variable_conf()) {\n    return \"variable\";\n  } else if (op_conf.has_input_conf() && op_conf.has_return_conf()) {\n    return \"input\";\n  } else if (op_conf.has_output_conf() && op_conf.has_return_conf()) {\n    return \"output\";\n  } else {\n    return \"system_op\";\n  }\n}\n\nvoid FormateUserConf(nlohmann::json& json_conf) {\n  nlohmann::json user_conf = json_conf[\"user_conf\"];\n  if (user_conf.is_null()) {\n    json_conf.erase(json_conf.find(\"user_conf\"));\n    return;\n  }\n  std::string nomarl_array[] = {\"at_int32\",  \"at_int64\", \"at_bool\",   \"at_float\",    \"at_double\",\n                                \"at_string\", \"at_shape\", \"at_stride\", \"at_data_type\"};\n  std::string list_array[] = {\"at_list_int32\",     \"at_list_int64\", \"at_list_float\",\n                              \"at_list_data_type\", \"at_list_shape\", \"at_list_stride\",\n                              \"at_list_string\"};\n  nlohmann::json attr_json = user_conf[\"attr\"];\n  for (int32_t i = 0; i < attr_json.size(); i++) {\n    std::string key = attr_json[i][\"key\"];\n    nlohmann::json value_json = attr_json[i][\"value\"];\n    bool is_found_normal = false;\n    for (int32_t j = 0; j < nomarl_array->length(); j++) {\n      std::string value_key = nomarl_array[j];\n      if (value_json.contains(value_key)) {\n        is_found_normal = true;\n        if (\"at_shape\" == value_key || \"at_stride\" == value_key) {\n          json_conf[key] = value_json[value_key][\"dim\"];\n        } else {\n          json_conf[key] = value_json[value_key];\n        }\n        break;\n      }\n    }\n    if (is_found_normal) { continue; }\n    for (int32_t j = 0; j < list_array->length(); j++) {\n      std::string value_key = list_array[j];\n      if (value_json.contains(value_key)) {\n        if (value_json[value_key].contains(\"val\")) {\n          json_conf[key] = value_json[value_key][\"val\"];\n          break;\n        } else if (value_json[value_key].contains(\"dim\")) {\n          json_conf[key] = value_json[value_key][\"dim\"];\n          break;\n        }\n      }\n    }\n  }\n  json_conf.erase(json_conf.find(\"user_conf\"));\n}\n\nvoid FormateVariableConf(nlohmann::json& json_conf) {\n  nlohmann::json variable_conf = json_conf[\"variable_conf\"];\n  if (variable_conf == nullptr) {\n    json_conf.erase(json_conf.find(\"variable_conf\"));\n    return;\n  }\n  for (nlohmann::json::iterator it = variable_conf.begin(); it != variable_conf.end(); ++it) {\n    std::string key = it.key();\n    if (\"shape\" == key) {\n      json_conf[key] = it.value()[\"dim\"];\n    } else {\n      json_conf[key] = it.value();\n    }\n  }\n  json_conf.erase(json_conf.find(\"variable_conf\"));\n}\n\n}  // namespace\n\nstd::string oneflow::JobBuildAndInferCtx::GetJobStructureGraphJson(\n    const std::string& job_name) const {\n  HashSet<std::string> inputs_op_names;\n  HashSet<std::string> outputs_op_names;\n  std::vector<nlohmann::json> layers_vec;\n  layers_vec.reserve(op_name2op_.size());\n  for (const auto& pair : op_name2op_) {\n    nlohmann::json json_layers_pair;\n\n    const Operator* op = pair.second.get();\n    const std::string& op_name = pair.first;\n    HashSet<std::string> inbound_nodes;\n    for (const auto& ibn : op->input_bns()) {\n      const LogicalBlobId& lbi = op->BnInOp2Lbi(ibn);\n      if (op_name2op_.find(lbi.op_name()) != op_name2op_.end()) {\n        inbound_nodes.insert(lbi.op_name());\n      }\n    }\n\n    if (op->op_conf().has_input_conf() && op->op_conf().has_return_conf()) {\n      inputs_op_names.insert(op_name);\n    }\n    if (op->op_conf().has_output_conf() && op->op_conf().has_return_conf()) {\n      outputs_op_names.insert(op_name);\n    }\n    json_layers_pair[\"name\"] = op_name;\n\n    std::string class_name = OpConf2ClassName(op->op_conf());\n    json_layers_pair[\"class_name\"] = class_name;\n\n    nlohmann::json json_conf;\n    summary::ConvertProtobufMsg2Json(json_conf, op->op_conf());\n    FormateUserConf(json_conf);\n    FormateVariableConf(json_conf);\n    json_layers_pair[\"config\"] = json_conf;\n\n    std::vector<std::string> inbound_nodes_vec;\n    inbound_nodes_vec.reserve(inbound_nodes.size());\n    for (const auto& in_node_name : inbound_nodes) { inbound_nodes_vec.emplace_back(in_node_name); }\n    json_layers_pair[\"inbound_nodes\"] = inbound_nodes_vec;\n\n    layers_vec.emplace_back(json_layers_pair);\n  }\n\n  nlohmann::json json_pair;\n  json_pair[\"name\"] = job_name;\n  json_pair[\"layers\"] = layers_vec;\n  json_pair[\"input_layers\"] = inputs_op_names;\n  json_pair[\"output_layers\"] = outputs_op_names;\n\n  return json_pair.dump();\n}\n\nMaybe<void> JobBuildAndInferCtx::Rebuild() {\n  // clear old state\n  lbi2logical_blob_desc_.clear();\n  lbi2nd_sbp_from_producer_view_.clear();\n  lbi2parallel_desc_from_producer_view_.clear();\n  lbi2disable_boxing_.clear();\n  op_name2op_.clear();\n  parallel_desc2placement_group_.clear();\n  parallel_desc2blob_placement_group_.clear();\n  global_lbi2local_lbi_.clear();\n  local_lbi2sub_lbis_.clear();\n  local_lbi2parallel_desc_.clear();\n  local_lbi2sbp_parallel_.clear();\n  op_name2ancestors_need_no_grad_.clear();\n  // record op mirror view\n  HashMap<std::string, bool> op_name2is_local;\n  CHECK_OR_RETURN(job_->has_job_parallel_view_conf());\n  for (const auto& op_conf : job_->net().op()) {\n    const auto& op_name = op_conf.name();\n    CHECK_OR_RETURN(op_name2is_local.find(op_name) == op_name2is_local.end());  // NOLINT\n    op_name2is_local[op_name] = false;\n    const auto& op_name2is_local_parallel_view =\n        job_->job_parallel_view_conf().op_name2is_local_parallel_view();\n    if (op_name2is_local_parallel_view.find(op_name) != op_name2is_local_parallel_view.end()) {\n      if (op_name2is_local_parallel_view.at(op_name)) { op_name2is_local[op_name] = true; }\n    }\n  }\n  // build op graph\n  OpGraph op_graph;\n  if (Singleton<JobDesc>::Get()) {\n    JUST(op_graph.Init(*job_));\n  } else {\n    auto scope = std::make_unique<GlobalJobDescScope>(job_->job_conf(), job_id());\n    JUST(op_graph.Init(*job_));\n  }\n  // clear old job except job_conf\n  job_->mutable_net()->Clear();\n  job_->mutable_placement()->Clear();\n  job_->mutable_job_parallel_view_conf()->Clear();\n  job_->mutable_helper()->Clear();\n  // topo traverse op_graph to AddAndInferOp\n  op_graph.TopoForEachNode([&](OpNode* node) -> void {\n    const auto& op_conf = node->op().op_conf();\n    CHECK(op_name2is_local.find(op_conf.name()) != op_name2is_local.end());\n    bool is_local = op_name2is_local.at(op_conf.name());\n    if (is_local) {\n      CHECK_JUST(AddAndInferLocalOp(op_conf));\n    } else {\n      CHECK_JUST(AddAndInferGlobalOp(op_conf));\n    }\n  });\n  // updata job_helper\n  op_graph.DumpLogicalBlobDesc(job_);\n  op_graph.DumpNdSbpSignature(job_);\n  return Maybe<void>::Ok();\n}\n\nMaybe<std::string> JobBuildAndInferCtx::GetOpBlobLbn(const std::string& op_name,\n                                                     const std::string& bn_in_op) const {\n  const auto& lbi = JUST(Op4OpName(op_name))->BnInOp2Lbi(bn_in_op);\n  return GenLogicalBlobName(lbi);\n}\n\nMaybe<std::string> JobBuildAndInferCtx::NewUniqueOpNameByFunctionalOpConf(\n    const OperatorConf& op_conf) {\n  // NOTE(chengcheng): arg op_conf has a default global op_name because it is created by\n  //  static functional op expr, so we need reset a unique op name for each functional op.\n  //  This op_conf can NOT be a input/output/variable op which has set correct name in nn.Graph.\n  //  But free eager tensor is treated as a special variable which needs to create name here.\n  CHECK_OR_RETURN(!(op_conf.has_input_conf() || op_conf.has_output_conf()));\n\n  const auto& scope = JUST(GetCurrentScope());\n\n  std::string op_name_prefix;\n  for (const std::string& prefix : scope->scope_proto().scope_op_name_prefixes()) {\n    op_name_prefix += (prefix + \"-\");\n  }\n  std::string op_type_name;\n  if (op_conf.has_user_conf()) {\n    op_type_name = op_conf.user_conf().op_type_name();\n  } else if (op_conf.has_variable_conf()) {\n    // NOTE(chengcheng): To support Free Eager Tensor caught by nn.Graph\n    op_type_name = \"FreeEagerTensor\";\n  } else {\n    op_type_name = \"SystemOp\";\n  }\n  std::string op_name = op_name_prefix + op_type_name + \"-\" + std::to_string(unique_op_name_index_);\n  ++unique_op_name_index_;\n\n  return op_name;\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job/job_build_and_infer_ctx.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_JOB_JOB_BUILD_AND_INFER_CTX_H_\n#define ONEFLOW_CORE_JOB_JOB_BUILD_AND_INFER_CTX_H_\n\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/common/shape.h\"\n#include \"oneflow/core/common/stride.h\"\n#include \"oneflow/core/common/data_type.h\"\n#include \"oneflow/core/job/parallel_desc.h\"\n#include \"oneflow/core/job/job.pb.h\"\n#include \"oneflow/core/operator/operator.h\"\n#include \"oneflow/core/register/blob_desc.h\"\n\nnamespace oneflow {\n\nclass JobBuildAndInferCtx {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(JobBuildAndInferCtx);\n  JobBuildAndInferCtx(Job* job, int64_t job_id);\n  virtual ~JobBuildAndInferCtx() = default;\n\n  Maybe<void> SetJobConf(const JobConfigProto& job_conf);\n  Maybe<OpAttribute> AddAndInferGlobalOp(const OperatorConf& op_conf);\n  Maybe<OpAttribute> AddAndInferLocalOp(const OperatorConf& op_conf);\n  Maybe<void> AddLossLogicalBlobName(const std::string& lbn);\n  Maybe<void> SetTrainConf(const TrainConf& train_conf);\n  Maybe<void> MarkVariableGradientBlobNames(\n      const HashMap<std::string, std::string>& variable_grad_lbns);\n  Maybe<void> MarkOutputGradientBlobNames(\n      const HashMap<std::string, std::string>& output_gradient_lbns);\n\n  bool HasJobConf() const;\n  Maybe<Shape> GetStaticShape(const std::string& lbn) const;\n  Maybe<DataType> GetDataType(const std::string& lbn) const;\n  Maybe<bool> IsDynamic(const std::string& lbn) const;\n  Maybe<bool> IsDisableBoxing(const std::string& lbn) const;\n  Maybe<void> DisableBoxing(const std::string& lbn);\n  Maybe<OptInt64> GetSplitAxisFromProducerView(const std::string& lbn) const;\n  Maybe<const ParallelDesc*> GetParallelDescFromProducerView(const std::string& lbn) const;\n\n  bool IsLocalBlob(const std::string& lbn) const;\n  Maybe<int> LocalBlobGetNumSubLbi(const std::string& lbn) const;\n  Maybe<const LogicalBlobId*> LocalBlobGetSubLbi(const std::string& lbn, int index) const;\n\n  Maybe<Shape> LocalBlobGetStaticShape(const std::string& lbn_with_hint) const;\n  Maybe<DataType> LocalBlobGetDataType(const std::string& lbn_with_hint) const;\n  Maybe<bool> LocalBlobIsDynamic(const std::string& lbn_with_hint) const;\n  Maybe<OptInt64> LocalBlobGetSplitAxisFromProducerView(const std::string& lbn_with_hint) const;\n  Maybe<const ParallelDesc*> LocalBlobGetParallelDescFromProducerView(\n      const std::string& lbn_with_hint) const;\n\n  const Job& job() const;\n  int64_t job_id() const { return job_id_; }\n  Maybe<void> CheckJob() const;\n  std::string GetJobStructureGraphJson(const std::string& job_name) const;\n  Maybe<void> CheckLbnValidAndExist(const std::string& lbn) const;\n  Maybe<void> Rebuild();\n  Maybe<std::string> GetOpBlobLbn(const std::string& op_name, const std::string& bn_in_op) const;\n\n  // NOTE(chengcheng): Only used in multi-client.\n  Maybe<std::string> NewUniqueOpNameByFunctionalOpConf(const OperatorConf& op_conf);\n  Maybe<Operator*> Op4OpName(const std::string& op_name) const;\n\n  virtual Maybe<void> Complete() = 0;\n\n protected:\n  virtual Maybe<void> CheckAllInputsWithSameParallelNum(const Operator& op,\n                                                        int32_t parallel_num) const = 0;\n  virtual std::string GetLocalOpName(const std::string& op_name, int64_t parallel_id) const = 0;\n  virtual int64_t SizeOfSubGlobalOpList(int64_t parallel_num) const = 0;\n  virtual ParallelConf GetLocalOpParallelConf(const ParallelDesc&, int64_t parallel_id) const = 0;\n  virtual bool GetIsLocalParallelView() const = 0;\n  virtual Maybe<LogicalBlobId> FindOrCreateLocalLbiFromCompatibleGlobalBlob(\n      int64_t scope_symbol_id, const LogicalBlobId& lbn) = 0;\n\n  Job* mut_job() const { return job_; }\n  const HashMap<LogicalBlobId, std::vector<LogicalBlobId>>& local_lbi2sub_lbis() const {\n    return local_lbi2sub_lbis_;\n  }\n  HashMap<LogicalBlobId, std::vector<LogicalBlobId>>* mut_local_lbi2sub_lbis() {\n    return &local_lbi2sub_lbis_;\n  }\n  Maybe<const ParallelDesc*> ParallelDesc4Lbi(const LogicalBlobId& lbi) const;\n  HashMap<LogicalBlobId, LogicalBlobId>* mut_global_lbi2local_lbi() {\n    return &global_lbi2local_lbi_;\n  }\n  Maybe<const SbpParallel*> SbpParallel4Lbi(const LogicalBlobId& lbi) const;\n  bool IsVariableLbi(const LogicalBlobId& lbi) const;\n  Maybe<OpAttribute> AddAndInferOp(const OperatorConf& op_conf, const ParallelConf& parallel_conf,\n                                   const JobDesc* job_desc, bool is_local_parallel_view);\n\n private:\n  Maybe<ParallelConf> InferOpParallelConf(\n      const Operator& op, const ParallelConf& origin_parallel_conf,\n      const HashMap<std::string, bool>& ibn2disable_boxing) const;\n  Maybe<void> AddOpNameParallelConf2Placement(const std::string& op_name,\n                                              const ParallelConf& parallel_conf);\n  void InitIbn2DisableBoxing(const Operator& op, HashMap<std::string, bool>* ibn2disable_boxing);\n  Maybe<NdSbpSignature> InitConstraitNdSbpSignature(\n      const Operator& op, const HashMap<std::string, bool>& ibn2disable_boxing) const;\n  Maybe<OperatorConf> DecodeLbiHintAndReturnNewOpConf(const Operator& op,\n                                                      SbpSignature* sbp_sig_conf) const;\n  Maybe<void> AddLbiParallelConf2BlobPlacement(\n      const Operator* op, std::function<ParallelDesc*(const std::string&)> ParallelDesc4Obn);\n  void AddOpAndUpdateJobParallelViewConf(const OperatorConf& operator_conf,\n                                         const ParallelDesc& parallel_desc,\n                                         const NdSbpSignature& nd_sbp_signature,\n                                         bool is_local_parallel_view) const;\n  Maybe<void> InferLocalSignature(Operator*, bool is_local_parallel_view_conf, const ParallelDesc&);\n  Maybe<void> InferOpOutNdSbp(Operator*, const NdSbpSignature&, const ParallelDesc&);\n  Maybe<void> GenOpProducedEmptyLogicalBlobDesc(Operator* op);\n  Maybe<void> CheckOpBlobSplitability(Operator*, int64_t parallel_num);\n  Maybe<void> CheckPlacement() const;\n  Maybe<void> CheckJobConf() const;\n  Maybe<void> CheckOpScope() const;\n  Maybe<LogicalBlobId> GetLocalLbi(const std::string& lbn_with_hint) const;\n  bool HasAnyLocalBlobInput(const Operator& op) const;\n  Maybe<void> CheckAllInputsConvertableToLocalBlob(const Operator& op) const;\n  Maybe<void> AddLossGlobalBlobName(const std::string& lbn);\n  Maybe<void> AddLossLocalBlobName(const std::string& lbn);\n  Maybe<const LogicalBlobId*> GetSubLbi(int64_t scope_symbol_id, const LogicalBlobId& lbi,\n                                        int32_t index);\n  Maybe<bool> AllInputsBroadcastParallel(const Operator& op) const;\n\n  Job* job_;\n  int64_t job_id_;\n  HashMap<LogicalBlobId, std::unique_ptr<BlobDesc>> lbi2logical_blob_desc_;\n  HashMap<LogicalBlobId, NdSbp> lbi2nd_sbp_from_producer_view_;\n  HashMap<LogicalBlobId, ParallelDesc> lbi2parallel_desc_from_producer_view_;\n  HashMap<LogicalBlobId, bool> lbi2disable_boxing_;\n  HashMap<std::string, std::shared_ptr<Operator>> op_name2op_;\n  HashMap<ParallelDesc, PlacementGroup*> parallel_desc2placement_group_;\n  HashMap<ParallelDesc, BlobPlacementGroup*> parallel_desc2blob_placement_group_;\n  HashMap<LogicalBlobId, LogicalBlobId> global_lbi2local_lbi_;\n  HashMap<LogicalBlobId, std::vector<LogicalBlobId>> local_lbi2sub_lbis_;\n  HashMap<LogicalBlobId, ParallelDesc> local_lbi2parallel_desc_;\n  HashMap<LogicalBlobId, SbpParallel> local_lbi2sbp_parallel_;\n  bool is_job_conf_frozen_;\n  bool has_job_conf_;\n  HashMap<std::string, bool> op_name2ancestors_need_no_grad_;\n  int64_t unique_op_name_index_;\n};\n\nclass LazyJobBuildAndInferCtx : public JobBuildAndInferCtx {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(LazyJobBuildAndInferCtx);\n  LazyJobBuildAndInferCtx(Job* job, int64_t job_id) : JobBuildAndInferCtx(job, job_id) {}\n  virtual ~LazyJobBuildAndInferCtx() = default;\n\n private:\n  Maybe<void> Complete() override;\n  Maybe<void> CheckAllInputsWithSameParallelNum(const Operator& op,\n                                                int32_t parallel_num) const override;\n  std::string GetLocalOpName(const std::string& op_name, int64_t parallel_id) const override;\n  int64_t SizeOfSubGlobalOpList(int64_t parallel_num) const override { return parallel_num; }\n  ParallelConf GetLocalOpParallelConf(const ParallelDesc&, int64_t parallel_id) const override;\n  bool GetIsLocalParallelView() const override { return false; }\n  Maybe<LogicalBlobId> FindOrCreateLocalLbiFromCompatibleGlobalBlob(\n      int64_t scope_symbol_id, const LogicalBlobId& lbn) override;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_JOB_JOB_BUILD_AND_INFER_CTX_H_\n"
  },
  {
    "path": "oneflow/core/job/job_build_and_infer_ctx_mgr.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/job/job_build_and_infer_ctx_mgr.h\"\n\n#include \"oneflow/core/common/singleton.h\"\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/framework/multi_client_session_context.h\"\n#include \"oneflow/core/job/global_for.h\"\n#include \"oneflow/core/job/id_state.h\"\n#include \"oneflow/core/job/lazy_mode.h\"\n#include \"nlohmann/json.hpp\"\n\nnamespace oneflow {\n\nMaybe<void> JobBuildAndInferCtxMgr::OpenJobBuildAndInferCtx(const std::string& job_name) {\n  CHECK_OR_RETURN(!has_cur_job_) << Error::UnknownJobBuildAndInferError()\n                                 << \"cur job not leave before you enter this job_name:\" << job_name;\n  CHECK_OR_RETURN(!job_name.empty()) << Error::JobNameEmptyError();\n  CHECK_OR_RETURN(job_name2infer_ctx_.find(job_name) == job_name2infer_ctx_.end())\n      << Error::JobNameExistError() << \"job name: \" << job_name << \" already exist\";\n  int64_t job_id = job_id_count_++;\n  Job* job = job_set_.add_job();\n  job->mutable_job_conf()->set_job_name(job_name);\n  std::unique_ptr<JobBuildAndInferCtx> ctx(NewJobBuildAndInferCtx(job, job_id));\n  job_name2infer_ctx_.emplace(job_name, std::move(ctx));\n  cur_job_name_ = job_name;\n  has_cur_job_ = true;\n  return Maybe<void>::Ok();\n}\n\nJobBuildAndInferCtx* LazyJobBuildAndInferCtxMgr::NewJobBuildAndInferCtx(Job* job,\n                                                                        int64_t job_id) const {\n  return new LazyJobBuildAndInferCtx(job, job_id);\n}\n\nMaybe<JobBuildAndInferCtx*> JobBuildAndInferCtxMgr::FindJobBuildAndInferCtx(\n    const std::string& job_name) {\n  CHECK_OR_RETURN(job_name2infer_ctx_.find(job_name) != job_name2infer_ctx_.end())\n      << Error::NoJobBuildAndInferCtxError() << \"cannot find job name:\" << job_name;\n  return job_name2infer_ctx_.at(job_name).get();\n}\n\nMaybe<std::string> JobBuildAndInferCtxMgr::GetCurrentJobName() const {\n  CHECK_OR_RETURN(has_cur_job_) << Error::NoJobBuildAndInferCtxError()\n                                << \"current JobBuildAndInferCtx was closed, job name: \"\n                                << cur_job_name_;\n  return cur_job_name_;\n}\n\nMaybe<void> JobBuildAndInferCtxMgr::CloseCurrentJobBuildAndInferCtx() {\n  OF_RETURN_IF_ERROR(VirtualCloseJob());\n  has_cur_job_ = false;\n  return Maybe<void>::Ok();\n}\n\nstd::string JobBuildAndInferCtxMgr::structure_graph() const {\n  nlohmann::json json_array;\n  for (const auto& pair : job_name2infer_ctx_) {\n    nlohmann::json json_pair;\n    json_pair[\"class_name\"] = \"Model\";\n    std::string tmp_json = pair.second->GetJobStructureGraphJson(pair.first);\n    json_pair[\"config\"] = nlohmann::json::parse(tmp_json);\n    json_pair[\"backend\"] = \"oneflow\";\n    json_array.emplace_back(json_pair);\n  }\n  return json_array.dump();\n}\n\nvoid JobBuildAndInferCtxMgr::TryUpdateJobIdCount(int64_t id_count) {\n  job_id_count_ = std::max(id_count, job_id_count_);\n}\n\nint64_t JobBuildAndInferCtxMgr::GetJobIdCount() const { return job_id_count_; }\n\nMaybe<void> LazyJobBuildAndInferCtxMgr::VirtualCloseJob() {\n  const JobDesc* job_desc = Singleton<JobDesc>::Get();\n  if (job_desc == nullptr) { return Maybe<void>::Ok(); }\n  CHECK_EQ_OR_RETURN(job_desc->job_name(), *JUST(GetCurrentJobName()));\n  Singleton<JobDesc>::Delete();\n  return Maybe<void>::Ok();\n}\n\nMaybe<JobBuildAndInferCtxMgr*> GlobalJobBuildAndInferCtxMgr() {\n  return JUST(SingletonMaybe<LazyJobBuildAndInferCtxMgr>());\n}\n\nMaybe<JobBuildAndInferCtx*> GetJobBuildAndInferCtx(const std::string& job_name) {\n  auto* mgr = JUST(GlobalJobBuildAndInferCtxMgr());\n  return mgr->FindJobBuildAndInferCtx(job_name);\n}\n\nMaybe<JobBuildAndInferCtx*> GetCurInferCtx() {\n  auto* mgr = JUST(GlobalJobBuildAndInferCtxMgr());\n  return mgr->FindJobBuildAndInferCtx(*JUST(mgr->GetCurrentJobName()));\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job/job_build_and_infer_ctx_mgr.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_JOB_JOB_BUILD_AND_INFER_CXT_MGR_H_\n#define ONEFLOW_CORE_JOB_JOB_BUILD_AND_INFER_CXT_MGR_H_\n\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/job/job.pb.h\"\n#include \"oneflow/core/job/job_set.pb.h\"\n#include \"oneflow/core/job/job_build_and_infer_ctx.h\"\n\nnamespace oneflow {\n\nclass JobBuildAndInferCtxMgr {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(JobBuildAndInferCtxMgr);\n  virtual ~JobBuildAndInferCtxMgr() = default;\n\n  Maybe<void> OpenJobBuildAndInferCtx(const std::string& job_name);\n  Maybe<JobBuildAndInferCtx*> FindJobBuildAndInferCtx(const std::string& job_name);\n  Maybe<std::string> GetCurrentJobName() const;\n  Maybe<void> CloseCurrentJobBuildAndInferCtx();\n\n  const JobSet& job_set() const { return job_set_; }\n  std::string structure_graph() const;\n  void TryUpdateJobIdCount(int64_t id_count);\n  int64_t GetJobIdCount() const;\n\n protected:\n  virtual JobBuildAndInferCtx* NewJobBuildAndInferCtx(Job* job, int64_t job_id) const = 0;\n  JobBuildAndInferCtxMgr() : has_cur_job_(false) {}\n  virtual Maybe<void> VirtualCloseJob() = 0;\n  JobSet* mut_job_set() { return &job_set_; }\n\n  void clear_job_name2infer_ctx() { job_name2infer_ctx_.clear(); }\n\n private:\n  JobSet job_set_;\n  int64_t job_id_count_{0};\n  bool has_cur_job_;\n  std::string cur_job_name_;\n  HashMap<std::string, std::unique_ptr<JobBuildAndInferCtx>> job_name2infer_ctx_;\n};\n\nclass LazyJobBuildAndInferCtxMgr : public JobBuildAndInferCtxMgr {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(LazyJobBuildAndInferCtxMgr);\n  LazyJobBuildAndInferCtxMgr() : JobBuildAndInferCtxMgr() {}\n  ~LazyJobBuildAndInferCtxMgr() override = default;\n\n private:\n  friend class Singleton<LazyJobBuildAndInferCtxMgr>;\n\n  Maybe<void> VirtualCloseJob() override;\n  JobBuildAndInferCtx* NewJobBuildAndInferCtx(Job* job, int64_t job_id) const override;\n};\n\nMaybe<JobBuildAndInferCtxMgr*> GlobalJobBuildAndInferCtxMgr();\nMaybe<JobBuildAndInferCtx*> GetJobBuildAndInferCtx(const std::string& job_name);\nMaybe<JobBuildAndInferCtx*> GetCurInferCtx();\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_JOB_JOB_BUILD_AND_INFER_CXT_MGR_H_\n"
  },
  {
    "path": "oneflow/core/job/job_builder.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/job/job_builder.h\"\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/common/container_util.h\"\n#include \"oneflow/core/job/job.pb.h\"\n#include \"oneflow/core/job/sbp_parallel.pb.h\"\n#include \"oneflow/core/operator/op_conf.pb.h\"\n#include \"oneflow/core/operator/operator.h\"\n#include \"oneflow/core/vm/symbol_storage.h\"\n#include \"oneflow/core/framework/scope_util.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nint64_t GetParallelHierarchyNumAxes(\n    const HashMap<std::string, ParallelConf*>& op_name2parallel_conf, const std::string& op_name) {\n  const auto& it = op_name2parallel_conf.find(op_name);\n  CHECK(it != op_name2parallel_conf.end());\n  if (!it->second->has_hierarchy()) {\n    return 1;\n  } else if (it->second->hierarchy().dim_size() == 0) {\n    return 1;\n  } else {\n    return it->second->hierarchy().dim_size();\n  }\n}\n\nvoid SetNdSbpSignature4Oba(Job* job,\n                           HashMap<std::string, NdSbpSignature*>* op_name2nd_sbp_signature_map,\n                           const OpBlobArg& oba, const NdSbp& nd_sbp) {\n  auto* nd_sbp_sig = &(*job->mutable_job_parallel_view_conf()\n                            ->mutable_op_name2nd_sbp_signature_conf())[oba.op_name()];\n  (*nd_sbp_sig->mutable_bn_in_op2nd_sbp())[oba.bn_in_op()] = nd_sbp;\n  auto* op_name2nd_sbp_signature_conf =\n      job->mutable_job_parallel_view_conf()->mutable_op_name2nd_sbp_signature_conf();\n  (*op_name2nd_sbp_signature_map)[oba.op_name()] = &(*op_name2nd_sbp_signature_conf)[oba.op_name()];\n}\n\nvoid SetSbpSignature4Oba(Job* job, const OpBlobArg& oba, const SbpParallel& sbp_parallel) {\n  auto* sbp_sig = &(\n      *job->mutable_job_parallel_view_conf()->mutable_op_name2sbp_signature_conf())[oba.op_name()];\n  (*sbp_sig->mutable_bn_in_op2sbp_parallel())[oba.bn_in_op()] = sbp_parallel;\n}\n\nvoid AddOrSetNdSbpSignature4OpName(\n    Job* job, HashMap<std::string, NdSbpSignature*>* op_name2nd_sbp_signature_map,\n    const std::string& op_name, const NdSbpSignature& nd_sbp_signature) {\n  const auto& it = op_name2nd_sbp_signature_map->find(op_name);\n  if (it != op_name2nd_sbp_signature_map->end()) {\n    *it->second = nd_sbp_signature;\n  } else {\n    auto* op_name2nd_sbp_signature_conf =\n        job->mutable_job_parallel_view_conf()->mutable_op_name2nd_sbp_signature_conf();\n    (*op_name2nd_sbp_signature_conf)[op_name] = nd_sbp_signature;\n    op_name2nd_sbp_signature_map->emplace(op_name, &(*op_name2nd_sbp_signature_conf)[op_name]);\n  }\n}\n\nvoid AddOrSetSbpSignature4OpName(Job* job, const std::string& op_name,\n                                 const SbpSignature& sbp_signature) {\n  auto* op_name2sbp_signature_conf =\n      job->mutable_job_parallel_view_conf()->mutable_op_name2sbp_signature_conf();\n  (*op_name2sbp_signature_conf)[op_name] = sbp_signature;\n}\n\n}  // namespace\n\nstd::function<const ParallelConf*(const std::string&)> MakeGetterParallelConf4OpName(\n    const Placement& placement) {\n  auto op_name2parallel_conf = std::make_shared<HashMap<std::string, const ParallelConf*>>();\n  for (const auto& placement_group : placement.placement_group()) {\n    for (const std::string& op_name : placement_group.op_set().op_name()) {\n      const ParallelConf* parallel_conf = &placement_group.parallel_conf();\n      CHECK(op_name2parallel_conf->emplace(op_name, parallel_conf).second)\n          << \"op_name: \" << op_name;\n    }\n  }\n  return [op_name2parallel_conf](const std::string& op_name) {\n    return op_name2parallel_conf->at(op_name);\n  };\n}\n\nJobBuilder::JobBuilder(Job* job) : job_(job) {\n  FOR_RANGE(int32_t, i, 0, job->net().op_size()) {\n    CHECK(op_name2op_conf_.emplace(job->net().op(i).name(), job->mutable_net()->mutable_op(i))\n              .second);\n  }\n  bool all_ops_1d_hierarchy = true;\n  FOR_RANGE(int32_t, i, 0, job->placement().placement_group_size()) {\n    auto* placemnt_group = job->mutable_placement()->mutable_placement_group(i);\n    if (placemnt_group->parallel_conf().has_hierarchy()\n        && placemnt_group->parallel_conf().hierarchy().dim_size() > 1) {\n      all_ops_1d_hierarchy = false;\n    }\n  }\n  auto* job_parallel_view_conf = job->mutable_job_parallel_view_conf();\n  for (auto& pair : *(job_parallel_view_conf->mutable_op_name2nd_sbp_signature_conf())) {\n    op_name2nd_sbp_signature_conf_.emplace(pair.first, &pair.second);\n  }\n  if (all_ops_1d_hierarchy) {\n    CHECK_EQ(job_parallel_view_conf->op_name2sbp_signature_conf_size(),\n             job_parallel_view_conf->op_name2nd_sbp_signature_conf_size());\n    for (const auto& pair : job_parallel_view_conf->op_name2nd_sbp_signature_conf()) {\n      const auto& op_name2sbp_sig = job_parallel_view_conf->op_name2sbp_signature_conf();\n      const auto it = op_name2sbp_sig.find(pair.first);\n      CHECK(it != op_name2sbp_sig.end());\n      CheckSbpSignatureAndNdSbpEquals(SbpSignature(it->second), NdSbpSignature(pair.second));\n    }\n  }\n  FOR_RANGE(int32_t, i, 0, job->placement().blob_placement_group_size()) {\n    auto* blob_pg = job->mutable_placement()->mutable_blob_placement_group(i);\n    for (const auto& lbi : blob_pg->lbi()) {\n      CHECK(lbi2blob_parallel_conf_.emplace(lbi, blob_pg->mutable_parallel_conf()).second);\n    }\n  }\n  for (auto& placement_group : *job->mutable_placement()->mutable_placement_group()) {\n    if (placement_group.op_set().op_name().empty()) { continue; }\n    const ParallelConf& parallel_conf = placement_group.parallel_conf();\n    auto it = parallel_conf2placement_group_.find(parallel_conf);\n    if (it == parallel_conf2placement_group_.end()) {\n      parallel_conf2placement_group_.emplace(parallel_conf, &placement_group);\n      for (const auto& op_name : placement_group.op_set().op_name()) {\n        CHECK(op_name2parallel_conf_.emplace(op_name, placement_group.mutable_parallel_conf())\n                  .second);\n      }\n    } else {\n      PlacementGroup* existing_placement_group = it->second;\n      for (const auto& op_name : placement_group.op_set().op_name()) {\n        *existing_placement_group->mutable_op_set()->mutable_op_name()->Add() = op_name;\n        CHECK(op_name2parallel_conf_\n                  .emplace(op_name, existing_placement_group->mutable_parallel_conf())\n                  .second);\n      }\n      placement_group.mutable_op_set()->mutable_op_name()->Clear();\n    }\n  }\n}\n\nMaybe<OperatorConf*> JobBuilder::MutableOpConf4OpName(const std::string& op_name) {\n  const auto& it = op_name2op_conf_.find(op_name);\n  CHECK_OR_RETURN(it != op_name2op_conf_.end());\n  return it->second;\n}\n\nMaybe<const OperatorConf&> JobBuilder::OpConf4OpName(const std::string& op_name) const {\n  return *JUST(MapAt(op_name2op_conf_, op_name));\n}\n\nMaybe<const ParallelConf&> JobBuilder::ParallelConf4Lbi(const LogicalBlobId& lbi) const {\n  const auto& iter = lbi2blob_parallel_conf_.find(lbi);\n  if (iter != lbi2blob_parallel_conf_.end()) { return *iter->second; }\n  return ParallelConf4OpName(lbi.op_name());\n}\n\nMaybe<void> JobBuilder::AddOp(const ParallelConf& parallel_conf, const OperatorConf& op_conf) {\n  CHECK_OR_RETURN(op_name2op_conf_.find(op_conf.name()) == op_name2op_conf_.end());\n  OperatorConf* mut_op_conf = job_->mutable_net()->add_op();\n  *mut_op_conf = op_conf;\n  CHECK_OR_RETURN(op_name2op_conf_.emplace(op_conf.name(), mut_op_conf).second);\n  AddOpToModuleConf(op_conf);\n  AddOpNamesToPlacementGroup({op_conf.name()}, parallel_conf);\n  return Maybe<void>::Ok();\n}\n\nvoid JobBuilder::AddOps(const ParallelConf& parallel_conf,\n                        const std::vector<OperatorConf>& op_confs) {\n  if (op_confs.empty()) { return; }\n  std::vector<std::string> op_names;\n  op_names.reserve(op_confs.size());\n  for (const auto& op_conf : op_confs) {\n    CHECK(op_name2op_conf_.find(op_conf.name()) == op_name2op_conf_.end());\n    OperatorConf* mut_op_conf = job_->mutable_net()->add_op();\n    *mut_op_conf = op_conf;\n    CHECK(op_name2op_conf_.emplace(op_conf.name(), mut_op_conf).second);\n    op_names.emplace_back(op_conf.name());\n    AddOpToModuleConf(op_conf);\n  }\n  AddOpNamesToPlacementGroup(op_names, parallel_conf);\n}\n\nvoid JobBuilder::AddOpToModuleConf(const OperatorConf& op_conf) {\n  // set up the module config\n  if (Singleton<symbol::Storage<Scope>>::Get()->Has(op_conf.scope_symbol_id())) {\n    const auto& scope = Singleton<symbol::Storage<Scope>>::Get()->Get(op_conf.scope_symbol_id());\n    if (scope.scope_proto().has_module_name()) {\n      const auto& module_name = scope.scope_proto().module_name();\n      auto* module_name2module_conf = job_->mutable_module_name2module_conf();\n      if (!(*module_name2module_conf)[module_name].has_name()) {\n        (*module_name2module_conf)[module_name].set_name(scope.scope_proto().module_name());\n      }\n\n      *((*module_name2module_conf)[module_name].add_ops()) = op_conf.name();\n      return;\n    }\n  }\n  const auto& module_name = job_->job_conf().job_name();\n  auto* module_name2module_conf = job_->mutable_module_name2module_conf();\n  if (!(*module_name2module_conf)[module_name].has_name()) {\n    (*module_name2module_conf)[module_name].set_name(module_name);\n  }\n\n  *((*module_name2module_conf)[module_name].add_ops()) = op_conf.name();\n}\n\nvoid JobBuilder::AddOpNamesToPlacementGroup(const std::vector<std::string>& op_names,\n                                            const ParallelConf& parallel_conf) {\n  PlacementGroup* placement_group = nullptr;\n  auto it = parallel_conf2placement_group_.find(parallel_conf);\n  if (it != parallel_conf2placement_group_.end()) {\n    placement_group = it->second;\n  } else {\n    placement_group = job_->mutable_placement()->add_placement_group();\n    *placement_group->mutable_parallel_conf() = parallel_conf;\n    parallel_conf2placement_group_.emplace(parallel_conf, placement_group);\n  }\n  for (const auto& op_name : op_names) {\n    placement_group->mutable_op_set()->add_op_name(op_name);\n    CHECK(op_name2parallel_conf_.emplace(op_name, placement_group->mutable_parallel_conf()).second);\n  }\n}\n\nvoid JobBuilder::MutParallelConfOnlyOnce(const std::string& op_name,\n                                         const ParallelConf& parallel_conf) {\n  CHECK(modified_parallel_conf_op_names_.emplace(op_name).second);\n  const auto& parallel_conf_it = op_name2parallel_conf_.find(op_name);\n  CHECK(parallel_conf_it != op_name2parallel_conf_.end());\n  auto old_placement_group_it = parallel_conf2placement_group_.find(*parallel_conf_it->second);\n  CHECK(old_placement_group_it != parallel_conf2placement_group_.end());\n  op_name2parallel_conf_.erase(parallel_conf_it);\n  Erase<PbRpf<std::string>>(*old_placement_group_it->second->mutable_op_set()->mutable_op_name(),\n                            [&](const std::string& x) { return x == op_name; });\n  AddOpNamesToPlacementGroup({op_name}, parallel_conf);\n}\n\nvoid JobBuilder::RemoveOpByName(const std::string& op_name) {\n  RemoveOpByName(std::unordered_set<std::string>{op_name});\n}\n\nvoid JobBuilder::RemoveOpByName(const std::unordered_set<std::string>& removing_names) {\n  // Update net\n  DLNetConf net = job_->net();\n  job_->mutable_net()->clear_op();\n  for (const OperatorConf& op_conf : net.op()) {\n    if (removing_names.count(op_conf.name()) == 0) { *(job_->mutable_net()->add_op()) = op_conf; }\n  }\n  // Update module conf\n  auto module_confs_map = job_->module_name2module_conf();\n  job_->clear_module_name2module_conf();\n  for (const auto& module_conf_pair : module_confs_map) {\n    const auto& module_name = module_conf_pair.first;\n    auto* module_name2module_conf = job_->mutable_module_name2module_conf();\n    if (!(*module_name2module_conf)[module_name].has_name()) {\n      (*module_name2module_conf)[module_name].set_name(module_name);\n    }\n    for (const auto& op_name : module_conf_pair.second.ops()) {\n      if (removing_names.count(op_name) == 0) {\n        *((*module_name2module_conf)[module_name].add_ops()) = op_name;\n      }\n    }\n  }\n  // Update placement\n  auto placement_group = job_->placement().placement_group();\n  job_->mutable_placement()->clear_placement_group();\n  for (const PlacementGroup& place : placement_group) {\n    PlacementGroup p;\n    OpNameSet* op_set = p.mutable_op_set();\n    for (const std::string& name : place.op_set().op_name()) {\n      if (removing_names.count(name) == 0) { op_set->add_op_name(name); }\n    }\n\n    *(p.mutable_parallel_conf()) = place.parallel_conf();\n    if (op_set->op_name().size() > 0) { *(job_->mutable_placement()->add_placement_group()) = p; }\n  }\n\n  auto* op_name2sbp_signature_conf =\n      job_->mutable_job_parallel_view_conf()->mutable_op_name2sbp_signature_conf();\n  auto* op_name2nd_sbp_signature_conf =\n      job_->mutable_job_parallel_view_conf()->mutable_op_name2nd_sbp_signature_conf();\n  for (const std::string& op_name : removing_names) {\n    // Update NdSbp, Sbp\n    if (op_name2nd_sbp_signature_conf->count(op_name) > 0) {\n      op_name2nd_sbp_signature_conf->erase(op_name);\n      if (GetParallelHierarchyNumAxes(op_name2parallel_conf_, op_name) == 1) {\n        CHECK(op_name2sbp_signature_conf->count(op_name) > 0);\n        op_name2sbp_signature_conf->erase(op_name);\n      }\n    }\n  }\n  // Update builder\n  JobBuilder builder(job_);\n  op_name2op_conf_.swap(builder.op_name2op_conf_);\n  op_name2parallel_conf_.swap(builder.op_name2parallel_conf_);\n  op_name2nd_sbp_signature_conf_.swap(builder.op_name2nd_sbp_signature_conf_);\n  parallel_conf2placement_group_.swap(builder.parallel_conf2placement_group_);\n}\n\nvoid JobBuilder::DelOps(const std::vector<std::string>& op_names) {\n  std::unordered_set<std::string> removing_names;\n  for (const auto& op_name : op_names) { removing_names.insert(op_name); }\n  RemoveOpByName(removing_names);\n}\n\nvoid JobBuilder::DelOps(const std::vector<OperatorConf>& op_confs) {\n  std::unordered_set<std::string> removing_names;\n  for (const auto& op_conf : op_confs) { removing_names.insert(op_conf.name()); }\n  RemoveOpByName(removing_names);\n}\n\nMaybe<void> JobBuilder::MutOpOnlyOnce(const OperatorConf& op_conf) {\n  CHECK_OR_RETURN(modified_op_conf_op_names_.emplace(op_conf.name()).second)\n      << op_conf.name() << \" is mut twice.\";\n  auto find_iter = op_name2op_conf_.find(op_conf.name());\n  CHECK_OR_RETURN(find_iter != op_name2op_conf_.end()) << op_conf.name() << \" not found.\";\n  find_iter->second->CopyFrom(op_conf);\n  return Maybe<void>::Ok();\n}\n\nvoid JobBuilder::MutOpsOnlyOnce(const std::vector<OperatorConf>& op_confs) {\n  for (const auto& op_conf : op_confs) {\n    CHECK(modified_op_conf_op_names_.emplace(op_conf.name()).second)\n        << op_conf.name() << \" is mut twice.\";\n    op_name2op_conf_.at(op_conf.name())->CopyFrom(op_conf);\n  }\n}\n\nMaybe<bool> JobBuilder::IsInMutOpTransaction(const std::string& op_name) const {\n  auto find_iter = mut_op_transaction_name2op_conf_.find(op_name);\n  return find_iter != mut_op_transaction_name2op_conf_.end();\n}\n\nMaybe<OperatorConf&> JobBuilder::MutOpTransactionGet(const std::string& op_name) {\n  return JUST(MapAt(mut_op_transaction_name2op_conf_, op_name));\n}\n\nMaybe<void> JobBuilder::MutOpTransactionMut(const OperatorConf& op_conf) {\n  auto find_iter = mut_op_transaction_name2op_conf_.find(op_conf.name());\n  if (find_iter == mut_op_transaction_name2op_conf_.end()) {\n    CHECK_OR_RETURN(mut_op_transaction_name2op_conf_.emplace(op_conf.name(), op_conf).second)\n        << op_conf.name() << \" has been added.\";\n  } else {\n    find_iter->second.CopyFrom(op_conf);\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> JobBuilder::MutOpTransactionCommit() {\n  for (const auto& pair : mut_op_transaction_name2op_conf_) { JUST(MutOpOnlyOnce(pair.second)); }\n  return Maybe<void>::Ok();\n}\n\nvoid JobBuilder::AddOrMutOpsOnlyOnce(const ParallelConf& parallel_conf,\n                                     const std::vector<OperatorConf>& op_confs) {\n  std::vector<OperatorConf> add_ops;\n  std::vector<OperatorConf> mut_ops;\n  for (const auto& op_conf : op_confs) {\n    if (op_name2op_conf_.find(op_conf.name()) == op_name2op_conf_.end()) {\n      add_ops.emplace_back(op_conf);\n    } else {\n      mut_ops.emplace_back(op_conf);\n    }\n  }\n  AddOps(parallel_conf, add_ops);\n  MutOpsOnlyOnce(mut_ops);\n}\n\nMaybe<void> JobBuilder::ForEachOperator(\n    const std::function<Maybe<void>(const Operator&)>& Handler) const {\n  for (const auto& pair : op_name2op_conf_) {\n    auto it = op_name2parallel_conf_.find(pair.first);\n    CHECK_OR_RETURN(it != op_name2parallel_conf_.end()) << \"op_name: \" << pair.first;\n    DeviceType device_type = ParallelDesc(*it->second).device_type();\n    std::shared_ptr<Operator> op = JUST(ConstructOp(*pair.second, device_type));\n    JUST(Handler(*op));\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<const ParallelConf&> JobBuilder::ParallelConf4OpName(const std::string& op_name) const {\n  const auto& iter = op_name2parallel_conf_.find(op_name);\n  CHECK_OR_RETURN(iter != op_name2parallel_conf_.end());\n  return *iter->second;\n}\n\nSbpParallel* JobBuilder::MutSbpParallel4Oba(const OpBlobArg& oba) const {\n  // TODO(guoran): rm this func\n  auto* sbp_sig = &(\n      *job_->mutable_job_parallel_view_conf()->mutable_op_name2sbp_signature_conf())[oba.op_name()];\n  return &(*sbp_sig->mutable_bn_in_op2sbp_parallel())[oba.bn_in_op()];\n}\n\nvoid JobBuilder::SetSbpParallel4Oba(const OpBlobArg& oba, const SbpParallel& sbp_parallel) {\n  CHECK_EQ(GetParallelHierarchyNumAxes(op_name2parallel_conf_, oba.op_name()), 1);\n  SetSbpSignature4Oba(job_, oba, sbp_parallel);\n  NdSbp nd_sbp;\n  *nd_sbp.add_sbp_parallel() = sbp_parallel;\n  SetNdSbpSignature4Oba(job_, &op_name2nd_sbp_signature_conf_, oba, nd_sbp);\n}\n\nvoid JobBuilder::SetNdSbp4Oba(const OpBlobArg& oba, const NdSbp& nd_sbp) {\n  SetNdSbpSignature4Oba(job_, &op_name2nd_sbp_signature_conf_, oba, nd_sbp);\n  if (GetParallelHierarchyNumAxes(op_name2parallel_conf_, oba.op_name()) == 1) {\n    SetSbpSignature4Oba(job_, oba, nd_sbp.sbp_parallel(0));\n  }\n}\n\nconst SbpSignature JobBuilder::SbpSignature4OpName(const std::string& op_name) const {\n  CHECK_EQ(GetParallelHierarchyNumAxes(op_name2parallel_conf_, op_name), 1);\n  const auto& it = op_name2nd_sbp_signature_conf_.find(op_name);\n  CHECK(it != op_name2nd_sbp_signature_conf_.end());\n\n  SbpSignature sbp_sig_conf;\n  NdSbpSignatureToSbpSignature(*it->second, &sbp_sig_conf);\n  return sbp_sig_conf;\n}\n\nvoid JobBuilder::AddSbpSignature4OpName(const std::string& op_name,\n                                        const SbpSignature& sbp_signature) {\n  NdSbpSignature nd_sbp_signature;\n  SbpSignatureToNdSbpSignature(sbp_signature, &nd_sbp_signature);\n  AddOrSetNdSbpSignature4OpName(job_, &op_name2nd_sbp_signature_conf_, op_name, nd_sbp_signature);\n  CHECK_EQ(GetParallelHierarchyNumAxes(op_name2parallel_conf_, op_name), 1);\n  AddOrSetSbpSignature4OpName(job_, op_name, sbp_signature);\n}\n\nconst NdSbpSignature& JobBuilder::NdSbpSignature4OpName(const std::string& op_name) const {\n  const auto& it = op_name2nd_sbp_signature_conf_.find(op_name);\n  CHECK(it != op_name2nd_sbp_signature_conf_.end());\n  return *(it->second);\n}\n\nvoid JobBuilder::AddNdSbpSignature4OpName(const std::string& op_name,\n                                          const NdSbpSignature& nd_sbp_signature) {\n  AddOrSetNdSbpSignature4OpName(job_, &op_name2nd_sbp_signature_conf_, op_name, nd_sbp_signature);\n  if (GetParallelHierarchyNumAxes(op_name2parallel_conf_, op_name) == 1) {\n    SbpSignature sbp_signature;\n    NdSbpSignatureToSbpSignature(nd_sbp_signature, &sbp_signature);\n    AddOrSetSbpSignature4OpName(job_, op_name, sbp_signature);\n  }\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job/job_builder.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_JOB_JOB_CONF_BUILDER_H_\n#define ONEFLOW_CORE_JOB_JOB_CONF_BUILDER_H_\n\n#include \"oneflow/core/job/job_desc.h\"\n#include \"oneflow/core/register/op_blob_arg.pb.h\"\n#include \"oneflow/core/job/parallel_desc.h\"\n\nnamespace oneflow {\n\nconst static std::string kProducedLbi2ConsumedDiffLbi = \"produced_lbi2consumed_diff_lbi\";\n\nstd::function<const ParallelConf*(const std::string&)> MakeGetterParallelConf4OpName(\n    const Placement& placement);\n\nclass SbpParallel;\nclass LogicalBlobId;\nclass Operator;\n\nclass JobBuilder final {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(JobBuilder);\n  explicit JobBuilder(Job* job);\n  ~JobBuilder() = default;\n\n  const Job& job() const { return *job_; }\n  JobHelperConf* mutable_helper() { return job_->mutable_helper(); }\n  JobParallelViewConf* mutable_job_parallel_view_conf() {\n    return job_->mutable_job_parallel_view_conf();\n  }\n\n  MergedLogicalChainIdGroup* add_logical_chain_groups() { return job_->add_logical_chain_groups(); }\n\n  Maybe<const OperatorConf&> OpConf4OpName(const std::string& op_name) const;\n  Maybe<OperatorConf*> MutableOpConf4OpName(const std::string& op_name);\n\n  Maybe<void> AddOp(const ParallelConf& parallel_conf, const OperatorConf& op_conf);\n  void AddOps(const ParallelConf& parallel_conf, const std::vector<OperatorConf>& op_confs);\n  Maybe<void> MutOpOnlyOnce(const OperatorConf& op_conf);\n  void MutOpsOnlyOnce(const std::vector<OperatorConf>& op_confs);\n  // Mut op with transaction\n  Maybe<bool> IsInMutOpTransaction(const std::string& op_name) const;\n  Maybe<OperatorConf&> MutOpTransactionGet(const std::string& op_name);\n  Maybe<void> MutOpTransactionMut(const OperatorConf& op_conf);\n  Maybe<void> MutOpTransactionCommit();\n  void MutParallelConfOnlyOnce(const std::string& op_name, const ParallelConf& parallel_conf);\n  void AddOrMutOpsOnlyOnce(const ParallelConf& parallel_conf,\n                           const std::vector<OperatorConf>& op_confs);\n\n  void RemoveOpByName(const std::string& op_name);\n  void RemoveOpByName(const std::unordered_set<std::string>& removing_names);\n  void DelOps(const std::vector<std::string>& op_names);\n  void DelOps(const std::vector<OperatorConf>& op_confs);\n\n  SbpParallel* MutSbpParallel4Oba(const OpBlobArg& oba) const;\n  void SetSbpParallel4Oba(const OpBlobArg& oba, const SbpParallel& sbp_parallel);\n  void SetNdSbp4Oba(const OpBlobArg& oba, const NdSbp& nd_sbp);\n  Maybe<void> ForEachOperator(const std::function<Maybe<void>(const Operator&)>& Handler) const;\n\n  Maybe<const ParallelConf&> ParallelConf4Lbi(const LogicalBlobId& lbi) const;\n  Maybe<const ParallelConf&> ParallelConf4OpName(const std::string& op_name) const;\n\n  const SbpSignature SbpSignature4OpName(const std::string& op_name) const;\n  void AddSbpSignature4OpName(const std::string& op_name, const SbpSignature& sbp_signature);\n\n  const NdSbpSignature& NdSbpSignature4OpName(const std::string& op_name) const;\n  void AddNdSbpSignature4OpName(const std::string& op_name, const NdSbpSignature& nd_sbp_signature);\n\n private:\n  void AddOpNamesToPlacementGroup(const std::vector<std::string>& op_names,\n                                  const ParallelConf& parallel_conf);\n  void AddOpToModuleConf(const OperatorConf& op_conf);\n\n  Job* job_;\n  HashMap<std::string, OperatorConf*> op_name2op_conf_;\n  HashMap<std::string, ParallelConf*> op_name2parallel_conf_;\n  HashMap<LogicalBlobId, ParallelConf*> lbi2blob_parallel_conf_;\n  HashSet<std::string> modified_op_conf_op_names_;\n  HashSet<std::string> modified_parallel_conf_op_names_;\n\n  HashMap<std::string, NdSbpSignature*> op_name2nd_sbp_signature_conf_;\n  HashMap<ParallelConf, PlacementGroup*> parallel_conf2placement_group_;\n  HashMap<std::string, OperatorConf> mut_op_transaction_name2op_conf_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_JOB_JOB_CONF_BUILDER_H_\n"
  },
  {
    "path": "oneflow/core/job/job_conf.proto",
    "content": "syntax = \"proto2\";\npackage oneflow;\n\nimport \"oneflow/core/common/data_type.proto\";\nimport \"oneflow/core/job/placement.proto\";\nimport \"oneflow/core/register/blob_desc.proto\";\nimport \"oneflow/core/job/sbp_parallel.proto\";\nimport \"oneflow/core/framework/user_op_attr.proto\";\nimport \"oneflow/core/job/initializer_conf.proto\";\nimport \"oneflow/core/job/learning_rate_schedule_conf.proto\";\nimport \"oneflow/core/register/logical_blob_id.proto\";\nimport \"oneflow/core/operator/interface_blob_conf.proto\";\n\nmessage NaiveModelUpdateConf {\n}\n\nmessage MomentumModelUpdateConf {\n  optional float beta = 1 [default = 0.9];\n  optional float dampening = 2 [default = 0.0];\n  optional bool nesterov = 3 [default = false];\n  optional bool maximize = 4 [default = false];\n}\n\nmessage RMSPropModelUpdateConf {\n  optional float decay_rate = 1 [default = 0.99];\n  optional float epsilon = 2 [default = 1e-8];\n  optional bool centered = 3 [default = false];\n}\n\nmessage LARSModelUpdateConf {\n  optional float momentum_beta = 1 [default = 0.9];\n  optional float epsilon = 2 [default = 1e-9];\n  optional float lars_coefficient = 3 [default = 0.0001];\n}\n\nmessage AdamModelUpdateConf {\n  optional float beta1 = 1 [default = 0.9];\n  optional float beta2 = 2 [default = 0.999];\n  optional float epsilon = 3 [default = 1e-8];\n  optional bool do_bias_correction = 4 [default = true];\n  optional bool amsgrad = 5 [default = false];\n  optional bool smart_decay = 6 [default = false];\n}\n\nmessage LazyAdamModelUpdateConf {\n  optional float beta1 = 1 [default = 0.9];\n  optional float beta2 = 2 [default = 0.999];\n  optional float epsilon = 3 [default = 1e-8];\n  optional bool do_bias_correction = 4 [default = true];\n  optional bool amsgrad = 5 [default = false];\n}\n\nmessage LambModelUpdateConf {\n  optional float beta1 = 1 [default = 0.9];\n  optional float beta2 = 2 [default = 0.999];\n  optional float epsilon = 3 [default = 1e-8];\n  optional bool do_bias_correction = 4 [default = true];\n}\n\nmessage AdagradModelUpdateConf {\n  required float lr_decay = 1 [default = 0.0];\n  required float initial_accumulator_value = 2 [default = 0.0];\n  required float epsilon = 3 [default = 1e-10];\n}\n\nmessage FtrlModelUpdateConf {\n  required float initial_accumulator_value = 1 [default = 0.1];\n  required float lr_power = 2 [default = 0.5];\n  optional float lambda1 = 3 [default = 0.0];\n  optional float lambda2 = 4 [default = 0.0];\n  optional float beta = 5 [default = 0.0];\n}\n\nmessage AdadeltaModelUpdateConf {\n  required float rho = 1 [default = 0.9];\n  required float epsilon = 2 [default = 1e-6];\n  required bool maximize = 3 [default = false];\n}\n\nmessage ClipByGlobalNormConf {\n  optional float max_norm = 1 [default = 1.0];\n  optional double norm_type = 2 [default = 2.0];\n}\n\nmessage ClipConf {\n  oneof type {\n    ClipByGlobalNormConf clip_by_global_norm = 1;\n  }\n}\n\nmessage WeightDecayFilterPatternSet {\n  repeated string pattern = 1;\n}\n\nmessage WeightDecayConf {\n  required float weight_decay_rate = 1;\n  oneof weight_decay_filter_type {\n    WeightDecayFilterPatternSet includes = 2;\n    WeightDecayFilterPatternSet excludes = 3;\n  }\n}\n\nmessage OptimizerConf {\n  repeated string variable_op_names = 1;\n  optional float base_learning_rate = 2;\n  repeated string variable_grad_lbns = 3;\n  optional LearningRateDecayConf learning_rate_decay = 4;\n  optional string learning_rate_lbn = 5;\n  optional ClipConf clip_conf = 6;\n  optional WeightDecayConf weight_decay_conf = 7;\n  optional float lr_scale = 8 [default = 1.0];\n  oneof normal_mdupdt {\n    NaiveModelUpdateConf naive_conf = 1000;\n    MomentumModelUpdateConf momentum_conf = 1001;\n    RMSPropModelUpdateConf rmsprop_conf = 1002;\n    LARSModelUpdateConf lars_conf = 1003;\n    AdamModelUpdateConf adam_conf = 1004;\n    LazyAdamModelUpdateConf lazy_adam_conf = 1005;\n    LambModelUpdateConf lamb_conf = 1006;\n    AdagradModelUpdateConf adagrad_conf = 1007;\n    FtrlModelUpdateConf ftrl_conf = 1008;\n    AdadeltaModelUpdateConf adadelta_conf = 1009; \n  }\n}\n\nmessage NormalModelUpdateOpUserConf {\n  optional LearningRateDecayConf learning_rate_decay = 1;\n  optional ClipConf clip_conf = 3;\n  optional WeightDecayConf weight_decay_conf = 4;\n  oneof normal_mdupdt {\n    NaiveModelUpdateConf naive_conf = 1000;\n    MomentumModelUpdateConf momentum_conf = 1001;\n    RMSPropModelUpdateConf rmsprop_conf = 1002;\n    LARSModelUpdateConf lars_conf = 1003;\n    AdamModelUpdateConf adam_conf = 1004;\n    LazyAdamModelUpdateConf lazy_adam_conf = 1005;\n    LambModelUpdateConf lamb_conf = 1006;\n    AdagradModelUpdateConf adagrad_conf = 1007;\n    FtrlModelUpdateConf ftrl_conf = 1008;\n  }\n}\n\nmessage DynamicLossScalePolicy {\n  optional float initial_loss_scale = 1 [default = 1073741824.0];\n  optional float increment_period = 2 [default = 2000];\n  optional float multiplier = 3 [default=2.0];\n}\n\nmessage TrainConf {\n  repeated OptimizerConf optimizer_conf = 1;\n  repeated string loss_lbn = 2;\n  repeated string loss_grad_lbn = 6;\n  optional string train_step_lbn = 3;\n  oneof loss_scale_policy {\n    float loss_scale_factor = 4 [default = 1];\n    DynamicLossScalePolicy dynamic_loss_scale_policy = 5;\n  }\n  // Deprecated model update conf, will be removed later.\n  optional NormalModelUpdateOpUserConf model_update_conf = 101;\n  optional float primary_lr = 102;\n  optional float secondary_lr = 103;\n  optional string primary_lr_lbn = 104;\n  optional string secondary_lr_lbn = 105;\n}\n\nmessage PredictConf {\n}\n\nmessage MemoryAllocationAlgorithmConf {\n  optional bool use_mem_size_first_algo = 1 [default = true];\n  optional bool use_lifetime_first_algo = 2 [default = false];\n  optional bool use_time_line_algo = 3 [default = false];\n  optional bool use_mem_volume_first_algo = 4 [default = false];\n}\n\nmessage MemoryCompactInsertConf {\n  optional bool use_compact_insert = 1 [default = false];\n  optional bool use_non_compact_insert = 2 [default = true];\n}\n\nmessage QatConfig {\n  optional bool per_channel_weight_quantization = 1 [default = false];\n  optional bool symmetric = 2 [default = true];\n  optional float moving_min_max_momentum = 3 [default = 0.95];\n  optional int64 moving_min_max_stop_update_after_iters = 4;\n  optional string target_backend = 5 [default = \"\"];\n}\n\nmessage IndexedSlicesOptimizerConf {\n  optional bool enable = 1 [default = true];\n  required OpNameSet include_op_names = 2;\n}\n\nmessage ParallelBlobConf {\n  required BlobDescProto logical_blob_desc_conf = 1;\n  required ParallelConf parallel_conf = 2;\n  required NdSbp nd_sbp = 3;\n}\n\nmessage JobInputDef {\n  required LogicalBlobId lbi = 1;\n  required InterfaceBlobConf blob_conf = 2;\n}\n\nmessage JobOutputDef {\n  required LogicalBlobId lbi = 1;\n}\n\nmessage JobSignatureDef {\n  map<string, JobInputDef> inputs = 1;\n  map<string, JobOutputDef> outputs = 2;\n}\n\nenum StraightenAlgorithmTag {\n  kDisableStraighten = 1;\n  kOverlap4Transfer = 2;\n  kCompressMemory = 3;\n  kOverlap4CpuGpu = 4;\n  kDelayShortGpu = 5;\n}\n\nenum AutoMemoryStrategy {\n  kDisableAutoMemory = 1;\n  kSlightAutoMemory = 2;\n  kModerateAutoMemory = 3;\n  kHeavyAutoMemory = 4;\n  kAdaptiveAutoMemory = 5;\n}\n\nmessage JobConfigProto {\n  required string job_name = 1;\n\n  oneof job_type {\n    TrainConf train_conf = 3;\n    PredictConf predict_conf = 4;\n  }\n  optional DataType default_data_type = 8 [default = kFloat]; // kFloat or kDouble\n  oneof default_initialize_conf {\n    InitializerConf default_initializer_conf = 10;\n    string default_initialize_with_snapshot_path = 11;\n  }\n\n  optional MemoryAllocationAlgorithmConf memory_allocation_algorithm_conf = 102;\n  optional MemoryCompactInsertConf memory_compact_insert_conf = 103;\n\n  optional IndexedSlicesOptimizerConf indexed_slices_optimizer_conf = 104;\n  optional bool enable_fuse_model_update_ops = 105 [default = false];\n  optional bool enable_gradients_stats_aggregation = 106 [default = true];\n  optional string optimizer_placement_optimization_mode = 107;\n  optional int64 optimizer_placement_optimization_threshold = 108 [default = 1024];\n  optional int64 optimizer_placement_optimization_shard_restore_level = 110 [default = 2];\n\n  optional QatConfig qat_config = 109;\n\n  optional bool enable_cudnn = 200 [default = true];\n  optional int64 cudnn_buf_limit_mbyte = 201 [default = 1024];  // 1GByte\n  optional int32 cudnn_conv_force_fwd_algo = 202;\n  optional int32 cudnn_conv_force_bwd_data_algo = 203;\n  optional int32 cudnn_conv_force_bwd_filter_algo = 204;\n  optional bool cudnn_conv_heuristic_search_algo = 205 [default = true];\n  optional bool cudnn_conv_use_deterministic_algo_only = 206 [default = false];\n  optional bool enable_cudnn_fused_normalization_add_relu = 207;\n  optional bool enable_fuse_add_to_output = 208 [default = false];\n  optional bool enable_fuse_cast_scale = 209 [default = false];\n  optional int64 num_gradient_accumulation_steps = 210;\n\n  optional bool enable_reuse_mem = 300 [default = true];\n  optional bool enable_inplace = 301 [default = true];\n  optional bool enable_inplace_in_reduce_struct = 302 [default = true];\n\n  optional bool do_parallel_cast_before_widening_type_cast = 403 [default = true];\n\n  optional bool prune_parallel_cast_ops = 509 [default = true];\n  optional bool prune_cast_to_static_shape_ops = 510 [default = true];\n  optional bool prune_amp_white_identity_ops = 511 [default = true];\n  optional bool prune_depend_ops = 512 [default = true];\n\n  optional bool cudnn_conv_enable_pseudo_half = 600 [default = true];\n  optional bool enable_auto_mixed_precision = 602 [default = false];\n  optional bool enable_quantization_aware_training = 603 [default = false];\n  optional DataType mixed_precision_data_type = 604 [default = kFloat16]; // kFloat16 or kBFloat16\n  optional bool enable_multi_tensor_update = 605 [default = false];\n  optional bool enable_fused_model_update_cast = 606 [default = false];\n\n  optional bool enable_auto_parallel = 700 [default = false];\n  optional double auto_parallel_computation_cost_ratio = 701 [default = 0.05];\n  optional double auto_parallel_wait_time = 702 [default = 1.65e4];\n  optional bool enable_auto_parallel_trunk_algo = 703 [default = true];\n  optional bool enable_auto_parallel_sbp_collector = 704 [default = false];\n  optional bool enable_auto_parallel_ignore_user_sbp_config = 705 [default = false];\n  optional AutoMemoryStrategy enable_auto_memory = 706 [default = kAdaptiveAutoMemory];\n  \n  optional StraightenAlgorithmTag straighten_algorithm_tag_in_task_graph = 800 [default = kCompressMemory];\n  optional bool enable_compress_memory = 801 [default = false];\n\n  optional int64 concurrency_width = 1000 [default = 128];\n\n  map<string, AttrValue> flag_name2flag_value = 2000;\n\n  optional int64 logical_object_id = 3000;\n\n  optional JobSignatureDef signature = 4000;\n}\n"
  },
  {
    "path": "oneflow/core/job/job_desc.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/job/job_desc.h\"\n#include \"oneflow/core/job/job_set.pb.h\"\n#include \"oneflow/core/job/parallel_desc.h\"\n#include \"oneflow/core/job/global_for.h\"\n#include \"oneflow/core/operator/operator.h\"\n#include \"oneflow/core/common/protobuf.h\"\n#include \"oneflow/core/persistence/hadoop/hadoop_file_system.h\"\n#include \"oneflow/core/graph/graph.h\"\n#include \"oneflow/core/graph/op_graph.h\"\n#include \"oneflow/core/job/job_builder.h\"\n#include \"oneflow/core/job/job_desc.h\"\n#include \"oneflow/core/job/global_for.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nvoid CheckFunctionConfig(const JobConfigProto& job_conf) {\n  const auto& attr_name2attr_def = GlobalFunctionConfigDef().attr_name2attr_def();\n  for (const auto& pair : job_conf.flag_name2flag_value()) {\n    const auto& iter = attr_name2attr_def.find(pair.first);\n    CHECK(iter != attr_name2attr_def.end());\n    CHECK_EQ(iter->second.default_val().value_case(), pair.second.value_case());\n  }\n}\n\n}  // namespace\n\nJobDesc::JobDesc(const JobConfigProto& job_conf, int64_t job_id)\n    : job_conf_(job_conf), job_id_(job_id), symbol_id_(NullOpt) {\n  CHECK_JUST(Init());\n  Singleton<ResourceDesc, ForSession>::Get()->DumpCudnnConf(job_conf);\n}\n\nMaybe<JobDesc> JobDesc::New(int64_t symbol_id, const JobConfigProto& job_conf) {\n  auto job_desc = std::make_shared<JobDesc>(job_conf);\n  job_desc->symbol_id_ = symbol_id;\n  return job_desc;\n}\n\nMaybe<void> JobDesc::Init() {\n  CheckFunctionConfig(job_conf_);\n  return Maybe<void>::Ok();\n}\n\nconst AttrValue& JobDesc::GetFunctionFlagVal(const std::string& field_name) const {\n  const auto& iter = job_conf_.flag_name2flag_value().find(field_name);\n  if (iter != job_conf_.flag_name2flag_value().end()) { return iter->second; }\n  const auto& attr_name2attr_def = GlobalFunctionConfigDef().attr_name2attr_def();\n  const auto& def_iter = attr_name2attr_def.find(field_name);\n  CHECK(def_iter != attr_name2attr_def.end());\n  return def_iter->second.default_val();\n}\n\nbool IsInterfaceOpConf(const OperatorConf& op_conf) {\n  return IsClassRegistered<int32_t, IsInterfaceOpConf4OpTypeCase>(op_conf.op_type_case());\n}\n\nGlobalJobDescScope::GlobalJobDescScope(const JobConfigProto& job_conf, int64_t job_id) {\n  if (Singleton<JobDesc>::Get() != nullptr) { Singleton<JobDesc>::Delete(); }\n  Singleton<JobDesc>::New(job_conf, job_id);\n}\n\nGlobalJobDescScope::~GlobalJobDescScope() { Singleton<JobDesc>::Delete(); }\n\nconst JobDesc& GlobalJobDesc() { return *Singleton<JobDesc>::Get(); }\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job/job_desc.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_JOB_JOB_DESC_H_\n#define ONEFLOW_CORE_JOB_JOB_DESC_H_\n\n#include \"oneflow/core/common/optional.h\"\n#include \"oneflow/core/common/protobuf.h\"\n#include \"oneflow/core/job/dlnet_conf.pb.h\"\n#include \"oneflow/core/job/job.pb.h\"\n#include \"oneflow/core/framework/user_op_attr.pb.h\"\n#include \"oneflow/core/job/placement.pb.h\"\n#include \"oneflow/core/job/inter_user_job_info.pb.h\"\n#include \"oneflow/core/register/logical_blob_id.pb.h\"\n#include \"oneflow/core/framework/config_def.h\"\n\nnamespace oneflow {\n\nbool IsInterfaceOpConf(const OperatorConf& op_conf);\n\nclass JobDesc final {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(JobDesc);\n  JobDesc(const JobConfigProto& job_conf, int64_t job_id);\n  explicit JobDesc(const JobConfigProto& job_conf) : JobDesc(job_conf, -1) {}\n  ~JobDesc() = default;\n\n  static Maybe<JobDesc> New(int64_t symbol_id, const JobConfigProto& job_conf);\n  const Optional<int64_t>& symbol_id() const { return symbol_id_; }\n\n  // Common\n  int64_t job_id() const { return job_id_; }\n  const std::string& job_name() const { return job_conf_.job_name(); }\n  int64_t concurrency_width() const { return job_conf_.concurrency_width(); }\n  const JobConfigProto& job_conf() const { return job_conf_; }\n  const JobConfigProto& data() const { return job_conf_; }\n  DataType DefaultDataType() const { return job_conf_.default_data_type(); }\n  bool EnableCudnn() const { return job_conf_.enable_cudnn(); }\n  bool IsTrain() const { return job_conf_.has_train_conf(); }\n  bool IsPredict() const { return job_conf_.has_predict_conf(); }\n  bool enable_reuse_mem() const { return job_conf_.enable_reuse_mem(); }\n  bool enable_inplace() const { return job_conf_.enable_inplace(); }\n  bool enable_auto_mixed_precision() const { return job_conf_.enable_auto_mixed_precision(); }\n  bool enable_multi_tensor_update() const { return job_conf_.enable_multi_tensor_update(); }\n  bool enable_fused_model_update_cast() const { return job_conf_.enable_fused_model_update_cast(); }\n  DataType mixed_precision_data_type() const { return job_conf_.mixed_precision_data_type(); }\n  bool do_parallel_cast_before_widening_type_cast() const {\n    return job_conf_.do_parallel_cast_before_widening_type_cast();\n  };\n  bool prune_parallel_cast_ops() const { return job_conf_.prune_parallel_cast_ops(); }\n  bool prune_cast_to_static_shape_ops() const { return job_conf_.prune_cast_to_static_shape_ops(); }\n  bool prune_amp_white_identity_ops() const { return job_conf_.prune_amp_white_identity_ops(); }\n  bool prune_depend_ops() const { return job_conf_.prune_depend_ops(); }\n  bool enable_auto_parallel() const { return job_conf_.enable_auto_parallel(); }\n  int64_t cudnn_buf_limit_mbyte() const { return job_conf_.cudnn_buf_limit_mbyte(); }\n\n#define DEFINE_FUNCTION_CONFIG_GETTER(T, func_name, field_name) \\\n  T func_name(const std::string& field_name) const {            \\\n    const AttrValue& attr_val = GetFunctionFlagVal(field_name); \\\n    CHECK(attr_val.has_##field_name());                         \\\n    return attr_val.field_name();                               \\\n  }\n  DEFINE_FUNCTION_CONFIG_GETTER(bool, Bool, at_bool);\n  DEFINE_FUNCTION_CONFIG_GETTER(int64_t, Int64, at_int64);\n  DEFINE_FUNCTION_CONFIG_GETTER(double, Double, at_double);\n  DEFINE_FUNCTION_CONFIG_GETTER(const std::string&, String, at_string);\n\n private:\n  Maybe<void> Init();\n  const AttrValue& GetFunctionFlagVal(const std::string& field_name) const;\n\n  JobConfigProto job_conf_;\n  int64_t job_id_;\n  Optional<int64_t> symbol_id_;\n};\n\ntypedef HashMap<std::string, int64_t> JobName2JobId;\n\nclass GlobalJobDescScope final {\n public:\n  GlobalJobDescScope(const JobConfigProto& job_conf, int64_t job_id);\n  ~GlobalJobDescScope();\n};\nconst JobDesc& GlobalJobDesc();\n\nbool IsPullJob(const std::string& job_name, const InterUserJobInfo& inter_user_job_info);\nbool IsPushJob(const std::string& job_name, const InterUserJobInfo& inter_user_job_info);\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_JOB_JOB_DESC_H_\n"
  },
  {
    "path": "oneflow/core/job/job_instance.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_JOB_JOB_INSTANCE_H_\n#define ONEFLOW_CORE_JOB_JOB_INSTANCE_H_\n\n#include <string>\n#include \"oneflow/core/common/util.h\"\n\nnamespace oneflow {\n\nclass JobInstance {\n public:\n  JobInstance() = default;\n\n  virtual ~JobInstance() = default;\n\n  virtual std::string job_name() const { UNIMPLEMENTED(); }\n  virtual void Finish() const { UNIMPLEMENTED(); }\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_JOB_JOB_INSTANCE_H_\n"
  },
  {
    "path": "oneflow/core/job/job_interpreter.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/container_util.h\"\n#include \"oneflow/core/framework/nn_graph.h\"\n#include \"oneflow/core/framework/op_builder.h\"\n#include \"oneflow/core/framework/op_interpreter.h\"\n#include \"oneflow/core/functional/functional_api.yaml.h\"\n#include \"oneflow/core/job/job.pb.h\"\n#include \"oneflow/core/profiler/profiler.h\"\n#include \"oneflow/core/framework/local_tensor_infer_cache.h\"\n#include \"oneflow/core/framework/global_tensor_infer_cache.h\"\n#include \"oneflow/core/boxing/eager_boxing_interpreter_mgr.h\"\n#include \"oneflow/core/framework/tensor_global_id.h\"\n#include \"oneflow/core/common/decorator.h\"\n#include \"oneflow/core/boxing/eager_boxing_logger.h\"\n\nnamespace oneflow {\nnamespace one {\n\nusing Env = std::map<std::string, std::shared_ptr<Tensor>>;\nusing NameToParallelDescMap = std::map<std::string, Symbol<ParallelDesc>>;\n\nMaybe<Env> InitEnv(const one::TensorTuple& graph_inputs, const std::shared_ptr<NNGraph>& graph) {\n  Env env;\n  for (const auto& [name, tensor] : graph->variable_op_name2tensor()) {\n    env.emplace(name + \"/out\", tensor);\n  }\n  for (size_t i = 0; i < graph->inputs_op_names().size(); ++i) {\n    const auto& name = graph->inputs_op_names()[i];\n    env.emplace(name + \"/out\", JUST(VectorAt(graph_inputs, i)));\n  }\n  return env;\n}\n\nMaybe<UserOpExpr> OpConfToUserOpExpr(const OperatorConf& op_conf) {\n  CHECK_OR_RETURN(op_conf.has_user_conf());\n  const auto& user_conf = op_conf.user_conf();\n  auto builder = OpBuilder(user_conf.op_type_name());\n  for (const auto& pair : user_conf.attr()) { builder.Attr(pair.first, pair.second); }\n  for (const auto& pair : user_conf.input()) {\n    // ignore \"UserSourceOpTickInput\"\n    if (pair.first == \"UserSourceOpTickInput\") { continue; }\n    builder.Input(pair.first, pair.second.s_size());\n  }\n  for (const auto& pair : user_conf.output()) { builder.Output(pair.first, pair.second.s_size()); }\n  return JUST(builder.Build());\n}\n\ntemplate<typename Func>\nMaybe<std::pair<TensorTuple, OpArgsVector<std::string>>> GetInputTensors(\n    const UserOpConf& user_conf, const Env& env, const Func& preprocess) {\n  TensorTuple inputs;\n  OpArgsVector<std::string> ibns;\n  for (const auto& [ibn, ibs] : user_conf.input()) {\n    if (ibn == \"UserSourceOpTickInput\") { continue; }\n    const auto& tensor_names = ibs.s();\n    for (int i = 0; i < tensor_names.size(); ++i) {\n      inputs.emplace_back(preprocess(JUST(MapAt(env, tensor_names[i]))));\n      ibns.emplace_back(ibn + '_' + std::to_string(i));\n    }\n  }\n  return std::make_pair(inputs, ibns);\n}\n\nOpArgsVector<std::string> GetOutputNamesOfOp(const UserOpConf& user_conf) {\n  OpArgsVector<std::string> output_names;\n  for (const auto& pair : user_conf.output()) {\n    for (const auto& name : pair.second.s()) { output_names.emplace_back(name); }\n  }\n  return output_names;\n}\n\n// Only support a limited subset of view ops for now\nbool IsViewOp(const std::shared_ptr<UserOpExpr>& op) {\n  return op->op_type_name() == \"reshape\" || op->op_type_name() == \"expand_dims\";\n}\n\nMaybe<void> RunViewOp(const std::shared_ptr<UserOpExpr>& op, Env& env, const TensorTuple& inputs,\n                      const OpArgsVector<std::string>& output_names) {\n  // eliminate the memcpy of view ops\n  CHECK_OR_RETURN(IsViewOp(op));\n  const std::shared_ptr<const LocalTensorInferResult> result =\n      JUST([&]() -> Maybe<const LocalTensorInferResult> {\n        LocalTensorMetaInferArgs infer_args;\n        JUST(infer_args.Init(op->base_attrs(), JUST(inputs[0]->device()), inputs));\n        return JUST(op->mut_local_tensor_infer_cache()->GetOrInfer(infer_args));\n      }());\n  const auto& output_shape = result->output_tensor_metas()[0]->shape();\n  const auto output =\n      JUST(view::BasicView(inputs[0], output_shape, JUST(inputs[0]->storage_offset())));\n  env.emplace(output_names[0], output);\n  return Maybe<void>::Ok();\n}\n\nnamespace {\n\nMaybe<void> RawRunGlobalNormalOp(const std::shared_ptr<UserOpExpr>& op, TensorTuple& inputs,\n                                 TensorTuple* outputs, Env& env,\n                                 const OpArgsVector<std::string>& ibns,\n                                 const OpArgsVector<std::string>& output_names,\n                                 const NdSbpSignature& ndsbp_signature,\n                                 const Symbol<ParallelDesc>& op_parallel_desc) {\n  Optional<int64_t> parallel_id;\n  const auto& tensor_device =\n      JUST(GetTensorDevice4CurrentProcessCtx(op_parallel_desc, &parallel_id));\n  const auto* mgr = Singleton<EagerBoxingInterpreterManager>::Get();\n  CHECK_OR_RETURN(inputs.size() == ibns.size()) << \"inputs size != ibns size\";\n  for (int i = 0; i < inputs.size(); ++i) {\n    std::shared_ptr<Tensor> input_tensor = inputs[i];\n    std::string lbn = JUST(VectorAt(ibns, i));\n    const auto& logical_shape = input_tensor->shape();\n    CHECK_OR_RETURN(logical_shape->elem_cnt() > 0) << \"tensor logical element empty\";\n    const auto& in_nd_sbp = JUST(input_tensor->nd_sbp());\n    const auto& out_nd_sbp = SymbolOf(JUST(MapAt(ndsbp_signature.bn_in_op2nd_sbp(), lbn)));\n    const auto& in_parallel_desc = JUST(input_tensor->parallel_desc());\n    const auto& out_parallel_desc = op_parallel_desc;\n    CHECK_OR_RETURN(in_parallel_desc == out_parallel_desc) << \"input placement != output placement\";\n    if (in_parallel_desc->parallel_num() != 1 && in_nd_sbp != out_nd_sbp) {\n      const auto& boxing_interpreter = JUST(mgr->GetEagerBoxingInterpreter(\n          in_nd_sbp, out_nd_sbp, in_parallel_desc, out_parallel_desc, *logical_shape));\n      Singleton<const EagerBoxingLogger>::Get()->Log(\n          *JUST(boxing_interpreter->boxing_interpreter_status()), /* prefix */ \"\");\n      if (parallel_id.has_value()) {\n        inputs.at(i) = JUST(boxing_interpreter->Interpret(input_tensor, in_nd_sbp, out_nd_sbp,\n                                                          in_parallel_desc, out_parallel_desc));\n      }\n    }\n  }\n  static EagerGlobalInterpreter it;\n  static OpExprInterpContext ctx =\n      OpExprInterpContext(AttrMap{}, op_parallel_desc,\n                          SymbolOf(JUST(MapAt(ndsbp_signature.bn_in_op2nd_sbp(), \"out_0\"))));\n  JUST(it.Apply(*op, inputs, outputs, ctx));\n  for (size_t i = 0; i < output_names.size(); ++i) {\n    env.emplace(output_names[i], JUST(VectorAt(*outputs, i)));\n  }\n  return Maybe<void>::Ok();\n}\n\nauto* RunGlobalNormalOpThenInitGlobalId = DECORATE(&RawRunGlobalNormalOp, NonRecursiveInitGlobalId);\n\n}  // namespace\n\nMaybe<void> RunGlobalNormalOp(const std::shared_ptr<UserOpExpr>& op, TensorTuple& inputs, Env& env,\n                              const OpArgsVector<std::string>& ibns,\n                              const OpArgsVector<std::string>& output_names,\n                              const NdSbpSignature& ndsbp_signature,\n                              const Symbol<ParallelDesc>& op_parallel_desc) {\n  TensorTuple outputs(output_names.size());\n  return RunGlobalNormalOpThenInitGlobalId(op, inputs, &outputs, env, ibns, output_names,\n                                           ndsbp_signature, op_parallel_desc);\n}\n\nMaybe<void> RunNormalOp(const std::shared_ptr<UserOpExpr>& op, Env& env, const TensorTuple& inputs,\n                        const OpArgsVector<std::string>& output_names) {\n  TensorTuple outputs(output_names.size());\n  static EagerLocalInterpreter it;\n  static AttrMap empty_attr_map;\n  JUST(it.Apply(*op, inputs, &outputs, empty_attr_map));\n  for (size_t i = 0; i < output_names.size(); ++i) {\n    env.emplace(output_names[i], JUST(VectorAt(outputs, i)));\n  }\n  return Maybe<void>::Ok();\n}\n\n// tensors in outdated_tensors_after_op[i] will not be accessed any more after i-th op\n// so they can be released once i-th op's execution finishes.\nstd::vector<std::vector<std::string>> GetOutdatedTensorsAfterOp(const Job& job) {\n  std::vector<std::vector<std::string>> outdated_tensors_after_op(job.net().op_size());\n  std::set<std::string> visited;\n  for (int i = job.net().op_size() - 1; i >= 0; --i) {\n    const auto& op_conf = job.net().op(i);\n    // do not release the graph output tensors\n    if (op_conf.has_output_conf()) {\n      const auto& output_conf = op_conf.output_conf();\n      visited.insert(output_conf.in());\n    } else if (op_conf.has_user_conf()) {\n      const auto& user_conf = op_conf.user_conf();\n      for (const auto& pair : user_conf.input()) {\n        if (pair.first == \"UserSourceOpTickInput\") { continue; }\n        for (const auto& name : pair.second.s()) {\n          if (visited.find(name) == visited.end()) {\n            outdated_tensors_after_op[i].push_back(name);\n            visited.insert(name);\n          }\n        }\n      }\n    }\n  }\n  return outdated_tensors_after_op;\n}\n\nMaybe<void> InitOpExprs(const std::shared_ptr<NNGraph>& graph) {\n  CHECK_OR_RETURN(graph->cached_op_exprs.empty());\n\n  const auto& job = graph->job();\n  for (int i = 0; i < job.net().op_size(); i++) {\n    const auto& op_conf = job.net().op(i);\n    if (op_conf.has_user_conf()) {\n      const auto op_expr = JUST(OpConfToUserOpExpr(op_conf));\n      graph->cached_op_exprs.push_back(op_expr);\n    } else {\n      graph->cached_op_exprs.push_back(nullptr);\n    }\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<one::TensorTuple> InterpretJob(const one::TensorTuple& graph_inputs,\n                                     const std::shared_ptr<NNGraph>& graph) {\n  if (graph->cached_op_exprs.empty()) { JUST(InitOpExprs(graph)); }\n\n  const auto& job = graph->job();\n  auto env = *JUST(InitEnv(graph_inputs, graph));\n\n  // See comments above GetOutdatedTensorsAfterOp's definition for more details\n  const auto outdated_tensors_after_op = GetOutdatedTensorsAfterOp(job);\n\n  CHECK_OR_RETURN(job.has_placement()) << \"no job placement\";\n  const auto& job_placement = job.placement();\n  NameToParallelDescMap op2paralleldesc;\n  for (const auto& blob_placement_group : job_placement.blob_placement_group()) {\n    const auto parallel_desc = SymbolOf(ParallelDesc(blob_placement_group.parallel_conf()));\n    for (const auto& logical_blob_id : blob_placement_group.lbi()) {\n      op2paralleldesc.emplace(logical_blob_id.op_name(), parallel_desc);\n    }\n  }\n  CHECK_OR_RETURN(job.has_job_parallel_view_conf()) << \"no job parallel conf\";\n  const auto& op_name2nd_sbp_signature_conf =\n      job.job_parallel_view_conf().op_name2nd_sbp_signature_conf();\n\n  one::TensorTuple graph_outputs;\n  for (int i = 0; i < job.net().op_size(); i++) {\n    const auto& op_conf = job.net().op(i);\n    if (op_conf.has_user_conf()) {\n      auto op = CHECK_NOTNULL(graph->cached_op_exprs[i]);\n      const auto& user_conf = op_conf.user_conf();\n      OF_PROFILER_RANGE_GUARD(user_conf.op_type_name());\n      auto [inputs, ibns] =\n          *JUST(GetInputTensors(user_conf, env, [&op_conf](const std::shared_ptr<Tensor>& tensor) {\n            return CHECK_JUST(functional::To(tensor, op_conf.device_tag()));\n          }));\n      OpArgsVector<std::string> output_names = GetOutputNamesOfOp(user_conf);\n      if (!inputs.empty()\n          && inputs[0]->is_local()) {  // All tensors maintain the same properties of is_local\n        if (IsViewOp(op)) {\n          JUST(RunViewOp(op, env, inputs, output_names));\n        } else {\n          JUST(RunNormalOp(op, env, inputs, output_names));\n        }\n      } else {\n        const auto& op_parallel_desc = JUST(MapAt(op2paralleldesc, op_conf.name()));\n        const auto& nd_sbp_signature_conf =\n            JUST(MapAt(op_name2nd_sbp_signature_conf, op_conf.name()));\n        JUST(RunGlobalNormalOp(op, inputs, env, ibns, output_names, nd_sbp_signature_conf,\n                               op_parallel_desc));\n      }\n      for (const auto& name : outdated_tensors_after_op[i]) {\n        CHECK_EQ_OR_RETURN(env.erase(name), 1);\n      }\n    } else if (op_conf.has_learning_rate_schedule_conf()) {\n      // FIXME(daquexian):\n      // It is a temporary hack to support learning_rate_schedule op.\n      // Only the naive sgd without any lr decay is supported.\n      const auto& lr_conf = op_conf.learning_rate_schedule_conf();\n      env.emplace(\n          op_conf.name() + \"/\" + lr_conf.out(),\n          JUST(functional::Constant({1}, lr_conf.learning_rate(), DType::Float(), NullOpt)));\n    } else if (op_conf.has_identity_conf()) {\n      const auto& identity_conf = op_conf.identity_conf();\n      const auto& in = identity_conf.in();\n      const auto& out = op_conf.name() + \"/\" + identity_conf.out();\n      env.emplace(out, JUST(functional::Identity(JUST(MapAt(env, in)))));\n    } else if (op_conf.has_output_conf()) {\n      const auto& output_conf = op_conf.output_conf();\n      graph_outputs.emplace_back(JUST(MapAt(env, output_conf.in())));\n    }\n  }\n  return graph_outputs;\n}\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job/job_interpreter.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/job/job.pb.h\"\n\nnamespace oneflow {\nclass NNGraph;\nnamespace one {\nclass TensorTuple;\nMaybe<one::TensorTuple> InterpretJob(const one::TensorTuple& inputs,\n                                     const std::shared_ptr<NNGraph>& graph);\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job/job_ir.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/job/job_ir.h\"\n\nnamespace oneflow {\n\n#ifndef WITH_MLIR\n\nMaybe<std::string> ConvertJobToTosaIR(Job* job) {\n  UNIMPLEMENTED_THEN_RETURN() << \"ConvertJobToTosaIR is only supported WITH_MLIR\";\n}\n\nMaybe<void> SaveJobToIR(Job* job, const std::string& path) {\n  UNIMPLEMENTED_THEN_RETURN() << \"SaveJobToIR is only supported WITH_MLIR\";\n}\n\nMaybe<std::string> ConvertJobToIR(Job* job) {\n  UNIMPLEMENTED_THEN_RETURN() << \"ConvertJobToIR is only supported WITH_MLIR\";\n}\n\nMaybe<void> LoadJobFromIR(Job* job, const std::string& path) {\n  UNIMPLEMENTED_THEN_RETURN() << \"LoadJobFromIR is only supported WITH_MLIR\";\n}\n\n#endif\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job/job_ir.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_JOB_JOB_IR_H_\n#define ONEFLOW_CORE_JOB_JOB_IR_H_\n\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/job/job.pb.h\"\n\nnamespace oneflow {\n\nMaybe<std::string> ConvertJobToTosaIR(Job* job);\nMaybe<std::string> ConvertJobToIR(Job* job);\nMaybe<void> SaveJobToIR(Job* job, const std::string& path);\nMaybe<void> LoadJobFromIR(Job* job, const std::string& path);\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_JOB_JOB_IR_H_\n"
  },
  {
    "path": "oneflow/core/job/job_set.proto",
    "content": "syntax = \"proto2\";\npackage oneflow;\n\nimport \"oneflow/core/job/job.proto\";\nimport \"oneflow/core/job/resource.proto\";\n\nmessage ReuseMemPriorityStrategy {\n}\n\nmessage ParallelismPriorityStrategy {\n}\n\nmessage JobNameGroup {\n  repeated string job_name = 1;\n}\n\nmessage CustomParallelismStrategy {\n  repeated JobNameGroup nonparallel_group = 1;\n}\n\nmessage InterJobReuseMemStrategy {\n  oneof strategy_case {\n    ReuseMemPriorityStrategy reuse_mem_priority = 1;\n    ParallelismPriorityStrategy parallelism_priority = 2;\n    CustomParallelismStrategy custom_parallelism = 3;\n  }\n}\n\nmessage ConfigProto {\n  required Resource resource = 1;\n  required int64 session_id = 5;\n}\n\nmessage JobSet {\n  repeated Job job = 1;\n  optional InterJobReuseMemStrategy inter_job_reuse_mem_strategy = 5;\n}\n"
  },
  {
    "path": "oneflow/core/job/job_set_compile_ctx.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_JOB_SET_COMPILE_CTX_\n#define ONEFLOW_CORE_JOB_SET_COMPILE_CTX_\n\n#include \"oneflow/core/job/compiler.h\"\n#include \"oneflow/core/job/job_set_compile_ctx.pb.h\"\n\nnamespace oneflow {\n\nclass JobSetCompileCtx final {\n public:\n  JobSetCompileCtx() = default;\n  ~JobSetCompileCtx() = default;\n\n  PbMap<std::string, int64_t>* GetVarOpName2randomSeed() {\n    return job_set_compile_ctx_proto_.mutable_var_op_name2random_seed();\n  }\n\n private:\n  JobSetCompileCtxProto job_set_compile_ctx_proto_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_JOB_SET_COMPILE_CTX_\n"
  },
  {
    "path": "oneflow/core/job/job_set_compile_ctx.proto",
    "content": "syntax = \"proto2\";\npackage oneflow;\n\nmessage JobSetCompileCtxProto {\n  map<string, int64> var_op_name2random_seed = 1;\n}\n"
  },
  {
    "path": "oneflow/core/job/lazy_mode.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/job/lazy_mode.h\"\n\nnamespace oneflow {\n\n/* static */ bool* LazyMode::get_mode_ptr() {\n  static thread_local bool mode = false;\n  return &mode;\n}\n\n/* static */ bool LazyMode::is_enabled() { return *get_mode_ptr(); }\n\n/* static */ void LazyMode::set_enabled(bool enabled) { *get_mode_ptr() = enabled; }\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job/lazy_mode.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_JOB_LAZY_MODE_H_\n#define ONEFLOW_CORE_JOB_LAZY_MODE_H_\n\n#include \"oneflow/core/common/util.h\"\n\nnamespace oneflow {\n\nclass LazyMode {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(LazyMode);\n  LazyMode() = delete;\n  ~LazyMode() = delete;\n\n  static bool is_enabled();\n  class Guard {\n   public:\n    explicit Guard(bool enabled) : prev_mode_(LazyMode::is_enabled()) {\n      LazyMode::set_enabled(enabled);\n    }\n    ~Guard() { LazyMode::set_enabled(prev_mode_); }\n\n   private:\n    bool prev_mode_;\n  };\n\n private:\n  static bool* get_mode_ptr();\n  static void set_enabled(bool enabled);\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_JOB_LAZY_MODE_H_\n"
  },
  {
    "path": "oneflow/core/job/learning_rate_schedule_conf.proto",
    "content": "syntax = \"proto2\";\npackage oneflow;\n\nmessage ExponentialDecayConf {\n  required int64 decay_batches = 1;\n  required double decay_rate = 2;\n  optional bool staircase = 3 [default = false];\n}\n\nmessage InverseTimeDecayConf {\n  required int64 decay_batches = 1;\n  required double decay_rate = 2;\n  optional bool staircase = 3 [default = false];\n}\n\nmessage NaturalExpDecayConf {\n  required int64 decay_batches = 1;\n  required double decay_rate = 2;\n  optional bool staircase = 3 [default = false];\n}\n\nmessage PiecewiseConstantConf {\n  repeated int64 boundaries = 1;\n  repeated double values = 2;\n}\n\nmessage PolynomialDecayConf {\n  required int64 decay_batches = 1;\n  optional double end_learning_rate = 2 [default = 0.0001];\n  optional double power = 3 [default = 1.0];\n  optional bool cycle = 4 [default = false];\n}\n\nmessage CosineDecayConf {\n  required int64 decay_batches = 1;\n  optional double alpha = 2 [default = 0.0];\n}\n\nmessage CosineAnnealingDecayConf {\n  required int64 t_max = 1;\n  optional double eta_min = 2 [default = 0.0];\n}\n\nmessage LinearCosineDecayConf {\n  required int64 decay_batches = 1;\n  optional double num_periods = 2 [default = 0.5];\n  optional double alpha = 3 [default = 0.0];\n  optional double beta = 4 [default = 0.001];\n}\n\nmessage PiecewiseScalingConf {\n  repeated int64 boundaries = 1;\n  repeated double scales = 2;\n}\n\nmessage StepConf {\n  required int64 step_size = 1;\n  optional double gamma = 2 [default = 0.1];\n}\n\nmessage MultiStepConf {\n  repeated int64 milestones = 1;\n  optional double gamma = 2 [default = 0.1];\n}\n\nmessage LinearLRConf {\n  required double start_factor = 1;\n  required double end_factor = 2;\n  required int64 total_iters = 3;\n}\n\nmessage ConstantLRConf {\n  required double factor = 1;\n  required int64 total_iters = 2;\n}\n\nmessage CosineAnnealingWarmRestartsConf {\n  required int64 t_initial = 1;\n  required int64 t_mult = 2;\n  required double eta_min = 3;\n  required double decay_rate = 4;\n  required int64 restart_limit = 5;\n}\n\nmessage SequentialSchedulerConf {\n  repeated LearningRateDecayConf schedulers = 1;\n  repeated int64 milestones = 2;\n  // NOTE(zwx): should be repeated bool, however it has bug in cfg\n  repeated int32 interval_rescaling = 3;\n}\n\n// TODO(zwx): ChainedSchedulerConf \n\nmessage LearningRateDecayConf {\n  oneof type {\n    ExponentialDecayConf exponential_conf = 2000;\n    InverseTimeDecayConf inverse_time_conf = 2001;\n    NaturalExpDecayConf natural_exp_conf = 2002;\n    PiecewiseConstantConf piecewise_constant_conf = 2003;\n    PolynomialDecayConf polynomial_conf = 2004;\n    CosineDecayConf cosine_conf = 2005;\n    LinearCosineDecayConf linear_cosine_conf = 2006;\n    PiecewiseScalingConf piecewise_scaling_conf = 2007;\n    MultiStepConf multi_step_conf = 2008;\n    StepConf step_conf = 2009;\n    CosineAnnealingDecayConf cosine_annealing_conf = 2010;\n    LinearLRConf linear_lr_conf = 2011;\n    ConstantLRConf constant_lr_conf = 2012;\n    CosineAnnealingWarmRestartsConf cosine_annealing_warm_restarts_conf = 2013;\n    SequentialSchedulerConf sequential_scheduler_conf = 2014;\n  }\n}\n"
  },
  {
    "path": "oneflow/core/job/local_parallel.proto",
    "content": "syntax = \"proto2\";\npackage oneflow;\n\nmessage LocalParallel {\n}\n\nmessage OptLocalParallel {\n  optional LocalParallel local_parallel = 1;\n}\n\nmessage LocalSignature {\n  map<string, OptLocalParallel> bn_in_op2opt_local_parallel = 1;\n}\n"
  },
  {
    "path": "oneflow/core/job/local_sig_infer_hint.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_JOB_MIRRORED_SIG_INFER_HINT_H_\n#define ONEFLOW_CORE_JOB_MIRRORED_SIG_INFER_HINT_H_\n\n#include \"oneflow/core/job/parallel_desc.h\"\n\nnamespace oneflow {\n\nclass LocalSigInferHint final {\n public:\n  LocalSigInferHint(const ParallelDesc* parallel_desc, bool is_local_parallel_view)\n      : parallel_desc_(parallel_desc), is_local_parallel_view_(is_local_parallel_view) {}\n\n  const ParallelDesc& parallel_desc() const { return *parallel_desc_; }\n  bool is_local_parallel_view() const { return is_local_parallel_view_; }\n\n private:\n  const ParallelDesc* parallel_desc_;\n  bool is_local_parallel_view_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_JOB_MIRRORED_SIG_INFER_HINT_H_\n"
  },
  {
    "path": "oneflow/core/job/memory_share_strategy.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/job/memory_share_strategy.h\"\n#include <glog/logging.h>\n#include <algorithm>\n#include \"oneflow/core/common/hash_container.h\"\n#include \"oneflow/core/common/just.h\"\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/register/runtime_register_desc.h\"\n\nnamespace oneflow {\n\nnamespace {\nconstexpr int32_t kMaxIterStep = 100;\n}  // anonymous namespace\n\nbool IsLifetimeExcluded(const std::pair<int32_t, int32_t>& a,\n                        const std::pair<int32_t, int32_t>& b) {\n  return a.first < b.second && b.first < a.second;\n}\n\n// Initialization\nvoid MemoryShareStrategy::InitRegister(\n    const HashMap<RegstDescProto*, std::pair<int32_t, int32_t>>& register2lifetime) {\n  total_register_num_ = register2lifetime.size();\n  index2register_.resize(total_register_num_);\n  int32_t register_id = 0;\n  for (const auto& pair : register2lifetime) {\n    index2register_[register_id] = pair.first;\n    register_id++;\n  }\n}\n\nvoid MemoryShareStrategy::InitRegisterInformation(\n    const HashMap<RegstDescProto*, size_t>& mem_reused_regst2size) {\n  total_register_num_ = index2register_.size();\n  register_size_.resize(total_register_num_);\n  for (int32_t register_id = 0; register_id < total_register_num_; register_id++) {\n    const auto& register_ = index2register_[register_id];\n    int64_t register_size = mem_reused_regst2size.at(register_);\n    register_size_[register_id] = register_size;\n    register2index_[register_] = register_id;\n  }\n  order_.resize(total_register_num_);\n  for (int32_t i = 0; i < total_register_num_; i++) { order_[i] = i; }\n}\n\n// Steal a compact position as the initial strategy\nvoid MemoryShareStrategy::StealCompactPosition(\n    const HashMap<RegstDescProto*, int64_t>& regst_desc2offset,\n    const HashMap<RegstDescProto*, size_t>& mem_reused_regst2size,\n    const HashMap<RegstDescProto*, std::pair<int32_t, int32_t>>& register2lifetime) {\n  // Initialization\n  InitRegister(register2lifetime);\n\n  // Sort index2register_\n  std::sort(index2register_.begin(), index2register_.end(),\n            [&](RegstDescProto* i, RegstDescProto* j) {\n              return regst_desc2offset.at(i) < regst_desc2offset.at(j);\n            });\n  // Update other information\n  InitRegisterInformation(mem_reused_regst2size);\n\n  left_registers_.clear();\n  left_registers_.resize(total_register_num_);\n  excluded_registers_.clear();\n  excluded_registers_.resize(total_register_num_);\n  // should_visit_[i] indicates whether we should visit register[i].\n  // should_visit_[i] = 0: should not visit i, or have already visited i..\n  // should_visit_[i] = 1: should visit i, i is excluded with j\n  // should_visit_[i] = 2: should visit i, i is not excluded with j\n  should_visit_.clear();\n  should_visit_.resize(total_register_num_, 0);\n  register_offset_.resize(total_register_num_);\n  // Generate a compact relationship of position\n  // For example we have 3 relationship: x1 < x2, x2 < x3, x1 < x3\n  // We would delete the redundant relationship (x1 < x3)\n  for (int32_t j = 0; j < total_register_num_; j++) {\n    const auto& register_j = index2register_[j];\n    register_offset_[j] = regst_desc2offset.at(register_j);\n    auto& excluded_register_j = excluded_registers_[j];\n    const auto& lifetime_j = register2lifetime.at(register_j);\n    // Init should visit with all orders of the excluded register\n    for (int32_t i = j + 1; i < total_register_num_; i++) {\n      if (IsLifetimeExcluded(lifetime_j, register2lifetime.at(index2register_[i]))) {\n        // Copy the data to excluded registers\n        excluded_register_j.insert(i);\n        excluded_registers_[i].insert(j);\n      }\n    }\n  }\n\n  for (int32_t j = 0; j < total_register_num_; j++) { ResetCompactPosition(j); }\n}\n\n// Generate a compact position with the order of occurrence\n// Not recommended\nvoid MemoryShareStrategy::GenerateCompactPosition(\n    const HashMap<RegstDescProto*, size_t>& mem_reused_regst2size,\n    const HashMap<RegstDescProto*, std::pair<int32_t, int32_t>>& register2lifetime) {\n  HashMap<RegstDescProto*, int64_t> regst_desc2offset;\n  int64_t offset = 0;\n  for (const auto& pair : register2lifetime) {\n    regst_desc2offset[pair.first] = offset;\n    offset++;\n  }\n  StealCompactPosition(regst_desc2offset, mem_reused_regst2size, register2lifetime);\n}\n\n// Compute optimal cost with compact relationship\nsize_t MemoryShareStrategy::ComputeOptimalCost4CompactRelationship() {\n  int64_t mem_block_size = 0;\n  for (int32_t i = 0; i < total_register_num_; i++) {\n    mem_block_size =\n        std::max(mem_block_size, ComputeOffset4CompactRelationship(i) + register_size_[i]);\n  }\n  mem_block_size_ = size_t(mem_block_size);\n  return mem_block_size_;\n}\n\n// Compute offset with compact relationship\nint64_t MemoryShareStrategy::ComputeOffset4CompactRelationship(int32_t i) {\n  if (register_offset_[i] < 0) {\n    // An initial value x would be store as -x - 1.\n    register_offset_[i] = -register_offset_[i] - 1;\n    for (int32_t j : left_registers_[i]) {\n      register_offset_[i] =\n          std::max(register_offset_[i], ComputeOffset4CompactRelationship(j) + register_size_[j]);\n    }\n  }\n  return register_offset_[i];\n}\n\nsize_t MemoryShareStrategy::ComputeOptimalAdjustedCost() {\n  // Initial optimal cost\n  size_t optimal_cost = ComputeOptimalCostFrom0();\n  // All the registers excluded with register i are sorted from left to right\n  // std::vector<int32_t> order_;\n  // auto CompareRegisterPosition = [&](int32_t i, int32_t j) {\n  //   return register_offset_[i] < register_offset_[j];\n  // };\n  backup_registers_.clear();\n  backup_registers_.resize(total_register_num_);\n  // The number of steps that the optimal cost does not decrease\n  int32_t step_no_decrease = 0;\n  for (int32_t m = 0; m < max_iteration_step_; m++) {\n    for (int32_t i = 0; i < total_register_num_; i++) {\n      EliminateRegister(i);\n      size_t cost_without_i = ComputeOptimalCostFrom0();\n      // Find the offset of i which has the minimum cost\n      int64_t min_x_i = -1;\n      if (cost_without_i < optimal_cost) {\n        // Find the minimum cost\n        int64_t min_cost = optimal_cost;\n        // Back up the current register offset with elimination of i\n        auto register_offset_backup = register_offset_;\n        // Try to insert the register i into the sorted excluded registers\n        HashSet<int64_t> all_x_i;\n        for (int32_t j : excluded_registers_[i]) {\n          // Insert i before j\n          all_x_i.insert(register_offset_backup[j]);\n          // Insert i after j\n          all_x_i.insert(register_offset_backup[j] + register_size_[j]);\n        }\n\n        for (int64_t x_i : all_x_i) {\n          int64_t cost_insert_i = ComputeOptimalCostWithOccupation(i, x_i, register_offset_backup);\n          // Check if we found a smaller cost\n          if (cost_insert_i < min_cost) {\n            min_cost = cost_insert_i;\n            min_x_i = x_i;\n            if (min_cost <= cost_without_i) { break; }\n          }\n        }\n        // Found a smaller cost\n        if (min_x_i >= 0) {\n          InsertRegister(i, min_x_i, register_offset_backup);\n          optimal_cost = ComputeOptimalCostFrom0();\n        }\n      }\n      // Found a smaller cost\n      if (min_x_i >= 0) {\n        // Move to a new status with smaller cost, dump the backup of the offset.\n        ClearBackup();\n        step_no_decrease = 0;\n      } else {\n        // Recover to the original status\n        RecoverFromBackup(i);\n        // Adjust the offset after recovery\n        ComputeOptimalCostFrom0();\n        // Terminate it if no cost reduce for any of the adjustment.\n        step_no_decrease++;\n        if (step_no_decrease >= total_register_num_) { break; }\n      }\n    }\n    if (step_no_decrease >= total_register_num_) { break; }\n  }\n  CHECK_JUST(CheckConflict());\n  return optimal_cost;\n}\n\n// Let x_i occupy some space [x_i, x_i + l_i), then we recompute the optimal cost\nsize_t MemoryShareStrategy::ComputeOptimalCostWithOccupation(\n    int32_t i, int64_t x_i, const std::vector<int64_t>& register_offset_backup) {\n  // The end of register i.\n  int64_t e_i = x_i + register_size_[i];\n  register_offset_.clear();\n  register_offset_.resize(total_register_num_, -1);\n  for (int32_t k : excluded_registers_[i]) {\n    // x_k + l_k > x_i\n    // k is behind i\n    if (register_offset_backup[k] + register_size_[k] > x_i) {\n      register_offset_[k] = -e_i - 1;\n    } else {\n      register_offset_[k] = register_offset_backup[k];\n    }\n  }\n  register_offset_[i] = x_i;\n  return ComputeOptimalCost4CompactRelationship();\n}\n\n// Eliminate one register\nvoid MemoryShareStrategy::EliminateRegister(int32_t i) {\n  // Init back up registers\n  backup_registers_[i] = left_registers_[i];\n  for (auto j : excluded_registers_[i]) {\n    if (register_offset_[i] < register_offset_[j]) {\n      should_visit_.clear();\n      should_visit_.resize(total_register_num_, 0);\n      // should_visit_[i] = 0: should not visit i, or have already visited i..\n      // should_visit_[i] = 1: should visit i, i is excluded with j\n      // should_visit_[i] = 2: should visit i, i is not excluded with j\n      // should_visit_[i] = -1: i is visited, i is excluded with j\n      // should_visit_[i] = -2: i is visited, i is not excluded with j\n      for (int32_t k = 0; k < total_register_num_; k++) {\n        if (register_offset_[k] < register_offset_[j]) {\n          if (Exclude(k, j)) {\n            should_visit_[k] = 1;\n          } else {\n            should_visit_[k] = 2;\n          }\n        }\n      }\n      // Eliminate all the grandsons of the excluded registers\n      for (int32_t k : excluded_registers_[j]) {\n        if (should_visit_[k] == 1) { EliminateRedundantRelationshipIgnore(i, k); }\n      }\n      for (int32_t k : excluded_registers_[j]) {\n        if (should_visit_[k] == -1) {\n          if (left_registers_[j].insert(k).second) { backup_registers_[j].insert(k); }\n        }\n      }\n      if (left_registers_[j].erase(i)) { backup_register_behind_i_.insert(j); }\n    }\n  }\n  left_registers_[i].clear();\n}\n\n// Whether i and j occurs simultaneously\nbool MemoryShareStrategy::Exclude(int32_t i, int32_t j) {\n  return excluded_registers_[i].find(j) != excluded_registers_[i].end();\n}\n\n// If the previous strategy has fewer cost, recover to the previous one from the backup.\nvoid MemoryShareStrategy::RecoverFromBackup(int32_t i) {\n  for (int32_t j = 0; j < total_register_num_; j++) {\n    if (i == j) {\n      left_registers_[i] = backup_registers_[i];\n    } else {\n      for (int32_t k : backup_registers_[j]) { left_registers_[j].erase(k); }\n    }\n  }\n  for (int32_t j : backup_register_behind_i_) { left_registers_[j].insert(i); }\n  ClearBackup();\n}\n\n// Clear backup\nvoid MemoryShareStrategy::ClearBackup() {\n  for (auto& backup_register : backup_registers_) { backup_register.clear(); }\n  backup_register_behind_i_.clear();\n}\n\nsize_t MemoryShareStrategy::ComputeOptimalCostFrom0() {\n  register_offset_.clear();\n  register_offset_.resize(total_register_num_, -1);\n  return ComputeOptimalCost4CompactRelationship();\n}\n\n// Insert register i at position [x_i, x_i + l_i)\nvoid MemoryShareStrategy::InsertRegister(int32_t i, int64_t x_i,\n                                         const std::vector<int64_t>& original_register_offset) {\n  ComputeOptimalCostWithOccupation(i, x_i, original_register_offset);\n  std::sort(order_.begin(), order_.end(),\n            [&](int32_t k, int32_t j) { return register_offset_[k] < register_offset_[j]; });\n  for (int32_t j : order_) {\n    if (register_offset_[i] <= register_offset_[j]) { ResetCompactPosition(j); }\n  }\n}\n\n// Eliminate children of j but ignore i.\nvoid MemoryShareStrategy::EliminateRedundantRelationshipIgnore(int32_t i, int32_t j) {\n  // Ignore i\n  if (i == j) { return; }\n  if (should_visit_[j] > 0) {\n    // Do not look into it again\n    should_visit_[j] = -should_visit_[j];\n    for (int32_t k : left_registers_[j]) {\n      EliminateRedundantRelationshipIgnore(i, k);\n      should_visit_[k] = 0;\n    }\n  }\n}\n\n// Check whether the current offset does not introduce any conflict\nMaybe<void> MemoryShareStrategy::CheckConflict() {\n  CHECK_EQ_OR_RETURN(index2register_.size(), register_offset_.size())\n      << \"Not equal size, we might be calling CheckConflict() at a wrong time.\";\n  for (int32_t i = 0; i < total_register_num_; i++) {\n    CHECK_GE_OR_RETURN(register_offset_[i], 0) << \"Register offset is not computed.\";\n    for (int32_t j : excluded_registers_[i]) {\n      CHECK_OR_RETURN(register_offset_[i] + register_size_[i] <= register_offset_[j]\n                      || register_offset_[j] + register_size_[j] <= register_offset_[i])\n          << \"Two registers overlap\";\n    }\n  }\n  return Maybe<void>::Ok();\n}\n\n// Update the offset with the adjusted strategy\nvoid MemoryShareStrategy::UpdateOffset(size_t* mem_block_size,\n                                       HashMap<RegstDescProto*, int64_t>* regst_desc2offset) {\n  size_t optimal_cost = ComputeOptimalAdjustedCost();\n  if (optimal_cost < *mem_block_size) {\n    VLOG(3) << \"Original cost: \" << *mem_block_size << \", updated cost: \" << optimal_cost;\n    *mem_block_size = optimal_cost;\n    for (auto& pair : *regst_desc2offset) {\n      pair.second = register_offset_[register2index_[pair.first]];\n    }\n  }\n}\n\n// Find all the k < i, eliminates k < j,\n// since k < i and i < j have already implied that.\nvoid MemoryShareStrategy::EliminateRedundantRelationship(int32_t i) {\n  // If i is already eliminate, skip it.\n  if (should_visit_[i]) {\n    for (int32_t k : left_registers_[i]) {\n      // Eliminate all the k < i\n      EliminateRedundantRelationship(k);\n      // Eliminate left[i]\n      should_visit_[k] = 0;\n    }\n  }\n}\n\n// Reset the compact position for the registers\nvoid MemoryShareStrategy::ResetCompactPosition(int32_t j) {\n  left_registers_[j].clear();\n  // Mark all the registers on the left\n  for (int32_t i = 0; i < total_register_num_; i++) {\n    if (register_offset_[i] < register_offset_[j]) {\n      if (Exclude(i, j)) {\n        should_visit_[i] = 1;\n      } else {\n        should_visit_[i] = 2;\n      }\n    } else {\n      // Might be unnecessary since we clear up should_visit_ before.\n      should_visit_[i] = 0;\n    }\n  }\n\n  for (int32_t i = 0; i < total_register_num_; i++) {\n    if (should_visit_[i] == 1) {\n      // Find all the k < i, eliminates k < j,\n      // since k < i and i < j have already implied that.\n      // Also reset should_visit_[i] to false,\n      // since we have already visited i.\n      EliminateRedundantRelationship(i);\n    }\n  }\n\n  for (int32_t i = 0; i < total_register_num_; i++) {\n    if (should_visit_[i] == 1) {\n      // i < j\n      left_registers_[j].insert(i);\n    }\n    // Might be unnecessary since we clear up should_visit_ before.\n    should_visit_[i] = 0;\n  }\n}\n\n// Update the maximum iteration step with the current size and lower bound\nvoid MemoryShareStrategy::UpdateMaxIteration(size_t mem_block_size, size_t lower_bound) {\n  if (lower_bound > 0) {\n    max_iteration_step_ = ((mem_block_size - lower_bound) * 100) / lower_bound;\n  } else {\n    // A graph only containing several 0 size tensors might have lower bound = 0.\n    // Check test_div.py::TestDiv::test_0_size_div for example.\n    max_iteration_step_ = 0;\n  }\n  // if mem_block_size is closed to the maximum number of type size_t, then we might have a negative\n  // value for (mem_block_size - lower_bound) * 100\n  // In this case, we just set a large max_iteration_step_\n  if (max_iteration_step_ < 0) { max_iteration_step_ = kMaxIterStep; }\n}\n\n// Adaptively update the offset of registers to minimize the total memory\nvoid MemoryShareStrategy::AdaptivelyUpdateOffset(\n    const HashMap<RegstDescProto*, size_t>& mem_reused_regst2size,\n    const HashMap<RegstDescProto*, std::pair<int32_t, int32_t>>& register2lifetime,\n    size_t lower_bound, size_t* mem_block_size,\n    HashMap<RegstDescProto*, int64_t>* regst_desc2offset) {\n  VLOG(3) << \"Current memory size: \" << *mem_block_size << \", lower bound : \" << lower_bound;\n  if (*mem_block_size > lower_bound) {\n    UpdateMaxIteration(*mem_block_size, lower_bound);\n    VLOG(3) << \"max iteration step: \" << max_iteration_step_;\n    if (max_iteration_step_ > 0) {\n      StealCompactPosition(*regst_desc2offset, mem_reused_regst2size, register2lifetime);\n      UpdateOffset(mem_block_size, regst_desc2offset);\n    }\n    VLOG(3) << \"After compression, memory size: \" << *mem_block_size;\n  }\n}\n\n// Set the offset of registers to minimize the total memory\n// Iterating from a random order might take a lot of steps to reach the optimal cost.\n// Therefore, this function is not recommended with an initial offset provided.\nvoid MemoryShareStrategy::GenerateOffset(\n    const HashMap<RegstDescProto*, size_t>& mem_reused_regst2size,\n    const HashMap<RegstDescProto*, std::pair<int32_t, int32_t>>& register2lifetime,\n    size_t* mem_block_size, HashMap<RegstDescProto*, int64_t>* regst_desc2offset) {\n  max_iteration_step_ = kMaxIterStep;\n  VLOG(3) << \"max iteration step: \" << max_iteration_step_;\n  GenerateCompactPosition(mem_reused_regst2size, register2lifetime);\n  UpdateOffset(mem_block_size, regst_desc2offset);\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job/memory_share_strategy.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#ifndef ONEFLOW_CORE_JOB_MEMORY_SHARE_STRATEGY_H_\n#define ONEFLOW_CORE_JOB_MEMORY_SHARE_STRATEGY_H_\n\n#include <vector>\n#include \"oneflow/core/common/hash_container.h\"\n#include \"oneflow/core/register/register_desc.pb.h\"\n#include \"oneflow/core/common/maybe.h\"\n\nnamespace oneflow {\n\n// NOTE: Another trick to save times.\n// Comparing two numbers is faster than asking the existence in a HashSet.\nbool IsLifetimeExcluded(const std::pair<int32_t, int32_t>& a, const std::pair<int32_t, int32_t>& b);\n\nclass MemoryShareStrategy {\n public:\n  // Adaptively update the offset of registers to minimize the total memory\n  void AdaptivelyUpdateOffset(\n      const HashMap<RegstDescProto*, size_t>& mem_reused_regst2size,\n      const HashMap<RegstDescProto*, std::pair<int32_t, int32_t>>& register2lifetime,\n      size_t lower_bound, size_t* mem_block_size,\n      HashMap<RegstDescProto*, int64_t>* regst_desc2offset);\n\n  // Set the offset of registers to minimize the total memory\n  // Iterating from a random order might take a lot of steps to reach the optimal cost.\n  // Therefore, this function is not recommended with an initial offset provided.\n  void GenerateOffset(\n      const HashMap<RegstDescProto*, size_t>& mem_reused_regst2size,\n      const HashMap<RegstDescProto*, std::pair<int32_t, int32_t>>& register2lifetime,\n      size_t* mem_block_size, HashMap<RegstDescProto*, int64_t>* regst_desc2offset);\n\n private:\n  size_t mem_block_size_;\n  int32_t max_iteration_step_;\n  std::vector<int64_t> register_offset_;\n  std::vector<int64_t> register_size_;\n  HashMap<RegstDescProto*, int32_t> register2index_;\n  std::vector<RegstDescProto*> index2register_;\n  // left registers store the first registers on the left, which have smaller offsets.\n  // For example, 1 < 2 < 3 < 5\n  //                  2 < 4 < 5\n  // Then\n  //      left_registers_[1] = {}\n  //      left_registers_[2] = {1}\n  //      left_registers_[3] = {2}\n  //      left_registers_[4] = {2}\n  //      left_registers_[5] = {3, 4}\n  //  We know that 1 < 3, but 1 is not in left_registers_[3],\n  //  since we only store the first registers.\n  std::vector<HashSet<int32_t>> left_registers_;\n  // Store all the registers which exist simultaneously.\n  std::vector<HashSet<int32_t>> excluded_registers_;\n  // Back up the changes\n  std::vector<HashSet<int32_t>> backup_registers_;\n  HashSet<int32_t> backup_register_behind_i_;\n  // A buffer which implies whether we should visit a register\n  std::vector<int32_t> should_visit_;\n  int32_t total_register_num_;\n  std::vector<int32_t> order_;\n\n  // Mid-level interfaces\n  // Steal a compact position as the initial strategy\n  void StealCompactPosition(\n      const HashMap<RegstDescProto*, int64_t>& regst_desc2offset,\n      const HashMap<RegstDescProto*, size_t>& mem_reused_regst2size,\n      const HashMap<RegstDescProto*, std::pair<int32_t, int32_t>>& register2lifetime);\n  // Generate a compact position with the order of occurrence\n  void GenerateCompactPosition(\n      const HashMap<RegstDescProto*, size_t>& mem_reused_regst2size,\n      const HashMap<RegstDescProto*, std::pair<int32_t, int32_t>>& register2lifetime);\n  // Update the offset with the adjusted strategy\n  void UpdateOffset(size_t* mem_block_size, HashMap<RegstDescProto*, int64_t>* regst_desc2offset);\n  // Update the maximum iteration step with the current size and lower bound\n  void UpdateMaxIteration(size_t mem_block_size, size_t lower_bound);\n\n  // Initialization\n  void InitRegister(const HashMap<RegstDescProto*, std::pair<int32_t, int32_t>>& register2lifetime);\n  void InitRegisterInformation(const HashMap<RegstDescProto*, size_t>& mem_reused_regst2size);\n  // Adjust the original strategy, return the updated optimal cost\n  size_t ComputeOptimalAdjustedCost();\n  // Eliminate one register\n  void EliminateRegister(int32_t i);\n  // Eliminate children of j but ignore i.\n  void EliminateRedundantRelationshipIgnore(int32_t i, int32_t j);\n  // Whether i and j occurs simultaneously\n  bool Exclude(int32_t i, int32_t j);\n  // If the previous strategy without the elimination of i has fewer cost, recover to the previous\n  // one from the backup.\n  void RecoverFromBackup(int32_t i);\n  // Clear backup\n  void ClearBackup();\n  // Let x_i occupy some space [x_i, x_i + l_i), then we recompute the optimal cost\n  size_t ComputeOptimalCostWithOccupation(int32_t i, int64_t x_i,\n                                          const std::vector<int64_t>& register_offset_backup);\n  // Insert register i at position [x_i, x_i + l_i)\n  void InsertRegister(int32_t i, int64_t x_i, const std::vector<int64_t>& original_register_offset);\n\n  // Compute optimal cost with compact relationship\n  size_t ComputeOptimalCost4CompactRelationship();\n  size_t ComputeOptimalCostFrom0();\n  // Compute offset with compact relationship\n  int64_t ComputeOffset4CompactRelationship(int32_t i);\n  // Check whether the current offset does not introduce any conflict\n  Maybe<void> CheckConflict();\n  // Reset the compact position for the registers with should_visit_ = 0\n  void ResetCompactPosition(int32_t j);\n  // Find all the k < i, eliminates k < j,\n  // since k < i and i < j have already implied that.\n  void EliminateRedundantRelationship(int32_t i);\n};\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_JOB_MEMORY_SHARE_STRATEGY_H_\n"
  },
  {
    "path": "oneflow/core/job/module_conf.proto",
    "content": "syntax = \"proto2\";\npackage oneflow;\n\nmessage ModuleConf {\n  required string name = 1;\n  repeated string ops = 2;\n}\n"
  },
  {
    "path": "oneflow/core/job/nd_sbp_infer_hint.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_JOB_ND_SBP_INFER_HINT_H_\n#define ONEFLOW_CORE_JOB_ND_SBP_INFER_HINT_H_\n\n#include \"oneflow/core/job/sbp_parallel.pb.h\"\n#include \"oneflow/core/job/parallel_desc.h\"\n#include \"oneflow/core/register/blob_desc.h\"\n\nnamespace oneflow {\n\nclass NdSbpInferHint final {\n public:\n  NdSbpInferHint(const ParallelDesc* parallel_desc, const BlobDesc* logical_blob_desc,\n                 const NdSbp* nd_sbp)\n      : parallel_desc_(parallel_desc), logical_blob_desc_(logical_blob_desc), nd_sbp_(nd_sbp) {}\n  NdSbpInferHint(const NdSbpInferHint&) = default;\n  ~NdSbpInferHint() = default;\n\n  // Getters\n  const ParallelDesc& parallel_desc() const { return *parallel_desc_; }\n  const BlobDesc& logical_blob_desc() const { return *logical_blob_desc_; }\n  const NdSbp& nd_sbp() const { return *nd_sbp_; }\n\n private:\n  const ParallelDesc* parallel_desc_;\n  const BlobDesc* logical_blob_desc_;\n  const NdSbp* nd_sbp_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_JOB_ND_SBP_INFER_HINT_H_\n"
  },
  {
    "path": "oneflow/core/job/nd_sbp_util.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/job/nd_sbp_util.h\"\n#include \"oneflow/core/common/balanced_splitter.h\"\n#include \"oneflow/core/common/nd_index_offset_helper.h\"\n\nnamespace oneflow {\n\nstd::vector<TensorSliceView> GetTensorSliceView(const int64_t parallel_num,\n                                                const SbpParallel& sbp_parallel,\n                                                const BlobDesc& blob_desc) {\n  const Shape& shape = blob_desc.shape();\n  std::vector<Range> ranges(shape.NumAxes());\n  FOR_RANGE(int64_t, i, 0, shape.NumAxes()) {\n    ranges[i].mut_begin() = 0;\n    ranges[i].mut_end() = shape.At(i);\n  }\n  if (shape.NumAxes() == 0) {\n    // NOTE(chengcheng): For Scalar Tensor.\n    ranges.emplace_back(0, 1);\n  }\n  std::vector<TensorSliceView> views;\n  views.reserve(parallel_num);\n  if (sbp_parallel.has_partial_sum_parallel() || sbp_parallel.has_broadcast_parallel()) {\n    FOR_RANGE(int64_t, i, 0, parallel_num) { views.emplace_back(ranges); }\n  } else if (sbp_parallel.has_split_parallel()) {\n    const int64_t axis = sbp_parallel.split_parallel().axis();\n    CHECK_LT(axis, shape.NumAxes());\n    const BalancedSplitter bs(shape.At(axis), parallel_num);\n    FOR_RANGE(int64_t, i, 0, parallel_num) {\n      if (bs.At(i).size() == 0) {\n        views.emplace_back();\n      } else {\n        ranges[axis] = bs.At(i);\n        views.emplace_back(ranges);\n      }\n    }\n  } else {\n    UNIMPLEMENTED();\n  }\n  return views;\n}\n\nTensorSliceView GetTensorSliceView4ParallelRank(const Shape& parallel_hierarchy,\n                                                const NdSbp& nd_sbp, const Shape& logical_shape,\n                                                const std::vector<int64_t>& parallel_rank) {\n  std::vector<Range> ranges(logical_shape.NumAxes());\n  FOR_RANGE(int64_t, i, 0, logical_shape.NumAxes()) {\n    ranges[i].mut_begin() = 0;\n    ranges[i].mut_end() = logical_shape.At(i);\n  }\n  if (parallel_hierarchy.elem_cnt() == 1) { return TensorSliceView(ranges); }\n  if (parallel_hierarchy.NumAxes() == 1) {\n    const SbpParallel& sbp_parallel = nd_sbp.sbp_parallel(0);\n    if (sbp_parallel.has_split_parallel()) {\n      const int64_t split_axis = sbp_parallel.split_parallel().axis();\n      CHECK_GE(split_axis, 0);\n      CHECK_LT(split_axis, ranges.size());\n      const int64_t id = parallel_rank.front();\n      CHECK_GE(id, 0);\n      CHECK_LT(id, parallel_hierarchy.elem_cnt());\n      const BalancedSplitter bs(logical_shape.At(split_axis), parallel_hierarchy.elem_cnt());\n      CHECK_GT(bs.At(id).size(), 0);\n      ranges[split_axis] = bs.At(id);\n    }\n  } else {\n    Shape physical_shape(logical_shape);\n    FOR_RANGE(int64_t, i, 0, parallel_hierarchy.NumAxes()) {\n      const SbpParallel& sbp_parallel = nd_sbp.sbp_parallel(i);\n      if (sbp_parallel.has_split_parallel()) {\n        const int64_t split_axis = sbp_parallel.split_parallel().axis();\n        CHECK_GE(split_axis, 0);\n        CHECK_LT(split_axis, ranges.size());\n        CHECK_GE(ranges[split_axis].size(), parallel_hierarchy.At(i));\n        const BalancedSplitter bs(physical_shape.At(split_axis), parallel_hierarchy.At(i));\n        const auto& range = bs.At(parallel_rank.at(i));\n        const int64_t range_size = range.size();\n        const int64_t dim_start = ranges[split_axis].begin() + range.begin();\n        physical_shape.Set(split_axis, range_size);\n        ranges[split_axis].mut_begin() = dim_start;\n        ranges[split_axis].mut_end() = dim_start + range_size;\n      }\n    }\n  }\n  return TensorSliceView(ranges);\n}\n\nTensorSliceView GetTensorSliceView4ParallelId(const Shape& parallel_hierarchy, const NdSbp& nd_sbp,\n                                              const Shape& logical_shape, int64_t parallel_id) {\n  NdIndexOffsetHelper<int64_t, SHAPE_MAX_AXIS_SIZE> hierarchy_index_helper(\n      parallel_hierarchy.dim_vec().data(), parallel_hierarchy.NumAxes());\n  std::vector<int64_t> parallel_rank(SHAPE_MAX_AXIS_SIZE);\n  hierarchy_index_helper.OffsetToNdIndex(parallel_id, parallel_rank.data());\n  return GetTensorSliceView4ParallelRank(parallel_hierarchy, nd_sbp, logical_shape, parallel_rank);\n}\n\nstd::vector<TensorSliceView> GetTensorSliceView(const Shape& parallel_hierarchy,\n                                                const NdSbp& nd_sbp, const Shape& logical_shape) {\n  std::vector<TensorSliceView> views;\n  views.reserve(parallel_hierarchy.elem_cnt());\n  FOR_RANGE(int64_t, i, 0, parallel_hierarchy.elem_cnt()) {\n    views.emplace_back(GetTensorSliceView4ParallelId(parallel_hierarchy, nd_sbp, logical_shape, i));\n  }\n  return views;\n}\n\nTensorSliceView GetBroadcastTensorSliceView(const BlobDesc& blob_desc) {\n  return TensorSliceView(blob_desc.shape());\n}\n\nbool NdSbpHasPartialParallel(const NdSbp& nd_sbp) {\n  CHECK_GT(nd_sbp.sbp_parallel_size(), 0);\n  FOR_RANGE(int64_t, i, 0, nd_sbp.sbp_parallel_size()) {\n    if (nd_sbp.sbp_parallel(i).has_partial_sum_parallel()) { return true; }\n  }\n  return false;\n}\n\nbool NdSbpHasBroadcastParallel(const NdSbp& nd_sbp) {\n  CHECK_GT(nd_sbp.sbp_parallel_size(), 0);\n  FOR_RANGE(int64_t, i, 0, nd_sbp.sbp_parallel_size()) {\n    if (nd_sbp.sbp_parallel(i).has_broadcast_parallel()) { return true; }\n  }\n  return false;\n}\n\nbool NdSbpIsAllBroadcast(const NdSbp& nd_sbp) {\n  for (const auto& sbp_parallel : nd_sbp.sbp_parallel()) {\n    if (!sbp_parallel.has_broadcast_parallel()) { return false; }\n  }\n  return true;\n}\n\nbool NdSbpIsAllPartialSum(const NdSbp& nd_sbp) {\n  for (const auto& sbp_parallel : nd_sbp.sbp_parallel()) {\n    if (!sbp_parallel.has_partial_sum_parallel()) { return false; }\n  }\n  return true;\n}\n\nbool NdSbpIsAllSplit(const NdSbp& nd_sbp, int64_t axis) {\n  for (const auto& sbp_parallel : nd_sbp.sbp_parallel()) {\n    if (!(sbp_parallel.has_split_parallel() && sbp_parallel.split_parallel().axis() == axis)) {\n      return false;\n    }\n  }\n  return true;\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job/nd_sbp_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_JOB_ND_SBP_UTIL_H_\n#define ONEFLOW_CORE_JOB_ND_SBP_UTIL_H_\n\n#include \"oneflow/core/register/tensor_slice_view.h\"\n#include \"oneflow/core/job/sbp_parallel.h\"\n\nnamespace oneflow {\n\nstd::vector<TensorSliceView> GetTensorSliceView(int64_t parallel_num,\n                                                const SbpParallel& sbp_parallel,\n                                                const BlobDesc& blob_desc);\nstd::vector<TensorSliceView> GetTensorSliceView(const Shape& parallel_hierarchy,\n                                                const NdSbp& nd_sbp, const Shape& logical_shape);\nTensorSliceView GetTensorSliceView4ParallelRank(const Shape& parallel_hierarchy,\n                                                const NdSbp& nd_sbp, const Shape& logical_shape,\n                                                const std::vector<int64_t>& parallel_rank);\nTensorSliceView GetTensorSliceView4ParallelId(const Shape& parallel_hierarchy, const NdSbp& nd_sbp,\n                                              const Shape& logical_shape, int64_t parallel_id);\nTensorSliceView GetBroadcastTensorSliceView(const BlobDesc& blob_desc);\n\nbool NdSbpIsAllBroadcast(const NdSbp& nd_sbp);\nbool NdSbpIsAllPartialSum(const NdSbp& nd_sbp);\nbool NdSbpIsAllSplit(const NdSbp& nd_sbp, int64_t axis);\nbool NdSbpHasPartialParallel(const NdSbp& nd_sbp);\nbool NdSbpHasBroadcastParallel(const NdSbp& nd_sbp);\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_JOB_SBP_PARALLEL_H_\n"
  },
  {
    "path": "oneflow/core/job/oneflow.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/constant.h\"\n#include \"oneflow/core/common/range.h\"\n#include \"oneflow/core/common/str_util.h\"\n#include \"oneflow/core/common/protobuf.h\"\n#include \"oneflow/core/control/ctrl_client.h\"\n#include \"oneflow/core/control/global_process_ctx.h\"\n#include \"oneflow/core/common/buffer_manager.h\"\n#include \"oneflow/core/job/compiler.h\"\n#include \"oneflow/core/job/job_desc.h\"\n#include \"oneflow/core/job/job_builder.h\"\n#include \"oneflow/core/job/job_set.pb.h\"\n#include \"oneflow/core/job/sub_plan.pb.h\"\n#include \"oneflow/core/job/plan.pb.h\"\n#include \"oneflow/core/persistence/tee_persistent_log_stream.h\"\n#include \"oneflow/core/job/oneflow.h\"\n#include \"oneflow/core/job/inter_job_mem_sharing_util.h\"\n#include \"oneflow/core/job/plan_util.h\"\n#include \"oneflow/core/operator/interface_op_util.h\"\n#include \"oneflow/core/job/critical_section_desc.h\"\n#include \"oneflow/core/job/global_for.h\"\n#include \"oneflow/core/vm/virtual_machine.h\"\n#include \"oneflow/core/graph/plan_task_graph.h\"\n#include \"oneflow/core/graph/boxing/collective_boxing_util.h\"\n#include \"oneflow/core/profiler/profiler.h\"\n#include \"oneflow/core/job/sbp_parallel.h\"\n#include \"oneflow/core/job_rewriter/job_completer.h\"\n\nnamespace std {\n\ntemplate<>\nstruct hash<oneflow::ParallelBlobConf> {\n  size_t operator()(const oneflow::ParallelBlobConf& parallel_blob_conf) const {\n    std::string serialized;\n    parallel_blob_conf.SerializeToString(&serialized);\n    return std::hash<std::string>()(serialized);\n  }\n};\n\n}  // namespace std\n\nnamespace oneflow {\n\nbool operator==(const ParallelBlobConf& lhs, const ParallelBlobConf& rhs) {\n  return BlobDesc(lhs.logical_blob_desc_conf()) == BlobDesc(rhs.logical_blob_desc_conf())\n         && lhs.parallel_conf() == rhs.parallel_conf() && lhs.nd_sbp() == rhs.nd_sbp();\n}\n\nnamespace {\n\n// There are circles in MainJob.\n// A MainJob is a Job like:\n//\n// wait_and_send_ids_op -> reentrant_lock_op -> case_op -> identity_op -> esac_op ->\n//                                \\________________________________________________/\n//\n// back edges esac_op -> reentrant_lock_op are linked by rewriting the plan instead of\n// compiling OpGraph to TaskGraph.\n// ReentrantLockBackEdge holds the key information of a back edge\nstruct ReentrantLockBackEdge {\n  std::string reentrant_lock_op_name;       // back edge destination.\n  LogicalBlobId critical_section_sink_lbi;  // back edge source.\n};\n\nstd::string cluster_thrd_ids_key(const std::string& plan_name) {\n  return plan_name + \"_cluster_thrd_ids\";\n}\n\nstd::string ctrl_regst_desc_info_key(const std::string& plan_name) {\n  return plan_name + \"_ctrl_regst_desc_info_key\";\n}\n\nstd::string job_id2job_conf(const std::string& plan_name) { return plan_name + \"_job_id2job_conf\"; }\n\nstd::string GetCollectiveBoxingPlanKey(const std::string& plan_name) {\n  return plan_name + \"_collective_boxing_plan\";\n}\n\nstd::string sub_plan_key(const std::string& plan_name, int64_t machine_id, int64_t thrd_id) {\n  return plan_name + \"_\" + std::to_string(machine_id) + \"_\" + std::to_string(thrd_id);\n}\n\nstd::string block7chunk_key(const std::string& plan_name, int64_t machine_id) {\n  return plan_name + \"_\" + std::to_string(machine_id) + \"_block7chunk\";\n}\n\nvoid PushPlan(const std::string& plan_name, Plan&& plan) {\n  HashMap<int64_t, std::set<int64_t>> machine_id2thrd_id_set;\n  HashMap<std::pair<int64_t, int64_t>, std::list<TaskProto>> mchn_thrd_id2task_protos;\n  HashMap<int64_t, MemBlockAndChunkList> machine_id2block7chunk;\n\n  for (TaskProto& task : *plan.mutable_task()) {\n    machine_id2thrd_id_set[task.machine_id()].insert(task.thrd_id());\n    mchn_thrd_id2task_protos[std::make_pair(task.machine_id(), task.thrd_id())].emplace_back(\n        std::move(task));\n  }\n\n  HashMap<int64_t, ThrdIds> machine_id2thrd_ids;\n  for (const auto& pair : machine_id2thrd_id_set) {\n    CHECK(machine_id2thrd_ids.emplace(pair.first, ThrdIds()).second);\n    std::vector<int64_t> thrd_id_vec(pair.second.begin(), pair.second.end());\n    *(machine_id2thrd_ids.at(pair.first).mutable_thrd_id()) = StdVec2PbRf(thrd_id_vec);\n  }\n\n  ClusterThrdIds cluster_thrd_ids;\n  *(cluster_thrd_ids.mutable_machine_id2thrd_ids()) = HashMap2PbMap(machine_id2thrd_ids);\n  Singleton<CtrlClient>::Get()->PushKV(cluster_thrd_ids_key(plan_name), cluster_thrd_ids);\n\n  for (std::pair<const std::pair<int64_t, int64_t>, std::list<oneflow::TaskProto>>& pair :\n       mchn_thrd_id2task_protos) {\n    SubPlan sub_plan;\n    sub_plan.mutable_task()->Reserve(pair.second.size());\n    while (!pair.second.empty()) {\n      sub_plan.mutable_task()->Add(std::move(pair.second.front()));\n      pair.second.pop_front();\n    }\n    Singleton<CtrlClient>::Get()->PushKV(\n        sub_plan_key(plan_name, pair.first.first, pair.first.second), sub_plan);\n  }\n\n  for (const auto& mem_block : plan.block_chunk_list().mem_block()) {\n    *machine_id2block7chunk[mem_block.machine_id()].add_mem_block() = mem_block;\n  }\n  for (const auto& chunk : plan.block_chunk_list().chunk()) {\n    *machine_id2block7chunk[chunk.machine_id()].add_chunk() = chunk;\n  }\n  for (const auto& pair : machine_id2block7chunk) {\n    Singleton<CtrlClient>::Get()->PushKV(block7chunk_key(plan_name, pair.first), pair.second);\n  }\n\n  Singleton<CtrlClient>::Get()->PushKV(ctrl_regst_desc_info_key(plan_name),\n                                       plan.ctrl_regst_desc_info());\n  Singleton<CtrlClient>::Get()->PushKV(job_id2job_conf(plan_name), plan.job_confs());\n  Singleton<CtrlClient>::Get()->PushKV(GetCollectiveBoxingPlanKey(plan_name),\n                                       plan.collective_boxing_plan());\n}\n\nvoid PullPlan(const std::string& plan_name, Plan* plan) {\n  ClusterThrdIds cluster_thrd_ids;\n  Singleton<CtrlClient>::Get()->PullKV(cluster_thrd_ids_key(plan_name), &cluster_thrd_ids);\n  PrintProtoToTextFile(cluster_thrd_ids, JoinPath(FLAGS_log_dir, cluster_thrd_ids_key(plan_name)));\n  HashMap<int64_t, ThrdIds> machine_id2thrd_ids;\n  machine_id2thrd_ids = PbMap2HashMap(cluster_thrd_ids.machine_id2thrd_ids());\n  int64_t machine_id = GlobalProcessCtx::Rank();\n  auto thrd_ids_it = machine_id2thrd_ids.find(machine_id);\n  CHECK(thrd_ids_it != machine_id2thrd_ids.end());\n  std::vector<int64_t> thrd_id_vec = PbRf2StdVec(thrd_ids_it->second.thrd_id());\n  for (auto thrd_id : thrd_id_vec) {\n    SubPlan sub_plan;\n    Singleton<CtrlClient>::Get()->PullKV(sub_plan_key(plan_name, machine_id, thrd_id), &sub_plan);\n    plan->mutable_task()->MergeFrom(sub_plan.task());\n  }\n  CtrlRegstDescInfo ctrl_regst_desc_info;\n  Singleton<CtrlClient>::Get()->PullKV(ctrl_regst_desc_info_key(plan_name), &ctrl_regst_desc_info);\n  *(plan->mutable_ctrl_regst_desc_info()) = ctrl_regst_desc_info;\n  JobConfs job_confs;\n  Singleton<CtrlClient>::Get()->PullKV(job_id2job_conf(plan_name), &job_confs);\n  *(plan->mutable_job_confs()) = job_confs;\n  Singleton<CtrlClient>::Get()->PullKV(GetCollectiveBoxingPlanKey(plan_name),\n                                       plan->mutable_collective_boxing_plan());\n  MemBlockAndChunkList block7chunk;\n  Singleton<CtrlClient>::Get()->PullKV(block7chunk_key(plan_name, machine_id), &block7chunk);\n  plan->mutable_block_chunk_list()->CopyFrom(block7chunk);\n  // pull op_attribute_info\n  OpAttributeInfo op_attribute_info;\n  Singleton<CtrlClient>::Get()->PullKV(\"op_attribute_info\", &op_attribute_info);\n  // populate op_attribute_info\n  PlanUtil::PopulateOpAttribute(plan, op_attribute_info.job_id2op_attribute_ref_table());\n}\n\nMaybe<void> CompileCurJobOnMaster(Job* job, Plan* plan, bool need_job_complete) {\n  const JobDesc& job_desc = GlobalJobDesc();\n  if (GlobalProcessCtx::IsThisProcessMaster()) {\n    double start = GetCurTime();\n    if (need_job_complete) { JUST(JobCompleter::Complete(job)); }\n    Compiler().Compile(job, plan);\n    PlanUtil::GenMemBlockAndChunk4Plan(plan);\n\n    LOG(INFO) << \"\\njob_id: \" << job_desc.job_id() << \" , job_name: \" << job_desc.job_name()\n              << \" , compile time: \" << (GetCurTime() - start) / 1000000000.0 << \" seconds.\\n\";\n    if (Singleton<ResourceDesc, ForSession>::Get()->enable_debug_mode()) {\n      TeePersistentLogStream::Create(StrCat(\"subplan_job_\", job_desc.job_id()))->Write(*plan);\n    }\n  }\n  PlanUtil::GenCollectiveBoxingPlan(job, plan);\n  PlanUtil::GenRegisterHint(plan);\n  return Maybe<void>::Ok();\n}\n\nvoid MergePlan(Plan* plan, Plan&& other) {\n  PbRpf<TaskProto>* dst_tasks = plan->mutable_task();\n  PbRpf<TaskProto>* src_tasks = other.mutable_task();\n  dst_tasks->Reserve(dst_tasks->size() + src_tasks->size());\n  for (TaskProto& task : *src_tasks) { *(dst_tasks->Add()) = std::move(task); }\n  plan->mutable_block_chunk_list()->MergeFrom(other.block_chunk_list());\n\n  for (const auto& pair : other.job_confs().job_id2job_conf()) {\n    CHECK(plan->mutable_job_confs()->mutable_job_id2job_conf()->insert(pair).second);\n  }\n  for (const auto& pair : other.collective_boxing_plan().job_id2request_set()) {\n    CHECK(\n        plan->mutable_collective_boxing_plan()->mutable_job_id2request_set()->insert(pair).second);\n  }\n  for (auto& pair : *(other.mutable_job_id2op_attribute_ref_table())) {\n    CHECK(plan->job_id2op_attribute_ref_table().find(pair.first)\n          == plan->job_id2op_attribute_ref_table().end())\n        << \"fail to merge op attribute info for job: \" << pair.first;\n    (*plan->mutable_job_id2op_attribute_ref_table())[pair.first] = std::move(pair.second);\n  }\n}\n\nvoid MergeSubPlan(Plan* plan, std::vector<Plan>&& sub_plans) {\n  CHECK(!sub_plans.empty());\n  *plan = std::move(sub_plans.at(0));\n  FOR_RANGE(int32_t, i, 1, sub_plans.size()) { MergePlan(plan, std::move(sub_plans.at(i))); }\n}\n\nRegstDescProto* GetSoleDataRegstDescProto(TaskProto* task) {\n  RegstDescProto* ret = nullptr;\n  for (auto& pair : *task->mutable_produced_regst_desc()) {\n    CHECK(pair.second.regst_desc_type().has_data_regst_desc());\n    CHECK_ISNULL(ret);\n    ret = &pair.second;\n  }\n  CHECK_NOTNULL(ret);\n  return ret;\n}\n\nconst OperatorConf& GetSoleOpConf(Plan* plan, const TaskProto& task) {\n  CHECK_EQ(task.exec_sequence().exec_node_size(), 1);\n  return PlanUtil::GetOpAttribute(plan, task.job_id(),\n                                  task.exec_sequence().exec_node(0).kernel_conf())\n      .op_conf();\n}\n\nvoid UpdateSoleObnRegstDescId(Plan* plan, TaskProto* task) {\n  CHECK_EQ(task->exec_sequence().exec_node_size(), 1);\n  auto* exec_node = task->mutable_exec_sequence()->mutable_exec_node(0);\n  const auto& obns =\n      PlanUtil::GetOpAttribute(plan, task->job_id(), exec_node->kernel_conf()).output_bns();\n  CHECK_EQ(obns.size(), 1);\n  int64_t regst_desc_id = GetSoleDataRegstDescProto(task)->regst_desc_id();\n  (*exec_node->mutable_bn_in_op2regst_desc_id())[obns.Get(0)] = regst_desc_id;\n}\n\n// example\n// given caller plan: op_A --> op_identity_tick --> op_B\n// given callee plan: op_src_tick --> op_C --> op_D --> op_E --> op_sink_tick\n// return:\n//         op_A --> op_identity_tick --> op_C --> op_D --> op_E --> op_sink_tick --> op_B\n//                                        /\n//                        op_src_tick -->/\n//\n// note: after this function called, op_src_tick is illegal and need to be deleted from plan\nvoid LinkTickTaskProto(Plan* plan, TaskProto* identity_tick, TaskProto* src_tick,\n                       TaskProto* sink_tick) {\n  CHECK(GetSoleOpConf(plan, *identity_tick).has_tick_conf());\n  CHECK(GetSoleOpConf(plan, *src_tick).has_source_tick_conf());\n  CHECK(GetSoleOpConf(plan, *sink_tick).has_sink_tick_conf());\n  RegstDescProto* id_tick_sole_regst = GetSoleDataRegstDescProto(identity_tick);\n  RegstDescProto* src_tick_sole_regst = GetSoleDataRegstDescProto(src_tick);\n  RegstDescProto* sink_tick_sole_regst = GetSoleDataRegstDescProto(sink_tick);\n\n  sink_tick_sole_regst->set_regst_desc_id(id_tick_sole_regst->regst_desc_id());\n  *sink_tick_sole_regst->mutable_consumer_task_id() = id_tick_sole_regst->consumer_task_id();\n  UpdateSoleObnRegstDescId(plan, sink_tick);\n  CHECK_EQ(identity_tick->machine_id(), sink_tick->machine_id());\n\n  id_tick_sole_regst->set_regst_desc_id(src_tick_sole_regst->regst_desc_id());\n  *id_tick_sole_regst->mutable_consumer_task_id() = src_tick_sole_regst->consumer_task_id();\n  UpdateSoleObnRegstDescId(plan, identity_tick);\n}\n\nvoid LinkMainPlan(Plan* plan, Plan&& main_plan,\n                  const std::vector<std::map<int64_t, std::string>>& identity_tick_op_names) {\n  std::function<bool(const TaskProto*)> IsInterfaceTickTockTask;\n  {\n    auto task_ids = std::make_shared<HashSet<int64_t>>();\n    for (const auto& task : main_plan.task()) {\n      if (task.task_type() == TaskType::kTick) { CHECK(task_ids->emplace(task.task_id()).second); }\n    }\n    IsInterfaceTickTockTask = [task_ids, plan](const TaskProto* task) {\n      if (task_ids->find(task->task_id()) != task_ids->end()) { return true; }\n      if (task->exec_sequence().exec_node_size() != 1) { return false; }\n      const auto& kernel_conf = task->exec_sequence().exec_node(0).kernel_conf();\n      OperatorConf::OpTypeCase op_type_case =\n          PlanUtil::GetOpAttribute(plan, task->job_id(), kernel_conf).op_conf().op_type_case();\n      return op_type_case == OperatorConf::kSourceTickConf\n             || op_type_case == OperatorConf::kSinkTickConf;\n    };\n  }\n  MergePlan(plan, std::move(main_plan));\n  HashMap<std::string, TaskProto*> sole_tick_op_name2sole_task;\n  FOR_RANGE(int64_t, i, 0, plan->task_size()) {\n    TaskProto* task = plan->mutable_task(i);\n    if (IsInterfaceTickTockTask(task) == false) { continue; }\n    const auto& kernel_conf = task->exec_sequence().exec_node(0).kernel_conf();\n    const auto& op_name =\n        PlanUtil::GetOpAttribute(plan, task->job_id(), kernel_conf).op_conf().name();\n    CHECK(sole_tick_op_name2sole_task.emplace(op_name, task).second);\n  }\n  auto TaskProto4TaskId = PlanUtil::MakeGetterTaskProto4TaskId(*plan);\n  const auto& process_ranks = Singleton<ResourceDesc, ForSession>::Get()->process_ranks();\n  FOR_RANGE(int32_t, i, 0, Singleton<CriticalSectionDesc>::Get()->CriticalSectionNum()) {\n    const CriticalSection& cs = Singleton<CriticalSectionDesc>::Get()->GetCriticalSection(i);\n    for (int64_t machine_id : process_ranks) {\n      TaskProto* identity_tick =\n          sole_tick_op_name2sole_task.at(identity_tick_op_names.at(i).at(machine_id));\n      LinkTickTaskProto(\n          plan, identity_tick,\n          sole_tick_op_name2sole_task.at(cs.machine_id2source_tick_op_name().at(machine_id)),\n          sole_tick_op_name2sole_task.at(cs.machine_id2sink_tick_op_name().at(machine_id)));\n    }\n  }\n  {\n    // erase source_tick task_proto\n    HashSet<std::string> source_tick_op_names;\n    FOR_RANGE(int32_t, i, 0, Singleton<CriticalSectionDesc>::Get()->CriticalSectionNum()) {\n      const CriticalSection& cs = Singleton<CriticalSectionDesc>::Get()->GetCriticalSection(i);\n      for (int64_t machine_id : process_ranks) {\n        const auto& src_tick_op_name = cs.machine_id2source_tick_op_name().at(machine_id);\n        CHECK(source_tick_op_names.emplace(src_tick_op_name).second);\n      }\n    }\n    Erase<PbRpf<TaskProto>>(*plan->mutable_task(), [&](const TaskProto& task) {\n      if (task.task_type() == TaskType::kSourceTick) {\n        CHECK(task.exec_sequence().exec_node_size() == 1);\n        const auto& kernel_conf = task.exec_sequence().exec_node(0).kernel_conf();\n        const auto& op_conf = PlanUtil::GetOpAttribute(plan, task.job_id(), kernel_conf).op_conf();\n        CHECK(op_conf.has_source_tick_conf());\n        CHECK(source_tick_op_names.find(op_conf.name()) != source_tick_op_names.end());\n        return true;\n      } else {\n        return false;\n      }\n    });\n  }\n}\n\nvoid GetMemSharingOpBlobInfo(const JobBuilder& job_builder, const std::string& op_name,\n                             ParallelBlobConf* blob_conf) {\n  std::string obn = \"out\";\n  std::string lbn;\n  {\n    const auto& op_conf = CHECK_JUST(job_builder.OpConf4OpName(op_name));\n    if (op_conf.has_variable_conf()) {\n      lbn = op_name + \"/\" + op_conf.variable_conf().out();\n    } else if (op_conf.has_input_conf()) {\n      lbn = op_name + \"/\" + op_conf.input_conf().out();\n    } else if (op_conf.has_output_conf()) {\n      lbn = op_name + \"/\" + op_conf.output_conf().out();\n    } else if (op_conf.has_return_conf()) {\n      lbn = op_name + \"/\" + op_conf.return_conf().out();\n    } else {\n      UNIMPLEMENTED();\n    }\n  }\n  const auto& job = job_builder.job();\n  ParallelBlobConf ret;\n  *blob_conf->mutable_parallel_conf() = CHECK_JUST(job_builder.ParallelConf4OpName(op_name));\n  *blob_conf->mutable_logical_blob_desc_conf() = job.helper().lbn2logical_blob_desc().at(lbn);\n  *blob_conf->mutable_nd_sbp() =\n      job.job_parallel_view_conf().op_name2nd_sbp_signature_conf().at(op_name).bn_in_op2nd_sbp().at(\n          obn);\n}\n\nvoid FilterOpName2ParallelBlobConf(\n    const HashSet<OperatorConf::OpTypeCase>& match, const std::vector<std::shared_ptr<Job>>& jobs,\n    HashMap<std::string, ParallelBlobConf>* op_name2parallel_blob_conf) {\n  FOR_RANGE(int64_t, job_id, 0, jobs.size()) {\n    JobBuilder job_builder(jobs.at(job_id).get());\n    for (const OperatorConf& op_conf : jobs.at(job_id)->net().op()) {\n      if (match.find(op_conf.op_type_case()) == match.end()) { continue; }\n      ParallelBlobConf parallel_blob_conf;\n      GetMemSharingOpBlobInfo(job_builder, op_conf.name(), &parallel_blob_conf);\n      auto iter = op_name2parallel_blob_conf->find(op_conf.name());\n      if (iter == op_name2parallel_blob_conf->end()) {\n        CHECK(op_name2parallel_blob_conf->emplace(op_conf.name(), parallel_blob_conf).second);\n      } else {\n        CHECK(parallel_blob_conf == iter->second);\n      }\n    }\n  }\n}\n\nvoid CheckNonDistributeOptimizerAvailable(const std::vector<std::shared_ptr<Job>>& jobs) {\n  bool has_job_enable_optimizer_placement_optimization = false;\n  const auto IsEnabled = [](const Job& job) {\n    return job.job_conf().has_train_conf()\n           && job.job_conf().has_optimizer_placement_optimization_mode();\n  };\n  FOR_RANGE(int64_t, job_id, 0, jobs.size()) {\n    if (IsEnabled(*jobs.at(job_id))) {\n      has_job_enable_optimizer_placement_optimization = true;\n      break;\n    }\n  }\n  if (!has_job_enable_optimizer_placement_optimization) { return; }\n\n  HashSet<std::string> var_names;\n  FOR_RANGE(int64_t, job_id, 0, jobs.size()) {\n    if (!IsEnabled(*jobs.at(job_id))) { continue; }\n    for (const OperatorConf& op_conf : jobs.at(job_id)->net().op()) {\n      if (op_conf.op_type_case() != OperatorConf::kVariableConf) { continue; }\n      if (var_names.find(op_conf.name()) == var_names.end()) {\n        var_names.emplace(op_conf.name());\n      } else {\n        // optimizer_placement_optimization jobs has a same variable in between them.\n        LOG(FATAL)\n            << \"Only support optimizer_placement_optimization when jobs not sharing same variable\";\n      }\n    }\n  }\n  FOR_RANGE(int64_t, job_id, 0, jobs.size()) {\n    if (IsEnabled(*jobs.at(job_id))) { continue; }\n    for (const OperatorConf& op_conf : jobs.at(job_id)->net().op()) {\n      if (op_conf.op_type_case() != OperatorConf::kVariableConf) { continue; }\n      if (var_names.find(op_conf.name()) != var_names.end()) {\n        // Other jobs has a same variable in optimizer_placement_optimization jobs.\n        LOG(FATAL)\n            << \"Only support optimizer_placement_optimization when jobs not sharing same variable\";\n      }\n    }\n  }\n}\n\nMaybe<ReentrantLockBackEdge> MakeMainJobComponent(\n    const std::string& wait_and_send_ids_lbn, const Range& machine_id_range,\n    JobBuilder* job_builder, std::vector<std::map<int64_t, std::string>>* identity_tick_op_names,\n    std::vector<std::map<int64_t, std::string>>* cb_sink_tick_op_names) {\n  ParallelConf parallel_conf;\n  parallel_conf.set_device_tag(\"cpu\");\n  parallel_conf.add_device_name(std::string(\"@\") + std::to_string(machine_id_range.begin()) + \":0\");\n  auto lock_back_edge = std::make_shared<ReentrantLockBackEdge>();\n  OperatorConf reentrant_lock_op_conf;\n  {\n    lock_back_edge->reentrant_lock_op_name =\n        std::string(\"System-Main-ReentrantLock_\") + NewUniqueId();\n    reentrant_lock_op_conf.set_name(lock_back_edge->reentrant_lock_op_name);\n    auto* reentrant_lock_conf = reentrant_lock_op_conf.mutable_reentrant_lock_conf();\n    reentrant_lock_conf->set_start(wait_and_send_ids_lbn);\n    // ibn \"end\" is set after plan generated because we don't like cycle in job\n    reentrant_lock_conf->set_out(\"out\");\n    Singleton<CriticalSectionDesc>::Get()->DumpCriticalSectionId2IntersectinIds(\n        reentrant_lock_conf->mutable_lock_id2intersecting_lock_ids());\n    JUST(job_builder->AddOp(parallel_conf, reentrant_lock_op_conf));\n  }\n  // critical section case op conf\n  OperatorConf cs_case_op_conf;\n  {\n    cs_case_op_conf.set_name(std::string(\"System-Main-Case_\") + NewUniqueId());\n    auto* cs_case_conf = cs_case_op_conf.mutable_case_conf();\n    cs_case_conf->set_in(reentrant_lock_op_conf.name() + \"/out\");\n    FOR_RANGE(int64_t, i, 0, Singleton<CriticalSectionDesc>::Get()->CriticalSectionNum()) {\n      cs_case_conf->add_out(GenRepeatedBn(\"out\", i));\n    }\n    JUST(job_builder->AddOp(parallel_conf, cs_case_op_conf));\n  }\n  const int64_t num_critial_sections = Singleton<CriticalSectionDesc>::Get()->CriticalSectionNum();\n  std::vector<std::string> snk_tick_op_names;\n  snk_tick_op_names.reserve(num_critial_sections * machine_id_range.size());\n  FOR_RANGE(int64_t, i, 0, num_critial_sections) {\n    // source tick\n    OperatorConf src_tick_op_conf;\n    {\n      std::string name_prefix = \"System-Main-SourceTick_CriticalSection_\";\n      src_tick_op_conf.set_name(name_prefix + std::to_string(i) + \"_\" + NewUniqueId());\n      auto* src_tick_conf = src_tick_op_conf.mutable_tick_conf();\n      src_tick_conf->add_tick(cs_case_op_conf.name() + \"/\" + GenRepeatedBn(\"out\", i));\n      src_tick_conf->set_out(\"out\");\n      JUST(job_builder->AddOp(parallel_conf, src_tick_op_conf));\n    }\n\n    auto* cur_cb_sink_tick_op_names = &cb_sink_tick_op_names->at(i);\n    for (int64_t machine_id = machine_id_range.begin(); machine_id < machine_id_range.end();\n         ++machine_id) {\n      // identity tick\n      OperatorConf identity_tick_op_conf;\n      {\n        std::string name_prefix = \"System-Main-Tick_CriticalSection_\";\n        identity_tick_op_conf.set_name(name_prefix + std::to_string(i) + \"_\" + NewUniqueId());\n        auto* identity_tick_conf = identity_tick_op_conf.mutable_tick_conf();\n        identity_tick_conf->add_tick(src_tick_op_conf.name() + \"/out\");\n        identity_tick_conf->set_out(\"out\");\n        JUST(job_builder->AddOp(parallel_conf, identity_tick_op_conf));\n        auto* cur_id_tick_op_names = &identity_tick_op_names->at(i);\n        CHECK_OR_RETURN(\n            cur_id_tick_op_names->emplace(machine_id, identity_tick_op_conf.name()).second);\n      }\n      // callback\n      {\n        OperatorConf cb_sink_tick_op_conf;\n        std::string name_prefix = \"System-Main-CallbackSinkTick_\";\n        cb_sink_tick_op_conf.set_name(name_prefix + std::to_string(i) + NewUniqueId());\n        auto* cb_sink_tick_conf = cb_sink_tick_op_conf.mutable_sink_tick_conf();\n        cb_sink_tick_conf->add_tick(identity_tick_op_conf.name() + \"/out\");\n        cb_sink_tick_conf->set_out(\"out\");\n        JUST(job_builder->AddOp(parallel_conf, cb_sink_tick_op_conf));\n        CHECK_OR_RETURN(\n            cur_cb_sink_tick_op_names->emplace(machine_id, cb_sink_tick_op_conf.name()).second);\n      }\n      // sink tick\n      {\n        OperatorConf snk_tick_op_conf;\n        std::string name_prefix = \"System-Main-SinkTick_CriticalSection_\";\n        snk_tick_op_conf.set_name(name_prefix + std::to_string(i) + NewUniqueId());\n        auto* snk_tick_conf = snk_tick_op_conf.mutable_sink_tick_conf();\n        snk_tick_conf->add_tick(identity_tick_op_conf.name() + \"/out\");\n        snk_tick_conf->set_out(\"out\");\n        JUST(job_builder->AddOp(parallel_conf, snk_tick_op_conf));\n        snk_tick_op_names.emplace_back(snk_tick_op_conf.name());\n      }\n    }\n  }\n  // critical section esac op conf\n  OperatorConf cs_esac_op_conf;\n  {\n    cs_esac_op_conf.set_name(std::string(\"System-Main-Esac_\") + NewUniqueId());\n    // cs_esac_op_conf.set_pass_tag(\"main\");\n    auto* cs_esac_conf = cs_esac_op_conf.mutable_esac_conf();\n    for (const auto& snk_tick_op_name : snk_tick_op_names) {\n      cs_esac_conf->add_in(snk_tick_op_name + \"/out\");\n    }\n    cs_esac_conf->set_out(\"out\");\n    cs_esac_conf->set_data_type(DataType::kInt32);\n    JUST(job_builder->AddOp(parallel_conf, cs_esac_op_conf));\n  }\n  lock_back_edge->critical_section_sink_lbi.set_op_name(cs_esac_op_conf.name());\n  lock_back_edge->critical_section_sink_lbi.set_blob_name(\"out\");\n  return lock_back_edge;\n}\n\nMaybe<void> MakeCallbackNotifierSinkTick(\n    const std::set<int64_t>& process_ranks,\n    const std::vector<std::map<int64_t, std::string>>& cb_sink_tick_op_names,\n    JobBuilder* job_builder, const std::function<void(const std::string& lbn)>& DoEachSinkTickLbn) {\n  const auto& MakeSinkTick = [&](const std::vector<int64_t>& job_cs_ids,\n                                 int64_t machine_id) -> Maybe<std::string> {\n    if (job_cs_ids.size() == 1) {\n      return cb_sink_tick_op_names.at(job_cs_ids.at(0)).at(machine_id) + \"/out\";\n    }\n    ParallelConf machine_parallel_conf;\n    {\n      machine_parallel_conf.set_device_tag(\"cpu\");\n      machine_parallel_conf.add_device_name(\"@\" + std::to_string(machine_id) + \":0\");\n    }\n    OperatorConf snk_tick_op_conf;\n    {\n      std::string name_prefix = \"System-Main-CallbackNotifier_CriticalSection_\";\n      snk_tick_op_conf.set_name(name_prefix + NewUniqueId());\n      auto* snk_tick_conf = snk_tick_op_conf.mutable_sink_tick_conf();\n      for (int64_t job_cs_id : job_cs_ids) {\n        const auto& cb_sink_tick_op_name = cb_sink_tick_op_names.at(job_cs_id).at(machine_id);\n        snk_tick_conf->add_tick(cb_sink_tick_op_name + \"/out\");\n      }\n      snk_tick_conf->set_out(\"out\");\n      JUST(job_builder->AddOp(machine_parallel_conf, snk_tick_op_conf));\n    }\n    return snk_tick_op_conf.name() + \"/out\";\n  };\n  ParallelConf parallel_conf;\n  {\n    parallel_conf.set_device_tag(\"cpu\");\n    parallel_conf.add_device_name(\"0:0\");\n  }\n  for (const auto& cs_ids : Singleton<CriticalSectionDesc>::Get()->job_id2critical_section_ids()) {\n    OperatorConf snk_tick_op_conf;\n    {\n      std::string name_prefix = \"System-Main-CallbackNotifier_CriticalSection_\";\n      snk_tick_op_conf.set_name(name_prefix + NewUniqueId());\n      snk_tick_op_conf.set_pass_tag(kMainOp);\n      auto* snk_tick_conf = snk_tick_op_conf.mutable_sink_tick_conf();\n      for (int64_t machine_id : process_ranks) {\n        snk_tick_conf->add_tick(*JUST(MakeSinkTick(cs_ids, machine_id)));\n      }\n      snk_tick_conf->set_out(\"out\");\n      JUST(job_builder->AddOp(parallel_conf, snk_tick_op_conf));\n    }\n    DoEachSinkTickLbn(snk_tick_op_conf.name() + \"/out\");\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> MakeMainJob(Job* main_job,\n                        std::vector<std::map<int64_t, std::string>>* identity_tick_op_names,\n                        std::vector<ReentrantLockBackEdge>* lock_back_edges) {\n  JobBuilder job_builder(main_job);\n  CHECK_OR_RETURN(GlobalProcessCtx::IsThisProcessMaster());\n  ParallelConf parallel_conf;\n  parallel_conf.set_device_tag(\"cpu\");\n  parallel_conf.add_device_name(\"0:0\");\n  OperatorConf wait_and_send_ids_op_conf;\n  {\n    wait_and_send_ids_op_conf.set_name(std::string(\"System-Main-WaitAndSendIds_\") + NewUniqueId());\n    wait_and_send_ids_op_conf.set_pass_tag(kMainOp);\n    auto* wait_and_send_ids_conf = wait_and_send_ids_op_conf.mutable_wait_and_send_ids_conf();\n    wait_and_send_ids_conf->set_out(\"out\");\n    wait_and_send_ids_conf->set_wait_buffer_name(kBufferNameGlobalWaitJobId);\n    wait_and_send_ids_conf->set_data_type(DataType::kInt32);\n    auto* id_list = wait_and_send_ids_conf->mutable_id_list();\n    FOR_RANGE(int32_t, i, 0, Singleton<JobName2JobId>::Get()->size()) { id_list->Add(); }\n    HashSet<int64_t> unique_check;\n    for (const auto& pair : *Singleton<JobName2JobId>::Get()) {\n      int64_t job_id = pair.second;\n      CHECK_OR_RETURN(unique_check.insert(job_id).second);\n      const auto& cs_idx = Singleton<CriticalSectionDesc>::Get()->CriticalSectionIds4JobId(job_id);\n      *id_list->Mutable(job_id)->mutable_value() = {cs_idx.begin(), cs_idx.end()};\n    }\n    JUST(job_builder.AddOp(parallel_conf, wait_and_send_ids_op_conf));\n  }\n  const int64_t num_critial_sections = Singleton<CriticalSectionDesc>::Get()->CriticalSectionNum();\n  std::vector<std::map<int64_t, std::string>> cb_sink_tick_op_names;\n  identity_tick_op_names->resize(num_critial_sections);\n  cb_sink_tick_op_names.resize(num_critial_sections);\n  const auto& process_ranks = Singleton<ResourceDesc, ForSession>::Get()->process_ranks();\n  for (int64_t machine_id : process_ranks) {\n    Range sub_range(machine_id, machine_id + 1);\n    const auto& in_lbn = wait_and_send_ids_op_conf.name() + \"/out\";\n    lock_back_edges->emplace_back(*JUST(MakeMainJobComponent(\n        in_lbn, sub_range, &job_builder, identity_tick_op_names, &cb_sink_tick_op_names)));\n  }\n  OperatorConf callback_notify_esac_op_conf;\n  {\n    callback_notify_esac_op_conf.set_name(std::string(\"System-Main-Esac_\") + NewUniqueId());\n    callback_notify_esac_op_conf.set_pass_tag(kMainOp);\n    auto* callback_notify_esac_conf = callback_notify_esac_op_conf.mutable_esac_conf();\n    JUST(MakeCallbackNotifierSinkTick(\n        process_ranks, cb_sink_tick_op_names, &job_builder,\n        [&](const std::string& lbn) { callback_notify_esac_conf->add_in(lbn); }));\n    callback_notify_esac_conf->set_out(\"out\");\n    callback_notify_esac_conf->set_data_type(DataType::kInt32);\n    JUST(job_builder.AddOp(parallel_conf, callback_notify_esac_op_conf));\n  }\n  OperatorConf callback_notify_op_conf;\n  {\n    callback_notify_op_conf.set_name(std::string(\"System-Main-CallbackNotify_\") + NewUniqueId());\n    callback_notify_op_conf.set_pass_tag(kMainOp);\n    auto* callback_notify_conf = callback_notify_op_conf.mutable_callback_notify_conf();\n    callback_notify_conf->set_in(callback_notify_esac_op_conf.name() + \"/out\");\n    auto* buffer_names = callback_notify_conf->mutable_callback_buffer_name();\n    FOR_RANGE(int64_t, i, 0, Singleton<JobName2JobId>::Get()->size()) { buffer_names->Add(); }\n    for (const auto& pair : *Singleton<JobName2JobId>::Get()) {\n      int64_t job_id = pair.second;\n      const auto& buffer_name = GetCallbackNotifierBufferName(pair.first);\n      *buffer_names->Mutable(job_id) = buffer_name;\n    }\n    JUST(job_builder.AddOp(parallel_conf, callback_notify_op_conf));\n  }\n\n  auto* job_conf = main_job->mutable_job_conf();\n  job_conf->set_job_name(\"MainJob-unamed\");\n  job_conf->mutable_predict_conf();\n  job_conf->set_default_data_type(DataType::kInt32);\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> ConnectCriticalSectionEndToReentrantLockEnd(\n    Plan* main_plan, const ReentrantLockBackEdge& lock_back_edge) {\n  TaskProto* reentrant_lock_task = nullptr;\n  TaskProto* cs_sink_task = nullptr;\n  FOR_RANGE(int64_t, i, 0, main_plan->task_size()) {\n    auto* task = main_plan->mutable_task(i);\n    CHECK_EQ_OR_RETURN(task->exec_sequence().exec_node_size(), 1);\n    const auto& kernel_conf = task->exec_sequence().exec_node(0).kernel_conf();\n    const auto& op_name =\n        PlanUtil::GetOpAttribute(main_plan, task->job_id(), kernel_conf).op_conf().name();\n    if (op_name == lock_back_edge.reentrant_lock_op_name) {\n      CHECK_ISNULL_OR_RETURN(reentrant_lock_task);\n      reentrant_lock_task = task;\n    } else if (op_name == lock_back_edge.critical_section_sink_lbi.op_name()) {\n      CHECK_ISNULL_OR_RETURN(cs_sink_task);\n      cs_sink_task = task;\n    } else {\n      // do nothing\n    }\n  }\n  CHECK_NOTNULL_OR_RETURN(reentrant_lock_task);\n  CHECK_NOTNULL_OR_RETURN(cs_sink_task);\n  RegstDescProto* cs_end_regst = PlanUtil::GetSoleProducedDataRegst(cs_sink_task);\n  cs_end_regst->add_consumer_task_id(reentrant_lock_task->task_id());\n  reentrant_lock_task->mutable_consumed_regst_desc_id()->at(\"in\").add_regst_desc_id(\n      cs_end_regst->regst_desc_id());\n\n  auto* reentrant_exec_node = reentrant_lock_task->mutable_exec_sequence()->mutable_exec_node(0);\n  (*reentrant_exec_node->mutable_bn_in_op2regst_desc_id())[\"end\"] = cs_end_regst->regst_desc_id();\n\n  auto* op_attribute = reentrant_exec_node->mutable_kernel_conf()->mutable_op_attribute();\n  op_attribute->add_input_bns(\"end\");\n  (*op_attribute->mutable_arg_signature()->mutable_bn_in_op2lbi())[\"end\"] =\n      lock_back_edge.critical_section_sink_lbi;\n  const auto& blob_desc_signature_map =\n      op_attribute->logical_blob_desc_signature().bn_in_op2blob_desc();\n  const auto it = blob_desc_signature_map.find(\"start\");\n  CHECK_OR_RETURN(it != blob_desc_signature_map.end());\n  CHECK_OR_RETURN(blob_desc_signature_map.find(\"end\") == blob_desc_signature_map.end());\n  (*op_attribute->mutable_logical_blob_desc_signature()->mutable_bn_in_op2blob_desc())[\"end\"] =\n      it->second;\n  auto* reentrant_lock_conf = op_attribute->mutable_op_conf()->mutable_reentrant_lock_conf();\n  reentrant_lock_conf->set_end(GenLogicalBlobName(lock_back_edge.critical_section_sink_lbi));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> CompileMainJob(Job* main_job, const std::vector<ReentrantLockBackEdge>& lock_back_edges,\n                           int64_t job_id, Plan* main_plan) {\n  CHECK_OR_RETURN(GlobalProcessCtx::IsThisProcessMaster());\n  {\n    auto scope = std::make_unique<GlobalJobDescScope>(main_job->job_conf(), job_id);\n    JUST(CompileCurJobOnMaster(main_job, main_plan, false));\n  }\n  for (const auto& lock_back_edge : lock_back_edges) {\n    JUST(ConnectCriticalSectionEndToReentrantLockEnd(main_plan, lock_back_edge));\n  }\n  return Maybe<void>::Ok();\n}\n\nvoid AddJobName2JobId(const std::string& job_name, int64_t job_id) {\n  if (!GlobalProcessCtx::IsThisProcessMaster()) { return; }\n  CHECK(Singleton<JobName2JobId>::Get()->emplace(job_name, job_id).second);\n}\n\nbool NeedAllocateMemory(const RegstDescTypeProto& regst_desc_type) {\n  return regst_desc_type.has_data_regst_desc();\n}\n\nvoid FinishGlobalCriticalSectionDesc(const Plan& plan, int64_t job_size) {\n  std::vector<HashMap<std::string, HashSet<int64_t>>> job_id2sole_op_name2mem_block_ids(job_size);\n  std::vector<HashSet<int64_t>> job_id2mem_block_ids(job_size);\n  std::vector<HashSet<int64_t>> job_id2chunk_ids(job_size);\n  for (const auto& task : plan.task()) {\n    if (task.exec_sequence().exec_node_size() == 1) {\n      const auto& kernel_conf = task.exec_sequence().exec_node(0).kernel_conf();\n      const std::string& op_name =\n          PlanUtil::GetOpAttribute(&plan, task.job_id(), kernel_conf).op_conf().name();\n      HashSet<int64_t>* mem_block_ids =\n          &(job_id2sole_op_name2mem_block_ids.at(task.job_id())[op_name]);\n      for (const auto& pair : task.produced_regst_desc()) {\n        if (NeedAllocateMemory(pair.second.regst_desc_type())) {\n          mem_block_ids->emplace(pair.second.mem_block_id());\n        }\n        if (pair.second.has_separated_header_mem_block_id()\n            && pair.second.separated_header_mem_block_id() != -1) {\n          mem_block_ids->emplace(pair.second.separated_header_mem_block_id());\n        }\n      }\n    }\n  }\n  for (const auto& mem_block : plan.block_chunk_list().mem_block()) {\n    if (mem_block.mem_size() == 0) { continue; }\n    for (int64_t job_id : mem_block.job_id()) {\n      job_id2mem_block_ids.at(job_id).insert(mem_block.mem_block_id());\n    }\n  }\n  for (const auto& chunk : plan.block_chunk_list().chunk()) {\n    if (chunk.mem_size() == 0) { continue; }\n    for (int64_t job_id : chunk.job_id()) { job_id2chunk_ids.at(job_id).insert(chunk.chunk_id()); }\n  }\n\n  HashMap<int64_t, HashSet<int64_t>> job_id2input_output_mem_block_ids;\n  auto* critical_section_desc = Singleton<CriticalSectionDesc>::Get();\n  // set mem_block_id for InputOutputCriticalSection\n  FOR_RANGE(int64_t, i, 0, critical_section_desc->CriticalSectionNum()) {\n    auto* critical_section = critical_section_desc->MutCriticalSection(i);\n    int64_t job_id = critical_section->job_id();\n    auto* input_output_mem_block_ids = &job_id2input_output_mem_block_ids[job_id];\n    if (critical_section->has_input_output_critical_section()) {\n      HashSet<int64_t> mem_block_ids;\n      for (const auto& op_name :\n           critical_section->input_output_critical_section().lbi_producer_op_name()) {\n        const auto& cur_mem_block_ids = job_id2sole_op_name2mem_block_ids.at(job_id).at(op_name);\n        mem_block_ids.insert(cur_mem_block_ids.begin(), cur_mem_block_ids.end());\n      }\n      *critical_section->mutable_mem_block_id() = {mem_block_ids.begin(), mem_block_ids.end()};\n      input_output_mem_block_ids->insert(mem_block_ids.begin(), mem_block_ids.end());\n    } else {\n      CHECK(critical_section->has_total_job_critical_section());\n    }\n  }\n  HashSet<int64_t> unique_job_id_check;\n  // set mem_block_id for TotalJobCriticalSection\n  FOR_RANGE(int64_t, i, 0, critical_section_desc->CriticalSectionNum()) {\n    auto* critical_section = critical_section_desc->MutCriticalSection(i);\n    int64_t job_id = critical_section->job_id();\n    const auto& input_output_mem_block_ids = job_id2input_output_mem_block_ids.at(job_id);\n    if (critical_section->has_total_job_critical_section()) {\n      CHECK(unique_job_id_check.emplace(job_id).second);\n      auto* mem_block_ids = &job_id2mem_block_ids.at(job_id);\n      {\n        // exclude input/output criticalsection mem_blob_ids from total_job\n        auto it = mem_block_ids->begin();\n        while (it != mem_block_ids->end()) {\n          if (input_output_mem_block_ids.find(*it) == input_output_mem_block_ids.end()) {\n            ++it;\n          } else {\n            it = mem_block_ids->erase(it);\n          }\n        }\n      }\n      *critical_section->mutable_mem_block_id() = {mem_block_ids->begin(), mem_block_ids->end()};\n      *critical_section->mutable_chunk_id() = {job_id2chunk_ids.at(job_id).begin(),\n                                               job_id2chunk_ids.at(job_id).end()};\n    }\n  }\n  critical_section_desc->Done();\n}\n\nREGISTER_FUNCTION_CONFIG_DEF().Bool(\"__is_user_function__\", true, \"is user defined function\");\n\nMaybe<void> CompileJobsAndMergePlans(const PbRpf<Job>& job_confs, Plan& plan) {\n  std::vector<std::shared_ptr<Job>> jobs(job_confs.size());\n  FOR_RANGE(int, i, 0, jobs.size()) { jobs.at(i).reset(new Job(job_confs.Get(i))); }\n  // These checks donot work in nn.Graph API because there is only on job compile each time.\n  // And nn.Graph Support training and evaluation share the same variable.\n  if (jobs.size() > 1) { CheckNonDistributeOptimizerAvailable(jobs); }\n  HashMap<std::string, ParallelBlobConf> var_op_name2parallel_blob_conf;\n  FilterOpName2ParallelBlobConf({OperatorConf::kVariableConf}, jobs,\n                                &var_op_name2parallel_blob_conf);\n  std::vector<std::shared_ptr<Job>> function_jobs;\n  function_jobs.reserve(jobs.size());\n  FOR_RANGE(int, i, 0, jobs.size()) {\n    JobDesc job_desc(jobs.at(i)->job_conf(), i);\n    if (job_desc.Bool(\"__is_user_function__\")) { function_jobs.emplace_back(jobs.at(i)); }\n  }\n\n  std::vector<Plan> sub_plans(jobs.size());\n  FOR_RANGE(int64_t, i, 0, jobs.size()) {\n    AddJobName2JobId(jobs.at(i)->job_conf().job_name(), i);\n    auto scope = std::make_unique<GlobalJobDescScope>(jobs.at(i)->job_conf(), i);\n    JUST(CompileCurJobOnMaster(jobs.at(i).get(), &sub_plans.at(i), true));\n  }\n  MergeSubPlan(&plan, std::move(sub_plans));\n  InterJobMemSharingUtil::MergeMemReusedChunkBetweenUserJobs(function_jobs, &plan);\n  InterJobMemSharingUtil::MergeMemSharedInterfaceMemBlockBetweenJobs(jobs, &plan);\n  PlanUtil::SetForceInplaceMemBlock(&plan);\n  FinishGlobalCriticalSectionDesc(plan, jobs.size());\n  Plan main_plan;\n  std::vector<std::map<int64_t, std::string>> identity_tick_op_names;\n  {\n    Job main_job;\n    std::vector<ReentrantLockBackEdge> lock_back_edges;\n    JUST(MakeMainJob(&main_job, &identity_tick_op_names, &lock_back_edges));\n    AddJobName2JobId(main_job.job_conf().job_name(), jobs.size());\n    JUST(CompileMainJob(&main_job, lock_back_edges, jobs.size(), &main_plan));\n  }\n  LinkMainPlan(&plan, std::move(main_plan), identity_tick_op_names);\n  PlanUtil::CleanUselessMemBlockAndCheckValid(&plan);\n  PlanUtil::DumpCtrlRegstInfoToPlan(&plan);\n  PlanUtil::PlanMemoryLog(&plan, \"merged_plan\");\n  if (Singleton<ResourceDesc, ForSession>::Get()->enable_debug_mode()) {\n    TeePersistentLogStream::Create(\"merged_plan\")->Write(plan);\n    PlanUtil::ToDotFile(plan, \"/dot/merged_plan.dot\");\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> CompileJobsAndPushMergedPlan(const PbRpf<Job>& job_confs) {\n  if (GlobalProcessCtx::IsThisProcessMaster()) {\n    Plan plan;\n    JUST(CompileJobsAndMergePlans(job_confs, plan));\n    double start = GetCurTime();\n    // push op_attribute_info\n    OpAttributeInfo op_attribute_info;\n    *op_attribute_info.mutable_job_id2op_attribute_ref_table() =\n        plan.job_id2op_attribute_ref_table();\n    Singleton<CtrlClient>::Get()->PushKV(\"op_attribute_info\", op_attribute_info);\n    // push plan\n    PushPlan(\"merged_plan\", std::move(plan));\n    LOG(INFO) << \" PushPlan merged_plan time: \" << (GetCurTime() - start) / 1e9 << \" seconds.\\n\";\n  }\n  OF_SESSION_BARRIER();\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\nMaybe<void> Oneflow::Init(const oneflow::JobSet& job_set) {\n  OF_PROFILER_RANGE_GUARD(\"Oneflow::Init\");\n  // Runtime\n  OF_PROFILER_RANGE_PUSH(\"CompileJobsAndPushMergedPlan\");\n  JUST(CompileJobsAndPushMergedPlan(job_set.job()));\n  OF_PROFILER_RANGE_POP();  // CompileJobsAndPushMergedPlan\n  double start = GetCurTime();\n  PullPlan(\"merged_plan\", &plan_);\n  LOG(INFO) << \" PullPlan merged_plan time: \" << (GetCurTime() - start) / 1e9 << \" seconds.\\n\";\n  if (GlobalProcessCtx::IsThisProcessMaster()) {\n    runtime_buffers_scope_.reset(new RuntimeBuffersScope(plan_.job_confs()));\n  }\n  OF_PROFILER_RANGE_PUSH(\"new Runtime\");\n\n  HashMap<std::string, vm::EagerBlobObject*> variable_op_name2eager_blob_object;\n  runtime_.reset(new Runtime(plan_, variable_op_name2eager_blob_object));\n  OF_PROFILER_RANGE_POP();  // new Runtime\n  return Maybe<void>::Ok();\n}\n\nOneflow::~Oneflow() {\n  if (GlobalProcessCtx::IsThisProcessMaster()) { runtime_buffers_scope_.reset(); }\n  runtime_.reset();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job/oneflow.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_JOB_ONEFLOW_H_\n#define ONEFLOW_CORE_JOB_ONEFLOW_H_\n\n#include \"oneflow/core/job/job_set.pb.h\"\n#include \"oneflow/core/job/plan.pb.h\"\n#include \"oneflow/core/control/ctrl_server.h\"\n#include \"oneflow/core/job/runtime.h\"\n#include \"oneflow/core/job/runtime_buffers_scope.h\"\n#include \"oneflow/core/job/inter_user_job_info.pb.h\"\n\nnamespace oneflow {\n\nclass Oneflow final {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(Oneflow);\n  Oneflow() {}\n  ~Oneflow();\n\n  Maybe<void> Init(const oneflow::JobSet& job_set);\n\n private:\n  Plan plan_;\n  std::unique_ptr<RuntimeBuffersScope> runtime_buffers_scope_;\n  std::unique_ptr<Runtime> runtime_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_JOB_ONEFLOW_H_\n"
  },
  {
    "path": "oneflow/core/job/parallel_conf_signature.proto",
    "content": "syntax = \"proto2\";\npackage oneflow;\n\nimport \"oneflow/core/job/placement.proto\";\n\nmessage ParallelConfSignature {\n  optional ParallelConf op_parallel_conf = 1;\n  map<string, ParallelConf> bn_in_op2parallel_conf = 2;\n}\n"
  },
  {
    "path": "oneflow/core/job/parallel_desc.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <algorithm>\n#include \"oneflow/core/job/parallel_desc.h\"\n#include \"oneflow/core/common/container_util.h\"\n#include \"oneflow/core/common/decorator.h\"\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/common/cpp_attribute.h\"\n#include \"oneflow/core/job/global_for.h\"\n#include \"oneflow/core/job/id_manager.h\"\n#include \"oneflow/core/control/global_process_ctx.h\"\n#include \"oneflow/core/framework/parallel_conf_util.h\"\n#include \"oneflow/core/framework/instructions_builder.h\"\n#include \"oneflow/core/framework/device.h\"\n#include \"oneflow/core/vm/vm_util.h\"\n#include \"oneflow/core/ep/include/device_manager_registry.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nint64_t GetDeviceCount(DeviceType device_type) {\n  return Singleton<ep::DeviceManagerRegistry>::Get()->GetDeviceCount(device_type);\n}\n\nusing MachineId2DeviceIdList =\n    std::shared_ptr<HashMap<int64_t, std::shared_ptr<std::vector<int64_t>>>>;\n\nbool GlobalDeviceIdsContaining(const MachineId2DeviceIdList& bigger,\n                               const MachineId2DeviceIdList& smaller) {\n  for (const auto& pair : *smaller) {\n    if (bigger->find(pair.first) == bigger->end()) { return false; }\n    const auto& bigger_device_ids = bigger->find(pair.first)->second;\n    std::vector<int64_t>::iterator ret;\n    for (int64_t device_id : *pair.second) {\n      ret = std::find(bigger_device_ids->begin(), bigger_device_ids->end(), device_id);\n      if (ret == bigger_device_ids->end()) { return false; }\n    }\n  }\n  return true;\n}\n\n}  // namespace\n\nMaybe<std::pair<int64_t, std::string>> ParseDeviceNameConf(const std::string& device_name) {\n  size_t delimiter_pos = device_name.rfind(\":\");\n  CHECK_NE_OR_RETURN(delimiter_pos, std::string::npos);\n  int64_t mchn_id = oneflow_cast<int64_t>(device_name.substr(0, delimiter_pos));\n  std::string device_id_str = device_name.substr(delimiter_pos + 1);\n  return std::make_pair(mchn_id, device_id_str);\n}\n\nMaybe<OFRecord> ParseMachineAndDeviceIdList(const ParallelConf& parallel_conf) {\n  ParallelDesc parallel_desc;\n  JUST(parallel_desc.MaybeInit(parallel_conf));\n  auto machine2device_list = std::make_shared<OFRecord>();\n  auto* features = machine2device_list->mutable_feature();\n  for (int64_t machine_id : parallel_desc.sorted_machine_ids()) {\n    Int32List* device_id_list = (*features)[std::to_string(machine_id)].mutable_int32_list();\n    for (int64_t device_id : parallel_desc.sorted_dev_phy_ids(machine_id)) {\n      device_id_list->add_value(device_id);\n    }\n  }\n  return machine2device_list;\n}\n\nParallelDesc::ParallelDesc(const ParallelConf& user_conf) : symbol_id_(NullOpt) {  // NOLINT\n  CHECK_JUST(MaybeInit(user_conf));\n}\n\nMaybe<ParallelDesc> ParallelDesc::New(int64_t symbol_id, const ParallelConf& parallel_conf) {\n  std::shared_ptr<ParallelDesc> parallel_desc(new ParallelDesc(symbol_id));\n  JUST(parallel_desc->MaybeInit(parallel_conf));\n  return parallel_desc;\n}\n\nMaybe<ParallelDesc> ParallelDesc::New(const std::string& device_tag,\n                                      const std::vector<std::string>& machine_device_ids,\n                                      const std::shared_ptr<Shape>& hierarchy) {\n  const auto parallel_conf = JUST(MakeParallelConf(device_tag, machine_device_ids, hierarchy));\n  std::shared_ptr<ParallelDesc> parallel_desc;\n  JUST(PhysicalRun([&parallel_desc, &parallel_conf](InstructionsBuilder* builder) -> Maybe<void> {\n    parallel_desc = JUST(builder->GetParallelDescSymbol(*parallel_conf));\n    return Maybe<void>::Ok();\n  }));\n  return parallel_desc;\n}\n\nMaybe<void> ParallelDesc::MaybeInit(const ParallelConf& user_conf) {\n  parallel_conf_ = user_conf;\n  device_type_ = DeviceType::kInvalidDevice;\n  const std::string& device_tag = parallel_conf_.device_tag();\n  DeviceType device_type = JUST(DeviceType4DeviceTag(device_tag));\n  CHECK_OR_RETURN(device_type_ == DeviceType::kInvalidDevice || device_type_ == device_type);\n  device_type_ = device_type;\n  machine_id2sorted_dev_phy_ids_ =\n      std::make_shared<HashMap<int64_t, std::shared_ptr<std::vector<int64_t>>>>();\n  for (const std::string& device_name : parallel_conf_.device_name()) {\n    if (device_name[0] == '@') {\n      JUST(SetMachineIdAndDeviceIdsByParsingDeviceName(device_name.substr(1), 1));\n    } else {\n      JUST(SetMachineIdAndDeviceIdsByParsingDeviceName(device_name,\n                                                       GlobalProcessCtx::NumOfProcessPerNode()));\n    }\n  }\n  containing_current_rank_ = machine_id2sorted_dev_phy_ids_->count(GlobalProcessCtx::Rank()) > 0;\n  ClearUp();\n  JUST(SanityCheck());\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> ParallelDesc::SetMachineIdAndDeviceIdsByParsingDeviceName(\n    const std::string& device_name, size_t cols) {\n  auto [node_id, device_id_str] = *JUST(ParseDeviceNameConf(device_name));\n  int64_t minus_pos = device_id_str.find(\"-\");\n  if (minus_pos == std::string::npos) {\n    device_id_str = device_id_str + \"-\" + device_id_str;\n    minus_pos = device_id_str.find(\"-\");\n  }\n  int64_t min_id = oneflow_cast<int64_t>(device_id_str.substr(0, minus_pos));\n  int64_t max_id = oneflow_cast<int64_t>(device_id_str.substr(minus_pos + 1));\n  CHECK_LE_OR_RETURN(min_id, max_id);\n  for (int64_t dev_phy_id = min_id; dev_phy_id <= max_id; ++dev_phy_id) {\n    int64_t mchn_id = dev_phy_id % cols + node_id * cols;\n    if (!(*machine_id2sorted_dev_phy_ids_)[mchn_id]) {\n      (*machine_id2sorted_dev_phy_ids_)[mchn_id] = std::make_shared<std::vector<int64_t>>();\n    }\n    (*machine_id2sorted_dev_phy_ids_)[mchn_id]->emplace_back(dev_phy_id);\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<int64_t> ParallelDesc::ParallelId4MachineDeviceId(int64_t machine_id,\n                                                        int64_t device_id) const {\n  const auto& machine_iter = machine_id2device_id2parallel_id_.find(machine_id);\n  CHECK_OR_RETURN(machine_iter != machine_id2device_id2parallel_id_.end());\n  const auto& device_iter = machine_iter->second.find(device_id);\n  CHECK_OR_RETURN(device_iter != machine_iter->second.end());\n  return device_iter->second;\n}\n\nMaybe<Symbol<Device>> ParallelDesc::GetTensorDevice4CurrentProcessCtx(\n    Optional<int64_t>* parallel_id) const {\n  int64_t machine_id = 0;\n  int64_t device_id = 0;\n  GlobalProcessCtx::GetCurrentMachineIdAndDeviceId(&machine_id, &device_id);\n  const auto& device = JUST(Device::New(device_tag(), device_id));\n  int64_t parallel_id_val = -1;\n  if (TryGetParallelId(machine_id, device_id, &parallel_id_val)) {\n    *parallel_id = parallel_id_val;\n  } else {\n    *parallel_id = Optional<int64_t>();\n  }\n  return device;\n}\n\nMaybe<Symbol<Device>> GetTensorDevice4CurrentProcessCtx(Symbol<ParallelDesc> parallel_desc,\n                                                        Optional<int64_t>* parallel_id) {\n  static thread_local HashMap<Symbol<ParallelDesc>, Optional<int64_t>> parallel_desc2parallel_id;\n  static thread_local HashMap<Symbol<ParallelDesc>, Symbol<Device>> parallel_desc2device;\n  auto parallel_id_iter = parallel_desc2parallel_id.find(parallel_desc);\n  auto device_iter = parallel_desc2device.find(parallel_desc);\n  if (device_iter == parallel_desc2device.end()) {\n    CHECK_OR_RETURN(parallel_id_iter == parallel_desc2parallel_id.end());\n    Optional<int64_t> id_val;\n    const auto& device_symbol = JUST(parallel_desc->GetTensorDevice4CurrentProcessCtx(&id_val));\n    parallel_id_iter = parallel_desc2parallel_id.emplace(parallel_desc, id_val).first;\n    device_iter = parallel_desc2device.emplace(parallel_desc, device_symbol).first;\n  } else {\n    CHECK_OR_RETURN(parallel_id_iter != parallel_desc2parallel_id.end());\n  }\n  *parallel_id = parallel_id_iter->second;\n  return device_iter->second;\n}\n\nbool ParallelDesc::TryGetParallelId(int64_t machine_id, int64_t device_id,\n                                    int64_t* parallel_id) const {\n  const auto& machine_iter = machine_id2device_id2parallel_id_.find(machine_id);\n  if (machine_iter == machine_id2device_id2parallel_id_.end()) { return false; }\n  const auto& device_iter = machine_iter->second.find(device_id);\n  if (device_iter == machine_iter->second.end()) { return false; }\n  *parallel_id = device_iter->second;\n  return true;\n}\n\nMaybe<bool> ParallelDesc::TryGetParallelId(int64_t rank, int64_t* parallel_id) const {\n  if (!HasMachineId(rank)) { return false; }\n  const auto& device_ids = sorted_dev_phy_ids(rank);\n  CHECK_EQ_OR_RETURN(device_ids.size(), 1) << \"only sole device_id supported. parallel_conf: \\n\"\n                                           << parallel_conf().DebugString();\n  return TryGetParallelId(rank, JUST(VectorAt(device_ids, 0)), parallel_id);\n}\n\nMaybe<void> ParallelDesc::GetParallelContext(ParallelContext* parallel_ctx, int64_t machine_id,\n                                             int64_t device_id) const {\n  parallel_ctx->set_parallel_num(parallel_num());\n  parallel_ctx->set_parallel_id(JUST(ParallelId4MachineDeviceId(machine_id, device_id)));\n  return Maybe<void>::Ok();\n}\n\nbool ParallelDesc::Equals(const ParallelDesc& rhs) const {\n  return (this == &rhs)\n         || (device_type_ == rhs.device_type_ && sorted_machine_ids_ == rhs.sorted_machine_ids_\n             && EqualsMachineId2SortedDevPhyIds(rhs) && *hierarchy_ == *rhs.hierarchy_);\n}\n\nbool ParallelDesc::EqualsIgnoringDeviceType(const ParallelDesc& rhs) const {\n  return sorted_machine_ids_ == rhs.sorted_machine_ids_ && EqualsMachineId2SortedDevPhyIds(rhs)\n         && *hierarchy_ == *rhs.hierarchy_;\n}\n\nbool ParallelDesc::EqualsIgnoringHierarchy(const ParallelDesc& rhs) const {\n  return (this == &rhs)\n         || (device_type_ == rhs.device_type_ && sorted_machine_ids_ == rhs.sorted_machine_ids_\n             && EqualsMachineId2SortedDevPhyIds(rhs));\n}\n\nbool ParallelDesc::EqualsOnlyForMachineAndDeviceIds(const ParallelDesc& rhs) const {\n  return (this == &rhs)\n         || (sorted_machine_ids_ == rhs.sorted_machine_ids_\n             && EqualsMachineId2SortedDevPhyIds(rhs));\n}\n\nbool ParallelDesc::EqualsMachineId2SortedDevPhyIds(const ParallelDesc& rhs) const {\n  for (int64_t machine_id : sorted_machine_ids_) {\n    if (*machine_id2sorted_dev_phy_ids_->at(machine_id)\n        != *rhs.machine_id2sorted_dev_phy_ids_->at(machine_id)) {\n      return false;\n    }\n  }\n  return true;\n}\n\nvoid ParallelDesc::ClearUp() {\n  EraseIf<int64_t, std::shared_ptr<std::vector<int64_t>>>(\n      machine_id2sorted_dev_phy_ids_.get(),\n      [](HashMap<int64_t, std::shared_ptr<std::vector<int64_t>>>::iterator it) {\n        return it->second->empty();\n      });\n  sorted_machine_ids_.clear();\n  parallel_num_ = 0;\n  for (auto& pair : *machine_id2sorted_dev_phy_ids_) {\n    sorted_machine_ids_.emplace_back(pair.first);\n    SortAndRemoveDuplication((pair.second).get());\n    parallel_num_ += pair.second->size();\n  }\n  if (parallel_conf_.has_hierarchy() && parallel_conf_.hierarchy().dim_size() != 0) {\n    hierarchy_.reset(new Shape(parallel_conf_.hierarchy()));\n    CHECK_EQ(hierarchy_->elem_cnt(), parallel_num_);\n  } else {\n    hierarchy_.reset(new Shape({parallel_num_}));\n    hierarchy_->ToProto(parallel_conf_.mutable_hierarchy());\n  }\n  SortAndRemoveDuplication(&sorted_machine_ids_);\n  parallel_conf_.clear_device_name();\n  int64_t parallel_id = 0;\n  for (int64_t machine_id : sorted_machine_ids_) {\n    for (int64_t device_id : *machine_id2sorted_dev_phy_ids_->at(machine_id)) {\n      parallel_conf_.add_device_name(std::string(\"@\") + std::to_string(machine_id) + \":\"\n                                     + std::to_string(device_id));\n      CHECK_EQ(parallel_id, parallel_id2machine_id_.size());\n      parallel_id2machine_id_.push_back(machine_id);\n      CHECK_EQ(parallel_id, parallel_id2device_id_.size());\n      parallel_id2device_id_.push_back(device_id);\n      machine_id2device_id2parallel_id_[machine_id][device_id] = parallel_id;\n      parallel_id += 1;\n    }\n  }\n}\n\nvoid ParallelDesc::set_device_type(DeviceType device_type) {\n  if (device_type == device_type_) { return; }\n  device_type_ = device_type;\n  const std::string tag = *CHECK_JUST(DeviceTag4DeviceType(device_type));\n  parallel_conf_.set_device_tag(tag);\n}\n\nMaybe<void> ParallelDesc::SanityCheck() {\n  device_num_of_each_machine_ = -1;\n  for (auto& pair : *machine_id2sorted_dev_phy_ids_) {\n    if (device_num_of_each_machine_ == -1) {\n      device_num_of_each_machine_ = pair.second->size();\n    } else {\n      CHECK_EQ_OR_RETURN(device_num_of_each_machine_, pair.second->size());\n    }\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> ParallelDesc::CheckDeviceIdsIsValid() const {\n  const auto& sorted_dev_phy_ids_iter =\n      machine_id2sorted_dev_phy_ids_->find(GlobalProcessCtx::Rank());\n  for (int64_t machine_id : sorted_machine_ids_) {\n    CHECK_LT_OR_RETURN(machine_id, GlobalProcessCtx::WorldSize())\n        << Error::RuntimeError()\n        << \"Placement is invalid because rank must be less than world size!\";\n  }\n  if (sorted_dev_phy_ids_iter != machine_id2sorted_dev_phy_ids_->end()) {\n    for (int64_t dev_phy_id : *sorted_dev_phy_ids_iter->second) {\n      if (device_type_ == DeviceType::kCPU) {\n        CHECK_LT_OR_RETURN(dev_phy_id, GlobalProcessCtx::NumOfProcessPerNode())\n            << Error::RuntimeError()\n            << \"Placement is invalid because device id must be less than num of process per node\";\n      } else {\n        const int64_t device_count = GetDeviceCount(device_type_);\n        CHECK_NE_OR_RETURN(device_count, 0)\n            << Error::RuntimeError() << \"Placement is invalid because there is no device!\";\n        int64_t device_num = std::min(GlobalProcessCtx::NumOfProcessPerNode(), device_count);\n        CHECK_LT_OR_RETURN(dev_phy_id, device_num)\n            << Error::RuntimeError() << \"Placement is invalid because device id must be less than \"\n            << (device_count < GlobalProcessCtx::NumOfProcessPerNode() ? \"num devices on node\"\n                                                                       : \"num of process per node\");\n      }\n    }\n  }\n\n  return Maybe<void>::Ok();\n}\n\nParallelConf ParallelDesc::GetParallelIdOnlyParallelConf(int64_t parallel_id) const {\n  ParallelConf parallel_conf;\n  std::string rank = std::to_string(CHECK_JUST(MachineId4ParallelId(parallel_id)));\n  std::string device_id = std::to_string(CHECK_JUST(DeviceId4ParallelId(parallel_id)));\n  parallel_conf.set_device_tag(*CHECK_JUST(DeviceTag4DeviceType(device_type())));\n  parallel_conf.add_device_name(std::string(\"@\") + rank + \":\" + device_id);\n  return parallel_conf;\n}\n\nMaybe<int64_t> ParallelDesc::MachineId4ParallelId(int64_t parallel_id) const {\n  CHECK_LT_OR_RETURN(parallel_id, parallel_id2machine_id_.size())\n      << \"parallel_id: \" << parallel_id << \"\\n----[ parallel_conf ]----\"\n      << parallel_conf().DebugString();\n  return parallel_id2machine_id_.at(parallel_id);\n}\n\nMaybe<int64_t> ParallelDesc::DeviceId4ParallelId(int64_t parallel_id) const {\n  CHECK_LT_OR_RETURN(parallel_id, parallel_id2device_id_.size())\n      << \"parallel_id: \" << parallel_id << \"\\n----[ parallel_conf ]----\"\n      << parallel_conf().DebugString();\n  return parallel_id2device_id_.at(parallel_id);\n}\n\nbool ParallelDesc::ContainingMachineId(int64_t machine_id) const {\n  return machine_id2sorted_dev_phy_ids_->find(machine_id) != machine_id2sorted_dev_phy_ids_->end();\n}\n\nbool ParallelDesc::Containing(int64_t machine_id, int64_t device_id) const {\n  const auto& machine_iter = machine_id2sorted_dev_phy_ids_->find(machine_id);\n  if (machine_iter == machine_id2sorted_dev_phy_ids_->end()) { return false; }\n  const auto& vec = machine_iter->second;\n  return std::find(vec->begin(), vec->end(), device_id) != vec->end();\n}\n\nbool ParallelDesc::Bigger(const ParallelDesc& rhs) const {\n  if (device_tag() != rhs.device_tag()) { return false; }\n  return GlobalDeviceIdsContaining(machine_id2sorted_dev_phy_ids_,\n                                   rhs.machine_id2sorted_dev_phy_ids());\n}\n\nstd::tuple<int32_t, int32_t> GetPartIdAndPartNumFromParallelCtx(\n    const ParallelContext* parallel_ctx) {\n  return std::make_tuple(parallel_ctx->parallel_id(), parallel_ctx->parallel_num());\n}\n\nParallelConf GenParallelConfOfCpuZeroOnMaster() {\n  ParallelConf parallel_conf;\n  parallel_conf.set_device_tag(\"cpu\");\n  parallel_conf.add_device_name(\"0:0\");\n  return parallel_conf;\n}\n\nParallelConf GenParallelConfOfCpuZeroOnAllMachines() {\n  ParallelConf parallel_conf;\n  parallel_conf.set_device_tag(\"cpu\");\n  for (int64_t i : Singleton<ResourceDesc, ForSession>::Get()->process_ranks()) {\n    parallel_conf.add_device_name(std::string(\"@\") + std::to_string(i) + \":0\");\n  }\n  return parallel_conf;\n}\n\nParallelConf GenParallelConfOfCpuOnAllRanks() {\n  ParallelConf parallel_conf;\n  parallel_conf.set_device_tag(\"cpu\");\n  int64_t node_size = GlobalProcessCtx::NodeSize();\n  int64_t device_num = GlobalProcessCtx::NumOfProcessPerNode();\n  for (int64_t node_id = 0; node_id < node_size; ++node_id) {\n    parallel_conf.add_device_name(std::to_string(node_id) + \":0-\" + std::to_string(device_num - 1));\n  }\n  return parallel_conf;\n}\n\nnamespace {\n\nMaybe<Optional<int64_t>> CalcParallelId4CurrentProcessCtx(Symbol<ParallelDesc> parallel_desc) {\n  int64_t machine_id = 0;\n  int64_t device_id = 0;\n  GlobalProcessCtx::GetCurrentMachineIdAndDeviceId(&machine_id, &device_id);\n  int64_t parallel_id = -1;\n  if (parallel_desc->TryGetParallelId(machine_id, device_id, &parallel_id)) {\n    return Optional<int64_t>(parallel_id);\n  } else {\n    return Optional<int64_t>();\n  }\n}\n\nMaybe<const ParallelContext> CalcParallelContext4CurrentProcessCtx(\n    Symbol<ParallelDesc> parallel_desc) {\n  int64_t machine_id = 0;\n  int64_t device_id = 0;\n  GlobalProcessCtx::GetCurrentMachineIdAndDeviceId(&machine_id, &device_id);\n  int64_t parallel_id_val = -1;\n  CHECK_OR_RETURN(parallel_desc->TryGetParallelId(machine_id, device_id, &parallel_id_val));\n  std::shared_ptr<ParallelContext> parallel_ctx = std::make_shared<ParallelContext>();\n  parallel_ctx->set_parallel_id(parallel_id_val);\n  parallel_ctx->set_parallel_num(parallel_desc->parallel_num());\n  return std::shared_ptr<const ParallelContext>(parallel_ctx);\n}\n\nMaybe<Symbol<ParallelDesc>> RawReplaceDeviceType(Symbol<ParallelDesc> parallel_desc,\n                                                 DeviceType device_type) {\n  ParallelConf parallel_conf(parallel_desc->parallel_conf());\n  parallel_conf.set_device_tag(*JUST(DeviceTag4DeviceType(device_type)));\n  return SymbolOf(ParallelDesc(parallel_conf));\n}\n\nMaybe<std::string> RanksToString(int64_t axis, const int64_t* ranks, const Shape& shape) {\n  if (axis == shape.NumAxes()) { return std::to_string(*ranks); }\n  int64_t stride = shape.Count(axis) / shape.At(axis);\n  std::string str = \"[\";\n  for (int i = 0; i < shape.At(axis); ++i) {\n    str += *JUST(RanksToString(axis + 1, ranks, shape));\n    ranks += stride;\n    if (i != shape.At(axis) - 1) { str += \", \"; }\n  }\n  str += \"]\";\n  return str;\n}\n\nMaybe<std::string> RawPlacementToString(Symbol<ParallelDesc> placement) {\n  const std::string& device_type = placement->device_tag();\n  std::vector<int64_t> sorted_node_ids;\n  sorted_node_ids.reserve(placement->sorted_machine_ids().size());\n  HashMap<int64_t, std::vector<int64_t>> node_id2sorted_dev_phy_ids;\n  for (int64_t machine_id : placement->sorted_machine_ids()) {\n    int64_t node_id = GlobalProcessCtx::NodeId(machine_id);\n    if (!std::count(sorted_node_ids.begin(), sorted_node_ids.end(), node_id)) {\n      sorted_node_ids.emplace_back(node_id);\n    }\n    for (int64_t device_id : placement->sorted_dev_phy_ids(machine_id)) {\n      node_id2sorted_dev_phy_ids[node_id].emplace_back(device_id);\n    }\n  }\n  std::vector<int64_t> ranks;\n  for (int64_t node_id : sorted_node_ids) {\n    for (int64_t device_id : node_id2sorted_dev_phy_ids.at(node_id)) {\n      ranks.emplace_back(node_id * GlobalProcessCtx::NumOfProcessPerNode() + device_id);\n    }\n  }\n  CHECK_EQ_OR_RETURN(ranks.size(), placement->hierarchy()->elem_cnt())\n      << \"rank size is \" << ranks.size() << \", but shape is \" << placement->hierarchy()->ToString();\n  const auto& ranks_str = JUST(RanksToString(0, ranks.data(), *placement->hierarchy()));\n  return \"oneflow.placement(type=\\\"\" + device_type + \"\\\", ranks=\" + *ranks_str + \")\";\n}\n\nMaybe<Symbol<Device>> RawGetTensorDevice(Symbol<ParallelDesc> parallel_desc) {\n  int64_t machine_id = 0;\n  int64_t device_id = 0;\n  GlobalProcessCtx::GetCurrentMachineIdAndDeviceId(&machine_id, &device_id);\n  const auto& type = parallel_desc->device_tag();\n  return JUST(Device::New(type, device_id));\n}\n\nMaybe<Symbol<ParallelDesc>> RawTxtStringToPlacement(const std::string& parallel_conf_str) {\n  ParallelConf parallel_conf;\n  CHECK_OR_RETURN(TxtString2PbMessage(parallel_conf_str, &parallel_conf));\n  return SymbolOf(ParallelDesc(parallel_conf));\n}\n\nMaybe<void> RawCheckDeviceIdsIsValid(Symbol<ParallelDesc> placement) {\n  JUST(placement->CheckDeviceIdsIsValid());\n  return Maybe<void>::Ok();\n}\n\nMaybe<Symbol<ParallelDesc>> RawGetParallelDescOfThisRank(const std::string& device_tag) {\n  ParallelConf parallel_conf;\n  parallel_conf.set_device_tag(device_tag);\n  parallel_conf.add_device_name(std::to_string(GlobalProcessCtx::Rank()) + \":\"\n                                + std::to_string(GlobalProcessCtx::LocalRank()));\n  return SymbolOf(ParallelDesc(parallel_conf));\n}\n\n}  // namespace\n\ndecltype(GetParallelId4CurrentProcessCtx) GetParallelId4CurrentProcessCtx =\n    DECORATE(&CalcParallelId4CurrentProcessCtx, ThreadLocal);\ndecltype(GetParallelContext4CurrentProcessCtx) GetParallelContext4CurrentProcessCtx =\n    DECORATE(&CalcParallelContext4CurrentProcessCtx, ThreadLocal);\ndecltype(ReplaceDeviceType) ReplaceDeviceType = DECORATE(&RawReplaceDeviceType, ThreadLocal);\ndecltype(PlacementToString) PlacementToString = DECORATE(&RawPlacementToString, ThreadLocal);\ndecltype(GetTensorDevice) GetTensorDevice = DECORATE(&RawGetTensorDevice, ThreadLocal);\ndecltype(TxtStringToPlacement) TxtStringToPlacement =\n    DECORATE(&RawTxtStringToPlacement, ThreadLocalCopiable);\ndecltype(GetParallelDescOfThisRank) GetParallelDescOfThisRank =\n    DECORATE(&RawGetParallelDescOfThisRank, ThreadLocalCopiable);\ndecltype(CheckDeviceIdsIsValid) CheckDeviceIdsIsValid =\n    DECORATE(&RawCheckDeviceIdsIsValid, ThreadLocal);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job/parallel_desc.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_JOB_PARALLEL_DESC_H_\n#define ONEFLOW_CORE_JOB_PARALLEL_DESC_H_\n\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/common/optional.h\"\n#include \"oneflow/core/common/decorator.h\"\n#include \"oneflow/core/job/placement.pb.h\"\n#include \"oneflow/core/record/record.pb.h\"\n#include \"oneflow/core/framework/to_string.h\"\n#include \"oneflow/core/common/shape.h\"\n#include \"oneflow/core/common/symbol.h\"\n#include \"oneflow/core/common/cached_caller.h\"\n\nnamespace oneflow {\n\nclass ResourceDesc;\n\nMaybe<OFRecord> ParseMachineAndDeviceIdList(const ParallelConf& parallel_conf);\n\nMaybe<std::pair<int64_t, std::string>> ParseDeviceNameConf(const std::string& device_name);\n\nclass ParallelContext;\nclass Device;\n\nclass ParallelDesc final {\n public:\n  ~ParallelDesc() = default;\n\n  ParallelDesc(const ParallelDesc&) = default;\n  ParallelDesc(const ParallelConf& user_conf);\n\n  static Maybe<ParallelDesc> New(int64_t symbol_id, const ParallelConf& parallel_conf);\n  static Maybe<ParallelDesc> New(const std::string& device_tag,\n                                 const std::vector<std::string>& machine_device_ids,\n                                 const std::shared_ptr<Shape>& hierarchy);\n\n  Maybe<void> MaybeInit(const ParallelConf& user_conf);\n\n  // Getters\n  const Optional<int64_t>& symbol_id() const { return symbol_id_; }\n  bool containing_current_rank() const { return containing_current_rank_; }\n  DeviceType device_type() const { return device_type_; }\n  const std::string& device_tag() const { return parallel_conf_.device_tag(); }\n  std::shared_ptr<HashMap<int64_t, std::shared_ptr<std::vector<int64_t>>>>\n  machine_id2sorted_dev_phy_ids() const {\n    return machine_id2sorted_dev_phy_ids_;\n  }\n  bool HasMachineId(int64_t machine_id) const {\n    return machine_id2sorted_dev_phy_ids_->find(machine_id)\n           != machine_id2sorted_dev_phy_ids_->end();\n  }\n  const std::vector<int64_t>& sorted_machine_ids() const { return sorted_machine_ids_; }\n  const std::vector<int64_t>& sorted_dev_phy_ids(int64_t machine_id) const {\n    return *machine_id2sorted_dev_phy_ids_->at(machine_id);\n  }\n  int64_t parallel_num() const { return parallel_num_; }\n  int64_t device_num_of_each_machine() const { return device_num_of_each_machine_; }\n  const ParallelConf& parallel_conf() const { return parallel_conf_; }\n\n  const ParallelConf& data() const { return parallel_conf_; }\n\n  Maybe<void> GetParallelContext(ParallelContext* parallel_ctx, int64_t machine_id,\n                                 int64_t device_id) const;\n  std::shared_ptr<Shape> hierarchy() const { return hierarchy_; }\n\n  // Setters\n  void set_device_type(DeviceType device_type);\n\n  ParallelConf GetParallelIdOnlyParallelConf(int64_t parallel_id) const;\n\n  bool EqualsIgnoringDeviceType(const ParallelDesc& rhs) const;\n  bool EqualsIgnoringHierarchy(const ParallelDesc& rhs) const;\n  bool EqualsOnlyForMachineAndDeviceIds(const ParallelDesc& rhs) const;\n  bool Equals(const ParallelDesc& rhs) const;\n  bool operator==(const ParallelDesc& rhs) const { return Equals(rhs); }\n  bool operator!=(const ParallelDesc& rhs) const { return !(*this == rhs); }\n  bool Equals(const ParallelDesc* rhs) const { return Equals(*rhs); }\n  const std::vector<int64_t>& parallel_id2machine_id() const { return parallel_id2machine_id_; }\n  const std::vector<int64_t>& parallel_id2device_id() const { return parallel_id2device_id_; }\n  Maybe<int64_t> MachineId4ParallelId(int64_t parallel_id) const;\n  Maybe<int64_t> DeviceId4ParallelId(int64_t parallel_id) const;\n  Maybe<int64_t> ParallelId4MachineDeviceId(int64_t machine_id, int64_t device_id) const;\n  Maybe<Symbol<Device>> GetTensorDevice4CurrentProcessCtx(Optional<int64_t>* parallel_id) const;\n  bool Containing(int64_t machine_id, int64_t device_id) const;\n  // this api is exported to python as Containing\n  bool Bigger(const ParallelDesc& rhs) const;\n  bool ContainingMachineId(int64_t machine_id) const;\n\n  bool TryGetParallelId(int64_t machine_id, int64_t device_id, int64_t* parallel_id) const;\n  Maybe<bool> TryGetParallelId(int64_t rank, int64_t* parallel_id) const;\n\n  Maybe<void> CheckDeviceIdsIsValid() const;\n\n private:\n  friend Maybe<OFRecord> ParseMachineAndDeviceIdList(const ParallelConf& parallel_conf);\n  ParallelDesc() : symbol_id_(NullOpt) {}\n  ParallelDesc(int64_t symbol_id) : symbol_id_(symbol_id) {}\n  void ClearUp();\n  Maybe<void> SetMachineIdAndDeviceIdsByParsingDeviceName(const std::string& device_name,\n                                                          size_t cols);\n  Maybe<void> SanityCheck();\n  Maybe<void> CheckWithResourceDesc(const ResourceDesc& resource_desc);\n  bool EqualsMachineId2SortedDevPhyIds(const ParallelDesc& rhs) const;\n\n  Optional<int64_t> symbol_id_;\n  DeviceType device_type_;\n  ParallelConf parallel_conf_;\n  std::shared_ptr<Shape> hierarchy_;\n  std::vector<int64_t> sorted_machine_ids_;\n  std::shared_ptr<HashMap<int64_t, std::shared_ptr<std::vector<int64_t>>>>\n      machine_id2sorted_dev_phy_ids_;\n  int64_t parallel_num_;\n  int64_t device_num_of_each_machine_;\n  std::vector<int64_t> parallel_id2machine_id_;\n  std::vector<int64_t> parallel_id2device_id_;\n  HashMap<int64_t, HashMap<int64_t, int64_t>> machine_id2device_id2parallel_id_;\n  // cached result of ContainingMachineId(GlobalProcessCtx::Rank()) for performace optimization.\n  bool containing_current_rank_;\n};\n\nMaybe<Symbol<Device>> GetTensorDevice4CurrentProcessCtx(Symbol<ParallelDesc> parallel_desc,\n                                                        Optional<int64_t>* parallel_id);\n\nextern Maybe<Optional<int64_t>> (*GetParallelId4CurrentProcessCtx)(\n    Symbol<ParallelDesc> parallel_desc);\nextern Maybe<const ParallelContext> (*GetParallelContext4CurrentProcessCtx)(\n    Symbol<ParallelDesc> parallel_desc);\nextern Maybe<Symbol<ParallelDesc>> (*ReplaceDeviceType)(Symbol<ParallelDesc>, DeviceType);\nextern Maybe<std::string> (*PlacementToString)(Symbol<ParallelDesc> placement);\nextern Maybe<Symbol<Device>> (*GetTensorDevice)(Symbol<ParallelDesc> parallel_desc);\nextern Maybe<Symbol<ParallelDesc>> (*TxtStringToPlacement)(const std::string& parallel_conf_str);\nextern Maybe<void> (*CheckDeviceIdsIsValid)(Symbol<ParallelDesc> placement);\n\nextern Maybe<Symbol<ParallelDesc>> (*GetParallelDescOfThisRank)(const std::string& device_tag);\n\ninline bool operator==(const ParallelConf& lhs, const ParallelConf& rhs) {\n  return ParallelDesc(lhs) == ParallelDesc(rhs);\n}\n\ninline bool operator!=(const ParallelConf& lhs, const ParallelConf& rhs) {\n  return ParallelDesc(lhs) != ParallelDesc(rhs);\n}\n\nstd::tuple<int32_t, int32_t> GetPartIdAndPartNumFromParallelCtx(\n    const ParallelContext* parallel_ctx);\n\nParallelConf GenParallelConfOfCpuZeroOnMaster();\nParallelConf GenParallelConfOfCpuZeroOnAllMachines();\nParallelConf GenParallelConfOfCpuOnAllRanks();\n\nnamespace private_details {\n\nMaybe<Symbol<ParallelDesc>> RawReplaceDeviceType(Symbol<ParallelDesc>, DeviceType);\n\nMaybe<std::string> RawPlacementToString(Symbol<ParallelDesc> placement);\n\nMaybe<Symbol<ParallelDesc>> RawTxtStringToPlacement(const std::string& parallel_conf_str);\n\n}  // namespace private_details\n}  // namespace oneflow\n\nnamespace std {\n\ntemplate<>\nstruct hash<oneflow::ParallelDesc> {\n  size_t operator()(const oneflow::ParallelDesc& pr) const {\n    using namespace oneflow;\n    size_t ret = 0;\n    int i = 0;\n    int shift_roundtrip = (sizeof(size_t) / 2);\n    for (int machine_id : pr.sorted_machine_ids()) {\n      int shift = i++ % shift_roundtrip;\n      AddHash(&ret, machine_id << shift_roundtrip << shift);\n      AddHash(&ret, pr.sorted_dev_phy_ids(machine_id).size() << shift);\n    }\n    AddHash(&ret, *pr.hierarchy());\n    return hash<size_t>()(ret);\n  }\n};\n\ntemplate<>\nstruct hash<oneflow::ParallelConf> {\n  size_t operator()(const oneflow::ParallelConf& parallel_conf) const {\n    return std::hash<oneflow::ParallelDesc>()(oneflow::ParallelDesc(parallel_conf));\n  }\n};\n\n}  // namespace std\n\n#endif  // ONEFLOW_CORE_JOB_PARALLEL_DESC_H_\n"
  },
  {
    "path": "oneflow/core/job/parallel_desc_test.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <algorithm>\n#include \"gtest/gtest.h\"\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/common/symbol.h\"\n#include \"oneflow/core/job/placement.pb.h\"\n#include \"oneflow/core/framework/placement_sbp_util.h\"\n#include \"oneflow/core/job/parallel_desc.h\"\n#include \"oneflow/core/control/ctrl_bootstrap.pb.h\"\n\nnamespace oneflow {\n\nnamespace test {\n\nnamespace {\n\nstruct GlobaProcessCtxScope final {\n  GlobaProcessCtxScope(int64_t node_size, int64_t world_size) {\n    Singleton<ProcessCtx>::New();\n    auto* ctx = Singleton<ProcessCtx>::Get();\n    for (int i = 0; i < world_size; ++i) { ctx->mutable_ctrl_addr()->Add(); }\n    ctx->set_rank(0);\n    ctx->set_node_size(node_size);\n  }\n  ~GlobaProcessCtxScope() { Singleton<ProcessCtx>::Delete(); }\n};\n\n}  // namespace\n\nTEST(ParallelDesc, continuous_1n4d) {\n  GlobaProcessCtxScope scope(1, 4);\n  ParallelConf parallel_conf;\n  parallel_conf.set_device_tag(\"cpu\");\n  parallel_conf.add_device_name(\"0:0-3\");\n  ParallelDesc parallel_desc(parallel_conf);\n  ASSERT_EQ(parallel_desc.device_tag(), \"cpu\");\n  ASSERT_EQ(parallel_desc.parallel_num(), 4);\n}\n\nTEST(ParallelDesc, continuous_1n4d_multi_process) {\n  GlobaProcessCtxScope scope(1, 4);\n  ParallelConf parallel_conf;\n  parallel_conf.set_device_tag(\"cpu\");\n  parallel_conf.add_device_name(\"0:0-3\");\n  ParallelDesc parallel_desc(parallel_conf);\n  const std::vector<int64_t>& machine_ids = parallel_desc.sorted_machine_ids();\n  ASSERT_EQ(parallel_desc.device_tag(), \"cpu\");\n  ASSERT_EQ(parallel_desc.parallel_num(), 4);\n  ASSERT_EQ(std::count(machine_ids.begin(), machine_ids.end(), 0), 1);\n  ASSERT_EQ(std::count(machine_ids.begin(), machine_ids.end(), 1), 1);\n  ASSERT_EQ(std::count(machine_ids.begin(), machine_ids.end(), 2), 1);\n  ASSERT_EQ(std::count(machine_ids.begin(), machine_ids.end(), 3), 1);\n}\n\nTEST(ParallelDesc, continuous_1n4d_multi_process_with_rank) {\n  GlobaProcessCtxScope scope(1, 4);\n  ParallelConf parallel_conf;\n  parallel_conf.set_device_tag(\"cpu\");\n  parallel_conf.add_device_name(\"@0:0-3\");\n  ParallelDesc parallel_desc(parallel_conf);\n  const std::vector<int64_t>& machine_ids = parallel_desc.sorted_machine_ids();\n  ASSERT_EQ(parallel_desc.device_tag(), \"cpu\");\n  ASSERT_EQ(parallel_desc.parallel_num(), 4);\n  ASSERT_EQ(machine_ids.size(), 1);\n  ASSERT_EQ(std::count(machine_ids.begin(), machine_ids.end(), 0), 1);\n}\n\nTEST(ParallelDesc, discrete_1n4d) {\n  GlobaProcessCtxScope scope(1, 4);\n  ParallelConf parallel_conf;\n  parallel_conf.set_device_tag(\"cpu\");\n  parallel_conf.add_device_name(\"0:0-1\");\n  parallel_conf.add_device_name(\"0:2-3\");\n  ParallelDesc parallel_desc(parallel_conf);\n  ASSERT_EQ(parallel_desc.device_tag(), \"cpu\");\n  ASSERT_EQ(parallel_desc.parallel_num(), 4);\n}\n\nTEST(ParallelDesc, continuous_2n8d) {\n  GlobaProcessCtxScope scope(2, 8);\n  ParallelConf parallel_conf;\n  parallel_conf.set_device_tag(\"cpu\");\n  parallel_conf.add_device_name(\"0:0-3\");\n  parallel_conf.add_device_name(\"1:0-3\");\n  ParallelDesc parallel_desc(parallel_conf);\n  ASSERT_EQ(parallel_desc.device_tag(), \"cpu\");\n  ASSERT_EQ(parallel_desc.parallel_num(), 8);\n}\n\nTEST(ParallelDesc, discrete_2n8d) {\n  GlobaProcessCtxScope scope(2, 8);\n  ParallelConf parallel_conf;\n  parallel_conf.set_device_tag(\"cpu\");\n  parallel_conf.add_device_name(\"0:0-1\");\n  parallel_conf.add_device_name(\"0:2-3\");\n  parallel_conf.add_device_name(\"1:0-1\");\n  parallel_conf.add_device_name(\"1:2-3\");\n  ParallelDesc parallel_desc(parallel_conf);\n  ASSERT_EQ(parallel_desc.device_tag(), \"cpu\");\n  ASSERT_EQ(parallel_desc.parallel_num(), 8);\n}\n\nTEST(GetBroadcastGroup, naive_1n1d) {\n  GlobaProcessCtxScope scope(1, 1);\n  ParallelConf parallel_conf;\n  parallel_conf.set_device_tag(\"cpu\");\n  parallel_conf.add_device_name(\"0:0\");\n  const auto& parallel_desc = SymbolOf(ParallelDesc(parallel_conf));\n  const auto& map = CHECK_JUST(GetBroadcastGroup(parallel_desc, parallel_desc));\n  ASSERT_EQ(map->size(), 1);\n  ASSERT_EQ(map->begin()->first, 0);\n  ASSERT_TRUE(map->begin()->second == parallel_desc);\n}\n\nTEST(GetBroadcastGroup, naive_1n4d) {\n  GlobaProcessCtxScope scope(1, 4);\n  ParallelConf src_parallel_conf;\n  src_parallel_conf.set_device_tag(\"cpu\");\n  src_parallel_conf.add_device_name(\"0:0\");\n  const auto& src_parallel_desc = SymbolOf(ParallelDesc(src_parallel_conf));\n  ParallelConf dst_parallel_conf;\n  dst_parallel_conf.set_device_tag(\"cpu\");\n  dst_parallel_conf.add_device_name(\"0:0-3\");\n  const auto& dst_parallel_desc = SymbolOf(ParallelDesc(dst_parallel_conf));\n  const auto& map = CHECK_JUST(GetBroadcastGroup(src_parallel_desc, dst_parallel_desc));\n  ASSERT_EQ(map->size(), 4);\n  for (int i = 0; i < 4; ++i) {\n    const auto& iter = map->find(i);\n    ASSERT_TRUE(iter != map->end());\n    ASSERT_TRUE(iter->second == dst_parallel_desc);\n  }\n}\n\nTEST(GetBroadcastGroup, naive_2n8d) {\n  GlobaProcessCtxScope scope(2, 8);\n  ParallelConf src_parallel_conf;\n  src_parallel_conf.set_device_tag(\"cpu\");\n  src_parallel_conf.add_device_name(\"0:0\");\n  src_parallel_conf.add_device_name(\"1:0\");\n  const auto& src_parallel_desc = SymbolOf(ParallelDesc(src_parallel_conf));\n  ParallelConf dst_parallel_conf;\n  dst_parallel_conf.set_device_tag(\"cpu\");\n  dst_parallel_conf.add_device_name(\"0:0-3\");\n  dst_parallel_conf.add_device_name(\"1:0-3\");\n  const auto& dst_parallel_desc = SymbolOf(ParallelDesc(dst_parallel_conf));\n  const auto& map = CHECK_JUST(GetBroadcastGroup(src_parallel_desc, dst_parallel_desc));\n  ASSERT_EQ(map->size(), 8);\n\n  ParallelConf first_node_parallel_conf;\n  first_node_parallel_conf.set_device_tag(\"cpu\");\n  first_node_parallel_conf.add_device_name(\"0:0-3\");\n  const auto& first_node_parallel_desc = SymbolOf(ParallelDesc(first_node_parallel_conf));\n  for (int i = 0; i < 4; ++i) {\n    const auto& iter = map->find(i);\n    ASSERT_TRUE(iter != map->end());\n    ASSERT_TRUE(iter->second == first_node_parallel_desc);\n  }\n  ParallelConf second_node_parallel_conf;\n  second_node_parallel_conf.set_device_tag(\"cpu\");\n  second_node_parallel_conf.add_device_name(\"1:0-3\");\n  const auto& second_node_parallel_desc = SymbolOf(ParallelDesc(second_node_parallel_conf));\n  for (int i = 4; i < 8; ++i) {\n    const auto& iter = map->find(i);\n    ASSERT_TRUE(iter != map->end());\n    ASSERT_TRUE(iter->second == second_node_parallel_desc);\n  }\n}\n\n}  // namespace test\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job/parallel_signature.proto",
    "content": "syntax = \"proto2\";\npackage oneflow;\n\nmessage ParallelSignature {\n  optional int64 op_parallel_desc_symbol_id = 1;\n  map<string, int64> bn_in_op2parallel_desc_symbol_id = 2;\n}\n"
  },
  {
    "path": "oneflow/core/job/pipeline_config_def.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/config_def.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nREGISTER_SCOPE_CONFIG_DEF().Int64(\n    \"pipeline_stage_id_hint\", 0,\n    \"Manually marking different stages of pipelining parallelism. \\n Generally speaking, different \"\n    \"stages are on different devices, and these stages are connected sequentially, so that the \"\n    \"whole network can be pipeline parallel.\");\n\n}  // namespace\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job/placement.proto",
    "content": "syntax = \"proto2\";\npackage oneflow;\n\nimport \"oneflow/core/register/logical_blob_id.proto\";\nimport \"oneflow/core/common/shape.proto\";\n\nmessage ParallelContext {\n  required int64 parallel_id = 1;\n  required int64 parallel_num = 2;\n}\n\nmessage ParallelConf {\n  repeated string device_name = 1;\n  required string device_tag = 2;\n  optional ShapeProto hierarchy = 3;\n}\n\nmessage OpNameSet {\n  repeated string op_name = 1;\n}\n\nmessage PlacementGroup {\n  required OpNameSet op_set = 1;\n  required ParallelConf parallel_conf = 2;\n}\n\nmessage BlobPlacementGroup {\n  repeated LogicalBlobId lbi = 1;\n  required ParallelConf parallel_conf = 2;\n}\n\nmessage Placement {\n  repeated PlacementGroup placement_group = 1;\n  repeated BlobPlacementGroup blob_placement_group = 2;\n}\n"
  },
  {
    "path": "oneflow/core/job/placement_scope.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/job/placement_scope.h\"\n#include \"oneflow/core/operator/operator.h\"\n\nnamespace oneflow {\n\nMaybe<Symbol<ParallelDesc>> PlacementScope::GetParallelDesc(const std::string& device_tag,\n                                                            const OperatorConf& op_conf) const {\n  if (device_tag == \"cpu\" || IsCpuOnly(op_conf)) {\n    return host_parallel_desc_;\n  } else {\n    return device_parallel_desc_;\n  }\n}\n\nMaybe<Symbol<ParallelDesc>> PlacementScope::GetParallelDesc(const std::string& device_tag,\n                                                            const std::string& op_type_name) const {\n  if (device_tag == \"cpu\" || IsCpuOnly(op_type_name)) {\n    return host_parallel_desc_;\n  } else {\n    return device_parallel_desc_;\n  }\n}\n\nMaybe<Symbol<ParallelDesc>> PlacementScope::GetParallelDesc(const std::string& op_type_name) const {\n  return GetParallelDesc(device_parallel_desc_->device_tag(), op_type_name);\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job/placement_scope.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_JOB_PLACEMENT_SCOPE_H_\n#define ONEFLOW_CORE_JOB_PLACEMENT_SCOPE_H_\n\n#include \"oneflow/core/common/symbol.h\"\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/job/parallel_desc.h\"\n\nnamespace oneflow {\n\nclass OperatorConf;\n\nclass PlacementScope final {\n public:\n  PlacementScope(Symbol<ParallelDesc> device_parallel_desc, Symbol<ParallelDesc> host_parallel_desc)\n      : device_parallel_desc_(device_parallel_desc), host_parallel_desc_(host_parallel_desc) {}\n\n  size_t hash_value() const { return Hash(device_parallel_desc_, host_parallel_desc_); }\n\n  bool operator==(const PlacementScope& other) const {\n    return this->device_parallel_desc_ == other.device_parallel_desc_\n           && this->host_parallel_desc_ == other.host_parallel_desc_;\n  }\n\n  Symbol<ParallelDesc> device_parallel_desc() const { return device_parallel_desc_; }\n  Symbol<ParallelDesc> host_parallel_desc() const { return host_parallel_desc_; }\n\n  Maybe<Symbol<ParallelDesc>> GetParallelDesc(const std::string& device_tag,\n                                              const OperatorConf& op_conf) const;\n\n  Maybe<Symbol<ParallelDesc>> GetParallelDesc(const std::string& device_tag,\n                                              const std::string& op_type_name) const;\n\n  Maybe<Symbol<ParallelDesc>> GetParallelDesc(const std::string& op_type_name) const;\n\n private:\n  Symbol<ParallelDesc> device_parallel_desc_;\n  Symbol<ParallelDesc> host_parallel_desc_;\n};\n\n}  // namespace oneflow\n\nnamespace std {\n\ntemplate<>\nstruct hash<oneflow::PlacementScope> final {\n  size_t operator()(const oneflow::PlacementScope& val) const { return val.hash_value(); }\n};\n\n}  // namespace std\n\n#endif  // ONEFLOW_CORE_JOB_PLACEMENT_SCOPE_H_\n"
  },
  {
    "path": "oneflow/core/job/plan.proto",
    "content": "syntax = \"proto2\";\npackage oneflow;\n\nimport \"oneflow/core/job/task.proto\";\nimport \"oneflow/core/job/job_conf.proto\";\nimport \"oneflow/core/memory/memory_block.proto\";\nimport \"oneflow/core/graph/boxing/collective_boxing.proto\";\nimport \"oneflow/core/operator/op_attribute.proto\";\n\nmessage MachineIds {\n  repeated int64 machine_id = 1;\n}\n\nmessage JobConfs {\n  map<int64, JobConfigProto> job_id2job_conf = 1;\n}\n\nmessage CollectiveBoxingPlan {\n  map<int64, boxing.collective.RequestSet> job_id2request_set = 1;\n}\n\nmessage CtrlRegstDescInfo {\n  map<int64, int64> ctrl_regst_desc_id2producer_task_id = 6;\n}\n\nmessage OpAttributeRefTable {\n  map<string, OpAttribute> op_name2op_attribute = 1;\n}\n\nmessage OpAttributeInfo {\n  map<int64, OpAttributeRefTable> job_id2op_attribute_ref_table = 1;\n}\n\nmessage Plan {\n  repeated TaskProto task = 1;\n  required MemBlockAndChunkList block_chunk_list = 2;\n  required JobConfs job_confs = 4;\n  required CollectiveBoxingPlan collective_boxing_plan= 5;\n  required CtrlRegstDescInfo ctrl_regst_desc_info = 6;\n  map<int64, OpAttributeRefTable> job_id2op_attribute_ref_table = 7;\n}\n"
  },
  {
    "path": "oneflow/core/job/plan_util.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/constant.h\"\n#include \"oneflow/core/common/str_util.h\"\n#include \"oneflow/core/common/env_var/debug_mode.h\"\n#include \"oneflow/core/control/global_process_ctx.h\"\n#include \"oneflow/core/job/plan_util.h\"\n#include \"oneflow/core/common/container_util.h\"\n#include \"oneflow/core/job/global_for.h\"\n#include \"oneflow/core/graph/plan_task_graph.h\"\n#include \"oneflow/core/graph/boxing/collective_boxing_util.h\"\n#include \"oneflow/core/memory/chunk_manager.h\"\n#include \"oneflow/core/memory/memory_case_util.h\"\n#include \"oneflow/core/register/runtime_register_desc.h\"\n#include \"oneflow/core/persistence/tee_persistent_log_stream.h\"\n#include \"oneflow/core/ep/include/device_manager_registry.h\"\n#include \"oneflow/core/operator/operator.h\"\n#include \"oneflow/core/graph/task_node.h\"\n\nnamespace oneflow {\n\nRegstDescProto* PlanUtil::GetSoleProducedDataRegst(TaskProto* task_proto) {\n  RegstDescProto* ret = nullptr;\n  for (auto& pair : *task_proto->mutable_produced_regst_desc()) {\n    RegstDescProto* regst_desc = &pair.second;\n    if (regst_desc->regst_desc_type().has_data_regst_desc()) {\n      CHECK_ISNULL(ret);\n      CHECK_EQ(regst_desc->regst_desc_type().data_regst_desc().lbi2blob_desc_size(), 1);\n      ret = regst_desc;\n    }\n  }\n  CHECK_NOTNULL(ret);\n  return ret;\n}\n\nstd::function<const TaskProto*(int64_t)> PlanUtil::MakeGetterTaskProto4TaskId(const Plan& plan) {\n  auto task_id2task_proto = std::make_shared<HashMap<int64_t, const TaskProto*>>();\n  for (const TaskProto& task_proto : plan.task()) {\n    task_id2task_proto->emplace(task_proto.task_id(), &task_proto);\n  }\n  return [task_id2task_proto](int64_t task_id) { return task_id2task_proto->at(task_id); };\n}\n\nnamespace {\n\nvoid SetVariableOpNamesForVariableAndRepeatRegst(Plan* plan) {\n  // NOTE(chengcheng): set variable_op_name before set separated header because var regst alway\n  //  separated.\n  HashMap<int64_t, std::string> regst_id2var_name;\n  for (int i = 0; i < plan->task_size(); i++) {\n    TaskProto* task = plan->mutable_task(i);\n    if (task->exec_sequence().exec_node_size() == 1) {\n      const auto& op_conf =\n          PlanUtil::GetOpAttribute(plan, task->job_id(),\n                                   task->exec_sequence().exec_node(0).kernel_conf())\n              .op_conf();\n      if (op_conf.has_variable_conf()) {\n        RegstDescProto* regst = PlanUtil::GetSoleProducedDataRegst(task);\n        regst_id2var_name.emplace(regst->regst_desc_id(), op_conf.name());\n        regst->set_variable_op_name(op_conf.name());\n      }\n    }\n  }\n\n  for (int i = 0; i < plan->task_size(); i++) {\n    TaskProto* task = plan->mutable_task(i);\n    if (task->task_type() == TaskType::kRepeat) {\n      RegstDescProto* regst = PlanUtil::GetSoleProducedDataRegst(task);\n      CHECK(regst->has_force_inplace_consumed_regst_desc_id());\n      int64_t force_inplace_regst_id = regst->force_inplace_consumed_regst_desc_id();\n      auto var_name_it = regst_id2var_name.find(force_inplace_regst_id);\n      if (var_name_it != regst_id2var_name.end()) {\n        regst->set_variable_op_name(var_name_it->second);\n        VLOG(3) << \" set var op name to repeat regst : \" << regst->DebugString();\n      }\n    }\n  }\n}\n\n}  // namespace\n\nvoid PlanUtil::SetUniqueMemBlockId4UnreusedMemRegst(Plan* plan) {\n  SetVariableOpNamesForVariableAndRepeatRegst(plan);\n\n  for (int i = 0; i < plan->task_size(); i++) {\n    TaskProto* task = plan->mutable_task(i);\n\n    for (auto& pair : *task->mutable_produced_regst_desc()) {\n      RegstDescProto* regst_desc = &pair.second;\n      if (regst_desc->mem_block_id() == -1) {\n        CHECK_EQ(regst_desc->mem_block_offset(), -1);\n        regst_desc->set_mem_block_id(Singleton<IDMgr>::Get()->NewMemBlockId());\n        regst_desc->set_mem_block_offset(0);\n      }\n\n      RtRegstDesc rt_regst_desc(*regst_desc);\n      int64_t regst_separated_size = rt_regst_desc.TotalSeparatedHeaderByteSize4AllRegst();\n      if (regst_separated_size > 0) {\n        int64_t separated_mem_block_id = Singleton<IDMgr>::Get()->NewMemBlockId();\n        regst_desc->set_separated_header_mem_block_id(separated_mem_block_id);\n      }\n    }\n  }\n}\n\nvoid PlanUtil::GenMemBlockAndChunk4Plan(Plan* plan) {\n  HashSet<std::string> variable_op_names;\n  PlanUtil::GenMemBlockAndChunkWithVariableOpNames4Plan(plan, variable_op_names);\n}\n\nnamespace {\n\nvoid GenChunkForMultiNNGraphMemoryReuseInMultiClient(\n    Plan* plan, HashMap<int64_t, std::unique_ptr<MemBlockProto>>* mem_block_id2mem_block) {\n  HashMap<int64_t, HashSet<MemBlockProto*>> mzuid2mem_blocks;\n\n  for (auto& pair : *mem_block_id2mem_block) {\n    MemBlockProto* mem_block = pair.second.get();\n    CHECK(mem_block->has_chunk_id() == false);\n    CHECK(mem_block->has_chunk_offset() == false);\n    if (mem_block->has_variable_op_name()) { continue; }\n    if (!mem_block->enable_reuse_mem()) { continue; }\n    // NOTE(chengcheng):\n    //   only reused mem in cuda device.\n    //   special cpu memory like OFRecord pb and TensorBuffer CANNOT reused by another plan.\n    if (memory::IsHostMem(mem_block->mem_case())) { continue; }\n    int64_t mem_zone_uid =\n        memory::GetUniqueMemCaseId(mem_block->machine_id(), mem_block->mem_case());\n    auto it = mzuid2mem_blocks.find(mem_zone_uid);\n    if (it == mzuid2mem_blocks.end()) {\n      it = mzuid2mem_blocks.emplace(mem_zone_uid, HashSet<MemBlockProto*>()).first;\n    }\n    CHECK(it->second.insert(mem_block).second);\n  }\n\n  std::vector<ChunkProto> all_chunks;\n  HashSet<int64_t> unique_chunk_ids;\n\n  for (auto& pair : mzuid2mem_blocks) {\n    int64_t mem_zone_uid = pair.first;\n    std::vector<const ChunkProto*> exist_chunks;\n    Singleton<ChunkMgr>::Get()->GetChunkProtosByMemZoneUniqueId(mem_zone_uid, &exist_chunks);\n    auto chunk_it = exist_chunks.begin();\n    auto& mem_blocks = pair.second;\n    int64_t current_chunk_offset = 0;\n    HashSet<MemBlockProto*> remain_blocks;\n    for (auto mem_block_it = mem_blocks.begin(); mem_block_it != mem_blocks.end(); ++mem_block_it) {\n      if (chunk_it == exist_chunks.end()) {\n        // NOTE(chengcheng): it means that exist chunk has run out.\n        CHECK(remain_blocks.insert(*mem_block_it).second);\n      } else {\n        // NOTE(chengcheng): find chunk which has enough space left.\n        while (chunk_it != exist_chunks.end()\n               && (current_chunk_offset + (*mem_block_it)->mem_size() > (*chunk_it)->mem_size())) {\n          // NOTE(chengcheng): current chunk has no space left, so we move to next chunk.\n          ++chunk_it;\n          current_chunk_offset = 0;\n        }\n        if (chunk_it != exist_chunks.end()) {\n          // NOTE(chengcheng): lucky, we find a appropriate chunk.\n          MemBlockProto* mem_block = *mem_block_it;\n          const ChunkProto* chunk = *chunk_it;\n          CHECK_EQ(mem_block->machine_id(), chunk->machine_id());\n          CHECK(mem_block->mem_case() == chunk->mem_case());\n          CHECK_LE(current_chunk_offset + mem_block->mem_size(), chunk->mem_size());\n          CHECK_GE(current_chunk_offset, 0);\n          // CHECK_GT(mem_block->mem_size(), 0); NOTE(chengcheng): has mem block mem size = 0\n          CHECK_GE(chunk->mem_size(), 0);\n          mem_block->set_chunk_id(chunk->chunk_id());\n          mem_block->set_chunk_offset(current_chunk_offset);\n          current_chunk_offset += mem_block->mem_size();\n          VLOG(3) << \"Lazy nn.Graph Reused MemBlock :[\" << mem_block->DebugString()\n                  << \"] to old Chunk :[\" << chunk->DebugString() << \"]\\n\";\n        } else {\n          // NOTE(chengcheng): sad, no chunk can used, so this mem block need to insert in remain.\n          CHECK(remain_blocks.insert(*mem_block_it).second);\n        }\n      }\n    }\n\n    for (const ChunkProto* exist_chunk : exist_chunks) {\n      all_chunks.emplace_back(*exist_chunk);\n      CHECK(unique_chunk_ids.insert(exist_chunk->chunk_id()).second);\n    }\n\n    if (!remain_blocks.empty()) {\n      auto remain_block_it = remain_blocks.begin();\n      MemBlockProto* first_block = *remain_block_it;\n      ChunkProto new_chunk;\n      new_chunk.set_chunk_id(Singleton<IDMgr>::Get()->NewChunkId());\n      new_chunk.set_machine_id(first_block->machine_id());\n      *new_chunk.mutable_mem_case() = first_block->mem_case();\n      new_chunk.set_mem_size(first_block->mem_size());\n      first_block->set_chunk_id(new_chunk.chunk_id());\n      first_block->set_chunk_offset(0);\n      ++remain_block_it;\n      VLOG(3) << \"Lazy nn.Graph Add MemBlock :[\" << first_block->DebugString() << \"] to NewChunk :[\"\n              << new_chunk.DebugString() << \"]\\n\";\n\n      while (remain_block_it != remain_blocks.end()) {\n        MemBlockProto* this_block = *remain_block_it;\n        CHECK_EQ(this_block->machine_id(), new_chunk.machine_id());\n        CHECK(this_block->mem_case() == new_chunk.mem_case());\n        this_block->set_chunk_id(new_chunk.chunk_id());\n        this_block->set_chunk_offset(new_chunk.mem_size());\n        new_chunk.set_mem_size(new_chunk.mem_size() + this_block->mem_size());\n        VLOG(3) << \"Lazy nn.Graph Add MemBlock :[\" << this_block->DebugString()\n                << \"] to NewChunk :[\" << new_chunk.DebugString() << \"]\\n\";\n        ++remain_block_it;\n      }\n\n      all_chunks.emplace_back(new_chunk);\n      CHECK(unique_chunk_ids.insert(new_chunk.chunk_id()).second);\n\n      Singleton<ChunkMgr>::Get()->AddChunkProto(new_chunk);\n    }\n  }\n\n  CHECK_EQ(all_chunks.size(), unique_chunk_ids.size());\n\n  for (const ChunkProto& chunk : all_chunks) {\n    *(plan->mutable_block_chunk_list()->add_chunk()) = chunk;\n  }\n}\n\n}  // namespace\n\nvoid PlanUtil::MergeMemBlockIdByLogicalChainId(Plan* plan, const Job& job, int64_t limited_rank) {\n  if (job.logical_chain_groups_size() == 0) { return; }\n  HashMap<int64_t, HashMap<int64_t, int64_t>> logical_chain_id2machine_id2mem_block_id;\n\n  for (int64_t i = 0; i < plan->task_size(); ++i) {\n    TaskProto* task = plan->mutable_task(i);\n    const StreamId stream_id = PlanUtil::GetStreamId(*task);\n    int64_t machine_id = task->machine_id();\n    DeviceType device_type = stream_id.device_id().device_type();\n    // TODO(zwx): eliminate this special 'is cpu' determine\n    if (device_type == DeviceType::kCPU) { continue; }\n    if (!IsValidChainId(task->chain_id())) { continue; }\n    int64_t logical_chain_id = task->chain_id();\n\n    for (auto& pair : *(task->mutable_produced_regst_desc())) {\n      RegstDescProto* regst_desc = &pair.second;\n      if (regst_desc->mem_block_id() != -1 && regst_desc->enable_reuse_mem()\n          && regst_desc->mem_case().device_type() == device_type\n          && regst_desc->regst_desc_type().has_data_regst_desc()) {\n        int64_t mem_block_id = regst_desc->mem_block_id();\n        auto* rank2blocks = &(logical_chain_id2machine_id2mem_block_id[logical_chain_id]);\n        if (rank2blocks->find(machine_id) == rank2blocks->end()) {\n          rank2blocks->emplace(machine_id, mem_block_id);\n        } else {\n          CHECK_EQ(rank2blocks->at(machine_id), mem_block_id);\n        }\n      }\n    }\n  }\n\n  HashMap<int64_t, int64_t> mem_block_id2merged_mem_block_id;\n  for (const auto& logical_chain_group : job.logical_chain_groups()) {\n    CHECK_GE(logical_chain_group.logical_chain_id_list_size(), 2);\n    int64_t merged_logical_chain_id = logical_chain_group.logical_chain_id_list(0);\n    if (limited_rank == -1) {\n      CHECK(logical_chain_id2machine_id2mem_block_id.find(merged_logical_chain_id)\n            != logical_chain_id2machine_id2mem_block_id.end());\n    } else {\n      if (logical_chain_id2machine_id2mem_block_id.find(merged_logical_chain_id)\n          == logical_chain_id2machine_id2mem_block_id.end()) {\n        // Skip when doing rank compile and this logical chain group is not related to this rank.\n        continue;\n      }\n    }\n    const auto& merged_rank2block =\n        logical_chain_id2machine_id2mem_block_id.at(merged_logical_chain_id);\n    for (int64_t i = 1; i < logical_chain_group.logical_chain_id_list_size(); ++i) {\n      int64_t this_logical_chain_id = logical_chain_group.logical_chain_id_list(i);\n      // NOTE(chengcheng): merge mem block id by each rank\n      CHECK(logical_chain_id2machine_id2mem_block_id.find(this_logical_chain_id)\n            != logical_chain_id2machine_id2mem_block_id.end());\n      const auto& this_rank2block =\n          logical_chain_id2machine_id2mem_block_id.at(this_logical_chain_id);\n      for (const auto& pair : this_rank2block) {\n        int64_t this_machine_id = pair.first;\n        int64_t this_mem_block_id = pair.second;\n        if (limited_rank == -1) {\n          CHECK(merged_rank2block.find(this_machine_id) != merged_rank2block.end());\n        } else {\n          if (merged_rank2block.find(this_machine_id) == merged_rank2block.end()) { continue; }\n        }\n\n        int64_t merged_mem_block_id = merged_rank2block.at(this_machine_id);\n        CHECK(mem_block_id2merged_mem_block_id.emplace(this_mem_block_id, merged_mem_block_id)\n                  .second);\n        VLOG(2) << \" merge mem_block_id: \" << this_mem_block_id << \" to \" << merged_mem_block_id;\n      }\n    }\n  }\n\n  for (int64_t i = 0; i < plan->task_size(); ++i) {\n    TaskProto* task = plan->mutable_task(i);\n    const StreamId stream_id = PlanUtil::GetStreamId(*task);\n    DeviceType device_type = stream_id.device_id().device_type();\n    // TODO(zwx): eliminate this special 'is cpu' determine\n    if (device_type == DeviceType::kCPU) { continue; }\n    if (!IsValidChainId(task->chain_id())) { continue; }\n\n    for (auto& pair : *(task->mutable_produced_regst_desc())) {\n      RegstDescProto* regst_desc = &pair.second;\n      if (regst_desc->mem_block_id() != -1 && regst_desc->enable_reuse_mem()\n          && regst_desc->mem_case().device_type() == device_type\n          && regst_desc->regst_desc_type().has_data_regst_desc()) {\n        int64_t mem_block_id = regst_desc->mem_block_id();\n        if (mem_block_id2merged_mem_block_id.find(mem_block_id)\n            != mem_block_id2merged_mem_block_id.end()) {\n          // merge mem_block_id\n          int64_t merged_mem_block_id = mem_block_id2merged_mem_block_id.at(mem_block_id);\n          regst_desc->set_mem_block_id(merged_mem_block_id);\n          if (VLOG_IS_ON(3)) {\n            const auto& data_regst = regst_desc->regst_desc_type().data_regst_desc();\n            CHECK_GE(data_regst.lbi2blob_desc_size(), 1);\n            const auto& lbi2blob_desc_pair = data_regst.lbi2blob_desc(0);\n            std::string tensor_name = GenLogicalBlobName(lbi2blob_desc_pair.lbi());\n            VLOG(3) << \" regst: \" << tensor_name << \" merge mem block id \" << mem_block_id << \" to \"\n                    << merged_mem_block_id;\n          }\n        }\n      }\n    }\n  }\n}\n\nvoid PlanUtil::GenMemBlockAndChunkWithVariableOpNames4Plan(\n    Plan* plan, const HashSet<std::string>& variable_op_names) {\n  HashMap<int64_t, std::unique_ptr<MemBlockProto>> mem_block_id2mem_block;\n\n  auto IsVariableRegst = [&](const TaskProto* task, std::string* name) -> bool {\n    if (variable_op_names.empty()) { return false; }\n    if (task->exec_sequence().exec_node_size() != 1) { return false; }\n    const auto& op_conf =\n        GetOpAttribute(plan, task->job_id(), task->exec_sequence().exec_node(0).kernel_conf())\n            .op_conf();\n    if (!op_conf.has_variable_conf()) { return false; }\n    const std::string& var_name = op_conf.name();\n    if (variable_op_names.find(var_name) == variable_op_names.end()) {\n      LOG(WARNING) << \" Oh no! Cannot find variable_op_name: \" << var_name\n                   << \" in nn.Graph Compiler bind EagerTensor with VariableOp. \"\n                   << \" \\n But each variable need bind with eager tensor for init.\";\n      return false;\n    }\n    *name = var_name;\n    return true;\n  };\n\n  auto GenMemBlock4RegstIfNeed = [&](RegstDescProto* regst_desc, const TaskProto* task) {\n    const int64_t job_id = task->job_id();\n    const int64_t machine_id = task->machine_id();\n    const int64_t thrd_id = task->thrd_id();\n    int64_t mem_block_id = regst_desc->mem_block_id();\n    int64_t mem_block_offset = regst_desc->mem_block_offset();\n    CHECK_NE(mem_block_id, -1);\n    CHECK_NE(mem_block_offset, -1);\n\n    std::string var_name;\n    bool is_variable_regst = IsVariableRegst(task, &var_name);\n    if (is_variable_regst) {\n      CHECK(!var_name.empty());\n      CHECK_EQ(regst_desc->register_num(), 1);\n      CHECK_EQ(regst_desc->min_register_num(), 1);\n      // NOTE(xuxiaoyu): this check cannot pass when open ZeRO\n      // CHECK_EQ(regst_desc->max_register_num(), 1) << var_name;\n      regst_desc->set_variable_op_name(var_name);\n    }\n\n    RtRegstDesc rt_regst_desc(*regst_desc);\n    int64_t regst_main_size = rt_regst_desc.TotalMainByteSize4AllRegst();\n    int64_t regst_separated_size = rt_regst_desc.TotalSeparatedHeaderByteSize4AllRegst();\n\n    auto mem_block_it = mem_block_id2mem_block.find(mem_block_id);\n    if (mem_block_it == mem_block_id2mem_block.end()) {\n      MemBlockProto mem_block;\n      mem_block.set_mem_block_id(mem_block_id);\n      mem_block.add_job_id(job_id);\n      mem_block.set_machine_id(machine_id);\n      *(mem_block.mutable_mem_case()) = regst_desc->mem_case();\n      mem_block.set_enable_reuse_mem(regst_desc->enable_reuse_mem());\n      mem_block.set_mem_size(regst_main_size + mem_block_offset);\n      mem_block.set_thrd_id_hint(thrd_id);\n      if (is_variable_regst) {\n        mem_block.set_variable_op_name(var_name);\n        mem_block.set_is_separated_header(false);\n      }\n      CHECK(mem_block_id2mem_block\n                .emplace(mem_block.mem_block_id(), std::make_unique<MemBlockProto>(mem_block))\n                .second);\n    } else {\n      MemBlockProto* mem_block = mem_block_it->second.get();\n      CHECK_EQ(mem_block->job_id(0), job_id);\n      CHECK_EQ(mem_block->machine_id(), machine_id);\n      CHECK(mem_block->mem_case() == regst_desc->mem_case());\n      CHECK_EQ(mem_block->enable_reuse_mem(), regst_desc->enable_reuse_mem());\n      if (mem_block->enable_reuse_mem()) {\n        mem_block->set_mem_size(\n            std::max(mem_block->mem_size(), regst_main_size + mem_block_offset));\n      } else {\n        CHECK_EQ(mem_block->mem_size(), regst_main_size);\n        CHECK_EQ(mem_block_offset, 0);\n      }\n      if (is_variable_regst) {\n        mem_block->set_variable_op_name(var_name);\n        mem_block->set_is_separated_header(false);\n      }\n    }\n\n    if (regst_separated_size > 0) {\n      CHECK(regst_desc->has_separated_header_mem_block_id()) << regst_desc->DebugString();\n      int64_t separated_mem_block_id = regst_desc->separated_header_mem_block_id();\n      CHECK_NE(separated_mem_block_id, -1);\n      if (mem_block_id2mem_block.find(separated_mem_block_id) == mem_block_id2mem_block.end()) {\n        MemBlockProto mem_block;\n        mem_block.set_mem_block_id(separated_mem_block_id);\n        mem_block.add_job_id(job_id);\n        mem_block.set_machine_id(machine_id);\n        *(mem_block.mutable_mem_case()) = memory::GetPinnedHostMemoryCase(regst_desc->mem_case());\n        mem_block.set_enable_reuse_mem(false);\n        mem_block.set_mem_size(regst_separated_size);\n        mem_block.set_thrd_id_hint(thrd_id);\n        if (is_variable_regst) {\n          mem_block.set_variable_op_name(var_name);\n          mem_block.set_is_separated_header(true);\n        }\n        CHECK(mem_block_id2mem_block\n                  .emplace(mem_block.mem_block_id(), std::make_unique<MemBlockProto>(mem_block))\n                  .second);\n      } else {\n        MemBlockProto* mem_block = mem_block_id2mem_block.at(separated_mem_block_id).get();\n        CHECK_EQ(mem_block->job_id(0), job_id);\n        CHECK_EQ(mem_block->machine_id(), machine_id);\n        CHECK(mem_block->mem_case() == memory::GetPinnedHostMemoryCase(regst_desc->mem_case()));\n        CHECK_EQ(mem_block->enable_reuse_mem(), false);\n        CHECK_EQ(mem_block->mem_size(), regst_separated_size);\n        if (is_variable_regst) {\n          mem_block->set_variable_op_name(var_name);\n          mem_block->set_is_separated_header(true);\n        }\n      }\n    }\n  };\n\n  for (int i = 0; i < plan->task_size(); i++) {\n    TaskProto* task = plan->mutable_task(i);\n    for (auto& pair : *task->mutable_produced_regst_desc()) {\n      GenMemBlock4RegstIfNeed(&pair.second, task);\n    }\n  }\n\n  GenChunkForMultiNNGraphMemoryReuseInMultiClient(plan, &mem_block_id2mem_block);\n\n  for (const auto& pair : mem_block_id2mem_block) {\n    *(plan->mutable_block_chunk_list()->add_mem_block()) = *(pair.second);\n  }\n}\n\nvoid PlanUtil::CleanUselessMemBlockAndCheckValid(Plan* plan) {\n  HashMap<int64_t, ChunkProto> chunk_id2chunk;\n  HashMap<int64_t, MemBlockProto> mem_block_id2mem_block;\n  for (const auto& chunk : plan->block_chunk_list().chunk()) {\n    CHECK(chunk_id2chunk.emplace(chunk.chunk_id(), chunk).second);\n  }\n  for (const auto& mem_block : plan->block_chunk_list().mem_block()) {\n    CHECK(mem_block_id2mem_block.emplace(mem_block.mem_block_id(), mem_block).second);\n  }\n  plan->mutable_block_chunk_list()->clear_mem_block();\n\n  HashMap<int64_t, HashSet<int64_t>> chunk_id2job_ids;\n  HashMap<int64_t, HashSet<int64_t>> mem_block_id2job_ids;\n  for (const auto& pair : chunk_id2chunk) {\n    for (int64_t job_id : pair.second.job_id()) {\n      CHECK(chunk_id2job_ids[pair.first].insert(job_id).second);\n    }\n  }\n  for (const auto& pair : mem_block_id2mem_block) {\n    for (int64_t job_id : pair.second.job_id()) {\n      CHECK(mem_block_id2job_ids[pair.first].insert(job_id).second);\n    }\n  }\n\n  HashSet<int64_t> valid_mem_block_ids;\n  for (const TaskProto& task : plan->task()) {\n    for (const auto& pair : task.produced_regst_desc()) {\n      const RegstDescProto& regst = pair.second;\n      RtRegstDesc rt_regst(regst);\n      int64_t regst_size = rt_regst.TotalMainByteSize4AllRegst();\n      CHECK(mem_block_id2mem_block.find(regst.mem_block_id()) != mem_block_id2mem_block.end());\n      const MemBlockProto& mem_block = mem_block_id2mem_block.at(regst.mem_block_id());\n      CHECK_GE(mem_block.mem_size(), regst.mem_block_offset() + regst_size);\n      CHECK_EQ(task.machine_id(), mem_block.machine_id());\n      CHECK_EQ(mem_block.enable_reuse_mem(), regst.enable_reuse_mem());\n      CHECK(mem_block.mem_case() == regst.mem_case());\n      const auto& job_ids = mem_block_id2job_ids[regst.mem_block_id()];\n      CHECK(job_ids.find(task.job_id()) != job_ids.end());\n      valid_mem_block_ids.insert(regst.mem_block_id());\n\n      // separated_header\n      int64_t separated_header_mem_size = rt_regst.TotalSeparatedHeaderByteSize4AllRegst();\n      if (separated_header_mem_size > 0) {\n        int64_t header_block_id = regst.separated_header_mem_block_id();\n        CHECK_NE(header_block_id, -1);\n        CHECK(mem_block_id2mem_block.find(header_block_id) != mem_block_id2mem_block.end());\n        const MemBlockProto& header_mem_block = mem_block_id2mem_block.at(header_block_id);\n        CHECK_EQ(header_mem_block.mem_size(), separated_header_mem_size);\n        CHECK_EQ(task.machine_id(), header_mem_block.machine_id());\n        CHECK(header_mem_block.mem_case() == memory::GetPinnedHostMemoryCase(regst.mem_case()));\n        CHECK(header_mem_block.enable_reuse_mem() == false);\n        const auto& header_block_job_ids = mem_block_id2job_ids[header_block_id];\n        CHECK(header_block_job_ids.find(task.job_id()) != header_block_job_ids.end());\n        valid_mem_block_ids.insert(regst.separated_header_mem_block_id());\n      }\n    }\n  }\n\n  HashSet<int64_t> useless_mem_block_ids;\n  HashSet<int64_t> valid_chunk_ids;\n  for (const auto& pair : mem_block_id2mem_block) {\n    if (valid_mem_block_ids.find(pair.first) == valid_mem_block_ids.end()) {\n      CHECK(useless_mem_block_ids.insert(pair.first).second);\n      continue;\n    }\n    const MemBlockProto& mem_block = pair.second;\n    if (mem_block.has_chunk_id()) {\n      CHECK(mem_block.has_chunk_offset());\n      CHECK(mem_block.enable_reuse_mem());\n      CHECK(chunk_id2chunk.find(mem_block.chunk_id()) != chunk_id2chunk.end());\n      const ChunkProto& chunk = chunk_id2chunk.at(mem_block.chunk_id());\n      CHECK_GE(chunk.mem_size(), mem_block.chunk_offset() + mem_block.mem_size());\n      CHECK_EQ(mem_block.job_id_size(), 1);\n      CHECK_GE(chunk.job_id_size(), 1);\n      const HashSet<int64_t>& chunk_job_ids = chunk_id2job_ids.at(chunk.chunk_id());\n      CHECK(chunk_job_ids.find(mem_block.job_id(0)) != chunk_job_ids.end());\n      valid_chunk_ids.insert(mem_block.chunk_id());\n    }\n  }\n  CHECK_EQ(valid_chunk_ids.size(), chunk_id2chunk.size());\n\n  for (int64_t useless_block_id : useless_mem_block_ids) {\n    mem_block_id2mem_block.erase(useless_block_id);\n  }\n\n  for (const auto& pair : mem_block_id2mem_block) {\n    *(plan->mutable_block_chunk_list()->add_mem_block()) = pair.second;\n  }\n}\n\nvoid PlanUtil::ToDotFile(const Plan& plan, const std::string& filepath) {\n  const auto& process_ranks = Singleton<ResourceDesc, ForSession>::Get()->process_ranks();\n  size_t gpu_device_num =\n      Singleton<ep::DeviceManagerRegistry>::Get()->GetDeviceCount(DeviceType::kCUDA);\n  std::map<int64_t, std::map<int64_t, std::vector<std::vector<std::string>>>>\n      machine_id2job_id_device_id2node_list;\n  for (size_t i : process_ranks) {\n    for (const auto& pair : plan.job_confs().job_id2job_conf()) {\n      machine_id2job_id_device_id2node_list[i][pair.first].resize(gpu_device_num);\n    }\n  }\n  std::map<int64_t, std::map<int64_t, std::vector<std::string>>> machine_id2job_id2host_node_list;\n  std::vector<std::string> main_node_list;\n  std::vector<std::string> copy_comm_net_node_list;\n  HashSet<int64_t> ctrl_regst_desc_ids;\n  HashMap<int64_t, HashMap<int64_t, std::string>> task_id2consumer_regst_id2name;\n  HashMap<int64_t, std::string> task_id2op_name;\n  HashMap<int64_t, std::vector<int64_t>> task_id2producer_task_ids;\n  std::vector<std::set<int64_t>> machine_id2device_id2node_list_job_ids(process_ranks.size());\n  std::vector<std::set<int64_t>> machine_id2host_node_list_job_ids(process_ranks.size());\n\n  auto InsertNodeDefByTaskProto = [&](const TaskProto& task_proto, const std::string& node_def,\n                                      const std::string& pass_tag) {\n    if (task_proto.task_type() == TaskType::kCopyCommNet) {\n      copy_comm_net_node_list.emplace_back(node_def);\n      return;\n    }\n    if (pass_tag == kNoPassTag) {\n      const StreamId stream_id = PlanUtil::GetStreamId(task_proto);\n      if (stream_id.device_id().device_type() == DeviceType::kCUDA) {\n        machine_id2job_id_device_id2node_list[task_proto.machine_id()][task_proto.job_id()]\n                                             [stream_id.device_id().device_index()]\n                                                 .emplace_back(node_def);\n        machine_id2device_id2node_list_job_ids[task_proto.machine_id()].insert(task_proto.job_id());\n      } else {\n        machine_id2job_id2host_node_list[task_proto.machine_id()][task_proto.job_id()].emplace_back(\n            node_def);\n        machine_id2host_node_list_job_ids[task_proto.machine_id()].insert(task_proto.job_id());\n      }\n    } else if (pass_tag == kMainOp) {\n      main_node_list.emplace_back(node_def);\n    } else {\n      UNIMPLEMENTED();\n    }\n  };\n\n  auto GenEdgeColorStr = [](const RegstDescTypeProto& type) {\n    if (type.has_ctrl_regst_desc()) { return \"fontcolor=\\\"gray65\\\",color=\\\"gray65\\\"\"; }\n    return \"fontcolor=\\\"gray15\\\",color=\\\"gray15\\\"\";\n  };\n\n  auto IsEsac2ReentrantLockEdge = [](const std::string& src_name, const std::string& dst_name) {\n    if (src_name.find(\"Esac\") != std::string::npos\n        && dst_name.find(\"ReentrantLock\") != std::string::npos) {\n      return true;\n    }\n    return false;\n  };\n\n  auto IsEsacNode = [](const std::string& name) {\n    if (name.find(\"Esac\") != std::string::npos) { return true; }\n    return false;\n  };\n\n  auto log_stream = TeePersistentLogStream::Create(filepath);\n  // task node\n  for (const TaskProto& task_proto : plan.task()) {\n    for (const auto& pair : task_proto.produced_regst_desc()) {\n      const RegstDescProto& regst = pair.second;\n      for (int64_t consumer_task_id : regst.consumer_task_id()) {\n        task_id2producer_task_ids[consumer_task_id].emplace_back(task_proto.task_id());\n      }\n    }\n  }\n\n  for (const TaskProto& task_proto : plan.task()) {\n    std::string task_id_str = \"task\" + std::to_string(task_proto.task_id());\n    std::string task_class = task_id_str;\n    for (const auto& in_task_id : task_id2producer_task_ids[task_proto.task_id()]) {\n      task_class += \" in\" + std::to_string(in_task_id);\n    }\n    for (const auto& pair : task_proto.produced_regst_desc()) {\n      const RegstDescProto& regst = pair.second;\n      for (int64_t consumer_task_id : regst.consumer_task_id()) {\n        task_class += \" out\" + std::to_string(consumer_task_id);\n      }\n    }\n    task_class += \" job_id\" + std::to_string(task_proto.job_id());\n    task_class += \" machine_id\" + std::to_string(task_proto.machine_id());\n    std::string node_def = task_id_str + \"[class=\\\"\" + task_class + \"\\\",label=\\\"{{\";\n    node_def += std::to_string(task_proto.task_id()) + \":\" + std::to_string(task_proto.machine_id())\n                + \"\\\\n\";\n    std::string op_name = \"\";\n    std::string pass_tag = kNoPassTag;\n    for (const ExecNodeProto& exec_node : task_proto.exec_sequence().exec_node()) {\n      const auto& op_conf =\n          GetOpAttribute(&plan, task_proto.job_id(), exec_node.kernel_conf()).op_conf();\n      op_name += op_conf.name();\n      if (op_conf.has_pass_tag()) { pass_tag = op_conf.pass_tag(); }\n    }\n    task_id2op_name[task_proto.task_id()] = op_name;\n    node_def += op_name;\n    size_t index = 0;\n    for (const auto& pair : task_proto.produced_regst_desc()) {\n      std::string regst_id = std::to_string(pair.second.regst_desc_id());\n      if (index % 2 == 0) {\n        node_def += \"}|{\";\n      } else {\n        node_def += \"|\";\n      }\n      // node_def += \"<regst_desc_\" + regst_id + \">\";\n      node_def += (pair.first + \":\" + regst_id + \":\" + std::to_string(pair.second.register_num()));\n      ++index;\n    }\n    node_def += \"}}\";\n    node_def +=\n        (\"\\\",tooltip=\\\"\" + TaskType_Name(task_proto.task_type()) + \"  \"\n         + std::to_string(task_proto.task_id()) + \"-\" + std::to_string(task_proto.machine_id())\n         + \":\" + std::to_string(task_proto.thrd_id()) + \":\"\n         + std::to_string(task_proto.parallel_ctx().parallel_id())\n         + \"\\\", shape=record, style=\\\"rounded,filled\\\"\"\n         + \",colorscheme=set312, fillcolor=\" + std::to_string((task_proto.job_id() % 12) + 1));\n    if (IsEsacNode(op_name)) { node_def += \",width=5,height=1.5\"; }\n    node_def += \"];\\n\";\n    InsertNodeDefByTaskProto(task_proto, node_def, pass_tag);\n    for (const auto& pair : task_proto.consumed_regst_desc_id()) {\n      for (int64_t regst_desc_id : pair.second.regst_desc_id()) {\n        task_id2consumer_regst_id2name[task_proto.task_id()][regst_desc_id] = pair.first;\n      }\n    }\n  }\n\n  log_stream << \"digraph merged_plan_graph {\\n\";\n  log_stream << \"#splines=\\\"ortho\\\";\\n\";\n  log_stream << \"#rankdir=TB;\\n\";\n  log_stream << \"#nodesep=1.3;\\n\";\n  log_stream << \"#ranksep=1.3;\\n\";\n  log_stream << \"node[color=\\\"gray\\\"];\\n\";\n  // main_node and copy_comm_net_node graph\n  for (const std::string& main_node : main_node_list) { log_stream << main_node; }\n  for (const std::string& copy_comm_net_node : copy_comm_net_node_list) {\n    log_stream << copy_comm_net_node;\n  }\n  // sub graph\n  for (size_t machine_id : process_ranks) {\n    std::string machine_name = \"machine_\" + std::to_string(machine_id);\n    log_stream << \"subgraph cluster_\" << machine_name << \" { label = \\\"\" << machine_name << \"\\\";\\n\";\n    log_stream << \"style=\\\"rounded\\\";\\n\";\n    {\n      for (const auto& job_id : machine_id2host_node_list_job_ids[machine_id]) {\n        std::string job_name = plan.job_confs().job_id2job_conf().at(job_id).job_name();\n        job_name += (std::string(\":\") + std::to_string(job_id));\n        if (job_id != plan.job_confs().job_id2job_conf().size() - 1) {\n          log_stream << \"subgraph cluster_job_\" << std::to_string(job_id) << \" { label = \\\"\"\n                     << job_name << \"\\\";\\n\";\n          log_stream << \"style=\\\"rounded\\\";\\n\";\n        }\n        for (const std::string& host_node_def :\n             machine_id2job_id2host_node_list[machine_id][job_id]) {\n          log_stream << host_node_def;\n        }\n        if (machine_id2device_id2node_list_job_ids[machine_id].find(job_id)\n            != machine_id2device_id2node_list_job_ids[machine_id].end()) {\n          for (size_t device_id = 0; device_id < gpu_device_num; ++device_id) {\n            std::string device_name = machine_name + \"_device_\" + std::to_string(device_id);\n            log_stream << \"#subgraph cluster_\" << device_name << \" { label = \\\"\" << device_name\n                       << \"\\\";\\n\";\n            log_stream << \"#color=\\\"skyblue\\\";\\n\";\n            log_stream << \"#fillcolor=\\\"azure\\\";\\n\";\n            log_stream << \"#style=\\\"rounded,filled\\\";\\n\";\n            for (const auto& device_node_def :\n                 machine_id2job_id_device_id2node_list[machine_id][job_id][device_id]) {\n              log_stream << device_node_def;\n            }\n            log_stream << \"#}\\n\";\n          }\n          machine_id2device_id2node_list_job_ids[machine_id].erase(job_id);\n        }\n\n        if (job_id != plan.job_confs().job_id2job_conf().size() - 1) { log_stream << \"}\\n\"; }\n      }\n      for (const auto& job_id : machine_id2device_id2node_list_job_ids[machine_id]) {\n        std::string job_name = plan.job_confs().job_id2job_conf().at(job_id).job_name();\n        job_name += (std::string(\":\") + std::to_string(job_id));\n        if (job_id != plan.job_confs().job_id2job_conf().size() - 1) {\n          log_stream << \"subgraph cluster_job_\" << std::to_string(job_id) << \" { label = \\\"\"\n                     << job_name << \"\\\";\\n\";\n          log_stream << \"style=\\\"rounded\\\";\\n\";\n        }\n        for (size_t device_id = 0; device_id < gpu_device_num; ++device_id) {\n          std::string device_name = machine_name + \"_device_\" + std::to_string(device_id);\n          log_stream << \"#subgraph cluster_\" << device_name << \" { label = \\\"\" << device_name\n                     << \"\\\";\\n\";\n          log_stream << \"#color=\\\"skyblue\\\";\\n\";\n          log_stream << \"#fillcolor=\\\"azure\\\";\\n\";\n          log_stream << \"#style=\\\"rounded,filled\\\";\\n\";\n          for (const auto& device_node_def :\n               machine_id2job_id_device_id2node_list[machine_id][job_id][device_id]) {\n            log_stream << device_node_def;\n          }\n          log_stream << \"#}\\n\";\n        }\n        if (job_id != plan.job_confs().job_id2job_conf().size() - 1) { log_stream << \"}\\n\"; }\n      }\n    }\n    log_stream << \"}\\n\";\n  }\n\n  // produce/consume edge\n  for (const TaskProto& task_proto : plan.task()) {\n    for (const auto& pair : task_proto.produced_regst_desc()) {\n      const RegstDescProto& regst = pair.second;\n      std::string src_node = \"task\" + std::to_string(task_proto.task_id());\n      // src_node += \":regst_desc_\" + std::to_string(regst.regst_desc_id());\n      for (int64_t consumer_task_id : regst.consumer_task_id()) {\n        std::string dst_node = \"task\" + std::to_string(consumer_task_id);\n        // dst_node +=  \":task_node_\" + std::to_string(consumer_task_id);\n        std::string consumer_regst_name =\n            task_id2consumer_regst_id2name[consumer_task_id][regst.regst_desc_id()];\n        std::string consumer_op_name = task_id2op_name[consumer_task_id];\n        std::string producer_regst_name = pair.first;\n        std::string producer_op_name = task_id2op_name[task_proto.task_id()];\n        std::string tooltip = producer_op_name + \" : \" + producer_regst_name + \" -> \"\n                              + consumer_op_name + \" : \" + consumer_regst_name;\n        if (IsEsac2ReentrantLockEdge(producer_op_name, consumer_op_name)) {\n          log_stream << dst_node << \"->\" << src_node\n                     << \"[arrowhead=\\\"invempty\\\",fontcolor=\\\"red\\\",color=\\\"red\\\",taillabel=\\\"\"\n                     << consumer_regst_name << \"\\\",tailtooltip=\\\"\" << tooltip;\n        } else {\n          log_stream << src_node << \"->\" << dst_node << \"[\"\n                     << GenEdgeColorStr(regst.regst_desc_type()) << \",headlabel=\\\"\"\n                     << consumer_regst_name << \"\\\",headtooltip=\\\"\" << tooltip;\n        }\n        log_stream << \"\\\",tooltip=\\\"\" << tooltip << \"\\\",arrowsize=0.5,labeldistance=1.5,penwidth=2\"\n                   << \"];\\n\";\n      }\n    }\n  }\n  log_stream << \"}\\n\";\n}\n\nstd::function<RegstDescProto*(int64_t)> PlanUtil::MakeMutRegstDesc4Id(Plan* plan) {\n  auto regst_desc_id2regst_desc = std::make_shared<HashMap<int64_t, RegstDescProto*>>();\n  for (int i = 0; i < plan->task_size(); i++) {\n    TaskProto* task = plan->mutable_task(i);\n    for (auto& pair : *task->mutable_produced_regst_desc()) {\n      int64_t regst_desc_id = pair.second.regst_desc_id();\n      CHECK(regst_desc_id2regst_desc->insert({regst_desc_id, &pair.second}).second)\n          << \"regst_desc_id2regst_desc has got duplicated regst_desc_id \" << regst_desc_id;\n    }\n  }\n  return [regst_desc_id2regst_desc](int64_t regst_desc_id) -> RegstDescProto* {\n    auto iter = regst_desc_id2regst_desc->find(regst_desc_id);\n    CHECK(iter != regst_desc_id2regst_desc->end())\n        << \"regst_desc_id \" << regst_desc_id << \" can't be found in plan.\";\n    return iter->second;\n  };\n}\n\nvoid PlanUtil::SetForceInplaceMemBlock(Plan* plan, int64_t limited_rank) {\n  auto RegstDesc4Id = MakeMutRegstDesc4Id(plan);\n  for (int i = 0; i < plan->task_size(); i++) {\n    TaskProto* task = plan->mutable_task(i);\n    // When do seperation compilation, some rank's plan (such as rank 0) has other ranks task node\n    // for compilation. There is no need to set mem block for other ranks task node.\n    if (limited_rank >= 0 && task->machine_id() != limited_rank) { continue; }\n    for (auto& pair : *task->mutable_produced_regst_desc()) {\n      RegstDescProto* regst_desc = &pair.second;\n      if (regst_desc->has_force_inplace_consumed_regst_desc_id()) {\n        int64_t force_id = regst_desc->force_inplace_consumed_regst_desc_id();\n        const RegstDescProto* in_regst_desc = RegstDesc4Id(force_id);\n        CHECK(!in_regst_desc->enable_reuse_mem());\n        CHECK(!regst_desc->enable_reuse_mem());\n        CHECK_NE(in_regst_desc->mem_block_id(), -1);\n        CHECK_EQ(in_regst_desc->mem_block_offset(), 0);\n        CHECK_EQ(regst_desc->mem_block_offset(), 0);\n        CHECK_EQ(in_regst_desc->register_num(), regst_desc->register_num());\n        CHECK(in_regst_desc->mem_case() == regst_desc->mem_case());\n        RtRegstDesc in_regst_rt(*in_regst_desc);\n        RtRegstDesc regst_rt(*regst_desc);\n        CHECK_EQ(in_regst_rt.TotalByteSize4AllRegst(), regst_rt.TotalByteSize4AllRegst());\n        CHECK_EQ(in_regst_rt.TotalMainByteSize4AllRegst(), regst_rt.TotalMainByteSize4AllRegst());\n        CHECK_EQ(in_regst_rt.TotalSeparatedHeaderByteSize4AllRegst(),\n                 regst_rt.TotalSeparatedHeaderByteSize4AllRegst());\n        regst_desc->set_mem_block_id(in_regst_desc->mem_block_id());\n        regst_desc->set_inplace_consumed_regst_desc_id(force_id);\n        if (in_regst_desc->has_separated_header_mem_block_id()) {\n          CHECK(regst_desc->has_separated_header_mem_block_id());\n          regst_desc->set_separated_header_mem_block_id(\n              in_regst_desc->separated_header_mem_block_id());\n        }\n        VLOG(3) << \" set force inplace from \" << regst_desc->DebugString() << \" to \"\n                << in_regst_desc->DebugString();\n      }\n    }\n  }\n}\n\nvoid PlanUtil::DumpCtrlRegstInfoToPlan(Plan* plan) {\n  auto* ctrl_regst_desc_id2producer_task_id =\n      plan->mutable_ctrl_regst_desc_info()->mutable_ctrl_regst_desc_id2producer_task_id();\n  for (const TaskProto& task : plan->task()) {\n    for (const auto& pair : task.produced_regst_desc()) {\n      if (pair.second.regst_desc_type().has_ctrl_regst_desc()) {\n        ctrl_regst_desc_id2producer_task_id->insert(\n            {pair.second.regst_desc_id(), pair.second.producer_task_id()});\n      }\n    }\n  }\n}\n\nnamespace {\n\nbool IsCollectiveBoxingTaskType(TaskType task_type) {\n  return task_type == TaskType::kCollectiveBoxingGeneric;\n}\n\nbool IsCollectiveBoxingNode(const PlanTaskNode* node) {\n  const TaskType task_type = node->task_proto()->task_type();\n  return IsCollectiveBoxingTaskType(task_type);\n}\n\nconst boxing::collective::RankDesc& GetRankDesc(const OperatorConf& conf) {\n  if (conf.has_collective_boxing_generic_conf()) {\n    return conf.collective_boxing_generic_conf().rank_desc();\n  } else {\n    UNIMPLEMENTED();\n  }\n}\n\nconst boxing::collective::RankDesc& GetRankDesc(Plan* plan, const TaskProto& task_proto) {\n  CHECK_EQ(task_proto.exec_sequence().exec_node_size(), 1);\n  return GetRankDesc(PlanUtil::GetOpAttribute(plan, task_proto.job_id(),\n                                              task_proto.exec_sequence().exec_node(0).kernel_conf())\n                         .op_conf());\n}\n\nstruct CollectiveBoxingRequestInfo {\n  boxing::collective::OpDesc op_desc;\n  std::map<int64_t, const PlanTaskNode*> rank2node;\n  int64_t order;\n  int64_t dependency_depth;\n};\n\nvoid GetDeviceDesc(const TaskProto* task_proto, boxing::collective::DeviceDesc* device_desc) {\n  device_desc->set_machine_id(task_proto->machine_id());\n  const StreamId stream_id = PlanUtil::GetStreamId(*task_proto);\n  const DeviceId& device_id = stream_id.device_id();\n  device_desc->set_device_type(device_id.device_type());\n  device_desc->set_device_id(device_id.device_index());\n}\n\n}  // namespace\n\nvoid PlanUtil::GenCollectiveBoxingPlan(Job* job, Plan* plan) {\n  using namespace boxing::collective;\n\n  RequestSet* request_set = &(*plan->mutable_collective_boxing_plan()\n                                   ->mutable_job_id2request_set())[GlobalJobDesc().job_id()];\n  const int64_t cb_task_count = std::count_if(\n      plan->task().cbegin(), plan->task().cend(),\n      [](const TaskProto& task) { return IsCollectiveBoxingTaskType(task.task_type()); });\n  if (cb_task_count == 0) { return; }\n\n  PlanTaskGraph plan_task_graph(*plan);\n  int64_t dependency_depth = 0;\n  int64_t order = 0;\n  HashSet<const PlanTaskNode*> all_visited;\n  while (true) {\n    std::list<const PlanTaskNode*> src_nodes;\n    plan_task_graph.ForEachNode([&](const PlanTaskNode* node) {\n      if (all_visited.count(node) != 0) { return; }\n      int64_t in_cnt = 0;\n      node->ForEachNodeOnInEdge([&](const PlanTaskNode* node_on_in_edge) {\n        if (all_visited.count(node_on_in_edge) != 0) { return; }\n        in_cnt += 1;\n      });\n      if (in_cnt == 0) { src_nodes.emplace_back(node); }\n    });\n    if (src_nodes.empty()) { break; }\n    auto ForEachNodeOnInEdge = [&](const PlanTaskNode* node,\n                                   const std::function<void(const PlanTaskNode*)>& Handler) {\n      node->ForEachNodeOnInEdge([&](const PlanTaskNode* node_on_in_edge) {\n        if (all_visited.count(node_on_in_edge) == 0) { Handler(node_on_in_edge); }\n      });\n    };\n    auto ForEachNodeOnOutEdge = [&](const PlanTaskNode* node,\n                                    const std::function<void(const PlanTaskNode*)>& Handler) {\n      if (!IsCollectiveBoxingNode(node)) {\n        node->ForEachNodeOnOutEdge([&](const PlanTaskNode* node_on_out_edge) {\n          bool has_unvisited_collective_boxing_node_on_in_edges = false;\n          node_on_out_edge->ForEachNodeOnInEdge([&](const PlanTaskNode* node_on_in_edge) {\n            if (!has_unvisited_collective_boxing_node_on_in_edges\n                && IsCollectiveBoxingNode(node_on_in_edge)\n                && all_visited.count(node_on_in_edge) == 0) {\n              has_unvisited_collective_boxing_node_on_in_edges = true;\n            }\n          });\n          if (!has_unvisited_collective_boxing_node_on_in_edges) { Handler(node_on_out_edge); }\n        });\n      }\n    };\n    HashSet<const PlanTaskNode*> visited;\n    std::vector<const PlanTaskNode*> collective_boxing_nodes;\n    plan_task_graph.TopoForEachNode(src_nodes, ForEachNodeOnInEdge, ForEachNodeOnOutEdge,\n                                    [&](const PlanTaskNode* node) {\n                                      visited.insert(node);\n                                      if (IsCollectiveBoxingNode(node)) {\n                                        collective_boxing_nodes.emplace_back(node);\n                                      }\n                                    });\n    if (collective_boxing_nodes.empty()) { break; }\n    HashMap<std::string, CollectiveBoxingRequestInfo> name2request_info;\n    for (const PlanTaskNode* node : collective_boxing_nodes) {\n      const TaskProto* task_proto = node->task_proto();\n      const RankDesc& rank_desc = GetRankDesc(plan, *task_proto);\n      CHECK_GE(rank_desc.rank(), 0);\n      CHECK_LT(rank_desc.rank(), rank_desc.op_desc().num_ranks());\n      const std::string& name = rank_desc.op_desc().name();\n      boxing::collective::DeviceDesc device_desc;\n      GetDeviceDesc(task_proto, &device_desc);\n      auto it = name2request_info.find(name);\n      if (it == name2request_info.end()) {\n        CollectiveBoxingRequestInfo request_info{\n            .op_desc = rank_desc.op_desc(),\n            .rank2node = {std::make_pair(rank_desc.rank(), node)},\n            .order = order,\n            .dependency_depth = dependency_depth,\n        };\n        name2request_info.emplace(std::make_pair(name, std::move(request_info)));\n        order += 1;\n      } else {\n        CHECK(it->second.op_desc == rank_desc.op_desc());\n        CHECK(it->second.rank2node.emplace(std::make_pair(rank_desc.rank(), node)).second);\n      }\n    }\n    int64_t collected = 0;\n    for (const auto& name7request_info : name2request_info) {\n      const CollectiveBoxingRequestInfo& info = name7request_info.second;\n      if (info.rank2node.size() == info.op_desc.num_ranks()) {\n        collected += 1;\n        boxing::collective::RequestDesc* request_desc = request_set->mutable_request()->Add();\n        *request_desc->mutable_op_desc() = info.op_desc;\n        for (int64_t i = 0; i < info.op_desc.num_ranks(); ++i) {\n          GetDeviceDesc(info.rank2node.at(i)->task_proto(),\n                        request_desc->mutable_device_set()->mutable_device()->Add());\n        }\n        request_desc->set_order(info.order);\n        request_desc->set_dependency_depth(info.dependency_depth);\n      } else {\n        CHECK_LT(info.rank2node.size(), info.op_desc.num_ranks());\n        for (const auto& pair : info.rank2node) { visited.erase(pair.second); }\n      }\n    }\n    CHECK_GT(collected, 0);\n    all_visited.insert(visited.begin(), visited.end());\n    ++dependency_depth;\n  }\n}\n\nvoid PlanUtil::GenRegisterHint(Plan* plan) {\n  HashSet<int64_t> multi_regst_regst_desc_ids;\n  for (const TaskProto& task : plan->task()) {\n    for (const auto& pair : task.produced_regst_desc()) {\n      if (pair.second.register_num() != 1 || task.task_type() == TaskType::kRepeat) {\n        multi_regst_regst_desc_ids.emplace(pair.second.regst_desc_id());\n      }\n    }\n  }\n  for (TaskProto& task : *(plan->mutable_task())) {\n    bool all_register_num_eq_one = true;\n    for (const auto& pair : task.produced_regst_desc()) {\n      if (pair.second.register_num() != 1) {\n        all_register_num_eq_one = false;\n        break;\n      }\n    }\n    for (const auto& pair : task.consumed_regst_desc_id()) {\n      if (!all_register_num_eq_one) { break; }\n      for (auto regst_desc_id : pair.second.regst_desc_id()) {\n        if (multi_regst_regst_desc_ids.count(regst_desc_id) > 0) {\n          all_register_num_eq_one = false;\n          break;\n        }\n      }\n    }\n    task.set_all_register_num_eq_one_hint(all_register_num_eq_one);\n  }\n}\n\nnamespace {\n\nstruct MemBlockMemoryInfo {\n  int64_t mem_block_id;\n  int64_t mem_block_mem_size;\n  int64_t regst_num;\n  std::vector<int64_t> ordered_regst_desc_id;\n  MemBlockMemoryInfo() : mem_block_id(-1), mem_block_mem_size(-1), regst_num(-1) {}\n};\n\nstruct ChunkMemoryInfo {\n  int64_t chunk_id;\n  int64_t chunk_mem_size;\n  std::vector<int64_t> mem_block_ids;\n  ChunkMemoryInfo() : chunk_id(-1), chunk_mem_size(-1) {}\n};\n\nstruct RankDeviceMemoryInfo {\n  int64_t rank_id;\n  int64_t device_id;\n  ChunkMemoryInfo chunk_info;\n  int64_t total_mem_size;\n  int64_t not_reused_mem_size;\n  std::vector<int64_t> not_reused_mem_block_ids;\n  int64_t eager_variable_total_mem_size;\n  std::vector<int64_t> eager_variable_mem_block_ids;\n  RankDeviceMemoryInfo()\n      : rank_id(-1),\n        device_id(-1),\n        total_mem_size(0),\n        not_reused_mem_size(0),\n        eager_variable_total_mem_size(0) {}\n};\n\n}  // namespace\n\nvoid PlanUtil::PlanMemoryLog(Plan* plan, const std::string& plan_name) {\n  std::vector<RankDeviceMemoryInfo> rank_device_memory_infos(GlobalProcessCtx::WorldSize(),\n                                                             RankDeviceMemoryInfo());\n  HashMap<int64_t, MemBlockMemoryInfo> mem_block_id2info;\n  HashMap<int64_t, const RegstDescProto*> regst_desc_id2regst;\n\n  for (const ChunkProto& chunk : plan->block_chunk_list().chunk()) {\n    int64_t rank_id = chunk.machine_id();\n    auto& info = rank_device_memory_infos[rank_id];\n    info.rank_id = rank_id;\n    if (!memory::IsHostMem(chunk.mem_case())) { info.device_id = chunk.mem_case().device_id(); }\n    info.total_mem_size += chunk.mem_size();\n    info.chunk_info.chunk_id = chunk.chunk_id();\n    info.chunk_info.chunk_mem_size = chunk.mem_size();\n  }\n\n  for (const MemBlockProto& mem_block : plan->block_chunk_list().mem_block()) {\n    int64_t mem_block_id = mem_block.mem_block_id();\n    mem_block_id2info.emplace(mem_block_id, MemBlockMemoryInfo());\n    auto& info = mem_block_id2info.at(mem_block_id);\n    info.mem_block_id = mem_block_id;\n    info.mem_block_mem_size = mem_block.mem_size();\n    auto& rank_memory_info = rank_device_memory_infos.at(mem_block.machine_id());\n    if (!memory::IsHostMem(mem_block.mem_case())) {\n      if (mem_block.has_chunk_id()) {\n        rank_memory_info.chunk_info.mem_block_ids.push_back(mem_block_id);\n      } else {\n        if (mem_block.has_variable_op_name()) {\n          rank_memory_info.eager_variable_mem_block_ids.push_back(mem_block_id);\n          rank_memory_info.eager_variable_total_mem_size += mem_block.mem_size();\n        } else {\n          rank_memory_info.not_reused_mem_block_ids.push_back(mem_block_id);\n          rank_memory_info.not_reused_mem_size += mem_block.mem_size();\n        }\n        rank_memory_info.total_mem_size += mem_block.mem_size();\n      }\n    }\n  }\n\n  for (const auto& task : plan->task()) {\n    for (const auto& pair : task.produced_regst_desc()) {\n      const auto& regst = pair.second;\n      if (regst.regst_desc_type().has_data_regst_desc()\n          && mem_block_id2info.find(regst.mem_block_id()) != mem_block_id2info.end()) {\n        mem_block_id2info.at(regst.mem_block_id())\n            .ordered_regst_desc_id.push_back(regst.regst_desc_id());\n        regst_desc_id2regst.emplace(regst.regst_desc_id(), &regst);\n      }\n    }\n  }\n\n  auto CompMemBlock = [&](int64_t a, int64_t b) {\n    return mem_block_id2info[a].mem_block_mem_size > mem_block_id2info[b].mem_block_mem_size;\n  };\n\n  auto B2MiB = [](int64_t val) { return val * 1.0 / 1000000.0; };\n\n  for (auto& rank_memory_info : rank_device_memory_infos) {\n    std::sort(rank_memory_info.chunk_info.mem_block_ids.begin(),\n              rank_memory_info.chunk_info.mem_block_ids.end(), CompMemBlock);\n    std::sort(rank_memory_info.not_reused_mem_block_ids.begin(),\n              rank_memory_info.not_reused_mem_block_ids.end(), CompMemBlock);\n    std::sort(rank_memory_info.eager_variable_mem_block_ids.begin(),\n              rank_memory_info.eager_variable_mem_block_ids.end(), CompMemBlock);\n    LOG(INFO) << \"\\n Graph name \" << plan_name << \" in Rank: \" << rank_memory_info.rank_id\n              << \", Device: \" << rank_memory_info.device_id << \" needs to allocate [ \"\n              << B2MiB(rank_memory_info.total_mem_size)\n              << \" MiB ] device memory. \\n   In general, Chunk id: \"\n              << rank_memory_info.chunk_info.chunk_id << \"  memory is [ \"\n              << B2MiB(rank_memory_info.chunk_info.chunk_mem_size)\n              << \" MiB ] with mem_block_num = \" << rank_memory_info.chunk_info.mem_block_ids.size()\n              << \"\\n        Unreused memory not eager var is  [ \"\n              << B2MiB(rank_memory_info.not_reused_mem_size)\n              << \" MiB ] with mem_block_num = \" << rank_memory_info.not_reused_mem_block_ids.size()\n              << \"\\n        Eager Variable Tensor total memory is [ \"\n              << B2MiB(rank_memory_info.eager_variable_total_mem_size)\n              << \" MiB ] with mem_block_num = \"\n              << rank_memory_info.eager_variable_mem_block_ids.size() << \"\\n\";\n  }\n\n  auto Vlog3ForMemBlockDetails = [&](int64_t device_id, const std::vector<int64_t>& mem_block_ids,\n                                     const std::string& prefix) {\n    for (int64_t mem_block_id : mem_block_ids) {\n      CHECK(mem_block_id2info.find(mem_block_id) != mem_block_id2info.end());\n      const auto& mem_block_info = mem_block_id2info.at(mem_block_id);\n      if (mem_block_info.ordered_regst_desc_id.size() != 1) { continue; }\n      const auto* regst = regst_desc_id2regst.at(mem_block_info.ordered_regst_desc_id.at(0));\n      const auto& data_regst = regst->regst_desc_type().data_regst_desc();\n      const auto& lbi2blob_desc_pair = data_regst.lbi2blob_desc(0);\n      std::string tensor_name = GenLogicalBlobName(lbi2blob_desc_pair.lbi());\n      const auto& blob_desc = lbi2blob_desc_pair.blob_desc();\n      VLOG(3) << \"In Device: \" << device_id << \" Memblock id: \" << mem_block_id << prefix\n              << \" size: \" << B2MiB(mem_block_info.mem_block_mem_size)\n              << \" MiB, name: \" << tensor_name << \"\\nshape: \" << Shape(blob_desc.shape()).ToString()\n              << \" ,dtype: \" << DataType_Name(blob_desc.data_type());\n    }\n  };\n\n  for (const auto& rank_memory_info : rank_device_memory_infos) {\n    int64_t chunk_id = rank_memory_info.chunk_info.chunk_id;\n    int64_t device_id = rank_memory_info.device_id;\n    VLOG(2) << \"========================= \"\n            << \"In Device : \" << device_id << \" Chunk Memory info details:\";\n    for (int64_t mem_block_id : rank_memory_info.chunk_info.mem_block_ids) {\n      CHECK(mem_block_id2info.find(mem_block_id) != mem_block_id2info.end());\n      const auto& mem_block_info = mem_block_id2info.at(mem_block_id);\n      VLOG(2) << \"     In Device: \" << device_id << \" Chunk id: \" << chunk_id\n              << \" MemBlock id: \" << mem_block_id\n              << \" has num = \" << mem_block_info.ordered_regst_desc_id.size()\n              << \" tensor with mem size = \" << B2MiB(mem_block_info.mem_block_mem_size);\n      for (int64_t i = 0; i < mem_block_info.ordered_regst_desc_id.size(); ++i) {\n        const auto* regst = regst_desc_id2regst.at(mem_block_info.ordered_regst_desc_id.at(i));\n        const auto& data_regst = regst->regst_desc_type().data_regst_desc();\n        const auto& lbi2blob_desc_pair = data_regst.lbi2blob_desc(0);\n        std::string tensor_name = GenLogicalBlobName(lbi2blob_desc_pair.lbi());\n        const auto& blob_desc = lbi2blob_desc_pair.blob_desc();\n        std::string alloc_order = \"inplaced\";\n        if (regst->has_alloc_before_actor()) {\n          alloc_order = std::to_string(regst->alloc_before_actor());\n        }\n        std::string free_order = \"inplaced\";\n        if (regst->has_free_after_actor()) {\n          free_order = std::to_string(regst->free_after_actor());\n        }\n        VLOG(3) << \"In Chunk id: \" << chunk_id << \", MemBlock id: \" << mem_block_id\n                << \" Order: \" << i\n                << \" ,duration: \" << (regst->free_after_actor() - regst->alloc_before_actor() + 1)\n                << \" ,size: \" << B2MiB(BlobDesc(blob_desc).AlignedTotalByteSize())\n                << \" MiB, name: \" << tensor_name\n                << \"\\nshape: \" << Shape(blob_desc.shape()).ToString()\n                << \" ,dtype: \" << DataType_Name(blob_desc.data_type())\n                << \" ,alloc_order: \" << alloc_order << \" ,free_order: \" << free_order;\n      }\n    }\n\n    Vlog3ForMemBlockDetails(device_id, rank_memory_info.not_reused_mem_block_ids, \" Unreused \");\n    Vlog3ForMemBlockDetails(device_id, rank_memory_info.eager_variable_mem_block_ids,\n                            \" EagerVariable \");\n  }\n}\n\nvoid PlanUtil::GenLightPlan(Plan* plan, const std::string& plan_name, int64_t limited_rank) {\n  // NOTE(chengcheng): ordered_tasks is NOT exec order, just task id order.\n  std::vector<const TaskProto*> ordered_tasks;\n  for (const TaskProto& task : plan->task()) { ordered_tasks.push_back(&task); }\n  auto CompTask = [](const TaskProto* a, const TaskProto* b) {\n    return a->task_id() < b->task_id();\n  };\n  std::sort(ordered_tasks.begin(), ordered_tasks.end(), CompTask);\n\n  HashMap<int64_t, std::string> task_id2name;\n  HashMap<int64_t, const TaskProto*> task_id2proto;\n  HashMap<int64_t, std::string> regst_id2name;\n  HashMap<int64_t, const RegstDescProto&> regst_id2proto;\n  for (const auto* task : ordered_tasks) {\n    const auto& exec_seq = task->exec_sequence();\n    std::string name;\n    if (exec_seq.exec_node_size() >= 1) {\n      const auto& kernel_conf = task->exec_sequence().exec_node(0).kernel_conf();\n      if (kernel_conf.has_op_attribute_ref()) {\n        name = kernel_conf.op_attribute_ref();\n      } else {\n        name = kernel_conf.op_attribute().op_conf().name();\n      }\n    } else {\n      name = TaskType_Name(task->task_type());\n    }\n    task_id2name.emplace(task->task_id(), name);\n    task_id2proto.emplace(task->task_id(), task);\n    CHECK(!name.empty());\n    for (const auto& pair : task->produced_regst_desc()) {\n      std::string regst_name = name + \"/\" + pair.first;\n      regst_id2name.emplace(pair.second.regst_desc_id(), regst_name);\n      regst_id2proto.emplace(pair.second.regst_desc_id(), pair.second);\n    }\n  }\n\n  auto RegstId2TensorStr = [&](int64_t regst_id) -> std::string {\n    CHECK(regst_id2proto.find(regst_id) != regst_id2proto.end())\n        << \" regst_id2proto cannot find: \" << regst_id;\n    std::ostringstream ss;\n    ss << \"{\";\n    const RegstDescProto& regst = regst_id2proto.at(regst_id);\n    ss << \"regust_num: \" << std::to_string(regst.register_num());\n    ss << \", device: \" << *CHECK_JUST(DeviceTag4DeviceType(regst.mem_case().device_type()));\n    if (regst.regst_desc_type().has_data_regst_desc()) {\n      const DataRegstDesc& data = regst.regst_desc_type().data_regst_desc();\n      ss << \", time_shape: \" << Shape(data.time_shape()).ToString();\n      const BlobDescProto& blob = data.lbi2blob_desc(0).blob_desc();\n      ss << \", shape: \" << Shape(blob.shape()).ToString();\n      ss << \", dtype: \" << DataType_Name(blob.data_type());\n    } else {\n      ss << \", ctrl\";\n    }\n    ss << \"}\";\n    return ss.str();\n  };\n  std::vector<std::vector<const TaskProto*>> rank2ordered_task(GlobalProcessCtx::WorldSize(),\n                                                               std::vector<const TaskProto*>());\n  for (const auto* task : ordered_tasks) {\n    CHECK_LT(task->machine_id(), rank2ordered_task.size());\n    rank2ordered_task.at(task->machine_id()).push_back(task);\n  }\n  for (int64_t rank = 0; rank < GlobalProcessCtx::WorldSize(); ++rank) {\n    // Filter rank to generate log.\n    if (limited_rank >= 0 && rank != limited_rank) { continue; }\n    auto file_stream =\n        TeePersistentLogStream::Create(plan_name + \"_rank_\" + std::to_string(rank) + \"_light_plan\");\n    file_stream << \"rank : \" << std::to_string(rank) << \"\\n\";\n    CHECK_LT(rank, rank2ordered_task.size());\n    const auto& ordered_task_in_rank = rank2ordered_task.at(rank);\n    for (int64_t i = 0; i < ordered_task_in_rank.size(); ++i) {\n      CHECK_LT(i, ordered_task_in_rank.size());\n      const auto* task = ordered_task_in_rank.at(i);\n      int64_t task_id = task->task_id();\n      CHECK(task_id2name.find(task_id) != task_id2name.end())\n          << \" task_id2name cannot find\" << task_id;\n      int64_t thrd_id = task->thrd_id();\n      StreamId stream_id = DecodeStreamIdFromInt64(thrd_id);\n      file_stream << \"i : \" << std::to_string(i) << \" , actor id : \" << std::to_string(task_id)\n                  << \" thrd : \" << std::to_string(thrd_id) << \" name : \" << task_id2name.at(task_id)\n                  << \"\\n  chain_id : \" << std::to_string(task->chain_id())\n                  << \" order_in_chain : \" << std::to_string(task->order_in_chain())\n                  << \" device_type : \" << DeviceType_Name(stream_id.device_type())\n                  << \" stream_index : \" << std::to_string(stream_id.stream_index()) << \" {\\n\";\n      for (const auto& key2consume_regst : task->consumed_regst_desc_id()) {\n        std::string key = key2consume_regst.first;\n        for (int64_t consume_regst_id : key2consume_regst.second.regst_desc_id()) {\n          std::string other_rank_str = \"\";\n          CHECK(regst_id2proto.find(consume_regst_id) != regst_id2proto.end())\n              << \" regst_id2proto cannot find: \" << consume_regst_id;\n          int64_t consume_task_id = regst_id2proto.at(consume_regst_id).producer_task_id();\n          CHECK(task_id2proto.find(consume_task_id) != task_id2proto.end())\n              << \" task_id2proto cannot find: \" << consume_task_id;\n          int64_t other_rank = task_id2proto.at(consume_task_id)->machine_id();\n          if (other_rank != rank) { other_rank_str = \" , rank: \" + std::to_string(other_rank); }\n          CHECK(regst_id2name.find(consume_regst_id) != regst_id2name.end())\n              << \" regst_id2name cannot find: \" << consume_regst_id;\n          file_stream << \"  consume : \" << key << \" : <- [ \" << regst_id2name.at(consume_regst_id)\n                      << \" ] ( actor_id: \" << std::to_string(consume_task_id) << other_rank_str\n                      << \", regst: \" << RegstId2TensorStr(consume_regst_id) << \" )\\n\";\n        }\n      }\n      for (const auto& key2produce_regst : task->produced_regst_desc()) {\n        const RegstDescProto& regst = key2produce_regst.second;\n        file_stream << \"  produce : \" << key2produce_regst.first\n                    << \" regst: \" << RegstId2TensorStr(regst.regst_desc_id()) << \" {\\n\";\n        for (int64_t consumer_task_id : regst.consumer_task_id()) {\n          std::string other_rank_str = \"\";\n          CHECK(task_id2proto.find(consumer_task_id) != task_id2proto.end())\n              << \" task_id2proto cannot find \" << consumer_task_id;\n          CHECK(task_id2name.find(consumer_task_id) != task_id2name.end())\n              << \" task_id2name cannot find \" << consumer_task_id;\n          int64_t other_rank = task_id2proto.at(consumer_task_id)->machine_id();\n          if (other_rank != rank) { other_rank_str = \" , rank: \" + std::to_string(other_rank); }\n          file_stream << \"    -> [ \" << task_id2name.at(consumer_task_id)\n                      << \" ] ( actor_id: \" << std::to_string(consumer_task_id) << other_rank_str\n                      << \" )\\n\";\n        }\n        file_stream << \"  }\\n\";\n      }\n\n      file_stream << \"}\\n\";\n    }\n  }\n}\n\nconst oneflow::OpAttribute& PlanUtil::GetOpAttribute(const Plan* plan, int64_t job_id,\n                                                     const oneflow::KernelConf& kernel_conf) {\n  if (kernel_conf.has_op_attribute()) {\n    return kernel_conf.op_attribute();\n  } else if (kernel_conf.has_op_attribute_ref()) {\n    auto table_it = plan->job_id2op_attribute_ref_table().find(job_id);\n    CHECK(table_it != plan->job_id2op_attribute_ref_table().end())\n        << \"op attribute ref table not found for job id: \" << job_id;\n    ;\n    auto it = table_it->second.op_name2op_attribute().find(kernel_conf.op_attribute_ref());\n    CHECK(it != table_it->second.op_name2op_attribute().end())\n        << \"op attribute ref: \" << kernel_conf.op_attribute_ref() << \" not found\";\n    return it->second;\n  } else {\n    UNIMPLEMENTED() << \"kernel_conf must has either op_attribute or op_attribute_ref. kernel_conf: \"\n                    << kernel_conf.DebugString();\n  }\n}\n\nvoid PlanUtil::PopulateOpAttribute(\n    Plan* plan,\n    const PbMap<int64_t, ::oneflow::OpAttributeRefTable>& job_id2op_attribute_ref_table) {\n  for (auto& task : *plan->mutable_task()) {\n    if (task.exec_sequence().exec_node_size() == 1\n        && task.exec_sequence().exec_node(0).kernel_conf().has_op_attribute_ref()) {\n      auto* kernel_conf = task.mutable_exec_sequence()->mutable_exec_node(0)->mutable_kernel_conf();\n      auto table_it = job_id2op_attribute_ref_table.find(task.job_id());\n      CHECK(table_it != job_id2op_attribute_ref_table.end())\n          << \"op attribute ref table not found for job id: \" << task.job_id();\n      auto it = table_it->second.op_name2op_attribute().find(kernel_conf->op_attribute_ref());\n      CHECK(it != table_it->second.op_name2op_attribute().end())\n          << \"ref: \" << kernel_conf->op_attribute_ref() << \" not found\";\n      *kernel_conf->mutable_op_attribute() = it->second;\n      kernel_conf->clear_op_attribute_ref();\n    } else {\n      for (auto& exec_node : task.exec_sequence().exec_node()) {\n        CHECK(exec_node.kernel_conf().has_op_attribute())\n            << \"op_attribute absent, exec_node: \" << exec_node.DebugString();\n      }\n    }\n  }\n}\n\n/*static*/ StreamId PlanUtil::GetStreamId(const TaskProto& task) {\n  return DecodeStreamIdFromInt64(task.thrd_id());\n}\n\n/*static*/ int64_t PlanUtil::GetDeviceIndex(const TaskProto& task) {\n  return GetStreamId(task).device_id().device_index();\n}\n\n/*static*/ void PlanUtil::CreateOpAttributeRef(Plan* plan, int64_t job_id, TaskProto* task_proto) {\n  auto* job_id2op_attribute_ref_table = plan->mutable_job_id2op_attribute_ref_table();\n  CHECK(task_proto->exec_sequence().exec_node_size() == 1);\n  auto* exec_node = task_proto->mutable_exec_sequence()->mutable_exec_node(0);\n  CHECK(exec_node->kernel_conf().has_op_attribute());\n  const std::string op_name = exec_node->kernel_conf().op_attribute().op_conf().name();\n  auto* op_name2op_attribute =\n      (*job_id2op_attribute_ref_table)[job_id].mutable_op_name2op_attribute();\n  auto find_it = op_name2op_attribute->find(op_name);\n  if (find_it == op_name2op_attribute->end()) {\n    op_name2op_attribute->insert(\n        {op_name, task_proto->exec_sequence().exec_node(0).kernel_conf().op_attribute()});\n  }\n  auto* kernel_conf =\n      task_proto->mutable_exec_sequence()->mutable_exec_node(0)->mutable_kernel_conf();\n  kernel_conf->set_op_attribute_ref(op_name);\n  // NOTE(levi): memory of op_attribute_ is released here.\n  kernel_conf->set_allocated_op_attribute(nullptr);\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job/plan_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_JOB_PLAN_UTIL_H_\n#define ONEFLOW_CORE_JOB_PLAN_UTIL_H_\n\n#include <functional>\n#include \"oneflow/core/common/protobuf.h\"\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/job/plan.pb.h\"\n#include \"oneflow/core/job/job.pb.h\"\n#include \"oneflow/core/graph/stream_id.h\"\n#include \"oneflow/core/graph/plan_task_graph.h\"\n\nnamespace oneflow {\n\nstruct PlanUtil {\n  static RegstDescProto* GetSoleProducedDataRegst(TaskProto* task_proto);\n  static std::function<const TaskProto*(int64_t)> MakeGetterTaskProto4TaskId(const Plan& plan);\n  // limited_rank equals -1 means taking care of all ranks.\n  // Otherwise, only take care of rank limited_rank.\n  static void MergeMemBlockIdByLogicalChainId(Plan* plan, const Job& job,\n                                              int64_t limited_rank = -1);\n  static void SetUniqueMemBlockId4UnreusedMemRegst(Plan* plan);\n  static void GenMemBlockAndChunk4Plan(Plan* plan);\n  static void GenMemBlockAndChunkWithVariableOpNames4Plan(\n      Plan* plan, const HashSet<std::string>& variable_op_names);\n  static void CleanUselessMemBlockAndCheckValid(Plan* plan);\n  static void ToDotFile(const Plan& plan, const std::string& filepath);\n  static std::function<RegstDescProto*(int64_t)> MakeMutRegstDesc4Id(Plan* plan);\n  // limited_rank equals -1 means taking care of all ranks.\n  // Otherwise, only take care of rank limited_rank.\n  static void SetForceInplaceMemBlock(Plan* plan, int64_t limited_rank = -1);\n  static void DumpCtrlRegstInfoToPlan(Plan* plan);\n  static void GenCollectiveBoxingPlan(Job* job, Plan* plan);\n  static void GenRegisterHint(Plan* plan);\n  // Generate readable plan log from plan proto.\n  // Use filter_rank to choose which rank to generate. When filter_rank is -1, all rank will be\n  // generated. The default value of filter_rank is -1.\n  static void GenLightPlan(Plan* plan, const std::string& plan_name, int64_t limited_rank = -1);\n  static void PlanMemoryLog(Plan* plan, const std::string& plan_name);\n  static const oneflow::OpAttribute& GetOpAttribute(const Plan* plan, int64_t job_id,\n                                                    const oneflow::KernelConf& kernel_conf);\n  // NOTE(chengcheng): recovery op_attr\n  static void PopulateOpAttribute(\n      Plan* plan,\n      const PbMap<int64_t, ::oneflow::OpAttributeRefTable>& job_id2op_attribute_ref_table);\n  static StreamId GetStreamId(const TaskProto& task);\n  static int64_t GetDeviceIndex(const TaskProto& task);\n  static void CreateOpAttributeRef(Plan* plan, int64_t job_id, TaskProto* task_proto);\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_JOB_PLAN_UTIL_H_\n"
  },
  {
    "path": "oneflow/core/job/qat_config_def.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/config_def.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nREGISTER_SCOPE_CONFIG_DEF().Bool(\"quantization_aware_training\", true,\n                                 \"enable quantization aware training\");\n\n}  // namespace\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job/rank_compiler.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/job/rank_compiler.h\"\n#include \"oneflow/core/device/cuda_util.h\"\n#include \"oneflow/core/job/global_for.h\"\n#include \"oneflow/core/job/intra_job_mem_sharing_util.h\"\n#include \"oneflow/core/job/plan_util.h\"\n#include \"oneflow/core/persistence/tee_persistent_log_stream.h\"\n#include \"oneflow/core/graph/op_graph.h\"\n#include \"oneflow/core/job_rewriter/job_completer.h\"\n#include \"oneflow/core/thread/thread_pool.h\"\n#include \"oneflow/core/common/blocking_counter.h\"\n#include \"oneflow/core/rpc/include/global_process_ctx.h\"\n\nnamespace oneflow {\n\nMaybe<void> RankCompiler::Compile(const HashSet<std::string>& var_op_names, Job* job,\n                                  Plan* plan) const {\n#ifdef WITH_CUDA\n  // Use the right device when some plan compilation needs cuda to avoid creating unnecessary cuda\n  // context on cuda:0.\n  CudaCurrentDeviceGuard guard(GetCudaDeviceIndex());\n#endif  // WITH_CUDA\n  auto task_gph = JUST(RankTaskGraph::New(boxing_task_graph_proto_, var_op_names, rank_));\n  using std::placeholders::_1;\n  const auto& IsNotMyDuty = [&](const CompTaskNode* comp_task_node) {\n    if (comp_task_node == nullptr) { return false; }\n    const auto& parallel_desc = comp_task_node->op_node()->parallel_desc();\n    return !task_gph->IsDutyRank(parallel_desc, comp_task_node->machine_id());\n  };\n  task_gph->ForEachNode([&](TaskNode* task_node) {\n    auto* comp_task_node = dynamic_cast<CompTaskNode*>(task_node);\n    if (IsNotMyDuty(comp_task_node)) {\n      auto* fake_consumed_regsts_provider =\n          dynamic_cast<FakeConsumedRegstProvider*>(comp_task_node);\n      CHECK_NOTNULL(fake_consumed_regsts_provider)->ConsumeFakeRegstsIf();\n    } else {\n      task_node->ConsumeAllRegsts();\n    }\n  });\n  task_gph->ForEachNode([&](TaskNode* task_node) {\n    auto* comp_task_node = dynamic_cast<CompTaskNode*>(task_node);\n    if (IsNotMyDuty(comp_task_node)) {\n      // Do nothing. because all consumed registers are fake.\n    } else {\n      task_node->PinConsumedRegst();\n    }\n  });\n  task_gph->TopoForEachNode(&TaskNode::Build);\n  task_gph->RemoveEmptyRegsts();\n  task_gph->TopoForEachNode(&TaskNode::InferTimeShapeIfMeaningful);\n  task_gph->DecideExecutionOrder();\n  task_gph->MergeChainAndAddOrderingCtrlEdgeInSameChain();\n  auto IsReachable = Singleton<OpGraph>::Get()->MakePredicatorIsOpNameDataOrCtrlReachable();\n  const JobDesc& job_desc = GlobalJobDesc();\n  if (job_desc.enable_inplace()) {\n    task_gph->ForEachGpuDeviceNodes([&](const HashSet<TaskNode*>& dev_nodes) {\n      if (dev_nodes.empty()) { return; }\n      if ((*dev_nodes.begin())->machine_id() != rank_) { return; }  // other ranks are ignored.\n      task_gph->EnableInplaceMemSharing(dev_nodes, IsReachable);\n    });\n  }\n  task_gph->ForEachEdge([&](TaskEdge* task_edge) { task_edge->CheckRegstLbiValid(); });\n\n  // put infomation from task_gph into plan.\n  task_gph->ForEachNode([&](TaskNode* task_node) {\n    if (task_node->IsMeaningLess()) { return; }\n    auto* comp_task_node = dynamic_cast<CompTaskNode*>(task_node);\n    if (comp_task_node != nullptr) {\n      const auto& parallel_desc = comp_task_node->op_node()->parallel_desc();\n      if (!task_gph->IsDutyRank(parallel_desc, task_node->machine_id())) {\n        auto* fake_consumed_regsts_provider =\n            dynamic_cast<FakeConsumedRegstProvider*>(comp_task_node);\n        CHECK_NOTNULL(fake_consumed_regsts_provider)->EraseFakeRegstsIf();\n      }\n    }\n    TaskProto task_proto;\n    task_node->ToProto(&task_proto);\n    if (task_node->GetTaskType() == kNormalForward || task_node->GetTaskType() == kRepeat\n        || task_node->GetTaskType() == kAcc) {\n      PlanUtil::CreateOpAttributeRef(plan, job_desc.job_id(), &task_proto);\n    }\n    plan->mutable_task()->Add(std::move(task_proto));\n  });\n\n  // post-process for plan and delete Singleton<OpGraph>.\n  auto* job_id2job_conf = plan->mutable_job_confs()->mutable_job_id2job_conf();\n  (*job_id2job_conf)[GlobalJobDesc().job_id()] = GlobalJobDesc().job_conf();\n  // NOTE(chengcheng): infer mem blob id & set inplace & add ctrl\n  IntraJobMemSharingUtil::InferMemBlockId4MemReusedRegst(plan);\n  PlanUtil::MergeMemBlockIdByLogicalChainId(plan, *job, rank_);\n  PlanUtil::SetUniqueMemBlockId4UnreusedMemRegst(plan);\n  PlanUtil::SetForceInplaceMemBlock(plan, rank_);\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job/rank_compiler.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_JOB_RANK_COMPILER_H_\n#define ONEFLOW_CORE_JOB_RANK_COMPILER_H_\n\n#include \"oneflow/core/common/protobuf.h\"\n#include \"oneflow/core/graph/task_graph.h\"\n#include \"oneflow/core/graph/boxing_task_graph.pb.h\"\n#include \"oneflow/core/job/plan.pb.h\"\n#include \"oneflow/core/operator/operator.h\"\n\nnamespace oneflow {\n\nclass RankCompiler final {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(RankCompiler);\n  RankCompiler(const std::shared_ptr<BoxingTaskGraphProto>& boxing_task_graph_proto, int64_t rank)\n      : boxing_task_graph_proto_(boxing_task_graph_proto), rank_(rank) {}\n  ~RankCompiler() = default;\n\n  Maybe<void> Compile(const HashSet<std::string>& var_op_names, Job* job, Plan* plan) const;\n\n private:\n  std::shared_ptr<BoxingTaskGraphProto> boxing_task_graph_proto_;\n  int64_t rank_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_JOB_RANK_COMPILER_H_\n"
  },
  {
    "path": "oneflow/core/job/rank_group.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <map>\n#include \"oneflow/core/job/rank_group.h\"\n#include \"oneflow/core/job/parallel_desc.h\"\n#include \"oneflow/core/common/device_type.h\"\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/common/container_util.h\"\n#include \"oneflow/core/common/decorator.h\"\n#include \"oneflow/core/rpc/include/global_process_ctx.h\"\n\nnamespace oneflow {\n\n/*static*/ Maybe<Symbol<RankGroup>> RankGroup::New(Symbol<ParallelDesc> parallel_desc) {\n  return DECORATE(&RankGroup::RawNew, ThreadLocal)(parallel_desc);\n}\n\n/*static*/ Maybe<Symbol<RankGroup>> RankGroup::RawNew(Symbol<ParallelDesc> parallel_desc) {\n  CHECK_EQ_OR_RETURN(parallel_desc->sorted_machine_ids().size(), parallel_desc->parallel_num());\n  const auto& sorted_machine_ids = parallel_desc->sorted_machine_ids();\n  return New(std::set<int64_t>{sorted_machine_ids.begin(), sorted_machine_ids.end()});\n}\n\n/*static*/ Maybe<Symbol<RankGroup>> RankGroup::New(const std::set<int64_t>& ranks) {\n  static thread_local std::map<std::set<int64_t>, Symbol<RankGroup>> map;\n  auto iter = map.find(ranks);\n  if (iter == map.end()) {\n    RankGroup rank_group;\n    JUST(rank_group.Init(ranks));\n    iter = map.emplace(ranks, SymbolOf(rank_group)).first;\n  }\n  return iter->second;\n}\n\nnamespace {\n\nMaybe<Symbol<ParallelDesc>> CalcDefaultParallelDesc(DeviceType device_type,\n                                                    Symbol<RankGroup> rank_group) {\n  ParallelConf parallel_conf;\n  parallel_conf.set_device_tag(*JUST(DeviceTag4DeviceType(device_type)));\n  JUST(rank_group->ForEachRank([&](int64_t rank) -> Maybe<void> {\n    int64_t local_rank = GlobalProcessCtx::LocalRank(rank);\n    parallel_conf.add_device_name(std::string(\"@\") + std::to_string(rank) + \":\"\n                                  + std::to_string(local_rank));\n    return Maybe<void>::Ok();\n  }));\n  return SymbolOf(ParallelDesc(parallel_conf));\n}\n\nauto* CachedDefaultParallelDesc = DECORATE(&CalcDefaultParallelDesc, ThreadLocal);\n\n}  // namespace\n\n/*static*/ Maybe<Symbol<ParallelDesc>> RankGroup::GetDefaultParallelDesc(\n    DeviceType device_type, Symbol<RankGroup> rank_group) {\n  return CachedDefaultParallelDesc(device_type, rank_group);\n}\n\nnamespace {\n\nMaybe<std::set<int64_t>> AllWorldRanks() {\n  const auto& ranks = std::make_shared<std::set<int64_t>>();\n  for (int i = 0; i < GlobalProcessCtx::WorldSize(); ++i) { ranks->insert(i); }\n  return ranks;\n}\n\n}  // namespace\n\n/*static*/ Maybe<Symbol<RankGroup>> RankGroup::DefaultRankGroup() {\n  const auto& all_wold_ranks = JUST(AllWorldRanks());\n  const auto& rank_group = JUST(RankGroup::New(*all_wold_ranks));\n  return rank_group;\n}\n\nMaybe<void> RankGroup::Init(const std::set<int64_t>& ranks) {\n  ranks_ = ranks;\n  // Initialize rank2next_rank_in_ring_ and rank2prev_rank_in_ring_\n  {\n    CHECK_GT_OR_RETURN(ranks.size(), 0);\n    int64_t last = *(--ranks.end());\n    for (int64_t i : ranks) {\n      CHECK_OR_RETURN(rank2next_rank_in_ring_.emplace(last, i).second);\n      CHECK_OR_RETURN(rank2prev_rank_in_ring_.emplace(i, last).second);\n      last = i;\n    }\n  }\n  // Initialize hash_value_\n  hash_value_ = 0;\n  for (int64_t i : ranks) { HashCombine(&hash_value_, i); }\n  return Maybe<void>::Ok();\n}\n\nMaybe<int64_t> RankGroup::GetNextRankInRing(int64_t rank) const {\n  return MapAt(rank2next_rank_in_ring_, rank);\n}\n\nMaybe<int64_t> RankGroup::GetNextRankInRing() const {\n  return GetNextRankInRing(GlobalProcessCtx::Rank());\n}\n\nMaybe<int64_t> RankGroup::GetPrevRankInRing(int64_t rank) const {\n  return MapAt(rank2prev_rank_in_ring_, rank);\n}\n\nMaybe<int64_t> RankGroup::GetPrevRankInRing() const {\n  return GetPrevRankInRing(GlobalProcessCtx::Rank());\n}\n\nbool RankGroup::ContainingCurrentRank() const {\n  return rank2next_rank_in_ring_.count(GlobalProcessCtx::Rank()) > 0;\n}\n\nMaybe<void> RankGroup::ForEachRank(const std::function<Maybe<void>(int64_t)>& DoEach) const {\n  for (int64_t i : ranks_) { JUST(DoEach(i)); }\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job/rank_group.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_JOB_RANK_GROUP_H_\n#define ONEFLOW_CORE_JOB_RANK_GROUP_H_\n\n#include <functional>\n#include <vector>\n#include <unordered_map>\n#include <set>\n#include <string>\n#include <memory>\n#include \"oneflow/core/common/symbol.h\"\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/common/device_type.h\"\n\nnamespace oneflow {\n\nclass ParallelDesc;\n\nclass RankGroup final {\n public:\n  ~RankGroup() = default;\n\n  static Maybe<Symbol<RankGroup>> New(const std::set<int64_t>& ranks);\n  static Maybe<Symbol<RankGroup>> New(Symbol<ParallelDesc> parallel_desc);\n  static Maybe<Symbol<RankGroup>> DefaultRankGroup();\n\n  static Maybe<Symbol<ParallelDesc>> GetDefaultParallelDesc(DeviceType device_type,\n                                                            Symbol<RankGroup> rank_group);\n\n  bool operator==(const RankGroup& that) const { return this->ranks_ == that.ranks_; }\n  bool operator!=(const RankGroup& that) const { return !(*this == that); }\n\n  size_t size() const { return ranks_.size(); }\n  size_t hash_value() const { return hash_value_; }\n  Maybe<int64_t> GetNextRankInRing(int64_t rank) const;\n  Maybe<int64_t> GetNextRankInRing() const;\n  Maybe<int64_t> GetPrevRankInRing(int64_t rank) const;\n  Maybe<int64_t> GetPrevRankInRing() const;\n  bool ContainingCurrentRank() const;\n\n  Maybe<void> ForEachRank(const std::function<Maybe<void>(int64_t)>&) const;\n\n private:\n  RankGroup() = default;\n  Maybe<void> Init(const std::set<int64_t>& ranks);\n  static Maybe<Symbol<RankGroup>> RawNew(Symbol<ParallelDesc> parallel_desc);\n\n  std::set<int64_t> ranks_;\n  std::unordered_map<int64_t, int64_t> rank2next_rank_in_ring_;\n  std::unordered_map<int64_t, int64_t> rank2prev_rank_in_ring_;\n  size_t hash_value_;\n};\n\n}  // namespace oneflow\n\nnamespace std {\n\ntemplate<>\nstruct hash<oneflow::RankGroup> final {\n  size_t operator()(const oneflow::RankGroup& rank_group) const { return rank_group.hash_value(); }\n};\n\n}  // namespace std\n\n#endif  // ONEFLOW_CORE_JOB_RANK_GROUP_H_\n"
  },
  {
    "path": "oneflow/core/job/rank_group_scope.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/job/rank_group_scope.h\"\n\nnamespace oneflow {\n\n/*static*/ Maybe<Symbol<RankGroup>> RankGroupScope::CurrentRankGroup() {\n  return RankGroup::DefaultRankGroup();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job/rank_group_scope.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_JOB_RANK_GROUP_SCOPE_H_\n#define ONEFLOW_CORE_JOB_RANK_GROUP_SCOPE_H_\n\n#include \"oneflow/core/job/rank_group.h\"\n#include \"oneflow/core/common/symbol.h\"\n\nnamespace oneflow {\n\n// NOTE(daquexian): this scope class is not actually used. We only keep\n// it in case we need it in the future.\nclass RankGroupScope final {\n public:\n  static Maybe<Symbol<RankGroup>> CurrentRankGroup();\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_JOB_RANK_GROUP_SCOPE_H_\n"
  },
  {
    "path": "oneflow/core/job/rank_group_test.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"gtest/gtest.h\"\n#include <algorithm>\n#include <set>\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/job/rank_group.h\"\n#include \"oneflow/core/control/ctrl_bootstrap.pb.h\"\n\nnamespace oneflow {\nnamespace test {\n\nTEST(RankGroup, two_rank) {\n  const auto& rank_group = CHECK_JUST(RankGroup::New(std::set<int64_t>{0, 1}));\n  int64_t rank = 0;\n  rank = CHECK_JUST(rank_group->GetNextRankInRing(0));\n  ASSERT_EQ(rank, 1);\n  rank = CHECK_JUST(rank_group->GetNextRankInRing(1));\n  ASSERT_EQ(rank, 0);\n  rank = CHECK_JUST(rank_group->GetPrevRankInRing(0));\n  ASSERT_EQ(rank, 1);\n  rank = CHECK_JUST(rank_group->GetPrevRankInRing(1));\n  ASSERT_EQ(rank, 0);\n}\n\nTEST(RankGroup, nonconsecutive_rank) {\n  const auto& rank_group = CHECK_JUST(RankGroup::New(std::set<int64_t>{0, 1, 3, 4}));\n  int64_t rank = 0;\n  rank = CHECK_JUST(rank_group->GetNextRankInRing(0));\n  ASSERT_EQ(rank, 1);\n  rank = CHECK_JUST(rank_group->GetNextRankInRing(1));\n  ASSERT_EQ(rank, 3);\n  rank = CHECK_JUST(rank_group->GetNextRankInRing(3));\n  ASSERT_EQ(rank, 4);\n  rank = CHECK_JUST(rank_group->GetNextRankInRing(4));\n  ASSERT_EQ(rank, 0);\n  bool is_ok = TRY(rank_group->GetNextRankInRing(2)).IsOk();\n  ASSERT_FALSE(is_ok);\n  rank = CHECK_JUST(rank_group->GetPrevRankInRing(1));\n  ASSERT_EQ(rank, 0);\n  rank = CHECK_JUST(rank_group->GetPrevRankInRing(3));\n  ASSERT_EQ(rank, 1);\n  rank = CHECK_JUST(rank_group->GetPrevRankInRing(4));\n  ASSERT_EQ(rank, 3);\n  rank = CHECK_JUST(rank_group->GetPrevRankInRing(0));\n  ASSERT_EQ(rank, 4);\n}\n\n}  // namespace test\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job/regularizer_conf.proto",
    "content": "syntax = \"proto2\";\npackage oneflow;\n\nmessage L1L2RegularizerConf {\n  optional float l1 = 1 [default = 0.0];\n  optional float l2 = 2 [default = 0.0];\n}\n\nmessage RegularizerConf {\n  oneof type {\n    L1L2RegularizerConf l1_l2_conf = 1;\n  }\n}\n"
  },
  {
    "path": "oneflow/core/job/resource.proto",
    "content": "syntax = \"proto2\";\npackage oneflow;\n\nimport public \"oneflow/core/common/device_type.proto\";\n\nmessage CollectiveBoxingConf {\n  // global\n  optional bool enable_fusion = 1 [default = true];\n  optional int64 num_callback_threads = 2 [default = 4];\n\n  // nccl\n  optional int64 nccl_num_streams = 101 [default = 1];\n  optional int64 nccl_fusion_threshold_mb = 102 [default = 16];\n  optional bool nccl_fusion_all_reduce = 103 [default = true];\n  optional bool nccl_fusion_reduce_scatter = 104 [default = false];\n  optional bool nccl_fusion_all_gather = 105 [default = false];\n  optional bool nccl_fusion_reduce = 106 [default = true];\n  optional bool nccl_fusion_broadcast = 107 [default = true];\n  optional bool nccl_fusion_all_reduce_use_buffer = 108 [default = false];\n  optional int64 nccl_fusion_max_ops = 109 [default = 64];\n  optional bool nccl_enable_all_to_all = 110 [default = false];\n  optional bool nccl_enable_mixed_fusion = 111 [default = false];\n}\n\nmessage CudnnConfig {\n  optional bool enable_cudnn = 1 [default = true];\n  optional int64 cudnn_buf_limit_mbyte = 2 [default = 1024];  // 1GByte\n  optional int32 cudnn_conv_force_fwd_algo = 3;\n  optional int32 cudnn_conv_force_bwd_data_algo = 4;\n  optional int32 cudnn_conv_force_bwd_filter_algo = 5;\n  optional bool cudnn_conv_heuristic_search_algo = 6 [default = true];\n  optional bool cudnn_conv_use_deterministic_algo_only = 7 [default = false];\n  optional bool enable_cudnn_fused_normalization_add_relu = 8;\n  optional bool cudnn_conv_enable_pseudo_half = 9 [default = true];\n}\n\nmessage Resource {\n  optional int32 machine_num = 1 [default = 0];\n  optional int32 cpu_device_num = 5 [default = 0];\n  optional int32 comm_net_worker_num = 6 [default = 4];\n  optional int32 max_mdsave_worker_num = 7 [default = 64];\n  optional uint64 reserved_host_mem_mbyte = 12 [default = 500];\n  optional uint64 reserved_device_mem_mbyte = 13 [default = 500];\n  optional int32 compute_thread_pool_size = 15;\n  optional bool enable_thread_local_cache = 16 [default = true];\n  optional int64 thread_local_cache_max_size = 17 [default = 67108864]; // 64M\n  optional bool enable_debug_mode = 18 [default = false];\n  optional bool enable_tensor_float_32_compute = 20 [default = true];\n  \n  optional CollectiveBoxingConf collective_boxing_conf = 19;\n\n  // NOTE(chengcheng) to reuse nccl memory and speed up\n  optional bool nccl_use_compute_stream = 30 [default = false];\n  optional bool disable_group_boxing_by_dst_parallel = 31 [default = false];\n\n  optional CudnnConfig cudnn_conf = 32;\n  optional bool enable_legacy_model_io = 33 [default = true];\n  optional bool enable_legacy_model_io_v2 = 34 [default = false];\n}\n"
  },
  {
    "path": "oneflow/core/job/resource_desc.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <algorithm>\n#include \"oneflow/core/job/resource.pb.h\"\n#include \"oneflow/core/job/resource_desc.h\"\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/common/env_var/debug_mode.h\"\n#include \"oneflow/core/control/global_process_ctx.h\"\n#ifdef WITH_CUDA\n#include <nccl.h>\n#endif\n\nnamespace oneflow {\n\nResourceDesc::ResourceDesc(const Resource& resource, int64_t num_process_per_node)\n    : resource_(resource) {\n  CHECK_GT(resource_.machine_num(), 0);\n  CHECK_LE(resource_.machine_num(), Singleton<EnvDesc>::Get()->TotalMachineNum());\n  for (int i = 0; i < GlobalProcessCtx::WorldSize(); ++i) {\n    CHECK(process_ranks_.emplace(i).second);\n  }\n}\n\nMachine ResourceDesc::machine(int32_t idx) const {\n  CHECK_GE(idx, 0);\n  CHECK(process_ranks().find(idx) != process_ranks().end());\n  if (Singleton<EnvDesc>::Get()->has_ctrl_bootstrap_conf()) {\n    CHECK_NOTNULL(Singleton<ProcessCtx>::Get());\n    CHECK_GE(Singleton<ProcessCtx>::Get()->ctrl_addr().size(), process_ranks().size());\n    Machine machine;\n    const Address& addr = Singleton<ProcessCtx>::Get()->ctrl_addr(idx);\n    machine.set_addr(addr.host());\n    return machine;\n  } else {\n    return Singleton<EnvDesc>::Get()->machine(idx);\n  }\n}\n\nint32_t ResourceDesc::ComputeThreadPoolSize() const {\n  if (resource_.has_compute_thread_pool_size()) {\n    CHECK_GT(resource_.compute_thread_pool_size(), 0);\n    return resource_.compute_thread_pool_size();\n  } else {\n    return CpuDeviceNum();\n  }\n}\n\nbool ResourceDesc::enable_debug_mode() const {\n  return IsInDebugMode() || resource_.enable_debug_mode();\n}\n\nCollectiveBoxingConf ResourceDesc::collective_boxing_conf() const {\n  if (resource_.has_collective_boxing_conf()) {\n    return resource_.collective_boxing_conf();\n  } else {\n    return CollectiveBoxingConf();\n  }\n}\n\nbool ResourceDesc::nccl_use_compute_stream() const {\n#if defined(WITH_CUDA) && NCCL_VERSION_CODE > 2700\n  return resource_.nccl_use_compute_stream();\n#elif defined(WITH_NPU)\n  return resource_.nccl_use_compute_stream();\n#elif defined(WITH_MLU)\n  return resource_.nccl_use_compute_stream();\n#else\n  return false;\n#endif\n}\n\nvoid ResourceDesc::DumpCudnnConf(const JobConfigProto& job_conf) {\n  auto* cudnn_conf = resource_.mutable_cudnn_conf();\n  if (job_conf.has_enable_cudnn()) { cudnn_conf->set_enable_cudnn(job_conf.enable_cudnn()); }\n  if (job_conf.has_cudnn_buf_limit_mbyte()) {\n    cudnn_conf->set_cudnn_buf_limit_mbyte(job_conf.cudnn_buf_limit_mbyte());\n  }\n  if (job_conf.has_cudnn_conv_force_fwd_algo()) {\n    cudnn_conf->set_cudnn_conv_force_fwd_algo(job_conf.cudnn_conv_force_fwd_algo());\n  }\n  if (job_conf.has_cudnn_conv_force_bwd_data_algo()) {\n    cudnn_conf->set_cudnn_conv_force_bwd_data_algo(job_conf.cudnn_conv_force_bwd_data_algo());\n  }\n  if (job_conf.has_cudnn_conv_force_bwd_filter_algo()) {\n    cudnn_conf->set_cudnn_conv_force_bwd_filter_algo(job_conf.cudnn_conv_force_bwd_filter_algo());\n  }\n  if (job_conf.has_cudnn_conv_heuristic_search_algo()) {\n    cudnn_conf->set_cudnn_conv_heuristic_search_algo(job_conf.cudnn_conv_heuristic_search_algo());\n  }\n  if (job_conf.has_cudnn_conv_use_deterministic_algo_only()) {\n    cudnn_conf->set_cudnn_conv_use_deterministic_algo_only(\n        job_conf.cudnn_conv_use_deterministic_algo_only());\n  }\n  if (job_conf.has_enable_cudnn_fused_normalization_add_relu()) {\n    cudnn_conf->set_enable_cudnn_fused_normalization_add_relu(\n        job_conf.enable_cudnn_fused_normalization_add_relu());\n  }\n  if (job_conf.has_cudnn_conv_enable_pseudo_half()) {\n    cudnn_conf->set_cudnn_conv_enable_pseudo_half(job_conf.cudnn_conv_enable_pseudo_half());\n  }\n}\n\nvoid ResourceDesc::Update(const Resource& reso_conf) { resource_.CopyFrom(reso_conf); }\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job/resource_desc.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_JOB_RESOURCE_DESC_H_\n#define ONEFLOW_CORE_JOB_RESOURCE_DESC_H_\n\n#include <set>\n#include \"oneflow/core/job/job.pb.h\"\n#include \"oneflow/core/job/resource.pb.h\"\n#include \"oneflow/core/job/env_desc.h\"\n\nnamespace oneflow {\n\nstatic const size_t kMB = 1024 * 1024;\n\nclass ResourceDesc final {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(ResourceDesc);\n  ResourceDesc(const Resource& resource, int64_t num_process_per_node);\n  ResourceDesc(const Resource& resource)\n      : resource_(resource) {}  // TODO(yaochi): Only for eager, remove it later\n\n  ~ResourceDesc() = default;\n\n  const std::set<int64_t>& process_ranks() const { return process_ranks_; }\n  __attribute__((deprecated)) Machine machine(int32_t idx) const;\n  size_t CommNetWorkerNum() const { return resource_.comm_net_worker_num(); }\n  int32_t CpuDeviceNum() const { return resource_.cpu_device_num(); }\n  int32_t MaxMdSaveWorkerNum() const { return resource_.max_mdsave_worker_num(); }\n  size_t reserved_host_mem_byte() const { return resource_.reserved_host_mem_mbyte() * kMB; }\n  size_t reserved_device_mem_byte() const { return resource_.reserved_device_mem_mbyte() * kMB; }\n  bool enable_thread_local_cache() const { return resource_.enable_thread_local_cache(); }\n  size_t thread_local_cache_max_size() const { return resource_.thread_local_cache_max_size(); }\n  int32_t ComputeThreadPoolSize() const;\n  bool enable_debug_mode() const;\n  CollectiveBoxingConf collective_boxing_conf() const;\n  bool nccl_use_compute_stream() const;\n\n  void SetMachineNum(int32_t val) { resource_.set_machine_num(val); }\n  void SetCpuDeviceNum(int32_t val) { resource_.set_cpu_device_num(val); }\n  bool enable_tensor_float_32_compute() const { return resource_.enable_tensor_float_32_compute(); }\n  const Resource& resource() const { return resource_; }\n  void DumpCudnnConf(const JobConfigProto& job_conf);\n  void Update(const Resource& reso_conf);\n\n private:\n  Resource resource_;\n  std::set<int64_t> process_ranks_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_JOB_RESOURCE_DESC_H_\n"
  },
  {
    "path": "oneflow/core/job/runtime.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/job/runtime.h\"\n#include \"oneflow/core/job/global_for.h\"\n#include \"oneflow/core/control/ctrl_client.h\"\n#include \"oneflow/core/control/global_process_ctx.h\"\n#include \"oneflow/core/eager/eager_blob_object.h\"\n#include \"oneflow/core/job/env_desc.h\"\n#include \"oneflow/core/job/resource_desc.h\"\n#include \"oneflow/core/job/global_for.h\"\n#include \"oneflow/core/job/runtime_context.h\"\n#include \"oneflow/core/job/runtime_job_descs.h\"\n#include \"oneflow/core/job/eager_nccl_comm_manager.h\"\n#include \"oneflow/core/thread/thread_manager.h\"\n#include \"oneflow/core/graph/task_node.h\"\n#include \"oneflow/core/device/cuda_util.h\"\n#include \"oneflow/core/memory/memory_allocator.h\"\n#include \"oneflow/core/register/register_manager.h\"\n#include \"oneflow/user/summary/events_writer.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nvoid SendCmdMsg(const std::vector<const TaskProto*>& tasks, ActorCmd cmd) {\n  for (const TaskProto* task : tasks) {\n    ActorMsg msg = ActorMsg::BuildCommandMsg(task->task_id(), cmd);\n    Singleton<ActorMsgBus>::Get()->SendMsg(msg);\n  }\n}\n\nvoid HandoutTasks(const std::vector<const TaskProto*>& tasks) {\n  for (const TaskProto* task : tasks) {\n    Singleton<ThreadMgr>::Get()->GetThrd(task->thrd_id())->AddTask(*task);\n  }\n  SendCmdMsg(tasks, ActorCmd::kConstructActor);\n}\n\nbool HasNonCtrlConsumedRegstDescId(const TaskProto& task) {\n  for (const auto& pair : task.consumed_regst_desc_id()) {\n    if (pair.first == \"in_ctrl\") { continue; }\n    return true;\n  }\n  return false;\n}\n\n}  // namespace\n\nRuntime::Runtime(\n    const Plan& plan,\n    const HashMap<std::string, vm::EagerBlobObject*>& variable_op_name2eager_blob_object) {\n  DumpThreadIdsFromPlan(plan);\n  {\n    // NOTE(chengcheng): All runtime global(singleton) objects AddPlan\n    Singleton<RegstMgr>::Get()->AddPlan(plan, variable_op_name2eager_blob_object);\n    Singleton<ThreadMgr>::Get()->AddThreads(thread_ids_);\n    Singleton<RuntimeJobDescs>::Get()->AddPlan(plan);\n    collective_boxing_scheduler_plan_token_ =\n        Singleton<boxing::collective::Scheduler>::Get()->AddPlan(plan);\n#if defined(WITH_CUDA) || defined(WITH_NPU) || defined(WITH_MLU)\n    const auto& vaild_ccl_comm_mgr_device_types =\n        EagerCclCommMgrBuilder::Get().vaild_ccl_comm_mgr_device_types();\n    if (!vaild_ccl_comm_mgr_device_types.empty() && !Singleton<EagerCclCommMgr>::Get()) {\n      Singleton<EagerCclCommMgr>::SetAllocated(\n          EagerCclCommMgrBuilder::Get().NewCclCommMgr(vaild_ccl_comm_mgr_device_types.front()));\n    }\n    Singleton<EagerCclCommMgr>::Get()->CreateCommFromPlan(plan);\n#endif  // defined(WITH_CUDA) || defined(WITH_NPU) || defined(WITH_MLU)\n  }\n  std::vector<const TaskProto*> source_tasks;\n  source_tasks.reserve(plan.task().size());\n  std::vector<const TaskProto*> other_tasks;\n  other_tasks.reserve(plan.task().size());\n  int64_t this_machine_task_num = 0;\n  for (const TaskProto& task : plan.task()) {\n    if (task.machine_id() != GlobalProcessCtx::Rank()) { continue; }\n    if (!HasNonCtrlConsumedRegstDescId(task)) {\n      source_tasks.emplace_back(&task);\n    } else {\n      other_tasks.emplace_back(&task);\n    }\n    auto it = job_id2actor_size_.find(task.job_id());\n    if (it == job_id2actor_size_.end()) {\n      auto emplace_ret_pair = job_id2actor_size_.emplace(task.job_id(), 0);\n      CHECK(emplace_ret_pair.second);\n      it = emplace_ret_pair.first;\n    }\n    it->second++;\n    this_machine_task_num++;\n  }\n  RuntimeCtx* runtime_ctx = Singleton<RuntimeCtx>::Get();\n  runtime_ctx->NewCounter(\"constructing_actor_cnt\", this_machine_task_num);\n  HandoutTasks(source_tasks);\n  HandoutTasks(other_tasks);\n  runtime_ctx->WaitUntilCntEqualZero(\"constructing_actor_cnt\");\n  VLOG(3) << \"Actors on this machine constructed\";\n  OF_SESSION_BARRIER();\n  VLOG(3) << \"Actors on every machine constructed\";\n  for (auto pair : job_id2actor_size_) {\n    runtime_ctx->NewCounter(GetRunningActorCountKeyByJobId(pair.first), pair.second);\n  }\n  SendCmdMsg(source_tasks, ActorCmd::kStart);\n}\n\nRuntime::~Runtime() {\n  for (auto pair : job_id2actor_size_) {\n    Singleton<RuntimeCtx>::Get()->WaitUntilCntEqualZero(GetRunningActorCountKeyByJobId(pair.first));\n  }\n  OF_SESSION_BARRIER();\n  Singleton<ThreadMgr>::Get()->DeleteThreads(independent_thread_ids_);\n  Singleton<boxing::collective::Scheduler>::Get()->DeletePlan(\n      collective_boxing_scheduler_plan_token_);\n}\n\nvoid Runtime::DumpThreadIdsFromPlan(const Plan& plan) {\n  const int64_t this_rank = GlobalProcessCtx::Rank();\n  for (const TaskProto& task : plan.task()) {\n    TaskId task_id = DecodeTaskIdFromInt64(task.task_id());\n    StreamId stream_id = task_id.stream_id();\n    if (stream_id.rank() != this_rank) { continue; }\n    int64_t thrd_id = EncodeStreamIdToInt64(stream_id);\n    thread_ids_.insert(thrd_id);\n    // NOTE(chengcheng): there is not a interface to query whether a task type is indenpendent,\n    //  so use hard code.\n    if (task.task_type() == TaskType::kWaitAndSendIds\n        || task.task_type() == TaskType::kCriticalSectionWaitTick) {\n      CHECK(independent_thread_ids_.insert(thrd_id).second)\n          << \" RuntimeError! Thread : \" << thrd_id\n          << \" not independent with task proto: \" << task.DebugString();\n    }\n  }\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job/runtime.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_JOB_RUNTIME_H_\n#define ONEFLOW_CORE_JOB_RUNTIME_H_\n\n#include \"oneflow/core/job/job_desc.h\"\n#include \"oneflow/core/job/plan.pb.h\"\n#include \"oneflow/core/register/blob.h\"\n#include \"oneflow/core/job/collective_boxing/scheduler.h\"\n\nnamespace oneflow {\n\nnamespace vm {\nclass EagerBlobObject;\n}\n\nclass Runtime final {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(Runtime);\n  Runtime() = delete;\n  ~Runtime();\n\n  // TODO(chengcheng): refactor Runtime interface about variable_op_name2eager_blob_object\n  Runtime(const Plan& plan,\n          const HashMap<std::string, vm::EagerBlobObject*>& variable_op_name2eager_blob_object);\n\n private:\n  void DumpThreadIdsFromPlan(const Plan& plan);\n\n  HashMap<int64_t, int64_t> job_id2actor_size_;\n  HashSet<int64_t> thread_ids_;\n  HashSet<int64_t> independent_thread_ids_;\n\n  boxing::collective::SchedulerPlanToken* collective_boxing_scheduler_plan_token_;\n};\n\n}  // namespace oneflow\n\n#endif\n"
  },
  {
    "path": "oneflow/core/job/runtime_buffer_managers_scope.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/buffer_manager.h\"\n#include \"oneflow/core/job/runtime_buffer_managers_scope.h\"\n#include \"oneflow/core/job/job_instance.h\"\n\nnamespace oneflow {\n\nRuntimeBufferManagersScope::RuntimeBufferManagersScope() {\n  Singleton<BufferMgr<int64_t>>::New();\n  Singleton<BufferMgr<std::shared_ptr<JobInstance>>>::New();\n}\n\nRuntimeBufferManagersScope::~RuntimeBufferManagersScope() {\n  Singleton<BufferMgr<std::shared_ptr<JobInstance>>>::Delete();\n  Singleton<BufferMgr<int64_t>>::Delete();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job/runtime_buffer_managers_scope.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_JOB_RUNTIME_BUFFER_MANAGERS_SCOPE_H_\n#define ONEFLOW_CORE_JOB_RUNTIME_BUFFER_MANAGERS_SCOPE_H_\n\n#include \"oneflow/core/common/util.h\"\n\nnamespace oneflow {\n\nclass RuntimeBufferManagersScope final {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(RuntimeBufferManagersScope);\n  RuntimeBufferManagersScope();\n  ~RuntimeBufferManagersScope();\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_JOB_RUNTIME_BUFFER_MANAGERS_SCOPE_H_\n"
  },
  {
    "path": "oneflow/core/job/runtime_buffers_scope.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/buffer_manager.h\"\n#include \"oneflow/core/job/runtime_buffers_scope.h\"\n#include \"oneflow/core/job/job_desc.h\"\n#include \"oneflow/core/job/job_instance.h\"\n\nnamespace oneflow {\n\nRuntimeBuffersScope::RuntimeBuffersScope(const JobConfs& job_confs) {\n  size_t job_size = Singleton<JobName2JobId>::Get()->size();\n  Singleton<BufferMgr<int64_t>>::Get()->NewBuffer(kBufferNameGlobalWaitJobId, job_size);\n  auto* buffer_mgr = Singleton<BufferMgr<std::shared_ptr<JobInstance>>>::Get();\n  for (const auto& pair : job_confs.job_id2job_conf()) {\n    const auto& job_name = pair.second.job_name();\n    CHECK_EQ(pair.first, Singleton<JobName2JobId>::Get()->at(job_name));\n    size_t concurrency_width = pair.second.concurrency_width();\n    buffer_mgr->NewBuffer(GetCallbackNotifierBufferName(job_name), concurrency_width);\n  }\n}\n\nRuntimeBuffersScope::~RuntimeBuffersScope() {\n  auto* buffer_mgr = Singleton<BufferMgr<std::shared_ptr<JobInstance>>>::Get();\n  for (const auto& pair : *Singleton<JobName2JobId>::Get()) {\n    const auto& job_name = pair.first;\n    buffer_mgr->Get(GetCallbackNotifierBufferName(job_name))->Close();\n  }\n  Singleton<BufferMgr<int64_t>>::Get()->Get(kBufferNameGlobalWaitJobId)->Close();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job/runtime_buffers_scope.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_JOB_RUNTIME_BUFFERS_SCOPE_H_\n#define ONEFLOW_CORE_JOB_RUNTIME_BUFFERS_SCOPE_H_\n\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/job/plan.pb.h\"\n\nnamespace oneflow {\n\nclass RuntimeBuffersScope final {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(RuntimeBuffersScope);\n  RuntimeBuffersScope(const JobConfs& job_confs);\n  ~RuntimeBuffersScope();\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_JOB_RUNTIME_BUFFERS_SCOPE_H_\n"
  },
  {
    "path": "oneflow/core/job/runtime_context.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/job/runtime_context.h\"\n\nnamespace oneflow {\n\nvoid RuntimeCtx::NewCounter(const std::string& name, int64_t val) {\n  VLOG(3) << \"NewCounter \" << name << \" \" << val;\n  CHECK(counters_.emplace(name, std::make_unique<BlockingCounter>(val)).second);\n}\n\nvoid RuntimeCtx::DecreaseCounter(const std::string& name) {\n  auto it = counters_.find(name);\n  CHECK(it != counters_.end());\n  int64_t cur_val = it->second->Decrease();\n  VLOG(3) << \"DecreaseCounter \" << name << \", current val is \" << cur_val;\n}\n\nvoid RuntimeCtx::WaitUntilCntEqualZero(const std::string& name) {\n  auto it = counters_.find(name);\n  CHECK(it != counters_.end());\n  it->second->WaitForeverUntilCntEqualZero();\n  counters_.erase(it);\n}\n\nstd::string GetRunningActorCountKeyByJobId(int64_t job_id) {\n  return \"job_\" + std::to_string(job_id) + \"_running_actor_count\";\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job/runtime_context.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_JOB_RUNTIME_CONTEXT_H_\n#define ONEFLOW_CORE_JOB_RUNTIME_CONTEXT_H_\n\n#include \"oneflow/core/common/blocking_counter.h\"\n\nnamespace oneflow {\n\nclass RuntimeCtx final {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(RuntimeCtx);\n  RuntimeCtx() = default;\n  ~RuntimeCtx() = default;\n\n  void NewCounter(const std::string& name, int64_t val);\n  void DecreaseCounter(const std::string& name);\n  void WaitUntilCntEqualZero(const std::string& name);\n\n private:\n  HashMap<std::string, std::unique_ptr<BlockingCounter>> counters_;\n};\n\nstd::string GetRunningActorCountKeyByJobId(int64_t job_id);\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_JOB_RUNTIME_CONTEXT_H_\n"
  },
  {
    "path": "oneflow/core/job/runtime_job_descs.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/job/runtime_job_descs.h\"\n\nnamespace oneflow {\n\nvoid RuntimeJobDescs::AddPlan(const Plan& plan) {\n  for (const auto& pair : plan.job_confs().job_id2job_conf()) {\n    auto job_desc = std::make_unique<JobDesc>(pair.second, pair.first);\n    CHECK(job_id2job_desc_.emplace(pair.first, std::move(job_desc)).second);\n  }\n}\n\nconst JobDesc& RuntimeJobDescs::job_desc(int64_t job_id) const {\n  auto it = job_id2job_desc_.find(job_id);\n  CHECK(it != job_id2job_desc_.end());\n  return *(it->second);\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job/runtime_job_descs.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_JOB_RUNTIME_JOB_DESCS_H_\n#define ONEFLOW_CORE_JOB_RUNTIME_JOB_DESCS_H_\n\n#include \"oneflow/core/job/plan.pb.h\"\n#include \"oneflow/core/job/job_desc.h\"\n\nnamespace oneflow {\n\nclass RuntimeJobDescs final {\n public:\n  RuntimeJobDescs() = default;\n  ~RuntimeJobDescs() = default;\n\n  void AddPlan(const Plan& plan);\n  const JobDesc& job_desc(int64_t job_id) const;\n\n private:\n  HashMap<int64_t, std::unique_ptr<JobDesc>> job_id2job_desc_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_JOB_RUNTIME_JOB_DESCS_H_\n"
  },
  {
    "path": "oneflow/core/job/sbp_infer_hint.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_JOB_SBP_INFER_HINT_H_\n#define ONEFLOW_CORE_JOB_SBP_INFER_HINT_H_\n\n#include \"oneflow/core/job/sbp_parallel.pb.h\"\n#include \"oneflow/core/job/parallel_desc.h\"\n#include \"oneflow/core/register/blob_desc.h\"\n\nnamespace oneflow {\n\nclass SbpInferHint final {\n public:\n  SbpInferHint(const ParallelDesc* parallel_desc, const BlobDesc* logical_blob_desc,\n               const SbpParallel* sbp_parallel)\n      : parallel_desc_(parallel_desc),\n        logical_blob_desc_(logical_blob_desc),\n        sbp_parallel_(sbp_parallel) {}\n  SbpInferHint(const SbpInferHint&) = default;\n  ~SbpInferHint() = default;\n\n  // Getters\n  const ParallelDesc& parallel_desc() const { return *parallel_desc_; }\n  const BlobDesc& logical_blob_desc() const { return *logical_blob_desc_; }\n  const SbpParallel& sbp_parallel() const { return *sbp_parallel_; }\n\n private:\n  const ParallelDesc* parallel_desc_;\n  const BlobDesc* logical_blob_desc_;\n  const SbpParallel* sbp_parallel_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_JOB_SBP_INFER_HINT_H_\n"
  },
  {
    "path": "oneflow/core/job/sbp_parallel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/job/sbp_parallel.h\"\n#include \"oneflow/core/common/protobuf.h\"\n#include \"oneflow/core/common/str_util.h\"\n#include \"oneflow/core/framework/nd_sbp.h\"\n\nnamespace oneflow {\n\nbool operator==(const SbpParallel& lhs, const SbpParallel& rhs) {\n  if (lhs.parallel_type_case() != rhs.parallel_type_case()) { return false; }\n  if (lhs.has_split_parallel()) {\n    return lhs.split_parallel().axis() == rhs.split_parallel().axis();\n  } else if (lhs.has_broadcast_parallel()) {\n    return true;\n  } else if (lhs.has_partial_sum_parallel()) {\n    return true;\n  } else {\n    UNIMPLEMENTED();\n  }\n}\n\nbool operator==(const NdSbp& lhs, const NdSbp& rhs) {\n  if (lhs.sbp_parallel_size() != rhs.sbp_parallel_size()) { return false; }\n  for (int i = 0; i < lhs.sbp_parallel_size(); ++i) {\n    if (lhs.sbp_parallel(i) != rhs.sbp_parallel(i)) { return false; }\n  }\n  return true;\n  ;\n}\n\nbool operator==(const SbpSignature& lhs, const SbpSignature& rhs) {\n  if (lhs.bn_in_op2sbp_parallel_size() != rhs.bn_in_op2sbp_parallel_size()) { return false; }\n  const auto& lhs_map = lhs.bn_in_op2sbp_parallel();\n  const auto& rhs_map = rhs.bn_in_op2sbp_parallel();\n  for (const auto& lhs_pair : lhs_map) {\n    const auto& rhs_iter = rhs_map.find(lhs_pair.first);\n    if (rhs_iter == rhs_map.end()) { return false; }\n    if (lhs_pair.second != rhs_iter->second) { return false; }\n  }\n  return true;\n}\n\nbool operator==(const NdSbpSignature& lhs, const NdSbpSignature& rhs) {\n  if (lhs.bn_in_op2nd_sbp_size() != rhs.bn_in_op2nd_sbp_size()) { return false; }\n  const auto& lhs_map = lhs.bn_in_op2nd_sbp();\n  const auto& rhs_map = rhs.bn_in_op2nd_sbp();\n  for (const auto& lhs_pair : lhs_map) {\n    const auto& rhs_iter = rhs_map.find(lhs_pair.first);\n    if (rhs_iter == rhs_map.end()) { return false; }\n    if (lhs_pair.second != rhs_iter->second) { return false; }\n  }\n  return true;\n}\n\nMaybe<Symbol<SbpParallel>> MakeSplitSbpParallel(int axis) {\n  CHECK_LT_OR_RETURN(axis, kMaxSplitAxis);\n  SbpParallel split_sbp_parallel;\n  split_sbp_parallel.mutable_split_parallel()->set_axis(axis);\n  return SymbolOf(split_sbp_parallel);\n}\n\nMaybe<Symbol<SbpParallel>> MakeBroadcastSbpParallel() {\n  SbpParallel broadcast_sbp;\n  broadcast_sbp.mutable_broadcast_parallel();\n  return SymbolOf(broadcast_sbp);\n}\n\nMaybe<Symbol<SbpParallel>> MakePartialSumSbpParallel() {\n  SbpParallel partial_sum_sbp;\n  partial_sum_sbp.mutable_partial_sum_parallel();\n  return SymbolOf(partial_sum_sbp);\n}\n\n//  S -> S\n//  P -> B\n//  B -> P\nSbpParallel GetDualSbpParallel(const SbpParallel& sbp_parallel) {\n  SbpParallel ret(sbp_parallel);\n  if (sbp_parallel.has_split_parallel()) {\n    //  do nothing\n  } else if (sbp_parallel.has_broadcast_parallel()) {\n    ret.mutable_partial_sum_parallel();\n  } else if (sbp_parallel.has_partial_sum_parallel()) {\n    ret.mutable_broadcast_parallel();\n  } else {\n    UNIMPLEMENTED();\n  }\n  return ret;\n}\n\nbool IsSbpSignatureContaining(const SbpSignature& bigger, const SbpSignature& smaller) {\n  auto& bn2sbp = bigger.bn_in_op2sbp_parallel();\n  for (const auto& pair : smaller.bn_in_op2sbp_parallel()) {\n    if (pair.second.parallel_type_case() == SbpParallel::PARALLEL_TYPE_NOT_SET) { continue; }\n    CHECK(bn2sbp.find(pair.first) != bn2sbp.end()) << pair.first;\n    if (bn2sbp.at(pair.first) != pair.second) { return false; }\n  }\n  return true;\n}\n\nvoid FilterSbpSignatureList(const SbpSignatureList& sbp_sig_list, const SbpSignature& sbp_sig_conf,\n                            SbpSignatureList* filtered_sbp_sig_list) {\n  for (const auto& sbp_signature : sbp_sig_list.sbp_signature()) {\n    if (IsSbpSignatureContaining(sbp_signature, sbp_sig_conf)) {\n      *filtered_sbp_sig_list->mutable_sbp_signature()->Add() = sbp_signature;\n    }\n  }\n}\n\ndouble ComputCopyCostBetweenTwoSbpParallel(const SbpInferHint& producer_sbp_infer_hint,\n                                           const SbpParallel& consumer_sbp_parallel) {\n  if (producer_sbp_infer_hint.sbp_parallel() == consumer_sbp_parallel) { return 0.0; }\n  if (consumer_sbp_parallel.has_partial_sum_parallel()) { return GetMaxVal<int64_t>(); }\n  if (producer_sbp_infer_hint.sbp_parallel().has_broadcast_parallel()) {\n    return GetMaxVal<int32_t>();\n  }\n  const auto& logical_blob_desc = producer_sbp_infer_hint.logical_blob_desc();\n  return logical_blob_desc.shape().elem_cnt() * GetSizeOfDataType(logical_blob_desc.data_type());\n}\n\ndouble ComputeIbnCopyCost4SbpSig(\n    const PbRpf<std::string>& ibns,\n    const std::function<Maybe<const SbpInferHint*>(const std::string&)>& SbpInferHint4Ibn,\n    const SbpSignature& sbp_signature) {\n  double cost = 0;\n  for (const auto& ibn : ibns) {\n    const auto& consumer_sbp_parallel = sbp_signature.bn_in_op2sbp_parallel().find(ibn)->second;\n    cost += ComputCopyCostBetweenTwoSbpParallel(*CHECK_JUST(SbpInferHint4Ibn(ibn)),\n                                                consumer_sbp_parallel);\n  }\n  return cost;\n}\n\nstd::function<double(const SbpSignature*)> MakeGetterIbnCopyCost4SbpSig(\n    const PbRpf<std::string>& ibns,\n    const std::function<Maybe<const SbpInferHint*>(const std::string&)>& SbpInferHint4Ibn,\n    const SbpSignatureList& sbp_sig_list) {\n  auto sbp_sig2ibn_copy_cast = std::make_shared<HashMap<const SbpSignature*, double>>();\n  for (const auto& sbp_signature : sbp_sig_list.sbp_signature()) {\n    double cost = ComputeIbnCopyCost4SbpSig(ibns, SbpInferHint4Ibn, sbp_signature);\n    CHECK(sbp_sig2ibn_copy_cast->emplace(&sbp_signature, cost).second);\n  }\n  return [sbp_sig2ibn_copy_cast](const SbpSignature* sbp_sig) -> double {\n    return sbp_sig2ibn_copy_cast->at(sbp_sig);\n  };\n}\n\nstd::function<int32_t(const SbpSignature* sbp_sig)> MakeGetterOrderValue4SbpSig(\n    const SbpSignatureList& sbp_sig_list,\n    const std::function<int32_t(const SbpSignature&)>& CalcOrderValue4SbpSig) {\n  auto sbp_sig2order_value = std::make_shared<HashMap<const SbpSignature*, int32_t>>();\n  for (const SbpSignature& sbp_signature : sbp_sig_list.sbp_signature()) {\n    sbp_sig2order_value->emplace(&sbp_signature, CalcOrderValue4SbpSig(sbp_signature));\n  }\n  return [sbp_sig2order_value](const SbpSignature* sbp_sig) {\n    return sbp_sig2order_value->at(sbp_sig);\n  };\n}\n\nvoid SortSbpSignatureListByCopyCost(\n    const SbpSignatureList& sbp_sig_list, const PbRpf<std::string>& ibns,\n    const std::function<Maybe<const SbpInferHint*>(const std::string&)>& SbpInferHint4Ibn,\n    const std::function<int32_t(const SbpSignature&)>& CalcOrderValue4SbpSig,\n    std::vector<const SbpSignature*>* sorted_sbp_signatures) {\n  auto OrderValue4SbpSig = MakeGetterOrderValue4SbpSig(sbp_sig_list, CalcOrderValue4SbpSig);\n  auto IbnCopyCost4SbpSig = MakeGetterIbnCopyCost4SbpSig(ibns, SbpInferHint4Ibn, sbp_sig_list);\n  for (const auto& sbp_signature : sbp_sig_list.sbp_signature()) {\n    sorted_sbp_signatures->emplace_back(&sbp_signature);\n  }\n  std::sort(sorted_sbp_signatures->begin(), sorted_sbp_signatures->end(),\n            [&](const SbpSignature* lhs, const SbpSignature* rhs) {\n              if (OrderValue4SbpSig(lhs) < OrderValue4SbpSig(rhs)) { return true; }\n              if (OrderValue4SbpSig(lhs) > OrderValue4SbpSig(rhs)) { return false; }\n              return IbnCopyCost4SbpSig(lhs) < IbnCopyCost4SbpSig(rhs);\n            });\n}\n\nbool IsValidSbpParallelString(const std::string& sbp_str) {\n  SbpParallel sbp_parallel;\n  return ParseSbpParallelFromString(sbp_str, &sbp_parallel);\n}\n\nbool ParseNdSbpFromLongString(const std::string& nd_sbp_str, NdSbp* nd_sbp) {\n  bool success = true;\n  Split(nd_sbp_str, \",\", [&](std::string&& sbp_str) {\n    SbpParallel* sbp_parallel = nd_sbp->add_sbp_parallel();\n    bool ret = ParseSbpParallelFromString(sbp_str, sbp_parallel);\n    if (!ret) { success = false; }\n  });\n  if (nd_sbp->sbp_parallel_size() == 0) { return false; }\n  return success;\n}\n\nstd::string NdSbpToLongString(const NdSbp& nd_sbp) {\n  std::string ret = \"\";\n  for (int32_t i = 0; i < nd_sbp.sbp_parallel_size(); ++i) {\n    if (i > 0) { ret += \",\"; }  // NOTE(chengcheng): Separator ','\n    ret += SbpToString(nd_sbp.sbp_parallel(i));\n  }\n  return ret;\n}\n\nbool ParseSbpParallelFromString(const std::string& sbp_str, SbpParallel* sbp_parallel) {\n  bool success = false;\n  if (sbp_str.length() >= 1) {\n    if (sbp_str == \"B\") {\n      sbp_parallel->mutable_broadcast_parallel();\n      success = true;\n    } else if (sbp_str == \"P\") {\n      sbp_parallel->mutable_partial_sum_parallel();\n      success = true;\n    } else if (sbp_str[0] == 'S') {\n      if (sbp_str.length() >= 4 && sbp_str[1] == '(' && sbp_str[sbp_str.length() - 1] == ')') {\n        int split_axis = 0;\n        if (sbp_str.length() == 4) {\n          split_axis = sbp_str[2] - '0';\n          if (split_axis >= 0 && split_axis <= 9) { success = true; }\n        } else {\n          std::string split_axis_str = sbp_str.substr(2, sbp_str.length() - 3);\n          if (std::all_of(split_axis_str.cbegin(), split_axis_str.cend(),\n                          [](char ch) { return std::isdigit(ch); })) {\n            size_t pos = 0;\n            split_axis = std::stoi(split_axis_str, &pos);\n            if (pos == split_axis_str.length()) { success = true; }\n          }\n        }\n        if (success) { sbp_parallel->mutable_split_parallel()->set_axis(split_axis); }\n      }\n    }\n  }\n  return success;\n}\n\nstd::string SbpParallelToString(const SbpParallel& sbp_parallel) {\n  return SbpToString(sbp_parallel);\n}\n\nbool ParseNdSbpFromStringList(const std::vector<std::string>& sbp_str_list, NdSbp* nd_sbp) {\n  for (const auto& sbp_str : sbp_str_list) {\n    if (!ParseSbpParallelFromString(sbp_str, nd_sbp->add_sbp_parallel())) { return false; }\n  }\n  return true;\n}\n\nstd::vector<std::string> NdSbpToStringList(const NdSbp& nd_sbp) {\n  std::vector<std::string> sbp_str_list(nd_sbp.sbp_parallel_size());\n  for (size_t i = 0; i < sbp_str_list.size(); ++i) {\n    sbp_str_list[i] = SbpToString(nd_sbp.sbp_parallel(i));\n  }\n  return sbp_str_list;\n}\n\nvoid SbpSignatureToNdSbpSignature(const SbpSignature& sbp_signature,\n                                  NdSbpSignature* nd_sbp_signature) {\n  for (const auto& pair : sbp_signature.bn_in_op2sbp_parallel()) {\n    *((*nd_sbp_signature->mutable_bn_in_op2nd_sbp())[pair.first].add_sbp_parallel()) = pair.second;\n  }\n}\n\nvoid NdSbpSignatureToSbpSignature(const NdSbpSignature& nd_sbp_signature,\n                                  SbpSignature* sbp_signature) {\n  for (const auto& pair : nd_sbp_signature.bn_in_op2nd_sbp()) {\n    CHECK_EQ(pair.second.sbp_parallel_size(), 1);\n    (*sbp_signature->mutable_bn_in_op2sbp_parallel())[pair.first] = pair.second.sbp_parallel(0);\n  }\n}\n\nvoid CheckSbpSignatureAndNdSbpEquals(const SbpSignature& sbp_sig,\n                                     const NdSbpSignature& nd_sbp_sig) {\n  CHECK_EQ(sbp_sig.bn_in_op2sbp_parallel_size(), nd_sbp_sig.bn_in_op2nd_sbp_size());\n  for (const auto& pair : nd_sbp_sig.bn_in_op2nd_sbp()) {\n    const auto& bn_in_op2sbp_parallel = sbp_sig.bn_in_op2sbp_parallel();\n    const auto it = bn_in_op2sbp_parallel.find(pair.first);\n    CHECK(it != bn_in_op2sbp_parallel.end());\n    CHECK_EQ(pair.second.sbp_parallel_size(), 1);\n    CHECK(pair.second.sbp_parallel(0) == it->second);\n  }\n}\n\nbool NdSbpAllSameSplitParallel(const NdSbp& nd_sbp) {\n  CHECK_GT(nd_sbp.sbp_parallel_size(), 0);\n  const SbpParallel& first_sbp = nd_sbp.sbp_parallel(0);\n  if (!first_sbp.has_split_parallel()) { return false; }\n  FOR_RANGE(int64_t, i, 1, nd_sbp.sbp_parallel_size()) {\n    if (nd_sbp.sbp_parallel(i) != first_sbp) { return false; }\n  }\n  return true;\n}\n\nMaybe<std::string> NdSbpSignatureToString(const NdSbpSignature& nd_sbp_signature,\n                                          const std::vector<std::string>& inputs,\n                                          const std::vector<std::string>& outputs) {\n  std::ostringstream ss;\n\n  auto AppendBnNdSbpString = [&](const std::string& bn) -> Maybe<void> {\n    auto iter = nd_sbp_signature.bn_in_op2nd_sbp().find(bn);\n    if (iter == nd_sbp_signature.bn_in_op2nd_sbp().end()) {\n      return Error::RuntimeError()\n             << \"can't find \" << bn << \" in NdSbpSignature: \" << nd_sbp_signature.DebugString();\n    }\n    ss << \" \" << NdSbpToString(iter->second);\n    return Maybe<void>::Ok();\n  };\n\n  int bn_index = 0;\n  for (const auto& ibn : inputs) {\n    if (bn_index > 0) { ss << \", \"; }\n    ss << ibn;\n    JUST(AppendBnNdSbpString(ibn));\n    bn_index++;\n  }\n\n  ss << \" -> \";\n  bn_index = 0;\n  for (const auto& obn : outputs) {\n    if (bn_index > 0) { ss << \", \"; }\n    ss << obn;\n    JUST(AppendBnNdSbpString(obn));\n    bn_index++;\n  }\n\n  return ss.str();\n}\n\nMaybe<std::string> NdSbpSignatureToString(const NdSbpSignature& nd_sbp_signature,\n                                          const PbRpf<std::string>& inputs,\n                                          const PbRpf<std::string>& outputs) {\n  return NdSbpSignatureToString(nd_sbp_signature,\n                                std::vector<std::string>{inputs.begin(), inputs.end()},\n                                std::vector<std::string>{outputs.begin(), outputs.end()});\n}\n\nMaybe<std::string> NdSbpSignatureListToString(const std::vector<NdSbpSignature>& nd_sbp_sig_list,\n                                              const std::vector<std::string>& inputs,\n                                              const std::vector<std::string>& outputs) {\n  std::ostringstream ss;\n  if (nd_sbp_sig_list.empty()) { return ss.str(); }\n\n  auto WalkIO =\n      [&](const std::function<Maybe<std::string>(const std::string&)>& bn_handler) -> Maybe<void> {\n    ss << \"(\";\n    for (size_t i = 0; i < inputs.size(); ++i) {\n      ss << *JUST(bn_handler(inputs[i]));\n      if (i != inputs.size() - 1) { ss << \", \"; }\n    }\n    ss << \") -> (\";\n    for (size_t i = 0; i < outputs.size(); ++i) {\n      ss << *JUST(bn_handler(outputs[i]));\n      if (i != outputs.size() - 1) { ss << \", \"; }\n    }\n    ss << \")\";\n    return Maybe<void>::Ok();\n  };\n\n  ss << \"\\n\";\n  JUST(WalkIO([](const std::string& bn) -> Maybe<std::string> { return bn; }));\n  ss << \": \";\n\n  ss << \"[\\n\";\n  for (const auto& nd_sbp_sig : nd_sbp_sig_list) {\n    ss << \"\\t\";\n    JUST(WalkIO([&](const std::string& bn) -> Maybe<std::string> {\n      auto it = nd_sbp_sig.bn_in_op2nd_sbp().find(bn);\n      if (it == nd_sbp_sig.bn_in_op2nd_sbp().end()) {\n        return Error::RuntimeError()\n               << \"can't find \" << bn << \" in NdSbpSignature: \" << nd_sbp_sig.DebugString();\n      }\n      return NdSbpToString(it->second);\n    }));\n    ss << \",\\n\";\n  }\n  ss << \"]\";\n  return ss.str();\n}\n\nMaybe<std::string> NdSbpSignatureListToString(const std::vector<NdSbpSignature>& nd_sbp_sig_list,\n                                              const PbRpf<std::string>& inputs,\n                                              const PbRpf<std::string>& outputs) {\n  return NdSbpSignatureListToString(nd_sbp_sig_list,\n                                    std::vector<std::string>{inputs.begin(), inputs.end()},\n                                    std::vector<std::string>{outputs.begin(), outputs.end()});\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job/sbp_parallel.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_JOB_SBP_PARALLEL_H_\n#define ONEFLOW_CORE_JOB_SBP_PARALLEL_H_\n\n#include \"oneflow/core/job/sbp_parallel.pb.h\"\n#include \"oneflow/core/job/sbp_infer_hint.h\"\n#include \"oneflow/core/common/protobuf.h\"\n#include \"oneflow/core/common/symbol.h\"\n\nnamespace oneflow {\n\nbool operator==(const SbpParallel& lhs, const SbpParallel& rhs);\ninline bool operator!=(const SbpParallel& lhs, const SbpParallel& rhs) { return !(lhs == rhs); }\n\nbool operator==(const NdSbp& lhs, const NdSbp& rhs);\ninline bool operator!=(const NdSbp& lhs, const NdSbp& rhs) { return !(lhs == rhs); }\n\nbool operator==(const SbpSignature& lhs, const SbpSignature& rhs);\ninline bool operator!=(const SbpSignature& lhs, const SbpSignature& rhs) { return !(lhs == rhs); }\n\nbool operator==(const NdSbpSignature& lhs, const NdSbpSignature& rhs);\ninline bool operator!=(const NdSbpSignature& lhs, const NdSbpSignature& rhs) {\n  return !(lhs == rhs);\n}\n\n}  // namespace oneflow\n\nnamespace std {\n\ntemplate<>\nstruct hash<oneflow::SbpSignature> : public oneflow::SerializedHashPb<oneflow::SbpSignature> {};\n\ntemplate<>\nstruct hash<oneflow::NdSbpSignature> : public oneflow::SerializedHashPb<oneflow::NdSbpSignature> {};\n\n}  // namespace std\n\nnamespace oneflow {\n\nMaybe<Symbol<SbpParallel>> MakeSplitSbpParallel(int axis);\nMaybe<Symbol<SbpParallel>> MakeBroadcastSbpParallel();\nMaybe<Symbol<SbpParallel>> MakePartialSumSbpParallel();\n\nSbpParallel GetDualSbpParallel(const SbpParallel&);\n\nbool IsSbpSignatureContaining(const SbpSignature& bigger, const SbpSignature& smaller);\n\nvoid FilterSbpSignatureList(const SbpSignatureList& sbp_sig_list, const SbpSignature& sbp_sig_conf,\n                            SbpSignatureList* filtered_sbp_sig_list);\n\nvoid SortSbpSignatureListByCopyCost(\n    const SbpSignatureList& sbp_sig_list, const PbRpf<std::string>& ibns,\n    const std::function<Maybe<const SbpInferHint*>(const std::string&)>& SbpInferHint4Ibn,\n    const std::function<int32_t(const SbpSignature&)>& OrderValue4SbpSig,\n    std::vector<const SbpSignature*>* sorted_sbp_signatures);\n\nbool IsValidSbpParallelString(const std::string& sbp_str);\nbool ParseSbpParallelFromString(const std::string& sbp_str, SbpParallel* sbp_parallel);\nstd::string SbpParallelToString(const SbpParallel& sbp_parallel);\n\nbool ParseNdSbpFromStringList(const std::vector<std::string>& sbp_str_list, NdSbp* nd_sbp);\nstd::vector<std::string> NdSbpToStringList(const NdSbp& nd_sbp);\n\nbool ParseNdSbpFromLongString(const std::string& nd_sbp_str, NdSbp* nd_sbp);\nstd::string NdSbpToLongString(const NdSbp& nd_sbp);\n\nvoid SbpSignatureToNdSbpSignature(const SbpSignature& sbp_signature,\n                                  NdSbpSignature* nd_sbp_signature);\n\nvoid NdSbpSignatureToSbpSignature(const NdSbpSignature& nd_sbp_signature,\n                                  SbpSignature* sbp_signature);\n\nvoid CheckSbpSignatureAndNdSbpEquals(const SbpSignature& sbp_sig, const NdSbpSignature& nd_sbp_sig);\n\nbool NdSbpAllSameSplitParallel(const NdSbp& nd_sbp);\n\n// Print functions\n\nMaybe<std::string> NdSbpSignatureToString(const NdSbpSignature& nd_sbp_signature,\n                                          const std::vector<std::string>& inputs,\n                                          const std::vector<std::string>& outputs);\n\nMaybe<std::string> NdSbpSignatureToString(const NdSbpSignature& nd_sbp_signature,\n                                          const PbRpf<std::string>& inputs,\n                                          const PbRpf<std::string>& outputs);\n\nMaybe<std::string> NdSbpSignatureListToString(const std::vector<NdSbpSignature>& nd_sbp_sig_list,\n                                              const std::vector<std::string>& inputs,\n                                              const std::vector<std::string>& outputs);\n\nMaybe<std::string> NdSbpSignatureListToString(const std::vector<NdSbpSignature>& nd_sbp_sig_list,\n                                              const PbRpf<std::string>& inputs,\n                                              const PbRpf<std::string>& outputs);\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_JOB_SBP_PARALLEL_H_\n"
  },
  {
    "path": "oneflow/core/job/sbp_parallel.proto",
    "content": "syntax = \"proto2\";\npackage oneflow;\n\n// Take matmal_op as an example.\n\n//   Y     =   A    *   B\n// (m, n)    (m, k) , (k, n)\n\n\n// candidate signature 0:\n//     Y:Split(0), A:Split(0), B:Broadcast\n//     -----------------------------------\n//     device0:   Y0    =   A0    *   B\n//             (m0, n)    (m0, k) , (k, n)\n//     -----------------------------------\n//     device1:   Y1    =   A1    *   B\n//             (m1, n)    (m1, k) , (k, n)\n//     -----------------------------------\n//     where (m0 + m1 == m)\n//            and (A0 == A[0:m0, :]) and (A1 == A[m0:, :])\n//            and (Y0 == Y[0:m0, :]) and (Y1 == Y[m0:, :])\n\n// candidate signature 1:\n//     Y:Split(1), A:Broadcast, B:Split(1)\n//     -----------------------------------\n//     device0:   Y0    =   A    *   B0\n//             (m, n0)    (m, k) , (k, n0)\n//     -----------------------------------\n//     device1:   Y1    =   A    *   B1\n//             (m, n1)    (m, k) , (k, n1)\n//     -----------------------------------\n//     where (n0 + n1 == n)\n//            and (B0 == B[:, 0:n0]) and (B1 == B[:, n0:])\n//            and (Y0 == Y[:, 0:n0]) and (Y1 == Y[:, n0:])\n\n// candidate signature 2:\n//     Y:PartialSum, A:Split(1), B:Split(0)\n//     ------------------------------------\n//     device0:   Y0    =   A0    *   B0\n//              (m, n)    (m, k0) , (k0, n)\n//     ------------------------------------\n//     device1:   Y1    =   A1    *   B1\n//              (m, n)    (m, k1) , (k1, n)\n//     ------------------------------------\n//     where (k0 + k1 == k) and (Y0 + Y1 == Y)\n\nmessage SplitParallel {\n  required int64 axis = 1;\n}\n\nmessage BroadcastParallel {\n}\n\nmessage PartialSumParallel {\n}\n\nmessage SbpParallel {\n  oneof parallel_type {\n    SplitParallel split_parallel = 1;\n    BroadcastParallel broadcast_parallel = 2;\n    PartialSumParallel partial_sum_parallel = 3;\n  }\n}\n\nmessage SbpSignature {\n  map<string, SbpParallel> bn_in_op2sbp_parallel = 1;\n}\n\nmessage NdSbp {\n  repeated SbpParallel sbp_parallel = 1;\n}\n\nmessage NdSbpSignature {\n  map<string, NdSbp> bn_in_op2nd_sbp = 1;\n}\n\nmessage SbpSignatureList {\n  repeated SbpSignature sbp_signature = 1;\n}\n"
  },
  {
    "path": "oneflow/core/job/sbp_signature_builder.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/job/sbp_signature_builder.h\"\n#include \"oneflow/core/operator/operator.h\"\n#include \"oneflow/core/job/sbp_parallel.h\"\n\nnamespace oneflow {\n\nvoid SplitSbpSignatureListBuilder::CheckTemplate() {\n  CHECK_GT(sbp_signature_template_.bn_in_op2sbp_parallel().size(), 0);\n  const auto& first = sbp_signature_template_.bn_in_op2sbp_parallel().begin()->second;\n  CHECK(first.has_split_parallel());\n  for (const auto& pair : sbp_signature_template_.bn_in_op2sbp_parallel()) {\n    CHECK(first == pair.second);\n  }\n}\n\nSplitSbpSignatureListBuilder&& SplitSbpSignatureListBuilder::SetNumAxes(int64_t num_axes) {\n  num_axes_ = num_axes;\n  return std::move(*this);\n}\n\nvoid SplitSbpSignatureListBuilder::Build(SbpSignatureList* list) const {\n  CHECK_GE(num_axes_, 0);\n  SbpSignature sbp_sig_template(sbp_signature_template_);\n  FOR_RANGE(int32_t, axis, 0, num_axes_) {\n    for (auto& pair : *sbp_sig_template.mutable_bn_in_op2sbp_parallel()) {\n      pair.second.mutable_split_parallel()->set_axis(axis);\n    }\n    *list->mutable_sbp_signature()->Add() = sbp_sig_template;\n  }\n}\n\nSbpSignatureBuilder&& SbpSignatureBuilder::Split(const std::string& bn_in_op, int64_t axis) {\n  (*sbp_signature_.mutable_bn_in_op2sbp_parallel())[bn_in_op].mutable_split_parallel()->set_axis(\n      axis);\n  return std::move(*this);\n}\n\nSbpSignatureBuilder&& SbpSignatureBuilder::Broadcast(const std::string& bn_in_op) {\n  (*sbp_signature_.mutable_bn_in_op2sbp_parallel())[bn_in_op].mutable_broadcast_parallel();\n  return std::move(*this);\n}\n\nSbpSignatureBuilder&& SbpSignatureBuilder::PartialSum(const std::string& bn_in_op) {\n  (*sbp_signature_.mutable_bn_in_op2sbp_parallel())[bn_in_op].mutable_partial_sum_parallel();\n  return std::move(*this);\n}\n\nSbpSignatureBuilder&& SbpSignatureBuilder::Split(const PbRpf<std::string>& bns, int64_t axis) {\n  for (const auto& bn_in_op : bns) { Split(bn_in_op, axis); }\n  return std::move(*this);\n}\n\nSbpSignatureBuilder&& SbpSignatureBuilder::Broadcast(const PbRpf<std::string>& bns) {\n  for (const auto& bn_in_op : bns) { Broadcast(bn_in_op); }\n  return std::move(*this);\n}\n\nSbpSignatureBuilder&& SbpSignatureBuilder::PartialSum(const PbRpf<std::string>& bns) {\n  for (const auto& bn_in_op : bns) { PartialSum(bn_in_op); }\n  return std::move(*this);\n}\n\nSbpSignatureBuilder&& SbpSignatureBuilder::Split(const std::initializer_list<std::string>& bns,\n                                                 int64_t axis) {\n  for (const auto& bn_in_op : bns) { Split(bn_in_op, axis); }\n  return std::move(*this);\n}\n\nSbpSignatureBuilder&& SbpSignatureBuilder::Broadcast(\n    const std::initializer_list<std::string>& bns) {\n  for (const auto& bn_in_op : bns) { Broadcast(bn_in_op); }\n  return std::move(*this);\n}\n\nSbpSignatureBuilder&& SbpSignatureBuilder::PartialSum(\n    const std::initializer_list<std::string>& bns) {\n  for (const auto& bn_in_op : bns) { PartialSum(bn_in_op); }\n  return std::move(*this);\n}\n\nSplitSbpSignatureListBuilder SbpSignatureBuilder::MakeSplitSignatureListBuilder(\n    int64_t num_axes) const {\n  SbpSignature sbp_signature;\n  Build(&sbp_signature);\n  return SplitSbpSignatureListBuilder(sbp_signature).SetNumAxes(num_axes);\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job/sbp_signature_builder.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_JOB_SBP_SIGNATURE_BUILDER_H_\n#define ONEFLOW_CORE_JOB_SBP_SIGNATURE_BUILDER_H_\n\n#include \"oneflow/core/job/sbp_parallel.h\"\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/common/protobuf.h\"\n\nnamespace oneflow {\n\nclass SplitSbpSignatureListBuilder final {\n public:\n  SplitSbpSignatureListBuilder(const SplitSbpSignatureListBuilder&) = default;\n  explicit SplitSbpSignatureListBuilder(const SbpSignature& sbp_signature_template)\n      : sbp_signature_template_(sbp_signature_template), num_axes_(0) {\n    CheckTemplate();\n  }\n  ~SplitSbpSignatureListBuilder() = default;\n\n  SplitSbpSignatureListBuilder&& SetNumAxes(int64_t num_axes);\n  void Build(SbpSignatureList* list) const;\n\n private:\n  void CheckTemplate();\n\n  SbpSignature sbp_signature_template_;\n  int64_t num_axes_;\n};\n\nclass SbpSignatureBuilder final {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(SbpSignatureBuilder);\n  SbpSignatureBuilder() = default;\n  ~SbpSignatureBuilder() = default;\n\n  // split\n  SbpSignatureBuilder&& Split(const std::string& bn_in_op, int64_t axis);\n  SbpSignatureBuilder&& Split(const PbRpf<std::string>& bns, int64_t axis);\n  SbpSignatureBuilder&& Split(const std::initializer_list<std::string>& bns, int64_t axis);\n\n  // broadcast\n  SbpSignatureBuilder&& Broadcast(const std::string& bn_in_op);\n  SbpSignatureBuilder&& Broadcast(const PbRpf<std::string>& bns);\n  SbpSignatureBuilder&& Broadcast(const std::initializer_list<std::string>& bns);\n\n  // partial_sum\n  SbpSignatureBuilder&& PartialSum(const std::string& bn_in_op);\n  SbpSignatureBuilder&& PartialSum(const PbRpf<std::string>& bns);\n  SbpSignatureBuilder&& PartialSum(const std::initializer_list<std::string>& bns);\n\n  SplitSbpSignatureListBuilder MakeSplitSignatureListBuilder(int64_t num_axes) const;\n  void Build(SbpSignature* ret) const { *ret = sbp_signature_; }\n\n private:\n  SbpSignature sbp_signature_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_JOB_SBP_SIGNATURE_BUILDER_H_\n"
  },
  {
    "path": "oneflow/core/job/scope.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/to_string.h\"\n#include \"oneflow/core/job/scope.h\"\n#include \"oneflow/core/job/scope.pb.h\"\n#include \"oneflow/core/operator/operator.h\"\n#include \"oneflow/core/vm/symbol_storage.h\"\n#include \"oneflow/core/framework/instructions_builder.h\"\n\nnamespace oneflow {\n\nScope::Scope(const ScopeProto& scope_proto)\n    : auto_increment_id_(0), symbol_id_(NullOpt), scope_proto_(scope_proto) {\n  CHECK_OK(Init()) << scope_proto_.DebugString();\n}\n\nScope::Scope(int64_t symbol_id, const ScopeProto& scope_proto)\n    : auto_increment_id_(0), symbol_id_(symbol_id), scope_proto_(scope_proto) {}\n\nMaybe<Scope> Scope::New(int64_t symbol_id, const ScopeProto& scope_proto) {\n  auto* ptr = new Scope(symbol_id, scope_proto);\n  std::shared_ptr<Scope> scope(ptr);\n  JUST(scope->Init());\n  return scope;\n}\n\nMaybe<void> Scope::Init() {\n  {\n    const auto& storage = *Singleton<symbol::Storage<JobDesc>>::Get();\n    job_desc_ = JUST(storage.MaybeGetPtr(scope_proto_.job_desc_symbol_id()));\n  }\n  {\n    const auto& storage = *Singleton<symbol::Storage<ParallelDesc>>::Get();\n    const auto& device_parallel_desc =\n        SymbolOf(*JUST(storage.MaybeGetPtr(scope_proto_.device_parallel_desc_symbol_id())));\n    const auto& host_parallel_desc =\n        SymbolOf(*JUST(storage.MaybeGetPtr(scope_proto_.host_parallel_desc_symbol_id())));\n    placement_scope_ = SymbolOf(PlacementScope(device_parallel_desc, host_parallel_desc));\n  }\n  {\n    const auto& storage = *Singleton<symbol::Storage<Scope>>::Get();\n    if (scope_proto_.has_parent_scope_symbol_id()) {\n      parent_scope_symbol_ = JUST(storage.MaybeGetPtr(scope_proto_.parent_scope_symbol_id()));\n    }\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<const JobDesc*> Scope::job_desc() const {\n  CHECK_NOTNULL_OR_RETURN(job_desc_.get());\n  return job_desc_.get();\n}\n\nMaybe<int64_t> Scope::GetParallelDescSymbolId(const OperatorConf& op_conf) const {\n  if (op_conf.device_tag() == \"cpu\" || IsCpuOnly(op_conf)) {\n    return scope_proto_.host_parallel_desc_symbol_id();\n  } else {\n    return scope_proto_.device_parallel_desc_symbol_id();\n  }\n}\n\nMaybe<Symbol<ParallelDesc>> Scope::GetParallelDesc(const OperatorConf& op_conf) const {\n  return placement_scope_->GetParallelDesc(op_conf.device_tag(), op_conf);\n}\n\nconst AttrValue& Scope::GetAttrValue(const std::string& attr_name) const {\n  const auto& iter = scope_proto_.attr_name2attr_value().find(attr_name);\n  if (iter != scope_proto_.attr_name2attr_value().end()) { return iter->second; }\n  const auto& attr_name2attr_def = GlobalScopeConfigDef().attr_name2attr_def();\n  const auto& def_iter = attr_name2attr_def.find(attr_name);\n  CHECK(def_iter != attr_name2attr_def.end());\n  return def_iter->second.default_val();\n}\n\nMaybe<ScopeProto> Scope::MakeChildScopeProto() const {\n  auto child = std::make_shared<ScopeProto>(scope_proto_);\n  child->set_parent_scope_symbol_id(JUST(symbol_id()));\n  return child;\n}\n\nMaybe<int64_t> NewScopeSymbolId(\n    int64_t old_scope_symbol_id,\n    const std::function<void(std::shared_ptr<ScopeProto> new_scope)>& InitNewScopeProto) {\n  CHECK_OR_RETURN(Singleton<symbol::Storage<Scope>>::Get()->Has(old_scope_symbol_id));  // NOLINT\n  const Scope& old_scope = Singleton<symbol::Storage<Scope>>::Get()->Get(old_scope_symbol_id);\n  std::shared_ptr<ScopeProto> new_scope = JUST(old_scope.MakeChildScopeProto());\n  InitNewScopeProto(new_scope);\n  std::shared_ptr<Scope> new_scope_symbol;\n  JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe<void> {\n    new_scope_symbol = JUST(builder->GetScopeSymbol(*new_scope));\n    return Maybe<void>::Ok();\n  }));\n  return JUST(new_scope_symbol->symbol_id());\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job/scope.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_JOB_SCOPE_H_\n#define ONEFLOW_CORE_JOB_SCOPE_H_\n\n#include \"oneflow/core/job/scope.pb.h\"\n#include \"oneflow/core/job/parallel_desc.h\"\n#include \"oneflow/core/job/placement_scope.h\"\n#include \"oneflow/core/job/job_desc.h\"\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/common/optional.h\"\n#include \"oneflow/core/common/symbol.h\"\n\nnamespace oneflow {\n\nclass OperatorConf;\n\nclass Scope final {\n public:\n  Scope(const Scope&) = delete;\n  Scope(Scope&&) = delete;\n  explicit Scope(const ScopeProto& scope_proto);\n  ~Scope() = default;\n\n  static Maybe<Scope> New(int64_t symbol_id, const ScopeProto& scope_proto);\n  const Optional<int64_t>& symbol_id() const { return symbol_id_; }\n  int64_t auto_increment_id() { return ++auto_increment_id_; }\n  int64_t session_id() const { return scope_proto().session_id(); }\n  const std::shared_ptr<JobDesc>& job_desc_symbol() const { return job_desc_; }\n  Symbol<PlacementScope> placement_scope() const { return placement_scope_; }\n  Symbol<ParallelDesc> device_parallel_desc_symbol() const {\n    return placement_scope_->device_parallel_desc();\n  }\n  const std::shared_ptr<Scope>& parent_scope_symbol() const { return parent_scope_symbol_; }\n  Maybe<ScopeProto> MakeChildScopeProto() const;\n\n  Maybe<const JobDesc*> job_desc() const;\n  Maybe<int64_t> GetParallelDescSymbolId(const OperatorConf& op_conf) const;\n  Maybe<Symbol<ParallelDesc>> GetParallelDesc(const OperatorConf& op_conf) const;\n\n  const OptLocalParallel& opt_local_parallel_conf() const {\n    return scope_proto_.opt_local_parallel_conf();\n  }\n  const ScopeProto& scope_proto() const { return scope_proto_; }\n  const ScopeProto& data() const { return scope_proto_; }\n\n#define DEFINE_SCOPE_CONFIG_GETTER(T, func_name, field_name) \\\n  T func_name(const std::string& field_name) const {         \\\n    const AttrValue& attr_val = GetAttrValue(field_name);    \\\n    CHECK(attr_val.has_##field_name());                      \\\n    return attr_val.field_name();                            \\\n  }\n  DEFINE_SCOPE_CONFIG_GETTER(bool, Bool, at_bool);\n  DEFINE_SCOPE_CONFIG_GETTER(int64_t, Int64, at_int64);\n  DEFINE_SCOPE_CONFIG_GETTER(double, Double, at_double);\n  DEFINE_SCOPE_CONFIG_GETTER(const std::string&, String, at_string);\n\n private:\n  Scope(int64_t symbol_id, const ScopeProto& scope_proto);\n  Maybe<void> Init();\n\n  const AttrValue& GetAttrValue(const std::string& attr_name) const;\n\n  int64_t auto_increment_id_;\n  Optional<int64_t> symbol_id_;\n  const ScopeProto scope_proto_;\n  std::shared_ptr<JobDesc> job_desc_;\n  Symbol<PlacementScope> placement_scope_;\n  std::shared_ptr<Scope> parent_scope_symbol_;\n};\n\nMaybe<int64_t> NewScopeSymbolId(\n    int64_t old_scope_symbol_id,\n    const std::function<void(std::shared_ptr<ScopeProto> new_scope)>& InitNewScopeProto);\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_JOB_SCOPE_H_\n"
  },
  {
    "path": "oneflow/core/job/scope.proto",
    "content": "syntax = \"proto2\";\npackage oneflow;\n\nimport \"oneflow/core/job/local_parallel.proto\";\nimport \"oneflow/core/framework/user_op_attr.proto\";\nimport \"oneflow/core/job/module_conf.proto\";\n\nmessage ScopeProto {\n  required int64 job_desc_symbol_id = 20;\n  required int64 device_parallel_desc_symbol_id = 30;\n  required int64 host_parallel_desc_symbol_id = 40; \n  optional bool enable_cpu_alternative_op = 41 [default = true];\n  required OptLocalParallel opt_local_parallel_conf = 50;\n  repeated string scope_op_name_prefixes = 60;\n  optional int64 parent_scope_symbol_id = 70;\n  required int64 session_id = 80;\n  map<string, AttrValue> attr_name2attr_value = 90;\n  optional string calculation_pass_name = 100 [default = \"forward_pass\"];\n  optional string module_name = 110;\n}\n"
  },
  {
    "path": "oneflow/core/job/session.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <atomic>\n#include \"oneflow/core/job/session.h\"\n#include \"oneflow/core/job/job_set.pb.h\"\n#include \"oneflow/core/common/protobuf.h\"\n#include \"oneflow/core/common/util.h\"\n\nnamespace oneflow {\n\nint64_t NewSessionId() {\n  static std::atomic<int64_t> counter(0);\n  return counter++;\n}\n\nConfigProtoContext::ConfigProtoContext(const ConfigProto& config_proto)\n    : session_id_(config_proto.session_id()) {}\n\nConfigProtoContext::~ConfigProtoContext() {}\n\nLogicalConfigProtoContext::LogicalConfigProtoContext(const std::string& config_proto_str) {\n  ConfigProto config_proto;\n  CHECK(TxtString2PbMessage(config_proto_str, &config_proto));\n  // TODO(hanbinbin): init for worker machines\n  config_proto_ctx_.reset(new ConfigProtoContext(config_proto));\n}\n\nLogicalConfigProtoContext::~LogicalConfigProtoContext() {\n  config_proto_ctx_.reset();\n  // TODO(hanbinbin): destroy ConfigProtoContext of worker machines\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job/session.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_JOB_SESSION_H_\n#define ONEFLOW_CORE_JOB_SESSION_H_\n\n#include <memory>\n#include <string>\n\nnamespace oneflow {\n\nint64_t NewSessionId();\n\nclass ConfigProto;\nclass ConfigProtoContext {\n public:\n  ConfigProtoContext(const ConfigProto& config_proto);\n  ~ConfigProtoContext();\n\n  int64_t session_id() const { return session_id_; }\n\n private:\n  int64_t session_id_;\n};\n\nclass LogicalConfigProtoContext {\n public:\n  LogicalConfigProtoContext(const std::string& config_proto_str);\n  ~LogicalConfigProtoContext();\n\n  std::unique_ptr<ConfigProtoContext> config_proto_ctx_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_JOB_SESSION_H_\n"
  },
  {
    "path": "oneflow/core/job/ssp_config_def.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/config_def.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nREGISTER_FUNCTION_CONFIG_DEF()\n    .Bool(\"enable_ssp\", false, \"enable ssp\")\n    .String(\"ssp_partition_strategy\", \"naive_sequential\",\n            \"ssp partition strategy, Avaiable strategies: naive_sequential | disable\")\n    .ListInt64(\"ssp_partition_scope_ids\", {}, \"type: list[int64]. ssp partition scope symbol ids\");\n\nREGISTER_SCOPE_CONFIG_DEF()\n    .Int64(\"ssp_num_stages\", -1, \"total number of ssp stages\")\n    .Int64(\"ssp_stage_id\", -1, \"current ssp stage id \");\n\n}  // namespace\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job/sub_plan.proto",
    "content": "syntax = \"proto2\";\npackage oneflow;\n\nimport \"oneflow/core/job/task.proto\";\n\nmessage ThrdIds {\n  repeated int64 thrd_id = 1;\n}\n\nmessage ClusterThrdIds {\n  map<int64, ThrdIds> machine_id2thrd_ids = 1;\n}\n\nmessage SubPlan {\n  repeated TaskProto task = 1;\n}\n"
  },
  {
    "path": "oneflow/core/job/task.proto",
    "content": "syntax = \"proto2\";\npackage oneflow;\n\nimport \"oneflow/core/graph/exec_sequence.proto\";\nimport \"oneflow/core/register/register_desc.proto\";\nimport \"oneflow/core/job/placement.proto\";\n\nenum TaskType {\n  kInvalid = 0;\n  kNormalForward = 1;\n  kCopyHd = 12;\n  kCopyCommNet = 13;\n  kDeviceTick = 27;\n  kPack = 30;\n  kUnpack = 31;\n  kRepeat = 32;\n  kAcc = 33;\n  kAccCtrlTick = 34;\n  kSrcSubsetTick = 38;\n  kDstSubsetTick = 39;\n  kSourceTick = 40;\n  kTick = 41;\n  kAccTick = 42;\n  kCase = 43;\n  kEsac = 44;\n  kWaitAndSendIds = 45;\n  kReentrantLock = 46;\n  kCallbackNotify = 47;\n  kDistributeConcat = 55;\n  kDistributeSplit = 56;\n  kSliceBoxing = 57;\n  kCollectiveBoxingGeneric = 58;\n  kBoxingIdentity = 59;\n  kDecodeH2D = 60;\n  kCollectiveBoxingPack = 61;\n  kCollectiveBoxingUnpack = 62;\n  kSspVariableProxy = 63;\n  kBoxingZeros = 64;\n  kCriticalSectionWaitTick = 65;\n  kNcclSendRecvBoxing = 66;\n};\n\nmessage RegstDescIdSet {\n  repeated int64 regst_desc_id = 1;\n}\n\nmessage TaskProto {\n  // common\n  required TaskType task_type = 1;\n  required int64 machine_id = 2;\n  required int64 thrd_id = 3;\n  required int64 task_id = 4;\n  required int64 job_id = 5;\n  required ExecSequence exec_sequence = 7;\n  map<string, RegstDescProto> produced_regst_desc = 8;\n  map<string, RegstDescIdSet> consumed_regst_desc_id = 9;\n  optional bool all_register_num_eq_one_hint = 10 [default = false];\n  required int64 chain_id = 20;\n  required int64 order_in_chain = 21;\n  // compute task\n  optional ParallelContext parallel_ctx = 1000; // CompTask\n};\n"
  },
  {
    "path": "oneflow/core/job/utils/progress_bar.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/job/utils/progress_bar.h\"\n#include \"oneflow/core/job/graph_scope_vars.h\"\n#include \"oneflow/core/rpc/include/global_process_ctx.h\"\n\nnamespace oneflow {\n\nMaybe<void> LogProgress(const std::string& task_name, bool is_end) {\n  const bool log_progress =\n      GetGraphDebugMode() || ThreadLocalEnvBool<ONEFLOW_NNGRAPH_ENABLE_PROGRESS_BAR>();\n  if (!log_progress || OF_PREDICT_FALSE(GlobalProcessCtx::Rank() != 0)) {\n    return Maybe<void>::Ok();\n  }\n\n  const static thread_local uint64_t progress_total_num = 60;\n  static thread_local uint64_t progress_cnt = 1;\n  static constexpr char clear_line[] =\n      \"                                                                         \\r\";\n\n  auto const& limited_str = task_name.size() > 60 ? task_name.substr(0, 60) : task_name;\n  std::cout << clear_line << \"[\" << progress_cnt << \"/\" << progress_total_num << \"]\" << limited_str\n            << \"\\r\" << std::flush;\n  if (is_end) {\n    progress_cnt = 0;\n    std::cout << clear_line << std::endl << std::flush;\n  }\n  ++progress_cnt;\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job/utils/progress_bar.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_JOB_UTILS_PROGRESS_BAR_H_\n#define ONEFLOW_CORE_JOB_UTILS_PROGRESS_BAR_H_\n\n#include <string>\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/common/env_var/env_var.h\"\n\nnamespace oneflow {\n\nDEFINE_THREAD_LOCAL_ENV_BOOL(ONEFLOW_NNGRAPH_ENABLE_PROGRESS_BAR, false);\n\nMaybe<void> LogProgress(const std::string& task_name = \"\", bool is_end = false);\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_JOB_UTILS_PROGRESS_BAR_H_\n"
  },
  {
    "path": "oneflow/core/job/version.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/job/version.h\"\n#include \"oneflow/core/ep/include/device_manager_registry.h\"\n\nnamespace oneflow {\n\nvoid DumpVersionInfo() {\n  LOG(INFO) << \"OneFlow git version: \" << GetOneFlowGitVersion();\n  ep::DeviceManagerRegistry::DumpVersionInfo();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job/version.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_JOB_VERSION_H_\n#define ONEFLOW_CORE_JOB_VERSION_H_\n\n#include \"oneflow/core/common/util.h\"\n\nnamespace oneflow {\n\nconst char* GetOneFlowGitVersion();\n\nvoid DumpVersionInfo();\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_JOB_VERSION_H_\n"
  },
  {
    "path": "oneflow/core/job_rewriter/adadelta_optim.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/user_op_conf.h\"\n#include \"oneflow/core/job/initializer_conf.pb.h\"\n#include \"oneflow/core/job/job_builder.h\"\n#include \"oneflow/core/job/job_conf.pb.h\"\n#include \"oneflow/core/job_rewriter/job_pass.h\"\n#include \"oneflow/core/job_rewriter/optimizer.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/operator/op_conf.pb.h\"\n#include \"oneflow/core/operator/operator.h\"\n#include \"oneflow/core/operator/variable_op.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nstd::string GenVariableOutputLbn(const OperatorConf& op_conf) {\n  CHECK(op_conf.has_variable_conf());\n  return GenLogicalBlobName(op_conf.name(), op_conf.variable_conf().out());\n}\n\nOperatorConf GenerateAdadeltaHelperVariableConf(const VariableOp& op, const std::string& name) {\n  OperatorConf helper_variable_op(op.op_conf());\n  helper_variable_op.set_name(op.op_name() + \"-\" + name);\n  helper_variable_op.mutable_variable_conf()->set_out(\"out\");\n  InitializerConf constant_initializer;\n  constant_initializer.mutable_constant_conf()->set_value(0.0f);\n  *(helper_variable_op.mutable_variable_conf()->mutable_initializer()) = constant_initializer;\n  helper_variable_op.set_scope_symbol_id(op.op_conf().scope_symbol_id());\n  return helper_variable_op;\n}\n\nvoid GenerateAdadeltaOptimizerOpConf(JobPassCtx* ctx, const OpNode& var_op_node,\n                                     const std::string& model_diff_lbn,\n                                     const OptimizerConf& optimizer_conf, JobBuilder* job_builder) {\n  const VariableOp* var_op = dynamic_cast<const VariableOp*>(&var_op_node.op());\n  CHECK_NOTNULL(var_op);\n\n  user_op::UserOpConfWrapperBuilder adadelta_update_op_builder(var_op->op_name() + \"_optimizer\");\n  float rho = 0.0;\n  float epsilon = 0.0;\n  bool maximize = false;\n\n  const AdadeltaModelUpdateConf& adadelta_conf = optimizer_conf.adadelta_conf();\n  rho = adadelta_conf.rho();\n  epsilon = adadelta_conf.epsilon();\n  maximize = adadelta_conf.maximize();\n  const std::string& learning_rate_lbn = optimizer_conf.learning_rate_lbn();\n\n  OperatorConf square_avgs_var(GenerateAdadeltaHelperVariableConf(*var_op, \"square_avgs\"));\n  OperatorConf acc_deltas_var(GenerateAdadeltaHelperVariableConf(*var_op, \"acc_deltas\"));\n  job_builder->AddOps(var_op_node.parallel_desc().parallel_conf(),\n                      {square_avgs_var, acc_deltas_var});\n\n  adadelta_update_op_builder.OpTypeName(\"adadelta_update\")\n      .Input(\"model\", GenLogicalBlobName(var_op->BnInOp2Lbi(\"out\")))\n      .Input(\"model_diff\", model_diff_lbn)\n      .Input(\"learning_rate\", learning_rate_lbn)\n      .Input(\"square_avgs\", GenVariableOutputLbn(square_avgs_var))\n      .Input(\"acc_deltas\", GenVariableOutputLbn(acc_deltas_var))\n      .Attr<float>(\"rho\", rho)\n      .Attr<float>(\"epsilon\", epsilon)\n      .Attr<bool>(\"maximize\", maximize)\n      .Attr<float>(\"weight_decay\", GetOptimizerWeightDecayRate(optimizer_conf, *var_op))\n      .ScopeSymbolId(var_op->op_conf().scope_symbol_id());\n  if (optimizer_conf.has_lr_scale()) {\n    adadelta_update_op_builder.Attr<float>(\"learning_rate_scale\", optimizer_conf.lr_scale());\n  }\n  SetDynamicLossScaleSkipIf(ctx, &adadelta_update_op_builder);\n  const auto adadelta_update_op = adadelta_update_op_builder.Build();\n  job_builder->AddOps(var_op_node.parallel_desc().parallel_conf(), {adadelta_update_op.op_conf()});\n}\n\n}  // namespace\n\nREGISTER_OPTIMIZER(OptimizerConf::kAdadeltaConf, &GenerateAdadeltaOptimizerOpConf);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job_rewriter/adagrad_optm.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/user_op_conf.h\"\n#include \"oneflow/core/job/initializer_conf.pb.h\"\n#include \"oneflow/core/job/job_builder.h\"\n#include \"oneflow/core/job/job_conf.pb.h\"\n#include \"oneflow/core/job_rewriter/job_pass.h\"\n#include \"oneflow/core/job_rewriter/optimizer.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/operator/op_conf.pb.h\"\n#include \"oneflow/core/operator/operator.h\"\n#include \"oneflow/core/operator/variable_op.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nstd::string GenVariableOutputLbn(const OperatorConf& op_conf) {\n  CHECK(op_conf.has_variable_conf());\n  return GenLogicalBlobName(op_conf.name(), op_conf.variable_conf().out());\n}\n\nOperatorConf GenerateAdagradHelperVariableConf(const VariableOp& op, const std::string& name,\n                                               const float initial_value) {\n  OperatorConf helper_variable_op(op.op_conf());\n  helper_variable_op.set_name(op.op_name() + \"-\" + name);\n  helper_variable_op.mutable_variable_conf()->set_out(\"out\");\n  InitializerConf constant_initializer;\n  constant_initializer.mutable_constant_conf()->set_value(initial_value);\n  *(helper_variable_op.mutable_variable_conf()->mutable_initializer()) = constant_initializer;\n  helper_variable_op.set_scope_symbol_id(op.op_conf().scope_symbol_id());\n  return helper_variable_op;\n}\n\nvoid GenerateAdagradOptimizerOpConf(JobPassCtx* ctx, const OpNode& var_op_node,\n                                    const std::string& model_diff_lbn,\n                                    const OptimizerConf& optimizer_conf, JobBuilder* job_builder) {\n  const VariableOp* var_op = dynamic_cast<const VariableOp*>(&var_op_node.op());\n  CHECK_NOTNULL(var_op);\n\n  user_op::UserOpConfWrapperBuilder adagrad_update_op_builder(var_op->op_name() + \"_optimizer\");\n  float lr_decay = 0.0;\n  float initial_accumulator_value = 0.0;\n  float epsilon = 0.0;\n\n  const AdagradModelUpdateConf& adagrad_conf = optimizer_conf.adagrad_conf();\n  lr_decay = adagrad_conf.lr_decay();\n  initial_accumulator_value = adagrad_conf.initial_accumulator_value();\n  epsilon = adagrad_conf.epsilon();\n\n  const std::string& train_step_lbn = job_builder->job().job_conf().train_conf().train_step_lbn();\n  const std::string& learning_rate_lbn = optimizer_conf.learning_rate_lbn();\n\n  OperatorConf sum_var(\n      GenerateAdagradHelperVariableConf(*var_op, \"sum\", initial_accumulator_value));\n  job_builder->AddOps(var_op_node.parallel_desc().parallel_conf(), {sum_var});\n\n  adagrad_update_op_builder.OpTypeName(\"adagrad_update\")\n      .Input(\"model\", GenLogicalBlobName(var_op->BnInOp2Lbi(\"out\")))\n      .Input(\"model_diff\", model_diff_lbn)\n      .Input(\"learning_rate\", learning_rate_lbn)\n      .Input(\"train_step\", train_step_lbn)\n      .Input(\"sum\", GenVariableOutputLbn(sum_var))\n      .Attr<float>(\"epsilon\", epsilon)\n      .Attr<float>(\"lr_decay\", lr_decay)\n      .Attr<float>(\"weight_decay\", GetOptimizerWeightDecayRate(optimizer_conf, *var_op))\n      .ScopeSymbolId(var_op->op_conf().scope_symbol_id());\n  if (optimizer_conf.has_lr_scale()) {\n    adagrad_update_op_builder.Attr<float>(\"learning_rate_scale\", optimizer_conf.lr_scale());\n  }\n  SetDynamicLossScaleSkipIf(ctx, &adagrad_update_op_builder);\n  const auto adagrad_update_op = adagrad_update_op_builder.Build();\n  job_builder->AddOps(var_op_node.parallel_desc().parallel_conf(), {adagrad_update_op.op_conf()});\n}\n\n}  // namespace\n\nREGISTER_OPTIMIZER(OptimizerConf::kAdagradConf, &GenerateAdagradOptimizerOpConf);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job_rewriter/adam_optm.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/job_rewriter/optimizer.h\"\n#include \"oneflow/core/framework/framework.h\"\n\nnamespace oneflow {\n\nstruct BiasCorrectionFactorCacheKey {\n  float beta = 1.0;\n  ParallelConf parallel_conf;\n};\n\nbool operator==(const BiasCorrectionFactorCacheKey& lhs, const BiasCorrectionFactorCacheKey& rhs) {\n  return (lhs.beta == rhs.beta) && (lhs.parallel_conf == rhs.parallel_conf);\n}\n\n}  // namespace oneflow\n\nnamespace std {\n\ntemplate<>\nstruct hash<oneflow::BiasCorrectionFactorCacheKey> {\n  size_t operator()(const oneflow::BiasCorrectionFactorCacheKey& key) const {\n    using namespace oneflow;\n    return Hash(key.beta, key.parallel_conf);\n  }\n};\n\n}  // namespace std\n\nnamespace oneflow {\n\nclass BiasCorrectionFactorState final : public JobPassState {\n public:\n  BiasCorrectionFactorState() {}\n  ~BiasCorrectionFactorState() override = default;\n\n  std::string GetLbn(float beta, std::string bias_correction_name, ParallelConf parallel_conf,\n                     const std::function<std::string(float beta_val, std::string op_name)>&\n                         BiasCorrectionFactorStateOp) {\n    BiasCorrectionFactorCacheKey cache_key;\n    cache_key.beta = beta;\n    cache_key.parallel_conf = parallel_conf;\n    const auto& iter = key2lbn_.find(cache_key);\n    if (iter != key2lbn_.end()) {\n      return iter->second;\n    } else {\n      std::string lbn = BiasCorrectionFactorStateOp(beta, std::move(bias_correction_name));\n      key2lbn_.emplace(cache_key, lbn);\n      return lbn;\n    }\n  }\n\n private:\n  HashMap<BiasCorrectionFactorCacheKey, std::string> key2lbn_;\n};\n\nnamespace {\n\nstd::string GenVariableOutputLbn(const OperatorConf& op_conf) {\n  CHECK(op_conf.has_variable_conf());\n  return GenLogicalBlobName(op_conf.name(), op_conf.variable_conf().out());\n}\n\nOperatorConf GenerateAdamHelperVariableOpConf(const VariableOp& op, const std::string& name,\n                                              const float initial_value) {\n  OperatorConf helper_variable_op(op.op_conf());\n  helper_variable_op.set_name(op.op_name() + \"-\" + name);\n  helper_variable_op.mutable_variable_conf()->set_out(\"out\");\n  InitializerConf constant_initializer;\n  constant_initializer.mutable_constant_conf()->set_value(initial_value);\n  *(helper_variable_op.mutable_variable_conf()->mutable_initializer()) = constant_initializer;\n  helper_variable_op.set_scope_symbol_id(op.op_conf().scope_symbol_id());\n  return helper_variable_op;\n}\n\nvoid GenerateOptimizerOpConf(JobPassCtx* ctx, const OpNode& var_op_node,\n                             const std::string& model_diff_lbn, const OptimizerConf& optimizer_conf,\n                             JobBuilder* job_builder) {\n  const VariableOp* var_op = dynamic_cast<const VariableOp*>(&var_op_node.op());\n  CHECK_NOTNULL(var_op);\n\n  user_op::UserOpConfWrapperBuilder adam_update_op_builder(var_op->op_name() + \"_optimizer\");\n  float beta1 = 0.9;\n  float beta2 = 0.999;\n  float epsilon = 1e-8;\n  bool do_bias_correction = true;\n  bool amsgrad = false;\n  if (optimizer_conf.has_adam_conf()) {\n    const AdamModelUpdateConf& adam_conf = optimizer_conf.adam_conf();\n    beta1 = adam_conf.beta1();\n    beta2 = adam_conf.beta2();\n    epsilon = adam_conf.epsilon();\n    do_bias_correction = adam_conf.do_bias_correction();\n    amsgrad = adam_conf.amsgrad();\n  } else if (optimizer_conf.has_lazy_adam_conf()) {\n    const LazyAdamModelUpdateConf& lazy_adam_conf = optimizer_conf.lazy_adam_conf();\n    beta1 = lazy_adam_conf.beta1();\n    beta2 = lazy_adam_conf.beta2();\n    epsilon = lazy_adam_conf.epsilon();\n    do_bias_correction = lazy_adam_conf.do_bias_correction();\n    amsgrad = lazy_adam_conf.amsgrad();\n  } else {\n    UNIMPLEMENTED();\n  }\n  OperatorConf m_var(GenerateAdamHelperVariableOpConf(*var_op, \"m\", 0.f));\n  OperatorConf v_var(GenerateAdamHelperVariableOpConf(*var_op, \"v\", 0.f));\n  OperatorConf max_v_var{};\n  if (amsgrad) {\n    max_v_var = GenerateAdamHelperVariableOpConf(*var_op, \"max_v\", 0.f);\n    job_builder->AddOps(var_op_node.parallel_desc().parallel_conf(), {m_var, v_var, max_v_var});\n  } else {\n    job_builder->AddOps(var_op_node.parallel_desc().parallel_conf(), {m_var, v_var});\n  }\n\n  const std::string& train_step_lbn = job_builder->job().job_conf().train_conf().train_step_lbn();\n  const std::string& learning_rate_lbn = optimizer_conf.learning_rate_lbn();\n\n  adam_update_op_builder.OpTypeName(\"adam_update\")\n      .Input(\"model\", GenLogicalBlobName(var_op->BnInOp2Lbi(\"out\")))\n      .Input(\"model_diff\", model_diff_lbn)\n      .Input(\"learning_rate\", learning_rate_lbn)\n      .Input(\"m\", GenVariableOutputLbn(m_var))\n      .Input(\"v\", GenVariableOutputLbn(v_var))\n      .Attr<float>(\"beta1\", beta1)\n      .Attr<float>(\"beta2\", beta2)\n      .Attr<float>(\"epsilon\", epsilon)\n      .Attr<float>(\"weight_decay\", GetOptimizerWeightDecayRate(optimizer_conf, *var_op))\n      .Attr<bool>(\"amsgrad\", amsgrad)\n      .Attr<bool>(\"do_bias_correction\", do_bias_correction)\n      .ScopeSymbolId(var_op->op_conf().scope_symbol_id());\n  if (do_bias_correction) {\n    const std::string& job_pass_state_key = \"adam_bias_correction_factor\";\n    const bool has_state = CHECK_JUST(ctx->HasState<BiasCorrectionFactorState>(job_pass_state_key));\n    if (!has_state) {\n      CHECK_JUST(\n          ctx->ResetState(job_pass_state_key, std::make_unique<BiasCorrectionFactorState>()));\n    }\n    auto* state = CHECK_JUST(ctx->MutableState<BiasCorrectionFactorState>(job_pass_state_key));\n    ParallelConf bias_correction_parallel_conf;\n    const auto& lr_parallel_conf =\n        CHECK_JUST(job_builder->ParallelConf4Lbi(GenLogicalBlobId(learning_rate_lbn)));\n    const auto& train_step_parallel_conf =\n        CHECK_JUST(job_builder->ParallelConf4Lbi(GenLogicalBlobId(train_step_lbn)));\n    if (lr_parallel_conf == train_step_parallel_conf) {\n      bias_correction_parallel_conf = lr_parallel_conf;\n    } else {\n      bias_correction_parallel_conf = var_op_node.parallel_desc().parallel_conf();\n    }\n    auto AddAdamBiasCorrectionFactorOp = [&](float beta_val,\n                                             const std::string& op_name) -> std::string {\n      user_op::UserOpConfWrapperBuilder op_builder(var_op->op_name() + op_name);\n      const auto adam_bias_correction_factor_op =\n          op_builder.OpTypeName(\"adam_bias_correction_factor\")\n              .Input(\"train_step\", train_step_lbn)\n              .Attr<float>(\"beta\", beta_val)\n              .Output(\"out\")\n              .ScopeSymbolId(var_op->op_conf().scope_symbol_id())\n              .Build();\n\n      job_builder->AddOps(bias_correction_parallel_conf,\n                          {adam_bias_correction_factor_op.op_conf()});\n      return adam_bias_correction_factor_op.output(\"out\", 0);\n    };\n    const std::string bias_correction1_lbn =\n        state->GetLbn(beta1, \"adam_bias_correction_factor1\", bias_correction_parallel_conf,\n                      AddAdamBiasCorrectionFactorOp);\n    const std::string bias_correction2_lbn =\n        state->GetLbn(beta2, \"adam_bias_correction_factor2\", bias_correction_parallel_conf,\n                      AddAdamBiasCorrectionFactorOp);\n    adam_update_op_builder.Input(\"bias_correction1\", bias_correction1_lbn)\n        .Input(\"bias_correction2\", bias_correction2_lbn);\n  }\n  if (amsgrad) { adam_update_op_builder.Input(\"max_v\", GenVariableOutputLbn(max_v_var)); }\n  if (optimizer_conf.has_lr_scale()) {\n    adam_update_op_builder.Attr<float>(\"learning_rate_scale\", optimizer_conf.lr_scale());\n  }\n\n  SetDynamicLossScaleSkipIf(ctx, &adam_update_op_builder);\n  const auto adam_update_op = adam_update_op_builder.Build();\n  job_builder->AddOps(var_op_node.parallel_desc().parallel_conf(), {adam_update_op.op_conf()});\n}\n\n}  // namespace\n\nREGISTER_OPTIMIZER(OptimizerConf::kAdamConf, &GenerateOptimizerOpConf);\nREGISTER_OPTIMIZER(OptimizerConf::kLazyAdamConf, &GenerateOptimizerOpConf);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job_rewriter/add_ssp_variable_proxy.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/job_rewriter/job_pass.h\"\n#include \"oneflow/core/job/job.pb.h\"\n#include \"oneflow/core/job/scope.h\"\n#include \"oneflow/core/job_rewriter/calculation_pass.h\"\n#include \"oneflow/core/vm/symbol_storage.h\"\n#include \"oneflow/core/framework/framework.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nclass AddSspVariableProxyPass final : public JobPass {\n public:\n  AddSspVariableProxyPass(const AddSspVariableProxyPass&) = delete;\n  AddSspVariableProxyPass(AddSspVariableProxyPass&&) = delete;\n  AddSspVariableProxyPass() = default;\n  ~AddSspVariableProxyPass() = default;\n\n  Maybe<void> Apply(Job* job, JobPassCtx* ctx) const override {\n    if (!IsEnabled(*ctx)) { return Maybe<void>::Ok(); }\n    const OpGraph op_graph(*job);\n    JobBuilder job_builder(job);\n    return Apply(op_graph, &job_builder);\n  }\n\n  bool IsEnabled(const JobPassCtx& ctx) const {\n    return ctx.job_desc().IsTrain() && ctx.job_desc().Bool(\"enable_ssp\");\n  }\n\n  Maybe<void> Apply(const OpGraph& op_graph, JobBuilder* job_builder) const {\n    HashMap<LogicalBlobId, std::pair<std::string, std::string>> var2ref_value_pair;\n    HashSet<OpNode*> var_consumers;\n    HashSet<std::string> trainable_variable_op_names;\n    const Job& job = job_builder->job();\n    for (const auto& optimizer_conf : job.job_conf().train_conf().optimizer_conf()) {\n      for (const auto& variable_op_name : optimizer_conf.variable_op_names()) {\n        trainable_variable_op_names.insert(variable_op_name);\n      }\n    }\n    auto IsTrainableVarOp = [&](const OperatorConf& op_conf) {\n      if (!op_conf.has_variable_conf()) { return false; }\n      return trainable_variable_op_names.count(op_conf.name()) > 0;\n    };\n    JUST(ForEachTrainableVarOpNode(op_graph, IsTrainableVarOp, [&](OpNode* op_node) -> Maybe<void> {\n      op_node->ForEachNodeOnOutEdge([&](OpNode* consumer) { var_consumers.insert(consumer); });\n      const auto& old_var_out_lbi = op_node->op().BnInOp2Lbi(\"out\");\n      return AddSspVarProxyOp(op_node, job_builder, &var2ref_value_pair[old_var_out_lbi].first,\n                              &var2ref_value_pair[old_var_out_lbi].second);\n    }));\n    {\n      const auto& NeedReplace = [&](const LogicalBlobId& var_lbi) -> bool {\n        return var2ref_value_pair.count(var_lbi) > 0;\n      };\n      const auto& Ref4Var = [&](const LogicalBlobId& var_lbi) -> const std::string& {\n        return var2ref_value_pair.at(var_lbi).first;\n      };\n      const auto& Val4Var = [&](const LogicalBlobId& var_lbi) -> const std::string& {\n        return var2ref_value_pair.at(var_lbi).second;\n      };\n      for (OpNode* op_node : var_consumers) {\n        JUST(ReplaceVarWithSspVarProxyOp(op_node, job_builder, NeedReplace, Ref4Var, Val4Var));\n      }\n    }\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> ForEachTrainableVarOpNode(\n      const OpGraph& op_graph, const std::function<bool(const OperatorConf&)>& IsTrainableVarOp,\n      const std::function<Maybe<void>(OpNode*)>& DoEach) const {\n    const auto& IsSspVarProxy = [](const OperatorConf& op_conf) {\n      return op_conf.has_user_conf() && op_conf.user_conf().op_type_name() == \"ssp_variable_proxy\";\n    };\n    JUST(op_graph.MaybeForEachNode([&](OpNode* op_node) -> Maybe<void> {\n      const auto& op_conf = op_node->op().op_conf();\n      CHECK_OR_RETURN(!IsSspVarProxy(op_conf)) << \"AddSspVariableProxy can not be applied twice\";\n      if (IsTrainableVarOp(op_conf)) { return DoEach(op_node); }\n      return Maybe<void>::Ok();\n    }));\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> AddSspVarProxyOp(OpNode* op_node, JobBuilder* job_builder, std::string* ref_lbn,\n                               std::string* value_lbn) const {\n    const LogicalBlobId& old_var_out_lbi = op_node->op().BnInOp2Lbi(\"out\");\n    int64_t scope_symbol_id = op_node->op().op_conf().scope_symbol_id();\n    JUST(AddSspVarProxyOp(old_var_out_lbi, scope_symbol_id, job_builder, ref_lbn, value_lbn));\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> ReplaceVarWithSspVarProxyOp(\n      OpNode* op_node, JobBuilder* job_builder,\n      const std::function<bool(const LogicalBlobId&)>& NeedReplace,\n      const std::function<const std::string&(const LogicalBlobId&)>& Ref4Var,\n      const std::function<const std::string&(const LogicalBlobId&)>& Val4Var) const {\n    const auto& op = op_node->op();\n    std::unique_ptr<std::vector<OperatorConf>> new_op_confs;\n    for (const auto& ibn : op.input_bns()) {\n      const auto& lbi = op.BnInOp2Lbi(ibn);\n      if (!NeedReplace(lbi)) { continue; }\n      if (!new_op_confs) { new_op_confs.reset(new std::vector<OperatorConf>({op.op_conf()})); }\n      auto* new_op_conf = &new_op_confs->at(0);\n      int64_t scope_symbol_id = op.op_conf().scope_symbol_id();\n      bool in_optimizer_pass = JUST(IsInOptimizerPass(scope_symbol_id));\n      const auto* lbn = (in_optimizer_pass ? &Ref4Var(lbi) : &Val4Var(lbi));\n      ReplaceInputLbnInOpCustomizedConf(new_op_conf, ibn, *lbn);\n    }\n    if (new_op_confs) { job_builder->MutOpsOnlyOnce(*new_op_confs); }\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<bool> IsInOptimizerPass(int64_t scope_symbol_id) const {\n    const auto& scope = JUST(Singleton<symbol::Storage<Scope>>::Get()->MaybeGet(scope_symbol_id));\n    return scope.scope_proto().calculation_pass_name() == kOptimizerPass;\n  }\n\n  Maybe<void> AddSspVarProxyOp(const LogicalBlobId& old_var_out_lbi, int64_t scope_symbol_id,\n                               JobBuilder* job_builder, std::string* ref_lbn,\n                               std::string* value_lbn) const {\n    const Scope& scope = JUST(Singleton<symbol::Storage<Scope>>::Get()->MaybeGet(scope_symbol_id));\n    int64_t buffer_size = 0;\n    {\n      int64_t num_stages = scope.Int64(\"ssp_num_stages\");\n      int64_t stage_id = scope.Int64(\"ssp_stage_id\");\n      CHECK_GT(num_stages, 0);\n      CHECK_GE(stage_id, 0);\n      CHECK_LT(stage_id, num_stages);\n      buffer_size = num_stages - stage_id;\n    }\n    std::string op_name = old_var_out_lbi.op_name() + \"_ssp_variable_proxy\";\n    const auto proxy_op = user_op::UserOpConfWrapperBuilder(op_name)\n                              .Op(\"ssp_variable_proxy\")\n                              .ScopeSymbolId(scope_symbol_id)\n                              .Input(\"var\", GenLogicalBlobName(old_var_out_lbi))\n                              .Output(\"ref\")\n                              .Output(\"value\")\n                              .Attr<int64_t>(\"buffer_size\", buffer_size)\n                              .Build();\n    const auto& parallel_desc = *JUST(scope.GetParallelDesc(proxy_op.op_conf()));\n    job_builder->AddOps(parallel_desc.parallel_conf(), {proxy_op.op_conf()});\n    *ref_lbn = op_name + \"/ref_0\";\n    *value_lbn = op_name + \"/value_0\";\n    return Maybe<void>::Ok();\n  }\n};\n\nREGISTER_JOB_PASS(\"AddSspVariableProxy\", AddSspVariableProxyPass);\n\n}  // namespace\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job_rewriter/auto_learning_rate.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/job_rewriter/job_pass.h\"\n#include \"oneflow/core/job/job.pb.h\"\n#include \"oneflow/core/framework/framework.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nclass AutoLearningRate final : public JobPass {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(AutoLearningRate);\n  AutoLearningRate() = default;\n  ~AutoLearningRate() override = default;\n\n  bool IsEnabled(const JobPassCtx& ctx) const { return ctx.job_desc().IsTrain(); }\n\n  Maybe<void> Apply(const OpGraph& op_graph, Job* job) const;\n\n  Maybe<void> Apply(Job* job, JobPassCtx* ctx) const override {\n    if (!IsEnabled(*ctx)) { return Maybe<void>::Ok(); }\n    const OpGraph op_graph(*job);\n    return Apply(op_graph, job);\n  }\n};\n\nMaybe<void> AutoLearningRate::Apply(const OpGraph& op_graph, Job* job) const {\n  JobBuilder job_builder(job);\n  const TrainConf& train_conf = job->job_conf().train_conf();\n  auto AddScheduleOp = [&](const OptimizerConf& optimizer_conf,\n                           const std::string& op_name) -> std::string {\n    const class oneflow::OpNode* op_node =\n        op_graph.OpNode4OpName(GenLogicalBlobId(train_conf.train_step_lbn()).op_name());\n    CHECK_OR_RETURN(op_node != nullptr) << \"op node not found in op graph, op name: \" << op_name;\n    const ParallelConf& parallel_conf = op_node->parallel_desc().parallel_conf();\n    OperatorConf schedule_op_conf{};\n    schedule_op_conf.set_name(op_name);\n    auto* schedule_conf = schedule_op_conf.mutable_learning_rate_schedule_conf();\n    schedule_conf->set_train_step(train_conf.train_step_lbn());\n    schedule_conf->set_learning_rate(optimizer_conf.base_learning_rate());\n    schedule_conf->set_out(\"out\");\n    if (optimizer_conf.has_learning_rate_decay()) {\n      *schedule_conf->mutable_learning_rate_decay() = optimizer_conf.learning_rate_decay();\n    }\n    schedule_op_conf.set_scope_symbol_id(op_node->op().op_conf().scope_symbol_id());\n    job_builder.AddOps(parallel_conf, {schedule_op_conf});\n    return GenLogicalBlobName(op_name, schedule_conf->out());\n  };\n  FOR_RANGE(int64_t, i, 0, train_conf.optimizer_conf_size()) {\n    const auto& optimizer_conf = train_conf.optimizer_conf(i);\n    const std::string& lbn =\n        AddScheduleOp(optimizer_conf, \"System-Train-LearningRate-Scheduler_\" + NewUniqueId());\n    job->mutable_job_conf()->mutable_train_conf()->mutable_optimizer_conf(i)->set_learning_rate_lbn(\n        lbn);\n  }\n  return Maybe<void>::Ok();\n}\n\nREGISTER_JOB_PASS(\"AutoLearningRate\", AutoLearningRate);\n\n}  // namespace\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job_rewriter/auto_mixed_precision.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/job_rewriter/auto_mixed_precision.h\"\n#include \"oneflow/core/job_rewriter/auto_mixed_precision_lists.h\"\n\n#include <algorithm>\n\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/job_rewriter/job_pass.h\"\n#include \"oneflow/core/job_rewriter/pass_util.h\"\n#include \"oneflow/core/job/job_desc.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nvoid VerifyAMPList(const AMPList& amp_list) {\n  for (const auto& op_type : amp_list) {\n    CHECK(user_op::UserOpRegistryMgr::Get().GetOpRegistryResult(op_type) != nullptr)\n        << \"Cannot find \" << op_type << \" of AutoMixedPrecision list in OpRegistry.\";\n  }\n}\n\nusing NoCastRegistry = std::multimap<std::string, OpArg>;\n\nNoCastRegistry* GetNoCastRegistry() {\n  static NoCastRegistry s_registry;\n  return &s_registry;\n}\n\nbool FindInNoCastRegisry(const std::string& op_type, const OpArg& op_arg) {\n  auto range = GetNoCastRegistry()->equal_range(op_type);\n  for (auto it = range.first; it != range.second; ++it) {\n    if (it->second == op_arg) { return true; }\n  }\n  return false;\n}\n\nstd::function<bool(OpNode*)> MakePredicatorIsAllowedToRunWithHalf(const OpGraph& op_graph) {\n  auto allowed_set = std::make_shared<HashSet<OpNode*>>();\n  op_graph.ForEachNode([&](OpNode* node) {\n    // half computation is not supported on cpu\n    if (node->parallel_desc().device_type() == DeviceType::kCPU) { return; }\n    if (node->op().output_bns().size() > 0\n        || IsUserOpWithTypeName(node->op().op_conf(), \"one_embedding_fused_lookup_grad\")) {\n      INSERT_CHECK(allowed_set->insert(node));\n    }\n  });\n  return [allowed_set](OpNode* node) -> bool { return IsKeyFound(*allowed_set, node); };\n}\n\nvoid InsertCastOpImpl(bool f2h, const OpGraph& op_graph, const HashSet<OpNode*>& white_set,\n                      const DataType mixed_precision_data_type, JobBuilder* job_builder) {\n  HashSet<OpEdge*> white_set_edges;\n  {\n    std::function<const std::unordered_set<OpEdge*>&(OpNode*)> Node2Edges =\n        f2h ? &OpNode::in_edges : &OpNode::out_edges;\n    std::function<OpNode*(OpEdge*)> OppositeNode = f2h ? &OpEdge::src_node : &OpEdge::dst_node;\n    op_graph.ForEachNode([&](OpNode* node) {\n      if (IsKeyFound(white_set, node)) {\n        for (OpEdge* edge : Node2Edges(node)) {\n          if (!IsKeyFound(white_set, OppositeNode(edge))) {\n            INSERT_CHECK(white_set_edges.insert(edge));\n          }\n        }\n      }\n    });\n    auto EdgeName4Edge = [](OpEdge* const& edge) {\n      return std::string(\"edge of\\t\") + edge->src_node()->op().op_name() + \"\\tto\\t\"\n             + edge->dst_node()->op().op_name();\n    };\n    VLOG(3) << \"white_set_edges for f2h value: \" << f2h << \" is \"\n            << Container2Str<HashSet<OpEdge*>, OpEdge*>(white_set_edges, EdgeName4Edge);\n  }\n\n  HashMap<std::string, std::vector<OpEdge*>> edges_group_by_lbn;\n  {\n    for (OpEdge* edge : white_set_edges) {\n      for (const auto& lbi : edge->lbis()) {\n        std::string lbn = GenLogicalBlobName(lbi);\n        edges_group_by_lbn[lbn].emplace_back(edge);\n      }\n    }\n  }\n\n  HashMap<std::string, OperatorConf> dst_op_name2dst_op_confs;\n  for (auto& pair : edges_group_by_lbn) {\n    const std::string& lbn = pair.first;\n    LogicalBlobId cur_lbi = GenLogicalBlobId(lbn);\n    OpNode* src_node = pair.second.front()->src_node();\n\n    const BlobDesc& blob_desc = src_node->LogicalBlobDesc4Lbi(cur_lbi);\n    if (blob_desc.data_type() != DataType::kFloat) { continue; }\n\n    std::string cast_suffix = f2h ? \"-cast_f2h\" : \"-cast_h2f\";\n    DataType cast_data_type = f2h ? mixed_precision_data_type : DataType::kFloat;\n    auto cast_op = user_op::UserOpConfWrapperBuilder(ReplaceSlashToDash4Lbn(lbn) + cast_suffix)\n                       .Op(\"cast\")\n                       .Input(\"in\", lbn)\n                       .Output(\"out\")\n                       .Attr<DataType>(\"dtype\", cast_data_type)\n                       .ScopeSymbolId(src_node->op().op_conf().scope_symbol_id())\n                       .Build();\n\n    bool cast_is_consumed = false;\n    for (OpEdge* edge : pair.second) {\n      CHECK(src_node == edge->src_node());\n      OpNode* dst_node = edge->dst_node();\n      const auto& dst_ibns = edge->lbi2ibns().at(cur_lbi);\n      for (const auto& dst_ibn : dst_ibns) {\n        if (dst_node->op().op_conf().has_user_conf()) {\n          const std::string& op_type = dst_node->op().op_conf().user_conf().op_type_name();\n          const auto& op_arg = GenUnRepeatedBn(dst_ibn);\n          if (FindInNoCastRegisry(op_type, op_arg)) { continue; }\n        }\n        cast_is_consumed = true;\n        const std::string& dst_op_name = dst_node->op().op_name();\n        if (!IsKeyFound(dst_op_name2dst_op_confs, dst_op_name)) {\n          INSERT_CHECK(dst_op_name2dst_op_confs.insert(\n              std::make_pair(dst_op_name, dst_node->op().op_conf())));\n        }\n        OperatorConf& dst_op_conf = dst_op_name2dst_op_confs.at(dst_op_name);\n        std::string new_lbn = cast_op.op_name() + \"/out_0\";\n        CHECK_EQ(lbn, ReplaceInputLbnInOpCustomizedConf(&dst_op_conf, dst_ibn, new_lbn));\n      }\n    }\n\n    if (cast_is_consumed) {\n      job_builder->AddOps(src_node->parallel_desc().parallel_conf(),\n                          std::vector<OperatorConf>{cast_op.op_conf()});\n      VLOG(3) << \"Insert CastOp: \" << cast_op.op_name() << \" between \" << lbn;\n    }\n  }\n\n  std::vector<OperatorConf> dst_op_confs;\n  dst_op_confs.reserve(dst_op_name2dst_op_confs.size());\n  for (const auto& pair : dst_op_name2dst_op_confs) { dst_op_confs.emplace_back(pair.second); }\n  // make sure an op_conf can only be udpated once, cuz later update will override before\n  job_builder->MutOpsOnlyOnce(dst_op_confs);\n}\n\nclass AutoMixedPrecision final : public JobPass {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(AutoMixedPrecision);\n  AutoMixedPrecision()\n      : white_list_(AutoMixedPrecisionLists::WhiteList()),\n        black_list_(AutoMixedPrecisionLists::BlackList()),\n        gray_list_(AutoMixedPrecisionLists::GrayList()),\n        clear_list_(AutoMixedPrecisionLists::ClearList()) {}\n  ~AutoMixedPrecision() = default;\n\n  bool IsEnabled(const JobPassCtx& ctx) const {\n#if defined(WITH_CUDA) && defined(CUDA_VERSION) && CUDA_VERSION < 10000\n    return false;\n#else\n    return ctx.job_desc().enable_auto_mixed_precision();\n#endif\n  }\n\n  Maybe<void> Apply(Job* job, JobPassCtx* ctx) const override;\n\n private:\n  void FillBlackSet(const OpGraph& op_graph, HashSet<OpNode*>* black_set) const;\n  void FillWhiteSet(const OpGraph& op_graph, std::function<bool(OpNode*)> IsAllowedToRunWithHalf,\n                    const HashSet<OpNode*>& black_set, HashSet<OpNode*>* white_set) const;\n  void PropagateWhiteThroughClearNodes(const OpGraph& op_graph,\n                                       std::function<bool(OpNode*)> IsAllowedToRunWithHalf,\n                                       const HashSet<OpNode*>& black_set,\n                                       HashSet<OpNode*>* white_set) const;\n  void InsertCastOp(const OpGraph& op_graph, const HashSet<OpNode*>& white_set,\n                    const DataType mixed_precision_data_type, JobBuilder* job_builder) const;\n\n  const AMPList& white_list_;\n  const AMPList& black_list_;\n  const AMPList& gray_list_;\n  const AMPList& clear_list_;\n};\n\nMaybe<void> AutoMixedPrecision::Apply(Job* job, JobPassCtx* ctx) const {\n  if (!ctx->job_desc().enable_auto_mixed_precision()) { return Maybe<void>::Ok(); }\n  const OpGraph op_graph(*job);\n  JobBuilder job_builder(job);\n  CHECK(GlobalJobDesc().DefaultDataType() == DataType::kFloat);\n\n  VerifyAMPList(white_list_);\n  VerifyAMPList(black_list_);\n  VerifyAMPList(gray_list_);\n  VerifyAMPList(clear_list_);\n\n  std::function<std::string(OpNode* const&)> OpName4Node = [](OpNode* const& node) {\n    return node->op().op_name();\n  };\n  HashSet<OpNode*> black_set;\n  HashSet<OpNode*> white_set;\n\n  FillBlackSet(op_graph, &black_set);\n  VLOG(3) << \"BlackSet include: \"\n          << Container2Str<HashSet<OpNode*>, OpNode*>(black_set, OpName4Node);\n\n  auto IsAllowedToRunWithHalf = MakePredicatorIsAllowedToRunWithHalf(op_graph);\n  FillWhiteSet(op_graph, IsAllowedToRunWithHalf, black_set, &white_set);\n  VLOG(3) << \"WhiteSet Before Propagate include: \"\n          << Container2Str<HashSet<OpNode*>, OpNode*>(white_set, OpName4Node);\n  PropagateWhiteThroughClearNodes(op_graph, IsAllowedToRunWithHalf, black_set, &white_set);\n  VLOG(2) << \"WhiteSet include: \"\n          << Container2Str<HashSet<OpNode*>, OpNode*>(white_set, OpName4Node);\n  const DataType mixed_precision_data_type = ctx->job_desc().mixed_precision_data_type();\n  CHECK(mixed_precision_data_type == DataType::kFloat16\n        || mixed_precision_data_type == DataType::kBFloat16);\n  InsertCastOp(op_graph, white_set, mixed_precision_data_type, &job_builder);\n  return Maybe<void>::Ok();\n}\n\nvoid AutoMixedPrecision::FillBlackSet(const OpGraph& op_graph, HashSet<OpNode*>* black_set) const {\n  HashSet<OpNode*> upstream_or_part_of_black_and_gray;\n  DfsTopoGraphTraversal(\n      op_graph, true,\n      [&](OpNode* node) {\n        return IsNodeInList(black_list_, node) || IsNodeInList(gray_list_, node);\n      },\n      [&](OpNode* node) { return IsNodeInList(clear_list_, node); },\n      [&](OpNode* node) { return IsKeyFound(upstream_or_part_of_black_and_gray, node); },\n      [&](OpNode* node) {\n        INSERT_CHECK(upstream_or_part_of_black_and_gray.insert(node));\n        VLOG(3) << \"FillBlackSet(): Insert \" << node->op().op_name()\n                << \" to upstream_or_part_of_black_and_gray\";\n      });\n\n  // propagate black through upstream_or_part_of_black_and_gray\n  DfsTopoGraphTraversal(\n      op_graph, false, [&](OpNode* node) { return IsNodeInList(black_list_, node); },\n      [&](OpNode* node) { return IsKeyFound(upstream_or_part_of_black_and_gray, node); },\n      [&](OpNode* node) { return IsKeyFound(*black_set, node); },\n      [&](OpNode* node) {\n        INSERT_CHECK(black_set->insert(node));\n        VLOG(3) << \"FillBlackSet(): Insert \" << node->op().op_name() << \" to black_set\";\n      });\n}\n\nvoid AutoMixedPrecision::FillWhiteSet(const OpGraph& op_graph,\n                                      std::function<bool(OpNode*)> IsAllowedToRunWithHalf,\n                                      const HashSet<OpNode*>& black_set,\n                                      HashSet<OpNode*>* white_set) const {\n  auto IsWhiteOrSinkAndAllowedToRunHalf = [&](OpNode* node) {\n    return IsAllowedToRunWithHalf(node)\n           && (IsNodeInList(white_list_, node)\n               || (node->out_edges().empty()\n                   && (IsNodeInList(gray_list_, node) || IsNodeInList(clear_list_, node))));\n  };\n  HashSet<OpNode*> upstream_or_part_of_white;\n  DfsTopoGraphTraversal(\n      op_graph, true, IsWhiteOrSinkAndAllowedToRunHalf,\n      [&](OpNode* node) {\n        return !IsKeyFound(black_set, node) && IsAllowedToRunWithHalf(node)\n               && (IsNodeInList(gray_list_, node) || IsNodeInList(clear_list_, node));\n      },\n      [&](OpNode* node) { return IsKeyFound(upstream_or_part_of_white, node); },\n      [&](OpNode* node) {\n        INSERT_CHECK(upstream_or_part_of_white.insert(node));\n        VLOG(3) << \"FillWhiteSet(): Insert \" << node->op().op_name()\n                << \" to upstream_or_part_of_white\";\n      });\n\n  auto IsWhiteAndAllowedToRunHalf = [&](OpNode* node) {\n    return IsAllowedToRunWithHalf(node) && IsNodeInList(white_list_, node);\n  };\n  DfsTopoGraphTraversal(\n      op_graph, false, IsWhiteAndAllowedToRunHalf,\n      [&](OpNode* node) { return IsKeyFound(upstream_or_part_of_white, node); },\n      [&](OpNode* node) { return IsKeyFound(*white_set, node); },\n      [&](OpNode* node) {\n        INSERT_CHECK(white_set->insert(node));\n        VLOG(3) << \"FillWhiteSet(): Insert \" << node->op().op_name() << \" to white_set\";\n      });\n}\n\nvoid AutoMixedPrecision::PropagateWhiteThroughClearNodes(\n    const OpGraph& op_graph, std::function<bool(OpNode*)> IsAllowedToRunWithHalf,\n    const HashSet<OpNode*>& black_set, HashSet<OpNode*>* white_set) const {\n  auto PropagateIntoOneDirection = [&](bool is_downward) {\n    DfsTopoGraphTraversal(\n        op_graph, !is_downward, [&](OpNode* node) { return false; },\n        [&](OpNode* node) {\n          return !IsKeyFound(*white_set, node) && !IsKeyFound(black_set, node)\n                 && IsNodeInList(clear_list_, node) && IsAllowedToRunWithHalf(node);\n        },\n        [&](OpNode* node) { return IsKeyFound(*white_set, node); },\n        [&](OpNode* node) {\n          INSERT_CHECK(white_set->insert(node));\n          VLOG(3) << \"PropagateWhiteThroughNonListNodes(): Insert \" << node->op().op_name()\n                  << \" to white_set\";\n        });\n  };\n  PropagateIntoOneDirection(true);\n  PropagateIntoOneDirection(false);\n}\n\nvoid AutoMixedPrecision::InsertCastOp(const OpGraph& op_graph, const HashSet<OpNode*>& white_set,\n                                      const DataType mixed_precision_data_type,\n                                      JobBuilder* job_builder) const {\n  InsertCastOpImpl(true, op_graph, white_set, mixed_precision_data_type, job_builder);\n  InsertCastOpImpl(false, op_graph, white_set, mixed_precision_data_type, job_builder);\n}\n\nREGISTER_JOB_PASS(\"AutoMixedPrecision\", AutoMixedPrecision);\n\n}  // namespace\n\nnamespace {\n\nstruct NoCastRegistrar final {\n  NoCastRegistrar(const std::string& op_type, OpArg&& op_arg) {\n    auto* registry = GetNoCastRegistry();\n    registry->emplace(std::make_pair(op_type, std::move(op_arg)));\n  }\n  ~NoCastRegistrar() = default;\n};\n\n#define REGISTER_NO_CAST_REGISTRY(op_type, input_arg_name, idx)       \\\n  static NoCastRegistrar OF_PP_CAT(g_registrar, __COUNTER__)(op_type, \\\n                                                             std::make_pair(input_arg_name, idx));\n\n// For Example:\n// REGISTER_NO_CAST_REGISTRY(\"matmul\", \"b\", 0);\n\nREGISTER_NO_CAST_REGISTRY(\"normalization\", \"moving_mean\", 0)\nREGISTER_NO_CAST_REGISTRY(\"normalization\", \"moving_variance\", 0)\nREGISTER_NO_CAST_REGISTRY(\"normalization\", \"gamma\", 0)\nREGISTER_NO_CAST_REGISTRY(\"normalization\", \"beta\", 0)\n\nREGISTER_NO_CAST_REGISTRY(\"normalization_grad\", \"gamma\", 0)\n\nREGISTER_NO_CAST_REGISTRY(\"normalization_add_relu\", \"moving_mean\", 0)\nREGISTER_NO_CAST_REGISTRY(\"normalization_add_relu\", \"moving_variance\", 0)\nREGISTER_NO_CAST_REGISTRY(\"normalization_add_relu\", \"gamma\", 0)\nREGISTER_NO_CAST_REGISTRY(\"normalization_add_relu\", \"beta\", 0)\n\nREGISTER_NO_CAST_REGISTRY(\"normalization_add_relu_grad\", \"gamma\", 0)\nREGISTER_NO_CAST_REGISTRY(\"normalization_add_relu_grad\", \"beta\", 0)\nREGISTER_NO_CAST_REGISTRY(\"normalization_add_relu_grad\", \"mean\", 0)\nREGISTER_NO_CAST_REGISTRY(\"normalization_add_relu_grad\", \"inv_variance\", 0)\nREGISTER_NO_CAST_REGISTRY(\"normalization_add_relu_grad\", \"reserve_space\", 0)\n\nREGISTER_NO_CAST_REGISTRY(\"layer_norm_grad\", \"mean\", 0)\nREGISTER_NO_CAST_REGISTRY(\"layer_norm_grad\", \"inv_variance\", 0)\nREGISTER_NO_CAST_REGISTRY(\"layer_norm_param_grad\", \"mean\", 0)\nREGISTER_NO_CAST_REGISTRY(\"layer_norm_param_grad\", \"inv_variance\", 0)\nREGISTER_NO_CAST_REGISTRY(\"fuse_layer_norm_grad\", \"mean\", 0)\nREGISTER_NO_CAST_REGISTRY(\"fuse_layer_norm_grad\", \"inv_variance\", 0)\n\n}  // namespace\n\nnamespace amp {\n\nbool IsNoCast(const std::string& op_type, const OpArg& op_arg) {\n  return FindInNoCastRegisry(op_type, op_arg);\n}\n\n}  // namespace amp\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job_rewriter/auto_mixed_precision.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_JOB_REWRITER_AUTO_MIXED_PRECISION_H_\n#define ONEFLOW_CORE_JOB_REWRITER_AUTO_MIXED_PRECISION_H_\n\n#include <string>\n\nnamespace oneflow {\n\nusing OpArg = std::pair<std::string, int32_t>;\n\nnamespace amp {\n\nbool IsNoCast(const std::string& op_type, const OpArg& op_arg);\n\n}  // namespace amp\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_JOB_REWRITER_AUTO_MIXED_PRECISION_H_\n"
  },
  {
    "path": "oneflow/core/job_rewriter/auto_mixed_precision_lists.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/job_rewriter/auto_mixed_precision_lists.h\"\n\nnamespace oneflow {\n\nconst AMPList& AutoMixedPrecisionLists::WhiteList() {\n  static AMPList white_list = {\"matmul\",\n                               \"batch_matmul\",\n                               \"conv2d\",\n                               \"conv_data_grad\",\n                               \"conv_filter_grad\",\n                               \"conv_bias_grad\",\n                               \"amp_white_identity\",\n                               \"broadcast_matmul\",\n                               \"broadcast_matmul_grad_b\",\n                               \"fused_self_attention_query_mul_key_and_value\",\n                               \"fused_self_attention_query_mul_key_and_value_grad\",\n                               \"prelu\",\n                               \"prelu_grad\",\n                               \"tf_prelu\",\n                               \"tf_prelu_grad\",\n                               \"cublas_fused_mlp\",\n                               \"cublas_fused_mlp_grad\",\n                               \"fused_matmul_bias\",\n                               \"cublas_bias_add_relu_matmul_grad\",\n                               \"fused_glu\",\n                               \"fused_glu_without_linear_grad\",\n                               \"fused_matmul_bias_add_relu_dropout\",\n                               \"fused_relu_dropout_grad\",\n                               \"fused_dot_feature_interaction\",\n                               \"fused_dot_feature_interaction_grad\",\n                               \"one_embedding_fused_lookup\",\n                               \"one_embedding_fused_lookup_grad\",\n                               \"binary_cross_entropy_with_logits_reduce_mean\",\n                               \"binary_cross_entropy_with_logits_reduce_mean_grad\",\n                               \"fused_cross_feature_interaction\",\n                               \"fused_cross_feature_interaction_v1_grad\",\n                               \"fused_cross_feature_interaction_v2_grad\",\n                               \"fused_multi_head_attention_inference\",\n                               \"grouped_matmul_bias\"};\n  return white_list;\n}\n\nconst AMPList& AutoMixedPrecisionLists::BlackList() {\n  // TODO(niuchong): reduce_mean?\n  static AMPList black_list = {\"amp_black_identity\"};\n  return black_list;\n}\n\nconst AMPList& AutoMixedPrecisionLists::GrayList() {\n  static AMPList gray_list = {\"add_n\",\n                              \"tf_avg_pool_1d\",\n                              \"tf_avg_pool_1d_grad\",\n                              \"tf_avg_pool_2d\",\n                              \"tf_avg_pool_2d_grad\",\n                              \"tf_avg_pool_3d\",\n                              \"tf_avg_pool_3d_grad\",\n                              \"avg_pool_1d\",\n                              \"avg_pool_1d_grad\",\n                              \"avg_pool_2d\",\n                              \"avg_pool_2d_grad\",\n                              \"avg_pool_3d\",\n                              \"avg_pool_3d_grad\",\n                              \"bias_add\",\n                              \"reduce_sum\",\n                              \"reduce_sum_like\",\n                              \"sigmoid_grad\",\n                              \"tanh\",\n                              \"tanh_grad\",\n                              \"sqrt\",\n                              \"sqrt_grad\",\n                              \"scalar_mul\",\n                              \"scalar_mul_by_tensor\",\n                              \"scalar_add\",\n                              \"scalar_div\",\n                              \"scalar_pow\",\n                              \"broadcast_add\",\n                              \"broadcast_sub\",\n                              \"broadcast_mul\",\n                              \"broadcast_div\",\n                              \"layer_norm\",\n                              \"layer_norm_param_grad\",\n                              \"layer_norm_grad\",\n                              \"fuse_layer_norm_grad\",\n                              \"skip_layer_norm\",\n                              \"rms_norm\",\n                              \"rms_norm_grad\",\n                              \"rms_norm_param_grad\",\n                              \"dropout\",\n                              \"dropout_grad\",\n                              \"softmax\",\n                              \"softmax_grad\",\n                              \"log_softmax\",\n                              \"log_softmax_grad\",\n                              \"gelu\",\n                              \"gelu_grad\",\n                              \"fast_gelu\",\n                              \"fast_gelu_grad\",\n                              \"normalization\",\n                              \"normalization_grad\",\n                              \"normalization_add_relu\",\n                              \"normalization_add_relu_grad\",\n                              \"sparse_softmax_cross_entropy\",\n                              \"sparse_softmax_cross_entropy_grad\",\n                              \"nll\",\n                              \"nll_grad\",\n                              \"fused_tril_scale_softmax_mask_scale\",\n                              \"fused_tril_scale_softmax_mask_scale_grad\",\n                              \"fused_scale_mask_softmax_dropout\",\n                              \"fused_scale_mask_softmax_dropout_grad\",\n                              \"fused_scale_mask_softmax\",\n                              \"fused_scale_mask_softmax_grad\",\n                              \"fused_bias_add_scale_mask_softmax_dropout\",\n                              \"fused_bias_add_gelu\",\n                              \"fused_bias_add_gelu_grad\",\n                              \"fused_bias_add_mask_scale\",\n                              \"fused_fast_gelu_mul\",\n                              \"fused_fast_gelu_mul_grad\",\n                              \"acc\",\n                              \"reciprocal\",\n                              \"reciprocal_no_nan\",\n                              \"group_norm\",\n                              \"group_norm_param_grad\",\n                              \"group_norm_grad\",\n                              \"silu\",\n                              \"silu_grad\",\n                              \"fused_weighted_sum\"};\n  return gray_list;\n}\n\nconst AMPList& AutoMixedPrecisionLists::ClearList() {\n  // TODO(niuchong): tuple_identity\n  static AMPList clear_list = {\"broadcast_like\",\n                               \"gather\",\n                               \"gather_nd\",\n                               \"scatter_nd\",\n                               \"scatter_nd_like\",\n                               \"unsorted_segment_sum_like\",\n                               \"tf_max_pool_1d\",\n                               \"tf_max_pool_1d_grad\",\n                               \"tf_max_pool_2d\",\n                               \"tf_max_pool_2d_grad\",\n                               \"tf_max_pool_3d\",\n                               \"tf_max_pool_3d_grad\",\n                               \"max_pool_1d\",\n                               \"max_pool_1d_grad\",\n                               \"max_pool_2d\",\n                               \"max_pool_2d_grad\",\n                               \"max_pool_3d\",\n                               \"max_pool_3d_grad\",\n                               \"reshape\",\n                               \"reshape_like\",\n                               \"relu\",\n                               \"relu_grad\",\n                               \"transpose\",\n                               \"random_mask_like\",\n                               \"cat\",\n                               \"split_like\",\n                               \"pad\",\n                               \"same_padding\",\n                               \"same_padding_grad\",\n                               \"tril\",\n                               \"slice\",\n                               \"slice_grad\",\n                               \"fused_scale_tril\",\n                               \"identity\",\n                               \"squeeze\",\n                               \"embedding\",\n                               \"embedding_grad\",\n                               \"expand\",\n                               \"expand_dims\",\n                               \"cast_to_static_shape\",\n                               \"parallel_cast\",\n                               \"hierarchical_parallel_cast\",\n                               \"hierarchical_parallel_cast_like\",\n                               \"repeat\",\n                               \"unpack\",\n                               \"pack\",\n                               \"nvtx_start\",\n                               \"nvtx_end\",\n                               \"narrow\",\n                               \"narrow_grad\",\n                               \"ones_like\",\n                               \"pinned_identity\",\n                               \"to_contiguous\",\n                               \"copy\",\n                               \"where\",\n                               \"upsample_nearest_2d\",\n                               \"fill_\"};\n\n  return clear_list;\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job_rewriter/auto_mixed_precision_lists.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_JOB_REWRITER_AUTO_MIXED_PRECISION_LISTS_H_\n#define ONEFLOW_CORE_JOB_REWRITER_AUTO_MIXED_PRECISION_LISTS_H_\n\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/operator/op_conf_util.h\"\n\nnamespace oneflow {\n\ntypedef HashSet<std::string> AMPList;\n\nclass AutoMixedPrecisionLists final {\n public:\n  // TODO(niuchong): list include grad\n  static const AMPList& WhiteList();\n  static const AMPList& BlackList();\n  static const AMPList& GrayList();\n  static const AMPList& ClearList();\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_JOB_REWRITER_AUTO_MIXED_PRECISION_LISTS_H_\n"
  },
  {
    "path": "oneflow/core/job_rewriter/auto_parallel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <chrono>\n#include \"oneflow/core/common/hash_container.h\"\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/job_rewriter/job_pass.h\"\n#include \"oneflow/core/job/job.pb.h\"\n#include \"oneflow/core/job/parallel_desc.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/auto_parallel/sbp_constructor.h\"\n#include \"oneflow/core/rpc/include/global_process_ctx.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nclass AutoParallelPass final : public JobPass {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(AutoParallelPass);\n  AutoParallelPass() = default;\n  ~AutoParallelPass() override = default;\n\n  Maybe<void> Apply(const OpGraph& op_graph, Job* job) const;\n\n  Maybe<void> Apply(Job* job, JobPassCtx* ctx) const override {\n    if (!job->job_conf().enable_auto_parallel()) { return Maybe<void>::Ok(); }\n    VLOG(3) << \"=== Enable AutoParallel ===\";\n    if (job->job_conf().enable_auto_parallel_ignore_user_sbp_config()) {\n      JUST(RemoveParallelCastOps(job));\n    }\n    const OpGraph op_graph(*job);\n    return Apply(op_graph, job);\n  }\n\n private:\n  Maybe<void> RemoveParallelCastOps(Job* job) const;\n};\n\nMaybe<void> AutoParallelPass::Apply(const OpGraph& op_graph, Job* job) const {\n  // auto-parallel\n  LOG(INFO) << \"Start Auto Parallel\";\n  auto time_begin = std::chrono::high_resolution_clock::now();\n\n  auto_parallel::SbpConstructor sbp_constructor(op_graph, job);\n  JUST(sbp_constructor.FindBestSbpSignature());\n  JUST(sbp_constructor.DumpNdSbpSignatureForJob(op_graph, job));\n  auto time_end = std::chrono::high_resolution_clock::now();\n  VLOG(2) << \"Auto parallel took \"\n          << std::chrono::duration_cast<std::chrono::milliseconds>(time_end - time_begin).count()\n          << \" ms\\n\";\n  if (GlobalProcessCtx::Rank() == 0) {\n    // sbp_constructor.PrintSBPGraphDebugInfo();\n    JUST(sbp_constructor.CheckSbpAgreement(*job));\n  }\n  return Maybe<void>::Ok();\n}\n\nREGISTER_JOB_PASS(\"AutoParallelPass\", AutoParallelPass);\n\nMaybe<void> AutoParallelPass::RemoveParallelCastOps(Job* job) const {\n  VLOG(3) << \"Remove parallel cast ops for auto_parallel:\";\n  const OpGraph op_graph(*job);\n  JobBuilder job_builder(job);\n  HashMap<std::string, OperatorConf> op_name2op_conf;\n  HashMap<std::string, NdSbpSignature> op_name2nd_sbp_signature;\n  HashSet<std::string> ctrl_in_op_names;\n  op_graph.ForEachNode([&](const OpNode* op_node) {\n    for (const std::string& ctrl_in_op_name : op_node->op().op_conf().ctrl_in_op_name()) {\n      ctrl_in_op_names.insert(ctrl_in_op_name);\n    }\n  });\n  const auto IsParallelCastOp = [](const OperatorConf& op_conf) -> bool {\n    return op_conf.has_user_conf()\n           && (op_conf.user_conf().op_type_name() == \"parallel_cast\"\n               || op_conf.user_conf().op_type_name() == \"hierarchical_parallel_cast\"\n               || op_conf.user_conf().op_type_name() == \"hierarchical_parallel_cast_like\");\n  };\n  std::vector<std::string> del_op_names;\n  HashSet<std::string> del_op_name_set;\n  std::function<void(const OpNode*)> Try2Delete = [&](const OpNode* op_node) {\n    if (del_op_name_set.find(op_node->op().op_name()) != del_op_name_set.end()) { return; }\n    const OperatorConf& op_conf = op_node->op().op_conf();\n    if (!IsParallelCastOp(op_conf)) { return; }\n    if (!op_conf.ctrl_in_op_name().empty()) {\n      VLOG(3) << \"Skip \" << op_conf.name() << \", because it has ctrl edge.\";\n      return;\n    }\n    if (ctrl_in_op_names.find(op_conf.name()) != ctrl_in_op_names.end()) {\n      VLOG(3) << \"Skip \" << op_conf.name() << \", because it is a ctrl edge.\";\n      return;\n    }\n    if (op_node->in_edges().size() != 1) { return; }\n\n    // Find the first op which won't be deleted\n    const OpNode* source_op = op_node;\n    const OpNode* producer = op_node->SoleInEdge()->src_node();\n    while (IsParallelCastOp(producer->op().op_conf())) {\n      Try2Delete(producer);\n      if (del_op_name_set.find(producer->op().op_name()) == del_op_name_set.end()) { break; }\n      source_op = producer;\n      producer = source_op->SoleInEdge()->src_node();\n    }\n    user_op::UserOpConfWrapper conf_wrapper_in(source_op->op().op_conf());\n    const LogicalBlobId& parallel_cast_in_lbi = GenLogicalBlobId(conf_wrapper_in.input(\"in\", 0));\n\n    user_op::UserOpConfWrapper conf_wrapper_out(op_conf);\n    const LogicalBlobId& parallel_cast_out_lbi =\n        GenLogicalBlobId(conf_wrapper_out.output(\"out\", 0));\n    if (op_node->parallel_desc() != producer->parallel_desc()) {\n      VLOG(3) << \"Skip \" << op_node->op().op_name() << \"(with placement: \"\n              << *CHECK_JUST(PlacementToString(SymbolOf(op_node->parallel_desc())))\n              << \"), because producer \" << producer->op().op_name() << \"'s placement is \"\n              << *CHECK_JUST(PlacementToString(SymbolOf(producer->parallel_desc())));\n      return;\n    }\n    for (const OpEdge* out_edge : op_node->out_edges()) {\n      const OpNode* consumer = out_edge->dst_node();\n      if (consumer->parallel_desc() != op_node->parallel_desc()) {\n        VLOG(3) << \"Skip \" << op_node->op().op_name() << \"(with placement: \"\n                << *CHECK_JUST(PlacementToString(SymbolOf(op_node->parallel_desc())))\n                << \"), because consumer \" << consumer->op().op_name() << \"'s placement is \"\n                << *CHECK_JUST(PlacementToString(SymbolOf(consumer->parallel_desc())));\n        return;\n      }\n    }\n    op_name2nd_sbp_signature[producer->op().op_name()] = producer->nd_sbp_signature();\n    for (const OpEdge* out_edge : op_node->out_edges()) {\n      const OpNode* consumer = out_edge->dst_node();\n      const std::string& consumer_op_name = consumer->op().op_name();\n      op_name2nd_sbp_signature[consumer_op_name] = consumer->nd_sbp_signature();\n      if (op_name2op_conf.find(consumer_op_name) == op_name2op_conf.end()) {\n        op_name2op_conf[consumer_op_name] = consumer->op().op_conf();\n      }\n      OperatorConf& consumer_op_conf = op_name2op_conf.at(consumer_op_name);\n      for (const std::string& ibn : consumer->op().input_bns()) {\n        if (consumer->op().BnInOp2Lbi(ibn) == parallel_cast_out_lbi) {\n          const auto& new_val = GenLogicalBlobName(parallel_cast_in_lbi);\n          const auto& old_val = ReplaceInputLbnInOpCustomizedConf(&consumer_op_conf, ibn, new_val);\n          CHECK_EQ(GenLogicalBlobName(parallel_cast_out_lbi), old_val);\n        }\n      }\n    }\n    del_op_names.emplace_back(op_conf.name());\n    del_op_name_set.insert(op_conf.name());\n    VLOG(3) << \"\\tremove \" << op_conf.name();\n  };\n  op_graph.ForEachNode(Try2Delete);\n  for (const auto& pair : op_name2op_conf) { job_builder.MutOpsOnlyOnce({pair.second}); }\n  for (const auto& pair : op_name2nd_sbp_signature) {\n    job_builder.AddNdSbpSignature4OpName(pair.first, pair.second);\n  }\n  job_builder.DelOps(del_op_names);\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job_rewriter/auto_train_step.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/job_rewriter/job_pass.h\"\n#include \"oneflow/core/job/job.pb.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/job_rewriter/dynamic_loss_scale_job_pass_state.h\"\n#include \"oneflow/core/framework/scope_util.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nclass AutoTrainStep final : public JobPass {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(AutoTrainStep);\n  AutoTrainStep() = default;\n  ~AutoTrainStep() override = default;\n\n  Maybe<void> Apply(Job* job, JobPassCtx* ctx) const override;\n};\n\nMaybe<void> AutoTrainStep::Apply(Job* job, JobPassCtx* ctx) const {\n  if (!ctx->job_desc().IsTrain()) { return Maybe<void>::Ok(); }\n  const OpGraph op_graph(*job);\n  const TrainConf& train_conf = job->job_conf().train_conf();\n  if (train_conf.has_train_step_lbn()) {\n    CHECK_OR_RETURN(!train_conf.has_dynamic_loss_scale_policy());\n    return Maybe<void>::Ok();\n  }\n  OperatorConf variable_op_conf{};\n  const std::string train_step_name = \"System-Train-TrainStep\";\n  variable_op_conf.set_name(train_step_name);\n  VariableOpConf* variable_conf = variable_op_conf.mutable_variable_conf();\n  variable_conf->set_out(\"out\");\n  *variable_conf->mutable_shape()->mutable_dim()->Add() = 1;\n  variable_conf->set_data_type(DataType::kInt64);\n  variable_conf->mutable_initializer()->mutable_constant_int_conf()->set_value(0);\n\n  OperatorConf identity_op_conf{};\n  identity_op_conf.set_name(train_step_name + \"-Identity\");\n  IdentityOpConf* identity_conf = identity_op_conf.mutable_identity_conf();\n  identity_conf->set_in(GenLogicalBlobName(variable_op_conf.name(), variable_conf->out()));\n  identity_conf->set_out(\"out\");\n  const std::string& train_step_lbn =\n      GenLogicalBlobName(identity_op_conf.name(), identity_conf->out());\n\n  JobBuilder job_builder(job);\n  ParallelConf parallel_conf;\n  if (ParseBooleanFromEnv(\"ONEFLOW_GRAPH_PLACE_TRAINING_STATE_ON_ALL_RANKS\", false)) {\n    parallel_conf = GenParallelConfOfCpuOnAllRanks();\n\n  } else {\n    parallel_conf = GenParallelConfOfCpuZeroOnMaster();\n  }\n  int64_t scope_symbol_id = 0;\n  {\n    const auto& opt_scope_symbol_id =\n        JUST(MakeInitialScope(job->job_conf(), SymbolOf(ParallelDesc(parallel_conf)),\n                              /* is_local */ false))\n            ->symbol_id();\n    CHECK_OR_RETURN(opt_scope_symbol_id.has_value())\n        << Error::RuntimeError() << \"symbol_id not initialized\";\n    scope_symbol_id = JUST(opt_scope_symbol_id);\n  }\n\n  auto scalar_add_op = user_op::UserOpConfWrapperBuilder(train_step_name + \"-ScalarAdd\")\n                           .Op(\"scalar_add\")\n                           .Input(\"in\", train_step_lbn)\n                           .Output(\"out\")\n                           .Attr<bool>(\"has_float_operand\", false)\n                           .Attr<double>(\"float_operand\", 0)\n                           .Attr<bool>(\"has_int_operand\", true)\n                           .Attr<int64_t>(\"int_operand\", 1)\n                           .ScopeSymbolId(scope_symbol_id)\n                           .Build();\n\n  variable_op_conf.set_scope_symbol_id(scope_symbol_id);\n  identity_op_conf.set_scope_symbol_id(scope_symbol_id);\n  job_builder.AddOps(parallel_conf, {variable_op_conf, identity_op_conf, scalar_add_op.op_conf()});\n  if (train_conf.has_dynamic_loss_scale_policy()) {\n    const auto& dynamic_loss_scale_state =\n        JUST(ctx->GetState<DynamicLossScaleJobPassState>(\"dynamic_loss_scale_state\"));\n    auto assign_op =\n        user_op::UserOpConfWrapperBuilder(train_step_name + \"-AssignIfNot\")\n            .Op(\"assign_if_not\")\n            .Input(\"ref\", GenLogicalBlobName(variable_op_conf.name(), variable_conf->out()))\n            .Input(\"value\", scalar_add_op.output(\"out\", 0))\n            .Input(\"condition\", dynamic_loss_scale_state.count_not_finite_lbn())\n            .ScopeSymbolId(scope_symbol_id)\n            .Build();\n    job_builder.AddOps(parallel_conf, {assign_op.op_conf()});\n  } else {\n    auto assign_op =\n        user_op::UserOpConfWrapperBuilder(train_step_name + \"-Assign\")\n            .Op(\"assign\")\n            .Input(\"ref\", GenLogicalBlobName(variable_op_conf.name(), variable_conf->out()))\n            .Input(\"value\", scalar_add_op.output(\"out\", 0))\n            .ScopeSymbolId(scope_symbol_id)\n            .Build();\n    job_builder.AddOps(parallel_conf, {assign_op.op_conf()});\n  }\n\n  job->mutable_job_conf()->mutable_train_conf()->set_train_step_lbn(train_step_lbn);\n  return Maybe<void>::Ok();\n}\n\nREGISTER_JOB_PASS(\"AutoTrainStep\", AutoTrainStep);\n\n}  // namespace\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job_rewriter/autograd.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/job_rewriter/autograd.h\"\n#include \"oneflow/core/job/job_builder.h\"\n#include \"oneflow/core/job_rewriter/clone_grad.h\"\n#include \"oneflow/core/operator/variable_op.h\"\n#include \"oneflow/core/register/op_blob_arg.pb.h\"\n#include \"oneflow/core/common/protobuf.h\"\n#include \"oneflow/core/common/container_util.h\"\n#include \"oneflow/core/common/throw.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/job_rewriter/job_pass.h\"\n#include \"oneflow/core/job_rewriter/dynamic_loss_scale_job_pass_state.h\"\n#include \"oneflow/core/framework/scope_util.h\"\n#include \"oneflow/core/job_rewriter/clip_by_global_norm_job_pass_state.h\"\n#include \"oneflow/core/job_rewriter/pass_util.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nconst TrainConf& GetTrainConf() { return GlobalJobDesc().job_conf().train_conf(); }\n\nint64_t ScopeSymbolId4Lbi(const OpGraph& op_graph, const LogicalBlobId& lbi) {\n  return op_graph.OpNode4OpName(lbi.op_name())->op().op_conf().scope_symbol_id();\n}\n\nbool AnyLbiWithDiffLbi(const OpEdge* op_edge) {\n  const Operator& src_op = op_edge->src_node()->op();\n  const Operator& dst_op = op_edge->dst_node()->op();\n  auto IsOutputBlobModifierRequiresGrad = [&](const LogicalBlobId& lbi) {\n    return src_op.OutputBlobModifier4Obn(op_edge->lbi2obn().at(lbi)).requires_grad();\n  };\n  auto IsInputBlobModifierRequiresGrad = [&](const LogicalBlobId& lbi) {\n    const auto& ibns = op_edge->lbi2ibns().at(lbi);\n    for (const std::string& ibn : ibns) {\n      if (dst_op.InputBlobModifier4Ibn(ibn).requires_grad()) { return true; }\n    }\n    CHECK_GT(ibns.size(), 0);\n    return false;\n  };\n  for (const LogicalBlobId& lbi : op_edge->lbis()) {\n    if (IsOutputBlobModifierRequiresGrad(lbi) && IsInputBlobModifierRequiresGrad(lbi)) {\n      return true;\n    }\n  }\n  CHECK_GT(op_edge->lbis().size(), 0);\n  return false;\n}\n\nvoid CheckNotReachableAmongOpNodes(const OpGraph& op_graph, const std::list<OpNode*>& op_nodes) {\n  auto IsReachable = op_graph.MakePredicatorIsReachable();\n  for (OpNode* src_node : op_nodes) {\n    for (OpNode* dst_node : op_nodes) {\n      if (src_node == dst_node) { continue; }\n      CHECK(!IsReachable(src_node, dst_node));\n    }\n  }\n}\n\nMaybe<void> GetLossOpNodes(const OpGraph& op_graph, std::list<OpNode*>* loss_op_nodes) {\n  const auto& train_conf = GetTrainConf();\n  HashSet<std::string> loss_op_names;\n  for (const std::string& loss_lbn : train_conf.loss_lbn()) {\n    loss_op_names.emplace(GenLogicalBlobId(loss_lbn).op_name());\n  }\n  op_graph.ForEachNode([&](OpNode* op_node) {\n    if (loss_op_names.find(op_node->op().op_name()) != loss_op_names.end()) {\n      loss_op_nodes->emplace_back(op_node);\n    }\n  });\n  if (loss_op_nodes->empty()) { return Error::LossBlobNotFoundError() << \"Loss blob not found.\"; }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> GetLossOpNodesAndAscendants(const OpGraph& op_graph, HashSet<OpNode*>* op_nodes) {\n  std::list<OpNode*> starts;\n  JUST(GetLossOpNodes(op_graph, &starts));\n  auto ForEachNextNode = [&](OpNode* op_node, const std::function<void(OpNode*)>& Handler) {\n    for (OpEdge* edge : op_node->in_edges()) {\n      if (AnyLbiWithDiffLbi(edge)) { Handler(edge->src_node()); }\n    }\n  };\n  op_graph.BfsForEachNode(starts, ForEachNextNode,\n                          [&](OpNode* op_node) { op_nodes->emplace(op_node); });\n  return Maybe<void>::Ok();\n}\n\nconst ParallelConf& ProducerParallelConf4Lbi(const OpGraph& op_graph, const LogicalBlobId& lbi) {\n  return op_graph.OpNode4OpName(lbi.op_name())->parallel_desc().parallel_conf();\n}\n\nvoid ScaleModelDiffByConstantLossInstanceNum(const OpGraph& op_graph, JobBuilder* job_builder,\n                                             HashMap<LogicalBlobId, LogicalBlobId>* lbi2diff_lbi,\n                                             const int64_t loss_instance_num) {\n  if (loss_instance_num == 1) { return; }\n  const float scale_factor = 1.0f / static_cast<float>(loss_instance_num);\n  for (auto& pair : *lbi2diff_lbi) {\n    const LogicalBlobId& lbi = pair.first;\n    LogicalBlobId& diff_lbi = pair.second;\n    auto scalar_mul_op =\n        user_op::UserOpConfWrapperBuilder(\"Sys-DiffScale-ScalarMul-\" + lbi.op_name() + \"_\"\n                                          + lbi.blob_name() + \"-\" + NewUniqueId())\n            .Op(\"scalar_mul\")\n            .Input(\"in\", GenLogicalBlobName(diff_lbi))\n            .Output(\"out\")\n            .Attr<bool>(\"has_float_operand\", true)\n            .Attr<double>(\"float_operand\", scale_factor)\n            .Attr<bool>(\"has_int_operand\", false)\n            .Attr<int64_t>(\"int_operand\", 0)\n            .ScopeSymbolId(ScopeSymbolId4Lbi(op_graph, lbi))\n            .Build();\n    job_builder->AddOps(ProducerParallelConf4Lbi(op_graph, lbi), {scalar_mul_op.op_conf()});\n    diff_lbi = GenLogicalBlobId(scalar_mul_op.output(\"out\", 0));\n  }\n}\n\nMaybe<void> TryLocalCastTotalLossInstanceNum(\n    JobBuilder* job_builder, const HashMap<LogicalBlobId, OpNode*>& loss_lbi2loss_node,\n    LogicalBlobId* total_loss_instance_num_lbi) {\n  auto IsLocal4Lbi = [](const LogicalBlobId& lbi, OpNode* op_node) -> Maybe<bool> {\n    const auto& obn = *JUST(op_node->op().obn4lbi(lbi));\n    const auto& opt_local_parallel = *JUST(op_node->op().OptLocalParallel4BnInOp(obn));\n    return opt_local_parallel.has_local_parallel();\n  };\n  const auto& begin = *loss_lbi2loss_node.begin();\n  bool is_local = JUST(IsLocal4Lbi(begin.first, begin.second));\n  for (const auto& pair : loss_lbi2loss_node) {\n    bool is_other_local = JUST(IsLocal4Lbi(pair.first, pair.second));\n    CHECK_EQ_OR_RETURN(is_local, is_other_local);  // NOLINT\n  }\n  if (is_local) {\n    OperatorConf op_conf;\n    op_conf.set_name(\"System-Cast-Local-TotalLossInstanceNum\" + NewUniqueId());\n    CastFromLocalOpConf* cast_from_local = op_conf.mutable_cast_from_local_conf();\n    cast_from_local->set_in(GenLogicalBlobName(*total_loss_instance_num_lbi));\n    cast_from_local->set_out(\"out\");\n    cast_from_local->mutable_sbp_parallel()->mutable_partial_sum_parallel();\n    const auto& parallel_conf = JUST(job_builder->ParallelConf4Lbi(*total_loss_instance_num_lbi));\n    int64_t scope_symbol_id = 0;\n    {\n      const auto& opt_scope_symbol_id = JUST(MakeInitialScope(job_builder->job().job_conf(),\n                                                              SymbolOf(ParallelDesc(parallel_conf)),\n                                                              /* is_local */ false))\n                                            ->symbol_id();\n      CHECK_OR_RETURN(opt_scope_symbol_id.has_value())\n          << Error::RuntimeError() << \"symbol_id not initialized\";\n      scope_symbol_id = JUST(opt_scope_symbol_id);\n    }\n    op_conf.set_scope_symbol_id(scope_symbol_id);\n    job_builder->AddOps(parallel_conf, {op_conf});\n    total_loss_instance_num_lbi->set_op_name(op_conf.name());\n    total_loss_instance_num_lbi->set_blob_name(\"out\");\n  }\n  return Maybe<void>::Ok();\n}\n\nvoid ScaleModelDiffByDynamicLossInstanceNum(\n    const OpGraph& op_graph, JobBuilder* job_builder,\n    HashMap<LogicalBlobId, LogicalBlobId>* lbi2diff_lbi,\n    const HashMap<LogicalBlobId, OpNode*>& loss_lbi2loss_node) {\n  auto BuildInstanceNumOpConf4LossOpNode = [&](const LogicalBlobId& loss_lbi, const OpNode* op_node,\n                                               LogicalBlobId* lbi) {\n    OperatorConf instance_num_op;\n    instance_num_op.set_name(\"System-Autograd-\" + loss_lbi.op_name() + \"-\" + loss_lbi.blob_name()\n                             + \"-LossInstanceNum\");\n    auto* instance_num_op_conf = instance_num_op.mutable_shape_elem_cnt_conf();\n    instance_num_op_conf->set_x(GenLogicalBlobName(loss_lbi));\n    instance_num_op_conf->set_y(\"y\");\n    instance_num_op_conf->set_data_type(op_node->LogicalBlobDesc4Lbi(loss_lbi).data_type());\n    instance_num_op_conf->mutable_include_axis_conf();\n    instance_num_op.set_scope_symbol_id(op_node->op().op_conf().scope_symbol_id());\n    job_builder->AddOps(op_node->parallel_desc().parallel_conf(), {instance_num_op});\n    lbi->set_op_name(instance_num_op.name());\n    lbi->set_blob_name(\"y\");\n  };\n  LogicalBlobId total_loss_instance_num_lbi;\n  if (loss_lbi2loss_node.size() == 1) {\n    const auto& pair_it = loss_lbi2loss_node.begin();\n    BuildInstanceNumOpConf4LossOpNode(pair_it->first, pair_it->second,\n                                      &total_loss_instance_num_lbi);\n  } else if (loss_lbi2loss_node.size() > 1) {\n    OperatorConf op_conf;\n    op_conf.set_name(\"System-Autograd-total_loss_instance_num\");\n    TotalLossInstanceNumOpConf* total_loss_instance_num_conf =\n        op_conf.mutable_total_loss_instance_num_conf();\n    for (const auto& pair : loss_lbi2loss_node) {\n      LogicalBlobId loss_instance_num_lbi;\n      BuildInstanceNumOpConf4LossOpNode(pair.first, pair.second, &loss_instance_num_lbi);\n      total_loss_instance_num_conf->add_in(GenLogicalBlobName(loss_instance_num_lbi));\n    }\n    total_loss_instance_num_conf->set_out(\"out\");\n\n    ParallelConf parallel_conf;\n    parallel_conf.set_device_tag(\"cpu\");\n    parallel_conf.add_device_name(\"0:0\");\n    int64_t scope_symbol_id = 0;\n    {\n      const auto& opt_scope_symbol_id =\n          CHECK_JUST(MakeInitialScope(job_builder->job().job_conf(),\n                                      SymbolOf(ParallelDesc(parallel_conf)),\n                                      /* is_local */ false))\n              ->symbol_id();\n      if (!opt_scope_symbol_id.has_value()) { THROW(RuntimeError) << \"symbol_id not initialized\"; }\n      scope_symbol_id = CHECK_JUST(opt_scope_symbol_id);\n    }\n    op_conf.set_scope_symbol_id(scope_symbol_id);\n    job_builder->AddOps(parallel_conf, {op_conf});\n\n    total_loss_instance_num_lbi.set_op_name(op_conf.name());\n    total_loss_instance_num_lbi.set_blob_name(\"out\");\n  } else {\n    UNIMPLEMENTED();\n  }\n  CHECK_JUST(TryLocalCastTotalLossInstanceNum(job_builder, loss_lbi2loss_node,\n                                              &total_loss_instance_num_lbi));\n  for (auto& pair : *lbi2diff_lbi) {\n    const LogicalBlobId& lbi = pair.first;\n    LogicalBlobId& diff_lbi = pair.second;\n    auto scalar_div_op =\n        user_op::UserOpConfWrapperBuilder(\"Sys-DiffScale-ScalarDiv-\" + lbi.op_name() + \"_\"\n                                          + lbi.blob_name() + \"-\" + NewUniqueId())\n            .Op(\"scalar_div_by_tensor\")\n            .Input(\"x\", GenLogicalBlobName(diff_lbi))\n            .Input(\"scalar\", GenLogicalBlobName(total_loss_instance_num_lbi))\n            .Output(\"y\")\n            .ScopeSymbolId(ScopeSymbolId4Lbi(op_graph, lbi))\n            .Build();\n    job_builder->AddOps(ProducerParallelConf4Lbi(op_graph, lbi), {scalar_div_op.op_conf()});\n    diff_lbi = GenLogicalBlobId(scalar_div_op.output(\"y\", 0));\n  }\n}\n\nbool AllSplitDistribution(const NdSbp& nd_sbp) {\n  for (int64_t i = 0; i < nd_sbp.sbp_parallel_size(); ++i) {\n    if (!nd_sbp.sbp_parallel(i).has_split_parallel()) { return false; }\n  }\n  return true;\n}\n\nvoid ForEachAggregatedParamGroup(\n    const OpGraph& op_graph, const HashMap<LogicalBlobId, LogicalBlobId>& lbi2diff_lbi,\n    const std::function<void(const ParallelDesc& parallel_desc, const NdSbp& nd_sbp,\n                             const std::vector<LogicalBlobId>& libs)>& Handler) {\n  HashMap<LogicalBlobId, const ParallelDesc*> lbi2parallel_desc;\n  HashMap<std::pair<ParallelDesc, NdSbp>, std::vector<LogicalBlobId>> group;\n  for (auto& pair : lbi2diff_lbi) {\n    const LogicalBlobId& lbi = pair.first;\n    const OpNode* model_op_node = op_graph.OpNode4OpName(lbi.op_name());\n    const ParallelDesc& parallel_desc = model_op_node->parallel_desc();\n    const NdSbp& nd_sbp = model_op_node->NdSbp4Lbi(lbi);\n    group[std::make_pair(parallel_desc, nd_sbp)].emplace_back(lbi);\n  }\n  for (const auto& pair : group) { Handler(pair.first.first, pair.first.second, pair.second); }\n}\n\nint64_t MakeScopeSymbolId(const JobConfigProto& job_conf, const ParallelConf& parallel_conf) {\n  const auto& opt_scope_symbol_id =\n      CHECK_JUST(MakeInitialScope(job_conf, SymbolOf(ParallelDesc(parallel_conf)),\n                                  /* is_local */ false))\n          ->symbol_id();\n  if (!opt_scope_symbol_id.has_value()) { THROW(RuntimeError) << \"symbol_id not initialized\"; }\n  return CHECK_JUST(opt_scope_symbol_id);\n}\n\nstd::string AddLbns(JobBuilder* job_builder, const std::vector<std::string>& lbns,\n                    const ParallelConf& parallel_conf, int64_t scope_symbol_id,\n                    const std::string& op_name_prefix) {\n  if (lbns.size() == 1) {\n    return lbns.front();\n  } else {\n    user_op::UserOpConfWrapperBuilder add_op_builder(op_name_prefix + NewUniqueId());\n    add_op_builder.Op(\"add_n\");\n    for (const std::string& lbn : lbns) { add_op_builder.Input(\"in\", lbn); }\n    const auto add_op = add_op_builder.Output(\"out\").ScopeSymbolId(scope_symbol_id).Build();\n    job_builder->AddOps(parallel_conf, {add_op.op_conf()});\n    return add_op.output(\"out\", 0);\n  }\n}\n\nstd::string AddParallelCast(JobBuilder* job_builder, const std::string& in_lbn,\n                            const std::string& sbp_str, const ParallelConf& parallel_conf,\n                            const std::string& op_name_prefix) {\n  ParallelConf flat_parallel_conf = parallel_conf;\n  flat_parallel_conf.mutable_hierarchy()->clear_dim();\n  const int64_t scope_symbol_id =\n      MakeScopeSymbolId(job_builder->job().job_conf(), flat_parallel_conf);\n  std::vector<std::string> sbp = {sbp_str};\n  auto parallel_cast_op =\n      user_op::UserOpConfWrapperBuilder(op_name_prefix + NewUniqueId())\n          .Op(\"hierarchical_parallel_cast\")\n          .Input(\"in\", in_lbn)\n          .Output(\"out\")\n          .Attr<std::vector<std::string>>(\"nd_sbp\", sbp)\n          .Attr<std::string>(\"grad_mode\", \"auto\")\n          .Attr<std::vector<std::string>>(\"grad_nd_sbp\", std::vector<std::string>{})\n          .ScopeSymbolId(scope_symbol_id)\n          .Build();\n  job_builder->AddOps(flat_parallel_conf, {parallel_cast_op.op_conf()});\n  return parallel_cast_op.output(\"out\", 0);\n}\n\nbool IsBroadcast(const NdSbp& nd_sbp, const ParallelDesc& parallel_desc) {\n  if (parallel_desc.parallel_num() == 1) { return true; }\n  for (int64_t i = 0; i < nd_sbp.sbp_parallel_size(); ++i) {\n    if (!nd_sbp.sbp_parallel(i).has_broadcast_parallel()) { return false; }\n  }\n  return true;\n}\n\nbool HasSplit(const NdSbp& nd_sbp, const ParallelDesc& parallel_desc) {\n  if (parallel_desc.parallel_num() == 1) { return false; }\n  for (const auto& sbp : nd_sbp.sbp_parallel()) {\n    if (sbp.has_split_parallel()) { return true; }\n  }\n  return false;\n}\n\nOperatorConf GenConstantLikeOp(const std::string& op_name, int64_t scope_symbol_id,\n                               const std::string& like_lbn, double value, DataType dtype) {\n  OperatorConf op_conf;\n  op_conf.set_name(op_name);\n  op_conf.set_scope_symbol_id(scope_symbol_id);\n  ConstantLikeOpConf* constant_like_conf = op_conf.mutable_constant_like_conf();\n  constant_like_conf->set_like(like_lbn);\n  if (dtype == DataType::kInt32) {\n    constant_like_conf->set_int_operand(static_cast<int32_t>(value));\n  } else if (dtype == DataType::kInt64) {\n    constant_like_conf->set_int_operand(static_cast<int64_t>(value));\n  } else if (dtype == DataType::kFloat) {\n    constant_like_conf->set_float_operand(static_cast<float>(value));\n  } else if (dtype == DataType::kDouble) {\n    constant_like_conf->set_float_operand(value);\n  } else {\n    UNIMPLEMENTED();\n  }\n  constant_like_conf->set_data_type(dtype);\n  constant_like_conf->set_out(\"out\");\n  return op_conf;\n}\n\nstd::string GlobalAbsMaxMin(const OpGraph& op_graph, JobBuilder* job_builder,\n                            const HashMap<LogicalBlobId, LogicalBlobId>& lbi2diff_lbi,\n                            bool max_or_min, ParallelConf* out_parallel_conf) {\n  // max(abs(x))\n  bool all_same_parallel_desc = true;\n  const ParallelDesc& any_parallel_desc =\n      op_graph.OpNode4OpName(lbi2diff_lbi.begin()->first.op_name())->parallel_desc();\n  std::vector<std::string> group_reduce_lbns;\n\n  auto GroupReduce = [&](const ParallelDesc& parallel_desc, const NdSbp& nd_sbp,\n                         const std::vector<LogicalBlobId>& lbis) {\n    if (!parallel_desc.EqualsIgnoringHierarchy(any_parallel_desc)) {\n      all_same_parallel_desc = false;\n    }\n    int64_t scope_symbol_id =\n        MakeScopeSymbolId(job_builder->job().job_conf(), parallel_desc.parallel_conf());\n    bool has_split = HasSplit(nd_sbp, parallel_desc);\n    if (job_builder->job().job_conf().enable_gradients_stats_aggregation()) {\n      std::string multi_reduce_op_type_name =\n          has_split ? (max_or_min ? \"local_multi_reduce_max_abs\" : \"local_multi_reduce_min_abs\")\n                    : (max_or_min ? \"multi_reduce_max_abs\" : \"multi_reduce_min_abs\");\n      std::string multi_reduce_op_name =\n          \"System-ClipGradient-GlobalNorm-MultiReduceXimumAbs-\" + NewUniqueId();\n      auto multi_reduce_op_builder = user_op::UserOpConfWrapperBuilder(multi_reduce_op_name)\n                                         .Op(multi_reduce_op_type_name)\n                                         .Output(\"y\")\n                                         .ScopeSymbolId(scope_symbol_id);\n      for (const auto& lbi : lbis) {\n        multi_reduce_op_builder.Input(\"x\", GenLogicalBlobName(lbi2diff_lbi.at(lbi)));\n      }\n      auto multi_reduce_op = multi_reduce_op_builder.Build();\n      job_builder->AddOps(parallel_desc.parallel_conf(), {multi_reduce_op.op_conf()});\n      if (has_split) {\n        std::string group_reduce_op_type_name = max_or_min ? \"reduce_max\" : \"reduce_min\";\n        std::string group_reduce_op_name =\n            \"System-ClipGradient-GlobalNorm-GroupReduceXimum-\" + NewUniqueId();\n        auto group_reduce_op = user_op::UserOpConfWrapperBuilder(group_reduce_op_name)\n                                   .Op(group_reduce_op_type_name)\n                                   .Input(\"input_tensor\", multi_reduce_op.output(\"y\", 0))\n                                   .Output(\"output_tensor\")\n                                   .Attr(\"axis\", std::vector<int32_t>{0})\n                                   .Attr(\"keepdims\", false)\n                                   .ScopeSymbolId(scope_symbol_id)\n                                   .Build();\n        job_builder->AddOps(parallel_desc.parallel_conf(), {group_reduce_op.op_conf()});\n        group_reduce_lbns.push_back(group_reduce_op.output(\"output_tensor\", 0));\n      } else {\n        group_reduce_lbns.push_back(multi_reduce_op.output(\"y\", 0));\n      }\n    } else {\n      UNIMPLEMENTED();\n    }\n  };\n  ForEachAggregatedParamGroup(op_graph, lbi2diff_lbi, GroupReduce);\n  CHECK_GT(group_reduce_lbns.size(), 0);\n\n  *out_parallel_conf = all_same_parallel_desc ? any_parallel_desc.parallel_conf()\n                                              : GenParallelConfOfCpuZeroOnMaster();\n  out_parallel_conf->mutable_hierarchy()->clear_dim();\n  if (group_reduce_lbns.size() == 1) {\n    return group_reduce_lbns[0];\n  } else {\n    // stack all group max and go on max\n    const int64_t scope_symbol_id =\n        MakeScopeSymbolId(job_builder->job().job_conf(), *out_parallel_conf);\n    auto stack_op_builder =\n        user_op::UserOpConfWrapperBuilder(\"System-ClipGradient-GlobalNorm-GlobalStack-\"\n                                          + NewUniqueId())\n            .Op(\"stack\")\n            .Output(\"out\")\n            .Attr(\"axis\", int64_t(0))\n            .Attr(\"max_dim_size\", static_cast<int64_t>(group_reduce_lbns.size()))\n            .ScopeSymbolId(scope_symbol_id);\n    for (const auto& lbn : group_reduce_lbns) { stack_op_builder.Input(\"in\", lbn); }\n    auto stack_op = stack_op_builder.Build();\n    job_builder->AddOps(*out_parallel_conf, {stack_op.op_conf()});\n\n    std::string reduce_op_type_name = max_or_min ? \"reduce_max\" : \"reduce_min\";\n    std::string reduce_op_name =\n        \"System-ClipGradient-GlobalNorm-GlobalReduceXimum-\" + NewUniqueId();\n    auto reduce_op = user_op::UserOpConfWrapperBuilder(reduce_op_name)\n                         .Op(reduce_op_type_name)\n                         .Input(\"input_tensor\", stack_op.output(\"out\", 0))\n                         .Output(\"output_tensor\")\n                         .Attr(\"axis\", std::vector<int32_t>{0})\n                         .Attr(\"keepdims\", false)\n                         .ScopeSymbolId(scope_symbol_id)\n                         .Build();\n    job_builder->AddOps(*out_parallel_conf, {reduce_op.op_conf()});\n    return reduce_op.output(\"output_tensor\", 0);\n  }\n}\n\nstd::string GlobalNorm(const OpGraph& op_graph, JobBuilder* job_builder,\n                       const HashMap<LogicalBlobId, LogicalBlobId>& lbi2diff_lbi, float p,\n                       ParallelConf* out_parallel_conf) {\n  bool all_same_parallel_desc = true;\n  const ParallelDesc& any_parallel_desc =\n      op_graph.OpNode4OpName(lbi2diff_lbi.begin()->first.op_name())->parallel_desc();\n  bool all_broadcast = true;\n  std::vector<std::string> group_lbns;\n  std::vector<ParallelConf> group_parallel_confs;\n  group_lbns.reserve(lbi2diff_lbi.size());\n  group_parallel_confs.reserve(lbi2diff_lbi.size());\n\n  auto GroupNorm = [&](const ParallelDesc& parallel_desc, const NdSbp& nd_sbp,\n                       const std::vector<LogicalBlobId>& lbis) {\n    if (!parallel_desc.EqualsIgnoringHierarchy(any_parallel_desc)) {\n      all_same_parallel_desc = false;\n    }\n    int64_t scope_symbol_id =\n        MakeScopeSymbolId(job_builder->job().job_conf(), parallel_desc.parallel_conf());\n    if (!IsBroadcast(nd_sbp, parallel_desc)) { all_broadcast = false; }\n    group_parallel_confs.emplace_back(parallel_desc.parallel_conf());\n\n    if (job_builder->job().job_conf().enable_gradients_stats_aggregation()) {\n      auto multi_reduce_sum_op_builder =\n          user_op::UserOpConfWrapperBuilder(\"System-ClipGradient-GlobalNorm-MultiReduceSumPowAbs-\"\n                                            + NewUniqueId())\n              .Op(\"multi_reduce_sum_pow_abs\")\n              .Attr(\"p\", p)\n              .Output(\"y\")\n              .ScopeSymbolId(scope_symbol_id);\n      for (const auto& lbi : lbis) {\n        multi_reduce_sum_op_builder.Input(\"x\", GenLogicalBlobName(lbi2diff_lbi.at(lbi)));\n      }\n      const auto multi_reduce_sum_op = multi_reduce_sum_op_builder.Build();\n      job_builder->AddOps(parallel_desc.parallel_conf(), {multi_reduce_sum_op.op_conf()});\n      group_lbns.emplace_back(multi_reduce_sum_op.output(\"y\", 0));\n    } else {\n      std::vector<std::string> lbns_to_add;\n      lbns_to_add.reserve(lbis.size());\n      for (const auto& lbi : lbis) {\n        const LogicalBlobId& diff_lbi = lbi2diff_lbi.at(lbi);\n        const auto square_sum_op =\n            user_op::UserOpConfWrapperBuilder(\"System-ClipGradient-GlobalNorm-ReduceSumPowAbs-\"\n                                              + NewUniqueId())\n                .Op(\"multi_reduce_sum_pow_abs\")\n                .Input(\"x\", GenLogicalBlobName(diff_lbi))\n                .Attr(\"p\", p)\n                .Output(\"y\")\n                .ScopeSymbolId(scope_symbol_id)\n                .Build();\n        job_builder->AddOps(parallel_desc.parallel_conf(), {square_sum_op.op_conf()});\n        lbns_to_add.emplace_back(square_sum_op.output(\"y\", 0));\n      }\n      group_lbns.emplace_back(AddLbns(job_builder, lbns_to_add, parallel_desc.parallel_conf(),\n                                      scope_symbol_id, \"System-ClipGradient-GlobalNorm-Add-\"));\n    }\n  };\n  ForEachAggregatedParamGroup(op_graph, lbi2diff_lbi, GroupNorm);\n\n  // sum in group\n  *out_parallel_conf = all_same_parallel_desc ? any_parallel_desc.parallel_conf()\n                                              : GenParallelConfOfCpuZeroOnMaster();\n  const int64_t scope_symbol_id =\n      MakeScopeSymbolId(job_builder->job().job_conf(), *out_parallel_conf);\n  std::vector<std::string> sum_group_lbns;\n  if (all_broadcast) {\n    sum_group_lbns = std::move(group_lbns);\n  } else {\n    sum_group_lbns.reserve(group_lbns.size());\n    for (size_t i = 0; i < group_lbns.size(); ++i) {\n      std::string lbn;\n      if (all_same_parallel_desc) {\n        // reduce many times P->B (allreduce) to 1 times\n        lbn = AddParallelCast(job_builder, group_lbns.at(i), \"P\", group_parallel_confs.at(i),\n                              \"System-ClipGradient-ParallelCast-\");\n      } else {\n        // sum will run on cpu 0, we need do P->B first,\n        // because when execution is on single device, only B is accepted\n        lbn = AddParallelCast(job_builder, group_lbns.at(i), \"B\", group_parallel_confs.at(i),\n                              \"System-ClipGradient-ParallelCast-\");\n      }\n      sum_group_lbns.push_back(std::move(lbn));\n    }\n    out_parallel_conf->mutable_hierarchy()->clear_dim();\n  }\n  auto global_reduce_sum_lbn = AddLbns(job_builder, sum_group_lbns, *out_parallel_conf,\n                                       scope_symbol_id, \"System-ClipGradient-GlobalNorm-Add-\");\n\n  auto global_pow_op =\n      user_op::UserOpConfWrapperBuilder(\"System-ClipGradient-GlobalNorm-GlobalPow-\" + NewUniqueId())\n          .Op(\"scalar_pow\")\n          .Input(\"in\", global_reduce_sum_lbn)\n          .Attr(\"float_operand\", 1.0 / p)\n          .Attr(\"has_float_operand\", true)\n          .Output(\"out\")\n          .ScopeSymbolId(scope_symbol_id)\n          .Build();\n  job_builder->AddOps(*out_parallel_conf, {global_pow_op.op_conf()});\n\n  return global_pow_op.output(\"out\", 0);\n}\n\nvoid ClipGradientByGlobalNorm(JobPassCtx* ctx, const OpGraph& op_graph, JobBuilder* job_builder,\n                              HashMap<LogicalBlobId, LogicalBlobId>* lbi2diff_lbi,\n                              const ClipByGlobalNormConf& conf) {\n  if (lbi2diff_lbi->empty()) { return; }\n  ParallelConf parallel_conf;\n  std::string total_norm_lbn;\n  CHECK(conf.has_norm_type());\n  double norm_type = conf.norm_type();\n  if (std::isinf(norm_type) && norm_type > 0) {\n    total_norm_lbn = GlobalAbsMaxMin(op_graph, job_builder, *lbi2diff_lbi, true, &parallel_conf);\n  } else if (std::isinf(norm_type) && norm_type < 0) {\n    total_norm_lbn = GlobalAbsMaxMin(op_graph, job_builder, *lbi2diff_lbi, false, &parallel_conf);\n  } else {\n    total_norm_lbn = GlobalNorm(op_graph, job_builder, *lbi2diff_lbi, norm_type, &parallel_conf);\n  }\n\n  int64_t scope_symbol_id = MakeScopeSymbolId(job_builder->job().job_conf(), parallel_conf);\n\n  auto add_eps_ops =\n      user_op::UserOpConfWrapperBuilder(\"System-ClipGradient-GlobalNorm-AddEps-\" + NewUniqueId())\n          .Op(\"scalar_add\")\n          .Input(\"in\", total_norm_lbn)\n          .Attr(\"float_operand\", 1e-6)\n          .Attr(\"has_float_operand\", true)\n          .Output(\"out\")\n          .ScopeSymbolId(scope_symbol_id)\n          .Build();\n  job_builder->AddOps(parallel_conf, {add_eps_ops.op_conf()});\n\n  auto inv_op =\n      user_op::UserOpConfWrapperBuilder(\"System-ClipGradient-GlobalNorm-Inv-\" + NewUniqueId())\n          .Op(\"reciprocal_no_nan\")\n          .Input(\"x\", add_eps_ops.output(\"out\", 0))\n          .Output(\"y\")\n          .ScopeSymbolId(scope_symbol_id)\n          .Build();\n  job_builder->AddOps(parallel_conf, {inv_op.op_conf()});\n\n  auto coeff_op =\n      user_op::UserOpConfWrapperBuilder(\"System-ClipGradient-GlobalNorm-Coeff-\" + NewUniqueId())\n          .Op(\"scalar_mul\")\n          .Input(\"in\", inv_op.output(\"y\", 0))\n          .Attr(\"float_operand\", static_cast<double>(conf.max_norm()))\n          .Attr(\"has_float_operand\", true)\n          .Output(\"out\")\n          .ScopeSymbolId(scope_symbol_id)\n          .Build();\n  job_builder->AddOps(parallel_conf, {coeff_op.op_conf()});\n\n  auto clamp_coeff_op =\n      user_op::UserOpConfWrapperBuilder(\"System-ClipGradient-GlobalNorm-Clamp-\" + NewUniqueId())\n          .Op(\"clip_by_scalar_max\")\n          .Input(\"x\", coeff_op.output(\"out\", 0))\n          .Attr(\"floating_max\", 1.0)\n          .Output(\"y\")\n          .ScopeSymbolId(scope_symbol_id)\n          .Build();\n  job_builder->AddOps(parallel_conf, {clamp_coeff_op.op_conf()});\n\n  const std::string& coeff_lbn = clamp_coeff_op.output(\"y\", 0);\n  for (auto& pair : *lbi2diff_lbi) {\n    const LogicalBlobId& lbi = pair.first;\n    LogicalBlobId& diff_lbi = pair.second;\n    auto mul_op_name = \"System-ClipGradient-GlobalNorm-ScalarMul-\" + NewUniqueId();\n    auto scalar_mul_op = user_op::UserOpConfWrapperBuilder(mul_op_name)\n                             .Op(\"scalar_mul_by_tensor\")\n                             .Input(\"x\", GenLogicalBlobName(diff_lbi))\n                             .Input(\"scalar\", coeff_lbn)\n                             .Output(\"y\")\n                             .ScopeSymbolId(ScopeSymbolId4Lbi(op_graph, lbi))\n                             .Build();\n    job_builder->AddOps(op_graph.OpNode4OpName(lbi.op_name())->parallel_desc().parallel_conf(),\n                        {scalar_mul_op.op_conf()});\n    diff_lbi = GenLogicalBlobId(scalar_mul_op.output(\"y\", 0));\n  }\n\n  if (!CHECK_JUST(ctx->HasState<ClipByGlobalNormJobPassState>(\"clip_by_global_norm_state\"))) {\n    CHECK_JUST(ctx->ResetState(\"clip_by_global_norm_state\",\n                               std::make_unique<ClipByGlobalNormJobPassState>()));\n  }\n  auto state =\n      CHECK_JUST(ctx->MutableState<ClipByGlobalNormJobPassState>(\"clip_by_global_norm_state\"));\n  const std::shared_ptr<ClipByGlobalNormJobPassState::TotalNormState>& total_norm_state =\n      std::make_shared<ClipByGlobalNormJobPassState::TotalNormState>(\n          total_norm_lbn, coeff_lbn, parallel_conf, scope_symbol_id);\n  for (auto& pair : *lbi2diff_lbi) {\n    const LogicalBlobId& lbi = pair.first;\n    const std::string& variable_op_name = lbi.op_name();\n    state->AddTotalNormState(variable_op_name, total_norm_state);\n  }\n}\n\n}  // namespace\n\nMaybe<void> MakeGetterLossOpNode4OpName(\n    const OpGraph& op_graph, std::function<OpNode*(const std::string&)>* LossOpNode4OpName) {\n  std::list<OpNode*> loss_nodes;\n  JUST(GetLossOpNodes(op_graph, &loss_nodes));\n  auto loss_op_name2op_node = std::make_shared<HashMap<std::string, OpNode*>>();\n  for (OpNode* op_node : loss_nodes) {\n    CHECK(loss_op_name2op_node->emplace(op_node->op().op_name(), op_node).second);\n  }\n  *LossOpNode4OpName = [loss_op_name2op_node](const std::string& op_name) -> OpNode* {\n    return loss_op_name2op_node->at(op_name);\n  };\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> ScaleModelDiffByLossInstanceNum(const OpGraph& op_graph, JobBuilder* job_builder,\n                                            HashMap<LogicalBlobId, LogicalBlobId>* lbi2diff_lbi) {\n  std::function<OpNode*(const std::string&)> LossOpNode4OpName;\n  JUST(MakeGetterLossOpNode4OpName(op_graph, &LossOpNode4OpName));\n  const auto& train_conf = GetTrainConf();\n  HashMap<LogicalBlobId, OpNode*> loss_lbi2op_node;\n  for (const auto& loss_lbn : train_conf.loss_lbn()) {\n    const auto& lbi = GenLogicalBlobId(loss_lbn);\n    CHECK(loss_lbi2op_node.emplace(lbi, LossOpNode4OpName(lbi.op_name())).second);\n  }\n  const Shape src_time_shape({1, 1});\n  const int64_t source_time_shape_elem_cnt = src_time_shape.elem_cnt();\n  bool all_loss_time_shape_eq_src = true;\n  for (const auto& pair : loss_lbi2op_node) {\n    const int64_t time_shape_elem_cnt = JUST(pair.second->op().GetOpTimeShape())->elem_cnt();\n    if (time_shape_elem_cnt != source_time_shape_elem_cnt) {\n      CHECK_EQ(time_shape_elem_cnt % source_time_shape_elem_cnt, 0);\n      all_loss_time_shape_eq_src = false;\n    }\n  }\n  if (all_loss_time_shape_eq_src) {\n    const BlobDesc* blob_desc = nullptr;\n    for (const auto& pair : loss_lbi2op_node) {\n      const BlobDesc* cur_blob_desc = &pair.second->LogicalBlobDesc4Lbi(pair.first);\n      if (blob_desc != nullptr) { CHECK(*blob_desc == *cur_blob_desc); }\n      blob_desc = cur_blob_desc;\n    }\n    if (blob_desc->is_dynamic()) {\n      ScaleModelDiffByDynamicLossInstanceNum(op_graph, job_builder, lbi2diff_lbi, loss_lbi2op_node);\n    } else {\n      ScaleModelDiffByConstantLossInstanceNum(op_graph, job_builder, lbi2diff_lbi,\n                                              blob_desc->shape().elem_cnt());\n    }\n  } else {\n    std::unique_ptr<BlobDesc> blob_desc;\n    for (const auto& pair : loss_lbi2op_node) {\n      const BlobDesc* cur_blob_desc = &pair.second->LogicalBlobDesc4Lbi(pair.first);\n      // TODO: support dynamic\n      CHECK(!cur_blob_desc->is_dynamic());\n      const DataType loss_data_type = cur_blob_desc->data_type();\n      const int64_t time_shape_elem_cnt = JUST(pair.second->op().GetOpTimeShape())->elem_cnt();\n      // TODO: consider sbp\n      const int64_t loss_elem_cnt =\n          cur_blob_desc->shape().elem_cnt() * time_shape_elem_cnt / source_time_shape_elem_cnt;\n      if (blob_desc) {\n        CHECK_EQ(blob_desc->data_type(), loss_data_type);\n        CHECK_EQ(blob_desc->shape().elem_cnt(), loss_elem_cnt);\n      } else {\n        blob_desc.reset(\n            new BlobDesc(Shape({loss_elem_cnt}), loss_data_type, cur_blob_desc->memory_format()));\n      }\n    }\n    ScaleModelDiffByConstantLossInstanceNum(op_graph, job_builder, lbi2diff_lbi,\n                                            blob_desc->shape().elem_cnt());\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> ScaleInitialDiffByLossScale(\n    JobPassCtx* ctx, const OpGraph& op_graph, JobBuilder* job_builder,\n    HashMap<LogicalBlobId, LogicalBlobId>* loss_lbi2initial_diff_lbi) {\n  const TrainConf& train_conf = ctx->job_desc().job_conf().train_conf();\n  if (!train_conf.has_dynamic_loss_scale_policy() && !train_conf.has_loss_scale_factor()) {\n    return Maybe<void>::Ok();\n  }\n  for (auto& it : *loss_lbi2initial_diff_lbi) {\n    const auto& loss_lbi = it.first;\n    const auto& initial_diff_lbi = it.second;\n    const OpNode* initial_diff_node = op_graph.OpNode4OpName(initial_diff_lbi.op_name());\n    int64_t scope_symbol_id = initial_diff_node->op().op_conf().scope_symbol_id();\n    const auto& parallel_conf = initial_diff_node->parallel_desc().parallel_conf();\n\n    std::string loss_diff_lbn = GenLogicalBlobName(initial_diff_lbi);\n    const DataType init_diff_data_type = op_graph.GetLogicalBlobDesc(initial_diff_lbi).data_type();\n    // cast loss init diff from float16 to float32 since we need do loss scale (float32 multiply)\n    // later\n    if (init_diff_data_type != DataType::kFloat) {\n      std::string cast_op_name =\n          initial_diff_lbi.op_name() + \"_\" + initial_diff_lbi.blob_name() + \"_loss_scale-cast_h2f\";\n      auto cast_op = user_op::UserOpConfWrapperBuilder(cast_op_name)\n                         .Op(\"cast\")\n                         .Input(\"in\", loss_diff_lbn)\n                         .Output(\"out\")\n                         .Attr<DataType>(\"dtype\", DataType::kFloat)\n                         .ScopeSymbolId(scope_symbol_id)\n                         .Build();\n      job_builder->AddOps(parallel_conf, {cast_op.op_conf()});\n      loss_diff_lbn = cast_op.output(\"out\", 0);\n    }\n\n    std::string loss_scale_val_lbn;\n    if (train_conf.has_dynamic_loss_scale_policy()) {\n      const auto& dynamic_loss_scale_state =\n          JUST(ctx->GetState<DynamicLossScaleJobPassState>(\"dynamic_loss_scale_state\"));\n      loss_scale_val_lbn = dynamic_loss_scale_state.loss_scale_val_lbn();\n    } else if (train_conf.has_loss_scale_factor()) {\n      OperatorConf constant_like_op{};\n      constant_like_op.set_name(loss_lbi.op_name() + \"_\" + loss_lbi.blob_name()\n                                + \"_constant_like_loss_scale\");\n      constant_like_op.set_scope_symbol_id(scope_symbol_id);\n      ConstantLikeOpConf* constant_like_conf = constant_like_op.mutable_constant_like_conf();\n      constant_like_conf->set_like(loss_diff_lbn);\n      constant_like_conf->set_out(\"out\");\n      constant_like_conf->set_float_operand(train_conf.loss_scale_factor());\n      job_builder->AddOps(parallel_conf, {constant_like_op});\n      loss_scale_val_lbn = GenLogicalBlobName(constant_like_op.name(), constant_like_conf->out());\n    } else {\n      UNIMPLEMENTED_THEN_RETURN() << \"dynamic or static loss scale must be config\";\n    }\n\n    const int64_t time_shape_elem_cnt =\n        JUST(initial_diff_node->op().GetInputBlobFastestTimeShape())->elem_cnt();\n    if (time_shape_elem_cnt != 1) {\n      const auto repeat_op =\n          user_op::UserOpConfWrapperBuilder(loss_lbi.op_name() + \"_\" + loss_lbi.blob_name()\n                                            + \"_loss_scale-repeat\")\n              .OpTypeName(\"repeat\")\n              .Input(\"in\", loss_scale_val_lbn)\n              .Output(\"out\")\n              .Attr<int32_t>(\"repeat_num\", time_shape_elem_cnt)\n              .ScopeSymbolId(scope_symbol_id)\n              .Build();\n      job_builder->AddOps(parallel_conf, {repeat_op.op_conf()});\n      loss_scale_val_lbn = repeat_op.output(\"out\", 0);\n    }\n\n    auto scalar_mul_op =\n        user_op::UserOpConfWrapperBuilder(initial_diff_lbi.op_name() + \"_\"\n                                          + initial_diff_lbi.blob_name() + \"_scale_initial_diff\")\n            .Op(\"scalar_mul_by_tensor\")\n            .Input(\"x\", loss_diff_lbn)\n            .Input(\"scalar\", loss_scale_val_lbn)\n            .Output(\"y\")\n            .ScopeSymbolId(scope_symbol_id)\n            .Build();\n    job_builder->AddOps(parallel_conf, {scalar_mul_op.op_conf()});\n    std::string scaled_initial_diff_lbn = scalar_mul_op.output(\"y\", 0);\n\n    // cast loss initial diff back to float16\n    if (init_diff_data_type != DataType::kFloat) {\n      std::string cast_op_name =\n          initial_diff_lbi.op_name() + \"_\" + initial_diff_lbi.blob_name() + \"_loss_scale-cast_f2h\";\n      auto cast_op = user_op::UserOpConfWrapperBuilder(cast_op_name)\n                         .Op(\"cast\")\n                         .Input(\"in\", scaled_initial_diff_lbn)\n                         .Output(\"out\")\n                         .Attr<DataType>(\"dtype\", init_diff_data_type)\n                         .ScopeSymbolId(scope_symbol_id)\n                         .Build();\n      job_builder->AddOps(parallel_conf, {cast_op.op_conf()});\n      scaled_initial_diff_lbn = cast_op.output(\"out\", 0);\n    }\n\n    // update consumer input by scalar_mul_op output\n    initial_diff_node->ForEachNodeOnOutEdge([&](const OpNode* out_node) {\n      for (const std::string& ibn : out_node->op().input_bns()) {\n        if (out_node->op().BnInOp2Lbi(ibn) == initial_diff_lbi) {\n          if (!CHECK_JUST(job_builder->IsInMutOpTransaction(out_node->op().op_name()))) {\n            CHECK_JUST(job_builder->MutOpTransactionMut(out_node->op().op_conf()));\n          }\n          OperatorConf& mut_consumer_op =\n              CHECK_JUST(job_builder->MutOpTransactionGet(out_node->op().op_name()));\n          const auto& old_lbn =\n              ReplaceInputLbnInOpCustomizedConf(&mut_consumer_op, ibn, scaled_initial_diff_lbn);\n          CHECK_EQ(old_lbn, GenLogicalBlobName(initial_diff_lbi));\n        }\n      }\n    });\n    // update initial diff lbi\n    it.second = GenLogicalBlobId(scaled_initial_diff_lbn);\n  }\n  JUST(job_builder->MutOpTransactionCommit());\n  return Maybe<void>::Ok();\n}\n\nvoid ScaleModelDiffByLossScale(JobPassCtx* ctx, const OpGraph& op_graph, JobBuilder* job_builder,\n                               HashMap<LogicalBlobId, LogicalBlobId>* lbi2diff_lbi) {\n  auto ProducerOpNode4Lbi = [&](const LogicalBlobId& lbi) {\n    return op_graph.OpNode4OpName(lbi.op_name());\n  };\n  auto ProducerOpNode4Lbn = [&](const std::string& lbn) {\n    return ProducerOpNode4Lbi(GenLogicalBlobId(lbn));\n  };\n  const TrainConf& train_conf = ctx->job_desc().job_conf().train_conf();\n  if (train_conf.has_dynamic_loss_scale_policy()) {\n    const auto& dynamic_loss_scale_state =\n        CHECK_JUST(ctx->GetState<DynamicLossScaleJobPassState>(\"dynamic_loss_scale_state\"));\n    HashMap<DataType, std::string> data_type2loss_scale_lbn;\n    const auto LossScale4DataType = [&](DataType data_type) -> std::string {\n      auto it = data_type2loss_scale_lbn.find(data_type);\n      if (it == data_type2loss_scale_lbn.end()) {\n        const std::string& loss_scale_val_lbn = dynamic_loss_scale_state.loss_scale_val_lbn();\n        const int64_t scope_symbol_id =\n            ScopeSymbolId4Lbi(op_graph, GenLogicalBlobId(loss_scale_val_lbn));\n        const ParallelConf& parallel_conf =\n            ProducerOpNode4Lbn(loss_scale_val_lbn)->parallel_desc().parallel_conf();\n        std::string loss_scale_lbn_with_data_type;\n        if (data_type == DataType::kFloat) {\n          loss_scale_lbn_with_data_type = loss_scale_val_lbn;\n        } else {\n          auto cast_op =\n              user_op::UserOpConfWrapperBuilder(\"System-DynamicLossScale-Cast-\" + NewUniqueId())\n                  .Op(\"cast\")\n                  .Input(\"in\", loss_scale_val_lbn)\n                  .Output(\"out\")\n                  .Attr<DataType>(\"dtype\", data_type)\n                  .ScopeSymbolId(scope_symbol_id)\n                  .Build();\n          loss_scale_lbn_with_data_type = cast_op.output(\"out\", 0);\n          job_builder->AddOps(parallel_conf, {cast_op.op_conf()});\n        }\n        auto inv_scale_op =\n            user_op::UserOpConfWrapperBuilder(\"System-DynamicLossScale-Reciprocal-\" + NewUniqueId())\n                .Op(\"reciprocal\")\n                .Input(\"x\", loss_scale_lbn_with_data_type)\n                .Output(\"y\")\n                .ScopeSymbolId(scope_symbol_id)\n                .Build();\n        job_builder->AddOps(parallel_conf, {inv_scale_op.op_conf()});\n        std::string lbn = inv_scale_op.output(\"y\", 0);\n        data_type2loss_scale_lbn[data_type] = lbn;\n        return lbn;\n      } else {\n        return it->second;\n      }\n    };\n    for (auto& pair : *lbi2diff_lbi) {\n      const LogicalBlobId& lbi = pair.first;\n      LogicalBlobId& diff_lbi = pair.second;\n      auto scalar_mul_op =\n          user_op::UserOpConfWrapperBuilder(\"Sys-DiffScale-ScalarMul-\" + lbi.op_name() + \"_\"\n                                            + lbi.blob_name() + \"-\" + NewUniqueId())\n              .Op(\"scalar_mul_by_tensor\")\n              .Input(\"x\", GenLogicalBlobName(diff_lbi))\n              .Input(\"scalar\", LossScale4DataType(op_graph.GetLogicalBlobDesc(lbi).data_type()))\n              .Output(\"y\")\n              .ScopeSymbolId(ScopeSymbolId4Lbi(op_graph, lbi))\n              .Build();\n      job_builder->AddOps(ProducerParallelConf4Lbi(op_graph, lbi), {scalar_mul_op.op_conf()});\n      diff_lbi = GenLogicalBlobId(scalar_mul_op.output(\"y\", 0));\n    }\n  } else if (train_conf.has_loss_scale_factor()) {\n    const float loss_scale_factor = train_conf.loss_scale_factor();\n    if (loss_scale_factor == 1) { return; }\n    const float down_scale_factor = 1.0f / loss_scale_factor;\n    for (auto& pair : *lbi2diff_lbi) {\n      const LogicalBlobId& lbi = pair.first;\n      LogicalBlobId& diff_lbi = pair.second;\n      auto scalar_mul_op =\n          user_op::UserOpConfWrapperBuilder(\"Sys-DiffScale-ScalarMul-\" + lbi.op_name() + \"_\"\n                                            + lbi.blob_name() + \"-\" + NewUniqueId())\n              .Op(\"scalar_mul\")\n              .Input(\"in\", GenLogicalBlobName(diff_lbi))\n              .Output(\"out\")\n              .Attr<bool>(\"has_float_operand\", true)\n              .Attr<double>(\"float_operand\", down_scale_factor)\n              .Attr<bool>(\"has_int_operand\", false)\n              .Attr<int64_t>(\"int_operand\", 0)\n              .ScopeSymbolId(ScopeSymbolId4Lbi(op_graph, lbi))\n              .Build();\n      job_builder->AddOps(ProducerParallelConf4Lbi(op_graph, lbi), {scalar_mul_op.op_conf()});\n      diff_lbi = GenLogicalBlobId(scalar_mul_op.output(\"out\", 0));\n    }\n  } else {\n    return;\n  }\n}\n\nvoid RegularizeGradient(const OpGraph& op_graph, JobBuilder* job_builder,\n                        HashMap<LogicalBlobId, LogicalBlobId>* lbi2diff_lbi) {\n  for (auto& pair : *lbi2diff_lbi) {\n    const LogicalBlobId& lbi = pair.first;\n    LogicalBlobId& diff_lbi = pair.second;\n    const OpNode* model_op_node = op_graph.OpNode4OpName(lbi.op_name());\n    int64_t scope_symbol_id = model_op_node->op().op_conf().scope_symbol_id();\n    CHECK(model_op_node->op().op_conf().has_variable_conf());\n    const VariableOpConf& variable_conf = model_op_node->op().op_conf().variable_conf();\n    if (!variable_conf.has_regularizer()) { continue; }\n    const RegularizerConf& regularizer_conf = variable_conf.regularizer();\n    if (regularizer_conf.has_l1_l2_conf()) {\n      user_op::UserOpConfWrapper regularize_gradient_op =\n          user_op::UserOpConfWrapperBuilder(\"System-RegularizeGradient-L1L2-\" + NewUniqueId())\n              .Op(\"l1_l2_regularize_gradient\")\n              .Input(\"model\", GenLogicalBlobName(lbi))\n              .Input(\"model_diff\", GenLogicalBlobName(diff_lbi))\n              .Output(\"out\")\n              .Attr<float>(\"l1\", regularizer_conf.l1_l2_conf().l1())\n              .Attr<float>(\"l2\", regularizer_conf.l1_l2_conf().l2())\n              .ScopeSymbolId(scope_symbol_id)\n              .Build();\n      job_builder->AddOps(model_op_node->parallel_desc().parallel_conf(),\n                          {regularize_gradient_op.op_conf()});\n      diff_lbi = GenLogicalBlobId(regularize_gradient_op.output(\"out\", 0));\n    } else {\n      UNIMPLEMENTED();\n    }\n  }\n}\n\nvoid ClipGradient(JobPassCtx* ctx, const OpGraph& op_graph, JobBuilder* job_builder,\n                  HashMap<LogicalBlobId, LogicalBlobId>* lbi2diff_lbi, const ClipConf& clip_conf) {\n  if (clip_conf.has_clip_by_global_norm()) {\n    ClipGradientByGlobalNorm(ctx, op_graph, job_builder, lbi2diff_lbi,\n                             clip_conf.clip_by_global_norm());\n  } else {\n    UNIMPLEMENTED();\n  }\n}\n\nvoid AddDiffParallelCast(const OpGraph& op_graph, JobBuilder* job_builder,\n                         HashMap<LogicalBlobId, LogicalBlobId>* lbi2diff_lbi) {\n  for (auto& pair : *lbi2diff_lbi) {\n    const LogicalBlobId& lbi = pair.first;\n    LogicalBlobId& diff_lbi = pair.second;\n    const OpNode* model_op_node = op_graph.OpNode4OpName(lbi.op_name());\n    if (model_op_node->parallel_desc().parallel_num() <= 1) { continue; }\n    const int64_t scope_symbol_id = model_op_node->op().op_conf().scope_symbol_id();\n    std::vector<std::string> nd_sbp;\n    const std::string& variable_sole_obn = model_op_node->op().SoleObn();\n    nd_sbp.reserve(model_op_node->NdSbp4BnInOp(variable_sole_obn).sbp_parallel().size());\n    for (const auto& sbp_parallel : model_op_node->NdSbp4BnInOp(variable_sole_obn).sbp_parallel()) {\n      nd_sbp.emplace_back(SbpParallelToString(sbp_parallel));\n    }\n    auto parallel_cast_op =\n        user_op::UserOpConfWrapperBuilder(\"System-AutoGrad-ParallelCast-\" + NewUniqueId())\n            .Op(\"hierarchical_parallel_cast\")\n            .Input(\"in\", GenLogicalBlobName(diff_lbi))\n            .Output(\"out\")\n            .Attr<std::vector<std::string>>(\"nd_sbp\", nd_sbp)\n            .Attr<std::string>(\"grad_mode\", \"auto\")\n            .Attr<std::vector<std::string>>(\"grad_nd_sbp\", std::vector<std::string>())\n            .ScopeSymbolId(scope_symbol_id)\n            .Build();\n    job_builder->AddOps(model_op_node->parallel_desc().parallel_conf(),\n                        {parallel_cast_op.op_conf()});\n    diff_lbi = GenLogicalBlobId(parallel_cast_op.output(\"out\", 0));\n  }\n}\n\nvoid AddDiffHalf2FloatCast(const OpGraph& op_graph, JobBuilder* job_builder,\n                           HashMap<LogicalBlobId, LogicalBlobId>* lbi2diff_lbi) {\n  for (auto& pair : *lbi2diff_lbi) {\n    LogicalBlobId& diff_lbi = pair.second;\n    auto data_type = op_graph.GetLogicalBlobDesc(diff_lbi).data_type();\n    if (data_type != DataType::kFloat) {\n      std::string lbn = GenLogicalBlobName(diff_lbi);\n      const OpNode* op_node = op_graph.OpNode4OpName(diff_lbi.op_name());\n      int64_t scope_symbol_id = op_node->op().op_conf().scope_symbol_id();\n      auto cast_op = user_op::UserOpConfWrapperBuilder(ReplaceSlashToDash4Lbn(lbn) + \"-cast_h2f\")\n                         .Op(\"cast\")\n                         .Input(\"in\", lbn)\n                         .Output(\"out\")\n                         .Attr<DataType>(\"dtype\", DataType::kFloat)\n                         .ScopeSymbolId(scope_symbol_id)\n                         .Build();\n      job_builder->AddOps(op_node->parallel_desc().parallel_conf(), {cast_op.op_conf()});\n      diff_lbi = GenLogicalBlobId(cast_op.output(\"out\", 0));\n    }\n  }\n}\n\nvoid AddDiffStaticShapeCast(const OpGraph& op_graph, JobBuilder* job_builder,\n                            HashMap<LogicalBlobId, LogicalBlobId>* lbi2diff_lbi) {\n  for (auto& pair : *lbi2diff_lbi) {\n    const LogicalBlobId& lbi = pair.first;\n    LogicalBlobId& diff_lbi = pair.second;\n    const OpNode* model_op_node = op_graph.OpNode4OpName(lbi.op_name());\n    int64_t scope_symbol_id = model_op_node->op().op_conf().scope_symbol_id();\n    const auto cast_to_static_shape_op =\n        user_op::UserOpConfWrapperBuilder(\"System-AutoGrad-StaticShapeCast-\" + NewUniqueId())\n            .Op(\"cast_to_static_shape\")\n            .Input(\"input\", GenLogicalBlobName(diff_lbi))\n            .Output(\"output\")\n            .ScopeSymbolId(scope_symbol_id)\n            .Build();\n    job_builder->AddOps(model_op_node->parallel_desc().parallel_conf(),\n                        {cast_to_static_shape_op.op_conf()});\n    diff_lbi = GenLogicalBlobId(cast_to_static_shape_op.output(\"output\", 0));\n  }\n}\n\nMaybe<void> CountNotFiniteIfNeeded(JobPassCtx* ctx, const OpGraph& op_graph,\n                                   JobBuilder* job_builder,\n                                   const HashMap<LogicalBlobId, LogicalBlobId>& lbi2diff_lbi) {\n  if (lbi2diff_lbi.empty()) { return Maybe<void>::Ok(); }\n  if (!ctx->job_desc().job_conf().train_conf().has_dynamic_loss_scale_policy()) {\n    return Maybe<void>::Ok();\n  }\n  bool all_same_parallel_desc = true;\n  const ParallelDesc& any_parallel_desc =\n      op_graph.OpNode4OpName(lbi2diff_lbi.begin()->first.op_name())->parallel_desc();\n  std::vector<std::string> partial_count_not_finite_lbns;\n  std::vector<bool> is_broadcast_nd_sbp;\n  std::vector<ParallelConf> param_group_parallel_confs;\n  ForEachAggregatedParamGroup(\n      op_graph, lbi2diff_lbi,\n      [&](const ParallelDesc& parallel_desc, const NdSbp& nd_sbp,\n          const std::vector<LogicalBlobId>& lbis) {\n        if (!parallel_desc.EqualsIgnoringHierarchy(any_parallel_desc)) {\n          all_same_parallel_desc = false;\n        }\n        const int64_t scope_symbol_id =\n            MakeScopeSymbolId(job_builder->job().job_conf(), parallel_desc.parallel_conf());\n        is_broadcast_nd_sbp.emplace_back(IsBroadcast(nd_sbp, parallel_desc));\n        param_group_parallel_confs.emplace_back(parallel_desc.parallel_conf());\n        if (job_builder->job().job_conf().enable_gradients_stats_aggregation()) {\n          auto multi_count_not_finite_op_builder =\n              user_op::UserOpConfWrapperBuilder(\"System-DynamicLossScale-MultiCountNotFinite-\"\n                                                + NewUniqueId())\n                  .Op(\"multi_count_not_finite\")\n                  .Output(\"y\")\n                  .ScopeSymbolId(scope_symbol_id);\n          for (const auto& lbi : lbis) {\n            multi_count_not_finite_op_builder.Input(\"x\", GenLogicalBlobName(lbi2diff_lbi.at(lbi)));\n          }\n          const auto multi_count_not_finite_op = multi_count_not_finite_op_builder.Build();\n          job_builder->AddOps(parallel_desc.parallel_conf(), {multi_count_not_finite_op.op_conf()});\n          partial_count_not_finite_lbns.emplace_back(multi_count_not_finite_op.output(\"y\", 0));\n        } else {\n          std::vector<std::string> lbns_to_add;\n          for (const auto& lbi : lbis) {\n            const auto count_not_finite_op =\n                user_op::UserOpConfWrapperBuilder(\"System-DynamicLossScale-CountNotFinite-\"\n                                                  + NewUniqueId())\n                    .Op(\"count_not_finite\")\n                    .Input(\"x\", GenLogicalBlobName(lbi2diff_lbi.at(lbi)))\n                    .Output(\"y\")\n                    .ScopeSymbolId(scope_symbol_id)\n                    .Build();\n            job_builder->AddOps(parallel_desc.parallel_conf(), {count_not_finite_op.op_conf()});\n            lbns_to_add.emplace_back(count_not_finite_op.output(\"y\", 0));\n          }\n          partial_count_not_finite_lbns.emplace_back(\n              AddLbns(job_builder, lbns_to_add, parallel_desc.parallel_conf(), scope_symbol_id,\n                      \"System-DynamicLossScale-CountNotFinite-Add-\"));\n        }\n      });\n\n  const bool all_group_broadcast =\n      std::all_of(is_broadcast_nd_sbp.begin(), is_broadcast_nd_sbp.end(), [](bool i) { return i; });\n  std::vector<std::string> count_not_finite_lbns_for_add;\n  ParallelConf count_all_parallel_conf = all_same_parallel_desc\n                                             ? any_parallel_desc.parallel_conf()\n                                             : GenParallelConfOfCpuZeroOnMaster();\n  if (!all_group_broadcast) {\n    for (int64_t i = 0; i < partial_count_not_finite_lbns.size(); ++i) {\n      count_not_finite_lbns_for_add.emplace_back(AddParallelCast(\n          job_builder, JUST(VectorAt(partial_count_not_finite_lbns, i)), \"P\",\n          JUST(VectorAt(param_group_parallel_confs, i)), \"System-DynamicLossScale-ParallelCast-\"));\n    }\n    count_all_parallel_conf.mutable_hierarchy()->clear_dim();\n  } else {\n    count_not_finite_lbns_for_add = std::move(partial_count_not_finite_lbns);\n  }\n  const int64_t scope_symbol_id =\n      MakeScopeSymbolId(job_builder->job().job_conf(), count_all_parallel_conf);\n  std::string count_all_lbn =\n      AddLbns(job_builder, count_not_finite_lbns_for_add, count_all_parallel_conf, scope_symbol_id,\n              \"System-DynamicLossScale-CountNotFinite-Add-\");\n  if (!all_group_broadcast) {\n    std::vector<std::string> cast_nd_sbp;\n    cast_nd_sbp.emplace_back(\"B\");\n    auto parallel_cast_op =\n        user_op::UserOpConfWrapperBuilder(\n            \"System-DynamicLossScale-CountNotFinite-After-Add-ParallelCast-\" + NewUniqueId())\n            .Op(\"hierarchical_parallel_cast\")\n            .Input(\"in\", count_all_lbn)\n            .Output(\"out\")\n            .Attr<std::vector<std::string>>(\"nd_sbp\", cast_nd_sbp)\n            .Attr<std::string>(\"grad_mode\", \"auto\")\n            .Attr<std::vector<std::string>>(\"grad_nd_sbp\", std::vector<std::string>())\n            .ScopeSymbolId(scope_symbol_id)\n            .Build();\n    job_builder->AddOps(count_all_parallel_conf, {parallel_cast_op.op_conf()});\n    count_all_lbn = parallel_cast_op.output(\"out\", 0);\n  }\n  const LogicalBlobId count_not_finite_lbi =\n      GenLogicalBlobId(JUST(ctx->GetState<DynamicLossScaleJobPassState>(\"dynamic_loss_scale_state\"))\n                           .count_not_finite_lbn());\n  auto count_not_finite_op = user_op::UserOpConfWrapperBuilder(count_not_finite_lbi.op_name())\n                                 .Op(\"identity\")\n                                 .Input(\"in\", count_all_lbn)\n                                 .Output(\"out\")\n                                 .ScopeSymbolId(scope_symbol_id)\n                                 .Build();\n  job_builder->MutOpsOnlyOnce({count_not_finite_op.op_conf()});\n  job_builder->MutParallelConfOnlyOnce(count_not_finite_op.op_name(), count_all_parallel_conf);\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job_rewriter/autograd.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_JOB_REWRITER_AUTOGRAD_H_\n#define ONEFLOW_CORE_JOB_REWRITER_AUTOGRAD_H_\n\n#include \"oneflow/core/job/job_desc.h\"\n#include \"oneflow/core/operator/operator.h\"\n#include \"oneflow/core/graph/op_graph.h\"\n\nnamespace oneflow {\n\nclass JobPassCtx;\n\nvoid AddDiffHalf2FloatCast(const OpGraph& op_graph, JobBuilder* job_builder,\n                           HashMap<LogicalBlobId, LogicalBlobId>* lbi2diff_lbi);\nvoid AddDiffParallelCast(const OpGraph& op_graph, JobBuilder* job_builder,\n                         HashMap<LogicalBlobId, LogicalBlobId>* lbi2diff_lbi);\nvoid AddDiffStaticShapeCast(const OpGraph& op_graph, JobBuilder* job_builder,\n                            HashMap<LogicalBlobId, LogicalBlobId>* lbi2diff_lbi);\nMaybe<void> CountNotFiniteIfNeeded(JobPassCtx* ctx, const OpGraph& op_graph,\n                                   JobBuilder* job_builder,\n                                   const HashMap<LogicalBlobId, LogicalBlobId>& lbi2diff_lbi);\nMaybe<void> MakeGetterLossOpNode4OpName(\n    const OpGraph& op_graph, std::function<OpNode*(const std::string&)>* LossOpNode4OpName);\nMaybe<void> ScaleModelDiffByLossInstanceNum(const OpGraph& op_graph, JobBuilder* job_builder,\n                                            HashMap<LogicalBlobId, LogicalBlobId>* lbi2diff_lbi);\n\nMaybe<void> ScaleInitialDiffByLossScale(\n    JobPassCtx* ctx, const OpGraph& op_graph, JobBuilder* job_builder,\n    HashMap<LogicalBlobId, LogicalBlobId>* loss_lbi2initial_diff_lbi);\n\nvoid ScaleModelDiffByLossScale(JobPassCtx* ctx, const OpGraph& op_graph, JobBuilder* job_builder,\n                               HashMap<LogicalBlobId, LogicalBlobId>* lbi2diff_lbi);\nvoid RegularizeGradient(const OpGraph& op_graph, JobBuilder* job_builder,\n                        HashMap<LogicalBlobId, LogicalBlobId>* lbi2diff_lbi);\nvoid ClipGradient(JobPassCtx* ctx, const OpGraph& op_graph, JobBuilder* job_builder,\n                  HashMap<LogicalBlobId, LogicalBlobId>* lbi2diff_lbi, const ClipConf& clip_conf);\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_JOB_REWRITER_AUTOGRAD_H_\n"
  },
  {
    "path": "oneflow/core/job_rewriter/autotick.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/container_util.h\"\n#include \"oneflow/core/common/protobuf.h\"\n#include \"oneflow/core/common/just.h\"\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/job_rewriter/autotick.h\"\n#include \"oneflow/core/job/job_builder.h\"\n#include \"oneflow/core/job/critical_section_desc.h\"\n#include \"oneflow/core/common/protobuf.h\"\n#include \"oneflow/core/common/container_util.h\"\n#include \"oneflow/core/common/buffer_manager.h\"\n#include \"oneflow/core/job/global_for.h\"\n#include \"oneflow/core/operator/op_conf.pb.h\"\n#include \"oneflow/core/operator/operator.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nstd::unique_ptr<MutOpConTickInputHelper> NewMutOpConTickInputHelper(const OperatorConf& op_conf) {\n  std::unique_ptr<MutOpConTickInputHelper> ret;\n  if (IsClassRegistered<int32_t, MutOpConTickInputHelper>(op_conf.op_type_case())) {\n    ret.reset(NewObj<int32_t, MutOpConTickInputHelper>(op_conf.op_type_case()));\n    ret->InitFromOpConf(op_conf);\n  }\n  return ret;\n}\n\nvoid PrependTickByParallelDesc(const OpGraph& op_graph, JobBuilder* job_builder) {\n  HashMap<ParallelDesc, std::vector<OpNode*>> parallel_desc2op_node;\n  op_graph.ForEachNode([&](OpNode* op_node) {\n    auto mut_tick_input_helper = NewMutOpConTickInputHelper(op_node->op().op_conf());\n    if (!mut_tick_input_helper) { return; }\n    if (mut_tick_input_helper->IsTickInputBound() == true) { return; }\n    parallel_desc2op_node[op_node->parallel_desc()].emplace_back(op_node);\n  });\n  for (const auto& pair : parallel_desc2op_node) {\n    OperatorConf device_tick_op;\n    device_tick_op.set_name(\"System-AutoTick-Prepend-DeviceTick_\" + NewUniqueId());\n    auto* device_tick_op_conf = device_tick_op.mutable_device_tick_conf();\n    device_tick_op_conf->set_out(\"out\");\n    job_builder->AddOps(pair.first.parallel_conf(), {device_tick_op});\n\n    for (const auto* op_node : pair.second) {\n      auto mut_tick_input_helper = NewMutOpConTickInputHelper(op_node->op().op_conf());\n      job_builder->MutOpsOnlyOnce(\n          {mut_tick_input_helper->NewTickInputBoundOpConf(device_tick_op.name() + \"/out\")});\n    }\n  }\n}\n\nMaybe<const OperatorConf&> FindJobSoleSrcSubsetTickOpConf(const Job& job) {\n  const OperatorConf* src_subset_tick_op_conf = nullptr;\n  for (const auto& op_conf : job.net().op()) {\n    if (!op_conf.has_src_subset_tick_conf()) { continue; }\n    CHECK_ISNULL_OR_RETURN(src_subset_tick_op_conf);\n    src_subset_tick_op_conf = &op_conf;\n  }\n  CHECK_NOTNULL_OR_RETURN(src_subset_tick_op_conf);\n  return *src_subset_tick_op_conf;\n}\n\nMaybe<void> BuildDstSubsetTickOpAndParallelConf(const HashSet<LogicalBlobId>& tick_lbis,\n                                                OperatorConf* dst_subset_tick_op,\n                                                JobBuilder* job_builder) {\n  dst_subset_tick_op->set_name(\"System-AutoTick-DstSubsetTick_\" + NewUniqueId());\n  auto* dst_subset_tick_op_conf = dst_subset_tick_op->mutable_dst_subset_tick_conf();\n  dst_subset_tick_op_conf->set_out(\"out\");\n  for (const LogicalBlobId& tick_lbi : tick_lbis) {\n    dst_subset_tick_op_conf->add_in(GenLogicalBlobName(tick_lbi));\n  }\n  ParallelConf parallel_conf;\n  parallel_conf.set_device_tag(\"cpu\");\n  for (int64_t machine_id : Singleton<ResourceDesc, ForSession>::Get()->process_ranks()) {\n    parallel_conf.add_device_name(std::string(\"@\") + std::to_string(machine_id) + \":0\");\n  }\n  JUST(job_builder->AddOp(parallel_conf, *dst_subset_tick_op));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> CreateDstSubsetTickAndSinkTicks(\n    const OperatorConf& src_subset_tick, const HashSet<LogicalBlobId>& tick_lbis,\n    JobBuilder* job_builder,\n    const std::function<Maybe<void>(int64_t machine_id, const std::string& op_name)>& DoEachSink) {\n  OperatorConf dst_subset_tick;\n  dst_subset_tick.mutable_dst_subset_tick_conf()->add_in(\n      src_subset_tick.name() + \"/\" + src_subset_tick.src_subset_tick_conf().out());\n  JUST(BuildDstSubsetTickOpAndParallelConf(tick_lbis, &dst_subset_tick, job_builder));\n  const auto& process_ranks = Singleton<ResourceDesc, ForSession>::Get()->process_ranks();\n  HashMap<int64_t, std::string> machine_id2gather_tick_in_lbns;\n  for (int64_t machine_id : process_ranks) {\n    ParallelConf parallel_conf;\n    parallel_conf.set_device_tag(\"cpu\");\n    parallel_conf.add_device_name(std::string(\"@\") + std::to_string(machine_id) + \":0\");\n    OperatorConf tick_op;\n    {\n      tick_op.set_name(\"System-AutoTick-Tick_\" + NewUniqueId());\n      auto* tick_conf = tick_op.mutable_tick_conf();\n      tick_conf->add_tick(dst_subset_tick.name() + \"/\"\n                          + dst_subset_tick.dst_subset_tick_conf().out());\n      tick_conf->set_out(\"out\");\n      JUST(job_builder->AddOp(parallel_conf, tick_op));\n    }\n    CHECK_OR_RETURN(\n        machine_id2gather_tick_in_lbns.emplace(machine_id, tick_op.name() + \"/out\").second);\n  }\n  for (int64_t machine_id : process_ranks) {\n    ParallelConf parallel_conf;\n    parallel_conf.set_device_tag(\"cpu\");\n    parallel_conf.add_device_name(std::string(\"@\") + std::to_string(machine_id) + \":0\");\n    OperatorConf tick_op;\n    {\n      tick_op.set_name(\"System-SyncAllRanksSinkTick_\" + NewUniqueId());\n      auto* tick_conf = tick_op.mutable_tick_conf();\n      // gather ticks from all processes.\n      for (int64_t tick_machine_id : process_ranks) {\n        tick_conf->add_tick(JUST(MapAt(machine_id2gather_tick_in_lbns, tick_machine_id)));\n      }\n      tick_conf->set_out(\"out\");\n      JUST(job_builder->AddOp(parallel_conf, tick_op));\n    }\n    OperatorConf sink_tick_op;\n    {\n      sink_tick_op.set_name(\"System-AutoTick-SinkTick_\" + NewUniqueId());\n      auto* sink_tick_conf = sink_tick_op.mutable_sink_tick_conf();\n      sink_tick_conf->add_tick(tick_op.name() + \"/out\");\n      sink_tick_conf->set_out(\"out\");\n      JUST(job_builder->AddOp(parallel_conf, sink_tick_op));\n    }\n    JUST(DoEachSink(machine_id, sink_tick_op.name()));\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> CreateDstSubsetTickAndSinkTicks(CriticalSection* critical_section,\n                                            const OperatorConf& src_subset_tick,\n                                            const HashSet<LogicalBlobId>& tick_lbis,\n                                            JobBuilder* job_builder) {\n  auto* map = critical_section->mutable_machine_id2sink_tick_op_name();\n  const auto& DoEachSink = [&](int64_t machine_id, const std::string& op_name) -> Maybe<void> {\n    (*map)[machine_id] = op_name;\n    return Maybe<void>::Ok();\n  };\n  JUST(CreateDstSubsetTickAndSinkTicks(src_subset_tick, tick_lbis, job_builder, DoEachSink));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> BuildSrcSubsetTickOpAndParallelConf(OperatorConf* src_subset_tick_op,\n                                                JobBuilder* job_builder) {\n  src_subset_tick_op->set_name(\"System-AutoTick-SrcSubsetTick_\" + NewUniqueId());\n  src_subset_tick_op->mutable_src_subset_tick_conf()->set_out(\"out\");\n  ParallelConf parallel_conf;\n  parallel_conf.set_device_tag(\"cpu\");\n  for (int64_t machine_id : Singleton<ResourceDesc, ForSession>::Get()->process_ranks()) {\n    parallel_conf.add_device_name(std::string(\"@\") + std::to_string(machine_id) + \":0\");\n  }\n  JUST(job_builder->AddOp(parallel_conf, *src_subset_tick_op));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> CreateSourceTicksAndSrcSubsetTick(\n    OperatorConf* src_subset_tick_op, JobBuilder* job_builder,\n    const std::function<Maybe<void>(int64_t machine_id, const std::string& op_name)>& DoEachSrc) {\n  for (int64_t machine_id : Singleton<ResourceDesc, ForSession>::Get()->process_ranks()) {\n    ParallelConf parallel_conf;\n    parallel_conf.set_device_tag(\"cpu\");\n    parallel_conf.add_device_name(std::string(\"@\") + std::to_string(machine_id) + \":0\");\n    OperatorConf src_tick_op;\n    {\n      src_tick_op.set_name(\"System-AutoTick-SourceTick_\" + NewUniqueId());\n      src_tick_op.mutable_source_tick_conf()->set_out(\"out\");\n      JUST(job_builder->AddOp(parallel_conf, src_tick_op));\n    }\n    JUST(DoEachSrc(machine_id, src_tick_op.name()));\n    OperatorConf tick_op;\n    {\n      tick_op.set_name(\"System-AutoTick-Tick_\" + NewUniqueId());\n      tick_op.mutable_tick_conf()->add_tick(src_tick_op.name() + \"/out\");\n      tick_op.mutable_tick_conf()->set_out(\"out\");\n      JUST(job_builder->AddOp(parallel_conf, tick_op));\n    }\n    src_subset_tick_op->mutable_src_subset_tick_conf()->add_in(tick_op.name() + \"/out\");\n  }\n  JUST(job_builder->MutOpOnlyOnce(*src_subset_tick_op));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> CreateSourceTicksAndSrcSubsetTick(CriticalSection* critical_section,\n                                              OperatorConf* src_subset_tick_op,\n                                              JobBuilder* job_builder) {\n  auto* map = critical_section->mutable_machine_id2source_tick_op_name();\n  const auto& DoEachSrc = [&](int64_t machine_id, const std::string& op_name) -> Maybe<void> {\n    (*map)[machine_id] = op_name;\n    return Maybe<void>::Ok();\n  };\n  JUST(CreateSourceTicksAndSrcSubsetTick(src_subset_tick_op, job_builder, DoEachSrc));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> ConnectSrcSubsetTickAndOtherTick(const OperatorConf& src_subset_tick_op,\n                                             JobBuilder* job_builder) {\n  CHECK_OR_RETURN(src_subset_tick_op.has_src_subset_tick_conf());\n  const std::string& src_lbn =\n      src_subset_tick_op.name() + \"/\" + src_subset_tick_op.src_subset_tick_conf().out();\n  JUST(job_builder->ForEachOperator([&](const Operator& op) -> Maybe<void> {\n    if (op.op_name() != src_subset_tick_op.name()) {\n      CHECK_OR_RETURN(!op.op_conf().has_src_subset_tick_conf());\n    }\n    auto mut_helper = NewMutOpConTickInputHelper(op.op_conf());\n    if (!mut_helper) { return Maybe<void>::Ok(); }\n    if (mut_helper->IsTickInputBound() == true) { return Maybe<void>::Ok(); }\n    JUST(job_builder->MutOpOnlyOnce(mut_helper->NewTickInputBoundOpConf(src_lbn)));\n    return Maybe<void>::Ok();\n  }));\n  return Maybe<void>::Ok();\n}\n\nMaybe<const OpNode*> GetSrcSubsetTickOpNode(const OpGraph& op_graph) {\n  const OpNode* src_subset_tick = nullptr;\n  JUST(op_graph.MaybeForEachNode([&](OpNode* op_node) -> Maybe<void> {\n    if (op_node->op().op_conf().has_src_subset_tick_conf()) {\n      CHECK_ISNULL_OR_RETURN(src_subset_tick);\n      src_subset_tick = op_node;\n    }\n    return Maybe<void>::Ok();\n  }));\n  CHECK_NOTNULL_OR_RETURN(src_subset_tick);\n  return src_subset_tick;\n}\n\nOperatorConf MakeTickOpConf(const std::string& tick_name) {\n  OperatorConf tick_op_conf;\n  tick_op_conf.set_name(std::string(\"System-AutoTick-\" + tick_name + \"Tick_\") + NewUniqueId());\n  auto* tick_conf = tick_op_conf.mutable_tick_conf();\n  tick_conf->set_out(\"out\");\n  return tick_op_conf;\n}\n\nOperatorConf MakeDeviceTickOpConf(const std::string& tick_name) {\n  OperatorConf device_tick_op_conf;\n  device_tick_op_conf.set_name(std::string(\"System-AutoTick-\" + tick_name + \"DeviceTick_\")\n                               + NewUniqueId());\n  auto* tick_conf = device_tick_op_conf.mutable_device_tick_conf();\n  tick_conf->set_out(\"out\");\n  return device_tick_op_conf;\n}\n\nOperatorConf AppendTick(const std::string tick_name, const std::vector<std::string>& op_names,\n                        const std::shared_ptr<const Shape>& time_shape, ParallelConf parallel_conf,\n                        JobBuilder* job_builder) {\n  OperatorConf device_tick_op_conf = MakeDeviceTickOpConf(tick_name);\n  if (time_shape) {\n    time_shape->ToProto(device_tick_op_conf.mutable_device_tick_conf()->mutable_time_shape());\n  }\n  for (const auto& op_name : op_names) { device_tick_op_conf.add_ctrl_in_op_name(op_name); }\n  job_builder->AddOps(parallel_conf, {device_tick_op_conf});\n  return device_tick_op_conf;\n}\n\nOperatorConf AppendTick(const std::string tick_name, const std::list<const OpNode*>& op_nodes,\n                        const std::shared_ptr<const Shape>& time_shape, JobBuilder* job_builder) {\n  std::vector<std::string> op_names;\n  op_names.reserve(op_nodes.size());\n  for (const auto* op_node : op_nodes) {\n    CHECK(op_nodes.front()->parallel_desc() == op_node->parallel_desc());\n    op_names.emplace_back(op_node->op().op_name());\n  }\n  return AppendTick(tick_name, op_names, time_shape,\n                    op_nodes.front()->parallel_desc().parallel_conf(), job_builder);\n}\n\nOperatorConf PrependTick(const HashSet<const OpNode*>& op_nodes, JobBuilder* job_builder) {\n  CHECK_GE(op_nodes.size(), 1);\n  OperatorConf tick_op_conf = MakeTickOpConf(\"Prepend\");\n  std::vector<OperatorConf> op_confs;\n  op_confs.reserve(op_nodes.size());\n  for (const OpNode* op_node : op_nodes) {\n    OperatorConf op_conf(op_node->op().op_conf());\n    op_conf.add_ctrl_in_op_name(tick_op_conf.name());\n    op_confs.emplace_back(op_conf);\n  }\n  job_builder->MutOpsOnlyOnce({op_confs});\n  ParallelDesc pd((*op_nodes.begin())->parallel_desc());\n  pd.set_device_type(DeviceType::kCPU);\n  job_builder->AddOps(pd.parallel_conf(), {tick_op_conf});\n  return tick_op_conf;\n}\n\nOperatorConf AppendAccTick(const Shape& src_shape, const std::list<const OpNode*>& op_nodes,\n                           JobBuilder* job_builder) {\n  std::shared_ptr<const Shape> tick_shape = CHECK_JUST(op_nodes.front()->op().GetOpTimeShape());\n  CHECK_EQ(tick_shape->elem_cnt() % src_shape.elem_cnt(), 0);\n  const OperatorConf& tick_op_conf = AppendTick(\"AppendAcc\", op_nodes, tick_shape, job_builder);\n  OperatorConf acc_op_conf;\n  {\n    acc_op_conf.set_name(std::string(\"System-AutoTick-AccTick_\") + NewUniqueId());\n    auto* acc_conf = acc_op_conf.mutable_acc_tick_conf();\n    CHECK(tick_op_conf.has_device_tick_conf());\n    acc_conf->set_one(tick_op_conf.name() + \"/\" + tick_op_conf.device_tick_conf().out());\n    acc_conf->set_acc(\"acc\");\n    acc_conf->set_max_acc_num(tick_shape->elem_cnt() / src_shape.elem_cnt());\n  }\n  OperatorConf last_device_tick_op_conf;\n  {\n    last_device_tick_op_conf.set_name(std::string(\"System-AutoTick-Tick_\") + NewUniqueId());\n    auto* device_tick_conf = last_device_tick_op_conf.mutable_device_tick_conf();\n    device_tick_conf->add_tick(acc_op_conf.name() + \"/acc\");\n    device_tick_conf->set_out(\"out\");\n  }\n  job_builder->AddOps(op_nodes.front()->parallel_desc().parallel_conf(),\n                      {acc_op_conf, last_device_tick_op_conf});\n  return last_device_tick_op_conf;\n}\n\nstd::vector<std::string> GetOpNames(const HashSet<const OpNode*>& op_nodes) {\n  std::vector<std::string> ret;\n  ret.reserve(op_nodes.size());\n  for (const OpNode* op_node : op_nodes) { ret.emplace_back(op_node->op().op_name()); }\n  return ret;\n};\n\nMaybe<void> InitOpTypeCase2OpNodes(\n    const OpGraph& op_graph,\n    HashMap<OperatorConf::OpTypeCase, HashSet<const OpNode*>>* op_type_case2op_nodes) {\n  JUST(op_graph.MaybeForEachNode([&](OpNode* op_node) -> Maybe<void> {\n    const auto& op_conf = op_node->op().op_conf();\n    if (IsInterfaceOpConf(op_conf)) {\n      CHECK_OR_RETURN((*op_type_case2op_nodes)[op_conf.op_type_case()].emplace(op_node).second);\n    }\n    return Maybe<void>::Ok();\n  }));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> ForEachInputCriticalSectionOpNodes(\n    const OpGraph& op_graph,\n    const std::function<Maybe<void>(const HashSet<const OpNode*>&,\n                                    const std::vector<std::string>&)>& Handler) {\n  HashMap<OperatorConf::OpTypeCase, HashSet<const OpNode*>> op_type_case2op_nodes;\n  JUST(InitOpTypeCase2OpNodes(op_graph, &op_type_case2op_nodes));\n  OperatorConf::OpTypeCase op_type_case = OperatorConf::kInputConf;\n  if (op_type_case2op_nodes[op_type_case].empty()) { return Maybe<void>::Ok(); }\n  HashSet<const OpNode*> op_nodes = op_type_case2op_nodes[op_type_case];\n  for (const OpNode* op_node : op_type_case2op_nodes[op_type_case]) {\n    op_node->ForEachNodeOnOutEdge([&](OpNode* out_node) { op_nodes.insert(out_node); });\n  }\n  JUST(Handler(op_nodes, GetOpNames(op_type_case2op_nodes[op_type_case])));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> ForEachOutputCriticalSectionOpNodes(\n    const OpGraph& op_graph,\n    const std::function<Maybe<void>(const HashSet<const OpNode*>&,\n                                    const std::vector<std::string>&)>& Handler) {\n  HashMap<OperatorConf::OpTypeCase, HashSet<const OpNode*>> op_type_case2op_nodes;\n  JUST(InitOpTypeCase2OpNodes(op_graph, &op_type_case2op_nodes));\n  if (op_type_case2op_nodes[OperatorConf::kReturnConf].empty() == false) {\n    JUST(Handler(op_type_case2op_nodes[OperatorConf::kReturnConf],\n                 GetOpNames(op_type_case2op_nodes[OperatorConf::kReturnConf])));\n  }\n  if (op_type_case2op_nodes[OperatorConf::kOutputConf].empty() == false) {\n    JUST(Handler(op_type_case2op_nodes[OperatorConf::kOutputConf],\n                 GetOpNames(op_type_case2op_nodes[OperatorConf::kOutputConf])));\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<std::vector<OperatorConf>> AddTickForTimeShape(const Shape& src_time_shape,\n                                                     const HashSet<const OpNode*>& op_nodes,\n                                                     JobBuilder* job_builder) {\n  HashMap<std::pair<ParallelDesc, std::pair<Shape, Shape>>, std::list<const OpNode*>>\n      pd7ts2op_nodes;\n  for (const OpNode* op_node : op_nodes) {\n    auto ts = std::make_pair(*JUST(op_node->op().GetInputOutputFastestTimeShape()),\n                             *JUST(op_node->op().GetOpTimeShape()));\n    pd7ts2op_nodes[{op_node->parallel_desc(), ts}].emplace_back(op_node);\n  }\n  std::vector<OperatorConf> op_confs;\n  op_confs.reserve(pd7ts2op_nodes.size());\n  for (const auto& pair : pd7ts2op_nodes) {\n    const std::pair<Shape, Shape>& ts = pair.first.second;\n    if (ts.second.elem_cnt() == src_time_shape.elem_cnt()) {\n      CHECK_GE_OR_RETURN(ts.first.elem_cnt(), ts.second.elem_cnt());\n      op_confs.emplace_back(\n          AppendTick(\"Append\", pair.second, std::make_shared<const Shape>(ts.second), job_builder));\n    } else if (ts.second.elem_cnt() > src_time_shape.elem_cnt()) {\n      op_confs.emplace_back(AppendAccTick(src_time_shape, pair.second, job_builder));\n    } else {\n      UNIMPLEMENTED_THEN_RETURN();\n    }\n  }\n  return op_confs;\n}\n\nMaybe<void> AddGlobalInputOutputCriticalSection(\n    const HashSet<const OpNode*>& op_nodes, const std::vector<std::string>& lbi_producer_op_names,\n    JobBuilder* job_builder) {\n  auto* critical_section =\n      Singleton<CriticalSectionDesc>::Get()->AddCriticalSection(GlobalJobDesc().job_id());\n  {\n    auto* io_cs = critical_section->mutable_input_output_critical_section();\n    *io_cs->mutable_lbi_producer_op_name() = {lbi_producer_op_names.begin(),\n                                              lbi_producer_op_names.end()};\n  }\n  auto time_shape = std::make_unique<Shape>(DimVector{1, 1});\n  HashMap<ParallelDesc, HashSet<const OpNode*>> parallel_desc2op_nodes;\n  for (const OpNode* op_node : op_nodes) {\n    CHECK_OR_RETURN(parallel_desc2op_nodes[op_node->parallel_desc()].insert(op_node).second);\n  }\n  std::vector<OperatorConf> source_ticks;\n  std::vector<OperatorConf> sink_ticks;\n  source_ticks.reserve(parallel_desc2op_nodes.size());\n  for (const auto& pair : parallel_desc2op_nodes) {\n    source_ticks.emplace_back(PrependTick(pair.second, job_builder));\n    const auto& ops = JUST(AddTickForTimeShape(*time_shape, pair.second, job_builder));\n    for (const auto& sink_tick : *ops) { sink_ticks.emplace_back(sink_tick); }\n  }\n  OperatorConf src_subset_tick_op;\n  {\n    CHECK_EQ_OR_RETURN(source_ticks.empty(), false);\n    JUST(BuildSrcSubsetTickOpAndParallelConf(&src_subset_tick_op, job_builder));\n    JUST(CreateSourceTicksAndSrcSubsetTick(critical_section, &src_subset_tick_op, job_builder));\n    for (auto& op_conf : source_ticks) {\n      op_conf.mutable_tick_conf()->add_tick(src_subset_tick_op.name() + \"/\"\n                                            + src_subset_tick_op.src_subset_tick_conf().out());\n    }\n    job_builder->MutOpsOnlyOnce(source_ticks);\n  }\n  HashSet<LogicalBlobId> tick_lbis;\n  for (const auto& op_conf : sink_ticks) {\n    LogicalBlobId lbi;\n    lbi.set_op_name(op_conf.name());\n    CHECK_OR_RETURN(op_conf.has_device_tick_conf());\n    lbi.set_blob_name(op_conf.device_tick_conf().out());\n    CHECK_OR_RETURN(tick_lbis.insert(lbi).second);\n  }\n  JUST(CreateDstSubsetTickAndSinkTicks(critical_section, src_subset_tick_op, tick_lbis,\n                                       job_builder));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> MultiClientAddOneWaitAndSendIdsOp(JobBuilder* job_builder, int64_t machine_id,\n                                              const OperatorConf& src_op_consumer) {\n  ParallelConf parallel_conf;\n  {\n    parallel_conf.set_device_tag(\"cpu\");\n    parallel_conf.add_device_name(std::string(\"@\") + std::to_string(machine_id) + \":0\");\n  }\n\n  // add wait_and_send_ids op conf\n  OperatorConf wait_and_send_ids_op_conf;\n  {\n    wait_and_send_ids_op_conf.set_name(std::string(\"System-Src-WaitAndSendIds_\") + NewUniqueId());\n    wait_and_send_ids_op_conf.set_pass_tag(kMainOp);\n    auto* wait_and_send_ids_conf = wait_and_send_ids_op_conf.mutable_wait_and_send_ids_conf();\n    wait_and_send_ids_conf->set_out(\"out\");\n    wait_and_send_ids_conf->set_wait_buffer_name(\"UnimplementedBufferName\");\n    wait_and_send_ids_conf->set_data_type(DataType::kInt32);\n    // wait_and_send_ids_conf->id_list() is unused in multi-client mode.\n  }\n  JUST(job_builder->AddOp(parallel_conf, wait_and_send_ids_op_conf));\n\n  // connect wait_and_send_ids to tick op which was connected to the src tick op\n  OperatorConf tick_op_conf;\n  tick_op_conf.CopyFrom(src_op_consumer);\n  CHECK_OR_RETURN(tick_op_conf.has_tick_conf());\n  CHECK_EQ_OR_RETURN(tick_op_conf.tick_conf().tick_size(), 1);\n  tick_op_conf.mutable_tick_conf()->clear_tick();\n  tick_op_conf.mutable_tick_conf()->add_tick(\n      GenLogicalBlobName(wait_and_send_ids_op_conf.name(), \"out\"));\n  JUST(job_builder->MutOpOnlyOnce(tick_op_conf));\n\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> MultiClientAddWaitAndSendIds(\n    JobBuilder* job_builder, const HashMap<int64_t, std::string>& machine_id2src_op_name) {\n  // Prepare the consumer tick op for each Source op\n  HashMap<std::string, OperatorConf> src_op_name2solo_consumer_tick_op;\n  HashSet<std::string> src_op_names;\n  for (const auto& pair : machine_id2src_op_name) {\n    CHECK_OR_RETURN(src_op_names.insert(pair.second).second)\n        << \" duplicated src op name \" << pair.second;\n  }\n  JUST(job_builder->ForEachOperator([&](const Operator& op) -> Maybe<void> {\n    // skip if the op is not a tick op\n    if (!op.op_conf().has_tick_conf()) { return Maybe<void>::Ok(); }\n    for (const auto& ibn : op.input_bns()) {\n      const auto& input_lbi = op.BnInOp2Lbi(ibn);\n      if (src_op_names.count(input_lbi.op_name()) == 0) { continue; }\n      auto insert_pair =\n          src_op_name2solo_consumer_tick_op.emplace(input_lbi.op_name(), op.op_conf());\n      CHECK_OR_RETURN(insert_pair.second)\n          << \" Duplicated src op name \" << input_lbi.op_name() << \" old op \"\n          << insert_pair.first->second.DebugString() << \" new op \" << op.op_conf().DebugString();\n    }\n    return Maybe<void>::Ok();\n  }));\n\n  // Replace Source op with WaitAndSendIds op\n  for (const auto& pair : machine_id2src_op_name) {\n    auto tick_op_iter = src_op_name2solo_consumer_tick_op.find(pair.second);\n    CHECK_OR_RETURN(tick_op_iter != src_op_name2solo_consumer_tick_op.end())\n        << \"Can't find consumer tick op of source op name \" << pair.second << \" machine id \"\n        << pair.first;\n    JUST(MultiClientAddOneWaitAndSendIdsOp(job_builder, pair.first, tick_op_iter->second));\n  }\n\n  // Delete Source op\n  std::vector<std::string> src_op_name_vec{src_op_names.begin(), src_op_names.end()};\n  job_builder->DelOps(src_op_name_vec);\n\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> MultiClientAddCallbackNotifier(JobBuilder* job_builder, int64_t machine_id,\n                                           const std::string& sink_op_name) {\n  ParallelConf parallel_conf;\n  {\n    parallel_conf.set_device_tag(\"cpu\");\n    parallel_conf.add_device_name(std::string(\"@\") + std::to_string(machine_id) + \":0\");\n  }\n  OperatorConf callback_notify_op_conf;\n  {\n    callback_notify_op_conf.set_name(std::string(\"System-Sink-CallbackNotify_\") + NewUniqueId());\n    callback_notify_op_conf.set_pass_tag(kMainOp);\n    auto* callback_notify_conf = callback_notify_op_conf.mutable_callback_notify_conf();\n    callback_notify_conf->set_in(GenLogicalBlobName(sink_op_name, \"out\"));\n    // callback_notify_conf->callback_buffer_name() is unused in multi-client mode.\n  }\n  JUST(job_builder->AddOp(parallel_conf, callback_notify_op_conf));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\nMaybe<void> AutoPrependTick(const OpGraph& op_graph, JobBuilder* job_builder) {\n  PrependTickByParallelDesc(op_graph, job_builder);\n  OperatorConf src_subset_tick_op;\n  JUST(BuildSrcSubsetTickOpAndParallelConf(&src_subset_tick_op, job_builder));\n  JUST(ConnectSrcSubsetTickAndOtherTick(src_subset_tick_op, job_builder));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> AddTickForTimeShape(const OpGraph& op_graph, JobBuilder* job_builder) {\n  const auto* op_node = JUST(GetSrcSubsetTickOpNode(op_graph));\n  const auto& src_time_shape = *JUST(op_node->op().GetOpTimeShape());\n  HashSet<const OpNode*> sink_op_nodes;\n  JUST(op_graph.MaybeForEachNode([&](OpNode* op_node) -> Maybe<void> {\n    CHECK_OR_RETURN(!op_node->op().op_conf().has_sink_tick_conf());\n    size_t out_cnt = 0;\n    op_graph.ForEachDataAndCtrlOutNode(op_node, [&](OpNode*) { ++out_cnt; });\n    if (out_cnt == 0) { sink_op_nodes.insert(op_node); }\n    return Maybe<void>::Ok();\n  }));\n  JUST(AddTickForTimeShape(src_time_shape, sink_op_nodes, job_builder));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> AutoSourceAndSinkTick(\n    const OpGraph& op_graph, JobBuilder* job_builder,\n    const std::function<Maybe<void>(int64_t machine_id, const std::string& op_name)>& DoEachSrc,\n    const std::function<Maybe<void>(int64_t machine_id, const std::string& op_name)>& DoEachSink) {\n  JUST(op_graph.MaybeForEachNode([&](OpNode* node) -> Maybe<void> {\n    CHECK_OR_RETURN(!node->op().op_conf().has_sink_tick_conf());\n    return Maybe<void>::Ok();\n  }));\n  const auto* op_node = JUST(GetSrcSubsetTickOpNode(op_graph));\n  const auto& src_time_shape = JUST(op_node->op().GetOpTimeShape());\n  HashSet<LogicalBlobId> tick_lbis;\n  JUST(op_graph.MaybeForEachNode([&](OpNode* op_node) -> Maybe<void> {\n    size_t out_cnt = 0;\n    op_graph.ForEachDataAndCtrlOutNode(op_node, [&](OpNode*) { ++out_cnt; });\n    if (out_cnt > 0) { return Maybe<void>::Ok(); }\n    CHECK_OR_RETURN(op_node->op().op_conf().has_device_tick_conf());\n    CHECK_OR_RETURN(JUST(op_node->op().GetOpTimeShape())->elem_cnt() == src_time_shape->elem_cnt());\n    CHECK_OR_RETURN(tick_lbis.emplace(op_node->op().BnInOp2Lbi(op_node->op().SoleObn())).second);\n    return Maybe<void>::Ok();\n  }));\n  OperatorConf src_subset_tick = JUST(FindJobSoleSrcSubsetTickOpConf(job_builder->job()));\n  JUST(CreateSourceTicksAndSrcSubsetTick(&src_subset_tick, job_builder, DoEachSrc));\n  JUST(CreateDstSubsetTickAndSinkTicks(src_subset_tick, tick_lbis, job_builder, DoEachSink));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> MultiClientAutoSourceAndSinkTick(const OpGraph& op_graph, Job* job) {\n  HashMap<int64_t, std::string> machine_id2src_op_name;\n  HashMap<int64_t, std::string> machine_id2sink_op_name;\n  {\n    JobBuilder job_builder(job);\n    const auto& DoEachSrc = [&](int64_t machine_id, const std::string& op_name) -> Maybe<void> {\n      CHECK_OR_RETURN(machine_id2src_op_name.emplace(machine_id, op_name).second);\n      return Maybe<void>::Ok();\n    };\n    const auto& DoEachSink = [&](int64_t machine_id, const std::string& op_name) -> Maybe<void> {\n      CHECK_OR_RETURN(machine_id2sink_op_name.emplace(machine_id, op_name).second);\n      return Maybe<void>::Ok();\n    };\n    JUST(AutoSourceAndSinkTick(op_graph, &job_builder, DoEachSrc, DoEachSink));\n  }\n  {\n    JobBuilder job_builder(job);\n    JUST(MultiClientAddWaitAndSendIds(&job_builder, machine_id2src_op_name));\n\n    for (const auto& pair : machine_id2sink_op_name) {\n      JUST(MultiClientAddCallbackNotifier(&job_builder, pair.first, pair.second));\n    }\n  }\n  return Maybe<void>::Ok();\n}\n\nnamespace {\n\nMaybe<void> InsertCriticalSectionSrcAndDstTicks(\n    const std::vector<const OpNode*>& interface_op_nodes, JobBuilder* job_builder,\n    std::vector<std::string>* interface_src_tick_op_names,\n    std::vector<std::string>* interface_dst_tick_lbns) {\n  HashMap<ParallelDesc, std::vector<const OpNode*>> parallel_desc2interface_op_nodes;\n  for (const auto* op_node : interface_op_nodes) {\n    parallel_desc2interface_op_nodes[op_node->parallel_desc()].push_back(op_node);\n  }\n  for (const auto& pair : parallel_desc2interface_op_nodes) {\n    const auto& parallel_conf = pair.first.parallel_conf();\n    for (const auto* op_node : pair.second) {\n      OperatorConf interface_op(op_node->op().op_conf());\n      {\n        OperatorConf device_tick_op;\n        device_tick_op.set_name(\"System-EagerCriticalSection-Interface-Begin-Tick-\"\n                                + NewUniqueId());\n        auto* device_tick_op_conf = device_tick_op.mutable_device_tick_conf();\n        device_tick_op_conf->set_out(\"out\");\n        interface_src_tick_op_names->push_back(device_tick_op.name());\n        JUST(job_builder->AddOp(parallel_conf, device_tick_op));\n        interface_op.add_ctrl_in_op_name(device_tick_op.name());\n        JUST(job_builder->MutOpOnlyOnce(interface_op));\n      }\n      {\n        OperatorConf device_tick_op;\n        device_tick_op.set_name(\"System-EagerCriticalSection-Interface-End-Tick-\" + NewUniqueId());\n        device_tick_op.add_ctrl_in_op_name(interface_op.name());\n        auto* device_tick_op_conf = device_tick_op.mutable_device_tick_conf();\n        device_tick_op_conf->set_out(\"out\");\n        interface_dst_tick_lbns->push_back(device_tick_op.name() + \"/out\");\n        JUST(job_builder->AddOp(parallel_conf, device_tick_op));\n      }\n    }\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InsertSrcSubsetTickAndDstSubsetTick(\n    const std::vector<std::string>& interface_src_tick_op_names,\n    const std::vector<std::string>& interface_dst_tick_lbns, JobBuilder* job_builder,\n    std::string* src_subset_tick_op_name, LogicalBlobId* dst_subset_tick_lbi) {\n  {\n    OperatorConf src_subset_tick;\n    JUST(BuildSrcSubsetTickOpAndParallelConf(&src_subset_tick, job_builder));\n    *src_subset_tick_op_name = src_subset_tick.name();\n  }\n  for (const auto& op_name : interface_src_tick_op_names) {\n    OperatorConf op_conf(JUST(job_builder->OpConf4OpName(op_name)));\n    CHECK_OR_RETURN(op_conf.has_device_tick_conf());\n    op_conf.mutable_device_tick_conf()->add_tick(*src_subset_tick_op_name + \"/out\");\n    JUST(job_builder->MutOpOnlyOnce(op_conf));\n  }\n  HashSet<LogicalBlobId> dst_subset_tick_input_lbis;\n  dst_subset_tick_input_lbis.insert(GenLogicalBlobId(*src_subset_tick_op_name + \"/out\"));\n  for (const auto& lbn : interface_dst_tick_lbns) {\n    const auto& lbi = GenLogicalBlobId(lbn);\n    CHECK_OR_RETURN(dst_subset_tick_input_lbis.insert(lbi).second);\n  }\n  {\n    OperatorConf dst_subset_tick_op;\n    JUST(BuildDstSubsetTickOpAndParallelConf(dst_subset_tick_input_lbis, &dst_subset_tick_op,\n                                             job_builder));\n    dst_subset_tick_lbi->set_op_name(dst_subset_tick_op.name());\n    CHECK_OR_RETURN(dst_subset_tick_op.has_dst_subset_tick_conf());\n    dst_subset_tick_lbi->set_blob_name(dst_subset_tick_op.dst_subset_tick_conf().out());\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InsertCriticalSectionWaitTicks(const OpGraph& op_graph, JobBuilder* job_builder,\n                                           const std::string& src_subset_tick_op_name,\n                                           const std::string& wait_buffer_name) {\n  std::vector<const OpNode*> wait_and_send_id_op_nodes;\n  op_graph.ForEachNode([&](OpNode* op_node) {\n    if (!op_node->op().op_conf().has_wait_and_send_ids_conf()) { return; }\n    wait_and_send_id_op_nodes.push_back(op_node);\n  });\n  CHECK_GT_OR_RETURN(wait_and_send_id_op_nodes.size(), 0);\n  OperatorConf src_subset_tick_op(JUST(job_builder->OpConf4OpName(src_subset_tick_op_name)));\n  CHECK_OR_RETURN(src_subset_tick_op.has_src_subset_tick_conf());\n  for (const OpNode* wait_and_send_id_op_node : wait_and_send_id_op_nodes) {\n    LogicalBlobId lbi;\n    lbi.set_op_name(wait_and_send_id_op_node->op().op_name());\n    lbi.set_blob_name(wait_and_send_id_op_node->op().op_conf().wait_and_send_ids_conf().out());\n    OperatorConf critical_section_wait_op;\n    {\n      critical_section_wait_op.set_name(\"System-EagerCriticalSection-Wait-\" + NewUniqueId());\n      auto* conf = critical_section_wait_op.mutable_critical_section_wait_tick_conf();\n      conf->add_tick(GenLogicalBlobName(lbi));\n      conf->set_out(\"out\");\n      conf->set_buffer_name(wait_buffer_name);\n    }\n    const auto& parallel_conf = wait_and_send_id_op_node->parallel_desc().parallel_conf();\n    JUST(job_builder->AddOp(parallel_conf, critical_section_wait_op));\n    src_subset_tick_op.mutable_src_subset_tick_conf()->add_in(critical_section_wait_op.name()\n                                                              + \"/out\");\n  }\n  JUST(job_builder->MutOpOnlyOnce(src_subset_tick_op));\n  return Maybe<void>::Ok();\n}\n\nMaybe<LogicalBlobId> InsertCriticalSectionCallbackTicks(const OpGraph& op_graph,\n                                                        JobBuilder* job_builder,\n                                                        const LogicalBlobId& dst_subset_tick_lbi,\n                                                        const std::string& callback_buffer_name) {\n  OperatorConf critical_section_callback_op;\n  critical_section_callback_op.set_name(\"System-EagerCriticalSection-Callback-\" + NewUniqueId());\n  auto* conf = critical_section_callback_op.mutable_critical_section_callback_tick_conf();\n  conf->add_tick(GenLogicalBlobName(dst_subset_tick_lbi));\n  conf->set_out(\"out\");\n  conf->set_buffer_name(callback_buffer_name);\n  const auto& op_name = dst_subset_tick_lbi.op_name();\n  const auto& parallel_conf = JUST(job_builder->ParallelConf4OpName(op_name));\n  JUST(job_builder->AddOp(parallel_conf, critical_section_callback_op));\n  LogicalBlobId critical_section_callback_lbi;\n  critical_section_callback_lbi.set_op_name(critical_section_callback_op.name());\n  critical_section_callback_lbi.set_blob_name(\"out\");\n  return critical_section_callback_lbi;\n}\n\nMaybe<LogicalBlobId> MultiClientAutoCriticalSectionTick(\n    const OpGraph& op_graph, JobBuilder* job_builder,\n    const std::vector<const OpNode*>& interface_op_nodes, const std::string& wait_buffer_name,\n    const std::string& callback_buffer_name) {\n  std::vector<std::string> interface_src_tick_op_names;\n  std::vector<std::string> interface_dst_tick_lbns;\n  JUST(InsertCriticalSectionSrcAndDstTicks(interface_op_nodes, job_builder,\n                                           &interface_src_tick_op_names, &interface_dst_tick_lbns));\n  std::string src_subset_tick_op_name;\n  LogicalBlobId dst_subset_tick_lbi;\n  JUST(InsertSrcSubsetTickAndDstSubsetTick(interface_src_tick_op_names, interface_dst_tick_lbns,\n                                           job_builder, &src_subset_tick_op_name,\n                                           &dst_subset_tick_lbi));\n  JUST(InsertCriticalSectionWaitTicks(op_graph, job_builder, src_subset_tick_op_name,\n                                      wait_buffer_name));\n  const auto& lbi = JUST(InsertCriticalSectionCallbackTicks(\n      op_graph, job_builder, dst_subset_tick_lbi, callback_buffer_name));\n  return lbi;\n}\n\nMaybe<void> ConnectCriticalSectionCallbackToJobSoleDstSubsetTick(\n    const OpGraph& op_graph, JobBuilder* job_builder,\n    const std::vector<std::shared_ptr<LogicalBlobId>>& critical_section_callback_lbis) {\n  const OpNode* dst_subset_tick_op_node = nullptr;\n  JUST(op_graph.MaybeForEachNode([&](OpNode* op_node) -> Maybe<void> {\n    if (!op_node->op().op_conf().has_dst_subset_tick_conf()) { return Maybe<void>::Ok(); }\n    CHECK_OR_RETURN(dst_subset_tick_op_node == nullptr);\n    dst_subset_tick_op_node = op_node;\n    return Maybe<void>::Ok();\n  }));\n  CHECK_NOTNULL_OR_RETURN(dst_subset_tick_op_node);\n  OperatorConf dst_subset_tick_op(dst_subset_tick_op_node->op().op_conf());\n  auto* conf = dst_subset_tick_op.mutable_dst_subset_tick_conf();\n  for (const auto& lbi : critical_section_callback_lbis) { conf->add_in(GenLogicalBlobName(*lbi)); }\n  JUST(job_builder->MutOpOnlyOnce(dst_subset_tick_op));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\nMaybe<void> MultiClientAutoInterfaceCriticalSectionTick(const OpGraph& op_graph, Job* job) {\n  JobBuilder job_builder(job);\n  std::vector<std::shared_ptr<LogicalBlobId>> critical_section_callback_lbis;\n  {\n    std::vector<const OpNode*> interface_op_nodes;\n    op_graph.ForEachNode([&](OpNode* node) {\n      if (node->op().op_conf().has_input_conf()) { interface_op_nodes.push_back(node); }\n    });\n    const auto& lbi = JUST(MultiClientAutoCriticalSectionTick(\n        op_graph, &job_builder, interface_op_nodes,\n        GetInputCriticalSectionWaitBufferName(job->job_conf().job_name()),\n        GetInputCriticalSectionCallbackBufferName(job->job_conf().job_name())));\n    critical_section_callback_lbis.push_back(lbi);\n  }\n  {\n    std::vector<const OpNode*> interface_op_nodes;\n    op_graph.ForEachNode([&](OpNode* node) {\n      if (node->op().op_conf().has_output_conf()) { interface_op_nodes.push_back(node); }\n    });\n    const auto& lbi = JUST(MultiClientAutoCriticalSectionTick(\n        op_graph, &job_builder, interface_op_nodes,\n        GetOutputCriticalSectionWaitBufferName(job->job_conf().job_name()),\n        GetOutputCriticalSectionCallbackBufferName(job->job_conf().job_name())));\n    critical_section_callback_lbis.push_back(lbi);\n  }\n  JUST(ConnectCriticalSectionCallbackToJobSoleDstSubsetTick(op_graph, &job_builder,\n                                                            critical_section_callback_lbis));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job_rewriter/autotick.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_JOB_REWRITER_AUTOTICK_H_\n#define ONEFLOW_CORE_JOB_REWRITER_AUTOTICK_H_\n\n#include \"oneflow/core/job/job_desc.h\"\n#include \"oneflow/core/operator/operator.h\"\n#include \"oneflow/core/graph/op_graph.h\"\n\nnamespace oneflow {\n\nMaybe<void> AutoPrependTick(const OpGraph& op_graph, JobBuilder* job_builder);\nMaybe<void> AddTickForTimeShape(const OpGraph& op_graph, JobBuilder* job_builder);\nMaybe<void> MultiClientAutoSourceAndSinkTick(const OpGraph& op_graph, Job* job);\nMaybe<void> MultiClientAutoInterfaceCriticalSectionTick(const OpGraph& op_graph, Job* job);\n\nclass MutOpConTickInputHelper {\n public:\n  bool IsTickInputBound() const { return VirtualIsTickInputBound(); }\n  virtual bool VirtualIsTickInputBound() const = 0;\n  virtual OperatorConf NewTickInputBoundOpConf(const std::string& lbn) const = 0;\n  void InitFromOpConf(const OperatorConf& op_conf) { op_conf_ = &op_conf; }\n  virtual ~MutOpConTickInputHelper() = default;\n\n protected:\n  MutOpConTickInputHelper() : op_conf_(nullptr) {}\n  const OperatorConf& op_conf() const { return *op_conf_; }\n\n private:\n  const OperatorConf* op_conf_;\n};\n\n#define REGISTER_AUTO_TICK(op_type_case, HelperType) \\\n  REGISTER_CLASS(int32_t, op_type_case, MutOpConTickInputHelper, HelperType)\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_JOB_REWRITER_AUTOTICK_H_\n"
  },
  {
    "path": "oneflow/core/job_rewriter/boxing_with_middle_nodes.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/job_rewriter/boxing_with_middle_nodes.h\"\n#include \"oneflow/core/common/just.h\"\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/framework/nd_sbp.h\"\n#include \"oneflow/core/framework/sbp_infer_util.h\"\n#include \"oneflow/core/job/job_desc.h\"\n#include \"oneflow/core/common/protobuf.h\"\n#include \"oneflow/core/auto_parallel/boxing_collector.h\"\n#include \"oneflow/core/common/container_util.h\"\n\nnamespace oneflow {\n\nMaybe<void> BoxingWithMiddleNodes(const OpGraph& op_graph, JobBuilder* job_builder) {\n  // Not allowed two-step boxing and disable checking for debugging\n  if (ParseBooleanFromEnv(\"ONEFLOW_BOXING_DISABLE_MIDDLE_NODE_AND_CHECK\", false)) {\n    return Maybe<void>::Ok();\n  }\n  // Initialize boxing collector\n  BoxingCollector boxing_collector;\n  std::vector<NdSbp> middle_sbps;\n  HashMap<const OpNode*, OperatorConf> op_node2op_conf;\n  // Fill other unsupported combinations\n  op_graph.ForEachNode([&](const OpNode* node) -> Maybe<void> {\n    OperatorConf::OpTypeCase op_type_case = node->op().op_conf().op_type_case();\n    if (IsClassRegistered<int32_t, DisableInputBoxingGroup>(op_type_case)) {\n      return Maybe<void>::Ok();\n    }\n    for (const std::string& ibn : node->op().input_bns()) {\n      const LogicalBlobId& lbi = node->op().BnInOp2Lbi(ibn);\n      const OpNode& producer = node->ProducerOpNode4Lbi(lbi);\n      const NdSbp& producer_nd_sbp = producer.NdSbp4Lbi(lbi);\n      const NdSbp& consumer_nd_sbp = node->NdSbp4BnInOp(ibn);\n\n      // If dealing with different placement\n      if (producer.parallel_desc().parallel_num() != 1\n          || node->parallel_desc().parallel_num() != 1) {\n        const auto& logical_blob_desc = producer.LogicalBlobDesc4Lbi(lbi);\n        // Ask for middle nodes\n        int32_t diag_node = 0;\n        JUST(boxing_collector.AskSbpCombination(producer_nd_sbp, consumer_nd_sbp, logical_blob_desc,\n                                                producer.parallel_desc(), node->parallel_desc(),\n                                                /*is_customized=*/false, middle_sbps, &diag_node,\n                                                /*compute_cost=*/false));\n        // move to the next ibn if no middle nodes needed\n        if (middle_sbps.size() <= 0) { continue; }\n        LogicalBlobId middle_node_lbi = lbi;\n        VLOG(3) << \" Lbi \" << lbi.op_name() << \"/\" << lbi.blob_name() << \" src sbp \"\n                << NdSbpToString(producer_nd_sbp);\n        VLOG(3) << \" Lbi \" << lbi.op_name() << \"/\" << lbi.blob_name() << \" dst sbp \"\n                << NdSbpToString(consumer_nd_sbp);\n        for (int32_t middle_node_id = 0; middle_node_id < middle_sbps.size(); middle_node_id++) {\n          VLOG(3) << \" Lbi \" << lbi.op_name() << \"/\" << lbi.blob_name() << \" add middle node \"\n                  << NdSbpToString(JUST(VectorAt(middle_sbps, middle_node_id)));\n          // Create the middle operators\n          OperatorConf identity_op_conf{};\n          identity_op_conf.set_name(\"System-Boxing-Middle-Identity-\" + NewUniqueId());\n          IdentityOpConf* identity_conf = identity_op_conf.mutable_identity_conf();\n          identity_conf->set_in(GenLogicalBlobName(middle_node_lbi));\n          identity_conf->set_out(\"out\");\n          if (middle_node_id < diag_node) {\n            job_builder->AddOps(producer.parallel_desc().parallel_conf(), {identity_op_conf});\n          } else {\n            job_builder->AddOps(node->parallel_desc().parallel_conf(), {identity_op_conf});\n          }\n          NdSbpSignature identity_nd_sbp_signature;\n          (*identity_nd_sbp_signature.mutable_bn_in_op2nd_sbp())[\"in\"] =\n              middle_sbps[middle_node_id];\n          (*identity_nd_sbp_signature.mutable_bn_in_op2nd_sbp())[\"out\"] =\n              middle_sbps[middle_node_id];\n          job_builder->AddNdSbpSignature4OpName(identity_op_conf.name(), identity_nd_sbp_signature);\n          // Connection for the next middle node\n          middle_node_lbi.set_op_name(identity_op_conf.name());\n          middle_node_lbi.set_blob_name(identity_conf->out());\n        }\n        // Replace input blob with configuration from middle nodes\n        if (op_node2op_conf.find(node) == op_node2op_conf.end()) {\n          op_node2op_conf[node] = node->op().op_conf();\n        }\n        OperatorConf& consumer_op_conf = op_node2op_conf[node];\n        const auto& old_val = ReplaceInputLbnInOpCustomizedConf(\n            &consumer_op_conf, ibn, GenLogicalBlobName(middle_node_lbi));\n        CHECK_EQ_OR_RETURN(GenLogicalBlobName(lbi), old_val);\n      }\n    }\n\n    return Maybe<void>::Ok();\n  });\n  for (const auto& op_node7op_conf : op_node2op_conf) {\n    JUST(job_builder->MutOpOnlyOnce(op_node7op_conf.second));\n  }\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job_rewriter/boxing_with_middle_nodes.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_JOB_REWRITER_BOXING_WITH_MIDDLE_NODES_H_\n#define ONEFLOW_CORE_JOB_REWRITER_BOXING_WITH_MIDDLE_NODES_H_\n\n#include \"oneflow/core/graph/op_graph.h\"\n\nnamespace oneflow {\n\nclass OpGraph;\nclass Job;\n\nMaybe<void> BoxingWithMiddleNodes(const OpGraph& op_graph, JobBuilder* job_builder);\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_JOB_REWRITER_BOXING_WITH_MIDDLE_NODES_H_\n"
  },
  {
    "path": "oneflow/core/job_rewriter/calculation_pass.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/job_rewriter/calculation_pass.h\"\n\nnamespace oneflow {\n\nconst std::string kForwardPass = \"forward_pass\";\nconst std::string kBackwardPass = \"backward_pass\";\nconst std::string kOptimizerPass = \"optimizer_pass\";\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job_rewriter/calculation_pass.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_JOB_REWRITE_CALCULATION_PASS_H_\n#define ONEFLOW_CORE_JOB_REWRITE_CALCULATION_PASS_H_\n\n#include <string>\n\nnamespace oneflow {\n\nextern const std::string kForwardPass;\nextern const std::string kBackwardPass;\nextern const std::string kOptimizerPass;\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_JOB_REWRITE_CALCULATION_PASS_H_\n"
  },
  {
    "path": "oneflow/core/job_rewriter/checkpointing_pass.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/job_rewriter/job_pass.h\"\n#include \"oneflow/core/job/job.pb.h\"\n#include \"oneflow/core/job/scope.h\"\n#include \"oneflow/core/job_rewriter/calculation_pass.h\"\n#include \"oneflow/core/vm/symbol_storage.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/nd_sbp.h\"\n#include \"oneflow/core/operator/operator.h\"\n#include \"oneflow/core/rpc/include/global_process_ctx.h\"\n#include \"oneflow/core/common/env_var/debug_mode.h\"\n\nnamespace oneflow {\n\nnamespace {\n\n// Do CheckpointingPass will use backward recomputation for sublinear memory cost.\nclass CheckpointingPass final : public JobPass {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CheckpointingPass);\n  CheckpointingPass() = default;\n  ~CheckpointingPass() = default;\n\n  Maybe<void> Apply(Job* job, JobPassCtx* ctx) const override {\n    if (!IsEnabled(*ctx)) { return Maybe<void>::Ok(); }\n    const OpGraph op_graph(*job);\n    JobBuilder job_builder(job);\n    return Apply(op_graph, &job_builder);\n  }\n\n  bool IsEnabled(const JobPassCtx& ctx) const { return ctx.job_desc().IsTrain(); }\n\n  Maybe<void> Apply(const OpGraph& op_graph, JobBuilder* job_builder) const;\n};\n\nconst std::string kCheckpointingFakeOpNamePrefix = \"Sys-Checkpointing-Fake-Fw-Op_\";\nconst std::string kCheckpointingIdentityOpName = \"Sys-Checkpointing-Identity\";\nconst std::string kCheckpointingBadOpName = \"Sys-CheckpointPassBadEndOpName\";\n\nconst Scope& Scope4OpNode(const OpNode* op_node) {\n  int64_t scope_symbol_id = op_node->op().op_conf().scope_symbol_id();\n  CHECK(Singleton<symbol::Storage<Scope>>::Get()->Has(scope_symbol_id))\n      << \"rank[\" << GlobalProcessCtx::Rank() << \"] \"\n      << \"scope_symbol_id: \" << scope_symbol_id;\n  return Singleton<symbol::Storage<Scope>>::Get()->Get(scope_symbol_id);\n}\n\nbool IsForwardPassScope(const Scope& scope) {\n  return scope.scope_proto().calculation_pass_name() == kForwardPass;\n}\n\nbool IsForwardPass7CheckpointingScope(const Scope& scope) {\n  return IsForwardPassScope(scope) && scope.Bool(\"checkpointing\");\n}\n\nvoid CollectAllCheckpointingOpsInForwardPass(\n    const OpGraph& op_graph, HashMap<std::string, const OpNode*>* checkpointing_op_name2op_node) {\n  // NOTE(chengcheng):\n  //   ignore batch_norm ops because of recompute bn will repeat the calculation of 'm' and 'v'.\n  //   in the future, we need to support the recomputation version of batch_norm which do NOT\n  //   update forward variables.\n  HashSet<std::string> ignore_op_type_names = {\"normalization\", \"normalization_add_relu\",\n                                               \"cudnn_fused_normalization_add_relu\", \"repeat\",\n                                               \"unpack\"};\n  op_graph.ForEachNode([&](const OpNode* op_node) {\n    const OperatorConf& op_conf = op_node->op().op_conf();\n    if (!op_conf.has_user_conf()) { return; }\n    if (ignore_op_type_names.find(op_conf.user_conf().op_type_name())\n        != ignore_op_type_names.end()) {\n      return;\n    }\n    if (IsForwardPass7CheckpointingScope(Scope4OpNode(op_node))) {\n      CHECK(checkpointing_op_name2op_node->emplace(op_conf.name(), op_node).second);\n    }\n  });\n}\n\nvoid GenConnectedCheckpointingSubgraphs(\n    const HashMap<std::string, const OpNode*>& checkpointing_op_name2op_node,\n    std::vector<HashSet<const OpNode*>>* checkpointing_subgraphs) {\n  HashSet<const OpNode*> visited_nodes;\n  checkpointing_subgraphs->reserve(checkpointing_op_name2op_node.size());\n  for (const auto& pair : checkpointing_op_name2op_node) {\n    const OpNode* node = pair.second;\n    if (visited_nodes.find(node) != visited_nodes.end()) { continue; }\n\n    // new subgraph\n    checkpointing_subgraphs->emplace_back(HashSet<const OpNode*>());\n    CHECK(!checkpointing_subgraphs->empty());\n    auto& subgraph = checkpointing_subgraphs->back();\n    CHECK(subgraph.empty());\n\n    // bfs search all node in checkpointing ops\n    CHECK(visited_nodes.insert(node).second);\n    std::queue<const OpNode*> queued_nodes;\n    queued_nodes.push(node);\n    while (!queued_nodes.empty()) {\n      const OpNode* cur_node = queued_nodes.front();\n      queued_nodes.pop();\n\n      CHECK(subgraph.insert(cur_node).second);\n\n      cur_node->ForEachNodeOnInOutEdge([&](const OpNode* next_node) {\n        const std::string& next_op_name = next_node->op().op_name();\n        if (checkpointing_op_name2op_node.find(next_op_name) != checkpointing_op_name2op_node.end()\n            && cur_node->parallel_desc() == next_node->parallel_desc()\n            && visited_nodes.find(next_node) == visited_nodes.end()) {\n          queued_nodes.push(next_node);\n          CHECK(visited_nodes.insert(next_node).second);\n        }\n      });\n    }\n  }\n}\n\nMaybe<void> CheckpointingPass::Apply(const OpGraph& op_graph, JobBuilder* job_builder) const {\n  // step 1. collect all checkpointing ops in forwardpass.\n  HashMap<std::string, const OpNode*> checkpointing_op_name2op_node;\n  CollectAllCheckpointingOpsInForwardPass(op_graph, &checkpointing_op_name2op_node);\n  if (checkpointing_op_name2op_node.empty()) { return Maybe<void>::Ok(); }\n\n  // step 2. get all connected subgraphs in checkpointing ops.\n  std::vector<HashSet<const OpNode*>> checkpointing_subgraphs;\n  GenConnectedCheckpointingSubgraphs(checkpointing_op_name2op_node, &checkpointing_subgraphs);\n\n  HashMap<const OpNode*, int32_t> op_node2order;\n  int32_t order = 0;\n  op_graph.TopoForEachNode([&](const OpNode* op_node) {\n    CHECK(op_node2order.emplace(op_node, order).second);\n    ++order;\n  });\n\n  // step 3. for each subgraphs:\n\n  // NOTE(chengcheng):\n  //   maybe a bw consumer will consume multi subgraph for recompute.\n  //   so we need collect bw consumer between subgraphs, and update them in job builder only once.\n  HashMap<std::string, OperatorConf> total_bw_consumers_op_name2conf;\n\n  int32_t subgraph_id = 0;\n  for (auto& subgraph : checkpointing_subgraphs) {\n    // step 3.1 ignore this subgraph if there is no direct edge to backward pass op.\n    HashSet<const OpNode*> bw_consumers;\n    for (const OpNode* node : subgraph) {\n      node->ForEachNodeOnOutEdge([&](const OpNode* out_node) {\n        if (!IsForwardPassScope(Scope4OpNode(out_node))) {\n          bw_consumers.insert(out_node);\n          CHECK(subgraph.find(out_node) == subgraph.end());\n        }\n      });\n    }\n    if (bw_consumers.empty()) { continue; }\n\n    HashSet<LogicalBlobId> checkpointing_tensor;\n\n    HashMap<std::string, const OpNode*> subgraph_op_name2op_node;\n    ParallelConf parallel_conf;\n    for (const OpNode* node : subgraph) {\n      subgraph_op_name2op_node.emplace(node->op().op_name(), node);\n      parallel_conf = node->parallel_desc().parallel_conf();\n    }\n\n    // step 3.2 generate fake subgraph for recomputation\n    HashMap<std::string, OperatorConf> fake_op_name2conf;\n    HashSet<std::string> source_node_in_fake_subgraph;\n    for (const OpNode* node : subgraph) {\n      OperatorConf fake_op_conf = node->op().op_conf();\n      std::string fake_op_name = kCheckpointingFakeOpNamePrefix + fake_op_conf.name();\n      fake_op_conf.set_name(fake_op_name);\n      const int64_t old_scope_symbol_id = fake_op_conf.scope_symbol_id();\n      // update fake op conf scope from fw to bw\n      const int64_t new_scope_symbol_id = JUST(\n          NewScopeSymbolId(old_scope_symbol_id, [](const std::shared_ptr<ScopeProto>& new_scope) {\n            CHECK_EQ(new_scope->calculation_pass_name(), kForwardPass);\n            new_scope->set_calculation_pass_name(kBackwardPass);\n          }));\n      fake_op_conf.set_scope_symbol_id(new_scope_symbol_id);\n\n      auto* user_conf = fake_op_conf.mutable_user_conf();\n      // change output lbns\n      for (auto& pair : *(user_conf->mutable_output())) {\n        auto& list_s = pair.second;\n        for (int i = 0; i < list_s.s_size(); ++i) {\n          std::string old_lbn = list_s.s(i);\n          list_s.set_s(i, kCheckpointingFakeOpNamePrefix + old_lbn);\n          // check valid\n          LogicalBlobId old_lbi = GenLogicalBlobId(old_lbn);\n          CHECK_EQ(node->op().op_conf().name(), old_lbi.op_name());\n          CHECK_EQ(kCheckpointingFakeOpNamePrefix + old_lbi.op_name(), fake_op_name);\n          std::string new_lbn = list_s.s(i);\n          LogicalBlobId new_lbi = GenLogicalBlobId(new_lbn);\n          CHECK_EQ(new_lbi.op_name(), fake_op_name);\n          CHECK_EQ(old_lbi.blob_name(), new_lbi.blob_name());\n        }\n      }\n\n      int32_t input_num = 0;\n      // change input lbns if in subgraph\n      for (auto& pair : *(user_conf->mutable_input())) {\n        auto& list_s = pair.second;\n        for (int i = 0; i < list_s.s_size(); ++i) {\n          ++input_num;\n          std::string old_lbn = list_s.s(i);\n          LogicalBlobId old_lbi = GenLogicalBlobId(old_lbn);\n\n          std::string old_input_op_name = old_lbi.op_name();\n          if (subgraph_op_name2op_node.find(old_input_op_name) != subgraph_op_name2op_node.end()) {\n            list_s.set_s(i, kCheckpointingFakeOpNamePrefix + old_lbn);\n          } else {\n            source_node_in_fake_subgraph.insert(fake_op_name);\n            checkpointing_tensor.insert(old_lbi);\n          }\n        }\n      }\n      if (input_num == 0) { source_node_in_fake_subgraph.insert(fake_op_name); }\n\n      fake_op_name2conf.emplace(fake_op_name, fake_op_conf);\n    }\n\n    const OpNode* first_bw_consumer = nullptr;\n    int32_t first_bw_order = std::numeric_limits<int32_t>::max();\n    // step 3.3 change bw consumers input from subgraph to fake subgraph\n    for (const OpNode* node : bw_consumers) {\n      std::string bw_consumer_name = node->op().op_name();\n      OperatorConf bw_consumer_op_conf;\n      // NOTE(chengcheng):\n      //   reuse bw conumer op conf if it has been existed in map.\n      if (total_bw_consumers_op_name2conf.find(bw_consumer_name)\n          != total_bw_consumers_op_name2conf.end()) {\n        bw_consumer_op_conf = total_bw_consumers_op_name2conf.at(bw_consumer_name);\n      } else {\n        bw_consumer_op_conf = node->op().op_conf();\n      }\n      CHECK_EQ(bw_consumer_name, bw_consumer_op_conf.name());\n\n      auto* user_conf = bw_consumer_op_conf.mutable_user_conf();\n      // change input lbns if in subgraph\n      for (auto& pair : *(user_conf->mutable_input())) {\n        auto& list_s = pair.second;\n        for (int i = 0; i < list_s.s_size(); ++i) {\n          std::string old_lbn = list_s.s(i);\n          LogicalBlobId old_lbi = GenLogicalBlobId(old_lbn);\n\n          std::string old_input_op_name = old_lbi.op_name();\n          if (subgraph_op_name2op_node.find(old_input_op_name) != subgraph_op_name2op_node.end()) {\n            list_s.set_s(i, kCheckpointingFakeOpNamePrefix + old_lbn);\n          }\n        }\n      }\n\n      // NOTE(chengcheng):\n      //   emplace maybe repeated, so do not check the return value\n      total_bw_consumers_op_name2conf.emplace(bw_consumer_name, bw_consumer_op_conf);\n\n      CHECK(op_node2order.find(node) != op_node2order.end());\n      int32_t this_order = op_node2order.at(node);\n      if (this_order < first_bw_order) {\n        first_bw_consumer = node;\n        first_bw_order = this_order;\n      }\n    }\n\n    // step 3.4 add control edge from End Op to all source node in fake subgraph\n    CHECK(first_bw_consumer != nullptr);\n    std::string end_op_name = kCheckpointingBadOpName;\n    int32_t end_order = -1;\n    const OpNode* end_op_node = nullptr;\n    first_bw_consumer->ForEachNodeOnInEdge([&](const OpNode* end_node) {\n      CHECK(op_node2order.find(end_node) != op_node2order.end());\n      int32_t this_order = op_node2order.at(end_node);\n      if (this_order > end_order) {\n        end_order = this_order;\n        end_op_name = end_node->op().op_name();\n        end_op_node = end_node;\n      }\n    });\n    CHECK_NE(end_order, -1);\n    CHECK_NE(end_op_name, kCheckpointingBadOpName);\n    CHECK_LT(end_order, first_bw_order);\n    CHECK(end_op_node != nullptr);\n    // NOTE(chengcheng): if end_op placement is different with first_bw_consumer, the ctrl edge\n    //   cannot be directly connected.\n    if (!first_bw_consumer->parallel_desc().EqualsIgnoringHierarchy(end_op_node->parallel_desc())) {\n      std::string lbn = \"\";\n      LogicalBlobId lbi;\n      const OpEdge* end_op_edge = nullptr;\n      for (const OpEdge* in_edge : first_bw_consumer->in_edges()) {\n        if (in_edge->src_node() == end_op_node) {\n          lbi = in_edge->lbis().front();\n          lbn = GenLogicalBlobName(lbi);\n          end_op_edge = in_edge;\n          break;\n        }\n      }\n      CHECK(!lbn.empty());\n\n      auto id_op = user_op::UserOpConfWrapperBuilder(kCheckpointingIdentityOpName + NewUniqueId())\n                       .Op(\"identity\")\n                       .Input(\"in\", lbn)\n                       .Output(\"out\")\n                       .ScopeSymbolId(first_bw_consumer->op().op_conf().scope_symbol_id())\n                       .Build();\n\n      std::string id_out = id_op.output(\"out\", 0);\n      for (const std::string& ibn : end_op_edge->lbi2ibns().at(lbi)) {\n        std::string old_lbn = ReplaceInputLbnInOpCustomizedConf(\n            &(total_bw_consumers_op_name2conf.at(first_bw_consumer->op().op_name())), ibn, id_out);\n        CHECK_EQ(old_lbn, lbn);\n      }\n\n      JUST(job_builder->AddOp(first_bw_consumer->parallel_desc().parallel_conf(), id_op.op_conf()));\n      end_op_name = id_op.op_name();\n    }\n    for (const auto& source_op_name : source_node_in_fake_subgraph) {\n      fake_op_name2conf.at(source_op_name).add_ctrl_in_op_name(end_op_name);\n    }\n\n    // step 3.5 add fake subgraph ops to job builder\n    std::vector<OperatorConf> fake_op_confs;\n    for (auto& pair : fake_op_name2conf) { fake_op_confs.emplace_back(pair.second); }\n    job_builder->AddOps(parallel_conf, fake_op_confs);\n\n    // step 3.6 log checkpointing tensor flow debug.\n    if (IsInDebugMode()) {\n      VLOG(2) << \" In subgraph: \" << subgraph_id\n              << \" has checkpointing tensor num = \" << checkpointing_tensor.size();\n      for (const auto& lbi : checkpointing_tensor) {\n        const OpNode* node = op_graph.OpNode4OpName(lbi.op_name());\n        const BlobDesc& blob = node->LogicalBlobDesc4Lbi(lbi);\n        VLOG(2) << \"Checkpointing tensor: \" << GenLogicalBlobName(lbi)\n                << \" ,shape: \" << blob.shape().ToString()\n                << \" ,dtype: \" << DataType_Name(blob.data_type())\n                << \" ,placement: \" << *JUST(PlacementToString(SymbolOf(node->parallel_desc())))\n                << \" ,sbp: \" << NdSbpToString(node->NdSbp4Lbi(lbi));\n      }\n      subgraph_id++;\n    }\n  }\n\n  // step 4. update bw consumers in job builder only once\n  std::vector<OperatorConf> total_bw_consumer_op_confs;\n  for (auto& pair : total_bw_consumers_op_name2conf) {\n    total_bw_consumer_op_confs.emplace_back(pair.second);\n  }\n  job_builder->MutOpsOnlyOnce(total_bw_consumer_op_confs);\n\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\nREGISTER_JOB_PASS(\"CheckpointingPass\", CheckpointingPass);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job_rewriter/clip_by_global_norm_job_pass_state.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_JOB_REWRITER_CLIP_BY_GLOBAL_NORM_JOB_PASS_STATE_H_\n#define ONEFLOW_CORE_JOB_REWRITER_CLIP_BY_GLOBAL_NORM_JOB_PASS_STATE_H_\n\n#include \"oneflow/core/job_rewriter/job_pass.h\"\n\nnamespace oneflow {\n\nclass ClipByGlobalNormJobPassState : public JobPassState {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(ClipByGlobalNormJobPassState);\n  ClipByGlobalNormJobPassState() = default;\n  ~ClipByGlobalNormJobPassState() override = default;\n\n  class TotalNormState {\n   public:\n    TotalNormState(const std::string& total_norm_lbn, const std::string& coeff_lbn,\n                   const ParallelConf& parallel_conf, int64_t scope_symbol_id)\n        : total_norm_lbn_(total_norm_lbn),\n          coeff_lbn_(coeff_lbn),\n          parallel_conf_(parallel_conf),\n          scope_symbol_id_(scope_symbol_id) {}\n\n    void set_total_norm_lbn(const std::string& total_norm_lbn) { total_norm_lbn_ = total_norm_lbn; }\n    const std::string& total_norm_lbn() const { return total_norm_lbn_; }\n    const std::string& coeff_lbn() const { return coeff_lbn_; }\n    const ParallelConf& parallel_conf() const { return parallel_conf_; }\n    int64_t scope_symbol_id() const { return scope_symbol_id_; }\n\n   private:\n    std::string total_norm_lbn_;\n    std::string coeff_lbn_;\n    ParallelConf parallel_conf_;\n    int64_t scope_symbol_id_;\n  };\n\n  void AddTotalNormState(const std::string& variable_op_name,\n                         const std::shared_ptr<TotalNormState>& total_norm_state) {\n    CHECK(variable_op_name2total_norm_state_.emplace(variable_op_name, total_norm_state).second)\n        << variable_op_name;\n  }\n\n  const std::shared_ptr<TotalNormState>& GetTotalNormState(const std::string& variable_op_name) {\n    const auto& it = variable_op_name2total_norm_state_.find(variable_op_name);\n    CHECK(it != variable_op_name2total_norm_state_.end());\n    return it->second;\n  }\n\n  const bool HasTotalNormState(const std::string& variable_op_name) {\n    const auto& it = variable_op_name2total_norm_state_.find(variable_op_name);\n    return (it != variable_op_name2total_norm_state_.end());\n  }\n\n private:\n  HashMap<std::string, std::shared_ptr<TotalNormState>> variable_op_name2total_norm_state_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_JOB_REWRITER_CLIP_BY_GLOBAL_NORM_JOB_PASS_STATE_H_\n"
  },
  {
    "path": "oneflow/core/job_rewriter/clone_grad.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/job_rewriter/clone_grad.h\"\n#include \"oneflow/core/framework/framework.h\"\n\nnamespace oneflow {\n\nMaybe<void> GenerateCloneGradOpIfNeed(\n    const OpNode& op_node, JobBuilder* job_builder,\n    const HashMap<OpBlobArg, LogicalBlobId>& in_oba2in_diff_lbi,\n    HashMap<OpBlobArg, LogicalBlobId>* out_oba2out_diff_lbi,\n    HashMap<OpBlobArg, LogicalBlobId>* out_oba2clone_bw_add_out_lbi) {\n  HashMap<LogicalBlobId, OpBlobArg> out_lbi2out_oba;\n  for (const auto& obn : op_node.op().output_bns()) {\n    out_lbi2out_oba[op_node.op().BnInOp2Lbi(obn)] = GenOpBlobArg(op_node.op().op_name(), obn);\n  }\n  HashMap<OpBlobArg, std::vector<LogicalBlobId>> out_oba2in_diff_lbis;\n  op_node.ForEachNodeOnOutEdge([&](OpNode* out_node) {\n    for (const auto& ibn : out_node->op().input_bns()) {\n      const auto& oba_it = out_lbi2out_oba.find(out_node->op().BnInOp2Lbi(ibn));\n      if (oba_it == out_lbi2out_oba.end()) { continue; }\n      const auto& in_diff_lbi_it =\n          in_oba2in_diff_lbi.find(GenOpBlobArg(out_node->op().op_name(), ibn));\n      if (in_diff_lbi_it == in_oba2in_diff_lbi.end()) { continue; }\n      out_oba2in_diff_lbis[oba_it->second].emplace_back(in_diff_lbi_it->second);\n    }\n  });\n  for (const auto& obn : op_node.op().output_bns()) {\n    const OpBlobArg& oba = GenOpBlobArg(op_node.op().op_name(), obn);\n    const LogicalBlobId& lbi = op_node.op().BnInOp2Lbi(obn);\n    const std::vector<LogicalBlobId>& lbis_to_add = out_oba2in_diff_lbis[oba];\n    if (lbis_to_add.empty()) {\n      continue;\n    } else if (lbis_to_add.size() == 1) {\n      out_oba2out_diff_lbi->emplace(oba, lbis_to_add.front());\n    } else {\n      user_op::UserOpConfWrapperBuilder add_op_builder(op_node.op().op_name() + \"_clone_grad_\"\n                                                       + NewUniqueId());\n      add_op_builder.Op(\"add_n\");\n      for (const LogicalBlobId& lbi_to_add : lbis_to_add) {\n        add_op_builder.Input(\"in\", GenLogicalBlobName(lbi_to_add));\n      }\n      const auto& op_conf = JUST(job_builder->OpConf4OpName(lbi.op_name()));\n      const auto add_op =\n          add_op_builder.Output(\"out\").ScopeSymbolId(op_conf.scope_symbol_id()).Build();\n      job_builder->AddOps(JUST(job_builder->ParallelConf4Lbi(lbi)), {add_op.op_conf()});\n      CHECK(out_oba2clone_bw_add_out_lbi->emplace(oba, lbis_to_add.front()).second);\n      out_oba2out_diff_lbi->emplace(oba, GenLogicalBlobId(add_op.output(\"out\", 0)));\n    }\n  }\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job_rewriter/clone_grad.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_JOB_REWRITER_CLONE_GRAD_H_\n#define ONEFLOW_CORE_JOB_REWRITER_CLONE_GRAD_H_\n\n#include \"oneflow/core/job_rewriter/autograd.h\"\n\nnamespace oneflow {\n\nMaybe<void> GenerateCloneGradOpIfNeed(\n    const OpNode& op_node, JobBuilder* job_builder,\n    const HashMap<OpBlobArg, LogicalBlobId>& in_oba2in_diff_lbi,\n    HashMap<OpBlobArg, LogicalBlobId>* out_oba2out_diff_lbi,\n    HashMap<OpBlobArg, LogicalBlobId>* out_oba2clone_bw_add_out_lbi);\n}\n\n#endif  // ONEFLOW_CORE_JOB_REWRITER_CLONE_GRAD_H_\n"
  },
  {
    "path": "oneflow/core/job_rewriter/cudnn_fused_normalization_add_relu_pass.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/job_rewriter/job_pass.h\"\n#include \"oneflow/core/framework/framework.h\"\n\n#ifdef WITH_CUDA\n#include <cudnn.h>\n#endif  // WITH_CUDA\n\nnamespace oneflow {\n\nnamespace {\n\nbool IsFusedBnAddReluSupported() {\n#if defined(WITH_CUDA) && (CUDNN_VERSION >= 7401)\n  return true;\n#else\n  return false;\n#endif\n}\n\nbool IsNormalizationAddReluOp(const OperatorConf& op) {\n  return op.has_user_conf()\n         && (op.user_conf().op_type_name() == \"normalization_add_relu\"\n             || op.user_conf().op_type_name() == \"normalization_add_relu_grad\");\n}\n\nbool NeedDoPass(const Job& job) {\n  return std::any_of(job.net().op().cbegin(), job.net().op().cend(), IsNormalizationAddReluOp);\n}\n\n}  // namespace\n\nclass CudnnFusedNormalizationAddReluPass final : public JobPass {\n public:\n  CudnnFusedNormalizationAddReluPass() = default;\n  ~CudnnFusedNormalizationAddReluPass() override = default;\n\n  bool IsEnabled(const JobPassCtx& ctx) const {\n    if (ctx.job_desc().job_conf().has_enable_cudnn_fused_normalization_add_relu()) {\n      bool enabled = ctx.job_desc().job_conf().enable_cudnn_fused_normalization_add_relu();\n      CHECK(!enabled || IsFusedBnAddReluSupported())\n          << \"Option 'enable_cudnn_fused_normalization_add_relu' is only supported when cuDNN \"\n             \"version >= 7.4.1\";\n      return enabled;\n    } else {\n      return IsFusedBnAddReluSupported();\n    }\n  }\n  Maybe<void> Apply(Job* job, JobPassCtx* ctx) const override;\n};\n\nMaybe<void> CudnnFusedNormalizationAddReluPass::Apply(Job* job, JobPassCtx* ctx) const {\n  if (!IsEnabled(*ctx)) { return Maybe<void>::Ok(); }\n  if (!NeedDoPass(*job)) { return Maybe<void>::Ok(); }\n  const OpGraph op_graph(*job);\n  JobBuilder job_builder(job);\n  const DataType mixed_precision_data_type = ctx->job_desc().mixed_precision_data_type();\n  op_graph.ForEachNode([&](const OpNode* op_node) {\n    const OperatorConf& op_conf = op_node->op().op_conf();\n    if (!IsNormalizationAddReluOp(op_conf)) { return; }\n    const std::string& op_type_name = op_conf.user_conf().op_type_name();\n    const user_op::UserOpConfWrapper user_op_conf(op_conf);\n    const BlobDesc& x_desc =\n        op_node->LogicalBlobDesc4Lbi(GenLogicalBlobId(user_op_conf.input(\"x\", 0)));\n    const int32_t axis = user_op_conf.attr<int32_t>(\"axis\");\n    if (x_desc.data_type() != mixed_precision_data_type) { return; }\n    const Shape& x_shape = x_desc.shape();\n    if (x_shape.Count(axis + 1) != 1) { return; }\n    if (x_shape.At(axis) % 4 != 0) { return; }\n    OperatorConf new_op_conf = op_conf;\n    auto mute_attrs = new_op_conf.mutable_user_conf()->mutable_attr();\n    auto training_it = mute_attrs->find(\"training\");\n    if (training_it != mute_attrs->end()) {\n      const bool training = user_op_conf.attr<bool>(\"training\");\n      if (!training) { return; }\n      mute_attrs->erase(training_it);\n    }\n    new_op_conf.mutable_user_conf()->set_op_type_name(\"cudnn_fused_\" + op_type_name);\n    job_builder.MutOpsOnlyOnce({new_op_conf});\n  });\n  return Maybe<void>::Ok();\n}\n\nREGISTER_JOB_PASS(\"CudnnFusedNormalizationAddReluPass\", CudnnFusedNormalizationAddReluPass);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job_rewriter/cutlass_conv_tuning_warmup_pass.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#ifdef WITH_CUTLASS\n\n#include \"oneflow/core/framework/to_string.h\"\n#include \"oneflow/core/job/global_for.h\"\n#include \"oneflow/core/job_rewriter/job_pass.h\"\n#include \"oneflow/user/kernels/cutlass_conv_tuner.h\"\n#include \"oneflow/core/ep/include/device_manager_registry.h\"\n#include \"oneflow/core/framework/user_op_conf.h\"\n#include <nlohmann/json.hpp>\n\nnamespace oneflow {\n\nnamespace {\n\nconstexpr size_t kMaxWorkspaceSize = 128 * 1024 * 1024;   // 128MB\nconstexpr size_t kBufferMallocAlign = 128 * 1024 * 1024;  // 128MB\n\nclass CutlassConvTuningWarmupPass final : public JobPass {\n public:\n  CutlassConvTuningWarmupPass() = default;\n  ~CutlassConvTuningWarmupPass() override = default;\n\n  Maybe<void> Apply(Job* job, JobPassCtx* ctx) const override;\n};\n\nMaybe<void> CutlassConvTuningWarmupPass::Apply(Job* job, JobPassCtx* ctx) const {\n  // Compatible with typo `KERENL`\n  if (!ParseBooleanFromEnv(\"ONEFLOW_KERNEL_CONV_ENABLE_CUTLASS_IMPL\",\n                           ParseBooleanFromEnv(\"ONEFLOW_KERENL_CONV_ENABLE_CUTLASS_IMPL\", false))) {\n    return Maybe<void>::Ok();\n  }\n  if (!ParseBooleanFromEnv(\n          \"ONEFLOW_KERNEL_CONV_CUTLASS_IMPL_ENABLE_TUNING_WARMUP\",\n          ParseBooleanFromEnv(\"ONEFLOW_KERENL_CONV_CUTLASS_IMPL_ENABLE_TUNING_WARMUP\", false))) {\n    return Maybe<void>::Ok();\n  }\n  const OpGraph op_graph(*job);\n  JobBuilder job_builder(job);\n\n  auto device = Singleton<ep::DeviceManagerRegistry>::Get()->GetDevice(DeviceType::kCUDA, 0);\n  ep::Stream* stream = device->CreateStream();\n  void* workspace = nullptr;\n  char* buffer = nullptr;\n  size_t buffer_size = 0;\n  OF_CUDA_CHECK(cudaMalloc(&workspace, kMaxWorkspaceSize));\n  std::vector<OperatorConf> op_confs;\n  op_graph.ForEachNode([&](const OpNode* node) {\n    const OperatorConf& op_conf = node->op().op_conf();\n    if (!op_conf.has_user_conf()) { return; }\n    if (op_conf.user_conf().op_type_name() != \"conv2d\") { return; }\n    if (node->parallel_desc().device_type() != DeviceType::kCUDA) { return; }\n    if (node->parallel_desc().parallel_num() != 1) { return; }\n    if (!node->parallel_desc().containing_current_rank()) { return; }\n    user_op::UserOpConfWrapper conv2d_op(op_conf);\n    if (conv2d_op.attr<std::string>(\"data_format\") != \"channels_last\") { return; }\n    if (conv2d_op.attr<int32_t>(\"groups\") != 1) { return; }\n    VLOG(3) << \"Tuning \" << op_conf.name();\n    const auto& in_desc = node->LogicalBlobDesc4Lbi(GenLogicalBlobId(conv2d_op.input(\"in\", 0)));\n    if (in_desc.data_type() != DataType::kFloat16) { return; }\n    const auto& weight_desc =\n        node->LogicalBlobDesc4Lbi(GenLogicalBlobId(conv2d_op.input(\"weight\", 0)));\n    const auto& out_desc = node->LogicalBlobDesc4Lbi(GenLogicalBlobId(conv2d_op.output(\"out\", 0)));\n\n    const auto& padding_before = conv2d_op.attr<std::vector<int32_t>>(\"padding_before\");\n    const auto& dilation_rate = conv2d_op.attr<std::vector<int32_t>>(\"dilation_rate\");\n    const auto& strides = conv2d_op.attr<std::vector<int32_t>>(\"strides\");\n\n    const int n = in_desc.shape().At(0);\n    const int h = in_desc.shape().At(1);\n    const int w = in_desc.shape().At(2);\n    const int c = in_desc.shape().At(3);\n\n    const int k = weight_desc.shape().At(0);\n    const int r = weight_desc.shape().At(1);\n    const int s = weight_desc.shape().At(2);\n    CHECK_EQ(weight_desc.shape().At(3), c);\n\n    const int p = out_desc.shape().At(1);\n    const int q = out_desc.shape().At(2);\n\n    cutlass::library::ConvFunctionalKey key(\n        cutlass::library::Provider::kCUTLASS, cutlass::library::ConvKind::kFprop,\n        cutlass::library::NumericTypeID::kF16, cutlass::library::LayoutTypeID::kTensorNHWC,\n        cutlass::library::NumericTypeID::kF16, cutlass::library::LayoutTypeID::kTensorNHWC,\n        cutlass::library::NumericTypeID::kF16, cutlass::library::LayoutTypeID::kTensorNHWC,\n        cutlass::library::NumericTypeID::kF32, cutlass::library::NumericTypeID::kF32);\n\n    const bool allow_half_accumulation =\n        ParseBooleanFromEnv(\"ONEFLOW_CONV_ALLOW_HALF_PRECISION_ACCUMULATION\", false);\n\n    if (allow_half_accumulation) {\n      key.element_accumulator = cutlass::library::NumericTypeID::kF16;\n      key.element_compute = cutlass::library::NumericTypeID::kF16;\n    }\n\n    const size_t x_size = GetCudaAlignedSize(in_desc.ByteSizeOfBlobBody());\n    const size_t w_size = GetCudaAlignedSize(weight_desc.ByteSizeOfBlobBody());\n    const size_t y_size = GetCudaAlignedSize(out_desc.ByteSizeOfBlobBody());\n    size_t bias_size = 0;\n    if (conv2d_op.has_input(\"bias\", 0)) {\n      bias_size =\n          GetCudaAlignedSize(node->LogicalBlobDesc4Lbi(GenLogicalBlobId(conv2d_op.input(\"bias\", 0)))\n                                 .ByteSizeOfBlobBody());\n    }\n    const size_t total_buf_size = x_size + w_size + y_size + bias_size;\n    if (total_buf_size > buffer_size) {\n      size_t malloc_size = RoundUp(total_buf_size, kBufferMallocAlign);\n      OF_CUDA_CHECK(cudaFree(buffer));\n      OF_CUDA_CHECK(cudaMalloc(&buffer, malloc_size));\n      buffer_size = malloc_size;\n    }\n    void* x_ptr = buffer;\n    void* w_ptr = buffer + x_size;\n    void* y_ptr = buffer + x_size + w_size;\n    void* bias_ptr = nullptr;\n    if (bias_size != 0) { bias_ptr = buffer + x_size + w_size + y_size; }\n\n    cutlass::conv::Conv2dProblemSize problem_size(\n        n, h, w, c, k, r, s, p, q, padding_before.at(0), padding_before.at(1), strides.at(0),\n        strides.at(1), dilation_rate.at(0), dilation_rate.at(1),\n        cutlass::conv::Mode::kCrossCorrelation);\n    cutlass::library::Conv2dConfiguration configuraion;\n    configuraion.split_k_mode = cutlass::conv::SplitKMode::kSerial;\n    configuraion.problem_size = problem_size;\n    configuraion.stride_a = {c, w * c, h * w * c};\n    configuraion.stride_b = {c, s * c, r * s * c};\n    configuraion.stride_c = {0, 0, 0};\n    cutlass::library::ConvArguments arguments;\n    arguments.A = x_ptr;\n    arguments.B = w_ptr;\n    arguments.reordered_B = nullptr;\n    arguments.C = bias_ptr;\n    arguments.D = y_ptr;\n    union SP {\n      float f{};\n      half h;\n    };\n\n    SP alpha;\n    SP beta;\n\n    if (allow_half_accumulation) {\n      alpha.h = static_cast<half>(1.0F);\n      if (bias_ptr == nullptr) {\n        beta.h = static_cast<half>(0.0F);\n      } else {\n        beta.h = static_cast<half>(1.0F);\n      }\n    } else {\n      alpha.f = 1.0F;\n      if (bias_ptr == nullptr) {\n        beta.f = 0.0F;\n      } else {\n        beta.f = 1.0F;\n      }\n    }\n    arguments.alpha = &alpha;\n    arguments.beta = &beta;\n    arguments.pointer_mode = cutlass::library::ScalarPointerMode::kHost;\n\n    const cutlass::library::Operation* operation = CutlassConvTuner::Get().FindConv2dOperation(\n        stream->As<ep::CudaStream>(), key, configuraion, arguments, workspace, kMaxWorkspaceSize);\n    if (operation != nullptr) {\n      VLOG(3) << \"Fastest operation: \" << operation->description().name;\n      nlohmann::json tuning_cache;\n      tuning_cache[\"cutlass\"] = operation->description().name;\n      OperatorConf new_op_conf = op_conf;\n      (*(*new_op_conf.mutable_user_conf()->mutable_attr())[\"tuning_cache\"].mutable_at_string()) =\n          tuning_cache.dump();\n      op_confs.push_back(new_op_conf);\n    }\n  });\n  job_builder.MutOpsOnlyOnce(op_confs);\n  OF_CUDA_CHECK(cudaFree(workspace));\n  OF_CUDA_CHECK(cudaFree(buffer));\n  device->DestroyStream(stream);\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\nREGISTER_JOB_PASS(\"CutlassConvTuningWarmupPass\", CutlassConvTuningWarmupPass);\n\n}  // namespace oneflow\n\n#endif  // WITH_CUTLASS\n"
  },
  {
    "path": "oneflow/core/job_rewriter/delay_variable_op_execution_pass.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/job_rewriter/job_pass.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/rpc/include/global_process_ctx.h\"\n\nnamespace oneflow {\n\nclass DelayVariableOpExecutionPass final : public JobPass {\n public:\n  DelayVariableOpExecutionPass() = default;\n  ~DelayVariableOpExecutionPass() override = default;\n\n  Maybe<void> Apply(Job* job, JobPassCtx* ctx) const override;\n};\n\nMaybe<void> DelayVariableOpExecutionPass::Apply(Job* job, JobPassCtx* ctx) const {\n  if (!ParseBooleanFromEnv(\"ONEFLOW_GRAPH_DELAY_VARIABLE_OP_EXECUTION\", false)) {\n    return Maybe<void>::Ok();\n  }\n  const JobConfigProto& job_conf = ctx->job_desc().job_conf();\n  if (job_conf.has_train_conf()) { return Maybe<void>::Ok(); }\n  if (job_conf.has_num_gradient_accumulation_steps()\n      && job_conf.num_gradient_accumulation_steps() > 1) {\n    return Maybe<void>::Ok();\n  }\n  if (GlobalProcessCtx::WorldSize() > 1) { return Maybe<void>::Ok(); }\n  const OpGraph op_graph(*job);\n  JobBuilder job_builder(job);\n  JUST(op_graph.TopoForEachNodeWithErrorCaptured([&](const OpNode* node) -> Maybe<void> {\n    const OperatorConf& op_conf = node->op().op_conf();\n    if (!op_conf.has_variable_conf()) { return Maybe<void>::Ok(); }\n    if (!op_conf.ctrl_in_op_name().empty()) { return Maybe<void>::Ok(); }\n    if (op_conf.variable_conf().has_tick()) { return Maybe<void>::Ok(); }\n    if (node->out_edges().size() != 1) { return Maybe<void>::Ok(); }\n    if (node->parallel_desc().parallel_num() != 1) { return Maybe<void>::Ok(); }\n    const OpNode* dst_node = (*node->out_edges().begin())->dst_node();\n    if (dst_node->parallel_desc() != node->parallel_desc()) { return Maybe<void>::Ok(); }\n\n    const OpEdge* none_variable_edge = nullptr;\n    for (const OpEdge* edge : dst_node->in_edges()) {\n      if (edge->src_node()->op().op_conf().has_variable_conf()) { continue; }\n      if (edge->lbis().size() == 0) { continue; }\n      if (edge->src_node()->parallel_desc() != node->parallel_desc()) { continue; }\n      none_variable_edge = edge;\n      break;\n    }\n    if (none_variable_edge == nullptr) { return Maybe<void>::Ok(); }\n    OperatorConf new_varibale_conf = op_conf;\n    new_varibale_conf.mutable_variable_conf()->set_tick(\n        GenLogicalBlobName(none_variable_edge->lbis().front()));\n    job_builder.MutOpsOnlyOnce({new_varibale_conf});\n    return Maybe<void>::Ok();\n  }));\n  return Maybe<void>::Ok();\n}\n\nREGISTER_JOB_PASS(\"DelayVariableOpExecutionPass\", DelayVariableOpExecutionPass);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job_rewriter/device_tick_autotick.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/job_rewriter/autotick.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nclass MutDeviceTickOpConTickInputHelper final : public MutOpConTickInputHelper {\n public:\n  MutDeviceTickOpConTickInputHelper() : MutOpConTickInputHelper() {}\n\n  bool VirtualIsTickInputBound() const override {\n    return op_conf().device_tick_conf().tick_size() > 0;\n  }\n\n  OperatorConf NewTickInputBoundOpConf(const std::string& lbn) const override {\n    OperatorConf ret(op_conf());\n    ret.mutable_device_tick_conf()->add_tick(lbn);\n    return ret;\n  }\n};\n\n}  // namespace\n\nREGISTER_AUTO_TICK(OperatorConf::kDeviceTickConf, MutDeviceTickOpConTickInputHelper);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job_rewriter/do_parallel_cast_before_widening_type_cast_pass.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/user_op_conf.h\"\n#include \"oneflow/core/job_rewriter/job_pass.h\"\n#include \"oneflow/core/job_rewriter/pass_util.h\"\n#include \"oneflow/core/rpc/include/global_process_ctx.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nclass DoParallelCastBeforeWideningTypeCast final : public JobPass {\n public:\n  DoParallelCastBeforeWideningTypeCast() = default;\n  ~DoParallelCastBeforeWideningTypeCast() override = default;\n\n  bool IsEnabled(const JobPassCtx& ctx) const {\n    return ctx.job_desc().do_parallel_cast_before_widening_type_cast();\n  }\n  Maybe<void> Apply(const OpGraph& op_graph, JobBuilder* job_builder) const;\n\n  Maybe<void> Apply(Job* job, JobPassCtx* ctx) const override {\n    if (!IsEnabled(*ctx)) { return Maybe<void>::Ok(); }\n    if (GlobalProcessCtx::WorldSize() == 1) { return Maybe<void>::Ok(); }\n    const OpGraph op_graph(*job);\n    JobBuilder job_builder(job);\n    return Apply(op_graph, &job_builder);\n  }\n};\n\nMaybe<void> DoParallelCastBeforeWideningTypeCast::Apply(const OpGraph& op_graph,\n                                                        JobBuilder* job_builder) const {\n  OpConfCache op_conf_cache;\n  op_graph.ForEachNode([&op_conf_cache](OpNode* parallel_cast_node) {\n    // find cast_fp16_to_fp32_or_double -> parallel_cast pattern\n    const OperatorConf& parallel_cast_op_conf =\n        op_conf_cache.GetLatest(parallel_cast_node->op().op_conf());\n    if (!(parallel_cast_op_conf.has_user_conf()\n          && (parallel_cast_op_conf.user_conf().op_type_name() == \"parallel_cast\"\n              || parallel_cast_op_conf.user_conf().op_type_name()\n                     == \"hierarchical_parallel_cast\"))) {\n      return;\n    }\n    auto* cast_node = parallel_cast_node->SoleInEdge()->src_node();\n    if (cast_node->out_edges().size() != 1) { return; }\n    auto cast_op_conf = op_conf_cache.GetLatest(cast_node->op().op_conf());\n    if (!(cast_op_conf.has_user_conf() && cast_op_conf.user_conf().op_type_name() == \"cast\")) {\n      return;\n    }\n    user_op::UserOpConfWrapper cast_conf_wrapper(cast_op_conf);\n    const auto cast_in_lbi = cast_node->SoleInEdge()->lbis().front();\n    const auto cast_in_dtype = cast_node->LogicalBlobDesc4Lbi(cast_in_lbi).data_type();\n    const auto cast_out_dtype = cast_conf_wrapper.attr<DataType>(\"dtype\");\n    if (!((cast_in_dtype == DataType::kFloat16 || cast_in_dtype == DataType::kBFloat16)\n          && (cast_out_dtype == DataType::kFloat || cast_out_dtype == DataType::kDouble))) {\n      return;\n    }\n\n    user_op::UserOpConfWrapper parallel_cast_conf_wrapper(parallel_cast_op_conf);\n    // replace parallel_cast op input with cast op input\n    {\n      OperatorConf new_parallel_cast_op_conf(parallel_cast_op_conf);\n      const auto& cast_input = cast_conf_wrapper.input(\"in\", 0);\n      const auto& parallel_cast_input = parallel_cast_conf_wrapper.input(\"in\", 0);\n      const auto& old_val =\n          ReplaceInputLbnInOpCustomizedConf(&new_parallel_cast_op_conf, \"in_0\", cast_input);\n      CHECK_EQ(parallel_cast_input, old_val);\n      op_conf_cache.Put(new_parallel_cast_op_conf);\n    }\n    // replace cast op input with parallel_cast op output\n    {\n      OperatorConf new_cast_op_conf(cast_op_conf);\n      const auto& parallel_cast_output = parallel_cast_conf_wrapper.output(\"out\", 0);\n      const auto& cast_input = cast_conf_wrapper.input(\"in\", 0);\n      const auto& old_val =\n          ReplaceInputLbnInOpCustomizedConf(&new_cast_op_conf, \"in_0\", parallel_cast_output);\n      CHECK_EQ(cast_input, old_val);\n      op_conf_cache.Put(new_cast_op_conf);\n    }\n\n    // update all parallel_cast op consumers\n    const std::string& cast_output = cast_conf_wrapper.output(\"out\", 0);\n    for (OpEdge* edge : parallel_cast_node->out_edges()) {\n      CHECK_EQ(1, edge->lbis().size());\n      LogicalBlobId cur_lbi = edge->lbis().front();\n      const auto lbn = GenLogicalBlobName(cur_lbi);\n      CHECK_EQ(1, edge->lbi2ibns().at(cur_lbi).size());\n      const std::string& dst_ibn = edge->lbi2ibns().at(cur_lbi).front();\n\n      OpNode* dst_node = edge->dst_node();\n      OperatorConf dst_op_conf = op_conf_cache.GetLatest(dst_node->op().op_conf());\n      CHECK_EQ(lbn, ReplaceInputLbnInOpCustomizedConf(&dst_op_conf, dst_ibn, cast_output));\n      op_conf_cache.Put(dst_op_conf);\n    }\n  });\n  job_builder->MutOpsOnlyOnce(op_conf_cache.op_confs());\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\nREGISTER_JOB_PASS(\"DoParallelCastBeforeWideningTypeCast\", DoParallelCastBeforeWideningTypeCast);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job_rewriter/dump_blob_parallel_conf_pass.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/job_rewriter/job_pass.h\"\n#include \"oneflow/core/job/job.pb.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nclass DumpBlobParallelConfPass final : public JobPass {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(DumpBlobParallelConfPass);\n  DumpBlobParallelConfPass() = default;\n  ~DumpBlobParallelConfPass() override = default;\n\n  bool IsEnabled(const JobPassCtx& ctx) const { return true; }\n\n  Maybe<void> Apply(const OpGraph& op_graph, Job* job) const {\n    op_graph.DumpLogicalBlobDesc(job);\n    op_graph.DumpArgSignature(job);\n    op_graph.DumpNdSbpSignature(job);\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> Apply(Job* job, JobPassCtx* ctx) const override {\n    const OpGraph op_graph(*job);\n    return Apply(op_graph, job);\n  }\n};\n\nREGISTER_JOB_PASS(\"DumpBlobParallelConfPass\", DumpBlobParallelConfPass);\n\n}  // namespace\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job_rewriter/dump_variable_info_pass.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/to_string.h\"\n#include \"oneflow/core/job/global_for.h\"\n#include \"oneflow/core/job_rewriter/job_pass.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nstd::string GetNdSbpString(const VariableOpConf& conf, const ParallelDesc& parallel_desc) {\n  const bool has_nd_sbp_conf = (conf.nd_sbp_size() != 0);\n  const int64_t num_axes = parallel_desc.hierarchy()->NumAxes();\n  if (has_nd_sbp_conf) { CHECK_EQ(conf.nd_sbp_size(), num_axes); }\n  std::string nd_sbp_str;\n  FOR_RANGE(int64_t, i, 0, num_axes) {\n    if (has_nd_sbp_conf) {\n      nd_sbp_str += conf.nd_sbp(i);\n    } else {\n      nd_sbp_str += \"B\";\n    }\n    if (i != num_axes - 1) { nd_sbp_str += \", \"; }\n  }\n  return nd_sbp_str;\n}\n\nclass DumpVariableInfoPass final : public JobPass {\n public:\n  DumpVariableInfoPass() = default;\n  ~DumpVariableInfoPass() override = default;\n\n  bool IsEnabled(const JobPassCtx& ctx) const {\n    return Singleton<ResourceDesc, ForSession>::Get()->enable_debug_mode();\n  }\n  Maybe<void> Apply(const OpGraph& op_graph, JobBuilder* job_builder) const;\n\n  Maybe<void> Apply(Job* job, JobPassCtx* ctx) const override {\n    if (!IsEnabled(*ctx)) { return Maybe<void>::Ok(); }\n    const OpGraph op_graph(*job);\n    JobBuilder job_builder(job);\n    return Apply(op_graph, &job_builder);\n  }\n};\n\nMaybe<void> DumpVariableInfoPass::Apply(const OpGraph& op_graph, JobBuilder* job_builder) const {\n  int64_t cnt = 0;\n  const std::string sep = \"\\t\";\n  auto log_stream =\n      TeePersistentLogStream::Create(\"variable_table_\" + std::to_string(GlobalJobDesc().job_id()));\n  (*log_stream) << \"id\" << sep << \"name\" << sep << \"device_tag\" << sep << \"parallel_hierarchy\"\n                << sep << \"distribute\" << sep << \"data_type\" << sep << \"shape\" << sep << \"elem_cnt\"\n                << sep << \"size\"\n                << \"\\n\";\n  JUST(op_graph.TopoForEachNodeWithErrorCaptured([&](const OpNode* node) -> Maybe<void> {\n    const OperatorConf& op_conf = node->op().op_conf();\n    if (!op_conf.has_variable_conf()) { return Maybe<void>::Ok(); }\n    const VariableOpConf& conf = op_conf.variable_conf();\n    (*log_stream) << std::to_string(cnt);\n    (*log_stream) << sep;\n    (*log_stream) << op_conf.name();\n    (*log_stream) << sep;\n    (*log_stream) << op_conf.device_tag();\n    (*log_stream) << sep;\n    (*log_stream) << node->parallel_desc().hierarchy()->DebugStr();\n    (*log_stream) << sep;\n    (*log_stream) << GetNdSbpString(conf, node->parallel_desc());\n    (*log_stream) << sep;\n    (*log_stream) << DataType_Name(conf.data_type());\n    (*log_stream) << sep;\n    const Shape shape(conf.shape());\n    (*log_stream) << shape.ToString();\n    (*log_stream) << sep;\n    (*log_stream) << std::to_string(shape.elem_cnt());\n    (*log_stream) << sep;\n    (*log_stream) << std::to_string(shape.elem_cnt() * GetSizeOfDataType(conf.data_type()));\n    (*log_stream) << \"\\n\";\n    cnt += 1;\n    return Maybe<void>::Ok();\n  }));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\nREGISTER_JOB_PASS(\"DumpVariableInfoPass\", DumpVariableInfoPass);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job_rewriter/dynamic_loss_scale_job_pass_state.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_JOB_REWRITER_DYNAMIC_LOSS_SCALE_JOB_PASS_STATE_H_\n#define ONEFLOW_CORE_JOB_REWRITER_DYNAMIC_LOSS_SCALE_JOB_PASS_STATE_H_\n\n#include \"oneflow/core/job_rewriter/job_pass.h\"\n\nnamespace oneflow {\n\nclass DynamicLossScaleJobPassState : public JobPassState {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(DynamicLossScaleJobPassState);\n  DynamicLossScaleJobPassState() = default;\n  ~DynamicLossScaleJobPassState() override = default;\n\n  const std::string& count_not_finite_lbn() const { return count_not_finite_lbn_; }\n  void set_count_not_finite_lbn(const std::string& lbn) { count_not_finite_lbn_ = lbn; }\n\n  const std::string& loss_scale_val_lbn() const { return loss_scale_val_lbn_; }\n  void set_loss_scale_val_lbn(const std::string& lbn) { loss_scale_val_lbn_ = lbn; }\n\n private:\n  std::string count_not_finite_lbn_;\n  std::string loss_scale_val_lbn_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_JOB_REWRITER_DYNAMIC_LOSS_SCALE_JOB_PASS_STATE_H_\n"
  },
  {
    "path": "oneflow/core/job_rewriter/dynamic_loss_scale_schedule_pass.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/job_rewriter/job_pass.h\"\n#include \"oneflow/core/job/job.pb.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/job_rewriter/dynamic_loss_scale_job_pass_state.h\"\n#include \"oneflow/core/framework/scope_util.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nclass DynamicLossScaleSchedulePass final : public JobPass {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(DynamicLossScaleSchedulePass);\n  DynamicLossScaleSchedulePass() = default;\n  ~DynamicLossScaleSchedulePass() override = default;\n\n  Maybe<void> Apply(Job* job, JobPassCtx* ctx) const override;\n};\n\nMaybe<void> DynamicLossScaleSchedulePass::Apply(Job* job, JobPassCtx* ctx) const {\n  if (!ctx->job_desc().IsTrain()) { return Maybe<void>::Ok(); }\n  const TrainConf& train_conf = job->job_conf().train_conf();\n  if (!train_conf.has_dynamic_loss_scale_policy()) { return Maybe<void>::Ok(); }\n  const auto& policy = train_conf.dynamic_loss_scale_policy();\n  const OpGraph op_graph(*job);\n  JobBuilder job_builder(job);\n  const ParallelConf& parallel_conf = GenParallelConfOfCpuZeroOnMaster();\n  int64_t scope_symbol_id;\n  {\n    const auto& opt_scope_symbol_id =\n        JUST(MakeInitialScope(job->job_conf(), SymbolOf(ParallelDesc(parallel_conf)),\n                              /* is_local */ false))\n            ->symbol_id();\n    CHECK_OR_RETURN(opt_scope_symbol_id.has_value())\n        << Error::RuntimeError() << \"symbol_id not initialized\";\n    scope_symbol_id = JUST(opt_scope_symbol_id);\n  }\n  OperatorConf loss_scale_var_op_conf{};\n  const std::string op_name_prefix = \"System-Train-DynamicLossScale-\";\n  {\n    loss_scale_var_op_conf.set_name(op_name_prefix + job->job_conf().job_name() + \"-LossScale\");\n    VariableOpConf* variable_conf = loss_scale_var_op_conf.mutable_variable_conf();\n    variable_conf->set_out(\"out\");\n    *variable_conf->mutable_shape()->mutable_dim()->Add() = 1;\n    variable_conf->set_data_type(DataType::kFloat);\n    variable_conf->mutable_initializer()->mutable_constant_conf()->set_value(\n        policy.initial_loss_scale());\n    loss_scale_var_op_conf.set_scope_symbol_id(scope_symbol_id);\n  }\n  OperatorConf good_step_counter_var_conf{};\n  {\n    good_step_counter_var_conf.set_name(op_name_prefix + job->job_conf().job_name()\n                                        + \"-GoodStepCounter\");\n    VariableOpConf* variable_conf = good_step_counter_var_conf.mutable_variable_conf();\n    variable_conf->set_out(\"out\");\n    *variable_conf->mutable_shape()->mutable_dim()->Add() = 1;\n    variable_conf->set_data_type(DataType::kInt64);\n    variable_conf->mutable_initializer()->mutable_constant_int_conf()->set_value(0);\n    good_step_counter_var_conf.set_scope_symbol_id(scope_symbol_id);\n  }\n  OperatorConf loss_scale_val_op_conf{};\n  const std::string loss_scale_var_lbn = GenLogicalBlobName(\n      loss_scale_var_op_conf.name(), loss_scale_var_op_conf.variable_conf().out());\n  {\n    loss_scale_val_op_conf.set_name(loss_scale_var_op_conf.name() + \"-Identity\");\n    loss_scale_val_op_conf.set_scope_symbol_id(scope_symbol_id);\n    IdentityOpConf* identity_conf = loss_scale_val_op_conf.mutable_identity_conf();\n    identity_conf->set_in(loss_scale_var_lbn);\n    identity_conf->set_out(\"out\");\n  }\n  // will be replaced by real count of not finite\n  auto count_not_finite_stub_op =\n      user_op::UserOpConfWrapperBuilder(op_name_prefix + job->job_conf().job_name()\n                                        + \"-CountNotFinite\")\n          .Op(\"constant\")\n          .Output(\"out\")\n          .Attr<double>(\"floating_value\", 0.0)\n          .Attr<int64_t>(\"integer_value\", 0)\n          .Attr<bool>(\"is_floating_value\", false)\n          .Attr<DataType>(\"dtype\", DataType::kInt64)\n          .Attr<Shape>(\"shape\", Shape({1}))\n          .ScopeSymbolId(scope_symbol_id)\n          .Build();\n  const std::string loss_scale_val_lbn = GenLogicalBlobName(\n      loss_scale_val_op_conf.name(), loss_scale_val_op_conf.identity_conf().out());\n  const std::string good_step_counter_var_lbn = GenLogicalBlobName(\n      good_step_counter_var_conf.name(), good_step_counter_var_conf.variable_conf().out());\n  auto schedule =\n      user_op::UserOpConfWrapperBuilder(op_name_prefix + job->job_conf().job_name() + \"-Schedule\")\n          .Op(\"dynamic_loss_scale_schedule\")\n          .Input(\"count_not_finite\", count_not_finite_stub_op.output(\"out\", 0))\n          .Input(\"loss_scale\", loss_scale_var_lbn)\n          .Input(\"good_step_counter\", good_step_counter_var_lbn)\n          .Attr<int64_t>(\"increment_period\", policy.increment_period())\n          .Attr<float>(\"multiplier\", policy.multiplier())\n          .ScopeSymbolId(scope_symbol_id)\n          .Build();\n  job_builder.AddOps(parallel_conf,\n                     {loss_scale_var_op_conf, loss_scale_val_op_conf, good_step_counter_var_conf,\n                      count_not_finite_stub_op.op_conf(), schedule.op_conf()});\n  if (!JUST(ctx->HasState<DynamicLossScaleJobPassState>(\"dynamic_loss_scale_state\"))) {\n    JUST(ctx->ResetState(\"dynamic_loss_scale_state\",\n                         std::make_unique<DynamicLossScaleJobPassState>()));\n  }\n  auto state = JUST(ctx->MutableState<DynamicLossScaleJobPassState>(\"dynamic_loss_scale_state\"));\n  state->set_loss_scale_val_lbn(loss_scale_val_lbn);\n  state->set_count_not_finite_lbn(count_not_finite_stub_op.output(\"out\", 0));\n  return Maybe<void>::Ok();\n}\n\nREGISTER_JOB_PASS(\"DynamicLossScaleSchedulePass\", DynamicLossScaleSchedulePass);\n\n}  // namespace\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job_rewriter/eliminate_dead_nodes_pass.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/job_rewriter/job_pass.h\"\n#include \"oneflow/core/framework/framework.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nclass EliminateDeadNodesPass final : public JobPass {\n public:\n  EliminateDeadNodesPass() = default;\n  ~EliminateDeadNodesPass() override = default;\n\n  Maybe<void> Apply(const OpGraph& op_graph, JobBuilder* job_builder) const;\n\n  Maybe<void> Apply(Job* job, JobPassCtx* ctx) const override {\n    const OpGraph op_graph(*job);\n    JobBuilder job_builder(job);\n    return Apply(op_graph, &job_builder);\n  }\n};\n\nstatic bool IsNoSideEffect(const OpNode* op_node) {\n  static HashSet<std::string> no_side_effect_ops = {\n      \"constant\", \"zeros_like\", \"ones_like\", \"repeat\", \"acc\", \"pack\", \"unpack\",\n  };\n  static HashSet<OperatorConf::OpTypeCase> no_side_effect_system_ops = {\n      OperatorConf::kDeviceTickConf,\n  };\n  const auto& op_conf = op_node->op().op_conf();\n  if (!op_conf.has_user_conf()) { return no_side_effect_system_ops.count(op_conf.op_type_case()); }\n  return no_side_effect_ops.count(op_conf.user_conf().op_type_name());\n}\n\nMaybe<void> EliminateDeadNodesPass::Apply(const OpGraph& op_graph, JobBuilder* job_builder) const {\n  HashSet<const OpNode*> delete_ops;\n  std::vector<OperatorConf> delete_op_confs;\n  op_graph.ReverseTopoForEachNode([&](const OpNode* op_node) {\n    if (!IsNoSideEffect(op_node)) { return; }\n    for (const auto* out_edge : op_node->out_edges()) {\n      if (!delete_ops.count(out_edge->dst_node())) { return; }\n    }\n    VLOG(3) << \"Eliminate dead node: \" << op_node->op().op_name();\n    delete_ops.insert(op_node);\n    delete_op_confs.emplace_back(op_node->op().op_conf());\n  });\n\n  job_builder->DelOps(delete_op_confs);\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\nREGISTER_JOB_PASS(\"EliminateDeadNodesPass\", EliminateDeadNodesPass);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job_rewriter/fix_pipeline_stage_id_pass.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/instructions_builder.h\"\n#include \"oneflow/core/job_rewriter/job_pass.h\"\n#include \"oneflow/core/job/job.pb.h\"\n#include \"oneflow/core/job/scope.h\"\n#include \"oneflow/core/job_rewriter/calculation_pass.h\"\n#include \"oneflow/core/operator/operator.h\"\n#include \"oneflow/core/vm/vm_util.h\"\n#include \"oneflow/core/vm/symbol_storage.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nclass FixPipelineStageIdPass final : public JobPass {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(FixPipelineStageIdPass);\n  FixPipelineStageIdPass() = default;\n  ~FixPipelineStageIdPass() = default;\n\n  Maybe<void> Apply(Job* job, JobPassCtx* ctx) const override {\n    if (!IsEnabled(*ctx)) { return Maybe<void>::Ok(); }\n    const OpGraph op_graph(*job);\n    JobBuilder job_builder(job);\n    return Apply(op_graph, &job_builder);\n  }\n\n  bool IsEnabled(const JobPassCtx& ctx) const {\n    return ctx.job_desc().IsTrain()\n           && ctx.job_desc().job_conf().num_gradient_accumulation_steps() > 1;\n  }\n\n  Maybe<void> Apply(const OpGraph& op_graph, JobBuilder* job_builder) const;\n};\n\nconst Scope& Scope4ScopeSymbolId(int64_t scope_symbol_id) {\n  CHECK(Singleton<symbol::Storage<Scope>>::Get()->Has(scope_symbol_id));\n  return Singleton<symbol::Storage<Scope>>::Get()->Get(scope_symbol_id);\n}\n\nconst Scope& Scope4OpNode(const OpNode* op_node) {\n  const OperatorConf& op_conf = op_node->op().op_conf();\n  CHECK(op_conf.has_scope_symbol_id());\n  return Scope4ScopeSymbolId(op_conf.scope_symbol_id());\n}\n\nbool OpNodeHasScope(const OpNode* node) { return node->op().op_conf().has_scope_symbol_id(); }\n\nint64_t GetStageIdHint(const OpNode* node) {\n  return Scope4OpNode(node).Int64(\"pipeline_stage_id_hint\");\n}\n\nstd::string ParallelDesc2HashString(const ParallelDesc& parallel_desc) {\n  std::string ret = parallel_desc.device_tag() + \",{\";\n  for (int64_t m : parallel_desc.sorted_machine_ids()) {\n    ret += (std::to_string(m) + \":[\");\n    for (int64_t d : parallel_desc.sorted_dev_phy_ids(m)) { ret += (std::to_string(d) + \",\"); }\n    ret += \"],\";\n  }\n  ret += \"}\";\n  return ret;\n}\n\nMaybe<int64_t> NewScopeWithStageId(int64_t old_scope_symbol_id, int64_t stage_id) {\n  return NewScopeSymbolId(\n      old_scope_symbol_id,\n      [stage_id](\n          std::shared_ptr<ScopeProto> new_scope) {  // NOLINT(performance-unnecessary-value-param)\n        auto* attr_map = new_scope->mutable_attr_name2attr_value();\n        (*attr_map)[\"pipeline_stage_id_hint\"].set_at_int64(stage_id);\n      });\n}\n\nMaybe<void> FixPipelineStageIdPass::Apply(const OpGraph& op_graph, JobBuilder* job_builder) const {\n  int64_t max_stage_id = 0;\n  op_graph.ForEachNode([&](const OpNode* this_node) {\n    if (!OpNodeHasScope(this_node)) {\n      LOG(WARNING) << \" op : \" << this_node->op().op_conf().DebugString() << \" has NOT scope!\";\n      return;\n    }\n    max_stage_id = std::max(max_stage_id, GetStageIdHint(this_node));\n  });\n\n  if (max_stage_id == 0) { return Maybe<void>::Ok(); }\n  const int64_t total_stage_num = max_stage_id + 1;\n  VLOG(3) << \"total stage num = \" << total_stage_num;\n\n  HashMap<std::string, const OpNode*> op_name2node;\n  HashMap<std::string, std::vector<const OpNode*>> placement2op_nodes;\n  std::vector<OperatorConf> fix_stage_op_confs;\n\n  // NOTE(chengcheng): group op by placement.\n  op_graph.ForEachNode([&](const OpNode* this_node) {\n    if (!OpNodeHasScope(this_node)) { return; }\n    const std::string& op_name = this_node->op().op_name();\n    op_name2node.emplace(op_name, this_node);\n    std::string placement = ParallelDesc2HashString(this_node->parallel_desc());\n    placement2op_nodes[placement].emplace_back(this_node);\n  });\n\n  for (auto& pair : placement2op_nodes) {\n    int64_t max_stage_id = -1;\n    for (const OpNode* this_node : pair.second) {\n      max_stage_id = std::max(max_stage_id, GetStageIdHint(this_node));\n    }\n    CHECK_GE_OR_RETURN(max_stage_id, 0);\n    for (const OpNode* this_node : pair.second) {\n      int64_t this_stage_id = GetStageIdHint(this_node);\n      if (this_stage_id != max_stage_id) {\n        VLOG(3) << \" In FixPipelineStageIdPass, op_name: \" << this_node->op().op_name()\n                << \" origin_stage_id = \" << this_stage_id\n                << \" is different with same placement : \" << pair.first\n                << \" max_stage_id: \" << max_stage_id\n                << \" , so change this op to the max stage id.\\n\";\n        OperatorConf new_op_conf = this_node->op().op_conf();\n        int64_t new_scope_symbol_id =\n            JUST(NewScopeWithStageId(new_op_conf.scope_symbol_id(), max_stage_id));\n        new_op_conf.set_scope_symbol_id(new_scope_symbol_id);\n        fix_stage_op_confs.emplace_back(std::move(new_op_conf));\n      }\n    }\n  }\n\n  for (const auto& op : fix_stage_op_confs) { JUST(job_builder->MutOpOnlyOnce(op)); }\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\nREGISTER_JOB_PASS(\"FixPipelineStageIdPass\", FixPipelineStageIdPass);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job_rewriter/ftrl_optm.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/user_op_conf.h\"\n#include \"oneflow/core/job/initializer_conf.pb.h\"\n#include \"oneflow/core/job/job_builder.h\"\n#include \"oneflow/core/job/job_conf.pb.h\"\n#include \"oneflow/core/job_rewriter/job_pass.h\"\n#include \"oneflow/core/job_rewriter/optimizer.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/operator/op_conf.pb.h\"\n#include \"oneflow/core/operator/operator.h\"\n#include \"oneflow/core/operator/variable_op.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nstd::string GenVariableOutputLbn(const OperatorConf& op_conf) {\n  CHECK(op_conf.has_variable_conf());\n  return GenLogicalBlobName(op_conf.name(), op_conf.variable_conf().out());\n}\n\nOperatorConf GenerateFtrlHelperVariableConf(const VariableOp& op, const std::string& name,\n                                            const float initial_value) {\n  OperatorConf helper_variable_op(op.op_conf());\n  helper_variable_op.set_name(op.op_name() + \"-\" + name);\n  helper_variable_op.mutable_variable_conf()->set_out(\"out\");\n  InitializerConf constant_initializer;\n  constant_initializer.mutable_constant_conf()->set_value(initial_value);\n  *(helper_variable_op.mutable_variable_conf()->mutable_initializer()) = constant_initializer;\n  helper_variable_op.set_scope_symbol_id(op.op_conf().scope_symbol_id());\n  return helper_variable_op;\n}\n\nvoid GenerateFtrlOptimizerOpConf(JobPassCtx* ctx, const OpNode& var_op_node,\n                                 const std::string& model_diff_lbn,\n                                 const OptimizerConf& optimizer_conf, JobBuilder* job_builder) {\n  const VariableOp* var_op = dynamic_cast<const VariableOp*>(&var_op_node.op());\n  CHECK_NOTNULL(var_op);\n\n  user_op::UserOpConfWrapperBuilder ftrl_update_op_builder(var_op->op_name() + \"_optimizer\");\n  float lr_power = 0.0;\n  float initial_accumulator_value = 0.0;\n  float lambda1 = 0.0;\n  float lambda2 = 0.0;\n  float beta = 0.0;\n\n  const FtrlModelUpdateConf& ftrl_conf = optimizer_conf.ftrl_conf();\n  lr_power = ftrl_conf.lr_power();\n  initial_accumulator_value = ftrl_conf.initial_accumulator_value();\n  lambda1 = ftrl_conf.lambda1();\n  lambda2 = ftrl_conf.lambda2();\n  beta = ftrl_conf.beta();\n\n  const std::string& learning_rate_lbn = optimizer_conf.learning_rate_lbn();\n  OperatorConf accumulator_var(\n      GenerateFtrlHelperVariableConf(*var_op, \"accumulate\", initial_accumulator_value));\n  OperatorConf z_var(GenerateFtrlHelperVariableConf(*var_op, \"z\", 0.0));\n  job_builder->AddOps(var_op_node.parallel_desc().parallel_conf(), {accumulator_var, z_var});\n\n  ftrl_update_op_builder.OpTypeName(\"ftrl_update\")\n      .Input(\"model\", GenLogicalBlobName(var_op->BnInOp2Lbi(\"out\")))\n      .Input(\"model_diff\", model_diff_lbn)\n      .Input(\"learning_rate\", learning_rate_lbn)\n      .Input(\"accumulate\", GenVariableOutputLbn(accumulator_var))\n      .Input(\"z\", GenVariableOutputLbn(z_var))\n      .Attr<float>(\"lr_power\", lr_power)\n      .Attr<float>(\"lambda1\", lambda1)\n      .Attr<float>(\"lambda2\", lambda2)\n      .Attr<float>(\"beta\", beta)\n      .Attr<float>(\"weight_decay\", GetOptimizerWeightDecayRate(optimizer_conf, *var_op))\n      .ScopeSymbolId(var_op->op_conf().scope_symbol_id());\n  if (optimizer_conf.has_lr_scale()) {\n    ftrl_update_op_builder.Attr<float>(\"learning_rate_scale\", optimizer_conf.lr_scale());\n  }\n  SetDynamicLossScaleSkipIf(ctx, &ftrl_update_op_builder);\n  const auto ftrl_update_op = ftrl_update_op_builder.Build();\n  job_builder->AddOps(var_op_node.parallel_desc().parallel_conf(), {ftrl_update_op.op_conf()});\n}\n\n}  // namespace\n\nREGISTER_OPTIMIZER(OptimizerConf::kFtrlConf, &GenerateFtrlOptimizerOpConf);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job_rewriter/fuse_add_to_output_pass.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/hash_container.h\"\n#include \"oneflow/core/common/just.h\"\n#include \"oneflow/core/job_rewriter/job_pass.h\"\n#include \"oneflow/core/framework/framework.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nclass FuseAddToOutputPass final : public JobPass {\n public:\n  FuseAddToOutputPass() = default;\n  ~FuseAddToOutputPass() override = default;\n\n  bool IsEnabled(const JobPassCtx& ctx) const {\n    return ctx.job_desc().job_conf().enable_fuse_add_to_output();\n  }\n  Maybe<void> Apply(const OpGraph& op_graph, JobBuilder* job_builder) const;\n\n  Maybe<void> Apply(Job* job, JobPassCtx* ctx) const override {\n    if (!IsEnabled(*ctx)) { return Maybe<void>::Ok(); }\n    const OpGraph op_graph(*job);\n    JobBuilder job_builder(job);\n    return Apply(op_graph, &job_builder);\n  }\n};\n\nMaybe<void> FuseAddToOutputPass::Apply(const OpGraph& op_graph, JobBuilder* job_builder) const {\n  const HashMap<std::string, user_op::OpArg> supported_op_type_name2output_arg(\n      {{\"normalization\", user_op::OpArg(\"y\", 0)},\n       {\"dropout\", user_op::OpArg(\"out\", 0)},\n       {\"matmul\", user_op::OpArg(\"out\", 0)},\n       {\"layer_norm_grad\", user_op::OpArg(\"dx\", 0)},\n       {\"batch_matmul\", user_op::OpArg(\"out\", 0)},\n       {\"fused_bias_add_mask_scale\", user_op::OpArg(\"out\", 0)},\n       {\"fused_matmul_bias\", user_op::OpArg(\"out\", 0)},\n       {\"broadcast_matmul\", user_op::OpArg(\"out\", 0)},\n       {\"broadcast_matmul_grad_b\", user_op::OpArg(\"out\", 0)}});\n  HashSet<std::string> consumer_op_names;\n  auto IsAddToOutputSupported = [&](const OpNode* node, const LogicalBlobId& lbi) -> bool {\n    const OperatorConf& op_conf = node->op().op_conf();\n    if (!op_conf.has_user_conf()) { return false; }\n    if (consumer_op_names.count(op_conf.name()) > 0) { return false; }\n    auto it = supported_op_type_name2output_arg.find(op_conf.user_conf().op_type_name());\n    if (it == supported_op_type_name2output_arg.end()) { return false; }\n    const user_op::UserOpConfWrapper user_op_conf(op_conf);\n    if (GenLogicalBlobId(user_op_conf.output(it->second.name(), it->second.index())) != lbi) {\n      return false;\n    }\n    // add op should be the only consumer\n    int64_t output_consumer_cnt = 0;\n    for (const OpEdge* out_edge : node->out_edges()) {\n      if (std::find(out_edge->lbis().cbegin(), out_edge->lbis().cend(), lbi)\n          != out_edge->lbis().cend()) {\n        output_consumer_cnt += 1;\n      }\n    }\n    if (output_consumer_cnt != 1) { return false; }\n    // already fused\n    if (user_op_conf.has_input(\"_add_to_output\", 0)) { return false; }\n    return true;\n  };\n\n  // Save all op's ctrl in op name in a set.\n  HashSet<std::string> ctrl_in_op_names;\n  op_graph.ForEachNode([&](const OpNode* op_node) {\n    for (const std::string& ctrl_in_op_name : op_node->op().op_conf().ctrl_in_op_name()) {\n      ctrl_in_op_names.insert(ctrl_in_op_name);\n    }\n  });\n\n  auto IsReachable = op_graph.MakePredicatorIsOpNameDataOrCtrlReachable();\n  std::vector<OperatorConf> delete_ops;\n  HashSet<std::string> be_fused_op_names;\n  JUST(op_graph.MaybeForEachNode([&](const OpNode* op_node) -> Maybe<void> {\n    const OperatorConf& op_conf = op_node->op().op_conf();\n    if (!op_conf.has_user_conf()) { return Maybe<void>::Ok(); }\n    if (!op_conf.ctrl_in_op_name().empty()) { return Maybe<void>::Ok(); }\n    if (ctrl_in_op_names.find(op_conf.name()) != ctrl_in_op_names.end()) {\n      return Maybe<void>::Ok();\n    }\n    if (op_conf.user_conf().op_type_name() != \"add_n\") { return Maybe<void>::Ok(); }\n    if (be_fused_op_names.count(op_conf.name()) > 0) { return Maybe<void>::Ok(); }\n    if (consumer_op_names.count(op_conf.name()) > 0) { return Maybe<void>::Ok(); }\n    const user_op::UserOpConfWrapper user_op_conf(op_conf);\n    if (user_op_conf.input_size(\"in\") != 2) { return Maybe<void>::Ok(); }\n\n    const LogicalBlobId in_0 = GenLogicalBlobId(user_op_conf.input(\"in\", 0));\n    const LogicalBlobId in_1 = GenLogicalBlobId(user_op_conf.input(\"in\", 1));\n    const LogicalBlobId out = GenLogicalBlobId(user_op_conf.output(\"out\", 0));\n    const OpNode* in_0_node = op_graph.OpNode4OpName(in_0.op_name());\n    const OpNode* in_1_node = op_graph.OpNode4OpName(in_1.op_name());\n\n    const OpNode* add_to_node;\n    const LogicalBlobId* add_to_lbi;\n    const LogicalBlobId* sum_lbi;\n    if ((!IsReachable(in_0.op_name(), in_1.op_name())) && IsAddToOutputSupported(in_0_node, in_0)) {\n      add_to_node = in_0_node;\n      add_to_lbi = &in_1;\n      sum_lbi = &in_0;\n      be_fused_op_names.insert(in_1.op_name());\n    } else if ((!IsReachable(in_1.op_name(), in_0.op_name()))\n               && IsAddToOutputSupported(in_1_node, in_1)) {\n      add_to_node = in_1_node;\n      add_to_lbi = &in_0;\n      sum_lbi = &in_1;\n      be_fused_op_names.insert(in_0.op_name());\n    } else {\n      return Maybe<void>::Ok();\n    }\n    // Make a new_add_to_op to fuse add_n into this op.\n    if (JUST(job_builder->IsInMutOpTransaction(add_to_node->op().op_name()))) {\n      OperatorConf& new_add_to_op_conf =\n          JUST(job_builder->MutOpTransactionGet(add_to_node->op().op_name()));\n      *(*(new_add_to_op_conf.mutable_user_conf()->mutable_input()))[\"_add_to_output\"]\n           .mutable_s()\n           ->Add() = GenLogicalBlobName(*add_to_lbi);\n    } else {\n      OperatorConf new_add_to_op_conf = add_to_node->op().op_conf();\n      *(*(new_add_to_op_conf.mutable_user_conf()->mutable_input()))[\"_add_to_output\"]\n           .mutable_s()\n           ->Add() = GenLogicalBlobName(*add_to_lbi);\n      JUST(job_builder->MutOpTransactionMut(new_add_to_op_conf));\n    }\n    for (const OpEdge* out_edge : op_node->out_edges()) {\n      const OpNode* consumer = out_edge->dst_node();\n      const std::string& consumer_op_name = consumer->op().op_name();\n      if (consumer_op_names.count(consumer_op_name) == 0) {\n        if (!JUST(job_builder->IsInMutOpTransaction(consumer->op().op_name()))) {\n          consumer_op_names.insert(consumer_op_name);\n          JUST(job_builder->MutOpTransactionMut(consumer->op().op_conf()));\n        }\n      }\n      // Make add_n op's consumer to consume the new_add_to_op\n      for (const std::string& ibn : consumer->op().input_bns()) {\n        if (consumer->op().BnInOp2Lbi(ibn) == out) {\n          OperatorConf& consumer_op_conf = JUST(job_builder->MutOpTransactionGet(consumer_op_name));\n          const auto& new_val = GenLogicalBlobName(*sum_lbi);\n          const auto& old_val = ReplaceInputLbnInOpCustomizedConf(&consumer_op_conf, ibn, new_val);\n          CHECK_EQ(GenLogicalBlobName(out), old_val);\n        }\n      }\n    }\n    // Add the add_n op to removing list\n    delete_ops.emplace_back(op_conf);\n    return Maybe<void>::Ok();\n  }));\n  JUST(job_builder->MutOpTransactionCommit());\n  job_builder->DelOps(delete_ops);\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\nREGISTER_JOB_PASS(\"FuseAddToOutputPass\", FuseAddToOutputPass);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job_rewriter/fuse_bce_reduce_mean_fw_bw_pass.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/job_rewriter/job_pass.h\"\n#include \"oneflow/core/framework/framework.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nvoid UpdateConsumerOpConf(const OpNode* consumer, const LogicalBlobId& out,\n                          const std::string& new_out_lbn,\n                          HashMap<std::string, OperatorConf>* op_name2op_conf) {\n  const std::string& consumer_op_name = consumer->op().op_name();\n  if (op_name2op_conf->find(consumer_op_name) == op_name2op_conf->end()) {\n    (*op_name2op_conf)[consumer_op_name] = consumer->op().op_conf();\n  }\n  for (const std::string& ibn : consumer->op().input_bns()) {\n    if (consumer->op().BnInOp2Lbi(ibn) == out) {\n      OperatorConf& consumer_op_conf = op_name2op_conf->at(consumer_op_name);\n      const auto& new_val = new_out_lbn;\n      const auto& old_val = ReplaceInputLbnInOpCustomizedConf(&consumer_op_conf, ibn, new_val);\n      CHECK_EQ(GenLogicalBlobName(out), old_val);\n    }\n  }\n}\n\nclass FuseBCEReduceMeanFwBwPass final : public JobPass {\n public:\n  FuseBCEReduceMeanFwBwPass() = default;\n  ~FuseBCEReduceMeanFwBwPass() override = default;\n\n  bool IsEnabled(const JobPassCtx& ctx) const {\n    return ParseBooleanFromEnv(\"ONEFLOW_FUSE_BCE_REDUCE_MEAN_FW_BW\", false);\n  }\n  Maybe<void> Apply(const OpGraph& op_graph, JobBuilder* job_builder) const;\n\n  Maybe<void> Apply(Job* job, JobPassCtx* ctx) const override {\n    if (!IsEnabled(*ctx)) { return Maybe<void>::Ok(); }\n    const OpGraph op_graph(*job);\n    JobBuilder job_builder(job);\n    return Apply(op_graph, &job_builder);\n  }\n};\n\nMaybe<void> FuseBCEReduceMeanFwBwPass::Apply(const OpGraph& op_graph,\n                                             JobBuilder* job_builder) const {\n  // This pass fuse binary_cross_entropy_with_logits_reduce_mean and\n  // binary_cross_entropy_with_logits_reduce_mean_grad. delete the h2f cast to loss, and the\n  // constant_like of dy.\n  const auto IsSafeToDelete = MakePredicatorIsSafeToDelete(op_graph);\n  HashMap<std::string, OperatorConf> op_name2op_conf;\n  std::vector<OperatorConf> delete_ops;\n  op_graph.ForEachNode([&](const OpNode* op_node) {\n    if (!IsUserOpWithTypeName(op_node->op().op_conf(),\n                              \"binary_cross_entropy_with_logits_reduce_mean\")) {\n      return;\n    }\n    if (op_node->out_edges().size() > 2) { return; }\n    bool find_grad_op = false;\n    for (const OpEdge* out_edge : op_node->out_edges()) {\n      const OpNode* consumer = out_edge->dst_node();\n      if (!IsSafeToDelete(consumer)) { return; }\n      if (!(IsUserOpWithTypeName(consumer->op().op_conf(), \"cast\")\n            || consumer->op().op_conf().has_constant_like_conf()\n            || consumer->op().op_conf().has_output_conf())) {\n        return;\n      }\n      if (consumer->op().op_conf().has_constant_like_conf()) {\n        const OpNode* grad_node = consumer->SoleOutEdge()->dst_node();\n        if (!IsUserOpWithTypeName(grad_node->op().op_conf(),\n                                  \"binary_cross_entropy_with_logits_reduce_mean_grad\")) {\n          return;\n        }\n        find_grad_op = true;\n        if (!IsSafeToDelete(grad_node)) { return; }\n      }\n    }\n    if (!find_grad_op) { return; }\n    const user_op::UserOpConfWrapper bce_op_conf(op_node->op().op_conf());\n    user_op::UserOpConfWrapperBuilder fused_op_builder(bce_op_conf.op_name());\n    fused_op_builder.OpTypeName(\"fused_bce_reduce_mean_fw_bw\")\n        .Input(\"input\", bce_op_conf.input(\"input\", 0))\n        .Input(\"target\", bce_op_conf.input(\"target\", 0))\n        .Output(\"out\")\n        .Output(\"dx\");\n    for (const OpEdge* out_edge : op_node->out_edges()) {\n      const OpNode* consumer = out_edge->dst_node();\n      if (IsUserOpWithTypeName(consumer->op().op_conf(), \"cast\")) {\n        const user_op::UserOpConfWrapper cast_conf(consumer->op().op_conf());\n        fused_op_builder.Attr<DataType>(\"out_dtype\", cast_conf.attr<DataType>(\"dtype\"));\n        // delete cast and update cast consumer's in.\n        delete_ops.push_back(consumer->op().op_conf());\n        for (const OpEdge* cast_out_edge : consumer->out_edges()) {\n          const OpNode* cast_consumer = cast_out_edge->dst_node();\n          UpdateConsumerOpConf(cast_consumer, GenLogicalBlobId(cast_conf.output(\"out\", 0)),\n                               GenLogicalBlobName(bce_op_conf.op_name(), \"out_0\"),\n                               &op_name2op_conf);\n        }\n      } else if (consumer->op().op_conf().has_constant_like_conf()) {\n        fused_op_builder.Attr<double>(\n            \"constant_value\", consumer->op().op_conf().constant_like_conf().float_operand());\n        const OpNode* grad_node = consumer->SoleOutEdge()->dst_node();\n        // delete constant_like and grad op, update consumer\n        delete_ops.push_back(grad_node->op().op_conf());\n        delete_ops.push_back(consumer->op().op_conf());\n        const user_op::UserOpConfWrapper grad_conf(grad_node->op().op_conf());\n        for (const OpEdge* grad_out_edge : grad_node->out_edges()) {\n          const OpNode* grad_consumer = grad_out_edge->dst_node();\n          UpdateConsumerOpConf(grad_consumer, GenLogicalBlobId(grad_conf.output(\"dx\", 0)),\n                               GenLogicalBlobName(bce_op_conf.op_name(), \"dx_0\"), &op_name2op_conf);\n        }\n      } else {\n        continue;\n      }\n    }\n    user_op::UserOpConfWrapper fused_op =\n        fused_op_builder.ScopeSymbolId(bce_op_conf.op_conf().scope_symbol_id()).Build();\n    job_builder->MutOpsOnlyOnce({fused_op.op_conf()});\n  });\n  job_builder->DelOps(delete_ops);\n  for (const auto& pair : op_name2op_conf) { job_builder->MutOpsOnlyOnce({pair.second}); }\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\nREGISTER_JOB_PASS(\"FuseBCEReduceMeanFwBwPass\", FuseBCEReduceMeanFwBwPass);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job_rewriter/fuse_cast_scale_pass.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/job_rewriter/job_pass.h\"\n#include \"oneflow/core/framework/framework.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nclass FuseCastScalePass final : public JobPass {\n public:\n  FuseCastScalePass() = default;\n  ~FuseCastScalePass() override = default;\n\n  bool IsEnabled(const JobPassCtx& ctx) const {\n    return ctx.job_desc().job_conf().enable_fuse_cast_scale();\n  }\n  Maybe<void> Apply(const OpGraph& op_graph, JobBuilder* job_builder) const;\n\n  Maybe<void> Apply(Job* job, JobPassCtx* ctx) const override {\n    if (!IsEnabled(*ctx)) { return Maybe<void>::Ok(); }\n    const OpGraph op_graph(*job);\n    JobBuilder job_builder(job);\n    return Apply(op_graph, &job_builder);\n  }\n};\n\nMaybe<void> FuseCastScalePass::Apply(const OpGraph& op_graph, JobBuilder* job_builder) const {\n  const auto IsSafeToDelete = MakePredicatorIsSafeToDelete(op_graph);\n  std::vector<OperatorConf> delete_ops;\n  op_graph.ForEachNode([&](const OpNode* op_node) {\n    if (!IsUserOpWithTypeName(op_node->op().op_conf(), \"cast\")) { return; }\n    if (!IsSafeToDelete(op_node)) { return; }\n    if (op_node->out_edges().size() != 1) { return; }\n    OpNode* sole_dst_node = op_node->SoleOutEdge()->dst_node();\n    if (IsUserOpWithTypeName(sole_dst_node->op().op_conf(), \"scalar_mul\")) {\n      if (!IsSafeToDelete(sole_dst_node)) { return; }\n      if (!IsUserOpWithTypeName(sole_dst_node->SoleOutEdge()->dst_node()->op().op_conf(),\n                                \"scalar_mul_by_tensor\")) {\n        return;\n      }\n    } else {\n      if (!IsUserOpWithTypeName(sole_dst_node->op().op_conf(), \"scalar_mul_by_tensor\")) { return; }\n    }\n    const user_op::UserOpConfWrapper cast_user_conf(op_node->op().op_conf());\n    if (op_node->LogicalBlobDesc4Lbi(GenLogicalBlobId(cast_user_conf.input(\"in\", 0))).data_type()\n            != DataType::kFloat16\n        && op_node->LogicalBlobDesc4Lbi(GenLogicalBlobId(cast_user_conf.input(\"in\", 0))).data_type()\n               != DataType::kBFloat16) {\n      return;\n    }\n    if (op_node->LogicalBlobDesc4Lbi(GenLogicalBlobId(cast_user_conf.output(\"out\", 0))).data_type()\n        != DataType::kFloat) {\n      return;\n    }\n    if (op_node->parallel_desc().device_type() != DeviceType::kCUDA) { return; }\n    double scale = 1.0;\n    if (IsUserOpWithTypeName(sole_dst_node->op().op_conf(), \"scalar_mul\")) {\n      const user_op::UserOpConfWrapper scalar_mul_op_conf(sole_dst_node->op().op_conf());\n      if (scalar_mul_op_conf.attr<bool>(\"has_int_operand\")) {\n        scale = static_cast<double>(scalar_mul_op_conf.attr<int64_t>(\"int_operand\"));\n      } else if (scalar_mul_op_conf.attr<bool>(\"has_float_operand\")) {\n        scale = scalar_mul_op_conf.attr<double>(\"float_operand\");\n      } else {\n        UNIMPLEMENTED();\n      }\n      delete_ops.emplace_back(sole_dst_node->op().op_conf());\n      sole_dst_node = sole_dst_node->SoleOutEdge()->dst_node();\n    }\n    delete_ops.emplace_back(op_node->op().op_conf());\n    const user_op::UserOpConfWrapper scale_user_conf(sole_dst_node->op().op_conf());\n\n    user_op::UserOpConfWrapperBuilder fused_op_builder(sole_dst_node->op().op_name());\n    fused_op_builder.OpTypeName(\"fused_cast_scale\")\n        .Input(\"x\", cast_user_conf.input(\"in\", 0))\n        .Input(\"scale_by_tensor\", scale_user_conf.input(\"scalar\", 0))\n        .Attr<double>(\"scale\", scale)\n        .Output(\"y\");\n\n    OperatorConf new_op_conf = sole_dst_node->op().op_conf();\n    *new_op_conf.mutable_user_conf() = fused_op_builder.Build().op_conf().user_conf();\n\n    job_builder->MutOpsOnlyOnce({new_op_conf});\n  });\n  job_builder->DelOps(delete_ops);\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\nREGISTER_JOB_PASS(\"FuseCastScalePass\", FuseCastScalePass);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job_rewriter/fuse_consecutive_add_pass.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/graph/op_graph.h\"\n#include \"oneflow/core/job_rewriter/job_pass.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/common/cost_util.h\"\n#include \"oneflow/core/operator/operator.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nclass FuseConsecutiveAddPass final : public JobPass {\n public:\n  FuseConsecutiveAddPass() = default;\n  ~FuseConsecutiveAddPass() override = default;\n\n  Maybe<void> Apply(const OpGraph& op_graph, JobBuilder* job_builder) const;\n\n  Maybe<void> Apply(Job* job, JobPassCtx* ctx) const override {\n    const OpGraph op_graph(*job);\n    JobBuilder job_builder(job);\n    JUST(Apply(op_graph, &job_builder));\n    return Maybe<void>::Ok();\n  }\n};\n\nMaybe<void> FuseConsecutiveAddPass::Apply(const OpGraph& op_graph, JobBuilder* job_builder) const {\n  const auto IsSafeToDelete = MakePredicatorIsSafeToDelete(op_graph);\n  std::vector<std::string> delete_ops;\n  op_graph.TopoForEachNode([&](const OpNode* op_node) {\n    if (!IsUserOpWithTypeName(op_node->op().op_conf(), \"add_n\") || !IsSafeToDelete(op_node)\n        || op_node->out_edges().size() != 1) {\n      return;\n    }\n    OpNode* sole_dst_node = op_node->SoleOutEdge()->dst_node();\n    if (!IsUserOpWithTypeName(sole_dst_node->op().op_conf(), \"add_n\")\n        || !IsSafeToDelete(sole_dst_node)) {\n      return;\n    }\n\n    const std::string this_op_name = op_node->op().op_name();\n\n    const auto& GetCurOpConf = [&](const OpNode& cur_op) -> OperatorConf {\n      const std::string& cur_op_name = cur_op.op().op_name();\n      if (!CHECK_JUST(job_builder->IsInMutOpTransaction(cur_op_name))) {\n        return cur_op.op().op_conf();\n      } else {\n        return CHECK_JUST(job_builder->MutOpTransactionGet(cur_op_name));\n      }\n    };\n\n    int64_t fused_cnt = 0;\n    auto fused_op_conf = GetCurOpConf(*sole_dst_node);\n    auto in_it = fused_op_conf.mutable_user_conf()->mutable_input()->find(\"in\");\n    CHECK(in_it != fused_op_conf.mutable_user_conf()->mutable_input()->end());\n    auto* in_lbns = in_it->second.mutable_s();\n    auto in_lbn_it = in_lbns->begin();\n    while (in_lbn_it != in_lbns->end()) {\n      const auto lbi = GenLogicalBlobId(*in_lbn_it);\n      if (lbi.op_name() == this_op_name) {\n        in_lbn_it = in_lbns->erase(in_lbn_it);\n        ++fused_cnt;\n      } else {\n        ++in_lbn_it;\n      }\n    }\n\n    const auto& this_op_conf = GetCurOpConf(*op_node);\n    auto this_in_it = this_op_conf.user_conf().input().find(\"in\");\n    CHECK(this_in_it != this_op_conf.user_conf().input().end());\n    for (int64_t fuse_i = 0; fuse_i < fused_cnt; ++fuse_i) {\n      for (const auto& this_in_lbn : this_in_it->second.s()) { *(in_lbns->Add()) = this_in_lbn; }\n    }\n\n    CHECK_JUST(job_builder->MutOpTransactionMut(fused_op_conf));\n    delete_ops.emplace_back(this_op_name);\n  });\n\n  if (delete_ops.empty()) { return Maybe<void>::Ok(); }\n  JUST(job_builder->MutOpTransactionCommit());\n  job_builder->DelOps(delete_ops);\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\nREGISTER_JOB_PASS(\"FuseConsecutiveAddPass\", FuseConsecutiveAddPass);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job_rewriter/fuse_embedding_interaction_pass.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/job_rewriter/job_pass.h\"\n#include \"oneflow/core/framework/framework.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nclass FuseEmbeddingShuffleInteractionPass final : public JobPass {\n public:\n  FuseEmbeddingShuffleInteractionPass() = default;\n  ~FuseEmbeddingShuffleInteractionPass() override = default;\n\n  bool IsEnabled(const JobPassCtx& ctx) const {\n    // if enable quantize, not support fuse kernel.\n    bool enable_quantized_comm =\n        ParseBooleanFromEnv(\"ONEFLOW_ONE_EMBEDDING_ENABLE_QUANTIZED_COMM\", false);\n    bool enable_fuse_embedding_interaction =\n        ParseBooleanFromEnv(\"ONEFLOW_ONE_EMBEDDING_FUSE_EMBEDDING_INTERACTION\", false);\n    return (!enable_quantized_comm && enable_fuse_embedding_interaction);\n  }\n  Maybe<void> Apply(const OpGraph& op_graph, JobBuilder* job_builder) const;\n\n  Maybe<void> Apply(Job* job, JobPassCtx* ctx) const override {\n    if (!IsEnabled(*ctx)) { return Maybe<void>::Ok(); }\n    const OpGraph op_graph(*job);\n    JobBuilder job_builder(job);\n    return Apply(op_graph, &job_builder);\n  }\n};\n\nMaybe<void> FuseEmbeddingShuffleInteractionPass::Apply(const OpGraph& op_graph,\n                                                       JobBuilder* job_builder) const {\n  op_graph.ForEachNode([&](const OpNode* op_node) {\n    if (!IsUserOpWithTypeName(op_node->op().op_conf(), \"embedding_shuffle\")) { return; }\n    if (op_node->out_edges().size() > 2) { return; }\n    const user_op::UserOpConfWrapper embedding_shuffle_conf(op_node->op().op_conf());\n    const std::string& embeddings_lbn = embedding_shuffle_conf.output(\"embeddings\", 0);\n    const std::string& indices_lbn =\n        embedding_shuffle_conf.input(\"inverse_unique_partition_indices\", 0);\n    const std::string& num_unique_matrix_lbn = embedding_shuffle_conf.input(\"num_unique_matrix\", 0);\n    if (op_node->LogicalBlobDesc4Lbi(GenLogicalBlobId(embeddings_lbn)).data_type()\n            != DataType::kFloat16\n        || embedding_shuffle_conf.attr<int64_t>(\"embedding_size\") % 2 != 0) {\n      // only support half and embedding_size % 2 == 0 fuse, because atomicAdd half is slow.\n      return;\n    }\n    if (op_node->LogicalBlobDesc4Lbi(GenLogicalBlobId(indices_lbn)).data_type()\n        != DataType::kUInt32) {\n      // only support indices with uint32_t dtype\n      return;\n    }\n    if (op_node->LogicalBlobDesc4Lbi(GenLogicalBlobId(num_unique_matrix_lbn)).data_type()\n        != DataType::kUInt32) {\n      // only support num_unique with uint32_t dtype\n      return;\n    }\n    for (const OpEdge* out_edge : op_node->out_edges()) {\n      const OpNode* consumer = out_edge->dst_node();\n      if (!consumer->op().op_conf().has_user_conf()) { return; }\n      const user_op::UserOpConfWrapper consumer_op_conf(consumer->op().op_conf());\n      if (!(consumer_op_conf.op_type_name() == \"fused_dot_feature_interaction\"\n            || consumer_op_conf.op_type_name() == \"fused_dot_feature_interaction_grad\")) {\n        return;\n      }\n      if (consumer_op_conf.attr<std::string>(\"pooling\") != \"none\") { return; }\n      int input_size = consumer_op_conf.input_size(\"features\");\n      CHECK_GT(input_size, 0) << input_size;\n      if (consumer_op_conf.input(\"features\", input_size - 1) != embeddings_lbn) {\n        // only support embeddings as last feature\n        return;\n      }\n      user_op::UserOpConfWrapperBuilder fused_op_builder(consumer_op_conf.op_name());\n      const std::string& op_type_name = consumer_op_conf.op_type_name();\n      fused_op_builder.OpTypeName(op_type_name)\n          .Input(\"sparse_feature\", embeddings_lbn)\n          .Input(\"sparse_indices\", indices_lbn)\n          .Input(\"num_valid_sparse_feature\", num_unique_matrix_lbn)\n          .Attr<bool>(\"self_interaction\", consumer_op_conf.attr<bool>(\"self_interaction\"))\n          .Attr<std::string>(\"pooling\", consumer_op_conf.attr<std::string>(\"pooling\"));\n      for (int i = 0; i < input_size - 1; ++i) {\n        fused_op_builder.Input(\"features\", consumer_op_conf.input(\"features\", i));\n      }\n      OperatorConf new_op_conf = consumer->op().op_conf();\n      if (op_type_name == \"fused_dot_feature_interaction\") {\n        if (consumer_op_conf.has_input(\"output_concat\", 0)) {\n          fused_op_builder.Input(\"output_concat\", consumer_op_conf.input(\"output_concat\", 0));\n        }\n        fused_op_builder.Output(\"out\")\n            .Attr<bool>(\"has_output_concat\", consumer_op_conf.attr<bool>(\"has_output_concat\"))\n            .Attr<int32_t>(\"output_padding\", consumer_op_conf.attr<int32_t>(\"output_padding\"));\n        *new_op_conf.mutable_user_conf() = fused_op_builder.Build().op_conf().user_conf();\n      } else {\n        // fused_dot_feature_interaction_grad\n        fused_op_builder.Input(\"dy\", consumer_op_conf.input(\"dy\", 0))\n            .Output(\"features_grad\", input_size - 1)\n            .Output(\"sparse_feature_grad\")\n            .Attr<int32_t>(\"output_concat_grad_dim\",\n                           consumer_op_conf.attr<int32_t>(\"output_concat_grad_dim\"));\n        if (consumer_op_conf.has_output(\"output_concat_grad\", 0)) {\n          fused_op_builder.Output(\"output_concat_grad\");\n        }\n        user_op::UserOpConfWrapper fused_dot_feature_interaction_grad_op = fused_op_builder.Build();\n        *new_op_conf.mutable_user_conf() =\n            fused_dot_feature_interaction_grad_op.op_conf().user_conf();\n        const LogicalBlobId last_feature_grad_lbi =\n            GenLogicalBlobId(consumer_op_conf.output(\"features_grad\", input_size - 1));\n        std::string sparse_feature_grad_lbn =\n            fused_dot_feature_interaction_grad_op.output(\"sparse_feature_grad\", 0);\n        for (const OpEdge* out_edge : consumer->out_edges()) {\n          const OpNode* grad_out_node = out_edge->dst_node();\n          if (out_edge->lbis().size() == 1 && out_edge->lbis().front() == last_feature_grad_lbi) {\n            if (!IsUserOpWithTypeName(grad_out_node->op().op_conf(),\n                                      \"embedding_gradient_shuffle\")) {\n              return;\n            }\n            OperatorConf new_embedding_gradient_shuffle_conf = grad_out_node->op().op_conf();\n            for (const std::string& ibn : grad_out_node->op().input_bns()) {\n              if (grad_out_node->op().BnInOp2Lbi(ibn) == last_feature_grad_lbi) {\n                const auto& new_val = sparse_feature_grad_lbn;\n                const auto& old_val = ReplaceInputLbnInOpCustomizedConf(\n                    &new_embedding_gradient_shuffle_conf, ibn, new_val);\n                CHECK_EQ(GenLogicalBlobName(last_feature_grad_lbi), old_val);\n              }\n            }\n            auto bool_attr = ::oneflow::AttrValue();\n            bool_attr.set_at_bool(true);\n            (*(new_embedding_gradient_shuffle_conf.mutable_user_conf()\n                   ->mutable_attr()))[\"skip_first_scatter\"] = bool_attr;\n            job_builder->MutOpsOnlyOnce({new_embedding_gradient_shuffle_conf});\n          }\n        }\n      }\n      job_builder->MutOpsOnlyOnce({new_op_conf});\n    }\n    auto bool_attr = ::oneflow::AttrValue();\n    bool_attr.set_at_bool(true);\n    OperatorConf new_embedding_shuffle_conf = op_node->op().op_conf();\n    (*(new_embedding_shuffle_conf.mutable_user_conf()->mutable_attr()))[\"skip_last_gather\"] =\n        bool_attr;\n    job_builder->MutOpsOnlyOnce({new_embedding_shuffle_conf});\n  });\n\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\nREGISTER_JOB_PASS(\"FuseEmbeddingShuffleInteractionPass\", FuseEmbeddingShuffleInteractionPass);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job_rewriter/fuse_model_update_cast_pass.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/job_rewriter/job_pass.h\"\n#include \"oneflow/core/framework/framework.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nclass FuseModelUpdateCastOpsPass final : public JobPass {\n public:\n  FuseModelUpdateCastOpsPass() = default;\n  ~FuseModelUpdateCastOpsPass() override = default;\n\n  bool IsEnabled(const JobPassCtx& ctx) const {\n    return (ctx.job_desc().enable_fused_model_update_cast()\n            || ParseBooleanFromEnv(\"ONEFLOW_FUSE_MODEL_UPDATE_CAST\", false))\n           && ctx.job_desc().enable_auto_mixed_precision();\n  }\n  Maybe<void> Apply(const OpGraph& op_graph, JobBuilder* job_builder) const;\n\n  Maybe<void> Apply(Job* job, JobPassCtx* ctx) const override {\n    if (!IsEnabled(*ctx)) { return Maybe<void>::Ok(); }\n    LOG(INFO) << \"Enable fuse model update cast pass. \";\n    const OpGraph op_graph(*job);\n    JobBuilder job_builder(job);\n    return Apply(op_graph, &job_builder);\n  }\n};\n\nMaybe<void> FuseModelUpdateCastOpsPass::Apply(const OpGraph& op_graph,\n                                              JobBuilder* job_builder) const {\n  op_graph.ForEachNode([&](OpNode* op_node) {\n    const auto& op_conf = op_node->op().op_conf();\n    if (!op_conf.has_variable_conf()) { return; }\n    LogicalBlobId model_copy_lbi;\n\n    for (OpEdge* find_cast_edge : op_node->out_edges()) {\n      OpNode* find_cast_node = find_cast_edge->dst_node();\n      if (!IsUserOpWithTypeName(find_cast_node->op().op_conf(), \"cast\")) { continue; }\n      const user_op::UserOpConfWrapper cast_user_conf(find_cast_node->op().op_conf());\n      if (find_cast_node->LogicalBlobDesc4Lbi(GenLogicalBlobId(cast_user_conf.input(\"in\", 0)))\n              .data_type()\n          != DataType::kFloat) {\n        continue;\n      }\n      if (find_cast_node->LogicalBlobDesc4Lbi(GenLogicalBlobId(cast_user_conf.output(\"out\", 0)))\n              .data_type()\n          != DataType::kFloat16) {\n        continue;\n      }\n      // Currently only support for cuda, maybe remove this limit.\n      if (find_cast_node->parallel_desc().device_type() != DeviceType::kCUDA) { continue; }\n\n      for (OpEdge* find_model_update_edge : op_node->out_edges()) {\n        OpNode* find_model_update_update_node = find_model_update_edge->dst_node();\n        if (!IsUserOpWithTypeName(find_model_update_update_node->op().op_conf(), \"sgd_update\")\n            && !IsUserOpWithTypeName(find_model_update_update_node->op().op_conf(),\n                                     \"adam_update\")) {\n          continue;\n        }\n\n        // Currently only support for cuda, maybe remove this limit.\n        if (find_model_update_update_node->parallel_desc().device_type() != DeviceType::kCUDA) {\n          continue;\n        }\n\n        const user_op::UserOpConfWrapper model_update_user_conf(\n            find_model_update_update_node->op().op_conf());\n\n        // Here we find cast and model_update node, Replace cast as mutable_cast_once, and add\n        // model_copy to model_update node.\n        user_op::UserOpConfWrapperBuilder fused_cast_op_builder(cast_user_conf.op_name());\n        fused_cast_op_builder.OpTypeName(\"mutable_cast_once\")\n            .Input(\"in\", cast_user_conf.input(\"in\", 0))\n            .Attr<DataType>(\"dtype\", cast_user_conf.attr<DataType>(\"dtype\"))\n            .Output(\"out\");\n\n        CHECK(cast_user_conf.op_conf().has_scope_symbol_id());\n        fused_cast_op_builder.ScopeSymbolId(cast_user_conf.op_conf().scope_symbol_id());\n\n        OperatorConf new_cast_op_conf = cast_user_conf.op_conf();\n        *new_cast_op_conf.mutable_user_conf() = fused_cast_op_builder.Build().op_conf().user_conf();\n        job_builder->MutOpsOnlyOnce({new_cast_op_conf});\n\n        const user_op::UserOpConfWrapper new_cast_user_conf(new_cast_op_conf);\n        model_copy_lbi = GenLogicalBlobId(new_cast_user_conf.output(\"out\", 0));\n        user_op::UserOpConfWrapperBuilder fused_model_update_op_builder(\n            model_update_user_conf.op_name());\n        if (IsUserOpWithTypeName(find_model_update_update_node->op().op_conf(), \"sgd_update\")) {\n          fused_model_update_op_builder.OpTypeName(\"sgd_update\")\n              .Input(\"model\", model_update_user_conf.input(\"model\", 0))\n              .Input(\"model_diff\", model_update_user_conf.input(\"model_diff\", 0))\n              .Input(\"learning_rate\", model_update_user_conf.input(\"learning_rate\", 0))\n              .Attr<double>(\"scale\", model_update_user_conf.attr<double>(\"scale\"))\n              .Attr<float>(\"l1\", model_update_user_conf.attr<float>(\"l1\"))\n              .Attr<float>(\"l2\", model_update_user_conf.attr<float>(\"l2\"))\n              .Attr<float>(\"weight_decay\", model_update_user_conf.attr<float>(\"weight_decay\"))\n              .Attr<float>(\"learning_rate_scale\",\n                           model_update_user_conf.attr<float>(\"learning_rate_scale\"));\n        } else if (IsUserOpWithTypeName(find_model_update_update_node->op().op_conf(),\n                                        \"adam_update\")) {\n          fused_model_update_op_builder.OpTypeName(\"adam_update\")\n              .Input(\"model\", model_update_user_conf.input(\"model\", 0))\n              .Input(\"model_diff\", model_update_user_conf.input(\"model_diff\", 0))\n              .Input(\"m\", model_update_user_conf.input(\"m\", 0))\n              .Input(\"v\", model_update_user_conf.input(\"v\", 0))\n              .Input(\"learning_rate\", model_update_user_conf.input(\"learning_rate\", 0))\n              .Attr<double>(\"scale\", model_update_user_conf.attr<double>(\"scale\"))\n              .Attr<float>(\"l1\", model_update_user_conf.attr<float>(\"l1\"))\n              .Attr<float>(\"l2\", model_update_user_conf.attr<float>(\"l2\"))\n              .Attr<float>(\"weight_decay\", model_update_user_conf.attr<float>(\"weight_decay\"))\n              .Attr<float>(\"beta1\", model_update_user_conf.attr<float>(\"beta1\"))\n              .Attr<float>(\"beta2\", model_update_user_conf.attr<float>(\"beta2\"))\n              .Attr<float>(\"epsilon\", model_update_user_conf.attr<float>(\"epsilon\"))\n              .Attr<bool>(\"amsgrad\", model_update_user_conf.attr<bool>(\"amsgrad\"))\n              .Attr<bool>(\"do_bias_correction\",\n                          model_update_user_conf.attr<bool>(\"do_bias_correction\"))\n              .Attr<float>(\"learning_rate_scale\",\n                           model_update_user_conf.attr<float>(\"learning_rate_scale\"));\n          ;\n          if (model_update_user_conf.attr<bool>(\"do_bias_correction\")) {\n            fused_model_update_op_builder.Input(\n                \"bias_correction1\", model_update_user_conf.input(\"bias_correction1\", 0));\n            fused_model_update_op_builder.Input(\n                \"bias_correction2\", model_update_user_conf.input(\"bias_correction2\", 0));\n          }\n          if (model_update_user_conf.attr<bool>(\"amsgrad\")) {\n            fused_model_update_op_builder.Input(\"max_v\", model_update_user_conf.input(\"max_v\", 0));\n          }\n        } else {\n          UNIMPLEMENTED() << \"Need support more optimizers. \";\n        }\n        fused_model_update_op_builder.Input(\"model_copy\", GenLogicalBlobName(model_copy_lbi));\n        CHECK(model_update_user_conf.op_conf().has_scope_symbol_id());\n        fused_model_update_op_builder.ScopeSymbolId(\n            model_update_user_conf.op_conf().scope_symbol_id());\n\n        OperatorConf new_model_update_op_conf = model_update_user_conf.op_conf();\n        *new_model_update_op_conf.mutable_user_conf() =\n            fused_model_update_op_builder.Build().op_conf().user_conf();\n        job_builder->MutOpsOnlyOnce({new_model_update_op_conf});\n        break;\n      }\n      break;\n    }\n  });\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\nREGISTER_JOB_PASS(\"FuseModelUpdateCastOpsPass\", FuseModelUpdateCastOpsPass);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job_rewriter/fuse_update_ops_pass.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/job_rewriter/job_pass.h\"\n#include \"oneflow/core/framework/framework.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nclass FuseUpdateOpsPass final : public JobPass {\n public:\n  FuseUpdateOpsPass() = default;\n  ~FuseUpdateOpsPass() override = default;\n\n  bool IsEnabled(const JobPassCtx& ctx) const {\n    return ctx.job_desc().job_conf().enable_fuse_model_update_ops();\n  }\n  Maybe<void> Apply(const OpGraph& op_graph, JobBuilder* job_builder) const;\n\n  Maybe<void> Apply(Job* job, JobPassCtx* ctx) const override {\n    if (!IsEnabled(*ctx)) { return Maybe<void>::Ok(); }\n    const OpGraph op_graph(*job);\n    JobBuilder job_builder(job);\n    return Apply(op_graph, &job_builder);\n  }\n};\n\nMaybe<void> FuseUpdateOpsPass::Apply(const OpGraph& op_graph, JobBuilder* job_builder) const {\n  const auto IsSafeToDelete = MakePredicatorIsSafeToDelete(op_graph);\n  std::vector<std::string> del_op_names;\n  op_graph.ForEachNode([&](const OpNode* op_node) {\n    if (!op_node->op().op_conf().has_user_conf()) { return; }\n    const user_op::UserOpConfWrapper user_op_conf(op_node->op().op_conf());\n    if (user_op_conf.op_type_name() != \"sgd_update\"\n        && user_op_conf.op_type_name() != \"momentum_update\"\n        && user_op_conf.op_type_name() != \"adam_update\"\n        && user_op_conf.op_type_name() != \"rmsprop_update\"\n        && user_op_conf.op_type_name() != \"lars_update\"\n        && user_op_conf.op_type_name() != \"adagrad_update\"\n        && user_op_conf.op_type_name() != \"lamb_update\"\n        && user_op_conf.op_type_name() != \"ftrl_update\"\n        && user_op_conf.op_type_name() != \"adadelta_update\") {\n      return;\n    }\n    if (user_op_conf.attr<double>(\"scale\") != 1.0 || user_op_conf.attr<float>(\"l1\") != 0.0f\n        || user_op_conf.attr<float>(\"l2\") != 0.0f) {\n      return;\n    }\n    float l1 = 0;\n    float l2 = 0;\n    double scale = 1;\n    bool fused = false;\n    LogicalBlobId model_diff_lbi = GenLogicalBlobId(user_op_conf.input(\"model_diff\", 0));\n    std::string scale_by_tensor_lbn;\n\n    [&]() {\n      do {\n        const OpNode* producer = op_graph.OpNode4OpName(model_diff_lbi.op_name());\n        if (!IsUserOpWithTypeName(producer->op().op_conf(), \"l1_l2_regularize_gradient\")) { break; }\n        if (!IsSafeToDelete(producer)) { return; }\n        const user_op::UserOpConfWrapper l1_l2_regularize_gradient_op_conf(\n            producer->op().op_conf());\n        if (l1_l2_regularize_gradient_op_conf.input(\"model\", 0) != user_op_conf.input(\"model\", 0)) {\n          return;\n        }\n        l1 = l1_l2_regularize_gradient_op_conf.attr<float>(\"l1\");\n        l2 = l1_l2_regularize_gradient_op_conf.attr<float>(\"l2\");\n        model_diff_lbi = GenLogicalBlobId(l1_l2_regularize_gradient_op_conf.input(\"model_diff\", 0));\n        del_op_names.emplace_back(producer->op().op_name());\n        fused = true;\n      } while (false);\n\n      do {\n        const OpNode* producer = op_graph.OpNode4OpName(model_diff_lbi.op_name());\n        if (!IsUserOpWithTypeName(producer->op().op_conf(), \"scalar_mul_by_tensor\")) { break; }\n        if (!IsSafeToDelete(producer)) { return; }\n        const user_op::UserOpConfWrapper scalar_mul_by_tensor_op_conf(producer->op().op_conf());\n        model_diff_lbi = GenLogicalBlobId(scalar_mul_by_tensor_op_conf.input(\"x\", 0));\n        scale_by_tensor_lbn = scalar_mul_by_tensor_op_conf.input(\"scalar\", 0);\n        del_op_names.emplace_back(producer->op().op_name());\n        fused = true;\n      } while (false);\n\n      do {\n        const OpNode* producer = op_graph.OpNode4OpName(model_diff_lbi.op_name());\n        if (!IsUserOpWithTypeName(producer->op().op_conf(), \"scalar_mul\")) { break; }\n        if (!IsSafeToDelete(producer)) { return; }\n        const user_op::UserOpConfWrapper scalar_mul_op_conf(producer->op().op_conf());\n        if (scalar_mul_op_conf.attr<bool>(\"has_int_operand\")) {\n          scale = static_cast<double>(scalar_mul_op_conf.attr<int64_t>(\"int_operand\"));\n        } else if (scalar_mul_op_conf.attr<bool>(\"has_float_operand\")) {\n          scale = scalar_mul_op_conf.attr<double>(\"float_operand\");\n        } else {\n          UNIMPLEMENTED();\n        }\n        model_diff_lbi = GenLogicalBlobId(scalar_mul_op_conf.input(\"in\", 0));\n        del_op_names.emplace_back(producer->op().op_name());\n        fused = true;\n      } while (false);\n\n      do {\n        const OpNode* producer = op_graph.OpNode4OpName(model_diff_lbi.op_name());\n        if (!IsUserOpWithTypeName(producer->op().op_conf(), \"cast\")) { break; }\n        if (!IsSafeToDelete(producer)) { return; }\n        const user_op::UserOpConfWrapper cast_op_conf(producer->op().op_conf());\n        if (producer->LogicalBlobDesc4Lbi(GenLogicalBlobId(cast_op_conf.input(\"in\", 0))).data_type()\n                != DataType::kFloat16\n            || cast_op_conf.attr<DataType>(\"dtype\") != DataType::kFloat) {\n          return;\n        }\n        model_diff_lbi = GenLogicalBlobId(cast_op_conf.input(\"in\", 0));\n        del_op_names.emplace_back(producer->op().op_name());\n        fused = true;\n      } while (false);\n    }();\n\n    if (!fused) { return; }\n\n    const TrainConf& train_conf = job_builder->job().job_conf().train_conf();\n\n    user_op::UserOpConfWrapperBuilder fused_op_builder(user_op_conf.op_name());\n    fused_op_builder.OpTypeName(user_op_conf.op_type_name())\n        .Input(\"model\", user_op_conf.input(\"model\", 0))\n        .Input(\"model_diff\", GenLogicalBlobName(model_diff_lbi))\n        .Input(\"learning_rate\", user_op_conf.input(\"learning_rate\", 0))\n        .Attr<double>(\"scale\", scale)\n        .Attr<float>(\"l1\", l1)\n        .Attr<float>(\"l2\", l2)\n        .Attr<float>(\"weight_decay\", user_op_conf.attr<float>(\"weight_decay\"))\n        .Attr<float>(\"learning_rate_scale\", user_op_conf.attr<float>(\"learning_rate_scale\"));\n    if (scale_by_tensor_lbn != \"\") {\n      fused_op_builder.Input(\"scale_by_tensor\", scale_by_tensor_lbn);\n    }\n    if (user_op_conf.has_input(\"skip_if\", 0)) {\n      fused_op_builder.Input(\"skip_if\", user_op_conf.input(\"skip_if\", 0));\n    }\n    if (user_op_conf.op_type_name() == \"sgd_update\") {\n      // do nothing\n    } else if (user_op_conf.op_type_name() == \"momentum_update\") {\n      fused_op_builder.Input(\"momentum\", user_op_conf.input(\"momentum\", 0))\n          .Attr<float>(\"beta\", user_op_conf.attr<float>(\"beta\"))\n          .Attr<float>(\"dampening\", user_op_conf.attr<float>(\"dampening\"))\n          .Attr<bool>(\"nesterov\", user_op_conf.attr<bool>(\"nesterov\"))\n          .Attr<bool>(\"maximize\", user_op_conf.attr<bool>(\"maximize\"));\n    } else if (user_op_conf.op_type_name() == \"adam_update\") {\n      fused_op_builder.Input(\"m\", user_op_conf.input(\"m\", 0))\n          .Input(\"v\", user_op_conf.input(\"v\", 0))\n          .Attr<float>(\"beta1\", user_op_conf.attr<float>(\"beta1\"))\n          .Attr<float>(\"beta2\", user_op_conf.attr<float>(\"beta2\"))\n          .Attr<float>(\"epsilon\", user_op_conf.attr<float>(\"epsilon\"))\n          .Attr<bool>(\"amsgrad\", user_op_conf.attr<bool>(\"amsgrad\"))\n          .Attr<bool>(\"do_bias_correction\", user_op_conf.attr<bool>(\"do_bias_correction\"));\n      if (user_op_conf.has_input(\"max_v\", 0)) {\n        fused_op_builder.Input(\"max_v\", user_op_conf.input(\"max_v\", 0));\n      }\n      if (user_op_conf.has_input(\"bias_correction1\", 0)) {\n        fused_op_builder.Input(\"bias_correction1\", user_op_conf.input(\"bias_correction1\", 0));\n      }\n      if (user_op_conf.has_input(\"bias_correction2\", 0)) {\n        fused_op_builder.Input(\"bias_correction2\", user_op_conf.input(\"bias_correction2\", 0));\n      }\n    } else if (user_op_conf.op_type_name() == \"rmsprop_update\") {\n      const bool centered = user_op_conf.attr<bool>(\"centered\");\n      fused_op_builder.Input(\"mean_square\", user_op_conf.input(\"mean_square\", 0.f))\n          .Attr<bool>(\"centered\", user_op_conf.attr<bool>(\"centered\"))\n          .Attr<float>(\"epsilon\", user_op_conf.attr<float>(\"epsilon\"))\n          .Attr<float>(\"decay_rate\", user_op_conf.attr<float>(\"decay_rate\"));\n      if (centered) {\n        fused_op_builder.Input(\"mean_gradient\", user_op_conf.input(\"mean_gradient\", 0.f));\n      }\n    } else if (user_op_conf.op_type_name() == \"lars_update\") {\n      fused_op_builder.Input(\"momentum\", user_op_conf.input(\"momentum\", 0))\n          .Attr<float>(\"momentum_beta\", user_op_conf.attr<float>(\"momentum_beta\"))\n          .Attr<float>(\"epsilon\", user_op_conf.attr<float>(\"epsilon\"))\n          .Attr<float>(\"lars_coefficient\", user_op_conf.attr<float>(\"lars_coefficient\"));\n    } else if (user_op_conf.op_type_name() == \"adagrad_update\") {\n      fused_op_builder.Input(\"sum\", user_op_conf.input(\"sum\", 0))\n          .Input(\"train_step\", train_conf.train_step_lbn())\n          .Attr<float>(\"lr_decay\", user_op_conf.attr<float>(\"lr_decay\"))\n          .Attr<float>(\"epsilon\", user_op_conf.attr<float>(\"epsilon\"));\n    } else if (user_op_conf.op_type_name() == \"lamb_update\") {\n      fused_op_builder.Input(\"m\", user_op_conf.input(\"m\", 0))\n          .Input(\"v\", user_op_conf.input(\"v\", 0))\n          .Attr<float>(\"beta1\", user_op_conf.attr<float>(\"beta1\"))\n          .Attr<float>(\"beta2\", user_op_conf.attr<float>(\"beta2\"))\n          .Attr<float>(\"epsilon\", user_op_conf.attr<float>(\"epsilon\"))\n          .Attr<bool>(\"do_bias_correction\", user_op_conf.attr<bool>(\"do_bias_correction\"));\n      if (user_op_conf.has_input(\"bias_correction1\", 0)) {\n        fused_op_builder.Input(\"bias_correction1\", user_op_conf.input(\"bias_correction1\", 0));\n      }\n      if (user_op_conf.has_input(\"bias_correction2\", 0)) {\n        fused_op_builder.Input(\"bias_correction2\", user_op_conf.input(\"bias_correction2\", 0));\n      }\n    } else if (user_op_conf.op_type_name() == \"ftrl_update\") {\n      fused_op_builder.Input(\"accumulate\", user_op_conf.input(\"accumulate\", 0))\n          .Input(\"z\", user_op_conf.input(\"z\", 0))\n          .Attr<float>(\"lr_power\", user_op_conf.attr<float>(\"lr_power\"))\n          .Attr<float>(\"lambda1\", user_op_conf.attr<float>(\"lambda1\"))\n          .Attr<float>(\"lambda2\", user_op_conf.attr<float>(\"lambda2\"))\n          .Attr<float>(\"beta\", user_op_conf.attr<float>(\"beta\"));\n    } else if (user_op_conf.op_type_name() == \"adadelta_update\") {\n      fused_op_builder.Input(\"square_avgs\", user_op_conf.input(\"square_avgs\", 0))\n          .Input(\"acc_deltas\", user_op_conf.input(\"acc_deltas\", 0))\n          .Attr<float>(\"rho\", user_op_conf.attr<float>(\"rho\"))\n          .Attr<float>(\"epsilon\", user_op_conf.attr<float>(\"epsilon\"))\n          .Attr<bool>(\"maximize\", user_op_conf.attr<bool>(\"maximize\"));\n    } else {\n      UNIMPLEMENTED();\n    }\n    CHECK(user_op_conf.op_conf().has_scope_symbol_id());\n    fused_op_builder.ScopeSymbolId(user_op_conf.op_conf().scope_symbol_id());\n    OperatorConf new_op_conf = user_op_conf.op_conf();\n    *new_op_conf.mutable_user_conf() = fused_op_builder.Build().op_conf().user_conf();\n    job_builder->MutOpsOnlyOnce({new_op_conf});\n  });\n  job_builder->DelOps(del_op_names);\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\nREGISTER_JOB_PASS(\"FuseUpdateOpsPass\", FuseUpdateOpsPass);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job_rewriter/generate_optimizer_op_confs.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/job_rewriter/job_pass.h\"\n#include \"oneflow/core/job_rewriter/autograd.h\"\n#include \"oneflow/core/job_rewriter/optimizer.h\"\n#include \"oneflow/core/job_rewriter/calculation_pass.h\"\n#include \"oneflow/core/job/scope.h\"\n#include \"oneflow/core/job/scope.pb.h\"\n#include \"oneflow/core/vm/symbol_storage.h\"\n#include \"oneflow/core/framework/instructions_builder.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nclass GenerateOptimizerOpConfs final : public JobPass {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(GenerateOptimizerOpConfs);\n  GenerateOptimizerOpConfs() = default;\n  ~GenerateOptimizerOpConfs() override = default;\n\n  bool IsEnabled(const JobPassCtx& ctx) const { return ctx.job_desc().IsTrain(); }\n\n  Maybe<void> Apply(Job* job, JobPassCtx* ctx) const override;\n};\n\nvoid FilterCurModelLbi2ModelDiffLbiByName(\n    const ::google::protobuf::RepeatedPtrField<std::string>& variables,\n    const HashMap<LogicalBlobId, LogicalBlobId>& model_lbi2model_diff_lbi,\n    HashMap<LogicalBlobId, LogicalBlobId>* cur_model_lbi2model_diff_lbi) {\n  for (const std::string& variable : variables) {\n    const LogicalBlobId& lbi = GenLogicalBlobId(variable + \"/out\");\n    if (model_lbi2model_diff_lbi.find(lbi) != model_lbi2model_diff_lbi.end()) {\n      (*cur_model_lbi2model_diff_lbi)[lbi] = model_lbi2model_diff_lbi.at(lbi);\n    }\n  }\n}\n\nMaybe<JobBuilder> WithCalculationPassScope(const std::string& pass_name, Job* job,\n                                           const std::function<Maybe<void>()>& Handler) {\n  HashSet<std::string> exists_op_names;\n  for (const auto& op_conf : job->net().op()) {\n    CHECK_OR_RETURN(exists_op_names.emplace(op_conf.name()).second);\n  }\n  JUST(Handler());\n  // using a new JobBuilder to avoid bugs caused by MutOnlyOnce\n  auto new_job_builder = std::make_shared<JobBuilder>(job);\n  HashMap<int64_t, std::vector<const OperatorConf*>> scope_id2op_names;\n  const auto& scope_storage = *Singleton<symbol::Storage<Scope>>::Get();\n  for (const auto& op_conf : job->net().op()) {\n    if (exists_op_names.count(op_conf.name()) > 0) { continue; }\n    CHECK_OR_RETURN(op_conf.has_scope_symbol_id());\n    OF_RETURN_IF_ERROR(scope_storage.MaybeGet(op_conf.scope_symbol_id())) << op_conf.DebugString();\n    scope_id2op_names[op_conf.scope_symbol_id()].emplace_back(&op_conf);\n  }\n  const auto& GetNewScopeSymbolId = [&](int64_t old_scope_symbol_id) -> Maybe<int64_t> {\n    const auto& old_scope = JUST(scope_storage.MaybeGet(old_scope_symbol_id));\n    std::shared_ptr<ScopeProto> new_scope = std::make_shared<ScopeProto>(old_scope.scope_proto());\n    new_scope->set_parent_scope_symbol_id(old_scope_symbol_id);\n    new_scope->set_calculation_pass_name(pass_name);\n    std::shared_ptr<Scope> new_scope_symbol;\n    JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe<void> {\n      new_scope_symbol = JUST(builder->GetScopeSymbol(*new_scope));\n      return Maybe<void>::Ok();\n    }));\n    return JUST(new_scope_symbol->symbol_id());\n  };\n  for (const auto& pair : scope_id2op_names) {\n    int64_t new_scope_symbol_id = JUST(GetNewScopeSymbolId(pair.first));\n    std::vector<OperatorConf> op_confs(pair.second.size());\n    for (int i = 0; i < pair.second.size(); ++i) {\n      op_confs.at(i).CopyFrom(*pair.second.at(i));\n      op_confs.at(i).set_scope_symbol_id(new_scope_symbol_id);\n    }\n    new_job_builder->MutOpsOnlyOnce(op_confs);\n  }\n  return new_job_builder;\n}\n\nMaybe<void> GenerateOptimizerOpConfs::Apply(Job* job, JobPassCtx* ctx) const {\n  if (!IsEnabled(*ctx)) { return Maybe<void>::Ok(); }\n  const auto& train_conf = job->job_conf().train_conf();\n  // loss initial gradients\n  HashMap<LogicalBlobId, LogicalBlobId> loss_lbi2initial_diff_lbi;\n  CHECK_OR_RETURN(train_conf.loss_lbn_size() == train_conf.loss_grad_lbn_size())\n      << \"loss_lbn and loss_grad_lbn size mismatch\";\n  for (int i = 0; i < train_conf.loss_lbn_size(); ++i) {\n    auto loss_lbi = GenLogicalBlobId(train_conf.loss_lbn(i));\n    auto loss_grad_lbi = GenLogicalBlobId(train_conf.loss_grad_lbn(i));\n    loss_lbi2initial_diff_lbi.emplace(loss_lbi, loss_grad_lbi);\n  }\n  // variable gradients\n  HashMap<LogicalBlobId, LogicalBlobId> model_lbi2model_diff_lbi;\n  for (const auto& optimizer_conf : train_conf.optimizer_conf()) {\n    CHECK_OR_RETURN(optimizer_conf.variable_op_names_size()\n                    == optimizer_conf.variable_grad_lbns_size())\n        << \"variable_op_names and variable_grad_lbns size mismatch\";\n    for (int i = 0; i < optimizer_conf.variable_op_names_size(); ++i) {\n      auto model_lbi = GenLogicalBlobId(optimizer_conf.variable_op_names(i) + \"/out\");\n      const auto& model_diff_lbn = optimizer_conf.variable_grad_lbns(i);\n      // variable maybe has no gradient, so skip it if model_diff_lbn is empty\n      if (!model_diff_lbn.empty()) {\n        model_lbi2model_diff_lbi.emplace(model_lbi, GenLogicalBlobId(model_diff_lbn));\n      }\n    }\n  }\n  const OpGraph op_graph(*job);\n  auto job_builder = std::make_shared<JobBuilder>(job);\n  const JobBuilder* old_job_builder = job_builder.get();\n  job_builder = JUST(WithCalculationPassScope(kOptimizerPass, job, [&]() -> Maybe<void> {\n    CHECK(old_job_builder == job_builder.get());  // Check this lambda never been async called\n    AddDiffHalf2FloatCast(op_graph, job_builder.get(), &model_lbi2model_diff_lbi);\n    AddDiffStaticShapeCast(op_graph, job_builder.get(), &model_lbi2model_diff_lbi);\n    AddDiffParallelCast(op_graph, job_builder.get(), &model_lbi2model_diff_lbi);\n    JUST(ScaleModelDiffByLossInstanceNum(op_graph, job_builder.get(), &model_lbi2model_diff_lbi));\n    JUST(ScaleInitialDiffByLossScale(ctx, op_graph, job_builder.get(), &loss_lbi2initial_diff_lbi));\n    ScaleModelDiffByLossScale(ctx, op_graph, job_builder.get(), &model_lbi2model_diff_lbi);\n    JUST(CountNotFiniteIfNeeded(ctx, op_graph, job_builder.get(), model_lbi2model_diff_lbi));\n    for (const auto& optimizer_conf : job->job_conf().train_conf().optimizer_conf()) {\n      HashMap<LogicalBlobId, LogicalBlobId> cur_model_lbi2model_diff_lbi;\n      FilterCurModelLbi2ModelDiffLbiByName(optimizer_conf.variable_op_names(),\n                                           model_lbi2model_diff_lbi, &cur_model_lbi2model_diff_lbi);\n      if (optimizer_conf.has_clip_conf()) {\n        ClipGradient(ctx, op_graph, job_builder.get(), &cur_model_lbi2model_diff_lbi,\n                     optimizer_conf.clip_conf());\n      }\n      RegularizeGradient(op_graph, job_builder.get(), &cur_model_lbi2model_diff_lbi);\n      op_graph.ForEachNode([&](OpNode* op_node) {\n        const VariableOp* var_op = dynamic_cast<const VariableOp*>(&op_node->op());\n        if (var_op == nullptr\n            || cur_model_lbi2model_diff_lbi.find(var_op->BnInOp2Lbi(var_op->SoleObn()))\n                   == cur_model_lbi2model_diff_lbi.end()) {\n          return;\n        }\n        const std::string& model_diff_lbn = GenLogicalBlobName(\n            cur_model_lbi2model_diff_lbi.at(var_op->BnInOp2Lbi(var_op->SoleObn())));\n        AddOptimizerOp(ctx, *op_node, model_diff_lbn, optimizer_conf, job_builder.get());\n      });\n    }\n    return Maybe<void>::Ok();\n  }));\n  return Maybe<void>::Ok();\n}\n\nREGISTER_JOB_PASS(\"GenerateOptimizerOpConfs\", GenerateOptimizerOpConfs);\n\n}  // namespace\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job_rewriter/group_boxing_by_dst_parallel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/job_rewriter/group_boxing_by_dst_parallel.h\"\n#include \"oneflow/core/framework/sbp_infer_util.h\"\n#include \"oneflow/core/job/scope.h\"\n#include \"oneflow/core/job/job_desc.h\"\n#include \"oneflow/core/vm/symbol_storage.h\"\n#include \"oneflow/core/common/protobuf.h\"\n\nnamespace oneflow {\n\nconst Scope& Scope4ScopeSymbolId(int64_t scope_symbol_id) {\n  CHECK(Singleton<symbol::Storage<Scope>>::Get()->Has(scope_symbol_id));\n  return Singleton<symbol::Storage<Scope>>::Get()->Get(scope_symbol_id);\n}\n\nconst Scope& Scope4OpNode(const OpNode* op_node) {\n  const OperatorConf& op_conf = op_node->op().op_conf();\n  CHECK(op_conf.has_scope_symbol_id());\n  return Scope4ScopeSymbolId(op_conf.scope_symbol_id());\n}\n\nbool OpNodeHasScope(const OpNode* node) { return node->op().op_conf().has_scope_symbol_id(); }\n\nint64_t GetStageIdHint(const OpNode* node) {\n  return Scope4OpNode(node).Int64(\"pipeline_stage_id_hint\");\n}\n\nMaybe<void> GroupBoxingByDstParallel(const OpGraph& op_graph, JobBuilder* job_builder) {\n  {\n    // NOTE(chengcheng): Disable group boxing for pipeline parallel, because there will be bad case\n    //  make forward backward exec sequential in ZeRO + 3-D Parallel by insert additional boxing\n    //  identity.\n    int64_t max_stage_id = 0;\n    op_graph.ForEachNode([&](const OpNode* this_node) {\n      if (!OpNodeHasScope(this_node)) {\n        LOG(WARNING) << \" op : \" << this_node->op().op_conf().DebugString() << \" has NOT scope!\";\n        return;\n      }\n      max_stage_id = std::max(max_stage_id, GetStageIdHint(this_node));\n    });\n    if (max_stage_id > 0) { return Maybe<void>::Ok(); }\n  }\n  HashMap<LogicalBlobId, HashMap<std::pair<ParallelDesc, NdSbp>,\n                                 std::vector<std::pair<const OpNode*, std::string>>>>\n      lbi2consumer_grouped_by_parallel;\n  HashMap<const OpNode*, OperatorConf> op_node2op_conf;\n  op_graph.ForEachNode([&](const OpNode* node) {\n    OperatorConf::OpTypeCase op_type_case = node->op().op_conf().op_type_case();\n    if (IsClassRegistered<int32_t, DisableInputBoxingGroup>(op_type_case)) { return; }\n    for (const std::string& ibn : node->op().input_bns()) {\n      const auto& blob_modifier_ = node->op().InputBlobModifier4Ibn(ibn);\n      if (blob_modifier_.has_is_mutable() && blob_modifier_.is_mutable()) { continue; }\n      const LogicalBlobId& lbi = node->op().BnInOp2Lbi(ibn);\n      const OpNode& producer = node->ProducerOpNode4Lbi(lbi);\n      const auto& logical_shape = node->LogicalBlobDesc4Lbi(lbi).shape();\n      const NdSbp& producer_nd_sbp = producer.NdSbp4Lbi(lbi);\n      const std::string& producer_lbn = *CHECK_JUST(producer.op().obn4lbi(lbi));\n      const ParallelDesc& producer_parallel_desc =\n          *CHECK_JUST(producer.op().GetParallelDesc4BnInOp(producer_lbn)).get();\n      ParallelDesc reduced_in_parallel_desc = producer_parallel_desc;\n      NdSbp reduced_in_nd_sbp;\n      NdSbpDimReduce(producer_parallel_desc, producer_nd_sbp, &reduced_in_parallel_desc,\n                     &reduced_in_nd_sbp, logical_shape);\n\n      const NdSbp& consumer_nd_sbp = node->NdSbp4BnInOp(ibn);\n      const ParallelDesc& consumer_parallel_desc =\n          *CHECK_JUST(node->op().GetParallelDesc4BnInOp(ibn));\n      ParallelDesc reduced_out_parallel_desc = consumer_parallel_desc;\n      NdSbp reduced_out_nd_sbp;\n      NdSbpDimReduce(consumer_parallel_desc, consumer_nd_sbp, &reduced_out_parallel_desc,\n                     &reduced_out_nd_sbp, logical_shape);\n\n      if (reduced_in_parallel_desc == reduced_out_parallel_desc\n          && reduced_in_nd_sbp == reduced_out_nd_sbp) {\n        continue;\n      }\n      lbi2consumer_grouped_by_parallel[lbi][{reduced_out_parallel_desc, reduced_out_nd_sbp}]\n          .push_back({node, ibn});\n      if (op_node2op_conf.find(node) == op_node2op_conf.end()) {\n        op_node2op_conf[node] = node->op().op_conf();\n      }\n    }\n  });\n  for (const auto& lbi7groups : lbi2consumer_grouped_by_parallel) {\n    const LogicalBlobId& lbi = lbi7groups.first;\n    for (const auto& parallel7group : lbi7groups.second) {\n      if (parallel7group.second.size() < 2) { continue; }\n      const ParallelDesc& dst_parallel_desc = parallel7group.first.first;\n      const NdSbp& dst_nd_sbp = parallel7group.first.second;\n      OperatorConf identity_op_conf{};\n      identity_op_conf.set_name(\"Sys-Boxing-GroupIdentity-\" + lbi.op_name() + \"_\" + lbi.blob_name()\n                                + \"-\" + NewUniqueId());\n      IdentityOpConf* identity_conf = identity_op_conf.mutable_identity_conf();\n      identity_conf->set_in(GenLogicalBlobName(lbi));\n      identity_conf->set_out(\"out\");\n      job_builder->AddOps(dst_parallel_desc.parallel_conf(), {identity_op_conf});\n      NdSbpSignature identity_nd_sbp_signature;\n      (*identity_nd_sbp_signature.mutable_bn_in_op2nd_sbp())[\"in\"] = dst_nd_sbp;\n      (*identity_nd_sbp_signature.mutable_bn_in_op2nd_sbp())[\"out\"] = dst_nd_sbp;\n      job_builder->AddNdSbpSignature4OpName(identity_op_conf.name(), identity_nd_sbp_signature);\n\n      LogicalBlobId grouped_lbi;\n      grouped_lbi.set_op_name(identity_op_conf.name());\n      grouped_lbi.set_blob_name(identity_conf->out());\n      for (const auto& consumer7ibn : parallel7group.second) {\n        const OpNode* consumer = consumer7ibn.first;\n        const std::string& ibn = consumer7ibn.second;\n        OperatorConf& consumer_op_conf = op_node2op_conf[consumer];\n        const auto& old_val = ReplaceInputLbnInOpCustomizedConf(&consumer_op_conf, ibn,\n                                                                GenLogicalBlobName(grouped_lbi));\n        CHECK_EQ_OR_RETURN(GenLogicalBlobName(lbi), old_val);\n      }\n    }\n  }\n  for (const auto& op_node7op_conf : op_node2op_conf) {\n    JUST(job_builder->MutOpOnlyOnce(op_node7op_conf.second));\n  }\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job_rewriter/group_boxing_by_dst_parallel.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_JOB_REWRITER_GROUP_BOXING_BY_DST_PARALLEL_H_\n#define ONEFLOW_CORE_JOB_REWRITER_GROUP_BOXING_BY_DST_PARALLEL_H_\n\n#include \"oneflow/core/graph/op_graph.h\"\n\nnamespace oneflow {\n\nclass OpGraph;\nclass Job;\n\nMaybe<void> GroupBoxingByDstParallel(const OpGraph& op_graph, JobBuilder* job_builder);\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_JOB_REWRITER_GROUP_BOXING_BY_DST_PARALLEL_H_\n"
  },
  {
    "path": "oneflow/core/job_rewriter/indexed_slices_optimizer_rewrite_pass.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/job_rewriter/job_pass.h\"\n#include \"oneflow/core/framework/framework.h\"\n\nnamespace oneflow {\n\nclass IndexedSlicesOptimizerRewritePass final : public JobPass {\n public:\n  IndexedSlicesOptimizerRewritePass() = default;\n  ~IndexedSlicesOptimizerRewritePass() override = default;\n\n  bool IsEnabled(const JobPassCtx& ctx) const {\n    return ctx.job_desc().job_conf().has_indexed_slices_optimizer_conf()\n           && ctx.job_desc().job_conf().indexed_slices_optimizer_conf().enable();\n  }\n\n  Maybe<void> Apply(const OpGraph& op_graph, JobBuilder* job_builder) const;\n\n  Maybe<void> Apply(Job* job, JobPassCtx* ctx) const override {\n    if (!IsEnabled(*ctx)) { return Maybe<void>::Ok(); }\n    const OpGraph op_graph(*job);\n    JobBuilder job_builder(job);\n    return Apply(op_graph, &job_builder);\n  }\n};\n\nMaybe<void> IndexedSlicesOptimizerRewritePass::Apply(const OpGraph& op_graph,\n                                                     JobBuilder* job_builder) const {\n  const PbRpf<std::string>& include_op_names =\n      GlobalJobDesc().job_conf().indexed_slices_optimizer_conf().include_op_names().op_name();\n  const std::set<std::string> include_op_name_set(\n      {include_op_names.cbegin(), include_op_names.cend()});\n  op_graph.ForEachNode([&](const OpNode* src_node) {\n    const OperatorConf& src_op_conf = src_node->op().op_conf();\n    if (src_node->out_edges().size() != 1) { return; }\n    std::string indices_lbn;\n    std::string values_lbn;\n    std::string model_op_name;\n    if (!src_op_conf.has_user_conf()) { return; }\n    const user_op::UserOpConfWrapper src_op(src_op_conf);\n    if (src_op.op_type_name() == \"unsorted_segment_sum\" && src_op.attr<int64_t>(\"axis\") == 0) {\n      indices_lbn = src_op.input(\"segment_ids\", 0);\n      values_lbn = src_op.input(\"data\", 0);\n    } else if (src_op.op_type_name() == \"unsorted_segment_sum_like\"\n               && src_op.attr<int64_t>(\"axis\") == 0) {\n      indices_lbn = src_op.input(\"segment_ids\", 0);\n      values_lbn = src_op.input(\"data\", 0);\n    } else {\n      return;\n    }\n    std::vector<const OpNode*> op_nodes_to_remove;\n    std::vector<const OpNode*> op_nodes_apply_to_diff;\n    const OpNode* dst_node = src_node->SoleOutEdge()->dst_node();\n    do {\n      if (dst_node->op().output_bns().empty()) { break; }\n      const OperatorConf& dst_op_conf = dst_node->op().op_conf();\n      if (dst_op_conf.has_user_conf()\n          && dst_op_conf.user_conf().op_type_name() == \"hierarchical_parallel_cast\") {\n        if (dst_node->out_edges().size() != 1) { return; }\n        op_nodes_to_remove.emplace_back(dst_node);\n        dst_node = dst_node->SoleOutEdge()->dst_node();\n        continue;\n      } else if (dst_op_conf.has_user_conf()\n                 && dst_op_conf.user_conf().op_type_name() == \"scalar_mul\") {\n        if (dst_node->out_edges().size() != 1) { return; }\n        op_nodes_apply_to_diff.emplace_back(dst_node);\n        dst_node = dst_node->SoleOutEdge()->dst_node();\n        continue;\n      } else {\n        return;\n      }\n    } while (true);\n    if (!dst_node->op().op_conf().has_user_conf()) { return; }\n    const user_op::UserOpConfWrapper user_op_conf(dst_node->op().op_conf());\n    if (user_op_conf.op_type_name() != \"sgd_update\"\n        && user_op_conf.op_type_name() != \"momentum_update\"\n        && user_op_conf.op_type_name() != \"adam_update\") {\n      return;\n    }\n    if (user_op_conf.attr<double>(\"scale\") != 1.0 || user_op_conf.attr<float>(\"l1\") != 0.0f\n        || user_op_conf.attr<float>(\"l2\") != 0.0f || user_op_conf.has_input(\"scale_by_tensor\", 0)) {\n      return;\n    }\n    const LogicalBlobId& model_lbi = GenLogicalBlobId(user_op_conf.input(\"model\", 0));\n    if (dst_node->LogicalBlobDesc4Lbi(GenLogicalBlobId(user_op_conf.input(\"model_diff\", 0)))\n            .data_type()\n        != dst_node->LogicalBlobDesc4Lbi(model_lbi).data_type()) {\n      return;\n    }\n    model_op_name = model_lbi.op_name();\n    user_op::UserOpConfWrapperBuilder indexed_slices_op_builder(\"System-Optimizer-IndexedSlices-\"\n                                                                + model_op_name);\n    indexed_slices_op_builder.OpTypeName(\"indexed_slices_\" + user_op_conf.op_type_name())\n        .Input(\"model\", user_op_conf.input(\"model\", 0))\n        .Input(\"learning_rate\", user_op_conf.input(\"learning_rate\", 0))\n        .Attr<float>(\"weight_decay\", user_op_conf.attr<float>(\"weight_decay\"))\n        .Attr<float>(\"learning_rate_scale\", user_op_conf.attr<float>(\"learning_rate_scale\"));\n\n    if (user_op_conf.op_type_name() == \"sgd_update\") {\n      // do nothing\n    } else if (user_op_conf.op_type_name() == \"momentum_update\") {\n      indexed_slices_op_builder.Input(\"momentum\", user_op_conf.input(\"momentum\", 0))\n          .Attr<float>(\"beta\", user_op_conf.attr<float>(\"beta\"))\n          .Attr<float>(\"dampening\", user_op_conf.attr<float>(\"dampening\"))\n          .Attr<bool>(\"nesterov\", user_op_conf.attr<bool>(\"nesterov\"))\n          .Attr<bool>(\"maximize\", user_op_conf.attr<bool>(\"maximize\"));\n    } else if (user_op_conf.op_type_name() == \"adam_update\") {\n      indexed_slices_op_builder.Input(\"m\", user_op_conf.input(\"m\", 0))\n          .Input(\"v\", user_op_conf.input(\"v\", 0))\n          .Attr<float>(\"beta1\", user_op_conf.attr<float>(\"beta1\"))\n          .Attr<float>(\"beta2\", user_op_conf.attr<float>(\"beta2\"))\n          .Attr<float>(\"epsilon\", user_op_conf.attr<float>(\"epsilon\"));\n      if (user_op_conf.has_input(\"max_v\", 0)) {\n        indexed_slices_op_builder.Input(\"max_v\", user_op_conf.input(\"max_v\", 0));\n      }\n    } else {\n      return;\n    }\n    CHECK(!model_op_name.empty());\n    CHECK(!indices_lbn.empty());\n    CHECK(!values_lbn.empty());\n    if (include_op_name_set.find(model_op_name) == include_op_name_set.end()) { return; }\n    for (const OpNode* node : op_nodes_to_remove) { job_builder->DelOps({node->op().op_conf()}); }\n    for (const OpNode* node : op_nodes_apply_to_diff) {\n      OperatorConf new_conf = node->op().op_conf();\n      if (new_conf.has_user_conf() && new_conf.user_conf().op_type_name() == \"scalar_mul\") {\n        const auto& old_val = ReplaceInputLbnInOpCustomizedConf(&new_conf, \"in_0\", values_lbn);\n        CHECK_EQ(GenLogicalBlobName(node->op().BnInOp2Lbi(\"in_0\")), old_val);\n        values_lbn = GenLogicalBlobName(new_conf.name(), \"out_0\");\n        job_builder->MutOpsOnlyOnce({new_conf});\n      } else {\n        UNIMPLEMENTED();\n      }\n    }\n    indexed_slices_op_builder.Input(\"model_diff_indices\", indices_lbn)\n        .Input(\"model_diff_values\", values_lbn)\n        .ScopeSymbolId(src_op_conf.scope_symbol_id());\n    job_builder->DelOps({src_op_conf, user_op_conf.op_conf()});\n    job_builder->AddOps(dst_node->parallel_desc().parallel_conf(),\n                        {indexed_slices_op_builder.Build().op_conf()});\n  });\n  return Maybe<void>::Ok();\n}\n\nREGISTER_JOB_PASS(\"IndexedSlicesOptimizerRewritePass\", IndexedSlicesOptimizerRewritePass);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job_rewriter/input_autotick.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/job_rewriter/autotick.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nclass MutInputOpConTickInputHelper final : public MutOpConTickInputHelper {\n public:\n  MutInputOpConTickInputHelper() : MutOpConTickInputHelper() {}\n\n  bool VirtualIsTickInputBound() const override { return op_conf().input_conf().has_tick(); }\n\n  OperatorConf NewTickInputBoundOpConf(const std::string& lbn) const override {\n    OperatorConf ret(op_conf());\n    ret.mutable_input_conf()->set_tick(lbn);\n    return ret;\n  }\n};\n\n}  // namespace\n\nREGISTER_AUTO_TICK(OperatorConf::kInputConf, MutInputOpConTickInputHelper);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job_rewriter/insert_nccl_logical_op_pass.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/auto_parallel/auto_memory.h\"\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/job/nd_sbp_util.h\"\n#if defined(WITH_CUDA) || defined(WITH_NPU) || defined(WITH_MLU)\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/nd_sbp.h\"\n#include \"oneflow/core/framework/instructions_builder.h\"\n#include \"oneflow/core/job/eager_nccl_comm_manager.h\"\n#include \"oneflow/core/job/scope.h\"\n#include \"oneflow/core/job/sbp_parallel.h\"\n#include \"oneflow/core/job/job.pb.h\"\n#include \"oneflow/core/job_rewriter/job_pass.h\"\n#include \"oneflow/core/job_rewriter/calculation_pass.h\"\n#include \"oneflow/core/vm/vm_util.h\"\n#include \"oneflow/core/vm/symbol_storage.h\"\n#include \"oneflow/core/operator/operator.h\"\n#include \"oneflow/core/framework/sbp_infer_util.h\"\n#include \"oneflow/core/common/env_var/debug_mode.h\"\n\nnamespace oneflow {\n\nDEFINE_ENV_INTEGER(ONEFLOW_GRAPH_MAX_NCCL_COMPUTE_STREAM, 8);\n\nnamespace {\n\nclass InsertNcclLogicalOpPass final : public JobPass {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(InsertNcclLogicalOpPass);\n  InsertNcclLogicalOpPass() = default;\n  ~InsertNcclLogicalOpPass() = default;\n\n  Maybe<void> Apply(Job* job, JobPassCtx* ctx) const override {\n    if (!IsEnabled(*ctx)) { return Maybe<void>::Ok(); }\n    const OpGraph op_graph(*job);\n    JobBuilder job_builder(job);\n    return Apply(op_graph, &job_builder);\n  }\n\n  bool IsEnabled(const JobPassCtx& ctx) const {\n    return Singleton<ResourceDesc, ForSession>::Get()->nccl_use_compute_stream();\n  }\n\n  Maybe<void> Apply(const OpGraph& op_graph, JobBuilder* job_builder) const;\n};\n\nconst std::string kNcclLogicalOpNamePrefix = \"System-NCCL-Logical\";\n\nbool IsTickOpConf(const OperatorConf& op_conf) {\n  if (IsClassRegistered<int32_t, IsTickTockOpTypeCase>(op_conf.op_type_case())) { return true; }\n  if (op_conf.has_user_conf()) {\n    const std::string& user_type_name = op_conf.user_conf().op_type_name();\n    if (user_type_name == \"cast_to_tick\" || user_type_name == \"acc_ctrl_tick\") { return true; }\n  }\n  return false;\n}\n\nbool IsBreakpointOpNode(const OpNode* node) {\n  // NOTE(chengcheng): breakpoint op is special which CANNOT through subgraph such as:\n  //   variable, tick, repeat/acc/pack/unpack change timeshape\n  const Operator& op = node->op();\n  const OperatorConf& op_conf = op.op_conf();\n  // TODO(chengcheng): filter ops which has special type\n  // TODO(chengcheng): get stream by op type\n  if (op_conf.has_variable_conf()                                                   /* varialbe */\n      || IsTickOpConf(op_conf)                                                      /* tick */\n      || op_conf.has_input_conf() || op_conf.has_output_conf()                      /* io */\n      || op_conf.has_wait_and_send_ids_conf() || op_conf.has_callback_notify_conf() /* ctrl */\n      || op_conf.has_image_decoder_random_crop_resize_conf() /* gpu decode */) {\n    return true;\n  }\n\n  if (op_conf.has_user_conf()) {\n    const std::string& user_type_name = op_conf.user_conf().op_type_name();\n    if (user_type_name == \"repeat\" || user_type_name == \"pack\" || user_type_name == \"unpack\"\n        || user_type_name == \"identity_buffer\") {\n      return true;\n    }\n    if (!EnableLogicalChain()) {\n      // NOTE(chengcheng): in old task graph chain version, consider acc as breakpoint node\n      if (user_type_name == \"acc\") { return true; }\n    }\n  }\n  return false;\n}\n\nbool IsAccOpNode(const OpNode* node) {\n  return node->op().op_conf().has_user_conf()\n         && node->op().op_conf().user_conf().op_type_name() == \"acc\";\n}\n\nbool IsRepeatOpNode(const OpNode* node) {\n  return node->op().op_conf().has_user_conf()\n         && node->op().op_conf().user_conf().op_type_name() == \"repeat\";\n}\n\nstd::shared_ptr<const Shape> GetOpNodeTimeShape(const OpNode* op_node) {\n  return CHECK_JUST(op_node->op().GetOpTimeShape());\n}\n\nstd::shared_ptr<const Shape> GetOpNodeInputTimeShape(const OpNode* op_node) {\n  return CHECK_JUST(op_node->op().GetInputBlobFastestTimeShape());\n}\n\nstd::shared_ptr<const Shape> GetOpNodeFastestTimeShape(const OpNode* op_node) {\n  return CHECK_JUST(op_node->op().GetInputOutputFastestTimeShape());\n}\n\nbool SharedPtrShapeEqual(const std::shared_ptr<const Shape>& lhs,\n                         const std::shared_ptr<const Shape>& rhs) {\n  return (*lhs) == (*rhs);\n}\n\nvoid FindAllConnectedSubgraphForGpuExecOrder(std::vector<HashSet<const OpNode*>>* ret,\n                                             const OpGraph& op_graph,\n                                             const std::vector<const OpNode*>& order) {\n  // NOTE(chengcheng): acc subgraph may greater than fw/bw subgraph. we need use max time shape.\n  std::shared_ptr<const Shape> seed_time_shape = std::make_shared<const Shape>(Shape({1, 1}));\n  op_graph.ForEachNode([&](const OpNode* node) {\n    std::shared_ptr<const Shape> this_time_shape = GetOpNodeFastestTimeShape(node);\n    if (this_time_shape->elem_cnt() > seed_time_shape->elem_cnt()) {\n      seed_time_shape = this_time_shape;\n    }\n  });\n\n  VLOG(2) << \" seed time shape = \" << seed_time_shape->ToString();\n\n  HashSet<const OpNode*> visited;\n\n  for (const OpNode* seed_node : order) {\n    if (visited.find(seed_node) != visited.end()) { continue; }\n    CHECK(visited.insert(seed_node).second);\n    const ParallelDesc& seed_parallel_desc = seed_node->parallel_desc();\n    // NOTE(chengcheng): ONLY consider GPU op and parallel num > 1.\n    if (seed_parallel_desc.device_type() == DeviceType::kCPU) { continue; }\n    if (seed_parallel_desc.parallel_num() <= 1) { continue; }\n    // NOTE(chengcheng): using fastest time shape for merge acc into bw subgraph.\n    if (!SharedPtrShapeEqual(GetOpNodeFastestTimeShape(seed_node), seed_time_shape)) { continue; }\n    if (IsBreakpointOpNode(seed_node)) { continue; }\n    // NOTE(chengcheng):\n    //   stream name hint maybe set by other job pass like replace embedding.\n    //   we cannot replace stream name in subgraph\n    if (seed_node->op().op_conf().has_stream_name_hint()) { continue; }\n\n    HashSet<const OpNode*> this_subgraph;\n    std::queue<const OpNode*> queued_nodes;\n\n    queued_nodes.push(seed_node);\n    while (!queued_nodes.empty()) {\n      const OpNode* cur_node = queued_nodes.front();\n      queued_nodes.pop();\n\n      CHECK(cur_node->parallel_desc().EqualsIgnoringHierarchy(seed_parallel_desc));\n      CHECK(this_subgraph.insert(cur_node).second);\n\n      cur_node->ForEachNodeOnInOutEdge([&](const OpNode* next_node) {\n        if (visited.find(next_node) == visited.end() && (!IsBreakpointOpNode(next_node))\n            && next_node->parallel_desc().EqualsIgnoringHierarchy(seed_parallel_desc)\n            && SharedPtrShapeEqual(GetOpNodeFastestTimeShape(next_node), seed_time_shape)) {\n          CHECK(visited.insert(next_node).second);\n          queued_nodes.push(next_node);\n        }\n      });\n    }\n\n    if (this_subgraph.size() > 1) {\n      ret->emplace_back(HashSet<const OpNode*>());\n      ret->back().swap(this_subgraph);\n    }\n  }\n\n  std::sort(ret->begin(), ret->end(),\n            [](const HashSet<const OpNode*>& lhs, const HashSet<const OpNode*>& rhs) {\n              return lhs.size() > rhs.size();\n            });\n}\n\nbool TryBuildNcclBy1DHierarchy(OperatorConf* ret, const SbpParallel& src_sbp,\n                               const SbpParallel& dst_sbp, const std::string& lbn,\n                               const int64_t scope_symbol_id, const BlobDesc& logical_blob_desc,\n                               const int64_t parallel_num) {\n  auto CanSplitAtDim = [&](int64_t dim) -> bool {\n    if (logical_blob_desc.shape().NumAxes() <= dim) { return false; }\n    return logical_blob_desc.shape().At(dim) % parallel_num == 0;\n  };\n  if (src_sbp.has_partial_sum_parallel() && dst_sbp.has_broadcast_parallel()) {\n    // P->B : AllReduce\n    *ret = user_op::UserOpConfWrapperBuilder(kNcclLogicalOpNamePrefix + \"-P2B-\" + NewUniqueId())\n               .Op(\"_nccl_logical_all_reduce\")\n               .Input(\"in\", lbn)\n               .Output(\"out\")\n               .Attr<std::vector<std::string>>(\"src_reduced_nd_sbp\", {SbpToString(src_sbp)})\n               .Attr<std::vector<std::string>>(\"dst_reduced_nd_sbp\", {SbpToString(dst_sbp)})\n               .ScopeSymbolId(scope_symbol_id)\n               .Build()\n               .op_conf();\n    return true;\n  } else if (CanSplitAtDim(0)\n             && (src_sbp.has_partial_sum_parallel() && dst_sbp.has_split_parallel())\n             && (dst_sbp.split_parallel().axis() == 0)) {\n    // P->S(0) : ReduceScatter\n    *ret = user_op::UserOpConfWrapperBuilder(kNcclLogicalOpNamePrefix + \"-P2S-\" + NewUniqueId())\n               .Op(\"_nccl_logical_reduce_scatter\")\n               .Input(\"in\", lbn)\n               .Output(\"out\")\n               .Attr<std::vector<std::string>>(\"src_reduced_nd_sbp\", {SbpToString(src_sbp)})\n               .Attr<std::vector<std::string>>(\"dst_reduced_nd_sbp\", {SbpToString(dst_sbp)})\n               .ScopeSymbolId(scope_symbol_id)\n               .Build()\n               .op_conf();\n    return true;\n  } else if (CanSplitAtDim(0) && (src_sbp.has_split_parallel() && dst_sbp.has_broadcast_parallel())\n             && (src_sbp.split_parallel().axis() == 0)) {\n    // S(0)->B : AllGather\n    *ret = user_op::UserOpConfWrapperBuilder(kNcclLogicalOpNamePrefix + \"-S2B-\" + NewUniqueId())\n               .Op(\"_nccl_logical_all_gather\")\n               .Input(\"in\", lbn)\n               .Output(\"out\")\n               .Attr<std::vector<std::string>>(\"src_reduced_nd_sbp\", {SbpToString(src_sbp)})\n               .Attr<std::vector<std::string>>(\"dst_reduced_nd_sbp\", {SbpToString(dst_sbp)})\n               .ScopeSymbolId(scope_symbol_id)\n               .Build()\n               .op_conf();\n    return true;\n  } else if (src_sbp.has_split_parallel() && dst_sbp.has_broadcast_parallel()\n             && src_sbp.split_parallel().axis() > 0\n             && CanSplitAtDim(src_sbp.split_parallel().axis())) {\n    // S(1)->B : AllGather Noncontinuous\n    *ret = user_op::UserOpConfWrapperBuilder(kNcclLogicalOpNamePrefix + \"-S2B-\" + NewUniqueId())\n               .Op(\"_nccl_logical_all_gather_noncontinuous\")\n               .Input(\"in\", lbn)\n               .Output(\"out\")\n               .Attr<std::vector<std::string>>(\"src_reduced_nd_sbp\", {SbpToString(src_sbp)})\n               .Attr<std::vector<std::string>>(\"dst_reduced_nd_sbp\", {SbpToString(dst_sbp)})\n               .ScopeSymbolId(scope_symbol_id)\n               .Build()\n               .op_conf();\n    return true;\n  } else if (src_sbp.has_split_parallel() && dst_sbp.has_split_parallel()\n             && src_sbp.split_parallel().axis() != dst_sbp.split_parallel().axis()\n             && CanSplitAtDim(src_sbp.split_parallel().axis())\n             && CanSplitAtDim(dst_sbp.split_parallel().axis())) {\n    // S(in)->S(out) : All2All\n    *ret = user_op::UserOpConfWrapperBuilder(kNcclLogicalOpNamePrefix + \"-S2S-\" + NewUniqueId())\n               .Op(\"_nccl_logical_s2s\")\n               .Input(\"in\", lbn)\n               .Output(\"out\")\n               .Attr<std::vector<std::string>>(\"src_reduced_nd_sbp\", {SbpToString(src_sbp)})\n               .Attr<std::vector<std::string>>(\"dst_reduced_nd_sbp\", {SbpToString(dst_sbp)})\n               .ScopeSymbolId(scope_symbol_id)\n               .Build()\n               .op_conf();\n    return true;\n  } else if (CanSplitAtDim(dst_sbp.split_parallel().axis())\n             && (src_sbp.has_partial_sum_parallel() && dst_sbp.has_split_parallel())\n             && (dst_sbp.split_parallel().axis() > 0)) {\n    // P->S(1) : ReduceScatter Noncontinuous\n    *ret = user_op::UserOpConfWrapperBuilder(kNcclLogicalOpNamePrefix + \"-P2S-\" + NewUniqueId())\n               .Op(\"_nccl_logical_reduce_scatter_noncontinuous\")\n               .Input(\"in\", lbn)\n               .Output(\"out\")\n               .Attr<std::vector<std::string>>(\"src_reduced_nd_sbp\", {SbpToString(src_sbp)})\n               .Attr<std::vector<std::string>>(\"dst_reduced_nd_sbp\", {SbpToString(dst_sbp)})\n               .ScopeSymbolId(scope_symbol_id)\n               .Build()\n               .op_conf();\n    return true;\n  } else if (!dst_sbp.has_partial_sum_parallel()) {\n    *ret = user_op::UserOpConfWrapperBuilder(kNcclLogicalOpNamePrefix + \"-(Send)2(Recv)-\"\n                                             + NewUniqueId())\n               .Op(\"_nccl_logical_send_recv\")\n               .Input(\"in\", lbn)\n               .Output(\"out\")\n               .Attr<std::vector<std::string>>(\"src_reduced_nd_sbp\", {SbpToString(src_sbp)})\n               .Attr<std::vector<std::string>>(\"dst_reduced_nd_sbp\", {SbpToString(dst_sbp)})\n               .ScopeSymbolId(scope_symbol_id)\n               .Build()\n               .op_conf();\n    return true;\n  }\n  return false;\n}\n\nbool TryBuildNcclBy2DHierarchySameDim0(OperatorConf* ret, const NdSbp& src_nd_sbp,\n                                       const NdSbp& dst_nd_sbp,\n                                       const std::shared_ptr<Shape>& hierarchy,\n                                       const std::string& lbn, const int64_t scope_symbol_id,\n                                       const BlobDesc& logical_blob_desc) {\n  CHECK_EQ(src_nd_sbp.sbp_parallel_size(), 2);\n  CHECK_EQ(dst_nd_sbp.sbp_parallel_size(), 2);\n  CHECK(src_nd_sbp.sbp_parallel(0) == dst_nd_sbp.sbp_parallel(0));\n  const SbpParallel& src_dim1_sbp = src_nd_sbp.sbp_parallel(1);\n  const SbpParallel& dst_dim1_sbp = dst_nd_sbp.sbp_parallel(1);\n\n  // split when dim0 sbp is split parallel\n  DimVector dim_vec = logical_blob_desc.shape().dim_vec();\n  if (src_nd_sbp.sbp_parallel(0).has_split_parallel()) {\n    const int64_t axis = src_nd_sbp.sbp_parallel(0).split_parallel().axis();\n    dim_vec.at(axis) /= hierarchy->At(0);\n  }\n  const int64_t num_ranks = hierarchy->At(1);\n\n  if (src_dim1_sbp.has_partial_sum_parallel() && dst_dim1_sbp.has_broadcast_parallel()) {\n    // (*, P)->(*, B) : AllReduce\n    *ret =\n        user_op::UserOpConfWrapperBuilder(kNcclLogicalOpNamePrefix + \"-(*P)2(*B)-\" + NewUniqueId())\n            .Op(\"_nccl_logical_2D_same_dim0_all_reduce\")\n            .Input(\"in\", lbn)\n            .Output(\"out\")\n            .Attr<std::vector<std::string>>(\"src_reduced_nd_sbp\", NdSbpToStringList(src_nd_sbp))\n            .Attr<std::vector<std::string>>(\"dst_reduced_nd_sbp\", NdSbpToStringList(dst_nd_sbp))\n            .ScopeSymbolId(scope_symbol_id)\n            .Build()\n            .op_conf();\n    return true;\n  } else if ((src_dim1_sbp.has_split_parallel() && dst_dim1_sbp.has_broadcast_parallel())\n             && (src_dim1_sbp.split_parallel().axis() == 0) && (dim_vec.at(0) % num_ranks == 0)) {\n    // (*, S(0)) -> (*, B) : AllGather\n    *ret =\n        user_op::UserOpConfWrapperBuilder(kNcclLogicalOpNamePrefix + \"-(*S0)2(*B)-\" + NewUniqueId())\n            .Op(\"_nccl_logical_2D_same_dim0_all_gather\")\n            .Input(\"in\", lbn)\n            .Output(\"out\")\n            .Attr<std::vector<std::string>>(\"src_reduced_nd_sbp\", NdSbpToStringList(src_nd_sbp))\n            .Attr<std::vector<std::string>>(\"dst_reduced_nd_sbp\", NdSbpToStringList(dst_nd_sbp))\n            .ScopeSymbolId(scope_symbol_id)\n            .Build()\n            .op_conf();\n    return true;\n  } else if (src_dim1_sbp.has_split_parallel() && dst_dim1_sbp.has_broadcast_parallel()\n             && (src_dim1_sbp.split_parallel().axis() > 0)\n             && (dim_vec.at(src_dim1_sbp.split_parallel().axis()) % num_ranks == 0)) {\n    // (*, S(1)) -> (*, B) : AllGather Noncontinuous\n    *ret =\n        user_op::UserOpConfWrapperBuilder(kNcclLogicalOpNamePrefix + \"-(*S1)2(*B)-\" + NewUniqueId())\n            .Op(\"_nccl_logical_2D_same_dim0_all_gather_noncontinuous\")\n            .Input(\"in\", lbn)\n            .Output(\"out\")\n            .Attr<std::vector<std::string>>(\"src_reduced_nd_sbp\", NdSbpToStringList(src_nd_sbp))\n            .Attr<std::vector<std::string>>(\"dst_reduced_nd_sbp\", NdSbpToStringList(dst_nd_sbp))\n            .ScopeSymbolId(scope_symbol_id)\n            .Build()\n            .op_conf();\n    return true;\n  } else if ((src_dim1_sbp.has_split_parallel() && dst_dim1_sbp.has_split_parallel())\n             && (src_dim1_sbp.split_parallel().axis() != dst_dim1_sbp.split_parallel().axis())\n             && (dim_vec.at(src_dim1_sbp.split_parallel().axis()) % num_ranks == 0)\n             && (dim_vec.at(dst_dim1_sbp.split_parallel().axis()) % num_ranks == 0)) {\n    // (*, S(src_split_axis)) -> (*, S(dst_split_axis)) : All2All\n    *ret =\n        user_op::UserOpConfWrapperBuilder(kNcclLogicalOpNamePrefix + \"-(*S)2(*S)-\" + NewUniqueId())\n            .Op(\"_nccl_logical_2D_same_dim0_all2all\")\n            .Input(\"in\", lbn)\n            .Output(\"out\")\n            .Attr<std::vector<std::string>>(\"src_reduced_nd_sbp\", NdSbpToStringList(src_nd_sbp))\n            .Attr<std::vector<std::string>>(\"dst_reduced_nd_sbp\", NdSbpToStringList(dst_nd_sbp))\n            .ScopeSymbolId(scope_symbol_id)\n            .Build()\n            .op_conf();\n    return true;\n  }\n  return false;\n}\n\nbool TryBuildNcclBy2DHierarchySameDim1(OperatorConf* ret, const NdSbp& src_nd_sbp,\n                                       const NdSbp& dst_nd_sbp,\n                                       const std::shared_ptr<Shape>& hierarchy,\n                                       const std::string& lbn, const int64_t scope_symbol_id,\n                                       const BlobDesc& logical_blob_desc) {\n  CHECK_EQ(src_nd_sbp.sbp_parallel_size(), 2);\n  CHECK_EQ(dst_nd_sbp.sbp_parallel_size(), 2);\n  CHECK(src_nd_sbp.sbp_parallel(1) == dst_nd_sbp.sbp_parallel(1));\n  const SbpParallel& src_dim1_sbp = src_nd_sbp.sbp_parallel(0);\n  const SbpParallel& dst_dim1_sbp = dst_nd_sbp.sbp_parallel(0);\n  if (src_dim1_sbp.has_partial_sum_parallel() && dst_dim1_sbp.has_broadcast_parallel()) {\n    // (P, *) -> (B, *) : AllReduce\n    *ret =\n        user_op::UserOpConfWrapperBuilder(kNcclLogicalOpNamePrefix + \"-(P*)2(B*)-\" + NewUniqueId())\n            .Op(\"_nccl_logical_2D_same_dim1_all_reduce\")\n            .Input(\"in\", lbn)\n            .Output(\"out\")\n            .Attr<std::vector<std::string>>(\"src_reduced_nd_sbp\", NdSbpToStringList(src_nd_sbp))\n            .Attr<std::vector<std::string>>(\"dst_reduced_nd_sbp\", NdSbpToStringList(dst_nd_sbp))\n            .ScopeSymbolId(scope_symbol_id)\n            .Build()\n            .op_conf();\n    return true;\n  }\n  return false;\n}\n\nbool TryBuildNcclBy2DHierarchyOthers(OperatorConf* ret, const NdSbp& src_nd_sbp,\n                                     const NdSbp& dst_nd_sbp,\n                                     const std::shared_ptr<Shape>& hierarchy,\n                                     const std::string& lbn, const int64_t scope_symbol_id,\n                                     const BlobDesc& logical_blob_desc) {\n  CHECK_EQ(src_nd_sbp.sbp_parallel_size(), 2);\n  CHECK_EQ(dst_nd_sbp.sbp_parallel_size(), 2);\n  // send recv is dealing with same 0-Dim\n  VLOG_IF(3, src_nd_sbp.sbp_parallel(0) == dst_nd_sbp.sbp_parallel(0))\n      << \"send recv is dealing with same 0-Dim, src sbp \" << NdSbpToString(src_nd_sbp)\n      << \", dst sbp \" << NdSbpToString(dst_nd_sbp);\n  // send recv is dealing with same 1-Dim, such as (B, S0) -> (S0, S0)\n  VLOG_IF(3, ((src_nd_sbp.sbp_parallel(1) == dst_nd_sbp.sbp_parallel(1))\n              && !(NdSbpAllSameSplitParallel(src_nd_sbp) || NdSbpAllSameSplitParallel(dst_nd_sbp))))\n      << \"send recv is dealing with same 1-Dim,  src sbp \" << NdSbpToString(src_nd_sbp)\n      << \", dst sbp \" << NdSbpToString(dst_nd_sbp);\n  // send recv can not dealing with P in dst_nd_sbp\n  if (NdSbpHasPartialParallel(dst_nd_sbp)) return false;\n  *ret = user_op::UserOpConfWrapperBuilder(kNcclLogicalOpNamePrefix + \"-(Send)2(Recv)-\"\n                                           + NewUniqueId())\n             .Op(\"_nccl_logical_send_recv\")\n             .Input(\"in\", lbn)\n             .Output(\"out\")\n             .Attr<std::vector<std::string>>(\"src_reduced_nd_sbp\", NdSbpToStringList(src_nd_sbp))\n             .Attr<std::vector<std::string>>(\"dst_reduced_nd_sbp\", NdSbpToStringList(dst_nd_sbp))\n             .ScopeSymbolId(scope_symbol_id)\n             .Build()\n             .op_conf();\n  return true;\n}\n\nMaybe<int64_t> BuildScopeWithReducedParallelDesc(int64_t old_scope_symbol_id,\n                                                 const ParallelDesc& parallel_desc) {\n  auto* scope_storage = Singleton<symbol::Storage<Scope>>::Get();\n  CHECK_OR_RETURN(scope_storage->Has(old_scope_symbol_id));\n  auto old_scope = scope_storage->GetPtr(old_scope_symbol_id);\n  std::shared_ptr<Scope> new_scope;\n  JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe<void> {\n    new_scope =\n        JUST(builder->BuildScopeWithNewParallelConf(old_scope, parallel_desc.parallel_conf()));\n    return Maybe<void>::Ok();\n  }));\n  // NOTE(chengcheng): need sync vm for get scope right now\n  JUST(vm::CurrentRankSync());\n  CHECK_OR_RETURN(new_scope);\n  return JUST(new_scope->symbol_id());\n}\n\nbool TryBuildNcclLogicalOpConf(OperatorConf* ret, const OpNode* src_node, const OpNode* dst_node,\n                               const LogicalBlobId& lbi, ParallelDesc* src_reduced_parallel_desc,\n                               ParallelDesc* dst_reduced_parallel_desc, NdSbp* src_reduced_nd_sbp,\n                               NdSbp* dst_reduced_nd_sbp) {\n  if (!src_node->op().op_conf().has_scope_symbol_id()) { return false; /* device_tick */ }\n  const std::string lbn = GenLogicalBlobName(lbi);\n  const BlobDesc& logical_blob_desc = src_node->LogicalBlobDesc4Lbi(lbi);\n\n  // reduce hierarchy\n  InOutParallelDimReduce(src_node->parallel_desc(), dst_node->parallel_desc(),\n                         src_node->NdSbp4Lbi(lbi), dst_node->NdSbp4Lbi(lbi),\n                         src_reduced_parallel_desc, dst_reduced_parallel_desc, src_reduced_nd_sbp,\n                         dst_reduced_nd_sbp, logical_blob_desc.shape());\n\n  CHECK_EQ(src_reduced_parallel_desc->parallel_num(), dst_reduced_parallel_desc->parallel_num());\n  std::shared_ptr<Shape> src_reduced_hierarchy = src_reduced_parallel_desc->hierarchy();\n  std::shared_ptr<Shape> dst_reduced_hierarchy = dst_reduced_parallel_desc->hierarchy();\n\n  if ((*src_reduced_hierarchy) == (*dst_reduced_hierarchy)\n      && (*src_reduced_nd_sbp) == (*dst_reduced_nd_sbp)) {\n    // one to one\n    return false;\n  }\n\n  // NOTE(chengcheng): nccl donot support dynamic shape.\n  if (logical_blob_desc.is_dynamic()) { return false; }\n  CHECK_GT(logical_blob_desc.shape().elem_cnt(), 0)\n      << dst_node->op().op_name() << \" consume \" << GenLogicalBlobName(lbi) << \", \"\n      << *CHECK_JUST(PlacementToString(*src_reduced_parallel_desc)) << \" \"\n      << NdSbpToString(*src_reduced_nd_sbp) << \" -> \"\n      << *CHECK_JUST(PlacementToString(*dst_reduced_parallel_desc)) << \" \"\n      << NdSbpToString(*dst_reduced_nd_sbp);\n\n  int64_t scope_symbol_id = CHECK_JUST(BuildScopeWithReducedParallelDesc(\n      src_node->op().op_conf().scope_symbol_id(), *src_reduced_parallel_desc));\n\n  if (src_reduced_hierarchy->NumAxes() == 1 && dst_reduced_hierarchy->NumAxes() == 1) {\n    return TryBuildNcclBy1DHierarchy(ret, src_reduced_nd_sbp->sbp_parallel(0),\n                                     dst_reduced_nd_sbp->sbp_parallel(0), lbn, scope_symbol_id,\n                                     logical_blob_desc, src_reduced_parallel_desc->parallel_num());\n  } else if (src_reduced_hierarchy->NumAxes() == 2\n             && (*src_reduced_hierarchy == *dst_reduced_hierarchy)) {\n    bool got_nccl = false;\n    if (src_reduced_nd_sbp->sbp_parallel(0) == dst_reduced_nd_sbp->sbp_parallel(0)) {\n      // TODO(): same dim 0 need to deal with (*, P) -> (*, S)\n      got_nccl = TryBuildNcclBy2DHierarchySameDim0(ret, *src_reduced_nd_sbp, *dst_reduced_nd_sbp,\n                                                   src_reduced_hierarchy, lbn, scope_symbol_id,\n                                                   logical_blob_desc);\n    } else if (src_reduced_nd_sbp->sbp_parallel(1) == dst_reduced_nd_sbp->sbp_parallel(1)) {\n      if (!(NdSbpAllSameSplitParallel(*src_reduced_nd_sbp)\n            || NdSbpAllSameSplitParallel(*dst_reduced_nd_sbp))) {\n        got_nccl = TryBuildNcclBy2DHierarchySameDim1(ret, *src_reduced_nd_sbp, *dst_reduced_nd_sbp,\n                                                     src_reduced_hierarchy, lbn, scope_symbol_id,\n                                                     logical_blob_desc);\n      }\n    }\n    if (!got_nccl) {\n      got_nccl = TryBuildNcclBy2DHierarchyOthers(ret, *src_reduced_nd_sbp, *dst_reduced_nd_sbp,\n                                                 src_reduced_hierarchy, lbn, scope_symbol_id,\n                                                 logical_blob_desc);\n    }\n    VLOG_IF(3, !got_nccl) << \"Cannot get nccl logical op for 2D sbp, src nd sbp \"\n                          << NdSbpToString(*src_reduced_nd_sbp) << \", dst nd sbp \"\n                          << NdSbpToString(*dst_reduced_nd_sbp) << \".\";\n    return got_nccl;\n  }\n  return false;\n}\n\nvoid InsertNcclLogicalOpsAsCloseAsPossibleToDstNode(\n    HashMap<std::string, OperatorConf>* subgraph_op_name2conf, HashSet<std::string>* mut_op_names,\n    std::vector<OperatorConf>* nccl_op_confs, std::vector<ParallelConf>* nccl_op_parallel_confs,\n    const std::vector<const OpNode*>& subgraph_ordered_nodes,\n    const HashMap<const OpNode*, int64_t>& node2subgraph_order) {\n  for (const OpNode* dst_node : subgraph_ordered_nodes) {\n    const std::string& dst_op_name = dst_node->op().op_name();\n    for (const OpEdge* op_edge : dst_node->in_edges()) {\n      const OpNode* src_node = op_edge->src_node();\n      const std::string& src_op_name = src_node->op().op_name();\n      CHECK(src_node != dst_node);\n      if (src_node->parallel_desc().EqualsIgnoringHierarchy(dst_node->parallel_desc())) {\n        // NOTE(chengcheng): We don't care src node whether in this subgraph, or whether is repeat\n        //  op, or whether is breaking op. We ONLY care src node is same placement with dst.\n        //  So, we can handle both ZeRO from variable and in GradAcc from repeat and in Pipeline.\n        for (const LogicalBlobId& lbi : op_edge->lbis()) {\n          OperatorConf nccl_op;\n          ParallelDesc src_reduced_parallel_desc = op_edge->src_node()->parallel_desc();\n          ParallelDesc dst_reduced_parallel_desc = op_edge->dst_node()->parallel_desc();\n          NdSbp src_reduced_nd_sbp;\n          NdSbp dst_reduced_nd_sbp;\n          if (!TryBuildNcclLogicalOpConf(&nccl_op, src_node, dst_node, lbi,\n                                         &src_reduced_parallel_desc, &dst_reduced_parallel_desc,\n                                         &src_reduced_nd_sbp, &dst_reduced_nd_sbp)) {\n            continue;\n          }\n          mut_op_names->insert(dst_op_name);\n          // insert nccl op\n          user_op::UserOpConfWrapper nccl_op_wrapper(nccl_op);\n          for (const std::string& ibn : op_edge->lbi2ibns().at(lbi)) {\n            std::string old_lbn = ReplaceInputLbnInOpCustomizedConf(\n                &subgraph_op_name2conf->at(dst_op_name), ibn, nccl_op_wrapper.output(\"out\", 0));\n            CHECK(old_lbn == GenLogicalBlobName(lbi));\n          }\n\n          // NOTE(chengcheng): Do NOT add ctrl edge for nccl fusion.\n          nccl_op_confs->emplace_back(nccl_op);\n          // NOTE(chengcheng, guoran): set nccl op as dst_node parallel_conf (hierarchy) may check\n          //   failed in complier, so need use dst_node reduced_parallel_conf.\n          nccl_op_parallel_confs->emplace_back(dst_reduced_parallel_desc.parallel_conf());\n          VLOG(2) << \" insert nccl op: \" << nccl_op.name() << \" from [\" << src_op_name << \"] to [\"\n                  << dst_op_name << \"]\\n\";\n        }\n      }\n    }\n  }\n}\n\nvoid GenAfterAccSubgraph(std::vector<const OpNode*>* ordered_after_acc_subgraph,\n                         const HashMap<const OpNode*, int64_t>& op_node2global_order,\n                         const std::vector<const OpNode*>& ordered_acc_op_nodes) {\n  std::shared_ptr<const Shape> seed_time_shape = std::make_shared<const Shape>(Shape({1, 1}));\n  const ParallelDesc& seed_parallel_desc = ordered_acc_op_nodes.front()->parallel_desc();\n  HashSet<const OpNode*> visited;\n  std::queue<const OpNode*> queued_nodes;\n  auto SearchToNextNode = [&](const OpNode* cur_node, const OpNode* next_node, const OpEdge* edge) {\n    if (visited.find(next_node) == visited.end() && (!IsBreakpointOpNode(next_node))\n        && next_node->parallel_desc().EqualsIgnoringHierarchy(seed_parallel_desc)\n        && SharedPtrShapeEqual(GetOpNodeFastestTimeShape(next_node), seed_time_shape)) {\n      CHECK(visited.insert(next_node).second);\n      queued_nodes.push(next_node);\n    }\n  };\n\n  auto CmpOpNodeOrder = [&](const OpNode* lhs, const OpNode* rhs) {\n    return op_node2global_order.at(lhs) < op_node2global_order.at(rhs);\n  };\n\n  for (const OpNode* acc_node : ordered_acc_op_nodes) {\n    for (const OpEdge* out_edge : acc_node->out_edges()) {\n      const OpNode* seed_node = out_edge->dst_node();\n      SearchToNextNode(acc_node, seed_node, out_edge);\n    }\n  }\n\n  while (!queued_nodes.empty()) {\n    const OpNode* cur_node = queued_nodes.front();\n    queued_nodes.pop();\n\n    ordered_after_acc_subgraph->push_back(cur_node);\n\n    for (const OpEdge* in_edge : cur_node->in_edges()) {\n      SearchToNextNode(cur_node, in_edge->src_node(), in_edge);\n    }\n    for (const OpEdge* out_edge : cur_node->out_edges()) {\n      SearchToNextNode(cur_node, out_edge->dst_node(), out_edge);\n    }\n  }\n\n  std::sort(ordered_after_acc_subgraph->begin(), ordered_after_acc_subgraph->end(), CmpOpNodeOrder);\n}\n\nstruct InsertNcclSubGraph {\n  std::vector<const OpNode*> ordered_op_nodes;\n  int64_t begin_op_global_order;\n  int64_t end_op_global_order;\n  const OpNode* begin_op;\n  const OpNode* end_op;\n};\n\nstruct PlacementNcclSubGraghsInfo {\n  std::vector<std::shared_ptr<InsertNcclSubGraph>> ordered_subgraph;\n  std::vector<const OpNode*> ordered_acc_op_nodes;\n  const ParallelDesc* seed_parallel_desc;\n};\n\nvoid InitInsertNcclSubGraphInfoFromSet(\n    std::shared_ptr<InsertNcclSubGraph> nccl_subgraph_info, const HashSet<const OpNode*>& subgraph,\n    const HashMap<const OpNode*, int64_t>& op_node2global_order,\n    const std::function<bool(const OpNode*, const OpNode*)>& CmpOpNodeOrder) {\n  auto* subgraph_ordered_nodes = &nccl_subgraph_info->ordered_op_nodes;\n  subgraph_ordered_nodes->assign(subgraph.begin(), subgraph.end());\n  std::sort(subgraph_ordered_nodes->begin(), subgraph_ordered_nodes->end(), CmpOpNodeOrder);\n  nccl_subgraph_info->begin_op = subgraph_ordered_nodes->front();\n  nccl_subgraph_info->end_op = subgraph_ordered_nodes->back();\n  nccl_subgraph_info->begin_op_global_order = op_node2global_order.at(nccl_subgraph_info->begin_op);\n  nccl_subgraph_info->end_op_global_order = op_node2global_order.at(nccl_subgraph_info->end_op);\n  CHECK(nccl_subgraph_info->begin_op != nccl_subgraph_info->end_op);\n  CHECK_LT(nccl_subgraph_info->begin_op_global_order, nccl_subgraph_info->end_op_global_order);\n}\n\nstd::string GetStreamIndexName(uint32_t id) { return \"NCCL_COMPUTE_\" + std::to_string(id); }\n\nint64_t InsertNcclLogicalOpsInSubGraph(const OpGraph& op_graph, JobBuilder* job_builder,\n                                       const std::vector<const OpNode*>& subgraph_ordered_nodes,\n                                       int64_t* nccl_compute_stream_id,\n                                       const int64_t logical_chain_id) {\n  HashMap<const OpNode*, int64_t> node2subgraph_order;\n  node2subgraph_order.reserve(subgraph_ordered_nodes.size());\n  for (int64_t i = 0; i < subgraph_ordered_nodes.size(); ++i) {\n    CHECK(node2subgraph_order.emplace(subgraph_ordered_nodes.at(i), i).second);\n  }\n\n  VLOG(3) << \" ======================================================================== \\n\"\n          << \" Try insert nccl logical ops into Graph: \" << job_builder->job().job_conf().job_name()\n          << \" , logical_chain: \" << logical_chain_id << \". Begin...\\n\";\n\n  HashSet<std::string> mut_op_names;\n  HashMap<std::string, OperatorConf> subgraph_op_name2conf;\n  for (const OpNode* this_node : subgraph_ordered_nodes) {\n    VLOG(3) << \"logical_chain: \" << logical_chain_id << \" , op: \" << this_node->op().op_name();\n    CHECK(\n        subgraph_op_name2conf.emplace(this_node->op().op_name(), this_node->op().op_conf()).second);\n  }\n\n  std::vector<OperatorConf> nccl_op_confs;\n  std::vector<ParallelConf> nccl_op_parallel_confs;\n  // NOTE(chengcheng): ONLY support insert nccl to dst for memory.\n  InsertNcclLogicalOpsAsCloseAsPossibleToDstNode(&subgraph_op_name2conf, &mut_op_names,\n                                                 &nccl_op_confs, &nccl_op_parallel_confs,\n                                                 subgraph_ordered_nodes, node2subgraph_order);\n\n  VLOG(3) << \" ======================================================================== \\n\"\n          << \" Try insert nccl logical ops into Graph: \" << job_builder->job().job_conf().job_name()\n          << \" , logical_chain: \" << logical_chain_id << \". End.\\n\";\n\n  // NOTE(chengcheng): For NCCL logical correct exec order in pipeline multi-subgraph.\n  if (nccl_op_confs.empty()) { return 0; }\n  const int64_t max_nccl_stream_count = EnvInteger<ONEFLOW_GRAPH_MAX_NCCL_COMPUTE_STREAM>();\n  if ((*nccl_compute_stream_id) >= max_nccl_stream_count) {\n    return 0;  // NOTE(chengcheng): ONLY support kMaxNcclComputeStreamCount insert nccl subgraphs.\n  }\n  std::string stream_index_name = GetStreamIndexName(*nccl_compute_stream_id);\n  // NOTE(chengcheng): ONLY valid subgraph will increase nccl stream id.\n  (*nccl_compute_stream_id)++;\n\n  // NOTE(chengcheng): set ALL subgraph op and ALL nccl op stream index and logical chain id.\n  for (auto& pair : subgraph_op_name2conf) {\n    mut_op_names.insert(pair.first);\n    pair.second.set_stream_name_hint(stream_index_name);\n    pair.second.set_logical_chain_id(logical_chain_id);\n  }\n  for (auto& nccl_op : nccl_op_confs) {\n    nccl_op.set_stream_name_hint(stream_index_name);\n    nccl_op.set_logical_chain_id(logical_chain_id);\n  }\n\n  std::vector<OperatorConf> mut_op_confs;\n  mut_op_confs.reserve(mut_op_names.size());\n  for (const std::string& mut_op_name : mut_op_names) {\n    mut_op_confs.emplace_back(subgraph_op_name2conf.at(mut_op_name));\n  }\n  job_builder->MutOpsOnlyOnce(mut_op_confs);\n\n  CHECK_EQ(nccl_op_confs.size(), nccl_op_parallel_confs.size());\n  for (int64_t i = 0; i < nccl_op_confs.size(); ++i) {\n    CHECK_JUST(job_builder->AddOp(nccl_op_parallel_confs.at(i), nccl_op_confs.at(i)));\n  }\n  VLOG(3) << \" In logical chain id: \" << logical_chain_id\n          << \" insert nccl op num = \" << nccl_op_confs.size()\n          << \" and origin chain op num = \" << subgraph_ordered_nodes.size();\n  return nccl_op_confs.size() + subgraph_ordered_nodes.size();\n}\n\nvoid InsertNcclLogicalOpsAfterAcc(const OpGraph& op_graph, JobBuilder* job_builder,\n                                  const std::vector<const OpNode*>& ordered_acc_op_nodes,\n                                  const HashMap<const OpNode*, int64_t>& op_node2global_order,\n                                  const int64_t nccl_compute_stream_id,\n                                  const int64_t logical_chain_id) {\n  // insert nccl ops after acc\n  std::vector<const OpNode*> ordered_after_acc_subgraph;\n  GenAfterAccSubgraph(&ordered_after_acc_subgraph, op_node2global_order, ordered_acc_op_nodes);\n  if (ordered_after_acc_subgraph.size() <= 1) { return; }\n\n  HashMap<const OpNode*, int64_t> node2subgraph_order;\n  node2subgraph_order.reserve(ordered_after_acc_subgraph.size());\n  for (int64_t i = 0; i < ordered_after_acc_subgraph.size(); ++i) {\n    CHECK(node2subgraph_order.emplace(ordered_after_acc_subgraph.at(i), i).second);\n  }\n\n  std::vector<OperatorConf> after_acc_nccl_op_confs;\n  std::vector<ParallelConf> after_acc_nccl_parallel_confs;\n  HashSet<std::string> mut_op_names;\n  HashMap<std::string, OperatorConf> acc_subgraph_op_name2conf;\n  for (const OpNode* this_node : ordered_after_acc_subgraph) {\n    CHECK(acc_subgraph_op_name2conf.emplace(this_node->op().op_name(), this_node->op().op_conf())\n              .second);\n    VLOG(3) << \"After Acc logical_chain: \" << logical_chain_id\n            << \" , op: \" << this_node->op().op_name();\n  }\n\n  InsertNcclLogicalOpsAsCloseAsPossibleToDstNode(\n      &acc_subgraph_op_name2conf, &mut_op_names, &after_acc_nccl_op_confs,\n      &after_acc_nccl_parallel_confs, ordered_after_acc_subgraph, node2subgraph_order);\n\n  if (after_acc_nccl_op_confs.empty()) {\n    CHECK(after_acc_nccl_parallel_confs.empty());\n    CHECK(mut_op_names.empty());\n  } else {\n    std::string stream_index_name = GetStreamIndexName(nccl_compute_stream_id);\n\n    // set logical chain id and stream name for ops after acc\n    for (auto& pair : acc_subgraph_op_name2conf) {\n      mut_op_names.insert(pair.first);\n      pair.second.set_stream_name_hint(stream_index_name);\n      pair.second.set_logical_chain_id(logical_chain_id);\n    }\n    for (auto& nccl_op : after_acc_nccl_op_confs) {\n      nccl_op.set_stream_name_hint(stream_index_name);\n      nccl_op.set_logical_chain_id(logical_chain_id);\n    }\n\n    // insert nccl ops after acc\n    std::vector<OperatorConf> mut_op_confs;\n    mut_op_confs.reserve(mut_op_names.size());\n    for (const std::string& mut_op_name : mut_op_names) {\n      mut_op_confs.emplace_back(acc_subgraph_op_name2conf.at(mut_op_name));\n    }\n    job_builder->MutOpsOnlyOnce(mut_op_confs);\n\n    CHECK_EQ(after_acc_nccl_op_confs.size(), after_acc_nccl_parallel_confs.size());\n    for (int64_t i = 0; i < after_acc_nccl_op_confs.size(); ++i) {\n      CHECK_JUST(\n          job_builder->AddOp(after_acc_nccl_parallel_confs.at(i), after_acc_nccl_op_confs.at(i)));\n    }\n  }\n}\n\nMaybe<void> InsertNcclLogicalOpPass::Apply(const OpGraph& op_graph, JobBuilder* job_builder) const {\n  std::vector<const OpNode*> ordered_op_nodes;\n  if (ParseBooleanFromEnv(\"DISABLE_LOGICAL_STRAIGHTEN\", false)) {\n    op_graph.TopoForEachNodeWithCtrlEdge(\n        [&](const OpNode* node) { ordered_op_nodes.emplace_back(node); });\n  } else {\n    auto_parallel::StraightenOpGraph(op_graph, &ordered_op_nodes);\n  }\n\n  HashMap<const OpNode*, int64_t> op_node2global_order;\n  for (int32_t global_order = 0; global_order < ordered_op_nodes.size(); global_order++) {\n    op_node2global_order.emplace(ordered_op_nodes[global_order], global_order);\n  }\n\n  std::vector<HashSet<const OpNode*>> subgraph_list;\n  FindAllConnectedSubgraphForGpuExecOrder(&subgraph_list, op_graph, ordered_op_nodes);\n  if (subgraph_list.size() == 0) { return Maybe<void>::Ok(); }\n\n  // sign subgraph ops logical chain id for merge.\n  int64_t global_logical_chain_id = 0;\n\n  auto CmpOpNodeOrder = [&](const OpNode* lhs, const OpNode* rhs) {\n    return op_node2global_order.at(lhs) < op_node2global_order.at(rhs);\n  };\n  auto CmpSubGraphOrder = [&](const std::shared_ptr<InsertNcclSubGraph>& lhs,\n                              const std::shared_ptr<InsertNcclSubGraph>& rhs) {\n    int64_t lhs_begin_op_global_order = op_node2global_order.at(lhs->ordered_op_nodes.front());\n    int64_t rhs_begin_op_global_order = op_node2global_order.at(rhs->ordered_op_nodes.front());\n    return lhs_begin_op_global_order < rhs_begin_op_global_order;\n  };\n\n  HashMap<std::string, PlacementNcclSubGraghsInfo> placement2subgraphs;\n  for (const auto& subgraph : subgraph_list) {\n    const OpNode* rand_node = *subgraph.begin();\n    const ParallelDesc& this_parallel_desc = rand_node->parallel_desc();\n    std::string key = GenParallelConfKey(this_parallel_desc.parallel_conf());\n    auto it = placement2subgraphs.find(key);\n    if (it == placement2subgraphs.end()) {\n      it = placement2subgraphs.emplace(key, PlacementNcclSubGraghsInfo()).first;\n      it->second.seed_parallel_desc = &this_parallel_desc;\n    } else {\n      CHECK(this_parallel_desc.EqualsIgnoringHierarchy(*it->second.seed_parallel_desc));\n    }\n    auto& info = it->second;\n    info.ordered_subgraph.emplace_back(std::make_shared<InsertNcclSubGraph>());\n    InitInsertNcclSubGraphInfoFromSet(info.ordered_subgraph.back(), subgraph, op_node2global_order,\n                                      CmpOpNodeOrder);\n  }\n  for (auto& pair : placement2subgraphs) {\n    std::sort(pair.second.ordered_subgraph.begin(), pair.second.ordered_subgraph.end(),\n              CmpSubGraphOrder);\n  }\n\n  for (const OpNode* this_node : ordered_op_nodes) {\n    if (IsAccOpNode(this_node)) {\n      const ParallelDesc& this_parallel_desc = this_node->parallel_desc();\n      std::string key = GenParallelConfKey(this_parallel_desc.parallel_conf());\n      auto it = placement2subgraphs.find(key);\n      if (it != placement2subgraphs.end()) {\n        it->second.ordered_acc_op_nodes.emplace_back(this_node);\n      }\n    }\n  }\n\n  for (auto& pair : placement2subgraphs) {\n    PlacementNcclSubGraghsInfo& info = pair.second;\n\n    // NOTE(chengcheng): insert nccl ops for each subgraph\n    int64_t stream_offset = 0;\n    int64_t total_op_num = 0;\n    for (int i = 0; i < info.ordered_subgraph.size(); i++) {\n      auto& ordered_op_nodes = info.ordered_subgraph.at(i)->ordered_op_nodes;\n      int64_t this_op_num = InsertNcclLogicalOpsInSubGraph(\n          op_graph, job_builder, ordered_op_nodes, &stream_offset, global_logical_chain_id++);\n      total_op_num += this_op_num;\n    }\n    if (stream_offset >= 2 && total_op_num >= 1000) {\n      LOG(WARNING) << \" In Graph: \" << job_builder->job().job_conf().job_name()\n                   << \" Placement: \" << pair.first << \" the total_op_num = \" << total_op_num\n                   << \" and has \" << stream_offset\n                   << \" different nccl stream which is possible to trigger cuda stream kernel \"\n                      \"launch upper limit.\"\n                   << \" So the nccl logical kernel will from async to sync exec, which may affect \"\n                      \"performance.\";\n      EagerCclCommMgr* comm_mgr = CHECK_NOTNULL(Singleton<EagerCclCommMgr>::Get());\n      comm_mgr->SetAsyncLaunchCclLogicalKernel(false);\n    }\n\n    // NOTE(chengcheng): insert acc for all subgraph with same placement group\n    if (!info.ordered_acc_op_nodes.empty()) {\n      InsertNcclLogicalOpsAfterAcc(op_graph, job_builder, info.ordered_acc_op_nodes,\n                                   op_node2global_order, stream_offset++,\n                                   global_logical_chain_id++);\n    }\n  }\n\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\nREGISTER_JOB_PASS(\"InsertNcclLogicalOpPass\", InsertNcclLogicalOpPass);\n\n}  // namespace oneflow\n\n#endif  // WITH_CUDA || WITH_NPU || defined(WITH_MLU)\n"
  },
  {
    "path": "oneflow/core/job_rewriter/insert_pinned_identity_op_pass.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/job_rewriter/job_pass.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nclass InsertPinnedIdentityOpPass final : public JobPass {\n public:\n  InsertPinnedIdentityOpPass() = default;\n  ~InsertPinnedIdentityOpPass() override = default;\n\n  Maybe<void> Apply(Job* job, JobPassCtx* ctx) const override;\n};\n\nMaybe<std::string> InsertPinnedIdentityOp(JobBuilder* job_builder, const OpGraph& op_graph,\n                                          const std::string& lbn) {\n  auto lbi = GenLogicalBlobId(lbn);\n  const OpNode* node = op_graph.OpNode4OpName(lbi.op_name());\n  auto pinned_identity_op =\n      user_op::UserOpConfWrapperBuilder(lbi.op_name() + \"_\" + lbi.blob_name() + \"_pinned_identity\")\n          .Op(\"pinned_identity\")\n          .Input(\"in\", lbn)\n          .Output(\"out\")\n          .ScopeSymbolId(node->op().op_conf().scope_symbol_id())\n          .Build();\n  const auto& parallel_conf = node->parallel_desc().parallel_conf();\n  job_builder->AddOps(parallel_conf, {pinned_identity_op.op_conf()});\n\n  node->ForEachNodeOnOutEdge([&](const OpNode* out_node) {\n    for (const std::string& ibn : out_node->op().input_bns()) {\n      if (out_node->op().BnInOp2Lbi(ibn) == lbi) {\n        if (!CHECK_JUST(job_builder->IsInMutOpTransaction(out_node->op().op_name()))) {\n          CHECK_JUST(job_builder->MutOpTransactionMut(out_node->op().op_conf()));\n        }\n        OperatorConf& mut_consumer_op =\n            CHECK_JUST(job_builder->MutOpTransactionGet(out_node->op().op_name()));\n        const auto& old_lbn = ReplaceInputLbnInOpCustomizedConf(\n            &mut_consumer_op, ibn, pinned_identity_op.output(\"out\", 0));\n        CHECK_EQ(old_lbn, GenLogicalBlobName(lbi));\n      }\n    }\n  });\n  return pinned_identity_op.output(\"out\", 0);\n}\n\nMaybe<void> InsertPinnedIdentityOpPass::Apply(Job* job, JobPassCtx* ctx) const {\n  if (!ctx->job_desc().IsTrain()) { return Maybe<void>::Ok(); }\n  const OpGraph op_graph(*job);\n  JobBuilder job_builder(job);\n\n  HashMap<std::string, std::string> pinned_lbns;\n  TrainConf* train_conf = job->mutable_job_conf()->mutable_train_conf();\n  // insert after loss\n  for (int i = 0; i < train_conf->loss_lbn_size(); ++i) {\n    const auto& loss_lbn = train_conf->loss_lbn(i);\n    auto it = pinned_lbns.find(loss_lbn);\n    if (it == pinned_lbns.end()) {\n      const auto& pinned_loss_lbn = JUST(InsertPinnedIdentityOp(&job_builder, op_graph, loss_lbn));\n      it = pinned_lbns.emplace(loss_lbn, *pinned_loss_lbn).first;\n    }\n    train_conf->set_loss_lbn(i, it->second);\n  }\n  // insert after loss initial gradient\n  for (int i = 0; i < train_conf->loss_grad_lbn_size(); ++i) {\n    const auto& loss_grad_lbn = train_conf->loss_grad_lbn(i);\n    auto it = pinned_lbns.find(loss_grad_lbn);\n    if (it == pinned_lbns.end()) {\n      const auto& pinned_loss_grad_lbn =\n          JUST(InsertPinnedIdentityOp(&job_builder, op_graph, loss_grad_lbn));\n      it = pinned_lbns.emplace(loss_grad_lbn, *pinned_loss_grad_lbn).first;\n    }\n    train_conf->set_loss_grad_lbn(i, it->second);\n  }\n  // insert after variable gradient\n  for (int i = 0; i < train_conf->optimizer_conf_size(); ++i) {\n    auto* optimizer_conf = train_conf->mutable_optimizer_conf(i);\n    for (int j = 0; j < optimizer_conf->variable_grad_lbns_size(); ++j) {\n      const auto& variable_grad_lbn = optimizer_conf->variable_grad_lbns(j);\n      if (variable_grad_lbn.empty()) { continue; }\n      auto it = pinned_lbns.find(variable_grad_lbn);\n      if (it == pinned_lbns.end()) {\n        const auto& pinned_variable_grad_lbn =\n            JUST(InsertPinnedIdentityOp(&job_builder, op_graph, variable_grad_lbn));\n        it = pinned_lbns.emplace(variable_grad_lbn, *pinned_variable_grad_lbn).first;\n      }\n      optimizer_conf->set_variable_grad_lbns(j, it->second);\n    }\n  }\n  JUST(job_builder.MutOpTransactionCommit());\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\nREGISTER_JOB_PASS(\"InsertPinnedIdentityOpPass\", InsertPinnedIdentityOpPass);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job_rewriter/job_completer.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/job_rewriter/job_completer.h\"\n#include \"oneflow/core/framework/placed_nd_sbp.h\"\n#include \"oneflow/core/graph/op_graph.h\"\n#include \"oneflow/core/job_rewriter/job_pass.h\"\n#include \"oneflow/core/job_rewriter/autograd.h\"\n#include \"oneflow/core/job_rewriter/autotick.h\"\n#include \"oneflow/core/job/job_desc.h\"\n#include \"oneflow/core/job/global_for.h\"\n#include \"oneflow/core/job_rewriter/group_boxing_by_dst_parallel.h\"\n#include \"oneflow/core/framework/config_def.h\"\n#include \"oneflow/core/job_rewriter/boxing_with_middle_nodes.h\"\n#include \"oneflow/core/operator/op_conf.pb.h\"\n#include \"oneflow/core/rpc/include/global_process_ctx.h\"\n#include \"oneflow/core/common/cost_util.h\"\n#include \"oneflow/core/common/buffer_manager.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nMaybe<void> CheckOpGraph(const OpGraph& op_graph) {\n  JUST(op_graph.MaybeForEachNode([&](OpNode* op_node) -> Maybe<void> {\n    size_t in_cnt = 0;\n    op_graph.ForEachDataAndCtrlInNode(op_node, [&](OpNode*) { ++in_cnt; });\n    if (in_cnt == 0) { CHECK_OR_RETURN(op_node->op().op_conf().has_wait_and_send_ids_conf()); }\n\n    size_t out_cnt = 0;\n    op_graph.ForEachDataAndCtrlOutNode(op_node, [&](OpNode*) { ++out_cnt; });\n\n    if (out_cnt == 0) { CHECK_OR_RETURN(op_node->op().op_conf().has_callback_notify_conf()); }\n    return Maybe<void>::Ok();\n  }));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> CheckAndLogOpGraph(const Job& job) {\n  auto op_graph = std::make_unique<OpGraph>(job);\n  // Check op graph.\n  JUST(CheckOpGraph(*op_graph));\n  // Log op graph.\n  if (Singleton<ResourceDesc, ForSession>::Get()->enable_debug_mode()) {\n    const JobDesc& job_desc = GlobalJobDesc();\n    TeePersistentLogStream::Create(StrCat(\"optimized_job\", job_desc.job_id()))->Write(job);\n    op_graph->ToDotWithFilePath(\"optimized_dlnet_\" + std::to_string(job_desc.job_id())\n                                + \"_op_graph.dot\");\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> WithOpGraphAndMutJob(Job* job,\n                                 const std::function<Maybe<void>(const OpGraph&, Job*)>& Handler) {\n  OpGraph op_graph(*job);\n  JUST(Handler(op_graph, job));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> WithOpGraphAndMutJobBuilder(\n    Job* job, const std::function<Maybe<void>(const OpGraph&, JobBuilder*)>& Handler) {\n  OpGraph op_graph(*job);\n  JobBuilder job_builder(job);\n  JUST(Handler(op_graph, &job_builder));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> SetCtrlInOpName4VariableOp(const OpGraph& op_graph, JobBuilder* job_builder) {\n  auto IsMutableConsumedLbi = [](const Operator& op, const LogicalBlobId& lbi) -> bool {\n    for (const std::string& bn : op.input_bns()) {\n      if (op.BnInOp2Lbi(bn) == lbi && op.InputBlobModifier4Ibn(bn).is_mutable()) { return true; }\n    }\n    return false;\n  };\n  auto IsReachable = op_graph.MakePredicatorIsOpNameDataOrCtrlReachable();\n  HashMap<const OperatorConf*, HashSet<std::string>> op_conf2ctrl_in_op_names;\n  JUST(op_graph.MaybeForEachNode([&](OpNode* op_node) -> Maybe<void> {\n    if (op_node->op().op_conf().has_variable_conf() == false) { return Maybe<void>::Ok(); }\n    if (op_node->out_edges().size() <= 1) { return Maybe<void>::Ok(); }\n    const Operator& variable_op = op_node->op();\n    const LogicalBlobId& variable_lbi = variable_op.BnInOp2Lbi(variable_op.SoleObn());\n    const OperatorConf* mutable_consumer = nullptr;\n    std::vector<const OperatorConf*> naive_consumers;\n    naive_consumers.reserve(op_node->out_edges().size());\n    for (OpEdge* edge : op_node->out_edges()) {\n      const auto& op_conf = edge->dst_node()->op().op_conf();\n      if (IsMutableConsumedLbi(edge->dst_node()->op(), variable_lbi)) {\n        CHECK_OR_RETURN(mutable_consumer == nullptr);\n        mutable_consumer = &op_conf;\n      } else {\n        naive_consumers.emplace_back(&op_conf);\n      }\n    }\n    if (mutable_consumer == nullptr) { return Maybe<void>::Ok(); }\n    for (const auto* fw_bw_op : naive_consumers) {\n      op_conf2ctrl_in_op_names[mutable_consumer].insert(fw_bw_op->name());\n    }\n    return Maybe<void>::Ok();\n  }));\n  for (const auto& pair : op_conf2ctrl_in_op_names) {\n    OperatorConf mut_mutable_consumer_op_conf(*pair.first);\n    for (const auto& fw_bw_op_name : pair.second) {\n      if (!IsReachable(fw_bw_op_name, mut_mutable_consumer_op_conf.name())) {\n        mut_mutable_consumer_op_conf.add_ctrl_in_op_name(fw_bw_op_name);\n      }\n    }\n    JUST(job_builder->MutOpOnlyOnce(mut_mutable_consumer_op_conf));\n  }\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\nMaybe<void> JobCompleter::Complete(Job* job) {\n  const auto& job_name = job->job_conf().job_name();\n  JobPassCtx job_pass_ctx(GlobalJobDesc());\n  // NOTE(chengcheng): disable this pass for reduce boxing memory life cycle to memory cost.\n  auto compile_tc = std::make_unique<CostCounter<std::chrono::seconds>>(true, true);\n  if (!Singleton<ResourceDesc, ForSession>::Get()\n           ->resource()\n           .disable_group_boxing_by_dst_parallel()) {\n    JUST(WithOpGraphAndMutJobBuilder(job, &GroupBoxingByDstParallel));\n  }\n  compile_tc->Count(\"[GraphCompile]\" + job_name + \" GroupBoxingByDstParallel\", 1, true);\n  if (GlobalProcessCtx::WorldSize() > 1) {\n    JUST(WithOpGraphAndMutJobBuilder(job, &BoxingWithMiddleNodes));\n  }\n  compile_tc->Count(\"[GraphCompile]\" + job_name + \" BoxingWithMiddleNodes\", 1, true);\n  JUST(WithOpGraphAndMutJobBuilder(job, &SetCtrlInOpName4VariableOp));\n  compile_tc->Count(\"[GraphCompile]\" + job_name + \" SetCtrl\", 1, true);\n  // complete tick ops\n  JUST(WithOpGraphAndMutJobBuilder(job, &AutoPrependTick));\n  compile_tc->Count(\"[GraphCompile]\" + job_name + \" AutoPrependTick\", 1, true);\n  JUST(WithOpGraphAndMutJobBuilder(job, &AddTickForTimeShape));\n  compile_tc->Count(\"[GraphCompile]\" + job_name + \" AddTickForTimeShape\", 1, true);\n  JUST(WithOpGraphAndMutJob(job, &MultiClientAutoSourceAndSinkTick));\n  compile_tc->Count(\"[GraphCompile]\" + job_name + \" AutoSourceAndSinkTick\", 1, true);\n  JUST(WithOpGraphAndMutJob(job, &MultiClientAutoInterfaceCriticalSectionTick));\n  compile_tc->Count(\"[GraphCompile]\" + job_name + \" CriticalSectionTick\", 1, true);\n  JUST(JobPass4Name(\"SystemOpFillJobNamePass\")(job, &job_pass_ctx));\n  compile_tc->Count(\"[GraphCompile]\" + job_name + \" SystemOpFillJobNamePass\", 1, true);\n  JUST(JobPass4Name(\"DumpBlobParallelConfPass\")(job, &job_pass_ctx));\n  compile_tc->Count(\"[GraphCompile]\" + job_name + \" DumpBlobParallelConfPass\", 1, true);\n#if defined(WITH_CUDA) || defined(WITH_NPU) || defined(WITH_MLU)\n  if (Singleton<ResourceDesc, ForSession>::Get()->nccl_use_compute_stream()) {\n    // NOTE(chengcheng): this pass need as last pass for insert correct op with nccl boxing.\n    JUST(JobPass4Name(\"InsertNcclLogicalOpPass\")(job, &job_pass_ctx));\n    compile_tc->Count(\"[GraphCompile]\" + job_name + \" InsertNcclLogicalOpPass\", 1, true);\n    // NOTE(chengcheng): must do this pass after InsertNcclLogicalOpPass for nccl op fusion and\n    //    add ctrl stirct order.\n    JUST(JobPass4Name(\"NcclLogicalOpFusionPass\")(job, &job_pass_ctx));\n    compile_tc->Count(\"[GraphCompile]\" + job_name + \" NcclLogicalOpFusionPass\", 1, true);\n    JUST(JobPass4Name(\"NcclLogicalChainStrictOrderPass\")(job, &job_pass_ctx));\n    compile_tc->Count(\"[GraphCompile]\" + job_name + \" NcclLogicalChainStrictOrderPass\", 1, true);\n\n    // NOTE(chengcheng): Because insert new logical nccl op, MUST dump time shape, sbp again.\n    JUST(JobPass4Name(\"DumpBlobParallelConfPass\")(job, &job_pass_ctx));\n    compile_tc->Count(\"[GraphCompile]\" + job_name + \" DumpBlobParallelConfPass\", 1, true);\n  }\n#endif  // WITH_CUDA || WITH_NPU || WITH_MLU\n  JUST(JobPass4Name(\"LogicalChainPass\")(job, &job_pass_ctx));\n  JUST(JobPass4Name(\"DumpBlobParallelConfPass\")(job, &job_pass_ctx));\n\n  JUST(CheckAndLogOpGraph(*job));\n  compile_tc->Count(\"[GraphCompile]\" + job_name + \" CheckAndLogOpGraph\", 1, true);\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> JobCompleter::UpdateSharedGraphForNewInput(\n    Job* job,\n    const std::function<Maybe<std::shared_ptr<one::Tensor>>(const std::string&)>& InputTensor4Name,\n    const std::function<Maybe<const OperatorConf*>(const std::string& shared_op_name)>&\n        NewOp4SharedOpName) {\n  // job is a copy from a shared graph.\n  // The job name has already update in py nn.Graph.\n  const auto& new_job_name = job->job_conf().job_name();\n\n  const auto& UpdateInputShape = [&InputTensor4Name](OperatorConf& op_conf) -> Maybe<void> {\n    // Input op needs to be updated with new input tensor.\n    if (op_conf.has_input_conf()) {\n      InputOpConf* input_conf = op_conf.mutable_input_conf();\n      InterfaceBlobConf* blob_conf = input_conf->mutable_blob_conf();\n      auto input_tensor = *JUST(InputTensor4Name(op_conf.name()));\n      input_tensor->shape()->ToProto(blob_conf->mutable_shape());\n      blob_conf->set_data_type(input_tensor->dtype()->data_type());\n    }\n    return Maybe<void>::Ok();\n  };\n\n  const auto& UpdateAttr = [&NewOp4SharedOpName](OperatorConf& op_conf) -> Maybe<void> {\n    // Some op attributes need to be updated with the new traced graph.\n    if (op_conf.has_user_conf()) {\n      for (auto& pair : *op_conf.mutable_user_conf()->mutable_attr()) {\n        const auto* new_op_conf = JUST(NewOp4SharedOpName(op_conf.name()));\n        if (new_op_conf == nullptr) { continue; }\n        CHECK_EQ_OR_RETURN(new_op_conf->user_conf().op_type_name(),\n                           op_conf.user_conf().op_type_name())\n            << \" new op \" << new_op_conf->DebugString() << \" is not corresponding with \"\n            << op_conf.DebugString();\n        auto attr_iter = new_op_conf->user_conf().attr().find(pair.first);\n        CHECK_OR_RETURN(attr_iter != new_op_conf->user_conf().attr().end())\n            << \" There is not attr \" << pair.first << \" in new op \" << new_op_conf->DebugString();\n        if (pair.second.has_at_shape()) {\n          *pair.second.mutable_at_shape() = attr_iter->second.at_shape();\n        } else if (pair.second.has_at_double()) {\n          pair.second.set_at_double(attr_iter->second.at_double());\n        } else if (pair.second.has_at_list_int64()) {\n          pair.second.mutable_at_list_int64()->CopyFrom(attr_iter->second.at_list_int64());\n        }\n      }\n    }\n    return Maybe<void>::Ok();\n  };\n\n  const auto& UpdateBufferName = [&new_job_name](OperatorConf& op_conf) -> Maybe<void> {\n  // These operators' execution depends on new job name.\n#define UPDATE_JOB_NAME(op_conf_name)                             \\\n  if (op_conf.has_##op_conf_name()) {                             \\\n    op_conf.mutable_##op_conf_name()->set_job_name(new_job_name); \\\n  }\n    UPDATE_JOB_NAME(input_conf);\n    UPDATE_JOB_NAME(output_conf);\n    UPDATE_JOB_NAME(callback_notify_conf);\n    UPDATE_JOB_NAME(wait_and_send_ids_conf);\n    UPDATE_JOB_NAME(return_conf);\n#undef UPDATE_JOB_NAME\n\n    // Critical section operators depend job_name related buffer_name.\n    if (op_conf.has_critical_section_wait_tick_conf()) {\n      const auto& buffer_name = op_conf.critical_section_wait_tick_conf().buffer_name();\n      if (buffer_name.rfind(kInputCriticalSectionWaitBufferNamePrefix, 0) == 0) {\n        op_conf.mutable_critical_section_wait_tick_conf()->set_buffer_name(\n            GetInputCriticalSectionWaitBufferName(new_job_name));\n      } else if (buffer_name.rfind(kOutputCriticalSectionWaitBufferNamePrefix, 0) == 0) {\n        op_conf.mutable_critical_section_wait_tick_conf()->set_buffer_name(\n            GetOutputCriticalSectionWaitBufferName(new_job_name));\n      }\n    }\n    if (op_conf.has_critical_section_callback_tick_conf()) {\n      const auto& buffer_name = op_conf.critical_section_callback_tick_conf().buffer_name();\n      if (buffer_name.rfind(kInputCriticalSectionCallbackBufferNamePrefix, 0) == 0) {\n        op_conf.mutable_critical_section_callback_tick_conf()->set_buffer_name(\n            GetInputCriticalSectionCallbackBufferName(new_job_name));\n      } else if (buffer_name.rfind(kOutputCriticalSectionCallbackBufferNamePrefix, 0) == 0) {\n        op_conf.mutable_critical_section_callback_tick_conf()->set_buffer_name(\n            GetOutputCriticalSectionCallbackBufferName(new_job_name));\n      }\n    }\n    return Maybe<void>::Ok();\n  };\n\n  // Update the job for new input.\n  for (auto& op_conf : *job->mutable_net()->mutable_op()) {\n    JUST(UpdateInputShape(op_conf));\n    JUST(UpdateAttr(op_conf));\n    JUST(UpdateBufferName(op_conf));\n  }\n  // Use OpGraph init to infer all LogicalBlobDesc with the new input shape.\n  auto op_graph = std::make_unique<OpGraph>(*job);\n  op_graph->DumpLogicalBlobDesc(job);\n\n#ifdef WITH_CUTLASS\n  // Warmup cutlass conv with new input shape.\n  JobPassCtx job_pass_ctx(GlobalJobDesc());\n  JUST(JobPass4Name(\"CutlassConvTuningWarmupPass\")(job, &job_pass_ctx));\n#endif  // WITH_CUTLASS\n\n  JUST(CheckAndLogOpGraph(*job));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job_rewriter/job_completer.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_JOB_REWRITER_JOB_COMPLETER_H_\n#define ONEFLOW_CORE_JOB_REWRITER_JOB_COMPLETER_H_\n\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/job/job_desc.h\"\n#include \"oneflow/core/graph/op_graph.h\"\n#include \"oneflow/core/framework/tensor.h\"\n\nnamespace oneflow {\n\nclass JobCompleter final {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(JobCompleter);\n  JobCompleter() = default;\n  ~JobCompleter() = default;\n\n  static Maybe<void> Complete(Job* job);\n  // The job is copied from a shared graph, it needs to be modified\n  // for a new graph with different input.\n  static Maybe<void> UpdateSharedGraphForNewInput(\n      Job* job,\n      const std::function<Maybe<std::shared_ptr<one::Tensor>>(const std::string&)>&\n          InputTensor4Name,\n      const std::function<Maybe<const OperatorConf*>(const std::string& shared_op_name)>&\n          NewOp4SharedOpName);\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_JOB_REWRITER_JOB_COMPLETER_H_\n"
  },
  {
    "path": "oneflow/core/job_rewriter/job_pass.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/job_rewriter/job_pass.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nHashMap<std::string, const JobPass*>* PassName2JobPass() {\n  static HashMap<std::string, const JobPass*> pass_name2job_pass;\n  return &pass_name2job_pass;\n}\n\n}  // namespace\n\nvoid RegisterJobPass(const std::string& pass_name, const JobPass* pass) {\n  CHECK(PassName2JobPass()->emplace(pass_name, pass).second);\n}\n\nbool HasJobPass(const std::string& pass_name) {\n  return PassName2JobPass()->find(pass_name) != PassName2JobPass()->end();\n}\n\nconst JobPass& JobPass4Name(const std::string& pass_name) {\n  const auto& iter = PassName2JobPass()->find(pass_name);\n  CHECK(iter != PassName2JobPass()->end()) << \"Cannot find job pass: \" << pass_name;\n  return *iter->second;\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job_rewriter/job_pass.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_JOB_REWRITER_JOB_PASS_H_\n#define ONEFLOW_CORE_JOB_REWRITER_JOB_PASS_H_\n\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/graph/op_graph.h\"\n#include \"oneflow/core/job/job_builder.h\"\n#include \"oneflow/core/job_rewriter/pass_util.h\"\n\nnamespace oneflow {\n\nclass JobPassCtx;\n\nclass JobPass {\n public:\n  JobPass() = default;\n  virtual ~JobPass() = default;\n\n  Maybe<void> operator()(Job* job, JobPassCtx* ctx) const { return Apply(job, ctx); }\n  virtual Maybe<void> Apply(Job* job, JobPassCtx* ctx) const = 0;\n};\n\nclass JobPassState {\n public:\n  virtual ~JobPassState() = default;\n\n protected:\n  JobPassState() = default;\n};\n\nclass JobPassCtx {\n public:\n  JobPassCtx(const JobPassCtx&) = delete;\n  JobPassCtx(JobPassCtx&&) = delete;\n  JobPassCtx(const JobDesc& job_desc) : job_desc_(&job_desc) {}\n  ~JobPassCtx() = default;\n\n  const JobDesc& job_desc() const { return *job_desc_; }\n\n  template<typename T>\n  Maybe<const T&> GetState(const std::string& key) const {\n    const auto& iter = key2state_.find(key);\n    CHECK_OR_RETURN(iter != key2state_.end());\n    const T* ptr = dynamic_cast<T*>(iter->second.get());\n    const auto& origin_obj = *iter->second;\n    CHECK_NOTNULL_OR_RETURN(ptr) << typeid(origin_obj).name();\n    return *ptr;\n  }\n\n  template<typename T>\n  Maybe<T*> MutableState(const std::string& key) {\n    const auto& iter = key2state_.find(key);\n    CHECK_OR_RETURN(iter != key2state_.end());\n    T* ptr = dynamic_cast<T*>(iter->second.get());\n    const auto& origin_obj = *iter->second;\n    CHECK_NOTNULL_OR_RETURN(ptr) << typeid(origin_obj).name();\n    return ptr;\n  }\n\n  template<typename T>\n  Maybe<bool> HasState(const std::string& key) const {\n    const auto& iter = key2state_.find(key);\n    return (iter != key2state_.end());\n  }\n\n  Maybe<void> ResetState(const std::string& key, std::unique_ptr<JobPassState>&& state) {\n    if (!state) {\n      key2state_.erase(key);\n    } else {\n      key2state_.emplace(key, std::move(state));\n    }\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> ResetState(const std::string& key) {\n    key2state_.erase(key);\n    return Maybe<void>::Ok();\n  }\n\n private:\n  const JobDesc* job_desc_;\n  HashMap<std::string, std::unique_ptr<JobPassState>> key2state_;\n};\n\n#define REGISTER_JOB_PASS(pass_name, pass_type) COMMAND(RegisterJobPass(pass_name, new pass_type))\n\nvoid RegisterJobPass(const std::string& pass_name, const JobPass* pass);\nbool HasJobPass(const std::string& pass_name);\nconst JobPass& JobPass4Name(const std::string& pass_name);\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_JOB_REWRITER_JOB_PASS_H_\n"
  },
  {
    "path": "oneflow/core/job_rewriter/lamb_optm.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/job_rewriter/optimizer.h\"\n#include \"oneflow/core/framework/framework.h\"\n\nnamespace oneflow {\n\nstruct BiasCorrectionFactorCacheKey {\n  float beta = 1.0;\n  ParallelConf parallel_conf;\n};\n\nbool operator==(const BiasCorrectionFactorCacheKey& lhs, const BiasCorrectionFactorCacheKey& rhs);\n\n}  // namespace oneflow\n\nnamespace std {\ntemplate<>\nstruct hash<oneflow::BiasCorrectionFactorCacheKey> {\n  size_t operator()(const oneflow::BiasCorrectionFactorCacheKey& key) const {\n    using namespace oneflow;\n    return Hash(key.beta, key.parallel_conf);\n  }\n};\n\n}  // namespace std\n\nnamespace oneflow {\n\n// Forward declaration for bias correction factor\nclass BiasCorrectionFactorState final : public JobPassState {\n public:\n  BiasCorrectionFactorState() {}\n  ~BiasCorrectionFactorState() override = default;\n\n  std::string GetLbn(float beta, std::string bias_correction_name, ParallelConf parallel_conf,\n                     const std::function<std::string(float beta_val, std::string op_name)>&\n                         BiasCorrectionFactorStateOp);\n\n private:\n  HashMap<BiasCorrectionFactorCacheKey, std::string> key2lbn_;\n};\n\nnamespace {\n\nstd::string GenVariableOutputLbn(const OperatorConf& op_conf) {\n  CHECK(op_conf.has_variable_conf());\n  return GenLogicalBlobName(op_conf.name(), op_conf.variable_conf().out());\n}\n\nOperatorConf GenerateLAMBHelperVariableOpConf(const VariableOp& op, const std::string& name,\n                                              const float initial_value) {\n  OperatorConf helper_variable_op(op.op_conf());\n  helper_variable_op.set_name(op.op_name() + \"-\" + name);\n  helper_variable_op.mutable_variable_conf()->set_out(\"out\");\n  InitializerConf constant_initializer;\n  constant_initializer.mutable_constant_conf()->set_value(initial_value);\n  *(helper_variable_op.mutable_variable_conf()->mutable_initializer()) = constant_initializer;\n  helper_variable_op.set_scope_symbol_id(op.op_conf().scope_symbol_id());\n  return helper_variable_op;\n}\n\nvoid SetScalarShapeAndNdSbpConf(const ParallelDesc& parallel_desc, OperatorConf* op_conf) {\n  op_conf->mutable_variable_conf()->mutable_shape()->clear_dim();\n  op_conf->mutable_variable_conf()->mutable_shape()->add_dim(1);\n  op_conf->mutable_variable_conf()->clear_nd_sbp();\n  FOR_RANGE(int, i, 0, parallel_desc.hierarchy()->NumAxes()) {\n    *op_conf->mutable_variable_conf()->add_nd_sbp() = \"B\";\n  }\n  CHECK_NE(op_conf->name(), std::string(\"\"));\n}\n\nvoid GenerateOptimizerOpConf(JobPassCtx* ctx, const OpNode& var_op_node,\n                             const std::string& model_diff_lbn, const OptimizerConf& optimizer_conf,\n                             JobBuilder* job_builder) {\n  const VariableOp* var_op = dynamic_cast<const VariableOp*>(&var_op_node.op());\n  CHECK_NOTNULL(var_op);\n\n  OperatorConf m_var = GenerateLAMBHelperVariableOpConf(*var_op, \"m\", 0.f);\n  OperatorConf v_var = GenerateLAMBHelperVariableOpConf(*var_op, \"v\", 0.f);\n\n  job_builder->AddOps(var_op_node.parallel_desc().parallel_conf(), {m_var, v_var});\n\n  user_op::UserOpConfWrapperBuilder lamb_update_op_builder(var_op->op_name() + \"_optimizer\");\n\n  const LambModelUpdateConf& lamb_conf = optimizer_conf.lamb_conf();\n  float beta1 = lamb_conf.beta1();\n  float beta2 = lamb_conf.beta2();\n  float epsilon = lamb_conf.epsilon();\n  bool do_bias_correction = lamb_conf.do_bias_correction();\n\n  const std::string& train_step_lbn = job_builder->job().job_conf().train_conf().train_step_lbn();\n  const std::string& learning_rate_lbn = optimizer_conf.learning_rate_lbn();\n\n  if (do_bias_correction) {\n    // Reuse adam bias_correction job pass\n    const std::string& job_pass_state_key = \"adam_bias_correction_factor\";\n    const bool has_state = CHECK_JUST(ctx->HasState<BiasCorrectionFactorState>(job_pass_state_key));\n    if (!has_state) {\n      CHECK_JUST(\n          ctx->ResetState(job_pass_state_key, std::make_unique<BiasCorrectionFactorState>()));\n    }\n    auto* state = CHECK_JUST(ctx->MutableState<BiasCorrectionFactorState>(job_pass_state_key));\n    ParallelConf bias_correction_parallel_conf;\n    const auto& lr_parallel_conf =\n        CHECK_JUST(job_builder->ParallelConf4Lbi(GenLogicalBlobId(learning_rate_lbn)));\n    const auto& train_step_parallel_conf =\n        CHECK_JUST(job_builder->ParallelConf4Lbi(GenLogicalBlobId(train_step_lbn)));\n    if (lr_parallel_conf == train_step_parallel_conf) {\n      bias_correction_parallel_conf = lr_parallel_conf;\n    } else {\n      bias_correction_parallel_conf = var_op_node.parallel_desc().parallel_conf();\n    }\n    auto AddLambBiasCorrectionFactorOp = [&](float beta_val,\n                                             const std::string& op_name) -> std::string {\n      user_op::UserOpConfWrapperBuilder op_builder(var_op->op_name() + op_name);\n      const auto lamb_bias_correction_factor_op =\n          op_builder.OpTypeName(\"adam_bias_correction_factor\")\n              .Input(\"train_step\", train_step_lbn)\n              .Attr<float>(\"beta\", beta_val)\n              .Output(\"out\")\n              .ScopeSymbolId(var_op->op_conf().scope_symbol_id())\n              .Build();\n\n      job_builder->AddOps(bias_correction_parallel_conf,\n                          {lamb_bias_correction_factor_op.op_conf()});\n      return lamb_bias_correction_factor_op.output(\"out\", 0);\n    };\n\n    const std::string bias_correction1_lbn =\n        state->GetLbn(beta1, \"lamb_bias_correction_factor1\", bias_correction_parallel_conf,\n                      AddLambBiasCorrectionFactorOp);\n    const std::string bias_correction2_lbn =\n        state->GetLbn(beta2, \"lamb_bias_correction_factor2\", bias_correction_parallel_conf,\n                      AddLambBiasCorrectionFactorOp);\n\n    lamb_update_op_builder.OpTypeName(\"lamb_update\")\n        .Input(\"model\", GenLogicalBlobName(var_op->BnInOp2Lbi(\"out\")))\n        .Input(\"model_diff\", model_diff_lbn)\n        .Input(\"m\", GenVariableOutputLbn(m_var))\n        .Input(\"v\", GenVariableOutputLbn(v_var))\n        .Input(\"learning_rate\", learning_rate_lbn)\n        .Input(\"bias_correction1\", bias_correction1_lbn)\n        .Input(\"bias_correction2\", bias_correction2_lbn)\n        .Attr<float>(\"beta1\", beta1)\n        .Attr<float>(\"beta2\", beta2)\n        .Attr<float>(\"epsilon\", epsilon)\n        .Attr<float>(\"weight_decay\", GetOptimizerWeightDecayRate(optimizer_conf, *var_op))\n        .Attr<bool>(\"do_bias_correction\", true)\n        .ScopeSymbolId(var_op->op_conf().scope_symbol_id());\n  } else {\n    lamb_update_op_builder.OpTypeName(\"lamb_update\")\n        .Input(\"model\", GenLogicalBlobName(var_op->BnInOp2Lbi(\"out\")))\n        .Input(\"model_diff\", model_diff_lbn)\n        .Input(\"m\", GenVariableOutputLbn(m_var))\n        .Input(\"v\", GenVariableOutputLbn(v_var))\n        .Input(\"learning_rate\", learning_rate_lbn)\n        .Attr<float>(\"beta1\", beta1)\n        .Attr<float>(\"beta2\", beta2)\n        .Attr<float>(\"epsilon\", epsilon)\n        .Attr<float>(\"weight_decay\", GetOptimizerWeightDecayRate(optimizer_conf, *var_op))\n        .Attr<bool>(\"do_bias_correction\", false)\n        .ScopeSymbolId(var_op->op_conf().scope_symbol_id());\n  }\n\n  if (optimizer_conf.has_lr_scale()) {\n    lamb_update_op_builder.Attr<float>(\"learning_rate_scale\", optimizer_conf.lr_scale());\n  }\n\n  SetDynamicLossScaleSkipIf(ctx, &lamb_update_op_builder);\n  const auto lamb_update_op = lamb_update_op_builder.Build();\n  job_builder->AddOps(var_op_node.parallel_desc().parallel_conf(), {lamb_update_op.op_conf()});\n}\n\n}  // namespace\n\nREGISTER_OPTIMIZER(OptimizerConf::kLambConf, &GenerateOptimizerOpConf);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job_rewriter/lars_optm.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/job_rewriter/optimizer.h\"\n#include \"oneflow/core/framework/framework.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nvoid GenerateOptimizerOpConf(JobPassCtx* ctx, const OpNode& var_op_node,\n                             const std::string& model_diff_lbn, const OptimizerConf optimizer_conf,\n                             JobBuilder* job_builder) {\n  const VariableOp* var_op = dynamic_cast<const VariableOp*>(&var_op_node.op());\n  CHECK_NOTNULL(var_op);\n  const std::string momentum_var_op_name = var_op->op_name() + \"-momentum\";\n  OperatorConf momentum_var(var_op->op_conf());\n  InitializerConf constant_initializer;\n  constant_initializer.mutable_constant_conf()->set_value(0.f);\n  *(momentum_var.mutable_variable_conf()->mutable_initializer()) = constant_initializer;\n  momentum_var.set_name(momentum_var_op_name);\n  momentum_var.mutable_variable_conf()->set_out(\"out\");\n  momentum_var.set_scope_symbol_id(var_op->op_conf().scope_symbol_id());\n  job_builder->AddOps(var_op_node.parallel_desc().parallel_conf(), {momentum_var});\n\n  user_op::UserOpConfWrapperBuilder lars_update_op_builder(var_op->op_name() + \"_optimizer\");\n  lars_update_op_builder.OpTypeName(\"lars_update\")\n      .Input(\"model\", GenLogicalBlobName(var_op->BnInOp2Lbi(\"out\")))\n      .Input(\"model_diff\", model_diff_lbn)\n      .Input(\"learning_rate\", optimizer_conf.learning_rate_lbn())\n      .Input(\"momentum\",\n             GenLogicalBlobName(momentum_var_op_name, momentum_var.variable_conf().out()))\n      .Attr<float>(\"momentum_beta\", optimizer_conf.lars_conf().momentum_beta())\n      .Attr<float>(\"epsilon\", optimizer_conf.lars_conf().epsilon())\n      .Attr<float>(\"lars_coefficient\", optimizer_conf.lars_conf().lars_coefficient())\n      .Attr<float>(\"weight_decay\", GetOptimizerWeightDecayRate(optimizer_conf, *var_op))\n      .ScopeSymbolId(var_op->op_conf().scope_symbol_id());\n  if (optimizer_conf.has_lr_scale()) {\n    lars_update_op_builder.Attr<float>(\"learning_rate_scale\", optimizer_conf.lr_scale());\n  }\n  SetDynamicLossScaleSkipIf(ctx, &lars_update_op_builder);\n  user_op::UserOpConfWrapper lars_update_op = lars_update_op_builder.Build();\n  job_builder->AddOps(var_op_node.parallel_desc().parallel_conf(), {lars_update_op.op_conf()});\n}\n\n}  // namespace\n\nREGISTER_OPTIMIZER(OptimizerConf::kLarsConf, &GenerateOptimizerOpConf);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job_rewriter/logical_chain_pass.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/auto_parallel/auto_memory.h\"\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/nd_sbp.h\"\n#include \"oneflow/core/framework/instructions_builder.h\"\n#include \"oneflow/core/framework/sbp_infer_util.h\"\n#include \"oneflow/core/job/scope.h\"\n#include \"oneflow/core/job/sbp_parallel.h\"\n#include \"oneflow/core/job/job.pb.h\"\n#include \"oneflow/core/job/nd_sbp_util.h\"\n#include \"oneflow/core/job_rewriter/job_pass.h\"\n#include \"oneflow/core/job_rewriter/calculation_pass.h\"\n#include \"oneflow/core/vm/vm_util.h\"\n#include \"oneflow/core/vm/symbol_storage.h\"\n#include \"oneflow/core/operator/operator.h\"\n#include \"oneflow/core/common/env_var/env_var.h\"\n#include \"oneflow/core/common/env_var/debug_mode.h\"\n#include \"oneflow/core/common/container_util.h\"\n\nnamespace oneflow {\n\nDEFINE_ENV_BOOL(ENABLE_ACC_CHAIN_MERGE, true);\n\nnamespace {\n\nclass LogicalChainPass final : public JobPass {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(LogicalChainPass);\n  LogicalChainPass() = default;\n  ~LogicalChainPass() = default;\n\n  Maybe<void> Apply(Job* job, JobPassCtx* ctx) const override {\n    if (!IsEnabled(*ctx)) { return Maybe<void>::Ok(); }\n    const OpGraph op_graph(*job);\n    JobBuilder job_builder(job);\n    return Apply(op_graph, &job_builder);\n  }\n\n  bool IsEnabled(const JobPassCtx& ctx) const { return EnableLogicalChain(); }\n\n  Maybe<void> Apply(const OpGraph& op_graph, JobBuilder* job_builder) const;\n};\n\nbool IsTickOpConf(const OperatorConf& op_conf) {\n  if (IsClassRegistered<int32_t, IsTickTockOpTypeCase>(op_conf.op_type_case())) { return true; }\n  if (op_conf.has_user_conf()) {\n    const std::string& user_type_name = op_conf.user_conf().op_type_name();\n    if (user_type_name == \"cast_to_tick\" || user_type_name == \"acc_ctrl_tick\") { return true; }\n  }\n  return false;\n}\n\nbool IsBreakpointOpNode(const OpNode* node) {\n  // NOTE(chengcheng): breakpoint op is special which CANNOT merge in chain such as:\n  //   variable, tick, repeat/acc/pack/unpack change timeshape\n  const Operator& op = node->op();\n  const OperatorConf& op_conf = op.op_conf();\n\n  // TODO(chengcheng): filter ops which has special type\n  // TODO(chengcheng): get stream by op type\n  if (op_conf.has_variable_conf()                                                   /* variable */\n      || IsTickOpConf(op_conf)                                                      /* tick */\n      || op_conf.has_input_conf() || op_conf.has_output_conf()                      /* io */\n      || op_conf.has_wait_and_send_ids_conf() || op_conf.has_callback_notify_conf() /* ctrl */\n      || op_conf.has_image_decoder_random_crop_resize_conf() /* gpu decode */) {\n    return true;\n  }\n\n  if (op_conf.has_user_conf()) {\n    const std::string& user_type_name = op_conf.user_conf().op_type_name();\n    if (user_type_name == \"repeat\" || user_type_name == \"unpack\"\n        || user_type_name == \"identity_buffer\" || user_type_name == \"copy_h2d\"\n        || user_type_name == \"copy_d2h\") {\n      return true;\n    }\n  }\n  return false;\n}\n\nbool IsAccOrPackOpNode(const OpNode* node) {\n  const auto& op_conf = node->op().op_conf();\n  return op_conf.has_user_conf()\n         && (op_conf.user_conf().op_type_name() == \"acc\"\n             || op_conf.user_conf().op_type_name() == \"pack\");\n}\n\nbool IsAccOpNode(const OpNode* node) {\n  return node->op().op_conf().has_user_conf()\n         && node->op().op_conf().user_conf().op_type_name() == \"acc\";\n}\n\nbool IsRepeatOpNode(const OpNode* node) {\n  return node->op().op_conf().has_user_conf()\n         && node->op().op_conf().user_conf().op_type_name() == \"repeat\";\n}\n\nstd::shared_ptr<const Shape> GetOpNodeFastestTimeShape(const OpNode* op_node) {\n  return CHECK_JUST(op_node->op().GetInputOutputFastestTimeShape());\n}\n\nstd::shared_ptr<const Shape> GetOpNodeInputTimeShape(const OpNode* op_node) {\n  return CHECK_JUST(op_node->op().GetInputBlobFastestTimeShape());\n}\n\nbool SharedPtrShapeEqual(const std::shared_ptr<const Shape>& lhs,\n                         const std::shared_ptr<const Shape>& rhs) {\n  return (*lhs) == (*rhs);\n}\n\nbool IsOpEdge121Connected(const OpNode* src_node, const OpNode* dst_node, const OpEdge* edge) {\n  CHECK(src_node != dst_node && (edge->src_node() == src_node || edge->src_node() == dst_node)\n        && (edge->dst_node() == src_node || edge->dst_node() == dst_node));\n  if (src_node->parallel_desc().parallel_num() != dst_node->parallel_desc().parallel_num()) {\n    return false;\n  }\n  if (src_node->parallel_desc().parallel_num() == 1) { return true; }\n  for (const auto& lbi : edge->lbis()) {\n    // NOTE(chengcheng): nd_sbp need to be reduction like from [P, P] to [P]\n    Shape src_reduced_hierarchy;\n    Shape dst_reduced_hierarchy;\n    NdSbp src_reduced_nd_sbp;\n    NdSbp dst_reduced_nd_sbp;\n\n    InOutParallelDimReduce(*src_node->parallel_desc().hierarchy(),\n                           *dst_node->parallel_desc().hierarchy(), src_node->NdSbp4Lbi(lbi),\n                           dst_node->NdSbp4Lbi(lbi), &src_reduced_hierarchy, &dst_reduced_hierarchy,\n                           &src_reduced_nd_sbp, &dst_reduced_nd_sbp,\n                           src_node->LogicalBlobDesc4Lbi(lbi).shape());\n    if (src_reduced_hierarchy != dst_reduced_hierarchy\n        || src_reduced_nd_sbp != dst_reduced_nd_sbp) {\n      // Not one to one\n      return false;\n    }\n  }\n\n  return true;\n}\n\nvoid GetLogicalChainsWithTimeShape(std::vector<HashSet<const OpNode*>>* ret,\n                                   const std::vector<const OpNode*>& order,\n                                   const std::shared_ptr<const Shape>& seed_time_shape) {\n  HashSet<const OpNode*> visited;\n  for (const OpNode* seed_node : order) {\n    if (visited.find(seed_node) != visited.end()) { continue; }\n    CHECK(visited.insert(seed_node).second);\n    const ParallelDesc& seed_parallel_desc = seed_node->parallel_desc();\n    if (seed_node->op().op_conf().has_logical_chain_id()) { continue; }\n    // TODO(chengcheng): support cpu chain.\n    if (seed_parallel_desc.device_type() == DeviceType::kCPU) { continue; }\n    if (!SharedPtrShapeEqual(GetOpNodeFastestTimeShape(seed_node), seed_time_shape)) { continue; }\n    if (IsBreakpointOpNode(seed_node)) { continue; }\n\n    HashSet<const OpNode*> this_subgraph;\n    std::queue<const OpNode*> queued_nodes;\n\n    queued_nodes.push(seed_node);\n    while (!queued_nodes.empty()) {\n      const OpNode* cur_node = queued_nodes.front();\n      queued_nodes.pop();\n\n      CHECK(cur_node->parallel_desc().EqualsIgnoringHierarchy(seed_parallel_desc));\n      CHECK(this_subgraph.insert(cur_node).second);\n\n      auto SearchToNextNode = [&](const OpNode* cur_node, const OpNode* next_node,\n                                  const OpEdge* edge) {\n        if (visited.find(next_node) == visited.end() && (!IsBreakpointOpNode(next_node))\n            && (!next_node->op().op_conf().has_logical_chain_id()) /* skip logical chain id */\n            && next_node->parallel_desc().EqualsIgnoringHierarchy(seed_parallel_desc)\n            && SharedPtrShapeEqual(GetOpNodeFastestTimeShape(next_node), seed_time_shape)\n            && next_node->op().op_conf().stream_name_hint()\n                   == seed_node->op().op_conf().stream_name_hint()\n            && IsOpEdge121Connected(cur_node, next_node, edge)) {\n          CHECK(visited.insert(next_node).second);\n          queued_nodes.push(next_node);\n        }\n      };\n\n      for (const OpEdge* in_edge : cur_node->in_edges()) {\n        SearchToNextNode(cur_node, in_edge->src_node(), in_edge);\n      }\n      for (const OpEdge* out_edge : cur_node->out_edges()) {\n        SearchToNextNode(cur_node, out_edge->dst_node(), out_edge);\n      }\n    }\n\n    if (this_subgraph.size() > 1) {\n      ret->emplace_back(HashSet<const OpNode*>());\n      ret->back().swap(this_subgraph);\n    }\n  }\n}\n\nstruct LogicalChain {\n  int64_t logical_chain_id;\n  std::vector<const OpNode*> ordered_op_nodes;\n  explicit LogicalChain(int64_t val) : logical_chain_id(val) { CHECK_GE(val, 0); }\n};\n\nstruct PlacementLogicalChainsInfo {\n  std::vector<std::shared_ptr<LogicalChain>> ordered_logical_chains;\n  std::vector<const OpNode*> ordered_acc_op_nodes;\n  std::shared_ptr<LogicalChain> after_acc_logical_chain;\n  const ParallelDesc* seed_parallel_desc;\n  PlacementLogicalChainsInfo() : seed_parallel_desc(nullptr) {}\n};\n\nvoid InitPlacementLogicalChainsInfoFromSet(\n    const std::shared_ptr<LogicalChain>& logical_chain,\n    const HashSet<const OpNode*>& origin_logical_chain,\n    const HashMap<const OpNode*, int64_t>& op_node2global_order,\n    const std::function<bool(const OpNode*, const OpNode*)>& CmpOpNodeOrder) {\n  auto* logical_chain_ordered_nodes = &logical_chain->ordered_op_nodes;\n  CHECK(logical_chain_ordered_nodes->empty());\n  logical_chain_ordered_nodes->assign(origin_logical_chain.begin(), origin_logical_chain.end());\n  std::sort(logical_chain_ordered_nodes->begin(), logical_chain_ordered_nodes->end(),\n            CmpOpNodeOrder);\n  const OpNode* begin_op = logical_chain_ordered_nodes->front();\n  const OpNode* end_op = logical_chain_ordered_nodes->back();\n  int64_t begin_op_global_order = op_node2global_order.at(begin_op);\n  int64_t end_op_global_order = op_node2global_order.at(end_op);\n  CHECK(begin_op != end_op);\n  CHECK_LT(begin_op_global_order, end_op_global_order);\n}\n\nvoid CreateAfterAccLogicalChain(const std::shared_ptr<LogicalChain>& after_acc_logical_chain,\n                                const std::vector<const OpNode*>& ordered_acc_op_nodes,\n                                const ParallelDesc& seed_parallel_desc) {\n  // Meta time shape (1, 1)\n  std::shared_ptr<const Shape> meta_time_shape = std::make_shared<const Shape>(Shape({1, 1}));\n  HashSet<const OpNode*> visited;\n  HashSet<const OpNode*> after_acc_chain_ops;\n  std::queue<const OpNode*> queued_nodes;\n  auto SearchToNextNode = [&](const OpNode* cur_node, const OpNode* next_node, const OpEdge* edge) {\n    if (visited.find(next_node) == visited.end() && (!IsBreakpointOpNode(next_node))\n        && next_node->parallel_desc().EqualsIgnoringHierarchy(seed_parallel_desc)\n        && SharedPtrShapeEqual(GetOpNodeFastestTimeShape(next_node), meta_time_shape)\n        && IsOpEdge121Connected(cur_node, next_node, edge)) {\n      CHECK(visited.insert(next_node).second);\n      queued_nodes.push(next_node);\n    }\n  };\n\n  for (const OpNode* acc_node : ordered_acc_op_nodes) {\n    for (const OpEdge* out_edge : acc_node->out_edges()) {\n      const OpNode* seed_node = out_edge->dst_node();\n      SearchToNextNode(acc_node, seed_node, out_edge);\n    }\n  }\n\n  while (!queued_nodes.empty()) {\n    const OpNode* cur_node = queued_nodes.front();\n    queued_nodes.pop();\n\n    CHECK(after_acc_chain_ops.insert(cur_node).second);\n\n    for (const OpEdge* in_edge : cur_node->in_edges()) {\n      SearchToNextNode(cur_node, in_edge->src_node(), in_edge);\n    }\n    for (const OpEdge* out_edge : cur_node->out_edges()) {\n      SearchToNextNode(cur_node, out_edge->dst_node(), out_edge);\n    }\n  }\n\n  if (after_acc_chain_ops.size() > 1) {\n    for (const OpNode* node : after_acc_chain_ops) {\n      after_acc_logical_chain->ordered_op_nodes.push_back(node);\n    }\n    CHECK_EQ(after_acc_logical_chain->ordered_op_nodes.size(), after_acc_chain_ops.size());\n  }\n}\n\nvoid TryMergeAfterAccLogicalChainToMaxLogicalChain(\n    PlacementLogicalChainsInfo* info, HashMap<std::string, OperatorConf>* mut_op_name2conf,\n    JobBuilder* job_builder,\n    const std::function<bool(const std::string&, const std::string&)>& IsReachable,\n    const std::shared_ptr<const Shape>& seed_time_shape) {\n  if (!EnvBool<ENABLE_ACC_CHAIN_MERGE>()) { return; }\n  int64_t max_chain_index = 0;\n  for (int64_t i = 1; i < info->ordered_logical_chains.size(); ++i) {\n    if (info->ordered_logical_chains.at(i)->ordered_op_nodes.size()\n        > info->ordered_logical_chains.at(max_chain_index)->ordered_op_nodes.size()) {\n      max_chain_index = i;\n    }\n  }\n\n  const int64_t acc_chain_id = info->after_acc_logical_chain->logical_chain_id;\n  auto& acc_chain_order_ops = info->after_acc_logical_chain->ordered_op_nodes;\n  const auto& max_chain = info->ordered_logical_chains.at(max_chain_index);\n  const OpNode* max_chain_src_op = max_chain->ordered_op_nodes.front();\n  const OpNode* max_chain_sink_op = max_chain->ordered_op_nodes.back();\n  HashSet<const OpNode*> max_chain_ops(max_chain->ordered_op_nodes.begin(),\n                                       max_chain->ordered_op_nodes.end());\n\n  const OpNode* acc_chain_src_op = acc_chain_order_ops.front();\n  const OpNode* acc_chain_sink_op = acc_chain_order_ops.back();\n  // NOTE(chengcheng): find all nontrivial sink consumer ops\n  HashSet<const OpNode*> nontrivial_sink_consumers;\n  for (const OpNode* chain_op : max_chain->ordered_op_nodes) {\n    chain_op->ForEachNodeOnOutEdge([&](const OpNode* out_node) {\n      if (max_chain_ops.find(out_node) == max_chain_ops.end()\n          && !IsTickOpConf(out_node->op().op_conf())\n          && SharedPtrShapeEqual(GetOpNodeFastestTimeShape(out_node), seed_time_shape)) {\n        nontrivial_sink_consumers.insert(out_node);\n      }\n    });\n  }\n\n  // NOTE(chengcheng): find last op can insert acc ctrl tick.\n  while ((!acc_chain_sink_op->op().op_conf().has_user_conf())\n         || IsReachable(acc_chain_sink_op->op().op_name(), max_chain_src_op->op().op_name())) {\n    VLOG(3) << \" cannot insert acc ctrl edge between: [\" << max_chain_src_op->op().op_name()\n            << \"] -> [\" << acc_chain_sink_op->op().op_name() << \"] , debug info :\\n\"\n            << max_chain_src_op->op().op_conf().DebugString() << \"\\n\"\n            << acc_chain_sink_op->op().op_conf().DebugString() << \"\\n\";\n\n    VLOG(3) << \"remove op : \" << acc_chain_sink_op->op().op_name()\n            << \" from after acc logical chain: \" << acc_chain_id;\n    acc_chain_order_ops.pop_back();\n    if (acc_chain_order_ops.size() > 1) {\n      acc_chain_sink_op = acc_chain_order_ops.back();\n    } else {\n      acc_chain_sink_op = nullptr;\n      break;\n    }\n  }\n  if (acc_chain_sink_op == nullptr) { return; }\n\n  // NOTE(chengcheng): find last op can insert acc tick.\n  while (IsReachable(acc_chain_src_op->op().op_name(), max_chain_sink_op->op().op_name())) {\n    VLOG(3) << \" cannot insert acc tick edge between: [\" << max_chain_sink_op->op().op_name()\n            << \"] -> [\" << acc_chain_src_op->op().op_name() << \"] , debug info :\\n\"\n            << max_chain_sink_op->op().op_conf().DebugString() << \"\\n\"\n            << acc_chain_src_op->op().op_conf().DebugString() << \"\\n\";\n\n    VLOG(3) << \"remove op : \" << acc_chain_src_op->op().op_name()\n            << \" from after acc logical chain: \" << acc_chain_id;\n    acc_chain_order_ops.erase(acc_chain_order_ops.begin());\n    if (acc_chain_order_ops.size() > 1) {\n      acc_chain_src_op = acc_chain_order_ops.front();\n    } else {\n      acc_chain_src_op = nullptr;\n      break;\n    }\n  }\n  if (acc_chain_src_op == nullptr) { return; }\n\n  // NOTE(chengcheng):\n  //   1.add acc ctrl tick between max chain src to acc chain sink for memory lock.\n  const int64_t acc_num = job_builder->job().job_conf().num_gradient_accumulation_steps();\n  CHECK_GT(acc_num, 1);\n  const auto& fc_src_obns = max_chain_src_op->op().output_bns();\n  CHECK(!fc_src_obns.empty());\n  const std::string& max_chain_src_out_lbn =\n      GenLogicalBlobName(max_chain_src_op->op().BnInOp2Lbi(fc_src_obns.Get(0)));\n\n  VLOG(3) << \" max_chain_src_out_lbn : \" << max_chain_src_out_lbn;\n  user_op::UserOpConfWrapper acc_ctrl_tick_op =\n      user_op::UserOpConfWrapperBuilder(\"Sys-AccCtrlTick4MergeMaxAccChain-\" + NewUniqueId())\n          .OpTypeName(\"acc_ctrl_tick\")\n          .Input(\"in\", max_chain_src_out_lbn)\n          .Output(\"out\")\n          .ScopeSymbolId(max_chain_src_op->op().op_conf().scope_symbol_id())\n          .Attr<int32_t>(\"max_acc_num\", acc_num)\n          .Build();\n\n  OperatorConf& acc_chain_sink_op_conf =\n      CHECK_JUST(MapAt(*mut_op_name2conf, acc_chain_sink_op->op().op_name()));\n  CHECK(acc_chain_sink_op_conf.has_user_conf());\n  (*acc_chain_sink_op_conf.mutable_user_conf()\n        ->mutable_input())[user_op::kUserSourceOpTickInputArgName]\n      .add_s(acc_ctrl_tick_op.output(\"out\", 0));\n  CHECK_JUST(job_builder->AddOp(max_chain_src_op->parallel_desc().parallel_conf(),\n                                acc_ctrl_tick_op.op_conf()));\n  VLOG(3) << \" Insert acc ctrl tick between: [\" << max_chain_src_op->op().op_name() << \"] -> [\"\n          << acc_chain_sink_op->op().op_name() << \"]\";\n\n  // NOTE(chengcheng):\n  //   2.add acc tick between max chain sink to acc chain src for strict exec order.\n  const auto& fc_sink_obns = max_chain_sink_op->op().output_bns();\n  CHECK(!fc_sink_obns.empty());\n  const std::string max_chain_sink_lbn =\n      GenLogicalBlobName(max_chain_sink_op->op().BnInOp2Lbi(fc_sink_obns.Get(0)));\n  VLOG(3) << \" max_chain_sink_lbn : \" << max_chain_sink_lbn;\n\n  user_op::UserOpConfWrapper cast_to_tick_op =\n      user_op::UserOpConfWrapperBuilder(\"Sys-LogicalChainSink-CastToTick-\" + NewUniqueId())\n          .OpTypeName(\"cast_to_tick\")\n          .Input(\"in\", max_chain_sink_lbn)\n          .Output(\"out\")\n          .ScopeSymbolId(max_chain_sink_op->op().op_conf().scope_symbol_id())\n          .Build();\n\n  CHECK_JUST(job_builder->AddOp(max_chain_sink_op->parallel_desc().parallel_conf(),\n                                cast_to_tick_op.op_conf()));\n\n  std::string acc_tick_output_lbn = cast_to_tick_op.output(\"out\", 0);\n  if (!IsAccOrPackOpNode(max_chain_sink_op)) {\n    // NOTE(chengcheng): Acc Op can be merged in fw/bw chain, if the last op is acc op,\n    //  there is no need and CANNOT insert acc tick op.\n\n    OperatorConf sink_acc_tick_conf;\n    sink_acc_tick_conf.set_name(std::string(\"Sys-LogicalChainSink-AccTick_\") + NewUniqueId());\n    sink_acc_tick_conf.set_scope_symbol_id(max_chain_sink_op->op().op_conf().scope_symbol_id());\n    auto* acc_conf = sink_acc_tick_conf.mutable_acc_tick_conf();\n    acc_conf->set_one(cast_to_tick_op.output(\"out\", 0));\n    acc_conf->set_acc(\"acc\");\n    acc_conf->set_max_acc_num(acc_num);\n    acc_tick_output_lbn = GenLogicalBlobName(sink_acc_tick_conf.name(), \"acc\");\n\n    VLOG(3) << \" insert acc tick op : \" << sink_acc_tick_conf.name()\n            << \" of last op in fw/bw chain.\";\n\n    CHECK_JUST(\n        job_builder->AddOp(max_chain_sink_op->parallel_desc().parallel_conf(), sink_acc_tick_conf));\n  }\n\n  OperatorConf sink_final_tick_conf;\n  sink_final_tick_conf.set_name(std::string(\"Sys-LogicalChainSink-FinalTick-DeviceTick_\")\n                                + NewUniqueId());\n  sink_final_tick_conf.set_scope_symbol_id(max_chain_sink_op->op().op_conf().scope_symbol_id());\n  auto* tick_conf = sink_final_tick_conf.mutable_device_tick_conf();\n  tick_conf->add_tick(acc_tick_output_lbn);\n  tick_conf->set_out(\"out\");\n\n  // NOTE(chengcheng):\n  //   3. Important Tips: If there have nontrivial_sink_consumers, there must insert ctrl\n  //   between sink consumer with acc chain for exec order.\n  for (const OpNode* sink_consumer : nontrivial_sink_consumers) {\n    VLOG(2) << \" insert acc tick between nontrivial_sink_consumer: [\"\n            << sink_consumer->op().op_name() << \"] -> [\" << sink_final_tick_conf.name()\n            << \"] for mem safe guard.\";\n    CHECK(!IsReachable(acc_chain_src_op->op().op_name(), sink_consumer->op().op_name()));\n    const auto& sink_consumer_obns = sink_consumer->op().output_bns();\n    CHECK(!sink_consumer_obns.empty());\n\n    std::string sink_consumer_output_lbn =\n        GenLogicalBlobName(sink_consumer->op().BnInOp2Lbi(sink_consumer_obns.Get(0)));\n    user_op::UserOpConfWrapper sink_consumer_cast_to_tick_op =\n        user_op::UserOpConfWrapperBuilder(\"Sys-LogicalChainSinkConsumer-CastToTick-\"\n                                          + NewUniqueId())\n            .OpTypeName(\"cast_to_tick\")\n            .Input(\"in\", sink_consumer_output_lbn)\n            .Output(\"out\")\n            .ScopeSymbolId(sink_consumer->op().op_conf().scope_symbol_id())\n            .Build();\n\n    CHECK_JUST(job_builder->AddOp(sink_consumer->parallel_desc().parallel_conf(),\n                                  sink_consumer_cast_to_tick_op.op_conf()));\n\n    std::string sink_consumer_acc_tick_lbn = sink_consumer_cast_to_tick_op.output(\"out\", 0);\n    if (!IsAccOrPackOpNode(sink_consumer)) {\n      OperatorConf sink_consumer_acc_tick_conf;\n      sink_consumer_acc_tick_conf.set_name(std::string(\"Sys-LogicalChainSinkConsumer-AccTick_\")\n                                           + NewUniqueId());\n      sink_consumer_acc_tick_conf.set_scope_symbol_id(\n          acc_chain_src_op->op().op_conf().scope_symbol_id());\n      auto* acc_conf = sink_consumer_acc_tick_conf.mutable_acc_tick_conf();\n      acc_conf->set_one(sink_consumer_acc_tick_lbn);\n      acc_conf->set_acc(\"acc\");\n      acc_conf->set_max_acc_num(acc_num);\n      sink_consumer_acc_tick_lbn = GenLogicalBlobName(sink_consumer_acc_tick_conf.name(), \"acc\");\n\n      VLOG(3) << \" insert acc tick op : \" << sink_consumer_acc_tick_conf.name()\n              << \" of nontrivial_sink_consumer in fw/bw chain.\";\n\n      CHECK_JUST(job_builder->AddOp(sink_consumer->parallel_desc().parallel_conf(),\n                                    sink_consumer_acc_tick_conf));\n    }\n    tick_conf->add_tick(sink_consumer_acc_tick_lbn);\n  }\n\n  CHECK_JUST(\n      job_builder->AddOp(max_chain_sink_op->parallel_desc().parallel_conf(), sink_final_tick_conf));\n\n  CHECK_JUST(MapAt(*mut_op_name2conf, acc_chain_src_op->op().op_name()))\n      .add_ctrl_in_op_name(sink_final_tick_conf.name());\n\n  VLOG(3) << \" Insert acc tick between: [\" << max_chain_sink_op->op().op_name() << \"] -> [\"\n          << acc_chain_src_op->op().op_name() << \"]\";\n\n  // NOTE(chengcheng):\n  //   4. merge max chain and acc chain\n  MergedLogicalChainIdGroup* group = job_builder->add_logical_chain_groups();\n  group->add_logical_chain_id_list(max_chain->logical_chain_id);\n  group->add_logical_chain_id_list(acc_chain_id);\n  VLOG(3) << \" Merge acc chain : \" << acc_chain_id\n          << \" to max logical chain : \" << max_chain->logical_chain_id;\n}\n\nMaybe<void> LogicalChainPass::Apply(const OpGraph& op_graph, JobBuilder* job_builder) const {\n  const int64_t acc_num = job_builder->job().job_conf().num_gradient_accumulation_steps();\n  bool has_acc = acc_num > 1;\n  int64_t max_logical_chain_id = -1;\n\n  HashMap<std::string, PlacementLogicalChainsInfo> placement2logical_chains;\n  auto FindOrCreatePlacementLogicalChainsInfo = [&](const OpNode* node) {\n    const ParallelDesc& this_parallel_desc = node->parallel_desc();\n    std::string key = GenParallelConfKey(this_parallel_desc.parallel_conf());\n    auto it = placement2logical_chains.find(key);\n    if (it == placement2logical_chains.end()) {\n      it = placement2logical_chains.emplace(key, PlacementLogicalChainsInfo()).first;\n      it->second.seed_parallel_desc = &this_parallel_desc;\n    }\n    return &(it->second);\n  };\n\n  std::vector<const OpNode*> ordered_op_nodes;\n  HashMap<const OpNode*, int64_t> op_node2global_order;\n  HashMap<std::string, OperatorConf> mut_op_name2conf;\n  std::shared_ptr<const Shape> seed_time_shape = std::make_shared<const Shape>(Shape({1, 1}));\n  if (ParseBooleanFromEnv(\"DISABLE_LOGICAL_STRAIGHTEN\", false)) {\n    op_graph.TopoForEachNodeWithCtrlEdge(\n        [&](const OpNode* node) { ordered_op_nodes.emplace_back(node); });\n  } else {\n    auto_parallel::StraightenOpGraph(op_graph, &ordered_op_nodes);\n  }\n\n  for (int32_t global_order = 0; global_order < ordered_op_nodes.size(); global_order++) {\n    const OpNode* node = JUST(VectorAt(ordered_op_nodes, global_order));\n    op_node2global_order.emplace(node, global_order);\n    std::shared_ptr<const Shape> this_time_shape = GetOpNodeFastestTimeShape(node);\n    if (this_time_shape->elem_cnt() > seed_time_shape->elem_cnt()) {\n      seed_time_shape = this_time_shape;\n    }\n    mut_op_name2conf.emplace(node->op().op_name(), node->op().op_conf());\n    // NOTE(chengcheng): handle logical chain id set by nccl logical pass\n    if (node->op().op_conf().has_logical_chain_id()) {\n      const int64_t logical_chain_id = node->op().op_conf().logical_chain_id();\n      max_logical_chain_id = std::max(max_logical_chain_id, logical_chain_id);\n      PlacementLogicalChainsInfo* info = FindOrCreatePlacementLogicalChainsInfo(node);\n      if (has_acc && this_time_shape->elem_cnt() == 1) {\n        // acc logical chain\n        if (info->after_acc_logical_chain.get() == nullptr) {\n          info->after_acc_logical_chain = std::make_shared<LogicalChain>(logical_chain_id);\n        }\n        info->after_acc_logical_chain->ordered_op_nodes.push_back(node);\n        CHECK_EQ(info->after_acc_logical_chain->logical_chain_id, logical_chain_id);\n      } else {\n        // fw/bw logical chain\n        bool find_chain = false;\n        for (const auto& logical_chain : info->ordered_logical_chains) {\n          if (logical_chain->logical_chain_id == logical_chain_id) {\n            logical_chain->ordered_op_nodes.push_back(node);\n            find_chain = true;\n            break;\n          }\n        }\n        if (!find_chain) {\n          info->ordered_logical_chains.push_back(std::make_shared<LogicalChain>(logical_chain_id));\n          info->ordered_logical_chains.back()->ordered_op_nodes.push_back(node);\n          CHECK_EQ(info->ordered_logical_chains.back()->logical_chain_id, logical_chain_id);\n        }\n      }\n    }\n  }\n\n  VLOG(2) << \" seed time shape = \" << seed_time_shape->ToString();\n\n  std::vector<HashSet<const OpNode*>> logical_chains;\n  GetLogicalChainsWithTimeShape(&logical_chains, ordered_op_nodes, seed_time_shape);\n  if (logical_chains.empty() && placement2logical_chains.empty()) { return Maybe<void>::Ok(); }\n\n  auto CmpOpNodeOrder = [&](const OpNode* lhs, const OpNode* rhs) {\n    return op_node2global_order.at(lhs) < op_node2global_order.at(rhs);\n  };\n  auto CmpLogicalChainOrder = [&](const std::shared_ptr<LogicalChain>& lhs,\n                                  const std::shared_ptr<LogicalChain>& rhs) {\n    int64_t lhs_begin_op_global_order = op_node2global_order.at(lhs->ordered_op_nodes.front());\n    int64_t rhs_begin_op_global_order = op_node2global_order.at(rhs->ordered_op_nodes.front());\n    return lhs_begin_op_global_order < rhs_begin_op_global_order;\n  };\n  auto IsReachable = op_graph.MakePredicatorIsOpNameDataOrCtrlReachable();\n\n  for (const auto& origin_logical_chain : logical_chains) {\n    const OpNode* rand_node = *origin_logical_chain.begin();\n    PlacementLogicalChainsInfo* info = FindOrCreatePlacementLogicalChainsInfo(rand_node);\n    info->ordered_logical_chains.emplace_back(\n        std::make_shared<LogicalChain>(++max_logical_chain_id));\n    InitPlacementLogicalChainsInfoFromSet(info->ordered_logical_chains.back(), origin_logical_chain,\n                                          op_node2global_order, CmpOpNodeOrder);\n  }\n\n  for (auto& pair : placement2logical_chains) {\n    std::sort(pair.second.ordered_logical_chains.begin(), pair.second.ordered_logical_chains.end(),\n              CmpLogicalChainOrder);\n  }\n\n  for (const OpNode* this_node : ordered_op_nodes) {\n    if (IsAccOpNode(this_node)) {\n      const ParallelDesc& this_parallel_desc = this_node->parallel_desc();\n      std::string key = GenParallelConfKey(this_parallel_desc.parallel_conf());\n      auto it = placement2logical_chains.find(key);\n      if (it != placement2logical_chains.end()) {\n        it->second.ordered_acc_op_nodes.emplace_back(this_node);\n      }\n    }\n  }\n\n  auto InsertCtrlEdgeInChain = [&](const std::vector<const OpNode*>& ordered_op_nodes) {\n    for (int64_t i = 1; i < ordered_op_nodes.size(); ++i) {\n      const OpNode* this_node = CHECK_JUST(VectorAt(ordered_op_nodes, i));\n      const OpNode* prev_node = CHECK_JUST(VectorAt(ordered_op_nodes, i - 1));\n      const std::string& this_op_name = this_node->op().op_name();\n      const std::string& prev_op_name = prev_node->op().op_name();\n      if (!IsReachable(prev_op_name, this_op_name)) {\n        CHECK_JUST(MapAt(mut_op_name2conf, this_op_name)).add_ctrl_in_op_name(prev_op_name);\n      }\n    }\n  };\n\n  auto InsertLogicalChainId = [&](const std::vector<const OpNode*>& ordered_op_nodes,\n                                  const int64_t logical_chain_id) {\n    int64_t order = 0;\n    for (const OpNode* op_node : ordered_op_nodes) {\n      auto& conf = CHECK_JUST(MapAt(mut_op_name2conf, op_node->op().op_name()));\n      conf.set_logical_chain_id(logical_chain_id);\n      conf.set_order_in_logical_chain(order++);\n    }\n  };\n\n  HashSet<int64_t> exist_chain_ids;\n  for (auto& pair : placement2logical_chains) {\n    const auto& placement = pair.first;\n    auto& info = pair.second;\n    CHECK_GE(info.ordered_logical_chains.size(), 1);\n\n    // NOTE(chengcheng): set logical chain id for each op in each logical chain, and insert ctrl\n    //   edge for order.\n    for (auto& logical_chain : info.ordered_logical_chains) {\n      CHECK_GE(logical_chain->logical_chain_id, 0);\n      CHECK(exist_chain_ids.insert(logical_chain->logical_chain_id).second);\n      InsertLogicalChainId(logical_chain->ordered_op_nodes, logical_chain->logical_chain_id);\n      InsertCtrlEdgeInChain(logical_chain->ordered_op_nodes);\n    }\n\n    for (const auto& logical_chain : info.ordered_logical_chains) {\n      VLOG(3) << \" In placement: \" << placement\n              << \" logical_chain_id: \" << logical_chain->logical_chain_id\n              << \" has op num = \" << logical_chain->ordered_op_nodes.size();\n\n      for (int i = 0; i < logical_chain->ordered_op_nodes.size(); ++i) {\n        const OpNode* ordered_op = JUST(VectorAt(logical_chain->ordered_op_nodes, i));\n        VLOG(3) << \" ChainId: \" << logical_chain->logical_chain_id << \" order: \" << i\n                << \" op_name: \" << ordered_op->op().op_name()\n                << \" global_order: \" << JUST(MapAt(op_node2global_order, ordered_op));\n      }\n    }\n\n    // NOTE(chengcheng): create logical chain after acc, and merge with max logical chain.\n    const std::vector<const OpNode*>& ordered_acc_op_nodes = info.ordered_acc_op_nodes;\n    if (!ordered_acc_op_nodes.empty()) {\n      if (info.after_acc_logical_chain.get() == nullptr) {\n        info.after_acc_logical_chain = std::make_shared<LogicalChain>(++max_logical_chain_id);\n        CreateAfterAccLogicalChain(info.after_acc_logical_chain, ordered_acc_op_nodes,\n                                   *info.seed_parallel_desc);\n      }\n      CHECK_GE(info.after_acc_logical_chain->logical_chain_id, 0);\n      CHECK(exist_chain_ids.insert(info.after_acc_logical_chain->logical_chain_id).second);\n      auto& acc_chain_order_ops = info.after_acc_logical_chain->ordered_op_nodes;\n      if (acc_chain_order_ops.size() > 1) {\n        std::sort(acc_chain_order_ops.begin(), acc_chain_order_ops.end(), CmpOpNodeOrder);\n\n        TryMergeAfterAccLogicalChainToMaxLogicalChain(&info, &mut_op_name2conf, job_builder,\n                                                      IsReachable, seed_time_shape);\n\n        if (acc_chain_order_ops.size() <= 1) { continue; }\n\n        VLOG(3) << \" In placement: \" << placement\n                << \" AccLogicalChain: \" << info.after_acc_logical_chain->logical_chain_id\n                << \" has op num = \" << acc_chain_order_ops.size();\n\n        for (int i = 0; i < acc_chain_order_ops.size(); ++i) {\n          const OpNode* ordered_op = JUST(VectorAt(acc_chain_order_ops, i));\n          VLOG(3) << \" AfterAccChainId: \" << info.after_acc_logical_chain->logical_chain_id\n                  << \" order: \" << i << \" op_name: \" << ordered_op->op().op_name()\n                  << \" global_order: \" << JUST(MapAt(op_node2global_order, ordered_op));\n        }\n\n        InsertLogicalChainId(acc_chain_order_ops, info.after_acc_logical_chain->logical_chain_id);\n        InsertCtrlEdgeInChain(acc_chain_order_ops);\n      }\n    }\n  }\n\n  // NOTE(chengcheng): update global order and chain id for ops.\n  for (const auto& pair : mut_op_name2conf) { JUST(job_builder->MutOpOnlyOnce(pair.second)); }\n\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\nREGISTER_JOB_PASS(\"LogicalChainPass\", LogicalChainPass);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job_rewriter/momentum_optm.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/str_util.h\"\n#include \"oneflow/core/job_rewriter/optimizer.h\"\n#include \"oneflow/core/framework/framework.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nvoid GenerateOptimizerOpConf(JobPassCtx* ctx, const OpNode& var_op_node,\n                             const std::string& model_diff_lbn, const OptimizerConf& optimizer_conf,\n                             JobBuilder* job_builder) {\n  const VariableOp* var_op = dynamic_cast<const VariableOp*>(&var_op_node.op());\n  CHECK_NOTNULL(var_op);\n  const std::string op_name = var_op->op_name() + \"-momentum\";\n  OperatorConf momentum_var(var_op->op_conf());\n  const bool has_snapshot_path =\n      job_builder->job().job_conf().has_default_initialize_with_snapshot_path();\n  std::string file_path;\n  if (has_snapshot_path) {\n    file_path = JoinPath(job_builder->job().job_conf().default_initialize_with_snapshot_path(),\n                         op_name, \"out\");\n  }\n  if (has_snapshot_path && SnapshotFS()->FileExists(file_path)) {\n    VLOG(3) << \"file_path: \" << file_path;\n    momentum_var.mutable_variable_conf()->mutable_initialize_with_snapshot()->set_path(\n        JoinPath(job_builder->job().job_conf().default_initialize_with_snapshot_path(), op_name));\n    momentum_var.mutable_variable_conf()->mutable_initialize_with_snapshot()->set_key(\"out\");\n  } else {\n    if (has_snapshot_path) { VLOG(3) << file_path << \" not found, will be initialized\"; }\n    InitializerConf constant_initializer;\n    constant_initializer.mutable_constant_conf()->set_value(0.f);\n    *(momentum_var.mutable_variable_conf()->mutable_initializer()) = constant_initializer;\n  }\n  momentum_var.set_name(op_name);\n  momentum_var.mutable_variable_conf()->set_out(\"out\");\n  momentum_var.set_scope_symbol_id(var_op->op_conf().scope_symbol_id());\n  job_builder->AddOps(var_op_node.parallel_desc().parallel_conf(), {momentum_var});\n\n  user_op::UserOpConfWrapperBuilder momentum_update_op_builder(var_op->op_name() + \"_optimizer\");\n  momentum_update_op_builder.OpTypeName(\"momentum_update\")\n      .Input(\"model\", GenLogicalBlobName(var_op->BnInOp2Lbi(\"out\")))\n      .Input(\"model_diff\", model_diff_lbn)\n      .Input(\"learning_rate\", optimizer_conf.learning_rate_lbn())\n      .Input(\"momentum\", GenLogicalBlobName(op_name, momentum_var.variable_conf().out()))\n      .Attr<float>(\"beta\", optimizer_conf.momentum_conf().beta())\n      .Attr<float>(\"dampening\", optimizer_conf.momentum_conf().dampening())\n      .Attr<bool>(\"nesterov\", optimizer_conf.momentum_conf().nesterov())\n      .Attr<bool>(\"maximize\", optimizer_conf.momentum_conf().maximize())\n      .Attr<float>(\"weight_decay\", GetOptimizerWeightDecayRate(optimizer_conf, *var_op))\n      .ScopeSymbolId(var_op->op_conf().scope_symbol_id());\n  if (optimizer_conf.has_lr_scale()) {\n    momentum_update_op_builder.Attr<float>(\"learning_rate_scale\", optimizer_conf.lr_scale());\n  }\n  SetDynamicLossScaleSkipIf(ctx, &momentum_update_op_builder);\n  user_op::UserOpConfWrapper momentum_update_op = momentum_update_op_builder.Build();\n  job_builder->AddOps(var_op_node.parallel_desc().parallel_conf(), {momentum_update_op.op_conf()});\n}\n\n}  // namespace\n\nREGISTER_OPTIMIZER(OptimizerConf::kMomentumConf, &GenerateOptimizerOpConf);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job_rewriter/multi_tensor_model_update.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/job_rewriter/job_pass.h\"\n#include \"oneflow/core/framework/framework.h\"\n\nnamespace oneflow {\n\nstruct SGDOptimizerKey {\n  std::string learning_rate;\n  std::string scale_by_tensor_lbn;\n  std::string skip_if_lbn;\n  double scale;\n  float l1;\n  float l2;\n  float weight_decay;\n  ParallelConf parallel_conf;\n  bool has_model_copy;\n  /*\n  In fuse_model_update_cast pass, not all the cast fp16 model_diff kernel can be fused,\n  it may cause some model diff type is float16, some is float32.\n  So here we need to use model_diff datatype as key to group.\n  */\n  DataType model_diff_dtype;\n};\n\nbool operator==(const SGDOptimizerKey& lhs, const SGDOptimizerKey& rhs) {\n  return (lhs.learning_rate == rhs.learning_rate)\n         && (lhs.scale_by_tensor_lbn == rhs.scale_by_tensor_lbn)\n         && (lhs.skip_if_lbn == rhs.skip_if_lbn) && (lhs.scale == rhs.scale) && (lhs.l1 == rhs.l1)\n         && (lhs.l2 == rhs.l2) && (lhs.weight_decay == rhs.weight_decay)\n         && (lhs.parallel_conf == rhs.parallel_conf) && (lhs.has_model_copy == rhs.has_model_copy)\n         && (lhs.model_diff_dtype == rhs.model_diff_dtype);\n}\n\nstruct AdamOptimizerKey {\n  std::string learning_rate;\n  std::string scale_by_tensor_lbn;\n  std::string skip_if_lbn;\n  double scale;\n  float l1;\n  float l2;\n  float beta1;\n  float beta2;\n  float epsilon;\n  float weight_decay;\n  bool amsgrad;\n  bool do_bias_correction;\n  ParallelConf parallel_conf;\n  bool has_model_copy;\n  DataType model_diff_dtype;\n};\n\nbool operator==(const AdamOptimizerKey& lhs, const AdamOptimizerKey& rhs) {\n  return (lhs.learning_rate == rhs.learning_rate)\n         && (lhs.scale_by_tensor_lbn == rhs.scale_by_tensor_lbn)\n         && (lhs.skip_if_lbn == rhs.skip_if_lbn) && (lhs.scale == rhs.scale) && (lhs.l1 == rhs.l1)\n         && (lhs.l2 == rhs.l2) && (lhs.beta1 == rhs.beta1) && (lhs.beta2 == rhs.beta2)\n         && (lhs.epsilon == rhs.epsilon) && (lhs.weight_decay == rhs.weight_decay)\n         && (lhs.amsgrad == rhs.amsgrad) && (lhs.do_bias_correction == rhs.do_bias_correction)\n         && (lhs.parallel_conf == rhs.parallel_conf) && (lhs.has_model_copy == rhs.has_model_copy)\n         && (lhs.model_diff_dtype == rhs.model_diff_dtype);\n}\n\n}  // namespace oneflow\n\nnamespace std {\n\ntemplate<>\nstruct hash<oneflow::SGDOptimizerKey> {\n  size_t operator()(const oneflow::SGDOptimizerKey& key) const {\n    const auto float_hash = std::hash<float>();\n    const auto double_hash = std::hash<float>();\n    const auto& string_hash = std::hash<std::string>();\n    const auto& parallel_conf_hash = std::hash<oneflow::ParallelConf>();\n    const auto& bool_hash = std::hash<bool>();\n    const auto& dtype_hash = std::hash<oneflow::DataType>();\n\n    size_t hash = string_hash(key.learning_rate);\n    oneflow::HashCombine(&hash, string_hash(key.scale_by_tensor_lbn));\n    oneflow::HashCombine(&hash, string_hash(key.skip_if_lbn));\n    oneflow::HashCombine(&hash, double_hash(key.scale));\n    oneflow::HashCombine(&hash, float_hash(key.l1));\n    oneflow::HashCombine(&hash, float_hash(key.l2));\n    oneflow::HashCombine(&hash, float_hash(key.weight_decay));\n    oneflow::HashCombine(&hash, parallel_conf_hash(key.parallel_conf));\n    oneflow::HashCombine(&hash, bool_hash(key.has_model_copy));\n    oneflow::HashCombine(&hash, dtype_hash(key.model_diff_dtype));\n    return hash;\n  }\n};\n\ntemplate<>\nstruct hash<oneflow::AdamOptimizerKey> {\n  size_t operator()(const oneflow::AdamOptimizerKey& key) const {\n    const auto& float_hash = std::hash<float>();\n    const auto& double_hash = std::hash<float>();\n    const auto& string_hash = std::hash<std::string>();\n    const auto& bool_hash = std::hash<bool>();\n    const auto& parallel_conf_hash = std::hash<oneflow::ParallelConf>();\n    const auto& dtype_hash = std::hash<oneflow::DataType>();\n\n    size_t hash = string_hash(key.learning_rate);\n    oneflow::HashCombine(&hash, string_hash(key.scale_by_tensor_lbn));\n    oneflow::HashCombine(&hash, string_hash(key.skip_if_lbn));\n    oneflow::HashCombine(&hash, double_hash(key.scale));\n    oneflow::HashCombine(&hash, float_hash(key.l1));\n    oneflow::HashCombine(&hash, float_hash(key.l2));\n    oneflow::HashCombine(&hash, float_hash(key.beta1));\n    oneflow::HashCombine(&hash, float_hash(key.beta2));\n    oneflow::HashCombine(&hash, float_hash(key.epsilon));\n    oneflow::HashCombine(&hash, float_hash(key.weight_decay));\n    oneflow::HashCombine(&hash, bool_hash(key.amsgrad));\n    oneflow::HashCombine(&hash, bool_hash(key.do_bias_correction));\n    oneflow::HashCombine(&hash, parallel_conf_hash(key.parallel_conf));\n    oneflow::HashCombine(&hash, bool_hash(key.has_model_copy));\n    oneflow::HashCombine(&hash, dtype_hash(key.model_diff_dtype));\n    return hash;\n  }\n};\n\n}  // namespace std\n\nnamespace oneflow {\n\nnamespace {\n\nvoid AddScaleAndSkipLbn(user_op::UserOpConfWrapperBuilder& multi_tensor_model_update_op_builder,\n                        const user_op::UserOpConfWrapper& model_update_user_conf) {\n  if (model_update_user_conf.has_input(\"scale_by_tensor\", 0)) {\n    multi_tensor_model_update_op_builder.Input(\"scale_by_tensor\",\n                                               model_update_user_conf.input(\"scale_by_tensor\", 0));\n  }\n  if (model_update_user_conf.has_input(\"skip_if\", 0)) {\n    multi_tensor_model_update_op_builder.Input(\"skip_if\",\n                                               model_update_user_conf.input(\"skip_if\", 0));\n  }\n}\n\nvoid AddProcessedVariable(HashSet<std::string>& processed_variable_list,\n                          const user_op::UserOpConfWrapper& model_update_user_conf) {\n  /*\n  Since each variable op will be processed in pass, for example, Adam optimizer has 3 variables:\n  model, m, v. We replace to multi tensor optimizer and processed 3 variables at once, if we don't\n  filter these variables, these variables will be repeated 3 times in multi_tensor_update kernel.\n\n  Here we use a HashSet to sign if the variable has been processed.\n  */\n  processed_variable_list.emplace(model_update_user_conf.input(\"model\", 0));\n  if (model_update_user_conf.op_type_name() == \"adam_update\") {\n    processed_variable_list.emplace(model_update_user_conf.input(\"m\", 0));\n    processed_variable_list.emplace(model_update_user_conf.input(\"v\", 0));\n  }\n}\n\nbool IfVariableProcessed(const HashSet<std::string>& processed_variable_list,\n                         const user_op::UserOpConfWrapper& model_update_user_conf) {\n  if (model_update_user_conf.op_type_name() == \"sgd_update\") {\n    const auto& processed_model_iter =\n        processed_variable_list.find(model_update_user_conf.input(\"model\", 0));\n    if (processed_model_iter != processed_variable_list.end()) { return true; }\n  } else if (model_update_user_conf.op_type_name() == \"adam_update\") {\n    const auto& processed_model_iter =\n        processed_variable_list.find(model_update_user_conf.input(\"model\", 0));\n    const auto& processed_m_iter =\n        processed_variable_list.find(model_update_user_conf.input(\"m\", 0));\n    const auto& processed_v_iter =\n        processed_variable_list.find(model_update_user_conf.input(\"v\", 0));\n    if (processed_model_iter != processed_variable_list.end()\n        && processed_m_iter != processed_variable_list.end()\n        && processed_v_iter != processed_variable_list.end()) {\n      return true;\n    }\n  } else {\n    UNIMPLEMENTED() << \"Current Optimizer do not support multi tensor update. \";\n  }\n  return false;\n}\n\nclass MultiTensorModelUpdatePass final : public JobPass {\n public:\n  MultiTensorModelUpdatePass() = default;\n  ~MultiTensorModelUpdatePass() override = default;\n\n  bool IsEnabled(const JobPassCtx& ctx) const {\n    return ctx.job_desc().enable_multi_tensor_update()\n           || ParseBooleanFromEnv(\"ONEFLOW_ENABLE_MULTI_TENSOR_MODEL_UPDATE\", false);\n  }\n  Maybe<void> Apply(const OpGraph& op_graph, JobBuilder* job_builder) const;\n\n  Maybe<void> Apply(Job* job, JobPassCtx* ctx) const override {\n    if (!IsEnabled(*ctx)) { return Maybe<void>::Ok(); }\n    const OpGraph op_graph(*job);\n    JobBuilder job_builder(job);\n    return Apply(op_graph, &job_builder);\n  }\n};\n\nMaybe<void> MultiTensorModelUpdatePass::Apply(const OpGraph& op_graph,\n                                              JobBuilder* job_builder) const {\n  if (!job_builder->job().job_conf().has_train_conf()) { return Maybe<void>::Ok(); }\n  std::vector<OperatorConf> delete_ops;\n  ParallelConf parallel_conf{};\n  HashMap<SGDOptimizerKey, user_op::UserOpConfWrapperBuilder> multi_tensor_sgd_update_hashmap;\n  HashMap<AdamOptimizerKey, user_op::UserOpConfWrapperBuilder> multi_tensor_adam_update_hashmap;\n  HashSet<std::string> processed_variable_list{};\n\n  op_graph.ForEachNode([&](OpNode* op_node) {\n    const auto& op_conf = op_node->op().op_conf();\n    if (!op_conf.has_variable_conf()) { return; }\n    LogicalBlobId model_copy_lbi;\n\n    for (OpEdge* find_model_update_edge : op_node->out_edges()) {\n      OpNode* find_model_update_update_node = find_model_update_edge->dst_node();\n      if (!IsUserOpWithTypeName(find_model_update_update_node->op().op_conf(), \"sgd_update\")\n          && !IsUserOpWithTypeName(find_model_update_update_node->op().op_conf(), \"adam_update\")) {\n        continue;\n      }\n      const user_op::UserOpConfWrapper model_update_user_conf(\n          find_model_update_update_node->op().op_conf());\n      // Multi tensor update pass only support for CUDA currently.\n      if (find_model_update_update_node->parallel_desc().device_type() != DeviceType::kCUDA) {\n        continue;\n      }\n\n      // Multi tensor update pass only support Data Parallel.\n      bool if_data_parallel = true;\n      for (const auto& pair :\n           find_model_update_update_node->sbp_signature().bn_in_op2sbp_parallel()) {\n        if (!pair.second.has_broadcast_parallel()) {\n          if_data_parallel = false;\n          break;\n        }\n      }\n      if (!if_data_parallel) { continue; }\n\n      // Check the variable has been processed before.\n      if (IfVariableProcessed(processed_variable_list, model_update_user_conf)) { continue; }\n\n      delete_ops.emplace_back(find_model_update_update_node->op().op_conf());\n      parallel_conf = find_model_update_update_node->parallel_desc().parallel_conf();\n\n      std::string scale_by_tensor_lbn = \"\";\n      std::string skip_if_lbn = \"\";\n      bool has_model_copy = false;\n      if (model_update_user_conf.has_input(\"scale_by_tensor\", 0)) {\n        scale_by_tensor_lbn = model_update_user_conf.input(\"scale_by_tensor\", 0);\n      }\n      if (model_update_user_conf.has_input(\"skip_if\", 0)) {\n        skip_if_lbn = model_update_user_conf.input(\"skip_if\", 0);\n      }\n      if (model_update_user_conf.has_input(\"model_copy\", 0)) { has_model_copy = true; }\n\n      const BlobDesc& model_diff_blob_desc = op_graph.GetLogicalBlobDesc(\n          GenLogicalBlobId(model_update_user_conf.input(\"model_diff\", 0)));\n      const DataType model_diff_dtype = model_diff_blob_desc.data_type();\n\n      if (IsUserOpWithTypeName(find_model_update_update_node->op().op_conf(), \"sgd_update\")) {\n        SGDOptimizerKey key{model_update_user_conf.input(\"learning_rate\", 0),\n                            scale_by_tensor_lbn,\n                            skip_if_lbn,\n                            model_update_user_conf.attr<double>(\"scale\"),\n                            model_update_user_conf.attr<float>(\"l1\"),\n                            model_update_user_conf.attr<float>(\"l2\"),\n                            model_update_user_conf.attr<float>(\"weight_decay\"),\n                            parallel_conf,\n                            has_model_copy,\n                            model_diff_dtype};\n        const auto& iter = multi_tensor_sgd_update_hashmap.find(key);\n\n        if (iter != multi_tensor_sgd_update_hashmap.end()) {\n          iter->second.Input(\"model\", model_update_user_conf.input(\"model\", 0))\n              .Input(\"model_diff\", model_update_user_conf.input(\"model_diff\", 0));\n          if (has_model_copy) {\n            iter->second.Input(\"model_copy\", model_update_user_conf.input(\"model_copy\", 0));\n          }\n        } else {\n          user_op::UserOpConfWrapperBuilder multi_tensor_sgd_update_op_builder(\n              \"multi_tensor_model_update\" + NewUniqueId());\n          std::string op_type_name = \"multi_tensor_sgd_update\";\n          if (has_model_copy) { op_type_name = \"multi_tensor_sgd_update_with_cast\"; }\n\n          multi_tensor_sgd_update_op_builder.OpTypeName(op_type_name)\n              .Input(\"model\", model_update_user_conf.input(\"model\", 0))\n              .Input(\"model_diff\", model_update_user_conf.input(\"model_diff\", 0))\n              .Input(\"learning_rate\", model_update_user_conf.input(\"learning_rate\", 0))\n              .Attr<double>(\"scale\", model_update_user_conf.attr<double>(\"scale\"))\n              .Attr<float>(\"l1\", model_update_user_conf.attr<float>(\"l1\"))\n              .Attr<float>(\"l2\", model_update_user_conf.attr<float>(\"l2\"))\n              .Attr<float>(\"weight_decay\", model_update_user_conf.attr<float>(\"weight_decay\"));\n          if (has_model_copy) {\n            multi_tensor_sgd_update_op_builder.Input(\"model_copy\",\n                                                     model_update_user_conf.input(\"model_copy\", 0));\n          }\n\n          AddScaleAndSkipLbn(multi_tensor_sgd_update_op_builder, model_update_user_conf);\n\n          CHECK(model_update_user_conf.op_conf().has_scope_symbol_id());\n          multi_tensor_sgd_update_op_builder.ScopeSymbolId(\n              model_update_user_conf.op_conf().scope_symbol_id());\n          multi_tensor_sgd_update_hashmap.emplace(key, multi_tensor_sgd_update_op_builder);\n        }\n      } else if (IsUserOpWithTypeName(find_model_update_update_node->op().op_conf(),\n                                      \"adam_update\")) {\n        AdamOptimizerKey key{model_update_user_conf.input(\"learning_rate\", 0),\n                             scale_by_tensor_lbn,\n                             skip_if_lbn,\n                             model_update_user_conf.attr<double>(\"scale\"),\n                             model_update_user_conf.attr<float>(\"l1\"),\n                             model_update_user_conf.attr<float>(\"l2\"),\n                             model_update_user_conf.attr<float>(\"beta1\"),\n                             model_update_user_conf.attr<float>(\"beta2\"),\n                             model_update_user_conf.attr<float>(\"epsilon\"),\n                             model_update_user_conf.attr<float>(\"weight_decay\"),\n                             model_update_user_conf.attr<bool>(\"amsgrad\"),\n                             model_update_user_conf.attr<bool>(\"do_bias_correction\"),\n                             parallel_conf,\n                             has_model_copy,\n                             model_diff_dtype};\n        if (key.amsgrad) {\n          UNIMPLEMENTED() << \"Multi Tensor Adam update do not support amsgrad = True. \";\n        }\n        const auto& iter = multi_tensor_adam_update_hashmap.find(key);\n\n        if (iter != multi_tensor_adam_update_hashmap.end()) {\n          iter->second.Input(\"model\", model_update_user_conf.input(\"model\", 0))\n              .Input(\"model_diff\", model_update_user_conf.input(\"model_diff\", 0))\n              .Input(\"m\", model_update_user_conf.input(\"m\", 0))\n              .Input(\"v\", model_update_user_conf.input(\"v\", 0));\n          if (has_model_copy) {\n            iter->second.Input(\"model_copy\", model_update_user_conf.input(\"model_copy\", 0));\n          }\n          if (model_update_user_conf.attr<bool>(\"do_bias_correction\")) {\n            iter->second\n                .Input(\"bias_correction1\", model_update_user_conf.input(\"bias_correction1\", 0))\n                .Input(\"bias_correction2\", model_update_user_conf.input(\"bias_correction2\", 0));\n          }\n        } else {\n          user_op::UserOpConfWrapperBuilder multi_tensor_adam_update_op_builder(\n              \"multi_tensor_model_update\" + NewUniqueId());\n          std::string op_type_name = \"multi_tensor_adam_update\";\n          if (has_model_copy) { op_type_name = \"multi_tensor_adam_update_with_cast\"; }\n          multi_tensor_adam_update_op_builder.OpTypeName(op_type_name)\n              .Input(\"model\", model_update_user_conf.input(\"model\", 0))\n              .Input(\"model_diff\", model_update_user_conf.input(\"model_diff\", 0))\n              .Input(\"m\", model_update_user_conf.input(\"m\", 0))\n              .Input(\"v\", model_update_user_conf.input(\"v\", 0))\n              .Input(\"learning_rate\", model_update_user_conf.input(\"learning_rate\", 0))\n              .Attr<double>(\"scale\", model_update_user_conf.attr<double>(\"scale\"))\n              .Attr<float>(\"l1\", model_update_user_conf.attr<float>(\"l1\"))\n              .Attr<float>(\"l2\", model_update_user_conf.attr<float>(\"l2\"))\n              .Attr<float>(\"beta1\", model_update_user_conf.attr<float>(\"beta1\"))\n              .Attr<float>(\"beta2\", model_update_user_conf.attr<float>(\"beta2\"))\n              .Attr<float>(\"epsilon\", model_update_user_conf.attr<float>(\"epsilon\"))\n              .Attr<float>(\"weight_decay\", model_update_user_conf.attr<float>(\"weight_decay\"))\n              .Attr<bool>(\"amsgrad\", model_update_user_conf.attr<bool>(\"amsgrad\"))\n              .Attr<bool>(\"do_bias_correction\",\n                          model_update_user_conf.attr<bool>(\"do_bias_correction\"));\n\n          if (model_update_user_conf.attr<bool>(\"do_bias_correction\")) {\n            multi_tensor_adam_update_op_builder\n                .Input(\"bias_correction1\", model_update_user_conf.input(\"bias_correction1\", 0))\n                .Input(\"bias_correction2\", model_update_user_conf.input(\"bias_correction2\", 0));\n          }\n          if (has_model_copy) {\n            multi_tensor_adam_update_op_builder.Input(\n                \"model_copy\", model_update_user_conf.input(\"model_copy\", 0));\n          }\n          AddScaleAndSkipLbn(multi_tensor_adam_update_op_builder, model_update_user_conf);\n\n          CHECK(model_update_user_conf.op_conf().has_scope_symbol_id());\n          multi_tensor_adam_update_op_builder.ScopeSymbolId(\n              model_update_user_conf.op_conf().scope_symbol_id());\n          multi_tensor_adam_update_hashmap.emplace(key, multi_tensor_adam_update_op_builder);\n        }\n      } else {\n        UNIMPLEMENTED() << \"Current Optimizer do not support multi tensor update. \";\n      }\n\n      AddProcessedVariable(processed_variable_list, model_update_user_conf);\n      break;\n    }\n  });\n  for (auto& op : multi_tensor_sgd_update_hashmap) {\n    auto multi_tensor_model_update_sgd_op = op.second.Build();\n    job_builder->AddOps(parallel_conf, {multi_tensor_model_update_sgd_op.op_conf()});\n  }\n  for (auto& op : multi_tensor_adam_update_hashmap) {\n    auto multi_tensor_model_update_adam_op = op.second.Build();\n    job_builder->AddOps(parallel_conf, {multi_tensor_model_update_adam_op.op_conf()});\n  }\n  job_builder->DelOps(delete_ops);\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\nREGISTER_JOB_PASS(\"MultiTensorModelUpdatePass\", MultiTensorModelUpdatePass);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job_rewriter/nccl_logical_chain_strict_order_pass.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#if defined(WITH_CUDA) || defined(WITH_NPU) || defined(WITH_MLU)\n#include \"oneflow/core/auto_parallel/auto_memory.h\"\n#include \"oneflow/core/job/nd_sbp_util.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/nd_sbp.h\"\n#include \"oneflow/core/job/sbp_parallel.h\"\n#include \"oneflow/core/job/job.pb.h\"\n#include \"oneflow/core/job_rewriter/job_pass.h\"\n#include \"oneflow/core/job_rewriter/calculation_pass.h\"\n#include \"oneflow/core/operator/operator.h\"\n#include \"oneflow/core/framework/sbp_infer_util.h\"\n#include \"oneflow/core/common/env_var/debug_mode.h\"\n#include \"oneflow/core/common/container_util.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nclass NcclLogicalChainStrictOrderPass final : public JobPass {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(NcclLogicalChainStrictOrderPass);\n  NcclLogicalChainStrictOrderPass() = default;\n  ~NcclLogicalChainStrictOrderPass() = default;\n\n  Maybe<void> Apply(Job* job, JobPassCtx* ctx) const override {\n    if (!IsEnabled(*ctx)) { return Maybe<void>::Ok(); }\n    const OpGraph op_graph(*job);\n    JobBuilder job_builder(job);\n    return Apply(op_graph, &job_builder);\n  }\n\n  bool IsEnabled(const JobPassCtx& ctx) const {\n    return Singleton<ResourceDesc, ForSession>::Get()->nccl_use_compute_stream();\n  }\n\n  Maybe<void> Apply(const OpGraph& op_graph, JobBuilder* job_builder) const;\n};\n\nbool IsAccOrPackOpNode(const OpNode* node) {\n  const auto& op_conf = node->op().op_conf();\n  return op_conf.has_user_conf()\n         && (op_conf.user_conf().op_type_name() == \"acc\"\n             || op_conf.user_conf().op_type_name() == \"pack\");\n}\n\nMaybe<void> InsertCtrlOpBetweenBwChainAndAccChain(\n    HashMap<std::string, OperatorConf>* mut_op_name2conf, JobBuilder* job_builder,\n    const std::vector<const OpNode*>& ordered_op_nodes,\n    const std::function<bool(const std::string&, const std::string&)>& IsReachable) {\n  HashMap<std::string, const OpNode*> placement2last_normal_node;\n  HashMap<std::string, const OpNode*> placement2first_after_acc_node;\n  int64_t acc_num = job_builder->job().job_conf().num_gradient_accumulation_steps();\n\n  for (int32_t global_order = 0; global_order < ordered_op_nodes.size(); global_order++) {\n    const OpNode* node = JUST(VectorAt(ordered_op_nodes, global_order));\n    if (!node->op().op_conf().has_logical_chain_id()) { continue; }\n    const int64_t time_shape_cnt =\n        CHECK_JUST(node->op().GetInputOutputFastestTimeShape())->elem_cnt();\n    CHECK(time_shape_cnt == acc_num || time_shape_cnt == 1)\n        << \" invalid time shape count = \" << time_shape_cnt << \" which should be : [ \" << acc_num\n        << \" , 1 ]\";\n    std::string placement_key = GenParallelConfKey(node->parallel_desc().parallel_conf());\n    if (time_shape_cnt == acc_num) {\n      // for all fw/bw chains in this placement\n      placement2last_normal_node[placement_key] = node;  // create or update\n    } else {\n      // acc chain\n      if (placement2first_after_acc_node.find(placement_key)\n          == placement2first_after_acc_node.end()) {\n        CHECK(placement2first_after_acc_node.emplace(placement_key, node).second);\n      }\n    }\n  }\n\n  for (const auto& pair : placement2last_normal_node) {\n    if (placement2first_after_acc_node.find(pair.first) == placement2first_after_acc_node.end()) {\n      continue;\n    }\n    const OpNode* last_bw_node = pair.second;\n    const OpNode* first_after_acc_node = JUST(MapAt(placement2first_after_acc_node, pair.first));\n    const std::string& last_bw_op_name = last_bw_node->op().op_name();\n    const std::string& first_after_acc_op_name = first_after_acc_node->op().op_name();\n\n    CHECK_OR_RETURN(!IsReachable(first_after_acc_op_name, last_bw_op_name))\n        << Error::RuntimeError()\n        << \" Error! Cycle control edge from first acc chain op: \" << first_after_acc_op_name\n        << \" to last bw chain sink op: \" << last_bw_op_name;\n\n    const auto& bw_sink_obns = last_bw_node->op().output_bns();\n    CHECK_OR_RETURN(!bw_sink_obns.empty());\n    const std::string bw_sink_lbn =\n        GenLogicalBlobName(last_bw_node->op().BnInOp2Lbi(bw_sink_obns.Get(0)));\n    VLOG(3) << \" bw_sink_lbn : \" << bw_sink_lbn;\n\n    user_op::UserOpConfWrapper cast_to_tick_op =\n        user_op::UserOpConfWrapperBuilder(\"Sys-LastNcclChainSink-CastToTick-\" + NewUniqueId())\n            .OpTypeName(\"cast_to_tick\")\n            .Input(\"in\", bw_sink_lbn)\n            .Output(\"out\")\n            .ScopeSymbolId(last_bw_node->op().op_conf().scope_symbol_id())\n            .Build();\n\n    JUST(job_builder->AddOp(last_bw_node->parallel_desc().parallel_conf(),\n                            cast_to_tick_op.op_conf()));\n\n    std::string acc_tick_output_lbn = cast_to_tick_op.output(\"out\", 0);\n    if (!IsAccOrPackOpNode(last_bw_node)) {\n      // NOTE(chengcheng): Acc Op can be merged in fw/bw chain, if the last op is acc op,\n      //  there is no need and CANNOT insert acc tick op.\n\n      OperatorConf sink_acc_tick_conf;\n      sink_acc_tick_conf.set_name(std::string(\"Sys-LastNcclChainSink-AccTick_\") + NewUniqueId());\n      sink_acc_tick_conf.set_scope_symbol_id(last_bw_node->op().op_conf().scope_symbol_id());\n      auto* acc_conf = sink_acc_tick_conf.mutable_acc_tick_conf();\n      acc_conf->set_one(acc_tick_output_lbn);\n      acc_conf->set_acc(\"acc\");\n      acc_conf->set_max_acc_num(acc_num);\n\n      acc_tick_output_lbn = GenLogicalBlobName(sink_acc_tick_conf.name(), \"acc\");\n\n      VLOG(3) << \" insert acc tick op : \" << sink_acc_tick_conf.name()\n              << \" of last op in fw/bw chain.\";\n\n      JUST(job_builder->AddOp(last_bw_node->parallel_desc().parallel_conf(), sink_acc_tick_conf));\n    }\n\n    OperatorConf sink_final_tick_conf;\n    sink_final_tick_conf.set_name(std::string(\"Sys-LastNcclChainSink-FinalTick-DeviceTick_\")\n                                  + NewUniqueId());\n    sink_final_tick_conf.set_scope_symbol_id(last_bw_node->op().op_conf().scope_symbol_id());\n    auto* tick_conf = sink_final_tick_conf.mutable_device_tick_conf();\n    tick_conf->add_tick(acc_tick_output_lbn);\n    tick_conf->set_out(\"out\");\n\n    JUST(job_builder->AddOp(last_bw_node->parallel_desc().parallel_conf(), sink_final_tick_conf));\n\n    if (mut_op_name2conf->find(first_after_acc_op_name) == mut_op_name2conf->end()) {\n      mut_op_name2conf->emplace(first_after_acc_op_name, first_after_acc_node->op().op_conf());\n    }\n    JUST(MapAt(*mut_op_name2conf, first_after_acc_op_name))\n        .add_ctrl_in_op_name(sink_final_tick_conf.name());\n\n    VLOG(2) << \" In: \" << pair.first << \" , insert ctrl edge from: [ \" << last_bw_op_name\n            << \" ] to: [ \" << first_after_acc_op_name << \" ]\";\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> NcclLogicalChainStrictOrderPass::Apply(const OpGraph& op_graph,\n                                                   JobBuilder* job_builder) const {\n  HashMap<int64_t, const OpNode*> nccl_chain_id2cur_last_node;\n  HashMap<std::string, OperatorConf> mut_op_name2conf;\n  auto IsReachable = op_graph.MakePredicatorIsOpNameDataOrCtrlReachable();\n\n  std::vector<const OpNode*> ordered_op_nodes;\n  if (ParseBooleanFromEnv(\"DISABLE_LOGICAL_STRAIGHTEN\", false)) {\n    op_graph.TopoForEachNodeWithCtrlEdge(\n        [&](const OpNode* node) { ordered_op_nodes.emplace_back(node); });\n  } else {\n    auto_parallel::StraightenOpGraph(op_graph, &ordered_op_nodes);\n  }\n\n  for (int32_t global_order = 0; global_order < ordered_op_nodes.size(); global_order++) {\n    const OpNode* node = JUST(VectorAt(ordered_op_nodes, global_order));\n    if (!node->op().op_conf().has_logical_chain_id()) { continue; }\n    const int64_t logical_chain_id = node->op().op_conf().logical_chain_id();\n\n    // add ctrl edge for strict order\n    auto it = nccl_chain_id2cur_last_node.find(logical_chain_id);\n    if (it == nccl_chain_id2cur_last_node.end()) {\n      nccl_chain_id2cur_last_node.emplace(logical_chain_id, node);\n    } else {\n      const std::string& this_op_name = node->op().op_name();\n      const std::string& prev_op_name = it->second->op().op_name();\n      if (!IsReachable(prev_op_name, this_op_name)) {\n        CHECK(mut_op_name2conf.emplace(this_op_name, node->op().op_conf()).second);\n        JUST(MapAt(mut_op_name2conf, this_op_name)).add_ctrl_in_op_name(prev_op_name);\n      }\n      it->second = node;\n    }\n  }\n\n  if (job_builder->job().job_conf().num_gradient_accumulation_steps() > 1) {\n    JUST(InsertCtrlOpBetweenBwChainAndAccChain(&mut_op_name2conf, job_builder, ordered_op_nodes,\n                                               IsReachable));\n  }\n\n  for (const auto& pair : mut_op_name2conf) { JUST(job_builder->MutOpOnlyOnce(pair.second)); }\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\nREGISTER_JOB_PASS(\"NcclLogicalChainStrictOrderPass\", NcclLogicalChainStrictOrderPass);\n\n}  // namespace oneflow\n\n#endif  // WITH_CUDA || WITH_NPU || WITH_MLU\n"
  },
  {
    "path": "oneflow/core/job_rewriter/nccl_logical_op_fusion_pass.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#if defined(WITH_CUDA) || defined(WITH_NPU) || defined(WITH_MLU)\n#include \"oneflow/core/auto_parallel/auto_memory.h\"\n#include \"oneflow/core/job/nd_sbp_util.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/nd_sbp.h\"\n#include \"oneflow/core/job/sbp_parallel.h\"\n#include \"oneflow/core/job/job.pb.h\"\n#include \"oneflow/core/job_rewriter/job_pass.h\"\n#include \"oneflow/core/job_rewriter/calculation_pass.h\"\n#include \"oneflow/core/operator/operator.h\"\n#include \"oneflow/core/framework/sbp_infer_util.h\"\n#include \"oneflow/core/common/env_var/env_var.h\"\n#include \"oneflow/core/common/env_var/debug_mode.h\"\n#include \"oneflow/core/common/container_util.h\"\n#include \"oneflow/user/ops/nccl_logical_util.h\"\n\nnamespace oneflow {\n\n// nccl fusion bucket size 500MiB.\nDEFINE_ENV_INTEGER(ONEFLOW_GRAPH_NCCL_LOGICAL_FUSION_BUCKET_SIZE, 5e8);\n\nnamespace {\n\nclass NcclLogicalOpFusionPass final : public JobPass {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(NcclLogicalOpFusionPass);\n  NcclLogicalOpFusionPass() = default;\n  ~NcclLogicalOpFusionPass() = default;\n\n  Maybe<void> Apply(Job* job, JobPassCtx* ctx) const override {\n    if (!IsEnabled(*ctx)) { return Maybe<void>::Ok(); }\n    const OpGraph op_graph(*job);\n    JobBuilder job_builder(job);\n    return Apply(op_graph, &job_builder);\n  }\n\n  bool IsEnabled(const JobPassCtx& ctx) const {\n    return Singleton<ResourceDesc, ForSession>::Get()->nccl_use_compute_stream()\n           && EnableNcclLogicalFusion();\n  }\n\n  Maybe<void> Apply(const OpGraph& op_graph, JobBuilder* job_builder) const;\n};\n\nconst std::string kNcclLogicalFusionOpNamePrefix = \"Sys-NCCL-Logical-Fusion\";\n\nbool IsNcclLogicalOpNode(const OpNode* node) {\n  if (node->op().op_conf().has_user_conf()) {\n    const std::string& user_type_name = node->op().op_conf().user_conf().op_type_name();\n    if (user_type_name == \"_nccl_logical_all_reduce\"\n        || user_type_name == \"_nccl_logical_reduce_scatter\"\n        || user_type_name == \"_nccl_logical_reduce_scatter_noncontinuous\"\n        || user_type_name == \"_nccl_logical_all_gather\"\n        || user_type_name == \"_nccl_logical_all_gather_noncontinuous\"\n        || user_type_name == \"_nccl_logical_s2s\"\n        || user_type_name == \"_nccl_logical_2D_same_dim0_all_reduce\"\n        || user_type_name == \"_nccl_logical_2D_same_dim0_all_gather\"\n        || user_type_name == \"_nccl_logical_2D_same_dim0_all_gather_noncontinuous\"\n        || user_type_name == \"_nccl_logical_2D_same_dim0_all2all\"\n        || user_type_name == \"_nccl_logical_2D_same_dim1_all_reduce\"\n        /* || user_type_name == \"_nccl_logical_send_recv\" */) {\n      // TODO(chengcheng) : support nccl send/recv kernel\n      return true;\n    }\n  }\n  return false;\n}\n\nMaybe<void> ReplaceNcclOpsWithFusionOp(std::vector<OperatorConf>* nccl_fusion_ops,\n                                       std::vector<ParallelConf>* nccl_fusion_op_parallel_confs,\n                                       std::unordered_set<std::string>* del_ops,\n                                       HashMap<std::string, OperatorConf>* mut_op_name2conf,\n                                       const std::vector<const OpNode*>& nccl_ops) {\n  if (nccl_ops.size() <= 1) { return Maybe<void>::Ok(); }\n  const int32_t nccl_size = nccl_ops.size();\n  const OpNode* first_nccl = nccl_ops.front();\n  const OperatorConf& first_nccl_conf = first_nccl->op().op_conf();\n  const ParallelDesc& seed_placement = first_nccl->parallel_desc();\n  const int64_t scope_symbol_id = first_nccl_conf.scope_symbol_id();\n  std::vector<std::string> src_nd_sbp_str_list;\n  std::vector<std::string> dst_nd_sbp_str_list;\n  std::vector<std::string> nccl_type_list;\n  int64_t logical_chain_id = first_nccl_conf.logical_chain_id();\n  bool has_stream_name_hint = first_nccl_conf.has_stream_name_hint();\n  std::string stream_name_hint = first_nccl_conf.stream_name_hint();\n  user_op::UserOpConfWrapperBuilder fusion_builder =\n      user_op::UserOpConfWrapperBuilder(\"Sys-NCCL-fusion-\" + NewUniqueId());\n  fusion_builder.OpTypeName(\"_nccl_logical_fusion\");\n  for (const OpNode* nccl_op : nccl_ops) {\n    fusion_builder.Input(\"in\",\n                         GenLogicalBlobName(nccl_op->op().BnInOp2Lbi(nccl_op->op().SoleIbn())));\n    src_nd_sbp_str_list.push_back(\n        NdSbpToLongString(nccl_op->NdSbp4BnInOp(nccl_op->op().SoleIbn())));\n    dst_nd_sbp_str_list.push_back(\n        NdSbpToLongString(nccl_op->NdSbp4BnInOp(nccl_op->op().SoleObn())));\n    nccl_type_list.push_back(nccl_op->op().op_conf().user_conf().op_type_name());\n    CHECK(seed_placement == nccl_op->parallel_desc());\n    CHECK_EQ(has_stream_name_hint, nccl_op->op().op_conf().has_stream_name_hint());\n    CHECK_EQ(stream_name_hint, nccl_op->op().op_conf().stream_name_hint());\n    // 1. update del op\n    VLOG(3) << \" Del op: \" << nccl_op->op().op_conf().DebugString();\n    del_ops->insert(nccl_op->op().op_name());\n  }\n\n  auto fusion_nccl_op =\n      fusion_builder.Output(\"out\", nccl_size)\n          .Attr<std::vector<std::string>>(\"src_nd_sbp_str_list\", src_nd_sbp_str_list)\n          .Attr<std::vector<std::string>>(\"dst_nd_sbp_str_list\", dst_nd_sbp_str_list)\n          .Attr<std::vector<std::string>>(\"nccl_type_list\", nccl_type_list)\n          .ScopeSymbolId(scope_symbol_id)\n          .Build();\n  OperatorConf fusion_nccl_op_conf = fusion_nccl_op.op_conf();\n  fusion_nccl_op_conf.set_logical_chain_id(logical_chain_id);\n  if (has_stream_name_hint) { fusion_nccl_op_conf.set_stream_name_hint(stream_name_hint); }\n\n  // 2. update fusion op\n  VLOG(3) << \" Add fusion op : \" << fusion_nccl_op_conf.DebugString()\n          << \" \\n with placement: \" << seed_placement.parallel_conf().DebugString();\n  nccl_fusion_ops->push_back(fusion_nccl_op_conf);\n  nccl_fusion_op_parallel_confs->push_back(seed_placement.parallel_conf());\n\n  for (int32_t i = 0; i < nccl_size; ++i) {\n    std::string output_lbn = fusion_nccl_op.output(\"out\", i);\n    std::string input_lbn = fusion_nccl_op.input(\"in\", i);\n    const OpNode* origin_nccl = JUST(VectorAt(nccl_ops, i));\n    const OpEdge* origin_edge = origin_nccl->SoleOutEdge();\n    std::string origin_nccl_input_lbn =\n        GenLogicalBlobName(origin_nccl->op().BnInOp2Lbi(origin_nccl->op().SoleIbn()));\n    std::string origin_nccl_output_lbn =\n        GenLogicalBlobName(origin_nccl->op().BnInOp2Lbi(origin_nccl->op().SoleObn()));\n    CHECK_EQ(input_lbn, origin_nccl_input_lbn);\n    const OpNode* origin_consumer = origin_edge->dst_node();\n    const std::string& consumer_op_name = origin_consumer->op().op_name();\n    if (mut_op_name2conf->find(consumer_op_name) == mut_op_name2conf->end()) {\n      mut_op_name2conf->emplace(consumer_op_name, origin_consumer->op().op_conf());\n    }\n    CHECK_EQ(origin_edge->lbis().size(), 1);\n    const LogicalBlobId& lbi = origin_edge->lbis().front();\n    VLOG(3) << \" input_lbn: \" << input_lbn;\n    VLOG(3) << \" lbi: \" << GenLogicalBlobName(lbi);\n    CHECK_EQ(origin_nccl_output_lbn, GenLogicalBlobName(lbi));\n\n    // 3. update consumer op\n    for (const std::string& ibn : JUST(MapAt(origin_edge->lbi2ibns(), lbi))) {\n      std::string old_lbn = ReplaceInputLbnInOpCustomizedConf(\n          &JUST(MapAt(*mut_op_name2conf, consumer_op_name)), ibn, output_lbn);\n      CHECK_EQ(old_lbn, origin_nccl_output_lbn);\n    }\n\n    VLOG(3) << \" Update origin consumer op from: \\n [ \"\n            << origin_consumer->op().op_conf().DebugString() << \" ] \\n to \\n [ \"\n            << JUST(MapAt(*mut_op_name2conf, consumer_op_name)).DebugString() << \" ] \\n\";\n  }\n  return Maybe<void>::Ok();\n}\n\nstruct NcclFusionBucket {\n  std::vector<const OpNode*> nccl_ops;\n  int64_t fusion_bucket_size;\n  NcclFusionBucket() : fusion_bucket_size(0) {}\n};\n\nstd::string GenNcclFusionKey(const OpNode* nccl_op) {\n  // NOTE(chengcheng): Chain need same placement but ignore hierarchy,\n  //   logical_chain_id + hierarchy_shape can guarantee the same device_mesh.\n  int64_t logical_chain_id = nccl_op->op().op_conf().logical_chain_id();\n  const auto& hierarchy = nccl_op->parallel_desc().hierarchy();\n  std::string fusion_key =\n      \"logical_chain_id: \" + std::to_string(logical_chain_id)\n      + \", device_mesh: \" + hierarchy->ToString()\n      + \", comm: \" + GetCommKeyFromNcclType(nccl_op->op().op_conf().user_conf().op_type_name());\n  return fusion_key;\n}\n\nint64_t GetNcclOpMemSize(const OpNode* nccl_op) {\n  const LogicalBlobId& in_lbi = nccl_op->op().BnInOp2Lbi(nccl_op->op().SoleIbn());\n  const LogicalBlobId& out_lbi = nccl_op->op().BnInOp2Lbi(nccl_op->op().SoleObn());\n  const BlobDesc& in_logical_blob_desc = nccl_op->LogicalBlobDesc4Lbi(in_lbi);\n  const BlobDesc& out_logical_blob_desc = nccl_op->LogicalBlobDesc4Lbi(out_lbi);\n  const std::shared_ptr<Shape> in_local_shape = CHECK_JUST(GetPhysicalShape(\n      in_logical_blob_desc.shape(), nccl_op->NdSbp4Lbi(in_lbi), nccl_op->parallel_desc(), 0));\n  const std::shared_ptr<Shape> out_local_shape = CHECK_JUST(GetPhysicalShape(\n      out_logical_blob_desc.shape(), nccl_op->NdSbp4Lbi(out_lbi), nccl_op->parallel_desc(), 0));\n  int64_t elem_cnt = std::max(in_local_shape->elem_cnt(), out_local_shape->elem_cnt());\n  return GetCudaAlignedSize(elem_cnt * GetSizeOfDataType(in_logical_blob_desc.data_type()));\n}\n\nvoid AppendOrCreatFusionBucket(std::vector<NcclFusionBucket>* buckets, const OpNode* nccl_op,\n                               const int64_t bucket_limit) {\n  const int64_t nccl_mem_size = GetNcclOpMemSize(nccl_op);\n  for (auto& fusion_bucket : *buckets) {\n    if (fusion_bucket.fusion_bucket_size + nccl_mem_size < bucket_limit) {\n      fusion_bucket.nccl_ops.push_back(nccl_op);\n      fusion_bucket.fusion_bucket_size += nccl_mem_size;\n      return;\n    }\n  }\n  buckets->push_back(NcclFusionBucket());\n  buckets->back().nccl_ops.push_back(nccl_op);\n  buckets->back().fusion_bucket_size += nccl_mem_size;\n}\n\nMaybe<void> NcclLogicalOpFusionPass::Apply(const OpGraph& op_graph, JobBuilder* job_builder) const {\n  HashMap<const OpNode*, int64_t> op_node2nccl_depth;\n  HashMap<int64_t, std::vector<const OpNode*>> nccl_depth2nccl_ops;\n  auto ConstForEachDataAndCtrlInNode = [&](const OpNode* node,\n                                           const std::function<void(const OpNode*)>& Handler) {\n    node->ForEachNodeOnInEdge(Handler);\n    for (const auto& ctrl_in_op_name : node->op().op_conf().ctrl_in_op_name()) {\n      const OpNode* in_node = op_graph.OpNode4OpName(ctrl_in_op_name);\n      CHECK(in_node) << \" cannot find ctrl_in_op_name: [\" << ctrl_in_op_name << \"] of op: [\"\n                     << node->op().op_name() << \"] in OpGraph. \";\n      Handler(in_node);\n    }\n  };\n\n  std::vector<const OpNode*> ordered_op_nodes;\n  if (ParseBooleanFromEnv(\"DISABLE_LOGICAL_STRAIGHTEN\", false)) {\n    op_graph.TopoForEachNodeWithCtrlEdge(\n        [&](const OpNode* node) { ordered_op_nodes.emplace_back(node); });\n  } else {\n    auto_parallel::StraightenOpGraph(op_graph, &ordered_op_nodes);\n  }\n\n  for (const OpNode* node : ordered_op_nodes) {\n    int64_t nccl_depth = 0;\n    ConstForEachDataAndCtrlInNode(node, [&](const OpNode* in_node) {\n      auto it = op_node2nccl_depth.find(in_node);\n      CHECK(it != op_node2nccl_depth.end());  // topo search\n      nccl_depth = std::max(nccl_depth, it->second);\n    });\n    if (IsNcclLogicalOpNode(node)) {\n      nccl_depth++;  // ONLY nccl node update depth\n      nccl_depth2nccl_ops[nccl_depth].push_back(node);\n    }\n    CHECK(op_node2nccl_depth.emplace(node, nccl_depth).second);\n  }\n\n  if (nccl_depth2nccl_ops.empty()) { return Maybe<void>::Ok(); }\n\n  std::vector<OperatorConf> nccl_fusion_ops;\n  std::vector<ParallelConf> nccl_fusion_op_parallel_confs;\n\n  std::unordered_set<std::string> del_ops;\n  HashMap<std::string, OperatorConf> mut_op_name2conf;\n\n  const int64_t bucket_limit = EnvInteger<ONEFLOW_GRAPH_NCCL_LOGICAL_FUSION_BUCKET_SIZE>();\n  VLOG(2) << \"bucket_limit = \" << bucket_limit;\n\n  for (const auto& pair : nccl_depth2nccl_ops) {\n    HashMap<std::string, std::vector<NcclFusionBucket>> fusion_key2nccl_buckets;\n    for (const OpNode* nccl_op : pair.second) {\n      CHECK(nccl_op->op().op_conf().has_logical_chain_id());\n      std::string fusion_key = GenNcclFusionKey(nccl_op);\n      AppendOrCreatFusionBucket(&fusion_key2nccl_buckets[fusion_key], nccl_op, bucket_limit);\n    }\n    for (const auto& pair : fusion_key2nccl_buckets) {\n      for (const auto& fusion_bucket : pair.second) {\n        JUST(ReplaceNcclOpsWithFusionOp(&nccl_fusion_ops, &nccl_fusion_op_parallel_confs, &del_ops,\n                                        &mut_op_name2conf, fusion_bucket.nccl_ops));\n      }\n    }\n  }\n\n  job_builder->RemoveOpByName(del_ops);\n  for (const auto& pair : mut_op_name2conf) { JUST(job_builder->MutOpOnlyOnce(pair.second)); }\n  CHECK_EQ(nccl_fusion_ops.size(), nccl_fusion_op_parallel_confs.size());\n  for (int32_t i = 0; i < nccl_fusion_ops.size(); ++i) {\n    JUST(job_builder->AddOp(JUST(VectorAt(nccl_fusion_op_parallel_confs, i)),\n                            JUST(VectorAt(nccl_fusion_ops, i))));\n  }\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\nREGISTER_JOB_PASS(\"NcclLogicalOpFusionPass\", NcclLogicalOpFusionPass);\n\n}  // namespace oneflow\n\n#endif  // WITH_CUDA || WITH_NPU || WITH_MLU\n"
  },
  {
    "path": "oneflow/core/job_rewriter/normalization_exponential_average_auto_tick_rewrite_pass.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/job_rewriter/job_pass.h\"\n#include \"oneflow/core/framework/framework.h\"\n\nnamespace oneflow {\n\nclass NormalizationExponentialAverageAutoTickPass final : public JobPass {\n public:\n  NormalizationExponentialAverageAutoTickPass() = default;\n  ~NormalizationExponentialAverageAutoTickPass() override = default;\n\n  Maybe<void> Apply(Job* job, JobPassCtx* ctx) const override;\n};\n\nMaybe<void> NormalizationExponentialAverageAutoTickPass::Apply(Job* job, JobPassCtx* ctx) const {\n  const JobConfigProto& job_conf = ctx->job_desc().job_conf();\n  if (!job_conf.has_train_conf()) { return Maybe<void>::Ok(); }\n  if ((!job_conf.has_num_gradient_accumulation_steps())\n      || job_conf.num_gradient_accumulation_steps() <= 1) {\n    return Maybe<void>::Ok();\n  }\n  const OpGraph op_graph(*job);\n  JobBuilder job_builder(job);\n  JUST(op_graph.TopoForEachNodeWithErrorCaptured([&](const OpNode* node) -> Maybe<void> {\n    const OperatorConf& op_conf = node->op().op_conf();\n    if (!op_conf.has_user_conf()) { return Maybe<void>::Ok(); }\n    const user_op::UserOpConfWrapper user_op_conf(op_conf);\n    if (user_op_conf.op_type_name() != \"normalization\"\n        && user_op_conf.op_type_name() != \"normalization_add_relu\") {\n      return Maybe<void>::Ok();\n    }\n    const std::string& x_lbn = user_op_conf.input(\"x\", 0);\n    const std::string& moving_mean_lbn = user_op_conf.input(\"moving_mean\", 0);\n    const std::string& moving_variance_lbn = user_op_conf.input(\"moving_variance\", 0);\n    std::string x_tick_lbn;\n    auto GetXTick = [&]() {\n      if (x_tick_lbn.empty()) {\n        user_op::UserOpConfWrapperBuilder cast_to_tick_builder(\"System-CastToTick-\"\n                                                               + NewUniqueId());\n        const auto cast_to_tick_op = cast_to_tick_builder.OpTypeName(\"cast_to_tick\")\n                                         .Input(\"in\", x_lbn)\n                                         .Output(\"out\")\n                                         .Build();\n        job_builder.AddOps(node->parallel_desc().parallel_conf(), {cast_to_tick_op.op_conf()});\n        x_tick_lbn = cast_to_tick_op.output(\"out\", 0);\n      }\n      return x_tick_lbn;\n    };\n    auto TrySetTickForNode = [&](const OpNode* var_node) {\n      if (!var_node->in_edges().empty()) { return; }\n      if (!var_node->op().op_conf().has_variable_conf()) { return; }\n      if (var_node->op().op_conf().variable_conf().has_tick()) { return; }\n      OperatorConf new_var_op_conf = var_node->op().op_conf();\n      new_var_op_conf.mutable_variable_conf()->set_tick(GetXTick());\n      job_builder.MutOpsOnlyOnce({new_var_op_conf});\n    };\n    TrySetTickForNode(op_graph.OpNode4OpName(GenLogicalBlobId(moving_mean_lbn).op_name()));\n    TrySetTickForNode(op_graph.OpNode4OpName(GenLogicalBlobId(moving_variance_lbn).op_name()));\n    return Maybe<void>::Ok();\n  }));\n  return Maybe<void>::Ok();\n}\n\nREGISTER_JOB_PASS(\"NormalizationExponentialAverageAutoTickPass\",\n                  NormalizationExponentialAverageAutoTickPass);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job_rewriter/optimizer.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/job_rewriter/optimizer.h\"\n#include \"oneflow/core/job_rewriter/dynamic_loss_scale_job_pass_state.h\"\n#include <re2/re2.h>\n\nnamespace oneflow {\n\nvoid GenerateOptimizerOpConfWrapperStruct::Call(JobPassCtx* ctx, const OpNode& var_op_node,\n                                                const std::string& model_diff_lbn,\n                                                const OptimizerConf& optimizer_conf,\n                                                JobBuilder* job_builder) const {\n  (*func_)(ctx, var_op_node, model_diff_lbn, optimizer_conf, job_builder);\n}\n\nvoid AddOptimizerOp(JobPassCtx* ctx, const OpNode& var_op_node, const std::string& model_diff_lbn,\n                    const OptimizerConf& optimizer_conf, JobBuilder* job_builder) {\n  const auto optimizer_case = optimizer_conf.normal_mdupdt_case();\n  auto* obj = NewObj<int32_t, GenerateOptimizerOpConfWrapperStruct>(optimizer_case);\n  obj->Call(ctx, var_op_node, model_diff_lbn, optimizer_conf, job_builder);\n}\n\nfloat GetOptimizerWeightDecayRate(const OptimizerConf& optimizer_conf, const VariableOp& op) {\n  if (optimizer_conf.has_weight_decay_conf()) {\n    const WeightDecayConf& weight_decay_conf = optimizer_conf.weight_decay_conf();\n    std::function<bool(const std::string& op_name)> WeightDecayFilter;\n    if (weight_decay_conf.has_includes()) {\n      WeightDecayFilter = [&](const std::string& op_name) {\n        return std::any_of(\n            weight_decay_conf.includes().pattern().cbegin(),\n            weight_decay_conf.includes().pattern().cend(),\n            [&](const std::string& pattern) { return RE2::PartialMatch(op_name, pattern); });\n      };\n    } else if (weight_decay_conf.has_excludes()) {\n      WeightDecayFilter = [&](const std::string& op_name) {\n        return !std::any_of(\n            weight_decay_conf.excludes().pattern().cbegin(),\n            weight_decay_conf.excludes().pattern().cend(),\n            [&](const std::string& pattern) { return RE2::PartialMatch(op_name, pattern); });\n      };\n    } else {\n      WeightDecayFilter = [&](const std::string& op_name) { return true; };\n    }\n    if (WeightDecayFilter(op.op_name())) {\n      return weight_decay_conf.weight_decay_rate();\n    } else {\n      return 0;\n    }\n  } else {\n    return 0;\n  }\n}\n\nvoid SetDynamicLossScaleSkipIf(JobPassCtx* ctx, user_op::UserOpConfWrapperBuilder* builder) {\n  if (!ctx->job_desc().job_conf().train_conf().has_dynamic_loss_scale_policy()) { return; }\n  builder->Input(\"skip_if\",\n                 CHECK_JUST(ctx->GetState<DynamicLossScaleJobPassState>(\"dynamic_loss_scale_state\"))\n                     .count_not_finite_lbn());\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job_rewriter/optimizer.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_JOB_REWRITER_OPTIMIZER_H_\n#define ONEFLOW_CORE_JOB_REWRITER_OPTIMIZER_H_\n\n#include \"oneflow/core/graph/op_graph.h\"\n#include \"oneflow/core/operator/variable_op.h\"\n#include \"oneflow/core/job_rewriter/job_pass.h\"\n#include \"oneflow/core/framework/user_op_conf.h\"\n\nnamespace oneflow {\n\nvoid AddOptimizerOp(JobPassCtx* ctx, const OpNode& var_op_node, const std::string& model_diff_lbn,\n                    const OptimizerConf& optimizer_conf, JobBuilder* job_builder);\n\nfloat GetOptimizerWeightDecayRate(const OptimizerConf& optimizer_conf, const VariableOp& op);\n\nvoid SetDynamicLossScaleSkipIf(JobPassCtx* ctx, user_op::UserOpConfWrapperBuilder* builder);\n\nclass GenerateOptimizerOpConfWrapperStruct final {\n public:\n  using Func = std::function<void(JobPassCtx*, const OpNode&, const std::string&,\n                                  const OptimizerConf&, JobBuilder*)>;\n  GenerateOptimizerOpConfWrapperStruct(const Func& f) : func_(std::make_unique<Func>(f)) {}\n  void Call(JobPassCtx* ctx, const OpNode& var_op_node, const std::string& model_diff_lbn,\n            const OptimizerConf& optimizer_conf, JobBuilder* job_builder) const;\n\n private:\n  const std::unique_ptr<const Func> func_;\n};\n\n#define REGISTER_OPTIMIZER(model_update_case, gen_optimizer_conf_func)  \\\n  REGISTER_CLASS_CREATOR(                                               \\\n      int32_t, model_update_case, GenerateOptimizerOpConfWrapperStruct, \\\n      ([] { return new GenerateOptimizerOpConfWrapperStruct(gen_optimizer_conf_func); }))\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_JOB_REWRITER_OPTIMIZER_H_\n"
  },
  {
    "path": "oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <glog/logging.h>\n#include <cstdint>\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/framework/nd_sbp.h\"\n#include \"oneflow/core/framework/user_op_conf.h\"\n#include \"oneflow/core/job/job_conf.pb.h\"\n#include \"oneflow/core/job/nd_sbp_util.h\"\n#include \"oneflow/core/job/sbp_parallel.h\"\n#include \"oneflow/core/job/sbp_parallel.pb.h\"\n#include \"oneflow/core/job_rewriter/job_pass.h\"\n#include \"oneflow/core/graph/op_graph.h\"\n#include \"oneflow/core/job/job_desc.h\"\n#include \"oneflow/core/operator/op_conf.pb.h\"\n#include \"oneflow/core/operator/operator.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nint64_t GetSoleOutBlobSize(const OpNode* node) {\n  const BlobDesc& blob_desc =\n      node->LogicalBlobDesc4Lbi(node->op().BnInOp2Lbi(node->op().SoleObn()));\n  return blob_desc.shape().elem_cnt() * GetSizeOfDataType(blob_desc.data_type());\n}\n\nclass DataParallelNodeSequence final {\n public:\n  DataParallelNodeSequence(std::vector<const OpNode*> nodes, int64_t order)\n      : nodes_(std::move(nodes)), order_(order), len_(nodes_.size()) {\n    const OpNode* var_node = nodes_.front();\n    CHECK(var_node->op().op_conf().has_variable_conf());\n    model_size_ = GetSoleOutBlobSize(var_node);\n  }\n  ~DataParallelNodeSequence() = default;\n\n  const OpNode* GetVariableNode() const { return nodes_.front(); }\n\n  const OpNode* GetLastNode() const { return nodes_.back(); }\n\n  int64_t order() const { return order_; }\n\n  const std::vector<const OpNode*>& nodes() const { return nodes_; }\n\n  const ParallelDesc& parallel_desc() const { return nodes_.front()->parallel_desc(); }\n\n  int64_t model_size() const { return model_size_; }\n\n  int64_t len() const { return len_; }\n\n  void resize(const int64_t size) {\n    CHECK_LE(size, len_);\n    CHECK_GE(size, 1);\n    nodes_.resize(size);\n    len_ = nodes().size();\n  }\n\n private:\n  std::vector<const OpNode*> nodes_;\n  int64_t order_;\n  int64_t model_size_;\n  int64_t len_;\n};\n\nusing SequencePtr = std::shared_ptr<DataParallelNodeSequence>;\n\nParallelConf NonDistributedParallelConf4ParallelId(const ParallelDesc& pd,\n                                                   const int64_t parallel_id) {\n  std::string device_name;\n  device_name += std::to_string(CHECK_JUST(pd.MachineId4ParallelId(parallel_id)));\n  device_name += \":\";\n  device_name += std::to_string(CHECK_JUST(pd.DeviceId4ParallelId(parallel_id)));\n  ParallelConf parallel_conf;\n  *parallel_conf.mutable_device_name()->Add() = device_name;\n  parallel_conf.set_device_tag(*CHECK_JUST(DeviceTag4DeviceType(pd.device_type())));\n  return parallel_conf;\n}\n\nMaybe<void> GetDataParallelVariableAndNaiveSuccNode(\n    const OpNode* start, const std::function<bool(const OpNode*)>& IsAllowed,\n    std::vector<const OpNode*>* out) {\n  // Find sequence like: vairable -> cast_fp32_to_fp16\n  if (!start->op().op_conf().has_variable_conf()) { return Maybe<void>::Ok(); }\n  const ParallelDesc& pd = start->parallel_desc();\n  if (pd.parallel_num() == 1) { return Maybe<void>::Ok(); }\n  const OpNode* cur_node = start;\n  while (cur_node != nullptr) {\n    if (cur_node != start) {\n      if (cur_node->parallel_desc() != pd) { break; }\n      if (cur_node->in_edges().size() > 1) { break; }\n      if (cur_node->op().input_bns().size() != 1) { break; }\n      const std::string& sole_ibn = cur_node->op().SoleIbn();\n      const NdSbp& ibn_nd_sbp = cur_node->NdSbp4BnInOp(sole_ibn);\n      bool has_broadcast = false;\n      FOR_RANGE(int, i, 0, ibn_nd_sbp.sbp_parallel_size()) {\n        if (ibn_nd_sbp.sbp_parallel(i).has_broadcast_parallel()) { has_broadcast = true; };\n      }\n      if (!has_broadcast) { break; }\n    }\n    if (cur_node->op().output_bns().size() != 1) { break; }\n    const std::string& sole_obn = cur_node->op().SoleObn();\n    const NdSbp& obn_nd_sbp = cur_node->NdSbp4BnInOp(sole_obn);\n    bool has_broadcast = false;\n    FOR_RANGE(int, i, 0, obn_nd_sbp.sbp_parallel_size()) {\n      if (obn_nd_sbp.sbp_parallel(i).has_broadcast_parallel()) { has_broadcast = true; };\n    }\n    if (!has_broadcast) { break; }\n    out->emplace_back(cur_node);\n    if (cur_node->out_edges().size() == 1) {\n      cur_node = cur_node->SoleOutEdge()->dst_node();\n    } else {\n      cur_node = nullptr;\n    }\n  }\n  return Maybe<void>::Ok();\n}\n\nvoid SetBroadcastParallel4OpNodeIbn(JobBuilder* builder, const OpNode* node,\n                                    const std::string& ibn) {\n  OpBlobArg op_blob_arg;\n  op_blob_arg.set_op_name(node->op().op_name());\n  op_blob_arg.set_bn_in_op(ibn);\n  SbpParallel sbp_parallel;\n  sbp_parallel.mutable_broadcast_parallel();\n  builder->SetSbpParallel4Oba(op_blob_arg, sbp_parallel);\n}\n\nvoid SetBroadcastParallel4Consumers(JobBuilder* builder, const SequencePtr& sequence) {\n  const OpNode* node = sequence->GetLastNode();\n  const LogicalBlobId& lbi = node->op().BnInOp2Lbi(node->op().SoleObn());\n  node->ForEachNodeOnOutEdge([&](const OpNode* out_node) {\n    for (const std::string& ibn : out_node->op().input_bns()) {\n      if (out_node->op().BnInOp2Lbi(ibn) == lbi) {\n        SetBroadcastParallel4OpNodeIbn(builder, out_node, ibn);\n      }\n    }\n  });\n}\n\nvoid SetNdSbp4OpNodeIbn(JobBuilder* builder, const OpNode* node, const std::string& ibn,\n                        const NdSbp& nd_sbp) {\n  OpBlobArg op_blob_arg;\n  op_blob_arg.set_op_name(node->op().op_name());\n  op_blob_arg.set_bn_in_op(ibn);\n  builder->SetNdSbp4Oba(op_blob_arg, nd_sbp);\n}\n\nvoid SetNdSbp4Consumers(JobBuilder* builder, const SequencePtr& sequence, const NdSbp& nd_sbp) {\n  const OpNode* node = sequence->GetLastNode();\n  const LogicalBlobId& lbi = node->op().BnInOp2Lbi(node->op().SoleObn());\n  const int64_t shard_restore_level =\n      builder->job().job_conf().optimizer_placement_optimization_shard_restore_level();\n  // If shard_restore_level == 0, no limit on consumer\n  if (shard_restore_level == 1) {\n    // Input lbn for parallel cast op\n    std::string parallel_cast_input_lbn = GenLogicalBlobName(lbi);\n    // Add parallel cast op to make soft limt on consumer to consume weight with Broadcast SBP.\n    const auto parallel_cast_op =\n        user_op::UserOpConfWrapperBuilder(\"System-ZeRO-ParallelCast-\" + node->op().op_name() + \"-\"\n                                          + NewUniqueId())\n            .Op(\"hierarchical_parallel_cast\")\n            .Input(\"in\", parallel_cast_input_lbn)\n            .Output(\"out\")\n            .Attr<std::vector<std::string>>(\"nd_sbp\", NdSbpToStringList(nd_sbp))\n            .Attr<std::string>(\"grad_mode\", \"identity\")  // don't do ndsbp cast at backward\n            .Attr<std::vector<std::string>>(\"grad_nd_sbp\", std::vector<std::string>())\n            .ScopeSymbolId(node->op().op_conf().scope_symbol_id())\n            .Build();\n    builder->AddOps(node->parallel_desc().parallel_conf(), {parallel_cast_op.op_conf()});\n\n    // Make consumers to consume parallel cast op\n    auto out_lbn = parallel_cast_op.output(\"out\", 0);\n    node->ForEachNodeOnOutEdge([&](const OpNode* out_node) {\n      for (const std::string& ibn : out_node->op().input_bns()) {\n        if (out_node->op().BnInOp2Lbi(ibn) == lbi) {\n          if (!CHECK_JUST(builder->IsInMutOpTransaction(out_node->op().op_name()))) {\n            CHECK_JUST(builder->MutOpTransactionMut(out_node->op().op_conf()));\n          }\n          OperatorConf& mut_consumer_op =\n              CHECK_JUST(builder->MutOpTransactionGet(out_node->op().op_name()));\n          const auto& old_lbn = ReplaceInputLbnInOpCustomizedConf(&mut_consumer_op, ibn, out_lbn);\n          CHECK_EQ(old_lbn, GenLogicalBlobName(lbi));\n        }\n      }\n    });\n  } else if (shard_restore_level == 2) {\n    // Hard limt consumer to consume weight as Broadcast.\n    node->ForEachNodeOnOutEdge([&](const OpNode* out_node) {\n      for (const std::string& ibn : out_node->op().input_bns()) {\n        if (out_node->op().BnInOp2Lbi(ibn) == lbi) {\n          SetNdSbp4OpNodeIbn(builder, out_node, ibn, nd_sbp);\n        }\n      }\n    });\n  }\n}\n\nstd::function<int64_t(const OpNode*)> MakeGetterOpNode2TopoOrder(const OpGraph& op_graph) {\n  HashMap<const OpNode*, int64_t> op_node2topo_order;\n  int64_t node_cnt = 0;\n  op_graph.TopoForEachNode([&](const OpNode* node) {\n    op_node2topo_order[node] = node_cnt;\n    node_cnt += 1;\n  });\n  return [op_node2topo_order](const OpNode* node) { return op_node2topo_order.at(node); };\n}\n\nint64_t GetMinConsumerOrder(const OpGraph& op_graph, const OpNode* node,\n                            const std::function<int64_t(const OpNode*)>& OpNode2Order) {\n  int64_t min_consumer_topo_order = op_graph.node_num();\n  node->ForEachNodeOnOutEdge([&](const OpNode* dst) {\n    min_consumer_topo_order = std::min(min_consumer_topo_order, OpNode2Order(dst));\n  });\n  return min_consumer_topo_order;\n}\n\nvoid ForEachDataParallelNodeSequence(const OpGraph& op_graph,\n                                     const std::function<bool(const OpNode*)>& IsAllowed,\n                                     std::function<void(SequencePtr&&)> Handler) {\n  auto OpNode2Order = MakeGetterOpNode2TopoOrder(op_graph);\n  op_graph.ForEachNode([&](const OpNode* node) {\n    std::vector<const OpNode*> nodes;\n    // Find sequence like: vairable -> cast_fp32_to_fp16\n    CHECK_JUST(GetDataParallelVariableAndNaiveSuccNode(node, IsAllowed, &nodes));\n    if (nodes.empty()) { return; }\n    const int64_t order = GetMinConsumerOrder(op_graph, nodes.back(), OpNode2Order);\n    Handler(std::make_shared<DataParallelNodeSequence>(std::move(nodes), order));\n  });\n}\n\nbool SequenceCompSortedByOrderAsc(const SequencePtr& lhs, const SequencePtr& rhs) {\n  return lhs->order() < rhs->order();\n}\n\nbool SequenceCompSortedByModelSizeDesc(const SequencePtr& lhs, const SequencePtr& rhs) {\n  return lhs->model_size() > rhs->model_size();\n}\n\nvoid ForEachParallelSortedNodeSequence(\n    const OpGraph& op_graph, const std::function<bool(const OpNode*)>& IsAllowed,\n    const std::function<bool(const SequencePtr&, const SequencePtr&)>& Comp,\n    const std::function<void(const ParallelDesc&, std::vector<SequencePtr>&&)>& Handler) {\n  HashMap<ParallelDesc, std::vector<SequencePtr>> parallel_desc2sequences;\n  // Find sequence like: vairable -> cast_fp32_to_fp16\n  ForEachDataParallelNodeSequence(op_graph, IsAllowed, [&](SequencePtr&& sequence) {\n    parallel_desc2sequences[sequence->parallel_desc()].emplace_back(std::move(sequence));\n  });\n  for (auto& pair : parallel_desc2sequences) {\n    auto& sequences = pair.second;\n    std::sort(sequences.begin(), sequences.end(), Comp);\n    Handler(pair.first, std::move(sequences));\n  }\n}\n\nbool IsS0Parallel(const SbpParallel& sbp_parallel) {\n  return sbp_parallel.has_split_parallel() && sbp_parallel.split_parallel().axis() == 0;\n}\n\nbool IsS0Parallel(const SbpSignature& signature, const std::string& bn) {\n  return IsS0Parallel(signature.bn_in_op2sbp_parallel().at(bn));\n}\n\nbool IsNdSbpMatch(const NdSbpSignature& signature, const std::string& bn, const NdSbp& nd_sbp) {\n  return signature.bn_in_op2nd_sbp().at(bn) == nd_sbp;\n}\n\nbool IsNdSbpSupported4Op(const OpNode* node, const NdSbp& nd_sbp) {\n  if (node->op().input_bns().size() != 1 || node->op().output_bns().size() != 1) { return false; }\n  std::vector<NdSbpSignature> list;\n  auto LogicalBlobDesc4Ibn = [&](const std::string& bn) -> Maybe<const BlobDesc&> {\n    return Maybe<const BlobDesc&>(node->LogicalBlobDesc4Lbi(node->op().BnInOp2Lbi(bn)));\n  };\n  CHECK_JUST(node->op().GetNdSbpSignatureList(LogicalBlobDesc4Ibn, node->parallel_desc(), &list));\n  const auto IsInAndOutMatch = [&](const NdSbpSignature& signature) {\n    return IsNdSbpMatch(signature, node->op().SoleIbn(), nd_sbp)\n           && IsNdSbpMatch(signature, node->op().SoleObn(), nd_sbp);\n  };\n  return std::any_of(list.cbegin(), list.cend(), IsInAndOutMatch);\n}\n\nbool IsS0SignatureSupported(const OpNode* node) {\n  if (node->op().input_bns().size() != 1 || node->op().output_bns().size() != 1) { return false; }\n  SbpSignatureList list;\n  auto LogicalBlobDesc4Ibn = [&](const std::string& bn) -> Maybe<const BlobDesc&> {\n    return Maybe<const BlobDesc&>(node->LogicalBlobDesc4Lbi(node->op().BnInOp2Lbi(bn)));\n  };\n  CHECK_JUST(node->op().GetSbpSignaturesIf(LogicalBlobDesc4Ibn,\n                                           node->parallel_desc().parallel_num(), &list));\n  const auto IsInOutS0Parallel = [&](const SbpSignature& signature) {\n    return IsS0Parallel(signature, node->op().SoleIbn())\n           && IsS0Parallel(signature, node->op().SoleObn());\n  };\n  return std::any_of(list.sbp_signature().cbegin(), list.sbp_signature().cend(), IsInOutS0Parallel);\n}\n\nvoid ForEachModelSizeBalancedPartition(\n    const ParallelDesc& parallel_desc, std::vector<SequencePtr>&& sorted_sequences,\n    const std::function<void(ParallelDesc new_parallel_desc, std::vector<SequencePtr>&&)>&\n        Handler) {\n  std::vector<SequencePtr> sequences = std::move(sorted_sequences);\n  std::vector<int64_t> parallel_id2model_size(parallel_desc.parallel_num(), 0);\n  std::vector<std::vector<SequencePtr>> partitions(parallel_desc.parallel_num());\n  for (auto& sequence : sequences) {\n    const auto it =\n        std::min_element(parallel_id2model_size.cbegin(), parallel_id2model_size.cend());\n    const int64_t min_parallel_id = std::distance(parallel_id2model_size.cbegin(), it);\n    parallel_id2model_size.at(min_parallel_id) += sequence->model_size();\n    partitions.at(min_parallel_id).emplace_back(std::move(sequence));\n  }\n  for (int64_t i = 0; i < parallel_desc.parallel_num(); ++i) {\n    ParallelConf parallel_conf = NonDistributedParallelConf4ParallelId(parallel_desc, i);\n    Handler(parallel_conf, std::move(partitions.at(i)));\n  }\n}\n\nnamespace {\nbool IsSplitValid(const Shape& shape, const NdSbp& nd_sbp, const Shape& hierachy,\n                  int64_t min_size) {\n  if (shape.NumAxes() < 1 || shape.elem_cnt() < 1) { return false; }\n  CHECK_EQ(nd_sbp.sbp_parallel_size(), hierachy.NumAxes());\n  Shape cur_shape = shape;\n  if (cur_shape.elem_cnt() < min_size) { return false; }\n  FOR_RANGE(int64_t, i, 0, hierachy.NumAxes()) {\n    const auto& sbp = nd_sbp.sbp_parallel(i);\n    if (sbp.has_split_parallel()) {\n      const int64_t dim = sbp.split_parallel().axis();\n      if (dim >= cur_shape.NumAxes()) { return false; }\n      // Unbalanced split and take the minimum\n      cur_shape.Set(dim, cur_shape.At(dim) / hierachy.At(i));\n      // Larger then min size.\n      if (cur_shape.elem_cnt() < min_size) { return false; }\n    }\n  }\n  return true;\n}\n\nvoid GenerateSplitSignature(const NdSbp& var_nd_sbp, const OperatorConf& new_var_op_conf,\n                            std::string& new_split_signature, int64_t& split_dim) {\n  if (new_var_op_conf.variable_conf().nd_sbp_size() > 0 && NdSbpIsAllBroadcast(var_nd_sbp)) {\n    // split last dim\n    split_dim = new_var_op_conf.variable_conf().nd_sbp_size() - 1;\n    // All B, B -> S0\n    new_split_signature = \"S(0)\";\n  } else {\n    // ND sbp, (*, B, S, *) -> (*, S, S, *)\n    // ND sbp, (*, S, B, *) -> (*, S, S, *)\n    FOR_RANGE(int64_t, j, 0, new_var_op_conf.variable_conf().nd_sbp_size()) {\n      if (new_var_op_conf.variable_conf().nd_sbp(j) == \"B\") {\n        std::vector<int64_t> adjacent_dim{j - 1, j + 1};\n        for (auto const& dim_to_try : adjacent_dim) {\n          if (dim_to_try >= 0 && dim_to_try < new_var_op_conf.variable_conf().nd_sbp_size()) {\n            SbpParallel sbp;\n            if (ParseSbpParallelFromString(new_var_op_conf.variable_conf().nd_sbp(dim_to_try), &sbp)\n                && sbp.has_split_parallel()) {\n              new_split_signature = new_var_op_conf.variable_conf().nd_sbp(dim_to_try);\n              split_dim = j;\n            }\n          }\n          if (new_split_signature != \"\") break;\n        }\n      }\n      // Only split one more dim.\n      if (new_split_signature != \"\") break;\n    }\n  }\n}\nvoid ShardSequence(JobBuilder* builder, const int64_t threshold, const ParallelDesc& pd,\n                   std::vector<SequencePtr>&& sorted_sequences) {\n  // For all sorted sequence, set the variable op in the sequence to S\n  // and add ctrl edge to control the execution order between variable ops.\n  // A sequence is a variable op and its cast(fp32 to fp16) op. This is because the forward pass\n  // consume the fp16 variable and the optimizer consume the fp32 variable.\n  std::string prev_allowed_op_name = \"\";\n  for (int64_t i = 0; i < sorted_sequences.size(); ++i) {\n    const OpNode* var_node = sorted_sequences.at(i)->GetVariableNode();\n    OperatorConf new_var_op_conf = var_node->op().op_conf();\n    const std::string& sole_obn = var_node->op().SoleObn();\n    const NdSbp& var_nd_sbp = var_node->NdSbp4BnInOp(sole_obn);\n    const Shape& logical_shape = Shape(new_var_op_conf.variable_conf().shape());\n\n    std::string new_split_signature = \"\";\n    int64_t split_dim = 0;\n    GenerateSplitSignature(var_nd_sbp, new_var_op_conf, new_split_signature, split_dim);\n    if (new_split_signature != \"\") {\n      *new_var_op_conf.mutable_variable_conf()->mutable_nd_sbp(split_dim) = new_split_signature;\n    } else {\n      continue;\n    }\n\n    bool split_is_allowed = true;\n    {\n      NdSbp new_nd_sbp;\n      std::vector<std::string> nd_sbp_str_vec;\n      for (const auto& sbp_str : new_var_op_conf.variable_conf().nd_sbp()) {\n        nd_sbp_str_vec.emplace_back(sbp_str);\n      }\n      ParseNdSbpFromStringList(nd_sbp_str_vec, &new_nd_sbp);\n      // check allowed by min shard size and evenly split\n      if (split_is_allowed) {\n        split_is_allowed = IsSplitValid(logical_shape, new_nd_sbp, *pd.hierarchy(), threshold);\n      }\n      if (split_is_allowed) {\n        // resize sequence by new nd sbp limit\n        auto& cur_seq = sorted_sequences.at(i);\n        int64_t max_len = 1;\n        if (cur_seq->len() > 1) {\n          FOR_RANGE(int64_t, node_idx, 1, cur_seq->len()) {\n            if (IsNdSbpSupported4Op(cur_seq->nodes().at(node_idx), new_nd_sbp)) {\n              ++max_len;\n            } else {\n              break;\n            }\n          }\n        }\n        if (max_len < cur_seq->len()) { cur_seq->resize(max_len); }\n      }\n    }\n    if (!split_is_allowed) {\n      VLOG(3) << var_node->op().op_name() << \" failed to change from B to S \"\n              << \" with op conf \" << new_var_op_conf.variable_conf().DebugString();\n      continue;\n    }\n    if (!prev_allowed_op_name.empty()) {\n      new_var_op_conf.add_ctrl_in_op_name(prev_allowed_op_name);\n    }\n    builder->MutOpsOnlyOnce({new_var_op_conf});\n    // Set consumers to consum this variable op's cast op's output as Broadcast.\n    if (new_split_signature != \"\") {\n      SetNdSbp4Consumers(builder, sorted_sequences.at(i), var_nd_sbp);\n    }\n    prev_allowed_op_name = var_node->op().op_name();\n    VLOG(3) << var_node->op().op_name() << \" succeed to change from B to \" << new_split_signature\n            << \" on ranks dim \" << split_dim << \" with op conf \"\n            << new_var_op_conf.variable_conf().DebugString();\n  }\n}\n}  // namespace\n\nMaybe<void> RewriteDistributedSplit(const OpGraph& op_graph, JobBuilder* builder) {\n  const int64_t threshold = builder->job().job_conf().optimizer_placement_optimization_threshold();\n  const auto IsAllowed = [](const OpNode* n) -> bool {\n    // No need to limit here.\n    return true;\n  };\n  const auto PlacementSequencesAsSplitParallel = [&](const ParallelDesc& pd,\n                                                     std::vector<SequencePtr>&& sorted_sequences) {\n    ShardSequence(builder, threshold, pd, std::forward<std::vector<SequencePtr>>(sorted_sequences));\n  };\n  ForEachParallelSortedNodeSequence(op_graph, IsAllowed, SequenceCompSortedByOrderAsc,\n                                    PlacementSequencesAsSplitParallel);\n  JUST(builder->MutOpTransactionCommit());\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> RewriteNonDistributed(const OpGraph& op_graph, JobBuilder* builder) {\n  HashMap<ParallelDesc, std::vector<SequencePtr>> new_parallel_desc2sequences;\n  const auto RewritePartition = [&](const ParallelDesc& new_parallel_desc,\n                                    std::vector<SequencePtr>&& partition) {\n    for (auto& sequence : partition) {\n      for (const OpNode* op_node : sequence->nodes()) {\n        builder->MutParallelConfOnlyOnce(op_node->op().op_name(),\n                                         new_parallel_desc.parallel_conf());\n      }\n      SetBroadcastParallel4Consumers(builder, sequence);\n      new_parallel_desc2sequences[new_parallel_desc].emplace_back(std::move(sequence));\n    }\n  };\n  const auto RewriteSequences = [&](const ParallelDesc& pd,\n                                    std::vector<SequencePtr>&& sorted_sequences) {\n    ForEachModelSizeBalancedPartition(pd, std::move(sorted_sequences), RewritePartition);\n  };\n  const int64_t threshold = builder->job().job_conf().optimizer_placement_optimization_threshold();\n  const auto IsAllowed = [threshold](const OpNode* n) -> bool {\n    if (n->op().op_conf().has_variable_conf()) {\n      const Shape shape(n->op().op_conf().variable_conf().shape());\n      const int64_t parallel_num = n->parallel_desc().parallel_num();\n      return shape.elem_cnt() >= threshold * parallel_num;\n    } else {\n      return true;\n    }\n  };\n  ForEachParallelSortedNodeSequence(op_graph, IsAllowed, SequenceCompSortedByModelSizeDesc,\n                                    RewriteSequences);\n\n  for (auto& parallel_desc7sequences : new_parallel_desc2sequences) {\n    auto& sequences = parallel_desc7sequences.second;\n    std::sort(sequences.begin(), sequences.end(), SequenceCompSortedByOrderAsc);\n    for (int64_t i = 1; i < sequences.size(); ++i) {\n      const OpNode* cur_var_node = sequences.at(i)->GetVariableNode();\n      OperatorConf cur_var_conf(cur_var_node->op().op_conf());\n      const OpNode* prev_var_node = sequences.at(i - i)->GetVariableNode();\n      cur_var_conf.add_ctrl_in_op_name(prev_var_node->op().op_name());\n      builder->MutOpsOnlyOnce({cur_var_conf});\n    }\n  }\n  return Maybe<void>::Ok();\n}\n\nclass OptimizerPlacementOptimizationPass final : public JobPass {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(OptimizerPlacementOptimizationPass);\n  OptimizerPlacementOptimizationPass() = default;\n  ~OptimizerPlacementOptimizationPass() override = default;\n\n  Maybe<void> Apply(Job* job, JobPassCtx* ctx) const override {\n    if (!(ctx->job_desc().IsTrain()\n          && ctx->job_desc().job_conf().has_optimizer_placement_optimization_mode()\n          && ctx->job_desc().job_conf().optimizer_placement_optimization_mode() != \"none\")) {\n      return Maybe<void>::Ok();\n    }\n    if (job->job_conf().enable_auto_parallel()\n        && job->job_conf().enable_auto_parallel_ignore_user_sbp_config()) {\n      LOG(WARNING) << \"ZeRO optimization will be ignored when enabling AutoParallel to ignore user \"\n                      \"sbp configuration\";\n      if (job->job_conf().enable_auto_memory() != oneflow::AutoMemoryStrategy::kHeavyAutoMemory) {\n        job->mutable_job_conf()->set_enable_auto_memory(\n            ::oneflow::AutoMemoryStrategy::kModerateAutoMemory);\n        LOG(WARNING) << \"But we turn on moderate auto memory to reduce the memory, which has \"\n                        \"similar effect as the ZeRO optimization\";\n      }\n      return Maybe<void>::Ok();\n    }\n    const std::string& mode = ctx->job_desc().job_conf().optimizer_placement_optimization_mode();\n    const OpGraph op_graph(*job);\n    JobBuilder job_builder(job);\n    if (mode == \"non_distributed\") {\n      return RewriteNonDistributed(op_graph, &job_builder);\n    } else if (mode == \"distributed_split\") {\n      return RewriteDistributedSplit(op_graph, &job_builder);\n    } else {\n      return Error::UnimplementedError();\n    }\n  }\n};\n\nREGISTER_JOB_PASS(\"OptimizerPlacementOptimizationPass\", OptimizerPlacementOptimizationPass);\n\n}  // namespace\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job_rewriter/pass_util.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/job_rewriter/pass_util.h\"\n\nnamespace oneflow {\n\nbool IsNodeInList(const HashSet<std::string>& op_list, OpNode* node) {\n  if (node->op().op_conf().has_user_conf() == false) { return false; }\n  const std::string op_type = node->op().op_conf().user_conf().op_type_name();\n  return IsKeyFound(op_list, op_type);\n}\n\nstd::string ReplaceSlashToDash4Lbn(std::string lbn) {\n  std::replace(lbn.begin(), lbn.end(), '/', '-');\n  return lbn;\n}\n\nvoid DfsTopoGraphTraversal(const OpGraph& graph, bool reversed,\n                           std::function<bool(OpNode*)> IsCurNodeStartNode,\n                           std::function<bool(OpNode*)> IsCurNodeSatisfied,\n                           std::function<bool(OpNode*)> IsFatherNodeSatisfied,\n                           std::function<void(OpNode*)> NodeHandler) {\n  auto start_nodes = reversed ? graph.sink_nodes() : graph.source_nodes();\n  std::function<void(OpNode*, std::function<void(OpNode*)>)> NodeOnInEdge =\n      reversed ? &OpNode::ForEachNodeOnOutEdge : &OpNode::ForEachNodeOnInEdge;\n  std::function<void(OpNode*, std::function<void(OpNode*)>)> NodeOnOutEdge =\n      reversed ? &OpNode::ForEachNodeOnInEdge : &OpNode::ForEachNodeOnOutEdge;\n  graph.DfsTopoForEachNode(start_nodes, NodeOnInEdge, NodeOnOutEdge, [&](OpNode* node) {\n    if (IsCurNodeStartNode(node)) {\n      NodeHandler(node);\n      return;\n    }\n    if (IsCurNodeSatisfied(node)) {\n      bool is_one_father_of_node_satisfied = false;\n      NodeOnInEdge(node, [&](OpNode* father_node) {\n        if (is_one_father_of_node_satisfied) { return; }\n        if (IsFatherNodeSatisfied(father_node)) { is_one_father_of_node_satisfied = true; }\n      });\n      if (is_one_father_of_node_satisfied) { NodeHandler(node); }\n    }\n  });\n}\n\nstd::function<bool(const OpNode* op_node)> MakePredicatorIsSafeToDelete(const OpGraph& op_graph) {\n  HashSet<std::string> ctrl_in_op_names;\n  op_graph.ForEachNode([&](const OpNode* op_node) {\n    for (const std::string& ctrl_in_op_name : op_node->op().op_conf().ctrl_in_op_name()) {\n      ctrl_in_op_names.insert(ctrl_in_op_name);\n    }\n  });\n  return [=](const OpNode* op_node) {\n    if (op_node->out_edges().size() > 1) { return false; }\n    if (!op_node->op().op_conf().ctrl_in_op_name().empty()) { return false; }\n    if (ctrl_in_op_names.find(op_node->op().op_conf().name()) != ctrl_in_op_names.end()) {\n      return false;\n    }\n    return true;\n  };\n}\n\nbool IsUserOpWithTypeName(const OperatorConf& op_conf, const std::string& op_type_name) {\n  return op_conf.has_user_conf() && op_conf.user_conf().op_type_name() == op_type_name;\n}\n\nstd::string GenParallelConfKey(const ParallelConf& conf) {\n  std::string ret = conf.device_tag();\n  for (const auto& name : conf.device_name()) { ret += (\"-\" + name); }\n  return ret;\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job_rewriter/pass_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_JOB_REWRITER_PASS_UTIL_H_\n#define ONEFLOW_CORE_JOB_REWRITER_PASS_UTIL_H_\n\n#include <string>\n#include <map>\n\n#include \"oneflow/core/graph/op_graph.h\"\n\nnamespace oneflow {\n#define INSERT_CHECK(expr) CHECK(expr.second)\n#define INSERT_CHECK_OR_RETURN(expr) CHECK_OR_RETURN(expr.second)\n\ntemplate<typename MapT, typename KeyT>\nbool IsKeyFound(const MapT& m, const KeyT& k) {\n  return m.find(k) != m.end();\n}\n\nbool IsNodeInList(const HashSet<std::string>& op_list, OpNode* node);\n\ntemplate<typename ContainerT, typename ElemT>\nstd::string Container2Str(const ContainerT& container,\n                          std::function<std::string(const ElemT&)> elem2str) {\n  std::string ret;\n  bool is_first = true;\n  for (const ElemT& elem : container) {\n    if (is_first) {\n      is_first = false;\n    } else {\n      ret += \",\\n\";\n    }\n    ret += elem2str(elem);\n  }\n  return ret;\n}\n\nstd::string ReplaceSlashToDash4Lbn(std::string lbn);\n\nvoid DfsTopoGraphTraversal(const OpGraph& graph, bool reversed,\n                           std::function<bool(OpNode*)> IsCurNodeStartNode,\n                           std::function<bool(OpNode*)> IsCurNodeSatisfied,\n                           std::function<bool(OpNode*)> IsFatherNodeSatisfied,\n                           std::function<void(OpNode*)> NodeHandler);\n\n// make sure an op_conf can only be udpated once, cuz later update will override before\nclass OpConfCache {\n  std::map<std::string, OperatorConf> _op_confs_to_update;\n\n public:\n  OperatorConf GetLatest(const OperatorConf& op_conf) {\n    if (_op_confs_to_update.find(op_conf.name()) != _op_confs_to_update.end()) {\n      return _op_confs_to_update[op_conf.name()];\n    }\n    return op_conf;\n  }\n  void Put(const OperatorConf& op_conf) { _op_confs_to_update[op_conf.name()] = op_conf; }\n  std::vector<OperatorConf> op_confs() {\n    std::vector<OperatorConf> ret;\n    for (const auto& x : _op_confs_to_update) { ret.emplace_back(x.second); }\n    return ret;\n  }\n};\n\nstd::function<bool(const OpNode* op_node)> MakePredicatorIsSafeToDelete(const OpGraph& op_graph);\nbool IsUserOpWithTypeName(const OperatorConf& op_conf, const std::string& op_type_name);\n\nstd::string GenParallelConfKey(const ParallelConf& conf);\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_JOB_REWRITER_PASS_UTIL_H_\n"
  },
  {
    "path": "oneflow/core/job_rewriter/pipeline_buffer_pass.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/job_rewriter/job_pass.h\"\n#include \"oneflow/core/job/job.pb.h\"\n#include \"oneflow/core/job/scope.h\"\n#include \"oneflow/core/job_rewriter/calculation_pass.h\"\n#include \"oneflow/core/vm/symbol_storage.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/operator/operator.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nclass PipelineBufferPass final : public JobPass {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(PipelineBufferPass);\n  PipelineBufferPass() = default;\n  ~PipelineBufferPass() = default;\n\n  Maybe<void> Apply(Job* job, JobPassCtx* ctx) const override {\n    if (!IsEnabled(*ctx)) { return Maybe<void>::Ok(); }\n    const OpGraph op_graph(*job);\n    JobBuilder job_builder(job);\n    return Apply(op_graph, &job_builder);\n  }\n\n  bool IsEnabled(const JobPassCtx& ctx) const {\n    // Pipeline optimization depends on gradient accumulatioin.\n    return ctx.job_desc().IsTrain()\n           && ctx.job_desc().job_conf().num_gradient_accumulation_steps() > 1;\n  }\n\n  Maybe<void> Apply(const OpGraph& op_graph, JobBuilder* job_builder) const;\n};\n\nconst std::string kBufferOpNamePrefix = \"System-Pipeline-Buffer-Op_\";\n\nconst Scope& Scope4ScopeSymbolId(int64_t scope_symbol_id) {\n  CHECK(Singleton<symbol::Storage<Scope>>::Get()->Has(scope_symbol_id));\n  return Singleton<symbol::Storage<Scope>>::Get()->Get(scope_symbol_id);\n}\n\nconst Scope& Scope4OpNode(const OpNode* op_node) {\n  const OperatorConf& op_conf = op_node->op().op_conf();\n  CHECK(op_conf.has_scope_symbol_id());\n  return Scope4ScopeSymbolId(op_conf.scope_symbol_id());\n}\n\nbool IsForwardPass(const OpNode* node) {\n  return Scope4OpNode(node).scope_proto().calculation_pass_name() == kForwardPass;\n}\n\nbool IsBackwardPass(const OpNode* node) {\n  return Scope4OpNode(node).scope_proto().calculation_pass_name() == kBackwardPass;\n}\n\nbool OpNodeHasScope(const OpNode* node) { return node->op().op_conf().has_scope_symbol_id(); }\n\nbool IsIdentityBufferOrRepeatOpNode(const OpNode* node) {\n  const OperatorConf& op_conf = node->op().op_conf();\n  if (op_conf.has_user_conf()) {\n    const std::string& op_type_name = op_conf.user_conf().op_type_name();\n    if (op_type_name == \"identity_buffer\" || op_type_name == \"repeat\") { return true; }\n  }\n  return false;\n}\n\nint64_t GetStageIdHint(const OpNode* node) {\n  return Scope4OpNode(node).Int64(\"pipeline_stage_id_hint\");\n}\n\nvoid TryInsertOrUseBufferOpToDstNode(\n    const OpEdge* op_edge, const int64_t buffer_size,\n    HashMap<std::string, OperatorConf>* buffer_op_name2op_conf,\n    HashMap<std::string, ParallelConf>* buffer_op_name2parallel_conf,\n    HashMap<std::string, OperatorConf>* mut_op_name2conf) {\n  const OpNode* src_node = op_edge->src_node();\n  const OpNode* dst_node = op_edge->dst_node();\n  const int64_t src_stage_id = GetStageIdHint(src_node);\n  const int64_t dst_stage_id = GetStageIdHint(dst_node);\n  const std::string& dst_op_name = dst_node->op().op_name();\n  const int64_t stage_id = GetStageIdHint(dst_node);\n  for (const LogicalBlobId& lbi : op_edge->lbis()) {\n    std::string lbn = GenLogicalBlobName(lbi);\n    std::string buffer_op_name = kBufferOpNamePrefix + \"-\" + lbi.op_name() + \"-\" + lbi.blob_name()\n                                 + \"-stage_id_\" + std::to_string(stage_id);\n\n    auto it = buffer_op_name2op_conf->find(buffer_op_name);\n    if (it == buffer_op_name2op_conf->end()) {\n      it = buffer_op_name2op_conf\n               ->emplace(buffer_op_name,\n                         user_op::UserOpConfWrapperBuilder(buffer_op_name)\n                             .Op(\"identity_buffer\")\n                             .Input(\"in\", lbn)\n                             .Output(\"out\")\n                             .Attr<int64_t>(\"buffer_size\", buffer_size)\n                             .ScopeSymbolId(dst_node->op().op_conf().scope_symbol_id())\n                             .Build()\n                             .op_conf())\n               .first;\n      CHECK(buffer_op_name2parallel_conf\n                ->emplace(buffer_op_name, dst_node->parallel_desc().parallel_conf())\n                .second);\n\n      VLOG(3) << \"\\n Insert buffer op : [\" << buffer_op_name << \"](buffer_size:\" << buffer_size\n              << \") \\n from [\" << src_node->op().op_name()\n              << \"] (stage_id:\" << std::to_string(src_stage_id) << \") -> [\"\n              << dst_node->op().op_name() << \"] (stage_id:\" << std::to_string(dst_stage_id)\n              << \") \\n\";\n    }\n\n    auto mut_op_it = mut_op_name2conf->find(dst_op_name);\n    if (mut_op_it == mut_op_name2conf->end()) {\n      mut_op_it = mut_op_name2conf->emplace(dst_op_name, dst_node->op().op_conf()).first;\n    }\n\n    const std::string buffer_out = user_op::UserOpConfWrapper(it->second).output(\"out\", 0);\n    for (const std::string& ibn : op_edge->lbi2ibns().at(lbi)) {\n      std::string old_lbn =\n          ReplaceInputLbnInOpCustomizedConf(&(mut_op_it->second), ibn, buffer_out);\n      CHECK_EQ(old_lbn, lbn);\n    }\n  }\n}\n\nvoid TryInsertOrUseBufferOpBothSrcDst(\n    const OpEdge* op_edge, const int64_t src_buffer_size, const int64_t dst_buffer_size,\n    HashMap<std::string, OperatorConf>* buffer_op_name2op_conf,\n    HashMap<std::string, ParallelConf>* buffer_op_name2parallel_conf,\n    HashMap<std::string, OperatorConf>* mut_op_name2conf) {\n  const OpNode* src_node = op_edge->src_node();\n  const OpNode* dst_node = op_edge->dst_node();\n  const ParallelDesc& src_parallel_desc = src_node->parallel_desc();\n  const ParallelDesc& dst_parallel_desc = dst_node->parallel_desc();\n  const std::string& src_op_name = src_node->op().op_name();\n  const std::string& dst_op_name = dst_node->op().op_name();\n  const int64_t src_stage_id = GetStageIdHint(src_node);\n  const int64_t dst_stage_id = GetStageIdHint(dst_node);\n  CHECK_NE(src_stage_id, dst_stage_id);\n  CHECK_GE(src_buffer_size, 1);\n  CHECK_GE(dst_buffer_size, 1);\n  CHECK(!src_parallel_desc.EqualsIgnoringHierarchy(dst_parallel_desc))\n      << \" Pipeline buffer pass meet ERROR! the src_op: \" << src_op_name\n      << \" -> dst_op: \" << dst_op_name\n      << \" with same placement: \" << src_parallel_desc.parallel_conf().DebugString()\n      << \" , but with different stage id: src_stage_id (\" << src_stage_id << \") -> dst_stage_id (\"\n      << dst_stage_id << \"). Please check your stage id config for modules.\";\n  for (const LogicalBlobId& lbi : op_edge->lbis()) {\n    std::string lbn = GenLogicalBlobName(lbi);\n    std::string src_buffer_op_name =\n        kBufferOpNamePrefix + \"-\" + lbi.op_name() + \"-\" + lbi.blob_name();\n    std::string dst_buffer_op_name = kBufferOpNamePrefix + \"-\" + lbi.op_name() + \"-\"\n                                     + lbi.blob_name() + \"-stage_id_\"\n                                     + std::to_string(dst_stage_id);\n\n    auto src_buffer_it = buffer_op_name2op_conf->find(src_buffer_op_name);\n    if (src_buffer_it == buffer_op_name2op_conf->end()) {\n      src_buffer_it = buffer_op_name2op_conf\n                          ->emplace(src_buffer_op_name,\n                                    user_op::UserOpConfWrapperBuilder(src_buffer_op_name)\n                                        .Op(\"identity_buffer\")\n                                        .Input(\"in\", lbn)\n                                        .Output(\"out\")\n                                        .Attr<int64_t>(\"buffer_size\", src_buffer_size)\n                                        .ScopeSymbolId(src_node->op().op_conf().scope_symbol_id())\n                                        .Build()\n                                        .op_conf())\n                          .first;\n      CHECK(buffer_op_name2parallel_conf\n                ->emplace(src_buffer_op_name, src_parallel_desc.parallel_conf())\n                .second);\n    }\n    const OperatorConf& src_conf = src_buffer_it->second;\n    const std::string src_buffer_out = user_op::UserOpConfWrapper(src_conf).output(\"out\", 0);\n\n    auto dst_buffer_it = buffer_op_name2op_conf->find(dst_buffer_op_name);\n    if (dst_buffer_it == buffer_op_name2op_conf->end()) {\n      dst_buffer_it = buffer_op_name2op_conf\n                          ->emplace(dst_buffer_op_name,\n                                    user_op::UserOpConfWrapperBuilder(dst_buffer_op_name)\n                                        .Op(\"identity_buffer\")\n                                        .Input(\"in\", src_buffer_out)\n                                        .Output(\"out\")\n                                        .Attr<int64_t>(\"buffer_size\", dst_buffer_size)\n                                        .ScopeSymbolId(dst_node->op().op_conf().scope_symbol_id())\n                                        .Build()\n                                        .op_conf())\n                          .first;\n      CHECK(buffer_op_name2parallel_conf\n                ->emplace(dst_buffer_op_name, dst_parallel_desc.parallel_conf())\n                .second);\n    }\n    const OperatorConf& dst_conf = dst_buffer_it->second;\n\n    auto mut_op_it = mut_op_name2conf->find(dst_op_name);\n    if (mut_op_it == mut_op_name2conf->end()) {\n      mut_op_it = mut_op_name2conf->emplace(dst_op_name, dst_node->op().op_conf()).first;\n    }\n\n    VLOG(3) << \"\\n Insert buffer op pair : src_buffer = <\" << src_buffer_op_name\n            << \">(buffer_size:\" << src_buffer_size << \") , dst_buffer = <\" << dst_buffer_op_name\n            << \">(buffer_size:\" << dst_buffer_size << \") \\n from [\" << src_node->op().op_name()\n            << \"] (stage_id:\" << std::to_string(src_stage_id) << \") -> [\"\n            << dst_node->op().op_name() << \"] (stage_id:\" << std::to_string(dst_stage_id) << \") \\n\";\n\n    const std::string dst_buffer_out = user_op::UserOpConfWrapper(dst_conf).output(\"out\", 0);\n    for (const std::string& ibn : op_edge->lbi2ibns().at(lbi)) {\n      std::string old_lbn =\n          ReplaceInputLbnInOpCustomizedConf(&(mut_op_it->second), ibn, dst_buffer_out);\n      CHECK_EQ(old_lbn, lbn);\n    }\n  }\n}\n\nMaybe<void> PipelineBufferPass::Apply(const OpGraph& op_graph, JobBuilder* job_builder) const {\n  int64_t max_stage_id = 0;\n  op_graph.ForEachNode([&](const OpNode* this_node) {\n    if (!OpNodeHasScope(this_node)) {\n      LOG(WARNING) << \" op : \" << this_node->op().op_conf().DebugString() << \" has NOT scope!\";\n      return;\n    }\n    max_stage_id = std::max(max_stage_id, GetStageIdHint(this_node));\n  });\n\n  if (max_stage_id == 0) { return Maybe<void>::Ok(); }\n  const int64_t total_stage_num = max_stage_id + 1;\n  VLOG(3) << \"total stage num = \" << total_stage_num;\n\n  HashMap<std::string, OperatorConf> buffer_op_name2op_conf;\n  HashMap<std::string, ParallelConf> buffer_op_name2parallel_conf;\n  HashMap<std::string, OperatorConf> mut_op_name2conf;\n\n  op_graph.ForEachNode([&](const OpNode* this_node) {\n    if (!OpNodeHasScope(this_node)) { return; /* ignore op without scope */ }\n    if (!IsBackwardPass(this_node)) { return; /* ignore fw dst op */ }\n    for (const OpEdge* in_edge : this_node->in_edges()) {\n      const OpNode* src_node = in_edge->src_node();\n      if (!OpNodeHasScope(src_node)) { continue; /* ignore op without scope */ }\n      const int64_t src_stage_id = GetStageIdHint(src_node);\n      const int64_t dst_stage_id = GetStageIdHint(this_node);\n\n      if (IsForwardPass(src_node) && (!IsIdentityBufferOrRepeatOpNode(src_node))) {\n        if (dst_stage_id == max_stage_id) {\n          continue; /* last stage(loss) does NOT need to insert buffer */\n        }\n        if (src_stage_id != dst_stage_id) {\n          LOG(WARNING)\n              << \" Cross diff stage link From: [\" << src_node->op().op_conf().DebugString()\n              << \"](stage_id:\" << std::to_string(src_stage_id) << \") -> [\"\n              << this_node->op().op_conf().DebugString()\n              << \"](stage_id:\" << std::to_string(dst_stage_id)\n              << \"). Make sure to change the tensor's placement before it enter the module \"\n                 \"of a next pipeline stage.\\n\";\n        }\n        const int64_t buffer_size = total_stage_num * 2; /* NOTE(chengcheng): max buffer size */\n        TryInsertOrUseBufferOpToDstNode(in_edge, buffer_size, &buffer_op_name2op_conf,\n                                        &buffer_op_name2parallel_conf, &mut_op_name2conf);\n      }\n    }\n    for (const std::string& ctrl_in_op_name : this_node->op().op_conf().ctrl_in_op_name()) {\n      const OpNode* src_node = op_graph.OpNode4OpName(ctrl_in_op_name);\n      if (!OpNodeHasScope(src_node)) { continue; /* ignore op without scope */ }\n      if (IsForwardPass(src_node)) {\n        LOG(WARNING) << \"CtrlEdge: src_op[FwPass]: \" << src_node->op().op_conf().DebugString()\n                     << \" dst_op[BwPass]: \" << this_node->op().op_conf().DebugString()\n                     << \" connected.\";\n      }\n    }\n  });\n\n  op_graph.ForEachEdge([&](const OpEdge* edge) {\n    const OpNode* src_node = edge->src_node();\n    const OpNode* dst_node = edge->dst_node();\n    if (OpNodeHasScope(src_node) && OpNodeHasScope(dst_node) && IsForwardPass(src_node)\n        && IsForwardPass(dst_node)) {\n      const int64_t src_stage_id = GetStageIdHint(src_node);\n      const int64_t dst_stage_id = GetStageIdHint(dst_node);\n      if (src_node->parallel_desc().device_type() == DeviceType::kCPU\n          && dst_node->parallel_desc().device_type() == DeviceType::kCUDA) {\n        if (src_stage_id == 0 && dst_stage_id == max_stage_id) {\n          TryInsertOrUseBufferOpToDstNode(edge, total_stage_num * 2, &buffer_op_name2op_conf,\n                                          &buffer_op_name2parallel_conf, &mut_op_name2conf);\n          return;\n        }\n      }\n      if (src_stage_id < dst_stage_id) {\n        /* NOTE(chengcheng): We insert double buffer between src / dst node.\n         *   src_buffer_size = 1 because we need free memory as early as possible so we can overlap\n         *   CopyD2H with Compute.\n         *   dst_buffer_size = dst_stage_id - src_stage_id for pipeline.\n         */\n        const int64_t dst_buffer_size = dst_stage_id - src_stage_id;\n        TryInsertOrUseBufferOpBothSrcDst(edge, 1, dst_buffer_size, &buffer_op_name2op_conf,\n                                         &buffer_op_name2parallel_conf, &mut_op_name2conf);\n      }\n    }\n    if (OpNodeHasScope(src_node) && OpNodeHasScope(dst_node) && IsBackwardPass(src_node)\n        && IsBackwardPass(dst_node)) {\n      const int64_t src_stage_id = GetStageIdHint(src_node);\n      const int64_t dst_stage_id = GetStageIdHint(dst_node);\n      // NOTE(chengcheng): Backward ONLY need buffer size 1.\n      if (src_stage_id > dst_stage_id) {\n        TryInsertOrUseBufferOpBothSrcDst(edge, 1, 1, &buffer_op_name2op_conf,\n                                         &buffer_op_name2parallel_conf, &mut_op_name2conf);\n      }\n    }\n  });\n\n  for (auto& pair : buffer_op_name2op_conf) {\n    CHECK(buffer_op_name2parallel_conf.find(pair.first) != buffer_op_name2parallel_conf.end());\n    JUST(job_builder->AddOp(buffer_op_name2parallel_conf.at(pair.first), pair.second));\n  }\n  for (auto& pair : mut_op_name2conf) { JUST(job_builder->MutOpOnlyOnce(pair.second)); }\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\nREGISTER_JOB_PASS(\"PipelineBufferPass\", PipelineBufferPass);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job_rewriter/prune_amp_white_identity_op_pass.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/job_rewriter/job_pass.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nbool IsAmpIdentityOp(const OperatorConf& op) {\n  return op.has_user_conf()\n         && (op.user_conf().op_type_name() == \"amp_white_identity\"\n             || op.user_conf().op_type_name() == \"amp_black_identity\");\n}\n\nbool NeedDoPass(const Job& job) {\n  return std::any_of(job.net().op().cbegin(), job.net().op().cend(), IsAmpIdentityOp);\n}\n\nclass PruneAmpWhiteIdentityOpPass final : public JobPass {\n public:\n  PruneAmpWhiteIdentityOpPass() = default;\n  ~PruneAmpWhiteIdentityOpPass() override = default;\n\n  Maybe<void> Apply(Job* job, JobPassCtx* ctx) const override;\n};\n\nMaybe<void> PruneAmpWhiteIdentityOpPass::Apply(Job* job, JobPassCtx* ctx) const {\n  if (!ctx->job_desc().prune_amp_white_identity_ops()) { return Maybe<void>::Ok(); }\n  if (!NeedDoPass(*job)) { return Maybe<void>::Ok(); }\n  const OpGraph op_graph(*job);\n\n  HashSet<std::string> ctrl_in_op_names;\n  op_graph.ForEachNode([&](const OpNode* op_node) {\n    for (const std::string& ctrl_in_op_name : op_node->op().op_conf().ctrl_in_op_name()) {\n      ctrl_in_op_names.insert(ctrl_in_op_name);\n    }\n  });\n\n  HashSet<const OpNode*> del_nodes;\n  op_graph.ForEachNode([&](const OpNode* op_node) {\n    const std::string& op_name = op_node->op().op_name();\n    const OperatorConf& op_conf = op_node->op().op_conf();\n    // not amp identity op\n    if (!IsAmpIdentityOp(op_conf)) { return; }\n    // has ctrl in\n    if (!op_conf.ctrl_in_op_name().empty()) { return; }\n    // is ctrl in of another op\n    if (ctrl_in_op_names.find(op_name) != ctrl_in_op_names.end()) { return; }\n    // not sole in\n    if (op_node->in_edges().size() != 1) { return; }\n\n    del_nodes.insert(op_node);\n  });\n\n  HashMap<std::string, OperatorConf> to_update_op_confs;\n  std::vector<std::string> del_op_names;\n  del_op_names.reserve(del_nodes.size());\n  for (const OpNode* op_node : del_nodes) {\n    del_op_names.emplace_back(op_node->op().op_name());\n\n    // find first node not deleted\n    const OpNode* first = op_node;\n    const OpNode* producer = op_node->SoleInEdge()->src_node();\n    while (del_nodes.find(producer) != del_nodes.end()) {\n      first = producer;\n      producer = producer->SoleInEdge()->src_node();\n    }\n\n    const auto& old_lbi = op_node->op().BnInOp2Lbi(op_node->op().SoleObn());\n    const auto& new_lbi = first->op().BnInOp2Lbi(first->op().SoleIbn());\n\n    for (const OpEdge* out_edge : op_node->out_edges()) {\n      const OpNode* consumer = out_edge->dst_node();\n      if (del_nodes.find(consumer) == del_nodes.end()) {\n        const Operator& op = consumer->op();\n        for (const std::string& ibn : op.input_bns()) {\n          if (op.BnInOp2Lbi(ibn) == old_lbi) {\n            auto iter = to_update_op_confs.find(op.op_name());\n            if (iter == to_update_op_confs.end()) {\n              iter = to_update_op_confs.emplace(op.op_name(), op.op_conf()).first;\n            }\n            OperatorConf& op_conf = iter->second;\n            const auto& old_val =\n                ReplaceInputLbnInOpCustomizedConf(&op_conf, ibn, GenLogicalBlobName(new_lbi));\n            CHECK_EQ_OR_RETURN(GenLogicalBlobName(old_lbi), old_val);\n          }\n        }\n      }\n    }\n  }\n\n  JobBuilder job_builder(job);\n  for (const auto& pair : to_update_op_confs) { job_builder.MutOpsOnlyOnce({pair.second}); }\n  job_builder.DelOps(del_op_names);\n\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\nREGISTER_JOB_PASS(\"PruneAmpWhiteIdentityOpPass\", PruneAmpWhiteIdentityOpPass);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job_rewriter/prune_cast_to_static_shape_op_pass.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/job_rewriter/job_pass.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nbool IsRelatedOp(const OperatorConf& op) {\n  return op.has_user_conf() && (op.user_conf().op_type_name() == \"cast_to_static_shape\");\n}\n\nbool NeedDoPass(const Job& job) {\n  return std::any_of(job.net().op().cbegin(), job.net().op().cend(), IsRelatedOp);\n}\n\nclass PruneCastToStaticShapeOpsPass final : public JobPass {\n public:\n  PruneCastToStaticShapeOpsPass() = default;\n  ~PruneCastToStaticShapeOpsPass() override = default;\n\n  bool IsEnabled(const JobPassCtx& ctx) const {\n    return ctx.job_desc().IsTrain() && ctx.job_desc().prune_cast_to_static_shape_ops();\n  }\n  Maybe<void> Apply(const OpGraph& op_graph, JobBuilder* job_builder) const;\n\n  Maybe<void> Apply(Job* job, JobPassCtx* ctx) const override {\n    if (!IsEnabled(*ctx)) { return Maybe<void>::Ok(); }\n    if (!NeedDoPass(*job)) { return Maybe<void>::Ok(); }\n    const OpGraph op_graph(*job);\n    JobBuilder job_builder(job);\n    return Apply(op_graph, &job_builder);\n  }\n};\n\nMaybe<void> PruneCastToStaticShapeOpsPass::Apply(const OpGraph& op_graph,\n                                                 JobBuilder* job_builder) const {\n  HashMap<std::string, OperatorConf> op_name2op_conf;\n  HashSet<std::string> ctrl_in_op_names;\n  op_graph.ForEachNode([&](const OpNode* op_node) {\n    for (const std::string& ctrl_in_op_name : op_node->op().op_conf().ctrl_in_op_name()) {\n      ctrl_in_op_names.insert(ctrl_in_op_name);\n    }\n  });\n  std::vector<std::string> del_op_names;\n  op_graph.ForEachNode([&](const OpNode* op_node) {\n    const OperatorConf& op_conf = op_node->op().op_conf();\n    if (!op_conf.has_user_conf()) { return; }\n    const std::string& op_type_name = op_conf.user_conf().op_type_name();\n    if (op_type_name != \"cast_to_static_shape\") { return; }\n    if (!op_conf.ctrl_in_op_name().empty()) { return; }\n    if (ctrl_in_op_names.find(op_conf.name()) != ctrl_in_op_names.end()) { return; }\n    if (op_node->in_edges().size() != 1) { return; }\n    const user_op::UserOpConfWrapper user_op_conf(op_conf);\n    const LogicalBlobId& cast_in_lbi = GenLogicalBlobId(user_op_conf.input(\"input\", 0));\n    const LogicalBlobId& cast_out_lbi = GenLogicalBlobId(user_op_conf.output(\"output\", 0));\n    const OpNode* producer = op_graph.OpNode4OpName(cast_in_lbi.op_name());\n    const BlobDesc& cast_in_logical_blob_desc = producer->LogicalBlobDesc4Lbi(cast_in_lbi);\n    if (cast_in_logical_blob_desc.is_dynamic()) { return; }\n    for (const OpEdge* out_edge : op_node->out_edges()) {\n      const OpNode* consumer = out_edge->dst_node();\n      const std::string& consumer_op_name = consumer->op().op_name();\n      if (op_name2op_conf.find(consumer_op_name) == op_name2op_conf.end()) {\n        op_name2op_conf[consumer_op_name] = consumer->op().op_conf();\n      }\n      OperatorConf& consumer_op_conf = op_name2op_conf.at(consumer_op_name);\n      for (const std::string& ibn : consumer->op().input_bns()) {\n        if (consumer->op().BnInOp2Lbi(ibn) == cast_out_lbi) {\n          const auto& old_val = ReplaceInputLbnInOpCustomizedConf(&consumer_op_conf, ibn,\n                                                                  GenLogicalBlobName(cast_in_lbi));\n          CHECK_EQ(GenLogicalBlobName(cast_out_lbi), old_val);\n        }\n      }\n    }\n    del_op_names.emplace_back(op_conf.name());\n  });\n  for (const auto& pair : op_name2op_conf) { job_builder->MutOpsOnlyOnce({pair.second}); }\n  job_builder->DelOps(del_op_names);\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\nREGISTER_JOB_PASS(\"PruneCastToStaticShapeOpsPass\", PruneCastToStaticShapeOpsPass);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job_rewriter/prune_depend_op_pass.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <glog/logging.h>\n#include <string>\n#include <vector>\n#include \"oneflow/core/common/hash_container.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/graph/node.h\"\n#include \"oneflow/core/graph/op_graph.h\"\n#include \"oneflow/core/job_rewriter/job_pass.h\"\n#include \"oneflow/core/register/logical_blob_id.pb.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nstruct UpdatedNodeInfo {\n  const OpNode* node = nullptr;\n  const OpNode* new_src_node = nullptr;\n  const OpNode* depend_node_nearest_src = nullptr;\n  const OpNode* depend_node_nearest_dst = nullptr;\n  std::vector<const OpNode*> new_in_ctrl_nodes;\n  bool updated = false;\n};\n\nbool IsDependyOp(const OperatorConf& op) {\n  return op.has_user_conf() && (op.user_conf().op_type_name() == \"depend\");\n}\n\nbool NeedDoPass(const Job& job) {\n  return std::any_of(job.net().op().cbegin(), job.net().op().cend(), IsDependyOp);\n}\n\nconst OpNode* GetNodeFromEdgeByTensorName(const OpNode* op_node,\n                                          const std::string& target_tensor_name) {\n  CHECK(IsDependyOp(op_node->op().op_conf()));\n  for (const OpEdge* in_edge : op_node->in_edges()) {\n    const OpNode* in_op_node = in_edge->src_node();\n    const std::string& in_op_node_name = in_op_node->op().op_name();\n    const HashMap<LogicalBlobId, std::vector<std::string>>& lbi2ibns = in_edge->lbi2ibns();\n\n    for (const auto& item : lbi2ibns) {\n      const std::string& lbi_op_name = item.first.op_name();\n      for (const std::string& tensor_name : item.second) {\n        if (in_op_node_name == lbi_op_name && tensor_name == target_tensor_name) {\n          return in_op_node;\n        }\n      }\n    }\n  }\n  return nullptr;\n}\n\nconst OpNode* GetNodeFromInputEdge(const OpNode* op_node) {\n  return GetNodeFromEdgeByTensorName(op_node, \"in_0\");\n}\n\nconst OpNode* GetNodeFromInCtrlEdge(const OpNode* op_node) {\n  return GetNodeFromEdgeByTensorName(op_node, \"depend_tensor_0\");\n}\n\nLogicalBlobId GetNewLbi(const OpNode* src_node, const OpNode* depend_node_nearest_src) {\n  CHECK(IsDependyOp(depend_node_nearest_src->op().op_conf()));\n  for (const OpEdge* out_edge : src_node->out_edges()) {\n    const OpNode* dst_node = out_edge->dst_node();\n    if (dst_node != depend_node_nearest_src) { continue; }\n\n    CHECK(out_edge->lbis().size() == 1);\n    return out_edge->lbis()[0];\n  }\n  // should not reach here\n  CHECK(false);\n  return {};\n}\n\nclass PruneDependOpPass final : public JobPass {\n public:\n  PruneDependOpPass() = default;\n  ~PruneDependOpPass() override = default;\n\n  Maybe<void> Apply(Job* job, JobPassCtx* ctx) const override;\n};\n\nMaybe<void> PruneDependOpPass::Apply(Job* job, JobPassCtx* ctx) const {\n  if (!ctx->job_desc().prune_depend_ops()) { return Maybe<void>::Ok(); }\n  if (!NeedDoPass(*job)) { return Maybe<void>::Ok(); }\n  const OpGraph op_graph(*job);\n\n  HashMap<std::string, UpdatedNodeInfo> node_info_with_update;\n  std::vector<const OpNode*> ordered_nodes;\n\n  // Step 0: topological sort, setup a map for recording modification\n  op_graph.TopoForEachNodeWithCtrlEdge([&](const OpNode* node) {\n    UpdatedNodeInfo node_info;\n    node_info.node = node;\n    node_info_with_update.emplace(node->op().op_name(), node_info);\n    ordered_nodes.emplace_back(node);\n  });\n\n  // Step 1: process node by topological order\n  // record modification info when meet Depend OP nodes\n  for (const OpNode* cur_node : ordered_nodes) {\n    const std::string& cur_op_name = cur_node->op().op_name();\n    const OperatorConf& cur_op_conf = cur_node->op().op_conf();\n    if (!IsDependyOp(cur_op_conf)) { continue; }\n\n    // record modification info to each dst_node\n    for (const OpEdge* out_edge : cur_node->out_edges()) {\n      const OpNode* dst_node = out_edge->dst_node();\n      const Operator& dst_op = dst_node->op();\n\n      UpdatedNodeInfo& updated_dst_node_info = node_info_with_update.find(dst_op.op_name())->second;\n      UpdatedNodeInfo& updated_cur_node_info = node_info_with_update.find(cur_op_name)->second;\n      updated_dst_node_info.updated = true;\n      updated_dst_node_info.depend_node_nearest_dst = cur_node;\n\n      // Step 1.1: record a new in-ctrl node\n      const OpNode* cur_in_ctrl_node = GetNodeFromInCtrlEdge(cur_node);\n      updated_dst_node_info.new_in_ctrl_nodes.emplace_back(cur_in_ctrl_node);\n\n      // Step 1.2: inherit in-ctrl nodes from Depend OP nodes\n      const auto& ori_in_ctrl_op_names = cur_op_conf.ctrl_in_op_name();\n      for (const std::string& ori_ctrl_in_op_name : ori_in_ctrl_op_names) {\n        updated_dst_node_info.new_in_ctrl_nodes.emplace_back(\n            node_info_with_update[ori_ctrl_in_op_name].node);\n      }\n      if (updated_cur_node_info.updated) {\n        std::vector<const OpNode*>& inherit_in_ctrl_nodes = updated_cur_node_info.new_in_ctrl_nodes;\n        for (const OpNode* inherit_in_ctrl_node : inherit_in_ctrl_nodes) {\n          updated_dst_node_info.new_in_ctrl_nodes.emplace_back(inherit_in_ctrl_node);\n        }\n      }\n\n      // Step 1.3 process src nodes\n      const OpNode* cur_src_node = GetNodeFromInputEdge(cur_node);\n      if (IsDependyOp(dst_node->op().op_conf()) && cur_node == GetNodeFromInCtrlEdge(dst_node)) {\n        // \"cur_node\" and \"dst_node\" are all Depend OP nodes, and their connection is like this\n        // other_node   cur_node\n        //          \\   /\n        //         dst_node\n        // in this case, all src nodes of \"cur_node\" should be seen as in-ctrl nodes\n        if (updated_cur_node_info.updated && updated_cur_node_info.new_src_node) {\n          updated_dst_node_info.new_in_ctrl_nodes.emplace_back(updated_cur_node_info.new_src_node);\n        }\n        updated_dst_node_info.new_in_ctrl_nodes.emplace_back(cur_src_node);\n      } else {\n        if (!IsDependyOp(cur_src_node->op().op_conf())) {\n          updated_dst_node_info.new_src_node = cur_src_node;\n          updated_dst_node_info.depend_node_nearest_src = cur_node;\n        } else if (updated_cur_node_info.updated && updated_cur_node_info.new_src_node) {\n          updated_dst_node_info.new_src_node = updated_cur_node_info.new_src_node;\n          updated_dst_node_info.depend_node_nearest_src =\n              updated_cur_node_info.depend_node_nearest_src;\n        }\n      }\n    }\n  }\n\n  // Step 2: extract modification info\n  // including new connection and to delete nodes\n  std::vector<std::string> del_node_names;\n  HashMap<std::string, OperatorConf> to_update_op_confs;\n  for (const auto& node_info : node_info_with_update) {\n    // filter nodes not updated\n    if (!node_info.second.updated) { continue; }\n    const OpNode* cur_node = node_info.second.node;\n    const std::string& cur_op_name = cur_node->op().op_name();\n    // filter Depnd nodes\n    if (IsDependyOp(cur_node->op().op_conf())) {\n      del_node_names.emplace_back(cur_op_name);\n      continue;\n    }\n\n    const Operator& cur_op = cur_node->op();\n    auto iter = to_update_op_confs.find(node_info.first);\n    if (iter == to_update_op_confs.end()) {\n      iter = to_update_op_confs.emplace(node_info.first, cur_op.op_conf()).first;\n    }\n    OperatorConf& cur_op_conf = iter->second;\n\n    // Step 2.1: connect updated src_node with cur_node (dst_node of Depned OP)\n    const OpNode* src_node = node_info.second.new_src_node;\n    const OpNode* depend_node_nearest_dst = node_info.second.depend_node_nearest_dst;\n    const OpNode* depend_node_nearest_src = node_info.second.depend_node_nearest_src;\n    CHECK(src_node && depend_node_nearest_dst && depend_node_nearest_src);\n    const auto& old_lbi =\n        depend_node_nearest_dst->op().BnInOp2Lbi(depend_node_nearest_dst->op().SoleObn());\n    const auto new_lbi = GetNewLbi(src_node, depend_node_nearest_src);\n    for (const std::string& ibn : cur_node->op().input_bns()) {\n      if (cur_op.BnInOp2Lbi(ibn) == old_lbi) {\n        const auto& old_val =\n            ReplaceInputLbnInOpCustomizedConf(&cur_op_conf, ibn, GenLogicalBlobName(new_lbi));\n        CHECK_EQ(GenLogicalBlobName(old_lbi), old_val);\n        VLOG(3) << \"Update input edge, Src Node: \" << src_node->op().op_name()\n                << \"\\t->\\tDst Node: \" << cur_op_name;\n      }\n    }\n\n    // Step 2.2: add in-ctrl OPs\n    const auto& existed_ctrl_in_op_names = cur_op_conf.ctrl_in_op_name();\n    for (const OpNode* in_ctrl_node : node_info.second.new_in_ctrl_nodes) {\n      // filter Depnd nodes\n      if (IsDependyOp(in_ctrl_node->op().op_conf())) { continue; }\n      CHECK(cur_node != in_ctrl_node);  // self-loop found\n      const std::string& new_ctrl_in_op_name = in_ctrl_node->op().op_name();\n      auto existed_it = std::find(existed_ctrl_in_op_names.begin(), existed_ctrl_in_op_names.end(),\n                                  new_ctrl_in_op_name);\n      // filter src node or duplicate in-ctrl nodes\n      if (in_ctrl_node != src_node && existed_it == existed_ctrl_in_op_names.end()) {\n        cur_op_conf.add_ctrl_in_op_name(new_ctrl_in_op_name);\n        VLOG(3) << \"Add in-ctrl edge, Src Node: \" << new_ctrl_in_op_name\n                << \"\\t->\\tDst Node: \" << cur_op_name;\n      }\n    }\n  }\n\n  // Step 3: apply modification to job\n  JobBuilder job_builder(job);\n  for (const auto& pair : to_update_op_confs) { job_builder.MutOpsOnlyOnce({pair.second}); }\n  job_builder.DelOps(del_node_names);\n  return Maybe<void>::Ok();\n};\n\n}  // namespace\n\nREGISTER_JOB_PASS(\"PruneDependOpPass\", PruneDependOpPass);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job_rewriter/prune_parallel_cast_op_pass.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/job_rewriter/job_pass.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nbool IsParallelCastOp(const OperatorConf& op_conf) {\n  return op_conf.has_user_conf()\n         && (op_conf.user_conf().op_type_name() == \"parallel_cast\"\n             || op_conf.user_conf().op_type_name() == \"hierarchical_parallel_cast\"\n             || op_conf.user_conf().op_type_name() == \"hierarchical_parallel_cast_like\");\n}\n\nbool NeedDoPass(const Job& job) {\n  return std::any_of(job.net().op().cbegin(), job.net().op().cend(), IsParallelCastOp);\n}\n\nclass PruneParallelCastOpsPass final : public JobPass {\n public:\n  PruneParallelCastOpsPass() = default;\n  ~PruneParallelCastOpsPass() override = default;\n\n  bool IsEnabled(const JobPassCtx& ctx) const { return ctx.job_desc().prune_parallel_cast_ops(); }\n  Maybe<void> Apply(const OpGraph& op_graph, JobBuilder* job_builder) const;\n\n  Maybe<void> Apply(Job* job, JobPassCtx* ctx) const override {\n    if (!IsEnabled(*ctx)) { return Maybe<void>::Ok(); }\n    if (!NeedDoPass(*job)) { return Maybe<void>::Ok(); }\n    const OpGraph op_graph(*job);\n    JobBuilder job_builder(job);\n    return Apply(op_graph, &job_builder);\n  }\n};\n\nMaybe<void> PruneParallelCastOpsPass::Apply(const OpGraph& op_graph,\n                                            JobBuilder* job_builder) const {\n  HashMap<std::string, OperatorConf> op_name2op_conf;\n  HashMap<std::string, NdSbpSignature> op_name2nd_sbp_signature;\n  HashSet<std::string> ctrl_in_op_names;\n  op_graph.ForEachNode([&](const OpNode* op_node) {\n    for (const std::string& ctrl_in_op_name : op_node->op().op_conf().ctrl_in_op_name()) {\n      ctrl_in_op_names.insert(ctrl_in_op_name);\n    }\n  });\n  std::vector<std::string> del_op_names;\n  op_graph.ForEachNode([&](const OpNode* op_node) {\n    const OperatorConf& op_conf = op_node->op().op_conf();\n    if (!op_conf.ctrl_in_op_name().empty()) { return; }\n    if (ctrl_in_op_names.find(op_conf.name()) != ctrl_in_op_names.end()) { return; }\n    if (!IsParallelCastOp(op_conf)) { return; }\n    if (op_node->in_edges().size() != 1) { return; }\n    user_op::UserOpConfWrapper conf_wrapper(op_conf);\n    const LogicalBlobId& parallel_cast_in_lbi = GenLogicalBlobId(conf_wrapper.input(\"in\", 0));\n    const LogicalBlobId& parallel_cast_out_lbi = GenLogicalBlobId(conf_wrapper.output(\"out\", 0));\n    const OpNode* producer = op_graph.OpNode4OpName(parallel_cast_in_lbi.op_name());\n    const NdSbp& parallel_cast_nd_sbp = op_node->NdSbp4Lbi(parallel_cast_in_lbi);\n    const NdSbp& producer_nd_sbp = producer->NdSbp4Lbi(parallel_cast_in_lbi);\n    if (op_node->parallel_desc() != producer->parallel_desc()) { return; }\n    if (parallel_cast_nd_sbp != producer_nd_sbp && op_node->out_edges().size() > 1) { return; }\n    for (const OpEdge* out_edge : op_node->out_edges()) {\n      const OpNode* consumer = out_edge->dst_node();\n      if (IsParallelCastOp(consumer->op().op_conf())) { return; }\n      if (consumer->parallel_desc() != op_node->parallel_desc()) { return; }\n      if (consumer->NdSbp4Lbi(parallel_cast_out_lbi) != parallel_cast_nd_sbp) { return; }\n    }\n    op_name2nd_sbp_signature[producer->op().op_name()] = producer->nd_sbp_signature();\n    for (const OpEdge* out_edge : op_node->out_edges()) {\n      const OpNode* consumer = out_edge->dst_node();\n      const std::string& consumer_op_name = consumer->op().op_name();\n      op_name2nd_sbp_signature[consumer_op_name] = consumer->nd_sbp_signature();\n      if (op_name2op_conf.find(consumer_op_name) == op_name2op_conf.end()) {\n        op_name2op_conf[consumer_op_name] = consumer->op().op_conf();\n      }\n      OperatorConf& consumer_op_conf = op_name2op_conf.at(consumer_op_name);\n      for (const std::string& ibn : consumer->op().input_bns()) {\n        if (consumer->op().BnInOp2Lbi(ibn) == parallel_cast_out_lbi) {\n          const auto& new_val = GenLogicalBlobName(parallel_cast_in_lbi);\n          const auto& old_val = ReplaceInputLbnInOpCustomizedConf(&consumer_op_conf, ibn, new_val);\n          CHECK_EQ(GenLogicalBlobName(parallel_cast_out_lbi), old_val);\n        }\n      }\n    }\n    del_op_names.emplace_back(op_conf.name());\n  });\n  for (const auto& pair : op_name2op_conf) { job_builder->MutOpsOnlyOnce({pair.second}); }\n  for (const auto& pair : op_name2nd_sbp_signature) {\n    job_builder->AddNdSbpSignature4OpName(pair.first, pair.second);\n  }\n  job_builder->DelOps(del_op_names);\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\nREGISTER_JOB_PASS(\"PruneParallelCastOpsPass\", PruneParallelCastOpsPass);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job_rewriter/prune_pinned_identity_op_pass.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/job_rewriter/job_pass.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nclass PrunePinnedIdentityOpPass final : public JobPass {\n public:\n  PrunePinnedIdentityOpPass() = default;\n  ~PrunePinnedIdentityOpPass() override = default;\n\n  Maybe<void> Apply(Job* job, JobPassCtx* ctx) const override;\n};\n\nMaybe<std::string> PrunePinnedIdentityOp(JobBuilder* job_builder,\n                                         std::vector<std::string>* outdated_ops,\n                                         const OpGraph& op_graph, const std::string& lbn) {\n  auto lbi = GenLogicalBlobId(lbn);\n  const OpNode* op_node = op_graph.OpNode4OpName(lbi.op_name());\n  CHECK_EQ_OR_RETURN(op_node->in_edges().size(), 1);  // NOLINT\n  const OperatorConf& op_conf = op_node->op().op_conf();\n  CHECK_OR_RETURN(op_conf.has_user_conf());  // NOLINT\n  const std::string& op_type_name = op_conf.user_conf().op_type_name();\n  CHECK_OR_RETURN(op_type_name == \"pinned_identity\");  // NOLINT\n\n  // skip prune if the pinned identity has `ctrl_in_op`\n  if (!op_conf.ctrl_in_op_name().empty()) { return lbn; }\n\n  const user_op::UserOpConfWrapper user_op_conf(op_conf);\n  const LogicalBlobId& in_lbi = GenLogicalBlobId(user_op_conf.input(\"in\", 0));\n  const LogicalBlobId& out_lbi = GenLogicalBlobId(user_op_conf.output(\"out\", 0));\n\n  op_node->ForEachNodeOnOutEdge([&](const OpNode* out_node) {\n    for (const std::string& ibn : out_node->op().input_bns()) {\n      if (out_node->op().BnInOp2Lbi(ibn) == out_lbi) {\n        if (!CHECK_JUST(job_builder->IsInMutOpTransaction(out_node->op().op_name()))) {\n          CHECK_JUST(job_builder->MutOpTransactionMut(out_node->op().op_conf()));\n        }\n        OperatorConf& mut_consumer_op =\n            CHECK_JUST(job_builder->MutOpTransactionGet(out_node->op().op_name()));\n        const auto& old_lbn =\n            ReplaceInputLbnInOpCustomizedConf(&mut_consumer_op, ibn, GenLogicalBlobName(in_lbi));\n        CHECK_EQ(old_lbn, GenLogicalBlobName(out_lbi));\n      }\n    }\n  });\n  outdated_ops->push_back(op_conf.name());\n  return GenLogicalBlobName(in_lbi);\n}\n\nMaybe<void> PrunePinnedIdentityOpPass::Apply(Job* job, JobPassCtx* ctx) const {\n  if (!job->job_conf().has_train_conf()) { return Maybe<void>::Ok(); }\n  const OpGraph op_graph(*job);\n  JobBuilder job_builder(job);\n  HashMap<std::string, std::string> pruned_lbns;\n  std::vector<std::string> outdated_ops;\n  TrainConf* train_conf = job->mutable_job_conf()->mutable_train_conf();\n  // prune loss pinned identity\n  for (int i = 0; i < train_conf->loss_lbn_size(); ++i) {\n    const auto& pinned_loss_lbn = train_conf->loss_lbn(i);\n    auto it = pruned_lbns.find(pinned_loss_lbn);\n    if (it == pruned_lbns.end()) {\n      const auto& loss_lbn =\n          JUST(PrunePinnedIdentityOp(&job_builder, &outdated_ops, op_graph, pinned_loss_lbn));\n      it = pruned_lbns.emplace(pinned_loss_lbn, *loss_lbn).first;\n    }\n    train_conf->set_loss_lbn(i, it->second);\n  }\n  // prune loss initial gradient pinned identity\n  for (int i = 0; i < train_conf->loss_grad_lbn_size(); ++i) {\n    const auto& pinned_loss_grad_lbn = train_conf->loss_grad_lbn(i);\n    auto it = pruned_lbns.find(pinned_loss_grad_lbn);\n    if (it == pruned_lbns.end()) {\n      const auto& loss_grad_lbn =\n          JUST(PrunePinnedIdentityOp(&job_builder, &outdated_ops, op_graph, pinned_loss_grad_lbn));\n      it = pruned_lbns.emplace(pinned_loss_grad_lbn, *loss_grad_lbn).first;\n    }\n    train_conf->set_loss_grad_lbn(i, it->second);\n  }\n  // prune variable gradient pinned identity\n  for (int i = 0; i < train_conf->optimizer_conf_size(); ++i) {\n    auto* optimizer_conf = train_conf->mutable_optimizer_conf(i);\n    for (int j = 0; j < optimizer_conf->variable_grad_lbns_size(); ++j) {\n      const auto& pinned_variable_grad_lbn = optimizer_conf->variable_grad_lbns(j);\n      if (pinned_variable_grad_lbn.empty()) { continue; }\n      auto it = pruned_lbns.find(pinned_variable_grad_lbn);\n      if (it == pruned_lbns.end()) {\n        const auto& variable_grad_lbn = JUST(\n            PrunePinnedIdentityOp(&job_builder, &outdated_ops, op_graph, pinned_variable_grad_lbn));\n        it = pruned_lbns.emplace(pinned_variable_grad_lbn, *variable_grad_lbn).first;\n      }\n      optimizer_conf->set_variable_grad_lbns(j, it->second);\n    }\n  }\n  job_builder.DelOps(outdated_ops);\n  JUST(job_builder.MutOpTransactionCommit());\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\nREGISTER_JOB_PASS(\"PrunePinnedIdentityOpPass\", PrunePinnedIdentityOpPass);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job_rewriter/quantization_aware_training.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/job/job_conf.pb.h\"\n#include <algorithm>\n\n#include \"oneflow/core/device/cuda_util.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/job/scope.h\"\n#include \"oneflow/core/job_rewriter/job_pass.h\"\n#include \"oneflow/core/job_rewriter/pass_util.h\"\n#include \"oneflow/core/job/job_desc.h\"\n#include \"oneflow/core/vm/symbol_storage.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nusing OpTypeSet = HashSet<std::string>;\n\nconst std::string FAKE_QUANT_SUFFIX = \"-fake-quant\";\nconst std::string ZP_SUFFIX = \"-fake-quant-zp\";\nconst std::string MOVING_MAX_SUFFIX = \"-fake-quant-moving-max\";\nconst std::string MOVING_MIN_SUFFIX = \"-fake-quant-moving-min\";\nconst std::string MUL_BIAS_SUFFIX = \"-fake-quant-mul-bias\";\nconst std::string OBSERVER_SUFFIX = \"-fake-quant-observer\";\nconst std::string TRAIN_STEP_SUFFIX = \"-fake-train-step\";\n\nMaybe<void> VerifyQATList(const OpTypeSet& op_list) {\n  for (const auto& op_type : op_list) {\n    CHECK_OR_RETURN(user_op::UserOpRegistryMgr::Get().GetOpRegistryResult(op_type) != nullptr)\n        << \"Cannot find \" << op_type << \" of QuantAwareTraining list in OpRegistry.\";\n  }\n  return Maybe<void>::Ok();\n}\n\nHashMap<std::string, std::string> scale_map;\n\nMaybe<std::string> GetScaleLbn(const std::string& lbn) {\n  CHECK_OR_RETURN(scale_map.find(lbn) != scale_map.end());\n  return scale_map[lbn];\n}\n\nMaybe<bool> IsConvBiasEdge(const QatConfig& qat_config, const OpEdge* edge,\n                           std::string* conv_input_scale_lbn, std::string* conv_weight_scale_lbn,\n                           int64_t* weight_scale_length) {\n  const auto* dst_node = edge->dst_node();\n\n  const auto dst_op_type = dst_node->op().op_conf().user_conf().op_type_name();\n\n  auto GetInputAndWeightScaleLbnAndWeightScaleLen4ConvNode =\n      [](const QatConfig& qat_config, const OpNode* conv_node, std::string* conv_input_scale_lbn,\n         std::string* conv_weight_scale_lbn, int64_t* weight_scale_length) -> Maybe<void> {\n    *weight_scale_length = 1;\n    for (const OpEdge* in_edge : conv_node->in_edges()) {\n      CHECK_EQ_OR_RETURN(in_edge->lbis().size(), 1);\n      const auto lbi = in_edge->lbis().front();\n      const auto ibn = in_edge->lbi2ibns().at(lbi);\n      CHECK_EQ_OR_RETURN(ibn.size(), 1);\n      CHECK_OR_RETURN(ibn[0] == \"in_0\" || ibn[0] == \"weight_0\");\n      if (ibn[0] == \"in_0\") {\n        *conv_input_scale_lbn = *JUST(GetScaleLbn(GenLogicalBlobName(in_edge->lbis()[0])));\n      } else if (ibn[0] == \"weight_0\") {\n        if (qat_config.per_channel_weight_quantization()) {\n          *weight_scale_length = conv_node->LogicalBlobDesc4Lbi(lbi).shape().At(0);\n        }\n        *conv_weight_scale_lbn = *JUST(GetScaleLbn(GenLogicalBlobName(in_edge->lbis()[0])));\n      }\n    }\n    return Maybe<void>::Ok();\n  };\n\n  if (dst_op_type == \"conv2d\") {\n    CHECK_EQ_OR_RETURN(edge->lbis().size(), 1);\n    const auto lbi = edge->lbis().front();\n    const auto ibn = edge->lbi2ibns().at(lbi);\n    CHECK_EQ_OR_RETURN(ibn.size(), 1);\n    if (ibn[0] == \"bias_0\") {\n      JUST(GetInputAndWeightScaleLbnAndWeightScaleLen4ConvNode(\n          qat_config, dst_node, conv_input_scale_lbn, conv_weight_scale_lbn, weight_scale_length));\n      return true;\n    }\n  } else if (dst_op_type == \"bias_add\") {\n    // check whether the bias_add corresponds to a conv\n    for (const OpEdge* edge : dst_node->in_edges()) {\n      const auto* src_node = edge->src_node();\n      if (src_node->op().op_conf().user_conf().op_type_name() == \"conv2d\") {\n        JUST(GetInputAndWeightScaleLbnAndWeightScaleLen4ConvNode(\n            qat_config, src_node, conv_input_scale_lbn, conv_weight_scale_lbn,\n            weight_scale_length));\n        return true;\n      }\n    }\n  }\n  return false;\n}\n\nbool IsWeightEdge(const OpEdge* edge) {\n  return edge->src_node()->op().op_conf().has_variable_conf();\n}\n\nbool IsBnInputEdge(const OpEdge* edge) {\n  // Skip the inputs of bn for now.\n  // In the complete qat pass, bn will be merged into conv.\n  return edge->dst_node()->op().op_conf().user_conf().op_type_name() == \"normalization\";\n}\n\nstd::string OpTypeName4OpNode(const OpNode* node) {\n  return node->op().op_conf().user_conf().op_type_name();\n}\n\nusing OpConfMap = HashMap<std::string, OperatorConf>;\n\ntemplate<DataType data_type = DataType::kFloat>\nOperatorConf Get1DZeroVariableOpConf(std::string name, const int64_t scope_symbol_id,\n                                     const int64_t length, OpConfMap* inserted_ops) {\n  OperatorConf variable_op_conf{};\n  variable_op_conf.set_name(name);\n  variable_op_conf.set_scope_symbol_id(scope_symbol_id);\n  VariableOpConf* variable_conf = variable_op_conf.mutable_variable_conf();\n  variable_conf->set_out(\"out\");\n  *variable_conf->mutable_shape()->mutable_dim()->Add() = length;\n  variable_conf->set_data_type(data_type);\n  variable_conf->mutable_initializer()->mutable_constant_conf()->set_value(0);\n  (*inserted_ops)[name] = variable_op_conf;\n  return variable_op_conf;\n}\n\nMaybe<OpNode*> GetInferenceOutputNode(const OpGraph& op_graph, OpNode* node) {\n  OpNode* cur_node = node;\n  if (node->op().op_conf().user_conf().op_type_name() == \"conv2d\"\n      && node->out_edges().size() == 1) {\n    OpNode* next_node = node->SoleOutEdge()->dst_node();\n    if (OpTypeName4OpNode(next_node) == \"bias_add\") {\n      cur_node = next_node;\n      if (next_node->out_edges().size() == 1) { next_node = next_node->SoleOutEdge()->dst_node(); }\n    }\n    if (OpTypeName4OpNode(next_node) == \"normalization\") {\n      cur_node = next_node;\n      if (next_node->out_edges().size() == 1) { next_node = next_node->SoleOutEdge()->dst_node(); }\n    }\n    if (OpTypeName4OpNode(next_node) == \"relu\") { cur_node = next_node; }\n  }\n  VLOG(3) << \"For node: \" << node->op().op_name();\n  VLOG(3) << \"output node is: \" << cur_node->op().op_name();\n  return cur_node;\n}\n\nbool PerLayerQuantizationAttr4Config(const QatConfig& qat_config) {\n  return !qat_config.per_channel_weight_quantization();\n}\n\nstd::string QuantizationSchemeAttr4QatConfig(const QatConfig& qat_config) {\n  return qat_config.symmetric() ? \"symmetric\" : \"affine\";\n}\n\n// TODO: refactor the following 4 methods by registration\nMaybe<std::string> QuantizationFormulaAttr4QatConfig(const QatConfig& qat_config) {\n  const auto target_backend = qat_config.target_backend();\n  if (target_backend == \"\" || target_backend == \"tensorrt\") {\n    return std::string(\"google\");\n  } else if (target_backend == \"cambricon\") {\n    return std::string(\"cambricon\");\n  } else {\n    UNIMPLEMENTED_THEN_RETURN();\n  }\n}\n\nMaybe<OpTypeSet> Int8List4QatConfig(const QatConfig& qat_config) {\n  const auto target_backend = qat_config.target_backend();\n  if (target_backend == \"\") {\n    return OpTypeSet{\"add_n\",  \"matmul\",         \"batch_matmul\",\n                     \"conv2d\", \"tf_avg_pool_2d\", \"tf_max_pool_2d\"};\n  } else if (target_backend == \"cambricon\" || target_backend == \"tensorrt\") {\n    return OpTypeSet{\"conv2d\", \"matmul\"};\n  } else {\n    UNIMPLEMENTED_THEN_RETURN();\n  }\n}\n\nMaybe<OpTypeSet> TransparentList4QatConfig(const QatConfig& qat_config) {\n  const auto target_backend = qat_config.target_backend();\n  if (target_backend == \"\" || target_backend == \"tensorrt\") {\n    return OpTypeSet{\"reshape\"};\n  } else if (target_backend == \"cambricon\") {\n    return OpTypeSet{};\n  } else {\n    UNIMPLEMENTED_THEN_RETURN();\n  }\n}\n\nMaybe<bool> InsertQuantOpAfterInt8Ops4QatConfig(const QatConfig& qat_config) {\n  const auto target_backend = qat_config.target_backend();\n  if (target_backend == \"\" || target_backend == \"tensorrt\") {\n    return true;\n  } else if (target_backend == \"cambricon\") {\n    return false;\n  } else {\n    UNIMPLEMENTED_THEN_RETURN();\n  }\n}\n\nuser_op::UserOpConfWrapper MultiplyOp(const std::string& name, const std::string& x,\n                                      const std::string& y, const int64_t scope_symbol_id,\n                                      OpConfMap* inserted_ops) {\n  auto op_wrapper = user_op::UserOpConfWrapperBuilder(name)\n                        .Op(\"broadcast_mul\")\n                        .Input(\"x\", x)\n                        .Input(\"y\", y)\n                        .Output(\"z\")\n                        .ScopeSymbolId(scope_symbol_id)\n                        .Build();\n  (*inserted_ops)[name] = op_wrapper.op_conf();\n  return op_wrapper;\n}\n\nMaybe<user_op::UserOpConfWrapper> MinMaxObserver(const std::string& name, const std::string& input,\n                                                 const QatConfig& qat_config,\n                                                 const int64_t scope_symbol_id,\n                                                 OpConfMap* inserted_ops) {\n  const auto op_wrapper =\n      user_op::UserOpConfWrapperBuilder(name)\n          .Op(\"min_max_observer\")\n          .Input(\"in\", input)\n          .Output(\"scale\")\n          .Output(\"zero_point\")\n          .Attr<std::string>(\"quantization_formula\",\n                             *JUST(QuantizationFormulaAttr4QatConfig(qat_config)))\n          .Attr<std::string>(\"quantization_scheme\", QuantizationSchemeAttr4QatConfig(qat_config))\n          .Attr(\"per_layer_quantization\", PerLayerQuantizationAttr4Config(qat_config))\n          .ScopeSymbolId(scope_symbol_id)\n          .Build();\n  (*inserted_ops)[name] = op_wrapper.op_conf();\n  return op_wrapper;\n}\n\nMaybe<user_op::UserOpConfWrapper> MovingMinMaxObserver(\n    const std::string& name, const std::string& input, const std::string& train_step_lbn,\n    const QatConfig& qat_config, const int64_t scope_symbol_id, OpConfMap* inserted_ops) {\n  const std::string moving_max_name = name + MOVING_MAX_SUFFIX;\n  const std::string moving_min_name = name + MOVING_MIN_SUFFIX;\n  const auto moving_max_var =\n      Get1DZeroVariableOpConf(moving_max_name, scope_symbol_id, 1, inserted_ops);\n  const auto moving_min_var =\n      Get1DZeroVariableOpConf(moving_min_name, scope_symbol_id, 1, inserted_ops);\n  std::string observer_current_train_step = train_step_lbn;\n  if (!GlobalJobDesc().IsTrain()) {\n    const std::string train_step_name = name + TRAIN_STEP_SUFFIX;\n    const auto train_step_var = Get1DZeroVariableOpConf<DataType::kInt64>(\n        train_step_name, scope_symbol_id, 1, inserted_ops);\n    observer_current_train_step =\n        GenLogicalBlobName(train_step_var.name(), train_step_var.variable_conf().out());\n  }\n  const auto op_wrapper =\n      user_op::UserOpConfWrapperBuilder(name)\n          .Op(\"moving_average_min_max_observer\")\n          .Input(\"in\", input)\n          .Input(\"current_train_step\", observer_current_train_step)\n          .Input(\"moving_max\",\n                 GenLogicalBlobName(moving_max_var.name(), moving_max_var.variable_conf().out()))\n          .Input(\"moving_min\",\n                 GenLogicalBlobName(moving_min_var.name(), moving_min_var.variable_conf().out()))\n          .Output(\"scale\")\n          .Output(\"zero_point\")\n          .Attr(\"training\", GlobalJobDesc().IsTrain())\n          .Attr(\"stop_update_after_iters\", qat_config.moving_min_max_stop_update_after_iters())\n          .Attr<std::string>(\"quantization_formula\",\n                             *JUST(QuantizationFormulaAttr4QatConfig(qat_config)))\n          .Attr<std::string>(\"quantization_scheme\", QuantizationSchemeAttr4QatConfig(qat_config))\n          .Attr(\"momentum\", qat_config.moving_min_max_momentum())\n          .ScopeSymbolId(scope_symbol_id)\n          .Build();\n  (*inserted_ops)[name] = op_wrapper.op_conf();\n  return op_wrapper;\n}\n\nMaybe<user_op::UserOpConfWrapper> FakeQuantOp(const std::string& name, const std::string& input,\n                                              const std::string& scale,\n                                              const std::string& zero_point,\n                                              const QatConfig& qat_config,\n                                              const int64_t scope_symbol_id,\n                                              OpConfMap* inserted_ops) {\n  const auto op_wrapper =\n      user_op::UserOpConfWrapperBuilder(name)\n          .Op(\"fake_quantization\")\n          .Input(\"in\", input)\n          .Input(\"scale\", scale)\n          .Input(\"zero_point\", zero_point)\n          .Attr<std::string>(\"quantization_formula\",\n                             *JUST(QuantizationFormulaAttr4QatConfig(qat_config)))\n          .Attr<std::string>(\"quantization_scheme\", QuantizationSchemeAttr4QatConfig(qat_config))\n          .Output(\"out\")\n          .ScopeSymbolId(scope_symbol_id)\n          .Build();\n  (*inserted_ops)[name] = op_wrapper.op_conf();\n  return op_wrapper;\n}\n\nMaybe<void> GetScaleAndZeroPointLbn4Edge(OpEdge* edge, const std::string train_step_lbn,\n                                         std::string* scale, std::string* zero_point,\n                                         const QatConfig& qat_config, const int64_t scope_symbol_id,\n                                         OpConfMap* inserted_ops) {\n  std::string lbn = GenLogicalBlobName(edge->lbis().front());\n  std::string conv_input_scale_lbn;\n  std::string conv_weight_scale_lbn;\n  int64_t weight_scale_length;\n  if (JUST(IsConvBiasEdge(qat_config, edge, &conv_input_scale_lbn, &conv_weight_scale_lbn,\n                          &weight_scale_length))) {\n    // mul scale\n    const std::string mul_scale_op_name = ReplaceSlashToDash4Lbn(lbn) + MUL_BIAS_SUFFIX;\n    CHECK_OR_RETURN(inserted_ops->find(mul_scale_op_name) == inserted_ops->end());\n    const auto mul_scale_op = MultiplyOp(mul_scale_op_name, conv_input_scale_lbn,\n                                         conv_weight_scale_lbn, scope_symbol_id, inserted_ops);\n\n    *scale = mul_scale_op.output(\"z\", 0);\n    const std::string zp_var_name = ReplaceSlashToDash4Lbn(lbn) + ZP_SUFFIX;\n    const auto zp_var =\n        Get1DZeroVariableOpConf(zp_var_name, scope_symbol_id, weight_scale_length, inserted_ops);\n    *zero_point = GenLogicalBlobName(zp_var.name(), zp_var.variable_conf().out());\n  } else {\n    const std::string observer_op_name = ReplaceSlashToDash4Lbn(lbn) + OBSERVER_SUFFIX;\n    if (IsWeightEdge(edge)) {\n      const auto observer_op =\n          JUST(MinMaxObserver(observer_op_name, lbn, qat_config, scope_symbol_id, inserted_ops));\n      *scale = observer_op->output(\"scale\", 0);\n      *zero_point = observer_op->output(\"zero_point\", 0);\n    } else {\n      CHECK_OR_RETURN(qat_config.has_moving_min_max_stop_update_after_iters());\n      const auto observer_op = JUST(MovingMinMaxObserver(\n          observer_op_name, lbn, train_step_lbn, qat_config, scope_symbol_id, inserted_ops));\n      *scale = observer_op->output(\"scale\", 0);\n      *zero_point = observer_op->output(\"zero_point\", 0);\n    }\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> ReplaceInputLbn4DstNodeOfEdge(OpEdge* edge, const std::string& new_lbn,\n                                          OpConfCache* op_conf_cache) {\n  OpNode* dst_node = edge->dst_node();\n  LogicalBlobId cur_lbi = edge->lbis().front();\n  CHECK_EQ_OR_RETURN(1, edge->lbi2ibns().at(cur_lbi).size());\n  const std::string& dst_ibn = edge->lbi2ibns().at(cur_lbi).front();\n\n  OperatorConf dst_op_conf = op_conf_cache->GetLatest(dst_node->op().op_conf());\n  ReplaceInputLbnInOpCustomizedConf(&dst_op_conf, dst_ibn, new_lbn);\n  op_conf_cache->Put(dst_op_conf);\n  return Maybe<void>::Ok();\n}\n\nclass QuantAwareTraining final : public JobPass {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(QuantAwareTraining);\n  QuantAwareTraining() = default;\n\n  bool IsEnabled(const JobPassCtx& ctx) const {\n    return ctx.job_desc().job_conf().enable_quantization_aware_training();\n  }\n\n  Maybe<void> Apply(Job* job, JobPassCtx* ctx) const override;\n\n private:\n  Maybe<void> InsertFakeQuantOp(const QatConfig& qat_config, const OpGraph& op_graph,\n                                const OpTypeSet& int8_list, const OpTypeSet& transparent_list,\n                                bool insert_quant_op_after_int8_ops,\n                                HashSet<OpNode*> downstream_white, Job* job) const;\n};\n\nMaybe<bool> IsNodeQuantizationEnabled(const OpNode& node) {\n  int64_t scope_symbol_id = node.op().op_conf().scope_symbol_id();\n  CHECK_OR_RETURN(Singleton<symbol::Storage<Scope>>::Get()->Has(scope_symbol_id));  // NOLINT\n  const Scope& scope = Singleton<symbol::Storage<Scope>>::Get()->Get(scope_symbol_id);\n  return scope.Bool(\"quantization_aware_training\");\n}\n\nMaybe<void> QuantAwareTraining::Apply(Job* job, JobPassCtx* ctx) const {\n  if (!IsEnabled(*ctx)) { return Maybe<void>::Ok(); }\n  const OpGraph op_graph(*job);\n  CHECK_OR_RETURN(GlobalJobDesc().DefaultDataType() == DataType::kFloat);\n\n  const auto qat_config = ctx->job_desc().job_conf().qat_config();\n\n  OpTypeSet int8_list = *JUST(Int8List4QatConfig(qat_config));\n  OpTypeSet transparent_list = *JUST(TransparentList4QatConfig(qat_config));\n  // if `insert_quant_op_after_int8_ops` is false,\n  // always insert quant op before int8 ops.\n  // if `insert_quant_op_after_int8_ops` is true,\n  // always insert quant op after int8 ops\n  bool insert_quant_op_after_int8_ops = JUST(InsertQuantOpAfterInt8Ops4QatConfig(qat_config));\n\n  JUST(VerifyQATList(int8_list));\n  JUST(VerifyQATList(transparent_list));\n\n  std::function<std::string(OpNode* const&)> OpName4Node = [](OpNode* const& node) {\n    return node->op().op_name();\n  };\n  HashSet<OpNode*> white_set;\n  DfsTopoGraphTraversal(\n      op_graph, false, [&int8_list](OpNode* node) { return IsNodeInList(int8_list, node); },\n      [&](OpNode* node) { return IsNodeInList(transparent_list, node); },\n      [&](OpNode* node) { return IsKeyFound(white_set, node); },\n      [&](OpNode* node) {\n        INSERT_CHECK(white_set.insert(node));\n        if (node->op().op_conf().user_conf().op_type_name() == \"conv2d\"\n            && node->out_edges().size() == 1) {\n          OpNode* next_node = node->SoleOutEdge()->dst_node();\n          if (OpTypeName4OpNode(next_node) == \"bias_add\") {\n            INSERT_CHECK(white_set.insert(next_node));\n            // TODO(daquexian): mark these special nodes\n            if (next_node->out_edges().size() == 1) {\n              next_node = next_node->SoleOutEdge()->dst_node();\n            }\n          }\n          if (OpTypeName4OpNode(next_node) == \"normalization\") {\n            INSERT_CHECK(white_set.insert(next_node));\n            if (next_node->out_edges().size() == 1) {\n              next_node = next_node->SoleOutEdge()->dst_node();\n            }\n          }\n          if (OpTypeName4OpNode(next_node) == \"relu\") { INSERT_CHECK(white_set.insert(next_node)); }\n        }\n      });\n\n  VLOG(3) << \"white_set include: \"\n          << Container2Str<HashSet<OpNode*>, OpNode*>(white_set, OpName4Node);\n\n  JUST(InsertFakeQuantOp(ctx->job_desc().job_conf().qat_config(), op_graph, int8_list,\n                         transparent_list, insert_quant_op_after_int8_ops, white_set, job));\n  return Maybe<void>::Ok();\n}\n\n// TODO: remove int8_list, transparent_list and insert_quant_op_after_int8_ops arguments\nMaybe<void> QuantAwareTraining::InsertFakeQuantOp(const QatConfig& qat_config,\n                                                  const OpGraph& op_graph,\n                                                  const OpTypeSet& int8_list,\n                                                  const OpTypeSet& transparent_list,\n                                                  const bool insert_quant_op_after_int8_ops,\n                                                  HashSet<OpNode*> white_set, Job* job) const {\n  JobBuilder job_builder(job);\n  HashSet<OpEdge*> white_set_edges;\n  auto EdgeName4Edge = [](OpEdge* const& edge) {\n    return std::string(\"edge of\\t\") + edge->src_node()->op().op_name() + \"\\tto\\t\"\n           + edge->dst_node()->op().op_name();\n  };\n  auto AddWhiteSetEdge = [&white_set_edges, &EdgeName4Edge](OpEdge* edge) -> Maybe<void> {\n    VLOG(3) << \"insert \" << EdgeName4Edge(edge);\n    CHECK_EQ_OR_RETURN(edge->lbis().size(), 1);\n    const std::string lbn = GenLogicalBlobName(edge->lbis().front());\n    scale_map[lbn] = ReplaceSlashToDash4Lbn(lbn) + OBSERVER_SUFFIX + \"/scale_0\";\n    VLOG(3) << \"set \" << lbn << \" to \" << scale_map[lbn];\n    INSERT_CHECK_OR_RETURN(white_set_edges.insert(edge));\n    return Maybe<void>::Ok();\n  };\n  auto PropagateScale = [](OpNode* node) -> Maybe<void> {\n    CHECK_EQ_OR_RETURN(node->in_edges().size(), 1);\n    CHECK_EQ_OR_RETURN(node->SoleInEdge()->lbis().size(), 1);\n    for (OpEdge* edge : node->out_edges()) {\n      CHECK_EQ_OR_RETURN(edge->lbis().size(), 1);\n      const std::string node_input_lbn = GenLogicalBlobName(node->SoleInEdge()->lbis().front());\n      const std::string lbn = GenLogicalBlobName(edge->lbis().front());\n      if (scale_map.find(node_input_lbn) != scale_map.end()) {\n        scale_map[lbn] = scale_map[node_input_lbn];\n      }\n    }\n    return Maybe<void>::Ok();\n  };\n\n  {\n    JUST(op_graph.MaybeForEachNode([&](OpNode* node) -> Maybe<void> {\n      if (IsKeyFound(white_set, node)) {\n        for (OpEdge* edge : node->in_edges()) {\n          if (IsKeyFound(white_set, edge->src_node())) { continue; }\n          if (JUST(IsNodeQuantizationEnabled(*edge->dst_node()))) { JUST(AddWhiteSetEdge(edge)); }\n        }\n        if (IsNodeInList(int8_list, node)) {\n          if (insert_quant_op_after_int8_ops) {\n            OpNode* inference_node = JUST(GetInferenceOutputNode(op_graph, node));\n            if (JUST(IsNodeQuantizationEnabled(*inference_node))) {\n              for (OpEdge* edge : inference_node->out_edges()) { JUST(AddWhiteSetEdge(edge)); }\n            }\n          } else {\n            if (JUST(IsNodeQuantizationEnabled(*node))) {\n              for (OpEdge* edge : node->in_edges()) {\n                if (white_set_edges.find(edge) == white_set_edges.end()) {\n                  JUST(AddWhiteSetEdge(edge));\n                }\n              }\n            }\n          }\n        } else if (IsNodeInList(transparent_list, node)) {\n          JUST(PropagateScale(node));\n        } else {\n          // this is bias_add, relu or bn op in \"conv -> bias_add -> bn -> relu\" pattern,\n          // do nothing\n        }\n      }\n      return Maybe<void>::Ok();\n    }));\n    VLOG(3) << \"white_set_edges: \"\n            << Container2Str<HashSet<OpEdge*>, OpEdge*>(white_set_edges, EdgeName4Edge);\n  }\n\n  // group edges by lbn so that we can use `src_node` when calling `AddOps`\n  HashMap<std::string, std::vector<OpEdge*>> edges_group_by_lbn;\n  {\n    for (OpEdge* edge : white_set_edges) {\n      CHECK_EQ_OR_RETURN(1, edge->lbis().size());\n      std::string lbn = GenLogicalBlobName(edge->lbis().front());\n      edges_group_by_lbn[lbn].emplace_back(edge);\n    }\n  }\n\n  OpConfCache op_conf_cache;\n  for (auto& pair : edges_group_by_lbn) {\n    const std::string& lbn = pair.first;\n    const OpNode* src_node = pair.second.front()->src_node();\n\n    const BlobDesc& blob_desc = src_node->LogicalBlobDesc4Lbi(GenLogicalBlobId(lbn));\n    if (blob_desc.data_type() != DataType::kFloat) { continue; }\n\n    OpConfMap inserted_ops;\n    for (OpEdge* edge : pair.second) {\n      if (IsBnInputEdge(edge)) { continue; }\n      std::string scale;\n      std::string zero_point;\n      const int64_t scope_symbol_id = edge->src_node()->op().op_conf().scope_symbol_id();\n      JUST(GetScaleAndZeroPointLbn4Edge(edge, job->job_conf().train_conf().train_step_lbn(), &scale,\n                                        &zero_point, qat_config, scope_symbol_id, &inserted_ops));\n      const std::string fake_quant_op_name = ReplaceSlashToDash4Lbn(lbn) + FAKE_QUANT_SUFFIX;\n      const auto fake_quant_op = JUST(FakeQuantOp(fake_quant_op_name, lbn, scale, zero_point,\n                                                  qat_config, scope_symbol_id, &inserted_ops));\n\n      const std::string fake_quant_op_output_name = fake_quant_op->output(\"out\", 0);\n\n      JUST(ReplaceInputLbn4DstNodeOfEdge(edge, fake_quant_op_output_name, &op_conf_cache));\n    }\n\n    for (const auto& pair : inserted_ops) {\n      VLOG(3) << \"Insert op: \" << pair.second.DebugString() << \" between \" << lbn;\n      job_builder.AddOps(src_node->parallel_desc().parallel_conf(), {pair.second});\n    }\n  }\n\n  job_builder.MutOpsOnlyOnce(op_conf_cache.op_confs());\n  return Maybe<void>::Ok();\n}\n\nREGISTER_JOB_PASS(\"QuantAwareTraining\", QuantAwareTraining);\n\n}  // namespace\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job_rewriter/replace_embedding_ops_pass.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/job_rewriter/job_pass.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/job_rewriter/dynamic_loss_scale_job_pass_state.h\"\n#include \"oneflow/core/job_rewriter/autograd.h\"\n#include \"oneflow/core/common/container_util.h\"\n#include \"oneflow/core/job_rewriter/clip_by_global_norm_job_pass_state.h\"\n#include \"oneflow/core/embedding/embedding_manager.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nstd::string BuildIdentityOp(JobBuilder* job_builder, const std::string& in_lbn,\n                            const ParallelConf& parallel_conf,\n                            const user_op::UserOpConfWrapper& embedding_op) {\n  user_op::UserOpConfWrapperBuilder identity_op_builder(embedding_op.op_name() + \"_identity_\"\n                                                        + NewUniqueId());\n  user_op::UserOpConfWrapper identity_op =\n      identity_op_builder.OpTypeName(\"identity\")\n          .Input(\"in\", in_lbn)\n          .Output(\"out\")\n          .ScopeSymbolId(embedding_op.op_conf().scope_symbol_id())\n          .Build();\n  job_builder->AddOps(parallel_conf, {identity_op.op_conf()});\n  return identity_op.output(\"out\", 0);\n}\n\nMaybe<void> DynamicLossScaleAddGradient(\n    JobPassCtx* ctx, const OpGraph& op_graph, JobBuilder* job_builder,\n    const HashMap<std::string, std::string>& shadow_op_name2grad_lbn, int64_t scope_symbol_id,\n    const ParallelConf& parallel_conf) {\n  if (job_builder->job().job_conf().train_conf().has_dynamic_loss_scale_policy()) {\n    const auto& dynamic_loss_scale_state =\n        JUST(ctx->GetState<DynamicLossScaleJobPassState>(\"dynamic_loss_scale_state\"));\n    const LogicalBlobId count_not_finite_lbi =\n        GenLogicalBlobId(dynamic_loss_scale_state.count_not_finite_lbn());\n    const OpNode* op_node = op_graph.OpNode4OpName(count_not_finite_lbi.op_name());\n    if (op_node->op().op_conf().has_user_conf()\n        && op_node->op().op_conf().user_conf().op_type_name() == \"identity\") {\n      const user_op::UserOpConfWrapper identity_op_conf(op_node->op().op_conf());\n      std::string new_count_not_finite_lbn;\n      if (shadow_op_name2grad_lbn.size() == 1) {\n        const std::string& grad_lbn = shadow_op_name2grad_lbn.begin()->second;\n        const auto count_not_finite_op =\n            user_op::UserOpConfWrapperBuilder(\"OneEmbedding-DynamicLossScale-CountNotFinite-\"\n                                              + NewUniqueId())\n                .Op(\"count_not_finite\")\n                .Input(\"x\", grad_lbn)\n                .Output(\"y\")\n                .ScopeSymbolId(op_node->op().op_conf().scope_symbol_id())\n                .Build();\n        job_builder->AddOps(parallel_conf, {count_not_finite_op.op_conf()});\n        new_count_not_finite_lbn = count_not_finite_op.output(\"y\", 0);\n      } else {\n        auto multi_count_not_finite_op_builder =\n            user_op::UserOpConfWrapperBuilder(\"OneEmbedding-DynamicLossScale-MultiCountNotFinite-\"\n                                              + NewUniqueId())\n                .Op(\"multi_count_not_finite\")\n                .Output(\"y\")\n                .ScopeSymbolId(op_node->op().op_conf().scope_symbol_id());\n        for (const auto& pair : shadow_op_name2grad_lbn) {\n          multi_count_not_finite_op_builder.Input(\"x\", pair.second);\n        }\n        const auto multi_count_not_finite_op = multi_count_not_finite_op_builder.Build();\n        job_builder->AddOps(parallel_conf, {multi_count_not_finite_op.op_conf()});\n        new_count_not_finite_lbn = multi_count_not_finite_op.output(\"y\", 0);\n      }\n      user_op::UserOpConfWrapperBuilder add_op_builder(\n          \"OneEmbedding-DynamicLossScale-CountNotFinite-Add_\" + NewUniqueId());\n      const auto add_op = add_op_builder.Op(\"add_n\")\n                              .Input(\"in\", identity_op_conf.input(\"in\", 0))\n                              .Input(\"in\", new_count_not_finite_lbn)\n                              .Output(\"out\")\n                              .ScopeSymbolId(op_node->op().op_conf().scope_symbol_id())\n                              .Build();\n\n      job_builder->AddOps(op_node->parallel_desc().parallel_conf(), {add_op.op_conf()});\n\n      OperatorConf new_identity_conf = identity_op_conf.op_conf();\n      const auto& old_val =\n          ReplaceInputLbnInOpCustomizedConf(&new_identity_conf, \"in_0\", add_op.output(\"out\", 0));\n      CHECK_EQ_OR_RETURN(identity_op_conf.input(\"in\", 0), old_val);\n      job_builder->MutOpsOnlyOnce({new_identity_conf});\n    } else {\n      UNIMPLEMENTED_THEN_RETURN();\n    }\n  }\n  return Maybe<void>::Ok();\n}\n\nvoid BuildEmbeddingLookup(\n    JobPassCtx* ctx, JobBuilder* job_builder, const int64_t embedding_size, const int64_t line_size,\n    const std::string& embedding_name, const int64_t seed, bool has_embedding_prefetch,\n    const ParallelConf& parallel_conf, const user_op::UserOpConfWrapper& embedding_op,\n    const std::string& prefetch_num_unique_ids_lbn, const std::string& prefetch_unique_ids_lbn,\n    const std::string& prefetch_unique_table_ids_lbn, const std::string& num_unique_ids_lbn,\n    const std::string& unique_ids_lbn, const std::string& unique_table_ids_lbn,\n    std::string* embedding_lbn, std::string* unique_values_lbn,\n    OperatorConf* embedding_prefetch_op_conf, OperatorConf* embedding_lookup_op_conf) {\n  std::string context_lbn;\n  if (has_embedding_prefetch) {\n    // embedding prefetch op\n    user_op::UserOpConfWrapperBuilder embedding_prefetch_op_builder(\n        embedding_op.op_name() + \"_embedding_prefetch\" + NewUniqueId());\n    user_op::UserOpConfWrapper embedding_prefetch_op =\n        embedding_prefetch_op_builder.OpTypeName(\"embedding_prefetch\")\n            .Input(\"num_unique_ids\", prefetch_num_unique_ids_lbn)\n            .Input(\"unique_ids\", prefetch_unique_ids_lbn)\n            .Input(\"table_ids\", prefetch_unique_table_ids_lbn)\n            .Output(\"context\")\n            .Attr<int64_t>(\"embedding_size\", embedding_size)\n            .Attr<int64_t>(\"line_size\", line_size)\n            .Attr<std::string>(\"embedding_tables\",\n                               embedding_op.attr<std::string>(\"embedding_tables\"))\n            .Attr<std::string>(\"embedding_name\", embedding_name)\n            .Attr<int64_t>(\"seed\", seed)\n            .ScopeSymbolId(embedding_op.op_conf().scope_symbol_id())\n            .Build();\n    *embedding_prefetch_op_conf = embedding_prefetch_op.op_conf();\n    if (!ParseBooleanFromEnv(\"ONEFLOW_ONE_EMBEDDING_DISABLE_PIPELINED_EXECUTION\", false)) {\n      embedding_prefetch_op_conf->set_stream_name_hint(embedding_name + \"_EMBEDDING\");\n    }\n    context_lbn = embedding_prefetch_op.output(\"context\", 0);\n  }\n\n  // embedding lookup op\n  user_op::UserOpConfWrapperBuilder embedding_lookup_op_builder(\n      embedding_op.op_name() + \"_embedding_lookup\" + NewUniqueId());\n  embedding_lookup_op_builder.OpTypeName(\"embedding_lookup\")\n      .Input(\"num_unique_ids\", num_unique_ids_lbn)\n      .Input(\"unique_ids\", unique_ids_lbn)\n      .Input(\"table_ids\", unique_table_ids_lbn)\n      .Output(\"unique_values\")\n      .Attr<DataType>(\"dtype\", embedding_op.attr<DataType>(\"dtype\"))\n      .Attr<int64_t>(\"embedding_size\", embedding_size)\n      .Attr<int64_t>(\"line_size\", line_size)\n      .Attr<std::string>(\"embedding_tables\", embedding_op.attr<std::string>(\"embedding_tables\"))\n      .Attr<std::string>(\"embedding_name\", embedding_name)\n      .Attr<int64_t>(\"seed\", seed)\n      .ScopeSymbolId(embedding_op.op_conf().scope_symbol_id());\n  if (has_embedding_prefetch) { embedding_lookup_op_builder.Input(\"context\", context_lbn); }\n  bool has_embeddings_output =\n      (line_size != embedding_size) || ctx->job_desc().enable_auto_mixed_precision();\n  if (has_embeddings_output) {\n    DataType embeddings_dtype = ctx->job_desc().enable_auto_mixed_precision()\n                                    ? DataType::kFloat16\n                                    : embedding_op.attr<DataType>(\"dtype\");\n    embedding_lookup_op_builder.Output(\"embeddings\")\n        .Attr<DataType>(\"embeddings_dtype\", embeddings_dtype);\n  }\n  user_op::UserOpConfWrapper embedding_lookup_op = embedding_lookup_op_builder.Build();\n  *embedding_lookup_op_conf = embedding_lookup_op.op_conf();\n  if (!ParseBooleanFromEnv(\"ONEFLOW_ONE_EMBEDDING_DISABLE_PIPELINED_EXECUTION\", false)) {\n    embedding_lookup_op_conf->set_stream_name_hint(embedding_name + \"_EMBEDDING\");\n  }\n  if (has_embeddings_output) {\n    *embedding_lbn = embedding_lookup_op.output(\"embeddings\", 0);\n  } else {\n    *embedding_lbn = embedding_lookup_op.output(\"unique_values\", 0);\n  }\n  *unique_values_lbn = embedding_lookup_op.output(\"unique_values\", 0);\n}\n\nvoid BuildEmbeddingShuffle(JobBuilder* job_builder, const std::string& embedding_name,\n                           int64_t embedding_size, const ParallelConf& parallel_conf,\n                           const user_op::UserOpConfWrapper& embedding_op,\n                           const std::string& inverse_indices_lbn,\n                           const std::string& inner_inverse_unique_partition_indices_lbn,\n                           const std::string& num_unique_matrix_lbn,\n                           const std::string& embedding_lbn, std::vector<OperatorConf>* add_ops,\n                           std::string* new_embeddings_lbn) {\n  const bool is_train_job = job_builder->job().job_conf().has_train_conf();\n  user_op::UserOpConfWrapperBuilder embedding_shuffle_op_builder(\n      embedding_op.op_name() + \"_embedding_shuffle\" + NewUniqueId());\n  user_op::UserOpConfWrapper embedding_shuffle_op =\n      embedding_shuffle_op_builder.OpTypeName(\"embedding_shuffle\")\n          .Input(\"cur_rank_embeddings\", embedding_lbn)\n          .Input(\"cur_rank_inverse_indices\", inverse_indices_lbn)\n          .Input(\"inverse_unique_partition_indices\", inner_inverse_unique_partition_indices_lbn)\n          .Input(\"num_unique_matrix\", num_unique_matrix_lbn)\n          .Attr<std::string>(\"embedding_name\", embedding_name)\n          .Attr<int64_t>(\"embedding_size\", embedding_size)\n          .Attr<bool>(\"is_train\", is_train_job)\n          .Output(\"embeddings\")\n          .ScopeSymbolId(embedding_op.op_conf().scope_symbol_id())\n          .Build();\n  OperatorConf embedding_shuffle_new_op_conf = embedding_shuffle_op.op_conf();\n  if (!ParseBooleanFromEnv(\"ONEFLOW_ONE_EMBEDDING_DISABLE_PIPELINED_EXECUTION\", false)\n      && ParseBooleanFromEnv(\"ONEFLOW_ONE_EMBEDDING_EMBEDDING_SHUFFLE_INDEPENTENT_STREAM\", true)) {\n    embedding_shuffle_new_op_conf.set_stream_name_hint(embedding_name + \"_EMBEDDING\");\n  }\n  add_ops->push_back(embedding_shuffle_new_op_conf);\n  *new_embeddings_lbn = embedding_shuffle_op.output(\"embeddings\", 0);\n}\n\nvoid BuildEmbeddingGradientShuffle(\n    JobPassCtx* ctx, const OpGraph& op_graph, JobBuilder* job_builder, const OpNode* op_node,\n    const std::string& embedding_name, int64_t embedding_size, const bool use_system_gather,\n    const ParallelConf& embedding_parallel_conf, const int64_t embedding_scope_symbol_id,\n    const user_op::UserOpConfWrapper& embedding_op, const std::string& inverse_indices_lbn,\n    const std::string& inner_inverse_unique_partition_indices_lbn,\n    const std::string& num_unique_matrix_lbn, const std::string& update_embedding_grad,\n    const bool has_clip_grad, std::string* cur_rank_unique_embedding_grad_lbn) {\n  std::string update_embedding_grad_lbn = update_embedding_grad;\n  if (ctx->job_desc().enable_auto_mixed_precision()\n      && !ParseBooleanFromEnv(\"ONEFLOW_ONE_EMBEDDING_GRADIENT_SHUFFLE_USE_FP16\", true)) {\n    auto cast_op =\n        user_op::UserOpConfWrapperBuilder(embedding_op.op_name() + \"_before_grad_shuffle_cast_h2f\")\n            .Op(\"cast\")\n            .Input(\"in\", update_embedding_grad_lbn)\n            .Output(\"out\")\n            .Attr<DataType>(\"dtype\", DataType::kFloat)\n            .ScopeSymbolId(embedding_scope_symbol_id)\n            .Build();\n    job_builder->AddOps(embedding_parallel_conf, {cast_op.op_conf()});\n    update_embedding_grad_lbn = cast_op.output(\"out\", 0);\n  }\n  if (use_system_gather) {\n    const int64_t num_segments =\n        op_node->LogicalBlobDesc4Lbi(op_node->op().BnInOp2Lbi(\"ids_0\")).shape().elem_cnt();\n    user_op::UserOpConfWrapperBuilder unsorted_segment_sum_op_builder(embedding_op.op_name()\n                                                                      + \"_unsorted_segment_sum\");\n    user_op::UserOpConfWrapper unsorted_segment_sum_op =\n        unsorted_segment_sum_op_builder.OpTypeName(\"unsorted_segment_sum\")\n            .Input(\"data\", update_embedding_grad_lbn)\n            .Input(\"segment_ids\", inverse_indices_lbn)\n            .Output(\"out\")\n            .Attr<int64_t>(\"num_segments\", num_segments)\n            .ScopeSymbolId(embedding_scope_symbol_id)\n            .Build();\n    job_builder->AddOps(embedding_parallel_conf, {unsorted_segment_sum_op.op_conf()});\n    *cur_rank_unique_embedding_grad_lbn = unsorted_segment_sum_op.output(\"out\", 0);\n  } else {\n    // embedding_gradient_shuffle op\n    // if no dynamic loss scale or no clip_grad, we think gradient shuffle grad's invalid buffer\n    // need not to be memset.\n    const bool has_dynamic_loss_scale =\n        job_builder->job().job_conf().train_conf().has_dynamic_loss_scale_policy();\n    const bool only_zero_valid_grad = (!has_clip_grad) && (!has_dynamic_loss_scale);\n    user_op::UserOpConfWrapperBuilder embedding_gradient_shuffle_op_builder(\n        embedding_op.op_name() + \"_embedding_gradient_shuffle\" + NewUniqueId());\n    user_op::UserOpConfWrapper embedding_gradient_shuffle_op =\n        embedding_gradient_shuffle_op_builder.OpTypeName(\"embedding_gradient_shuffle\")\n            .Input(\"cur_rank_inverse_indices\", inverse_indices_lbn)\n            .Input(\"inverse_unique_partition_indices\", inner_inverse_unique_partition_indices_lbn)\n            .Input(\"embedding_grad\", update_embedding_grad_lbn)\n            .Input(\"num_unique_matrix\", num_unique_matrix_lbn)\n            .Output(\"cur_rank_unique_embedding_grad\")\n            .Attr<std::string>(\"embedding_name\", embedding_name)\n            .Attr<int64_t>(\"embedding_size\", embedding_size)\n            .Attr<bool>(\"only_zero_valid_grad\", only_zero_valid_grad)\n            .ScopeSymbolId(embedding_scope_symbol_id)\n            .Build();\n    OperatorConf embedding_gradient_shuffle_new_op_conf = embedding_gradient_shuffle_op.op_conf();\n    if (!ParseBooleanFromEnv(\"ONEFLOW_ONE_EMBEDDING_DISABLE_PIPELINED_EXECUTION\", false)\n        && ParseBooleanFromEnv(\n            \"ONEFLOW_ONE_EMBEDDING_EMBEDDING_GRADIENT_SHUFFLE_INDEPENTENT_STREAM\", true)) {\n      embedding_gradient_shuffle_new_op_conf.set_stream_name_hint(embedding_name + \"_EMBEDDING\");\n    }\n    job_builder->AddOps(embedding_parallel_conf, {embedding_gradient_shuffle_new_op_conf});\n    *cur_rank_unique_embedding_grad_lbn =\n        embedding_gradient_shuffle_op.output(\"cur_rank_unique_embedding_grad\", 0);\n  }\n  if (ctx->job_desc().enable_auto_mixed_precision()\n      && ParseBooleanFromEnv(\"ONEFLOW_ONE_EMBEDDING_GRADIENT_SHUFFLE_USE_FP16\", true)\n      && (ParseBooleanFromEnv(\"ONEFLOW_ONE_EMBEDDING_NOT_FUSE_CAST_TO_UPDATE\", false)\n          || has_clip_grad)) {\n    auto cast_op = user_op::UserOpConfWrapperBuilder(embedding_op.op_name() + \"_cast_h2f\")\n                       .Op(\"cast\")\n                       .Input(\"in\", *cur_rank_unique_embedding_grad_lbn)\n                       .Output(\"out\")\n                       .Attr<DataType>(\"dtype\", DataType::kFloat)\n                       .ScopeSymbolId(embedding_scope_symbol_id)\n                       .Build();\n    *cur_rank_unique_embedding_grad_lbn = cast_op.output(\"out\", 0);\n    job_builder->AddOps(embedding_parallel_conf, {cast_op.op_conf()});\n  }\n}\n\ndouble GetLossInstanceNumScaleFactor(const OpGraph& op_graph, JobBuilder* job_builder) {\n  double scale_factor = 1;\n  std::function<OpNode*(const std::string&)> LossOpNode4OpName;\n  CHECK_JUST(MakeGetterLossOpNode4OpName(op_graph, &LossOpNode4OpName));\n  const TrainConf& train_conf = job_builder->job().job_conf().train_conf();\n  HashMap<LogicalBlobId, OpNode*> loss_lbi2op_node;\n  CHECK_GT(train_conf.loss_lbn().size(), 0);\n  for (const auto& loss_lbn : train_conf.loss_lbn()) {\n    const auto& lbi = GenLogicalBlobId(loss_lbn);\n    CHECK(loss_lbi2op_node.emplace(lbi, LossOpNode4OpName(lbi.op_name())).second);\n  }\n  const Shape src_time_shape({1, 1});\n  const int64_t source_time_shape_elem_cnt = src_time_shape.elem_cnt();\n  bool all_loss_time_shape_eq_src = true;\n  for (const auto& pair : loss_lbi2op_node) {\n    const int64_t time_shape_elem_cnt = CHECK_JUST(pair.second->op().GetOpTimeShape())->elem_cnt();\n    if (time_shape_elem_cnt != source_time_shape_elem_cnt) {\n      CHECK_EQ(time_shape_elem_cnt % source_time_shape_elem_cnt, 0);\n      all_loss_time_shape_eq_src = false;\n    }\n  }\n  if (all_loss_time_shape_eq_src) {\n    const BlobDesc* blob_desc = nullptr;\n    for (const auto& pair : loss_lbi2op_node) {\n      const BlobDesc* cur_blob_desc = &pair.second->LogicalBlobDesc4Lbi(pair.first);\n      if (blob_desc != nullptr) { CHECK(*blob_desc == *cur_blob_desc); }\n      blob_desc = cur_blob_desc;\n    }\n    CHECK(blob_desc != nullptr);\n    scale_factor = 1.0f / static_cast<float>(blob_desc->shape().elem_cnt());\n  } else {\n    std::unique_ptr<BlobDesc> blob_desc;\n    for (const auto& pair : loss_lbi2op_node) {\n      const BlobDesc* cur_blob_desc = &pair.second->LogicalBlobDesc4Lbi(pair.first);\n      // TODO: support dynamic\n      CHECK(!cur_blob_desc->is_dynamic());\n      const DataType loss_data_type = cur_blob_desc->data_type();\n      const int64_t time_shape_elem_cnt =\n          CHECK_JUST(pair.second->op().GetOpTimeShape())->elem_cnt();\n      // TODO: consider sbp\n      const int64_t loss_elem_cnt =\n          cur_blob_desc->shape().elem_cnt() * time_shape_elem_cnt / source_time_shape_elem_cnt;\n      if (blob_desc) {\n        CHECK_EQ(blob_desc->data_type(), loss_data_type);\n        CHECK_EQ(blob_desc->shape().elem_cnt(), loss_elem_cnt);\n      } else {\n        blob_desc.reset(\n            new BlobDesc(Shape({loss_elem_cnt}), loss_data_type, cur_blob_desc->memory_format()));\n      }\n    }\n    scale_factor = 1.0f / static_cast<float>(blob_desc->shape().elem_cnt());\n  }\n  return scale_factor;\n}\n\nvoid BuildIdShuffle(bool use_system_gather, const std::string& embedding_name,\n                    const user_op::UserOpConfWrapper& embedding_op,\n                    std::vector<OperatorConf>* add_ops, std::string* prefetch_num_unique_lbn,\n                    std::string* prefetch_unique_ids_lbn,\n                    std::string* prefetch_unique_table_ids_lbn,\n                    std::string* inner_inverse_unique_partition_indices_lbn,\n                    std::string* num_unique_ids_lbn, std::string* unique_ids_lbn,\n                    std::string* unique_table_ids_lbn, std::string* inverse_indices_lbn,\n                    std::string* num_unique_matrix_lbn) {\n  const int32_t num_tables = embedding_op.attr<int32_t>(\"num_tables\");\n  const int64_t padding_idx = embedding_op.attr<int64_t>(\"padding_idx\");\n  const int64_t has_padding_idx = embedding_op.attr<bool>(\"has_padding_idx\");\n  bool enable_pipelined_execution =\n      !ParseBooleanFromEnv(\"ONEFLOW_ONE_EMBEDDING_DISABLE_PIPELINED_EXECUTION\", false);\n  if (use_system_gather) {\n    user_op::UserOpConfWrapperBuilder unique_op_builder(embedding_op.op_name()\n                                                        + \"_unique_ids_and_tables\");\n    unique_op_builder.OpTypeName(\"unique_key_value_pair\")\n        .Input(\"keys\", embedding_op.input(\"ids\", 0))\n        .Output(\"num_unique\")\n        .Output(\"unique_keys\")\n        .Output(\"unique_values\")\n        .Output(\"inverse_indices\")\n        .Attr<int32_t>(\"num_tables\", num_tables)\n        .Attr<int64_t>(\"padding_idx\", padding_idx)\n        .Attr<bool>(\"has_padding_idx\", has_padding_idx)\n        .Attr<std::string>(\"embedding_name\", embedding_name)\n        .ScopeSymbolId(embedding_op.op_conf().scope_symbol_id());\n    if (embedding_op.has_input(\"table_ids\", 0)) {\n      unique_op_builder.Input(\"values\", embedding_op.input(\"table_ids\", 0));\n    }\n    user_op::UserOpConfWrapper unique_op = unique_op_builder.Build();\n    OperatorConf unique_new_op_conf = unique_op.op_conf();\n    if (enable_pipelined_execution) {\n      unique_new_op_conf.set_stream_name_hint(embedding_name + \"_ID_SHUFFLE\");\n    }\n    add_ops->push_back(unique_new_op_conf);\n    *num_unique_ids_lbn = unique_op.output(\"num_unique\", 0);\n    *unique_ids_lbn = unique_op.output(\"unique_keys\", 0);\n    *unique_table_ids_lbn = unique_op.output(\"unique_values\", 0);\n    *inverse_indices_lbn = unique_op.output(\"inverse_indices\", 0);\n    *prefetch_num_unique_lbn = *num_unique_ids_lbn;\n    *prefetch_unique_ids_lbn = *unique_ids_lbn;\n    *prefetch_unique_table_ids_lbn = *unique_table_ids_lbn;\n  } else {\n    user_op::UserOpConfWrapperBuilder id_shuffle_op_builder(embedding_op.op_name() + \"_id_shuffle\"\n                                                            + NewUniqueId());\n    id_shuffle_op_builder.OpTypeName(\"id_shuffle\")\n        .Input(\"ids\", embedding_op.input(\"ids\", 0))\n        .Output(\"inverse_unique_partition_indices\")\n        .Output(\"cur_rank_num_unique\")\n        .Output(\"cur_rank_unique_ids\")\n        .Output(\"cur_rank_unique_table_ids\")\n        .Output(\"cur_rank_inverse_indices\")\n        .Output(\"num_unique_matrix\")\n        .Attr<int32_t>(\"num_tables\", num_tables)\n        .Attr<int64_t>(\"padding_idx\", padding_idx)\n        .Attr<bool>(\"has_padding_idx\", has_padding_idx)\n        .Attr<std::string>(\"embedding_name\", embedding_name)\n        .ScopeSymbolId(embedding_op.op_conf().scope_symbol_id());\n    if (embedding_op.has_input(\"table_ids\", 0)) {\n      id_shuffle_op_builder.Input(\"table_ids\", embedding_op.input(\"table_ids\", 0));\n    }\n    user_op::UserOpConfWrapper id_shuffle_op = id_shuffle_op_builder.Build();\n    OperatorConf id_shuffle_new_op_conf = id_shuffle_op.op_conf();\n    if (enable_pipelined_execution) {\n      id_shuffle_new_op_conf.set_stream_name_hint(embedding_name + \"_ID_SHUFFLE\");\n    }\n    add_ops->push_back(id_shuffle_new_op_conf);\n    if (ParseBooleanFromEnv(\"ONEFLOW_ONE_EMBEDDING_ADD_ID_SHUFFLE_COPY_OUT\", true)) {\n      // add id_shuffle_copy_out, so the consumer can use light_actor and cuda_graph.\n      user_op::UserOpConfWrapperBuilder identity_op_builder(\n          embedding_op.op_name() + \"_id_shuffle_copy_out_\" + NewUniqueId());\n      user_op::UserOpConfWrapper identity_op =\n          identity_op_builder.OpTypeName(\"id_shuffle_copy_out\")\n              .Attr<std::string>(\"embedding_name\", embedding_name)\n              .Input(\"inverse_unique_partition_indices\",\n                     id_shuffle_op.output(\"inverse_unique_partition_indices\", 0))\n              .Input(\"cur_rank_num_unique\", id_shuffle_op.output(\"cur_rank_num_unique\", 0))\n              .Input(\"cur_rank_unique_ids\", id_shuffle_op.output(\"cur_rank_unique_ids\", 0))\n              .Input(\"cur_rank_unique_table_ids\",\n                     id_shuffle_op.output(\"cur_rank_unique_table_ids\", 0))\n              .Input(\"cur_rank_inverse_indices\",\n                     id_shuffle_op.output(\"cur_rank_inverse_indices\", 0))\n              .Input(\"num_unique_matrix\", id_shuffle_op.output(\"num_unique_matrix\", 0))\n              .Output(\"out_inverse_unique_partition_indices\")\n              .Output(\"out_cur_rank_num_unique\")\n              .Output(\"out_cur_rank_unique_ids\")\n              .Output(\"out_cur_rank_unique_table_ids\")\n              .Output(\"out_cur_rank_inverse_indices\")\n              .Output(\"out_num_unique_matrix\")\n              .ScopeSymbolId(embedding_op.op_conf().scope_symbol_id())\n              .Build();\n      OperatorConf identity_op_conf = identity_op.op_conf();\n      if (enable_pipelined_execution) {\n        identity_op_conf.set_stream_name_hint(embedding_name + \"_EMBEDDING\");\n      }\n      add_ops->push_back(identity_op_conf);\n      *inner_inverse_unique_partition_indices_lbn =\n          identity_op.output(\"out_inverse_unique_partition_indices\", 0);\n      *num_unique_ids_lbn = identity_op.output(\"out_cur_rank_num_unique\", 0);\n      *unique_ids_lbn = identity_op.output(\"out_cur_rank_unique_ids\", 0);\n      *unique_table_ids_lbn = identity_op.output(\"out_cur_rank_unique_table_ids\", 0);\n      *inverse_indices_lbn = identity_op.output(\"out_cur_rank_inverse_indices\", 0);\n      *num_unique_matrix_lbn = identity_op.output(\"out_num_unique_matrix\", 0);\n    } else {\n      *inner_inverse_unique_partition_indices_lbn =\n          id_shuffle_op.output(\"inverse_unique_partition_indices\", 0);\n      *num_unique_ids_lbn = id_shuffle_op.output(\"cur_rank_num_unique\", 0);\n      *unique_ids_lbn = id_shuffle_op.output(\"cur_rank_unique_ids\", 0);\n      *unique_table_ids_lbn = id_shuffle_op.output(\"cur_rank_unique_table_ids\", 0);\n      *inverse_indices_lbn = id_shuffle_op.output(\"cur_rank_inverse_indices\", 0);\n      *num_unique_matrix_lbn = id_shuffle_op.output(\"num_unique_matrix\", 0);\n    }\n    *prefetch_num_unique_lbn = id_shuffle_op.output(\"cur_rank_num_unique\", 0);\n    *prefetch_unique_ids_lbn = id_shuffle_op.output(\"cur_rank_unique_ids\", 0);\n    *prefetch_unique_table_ids_lbn = id_shuffle_op.output(\"cur_rank_unique_table_ids\", 0);\n  }\n}\n\nvoid MakeConstantInitializerAttr(const int64_t embedding_size, const int64_t line_size,\n                                 const std::vector<float>& values, std::string* initializer_attr) {\n  if (embedding_size == line_size) { return; }\n  const int32_t num_states = line_size / embedding_size - 1;\n  CHECK_GT(num_states, 0) << \"num_states \" << num_states;\n  CHECK(values.size() == 0 || num_states == values.size())\n      << \"must set \" << num_states << \" optimizer states init value, but get \" << values.size();\n  nlohmann::json initializers;\n  for (int32_t i = 0; i < num_states; ++i) {\n    nlohmann::json initializer;\n    initializer[\"type\"] = \"constant\";\n    const float initial_value = values.size() > 0 ? values.at(i) : 0.0;\n    initializer[\"value\"] = initial_value;\n    initializers.push_back(initializer);\n  }\n  *initializer_attr = initializers.dump();\n}\n\nvoid ScaleGrad(JobPassCtx* ctx, const OpGraph& op_graph, JobBuilder* job_builder,\n               const ParallelConf& embedding_parallel_conf, const int64_t embedding_scope_symbol_id,\n               const bool has_clip_grad, const std::string& embedding_grad_lbn,\n               std::string* new_embedding_grad_lbn, std::string* update_skip_if_lbn,\n               std::string* fuse_to_update_down_scale_by_lbn, double* fuse_to_update_scale) {\n  *new_embedding_grad_lbn = embedding_grad_lbn;\n  const TrainConf& train_conf = job_builder->job().job_conf().train_conf();\n  double scale = GetLossInstanceNumScaleFactor(op_graph, job_builder);\n  if (train_conf.has_dynamic_loss_scale_policy()) {\n    const auto& dynamic_loss_scale_state =\n        CHECK_JUST(ctx->GetState<DynamicLossScaleJobPassState>(\"dynamic_loss_scale_state\"));\n    const std::string& loss_scale_val_lbn = dynamic_loss_scale_state.loss_scale_val_lbn();\n    *update_skip_if_lbn = dynamic_loss_scale_state.count_not_finite_lbn();\n    if (has_clip_grad) {\n      const LogicalBlobId loss_scale_val_lbi = GenLogicalBlobId(loss_scale_val_lbn);\n      const OpNode* loss_scale_node = op_graph.OpNode4OpName(loss_scale_val_lbi.op_name());\n      auto inv_scale_op = user_op::UserOpConfWrapperBuilder(\n                              \"OneEmbedding-DynamicLossScale-Reciprocal-\" + NewUniqueId())\n                              .Op(\"reciprocal\")\n                              .Input(\"x\", loss_scale_val_lbn)\n                              .Output(\"y\")\n                              .ScopeSymbolId(loss_scale_node->op().op_conf().scope_symbol_id())\n                              .Build();\n      job_builder->AddOps(loss_scale_node->parallel_desc().parallel_conf(),\n                          {inv_scale_op.op_conf()});\n\n      auto scalar_mul_op = user_op::UserOpConfWrapperBuilder(\n                               \"OneEmbedding-ModelDiffScale-ScalarMul-\" + NewUniqueId())\n                               .Op(\"scalar_mul_by_tensor\")\n                               .Input(\"x\", *new_embedding_grad_lbn)\n                               .Input(\"scalar\", inv_scale_op.output(\"y\", 0))\n                               .Output(\"y\")\n                               .ScopeSymbolId(embedding_scope_symbol_id)\n                               .Build();\n      job_builder->AddOps(embedding_parallel_conf, {scalar_mul_op.op_conf()});\n      *new_embedding_grad_lbn = scalar_mul_op.output(\"y\", 0);\n    } else {\n      *fuse_to_update_down_scale_by_lbn = loss_scale_val_lbn;\n    }\n  } else if (train_conf.has_loss_scale_factor()) {\n    double down_scale_factor = 1.0f / train_conf.loss_scale_factor();\n    scale *= down_scale_factor;\n  }\n  if (has_clip_grad) {\n    auto scalar_mul_op =\n        user_op::UserOpConfWrapperBuilder(\"OneEmbedding-ModelDiffScale-ScalarMul-\" + NewUniqueId())\n            .Op(\"scalar_mul\")\n            .Input(\"in\", *new_embedding_grad_lbn)\n            .Output(\"out\")\n            .Attr<bool>(\"has_float_operand\", true)\n            .Attr<double>(\"float_operand\", scale)\n            .Attr<bool>(\"has_int_operand\", false)\n            .Attr<int64_t>(\"int_operand\", 0)\n            .ScopeSymbolId(embedding_scope_symbol_id)\n            .Build();\n    job_builder->AddOps(embedding_parallel_conf, {scalar_mul_op.op_conf()});\n    *new_embedding_grad_lbn = scalar_mul_op.output(\"out\", 0);\n    *fuse_to_update_scale = 1.0;\n  } else {\n    *fuse_to_update_scale = scale;\n  }\n}\n\nbool IsSupportFusedUpdatePut(const bool is_full_cache, const bool enable_auto_mixed_precision,\n                             const bool is_sgd, const std::string& down_scale_by_lbn,\n                             const std::string& skip_if_lbn, const float l1, const float l2,\n                             const float weight_decay) {\n  if (!ParseBooleanFromEnv(\"ONEFLOW_ONE_EMBEDDING_FUSE_UPDATE_PUT\", true)) { return false; }\n  if (!is_full_cache) { return false; }\n  if (!enable_auto_mixed_precision) { return false; }\n  if (!is_sgd) { return false; }\n  if (!ParseBooleanFromEnv(\"ONEFLOW_ONE_EMBEDDING_GRADIENT_SHUFFLE_USE_FP16\", true)) {\n    return false;\n  }\n  if (!down_scale_by_lbn.empty()) { return false; }\n  if (!skip_if_lbn.empty()) { return false; }\n  if (l1 != 0) { return false; }\n  if (l2 != 0) { return false; }\n  if (weight_decay != 0) { return false; }\n  return true;\n}\n\nvoid BuildEmbeddingUpdate(\n    JobPassCtx* ctx, const OpGraph& op_graph, JobBuilder* job_builder,\n    const ParallelConf& embedding_parallel_conf, const int64_t embedding_scope_symbol_id,\n    const bool is_full_cache, const int64_t embedding_size, const int64_t line_size, const float l1,\n    const float l2, const std::string& embedding_name, const OptimizerConf& optimizer_conf,\n    const user_op::UserOpConfWrapper& embedding_op, const std::string& num_unique_ids_lbn,\n    const std::string& unique_ids_lbn, const std::string& unique_values_lbn,\n    const std::string& embedding_grad_lbn, const std::string& learning_rate_lbn,\n    std::string* new_embedding_grad_lbn, std::string* state_initializer,\n    OperatorConf* embedding_update_new_op_conf) {\n  const TrainConf& train_conf = job_builder->job().job_conf().train_conf();\n  const bool has_clip_grad = optimizer_conf.has_clip_conf();\n  *new_embedding_grad_lbn = embedding_grad_lbn;\n  std::string update_skip_if_lbn;\n  std::string fuse_to_update_down_scale_by_lbn;\n  double fuse_to_update_scale = 1.0;\n  ScaleGrad(ctx, op_graph, job_builder, embedding_parallel_conf, embedding_scope_symbol_id,\n            has_clip_grad, embedding_grad_lbn, new_embedding_grad_lbn, &update_skip_if_lbn,\n            &fuse_to_update_down_scale_by_lbn, &fuse_to_update_scale);\n\n  if (IsSupportFusedUpdatePut(is_full_cache, ctx->job_desc().enable_auto_mixed_precision(),\n                              optimizer_conf.has_naive_conf(), fuse_to_update_down_scale_by_lbn,\n                              update_skip_if_lbn, l1, l2,\n                              optimizer_conf.weight_decay_conf().weight_decay_rate())) {\n    user_op::UserOpConfWrapperBuilder fused_embedding_update_put_op_builder(\n        embedding_op.op_name() + \"_fused_embedding_update_put\" + NewUniqueId());\n    user_op::UserOpConfWrapper fused_embedding_update_put_op =\n        fused_embedding_update_put_op_builder.OpTypeName(\"one_embedding_fused_sgd_update_put\")\n            .Input(\"num_unique_ids\", num_unique_ids_lbn)\n            .Input(\"unique_ids\", unique_ids_lbn)\n            .Input(\"unique_embeddings\", unique_values_lbn)\n            .Input(\"embedding_grad\", *new_embedding_grad_lbn)\n            .Input(\"learning_rate\", learning_rate_lbn)\n            .Attr<double>(\"scale\", fuse_to_update_scale)\n            .Attr<std::string>(\"embedding_name\", embedding_name)\n            .Attr<int64_t>(\"embedding_size\", embedding_size)\n            .Attr<int64_t>(\"line_size\", line_size)\n            .ScopeSymbolId(embedding_scope_symbol_id)\n            .Build();\n    *embedding_update_new_op_conf = fused_embedding_update_put_op.op_conf();\n    if (!ParseBooleanFromEnv(\"ONEFLOW_ONE_EMBEDDING_DISABLE_PIPELINED_EXECUTION\", false)) {\n      embedding_update_new_op_conf->set_stream_name_hint(embedding_name + \"_EMBEDDING\");\n    }\n    return;\n  }\n\n  auto AddAdamBiasCorrectionFactorOp = [&](float beta_val,\n                                           const std::string& op_name) -> std::string {\n    user_op::UserOpConfWrapperBuilder op_builder(embedding_op.op_name() + op_name);\n    const auto adam_bias_correction_factor_op =\n        op_builder.OpTypeName(\"adam_bias_correction_factor\")\n            .Input(\"train_step\", train_conf.train_step_lbn())\n            .Attr<float>(\"beta\", beta_val)\n            .Output(\"out\")\n            .ScopeSymbolId(embedding_scope_symbol_id)\n            .Build();\n    job_builder->AddOps(embedding_parallel_conf, {adam_bias_correction_factor_op.op_conf()});\n    return adam_bias_correction_factor_op.output(\"out\", 0);\n  };\n  user_op::UserOpConfWrapperBuilder embedding_update_op_builder(\n      embedding_op.op_name() + \"_embedding_update\" + NewUniqueId());\n  std::vector<float> state_constant_init_values;\n  if (optimizer_conf.has_naive_conf()) {\n    embedding_update_op_builder.OpTypeName(\"one_embedding_sgd_update\");\n  } else if (optimizer_conf.has_momentum_conf()) {\n    embedding_update_op_builder.OpTypeName(\"one_embedding_momentum_update\")\n        .Attr<float>(\"beta\", optimizer_conf.momentum_conf().beta());\n  } else if (optimizer_conf.has_adam_conf()) {\n    const AdamModelUpdateConf& adam_conf = optimizer_conf.adam_conf();\n    if (adam_conf.smart_decay()) {\n      CHECK(adam_conf.do_bias_correction())\n          << \"when use smart decay adam, do_bias_correction should be true. but got \"\n          << adam_conf.do_bias_correction();\n      embedding_update_op_builder.OpTypeName(\"one_embedding_smart_decay_sparse_adam_update\")\n          .Input(\"train_step\", train_conf.train_step_lbn())\n          .Attr<float>(\"beta1\", adam_conf.beta1())\n          .Attr<float>(\"beta2\", adam_conf.beta2())\n          .Attr<float>(\"epsilon\", adam_conf.epsilon())\n          .Attr<bool>(\"do_bias_correction\", adam_conf.do_bias_correction());\n    } else {\n      embedding_update_op_builder.OpTypeName(\"one_embedding_adam_update\")\n          .Attr<float>(\"beta1\", adam_conf.beta1())\n          .Attr<float>(\"beta2\", adam_conf.beta2())\n          .Attr<float>(\"epsilon\", adam_conf.epsilon())\n          .Attr<bool>(\"do_bias_correction\", adam_conf.do_bias_correction());\n      if (adam_conf.do_bias_correction()) {\n        const std::string bias_correction1_lbn =\n            AddAdamBiasCorrectionFactorOp(adam_conf.beta1(), \"adam_bias_correction_factor1\");\n        const std::string bias_correction2_lbn =\n            AddAdamBiasCorrectionFactorOp(adam_conf.beta2(), \"adam_bias_correction_factor2\");\n        embedding_update_op_builder.Input(\"bias_correction1\", bias_correction1_lbn)\n            .Input(\"bias_correction2\", bias_correction2_lbn);\n      }\n    }\n  } else if (optimizer_conf.has_adagrad_conf()) {\n    const AdagradModelUpdateConf& adagrad_conf = optimizer_conf.adagrad_conf();\n    state_constant_init_values.push_back(adagrad_conf.initial_accumulator_value());\n    embedding_update_op_builder.OpTypeName(\"one_embedding_adagrad_update\")\n        .Input(\"train_step\", train_conf.train_step_lbn())\n        .Attr<float>(\"lr_decay\", adagrad_conf.lr_decay())\n        .Attr<float>(\"epsilon\", adagrad_conf.epsilon());\n  } else if (optimizer_conf.has_ftrl_conf()) {\n    const FtrlModelUpdateConf& ftrl_conf = optimizer_conf.ftrl_conf();\n    state_constant_init_values.push_back(ftrl_conf.initial_accumulator_value());\n    // For `z`, its init value is 0.0.\n    state_constant_init_values.push_back(0.0);\n    embedding_update_op_builder.OpTypeName(\"one_embedding_ftrl_update\")\n        .Attr<float>(\"lr_power\", ftrl_conf.lr_power())\n        .Attr<float>(\"lambda1\", ftrl_conf.lambda1())\n        .Attr<float>(\"lambda2\", ftrl_conf.lambda2())\n        .Attr<float>(\"beta\", ftrl_conf.beta());\n  } else {\n    UNIMPLEMENTED();\n  }\n  MakeConstantInitializerAttr(embedding_size, line_size, state_constant_init_values,\n                              state_initializer);\n\n  embedding_update_op_builder.Input(\"num_unique_ids\", num_unique_ids_lbn)\n      .Input(\"unique_embeddings\", unique_values_lbn)\n      .Input(\"learning_rate\", learning_rate_lbn)\n      .Attr<float>(\"weight_decay\", optimizer_conf.weight_decay_conf().weight_decay_rate())\n      .Attr<float>(\"l1\", l1)\n      .Attr<float>(\"l2\", l2)\n      .Output(\"updated_unique_embeddings\");\n  if (!update_skip_if_lbn.empty()) {\n    embedding_update_op_builder.Input(\"skip_if\", update_skip_if_lbn);\n  }\n  if (!fuse_to_update_down_scale_by_lbn.empty()) {\n    CHECK(!has_clip_grad);\n    embedding_update_op_builder.Input(\"down_scale_by_tensor\", fuse_to_update_down_scale_by_lbn);\n  }\n  user_op::UserOpConfWrapper embedding_update_op =\n      embedding_update_op_builder.Input(\"embedding_grad\", *new_embedding_grad_lbn)\n          .Attr<double>(\"scale\", fuse_to_update_scale)\n          .Attr<std::string>(\"embedding_name\", embedding_name)\n          .Attr<int64_t>(\"embedding_size\", embedding_size)\n          .Attr<int64_t>(\"line_size\", line_size)\n          .ScopeSymbolId(embedding_scope_symbol_id)\n          .Build();\n  *embedding_update_new_op_conf = embedding_update_op.op_conf();\n  if (!ParseBooleanFromEnv(\"ONEFLOW_ONE_EMBEDDING_DISABLE_PIPELINED_EXECUTION\", false)) {\n    embedding_update_new_op_conf->set_stream_name_hint(embedding_name + \"_EMBEDDING\");\n  }\n\n  user_op::UserOpConfWrapperBuilder embedding_put_op_builder(embedding_op.op_name()\n                                                             + \"_embedding_put\" + NewUniqueId());\n  user_op::UserOpConfWrapper embedding_put_op =\n      embedding_put_op_builder.OpTypeName(\"embedding_put\")\n          .Input(\"num_unique_ids\", num_unique_ids_lbn)\n          .Input(\"unique_ids\", unique_ids_lbn)\n          .Input(\"unique_embeddings\", embedding_update_op.output(\"updated_unique_embeddings\", 0))\n          .Attr<std::string>(\"embedding_name\", embedding_name)\n          .ScopeSymbolId(embedding_scope_symbol_id)\n          .Build();\n  OperatorConf embedding_put_new_op_conf = embedding_put_op.op_conf();\n  if (!ParseBooleanFromEnv(\"ONEFLOW_ONE_EMBEDDING_DISABLE_PIPELINED_EXECUTION\", false)) {\n    embedding_put_new_op_conf.set_stream_name_hint(embedding_name + \"_EMBEDDING\");\n  }\n  job_builder->AddOps(embedding_parallel_conf, {embedding_put_new_op_conf});\n}\n\nvoid UpdateConsumerOpConf(const OpNode* consumer, const LogicalBlobId& out,\n                          const std::string& new_out_lbn,\n                          HashMap<std::string, OperatorConf>* op_name2op_conf) {\n  const std::string& consumer_op_name = consumer->op().op_name();\n  if (op_name2op_conf->find(consumer_op_name) == op_name2op_conf->end()) {\n    (*op_name2op_conf)[consumer_op_name] = consumer->op().op_conf();\n  }\n  for (const std::string& ibn : consumer->op().input_bns()) {\n    if (consumer->op().BnInOp2Lbi(ibn) == out) {\n      OperatorConf& consumer_op_conf = op_name2op_conf->at(consumer_op_name);\n      const auto& new_val = new_out_lbn;\n      const auto& old_val = ReplaceInputLbnInOpCustomizedConf(&consumer_op_conf, ibn, new_val);\n      CHECK_EQ(GenLogicalBlobName(out), old_val);\n    }\n  }\n}\n\nstd::string GlobalAbsMaxMin(JobBuilder* job_builder,\n                            const HashMap<std::string, std::string>& shadow_op_name2grad_lbn,\n                            float p, const std::string& total_norm_lbn, bool max_or_min,\n                            const ParallelConf& embedding_parallel_conf,\n                            const int64_t embedding_scope_symbol_id,\n                            const ParallelConf& parallel_conf, const int64_t scope_symbol_id) {\n  bool has_split = true;\n  std::string multi_reduce_op_type_name =\n      has_split ? (max_or_min ? \"local_multi_reduce_max_abs\" : \"local_multi_reduce_min_abs\")\n                : (max_or_min ? \"multi_reduce_max_abs\" : \"multi_reduce_min_abs\");\n  std::string multi_reduce_op_name =\n      \"OneEmbedding-ClipGradient-GlobalNorm-MultiReduceXimumAbs-\" + NewUniqueId();\n  auto multi_reduce_op_builder = user_op::UserOpConfWrapperBuilder(multi_reduce_op_name)\n                                     .Op(multi_reduce_op_type_name)\n                                     .Output(\"y\")\n                                     .ScopeSymbolId(embedding_scope_symbol_id);\n  for (const auto& pair : shadow_op_name2grad_lbn) {\n    const std::string& grad_lbn = pair.second;\n    multi_reduce_op_builder.Input(\"x\", grad_lbn);\n  }\n  auto multi_reduce_op = multi_reduce_op_builder.Build();\n  job_builder->AddOps(embedding_parallel_conf, {multi_reduce_op.op_conf()});\n  std::string embedding_reduce_lbn = multi_reduce_op.output(\"y\", 0);\n  if (has_split) {\n    std::string group_reduce_op_type_name = max_or_min ? \"reduce_max\" : \"reduce_min\";\n    std::string group_reduce_op_name =\n        \"OneEmbedding-ClipGradient-GlobalNorm-GroupReduceXimum-\" + NewUniqueId();\n    auto group_reduce_op = user_op::UserOpConfWrapperBuilder(group_reduce_op_name)\n                               .Op(group_reduce_op_type_name)\n                               .Input(\"input_tensor\", multi_reduce_op.output(\"y\", 0))\n                               .Output(\"output_tensor\")\n                               .Attr(\"axis\", std::vector<int32_t>{0})\n                               .Attr(\"keepdims\", false)\n                               .ScopeSymbolId(embedding_scope_symbol_id)\n                               .Build();\n    job_builder->AddOps(embedding_parallel_conf, {group_reduce_op.op_conf()});\n    embedding_reduce_lbn = group_reduce_op.output(\"output_tensor\", 0);\n  }\n  if (!total_norm_lbn.empty()) {\n    auto stack_op_builder = user_op::UserOpConfWrapperBuilder(\n                                \"OneEmbedding-ClipGradient-GlobalNorm-GlobalStack-\" + NewUniqueId())\n                                .Op(\"stack\")\n                                .Input(\"in\", embedding_reduce_lbn)\n                                .Input(\"in\", total_norm_lbn)\n                                .Output(\"out\")\n                                .Attr(\"axis\", int64_t(0))\n                                .Attr(\"max_dim_size\", static_cast<int64_t>(2))\n                                .ScopeSymbolId(scope_symbol_id);\n    auto stack_op = stack_op_builder.Build();\n    job_builder->AddOps(parallel_conf, {stack_op.op_conf()});\n    std::string reduce_op_type_name = max_or_min ? \"reduce_max\" : \"reduce_min\";\n    std::string reduce_op_name =\n        \"OneEmbedding-ClipGradient-GlobalNorm-GlobalReduceXimum-\" + NewUniqueId();\n    auto reduce_op = user_op::UserOpConfWrapperBuilder(reduce_op_name)\n                         .Op(reduce_op_type_name)\n                         .Input(\"input_tensor\", stack_op.output(\"out\", 0))\n                         .Output(\"output_tensor\")\n                         .Attr(\"axis\", std::vector<int32_t>{0})\n                         .Attr(\"keepdims\", false)\n                         .ScopeSymbolId(scope_symbol_id)\n                         .Build();\n    job_builder->AddOps(parallel_conf, {reduce_op.op_conf()});\n    return reduce_op.output(\"output_tensor\", 0);\n  } else {\n    return embedding_reduce_lbn;\n  }\n}\n\nstd::string GlobalNorm(JobBuilder* job_builder,\n                       const HashMap<std::string, std::string>& shadow_op_name2grad_lbn, float p,\n                       const std::string& total_norm_lbn,\n                       const ParallelConf& embedding_parallel_conf,\n                       const int64_t embedding_scope_symbol_id, const ParallelConf& parallel_conf,\n                       const int64_t scope_symbol_id) {\n  auto multi_reduce_sum_op_builder =\n      user_op::UserOpConfWrapperBuilder(\"OneEmbedding-ClipGradient-GlobalNorm-MultiReduceSumPowAbs-\"\n                                        + NewUniqueId())\n          .Op(\"multi_reduce_sum_pow_abs\")\n          .Attr(\"p\", static_cast<float>(p))\n          .Output(\"y\")\n          .ScopeSymbolId(embedding_scope_symbol_id);\n  for (const auto& pair : shadow_op_name2grad_lbn) {\n    const std::string grad_lbn = pair.second;\n    multi_reduce_sum_op_builder.Input(\"x\", grad_lbn);\n  }\n  const auto multi_reduce_sum_op = multi_reduce_sum_op_builder.Build();\n  job_builder->AddOps(embedding_parallel_conf, {multi_reduce_sum_op.op_conf()});\n  const std::string& embedding_sum_pow_abs_lbn = multi_reduce_sum_op.output(\"y\", 0);\n  std::string global_pow_in_lbn;\n  if (!total_norm_lbn.empty()) {\n    auto pow_op = user_op::UserOpConfWrapperBuilder(\n                      \"OneEmbedding-ClipGradient-GlobalNorm-GlobalPow-\" + NewUniqueId())\n                      .Op(\"scalar_pow\")\n                      .Input(\"in\", total_norm_lbn)\n                      .Attr(\"float_operand\", static_cast<double>(p))\n                      .Attr(\"has_float_operand\", true)\n                      .Output(\"out\")\n                      .ScopeSymbolId(scope_symbol_id)\n                      .Build();\n    job_builder->AddOps(parallel_conf, {pow_op.op_conf()});\n    user_op::UserOpConfWrapperBuilder add_op_builder(\"OneEmbedding-ClipGradient-GlobalNorm-Add-\"\n                                                     + NewUniqueId());\n    const auto add_op = add_op_builder.Op(\"add_n\")\n                            .Input(\"in\", embedding_sum_pow_abs_lbn)\n                            .Input(\"in\", pow_op.output(\"out\", 0))\n                            .Output(\"out\")\n                            .ScopeSymbolId(scope_symbol_id)\n                            .Build();\n    job_builder->AddOps(parallel_conf, {add_op.op_conf()});\n    global_pow_in_lbn = add_op.output(\"out\", 0);\n  } else {\n    global_pow_in_lbn = embedding_sum_pow_abs_lbn;\n  }\n  auto global_pow_op = user_op::UserOpConfWrapperBuilder(\n                           \"OneEmbedding-ClipGradient-GlobalNorm-GlobalPow-\" + NewUniqueId())\n                           .Op(\"scalar_pow\")\n                           .Input(\"in\", global_pow_in_lbn)\n                           .Attr(\"float_operand\", static_cast<double>(1.0 / p))\n                           .Attr(\"has_float_operand\", true)\n                           .Output(\"out\")\n                           .ScopeSymbolId(scope_symbol_id)\n                           .Build();\n  job_builder->AddOps(parallel_conf, {global_pow_op.op_conf()});\n  return global_pow_op.output(\"out\", 0);\n}\n\nstd::string GetClampCoeff(JobBuilder* job_builder, const std::string& total_norm_lbn,\n                          float max_norm, const ParallelConf& parallel_conf,\n                          const int64_t scope_symbol_id) {\n  auto add_eps_ops = user_op::UserOpConfWrapperBuilder(\n                         \"OneEmbedding-ClipGradient-GlobalNorm-AddEps-\" + NewUniqueId())\n                         .Op(\"scalar_add\")\n                         .Input(\"in\", total_norm_lbn)\n                         .Attr(\"float_operand\", 1e-6)\n                         .Attr(\"has_float_operand\", true)\n                         .Output(\"out\")\n                         .ScopeSymbolId(scope_symbol_id)\n                         .Build();\n  job_builder->AddOps(parallel_conf, {add_eps_ops.op_conf()});\n\n  auto inv_op =\n      user_op::UserOpConfWrapperBuilder(\"OneEmbedding-ClipGradient-GlobalNorm-Inv-\" + NewUniqueId())\n          .Op(\"reciprocal_no_nan\")\n          .Input(\"x\", add_eps_ops.output(\"out\", 0))\n          .Output(\"y\")\n          .ScopeSymbolId(scope_symbol_id)\n          .Build();\n  job_builder->AddOps(parallel_conf, {inv_op.op_conf()});\n\n  auto coeff_op = user_op::UserOpConfWrapperBuilder(\"OneEmbedding-ClipGradient-GlobalNorm-Coeff-\"\n                                                    + NewUniqueId())\n                      .Op(\"scalar_mul\")\n                      .Input(\"in\", inv_op.output(\"y\", 0))\n                      .Attr(\"float_operand\", static_cast<double>(max_norm))\n                      .Attr(\"has_float_operand\", true)\n                      .Output(\"out\")\n                      .ScopeSymbolId(scope_symbol_id)\n                      .Build();\n  job_builder->AddOps(parallel_conf, {coeff_op.op_conf()});\n\n  auto clamp_coeff_op = user_op::UserOpConfWrapperBuilder(\n                            \"OneEmbedding-ClipGradient-GlobalNorm-Clamp-\" + NewUniqueId())\n                            .Op(\"clip_by_scalar_max\")\n                            .Input(\"x\", coeff_op.output(\"out\", 0))\n                            .Attr(\"floating_max\", 1.0)\n                            .Output(\"y\")\n                            .ScopeSymbolId(scope_symbol_id)\n                            .Build();\n  job_builder->AddOps(parallel_conf, {clamp_coeff_op.op_conf()});\n  return clamp_coeff_op.output(\"y\", 0);\n}\n\nvoid ClipGradByGlobalNorm(JobPassCtx* ctx, const OpGraph& op_graph, JobBuilder* job_builder,\n                          const OptimizerConf& optimizer_conf,\n                          const HashMap<std::string, std::string>& shadow_op_name2grad_lbn,\n                          const HashMap<std::string, OperatorConf>& grad_lbn2update_op_conf,\n                          const ParallelConf& embedding_parallel_conf,\n                          const int64_t embedding_scope_symbol_id,\n                          HashMap<std::string, OperatorConf>* op_name2op_conf) {\n  const ClipByGlobalNormConf& conf = optimizer_conf.clip_conf().clip_by_global_norm();\n  double norm_type = conf.norm_type();\n  auto clip_by_global_norm_pass_state =\n      CHECK_JUST(ctx->MutableState<ClipByGlobalNormJobPassState>(\"clip_by_global_norm_state\"));\n\n  const auto NewGlobalNorm = [&](const std::string& total_norm_lbn,\n                                 const ParallelConf& parallel_conf,\n                                 const int64_t scope_symbol_id) -> std::string {\n    if (std::isinf(norm_type) && norm_type > 0) {\n      return GlobalAbsMaxMin(job_builder, shadow_op_name2grad_lbn, norm_type, total_norm_lbn, true,\n                             embedding_parallel_conf, embedding_scope_symbol_id, parallel_conf,\n                             scope_symbol_id);\n    } else if (std::isinf(norm_type) && norm_type < 0) {\n      UNIMPLEMENTED()\n          << \"one_embedding gradient's invalid values set to 0, so not support abs_reduce_min.\";\n      return GlobalAbsMaxMin(job_builder, shadow_op_name2grad_lbn, norm_type, total_norm_lbn, false,\n                             embedding_parallel_conf, embedding_scope_symbol_id, parallel_conf,\n                             scope_symbol_id);\n    } else {\n      return GlobalNorm(job_builder, shadow_op_name2grad_lbn, norm_type, total_norm_lbn,\n                        embedding_parallel_conf, embedding_scope_symbol_id, parallel_conf,\n                        scope_symbol_id);\n    }\n  };\n  bool has_total_norm_state = false;\n  std::string variable_op_name;\n  for (const auto& var_op_name : optimizer_conf.variable_op_names()) {\n    if (clip_by_global_norm_pass_state->HasTotalNormState(var_op_name)) {\n      has_total_norm_state = true;\n      variable_op_name = var_op_name;\n      break;\n    }\n  }\n  std::string coeff_lbn;\n  if (has_total_norm_state) {\n    // has_total_norm_state means there are some gradients in same optimizer group with\n    // embedding_grads, the total_norm_lbn is the global norm of other gradients, embedding_grads\n    // need to compute global norm with total_norm_lbn and update the consumer of the\n    // total_norm_lbn, no need to compute clamp coff because it has been built in autograd pass.\n    const std::shared_ptr<ClipByGlobalNormJobPassState::TotalNormState>& total_norm_state =\n        clip_by_global_norm_pass_state->GetTotalNormState(variable_op_name);\n    const LogicalBlobId total_norm_lbi = GenLogicalBlobId(total_norm_state->total_norm_lbn());\n    std::string new_total_norm_lbn =\n        NewGlobalNorm(total_norm_state->total_norm_lbn(), total_norm_state->parallel_conf(),\n                      total_norm_state->scope_symbol_id());\n    const OpNode* total_norm_lbn_producer = op_graph.OpNode4OpName(total_norm_lbi.op_name());\n    for (const OpEdge* out_edge : total_norm_lbn_producer->out_edges()) {\n      const OpNode* consumer = out_edge->dst_node();\n      UpdateConsumerOpConf(consumer, total_norm_lbi, new_total_norm_lbn, op_name2op_conf);\n    }\n    total_norm_state->set_total_norm_lbn(new_total_norm_lbn);\n    coeff_lbn = total_norm_state->coeff_lbn();\n  } else {\n    // no norm_state means there are no gradients in same optimizer group with embedding_grad,\n    // embedding_grad compute the global norm and clip independently.\n    const std::string& new_total_norm_lbn =\n        NewGlobalNorm(\"\", embedding_parallel_conf, embedding_scope_symbol_id);\n    coeff_lbn = GetClampCoeff(job_builder, new_total_norm_lbn, conf.max_norm(),\n                              embedding_parallel_conf, embedding_scope_symbol_id);\n  }\n  for (const auto& pair : shadow_op_name2grad_lbn) {\n    const std::string& grad_lbn = pair.second;\n    const auto& it = grad_lbn2update_op_conf.find(grad_lbn);\n    CHECK(it != grad_lbn2update_op_conf.end());\n    OperatorConf update_op_conf = it->second;\n    *(*update_op_conf.mutable_user_conf()->mutable_input())[\"scale_by_tensor\"].mutable_s() =\n        StdVec2PbRpf<std::string>({coeff_lbn});\n    job_builder->AddOps(embedding_parallel_conf, {update_op_conf});\n  }\n}\n\nvoid FilterCurGradLbnAndUpdateOpConfPairs(\n    const ::google::protobuf::RepeatedPtrField<std::string>& variables,\n    const HashMap<std::string, std::string>& shadow_op_name2grad_lbn,\n    HashMap<std::string, std::string>* cur_shadow_op_name2grad_lbn) {\n  for (const std::string& variable : variables) {\n    const auto& it = shadow_op_name2grad_lbn.find(variable);\n    if (it != shadow_op_name2grad_lbn.end()) {\n      (*cur_shadow_op_name2grad_lbn)[variable] = it->second;\n    }\n  }\n}\n\nvoid FilterEmbeddingGradients(JobPassCtx* ctx, const OpGraph& op_graph, JobBuilder* job_builder,\n                              const HashMap<std::string, std::string>& shadow_op_name2grad_lbn,\n                              const HashMap<std::string, OperatorConf>& grad_lbn2update_op_conf,\n                              const ParallelConf& embedding_parallel_conf,\n                              const int64_t embedding_scope_symbol_id,\n                              HashMap<std::string, OperatorConf>* op_name2op_conf) {\n  for (const auto& optimizer_conf : job_builder->job().job_conf().train_conf().optimizer_conf()) {\n    HashMap<std::string, std::string> cur_shadow_op_name2grad_lbn;\n    FilterCurGradLbnAndUpdateOpConfPairs(optimizer_conf.variable_op_names(),\n                                         shadow_op_name2grad_lbn, &cur_shadow_op_name2grad_lbn);\n    if (!optimizer_conf.has_clip_conf()) {\n      for (const auto& pair : cur_shadow_op_name2grad_lbn) {\n        const auto& it = grad_lbn2update_op_conf.find(pair.second);\n        CHECK(it != grad_lbn2update_op_conf.end());\n        job_builder->AddOps(embedding_parallel_conf, {it->second});\n      }\n    } else {\n      ClipGradByGlobalNorm(ctx, op_graph, job_builder, optimizer_conf, cur_shadow_op_name2grad_lbn,\n                           grad_lbn2update_op_conf, embedding_parallel_conf,\n                           embedding_scope_symbol_id, op_name2op_conf);\n    }\n  }\n}\n\nbool IsRelatedOp(const OperatorConf& op) {\n  return op.has_user_conf() && (op.user_conf().op_type_name() == \"one_embedding_fused_lookup\");\n}\n\nbool NeedDoPass(const Job& job) {\n  return std::any_of(job.net().op().cbegin(), job.net().op().cend(), IsRelatedOp);\n}\n\n}  // namespace\n\nclass ReplaceEmbeddingOps final : public JobPass {\n public:\n  ReplaceEmbeddingOps() = default;\n  ~ReplaceEmbeddingOps() override = default;\n\n  bool IsEnabled(const JobPassCtx& ctx) const { return ctx.job_desc().IsTrain(); }\n  Maybe<void> Apply(const OpGraph& op_graph, JobBuilder* job_builder, JobPassCtx* ctx) const;\n\n  Maybe<void> Apply(Job* job, JobPassCtx* ctx) const override {\n    if (!IsEnabled(*ctx)) { return Maybe<void>::Ok(); }\n    if (!NeedDoPass(*job)) { return Maybe<void>::Ok(); }\n    const OpGraph op_graph(*job);\n    JobBuilder job_builder(job);\n    return Apply(op_graph, &job_builder, ctx);\n  }\n};\n\nMaybe<void> ReplaceEmbeddingOps::Apply(const OpGraph& op_graph, JobBuilder* job_builder,\n                                       JobPassCtx* ctx) const {\n  ParallelConf embedding_parallel_conf;\n  int64_t embedding_scope_symbol_id = 0;\n  HashMap<std::string, OperatorConf> op_name2op_conf;\n  HashMap<std::string, std::string> shadow_op_name2grad_lbn;\n  HashMap<std::string, OperatorConf> grad_lbn2update_op_conf;\n  op_graph.ForEachNode([&](const OpNode* op_node) {\n    const OperatorConf& op_conf = op_node->op().op_conf();\n    if (!op_conf.has_user_conf()) { return; }\n    if (!(op_conf.user_conf().op_type_name() == \"one_embedding_fused_lookup\")) { return; }\n    std::vector<OperatorConf> add_ops;\n    std::vector<std::string> delete_op_names;\n    const user_op::UserOpConfWrapper embedding_op(op_node->op().op_conf());\n    const OpNode* shadow_producer =\n        op_graph.OpNode4OpName(GenLogicalBlobId(embedding_op.input(\"shadow\", 0)).op_name());\n    std::string shadow_op_name;\n    if (shadow_producer->op().op_conf().has_variable_conf()) {\n      shadow_op_name = shadow_producer->op().op_name();\n    } else if (shadow_producer->op().op_conf().has_user_conf()\n               && shadow_producer->op().op_conf().user_conf().op_type_name() == \"cast\") {\n      const user_op::UserOpConfWrapper shadow_cast_op(shadow_producer->op().op_conf());\n      const OpNode* cast_producer =\n          op_graph.OpNode4OpName(GenLogicalBlobId(shadow_cast_op.input(\"in\", 0)).op_name());\n      CHECK(cast_producer->op().op_conf().has_variable_conf()) << cast_producer->op().op_name();\n      shadow_op_name = cast_producer->op().op_name();\n      delete_op_names.push_back(shadow_cast_op.op_name());\n    } else {\n      UNIMPLEMENTED() << \"shadow must be variable or variable and cast\";\n    }\n    // assume all embeddings have same placement\n    embedding_scope_symbol_id = embedding_op.op_conf().scope_symbol_id();\n    embedding_parallel_conf = op_node->parallel_desc().parallel_conf();\n    const std::string& embedding_name = embedding_op.attr<std::string>(\"embedding_name\");\n    const int64_t line_size = embedding_op.attr<int64_t>(\"line_size\");\n    const int64_t embedding_size = embedding_op.attr<int64_t>(\"embedding_size\");\n    const bool is_full_cache = embedding_op.attr<bool>(\"is_full_cache\");\n    const int64_t seed = embedding_op.attr<int64_t>(\"seed\");\n    const int64_t parallel_num = op_node->parallel_desc().parallel_num();\n    const bool use_system_gather =\n        (parallel_num == 1 && ParseBooleanFromEnv(\"ONEFLOW_ONE_EMBEDDING_USE_SYSTEM_GATHER\", true));\n    std::string new_embeddings_lbn;\n\n    // prefetch can not exec in advance when it consume id_shuffle_copy_out, because\n    // id_shuffle_copy_out's regster_num is 1. so we set id_shuffle out to\n    // prefetch_num_unique_ids_lbn and prefetch consume them for pipeline.\n    std::string prefetch_num_unique_ids_lbn;\n    std::string prefetch_unique_ids_lbn;\n    std::string prefetch_unique_table_ids_lbn;\n    std::string inner_inverse_unique_partition_indices_lbn;\n    std::string num_unique_ids_lbn;\n    std::string unique_ids_lbn;\n    std::string unique_table_ids_lbn;\n    std::string inverse_indices_lbn;\n    std::string num_unique_matrix_lbn;\n\n    BuildIdShuffle(use_system_gather, embedding_name, embedding_op, &add_ops,\n                   &prefetch_num_unique_ids_lbn, &prefetch_unique_ids_lbn,\n                   &prefetch_unique_table_ids_lbn, &inner_inverse_unique_partition_indices_lbn,\n                   &num_unique_ids_lbn, &unique_ids_lbn, &unique_table_ids_lbn,\n                   &inverse_indices_lbn, &num_unique_matrix_lbn);\n    const bool is_train_job = job_builder->job().job_conf().has_train_conf();\n    const bool no_optimizer_states = (embedding_size == line_size);\n    const bool has_embedding_prefetch = (!is_full_cache) && (is_train_job || no_optimizer_states);\n\n    OperatorConf embedding_prefetch_op_conf;\n    OperatorConf embedding_lookup_op_conf;\n    // embedding lookup op\n    std::string embedding_lbn, unique_values_lbn;\n    BuildEmbeddingLookup(\n        ctx, job_builder, embedding_size, line_size, embedding_name, seed, has_embedding_prefetch,\n        embedding_parallel_conf, embedding_op, prefetch_num_unique_ids_lbn, prefetch_unique_ids_lbn,\n        prefetch_unique_table_ids_lbn, num_unique_ids_lbn, unique_ids_lbn, unique_table_ids_lbn,\n        &embedding_lbn, &unique_values_lbn, &embedding_prefetch_op_conf, &embedding_lookup_op_conf);\n\n    if (use_system_gather) {\n      user_op::UserOpConfWrapperBuilder gather_op_builder(embedding_op.op_name()\n                                                          + \"_one_embedding_gather\");\n      user_op::UserOpConfWrapper gather_op =\n          gather_op_builder.OpTypeName(\"one_embedding_gather\")\n              .Input(\"in\", embedding_lbn)\n              .Input(\"indices\", inverse_indices_lbn)\n              .Output(\"out\")\n              .Attr<int64_t>(\"embedding_size\", embedding_size)\n              .Attr<std::string>(\"embedding_name\", embedding_name)\n              .ScopeSymbolId(embedding_scope_symbol_id)\n              .Build();\n      add_ops.push_back(gather_op.op_conf());\n      new_embeddings_lbn = gather_op.output(\"out\", 0);\n    } else {\n      // embedding shuffle op\n      BuildEmbeddingShuffle(job_builder, embedding_name, embedding_size, embedding_parallel_conf,\n                            embedding_op, inverse_indices_lbn,\n                            inner_inverse_unique_partition_indices_lbn, num_unique_matrix_lbn,\n                            embedding_lbn, &add_ops, &new_embeddings_lbn);\n    }\n    delete_op_names.push_back(embedding_op.op_name());\n\n    const LogicalBlobId out = GenLogicalBlobId(embedding_op.output(\"embeddings\", 0));\n    for (const OpEdge* out_edge : op_node->out_edges()) {\n      const OpNode* consumer = out_edge->dst_node();\n      UpdateConsumerOpConf(consumer, out, new_embeddings_lbn, &op_name2op_conf);\n    }\n    std::string state_initializer;\n    // find update op\n    const OpNode* producer =\n        op_graph.OpNode4OpName(GenLogicalBlobId(embedding_op.input(\"ids\", 0)).op_name());\n    for (OpEdge* edge : producer->out_edges()) {\n      const OpNode* consumer = edge->dst_node();\n      if (consumer->op().op_conf().has_user_conf()) {\n        const user_op::UserOpConfWrapper update_op_conf(consumer->op().op_conf());\n        if (update_op_conf.op_type_name() != \"one_embedding_fused_lookup_grad\") { continue; }\n        if (update_op_conf.attr<std::string>(\"embedding_name\")\n            != embedding_op.attr<std::string>(\"embedding_name\")) {\n          continue;\n        }\n        delete_op_names.push_back(update_op_conf.op_name());\n\n        OptimizerConf embedding_optimizer_conf;\n        bool found_embedding_optimizer = false;\n        for (const auto& optimizer_conf :\n             job_builder->job().job_conf().train_conf().optimizer_conf()) {\n          for (const auto& name : optimizer_conf.variable_op_names()) {\n            if (name == shadow_op_name) {\n              embedding_optimizer_conf = optimizer_conf;\n              found_embedding_optimizer = true;\n              break;\n            }\n          }\n          if (found_embedding_optimizer == true) { break; }\n        }\n        CHECK_EQ(found_embedding_optimizer, true) << shadow_op_name << \" has not found optimizer\";\n\n        std::string embedding_grad_lbn;\n        BuildEmbeddingGradientShuffle(\n            ctx, op_graph, job_builder, op_node, embedding_name, embedding_size, use_system_gather,\n            embedding_parallel_conf, embedding_scope_symbol_id, embedding_op, inverse_indices_lbn,\n            inner_inverse_unique_partition_indices_lbn, num_unique_matrix_lbn,\n            update_op_conf.input(\"embedding_grad\", 0), embedding_optimizer_conf.has_clip_conf(),\n            &embedding_grad_lbn);\n\n        const OpNode* shadow_node = op_graph.OpNode4OpName(shadow_op_name);\n        const VariableOpConf& shadow_variable_conf = shadow_node->op().op_conf().variable_conf();\n        float l1 = 0.0;\n        float l2 = 0.0;\n        if (shadow_variable_conf.has_regularizer()) {\n          const RegularizerConf& regularizer_conf = shadow_variable_conf.regularizer();\n          if (regularizer_conf.has_l1_l2_conf()) {\n            l1 = regularizer_conf.l1_l2_conf().l1();\n            l2 = regularizer_conf.l1_l2_conf().l2();\n          }\n        }\n        const std::string& learning_rate_lbn = embedding_optimizer_conf.learning_rate_lbn();\n\n        std::string new_embedding_grad_lbn;\n        OperatorConf embedding_update_op_conf;\n        BuildEmbeddingUpdate(ctx, op_graph, job_builder, embedding_parallel_conf,\n                             embedding_scope_symbol_id, is_full_cache, embedding_size, line_size,\n                             l1, l2, embedding_name, embedding_optimizer_conf, embedding_op,\n                             num_unique_ids_lbn, unique_ids_lbn, unique_values_lbn,\n                             embedding_grad_lbn, learning_rate_lbn, &new_embedding_grad_lbn,\n                             &state_initializer, &embedding_update_op_conf);\n        shadow_op_name2grad_lbn[shadow_op_name] = new_embedding_grad_lbn;\n        grad_lbn2update_op_conf[new_embedding_grad_lbn] = std::move(embedding_update_op_conf);\n      }\n    }\n    if ((state_initializer.empty()) && !no_optimizer_states) {\n      CHECK(!is_train_job) << \"train job must have set state initializer\";\n      MakeConstantInitializerAttr(embedding_size, line_size, {}, &state_initializer);\n    }\n    auto state_initializer_attr = ::oneflow::AttrValue();\n    state_initializer_attr.set_at_string(state_initializer);\n    if (has_embedding_prefetch) {\n      (*(embedding_prefetch_op_conf.mutable_user_conf()->mutable_attr()))[\"state_initializer\"] =\n          state_initializer_attr;\n      add_ops.push_back(embedding_prefetch_op_conf);\n    }\n    (*(embedding_lookup_op_conf.mutable_user_conf()->mutable_attr()))[\"state_initializer\"] =\n        state_initializer_attr;\n    add_ops.push_back(embedding_lookup_op_conf);\n    job_builder->DelOps(delete_op_names);\n    job_builder->AddOps(embedding_parallel_conf, add_ops);\n  });\n  if (shadow_op_name2grad_lbn.size() > 0) {\n    FilterEmbeddingGradients(ctx, op_graph, job_builder, shadow_op_name2grad_lbn,\n                             grad_lbn2update_op_conf, embedding_parallel_conf,\n                             embedding_scope_symbol_id, &op_name2op_conf);\n    JUST(DynamicLossScaleAddGradient(ctx, op_graph, job_builder, shadow_op_name2grad_lbn,\n                                     embedding_scope_symbol_id, embedding_parallel_conf));\n  }\n  for (const auto& pair : op_name2op_conf) { job_builder->MutOpsOnlyOnce({pair.second}); }\n  return Maybe<void>::Ok();\n}\n\nREGISTER_JOB_PASS(\"ReplaceEmbeddingOps\", ReplaceEmbeddingOps);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job_rewriter/rmsprop_optm.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/job_rewriter/optimizer.h\"\n#include \"oneflow/core/framework/framework.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nstd::string GenVariableOutputLbn(const OperatorConf& op_conf) {\n  CHECK(op_conf.has_variable_conf());\n  return GenLogicalBlobName(op_conf.name(), op_conf.variable_conf().out());\n}\n\nOperatorConf GenerateRmspropHelperVariableOpConf(const VariableOp& op, const std::string& name,\n                                                 const float initial_value) {\n  OperatorConf helper_variable_op(op.op_conf());\n  helper_variable_op.set_name(op.op_name() + \"-\" + name);\n  helper_variable_op.mutable_variable_conf()->set_out(\"out\");\n  InitializerConf constant_initializer;\n  constant_initializer.mutable_constant_conf()->set_value(initial_value);\n  *(helper_variable_op.mutable_variable_conf()->mutable_initializer()) = constant_initializer;\n  helper_variable_op.set_scope_symbol_id(op.op_conf().scope_symbol_id());\n  return helper_variable_op;\n}\n\nvoid GenerateOptimizerOpConf(JobPassCtx* ctx, const OpNode& var_op_node,\n                             const std::string& model_diff_lbn, const OptimizerConf& optimizer_conf,\n                             JobBuilder* job_builder) {\n  const VariableOp* var_op = dynamic_cast<const VariableOp*>(&var_op_node.op());\n  CHECK_NOTNULL(var_op);\n  OperatorConf mean_square_var(GenerateRmspropHelperVariableOpConf(*var_op, \"mean_square\", 0.f));\n  job_builder->AddOps(var_op_node.parallel_desc().parallel_conf(), {mean_square_var});\n\n  user_op::UserOpConfWrapperBuilder rmsprop_update_op_builder(var_op->op_name() + \"_optimizer\");\n  const RMSPropModelUpdateConf& rmsprop_conf = optimizer_conf.rmsprop_conf();\n  bool centered = rmsprop_conf.centered();\n  rmsprop_update_op_builder.OpTypeName(\"rmsprop_update\")\n      .Input(\"model\", GenLogicalBlobName(var_op->BnInOp2Lbi(\"out\")))\n      .Input(\"model_diff\", model_diff_lbn)\n      .Input(\"learning_rate\", optimizer_conf.learning_rate_lbn())\n      .Input(\"mean_square\", GenVariableOutputLbn(mean_square_var))\n      .Attr<bool>(\"centered\", centered)\n      .Attr<float>(\"epsilon\", rmsprop_conf.epsilon())\n      .Attr<float>(\"decay_rate\", rmsprop_conf.decay_rate())\n      .Attr<float>(\"weight_decay\", GetOptimizerWeightDecayRate(optimizer_conf, *var_op))\n      .ScopeSymbolId(var_op->op_conf().scope_symbol_id());\n\n  if (optimizer_conf.has_lr_scale()) {\n    rmsprop_update_op_builder.Attr<float>(\"learning_rate_scale\", optimizer_conf.lr_scale());\n  }\n\n  SetDynamicLossScaleSkipIf(ctx, &rmsprop_update_op_builder);\n\n  if (centered) {\n    OperatorConf mean_gradient_var(\n        GenerateRmspropHelperVariableOpConf(*var_op, \"mean_gradient\", 0.f));\n    job_builder->AddOps(var_op_node.parallel_desc().parallel_conf(), {mean_gradient_var});\n    rmsprop_update_op_builder.Input(\"mean_gradient\", GenVariableOutputLbn(mean_gradient_var));\n  }\n\n  user_op::UserOpConfWrapper rmsprop_update_op = rmsprop_update_op_builder.Build();\n  job_builder->AddOps(var_op_node.parallel_desc().parallel_conf(), {rmsprop_update_op.op_conf()});\n}\n\n}  // namespace\n\nREGISTER_OPTIMIZER(OptimizerConf::kRmspropConf, &GenerateOptimizerOpConf);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job_rewriter/sequential_one_embedding_shuffle_ops_pass.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/job_rewriter/job_pass.h\"\n#include \"oneflow/core/framework/framework.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nclass SequentialOneEmbeddingOpsPass final : public JobPass {\n public:\n  SequentialOneEmbeddingOpsPass() = default;\n  ~SequentialOneEmbeddingOpsPass() override = default;\n\n  bool IsEnabled(const JobPassCtx& ctx) const {\n    return ParseBooleanFromEnv(\"ONEFLOW_ONE_EMBEDDING_DISABLE_PIPELINED_EXECUTION\", false);\n  }\n  Maybe<void> Apply(const OpGraph& op_graph, JobBuilder* job_builder) const;\n\n  Maybe<void> Apply(Job* job, JobPassCtx* ctx) const override {\n    if (!IsEnabled(*ctx)) { return Maybe<void>::Ok(); }\n    const OpGraph op_graph(*job);\n    JobBuilder job_builder(job);\n    return Apply(op_graph, &job_builder);\n  }\n};\n\nMaybe<void> SequentialOneEmbeddingOpsPass::Apply(const OpGraph& op_graph,\n                                                 JobBuilder* job_builder) const {\n  HashMap<std::string, std::vector<std::string>> stream_name_hint2shuffle_op_names;\n  op_graph.TopoForEachNode([&](const OpNode* op_node) {\n    if (!(IsUserOpWithTypeName(op_node->op().op_conf(), \"id_shuffle\")\n          || IsUserOpWithTypeName(op_node->op().op_conf(), \"embedding_shuffle\")\n          || IsUserOpWithTypeName(op_node->op().op_conf(), \"embedding_gradient_shuffle\"))) {\n      return;\n    }\n    OperatorConf op_conf = op_node->op().op_conf();\n    std::string stream_name;\n    if (op_conf.has_stream_name_hint()) {\n      stream_name = op_conf.stream_name_hint();\n    } else {\n      stream_name = \"DEFAULT\";\n    }\n    const auto& it = stream_name_hint2shuffle_op_names.find(stream_name);\n    if (it != stream_name_hint2shuffle_op_names.end()) {\n      if (it->second.size() > 0) {\n        std::string pre_shuffle_op_name = it->second.back();\n        op_conf.add_ctrl_in_op_name(pre_shuffle_op_name);\n        job_builder->MutOpsOnlyOnce({op_conf});\n      }\n      it->second.push_back(op_conf.name());\n    } else {\n      std::vector<std::string> shuffle_ops{op_conf.name()};\n      CHECK(stream_name_hint2shuffle_op_names.emplace(stream_name, shuffle_ops).second);\n    }\n  });\n\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\nREGISTER_JOB_PASS(\"SequentialOneEmbeddingOpsPass\", SequentialOneEmbeddingOpsPass);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job_rewriter/sgd_optm.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/job_rewriter/optimizer.h\"\n#include \"oneflow/core/framework/framework.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nvoid GenerateOptimizerOpConf(JobPassCtx* ctx, const OpNode& var_op_node,\n                             const std::string& model_diff_lbn, const OptimizerConf& optimizer_conf,\n                             JobBuilder* job_builder) {\n  const VariableOp* var_op = dynamic_cast<const VariableOp*>(&var_op_node.op());\n  CHECK_NOTNULL(var_op);\n  user_op::UserOpConfWrapperBuilder sgd_update_op_builder(var_op->op_name() + \"_optimizer\");\n  sgd_update_op_builder.OpTypeName(\"sgd_update\")\n      .Input(\"model\", GenLogicalBlobName(var_op->BnInOp2Lbi(\"out\")))\n      .Input(\"model_diff\", model_diff_lbn)\n      .Input(\"learning_rate\", optimizer_conf.learning_rate_lbn())\n      .Attr<float>(\"weight_decay\", GetOptimizerWeightDecayRate(optimizer_conf, *var_op))\n      .ScopeSymbolId(var_op->op_conf().scope_symbol_id());\n  if (optimizer_conf.has_lr_scale()) {\n    sgd_update_op_builder.Attr<float>(\"learning_rate_scale\", optimizer_conf.lr_scale());\n  }\n  SetDynamicLossScaleSkipIf(ctx, &sgd_update_op_builder);\n  user_op::UserOpConfWrapper sgd_update_op = sgd_update_op_builder.Build();\n  job_builder->AddOps(var_op_node.parallel_desc().parallel_conf(), {sgd_update_op.op_conf()});\n}\n\n}  // namespace\n\nREGISTER_OPTIMIZER(OptimizerConf::kNaiveConf, &GenerateOptimizerOpConf);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job_rewriter/source_user_op_auto_tick.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/job_rewriter/autotick.h\"\n#include \"oneflow/core/framework/user_op_registry_manager.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nclass MutUserOpConTickInputHelper final : public MutOpConTickInputHelper {\n public:\n  MutUserOpConTickInputHelper() : MutOpConTickInputHelper() {}\n\n  bool VirtualIsTickInputBound() const override { return !op_conf().user_conf().input().empty(); }\n\n  OperatorConf NewTickInputBoundOpConf(const std::string& lbn) const override {\n    OperatorConf ret(op_conf());\n    (*ret.mutable_user_conf()->mutable_input())[user_op::kUserSourceOpTickInputArgName].add_s(lbn);\n    return ret;\n  }\n};\n\n}  // namespace\n\nREGISTER_AUTO_TICK(OperatorConf::kUserConf, MutUserOpConTickInputHelper);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job_rewriter/split_sparse_softmax_cross_entropy_op_pass.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/job_rewriter/job_pass.h\"\n#include \"oneflow/core/framework/framework.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nbool NeedDoPass(const Job& job) {\n  return std::any_of(job.net().op().cbegin(), job.net().op().cend(), [&](const OperatorConf& op) {\n    return op.has_user_conf() && op.user_conf().op_type_name() == \"sparse_softmax_cross_entropy_ms\";\n  });\n}\n\nvoid UpdateProbConsumerOpConf(const std::string& new_prob_lbn, const OpNode* op_node,\n                              JobBuilder* job_builder) {\n  for (const OpEdge* edge : op_node->out_edges()) {\n    OpNode* out_node = edge->dst_node();\n    OperatorConf new_conf = out_node->op().op_conf();\n    if (new_conf.has_user_conf()\n        && new_conf.user_conf().op_type_name() == \"sparse_softmax_cross_entropy_ms_grad\") {\n      CHECK_EQ(GenLogicalBlobName(out_node->op().BnInOp2Lbi(\"prob_0\")),\n               ReplaceInputLbnInOpCustomizedConf(&new_conf, \"prob_0\", new_prob_lbn));\n      job_builder->MutOpsOnlyOnce({new_conf});\n    }\n  }\n}\n\nclass SplitSparseSoftmaxCrossEntropyOpPass final : public JobPass {\n public:\n  SplitSparseSoftmaxCrossEntropyOpPass() = default;\n  ~SplitSparseSoftmaxCrossEntropyOpPass() override = default;\n\n  Maybe<void> Apply(const OpGraph& op_graph, JobBuilder* job_builder) const;\n\n  Maybe<void> Apply(Job* job, JobPassCtx* ctx) const override {\n    if (!NeedDoPass(*job)) { return Maybe<void>::Ok(); }\n    const OpGraph op_graph(*job);\n    JobBuilder job_builder(job);\n    return Apply(op_graph, &job_builder);\n  }\n};\n\nMaybe<void> SplitSparseSoftmaxCrossEntropyOpPass::Apply(const OpGraph& op_graph,\n                                                        JobBuilder* job_builder) const {\n  std::vector<std::string> to_del_op_names;\n  HashMap<std::string, OperatorConf> consumer_op_name2op_confs;\n  op_graph.ForEachNode([&](const OpNode* node) {\n    const OperatorConf& op_conf = node->op().op_conf();\n    if (!op_conf.has_user_conf()) { return; }\n    if (op_conf.user_conf().op_type_name() != \"sparse_softmax_cross_entropy_ms\") { return; }\n\n    const int64_t scope_symbol_id = node->op().op_conf().scope_symbol_id();\n    user_op::UserOpConfWrapper cur_op(op_conf);\n    const std::string& op_prediction_blob_name = cur_op.input(\"prediction\", 0);\n    const std::string& op_label_blob_name = cur_op.input(\"label\", 0);\n    const int32_t split_axis =\n        node->LogicalBlobDesc4Lbi(node->op().BnInOp2Lbi(\"prediction_0\")).shape().NumAxes() - 1;\n    const std::vector<int32_t> axis_vec(1, split_axis);\n\n    const std::string& op_name = node->op().op_name();\n    const auto& prediction_nd_sbp = node->NdSbp4BnInOp(\"prediction_0\");\n\n    NdSbp stat_distribution_for_consumer;\n\n    bool has_split_axis_parallel = false;\n    CHECK_EQ(prediction_nd_sbp.sbp_parallel_size(), node->parallel_desc().hierarchy()->NumAxes());\n    for (int64_t i = 0; i < node->parallel_desc().hierarchy()->NumAxes(); ++i) {\n      const auto& sbp = prediction_nd_sbp.sbp_parallel(i);\n      if (sbp.has_split_parallel() && sbp.split_parallel().axis() == split_axis) {\n        has_split_axis_parallel = true;\n        stat_distribution_for_consumer.add_sbp_parallel()->mutable_broadcast_parallel();\n      } else {\n        CHECK(!sbp.has_partial_sum_parallel());\n        *stat_distribution_for_consumer.add_sbp_parallel() = SbpParallel(sbp);\n      }\n    }\n\n    if (!has_split_axis_parallel) { return; }\n    to_del_op_names.push_back(op_name);\n\n    auto reduce_max_device_stage_op =\n        user_op::UserOpConfWrapperBuilder(op_name + \"-split_softmax_reduce_max_device_stage\")\n            .Op(\"reduce_max_device_stage\")\n            .Input(\"in\", op_prediction_blob_name)\n            .Output(\"out\")\n            .Output(\"mask\")\n            .Output(\"count\")\n            .Attr(\"axis\", axis_vec)\n            .ScopeSymbolId(scope_symbol_id)\n            .Build();\n    job_builder->AddOps(node->parallel_desc().parallel_conf(),\n                        {reduce_max_device_stage_op.op_conf()});\n    NdSbpSignature reduce_max_device_stage_signature;\n    (*reduce_max_device_stage_signature.mutable_bn_in_op2nd_sbp())[\"in_0\"] =\n        NdSbp(prediction_nd_sbp);\n    (*reduce_max_device_stage_signature.mutable_bn_in_op2nd_sbp())[\"out_0\"] =\n        NdSbp(prediction_nd_sbp);\n    (*reduce_max_device_stage_signature.mutable_bn_in_op2nd_sbp())[\"mask_0\"] =\n        NdSbp(prediction_nd_sbp);\n    (*reduce_max_device_stage_signature.mutable_bn_in_op2nd_sbp())[\"count_0\"] =\n        NdSbp(prediction_nd_sbp);\n    job_builder->AddNdSbpSignature4OpName(reduce_max_device_stage_op.op_name(),\n                                          reduce_max_device_stage_signature);\n\n    auto reduce_max_global_stage_op =\n        user_op::UserOpConfWrapperBuilder(op_name + \"-split_softmax_reduce_max_global_stage\")\n            .Op(\"reduce_max_global_stage\")\n            .Input(\"in\", reduce_max_device_stage_op.output(\"out\", 0))\n            .Input(\"device_count\", reduce_max_device_stage_op.output(\"count\", 0))\n            .Output(\"out\")\n            .Output(\"mask\")\n            .Attr(\"axis\", axis_vec)\n            .Attr(\"keepdims\", true)\n            .ScopeSymbolId(scope_symbol_id)\n            .Build();\n    job_builder->AddOps(node->parallel_desc().parallel_conf(),\n                        {reduce_max_global_stage_op.op_conf()});\n    NdSbpSignature reduce_max_global_stage_signature;\n    (*reduce_max_global_stage_signature.mutable_bn_in_op2nd_sbp())[\"in_0\"] =\n        stat_distribution_for_consumer;\n    (*reduce_max_global_stage_signature.mutable_bn_in_op2nd_sbp())[\"device_count_0\"] =\n        stat_distribution_for_consumer;\n    (*reduce_max_global_stage_signature.mutable_bn_in_op2nd_sbp())[\"out_0\"] =\n        stat_distribution_for_consumer;\n    job_builder->AddNdSbpSignature4OpName(reduce_max_global_stage_op.op_name(),\n                                          reduce_max_global_stage_signature);\n\n    auto broadcast_sub_max_op =\n        user_op::UserOpConfWrapperBuilder(op_name + \"-split_softmax_sub_max\")\n            .Op(\"broadcast_sub\")\n            .Input(\"x\", op_prediction_blob_name)\n            .Input(\"y\", reduce_max_global_stage_op.output(\"out\", 0))\n            .Output(\"z\")\n            .ScopeSymbolId(scope_symbol_id)\n            .Build();\n    job_builder->AddOps(node->parallel_desc().parallel_conf(), {broadcast_sub_max_op.op_conf()});\n\n    auto exp_op = user_op::UserOpConfWrapperBuilder(op_name + \"-split_softmax_exp\")\n                      .Op(\"exp\")\n                      .Input(\"x\", broadcast_sub_max_op.output(\"z\", 0))\n                      .Output(\"y\")\n                      .ScopeSymbolId(scope_symbol_id)\n                      .Build();\n    job_builder->AddOps(node->parallel_desc().parallel_conf(), {exp_op.op_conf()});\n\n    auto reduce_sum_op = user_op::UserOpConfWrapperBuilder(op_name + \"-split_softmax_reduce_sum\")\n                             .Op(\"reduce_sum\")\n                             .Input(\"input_tensor\", exp_op.output(\"y\", 0))\n                             .Output(\"output_tensor\")\n                             .Attr(\"axis\", axis_vec)\n                             .Attr(\"keepdims\", true)\n                             .ScopeSymbolId(scope_symbol_id)\n                             .Build();\n    job_builder->AddOps(node->parallel_desc().parallel_conf(), {reduce_sum_op.op_conf()});\n\n    std::string reduce_sum_op_out;\n    if (node->parallel_desc().hierarchy()->NumAxes() > 1) {\n      std::vector<std::string> nd_sbp_conf;\n      for (const auto& sbp_parallel : stat_distribution_for_consumer.sbp_parallel()) {\n        nd_sbp_conf.emplace_back(SbpParallelToString(sbp_parallel));\n      }\n      auto parallel_cast_sum_op =\n          user_op::UserOpConfWrapperBuilder(op_name + \"-split_softmax_reduce_sum_cast_P2B\")\n              .Op(\"hierarchical_parallel_cast\")\n              .Input(\"in\", reduce_sum_op.output(\"output_tensor\", 0))\n              .Output(\"out\")\n              .Attr<std::vector<std::string>>(\"nd_sbp\", nd_sbp_conf)\n              .Attr<std::string>(\"grad_mode\", \"auto\")\n              .Attr<std::vector<std::string>>(\"grad_nd_sbp\", std::vector<std::string>())\n              .ScopeSymbolId(scope_symbol_id)\n              .Build();\n      job_builder->AddOps(node->parallel_desc().parallel_conf(), {parallel_cast_sum_op.op_conf()});\n      reduce_sum_op_out = parallel_cast_sum_op.output(\"out\", 0);\n    } else {\n      reduce_sum_op_out = reduce_sum_op.output(\"output_tensor\", 0);\n    }\n\n    auto broadcast_div_op = user_op::UserOpConfWrapperBuilder(op_name + \"-split_softmax_div\")\n                                .Op(\"broadcast_div\")\n                                .Input(\"x\", exp_op.output(\"y\", 0))\n                                .Input(\"y\", reduce_sum_op_out)\n                                .Output(\"z\")\n                                .ScopeSymbolId(scope_symbol_id)\n                                .Build();\n    job_builder->AddOps(node->parallel_desc().parallel_conf(), {broadcast_div_op.op_conf()});\n\n    auto log_op = user_op::UserOpConfWrapperBuilder(op_name + \"-log\")\n                      .Op(\"log\")\n                      .Input(\"x\", reduce_sum_op_out)\n                      .Output(\"y\")\n                      .ScopeSymbolId(scope_symbol_id)\n                      .Build();\n    job_builder->AddOps(node->parallel_desc().parallel_conf(), {log_op.op_conf()});\n\n    auto broadcast_sub_op = user_op::UserOpConfWrapperBuilder(op_name + \"-broadcast_add\")\n                                .Op(\"broadcast_sub\")\n                                .Input(\"x\", broadcast_sub_max_op.output(\"z\", 0))\n                                .Input(\"y\", log_op.output(\"y\", 0))\n                                .Output(\"z\")\n                                .ScopeSymbolId(scope_symbol_id)\n                                .Build();\n    job_builder->AddOps(node->parallel_desc().parallel_conf(), {broadcast_sub_op.op_conf()});\n\n    auto nll_op = user_op::UserOpConfWrapperBuilder(op_name + \"-nll\")\n                      .Op(\"nll\")\n                      .Input(\"input\", broadcast_sub_op.output(\"z\", 0))\n                      .Input(\"target\", op_label_blob_name)\n                      .Output(\"output\")\n                      .Output(\"out_weight\")\n                      .Attr<int64_t>(\"ignore_index\", -100)\n                      .ScopeSymbolId(scope_symbol_id)\n                      .Build();\n    job_builder->AddOps(node->parallel_desc().parallel_conf(), {nll_op.op_conf()});\n\n    const std::string& prob_lbn = cur_op.output(\"prob\", 0);\n    const std::string& out_lbn = cur_op.output(\"out\", 0);\n    const std::string& new_prob_lbn = broadcast_div_op.output(\"z\", 0);\n    const std::string& new_out_lbn = nll_op.output(\"output\", 0);\n\n    for (const OpEdge* out_edge : node->out_edges()) {\n      const OpNode* consumer = out_edge->dst_node();\n      const std::string& consumer_op_name = consumer->op().op_name();\n      if (consumer_op_name2op_confs.find(consumer_op_name) == consumer_op_name2op_confs.end()) {\n        consumer_op_name2op_confs[consumer_op_name] = consumer->op().op_conf();\n      }\n      OperatorConf& consumer_op_conf = consumer_op_name2op_confs[consumer_op_name];\n      for (const std::string& ibn : consumer->op().input_bns()) {\n        const std::string& input_lbn = GenLogicalBlobName(consumer->op().BnInOp2Lbi(ibn));\n        if (input_lbn == prob_lbn) {\n          const auto& old_lbn =\n              ReplaceInputLbnInOpCustomizedConf(&consumer_op_conf, ibn, new_prob_lbn);\n          CHECK_EQ(old_lbn, prob_lbn);\n        } else if (input_lbn == out_lbn) {\n          const auto& old_lbn =\n              ReplaceInputLbnInOpCustomizedConf(&consumer_op_conf, ibn, new_out_lbn);\n          CHECK_EQ(old_lbn, out_lbn);\n        } else {\n          // does not care\n        }\n      }\n    }\n  });\n  for (const auto& pair : consumer_op_name2op_confs) { job_builder->MutOpsOnlyOnce({pair.second}); }\n  job_builder->DelOps(to_del_op_names);\n  return Maybe<void>::Ok();\n}\n\nREGISTER_JOB_PASS(\"SplitSparseSoftmaxCrossEntropyOpPass\", SplitSparseSoftmaxCrossEntropyOpPass);\n\n}  // namespace\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job_rewriter/system_op_fill_job_name_pass.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/job_rewriter/job_pass.h\"\n#include \"oneflow/core/job/job.pb.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nclass SystemOpFillJobNamePass final : public JobPass {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(SystemOpFillJobNamePass);\n  SystemOpFillJobNamePass() = default;\n  ~SystemOpFillJobNamePass() override = default;\n\n  bool IsEnabled(const JobPassCtx& ctx) const { return true; }\n\n  Maybe<void> Apply(Job* job, JobPassCtx* ctx) const override {\n    const std::string& job_name = job->job_conf().job_name();\n    for (OperatorConf& op_conf : *job->mutable_net()->mutable_op()) {\n      if (op_conf.has_input_conf()) {\n        op_conf.mutable_input_conf()->set_job_name(job_name);\n      } else if (op_conf.has_wait_and_send_ids_conf()) {\n        op_conf.mutable_wait_and_send_ids_conf()->set_job_name(job_name);\n      } else if (op_conf.has_output_conf()) {\n        op_conf.mutable_output_conf()->set_job_name(job_name);\n      } else if (op_conf.has_return_conf()) {\n        op_conf.mutable_return_conf()->set_job_name(job_name);\n      } else if (op_conf.has_callback_notify_conf()) {\n        op_conf.mutable_callback_notify_conf()->set_job_name(job_name);\n      } else {\n        // do nothing\n      }\n    }\n    return Maybe<void>::Ok();\n  }\n};\n\nREGISTER_JOB_PASS(\"SystemOpFillJobNamePass\", SystemOpFillJobNamePass);\n\n}  // namespace\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job_rewriter/tick_autotick.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/job_rewriter/autotick.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nclass MutTickOpConTickInputHelper final : public MutOpConTickInputHelper {\n public:\n  MutTickOpConTickInputHelper() : MutOpConTickInputHelper() {}\n\n  bool VirtualIsTickInputBound() const override { return op_conf().tick_conf().tick_size() > 0; }\n\n  OperatorConf NewTickInputBoundOpConf(const std::string& lbn) const override {\n    OperatorConf ret(op_conf());\n    ret.mutable_tick_conf()->add_tick(lbn);\n    return ret;\n  }\n};\n\n}  // namespace\n\nREGISTER_AUTO_TICK(OperatorConf::kTickConf, MutTickOpConTickInputHelper);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/job_rewriter/variable_autotick.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/job_rewriter/autotick.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nclass MutVariableOpConTickInputHelper final : public MutOpConTickInputHelper {\n public:\n  MutVariableOpConTickInputHelper() : MutOpConTickInputHelper() {}\n\n  bool VirtualIsTickInputBound() const override { return op_conf().variable_conf().has_tick(); }\n\n  OperatorConf NewTickInputBoundOpConf(const std::string& lbn) const override {\n    OperatorConf ret(op_conf());\n    ret.mutable_variable_conf()->set_tick(lbn);\n    return ret;\n  }\n};\n\n}  // namespace\n\nREGISTER_AUTO_TICK(OperatorConf::kVariableConf, MutVariableOpConTickInputHelper);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/kernel/assign_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/kernel/kernel.h\"\n\nnamespace oneflow {\n\nclass AssignKernel final : public Kernel {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(AssignKernel);\n  AssignKernel() = default;\n  ~AssignKernel() override = default;\n\n private:\n  bool IsStateless() const override { return false; }\n  void ForwardDataContent(KernelContext* ctx) const override;\n};\n\nvoid AssignKernel::ForwardDataContent(KernelContext* ctx) const {\n  const Blob* value = ctx->BnInOp2Blob(\"value\");\n  Blob* ref = ctx->BnInOp2Blob(\"ref\");\n  AutoMemcpy(ctx->stream(), ref, value);\n}\n\nREGISTER_KERNEL(OperatorConf::kAssignConf, AssignKernel);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/kernel/blob_access_checker_kernel_observer.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/kernel/blob_access_checker_kernel_observer.h\"\n#include \"oneflow/core/kernel/kernel.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename HandlerT>\nvoid ForEachObnAndIsHeaderInferedBeforeCompute(KernelContext* kernel_ctx, const Kernel* kernel,\n                                               const HandlerT& Handler) {\n  const auto& modifier_map =\n      kernel->op_attribute().arg_modifier_signature().obn2output_blob_modifier();\n  for (const std::string& obn : kernel->op_attribute().output_bns()) {\n    Blob* blob = kernel_ctx->BnInOp2Blob(obn);\n    if (blob) {\n      bool is_header_infered_before_compute = modifier_map.at(obn).header_infered_before_compute();\n      Handler(obn, is_header_infered_before_compute);\n    }\n  }\n}\n\ntemplate<typename HandlerT>\nvoid ForEachObnAndIsMutableByConsumer(KernelContext* kernel_ctx, const Kernel* kernel,\n                                      const HandlerT& Handler) {\n  const auto& modifier_map =\n      kernel->op_attribute().arg_modifier_signature().obn2output_blob_modifier();\n  for (const std::string& obn : kernel->op_attribute().output_bns()) {\n    Blob* blob = kernel_ctx->BnInOp2Blob(obn);\n    if (blob) {\n      bool is_mutable_by_consumer = modifier_map.at(obn).is_mutable();\n      Handler(obn, is_mutable_by_consumer);\n    }\n  }\n}\n\nvoid SetOutputBlobProducerInferAccessChecker(KernelContext* kernel_ctx, const Kernel* kernel) {\n  ForEachObnAndIsHeaderInferedBeforeCompute(\n      kernel_ctx, kernel, [&](const std::string& obn, bool _) {\n        kernel_ctx->BnInOp2Blob(obn)->set_blob_access_checker(\n            Singleton<BlobAccessCheckerIf<true, false>>::Get());\n      });\n}\n\nvoid SetOutputBlobProducerComputeAccessChecker(KernelContext* kernel_ctx, const Kernel* kernel) {\n  ForEachObnAndIsHeaderInferedBeforeCompute(\n      kernel_ctx, kernel, [&](const std::string& obn, bool is_header_infered_before_compute) {\n        const BlobAccessChecker* checker = nullptr;\n        if (is_header_infered_before_compute) {\n          checker = Singleton<BlobAccessCheckerIf<false, true>>::Get();\n        } else {\n          checker = Singleton<BlobAccessCheckerIf<true, true>>::Get();\n        }\n        kernel_ctx->BnInOp2Blob(obn)->set_blob_access_checker(checker);\n      });\n}\n\nvoid SetOutputBlobConsumerAccessChecker(KernelContext* kernel_ctx, const Kernel* kernel) {\n  ForEachObnAndIsMutableByConsumer(\n      kernel_ctx, kernel, [&](const std::string& obn, bool is_mutable) {\n        const BlobAccessChecker* checker = nullptr;\n        if (is_mutable) {\n          checker = Singleton<BlobAccessCheckerIf<false, true>>::Get();\n        } else {\n          checker = Singleton<BlobAccessCheckerIf<false, false>>::Get();\n        }\n        kernel_ctx->BnInOp2Blob(obn)->set_blob_access_checker(checker);\n      });\n}\n\n}  // namespace\n\nvoid BlobAccessCheckerKernelObserver::WillForwardHeader(KernelContext* kernel_ctx,\n                                                        const Kernel* kernel) {\n  SetOutputBlobProducerInferAccessChecker(kernel_ctx, kernel);\n}\n\nvoid BlobAccessCheckerKernelObserver::WillForwardDataContent(KernelContext* kernel_ctx,\n                                                             const Kernel* kernel) {\n  SetOutputBlobProducerComputeAccessChecker(kernel_ctx, kernel);\n}\n\nvoid BlobAccessCheckerKernelObserver::DidForwardDataContent(KernelContext* kernel_ctx,\n                                                            const Kernel* kernel) {\n  SetOutputBlobConsumerAccessChecker(kernel_ctx, kernel);\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/kernel/blob_access_checker_kernel_observer.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_KERNEL_BLOB_ACCESS_CHECKER_KERNEL_OBSERVER_H_\n#define ONEFLOW_CORE_KERNEL_BLOB_ACCESS_CHECKER_KERNEL_OBSERVER_H_\n\n#include \"oneflow/core/kernel/kernel_observer.h\"\n\nnamespace oneflow {\n\nclass BlobAccessCheckerKernelObserver final : public KernelObserver {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(BlobAccessCheckerKernelObserver);\n  BlobAccessCheckerKernelObserver() = default;\n  ~BlobAccessCheckerKernelObserver() override = default;\n\n  void WillForwardHeader(KernelContext* kernel_ctx, const Kernel* kernel) override;\n\n  void WillForwardDataContent(KernelContext* kernel_ctx, const Kernel* kernel) override;\n  void DidForwardDataContent(KernelContext* kernel_ctx, const Kernel* kernel) override;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_KERNEL_BLOB_ACCESS_CHECKER_KERNEL_OBSERVER_H_\n"
  },
  {
    "path": "oneflow/core/kernel/blob_tensor_view.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/kernel/blob_tensor_view.h\"\n#include \"oneflow/core/register/blob.h\"\n\nnamespace oneflow {\n\nnamespace user_op {\n\nBlobTensorView::BlobTensorView(Blob* blob) : blob_(blob) {}\n\nShapeView BlobTensorView::shape_view() const { return blob_->shape(); }\n\nMutShapeView BlobTensorView::mut_shape_view() { return *blob_->mut_shape_view(); }\n\nconst Stride& BlobTensorView::stride() const { return blob_->stride(); }\n\nDataType BlobTensorView::data_type() const { return blob_->data_type(); }\n\nMemoryFormat BlobTensorView::memory_format() const { return blob_->memory_format(); }\n\nconst MemoryCase& BlobTensorView::mem_case() const { return blob_->mem_case(); }\n\nconst void* BlobTensorView::raw_dptr() const { return blob_->dptr(); }\n\nvoid* BlobTensorView::mut_raw_dptr() { return blob_->mut_dptr(); }\n\nvoid BlobTensorView::Reset(Blob* blob) { blob_ = blob; }\n\nBlob* BlobTensorView::blob() const { return blob_; }\n\n}  // namespace user_op\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/kernel/blob_tensor_view.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_KERNEL_BLOB_TENSOR_VIEW_H_\n#define ONEFLOW_CORE_KERNEL_BLOB_TENSOR_VIEW_H_\n\n#include \"oneflow/core/framework/user_op_tensor.h\"\n\nnamespace oneflow {\n\nclass Blob;\n\nnamespace user_op {\n\nclass BlobTensorView final : public Tensor {\n public:\n  explicit BlobTensorView(Blob* blob);\n  ~BlobTensorView() = default;\n\n  ShapeView shape_view() const override;\n  MutShapeView mut_shape_view() override;\n  const Stride& stride() const override;\n  DataType data_type() const override;\n  MemoryFormat memory_format() const override;\n  const MemoryCase& mem_case() const override;\n  const void* raw_dptr() const override;\n  void* mut_raw_dptr() override;\n\n  void Reset(Blob* blob);\n  Blob* blob() const;\n\n private:\n  Blob* blob_;\n};\n\n}  // namespace user_op\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_KERNEL_BLOB_TENSOR_VIEW_H_\n"
  },
  {
    "path": "oneflow/core/kernel/boxing_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/kernel/kernel.h\"\n#include \"oneflow/core/kernel/kernel_util.h\"\n#include \"oneflow/core/operator/op_conf_util.h\"\n#include \"oneflow/core/common/balanced_splitter.h\"\n#include \"oneflow/core/thread/thread_manager.h\"\n#include \"oneflow/core/common/blocking_counter.h\"\n#include \"oneflow/core/ep/include/primitive/add.h\"\n\nnamespace oneflow {\n\ntemplate<typename T>\nclass BoxingKernel final : public Kernel {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(BoxingKernel);\n  BoxingKernel() = default;\n  ~BoxingKernel() = default;\n\n private:\n  void VirtualKernelInit(KernelContext* ctx) override;\n  void ForwardDataContent(KernelContext* ctx) const override;\n\n  PbRpf<std::string> ibn_0_;\n  PbRpf<std::string> obn_0_;\n};\n\nnamespace {\n\nPbRpf<std::string> ConstructPbRpf(const std::string& s) {\n  PbRpf<std::string> ret;\n  ret.Reserve(1);\n  ret.Add()->assign(s);\n  return ret;\n}\n\ntemplate<typename T>\nvoid CalcSumOfBlobs(KernelContext* ctx, const std::function<Blob*(const std::string&)>& BnInOp2Blob,\n                    const PbRpf<std::string>& src_bns, const std::string& dst_bn) {\n  Blob* dst_blob = BnInOp2Blob(dst_bn);\n  std::unique_ptr<ep::primitive::Add> primitive =\n      ep::primitive::NewPrimitive<ep::primitive::AddFactory>(DeviceType::kCPU,\n                                                             dst_blob->data_type());\n  CHECK(primitive);\n  std::vector<const void*> srcs(src_bns.size());\n  FOR_RANGE(size_t, i, 0, src_bns.size()) {\n    Blob* src_blob_i = BnInOp2Blob(src_bns.Get(i));\n    srcs[i] = src_blob_i->dptr<T>();\n  }\n  primitive->Launch(ctx->stream(), srcs.data(), srcs.size(), dst_blob->mut_dptr<T>(),\n                    dst_blob->static_shape().elem_cnt());\n}\n\nvoid CopyFromFirstToOtherBlobs(KernelContext* ctx,\n                               const std::function<Blob*(const std::string&)>& BnInOp2Blob,\n                               const PbRpf<std::string>& bns) {\n  const Blob* blob_0 = BnInOp2Blob(bns.Get(0));\n  FOR_RANGE(size_t, i, 1, bns.size()) {\n    AutoMemcpy(ctx->stream(), BnInOp2Blob(bns.Get(i)), blob_0);\n  }\n}\n\nclass DataContentDesc final {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(DataContentDesc);\n  DataContentDesc() = delete;\n  ~DataContentDesc() = default;\n\n  DataContentDesc(std::function<Blob*(const std::string&)> BnInOp2Blob,\n                  const PbRpf<std::string>* bns, int32_t axis) {\n    BnInOp2Blob_ = BnInOp2Blob;\n    seg_num_ = BnInOp2Blob(bns->Get(0))->static_shape().Count(0, axis);\n    elem_sum_.assign(bns->size(), 0);\n    FOR_RANGE(size_t, i, 0, elem_sum_.size()) {\n      elem_sum_[i] = BnInOp2Blob(bns->Get(i))->static_shape().Count(axis);\n      if (i > 0) { elem_sum_[i] += elem_sum_[i - 1]; }\n    }\n    bns_ = bns;\n    axis_ = axis;\n  }\n\n  size_t OneElemSize() const { return GetSizeOfDataType(BnInOp2Blob_(bns_->Get(0))->data_type()); }\n\n  int64_t TotalElemNum() const { return seg_num_ * elem_sum_.back(); }\n\n  template<typename DptrT, DptrT* (*GetDptrT)(Blob*)>\n  std::tuple<int64_t, DptrT*> CalcContinuousElemNumStartFrom(int64_t idx) const {\n    std::tuple<int64_t, DptrT*> ret(0, nullptr);\n    int64_t seg_idx = idx / elem_sum_.back();\n    int64_t idx_in_seg = idx % elem_sum_.back();\n    auto elem_sum_it = std::upper_bound(elem_sum_.begin(), elem_sum_.end(), idx_in_seg);\n    CHECK(elem_sum_it != elem_sum_.end());\n    std::get<0>(ret) = *elem_sum_it - idx_in_seg;\n    int64_t bn_idx = elem_sum_it - elem_sum_.begin();\n    int64_t idx_in_blob = idx_in_seg;\n    if (bn_idx > 0) { idx_in_blob -= elem_sum_[bn_idx - 1]; }\n    Blob* blob = BnInOp2Blob_(bns_->Get(bn_idx));\n    std::get<1>(ret) = GetDptrT(blob)\n                       + (seg_idx * blob->static_shape().Count(axis_) + idx_in_blob)\n                             * GetSizeOfDataType(blob->data_type());\n    return ret;\n  }\n\n private:\n  std::function<Blob*(const std::string&)> BnInOp2Blob_;\n  int64_t seg_num_;\n  std::vector<int64_t> elem_sum_;\n  const PbRpf<std::string>* bns_;\n  int32_t axis_;\n};\n\nstatic const char* GetConstDptr(Blob* blob) { return blob->dptr<char>(); }\nstatic char* GetMutDptr(Blob* blob) { return blob->mut_dptr<char>(); }\n\nvoid ConcatSplitPartDataContent(ep::Stream* stream, const DataContentDesc& in_desc,\n                                const DataContentDesc& out_desc, int32_t part_id,\n                                int32_t part_num) {\n  size_t one_elem_size = in_desc.OneElemSize();\n  BalancedSplitter bs(in_desc.TotalElemNum(), part_num);\n  Range range = bs.At(part_id);\n  int64_t in_idx = range.begin();\n  int64_t in_elem_num = 0;\n  const char* in_ptr = nullptr;\n  int64_t out_idx = range.begin();\n  int64_t out_elem_num = 0;\n  char* out_ptr = nullptr;\n\n  while (in_elem_num > 0 || out_elem_num > 0 || in_idx < range.end() || out_idx < range.end()) {\n    if (in_elem_num == 0) {\n      std::tie(in_elem_num, in_ptr) =\n          in_desc.CalcContinuousElemNumStartFrom<const char, GetConstDptr>(in_idx);\n      in_elem_num = std::min(in_elem_num, range.end() - in_idx);\n      if (in_elem_num == 0) { break; }\n      in_idx += in_elem_num;\n    }\n    if (out_elem_num == 0) {\n      std::tie(out_elem_num, out_ptr) =\n          out_desc.CalcContinuousElemNumStartFrom<char, GetMutDptr>(out_idx);\n      out_elem_num = std::min(out_elem_num, range.end() - out_idx);\n      if (out_elem_num == 0) { break; }\n      out_idx += out_elem_num;\n    }\n    int64_t copy_elem_num = std::min(in_elem_num, out_elem_num);\n    size_t copy_size = copy_elem_num * one_elem_size;\n    Memcpy<DeviceType::kCPU>(stream, out_ptr, in_ptr, copy_size);\n    in_elem_num -= copy_elem_num;\n    out_elem_num -= copy_elem_num;\n    in_ptr += copy_size;\n    out_ptr += copy_size;\n  }\n  CHECK_EQ(in_elem_num, 0);\n  CHECK_EQ(out_elem_num, 0);\n  CHECK_EQ(in_idx, range.end());\n  CHECK_EQ(out_idx, range.end());\n}\n\nvoid ConcatSplitDataContent(ep::Stream* stream,\n                            const std::function<Blob*(const std::string&)>& BnInOp2Blob,\n                            const PbRpf<std::string>& concat_bns, int32_t concat_axis,\n                            const PbRpf<std::string>& split_bns, int32_t split_axis) {\n  DataContentDesc in_desc(BnInOp2Blob, &concat_bns, concat_axis);\n  DataContentDesc out_desc(BnInOp2Blob, &split_bns, split_axis);\n  CHECK_EQ(in_desc.TotalElemNum(), out_desc.TotalElemNum());\n  CHECK_EQ(in_desc.OneElemSize(), out_desc.OneElemSize());\n  static const size_t min_byte_one_part = 128;\n  int32_t part_num = in_desc.TotalElemNum() * in_desc.OneElemSize() / min_byte_one_part;\n  part_num = std::min(part_num, Singleton<ThreadPool>::Get()->thread_num());\n  if (part_num >= 2) {\n    BlockingCounter bc(part_num);\n    FOR_RANGE(int32_t, part_id, 0, part_num) {\n      Singleton<ThreadPool>::Get()->AddWork(\n          [stream, &in_desc, &out_desc, part_id, &part_num, &bc]() {\n            ConcatSplitPartDataContent(stream, in_desc, out_desc, part_id, part_num);\n            bc.Decrease();\n          });\n    }\n    bc.WaitForeverUntilCntEqualZero();\n  } else {\n    ConcatSplitPartDataContent(stream, in_desc, out_desc, 0, 1);\n  }\n}\n\n}  // namespace\n\ntemplate<typename T>\nvoid BoxingKernel<T>::VirtualKernelInit(KernelContext* ctx) {\n  const std::string& ibn_0 = op_attribute().input_bns(0);\n  const std::string& obn_0 = op_attribute().output_bns(0);\n  ibn_0_ = ConstructPbRpf(ibn_0);\n  obn_0_ = ConstructPbRpf(obn_0);\n}\n\ntemplate<typename T>\nvoid BoxingKernel<T>::ForwardDataContent(KernelContext* ctx) const {\n  const BoxingOpConf& boxing_conf = op_conf().boxing_conf();\n  ep::Stream* stream = ctx->stream();\n  const auto BnInOp2Blob = [ctx](const std::string& bn) { return ctx->BnInOp2Blob(bn); };\n  if (boxing_conf.in_box_case() == BoxingOpConf::kConcatBox) {\n    if (boxing_conf.out_box_case() == BoxingOpConf::kSplitBox) {\n      ConcatSplitDataContent(stream, BnInOp2Blob, op_attribute().input_bns(),\n                             boxing_conf.concat_box().axis(), op_attribute().output_bns(),\n                             boxing_conf.split_box().axis());\n    } else if (boxing_conf.out_box_case() == BoxingOpConf::kCloneBox) {\n      ConcatSplitDataContent(stream, BnInOp2Blob, op_attribute().input_bns(),\n                             boxing_conf.concat_box().axis(), obn_0_, 0);\n      CopyFromFirstToOtherBlobs(ctx, BnInOp2Blob, op_attribute().output_bns());\n    } else {\n      UNIMPLEMENTED();\n    }\n  } else if (boxing_conf.in_box_case() == BoxingOpConf::kAddBox) {\n    if (boxing_conf.out_box_case() == BoxingOpConf::kSplitBox) {\n      CalcSumOfBlobs<T>(ctx, BnInOp2Blob, op_attribute().input_bns(), \"middle\");\n      ConcatSplitDataContent(stream, BnInOp2Blob, ConstructPbRpf(\"middle\"), 0,\n                             op_attribute().output_bns(), boxing_conf.split_box().axis());\n    } else if (boxing_conf.out_box_case() == BoxingOpConf::kCloneBox) {\n      CalcSumOfBlobs<T>(ctx, BnInOp2Blob, op_attribute().input_bns(), obn_0_.Get(0));\n      CopyFromFirstToOtherBlobs(ctx, BnInOp2Blob, op_attribute().output_bns());\n    } else {\n      UNIMPLEMENTED();\n    }\n  } else {\n    UNIMPLEMENTED();\n  }\n}\n\nADD_CPU_DEFAULT_KERNEL_CREATOR(OperatorConf::kBoxingConf, BoxingKernel,\n                               ARITHMETIC_DATA_TYPE_SEQ FLOAT16_DATA_TYPE_SEQ);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/kernel/boxing_zeros_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/kernel/kernel.h\"\n#include \"oneflow/core/kernel/kernel_context.h\"\n#include \"oneflow/core/ep/include/primitive/memset.h\"\n\nnamespace oneflow {\n\nclass BoxingZerosKernel final : public Kernel {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(BoxingZerosKernel);\n  BoxingZerosKernel() = default;\n  ~BoxingZerosKernel() override = default;\n\n private:\n  void VirtualKernelInit(KernelContext* ctx) override;\n  void ForwardDataContent(KernelContext* ctx) const override;\n\n  std::unique_ptr<ep::primitive::Memset> primitive_;\n};\n\nvoid BoxingZerosKernel::VirtualKernelInit(KernelContext* ctx) {\n  primitive_ =\n      ep::primitive::NewPrimitive<ep::primitive::MemsetFactory>(ctx->stream()->device_type());\n  CHECK(primitive_);\n}\n\nvoid BoxingZerosKernel::ForwardDataContent(KernelContext* ctx) const {\n  Blob* out = ctx->BnInOp2Blob(\"out\");\n  primitive_->Launch(ctx->stream(), out->mut_dptr(), 0, out->ByteSizeOfBlobBody());\n}\n\nREGISTER_KERNEL(OperatorConf::kBoxingZerosConf, BoxingZerosKernel);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/kernel/broadcast_to_compatible_with_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/kernel/kernel.h\"\n#include \"oneflow/core/ndarray/ndarray_util.h\"\n\nnamespace oneflow {\n\ntemplate<DeviceType device_type, typename T>\nclass BroadcastToCompatibleWithKernel final : public Kernel {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(BroadcastToCompatibleWithKernel);\n  BroadcastToCompatibleWithKernel() = default;\n  ~BroadcastToCompatibleWithKernel() = default;\n\n private:\n  void ForwardDataContent(KernelContext* ctx) const override;\n};\n\ntemplate<DeviceType device_type, typename T>\nvoid BroadcastToCompatibleWithKernel<device_type, T>::ForwardDataContent(KernelContext* ctx) const {\n  const Blob* x = ctx->BnInOp2Blob(\"x\");\n  Blob* y = ctx->BnInOp2Blob(\"y\");\n  const auto& broadcast_axes =\n      this->kernel_conf().broadcast_to_compatible_with_conf().broadcast_axes();\n  int64_t num_axes = y->shape().NumAxes();\n  Shape x_extend_shape = CreateLeftExtendedShape(x->shape(), num_axes);\n  FOR_RANGE(int64_t, i, 0, num_axes) {\n    if (std::find(broadcast_axes.begin(), broadcast_axes.end(), i) == broadcast_axes.end()) {\n      CHECK_EQ(x_extend_shape.At(i), y->shape().At(i));\n    } else {\n      CHECK_EQ(x_extend_shape.At(i), 1);\n    }\n  }\n  NdarrayUtil<device_type, T>::BroadcastTo(ctx->stream(), XpuVarNdarray<T>(y, num_axes),\n                                           XpuVarNdarray<const T>(x, num_axes));\n}\n\n#define REGISTTER_BROADCAST_TO_COMPATIBLE_WITH_KERNEL(device_type_v, dtype_pair)                 \\\n  REGISTER_KERNEL_WITH_DEVICE_AND_DTYPE(                                                         \\\n      OperatorConf::kBroadcastToCompatibleWithConf, device_type_v, OF_PP_PAIR_FIRST(dtype_pair), \\\n      BroadcastToCompatibleWithKernel<device_type_v, OF_PP_PAIR_FIRST(dtype_pair)>)\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTTER_BROADCAST_TO_COMPATIBLE_WITH_KERNEL, DEVICE_TYPE_SEQ,\n                                 ARITHMETIC_DATA_TYPE_SEQ BOOL_DATA_TYPE_SEQ)\n\n#if defined(WITH_CUDA)\nREGISTTER_BROADCAST_TO_COMPATIBLE_WITH_KERNEL(DeviceType::kCUDA, (float16, DataType::kFloat16))\n#endif\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/kernel/callback_notify_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/kernel/kernel.h\"\n#include \"oneflow/core/common/buffer_manager.h\"\n#include \"oneflow/core/job/job_instance.h\"\n#include \"oneflow/core/job/global_for.h\"\n#include \"oneflow/core/common/buffer_manager.h\"\n\nnamespace oneflow {\n\ntemplate<typename T>\nclass CallbackNotifyKernel final : public Kernel {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CallbackNotifyKernel);\n  CallbackNotifyKernel() = default;\n  ~CallbackNotifyKernel() = default;\n\n private:\n  bool IsStateless() const override { return false; }\n  void ForwardDataContent(KernelContext* ctx) const override;\n};\n\ntemplate<typename T>\nvoid CallbackNotifyKernel<T>::ForwardDataContent(KernelContext* ctx) const {\n  auto* buffer_mgr = Singleton<BufferMgr<std::shared_ptr<JobInstance>>>::Get();\n  std::string buffer_name;\n  CHECK(this->op_conf().callback_notify_conf().has_job_name());\n  buffer_name = GetCallbackNotifierBufferName(this->op_conf().callback_notify_conf().job_name());\n  std::shared_ptr<JobInstance> job_instance;\n  BufferStatus buffer_status = buffer_mgr->Get(buffer_name)->TryReceive(&job_instance);\n  CHECK_NE(buffer_status, kBufferStatusEmpty);\n  if (buffer_status == kBufferStatusSuccess) { job_instance->Finish(); }\n}\n\nADD_CPU_DEFAULT_KERNEL_CREATOR(OperatorConf::kCallbackNotifyConf, CallbackNotifyKernel,\n                               INT_DATA_TYPE_SEQ);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/kernel/case_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/kernel/case_kernel.h\"\n#include \"oneflow/core/operator/operator.h\"\n\nnamespace oneflow {\n\ntemplate<typename T>\nvoid CaseKernel<T>::VirtualKernelInit(KernelContext* ctx) {\n  ctx->set_state(std::make_shared<CaseStatus>());\n}\n\ntemplate<typename T>\nvoid CaseKernel<T>::ForwardDataContent(KernelContext* ctx) const {\n  CaseStatus* const case_status = CHECK_NOTNULL(dynamic_cast<CaseStatus*>(ctx->state().get()));\n  if (case_status->cmd == kCaseCmdHandleInput) {\n    int64_t cur_selected_id = static_cast<int64_t>(ctx->BnInOp2Blob(\"in\")->dptr<T>()[0]);\n    case_status->select_id2request_cnt[cur_selected_id] += 1;\n  } else if (case_status->cmd == kCaseCmdHandleOutput) {\n    int64_t cur_selected_id = case_status->cur_selected_id;\n    CHECK_GT(case_status->select_id2request_cnt[cur_selected_id], 0);\n    case_status->select_id2request_cnt[cur_selected_id] -= 1;\n    if (case_status->select_id2request_cnt[cur_selected_id] == 0) {\n      case_status->select_id2request_cnt.erase(cur_selected_id);\n    }\n    *(ctx->BnInOp2Blob(GenRepeatedBn(\"out\", cur_selected_id))->mut_dptr<T>()) = cur_selected_id;\n  } else {\n    UNIMPLEMENTED();\n  }\n}\n\nADD_CPU_DEFAULT_KERNEL_CREATOR(OperatorConf::kCaseConf, CaseKernel, INT_DATA_TYPE_SEQ)\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/kernel/case_kernel.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_KERNEL_CASE_KERNEL_H_\n#define ONEFLOW_CORE_KERNEL_CASE_KERNEL_H_\n\n#include \"oneflow/core/kernel/kernel.h\"\n\nnamespace oneflow {\n\nenum CaseCmd {\n  kCaseCmdInvalid = 0,\n  kCaseCmdHandleInput = 1,\n  kCaseCmdHandleOutput = 2,\n};\n\nstruct CaseStatus final : public KernelState {\n  CaseStatus() : cmd(kCaseCmdInvalid), cur_selected_id(-1) {}\n  ~CaseStatus() = default;\n\n  CaseCmd cmd;\n  int64_t cur_selected_id;\n  HashMap<int64_t, int64_t> select_id2request_cnt;\n};\n\ntemplate<typename T>\nclass CaseKernel final : public Kernel {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CaseKernel);\n  CaseKernel() = default;\n  ~CaseKernel() override = default;\n\n private:\n  void VirtualKernelInit(KernelContext* ctx) override;\n  void ForwardDataContent(KernelContext* ctx) const override;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_KERNEL_CASE_KERNEL_H_\n"
  },
  {
    "path": "oneflow/core/kernel/chain_kernel_observer.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/kernel/chain_kernel_observer.h\"\n#include \"oneflow/core/kernel/kernel.h\"\n\nnamespace oneflow {\n\nvoid ChainKernelObserver::WillForward(KernelContext* kernel_ctx, const Kernel* kernel) {\n  for (const auto& observer : kernel_observers_) { observer->WillForward(kernel_ctx, kernel); }\n}\n\nvoid ChainKernelObserver::DidForward(KernelContext* kernel_ctx, const Kernel* kernel) {\n  for (const auto& observer : kernel_observers_) { observer->DidForward(kernel_ctx, kernel); }\n}\n\nvoid ChainKernelObserver::WillForwardHeader(KernelContext* kernel_ctx, const Kernel* kernel) {\n  for (const auto& observer : kernel_observers_) {\n    observer->WillForwardHeader(kernel_ctx, kernel);\n  }\n}\n\nvoid ChainKernelObserver::DidForwardHeader(KernelContext* kernel_ctx, const Kernel* kernel) {\n  for (const auto& observer : kernel_observers_) { observer->DidForwardHeader(kernel_ctx, kernel); }\n}\n\nvoid ChainKernelObserver::WillForwardDataContent(KernelContext* kernel_ctx, const Kernel* kernel) {\n  for (const auto& observer : kernel_observers_) {\n    observer->WillForwardDataContent(kernel_ctx, kernel);\n  }\n}\n\nvoid ChainKernelObserver::DidForwardDataContent(KernelContext* kernel_ctx, const Kernel* kernel) {\n  for (const auto& observer : kernel_observers_) {\n    observer->DidForwardDataContent(kernel_ctx, kernel);\n  }\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/kernel/chain_kernel_observer.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_KERNEL_CHAIN_KERNEL_OBSERVER_H_\n#define ONEFLOW_CORE_KERNEL_CHAIN_KERNEL_OBSERVER_H_\n\n#include \"oneflow/core/kernel/kernel_observer.h\"\n\nnamespace oneflow {\n\nclass ChainKernelObserver final : public KernelObserver {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(ChainKernelObserver);\n  explicit ChainKernelObserver(std::vector<std::shared_ptr<KernelObserver>> kernel_observers)\n      : kernel_observers_(std::move(kernel_observers)) {}\n  ~ChainKernelObserver() override = default;\n\n  void WillForward(KernelContext* kernel_ctx, const Kernel* kernel) override;\n  void DidForward(KernelContext* kernel_ctx, const Kernel* kernel) override;\n\n  void WillForwardHeader(KernelContext* kernel_ctx, const Kernel* kernel) override;\n  void DidForwardHeader(KernelContext* kernel_ctx, const Kernel* kernel) override;\n\n  void WillForwardDataContent(KernelContext* kernel_ctx, const Kernel* kernel) override;\n  void DidForwardDataContent(KernelContext* kernel_ctx, const Kernel* kernel) override;\n\n private:\n  std::vector<std::shared_ptr<KernelObserver>> kernel_observers_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_KERNEL_CHAIN_KERNEL_OBSERVER_H_\n"
  },
  {
    "path": "oneflow/core/kernel/collective_boxing_kernels.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/kernel/kernel.h\"\n#include \"oneflow/core/job/collective_boxing/scheduler.h\"\n#include \"oneflow/core/common/blocking_counter.h\"\n#include \"oneflow/core/graph/boxing/collective_boxing_util.h\"\n#include \"oneflow/core/lazy/actor/collective_boxing_actor_context.h\"\n\nnamespace oneflow {\n\nusing namespace boxing::collective;\n\nnamespace {\n\nCollectiveBoxingActorContext* GetCollectiveBoxingActorContext(KernelContext* kernel_ctx) {\n  auto* actor_context_provider = CHECK_NOTNULL(dynamic_cast<ActorContextProvider*>(kernel_ctx));\n  return CHECK_NOTNULL(\n      dynamic_cast<CollectiveBoxingActorContext*>(actor_context_provider->GetActorContext()));\n}\n\nclass CollectiveBoxingKernelState final : public KernelState {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CollectiveBoxingKernelState);\n  explicit CollectiveBoxingKernelState(const RankDesc& rank_desc)\n      : request_handle_(Singleton<Scheduler>::Get()->CreateRequestHandle(rank_desc)) {}\n  ~CollectiveBoxingKernelState() override {\n    Singleton<Scheduler>::Get()->DestroyRequestHandle(request_handle_);\n  }\n  RequestHandle* request_handle() { return request_handle_; }\n\n private:\n  RequestHandle* request_handle_ = nullptr;\n};\n\nclass CollectiveBoxingGenericKernel final : public Kernel {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CollectiveBoxingGenericKernel);\n  CollectiveBoxingGenericKernel() = default;\n  ~CollectiveBoxingGenericKernel() override = default;\n\n private:\n  void VirtualKernelInit(KernelContext* ctx) override;\n  bool IsKernelLaunchSynchronized() const override { return false; }\n  void ForwardDataContent(KernelContext* ctx) const override;\n};\n\nvoid CollectiveBoxingGenericKernel::VirtualKernelInit(KernelContext* ctx) {\n  const RankDesc& rank_desc = this->op_conf().collective_boxing_generic_conf().rank_desc();\n  ctx->set_state(std::make_shared<CollectiveBoxingKernelState>(rank_desc));\n}\n\nvoid CollectiveBoxingGenericKernel::ForwardDataContent(KernelContext* ctx) const {\n  RequestHandle* request_handle =\n      CHECK_NOTNULL(dynamic_cast<CollectiveBoxingKernelState*>(ctx->state().get()))\n          ->request_handle();\n  const void* send_buff = nullptr;\n  void* recv_buff = nullptr;\n  const RankDesc& rank_desc = this->op_conf().collective_boxing_generic_conf().rank_desc();\n  const DataType data_type = rank_desc.op_desc().data_type();\n  if (GenericOpHasInput(rank_desc)) {\n    const Blob* in = ctx->BnInOp2Blob(\"in\");\n    CHECK_EQ(in->data_type(), data_type);\n    CHECK(in->shape() == ShapeView(GenericOpGetInputShape(rank_desc)));\n    send_buff = in->dptr();\n  }\n  if (GenericOpHasOutput(rank_desc)) {\n    Blob* out = ctx->BnInOp2Blob(\"out\");\n    CHECK_EQ(out->data_type(), data_type);\n    CHECK(out->shape() == ShapeView(GenericOpGetOutputShape(rank_desc)));\n    recv_buff = out->mut_dptr();\n  }\n  auto* actor_ctx = GetCollectiveBoxingActorContext(ctx);\n  actor_ctx->Schedule(request_handle, send_buff, recv_buff);\n}\n\nREGISTER_KERNEL(OperatorConf::kCollectiveBoxingGenericConf, CollectiveBoxingGenericKernel);\n\n}  // namespace\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/kernel/collective_boxing_pack_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/kernel/kernel.h\"\n#include \"oneflow/core/kernel/new_kernel_util.h\"\n#include \"oneflow/core/ep/include/primitive/permute.h\"\n\nnamespace oneflow {\n\nclass CollectiveBoxingPackKernel final : public Kernel {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CollectiveBoxingPackKernel);\n  CollectiveBoxingPackKernel() = default;\n  ~CollectiveBoxingPackKernel() override = default;\n\n private:\n  bool IsStateless() const override { return false; }\n  void ForwardDataContent(KernelContext* ctx) const override;\n};\n\nvoid CollectiveBoxingPackKernel::ForwardDataContent(KernelContext* ctx) const {\n  const Blob* in = ctx->BnInOp2Blob(\"in\");\n  Blob* out = ctx->BnInOp2Blob(\"out\");\n  const CollectiveBoxingPackOpConf& pack_conf = this->op_conf().collective_boxing_pack_conf();\n  const int64_t num_ranks = pack_conf.num_ranks();\n  const Shape logical_shape(pack_conf.logical_shape());\n  const bool need_transpose = !((pack_conf.dst_sbp_parallel().has_split_parallel()\n                                 && pack_conf.dst_sbp_parallel().split_parallel().axis() == 0)\n                                || pack_conf.dst_sbp_parallel().has_broadcast_parallel()\n                                || pack_conf.dst_sbp_parallel().has_partial_sum_parallel());\n  if (need_transpose) {\n    const int64_t dst_split_axis = pack_conf.dst_sbp_parallel().split_parallel().axis();\n    DimVector transpose_in_dim_vec = logical_shape.dim_vec();\n    if (pack_conf.src_sbp_parallel().has_split_parallel()) {\n      const int64_t src_split_axis = pack_conf.src_sbp_parallel().split_parallel().axis();\n      transpose_in_dim_vec[src_split_axis] = transpose_in_dim_vec.at(src_split_axis) / num_ranks;\n    }\n    CHECK_EQ(transpose_in_dim_vec.at(dst_split_axis) % num_ranks, 0);\n    transpose_in_dim_vec[dst_split_axis] = transpose_in_dim_vec.at(dst_split_axis) / num_ranks;\n    transpose_in_dim_vec.insert(transpose_in_dim_vec.begin() + dst_split_axis, num_ranks);\n    std::vector<int32_t> perm;\n    perm.emplace_back(dst_split_axis);\n    FOR_RANGE(int64_t, i, 0, transpose_in_dim_vec.size()) {\n      if (i != dst_split_axis) { perm.emplace_back(i); }\n    }\n    auto transpose = ep::primitive::NewPrimitive<ep::primitive::PermuteFactory>(\n        ctx->stream()->device_type(), transpose_in_dim_vec.size());\n    CHECK(transpose);\n    transpose->Launch(ctx->stream(), in->data_type(), transpose_in_dim_vec.size(),\n                      transpose_in_dim_vec.data(), in->dptr(), perm.data(), out->mut_dptr());\n  } else {\n    AutoMemcpy(ctx->stream(), out, in);\n  }\n}\n\nREGISTER_KERNEL(OperatorConf::kCollectiveBoxingPackConf, CollectiveBoxingPackKernel);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/kernel/collective_boxing_unpack_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/kernel/kernel.h\"\n#include \"oneflow/core/kernel/new_kernel_util.h\"\n#include \"oneflow/core/ep/include/primitive/permute.h\"\n\nnamespace oneflow {\n\nclass CollectiveBoxingUnpackKernel final : public Kernel {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CollectiveBoxingUnpackKernel);\n  CollectiveBoxingUnpackKernel() = default;\n  ~CollectiveBoxingUnpackKernel() override = default;\n\n private:\n  bool IsStateless() const override { return false; }\n  void ForwardDataContent(KernelContext* ctx) const override;\n};\n\nvoid CollectiveBoxingUnpackKernel::ForwardDataContent(KernelContext* ctx) const {\n  const Blob* in = ctx->BnInOp2Blob(\"in\");\n  Blob* out = ctx->BnInOp2Blob(\"out\");\n  const CollectiveBoxingUnpackOpConf& unpack_conf = this->op_conf().collective_boxing_unpack_conf();\n  const int64_t num_ranks = unpack_conf.num_ranks();\n  const Shape logical_shape(unpack_conf.logical_shape());\n  // skip 0size tensor boxing\n  if (logical_shape.elem_cnt() == 0) { return; }\n  const bool need_transpose = !((unpack_conf.src_sbp_parallel().has_split_parallel()\n                                 && unpack_conf.src_sbp_parallel().split_parallel().axis() == 0)\n                                || unpack_conf.src_sbp_parallel().has_broadcast_parallel()\n                                || unpack_conf.src_sbp_parallel().has_partial_sum_parallel());\n  if (need_transpose) {\n    const int64_t src_split_axis = unpack_conf.src_sbp_parallel().split_parallel().axis();\n    DimVector transpose_in_dim_vec = logical_shape.dim_vec();\n    CHECK_EQ(transpose_in_dim_vec.at(src_split_axis) % num_ranks, 0);\n    transpose_in_dim_vec[src_split_axis] = transpose_in_dim_vec.at(src_split_axis) / num_ranks;\n    if (unpack_conf.dst_sbp_parallel().has_split_parallel()) {\n      const int64_t dst_split_axis = unpack_conf.dst_sbp_parallel().split_parallel().axis();\n      CHECK_EQ(transpose_in_dim_vec.at(dst_split_axis) % num_ranks, 0);\n      transpose_in_dim_vec[dst_split_axis] = transpose_in_dim_vec.at(dst_split_axis) / num_ranks;\n    }\n    transpose_in_dim_vec.insert(transpose_in_dim_vec.begin(), num_ranks);\n    std::vector<int32_t> perm;\n    FOR_RANGE(int64_t, i, 1, transpose_in_dim_vec.size()) { perm.emplace_back(i); }\n    perm.insert(perm.begin() + src_split_axis, 0);\n    auto transpose = ep::primitive::NewPrimitive<ep::primitive::PermuteFactory>(\n        ctx->stream()->device_type(), transpose_in_dim_vec.size());\n    CHECK(transpose);\n    transpose->Launch(ctx->stream(), in->data_type(), transpose_in_dim_vec.size(),\n                      transpose_in_dim_vec.data(), in->dptr(), perm.data(), out->mut_dptr());\n  } else {\n    AutoMemcpy(ctx->stream(), out, in);\n  }\n}\n\nREGISTER_KERNEL(OperatorConf::kCollectiveBoxingUnpackConf, CollectiveBoxingUnpackKernel);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/kernel/constant_like_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/kernel/kernel.h\"\n#include \"oneflow/core/common/scalar.h\"\n#include \"oneflow/core/ep/include/primitive/fill.h\"\n\nnamespace oneflow {\n\nclass ConstantLikeKernel final : public Kernel {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(ConstantLikeKernel);\n  ConstantLikeKernel() : is_init_(false) {}\n  ~ConstantLikeKernel() = default;\n\n private:\n  mutable bool is_init_;\n\n  void ForwardDataContent(KernelContext* ctx) const override {\n    if (is_init_) { return; }\n    Blob* out_blob = ctx->BnInOp2Blob(\"out\");\n    Scalar value;\n    const auto& conf = this->op_conf().constant_like_conf();\n    if (conf.has_int_operand()) {\n      value = Scalar(conf.int_operand());\n    } else if (conf.has_float_operand()) {\n      value = Scalar(conf.float_operand());\n    } else {\n      UNIMPLEMENTED();\n    }\n    std::unique_ptr<ep::primitive::Fill> primitive =\n        ep::primitive::NewPrimitive<ep::primitive::FillFactory>(ctx->stream()->device_type(),\n                                                                out_blob->data_type());\n    CHECK(primitive);\n    primitive->Launch(ctx->stream(), out_blob->mut_dptr(), value,\n                      out_blob->static_shape().elem_cnt());\n    is_init_ = true;\n  }\n};\n\nREGISTER_KERNEL(OperatorConf::kConstantLikeConf, ConstantLikeKernel);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/kernel/cpu_check_numerics_kernel_observer.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_KERNEL_CPU_CHECK_NUMERICS_KERNEL_OBSERVER_H_\n#define ONEFLOW_CORE_KERNEL_CPU_CHECK_NUMERICS_KERNEL_OBSERVER_H_\n\n#include \"oneflow/core/kernel/kernel_observer.h\"\n\nnamespace oneflow {\n\nclass CpuCheckNumericsKernelObserver final : public KernelObserver {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CpuCheckNumericsKernelObserver);\n  CpuCheckNumericsKernelObserver() = default;\n  ~CpuCheckNumericsKernelObserver() override = default;\n\n  void DidForwardDataContent(KernelContext* ctx, const Kernel* kernel) override;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_KERNEL_CPU_CHECK_NUMERICS_KERNEL_OBSERVER_H_\n"
  },
  {
    "path": "oneflow/core/kernel/cpu_numerics_kernel_observer.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/kernel/cpu_check_numerics_kernel_observer.h\"\n#include \"oneflow/core/kernel/kernel.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename T>\nbool HasNotFinite(const int64_t elem_cnt, const T* data_ptr) {\n  FOR_RANGE(int64_t, i, 0, elem_cnt) {\n    if (!std::isfinite(data_ptr[i])) { return true; }\n  }\n  return false;\n}\n\nbool HasNotFiniteCpu(ep::Stream* stream, const Blob* blob) {\n  const DataType dtype = blob->data_type();\n  const int64_t elem_cnt = blob->shape().elem_cnt();\n  if (dtype == kFloat) {\n    return HasNotFinite<float>(elem_cnt, blob->dptr<float>());\n  } else if (dtype == kDouble) {\n    return HasNotFinite<double>(elem_cnt, blob->dptr<double>());\n  } else {\n    return false;\n  }\n}\n\nvoid DumpBlob(KernelContext* ctx, const std::string& bn) {\n  Blob* blob = ctx->BnInOp2Blob(bn);\n  if (blob != nullptr) {\n    std::ofstream ofs(bn);\n    ofs.write(blob->dptr<char>(), blob->ByteSizeOfBlobBody());\n  }\n}\n\nvoid DumpBlobs(KernelContext* ctx, const Kernel* kernel) {\n  for (const auto& obn : kernel->op_attribute().output_bns()) { DumpBlob(ctx, obn); }\n  for (const auto& ibn : kernel->op_attribute().input_bns()) { DumpBlob(ctx, ibn); }\n}\n\n}  // namespace\n\nvoid CpuCheckNumericsKernelObserver::DidForwardDataContent(KernelContext* ctx,\n                                                           const Kernel* kernel) {\n  for (const auto& obn : kernel->op_attribute().output_bns()) {\n    Blob* blob = ctx->BnInOp2Blob(obn);\n    if (blob != nullptr) {\n      bool has_not_finite = HasNotFiniteCpu(ctx->stream(), blob);\n      if (has_not_finite\n          && ParseBooleanFromEnv(\"ONEFLOW_DEBUG_KERNEL_SYNC_CHECK_NUMERICS_DUMP\", false)) {\n        DumpBlobs(ctx, kernel);\n      }\n      CHECK(!has_not_finite) << kernel->op_conf().name() << \" : \" << obn << \" has nan or inf\";\n    }\n  }\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/kernel/critical_section_callback_tick_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/kernel/kernel.h\"\n#include \"oneflow/core/common/buffer_manager.h\"\n#include \"oneflow/core/job/critical_section_instance.h\"\n#include \"oneflow/core/job/global_for.h\"\n#include \"oneflow/core/common/buffer_manager.h\"\n\nnamespace oneflow {\n\nclass CriticalSectionCallbackTickKernel final : public Kernel {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CriticalSectionCallbackTickKernel);\n  CriticalSectionCallbackTickKernel() = default;\n  ~CriticalSectionCallbackTickKernel() = default;\n\n private:\n  bool IsStateless() const override { return false; }\n  void ForwardDataContent(KernelContext* ctx) const override;\n};\n\nvoid CriticalSectionCallbackTickKernel::ForwardDataContent(KernelContext* ctx) const {\n  auto* buffer_mgr = Singleton<BufferMgr<std::shared_ptr<CriticalSectionInstance>>>::Get();\n  CHECK(op_conf().has_critical_section_callback_tick_conf());\n  const std::string& buffer_name = op_conf().critical_section_callback_tick_conf().buffer_name();\n  std::shared_ptr<CriticalSectionInstance> critical_section_instance;\n  BufferStatus buffer_status = buffer_mgr->Get(buffer_name)->TryReceive(&critical_section_instance);\n  CHECK_EQ(buffer_status, kBufferStatusSuccess);\n  critical_section_instance->Finish();\n}\n\nREGISTER_KERNEL(OperatorConf::kCriticalSectionCallbackTickConf, CriticalSectionCallbackTickKernel);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/kernel/critical_section_wait_tick_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/kernel/kernel.h\"\n#include \"oneflow/core/common/buffer_manager.h\"\n#include \"oneflow/core/job/critical_section_instance.h\"\n#include \"oneflow/core/job/global_for.h\"\n#include \"oneflow/core/common/buffer_manager.h\"\n\nnamespace oneflow {\n\nclass CriticalSectionWaitTickKernel final : public Kernel {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CriticalSectionWaitTickKernel);\n  CriticalSectionWaitTickKernel() = default;\n  ~CriticalSectionWaitTickKernel() = default;\n\n private:\n  bool IsStateless() const override { return false; }\n  void ForwardDataContent(KernelContext* ctx) const override;\n};\n\nvoid CriticalSectionWaitTickKernel::ForwardDataContent(KernelContext* ctx) const {\n  auto* buffer_mgr = Singleton<BufferMgr<std::shared_ptr<CriticalSectionInstance>>>::Get();\n  CHECK(this->op_conf().has_critical_section_wait_tick_conf());\n  const std::string& buffer_name = this->op_conf().critical_section_wait_tick_conf().buffer_name();\n  std::shared_ptr<CriticalSectionInstance> critical_section_instance;\n  BufferStatus buffer_status = buffer_mgr->Get(buffer_name)->Pull(&critical_section_instance);\n  CHECK_EQ(buffer_status, kBufferStatusSuccess);\n}\n\nREGISTER_KERNEL(OperatorConf::kCriticalSectionWaitTickConf, CriticalSectionWaitTickKernel);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/kernel/cuda_check_numerics_kernel_observer.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/kernel/cuda_check_numerics_kernel_observer.h\"\n#include \"oneflow/core/kernel/kernel.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename T>\n__device__ bool IsNotFinite(T x) {\n  return !isfinite(x);\n}\n\ntemplate<>\n__device__ bool IsNotFinite<half>(half x) {\n#if __CUDA_ARCH__ >= 530\n  return (__hisinf(x) || __hisnan(x));\n#else\n  __trap();\n  return true;\n#endif\n}\n\ntemplate<typename T>\n__global__ void HasNotFiniteGpuKernel(const int64_t n, const T* x, volatile bool* has_not_finite) {\n  if (*has_not_finite) { return; }\n  CUDA_1D_KERNEL_LOOP_T(int64_t, i, n) {\n    if (IsNotFinite(x[i])) {\n      *has_not_finite = true;\n      return;\n    }\n  }\n}\n\ntemplate<typename T>\nbool HasNotFinite(ep::Stream* stream, const int64_t elem_cnt, const T* data_ptr,\n                  bool* has_not_finite_host, bool* has_not_finite_device) {\n  OF_CUDA_CHECK(cudaMemsetAsync(has_not_finite_device, 0, sizeof(bool),\n                                stream->As<ep::CudaStream>()->cuda_stream()));\n  HasNotFiniteGpuKernel<T>\n      <<<BlocksNum4ThreadsNum(elem_cnt), kCudaThreadsNumPerBlock, 0,\n         stream->As<ep::CudaStream>()->cuda_stream()>>>(elem_cnt, data_ptr, has_not_finite_device);\n  OF_CUDA_CHECK(cudaMemcpyAsync(has_not_finite_host, has_not_finite_device, sizeof(bool),\n                                cudaMemcpyDefault, stream->As<ep::CudaStream>()->cuda_stream()));\n  OF_CUDA_CHECK(cudaStreamSynchronize(stream->As<ep::CudaStream>()->cuda_stream()));\n  return *has_not_finite_host;\n}\n\nbool HasNotFiniteGpu(ep::Stream* stream, const Blob* blob, bool* has_not_finite_host,\n                     bool* has_not_finite_device) {\n  auto* cuda_stream = stream->As<ep::CudaStream>();\n  const DataType dtype = blob->data_type();\n  const int64_t elem_cnt = blob->shape().elem_cnt();\n  if (elem_cnt == 0) { return false; }\n  if (dtype == kFloat) {\n    return HasNotFinite<float>(stream, elem_cnt, blob->dptr<float>(), has_not_finite_host,\n                               has_not_finite_device);\n  } else if (dtype == kDouble) {\n    return HasNotFinite<double>(stream, elem_cnt, blob->dptr<double>(), has_not_finite_host,\n                                has_not_finite_device);\n  } else if (dtype == kFloat16) {\n    if (cuda_stream->cuda_arch() >= 530) {\n      return HasNotFinite<half>(stream, elem_cnt, blob->dptr<half>(), has_not_finite_host,\n                                has_not_finite_device);\n    } else {\n      LOG(FATAL) << \"use half need nvcc arch >= 530\";\n      return true;\n    }\n  } else {\n    return false;\n  }\n}\n\nvoid DumpBlob(KernelContext* ctx, const std::string& bn) {\n  Blob* blob = ctx->BnInOp2Blob(bn);\n  if (blob != nullptr) {\n    std::vector<char> buffer(blob->ByteSizeOfBlobBody());\n    OF_CUDA_CHECK(\n        cudaMemcpy(buffer.data(), blob->dptr(), blob->ByteSizeOfBlobBody(), cudaMemcpyDefault));\n    OF_CUDA_CHECK(cudaDeviceSynchronize());\n    std::ofstream ofs(bn);\n    ofs.write(buffer.data(), blob->ByteSizeOfBlobBody());\n  }\n}\n\nvoid DumpBlobs(KernelContext* ctx, const Kernel* kernel) {\n  for (const auto& obn : kernel->op_attribute().output_bns()) { DumpBlob(ctx, obn); }\n  for (const auto& ibn : kernel->op_attribute().input_bns()) { DumpBlob(ctx, ibn); }\n}\n\n}  // namespace\n\nCudaCheckNumericsKernelObserver::CudaCheckNumericsKernelObserver()\n    : has_not_finite_host_(nullptr), has_not_finite_device_(nullptr) {\n  OF_CUDA_CHECK(cudaGetDevice(&device_id_));\n  OF_CUDA_CHECK(cudaMallocHost(&has_not_finite_host_, sizeof(bool)));\n  OF_CUDA_CHECK(cudaMalloc(&has_not_finite_device_, sizeof(bool)));\n}\n\nCudaCheckNumericsKernelObserver::~CudaCheckNumericsKernelObserver() {\n  CudaCurrentDeviceGuard guard(device_id_);\n  OF_CUDA_CHECK(cudaFreeHost(has_not_finite_host_));\n  OF_CUDA_CHECK(cudaFree(has_not_finite_device_));\n}\n\nvoid CudaCheckNumericsKernelObserver::DidForwardDataContent(KernelContext* ctx,\n                                                            const Kernel* kernel) {\n  for (const auto& obn : kernel->op_attribute().output_bns()) {\n    Blob* blob = ctx->BnInOp2Blob(obn);\n    if (blob != nullptr) {\n      bool has_not_finite =\n          HasNotFiniteGpu(ctx->stream(), blob, has_not_finite_host_, has_not_finite_device_);\n      if (has_not_finite\n          && ParseBooleanFromEnv(\"ONEFLOW_DEBUG_KERNEL_SYNC_CHECK_NUMERICS_DUMP\", false)) {\n        DumpBlobs(ctx, kernel);\n      }\n      CHECK(!has_not_finite) << kernel->op_conf().name() << \" : \" << obn << \" has nan or inf\";\n    }\n  }\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/kernel/cuda_check_numerics_kernel_observer.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_KERNEL_CUDA_CHECK_NUMERICS_KERNEL_OBSERVER_H_\n#define ONEFLOW_CORE_KERNEL_CUDA_CHECK_NUMERICS_KERNEL_OBSERVER_H_\n\n#ifdef WITH_CUDA\n\n#include \"oneflow/core/kernel/kernel_observer.h\"\n\nnamespace oneflow {\n\nclass CudaCheckNumericsKernelObserver final : public KernelObserver {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CudaCheckNumericsKernelObserver);\n  CudaCheckNumericsKernelObserver();\n  ~CudaCheckNumericsKernelObserver() override;\n\n  void DidForwardDataContent(KernelContext* ctx, const Kernel* kernel) override;\n\n private:\n  bool* has_not_finite_host_;\n  bool* has_not_finite_device_;\n  int device_id_;\n};\n\n}  // namespace oneflow\n\n#endif  // WITH_CUDA\n\n#endif  // ONEFLOW_CORE_KERNEL_CUDA_CHECK_NUMERICS_KERNEL_OBSERVER_H_\n"
  },
  {
    "path": "oneflow/core/kernel/cuda_graph_support.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#ifndef ONEFLOW_CORE_KERNEL_CUDA_GRAPH_SUPPORT_H_\n#define ONEFLOW_CORE_KERNEL_CUDA_GRAPH_SUPPORT_H_\n\nnamespace oneflow {\n\nnamespace user_op {\n\nclass KernelInitContext;\nclass KernelComputeContext;\nclass OpKernelState;\nclass OpKernelCache;\n\nclass CudaGraphSupport {\n public:\n  CudaGraphSupport() = default;\n  virtual ~CudaGraphSupport() = default;\n\n  virtual bool IsCudaGraphSupported(KernelInitContext* ctx, OpKernelState* state) const {\n    return true;\n  }\n\n  virtual bool IsReadyForCapture(KernelComputeContext* ctx, OpKernelState* state,\n                                 const OpKernelCache* cache) const {\n    return true;\n  }\n};\n\n}  // namespace user_op\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_KERNEL_CUDA_GRAPH_SUPPORT_H_\n"
  },
  {
    "path": "oneflow/core/kernel/distribute_kernels.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/kernel/kernel.h\"\n#include \"oneflow/core/kernel/kernel_context.h\"\n\nnamespace oneflow {\n\nclass DistributeAddKernel final : public Kernel {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(DistributeAddKernel);\n  DistributeAddKernel() = default;\n  ~DistributeAddKernel() = default;\n\n private:\n  void ForwardDataContent(KernelContext* ctx) const override;\n  const Blob* GetInBlob(KernelContext* ctx) const;\n};\n\nvoid DistributeAddKernel::ForwardDataContent(KernelContext* ctx) const {\n  AutoMemcpy(ctx->stream(), ctx->BnInOp2Blob(\"out\"), GetInBlob(ctx));\n}\n\nconst Blob* DistributeAddKernel::GetInBlob(KernelContext* ctx) const {\n  const Blob* in_blob = nullptr;\n  FOR_RANGE(int, i, 0, this->op_attribute().input_bns().size()) {\n    const Blob* cur_blob = ctx->BnInOp2Blob(this->op_attribute().input_bns().Get(i));\n    if (cur_blob != nullptr && cur_blob != in_blob) {\n      CHECK_ISNULL(in_blob);\n      in_blob = cur_blob;\n    }\n  }\n  return in_blob;\n}\n\nREGISTER_KERNEL(OperatorConf::kDistributeAddConf, DistributeAddKernel);\n\nclass DistributeCloneKernel final : public Kernel {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(DistributeCloneKernel);\n  DistributeCloneKernel() = default;\n  ~DistributeCloneKernel() = default;\n\n private:\n  void ForwardDataContent(KernelContext* ctx) const override;\n  Blob* GetOutBlob(KernelContext* ctx) const;\n};\n\nvoid DistributeCloneKernel::ForwardDataContent(KernelContext* ctx) const {\n  AutoMemcpy(ctx->stream(), GetOutBlob(ctx), ctx->BnInOp2Blob(\"in\"));\n}\n\nBlob* DistributeCloneKernel::GetOutBlob(KernelContext* ctx) const {\n  Blob* out_blob = nullptr;\n  FOR_RANGE(int, i, 0, this->op_attribute().output_bns().size()) {\n    Blob* cur_blob = ctx->BnInOp2Blob(this->op_attribute().output_bns().Get(i));\n    if (cur_blob != nullptr && cur_blob != out_blob) {\n      CHECK_ISNULL(out_blob);\n      out_blob = cur_blob;\n    }\n  }\n  return out_blob;\n}\n\nREGISTER_KERNEL(OperatorConf::kDistributeCloneConf, DistributeCloneKernel);\n\nclass DistributeConcatKernel final : public Kernel {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(DistributeConcatKernel);\n  DistributeConcatKernel() = default;\n  ~DistributeConcatKernel() = default;\n\n private:\n  void ForwardDataContent(KernelContext* ctx) const override;\n  const Blob* GetInBlob(KernelContext* ctx) const;\n};\n\nvoid DistributeConcatKernel::ForwardDataContent(KernelContext* ctx) const {\n  AutoMemcpy(ctx->stream(), ctx->BnInOp2Blob(\"out\"), GetInBlob(ctx));\n}\n\nconst Blob* DistributeConcatKernel::GetInBlob(KernelContext* ctx) const {\n  const Blob* in_blob = nullptr;\n  FOR_RANGE(int, i, 0, this->op_attribute().input_bns().size()) {\n    const Blob* cur_blob = ctx->BnInOp2Blob(this->op_attribute().input_bns().Get(i));\n    if (cur_blob != nullptr && cur_blob != in_blob) {\n      CHECK_ISNULL(in_blob);\n      in_blob = cur_blob;\n    }\n  }\n  return in_blob;\n}\n\nREGISTER_KERNEL(OperatorConf::kDistributeConcatConf, DistributeConcatKernel);\n\nclass DistributeSplitKernel final : public Kernel {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(DistributeSplitKernel);\n  DistributeSplitKernel() = default;\n  ~DistributeSplitKernel() = default;\n\n private:\n  void ForwardDataContent(KernelContext* ctx) const override;\n  void ForwardShape(KernelContext* ctx) const override;\n  Blob* GetOutBlob(KernelContext* ctx) const;\n};\n\nvoid DistributeSplitKernel::ForwardDataContent(KernelContext* ctx) const {\n  AutoMemcpy(ctx->stream(), GetOutBlob(ctx), ctx->BnInOp2Blob(\"in\"));\n}\n\nvoid DistributeSplitKernel::ForwardShape(KernelContext* ctx) const {\n  Blob* out_blob = GetOutBlob(ctx);\n  out_blob->mut_shape_view()->set_shape(ctx->BnInOp2Blob(\"in\")->shape());\n}\n\nBlob* DistributeSplitKernel::GetOutBlob(KernelContext* ctx) const {\n  Blob* out_blob = nullptr;\n  FOR_RANGE(int, i, 0, this->op_attribute().output_bns().size()) {\n    Blob* cur_blob = ctx->BnInOp2Blob(this->op_attribute().output_bns().Get(i));\n    if (cur_blob != nullptr && cur_blob != out_blob) {\n      CHECK_ISNULL(out_blob);\n      out_blob = cur_blob;\n    }\n  }\n  return out_blob;\n}\n\nREGISTER_KERNEL(OperatorConf::kDistributeSplitConf, DistributeSplitKernel);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/kernel/dynamic_reshape_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/kernel/kernel.h\"\n\nnamespace oneflow {\n\nclass DynamicReshapeKernel final : public Kernel {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(DynamicReshapeKernel);\n  DynamicReshapeKernel() = default;\n  ~DynamicReshapeKernel() override = default;\n\n private:\n  void ForwardDataContent(KernelContext* ctx) const override;\n};\n\nvoid DynamicReshapeKernel::ForwardDataContent(KernelContext* ctx) const {\n  const Blob* in_blob = ctx->BnInOp2Blob(\"in\");\n  Blob* out_blob = ctx->BnInOp2Blob(\"out\");\n  AutoMemcpy(ctx->stream(), out_blob, in_blob);\n}\n\nREGISTER_KERNEL(OperatorConf::kDynamicReshapeConf, DynamicReshapeKernel);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/kernel/dynamic_reshape_like_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/kernel/kernel.h\"\n\nnamespace oneflow {\n\nclass DynamicReshapeLikeKernel final : public Kernel {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(DynamicReshapeLikeKernel);\n  DynamicReshapeLikeKernel() = default;\n  ~DynamicReshapeLikeKernel() override = default;\n\n private:\n  void ForwardDataContent(KernelContext* ctx) const override;\n};\n\nvoid DynamicReshapeLikeKernel::ForwardDataContent(KernelContext* ctx) const {\n  const Blob* in_blob = ctx->BnInOp2Blob(\"x\");\n  Blob* out_blob = ctx->BnInOp2Blob(\"y\");\n  AutoMemcpy(ctx->stream(), out_blob, in_blob);\n}\n\nREGISTER_KERNEL(OperatorConf::kDynamicReshapeLikeConf, DynamicReshapeLikeKernel);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/kernel/esac_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/kernel/esac_kernel.h\"\n\nnamespace oneflow {\n\ntemplate<typename T>\nvoid EsacKernel<T>::VirtualKernelInit(KernelContext* ctx) {\n  ctx->set_state(std::make_shared<EsacKernelState>());\n}\n\ntemplate<typename T>\nvoid EsacKernel<T>::ForwardDataContent(KernelContext* ctx) const {\n  T value =\n      static_cast<T>(CHECK_NOTNULL(dynamic_cast<EsacKernelState*>(ctx->state().get()))->value);\n  *(ctx->BnInOp2Blob(\"out\")->mut_dptr<T>()) = value;\n}\n\nADD_CPU_DEFAULT_KERNEL_CREATOR(OperatorConf::kEsacConf, EsacKernel, INT_DATA_TYPE_SEQ)\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/kernel/esac_kernel.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_KERNEL_ESAC_KERNEL_H_\n#define ONEFLOW_CORE_KERNEL_ESAC_KERNEL_H_\n\n#include \"oneflow/core/kernel/kernel.h\"\n\nnamespace oneflow {\n\nstruct EsacKernelState : public KernelState {\n  int64_t value{};\n};\n\ntemplate<typename T>\nclass EsacKernel final : public Kernel {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(EsacKernel);\n  EsacKernel() = default;\n  ~EsacKernel() override = default;\n\n private:\n  void VirtualKernelInit(KernelContext* ctx) override;\n  void ForwardDataContent(KernelContext* ctx) const override;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_KERNEL_ESAC_KERNEL_H_\n"
  },
  {
    "path": "oneflow/core/kernel/identity_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/kernel/kernel.h\"\n#include \"oneflow/core/kernel/kernel_context.h\"\n#include \"oneflow/core/ep/include/primitive/memcpy.h\"\n\nnamespace oneflow {\n\nclass IdentityKernel final : public Kernel {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(IdentityKernel);\n  IdentityKernel() = default;\n  ~IdentityKernel() = default;\n\n private:\n  void ForwardDataContent(KernelContext* ctx) const override;\n  void ForwardHeader(KernelContext* ctx) const override;\n};\n\nvoid IdentityKernel::ForwardDataContent(KernelContext* ctx) const {\n  const Blob* in_blob = ctx->BnInOp2Blob(\"in\");\n  Blob* out_blob = ctx->BnInOp2Blob(\"out\");\n  AutoMemcpy(ctx->stream(), out_blob, in_blob);\n}\n\nvoid IdentityKernel::ForwardHeader(KernelContext* ctx) const {\n  ctx->BnInOp2Blob(\"out\")->CopyHeaderFrom(ctx->BnInOp2Blob(\"in\"));\n}\n\nREGISTER_KERNEL(OperatorConf::kIdentityConf, IdentityKernel);\nREGISTER_KERNEL(OperatorConf::kCopyConf, IdentityKernel);\nREGISTER_KERNEL(OperatorConf::kCastToLocalConf, IdentityKernel);\nREGISTER_KERNEL(OperatorConf::kCastFromLocalConf, IdentityKernel);\nREGISTER_KERNEL(OperatorConf::kBoxingIdentityConf, IdentityKernel);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/kernel/image_decoder_random_crop_resize_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/common/error.h\"\n#include \"oneflow/core/kernel/kernel.h\"\n#include \"oneflow/core/common/tensor_buffer.h\"\n#include \"oneflow/core/common/channel.h\"\n#include \"oneflow/core/common/blocking_counter.h\"\n#include \"oneflow/core/profiler/profiler.h\"\n#include \"oneflow/user/image/random_crop_generator.h\"\n#include \"oneflow/user/image/jpeg_decoder.h\"\n#include <opencv2/opencv.hpp>\n\n#ifdef WITH_CUDA\n\n#include <cuda.h>\n\n#if CUDA_VERSION >= 10020\n\n#define WITH_NVJPEG\n\n#include <nvjpeg.h>\n#include <npp.h>\n\n#endif  // CUDA_VERSION >= 10020\n\n#endif  // WITH_CUDA\n\nnamespace oneflow {\n\nnamespace {\n\nconstexpr int kNumChannels = 3;\n\nstruct Task {\n  const unsigned char* data;\n  size_t length;\n  unsigned char* dst;\n  RandomCropGenerator* crop_generator;\n};\n\nstruct Work {\n  std::shared_ptr<std::vector<Task>> tasks;\n  unsigned char* workspace = nullptr;\n  size_t workspace_size = 0;\n  std::shared_ptr<BlockingCounter> done_counter;\n  std::shared_ptr<std::atomic<int>> task_counter;\n};\n\nstruct ROI {\n  int x;\n  int y;\n  int w;\n  int h;\n};\n\nclass ROIGenerator {\n public:\n  virtual ~ROIGenerator() = default;\n  virtual void Generate(int width, int height, ROI* roi) const = 0;\n};\n\nclass RandomCropROIGenerator : public ROIGenerator {\n public:\n  explicit RandomCropROIGenerator(RandomCropGenerator* crop_generator)\n      : crop_generator_(crop_generator) {}\n  ~RandomCropROIGenerator() override = default;\n\n  void Generate(int width, int height, ROI* roi) const override {\n    CropWindow window;\n    crop_generator_->GenerateCropWindow({height, width}, &window);\n    roi->x = window.anchor.At(1);\n    roi->y = window.anchor.At(0);\n    roi->w = window.shape.At(1);\n    roi->h = window.shape.At(0);\n  }\n\n private:\n  RandomCropGenerator* crop_generator_;\n};\n\nclass NoChangeROIGenerator : public ROIGenerator {\n public:\n  ~NoChangeROIGenerator() override = default;\n\n  void Generate(int width, int height, ROI* roi) const override {\n    roi->x = 0;\n    roi->y = 0;\n    roi->w = width;\n    roi->h = height;\n  }\n};\n\nvoid GenerateRandomCropRoi(RandomCropGenerator* crop_generator, int width, int height, int* roi_x,\n                           int* roi_y, int* roi_width, int* roi_height) {\n  CropWindow window;\n  crop_generator->GenerateCropWindow({height, width}, &window);\n  *roi_x = window.anchor.At(1);\n  *roi_y = window.anchor.At(0);\n  *roi_width = window.shape.At(1);\n  *roi_height = window.shape.At(0);\n}\n\nclass DecodeHandle {\n public:\n  DecodeHandle() = default;\n  virtual ~DecodeHandle() = default;\n\n  virtual void DecodeRandomCropResize(const unsigned char* data, size_t length,\n                                      RandomCropGenerator* crop_generator, unsigned char* workspace,\n                                      size_t workspace_size, unsigned char* dst, int target_width,\n                                      int target_height) = 0;\n  virtual void WarmupOnce(int warmup_size, unsigned char* workspace, size_t workspace_size) = 0;\n  virtual void Synchronize() = 0;\n};\n\nusing DecodeHandleFactory = std::function<std::shared_ptr<DecodeHandle>()>;\ntemplate<DeviceType device_type>\nDecodeHandleFactory CreateDecodeHandleFactory(int target_width, int target_height);\n\nclass CpuDecodeHandle final : public DecodeHandle {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CpuDecodeHandle);\n  CpuDecodeHandle() = default;\n  ~CpuDecodeHandle() override = default;\n\n  void DecodeRandomCropResize(const unsigned char* data, size_t length,\n                              RandomCropGenerator* crop_generator, unsigned char* workspace,\n                              size_t workspace_size, unsigned char* dst, int target_width,\n                              int target_height) override;\n  void WarmupOnce(int warmup_size, unsigned char* workspace, size_t workspace_size) override {\n    // do nothing\n  }\n  void Synchronize() override {\n    // do nothing\n  }\n};\n\nbool CpuJpegDecodeRandomCropResize(const unsigned char* data, size_t length,\n                                   RandomCropGenerator* crop_generator, unsigned char* workspace,\n                                   size_t workspace_size, unsigned char* dst, int target_width,\n                                   int target_height) {\n  cv::Mat image_mat;\n  if (JpegPartialDecodeRandomCropImage(data, length, crop_generator, workspace, workspace_size,\n                                       &image_mat)) {\n    return false;\n  }\n\n  cv::Mat dst_mat(target_height, target_width, CV_8UC3, dst, cv::Mat::AUTO_STEP);\n  cv::resize(image_mat, dst_mat, cv::Size(target_width, target_height), 0, 0, cv::INTER_LINEAR);\n  return true;\n}\n\nvoid OpencvDecodeRandomCropResize(const unsigned char* data, size_t length,\n                                  RandomCropGenerator* crop_generator, unsigned char* dst,\n                                  int target_width, int target_height) {\n  cv::Mat image =\n      cv::imdecode(cv::Mat(1, length, CV_8UC1, const_cast<unsigned char*>(data)), cv::IMREAD_COLOR);\n  cv::Mat cropped;\n  if (crop_generator) {\n    cv::Rect roi;\n    GenerateRandomCropRoi(crop_generator, image.cols, image.rows, &roi.x, &roi.y, &roi.width,\n                          &roi.height);\n    image(roi).copyTo(cropped);\n  } else {\n    cropped = image;\n  }\n  cv::Mat resized;\n  cv::resize(cropped, resized, cv::Size(target_width, target_height), 0, 0, cv::INTER_LINEAR);\n  cv::Mat dst_mat(target_height, target_width, CV_8UC3, dst, cv::Mat::AUTO_STEP);\n  cv::cvtColor(resized, dst_mat, cv::COLOR_BGR2RGB);\n}\n\nvoid CpuDecodeHandle::DecodeRandomCropResize(const unsigned char* data, size_t length,\n                                             RandomCropGenerator* crop_generator,\n                                             unsigned char* workspace, size_t workspace_size,\n                                             unsigned char* dst, int target_width,\n                                             int target_height) {\n  if (CpuJpegDecodeRandomCropResize(data, length, crop_generator, workspace, workspace_size, dst,\n                                    target_width, target_height)) {\n    return;\n  }\n\n  OpencvDecodeRandomCropResize(data, length, crop_generator, dst, target_width, target_height);\n}\n\ntemplate<>\nDecodeHandleFactory CreateDecodeHandleFactory<DeviceType::kCPU>(int target_width,\n                                                                int target_height) {\n  return []() -> std::shared_ptr<DecodeHandle> { return std::make_shared<CpuDecodeHandle>(); };\n}\n\n#if defined(WITH_NVJPEG)\n\nint GpuDeviceMalloc(void** p, size_t s) { return (int)cudaMalloc(p, s); }\n\nint GpuDeviceFree(void* p) { return (int)cudaFree(p); }\n\nint GpuPinnedMalloc(void** p, size_t s, unsigned int flags) {\n  return (int)cudaHostAlloc(p, s, flags);\n}\n\nint GpuPinnedFree(void* p) { return (int)cudaFreeHost(p); }\n\nvoid InitNppStreamContext(NppStreamContext* ctx, int dev, cudaStream_t stream) {\n  ctx->hStream = stream;\n  ctx->nCudaDeviceId = dev;\n  OF_CUDA_CHECK(\n      cudaDeviceGetAttribute(&ctx->nMultiProcessorCount, cudaDevAttrMultiProcessorCount, dev));\n  OF_CUDA_CHECK(cudaDeviceGetAttribute(&ctx->nMaxThreadsPerMultiProcessor,\n                                       cudaDevAttrMaxThreadsPerMultiProcessor, dev));\n  OF_CUDA_CHECK(\n      cudaDeviceGetAttribute(&ctx->nMaxThreadsPerBlock, cudaDevAttrMaxThreadsPerBlock, dev));\n  int smem_per_block = 0;\n  OF_CUDA_CHECK(cudaDeviceGetAttribute(&smem_per_block, cudaDevAttrMaxSharedMemoryPerBlock, dev));\n  ctx->nSharedMemPerBlock = smem_per_block;\n  OF_CUDA_CHECK(cudaDeviceGetAttribute(&ctx->nCudaDevAttrComputeCapabilityMajor,\n                                       cudaDevAttrComputeCapabilityMajor, dev));\n  OF_CUDA_CHECK(cudaDeviceGetAttribute(&ctx->nCudaDevAttrComputeCapabilityMinor,\n                                       cudaDevAttrComputeCapabilityMinor, dev));\n  OF_CUDA_CHECK(cudaStreamGetFlags(stream, &ctx->nStreamFlags));\n}\n\nclass GpuDecodeHandle final : public DecodeHandle {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(GpuDecodeHandle);\n  explicit GpuDecodeHandle(int dev, int target_width, int target_height);\n  ~GpuDecodeHandle() override;\n\n  void DecodeRandomCropResize(const unsigned char* data, size_t length,\n                              RandomCropGenerator* crop_generator, unsigned char* workspace,\n                              size_t workspace_size, unsigned char* dst, int target_width,\n                              int target_height) override;\n  void WarmupOnce(int warmup_size, unsigned char* workspace, size_t workspace_size) override;\n  void Synchronize() override;\n\n private:\n  void DecodeRandomCrop(const unsigned char* data, size_t length, ROIGenerator* roi_generator,\n                        unsigned char* dst, size_t dst_max_length, int* dst_width, int* dst_height);\n  void Decode(const unsigned char* data, size_t length, unsigned char* dst, size_t dst_max_length,\n              int* dst_width, int* dst_height);\n  void CropResize(const unsigned char* src, int src_width, int src_height,\n                  ROIGenerator* roi_generator, unsigned char* dst, int dst_width, int dst_height);\n\n  cudaStream_t cuda_stream_ = nullptr;\n  nvjpegHandle_t jpeg_handle_ = nullptr;\n  nvjpegJpegState_t jpeg_state_ = nullptr;\n  nvjpegJpegState_t hw_jpeg_state_ = nullptr;\n  nvjpegBufferPinned_t jpeg_pinned_buffer_ = nullptr;\n  nvjpegBufferDevice_t jpeg_device_buffer_ = nullptr;\n  nvjpegDecodeParams_t jpeg_decode_params_ = nullptr;\n  nvjpegJpegDecoder_t jpeg_decoder_ = nullptr;\n  nvjpegJpegDecoder_t hw_jpeg_decoder_ = nullptr;\n  nvjpegJpegStream_t jpeg_stream_ = nullptr;\n  NppStreamContext npp_stream_ctx_{};\n  nvjpegDevAllocator_t dev_allocator_{};\n  nvjpegPinnedAllocator_t pinned_allocator_{};\n  CpuDecodeHandle fallback_handle_;\n  unsigned char* fallback_buffer_{};\n  size_t fallback_buffer_size_;\n  bool warmup_done_;\n  bool use_hardware_acceleration_;\n};\n\nGpuDecodeHandle::GpuDecodeHandle(int dev, int target_width, int target_height)\n    : warmup_done_(false), use_hardware_acceleration_(false) {\n  OF_CUDA_CHECK(cudaStreamCreateWithFlags(&cuda_stream_, cudaStreamNonBlocking));\n  dev_allocator_.dev_malloc = &GpuDeviceMalloc;\n  dev_allocator_.dev_free = &GpuDeviceFree;\n  pinned_allocator_.pinned_malloc = &GpuPinnedMalloc;\n  pinned_allocator_.pinned_free = &GpuPinnedFree;\n  OF_NVJPEG_CHECK(nvjpegCreateEx(NVJPEG_BACKEND_DEFAULT, &dev_allocator_, &pinned_allocator_, 0,\n                                 &jpeg_handle_));\n  OF_NVJPEG_CHECK(nvjpegDecoderCreate(jpeg_handle_, NVJPEG_BACKEND_DEFAULT, &jpeg_decoder_));\n  OF_NVJPEG_CHECK(nvjpegDecoderStateCreate(jpeg_handle_, jpeg_decoder_, &jpeg_state_));\n#if NVJPEG_VER_MAJOR >= 11\n  if (ParseBooleanFromEnv(\"ONEFLOW_DECODER_ENABLE_NVJPEG_HARDWARE_ACCELERATION\", true)\n      && nvjpegDecoderCreate(jpeg_handle_, NVJPEG_BACKEND_HARDWARE, &hw_jpeg_decoder_)\n             == NVJPEG_STATUS_SUCCESS) {\n    OF_NVJPEG_CHECK(nvjpegDecoderStateCreate(jpeg_handle_, hw_jpeg_decoder_, &hw_jpeg_state_));\n    use_hardware_acceleration_ = true;\n  } else {\n    hw_jpeg_decoder_ = nullptr;\n    hw_jpeg_state_ = nullptr;\n  }\n#endif\n  OF_NVJPEG_CHECK(nvjpegBufferPinnedCreate(jpeg_handle_, &pinned_allocator_, &jpeg_pinned_buffer_));\n  OF_NVJPEG_CHECK(nvjpegBufferDeviceCreate(jpeg_handle_, &dev_allocator_, &jpeg_device_buffer_));\n  OF_NVJPEG_CHECK(nvjpegDecodeParamsCreate(jpeg_handle_, &jpeg_decode_params_));\n  OF_NVJPEG_CHECK(nvjpegJpegStreamCreate(jpeg_handle_, &jpeg_stream_));\n  InitNppStreamContext(&npp_stream_ctx_, dev, cuda_stream_);\n  fallback_buffer_size_ = target_width * target_height * kNumChannels;\n  OF_CUDA_CHECK(cudaMallocHost(&fallback_buffer_, fallback_buffer_size_));\n}\n\nGpuDecodeHandle::~GpuDecodeHandle() {\n  OF_CUDA_CHECK(cudaStreamSynchronize(cuda_stream_));\n  OF_NVJPEG_CHECK(nvjpegJpegStreamDestroy(jpeg_stream_));\n  OF_NVJPEG_CHECK(nvjpegDecodeParamsDestroy(jpeg_decode_params_));\n  OF_NVJPEG_CHECK(nvjpegBufferDeviceDestroy(jpeg_device_buffer_));\n  OF_NVJPEG_CHECK(nvjpegBufferPinnedDestroy(jpeg_pinned_buffer_));\n  OF_NVJPEG_CHECK(nvjpegJpegStateDestroy(jpeg_state_));\n  OF_NVJPEG_CHECK(nvjpegDecoderDestroy(jpeg_decoder_));\n  if (use_hardware_acceleration_) {\n    OF_NVJPEG_CHECK(nvjpegJpegStateDestroy(hw_jpeg_state_));\n    OF_NVJPEG_CHECK(nvjpegDecoderDestroy(hw_jpeg_decoder_));\n  }\n  OF_NVJPEG_CHECK(nvjpegDestroy(jpeg_handle_));\n  OF_CUDA_CHECK(cudaStreamDestroy(cuda_stream_));\n  OF_CUDA_CHECK(cudaFreeHost(fallback_buffer_));\n}\n\nvoid GpuDecodeHandle::DecodeRandomCrop(const unsigned char* data, size_t length,\n                                       ROIGenerator* roi_generator, unsigned char* dst,\n                                       size_t dst_max_length, int* dst_width, int* dst_height) {\n  // https://docs.nvidia.com/cuda/archive/10.2/nvjpeg/index.html#nvjpeg-decoupled-decode-api\n  OF_NVJPEG_CHECK(nvjpegJpegStreamParse(jpeg_handle_, data, length, 0, 0, jpeg_stream_));\n  unsigned int orig_width = 0;\n  unsigned int orig_height = 0;\n  OF_NVJPEG_CHECK(nvjpegJpegStreamGetFrameDimensions(jpeg_stream_, &orig_width, &orig_height));\n  ROI roi{};\n  roi_generator->Generate(static_cast<int>(orig_width), static_cast<int>(orig_height), &roi);\n  CHECK_LE(roi.w * roi.h * kNumChannels, dst_max_length);\n  nvjpegImage_t image;\n  image.channel[0] = dst;\n  image.pitch[0] = roi.w * kNumChannels;\n  OF_NVJPEG_CHECK(nvjpegDecodeParamsSetOutputFormat(jpeg_decode_params_, NVJPEG_OUTPUT_RGBI));\n\n  nvjpegJpegDecoder_t jpeg_decoder = nullptr;\n  nvjpegJpegState_t jpeg_state = nullptr;\n  int is_hardware_acceleration_supported = -1;\n  if (use_hardware_acceleration_) {\n    nvjpegDecoderJpegSupported(hw_jpeg_decoder_, jpeg_stream_, jpeg_decode_params_,\n                               &is_hardware_acceleration_supported);\n  }\n  if (is_hardware_acceleration_supported == 0) {\n    jpeg_decoder = hw_jpeg_decoder_;\n    jpeg_state = hw_jpeg_state_;\n  } else {\n    jpeg_decoder = jpeg_decoder_;\n    jpeg_state = jpeg_state_;\n  }\n  if (roi.x != 0 || roi.y != 0 || roi.w != orig_width || roi.h != orig_height) {\n    // hardware_acceleration not support nvjpegDecodeParamsSetROI\n    OF_NVJPEG_CHECK(nvjpegDecodeParamsSetROI(jpeg_decode_params_, roi.x, roi.y, roi.w, roi.h));\n  } else {\n    OF_NVJPEG_CHECK(nvjpegDecodeParamsSetROI(jpeg_decode_params_, 0, 0, -1, -1));\n  }\n  OF_NVJPEG_CHECK(nvjpegStateAttachPinnedBuffer(jpeg_state, jpeg_pinned_buffer_));\n  OF_NVJPEG_CHECK(nvjpegStateAttachDeviceBuffer(jpeg_state, jpeg_device_buffer_));\n  OF_NVJPEG_CHECK(nvjpegDecodeJpegHost(jpeg_handle_, jpeg_decoder, jpeg_state, jpeg_decode_params_,\n                                       jpeg_stream_));\n  OF_NVJPEG_CHECK(nvjpegDecodeJpegTransferToDevice(jpeg_handle_, jpeg_decoder, jpeg_state,\n                                                   jpeg_stream_, cuda_stream_));\n  OF_NVJPEG_CHECK(\n      nvjpegDecodeJpegDevice(jpeg_handle_, jpeg_decoder, jpeg_state, &image, cuda_stream_));\n  *dst_width = roi.w;\n  *dst_height = roi.h;\n}\n\nvoid GpuDecodeHandle::Decode(const unsigned char* data, size_t length, unsigned char* dst,\n                             size_t dst_max_length, int* dst_width, int* dst_height) {\n  NoChangeROIGenerator no_change_roi_generator;\n  DecodeRandomCrop(data, length, &no_change_roi_generator, dst, dst_max_length, dst_width,\n                   dst_height);\n}\n\nvoid GpuDecodeHandle::CropResize(const unsigned char* src, int src_width, int src_height,\n                                 ROIGenerator* roi_generator, unsigned char* dst, int dst_width,\n                                 int dst_height) {\n  ROI roi{};\n  roi_generator->Generate(static_cast<int>(src_width), static_cast<int>(src_height), &roi);\n  const NppiSize src_size{\n      .width = src_width,\n      .height = src_height,\n  };\n  const NppiRect src_rect{\n      .x = roi.x,\n      .y = roi.y,\n      .width = roi.w,\n      .height = roi.h,\n  };\n  const NppiSize dst_size{\n      .width = dst_width,\n      .height = dst_height,\n  };\n  const NppiRect dst_rect{\n      .x = 0,\n      .y = 0,\n      .width = dst_width,\n      .height = dst_height,\n  };\n  NppStatus status =\n      nppiResize_8u_C3R_Ctx(src, src_width * kNumChannels, src_size, src_rect, dst, dst_width * 3,\n                            dst_size, dst_rect, NPPI_INTER_LINEAR, npp_stream_ctx_);\n  CHECK_GE(status, NPP_SUCCESS);\n}\n\nvoid GpuDecodeHandle::DecodeRandomCropResize(const unsigned char* data, size_t length,\n                                             RandomCropGenerator* crop_generator,\n                                             unsigned char* workspace, size_t workspace_size,\n                                             unsigned char* dst, int target_width,\n                                             int target_height) {\n  int width[NVJPEG_MAX_COMPONENT];\n  int height[NVJPEG_MAX_COMPONENT];\n  nvjpegChromaSubsampling_t subsampling{};\n  int num_components = 0;\n  nvjpegStatus_t status =\n      nvjpegGetImageInfo(jpeg_handle_, data, length, &num_components, &subsampling, width, height);\n  if (status != NVJPEG_STATUS_SUCCESS) {\n    CHECK_LE(target_width * target_height * kNumChannels, fallback_buffer_size_);\n    fallback_handle_.DecodeRandomCropResize(data, length, crop_generator, nullptr, 0,\n                                            fallback_buffer_, target_width, target_height);\n    OF_CUDA_CHECK(cudaMemcpyAsync(dst, fallback_buffer_,\n                                  target_width * target_height * kNumChannels, cudaMemcpyDefault,\n                                  cuda_stream_));\n    return;\n  }\n  NoChangeROIGenerator no_change_roi_generator;\n  RandomCropROIGenerator random_crop_roi_generator(crop_generator);\n  if (use_hardware_acceleration_) {\n    int w = 0;\n    int h = 0;\n    DecodeRandomCrop(data, length, &no_change_roi_generator, workspace, workspace_size, &w, &h);\n    CropResize(workspace, w, h, &random_crop_roi_generator, dst, target_width, target_height);\n  } else {\n    int w = 0;\n    int h = 0;\n    DecodeRandomCrop(data, length, &random_crop_roi_generator, workspace, workspace_size, &w, &h);\n    CropResize(workspace, w, h, &no_change_roi_generator, dst, target_width, target_height);\n  }\n}\n\nvoid GpuDecodeHandle::WarmupOnce(int warmup_size, unsigned char* workspace, size_t workspace_size) {\n  if (warmup_done_) { return; }\n  warmup_size = std::min(static_cast<int>(std::sqrt(workspace_size / kNumChannels)), warmup_size);\n  cv::Mat image = cv::Mat::zeros(cv::Size(warmup_size, warmup_size), CV_8UC3);\n  cv::randu(image, cv::Scalar(0, 0, 0), cv::Scalar(255, 255, 255));\n  std::vector<unsigned char> data;\n  cv::imencode(\".jpg\", image, data, {});\n  int decoded_width = 0;\n  int decoded_height = 0;\n  Decode(data.data(), data.size(), workspace, workspace_size, &decoded_width, &decoded_height);\n  Synchronize();\n  if (use_hardware_acceleration_) {\n    // Note(guoran): hardware acceleration jpeg decoder support baseline decoding only, use\n    // progressive to warmup jpeg decoder.\n    cv::imencode(\".jpg\", image, data, {cv::IMWRITE_JPEG_PROGRESSIVE, 1});\n    Decode(data.data(), data.size(), workspace, workspace_size, &decoded_width, &decoded_height);\n    Synchronize();\n  }\n  warmup_done_ = true;\n}\n\nvoid GpuDecodeHandle::Synchronize() { OF_CUDA_CHECK(cudaStreamSynchronize(cuda_stream_)); }\n\ntemplate<>\nDecodeHandleFactory CreateDecodeHandleFactory<DeviceType::kCUDA>(int target_width,\n                                                                 int target_height) {\n  int dev = 0;\n  OF_CUDA_CHECK(cudaGetDevice(&dev));\n  return [dev, target_width, target_height]() -> std::shared_ptr<DecodeHandle> {\n    OF_CUDA_CHECK(cudaSetDevice(dev));\n    return std::make_shared<GpuDecodeHandle>(dev, target_width, target_height);\n  };\n}\n\n#endif  // defined(WITH_NVJPEG)\n\nclass Worker final {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(Worker);\n  Worker(const std::function<std::shared_ptr<DecodeHandle>()>& handle_factory, int target_width,\n         int target_height, int warmup_size) {\n    worker_thread_ = std::thread(&Worker::PollWork, this, handle_factory, target_width,\n                                 target_height, warmup_size);\n  }\n  ~Worker() {\n    work_queue_.Close();\n    worker_thread_.join();\n  }\n\n  void Enqueue(std::shared_ptr<Work>& work) { work_queue_.Send(work); }\n\n private:\n  Channel<std::shared_ptr<Work>> work_queue_;\n  std::thread worker_thread_;\n\n  void PollWork(const std::function<std::shared_ptr<DecodeHandle>()>& handle_factory,\n                int target_width, int target_height, int warmup_size) {\n    OF_PROFILER_NAME_THIS_HOST_THREAD(\"_cuda_img_decode\");\n    std::shared_ptr<DecodeHandle> handle = handle_factory();\n    std::shared_ptr<Work> work;\n    while (true) {\n      ChannelStatus status = work_queue_.Receive(&work);\n      if (status == ChannelStatus::kChannelStatusErrorClosed) { break; }\n      CHECK_EQ(status, ChannelStatus::kChannelStatusSuccess);\n      handle->WarmupOnce(warmup_size, work->workspace, work->workspace_size);\n      while (true) {\n        const int task_id = work->task_counter->fetch_add(1, std::memory_order_relaxed);\n        if (task_id >= work->tasks->size()) { break; }\n        const Task& task = work->tasks->at(task_id);\n        handle->DecodeRandomCropResize(task.data, task.length, task.crop_generator, work->workspace,\n                                       work->workspace_size, task.dst, target_width, target_height);\n        handle->Synchronize();\n      }\n      work->done_counter->Decrease();\n    }\n  }\n};\n\n}  // namespace\n\ntemplate<DeviceType device_type>\nclass ImageDecoderRandomCropResizeKernel final : public Kernel {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(ImageDecoderRandomCropResizeKernel);\n  ImageDecoderRandomCropResizeKernel() = default;\n  ~ImageDecoderRandomCropResizeKernel() override = default;\n\n private:\n  void VirtualKernelInit(KernelContext* ctx) override;\n  void ForwardDataContent(KernelContext* ctx) const override;\n\n  std::vector<std::unique_ptr<RandomCropGenerator>> random_crop_generators_;\n  std::vector<std::unique_ptr<Worker>> workers_;\n};\n\ntemplate<DeviceType device_type>\nvoid ImageDecoderRandomCropResizeKernel<device_type>::VirtualKernelInit(KernelContext* ctx) {\n  const ImageDecoderRandomCropResizeOpConf& conf =\n      this->op_conf().image_decoder_random_crop_resize_conf();\n  const int64_t batch_size =\n      this->kernel_conf().image_decoder_random_crop_resize_conf().batch_size();\n  random_crop_generators_.resize(batch_size);\n  std::seed_seq seq{this->kernel_conf().image_decoder_random_crop_resize_conf().seed()};\n  std::vector<int> seeds(batch_size);\n  seq.generate(seeds.begin(), seeds.end());\n  AspectRatioRange aspect_ratio_range{\n      conf.random_aspect_ratio_min(),\n      conf.random_aspect_ratio_max(),\n  };\n  AreaRange area_range{\n      conf.random_area_min(),\n      conf.random_area_max(),\n  };\n  for (int64_t i = 0; i < batch_size; ++i) {\n    random_crop_generators_.at(i).reset(\n        new RandomCropGenerator(aspect_ratio_range, area_range, seeds.at(i), conf.num_attempts()));\n  }\n  workers_.resize(conf.num_workers());\n  for (int64_t i = 0; i < conf.num_workers(); ++i) {\n    workers_.at(i).reset(new Worker(\n        CreateDecodeHandleFactory<device_type>(conf.target_width(), conf.target_height()),\n        conf.target_width(), conf.target_height(), conf.warmup_size()));\n  }\n}\n\ntemplate<DeviceType device_type>\nvoid ImageDecoderRandomCropResizeKernel<device_type>::ForwardDataContent(KernelContext* ctx) const {\n  const ImageDecoderRandomCropResizeOpConf& conf =\n      this->op_conf().image_decoder_random_crop_resize_conf();\n  const Blob* in = ctx->BnInOp2Blob(\"in\");\n  Blob* out = ctx->BnInOp2Blob(\"out\");\n  Blob* tmp = ctx->BnInOp2Blob(\"tmp\");\n  CHECK_EQ(in->data_type(), DataType::kTensorBuffer);\n  CHECK_EQ(out->data_type(), DataType::kUInt8);\n  const ShapeView& in_shape = in->shape();\n  const int64_t num_in_axes = in_shape.NumAxes();\n  const ShapeView& out_shape = out->shape();\n  const int64_t num_out_axes = out_shape.NumAxes();\n  CHECK_EQ(num_out_axes, num_in_axes + 3);\n  for (int i = 0; i < num_in_axes; ++i) { CHECK_EQ(out_shape.At(i), in_shape.At(i)); }\n  CHECK_EQ(out_shape.At(num_in_axes), conf.target_height());\n  CHECK_EQ(out_shape.At(num_in_axes + 1), conf.target_width());\n  CHECK_EQ(out_shape.At(num_in_axes + 2), kNumChannels);\n  CHECK_EQ(tmp->data_type(), DataType::kUInt8);\n  const int64_t batch_size = in_shape.elem_cnt();\n  const auto* buffers = in->dptr<TensorBuffer>();\n  auto* out_ptr = out->mut_dptr<unsigned char>();\n  const int64_t out_instance_size = conf.target_height() * conf.target_width() * kNumChannels;\n  auto* workspace_ptr = tmp->mut_dptr<unsigned char>();\n  size_t workspace_size_per_worker = tmp->shape().elem_cnt() / workers_.size();\n  std::shared_ptr<BlockingCounter> done_counter(new BlockingCounter(workers_.size()));\n  std::shared_ptr<std::atomic<int>> task_counter(new std::atomic<int>(0));\n  std::shared_ptr<std::vector<Task>> tasks(new std::vector<Task>(batch_size));\n  for (int64_t task_id = 0; task_id < batch_size; ++task_id) {\n    const TensorBuffer* buffer = buffers + task_id;\n    CHECK_EQ(buffer->data_type(), DataType::kUInt8);\n    tasks->at(task_id).data = buffer->data<unsigned char>();\n    tasks->at(task_id).length = buffer->elem_cnt();\n    tasks->at(task_id).dst = out_ptr + task_id * out_instance_size;\n    tasks->at(task_id).crop_generator = random_crop_generators_.at(task_id).get();\n  }\n  // Larger images will be processed first, balancing the work time of the workers.\n  std::sort(tasks->begin(), tasks->end(),\n            [](const Task& a, const Task& b) { return b.length < a.length; });\n  for (int64_t worker_id = 0; worker_id < workers_.size(); ++worker_id) {\n    std::shared_ptr<Work> work(new Work());\n    work->tasks = tasks;\n    work->workspace = workspace_ptr + worker_id * workspace_size_per_worker;\n    work->workspace_size = workspace_size_per_worker;\n    work->done_counter = done_counter;\n    work->task_counter = task_counter;\n    workers_.at(worker_id)->Enqueue(work);\n  }\n  done_counter->WaitForeverUntilCntEqualZero();\n}\n\nNEW_REGISTER_KERNEL(OperatorConf::kImageDecoderRandomCropResizeConf,\n                    ImageDecoderRandomCropResizeKernel<DeviceType::kCPU>)\n    .SetIsMatchedPred([](const KernelConf& conf) -> bool {\n      return conf.op_attribute().op_conf().device_tag() == \"cpu\";\n    });\n\n#if defined(WITH_NVJPEG)\n\nNEW_REGISTER_KERNEL(OperatorConf::kImageDecoderRandomCropResizeConf,\n                    ImageDecoderRandomCropResizeKernel<DeviceType::kCUDA>)\n    .SetIsMatchedPred([](const KernelConf& conf) -> bool {\n      return conf.op_attribute().op_conf().device_tag() == \"cuda\";\n    });\n\n#endif  // defined(WITH_NVJPEG)\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/kernel/input_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/kernel/kernel.h\"\n#include \"oneflow/core/common/buffer_manager.h\"\n#include \"oneflow/core/job/critical_section_instance.h\"\n#include \"oneflow/core/job/global_for.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nclass InputKernel final : public Kernel {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(InputKernel);\n  InputKernel() = default;\n  ~InputKernel() = default;\n\n private:\n  void ForwardDataContent(KernelContext* ctx) const override {\n    CHECK(this->op_conf().input_conf().has_job_name());\n    const auto& job_name = this->op_conf().input_conf().job_name();\n    const auto& op_name = this->op_conf().name();\n    auto* buffer_mgr = Singleton<BufferMgr<std::shared_ptr<CriticalSectionInstance>>>::Get();\n    auto* buffer = buffer_mgr->Get(GetInputBufferName(job_name, op_name));\n    std::shared_ptr<CriticalSectionInstance> critical_section_instance;\n    BufferStatus buffer_status = buffer->TryReceive(&critical_section_instance);\n    CHECK_NE(buffer_status, kBufferStatusEmpty);\n    if (buffer_status == kBufferStatusSuccess) {\n      critical_section_instance->AccessBlobByOpName(ctx->stream(), ctx->BnInOp2Blob(\"out\"),\n                                                    op_name);\n    }\n  }\n  void ForwardHeader(KernelContext* ctx) const override {}\n};\n\n}  // namespace\n\nREGISTER_KERNEL(OperatorConf::kInputConf, InputKernel);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/kernel/kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/kernel/kernel.h\"\n#include \"oneflow/core/kernel/runtime_blob_shape_infer_helper.h\"\n#include \"oneflow/core/kernel/kernel_observer.h\"\n#include \"oneflow/core/vm/sync_vm_mode_guard.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nbool IsAllBlobEmpty(const PbRpf<std::string>& bns, KernelContext* ctx) {\n  for (const auto& bn : bns) {\n    Blob* blob = ctx->BnInOp2Blob(bn);\n    if (blob && !blob->IsBodyEmpty()) { return false; }\n  }\n  return true;\n}\n\n}  // namespace\n\nKernel::Kernel() = default;\n\nKernel::~Kernel() = default;\n\nvoid Kernel::InitBase(const KernelConf& kernel_conf) {\n  if (shape_infer_helper_) { return; }\n  kernel_conf_ = kernel_conf;\n  shape_infer_helper_.reset(\n      new RuntimeBlobShapeInferHelper(this->op_conf(), this->kernel_conf(), this));\n}\n\nvoid Kernel::Init(const KernelConf& kernel_conf, KernelContext* ctx) {\n  SyncVmModeGuard guard(SyncVmMode::kEnable);\n  InitBase(kernel_conf);\n  VirtualKernelInit(ctx);\n}\n\nvoid Kernel::Launch(KernelContext* ctx) const {\n  SyncVmModeGuard guard(SyncVmMode::kEnable);\n  ctx->WillForward(ctx, this);\n  Forward(ctx);\n  ctx->DidForward(ctx, this);\n}\n\nvoid Kernel::Forward(KernelContext* ctx) const {\n  ctx->WillForwardHeader(ctx, this);\n  ForwardHeader(ctx);\n  ctx->DidForwardHeader(ctx, this);\n  if ((!kernel_conf_.all_blobs_are_static()) && IsAllBlobEmpty(op_attribute().output_bns(), ctx)\n      && IsStateless()) {\n    return;\n  }\n  ctx->WillForwardDataContent(ctx, this);\n  ForwardDataContent(ctx);\n  ctx->DidForwardDataContent(ctx, this);\n}\n\nvoid Kernel::ForwardHeader(KernelContext* ctx) const {\n  if (!kernel_conf_.all_blobs_are_static()) { ForwardShape(ctx); }\n}\n\nvoid Kernel::ForwardShape(KernelContext* ctx) const {\n  return shape_infer_helper_->InferShape(\n      [ctx](const std::string& bn) { return ctx->BnInOp2Blob(bn); });\n}\n\nstd::unique_ptr<const Kernel> ConstructKernel(const KernelConf& conf, KernelContext* kernel_ctx) {\n  auto op_type = conf.op_attribute().op_conf().op_type_case();\n  CHECK_NE(op_type, OperatorConf::OpTypeCase::OP_TYPE_NOT_SET)\n      << \" ERROR! KernelConf: \" << conf.DebugString() << \" has NOT set op_type_case\";\n  Kernel* rptr = kernel_registration::CreateKernel(conf);\n  if (rptr == nullptr) { rptr = NewObj<int32_t, Kernel>(op_type, conf); }\n  CHECK_NOTNULL(rptr);\n  rptr->Init(conf, kernel_ctx);\n  return std::unique_ptr<const Kernel>(rptr);\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/kernel/kernel.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_KERNEL_KERNEL_H_\n#define ONEFLOW_CORE_KERNEL_KERNEL_H_\n\n#include \"oneflow/core/kernel/kernel.pb.h\"\n#include \"oneflow/core/kernel/kernel_registration.h\"\n#include \"oneflow/core/kernel/kernel_context.h\"\n#include \"oneflow/core/register/blob.h\"\n#include \"oneflow/core/kernel/kernel_util.h\"\n\nnamespace oneflow {\n\nclass JobDesc;\nclass RuntimeBlobShapeInferHelper;\n\nclass Kernel {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(Kernel);\n  virtual ~Kernel();\n\n  void Init(const KernelConf& kernel_conf, KernelContext* ctx);\n  void Launch(KernelContext* ctx) const;\n\n  const OperatorConf& op_conf() const { return op_attribute().op_conf(); }\n  const OpAttribute& op_attribute() const { return kernel_conf().op_attribute(); }\n  const KernelConf& kernel_conf() const { return kernel_conf_; }\n  /*\n   * return true means all below must be guaranteed when `Launch` function return:\n   * 1) all out blob header has been set (e.g. SyncSetHeadKernel)\n   * 2) all asynchronous task has been queued (e.g. NCCL related kernel)\n   */\n  virtual bool IsKernelLaunchSynchronized() const { return true; }\n\n  void SystemForwardHeader(KernelContext* ctx) const { ForwardHeader(ctx); }\n  void SystemForwardDataContent(KernelContext* ctx) const { ForwardDataContent(ctx); }\n  virtual void Forward(KernelContext* ctx) const;\n\n protected:\n  Kernel();\n  void InitBase(const KernelConf&);\n  virtual void VirtualKernelInit(KernelContext* ctx) {}\n\n  virtual void ForwardHeader(KernelContext* ctx) const;\n  virtual void ForwardShape(KernelContext* ctx) const;\n  // TODO(niuchong) : rename ForwardDataContent to ForwardBody\n  virtual void ForwardDataContent(KernelContext* ctx) const = 0;\n  virtual bool IsStateless() const { return false; }\n\n private:\n  std::unique_ptr<RuntimeBlobShapeInferHelper> shape_infer_helper_;\n  KernelConf kernel_conf_;\n};\n\n#define REGISTER_KERNEL(k, KernelType) \\\n  REGISTER_CLASS_WITH_ARGS(int32_t, k, Kernel, KernelType, const KernelConf&)\n#define REGISTER_KERNEL_CREATOR(k, f) \\\n  REGISTER_CLASS_CREATOR(int32_t, k, Kernel, f, const KernelConf&)\n\nstd::unique_ptr<const Kernel> ConstructKernel(const KernelConf& kernel_conf, KernelContext* ctx);\n\n}  // namespace oneflow\n\n#define MAKE_KERNEL_CREATOR_ENTRY(kernel_class, device_type, data_type_pair) \\\n  {GetHashKey(device_type, OF_PP_PAIR_SECOND(data_type_pair)),               \\\n   []() { return new kernel_class<device_type, OF_PP_PAIR_FIRST(data_type_pair)>(); }},\n\n#define ADD_DEFAULT_KERNEL_CREATOR(op_type_case, kernel_class, data_type_seq)                \\\n  namespace {                                                                                \\\n                                                                                             \\\n  Kernel* OF_PP_CAT(CreateKernel, __LINE__)(const KernelConf& kernel_conf) {                 \\\n    static const HashMap<std::string, std::function<Kernel*()>> creators = {                 \\\n        OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_KERNEL_CREATOR_ENTRY, (kernel_class),          \\\n                                         DEVICE_TYPE_SEQ, data_type_seq)};                   \\\n    DeviceType device_type =                                                                 \\\n        CHECK_JUST(DeviceType4DeviceTag(kernel_conf.op_attribute().op_conf().device_tag())); \\\n    auto key = GetHashKey(device_type, kernel_conf.data_type());                             \\\n    auto it = creators.find(key);                                                            \\\n    if (it == creators.end()) {                                                              \\\n      LOG(FATAL) << \"Error! Cannot find kernel creator: \" << kernel_conf.DebugString()       \\\n                 << \" with device_type = \" << device_type                                    \\\n                 << \", dtype = \" << kernel_conf.data_type();                                 \\\n    }                                                                                        \\\n    return (it->second)();                                                                   \\\n  }                                                                                          \\\n                                                                                             \\\n  REGISTER_KERNEL_CREATOR(op_type_case, OF_PP_CAT(CreateKernel, __LINE__));                  \\\n  }\n\n#define MAKE_DEVICE_TYPE_KERNEL_CREATOR_ENTRY(kernel_class, device_type) \\\n  {device_type, []() { return new kernel_class<device_type>(); }},\n\n#define ADD_DEVICE_TYPE_KERNEL_CREATOR(op_type_case, kernel_class)                              \\\n  namespace {                                                                                   \\\n                                                                                                \\\n  Kernel* OF_PP_CAT(CreateKernel, __LINE__)(const KernelConf& kernel_conf) {                    \\\n    static const HashMap<int, std::function<Kernel*()>> creators = {                            \\\n        OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_DEVICE_TYPE_KERNEL_CREATOR_ENTRY, (kernel_class), \\\n                                         DEVICE_TYPE_SEQ)};                                     \\\n    DeviceType device_type =                                                                    \\\n        CHECK_JUST(DeviceType4DeviceTag(kernel_conf.op_attribute().op_conf().device_tag()));    \\\n    auto it = creators.find(device_type);                                                       \\\n    if (it == creators.end()) {                                                                 \\\n      LOG(FATAL) << \"Error! Cannot find kernel creator: \" << kernel_conf.DebugString()          \\\n                 << \" with device_type = \" << device_type;                                      \\\n    }                                                                                           \\\n    return (it->second)();                                                                      \\\n  }                                                                                             \\\n                                                                                                \\\n  REGISTER_KERNEL_CREATOR(op_type_case, OF_PP_CAT(CreateKernel, __LINE__));                     \\\n  }\n\n#define MAKE_CPU_KERNEL_CREATOR_ENTRY(kernel_class, data_type_pair) \\\n  {OF_PP_PAIR_SECOND(data_type_pair),                               \\\n   []() { return new kernel_class<OF_PP_PAIR_FIRST(data_type_pair)>(); }},\n\n#define ADD_CPU_DEFAULT_KERNEL_CREATOR(op_type_case, kernel_class, data_type_seq)       \\\n  namespace {                                                                           \\\n                                                                                        \\\n  Kernel* CreateKernel(const KernelConf& kernel_conf) {                                 \\\n    static const HashMap<int, std::function<Kernel*()>> creators = {                    \\\n        OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_CPU_KERNEL_CREATOR_ENTRY, (kernel_class), \\\n                                         data_type_seq)};                               \\\n    auto it = creators.find(kernel_conf.data_type());                                   \\\n    if (it == creators.end()) {                                                         \\\n      LOG(FATAL) << \"Error! Cannot find kernel creator: \" << kernel_conf.DebugString()  \\\n                 << \" with dtype = \" << kernel_conf.data_type();                        \\\n    }                                                                                   \\\n    return (it->second)();                                                              \\\n  }                                                                                     \\\n                                                                                        \\\n  REGISTER_KERNEL_CREATOR(op_type_case, CreateKernel);                                  \\\n  }\n\n#endif  // ONEFLOW_CORE_KERNEL_KERNEL_H_\n"
  },
  {
    "path": "oneflow/core/kernel/kernel.proto",
    "content": "syntax = \"proto2\";\npackage oneflow;\n\nimport \"oneflow/core/common/data_type.proto\";\nimport \"oneflow/core/common/dtype_signature.proto\";\nimport \"oneflow/core/operator/op_attribute.proto\";\nimport \"oneflow/core/job/placement.proto\";\nimport \"oneflow/core/register/blob_desc.proto\";\n\nmessage DecodeRandomKernelConf {\n  required uint32 random_seed = 1;\n}\n\nmessage ShapeElemCntKernelConf {\n  repeated int32 axis = 1;\n}\n\nmessage UserKernelConf {\n  map<string, BlobDescProto> bn_in_op2blob_desc = 1;\n}\n\nmessage SyncDynamicResizeKernelConf {\n  required DataType size_data_type = 1;\n}\n\nmessage BroadcastToCompatibleWithKernelConf {\n  repeated int64 broadcast_axes = 1;\n}\n\nmessage ImageDecoderRandomCropResizeKernelConf {\n  required int64 seed = 1;\n  required int64 batch_size = 2;\n}\n\nmessage KernelConf {\n  required DataType data_type = 2;\n  required bool all_blobs_are_static = 6;\n  required DTypeSignature dtype_signature = 7;\n  optional ParallelContext parallel_ctx = 8;\n  optional OpAttribute op_attribute = 9;\n  optional string op_attribute_ref = 10;\n\n  oneof kernel_type {\n    UserKernelConf user_conf = 100;\n    DecodeRandomKernelConf decode_random_conf = 103;\n    SyncDynamicResizeKernelConf sync_dynamic_resize_conf = 360;\n\n    ShapeElemCntKernelConf shape_elem_cnt_conf = 412;\n    BroadcastToCompatibleWithKernelConf broadcast_to_compatible_with_conf = 428;\n    ImageDecoderRandomCropResizeKernelConf image_decoder_random_crop_resize_conf = 429;\n  }\n}\n"
  },
  {
    "path": "oneflow/core/kernel/kernel_context.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_KERNEL_KERNEL_CONTEXT_H_\n#define ONEFLOW_CORE_KERNEL_KERNEL_CONTEXT_H_\n\n#include \"oneflow/core/kernel/kernel_observer.h\"\n#include \"oneflow/core/ep/include/stream.h\"\n\nnamespace oneflow {\n\nclass Blob;\nclass JobDesc;\n\nclass KernelState {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(KernelState);\n  KernelState() = default;\n  virtual ~KernelState() = default;\n};\n\nclass KernelContext : public KernelObserver {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(KernelContext);\n  KernelContext() = default;\n  virtual ~KernelContext() = default;\n\n  virtual ep::Stream* stream() const = 0;\n  virtual Blob* BnInOp2Blob(const std::string& bn) const = 0;\n  virtual const std::shared_ptr<KernelState>& state() const = 0;\n  virtual void set_state(std::shared_ptr<KernelState> state) = 0;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_KERNEL_KERNEL_CONTEXT_H_\n"
  },
  {
    "path": "oneflow/core/kernel/kernel_observer.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_KERNEL_KERNEL_OBSERVER_H_\n#define ONEFLOW_CORE_KERNEL_KERNEL_OBSERVER_H_\n\n#include \"oneflow/core/common/util.h\"\n\nnamespace oneflow {\n\nclass Kernel;\nclass KernelContext;\nclass Blob;\n\nclass KernelObserver {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(KernelObserver);\n  KernelObserver() = default;\n  virtual ~KernelObserver() = default;\n\n  virtual void WillForward(KernelContext* kernel_ctx, const Kernel* kernel) {}\n  virtual void DidForward(KernelContext* kernel_ctx, const Kernel* kernel) {}\n\n  virtual void WillForwardHeader(KernelContext* kernel_ctx, const Kernel* kernel) {}\n  virtual void DidForwardHeader(KernelContext* kernel_ctx, const Kernel* kernel) {}\n\n  virtual void WillForwardDataContent(KernelContext* kernel_ctx, const Kernel* kernel) {}\n  virtual void DidForwardDataContent(KernelContext* kernel_ctx, const Kernel* kernel) {}\n};\n\nclass KernelObserverProvider {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(KernelObserverProvider);\n  KernelObserverProvider() = default;\n  virtual ~KernelObserverProvider() = default;\n\n  virtual KernelObserver* GetKernelObserver() = 0;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_KERNEL_KERNEL_OBSERVER_H_\n"
  },
  {
    "path": "oneflow/core/kernel/kernel_registration.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/kernel/kernel_registration.h\"\n#include \"oneflow/core/kernel/kernel.h\"\n\nnamespace oneflow {\n\nnamespace kernel_registration {\n\nnamespace {\n\nHashMap<OperatorConf::OpTypeCase, std::vector<KernelRegistryVal>>* MutKernelRegistry() {\n  static HashMap<OperatorConf::OpTypeCase, std::vector<KernelRegistryVal>> creators;\n  return &creators;\n}\n\n}  // namespace\n\nKernelRegistrarBuilder& KernelRegistrarBuilder::SetCreateFn(CreateFn fn) {\n  registry_val_.func = fn;\n  return *this;\n}\n\nKernelRegistrarBuilder& KernelRegistrarBuilder::SetIsMatchedPred(IsMatchedPredicator fn) {\n  registry_val_.cons.SetIsMatchedPred(fn);\n  return *this;\n}\n\nvoid KernelRegistrarBuilder::Finalize(OperatorConf::OpTypeCase* op_type,\n                                      KernelRegistryVal* val) const {\n  *op_type = op_type_;\n  val->func = registry_val_.func;\n  val->cons = registry_val_.cons;\n}\n\nKernelRegistrar::KernelRegistrar(const KernelRegistrarBuilder& builder) {\n  auto* creators = MutKernelRegistry();\n  OperatorConf::OpTypeCase op_type;\n  KernelRegistryVal val;\n  builder.Finalize(&op_type, &val);\n  (*creators)[op_type].emplace_back(std::move(val));\n}\n\nKernel* CreateKernel(const KernelConf& kernel_conf) {\n  auto op_type = kernel_conf.op_attribute().op_conf().op_type_case();\n  auto kernel_registry = MutKernelRegistry();\n  if (kernel_registry->find(op_type) == kernel_registry->end()) { return nullptr; }\n  const auto& registry_vals = kernel_registry->at(op_type);\n\n  Kernel* ret = nullptr;\n  bool is_matched = false;\n  for (const KernelRegistryVal& val : registry_vals) {\n    if (val.cons.IsMatched(kernel_conf)) {\n      CHECK(!is_matched)\n          << \"There are more than one kernel constraints satisfied by kernel conf of \"\n          << static_cast<size_t>(op_type);\n      is_matched = true;\n      ret = val.func();\n    }\n  }\n  // TODO: print more info when failed\n  return ret;\n}\n\n}  // namespace kernel_registration\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/kernel/kernel_registration.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_KERNEL_KERNEL_REGISTRATION_H_\n#define ONEFLOW_CORE_KERNEL_KERNEL_REGISTRATION_H_\n\n#include \"oneflow/core/common/data_type.h\"\n#include \"oneflow/core/common/data_type.pb.h\"\n#include \"oneflow/core/common/device_type.h\"\n#include \"oneflow/core/common/str_util.h\"\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/framework/to_string.h\"\n#include \"oneflow/core/kernel/kernel.pb.h\"\n#include \"oneflow/core/operator/op_conf_util.h\"\n\nnamespace oneflow {\n\nclass Kernel;\n\nnamespace kernel_registration {\n\nusing CreateFn = std::function<Kernel*()>;\nusing IsMatchedPredicator = std::function<bool(const KernelConf&)>;\n\nclass KernelConstraint final {\n public:\n  KernelConstraint() = default;\n  ~KernelConstraint() = default;\n\n  bool IsMatched(const KernelConf& conf) const { return predicator_(conf); }\n  void SetIsMatchedPred(IsMatchedPredicator pred) { predicator_ = pred; }\n\n private:\n  IsMatchedPredicator predicator_;\n};\n\nstruct KernelRegistryVal final {\n  KernelRegistryVal() : func(), cons() {}\n\n  CreateFn func;\n  KernelConstraint cons;\n};\n\nclass KernelRegistrarBuilder final {\n public:\n  explicit KernelRegistrarBuilder(OperatorConf::OpTypeCase op_type)\n      : op_type_(op_type), registry_val_() {}\n  KernelRegistrarBuilder& SetCreateFn(CreateFn fn);\n  KernelRegistrarBuilder& SetIsMatchedPred(IsMatchedPredicator fn);\n\n  void Finalize(OperatorConf::OpTypeCase* op_type, KernelRegistryVal* val) const;\n\n private:\n  OperatorConf::OpTypeCase op_type_;\n  KernelRegistryVal registry_val_;\n};\n\nstruct KernelRegistrar final {\n  KernelRegistrar(const KernelRegistrarBuilder&);\n};\n\nKernel* CreateKernel(const KernelConf& kernel_conf);\n\n}  // namespace kernel_registration\n\n#define NEW_REGISTER_KERNEL(op_type, ...)                                           \\\n  static kernel_registration::KernelRegistrar OF_PP_CAT(g_registrar, __COUNTER__) = \\\n      kernel_registration::KernelRegistrarBuilder(op_type).SetCreateFn(             \\\n          []() { return new __VA_ARGS__(); })\n\n#define REGISTER_KERNEL_WITH_NOTHING(op_type, ...)                                           \\\n  NEW_REGISTER_KERNEL(op_type, __VA_ARGS__).SetIsMatchedPred([](const KernelConf&) -> bool { \\\n    return true;                                                                             \\\n  });\n\n#define REGISTER_KERNEL_WITH_DEVICE_AND_DTYPE(op_type, device, dtype, ...)                        \\\n  NEW_REGISTER_KERNEL(op_type, __VA_ARGS__).SetIsMatchedPred([](const KernelConf& conf) -> bool { \\\n    return (*CHECK_JUST(DeviceTag4DeviceType(device))                                             \\\n            == conf.op_attribute().op_conf().device_tag())                                        \\\n           && (GetDataType<dtype>::value == conf.data_type());                                    \\\n  });\n\n#define REGISTER_KERNEL_WITH_DEVICE(op_type, device, ...)                                         \\\n  NEW_REGISTER_KERNEL(op_type, __VA_ARGS__).SetIsMatchedPred([](const KernelConf& conf) -> bool { \\\n    return (*CHECK_JUST(DeviceTag4DeviceType(device))                                             \\\n            == conf.op_attribute().op_conf().device_tag());                                       \\\n  });\n\n#define REGISTER_KERNEL_HELPER_CPU_FLOATING(op_type, kernel)               \\\n  REGISTER_KERNEL_WITH_DEVICE_AND_DTYPE(op_type, DeviceType::kCPU, float,  \\\n                                        kernel<DeviceType::kCPU, float>)   \\\n  REGISTER_KERNEL_WITH_DEVICE_AND_DTYPE(op_type, DeviceType::kCPU, double, \\\n                                        kernel<DeviceType::kCPU, double>)\n\n#define REGISTER_KERNEL_HELPER_CUDA_FLOATING(op_type, kernel)               \\\n  REGISTER_KERNEL_WITH_DEVICE_AND_DTYPE(op_type, DeviceType::kCUDA, float,  \\\n                                        kernel<DeviceType::kCUDA, float>)   \\\n  REGISTER_KERNEL_WITH_DEVICE_AND_DTYPE(op_type, DeviceType::kCUDA, double, \\\n                                        kernel<DeviceType::kCUDA, double>)\n\n#define REGISTER_KERNEL_HELPER_CUDA_HALF(op_type, kernel)                    \\\n  REGISTER_KERNEL_WITH_DEVICE_AND_DTYPE(op_type, DeviceType::kCUDA, float16, \\\n                                        kernel<DeviceType::kCUDA, float16>)\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_KERNEL_KERNEL_REGISTRATION_H_\n"
  },
  {
    "path": "oneflow/core/kernel/kernel_util.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/device_type.pb.h\"\n#include \"oneflow/core/kernel/kernel.h\"\n#include \"oneflow/core/kernel/kernel_util.h\"\n#include \"oneflow/core/common/balanced_splitter.h\"\n#include \"oneflow/core/register/register_manager.h\"\n#include \"oneflow/core/memory/memory_case_util.h\"\n#include \"oneflow/core/ep/include/primitive/memcpy.h\"\n#include \"oneflow/core/ep/include/primitive/memset.h\"\n\nnamespace oneflow {\n\nvoid AutoMemcpy(ep::Stream* stream, void* dst, const void* src, size_t sz,\n                const MemoryCase& dst_mem_case, const MemoryCase& src_mem_case) {\n  ep::primitive::MemcpyKind kind{};\n  if (stream->device_type() == DeviceType::kCPU) {\n    CHECK(memory::IsHostMem(src_mem_case));\n    if (dst_mem_case.device_type() != DeviceType::kMeta) { CHECK(memory::IsHostMem(dst_mem_case)); }\n    kind = ep::primitive::MemcpyKind::kDtoD;\n  } else {\n    if (memory::IsHostMem(src_mem_case)) {\n      CHECK(!memory::IsHostMem(dst_mem_case));\n      kind = ep::primitive::MemcpyKind::kHtoD;\n    } else if (memory::IsHostMem(dst_mem_case)) {\n      CHECK(!memory::IsHostMem(src_mem_case));\n      kind = ep::primitive::MemcpyKind::kDtoH;\n    } else {\n      kind = ep::primitive::MemcpyKind::kDtoD;\n    }\n  }\n  std::unique_ptr<ep::primitive::Memcpy> primitive =\n      ep::primitive::NewPrimitive<ep::primitive::MemcpyFactory>(stream->device_type(), kind);\n  CHECK(primitive);\n  primitive->Launch(stream, dst, src, sz);\n}\n\nvoid AutoMemcpy(ep::Stream* stream, Blob* dst, const Blob* src) {\n  const size_t body_bytes = src->ByteSizeOfBlobBody();\n  CHECK_EQ(dst->ByteSizeOfBlobBody(), body_bytes);\n  AutoMemcpy(stream, dst->mut_dptr(), src->dptr(), body_bytes, dst->mem_case(), src->mem_case());\n}\n\nvoid SyncAutoMemcpy(ep::Stream* stream, void* dst, const void* src, size_t sz,\n                    const MemoryCase& dst_mem_case, const MemoryCase& src_mem_case) {\n  AutoMemcpy(stream, dst, src, sz, dst_mem_case, src_mem_case);\n  CHECK_JUST(stream->Sync());\n}\n\nvoid AutoMemset(ep::Stream* stream, void* dst, const char value, size_t sz,\n                const MemoryCase& /*dst_mem_case*/) {\n  std::unique_ptr<ep::primitive::Memset> primitive =\n      ep::primitive::NewPrimitive<ep::primitive::MemsetFactory>(stream->device_type());\n  primitive->Launch(stream, dst, value, sz);\n}\n\n}  //  namespace oneflow\n"
  },
  {
    "path": "oneflow/core/kernel/kernel_util.cuh",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_KERNEL_KERNEL_UTIL_CUH_\n#define ONEFLOW_CORE_KERNEL_KERNEL_UTIL_CUH_\n#include \"oneflow/core/cuda/atomic.cuh\"\n#include \"oneflow/core/device/cuda_pseudo_half.h\"\n#include \"oneflow/core/common/data_type.h\"\n#include \"oneflow/core/device/cuda_util.h\"\n\nnamespace oneflow {\n\ntemplate<typename T, typename std::enable_if<IsFloating<T>::value>::type* = nullptr>\nOF_DEVICE_FUNC T MaxWithLogThreshold(T x) {\n  const T threshold = 1e-20;\n  return x > threshold ? x : threshold;\n}\n\ntemplate<typename T, typename std::enable_if<IsIntegral<T>::value>::type* = nullptr>\nOF_DEVICE_FUNC T MaxWithLogThreshold(T x) {\n  return x;\n}\n\n#if defined(__CUDACC__)\n__device__ __forceinline__ half MaxWithLogThreshold(half x) {\n  half threshold = hexp2(__float2half(-14.0));\n  if (__hgt(x, threshold)) { return x; }\n  return threshold;\n}\n#endif\n\ntemplate<typename T>\nOF_DEVICE_FUNC T SafeLog(T x) {\n  return logf(MaxWithLogThreshold(x));\n}\n\n#if defined(__CUDACC__)\n__device__ __forceinline__ half SafeLog(half x) { return hlog(MaxWithLogThreshold(x)); }\n#endif\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_KERNEL_KERNEL_UTIL_CUH_\n"
  },
  {
    "path": "oneflow/core/kernel/kernel_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_KERNEL_KERNEL_UTIL_H_\n#define ONEFLOW_CORE_KERNEL_KERNEL_UTIL_H_\n\n#include \"oneflow/core/common/blas.h\"\n#include \"oneflow/core/common/data_type.h\"\n#include \"oneflow/core/common/str_util.h\"\n#include \"oneflow/core/device/cudnn_util.h\"\n#include \"oneflow/core/kernel/kernel_context.h\"\n#include \"oneflow/core/common/switch_func.h\"\n#include \"oneflow/core/kernel/new_kernel_util.h\"\n#include \"oneflow/core/register/blob.h\"\n#include \"oneflow/core/ep/include/stream.h\"\n\nnamespace oneflow {\n\nclass Blob;\nclass MemoryCase;\nclass StreamContext;\n\nvoid AutoMemcpy(ep::Stream* stream, void* dst, const void* src, size_t sz,\n                const MemoryCase& dst_mem_case, const MemoryCase& src_mem_case);\nvoid AutoMemcpy(ep::Stream* stream, Blob* dst, const Blob* src);\nvoid SyncAutoMemcpy(ep::Stream* stream, void* dst, const void* src, size_t sz,\n                    const MemoryCase& dst_mem_case, const MemoryCase& src_mem_case);\nvoid AutoMemset(ep::Stream* stream, void* dst, const char value, size_t sz,\n                const MemoryCase& dst_mem_case);\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_KERNEL_KERNEL_UTIL_H_\n"
  },
  {
    "path": "oneflow/core/kernel/learning_rate_schedule_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <sys/types.h>\n#include <unistd.h>\n\n#include \"oneflow/core/kernel/kernel.h\"\n#include \"oneflow/core/job/resource_desc.h\"\n#include \"oneflow/core/job/global_for.h\"\n#include \"oneflow/core/job/graph_scope_vars.h\"\n#include \"oneflow/core/persistence/tee_persistent_log_stream.h\"\n\nnamespace oneflow {\n\nclass LearningRateScheduleKernel final : public Kernel {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(LearningRateScheduleKernel);\n  LearningRateScheduleKernel() = default;\n  ~LearningRateScheduleKernel() override = default;\n\n private:\n  void VirtualKernelInit(KernelContext* ctx) override {\n    if (Singleton<ResourceDesc, ForSession>::Get()->enable_debug_mode()) {\n      pid_t pid = getpid();\n      log_stream_ = TeePersistentLogStream::Create(std::to_string(pid) + \"-train_step2lr.csv\");\n      (*log_stream_) << \"train_step, lr\\n\";\n    }\n    if (IsOpenGraphVerboseStepLr()) { print_step_lr_ = true; }\n  }\n\n  void ForwardDataContent(KernelContext* ctx) const override;\n  bool print_step_lr_ = false;\n  std::unique_ptr<TeePersistentLogStream> log_stream_;\n};\n\nnamespace {\n\ndouble GetDecayedLearningRate(const LearningRateDecayConf& conf, double base_lr, int64_t step);\n\ndouble ConstantLearningRate(double base_lr, double factor, int64_t total_step, int64_t cur_step) {\n  CHECK_GE(total_step, 0);\n  CHECK_GE(factor, 0.0);\n  CHECK_LE(factor, 1.0);\n  if (cur_step < total_step) { return base_lr * factor; }\n  return base_lr;\n}\n\ndouble LinearLearningRate(double base_lr, double start_factor, double end_factor,\n                          int64_t total_step, int64_t cur_step) {\n  CHECK_GE(total_step, 0);\n  CHECK_GE(start_factor, 0.0);\n  CHECK_LE(start_factor, 1.0);\n  CHECK_GE(end_factor, 0.0);\n  CHECK_LE(end_factor, 1.0);\n  double multiplier = end_factor;\n  double c_step_f = float(cur_step);\n  double t_step_f = float(total_step);\n  if (cur_step < total_step) {\n    multiplier = start_factor + (end_factor - start_factor) * (c_step_f / t_step_f);\n  }\n  return base_lr * multiplier;\n}\n\ndouble ExponentialDecayedLearningRate(const ExponentialDecayConf& conf, double lr,\n                                      int64_t cur_batch_num) {\n  CHECK_GT(conf.decay_batches(), 0);\n  double p = static_cast<double>(cur_batch_num) / static_cast<double>(conf.decay_batches());\n  if (conf.staircase()) { p = std::floor(p); }\n  return lr * std::pow(conf.decay_rate(), p);\n}\n\ndouble InverseTimeDecayedLearningRate(const InverseTimeDecayConf& conf, double lr,\n                                      int64_t cur_batch_num) {\n  CHECK_GT(conf.decay_batches(), 0);\n  double p = static_cast<double>(cur_batch_num) / static_cast<double>(conf.decay_batches());\n  if (conf.staircase()) { p = std::floor(p); }\n  return lr / (1.0 + conf.decay_rate() * p);\n}\n\ndouble NaturalExpDecayedLearningRate(const NaturalExpDecayConf& conf, double lr,\n                                     int64_t cur_batch_num) {\n  CHECK_GT(conf.decay_batches(), 0);\n  double p = static_cast<double>(cur_batch_num) / static_cast<double>(conf.decay_batches());\n  if (conf.staircase()) { p = std::floor(p); }\n  return lr * std::exp(-conf.decay_rate() * p);\n}\n\ndouble PiecewiseConstantLearningRate(const PiecewiseConstantConf& conf, double lr,\n                                     int64_t cur_batch_num) {\n  const PbRf<int64_t>& boundaries = conf.boundaries();\n  const PbRf<double>& values = conf.values();\n  CHECK_EQ(boundaries.size() + 1, values.size());\n  size_t i = 0;\n  for (; i < boundaries.size(); ++i) {\n    if (cur_batch_num <= boundaries[i]) { break; }\n  }\n  return values[i];\n}\n\ndouble PolynomialDecayedLearningRate(const PolynomialDecayConf& conf, double lr,\n                                     int64_t cur_batch_num) {\n  CHECK_GT(conf.decay_batches(), 0);\n  double cur_batch = static_cast<double>(cur_batch_num);\n  double decay_batches = static_cast<double>(conf.decay_batches());\n  if (conf.cycle()) {\n    if (cur_batch_num == 0) { cur_batch = 1.0; }\n    decay_batches = decay_batches * std::ceil(cur_batch / decay_batches);\n  } else {\n    cur_batch = std::min(cur_batch, decay_batches);\n  }\n  return (lr - conf.end_learning_rate()) * std::pow(1.0 - (cur_batch / decay_batches), conf.power())\n         + conf.end_learning_rate();\n}\n\ndouble CosineDecayedLearningRate(const CosineDecayConf& conf, double lr, int64_t cur_batch_num) {\n  CHECK_GT(conf.decay_batches(), 0);\n  const double PI = std::atan(1.0) * 4.0;\n  double cur_batch = static_cast<double>(cur_batch_num);\n  double decay_batches = static_cast<double>(conf.decay_batches());\n  cur_batch = std::min(cur_batch, decay_batches);\n  double cosine_decay = 0.5 * (1.0 + std::cos(PI * cur_batch / decay_batches));\n  double decayed = (1.0 - conf.alpha()) * cosine_decay + conf.alpha();\n  return lr * decayed;\n}\n\ndouble CosineAnnealingDecayedLearningRate(const CosineAnnealingDecayConf& conf, double lr,\n                                          int64_t cur_batch_num) {\n  CHECK_GT(conf.t_max(), 0);\n  if (0 == cur_batch_num) { return lr; }\n\n  const double PI = std::atan(1.0) * 4.0;\n  const double eta_min = conf.eta_min();\n  CHECK_LT(eta_min, lr);\n  const double t_max_d = static_cast<double>(conf.t_max());\n  const double cur_batch_num_d = static_cast<double>(cur_batch_num);\n\n  return eta_min + (((lr - eta_min) * (1 + std::cos(PI * (cur_batch_num_d / t_max_d)))) / 2);\n}\n\ndouble LinearCosineDecayedLearningRate(const LinearCosineDecayConf& conf, double lr,\n                                       int64_t cur_batch_num) {\n  CHECK_GT(conf.decay_batches(), 0);\n  const double PI = std::atan(1.0) * 4.0;\n  double cur_batch = static_cast<double>(cur_batch_num);\n  double decay_batches = static_cast<double>(conf.decay_batches());\n  cur_batch = std::min(cur_batch, decay_batches);\n  double linear_decay = (decay_batches - cur_batch) / decay_batches;\n  double cosine_decay =\n      0.5 * (1.0 + std::cos(PI * 2.0 * conf.num_periods() * cur_batch / decay_batches));\n  double decayed = (conf.alpha() + linear_decay) * cosine_decay + conf.beta();\n  return lr * decayed;\n}\n\ndouble PiecewiseScalingLearningRate(const PiecewiseScalingConf& conf, double lr,\n                                    int64_t cur_batch_num) {\n  const PbRf<int64_t>& boundaries = conf.boundaries();\n  const PbRf<double>& scales = conf.scales();\n  CHECK_EQ(boundaries.size() + 1, scales.size());\n  size_t i = 0;\n  for (; i < boundaries.size(); ++i) {\n    if (cur_batch_num <= boundaries[i]) { break; }\n  }\n  return scales[i] * lr;\n}\n\ndouble StepLearningRate(const StepConf& conf, double lr, int64_t cur_batch_num) {\n  const int64_t step_size = conf.step_size();\n  CHECK_GE(step_size, 1);\n  const double gamma = conf.gamma();\n\n  double cur_batch = static_cast<double>(cur_batch_num);\n  double step = static_cast<double>(step_size);\n  size_t i = std::floor(cur_batch / step);\n\n  return lr * std::pow(gamma, i);\n}\n\ndouble MultiStepLearningRate(const MultiStepConf& conf, double lr, int64_t cur_batch_num) {\n  const PbRf<int64_t>& milestones = conf.milestones();\n  CHECK_GE(milestones.size(), 1);\n  const double gamma = conf.gamma();\n\n  size_t i = 0;\n  if (cur_batch_num < milestones[milestones.size() - 1]) {\n    for (; i < milestones.size(); ++i) {\n      if (cur_batch_num < milestones[i]) { break; }\n    }\n  } else {\n    i = milestones.size();\n  }\n\n  return lr * std::pow(gamma, i);\n}\n\ndouble CosineAnnealingWarmRestartsLearningRate(const CosineAnnealingWarmRestartsConf& conf,\n                                               const double base_lr, const int64_t step) {\n  int64_t epoch_steps = conf.t_initial();\n  int64_t epoch = step / epoch_steps;\n  int64_t step_in_epoch = step - (epoch_steps * epoch);\n  if (conf.t_mult() > 1) {\n    epoch = static_cast<int64_t>(std::floor(\n        std::log(1 - step / conf.t_initial() * (1 - conf.t_mult())) / std::log(conf.t_mult())));\n    int64_t interval = std::pow(conf.t_mult(), epoch);\n    epoch_steps = interval * conf.t_initial();\n    step_in_epoch = step\n                    - static_cast<int64_t>(std::floor(static_cast<double>(1 - interval)\n                                                      / (1 - conf.t_mult()) * conf.t_initial()));\n  }\n  double lr = conf.eta_min();\n  if (conf.restart_limit() == 0 || (conf.restart_limit() > 0 && epoch < conf.restart_limit())) {\n    double gamma = std::pow(conf.decay_rate(), epoch);\n    lr = lr + 0.5 * (base_lr * gamma - lr) * (1 + std::cos(M_PI * step_in_epoch / epoch_steps));\n  }\n  return lr;\n}\n\ndouble SequentialScheduler(const SequentialSchedulerConf& conf, const double base_lr,\n                           const int64_t step) {\n  CHECK_GE(conf.schedulers_size(), 1);\n  CHECK_EQ(conf.milestones_size(), conf.schedulers_size() - 1);\n  CHECK_EQ(conf.interval_rescaling_size(), conf.milestones_size());\n\n  int64_t cur_step = step;\n  size_t scheduler_idx = 0;\n  for (size_t i = 0; i < conf.milestones_size(); ++i) {\n    if (step < conf.milestones(i)) {\n      break;\n    } else {\n      if (conf.interval_rescaling(i)) { cur_step = step - conf.milestones(i); }\n      scheduler_idx++;\n    }\n  }\n  return GetDecayedLearningRate(conf.schedulers(scheduler_idx), base_lr, cur_step);\n}\n\ndouble GetDecayedLearningRate(const LearningRateDecayConf& conf, double lr, int64_t cur_batch_num) {\n  if (conf.has_exponential_conf()) {\n    return ExponentialDecayedLearningRate(conf.exponential_conf(), lr, cur_batch_num);\n  } else if (conf.has_inverse_time_conf()) {\n    return InverseTimeDecayedLearningRate(conf.inverse_time_conf(), lr, cur_batch_num);\n  } else if (conf.has_natural_exp_conf()) {\n    return NaturalExpDecayedLearningRate(conf.natural_exp_conf(), lr, cur_batch_num);\n  } else if (conf.has_piecewise_constant_conf()) {\n    return PiecewiseConstantLearningRate(conf.piecewise_constant_conf(), lr, cur_batch_num);\n  } else if (conf.has_polynomial_conf()) {\n    return PolynomialDecayedLearningRate(conf.polynomial_conf(), lr, cur_batch_num);\n  } else if (conf.has_cosine_conf()) {\n    return CosineDecayedLearningRate(conf.cosine_conf(), lr, cur_batch_num);\n  } else if (conf.has_cosine_annealing_conf()) {\n    return CosineAnnealingDecayedLearningRate(conf.cosine_annealing_conf(), lr, cur_batch_num);\n  } else if (conf.has_linear_cosine_conf()) {\n    return LinearCosineDecayedLearningRate(conf.linear_cosine_conf(), lr, cur_batch_num);\n  } else if (conf.has_piecewise_scaling_conf()) {\n    return PiecewiseScalingLearningRate(conf.piecewise_scaling_conf(), lr, cur_batch_num);\n  } else if (conf.has_step_conf()) {\n    return StepLearningRate(conf.step_conf(), lr, cur_batch_num);\n  } else if (conf.has_multi_step_conf()) {\n    return MultiStepLearningRate(conf.multi_step_conf(), lr, cur_batch_num);\n  } else if (conf.has_constant_lr_conf()) {\n    return ConstantLearningRate(lr, conf.constant_lr_conf().factor(),\n                                conf.constant_lr_conf().total_iters(), cur_batch_num);\n  } else if (conf.has_linear_lr_conf()) {\n    return LinearLearningRate(lr, conf.linear_lr_conf().start_factor(),\n                              conf.linear_lr_conf().end_factor(),\n                              conf.linear_lr_conf().total_iters(), cur_batch_num);\n  } else if (conf.has_cosine_annealing_warm_restarts_conf()) {\n    return CosineAnnealingWarmRestartsLearningRate(conf.cosine_annealing_warm_restarts_conf(), lr,\n                                                   cur_batch_num);\n  } else if (conf.has_sequential_scheduler_conf()) {\n    return SequentialScheduler(conf.sequential_scheduler_conf(), lr, cur_batch_num);\n  } else {\n    UNIMPLEMENTED();\n  }\n}\n\n}  // namespace\n\nvoid LearningRateScheduleKernel::ForwardDataContent(KernelContext* ctx) const {\n  const LearningRateScheduleOpConf& conf = this->op_conf().learning_rate_schedule_conf();\n  const int64_t train_step = *ctx->BnInOp2Blob(\"train_step\")->dptr<int64_t>();\n  float learning_rate = conf.learning_rate();\n  if (conf.has_learning_rate_decay()) {\n    learning_rate = GetDecayedLearningRate(conf.learning_rate_decay(), learning_rate, train_step);\n  }\n  // NOTE(lixiang): Set verbose=True will print step and lr.\n  if (unlikely(print_step_lr_)) {\n    std::cout << \"Last step \" << train_step << \" adjusting learning rate to \" << learning_rate\n              << std::endl;\n  }\n  *ctx->BnInOp2Blob(\"out\")->mut_dptr<float>() = learning_rate;\n  if (Singleton<ResourceDesc, ForSession>::Get()->enable_debug_mode()) {\n    (*log_stream_) << std::to_string(train_step) << \", \" << std::to_string(learning_rate) << \"\\n\";\n    log_stream_->Flush();\n  }\n}\n\nREGISTER_KERNEL(OperatorConf::kLearningRateScheduleConf, LearningRateScheduleKernel);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/kernel/nccl_send_recv_boxing_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/kernel/kernel.h\"\n#include \"oneflow/core/device/nccl_util.h\"\n#include \"oneflow/core/job/eager_nccl_comm_manager.h\"\n#include \"oneflow/core/register/tensor_slice_copier.h\"\n#include \"oneflow/core/ep/include/primitive/memset.h\"\n#include \"oneflow/core/ep/include/primitive/add.h\"\n#include \"oneflow/core/operator/nccl_send_recv_boxing_op_util.h\"\n#include \"oneflow/user/kernels/collective_communication/include/all_to_all.h\"\n\n#if (defined(WITH_CUDA) && (NCCL_VERSION_CODE > 2700)) || defined(WITH_NPU) || defined(WITH_MLU)\n\nnamespace oneflow {\n\nclass CclSendRecvBoxingKernel final : public Kernel {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CclSendRecvBoxingKernel);\n  CclSendRecvBoxingKernel() = default;\n  ~CclSendRecvBoxingKernel() override = default;\n\n  const std::vector<std::shared_ptr<TensorSliceCopier>>& in_tensor_slice_copier_vec() const {\n    return in_tensor_slice_copier_vec_;\n  }\n  const std::vector<std::shared_ptr<TensorSliceCopier>>& out_tensor_slice_copier_vec() const {\n    return out_tensor_slice_copier_vec_;\n  }\n  const std::vector<uint64_t>& send_elem_cnts() const { return send_elem_cnts_; }\n  const std::vector<uint64_t>& recv_elem_cnts() const { return recv_elem_cnts_; }\n  const bool has_input() const { return has_input_; }\n  const bool has_output() const { return has_output_; }\n  ccl::CclComm ccl_comm() const { return GetOrCreate().ccl_comm; }\n\n private:\n  struct Comm {\n    explicit Comm(ccl::CclComm comm) : ccl_comm(comm) {}\n    ccl::CclComm ccl_comm;\n  };\n\n  void Init() const {\n    ParallelDesc parallel_desc(parallel_conf_);\n    EagerCclCommMgr* comm_mgr = CHECK_NOTNULL(Singleton<EagerCclCommMgr>::Get());\n    ccl::CclComm ccl_comm =\n        comm_mgr->GetCclCommForParallelDescAndStreamName(parallel_desc, stream_name_);\n    ccl_comm_.reset(new Comm(ccl_comm));\n  }\n\n  const Comm& GetOrCreate() const {\n    if (!ccl_comm_) { Init(); }\n    return *ccl_comm_;\n  }\n\n  void VirtualKernelInit(KernelContext* ctx) override;\n  void ForwardDataContent(KernelContext* ctx) const override;\n\n  std::string stream_name_;\n  ParallelConf parallel_conf_;\n  mutable std::unique_ptr<Comm> ccl_comm_;\n  bool src_nd_sbp_no_partial_parallel_;\n  std::vector<std::shared_ptr<TensorSliceCopier>> in_tensor_slice_copier_vec_;\n  std::vector<std::shared_ptr<TensorSliceCopier>> out_tensor_slice_copier_vec_;\n  std::vector<uint64_t> send_elem_cnts_;\n  std::vector<uint64_t> recv_elem_cnts_;\n  bool has_input_;\n  bool has_output_;\n};\n\nvoid CclSendRecvBoxingKernel::ForwardDataContent(KernelContext* ctx) const {\n  Blob* buf = ctx->BnInOp2Blob(\"buf\");\n  ccl::CclComm ccl_comm = this->ccl_comm();\n  const std::vector<uint64_t>& send_elem_cnts = this->send_elem_cnts();\n  const std::vector<uint64_t>& recv_elem_cnts = this->recv_elem_cnts();\n  const int64_t parallel_num = this->kernel_conf().parallel_ctx().parallel_num();\n  const DataType data_type = buf->data_type();\n  const size_t dtype_size = GetSizeOfDataType(data_type);\n  std::vector<void*> send_in_ptr;\n  std::vector<void*> recv_out_ptr;\n  std::vector<uint64_t> send_offsets;\n  std::vector<uint64_t> recv_offsets;\n  char* buf_ptr = buf->mut_dptr<char>();\n  uint64_t offset = 0;\n  if (this->has_input()) {\n    for (int64_t i = 0; i < parallel_num; ++i) {\n      void* send_ptr = reinterpret_cast<void*>(buf_ptr + offset);\n      send_in_ptr.push_back(send_ptr);\n      send_offsets.push_back(offset);\n      offset += send_elem_cnts.at(i) * dtype_size;\n    }\n  }\n  const uint64_t recv_offset = offset;\n  if (this->has_output()) {\n    for (int64_t i = 0; i < parallel_num; ++i) {\n      void* recv_ptr = reinterpret_cast<void*>(buf_ptr + offset);\n      recv_out_ptr.push_back(recv_ptr);\n      recv_offsets.push_back(offset - recv_offset);\n      offset += recv_elem_cnts.at(i) * dtype_size;\n    }\n  }\n  if (this->has_input()) {\n    const Blob* in = ctx->BnInOp2Blob(\"in\");\n    const std::vector<std::shared_ptr<TensorSliceCopier>>& in_tensor_slice_copier_vec =\n        this->in_tensor_slice_copier_vec();\n    for (int64_t i = 0; i < parallel_num; ++i) {\n      if (in_tensor_slice_copier_vec.at(i)) {\n        in_tensor_slice_copier_vec.at(i)->Copy(ctx->stream(), send_in_ptr.at(i), in->dptr());\n      }\n    }\n  }\n\n  if (this->has_input() || this->has_output()) {\n    std::unique_ptr<ccl::AllToAll> all_to_all = ccl::NewCollectiveCommunication<ccl::AllToAll>(\n        ctx->stream()->device_type(), data_type, data_type, parallel_num);\n    void* send_buf = reinterpret_cast<void*>(buf_ptr);\n    void* recv_buf = reinterpret_cast<void*>(buf_ptr + recv_offset);\n    all_to_all->Launch(ctx->stream(), send_buf, send_elem_cnts.data(), send_offsets.data(),\n                       recv_buf, recv_elem_cnts.data(), recv_offsets.data(), ccl_comm,\n                       this->has_input(), this->has_output());\n  }\n\n  if (!this->has_output()) { return; }\n  Blob* out = ctx->BnInOp2Blob(\"out\");\n  const std::vector<std::shared_ptr<TensorSliceCopier>>& out_tensor_slice_copier_vec =\n      this->out_tensor_slice_copier_vec();\n\n  if (src_nd_sbp_no_partial_parallel_) {\n    for (int64_t i = 0; i < parallel_num; ++i) {\n      if (out_tensor_slice_copier_vec.at(i)) {\n        out_tensor_slice_copier_vec.at(i)->Copy(ctx->stream(), out->mut_dptr(), recv_out_ptr.at(i));\n      }\n    }\n  } else {\n    std::unique_ptr<ep::primitive::Add> add_primitive =\n        ep::primitive::NewPrimitive<ep::primitive::AddFactory>(ctx->stream()->device_type(),\n                                                               out->data_type());\n    CHECK(add_primitive);\n    std::unique_ptr<ep::primitive::Memset> memset_primitive =\n        ep::primitive::NewPrimitive<ep::primitive::MemsetFactory>(ctx->stream()->device_type());\n    CHECK(memset_primitive);\n    bool is_first_slice = true;\n    for (int64_t i = 0; i < parallel_num; ++i) {\n      if (out_tensor_slice_copier_vec.at(i)) {\n        if (is_first_slice) {\n          is_first_slice = false;\n          if (recv_elem_cnts.at(i) != out->shape().elem_cnt()) {\n            // if not same shape, memset out\n            memset_primitive->Launch(ctx->stream(), out->mut_dptr(), 0,\n                                     out->shape().elem_cnt() * dtype_size);\n          }\n          out_tensor_slice_copier_vec.at(i)->Copy(ctx->stream(), out->mut_dptr(),\n                                                  recv_out_ptr.at(i));\n        } else {\n          if (recv_elem_cnts.at(i) == out->shape().elem_cnt()) {\n            add_primitive->Launch(ctx->stream(), out->dptr(), recv_out_ptr.at(i), out->mut_dptr(),\n                                  out->shape().elem_cnt());\n          } else {\n            void* out_buf = reinterpret_cast<void*>(buf_ptr + offset);\n            memset_primitive->Launch(ctx->stream(), out_buf, 0,\n                                     out->shape().elem_cnt() * dtype_size);\n            out_tensor_slice_copier_vec.at(i)->Copy(ctx->stream(), out_buf, recv_out_ptr.at(i));\n            add_primitive->Launch(ctx->stream(), out->dptr(), out_buf, out->mut_dptr(),\n                                  out->shape().elem_cnt());\n          }\n        }\n      }\n    }\n  }\n}\n\nvoid CclSendRecvBoxingKernel::VirtualKernelInit(KernelContext* ctx) {\n  const NcclSendRecvBoxingOpConf& conf = this->op_conf().nccl_send_recv_boxing_conf();\n  if (this->op_conf().has_stream_name_hint()) {\n    stream_name_ = this->op_conf().stream_name_hint();\n  } else {\n    stream_name_ = EagerCclCommMgr::kDefaultCclStreamName;\n  }\n  parallel_conf_ = conf.parallel_conf();\n  const int64_t parallel_id = this->kernel_conf().parallel_ctx().parallel_id();\n  ParallelDesc parallel_desc(parallel_conf_);\n  ParallelDesc src_parallel_desc(conf.src_parallel_conf());\n  ParallelDesc dst_parallel_desc(conf.dst_parallel_conf());\n  const NdSbp& src_nd_sbp = conf.src_nd_sbp();\n  const NdSbp& dst_nd_sbp = conf.dst_nd_sbp();\n  has_input_ = conf.has_input();\n  has_output_ = conf.has_output();\n  src_nd_sbp_no_partial_parallel_ = !NdSbpHasPartialParallel(src_nd_sbp);\n  const DataType data_type = this->kernel_conf().data_type();\n  const DeviceType device_type = parallel_desc.device_type();\n  const Shape& logical_shape = Shape(conf.logical_shape());\n  const int64_t parallel_num = parallel_desc.parallel_num();\n\n  std::vector<TensorSliceView> src_send_intersections;\n  std::vector<TensorSliceView> dst_recv_intersections;\n  GetRankSendRecvIntersection(parallel_id, parallel_desc, src_parallel_desc, dst_parallel_desc,\n                              src_nd_sbp, dst_nd_sbp, logical_shape, &src_send_intersections,\n                              &dst_recv_intersections);\n  // if parallel_id exists in src parallel desc, has send\n  int64_t src_parallel_id = GetMappedParallelId(parallel_id, parallel_desc, src_parallel_desc);\n  if (src_parallel_id != -1) {\n    CHECK_EQ(src_send_intersections.size(), parallel_num);\n    send_elem_cnts_.resize(parallel_num);\n    in_tensor_slice_copier_vec_.resize(parallel_num);\n    const TensorSliceView& cur_rank_in_slice = GetTensorSliceView4ParallelId(\n        *src_parallel_desc.hierarchy(), src_nd_sbp, logical_shape, src_parallel_id);\n    for (int64_t i = 0; i < parallel_num; ++i) {\n      const TensorSliceView& intersection = src_send_intersections.at(i);\n      if (!intersection.IsEmpty()) {\n        send_elem_cnts_.at(i) = intersection.shape().elem_cnt();\n        in_tensor_slice_copier_vec_.at(i).reset(\n            new TensorSliceCopier(intersection, cur_rank_in_slice, data_type, device_type));\n      }\n    }\n  } else {\n    CHECK_EQ(src_send_intersections.size(), 0);\n  }\n\n  // if parallel_id exists in src parallel desc, has send\n  int64_t dst_parallel_id = GetMappedParallelId(parallel_id, parallel_desc, dst_parallel_desc);\n  if (dst_parallel_id != -1) {\n    CHECK_EQ(dst_recv_intersections.size(), parallel_num);\n    recv_elem_cnts_.resize(parallel_num);\n    out_tensor_slice_copier_vec_.resize(parallel_num);\n    const TensorSliceView& cur_rank_out_slice = GetTensorSliceView4ParallelId(\n        *dst_parallel_desc.hierarchy(), dst_nd_sbp, logical_shape, dst_parallel_id);\n    for (int64_t i = 0; i < parallel_num; ++i) {\n      const TensorSliceView& intersection = dst_recv_intersections.at(i);\n      if (!intersection.IsEmpty()) {\n        recv_elem_cnts_.at(i) = intersection.shape().elem_cnt();\n        out_tensor_slice_copier_vec_.at(i).reset(\n            new TensorSliceCopier(cur_rank_out_slice, intersection, data_type, device_type));\n      }\n    }\n  } else {\n    CHECK_EQ(dst_recv_intersections.size(), 0);\n  }\n}\n\n// TODO: replace all kNcclxxxConf with kCclxxxConf(for multi devices)\nREGISTER_KERNEL(OperatorConf::kNcclSendRecvBoxingConf, CclSendRecvBoxingKernel);\n\nREGISTER_SYSTEM_OP_KERNEL_UNIFIED_CCL_COMM_INIT(OperatorConf::kNcclSendRecvBoxingConf);\n\n}  // namespace oneflow\n\n#endif  // WITH_CUDA || WITH_NPU || WITH_MLU\n"
  },
  {
    "path": "oneflow/core/kernel/new_kernel_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_KERNEL_NEW_KERNEL_UTIL_H_\n#define ONEFLOW_CORE_KERNEL_NEW_KERNEL_UTIL_H_\n#include \"oneflow/core/common/data_type.h\"\n#include \"oneflow/core/device/cuda_util.h\"\n#include \"oneflow/core/ep/include/stream.h\"\n#include \"oneflow/core/ep/include/primitive/memset.h\"\n#include \"oneflow/core/ep/include/primitive/memcpy.h\"\n\nnamespace oneflow {\n\nnamespace ep {\n\nclass Stream;\n\n}\n\ntemplate<DeviceType device_type>\nvoid Memcpy(ep::Stream* stream, void* dst, const void* src, size_t sz) {\n  CHECK_EQ(device_type, stream->device_type()) << \"Device type mismatch\";\n  std::unique_ptr<ep::primitive::Memcpy> primitive =\n      ep::primitive::NewPrimitive<ep::primitive::MemcpyFactory>(stream->device_type(),\n                                                                ep::primitive::MemcpyKind::kDtoD);\n  CHECK(primitive) << \"Can not create Memcpy primitive for device type \" << device_type;\n  primitive->Launch(stream, dst, src, sz);\n}\n\ntemplate<DeviceType device_type>\nvoid Memset(ep::Stream* stream, void* dst, const char value, size_t sz) {\n  CHECK_EQ(device_type, stream->device_type()) << \"Device type mismatch\";\n  std::unique_ptr<ep::primitive::Memset> primitive =\n      ep::primitive::NewPrimitive<ep::primitive::MemsetFactory>(stream->device_type());\n  CHECK(primitive) << \"Can not create Memset primitive for device type \" << device_type;\n  primitive->Launch(stream, dst, value, sz);\n}\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_KERNEL_NEW_KERNEL_UTIL_H_\n"
  },
  {
    "path": "oneflow/core/kernel/nop_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/kernel/kernel.h\"\n\nnamespace oneflow {\n\nclass NopKernel final : public Kernel {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(NopKernel);\n  NopKernel() = default;\n  ~NopKernel() = default;\n\n private:\n  void ForwardDataContent(KernelContext* ctx) const override {}\n};\n\nREGISTER_KERNEL(OperatorConf::kVariableConf, NopKernel);\nREGISTER_KERNEL(OperatorConf::kTickConf, NopKernel);\nREGISTER_KERNEL(OperatorConf::kSinkTickConf, NopKernel);\nREGISTER_KERNEL(OperatorConf::kAccTickConf, NopKernel);\nREGISTER_KERNEL(OperatorConf::kCopyCommNetConf, NopKernel);\nREGISTER_KERNEL(OperatorConf::kDeviceTickConf, NopKernel);\nREGISTER_KERNEL(OperatorConf::kDstSubsetTickConf, NopKernel);\nREGISTER_KERNEL(OperatorConf::kSourceTickConf, NopKernel);\nREGISTER_KERNEL(OperatorConf::kSrcSubsetTickConf, NopKernel);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/kernel/output_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/kernel/kernel.h\"\n#include \"oneflow/core/common/buffer_manager.h\"\n#include \"oneflow/core/job/critical_section_instance.h\"\n#include \"oneflow/core/job/global_for.h\"\n\nnamespace oneflow {\n\nclass OutputKernel final : public Kernel {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(OutputKernel);\n  OutputKernel() = default;\n  ~OutputKernel() = default;\n\n private:\n  void ForwardDataContent(KernelContext* ctx) const override;\n  void ForwardHeader(KernelContext* ctx) const override;\n};\n\nvoid OutputKernel::ForwardDataContent(KernelContext* ctx) const {\n  CHECK(this->op_conf().output_conf().has_job_name());\n  const auto& job_name = this->op_conf().output_conf().job_name();\n  const auto& op_name = this->op_conf().name();\n  auto* buffer_mgr = Singleton<BufferMgr<std::shared_ptr<CriticalSectionInstance>>>::Get();\n  auto* buffer = buffer_mgr->Get(GetOutputBufferName(job_name, op_name));\n  std::shared_ptr<CriticalSectionInstance> critical_section_instance;\n  BufferStatus buffer_status = buffer->TryReceive(&critical_section_instance);\n  CHECK_NE(buffer_status, kBufferStatusEmpty);\n  if (buffer_status == kBufferStatusSuccess) {\n    critical_section_instance->AccessBlobByOpName(ctx->stream(), ctx->BnInOp2Blob(\"in\"), op_name);\n  }\n}\n\nvoid OutputKernel::ForwardHeader(KernelContext* ctx) const {\n  // Do nothing.\n}\n\nREGISTER_KERNEL(OperatorConf::kOutputConf, OutputKernel);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/kernel/profiler_kernel_observer.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/kernel/profiler_kernel_observer.h\"\n#include \"oneflow/core/profiler/profiler.h\"\n#include \"oneflow/core/profiler/kernel.h\"\n\nnamespace oneflow {\n\nvoid ProfilerKernelObserver::WillForwardDataContent(KernelContext* kernel_ctx,\n                                                    const Kernel* kernel) {\n  OF_PROFILER_ONLY_CODE(profiler::TraceKernelForwardDataContentStart(kernel_ctx, kernel));\n}\n\nvoid ProfilerKernelObserver::DidForwardDataContent(KernelContext* kernel_ctx,\n                                                   const Kernel* kernel) {\n  OF_PROFILER_ONLY_CODE(profiler::TraceKernelForwardDataContentEnd(kernel_ctx, kernel));\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/kernel/profiler_kernel_observer.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_KERNEL_PROFILER_KERNEL_OBSERVER_H_\n#define ONEFLOW_CORE_KERNEL_PROFILER_KERNEL_OBSERVER_H_\n\n#include \"oneflow/core/kernel/kernel_observer.h\"\n\nnamespace oneflow {\n\nclass ProfilerKernelObserver final : public KernelObserver {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(ProfilerKernelObserver);\n  ProfilerKernelObserver() = default;\n  ~ProfilerKernelObserver() override = default;\n\n  void WillForwardDataContent(KernelContext* kernel_ctx, const Kernel* kernel) override;\n  void DidForwardDataContent(KernelContext* kernel_ctx, const Kernel* kernel) override;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_KERNEL_PROFILER_KERNEL_OBSERVER_H_\n"
  },
  {
    "path": "oneflow/core/kernel/random_generator.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/kernel/random_generator.h\"\n#include \"oneflow/core/common/preprocessor.h\"\n\nnamespace oneflow {\n\ntemplate<typename T>\nvoid RandomGenerator<DeviceType::kCPU>::Uniform(const int64_t elem_cnt, T* dptr) {\n  Uniform(elem_cnt, GetZeroVal<T>(), GetOneVal<T>(), dptr);\n}\n\ntemplate<typename T>\nvoid RandomGenerator<DeviceType::kCPU>::Uniform(const int64_t elem_cnt, const T min, const T max,\n                                                T* dptr) {\n  CHECK_GE(elem_cnt, 0);\n  CHECK(dptr);\n  CHECK_LE(min, max);\n  std::uniform_real_distribution<T> random_distribution(min, std::nextafter(max, GetMaxVal<T>()));\n  for (int64_t i = 0; i < elem_cnt; ++i) { dptr[i] = random_distribution(mt19937_generator_); }\n}\n\n#define INITIATE_CPU_RANDOM_GENERATOR_UNIFORM(T, typeproto)                                        \\\n  template void RandomGenerator<DeviceType::kCPU>::Uniform<T>(const int64_t elem_cnt, T* dptr);    \\\n  template void RandomGenerator<DeviceType::kCPU>::Uniform<T>(const int64_t elem_cnt, const T min, \\\n                                                              const T max, T* dptr);\n\nOF_PP_FOR_EACH_TUPLE(INITIATE_CPU_RANDOM_GENERATOR_UNIFORM, FLOATING_DATA_TYPE_SEQ);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/kernel/random_generator.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/kernel/random_generator.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename T>\nvoid RngUniformGpu(const curandGenerator_t& gen, int64_t n, T* ret);\n\ntemplate<>\nvoid RngUniformGpu<float>(const curandGenerator_t& gen, int64_t n, float* ret) {\n  OF_CURAND_CHECK(curandGenerateUniform(gen, ret, n));\n}\n\ntemplate<>\nvoid RngUniformGpu<double>(const curandGenerator_t& gen, int64_t n, double* ret) {\n  OF_CURAND_CHECK(curandGenerateUniformDouble(gen, ret, n));\n}\n\n}  // namespace\n\nRandomGenerator<DeviceType::kCUDA>::RandomGenerator(int64_t seed, ep::Stream* stream) {\n  OF_CURAND_CHECK(curandCreateGenerator(&curand_generator_, CURAND_RNG_PSEUDO_DEFAULT));\n  OF_CURAND_CHECK(curandSetPseudoRandomGeneratorSeed(curand_generator_, seed));\n  OF_CURAND_CHECK(curandSetStream(curand_generator_, stream->As<ep::CudaStream>()->cuda_stream()));\n}\n\nRandomGenerator<DeviceType::kCUDA>::~RandomGenerator() {\n  OF_CURAND_CHECK(curandDestroyGenerator(curand_generator_));\n}\n\ntemplate<typename T>\nvoid RandomGenerator<DeviceType::kCUDA>::Uniform(const int64_t elem_cnt, T* dptr) {\n  RngUniformGpu(curand_generator_, elem_cnt, dptr);\n}\n\n#define INITIATE_CUDA_RANDOM_GENERATOR_UNIFORM(T, typeproto) \\\n  template void RandomGenerator<DeviceType::kCUDA>::Uniform<T>(const int64_t elem_cnt, T* dptr);\n\nOF_PP_FOR_EACH_TUPLE(INITIATE_CUDA_RANDOM_GENERATOR_UNIFORM, FLOATING_DATA_TYPE_SEQ);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/kernel/random_generator.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_KERNEL_RANDOM_GENERATOR_H_\n#define ONEFLOW_CORE_KERNEL_RANDOM_GENERATOR_H_\n\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/device/cuda_util.h\"\n#include \"oneflow/core/common/data_type.h\"\n#include \"oneflow/core/job/resource.pb.h\"\n#include \"oneflow/core/ep/include/stream.h\"\n\nnamespace oneflow {\n\ntemplate<DeviceType device_type>\nclass RandomGenerator;\n\ntemplate<>\nclass RandomGenerator<DeviceType::kCPU> final {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(RandomGenerator);\n  RandomGenerator(int64_t seed, ep::Stream* stream) : mt19937_generator_(seed) {}\n  ~RandomGenerator() {}\n\n  template<typename T>\n  void Uniform(const int64_t elem_cnt, T* dptr);\n  template<typename T>\n  void Uniform(const int64_t elem_cnt, const T min, const T max, T* dptr);\n\n private:\n  std::mt19937 mt19937_generator_;\n};\n\ntemplate<>\nclass RandomGenerator<DeviceType::kCUDA> final {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(RandomGenerator);\n  RandomGenerator(int64_t seed, ep::Stream* stream);\n  ~RandomGenerator();\n\n  template<typename T>\n  void Uniform(const int64_t elem_cnt, T* dptr);\n\n private:\n#ifdef WITH_CUDA\n  curandGenerator_t curand_generator_;\n#endif\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_KERNEL_RANDOM_GENERATOR_H_\n"
  },
  {
    "path": "oneflow/core/kernel/reentrant_lock_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/kernel/reentrant_lock_kernel.h\"\n\nnamespace oneflow {\n\nstd::string ReentrantLockStatus::kEmptyIbn = \"reentrant_lock_status_empty_ibn\";\n\nvoid ReentrantLockStatus::Init(const KernelConf& kernel_conf) {\n  const auto& conf = kernel_conf.op_attribute().op_conf().reentrant_lock_conf();\n  cur_ibn_ = \"\";\n  cur_act_id_ = -1;\n  acquired_lock_to_be_sent_ = false;\n  total_queued_request_lock_num_ = 0;\n  total_acquired_lock_num_ = 0;\n  lock_id2queued_request_act_id_.resize(conf.lock_id2intersecting_lock_ids_size());\n  lock_id2acquired_num_.resize(conf.lock_id2intersecting_lock_ids_size());\n  for (const Int64List& ids : conf.lock_id2intersecting_lock_ids()) {\n    lock_id2intersecting_lock_ids_.emplace_back(\n        std::vector<int64_t>(ids.value().begin(), ids.value().end()));\n  }\n}\n\nbool ReentrantLockStatus::TryAcquireLock(int64_t lock_id) {\n  CHECK_EQ(lock_id2queued_request_act_id_.at(lock_id).empty(), false);\n  int64_t act_id = lock_id2queued_request_act_id_.at(lock_id).front();\n  bool blocked = false;\n  for (int64_t intersect_lock_id : lock_id2intersecting_lock_ids_.at(lock_id)) {\n    if (lock_id2acquired_num_.at(intersect_lock_id) > 0\n        || (lock_id2queued_request_act_id_.at(intersect_lock_id).empty() == false\n            && lock_id2queued_request_act_id_.at(intersect_lock_id).front() < act_id)) {\n      blocked = true;\n      break;\n    }\n  }\n  if (blocked) { return false; }\n  lock_id2queued_request_act_id_.at(lock_id).pop();\n  --total_queued_request_lock_num_;\n  ++lock_id2acquired_num_.at(lock_id);\n  ++total_acquired_lock_num_;\n  return true;\n}\n\nvoid ReentrantLockStatus::RequestLock(int64_t lock_id, std::queue<int64_t>* unlocked_ids) {\n  lock_id2queued_request_act_id_.at(lock_id).push(cur_act_id());\n  ++total_queued_request_lock_num_;\n  if (TryAcquireLock(lock_id)) { unlocked_ids->push(lock_id); }\n}\n\nvoid ReentrantLockStatus::ReleaseLock(int64_t lock_id, std::queue<int64_t>* unlocked_ids) {\n  CHECK_GT(lock_id2acquired_num_.at(lock_id), 0);\n  CHECK_GT(total_acquired_lock_num_, 0);\n  --lock_id2acquired_num_.at(lock_id);\n  --total_acquired_lock_num_;\n  size_t unlocked_cnt = 0;\n  do {\n    unlocked_cnt = 0;\n    auto ReleaseRelatedLockId = [&](int64_t related_lock_id) {\n      if (lock_id2queued_request_act_id_.at(related_lock_id).empty()) { return; }\n      if (TryAcquireLock(related_lock_id)) {\n        unlocked_ids->push(related_lock_id);\n        ++unlocked_cnt;\n      }\n    };\n    ReleaseRelatedLockId(lock_id);\n    for (int64_t id : lock_id2intersecting_lock_ids_.at(lock_id)) { ReleaseRelatedLockId(id); }\n  } while (unlocked_cnt > 0);\n}\n\ntemplate<typename T>\nvoid ReentrantLockKernel<T>::VirtualKernelInit(KernelContext* ctx) {\n  ctx->set_state(std::make_shared<ReentrantLockStatus>());\n}\n\ntemplate<typename T>\nvoid ReentrantLockKernel<T>::ForwardDataContent(KernelContext* ctx) const {\n  auto* const status = CHECK_NOTNULL(dynamic_cast<ReentrantLockStatus*>(ctx->state().get()));\n  if (status->cur_ibn() == \"start\") {\n    T lock_id = *ctx->BnInOp2Blob(\"start\")->dptr<T>();\n    status->RequestLock(lock_id, status->mut_cur_unlocked_ids());\n  } else if (status->cur_ibn() == \"end\") {\n    status->ReleaseLock(*ctx->BnInOp2Blob(\"end\")->dptr<T>(), status->mut_cur_unlocked_ids());\n  } else {\n    CHECK_EQ(status->cur_ibn(), ReentrantLockStatus::kEmptyIbn);\n  }\n  if (status->cur_unlocked_ids().size() > 0) {\n    T lock_id = status->cur_unlocked_ids().front();\n    status->mut_cur_unlocked_ids()->pop();\n    *ctx->BnInOp2Blob(\"out\")->mut_dptr<T>() = lock_id;\n    status->set_acquired_lock_to_be_sent(true);\n  } else {\n    status->set_acquired_lock_to_be_sent(false);\n  }\n}\n\nADD_CPU_DEFAULT_KERNEL_CREATOR(OperatorConf::kReentrantLockConf, ReentrantLockKernel,\n                               INT_DATA_TYPE_SEQ)\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/kernel/reentrant_lock_kernel.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_KERNEL_REENTRANT_LOCK_KERNEL_H_\n#define ONEFLOW_CORE_KERNEL_REENTRANT_LOCK_KERNEL_H_\n\n#include \"oneflow/core/kernel/kernel.h\"\n#include \"oneflow/core/graph/graph.h\"\n\nnamespace oneflow {\n\nclass ReentrantLockStatus final : public KernelState {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(ReentrantLockStatus);\n  ReentrantLockStatus() = default;\n  ~ReentrantLockStatus() = default;\n\n  void Init(const KernelConf& kernel_conf);\n\n  static std::string kEmptyIbn;\n\n  // true: success\n  // false: failed\n  void RequestLock(int64_t lock_id, std::queue<int64_t>* unlocked_ids);\n\n  // return lock_id if any other lock acquired\n  // -1: no other lock acquired\n  void ReleaseLock(int64_t lock_id, std::queue<int64_t>* unlocked_ids);\n\n  const std::queue<int64_t>& cur_unlocked_ids() const { return cur_unlocked_ids_; }\n  std::queue<int64_t>* mut_cur_unlocked_ids() { return &cur_unlocked_ids_; }\n\n  // Getters\n  const std::string& cur_ibn() const { return cur_ibn_; }\n  int64_t cur_act_id() const { return cur_act_id_; }\n  bool acquired_lock_to_be_sent() const { return acquired_lock_to_be_sent_; }\n  size_t total_queued_request_lock_num() const { return total_queued_request_lock_num_; }\n  size_t total_acquired_lock_num() const { return total_acquired_lock_num_; }\n\n  // Setters\n  void set_cur_ibn(const std::string& ibn) { cur_ibn_ = ibn; }\n  void set_cur_act_id(int64_t act_id) { cur_act_id_ = act_id; }\n  void set_acquired_lock_to_be_sent(bool val) { acquired_lock_to_be_sent_ = val; }\n\n private:\n  // true: success\n  // false: failed\n  bool TryAcquireLock(int64_t lock_id);\n\n  std::string cur_ibn_;\n  int64_t cur_act_id_{};\n  bool acquired_lock_to_be_sent_{};\n  size_t total_queued_request_lock_num_{};\n  size_t total_acquired_lock_num_{};\n  std::vector<std::queue<int64_t>> lock_id2queued_request_act_id_;\n  std::vector<size_t> lock_id2acquired_num_;\n  std::vector<std::vector<int64_t>> lock_id2intersecting_lock_ids_;\n  std::queue<int64_t> cur_unlocked_ids_;\n};\n\ntemplate<typename T>\nclass ReentrantLockKernel final : public Kernel {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(ReentrantLockKernel);\n  ReentrantLockKernel() = default;\n  ~ReentrantLockKernel() override = default;\n\n private:\n  void VirtualKernelInit(KernelContext* ctx) override;\n  void ForwardDataContent(KernelContext* ctx) const override;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_KERNEL_REENTRANT_LOCK_KERNEL_H_\n"
  },
  {
    "path": "oneflow/core/kernel/return_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/kernel/kernel.h\"\n#include \"oneflow/core/common/buffer_manager.h\"\n#include \"oneflow/core/job/critical_section_instance.h\"\n#include \"oneflow/core/job/global_for.h\"\n\nnamespace oneflow {\n\nclass ReturnKernel final : public Kernel {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(ReturnKernel);\n  ReturnKernel() = default;\n  ~ReturnKernel() = default;\n\n private:\n  void ForwardDataContent(KernelContext* ctx) const override;\n  void ForwardHeader(KernelContext* ctx) const override;\n};\n\nvoid ReturnKernel::ForwardDataContent(KernelContext* ctx) const {\n  CHECK(this->op_conf().return_conf().has_job_name());\n  const auto& job_name = this->op_conf().return_conf().job_name();\n  const auto& op_name = this->op_conf().name();\n  auto* buffer_mgr = Singleton<BufferMgr<std::shared_ptr<CriticalSectionInstance>>>::Get();\n  auto* buffer = buffer_mgr->Get(GetOutputBufferName(job_name, op_name));\n  std::shared_ptr<CriticalSectionInstance> critical_section_instance;\n  BufferStatus buffer_status = buffer->TryReceive(&critical_section_instance);\n  CHECK_NE(buffer_status, kBufferStatusEmpty);\n  if (buffer_status == kBufferStatusSuccess) {\n    critical_section_instance->AccessBlobByOpName(ctx->stream(), ctx->BnInOp2Blob(\"in\"), op_name);\n  }\n}\n\nvoid ReturnKernel::ForwardHeader(KernelContext* ctx) const {\n  // Do nothing.\n}\n\nREGISTER_KERNEL(OperatorConf::kReturnConf, ReturnKernel);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/kernel/runtime_blob_shape_infer_helper.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/kernel/runtime_blob_shape_infer_helper.h\"\n#include \"oneflow/core/register/blob.h\"\n#include \"oneflow/core/common/cached_caller.h\"\n#include \"oneflow/core/job/resource_desc.h\"\n#include \"oneflow/core/job/global_for.h\"\n\nnamespace oneflow {\n\nRuntimeBlobShapeInferHelper::RuntimeBlobShapeInferHelper(const OperatorConf& op_conf,\n                                                         const KernelConf& kernel_conf,\n                                                         const void* scope) {\n  op_ = CHECK_JUST(ConstructOp(op_conf));\n  const OpAttribute& op_attribute = kernel_conf.op_attribute();\n  if (op_attribute.has_parallel_conf_signature()\n      && op_attribute.parallel_conf_signature().has_op_parallel_conf()) {\n    CHECK_JUST(op_->FillOpParallelDesc(\n        ParallelDesc(op_attribute.parallel_conf_signature().op_parallel_conf())));\n  }\n  if (op_attribute.has_sbp_signature()) {\n    sbp_signature_.reset(new SbpSignature(op_attribute.sbp_signature()));\n    CHECK_JUST(op_->FillSbpSignature(*sbp_signature_));\n  }\n  op_->ForEachBnInOp([&](const std::string& bn_in_op) { bn_in_op2blob_desc_[bn_in_op].reset(); });\n  if (op_attribute.has_logical_blob_desc_signature()) {\n    HashMap<std::string, std::unique_ptr<BlobDesc>> bn_in_op2logical_blob_desc;\n    const auto& blob_desc_signature_map =\n        op_attribute.logical_blob_desc_signature().bn_in_op2blob_desc();\n    for (const auto& pair : blob_desc_signature_map) {\n      bn_in_op2logical_blob_desc[pair.first].reset(new BlobDesc(pair.second));\n    }\n    auto GetLogicalBlobDesc4BnInOp = [&](const std::string& bn) -> BlobDesc* {\n      if (bn_in_op2logical_blob_desc.find(bn) != bn_in_op2logical_blob_desc.end()) {\n        return bn_in_op2logical_blob_desc.at(bn).get();\n      }\n      return nullptr;\n    };\n    CHECK_JUST(op_->FillLogicalInBlobDesc(GetLogicalBlobDesc4BnInOp));\n    CHECK_JUST(op_->FillLogicalOutBlobDesc(GetLogicalBlobDesc4BnInOp));\n  }\n  if (kernel_conf.has_parallel_ctx()) {\n    parallel_ctx_.reset(new ParallelContext(kernel_conf.parallel_ctx()));\n  }\n  op_infer_cache_key_.scope = scope;\n  op_infer_cache_key_.op_conf_sym = op_->GetOpConfWithoutOpNameAndLbn();\n  op_infer_cache_key_.ibn_idx2shape_sym.resize(op_->input_bns().size());\n  op_infer_cache_key_.dtype_signature_sym = SymbolOf(kernel_conf.dtype_signature());\n}\n\nvoid RuntimeBlobShapeInferHelper::UpdateInputBlobDescs7OpInferCacheKey(\n    std::function<Blob*(const std::string&)> BnInOp2Blob) {\n  auto ResetBlobDescAndGetShapeSym = [&](const std::string& ibn) -> Symbol<Shape> {\n    const Blob* blob = BnInOp2Blob(ibn);\n    if (blob == nullptr) { return Symbol<Shape>(); }\n    BlobDesc* blob_desc = BlobDesc4BnInOp(ibn, blob->blob_desc());\n    Shape blob_shape = blob_desc->shape();\n    blob_shape.LeftOnesExtendedAssign(blob->shape());\n    blob_desc->set_shape(blob_shape);\n    Stride blob_stride = blob_desc->stride();\n    blob_stride.CheckNumAxesIdenticalAndAssign(blob->stride());\n    blob_desc->set_stride(blob_stride);\n    return SymbolOf(blob_desc->shape());\n  };\n  const auto& input_bns = op_->input_bns();\n  FOR_RANGE(int, i, 0, input_bns.size()) {\n    op_infer_cache_key_.ibn_idx2shape_sym.at(i) = ResetBlobDescAndGetShapeSym(input_bns.Get(i));\n  }\n}\n\nBlobDesc* RuntimeBlobShapeInferHelper::BlobDesc4BnInOp(const std::string& bn_in_op,\n                                                       const BlobDesc& blob_desc) {\n  auto it = bn_in_op2blob_desc_.find(bn_in_op);\n  if (it == bn_in_op2blob_desc_.end()) { return nullptr; }\n  if (!it->second) { it->second.reset(new BlobDesc(blob_desc)); }\n  return it->second.get();\n}\n\nvoid RuntimeBlobShapeInferHelper::InferShape(\n    const std::function<Blob*(const std::string&)>& BnInOp2Blob) {\n  UpdateInputBlobDescs7OpInferCacheKey(BnInOp2Blob);\n  auto Infer = [&](const OpInferCacheKey& key) -> std::shared_ptr<const OpInferCacheValue> {\n    auto CachedBlobDesc4BnInOp = WithResultCached([&](const std::string& bn_in_op) -> BlobDesc* {\n      const Blob* blob = BnInOp2Blob(bn_in_op);\n      if (blob == nullptr) { return nullptr; }\n      return BlobDesc4BnInOp(bn_in_op, blob->blob_desc());\n    });\n    CHECK_JUST(op_->InferOutBlobDescsIf(CachedBlobDesc4BnInOp, parallel_ctx_.get()));\n    auto* ret = new OpInferCacheValue();\n    ret->obn_idx2shape_sym.resize(op_->output_bns().size());\n    FOR_RANGE(int, i, 0, op_->output_bns().size()) {\n      const auto& obn = op_->output_bns().Get(i);\n      const auto& blob_desc = bn_in_op2blob_desc_.at(obn);\n      ret->obn_idx2shape_sym.at(i).reset(blob_desc->shape());\n      auto* blob = BnInOp2Blob(obn);\n      if (blob == nullptr) { continue; }\n      CHECK_EQ(blob->data_type(), blob_desc->data_type());\n      CHECK_EQ(blob->blob_desc().is_dynamic(), blob_desc->is_dynamic());\n    }\n    return std::shared_ptr<const OpInferCacheValue>(ret);\n  };\n  size_t cache_size = Singleton<ResourceDesc, ForSession>::Get()->thread_local_cache_max_size();\n  const auto& shape_infer_ret = ThreadLocalCachedCall(cache_size, Infer, op_infer_cache_key_);\n  const auto& obn_idx2shape_sym = shape_infer_ret->obn_idx2shape_sym;\n  FOR_RANGE(int, i, 0, op_->output_bns().size()) {\n    const auto& obn = op_->output_bns().Get(i);\n    auto* blob = BnInOp2Blob(obn);\n    if (blob == nullptr) { continue; }\n    if (blob->blob_desc().is_dynamic()) {\n      blob->mut_shape_view()->set_shape(*obn_idx2shape_sym.at(i));\n    } else {\n      CHECK(*obn_idx2shape_sym.at(i) == blob->static_shape());\n    }\n  }\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/kernel/runtime_blob_shape_infer_helper.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_KERNEL_RUNTIME_BLOB_SHAPE_INFER_HELPER_H_\n#define ONEFLOW_CORE_KERNEL_RUNTIME_BLOB_SHAPE_INFER_HELPER_H_\n\n#include \"oneflow/core/operator/operator.h\"\n#include \"oneflow/core/operator/op_infer_cache.h\"\n\nnamespace oneflow {\n\nclass Blob;\nclass BlobDesc;\n\nclass RuntimeBlobShapeInferHelper final {\n public:\n  RuntimeBlobShapeInferHelper(const OperatorConf& op_conf, const KernelConf& kernel_conf,\n                              const void* scope);\n  ~RuntimeBlobShapeInferHelper() = default;\n\n  void InferShape(const std::function<Blob*(const std::string&)>& BnInOp2Blob);\n\n private:\n  void UpdateInputBlobDescs7OpInferCacheKey(std::function<Blob*(const std::string&)> BnInOp2Blob);\n  BlobDesc* BlobDesc4BnInOp(const std::string& bn_in_op, const BlobDesc& rt_blob_desc);\n\n  std::shared_ptr<Operator> op_;\n  HashSet<std::string> ibns_;\n  HashMap<std::string, std::unique_ptr<BlobDesc>> bn_in_op2blob_desc_;\n  std::unique_ptr<ParallelContext> parallel_ctx_;\n  std::unique_ptr<SbpSignature> sbp_signature_;\n  OpInferCacheKey op_infer_cache_key_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_KERNEL_RUNTIME_BLOB_SHAPE_INFER_HELPER_H_\n"
  },
  {
    "path": "oneflow/core/kernel/shape_elem_cnt_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/kernel/kernel.h\"\n#include \"oneflow/core/ep/include/primitive/fill.h\"\n\nnamespace oneflow {\n\ntemplate<DeviceType device_type, typename T>\nclass ShapeElemCntKernel final : public Kernel {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(ShapeElemCntKernel);\n  ShapeElemCntKernel() = default;\n  ~ShapeElemCntKernel() override = default;\n\n private:\n  void ForwardDataContent(KernelContext* ctx) const override;\n  int32_t GetShapePartialElemCnt(const ShapeView& shape) const;\n};\n\ntemplate<DeviceType device_type, typename T>\nvoid ShapeElemCntKernel<device_type, T>::ForwardDataContent(KernelContext* ctx) const {\n  const T elem_cnt = GetShapePartialElemCnt(ctx->BnInOp2Blob(\"x\")->shape());\n  std::unique_ptr<ep::primitive::Fill> fill =\n      ep::primitive::NewPrimitive<ep::primitive::FillFactory>(ctx->stream()->device_type(),\n                                                              ctx->BnInOp2Blob(\"y\")->data_type());\n  CHECK(fill);\n  fill->Launch(ctx->stream(), ctx->BnInOp2Blob(\"y\")->mut_dptr(), elem_cnt, 1);\n}\n\ntemplate<DeviceType device_type, typename T>\nint32_t ShapeElemCntKernel<device_type, T>::GetShapePartialElemCnt(const ShapeView& shape) const {\n  int32_t ret = 1;\n  for (int32_t axis : this->kernel_conf().shape_elem_cnt_conf().axis()) { ret *= shape.At(axis); }\n  return ret;\n}\n\nADD_DEFAULT_KERNEL_CREATOR(OperatorConf::kShapeElemCntConf, ShapeElemCntKernel,\n                           ARITHMETIC_DATA_TYPE_SEQ);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/kernel/slice_boxing_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/kernel/kernel.h\"\n#include \"oneflow/core/register/tensor_slice_copier.h\"\n#include \"oneflow/core/operator/operator.h\"\n#include \"oneflow/core/ep/include/primitive/add.h\"\n#include \"oneflow/core/ep/include/primitive/copy_nd.h\"\n#include \"oneflow/core/ep/include/primitive/memset.h\"\n\nnamespace oneflow {\n\nclass SliceBoxingKernel : public Kernel {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(SliceBoxingKernel);\n  SliceBoxingKernel() = default;\n  ~SliceBoxingKernel() override = default;\n\n protected:\n  virtual const SliceBoxingConf& GetCustomizedBoxingConf() const = 0;\n\n  const std::vector<std::shared_ptr<TensorSliceCopier>>& tensor_slice_copier_vec() const;\n\n private:\n  void VirtualKernelInit(KernelContext* ctx) override;\n\n  std::vector<std::shared_ptr<TensorSliceCopier>> tensor_slice_copier_vec_;\n};\n\nclass SliceBoxingCopyKernel final : public SliceBoxingKernel {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(SliceBoxingCopyKernel);\n  SliceBoxingCopyKernel() = default;\n  ~SliceBoxingCopyKernel() override = default;\n\n private:\n  virtual const SliceBoxingConf& GetCustomizedBoxingConf() const override;\n  void ForwardDataContent(KernelContext* ctx) const override;\n};\n\nclass SliceBoxingAddKernel final : public SliceBoxingKernel {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(SliceBoxingAddKernel);\n  SliceBoxingAddKernel() = default;\n  ~SliceBoxingAddKernel() override = default;\n\n private:\n  virtual const SliceBoxingConf& GetCustomizedBoxingConf() const override;\n  void ForwardDataContent(KernelContext* ctx) const override;\n};\n\nvoid SliceBoxingKernel::VirtualKernelInit(KernelContext* ctx) {\n  const SliceBoxingConf& conf = GetCustomizedBoxingConf();\n  if (/*is_0size_tensor=*/std::any_of(conf.out_shape().dim().begin(), conf.out_shape().dim().end(),\n                                      [](int64_t dim) { return dim == 0; })) {\n    return;\n  }\n  const TensorSliceView out_slice(conf.out_slice());\n  for (const TensorSliceViewProto& in_slice_proto : conf.in_slice()) {\n    const TensorSliceView in_slice(in_slice_proto);\n    tensor_slice_copier_vec_.emplace_back(new TensorSliceCopier(\n        out_slice, in_slice, this->kernel_conf().data_type(), ctx->stream()->device_type()));\n  }\n}\n\nconst std::vector<std::shared_ptr<TensorSliceCopier>>& SliceBoxingKernel::tensor_slice_copier_vec()\n    const {\n  return tensor_slice_copier_vec_;\n}\n\nconst SliceBoxingConf& SliceBoxingCopyKernel::GetCustomizedBoxingConf() const {\n  return this->op_conf().slice_boxing_copy_conf().slice_boxing_conf();\n}\n\nvoid SliceBoxingCopyKernel::ForwardDataContent(KernelContext* ctx) const {\n  Blob* out = ctx->BnInOp2Blob(\"out\");\n  if (out->shape_view().elem_cnt() == 0) { return; }\n  FOR_RANGE(int64_t, i, 0, this->op_attribute().input_bns().size()) {\n    const Blob* in_i = ctx->BnInOp2Blob(GenRepeatedBn(\"in\", i));\n    this->tensor_slice_copier_vec().at(i)->Copy(ctx->stream(), out, in_i);\n  }\n}\n\nconst SliceBoxingConf& SliceBoxingAddKernel::GetCustomizedBoxingConf() const {\n  return this->op_conf().slice_boxing_add_conf().slice_boxing_conf();\n}\n\nvoid SliceBoxingAddKernel::ForwardDataContent(KernelContext* ctx) const {\n  Blob* out = ctx->BnInOp2Blob(\"out\");\n  if (out->shape_view().elem_cnt() == 0) { return; }\n  std::unique_ptr<ep::primitive::Add> primitive =\n      ep::primitive::NewPrimitive<ep::primitive::AddFactory>(ctx->stream()->device_type(),\n                                                             out->data_type());\n  CHECK(primitive);\n  std::unique_ptr<ep::primitive::Memset> memset_primitive =\n      ep::primitive::NewPrimitive<ep::primitive::MemsetFactory>(ctx->stream()->device_type());\n  CHECK(memset_primitive);\n  FOR_RANGE(int64_t, i, 0, this->op_attribute().input_bns().size()) {\n    const Blob* in_i = ctx->BnInOp2Blob(GenRepeatedBn(\"in\", i));\n    if (i == 0) {\n      if (in_i->shape().NumAxes() == 0 && out->shape().NumAxes() == 0) {\n        AutoMemcpy(ctx->stream(), out, in_i);\n      } else {\n        this->tensor_slice_copier_vec().at(i)->Copy(ctx->stream(), out, in_i);\n      }\n    } else {\n      if (in_i->shape() == out->shape()) {\n        primitive->Launch(ctx->stream(), out->dptr(), in_i->dptr(), out->mut_dptr(),\n                          out->shape().elem_cnt());\n      } else {\n        Blob* buf = ctx->BnInOp2Blob(\"buf\");\n        memset_primitive->Launch(ctx->stream(), buf->mut_dptr(), 0,\n                                 buf->shape().elem_cnt() * GetSizeOfDataType(buf->data_type()));\n        this->tensor_slice_copier_vec().at(i)->Copy(ctx->stream(), buf, in_i);\n        primitive->Launch(ctx->stream(), out->dptr(), buf->dptr(), out->mut_dptr(),\n                          out->shape().elem_cnt());\n      }\n    }\n  }\n}\n\nREGISTER_KERNEL(OperatorConf::kSliceBoxingCopyConf, SliceBoxingCopyKernel);\nREGISTER_KERNEL(OperatorConf::kSliceBoxingAddConf, SliceBoxingAddKernel);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/kernel/sync_check_kernel_observer.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/kernel/sync_check_kernel_observer.h\"\n#include \"oneflow/core/kernel/kernel.h\"\n\nnamespace oneflow {\n\nvoid SyncCheckKernelObserver::DidForwardDataContent(KernelContext* kernel_ctx,\n                                                    const Kernel* kernel) {\n  CHECK_JUST_MSG(kernel_ctx->stream()->Sync(), kernel->op_conf().name());\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/kernel/sync_check_kernel_observer.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_KERNEL_SYNC_CHECK_KERNEL_OBSERVER_H_\n#define ONEFLOW_CORE_KERNEL_SYNC_CHECK_KERNEL_OBSERVER_H_\n\n#include \"oneflow/core/kernel/kernel_observer.h\"\n\nnamespace oneflow {\n\nclass SyncCheckKernelObserver final : public KernelObserver {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(SyncCheckKernelObserver);\n  SyncCheckKernelObserver() = default;\n  ~SyncCheckKernelObserver() override = default;\n\n  void DidForwardDataContent(KernelContext* kernel_ctx, const Kernel* kernel) override;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_KERNEL_SYNC_CHECK_KERNEL_OBSERVER_H_\n"
  },
  {
    "path": "oneflow/core/kernel/sync_dynamic_resize_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/device/cuda_util.h\"\n#include \"oneflow/core/framework/to_string.h\"\n#include \"oneflow/core/kernel/kernel.h\"\n#include \"oneflow/core/register/register_desc.h\"\n#include \"oneflow/core/lazy/actor/actor_context.h\"\n#include \"oneflow/core/memory/memory_case_util.h\"\n\n#include <cstddef>\n#include <cstdint>\n#include <memory>\n#include <mutex>\n#include <queue>\n\nnamespace oneflow {\n\n#ifdef WITH_CUDA\n\nnamespace {\n\nclass CudaHostMem {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CudaHostMem);\n  CudaHostMem(const size_t size) { OF_CUDA_CHECK(cudaMallocHost(&ptr_, size)); }\n  ~CudaHostMem() { OF_CUDA_CHECK(cudaFreeHost(ptr_)); }\n  void* Ptr() const { return ptr_; }\n\n private:\n  void* ptr_;\n};\n\n}  // namespace\n\ntemplate<typename SizeType>\nclass SyncDynamicResizeGPUKernel final : public Kernel {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(SyncDynamicResizeGPUKernel);\n  SyncDynamicResizeGPUKernel() = default;\n  ~SyncDynamicResizeGPUKernel() override = default;\n\n private:\n  bool IsKernelLaunchSynchronized() const override { return false; }\n\n  void ForwardDataContent(KernelContext* ctx) const override {\n    const SyncDynamicResizeOpConf& conf = this->op_conf().sync_dynamic_resize_conf();\n    CHECK_EQ(conf.axis(), 0);\n    std::shared_ptr<CudaHostMem> cuda_host_mem_ptr;\n    {\n      std::lock_guard<std::mutex> lock(mutex_);\n      if (queue_.empty()) {\n        cuda_host_mem_ptr.reset(new CudaHostMem(sizeof(SizeType)));\n      } else {\n        cuda_host_mem_ptr = queue_.front();\n        queue_.pop();\n      }\n    }\n    const Blob* in = ctx->BnInOp2Blob(\"in\");\n    const Blob* size = ctx->BnInOp2Blob(\"size\");\n    Blob* out = ctx->BnInOp2Blob(\"out\");\n    AutoMemcpy(ctx->stream(), out->mut_dptr(), in->dptr(), in->ByteSizeOfBlobBody(),\n               out->mem_case(), in->mem_case());\n    AutoMemcpy(ctx->stream(), cuda_host_mem_ptr->Ptr(), size->dptr(), sizeof(SizeType),\n               memory::MakeHostMemCase(), size->mem_case());\n    const auto& UpdateShape = [out, cuda_host_mem_ptr, conf, this]() {\n      const int64_t new_size = *reinterpret_cast<SizeType*>(cuda_host_mem_ptr->Ptr());\n      CHECK_GE(new_size, 0);\n      CHECK_LE(new_size, out->shape_view().At(conf.axis()));\n      // NOTE(Liang Depeng): `mut_shape_view` should be used here to get the blob's `MutShapeView`\n      //                     pointer. But this callback is called after `Kernel::Forward` function's\n      //                     execution and the header check is already been set to false at that\n      //                     moment. So we have to choose the `ForceMutShapeView` function with\n      //                     header checker disabled.\n      out->ForceMutShapeView()->Set(conf.axis(), new_size);\n      std::lock_guard<std::mutex> lock(mutex_);\n      queue_.push(cuda_host_mem_ptr);\n    };\n    if (conf.eager()) {\n      CHECK_JUST(ctx->stream()->Sync());\n      UpdateShape();\n    } else {\n      auto* actor_context_provider = CHECK_NOTNULL(dynamic_cast<ActorContextProvider*>(ctx));\n      actor_context_provider->GetActorContext()->AddCallback(UpdateShape);\n    }\n  }\n\n  mutable std::queue<std::shared_ptr<CudaHostMem>> queue_;\n  mutable std::mutex mutex_;\n};\n\n#define REGISTER_SYNC_DYNAMIC_RESIZE_GPU_KERNEL(stype)                                         \\\n  NEW_REGISTER_KERNEL(OperatorConf::kSyncDynamicResizeConf, SyncDynamicResizeGPUKernel<stype>) \\\n      .SetIsMatchedPred([](const KernelConf& kernel_conf) {                                    \\\n        return (kernel_conf.op_attribute().op_conf().device_tag() == \"cuda\"                    \\\n                && GetDataType<stype>::value                                                   \\\n                       == kernel_conf.sync_dynamic_resize_conf().size_data_type());            \\\n      })\nREGISTER_SYNC_DYNAMIC_RESIZE_GPU_KERNEL(int8_t);\nREGISTER_SYNC_DYNAMIC_RESIZE_GPU_KERNEL(int32_t);\nREGISTER_SYNC_DYNAMIC_RESIZE_GPU_KERNEL(int64_t);\n\n#endif  // WITH_CUDA\n\ntemplate<typename SizeType>\nclass SyncDynamicResizeCPUKernel final : public Kernel {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(SyncDynamicResizeCPUKernel);\n  SyncDynamicResizeCPUKernel() = default;\n  ~SyncDynamicResizeCPUKernel() override = default;\n\n private:\n  bool IsKernelLaunchSynchronized() const override { return false; }\n  void ForwardDataContent(KernelContext* ctx) const override {\n    const SyncDynamicResizeOpConf& conf = this->op_conf().sync_dynamic_resize_conf();\n    CHECK_EQ(conf.axis(), 0);\n    const Blob* in = ctx->BnInOp2Blob(\"in\");\n    const Blob* size = ctx->BnInOp2Blob(\"size\");\n    Blob* out = ctx->BnInOp2Blob(\"out\");\n    AutoMemcpy(ctx->stream(), out->mut_dptr(), in->dptr(), in->ByteSizeOfBlobBody(),\n               out->mem_case(), in->mem_case());\n    const SizeType new_size = *size->dptr<SizeType>();\n    CHECK_GE(new_size, 0);\n    CHECK_LE(new_size, out->shape_view().At(conf.axis()));\n    out->mut_shape_view()->Set(conf.axis(), new_size);\n  }\n};\n\n#define REGISTER_SYNC_DYNAMIC_RESIZE_CPU_KERNEL(stype)                                         \\\n  NEW_REGISTER_KERNEL(OperatorConf::kSyncDynamicResizeConf, SyncDynamicResizeCPUKernel<stype>) \\\n      .SetIsMatchedPred([](const KernelConf& kernel_conf) {                                    \\\n        return (kernel_conf.op_attribute().op_conf().device_tag() == \"cpu\"                     \\\n                && GetDataType<stype>::value                                                   \\\n                       == kernel_conf.sync_dynamic_resize_conf().size_data_type());            \\\n      })\nREGISTER_SYNC_DYNAMIC_RESIZE_CPU_KERNEL(int8_t);\nREGISTER_SYNC_DYNAMIC_RESIZE_CPU_KERNEL(int32_t);\nREGISTER_SYNC_DYNAMIC_RESIZE_CPU_KERNEL(int64_t);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/kernel/total_loss_instance_num_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/kernel/kernel.h\"\n\nnamespace oneflow {\n\ntemplate<typename T>\nclass TotalLossInstanceNumKernel final : public Kernel {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(TotalLossInstanceNumKernel);\n  TotalLossInstanceNumKernel() = default;\n  ~TotalLossInstanceNumKernel() override = default;\n\n private:\n  void ForwardDataContent(KernelContext* ctx) const override;\n};\n\ntemplate<typename T>\nvoid TotalLossInstanceNumKernel<T>::ForwardDataContent(KernelContext* ctx) const {\n  const auto& input_bns = this->op_attribute().input_bns();\n  T first_val = ctx->BnInOp2Blob(input_bns.Get(0))->template dptr<T>()[0];\n  for (const std::string& ibn : input_bns) {\n    CHECK_EQ(ctx->BnInOp2Blob(ibn)->template dptr<T>()[0], first_val);\n  }\n  ctx->BnInOp2Blob(\"out\")->template mut_dptr<T>()[0] = first_val;\n}\n\nADD_CPU_DEFAULT_KERNEL_CREATOR(OperatorConf::kTotalLossInstanceNumConf, TotalLossInstanceNumKernel,\n                               ARITHMETIC_DATA_TYPE_SEQ);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/kernel/user_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/kernel/user_kernel.h\"\n#include \"oneflow/core/framework/infer_util.h\"\n#include \"oneflow/core/framework/op_kernel.h\"\n#include \"oneflow/core/framework/op_kernel_infer_cache.h\"\n#include \"oneflow/core/framework/user_op_tensor.h\"\n#include \"oneflow/core/kernel/blob_tensor_view.h\"\n#include \"oneflow/core/framework/to_string.h\"\n#include \"oneflow/core/framework/user_op_conf.h\"\n#include \"oneflow/core/framework/user_op_registry_manager.h\"\n#include \"oneflow/core/kernel/kernel.h\"\n#include \"oneflow/core/kernel/cuda_graph_support.h\"\n#include \"oneflow/core/operator/operator.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nbool IsAllBlobEmpty(const PbRpf<std::string>& bns,\n                    const std::function<Blob*(const std::string& bn)>& BnInOp2Blob) {\n  for (const auto& bn : bns) {\n    Blob* blob = BnInOp2Blob(bn);\n    if (blob && !blob->IsBodyEmpty()) { return false; }\n  }\n  return true;\n}\n\n}  // namespace\n\nusing Arg2Tensor =\n    HashMap<std::pair<std::string, int32_t>, std::unique_ptr<user_op::BlobTensorView>>;\nusing ArgVec = std::vector<std::pair<std::string, int32_t>>;\n\nnamespace {\n\nvoid FillTensorDescWithBlob(const Blob* blob, user_op::NaiveTensorDesc* tensor_desc) {\n  BlobDescProto proto;\n  blob->blob_desc().shape().ToProto(proto.mutable_shape());\n  blob->blob_desc().stride().ToProto(proto.mutable_stride());\n  proto.set_data_type(blob->blob_desc().data_type());\n  proto.set_is_dynamic(blob->blob_desc().is_dynamic());\n  *tensor_desc = proto;\n  Shape tensor_desc_shape = tensor_desc->shape();\n  tensor_desc_shape.CheckNumAxesIdenticalAndAssign(blob->shape());\n  tensor_desc->set_shape(tensor_desc_shape);\n  Stride tensor_desc_stride = tensor_desc->stride();\n  tensor_desc_stride.CheckNumAxesIdenticalAndAssign(blob->stride());\n  tensor_desc->set_stride(tensor_desc_stride);\n}\n\n}  // namespace\n\nclass UserKernelBaseContext {\n public:\n  explicit UserKernelBaseContext(const KernelConf& kernel_conf) {\n    CHECK(kernel_conf.has_user_conf());\n    CHECK(kernel_conf.op_attribute().op_conf().has_user_conf());\n\n    auto InitInOrOut = [&](const PbMap<std::string, UserOpConf::ListString>& arg_map,\n                           ArgVec* arg_vec) {\n      for (auto it = arg_map.begin(); it != arg_map.end(); ++it) {\n        for (int32_t i = 0; i < it->second.s_size(); ++i) {\n          arg_vec->emplace_back(std::make_pair(it->first, i));\n        }\n      }\n    };\n    InitInOrOut(kernel_conf.op_attribute().op_conf().user_conf().input(), &inputs_);\n    InitInOrOut(kernel_conf.op_attribute().op_conf().user_conf().output(), &outputs_);\n    device_type_ =\n        CHECK_JUST(DeviceType4DeviceTag(kernel_conf.op_attribute().op_conf().device_tag()));\n    parallel_ctx_ = kernel_conf.parallel_ctx();\n    for (const auto& pair : kernel_conf.user_conf().bn_in_op2blob_desc()) {\n      arg2bn_and_tensor_desc_.emplace(\n          GenUnRepeatedBn(pair.first),\n          std::make_pair(pair.first, user_op::NaiveTensorDesc(pair.second)));\n    }\n  }\n  ~UserKernelBaseContext() = default;\n\n  DeviceType device_type() const { return device_type_; }\n  const ParallelContext& parallel_ctx() const { return parallel_ctx_; }\n  const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(const std::string& arg_name,\n                                                        int32_t index) const {\n    auto it = arg2bn_and_tensor_desc_.find(std::make_pair(arg_name, index));\n    if (it == arg2bn_and_tensor_desc_.end()) { return nullptr; }\n    return &(it->second.second);\n  }\n\n  const ArgVec& inputs() const { return inputs_; }\n  const ArgVec& outputs() const { return outputs_; }\n\n private:\n  friend class UserKernelInitAndCacheContext;\n  HashMap<std::pair<std::string, int32_t>, std::pair<std::string, user_op::NaiveTensorDesc>>\n      arg2bn_and_tensor_desc_;\n  ArgVec inputs_;\n  ArgVec outputs_;\n  DeviceType device_type_;\n  ParallelContext parallel_ctx_;\n};\n\nclass UserKernelInitAndCacheContext final : public user_op::KernelInitContext,\n                                            public user_op::KernelCacheContext {\n public:\n  explicit UserKernelInitAndCacheContext(ep::Stream* stream, const KernelConf& kernel_conf)\n      : user_op_conf_(kernel_conf.op_attribute().op_conf()),\n        stream_(stream),\n        base_ctx_(UserKernelBaseContext(kernel_conf)),\n        parallel_desc_(kernel_conf.op_attribute().parallel_conf_signature().op_parallel_conf()) {\n    nd_sbp_signature_ = NdSbpSignature(kernel_conf.op_attribute().nd_sbp_signature());\n    if (kernel_conf.op_attribute().has_sbp_signature()) {\n      sbp_signature_ = SbpSignature(kernel_conf.op_attribute().sbp_signature());\n    }\n    bool is_dynamic = false;\n    for (const auto& pair : kernel_conf.user_conf().bn_in_op2blob_desc()) {\n      if (pair.second.is_dynamic()) {\n        is_dynamic = true;\n        break;\n      }\n    }\n    if (!is_dynamic || parallel_ctx().parallel_num() == 1) {\n      for (const auto& pair :\n           kernel_conf.op_attribute().logical_blob_desc_signature().bn_in_op2blob_desc()) {\n        arg2logical_tensor_desc_.emplace(GenUnRepeatedBn(pair.first),\n                                         user_op::NaiveTensorDesc(pair.second));\n      }\n    }\n  }\n  ~UserKernelInitAndCacheContext() override = default;\n\n  ep::Stream* stream() override { return stream_; }\n\n  void UpdateTensorWithCorrBlob(const std::function<Blob*(const std::string&)>& BnInOp2Blob) {\n    for (auto& pair : base_ctx_.arg2bn_and_tensor_desc_) {\n      const std::string& bn = pair.second.first;\n      auto& tensor_desc = pair.second.second;\n      Blob* blob = BnInOp2Blob(bn);\n      CHECK(blob != nullptr) << \"Blob \" << bn << \" is not found in cache context.\";\n      if (blob->blob_desc().is_dynamic()) {\n        Shape shape;\n        blob->shape().ToShape(&shape);\n        tensor_desc.set_shape(shape);\n      }\n    }\n  }\n\n  DeviceType device_type() const override { return base_ctx_.device_type(); }\n  const ParallelContext& parallel_ctx() const override { return base_ctx_.parallel_ctx(); }\n  const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(const std::string& arg_name,\n                                                        int32_t index) const override {\n    return base_ctx_.TensorDesc4ArgNameAndIndex(arg_name, index);\n  }\n  const user_op::TensorDesc* LogicalTensorDesc4ArgNameAndIndex(const std::string& arg_name,\n                                                               int32_t index) const override {\n    auto it = arg2logical_tensor_desc_.find(std::make_pair(arg_name, index));\n    if (it == arg2logical_tensor_desc_.end()) {\n      return nullptr;\n    } else {\n      return &(it->second);\n    }\n  }\n  const SbpParallel& SbpParallel4ArgNameAndIndex(const std::string& arg_name,\n                                                 int32_t index) const override {\n    CHECK_EQ(parallel_desc_.hierarchy()->NumAxes(), 1);\n    const auto& bn2sbp = sbp_signature_.bn_in_op2sbp_parallel();\n    std::string bn = GenRepeatedBn(arg_name, index);\n    auto it = bn2sbp.find(bn);\n    CHECK(it != bn2sbp.end());\n    return it->second;\n  }\n\n  const NdSbp& NdSbp4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override {\n    const auto& bn2nd_sbp = nd_sbp_signature_.bn_in_op2nd_sbp();\n    std::string bn = GenRepeatedBn(arg_name, index);\n    auto it = bn2nd_sbp.find(bn);\n    CHECK(it != bn2nd_sbp.end());\n    return it->second;\n  }\n\n  const ArgVec& inputs() const override { return base_ctx_.inputs(); }\n  const ArgVec& outputs() const override { return base_ctx_.outputs(); }\n  const ParallelDesc& parallel_desc() const override { return parallel_desc_; }\n\n private:\n  const user_op::UserOpConfWrapper& user_op_conf() const override { return user_op_conf_; }\n\n  const std::shared_ptr<const user_op::AttrVal>& Attr4Name(\n      const std::string& attr_name) const override {\n    return user_op_conf().Attr4Name(attr_name);\n  }\n\n  user_op::UserOpConfWrapper user_op_conf_;\n  ep::Stream* stream_;\n  UserKernelBaseContext base_ctx_;\n  SbpSignature sbp_signature_;\n  HashMap<std::pair<std::string, int32_t>, user_op::NaiveTensorDesc> arg2logical_tensor_desc_;\n  ParallelDesc parallel_desc_;\n  NdSbpSignature nd_sbp_signature_;\n};\n\nusing UserKernelInitContext = UserKernelInitAndCacheContext;\nusing UserKernelCacheContext = UserKernelInitAndCacheContext;\n\nclass UserKernelOpInferContext : public user_op::InferContext {\n public:\n  explicit UserKernelOpInferContext(const KernelConf& kernel_conf)\n      : user_op_conf_(kernel_conf.op_attribute().op_conf()),\n        parallel_ctx_(kernel_conf.parallel_ctx()),\n        nd_sbp_signature_(kernel_conf.op_attribute().nd_sbp_signature()),\n        parallel_desc_(kernel_conf.op_attribute().parallel_conf_signature().op_parallel_conf()) {\n    if (kernel_conf.op_attribute().has_sbp_signature()) {\n      sbp_signature_ = SbpSignature(kernel_conf.op_attribute().sbp_signature());\n    }\n    auto InitTensorDesc = [&](const PbMap<std::string, UserOpConf::ListString>& arg_map,\n                              ArgVec* arg_vec) {\n      for (auto it = arg_map.begin(); it != arg_map.end(); ++it) {\n        const std::string& arg_name = it->first;\n        for (int32_t i = 0; i < it->second.s_size(); ++i) {\n          std::pair<std::string, int32_t> arg_pair = std::make_pair(arg_name, i);\n          arg_vec->emplace_back(arg_pair);\n          arg2tensor_desc_.emplace(arg_pair, nullptr);\n        }\n      }\n    };\n    InitTensorDesc(kernel_conf.op_attribute().op_conf().user_conf().input(), &inputs_);\n    InitTensorDesc(kernel_conf.op_attribute().op_conf().user_conf().output(), &outputs_);\n    for (const auto& pair :\n         kernel_conf.op_attribute().logical_blob_desc_signature().bn_in_op2blob_desc()) {\n      arg2logical_tensor_desc_.emplace(GenUnRepeatedBn(pair.first),\n                                       user_op::NaiveTensorDesc(pair.second));\n    }\n  }\n  ~UserKernelOpInferContext() override = default;\n\n  const user_op::TensorDesc* LogicalTensorDesc4ArgNameAndIndex(const std::string& arg_name,\n                                                               int32_t index) const override {\n    auto it = arg2logical_tensor_desc_.find(std::make_pair(arg_name, index));\n    CHECK(it != arg2logical_tensor_desc_.end())\n        << \"Arg (\" << arg_name << \",\" << index << \") is not found\";\n    return &(it->second);\n  }\n\n  const user_op::TensorDesc& InputTensorDesc(const std::string& arg_name,\n                                             int32_t index) const override {\n    return *TensorDesc4ArgNameAndIndex(arg_name, index);\n  }\n  const user_op::TensorDesc& OutputTensorDesc(const std::string& arg_name,\n                                              int32_t index) const override {\n    return *TensorDesc4ArgNameAndIndex(arg_name, index);\n  }\n  user_op::TensorDesc* MutOutputTensorDesc(const std::string& arg_name, int32_t index) override {\n    return MutTensorDesc4ArgNameAndIndex(arg_name, index);\n  }\n  const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(const std::string& arg_name,\n                                                        int32_t index) const {\n    auto it = arg2tensor_desc_.find(std::make_pair(arg_name, index));\n    if (it == arg2tensor_desc_.end()) { return nullptr; }\n    return it->second.get();\n  }\n  user_op::TensorDesc* MutTensorDesc4ArgNameAndIndex(const std::string& arg_name, int32_t index) {\n    auto it = arg2tensor_desc_.find(std::make_pair(arg_name, index));\n    if (it == arg2tensor_desc_.end()) { return nullptr; }\n    return it->second.get();\n  }\n  const Shape& InputShape(const std::string& arg_name, int32_t index) const override {\n    return Shape4ArgNameAndIndex(arg_name, index);\n  }\n  const Shape& OutputShape(const std::string& arg_name, int32_t index) const override {\n    return Shape4ArgNameAndIndex(arg_name, index);\n  }\n  void SetOutputShape(const std::string& arg_name, int32_t index, const Shape& shape) override {\n    SetShape4ArgNameAndIndex(arg_name, index, shape);\n  }\n  const Shape& Shape4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override {\n    return TensorDesc4ArgNameAndIndex(arg_name, index)->shape();\n  }\n  void SetShape4ArgNameAndIndex(const std::string& arg_name, int32_t index,\n                                const Shape& shape) override {\n    return MutTensorDesc4ArgNameAndIndex(arg_name, index)->set_shape(shape);\n  }\n  const Stride& InputStride(const std::string& arg_name, int32_t index) const override {\n    return Stride4ArgNameAndIndex(arg_name, index);\n  }\n  const Stride& OutputStride(const std::string& arg_name, int32_t index) const override {\n    return Stride4ArgNameAndIndex(arg_name, index);\n  }\n  void SetOutputStride(const std::string& arg_name, int32_t index, const Stride& stride) override {\n    return SetStride4ArgNameAndIndex(arg_name, index, stride);\n  }\n  const Stride& Stride4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override {\n    return TensorDesc4ArgNameAndIndex(arg_name, index)->stride();\n  }\n  void SetStride4ArgNameAndIndex(const std::string& arg_name, int32_t index,\n                                 const Stride& stride) override {\n    return MutTensorDesc4ArgNameAndIndex(arg_name, index)->set_stride(stride);\n  }\n  DataType InputDType(const std::string& arg_name, int32_t index) const override {\n    return Dtype4ArgNameAndIndex(arg_name, index);\n  }\n  DataType OutputDType(const std::string& arg_name, int32_t index) const override {\n    return Dtype4ArgNameAndIndex(arg_name, index);\n  }\n  void SetOutputDType(const std::string& arg_name, int32_t index, DataType data_type) override {\n    return SetDtype4ArgNameAndIndex(arg_name, index, data_type);\n  }\n  DataType Dtype4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override {\n    return TensorDesc4ArgNameAndIndex(arg_name, index)->data_type();\n  }\n  void SetDtype4ArgNameAndIndex(const std::string& arg_name, int32_t index,\n                                DataType data_type) override {\n    return MutTensorDesc4ArgNameAndIndex(arg_name, index)->set_data_type(data_type);\n  }\n\n  MemoryFormat InputMemoryFormat(const std::string& arg_name, int32_t index) const override {\n    return MemoryFormat4ArgNameAndIndex(arg_name, index);\n  }\n  MemoryFormat OutputMemoryFormat(const std::string& arg_name, int32_t index) const override {\n    return MemoryFormat4ArgNameAndIndex(arg_name, index);\n  }\n  void SetOutputMemoryFormat(const std::string& arg_name, int32_t index,\n                             MemoryFormat memory_format) override {\n    return SetMemoryFormat4ArgNameAndIndex(arg_name, index, memory_format);\n  }\n  MemoryFormat MemoryFormat4ArgNameAndIndex(const std::string& arg_name,\n                                            int32_t index) const override {\n    return TensorDesc4ArgNameAndIndex(arg_name, index)->memory_format();\n  }\n  void SetMemoryFormat4ArgNameAndIndex(const std::string& arg_name, int32_t index,\n                                       MemoryFormat memory_format) override {\n    MutTensorDesc4ArgNameAndIndex(arg_name, index)->set_memory_format(memory_format);\n  }\n\n  bool InputIsDynamic(const std::string& arg_name, int32_t index) const override {\n    return IsDynamic4ArgNameAndIndex(arg_name, index);\n  }\n  bool OutputIsDynamic(const std::string& arg_name, int32_t index) const override {\n    return IsDynamic4ArgNameAndIndex(arg_name, index);\n  }\n  void SetOutputIsDynamic(const std::string& arg_name, int32_t index, bool is_dynamic) override {\n    return SetIsDynamic4ArgNameAndIndex(arg_name, index, is_dynamic);\n  }\n  bool IsDynamic4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override {\n    return TensorDesc4ArgNameAndIndex(arg_name, index)->is_dynamic();\n  }\n  void SetIsDynamic4ArgNameAndIndex(const std::string& arg_name, int32_t index,\n                                    bool is_dynamic) override {\n    return MutTensorDesc4ArgNameAndIndex(arg_name, index)->set_is_dynamic(is_dynamic);\n  }\n\n  const ArgVec& inputs() const override { return inputs_; }\n  const ArgVec& outputs() const override { return outputs_; }\n  const ParallelContext& parallel_ctx() const override { return parallel_ctx_; };\n  const ParallelDesc& parallel_desc() const override { return parallel_desc_; }\n  const SbpParallel& SbpParallel4ArgNameAndIndex(const std::string& arg_name,\n                                                 int32_t index) const override {\n    CHECK_EQ(parallel_desc_.hierarchy()->NumAxes(), 1);\n    const auto& bn2sbp = sbp_signature_.bn_in_op2sbp_parallel();\n    std::string bn = GenRepeatedBn(arg_name, index);\n    auto it = bn2sbp.find(bn);\n    CHECK(it != bn2sbp.end());\n    return it->second;\n  }\n  const NdSbp& NdSbp4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override {\n    const auto& bn2nd_sbp = nd_sbp_signature_.bn_in_op2nd_sbp();\n    std::string bn = GenRepeatedBn(arg_name, index);\n    auto it = bn2nd_sbp.find(bn);\n    CHECK(it != bn2nd_sbp.end());\n    return it->second;\n  }\n  void UpdateArg2TensorDesc(const std::function<Blob*(const std::string&)>& BnInOp2Blob) {\n    for (auto& pair : arg2tensor_desc_) {\n      const auto& arg_pair = pair.first;\n      std::unique_ptr<user_op::NaiveTensorDesc>* arg_tensor_desc_ptr = &pair.second;\n      Blob* blob = BnInOp2Blob(GenRepeatedBn(arg_pair.first, arg_pair.second));\n      CHECK_NOTNULL(blob);\n      if (*arg_tensor_desc_ptr) {\n        Shape tensor_desc_shape = (*arg_tensor_desc_ptr)->shape();\n        tensor_desc_shape.CheckNumAxesIdenticalAndAssign(blob->shape());\n        (*arg_tensor_desc_ptr)->set_shape(tensor_desc_shape);\n        Stride tensor_desc_stride = (*arg_tensor_desc_ptr)->stride();\n        tensor_desc_stride.CheckNumAxesIdenticalAndAssign(blob->stride());\n        (*arg_tensor_desc_ptr)->set_stride(tensor_desc_stride);\n      } else {\n        arg_tensor_desc_ptr->reset(new user_op::NaiveTensorDesc());\n        FillTensorDescWithBlob(blob, arg_tensor_desc_ptr->get());\n      }\n    }\n  }\n\n  int64_t parallel_num() const override { return parallel_ctx_.parallel_num(); }\n\n  const std::string& input(const std::string& arg_name, int32_t index) const override {\n    return user_op_conf().input(arg_name, index);\n  }\n  const std::string& output(const std::string& arg_name, int32_t index) const override {\n    return user_op_conf().output(arg_name, index);\n  }\n  bool has_input(const std::string& arg_name, int32_t index) const override {\n    return user_op_conf().has_input(arg_name, index);\n  }\n  bool has_output(const std::string& arg_name, int32_t index) const override {\n    return user_op_conf().has_output(arg_name, index);\n  }\n  int32_t input_size(const std::string& arg_name) const override {\n    return user_op_conf().input_size(arg_name);\n  }\n  int32_t output_size(const std::string& arg_name) const override {\n    return user_op_conf().output_size(arg_name);\n  }\n  const std::string& op_name() const override { return user_op_conf().op_name(); }\n  const std::string& op_type_name() const override { return user_op_conf().op_type_name(); }\n  const std::string& op_loc() const override { return user_op_conf_.op_conf().loc(); }\n\n private:\n  const user_op::UserOpConfWrapper& user_op_conf() const { return user_op_conf_; }\n  const std::shared_ptr<const user_op::AttrVal>& Attr4Name(\n      const std::string& attr_name) const override {\n    return user_op_conf().Attr4Name(attr_name);\n  }\n\n  user_op::UserOpConfWrapper user_op_conf_;\n  ArgVec inputs_;\n  ArgVec outputs_;\n  ParallelContext parallel_ctx_;\n  SbpSignature sbp_signature_;\n  NdSbpSignature nd_sbp_signature_;\n  ParallelDesc parallel_desc_;\n  HashMap<std::pair<std::string, int32_t>, std::unique_ptr<user_op::NaiveTensorDesc>>\n      arg2tensor_desc_;\n  HashMap<std::pair<std::string, int32_t>, user_op::NaiveTensorDesc> arg2logical_tensor_desc_;\n};\n\nclass UserKernelInferContext final : public user_op::KernelInferContext {\n public:\n  explicit UserKernelInferContext(ep::Stream* stream, const KernelConf& kernel_conf)\n      : user_op_conf_(kernel_conf.op_attribute().op_conf()),\n        stream_(stream),\n        base_ctx_(UserKernelBaseContext(kernel_conf)),\n        op_infer_ctx_(kernel_conf) {\n    auto InitArg2Blob = [this](const PbMap<std::string, UserOpConf::ListString>& arg_map) {\n      for (auto it = arg_map.begin(); it != arg_map.end(); ++it) {\n        const std::string& arg_name = it->first;\n        for (int32_t i = 0; i < it->second.s_size(); ++i) {\n          arg2tensor_.emplace(std::make_pair(arg_name, i), nullptr);\n        }\n      }\n    };\n    InitArg2Blob(kernel_conf.op_attribute().op_conf().user_conf().input());\n    InitArg2Blob(kernel_conf.op_attribute().op_conf().user_conf().output());\n\n    const auto* op_reg_val = user_op::UserOpRegistryMgr::Get().GetOpRegistryResult(\n        kernel_conf.op_attribute().op_conf().user_conf().op_type_name());\n    CHECK_NOTNULL(op_reg_val);\n    if (op_reg_val->physical_tensor_desc_infer_fn) {\n      tensor_desc_infer_fn_ = op_reg_val->physical_tensor_desc_infer_fn;\n    } else {\n      UNIMPLEMENTED();\n    }\n  }\n  ~UserKernelInferContext() = default;\n\n  DeviceType device_type() const override { return base_ctx_.device_type(); }\n  const ParallelContext& parallel_ctx() const override { return base_ctx_.parallel_ctx(); }\n  const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(const std::string& arg_name,\n                                                        int32_t index) const override {\n    return base_ctx_.TensorDesc4ArgNameAndIndex(arg_name, index);\n  }\n  const ArgVec& inputs() const override { return base_ctx_.inputs(); }\n  const ArgVec& outputs() const override { return base_ctx_.outputs(); }\n\n  ep::Stream* stream() override { return stream_; }\n  user_op::Tensor* Tensor4ArgNameAndIndex(const std::string& arg_name, int32_t arg_index) override {\n    auto it = arg2tensor_.find(std::make_pair(arg_name, arg_index));\n    CHECK(it != arg2tensor_.end()) << \"Arg (\" << arg_name << \",\" << arg_index << \") is not found\";\n    return it->second.get();\n  }\n  ShapeView ShapeView4ArgNameAndIndex(const std::string& arg_name, int32_t arg_index) override {\n    user_op::Tensor* arg_tensor = Tensor4ArgNameAndIndex(arg_name, arg_index);\n    CHECK(arg_tensor != nullptr) << \"Tensor of arg (\" << arg_name << \",\" << arg_index\n                                 << \") is not found\";\n    return arg_tensor->shape_view();\n  }\n  MutShapeView MutShapeView4ArgNameAndIndex(const std::string& arg_name,\n                                            int32_t arg_index) override {\n    user_op::Tensor* arg_tensor = Tensor4ArgNameAndIndex(arg_name, arg_index);\n    CHECK(arg_tensor != nullptr) << \"Tensor of arg (\" << arg_name << \",\" << arg_index\n                                 << \") is not found\";\n    return arg_tensor->mut_shape_view();\n  }\n\n  user_op::InferContext* MutOpInferContext() override { return &op_infer_ctx_; }\n  const user_op::TensorDescInferFn& GetOpInferFn() const override { return tensor_desc_infer_fn_; }\n\n  void UpdateArg2Tensor(const std::function<Blob*(const std::string&)>& BnInOp2Blob) {\n    for (auto& pair : arg2tensor_) {\n      const auto& arg_pair = pair.first;\n      std::unique_ptr<user_op::BlobTensorView>* arg_tensor_ptr = &pair.second;\n      Blob* blob = BnInOp2Blob(GenRepeatedBn(arg_pair.first, arg_pair.second));\n      if (blob == nullptr) { continue; }\n      if (*arg_tensor_ptr) {\n        arg_tensor_ptr->get()->Reset(blob);\n      } else {\n        arg_tensor_ptr->reset(new user_op::BlobTensorView(blob));\n      }\n    }\n  }\n\n private:\n  const user_op::UserOpConfWrapper& user_op_conf() const override { return user_op_conf_; }\n  const std::shared_ptr<const user_op::AttrVal>& Attr4Name(\n      const std::string& attr_name) const override {\n    return user_op_conf().Attr4Name(attr_name);\n  }\n\n  user_op::UserOpConfWrapper user_op_conf_;\n  ep::Stream* stream_;\n  UserKernelBaseContext base_ctx_;\n  UserKernelOpInferContext op_infer_ctx_;\n  user_op::TensorDescInferFn tensor_desc_infer_fn_;\n  HashMap<std::pair<std::string, int32_t>, std::unique_ptr<user_op::BlobTensorView>> arg2tensor_;\n};\n\nnamespace {\n\nstruct BnTensorPair {\n  std::string bn;\n  std::unique_ptr<user_op::BlobTensorView> tensor;\n};\n\nBnTensorPair MakeBnTensorPair(const std::string& bn) {\n  BnTensorPair pair;\n  pair.bn = bn;\n  return pair;\n}\n\nBnTensorPair MakeBnTensorPair(const std::string& bn,\n                              std::unique_ptr<user_op::BlobTensorView>&& tensor) {\n  BnTensorPair pair;\n  pair.bn = bn;\n  pair.tensor = std::move(tensor);\n  return pair;\n}\n\n}  // namespace\n\nclass UserKernelComputeContext final : public user_op::KernelComputeContext {\n public:\n  explicit UserKernelComputeContext(ep::Stream* stream, const KernelConf& kernel_conf)\n      : user_op_conf_(kernel_conf.op_attribute().op_conf()),\n        stream_(stream),\n        base_ctx_(kernel_conf) {\n    auto InitInOrOut = [&](const PbMap<std::string, UserOpConf::ListString>& arg_map) {\n      for (const auto& it : arg_map) {\n        const std::string& arg_name = it.first;\n        for (int32_t i = 0; i < it.second.s_size(); ++i) {\n          arg2bn_tensor_pair_.emplace(std::make_pair(arg_name, i),\n                                      MakeBnTensorPair(GenRepeatedBn(arg_name, i)));\n        }\n      }\n    };\n    InitInOrOut(kernel_conf.op_attribute().op_conf().user_conf().input());\n    InitInOrOut(kernel_conf.op_attribute().op_conf().user_conf().output());\n    arg2bn_tensor_pair_.emplace(std::make_pair(\"tmp_buffer\", 0),\n                                MakeBnTensorPair(GenRepeatedBn(\"tmp_buffer\", 0)));\n  }\n  ~UserKernelComputeContext() = default;\n\n  const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(const std::string& arg_name,\n                                                        int32_t index) const override {\n    return base_ctx_.TensorDesc4ArgNameAndIndex(arg_name, index);\n  }\n\n  user_op::Tensor* Tensor4ArgNameAndIndex(const std::string& arg_name, int32_t index) override {\n    auto it = arg2bn_tensor_pair_.find(std::make_pair(arg_name, index));\n    if (it == arg2bn_tensor_pair_.end()) { return nullptr; }\n    return it->second.tensor.get();\n  }\n  ep::Stream* stream() override { return stream_; }\n\n  bool UpdateTensorWithCorrBlob(const std::function<Blob*(const std::string&)>& BnInOp2Blob) {\n    bool updated = false;\n    for (auto& pair : arg2bn_tensor_pair_) {\n      std::unique_ptr<user_op::BlobTensorView>* arg_tensor_ptr = &pair.second.tensor;\n      Blob* blob = BnInOp2Blob(pair.second.bn);\n      if (blob == nullptr) {\n        if (*arg_tensor_ptr) {\n          arg_tensor_ptr->reset(nullptr);\n          updated = true;\n        }\n      } else {\n        if (*arg_tensor_ptr) {\n          if (arg_tensor_ptr->get()->blob() != blob) {\n            arg_tensor_ptr->get()->Reset(blob);\n            updated = true;\n          } else {\n            if (blob->blob_desc().is_dynamic()) { updated = true; }\n          }\n        } else {\n          arg_tensor_ptr->reset(new user_op::BlobTensorView(blob));\n          updated = true;\n        }\n      }\n    }\n    return updated;\n  }\n\n  DeviceType device_type() const override { return base_ctx_.device_type(); }\n  const ParallelContext& parallel_ctx() const override { return base_ctx_.parallel_ctx(); }\n\n  const ArgVec& inputs() const override { return base_ctx_.inputs(); }\n  const ArgVec& outputs() const override { return base_ctx_.outputs(); }\n\n private:\n  const std::shared_ptr<const user_op::AttrVal>& Attr4Name(\n      const std::string& attr_name) const override {\n    return user_op_conf().Attr4Name(attr_name);\n  }\n\n  const user_op::UserOpConfWrapper& user_op_conf() const override { return user_op_conf_; }\n\n  user_op::UserOpConfWrapper user_op_conf_;\n  ep::Stream* stream_;\n  HashMap<std::pair<std::string, int32_t>, BnTensorPair> arg2bn_tensor_pair_;\n  UserKernelBaseContext base_ctx_;\n};\n\n// kernel registry context used in kernel creation\nclass UserKernelRegContext final : public user_op::KernelRegContext {\n public:\n  explicit UserKernelRegContext(const KernelConf& kernel_conf)\n      : user_op_conf_(kernel_conf.op_attribute().op_conf()),\n        base_ctx_(UserKernelBaseContext(kernel_conf)) {}\n  ~UserKernelRegContext() = default;\n\n  DeviceType device_type() const override { return base_ctx_.device_type(); }\n  const ParallelContext& parallel_ctx() const override { return base_ctx_.parallel_ctx(); }\n  const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(const std::string& arg_name,\n                                                        int32_t index) const override {\n    return base_ctx_.TensorDesc4ArgNameAndIndex(arg_name, index);\n  }\n  const ArgVec& inputs() const override { return base_ctx_.inputs(); }\n  const ArgVec& outputs() const override { return base_ctx_.outputs(); }\n\n  const user_op::UserOpConfWrapper& user_op_conf() const override { return user_op_conf_; }\n\n  const std::shared_ptr<const user_op::AttrVal>& Attr4Name(\n      const std::string& attr_name) const override {\n    return user_op_conf().Attr4Name(attr_name);\n  }\n\n private:\n  user_op::UserOpConfWrapper user_op_conf_;\n  UserKernelBaseContext base_ctx_;\n};\n\nUserKernel::~UserKernel() = default;\n\nvoid UserKernel::InitUserKernel(ep::Stream* stream) {\n  ctx_.reset(new UserKernelComputeContext(stream, kernel_conf()));\n  infer_ctx_.reset(new UserKernelInferContext(stream, kernel_conf()));\n  cache_ctx_.reset(new UserKernelCacheContext(stream, kernel_conf()));\n  infer_cache_.reset(new user_op::OpKernelInferCache(kernel_conf(), this));\n  {\n    const std::string& op_type_name =\n        kernel_conf().op_attribute().op_conf().user_conf().op_type_name();\n    const user_op::OpKernelRegistryResult* kernel_reg_val =\n        CHECK_JUST(user_op::UserOpRegistryMgr::Get().GetOpKernelRegistryResult(\n            op_type_name, UserKernelRegContext(kernel_conf())));\n    CHECK_NOTNULL(kernel_reg_val);\n    kernel_.reset(kernel_reg_val->create_fn());\n  }\n}\n\nstd::shared_ptr<user_op::OpKernelState> UserKernel::CreateOpKernelState(KernelContext* ctx) {\n  UserKernelInitContext init_ctx(ctx->stream(), kernel_conf());\n  return kernel_->CreateOpKernelState(&init_ctx);\n}\n\nconst std::shared_ptr<user_op::OpKernelState>& UserKernel::GetOpKernelState() const {\n  return opkernel_state_;\n}\n\nvoid UserKernel::ForwardUserKernel(const std::function<Blob*(const std::string&)>& BnInOp2Blob,\n                                   user_op::OpKernelState* opkernel_state) const {\n  const bool updated = ctx_->UpdateTensorWithCorrBlob(BnInOp2Blob);\n\n  if (updated) {\n    cache_ctx_->UpdateTensorWithCorrBlob(BnInOp2Blob);\n    kernel_->InitOpKernelCacheWithFlags(cache_ctx_.get(), user_op::OpKernelCache::kAttrNotChanged,\n                                        &opkernel_cache_);\n  } else {\n    // do nothing\n  }\n#ifdef WITH_CUDA_GRAPHS\n  bool current_scope_capturing = false;\n  if (cuda_graph_exec_) {\n    auto* cuda_stream = dynamic_cast<ep::CudaStream*>(ctx_->stream());\n    if (!cuda_stream->IsGraphCapturing()) {\n      if (cuda_graph_exec_->IsInstantiated() && (!updated)) {\n        cuda_stream->LaunchGraph(cuda_graph_exec_.get());\n        return;\n      }\n      const auto* cuda_graph_support =\n          CHECK_NOTNULL(dynamic_cast<const user_op::CudaGraphSupport*>(kernel_.get()));\n      if (cuda_graph_support->IsReadyForCapture(ctx_.get(), opkernel_state,\n                                                opkernel_cache_.get())) {\n        current_scope_capturing = true;\n        cuda_stream->BeginGraphCapture();\n      }\n    }\n  }\n#endif  // WITH_CUDA_GRAPHS\n\n  kernel_->Compute(ctx_.get(), opkernel_state, opkernel_cache_.get());\n\n#ifdef WITH_CUDA_GRAPHS\n  if (cuda_graph_exec_ && current_scope_capturing) {\n    auto* cuda_stream = dynamic_cast<ep::CudaStream*>(ctx_->stream());\n    cuda_stream->EndGraphCapture(cuda_graph_exec_.get());\n    cuda_stream->LaunchGraph(cuda_graph_exec_.get());\n  }\n#endif  // WITH_CUDA_GRAPHS\n}\n\nbool UserKernel::IsCudaGraphSupported() const {\n#ifdef WITH_CUDA_GRAPHS\n  return cuda_graph_exec_.get() != nullptr;\n#else\n  return false;\n#endif  // WITH_CUDA_GRAPHS\n}\n\nbool UserKernel::IsReadyForCudaGraphCapture(KernelContext* ctx) const {\n  const auto* cuda_graph_support = dynamic_cast<const user_op::CudaGraphSupport*>(kernel_.get());\n  if (cuda_graph_support == nullptr) { return false; }\n  return cuda_graph_support->IsReadyForCapture(ctx_.get(), opkernel_state_.get(),\n                                               opkernel_cache_.get());\n}\n\nvoid UserKernel::VirtualKernelInit(KernelContext* ctx) {\n  InitUserKernel(ctx->stream());\n  CHECK(opkernel_state_.get() == nullptr);\n  opkernel_state_ = CreateOpKernelState(ctx);\n  kernel_->InitOpKernelCacheWithFlags(cache_ctx_.get(), user_op::OpKernelCache::kAllMayChanged,\n                                      &opkernel_cache_);\n#ifdef WITH_CUDA_GRAPHS\n  if (ParseBooleanFromEnv(\"ONEFLOW_KERNEL_ENABLE_CUDA_GRAPH\", false)\n      && (!ParseBooleanFromEnv(\"ONEFLOW_GRAPH_ENABLE_STREAM_ORDERED_MEMORY_ALLOCATION\", false))) {\n    UserKernelInitContext init_ctx(ctx->stream(), kernel_conf());\n    auto* cuda_stream = dynamic_cast<ep::CudaStream*>(ctx->stream());\n    const auto* cuda_graph_support = dynamic_cast<const user_op::CudaGraphSupport*>(kernel_.get());\n    if (cuda_stream != nullptr) {\n      if (cuda_graph_support != nullptr\n          && cuda_graph_support->IsCudaGraphSupported(&init_ctx, opkernel_state_.get())) {\n        cuda_graph_exec_.reset(new ep::CudaGraphExecutable());\n        VLOG(3) << \"CUDA Graphs Kernel: \" << op_conf().name() << \" (\"\n                << op_conf().user_conf().op_type_name() << \")\";\n      } else {\n        VLOG(3) << \"CUDA Graphs not supported: \" << op_conf().name() << \" (\"\n                << op_conf().user_conf().op_type_name() << \")\";\n      }\n    }\n  }\n#endif  // WITH_CUDA_GRAPHS\n}\n\nvoid UserKernel::ForwardDataContent(KernelContext* ctx) const {\n  const auto BnInOp2Blob = [ctx](const std::string& bn) { return ctx->BnInOp2Blob(bn); };\n  ForwardUserKernel(BnInOp2Blob, opkernel_state_.get());\n}\n\nvoid UserKernel::ForwardShape(KernelContext* ctx) const {\n  const auto BnInOp2Blob = [ctx](const std::string& bn) { return ctx->BnInOp2Blob(bn); };\n  infer_ctx_->UpdateArg2Tensor(BnInOp2Blob);\n  infer_cache_->UpdateCacheKey(infer_ctx_.get());\n  if (!infer_cache_->IsCacheHit()) {\n    auto* op_infer_ctx = dynamic_cast<UserKernelOpInferContext*>(infer_ctx_->MutOpInferContext());\n    CHECK_NOTNULL(op_infer_ctx);\n    op_infer_ctx->UpdateArg2TensorDesc(BnInOp2Blob);\n    kernel_->InferShape(infer_ctx_.get());\n    for (const auto& out_arg_pair : infer_ctx_->outputs()) {\n      const Shape& static_shape =\n          infer_ctx_->TensorDesc4ArgNameAndIndex(out_arg_pair.first, out_arg_pair.second)->shape();\n      const ShapeView& shape_view =\n          infer_ctx_->ShapeView4ArgNameAndIndex(out_arg_pair.first, out_arg_pair.second);\n      CHECK_LE(shape_view.elem_cnt(), static_shape.elem_cnt())\n          << \"InferShape of OpKernel (op_type_name: \" << op_conf().user_conf().op_type_name()\n          << \", op_name: \" << op_conf().name()\n          << \") raise error, output arg's (name: \" << out_arg_pair.first\n          << \", index: \" << out_arg_pair.second << \") runtime shape \" << shape_view.ToString()\n          << \" surpass the limit of static shape \" << static_shape.ToString();\n    }\n    infer_cache_->UpdateCacheValue(infer_ctx_.get());\n  } else {\n    std::shared_ptr<const OpInferCacheValue> cache_value_ptr = infer_cache_->GetCacheValue();\n    FOR_RANGE(int, i, 0, infer_ctx_->outputs().size()) {\n      const auto& out_arg_pair = infer_ctx_->outputs().at(i);\n      MutShapeView mut_shape_view =\n          infer_ctx_->MutShapeView4ArgNameAndIndex(out_arg_pair.first, out_arg_pair.second);\n      mut_shape_view.set_shape(*cache_value_ptr->obn_idx2shape_sym.at(i));\n    }\n  }\n}\n\nbool UserKernel::IsStateless() const { return !kernel_->AlwaysComputeWhenAllOutputsEmpty(); }\nNEW_REGISTER_KERNEL(OperatorConf::kUserConf, UserKernel).SetIsMatchedPred([](const KernelConf&) {\n  return true;\n});\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/kernel/user_kernel.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/framework/op_kernel.h\"\n#include \"oneflow/core/framework/op_kernel_infer_cache.h\"\n#include \"oneflow/core/framework/user_op_tensor.h\"\n#include \"oneflow/core/framework/to_string.h\"\n#include \"oneflow/core/framework/user_op_conf.h\"\n#include \"oneflow/core/framework/user_op_registry_manager.h\"\n#include \"oneflow/core/kernel/kernel.h\"\n\n#ifdef WITH_CUDA\n\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n\n#endif  // WITH_CUDA\n\nnamespace oneflow {\n\nclass UserKernelComputeContext;\nclass UserKernelInferContext;\nclass UserKernelInitAndCacheContext;\n\nnamespace user_op {\nclass OpKernelCache;\n}\n\nclass UserKernel final : public Kernel {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(UserKernel);\n  UserKernel() = default;\n  ~UserKernel() override;\n\n  void InitUserKernel(ep::Stream* stream);\n  std::shared_ptr<user_op::OpKernelState> CreateOpKernelState(KernelContext* ctx);\n  const std::shared_ptr<user_op::OpKernelState>& GetOpKernelState() const;\n  void ForwardUserKernel(const std::function<Blob*(const std::string&)>& BnInOp2Blob,\n                         user_op::OpKernelState* opkernel_state) const;\n  bool IsCudaGraphSupported() const;\n  bool IsReadyForCudaGraphCapture(KernelContext* ctx) const;\n\n private:\n  void VirtualKernelInit(KernelContext* ctx) override;\n\n  void ForwardDataContent(KernelContext* ctx) const override;\n  void ForwardShape(KernelContext* ctx) const override;\n\n  bool IsStateless() const override;\n  bool IsKernelLaunchSynchronized() const override { return kernel_->IsKernelLaunchSynchronized(); }\n\n  mutable std::shared_ptr<user_op::OpKernelCache> opkernel_cache_;\n  std::shared_ptr<user_op::OpKernelState> opkernel_state_;\n  std::unique_ptr<const user_op::OpKernel> kernel_;\n  std::unique_ptr<UserKernelComputeContext> ctx_;\n  std::unique_ptr<UserKernelInitAndCacheContext> cache_ctx_;\n  std::unique_ptr<UserKernelInferContext> infer_ctx_;\n  std::unique_ptr<user_op::OpKernelInferCache> infer_cache_;\n#ifdef WITH_CUDA_GRAPHS\n  std::unique_ptr<ep::CudaGraphExecutable> cuda_graph_exec_;\n#endif  // WITH_CUDA_GRAPHS\n};\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/kernel/util/cuda_half_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_KERNEL_UTIL_CUDA_HALF_UTIL_H_\n#define ONEFLOW_CORE_KERNEL_UTIL_CUDA_HALF_UTIL_H_\n\n#include \"oneflow/core/device/cuda_util.h\"\n\nnamespace oneflow {\n\n#define HALF_CHECK_FAILED                                             \\\n  printf(\"half operations are only supported when CUDA_ARCH >= 530\"); \\\n  assert(false)\n\n__inline__ __device__ half hone() { return __float2half(1.0); }\n__inline__ __device__ half hzero() { return __float2half(0.0); }\n\n__inline__ half float16_2half(float16 x) {\n  // TODO: Potential loss of accuracy\n  half* ret = reinterpret_cast<half*>(&x);\n  return *ret;\n}\n\n__inline__ float16 half2float16(half x) {\n  // TODO: Potential loss of accuracy\n  float16* ret = reinterpret_cast<float16*>(&x);\n  return *ret;\n}\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_KERNEL_UTIL_CUDA_HALF_UTIL_H_\n"
  },
  {
    "path": "oneflow/core/kernel/util/numeric_limits.cuh",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n// reference: https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/cuda/NumericLimits.cuh\n#pragma once\n#include <limits.h>\n#include <math.h>\n#include <float.h>\n\n#include \"oneflow/core/device/cuda_util.h\"\n#include \"oneflow/core/framework/framework.h\"\n\n// numeric_limits.cuh is a holder for numeric limits definitions of commonly used\n// types. This header is very specific to ROCm HIP and may be removed in the future.\n\n// The lower_bound and upper_bound constants are same as lowest and max for\n// integral types, but are -inf and +inf for floating point types. They are\n// useful in implementing min, max, etc.\n\nnamespace oneflow {\nnamespace detail {\n\n#if defined(__CUDACC__)\n#define OF_NUMERICS_FUNC static inline __host__ __device__\n#else\n#define OF_NUMERICS_FUNC static inline\n#endif\n\ntemplate<typename T>\nstruct numeric_limits {};\n\n// WARNING: the following oneflow::numeric_limits definitions are there only to support\n//          HIP compilation for the moment. Use std::numeric_limits if you are not\n//          compiling for ROCm.\n//          from @colesbury: \"The functions on numeric_limits aren't marked with\n//          __device__ which is why they don't work with ROCm. CUDA allows them\n//          because they're constexpr.\"\n\nnamespace {\n// ROCm doesn't like INFINITY too.\nconstexpr double inf = INFINITY;\n}  // namespace\n\ntemplate<>\nstruct numeric_limits<bool> {\n  OF_NUMERICS_FUNC bool lowest() { return false; }\n  OF_NUMERICS_FUNC bool max() { return true; }\n  OF_NUMERICS_FUNC bool lower_bound() { return false; }\n  OF_NUMERICS_FUNC bool upper_bound() { return true; }\n};\n\ntemplate<>\nstruct numeric_limits<uint8_t> {\n  OF_NUMERICS_FUNC uint8_t lowest() { return 0; }\n  OF_NUMERICS_FUNC uint8_t max() { return UINT8_MAX; }\n  OF_NUMERICS_FUNC uint8_t lower_bound() { return 0; }\n  OF_NUMERICS_FUNC uint8_t upper_bound() { return UINT8_MAX; }\n};\n\ntemplate<>\nstruct numeric_limits<int8_t> {\n  OF_NUMERICS_FUNC int8_t lowest() { return INT8_MIN; }\n  OF_NUMERICS_FUNC int8_t max() { return INT8_MAX; }\n  OF_NUMERICS_FUNC int8_t lower_bound() { return INT8_MIN; }\n  OF_NUMERICS_FUNC int8_t upper_bound() { return INT8_MAX; }\n};\n\ntemplate<>\nstruct numeric_limits<int16_t> {\n  OF_NUMERICS_FUNC int16_t lowest() { return INT16_MIN; }\n  OF_NUMERICS_FUNC int16_t max() { return INT16_MAX; }\n  OF_NUMERICS_FUNC int16_t lower_bound() { return INT16_MIN; }\n  OF_NUMERICS_FUNC int16_t upper_bound() { return INT16_MAX; }\n};\n\ntemplate<>\nstruct numeric_limits<int32_t> {\n  OF_NUMERICS_FUNC int32_t lowest() { return INT32_MIN; }\n  OF_NUMERICS_FUNC int32_t max() { return INT32_MAX; }\n  OF_NUMERICS_FUNC int32_t lower_bound() { return INT32_MIN; }\n  OF_NUMERICS_FUNC int32_t upper_bound() { return INT32_MAX; }\n};\n\ntemplate<>\nstruct numeric_limits<int64_t> {\n#ifdef _MSC_VER\n  OF_NUMERICS_FUNC int64_t lowest() { return _I64_MIN; }\n  OF_NUMERICS_FUNC int64_t max() { return _I64_MAX; }\n  OF_NUMERICS_FUNC int64_t lower_bound() { return _I64_MIN; }\n  OF_NUMERICS_FUNC int64_t upper_bound() { return _I64_MAX; }\n#else\n  OF_NUMERICS_FUNC int64_t lowest() { return INT64_MIN; }\n  OF_NUMERICS_FUNC int64_t max() { return INT64_MAX; }\n  OF_NUMERICS_FUNC int64_t lower_bound() { return INT64_MIN; }\n  OF_NUMERICS_FUNC int64_t upper_bound() { return INT64_MAX; }\n#endif\n};\n\ntemplate<>\nstruct numeric_limits<float> {\n  OF_NUMERICS_FUNC float lowest() { return -FLT_MAX; }\n  OF_NUMERICS_FUNC float max() { return FLT_MAX; }\n  OF_NUMERICS_FUNC float lower_bound() { return -static_cast<float>(inf); }\n  OF_NUMERICS_FUNC float upper_bound() { return static_cast<float>(inf); }\n};\n\n#if defined(__CUDACC__)\nstatic __device__ unsigned short int HALF_LOWEST = 0xfbff;\nstatic __device__ unsigned short int HALF_MAX = 0x7bff;\nstatic __device__ unsigned short int HALF_LOWER_BOUND = 0xfc00;\nstatic __device__ unsigned short int HALF_UPPER_BOUND = 0x7c00;\ntemplate<>\nstruct numeric_limits<half> {\n  static inline __device__ half lowest() { return *reinterpret_cast<const __half*>(&HALF_LOWEST); }\n  static inline __device__ half max() { return *reinterpret_cast<const __half*>(&HALF_MAX); }\n  static inline __device__ half lower_bound() {\n    return *reinterpret_cast<const __half*>(&HALF_LOWER_BOUND);\n  }\n  static inline __device__ half upper_bound() {\n    return *reinterpret_cast<const __half*>(&HALF_UPPER_BOUND);\n  }\n};\n\n#if CUDA_VERSION >= 11000\n\nstatic __device__ unsigned short int NV_BFLOAT16_LOWEST = 0xff7f;\nstatic __device__ unsigned short int NV_BFLOAT16_MAX = 0x7f7f;\nstatic __device__ unsigned short int NV_BFLOAT16_LOWER_BOUND = 0xff80;\nstatic __device__ unsigned short int NV_BFLOAT16_UPPER_BOUND = 0x7f80;\ntemplate<>\nstruct numeric_limits<nv_bfloat16> {\n  static inline __device__ nv_bfloat16 lowest() {\n    return *reinterpret_cast<const __nv_bfloat16*>(&NV_BFLOAT16_LOWEST);\n  }\n  static inline __device__ nv_bfloat16 max() {\n    return *reinterpret_cast<const __nv_bfloat16*>(&NV_BFLOAT16_MAX);\n  }\n  static inline __device__ nv_bfloat16 lower_bound() {\n    return *reinterpret_cast<const __nv_bfloat16*>(&NV_BFLOAT16_LOWER_BOUND);\n  }\n  static inline __device__ nv_bfloat16 upper_bound() {\n    return *reinterpret_cast<const __nv_bfloat16*>(&NV_BFLOAT16_UPPER_BOUND);\n  }\n};\n\n#endif  // CUDA_VERSION >= 11000\n\n#endif  // defined(__CUDACC__)\n\ntemplate<>\nstruct numeric_limits<double> {\n  OF_NUMERICS_FUNC double lowest() { return -DBL_MAX; }\n  OF_NUMERICS_FUNC double max() { return DBL_MAX; }\n  OF_NUMERICS_FUNC double lower_bound() { return -inf; }\n  OF_NUMERICS_FUNC double upper_bound() { return inf; }\n};\n\n}  // namespace detail\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/kernel/util/numerics.cuh",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n// reference: https://github.com/pytorch/pytorch/blob/master/aten/src/THC/THCNumerics.cuh\n#ifndef ONEFLOW_CORE_KERNEL_UTIL_NUMERICS_H\n#define ONEFLOW_CORE_KERNEL_UTIL_NUMERICS_H\n#pragma once\n\n#include <limits.h>\n#include <math.h>\n#include <float.h>\n#include <cstdlib>\n#include <assert.h>\n\n#include \"oneflow/core/device/cuda_util.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/util/numeric_limits.cuh\"\n\nnamespace oneflow {\nnamespace detail {\n\ntemplate<typename T>\nstruct numerics {};\n\ntemplate<typename T>\nOF_NUMERICS_FUNC T powi(T a, T b) {\n  assert(numerics<T>::ge(b, 0));\n  T result = 1;\n  while (b) {\n    if (b & 1) { result *= a; }\n    b /= 2;\n    a *= a;\n  }\n  return result;\n}\n\ntemplate<>\nstruct numerics<uint8_t> {\n  OF_NUMERICS_FUNC uint8_t min() { return detail::numeric_limits<uint8_t>::lowest(); }\n  OF_NUMERICS_FUNC uint8_t max() { return detail::numeric_limits<uint8_t>::max(); }\n  OF_NUMERICS_FUNC uint8_t lower_bound() { return detail::numeric_limits<uint8_t>::lower_bound(); }\n  OF_NUMERICS_FUNC uint8_t upper_bound() { return detail::numeric_limits<uint8_t>::upper_bound(); }\n\n  OF_NUMERICS_FUNC bool lt(uint8_t a, uint8_t b) { return a < b; }\n  OF_NUMERICS_FUNC bool le(uint8_t a, uint8_t b) { return a <= b; }\n  OF_NUMERICS_FUNC bool gt(uint8_t a, uint8_t b) { return a > b; }\n  OF_NUMERICS_FUNC bool ge(uint8_t a, uint8_t b) { return a >= b; }\n  OF_NUMERICS_FUNC bool eq(uint8_t a, uint8_t b) { return a == b; }\n  OF_NUMERICS_FUNC bool ne(uint8_t a, uint8_t b) { return a != b; }\n\n  OF_NUMERICS_FUNC uint8_t add(uint8_t a, uint8_t b) { return a + b; }\n  OF_NUMERICS_FUNC uint8_t mul(uint8_t a, uint8_t b) { return a * b; }\n  OF_NUMERICS_FUNC uint8_t sub(uint8_t a, uint8_t b) { return a - b; }\n  OF_NUMERICS_FUNC uint8_t div(uint8_t a, uint8_t b) { return a / b; }\n  OF_NUMERICS_FUNC uint8_t pow(uint8_t a, uint8_t b) { return powi<uint8_t>(a, b); }\n  OF_NUMERICS_FUNC bool isnan(uint8_t a) { return false; }\n  OF_NUMERICS_FUNC bool isinf(uint8_t a) { return false; }\n};\n\n#ifdef _MSC_VER\n// Suppress warning C4804: '/': unsafe use of type 'bool' in operation\n#pragma warning(push)\n#pragma warning(disable : 4804)\n#endif\n\ntemplate<>\nstruct numerics<bool> {\n  OF_NUMERICS_FUNC bool min() { return detail::numeric_limits<bool>::lowest(); }\n  OF_NUMERICS_FUNC bool max() { return detail::numeric_limits<bool>::max(); }\n  OF_NUMERICS_FUNC bool lower_bound() { return detail::numeric_limits<bool>::lower_bound(); }\n  OF_NUMERICS_FUNC bool upper_bound() { return detail::numeric_limits<bool>::upper_bound(); }\n\n  OF_NUMERICS_FUNC bool lt(bool a, bool b) { return a < b; }\n  OF_NUMERICS_FUNC bool le(bool a, bool b) { return a <= b; }\n  OF_NUMERICS_FUNC bool gt(bool a, bool b) { return a > b; }\n  OF_NUMERICS_FUNC bool ge(bool a, bool b) { return a >= b; }\n  OF_NUMERICS_FUNC bool eq(bool a, bool b) { return a == b; }\n  OF_NUMERICS_FUNC bool ne(bool a, bool b) { return a != b; }\n  OF_NUMERICS_FUNC bool add(bool a, bool b) { return a + b; }\n  OF_NUMERICS_FUNC bool mul(bool a, bool b) { return a && b; }\n  OF_NUMERICS_FUNC bool sub(bool a, bool b) { return a - b; }\n  OF_NUMERICS_FUNC bool div(bool a, bool b) { return a / b; }\n  OF_NUMERICS_FUNC bool isnan(bool a) { return false; }\n  OF_NUMERICS_FUNC bool isinf(bool a) { return false; }\n};\n\n#ifdef _MSC_VER\n#pragma warning(pop)\n#endif\n\ntemplate<>\nstruct numerics<int8_t> {\n  OF_NUMERICS_FUNC int8_t min() { return detail::numeric_limits<int8_t>::lowest(); }\n  OF_NUMERICS_FUNC int8_t max() { return detail::numeric_limits<int8_t>::max(); }\n  OF_NUMERICS_FUNC int8_t lower_bound() { return detail::numeric_limits<int8_t>::lower_bound(); }\n  OF_NUMERICS_FUNC int8_t upper_bound() { return detail::numeric_limits<int8_t>::upper_bound(); }\n\n  OF_NUMERICS_FUNC bool lt(int8_t a, int8_t b) { return a < b; }\n  OF_NUMERICS_FUNC bool le(int8_t a, int8_t b) { return a <= b; }\n  OF_NUMERICS_FUNC bool gt(int8_t a, int8_t b) { return a > b; }\n  OF_NUMERICS_FUNC bool ge(int8_t a, int8_t b) { return a >= b; }\n  OF_NUMERICS_FUNC bool eq(int8_t a, int8_t b) { return a == b; }\n  OF_NUMERICS_FUNC bool ne(int8_t a, int8_t b) { return a != b; }\n\n  OF_NUMERICS_FUNC int8_t add(int8_t a, int8_t b) { return a + b; }\n  OF_NUMERICS_FUNC int8_t mul(int8_t a, int8_t b) { return a * b; }\n  OF_NUMERICS_FUNC int8_t sub(int8_t a, int8_t b) { return a - b; }\n  OF_NUMERICS_FUNC int8_t div(int8_t a, int8_t b) { return a / b; }\n  OF_NUMERICS_FUNC int8_t pow(int8_t a, int8_t b) { return powi<int8_t>(a, b); }\n  OF_NUMERICS_FUNC bool isnan(int8_t a) { return false; }\n  OF_NUMERICS_FUNC bool isinf(int8_t a) { return false; }\n};\n\ntemplate<>\nstruct numerics<int16_t> {\n  OF_NUMERICS_FUNC int16_t min() { return detail::numeric_limits<int16_t>::lowest(); }\n  OF_NUMERICS_FUNC int16_t max() { return detail::numeric_limits<int16_t>::max(); }\n  OF_NUMERICS_FUNC int16_t lower_bound() { return detail::numeric_limits<int16_t>::lower_bound(); }\n  OF_NUMERICS_FUNC int16_t upper_bound() { return detail::numeric_limits<int16_t>::upper_bound(); }\n\n  OF_NUMERICS_FUNC bool lt(int16_t a, int16_t b) { return a < b; }\n  OF_NUMERICS_FUNC bool le(int16_t a, int16_t b) { return a <= b; }\n  OF_NUMERICS_FUNC bool gt(int16_t a, int16_t b) { return a > b; }\n  OF_NUMERICS_FUNC bool ge(int16_t a, int16_t b) { return a >= b; }\n  OF_NUMERICS_FUNC bool eq(int16_t a, int16_t b) { return a == b; }\n  OF_NUMERICS_FUNC bool ne(int16_t a, int16_t b) { return a != b; }\n\n  OF_NUMERICS_FUNC int16_t add(int16_t a, int16_t b) { return a + b; }\n  OF_NUMERICS_FUNC int16_t mul(int16_t a, int16_t b) { return a * b; }\n  OF_NUMERICS_FUNC int16_t sub(int16_t a, int16_t b) { return a - b; }\n  OF_NUMERICS_FUNC int16_t div(int16_t a, int16_t b) { return a / b; }\n  OF_NUMERICS_FUNC int16_t pow(int16_t a, int16_t b) { return powi<int16_t>(a, b); }\n  OF_NUMERICS_FUNC bool isnan(int16_t a) { return false; }\n  OF_NUMERICS_FUNC bool isinf(int16_t a) { return false; }\n};\n\ntemplate<>\nstruct numerics<int32_t> {\n  OF_NUMERICS_FUNC int32_t min() { return detail::numeric_limits<int32_t>::lowest(); }\n  OF_NUMERICS_FUNC int32_t max() { return detail::numeric_limits<int32_t>::max(); }\n  OF_NUMERICS_FUNC int32_t lower_bound() { return detail::numeric_limits<int32_t>::lower_bound(); }\n  OF_NUMERICS_FUNC int32_t upper_bound() { return detail::numeric_limits<int32_t>::upper_bound(); }\n\n  OF_NUMERICS_FUNC bool lt(int32_t a, int32_t b) { return a < b; }\n  OF_NUMERICS_FUNC bool le(int32_t a, int32_t b) { return a <= b; }\n  OF_NUMERICS_FUNC bool gt(int32_t a, int32_t b) { return a > b; }\n  OF_NUMERICS_FUNC bool ge(int32_t a, int32_t b) { return a >= b; }\n  OF_NUMERICS_FUNC bool eq(int32_t a, int32_t b) { return a == b; }\n  OF_NUMERICS_FUNC bool ne(int32_t a, int32_t b) { return a != b; }\n\n  OF_NUMERICS_FUNC int32_t add(int32_t a, int32_t b) { return a + b; }\n  OF_NUMERICS_FUNC int32_t mul(int32_t a, int32_t b) { return a * b; }\n  OF_NUMERICS_FUNC int32_t sub(int32_t a, int32_t b) { return a - b; }\n  OF_NUMERICS_FUNC int32_t div(int32_t a, int32_t b) { return a / b; }\n  OF_NUMERICS_FUNC int32_t pow(int32_t a, int32_t b) { return powi<int32_t>(a, b); }\n  OF_NUMERICS_FUNC bool isnan(int32_t a) { return false; }\n  OF_NUMERICS_FUNC bool isinf(int32_t a) { return false; }\n};\n\ntemplate<>\nstruct numerics<int64_t> {\n  OF_NUMERICS_FUNC int64_t min() { return detail::numeric_limits<int64_t>::lowest(); }\n  OF_NUMERICS_FUNC int64_t max() { return detail::numeric_limits<int64_t>::max(); }\n  OF_NUMERICS_FUNC int64_t lower_bound() { return detail::numeric_limits<int64_t>::lower_bound(); }\n  OF_NUMERICS_FUNC int64_t upper_bound() { return detail::numeric_limits<int64_t>::upper_bound(); }\n\n  OF_NUMERICS_FUNC bool lt(int64_t a, int64_t b) { return a < b; }\n  OF_NUMERICS_FUNC bool le(int64_t a, int64_t b) { return a <= b; }\n  OF_NUMERICS_FUNC bool gt(int64_t a, int64_t b) { return a > b; }\n  OF_NUMERICS_FUNC bool ge(int64_t a, int64_t b) { return a >= b; }\n  OF_NUMERICS_FUNC bool eq(int64_t a, int64_t b) { return a == b; }\n  OF_NUMERICS_FUNC bool ne(int64_t a, int64_t b) { return a != b; }\n\n  OF_NUMERICS_FUNC int64_t add(int64_t a, int64_t b) { return a + b; }\n  OF_NUMERICS_FUNC int64_t mul(int64_t a, int64_t b) { return a * b; }\n  OF_NUMERICS_FUNC int64_t sub(int64_t a, int64_t b) { return a - b; }\n  OF_NUMERICS_FUNC int64_t div(int64_t a, int64_t b) { return a / b; };\n  OF_NUMERICS_FUNC int64_t pow(int64_t a, int64_t b) { return powi<int64_t>(a, b); }\n  OF_NUMERICS_FUNC bool isnan(int64_t a) { return false; }\n  OF_NUMERICS_FUNC bool isinf(int64_t a) { return false; }\n};\n\n// DEPRECATED: use math functions from std and cuda math API (if needed)\ntemplate<>\nstruct numerics<float> {\n  OF_NUMERICS_FUNC float min() { return detail::numeric_limits<float>::lowest(); }\n  OF_NUMERICS_FUNC float max() { return detail::numeric_limits<float>::max(); }\n  OF_NUMERICS_FUNC float lower_bound() { return detail::numeric_limits<float>::lower_bound(); }\n  OF_NUMERICS_FUNC float upper_bound() { return detail::numeric_limits<float>::upper_bound(); }\n\n  OF_NUMERICS_FUNC bool lt(float a, float b) { return a < b; }\n  OF_NUMERICS_FUNC bool le(float a, float b) { return a <= b; }\n  OF_NUMERICS_FUNC bool gt(float a, float b) { return a > b; }\n  OF_NUMERICS_FUNC bool ge(float a, float b) { return a >= b; }\n  OF_NUMERICS_FUNC bool eq(float a, float b) { return a == b; }\n  OF_NUMERICS_FUNC bool ne(float a, float b) { return a != b; }\n\n  OF_NUMERICS_FUNC float sqrt(float a) { return sqrtf(a); }\n  OF_NUMERICS_FUNC float atan(float a) { return atanf(a); }\n  OF_NUMERICS_FUNC float add(float a, float b) { return a + b; }\n  OF_NUMERICS_FUNC float div(float a, float b) { return a / b; }\n  OF_NUMERICS_FUNC float mul(float a, float b) { return a * b; }\n  OF_NUMERICS_FUNC float sub(float a, float b) { return a - b; }\n  OF_NUMERICS_FUNC float pow(float a, float b) { return powf(a, b); }\n  OF_NUMERICS_FUNC bool isnan(float a) { return ::isnan(a); }\n  OF_NUMERICS_FUNC bool isinf(float a) { return ::isinf(a); }\n};\n\n#if defined(__CUDACC__)\ntemplate<>\nstruct numerics<half> {\n  OF_NUMERICS_FUNC bool isnan(half a) { return ::isnan((float)a); }\n};\n#endif\n\ntemplate<>\nstruct numerics<double> {\n  OF_NUMERICS_FUNC double min() { return detail::numeric_limits<double>::lowest(); }\n  OF_NUMERICS_FUNC double max() { return detail::numeric_limits<double>::max(); }\n  OF_NUMERICS_FUNC double lower_bound() { return detail::numeric_limits<double>::lower_bound(); }\n  OF_NUMERICS_FUNC double upper_bound() { return detail::numeric_limits<double>::upper_bound(); }\n\n  OF_NUMERICS_FUNC bool lt(double a, double b) { return a < b; }\n  OF_NUMERICS_FUNC bool le(double a, double b) { return a <= b; }\n  OF_NUMERICS_FUNC bool gt(double a, double b) { return a > b; }\n  OF_NUMERICS_FUNC bool ge(double a, double b) { return a >= b; }\n  OF_NUMERICS_FUNC bool eq(double a, double b) { return a == b; }\n  OF_NUMERICS_FUNC bool ne(double a, double b) { return a != b; }\n\n  OF_NUMERICS_FUNC double sqrt(double a) { return ::sqrt(a); }\n  OF_NUMERICS_FUNC double atan(double a) { return ::atan(a); }\n  OF_NUMERICS_FUNC double add(double a, double b) { return a + b; }\n  OF_NUMERICS_FUNC double div(double a, double b) { return a / b; }\n  OF_NUMERICS_FUNC double mul(double a, double b) { return a * b; }\n  OF_NUMERICS_FUNC double sub(double a, double b) { return a - b; }\n  OF_NUMERICS_FUNC double pow(double a, double b) { return ::pow(a, b); }\n  OF_NUMERICS_FUNC bool isnan(double a) { return ::isnan(a); }\n  OF_NUMERICS_FUNC bool isinf(double a) { return ::isinf(a); }\n};\n\n}  // namespace detail\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_KERNEL_UTIL_NUMERICS_H\n"
  },
  {
    "path": "oneflow/core/kernel/wait_and_send_ids_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/kernel/wait_and_send_ids_kernel.h\"\n#include \"oneflow/core/common/buffer_manager.h\"\n#include \"oneflow/core/job/job_instance.h\"\n#include \"oneflow/core/job/global_for.h\"\n\nnamespace oneflow {\n\ntemplate<typename T>\nvoid WaitAndSendIdsKernel<T>::VirtualKernelInit(KernelContext* ctx) {\n  ctx->set_state(std::make_shared<WaitAndSendIdsStatus>());\n}\n\ntemplate<typename T>\nvoid WaitAndSendIdsKernel<T>::ForwardDataContent(KernelContext* ctx) const {\n  auto* status = CHECK_NOTNULL(dynamic_cast<WaitAndSendIdsStatus*>(ctx->state().get()));\n  if (status->out_idx_ >= status->out_num_) {\n    CHECK(this->op_conf().wait_and_send_ids_conf().has_job_name());\n    const auto& job_name = this->op_conf().wait_and_send_ids_conf().job_name();\n    auto* buffer_mgr = Singleton<BufferMgr<std::shared_ptr<JobInstance>>>::Get();\n    auto* buffer = buffer_mgr->Get(GetSourceTickBufferName(job_name));\n    status->in_id_ = 0;\n    {\n      std::shared_ptr<JobInstance> job_instance;\n      status->buffer_status_ = buffer->Pull(&job_instance);\n    }\n    if (status->buffer_status_ == kBufferStatusErrorClosed) { return; }\n    status->out_idx_ = 0;\n    status->out_num_ = 1;\n  }\n\n  *ctx->BnInOp2Blob(\"out\")->mut_dptr<T>() = 0;\n  ++status->out_idx_;\n}\n\nADD_CPU_DEFAULT_KERNEL_CREATOR(OperatorConf::kWaitAndSendIdsConf, WaitAndSendIdsKernel,\n                               INT_DATA_TYPE_SEQ);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/kernel/wait_and_send_ids_kernel.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_KERNEL_WAIT_AND_SEND_IDS_KERNEL_H_\n#define ONEFLOW_CORE_KERNEL_WAIT_AND_SEND_IDS_KERNEL_H_\n\n#include \"oneflow/core/kernel/kernel.h\"\n#include \"oneflow/core/common/buffer_manager.h\"\n\nnamespace oneflow {\n\nstruct WaitAndSendIdsStatus final : public KernelState {\n  BufferStatus buffer_status_;\n  int64_t in_id_;\n  int64_t out_idx_;\n  size_t out_num_;\n};\n\ntemplate<typename T>\nclass WaitAndSendIdsKernel final : public Kernel {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(WaitAndSendIdsKernel);\n  WaitAndSendIdsKernel() = default;\n  ~WaitAndSendIdsKernel() = default;\n\n private:\n  void VirtualKernelInit(KernelContext* ctx) override;\n  void ForwardDataContent(KernelContext* ctx) const override;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_KERNEL_WAIT_AND_SEND_IDS_KERNEL_H_\n"
  },
  {
    "path": "oneflow/core/lazy/actor/acc_actor.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/lazy/actor/actor.h\"\n\nnamespace oneflow {\n\nclass AccActor final : public Actor {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(AccActor);\n  AccActor() = default;\n  ~AccActor() override = default;\n\n private:\n  void Act() override;\n  void VirtualAsyncSendNaiveProducedRegstMsgToConsumer() override;\n\n  void VirtualActorInit(const TaskProto& proto) override;\n\n  int32_t acc_cnt_{};\n  int32_t max_acc_cnt_{};\n};\n\nvoid AccActor::VirtualActorInit(const TaskProto& proto) {\n  const Shape& in_time_shape = Singleton<RegstMgr>::Get()\n                                   ->RegstDesc4RegstDescId(Name2SoleRegstDescId(\"in\"))\n                                   .data_regst_time_shape();\n  const Shape& out_time_shape = Singleton<RegstMgr>::Get()\n                                    ->RegstDesc4RegstDescId(Name2SoleRegstDescId(\"out\"))\n                                    .data_regst_time_shape();\n  CHECK_GE(in_time_shape.elem_cnt(), out_time_shape.elem_cnt());\n  max_acc_cnt_ = in_time_shape.elem_cnt() / out_time_shape.elem_cnt();\n  acc_cnt_ = 0;\n  OF_SET_MSG_HANDLER(&AccActor::HandlerNormal);\n}\n\nvoid AccActor::Act() {\n  if (acc_cnt_ == 0) {\n    Regst* out_regst = GetNaiveCurWriteable(\"out\");\n    Regst* in_regst = GetNaiveCurReadable(\"in\");\n    const Blob* in_blob = in_regst->GetMutSoleBlob();\n    Blob* out_blob = out_regst->GetMutSoleBlob();\n    const size_t size = in_blob->ByteSizeOfBlobBody();\n    CHECK_EQ(out_blob->ByteSizeOfBlobBody(), size);\n    AutoMemcpy(actor_ctx()->stream_ctx()->stream(), out_blob->ForceMutDptr(), in_blob->dptr(), size,\n               out_blob->mem_case(), in_blob->mem_case());\n  } else {\n    AsyncLaunchKernel();\n  }\n  acc_cnt_ += 1;\n}\n\nvoid AccActor::VirtualAsyncSendNaiveProducedRegstMsgToConsumer() {\n  if (acc_cnt_ == max_acc_cnt_) {\n    HandleProducedNaiveDataRegstToConsumer();\n    acc_cnt_ = 0;\n  }\n}\n\nREGISTER_ACTOR(TaskType::kAcc, AccActor);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/lazy/actor/acc_ctrl_tick_actor.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/lazy/actor/actor.h\"\n#include \"oneflow/core/framework/framework.h\"\n\nnamespace oneflow {\n\nclass AccCtrlTickActor : public Actor {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(AccCtrlTickActor);\n  AccCtrlTickActor()\n      : acc_cnt_(0),\n        max_acc_num_(0),\n        last_micro_batch_input_output_mutex_(false),\n        consumed_tick_regst_desc_id_(-1),\n        produced_tick_regst_desc_id_(-1){};\n  virtual ~AccCtrlTickActor() = default;\n\n private:\n  // NOTE(chengcheng): Empty rs for naive and inplace regst, all regst is customized.\n  std::pair<RegstNameType, HashSet<std::string>> GetNaiveOrCustomizedConsumedRegstDescName()\n      override {\n    return std::make_pair(RegstNameType::kNaive, HashSet<std::string>{});\n  }\n  std::pair<RegstNameType, HashSet<std::string>> GetNaiveOrCustomizedProducedRegstDescName()\n      override {\n    return std::make_pair(RegstNameType::kNaive, HashSet<std::string>{});\n  }\n\n  bool IsCustomizedReadReady() const override {\n    return (!last_micro_batch_input_output_mutex_) && consumed_tick_rs_.IsCurSlotReady();\n  }\n  bool IsCustomizedWriteReady() const override { return produced_tick_rs_.IsCurSlotReady(); }\n\n  void NormalProcessCustomizedEordMsg(const ActorMsg&) override {}\n  bool IsCustomizedReadAlwaysUnReadyFromNow() const override {\n    // all Messages are flushed\n    return ReceiveEordMsg(consumed_tick_regst_desc_id_);\n  }\n\n  void VirtualActorInit(const TaskProto& proto) override;\n  void Act() override;\n  void AsyncSendCustomizedProducedRegstMsgToConsumer() override;\n  void AsyncSendCustomizedConsumedRegstMsgToProducer() override;\n  void UpdtStateAsCustomizedProducedRegst(Regst* regst) override;\n  void NormalProcessCustomizedReadableRegstMsg(const ActorMsg& msg) override;\n\n  int32_t acc_cnt_;\n  int32_t max_acc_num_;\n  bool last_micro_batch_input_output_mutex_;\n  int64_t consumed_tick_regst_desc_id_;\n  int64_t produced_tick_regst_desc_id_;\n  RegstSlot consumed_tick_rs_;\n  RegstSlot produced_tick_rs_;\n};\n\nvoid AccCtrlTickActor::VirtualActorInit(const TaskProto& proto) {\n  acc_cnt_ = 0;\n  const OperatorConf& op_conf =\n      proto.exec_sequence().exec_node(0).kernel_conf().op_attribute().op_conf();\n  max_acc_num_ = user_op::UserOpConfWrapper(op_conf).attr<int32_t>(\"max_acc_num\");\n\n  // NOTE(chengcheng): check time shape equal max_acc_num\n  const Shape& in_time_shape = Singleton<RegstMgr>::Get()\n                                   ->RegstDesc4RegstDescId(Name2SoleRegstDescId(\"in\"))\n                                   .data_regst_time_shape();\n  const Shape& out_time_shape = Singleton<RegstMgr>::Get()\n                                    ->RegstDesc4RegstDescId(Name2SoleRegstDescId(\"out\"))\n                                    .data_regst_time_shape();\n  CHECK_EQ(in_time_shape.elem_cnt() % out_time_shape.elem_cnt(), 0);\n  CHECK_EQ(in_time_shape.elem_cnt() / out_time_shape.elem_cnt(), max_acc_num_);\n  CHECK_GT(max_acc_num_, 1);\n\n  // input\n  const auto& consumed_ids = proto.consumed_regst_desc_id();\n  CHECK_EQ(consumed_ids.size(), 1);\n  auto in_it = consumed_ids.find(\"in\");\n  CHECK(in_it != consumed_ids.end());\n  CHECK_EQ(in_it->second.regst_desc_id_size(), 1);\n  consumed_tick_regst_desc_id_ = in_it->second.regst_desc_id(0);\n  consumed_tick_rs_.InsertRegstDescId(consumed_tick_regst_desc_id_);\n  consumed_tick_rs_.InitedDone();\n\n  // output\n  CHECK_EQ(proto.produced_regst_desc().size(), 1);\n\n  const auto& produced_ids = proto.produced_regst_desc();\n  CHECK_EQ(produced_ids.size(), 1);\n  auto out_it = produced_ids.find(\"out\");\n  CHECK(out_it != produced_ids.end());\n  const RegstDescProto& out_regst_desc = out_it->second;\n  produced_tick_regst_desc_id_ = out_regst_desc.regst_desc_id();\n  produced_tick_rs_.InsertRegstDescId(produced_tick_regst_desc_id_);\n  produced_tick_rs_.InitedDone();\n\n  ForEachProducedRegst([&](Regst* regst) {\n    CHECK_EQ(regst->regst_desc_id(), produced_tick_regst_desc_id_);\n    CHECK_EQ(0, produced_tick_rs_.TryPushBackRegst(regst));\n  });\n\n  OF_SET_MSG_HANDLER(&AccCtrlTickActor::HandlerNormal);\n}\n\nvoid AccCtrlTickActor::Act() {\n  acc_cnt_ += 1;\n  if (acc_cnt_ == max_acc_num_) {\n    CHECK(!last_micro_batch_input_output_mutex_);\n    last_micro_batch_input_output_mutex_ = true;\n    acc_cnt_ = 0;\n  }\n}\n\nvoid AccCtrlTickActor::AsyncSendCustomizedProducedRegstMsgToConsumer() {\n  if (last_micro_batch_input_output_mutex_) {\n    CHECK(consumed_tick_rs_.IsCurSlotReady());  // inplace consume\n    CHECK(produced_tick_rs_.IsCurSlotReady());\n    Regst* const tick_regst = produced_tick_rs_.Front(produced_tick_regst_desc_id_);\n    CHECK_GT(HandleRegstToConsumer(tick_regst), 0);\n    produced_tick_rs_.PopFrontRegsts({produced_tick_regst_desc_id_});\n  }\n}\n\nvoid AccCtrlTickActor::AsyncSendCustomizedConsumedRegstMsgToProducer() {\n  if (!last_micro_batch_input_output_mutex_) {\n    Regst* const tick_regst = consumed_tick_rs_.Front(consumed_tick_regst_desc_id_);\n    CHECK_NOTNULL(tick_regst);\n    AsyncSendRegstMsgToProducer(tick_regst);\n    CHECK_EQ(0, consumed_tick_rs_.TryPopFrontRegst(consumed_tick_regst_desc_id_));\n  }\n}\n\nvoid AccCtrlTickActor::UpdtStateAsCustomizedProducedRegst(Regst* regst) {\n  CHECK(last_micro_batch_input_output_mutex_);\n  CHECK_EQ(regst->regst_desc_id(), produced_tick_regst_desc_id_);\n  CHECK_EQ(produced_tick_rs_.TryPushBackRegst(regst), 0);\n\n  Regst* in_regst = consumed_tick_rs_.Front(consumed_tick_regst_desc_id_);\n  CHECK(in_regst);\n  AsyncSendRegstMsgToProducer(in_regst);\n  CHECK_EQ(0, consumed_tick_rs_.TryPopFrontRegst(consumed_tick_regst_desc_id_));\n  last_micro_batch_input_output_mutex_ = false;\n}\n\nvoid AccCtrlTickActor::NormalProcessCustomizedReadableRegstMsg(const ActorMsg& msg) {\n  CHECK_EQ(0, consumed_tick_rs_.TryPushBackRegst(msg.regst()));\n}\n\nREGISTER_ACTOR(TaskType::kAccCtrlTick, AccCtrlTickActor);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/lazy/actor/acc_tick_actor.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/lazy/actor/actor.h\"\n\nnamespace oneflow {\n\nclass AccTickActor : public Actor {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(AccTickActor);\n  AccTickActor() = default;\n  virtual ~AccTickActor() = default;\n\n protected:\n  void VirtualActorInit(const TaskProto& proto) override;\n\n private:\n  void Act() override;\n  void VirtualAsyncSendNaiveProducedRegstMsgToConsumer() override;\n\n  int32_t acc_cnt_;\n  int32_t max_acc_cnt_;\n};\n\nvoid AccTickActor::VirtualActorInit(const TaskProto& proto) {\n  const Shape& in_time_shape = Singleton<RegstMgr>::Get()\n                                   ->RegstDesc4RegstDescId(Name2SoleRegstDescId(\"in\"))\n                                   .data_regst_time_shape();\n  const Shape& out_time_shape = Singleton<RegstMgr>::Get()\n                                    ->RegstDesc4RegstDescId(Name2SoleRegstDescId(\"out\"))\n                                    .data_regst_time_shape();\n  CHECK_EQ(in_time_shape.elem_cnt() % out_time_shape.elem_cnt(), 0);\n\n  acc_cnt_ = 0;\n  max_acc_cnt_ = in_time_shape.elem_cnt() / out_time_shape.elem_cnt();\n  OF_SET_MSG_HANDLER(&AccTickActor::HandlerNormal);\n}\n\nvoid AccTickActor::Act() { acc_cnt_ += 1; }\n\nvoid AccTickActor::VirtualAsyncSendNaiveProducedRegstMsgToConsumer() {\n  if (acc_cnt_ == max_acc_cnt_) {\n    HandleProducedNaiveDataRegstToConsumer();\n    acc_cnt_ = 0;\n  }\n}\n\nREGISTER_ACTOR(TaskType::kAccTick, AccTickActor);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/lazy/actor/actor.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/lazy/actor/actor.h\"\n#include \"oneflow/core/common/env_var/debug_mode.h\"\n#include \"oneflow/core/control/global_process_ctx.h\"\n#include \"oneflow/core/job/runtime_job_descs.h\"\n#include \"oneflow/core/lazy/stream_context/include/stream_context.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nclass KernelContextImpl : public KernelContext, public ActorContextProvider {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(KernelContextImpl);\n  explicit KernelContextImpl(ActorContext* actor_ctx)\n      : actor_ctx_(actor_ctx),\n        stream_ctx_(actor_ctx->stream_ctx()),\n        state_(nullptr),\n        stream_kernel_observer_(nullptr) {\n    auto* kernel_observer_provider = dynamic_cast<KernelObserverProvider*>(stream_ctx_);\n    if (kernel_observer_provider != nullptr) {\n      stream_kernel_observer_ = kernel_observer_provider->GetKernelObserver();\n    }\n  }\n  ~KernelContextImpl() = default;\n\n  ep::Stream* stream() const override { return stream_ctx_->stream(); }\n\n  ActorContext* GetActorContext() const override { return actor_ctx_; }\n\n  Blob* BnInOp2Blob(const std::string& bn) const override { return bn_in_op2blob_fn_(bn); }\n\n  const std::shared_ptr<KernelState>& state() const override { return state_; }\n\n  void set_state(std::shared_ptr<KernelState> state) override { state_ = std::move(state); }\n\n  void WillForward(KernelContext* kernel_ctx, const Kernel* kernel) override;\n  void DidForward(KernelContext* kernel_ctx, const Kernel* kernel) override;\n\n  void WillForwardHeader(KernelContext* kernel_ctx, const Kernel* kernel) override;\n  void DidForwardHeader(KernelContext* kernel_ctx, const Kernel* kernel) override;\n\n  void WillForwardDataContent(KernelContext* kernel_ctx, const Kernel* kernel) override;\n  void DidForwardDataContent(KernelContext* kernel_ctx, const Kernel* kernel) override;\n\n  void UpdateBnInOp2BlobFn(std::function<Blob*(const std::string&)> fn) {\n    bn_in_op2blob_fn_ = std::move(fn);\n  }\n\n private:\n  ActorContext* actor_ctx_;\n  StreamContext* stream_ctx_;\n  std::function<Blob*(const std::string&)> bn_in_op2blob_fn_;\n  std::shared_ptr<KernelState> state_;\n  KernelObserver* stream_kernel_observer_;\n};\n\nvoid KernelContextImpl::WillForward(KernelContext* kernel_ctx, const Kernel* kernel) {\n  Singleton<KernelObserver>::Get()->WillForward(kernel_ctx, kernel);\n  if (stream_kernel_observer_ != nullptr) {\n    stream_kernel_observer_->WillForward(kernel_ctx, kernel);\n  }\n}\n\nvoid KernelContextImpl::DidForward(KernelContext* kernel_ctx, const Kernel* kernel) {\n  CHECK_JUST_MSG(kernel_ctx->stream()->GetAsyncError(), kernel->op_conf().name());\n  Singleton<KernelObserver>::Get()->DidForward(kernel_ctx, kernel);\n  if (stream_kernel_observer_ != nullptr) {\n    stream_kernel_observer_->DidForward(kernel_ctx, kernel);\n  }\n}\n\nvoid KernelContextImpl::WillForwardHeader(KernelContext* kernel_ctx, const Kernel* kernel) {\n  Singleton<KernelObserver>::Get()->WillForwardHeader(kernel_ctx, kernel);\n  if (stream_kernel_observer_ != nullptr) {\n    stream_kernel_observer_->WillForwardHeader(kernel_ctx, kernel);\n  }\n}\n\nvoid KernelContextImpl::DidForwardHeader(KernelContext* kernel_ctx, const Kernel* kernel) {\n  Singleton<KernelObserver>::Get()->DidForwardHeader(kernel_ctx, kernel);\n  if (stream_kernel_observer_ != nullptr) {\n    stream_kernel_observer_->DidForwardHeader(kernel_ctx, kernel);\n  }\n}\n\nvoid KernelContextImpl::WillForwardDataContent(KernelContext* kernel_ctx, const Kernel* kernel) {\n  Singleton<KernelObserver>::Get()->WillForwardDataContent(kernel_ctx, kernel);\n  if (stream_kernel_observer_ != nullptr) {\n    stream_kernel_observer_->WillForwardDataContent(kernel_ctx, kernel);\n  }\n}\n\nvoid KernelContextImpl::DidForwardDataContent(KernelContext* kernel_ctx, const Kernel* kernel) {\n  Singleton<KernelObserver>::Get()->DidForwardDataContent(kernel_ctx, kernel);\n  if (stream_kernel_observer_ != nullptr) {\n    stream_kernel_observer_->DidForwardDataContent(kernel_ctx, kernel);\n  }\n}\n\nvoid CheckInplaceRegstDescId(const TaskProto& task_proto) {\n  HashSet<int64_t> consumed_regst_desc_ids;\n  for (const auto& pair : task_proto.consumed_regst_desc_id()) {\n    for (int64_t id : pair.second.regst_desc_id()) { consumed_regst_desc_ids.insert(id); }\n  }\n  for (const auto& pair : task_proto.produced_regst_desc()) {\n    if (pair.second.has_inplace_consumed_regst_desc_id() == false) { continue; }\n    int64_t in_regst_desc_id = pair.second.inplace_consumed_regst_desc_id();\n    CHECK(consumed_regst_desc_ids.find(in_regst_desc_id) != consumed_regst_desc_ids.end());\n  }\n}\n\n}  // namespace\n\nActor::~Actor() = default;\n\nvoid Actor::Init(const JobDesc* job_desc, ActorContext* actor_ctx) {\n  actor_ctx_ = actor_ctx;\n  const TaskProto& task_proto = actor_ctx->task_proto();\n  actor_id_ = task_proto.task_id();\n  thrd_id_ = ThrdId4ActorId(actor_id_);\n  job_id_ = task_proto.job_id();\n  act_cnt_ = 0;\n  op_name_ = \"NULL_OP\";\n  debug_ = EnableActorDebugLog();\n  for (const ExecNodeProto& node : task_proto.exec_sequence().exec_node()) {\n    ExecKernel ek;\n    ek.kernel_ctx.reset(new KernelContextImpl(actor_ctx));\n    ek.kernel = ConstructKernel(node.kernel_conf(), ek.kernel_ctx.get());\n    exec_kernel_vec_.emplace_back(std::move(ek));\n    op_name_ = node.kernel_conf().op_attribute().op_conf().name();\n  }\n\n  is_kernel_launch_synchronized_ =\n      std::all_of(exec_kernel_vec_.cbegin(), exec_kernel_vec_.cend(),\n                  [](const ExecKernel& ek) { return ek.kernel->IsKernelLaunchSynchronized(); });\n  if (!is_kernel_launch_synchronized_) { CHECK_EQ(exec_kernel_vec_.size(), 1); }\n\n  remaining_eord_cnt_ = 0;\n  msg_handler_ = nullptr;\n  eord_regst_desc_ids_.clear();\n\n  for (const auto& pair : task_proto.produced_regst_desc()) {\n    Singleton<RegstMgr>::Get()->NewRegsts(pair.second, [this](Regst* regst) {\n      produced_regsts_[regst->regst_desc_id()].emplace_back(regst);\n    });\n    int64_t regst_desc_id = pair.second.regst_desc_id();\n    CHECK(name2regst_desc_id_.insert({pair.first, {regst_desc_id}}).second);\n    if (pair.second.regst_desc_type().has_ctrl_regst_desc()) {\n      produced_ctrl_regst_desc_ids_.insert(regst_desc_id);\n    }\n  }\n  for (const auto& pair : produced_regsts_) {\n    for (const auto& regst : pair.second) { produced_regst2reading_cnt_[regst.get()] = 0; }\n  }\n\n  for (const auto& pair : task_proto.consumed_regst_desc_id()) {\n    CHECK(name2regst_desc_id_.find(pair.first) == name2regst_desc_id_.end());\n    std::vector<int64_t>& regst_desc_id_vec = name2regst_desc_id_[pair.first];\n    for (int64_t regst_desc_id : pair.second.regst_desc_id()) {\n      regst_desc_id_vec.emplace_back(regst_desc_id);\n    }\n    remaining_eord_cnt_ += pair.second.regst_desc_id_size();\n    if (pair.first == \"in_ctrl\") {\n      consumed_ctrl_regst_desc_ids_.insert(regst_desc_id_vec.begin(), regst_desc_id_vec.end());\n    }\n  }\n\n  total_reading_cnt_ = 0;\n  is_inplace_consumed_eord_ = false;\n  CheckInplaceRegstDescId(task_proto);\n  TakeOverInplaceConsumedAndProduced(task_proto.produced_regst_desc());\n  is_naive_consumed_eord_ = false;\n  TakeOverNaiveConsumed(task_proto.consumed_regst_desc_id());\n  TakeOverNaiveProduced(task_proto.produced_regst_desc());\n  InitBnInOp2BlobInfo(task_proto);\n  VirtualActorInit(task_proto);\n}\n\nvoid Actor::TakeOverInplaceConsumedAndProduced(\n    const PbMap<std::string, RegstDescProto>& produced_ids) {\n  for (const auto& pair : produced_ids) {\n    int64_t out_regst_desc_id = pair.second.regst_desc_id();\n    if (pair.second.has_inplace_consumed_regst_desc_id() == false) { continue; }\n    int64_t in_regst_desc_id = pair.second.inplace_consumed_regst_desc_id();\n    inplace_regst_desc_id_in2out_.insert(std::make_pair(in_regst_desc_id, out_regst_desc_id));\n    inplace_regst_desc_id_out2in_.insert(std::make_pair(out_regst_desc_id, in_regst_desc_id));\n    inplace_consumed_rs_.InsertRegstDescId(in_regst_desc_id);\n    inplace_produced_rs_.InsertRegstDescId(out_regst_desc_id);\n  }\n  inplace_consumed_rs_.InitedDone();\n  inplace_produced_rs_.InitedDone();\n  for (const auto& pair : produced_regsts_) {\n    if (inplace_produced_rs_.HasRegstDescId(pair.first)) {\n      for (const auto& regst : pair.second) {\n        CHECK_EQ(0, inplace_produced_rs_.TryPushBackRegst(regst.get()));\n        if (regst->consumers_actor_id().size() == 0) {\n          CHECK(inplace_in_ids_with_no_out_consumed_\n                    .emplace(inplace_regst_desc_id_out2in_.at(pair.first))\n                    .second);\n        }\n      }\n    }\n  }\n}\n\nvoid Actor::TakeOverNaiveConsumed(const PbMap<std::string, RegstDescIdSet>& consumed_ids) {\n  auto res = GetNaiveOrCustomizedConsumedRegstDescName();\n  bool is_naive_names = res.first == RegstNameType::kNaive;\n  const HashSet<std::string>& names = res.second;\n\n  for (const auto& pair : consumed_ids) {\n    bool find_the_name = names.find(pair.first) != names.end();\n    if (is_naive_names == find_the_name || pair.first == \"in_ctrl\") {\n      for (int64_t regst_desc_id : pair.second.regst_desc_id()) {\n        if (inplace_consumed_rs_.HasRegstDescId(regst_desc_id)) { continue; }\n        naive_consumed_rs_.InsertRegstDescId(regst_desc_id);\n      }\n    }\n  }\n  naive_consumed_rs_.InitedDone();\n}\n\nvoid Actor::TakeOverNaiveProduced(const PbMap<std::string, RegstDescProto>& produced_ids) {\n  auto res = GetNaiveOrCustomizedProducedRegstDescName();\n  bool is_naive_names = res.first == RegstNameType::kNaive;\n  const HashSet<std::string>& names = res.second;\n\n  for (const auto& pair : produced_ids) {\n    bool find_the_name = names.find(pair.first) != names.end();\n    if (inplace_produced_rs_.HasRegstDescId(pair.second.regst_desc_id())) { continue; }\n    if (is_naive_names == find_the_name || pair.first.substr(0, 9) == \"out_ctrl_\") {\n      naive_produced_rs_.InsertRegstDescId(pair.second.regst_desc_id());\n    }\n  }\n  naive_produced_rs_.InitedDone();\n\n  for (const auto& pair : produced_regsts_) {\n    if (naive_produced_rs_.HasRegstDescId(pair.first) == false) { continue; }\n    for (const auto& regst : pair.second) {\n      CHECK_EQ(0, naive_produced_rs_.TryPushBackRegst(regst.get()));\n    }\n  }\n}\n\nvoid Actor::InitBnInOp2BlobInfo(const TaskProto& task_proto) {\n  for (int64_t i = 0; i < exec_kernel_vec_.size(); ++i) {\n    ExecKernel& ek = exec_kernel_vec_.at(i);\n    const ExecNodeProto& node = task_proto.exec_sequence().exec_node(i);\n    for (auto& pair : node.kernel_conf().op_attribute().arg_signature().bn_in_op2lbi()) {\n      BlobInfo blob_info;\n      blob_info.lbi = pair.second;\n      const std::string& bn = pair.first;\n      auto regst_desc_id_it = node.bn_in_op2regst_desc_id().find(bn);\n      if (regst_desc_id_it != node.bn_in_op2regst_desc_id().end()\n          && Singleton<RegstMgr>::Get()->HasRegstDescId(regst_desc_id_it->second)) {\n        const int64_t regst_desc_id = regst_desc_id_it->second;\n        blob_info.regst_desc_id = regst_desc_id;\n        const RtRegstDesc& regst_desc =\n            Singleton<RegstMgr>::Get()->RegstDesc4RegstDescId(regst_desc_id);\n        blob_info.ordinal = regst_desc.GetOrdinalForLbi(blob_info.lbi);\n        if (naive_produced_rs_.HasRegstDescId(regst_desc_id)) {\n          blob_info.rs = &naive_produced_rs_;\n        } else if (inplace_produced_rs_.HasRegstDescId(regst_desc_id)) {\n          blob_info.rs = &inplace_produced_rs_;\n        } else if (naive_consumed_rs_.HasRegstDescId(regst_desc_id)) {\n          blob_info.rs = &naive_consumed_rs_;\n        } else if (inplace_consumed_rs_.HasRegstDescId(regst_desc_id)) {\n          blob_info.rs = &inplace_consumed_rs_;\n        } else {\n          blob_info.rs = nullptr;\n        }\n      } else {\n        blob_info.regst_desc_id = -1;\n        blob_info.ordinal = -1;\n        blob_info.rs = nullptr;\n      }\n      ek.bn_in_op2blob_info.emplace(bn, std::move(blob_info));\n    }\n  }\n}\n\nvoid Actor::ForEachProducedRegst(const std::function<void(Regst*)>& Handler) const {\n  for (const auto& pair : produced_regsts_) {\n    for (const auto& regst : pair.second) { Handler(regst.get()); }\n  }\n}\n\nint64_t Actor::Name2SoleRegstDescId(const std::string& name) const {\n  auto find_it = name2regst_desc_id_.find(name);\n  if (find_it != name2regst_desc_id_.end()) {\n    CHECK_EQ(find_it->second.size(), 1);\n    return find_it->second.front();\n  }\n  return -1;\n}\n\nconst std::vector<int64_t>& Actor::Name2RegstDescIds(const std::string& name) const {\n  return name2regst_desc_id_.at(name);\n}\n\nint64_t Actor::ReadingCnt4ProducedRegst(Regst* regst) const {\n  return produced_regst2reading_cnt_.at(regst);\n}\n\nvoid Actor::IncreaseReadingCnt4ProducedRegst(Regst* regst, int64_t val) {\n  produced_regst2reading_cnt_.at(regst) += val;\n}\n\nvoid Actor::ForEachCurNaiveReadableDataRegst(const std::function<void(const Regst*)>& func) const {\n  naive_consumed_rs_.ForEachFrontRegst([func](int64_t regst_desc_id, Regst* regst) {\n    if (Singleton<RegstMgr>::Get()->HasProducerTaskId4RegstDescId(regst_desc_id)) { return; }\n    if (regst->regst_desc()->regst_desc_type().has_data_regst_desc()) { func(regst); }\n  });\n}\n\nbool Actor::ReceiveEordMsg(int64_t regst_desc_id) const {\n  return eord_regst_desc_ids_.find(regst_desc_id) != eord_regst_desc_ids_.end();\n}\n\nint Actor::HandlerNormal(const ActorMsg& msg) {\n  if (msg.msg_type() == ActorMsgType::kEordMsg) {\n    remaining_eord_cnt_ -= 1;\n    CHECK(eord_regst_desc_ids_.insert(msg.eord_regst_desc_id()).second);\n    if (naive_consumed_rs_.HasRegstDescId(msg.eord_regst_desc_id())) {\n      is_naive_consumed_eord_ = true;\n    } else if (inplace_consumed_rs_.HasRegstDescId(msg.eord_regst_desc_id())) {\n      is_inplace_consumed_eord_ = true;\n    } else {\n      NormalProcessCustomizedEordMsg(msg);\n    }\n  } else if (msg.msg_type() == ActorMsgType::kRegstMsg) {\n    if (msg.SrcMachineId() == GlobalProcessCtx::Rank()) {\n      Regst* regst = msg.regst();\n      if (naive_consumed_rs_.HasRegstDescId(regst->regst_desc_id())) {\n        CHECK_EQ(0, naive_consumed_rs_.TryPushBackRegst(regst));\n        const auto& rdeq = naive_consumed_rs_.RegstDeq4RegstDescId(regst->regst_desc_id());\n        CHECK(rdeq.empty() == false);\n        if (rdeq.front()->regst_desc()->regst_desc_type().has_data_regst_desc()) {\n          NormalProcessNaiveReadableDataRegstMsg(rdeq);\n        }\n      } else if (inplace_consumed_rs_.HasRegstDescId(regst->regst_desc_id())) {\n        CHECK_EQ(0, inplace_consumed_rs_.TryPushBackRegst(regst));\n      } else if (TryUpdtStateAsProducedRegst(regst) == 0) {\n        // do nothing\n      } else {\n        NormalProcessCustomizedReadableRegstMsg(msg);\n      }\n    } else {\n      if (NormalTryProcessReadableMsgFromOtherMachine(msg) == false) {\n        // process ctrl msg from other rank\n        if (IsConsumedCtrlRegstDescId(msg.regst_desc_id())) {\n          Regst* regst = msg.regst();\n          CHECK(naive_consumed_rs_.HasRegstDescId(msg.regst_desc_id()));\n          CHECK(Singleton<RegstMgr>::Get()->HasProducerTaskId4RegstDescId(msg.regst_desc_id()));\n          CHECK_EQ(0, naive_consumed_rs_.TryPushBackRegst(regst, msg.regst_desc_id()));\n          const auto& rdeq = naive_consumed_rs_.RegstDeq4RegstDescId(msg.regst_desc_id());\n          CHECK(rdeq.empty() == false);\n        } else {\n          CHECK_EQ(TryUpdtStateAsProducedRegst(msg.regst()), 0);\n        }\n      }\n    }\n\n    if (debug_) {\n      LOG(INFO) << \" Actor: \" << actor_id_ << \" op: \" << op_name_ << \" in act_cnt: [ \" << act_cnt_\n                << \" ] , Recv ActorMsg from: \" << msg.src_actor_id()\n                << \" to: \" << msg.dst_actor_id() << \" with regst: \" << msg.regst_desc_id();\n    }\n    ActUntilFail();\n  } else if (msg.msg_type() == ActorMsgType::kCmdMsg) {\n    CHECK_EQ(msg.actor_cmd(), ActorCmd::kStart);\n    ActUntilFail();\n  } else {\n    UNIMPLEMENTED();\n  }\n  // handler halts\n  bool has_naive_or_inplace = naive_consumed_rs_.total_regst_desc_cnt() != 0\n                              || inplace_consumed_rs_.total_regst_desc_cnt() != 0;\n  bool naive_or_inplace_eord_and_empty =\n      (is_naive_consumed_eord_ || is_inplace_consumed_eord_)\n      && (naive_consumed_rs_.available_regst_desc_cnt() == 0\n          && inplace_consumed_rs_.available_regst_desc_cnt() == 0);\n  bool customized_eord = IsCustomizedReadAlwaysUnReadyFromNow();\n  if ((has_naive_or_inplace && naive_or_inplace_eord_and_empty)\n      || (!has_naive_or_inplace && customized_eord)) {\n    CHECK_EQ(naive_consumed_rs_.available_regst_desc_cnt(), 0);\n    AsyncReturnAllCustomizedReadableRegst();\n    AsyncSendEORDMsgForAllProducedRegstDesc();\n    if (remaining_eord_cnt_ == 0 && total_reading_cnt_ == 0) {\n      OF_SET_MSG_HANDLER(nullptr);\n      return 1;\n    } else {\n      OF_SET_MSG_HANDLER(&Actor::HandlerZombie);\n      return 0;\n    }\n  }\n  return 0;\n}\n\nint Actor::HandlerZombie(const ActorMsg& msg) {\n  if (msg.msg_type() == ActorMsgType::kEordMsg) {\n    CHECK_GE(remaining_eord_cnt_, 1);\n    remaining_eord_cnt_ -= 1;\n  } else if (msg.msg_type() == ActorMsgType::kRegstMsg) {\n    if (TryUpdtStateAsProducedRegst(msg.regst()) != 0) { AsyncSendRegstMsgToProducer(msg.regst()); }\n  } else {\n    UNIMPLEMENTED();\n  }\n  if (remaining_eord_cnt_ == 0 && total_reading_cnt_ == 0) {\n    msg_handler_ = nullptr;\n    return 1;\n  }\n  return 0;\n}\n\nvoid Actor::ActUntilFail() {\n  if (debug_) {\n    // NOTE(chengcheng): using if(debug_) code hack to minimize debug code cost when debug off.\n    LOG(INFO) << \" Actor: \" << actor_id_ << \" op: \" << op_name_ << \" Try to act before act_cnt: [ \"\n              << act_cnt_ << \" ] . And IsReadReady: \" << IsReadReady()\n              << \" IsWriteReady: \" << IsWriteReady();\n  }\n  while (IsReadReady() && IsWriteReady()) {\n    PrepareProducedNaiveInplaceDataRegst();\n\n    if (debug_) {\n      LOG(INFO) << \" Actor: \" << actor_id_ << \" op: \" << op_name_ << \" Try to act act_cnt: [ \"\n                << act_cnt_ << \" ] before launch kernel.\";\n    }\n\n    Act();\n\n    AsyncSendCustomizedProducedRegstMsgToConsumer();\n    AsyncSendNaiveProducedRegstMsgToConsumer();\n    AsyncSendInplaceProducedRegstMsgToConsumer();\n\n    AsyncSendCustomizedConsumedRegstMsgToProducer();\n    AsyncSendNaiveConsumedRegstMsgToProducer();\n    AsyncRetInplaceConsumedRegstIfNoConsumer();\n\n    AsyncSendQueuedMsg();\n\n    if (debug_) {\n      LOG(INFO) << \" Actor: \" << actor_id_ << \" op: \" << op_name_ << \" Finish act act_cnt: [ \"\n                << act_cnt_++ << \" ].\";\n    }\n  }\n  // NOTE(liujuncheng): return inplace consumed\n  AsyncSendQueuedMsg();\n}\n\nvoid Actor::AsyncSendNaiveProducedRegstMsgToConsumer() {\n  VirtualAsyncSendNaiveProducedRegstMsgToConsumer();\n  AsyncSendProducedCtrlRegstMsgToConsumer();\n}\n\nvoid Actor::VirtualAsyncSendNaiveProducedRegstMsgToConsumer() {\n  HandleProducedNaiveDataRegstToConsumer();\n}\n\nvoid Actor::AsyncSendInplaceProducedRegstMsgToConsumer() {\n  VirtualAsyncSendInplaceProducedRegstMsgToConsumer();\n}\n\nvoid Actor::AsyncRetInplaceConsumedRegstIfNoConsumer() {\n  tmp_regst_desc_id_vec_.clear();\n  inplace_consumed_rs_.ForChosenRegstDeq(\n      [&](int64_t regst_desc_id) {\n        return inplace_in_ids_with_no_out_consumed_.find(regst_desc_id)\n               != inplace_in_ids_with_no_out_consumed_.end();\n      },\n      [&](const std::deque<Regst*>& deq) {\n        if (!deq.empty()) {\n          Regst* in_regst = deq.front();\n          CHECK(in_regst);\n          AsyncSendRegstMsgToProducer(in_regst);\n          tmp_regst_desc_id_vec_.emplace_back(in_regst->regst_desc_id());\n        }\n      });\n  inplace_consumed_rs_.PopFrontRegsts(tmp_regst_desc_id_vec_);\n}\n\nvoid Actor::VirtualAsyncSendInplaceProducedRegstMsgToConsumer() {\n  HandleProducedInplaceDataRegstToConsumer();\n}\n\nvoid Actor::AsyncSendNaiveConsumedRegstMsgToProducer() {\n  VirtualAsyncSendNaiveConsumedRegstMsgToProducer();\n  AsyncSendConsumedCtrlRegstMsgToProducer();\n}\n\nvoid Actor::VirtualAsyncSendNaiveConsumedRegstMsgToProducer() {\n  HandleConsumedNaiveDataRegstToProducer();\n}\n\nvoid Actor::AsyncSendConsumedCtrlRegstMsgToProducer() {\n  auto IsChosenRegstDescId = [this](int64_t regst_desc_id) {\n    return IsConsumedCtrlRegstDescId(regst_desc_id) && ConsumedCtrlRegstValid(regst_desc_id);\n  };\n\n  tmp_regst_desc_id_vec_.clear();\n  naive_consumed_rs_.ForChosenRegstDeq(IsChosenRegstDescId, [&](int64_t regst_desc_id,\n                                                                const std::deque<Regst*>& reg_deq) {\n    CHECK(reg_deq.empty() == false);\n    auto producer_task_id = Singleton<RegstMgr>::Get()->ProducerTaskId4RegstDescId(regst_desc_id);\n    Regst* regst = reg_deq.front();\n    CHECK_GE(reg_deq.size(), 1);\n    // must access regst before sending it to producer\n    tmp_regst_desc_id_vec_.emplace_back(regst_desc_id);\n    EnqueueAsyncMsg(ActorMsg::BuildRegstMsgToProducer(actor_id_, producer_task_id, regst));\n  });\n  naive_consumed_rs_.PopFrontRegsts(tmp_regst_desc_id_vec_);\n}\n\nvoid Actor::AsyncSendProducedCtrlRegstMsgToConsumer() {\n  auto IsChosenRegstDescId = [this](int64_t regst_desc_id) {\n    return IsProducedCtrlRegstDescId(regst_desc_id) && ProducedCtrlRegstValid(regst_desc_id);\n  };\n\n  tmp_regst_desc_id_vec_.clear();\n  naive_produced_rs_.ForChosenFrontRegst(IsChosenRegstDescId, [&](Regst* regst) {\n    CHECK(regst->regst_desc()->regst_desc_type().has_ctrl_regst_desc());\n    int64_t real_consumer_cnt = HandleRegstToConsumer(regst);\n    if (real_consumer_cnt > 0) { tmp_regst_desc_id_vec_.emplace_back(regst->regst_desc_id()); }\n  });\n  naive_produced_rs_.PopFrontRegsts(tmp_regst_desc_id_vec_);\n}\n\nint64_t Actor::HandleRegstToConsumer(Regst* regst) {\n  auto regst_reading_cnt_it = produced_regst2reading_cnt_.find(regst);\n  CHECK_EQ(regst_reading_cnt_it->second, 0);\n\n  int64_t real_consumer_cnt = 0;\n  ActorMsg tpl_msg = ActorMsg::BuildRegstMsgToConsumer(actor_id_, 0, regst);\n  for (int64_t consumer : regst->consumers_actor_id()) {\n    tpl_msg.set_dst_actor_id(consumer);\n    EnqueueAsyncMsg(tpl_msg);\n    real_consumer_cnt += 1;\n  }\n  total_reading_cnt_ += real_consumer_cnt;\n  regst_reading_cnt_it->second += real_consumer_cnt;\n  return real_consumer_cnt;\n}\n\nbool Actor::IsReadReady() const {\n  return naive_consumed_rs_.IsCurSlotReady() && inplace_consumed_rs_.IsCurSlotReady()\n         && IsCustomizedReadReady();\n}\n\nbool Actor::IsWriteReady() const {\n  return naive_produced_rs_.IsCurSlotReady() && inplace_produced_rs_.IsCurSlotReady()\n         && IsCustomizedWriteReady();\n}\n\nvoid Actor::AsyncLaunchKernel(std::function<Regst*(int64_t)> Regst4RegstDescId) {\n  for (const ExecKernel& ek : exec_kernel_vec_) {\n    CHECK_NOTNULL(dynamic_cast<KernelContextImpl*>(ek.kernel_ctx.get()))\n        ->UpdateBnInOp2BlobFn([&](const std::string& bn_in_op) -> Blob* {\n          const auto blob_info_it = ek.bn_in_op2blob_info.find(bn_in_op);\n          if (blob_info_it == ek.bn_in_op2blob_info.cend()) { return nullptr; }\n          const BlobInfo& info = blob_info_it->second;\n          if (info.regst_desc_id == -1) { return nullptr; }\n          Regst* regst = nullptr;\n          if (info.rs != nullptr) {\n            regst = info.rs->Front(info.regst_desc_id);\n          } else {\n            regst = Regst4RegstDescId(info.regst_desc_id);\n          }\n          if (regst == nullptr) { return nullptr; }\n          if (info.ordinal >= 0) {\n            return regst->GetBlobByOrdinal(info.ordinal);\n          } else {\n            return regst->GetBlobByLbi(info.lbi);\n          }\n        });\n    ek.kernel->Launch(ek.kernel_ctx.get());\n  }\n}\n\nvoid Actor::AsyncLaunchKernel() {\n  AsyncLaunchKernel([](int64_t) -> Regst* {\n    UNIMPLEMENTED();\n    return nullptr;\n  });\n}\n\nvoid Actor::PrepareProducedNaiveInplaceDataRegst() {\n  naive_produced_rs_.ForEachFrontRegst([&](Regst* regst) {\n    if (regst->regst_desc()->regst_desc_type().has_data_regst_desc()) {\n      if (regst->allocation_type() == RegstAllocationType::kStreamOrdered) {\n        CHECK(regst->body_mem_ptr() == nullptr);\n        void* body_ptr = nullptr;\n        CHECK_JUST(actor_ctx_->stream_ctx()->stream()->AllocAsync(\n            &body_ptr, regst->regst_desc()->BodyByteSize4OneRegst()));\n        regst->ResetBodyMemPtr(body_ptr);\n      } else if (regst->allocation_type() == RegstAllocationType::kStatic) {\n        // do nothing\n      } else {\n        UNIMPLEMENTED();\n      }\n    }\n  });\n\n  inplace_produced_rs_.ForEachFrontRegst([&](Regst* regst) {\n    CHECK(regst->regst_desc()->regst_desc_type().has_data_regst_desc());\n    const int64_t in_regst_desc_id = inplace_regst_desc_id_out2in_.at(regst->regst_desc_id());\n    Regst* in_regst = inplace_consumed_rs_.Front(in_regst_desc_id);\n    CHECK(in_regst != nullptr);\n    if (regst->allocation_type() == RegstAllocationType::kStreamOrdered) {\n      CHECK(regst->body_mem_ptr() == nullptr);\n      regst->ResetBodyMemPtr(in_regst->body_mem_ptr());\n    } else if (regst->allocation_type() == RegstAllocationType::kStatic) {\n      // do nothing\n    } else {\n      UNIMPLEMENTED();\n    }\n  });\n}\n\nvoid Actor::HandleProducedNaiveDataRegstToConsumer() {\n  tmp_regst_desc_id_vec_.clear();\n  naive_produced_rs_.ForEachFrontRegst([&](Regst* regst) {\n    if (regst->regst_desc()->regst_desc_type().has_data_regst_desc()) {\n      int64_t real_consumer_cnt = HandleRegstToConsumer(regst);\n      if (real_consumer_cnt > 0) {\n        tmp_regst_desc_id_vec_.emplace_back(regst->regst_desc_id());\n      } else {\n        if (regst->allocation_type() == RegstAllocationType::kStreamOrdered) {\n          CHECK_JUST(actor_ctx_->stream_ctx()->stream()->FreeAsync(regst->body_mem_ptr()));\n          regst->ResetBodyMemPtr(nullptr);\n        } else if (regst->allocation_type() == RegstAllocationType::kStatic) {\n          // do nothing\n        } else {\n          UNIMPLEMENTED();\n        }\n      }\n    }\n  });\n  naive_produced_rs_.PopFrontRegsts(tmp_regst_desc_id_vec_);\n}\n\nvoid Actor::HandleProducedInplaceDataRegstToConsumer() {\n  tmp_regst_desc_id_vec_.clear();\n  inplace_produced_rs_.ForEachFrontRegst([&](Regst* regst) {\n    CHECK(regst->regst_desc()->regst_desc_type().has_data_regst_desc());\n    int64_t real_consumer_cnt = HandleRegstToConsumer(regst);\n    if (real_consumer_cnt > 0) {\n      tmp_regst_desc_id_vec_.emplace_back(regst->regst_desc_id());\n    } else {\n      if (regst->allocation_type() == RegstAllocationType::kStreamOrdered) {\n        regst->ResetBodyMemPtr(nullptr);\n      } else if (regst->allocation_type() == RegstAllocationType::kStatic) {\n        // do nothing\n      } else {\n        UNIMPLEMENTED();\n      }\n    }\n  });\n  inplace_produced_rs_.PopFrontRegsts(tmp_regst_desc_id_vec_);\n}\n\nvoid Actor::HandleConsumedNaiveDataRegstToProducer() {\n  tmp_regst_desc_id_vec_.clear();\n  naive_consumed_rs_.ForEachFrontRegst([&](int64_t regst_desc_id, Regst* regst) {\n    if (IsConsumedCtrlRegstDescId(regst_desc_id)) { return; }\n    if (regst->regst_desc()->regst_desc_type().has_data_regst_desc()) {\n      // must access regst before sending it to producer\n      tmp_regst_desc_id_vec_.emplace_back(regst->regst_desc_id());\n      EnqueueAsyncMsg(\n          ActorMsg::BuildRegstMsgToProducer(actor_id_, regst->producer_actor_id(), regst));\n    }\n  });\n  naive_consumed_rs_.PopFrontRegsts(tmp_regst_desc_id_vec_);\n}\n\nvoid Actor::AsyncSendEORDMsgForAllProducedRegstDesc() {\n  for (auto& pair : produced_regsts_) {\n    CHECK(!pair.second.empty());\n    const RtRegstDesc* regst_desc = pair.second.front()->regst_desc();\n    AddCallback([regst_desc]() {\n      for (int64_t consumer : regst_desc->consumers_actor_id()) {\n        Singleton<ActorMsgBus>::Get()->SendMsg(\n            ActorMsg::BuildEordMsg(consumer, regst_desc->regst_desc_id()));\n      }\n    });\n  }\n}\n\nvoid Actor::AsyncSendRegstMsgToProducer(Regst* regst) {\n  AsyncSendRegstMsgToProducer(regst, regst->producer_actor_id());\n}\n\nvoid Actor::AsyncSendRegstMsgToProducer(Regst* regst, int64_t producer) {\n  // must access regst before sending it to producer\n  int64_t regst_desc_id = regst->regst_desc_id();\n  EnqueueAsyncMsg(ActorMsg::BuildRegstMsgToProducer(actor_id_, producer, regst));\n  naive_consumed_rs_.TryPopFrontRegst(regst_desc_id);\n}\n\nRegst* Actor::GetSoleProducedRegst4RegstDescId(int64_t regst_desc_id) const {\n  auto it = produced_regsts_.find(regst_desc_id);\n  CHECK(it != produced_regsts_.end());\n  CHECK_EQ(it->second.size(), 1);\n  return it->second.front().get();\n}\n\nint Actor::TryUpdtStateAsProducedRegst(Regst* regst) {\n  auto reading_cnt_it = produced_regst2reading_cnt_.find(regst);\n  if (reading_cnt_it == produced_regst2reading_cnt_.end()) { return -1; }\n  CHECK(produced_regsts_.find(regst->regst_desc_id()) != produced_regsts_.end());\n  CHECK_GE(reading_cnt_it->second, 1);\n  reading_cnt_it->second -= 1;\n  total_reading_cnt_ -= 1;\n\n  if (debug_) {\n    LOG(INFO) << \" Actor: \" << actor_id_ << \" op: \" << op_name_ << \" in act_cnt: [ \" << act_cnt_\n              << \" ] recv produce_regst: \" << regst->regst_desc_id()\n              << \" and the total_reading_cnt_ is : \" << total_reading_cnt_ << \" now.\";\n  }\n\n  if (reading_cnt_it->second != 0) { return 0; }\n\n  if (inplace_produced_rs_.TryPushBackRegst(regst) == 0) {\n    int64_t in_regst_desc_id = inplace_regst_desc_id_out2in_.at(regst->regst_desc_id());\n    Regst* in_regst = inplace_consumed_rs_.Front(in_regst_desc_id);\n    CHECK(in_regst);\n    if (regst->allocation_type() == RegstAllocationType::kStreamOrdered) {\n      regst->ResetBodyMemPtr(nullptr);\n    } else if (regst->allocation_type() == RegstAllocationType::kStatic) {\n      // do nothing\n    } else {\n      UNIMPLEMENTED();\n    }\n    AsyncSendRegstMsgToProducer(in_regst);\n    CHECK_EQ(0, inplace_consumed_rs_.TryPopFrontRegst(in_regst_desc_id));\n  } else if (naive_produced_rs_.TryPushBackRegst(regst) == 0) {\n    if (regst->allocation_type() == RegstAllocationType::kStreamOrdered) {\n      CHECK_JUST(actor_ctx_->stream_ctx()->stream()->FreeAsync(regst->body_mem_ptr()));\n      regst->ResetBodyMemPtr(nullptr);\n    } else if (regst->allocation_type() == RegstAllocationType::kStatic) {\n      // do nothing\n    } else {\n      UNIMPLEMENTED();\n    }\n  } else {\n    UpdtStateAsCustomizedProducedRegst(regst);\n  }\n  return 0;\n}\n\nvoid Actor::EnqueueAsyncMsg(const ActorMsg& msg) {\n  if (is_kernel_launch_synchronized_ && thrd_id_ == ThrdId4ActorId(msg.dst_actor_id())) {\n    sync_msg_queue_.emplace_back(msg);\n  } else {\n    async_msg_queue_.emplace_back(msg);\n  }\n\n  if (debug_ && msg.msg_type() == ActorMsgType::kRegstMsg) {\n    LOG(INFO) << \" Actor: \" << actor_id_ << \" op: \" << op_name_ << \" in act_cnt: [ \" << act_cnt_\n              << \" ] post ActorMsg from: \" << msg.src_actor_id() << \" to: \" << msg.dst_actor_id()\n              << \" with regst: \" << msg.regst_desc_id();\n  }\n}\n\nRegst* Actor::GetNaiveOrInplaceCurReadable(int64_t regst_desc_id) const {\n  Regst* regst = naive_consumed_rs_.Front(regst_desc_id);\n  if (regst == nullptr) { regst = inplace_consumed_rs_.Front(regst_desc_id); }\n  return regst;\n}\n\nRegst* Actor::GetNaiveOrInplaceCurWriteable(int64_t regst_desc_id) const {\n  Regst* regst = naive_produced_rs_.Front(regst_desc_id);\n  if (regst == nullptr) { regst = inplace_produced_rs_.Front(regst_desc_id); }\n  return regst;\n}\n\nRegst* Actor::GetNaiveCurReadable(int64_t regst_desc_id) const {\n  return naive_consumed_rs_.Front(regst_desc_id);\n}\n\nRegst* Actor::GetNaiveCurWriteable(int64_t regst_desc_id) const {\n  return naive_produced_rs_.Front(regst_desc_id);\n}\n\nvoid Actor::AsyncSendQueuedMsg() {\n  if (!sync_msg_queue_.empty()) {\n    Singleton<ActorMsgBus>::Get()->SendMsgsWithoutCommNet(sync_msg_queue_.data(),\n                                                          sync_msg_queue_.size(), thrd_id_);\n    sync_msg_queue_.clear();\n  }\n  if (!async_msg_queue_.empty()) {\n    std::deque<ActorMsg> msgs;\n    msgs.swap(async_msg_queue_);\n    AddCallback([msgs]() {\n      for (const ActorMsg& msg : msgs) { Singleton<ActorMsgBus>::Get()->SendMsg(msg); }\n    });\n  }\n}\n\nvoid Actor::AddCallback(std::function<void()> callback) {\n  actor_ctx_->AddCallback(std::move(callback));\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/lazy/actor/actor.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_LAZY_ACTOR_ACTOR_H_\n#define ONEFLOW_CORE_LAZY_ACTOR_ACTOR_H_\n\n#include \"oneflow/core/lazy/actor/actor_base.h\"\n#include \"oneflow/core/lazy/actor/actor_message_bus.h\"\n#include \"oneflow/core/job/task.pb.h\"\n#include \"oneflow/core/kernel/kernel.h\"\n#include \"oneflow/core/kernel/kernel_context.h\"\n#include \"oneflow/core/register/register_manager.h\"\n#include \"oneflow/core/lazy/actor/register_slot.h\"\n\nnamespace oneflow {\n\nclass Actor : public ActorBase {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(Actor);\n  virtual ~Actor();\n\n  void Init(const JobDesc* job_desc, ActorContext* actor_ctx) override;\n\n  // 1: success, and actor finish\n  // 0: success, and actor not finish\n  int ProcessMsg(const ActorMsg& msg) override { return (this->*msg_handler_)(msg); }\n\n  int64_t machine_id() const { return MachineId4ActorId(actor_id_); }\n  int64_t actor_id() const { return actor_id_; }\n  int64_t job_id() const { return job_id_; }\n\n protected:\n  struct BlobInfo {\n    LogicalBlobId lbi;\n    int64_t regst_desc_id;\n    int64_t ordinal;\n    RegstSlot* rs;\n  };\n  struct ExecKernel {\n    std::unique_ptr<const Kernel> kernel;\n    HashMap<std::string, BlobInfo> bn_in_op2blob_info;\n    std::unique_ptr<KernelContext> kernel_ctx;\n  };\n  using MsgHandler = int (Actor::*)(const ActorMsg&);\n  enum class RegstNameType { kNaive = 0, kCustomized };\n\n  // Util\n  Actor() = default;\n  bool ReceiveAllEordMsg() const { return remaining_eord_cnt_ == 0; }\n  bool ReceiveEordMsg(int64_t regst_desc_id) const;\n  virtual void VirtualActorInit(const TaskProto&) {}\n  int64_t Name2SoleRegstDescId(const std::string& name) const;\n  const std::vector<int64_t>& Name2RegstDescIds(const std::string& name) const;\n  ActorContext* actor_ctx() const { return actor_ctx_; }\n  const std::vector<ExecKernel>& exec_kernel_vec() { return exec_kernel_vec_; }\n  void ForEachCurNaiveReadableDataRegst(const std::function<void(const Regst*)>&) const;\n\n  int64_t ReadingCnt4ProducedRegst(Regst* regst) const;\n  void IncreaseReadingCnt4ProducedRegst(Regst* regst, int64_t val);\n  void IncreaseTotalReadingCnt(int64_t val) { total_reading_cnt_ += val; }\n\n  // Msg Handler\n  void set_msg_handler(MsgHandler val) { msg_handler_ = val; }\n#define OF_SET_MSG_HANDLER(val)                                 \\\n  do {                                                          \\\n    VLOG(3) << \"actor \" << actor_id() << \" switch to \" << #val; \\\n    set_msg_handler(static_cast<MsgHandler>(val));              \\\n  } while (0)\n\n  // Common Handlers and related virtual method\n  int HandlerNormal(const ActorMsg& msg);\n  int HandlerZombie(const ActorMsg& msg);\n\n  virtual bool ConsumedCtrlRegstValid(int64_t regst_desc_id) const { return true; }\n  virtual bool ProducedCtrlRegstValid(int64_t regst_desc_id) const { return true; }\n\n  void AsyncLaunchKernel(std::function<Regst*(int64_t)> Regst4RegstDescId);\n  void AsyncLaunchKernel();\n\n  // Util For Derived Actor to Send Msg\n  void EnqueueAsyncMsg(const ActorMsg&);\n  void HandleProducedNaiveDataRegstToConsumer();\n  void PrepareProducedNaiveInplaceDataRegst();\n  void HandleProducedInplaceDataRegstToConsumer();\n\n  void HandleConsumedNaiveDataRegstToProducer();\n  void AsyncSendRegstMsgToProducer(Regst*);\n  void AsyncSendRegstMsgToProducer(Regst*, int64_t producer);\n  void AsyncSendEORDMsgForAllProducedRegstDesc();\n  void AsyncSendQueuedMsg();\n\n  // Get Regst\n  Regst* GetNaiveCurReadable(int64_t regst_desc_id) const;\n  Regst* GetNaiveCurReadable(const std::string& name) const {\n    return GetNaiveCurReadable(Name2SoleRegstDescId(name));\n  }\n  Regst* GetNaiveOrInplaceCurReadable(int64_t regst_desc_id) const;\n  Regst* GetNaiveOrInplaceCurReadable(const std::string& name) const {\n    return GetNaiveOrInplaceCurReadable(Name2SoleRegstDescId(name));\n  }\n  Regst* GetNaiveCurWriteable(int64_t regst_desc_id) const;\n  Regst* GetNaiveCurWriteable(const std::string& name) const {\n    return GetNaiveCurWriteable(Name2SoleRegstDescId(name));\n  }\n  Regst* GetNaiveOrInplaceCurWriteable(int64_t regst_desc_id) const;\n  Regst* GetNaiveOrInplaceCurWriteable(const std::string& name) const {\n    return GetNaiveOrInplaceCurWriteable(Name2SoleRegstDescId(name));\n  }\n  Regst* GetSoleProducedRegst4RegstDescId(int64_t regst_desc_id) const;\n  void ForEachProducedRegst(const std::function<void(Regst*)>&) const;\n  int64_t HandleRegstToConsumer(Regst* regst);\n\n protected:\n  bool IsConsumedCtrlRegstDescId(int64_t regst_desc_id) {\n    return consumed_ctrl_regst_desc_ids_.find(regst_desc_id) != consumed_ctrl_regst_desc_ids_.end();\n  }\n  bool IsProducedCtrlRegstDescId(int64_t regst_desc_id) {\n    return produced_ctrl_regst_desc_ids_.find(regst_desc_id) != produced_ctrl_regst_desc_ids_.end();\n  }\n\n  // Process Msg\n  virtual void NormalProcessNaiveReadableDataRegstMsg(const std::deque<Regst*>&) {}\n  virtual bool NormalTryProcessReadableMsgFromOtherMachine(const ActorMsg&) { return false; }\n  int TryUpdtStateAsProducedRegst(Regst* regst);\n\n  // Act\n  void ActUntilFail();\n  virtual void Act() { UNIMPLEMENTED(); }\n\n  // Ready\n  bool IsReadReady() const;\n  bool IsWriteReady() const;\n\n  // Naive, Inplace Or Customized\n  virtual void TakeOverInplaceConsumedAndProduced(\n      const PbMap<std::string, RegstDescProto>& produced_ids);\n  void TakeOverNaiveConsumed(const PbMap<std::string, RegstDescIdSet>& consumed_ids);\n  void TakeOverNaiveProduced(const PbMap<std::string, RegstDescProto>& produced_ids);\n  void InitBnInOp2BlobInfo(const TaskProto& task_proto);\n\n  // Send Msgs\n  void AsyncSendNaiveProducedRegstMsgToConsumer();\n  virtual void VirtualAsyncSendNaiveProducedRegstMsgToConsumer();\n  virtual void VirtualAsyncSendInplaceProducedRegstMsgToConsumer();\n  void AsyncSendInplaceProducedRegstMsgToConsumer();\n  void AsyncSendNaiveConsumedRegstMsgToProducer();\n  virtual void VirtualAsyncSendNaiveConsumedRegstMsgToProducer();\n  void AsyncSendConsumedCtrlRegstMsgToProducer();\n  void AsyncSendProducedCtrlRegstMsgToConsumer();\n\n  // Customized Consumed virtual func\n  virtual void ForEachCurCustomizedReadableRegst(std::function<void(const Regst*)>) const {}\n  virtual void NormalProcessCustomizedEordMsg(const ActorMsg&) {}\n  virtual void NormalProcessCustomizedReadableRegstMsg(const ActorMsg&) { UNIMPLEMENTED(); }\n  virtual bool IsCustomizedReadReady() const { return true; }\n  virtual bool IsCustomizedReadAlwaysUnReadyFromNow() const { return false; }\n  virtual std::pair<RegstNameType, HashSet<std::string>>\n  GetNaiveOrCustomizedConsumedRegstDescName() {\n    return std::make_pair(RegstNameType::kCustomized, HashSet<std::string>{});\n  }\n  virtual void AsyncSendCustomizedProducedRegstMsgToConsumer() {}\n  virtual void AsyncReturnAllCustomizedReadableRegst() {}\n\n  // Customized Produced virtual func\n  virtual void UpdtStateAsCustomizedProducedRegst(Regst* regst) { UNIMPLEMENTED(); }\n  virtual bool IsCustomizedWriteReady() const { return true; }\n  virtual std::pair<RegstNameType, HashSet<std::string>>\n  GetNaiveOrCustomizedProducedRegstDescName() {\n    return std::make_pair(RegstNameType::kCustomized, HashSet<std::string>{});\n  }\n  virtual void AsyncSendCustomizedConsumedRegstMsgToProducer() {}\n  void AsyncRetInplaceConsumedRegstIfNoConsumer();\n\n  virtual void AddCallback(std::function<void()> callback);\n\n  int64_t actor_id_;\n  int64_t thrd_id_;\n  int64_t job_id_;\n  std::vector<ExecKernel> exec_kernel_vec_;\n  HashMap<std::string, std::vector<int64_t>> name2regst_desc_id_;\n  MsgHandler msg_handler_;\n  ActorContext* actor_ctx_;\n  HashSet<int64_t> eord_regst_desc_ids_;\n  int64_t remaining_eord_cnt_;\n\n  HashMap<int64_t, std::vector<std::unique_ptr<Regst>>> produced_regsts_;\n  HashMap<Regst*, int64_t> produced_regst2reading_cnt_;\n  int64_t total_reading_cnt_;\n\n  RegstSlot naive_produced_rs_;\n  RegstSlot naive_consumed_rs_;\n  bool is_naive_consumed_eord_;\n\n  HashSet<int64_t> produced_ctrl_regst_desc_ids_;\n  HashSet<int64_t> consumed_ctrl_regst_desc_ids_;\n\n  RegstSlot inplace_consumed_rs_;\n  RegstSlot inplace_produced_rs_;\n  bool is_inplace_consumed_eord_;\n  HashSet<int64_t> inplace_in_ids_with_no_out_consumed_;\n  HashMap<int64_t, int64_t> inplace_regst_desc_id_in2out_;\n  HashMap<int64_t, int64_t> inplace_regst_desc_id_out2in_;\n\n  std::deque<ActorMsg> async_msg_queue_;\n  std::vector<ActorMsg> sync_msg_queue_;\n  bool is_kernel_launch_synchronized_;\n  std::vector<int64_t> tmp_regst_desc_id_vec_;\n\n  // for debug\n  std::string op_name_;\n  bool debug_;\n  int64_t act_cnt_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_LAZY_ACTOR_ACTOR_H_\n"
  },
  {
    "path": "oneflow/core/lazy/actor/actor_base.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/lazy/actor/actor.h\"\n#include \"oneflow/core/control/global_process_ctx.h\"\n#include \"oneflow/core/job/runtime_job_descs.h\"\n\nnamespace oneflow {\n\nstd::unique_ptr<ActorBase> NewActor(ActorContext* actor_ctx) {\n  ActorBase* rptr = NewObj<int32_t, ActorBase>(actor_ctx->task_proto().task_type());\n  const auto& job_descs = *Singleton<RuntimeJobDescs>::Get();\n  rptr->Init(&job_descs.job_desc(actor_ctx->task_proto().job_id()), actor_ctx);\n  return std::unique_ptr<ActorBase>(rptr);\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/lazy/actor/actor_base.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_LAZY_ACTOR_ACTOR_BASE_H_\n#define ONEFLOW_CORE_LAZY_ACTOR_ACTOR_BASE_H_\n\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/lazy/actor/actor_context.h\"\n\nnamespace oneflow {\n\nclass JobDesc;\nclass TaskProto;\nclass StreamContext;\nclass ActorMsg;\n\nclass ActorBase {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(ActorBase);\n  ActorBase() = default;\n  virtual ~ActorBase() = default;\n\n  virtual void Init(const JobDesc* job_desc, ActorContext* actor_ctx) = 0;\n\n  // 1: success, and actor finish\n  // 0: success, and actor not finish\n  virtual int ProcessMsg(const ActorMsg& msg) = 0;\n};\n\nstd::unique_ptr<ActorBase> NewActor(ActorContext* actor_ctx);\n\n#define REGISTER_ACTOR(task_type, ActorType) \\\n  REGISTER_CLASS(int32_t, task_type, ActorBase, ActorType)\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_LAZY_ACTOR_ACTOR_BASE_H_\n"
  },
  {
    "path": "oneflow/core/lazy/actor/actor_context.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/lazy/actor/actor_context.h\"\n#include \"oneflow/core/lazy/actor/generic_actor_context.h\"\n\nnamespace oneflow {\n\nstd::unique_ptr<ActorContext> NewActorContext(const TaskProto& task_proto,\n                                              StreamContext* stream_ctx) {\n  ActorContext* ctx = nullptr;\n  if (IsClassRegistered<int32_t, ActorContext>(task_proto.task_type())) {\n    ctx = NewObj<int32_t, ActorContext>(task_proto.task_type());\n  } else {\n    ctx = new GenericActorContext();\n  }\n  ctx->Init(task_proto, stream_ctx);\n  return std::unique_ptr<ActorContext>(ctx);\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/lazy/actor/actor_context.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_LAZY_ACTOR_ACTOR_CONTEXT_H_\n#define ONEFLOW_CORE_LAZY_ACTOR_ACTOR_CONTEXT_H_\n\n#include \"oneflow/core/lazy/stream_context/include/stream_context.h\"\n#include \"oneflow/core/job/task.pb.h\"\n\nnamespace oneflow {\n\nclass ActorContext {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(ActorContext);\n  ActorContext() = default;\n  virtual ~ActorContext() = default;\n\n  virtual void Init(const TaskProto& task_proto, StreamContext* stream_ctx) = 0;\n  virtual void AddCallback(std::function<void()> callback) = 0;\n\n  virtual StreamContext* stream_ctx() const = 0;\n  virtual const TaskProto& task_proto() const = 0;\n};\n\nclass ActorContextProvider {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(ActorContextProvider);\n  ActorContextProvider() = default;\n  virtual ~ActorContextProvider() = default;\n\n  virtual ActorContext* GetActorContext() const = 0;\n};\n\nstd::unique_ptr<ActorContext> NewActorContext(const TaskProto& task_proto,\n                                              StreamContext* stream_ctx);\n\n#define REGISTER_ACTOR_CONTEXT(task_type, ActorContextType) \\\n  REGISTER_CLASS(int32_t, task_type, ActorContext, ActorContextType)\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_LAZY_ACTOR_ACTOR_CONTEXT_H_\n"
  },
  {
    "path": "oneflow/core/lazy/actor/actor_message.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/lazy/actor/actor_message.h\"\n#include \"oneflow/core/control/global_process_ctx.h\"\n#include \"oneflow/core/job/id_manager.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nbool IsSoleBlobAndDynamicEmpty(Regst* regst) {\n  if (regst == nullptr) { return false; }\n  if (regst->GetBlobSize() != 1) { return false; }\n  Blob* sole_blob = regst->GetMutSoleBlob();\n  if (!regst->GetSoleBlob()->IsBodyEmpty()) { return false; }\n  const auto& shape = sole_blob->shape();\n  for (int i = 0; i < shape.NumAxes(); ++i) {\n    if (shape.At(i) != 0) { return false; }\n  }\n  return true;\n}\n\n}  // namespace\n\nActorMsg ActorMsg::BuildRegstMsgToConsumer(int64_t producer, int64_t consumer,\n                                           Regst* regst_raw_ptr) {\n  ActorMsg msg{};\n  msg.src_actor_id_ = producer;\n  msg.dst_actor_id_ = consumer;\n  msg.msg_type_ = ActorMsgType::kRegstMsg;\n  msg.regst_wrapper_.regst = regst_raw_ptr;\n  msg.regst_wrapper_.comm_net_token = nullptr;\n  msg.regst_wrapper_.regst_desc_id = regst_raw_ptr->regst_desc_id();\n  msg.regst_wrapper_.has_sole_empty_blob = IsSoleBlobAndDynamicEmpty(regst_raw_ptr);\n  msg.regst_wrapper_.is_data_regst_to_consumer =\n      regst_raw_ptr->regst_desc()->regst_desc_type().has_data_regst_desc();\n  return msg;\n}\n\nActorMsg ActorMsg::BuildRegstMsgToProducer(int64_t consumer, int64_t producer,\n                                           Regst* regst_raw_ptr) {\n  ActorMsg msg{};\n  msg.src_actor_id_ = consumer;\n  msg.dst_actor_id_ = producer;\n  msg.msg_type_ = ActorMsgType::kRegstMsg;\n  msg.regst_wrapper_.regst = regst_raw_ptr;\n  msg.regst_wrapper_.regst_desc_id = -1;\n  msg.regst_wrapper_.comm_net_token = nullptr;\n  // you can NOT access the regst ptr when multi nodes, because the address is in another machine\n  msg.regst_wrapper_.has_sole_empty_blob = false;\n  msg.regst_wrapper_.is_data_regst_to_consumer = false;\n  return msg;\n}\n\nActorMsg ActorMsg::BuildEordMsg(int64_t consumer, int64_t regst_desc_id) {\n  ActorMsg msg{};\n  msg.src_actor_id_ = -1;\n  msg.dst_actor_id_ = consumer;\n  msg.msg_type_ = ActorMsgType::kEordMsg;\n  msg.eord_regst_desc_id_ = regst_desc_id;\n  return msg;\n}\n\nActorMsg ActorMsg::BuildCommandMsg(int64_t dst_actor_id, ActorCmd cmd) {\n  ActorMsg msg{};\n  msg.src_actor_id_ = -1;\n  msg.dst_actor_id_ = dst_actor_id;\n  msg.msg_type_ = ActorMsgType::kCmdMsg;\n  msg.actor_cmd_ = cmd;\n  return msg;\n}\n\nint64_t ActorMsg::SrcMachineId() const { return MachineId4ActorId(src_actor_id_); }\n\nActorCmd ActorMsg::actor_cmd() const {\n  CHECK_EQ(msg_type_, ActorMsgType::kCmdMsg);\n  return actor_cmd_;\n}\n\nRegst* ActorMsg::regst() const {\n  CHECK_EQ(msg_type_, ActorMsgType::kRegstMsg);\n  return regst_wrapper_.regst;\n}\n\nint64_t ActorMsg::regst_desc_id() const {\n  CHECK_EQ(msg_type_, ActorMsgType::kRegstMsg);\n  // FIXME(liujunchneg): regst_desc_id for remote returned regst\n  if (MachineId4ActorId(src_actor_id_) == GlobalProcessCtx::Rank()) {\n    return regst_wrapper_.regst->regst_desc_id();\n  } else {\n    return regst_wrapper_.regst_desc_id;\n  }\n}\n\nint64_t ActorMsg::comm_net_sequence_number() const {\n  CHECK_EQ(msg_type_, ActorMsgType::kRegstMsg);\n  return regst_wrapper_.comm_net_sequence_number;\n}\n\nvoid ActorMsg::set_comm_net_sequence_number(int64_t sequence_number) {\n  CHECK_EQ(msg_type_, ActorMsgType::kRegstMsg);\n  regst_wrapper_.comm_net_sequence_number = sequence_number;\n}\n\nvoid* ActorMsg::comm_net_token() const {\n  CHECK_EQ(msg_type_, ActorMsgType::kRegstMsg);\n  return regst_wrapper_.comm_net_token;\n}\n\nvoid ActorMsg::set_comm_net_token(void* token) {\n  CHECK_EQ(msg_type_, ActorMsgType::kRegstMsg);\n  regst_wrapper_.comm_net_token = token;\n}\n\nbool ActorMsg::has_sole_empty_blob() const {\n  CHECK_EQ(msg_type_, ActorMsgType::kRegstMsg);\n  return regst_wrapper_.has_sole_empty_blob;\n}\n\nint64_t ActorMsg::eord_regst_desc_id() const {\n  CHECK_EQ(msg_type_, ActorMsgType::kEordMsg);\n  return eord_regst_desc_id_;\n}\n\nbool ActorMsg::IsDataRegstMsgToConsumer() const {\n  return msg_type_ == ActorMsgType::kRegstMsg && regst_wrapper_.is_data_regst_to_consumer;\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/lazy/actor/actor_message.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_LAZY_ACTOR_ACTOR_MESSAGE_H_\n#define ONEFLOW_CORE_LAZY_ACTOR_ACTOR_MESSAGE_H_\n\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/register/register.h\"\n\nnamespace oneflow {\n\nenum class ActorCmd {\n  kStart = 0,  // Source Actor\n  kStopThread,\n  kConstructActor\n};\n\nenum class ActorMsgType : int8_t { kRegstMsg = 0, kEordMsg, kCmdMsg };\n\nclass ActorMsg final {\n public:\n  ActorMsg() = default;\n  ~ActorMsg() = default;\n\n  // Build Msg\n  static ActorMsg BuildRegstMsgToConsumer(int64_t producer, int64_t consumer, Regst*);\n  static ActorMsg BuildRegstMsgToProducer(int64_t consumer, int64_t producer, Regst*);\n  static ActorMsg BuildEordMsg(int64_t consumer, int64_t regst_desc_id);\n  static ActorMsg BuildCommandMsg(int64_t dst_actor_id, ActorCmd cmd);\n\n  // Getters\n  int64_t SrcMachineId() const;\n  int64_t src_actor_id() const { return src_actor_id_; }\n  int64_t dst_actor_id() const { return dst_actor_id_; }\n  ActorMsgType msg_type() const { return msg_type_; }\n  ActorCmd actor_cmd() const;\n  Regst* regst() const;\n  int64_t regst_desc_id() const;\n  void* comm_net_token() const;\n  void set_comm_net_token(void* token);\n  bool has_sole_empty_blob() const;\n  int64_t eord_regst_desc_id() const;\n  bool IsDataRegstMsgToConsumer() const;\n  int64_t comm_net_sequence_number() const;\n  void set_comm_net_sequence_number(int64_t sequence_number);\n\n  // Serialize\n  template<typename StreamT>\n  void Serialize(StreamT& out_stream) const {\n    out_stream.Write(this, sizeof(ActorMsg));\n  }\n  template<typename StreamT>\n  void Deserialize(StreamT& in_stream) {\n    in_stream.Read(this, sizeof(ActorMsg));\n  }\n\n  void set_dst_actor_id(int64_t actor_id) { dst_actor_id_ = actor_id; }\n\n private:\n  struct RegstWrapper {\n    Regst* regst;\n    void* comm_net_token;\n    int64_t comm_net_sequence_number;\n    int64_t regst_desc_id;\n    bool has_sole_empty_blob;\n    bool is_data_regst_to_consumer;\n  };\n\n  int64_t src_actor_id_;\n  int64_t dst_actor_id_;\n  union {\n    ActorCmd actor_cmd_;\n    RegstWrapper regst_wrapper_;\n    int64_t eord_regst_desc_id_;\n  };\n  ActorMsgType msg_type_;\n};\n\ntemplate<typename StreamT>\nStreamT& operator<<(StreamT& out_stream, const ActorMsg& msg) {\n  msg.Serialize(out_stream);\n}\n\ntemplate<typename StreamT>\nStreamT& operator>>(StreamT& in_stream, const ActorMsg& msg) {\n  msg.Deserialize(in_stream);\n}\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_LAZY_ACTOR_ACTOR_MESSAGE_H_\n"
  },
  {
    "path": "oneflow/core/lazy/actor/actor_message_bus.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/lazy/actor/actor_message_bus.h\"\n#include \"oneflow/core/control/global_process_ctx.h\"\n#include \"oneflow/core/job/id_manager.h\"\n#include \"oneflow/core/thread/thread_manager.h\"\n#include \"oneflow/core/comm_network/comm_network.h\"\n\nnamespace oneflow {\n\nvoid ActorMsgBus::SendMsg(const ActorMsg& msg) {\n  int64_t dst_machine_id = MachineId4ActorId(msg.dst_actor_id());\n  if (dst_machine_id == GlobalProcessCtx::Rank()) {\n    SendMsgWithoutCommNet(msg);\n  } else {\n    if (msg.IsDataRegstMsgToConsumer()) {\n      int64_t comm_net_sequence;\n      {\n        std::unique_lock<std::mutex> lock(\n            regst_desc_id_dst_actor_id2comm_net_sequence_number_mutex_);\n        int64_t& comm_net_sequence_ref =\n            regst_desc_id_dst_actor_id2comm_net_sequence_number_[std::make_pair(\n                msg.regst_desc_id(), msg.dst_actor_id())];\n        comm_net_sequence = comm_net_sequence_ref;\n        comm_net_sequence_ref += 1;\n      }\n      ActorMsg new_msg = msg;\n      new_msg.set_comm_net_sequence_number(comm_net_sequence);\n      Singleton<CommNet>::Get()->SendActorMsg(dst_machine_id, new_msg);\n    } else {\n      Singleton<CommNet>::Get()->SendActorMsg(dst_machine_id, msg);\n    }\n  }\n}\n\nvoid ActorMsgBus::SendMsgWithoutCommNet(const ActorMsg& msg) {\n  CHECK_EQ(MachineId4ActorId(msg.dst_actor_id()), GlobalProcessCtx::Rank());\n  int64_t thrd_id = ThrdId4ActorId(msg.dst_actor_id());\n  Singleton<ThreadMgr>::Get()->GetThrd(thrd_id)->EnqueueActorMsg(msg);\n}\n\nvoid ActorMsgBus::SendMsgsWithoutCommNet(const ActorMsg* msgs, size_t n, int64_t thrd_id) {\n  Singleton<ThreadMgr>::Get()->GetThrd(thrd_id)->EnqueueActorMsg(msgs, msgs + n);\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/lazy/actor/actor_message_bus.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_LAZY_ACTOR_ACTOR_MESSAGE_BUS_H_\n#define ONEFLOW_CORE_LAZY_ACTOR_ACTOR_MESSAGE_BUS_H_\n\n#include \"oneflow/core/lazy/actor/actor_message.h\"\n#include \"oneflow/core/common/util.h\"\n\nnamespace oneflow {\n\nclass ActorMsgBus final {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(ActorMsgBus);\n  ~ActorMsgBus() = default;\n\n  void SendMsg(const ActorMsg& msg);\n  void SendMsgWithoutCommNet(const ActorMsg& msg);\n  void SendMsgsWithoutCommNet(const ActorMsg* msgs, size_t n, int64_t thrd_id);\n\n private:\n  friend class Singleton<ActorMsgBus>;\n  ActorMsgBus() = default;\n  HashMap<std::pair<int64_t, int64_t>, int64_t>\n      regst_desc_id_dst_actor_id2comm_net_sequence_number_;\n  std::mutex regst_desc_id_dst_actor_id2comm_net_sequence_number_mutex_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_LAZY_ACTOR_ACTOR_MESSAGE_BUS_H_\n"
  },
  {
    "path": "oneflow/core/lazy/actor/boxing_zeros_actor.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/lazy/actor/naive_actor.h\"\n\nnamespace oneflow {\n\nclass BoxingZerosActor : public NaiveActor {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(BoxingZerosActor);\n  BoxingZerosActor() = default;\n  ~BoxingZerosActor() override = default;\n\n  void VirtualActorInit(const TaskProto& task_proto) override {\n    NaiveActor::VirtualActorInit(task_proto);\n    out_inited_ = false;\n  }\n\n private:\n  void Act() override {\n    if (!out_inited_) {\n      NaiveActor::Act();\n      out_inited_ = true;\n    }\n  }\n\n  void VirtualAsyncSendNaiveProducedRegstMsgToConsumer() override {\n    HandleProducedNaiveDataRegstToConsumer();\n  }\n\n  bool out_inited_;\n};\n\nREGISTER_ACTOR(TaskType::kBoxingZeros, BoxingZerosActor);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/lazy/actor/callback_notify_actor.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/lazy/actor/sink_actor.h\"\n\nnamespace oneflow {\n\nclass CallbackNotifyActor final : public SinkActor {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CallbackNotifyActor);\n  CallbackNotifyActor() = default;\n  ~CallbackNotifyActor() = default;\n\n private:\n};\n\nREGISTER_ACTOR(TaskType::kCallbackNotify, CallbackNotifyActor);\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/lazy/actor/case_actor.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/lazy/actor/actor.h\"\n#include \"oneflow/core/kernel/case_kernel.h\"\n#include \"oneflow/core/operator/operator.h\"\n\nnamespace oneflow {\n\nclass CaseActor final : public Actor {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CaseActor);\n  CaseActor() : case_status_(nullptr) {}\n  ~CaseActor() override = default;\n\n protected:\n  bool IsCustomizedReadReady() const override;\n  bool IsCustomizedWriteReady() const override;\n  bool IsCustomizedReadAlwaysUnReadyFromNow() const override;\n  void UpdtStateAsCustomizedProducedRegst(Regst* regst) override;\n  void AsyncSendCustomizedProducedRegstMsgToConsumer() override;\n  void AsyncSendCustomizedConsumedRegstMsgToProducer() override;\n  void ForEachCurCustomizedReadableRegst(std::function<void(const Regst*)>) const override;\n  void VirtualActorInit(const TaskProto&) override;\n  bool ProducedCtrlRegstValid(int64_t regst_desc_id) const override;\n  void NormalProcessCustomizedReadableRegstMsg(const ActorMsg&) override;\n  void NormalProcessCustomizedEordMsg(const ActorMsg&) override {}\n  std::pair<RegstNameType, HashSet<std::string>> GetNaiveOrCustomizedConsumedRegstDescName()\n      override {\n    return std::make_pair(RegstNameType::kNaive, HashSet<std::string>{});\n  }\n  std::pair<RegstNameType, HashSet<std::string>> GetNaiveOrCustomizedProducedRegstDescName()\n      override {\n    return std::make_pair(RegstNameType::kNaive, HashSet<std::string>{});\n  }\n\n private:\n  void Act() override;\n  void TakeOverConsumedRegst(const PbMap<std::string, RegstDescIdSet>& consumed_ids);\n  void TakeOverProducedRegst(const PbMap<std::string, RegstDescProto>& produced_ids);\n  bool IsInputOrOutputReady() const;\n  int64_t GetCurSelectId() const;\n\n  HashMap<int64_t, int64_t> out_bn_id2regst_desc_id_;\n  int64_t consumed_regst_desc_id_{};\n  RegstSlot consumed_rs_;\n  HashMap<int64_t, RegstSlot> regst_desc_id2produced_rs_;\n  CaseStatus* case_status_;\n};\n\nvoid CaseActor::VirtualActorInit(const TaskProto& task_proto) {\n  CHECK_EQ(1, exec_kernel_vec().size());\n  case_status_ =\n      CHECK_NOTNULL(dynamic_cast<CaseStatus*>(exec_kernel_vec().at(0).kernel_ctx->state().get()));\n  const int32_t output_bns_size =\n      task_proto.exec_sequence().exec_node().Get(0).kernel_conf().op_attribute().output_bns_size();\n  FOR_RANGE(int64_t, i, 0, output_bns_size) {\n    const int64_t regst_desc_id =\n        exec_kernel_vec().at(0).bn_in_op2blob_info.at(GenRepeatedBn(\"out\", i)).regst_desc_id;\n    CHECK(out_bn_id2regst_desc_id_.emplace(i, regst_desc_id).second);\n  }\n  TakeOverConsumedRegst(task_proto.consumed_regst_desc_id());\n  TakeOverProducedRegst(task_proto.produced_regst_desc());\n  OF_SET_MSG_HANDLER(&CaseActor::HandlerNormal);\n}\n\nvoid CaseActor::TakeOverConsumedRegst(const PbMap<std::string, RegstDescIdSet>& consumed_ids) {\n  CHECK_EQ(consumed_ids.size(), 1);\n  const auto& pair = *consumed_ids.begin();\n  CHECK_EQ(pair.second.regst_desc_id_size(), 1);\n  consumed_regst_desc_id_ = pair.second.regst_desc_id(0);\n  consumed_rs_.InsertRegstDescId(consumed_regst_desc_id_);\n  consumed_rs_.InitedDone();\n}\n\nvoid CaseActor::TakeOverProducedRegst(const PbMap<std::string, RegstDescProto>& produced_ids) {\n  for (const auto& pair : produced_ids) {\n    CHECK(pair.second.regst_desc_type().has_data_regst_desc());\n    CHECK_EQ(pair.second.has_inplace_consumed_regst_desc_id(), false);\n    const int64_t regst_desc_id = pair.second.regst_desc_id();\n    regst_desc_id2produced_rs_[regst_desc_id].InsertRegstDescId(regst_desc_id);\n    regst_desc_id2produced_rs_.at(regst_desc_id).InitedDone();\n  }\n  ForEachProducedRegst([&](Regst* regst) {\n    const int64_t regst_desc_id = regst->regst_desc_id();\n    CHECK_EQ(0, regst_desc_id2produced_rs_.at(regst_desc_id).TryPushBackRegst(regst));\n  });\n}\n\n// twice called for each output\n// first called: set cur_selected_id\n// second called: output cur_selected_id\nvoid CaseActor::Act() {\n  Regst* const consumed_regst = consumed_rs_.Front(consumed_regst_desc_id_);\n  case_status_->cur_selected_id = GetCurSelectId();\n  case_status_->cmd =\n      (case_status_->cur_selected_id == -1 ? kCaseCmdHandleInput : kCaseCmdHandleOutput);\n  AsyncLaunchKernel([&](int64_t regst_desc_id) -> Regst* {\n    if (consumed_regst_desc_id_ == regst_desc_id) { return consumed_regst; }\n    return regst_desc_id2produced_rs_.at(regst_desc_id).Front(regst_desc_id);\n  });\n}\n\nvoid CaseActor::UpdtStateAsCustomizedProducedRegst(Regst* regst) {\n  const int64_t regst_desc_id = regst->regst_desc_id();\n  CHECK_EQ(0, regst_desc_id2produced_rs_.at(regst_desc_id).TryPushBackRegst(regst));\n}\n\nbool CaseActor::IsCustomizedReadReady() const { return IsInputOrOutputReady(); }\n\nbool CaseActor::IsCustomizedWriteReady() const { return IsInputOrOutputReady(); }\n\nbool CaseActor::IsCustomizedReadAlwaysUnReadyFromNow() const {\n  return ReceiveEordMsg(consumed_regst_desc_id_) && case_status_->select_id2request_cnt.size() == 0;\n}\n\nbool CaseActor::IsInputOrOutputReady() const {\n  if (GetCurSelectId() != -1) { return true; }\n  return consumed_rs_.IsCurSlotReady();\n}\n\nint64_t CaseActor::GetCurSelectId() const {\n  for (const auto& pair : case_status_->select_id2request_cnt) {\n    CHECK_GT(pair.second, 0);\n    const int64_t regst_desc_id = out_bn_id2regst_desc_id_.at(pair.first);\n    if (regst_desc_id2produced_rs_.at(regst_desc_id).IsCurSlotReady()) { return pair.first; }\n  }\n  return -1;\n}\n\nvoid CaseActor::ForEachCurCustomizedReadableRegst(std::function<void(const Regst*)> Handler) const {\n  Handler(consumed_rs_.Front(consumed_regst_desc_id_));\n}\n\nvoid CaseActor::AsyncSendCustomizedConsumedRegstMsgToProducer() {\n  if (case_status_->cmd != kCaseCmdHandleInput) { return; }\n  Regst* const cur_regst = consumed_rs_.Front(consumed_regst_desc_id_);\n  CHECK_NOTNULL(cur_regst);\n  AsyncSendRegstMsgToProducer(cur_regst);\n  CHECK_EQ(0, consumed_rs_.TryPopFrontRegst(consumed_regst_desc_id_));\n}\n\nvoid CaseActor::NormalProcessCustomizedReadableRegstMsg(const ActorMsg& msg) {\n  CHECK_EQ(0, consumed_rs_.TryPushBackRegst(msg.regst()));\n}\n\nvoid CaseActor::AsyncSendCustomizedProducedRegstMsgToConsumer() {\n  if (case_status_->cmd != kCaseCmdHandleOutput) { return; }\n  const int64_t regst_desc_id = out_bn_id2regst_desc_id_.at(case_status_->cur_selected_id);\n  Regst* const regst = regst_desc_id2produced_rs_.at(regst_desc_id).Front(regst_desc_id);\n  CHECK_GT(HandleRegstToConsumer(regst), 0);\n  regst_desc_id2produced_rs_.at(regst_desc_id).PopFrontRegsts({regst_desc_id});\n}\n\nbool CaseActor::ProducedCtrlRegstValid(int64_t regst_desc_id) const { return true; }\n\nREGISTER_ACTOR(kCase, CaseActor);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/lazy/actor/collective_boxing_actor_context.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/lazy/actor/collective_boxing_actor_context.h\"\n#include \"oneflow/core/job/collective_boxing/scheduler.h\"\n\nnamespace oneflow {\n\nusing namespace boxing::collective;\n\nvoid CollectiveBoxingActorContext::Init(const TaskProto& task_proto, StreamContext* stream_ctx) {\n  stream_ctx_ = stream_ctx;\n  task_proto_ = task_proto;\n  scheduled_count_ = 0;\n  completed_count_ = 0;\n}\n\nvoid CollectiveBoxingActorContext::AddCallback(std::function<void()> callback) {\n  std::lock_guard<std::mutex> lock(mutex_);\n  if (scheduled_count_ == completed_count_) {\n    callback();\n  } else {\n    callbacks_.emplace_back(std::make_pair(scheduled_count_ - 1, std::move(callback)));\n  }\n}\n\nvoid CollectiveBoxingActorContext::Schedule(RequestHandle* handle, const void* send_buff,\n                                            void* recv_buff) {\n  std::lock_guard<std::mutex> lock(mutex_);\n  auto request = std::make_shared<boxing::collective::RuntimeRequestInfo>();\n  request->send_buff = send_buff;\n  request->recv_buff = recv_buff;\n  const size_t schedule_id = scheduled_count_;\n  request->callback = [schedule_id, this](const Maybe<void>& status) {\n    CHECK(status.IsOk());\n    this->SetCompleted(schedule_id);\n  };\n  Singleton<Scheduler>::Get()->Schedule(handle, request);\n  scheduled_count_ += 1;\n}\n\nvoid CollectiveBoxingActorContext::SetCompleted(size_t schedule_id) {\n  std::lock_guard<std::mutex> lock(mutex_);\n  CHECK_EQ(schedule_id, completed_count_);\n  while (!callbacks_.empty() && callbacks_.front().first == schedule_id) {\n    callbacks_.front().second();\n    callbacks_.pop_front();\n  }\n  completed_count_ += 1;\n}\n\nStreamContext* CollectiveBoxingActorContext::stream_ctx() const { return stream_ctx_; }\n\nconst TaskProto& CollectiveBoxingActorContext::task_proto() const { return task_proto_; }\n\nREGISTER_ACTOR_CONTEXT(TaskType::kCollectiveBoxingGeneric, CollectiveBoxingActorContext);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/lazy/actor/collective_boxing_actor_context.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_LAZY_ACTOR_COLLECTIVE_BOXING_ACTOR_CONTEXT_H_\n#define ONEFLOW_CORE_LAZY_ACTOR_COLLECTIVE_BOXING_ACTOR_CONTEXT_H_\n\n#include \"oneflow/core/lazy/actor/actor_context.h\"\n#include \"oneflow/core/job/collective_boxing/scheduler.h\"\n\nnamespace oneflow {\n\nclass CollectiveBoxingActorContext : public ActorContext {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CollectiveBoxingActorContext);\n  CollectiveBoxingActorContext() = default;\n  ~CollectiveBoxingActorContext() override = default;\n\n  void Init(const TaskProto& task_proto, StreamContext* stream_ctx) override;\n  void AddCallback(std::function<void()> callback) override;\n  void Schedule(boxing::collective::RequestHandle* handle, const void* send_buff, void* recv_buff);\n  void SetCompleted(size_t schedule_id);\n\n  StreamContext* stream_ctx() const override;\n  const TaskProto& task_proto() const override;\n\n private:\n  StreamContext* stream_ctx_{};\n  TaskProto task_proto_{};\n  size_t scheduled_count_{};\n  size_t completed_count_{};\n  std::mutex mutex_;\n  std::deque<std::pair<size_t, std::function<void()>>> callbacks_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_LAZY_ACTOR_COLLECTIVE_BOXING_ACTOR_CONTEXT_H_\n"
  },
  {
    "path": "oneflow/core/lazy/actor/copy_comm_net_actor.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/lazy/actor/actor.h\"\n#include \"oneflow/core/comm_network/comm_network.h\"\n#include \"oneflow/core/register/register.h\"\n\nnamespace oneflow {\n\nclass CopyCommNetActor final : public Actor {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CopyCommNetActor);\n  CopyCommNetActor() = default;\n  ~CopyCommNetActor();\n\n private:\n  struct RegstCtx {\n    void* comm_net_token;\n    Regst* regst_raw_ptr;\n    int64_t producer;\n    bool has_sole_empty_blob;\n  };\n\n  void VirtualActorInit(const TaskProto&) override;\n\n  std::pair<RegstNameType, HashSet<std::string>> GetNaiveOrCustomizedConsumedRegstDescName()\n      override {\n    return std::make_pair(RegstNameType::kNaive, HashSet<std::string>{});\n  }\n  void ForEachCurCustomizedReadableRegst(std::function<void(const Regst*)>) const override;\n  void NormalProcessCustomizedEordMsg(const ActorMsg&) override { is_in_eord_ = true; }\n  bool NormalTryProcessReadableMsgFromOtherMachine(const ActorMsg&) override;\n  void Act() override;\n  void VirtualAsyncSendNaiveProducedRegstMsgToConsumer() override;\n  void AsyncSendCustomizedConsumedRegstMsgToProducer() override;\n  bool IsCustomizedReadReady() const override;\n  bool IsCustomizedReadAlwaysUnReadyFromNow() const override;\n  void AsyncReturnAllCustomizedReadableRegst() override;\n  void AddCallback(std::function<void()> callback) override;\n  bool is_in_eord_;\n  HashMap<int64_t, RegstCtx> sequence_number2regst_ctx_;\n  void* actor_read_id_;\n  int64_t next_sequence_number_;\n  int64_t in_regst_desc_id_;\n};\n\nCopyCommNetActor::~CopyCommNetActor() {\n  Singleton<CommNet>::Get()->DeleteActorReadId(actor_read_id_);\n}\n\nvoid CopyCommNetActor::VirtualActorInit(const TaskProto& task_proto) {\n  is_in_eord_ = false;\n  next_sequence_number_ = 0;\n  in_regst_desc_id_ = Name2SoleRegstDescId(\"copy_in\");\n  actor_read_id_ = Singleton<CommNet>::Get()->NewActorReadId();\n  OF_SET_MSG_HANDLER(&CopyCommNetActor::HandlerNormal);\n}\n\nvoid CopyCommNetActor::ForEachCurCustomizedReadableRegst(\n    std::function<void(const Regst*)> handler) const {\n  handler(sequence_number2regst_ctx_.at(next_sequence_number_).regst_raw_ptr);\n}\n\nbool CopyCommNetActor::NormalTryProcessReadableMsgFromOtherMachine(const ActorMsg& msg) {\n  RegstCtx regst_ctx;\n  regst_ctx.comm_net_token = msg.comm_net_token();\n  regst_ctx.regst_raw_ptr = msg.regst();\n  regst_ctx.producer = msg.src_actor_id();\n  regst_ctx.has_sole_empty_blob = msg.has_sole_empty_blob();\n  CHECK(sequence_number2regst_ctx_.emplace(msg.comm_net_sequence_number(), regst_ctx).second);\n  return true;\n}\n\nvoid CopyCommNetActor::Act() {\n  // readable\n  auto readable_it = sequence_number2regst_ctx_.find(next_sequence_number_);\n  void* readable_token = readable_it->second.comm_net_token;\n  int64_t src_actor_id = readable_it->second.producer;\n  int64_t src_machine_id = MachineId4ActorId(src_actor_id);\n  // writeable\n  Regst* writeable_regst = GetNaiveCurWriteable(\"copy_out\");\n  if (readable_it->second.has_sole_empty_blob) {\n    // pass if regst dynamic body is emtpy\n    Blob* data_blob = writeable_regst->GetMutSoleBlob();\n    Shape empty_shape = data_blob->static_shape();\n    for (int i = 0; i < empty_shape.NumAxes(); ++i) { empty_shape.Set(i, 0); }\n    data_blob->mut_shape_view()->set_shape(empty_shape);\n  } else {\n    void* writeable_token = writeable_regst->comm_net_token();\n    // Async\n    Singleton<CommNet>::Get()->Read(actor_read_id_, src_machine_id, readable_token,\n                                    writeable_token);\n  }\n}\n\nvoid CopyCommNetActor::VirtualAsyncSendNaiveProducedRegstMsgToConsumer() {\n  HandleProducedNaiveDataRegstToConsumer();\n}\n\nvoid CopyCommNetActor::AsyncSendCustomizedConsumedRegstMsgToProducer() {\n  auto readable_it = sequence_number2regst_ctx_.find(next_sequence_number_);\n  EnqueueAsyncMsg(ActorMsg::BuildRegstMsgToProducer(actor_id(), readable_it->second.producer,\n                                                    readable_it->second.regst_raw_ptr));\n  sequence_number2regst_ctx_.erase(readable_it);\n  next_sequence_number_ += 1;\n}\n\nbool CopyCommNetActor::IsCustomizedReadReady() const {\n  return sequence_number2regst_ctx_.find(next_sequence_number_) != sequence_number2regst_ctx_.end();\n}\n\nbool CopyCommNetActor::IsCustomizedReadAlwaysUnReadyFromNow() const {\n  return is_in_eord_ && sequence_number2regst_ctx_.empty();\n}\n\nvoid CopyCommNetActor::AsyncReturnAllCustomizedReadableRegst() {\n  CHECK(sequence_number2regst_ctx_.empty());\n}\n\nvoid CopyCommNetActor::AddCallback(std::function<void()> callback) {\n  Singleton<CommNet>::Get()->AddReadCallBack(actor_read_id_, callback);\n}\n\nREGISTER_ACTOR(TaskType::kCopyCommNet, CopyCommNetActor);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/lazy/actor/esac_actor.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/lazy/actor/actor.h\"\n#include \"oneflow/core/operator/operator.h\"\n#include \"oneflow/core/kernel/esac_kernel.h\"\n\nnamespace oneflow {\n\nclass EsacActor final : public Actor {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(EsacActor);\n  EsacActor() = default;\n  ~EsacActor() override = default;\n\n protected:\n  void VirtualActorInit(const TaskProto&) override;\n  int64_t InBnId4RegstDescId(int64_t id) const { return regst_desc_id2in_bn_id_.at(id); }\n\n  bool ProducedCtrlRegstValid(int64_t regst_desc_id) const override;\n\n private:\n  void Act() override;\n  void NormalProcessCustomizedReadableRegstMsg(const ActorMsg&) override;\n  void ForEachCurCustomizedReadableRegst(std::function<void(const Regst*)>) const override;\n  bool IsCustomizedReadReady() const override;\n  void NormalProcessCustomizedEordMsg(const ActorMsg&) override {}\n  bool IsCustomizedReadAlwaysUnReadyFromNow() const override {\n    return ReceiveAllEordMsg() && consumed_rs_.available_regst_desc_cnt() == 0;\n  }\n  void AsyncReturnAllCustomizedReadableRegst() override;\n  std::pair<RegstNameType, HashSet<std::string>> GetNaiveOrCustomizedConsumedRegstDescName()\n      override {\n    return std::make_pair(RegstNameType::kNaive, HashSet<std::string>{});\n  }\n  void VirtualAsyncSendNaiveProducedRegstMsgToConsumer() override;\n  void AsyncSendCustomizedConsumedRegstMsgToProducer() override;\n  int64_t GetCurProcessedRegstDescId() const;\n\n  RegstSlot consumed_rs_;\n  int64_t cur_processed_regst_desc_id_{};\n  HashMap<int64_t, int64_t> regst_desc_id2in_bn_id_;\n};\n\nvoid EsacActor::VirtualActorInit(const TaskProto& task_proto) {\n  CHECK_EQ(1, exec_kernel_vec().size());\n  const int32_t input_bns_size =\n      task_proto.exec_sequence().exec_node().Get(0).kernel_conf().op_attribute().input_bns_size();\n  FOR_RANGE(int64_t, i, 0, input_bns_size) {\n    const int64_t regst_desc_id =\n        exec_kernel_vec().at(0).bn_in_op2blob_info.at(GenRepeatedBn(\"in\", i)).regst_desc_id;\n    CHECK(regst_desc_id2in_bn_id_.emplace(regst_desc_id, i).second);\n  }\n  for (const auto& pair : task_proto.consumed_regst_desc_id()) {\n    for (const int64_t regst_desc_id : pair.second.regst_desc_id()) {\n      consumed_rs_.InsertRegstDescId(regst_desc_id);\n    }\n  }\n  consumed_rs_.InitedDone();\n  cur_processed_regst_desc_id_ = -1;\n  OF_SET_MSG_HANDLER(&EsacActor::HandlerNormal);\n}\n\nvoid EsacActor::NormalProcessCustomizedReadableRegstMsg(const ActorMsg& msg) {\n  CHECK_EQ(0, consumed_rs_.TryPushBackRegst(msg.regst()));\n}\n\nbool EsacActor::IsCustomizedReadReady() const { return -1 != GetCurProcessedRegstDescId(); }\n\nvoid EsacActor::ForEachCurCustomizedReadableRegst(std::function<void(const Regst*)> handler) const {\n  handler(consumed_rs_.Front(cur_processed_regst_desc_id_));\n}\n\nvoid EsacActor::Act() {\n  cur_processed_regst_desc_id_ = GetCurProcessedRegstDescId();\n  Regst* cur_regst = consumed_rs_.Front(cur_processed_regst_desc_id_);\n  CHECK(cur_regst);\n  int64_t in_bn_id = InBnId4RegstDescId(cur_processed_regst_desc_id_);\n  CHECK_EQ(exec_kernel_vec().size(), 1);\n  CHECK_NOTNULL(dynamic_cast<EsacKernelState*>(exec_kernel_vec().at(0).kernel_ctx->state().get()))\n      ->value = in_bn_id;\n  AsyncLaunchKernel([&](int64_t regst_desc_id) -> Regst* {\n    if (cur_processed_regst_desc_id_ != regst_desc_id) { return nullptr; }\n    return cur_regst;\n  });\n}\n\nvoid EsacActor::VirtualAsyncSendNaiveProducedRegstMsgToConsumer() {\n  HandleProducedNaiveDataRegstToConsumer();\n}\n\nvoid EsacActor::AsyncSendCustomizedConsumedRegstMsgToProducer() {\n  Regst* cur_regst = consumed_rs_.Front(cur_processed_regst_desc_id_);\n  CHECK(cur_regst);\n  AsyncSendRegstMsgToProducer(cur_regst);\n  CHECK_EQ(0, consumed_rs_.TryPopFrontRegst(cur_processed_regst_desc_id_));\n  cur_processed_regst_desc_id_ = -1;\n}\n\nvoid EsacActor::AsyncReturnAllCustomizedReadableRegst() {\n  CHECK_EQ(-1, cur_processed_regst_desc_id_);\n  CHECK_EQ(0, consumed_rs_.available_regst_desc_cnt());\n}\n\nbool EsacActor::ProducedCtrlRegstValid(int64_t regst_desc_id) const { return true; }\n\nint64_t EsacActor::GetCurProcessedRegstDescId() const {\n  int64_t cur_processed_regst_desc_id = -1;\n  consumed_rs_.ForChosenRegstDeq(\n      [&cur_processed_regst_desc_id](int64_t) { return cur_processed_regst_desc_id == -1; },\n      [&cur_processed_regst_desc_id](const std::deque<Regst*>& reg_deq) {\n        if (reg_deq.empty()) { return; }\n        cur_processed_regst_desc_id = reg_deq.front()->regst_desc_id();\n      });\n  return cur_processed_regst_desc_id;\n}\n\nREGISTER_ACTOR(kEsac, EsacActor);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/lazy/actor/generic_actor_context.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/lazy/actor/generic_actor_context.h\"\n\nnamespace oneflow {\n\nvoid GenericActorContext::Init(const TaskProto& task_proto, StreamContext* stream_ctx) {\n  stream_ctx_ = stream_ctx;\n  task_proto_ = task_proto;\n}\n\nvoid GenericActorContext::AddCallback(std::function<void()> callback) {\n  CHECK_JUST(stream_ctx_->AddCallback(std::move(callback)));\n}\n\nStreamContext* GenericActorContext::stream_ctx() const { return stream_ctx_; }\n\nconst TaskProto& GenericActorContext::task_proto() const { return task_proto_; }\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/lazy/actor/generic_actor_context.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_LAZY_ACTOR_GENERIC_ACTOR_CONTEXT_H_\n#define ONEFLOW_CORE_LAZY_ACTOR_GENERIC_ACTOR_CONTEXT_H_\n\n#include \"oneflow/core/lazy/actor/actor_context.h\"\n\nnamespace oneflow {\n\nclass GenericActorContext : public ActorContext {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(GenericActorContext);\n  GenericActorContext() = default;\n  ~GenericActorContext() override = default;\n\n  void Init(const TaskProto& task_proto, StreamContext* stream_ctx) override;\n  void AddCallback(std::function<void()> callback) override;\n\n  StreamContext* stream_ctx() const override;\n  const TaskProto& task_proto() const override;\n\n private:\n  StreamContext* stream_ctx_{};\n  TaskProto task_proto_{};\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_LAZY_ACTOR_GENERIC_ACTOR_CONTEXT_H_\n"
  },
  {
    "path": "oneflow/core/lazy/actor/input_wise_actor.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/lazy/actor/input_wise_actor.h\"\n\nnamespace oneflow {\n\nvoid InputWiseActor::Init(const TaskProto& task_proto) {\n  CHECK_EQ(1, exec_kernel_vec().size());\n  const auto& input_bns =\n      task_proto.exec_sequence().exec_node().Get(0).kernel_conf().op_attribute().input_bns();\n  HashMap<std::string, int64_t> ibn2in_bn_id;\n  for (int64_t i = 0; i < input_bns.size(); ++i) {\n    CHECK(ibn2in_bn_id.emplace(input_bns.Get(i), i).second);\n  }\n  for (const auto& pair : exec_kernel_vec().at(0).bn_in_op2blob_info) {\n    auto it = ibn2in_bn_id.find(pair.first);\n    if (it != ibn2in_bn_id.end()) {\n      CHECK(regst_desc_id2in_bn_id_.emplace(pair.second.regst_desc_id, it->second).second);\n    }\n  }\n\n  for (const auto& pair : task_proto.consumed_regst_desc_id()) {\n    for (int64_t regst_desc_id : pair.second.regst_desc_id()) {\n      consumed_rs_.InsertRegstDescId(regst_desc_id);\n      CHECK(regst_desc_id2is_processed_.emplace(regst_desc_id, false).second);\n    }\n  }\n  consumed_rs_.InitedDone();\n  cur_processed_regst_desc_id_ = -1;\n  processed_regst_desc_id_cnt_ = 0;\n  OF_SET_MSG_HANDLER(&InputWiseActor::HandlerNormal);\n}\n\nvoid InputWiseActor::NormalProcessCustomizedReadableRegstMsg(const ActorMsg& msg) {\n  CHECK_EQ(0, consumed_rs_.TryPushBackRegst(msg.regst()));\n}\n\nbool InputWiseActor::IsCustomizedReadReady() const { return -1 != GetCurProcessedRegstDescId(); }\n\nvoid InputWiseActor::ForEachCurCustomizedReadableRegst(\n    std::function<void(const Regst*)> handler) const {\n  handler(consumed_rs_.Front(cur_processed_regst_desc_id_));\n}\n\nvoid InputWiseActor::Act() {\n  cur_processed_regst_desc_id_ = GetCurProcessedRegstDescId();\n  Regst* cur_regst = consumed_rs_.Front(cur_processed_regst_desc_id_);\n  CHECK(cur_regst);\n  AsyncLaunchKernel([&](int64_t regst_desc_id) -> Regst* {\n    if (cur_processed_regst_desc_id_ != regst_desc_id) { return nullptr; }\n    return cur_regst;\n  });\n  processed_regst_desc_id_cnt_ += 1;\n  regst_desc_id2is_processed_.at(cur_processed_regst_desc_id_) = true;\n}\n\nvoid InputWiseActor::VirtualAsyncSendNaiveProducedRegstMsgToConsumer() {\n  if (processed_regst_desc_id_cnt_ == regst_desc_id2is_processed_.size()) {\n    HandleProducedNaiveDataRegstToConsumer();\n    for (auto& pair : regst_desc_id2is_processed_) {\n      CHECK(pair.second);\n      pair.second = false;\n    }\n    processed_regst_desc_id_cnt_ = 0;\n  }\n}\n\nvoid InputWiseActor::AsyncSendCustomizedConsumedRegstMsgToProducer() {\n  Regst* cur_regst = consumed_rs_.Front(cur_processed_regst_desc_id_);\n  CHECK(cur_regst);\n  AsyncSendRegstMsgToProducer(cur_regst);\n  CHECK_EQ(0, consumed_rs_.TryPopFrontRegst(cur_processed_regst_desc_id_));\n  cur_processed_regst_desc_id_ = -1;\n}\n\nvoid InputWiseActor::AsyncReturnAllCustomizedReadableRegst() {\n  CHECK_EQ(-1, cur_processed_regst_desc_id_);\n  CHECK_EQ(0, processed_regst_desc_id_cnt_);\n  CHECK_EQ(0, consumed_rs_.available_regst_desc_cnt());\n}\n\nbool InputWiseActor::ProducedCtrlRegstValid(int64_t regst_desc_id) const { return true; }\n\nint64_t InputWiseActor::GetCurProcessedRegstDescId() const {\n  int64_t cur_processed_regst_desc_id = -1;\n  consumed_rs_.ForChosenRegstDeq(\n      [cur_processed_regst_desc_id](int64_t) { return cur_processed_regst_desc_id == -1; },\n      [this, &cur_processed_regst_desc_id](const std::deque<Regst*>& reg_deq) {\n        if (reg_deq.empty()) { return; }\n        int64_t regst_desc_id = reg_deq.front()->regst_desc_id();\n        if (regst_desc_id2is_processed_.at(regst_desc_id) == false) {\n          cur_processed_regst_desc_id = regst_desc_id;\n        }\n      });\n  return cur_processed_regst_desc_id;\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/lazy/actor/input_wise_actor.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_LAZY_ACTOR_INPUT_WISE_ACTOR_H_\n#define ONEFLOW_CORE_LAZY_ACTOR_INPUT_WISE_ACTOR_H_\n\n#include \"oneflow/core/lazy/actor/actor.h\"\n\nnamespace oneflow {\n\nclass InputWiseActor : public Actor {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(InputWiseActor);\n  InputWiseActor() = default;\n  ~InputWiseActor() = default;\n\n  using Actor::Init;\n\n protected:\n  void Init(const TaskProto&);\n  int64_t cur_processed_regst_desc_id() const { return cur_processed_regst_desc_id_; }\n  int64_t processed_regst_desc_id_cnt() const { return processed_regst_desc_id_cnt_; }\n  int64_t RegstDescNum() const { return consumed_rs_.total_regst_desc_cnt(); }\n  int64_t InBnId4RegstDescId(int64_t id) const { return regst_desc_id2in_bn_id_.at(id); }\n\n  bool ProducedCtrlRegstValid(int64_t regst_desc_id) const override;\n\n private:\n  void Act() override;\n  void NormalProcessCustomizedReadableRegstMsg(const ActorMsg&) override;\n  void ForEachCurCustomizedReadableRegst(std::function<void(const Regst*)>) const override;\n  bool IsCustomizedReadReady() const override;\n  void NormalProcessCustomizedEordMsg(const ActorMsg&) override {}\n  bool IsCustomizedReadAlwaysUnReadyFromNow() const override {\n    return ReceiveAllEordMsg() && consumed_rs_.available_regst_desc_cnt() == 0;\n  }\n  void AsyncReturnAllCustomizedReadableRegst() override;\n  std::pair<RegstNameType, HashSet<std::string>> GetNaiveOrCustomizedConsumedRegstDescName()\n      override {\n    return std::make_pair(RegstNameType::kNaive, HashSet<std::string>{});\n  }\n  void VirtualAsyncSendNaiveProducedRegstMsgToConsumer() override;\n  void AsyncSendCustomizedConsumedRegstMsgToProducer() override;\n\n  int64_t GetCurProcessedRegstDescId() const;\n\n  RegstSlot consumed_rs_;\n  HashMap<int64_t, bool> regst_desc_id2is_processed_;\n  int64_t processed_regst_desc_id_cnt_;\n  int64_t cur_processed_regst_desc_id_;\n\n  HashMap<int64_t, int64_t> regst_desc_id2in_bn_id_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_LAZY_ACTOR_INPUT_WISE_ACTOR_H_\n"
  },
  {
    "path": "oneflow/core/lazy/actor/light_actor.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/lazy/actor/actor_base.h\"\n#include \"oneflow/core/register/register.h\"\n#include \"oneflow/core/kernel/kernel_context.h\"\n#include \"oneflow/core/kernel/kernel.h\"\n#include \"oneflow/core/job/id_manager.h\"\n#include \"oneflow/core/register/register_manager.h\"\n#include \"oneflow/core/lazy/actor/actor_message.h\"\n#include \"oneflow/core/lazy/actor/actor_message_bus.h\"\n#include \"oneflow/core/thread/thread.h\"\n#include \"oneflow/core/thread/thread_manager.h\"\n#include \"oneflow/core/job/runtime_job_descs.h\"\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/common/env_var/debug_mode.h\"\n#include \"oneflow/core/kernel/user_kernel.h\"\n#include \"oneflow/core/lazy/stream_context/include/stream_context.h\"\n\n#ifdef WITH_CUDA\n\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n\n#endif  // WITH_CUDA\n\nnamespace oneflow {\n\nnamespace {\n\nenum RegstType : int8_t {\n  kInvalid = 0,\n  kProduced,\n  kConsumed,\n};\n\ntemplate<typename IndexType>\nstruct ProducedRegstState {\n  IndexType reading_cnt;\n  IndexType max_reading_cnt;\n};\n\nstruct ConsumedRegstState {\n  bool ready;\n  bool eord;\n};\n\ntemplate<typename IndexType>\nstruct RegstState {\n  Regst* regst;\n  RegstType regst_type;\n  union {\n    ProducedRegstState<IndexType> produced;\n    ConsumedRegstState consumed;\n  };\n};\n\nstruct KernelInfo {\n  std::unique_ptr<const Kernel> kernel;\n  HashMap<std::string, Blob*> bn_in_op2blob;\n  std::shared_ptr<KernelState> state;\n};\n\nstruct DebugInfo {\n  int64_t actor_id;\n  std::string op_name;\n  int64_t act_cnt;\n  DebugInfo() : actor_id(-1), op_name(\"\"), act_cnt(-1) {}\n};\n\ntemplate<typename IndexType, int max_size>\nstruct ArrayBaseIndex {\n  ArrayBaseIndex() { std::memset(this, 0, sizeof(*this)); }\n\n  inline IndexType Size() const { return size; }\n\n  void Reserve(IndexType new_size) { CHECK_LE(new_size, max_size); }\n\n  inline IndexType Lookup(int64_t v) const {\n    for (IndexType i = 0; i < size; ++i) {\n      if (arr[i] == v) { return i; }\n    }\n    CHECK(false);\n    return -1;\n  }\n\n  bool Contains(int64_t v) const {\n    for (IndexType i = 0; i < size; ++i) {\n      if (arr[i] == v) { return true; }\n    }\n    return false;\n  }\n\n  IndexType Add(int64_t v) {\n    CHECK_LT(size, max_size);\n    const IndexType index = size;\n    size += 1;\n    arr[index] = v;\n    return index;\n  }\n\n  void GetValues(std::vector<int64_t>* values) const {\n    values->resize(size);\n    for (IndexType i = 0; i < size; ++i) { values->at(i) = arr[i]; }\n  }\n\n  std::array<int64_t, max_size> arr;\n  IndexType size;\n};\n\ntemplate<typename IndexType>\nstruct MapBaseIndex {\n  inline IndexType Size() const { return index_map.size(); }\n\n  void Reserve(IndexType size) { index_map.reserve(size); }\n\n  inline IndexType Lookup(int64_t v) {\n    auto it = index_map.find(v);\n    CHECK(it != index_map.end());\n    return it->second;\n  }\n\n  bool Contains(int64_t v) { return index_map.count(v) > 0; }\n\n  IndexType Add(int64_t v) {\n    const IndexType index = index_map.size();\n    CHECK(index_map.emplace(v, index).second);\n    return index;\n  }\n\n  void GetValues(std::vector<int64_t>* values) const {\n    values->resize(index_map.size());\n    for (const auto& pair : index_map) { values->at(pair.second) = pair.first; }\n  }\n\n  HashMap<int64_t, IndexType> index_map;\n};\n\ntemplate<typename IndexType, int max_size>\nstruct ArrayBaseStateContainer {\n  ArrayBaseStateContainer() { std::memset(this, 0, sizeof(*this)); }\n\n  void Resize(IndexType new_size) {\n    CHECK_LE(new_size, max_size);\n    size = new_size;\n  }\n\n  inline IndexType Size() const { return size; }\n\n  inline RegstState<IndexType>& Get(IndexType index) {\n    CHECK_LT(index, size);\n    return arr[index];\n  }\n\n  std::array<RegstState<IndexType>, max_size> arr;\n  IndexType size;\n};\n\ntemplate<typename IndexType>\nstruct VectorBaseStateContainer {\n  void Resize(IndexType new_size) { vec.resize(new_size); }\n\n  inline IndexType Size() const { return static_cast<IndexType>(vec.size()); }\n\n  inline RegstState<IndexType>& Get(IndexType index) { return vec.at(index); }\n\n  std::vector<RegstState<IndexType>> vec;\n};\n\nbool IsInplaceRegstDesc(const RegstDescProto& regst_desc) {\n  return regst_desc.has_inplace_consumed_regst_desc_id() && regst_desc.consumer_task_id_size() > 0;\n}\n\nsize_t GetRegstDescCount(const TaskProto& task) {\n  size_t regst_cnt = task.produced_regst_desc().size();\n  for (const auto& pair : task.consumed_regst_desc_id()) {\n    regst_cnt += pair.second.regst_desc_id_size();\n  }\n  return regst_cnt;\n}\n\nsize_t GetConsumerCount(const TaskProto& task) {\n  size_t consumer_cnt = 0;\n  for (const auto& pair : task.produced_regst_desc()) {\n    consumer_cnt += pair.second.consumer_task_id_size();\n  }\n  return consumer_cnt;\n}\n\nbool NeedExecKernelWhenInplace(const TaskProto& task) {\n  int64_t data_regst_cnt = 0;\n  for (const auto& pair : task.produced_regst_desc()) {\n    if (pair.second.regst_desc_type().has_data_regst_desc()) {\n      if (data_regst_cnt != 0) { return true; }\n      data_regst_cnt += 1;\n      const DataRegstDesc& regst_desc = pair.second.regst_desc_type().data_regst_desc();\n      if (regst_desc.lbi2blob_desc().size() != 1) { return true; }\n      if (regst_desc.lbi2blob_desc().begin()->blob_desc().is_dynamic()) { return true; }\n    }\n  }\n  if (data_regst_cnt != 1) { return true; }\n  if (task.exec_sequence().exec_node().size() != 1) { return true; }\n  const OperatorConf& op_conf =\n      task.exec_sequence().exec_node(0).kernel_conf().op_attribute().op_conf();\n  if (!op_conf.has_user_conf()) { return true; }\n  const std::string& op_type = op_conf.user_conf().op_type_name();\n  const bool is_const_inplace_op_type = (op_type == \"expand_dims\") || (op_type == \"squeeze\")\n                                        || (op_type == \"reshape\") || (op_type == \"reshape_like\")\n                                        || (op_type == \"transpose\");\n  if (!is_const_inplace_op_type) { return true; }\n  return false;\n}\n\n#ifdef WITH_CUDA_GRAPHS\n\nbool IsCUDAGraphSupported(const Kernel* kernel) {\n  auto* user_kernel = dynamic_cast<const UserKernel*>(kernel);\n  return (user_kernel != nullptr && user_kernel->IsCudaGraphSupported());\n}\n\n#endif  // WITH_CUDA_GRAPHS\n\ntemplate<int exec_kernel, int inplace, typename IndexType, typename RegstIndex,\n         typename StateContainer, bool dynamic_allocation, int debug>\nclass LightActor : public ActorBase, public KernelContext, public ActorContextProvider {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(LightActor);\n  explicit LightActor(ActorContext* actor_ctx)\n      : thread_(nullptr),\n        actor_ctx_(actor_ctx),\n        stream_ctx_(actor_ctx->stream_ctx()),\n        stream_kernel_observer_(nullptr) {\n    auto* kernel_observer_provider = dynamic_cast<KernelObserverProvider*>(stream_ctx_);\n    if (kernel_observer_provider != nullptr) {\n      stream_kernel_observer_ = kernel_observer_provider->GetKernelObserver();\n    }\n  }\n  ~LightActor() override {\n    for (IndexType i = 0; i < index2state_.Size(); ++i) {\n      auto& state = index2state_.Get(i);\n      if (state.regst_type == RegstType::kProduced) { delete state.regst; }\n    }\n  }\n\n  void Init(const JobDesc* job_desc, ActorContext* actor_ctx) override {\n    const TaskProto& task_proto = actor_ctx->task_proto();\n    CHECK_EQ(task_proto.exec_sequence().exec_node_size(), 1);\n    if (debug) {\n      debug_info_[0].reset(new DebugInfo());\n      debug_info_[0]->op_name =\n          task_proto.exec_sequence().exec_node(0).kernel_conf().op_attribute().op_conf().name();\n      debug_info_[0]->actor_id = task_proto.task_id();\n      debug_info_[0]->act_cnt = 0;\n    }\n    if (exec_kernel) {\n      kernel_info_[0].reset(new KernelInfo());\n      const KernelConf& kernel_conf = task_proto.exec_sequence().exec_node(0).kernel_conf();\n      kernel_info_[0]->kernel = ConstructKernel(kernel_conf, this);\n#ifdef WITH_CUDA_GRAPHS\n      auto* cuda_stream = dynamic_cast<ep::CudaStream*>(actor_ctx->stream_ctx()->stream());\n      if (cuda_stream != nullptr && kernel_conf.all_blobs_are_static()\n          && IsCUDAGraphSupported(kernel_info_[0]->kernel.get())) {\n        cuda_graph_exec_[0].reset(new ep::CudaGraphExecutable());\n      }\n#endif\n    }\n    const int64_t thrd_id = ThrdId4ActorId(task_proto.task_id());\n    thread_ = Singleton<ThreadMgr>::Get()->GetThrd(thrd_id);\n    total_reading_cnt_ = 0;\n    max_total_reading_cnt_ = 0;\n    remaining_eord_cnt_ = 0;\n    ready_consumed_ = 0;\n    max_ready_consumed_ = 0;\n\n    const IndexType regst_cnt = GetRegstDescCount(task_proto);\n    regst_desc_id_index_.Reserve(regst_cnt);\n    index2state_.Resize(regst_cnt);\n\n    IndexType inplace_produced_index = -1;\n    IndexType inplace_consumed_index = -1;\n    int64_t inplace_consumed_regst_desc_id = -1;\n\n    for (const auto& pair : task_proto.produced_regst_desc()) {\n      const RegstDescProto& regst_desc = pair.second;\n      if (IsInplaceRegstDesc(regst_desc)) {\n        CHECK_EQ(inplace_consumed_regst_desc_id, -1);\n        inplace_consumed_regst_desc_id = regst_desc.inplace_consumed_regst_desc_id();\n      }\n    }\n\n    for (const auto& pair : task_proto.consumed_regst_desc_id()) {\n      for (int64_t regst_desc_id : pair.second.regst_desc_id()) {\n        const IndexType index = regst_desc_id_index_.Add(regst_desc_id);\n        auto& state = index2state_.Get(index);\n        state.regst_type = RegstType::kConsumed;\n        state.consumed.ready = false;\n        state.consumed.eord = false;\n        remaining_eord_cnt_ += 1;\n        max_ready_consumed_ += 1;\n        if (regst_desc_id == inplace_consumed_regst_desc_id) { inplace_consumed_index = index; }\n      }\n    }\n\n    for (const auto& pair : task_proto.produced_regst_desc()) {\n      const RegstDescProto& regst_desc = pair.second;\n      const IndexType index = regst_desc_id_index_.Add(regst_desc.regst_desc_id());\n      auto& state = index2state_.Get(index);\n\n      Singleton<RegstMgr>::Get()->NewRegsts(regst_desc, [&state](Regst* regst) {\n        CHECK(state.regst == nullptr);\n        state.regst = regst;\n      });\n      state.produced.max_reading_cnt = regst_desc.consumer_task_id_size();\n      state.regst_type = RegstType::kProduced;\n      state.produced.reading_cnt = 0;\n      max_total_reading_cnt_ += state.produced.max_reading_cnt;\n      if (IsInplaceRegstDesc(regst_desc)) {\n        CHECK_EQ(inplace_produced_index, -1);\n        inplace_produced_index = index;\n      }\n    }\n\n    if (inplace) {\n      CHECK_NE(inplace_produced_index, -1);\n      CHECK_NE(inplace_consumed_index, -1);\n      inplace_produced_index_[0] = inplace_produced_index;\n      inplace_consumed_index_[0] = inplace_consumed_index;\n    } else {\n      CHECK_EQ(inplace_produced_index, -1);\n      CHECK_EQ(inplace_consumed_index, -1);\n    }\n  }\n\n  int ProcessMsg(const ActorMsg& msg) override {\n    HandleActorMsg(msg);\n    if (debug) {\n      LOG(INFO) << \" Actor: \" << debug_info_[0]->actor_id << \" op: \" << debug_info_[0]->op_name\n                << \" in act_cnt: [ \" << debug_info_[0]->act_cnt\n                << \" ]  IsWriteReady: \" << (total_reading_cnt_ == 0)\n                << \" IsReadReady: \" << (ready_consumed_ == max_ready_consumed_)\n                << \" \\n details: { total_reading_cnt = \" << static_cast<int64_t>(total_reading_cnt_)\n                << \" (expect: 0) , ready_consumed_ = \" << static_cast<int64_t>(ready_consumed_)\n                << \" (except: \" << static_cast<int64_t>(max_ready_consumed_) << \") }\";\n    }\n\n    if (total_reading_cnt_ != 0) { return 0; }\n    if (ready_consumed_ == max_ready_consumed_) {\n      ActOnce();\n      return 0;\n    }\n    if (OF_PREDICT_FALSE(ready_consumed_ == 0 && remaining_eord_cnt_ == 0)) {\n      SendEORDMsg();\n      return 1;\n    }\n    return 0;\n  }\n\n private:\n  void InitBnInOp2Blob() {\n    if (exec_kernel) {\n      const ExecNodeProto& node = actor_ctx_->task_proto().exec_sequence().exec_node(0);\n      for (auto& pair : node.kernel_conf().op_attribute().arg_signature().bn_in_op2lbi()) {\n        const std::string& bn = pair.first;\n        auto regst_desc_id_it = node.bn_in_op2regst_desc_id().find(bn);\n        if (regst_desc_id_it == node.bn_in_op2regst_desc_id().end()) {\n          CHECK(kernel_info_[0]->bn_in_op2blob.emplace(bn, nullptr).second);\n          continue;\n        }\n        if (!regst_desc_id_index_.Contains(regst_desc_id_it->second)) {\n          CHECK(kernel_info_[0]->bn_in_op2blob.emplace(bn, nullptr).second);\n          continue;\n        }\n        Regst* regst =\n            index2state_.Get(regst_desc_id_index_.Lookup(regst_desc_id_it->second)).regst;\n        if (regst == nullptr) {\n          LOG(WARNING) << \"null regst found, op:\"\n                       << node.kernel_conf().op_attribute().op_conf().name();\n          CHECK(kernel_info_[0]->bn_in_op2blob.emplace(bn, nullptr).second);\n          continue;\n        }\n        Blob* blob = regst->GetBlobByLbi(pair.second);\n        if (!blob) {\n          LOG(WARNING) << \"null blob found, op: \"\n                       << node.kernel_conf().op_attribute().op_conf().name();\n        }\n        CHECK(kernel_info_[0]->bn_in_op2blob.emplace(bn, blob).second);\n      }\n    }\n  }\n\n  void InitActMsg() {\n    const bool is_kernel_launch_synchronized =\n        (!exec_kernel) || kernel_info_[0]->kernel->IsKernelLaunchSynchronized();\n    const int64_t actor_id = actor_ctx_->task_proto().task_id();\n    const int64_t thrd_id = ThrdId4ActorId(actor_id);\n    auto IsSyncMsg = [&](const ActorMsg& msg) {\n      return is_kernel_launch_synchronized && thrd_id == ThrdId4ActorId(msg.dst_actor_id());\n    };\n    auto EnqueueActorMsg = [&](const ActorMsg& msg) {\n      if (IsSyncMsg(msg)) {\n        sync_post_act_msgs_.emplace_back(msg);\n      } else {\n        async_post_act_msgs_.emplace_back(msg);\n      }\n    };\n    std::vector<int64_t> index2regst_desc_id;\n    regst_desc_id_index_.GetValues(&index2regst_desc_id);\n    for (IndexType i = 0; i < index2state_.Size(); ++i) {\n      const auto& state = index2state_.Get(i);\n      if (state.regst_type == RegstType::kProduced) {\n        for (int64_t consumer : state.regst->consumers_actor_id()) {\n          EnqueueActorMsg(ActorMsg::BuildRegstMsgToConsumer(actor_id, consumer, state.regst));\n        }\n      } else if (state.regst_type == RegstType::kConsumed) {\n        const int64_t regst_desc_id = index2regst_desc_id.at(i);\n        int64_t producer = -1;\n        if (Singleton<RegstMgr>::Get()->HasProducerTaskId4RegstDescId(regst_desc_id)) {\n          producer = Singleton<RegstMgr>::Get()->ProducerTaskId4RegstDescId(regst_desc_id);\n        } else {\n          producer = state.regst->producer_actor_id();\n        }\n        ActorMsg msg = ActorMsg::BuildRegstMsgToProducer(actor_id, producer, state.regst);\n        if (inplace && i == inplace_consumed_index_[0]) {\n          if (IsSyncMsg(msg)) {\n            return_inplace_consumed_fn_[0] = [this, msg]() { thread_->EnqueueActorMsg(msg); };\n          } else {\n            return_inplace_consumed_fn_[0] = [this, msg]() {\n              actor_ctx_->AddCallback([msg] { Singleton<ActorMsgBus>::Get()->SendMsg(msg); });\n            };\n          }\n        } else {\n          EnqueueActorMsg(msg);\n        }\n      } else {\n        UNIMPLEMENTED();\n      }\n    }\n  }\n\n  inline void ResetState() {\n    total_reading_cnt_ = max_total_reading_cnt_;\n    ready_consumed_ = 0;\n    for (IndexType i = 0; i < index2state_.Size(); ++i) {\n      auto& state = index2state_.Get(i);\n      if (state.regst_type == RegstType::kProduced) {\n        state.produced.reading_cnt = state.produced.max_reading_cnt;\n        if (dynamic_allocation && state.produced.max_reading_cnt == 0\n            && state.regst->regst_desc()->regst_desc_type().has_data_regst_desc()) {\n          if (state.regst->allocation_type() == RegstAllocationType::kStreamOrdered) {\n            if (inplace && i == inplace_produced_index_[0]) {\n              // do nothing\n            } else {\n              CHECK_JUST(\n                  actor_ctx_->stream_ctx()->stream()->FreeAsync(state.regst->body_mem_ptr()));\n            }\n            state.regst->ResetBodyMemPtr(nullptr);\n          } else if (state.regst->allocation_type() == RegstAllocationType::kStatic) {\n            // do nothing\n          } else {\n            UNIMPLEMENTED();\n          }\n        }\n      } else if (state.regst_type == RegstType::kConsumed) {\n        state.consumed.ready = false;\n      } else {\n        UNIMPLEMENTED();\n      }\n    }\n  }\n\n  inline void HandleActorMsg(const ActorMsg& msg) {\n    if (OF_PREDICT_TRUE(msg.msg_type() == ActorMsgType::kRegstMsg)) {\n      HandleRegstMsg(msg);\n    } else if (msg.msg_type() == ActorMsgType::kEordMsg) {\n      HandleEordMsg(msg);\n    } else if (msg.msg_type() == ActorMsgType::kCmdMsg) {\n      CHECK_EQ(msg.actor_cmd(), ActorCmd::kStart);\n    } else {\n      UNIMPLEMENTED() << msg.msg_type() << \" \" << actor_ctx_->task_proto().task_id();\n    }\n  }\n\n  void HandleEordMsg(const ActorMsg& msg) {\n    const IndexType index = regst_desc_id_index_.Lookup(msg.eord_regst_desc_id());\n    auto& state = index2state_.Get(index);\n    CHECK_EQ(state.regst_type, RegstType::kConsumed);\n    CHECK_EQ(state.consumed.eord, false);\n    state.consumed.eord = true;\n    CHECK_GT(remaining_eord_cnt_, 0);\n    remaining_eord_cnt_ -= 1;\n  }\n\n  inline void HandleRegstMsg(const ActorMsg& msg) {\n    int64_t regst_desc_id = msg.regst_desc_id();\n    if (regst_desc_id == -1) { regst_desc_id = msg.regst()->regst_desc_id(); }\n\n    if (debug) {\n      LOG(INFO) << \" Actor: \" << debug_info_[0]->actor_id << \" op: \" << debug_info_[0]->op_name\n                << \" in act_cnt: [ \" << debug_info_[0]->act_cnt\n                << \" ] , Recv ActorMsg from: \" << msg.src_actor_id()\n                << \" to: \" << msg.dst_actor_id() << \" with regst: \" << regst_desc_id;\n    }\n\n    const IndexType index = regst_desc_id_index_.Lookup(regst_desc_id);\n    auto& state = index2state_.Get(index);\n    if (state.regst_type == RegstType::kProduced) {\n      CHECK_GT(state.produced.reading_cnt, 0);\n      state.produced.reading_cnt -= 1;\n      CHECK_GT(total_reading_cnt_, 0);\n      total_reading_cnt_ -= 1;\n\n      if (dynamic_allocation && state.produced.reading_cnt == 0\n          && state.regst->regst_desc()->regst_desc_type().has_data_regst_desc()) {\n        if (state.regst->allocation_type() == RegstAllocationType::kStreamOrdered) {\n          if (inplace && index == inplace_produced_index_[0]) {\n            // do nothing\n          } else {\n            CHECK_JUST(actor_ctx_->stream_ctx()->stream()->FreeAsync(state.regst->body_mem_ptr()));\n          }\n          state.regst->ResetBodyMemPtr(nullptr);\n        } else if (state.regst->allocation_type() == RegstAllocationType::kStatic) {\n          // do nothing\n        } else {\n          UNIMPLEMENTED();\n        }\n      }\n\n      if (inplace && index == inplace_produced_index_[0] && state.produced.reading_cnt == 0) {\n        return_inplace_consumed_fn_[0]();\n      }\n    } else if (state.regst_type == RegstType::kConsumed) {\n      CHECK_EQ(state.consumed.ready, false);\n      CHECK_EQ(state.consumed.eord, false);\n      if (state.regst == nullptr) {\n        state.regst = msg.regst();\n      } else {\n        CHECK(state.regst == msg.regst());\n      }\n      ready_consumed_ += 1;\n    } else {\n      UNIMPLEMENTED();\n    }\n  }\n\n  inline void ActOnce() {\n    if (OF_PREDICT_FALSE(sync_post_act_msgs_.empty() && async_post_act_msgs_.empty())) {\n      InitBnInOp2Blob();\n      InitActMsg();\n    }\n\n    for (IndexType i = 0; i < index2state_.Size(); ++i) {\n      auto& state = index2state_.Get(i);\n      if (dynamic_allocation && state.regst_type == RegstType::kProduced\n          && state.regst->regst_desc()->regst_desc_type().has_data_regst_desc()) {\n        if (state.regst->allocation_type() == RegstAllocationType::kStreamOrdered) {\n          CHECK(state.regst->body_mem_ptr() == nullptr);\n          void* body_ptr = nullptr;\n          if (inplace && i == inplace_produced_index_[0]) {\n            body_ptr = index2state_.Get(inplace_consumed_index_[0]).regst->body_mem_ptr();\n          } else {\n            CHECK_JUST(actor_ctx_->stream_ctx()->stream()->AllocAsync(\n                &body_ptr, state.regst->regst_desc()->BodyByteSize4OneRegst()));\n          }\n          state.regst->ResetBodyMemPtr(body_ptr);\n        } else if (state.regst->allocation_type() == RegstAllocationType::kStatic) {\n          // do nothing\n        } else {\n          UNIMPLEMENTED();\n        }\n      }\n    }\n\n    if (debug) {\n      LOG(INFO) << \" Actor: \" << debug_info_[0]->actor_id << \" op: \" << debug_info_[0]->op_name\n                << \" Try to act act_cnt: [ \" << debug_info_[0]->act_cnt\n                << \" ] before launch kernel.\";\n    }\n\n    if (exec_kernel) { LaunchKernel(); }\n\n    ResetState();\n    thread_->EnqueueActorMsg(sync_post_act_msgs_.cbegin(), sync_post_act_msgs_.cend());\n    if (!async_post_act_msgs_.empty()) {\n      actor_ctx_->AddCallback([this]() {\n        for (const auto& msg : async_post_act_msgs_) {\n          Singleton<ActorMsgBus>::Get()->SendMsg(msg);\n        }\n      });\n    }\n\n    if (debug) {\n      for (const auto& msg : sync_post_act_msgs_) {\n        LOG(INFO) << \" Actor: \" << debug_info_[0]->actor_id << \" op: \" << debug_info_[0]->op_name\n                  << \" in act_cnt: [ \" << debug_info_[0]->act_cnt\n                  << \" ] Sync post ActorMsg from: \" << msg.src_actor_id()\n                  << \" to: \" << msg.dst_actor_id() << \" with regst: \" << msg.regst_desc_id();\n      }\n      for (const auto& msg : async_post_act_msgs_) {\n        LOG(INFO) << \" Actor: \" << debug_info_[0]->actor_id << \" op: \" << debug_info_[0]->op_name\n                  << \" in act_cnt: [ \" << debug_info_[0]->act_cnt\n                  << \" ] Async post ActorMsg from: \" << msg.src_actor_id()\n                  << \" to: \" << msg.dst_actor_id() << \" with regst: \" << msg.regst_desc_id();\n      }\n      LOG(INFO) << \" Actor: \" << debug_info_[0]->actor_id << \" op: \" << debug_info_[0]->op_name\n                << \" Finish act act_cnt: [ \" << debug_info_[0]->act_cnt++ << \" ].\";\n    }\n  }\n\n  inline void LaunchKernel() {\n#ifdef WITH_CUDA_GRAPHS\n    bool is_capturing = false;\n    if (cuda_graph_exec_[0]) {\n      auto* cuda_stream = stream_ctx_->stream()->As<ep::CudaStream>();\n      if (cuda_graph_exec_[0]->IsInstantiated()) {\n        cuda_stream->LaunchGraph(cuda_graph_exec_[0].get());\n        return;\n      }\n      auto* user_kernel =\n          CHECK_NOTNULL(dynamic_cast<const UserKernel*>(kernel_info_[0]->kernel.get()));\n      if (user_kernel->IsReadyForCudaGraphCapture(this)) {\n        is_capturing = true;\n        cuda_stream->BeginGraphCapture();\n      }\n    }\n#endif\n    kernel_info_[0]->kernel->Launch(this);\n#ifdef WITH_CUDA_GRAPHS\n    if (cuda_graph_exec_[0] && is_capturing) {\n      auto* cuda_stream = stream_ctx_->stream()->As<ep::CudaStream>();\n      cuda_stream->EndGraphCapture(cuda_graph_exec_[0].get());\n      cuda_stream->LaunchGraph(cuda_graph_exec_[0].get());\n    }\n#endif\n  }\n\n  void SendEORDMsg() {\n    for (IndexType i = 0; i < index2state_.Size(); ++i) {\n      auto& state = index2state_.Get(i);\n      if (state.regst_type != RegstType::kProduced) { continue; }\n      const RtRegstDesc* regst_desc = state.regst->regst_desc();\n      actor_ctx_->AddCallback([regst_desc]() {\n        for (int64_t consumer : regst_desc->consumers_actor_id()) {\n          Singleton<ActorMsgBus>::Get()->SendMsg(\n              ActorMsg::BuildEordMsg(consumer, regst_desc->regst_desc_id()));\n        }\n      });\n    }\n  }\n\n  ep::Stream* stream() const override { return stream_ctx_->stream(); }\n\n  ActorContext* GetActorContext() const override { return actor_ctx_; }\n\n  Blob* BnInOp2Blob(const std::string& bn) const override {\n    if (exec_kernel) {\n      auto it = kernel_info_[0]->bn_in_op2blob.find(bn);\n      if (it == kernel_info_[0]->bn_in_op2blob.end()) {\n        return nullptr;\n      } else {\n        return it->second;\n      }\n    } else {\n      return nullptr;\n    }\n  }\n\n  const std::shared_ptr<KernelState>& state() const override {\n    if (exec_kernel) {\n      return kernel_info_[0]->state;\n    } else {\n      static const std::shared_ptr<KernelState> null_state;\n      return null_state;\n    }\n  }\n\n  void set_state(std::shared_ptr<KernelState> state) override {\n    CHECK(exec_kernel);\n    kernel_info_[0]->state = std::move(state);\n  }\n\n  void WillForward(KernelContext* kernel_ctx, const Kernel* kernel) override {\n    Singleton<KernelObserver>::Get()->WillForward(kernel_ctx, kernel);\n    if (stream_kernel_observer_ != nullptr) {\n      stream_kernel_observer_->WillForward(kernel_ctx, kernel);\n    }\n  }\n\n  void DidForward(KernelContext* kernel_ctx, const Kernel* kernel) override {\n    CHECK_JUST_MSG(kernel_ctx->stream()->GetAsyncError(), kernel->op_conf().name());\n    Singleton<KernelObserver>::Get()->DidForward(kernel_ctx, kernel);\n    if (stream_kernel_observer_ != nullptr) {\n      stream_kernel_observer_->DidForward(kernel_ctx, kernel);\n    }\n  }\n\n  void WillForwardHeader(KernelContext* kernel_ctx, const Kernel* kernel) override {\n    Singleton<KernelObserver>::Get()->WillForwardHeader(kernel_ctx, kernel);\n    if (stream_kernel_observer_ != nullptr) {\n      stream_kernel_observer_->WillForwardHeader(kernel_ctx, kernel);\n    }\n  }\n\n  void DidForwardHeader(KernelContext* kernel_ctx, const Kernel* kernel) override {\n    Singleton<KernelObserver>::Get()->DidForwardHeader(kernel_ctx, kernel);\n    if (stream_kernel_observer_ != nullptr) {\n      stream_kernel_observer_->DidForwardHeader(kernel_ctx, kernel);\n    }\n  }\n\n  void WillForwardDataContent(KernelContext* kernel_ctx, const Kernel* kernel) override {\n    Singleton<KernelObserver>::Get()->WillForwardDataContent(kernel_ctx, kernel);\n    if (stream_kernel_observer_ != nullptr) {\n      stream_kernel_observer_->WillForwardDataContent(kernel_ctx, kernel);\n    }\n  }\n\n  void DidForwardDataContent(KernelContext* kernel_ctx, const Kernel* kernel) override {\n    Singleton<KernelObserver>::Get()->DidForwardDataContent(kernel_ctx, kernel);\n    if (stream_kernel_observer_ != nullptr) {\n      stream_kernel_observer_->DidForwardDataContent(kernel_ctx, kernel);\n    }\n  }\n\n  RegstIndex regst_desc_id_index_;\n  StateContainer index2state_;\n  IndexType total_reading_cnt_;\n  IndexType ready_consumed_;\n  IndexType max_total_reading_cnt_;\n  IndexType max_ready_consumed_;\n  IndexType remaining_eord_cnt_;\n  IndexType inplace_produced_index_[inplace];\n  IndexType inplace_consumed_index_[inplace];\n  std::function<void()> return_inplace_consumed_fn_[inplace];\n  Thread* thread_;\n  std::unique_ptr<KernelInfo> kernel_info_[exec_kernel];\n#ifdef WITH_CUDA_GRAPHS\n  std::unique_ptr<ep::CudaGraphExecutable> cuda_graph_exec_[exec_kernel];\n#endif\n  ActorContext* actor_ctx_;\n  StreamContext* stream_ctx_;\n  std::vector<ActorMsg> sync_post_act_msgs_;\n  std::vector<ActorMsg> async_post_act_msgs_;\n  KernelObserver* stream_kernel_observer_;\n\n  // for debug\n  std::unique_ptr<DebugInfo> debug_info_[debug];\n};\n\ntemplate<int kernel_exec, int inplace, typename IndexType, typename RegstIndex,\n         typename StateContainer, bool dynamic_allocation>\nActorBase* DispatchNewLightActorDebug(ActorContext* actor_ctx) {\n  const bool debug = EnableActorDebugLog();\n  if (debug) {\n    return new LightActor<kernel_exec, inplace, IndexType, RegstIndex, StateContainer,\n                          dynamic_allocation, 1>(actor_ctx);\n  } else {\n    return new LightActor<kernel_exec, inplace, IndexType, RegstIndex, StateContainer,\n                          dynamic_allocation, 0>(actor_ctx);\n  }\n}\n\ntemplate<int kernel_exec, int inplace, typename IndexType, typename RegstIndex,\n         typename StateContainer>\nActorBase* DispatchNewLightActorDynamicAlloc(ActorContext* actor_ctx) {\n  const bool dynamic_allocation =\n      ParseBooleanFromEnv(\"ONEFLOW_GRAPH_ENABLE_STREAM_ORDERED_MEMORY_ALLOCATION\", false);\n  if (dynamic_allocation) {\n    return DispatchNewLightActorDebug<kernel_exec, inplace, IndexType, RegstIndex, StateContainer,\n                                      true>(actor_ctx);\n  } else {\n    return DispatchNewLightActorDebug<kernel_exec, inplace, IndexType, RegstIndex, StateContainer,\n                                      false>(actor_ctx);\n  }\n}\n\ntemplate<int kernel_exec, int inplace, typename IndexType>\nActorBase* DispatchNewLightActorMaxSize(ActorContext* actor_ctx) {\n  const size_t regst_desc_count = GetRegstDescCount(actor_ctx->task_proto());\n  if (regst_desc_count <= 2) {\n    return DispatchNewLightActorDynamicAlloc<kernel_exec, inplace, IndexType,\n                                             ArrayBaseIndex<IndexType, 2>,\n                                             ArrayBaseStateContainer<IndexType, 2>>(actor_ctx);\n  } else if (regst_desc_count <= 4) {\n    return DispatchNewLightActorDynamicAlloc<kernel_exec, inplace, IndexType,\n                                             ArrayBaseIndex<IndexType, 4>,\n                                             ArrayBaseStateContainer<IndexType, 4>>(actor_ctx);\n  } else if (regst_desc_count <= 8) {\n    return DispatchNewLightActorDynamicAlloc<kernel_exec, inplace, IndexType,\n                                             ArrayBaseIndex<IndexType, 8>,\n                                             ArrayBaseStateContainer<IndexType, 8>>(actor_ctx);\n  } else {\n    return DispatchNewLightActorDynamicAlloc<kernel_exec, inplace, IndexType,\n                                             MapBaseIndex<IndexType>,\n                                             VectorBaseStateContainer<IndexType>>(actor_ctx);\n  }\n}\n\ntemplate<int kernel_exec, int inplace>\nActorBase* DispatchNewLightActorIndexType(ActorContext* actor_ctx) {\n  size_t size = std::max(GetRegstDescCount(actor_ctx->task_proto()),\n                         GetConsumerCount(actor_ctx->task_proto()));\n  if (size <= static_cast<size_t>(std::numeric_limits<int8_t>::max())) {\n    return DispatchNewLightActorMaxSize<kernel_exec, inplace, int8_t>(actor_ctx);\n  } else if (size <= static_cast<size_t>(std::numeric_limits<int32_t>::max())) {\n    return DispatchNewLightActorMaxSize<kernel_exec, inplace, int32_t>(actor_ctx);\n  } else {\n    return nullptr;\n  }\n}\n\ntemplate<int kernel_exec>\nActorBase* DispatchNewLightActorInplace(ActorContext* actor_ctx) {\n  const auto& produced_regst_desc = actor_ctx->task_proto().produced_regst_desc();\n  const size_t inplace_produced_regst_cnt =\n      std::count_if(produced_regst_desc.cbegin(), produced_regst_desc.cend(),\n                    [](const PbMapPair<std::string, RegstDescProto>& pair) {\n                      return pair.second.has_inplace_consumed_regst_desc_id();\n                    });\n  if (inplace_produced_regst_cnt > 1) { return nullptr; }\n  bool inplace = false;\n  for (const auto& pair : produced_regst_desc) {\n    const RegstDescProto& regst_desc = pair.second;\n    if (IsInplaceRegstDesc(regst_desc)) {\n      CHECK_EQ(inplace, false);\n      inplace = true;\n    }\n  }\n  if (inplace) {\n    if (kernel_exec && NeedExecKernelWhenInplace(actor_ctx->task_proto())) {\n      return DispatchNewLightActorIndexType<1, 1>(actor_ctx);\n    } else {\n      return DispatchNewLightActorIndexType<0, 1>(actor_ctx);\n    }\n  } else {\n    return DispatchNewLightActorIndexType<kernel_exec, 0>(actor_ctx);\n  }\n}\n\nActorBase* NewLightActorWithKernel(ActorContext* actor_ctx) {\n  return DispatchNewLightActorInplace<1>(actor_ctx);\n}\n\nActorBase* NewLightActorWithoutKernel(ActorContext* actor_ctx) {\n  return DispatchNewLightActorInplace<0>(actor_ctx);\n}\n\nActorBase* TryNewLightActorWithoutInit(ActorContext* actor_ctx) {\n  const TaskProto& task_proto = actor_ctx->task_proto();\n  if (!task_proto.all_register_num_eq_one_hint()) { return nullptr; }\n  if (task_proto.exec_sequence().exec_node_size() != 1) { return nullptr; }\n  if (task_proto.task_type() == TaskType::kNormalForward) {\n    const OperatorConf& op_conf =\n        task_proto.exec_sequence().exec_node(0).kernel_conf().op_attribute().op_conf();\n    if (op_conf.has_variable_conf()) {\n      return NewLightActorWithoutKernel(actor_ctx);\n    } else {\n      return NewLightActorWithKernel(actor_ctx);\n    }\n  } else if (task_proto.task_type() == TaskType::kCopyHd) {\n    return NewLightActorWithKernel(actor_ctx);\n  } else if (task_proto.task_type() == TaskType::kTick) {\n    return NewLightActorWithoutKernel(actor_ctx);\n  } else if (task_proto.task_type() == TaskType::kCollectiveBoxingGeneric) {\n    return NewLightActorWithKernel(actor_ctx);\n  } else {\n    return nullptr;\n  }\n}\n\n}  // namespace\n\nstd::unique_ptr<ActorBase> TryNewLightActor(ActorContext* actor_ctx) {\n  ActorBase* actor = TryNewLightActorWithoutInit(actor_ctx);\n  if (actor != nullptr) {\n    const auto& job_descs = *Singleton<RuntimeJobDescs>::Get();\n    actor->Init(&job_descs.job_desc(actor_ctx->task_proto().job_id()), actor_ctx);\n  }\n  return std::unique_ptr<ActorBase>(actor);\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/lazy/actor/light_actor.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_LAZY_ACTOR_LIGHT_ACTOR_H_\n#define ONEFLOW_CORE_LAZY_ACTOR_LIGHT_ACTOR_H_\n\n#include \"oneflow/core/lazy/actor/actor_base.h\"\n\nnamespace oneflow {\n\nstd::unique_ptr<ActorBase> TryNewLightActor(ActorContext* ctx);\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_LAZY_ACTOR_LIGHT_ACTOR_H_\n"
  },
  {
    "path": "oneflow/core/lazy/actor/naive_actor.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/lazy/actor/naive_actor.h\"\n\nnamespace oneflow {\n\nvoid NaiveActor::Act() {\n  AsyncLaunchKernel([&](int64_t regst_desc_id) -> Regst* { return nullptr; });\n}\n\nvoid NaiveActor::VirtualActorInit(const TaskProto&) {\n  OF_SET_MSG_HANDLER(&NaiveActor::HandlerNormal);\n}\n\nREGISTER_ACTOR(TaskType::kNormalForward, NaiveActor);\nREGISTER_ACTOR(TaskType::kDistributeConcat, NaiveActor);\nREGISTER_ACTOR(TaskType::kDistributeSplit, NaiveActor);\nREGISTER_ACTOR(TaskType::kSliceBoxing, NaiveActor);\nREGISTER_ACTOR(TaskType::kBoxingIdentity, NaiveActor);\nREGISTER_ACTOR(TaskType::kCollectiveBoxingPack, NaiveActor);\nREGISTER_ACTOR(TaskType::kCollectiveBoxingUnpack, NaiveActor);\nREGISTER_ACTOR(TaskType::kNcclSendRecvBoxing, NaiveActor);\nREGISTER_ACTOR(TaskType::kDecodeH2D, NaiveActor);\nREGISTER_ACTOR(TaskType::kCriticalSectionWaitTick, NaiveActor);\nREGISTER_ACTOR(TaskType::kCopyHd, NaiveActor);\nREGISTER_ACTOR(TaskType::kCollectiveBoxingGeneric, NaiveActor);\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/lazy/actor/naive_actor.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_LAZY_ACTOR_NAIVE_ACTOR_H_\n#define ONEFLOW_CORE_LAZY_ACTOR_NAIVE_ACTOR_H_\n\n#include \"oneflow/core/lazy/actor/actor.h\"\n\nnamespace oneflow {\n\nclass NaiveActor : public Actor {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(NaiveActor);\n  NaiveActor() = default;\n  ~NaiveActor() override = default;\n\n  void VirtualActorInit(const TaskProto&) override;\n\n protected:\n  void Act() override;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_LAZY_ACTOR_NAIVE_ACTOR_H_\n"
  },
  {
    "path": "oneflow/core/lazy/actor/pack_actor.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/lazy/actor/actor.h\"\n#include \"oneflow/core/kernel/user_kernel.h\"\n#include \"oneflow/user/kernels/op_kernel_wrapper.h\"\n\nnamespace oneflow {\n\nclass PackActor final : public Actor {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(PackActor);\n  PackActor() = default;\n  ~PackActor() = default;\n\n private:\n  void VirtualActorInit(const TaskProto& proto) override;\n  void Act() override;\n  void VirtualAsyncSendNaiveProducedRegstMsgToConsumer() override;\n  void VirtualAsyncSendNaiveConsumedRegstMsgToProducer() override;\n\n  size_t total_pack_num_;\n  size_t act_num_cnt_;\n};\n\nvoid PackActor::VirtualActorInit(const TaskProto& proto) {\n  const Shape& in_time_shape = Singleton<RegstMgr>::Get()\n                                   ->RegstDesc4RegstDescId(Name2SoleRegstDescId(\"in\"))\n                                   .data_regst_time_shape();\n  total_pack_num_ = in_time_shape.At(in_time_shape.NumAxes() - 1);\n  act_num_cnt_ = 0;\n  OF_SET_MSG_HANDLER(&PackActor::HandlerNormal);\n}\n\nvoid PackActor::Act() {\n  CHECK_GE(exec_kernel_vec().size(), 1);\n  auto user_kernel = dynamic_cast<const UserKernel*>(exec_kernel_vec().at(0).kernel.get());\n  CHECK_NOTNULL(user_kernel);\n  auto state = dynamic_cast<OpKernelStateWrapper<std::pair<size_t, size_t>>*>(\n      user_kernel->GetOpKernelState().get());\n  CHECK_NOTNULL(state);\n  state->Mutable()->first = act_num_cnt_;\n  state->Mutable()->second = total_pack_num_;\n  AsyncLaunchKernel();\n  act_num_cnt_ += 1;\n}\n\nvoid PackActor::VirtualAsyncSendNaiveProducedRegstMsgToConsumer() {\n  if (act_num_cnt_ == total_pack_num_) { HandleProducedNaiveDataRegstToConsumer(); }\n}\n\nvoid PackActor::VirtualAsyncSendNaiveConsumedRegstMsgToProducer() {\n  HandleConsumedNaiveDataRegstToProducer();\n  if (act_num_cnt_ == total_pack_num_) { act_num_cnt_ = 0; }\n}\n\nREGISTER_ACTOR(TaskType::kPack, PackActor);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/lazy/actor/reentrant_lock_actor.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/lazy/actor/actor.h\"\n#include \"oneflow/core/kernel/reentrant_lock_kernel.h\"\n\nnamespace oneflow {\n\nclass ReentrantLockActor final : public Actor {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(ReentrantLockActor);\n  ReentrantLockActor() : reentrant_lock_status_(nullptr) {}\n  ~ReentrantLockActor() override = default;\n\n protected:\n  void VirtualActorInit(const TaskProto&) override;\n\n private:\n  void Act() override;\n  void NormalProcessCustomizedReadableRegstMsg(const ActorMsg&) override;\n  void ForEachCurCustomizedReadableRegst(std::function<void(const Regst*)>) const override;\n  bool IsCustomizedReadReady() const override;\n  void NormalProcessCustomizedEordMsg(const ActorMsg&) override {}\n  bool IsCustomizedReadAlwaysUnReadyFromNow() const override;\n  void AsyncReturnAllCustomizedReadableRegst() override;\n  std::pair<RegstNameType, HashSet<std::string>> GetNaiveOrCustomizedConsumedRegstDescName()\n      override {\n    return std::make_pair(RegstNameType::kNaive, HashSet<std::string>{});\n  }\n  void VirtualAsyncSendNaiveProducedRegstMsgToConsumer() override;\n  void AsyncSendCustomizedConsumedRegstMsgToProducer() override;\n  int64_t GetCurProcessedRegstDescId() const;\n\n  const std::string& Ibn4RegstDescId(int64_t id) const;\n\n  RegstSlot consumed_rs_;\n  int64_t cur_processed_regst_desc_id_{};\n  HashMap<int64_t, std::string> regst_desc_id2ibn_;\n  ReentrantLockStatus* reentrant_lock_status_;\n  int64_t eord_regst_desc_id_{};\n  int64_t act_id_{};\n};\n\nvoid ReentrantLockActor::VirtualActorInit(const TaskProto& task_proto) {\n  CHECK_EQ(1, exec_kernel_vec().size());\n  reentrant_lock_status_ = CHECK_NOTNULL(\n      dynamic_cast<ReentrantLockStatus*>(exec_kernel_vec().at(0).kernel_ctx->state().get()));\n  act_id_ = 0;\n  const auto& kernel_conf = task_proto.exec_sequence().exec_node().Get(0).kernel_conf();\n  const auto& ibns = kernel_conf.op_attribute().input_bns();\n  for (const auto& ibn : ibns) {\n    int64_t regst_desc_id = exec_kernel_vec().at(0).bn_in_op2blob_info.at(ibn).regst_desc_id;\n    if (ibn == \"start\") { eord_regst_desc_id_ = regst_desc_id; }\n    CHECK(regst_desc_id2ibn_.emplace(regst_desc_id, ibn).second);\n  }\n  for (const auto& pair : task_proto.consumed_regst_desc_id()) {\n    for (const int64_t regst_desc_id : pair.second.regst_desc_id()) {\n      consumed_rs_.InsertRegstDescId(regst_desc_id);\n    }\n  }\n  consumed_rs_.InitedDone();\n  cur_processed_regst_desc_id_ = -1;\n  reentrant_lock_status_->Init(kernel_conf);\n  OF_SET_MSG_HANDLER(&ReentrantLockActor::HandlerNormal);\n}\n\nvoid ReentrantLockActor::NormalProcessCustomizedReadableRegstMsg(const ActorMsg& msg) {\n  CHECK_EQ(0, consumed_rs_.TryPushBackRegst(msg.regst()));\n}\n\nbool ReentrantLockActor::IsCustomizedReadReady() const {\n  return reentrant_lock_status_->cur_unlocked_ids().size() > 0\n         || -1 != GetCurProcessedRegstDescId();\n}\n\nvoid ReentrantLockActor::ForEachCurCustomizedReadableRegst(\n    std::function<void(const Regst*)> handler) const {\n  handler(consumed_rs_.Front(cur_processed_regst_desc_id_));\n}\n\nconst std::string& ReentrantLockActor::Ibn4RegstDescId(int64_t id) const {\n  const auto& iter = regst_desc_id2ibn_.find(id);\n  if (iter == regst_desc_id2ibn_.end()) { return ReentrantLockStatus::kEmptyIbn; }\n  return regst_desc_id2ibn_.at(id);\n}\n\nvoid ReentrantLockActor::Act() {\n  cur_processed_regst_desc_id_ = GetCurProcessedRegstDescId();\n  Regst* const cur_regst = consumed_rs_.Front(cur_processed_regst_desc_id_);\n  reentrant_lock_status_->set_cur_ibn(Ibn4RegstDescId(cur_processed_regst_desc_id_));\n  reentrant_lock_status_->set_cur_act_id(act_id_);\n  act_id_ += 1;\n  AsyncLaunchKernel([&](int64_t regst_desc_id) -> Regst* {\n    if (cur_processed_regst_desc_id_ != regst_desc_id) { return nullptr; }\n    return cur_regst;\n  });\n}\n\nbool ReentrantLockActor::IsCustomizedReadAlwaysUnReadyFromNow() const {\n  return ReceiveEordMsg(eord_regst_desc_id_)\n         && reentrant_lock_status_->total_queued_request_lock_num() == 0\n         && reentrant_lock_status_->total_acquired_lock_num() == 0;\n}\n\nvoid ReentrantLockActor::VirtualAsyncSendNaiveProducedRegstMsgToConsumer() {\n  if (reentrant_lock_status_->acquired_lock_to_be_sent() == false) { return; }\n  HandleProducedNaiveDataRegstToConsumer();\n}\n\nvoid ReentrantLockActor::AsyncSendCustomizedConsumedRegstMsgToProducer() {\n  Regst* const cur_regst = consumed_rs_.Front(cur_processed_regst_desc_id_);\n  if (cur_regst == nullptr) { return; }\n  AsyncSendRegstMsgToProducer(cur_regst);\n  CHECK_EQ(0, consumed_rs_.TryPopFrontRegst(cur_processed_regst_desc_id_));\n  cur_processed_regst_desc_id_ = -1;\n}\n\nvoid ReentrantLockActor::AsyncReturnAllCustomizedReadableRegst() {\n  CHECK_EQ(-1, cur_processed_regst_desc_id_);\n  CHECK_EQ(0, consumed_rs_.available_regst_desc_cnt());\n}\n\nint64_t ReentrantLockActor::GetCurProcessedRegstDescId() const {\n  int64_t cur_processed_regst_desc_id = -1;\n  consumed_rs_.ForChosenRegstDeq(\n      [&cur_processed_regst_desc_id](int64_t) { return cur_processed_regst_desc_id == -1; },\n      [&cur_processed_regst_desc_id](const std::deque<Regst*>& reg_deq) {\n        if (reg_deq.empty()) { return; }\n        cur_processed_regst_desc_id = reg_deq.front()->regst_desc_id();\n      });\n  return cur_processed_regst_desc_id;\n}\n\nREGISTER_ACTOR(kReentrantLock, ReentrantLockActor);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/lazy/actor/register_slot.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/lazy/actor/register_slot.h\"\n\nnamespace oneflow {\n\nint64_t RegstSlot::GetReadyRegstSize(int64_t regst_desc_id) const {\n  CHECK(is_inited_);\n  auto it = regst_desc_id2regsts_.find(regst_desc_id);\n  if (it == regst_desc_id2regsts_.end()) { return -1; }\n  return it->second.size();\n}\n\nbool RegstSlot::HasRegstDescId(int64_t regst_desc_id) const {\n  CHECK(is_inited_);\n  return regst_desc_id2regsts_.find(regst_desc_id) != regst_desc_id2regsts_.end();\n}\n\nconst std::deque<Regst*>& RegstSlot::RegstDeq4RegstDescId(int64_t regst_desc_id) const {\n  CHECK(is_inited_);\n  return regst_desc_id2regsts_.at(regst_desc_id);\n}\n\nint RegstSlot::TryPushBackRegst(Regst* regst) {\n  return TryPushBackRegst(regst, regst->regst_desc_id());\n}\n\nint RegstSlot::TryPushBackRegst(Regst* regst, int64_t regst_desc_id) {\n  CHECK(is_inited_);\n  auto it = regst_desc_id2regsts_.find(regst_desc_id);\n  if (it == regst_desc_id2regsts_.end()) { return -1; }\n  if (it->second.empty()) { available_regst_desc_cnt_ += 1; }\n  it->second.emplace_back(regst);\n  return 0;\n}\n\nint RegstSlot::TryPopFrontRegst(int64_t regst_desc_id) {\n  CHECK(is_inited_);\n  auto it = regst_desc_id2regsts_.find(regst_desc_id);\n  if (it == regst_desc_id2regsts_.end()) { return -1; }\n  CHECK(it->second.empty() == false);\n  it->second.pop_front();\n  if (it->second.empty()) { available_regst_desc_cnt_ -= 1; }\n  return 0;\n}\n\nvoid RegstSlot::PopFrontRegsts(const std::vector<int64_t>& regst_desc_ids) {\n  CHECK(is_inited_);\n  for (int64_t regst_desc_id : regst_desc_ids) { CHECK_EQ(0, TryPopFrontRegst(regst_desc_id)); }\n}\n\nvoid RegstSlot::InsertRegstDescId(int64_t regst_desc_id) {\n  CHECK(is_inited_ == false);\n  CHECK(regst_desc_id2regsts_.emplace(regst_desc_id, std::deque<Regst*>()).second);\n}\n\nRegst* RegstSlot::Front(int64_t regst_desc_id) const {\n  CHECK(is_inited_);\n  auto it = regst_desc_id2regsts_.find(regst_desc_id);\n  if (it == regst_desc_id2regsts_.end()) { return nullptr; }\n  if (it->second.empty()) { return nullptr; }\n  return it->second.front();\n}\n\nRegst* RegstSlot::SoleFront() const {\n  CHECK(is_inited_);\n  CHECK_EQ(1, total_regst_desc_cnt());\n  auto it = regst_desc_id2regsts_.begin();\n  if (it->second.empty()) { return nullptr; }\n  return it->second.front();\n}\n\nRegst* RegstSlot::FirstFront() const {\n  CHECK(is_inited_);\n  CHECK_GE(total_regst_desc_cnt(), 1);\n  auto it = regst_desc_id2regsts_.begin();\n  if (it->second.empty()) { return nullptr; }\n  return it->second.front();\n}\n\nvoid RegstSlot::InitedDone() {\n  CHECK(is_inited_ == false);\n  is_inited_ = true;\n}\n\nvoid RegstSlot::ForChosenFrontRegst(const std::function<bool(int64_t)>& IsChosenRegstDescId,\n                                    const std::function<void(Regst*)>& Handler) const {\n  for (const auto& kv : regst_desc_id2regsts_) {\n    if (IsChosenRegstDescId(kv.first)) {\n      CHECK(kv.second.empty() == false);\n      Handler(kv.second.front());\n    }\n  }\n}\n\nvoid RegstSlot::ForChosenFrontRegst(\n    const std::function<bool(int64_t)>& IsChosenRegstDescId,\n    const std::function<void(int64_t regst_desc_id, Regst*)>& Handler) const {\n  for (const auto& kv : regst_desc_id2regsts_) {\n    if (IsChosenRegstDescId(kv.first)) {\n      CHECK(kv.second.empty() == false);\n      Handler(kv.first, kv.second.front());\n    }\n  }\n}\n\nvoid RegstSlot::ForChosenRegstDeq(\n    const std::function<bool(int64_t)>& IsChosenRegstDescId,\n    const std::function<void(const std::deque<Regst*>&)>& Handler) const {\n  for (const auto& kv : regst_desc_id2regsts_) {\n    if (IsChosenRegstDescId(kv.first)) { Handler(kv.second); }\n  }\n}\n\nvoid RegstSlot::ForChosenRegstDeq(\n    const std::function<bool(int64_t)>& IsChosenRegstDescId,\n    const std::function<void(int64_t regst_desc_id, const std::deque<Regst*>&)>& Handler) const {\n  for (const auto& kv : regst_desc_id2regsts_) {\n    if (IsChosenRegstDescId(kv.first)) { Handler(kv.first, kv.second); }\n  }\n}\n\nvoid RegstSlot::ForEachFrontRegst(const std::function<void(Regst*)>& Handler) const {\n  ForChosenFrontRegst([](int64_t) { return true; }, Handler);\n}\n\nvoid RegstSlot::ForEachFrontRegst(\n    const std::function<void(int64_t regst_desc_id, Regst*)>& Handler) const {\n  ForChosenFrontRegst([](int64_t) { return true; }, Handler);\n}\n\nvoid RegstSlot::ForEachRegstDeq(\n    const std::function<void(const std::deque<Regst*>&)>& Handler) const {\n  ForChosenRegstDeq([](int64_t) { return true; }, Handler);\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/lazy/actor/register_slot.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_LAZY_ACTOR_REGISTER_SLOT_H_\n#define ONEFLOW_CORE_LAZY_ACTOR_REGISTER_SLOT_H_\n\n#include \"oneflow/core/register/register_manager.h\"\n\nnamespace oneflow {\n\nclass RegstSlot final {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(RegstSlot);\n  RegstSlot() : regst_desc_id2regsts_(), available_regst_desc_cnt_(0), is_inited_(false) {}\n  ~RegstSlot() = default;\n\n  bool is_inited() const { return is_inited_; }\n  size_t total_regst_desc_cnt() const { return regst_desc_id2regsts_.size(); }\n  size_t available_regst_desc_cnt() const { return available_regst_desc_cnt_; }\n\n  int64_t GetReadyRegstSize(int64_t regst_desc_id) const;\n  bool IsCurSlotReady() const { return available_regst_desc_cnt() == total_regst_desc_cnt(); }\n  bool HasRegstDescId(int64_t regst_desc_id) const;\n  const std::deque<Regst*>& RegstDeq4RegstDescId(int64_t regst_desc_id) const;\n  void ForEachFrontRegst(const std::function<void(Regst*)>&) const;\n  void ForEachFrontRegst(const std::function<void(int64_t regst_desc_id, Regst*)>&) const;\n  void ForEachRegstDeq(const std::function<void(const std::deque<Regst*>&)>&) const;\n  void ForChosenFrontRegst(const std::function<bool(int64_t)>&,\n                           const std::function<void(Regst*)>&) const;\n  void ForChosenFrontRegst(const std::function<bool(int64_t)>&,\n                           const std::function<void(int64_t regst_desc_id, Regst*)>&) const;\n  void ForChosenRegstDeq(const std::function<bool(int64_t)>&,\n                         const std::function<void(const std::deque<Regst*>&)>&) const;\n  void ForChosenRegstDeq(\n      const std::function<bool(int64_t)>&,\n      const std::function<void(int64_t regst_desc_id, const std::deque<Regst*>&)>&) const;\n\n  Regst* Front(int64_t regst_desc_id) const;\n  Regst* SoleFront() const;\n  Regst* FirstFront() const;\n\n  // 0: success, -1: cannot find regst_desc_id\n  int TryPushBackRegst(Regst* regst);\n  int TryPushBackRegst(Regst* regst, int64_t regst_desc_id);\n  int TryPopFrontRegst(int64_t regst_desc_id);\n\n  void PopFrontRegsts(const std::vector<int64_t>& regst_desc_ids);\n\n  void InitedDone();\n  void InsertRegstDescId(int64_t regst_desc_id);\n\n private:\n  HashMap<int64_t, std::deque<Regst*>> regst_desc_id2regsts_;\n  size_t available_regst_desc_cnt_;\n  bool is_inited_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_LAZY_ACTOR_REGISTER_SLOT_H_\n"
  },
  {
    "path": "oneflow/core/lazy/actor/repeat_actor.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/lazy/actor/actor.h\"\n#include \"oneflow/core/framework/framework.h\"\n\nnamespace oneflow {\n\nclass RepeatActor final : public Actor {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(RepeatActor);\n  RepeatActor()\n      : repeat_count_(0),\n        repeat_num_(0),\n        wait_all_regst_return_(false),\n        consumed_var_regst_desc_id_(-1),\n        produced_repeat_var_regst_desc_id_(-1){};\n  ~RepeatActor() override = default;\n\n private:\n  // NOTE(chengcheng): Empty rs for naive and inplace regst, all regst is customized.\n  std::pair<RegstNameType, HashSet<std::string>> GetNaiveOrCustomizedConsumedRegstDescName()\n      override {\n    return std::make_pair(RegstNameType::kNaive, HashSet<std::string>{});\n  }\n  std::pair<RegstNameType, HashSet<std::string>> GetNaiveOrCustomizedProducedRegstDescName()\n      override {\n    return std::make_pair(RegstNameType::kNaive, HashSet<std::string>{});\n  }\n  void TakeOverInplaceConsumedAndProduced(\n      const PbMap<std::string, RegstDescProto>& produced_ids) override {\n    // NOTE(chengcheng): all regst is customized.\n    inplace_consumed_rs_.InitedDone();\n    inplace_produced_rs_.InitedDone();\n  }\n\n  bool IsCustomizedReadReady() const override {\n    return (!wait_all_regst_return_) && consumed_var_rs_.IsCurSlotReady();\n  }\n  bool IsCustomizedWriteReady() const override {\n    return (!wait_all_regst_return_) && produced_repeat_var_rs_.IsCurSlotReady();\n  }\n\n  void NormalProcessCustomizedEordMsg(const ActorMsg&) override {}\n  bool IsCustomizedReadAlwaysUnReadyFromNow() const override {\n    // all Messages are flushed\n    return ReceiveEordMsg(consumed_var_regst_desc_id_);\n  }\n\n  void VirtualActorInit(const TaskProto& proto) override;\n  void Act() override;\n  void AsyncSendCustomizedProducedRegstMsgToConsumer() override;\n  void AsyncSendCustomizedConsumedRegstMsgToProducer() override;\n  void UpdtStateAsCustomizedProducedRegst(Regst* regst) override;\n  void NormalProcessCustomizedReadableRegstMsg(const ActorMsg& msg) override;\n\n  int32_t repeat_count_;\n  int32_t repeat_num_;\n  bool wait_all_regst_return_;\n  int64_t consumed_var_regst_desc_id_;\n  int64_t produced_repeat_var_regst_desc_id_;\n  RegstSlot consumed_var_rs_;\n  RegstSlot produced_repeat_var_rs_;\n};\n\nvoid RepeatActor::VirtualActorInit(const TaskProto& proto) {\n  repeat_count_ = 0;\n  const OperatorConf& op_conf =\n      proto.exec_sequence().exec_node(0).kernel_conf().op_attribute().op_conf();\n  repeat_num_ = user_op::UserOpConfWrapper(op_conf).attr<int32_t>(\"repeat_num\");\n\n  const Shape& in_time_shape = Singleton<RegstMgr>::Get()\n                                   ->RegstDesc4RegstDescId(Name2SoleRegstDescId(\"in\"))\n                                   .data_regst_time_shape();\n  const Shape& out_time_shape = Singleton<RegstMgr>::Get()\n                                    ->RegstDesc4RegstDescId(Name2SoleRegstDescId(\"out\"))\n                                    .data_regst_time_shape();\n  CHECK_GE(out_time_shape.NumAxes(), 1);\n  CHECK_EQ(in_time_shape.NumAxes() + 1, out_time_shape.NumAxes());\n  FOR_RANGE(int64_t, i, 0, in_time_shape.NumAxes()) {\n    CHECK_EQ(in_time_shape.At(i), out_time_shape.At(i));\n  }\n  CHECK_EQ(repeat_num_, out_time_shape.At(out_time_shape.NumAxes() - 1));\n\n  // input\n  const auto& consumed_ids = proto.consumed_regst_desc_id();\n  auto in_it = consumed_ids.find(\"in\");\n  CHECK(in_it != consumed_ids.end());\n  CHECK_EQ(in_it->second.regst_desc_id_size(), 1);\n  consumed_var_regst_desc_id_ = in_it->second.regst_desc_id(0);\n  consumed_var_rs_.InsertRegstDescId(consumed_var_regst_desc_id_);\n  consumed_var_rs_.InitedDone();\n\n  // output\n  const auto& produced_ids = proto.produced_regst_desc();\n  auto out_it = produced_ids.find(\"out\");\n  CHECK(out_it != produced_ids.end());\n  const RegstDescProto& out_regst_desc = out_it->second;\n  CHECK(!out_regst_desc.enable_reuse_mem());\n  CHECK_EQ(out_regst_desc.register_num(), 1);\n  // check inplace\n  CHECK_EQ(out_regst_desc.inplace_consumed_regst_desc_id(), consumed_var_regst_desc_id_);\n  produced_repeat_var_regst_desc_id_ = out_regst_desc.regst_desc_id();\n  produced_repeat_var_rs_.InsertRegstDescId(produced_repeat_var_regst_desc_id_);\n  produced_repeat_var_rs_.InitedDone();\n\n  // NOTE(chengcheng): repeat actor may has output ctrl regst. ctrl regst also need hack regst num.\n  for (const auto& pair : proto.produced_regst_desc()) {\n    const RegstDescProto& regst_desc = pair.second;\n    int64_t regst_desc_id = regst_desc.regst_desc_id();\n    // This iter begins from 1 because first ctrl regst was already inserted in\n    // TakeOverNaiveProduced\n    for (int64_t i = 1; i < repeat_num_; ++i) {\n      Singleton<RegstMgr>::Get()->NewRegsts(regst_desc, [this, regst_desc_id](Regst* regst) {\n        produced_regsts_[regst_desc_id].emplace_back(regst);\n        produced_regst2reading_cnt_[regst] = 0;\n        if (regst_desc_id != produced_repeat_var_regst_desc_id_) {\n          CHECK_EQ(0, naive_produced_rs_.TryPushBackRegst(regst));\n        }\n      });\n    }\n  }\n\n  ForEachProducedRegst([&](Regst* regst) {\n    if (regst->regst_desc_id() == produced_repeat_var_regst_desc_id_) {\n      CHECK_EQ(0, produced_repeat_var_rs_.TryPushBackRegst(regst));\n    }\n  });\n\n  for (const auto& pair : proto.produced_regst_desc()) {\n    const RegstDescProto& regst_desc = pair.second;\n    int64_t regst_desc_id = regst_desc.regst_desc_id();\n    if (regst_desc_id == produced_repeat_var_regst_desc_id_) {\n      CHECK_EQ(produced_repeat_var_rs_.GetReadyRegstSize(regst_desc_id), repeat_num_);\n    } else {\n      CHECK_EQ(naive_produced_rs_.GetReadyRegstSize(regst_desc_id), repeat_num_);\n    }\n  }\n\n  OF_SET_MSG_HANDLER(&RepeatActor::HandlerNormal);\n}\n\nvoid RepeatActor::Act() {\n  repeat_count_ += 1;\n\n  if (repeat_count_ == repeat_num_) {\n    wait_all_regst_return_ = true;\n    repeat_count_ = 0;\n  }\n\n  Regst* out_regst = produced_repeat_var_rs_.Front(produced_repeat_var_regst_desc_id_);\n  Regst* in_regst = consumed_var_rs_.Front(consumed_var_regst_desc_id_);\n  CHECK(out_regst && in_regst);\n  CHECK(out_regst->body_mem_ptr() == in_regst->body_mem_ptr());\n  CHECK(out_regst->header_mem_ptr() == in_regst->header_mem_ptr());\n  CHECK_EQ(out_regst->regst_desc()->MainByteSize4OneRegst(),\n           in_regst->regst_desc()->MainByteSize4OneRegst());\n  CHECK_EQ(out_regst->regst_desc()->SeparatedHeaderByteSize4OneRegst(),\n           in_regst->regst_desc()->SeparatedHeaderByteSize4OneRegst());\n}\n\nvoid RepeatActor::AsyncSendCustomizedProducedRegstMsgToConsumer() {\n  CHECK(produced_repeat_var_rs_.IsCurSlotReady());\n  Regst* const repeat_var_regst = produced_repeat_var_rs_.Front(produced_repeat_var_regst_desc_id_);\n  CHECK_GT(HandleRegstToConsumer(repeat_var_regst), 0);\n  produced_repeat_var_rs_.PopFrontRegsts({produced_repeat_var_regst_desc_id_});\n}\n\nvoid RepeatActor::AsyncSendCustomizedConsumedRegstMsgToProducer() {\n  // NOTE(chengcheng): do nothing. consumed var regst will return in inplace done.\n}\n\nvoid RepeatActor::UpdtStateAsCustomizedProducedRegst(Regst* regst) {\n  CHECK_EQ(regst->regst_desc_id(), produced_repeat_var_regst_desc_id_);\n  CHECK_EQ(produced_repeat_var_rs_.TryPushBackRegst(regst), 0);\n\n  if (wait_all_regst_return_\n      && produced_repeat_var_rs_.GetReadyRegstSize(produced_repeat_var_regst_desc_id_)\n             == repeat_num_) {\n    Regst* in_regst = consumed_var_rs_.Front(consumed_var_regst_desc_id_);\n    CHECK(in_regst);\n    AsyncSendRegstMsgToProducer(in_regst);\n    CHECK_EQ(0, consumed_var_rs_.TryPopFrontRegst(consumed_var_regst_desc_id_));\n    wait_all_regst_return_ = false;\n  }\n}\n\nvoid RepeatActor::NormalProcessCustomizedReadableRegstMsg(const ActorMsg& msg) {\n  CHECK_EQ(0, consumed_var_rs_.TryPushBackRegst(msg.regst()));\n}\nREGISTER_ACTOR(TaskType::kRepeat, RepeatActor);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/lazy/actor/sink_actor.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/lazy/actor/sink_actor.h\"\n\nnamespace oneflow {\n\nvoid SinkActor::VirtualActorInit(const TaskProto& proto) {\n  OF_SET_MSG_HANDLER(&SinkActor::HandlerNormal);\n  VirtualSinkActorInit(proto);\n}\n\nvoid SinkActor::Act() { AsyncLaunchKernel(); }\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/lazy/actor/sink_actor.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_LAZY_ACTOR_SINK_ACTOR_H_\n#define ONEFLOW_CORE_LAZY_ACTOR_SINK_ACTOR_H_\n\n#include \"oneflow/core/lazy/actor/actor.h\"\n\nnamespace oneflow {\n\nclass SinkActor : public Actor {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(SinkActor);\n  SinkActor() = default;\n  virtual ~SinkActor() = default;\n\n protected:\n  virtual void VirtualSinkActorInit(const TaskProto&) {}\n\n private:\n  void VirtualActorInit(const TaskProto&) override;\n  void Act() override;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_LAZY_ACTOR_SINK_ACTOR_H_\n"
  },
  {
    "path": "oneflow/core/lazy/actor/source_tick_actor.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/lazy/actor/actor.h\"\n#include \"oneflow/core/job/runtime_context.h\"\n#include \"oneflow/core/record/record.pb.h\"\n\nnamespace oneflow {\n\nclass SourceTickActor final : public Actor {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(SourceTickActor);\n  SourceTickActor() = default;\n  ~SourceTickActor() = default;\n\n private:\n  void VirtualActorInit(const TaskProto&) override;\n  void Act() override;\n  std::pair<RegstNameType, HashSet<std::string>> GetNaiveOrCustomizedConsumedRegstDescName()\n      override {\n    return std::make_pair(RegstNameType::kNaive, HashSet<std::string>{});\n  }\n  bool IsCustomizedReadReady() const override;\n  bool IsCustomizedReadAlwaysUnReadyFromNow() const override { return !IsCustomizedReadReady(); }\n\n  int HandlerWaitToStart(const ActorMsg&);\n};\n\nvoid SourceTickActor::VirtualActorInit(const TaskProto& task_proto) {\n  OF_SET_MSG_HANDLER(&SourceTickActor::HandlerWaitToStart);\n}\n\nvoid SourceTickActor::Act() {}\n\nbool SourceTickActor::IsCustomizedReadReady() const {\n  // NOTE(chengcheng): SourceTickActor CANNOT be used and need delete in the future\n  return true;\n}\n\nint SourceTickActor::HandlerWaitToStart(const ActorMsg& msg) {\n  CHECK_EQ(msg.actor_cmd(), ActorCmd::kStart);\n  OF_SET_MSG_HANDLER(&SourceTickActor::HandlerNormal);\n  return ProcessMsg(msg);\n}\n\nREGISTER_ACTOR(kSourceTick, SourceTickActor);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/lazy/actor/ssp_variable_proxy_actor.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/lazy/actor/actor.h\"\n#include \"oneflow/core/framework/user_op_conf.h\"\n\nnamespace oneflow {\n\nclass SspVariableProxyActor final : public Actor {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(SspVariableProxyActor);\n  SspVariableProxyActor() = default;\n  ~SspVariableProxyActor() override = default;\n\n protected:\n  std::pair<RegstNameType, HashSet<std::string>> GetNaiveOrCustomizedConsumedRegstDescName()\n      override {\n    return std::make_pair(RegstNameType::kNaive, HashSet<std::string>{});\n  }\n  std::pair<RegstNameType, HashSet<std::string>> GetNaiveOrCustomizedProducedRegstDescName()\n      override {\n    return std::make_pair(RegstNameType::kNaive, HashSet<std::string>{});\n  }\n  bool IsCustomizedReadReady() const override { return consumed_var_rs_.IsCurSlotReady(); }\n  bool IsCustomizedWriteReady() const override {\n    int64_t cur_staleness = (received_var_piece_id_ - ack_msg_returned_ref_piece_id_);\n    return ((cur_staleness <= staleness() /* bounded staleness */)\n            && (produced_value_rs_.IsCurSlotReady()\n                /* able to send messages to consumers of output `value` */))\n           || (produced_ref_rs_.IsCurSlotReady()\n               /* able to send or to flush messages to consumers of output `ref` */);\n  }\n  void NormalProcessCustomizedEordMsg(const ActorMsg&) override {}\n  bool IsCustomizedReadAlwaysUnReadyFromNow() const override {\n    // all Messages are flushed\n    return ReceiveEordMsg(consumed_var_regst_desc_id_)\n           && (received_var_piece_id_ <= ack_msg_returned_value_piece_id_ + 1\n               /* there is no need to wait the last piece */)\n           && (received_var_piece_id_ == ack_msg_returned_ref_piece_id_);\n  }\n\n  void UpdtStateAsCustomizedProducedRegst(Regst* regst) override {\n    if (regst->regst_desc_id() == produced_value_regst_desc_id_) {\n      ++ack_msg_returned_value_piece_id_;\n      CHECK_EQ(regst, GetRingBufferValueRegst(ack_msg_returned_value_piece_id_));\n      CHECK_EQ(0, produced_value_rs_.TryPushBackRegst(regst));\n      if (ack_msg_returned_ref_piece_id_ == ack_msg_returned_value_piece_id_\n          /* All mutable consumers to ref regst has done their job */) {\n        // The updated ref regst are not synced into value regst yet.\n        SyncRefRegstIntoValueRegst(ack_msg_returned_value_piece_id_);\n      } else if (ack_msg_returned_ref_piece_id_ > ack_msg_returned_value_piece_id_) {\n        // The ACK of ref resgt can just be slightly earlier than the one of value regst.\n        // `slightly` means `ack_msg_returned_ref_piece_id_ == ack_msg_returned_value_piece_id_`\n        UNIMPLEMENTED();\n      } else {\n        // Do nothing. The ref data is not updated yet.\n      }\n    } else if (regst->regst_desc_id() == produced_ref_regst_desc_id_) {\n      ++ack_msg_returned_ref_piece_id_;\n      CHECK_EQ(regst, ref_regst_);\n      if (ack_msg_returned_value_piece_id_ >= ack_msg_returned_ref_piece_id_\n          /* All const consumers to value regst has done their job */) {\n        SyncRefRegstIntoValueRegst(ack_msg_returned_ref_piece_id_);\n      } else {\n        // Do nothing. The ACK of value regst will do the sync work\n      }\n    } else {\n      UNIMPLEMENTED();\n    }\n  }\n\n  void AsyncSendCustomizedProducedRegstMsgToConsumer() override {\n    if (consumed_var_rs_.IsCurSlotReady() && produced_value_rs_.IsCurSlotReady()) {\n      Regst* const value_regst = produced_value_rs_.Front(produced_value_regst_desc_id_);\n      if (value_regst->consumers_actor_id().empty()) {\n        ++ack_msg_returned_value_piece_id_;\n      } else {\n        CHECK_EQ(value_regst, GetRingBufferValueRegst(received_var_piece_id_));\n        CHECK_GT(HandleRegstToConsumer(value_regst), 0);\n        produced_value_rs_.PopFrontRegsts({produced_value_regst_desc_id_});\n      }\n    }\n    if ((ack_msg_returned_ref_piece_id_ < received_var_piece_id_)\n        && produced_ref_rs_.IsCurSlotReady()) {\n      Regst* const ref_regst = produced_ref_rs_.Front(produced_ref_regst_desc_id_);\n      if (ref_regst->consumers_actor_id().empty()) {\n        ++ack_msg_returned_ref_piece_id_;\n      } else {\n        CHECK_GT(HandleRegstToConsumer(ref_regst), 0);\n        produced_ref_rs_.PopFrontRegsts({produced_ref_regst_desc_id_});\n      }\n    }\n  }\n\n  void AsyncSendCustomizedConsumedRegstMsgToProducer() override {\n    Regst* const var_regst = consumed_var_rs_.Front(consumed_var_regst_desc_id_);\n    CHECK_NOTNULL(var_regst);\n    AsyncSendRegstMsgToProducer(var_regst);\n    CHECK_EQ(0, consumed_var_rs_.TryPopFrontRegst(consumed_var_regst_desc_id_));\n  }\n\n  void ForEachCurCustomizedReadableRegst(std::function<void(const Regst*)> Handler) const override {\n    Handler(consumed_var_rs_.Front(consumed_var_regst_desc_id_));\n  }\n\n  void TakeOverInplaceConsumedAndProduced(\n      const PbMap<std::string, RegstDescProto>& produced_ids) override {\n    inplace_consumed_rs_.InitedDone();\n    inplace_produced_rs_.InitedDone();\n  }\n\n  void VirtualActorInit(const TaskProto& task_proto) override {\n    CheckInplaceBetweenVarAndRef(task_proto);\n    TakeOverVarRegst(task_proto.consumed_regst_desc_id());\n    TakeOverRefRegst(task_proto.produced_regst_desc());\n    TakeOverValueRegst(task_proto.produced_regst_desc());\n    OF_SET_MSG_HANDLER(&SspVariableProxyActor::HandlerNormal);\n  }\n\n  bool ProducedCtrlRegstValid(int64_t regst_desc_id) const override { return true; }\n\n  void NormalProcessCustomizedReadableRegstMsg(const ActorMsg& msg) override {\n    if (var_regst_ == nullptr) {\n      var_regst_ = msg.regst();\n    } else {\n      CHECK_EQ(var_regst_, msg.regst());\n    }\n    CHECK_EQ(0, consumed_var_rs_.TryPushBackRegst(var_regst_));\n    ++received_var_piece_id_;\n  }\n\n private:\n  void Act() override {\n    if (received_var_piece_id_ == 0) {\n      // Initialize all value regsts\n      for (int64_t piece_id = 0; piece_id < staleness(); ++piece_id) {\n        CopyRefToValue(GetRingBufferValueRegst(piece_id));\n      }\n    } else {\n      // Do nothing, value regsts are updated in UpdtStateAsCustomizedProducedRegst\n    }\n  }\n\n  void CheckInplaceBetweenVarAndRef(const TaskProto& task_proto) const {\n    int64_t var_id = task_proto.consumed_regst_desc_id().at(\"var\").regst_desc_id(0);\n    const auto& ref_regst_desc_proto = task_proto.produced_regst_desc().at(\"ref\");\n    CHECK_EQ(ref_regst_desc_proto.inplace_consumed_regst_desc_id(), var_id);\n  }\n\n  void TakeOverVarRegst(const PbMap<std::string, RegstDescIdSet>& consumed_ids) {\n    received_var_piece_id_ = -1;\n    consumed_var_regst_desc_id_ = consumed_ids.at(\"var\").regst_desc_id(0);\n    consumed_var_rs_.InsertRegstDescId(consumed_var_regst_desc_id_);\n    consumed_var_rs_.InitedDone();\n    var_regst_ = nullptr;\n  }\n\n  void TakeOverRefRegst(const PbMap<std::string, RegstDescProto>& produced_ids) {\n    ack_msg_returned_ref_piece_id_ = -1;\n    produced_ref_regst_desc_id_ = produced_ids.at(\"ref\").regst_desc_id();\n    produced_ref_rs_.InsertRegstDescId(produced_ref_regst_desc_id_);\n    produced_ref_rs_.InitedDone();\n    ref_regst_ = nullptr;\n    ForEachProducedRegst([&](Regst* regst) {\n      if (regst->regst_desc_id() != produced_ref_regst_desc_id_) { return; }\n      CHECK(ref_regst_ == nullptr) << \"regst_num of ref_regst must equal 1\";\n      CHECK_EQ(0, produced_ref_rs_.TryPushBackRegst(regst));\n      ref_regst_ = regst;\n    });\n  }\n\n  void TakeOverValueRegst(const PbMap<std::string, RegstDescProto>& produced_ids) {\n    ack_msg_returned_value_piece_id_ = -1;\n    produced_value_regst_desc_id_ = produced_ids.at(\"value\").regst_desc_id();\n    produced_value_rs_.InsertRegstDescId(produced_value_regst_desc_id_);\n    produced_value_rs_.InitedDone();\n    ForEachProducedRegst([&](Regst* regst) {\n      if (regst->regst_desc_id() != produced_value_regst_desc_id_) { return; }\n      CHECK_EQ(0, produced_value_rs_.TryPushBackRegst(regst));\n      value_regst_ring_buffer_.push_back(regst);\n    });\n  }\n\n  void SyncRefRegstIntoValueRegst(int64_t released_piece_id) {\n    CopyRefToValue(GetRingBufferValueRegst(released_piece_id));\n    CHECK_EQ(0, produced_ref_rs_.TryPushBackRegst(ref_regst_));\n  }\n\n  void CopyRefToValue(Regst* value_regst) {\n    AsyncLaunchKernel([&](int64_t regst_desc_id) -> Regst* {\n      if (regst_desc_id == consumed_var_regst_desc_id_) {\n        return var_regst_;\n      } else if (regst_desc_id == produced_ref_regst_desc_id_) {\n        return ref_regst_;\n      } else if (regst_desc_id == produced_value_regst_desc_id_) {\n        return value_regst;\n      } else {\n        UNIMPLEMENTED();\n      }\n    });\n  }\n\n  Regst* GetRingBufferValueRegst(int64_t value_piece_id) const {\n    return value_regst_ring_buffer_.at(value_piece_id % staleness());\n  }\n\n  size_t staleness() const { return value_regst_ring_buffer_.size(); }\n\n  // input var\n  int64_t received_var_piece_id_;\n  int64_t consumed_var_regst_desc_id_;\n  RegstSlot consumed_var_rs_;\n  Regst* var_regst_;\n  // output ref\n  // consumers has used the ref regst\n  int64_t ack_msg_returned_ref_piece_id_;\n  int64_t produced_ref_regst_desc_id_;\n  RegstSlot produced_ref_rs_;\n  Regst* ref_regst_;\n  // output value\n  // consumers has used the value regst\n  int64_t ack_msg_returned_value_piece_id_;\n  int64_t produced_value_regst_desc_id_;\n  RegstSlot produced_value_rs_;\n  std::vector<Regst*> value_regst_ring_buffer_;\n};\n\nREGISTER_ACTOR(TaskType::kSspVariableProxy, SspVariableProxyActor);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/lazy/actor/tick_actor.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/lazy/actor/naive_actor.h\"\n\nnamespace oneflow {\n\nclass TickActor final : public NaiveActor {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(TickActor);\n  TickActor() = default;\n  ~TickActor() = default;\n\n private:\n  void Act() override {}\n};\n\nREGISTER_ACTOR(kTick, TickActor);\nREGISTER_ACTOR(kDeviceTick, TickActor);\nREGISTER_ACTOR(kSrcSubsetTick, TickActor);\nREGISTER_ACTOR(kDstSubsetTick, TickActor);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/lazy/actor/unpack_actor.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/lazy/actor/actor.h\"\n#include \"oneflow/core/kernel/user_kernel.h\"\n#include \"oneflow/user/kernels/op_kernel_wrapper.h\"\n\nnamespace oneflow {\n\nclass UnpackActor final : public Actor {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(UnpackActor);\n  UnpackActor() = default;\n  ~UnpackActor() override = default;\n\n private:\n  void VirtualActorInit(const TaskProto& proto) override;\n  void Act() override;\n  void VirtualAsyncSendNaiveProducedRegstMsgToConsumer() override;\n  void VirtualAsyncSendNaiveConsumedRegstMsgToProducer() override;\n  bool ConsumedCtrlRegstValid(int64_t regst_desc_id) const override;\n\n  size_t total_unpack_num_;\n  size_t act_num_cnt_;\n};\n\nvoid UnpackActor::VirtualActorInit(const TaskProto& proto) {\n  const Shape& out_time_shape = Singleton<RegstMgr>::Get()\n                                    ->RegstDesc4RegstDescId(Name2SoleRegstDescId(\"out\"))\n                                    .data_regst_time_shape();\n  total_unpack_num_ = out_time_shape.At(out_time_shape.NumAxes() - 1);\n  act_num_cnt_ = 0;\n  OF_SET_MSG_HANDLER(&UnpackActor::HandlerNormal);\n}\n\nvoid UnpackActor::Act() {\n  CHECK_GE(exec_kernel_vec().size(), 1);\n  auto user_kernel = dynamic_cast<const UserKernel*>(exec_kernel_vec().at(0).kernel.get());\n  CHECK_NOTNULL(user_kernel);\n  auto state = dynamic_cast<OpKernelStateWrapper<std::pair<size_t, size_t>>*>(\n      user_kernel->GetOpKernelState().get());\n  CHECK_NOTNULL(state);\n  state->Mutable()->first = act_num_cnt_;\n  state->Mutable()->second = total_unpack_num_;\n  AsyncLaunchKernel();\n  act_num_cnt_ += 1;\n}\n\nvoid UnpackActor::VirtualAsyncSendNaiveProducedRegstMsgToConsumer() {\n  HandleProducedNaiveDataRegstToConsumer();\n}\n\nvoid UnpackActor::VirtualAsyncSendNaiveConsumedRegstMsgToProducer() {\n  if (act_num_cnt_ == total_unpack_num_) {\n    HandleConsumedNaiveDataRegstToProducer();\n    act_num_cnt_ = 0;\n  }\n}\n\nbool UnpackActor::ConsumedCtrlRegstValid(int64_t regst_desc_id) const { return act_num_cnt_ == 0; }\n\nREGISTER_ACTOR(TaskType::kUnpack, UnpackActor);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/lazy/actor/wait_and_send_ids_actor.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/lazy/actor/actor.h\"\n#include \"oneflow/core/kernel/wait_and_send_ids_kernel.h\"\n#include \"oneflow/core/job/runtime_context.h\"\n#include \"oneflow/core/record/record.pb.h\"\n\nnamespace oneflow {\n\nclass WaitAndSendIdsActor final : public Actor {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(WaitAndSendIdsActor);\n  WaitAndSendIdsActor() : wait_and_send_ids_status_(nullptr) {}\n  ~WaitAndSendIdsActor() = default;\n\n private:\n  void VirtualActorInit(const TaskProto&) override;\n  void Act() override;\n  std::pair<RegstNameType, HashSet<std::string>> GetNaiveOrCustomizedConsumedRegstDescName()\n      override {\n    return std::make_pair(RegstNameType::kNaive, HashSet<std::string>{});\n  }\n  void VirtualAsyncSendNaiveProducedRegstMsgToConsumer() override;\n  bool IsCustomizedReadReady() const override;\n  bool IsCustomizedReadAlwaysUnReadyFromNow() const override { return !IsCustomizedReadReady(); }\n\n  int HandlerWaitToStart(const ActorMsg&);\n\n  WaitAndSendIdsStatus* wait_and_send_ids_status_;\n};\n\nvoid WaitAndSendIdsActor::VirtualActorInit(const TaskProto& task_proto) {\n  CHECK_EQ(exec_kernel_vec().size(), 1);\n  wait_and_send_ids_status_ = CHECK_NOTNULL(\n      dynamic_cast<WaitAndSendIdsStatus*>(exec_kernel_vec().at(0).kernel_ctx->state().get()));\n  wait_and_send_ids_status_->buffer_status_ = kBufferStatusSuccess;\n  wait_and_send_ids_status_->in_id_ = 0;\n  wait_and_send_ids_status_->out_idx_ = 0;\n  wait_and_send_ids_status_->out_num_ = 0;\n  OF_SET_MSG_HANDLER(&WaitAndSendIdsActor::HandlerWaitToStart);\n}\n\nvoid WaitAndSendIdsActor::Act() {\n  CHECK_LE(wait_and_send_ids_status_->out_idx_, wait_and_send_ids_status_->out_num_);\n  AsyncLaunchKernel();\n}\n\nvoid WaitAndSendIdsActor::VirtualAsyncSendNaiveProducedRegstMsgToConsumer() {\n  if (wait_and_send_ids_status_->buffer_status_ == kBufferStatusSuccess) {\n    HandleProducedNaiveDataRegstToConsumer();\n  }\n}\n\nbool WaitAndSendIdsActor::IsCustomizedReadReady() const {\n  return wait_and_send_ids_status_->buffer_status_ == kBufferStatusSuccess;\n}\n\nint WaitAndSendIdsActor::HandlerWaitToStart(const ActorMsg& msg) {\n  CHECK_EQ(msg.actor_cmd(), ActorCmd::kStart);\n  OF_SET_MSG_HANDLER(&WaitAndSendIdsActor::HandlerNormal);\n  return ProcessMsg(msg);\n}\n\nREGISTER_ACTOR(kWaitAndSendIds, WaitAndSendIdsActor);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/lazy/stream_context/common/generic_stream_context.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/lazy/stream_context/include/generic_stream_context.h\"\n#include \"oneflow/core/job/global_for.h\"\n#include \"oneflow/core/ep/include/device_manager_registry.h\"\n#include \"oneflow/core/ep/include/active_device_guard.h\"\n\nnamespace oneflow {\n\nGenericStreamContext::GenericStreamContext(const StreamId& stream_id) : stream_(nullptr) {\n  device_ =\n      std::dynamic_pointer_cast<ep::Device>(Singleton<ep::DeviceManagerRegistry>::Get()->GetDevice(\n          stream_id.device_type(), stream_id.device_index()));\n  CHECK(device_);\n  ep::ActiveDeviceGuard guard(device_.get());\n  stream_ = dynamic_cast<ep::Stream*>(device_->CreateStream());\n  CHECK(stream_ != nullptr);\n  poller_thread_ = std::thread([this]() {\n    CHECK_JUST(stream_->OnExecutionContextSetup());\n    std::pair<ep::Event*, std::function<void()>> cb_event;\n    while (cb_event_chan_.Receive(&cb_event) == kChannelStatusSuccess) {\n      CHECK_JUST(cb_event.first->Sync());\n      cb_event.second();\n      device_->DestroyEvent(cb_event.first);\n    }\n    CHECK_JUST(stream_->OnExecutionContextTeardown());\n  });\n}\n\nGenericStreamContext::~GenericStreamContext() {\n  ep::ActiveDeviceGuard guard(device_.get());\n  cb_event_chan_.Close();\n  poller_thread_.join();\n  device_->DestroyStream(stream_);\n}\n\nMaybe<void> GenericStreamContext::AddCallback(std::function<void()> callback) {\n  ep::Event* event = device_->CreateEvent();\n  stream_->RecordEvent(event);\n  cb_event_chan_.Send(std::make_pair(event, std::move(callback)));\n  return Maybe<void>::Ok();\n}\n\nep::Stream* GenericStreamContext::stream() { return stream_; }\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/lazy/stream_context/cpu/cpu_stream_context.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/lazy/stream_context/include/stream_context.h\"\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/common/device_type.h\"\n#include \"oneflow/core/device/event_record.h\"\n#include \"oneflow/core/kernel/chain_kernel_observer.h\"\n#include \"oneflow/core/kernel/cpu_check_numerics_kernel_observer.h\"\n#include \"oneflow/core/graph/stream_id.h\"\n#include \"oneflow/core/ep/cpu/cpu_stream.h\"\n#include \"oneflow/core/ep/cpu/cpu_device.h\"\n#include \"oneflow/core/ep/include/device_manager_registry.h\"\n\nnamespace oneflow {\n\nclass CpuStreamContext : public StreamContext, public KernelObserverProvider {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CpuStreamContext);\n  CpuStreamContext();\n  ~CpuStreamContext() override;\n\n  ep::Stream* stream() override;\n  Maybe<void> AddCallback(std::function<void()> callback) override;\n  KernelObserver* GetKernelObserver() override;\n  DeviceType device_type() const override { return DeviceType::kCPU; }\n\n private:\n  std::shared_ptr<ep::Device> device_;\n  ep::Stream* stream_;\n  std::unique_ptr<KernelObserver> kernel_observer_;\n};\n\nCpuStreamContext::CpuStreamContext() : stream_(nullptr) {\n  device_ = Singleton<ep::DeviceManagerRegistry>::Get()->GetDevice(DeviceType::kCPU, 0);\n  stream_ = device_->CreateStream();  // NOLINT\n  std::vector<std::shared_ptr<KernelObserver>> kernel_observers;\n  if (ParseBooleanFromEnv(\"ONEFLOW_DEBUG_KERNEL_SYNC_CHECK_NUMERICS\", false)) {\n    kernel_observers.emplace_back(new CpuCheckNumericsKernelObserver());\n  }\n  kernel_observer_.reset(new ChainKernelObserver(kernel_observers));\n}\n\nCpuStreamContext::~CpuStreamContext() { device_->DestroyStream(stream_); }\n\nep::Stream* CpuStreamContext::stream() { return stream_; }\n\nMaybe<void> CpuStreamContext::AddCallback(std::function<void()> callback) {\n  callback();\n  return Maybe<void>::Ok();\n}\n\nKernelObserver* CpuStreamContext::GetKernelObserver() { return kernel_observer_.get(); }\n\nREGISTER_STREAM_CONTEXT_CREATOR_WITH_STREAM_ID(DeviceType::kCPU,\n                                               ([](const StreamId& stream_id) -> StreamContext* {\n                                                 return new CpuStreamContext();\n                                               }));\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/lazy/stream_context/cuda/cuda_stream_context.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/lazy/stream_context/include/stream_context.h\"\n#include \"oneflow/core/profiler/profiler.h\"\n#include \"oneflow/core/job/global_for.h\"\n#include \"oneflow/core/common/device_type.h\"\n#include \"oneflow/core/kernel/chain_kernel_observer.h\"\n#include \"oneflow/core/kernel/cuda_check_numerics_kernel_observer.h\"\n#include \"oneflow/core/graph/stream_id.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n#include \"oneflow/core/ep/cuda/cuda_device.h\"\n#include \"oneflow/core/ep/include/device_manager_registry.h\"\n#include \"oneflow/core/common/channel.h\"\n\n#ifdef WITH_CUDA\n#include <cublas_v2.h>\n\nnamespace oneflow {\n\nnamespace {\n\nclass CudaStreamContext : public StreamContext, public KernelObserverProvider {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CudaStreamContext);\n  explicit CudaStreamContext(int device_index);\n  ~CudaStreamContext() override;\n\n  Maybe<void> AddCallback(std::function<void()> callback) override;\n  DeviceType device_type() const override { return DeviceType::kCUDA; }\n  KernelObserver* GetKernelObserver() override;\n\n  ep::Stream* stream() override;\n\n private:\n  ep::CudaStream* stream_;\n  Channel<std::pair<ep::Event*, std::function<void()>>> cb_event_chan_;\n  std::thread poller_thread_;\n  int device_index_;\n  std::unique_ptr<KernelObserver> kernel_observer_;\n  std::shared_ptr<ep::CudaDevice> device_;\n};\n\nCudaStreamContext::CudaStreamContext(int device_index)\n    : stream_(nullptr), device_index_(device_index) {\n  CudaCurrentDeviceGuard guard(device_index_);\n  device_ = std::dynamic_pointer_cast<ep::CudaDevice>(\n      Singleton<ep::DeviceManagerRegistry>::Get()->GetDevice(DeviceType::kCUDA, device_index));\n  CHECK(device_);\n  stream_ = dynamic_cast<ep::CudaStream*>(device_->CreateStream());\n  CHECK(stream_ != nullptr);\n\n  std::vector<std::shared_ptr<KernelObserver>> kernel_observers;\n  if (ParseBooleanFromEnv(\"ONEFLOW_DEBUG_KERNEL_SYNC_CHECK_NUMERICS\", false)) {\n    LOG(WARNING) << \"Environment variable ONEFLOW_DEBUG_KERNEL_SYNC_CHECK_NUMERICS has been set \"\n                    \"to a truthy \"\n                    \"value, it will impact performance\";\n    kernel_observers.emplace_back(new CudaCheckNumericsKernelObserver());\n  }\n  kernel_observer_.reset(new ChainKernelObserver(kernel_observers));\n\n  poller_thread_ = std::thread([this]() {\n    CHECK_JUST(stream_->OnExecutionContextSetup());\n    OF_PROFILER_NAME_THIS_HOST_THREAD(\"_cuda\" + std::to_string(device_index_) + \" Poller : (\"\n                                      + std::to_string(device_index_) + \")\");\n    std::pair<ep::Event*, std::function<void()>> cb_event;\n    while (cb_event_chan_.Receive(&cb_event) == kChannelStatusSuccess) {\n      CHECK_JUST(cb_event.first->Sync());\n      cb_event.second();\n      device_->DestroyEvent(cb_event.first);\n    }\n    CHECK_JUST(stream_->OnExecutionContextTeardown());\n  });\n}\n\nCudaStreamContext::~CudaStreamContext() {\n  CudaCurrentDeviceGuard guard(device_index_);\n  cb_event_chan_.Close();\n  poller_thread_.join();\n  device_->DestroyStream(stream_);\n}\n\nMaybe<void> CudaStreamContext::AddCallback(std::function<void()> callback) {\n  ep::Event* event = device_->CreateEvent();\n  stream_->RecordEvent(event);\n  cb_event_chan_.Send(std::make_pair(event, std::move(callback)));\n  return Maybe<void>::Ok();\n}\n\nKernelObserver* CudaStreamContext::GetKernelObserver() { return kernel_observer_.get(); }\n\nep::Stream* CudaStreamContext::stream() { return stream_; }\n\nREGISTER_STREAM_CONTEXT_CREATOR_WITH_STREAM_ID(\n    DeviceType::kCUDA, ([](const StreamId& stream_id) -> StreamContext* {\n      CHECK_EQ(stream_id.device_type(), DeviceType::kCUDA);\n      return new CudaStreamContext(stream_id.device_index());\n    }));\n\n}  // namespace\n\n}  // namespace oneflow\n\n#endif\n"
  },
  {
    "path": "oneflow/core/lazy/stream_context/include/generic_stream_context.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_LAZY_STREAM_CONTEXT_GENERIC_STREAM_CONTEXT_H_\n#define ONEFLOW_CORE_LAZY_STREAM_CONTEXT_GENERIC_STREAM_CONTEXT_H_\n\n#include \"oneflow/core/lazy/stream_context/include/stream_context.h\"\n#include \"oneflow/core/common/device_type.h\"\n#include \"oneflow/core/graph/stream_id.h\"\n#include \"oneflow/core/ep/include/stream.h\"\n#include \"oneflow/core/ep/include/device.h\"\n#include \"oneflow/core/common/channel.h\"\n\nnamespace oneflow {\n\nclass GenericStreamContext : public StreamContext {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(GenericStreamContext);\n  explicit GenericStreamContext(const StreamId& stream_id);\n  ~GenericStreamContext() override;\n\n  Maybe<void> AddCallback(std::function<void()> callback) override;\n  DeviceType device_type() const override { return stream_->device_type(); }\n\n  ep::Stream* stream() override;\n\n private:\n  ep::Stream* stream_;\n  Channel<std::pair<ep::Event*, std::function<void()>>> cb_event_chan_;\n  std::thread poller_thread_;\n  std::shared_ptr<ep::Device> device_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_LAZY_STREAM_CONTEXT_GENERIC_STREAM_CONTEXT_H_\n"
  },
  {
    "path": "oneflow/core/lazy/stream_context/include/stream_context.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_LAZY_STREAM_CONTEXT_STREAM_CONTEXT_H_\n#define ONEFLOW_CORE_LAZY_STREAM_CONTEXT_STREAM_CONTEXT_H_\n\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/common/auto_registration_factory.h\"\n#include \"oneflow/core/common/device_type.h\"\n#include \"oneflow/core/ep/include/stream.h\"\n\nnamespace oneflow {\n\nclass StreamContext {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(StreamContext);\n  StreamContext() = default;\n  virtual ~StreamContext() = default;\n\n  virtual ep::Stream* stream() = 0;\n  virtual Maybe<void> AddCallback(std::function<void()> callback) = 0;\n  virtual DeviceType device_type() const = 0;\n};\n\n#define REGISTER_STREAM_CONTEXT_CREATOR_WITH_STREAM_ID(device, creator) \\\n  REGISTER_CLASS_CREATOR(int, device, StreamContext, creator, const StreamId&)\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_LAZY_STREAM_CONTEXT_STREAM_CONTEXT_H_\n"
  },
  {
    "path": "oneflow/core/memory/chunk_manager.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/memory/chunk_manager.h\"\n#include \"oneflow/core/memory/memory_allocator.h\"\n#include \"oneflow/core/memory/memory_case_util.h\"\n#include \"oneflow/core/control/global_process_ctx.h\"\n\nnamespace oneflow {\n\nvoid ChunkMgr::GetChunkProtosByMemZoneUniqueId(int64_t mem_zone_uid,\n                                               std::vector<const ChunkProto*>* chunks) const {\n  std::unique_lock<std::mutex> guard(mutex_);\n  chunks->clear();\n  auto chunk_ids_it = mzuid2chunk_ids_.find(mem_zone_uid);\n  if (chunk_ids_it != mzuid2chunk_ids_.end()) {\n    const auto& chunk_ids = chunk_ids_it->second;\n    chunks->reserve(chunk_ids.size());\n    for (int64_t chunk_id : chunk_ids) {\n      auto chunk_it = chunk_id2chunk_proto_.find(chunk_id);\n      CHECK(chunk_it != chunk_id2chunk_proto_.end());\n      chunks->emplace_back(chunk_it->second.get());\n    }\n  }\n}\n\nvoid ChunkMgr::AddChunkProto(const ChunkProto& chunk) {\n  std::unique_lock<std::mutex> guard(mutex_);\n  const int64_t mem_zone_uid = memory::GetUniqueMemCaseId(chunk.machine_id(), chunk.mem_case());\n  CHECK(\n      chunk_id2chunk_proto_.emplace(chunk.chunk_id(), std::make_unique<ChunkProto>(chunk)).second);\n  auto chunk_ids_it = mzuid2chunk_ids_.find(mem_zone_uid);\n  if (chunk_ids_it == mzuid2chunk_ids_.end()) {\n    chunk_ids_it = mzuid2chunk_ids_.emplace(mem_zone_uid, HashSet<int64_t>()).first;\n  }\n  CHECK(chunk_ids_it->second.insert(chunk.chunk_id()).second);\n}\n\nchar* ChunkMgr::FindOrCreateChunk(const ChunkProto& chunk) {\n  std::unique_lock<std::mutex> guard(mutex_);\n  CHECK_EQ(GlobalProcessCtx::Rank(), chunk.machine_id());\n  auto it = chunk_id2chunk_.find(chunk.chunk_id());\n  if (it == chunk_id2chunk_.end()) {\n    char* chunk_ptr =\n        Singleton<MemoryAllocator>::Get()->Allocate(chunk.mem_case(), chunk.mem_size());\n    it = chunk_id2chunk_.emplace(chunk.chunk_id(), ChunkWithPtr(chunk_ptr, chunk)).first;\n  } else {\n    const ChunkProto& store_proto = it->second.chunk_proto;\n    CHECK_EQ(chunk.chunk_id(), store_proto.chunk_id());\n    CHECK_EQ(chunk.machine_id(), store_proto.machine_id());\n    CHECK(chunk.mem_case() == store_proto.mem_case());\n    CHECK_EQ(chunk.mem_size(), store_proto.mem_size());\n  }\n  return it->second.ptr;\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/memory/chunk_manager.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_MEMORY_CHUNK_MANAGER_H_\n#define ONEFLOW_CORE_MEMORY_CHUNK_MANAGER_H_\n\n#include <mutex>\n\n#include \"oneflow/core/job/id_manager.h\"\n#include \"oneflow/core/memory/memory_block.pb.h\"\n#include \"oneflow/core/memory/memory_allocator.h\"\n\nnamespace oneflow {\n\nclass ChunkMgr final {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(ChunkMgr);\n  ChunkMgr() = default;\n  ~ChunkMgr() = default;\n\n  // Compiler\n  void GetChunkProtosByMemZoneUniqueId(int64_t mem_zone_uid,\n                                       std::vector<const ChunkProto*>* chunks) const;\n  void AddChunkProto(const ChunkProto& chunk);\n\n  // Runtime\n  char* FindOrCreateChunk(const ChunkProto& chunk);\n\n private:\n  // for master compiler in PlanUtil::GenMemBlockAndChunkWithVariableOpNames4Plan\n  HashMap<int64_t, HashSet<int64_t>> mzuid2chunk_ids_;\n  HashMap<int64_t, std::unique_ptr<ChunkProto>> chunk_id2chunk_proto_;\n\n  struct ChunkWithPtr {\n    char* ptr;\n    ChunkProto chunk_proto;\n    ChunkWithPtr(char* p, const ChunkProto& c_p) : ptr(p), chunk_proto(c_p) {}\n  };\n\n  // for runtime\n  HashMap<int64_t, ChunkWithPtr> chunk_id2chunk_;\n  mutable std::mutex mutex_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_MEMORY_CHUNK_MANAGER_H_\n"
  },
  {
    "path": "oneflow/core/memory/memory_allocator.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/memory/memory_allocator.h\"\n#include \"oneflow/core/device/cuda_util.h\"\n#include \"oneflow/core/eager/eager_blob_object.h\"\n#include \"oneflow/core/job/global_for.h\"\n#include \"oneflow/core/register/blob.h\"\n#include \"oneflow/core/common/tensor_buffer.h\"\n#include \"oneflow/core/record/record.pb.h\"\n#include \"oneflow/core/ep/include/device_manager_registry.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nstd::shared_ptr<ep::Device> GetAllocationDevice(const MemoryCase& mem_case) {\n  auto device = Singleton<ep::DeviceManagerRegistry>::Get()->GetDevice(mem_case.device_type(),\n                                                                       mem_case.device_id());\n  CHECK(device);\n  return device;\n}\n\nep::AllocationOptions GetAllocationOptions(const MemoryCase& mem_case) {\n  ep::AllocationOptions options{};\n  if (mem_case.has_pinned_device_type() && mem_case.has_pinned_device_id()) {\n    options.SetPinnedDevice(mem_case.pinned_device_type(), mem_case.pinned_device_id());\n  }\n  return options;\n}\n\n}  // namespace\n\nvoid* MemoryAllocatorImpl::Allocate(const MemoryCase& mem_case, size_t size) {\n  void* ptr = nullptr;\n  std::shared_ptr<ep::Device> device = GetAllocationDevice(mem_case);\n  ep::AllocationOptions options = GetAllocationOptions(mem_case);\n  CHECK_JUST(device->Alloc(options, &ptr, size));\n  return ptr;\n}\n\nvoid MemoryAllocatorImpl::Deallocate(void* ptr, const MemoryCase& mem_case) {\n  std::shared_ptr<ep::Device> device = GetAllocationDevice(mem_case);\n  ep::AllocationOptions options = GetAllocationOptions(mem_case);\n  device->Free(options, ptr);\n}\n\nvoid* MemoryAllocatorImpl::AllocateUnPinnedHostMem(size_t size) {\n  void* ptr = aligned_alloc(kHostAlignSize, size);\n  CHECK_NOTNULL(ptr);\n  return ptr;\n}\n\nvoid MemoryAllocatorImpl::DeallocateUnPinnedHostMem(void* ptr) {\n  free(ptr);  // NOLINT\n}\n\nMemoryAllocator::~MemoryAllocator() {\n  for (const std::function<void()>& deleter : deleters_) { deleter(); }\n}\n\nchar* MemoryAllocator::Allocate(const MemoryCase& mem_case, std::size_t size) {\n  char* dptr = static_cast<char*>(MemoryAllocatorImpl::Allocate(mem_case, size));\n  deleters_.push_front(std::bind(&MemoryAllocator::Deallocate, this, dptr, mem_case));\n  return dptr;\n}\n\nvoid MemoryAllocator::Deallocate(char* dptr, const MemoryCase& mem_case) {\n  MemoryAllocatorImpl::Deallocate(static_cast<void*>(dptr), mem_case);\n}\n\nvoid InitNonPODTypeBlobIfNeed(MemoryAllocator* allocator, Blob* blob_ptr) {\n  const BlobDesc& blob_desc = blob_ptr->blob_desc();\n  if (blob_desc.data_type() == kOFRecord) {\n    int64_t elem_cnt = blob_desc.shape().elem_cnt();\n    FOR_RANGE(int64_t, idx, 0, elem_cnt) {\n      allocator->PlacementNew(&blob_ptr->mut_dptr<OFRecord>()[idx]);\n    }\n  }\n  if (blob_desc.data_type() == kTensorBuffer) {\n    int64_t elem_cnt = blob_desc.shape().elem_cnt();\n    FOR_RANGE(int64_t, idx, 0, elem_cnt) {\n      allocator->PlacementNew(&blob_ptr->mut_dptr<TensorBuffer>()[idx]);\n    }\n  }\n}\n\nvoid InitNonPODTypeEagerBlobObjectIfNeed(MemoryAllocator* allocator,\n                                         vm::EagerBlobObject* eager_blob_object_ptr) {\n  if (eager_blob_object_ptr->data_type() == kOFRecord) {\n    int64_t elem_cnt = eager_blob_object_ptr->shape().elem_cnt();\n    FOR_RANGE(int64_t, idx, 0, elem_cnt) {\n      allocator->PlacementNew(&eager_blob_object_ptr->mut_dptr<OFRecord>()[idx]);\n    }\n  }\n  if (eager_blob_object_ptr->data_type() == kTensorBuffer) {\n    int64_t elem_cnt = eager_blob_object_ptr->shape().elem_cnt();\n    FOR_RANGE(int64_t, idx, 0, elem_cnt) {\n      allocator->PlacementNew(&eager_blob_object_ptr->mut_dptr<TensorBuffer>()[idx]);\n    }\n  }\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/memory/memory_allocator.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_MEMORY_MEMORY_ALLOCATOR_H_\n#define ONEFLOW_CORE_MEMORY_MEMORY_ALLOCATOR_H_\n\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/memory/memory_case_util.h\"\n\nnamespace oneflow {\n\nnamespace vm {\nclass EagerBlobObject;\n}\n\nclass MemoryAllocator final {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(MemoryAllocator);\n  MemoryAllocator() = default;\n  ~MemoryAllocator();\n\n  char* Allocate(const MemoryCase& mem_case, std::size_t size);\n  template<typename T>\n  T* PlacementNew(T* mem_ptr);\n\n private:\n  void Deallocate(char* dptr, const MemoryCase& mem_case);\n\n  std::mutex deleters_mutex_;\n  std::list<std::function<void()>> deleters_;\n};\n\nclass Blob;\nvoid InitNonPODTypeBlobIfNeed(MemoryAllocator* allocator, Blob* blob_ptr);\nvoid InitNonPODTypeEagerBlobObjectIfNeed(MemoryAllocator* allocator,\n                                         vm::EagerBlobObject* eager_blob_object_ptr);\n\ntemplate<typename T>\nT* MemoryAllocator::PlacementNew(T* mem_ptr) {\n  T* obj = new (mem_ptr) T();\n  {\n    std::unique_lock<std::mutex> lock(deleters_mutex_);\n    deleters_.push_front([obj] { obj->~T(); });\n  }\n  CHECK_EQ(mem_ptr, obj);\n  return obj;\n}\n\nstruct MemoryAllocatorImpl final {\n  static void* Allocate(const MemoryCase& mem_case, size_t size);\n  static void Deallocate(void* ptr, const MemoryCase& mem_case);\n  static void* AllocateUnPinnedHostMem(size_t size);\n  static void DeallocateUnPinnedHostMem(void* ptr);\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_MEMORY_MEMORY_ALLOCATOR_H_\n"
  },
  {
    "path": "oneflow/core/memory/memory_block.proto",
    "content": "syntax = \"proto2\";\npackage oneflow;\n\nimport \"oneflow/core/memory/memory_case.proto\";\n\nmessage MemBlockProto {\n  required int64 mem_block_id = 1;\n  repeated int64 job_id = 2;\n  required int64 machine_id = 3;\n  required MemoryCase mem_case = 4;\n  required bool enable_reuse_mem = 5;\n  optional int64 chunk_id = 6 [default = -1];\n  optional int64 chunk_offset = 7 [default = -1];\n  required int64 mem_size = 8;\n  // NOTE(chengcheng): thrd id hint is used by packed separated block group order.\n  optional int64 thrd_id_hint = 9 [default = -1];\n  // NOTE(chengcheng): mark this block memory is shared with EagerParameter.\n  optional string variable_op_name = 10 [default = \"\"];\n  optional bool is_separated_header = 11 [default = false];\n}\n\nmessage ChunkProto {\n  required int64 chunk_id = 1;\n  repeated int64 job_id = 2;\n  required int64 machine_id = 3;\n  required MemoryCase mem_case = 4;\n  required int64 mem_size = 5;\n}\n\nmessage MemBlockAndChunkList {\n  repeated MemBlockProto mem_block = 1;\n  repeated ChunkProto chunk = 2;\n}\n"
  },
  {
    "path": "oneflow/core/memory/memory_case.proto",
    "content": "syntax = \"proto2\";\npackage oneflow;\n\nimport \"oneflow/core/common/device_type.proto\";\n\nmessage MemoryCase {\n  required DeviceType device_type = 1;\n  required int64 device_id = 2;\n  optional DeviceType pinned_device_type = 3;\n  optional int64 pinned_device_id = 4;\n}\n"
  },
  {
    "path": "oneflow/core/memory/memory_case_util.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/memory/memory_case_util.h\"\n\n#include <google/protobuf/util/message_differencer.h>\n\nnamespace oneflow {\n\nnamespace memory {\n\nbool EqualsIgnorePinnedDevice(const MemoryCase& a, const MemoryCase& b) {\n  if (a.device_type() != b.device_type()) { return false; }\n  if (a.device_id() != b.device_id()) { return false; }\n  return true;\n}\n\nvoid GetPinnedHostMemoryCase(const MemoryCase& mem_case, MemoryCase* ret) {\n  ret->set_device_type(DeviceType::kCPU);\n  ret->set_device_id(0);\n  if (!IsHostMem(mem_case)) {\n    ret->set_pinned_device_type(mem_case.device_type());\n    ret->set_pinned_device_id(mem_case.device_id());\n  }\n}\n\nMemoryCase GetPinnedHostMemoryCase(const MemoryCase& mem_case) {\n  MemoryCase ret;\n  GetPinnedHostMemoryCase(mem_case, &ret);\n  return ret;\n}\n\n// clang-format off\n// MemCaseId encoding (bits)\n// | reserved | node_index | device_type | device_index | reserved | pinned_device_type | pinned_device_index |\n// | --- 1 -- | --- 19 --- | ---- 5 ---- | ----- 7 ---- | -- 20 -- | ------- 5 -------- | ------- 7 --------- |\n// | ---------------------- 32 ------------------------ | ---------------------- 32 ------------------------- |\n// clang-format on\n\nnamespace {\n\nconstexpr size_t kDeviceIndexBits = 7;\nconstexpr size_t kDeviceTypeBits = 5;\nconstexpr size_t kDeviceTypeShift = kDeviceIndexBits;\nconstexpr size_t kNodeIndexShift = kDeviceTypeShift + kDeviceTypeBits;\nconstexpr size_t kPinnedDeviceShift = 32;\n\n}  // namespace\n\nint64_t GetMemCaseId(const MemoryCase& mem_case) {\n  uint32_t high = 0;\n  high |= static_cast<uint32_t>(mem_case.device_id());\n  high |= static_cast<uint32_t>(mem_case.device_type()) << kDeviceTypeShift;\n  uint32_t low = 0;\n  if (mem_case.has_pinned_device_id()) {\n    low |= static_cast<uint32_t>(mem_case.pinned_device_id());\n  }\n  if (mem_case.has_pinned_device_type()) {\n    low |= static_cast<uint32_t>(mem_case.pinned_device_type()) << kDeviceTypeShift;\n  }\n  int64_t id = 0;\n  id |= static_cast<int64_t>(high) << kPinnedDeviceShift;\n  id |= static_cast<int64_t>(low);\n  return id;\n}\n\nint64_t GetUniqueMemCaseId(int64_t machine_id, const MemoryCase& mem_case) {\n  int64_t id = 0;\n  id |= (machine_id << kNodeIndexShift << kPinnedDeviceShift);\n  id |= GetMemCaseId(mem_case);\n  return id;\n}\n\nstd::shared_ptr<MemoryCase> MakeMemCaseShared(const DeviceType device_type,\n                                              const int64_t device_id) {\n  auto mem_case_ptr = std::make_shared<MemoryCase>();\n  mem_case_ptr->set_device_type(device_type);\n  // We consider that there is only one cpu physical device.\n  // As non-cpu devices, a logical device map to a physical device,\n  // however as cpu devices, all logical devices map to a single physical device.\n  if (device_type == DeviceType::kCPU) {\n    mem_case_ptr->set_device_id(0);\n  } else {\n    mem_case_ptr->set_device_id(device_id);\n  }\n  return mem_case_ptr;\n}\n\nMemoryCase MakeHostMemCase() {\n  MemoryCase mem_case;\n  mem_case.set_device_type(DeviceType::kCPU);\n  mem_case.set_device_id(0);\n  return mem_case;\n}\n\nbool IsHostMem(const MemoryCase& mem_case) { return mem_case.device_type() == DeviceType::kCPU; }\n\n}  // namespace memory\n\nbool operator==(const MemoryCase& lhs, const MemoryCase& rhs) {\n  return google::protobuf::util::MessageDifferencer::Equals(lhs, rhs);\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/memory/memory_case_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_MEMORY_MEMORY_CASE_UTIL_H_\n#define ONEFLOW_CORE_MEMORY_MEMORY_CASE_UTIL_H_\n\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/common/device_type.h\"\n#include \"oneflow/core/memory/memory_case.pb.h\"\n\nnamespace oneflow {\n\nnamespace memory {\n\nbool EqualsIgnorePinnedDevice(const MemoryCase& a, const MemoryCase& b);\nvoid GetPinnedHostMemoryCase(const MemoryCase& mem_case, MemoryCase* ret);\nMemoryCase GetPinnedHostMemoryCase(const MemoryCase& mem_case);\nint64_t GetMemCaseId(const MemoryCase& mem_case);\nint64_t GetUniqueMemCaseId(int64_t machine_id, const MemoryCase& mem_case);\nstd::shared_ptr<MemoryCase> MakeMemCaseShared(const DeviceType device_type,\n                                              const int64_t device_id);\nMemoryCase MakeHostMemCase();\nbool IsHostMem(const MemoryCase& mem_case);\n\n}  // namespace memory\n\nbool operator==(const MemoryCase& lhs, const MemoryCase& rhs);\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_MEMORY_MEMORY_CASE_UTIL_H_\n"
  },
  {
    "path": "oneflow/core/memory/memory_zone.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/memory/memory_zone.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nconstexpr size_t kMemZoneIdDeviceTypeShift = MemZoneId::kDeviceIndexBits;\nconstexpr size_t kMemZoneIdRankShift = kMemZoneIdDeviceTypeShift + MemZoneId::kDeviceTypeBits;\n\nconstexpr int64_t kMemZoneIdRankInt64Mask = ((int64_t{1} << MemZoneId::kRankBits) - 1)\n                                            << kMemZoneIdRankShift;\nconstexpr int64_t kMemZoneIdDeviceTypeInt64Mask = ((int64_t{1} << MemZoneId::kDeviceTypeBits) - 1)\n                                                  << kMemZoneIdDeviceTypeShift;\nconstexpr int64_t kMemZoneIdDeviceIndexInt64Mask = (int64_t{1} << MemZoneId::kDeviceIndexBits) - 1;\n\n}  // namespace\n\nconst MemZoneId kInvalidMemZoneId = MemZoneId{0, DeviceType::kInvalidDevice, 0};\n\nMemZoneId GetNodeCPUMemZoneId(MemZoneId::rank_t node_index) {\n  return MemZoneId{node_index, DeviceType::kCPU, 0};\n}\n\nint64_t EncodeMemZoneIdToInt64(const MemZoneId& mem_zone_id) {\n  int64_t id = static_cast<int64_t>(mem_zone_id.device_index());\n  id |= static_cast<int64_t>(mem_zone_id.device_type()) << kMemZoneIdDeviceTypeShift;\n  id |= static_cast<int64_t>(mem_zone_id.rank()) << kMemZoneIdRankShift;\n  return id;\n}\n\nMemZoneId DecodeMemZoneIdFromInt64(int64_t mem_zone_id) {\n  int64_t rank = (mem_zone_id & kMemZoneIdRankInt64Mask) >> kMemZoneIdRankShift;\n  int64_t device_type = (mem_zone_id & kMemZoneIdDeviceTypeInt64Mask) >> kMemZoneIdDeviceTypeShift;\n  int64_t device_index = mem_zone_id & kMemZoneIdDeviceIndexInt64Mask;\n  return MemZoneId(static_cast<MemZoneId::rank_t>(rank), static_cast<DeviceType>(device_type),\n                   static_cast<MemZoneId::device_index_t>(device_index));\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/memory/memory_zone.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_MEMORY_MEMORY_ZONE_H_\n#define ONEFLOW_CORE_MEMORY_MEMORY_ZONE_H_\n\n#include \"oneflow/core/device/device_id.h\"\n\nnamespace oneflow {\n\nusing MemZoneId = DeviceId;\n\nint64_t EncodeMemZoneIdToInt64(const MemZoneId&);\nMemZoneId DecodeMemZoneIdFromInt64(int64_t);\n\nMemZoneId GetNodeCPUMemZoneId(MemZoneId::rank_t node_index);\n\nextern const MemZoneId kInvalidMemZoneId;\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_MEMORY_MEMORY_ZONE_H_\n"
  },
  {
    "path": "oneflow/core/ndarray/binary_func.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_NDARRAY_BINARY_FUNC_H_\n#define ONEFLOW_CORE_NDARRAY_BINARY_FUNC_H_\n\n#include <cstdint>\n#include <climits>\n#include <cfloat>\n#include <cmath>\n\n#if defined(__CUDACC__)\n#include <cuda_fp16.h>\n#endif\n#include \"oneflow/core/kernel/kernel_util.h\"\n#include \"oneflow/core/common/util.h\"\nnamespace oneflow {\n\n#define ARITHMETIC_BINARY_FUNC_NAME_SEQ (Add)(Sub)(Mul)(Div)(Min)(Max)(FloorMod)(FMod)(Pow)\n#define LOGICAL_BINARY_FUNC_NAME_SEQ (EQ)(NE)(GT)(GE)(LT)(LE)(AND)(OR)(XOR)\n\n#define PREPEND_PREFIX_BINARY_FUNC(name) OF_PP_CAT(BinaryFunc, name)\n#define ARITHMETIC_BINARY_FUNC_SEQ \\\n  OF_PP_SEQ_MAP(PREPEND_PREFIX_BINARY_FUNC, ARITHMETIC_BINARY_FUNC_NAME_SEQ)\n#define LOGICAL_BINARY_FUNC_SEQ \\\n  OF_PP_SEQ_MAP(PREPEND_PREFIX_BINARY_FUNC, LOGICAL_BINARY_FUNC_NAME_SEQ)\n\n#define REDUCE_BINARY_FUNC_NAME_SEQ (Sum)(Max)(Min)(Prod)(Any)(All)\n#define ARITHMETIC_REDUCE_BINARY_FUNC_NAME_SEQ (Sum)(Max)(Min)(Prod)\n#define LOGICAL_REDUCE_BINARY_FUNC_NAME_SEQ (Any)(All)\n#define REDUCE_BINARY_FUNC_SEQ \\\n  OF_PP_SEQ_MAP(PREPEND_PREFIX_BINARY_FUNC, REDUCE_BINARY_FUNC_NAME_SEQ)\n#define REDUCE_COMPLEX_BINARY_FUNC_SEQ OF_PP_SEQ_MAP(PREPEND_PREFIX_BINARY_FUNC, (Sum))\n#define ARITHMETIC_REDUCE_BINARY_FUNC_SEQ \\\n  OF_PP_SEQ_MAP(PREPEND_PREFIX_BINARY_FUNC, ARITHMETIC_REDUCE_BINARY_FUNC_NAME_SEQ)\n#define LOGICAL_REDUCE_BINARY_FUNC_SEQ \\\n  OF_PP_SEQ_MAP(PREPEND_PREFIX_BINARY_FUNC, LOGICAL_REDUCE_BINARY_FUNC_NAME_SEQ)\n#define NANSUM_REDUCE_BINARY_FUNC_SEQ OF_PP_SEQ_MAP(PREPEND_PREFIX_BINARY_FUNC, (NanSum))\n\n#define NO_HALF_UTIL_FOUND         \\\n  printf(\"cuda arch must >= 530\"); \\\n  assert(false);                   \\\n  return __float2half(0.0)\ntemplate<template<typename> class BinaryFunc, typename T>\nstruct BinaryFuncTrait final {\n  typedef typename std::remove_const<decltype(\n      BinaryFunc<T>::Invoke(std::declval<const T>(), std::declval<const T>()))>::type return_type;\n};\n\n#define SPECIALIZE_CONST_TYPE_BINARY_FUNC(func_struct)                                        \\\n  template<typename T>                                                                        \\\n  struct func_struct<const T> final {                                                         \\\n    static OF_DEVICE_FUNC const typename BinaryFuncTrait<func_struct, T>::return_type Invoke( \\\n        const T x, const T y) {                                                               \\\n      return func_struct<T>::Invoke(x, y);                                                    \\\n    }                                                                                         \\\n  }\n\ntemplate<typename T>\nstruct BinaryFuncNanSum final {\n  static OF_DEVICE_FUNC T Invoke(const T x, const T y) {\n#if defined(__CUDACC__)\n    if (isnan(x)) return isnan(y) ? T{0} : y;\n    return isnan(y) ? x : x + y;\n#else\n    if (std::isnan(x)) return std::isnan(y) ? T{0} : y;\n    return std::isnan(y) ? x : x + y;\n#endif\n  }\n};\nSPECIALIZE_CONST_TYPE_BINARY_FUNC(BinaryFuncNanSum);\n\ntemplate<typename T>\nstruct BinaryFuncAdd final {\n  static OF_DEVICE_FUNC T Invoke(const T x, const T y) { return x + y; }\n};\ntemplate<typename T>\nstruct BinaryFuncSum final {\n  static OF_DEVICE_FUNC T Invoke(const T x, const T y) { return BinaryFuncAdd<T>::Invoke(x, y); }\n};\nSPECIALIZE_CONST_TYPE_BINARY_FUNC(BinaryFuncAdd);\n\ntemplate<typename T>\nstruct BinaryFuncSub final {\n  static OF_DEVICE_FUNC T Invoke(const T x, const T y) { return x - y; }\n};\nSPECIALIZE_CONST_TYPE_BINARY_FUNC(BinaryFuncSub);\n\ntemplate<typename T>\nstruct BinaryFuncMul final {\n  static OF_DEVICE_FUNC T Invoke(const T x, const T y) { return x * y; }\n};\ntemplate<>\nstruct BinaryFuncMul<bool> final {\n  static OF_DEVICE_FUNC bool Invoke(const bool x, const bool y) { return x && y; }\n};\ntemplate<typename T>\nstruct BinaryFuncProd final {\n  static OF_DEVICE_FUNC T Invoke(const T x, const T y) { return BinaryFuncMul<T>::Invoke(x, y); }\n};\nSPECIALIZE_CONST_TYPE_BINARY_FUNC(BinaryFuncMul);\n\ntemplate<typename T>\nstruct BinaryFuncDiv final {\n  static OF_DEVICE_FUNC T Invoke(const T x, const T y) { return x / y; }\n};\nSPECIALIZE_CONST_TYPE_BINARY_FUNC(BinaryFuncDiv);\n\ntemplate<typename T>\nstruct BinaryFuncFloorMod final {\n  static OF_DEVICE_FUNC T Invoke(const T x, const T y) {\n#if defined(__CUDACC__)\n    T trunc_mod = x % y;\n    return (trunc_mod != T(0)) && ((y < T(0)) != (trunc_mod < T(0))) ? trunc_mod + y : trunc_mod;\n#else\n    T trunc_mod = x % y;\n    return (trunc_mod != T(0)) && ((y < T(0)) != (trunc_mod < T(0))) ? trunc_mod + y : trunc_mod;\n#endif\n  }\n};\nSPECIALIZE_CONST_TYPE_BINARY_FUNC(BinaryFuncFloorMod);\n\ntemplate<>\nstruct BinaryFuncFloorMod<uint8_t> final {\n  static OF_DEVICE_FUNC uint8_t Invoke(const uint8_t x, const uint8_t y) {\n#if defined(__CUDACC__)\n    uint8_t trunc_mod = x % y;\n    return trunc_mod;\n#else\n    uint8_t trunc_mod = x % y;\n    return trunc_mod;\n#endif\n  }\n};\n\ntemplate<typename T>\nstruct BinaryFuncFMod final {\n  static OF_DEVICE_FUNC T Invoke(const T x, const T y) {\n#if defined(__CUDACC__)\n    T trunc_mod = x % y;\n    return trunc_mod;\n#else\n    T trunc_mod = x % y;\n    return trunc_mod;\n#endif\n  }\n};\nSPECIALIZE_CONST_TYPE_BINARY_FUNC(BinaryFuncFMod);\n\ntemplate<typename T>\nstruct BinaryFuncPow final {\n  static OF_DEVICE_FUNC const T Invoke(const T x, const T y) {\n#if defined(__CUDACC__)\n    return powf(x, y);\n#else\n    return std::pow(x, y);\n#endif\n  }\n};\n\ntemplate<>\nstruct BinaryFuncPow<bool> final {\n  static OF_DEVICE_FUNC bool Invoke(const bool x, const bool y) {\n#if defined(__CUDACC__)\n    return static_cast<bool>(powf(static_cast<float>(x), static_cast<float>(y)));\n#else\n    return static_cast<bool>(std::pow(static_cast<float>(x), static_cast<float>(y)));\n#endif\n  }\n};\n\nSPECIALIZE_CONST_TYPE_BINARY_FUNC(BinaryFuncPow);\n\ntemplate<>\nstruct BinaryFuncPow<float16> final {\n  static inline const float16 Invoke(const float16 x, const float16 y) {\n    return static_cast<float16>(std::pow(static_cast<float>(x), static_cast<float>(y)));\n  }\n};\n\n#if defined(__CUDACC__)\ntemplate<>\nstruct BinaryFuncPow<double> final {\n  static OF_DEVICE_FUNC double Invoke(const double x, const double y) { return pow(x, y); }\n};\n\ntemplate<>\nstruct BinaryFuncPow<float> final {\n  static __device__ __forceinline__ float Invoke(const float x, const float y) {\n    return powf(x, y);\n  }\n};\n\ntemplate<>\nstruct BinaryFuncPow<half> final {\n  static __device__ __forceinline__ half Invoke(const half x, const half y) {\n#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530\n    return __float2half(powf(__half2float(x), __half2float(y)));\n#else\n    NO_HALF_UTIL_FOUND;\n#endif\n  }\n};\n#endif  // defined(__CUDACC__)\n\ntemplate<typename T>\nstruct BinaryFuncFloorDiv final {\n  static OF_DEVICE_FUNC T Invoke(const T x, const T y) {\n#if defined(__CUDACC__)\n    return floor(fdividef(x, y));\n#else\n    return std::floor(x / y);\n#endif\n  }\n};\n\ntemplate<typename T>\nstruct BinaryFuncMax final {\n  static OF_DEVICE_FUNC T Invoke(const T x, const T y) { return x > y ? x : y; }\n};\nSPECIALIZE_CONST_TYPE_BINARY_FUNC(BinaryFuncMax);\n\ntemplate<typename T>\nstruct BinaryFuncMin final {\n  static OF_DEVICE_FUNC T Invoke(const T x, const T y) { return x < y ? x : y; }\n};\nSPECIALIZE_CONST_TYPE_BINARY_FUNC(BinaryFuncMin);\n\ntemplate<typename T>\nstruct BinaryFuncEQ final {\n  static OF_DEVICE_FUNC bool Invoke(const T x, const T y) { return x == y; }\n};\nSPECIALIZE_CONST_TYPE_BINARY_FUNC(BinaryFuncEQ);\n\ntemplate<typename T>\nstruct BinaryFuncNE final {\n  static OF_DEVICE_FUNC bool Invoke(const T x, const T y) { return x != y; }\n};\nSPECIALIZE_CONST_TYPE_BINARY_FUNC(BinaryFuncNE);\n\ntemplate<typename T>\nstruct BinaryFuncGT final {\n  static OF_DEVICE_FUNC bool Invoke(const T x, const T y) { return x > y; }\n};\nSPECIALIZE_CONST_TYPE_BINARY_FUNC(BinaryFuncGT);\n\ntemplate<typename T>\nstruct BinaryFuncGE final {\n  static OF_DEVICE_FUNC bool Invoke(const T x, const T y) { return x >= y; }\n};\nSPECIALIZE_CONST_TYPE_BINARY_FUNC(BinaryFuncGE);\n\ntemplate<typename T>\nstruct BinaryFuncLT final {\n  static OF_DEVICE_FUNC bool Invoke(const T x, const T y) { return x < y; }\n};\nSPECIALIZE_CONST_TYPE_BINARY_FUNC(BinaryFuncLT);\n\ntemplate<typename T>\nstruct BinaryFuncLE final {\n  static OF_DEVICE_FUNC bool Invoke(const T x, const T y) { return x <= y; }\n};\nSPECIALIZE_CONST_TYPE_BINARY_FUNC(BinaryFuncLE);\n\ntemplate<typename T>\nstruct BinaryFuncAND final {\n  static OF_DEVICE_FUNC bool Invoke(const T x, const T y) { return x && y; }\n};\ntemplate<typename T>\nstruct BinaryFuncAll final {\n  static OF_DEVICE_FUNC bool Invoke(const T x, const T y) { return BinaryFuncAND<T>::Invoke(x, y); }\n};\nSPECIALIZE_CONST_TYPE_BINARY_FUNC(BinaryFuncAND);\n\ntemplate<typename T>\nstruct BinaryFuncOR final {\n  static OF_DEVICE_FUNC bool Invoke(const T x, const T y) { return x || y; }\n};\ntemplate<typename T>\nstruct BinaryFuncAny final {\n  static OF_DEVICE_FUNC bool Invoke(const T x, const T y) { return BinaryFuncOR<T>::Invoke(x, y); }\n};\nSPECIALIZE_CONST_TYPE_BINARY_FUNC(BinaryFuncOR);\n\ntemplate<typename T>\nstruct BinaryFuncXOR final {\n  static OF_DEVICE_FUNC bool Invoke(const T x, const T y) { return (!x) != (!y); }\n};\nSPECIALIZE_CONST_TYPE_BINARY_FUNC(BinaryFuncXOR);\n\ntemplate<typename T>\nstruct BinaryFuncBitwiseAnd final {\n  static OF_DEVICE_FUNC T Invoke(const T x, const T y) { return x & y; }\n};\nSPECIALIZE_CONST_TYPE_BINARY_FUNC(BinaryFuncBitwiseAnd);\n\ntemplate<typename T>\nstruct BinaryFuncBitwiseOr final {\n  static OF_DEVICE_FUNC T Invoke(const T x, const T y) { return x | y; }\n};\nSPECIALIZE_CONST_TYPE_BINARY_FUNC(BinaryFuncBitwiseOr);\n\ntemplate<typename T>\nstruct BinaryFuncBitwiseXor final {\n  static OF_DEVICE_FUNC T Invoke(const T x, const T y) { return x ^ y; }\n};\nSPECIALIZE_CONST_TYPE_BINARY_FUNC(BinaryFuncBitwiseXor);\n\n#if defined(__CUDACC__)\ntemplate<>\nstruct BinaryFuncAdd<half> final {\n  static __device__ __forceinline__ half Invoke(const half x, const half y) { return __hadd(x, y); }\n};\n\ntemplate<>\nstruct BinaryFuncNanSum<half> final {\n  static __device__ __forceinline__ half Invoke(const half x, const half y) {\n    if (isnan(__half2float(x))) return isnan(__half2float(y)) ? half(0.0) : y;\n    return isnan(__half2float(y)) ? __hadd(x, y) : x;\n  }\n};\n\ntemplate<>\nstruct BinaryFuncSub<half> final {\n  static __device__ __forceinline__ half Invoke(const half x, const half y) {\n#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530\n    return __hsub(x, y);\n#else\n    NO_HALF_UTIL_FOUND;\n#endif\n  }\n};\n\ntemplate<>\nstruct BinaryFuncMul<half> final {\n  static __device__ __forceinline__ half Invoke(const half x, const half y) {\n#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530\n    return __hmul(x, y);\n#else\n    NO_HALF_UTIL_FOUND;\n#endif\n  }\n};\n\ntemplate<>\nstruct BinaryFuncDiv<half> final {\n  static __device__ __forceinline__ half Invoke(const half x, const half y) {\n#if __CUDA_ARCH__ >= 530\n    return __hdiv(x, y);\n#else\n    NO_HALF_UTIL_FOUND;\n#endif\n  }\n};\n\ntemplate<>\nstruct BinaryFuncMax<half> final {\n  static __device__ __forceinline__ half Invoke(const half x, const half y) {\n#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530\n    return __hgt(x, y) ? x : y;\n#else\n    NO_HALF_UTIL_FOUND;\n#endif\n  }\n};\n\ntemplate<>\nstruct BinaryFuncMin<half> final {\n  static __device__ __forceinline__ half Invoke(const half x, const half y) {\n#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530\n    return __hlt(x, y) ? x : y;\n#else\n    NO_HALF_UTIL_FOUND;\n#endif\n  }\n};\n\ntemplate<>\nstruct BinaryFuncAdd<cuComplex> final {\n  static __device__ __forceinline__ cuComplex Invoke(const cuComplex x, const cuComplex y) {\n    return cuComplex{x.x + y.x, x.y + y.y};\n  }\n};\n\ntemplate<>\nstruct BinaryFuncSub<cuComplex> final {\n  static __device__ __forceinline__ cuComplex Invoke(const cuComplex x, const cuComplex y) {\n    return cuComplex{x.x - y.x, x.y - y.y};\n  }\n};\n\ntemplate<>\nstruct BinaryFuncMul<cuComplex> final {\n  static __device__ __forceinline__ cuComplex Invoke(const cuComplex x, const cuComplex y) {\n    return cuCmulf(x, y);\n  }\n};\n\ntemplate<>\nstruct BinaryFuncAdd<cuDoubleComplex> final {\n  static __device__ __forceinline__ cuDoubleComplex Invoke(const cuDoubleComplex x,\n                                                           const cuDoubleComplex y) {\n    return cuDoubleComplex{x.x + y.x, x.y + y.y};\n  }\n};\n\ntemplate<>\nstruct BinaryFuncSub<cuDoubleComplex> final {\n  static __device__ __forceinline__ cuDoubleComplex Invoke(const cuDoubleComplex x,\n                                                           const cuDoubleComplex y) {\n    return cuDoubleComplex{x.x - y.x, x.y - y.y};\n  }\n};\n\ntemplate<>\nstruct BinaryFuncMul<cuDoubleComplex> final {\n  static __device__ __forceinline__ cuDoubleComplex Invoke(const cuDoubleComplex x,\n                                                           const cuDoubleComplex y) {\n    return cuCmul(x, y);\n  }\n};\n\n#endif  // defined(__CUDACC__)\n\n#if defined(__CUDACC__)\n\ntemplate<>\nstruct BinaryFuncFloorMod<float> final {\n  static __device__ __forceinline__ float Invoke(const float x, const float y) {\n    const float trunc_mod = fmodf(x, y);\n    return (trunc_mod != 0) && ((y < 0) != (trunc_mod < 0)) ? trunc_mod + y : trunc_mod;\n  }\n};\n\ntemplate<>\nstruct BinaryFuncFloorMod<double> final {\n  static __device__ __forceinline__ double Invoke(const double x, const double y) {\n    const double trunc_mod = fmod(x, y);\n    return (trunc_mod != 0) && ((y < 0) != (trunc_mod < 0)) ? trunc_mod + y : trunc_mod;\n  }\n};\n\ntemplate<>\nstruct BinaryFuncFloorMod<half> final {\n  static __device__ __forceinline__ half Invoke(const half x, const half y) {\n#if __CUDA_ARCH__ >= 530\n    const half trunc_mod = __float2half(fmodf(__half2float(x), __half2float(y)));\n    return __hne(trunc_mod, GetZeroVal<half>())\n                   && __hlt(y, GetZeroVal<half>()) != __hlt(trunc_mod, half(0))\n               ? trunc_mod + y\n               : trunc_mod;\n#else\n    NO_HALF_UTIL_FOUND;\n#endif\n  }\n};\n\n#else\n\ntemplate<>\nstruct BinaryFuncFloorMod<float> final {\n  static inline float Invoke(const float x, const float y) {\n    const float trunc_mod = std::fmod(x, y);\n    return (trunc_mod != 0) && ((y < 0) != (trunc_mod < 0)) ? trunc_mod + y : trunc_mod;\n  }\n};\n\ntemplate<>\nstruct BinaryFuncFloorMod<double> final {\n  static inline double Invoke(const double x, const double y) {\n    const double trunc_mod = std::fmod(x, y);\n    return (trunc_mod != 0) && ((y < 0) != (trunc_mod < 0)) ? trunc_mod + y : trunc_mod;\n  }\n};\n\ntemplate<>\nstruct BinaryFuncFloorMod<float16> final {\n  static inline float16 Invoke(const float16 x, const float16 y) {\n    const float trunc_mod = std::fmod(static_cast<float>(x), static_cast<float>(y));\n    return (trunc_mod != float(0)) && ((y < float(0)) != (trunc_mod < float(0)))\n               ? static_cast<float16>(trunc_mod + y)\n               : static_cast<float16>(trunc_mod);\n  }\n};\n\n#endif  // defined(__CUDACC__)\n\n#if defined(__CUDACC__)\n\ntemplate<>\nstruct BinaryFuncFMod<float> final {\n  static __device__ __forceinline__ float Invoke(const float x, const float y) {\n    const float trunc_mod = fmodf(x, y);\n    return trunc_mod;\n  }\n};\n\ntemplate<>\nstruct BinaryFuncFMod<double> final {\n  static __device__ __forceinline__ double Invoke(const double x, const double y) {\n    const double trunc_mod = fmod(x, y);\n    return trunc_mod;\n  }\n};\n\ntemplate<>\nstruct BinaryFuncFMod<half> final {\n  static __device__ __forceinline__ half Invoke(const half x, const half y) {\n#if __CUDA_ARCH__ >= 530\n    const half trunc_mod = __float2half(fmodf(__half2float(x), __half2float(y)));\n    return trunc_mod;\n#else\n    NO_HALF_UTIL_FOUND;\n#endif\n  }\n};\n#else\ntemplate<>\nstruct BinaryFuncFMod<float> final {\n  static inline float Invoke(const float x, const float y) {\n    const float trunc_mod = std::fmod(x, y);\n    return trunc_mod;\n  }\n};\n\ntemplate<>\nstruct BinaryFuncFMod<double> final {\n  static inline double Invoke(const double x, const double y) {\n    const double trunc_mod = std::fmod(x, y);\n    return trunc_mod;\n  }\n};\n\ntemplate<>\nstruct BinaryFuncFMod<float16> final {\n  static inline float16 Invoke(const float16 x, const float16 y) {\n    const float trunc_mod = std::fmod(static_cast<float>(x), static_cast<float>(y));\n    return static_cast<float16>(trunc_mod);\n  }\n};\n\n#endif  // defined(__CUDACC__)\n\n#if defined(__CUDACC__)\n\ntemplate<>\nstruct BinaryFuncFloorDiv<uint8_t> final {\n  static __device__ __forceinline__ uint8_t Invoke(uint8_t x, uint8_t y) { return x / y; }\n};\n\ntemplate<>\nstruct BinaryFuncFloorDiv<int8_t> final {\n  static __device__ __forceinline__ int8_t Invoke(int8_t x, int8_t y) { return x / y; }\n};\n\ntemplate<>\nstruct BinaryFuncFloorDiv<int32_t> final {\n  static __device__ __forceinline__ int32_t Invoke(int32_t x, int32_t y) { return x / y; }\n};\n\ntemplate<>\nstruct BinaryFuncFloorDiv<int64_t> final {\n  static __device__ __forceinline__ int64_t Invoke(int64_t x, int64_t y) { return x / y; }\n};\n\ntemplate<>\nstruct BinaryFuncFloorDiv<half> final {\n  static __device__ __forceinline__ half Invoke(const half x, const half y) {\n#if __CUDA_ARCH__ >= 530\n    return __float2half(floor(fdividef(__half2float(x), __half2float(y))));\n#else\n    NO_HALF_UTIL_FOUND;\n#endif\n  }\n};\n#else\ntemplate<>\nstruct BinaryFuncFloorDiv<float16> final {\n  static inline float16 Invoke(float16 x, float16 y) {\n    return static_cast<float16>(std::floor(static_cast<float>(x) / static_cast<float>(y)));\n  }\n};\n\n#endif  // defined(__CUDACC__)\ntemplate<typename T, template<typename> class binary_func>\nstruct UnitOfBinaryFunc;\n\n#define SPECIALIZE_UNIT_OF_BINARY_FUNC(binary_func, get_val) \\\n  template<typename T>                                       \\\n  struct UnitOfBinaryFunc<T, binary_func> final {            \\\n    static OF_DEVICE_FUNC T Val() { return get_val<T>(); }   \\\n  };\nSPECIALIZE_UNIT_OF_BINARY_FUNC(BinaryFuncAdd, GetZeroVal);\nSPECIALIZE_UNIT_OF_BINARY_FUNC(BinaryFuncNanSum, GetZeroVal);\nSPECIALIZE_UNIT_OF_BINARY_FUNC(BinaryFuncSum, GetZeroVal);\nSPECIALIZE_UNIT_OF_BINARY_FUNC(BinaryFuncMul, GetOneVal);\nSPECIALIZE_UNIT_OF_BINARY_FUNC(BinaryFuncProd, GetOneVal);\nSPECIALIZE_UNIT_OF_BINARY_FUNC(BinaryFuncMax, GetMinVal);\nSPECIALIZE_UNIT_OF_BINARY_FUNC(BinaryFuncMin, GetMaxVal);\nSPECIALIZE_UNIT_OF_BINARY_FUNC(BinaryFuncAny, GetZeroVal);\nSPECIALIZE_UNIT_OF_BINARY_FUNC(BinaryFuncAll, GetOneVal);\n#undef SPECIALIZE_UNIT_OF_BINARY_FUNC\n\n/*\nThese placeholder specializations are used for `GetBinaryBroadcastSbpSignature` in\noneflow/user/ops/math_binary_broadcast_ops.cpp\n*/\n#define SPECIALIZE_FOR_SBP(binary_func) \\\n  template<typename T>                  \\\n  struct binary_func final {};\n\nSPECIALIZE_FOR_SBP(BinaryFuncIEN);\nSPECIALIZE_FOR_SBP(BinaryFuncINN);\nSPECIALIZE_FOR_SBP(BinaryFuncZeta);\n#undef SPECIALIZE_FOR_SBP\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_NDARRAY_BINARY_FUNC_H_\n"
  },
  {
    "path": "oneflow/core/ndarray/cpu_concat_var_ndarray.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_NDARRAY_CPU_CONCAT_VAR_NDARRAY_H_\n#define ONEFLOW_CORE_NDARRAY_CPU_CONCAT_VAR_NDARRAY_H_\n\n#include \"oneflow/core/ndarray/cpu_ndarray.h\"\n#include \"oneflow/core/ndarray/cpu_var_ndarray.h\"\n#include \"oneflow/core/common/range.h\"\n\nnamespace oneflow {\n\ntemplate<typename T, int NDIMS, int CONCAT_AXES>\nclass CpuConcatVarNdarray : public CpuNdarray<T, NDIMS> {\n public:\n  static const bool immutable = false;\n  static_assert(CONCAT_AXES >= 0 && CONCAT_AXES < NDIMS, \"CONCAT_AXES should be a valid dim\");\n  CpuConcatVarNdarray(const std::vector<CpuVarNdarray<T, NDIMS>>& var_ndarrays)\n      : CpuNdarray<T, NDIMS>(CalcConcatenatedShape(var_ndarrays)),\n        var_ndarrays_(var_ndarrays),\n        dim_ranges_(CalcDimRanges(var_ndarrays)),\n        contiguous_lens_(CalcContiguousLens(var_ndarrays)) {}\n  ~CpuConcatVarNdarray() = default;\n\n  template<typename XT>\n  void CopyFrom(const XT& ndarray) {\n    CpuNdarrayCopy(this, ndarray);\n  }\n  void GetMutPtrAndContiguousSize(int64_t offset, T** ptr, size_t* size) const {\n    int64_t dim[NDIMS] = {0};\n    this->xpu_shape().template Offset2Coordinate<NDIMS>(offset, dim);\n    int32_t var_index = 0;\n    this->GetVarNdarrayIndexAndInputDim(dim[CONCAT_AXES], &var_index, &dim[CONCAT_AXES]);\n    int64_t input_offset =\n        this->var_ndarray(var_index).xpu_shape().template Coordinate2Offset<NDIMS>(dim);\n    this->GetMutPtrAndMinContiguousSize(var_index, input_offset, ptr, size);\n  }\n\n protected:\n  ALWAYS_INLINE void GetVarNdarrayIndexAndInputDim(int64_t output_dim, int32_t* var_index,\n                                                   int64_t* input_dim) const {\n    *var_index = CpuVarNdarrayIndex4OutputDim(output_dim);\n    *input_dim = output_dim - dim_ranges_[*var_index].begin();\n  }\n  ALWAYS_INLINE const CpuVarNdarray<T, NDIMS> var_ndarray(int32_t var_index) const {\n    return var_ndarrays_[var_index];\n  }\n  ALWAYS_INLINE void GetMutPtrAndMinContiguousSize(int32_t var_index, int64_t var_offset, T** ptr,\n                                                   size_t* size) const {\n    size_t var_contiguous_size = 0;\n    var_ndarray(var_index).GetMutPtrAndContiguousSize(var_offset, ptr, &var_contiguous_size);\n    *size = std::min(var_contiguous_size,\n                     static_cast<size_t>(contiguous_lens_[var_index]\n                                         - var_offset % contiguous_lens_[var_index]));\n  }\n\n private:\n  ALWAYS_INLINE int32_t CpuVarNdarrayIndex4OutputDim(int64_t output_dim) const {\n    // TODO change to bianry search\n    FOR_RANGE(int32_t, i, 0, dim_ranges_.size()) {\n      if (output_dim >= dim_ranges_[i].begin() && output_dim < dim_ranges_[i].end()) { return i; }\n    }\n    UNIMPLEMENTED();\n  }\n  XpuShape CalcConcatenatedShape(const std::vector<CpuVarNdarray<T, NDIMS>>& var_ndarrays) const {\n    CheckInputShape(var_ndarrays);\n    XpuShape xpu_shape(var_ndarrays[0].xpu_shape());\n    int64_t axes_dim_num = 0;\n    FOR_RANGE(int32_t, i, 0, var_ndarrays.size()) {\n      axes_dim_num += var_ndarrays[i].xpu_shape().At(CONCAT_AXES);\n    }\n    xpu_shape.Set(CONCAT_AXES, axes_dim_num);\n    return xpu_shape;\n  }\n  void CheckInputShape(const std::vector<CpuVarNdarray<T, NDIMS>>& var_ndarrays) const {\n    FOR_RANGE(int32_t, i, 1, var_ndarrays.size()) {\n      FOR_RANGE(int32_t, j, 0, NDIMS) {\n        if (j == CONCAT_AXES) { continue; }\n        CHECK_EQ(var_ndarrays[0].xpu_shape().At(j), var_ndarrays[i].xpu_shape().At(j));\n      }\n    }\n  }\n  std::vector<Range> CalcDimRanges(const std::vector<CpuVarNdarray<T, NDIMS>>& var_ndarrays) const {\n    int64_t axes_dim_num = 0;\n    std::vector<Range> ret;\n    FOR_RANGE(int32_t, i, 0, var_ndarrays.size()) {\n      ret.emplace_back(\n          Range(axes_dim_num, axes_dim_num + var_ndarrays[i].xpu_shape().At(CONCAT_AXES)));\n      axes_dim_num += var_ndarrays[i].xpu_shape().At(CONCAT_AXES);\n    }\n    return ret;\n  }\n  std::vector<size_t> CalcContiguousLens(\n      const std::vector<CpuVarNdarray<T, NDIMS>>& var_ndarrays) const {\n    std::vector<size_t> ret(var_ndarrays.size(), 0);\n    FOR_RANGE(int32_t, i, 0, var_ndarrays.size()) {\n      ret[i] = var_ndarrays[i].xpu_shape().Count(CONCAT_AXES);\n    }\n    return ret;\n  }\n  const std::vector<CpuVarNdarray<T, NDIMS>> var_ndarrays_;\n  const std::vector<Range> dim_ranges_;\n  const std::vector<size_t> contiguous_lens_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_NDARRAY_CPU_CONCAT_VAR_NDARRAY_H_\n"
  },
  {
    "path": "oneflow/core/ndarray/cpu_concat_var_ndarray_test.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/ndarray/cpu_ndarray_builder.h\"\n#include <gtest/gtest.h>\n\nnamespace oneflow {\n\nnamespace test {\n\nTEST(CpuConcatVarNdarray, two_elem_concat) {\n  std::vector<int32_t> x0_data{0};\n  std::vector<int32_t> x1_data{1};\n  std::vector<int32_t> buffer{-1, -1};\n  std::vector<int32_t> expected{0, 1};\n  CpuNdarrayBuilder<int32_t, 1> ndarray;\n  auto x0 = ndarray.Var(Shape{1LL}, x0_data.data());\n  auto x1 = ndarray.Var(Shape{1LL}, x1_data.data());\n  ndarray.Var(Shape{2LL}, buffer.data()).CopyFrom(ndarray.Concatenate({x0, x1}));\n  ASSERT_EQ(memcmp(buffer.data(), expected.data(), sizeof(int32_t) * 2), 0);\n}\n\nTEST(CpuConcatVarNdarray, two_elem_concat_assign) {\n  std::vector<int32_t> x0_data{-1};\n  std::vector<int32_t> x1_data{-1};\n  std::vector<int32_t> buffer{0, 1};\n  CpuNdarrayBuilder<int32_t, 1> ndarray;\n  auto x0 = ndarray.Var(Shape{1LL}, x0_data.data());\n  auto x1 = ndarray.Var(Shape{1LL}, x1_data.data());\n  ndarray.Concatenate({x0, x1}).CopyFrom(ndarray.Var(Shape{2LL}, buffer.data()));\n  ASSERT_EQ(x0_data[0], 0);\n  ASSERT_EQ(x1_data[0], 1);\n}\n\nTEST(CpuConcatVarNdarray, 2d_concat) {\n  // clang-format off\n std::vector<int32_t> x0_data{\n   0, 1, 2,\n   5, 6, 7,\n };\n std::vector<int32_t> x1_data{\n            3, 4,\n            8, 9,\n };\n std::vector<int32_t> expected{\n   0, 1, 2, 3, 4,\n   5, 6, 7, 8, 9,\n };\n std::vector<int32_t> buffer(10, -1);\n  // clang-format on\n  CpuNdarrayBuilder<int32_t, 2> ndarray;\n  auto x0 = ndarray.Var(Shape{2LL, 3LL}, x0_data.data());\n  auto x1 = ndarray.Var(Shape{2LL, 2LL}, x1_data.data());\n  ndarray.Var(Shape{2LL, 5LL}, buffer.data()).CopyFrom(ndarray.template Concatenate<1>({x0, x1}));\n  ASSERT_EQ(memcmp(buffer.data(), expected.data(), sizeof(int32_t) * 10), 0);\n}\n\nTEST(CpuConcatVarNdarray, 2d_concat_assign) {\n  // clang-format off\n std::vector<int32_t> x_data{\n   0, 1, 2, 3, 4,\n   5, 6, 7, 8, 9,\n };\n std::vector<int32_t> y0_buffer(6, -1);\n std::vector<int32_t> y1_buffer(4, -1);\n std::vector<int32_t> y0_expected{\n   0, 1, 2,\n   5, 6, 7,\n };\n std::vector<int32_t> y1_expected{\n            3, 4,\n            8, 9,\n };\n  // clang-format on\n  CpuNdarrayBuilder<int32_t, 2> ndarray;\n  auto x = ndarray.Var(Shape{2LL, 5LL}, x_data.data());\n  auto y0 = ndarray.Var(Shape{2LL, 3LL}, y0_buffer.data());\n  auto y1 = ndarray.Var(Shape{2LL, 2LL}, y1_buffer.data());\n  ndarray.template Concatenate<1>({y0, y1}).CopyFrom(x);\n  ASSERT_EQ(memcmp(y0_buffer.data(), y0_expected.data(), sizeof(int32_t) * 6), 0);\n  ASSERT_EQ(memcmp(y1_buffer.data(), y1_expected.data(), sizeof(int32_t) * 4), 0);\n}\n\nTEST(CpuConcatVarNdarray, 3d_concat) {\n  // clang-format off\n std::vector<int32_t> x0_data{\n   0, 1, 2,\n   5, 6, 7,\n\n   10,11,12,\n   15,16,17 \n };\n std::vector<int32_t> x1_data{\n            3, 4,\n            8, 9,\n\t      \n            13,14,\n            18,19,\n };\n std::vector<int32_t> expected{\n   0, 1, 2, 3, 4,\n   5, 6, 7, 8, 9,\n     \n   10,11,12,13,14,\n   15,16,17,18,19,\n };\n std::vector<int32_t> buffer(20, -1);\n  // clang-format on\n  CpuNdarrayBuilder<int32_t, 3> ndarray;\n  auto x0 = ndarray.Var(Shape{2LL, 2LL, 3LL}, x0_data.data());\n  auto x1 = ndarray.Var(Shape{2LL, 2LL, 2LL}, x1_data.data());\n  ndarray.Var(Shape{2LL, 2LL, 5LL}, buffer.data())\n      .CopyFrom(ndarray.template Concatenate<2>({x0, x1}));\n  ASSERT_EQ(memcmp(buffer.data(), expected.data(), sizeof(int32_t) * 20), 0);\n}\n\nTEST(CpuConcatVarNdarray, 3d_concat_assign) {\n  // clang-format off\n std::vector<int32_t> x_data{\n   0, 1, 2, 3, 4,\n   5, 6, 7, 8, 9,\n     \n   10,11,12,13,14,\n   15,16,17,18,19,\n };\n std::vector<int32_t> y0_expected{\n   0, 1, 2,\n   5, 6, 7,\n\n   10,11,12,\n   15,16,17 \n };\n std::vector<int32_t> y1_expected{\n            3, 4,\n            8, 9,\n     \n            13,14,\n            18,19,\n };\n std::vector<int32_t> y0_buffer(2*2*3, -1);\n std::vector<int32_t> y1_buffer(2*2*2, -1);\n  // clang-format on\n  CpuNdarrayBuilder<int32_t, 3> ndarray;\n  auto x = ndarray.Var(Shape{2LL, 2LL, 5LL}, x_data.data());\n  auto y0 = ndarray.Var(Shape{2LL, 2LL, 3LL}, y0_buffer.data());\n  auto y1 = ndarray.Var(Shape{2LL, 2LL, 2LL}, y1_buffer.data());\n  ndarray.template Concatenate<2>({y0, y1}).CopyFrom(x);\n  ASSERT_EQ(memcmp(y0_buffer.data(), y0_expected.data(), sizeof(int32_t) * y0_expected.size()), 0);\n  ASSERT_EQ(memcmp(y1_buffer.data(), y1_expected.data(), sizeof(int32_t) * y1_expected.size()), 0);\n}\n\n}  // namespace test\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/ndarray/cpu_ndarray.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_NDARRAY_CPU_NDARRAY_H_\n#define ONEFLOW_CORE_NDARRAY_CPU_NDARRAY_H_\n\n#include <climits>\n#include \"oneflow/core/common/shape.h\"\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/ndarray/xpu_shape.h\"\n\nnamespace oneflow {\n\ntemplate<typename T, int NDIMS>\nclass CpuNdarray {\n public:\n  using dtype = T;\n  static const int ndims = NDIMS;\n\n  ALWAYS_INLINE const XpuShape& xpu_shape() const { return xpu_shape_; }\n\n protected:\n  explicit CpuNdarray(const Shape& shape) : xpu_shape_(shape) {}\n  explicit CpuNdarray(const XpuShape& xpu_shape) : xpu_shape_(xpu_shape) {}\n  virtual ~CpuNdarray() = default;\n\n private:\n  XpuShape xpu_shape_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_NDARRAY_CPU_NDARRAY_H_\n"
  },
  {
    "path": "oneflow/core/ndarray/cpu_ndarray_builder.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_NDARRAY_CPU_NDARRAY_HELPER_H_\n#define ONEFLOW_CORE_NDARRAY_CPU_NDARRAY_HELPER_H_\n\n#include \"oneflow/core/ndarray/cpu_var_ndarray.h\"\n#include \"oneflow/core/ndarray/cpu_slice_var_ndarray.h\"\n#include \"oneflow/core/ndarray/cpu_concat_var_ndarray.h\"\n\nnamespace oneflow {\n\ntemplate<typename default_data_type, int default_ndims>\nclass CpuNdarrayBuilder final {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CpuNdarrayBuilder);\n  CpuNdarrayBuilder() = default;\n  ~CpuNdarrayBuilder() = default;\n\n  template<typename T = default_data_type, int NDIMS = default_ndims>\n  CpuVarNdarray<T, NDIMS> Var(const Shape& shape, T* ptr) const {\n    return CpuVarNdarray<T, NDIMS>(shape, ptr);\n  }\n  template<typename T = default_data_type, int NDIMS = default_ndims>\n  CpuVarNdarray<T, NDIMS> Var(const ShapeView& shape_view, T* ptr) const {\n    return CpuVarNdarray<T, NDIMS>(shape_view, ptr);\n  }\n  template<int CONCAT_AXES = 0, typename T = default_data_type, int NDIMS = default_ndims>\n  CpuConcatVarNdarray<T, NDIMS, CONCAT_AXES> Concatenate(\n      const std::vector<CpuVarNdarray<T, NDIMS>>& var_ndarrays) const {\n    return CpuConcatVarNdarray<T, NDIMS, CONCAT_AXES>(var_ndarrays);\n  }\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_NDARRAY_CPU_NDARRAY_HELPER_H_\n"
  },
  {
    "path": "oneflow/core/ndarray/cpu_ndarray_copy.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_NDARRAY_CPU_NDARRAY_COPY_H_\n#define ONEFLOW_CORE_NDARRAY_CPU_NDARRAY_COPY_H_\n\n#include \"oneflow/core/ndarray/cpu_ndarray.h\"\n\nnamespace oneflow {\n\ntemplate<typename YT, typename XT, typename T = typename YT::dtype>\nvoid CpuNdarrayCopy(YT* y_ndarray, const XT& x_ndarray) {\n  CHECK_EQ(y_ndarray->xpu_shape().ElemNum(), x_ndarray.xpu_shape().ElemNum());\n  T* dst_ptr = nullptr;\n  size_t dst_size = 0;\n  T* src_ptr = nullptr;\n  size_t src_size = 0;\n  int64_t cur_index = 0;\n  size_t total_elem_cnt = y_ndarray->xpu_shape().ElemNum();\n  while (cur_index < total_elem_cnt) {\n    if (dst_size == 0) { y_ndarray->GetMutPtrAndContiguousSize(cur_index, &dst_ptr, &dst_size); }\n    if (src_size == 0) { x_ndarray.GetMutPtrAndContiguousSize(cur_index, &src_ptr, &src_size); }\n    if (src_size == 0) { break; }\n    size_t cp_size = std::min(dst_size, src_size);\n    if (cp_size == 1) {\n      *dst_ptr = *src_ptr;\n    } else {\n      memcpy(dst_ptr, src_ptr, sizeof(T) * cp_size);\n    }\n    dst_ptr += cp_size;\n    src_ptr += cp_size;\n    dst_size -= cp_size;\n    src_size -= cp_size;\n    cur_index += cp_size;\n  }\n  CHECK_EQ(dst_size, 0);\n  CHECK_EQ(src_size, 0);\n  CHECK_EQ(cur_index, total_elem_cnt);\n}\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_NDARRAY_CPU_NDARRAY_COPY_H_\n"
  },
  {
    "path": "oneflow/core/ndarray/cpu_slice_var_ndarray.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_NDARRAY_CPU_SLICE_NDARRAY_H_\n#define ONEFLOW_CORE_NDARRAY_CPU_SLICE_NDARRAY_H_\n\n#include \"oneflow/core/ndarray/slice.h\"\n#include \"oneflow/core/ndarray/cpu_ndarray.h\"\n#include \"oneflow/core/ndarray/cpu_ndarray_copy.h\"\n\nnamespace oneflow {\n\ntemplate<typename XT>\nclass CpuSliceVarNdarray : public CpuNdarray<typename XT::dtype, XT::ndims> {\n public:\n  CpuSliceVarNdarray(XT&& x, std::array<Slice, XT::ndims>&& slices)\n      : CpuNdarray<typename XT::dtype, XT::ndims>(\n          BoundedSlices2Shape(BoundSlices(x, std::move(slices)))),\n        x_(x),\n        slices_(std::move(slices)) {\n    SetContiguousLength(slices);\n  }\n  virtual ~CpuSliceVarNdarray() = default;\n\n  CpuSliceVarNdarray<CpuSliceVarNdarray<XT>> operator()(Slice&& slice0) {\n    static_assert(XT::ndims == 1, \"NDIMS error\");\n    return CpuSliceVarNdarray<CpuSliceVarNdarray<XT>>(std::move(*this), {slice0});\n  }\n  CpuSliceVarNdarray<CpuSliceVarNdarray<XT>> operator()(Slice&& slice0, Slice&& slice1) {\n    static_assert(XT::ndims == 2, \"NDIMS error\");\n    return CpuSliceVarNdarray<CpuSliceVarNdarray<XT>>(std::move(*this), {slice0, slice1});\n  }\n  CpuSliceVarNdarray<CpuSliceVarNdarray<XT>> operator()(Slice&& slice0, Slice&& slice1,\n                                                        Slice&& slice2) {\n    static_assert(XT::ndims == 3, \"NDIMS error\");\n    return CpuSliceVarNdarray<CpuSliceVarNdarray<XT>>(std::move(*this), {slice0, slice1, slice2});\n  }\n  CpuSliceVarNdarray<CpuSliceVarNdarray<XT>> operator()(Slice&& slice0, Slice&& slice1,\n                                                        Slice&& slice2, Slice&& slice3) {\n    static_assert(XT::ndims == 4, \"NDIMS error\");\n    return CpuSliceVarNdarray<CpuSliceVarNdarray<XT>>(std::move(*this),\n                                                      {slice0, slice1, slice2, slice3});\n  }\n  CpuSliceVarNdarray<CpuSliceVarNdarray<XT>> operator()(Slice&& slice0, Slice&& slice1,\n                                                        Slice&& slice2, Slice&& slice3,\n                                                        Slice&& slice4) {\n    static_assert(XT::ndims == 5, \"NDIMS error\");\n    return CpuSliceVarNdarray<CpuSliceVarNdarray<XT>>(std::move(*this),\n                                                      {slice0, slice1, slice2, slice3, slice4});\n  }\n\n  template<typename AT>\n  void CopyFrom(const AT& ndarray) {\n    CpuNdarrayCopy(this, ndarray);\n  }\n\n  ALWAYS_INLINE void GetMutPtrAndContiguousSize(int64_t offset, typename XT::dtype** ptr,\n                                                size_t* size) const {\n    int64_t dim[XT::ndims] = {0};\n    this->xpu_shape().template Offset2Coordinate<XT::ndims>(offset, dim);\n    for (int i = 0; i < XT::ndims; ++i) { dim[i] = this->slice(i).Get(dim[i]); }\n    size_t x_offset = this->x().xpu_shape().template Coordinate2Offset<XT::ndims>(dim);\n    this->GetMutPtrAndMinContiguousSize(offset, x_offset, ptr, size);\n  }\n\n protected:\n  ALWAYS_INLINE const XT& x() const { return x_; }\n  ALWAYS_INLINE const Slice& slice(int32_t dim) const { return slices_[dim]; }\n  ALWAYS_INLINE void GetMutPtrAndMinContiguousSize(int64_t offset, int64_t x_offset,\n                                                   typename XT::dtype** ptr, size_t* size) const {\n    size_t x_contiguous_size;\n    this->x().GetMutPtrAndContiguousSize(x_offset, ptr, &x_contiguous_size);\n    size_t slice_contiguous_size = (contiguous_len_ - offset % contiguous_len_);\n    *size = std::min(x_contiguous_size, slice_contiguous_size);\n  }\n\n private:\n  static std::array<Slice, XT::ndims>&& BoundSlices(const XT& x,\n                                                    std::array<Slice, XT::ndims>&& slices) {\n    FOR_RANGE(int32_t, i, 0, XT::ndims) { slices[i].Bound(x.xpu_shape().At(i)); }\n    return std::move(slices);\n  }\n  static Shape BoundedSlices2Shape(const std::array<Slice, XT::ndims>& bounded_slices) {\n    DimVector dim_vec;\n    for (const Slice& slice : bounded_slices) {\n      CHECK_GT(slice.Size(), 0);\n      dim_vec.emplace_back(slice.Size());\n    }\n    return Shape(dim_vec);\n  }\n  void SetContiguousLength(const std::array<Slice, XT::ndims>& bounded_slices) {\n    contiguous_len_ = 1;\n    for (int i = XT::ndims - 1; i >= 0; --i) {\n      if (bounded_slices[i].IsContiguous()) { contiguous_len_ *= bounded_slices[i].Size(); }\n      if (!(bounded_slices[i].IsContiguous() && bounded_slices[i].IsCoveringAll())) { break; }\n    }\n  }\n  const XT& x_;\n  std::array<Slice, XT::ndims> slices_;\n  size_t contiguous_len_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_NDARRAY_CPU_SLICE_NDARRAY_H_\n"
  },
  {
    "path": "oneflow/core/ndarray/cpu_slice_var_ndarray_test.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/ndarray/cpu_ndarray_builder.h\"\n#include <gtest/gtest.h>\n\nnamespace oneflow {\n\nnamespace test {\n\nTEST(CpuSliceVarNdarray, one_elem_assign) {\n  std::vector<int32_t> data({1});\n  std::vector<int32_t> buffer({0});\n  CpuNdarrayBuilder<int32_t, 1> ndarray;\n  auto&& data_ndarray = ndarray.Var(Shape{1LL}, data.data());\n  auto&& buffer_ndarray = ndarray.Var(Shape{1LL}, buffer.data());\n  buffer_ndarray(0).CopyFrom(data_ndarray(0));\n  ASSERT_EQ(data[0], buffer[0]);\n}\n\nTEST(CpuSliceVarNdarray, one_elem_assign_slice_on_slice) {\n  std::vector<int32_t> data({1});\n  std::vector<int32_t> buffer({0});\n  CpuNdarrayBuilder<int32_t, 1> ndarray;\n  auto&& data_ndarray = ndarray.Var(Shape{1LL}, data.data());\n  auto&& buffer_ndarray = ndarray.Var(Shape{1LL}, buffer.data());\n  buffer_ndarray(0)(0).CopyFrom(data_ndarray(0)(0));\n  ASSERT_EQ(data[0], buffer[0]);\n}\n\nTEST(CpuSliceVarNdarray, 1d_assign) {\n  std::vector<int32_t> data({0, 1, 2, 3, 4, 5, 6, 7, 8, 9});\n  std::vector<int32_t> buffer(10, 0);\n  CpuNdarrayBuilder<int32_t, 1> ndarray;\n  auto&& data_ndarray = ndarray.Var(Shape{10LL}, data.data());\n  auto&& buffer_ndarray = ndarray.Var(Shape{10LL}, buffer.data());\n  buffer_ndarray({}).CopyFrom(data_ndarray({}));\n  ASSERT_EQ(memcmp(data.data(), buffer.data(), sizeof(int32_t) * 10), 0);\n}\n\nTEST(CpuSliceVarNdarray, 1d_slice_assign) {\n  std::vector<int32_t> data({1, 2, 3, 4, 5, 6, 7, 8});\n  std::vector<int32_t> buffer(10, 100);\n  std::vector<int32_t> expected({100, 1, 2, 3, 4, 5, 6, 7, 8, 100});\n  CpuNdarrayBuilder<int32_t, 1> ndarray;\n  auto&& data_ndarray = ndarray.Var(Shape{static_cast<int64_t>(data.size())}, data.data());\n  auto&& buffer_ndarray = ndarray.Var(Shape{10LL}, buffer.data());\n  ASSERT_EQ(buffer_ndarray({1, -1}).xpu_shape(), XpuShape(Shape({8})));\n  buffer_ndarray({1, -1}).CopyFrom(data_ndarray({}));\n  ASSERT_EQ(memcmp(expected.data(), buffer.data(), sizeof(int32_t) * 10), 0);\n}\n\nTEST(CpuSliceVarNdarray, 1d_slice) {\n  std::vector<int32_t> data({100, 1, 2, 3, 4, 5, 6, 7, 8, 100});\n  std::vector<int32_t> buffer(8, 100);\n  std::vector<int32_t> expected({1, 2, 3, 4, 5, 6, 7, 8});\n  CpuNdarrayBuilder<int32_t, 1> ndarray;\n  auto&& data_ndarray = ndarray.Var(Shape{static_cast<int64_t>(data.size())}, data.data());\n  auto&& buffer_ndarray = ndarray.Var(Shape{static_cast<int64_t>(buffer.size())}, buffer.data());\n  buffer_ndarray({}).CopyFrom(data_ndarray({1, -1}));\n  ASSERT_EQ(memcmp(expected.data(), buffer.data(), sizeof(int32_t) * buffer.size()), 0);\n}\n\nTEST(CpuSliceVarNdarray, 2d_slice) {\n  // clang-format off\n  std::vector<int32_t> data({\n      100, 100, 100, 100,\n      100, 0,   1,   100,\n      100, 2,   3,   100,\n      100, 100, 100, 100,\n  });\n  // clang-format on\n  std::vector<int32_t> buffer(4, 100);\n  std::vector<int32_t> expected({0, 1, 2, 3});\n  CpuNdarrayBuilder<int32_t, 2> ndarray;\n  auto&& data_ndarray = ndarray.Var(Shape{4LL, 4LL}, data.data());\n  auto&& buffer_ndarray = ndarray.Var(Shape{2LL, 2LL}, buffer.data());\n  buffer_ndarray({}, {}).CopyFrom(data_ndarray({1, -1}, {1, -1}));\n  ASSERT_EQ(memcmp(expected.data(), buffer.data(), sizeof(int32_t) * buffer.size()), 0);\n}\n\nTEST(CpuSliceVarNdarray, 2d_slice_assign) {\n  std::vector<int32_t> data({0, 1, 2, 3});\n  std::vector<int32_t> buffer(16, 100);\n  // clang-format off\n  std::vector<int32_t> expected({\n      100, 100, 100, 100,\n      100, 0,   1,   100,\n      100, 2,   3,   100,\n      100, 100, 100, 100,\n  });\n  // clang-format on\n  CpuNdarrayBuilder<int32_t, 2> ndarray;\n  auto&& data_ndarray = ndarray.Var(Shape{2LL, 2LL}, data.data());\n  auto&& buffer_ndarray = ndarray.Var(Shape{4LL, 4LL}, buffer.data());\n  buffer_ndarray({1, -1}, {1, -1}).CopyFrom(data_ndarray({}, {}));\n  ASSERT_EQ(memcmp(expected.data(), buffer.data(), sizeof(int32_t) * buffer.size()), 0);\n}\n\nTEST(CpuSliceVarNdarray, 2d_slice_reverse) {\n  // clang-format off\n  std::vector<int32_t> data({\n      100, 100, 100, 100,\n      100, 0,   1,   100,\n      100, 2,   3,   100,\n      100, 100, 100, 100,\n  });\n  std::vector<int32_t> buffer(16, 100);\n  std::vector<int32_t> expected({\n      100, 100, 100, 100,\n      100, 2,   3,   100,\n      100, 0,   1,   100,\n      100, 100, 100, 100,\n  });\n  // clang-format on\n  CpuNdarrayBuilder<int32_t, 2> ndarray;\n  auto&& data_ndarray = ndarray.Var(Shape{4LL, 4LL}, data.data());\n  auto&& buffer_ndarray = ndarray.Var(Shape{4LL, 4LL}, buffer.data());\n  buffer_ndarray({1, -1}, {1, -1}).CopyFrom(data_ndarray({-2, 0, -1}, {1, -1}));\n  ASSERT_EQ(memcmp(expected.data(), buffer.data(), sizeof(int32_t) * buffer.size()), 0);\n}\n\nTEST(CpuSliceVarNdarray, 3d_slice) {\n  // clang-format off\n  std::vector<int32_t> data({\n      100, 100, 100, 100,\n      100, 0,   1,   100,\n      100, 2,   3,   100,\n      100, 100, 100, 100,\n\t\n      100, 100, 100, 100,\n      100, 4,   5,   100,\n      100, 6,   7,   100,\n      100, 100, 100, 100,\n  });\n  std::vector<int32_t> buffer(8, -1);\n  std::vector<int32_t> expected({\n      0, 1,\n      2, 3,\n\n      4, 5,\n      6, 7\n  });\n  // clang-format on\n  CpuNdarrayBuilder<int32_t, 3> ndarray;\n  auto&& data_ndarray = ndarray.Var(Shape{2LL, 4LL, 4LL}, data.data());\n  auto&& buffer_ndarray = ndarray.Var(Shape{2LL, 2LL, 2LL}, buffer.data());\n  buffer_ndarray.CopyFrom(data_ndarray({}, {1, -1}, {1, -1}));\n  ASSERT_EQ(memcmp(expected.data(), buffer.data(), sizeof(int32_t) * buffer.size()), 0);\n}\n\nTEST(CpuSliceVarNdarray, 3d_slice_assign) {\n  // clang-format off\n  std::vector<int32_t> data({\n      0, 1,\n      2, 3,\n\n      4, 5,\n      6, 7\n  });\n  std::vector<int32_t> buffer(32, 100);\n  std::vector<int32_t> expected({\n      100, 100, 100, 100,\n      100, 0,   1,   100,\n      100, 2,   3,   100,\n      100, 100, 100, 100,\n\t\n      100, 100, 100, 100,\n      100, 4,   5,   100,\n      100, 6,   7,   100,\n      100, 100, 100, 100,\n  });\n  // clang-format on\n  CpuNdarrayBuilder<int32_t, 3> ndarray;\n  auto&& data_ndarray = ndarray.Var(Shape{2LL, 2LL, 2LL}, data.data());\n  auto&& buffer_ndarray = ndarray.Var(Shape{2LL, 4LL, 4LL}, buffer.data());\n  buffer_ndarray({}, {1, -1}, {1, -1}).CopyFrom(data_ndarray);\n  ASSERT_EQ(memcmp(expected.data(), buffer.data(), sizeof(int32_t) * buffer.size()), 0);\n}\n\n}  // namespace test\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/ndarray/cpu_var_ndarray.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_NDARRAY_CPU_VAR_NDARRAY_H_\n#define ONEFLOW_CORE_NDARRAY_CPU_VAR_NDARRAY_H_\n\n#include \"oneflow/core/ndarray/cpu_ndarray.h\"\n#include \"oneflow/core/ndarray/cpu_ndarray_copy.h\"\n\nnamespace oneflow {\n\nclass Slice;\ntemplate<typename XT>\nclass CpuSliceVarNdarray;\n\ntemplate<typename T, int NDIMS>\nclass CpuVarNdarray : public CpuNdarray<T, NDIMS> {\n public:\n  CpuVarNdarray(const CpuVarNdarray&) = default;\n  CpuVarNdarray(const Shape& shape, T* ptr)\n      : CpuNdarray<T, NDIMS>(shape), ptr_(ptr), len_(shape.elem_cnt()) {\n    CHECK_GT(len_, 0);\n  }\n  CpuVarNdarray(const ShapeView& shape_view, T* ptr)\n      : CpuNdarray<T, NDIMS>(XpuShape(shape_view)), ptr_(ptr), len_(shape_view.elem_cnt()) {\n    CHECK_GT(len_, 0);\n  }\n  virtual ~CpuVarNdarray() = default;\n\n  CpuSliceVarNdarray<CpuVarNdarray<T, NDIMS>> operator()(Slice&& slice0) {\n    static_assert(NDIMS == 1, \"NDIMS error\");\n    return CpuSliceVarNdarray<CpuVarNdarray<T, NDIMS>>(std::move(*this), {slice0});\n  }\n  CpuSliceVarNdarray<CpuVarNdarray<T, NDIMS>> operator()(Slice&& slice0, Slice&& slice1) {\n    static_assert(NDIMS == 2, \"NDIMS error\");\n    return CpuSliceVarNdarray<CpuVarNdarray<T, NDIMS>>(std::move(*this), {slice0, slice1});\n  }\n  CpuSliceVarNdarray<CpuVarNdarray<T, NDIMS>> operator()(Slice&& slice0, Slice&& slice1,\n                                                         Slice&& slice2) {\n    static_assert(NDIMS == 3, \"NDIMS error\");\n    return CpuSliceVarNdarray<CpuVarNdarray<T, NDIMS>>(std::move(*this), {slice0, slice1, slice2});\n  }\n  CpuSliceVarNdarray<CpuVarNdarray<T, NDIMS>> operator()(Slice&& slice0, Slice&& slice1,\n                                                         Slice&& slice2, Slice&& slice3) {\n    static_assert(NDIMS == 4, \"NDIMS error\");\n    return CpuSliceVarNdarray<CpuVarNdarray<T, NDIMS>>(std::move(*this),\n                                                       {slice0, slice1, slice2, slice3});\n  }\n  CpuSliceVarNdarray<CpuVarNdarray<T, NDIMS>> operator()(Slice&& slice0, Slice&& slice1,\n                                                         Slice&& slice2, Slice&& slice3,\n                                                         Slice&& slice4) {\n    static_assert(NDIMS == 5, \"NDIMS error\");\n    return CpuSliceVarNdarray<CpuVarNdarray<T, NDIMS>>(std::move(*this),\n                                                       {slice0, slice1, slice2, slice3, slice4});\n  }\n\n  template<typename XT>\n  void CopyFrom(const XT& ndarray) {\n    CpuNdarrayCopy(this, ndarray);\n  }\n\n  ALWAYS_INLINE void GetMutPtrAndContiguousSize(int64_t offset, T** ptr, size_t* size) const {\n    *ptr = ptr_ + offset;\n    *size = len_ - offset;\n  }\n\n protected:\n  ALWAYS_INLINE T* ptr() const { return ptr_; }\n  ALWAYS_INLINE size_t len() const { return len_; }\n\n private:\n  T* const ptr_;\n  size_t len_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_NDARRAY_CPU_VAR_NDARRAY_H_\n"
  },
  {
    "path": "oneflow/core/ndarray/cpu_var_ndarray_test.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/ndarray/cpu_ndarray_builder.h\"\n#include <gtest/gtest.h>\n\nnamespace oneflow {\n\nnamespace test {\n\nTEST(CpuVarNdarray, one_elem_assign) {\n  std::vector<int32_t> data({1});\n  std::vector<int32_t> buffer({0});\n  CpuNdarrayBuilder<int32_t, 1> ndarray;\n  auto&& data_ndarray = ndarray.Var(Shape{1LL}, data.data());\n  auto&& buffer_ndarray = ndarray.Var(Shape{1LL}, buffer.data());\n  buffer_ndarray.CopyFrom(data_ndarray);\n  ASSERT_EQ(data[0], buffer[0]);\n}\n\nTEST(CpuVarNdarray, 1d_assign) {\n  std::vector<int32_t> data({0, 1, 2, 3, 4, 5, 6, 7, 8, 9});\n  std::vector<int32_t> buffer(10, 0);\n  CpuNdarrayBuilder<int32_t, 1> ndarray;\n  auto&& data_ndarray = ndarray.Var(Shape{10LL}, data.data());\n  auto&& buffer_ndarray = ndarray.Var(Shape{10LL}, buffer.data());\n  buffer_ndarray.CopyFrom(data_ndarray);\n  ASSERT_EQ(memcmp(data.data(), buffer.data(), sizeof(int32_t) * 10), 0);\n}\n\n}  // namespace test\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/ndarray/ndarray_apply_binary.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_NDARRAY_NDARRAY_APPLY_BINARY_H_\n#define ONEFLOW_CORE_NDARRAY_NDARRAY_APPLY_BINARY_H_\n\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/ndarray/ndarray_apply_binary_core.h\"\n\nnamespace oneflow {\n\ntemplate<DeviceType device_type, typename T, template<typename> class binary_func,\n         typename Enable = void>\nstruct NdarrayApplyBinary;\n\ntemplate<DeviceType device_type, typename T, template<typename> class binary_func>\nstruct NdarrayApplyBinary<\n    device_type, T, binary_func,\n    typename std::enable_if<std::is_same<T, typename DevDType<device_type, T>::type>::value>::type>\n    final {\n  static void Apply(ep::Stream* stream,\n                    const XpuVarNdarray<typename BinaryFuncTrait<binary_func, T>::return_type>& y,\n                    const XpuVarNdarray<const T>& a, const XpuVarNdarray<const T>& b) {\n    NdarrayApplyBinaryCoreWrapper<device_type, T, binary_func>::Apply(stream, y, a, b);\n  }\n  static void InplaceApply(ep::Stream* stream, const XpuVarNdarray<T>& y,\n                           const XpuVarNdarray<const T>& x) {\n    NdarrayApplyBinaryCoreWrapper<device_type, T, binary_func>::InplaceApply(stream, y, x);\n  }\n};\n\ntemplate<DeviceType device_type, typename T, template<typename> class binary_func>\nstruct NdarrayApplyBinary<\n    device_type, T, binary_func,\n    typename std::enable_if<!std::is_same<T, typename DevDType<device_type, T>::type>::value>::type>\n    final {\n  using NewT = typename DevDType<device_type, T>::type;\n  static void Apply(ep::Stream* stream,\n                    const XpuVarNdarray<typename BinaryFuncTrait<binary_func, T>::return_type>& y,\n                    const XpuVarNdarray<const T>& a, const XpuVarNdarray<const T>& b) {\n    return NdarrayApplyBinary<device_type, NewT, binary_func>::Apply(\n        stream, reinterpret_cast<const XpuVarNdarray<NewT>&>(y),\n        reinterpret_cast<const XpuVarNdarray<const NewT>&>(a),\n        reinterpret_cast<const XpuVarNdarray<const NewT>&>(b));\n  }\n  static void InplaceApply(ep::Stream* stream, const XpuVarNdarray<T>& y,\n                           const XpuVarNdarray<const T>& x) {\n    return NdarrayApplyBinary<device_type, NewT, binary_func>::InplaceApply(\n        stream, reinterpret_cast<const XpuVarNdarray<NewT>&>(y),\n        reinterpret_cast<const XpuVarNdarray<const NewT>&>(x));\n  }\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_NDARRAY_NDARRAY_APPLY_BINARY_H_\n"
  },
  {
    "path": "oneflow/core/ndarray/ndarray_apply_binary_core.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/ndarray/ndarray_apply_binary_core.h\"\n#include \"oneflow/core/ndarray/binary_func.h\"\n\nnamespace oneflow {\n\ntemplate<typename T, template<typename> class binary_func>\nstruct NdarrayApplyBinaryCoreWrapper<DeviceType::kCPU, T, binary_func> final {\n  static void Apply(ep::Stream* stream,\n                    const XpuVarNdarray<typename BinaryFuncTrait<binary_func, T>::return_type>& y,\n                    const XpuVarNdarray<const T>& a, const XpuVarNdarray<const T>& b) {\n    NdarrayApplyBinaryCore<T, binary_func>::Apply(y.shape().ElemNum(), y.ptr(), a.ptr(), b.ptr());\n  }\n  static void InplaceApply(ep::Stream* stream, const XpuVarNdarray<T>& y,\n                           const XpuVarNdarray<const T>& x) {\n    NdarrayApplyBinaryCore<T, binary_func>::InplaceApply(y.shape().ElemNum(), y.ptr(), x.ptr());\n  }\n};\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/ndarray/ndarray_apply_binary_core.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/ndarray/ndarray_apply_binary_core.h\"\n#include \"oneflow/core/ndarray/binary_func.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename T, template<typename> class binary_func>\n__global__ void NdarrayApplyBinaryApplyGpu(size_t n,\n                                           typename BinaryFuncTrait<binary_func, T>::return_type* y,\n                                           const T* a, const T* b) {\n  NdarrayApplyBinaryCore<T, binary_func>::Apply(n, y, a, b);\n}\n\ntemplate<typename T, template<typename> class binary_func>\n__global__ void NdarrayApplyBinaryInplaceApplyGpu(size_t n, T* y, const T* x) {\n  NdarrayApplyBinaryCore<T, binary_func>::InplaceApply(n, y, x);\n}\n\n}  // namespace\n\ntemplate<typename T, template<typename> class binary_func>\nstruct NdarrayApplyBinaryCoreWrapper<DeviceType::kCUDA, T, binary_func> final {\n  static void Apply(ep::Stream* stream,\n                    const XpuVarNdarray<typename BinaryFuncTrait<binary_func, T>::return_type>& y,\n                    const XpuVarNdarray<const T>& a, const XpuVarNdarray<const T>& b) {\n    size_t n = y.host_shape().HostElemNum();\n    if (n == 0) { return; }\n    RUN_CUDA_KERNEL((NdarrayApplyBinaryApplyGpu<T, binary_func>), stream, n, n, y.host_ptr(),\n                    a.host_ptr(), b.host_ptr());\n  }\n  static void InplaceApply(ep::Stream* stream, const XpuVarNdarray<T>& y,\n                           const XpuVarNdarray<const T>& x) {\n    size_t n = y.host_shape().HostElemNum();\n    if (n == 0) { return; }\n    RUN_CUDA_KERNEL((NdarrayApplyBinaryInplaceApplyGpu<T, binary_func>), stream, n, n, y.host_ptr(),\n                    x.host_ptr());\n  }\n};\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/ndarray/ndarray_apply_binary_core.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_NDARRAY_NDARRAY_APPLY_BINARY_CORE_H_\n#define ONEFLOW_CORE_NDARRAY_NDARRAY_APPLY_BINARY_CORE_H_\n\n#include \"oneflow/core/kernel/kernel_util.h\"\n#include \"oneflow/core/ndarray/xpu_binary_func_ndarray.h\"\n#include \"oneflow/core/ndarray/xpu_var_ndarray.h\"\n#include \"oneflow/core/ndarray/xpu_util.h\"\n#include \"oneflow/core/ndarray/binary_func.h\"\n\nnamespace oneflow {\n\ntemplate<DeviceType device_type, typename T, template<typename> class binary_func>\nstruct NdarrayApplyBinaryCoreWrapper final {\n  static void Apply(ep::Stream* stream,\n                    const XpuVarNdarray<typename BinaryFuncTrait<binary_func, T>::return_type>& y,\n                    const XpuVarNdarray<const T>& a, const XpuVarNdarray<const T>& b);\n  static void InplaceApply(ep::Stream* stream, const XpuVarNdarray<T>& y,\n                           const XpuVarNdarray<const T>& x);\n};\n\ntemplate<typename T, template<typename> class binary_func>\nstruct NdarrayApplyBinaryCore final {\n  OF_DEVICE_FUNC static void Apply(size_t n,\n                                   typename BinaryFuncTrait<binary_func, T>::return_type* y,\n                                   const T* a, const T* b) {\n    XPU_1D_KERNEL_LOOP_BEGIN(i, n);\n    y[i] = binary_func<T>::Invoke(a[i], b[i]);\n    XPU_1D_KERNEL_LOOP_END();\n  }\n  OF_DEVICE_FUNC static void InplaceApply(size_t n, T* y, const T* x) {\n    XPU_1D_KERNEL_LOOP_BEGIN(i, n);\n    y[i] = binary_func<T>::Invoke(y[i], x[i]);\n    XPU_1D_KERNEL_LOOP_END();\n  }\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_NDARRAY_NDARRAY_APPLY_BINARY_CORE_H_\n"
  },
  {
    "path": "oneflow/core/ndarray/ndarray_apply_broadcast_binary.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_NDARRAY_NDARRAY_APPLY_BROADCAST_BINARY_H_\n#define ONEFLOW_CORE_NDARRAY_NDARRAY_APPLY_BROADCAST_BINARY_H_\n\n#include \"oneflow/core/ndarray/ndarray_apply_broadcast_binary_core.h\"\n#include \"oneflow/core/ndarray/ndarray_apply_binary.h\"\n#include \"oneflow/core/common/util.h\"\n\nnamespace oneflow {\n\ntemplate<DeviceType device_type, typename T, template<typename> class binary_func,\n         typename Enable = void>\nstruct NdarrayApplyBroadcastBinary;\n\ntemplate<DeviceType device_type, typename T, template<typename> class binary_func>\nstruct NdarrayApplyBroadcastBinary<\n    device_type, T, binary_func,\n    typename std::enable_if<std::is_same<T, typename DevDType<device_type, T>::type>::value>::type>\n    final {\n  using RetT = typename BinaryFuncTrait<binary_func, T>::return_type;\n  static void Apply(ep::Stream* stream, const XpuVarNdarray<RetT>& y,\n                    const XpuVarNdarray<const T>& a, const XpuVarNdarray<const T>& b) {\n    if (a.shape() == b.shape()) {\n      return NdarrayApplyBinary<device_type, T, binary_func>::Apply(stream, y, a, b);\n    }\n    if (TryInplaceApply<std::is_same<RetT, T>::value>(stream, y, a, b)) { return; }\n    CheckBroadcastable(y, a, b);\n    DimVector simplified_y_dim;\n    DimVector simplified_a_dim;\n    DimVector simplified_b_dim;\n    SimplifyBroadcastShapes(y.shape(), a.shape(), b.shape(), &simplified_y_dim, &simplified_a_dim,\n                            &simplified_b_dim);\n    return SwitchApply(SwitchCase(simplified_y_dim.size()), stream,\n                       XpuVarNdarray<RetT>(Shape(simplified_y_dim), y.ptr()),\n                       XpuVarNdarray<const T>(Shape(simplified_a_dim), a.ptr()),\n                       XpuVarNdarray<const T>(Shape(simplified_b_dim), b.ptr()));\n  }\n\n  template<bool enabled>\n  static typename std::enable_if<enabled, bool>::type TryInplaceApply(\n      ep::Stream* stream, const XpuVarNdarray<RetT>& y, const XpuVarNdarray<const T>& a,\n      const XpuVarNdarray<const T>& b) {\n    bool is_inplace = (y.shape() == a.shape() && y.ptr() == a.ptr());\n    if (is_inplace) { InplaceApply(stream, y, b); }\n    return is_inplace;\n  }\n\n  template<bool enabled>\n  static typename std::enable_if<!enabled, bool>::type TryInplaceApply(\n      ep::Stream* stream, const XpuVarNdarray<RetT>& y, const XpuVarNdarray<const T>& a,\n      const XpuVarNdarray<const T>& b) {\n    return false;\n  }\n\n  static void InplaceApply(ep::Stream* stream, const XpuVarNdarray<T>& y,\n                           const XpuVarNdarray<const T>& x) {\n    if (y.shape() == x.shape()) {\n      return NdarrayApplyBinary<device_type, T, binary_func>::InplaceApply(stream, y, x);\n    }\n    CheckBroadcastable(y, reinterpret_cast<const XpuVarNdarray<const T>&>(y), x);\n    DimVector simplified_y_dim;\n    DimVector simplified_x_dim;\n    SimplifyBroadcastShapes(y.shape(), x.shape(), &simplified_y_dim, &simplified_x_dim);\n    return SwitchInplaceApply(SwitchCase(simplified_y_dim.size()), stream,\n                              XpuVarNdarray<T>(Shape(simplified_y_dim), y.ptr()),\n                              XpuVarNdarray<const T>(Shape(simplified_x_dim), x.ptr()));\n  }\n\n private:\n#define MAKE_NDARRAY_BROADCAST_BINARY(func_name, NDIMS) \\\n  NdarrayApplyBroadcastBinaryCoreWrapper<device_type, T, NDIMS, binary_func>::func_name\n  DEFINE_STATIC_SWITCH_FUNC(void, Apply, MAKE_NDARRAY_BROADCAST_BINARY, MAKE_NDIM_CTRV_SEQ(DIM_SEQ))\n#undef MAKE_NDARRAY_BROADCAST_BINARY\n\n#define MAKE_NDARRAY_INPLACE_BROADCAST_BINARY(func_name, NDIMS) \\\n  NdarrayApplyBroadcastInplaceBinaryCoreWrapper<device_type, T, NDIMS, binary_func>::func_name\n  DEFINE_STATIC_SWITCH_FUNC(void, InplaceApply, MAKE_NDARRAY_INPLACE_BROADCAST_BINARY,\n                            MAKE_NDIM_CTRV_SEQ(DIM_SEQ))\n#undef MAKE_NDARRAY_INPLACE_BROADCAST_BINARY\n\n  static void CheckBroadcastable(\n      const XpuVarNdarray<typename BinaryFuncTrait<binary_func, T>::return_type>& y,\n      const XpuVarNdarray<const T>& a, const XpuVarNdarray<const T>& b) {\n    CHECK_EQ(y.shape().NumAxes(), a.shape().NumAxes());\n    CHECK_EQ(y.shape().NumAxes(), b.shape().NumAxes());\n    for (int i = 0; i < y.shape().NumAxes(); ++i) {\n      CHECK_EQ(y.shape().At(i), (a.shape().At(i) == 0 || b.shape().At(i) == 0)\n                                    ? 0\n                                    : std::max(a.shape().At(i), b.shape().At(i)));\n      if (a.shape().At(i) != b.shape().At(i)) {\n        CHECK(a.shape().At(i) == 1 || b.shape().At(i) == 1);\n      }\n    }\n  }\n};\n\ntemplate<DeviceType device_type, typename T, template<typename> class binary_func>\nstruct NdarrayApplyBroadcastBinary<\n    device_type, T, binary_func,\n    typename std::enable_if<!std::is_same<T, typename DevDType<device_type, T>::type>::value>::type>\n    final {\n  using NewT = typename DevDType<device_type, T>::type;\n  static void Apply(ep::Stream* stream,\n                    const XpuVarNdarray<typename BinaryFuncTrait<binary_func, T>::return_type>& y,\n                    const XpuVarNdarray<const T>& a, const XpuVarNdarray<const T>& b) {\n    return NdarrayApplyBroadcastBinary<device_type, NewT, binary_func>::Apply(\n        stream, reinterpret_cast<const XpuVarNdarray<NewT>&>(y),\n        reinterpret_cast<const XpuVarNdarray<const NewT>&>(a),\n        reinterpret_cast<const XpuVarNdarray<const NewT>&>(b));\n  }\n  static void InplaceApply(ep::Stream* stream, const XpuVarNdarray<T>& y,\n                           const XpuVarNdarray<const T>& x) {\n    return NdarrayApplyBroadcastBinary<device_type, NewT, binary_func>::InplaceApply(\n        stream, reinterpret_cast<const XpuVarNdarray<NewT>&>(y),\n        reinterpret_cast<const XpuVarNdarray<const NewT>&>(x));\n  }\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_NDARRAY_NDARRAY_APPLY_BROADCAST_BINARY_H_\n"
  },
  {
    "path": "oneflow/core/ndarray/ndarray_apply_broadcast_binary_core.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/ndarray/ndarray_apply_broadcast_binary_core.h\"\n\nnamespace oneflow {\n\ntemplate<typename T, int NDIMS, template<typename> class binary_func>\nstruct NdarrayApplyBroadcastBinaryCoreWrapper<DeviceType::kCPU, T, NDIMS, binary_func> final {\n  static void Apply(ep::Stream* stream,\n                    const XpuVarNdarray<typename BinaryFuncTrait<binary_func, T>::return_type>& y,\n                    const XpuVarNdarray<const T>& a, const XpuVarNdarray<const T>& b) {\n    NdarrayApplyBroadcastBinaryCore<T, NDIMS, binary_func>::Apply(y, a, b);\n  }\n};\n\ntemplate<typename T, int NDIMS, template<typename> class binary_func>\nstruct NdarrayApplyBroadcastInplaceBinaryCoreWrapper<DeviceType::kCPU, T, NDIMS, binary_func>\n    final {\n  static void InplaceApply(ep::Stream* stream, const XpuVarNdarray<T>& y,\n                           const XpuVarNdarray<const T>& x) {\n    NdarrayApplyBroadcastBinaryCore<T, NDIMS, binary_func>::InplaceApply(y, x);\n  }\n};\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/ndarray/ndarray_apply_broadcast_binary_core.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/ndarray/ndarray_apply_broadcast_binary_core.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename Index>\nstruct XY2XFunctor final {\n  __host__ __device__ XY2XFunctor(Index dim_y) : dim_y_(dim_y) {}\n\n  __host__ __device__ Index operator()(Index idx) const { return idx / dim_y_; }\n\n  Index dim_y_;\n};\n\ntemplate<typename Index>\nstruct XY2YFunctor final {\n  __host__ __device__ XY2YFunctor(Index dim_y) : dim_y_(dim_y) {}\n\n  __host__ __device__ Index operator()(Index idx) const { return idx % dim_y_; }\n\n  Index dim_y_;\n};\n\ntemplate<typename Index>\nstruct XYZ2XZFunctor final {\n  __host__ __device__ XYZ2XZFunctor(Index dim_y, Index dim_z)\n      : dim_yz_(dim_y * dim_z), dim_z_(dim_z) {}\n\n  __host__ __device__ Index operator()(Index idx) const {\n    const Index x = idx / dim_yz_;\n    const Index z = (idx % dim_yz_) % dim_z_;\n    return x * dim_z_ + z;\n  }\n\n  Index dim_yz_;\n  Index dim_z_;\n};\n\ntemplate<typename Index>\nstruct XYZ2YFunctor final {\n  __host__ __device__ XYZ2YFunctor(Index dim_y, Index dim_z)\n      : dim_yz_(dim_y * dim_z), dim_z_(dim_z) {}\n\n  __host__ __device__ Index operator()(Index idx) const { return (idx % dim_yz_) / dim_z_; }\n\n  Index dim_yz_;\n  Index dim_z_;\n};\n\ntemplate<typename T, typename K, template<typename> class binary_func, typename OffsetFunctor>\n__global__ void PartialBroadcastGpu(K n, typename BinaryFuncTrait<binary_func, T>::return_type* y,\n                                    const T* a, const T* b, OffsetFunctor offset_functor) {\n  CUDA_1D_KERNEL_LOOP_T(K, i, n) { y[i] = binary_func<T>::Invoke(a[i], b[offset_functor(i)]); }\n}\n\ntemplate<typename T, int NDIMS, template<typename> class binary_func>\n__global__ void GpuBroadcastBinaryFunc(\n    const XpuVarNdarray<typename BinaryFuncTrait<binary_func, T>::return_type> y,\n    const XpuVarNdarray<const T> a, const XpuVarNdarray<const T> b) {\n  NdarrayApplyBroadcastBinaryCore<T, NDIMS, binary_func>::Apply(y, a, b);\n}\ntemplate<typename T, int NDIMS, template<typename> class binary_func>\n__global__ void GpuInplaceBroadcastBinaryFunc(const XpuVarNdarray<T> y,\n                                              const XpuVarNdarray<const T> x) {\n  NdarrayApplyBroadcastBinaryCore<T, NDIMS, binary_func>::InplaceApply(y, x);\n}\n\n}  // namespace\n\ntemplate<typename T, int NDIMS, template<typename> class binary_func>\nstruct NdarrayApplyBroadcastBinaryCoreWrapper<DeviceType::kCUDA, T, NDIMS, binary_func> final {\n  static void Apply(ep::Stream* stream,\n                    const XpuVarNdarray<typename BinaryFuncTrait<binary_func, T>::return_type>& y,\n                    const XpuVarNdarray<const T>& a, const XpuVarNdarray<const T>& b) {\n    size_t n = y.host_shape().HostElemNum();\n    if (n == 0) { return; }\n    if (IsKernelSafeInt32(n) && PartialBroadcast<int32_t>(stream, y, a, b)) { return; }\n    if (!IsKernelSafeInt32(n) && PartialBroadcast<int64_t>(stream, y, a, b)) { return; }\n    RUN_CUDA_KERNEL((GpuBroadcastBinaryFunc<T, NDIMS, binary_func>), stream, n, y, a, b);\n  }\n\n  template<typename K>\n  static bool PartialBroadcast(\n      ep::Stream* stream,\n      const XpuVarNdarray<typename BinaryFuncTrait<binary_func, T>::return_type>& y,\n      const XpuVarNdarray<const T>& a, const XpuVarNdarray<const T>& b) {\n    size_t n = y.host_shape().HostElemNum();\n    if (y.host_shape() == a.host_shape()) {\n      if (y.host_shape().NumAxes() == 2) {\n        const K y_dim0 = y.host_shape().At(0);\n        const K y_dim1 = y.host_shape().At(1);\n        const K b_dim0 = b.host_shape().At(0);\n        const K b_dim1 = b.host_shape().At(1);\n        if (b_dim0 == y_dim0 && b_dim1 == 1) {\n          XY2XFunctor<K> xy2x(y_dim1);\n          RUN_CUDA_KERNEL((PartialBroadcastGpu<T, K, binary_func, XY2XFunctor<K>>), stream, n, n,\n                          y.host_ptr(), a.host_ptr(), b.host_ptr(), xy2x);\n          return true;\n        }\n        if (b_dim0 == 1 && b_dim1 == y_dim1) {\n          XY2YFunctor<K> xy2y(y_dim1);\n          RUN_CUDA_KERNEL((PartialBroadcastGpu<T, K, binary_func, XY2YFunctor<K>>), stream, n, n,\n                          y.host_ptr(), a.host_ptr(), b.host_ptr(), xy2y);\n          return true;\n        }\n      }\n      if (y.host_shape().NumAxes() == 3) {\n        const K y_dim0 = y.host_shape().At(0);\n        const K y_dim1 = y.host_shape().At(1);\n        const K y_dim2 = y.host_shape().At(2);\n        const K b_dim0 = b.host_shape().At(0);\n        const K b_dim1 = b.host_shape().At(1);\n        const K b_dim2 = b.host_shape().At(2);\n        if (b_dim0 == y_dim0 && b_dim1 == 1 && b_dim2 == y_dim2) {\n          XYZ2XZFunctor<K> xyz2xz(y_dim1, y_dim2);\n          RUN_CUDA_KERNEL((PartialBroadcastGpu<T, K, binary_func, XYZ2XZFunctor<K>>), stream, n, n,\n                          y.host_ptr(), a.host_ptr(), b.host_ptr(), xyz2xz);\n          return true;\n        }\n        if (b_dim0 == 1 && b_dim1 == y_dim1 && b_dim2 == 1) {\n          XYZ2YFunctor<K> xyz2y(y_dim1, y_dim2);\n          RUN_CUDA_KERNEL((PartialBroadcastGpu<T, K, binary_func, XYZ2YFunctor<K>>), stream, n, n,\n                          y.host_ptr(), a.host_ptr(), b.host_ptr(), xyz2y);\n          return true;\n        }\n      }\n    }\n    return false;\n  }\n};\n\ntemplate<typename T, int NDIMS, template<typename> class binary_func>\nstruct NdarrayApplyBroadcastInplaceBinaryCoreWrapper<DeviceType::kCUDA, T, NDIMS, binary_func>\n    final {\n  static void InplaceApply(ep::Stream* stream, const XpuVarNdarray<T>& y,\n                           const XpuVarNdarray<const T>& x) {\n    size_t n = y.host_shape().HostElemNum();\n    XpuVarNdarray<const T> a(y.host_shape(), y.host_ptr());\n    using NBB = NdarrayApplyBroadcastBinaryCoreWrapper<DeviceType::kCUDA, T, NDIMS, binary_func>;\n    if (n == 0) { return; }\n    if (IsKernelSafeInt32(n) && NBB::template PartialBroadcast<int32_t>(stream, y, a, x)) {\n      return;\n    }\n    if (!IsKernelSafeInt32(n) && NBB::template PartialBroadcast<int64_t>(stream, y, a, x)) {\n      return;\n    }\n    RUN_CUDA_KERNEL((GpuInplaceBroadcastBinaryFunc<T, NDIMS, binary_func>), stream, n, y, x);\n  }\n};\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/ndarray/ndarray_apply_broadcast_binary_core.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_NDARRAY_NDARRAY_APPLY_BROADCAST_BINARY_CORE_H_\n#define ONEFLOW_CORE_NDARRAY_NDARRAY_APPLY_BROADCAST_BINARY_CORE_H_\n\n#include \"oneflow/core/ndarray/xpu_util.h\"\n#include \"oneflow/core/ndarray/xpu_var_ndarray.h\"\n#include \"oneflow/core/ndarray/xpu_broadcast_ndarray.h\"\n#include \"oneflow/core/ndarray/xpu_binary_func_ndarray.h\"\n#include \"oneflow/core/ndarray/binary_func.h\"\n#include \"oneflow/core/kernel/kernel_util.h\"\n\nnamespace oneflow {\n\ntemplate<DeviceType device_type, typename T, int NDIMS, template<typename> class binary_func>\nstruct NdarrayApplyBroadcastBinaryCoreWrapper final {\n  static void Apply(ep::Stream* stream,\n                    const XpuVarNdarray<typename BinaryFuncTrait<binary_func, T>::return_type>& y,\n                    const XpuVarNdarray<const T>& a, const XpuVarNdarray<const T>& b);\n};\n\ntemplate<DeviceType device_type, typename T, int NDIMS, template<typename> class binary_func>\nstruct NdarrayApplyBroadcastInplaceBinaryCoreWrapper final {\n  static void InplaceApply(ep::Stream* stream, const XpuVarNdarray<T>& y,\n                           const XpuVarNdarray<const T>& x);\n};\n\ntemplate<typename T, int NDIMS, template<typename> class binary_func>\nstruct NdarrayApplyBroadcastBinaryCore final {\n  OF_DEVICE_FUNC static void Apply(\n      const XpuVarNdarray<typename BinaryFuncTrait<binary_func, T>::return_type>& y,\n      const XpuVarNdarray<const T>& a, const XpuVarNdarray<const T>& b) {\n    const auto& ret =\n        a.Broadcast(y.shape()).template BinaryFunc<binary_func>(b.Broadcast(y.shape()));\n    y.template Assign<NDIMS>(ret);\n  }\n  OF_DEVICE_FUNC static void InplaceApply(const XpuVarNdarray<T>& y,\n                                          const XpuVarNdarray<const T>& x) {\n    y.template BinaryAssign<binary_func, NDIMS>(x.Broadcast(y.shape()));\n  }\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_NDARRAY_NDARRAY_APPLY_BROADCAST_BINARY_CORE_H_\n"
  },
  {
    "path": "oneflow/core/ndarray/ndarray_apply_broadcast_unary.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_NDARRAY_NDARRAY_APPLY_BROADCAST_UNARY_H_\n#define ONEFLOW_CORE_NDARRAY_NDARRAY_APPLY_BROADCAST_UNARY_H_\n\n#include \"oneflow/core/ndarray/ndarray_apply_broadcast_unary_core.h\"\n#include \"oneflow/core/common/util.h\"\n\nnamespace oneflow {\n\ntemplate<DeviceType device_type, typename T, template<typename> class unary_func,\n         typename Enable = void>\nstruct NdarrayApplyBroadcastUnary;\n\ntemplate<DeviceType device_type, typename T, template<typename> class unary_func>\nstruct NdarrayApplyBroadcastUnary<\n    device_type, T, unary_func,\n    typename std::enable_if<std::is_same<T, typename DevDType<device_type, T>::type>::value>::type>\n    final {\n  static void Apply(ep::Stream* stream, const XpuVarNdarray<T>& y,\n                    const XpuVarNdarray<const T>& x) {\n    CheckBroadcastable(y, x);\n    DimVector simplified_y_dim;\n    DimVector simplified_x_dim;\n    SimplifyBroadcastShapes(y.shape(), x.shape(), &simplified_y_dim, &simplified_x_dim);\n    SwitchApply(SwitchCase(simplified_y_dim.size()), stream,\n                XpuVarNdarray<T>(Shape(simplified_y_dim), y.ptr()),\n                XpuVarNdarray<const T>(Shape(simplified_x_dim), x.ptr()));\n  }\n\n private:\n#define DEFINE_NDARRAY_BROADCAST_UNARY(func_name, NDIMS) \\\n  NdarrayApplyBroadcastUnaryCoreWrapper<device_type, T, NDIMS, unary_func>::func_name\n  DEFINE_STATIC_SWITCH_FUNC(void, Apply, DEFINE_NDARRAY_BROADCAST_UNARY,\n                            MAKE_NDIM_CTRV_SEQ(DIM_SEQ))\n#undef DEFINE_NDARRAY_BROADCAST_UNARY\n  static void CheckBroadcastable(const XpuVarNdarray<T>& y, const XpuVarNdarray<const T>& x) {\n    CHECK_EQ(y.shape().NumAxes(), x.shape().NumAxes());\n    for (int i = 0; i < y.shape().NumAxes(); ++i) {\n      CHECK(x.shape().At(i) == 1 || x.shape().At(i) == y.shape().At(i));\n    }\n  }\n};\n\ntemplate<DeviceType device_type, typename T, template<typename> class unary_func>\nstruct NdarrayApplyBroadcastUnary<\n    device_type, T, unary_func,\n    typename std::enable_if<!std::is_same<T, typename DevDType<device_type, T>::type>::value>::type>\n    final {\n  static void Apply(ep::Stream* stream, const XpuVarNdarray<T>& y,\n                    const XpuVarNdarray<const T>& x) {\n    using NewT = typename DevDType<device_type, T>::type;\n    return NdarrayApplyBroadcastUnary<device_type, NewT, unary_func>::Apply(\n        stream, reinterpret_cast<const XpuVarNdarray<NewT>&>(y),\n        reinterpret_cast<const XpuVarNdarray<const NewT>&>(x));\n  }\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_NDARRAY_NDARRAY_APPLY_BROADCAST_UNARY_H_\n"
  },
  {
    "path": "oneflow/core/ndarray/ndarray_apply_broadcast_unary_core.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/ndarray/ndarray_apply_broadcast_unary_core.h\"\n\nnamespace oneflow {\n\ntemplate<typename T, int NDIMS, template<typename> class unary_func>\nstruct NdarrayApplyBroadcastUnaryCoreWrapper<DeviceType::kCPU, T, NDIMS, unary_func> final {\n  static void Apply(ep::Stream* stream, const XpuVarNdarray<T>& y,\n                    const XpuVarNdarray<const T>& x) {\n    NdarrayApplyBroadcastUnaryCore<T, NDIMS, unary_func>::Apply(y, x);\n  }\n};\n\n#define INSTANTIATE_BROADCAST_UNARY_FUNC(dtype_pair, NDIMS, unary_func) \\\n  template struct NdarrayApplyBroadcastUnaryCoreWrapper<                \\\n      DeviceType::kCPU, OF_PP_PAIR_FIRST(dtype_pair), NDIMS, unary_func>;\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_BROADCAST_UNARY_FUNC,\n                                 ARITHMETIC_DATA_TYPE_SEQ FLOAT16_DATA_TYPE_SEQ BOOL_DATA_TYPE_SEQ\n                                     COMPLEX_DATA_TYPE_SEQ,\n                                 DIM_SEQ, ARITHMETIC_UNARY_FUNC_SEQ)\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/ndarray/ndarray_apply_broadcast_unary_core.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/ndarray/ndarray_apply_broadcast_unary_core.h\"\n#include \"oneflow/core/ep/cuda/primitive/type_seq.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename T, int NDIMS, template<typename> class unary_func>\n__global__ void GpuBroadcastUnaryFunc(const XpuVarNdarray<T> y, const XpuVarNdarray<const T> x) {\n  NdarrayApplyBroadcastUnaryCore<T, NDIMS, unary_func>::Apply(y, x);\n}\n\n}  // namespace\n\ntemplate<typename T, int NDIMS, template<typename> class unary_func>\nstruct NdarrayApplyBroadcastUnaryCoreWrapper<DeviceType::kCUDA, T, NDIMS, unary_func> final {\n  static void Apply(ep::Stream* stream, const XpuVarNdarray<T>& y,\n                    const XpuVarNdarray<const T>& x) {\n    size_t n = y.host_shape().HostElemNum();\n    if (n == 0) { return; }\n    RUN_CUDA_KERNEL((GpuBroadcastUnaryFunc<T, NDIMS, unary_func>), stream, n, y, x);\n  }\n};\n\n#define INSTANTIATE_BROADCAST_UNARY_FUNC(dtype_pair, NDIMS, unary_func) \\\n  template struct NdarrayApplyBroadcastUnaryCoreWrapper<                \\\n      DeviceType::kCUDA, OF_PP_PAIR_FIRST(dtype_pair), NDIMS, unary_func>;\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_BROADCAST_UNARY_FUNC,\n                                 ARITHMETIC_DATA_TYPE_SEQ HALF_DATA_TYPE_SEQ BOOL_DATA_TYPE_SEQ\n                                     CUDA_PRIMITIVE_COMPLEX_TYPE_SEQ,\n                                 DIM_SEQ, ARITHMETIC_UNARY_FUNC_SEQ)\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/ndarray/ndarray_apply_broadcast_unary_core.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_NDARRAY_NDARRAY_APPLY_BROADCAST_UNARY_CORE_H_\n#define ONEFLOW_CORE_NDARRAY_NDARRAY_APPLY_BROADCAST_UNARY_CORE_H_\n\n#include \"oneflow/core/ndarray/xpu_util.h\"\n#include \"oneflow/core/ndarray/xpu_var_ndarray.h\"\n#include \"oneflow/core/ndarray/xpu_broadcast_ndarray.h\"\n#include \"oneflow/core/ndarray/xpu_unary_func_ndarray.h\"\n#include \"oneflow/core/ndarray/unary_func.h\"\n#include \"oneflow/core/kernel/kernel_util.h\"\n\nnamespace oneflow {\n\ntemplate<DeviceType device_type, typename T, int NDIMS, template<typename> class unary_func>\nstruct NdarrayApplyBroadcastUnaryCoreWrapper final {\n  static void Apply(ep::Stream* stream, const XpuVarNdarray<T>& y, const XpuVarNdarray<const T>& x);\n};\n\ntemplate<typename T, int NDIMS, template<typename> class unary_func>\nstruct NdarrayApplyBroadcastUnaryCore final {\n  OF_DEVICE_FUNC static void Apply(const XpuVarNdarray<T>& y, const XpuVarNdarray<const T>& x) {\n    y.template Assign<NDIMS>(x.Broadcast(y.shape()).template UnaryFunc<unary_func>());\n  }\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_NDARRAY_NDARRAY_APPLY_BROADCAST_UNARY_CORE_H_\n"
  },
  {
    "path": "oneflow/core/ndarray/ndarray_apply_unary.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_NDARRAY_NDARRAY_APPLY_UNARY_H_\n#define ONEFLOW_CORE_NDARRAY_NDARRAY_APPLY_UNARY_H_\n\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/ndarray/ndarray_apply_unary_core.h\"\n\nnamespace oneflow {\n\ntemplate<DeviceType device_type, typename T, template<typename> class unary_func,\n         typename Enable = void>\nstruct NdarrayApplyUnary;\n\ntemplate<DeviceType device_type, typename T, template<typename> class unary_func>\nstruct NdarrayApplyUnary<\n    device_type, T, unary_func,\n    typename std::enable_if<std::is_same<T, typename DevDType<device_type, T>::type>::value>::type>\n    final {\n  static void InplaceApply(ep::Stream* stream, const XpuVarNdarray<T>& y) {\n    NdarrayApplyUnaryCoreWrapper<device_type, T, unary_func>::InplaceApply(stream, y);\n  }\n};\n\ntemplate<DeviceType device_type, typename T, template<typename> class unary_func>\nstruct NdarrayApplyUnary<\n    device_type, T, unary_func,\n    typename std::enable_if<!std::is_same<T, typename DevDType<device_type, T>::type>::value>::type>\n    final {\n  static void InplaceApply(ep::Stream* stream, const XpuVarNdarray<T>& y) {\n    using NewT = typename DevDType<device_type, T>::type;\n    return NdarrayApplyUnary<device_type, NewT, unary_func>::InplaceApply(\n        stream, reinterpret_cast<const XpuVarNdarray<NewT>&>(y));\n  }\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_NDARRAY_NDARRAY_APPLY_UNARY_H_\n"
  },
  {
    "path": "oneflow/core/ndarray/ndarray_apply_unary_core.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/ndarray/ndarray_apply_unary_core.h\"\n#include \"oneflow/core/ndarray/unary_func.h\"\n\nnamespace oneflow {\n\ntemplate<typename T, template<typename> class unary_func>\nstruct NdarrayApplyUnaryCoreWrapper<DeviceType::kCPU, T, unary_func> final {\n  static void InplaceApply(ep::Stream* stream, const XpuVarNdarray<T>& y) {\n    NdarrayApplyUnaryCore<T, unary_func>::InplaceApply(y.ptr(), y.shape().ElemNum());\n  }\n};\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/ndarray/ndarray_apply_unary_core.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/ndarray/ndarray_apply_unary_core.h\"\n#include \"oneflow/core/ndarray/unary_func.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename T, template<typename> class unary_func>\n__global__ void NdarrayApplyUnaryInplaceApplyGpu(T* ptr, size_t n) {\n  NdarrayApplyUnaryCore<T, unary_func>::InplaceApply(ptr, n);\n}\n\n}  // namespace\n\ntemplate<typename T, template<typename> class unary_func>\nstruct NdarrayApplyUnaryCoreWrapper<DeviceType::kCUDA, T, unary_func> final {\n  static void InplaceApply(ep::Stream* stream, const XpuVarNdarray<T>& y) {\n    size_t n = y.host_shape().HostElemNum();\n    if (n == 0) { return; }\n    RUN_CUDA_KERNEL((NdarrayApplyUnaryInplaceApplyGpu<T, unary_func>), stream, n, y.host_ptr(), n);\n  }\n};\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/ndarray/ndarray_apply_unary_core.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_NDARRAY_NDARRAY_APPLY_UNARY_CORE_H_\n#define ONEFLOW_CORE_NDARRAY_NDARRAY_APPLY_UNARY_CORE_H_\n\n#include \"oneflow/core/kernel/kernel_util.h\"\n#include \"oneflow/core/ndarray/xpu_unary_func_ndarray.h\"\n#include \"oneflow/core/ndarray/xpu_var_ndarray.h\"\n#include \"oneflow/core/ndarray/xpu_util.h\"\n\nnamespace oneflow {\n\ntemplate<DeviceType device_type, typename T, template<typename> class unary_func>\nstruct NdarrayApplyUnaryCoreWrapper final {\n  static void InplaceApply(ep::Stream* stream, const XpuVarNdarray<T>& y);\n};\n\ntemplate<typename T, template<typename> class unary_func>\nstruct NdarrayApplyUnaryCore final {\n  OF_DEVICE_FUNC static void InplaceApply(T* y, size_t n) {\n    XPU_1D_KERNEL_LOOP(i, n) { y[i] = unary_func<T>::Invoke(y[i]); }\n  }\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_NDARRAY_NDARRAY_APPLY_UNARY_CORE_H_\n"
  },
  {
    "path": "oneflow/core/ndarray/ndarray_assign_core.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/ndarray/ndarray_assign_core.h\"\n#include \"oneflow/core/kernel/kernel_util.h\"\n\nnamespace oneflow {\n\ntemplate<typename T, typename X, int NDIMS>\nstruct NdarrayAssignCoreWrapper<DeviceType::kCPU, T, X, NDIMS> final {\n  static void Assign(ep::Stream* stream, const XpuVarNdarray<T>& y,\n                     const XpuReducedNdarray<X, NDIMS>& reduced) {\n    NdarrayAssignCore<T, X, NDIMS>::Assign(y, reduced);\n  }\n  static void Assign(ep::Stream* stream, const XpuVarNdarray<T>& y,\n                     const XpuVarNdarray<const X>& x) {\n    NdarrayAssignCore<T, X, NDIMS>::Assign(y, x);\n  }\n};\n\n#define INSTANTIATE_NDARRAY_ASSIGN(ret_dtype_pair, dtype_pair, NDIMS)                          \\\n  template struct NdarrayAssignCoreWrapper<DeviceType::kCPU, OF_PP_PAIR_FIRST(ret_dtype_pair), \\\n                                           OF_PP_PAIR_FIRST(dtype_pair), NDIMS>;\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(\n    INSTANTIATE_NDARRAY_ASSIGN,\n    ARITHMETIC_DATA_TYPE_SEQ FLOAT16_DATA_TYPE_SEQ UNSIGNED_INT_DATA_TYPE_SEQ BOOL_DATA_TYPE_SEQ,\n    ARITHMETIC_DATA_TYPE_SEQ FLOAT16_DATA_TYPE_SEQ UNSIGNED_INT_DATA_TYPE_SEQ BOOL_DATA_TYPE_SEQ,\n    DIM_SEQ);\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NDARRAY_ASSIGN, COMPLEX_DATA_TYPE_SEQ,\n                                 COMPLEX_DATA_TYPE_SEQ, DIM_SEQ);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/ndarray/ndarray_assign_core.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/ndarray/ndarray_assign_core.h\"\n#include \"oneflow/core/device/cuda_util.h\"\n#include \"oneflow/core/kernel/kernel_util.h\"\n#include \"oneflow/core/ep/cuda/primitive/type_seq.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename T, typename X, int NDIMS>\n__global__ void NdarrayAssignReducedGpu(XpuVarNdarray<T> y,\n                                        const XpuReducedNdarray<X, NDIMS> reduced) {\n  NdarrayAssignCore<T, X, NDIMS>::Assign(y, reduced);\n}\n\ntemplate<typename T, typename X, int NDIMS>\n__global__ void NdarrayAssignGpu(XpuVarNdarray<T> y, const XpuVarNdarray<const X> x) {\n  NdarrayAssignCore<T, X, NDIMS>::Assign(y, x);\n}\n\n}  // namespace\n\ntemplate<typename T, typename X, int NDIMS>\nstruct NdarrayAssignCoreWrapper<DeviceType::kCUDA, T, X, NDIMS> final {\n  static void Assign(ep::Stream* ctx, const XpuVarNdarray<T>& y,\n                     const XpuReducedNdarray<X, NDIMS>& reduced) {\n    size_t n = y.host_shape().HostElemNum();\n    if (n == 0) { return; }\n    RUN_CUDA_KERNEL((NdarrayAssignReducedGpu<T, X, NDIMS>), ctx, n, y, reduced);\n  }\n  static void Assign(ep::Stream* ctx, const XpuVarNdarray<T>& y, const XpuVarNdarray<const X>& x) {\n    size_t n = y.host_shape().HostElemNum();\n    if (n == 0) { return; }\n    RUN_CUDA_KERNEL((NdarrayAssignGpu<T, X, NDIMS>), ctx, n, y, x);\n  }\n};\n\n#define INSTANTIATE_NDARRAY_ASSIGN(ret_dtype_pair, dtype_pair, NDIMS)                           \\\n  template struct NdarrayAssignCoreWrapper<DeviceType::kCUDA, OF_PP_PAIR_FIRST(ret_dtype_pair), \\\n                                           OF_PP_PAIR_FIRST(dtype_pair), NDIMS>;\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(\n    INSTANTIATE_NDARRAY_ASSIGN,\n    ARITHMETIC_DATA_TYPE_SEQ UNSIGNED_INT_DATA_TYPE_SEQ BOOL_DATA_TYPE_SEQ,\n    ARITHMETIC_DATA_TYPE_SEQ UNSIGNED_INT_DATA_TYPE_SEQ BOOL_DATA_TYPE_SEQ, DIM_SEQ);\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NDARRAY_ASSIGN, HALF_DATA_TYPE_SEQ, HALF_DATA_TYPE_SEQ,\n                                 DIM_SEQ);\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NDARRAY_ASSIGN, CUDA_PRIMITIVE_COMPLEX64_TYPE_SEQ,\n                                 CUDA_PRIMITIVE_COMPLEX64_TYPE_SEQ, DIM_SEQ);\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NDARRAY_ASSIGN, CUDA_PRIMITIVE_COMPLEX128_TYPE_SEQ,\n                                 CUDA_PRIMITIVE_COMPLEX128_TYPE_SEQ, DIM_SEQ);\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/ndarray/ndarray_assign_core.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_NDARRAY_NDARRAY_ASSIGN_CORE_H_\n#define ONEFLOW_CORE_NDARRAY_NDARRAY_ASSIGN_CORE_H_\n\n#include \"oneflow/core/ndarray/xpu_var_ndarray.h\"\n#include \"oneflow/core/ndarray/xpu_reduced_ndarray.h\"\n\nnamespace oneflow {\n\ntemplate<DeviceType device_type, typename T, typename X, int NDIMS>\nstruct NdarrayAssignCoreWrapper final {\n  static void Assign(ep::Stream* stream, const XpuVarNdarray<T>& y,\n                     const XpuReducedNdarray<X, NDIMS>& reduced);\n  static void Assign(ep::Stream* stream, const XpuVarNdarray<T>& y,\n                     const XpuVarNdarray<const X>& x);\n};\n\ntemplate<typename T, typename X, int NDIMS>\nstruct NdarrayAssignCore final {\n  OF_DEVICE_FUNC static void Assign(const XpuVarNdarray<T>& y,\n                                    const XpuReducedNdarray<X, NDIMS>& reduced) {\n    y.template Assign<NDIMS>(reduced);\n  }\n\n  OF_DEVICE_FUNC static void Assign(const XpuVarNdarray<T>& y, const XpuVarNdarray<const X>& x) {\n    y.template Assign<NDIMS>(x);\n  }\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_NDARRAY_NDARRAY_ASSIGN_CORE_H_\n"
  },
  {
    "path": "oneflow/core/ndarray/ndarray_reduce.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_NDARRAY_NDARRAY_REDUCE_H_\n#define ONEFLOW_CORE_NDARRAY_NDARRAY_REDUCE_H_\n\n#include \"oneflow/core/device/cuda_util.h\"\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/common/data_type.h\"\n#include \"oneflow/core/ndarray/ndarray_reduce_impl.h\"\n\nnamespace oneflow {\n\ntemplate<DeviceType device_type, typename T, template<typename> class binary_func,\n         typename Enable = void>\nstruct NdarrayReduce;\n\ntemplate<DeviceType device_type, typename T, template<typename> class binary_func>\nstruct NdarrayReduce<\n    device_type, T, binary_func,\n    typename std::enable_if<std::is_same<T, typename DevDType<device_type, T>::type>::value>::type>\n    final {\n  using RetT = typename BinaryFuncTrait<binary_func, T>::return_type;\n  static void Reduce(ep::Stream* stream, const XpuVarNdarray<RetT>& origin_y,\n                     const XpuVarNdarray<const T>& origin_x, const XpuVarNdarray<T>& tmp_storage) {\n    DimVector simplified_x_dim;\n    DimVector simplified_y_dim;\n    TrySimplifyDims(origin_x.shape(), origin_y.shape(), &simplified_x_dim, &simplified_y_dim);\n    XpuVarNdarray<RetT> y(Shape(simplified_y_dim), origin_y.ptr());\n    XpuVarNdarray<const T> x(Shape(simplified_x_dim), origin_x.ptr());\n\n    CHECK_EQ(y.shape().NumAxes(), x.shape().NumAxes());\n    if (NdarrayNoReduce<device_type, T, binary_func>::Matched(y, x)) {\n      NdarrayNoReduce<device_type, T, binary_func>::Reduce(stream, y, x, tmp_storage);\n    } else if (NdarrayScalarReduce<device_type, T, binary_func>::Matched(y, x)) {\n      NdarrayScalarReduce<device_type, T, binary_func>::Reduce(stream, y, x, tmp_storage);\n    } else if (NdarrayMatrixRowReduce<device_type, T, binary_func>::Matched(y, x)) {\n      NdarrayMatrixRowReduce<device_type, T, binary_func>::Reduce(stream, y, x, tmp_storage);\n    } else if (NdarrayMatrixColReduce<device_type, T, binary_func>::Matched(y, x)) {\n      NdarrayMatrixColReduce<device_type, T, binary_func>::Reduce(stream, y, x, tmp_storage);\n    } else if (NdarrayXYZCubeXZReduce<device_type, T, binary_func>::Matched(y, x)) {\n      NdarrayXYZCubeXZReduce<device_type, T, binary_func>::Reduce(stream, y, x, tmp_storage);\n    } else {\n      NdarrayDefaultReduce<device_type, T, binary_func>::Reduce(stream, y, x, tmp_storage);\n    }\n  }\n\n  static void TrySimplifyDims(const XpuShape& x, const XpuShape& y, DimVector* simplified_x,\n                              DimVector* simplified_y) {\n    CHECK_EQ(y.NumAxes(), x.NumAxes());\n    CHECK(y.At(0) == 1 || y.At(0) == x.At(0));\n    CHECK(simplified_x->empty());\n    CHECK(simplified_y->empty());\n    simplified_x->emplace_back(x.At(0));\n    simplified_y->emplace_back(y.At(0));\n    bool prev_axis_is_reduced = (y.At(0) == 1);\n    FOR_RANGE(int, i, 1, x.NumAxes()) {\n      const int64_t x_dim = x.At(i);\n      const int64_t y_dim = y.At(i);\n      const bool cur_axis_is_reduced = (y_dim == 1);\n      CHECK(cur_axis_is_reduced || y_dim == x_dim);\n      if (cur_axis_is_reduced == prev_axis_is_reduced) {\n        simplified_x->back() *= x_dim;\n        simplified_y->back() *= y_dim;\n      } else {\n        simplified_x->emplace_back(x_dim);\n        simplified_y->emplace_back(y_dim);\n      }\n      prev_axis_is_reduced = cur_axis_is_reduced;\n    }\n  }\n};\n\ntemplate<DeviceType device_type, typename T, template<typename> class binary_func>\nstruct NdarrayReduce<\n    device_type, T, binary_func,\n    typename std::enable_if<!std::is_same<T, typename DevDType<device_type, T>::type>::value>::type>\n    final {\n  static void Reduce(ep::Stream* stream, const XpuVarNdarray<T>& y, const XpuVarNdarray<const T>& x,\n                     const XpuVarNdarray<T>& tmp_storage) {\n    using NewT = typename DevDType<device_type, T>::type;\n    return NdarrayReduce<device_type, NewT, binary_func>::Reduce(\n        stream, reinterpret_cast<const XpuVarNdarray<NewT>&>(y),\n        reinterpret_cast<const XpuVarNdarray<const NewT>&>(x),\n        reinterpret_cast<const XpuVarNdarray<NewT>&>(tmp_storage));\n  }\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_NDARRAY_NDARRAY_REDUCE_H_\n"
  },
  {
    "path": "oneflow/core/ndarray/ndarray_reduce_impl.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/data_type.h\"\n#include \"oneflow/core/common/preprocessor.h\"\n#include \"oneflow/core/ndarray/ndarray_reduce_impl.h\"\n#include \"oneflow/core/ndarray/binary_func.h\"\n\nnamespace oneflow {\n\n#define SPECIALIZE_CPU_NDARRAY_REDUCE_IMPL(struct_name)                                        \\\n  template<typename T, template<typename> class binary_func>                                   \\\n  struct struct_name<DeviceType::kCPU, T, binary_func> final {                                 \\\n    using RetT = typename BinaryFuncTrait<binary_func, T>::return_type;                        \\\n    static bool Matched(const XpuVarNdarray<RetT>& y, const XpuVarNdarray<const T>& x) {       \\\n      return false;                                                                            \\\n    }                                                                                          \\\n    static void Reduce(ep::Stream* stream, const XpuVarNdarray<RetT>& y,                       \\\n                       const XpuVarNdarray<const T>& x, const XpuVarNdarray<T>& tmp_storage) { \\\n      UNIMPLEMENTED();                                                                         \\\n    }                                                                                          \\\n  }\nSPECIALIZE_CPU_NDARRAY_REDUCE_IMPL(NdarrayScalarReduce);\nSPECIALIZE_CPU_NDARRAY_REDUCE_IMPL(NdarrayMatrixRowReduce);\nSPECIALIZE_CPU_NDARRAY_REDUCE_IMPL(NdarrayMatrixColReduce);\nSPECIALIZE_CPU_NDARRAY_REDUCE_IMPL(NdarrayXYZCubeXZReduce);\n#undef SPECIALIZE_CPU_NDARRAY_REDUCE_IMPL\n\n#define INSTANTIATE_NDARRAY_REDUCE_IMPL(dtype, binary_func)                                       \\\n  template struct NdarrayScalarReduce<DeviceType::kCPU, OF_PP_PAIR_FIRST(dtype), binary_func>;    \\\n  template struct NdarrayMatrixRowReduce<DeviceType::kCPU, OF_PP_PAIR_FIRST(dtype), binary_func>; \\\n  template struct NdarrayMatrixColReduce<DeviceType::kCPU, OF_PP_PAIR_FIRST(dtype), binary_func>; \\\n  template struct NdarrayXYZCubeXZReduce<DeviceType::kCPU, OF_PP_PAIR_FIRST(dtype), binary_func>;\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NDARRAY_REDUCE_IMPL,\n                                 ARITHMETIC_DATA_TYPE_SEQ FLOAT16_DATA_TYPE_SEQ\n                                     UNSIGNED_INT_DATA_TYPE_SEQ BOOL_DATA_TYPE_SEQ,\n                                 REDUCE_BINARY_FUNC_SEQ);\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NDARRAY_REDUCE_IMPL, FLOATING_DATA_TYPE_SEQ,\n                                 NANSUM_REDUCE_BINARY_FUNC_SEQ);\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NDARRAY_REDUCE_IMPL, COMPLEX_DATA_TYPE_SEQ,\n                                 REDUCE_BINARY_FUNC_SEQ);\n\ntemplate<typename T, int NDIMS, template<typename> class binary_func>\nstruct NdarrayReduceCoreWrapper<DeviceType::kCPU, T, NDIMS, binary_func> final {\n  static void ReduceAxis(ep::Stream* stream, const XpuReducedNdarray<T, NDIMS>& dst_reduced,\n                         const XpuReducedNdarray<T, NDIMS>& x, int axis) {\n    NdarrayReduceCore<T, NDIMS, binary_func>::ReduceAxis(dst_reduced, x, axis);\n  }\n};\n\n#define INSTANTIATE_NDARRAY_REDUCE_CORE_WRAPPER(dtype_pair, NDIMS, binary_func)                   \\\n  template struct NdarrayReduceCoreWrapper<DeviceType::kCPU, OF_PP_PAIR_FIRST(dtype_pair), NDIMS, \\\n                                           binary_func>;\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NDARRAY_REDUCE_CORE_WRAPPER,\n                                 ARITHMETIC_DATA_TYPE_SEQ FLOAT16_DATA_TYPE_SEQ\n                                     UNSIGNED_INT_DATA_TYPE_SEQ BOOL_DATA_TYPE_SEQ,\n                                 DIM_SEQ, REDUCE_BINARY_FUNC_SEQ);\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NDARRAY_REDUCE_CORE_WRAPPER, COMPLEX_DATA_TYPE_SEQ,\n                                 DIM_SEQ, REDUCE_COMPLEX_BINARY_FUNC_SEQ);\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NDARRAY_REDUCE_CORE_WRAPPER, FLOATING_DATA_TYPE_SEQ,\n                                 DIM_SEQ, NANSUM_REDUCE_BINARY_FUNC_SEQ);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/ndarray/ndarray_reduce_impl.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <cub/cub.cuh>\n#include \"oneflow/core/kernel/util/numerics.cuh\"\n#include \"oneflow/core/ndarray/ndarray_reduce_impl.h\"\n#include \"oneflow/core/ndarray/binary_func.h\"\n#include \"oneflow/core/common/preprocessor.h\"\n#include \"oneflow/core/common/shape.h\"\n#include \"oneflow/core/common/permutation_iterator.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n#include \"oneflow/core/ep/cuda/primitive/type_seq.h\"\n\nnamespace cub {\nstruct Prod {\n  template<typename T>\n  __host__ __device__ __forceinline__ T operator()(const T& a, const T& b) const {\n    return a * b;\n  }\n};\nstruct Any {\n  template<typename T, typename U>\n  __host__ __device__ __forceinline__ T operator()(const T& a, const U& b) const {\n    return a || b;\n  }\n};\nstruct All {\n  template<typename T, typename U>\n  __host__ __device__ __forceinline__ T operator()(const T& a, const U& b) const {\n    return a && b;\n  }\n};\nstruct NanSum {\n  template<typename T>\n  __host__ __device__ __forceinline__ T operator()(const T& a, const T& b) const {\n    if (oneflow::detail::numerics<T>::isnan(a))\n      return oneflow::detail::numerics<T>::isnan(b) ? T{0} : b;\n    return oneflow::detail::numerics<T>::isnan(b) ? a : a + b;\n  }\n};\n\n}  // namespace cub\n\n__host__ __device__ __forceinline__ cuComplex operator+(const cuComplex& a, const cuComplex& b) {\n  return cuComplex{a.x + b.x, a.y + b.y};\n}\n\n__host__ __device__ __forceinline__ cuDoubleComplex operator+(const cuDoubleComplex& a,\n                                                              const cuDoubleComplex& b) {\n  return cuDoubleComplex{a.x + b.x, a.y + b.y};\n}\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<template<typename> class R, typename T, typename K, typename RetT>\n__global__ void MatrixColReduceBy1ThreadPerColumn(K num_elems, K num_cols, const T* in, RetT* out) {\n  CUDA_1D_KERNEL_LOOP_T(K, j, num_cols) {\n    K index = j;\n    T sum = in[index];\n    for (index += num_cols; index < num_elems; index += num_cols) {\n      sum = R<T>::Invoke(sum, in[index]);\n    }\n    out[j] = sum;\n  }\n}\n\ntemplate<typename T>\nstruct WithAlign2 {\n  union {\n    T value;\n    int32_t padding;\n  };\n};\n\ntemplate<template<typename> class R, typename T, typename K, typename RetT>\n__global__ void MatrixColReduceByWarpBlock(K num_elems, K num_cols, const T* in, RetT* out) {\n  const K thread_col = threadIdx.x % kCudaWarpSize;\n  const K thread_row = threadIdx.x / kCudaWarpSize;\n  const K thread_dim_row = blockDim.x / kCudaWarpSize;\n  const K num_valid_threads = thread_dim_row * num_cols;  // ASSERT: always <= num_elems\n  const K col = blockIdx.x * kCudaWarpSize + thread_col;\n  __shared__ WithAlign2<T> partial_values[kCudaWarpSize * kCudaWarpSize];\n  if (col < num_cols) {\n    K index = thread_row * num_cols + col;\n    T val = in[index];\n    for (index += num_valid_threads; index < num_elems; index += num_valid_threads) {\n      val = R<T>::Invoke(val, in[index]);\n    }\n    partial_values[threadIdx.x].value = val;\n  }\n  __syncthreads();\n  if (col < num_cols && thread_row == 0) {\n    int index = thread_col;\n    T val = partial_values[index].value;\n    for (index += kCudaWarpSize; index < blockDim.x; index += kCudaWarpSize) {\n      val = R<T>::Invoke(val, partial_values[index].value);\n    }\n    out[col] = val;\n  }\n}\n\ntemplate<template<typename> class R, typename T, typename K, typename RetT>\nvoid MatrixColReduceBy1BlockLayer(ep::Stream* stream, K num_elems, K num_cols, const T* in,\n                                  RetT* out) {\n  CHECK_LE(num_cols, kCudaMaxBlocksNum * kCudaWarpSize);\n  const K num_rows = num_elems / num_cols;\n  CHECK_GT(num_rows, 0);\n  if (num_rows < kCudaWarpSize) {\n    RUN_CUDA_KERNEL((MatrixColReduceBy1ThreadPerColumn<R, T, K, RetT>), stream, num_cols, num_elems,\n                    num_cols, in, out);\n  } else {\n    const int num_blocks = (num_cols + kCudaWarpSize - 1) / kCudaWarpSize;\n    const int num_threads = kCudaWarpSize * kCudaWarpSize;\n    auto Reduce = &MatrixColReduceByWarpBlock<R, T, K, RetT>;\n    Reduce<<<num_blocks, num_threads, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(\n        num_elems, num_cols, in, out);\n  }\n}\n\nconst static int32_t kNumRows4OneBlockLayer = kCudaWarpSize * kCudaWarpSize;\nconst static int32_t kNumCols4OneBlockLayer = kCudaMaxBlocksNum * kCudaWarpSize / 2;\n\ntemplate<template<typename> class R, typename T, typename K>\nvoid MatrixColReduceK(ep::Stream* stream, K num_rows, K num_cols, const T* in,\n                      typename BinaryFuncTrait<R, T>::return_type* out, T* tmp) {\n  K num_elems = num_rows * num_cols;\n  if (num_rows < kNumRows4OneBlockLayer || num_cols > kNumCols4OneBlockLayer) {\n    MatrixColReduceBy1BlockLayer<R, T, K, typename BinaryFuncTrait<R, T>::return_type>(\n        stream, num_elems, num_cols, in, out);\n  } else {\n    int scale_shift = 1;\n    for (; true; ++scale_shift) {\n      if ((num_rows >> scale_shift) < kNumRows4OneBlockLayer) { break; }\n      if ((num_cols << scale_shift) > kNumCols4OneBlockLayer) { break; }\n    }\n    MatrixColReduceBy1BlockLayer<R, T, K, T>(stream, num_elems, (num_cols << scale_shift), in, tmp);\n    // recursively calls MatrixColReduceK(...) log32(num_rows) times at most\n    MatrixColReduceK<R, T, K>(stream, (1 << scale_shift), num_cols, tmp, out, tmp);\n  }\n}\n\ntemplate<template<typename> class R, typename T>\nvoid MatrixColReduce(ep::Stream* stream, int64_t num_rows, int64_t num_cols, const T* in,\n                     typename BinaryFuncTrait<R, T>::return_type* out, T* tmp) {\n  if (IsKernelSafeInt32(num_rows * num_cols)) {\n    return MatrixColReduceK<R, T, int32_t>(stream, num_rows, num_cols, in, out, tmp);\n  } else {\n    return MatrixColReduceK<R, T, int64_t>(stream, num_rows, num_cols, in, out, tmp);\n  }\n}\n\n}  // namespace\n\ntemplate<typename T, template<typename> class binary_func>\nstruct CubFunctor4BianryFunc;\n\n#define SPECIALIZE_CUB_FUNCTOR_4_BINARY_FUNC(func_name)          \\\n  template<typename T>                                           \\\n  struct CubFunctor4BianryFunc<T, BinaryFunc##func_name> final { \\\n    using type = cub::func_name;                                 \\\n  };\nOF_PP_FOR_EACH_ATOMIC(SPECIALIZE_CUB_FUNCTOR_4_BINARY_FUNC, REDUCE_BINARY_FUNC_NAME_SEQ(NanSum));\n#undef SPECIALIZE_CUB_FUNCTOR_4_BINARY_FUNC\n\nstruct RowOffsetFunctor final {\n  OF_DEVICE_FUNC explicit RowOffsetFunctor(int32_t num_cols) : num_cols_(num_cols) {}\n  OF_DEVICE_FUNC int32_t operator()(const int32_t& x) const { return x * num_cols_; }\n  int32_t num_cols_;\n};\n\ntemplate<typename T, template<typename> class binary_func>\nstruct NdarrayScalarReduce<DeviceType::kCUDA, T, binary_func> final {\n  using RetT = typename BinaryFuncTrait<binary_func, T>::return_type;\n  static bool Matched(const XpuVarNdarray<RetT>& y, const XpuVarNdarray<const T>& x) {\n    return y.shape().ElemNum() == 1;\n  }\n\n  static void Reduce(ep::Stream* stream, const XpuVarNdarray<RetT>& y,\n                     const XpuVarNdarray<const T>& x, const XpuVarNdarray<T>& tmp_storage) {\n    CHECK(Matched(y, x));\n    size_t x_size = x.shape().ElemNum();\n    size_t tmp_storage_bytes = 0;\n    auto DoReduce = [&](T* tmp_storage_ptr) {\n      int retcode = cub::DeviceReduce::Reduce(\n          tmp_storage_ptr, tmp_storage_bytes, x.ptr(), y.ptr(), x_size,\n          typename CubFunctor4BianryFunc<T, binary_func>::type(),\n          UnitOfBinaryFunc<T, binary_func>::Val(), stream->As<ep::CudaStream>()->cuda_stream());\n      CHECK_EQ(retcode, 0) << \"cub::DeviceSegmentedReduce::Reduce error\";\n    };\n    DoReduce(nullptr);\n    CHECK_GE(tmp_storage.shape().ElemNum() * sizeof(T), tmp_storage_bytes);\n    DoReduce(tmp_storage.ptr());\n  }\n};\n\ntemplate<typename T, template<typename> class binary_func>\nstruct NdarrayMatrixRowReduce<DeviceType::kCUDA, T, binary_func> final {\n  using RetT = typename BinaryFuncTrait<binary_func, T>::return_type;\n  static bool Matched(const XpuVarNdarray<RetT>& y, const XpuVarNdarray<const T>& x) {\n    if (y.shape().ElemNum() > GetMaxVal<int32_t>()) { return false; }\n    if (x.shape().NumAxes() != 2) { return false; }\n    if (y.shape().NumAxes() != 2) { return false; }\n    return x.shape().At(0) == y.shape().At(0) && y.shape().At(1) == 1;\n  }\n\n  static void Reduce(ep::Stream* stream, const XpuVarNdarray<RetT>& y,\n                     const XpuVarNdarray<const T>& x, const XpuVarNdarray<T>& tmp_storage) {\n    CHECK(Matched(y, x));\n    int32_t num_rows = y.shape().ElemNum();\n    int32_t num_cols = x.shape().ElemNum() / y.shape().ElemNum();\n    RowOffsetFunctor get_row_offset(num_cols);\n    cub::CountingInputIterator<int32_t> counting_intput_it(0);\n    cub::TransformInputIterator<int32_t, RowOffsetFunctor, cub::CountingInputIterator<int32_t>>\n        transform_input_iter(counting_intput_it, get_row_offset);\n    size_t tmp_storage_bytes = 0;\n    auto DoReduce = [&](T* tmp_storage_ptr) {\n      int retcode = cub::DeviceSegmentedReduce::Reduce(\n          tmp_storage_ptr, tmp_storage_bytes, x.ptr(), y.ptr(), num_rows, transform_input_iter,\n          transform_input_iter + 1, typename CubFunctor4BianryFunc<T, binary_func>::type(),\n          UnitOfBinaryFunc<T, binary_func>::Val(), stream->As<ep::CudaStream>()->cuda_stream());\n      CHECK_EQ(retcode, 0) << \"cub::DeviceSegmentedReduce::Reduce error\";\n    };\n    DoReduce(nullptr);\n    CHECK_GE(tmp_storage.shape().ElemNum() * sizeof(T), tmp_storage_bytes);\n    DoReduce(tmp_storage.ptr());\n  }\n};\n\ntemplate<typename T, template<typename> class binary_func>\nstruct NdarrayMatrixColReduce<DeviceType::kCUDA, T, binary_func> final {\n  using RetT = typename BinaryFuncTrait<binary_func, T>::return_type;\n  static bool Matched(const XpuVarNdarray<RetT>& y, const XpuVarNdarray<const T>& x) {\n    if (y.shape().ElemNum() > GetMaxVal<int32_t>()) { return false; }\n    if (x.shape().NumAxes() != 2) { return false; }\n    if (y.shape().NumAxes() != 2) { return false; }\n    return y.shape().At(0) == 1 && x.shape().At(1) == y.shape().At(1);\n  }\n\n  struct XY2YXFunctor final {\n    __host__ __device__ XY2YXFunctor(int32_t dim_x, int32_t dim_y) : dim_x_(dim_x), dim_y_(dim_y) {}\n\n    __host__ __device__ int32_t operator()(const int32_t& idx) const {\n      const int32_t y = idx / dim_x_;\n      const int32_t x = idx % dim_x_;\n      return x * dim_y_ + y;\n    }\n\n    int32_t dim_x_;\n    int32_t dim_y_;\n  };\n\n  static void Reduce(ep::Stream* stream, const XpuVarNdarray<RetT>& y,\n                     const XpuVarNdarray<const T>& x, const XpuVarNdarray<T>& tmp_storage) {\n    CHECK(Matched(y, x));\n    int64_t num_rows = x.shape().At(0);\n    int64_t num_cols = x.shape().At(1);\n    if (num_cols < kNumCols4OneBlockLayer) {\n      return MatrixColReduce<binary_func, T>(stream, num_rows, num_cols, x.host_ptr(), y.host_ptr(),\n                                             tmp_storage.host_ptr());\n    }\n    RowOffsetFunctor get_row_offset(num_rows);\n    cub::CountingInputIterator<int32_t> counting_intput_it(0);\n    cub::TransformInputIterator<int32_t, RowOffsetFunctor, cub::CountingInputIterator<int32_t>>\n        transform_input_iter(counting_intput_it, get_row_offset);\n\n    XY2YXFunctor xy2yx(x.shape().At(0), x.shape().At(1));\n    using XY2YxIndexIter =\n        cub::TransformInputIterator<int32_t, XY2YXFunctor, cub::CountingInputIterator<int32_t>>;\n    XY2YxIndexIter xy2yx_iter(counting_intput_it, xy2yx);\n    PermutationIterator<const T, const T*, XY2YxIndexIter> x_iter(x.ptr(), xy2yx_iter);\n    size_t tmp_storage_bytes = 0;\n    auto DoReduce = [&](T* tmp_storage_ptr) {\n      int retcode = cub::DeviceSegmentedReduce::Reduce(\n          tmp_storage_ptr, tmp_storage_bytes, x_iter, y.ptr(), num_cols, transform_input_iter,\n          transform_input_iter + 1, typename CubFunctor4BianryFunc<T, binary_func>::type(),\n          UnitOfBinaryFunc<T, binary_func>::Val(), stream->As<ep::CudaStream>()->cuda_stream());\n      CHECK_EQ(retcode, 0) << \"cub::DeviceSegmentedReduce::Reduce error\";\n    };\n    DoReduce(nullptr);\n    CHECK_GE(tmp_storage.shape().ElemNum() * sizeof(T), tmp_storage_bytes);\n    DoReduce(tmp_storage.ptr());\n  }\n};\n\ntemplate<typename T, template<typename> class binary_func>\nstruct NdarrayXYZCubeXZReduce<DeviceType::kCUDA, T, binary_func> final {\n  using RetT = typename BinaryFuncTrait<binary_func, T>::return_type;\n  static bool Matched(const XpuVarNdarray<RetT>& y, const XpuVarNdarray<const T>& x) {\n    if (y.shape().ElemNum() > GetMaxVal<int32_t>()) { return false; }\n    if (x.shape().NumAxes() != 3) { return false; }\n    if (y.shape().NumAxes() != 3) { return false; }\n    return y.shape().At(0) == 1 && x.shape().At(1) == y.shape().At(1) && y.shape().At(2) == 1;\n  }\n\n  struct XYZ2YxzFunctor final {\n    __host__ __device__ XYZ2YxzFunctor(int32_t dim_x, int32_t dim_y, int32_t dim_z)\n        : dim_z_(dim_z), dim_xz_(dim_x * dim_z), dim_yz_(dim_y * dim_z) {}\n\n    __host__ __device__ int32_t operator()(const int32_t& idx) const {\n      const int32_t y = idx / dim_xz_;\n      const int32_t xz_idx = idx % dim_xz_;\n      const int32_t x = xz_idx / dim_z_;\n      const int32_t z = xz_idx % dim_z_;\n      return x * dim_yz_ + y * dim_z_ + z;\n    }\n\n    int32_t dim_z_;\n    int32_t dim_xz_;\n    int32_t dim_yz_;\n  };\n\n  static void Reduce(ep::Stream* stream, const XpuVarNdarray<RetT>& y,\n                     const XpuVarNdarray<const T>& x, const XpuVarNdarray<T>& tmp_storage) {\n    CHECK(Matched(y, x));\n    int32_t num_rows = y.shape().ElemNum();\n    int32_t num_cols = x.shape().ElemNum() / y.shape().ElemNum();\n\n    RowOffsetFunctor get_row_offset(num_cols);\n    cub::CountingInputIterator<int32_t> counting_intput_it(0);\n    cub::TransformInputIterator<int32_t, RowOffsetFunctor, cub::CountingInputIterator<int32_t>>\n        transform_input_iter(counting_intput_it, get_row_offset);\n\n    XYZ2YxzFunctor xyz2yxz(x.shape().At(0), x.shape().At(1), x.shape().At(2));\n    using XYZ2YxzIndexIter =\n        cub::TransformInputIterator<int32_t, XYZ2YxzFunctor, cub::CountingInputIterator<int32_t>>;\n    XYZ2YxzIndexIter xyz2yxz_iter(counting_intput_it, xyz2yxz);\n    PermutationIterator<const T, const T*, XYZ2YxzIndexIter> x_iter(x.ptr(), xyz2yxz_iter);\n    size_t tmp_storage_bytes = 0;\n    auto DoReduce = [&](T* tmp_storage_ptr) {\n      int retcode = cub::DeviceSegmentedReduce::Reduce(\n          tmp_storage_ptr, tmp_storage_bytes, x_iter, y.ptr(), num_rows, transform_input_iter,\n          transform_input_iter + 1, typename CubFunctor4BianryFunc<T, binary_func>::type(),\n          UnitOfBinaryFunc<T, binary_func>::Val(), stream->As<ep::CudaStream>()->cuda_stream());\n      CHECK_EQ(retcode, 0) << \"cub::DeviceSegmentedReduce::Reduce error\";\n    };\n    DoReduce(nullptr);\n    CHECK_GE(tmp_storage.shape().ElemNum() * sizeof(T), tmp_storage_bytes);\n    DoReduce(tmp_storage.ptr());\n  }\n};\n\nnamespace {\n\ntemplate<typename T, int NDIMS, template<typename> class binary_func>\n__global__ void NdarrayReduceGpuInplaceReduceAxis(const XpuReducedNdarray<T, NDIMS> dst_reduced,\n                                                  const XpuReducedNdarray<T, NDIMS> x, int axis) {\n  NdarrayReduceCore<T, NDIMS, binary_func>::ReduceAxis(dst_reduced, x, axis);\n}\n\n}  // namespace\n\ntemplate<typename T, int NDIMS, template<typename> class binary_func>\nstruct NdarrayReduceCoreWrapper<DeviceType::kCUDA, T, NDIMS, binary_func> final {\n  static void ReduceAxis(ep::Stream* stream, const XpuReducedNdarray<T, NDIMS>& dst_reduced,\n                         const XpuReducedNdarray<T, NDIMS>& x, int axis) {\n    size_t n = x.host_shape().HostElemNum();\n    RUN_CUDA_KERNEL((NdarrayReduceGpuInplaceReduceAxis<T, NDIMS, binary_func>), stream, n,\n                    dst_reduced, x, axis);\n  }\n};\n\n#define INSTANTIATE_NDARRAY_REDUCE_IMPL(dtype, binary_func)                                        \\\n  template struct NdarrayScalarReduce<DeviceType::kCUDA, OF_PP_PAIR_FIRST(dtype), binary_func>;    \\\n  template struct NdarrayMatrixRowReduce<DeviceType::kCUDA, OF_PP_PAIR_FIRST(dtype), binary_func>; \\\n  template struct NdarrayMatrixColReduce<DeviceType::kCUDA, OF_PP_PAIR_FIRST(dtype), binary_func>; \\\n  template struct NdarrayXYZCubeXZReduce<DeviceType::kCUDA, OF_PP_PAIR_FIRST(dtype), binary_func>;\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NDARRAY_REDUCE_IMPL,\n                                 ARITHMETIC_DATA_TYPE_SEQ HALF_DATA_TYPE_SEQ\n                                     UNSIGNED_INT_DATA_TYPE_SEQ BOOL_DATA_TYPE_SEQ,\n                                 ARITHMETIC_REDUCE_BINARY_FUNC_SEQ);\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NDARRAY_REDUCE_IMPL, FLOATING_DATA_TYPE_SEQ,\n                                 NANSUM_REDUCE_BINARY_FUNC_SEQ);\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NDARRAY_REDUCE_IMPL,\n                                 ARITHMETIC_DATA_TYPE_SEQ UNSIGNED_INT_DATA_TYPE_SEQ\n                                     BOOL_DATA_TYPE_SEQ,\n                                 LOGICAL_REDUCE_BINARY_FUNC_SEQ);\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NDARRAY_REDUCE_IMPL, CUDA_PRIMITIVE_COMPLEX_TYPE_SEQ,\n                                 REDUCE_COMPLEX_BINARY_FUNC_SEQ);\n\n#define INSTANTIATE_NDARRAY_REDUCE_CORE_WRAPPER(dtype_pair, NDIMS, binary_func)                    \\\n  template struct NdarrayReduceCoreWrapper<DeviceType::kCUDA, OF_PP_PAIR_FIRST(dtype_pair), NDIMS, \\\n                                           binary_func>;\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NDARRAY_REDUCE_CORE_WRAPPER,\n                                 ARITHMETIC_DATA_TYPE_SEQ HALF_DATA_TYPE_SEQ\n                                     UNSIGNED_INT_DATA_TYPE_SEQ BOOL_DATA_TYPE_SEQ,\n                                 DIM_SEQ, ARITHMETIC_REDUCE_BINARY_FUNC_SEQ);\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NDARRAY_REDUCE_CORE_WRAPPER, FLOATING_DATA_TYPE_SEQ,\n                                 DIM_SEQ, NANSUM_REDUCE_BINARY_FUNC_SEQ);\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NDARRAY_REDUCE_CORE_WRAPPER,\n                                 ARITHMETIC_DATA_TYPE_SEQ UNSIGNED_INT_DATA_TYPE_SEQ\n                                     BOOL_DATA_TYPE_SEQ,\n                                 DIM_SEQ, LOGICAL_REDUCE_BINARY_FUNC_SEQ);\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NDARRAY_REDUCE_CORE_WRAPPER,\n                                 CUDA_PRIMITIVE_COMPLEX_TYPE_SEQ, DIM_SEQ,\n                                 REDUCE_COMPLEX_BINARY_FUNC_SEQ);\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/ndarray/ndarray_reduce_impl.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_NDARRAY_NDARRAY_REDUCE_IMPL_H_\n#define ONEFLOW_CORE_NDARRAY_NDARRAY_REDUCE_IMPL_H_\n\n#include \"oneflow/core/ndarray/xpu_var_ndarray.h\"\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/common/switch_func.h\"\n#include \"oneflow/core/ndarray/xpu_ndarray_assign.h\"\n#include \"oneflow/core/ndarray/binary_func.h\"\n\nnamespace oneflow {\n\n#define DECLARE_NDARRAY_REDUCE_IMPL(struct_name)                                       \\\n  template<DeviceType device_type, typename T, template<typename> class binary_func>   \\\n  struct struct_name final {                                                           \\\n    static bool Matched(                                                               \\\n        const XpuVarNdarray<typename BinaryFuncTrait<binary_func, T>::return_type>& y, \\\n        const XpuVarNdarray<const T>& x);                                              \\\n    static void Reduce(                                                                \\\n        ep::Stream* ctx,                                                               \\\n        const XpuVarNdarray<typename BinaryFuncTrait<binary_func, T>::return_type>& y, \\\n        const XpuVarNdarray<const T>& x, const XpuVarNdarray<T>& tmp_storage);         \\\n  }\nDECLARE_NDARRAY_REDUCE_IMPL(NdarrayScalarReduce);\nDECLARE_NDARRAY_REDUCE_IMPL(NdarrayMatrixRowReduce);\nDECLARE_NDARRAY_REDUCE_IMPL(NdarrayMatrixColReduce);\nDECLARE_NDARRAY_REDUCE_IMPL(NdarrayXYZCubeXZReduce);\n#undef DECLARE_NDARRAY_REDUCE_IMPL\n\ntemplate<DeviceType device_type, typename T, template<typename> class binary_func,\n         typename Enable = void>\nstruct NdarrayNoReduce;\n\ntemplate<DeviceType device_type, typename T, template<typename> class binary_func>\nstruct NdarrayNoReduce<device_type, T, binary_func,\n                       typename std::enable_if<std::is_same<\n                           T, typename BinaryFuncTrait<binary_func, T>::return_type>::value>::type>\n    final {\n  using RetT = typename BinaryFuncTrait<binary_func, T>::return_type;\n  static bool Matched(const XpuVarNdarray<RetT>& y, const XpuVarNdarray<const T>& x) {\n    return x.shape() == y.shape();\n  }\n  static void Reduce(ep::Stream* ctx, const XpuVarNdarray<RetT>& y, const XpuVarNdarray<const T>& x,\n                     const XpuVarNdarray<T>& tmp_storage) {\n    if (std::is_same<binary_func<T>, BinaryFuncNanSum<T>>()) {\n      XpuNdarrayAssign<device_type, RetT>::AssignNanSum(ctx, y, x);\n    } else {\n      XpuNdarrayAssign<device_type, RetT>::Assign(ctx, y, x);\n    }\n  }\n};\n\ntemplate<DeviceType device_type, typename T, template<typename> class binary_func>\nstruct NdarrayNoReduce<device_type, T, binary_func,\n                       typename std::enable_if<!std::is_same<\n                           T, typename BinaryFuncTrait<binary_func, T>::return_type>::value>::type>\n    final {\n  using RetT = typename BinaryFuncTrait<binary_func, T>::return_type;\n  static bool Matched(const XpuVarNdarray<RetT>& y, const XpuVarNdarray<const T>& x) {\n    return x.shape() == y.shape();\n  }\n  static void Reduce(ep::Stream* ctx, const XpuVarNdarray<RetT>& y, const XpuVarNdarray<const T>& x,\n                     const XpuVarNdarray<T>& tmp_storage) {\n    return SwitchReduce(SwitchCase(y.shape().NumAxes()), ctx, y, x, tmp_storage);\n  }\n\n private:\n#define DEFINE_NDARRAY_REDUCE(func_name, NDIMS) func_name<NDIMS>\n  DEFINE_STATIC_SWITCH_FUNC(void, Reduce, DEFINE_NDARRAY_REDUCE, MAKE_NDIM_CTRV_SEQ(DIM_SEQ))\n#undef DEFINE_NDARRAY_REDUCE\n\n  template<int NDIMS>\n  static void Reduce(ep::Stream* ctx, const XpuVarNdarray<RetT>& y, const XpuVarNdarray<const T>& x,\n                     const XpuVarNdarray<T>& tmp_storage) {\n    XpuNdarrayAssign<device_type, RetT>::template Assign<NDIMS>(ctx, y, x);\n  }\n};\n\ntemplate<DeviceType device_type, typename T, int NDIMS, template<typename> class binary_func>\nstruct NdarrayReduceCoreWrapper final {\n  static void ReduceAxis(ep::Stream* ctx, const XpuReducedNdarray<T, NDIMS>& dst_reduced,\n                         const XpuReducedNdarray<T, NDIMS>& x, int axis);\n};\n\ntemplate<DeviceType device_type, typename T, template<typename> class binary_func>\nstruct NdarrayDefaultReduce final {\n  using RetT = typename BinaryFuncTrait<binary_func, T>::return_type;\n  static void Reduce(ep::Stream* ctx, const XpuVarNdarray<RetT>& y, const XpuVarNdarray<const T>& x,\n                     const XpuVarNdarray<T>& tmp_storage) {\n    return SwitchReduce(SwitchCase(y.shape().NumAxes()), ctx, y, x, tmp_storage);\n  }\n\n private:\n#define DEFINE_NDARRAY_REDUCE(func_name, NDIMS) func_name<NDIMS>\n  DEFINE_STATIC_SWITCH_FUNC(void, Reduce, DEFINE_NDARRAY_REDUCE, MAKE_NDIM_CTRV_SEQ(DIM_SEQ))\n#undef DEFINE_NDARRAY_REDUCE\n\n  template<int NDIMS>\n  static void Reduce(ep::Stream* ctx, const XpuVarNdarray<RetT>& y, const XpuVarNdarray<const T>& x,\n                     const XpuVarNdarray<T>& tmp_storage) {\n    XpuVarNdarray<T> storage(x.shape(), tmp_storage.ptr());\n    XpuShape cur_shape(x.shape());\n    CHECK_EQ(y.shape().NumAxes(), x.shape().NumAxes());\n    CHECK(x.shape() != y.shape());\n    XpuNdarrayAssign<device_type, T>::Assign(ctx, storage, x);\n    for (int i = 0; i < x.shape().NumAxes(); ++i) {\n      if (y.shape().At(i) == x.shape().At(i)) { continue; }\n      CHECK_EQ(y.shape().At(i), 1);\n      CHECK_GT(x.shape().At(i), y.shape().At(i));\n      InplaceReduceAxis<NDIMS>(ctx, i, storage, &cur_shape);\n    }\n    XpuReducedNdarray<T, NDIMS> reduced(y.shape(), storage);\n    XpuNdarrayAssign<device_type, RetT>::template Assign<NDIMS>(ctx, y, reduced);\n  }\n\n  template<int NDIMS>\n  static void InplaceReduceAxis(ep::Stream* ctx, int axis, const XpuVarNdarray<T>& implace,\n                                XpuShape* cur_shape) {\n    int64_t target_elem_num = cur_shape->ElemNum() / cur_shape->At(axis);\n    while (cur_shape->At(axis) > 1) {\n      int64_t shrink = 8 + std::sqrt(target_elem_num);\n      XpuReducedNdarray<T, NDIMS> from(*cur_shape, implace);\n      int64_t new_dim_value = (cur_shape->At(axis) + (shrink - 1)) / shrink;\n      cur_shape->Set(axis, new_dim_value);\n      XpuReducedNdarray<T, NDIMS> to(*cur_shape, implace);\n      NdarrayReduceCoreWrapper<device_type, T, NDIMS, binary_func>::ReduceAxis(ctx, to, from, axis);\n    }\n  }\n};\n\ntemplate<typename T, int NDIMS, template<typename> class binary_func>\nstruct NdarrayReduceCore final {\n  template<typename X>\n  OF_DEVICE_FUNC static void ReduceAxis(const XpuReducedNdarray<T, NDIMS>& dst_reduced, const X& x,\n                                        int axis) {\n    size_t n = dst_reduced.shape().ElemNum();\n    int64_t dst_dim_val = dst_reduced.shape().At(axis);\n    XPU_1D_KERNEL_LOOP_BEGIN(i, n);\n    T* dst_reduced_ptr = dst_reduced.template Mut(i);\n    int64_t coord[NDIMS];\n    dst_reduced.shape().template Offset2Coordinate<NDIMS>(i, coord);\n    T reduced = UnitOfBinaryFunc<T, binary_func>::Val();\n    while (coord[axis] < x.shape().At(axis)) {\n      reduced = binary_func<T>::Invoke(reduced, x.template Get<NDIMS>(coord));\n      coord[axis] += dst_dim_val;\n    }\n    *dst_reduced_ptr = reduced;\n    XPU_1D_KERNEL_LOOP_END();\n  }\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_NDARRAY_NDARRAY_REDUCE_IMPL_H_\n"
  },
  {
    "path": "oneflow/core/ndarray/ndarray_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_NDARRAY_NDARRAY_UTIL_H_\n#define ONEFLOW_CORE_NDARRAY_NDARRAY_UTIL_H_\n\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/ndarray/xpu_var_ndarray.h\"\n#include \"oneflow/core/ndarray/xpu_var_ndarray_builder.h\"\n#include \"oneflow/core/ndarray/ndarray_reduce.h\"\n#include \"oneflow/core/ndarray/ndarray_apply_unary.h\"\n#include \"oneflow/core/ndarray/ndarray_apply_binary.h\"\n#include \"oneflow/core/ndarray/ndarray_apply_broadcast_unary.h\"\n#include \"oneflow/core/ndarray/ndarray_apply_broadcast_binary.h\"\n#include \"oneflow/core/ndarray/xpu_reduced_ndarray.h\"\n#include \"oneflow/core/ndarray/xpu_ndarray_assign.h\"\n#include \"oneflow/core/common/switch_func.h\"\n#include \"oneflow/core/common/util.h\"\n\nnamespace oneflow {\n\ntemplate<DeviceType device_type, typename T>\nstruct NdarrayUtil final {\n  static XpuVarNdarrayBuilder<const T> GetValNdarrayBuilder() {\n    return XpuVarNdarrayBuilder<const T>();\n  }\n  static XpuVarNdarrayBuilder<T> GetVarNdarrayBuilder() { return XpuVarNdarrayBuilder<T>(); }\n\n  static void Assign(ep::Stream* stream, const XpuVarNdarray<T>& y,\n                     const XpuVarNdarray<const T>& x) {\n    return XpuNdarrayAssign<device_type, T>::Assign(stream, y, x);\n  }\n\n  static void BroadcastTo(ep::Stream* stream, const XpuVarNdarray<T>& y,\n                          const XpuVarNdarray<const T>& x) {\n    return BroadcastIdentity(stream, y, x);\n  }\n\n#define DEFINE_UNARY_FUNC(func_name)                                                         \\\n  static void func_name(                                                                     \\\n      ep::Stream* stream,                                                                    \\\n      const XpuVarNdarray<typename UnaryFuncTrait<UnaryFunc##func_name, T>::return_type>& y, \\\n      const XpuVarNdarray<const T>& x) {                                                     \\\n    return ApplyUnary<UnaryFunc##func_name>(stream, y, x);                                   \\\n  }\n  OF_PP_FOR_EACH_ATOMIC(DEFINE_UNARY_FUNC, ARITHMETIC_UNARY_FUNC_NAME_SEQ)\n#undef DEFINE_UNARY_FUNC\n\n#define DEFINE_ARITHMETIC_BINARY_FUNC(func_name)                                               \\\n  static void func_name(                                                                       \\\n      ep::Stream* stream,                                                                      \\\n      const XpuVarNdarray<typename BinaryFuncTrait<BinaryFunc##func_name, T>::return_type>& y, \\\n      const XpuVarNdarray<const T>& a, const XpuVarNdarray<const T>& b) {                      \\\n    return ApplyBinary<BinaryFunc##func_name>(stream, y, a, b);                                \\\n  }\n  OF_PP_FOR_EACH_ATOMIC(DEFINE_ARITHMETIC_BINARY_FUNC, ARITHMETIC_BINARY_FUNC_NAME_SEQ)\n#undef DEFINE_ARITHMETIC_BINARY_FUNC\n\n#define DEFINE_LOGICAL_BINARY_FUNC(func_name)                                                  \\\n  static void func_name(                                                                       \\\n      ep::Stream* stream,                                                                      \\\n      const XpuVarNdarray<typename BinaryFuncTrait<BinaryFunc##func_name, T>::return_type>& y, \\\n      const XpuVarNdarray<const T>& a, const XpuVarNdarray<const T>& b) {                      \\\n    return ApplyBinary<BinaryFunc##func_name>(stream, y, a, b);                                \\\n  }\n  OF_PP_FOR_EACH_ATOMIC(DEFINE_LOGICAL_BINARY_FUNC, LOGICAL_BINARY_FUNC_NAME_SEQ)\n#undef DEFINE_LOGICAL_BINARY_FUNC\n\n#define DEFINE_BROADCAST_UNARY_FUNC(func_name)                                               \\\n  static void Broadcast##func_name(                                                          \\\n      ep::Stream* stream,                                                                    \\\n      const XpuVarNdarray<typename UnaryFuncTrait<UnaryFunc##func_name, T>::return_type>& y, \\\n      const XpuVarNdarray<const T>& x) {                                                     \\\n    return BroadcastApplyUnary<UnaryFunc##func_name>(stream, y, x);                          \\\n  }\n  OF_PP_FOR_EACH_ATOMIC(DEFINE_BROADCAST_UNARY_FUNC, ARITHMETIC_UNARY_FUNC_NAME_SEQ)\n#undef DEFINE_BROADCAST_UNARY_FUNC\n\n#define DEFINE_BROADCAST_ARITHMETIC_BINARY_FUNC(func_name)                                     \\\n  static void Broadcast##func_name(                                                            \\\n      ep::Stream* stream,                                                                      \\\n      const XpuVarNdarray<typename BinaryFuncTrait<BinaryFunc##func_name, T>::return_type>& y, \\\n      const XpuVarNdarray<const T>& a, const XpuVarNdarray<const T>& b) {                      \\\n    return BroadcastApplyBinary<BinaryFunc##func_name>(stream, y, a, b);                       \\\n  }\n  OF_PP_FOR_EACH_ATOMIC(DEFINE_BROADCAST_ARITHMETIC_BINARY_FUNC, ARITHMETIC_BINARY_FUNC_NAME_SEQ)\n#undef DEFINE_BROADCAST_ARITHMETIC_BINARY_FUNC\n\n#define DEFINE_BROADCAST_LOGICAL_BINARY_FUNC(func_name)                                        \\\n  static void Broadcast##func_name(                                                            \\\n      ep::Stream* stream,                                                                      \\\n      const XpuVarNdarray<typename BinaryFuncTrait<BinaryFunc##func_name, T>::return_type>& y, \\\n      const XpuVarNdarray<const T>& a, const XpuVarNdarray<const T>& b) {                      \\\n    return BroadcastApplyBinary<BinaryFunc##func_name>(stream, y, a, b);                       \\\n  }\n  OF_PP_FOR_EACH_ATOMIC(DEFINE_BROADCAST_LOGICAL_BINARY_FUNC, LOGICAL_BINARY_FUNC_NAME_SEQ)\n#undef DEFINE_BROADCAST_LOGICAL_BINARY_FUNC\n\n#define DEFINE_INPLACE_UNARY_FUNC(func_name)                                      \\\n  static void Inplace##func_name(ep::Stream* stream, const XpuVarNdarray<T>& y) { \\\n    InplaceApply<UnaryFunc##func_name>(stream, y);                                \\\n  }\n  OF_PP_FOR_EACH_ATOMIC(DEFINE_INPLACE_UNARY_FUNC, ARITHMETIC_UNARY_FUNC_NAME_SEQ)\n#undef DEFINE_INPLACE_UNARY_FUNC\n\n#define DEFINE_INPLACE_BINARY_FUNC(func_name)                                   \\\n  static void Inplace##func_name(ep::Stream* stream, const XpuVarNdarray<T>& y, \\\n                                 const XpuVarNdarray<const T>& x) {             \\\n    InplaceApply<BinaryFunc##func_name>(stream, y, x);                          \\\n  }\n  OF_PP_FOR_EACH_ATOMIC(DEFINE_INPLACE_BINARY_FUNC, ARITHMETIC_BINARY_FUNC_NAME_SEQ)\n#undef DEFINE_INPLACE_BINARY_FUNC\n\n#define DEFINE_INPLACE_BROADCAST_BINARY_FUNC(func_name)                                  \\\n  static void InplaceBroadcast##func_name(ep::Stream* stream, const XpuVarNdarray<T>& y, \\\n                                          const XpuVarNdarray<const T>& x) {             \\\n    return InplaceBroadcastApply<BinaryFunc##func_name>(stream, y, x);                   \\\n  }\n  OF_PP_FOR_EACH_ATOMIC(DEFINE_INPLACE_BROADCAST_BINARY_FUNC, ARITHMETIC_BINARY_FUNC_NAME_SEQ)\n#undef DEFINE_INPLACE_BROADCAST_BINARY_FUNC\n\n#define DEFINE_REDUCE_FUNC(func_name)                                                 \\\n  static void Reduce##func_name(ep::Stream* stream, const XpuVarNdarray<T>& y,        \\\n                                const XpuVarNdarray<const T>& x,                      \\\n                                const XpuVarNdarray<T>& tmp_storage) {                \\\n    return NdarrayReduce<device_type, T, BinaryFunc##func_name>::Reduce(stream, y, x, \\\n                                                                        tmp_storage); \\\n  }\n  OF_PP_FOR_EACH_ATOMIC(DEFINE_REDUCE_FUNC, REDUCE_BINARY_FUNC_NAME_SEQ)\n#undef DEFINE_REDUCE_FUNC\n\n private:\n  template<template<typename> class unary_func>\n  static void BroadcastApplyUnary(\n      ep::Stream* stream,\n      const XpuVarNdarray<typename UnaryFuncTrait<unary_func, T>::return_type>& y,\n      const XpuVarNdarray<const T>& x) {\n    CHECK_EQ(x.shape().NumAxes(), y.shape().NumAxes());\n    return NdarrayApplyBroadcastUnary<device_type, T, unary_func>::Apply(stream, y, x);\n  }\n\n  template<template<typename> class binary_func>\n  static void BroadcastApplyBinary(\n      ep::Stream* stream,\n      const XpuVarNdarray<typename BinaryFuncTrait<binary_func, T>::return_type>& y,\n      const XpuVarNdarray<const T>& a, const XpuVarNdarray<const T>& b) {\n    CHECK_EQ(a.shape().NumAxes(), y.shape().NumAxes());\n    CHECK_EQ(b.shape().NumAxes(), y.shape().NumAxes());\n    return NdarrayApplyBroadcastBinary<device_type, T, binary_func>::Apply(stream, y, a, b);\n  }\n\n  template<template<typename> class binary_func>\n  static void InplaceBroadcastApply(ep::Stream* stream, const XpuVarNdarray<T>& y,\n                                    const XpuVarNdarray<const T>& x) {\n    static_assert(std::is_same<T, typename BinaryFuncTrait<binary_func, T>::return_type>::value,\n                  \"T must be same with BinaryFuncTrait<binary_func, T>::return_type\");\n    CHECK_EQ(x.shape().NumAxes(), y.shape().NumAxes());\n    return NdarrayApplyBroadcastBinary<device_type, T, binary_func>::InplaceApply(stream, y, x);\n  }\n\n  template<template<typename> class unary_func>\n  static void InplaceApply(ep::Stream* stream, const XpuVarNdarray<T>& y) {\n    static_assert(std::is_same<T, typename UnaryFuncTrait<unary_func, T>::return_type>::value,\n                  \"T must be same with UnaryFuncTrait<unary_func, T>::return_type\");\n    return NdarrayApplyUnary<device_type, T, unary_func>::InplaceApply(stream, y);\n  }\n\n  template<template<typename> class binary_func>\n  static void InplaceApply(ep::Stream* stream, const XpuVarNdarray<T>& y,\n                           const XpuVarNdarray<const T>& x) {\n    static_assert(std::is_same<T, typename BinaryFuncTrait<binary_func, T>::return_type>::value,\n                  \"T must be same with BinaryFuncTrait<binary_func, T>::return_type\");\n    return NdarrayApplyBinary<device_type, T, binary_func>::InplaceApply(stream, y, x);\n  }\n\n  template<template<typename> class unary_func>\n  static void ApplyUnary(\n      ep::Stream* stream,\n      const XpuVarNdarray<typename UnaryFuncTrait<unary_func, T>::return_type>& y,\n      const XpuVarNdarray<const T>& x) {\n    return NdarrayApplyUnary<device_type, T, unary_func>::Apply(stream, y, x);\n  }\n\n  template<template<typename> class binary_func>\n  static void ApplyBinary(\n      ep::Stream* stream,\n      const XpuVarNdarray<typename BinaryFuncTrait<binary_func, T>::return_type>& y,\n      const XpuVarNdarray<const T>& a, const XpuVarNdarray<const T>& b) {\n    if (a.host_ptr() == y.host_ptr()) {\n      CHECK(a.host_shape() == y.host_shape());\n      return NdarrayApplyBinary<device_type, T, binary_func>::InplaceApply(stream, y, b);\n    } else {\n      return NdarrayApplyBinary<device_type, T, binary_func>::Apply(stream, y, a, b);\n    }\n  }\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_NDARRAY_NDARRAY_UTIL_H_\n"
  },
  {
    "path": "oneflow/core/ndarray/slice.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/ndarray/slice.h\"\n\nnamespace oneflow {\n\nSlice::Slice(const std::initializer_list<int64_t>& l) {\n  DimVector vec(l);\n  value_capacity_ = 0;\n  if (vec.size() == 0) {\n    start_ = kStart;\n    end_ = kEnd;\n    stride_ = 1;\n  } else if (vec.size() == 1) {\n    start_ = vec[0];\n    end_ = kEnd;\n    stride_ = 1;\n  } else if (vec.size() == 2) {\n    start_ = vec[0];\n    end_ = vec[1];\n    stride_ = 1;\n  } else if (vec.size() == 3) {\n    start_ = vec[0];\n    end_ = vec[1];\n    stride_ = vec[2];\n  } else {\n    UNIMPLEMENTED();\n  }\n}\n\nbool Slice::IsBounded() const {\n  CHECK_NE(stride_, 0);\n  if (value_capacity_ == 0) { return false; }\n  return (start_ >= 0) && (start_ <= value_capacity_ - (stride_ < 0)) && (end_ >= 0 - (stride_ < 0))\n         && (end_ <= value_capacity_);\n}\n\nconst Slice& Slice::Bound(size_t value_capacity) {\n  CHECK_GT(value_capacity, 0);\n  if (value_capacity_ == value_capacity) { return *this; }\n  CHECK_EQ(value_capacity_, 0);\n  value_capacity_ = value_capacity;\n  if (start_ != kStart && start_ < 0) { start_ += value_capacity_; }\n  if (end_ != kStart && end_ < 0) { end_ += value_capacity_; }\n  if (start_ == kStart) { start_ = 0; }\n  if (end_ == kEnd) { end_ = value_capacity_; }\n  if (start_ == kEnd) { start_ = value_capacity_ - (stride_ < 0); }\n  if (end_ == kStart) { end_ = 0 - (stride_ < 0); }\n  CHECK_NE(stride_, 0);\n  CHECK_GE(start_, 0);\n  CHECK_LE(start_, value_capacity_);\n  CHECK_GE(end_, 0);\n  CHECK_LE(end_, value_capacity_);\n  return *this;\n}\n\nsize_t Slice::Size() const {\n  CHECK(IsBounded());\n  if (stride_ > 0 && start_ >= end_) { return 0; }\n  if (stride_ < 0 && start_ <= end_) { return 0; }\n  return ((end_ - start_) + (stride_ - ((stride_ > 0) - (stride_ < 0)))) / stride_;\n}\n\nbool Slice::IsContiguous() const {\n  CHECK(IsBounded());\n  return stride_ == 1;\n}\nbool Slice::IsCoveringAll() const {\n  CHECK(IsBounded());\n  return start_ == 0 && end_ == value_capacity_;\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/ndarray/slice.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_NDARRAY_SLICE_H_\n#define ONEFLOW_CORE_NDARRAY_SLICE_H_\n\n#include \"oneflow/core/ndarray/cpu_ndarray.h\"\n\nnamespace oneflow {\n\nclass Slice final {\n public:\n  static const int64_t kStart = LLONG_MIN;\n  static const int64_t kEnd = LLONG_MAX;\n\n  Slice(const Slice&) = default;\n  Slice(int64_t index) : start_(index), end_(index + 1), stride_(1), value_capacity_(0) {}\n  Slice(const std::initializer_list<int64_t>& l);\n  ~Slice() = default;\n\n  const Slice& Bound(size_t value_capacity);\n\n  ALWAYS_INLINE int64_t Get(int64_t index) const { return start_ + index * stride_; }\n  bool IsBounded() const;\n  size_t Size() const;\n  bool IsContiguous() const;\n  bool IsCoveringAll() const;\n\n private:\n  int64_t start_;\n  int64_t end_;\n  int64_t stride_;\n  size_t value_capacity_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_NDARRAY_SLICE_H_\n"
  },
  {
    "path": "oneflow/core/ndarray/slice_test.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/ndarray/slice.h\"\n#include <gtest/gtest.h>\n\nnamespace oneflow {\n\nnamespace test {\n\nTEST(Slice, size) {\n  Slice slice({-2, 0, -1});\n  slice.Bound(4);\n  ASSERT_EQ(slice.Size(), 2);\n}\n\nTEST(Slice, contiguous) {\n  Slice slice({0, -1, 1});\n  slice.Bound(4);\n  ASSERT_TRUE(slice.IsContiguous());\n  ASSERT_FALSE(slice.IsCoveringAll());\n}\n\nTEST(Slice, is_covering_all) {\n  Slice slice({});\n  slice.Bound(4);\n  ASSERT_TRUE(slice.IsCoveringAll());\n  ASSERT_TRUE(slice.IsContiguous());\n}\n\n}  // namespace test\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/ndarray/unary_func.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_NDARRAY_UNARY_FUNC_H_\n#define ONEFLOW_CORE_NDARRAY_UNARY_FUNC_H_\n\n#if defined(__CUDACC__)\n#include <cuda_fp16.hpp>\n#endif\n#include \"oneflow/core/kernel/kernel_util.h\"\n#include \"oneflow/core/common/util.h\"\n\nnamespace oneflow {\n\n#define ARITHMETIC_UNARY_FUNC_NAME_SEQ (Identity)(Negative)(Exp)\n\n#define PREPEND_PREFIX_UNARY_FUNC(name) OF_PP_CAT(UnaryFunc, name)\n#define ARITHMETIC_UNARY_FUNC_SEQ \\\n  OF_PP_SEQ_MAP(PREPEND_PREFIX_UNARY_FUNC, ARITHMETIC_UNARY_FUNC_NAME_SEQ)\n\ntemplate<template<typename> class UnaryFunc, typename T>\nstruct UnaryFuncTrait final {\n  typedef typename std::remove_const<decltype(UnaryFunc<T>::Invoke(*(const T*)nullptr))>::type\n      return_type;\n};\n\n#define SPECIALIZE_CONST_TYPE_UNARY_FUNC(func_struct)                                     \\\n  template<typename T>                                                                    \\\n  struct func_struct<const T> final {                                                     \\\n    static OF_DEVICE_FUNC const T Invoke(const T x) { return func_struct<T>::Invoke(x); } \\\n  }\n\ntemplate<typename T>\nstruct UnaryFuncIdentity final {\n  static OF_DEVICE_FUNC const T Invoke(const T x) { return x; }\n};\n\ntemplate<typename T>\nstruct UnaryFuncNegative final {\n  static OF_DEVICE_FUNC const T Invoke(const T x) { return -x; }\n};\nSPECIALIZE_CONST_TYPE_UNARY_FUNC(UnaryFuncNegative);\n\ntemplate<typename T>\nstruct UnaryFuncExp final {\n  static OF_DEVICE_FUNC const T Invoke(const T x) {\n#if defined(__CUDA_ARCH__)\n    if (std::is_same<T, double>::value) {\n      return static_cast<T>(exp(static_cast<double>(x)));\n    } else {\n      return static_cast<T>(exp(static_cast<float>(x)));\n    }\n#else\n    return std::exp(x);\n#endif  // defined(__CUDA_ARCH__)\n  }\n};\n\ntemplate<>\nstruct UnaryFuncExp<bool> final {\n  static OF_DEVICE_FUNC bool Invoke(const bool x) {\n#if defined(__CUDA_ARCH__)\n    return static_cast<bool>(exp(static_cast<float>(x)));\n#else\n    return static_cast<bool>(std::exp(static_cast<float>(x)));\n#endif  // defined(__CUDA_ARCH__)\n  }\n};\nSPECIALIZE_CONST_TYPE_UNARY_FUNC(UnaryFuncExp);\n\ntemplate<>\nstruct UnaryFuncExp<float16> final {\n  static OF_DEVICE_FUNC const float16 Invoke(const float16 x) {\n#if defined(__CUDA_ARCH__)\n    half res = static_cast<half>(exp(static_cast<float>(*reinterpret_cast<const half*>(&x))));\n    return *reinterpret_cast<float16*>(&res);\n#else\n    return float16(std::exp(static_cast<float>(x)));\n#endif  // defined(__CUDA_ARCH__)\n  }\n};\n#define NO_HALF_UTIL_FOUND         \\\n  printf(\"cuda arch must >= 530\"); \\\n  assert(false);                   \\\n  return __float2half(0.0)\n\n#if defined(__CUDACC__)\ntemplate<>\nstruct UnaryFuncNegative<half> final {\n  static __device__ __forceinline__ const half Invoke(const half x) {\n#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530\n    return __hneg(x);\n#else\n    NO_HALF_UTIL_FOUND;\n#endif\n  }\n};\ntemplate<>\nstruct UnaryFuncExp<half> final {\n  static __device__ __forceinline__ const half Invoke(const half x) {\n    return __float2half(std::exp(__half2float(x)));\n  }\n};\n\ntemplate<>\nstruct UnaryFuncNegative<cuComplex> final {\n  static __device__ __forceinline__ const cuComplex Invoke(const cuComplex x) {\n    return cuComplex{-x.x, -x.y};\n  }\n};\ntemplate<>\nstruct UnaryFuncExp<cuComplex> final {\n  static __device__ __forceinline__ const cuComplex Invoke(const cuComplex x) {\n    return cuComplex{exp(x.x) * cos(x.y), exp(x.x) * sin(x.y)};\n  }\n};\n\ntemplate<>\nstruct UnaryFuncNegative<cuDoubleComplex> final {\n  static __device__ __forceinline__ const cuDoubleComplex Invoke(const cuDoubleComplex x) {\n    return cuDoubleComplex{-x.x, -x.y};\n  }\n};\ntemplate<>\nstruct UnaryFuncExp<cuDoubleComplex> final {\n  static __device__ __forceinline__ const cuDoubleComplex Invoke(const cuDoubleComplex x) {\n    return cuDoubleComplex{exp(x.x) * cos(x.y), exp(x.x) * sin(x.y)};\n  }\n};\n#endif\n\ntemplate<typename T>\nstruct UnaryFuncLogicalNot final {\n  static OF_DEVICE_FUNC bool Invoke(const T x) { return !x; }\n};\nSPECIALIZE_CONST_TYPE_UNARY_FUNC(UnaryFuncLogicalNot);\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_NDARRAY_UNARY_FUNC_H_\n"
  },
  {
    "path": "oneflow/core/ndarray/xpu_binary_func_ndarray.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_NDARRAY_XPU_BINARY_FUNC_NDARRAY_H_\n#define ONEFLOW_CORE_NDARRAY_XPU_BINARY_FUNC_NDARRAY_H_\n\n#include \"oneflow/core/ndarray/binary_func.h\"\n\nnamespace oneflow {\n\ntemplate<typename T, template<typename> class binary_func, typename A, typename B>\nclass XpuBinaryFuncNdarray final {\n public:\n  OF_DEVICE_FUNC XpuBinaryFuncNdarray(const A& a, const B& b) : a_(a), b_(b) {}\n\n  template<int NDIMS>\n  OF_DEVICE_FUNC typename BinaryFuncTrait<binary_func, T>::return_type Get(int64_t offset) const {\n    return binary_func<T>::Invoke(a_.template Get<NDIMS>(offset), b_.template Get<NDIMS>(offset));\n  }\n\n private:\n  const A a_;\n  const B b_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_NDARRAY_XPU_BINARY_FUNC_NDARRAY_H_\n"
  },
  {
    "path": "oneflow/core/ndarray/xpu_broadcast_ndarray.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_NDARRAY_XPU_BROADCAST_NDARRAY_H_\n#define ONEFLOW_CORE_NDARRAY_XPU_BROADCAST_NDARRAY_H_\n\n#include \"oneflow/core/ndarray/xpu_var_ndarray.h\"\n#include \"oneflow/core/ndarray/xpu_ndarray_base.h\"\n\nnamespace oneflow {\n\ntemplate<typename T, int NDIMS>\nstruct XpuBroadcastNdarrayUtil;\n\ntemplate<typename T>\nclass XpuBroadcastNdarray final : public XpuNdarrayBase<XpuBroadcastNdarray<T>, T> {\n public:\n  OF_DEVICE_FUNC XpuBroadcastNdarray(const XpuShape& shape, const XpuVarNdarray<T>& var)\n      : shape_(shape), var_(var) {}\n  ~XpuBroadcastNdarray() = default;\n\n  template<int NDIMS>\n  OF_DEVICE_FUNC T Get(int64_t offset) const {\n    int64_t coord[NDIMS];\n    shape_.template Offset2Coordinate<NDIMS>(offset, coord);\n    XpuBroadcastNdarrayUtil<T, NDIMS>::SrcCoordinate(var_.shape(), coord);\n    return var_.template Get<NDIMS>(coord);\n  }\n\n  OF_DEVICE_FUNC const XpuShape& shape() const { return shape_; }\n  OF_DEVICE_FUNC const XpuVarNdarray<T>& var() const { return var_; }\n\n private:\n  const XpuShape shape_;\n  const XpuVarNdarray<T> var_;\n};\n\n#define IMPLACE_SET_SRC_COORD(i) coord[i] %= src_shape.At(i);\n#define SPECIALIZE_XPU_BROADCAST_NDARRAY_UTIL(n)                                                \\\n  template<typename T>                                                                          \\\n  struct XpuBroadcastNdarrayUtil<T, n + 1> final {                                              \\\n    OF_DEVICE_FUNC static void SrcCoordinate(const XpuShape& src_shape, int64_t coord[n + 1]) { \\\n      OF_PP_FOR_EACH_TUPLE(IMPLACE_SET_SRC_COORD, GET_SEQ(n));                                  \\\n    }                                                                                           \\\n  }\nSPECIALIZE_XPU_BROADCAST_NDARRAY_UTIL(0);\nSPECIALIZE_XPU_BROADCAST_NDARRAY_UTIL(1);\nSPECIALIZE_XPU_BROADCAST_NDARRAY_UTIL(2);\nSPECIALIZE_XPU_BROADCAST_NDARRAY_UTIL(3);\nSPECIALIZE_XPU_BROADCAST_NDARRAY_UTIL(4);\nSPECIALIZE_XPU_BROADCAST_NDARRAY_UTIL(5);\nSPECIALIZE_XPU_BROADCAST_NDARRAY_UTIL(6);\n#undef SPECIALIZE_XPU_BROADCAST_NDARRAY_UTIL\n#undef IMPLACE_SET_SRC_COORD\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_NDARRAY_XPU_BROADCAST_NDARRAY_H_\n"
  },
  {
    "path": "oneflow/core/ndarray/xpu_ndarray_assign.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/ndarray/ndarray_assign_core.h\"\n#include \"oneflow/core/device/cuda_util.h\"\n#include \"oneflow/core/kernel/kernel_util.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename T, typename X, int NDIMS>\n__global__ void NdarrayAssignReducedGpu(XpuVarNdarray<T> y,\n                                        const XpuReducedNdarray<X, NDIMS> reduced) {\n  NdarrayAssignCore<T, X, NDIMS>::Assign(y, reduced);\n}\n\ntemplate<typename T, typename X, int NDIMS>\n__global__ void NdarrayAssignGpu(XpuVarNdarray<T> y, const XpuVarNdarray<const X> x) {\n  NdarrayAssignCore<T, X, NDIMS>::Assign(y, x);\n}\n\n}  // namespace\n\ntemplate<typename T, typename X, int NDIMS>\nstruct NdarrayAssignCoreWrapper<DeviceType::kCUDA, T, X, NDIMS> final {\n  static void Assign(ep::Stream* stream, XpuVarNdarray<T>* y,\n                     const XpuReducedNdarray<X, NDIMS>& reduced) {\n    size_t n = y->host_shape().HostElemNum();\n    RUN_CUDA_KERNEL((NdarrayAssignReducedGpu<T, X, NDIMS>), stream, n, *y, reduced);\n  }\n  static void Assign(ep::Stream* ctx, const XpuVarNdarray<T>& y, const XpuVarNdarray<const X>& x) {\n    size_t n = y.host_shape().HostElemNum();\n    if (n == 0) { return; }\n    RUN_CUDA_KERNEL((NdarrayAssignGpu<T, X, NDIMS>), ctx, n, y, x);\n  }\n};\n\n#define INSTANTIATE_NDARRAY_ASSIGN(ret_dtype_pair, dtype_pair, NDIMS)                           \\\n  template struct NdarrayAssignCoreWrapper<DeviceType::kCUDA, OF_PP_PAIR_FIRST(ret_dtype_pair), \\\n                                           OF_PP_PAIR_FIRST(dtype_pair), NDIMS>;\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(\n    INSTANTIATE_NDARRAY_ASSIGN,\n    ARITHMETIC_DATA_TYPE_SEQ UNSIGNED_INT_DATA_TYPE_SEQ BOOL_DATA_TYPE_SEQ,\n    ARITHMETIC_DATA_TYPE_SEQ UNSIGNED_INT_DATA_TYPE_SEQ BOOL_DATA_TYPE_SEQ, DIM_SEQ);\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NDARRAY_ASSIGN, HALF_DATA_TYPE_SEQ, HALF_DATA_TYPE_SEQ,\n                                 DIM_SEQ);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/ndarray/xpu_ndarray_assign.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_NDARRAY_XPU_ASSIGN_H_\n#define ONEFLOW_CORE_NDARRAY_XPU_ASSIGN_H_\n\n#include \"oneflow/core/ndarray/ndarray_assign_core.h\"\n#include \"oneflow/core/kernel/kernel_util.h\"\n#include \"oneflow/core/ep/include/primitive/elementwise_unary.h\"\n#include \"oneflow/core/ep/include/primitive/unary_op.h\"\n\nnamespace oneflow {\n\ntemplate<DeviceType device_type, typename T, typename Enable = void>\nstruct XpuNdarrayAssign;\n\ntemplate<DeviceType device_type, typename T>\nstruct XpuNdarrayAssign<\n    device_type, T,\n    typename std::enable_if<std::is_same<T, typename DevDType<device_type, T>::type>::value>::type>\n    final {\n  template<int NDIMS, typename X>\n  static void Assign(ep::Stream* stream, const XpuVarNdarray<T>& y,\n                     const XpuReducedNdarray<X, NDIMS>& reduced) {\n    NdarrayAssignCoreWrapper<device_type, T, X, NDIMS>::Assign(stream, y, reduced);\n  }\n  template<int NDIMS, typename X>\n  static void Assign(ep::Stream* stream, const XpuVarNdarray<T>& y,\n                     const XpuVarNdarray<const X>& x) {\n    NdarrayAssignCoreWrapper<device_type, T, X, NDIMS>::Assign(stream, y, x);\n  }\n  static void Assign(ep::Stream* stream, const XpuVarNdarray<T>& y,\n                     const XpuVarNdarray<const T>& x) {\n    CHECK(y.shape() == x.shape());\n    if (x.ptr() == y.ptr()) { return; }\n    Memcpy<device_type>(stream, y.ptr(), x.ptr(), y.shape().ElemNum() * sizeof(T));\n  }\n\n  static void AssignNanSum(ep::Stream* stream, const XpuVarNdarray<T>& y,\n                           const XpuVarNdarray<const T>& x) {\n    CHECK(y.shape() == x.shape());  // NOLINT\n    CHECK_EQ(device_type, stream->device_type()) << \"Device type mismatch\";\n    std::unique_ptr<ep::primitive::ElementwiseUnary> primitive =\n        ep::primitive::NewPrimitive<ep::primitive::ElementwiseUnaryFactory>(\n            device_type, ep::primitive::UnaryOp::kNanAssign, GetDataType<T>(), GetDataType<T>());\n    CHECK(primitive) << \"Can not create NanSum primitive for device type \" << device_type;\n    primitive->Launch(stream, x.ptr(), y.ptr(), y.shape().ElemNum());\n  }\n};\n\ntemplate<DeviceType device_type, typename T>\nstruct XpuNdarrayAssign<\n    device_type, T,\n    typename std::enable_if<!std::is_same<T, typename DevDType<device_type, T>::type>::value>::type>\n    final {\n  using NewT = typename DevDType<device_type, T>::type;\n  template<int NDIMS>\n  static void Assign(ep::Stream* stream, const XpuVarNdarray<T>& y,\n                     const XpuReducedNdarray<T, NDIMS>& reduced) {\n    XpuNdarrayAssign<device_type, NewT>::Assign(\n        stream, reinterpret_cast<const XpuVarNdarray<NewT>&>(y),\n        reinterpret_cast<const XpuReducedNdarray<NewT, NDIMS>&>(reduced));\n  }\n\n  static void Assign(ep::Stream* ctx, const XpuVarNdarray<T>& y, const XpuVarNdarray<const T>& x) {\n    XpuNdarrayAssign<device_type, NewT>::Assign(\n        ctx, reinterpret_cast<const XpuVarNdarray<NewT>&>(y),\n        reinterpret_cast<const XpuVarNdarray<const NewT>&>(x));\n  }\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_NDARRAY_XPU_ASSIGN_H_\n"
  },
  {
    "path": "oneflow/core/ndarray/xpu_ndarray_base.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_NDARRAY_XPU_NDARRAY_BASE_H_\n#define ONEFLOW_CORE_NDARRAY_XPU_NDARRAY_BASE_H_\n\n#include \"oneflow/core/ndarray/xpu_shape.h\"\n\nnamespace oneflow {\n\ntemplate<typename T, template<typename> class unary_func, typename X>\nclass XpuUnaryFuncNdarray;\ntemplate<typename T, template<typename> class binary_func, typename A, typename B>\nclass XpuBinaryFuncNdarray;\ntemplate<typename T>\nclass XpuBroadcastNdarray;\ntemplate<typename T, int, typename X>\nclass XpuTransposeNdarray;\ntemplate<typename T, int, typename X>\nclass XpuReshapeNdarray;\n\ntemplate<typename DerivedT, typename T>\nclass XpuNdarrayBase {\n public:\n  XpuNdarrayBase() = default;\n  ~XpuNdarrayBase() = default;\n\n  template<template<typename> class unary_func>\n  OF_DEVICE_FUNC XpuUnaryFuncNdarray<T, unary_func, DerivedT> UnaryFunc() const {\n    return XpuUnaryFuncNdarray<T, unary_func, DerivedT>(*static_cast<const DerivedT*>(this));\n  }\n  template<template<typename> class binary_func, typename X>\n  OF_DEVICE_FUNC XpuBinaryFuncNdarray<T, binary_func, DerivedT, X> BinaryFunc(const X& x) const {\n    return XpuBinaryFuncNdarray<T, binary_func, DerivedT, X>(*static_cast<const DerivedT*>(this),\n                                                             x);\n  }\n  OF_DEVICE_FUNC XpuBroadcastNdarray<const T> Broadcast(const XpuShape& shape) const {\n    return XpuBroadcastNdarray<const T>(shape, *static_cast<const DerivedT*>(this));\n  }\n  template<int NDIMS>\n  OF_DEVICE_FUNC XpuTransposeNdarray<T, NDIMS, DerivedT> Transpose(\n      const int64_t perm[NDIMS]) const {\n    return XpuTransposeNdarray<T, NDIMS, DerivedT>(*static_cast<const DerivedT*>(this), perm);\n  }\n  template<int NDIMS>\n  OF_DEVICE_FUNC XpuReshapeNdarray<T, NDIMS, DerivedT> Reshape(const int64_t shape[NDIMS]) {\n    return XpuReshapeNdarray<T, NDIMS, DerivedT>(*static_cast<const DerivedT*>(this), shape);\n  }\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_NDARRAY_XPU_NDARRAY_BASE_H_\n"
  },
  {
    "path": "oneflow/core/ndarray/xpu_reduced_ndarray.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_NDARRAY_XPU_REDUCED_NDARRAY_H_\n#define ONEFLOW_CORE_NDARRAY_XPU_REDUCED_NDARRAY_H_\n\n#include \"oneflow/core/ndarray/xpu_var_ndarray.h\"\n#include \"oneflow/core/ndarray/xpu_util.h\"\n#include \"oneflow/core/ndarray/unary_func.h\"\n\nnamespace oneflow {\n\ntemplate<typename T, int NDIMS, typename X = XpuVarNdarray<T>>\nclass XpuReducedNdarray final {\n public:\n  OF_DEVICE_FUNC XpuReducedNdarray(const XpuShape& shape, const X& data)\n      : shape_(shape), data_(data) {}\n\n  OF_DEVICE_FUNC const XpuShape& shape() const { return shape_; }\n  const XpuShape& host_shape() const { return shape_; }\n  OF_DEVICE_FUNC const X& data() const { return data_; }\n\n  template<int ndims = NDIMS>\n  OF_DEVICE_FUNC T Get(int64_t offset) const {\n    int64_t coord[NDIMS];\n    shape_.template Offset2Coordinate<NDIMS>(offset, coord);\n    return Get(coord);\n  }\n\n  template<int ndims = NDIMS>\n  OF_DEVICE_FUNC T Get(int64_t coord[ndims]) const {\n    return data_.template Get<ndims>(coord);\n  }\n\n  template<int ndims = NDIMS>\n  OF_DEVICE_FUNC T* Mut(int64_t offset) const {\n    int64_t coord[NDIMS];\n    shape_.template Offset2Coordinate<NDIMS>(offset, coord);\n    return Mut(coord);\n  }\n\n  template<int ndims = NDIMS>\n  OF_DEVICE_FUNC T* Mut(int64_t coord[NDIMS]) const {\n    return data_.template Mut<NDIMS>(coord);\n  }\n\n private:\n  XpuShape shape_;\n  X data_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_NDARRAY_XPU_REDUCED_NDARRAY_H_\n"
  },
  {
    "path": "oneflow/core/ndarray/xpu_reshape_ndarray.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_NDARRAY_XPU_RESHAPE_NDARRAY_H_\n#define ONEFLOW_CORE_NDARRAY_XPU_RESHAPE_NDARRAY_H_\n\nnamespace oneflow {\n\ntemplate<typename T, int NDIMS, typename X = XpuVarNdarray<T>>\nclass XpuReshapeNdarray final {\n public:\n  OF_DEVICE_FUNC XpuReshapeNdarray(const X& x, const int64_t dim[NDIMS])\n      : x_(x), shape_(dim, NDIMS) {}\n\n  template<int ndims = NDIMS>\n  OF_DEVICE_FUNC T Get(int64_t offset) const {\n    return x_.template Get<ndims>(offset);\n  }\n  template<int ndims = NDIMS>\n  OF_DEVICE_FUNC T* Mut(int64_t offset) const {\n    return x_.template Mut<ndims>(offset);\n  }\n  template<int ndims = NDIMS>\n  OF_DEVICE_FUNC T Get(int64_t coord[ndims]) const {\n    return Get<ndims>(Coord2Offset(coord));\n  }\n  template<int ndims = NDIMS>\n  OF_DEVICE_FUNC T* Mut(int64_t coord[NDIMS]) const {\n    return Get<NDIMS>(Coord2Offset(coord));\n  }\n\n private:\n  OF_DEVICE_FUNC int64_t Coord2Offset(const int64_t coord[NDIMS]) const {\n    return XpuShapeUtil<NDIMS>::Coord2Offset(shape_, coord);\n  }\n  const X& x_;\n  XpuShape shape_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_NDARRAY_XPU_RESHAPE_NDARRAY_H_\n"
  },
  {
    "path": "oneflow/core/ndarray/xpu_shape.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/ndarray/xpu_shape.h\"\n\nnamespace oneflow {\n\nXpuShape::XpuShape(const int64_t dim[], int num_axes) {\n  num_axes_ = num_axes;\n  int i = 0;\n  for (; i < num_axes_; ++i) { dim_[i] = dim[i]; }\n  UpdateDimElemNumAndElemNum();\n  for (; i < sizeof(dim_) / sizeof(dim_[0]); ++i) {\n    dim_[i] = 1;\n    dim_elem_num_[i] = 1;\n  }\n}\n\nXpuShape::XpuShape(const Shape& shape) {\n  num_axes_ = shape.NumAxes();\n  int i = 0;\n  for (; i < num_axes_; ++i) { dim_[i] = shape.At(i); }\n  UpdateDimElemNumAndElemNum();\n  for (; i < sizeof(dim_) / sizeof(dim_[0]); ++i) {\n    dim_[i] = 1;\n    dim_elem_num_[i] = 1;\n  }\n}\n\nXpuShape::XpuShape(const ShapeView& shape) {\n  num_axes_ = shape.NumAxes();\n  int i = 0;\n  for (; i < num_axes_; ++i) { dim_[i] = shape.At(i); }\n  UpdateDimElemNumAndElemNum();\n  for (; i < sizeof(dim_) / sizeof(dim_[0]); ++i) {\n    dim_[i] = 1;\n    dim_elem_num_[i] = 1;\n  }\n}\n\nXpuShape::XpuShape(const ShapeView& shape, int ndims_left_extend_to) {\n  if (shape.NumAxes() == 1 && ndims_left_extend_to == 0) {\n    num_axes_ = 0;\n    int i = 0;\n    dim_[i] = shape.At(i);\n    UpdateDimElemNumAndElemNum();\n    for (; i < sizeof(dim_) / sizeof(dim_[0]); ++i) { dim_[i] = 1; }\n  } else {\n    CHECK_LE(shape.NumAxes(), ndims_left_extend_to);\n    num_axes_ = ndims_left_extend_to;\n    size_t left_ones_num = num_axes_ - shape.NumAxes();\n    int i = 0;\n    for (; i < left_ones_num; ++i) { dim_[i] = 1; }\n    for (; i < num_axes_; ++i) { dim_[i] = shape.At(i - left_ones_num); }\n    UpdateDimElemNumAndElemNum();\n    for (; i < sizeof(dim_) / sizeof(dim_[0]); ++i) {\n      dim_[i] = 1;\n      dim_elem_num_[i] = 1;\n    }\n  }\n}\n\nbool XpuShape::operator==(const XpuShape& rhs) const {\n  if (num_axes_ != rhs.num_axes_) { return false; }\n  if (elem_num_ != rhs.elem_num_) { return false; }\n  for (int i = 0; i < num_axes_; ++i) {\n    if (dim_[i] != rhs.dim_[i]) { return false; }\n    if (dim_elem_num_[i] != rhs.dim_elem_num_[i]) { return false; }\n  }\n  return true;\n}\n\nvoid SimplifyBroadcastShapes(const XpuShape& y, const XpuShape& b, DimVector* simplified_y,\n                             DimVector* simplified_b) {\n  DimVector simplified_a;\n  SimplifyBroadcastShapes(y, y, b, simplified_y, &simplified_a, simplified_b);\n}\n\nvoid SimplifyBroadcastShapes(const XpuShape& y, const XpuShape& a, const XpuShape& b,\n                             DimVector* simplified_y, DimVector* simplified_a,\n                             DimVector* simplified_b) {\n  CHECK_EQ(y.NumAxes(), a.NumAxes());\n  CHECK_EQ(b.NumAxes(), a.NumAxes());\n  CHECK(simplified_y->empty());\n  CHECK(simplified_a->empty());\n  CHECK(simplified_b->empty());\n  simplified_y->emplace_back(y.At(0));\n  simplified_a->emplace_back(a.At(0));\n  simplified_b->emplace_back(b.At(0));\n  bool a_prev_axis_is_broadcast = (a.At(0) == 1);\n  bool b_prev_axis_is_broadcast = (b.At(0) == 1);\n  FOR_RANGE(int, i, 1, y.NumAxes()) {\n    const int64_t y_dim = y.At(i);\n    const int64_t a_dim = a.At(i);\n    const int64_t b_dim = b.At(i);\n    if ((a_dim == 1) && (b_dim == 1)) { continue; }\n    const bool a_cur_axis_is_broadcast = (a_dim == 1);\n    const bool b_cur_axis_is_broadcast = (b_dim == 1);\n    if (a_prev_axis_is_broadcast == a_cur_axis_is_broadcast\n        && b_prev_axis_is_broadcast == b_cur_axis_is_broadcast) {\n      simplified_y->back() *= y_dim;\n      simplified_a->back() *= a_dim;\n      simplified_b->back() *= b_dim;\n    } else {\n      simplified_y->emplace_back(y_dim);\n      simplified_a->emplace_back(a_dim);\n      simplified_b->emplace_back(b_dim);\n    }\n    a_prev_axis_is_broadcast = a_cur_axis_is_broadcast;\n    b_prev_axis_is_broadcast = b_cur_axis_is_broadcast;\n  }\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/ndarray/xpu_shape.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_NDARRAY_XPU_SHAPE_H_\n#define ONEFLOW_CORE_NDARRAY_XPU_SHAPE_H_\n\n#include \"oneflow/core/common/shape.h\"\n#include \"oneflow/core/common/shape_view.h\"\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/kernel/kernel_util.h\"\n#include \"oneflow/core/ndarray/xpu_util.h\"\n\nnamespace oneflow {\n\ntemplate<int NDIMS>\nstruct XpuShapeUtil;\n\nclass XpuShape final {\n public:\n  explicit XpuShape(const Shape& shape);\n  explicit XpuShape(const ShapeView& shape);\n  explicit XpuShape(const ShapeView& shape, int ndims_left_extend_to);\n  OF_DEVICE_FUNC XpuShape(const int64_t dim[], int num_axes);\n  XpuShape(const XpuShape&) = default;\n\n  OF_DEVICE_FUNC int64_t At(int64_t dim) const { return dim_[dim]; }\n  OF_DEVICE_FUNC int64_t DimElemNum(int64_t dim) const { return dim_elem_num_[dim]; }\n  OF_DEVICE_FUNC int64_t Count(int64_t dim) const { return At(dim) * DimElemNum(dim); }\n\n  OF_DEVICE_FUNC size_t ElemNum() const { return elem_num_; }\n  OF_DEVICE_FUNC size_t NumAxes() const { return num_axes_; }\n  size_t HostElemNum() const { return elem_num_; }\n  bool operator==(const XpuShape&) const;\n  bool operator!=(const XpuShape& rhs) const { return !(*this == rhs); }\n\n  OF_DEVICE_FUNC void Set(int64_t axis, int64_t value) {\n    dim_[axis] = value;\n    UpdateDimElemNumAndElemNum();\n  }\n\n  template<int NDIMS>\n  OF_DEVICE_FUNC int64_t Coordinate2Offset(const int64_t coord[NDIMS]) const {\n    return XpuShapeUtil<NDIMS>::Coordinate2Offset(*this, coord);\n  }\n  template<int NDIMS>\n  OF_DEVICE_FUNC void Offset2Coordinate(int64_t offset, int64_t coord[NDIMS]) const {\n    XpuShapeUtil<NDIMS>::Offset2Coordinate(*this, offset, coord);\n  }\n\n  OF_DEVICE_FUNC void UpdateDimElemNumAndElemNum() {\n    elem_num_ = 1;\n    for (int i = num_axes_ - 1; i >= 0; --i) {\n      dim_elem_num_[i] = elem_num_;\n      elem_num_ *= dim_[i];\n    }\n  }\n\n  std::string ToString() const { return ShapeView(dim_, num_axes_).ToString(); }\n\n private:\n  size_t num_axes_;\n  size_t elem_num_;\n  int64_t dim_[OF_PP_SEQ_SIZE(DIM_SEQ)];\n  int64_t dim_elem_num_[OF_PP_SEQ_SIZE(DIM_SEQ)];\n};\n\ntemplate<>\nstruct XpuShapeUtil<1> final {\n  OF_DEVICE_FUNC static int64_t Coordinate2Offset(const XpuShape& shape, const int64_t coord[1]) {\n    return coord[0];\n  }\n  OF_DEVICE_FUNC static void Offset2Coordinate(const XpuShape& shape, int64_t offset,\n                                               int64_t coord[1]) {\n    coord[0] = offset;\n  }\n};\n\n#define COORD_MUL_STRIDE(i) coord[i] * shape.DimElemNum(i) +\n#define EXTRACT_COORD(i)                   \\\n  coord[i] = offset / shape.DimElemNum(i); \\\n  offset %= shape.DimElemNum(i);\n\n#define SPECIALIZE_XPU_SHAPE_UTIL(n)                                                    \\\n  template<>                                                                            \\\n  struct XpuShapeUtil<n + 2> final {                                                    \\\n    OF_DEVICE_FUNC static int64_t Coordinate2Offset(const XpuShape& shape,              \\\n                                                    const int64_t coord[n + 2]) {       \\\n      return OF_PP_FOR_EACH_TUPLE(COORD_MUL_STRIDE, GET_SEQ(n)) coord[n + 1];           \\\n    }                                                                                   \\\n    OF_DEVICE_FUNC static void Offset2Coordinate(const XpuShape& shape, int64_t offset, \\\n                                                 int64_t coord[n + 2]) {                \\\n      OF_PP_FOR_EACH_TUPLE(EXTRACT_COORD, GET_SEQ(n));                                  \\\n      coord[n + 1] = offset;                                                            \\\n    }                                                                                   \\\n  };\n\nSPECIALIZE_XPU_SHAPE_UTIL(0);\nSPECIALIZE_XPU_SHAPE_UTIL(1);\nSPECIALIZE_XPU_SHAPE_UTIL(2);\nSPECIALIZE_XPU_SHAPE_UTIL(3);\nSPECIALIZE_XPU_SHAPE_UTIL(4);\nSPECIALIZE_XPU_SHAPE_UTIL(5);\n#undef SPECIALIZE_XPU_SHAPE_UTIL\n#undef EXTRACT_COORD\n#undef COORD_MUL_STRIDE\n\nvoid SimplifyBroadcastShapes(const XpuShape& y, const XpuShape& b, DimVector* simplified_y,\n                             DimVector* simplified_b);\n\nvoid SimplifyBroadcastShapes(const XpuShape& y, const XpuShape& a, const XpuShape& b,\n                             DimVector* simplified_y, DimVector* simplified_a,\n                             DimVector* simplified_b);\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_NDARRAY_XPU_SHAPE_H_\n"
  },
  {
    "path": "oneflow/core/ndarray/xpu_transpose_ndarray.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_NDARRAY_XPU_TRANSPOSE_NDARRAY_H_\n#define ONEFLOW_CORE_NDARRAY_XPU_TRANSPOSE_NDARRAY_H_\n\n#include \"oneflow/core/kernel/kernel_util.h\"\n\nnamespace oneflow {\n\ntemplate<typename T, int NDIMS, typename X = XpuVarNdarray<T>>\nclass XpuTransposeNdarray final {\n public:\n  OF_DEVICE_FUNC XpuTransposeNdarray(const X& x, const int64_t perm[NDIMS])\n      : x_(x), shape_(x.shape()) {\n    for (int i = 0; i < NDIMS; ++i) {\n      perm_[i] = perm[i];\n      shape_.Set(i, x.shape().At(perm[i]));\n    }\n  }\n\n  template<int ndims, typename = typename std::enable_if<ndims == NDIMS>::type>\n  OF_DEVICE_FUNC T Get(int64_t offset) const {\n    int64_t coord[NDIMS];\n    Offset2Coord(offset, coord);\n    return Get(coord);\n  }\n\n  template<int ndims, typename = typename std::enable_if<ndims == NDIMS>::type>\n  OF_DEVICE_FUNC T* Mut(int64_t offset) const {\n    int64_t coord[NDIMS];\n    Offset2Coord(offset, coord);\n    return Mut(coord);\n  }\n\n  template<int ndims, typename = typename std::enable_if<ndims == NDIMS>::type>\n  OF_DEVICE_FUNC T Get(int64_t coord[ndims]) const {\n    int64_t permuted_coord[NDIMS];\n    PermuteCoord(coord, permuted_coord);\n    return x_.template Get<ndims>(permuted_coord);\n  }\n\n  template<int ndims, typename = typename std::enable_if<ndims == NDIMS>::type>\n  OF_DEVICE_FUNC T* Mut(int64_t coord[NDIMS]) const {\n    int64_t permuted_coord[NDIMS];\n    PermuteCoord(coord, permuted_coord);\n    return x_.template Mut<NDIMS>(permuted_coord);\n  }\n\n private:\n  OF_DEVICE_FUNC void Offset2Coord(int64_t offset, int64_t coord[NDIMS]) const {\n    shape_.template Offset2Coordinate<NDIMS>(offset, coord);\n  }\n\n  OF_DEVICE_FUNC void PermuteCoord(const int64_t coord[NDIMS],\n                                   int64_t permuted_coord[NDIMS]) const {\n    for (int i = 0; i < NDIMS; ++i) { permuted_coord[perm_[i]] = coord[i]; }\n  }\n\n  const X& x_;\n  XpuShape shape_;\n  int64_t perm_[NDIMS];\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_NDARRAY_XPU_TRANSPOSE_NDARRAY_H_\n"
  },
  {
    "path": "oneflow/core/ndarray/xpu_unary_func_ndarray.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_UNARY_FUNC_NDARRAY_H_\n#define ONEFLOW_CORE_UNARY_FUNC_NDARRAY_H_\n\nnamespace oneflow {\n\ntemplate<typename T, template<typename> class unary_func, typename X>\nclass XpuUnaryFuncNdarray final {\n public:\n  OF_DEVICE_FUNC XpuUnaryFuncNdarray(const X& x) : x_(x) {}\n\n  template<int NDIMS>\n  OF_DEVICE_FUNC T Get(int64_t offset) const {\n    return unary_func<T>::Invoke(x_.template Get<NDIMS>(offset));\n  }\n\n private:\n  const X& x_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_UNARY_FUNC_NDARRAY_H_\n"
  },
  {
    "path": "oneflow/core/ndarray/xpu_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_NDARRAY_XPU_UTIL_H_\n#define ONEFLOW_CORE_NDARRAY_XPU_UTIL_H_\n\n#include \"oneflow/core/kernel/kernel_util.h\"\n#include \"oneflow/core/device/cuda_util.h\"\n#include \"oneflow/core/thread/thread_manager.h\"\n#include \"oneflow/core/ep/include/stream.h\"\n\nnamespace oneflow {\n\n#if defined(__CUDACC__)\n#define XPU_1D_KERNEL_LOOP_BEGIN(i, n) CUDA_1D_KERNEL_LOOP(i, n) {\n#define XPU_1D_KERNEL_LOOP_END() }\n#else\n#define XPU_1D_KERNEL_LOOP_BEGIN(i, n) MultiThreadLoop(n, [&](size_t i) {\n#define XPU_1D_KERNEL_LOOP_END() \\\n  });\n#endif\n\n#if defined(__CUDACC__)\n#define XPU_1D_KERNEL_LOOP(i, n) CUDA_1D_KERNEL_LOOP(i, n)\n#else\n#define XPU_1D_KERNEL_LOOP(i, n) FOR_RANGE(int64_t, i, 0, n)\n#endif\n\n#if defined(__CUDACC__)\n#define XPU_BLOAD_THREAD_2D_KERNEL_LOOP(i, j, m, n)     \\\n  for (int64_t i = blockIdx.x; i < (m); i += gridDim.x) \\\n    for (int64_t j = threadIdx.x; j < (n); j += blockDim.x)\n#else\n#define XPU_BLOAD_THREAD_2D_KERNEL_LOOP(i, j, m, n) \\\n  for (int64_t i = 0; i < (m); ++i)                 \\\n    for (int64_t j = 0; j < (n); ++j)\n#endif\n\n#if defined(__CUDACC__)\n#define OF_GLOBAL_FUNC __global__\n#else\n#define OF_GLOBAL_FUNC\n#endif\n\n#define GET_SEQ(n) OF_PP_CAT(OF_PP_CAT(GET_SEQ_, n), )\n#define GET_SEQ_0 OF_PP_MAKE_TUPLE_SEQ(0)\n#define GET_SEQ_1 GET_SEQ_0 OF_PP_MAKE_TUPLE_SEQ(1)\n#define GET_SEQ_2 GET_SEQ_1 OF_PP_MAKE_TUPLE_SEQ(2)\n#define GET_SEQ_3 GET_SEQ_2 OF_PP_MAKE_TUPLE_SEQ(3)\n#define GET_SEQ_4 GET_SEQ_3 OF_PP_MAKE_TUPLE_SEQ(4)\n#define GET_SEQ_5 GET_SEQ_4 OF_PP_MAKE_TUPLE_SEQ(5)\n#define GET_SEQ_6 GET_SEQ_5 OF_PP_MAKE_TUPLE_SEQ(6)\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_NDARRAY_XPU_UTIL_H_\n"
  },
  {
    "path": "oneflow/core/ndarray/xpu_var_ndarray.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_NDARRAY_XPU_VAR_NDARRAY_H_\n#define ONEFLOW_CORE_NDARRAY_XPU_VAR_NDARRAY_H_\n\n#include \"oneflow/core/ndarray/xpu_shape.h\"\n#include \"oneflow/core/kernel/kernel_util.h\"\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/register/blob.h\"\n#include \"oneflow/core/ndarray/xpu_util.h\"\n#include \"oneflow/core/ndarray/xpu_ndarray_base.h\"\n\nnamespace oneflow {\n\ntemplate<typename T>\nclass XpuVarNdarray final : public XpuNdarrayBase<XpuVarNdarray<T>, T> {\n public:\n  XpuVarNdarray(const Blob* blob, int ndims_left_extend_to)\n      : shape_(blob->shape(), ndims_left_extend_to),\n        ptr_(blob->dptr<typename std::remove_const<T>::type>()) {}\n  XpuVarNdarray(Blob* blob, int ndims_left_extend_to)\n      : shape_(blob->shape(), ndims_left_extend_to), ptr_(blob->mut_dptr<T>()) {}\n  XpuVarNdarray(const Shape& shape, T* ptr) : shape_(shape), ptr_(ptr) {}\n  XpuVarNdarray(const ShapeView& shape, T* ptr) : shape_(shape), ptr_(ptr) {}\n  XpuVarNdarray(const ShapeView& shape, T* ptr, int ndims_left_extend_to)\n      : shape_(shape, ndims_left_extend_to), ptr_(ptr) {}\n  ~XpuVarNdarray() = default;\n  ALWAYS_INLINE XpuVarNdarray(const XpuVarNdarray&) = default;\n  OF_DEVICE_FUNC ALWAYS_INLINE XpuVarNdarray(const XpuShape& shape, T* ptr)\n      : shape_(shape), ptr_(ptr) {}\n\n  const XpuShape& host_shape() const { return shape_; }\n  T* host_ptr() const { return ptr_; }\n\n  OF_DEVICE_FUNC const XpuShape& shape() const { return shape_; }\n  OF_DEVICE_FUNC T* ptr() const { return ptr_; }\n\n  template<int NDIMS>\n  OF_DEVICE_FUNC T Get(int64_t offset) const {\n    return ptr_[offset];\n  }\n  template<int NDIMS>\n  OF_DEVICE_FUNC T Get(int64_t coord[NDIMS]) const {\n    return ptr_[shape().template Coordinate2Offset<NDIMS>(coord)];\n  }\n\n  template<int NDIMS>\n  OF_DEVICE_FUNC T* Mut(int64_t offset) const {\n    return ptr_ + offset;\n  }\n\n  template<int NDIMS>\n  OF_DEVICE_FUNC T* Mut(int64_t coord[NDIMS]) const {\n    return ptr_ + shape().template Coordinate2Offset<NDIMS>(coord);\n  }\n\n  template<int NDIMS, typename X>\n  OF_DEVICE_FUNC void Assign(const X& x) const {\n    size_t n = shape_.ElemNum();\n    XPU_1D_KERNEL_LOOP_BEGIN(i, n);\n    ptr_[i] = x.template Get<NDIMS>(i);\n    XPU_1D_KERNEL_LOOP_END();\n  }\n\n  template<template<typename> class binary_func, int NDIMS, typename X>\n  OF_DEVICE_FUNC void BinaryAssign(const X& x) const {\n    size_t n = shape_.ElemNum();\n    XPU_1D_KERNEL_LOOP_BEGIN(i, n);\n    T* ptr_i = ptr_ + i;\n    *ptr_i = binary_func<T>::Invoke(*ptr_i, x.template Get<NDIMS>(i));\n    XPU_1D_KERNEL_LOOP_END();\n  }\n\n private:\n  XpuShape shape_;\n  T* ptr_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_NDARRAY_XPU_VAR_NDARRAY_H_\n"
  },
  {
    "path": "oneflow/core/ndarray/xpu_var_ndarray_builder.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_NDARRAY_XPU_VAR_NDARRAY_BUILDER_H_\n#define ONEFLOW_CORE_NDARRAY_XPU_VAR_NDARRAY_BUILDER_H_\n\n#include \"oneflow/core/ndarray/xpu_var_ndarray.h\"\n\nnamespace oneflow {\n\ntemplate<typename T>\nclass XpuVarNdarrayBuilder final {\n public:\n  XpuVarNdarrayBuilder() = default;\n  XpuVarNdarrayBuilder(const XpuVarNdarrayBuilder&) = default;\n  ~XpuVarNdarrayBuilder() = default;\n\n  XpuVarNdarray<T> operator()(const Shape& shape, T* ptr) const {\n    return XpuVarNdarray<T>(shape, ptr);\n  }\n  template<typename DT = T>\n  typename std::enable_if<!std::is_same<DT, const DT>::value, XpuVarNdarray<DT>>::type operator()(\n      Blob* blob, int ndims_extend_to) const {\n    return XpuVarNdarray<DT>(blob, ndims_extend_to);\n  }\n  template<typename DT = T>\n  typename std::enable_if<!std::is_same<DT, const DT>::value, XpuVarNdarray<const DT>>::type\n  operator()(const Blob* blob, int ndims_extend_to) const {\n    return XpuVarNdarray<const DT>(blob, ndims_extend_to);\n  }\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_NDARRAY_XPU_VAR_NDARRAY_BUILDER_H_\n"
  },
  {
    "path": "oneflow/core/operator/acc_tick_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/operator/acc_tick_op.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nMaybe<void> InferBlobDescs(const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp) {\n  *GetBlobDesc4BnInOp(\"acc\") = *GetBlobDesc4BnInOp(\"one\");\n  GetBlobDesc4BnInOp(\"acc\")->set_shape(Shape({1LL}));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\nMaybe<void> AccTickOp::InitFromOpConf() {\n  CHECK(op_conf().has_acc_tick_conf());\n\n  EnrollInputBn(\"one\", false);\n  EnrollOutputBn(\"acc\", false);\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> AccTickOp::InferLogicalOutBlobDescs(\n    const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,\n    const ParallelDesc& parallel_desc) const {\n  return InferBlobDescs(BlobDesc4BnInOp);\n}\n\nMaybe<void> AccTickOp::InferOutBlobDescs(\n    const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,\n    const ParallelContext* parallel_ctx) const {\n  return InferBlobDescs(GetBlobDesc4BnInOp);\n}\n\nMaybe<void> AccTickOp::InferOpTimeShape(\n    const std::function<Maybe<const Shape>(const std::string&)>& GetTimeShape4BnInOp,\n    std::shared_ptr<const Shape>* time_shape) const {\n  const int32_t max_acc_num = op_conf().acc_tick_conf().max_acc_num();\n  std::shared_ptr<const Shape> in_shape = JUST(GetTimeShape4BnInOp(\"one\"));\n  CHECK_EQ_OR_RETURN(in_shape->elem_cnt() % max_acc_num, 0);\n  DimVector in_dim_vec = in_shape->dim_vec();\n  std::shared_ptr<Shape> op_time_shape;\n  if (in_dim_vec.back() == max_acc_num) {\n    in_dim_vec.pop_back();\n    op_time_shape.reset(new Shape(in_dim_vec));\n  } else if (in_dim_vec.back() % max_acc_num == 0) {\n    in_dim_vec.back() /= max_acc_num;\n    op_time_shape.reset(new Shape(in_dim_vec));\n  } else {\n    op_time_shape.reset(new Shape({in_shape->elem_cnt() / max_acc_num}));\n  }\n  *time_shape = op_time_shape;\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> AccTickOp::GetSbpSignatures(\n    const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,\n    SbpSignatureList* sbp_sig_list) const {\n  return Maybe<void>::Ok();\n}\n\nREGISTER_OP(OperatorConf::kAccTickConf, AccTickOp);\nREGISTER_TICK_TOCK_OP(OperatorConf::kAccTickConf);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/operator/acc_tick_op.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_OPERATOR_ACC_TICK_OP_H_\n#define ONEFLOW_CORE_OPERATOR_ACC_TICK_OP_H_\n\n#include \"oneflow/core/operator/operator.h\"\n\nnamespace oneflow {\n\nclass AccTickOp final : public Operator {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(AccTickOp);\n  AccTickOp() = default;\n  ~AccTickOp() = default;\n\n  Maybe<void> InitFromOpConf() override;\n\n  Maybe<void> InferLogicalOutBlobDescs(\n      const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,\n      const ParallelDesc& parallel_desc) const override;\n  Maybe<void> InferOutBlobDescs(\n      const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,\n      const ParallelContext* parallel_ctx) const override;\n  Maybe<void> InferOpTimeShape(\n      const std::function<Maybe<const Shape>(const std::string&)>& GetTimeShape4BnInOp,\n      std::shared_ptr<const Shape>* time_shape) const override;\n  Maybe<void> GetSbpSignatures(\n      const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,\n      SbpSignatureList* sbp_sig_list) const override;\n\n private:\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_OPERATOR_ACC_TICK_OP_H_\n"
  },
  {
    "path": "oneflow/core/operator/arg_modifier_signature.proto",
    "content": "syntax = \"proto2\";\npackage oneflow;\n\nmessage InputBlobModifier {\n  optional bool is_mutable = 1 [default = false];\n  optional bool requires_grad = 3 [default = false];\n}\n\nmessage OutputBlobModifier {\n  optional bool is_mutable = 1 [default = false];\n  optional bool requires_grad = 2 [default = false];\n  optional bool header_infered_before_compute = 3 [default = true];\n  oneof inplace_type {\n    string mutable_inplace_ibn = 20;\n    string const_inplace_ibn = 21;\n  }\n}\n\nmessage ArgModifierSignature {\n  map<string, InputBlobModifier> ibn2input_blob_modifier = 1;\n  map<string, OutputBlobModifier> obn2output_blob_modifier = 2;\n}\n"
  },
  {
    "path": "oneflow/core/operator/assign_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/operator/operator.h\"\n\nnamespace oneflow {\n\nclass AssignOp final : public Operator {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(AssignOp);\n  AssignOp() = default;\n  ~AssignOp() override = default;\n\n  Maybe<void> InitFromOpConf() override;\n  Maybe<void> InferLogicalOutBlobDescs(\n      const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,\n      const ParallelDesc& parallel_desc) const override;\n  Maybe<void> InferOutBlobDescs(\n      const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,\n      const ParallelContext* parallel_ctx) const override;\n\n private:\n  Maybe<void> GetSbpSignatures(\n      const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,\n      SbpSignatureList* sbp_sig_list) const override;\n};\n\nMaybe<void> AssignOp::InitFromOpConf() {\n  CHECK(op_conf().has_assign_conf());\n  EnrollInputBn(\"ref\")->set_is_mutable(true);\n  EnrollInputBn(\"value\");\n  return Maybe<void>::Ok();\n}\n\nstd::string DebugString(const BlobDesc& blob_desc) {\n  BlobDescProto blob_desc_proto;\n  blob_desc.ToProto(&blob_desc_proto);\n  return blob_desc_proto.DebugString();\n}\n\nnamespace {\n\nMaybe<void> InferBlobDescs(const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp) {\n  CHECK_OR_RETURN(*BlobDesc4BnInOp(\"ref\") == *BlobDesc4BnInOp(\"value\"))\n      << \"\\nref_blob_desc: \" << DebugString(*BlobDesc4BnInOp(\"ref\"))\n      << \"\\nvalue_blob_desc: \" << DebugString(*BlobDesc4BnInOp(\"value\"));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\nMaybe<void> AssignOp::InferLogicalOutBlobDescs(\n    const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,\n    const ParallelDesc& parallel_desc) const {\n  return InferBlobDescs(BlobDesc4BnInOp);\n}\n\nMaybe<void> AssignOp::InferOutBlobDescs(\n    const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,\n    const ParallelContext* parallel_ctx) const {\n  return InferBlobDescs(GetBlobDesc4BnInOp);\n}\n\nMaybe<void> AssignOp::GetSbpSignatures(\n    const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,\n    SbpSignatureList* sbp_sig_list) const {\n  SbpSignatureBuilder()\n      .Split(input_bns(), 0)\n      .MakeSplitSignatureListBuilder(\n          JUST(LogicalBlobDesc4Ibn(input_bns().Get(0))).shape().NumAxes())\n      .Build(sbp_sig_list);\n  return Maybe<void>::Ok();\n}\n\nREGISTER_OP(OperatorConf::kAssignConf, AssignOp);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/operator/boxing_identity_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/operator/operator.h\"\n#include \"oneflow/core/common/protobuf.h\"\n#include \"oneflow/core/register/tensor_slice_view.h\"\n\nnamespace oneflow {\n\nclass BoxingIdentityOp : public Operator {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(BoxingIdentityOp);\n  BoxingIdentityOp() = default;\n  ~BoxingIdentityOp() override = default;\n\n  Maybe<void> InitFromOpConf() override;\n  Maybe<void> InferLogicalOutBlobDescs(\n      const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,\n      const ParallelDesc& parallel_desc) const override {\n    UNIMPLEMENTED_THEN_RETURN();\n  }\n  Maybe<void> InferOutBlobDescs(\n      const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,\n      const ParallelContext* parallel_ctx) const override;\n\n private:\n  LogicalBlobId lbi4ibn(const std::string& input_bn) const override;\n  LogicalBlobId lbi4obn(const std::string& output_bn) const override;\n};\n\nMaybe<void> BoxingIdentityOp::InitFromOpConf() {\n  EnrollInputBn(\"in\", false);\n  EnrollOutputBn(\"out\", false);\n  return Maybe<void>::Ok();\n}\n\nLogicalBlobId BoxingIdentityOp::lbi4ibn(const std::string& input_bn) const {\n  return this->op_conf().boxing_identity_conf().lbi();\n}\n\nLogicalBlobId BoxingIdentityOp::lbi4obn(const std::string& output_bn) const {\n  return this->op_conf().boxing_identity_conf().lbi();\n}\n\nMaybe<void> BoxingIdentityOp::InferOutBlobDescs(\n    const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,\n    const ParallelContext* parallel_ctx) const {\n  *GetBlobDesc4BnInOp(\"out\") = *GetBlobDesc4BnInOp(\"in\");\n  return Maybe<void>::Ok();\n}\n\nREGISTER_OP(OperatorConf::kBoxingIdentityConf, BoxingIdentityOp);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/operator/boxing_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/operator/boxing_op.h\"\n#include \"oneflow/core/common/protobuf.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nvoid EraseEmptyBnInVec(const std::function<const BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,\n                       PbRpf<std::string>* bns) {\n  size_t idx_available = 0;\n  for (size_t i = 0; i < bns->size(); ++i) {\n    if (GetBlobDesc4BnInOp((*bns)[i])) {\n      if (i != idx_available) { (*bns)[idx_available] = (*bns)[i]; }\n      ++idx_available;\n    }\n  }\n  bns->erase(bns->begin() + idx_available, bns->end());\n}\n\n}  // namespace\n\nvoid BoxingOp::VirtualGenKernelConf(\n    std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,\n    const ParallelContext* parallel_ctx, KernelConf* kernel_conf) const {\n  OpAttribute* op_attribute = kernel_conf->mutable_op_attribute();\n  EraseEmptyBnInVec(GetBlobDesc4BnInOp, op_attribute->mutable_input_bns());\n  EraseEmptyBnInVec(GetBlobDesc4BnInOp, op_attribute->mutable_output_bns());\n}\n\nMaybe<void> BoxingOp::InitFromOpConf() {\n  CHECK(op_conf().has_boxing_conf());\n  const BoxingOpConf& boxing_conf = op_conf().boxing_conf();\n\n  for (int32_t i = 0; i < boxing_conf.in_num(); ++i) {\n    EnrollInputBn(\"in_\" + std::to_string(i), false);\n  }\n  if (boxing_conf.in_box_case() == BoxingOpConf::kAddBox\n      && boxing_conf.out_box_case() == BoxingOpConf::kSplitBox) {\n    EnrollTmpBn(\"middle\");\n  }\n  for (int32_t i = 0; i < boxing_conf.out_num(); ++i) {\n    EnrollOutputBn(\"out_\" + std::to_string(i), false);\n  }\n  return Maybe<void>::Ok();\n}\n\nLogicalBlobId BoxingOp::lbi4ibn(const std::string& input_bn) const {\n  return op_conf().boxing_conf().lbi();\n}\n\nLogicalBlobId BoxingOp::lbi4obn(const std::string& output_bn) const {\n  return op_conf().boxing_conf().lbi();\n}\n\nSymbol<OperatorConf> BoxingOp::GetOpConfWithoutOpNameAndLbn() const {\n  OperatorConf op_conf(this->op_conf());\n  op_conf.set_name(\"undefined-op-name\");\n  CHECK(op_conf.has_boxing_conf());\n  auto* boxing_conf = op_conf.mutable_boxing_conf();\n  LogicalBlobId empty_logical_blob_id;\n  *boxing_conf->mutable_lbi() = empty_logical_blob_id;\n  return SymbolOf(op_conf);\n}\n\nMaybe<void> BoxingOp::InferBlobDescs(\n    const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp, bool is_logical) const {\n  const BoxingOpConf& conf = op_conf().boxing_conf();\n  BlobDesc* first_in_blob = BlobDesc4BnInOp(input_bns().Get(0));\n  if (conf.in_box_case() == BoxingOpConf::kAddBox) {\n    const Shape& first_in_blob_shape = first_in_blob->shape();\n    for (const std::string& ibn : input_bns()) {\n      CHECK_EQ_OR_RETURN(first_in_blob_shape, BlobDesc4BnInOp(ibn)->shape());\n    }\n  }\n\n  DimVector data_tmp_blob_shape_vec = BlobDesc4BnInOp(input_bns().Get(0))->shape().dim_vec();\n  JUST(InferTmpBlobDesc(BlobDesc4BnInOp, &data_tmp_blob_shape_vec, is_logical));\n\n  if (conf.out_box_case() == BoxingOpConf::kSplitBox) {\n    const BoxSplitConf& split_conf = conf.split_box();\n    CHECK_GE_OR_RETURN(split_conf.axis(), 0);\n    CHECK_LT_OR_RETURN(split_conf.axis(), data_tmp_blob_shape_vec.size());\n    FOR_RANGE(size_t, i, 0, output_bns().size()) {\n      BlobDesc* out_blob_desc = BlobDesc4BnInOp(output_bns().Get(i));\n      *out_blob_desc = *first_in_blob;\n      CHECK_GT_OR_RETURN(split_conf.part_num(i), 0);\n      data_tmp_blob_shape_vec[split_conf.axis()] = split_conf.part_num(i);\n      out_blob_desc->set_shape(Shape(data_tmp_blob_shape_vec));\n    }\n  } else if (conf.out_box_case() == BoxingOpConf::kCloneBox) {\n    for (const std::string& obn : output_bns()) {\n      BlobDesc* out_blob_desc = BlobDesc4BnInOp(obn);\n      *out_blob_desc = *first_in_blob;\n      out_blob_desc->set_shape(Shape(data_tmp_blob_shape_vec));\n    }\n  } else {\n    UNIMPLEMENTED_THEN_RETURN();\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> BoxingOp::InferLogicalOutBlobDescs(\n    const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,\n    const ParallelDesc& parallel_desc) const {\n  return InferBlobDescs(BlobDesc4BnInOp, true);\n}\n\nMaybe<void> BoxingOp::InferOutBlobDescs(\n    const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,\n    const ParallelContext* parallel_ctx) const {\n  return InferBlobDescs(GetBlobDesc4BnInOp, false);\n}\n\nMaybe<void> BoxingOp::InferTmpBlobDesc(\n    std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, DimVector* data_tmp_vec_ptr,\n    bool is_logical) const {\n  const BoxingOpConf& conf = op_conf().boxing_conf();\n  if (conf.in_box_case() == BoxingOpConf::kConcatBox) {\n    int32_t concat_axis = conf.concat_box().axis();\n    CHECK_GE_OR_RETURN(concat_axis, 0);\n    FOR_RANGE(size_t, ib_idx, 1, input_bns().size()) {\n      const BlobDesc* in_blob_desc = GetBlobDesc4BnInOp(input_bns().Get(ib_idx));\n      const DimVector& in_blob_shape_vec = in_blob_desc->shape().dim_vec();\n      CHECK_LT_OR_RETURN(concat_axis, in_blob_shape_vec.size());\n      FOR_RANGE(size_t, i, 0, in_blob_shape_vec.size()) {\n        if (i == concat_axis) {\n          (*data_tmp_vec_ptr)[i] += in_blob_shape_vec[i];\n        } else {\n          CHECK_EQ_OR_RETURN((*data_tmp_vec_ptr)[i], in_blob_shape_vec[i]);\n        }\n      }\n    }\n  }\n\n  CHECK_NE_OR_RETURN(conf.out_box_case(), BoxingOpConf::OUT_BOX_NOT_SET);\n  if (conf.in_box_case() == BoxingOpConf::kAddBox\n      && conf.out_box_case() == BoxingOpConf::kSplitBox) {\n    if (!is_logical) {\n      BlobDesc* data_tmp_blob_desc = GetBlobDesc4BnInOp(SoleTbn());\n      data_tmp_blob_desc->set_shape(Shape(*data_tmp_vec_ptr));\n      data_tmp_blob_desc->set_data_type(GetBlobDesc4BnInOp(input_bns().Get(0))->data_type());\n    }\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> BoxingOp::InferSbpSignature(\n    SbpSignature* sbp_signature, const SbpSignature& sbp_sig_conf,\n    const std::function<int32_t(const SbpSignature&)>& CalcOrderValue4SbpSig,\n    std::function<Maybe<const SbpInferHint*>(const std::string&)> SbpInferHint4Ibn,\n    const ParallelDesc& parallel_desc) const {\n  auto* bn2sbp = sbp_signature->mutable_bn_in_op2sbp_parallel();\n  const SbpParallel& sbp_parallel = JUST(SbpInferHint4Ibn(input_bns().Get(0)))->sbp_parallel();\n  FOR_RANGE(int32_t, i, 0, input_bns().size()) {\n    CHECK_OR_RETURN(sbp_parallel == JUST(SbpInferHint4Ibn(input_bns().Get(i)))->sbp_parallel());\n  }\n  (*bn2sbp)[input_bns().Get(0)] = sbp_parallel;\n  (*bn2sbp)[output_bns().Get(0)] = sbp_parallel;\n  return Maybe<void>::Ok();\n}\n\nREGISTER_OP(OperatorConf::kBoxingConf, BoxingOp);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/operator/boxing_op.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_OPERATOR_BOXING_OP_H_\n#define ONEFLOW_CORE_OPERATOR_BOXING_OP_H_\n\n#include \"oneflow/core/operator/operator.h\"\n\nnamespace oneflow {\n\nclass BoxingOp final : public Operator {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(BoxingOp);\n  BoxingOp() = default;\n  ~BoxingOp() = default;\n\n  Maybe<void> InitFromOpConf() override;\n  Maybe<void> InferLogicalOutBlobDescs(\n      const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,\n      const ParallelDesc& parallel_desc) const override;\n  Maybe<void> InferOutBlobDescs(\n      const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,\n      const ParallelContext* parallel_ctx) const override;\n\n protected:\n  void VirtualGenKernelConf(std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,\n                            const ParallelContext* parallel_ctx,\n                            KernelConf* kernel_conf) const override;\n\n  void AddLbi2OutputIndex(const LogicalBlobId& lbi, int32_t output_index) override {}\n\n private:\n  Maybe<void> InferBlobDescs(const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,\n                             bool is_logical) const;\n  Maybe<void> InferSbpSignature(\n      SbpSignature* sbp_signature, const SbpSignature& sbp_sig_conf,\n      const std::function<int32_t(const SbpSignature&)>& CalcOrderValue4SbpSig,\n      std::function<Maybe<const SbpInferHint*>(const std::string&)> SbpInferHint4Ibn,\n      const ParallelDesc& parallel_desc) const override;\n  LogicalBlobId lbi4ibn(const std::string& input_bn) const override;\n  LogicalBlobId lbi4obn(const std::string& output_bn) const override;\n  Maybe<void> InferTmpBlobDesc(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,\n                               DimVector* data_tmp_vec_ptr, bool is_logical) const;\n  Symbol<OperatorConf> GetOpConfWithoutOpNameAndLbn() const override;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_OPERATOR_BOXING_OP_H_\n"
  },
  {
    "path": "oneflow/core/operator/boxing_zeros_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/operator/operator.h\"\n#include \"oneflow/core/common/protobuf.h\"\n\nnamespace oneflow {\n\nclass BoxingZerosOp : public Operator {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(BoxingZerosOp);\n  BoxingZerosOp() = default;\n  ~BoxingZerosOp() override = default;\n\n  Maybe<void> InitFromOpConf() override;\n  Maybe<void> InferLogicalOutBlobDescs(\n      const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,\n      const ParallelDesc& parallel_desc) const override {\n    UNIMPLEMENTED_THEN_RETURN();\n  }\n  Maybe<void> InferOutBlobDescs(\n      const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,\n      const ParallelContext* parallel_ctx) const override;\n\n private:\n  LogicalBlobId lbi4ibn(const std::string& input_bn) const override;\n  LogicalBlobId lbi4obn(const std::string& output_bn) const override;\n};\n\nMaybe<void> BoxingZerosOp::InitFromOpConf() {\n  EnrollOutputBn(\"out\", false);\n  return Maybe<void>::Ok();\n}\n\nLogicalBlobId BoxingZerosOp::lbi4ibn(const std::string& input_bn) const {\n  return this->op_conf().boxing_zeros_conf().lbi();\n}\n\nLogicalBlobId BoxingZerosOp::lbi4obn(const std::string& output_bn) const {\n  return this->op_conf().boxing_zeros_conf().lbi();\n}\n\nMaybe<void> BoxingZerosOp::InferOutBlobDescs(\n    const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,\n    const ParallelContext* parallel_ctx) const {\n  const BoxingZerosOpConf& conf = this->op_conf().boxing_zeros_conf();\n  BlobDesc* out = GetBlobDesc4BnInOp(\"out\");\n  out->set_data_type(conf.data_type());\n  out->set_shape(Shape(conf.shape()));\n  return Maybe<void>::Ok();\n}\n\nREGISTER_OP(OperatorConf::kBoxingZerosConf, BoxingZerosOp);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/operator/broadcast_to_compatible_with_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/operator/operator.h\"\n#include \"oneflow/core/common/shape_view.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nMaybe<void> GetBroadcastShape(const Shape& a_shape, const Shape& b_shape, Shape* broadcast_shape) {\n  Shape max_shape = Shape::Ones(std::max(a_shape.NumAxes(), b_shape.NumAxes()));\n  Shape a_extend_shape = CreateLeftExtendedShape(ShapeView(a_shape), max_shape.NumAxes());\n  Shape b_extend_shape = CreateLeftExtendedShape(ShapeView(b_shape), max_shape.NumAxes());\n  FOR_RANGE(int64_t, i, 0, max_shape.NumAxes()) {\n    CHECK_OR_RETURN(a_extend_shape.At(i) == 1 || b_extend_shape.At(i) == 1\n                    || a_extend_shape.At(i) == b_extend_shape.At(i))\n        << \"shape \" << a_shape.ToString() << \" and shape \" << b_shape.ToString()\n        << \" are not broadcastable\";\n    max_shape.Set(i, std::max(a_extend_shape.At(i), b_extend_shape.At(i)));\n  }\n  *broadcast_shape = max_shape;\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InferBlobDescs(const OperatorConf& op_conf,\n                           const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp) {\n  int64_t num_compatibles = op_conf.broadcast_to_compatible_with_conf().compatible_size();\n  const BlobDesc* x_desc = BlobDesc4BnInOp(\"x\");\n  Shape broadcasted_shape(x_desc->shape());\n  FOR_RANGE(int64_t, i, 0, num_compatibles) {\n    const BlobDesc* compatible_i = BlobDesc4BnInOp(GenRepeatedBn(\"compatible\", i));\n    JUST(GetBroadcastShape(broadcasted_shape, compatible_i->shape(), &broadcasted_shape));\n  }\n  BlobDesc* y_desc = BlobDesc4BnInOp(\"y\");\n  y_desc->CopyFrom(*x_desc);\n  y_desc->set_shape(broadcasted_shape);\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\nclass BroadcastToCompatibleWithOp final : public Operator {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(BroadcastToCompatibleWithOp);\n  BroadcastToCompatibleWithOp() = default;\n  ~BroadcastToCompatibleWithOp() override = default;\n\n  Maybe<void> InitFromOpConf() override {\n    CHECK(op_conf().has_broadcast_to_compatible_with_conf());\n    EnrollInputBn(\"x\");\n    EnrollRepeatedInputBn(\"compatible\", false);\n    EnrollOutputBn(\"y\");\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> InferLogicalOutBlobDescs(\n      const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,\n      const ParallelDesc& parallel_desc) const override {\n    return InferBlobDescs(op_conf(), BlobDesc4BnInOp);\n  }\n\n  Maybe<void> InferOutBlobDescs(\n      const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,\n      const ParallelContext* parallel_ctx) const override {\n    return InferBlobDescs(op_conf(), GetBlobDesc4BnInOp);\n  }\n\n private:\n  void VirtualGenKernelConf(std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,\n                            const ParallelContext* parallel_ctx,\n                            KernelConf* kernel_conf) const override {\n    auto* conf = kernel_conf->mutable_broadcast_to_compatible_with_conf();\n    const BlobDesc* x_desc = GetBlobDesc4BnInOp(\"x\");\n    const BlobDesc* y_desc = GetBlobDesc4BnInOp(\"y\");\n    Shape x_extend_shape =\n        CreateLeftExtendedShape(ShapeView(x_desc->shape()), y_desc->shape().NumAxes());\n    FOR_RANGE(int64_t, i, 0, y_desc->shape().NumAxes()) {\n      if (x_extend_shape.At(i) == 1 && y_desc->shape().At(i) != 1)\n        conf->mutable_broadcast_axes()->Add(i);\n    }\n  }\n\n  Maybe<void> GetSbpSignatures(\n      const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,\n      SbpSignatureList* sbp_sig_list) const override {\n    Shape broadcasted_shape{1};\n    for (const std::string& ibn : input_bns()) {\n      const Shape& input_shape = JUST(LogicalBlobDesc4Ibn(ibn)).shape();\n      JUST(GetBroadcastShape(broadcasted_shape, input_shape, &broadcasted_shape));\n    }\n\n    const int64_t broadcast_num_axes = broadcasted_shape.NumAxes();\n    HashMap<std::string, Shape> ibn2extend_shape;\n    for (const std::string& ibn : input_bns()) {\n      const Shape& input_shape = JUST(LogicalBlobDesc4Ibn(ibn)).shape();\n      CHECK_OR_RETURN(\n          ibn2extend_shape\n              .emplace(ibn, CreateLeftExtendedShape(ShapeView(input_shape), broadcast_num_axes))\n              .second);\n    }\n\n    FOR_RANGE(int64_t, i, 0, broadcast_num_axes) {\n      if (broadcasted_shape.At(i) == 1) { continue; }\n      SbpSignature sbp_sig;\n      for (const auto& pair : ibn2extend_shape) {\n        if (pair.second.At(i) == 1) {\n          (*sbp_sig.mutable_bn_in_op2sbp_parallel())[pair.first].mutable_broadcast_parallel();\n        } else {\n          (*sbp_sig.mutable_bn_in_op2sbp_parallel())[pair.first].mutable_split_parallel()->set_axis(\n              i - (broadcast_num_axes - pair.second.NumAxes()));\n        }\n      }\n      (*sbp_sig.mutable_bn_in_op2sbp_parallel())[\"y\"].mutable_split_parallel()->set_axis(i);\n      *sbp_sig_list->mutable_sbp_signature()->Add() = sbp_sig;\n    }\n\n    PbRpf<std::string> compatible_bns;\n    int64_t num_compatibles = op_conf().broadcast_to_compatible_with_conf().compatible_size();\n    FOR_RANGE(int64_t, i, 0, num_compatibles) {\n      *compatible_bns.Add() = GenRepeatedBn(\"compatible\", i);\n    }\n    SbpSignatureBuilder()\n        .PartialSum(\"x\")\n        .Broadcast(compatible_bns)\n        .PartialSum(\"y\")\n        .Build(sbp_sig_list->mutable_sbp_signature()->Add());\n    SbpSignatureBuilder()\n        .Broadcast(\"x\")\n        .PartialSum(compatible_bns)\n        .Broadcast(\"y\")\n        .Build(sbp_sig_list->mutable_sbp_signature()->Add());\n    return Maybe<void>::Ok();\n  }\n};\n\nREGISTER_OP(OperatorConf::kBroadcastToCompatibleWithConf, BroadcastToCompatibleWithOp);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/operator/callback_notify_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/operator/callback_notify_op.h\"\n#include \"oneflow/core/job/sbp_signature_builder.h\"\n\nnamespace oneflow {\n\nMaybe<void> CallbackNotifyOp::InitFromOpConf() {\n  CHECK(op_conf().has_callback_notify_conf());\n  EnrollInputBn(\"in\", false);\n  return Maybe<void>::Ok();\n}\n\nnamespace {\n\nMaybe<void> InferBlobDescs(const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp) {\n  CHECK_OR_RETURN(BlobDesc4BnInOp(\"in\")->shape() == Shape({1}));\n  CHECK_OR_RETURN(IsIntegralDataType(BlobDesc4BnInOp(\"in\")->data_type()));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\nMaybe<void> CallbackNotifyOp::InferLogicalOutBlobDescs(\n    const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,\n    const ParallelDesc& parallel_desc) const {\n  CHECK_EQ_OR_RETURN(parallel_desc.parallel_num(), 1);\n  return InferBlobDescs(BlobDesc4BnInOp);\n}\n\nMaybe<void> CallbackNotifyOp::InferOutBlobDescs(\n    const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,\n    const ParallelContext* parallel_ctx) const {\n  CHECK_EQ_OR_RETURN(parallel_ctx->parallel_num(), 1);\n  return InferBlobDescs(GetBlobDesc4BnInOp);\n}\n\nMaybe<void> CallbackNotifyOp::GetSbpSignatures(SbpSignatureList* sbp_sig_list) const {\n  return Maybe<void>::Ok();\n}\n\nREGISTER_CPU_OP(OperatorConf::kCallbackNotifyConf, CallbackNotifyOp);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/operator/callback_notify_op.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_OPERATOR_CALLBACK_NOTIFY_OP_H_\n#define ONEFLOW_CORE_OPERATOR_CALLBACK_NOTIFY_OP_H_\n\n#include \"oneflow/core/operator/operator.h\"\n\nnamespace oneflow {\n\nclass CallbackNotifyOp final : public Operator {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CallbackNotifyOp);\n  CallbackNotifyOp() = default;\n  ~CallbackNotifyOp() = default;\n\n  Maybe<void> InitFromOpConf() override;\n  Maybe<void> InferLogicalOutBlobDescs(\n      const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,\n      const ParallelDesc& parallel_desc) const override;\n  Maybe<void> InferOutBlobDescs(\n      const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,\n      const ParallelContext* parallel_ctx) const override;\n\n private:\n  Maybe<void> GetSbpSignatures(SbpSignatureList* sbp_sig_list) const override;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_OPERATOR_CALLBACK_NOTIFY_OP_H_\n"
  },
  {
    "path": "oneflow/core/operator/case_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/operator/case_op.h\"\n#include \"oneflow/core/job/sbp_signature_builder.h\"\n\nnamespace oneflow {\n\nMaybe<void> CaseOp::InitFromOpConf() {\n  EnrollInputBn(\"in\", false);\n  EnrollRepeatedOutputBn(\"out\", false);\n  return Maybe<void>::Ok();\n}\n\nnamespace {\n\nMaybe<void> InferBlobDescs(const Operator& op,\n                           const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp) {\n  const BlobDesc* in = BlobDesc4BnInOp(\"in\");\n  CHECK_EQ_OR_RETURN(in->shape().elem_cnt(), 1);\n  const DataType data_type = in->data_type();\n  CHECK_OR_RETURN(IsIntegralDataType(data_type));\n  for (const std::string& obn : op.output_bns()) {\n    BlobDesc* out = BlobDesc4BnInOp(obn);\n    out->set_shape(Shape({1}));\n    out->set_data_type(data_type);\n  }\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\nMaybe<void> CaseOp::InferLogicalOutBlobDescs(\n    const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,\n    const ParallelDesc& parallel_desc) const {\n  return InferBlobDescs(*this, BlobDesc4BnInOp);\n}\n\nMaybe<void> CaseOp::InferOutBlobDescs(\n    const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,\n    const ParallelContext* parallel_ctx) const {\n  return InferBlobDescs(*this, GetBlobDesc4BnInOp);\n}\n\nMaybe<void> CaseOp::GetSbpSignatures(\n    const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,\n    SbpSignatureList* sbp_sig_list) const {\n  return Maybe<void>::Ok();\n}\n\nREGISTER_CPU_OP(OperatorConf::kCaseConf, CaseOp);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/operator/case_op.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_OPERATOR_CASE_OP_H_\n#define ONEFLOW_CORE_OPERATOR_CASE_OP_H_\n\n#include \"oneflow/core/operator/operator.h\"\n\nnamespace oneflow {\n\nclass CaseOp final : public Operator {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CaseOp);\n  CaseOp() = default;\n  ~CaseOp() override = default;\n\n  Maybe<void> InitFromOpConf() override;\n  Maybe<void> InferLogicalOutBlobDescs(\n      const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,\n      const ParallelDesc& parallel_desc) const override;\n  Maybe<void> InferOutBlobDescs(\n      const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,\n      const ParallelContext* parallel_ctx) const override;\n\n private:\n  Maybe<void> GetSbpSignatures(\n      const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,\n      SbpSignatureList* sbp_sig_list) const override;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_OPERATOR_CASE_OP_H_\n"
  },
  {
    "path": "oneflow/core/operator/collective_boxing_ops.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/operator/operator.h\"\n#include \"oneflow/core/graph/boxing/collective_boxing_util.h\"\n\nnamespace oneflow {\n\nusing namespace boxing::collective;\n\nclass CollectiveBoxingGenericOp : public Operator {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CollectiveBoxingGenericOp);\n  CollectiveBoxingGenericOp() = default;\n  ~CollectiveBoxingGenericOp() override = default;\n\n private:\n  Maybe<void> InitFromOpConf() override {\n    CHECK(op_conf().has_collective_boxing_generic_conf());\n    const RankDesc& rank_desc = op_conf().collective_boxing_generic_conf().rank_desc();\n    if (GenericOpHasInput(rank_desc)) { EnrollInputBn(\"in\", false); }\n    if (GenericOpHasOutput(rank_desc)) { EnrollOutputBn(\"out\", false); }\n    return Maybe<void>::Ok();\n  }\n\n  LogicalBlobId lbi4ibn(const std::string& input_bn) const override {\n    return this->op_conf().collective_boxing_generic_conf().lbi();\n  }\n\n  LogicalBlobId lbi4obn(const std::string& output_bn) const override {\n    return this->op_conf().collective_boxing_generic_conf().lbi();\n  }\n\n  Maybe<void> InferLogicalOutBlobDescs(\n      const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,\n      const ParallelDesc& parallel_desc) const override {\n    UNIMPLEMENTED_THEN_RETURN();\n  }\n  Maybe<void> InferOutBlobDescs(\n      const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,\n      const ParallelContext* parallel_ctx) const override {\n    const RankDesc& rank_desc = op_conf().collective_boxing_generic_conf().rank_desc();\n    const DataType data_type = rank_desc.op_desc().data_type();\n    if (GenericOpHasInput(rank_desc)) {\n      const BlobDesc* in = GetBlobDesc4BnInOp(\"in\");\n      CHECK_OR_RETURN(!in->is_dynamic());\n      CHECK_EQ_OR_RETURN(in->data_type(), data_type);\n      CHECK_EQ_OR_RETURN(in->shape(), GenericOpGetInputShape(rank_desc));\n    }\n    if (GenericOpHasOutput(rank_desc)) {\n      BlobDesc* out = GetBlobDesc4BnInOp(\"out\");\n      out->set_data_type(data_type);\n      out->set_shape(GenericOpGetOutputShape(rank_desc));\n    }\n    return Maybe<void>::Ok();\n  }\n};\n\nREGISTER_OP(OperatorConf::kCollectiveBoxingGenericConf, CollectiveBoxingGenericOp);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/operator/collective_boxing_pack_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/operator/operator.h\"\n#include \"oneflow/core/common/protobuf.h\"\n#include \"oneflow/core/register/tensor_slice_view.h\"\n\nnamespace oneflow {\n\nclass CollectiveBoxingPackOp : public Operator {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CollectiveBoxingPackOp);\n  CollectiveBoxingPackOp() = default;\n  ~CollectiveBoxingPackOp() override = default;\n\n  Maybe<void> InitFromOpConf() override;\n\n  Maybe<void> InferLogicalOutBlobDescs(\n      const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,\n      const ParallelDesc& parallel_desc) const override {\n    UNIMPLEMENTED_THEN_RETURN();\n  }\n  Maybe<void> InferOutBlobDescs(\n      const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,\n      const ParallelContext* parallel_ctx) const override;\n\n private:\n  LogicalBlobId lbi4ibn(const std::string& input_bn) const override;\n  LogicalBlobId lbi4obn(const std::string& output_bn) const override;\n};\n\nMaybe<void> CollectiveBoxingPackOp::InitFromOpConf() {\n  EnrollInputBn(\"in\", false);\n  EnrollOutputBn(\"out\", false);\n  return Maybe<void>::Ok();\n}\n\nLogicalBlobId CollectiveBoxingPackOp::lbi4ibn(const std::string& input_bn) const {\n  return this->op_conf().collective_boxing_pack_conf().lbi();\n}\n\nLogicalBlobId CollectiveBoxingPackOp::lbi4obn(const std::string& output_bn) const {\n  return this->op_conf().collective_boxing_pack_conf().lbi();\n}\n\nMaybe<void> CollectiveBoxingPackOp::InferOutBlobDescs(\n    const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,\n    const ParallelContext* parallel_ctx) const {\n  const BlobDesc* in_blob_desc = GetBlobDesc4BnInOp(\"in\");\n  BlobDesc* out_blob_desc = GetBlobDesc4BnInOp(\"out\");\n  *CHECK_NOTNULL(out_blob_desc) = *CHECK_NOTNULL(in_blob_desc);  // NOLINT\n  out_blob_desc->set_shape(Shape({in_blob_desc->shape().elem_cnt()}));\n  return Maybe<void>::Ok();\n}\n\nREGISTER_OP(OperatorConf::kCollectiveBoxingPackConf, CollectiveBoxingPackOp);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/operator/collective_boxing_unpack_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/operator/operator.h\"\n#include \"oneflow/core/common/protobuf.h\"\n#include \"oneflow/core/register/tensor_slice_view.h\"\n\nnamespace oneflow {\n\nclass CollectiveBoxingUnpackOp : public Operator {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CollectiveBoxingUnpackOp);\n  CollectiveBoxingUnpackOp() = default;\n  ~CollectiveBoxingUnpackOp() override = default;\n\n  Maybe<void> InitFromOpConf() override;\n\n  Maybe<void> InferLogicalOutBlobDescs(\n      const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,\n      const ParallelDesc& parallel_desc) const override {\n    UNIMPLEMENTED_THEN_RETURN();\n  }\n  Maybe<void> InferOutBlobDescs(\n      const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,\n      const ParallelContext* parallel_ctx) const override;\n\n private:\n  LogicalBlobId lbi4ibn(const std::string& input_bn) const override;\n  LogicalBlobId lbi4obn(const std::string& output_bn) const override;\n};\n\nMaybe<void> CollectiveBoxingUnpackOp::InitFromOpConf() {\n  EnrollInputBn(\"in\", false);\n  EnrollOutputBn(\"out\", false);\n  return Maybe<void>::Ok();\n}\n\nLogicalBlobId CollectiveBoxingUnpackOp::lbi4ibn(const std::string& input_bn) const {\n  return this->op_conf().collective_boxing_unpack_conf().lbi();\n}\n\nLogicalBlobId CollectiveBoxingUnpackOp::lbi4obn(const std::string& output_bn) const {\n  return this->op_conf().collective_boxing_unpack_conf().lbi();\n}\n\nMaybe<void> CollectiveBoxingUnpackOp::InferOutBlobDescs(\n    const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,\n    const ParallelContext* parallel_ctx) const {\n  const CollectiveBoxingUnpackOpConf& unpack_conf = this->op_conf().collective_boxing_unpack_conf();\n  const BlobDesc* in_blob_desc = GetBlobDesc4BnInOp(\"in\");\n  BlobDesc* out_blob_desc = GetBlobDesc4BnInOp(\"out\");\n  *out_blob_desc = *in_blob_desc;\n\n  Shape out_shape(unpack_conf.logical_shape());\n  if (unpack_conf.dst_sbp_parallel().has_split_parallel()) {\n    const int64_t dst_split_axis = unpack_conf.dst_sbp_parallel().split_parallel().axis();\n    out_shape.Set(dst_split_axis, out_shape.At(dst_split_axis) / unpack_conf.num_ranks());\n  }\n  CHECK_EQ_OR_RETURN(out_shape.elem_cnt(), in_blob_desc->shape().elem_cnt());\n  out_blob_desc->set_shape(out_shape);\n  return Maybe<void>::Ok();\n}\n\nREGISTER_OP(OperatorConf::kCollectiveBoxingUnpackConf, CollectiveBoxingUnpackOp);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/operator/constant_like_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/operator/operator.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nMaybe<void> InferBlobDescs(const OperatorConf& op_conf,\n                           const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp) {\n  const ConstantLikeOpConf& conf = op_conf.constant_like_conf();\n  BlobDesc* out_blob_desc = BlobDesc4BnInOp(\"out\");\n  *out_blob_desc = *BlobDesc4BnInOp(\"like\");\n  if (conf.has_data_type()) { out_blob_desc->set_data_type(conf.data_type()); }\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\nclass ConstantLikeOp final : public Operator {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(ConstantLikeOp);\n  ConstantLikeOp() = default;\n  ~ConstantLikeOp() = default;\n\n  Maybe<void> InitFromOpConf() override {\n    CHECK(op_conf().has_constant_like_conf());\n    EnrollInputBn(\"like\", false);\n    EnrollOutputBn(\"out\", false);\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> InferLogicalOutBlobDescs(\n      const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,\n      const ParallelDesc& parallel_desc) const override {\n    return InferBlobDescs(op_conf(), BlobDesc4BnInOp);\n  }\n\n  Maybe<void> InferOutBlobDescs(\n      const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,\n      const ParallelContext* parallel_ctx) const override {\n    return InferBlobDescs(op_conf(), GetBlobDesc4BnInOp);\n  }\n\n private:\n  Maybe<void> GetSbpSignatures(\n      const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,\n      SbpSignatureList* sbp_sig_list) const override {\n    SbpSignatureBuilder()\n        .Split(\"like\", 0)\n        .Split(\"out\", 0)\n        .MakeSplitSignatureListBuilder(JUST(LogicalBlobDesc4Ibn(\"like\")).shape().NumAxes())\n        .Build(sbp_sig_list);\n    SbpSignatureBuilder().PartialSum(\"like\").Broadcast(\"out\").Build(\n        sbp_sig_list->mutable_sbp_signature()->Add());\n    return Maybe<void>::Ok();\n  }\n};\n\nREGISTER_OP(OperatorConf::kConstantLikeConf, ConstantLikeOp);\nREGISTER_OP_SAME_OUTPUT_BLOB_REGST_NUM(OperatorConf::kConstantLikeConf, 1);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/operator/copy_comm_net_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/operator/copy_comm_net_op.h\"\n\nnamespace oneflow {\n\nMaybe<void> CopyCommNetOp::InitFromOpConf() {\n  EnrollInputBn(\"in\", false);\n  EnrollOutputBn(\"out\", false);\n  return Maybe<void>::Ok();\n}\n\nLogicalBlobId CopyCommNetOp::lbi4obn(const std::string& output_bn) const {\n  return this->op_conf().copy_comm_net_conf().lbi();\n}\n\nLogicalBlobId CopyCommNetOp::lbi4ibn(const std::string& input_bn) const {\n  return this->op_conf().copy_comm_net_conf().lbi();\n}\n\nREGISTER_OP(OperatorConf::kCopyCommNetConf, CopyCommNetOp);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/operator/copy_comm_net_op.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_OPERATOR_COPY_COMM_NET_OP_H_\n#define ONEFLOW_CORE_OPERATOR_COPY_COMM_NET_OP_H_\n\n#include \"oneflow/core/operator/operator.h\"\n\nnamespace oneflow {\n\nclass CopyCommNetOp final : public Operator {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CopyCommNetOp);\n  CopyCommNetOp() = default;\n  ~CopyCommNetOp() = default;\n\n  Maybe<void> InitFromOpConf() override;\n  Maybe<void> InferLogicalOutBlobDescs(\n      const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,\n      const ParallelDesc& parallel_desc) const override {\n    UNIMPLEMENTED_THEN_RETURN();\n  }\n\n private:\n  LogicalBlobId lbi4ibn(const std::string& input_bn) const override;\n  LogicalBlobId lbi4obn(const std::string& output_bn) const override;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_OPERATOR_COPY_COMM_NET_OP_H_\n"
  },
  {
    "path": "oneflow/core/operator/critical_section_callback_tick_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/operator/operator.h\"\n#include \"oneflow/core/job/sbp_signature_builder.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nMaybe<void> InferBlobDescs(const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp) {\n  BlobDesc* blob_desc = BlobDesc4BnInOp(\"out\");\n  blob_desc->set_shape(Shape({1}));\n  blob_desc->set_data_type(DataType::kInt8);\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\nclass CriticalSectionCallbackTickOp final : public Operator {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CriticalSectionCallbackTickOp);\n  CriticalSectionCallbackTickOp() = default;\n  ~CriticalSectionCallbackTickOp() = default;\n\n  Maybe<void> InitFromOpConf() override;\n  Maybe<void> InferLogicalOutBlobDescs(\n      const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,\n      const ParallelDesc& parallel_desc) const override;\n  Maybe<void> InferOutBlobDescs(\n      const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,\n      const ParallelContext* parallel_ctx) const override;\n\n private:\n  Maybe<void> GetSbpSignatures(\n      const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,\n      SbpSignatureList* sbp_sig_list) const override;\n};\n\nMaybe<void> CriticalSectionCallbackTickOp::InitFromOpConf() {\n  CHECK(op_conf().has_critical_section_callback_tick_conf());\n  EnrollRepeatedInputBn(\"tick\", false);\n  EnrollOutputBn(\"out\", false);\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> CriticalSectionCallbackTickOp::InferLogicalOutBlobDescs(\n    const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,\n    const ParallelDesc& parallel_desc) const {\n  return InferBlobDescs(BlobDesc4BnInOp);\n}\n\nMaybe<void> CriticalSectionCallbackTickOp::InferOutBlobDescs(\n    const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,\n    const ParallelContext* parallel_ctx) const {\n  return InferBlobDescs(GetBlobDesc4BnInOp);\n}\n\nMaybe<void> CriticalSectionCallbackTickOp::GetSbpSignatures(\n    const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,\n    SbpSignatureList* sbp_sig_list) const {\n  return Maybe<void>::Ok();\n}\n\nREGISTER_OP_SAME_OUTPUT_BLOB_REGST_NUM(OperatorConf::kCriticalSectionCallbackTickConf, 128);\nREGISTER_OP(OperatorConf::kCriticalSectionCallbackTickConf, CriticalSectionCallbackTickOp);\nREGISTER_TICK_TOCK_OP(OperatorConf::kCriticalSectionCallbackTickConf);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/operator/critical_section_wait_tick_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/operator/operator.h\"\n#include \"oneflow/core/job/sbp_signature_builder.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nMaybe<void> InferBlobDescs(const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp) {\n  BlobDesc* blob_desc = BlobDesc4BnInOp(\"out\");\n  blob_desc->set_shape(Shape({1}));\n  blob_desc->set_data_type(DataType::kInt8);\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\nclass CriticalSectionWaitTickOp final : public Operator {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CriticalSectionWaitTickOp);\n  CriticalSectionWaitTickOp() = default;\n  ~CriticalSectionWaitTickOp() = default;\n\n  Maybe<void> InitFromOpConf() override;\n  Maybe<void> InferLogicalOutBlobDescs(\n      const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,\n      const ParallelDesc& parallel_desc) const override;\n  Maybe<void> InferOutBlobDescs(\n      const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,\n      const ParallelContext* parallel_ctx) const override;\n\n private:\n  Maybe<void> GetSbpSignatures(\n      const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,\n      SbpSignatureList* sbp_sig_list) const override;\n};\n\nMaybe<void> CriticalSectionWaitTickOp::InitFromOpConf() {\n  CHECK_OR_RETURN(op_conf().has_critical_section_wait_tick_conf());\n  EnrollRepeatedInputBn(\"tick\", false);\n  EnrollOutputBn(\"out\", false);\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> CriticalSectionWaitTickOp::InferLogicalOutBlobDescs(\n    const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,\n    const ParallelDesc& parallel_desc) const {\n  return InferBlobDescs(BlobDesc4BnInOp);\n}\n\nMaybe<void> CriticalSectionWaitTickOp::InferOutBlobDescs(\n    const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,\n    const ParallelContext* parallel_ctx) const {\n  return InferBlobDescs(GetBlobDesc4BnInOp);\n}\n\nMaybe<void> CriticalSectionWaitTickOp::GetSbpSignatures(\n    const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,\n    SbpSignatureList* sbp_sig_list) const {\n  return Maybe<void>::Ok();\n}\n\nREGISTER_OP_SAME_OUTPUT_BLOB_REGST_NUM(OperatorConf::kCriticalSectionWaitTickConf, 2);\nREGISTER_OP(OperatorConf::kCriticalSectionWaitTickConf, CriticalSectionWaitTickOp);\nREGISTER_TICK_TOCK_OP(OperatorConf::kCriticalSectionWaitTickConf);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/operator/cwise_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/operator/cwise_op.h\"\n\nnamespace oneflow {\n\nMaybe<void> CWiseOp::InitFromOpConf() {\n  EnrollRepeatedInputBn(\"in\");\n  EnrollOutputBn(\"out\")->set_mutable_inplace_ibn(\"in_0\");\n  VirtualInitFromOpConf();\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> CWiseOp::InferLogicalOutBlobDescs(\n    const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,\n    const ParallelDesc& parallel_desc) const {\n  const BlobDesc* in_0_blob_desc = BlobDesc4BnInOp(input_bns().Get(0));\n  for (size_t i = 1; i < input_bns().size(); ++i) {\n    const auto* blob_desc = BlobDesc4BnInOp(input_bns().Get(i));\n    CHECK_OR_RETURN(*in_0_blob_desc == *blob_desc);\n  }\n  *BlobDesc4BnInOp(\"out\") = *in_0_blob_desc;\n  return VirtualInferBlobDescs(BlobDesc4BnInOp, nullptr);\n}\n\nMaybe<void> CWiseOp::InferOutBlobDescs(\n    const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,\n    const ParallelContext* parallel_ctx) const {\n  const BlobDesc* in_0_blob_desc = GetBlobDesc4BnInOp(input_bns().Get(0));\n  for (size_t i = 1; i < input_bns().size(); ++i) {\n    const auto* blob_desc = GetBlobDesc4BnInOp(input_bns().Get(i));\n    CHECK_OR_RETURN(*in_0_blob_desc == *blob_desc);\n  }\n  *GetBlobDesc4BnInOp(\"out\") = *in_0_blob_desc;\n  return VirtualInferBlobDescs(GetBlobDesc4BnInOp, parallel_ctx);\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/operator/cwise_op.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_OPERATOR_CWISE_OP_H_\n#define ONEFLOW_CORE_OPERATOR_CWISE_OP_H_\n\n#include \"oneflow/core/operator/operator.h\"\n\nnamespace oneflow {\n\nclass CWiseOp : public Operator {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CWiseOp);\n  CWiseOp() = default;\n  virtual ~CWiseOp() = default;\n\n  Maybe<void> InitFromOpConf() override;\n\n  Maybe<void> InferLogicalOutBlobDescs(\n      const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,\n      const ParallelDesc& parallel_desc) const override;\n  Maybe<void> InferOutBlobDescs(\n      const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,\n      const ParallelContext* parallel_ctx) const override;\n\n protected:\n  virtual void VirtualInitFromOpConf() { UNIMPLEMENTED(); }\n\n  virtual Maybe<void> VirtualInferBlobDescs(\n      std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,\n      const ParallelContext* parallel_ctx) const {\n    return Maybe<void>::Ok();\n  }\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_OPERATOR_CWISE_OP_H_\n"
  },
  {
    "path": "oneflow/core/operator/decode_random_op.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_OPERATOR_DECODE_RANDOM_OP_H_\n#define ONEFLOW_CORE_OPERATOR_DECODE_RANDOM_OP_H_\n\n#include \"oneflow/core/operator/operator.h\"\n\nnamespace oneflow {\n\nclass DecodeRandomOp final : public Operator {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(DecodeRandomOp);\n  DecodeRandomOp() = default;\n  ~DecodeRandomOp() = default;\n\n  Maybe<void> InitFromOpConf() override;\n\n  Maybe<void> InferLogicalOutBlobDescs(\n      const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,\n      const ParallelDesc& parallel_desc) const override;\n  Maybe<void> InferOutBlobDescs(\n      const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,\n      const ParallelContext* parallel_ctx) const override;\n\n private:\n  Maybe<void> GetSbpSignatures(SbpSignatureList* sbp_sig_list) const override;\n  void VirtualGenKernelConf(std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,\n                            const ParallelContext* parallel_ctx,\n                            KernelConf* kernel_conf) const override;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_OPERATOR_DECODE_RANDOM_OP_H_\n"
  },
  {
    "path": "oneflow/core/operator/device_tick_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/operator/device_tick_op.h\"\n#include \"oneflow/core/job/sbp_signature_builder.h\"\n\nnamespace oneflow {\n\nMaybe<void> DeviceTickOp::InitFromOpConf() {\n  CHECK(op_conf().has_device_tick_conf());\n  EnrollRepeatedInputBn(\"tick\", false);\n  EnrollOutputBn(\"out\", false);\n  return Maybe<void>::Ok();\n}\n\nnamespace {\n\nMaybe<void> InferBlobDescs(const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp) {\n  BlobDesc* blob_desc = BlobDesc4BnInOp(\"out\");\n  blob_desc->set_shape(Shape({1}));\n  blob_desc->set_data_type(DataType::kInt8);\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\nMaybe<void> DeviceTickOp::InferLogicalOutBlobDescs(\n    const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,\n    const ParallelDesc& parallel_desc) const {\n  return InferBlobDescs(BlobDesc4BnInOp);\n}\n\nMaybe<void> DeviceTickOp::InferOutBlobDescs(\n    const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,\n    const ParallelContext* parallel_ctx) const {\n  return InferBlobDescs(GetBlobDesc4BnInOp);\n}\n\nMaybe<void> DeviceTickOp::GetSbpSignatures(\n    const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,\n    SbpSignatureList* sbp_sig_list) const {\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> DeviceTickOp::InferOpTimeShape(\n    const std::function<Maybe<const Shape>(const std::string&)>& GetTimeShape4BnInOp,\n    std::shared_ptr<const Shape>* time_shape) const {\n  std::shared_ptr<const Shape> in_time_shape;\n  for (const auto& bn : input_bns()) {\n    std::shared_ptr<const Shape> ts = JUST(GetTimeShape4BnInOp(bn));\n    if (!in_time_shape) {\n      in_time_shape = ts;\n    } else {\n      CHECK_OR_RETURN(*in_time_shape == *ts);\n    }\n  }\n  if (this->op_conf().device_tick_conf().has_time_shape()) {\n    if (!in_time_shape) {\n      in_time_shape.reset(new Shape(this->op_conf().device_tick_conf().time_shape()));\n    } else {\n      CHECK_OR_RETURN(in_time_shape->elem_cnt()\n                      == Shape(this->op_conf().device_tick_conf().time_shape()).elem_cnt());\n    }\n  }\n  if (in_time_shape) {\n    *time_shape = in_time_shape;\n  } else {\n    *time_shape = std::make_shared<const Shape>(Shape({1, 1}));\n  }\n  return Maybe<void>::Ok();\n}\n\nREGISTER_OP(OperatorConf::kDeviceTickConf, DeviceTickOp);\nREGISTER_TICK_TOCK_OP(OperatorConf::kDeviceTickConf);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/operator/device_tick_op.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_OPERATOR_DEVICE_TICK_OP_H_\n#define ONEFLOW_CORE_OPERATOR_DEVICE_TICK_OP_H_\n\n#include \"oneflow/core/operator/operator.h\"\n\nnamespace oneflow {\n\nclass DeviceTickOp final : public Operator {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(DeviceTickOp);\n  DeviceTickOp() = default;\n  ~DeviceTickOp() = default;\n\n  Maybe<void> InitFromOpConf() override;\n  Maybe<void> InferLogicalOutBlobDescs(\n      const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,\n      const ParallelDesc& parallel_desc) const override;\n  Maybe<void> InferOutBlobDescs(\n      const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,\n      const ParallelContext* parallel_ctx) const override;\n  Maybe<void> InferOpTimeShape(\n      const std::function<Maybe<const Shape>(const std::string&)>& GetTimeShape4BnInOp,\n      std::shared_ptr<const Shape>* time_shape) const override;\n\n private:\n  Maybe<void> GetSbpSignatures(\n      const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,\n      SbpSignatureList* sbp_sig_list) const override;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_OPERATOR_DEVICE_TICK_OP_H_\n"
  },
  {
    "path": "oneflow/core/operator/distribute_add_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/operator/operator.h\"\n#include \"oneflow/core/common/balanced_splitter.h\"\n#include \"oneflow/core/vm/symbol_storage.h\"\n#include \"oneflow/core/job/scope.h\"\n\nnamespace oneflow {\n\nclass DistributeAddOp final : public Operator {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(DistributeAddOp);\n  DistributeAddOp() = default;\n  ~DistributeAddOp() = default;\n\n  Maybe<void> InitFromOpConf() override;\n\n  Maybe<void> InferLogicalOutBlobDescs(\n      const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,\n      const ParallelDesc& parallel_desc) const override;\n  Maybe<void> InferOutBlobDescs(\n      const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,\n      const ParallelContext* parallel_ctx) const override;\n\n private:\n  Maybe<void> InferBlobParallelDesc() override;\n  Maybe<void> InferSbpSignature(\n      SbpSignature* sbp_signature, const SbpSignature& sbp_sig_conf,\n      const std::function<int32_t(const SbpSignature&)>& CalcOrderValue4SbpSig,\n      std::function<Maybe<const SbpInferHint*>(const std::string&)> SbpInferHint4Ibn,\n      const ParallelDesc& parallel_desc) const override;\n};\n\nMaybe<void> DistributeAddOp::InitFromOpConf() {\n  CHECK(op_conf().has_distribute_add_conf());\n\n  EnrollRepeatedInputBn(\"in\");\n  EnrollOutputBn(\"out\");\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> DistributeAddOp::InferBlobParallelDesc() {\n  HashMap<std::string, std::shared_ptr<const ParallelDesc>> bn2parallel_desc;\n  const std::shared_ptr<const ParallelDesc> op_parallel_desc = JUST(GetOpParallelDesc());\n  FOR_RANGE(int, i, 0, input_bns().size()) {\n    bn2parallel_desc[input_bns().Get(i)] =\n        std::make_shared<const ParallelDesc>(op_parallel_desc->GetParallelIdOnlyParallelConf(i));\n  }\n  bn2parallel_desc[\"out\"] = op_parallel_desc;\n  JUST(FillBlobParallelDesc([&](const std::string& bn) -> Maybe<const ParallelDesc> {\n    auto it = bn2parallel_desc.find(bn);\n    CHECK_OR_RETURN(it != bn2parallel_desc.end());\n    return it->second;\n  }));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> DistributeAddOp::InferLogicalOutBlobDescs(\n    const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,\n    const ParallelDesc& parallel_desc) const {\n  const BlobDesc* in_0 = BlobDesc4BnInOp(input_bns().Get(0));\n  FOR_RANGE(int, i, 1, output_bns().size()) {\n    const BlobDesc* in_i = BlobDesc4BnInOp(input_bns().Get(i));\n    CHECK_OR_RETURN(*in_i == *in_0);\n  }\n  *BlobDesc4BnInOp(\"out\") = *in_0;\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> DistributeAddOp::InferOutBlobDescs(\n    const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,\n    const ParallelContext* parallel_ctx) const {\n  const BlobDesc* first_blob_desc = nullptr;\n  FOR_RANGE(int, i, 0, input_bns().size()) {\n    first_blob_desc = GetBlobDesc4BnInOp(input_bns().Get(i));\n    if (first_blob_desc != nullptr) { break; }\n  }\n  CHECK_NOTNULL(first_blob_desc);\n  *GetBlobDesc4BnInOp(\"out\") = *first_blob_desc;\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> DistributeAddOp::InferSbpSignature(\n    SbpSignature* sbp_signature, const SbpSignature& sbp_sig_conf,\n    const std::function<int32_t(const SbpSignature&)>& CalcOrderValue4SbpSig,\n    std::function<Maybe<const SbpInferHint*>(const std::string&)> SbpInferHint4Ibn,\n    const ParallelDesc& parallel_desc) const {\n  CHECK_EQ_OR_RETURN(parallel_desc.parallel_num(), input_bns().size());\n  const auto& first_in_hint = *JUST(SbpInferHint4Ibn(input_bns().Get(0)));\n  FOR_RANGE(int, i, 0, input_bns().size()) {\n    const auto& in_sbp_infer_hint = *JUST(SbpInferHint4Ibn(input_bns().Get(i)));\n    CHECK_EQ_OR_RETURN(1, in_sbp_infer_hint.parallel_desc().parallel_num());\n    CHECK_EQ_OR_RETURN(first_in_hint.logical_blob_desc().shape(),\n                       in_sbp_infer_hint.logical_blob_desc().shape());\n  }\n  auto* bn2sbp = sbp_signature->mutable_bn_in_op2sbp_parallel();\n  for (const auto& ibn : input_bns()) { (*bn2sbp)[ibn].mutable_partial_sum_parallel(); }\n  (*bn2sbp)[\"out\"].mutable_partial_sum_parallel();\n  return Maybe<void>::Ok();\n}\n\nREGISTER_OP(OperatorConf::kDistributeAddConf, DistributeAddOp);\nREGISTER_DISABLE_INPUT_BOXING_GROUP(OperatorConf::kDistributeAddConf);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/operator/distribute_clone_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/operator/operator.h\"\n#include \"oneflow/core/common/balanced_splitter.h\"\n#include \"oneflow/core/vm/symbol_storage.h\"\n#include \"oneflow/core/job/scope.h\"\n\nnamespace oneflow {\n\nclass DistributeCloneOp final : public Operator {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(DistributeCloneOp);\n  DistributeCloneOp() = default;\n  ~DistributeCloneOp() = default;\n\n  Maybe<void> InitFromOpConf() override;\n\n private:\n  Maybe<void> InferBlobParallelDesc() override;\n  Maybe<void> InferSbpSignature(\n      SbpSignature* sbp_signature, const SbpSignature& sbp_sig_conf,\n      const std::function<int32_t(const SbpSignature&)>& CalcOrderValue4SbpSig,\n      std::function<Maybe<const SbpInferHint*>(const std::string&)> SbpInferHint4Ibn,\n      const ParallelDesc& parallel_desc) const override;\n  Maybe<void> InferLogicalOutBlobDescs(\n      const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,\n      const ParallelDesc& parallel_desc) const override;\n  Maybe<void> InferOutBlobDescs(\n      const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,\n      const ParallelContext* parallel_ctx) const override;\n};\n\nMaybe<void> DistributeCloneOp::InitFromOpConf() {\n  CHECK(op_conf().has_distribute_clone_conf());\n\n  EnrollInputBn(\"in\");\n  EnrollRepeatedOutputBnWithSetter(\"out\", [&](OutputBlobModifier* ob_modifier) {\n    ob_modifier->set_is_mutable(op_conf().distribute_clone_conf().is_variable_ref());\n  });\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> DistributeCloneOp::InferLogicalOutBlobDescs(\n    const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,\n    const ParallelDesc& parallel_desc) const {\n  const auto& in_blob_desc = *BlobDesc4BnInOp(\"in\");\n  FOR_RANGE(int, i, 0, output_bns().size()) {\n    BlobDesc* blob_desc = BlobDesc4BnInOp(output_bns().Get(i));\n    *blob_desc = in_blob_desc;\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> DistributeCloneOp::InferOutBlobDescs(\n    const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,\n    const ParallelContext* parallel_ctx) const {\n  const auto& in_blob_desc = *GetBlobDesc4BnInOp(\"in\");\n  if (parallel_ctx->parallel_num() > 1) {\n    CHECK_EQ_OR_RETURN(parallel_ctx->parallel_num(), output_bns().size());\n    auto* out_blob_desc = GetBlobDesc4BnInOp(output_bns().Get(parallel_ctx->parallel_id()));\n    *out_blob_desc = in_blob_desc;\n    return Maybe<void>::Ok();\n  }\n  FOR_RANGE(int, i, 0, output_bns().size()) {\n    BlobDesc* blob_desc = GetBlobDesc4BnInOp(output_bns().Get(i));\n    if (blob_desc != nullptr) { *blob_desc = in_blob_desc; }\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> DistributeCloneOp::InferBlobParallelDesc() {\n  HashMap<std::string, std::shared_ptr<const ParallelDesc>> bn2parallel_desc;\n  const std::shared_ptr<const ParallelDesc> op_parallel_desc = JUST(GetOpParallelDesc());\n  bn2parallel_desc[\"in\"] = op_parallel_desc;\n  FOR_RANGE(int, i, 0, output_bns().size()) {\n    bn2parallel_desc[output_bns().Get(i)] =\n        std::make_shared<const ParallelDesc>(op_parallel_desc->GetParallelIdOnlyParallelConf(i));\n  }\n  JUST(FillBlobParallelDesc([&](const std::string& bn) -> Maybe<const ParallelDesc> {\n    auto it = bn2parallel_desc.find(bn);\n    CHECK_OR_RETURN(it != bn2parallel_desc.end());\n    return it->second;\n  }));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> DistributeCloneOp::InferSbpSignature(\n    SbpSignature* sbp_signature, const SbpSignature& sbp_sig_conf,\n    const std::function<int32_t(const SbpSignature&)>& CalcOrderValue4SbpSig,\n    std::function<Maybe<const SbpInferHint*>(const std::string&)> SbpInferHint4Ibn,\n    const ParallelDesc& parallel_desc) const {\n  CHECK_EQ_OR_RETURN(parallel_desc.parallel_num(), output_bns().size());\n  const SbpInferHint& in_hint = *JUST(SbpInferHint4Ibn(\"in\"));\n  CHECK_OR_RETURN(in_hint.parallel_desc() == parallel_desc);\n  SbpSignatureBuilder().Broadcast(output_bns()).Build(sbp_signature);\n  auto* bn2sbp = sbp_signature->mutable_bn_in_op2sbp_parallel();\n  (*bn2sbp)[\"in\"].mutable_broadcast_parallel();\n  return Maybe<void>::Ok();\n}\n\nREGISTER_OP(OperatorConf::kDistributeCloneConf, DistributeCloneOp);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/operator/distribute_concat_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/operator/operator.h\"\n#include \"oneflow/core/common/balanced_splitter.h\"\n#include \"oneflow/core/vm/symbol_storage.h\"\n#include \"oneflow/core/job/scope.h\"\n\nnamespace oneflow {\n\nclass DistributeConcatOp final : public Operator {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(DistributeConcatOp);\n  DistributeConcatOp() = default;\n  ~DistributeConcatOp() = default;\n\n  Maybe<void> InitFromOpConf() override;\n\n  Maybe<void> InferLogicalOutBlobDescs(\n      const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,\n      const ParallelDesc& parallel_desc) const override;\n  Maybe<void> InferOutBlobDescs(\n      const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,\n      const ParallelContext* parallel_ctx) const override;\n\n private:\n  Maybe<void> InferBlobParallelDesc() override;\n  Maybe<void> InferSbpSignature(\n      SbpSignature* sbp_signature, const SbpSignature& sbp_sig_conf,\n      const std::function<int32_t(const SbpSignature&)>& CalcOrderValue4SbpSig,\n      std::function<Maybe<const SbpInferHint*>(const std::string&)> SbpInferHint4Ibn,\n      const ParallelDesc& parallel_desc) const override;\n\n  Maybe<void> GetSbpSignatures(\n      const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,\n      SbpSignatureList* sbp_sig_list) const override;\n\n  int32_t FixAxis(const int32_t axis, const int64_t num_axes) const;\n};\n\nMaybe<void> DistributeConcatOp::InitFromOpConf() {\n  CHECK(op_conf().has_distribute_concat_conf());\n\n  EnrollRepeatedInputBn(\"in\");\n  EnrollOutputBn(\"out\");\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> DistributeConcatOp::InferLogicalOutBlobDescs(\n    const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,\n    const ParallelDesc& parallel_desc) const {\n  const auto& conf = op_conf().distribute_concat_conf();\n  BlobDesc* out = BlobDesc4BnInOp(\"out\");\n  *out = *BlobDesc4BnInOp(input_bns().Get(0));\n  const int32_t concat_axis = FixAxis(conf.axis(), out->shape().NumAxes());\n  int64_t concat_dim_size = out->shape().At(concat_axis);\n  for (size_t i = 1; i < input_bns().size(); ++i) {\n    const BlobDesc* in_i = BlobDesc4BnInOp(input_bns().Get(i));\n    for (int64_t j = 0; j < in_i->shape().NumAxes(); ++j) {\n      if (j == concat_axis) {\n        concat_dim_size += in_i->shape().At(j);\n      } else {\n        CHECK_EQ_OR_RETURN(out->shape().At(j), in_i->shape().At(j));\n      }\n    }\n    CHECK_EQ_OR_RETURN(in_i->data_type(), out->data_type());\n  }\n  Shape output = out->shape();\n  output.Set(concat_axis, concat_dim_size);\n  out->set_shape(output);\n  out->set_is_dynamic(false);\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> DistributeConcatOp::InferOutBlobDescs(\n    const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,\n    const ParallelContext* parallel_ctx) const {\n  if (parallel_ctx->parallel_num() > 1) {\n    const auto* in_blob_desc = GetBlobDesc4BnInOp(input_bns().Get(parallel_ctx->parallel_id()));\n    BlobDesc* out_blob_desc = GetBlobDesc4BnInOp(\"out\");\n    *out_blob_desc = *in_blob_desc;\n    out_blob_desc->set_is_dynamic(false);\n    return Maybe<void>::Ok();\n  }\n  const auto& conf = op_conf().distribute_concat_conf();\n  const BlobDesc* first_blob_desc = nullptr;\n  int first_blob_desc_idx = -1;\n  FOR_RANGE(int, i, 0, input_bns().size()) {\n    first_blob_desc = GetBlobDesc4BnInOp(input_bns().Get(i));\n    if (first_blob_desc != nullptr) {\n      first_blob_desc_idx = i;\n      break;\n    }\n  }\n  CHECK_NOTNULL(first_blob_desc);\n  DimVector out_dim_vec = first_blob_desc->shape().dim_vec();\n  int32_t concat_axis = FixAxis(conf.axis(), out_dim_vec.size());\n  for (size_t i = 0; i < input_bns().size(); ++i) {\n    const BlobDesc* in_i_blob_desc = GetBlobDesc4BnInOp(input_bns().Get(i));\n    if (in_i_blob_desc == nullptr) { continue; }\n    if (first_blob_desc_idx == i) { continue; }\n    for (int64_t j = 0; j < in_i_blob_desc->shape().NumAxes(); ++j) {\n      if (j == concat_axis) {\n        out_dim_vec[j] += in_i_blob_desc->shape().At(j);\n      } else {\n        CHECK_EQ_OR_RETURN(out_dim_vec[j], in_i_blob_desc->shape().At(j));\n      }\n    }\n    CHECK_EQ_OR_RETURN(in_i_blob_desc->data_type(), first_blob_desc->data_type());\n  }\n  BlobDesc* out_blob_desc = GetBlobDesc4BnInOp(\"out\");\n  *out_blob_desc = *first_blob_desc;\n  out_blob_desc->set_shape(Shape(out_dim_vec));\n  out_blob_desc->set_is_dynamic(false);\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> DistributeConcatOp::InferBlobParallelDesc() {\n  HashMap<std::string, std::shared_ptr<const ParallelDesc>> bn2parallel_desc;\n  const std::shared_ptr<const ParallelDesc> op_parallel_desc = JUST(GetOpParallelDesc());\n  FOR_RANGE(int, i, 0, input_bns().size()) {\n    bn2parallel_desc[input_bns().Get(i)] =\n        std::make_shared<const ParallelDesc>(op_parallel_desc->GetParallelIdOnlyParallelConf(i));\n  }\n  bn2parallel_desc[\"out\"] = op_parallel_desc;\n  JUST(FillBlobParallelDesc([&](const std::string& bn) -> Maybe<const ParallelDesc> {\n    auto it = bn2parallel_desc.find(bn);\n    CHECK_OR_RETURN(it != bn2parallel_desc.end());\n    return it->second;\n  }));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> DistributeConcatOp::InferSbpSignature(\n    SbpSignature* sbp_signature, const SbpSignature& sbp_sig_conf,\n    const std::function<int32_t(const SbpSignature&)>& CalcOrderValue4SbpSig,\n    std::function<Maybe<const SbpInferHint*>(const std::string&)> SbpInferHint4Ibn,\n    const ParallelDesc& parallel_desc) const {\n  CHECK_EQ_OR_RETURN(parallel_desc.parallel_num(), input_bns().size());\n  auto LogicalBlobDesc4Ibn = [&](const std::string& ibn) -> Maybe<const BlobDesc&> {\n    const SbpInferHint* sbp_infer_hint = JUST(SbpInferHint4Ibn(ibn));\n    return Maybe<const BlobDesc&>(sbp_infer_hint->logical_blob_desc());\n  };\n  {\n    // check parallel_num and dimention\n    const auto& conf = op_conf().distribute_concat_conf();\n    const int64_t num_axes = JUST(LogicalBlobDesc4Ibn(input_bns().Get(0))).shape().NumAxes();\n    const int32_t axis = FixAxis(conf.axis(), num_axes);\n    int64_t dim = 0;\n    FOR_RANGE(int, i, 0, input_bns().size()) {\n      const auto& in_parallel_desc = JUST(SbpInferHint4Ibn(input_bns().Get(i)))->parallel_desc();\n      CHECK_EQ_OR_RETURN(1, in_parallel_desc.parallel_num());\n      dim += JUST(LogicalBlobDesc4Ibn(input_bns().Get(i))).shape().At(axis);\n    }\n    BalancedSplitter bs(dim, parallel_desc.parallel_num());\n    FOR_RANGE(int, i, 0, input_bns().size()) {\n      CHECK_EQ_OR_RETURN(JUST(LogicalBlobDesc4Ibn(input_bns().Get(i))).shape().At(axis),\n                         bs.At(i).size());\n    }\n  }\n  SbpSignatureList sbp_sig_list;\n  JUST(GetSbpSignatures(LogicalBlobDesc4Ibn, &sbp_sig_list));\n  *sbp_signature = sbp_sig_list.sbp_signature().Get(0);\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> DistributeConcatOp::GetSbpSignatures(\n    const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,\n    SbpSignatureList* sbp_sig_list) const {\n  const auto& conf = op_conf().distribute_concat_conf();\n  const int64_t num_axes = JUST(LogicalBlobDesc4Ibn(input_bns().Get(0))).shape().NumAxes();\n  const int32_t axis = FixAxis(conf.axis(), num_axes);\n  SbpSignatureBuilder()\n      .Broadcast(input_bns())\n      .Split(output_bns(), axis)\n      .Build(sbp_sig_list->mutable_sbp_signature()->Add());\n  return Maybe<void>::Ok();\n}\n\nint32_t DistributeConcatOp::FixAxis(const int32_t axis, const int64_t num_axes) const {\n  int32_t ret = axis;\n  if (axis < 0) { ret += num_axes; }\n  CHECK_GE(axis, 0);\n  CHECK_LT(axis, num_axes);\n  return ret;\n}\n\nREGISTER_OP(OperatorConf::kDistributeConcatConf, DistributeConcatOp);\nREGISTER_DISABLE_INPUT_BOXING_GROUP(OperatorConf::kDistributeConcatConf);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/operator/distribute_split_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/operator/operator.h\"\n#include \"oneflow/core/common/balanced_splitter.h\"\n#include \"oneflow/core/vm/symbol_storage.h\"\n#include \"oneflow/core/job/scope.h\"\n\nnamespace oneflow {\n\nclass DistributeSplitOp final : public Operator {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(DistributeSplitOp);\n  DistributeSplitOp() = default;\n  ~DistributeSplitOp() = default;\n\n  Maybe<void> InitFromOpConf() override;\n\n private:\n  Maybe<void> InferBlobParallelDesc() override;\n  Maybe<void> InferSbpSignature(\n      SbpSignature* sbp_signature, const SbpSignature& sbp_sig_conf,\n      const std::function<int32_t(const SbpSignature&)>& CalcOrderValue4SbpSig,\n      std::function<Maybe<const SbpInferHint*>(const std::string&)> SbpInferHint4Ibn,\n      const ParallelDesc& parallel_desc) const override;\n  Maybe<void> InferLogicalOutBlobDescs(\n      const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,\n      const ParallelDesc& parallel_desc) const override;\n  Maybe<void> InferOutBlobDescs(\n      const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,\n      const ParallelContext* parallel_ctx) const override;\n\n  Maybe<void> GetSbpSignatures(\n      const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,\n      SbpSignatureList* sbp_sig_list) const override;\n\n  int32_t FixAxis(const int32_t axis, const int64_t num_axes) const;\n};\n\nMaybe<void> DistributeSplitOp::InitFromOpConf() {\n  CHECK(op_conf().has_distribute_split_conf());\n  EnrollInputBn(\"in\");\n  EnrollRepeatedOutputBnWithSetter(\"out\", [&](OutputBlobModifier* ob_modifier) {\n    ob_modifier->set_header_infered_before_compute(false);\n    ob_modifier->set_is_mutable(op_conf().distribute_split_conf().is_variable_ref());\n  });\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> DistributeSplitOp::InferLogicalOutBlobDescs(\n    const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,\n    const ParallelDesc& parallel_desc) const {\n  const auto& in_blob_desc = *BlobDesc4BnInOp(\"in\");\n  CHECK_EQ(parallel_desc.parallel_num(), output_bns().size());\n  const auto& conf = op_conf().distribute_split_conf();\n  const int32_t split_axis = FixAxis(conf.axis(), in_blob_desc.shape().NumAxes());\n  BalancedSplitter bs(in_blob_desc.shape().At(split_axis), parallel_desc.parallel_num());\n  FOR_RANGE(int, i, 0, parallel_desc.parallel_num()) {\n    BlobDesc* out_blob_desc = BlobDesc4BnInOp(output_bns().Get(i));\n    *out_blob_desc = in_blob_desc;\n    Shape output = out_blob_desc->shape();\n    output.Set(split_axis, bs.At(i).size());\n    out_blob_desc->set_shape(output);\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> DistributeSplitOp::InferOutBlobDescs(\n    const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,\n    const ParallelContext* parallel_ctx) const {\n  const auto& in_blob_desc = *GetBlobDesc4BnInOp(\"in\");\n  if (parallel_ctx->parallel_num() > 1) {\n    CHECK_EQ(parallel_ctx->parallel_num(), output_bns().size());\n    auto* out_blob_desc = GetBlobDesc4BnInOp(output_bns().Get(parallel_ctx->parallel_id()));\n    *out_blob_desc = in_blob_desc;\n    return Maybe<void>::Ok();\n  }\n  const auto& conf = op_conf().distribute_split_conf();\n  int32_t split_axis = FixAxis(conf.axis(), in_blob_desc.shape().NumAxes());\n  std::vector<BlobDesc*> out_blob_descs;\n  out_blob_descs.reserve(output_bns().size());\n  FOR_RANGE(int, i, 0, output_bns().size()) {\n    BlobDesc* blob_desc = GetBlobDesc4BnInOp(output_bns().Get(i));\n    if (blob_desc != nullptr) { out_blob_descs.emplace_back(blob_desc); }\n  }\n  BalancedSplitter bs(in_blob_desc.shape().At(split_axis), out_blob_descs.size());\n  FOR_RANGE(int, i, 0, out_blob_descs.size()) {\n    *out_blob_descs.at(i) = in_blob_desc;\n    Shape output = out_blob_descs.at(i)->shape();  // NOLINT\n    output.Set(split_axis, bs.At(i).size());\n    out_blob_descs.at(i)->set_shape(output);  // NOLINT\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> DistributeSplitOp::InferBlobParallelDesc() {\n  HashMap<std::string, std::shared_ptr<const ParallelDesc>> bn2parallel_desc;\n  const std::shared_ptr<const ParallelDesc> op_parallel_desc = JUST(GetOpParallelDesc());\n  bn2parallel_desc[\"in\"] = op_parallel_desc;\n  FOR_RANGE(int, i, 0, output_bns().size()) {\n    bn2parallel_desc[output_bns().Get(i)] =\n        std::make_shared<const ParallelDesc>(op_parallel_desc->GetParallelIdOnlyParallelConf(i));\n  }\n  JUST(FillBlobParallelDesc([&](const std::string& bn) -> Maybe<const ParallelDesc> {\n    auto it = bn2parallel_desc.find(bn);\n    CHECK_OR_RETURN(it != bn2parallel_desc.end());\n    return it->second;\n  }));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> DistributeSplitOp::InferSbpSignature(\n    SbpSignature* sbp_signature, const SbpSignature& sbp_sig_conf,\n    const std::function<int32_t(const SbpSignature&)>& CalcOrderValue4SbpSig,\n    std::function<Maybe<const SbpInferHint*>(const std::string&)> SbpInferHint4Ibn,\n    const ParallelDesc& parallel_desc) const {\n  CHECK_EQ_OR_RETURN(parallel_desc.parallel_num(), output_bns().size());\n  auto LogicalBlobDesc4Ibn = [&](const std::string& ibn) -> Maybe<const BlobDesc&> {\n    const SbpInferHint* sbp_infer_hint = JUST(SbpInferHint4Ibn(ibn));\n    return Maybe<const BlobDesc&>(sbp_infer_hint->logical_blob_desc());\n  };\n  SbpSignatureList sbp_sig_list;\n  JUST(GetSbpSignatures(LogicalBlobDesc4Ibn, &sbp_sig_list));\n  *sbp_signature = sbp_sig_list.sbp_signature().Get(0);\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> DistributeSplitOp::GetSbpSignatures(\n    const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,\n    SbpSignatureList* sbp_sig_list) const {\n  const auto& conf = op_conf().distribute_split_conf();\n  const int64_t num_axes = JUST(LogicalBlobDesc4Ibn(\"in\")).shape().NumAxes();\n  const int32_t axis = FixAxis(conf.axis(), num_axes);\n  SbpSignatureBuilder()\n      .Split(input_bns(), axis)\n      .Broadcast(output_bns())\n      .Build(sbp_sig_list->mutable_sbp_signature()->Add());\n  return Maybe<void>::Ok();\n}\n\nint32_t DistributeSplitOp::FixAxis(const int32_t axis, const int64_t num_axes) const {\n  int32_t ret = axis;\n  if (axis < 0) { ret += num_axes; }\n  CHECK_GE(axis, 0);\n  CHECK_LT(axis, num_axes);\n  return ret;\n}\n\nREGISTER_OP(OperatorConf::kDistributeSplitConf, DistributeSplitOp);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/operator/dst_subset_tick_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/operator/operator.h\"\n#include \"oneflow/core/job/sbp_signature_builder.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nMaybe<void> InferBlobDescs(const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp) {\n  BlobDesc* blob_desc = BlobDesc4BnInOp(\"out\");\n  blob_desc->set_shape(Shape({1}));\n  blob_desc->set_data_type(DataType::kInt8);\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\nclass DstSubsetTickOp final : public Operator {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(DstSubsetTickOp);\n  DstSubsetTickOp() = default;\n  ~DstSubsetTickOp() = default;\n\n  Maybe<void> InitFromOpConf() override;\n  Maybe<void> InferLogicalOutBlobDescs(\n      const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,\n      const ParallelDesc& parallel_desc) const override;\n  Maybe<void> InferOutBlobDescs(\n      const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,\n      const ParallelContext* parallel_ctx) const override;\n\n private:\n  Maybe<void> GetSbpSignatures(SbpSignatureList* sbp_sig_list) const override;\n};\n\nMaybe<void> DstSubsetTickOp::InitFromOpConf() {\n  CHECK(op_conf().has_dst_subset_tick_conf());\n  EnrollRepeatedInputBn(\"in\", false);\n  EnrollOutputBn(\"out\", false);\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> DstSubsetTickOp::InferLogicalOutBlobDescs(\n    const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,\n    const ParallelDesc& parallel_desc) const {\n  return InferBlobDescs(BlobDesc4BnInOp);\n}\n\nMaybe<void> DstSubsetTickOp::InferOutBlobDescs(\n    const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,\n    const ParallelContext* parallel_ctx) const {\n  return InferBlobDescs(GetBlobDesc4BnInOp);\n}\n\nMaybe<void> DstSubsetTickOp::GetSbpSignatures(SbpSignatureList* sbp_sig_list) const {\n  SbpSignatureBuilder()\n      .Broadcast(input_bns())\n      .Broadcast(output_bns())\n      .Build(sbp_sig_list->mutable_sbp_signature()->Add());\n  return Maybe<void>::Ok();\n}\n\nREGISTER_CPU_OP(OperatorConf::kDstSubsetTickConf, DstSubsetTickOp);\nREGISTER_TICK_TOCK_OP(OperatorConf::kDstSubsetTickConf);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/operator/dynamic_reshape_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/operator/operator.h\"\n\nnamespace oneflow {\n\nclass DynamicReshapeOp final : public Operator {\n public:\n  Maybe<void> InitFromOpConf() override {\n    CHECK(op_conf().has_dynamic_reshape_conf());\n    EnrollInputBn(\"in\");\n    EnrollOutputBn(\"out\")->set_const_inplace_ibn(\"in\");\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> InferLogicalOutBlobDescs(\n      const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,\n      const ParallelDesc& parallel_desc) const override {\n    const DynamicReshapeOpConf& conf = op_conf().dynamic_reshape_conf();\n    const BlobDesc* in = BlobDesc4BnInOp(\"in\");\n    BlobDesc* out = BlobDesc4BnInOp(\"out\");\n    *out = *in;\n    DimVector out_dim_vec(conf.shape().dim().begin(), conf.shape().dim().end());\n    int32_t inferred_axis = -1;\n    int32_t product = 1;\n    for (int32_t i = 0; i < out_dim_vec.size(); ++i) {\n      if (out_dim_vec.at(i) == -1) {\n        CHECK_EQ_OR_RETURN(-1, inferred_axis);\n        inferred_axis = i;\n      } else {\n        CHECK_GT_OR_RETURN(out_dim_vec.at(i), 0);\n        product *= out_dim_vec.at(i);\n      }\n    }\n    if (inferred_axis >= 0) {\n      CHECK_GE_OR_RETURN(product, 1);\n      CHECK_EQ_OR_RETURN(in->shape().elem_cnt() % product, 0);\n      out_dim_vec.at(inferred_axis) = in->shape().elem_cnt() / product;\n    }\n    out->set_shape(Shape(out_dim_vec));\n    CHECK_EQ_OR_RETURN(in->shape().elem_cnt(), out->shape().elem_cnt());\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> InferOutBlobDescs(\n      const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,\n      const ParallelContext* parallel_ctx) const override {\n    const auto* sbp_signature = JUST(this->sbp_signature());\n    const DynamicReshapeOpConf& conf = op_conf().dynamic_reshape_conf();\n    const BlobDesc* in = GetBlobDesc4BnInOp(\"in\");\n    BlobDesc* out = GetBlobDesc4BnInOp(\"out\");\n    *out = *in;\n    DimVector out_dim_vec(conf.shape().dim().begin(), conf.shape().dim().end());\n    if (parallel_ctx->parallel_num() > 1) {\n      // global strategy\n      //   ONLY support sbp: S(0); and -1 must at axis 0\n      const auto& out_sbp_it = sbp_signature->bn_in_op2sbp_parallel().find(\"out\");\n      CHECK_OR_RETURN(out_sbp_it != sbp_signature->bn_in_op2sbp_parallel().end());\n      const SbpParallel& out_sbp = out_sbp_it->second;\n      const auto& in_sbp_it = sbp_signature->bn_in_op2sbp_parallel().find(\"in\");\n      CHECK_OR_RETURN(in_sbp_it != sbp_signature->bn_in_op2sbp_parallel().end());\n      const SbpParallel& in_sbp = in_sbp_it->second;\n      if (out_sbp.has_split_parallel()) {\n        CHECK_EQ_OR_RETURN(out_sbp.split_parallel().axis(), 0);\n        CHECK_EQ_OR_RETURN(out_dim_vec.at(0), -1);\n        CHECK_OR_RETURN(in_sbp.has_split_parallel());\n        CHECK_EQ_OR_RETURN(in_sbp.split_parallel().axis(), 0);\n      }\n    }\n    int32_t inferred_axis = -1;\n    int32_t product = 1;\n    for (int32_t i = 0; i < out_dim_vec.size(); ++i) {\n      if (out_dim_vec.at(i) == -1) {\n        CHECK_EQ_OR_RETURN(-1, inferred_axis);\n        inferred_axis = i;\n      } else {\n        CHECK_GT_OR_RETURN(out_dim_vec.at(i), 0);\n        product *= out_dim_vec.at(i);\n      }\n    }\n    if (inferred_axis >= 0) {\n      CHECK_GE_OR_RETURN(product, 1);\n      CHECK_EQ_OR_RETURN(in->shape().elem_cnt() % product, 0);\n      out_dim_vec.at(inferred_axis) = in->shape().elem_cnt() / product;\n    }\n    out->set_shape(Shape(out_dim_vec));\n    CHECK_EQ_OR_RETURN(in->shape().elem_cnt(), out->shape().elem_cnt());\n    return Maybe<void>::Ok();\n  }\n\n private:\n  Maybe<void> GetSbpSignatures(\n      const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,\n      SbpSignatureList* sbp_sig_list) const override {\n    SbpSignatureBuilder()\n        .Split(input_bns(), 0)\n        .Split(output_bns(), 0)\n        .Build(sbp_sig_list->mutable_sbp_signature()->Add());\n    return Maybe<void>::Ok();\n  }\n};\n\nREGISTER_OP(OperatorConf::kDynamicReshapeConf, DynamicReshapeOp);\n\nclass DynamicReshapeLikeOp final : public Operator {\n public:\n  Maybe<void> InitFromOpConf() override {\n    CHECK(op_conf().has_dynamic_reshape_like_conf());\n    EnrollInputBn(\"x\");\n    EnrollOutputBn(\"y\");\n    EnrollInputBn(\"like\", false);\n    return Maybe<void>::Ok();\n  }\n  Maybe<void> InferLogicalOutBlobDescs(\n      const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,\n      const ParallelDesc& parallel_desc) const override {\n    CHECK_EQ_OR_RETURN(BlobDesc4BnInOp(\"x\")->shape().elem_cnt(),\n                       BlobDesc4BnInOp(\"like\")->shape().elem_cnt());\n    BlobDesc4BnInOp(\"y\")->CopyFrom(*BlobDesc4BnInOp(\"like\"));\n    return Maybe<void>::Ok();\n  }\n  Maybe<void> InferOutBlobDescs(\n      const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,\n      const ParallelContext* parallel_ctx) const override {\n    CHECK_EQ_OR_RETURN(GetBlobDesc4BnInOp(\"x\")->shape().elem_cnt(),\n                       GetBlobDesc4BnInOp(\"like\")->shape().elem_cnt());\n    GetBlobDesc4BnInOp(\"y\")->CopyFrom(*GetBlobDesc4BnInOp(\"like\"));\n    return Maybe<void>::Ok();\n  }\n\n private:\n  Maybe<void> GetSbpSignatures(\n      const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,\n      SbpSignatureList* sbp_sig_list) const override {\n    SbpSignatureBuilder()\n        .Split(input_bns(), 0)\n        .Split(output_bns(), 0)\n        .Build(sbp_sig_list->mutable_sbp_signature()->Add());\n    return Maybe<void>::Ok();\n  }\n};\n\nREGISTER_OP(OperatorConf::kDynamicReshapeLikeConf, DynamicReshapeLikeOp);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/operator/esac_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/operator/esac_op.h\"\n#include \"oneflow/core/job/sbp_signature_builder.h\"\n\nnamespace oneflow {\n\nMaybe<void> EsacOp::InitFromOpConf() {\n  EnrollRepeatedInputBn(\"in\", false);\n  EnrollOutputBn(\"out\", false);\n  return Maybe<void>::Ok();\n}\n\nnamespace {\n\nMaybe<void> InferBlobDescs(const OperatorConf& op_conf,\n                           const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp) {\n  BlobDesc* out = BlobDesc4BnInOp(\"out\");\n  out->set_shape(Shape({1}));\n  const DataType data_type = op_conf.esac_conf().data_type();\n  CHECK_OR_RETURN(IsIntegralDataType(data_type));\n  out->set_data_type(data_type);\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\nMaybe<void> EsacOp::InferLogicalOutBlobDescs(\n    const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,\n    const ParallelDesc& parallel_desc) const {\n  return InferBlobDescs(op_conf(), BlobDesc4BnInOp);\n}\n\nMaybe<void> EsacOp::InferOutBlobDescs(\n    const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,\n    const ParallelContext* parallel_ctx) const {\n  return InferBlobDescs(op_conf(), GetBlobDesc4BnInOp);\n}\n\nMaybe<void> EsacOp::GetSbpSignatures(\n    const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,\n    SbpSignatureList* sbp_sig_list) const {\n  return Maybe<void>::Ok();\n}\n\nREGISTER_CPU_OP(OperatorConf::kEsacConf, EsacOp);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/operator/esac_op.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_OPERATOR_ESAC_OP_H_\n#define ONEFLOW_CORE_OPERATOR_ESAC_OP_H_\n\n#include \"oneflow/core/operator/operator.h\"\n\nnamespace oneflow {\n\nclass EsacOp final : public Operator {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(EsacOp);\n  EsacOp() = default;\n  ~EsacOp() override = default;\n\n  Maybe<void> InitFromOpConf() override;\n  Maybe<void> InferLogicalOutBlobDescs(\n      const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,\n      const ParallelDesc& parallel_desc) const override;\n  Maybe<void> InferOutBlobDescs(\n      const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,\n      const ParallelContext* parallel_ctx) const override;\n\n private:\n  Maybe<void> GetSbpSignatures(\n      const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,\n      SbpSignatureList* sbp_sig_list) const override;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_OPERATOR_ESAC_OP_H_\n"
  },
  {
    "path": "oneflow/core/operator/identity_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/operator/operator.h\"\n#include \"oneflow/core/job/sbp_signature_builder.h\"\n#include \"oneflow/core/job/local_sig_infer_hint.h\"\n#include \"oneflow/core/common/protobuf.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nMaybe<void> InferBlobDescs(const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp) {\n  *BlobDesc4BnInOp(\"out\") = *BlobDesc4BnInOp(\"in\");\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\ntemplate<typename T>\nclass IdentityOpTpl final : public Operator {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(IdentityOpTpl);\n  IdentityOpTpl() = default;\n  ~IdentityOpTpl() override = default;\n\n  Maybe<void> InitFromOpConf() override {\n    EnrollInputBn(\"in\");\n    EnrollOutputBn(\"out\")->set_const_inplace_ibn(\"in\");\n    return Maybe<void>::Ok();\n  }\n  Maybe<void> InferLogicalOutBlobDescs(\n      const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,\n      const ParallelDesc& parallel_desc) const override {\n    return InferBlobDescs(BlobDesc4BnInOp);\n  }\n  Maybe<void> InferOutBlobDescs(\n      const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,\n      const ParallelContext* parallel_ctx) const override {\n    return InferBlobDescs(GetBlobDesc4BnInOp);\n  }\n\n private:\n  Maybe<void> GetSbpSignatures(\n      const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,\n      SbpSignatureList* sbp_sig_list) const override {\n    const auto bns = StdVec2PbRpf<std::string>({\"in\", \"out\"});\n    SbpSignatureBuilder().PartialSum(bns).Build(sbp_sig_list->mutable_sbp_signature()->Add());\n    const int64_t num_axes = JUST(LogicalBlobDesc4Ibn(\"in\")).shape().NumAxes();\n    SbpSignatureBuilder().Split(bns, 0).MakeSplitSignatureListBuilder(num_axes).Build(sbp_sig_list);\n    return Maybe<void>::Ok();\n  }\n};\n\nstruct IdentityOp {};\nREGISTER_OP(OperatorConf::kIdentityConf, IdentityOpTpl<IdentityOp>);\n\nstruct CopyOp {};\nREGISTER_OP(OperatorConf::kCopyConf, IdentityOpTpl<CopyOp>);\n\nclass LocalCastOp : public Operator {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(LocalCastOp);\n  LocalCastOp() = default;\n  virtual ~LocalCastOp() override = default;\n\n  Maybe<void> InitFromOpConf() override {\n    EnrollInputBn(\"in\");\n    EnrollOutputBn(\"out\")->set_const_inplace_ibn(\"in\");\n    return Maybe<void>::Ok();\n  }\n  Maybe<void> InferLogicalOutBlobDescs(\n      const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,\n      const ParallelDesc& parallel_desc) const override {\n    return InferBlobDescs(BlobDesc4BnInOp);\n  }\n  Maybe<void> InferOutBlobDescs(\n      const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,\n      const ParallelContext* parallel_ctx) const override {\n    return InferBlobDescs(GetBlobDesc4BnInOp);\n  }\n\n private:\n};\n\nnamespace {\n\nclass CastToLocalOp : public LocalCastOp {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CastToLocalOp);\n  CastToLocalOp() = default;\n  virtual ~CastToLocalOp() override = default;\n\n private:\n  Maybe<void> InferLogicalOutBlobDescs(\n      const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,\n      const ParallelDesc& parallel_desc) const override {\n    BlobDesc* out = BlobDesc4BnInOp(\"out\");\n    *out = *BlobDesc4BnInOp(\"in\");\n    const SbpParallel& conf_sbp = SbpParallel(op_conf().cast_to_local_conf().sbp_parallel());\n    if (conf_sbp.has_split_parallel()) {\n      const int64_t axis = conf_sbp.split_parallel().axis();\n      CHECK_GE_OR_RETURN(axis, 0);\n      CHECK_LT_OR_RETURN(axis, out->shape().NumAxes());\n      const int64_t dim_value = out->shape().At(axis);\n      const int64_t parallel_num = parallel_desc.parallel_num();\n      CHECK_EQ_OR_RETURN(dim_value % parallel_num, 0);\n      Shape output = out->shape();\n      output.Set(axis, dim_value / parallel_num);\n      out->set_shape(output);\n    }\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> InferSbpSignature(\n      SbpSignature* sbp_signature, const SbpSignature& sbp_sig_conf,\n      const std::function<int32_t(const SbpSignature&)>& CalcOrderValue4SbpSig,\n      std::function<Maybe<const SbpInferHint*>(const std::string&)> SbpInferHint4Ibn,\n      const ParallelDesc& parallel_desc) const override {\n    CHECK_NE_OR_RETURN(op_conf().cast_to_local_conf().sbp_parallel().parallel_type_case(),\n                       SbpParallel::PARALLEL_TYPE_NOT_SET)\n        << \"attribute sbp_parallel not set.\";\n    const auto& ibn_hint = *JUST(SbpInferHint4Ibn(\"in\"));\n    CHECK_EQ_OR_RETURN(ibn_hint.parallel_desc().parallel_num(), parallel_desc.parallel_num());\n    auto* map = sbp_signature->mutable_bn_in_op2sbp_parallel();\n    const SbpParallel& conf_sbp = SbpParallel(op_conf().cast_to_local_conf().sbp_parallel());\n    CHECK_OR_RETURN(ibn_hint.sbp_parallel() == conf_sbp);\n    (*map)[\"in\"] = ibn_hint.sbp_parallel();\n    (*map)[\"out\"] = conf_sbp;\n    return Maybe<void>::Ok();\n  }\n  Maybe<void> InferLocalSignature(\n      std::function<Maybe<const LocalSigInferHint*>(const std::string&)> LocalSigInferHint4Ibn,\n      bool is_local_parallel_view_conf, const ParallelDesc& parallel_desc) override {\n    const auto& in_infer_hint = *JUST(LocalSigInferHint4Ibn(\"in\"));\n    CHECK_OR_RETURN(!in_infer_hint.is_local_parallel_view())\n        << \"error use of CastToLocalOp. `in' shouldn't be a local blob\";\n    CHECK_EQ_OR_RETURN(in_infer_hint.parallel_desc().parallel_num(), parallel_desc.parallel_num());\n    MutOptLocalParallel(\"in\")->clear_local_parallel();\n    MutOptLocalParallel(\"out\")->mutable_local_parallel();\n    return Maybe<void>::Ok();\n  }\n};\n\nREGISTER_OP(OperatorConf::kCastToLocalConf, CastToLocalOp);\n\n}  // namespace\n\nnamespace {\n\nclass CastFromLocalOp : public LocalCastOp {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CastFromLocalOp);\n  CastFromLocalOp() = default;\n  virtual ~CastFromLocalOp() override = default;\n\n private:\n  Maybe<void> InferLogicalOutBlobDescs(\n      const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,\n      const ParallelDesc& parallel_desc) const override {\n    BlobDesc* out = BlobDesc4BnInOp(\"out\");\n    *out = *BlobDesc4BnInOp(\"in\");\n    const SbpParallel& conf_sbp = SbpParallel(op_conf().cast_from_local_conf().sbp_parallel());\n    if (conf_sbp.has_split_parallel()) {\n      const int64_t axis = conf_sbp.split_parallel().axis();\n      CHECK_GE_OR_RETURN(axis, 0);\n      CHECK_LT_OR_RETURN(axis, out->shape().NumAxes());\n      Shape output = out->shape();\n      output.Set(axis, out->shape().At(axis) * parallel_desc.parallel_num());\n      out->set_shape(output);\n    }\n    return Maybe<void>::Ok();\n  }\n  Maybe<void> InferSbpSignature(\n      SbpSignature* sbp_signature, const SbpSignature& sbp_sig_conf,\n      const std::function<int32_t(const SbpSignature&)>& CalcOrderValue4SbpSig,\n      std::function<Maybe<const SbpInferHint*>(const std::string&)> SbpInferHint4Ibn,\n      const ParallelDesc& parallel_desc) const override {\n    CHECK_NE_OR_RETURN(op_conf().cast_from_local_conf().sbp_parallel().parallel_type_case(),\n                       SbpParallel::PARALLEL_TYPE_NOT_SET)\n        << \"attribute sbp_parallel not set.\";\n    const auto& ibn_hint = *JUST(SbpInferHint4Ibn(\"in\"));\n    CHECK_EQ_OR_RETURN(ibn_hint.parallel_desc().parallel_num(), parallel_desc.parallel_num());\n    auto* map = sbp_signature->mutable_bn_in_op2sbp_parallel();\n    (*map)[\"in\"] = ibn_hint.sbp_parallel();\n    (*map)[\"out\"] = SbpParallel(op_conf().cast_from_local_conf().sbp_parallel());\n    return Maybe<void>::Ok();\n  }\n  Maybe<void> InferLocalSignature(\n      std::function<Maybe<const LocalSigInferHint*>(const std::string&)> LocalSigInferHint4Ibn,\n      bool is_local_parallel_view_conf, const ParallelDesc& parallel_desc) override {\n    const auto& in_infer_hint = *JUST(LocalSigInferHint4Ibn(\"in\"));\n    CHECK_OR_RETURN(in_infer_hint.is_local_parallel_view())\n        << \"error use of CastFromLocalOp. `in' should be a local blob\";\n    CHECK_EQ_OR_RETURN(in_infer_hint.parallel_desc().parallel_num(), parallel_desc.parallel_num());\n    MutOptLocalParallel(\"in\")->mutable_local_parallel();\n    MutOptLocalParallel(\"out\")->clear_local_parallel();\n    return Maybe<void>::Ok();\n  }\n};\n\nREGISTER_OP(OperatorConf::kCastFromLocalConf, CastFromLocalOp);\n\n}  // namespace\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/operator/image_decoder_random_crop_resize_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/operator/operator.h\"\n#include \"oneflow/core/vm/symbol_storage.h\"\n#include \"oneflow/core/job/scope.h\"\n\n#ifdef WITH_CUDA\n#include <cuda.h>\n#endif  // WITH_CUDA\n\nnamespace oneflow {\n\nnamespace {\n\nMaybe<void> InferBlobDescs(const OperatorConf& op_conf,\n                           const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp) {\n  const ImageDecoderRandomCropResizeOpConf& conf = op_conf.image_decoder_random_crop_resize_conf();\n  const BlobDesc* in = BlobDesc4BnInOp(\"in\");\n  BlobDesc* out = BlobDesc4BnInOp(\"out\");\n  CHECK_EQ_OR_RETURN(in->data_type(), DataType::kTensorBuffer);\n  *out = *in;\n  out->set_data_type(DataType::kUInt8);\n  DimVector out_dim_vec = in->shape().dim_vec();\n  out_dim_vec.emplace_back(conf.target_height());\n  out_dim_vec.emplace_back(conf.target_width());\n  out_dim_vec.emplace_back(3);\n  out->set_shape(Shape(out_dim_vec));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\nclass ImageDecoderRandomCropResizeOp final : public Operator {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(ImageDecoderRandomCropResizeOp);\n  ImageDecoderRandomCropResizeOp() = default;\n  ~ImageDecoderRandomCropResizeOp() override = default;\n\n private:\n  Maybe<void> InitFromOpConf() override {\n    EnrollInputBn(\"in\", false);\n    EnrollOutputBn(\"out\", false);\n    EnrollTmpBn(\"tmp\");\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> InferLogicalOutBlobDescs(\n      const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,\n      const ParallelDesc& parallel_desc) const override {\n    return InferBlobDescs(this->op_conf(), BlobDesc4BnInOp);\n  }\n\n  Maybe<void> InferOutBlobDescs(\n      const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,\n      const ParallelContext* parallel_ctx) const override {\n    return InferBlobDescs(this->op_conf(), GetBlobDesc4BnInOp);\n  }\n\n  Maybe<void> InferInternalBlobDescs(\n      const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,\n      const ParallelContext* parallel_ctx, const JobDesc* job_desc) const override {\n    const ImageDecoderRandomCropResizeOpConf& conf =\n        this->op_conf().image_decoder_random_crop_resize_conf();\n    BlobDesc* tmp = GetBlobDesc4BnInOp(\"tmp\");\n    tmp->set_data_type(DataType::kUInt8);\n    tmp->set_shape(Shape({conf.max_num_pixels() * 3 * conf.num_workers()}));\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> GetSbpSignatures(\n      const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,\n      SbpSignatureList* sbp_sig_list) const override {\n    SbpSignatureBuilder()\n        .Split(\"in\", 0)\n        .Split(\"out\", 0)\n        .MakeSplitSignatureListBuilder(JUST(LogicalBlobDesc4Ibn(\"in\")).shape().NumAxes())\n        .Build(sbp_sig_list);\n    return Maybe<void>::Ok();\n  }\n\n  void VirtualGenKernelConf(std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,\n                            const ParallelContext* parallel_ctx,\n                            KernelConf* kernel_conf) const override {\n    const ImageDecoderRandomCropResizeOpConf& conf =\n        this->op_conf().image_decoder_random_crop_resize_conf();\n    int64_t seed;\n    if (conf.has_seed()) {\n      seed = conf.seed();\n    } else {\n      std::random_device rd;\n      seed = rd();\n    }\n    std::seed_seq seq{seed};\n    std::vector<int64_t> seeds(parallel_ctx->parallel_num());\n    seq.generate(seeds.begin(), seeds.end());\n    kernel_conf->mutable_image_decoder_random_crop_resize_conf()->set_seed(\n        seeds.at(parallel_ctx->parallel_id()));\n    kernel_conf->mutable_image_decoder_random_crop_resize_conf()->set_batch_size(\n        GetBlobDesc4BnInOp(\"in\")->shape().elem_cnt());\n  }\n\n  Maybe<void> InferBlobParallelDesc() override {\n    HashMap<std::string, std::shared_ptr<const ParallelDesc>> bn2parallel_desc;\n    const std::shared_ptr<const ParallelDesc> op_parallel_desc = JUST(GetOpParallelDesc());\n    bn2parallel_desc[\"out\"] = op_parallel_desc;\n    if (device_type() == DeviceType::kCPU) {\n      bn2parallel_desc[\"in\"] = op_parallel_desc;\n    } else if (device_type() == DeviceType::kCUDA) {\n      std::shared_ptr<ParallelDesc> in_parallel_desc =\n          std::make_shared<ParallelDesc>(*op_parallel_desc);\n      in_parallel_desc->set_device_type(DeviceType::kCPU);\n      bn2parallel_desc[\"in\"] = in_parallel_desc;\n    } else {\n      UNIMPLEMENTED_THEN_RETURN();\n    }\n    JUST(FillBlobParallelDesc([&](const std::string& bn) -> Maybe<const ParallelDesc> {\n      auto it = bn2parallel_desc.find(bn);\n      CHECK_OR_RETURN(it != bn2parallel_desc.end());\n      return it->second;\n    }));\n    return Maybe<void>::Ok();\n  }\n};\n\n#if defined(WITH_CUDA) && CUDA_VERSION >= 10020\nREGISTER_OP(OperatorConf::kImageDecoderRandomCropResizeConf, ImageDecoderRandomCropResizeOp);\n#else\nREGISTER_CPU_OP(OperatorConf::kImageDecoderRandomCropResizeConf, ImageDecoderRandomCropResizeOp);\n#endif\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/operator/input_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/balanced_splitter.h\"\n#include \"oneflow/core/operator/input_op.h\"\n#include \"oneflow/core/operator/interface_op_util.h\"\n#include \"oneflow/core/job/sbp_signature_builder.h\"\n\nnamespace oneflow {\n\nnamespace {\nMaybe<void> InferInputOpNdSbpSignature(NdSbpSignature* nd_sbp_signature,\n                                       const ParallelDesc& parallel_desc,\n                                       const OperatorConf& op_conf) {\n  const auto& parallel_hierarchy = parallel_desc.hierarchy();\n  const InterfaceBlobConf& blob_conf = op_conf.input_conf().blob_conf();\n  if (op_conf.input_conf().has_tick()) {\n    NdSbp& tick_nd_sbp = (*nd_sbp_signature->mutable_bn_in_op2nd_sbp())[\"tick\"];\n    tick_nd_sbp.clear_sbp_parallel();\n    FOR_RANGE(int64_t, i, 0, parallel_hierarchy->NumAxes()) {\n      tick_nd_sbp.mutable_sbp_parallel()->Add()->mutable_broadcast_parallel();\n    }\n  }\n  NdSbp& out_nd_sbp = (*nd_sbp_signature->mutable_bn_in_op2nd_sbp())[\"out\"];\n  JUST(InterfaceOpUtil::ParseNdSbpFromBlobConf(blob_conf, parallel_desc, &out_nd_sbp));\n  return Maybe<void>::Ok();\n}\n}  // namespace\n\nMaybe<void> InputOp::InitFromOpConf() {\n  CHECK(op_conf().has_input_conf());\n  if (op_conf().input_conf().has_tick()) { EnrollInputBn(\"tick\", false); }\n  OutputBlobModifier* modifier = EnrollOutputBn(\"out\", false);\n  modifier->set_is_mutable(true);\n  modifier->set_header_infered_before_compute(false);\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InputOp::InferLogicalOutBlobDescs(\n    const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,\n    const ParallelDesc& parallel_desc) const {\n  BlobDesc* out_blob_desc = BlobDesc4BnInOp(\"out\");\n  JUST(InterfaceOpUtil::InferLogicalOutBlobDesc(op_conf().input_conf().blob_conf(), out_blob_desc,\n                                                parallel_desc));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InputOp::InferOutBlobDescs(\n    const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,\n    const ParallelContext* parallel_ctx) const {\n  BlobDesc* out_blob_desc = GetBlobDesc4BnInOp(\"out\");\n  JUST(InterfaceOpUtil::InferOutBlobDesc(op_conf().input_conf().blob_conf(), out_blob_desc,\n                                         parallel_ctx, *JUST(GetOpParallelDesc())));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InputOp::InferSbpSignature(\n    SbpSignature* sbp_signature, const SbpSignature& sbp_sig_conf,\n    const std::function<int32_t(const SbpSignature&)>& CalcOrderValue4SbpSig,\n    std::function<Maybe<const SbpInferHint*>(const std::string&)> SbpInferHint4Ibn,\n    const ParallelDesc& parallel_desc) const {\n  JUST(InterfaceOpUtil::GetInputLikeOpSbpSignature(op_conf().input_conf().blob_conf(), input_bns(),\n                                                   output_bns(), sbp_signature));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InputOp::GetSbpSignatures(SbpSignatureList* sbp_sig_list) const {\n  JUST(InterfaceOpUtil::GetInputLikeOpSbpSignature(op_conf().input_conf().blob_conf(), input_bns(),\n                                                   output_bns(),\n                                                   sbp_sig_list->mutable_sbp_signature()->Add()));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InputOp::GetNdSbpSignatureList(\n    const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,\n    const ParallelDesc& parallel_desc, std::vector<NdSbpSignature>* nd_sbp_sig_list) const {\n  NdSbpSignature nd_sbp_signature;\n  JUST(InferInputOpNdSbpSignature(&nd_sbp_signature, parallel_desc, op_conf()));\n  nd_sbp_sig_list->emplace_back(nd_sbp_signature);\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InputOp::InferNdSbpSignature(\n    NdSbpSignature* nd_sbp_signature, const NdSbpSignature& nd_sbp_constraints,\n    const ParallelDesc& parallel_desc,\n    std::function<Maybe<const NdSbpInferHint*>(const std::string&)> NdSbpInferHint4Ibn) const {\n  JUST(InferInputOpNdSbpSignature(nd_sbp_signature, parallel_desc, op_conf()));\n  return Maybe<void>::Ok();\n}\n\nSymbol<OperatorConf> InputOp::GetOpConfWithoutOpNameAndLbn() const {\n  return SymbolOf(this->op_conf());\n}\n\nREGISTER_OP(OperatorConf::kInputConf, InputOp);\nREGISTER_OP_SAME_OUTPUT_BLOB_REGST_NUM(OperatorConf::kInputConf, 1);\nREGISTER_INTERFACE_OP(OperatorConf::kInputConf);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/operator/input_op.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_OPERATOR_INPUT_OP_H_\n#define ONEFLOW_CORE_OPERATOR_INPUT_OP_H_\n\n#include \"oneflow/core/operator/operator.h\"\n\nnamespace oneflow {\n\nclass InputOp final : public Operator {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(InputOp);\n  InputOp() : Operator() {}\n  ~InputOp() = default;\n\n  Maybe<void> InitFromOpConf() override;\n  Maybe<void> InferOutBlobDescs(\n      const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,\n      const ParallelContext* parallel_ctx) const override;\n\n private:\n  Maybe<void> InferLogicalOutBlobDescs(\n      const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,\n      const ParallelDesc& parallel_desc) const override;\n  Maybe<void> InferSbpSignature(\n      SbpSignature* sbp_signature, const SbpSignature& sbp_sig_conf,\n      const std::function<int32_t(const SbpSignature&)>& CalcOrderValue4SbpSig,\n      std::function<Maybe<const SbpInferHint*>(const std::string&)> SbpInferHint4Ibn,\n      const ParallelDesc& parallel_desc) const override;\n\n  Maybe<void> GetSbpSignatures(SbpSignatureList* sbp_sig_list) const override;\n  Maybe<void> GetNdSbpSignatureList(\n      const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,\n      const ParallelDesc& parallel_desc,\n      std::vector<NdSbpSignature>* nd_sbp_sig_list) const override;\n  Symbol<OperatorConf> GetOpConfWithoutOpNameAndLbn() const override;\n  Maybe<void> InferNdSbpSignature(NdSbpSignature* nd_sbp_signature,\n                                  const NdSbpSignature& nd_sbp_constraints,\n                                  const ParallelDesc& parallel_desc,\n                                  std::function<Maybe<const NdSbpInferHint*>(const std::string&)>\n                                      NdSbpInferHint4Ibn) const override;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_OPERATOR_INPUT_OP_H_\n"
  },
  {
    "path": "oneflow/core/operator/interface_blob_conf.proto",
    "content": "syntax = \"proto2\";\npackage oneflow;\n\nimport \"oneflow/core/common/shape.proto\";\nimport \"oneflow/core/common/data_type.proto\";\nimport \"oneflow/core/job/sbp_parallel.proto\";\n\nmessage InterfaceBlobConf {\n  optional ShapeProto shape = 1;\n  optional DataType data_type = 2;\n  optional bool is_dynamic = 3;\n  optional NdSbp nd_sbp = 4;\n}\n"
  },
  {
    "path": "oneflow/core/operator/interface_op_util.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/operator/interface_op_util.h\"\n#include \"oneflow/core/common/balanced_splitter.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nvoid CheckShape(const Shape& shape) {\n  FOR_RANGE(int, i, 1, shape.NumAxes()) { CHECK_GE(shape.At(i), 0); }\n}\n\nMaybe<void> GetSbpSignature(const InterfaceBlobConf& blob_conf, const PbRpf<std::string>& input_bns,\n                            const PbRpf<std::string>& output_bns, SbpSignature* sbp_signature,\n                            bool is_for_input_op) {\n  if (!blob_conf.has_nd_sbp()) {\n    SbpSignatureBuilder().Broadcast(input_bns).Broadcast(output_bns).Build(sbp_signature);\n    return Maybe<void>::Ok();\n  }\n  CHECK_EQ_OR_RETURN(blob_conf.nd_sbp().sbp_parallel_size(), 1);\n  const auto& sbp_parallel = blob_conf.nd_sbp().sbp_parallel(0);\n  if (sbp_parallel.has_split_parallel()) {\n    int64_t num_axes = blob_conf.shape().dim_size();\n    int64_t split_axis = sbp_parallel.split_parallel().axis();\n    CHECK_GE_OR_RETURN(split_axis, 0);\n    CHECK_LT_OR_RETURN(split_axis, num_axes);\n    SbpSignatureBuilder sbp_signature_builder;\n    if (is_for_input_op) {\n      // broadcast tick args for InputOp\n      sbp_signature_builder.Broadcast(input_bns);\n    } else {\n      sbp_signature_builder.Split(input_bns, split_axis);\n    }\n    sbp_signature_builder.Split(output_bns, split_axis).Build(sbp_signature);\n  } else {\n    SbpSignatureBuilder().Broadcast(input_bns).Broadcast(output_bns).Build(sbp_signature);\n  }\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\nMaybe<void> InterfaceOpUtil::InferOutBlobDesc(const InterfaceBlobConf& blob_conf,\n                                              BlobDesc* out_blob_desc,\n                                              const ParallelContext* parallel_ctx,\n                                              const ParallelDesc& parallel_desc) {\n  NdSbp nd_sbp;\n  JUST(ParseNdSbpFromBlobConf(blob_conf, parallel_desc, &nd_sbp));\n  out_blob_desc->set_shape(\n      *JUST(GetPhysicalShape(Shape(blob_conf.shape()), nd_sbp, parallel_desc, *parallel_ctx)));\n  out_blob_desc->set_data_type(blob_conf.data_type());\n  out_blob_desc->set_is_dynamic(blob_conf.is_dynamic());\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InterfaceOpUtil::InferLogicalOutBlobDesc(const InterfaceBlobConf& blob_conf,\n                                                     BlobDesc* out_blob_desc,\n                                                     const ParallelDesc& parallel_desc) {\n  CHECK_OR_RETURN(blob_conf.has_shape());\n  out_blob_desc->set_shape(Shape(blob_conf.shape()));\n  CheckShape(out_blob_desc->shape());\n  if (out_blob_desc->shape().NumAxes() > 0) { CHECK_GT(out_blob_desc->shape().At(0), 0); }\n  CHECK_OR_RETURN(blob_conf.has_data_type());\n  out_blob_desc->set_data_type(blob_conf.data_type());\n  CHECK_OR_RETURN(blob_conf.has_is_dynamic());\n  out_blob_desc->set_is_dynamic(blob_conf.is_dynamic());\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InterfaceOpUtil::GetInputLikeOpSbpSignature(const InterfaceBlobConf& blob_conf,\n                                                        const PbRpf<std::string>& input_bns,\n                                                        const PbRpf<std::string>& output_bns,\n                                                        SbpSignature* sbp_signature) {\n  JUST(GetSbpSignature(blob_conf, input_bns, output_bns, sbp_signature, true));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InterfaceOpUtil::GetOutputLikeOpSbpSignature(const InterfaceBlobConf& blob_conf,\n                                                         const PbRpf<std::string>& input_bns,\n                                                         const PbRpf<std::string>& output_bns,\n                                                         SbpSignature* sbp_signature) {\n  JUST(GetSbpSignature(blob_conf, input_bns, output_bns, sbp_signature, false));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InterfaceOpUtil::InitBlobConf(InterfaceBlobConf* blob_conf,\n                                          const ParallelBlobConf& parallel_blob_conf) {\n  BlobDesc blob_desc(parallel_blob_conf.logical_blob_desc_conf());\n  blob_desc.shape().ToProto(blob_conf->mutable_shape());\n  blob_conf->set_data_type(blob_desc.data_type());\n  blob_conf->set_is_dynamic(blob_desc.is_dynamic());\n  *blob_conf->mutable_nd_sbp() = parallel_blob_conf.nd_sbp();\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InterfaceOpUtil::ParseNdSbpFromBlobConf(const InterfaceBlobConf& blob_conf,\n                                                    const ParallelDesc& parallel_desc,\n                                                    NdSbp* nd_sbp) {\n  const int64_t num_axes = parallel_desc.hierarchy()->NumAxes();\n  if (blob_conf.has_nd_sbp()) {\n    *nd_sbp = NdSbp(blob_conf.nd_sbp());\n  } else {\n    nd_sbp->clear_sbp_parallel();\n    FOR_RANGE(int64_t, i, 0, num_axes) { nd_sbp->add_sbp_parallel()->mutable_broadcast_parallel(); }\n  }\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/operator/interface_op_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_OPERATOR_INTERFACE_OP_UTIL_H_\n#define ONEFLOW_CORE_OPERATOR_INTERFACE_OP_UTIL_H_\n\n#include \"oneflow/core/operator/operator.h\"\n#include \"oneflow/core/job/sbp_signature_builder.h\"\n#include \"oneflow/core/job/job.pb.h\"\n\nnamespace oneflow {\n\nstruct InterfaceOpUtil final {\n  static Maybe<void> InferOutBlobDesc(const InterfaceBlobConf& blob_conf, BlobDesc* out_blob_desc,\n                                      const ParallelContext* parallel_ctx,\n                                      const ParallelDesc& parallel_desc);\n  static Maybe<void> InferLogicalOutBlobDesc(const InterfaceBlobConf& blob_conf,\n                                             BlobDesc* out_blob_desc,\n                                             const ParallelDesc& parallel_desc);\n  static Maybe<void> GetInputLikeOpSbpSignature(const InterfaceBlobConf& blob_conf,\n                                                const PbRpf<std::string>& input_bns,\n                                                const PbRpf<std::string>& output_bns,\n                                                SbpSignature* sbp_signature);\n  static Maybe<void> GetOutputLikeOpSbpSignature(const InterfaceBlobConf& blob_conf,\n                                                 const PbRpf<std::string>& input_bns,\n                                                 const PbRpf<std::string>& output_bns,\n                                                 SbpSignature* sbp_signature);\n  static Maybe<void> InitBlobConf(InterfaceBlobConf* blob_conf,\n                                  const ParallelBlobConf& parallel_blob_conf);\n\n  static Maybe<void> ParseNdSbpFromBlobConf(const InterfaceBlobConf& blob_conf,\n                                            const ParallelDesc& parallel_desc, NdSbp* nd_sbp);\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_OPERATOR_INTERFACE_OP_UTIL_H_\n"
  },
  {
    "path": "oneflow/core/operator/learning_rate_schedule_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/operator/operator.h\"\n\nnamespace oneflow {\n\nclass LearningRateScheduleOp final : public Operator {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(LearningRateScheduleOp);\n  LearningRateScheduleOp() = default;\n  ~LearningRateScheduleOp() override = default;\n\n  Maybe<void> InitFromOpConf() override;\n  virtual Maybe<void> InferLogicalOutBlobDescs(\n      const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,\n      const ParallelDesc& parallel_desc) const override;\n  Maybe<void> InferOutBlobDescs(\n      const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,\n      const ParallelContext* parallel_ctx) const override;\n\n private:\n  Maybe<void> GetSbpSignatures(\n      const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,\n      SbpSignatureList* sbp_sig_list) const override;\n};\n\nMaybe<void> LearningRateScheduleOp::InitFromOpConf() {\n  CHECK(op_conf().has_learning_rate_schedule_conf());\n  EnrollInputBn(\"train_step\");\n  EnrollOutputBn(\"out\");\n  return Maybe<void>::Ok();\n}\n\nnamespace {\n\nMaybe<void> InferBlobDescs(const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp) {\n  const BlobDesc* train_step = BlobDesc4BnInOp(\"train_step\");\n  CHECK_EQ(train_step->shape().elem_cnt(), 1);\n  CHECK_EQ(train_step->data_type(), DataType::kInt64);\n  BlobDesc* out = BlobDesc4BnInOp(\"out\");\n  out->set_shape(Shape({1}));\n  out->set_data_type(DataType::kFloat);\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\nMaybe<void> LearningRateScheduleOp::InferLogicalOutBlobDescs(\n    const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,\n    const ParallelDesc& parallel_desc) const {\n  return InferBlobDescs(BlobDesc4BnInOp);\n}\n\nMaybe<void> LearningRateScheduleOp::InferOutBlobDescs(\n    const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,\n    const ParallelContext* parallel_ctx) const {\n  return InferBlobDescs(GetBlobDesc4BnInOp);\n}\n\nMaybe<void> LearningRateScheduleOp::GetSbpSignatures(\n    const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,\n    SbpSignatureList* sbp_sig_list) const {\n  return Maybe<void>::Ok();\n}\n\nREGISTER_CPU_OP(OperatorConf::kLearningRateScheduleConf, LearningRateScheduleOp);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/operator/nccl_send_recv_boxing_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/container_util.h\"\n#include \"oneflow/core/operator/operator.h\"\n#include \"oneflow/core/common/protobuf.h\"\n#include \"oneflow/core/operator/nccl_send_recv_boxing_op_util.h\"\n\nnamespace oneflow {\n\nclass NcclSendRecvBoxingOp : public Operator {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(NcclSendRecvBoxingOp);\n  NcclSendRecvBoxingOp() = default;\n  ~NcclSendRecvBoxingOp() override = default;\n\n  Maybe<void> InitFromOpConf() override;\n  Maybe<void> InferInternalBlobDescs(\n      const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,\n      const ParallelContext* parallel_ctx, const JobDesc* job_desc) const override;\n  Maybe<void> InferLogicalOutBlobDescs(\n      const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,\n      const ParallelDesc& parallel_desc) const override {\n    UNIMPLEMENTED_THEN_RETURN();\n  }\n  Maybe<void> InferOutBlobDescs(\n      const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,\n      const ParallelContext* parallel_ctx) const override;\n\n private:\n  LogicalBlobId lbi4ibn(const std::string& input_bn) const override;\n  LogicalBlobId lbi4obn(const std::string& output_bn) const override;\n};\n\nMaybe<void> NcclSendRecvBoxingOp::InitFromOpConf() {\n  const NcclSendRecvBoxingOpConf& conf = this->op_conf().nccl_send_recv_boxing_conf();\n  if (conf.has_input()) { EnrollInputBn(\"in\", false); }\n  if (conf.has_output()) { EnrollOutputBn(\"out\", false); }\n  EnrollTmpBn(\"buf\");\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> NcclSendRecvBoxingOp::InferInternalBlobDescs(\n    const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,\n    const ParallelContext* parallel_ctx, const JobDesc* job_desc) const {\n  BlobDesc* buf = GetBlobDesc4BnInOp(\"buf\");\n  const NcclSendRecvBoxingOpConf& conf = this->op_conf().nccl_send_recv_boxing_conf();\n  const NdSbp& src_nd_sbp = conf.src_nd_sbp();\n  const NdSbp& dst_nd_sbp = conf.dst_nd_sbp();\n  ParallelDesc parallel_desc(conf.parallel_conf());\n  ParallelDesc in_parallel_desc(conf.src_parallel_conf());\n  ParallelDesc out_parallel_desc(conf.dst_parallel_conf());\n  const int64_t parallel_num = parallel_desc.parallel_num();\n  const int64_t parallel_id = parallel_ctx->parallel_id();\n  const Shape& logical_shape = Shape(conf.logical_shape());\n  std::vector<TensorSliceView> src_send_intersections;\n  std::vector<TensorSliceView> dst_recv_intersections;\n  GetRankSendRecvIntersection(parallel_id, parallel_desc, in_parallel_desc, out_parallel_desc,\n                              src_nd_sbp, dst_nd_sbp, logical_shape, &src_send_intersections,\n                              &dst_recv_intersections);\n  int64_t buf_count = 0;\n  if (conf.has_input()) {\n    const BlobDesc* in = GetBlobDesc4BnInOp(\"in\");\n    buf->set_data_type(in->data_type());\n    CHECK_EQ(src_send_intersections.size(), parallel_num);\n    for (int64_t i = 0; i < parallel_num; ++i) {\n      const TensorSliceView& intersection = JUST(VectorAt(src_send_intersections, i));\n      if (!intersection.IsEmpty()) { buf_count += intersection.shape().elem_cnt(); }\n    }\n  }\n  if (conf.has_output()) {\n    const BlobDesc* out = GetBlobDesc4BnInOp(\"out\");\n    buf->set_data_type(out->data_type());\n    for (int64_t i = 0; i < parallel_num; ++i) {\n      const TensorSliceView& intersection = JUST(VectorAt(dst_recv_intersections, i));\n      if (!intersection.IsEmpty()) { buf_count += intersection.shape().elem_cnt(); }\n    }\n    if (NdSbpHasPartialParallel(src_nd_sbp)) {\n      // Note: when src_nd_sbp has partial_sum, need a out_size buffer to copy and add to out.\n      buf_count += out->shape().elem_cnt();\n    }\n  }\n  buf->set_shape(Shape({buf_count}));\n  return Maybe<void>::Ok();\n}\n\nLogicalBlobId NcclSendRecvBoxingOp::lbi4ibn(const std::string& input_bn) const {\n  return this->op_conf().nccl_send_recv_boxing_conf().lbi();\n}\n\nLogicalBlobId NcclSendRecvBoxingOp::lbi4obn(const std::string& output_bn) const {\n  return this->op_conf().nccl_send_recv_boxing_conf().lbi();\n}\n\nMaybe<void> NcclSendRecvBoxingOp::InferOutBlobDescs(\n    const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,\n    const ParallelContext* parallel_ctx) const {\n  const NcclSendRecvBoxingOpConf& conf = this->op_conf().nccl_send_recv_boxing_conf();\n  const Shape& logical_shape = Shape(conf.logical_shape());\n  const ParallelDesc& parallel_desc = ParallelDesc(conf.parallel_conf());\n  const int64_t machine_id = JUST(parallel_desc.MachineId4ParallelId(parallel_ctx->parallel_id()));\n  const int64_t device_index = JUST(parallel_desc.DeviceId4ParallelId(parallel_ctx->parallel_id()));\n  if (conf.has_input()) {\n    const BlobDesc* in_blob_desc = GetBlobDesc4BnInOp(\"in\");\n    const NdSbp& src_nd_sbp = conf.src_nd_sbp();\n    const ParallelDesc& src_parallel_desc = ParallelDesc(conf.src_parallel_conf());\n    int64_t src_parallel_id =\n        JUST(src_parallel_desc.ParallelId4MachineDeviceId(machine_id, device_index));\n    std::shared_ptr<Shape> in_shape =\n        JUST(GetPhysicalShape(logical_shape, src_nd_sbp, src_parallel_desc, src_parallel_id));\n    CHECK_EQ_OR_RETURN(*in_shape, in_blob_desc->shape())\n        << \"Non-matching shape of blobs for pieces of nccl send recv\";\n  }\n  if (conf.has_output()) {\n    BlobDesc* out_blob_desc = GetBlobDesc4BnInOp(\"out\");\n    const NdSbp& dst_nd_sbp = conf.dst_nd_sbp();\n    const ParallelDesc& dst_parallel_desc = ParallelDesc(conf.dst_parallel_conf());\n    int64_t dst_parallel_id =\n        JUST(dst_parallel_desc.ParallelId4MachineDeviceId(machine_id, device_index));\n    std::shared_ptr<Shape> out_shape =\n        JUST(GetPhysicalShape(logical_shape, dst_nd_sbp, dst_parallel_desc, dst_parallel_id));\n    out_blob_desc->set_shape(*out_shape);\n    out_blob_desc->set_data_type(conf.data_type());\n  }\n  return Maybe<void>::Ok();\n}\n\nREGISTER_OP(OperatorConf::kNcclSendRecvBoxingConf, NcclSendRecvBoxingOp);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/operator/nccl_send_recv_boxing_op_util.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/nd_index_offset_helper.h\"\n#include \"oneflow/core/operator/nccl_send_recv_boxing_op_util.h\"\n\nnamespace oneflow {\n\nnamespace {\n// Go through all the ranks while transfer between two nd sbps with no PartialSum under the same\n// placement.\n// NOTE: We need to make sure no partial sums in the sbps of the producer and consumer.\nvoid DfsTraverseRanks4NdSbp(\n    int32_t depth, std::vector<int64_t>& in_parallel_ids,\n    const std::vector<int64_t>& out_parallel_ids, const Shape& in_parallel_hierarchy,\n    const NdIndexOffsetHelper<int64_t, SHAPE_MAX_AXIS_SIZE>& in_hierarchy_index_helper,\n    const NdSbp& in_nd_sbp, const std::function<void(int32_t)>& visit) {\n  if (depth >= in_parallel_hierarchy.NumAxes()) {\n    visit(in_hierarchy_index_helper.NdIndexToOffset(in_parallel_ids.data(),\n                                                    in_parallel_hierarchy.NumAxes()));\n    return;\n  }\n  if (in_nd_sbp.sbp_parallel(depth).has_broadcast_parallel()) {\n    // If Broadcast in the sbp of the producer, only visit those ranks with the same id as the\n    // current rank along the depth-dimension.\n    in_parallel_ids[depth] = out_parallel_ids[depth];\n    DfsTraverseRanks4NdSbp(depth + 1, in_parallel_ids, out_parallel_ids, in_parallel_hierarchy,\n                           in_hierarchy_index_helper, in_nd_sbp, visit);\n  } else {\n    // If Split or PartialSum, go through all the ranks along the depth-dimension.\n    for (int64_t i = 0; i < in_parallel_hierarchy.dim_vec().at(depth); i++) {\n      in_parallel_ids[depth] = i;\n      DfsTraverseRanks4NdSbp(depth + 1, in_parallel_ids, out_parallel_ids, in_parallel_hierarchy,\n                             in_hierarchy_index_helper, in_nd_sbp, visit);\n    }\n  }\n}\n\nbool NdSbpNoPartialParallel(const NdSbp& nd_sbp) {\n  CHECK_GT(nd_sbp.sbp_parallel_size(), 0);\n  FOR_RANGE(int64_t, i, 0, nd_sbp.sbp_parallel_size()) {\n    if (nd_sbp.sbp_parallel(i).has_partial_sum_parallel()) { return false; }\n  }\n  return true;\n}\n\n}  // namespace\n\nint64_t GetMappedParallelId(const int64_t from_parallel_id, const ParallelDesc& from_parallel_desc,\n                            const ParallelDesc& to_parallel_desc) {\n  const int64_t machine_id = CHECK_JUST(from_parallel_desc.MachineId4ParallelId(from_parallel_id));\n  const int64_t device_index = CHECK_JUST(from_parallel_desc.DeviceId4ParallelId(from_parallel_id));\n  if (to_parallel_desc.Containing(machine_id, device_index)) {\n    return CHECK_JUST(to_parallel_desc.ParallelId4MachineDeviceId(machine_id, device_index));\n  } else {\n    return -1;\n  }\n}\n\nvoid GetRankSendRecvIntersection(int64_t parallel_id, const ParallelDesc& parallel_desc,\n                                 const ParallelDesc& in_parallel_desc,\n                                 const ParallelDesc& out_parallel_desc, const NdSbp& in_nd_sbp,\n                                 const NdSbp& out_nd_sbp, const Shape& logical_shape,\n                                 std::vector<TensorSliceView>* send_intersections,\n                                 std::vector<TensorSliceView>* recv_intersections) {\n  const int64_t parallel_num = parallel_desc.parallel_num();\n  CHECK_LT(parallel_id, parallel_num);\n\n  const std::vector<TensorSliceView>& in_slices =\n      GetTensorSliceView(*in_parallel_desc.hierarchy(), in_nd_sbp, logical_shape);\n  const std::vector<TensorSliceView>& out_slices =\n      GetTensorSliceView(*out_parallel_desc.hierarchy(), out_nd_sbp, logical_shape);\n\n  const auto& in_parallel_hierarchy = in_parallel_desc.hierarchy();\n  int32_t in_hierarchy_dimension = in_parallel_hierarchy->NumAxes();\n  const NdIndexOffsetHelper<int64_t, SHAPE_MAX_AXIS_SIZE> in_hierarchy_index_helper(\n      in_parallel_hierarchy->dim_vec().data(), in_hierarchy_dimension);\n\n  const int64_t machine_id = CHECK_JUST(parallel_desc.MachineId4ParallelId(parallel_id));\n  const int64_t device_index = CHECK_JUST(parallel_desc.DeviceId4ParallelId(parallel_id));\n  const int64_t in_parallel_num = in_parallel_desc.parallel_num();\n  const int64_t out_parallel_num = out_parallel_desc.parallel_num();\n  // cur rank recv from\n  // cur rank has output\n  if (out_parallel_desc.Containing(machine_id, device_index)) {\n    recv_intersections->resize(parallel_num);\n    int64_t out_id =\n        CHECK_JUST(out_parallel_desc.ParallelId4MachineDeviceId(machine_id, device_index));\n    const TensorSliceView& cur_rank_out_slice = out_slices.at(out_id);\n    const auto& add_to_recv_intersections = [&](int32_t send_id) {\n      const TensorSliceView& in_slice = in_slices.at(send_id);\n      const TensorSliceView& intersection = cur_rank_out_slice.Intersect(in_slice);\n      if (intersection.IsEmpty()) { return; }\n      const int64_t merged_id = GetMappedParallelId(send_id, in_parallel_desc, parallel_desc);\n      recv_intersections->at(merged_id) = intersection;\n    };\n    int64_t corresponding_in_id = 0;\n    // For example [[0, 1], [2, 3]] -> [[1, 3], [5, 6]]\n    if (in_parallel_desc.Containing(machine_id, device_index)) {\n      // 1 and 3 are in [[0, 1], [2, 3]], use the same id in the producer parallel description\n      // The id of 1 is (0, 1), the id of 3 is (1, 1)\n      corresponding_in_id =\n          CHECK_JUST(in_parallel_desc.ParallelId4MachineDeviceId(machine_id, device_index));\n    } else {\n      // 5 and 7 are not in [[0, 1], [2, 3]]\n      // Then the id does not matter\n      corresponding_in_id = out_id % in_parallel_num;\n    }\n    std::vector<int64_t> in_parallel_ids(in_hierarchy_dimension);\n    // The corresponding parallel id of a consumer rank in the producer parallel description\n    std::vector<int64_t> out_parallel_ids(in_hierarchy_dimension);\n    in_hierarchy_index_helper.OffsetToNdIndex(corresponding_in_id, out_parallel_ids.data(),\n                                              in_hierarchy_dimension);\n    DfsTraverseRanks4NdSbp(0, in_parallel_ids, out_parallel_ids, *in_parallel_hierarchy,\n                           in_hierarchy_index_helper, in_nd_sbp, add_to_recv_intersections);\n  }\n\n  // cur rank send to\n  if (in_parallel_desc.Containing(machine_id, device_index)) {\n    send_intersections->resize(parallel_num);\n    int64_t in_id =\n        CHECK_JUST(in_parallel_desc.ParallelId4MachineDeviceId(machine_id, device_index));\n    const TensorSliceView& cur_rank_in_slice = in_slices.at(in_id);\n    for (int64_t recv_i = 0; recv_i < out_parallel_num; ++recv_i) {\n      const auto& add_to_send_intersections = [&](int32_t send_id) {\n        if (send_id != in_id) { return; }\n        const TensorSliceView& out_slice = out_slices.at(recv_i);\n        const TensorSliceView& intersection = out_slice.Intersect(cur_rank_in_slice);\n        if (intersection.IsEmpty()) { return; }\n        const int64_t merged_id = GetMappedParallelId(recv_i, out_parallel_desc, parallel_desc);\n        send_intersections->at(merged_id) = intersection;\n      };\n      int64_t out_device_id = CHECK_JUST(out_parallel_desc.DeviceId4ParallelId(recv_i));\n      int64_t out_machine_id = CHECK_JUST(out_parallel_desc.MachineId4ParallelId(recv_i));\n      int64_t corresponding_in_id = 0;\n      // For example [[0, 1], [2, 3]] -> [[1, 3], [5, 6]]\n      if (in_parallel_desc.Containing(out_machine_id, out_device_id)) {\n        // 1 and 3 are in [[0, 1], [2, 3]], use the same id in the producer parallel description\n        // The id of 1 is (0, 1), the id of 3 is (1, 1)\n        corresponding_in_id =\n            CHECK_JUST(in_parallel_desc.ParallelId4MachineDeviceId(out_machine_id, out_device_id));\n      } else {\n        // 5 and 7 are not in [[0, 1], [2, 3]]\n        // Then the id does not matter\n        corresponding_in_id = recv_i % in_parallel_num;\n      }\n      std::vector<int64_t> in_parallel_ids(in_hierarchy_dimension);\n      // The corresponding parallel id of a consumer rank in the producer parallel description\n      std::vector<int64_t> out_parallel_ids(in_hierarchy_dimension);\n      in_hierarchy_index_helper.OffsetToNdIndex(corresponding_in_id, out_parallel_ids.data(),\n                                                in_hierarchy_dimension);\n      DfsTraverseRanks4NdSbp(0, in_parallel_ids, out_parallel_ids, *in_parallel_hierarchy,\n                             in_hierarchy_index_helper, in_nd_sbp, add_to_send_intersections);\n    }\n  }\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/operator/nccl_send_recv_boxing_op_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/register/tensor_slice_view.h\"\n#include \"oneflow/core/job/nd_sbp_util.h\"\n\nnamespace oneflow {\n\nint64_t GetMappedParallelId(const int64_t from_parallel_id, const ParallelDesc& from_parallel_desc,\n                            const ParallelDesc& to_parallel_desc);\n\nvoid GetRankSendRecvIntersection(int64_t parallel_id, const ParallelDesc& parallel_desc,\n                                 const ParallelDesc& in_parallel_desc,\n                                 const ParallelDesc& out_parallel_desc, const NdSbp& in_nd_sbp,\n                                 const NdSbp& out_nd_sbp, const Shape& logical_shape,\n                                 std::vector<TensorSliceView>* send_intersections,\n                                 std::vector<TensorSliceView>* recv_intersections);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/operator/op_attribute.proto",
    "content": "syntax = \"proto2\";\npackage oneflow;\n\nimport \"oneflow/core/register/logical_blob_id.proto\";\nimport \"oneflow/core/register/blob_desc.proto\";\nimport \"oneflow/core/operator/op_conf.proto\";\nimport \"oneflow/core/operator/arg_modifier_signature.proto\";\nimport \"oneflow/core/job/sbp_parallel.proto\";\nimport \"oneflow/core/job/local_parallel.proto\";\nimport \"oneflow/core/job/blob_lifetime_signature.proto\";\nimport \"oneflow/core/job/parallel_signature.proto\";\nimport \"oneflow/core/job/parallel_conf_signature.proto\";\n\nmessage OpAttribute {\n  repeated string input_bns = 1;\n  repeated string output_bns = 2;\n  repeated string tmp_bns = 3;\n\n  required OperatorConf op_conf = 50;\n\n  // inter-node signature\n  required ArgSignature arg_signature = 100;\n  required ArgModifierSignature arg_modifier_signature = 101;\n  optional BlobLastUsedSignature blob_last_used_signature = 102;\n  optional BlobBackwardUsedSignature blob_backward_used_signature = 103;\n\n  // op node signature\n  optional SbpSignature sbp_signature = 104;\n  optional LocalSignature local_signature = 105;\n  optional BlobDescSignature logical_blob_desc_signature = 106;\n  optional ParallelConfSignature parallel_conf_signature = 109;\n  optional NdSbpSignature nd_sbp_signature = 110;\n}\n\nmessage OpAttributeList {\n  repeated OpAttribute op_attribute = 1;\n}\n"
  },
  {
    "path": "oneflow/core/operator/op_conf.proto",
    "content": "syntax = \"proto2\";\npackage oneflow;\n\nimport \"oneflow/core/common/shape.proto\";\nimport \"oneflow/core/common/data_type.proto\";\nimport \"oneflow/core/common/device_type.proto\";\nimport \"oneflow/core/record/record.proto\";\nimport \"oneflow/core/job/resource.proto\";\nimport \"oneflow/core/register/logical_blob_id.proto\";\nimport \"oneflow/core/register/tensor_slice_view.proto\";\nimport \"oneflow/core/framework/user_op_conf.proto\";\nimport \"oneflow/core/job/sbp_parallel.proto\";\nimport \"oneflow/core/graph/boxing/collective_boxing.proto\";\nimport \"oneflow/core/job/initializer_conf.proto\";\nimport \"oneflow/core/job/regularizer_conf.proto\";\nimport \"oneflow/core/job/placement.proto\";\nimport \"oneflow/core/job/learning_rate_schedule_conf.proto\";\nimport \"oneflow/core/operator/interface_blob_conf.proto\";\nimport \"oneflow/core/register/blob_desc.proto\";\n\n\nenum ActivationType {\n  kNone = 0;\n  kTanH = 1;\n  kSigmoid = 2;\n  kRelu = 3;\n}\n\nmessage DistributeConcatOpConf {\n  repeated string in = 1;\n  required string out = 2;\n  required int32 axis = 3;\n}\n\nmessage DistributeSplitOpConf {\n  required string in = 1;\n  repeated string out = 2;\n  required int32 axis = 3;\n  optional bool is_variable_ref = 4 [default = false];\n}\n\nmessage DistributeCloneOpConf {\n  required string in = 1;\n  repeated string out = 2;\n  optional bool is_variable_ref = 3 [default = false];\n}\n\nmessage DistributeAddOpConf {\n  repeated string in = 1;\n  required string out = 2;\n}\n\nmessage CopyCommNetOpConf {\n  required LogicalBlobId lbi = 2;\n}\n\nmessage BoxConcatConf {\n  required int32 axis = 1;\n}\n\nmessage BoxAddConf {\n}\n\nmessage BoxSplitConf {\n  required int32 axis = 1;\n  repeated int32 part_num = 2;\n}\n\nmessage BoxCloneConf {\n}\n\nmessage BoxingOpConf {\n  required LogicalBlobId lbi = 1;\n  required int32 in_num = 2;\n  required int32 out_num = 3;\n\n  oneof in_box {\n    BoxConcatConf concat_box = 4;\n    BoxAddConf add_box = 5;\n  }\n  oneof out_box {\n    BoxSplitConf split_box = 6;\n    BoxCloneConf clone_box = 7;\n  }\n}\n\nmessage DynamicReshapeOpConf {\n  required string in = 1;\n  required string out = 2;\n  required ShapeProto shape = 3;\n}\n\nmessage DynamicReshapeLikeOpConf {\n  required string x = 1;\n  required string y = 2;\n  required string like = 3;\n}\n\nmessage FeedInputOpConf {\n  // NOTE(chengcheng): define in/out key as UserOp ibn/obn.\n  required string in_0 = 1;\n  required string out_0 = 2;\n}\n\nmessage FeedVariableOpConf {\n  required string in_0 = 1;\n  required string out_0 = 2;\n}\n\nmessage FetchOutputOpConf {\n  required string in_0 = 1;\n  required string out_0 = 2;\n}\n\nmessage InputOpConf {\n  optional string tick = 1;\n  required string out = 2;\n  required InterfaceBlobConf blob_conf = 3;\n  optional string job_name = 4;\n}\n\nmessage ReturnOpConf {\n  required string in = 1;\n  required string out = 2;\n  optional string job_name = 3;\n}\n\nmessage OutputOpConf {\n  required string in = 1;\n  required string out = 2;\n  required InterfaceBlobConf blob_conf = 3;\n  optional string job_name = 4;\n}\n\nmessage VariableOpConf {\n  optional string tick = 1;\n  required string out = 2;\n  required ShapeProto shape = 3;\n  optional DataType data_type = 4;\n  oneof initialize {\n    InitializerConf initializer = 5;\n    InitializeWithSnapshotConf initialize_with_snapshot = 6;\n  }\n  optional string model_name = 7 [default = \"weight\"];\n  optional int64 random_seed = 9;\n  optional RegularizerConf regularizer = 10;\n  optional bool trainable = 11 [default = true];\n  repeated string nd_sbp = 12;\n}\n\nmessage TickOpConf {\n  repeated string tick = 1;\n  required string out = 2;\n}\n\nmessage CriticalSectionWaitTickOpConf {\n  repeated string tick = 1;\n  required string out = 2;\n  required string buffer_name = 3;\n}\n\nmessage CriticalSectionCallbackTickOpConf {\n  repeated string tick = 1;\n  required string out = 2;\n  required string buffer_name = 3;\n}\n\nmessage DeviceTickOpConf {\n  repeated string tick = 1;\n  required string out = 2;\n  optional ShapeProto time_shape = 3;\n}\n\nmessage WaitAndSendIdsOpConf {\n  required string out = 1;\n  required string wait_buffer_name = 2;\n  repeated Int64List id_list = 3;\n  required DataType data_type = 4 [default = kInt32];\n  optional string job_name = 5;\n}\n\nmessage CallbackNotifyOpConf {\n  required string in = 1;\n  repeated string callback_buffer_name = 2;\n  optional string job_name = 3;\n}\n\nmessage ReentrantLockOpConf {\n  required string start = 1;\n  optional string end = 2;\n  required string out = 3;\n  repeated Int64List lock_id2intersecting_lock_ids = 4;\n}\n\nmessage SrcSubsetTickOpConf {\n  repeated string in = 1;\n  required string out = 2;\n}\n\nmessage DstSubsetTickOpConf {\n  repeated string in = 1;\n  required string out = 2;\n}\n\nmessage SourceTickOpConf {\n  required string out = 1;\n}\n\nmessage SinkTickOpConf {\n  repeated string tick = 1;\n  required string out = 2;\n}\n\nmessage TotalLossInstanceNumOpConf {\n  repeated string in = 1;\n  required string out = 2;\n}\n\nmessage ShapeElemCntAxisConf {\n  repeated int32 axis = 1;\n}\n\nmessage ShapeElemCntRangeAxisConf {\n  // closed interval: [begin_axis, end_axis]\n  optional int32 begin_axis = 1 [default = 0];\n  optional int32 end_axis = 2 [default = -1];\n}\n\nmessage ShapeElemCntOpConf {\n  required string x = 1;\n  required string y = 2;\n  optional DataType data_type = 3 [default = kInt32];\n  oneof axis_conf {\n    ShapeElemCntAxisConf exclude_axis_conf = 4;\n    ShapeElemCntAxisConf include_axis_conf = 5;\n    ShapeElemCntRangeAxisConf range_axis_conf = 6;\n  }\n}\n\nmessage AccTickOpConf {\n  // in\n  required string one = 1;\n  // out\n  required string acc = 2;\n  optional int32 max_acc_num = 3 [default = 1];\n}\n\nmessage IdentityOpConf {\n  required string in = 1;\n  required string out = 2;\n}\n\nmessage CopyOpConf {\n  required string in = 1;\n  required string out = 2;\n}\n\nmessage CastToLocalOpConf {\n  required string in = 1;\n  required string out = 2;\n  required SbpParallel sbp_parallel = 3;\n}\n\nmessage CastFromLocalOpConf {\n  required string in = 1;\n  required string out = 2;\n  required SbpParallel sbp_parallel = 3;\n}\n\nmessage CaseOpConf {\n  required string in = 1;\n  repeated string out = 2;\n}\n\nmessage EsacOpConf {\n  repeated string in = 1;\n  required string out = 2;\n  optional DataType data_type = 3 [default=kInt32];\n}\n\nmessage AssignOpConf {\n  required string ref = 1;\n  required string value = 2;\n}\n\nmessage LearningRateScheduleOpConf {\n  required string train_step = 1;\n  required string out = 2;\n  required float learning_rate = 3;\n  optional LearningRateDecayConf learning_rate_decay = 4;\n}\n\nmessage SliceBoxingConf {\n  required LogicalBlobId lbi = 1;\n  repeated TensorSliceViewProto in_slice = 2;\n  required TensorSliceViewProto out_slice = 3;\n  optional ShapeProto out_shape = 4;\n}\n\nmessage SliceBoxingCopyOpConf {\n  required SliceBoxingConf slice_boxing_conf = 1;\n}\n\nmessage SliceBoxingAddOpConf {\n  required SliceBoxingConf slice_boxing_conf = 1;\n}\n\nmessage ConstantLikeOpConf {\n  required string like = 1;\n  required string out = 2;\n  optional DataType data_type = 3;\n  oneof scalar_operand {\n    int64 int_operand = 4;\n    double float_operand = 5;\n  }\n}\n\nmessage SyncDynamicResizeOpConf {\n  required string in = 1;\n  required string size = 2;\n  required string out = 3;\n  required int64 axis = 4;\n  optional bool eager = 5 [default = false];\n}\n\nmessage BroadcastToCompatibleWithOpConf {\n  required string x = 1;\n  repeated string compatible = 2;\n  required string y = 3;\n}\n\nmessage CollectiveBoxingGenericOpConf {\n  required LogicalBlobId lbi = 1;\n  required boxing.collective.RankDesc rank_desc = 2;\n}\n\nmessage BoxingIdentityOpConf {\n  required LogicalBlobId lbi = 1;\n}\n\nmessage CollectiveBoxingPackOpConf {\n  required LogicalBlobId lbi = 1;\n  required SbpParallel src_sbp_parallel = 2;\n  required SbpParallel dst_sbp_parallel = 3;\n  required int64 num_ranks = 4;\n  required ShapeProto logical_shape = 5;\n}\n\nmessage CollectiveBoxingUnpackOpConf {\n  required LogicalBlobId lbi = 1;\n  required SbpParallel src_sbp_parallel = 2;\n  required SbpParallel dst_sbp_parallel = 3;\n  required int64 num_ranks = 4;\n  required ShapeProto logical_shape = 5;\n}\n\nmessage ImageDecoderRandomCropResizeOpConf {\n  required string in = 1;\n  required string out = 2;\n  required int64 target_width = 3;\n  required int64 target_height = 4;\n  optional int64 num_workers = 5 [default = 3];\n  optional int64 max_num_pixels = 6 [default = 67108864];\n  optional int64 warmup_size = 7 [default = 6400];\n  optional int64 seed = 8;\n  optional int64 num_attempts = 9 [default = 10];\n  optional float random_area_min = 10 [default = 0.08];\n  optional float random_area_max = 11 [default = 1.0];\n  optional float random_aspect_ratio_min = 12 [default = 0.75];\n  optional float random_aspect_ratio_max = 13 [default = 1.333333];\n}\n\nmessage BoxingZerosOpConf {\n  required LogicalBlobId lbi = 1;\n  required ShapeProto shape = 2;\n  required DataType data_type = 3;\n}\n\nmessage NcclSendRecvBoxingOpConf {\n  required LogicalBlobId lbi = 1;\n  required NdSbp src_nd_sbp = 2;\n  required NdSbp dst_nd_sbp = 3;\n  required ParallelConf parallel_conf = 4;\n  required ParallelConf src_parallel_conf = 5;\n  required ParallelConf dst_parallel_conf = 6;\n  required ShapeProto logical_shape = 7;\n  required DataType data_type = 8;\n  required bool has_input = 9;\n  required bool has_output = 10;\n}\n\nmessage OperatorConf {\n  required string name = 1;\n  optional string device_tag = 4 [default = \"invalid_device\"];\n  repeated string ctrl_in_op_name = 7;\n  optional int64 scope_symbol_id = 8;\n  optional string stream_name_hint = 9;\n  optional string pass_tag = 10;\n  optional string loc = 11 [default = \"\"];\n  optional int64 logical_chain_id = 12 [default = -1];\n  optional int64 order_in_logical_chain = 13 [default = -1];\n  optional string calculation_pass_name = 14 [default = \"forward_pass\"];\n  oneof op_type {\n    // system op\n    CopyCommNetOpConf copy_comm_net_conf = 106;\n    BoxingOpConf boxing_conf = 108;\n    VariableOpConf variable_conf = 122;\n    TickOpConf tick_conf = 124;\n    CriticalSectionWaitTickOpConf critical_section_wait_tick_conf = 125;\n    CriticalSectionCallbackTickOpConf critical_section_callback_tick_conf = 126;\n    TotalLossInstanceNumOpConf total_loss_instance_num_conf = 131;\n    ShapeElemCntOpConf shape_elem_cnt_conf = 132;\n    SrcSubsetTickOpConf src_subset_tick_conf = 133;\n    DstSubsetTickOpConf dst_subset_tick_conf = 134;\n    SourceTickOpConf source_tick_conf = 135;\n    SinkTickOpConf sink_tick_conf = 136;\n    InputOpConf input_conf = 137;\n    OutputOpConf output_conf = 138;\n    WaitAndSendIdsOpConf wait_and_send_ids_conf = 139;\n    ReentrantLockOpConf reentrant_lock_conf = 140;\n    CallbackNotifyOpConf callback_notify_conf = 141;\n    AccTickOpConf acc_tick_conf = 144;\n    ReturnOpConf return_conf = 146;\n    DistributeConcatOpConf distribute_concat_conf = 155;\n    DistributeSplitOpConf distribute_split_conf = 156;\n    DistributeCloneOpConf distribute_clone_conf = 157;\n    DistributeAddOpConf distribute_add_conf = 158;\n    DeviceTickOpConf device_tick_conf = 159;\n    SliceBoxingCopyOpConf slice_boxing_copy_conf = 166;\n    SliceBoxingAddOpConf slice_boxing_add_conf = 167;\n    CollectiveBoxingGenericOpConf collective_boxing_generic_conf = 170;\n    BoxingIdentityOpConf boxing_identity_conf = 171;\n    CollectiveBoxingPackOpConf collective_boxing_pack_conf = 174;\n    CollectiveBoxingUnpackOpConf collective_boxing_unpack_conf = 175;\n    BoxingZerosOpConf boxing_zeros_conf = 176;\n    NcclSendRecvBoxingOpConf nccl_send_recv_boxing_conf = 177;\n    UserOpConf user_conf = 199;\n\n    // domain op\n    DynamicReshapeOpConf dynamic_reshape_conf = 203;\n    DynamicReshapeLikeOpConf dynamic_reshape_like_conf = 287;\n    IdentityOpConf identity_conf = 290;\n    CaseOpConf case_conf = 291;\n    EsacOpConf esac_conf = 292;\n    AssignOpConf assign_conf = 296;\n    LearningRateScheduleOpConf learning_rate_schedule_conf = 298;\n    ConstantLikeOpConf constant_like_conf = 339;\n    SyncDynamicResizeOpConf sync_dynamic_resize_conf = 340;\n    CopyOpConf copy_conf = 343;\n    CastToLocalOpConf cast_to_local_conf = 344;\n    CastFromLocalOpConf cast_from_local_conf = 345;\n    ImageDecoderRandomCropResizeOpConf image_decoder_random_crop_resize_conf = 349;\n\n    // math op\n    BroadcastToCompatibleWithOpConf broadcast_to_compatible_with_conf = 525;\n\n    // NOTE(chengcheng): Lazy 1.0 system ops.\n    //   Feed EagerTensor to interface op.\n    //   Note that FeedxxOp just for build CustomOpExpr, and has NO operator impl.\n    FeedInputOpConf feed_input_conf = 600;\n    FeedVariableOpConf feed_variable_conf = 601;\n    //   Fetch EagerTensor from output op\n    FetchOutputOpConf fetch_output_conf = 602;\n  }\n}\n\nmessage OpNameRelations {\n  map<string, string> src_op_name2dst_op_name = 1;\n}\n\nmessage OpNameGroups {\n  message OpNameGroup {\n    repeated string op_name = 1;\n  }\n  repeated OpNameGroup op_name_group = 2;\n}\n"
  },
  {
    "path": "oneflow/core/operator/op_conf_symbol.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/operator/op_conf_symbol.h\"\n\nnamespace oneflow {\n\nOperatorConfSymbol::OperatorConfSymbol(int64_t symbol_id, const OperatorConf& op_conf)\n    : symbol_id_(symbol_id), op_conf_(op_conf) {}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/operator/op_conf_symbol.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_OPERATOR_OP_CONF_SYMBOL_H_\n#define ONEFLOW_CORE_OPERATOR_OP_CONF_SYMBOL_H_\n\n#include <string>\n#include \"oneflow/core/common/optional.h\"\n#include \"oneflow/core/operator/op_conf.pb.h\"\n\nnamespace oneflow {\n\nclass OperatorConfSymbol final {\n public:\n  OperatorConfSymbol(const OperatorConfSymbol&) = delete;\n  OperatorConfSymbol(OperatorConfSymbol&&) = delete;\n  OperatorConfSymbol(int64_t symbol_id, const OperatorConf& op_conf);\n\n  ~OperatorConfSymbol() = default;\n\n  const OperatorConf& op_conf() const { return op_conf_; }\n  const OperatorConf& data() const { return op_conf_; }\n  const Optional<int64_t>& symbol_id() const { return symbol_id_; }\n\n private:\n  Optional<int64_t> symbol_id_;\n  OperatorConf op_conf_;\n  std::shared_ptr<OperatorConf> data_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_OPERATOR_OP_CONF_SYMBOL_H_\n"
  },
  {
    "path": "oneflow/core/operator/op_conf_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_OPERATOR_OP_CONF_UTIL_H_\n#define ONEFLOW_CORE_OPERATOR_OP_CONF_UTIL_H_\n\n#include \"oneflow/core/operator/op_conf.pb.h\"\n\nnamespace std {\n\ntemplate<>\nstruct hash<::oneflow::OperatorConf::OpTypeCase> {\n  std::size_t operator()(const ::oneflow::OperatorConf::OpTypeCase& op_type) const {\n    return std::hash<int>()(static_cast<size_t>(op_type));\n  }\n};\n\n}  // namespace std\n\n#endif  // ONEFLOW_CORE_OPERATOR_OP_CONF_UTIL_H_\n"
  },
  {
    "path": "oneflow/core/operator/op_infer_cache.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_OPERATOR_OP_INFER_CACHE_H_\n#define ONEFLOW_CORE_OPERATOR_OP_INFER_CACHE_H_\n\n#include \"oneflow/core/job/job_desc.h\"\n#include \"oneflow/core/operator/op_conf.pb.h\"\n#include \"oneflow/core/common/shape.h\"\n#include \"oneflow/core/common/dtype_signature.h\"\n#include \"oneflow/core/common/symbol.h\"\n\nnamespace oneflow {\n\nstruct OpInferCacheKey final {\n  const void* scope;\n  Symbol<OperatorConf> op_conf_sym;\n  Symbol<DTypeSignature> dtype_signature_sym;\n  std::vector<Symbol<Shape>> ibn_idx2shape_sym;\n};\n\nstruct OpInferCacheValue final {\n  std::vector<Symbol<Shape>> obn_idx2shape_sym;\n};\n\ninline bool operator==(const OpInferCacheKey& lhs, const OpInferCacheKey& rhs) {\n  return lhs.scope == rhs.scope && lhs.op_conf_sym == rhs.op_conf_sym\n         && lhs.dtype_signature_sym == rhs.dtype_signature_sym\n         && lhs.ibn_idx2shape_sym == rhs.ibn_idx2shape_sym;\n}\n\ninline bool operator!=(const OpInferCacheKey& lhs, const OpInferCacheKey& rhs) {\n  return !(lhs == rhs);\n}\n\n}  // namespace oneflow\n\nnamespace std {\n\ntemplate<>\nstruct hash<oneflow::OpInferCacheKey> final {\n  size_t operator()(const oneflow::OpInferCacheKey& op_infer_cache_key) const {\n    using namespace oneflow;\n    size_t ibn_idx2shape_sym_hash_value = 0;\n    for (const auto& shape_sym : op_infer_cache_key.ibn_idx2shape_sym) {\n      AddHash(&ibn_idx2shape_sym_hash_value, shape_sym);\n    }\n    return Hash(op_infer_cache_key.scope, op_infer_cache_key.op_conf_sym,\n                ibn_idx2shape_sym_hash_value, op_infer_cache_key.dtype_signature_sym);\n  }\n};\n\n}  // namespace std\n\n#endif  // ONEFLOW_CORE_OPERATOR_OP_INFER_CACHE_H_\n"
  },
  {
    "path": "oneflow/core/operator/op_node_signature.proto",
    "content": "syntax = \"proto2\";\npackage oneflow;\n\nimport \"oneflow/core/job/sbp_parallel.proto\";\nimport \"oneflow/core/job/local_parallel.proto\";\nimport \"oneflow/core/register/blob_desc.proto\";\nimport \"oneflow/core/job/parallel_signature.proto\";\n\nmessage OpNodeSignature {\n  optional SbpSignature sbp_signature = 1;\n  optional LocalSignature local_signature = 2;\n  optional BlobDescSignature logical_blob_desc_signature = 3;\n  optional ParallelSignature parallel_signature = 5;\n}\n"
  },
  {
    "path": "oneflow/core/operator/operator.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <utility>\n#include \"oneflow/core/auto_parallel/algorithm_util.h\"\n#include \"oneflow/core/common/balanced_splitter.h\"\n#include \"oneflow/core/common/container_util.h\"\n#include \"oneflow/core/common/decorator.h\"\n#include \"oneflow/core/vm/symbol_storage.h\"\n#include \"oneflow/core/framework/instructions_builder.h\"\n#include \"oneflow/core/framework/to_string.h\"\n#include \"oneflow/core/framework/user_op_registry_manager.h\"\n#include \"oneflow/core/job/local_sig_infer_hint.h\"\n#include \"oneflow/core/job/sbp_signature_builder.h\"\n#include \"oneflow/core/job/scope.h\"\n#include \"oneflow/core/job/sbp_parallel.h\"\n#include \"oneflow/core/job/lazy_mode.h\"\n#include \"oneflow/core/operator/operator.h\"\n#include \"oneflow/core/operator/op_node_signature.pb.h\"\n#include \"oneflow/core/job/nd_sbp_infer_hint.h\"\n#include \"oneflow/core/framework/nd_sbp.h\"\n#include \"oneflow/core/framework/sbp_infer_util.h\"\n#include \"oneflow/core/framework/placement_sbp_util.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nDataType GetDataTypeFromBnInOpVec(\n    std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,\n    const PbRpf<std::string>& bn_in_ops) {\n  for (const std::string& bn_in_op : bn_in_ops) {\n    const BlobDesc* blob_desc = GetBlobDesc4BnInOp(bn_in_op);\n    if (blob_desc) { return blob_desc->data_type(); }\n  }\n  return DataType::kInvalidDataType;\n}\n\nMaybe<Operator> CheckAndConstructOp(std::shared_ptr<const OperatorConf> op_conf) {\n  Operator* rptr = NewObj<int32_t, Operator>(op_conf->op_type_case(), *op_conf);\n  DeviceType device_type = JUST(DeviceType4DeviceTag(op_conf->device_tag()));\n  if (IsCpuOnly(*op_conf)) { CHECK_EQ_OR_RETURN(device_type, DeviceType::kCPU); }\n  JUST(rptr->Init(op_conf));\n  return std::shared_ptr<Operator>(rptr);\n}\n\n}  // namespace\n\nOperator::Operator() : device_type_(DeviceType::kInvalidDevice) {}\n\nMaybe<void> Operator::Init(const OperatorConf& op_conf) {\n  return Init(std::make_shared<const OperatorConf>(op_conf));\n}\n\nMaybe<void> Operator::Init(std::shared_ptr<const OperatorConf> op_conf) {\n  op_conf_ = std::move(op_conf);\n  device_type_ = JUST(DeviceType4DeviceTag(op_conf_->device_tag()));\n  JUST(InitFromOpConf());\n  input_output_bns_.Reserve(input_bns().size() + output_bns().size());\n  for (const auto& bn : input_bns()) { *input_output_bns_.Add() = bn; }\n  for (const auto& bn : output_bns()) { *input_output_bns_.Add() = bn; }\n  return Maybe<void>::Ok();\n}\n\nconst LogicalBlobId& Operator::BnInOp2Lbi(const std::string& bn_in_op) const {\n  return arg_signature_.bn_in_op2lbi().at(bn_in_op);\n}\n\nconst OperatorConf& Operator::op_conf() const {\n  CHECK(op_conf_);\n  return *op_conf_;\n}\n\nstd::shared_ptr<const OperatorConf> Operator::shared_op_conf() const { return op_conf_; }\n\nDeviceType Operator::device_type() const { return device_type_; }\n\nconst std::string& Operator::SoleIbn() const {\n  CHECK_EQ(input_bns().size(), 1) << \", op_name \" << op_name();\n  return input_bns().Get(0);\n}\nconst std::string& Operator::SoleObn() const {\n  CHECK_EQ(output_bns().size(), 1) << \", op_name \" << op_name();\n  return output_bns().Get(0);\n}\nconst std::string& Operator::SoleTbn() const {\n  CHECK_EQ(tmp_bns().size(), 1);\n  return tmp_bns().Get(0);\n}\n\nMaybe<const std::string*> Operator::obn4lbi(const LogicalBlobId& lbi) const {\n  const auto& it = lbi2output_index_.find(lbi);\n  CHECK_OR_RETURN(it != lbi2output_index_.end())\n      << \"no logical blob id found. lbn: \" << lbi.op_name() << \"/\" << lbi.blob_name();\n  return &output_bns().Get(it->second);\n}\n\nconst PbRpf<std::string>& Operator::input_bns() const { return input_bns_; }\n\nconst PbRpf<std::string>& Operator::output_bns() const { return output_bns_; }\n\nconst PbRpf<std::string>& Operator::tmp_bns() const { return tmp_bns_; }\n\nconst PbRpf<std::string>& Operator::input_output_bns() const { return input_output_bns_; }\n\nMaybe<void> Operator::InferParallelSignatureIf() {\n  JUST(InferBlobParallelDesc());\n  return Maybe<void>::Ok();\n}\n\nMaybe<const ParallelDesc> Operator::GetParallelDesc4BnInOp(const std::string& bn) const {\n  CHECK_OR_RETURN(bn2parallel_desc_);\n  auto it = bn2parallel_desc_->find(bn);\n  CHECK_OR_RETURN(it != bn2parallel_desc_->end());\n  return it->second;\n}\n\nMaybe<void> Operator::FillBlobParallelDesc(\n    const std::function<Maybe<const ParallelDesc>(const std::string&)>& ParallelDesc4Bn) {\n  CHECK_OR_RETURN(!bn2parallel_desc_);\n  bn2parallel_desc_.reset(new HashMap<std::string, std::shared_ptr<const ParallelDesc>>);\n  for (const auto& bn : input_output_bns()) {\n    auto blob_parallel_desc = JUST(ParallelDesc4Bn(bn));\n    CHECK(bn2parallel_desc_->emplace(bn, blob_parallel_desc).second);\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> Operator::InferBlobParallelDesc() {\n  JUST(FillBlobParallelDesc(\n      [&](const std::string& bn) -> Maybe<const ParallelDesc> { return GetOpParallelDesc(); }));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> Operator::FillOpParallelDesc(const ParallelDesc& parallel_desc) {\n  return FillOpParallelDesc(std::make_shared<const ParallelDesc>(parallel_desc));\n}\n\nMaybe<void> Operator::FillOpParallelDesc(std::shared_ptr<const ParallelDesc> parallel_desc) {\n  CHECK_OR_RETURN(!op_parallel_desc_);\n  op_parallel_desc_ = std::move(parallel_desc);\n  return Maybe<void>::Ok();\n}\n\nMaybe<const ParallelDesc> Operator::GetOpParallelDesc() const {\n  CHECK_OR_RETURN(op_parallel_desc_);\n  return op_parallel_desc_;\n}\n\nnamespace {\n\nMaybe<void> FillLogicalBlobDesc(\n    const std::function<Maybe<const BlobDesc>(int32_t)>& BlobDesc4Index,\n    const PbRpf<std::string>& bns,\n    std::unique_ptr<std::vector<std::shared_ptr<const BlobDesc>>>* index2logical_blob_desc_ptr) {\n  CHECK_OR_RETURN(!(*index2logical_blob_desc_ptr));\n  index2logical_blob_desc_ptr->reset(new std::vector<std::shared_ptr<const BlobDesc>>());\n  (*index2logical_blob_desc_ptr)->reserve(bns.size());\n  for (int32_t i = 0; i < bns.size(); ++i) {\n    (*index2logical_blob_desc_ptr)->emplace_back(JUST(BlobDesc4Index(i)));\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> FillLogicalBlobDesc(\n    const std::function<const BlobDesc&(const std::string&)>& BlobDesc4BnInOp,\n    const PbRpf<std::string>& bns,\n    std::unique_ptr<std::vector<std::shared_ptr<const BlobDesc>>>* index2logical_blob_desc_ptr) {\n  CHECK_OR_RETURN(!(*index2logical_blob_desc_ptr));\n  index2logical_blob_desc_ptr->reset(new std::vector<std::shared_ptr<const BlobDesc>>());\n  (*index2logical_blob_desc_ptr)->reserve(bns.size());\n  for (const auto& bn : bns) {\n    const BlobDesc& blob_desc = BlobDesc4BnInOp(bn);\n    (*index2logical_blob_desc_ptr)->emplace_back(std::make_shared<const BlobDesc>(blob_desc));\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> FillLogicalBlobDesc(\n    const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,\n    const PbRpf<std::string>& bns,\n    std::unique_ptr<std::vector<std::shared_ptr<const BlobDesc>>>* index2logical_blob_desc_ptr) {\n  JUST(FillLogicalBlobDesc(\n      [&](const std::string& bn) -> const BlobDesc& {\n        const BlobDesc* blob_desc = BlobDesc4BnInOp(bn);\n        CHECK_NOTNULL(blob_desc);\n        return *blob_desc;\n      },\n      bns, index2logical_blob_desc_ptr));\n  return Maybe<void>::Ok();\n}\n\nMaybe<const BlobDesc> GetLogicalBlobDesc(\n    const std::unique_ptr<std::vector<std::shared_ptr<const BlobDesc>>>& index2logical_blob_desc,\n    int32_t index) {\n  CHECK_OR_RETURN(index2logical_blob_desc);\n  CHECK_LT_OR_RETURN(index, index2logical_blob_desc->size());\n  return index2logical_blob_desc->at(index);\n}\n\nMaybe<void> FillLogicalBlobDescSignature(\n    const PbRpf<std::string>& bns,\n    const std::unique_ptr<std::vector<std::shared_ptr<const BlobDesc>>>& index2logical_blob_desc,\n    PbMap<std::string, BlobDescProto>* bn_in_op2blob_desc) {\n  CHECK_OR_RETURN(index2logical_blob_desc);\n  CHECK_EQ_OR_RETURN(bns.size(), index2logical_blob_desc->size());\n  for (int32_t i = 0; i < bns.size(); ++i) {\n    index2logical_blob_desc->at(i)->ToProto(&(*bn_in_op2blob_desc)[bns.Get(i)]);\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<bool> SupportNonContiguous(const Operator* op) {\n  const auto& op_conf = op->op_conf();\n  if (op_conf.has_user_conf()) {\n    const auto* registry =\n        user_op::UserOpRegistryMgr::Get().GetOpRegistryResult(op_conf.user_conf().op_type_name());\n    CHECK_NOTNULL_OR_RETURN(registry)\n        << \"The op(operation) \" << op_conf.user_conf().op_type_name()\n        << \" is not found. Please check whether it has been registered correctly.\";\n    return registry->non_contiguous_supported;\n  }\n  return false;\n}\n\n}  // namespace\n\nMaybe<void> Operator::FillLogicalInBlobDesc(\n    const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp) {\n  JUST(FillLogicalBlobDesc(BlobDesc4BnInOp, input_bns(), &input_index2logical_blob_desc_));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> Operator::FillLogicalInBlobDesc(\n    const std::function<const BlobDesc&(const std::string&)>& BlobDesc4BnInOp) {\n  JUST(FillLogicalBlobDesc(BlobDesc4BnInOp, input_bns(), &input_index2logical_blob_desc_));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> Operator::FillLogicalInBlobDesc(\n    const std::function<Maybe<const BlobDesc>(int32_t)>& BlobDesc4InputIndex) {\n  JUST(FillLogicalBlobDesc(BlobDesc4InputIndex, input_bns(), &input_index2logical_blob_desc_));\n  return Maybe<void>::Ok();\n}\n\nMaybe<const BlobDesc> Operator::GetLogicalBlobDesc4Ibn(const std::string& ibn) const {\n  return GetLogicalBlobDesc4InputIndex(JUST(GetInputIndex(ibn)));\n}\n\nMaybe<const BlobDesc> Operator::GetLogicalBlobDesc4InputIndex(int32_t index) const {\n  return GetLogicalBlobDesc(input_index2logical_blob_desc_, index);\n}\n\nMaybe<void> Operator::FillLogicalOutBlobDesc(\n    const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp) {\n  JUST(FillLogicalBlobDesc(BlobDesc4BnInOp, output_bns(), &output_index2logical_blob_desc_));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> Operator::FillLogicalOutBlobDesc(\n    const std::function<const BlobDesc&(const std::string&)>& BlobDesc4BnInOp) {\n  JUST(FillLogicalBlobDesc(BlobDesc4BnInOp, output_bns(), &output_index2logical_blob_desc_));\n  return Maybe<void>::Ok();\n}\n\nMaybe<const BlobDesc> Operator::GetLogicalBlobDesc4Obn(const std::string& obn) const {\n  return GetLogicalBlobDesc4OutputIndex(JUST(GetOutputIndex(obn)));\n}\n\nMaybe<const BlobDesc> Operator::GetLogicalBlobDesc4OutputIndex(int32_t index) const {\n  return GetLogicalBlobDesc(output_index2logical_blob_desc_, index);\n}\n\nMaybe<const BlobDesc*> Operator::GetLogicalBlobDescPtr4OutputIndex(int32_t index) const {\n  CHECK_OR_RETURN(output_index2logical_blob_desc_);\n  CHECK_LT_OR_RETURN(index, output_index2logical_blob_desc_->size());\n  CHECK_OR_RETURN(output_index2logical_blob_desc_->at(index));\n  return output_index2logical_blob_desc_->at(index).get();\n}\n\nMaybe<const BlobDesc> Operator::GetLogicalBlobDesc4BnInOp(const std::string& bn) const {\n  const auto& it = bn2index_pair_.find(bn);\n  CHECK_OR_RETURN(it != bn2index_pair_.end());\n  if (it->second.first == BlobNameTag::kInputBlobName) {\n    return GetLogicalBlobDesc4InputIndex(it->second.second);\n  } else if (it->second.first == BlobNameTag::kOutputBlobName) {\n    return GetLogicalBlobDesc4OutputIndex(it->second.second);\n  } else {\n    UNIMPLEMENTED_THEN_RETURN();\n  }\n}\n\nMaybe<void> Operator::InferLogicalOutBlobDescsIf() {\n  CHECK_OR_RETURN(input_index2logical_blob_desc_);\n  CHECK_OR_RETURN(!output_index2logical_blob_desc_);\n  std::vector<std::shared_ptr<BlobDesc>> output_logical_blob_desc_vec;\n  output_logical_blob_desc_vec.resize(output_bns().size());\n  for (auto& blob_desc : output_logical_blob_desc_vec) {\n    blob_desc.reset(new BlobDesc(DataType::kInvalidDataType, MemoryFormat::kContiguous));\n  }\n  std::vector<std::shared_ptr<BlobDesc>> in_logical_blob_desc_vec;\n  in_logical_blob_desc_vec.resize(input_bns().size());\n  auto BlobDesc4BnInOp = [&](const std::string& bn) -> BlobDesc* {\n    const auto& it = bn2index_pair_.find(bn);\n    CHECK(it != bn2index_pair_.end());\n    if (it->second.first == BlobNameTag::kInputBlobName) {\n      auto& ptr = in_logical_blob_desc_vec.at(it->second.second);\n      if (!ptr) { ptr.reset(new BlobDesc(*input_index2logical_blob_desc_->at(it->second.second))); }\n      return ptr.get();\n    } else if (it->second.first == BlobNameTag::kOutputBlobName) {\n      return output_logical_blob_desc_vec.at(it->second.second).get();\n    } else {\n      UNIMPLEMENTED();\n      return nullptr;\n    }\n  };\n  JUST(InferLogicalOutBlobDescs(BlobDesc4BnInOp, *JUST(GetOpParallelDesc())));\n  output_index2logical_blob_desc_.reset(new std::vector<std::shared_ptr<const BlobDesc>>());\n  output_index2logical_blob_desc_->resize(output_bns().size());\n  for (int32_t i = 0; i < output_bns().size(); ++i) {\n    auto& out_blob_desc = output_logical_blob_desc_vec[i];\n    // initialize stride by shape if the op does not support non-contiguous\n    if (!JUST(SupportNonContiguous(this))) {\n      out_blob_desc->set_stride(Stride(out_blob_desc->shape()));\n    }\n    CHECK_EQ_OR_RETURN(out_blob_desc->stride().size(), out_blob_desc->shape().size())\n        << Error::RuntimeError() << \"stride and shape size mismatch since stride is \"\n        << out_blob_desc->stride().ToString() << \" but shape is \"\n        << out_blob_desc->shape().ToString();\n    (*output_index2logical_blob_desc_)[i] = out_blob_desc;\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> Operator::InferBlobDescsIf(\n    const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,\n    const ParallelContext* parallel_ctx, const JobDesc* job_desc) const {\n  JUST(InferOutBlobDescsIf(GetBlobDesc4BnInOp, parallel_ctx));\n  JUST(InferInternalBlobDescsIf(GetBlobDesc4BnInOp, parallel_ctx, job_desc));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> Operator::InferOutBlobDescsIf(\n    std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,\n    const ParallelContext* parallel_ctx) const {\n  JUST(InferOutBlobDescs(GetBlobDesc4BnInOp, parallel_ctx));\n  for (const auto& bn : output_bns()) {\n    BlobDesc* out_blob_desc = GetBlobDesc4BnInOp(bn);\n    // initialize stride by shape if the op does not support non-contiguous\n    if (!JUST(SupportNonContiguous(this))) {\n      out_blob_desc->set_stride(Stride(out_blob_desc->shape()));\n    }\n    CHECK_EQ_OR_RETURN(out_blob_desc->stride().size(), out_blob_desc->shape().size())\n        << Error::RuntimeError() << \"stride and shape size mismatch since stride is \"\n        << out_blob_desc->stride().ToString() << \" but shape is \"\n        << out_blob_desc->shape().ToString();\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> Operator::InferOutBlobDescs(\n    const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,\n    const ParallelContext* parallel_ctx) const {\n  if (parallel_ctx->parallel_num() == 1) {\n    JUST(InferLogicalOutBlobDescs(GetBlobDesc4BnInOp, *JUST(GetOpParallelDesc())));\n  } else {\n    const auto& nd_sbp_signature = JUST(this->nd_sbp_signature());\n    const auto& parallel_desc = JUST(this->GetOpParallelDesc());\n    for (const auto& bn : input_bns()) {\n      const auto& nd_sbp = nd_sbp_signature->bn_in_op2nd_sbp().at(bn);\n      std::shared_ptr<const BlobDesc> in_logical = JUST(GetLogicalBlobDesc4Ibn(bn));\n      CHECK_OR_RETURN(\n          *JUST(GetPhysicalShape(in_logical->shape(), nd_sbp, *parallel_desc, *parallel_ctx))\n          == GetBlobDesc4BnInOp(bn)->shape());\n    }\n    for (const auto& bn : output_bns()) {\n      BlobDesc* desc = GetBlobDesc4BnInOp(bn);\n      *desc = *JUST(GetLogicalBlobDesc4Obn(bn));\n      const auto& nd_sbp = nd_sbp_signature->bn_in_op2nd_sbp().at(bn);\n      desc->set_shape(\n          *JUST(GetPhysicalShape(desc->shape(), nd_sbp, *parallel_desc, *parallel_ctx)));\n    }\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> Operator::InferInternalBlobDescsIf(\n    const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,\n    const ParallelContext* parallel_ctx, const JobDesc* job_desc) const {\n  return InferInternalBlobDescs(GetBlobDesc4BnInOp, parallel_ctx, job_desc);\n}\n\nMaybe<void> Operator::InferInternalBlobDescs(\n    const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,\n    const ParallelContext* parallel_ctx, const JobDesc* job_desc) const {\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> Operator::InferInplaceObn2IbnIf(\n    HashMap<std::string, std::string>* mut_inplace_obn2ibn,\n    HashMap<std::string, std::string>* con_inplace_obn2ibn,\n    const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,\n    const ParallelContext* parallel_ctx) const {\n  return InferInplaceObn2Ibn(mut_inplace_obn2ibn, con_inplace_obn2ibn, GetBlobDesc4BnInOp,\n                             parallel_ctx);\n}\n\nMaybe<void> Operator::InferInplaceObn2Ibn(\n    HashMap<std::string, std::string>* mut_inplace_obn2ibn,\n    HashMap<std::string, std::string>* con_inplace_obn2ibn,\n    const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,\n    const ParallelContext* parallel_ctx) const {\n  for (const std::string& obn : output_bns()) {\n    const auto& obn_modifier = OutputBlobModifier4Obn(obn);\n    if (obn_modifier.has_mutable_inplace_ibn()) {\n      mut_inplace_obn2ibn->emplace(obn, obn_modifier.mutable_inplace_ibn());\n    } else if (obn_modifier.has_const_inplace_ibn()) {\n      con_inplace_obn2ibn->emplace(obn, obn_modifier.const_inplace_ibn());\n    }\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> Operator::FillInputBlobTimeShape(\n    const std::function<Maybe<const Shape>(int32_t)>& GetTimeShape4InputIndex) {\n  CHECK_OR_RETURN(!input_index2time_shape_);\n  input_index2time_shape_.reset(new std::vector<std::shared_ptr<const Shape>>());\n  input_index2time_shape_->reserve(input_bns().size());\n  for (int32_t i = 0; i < input_bns().size(); ++i) {\n    std::shared_ptr<const Shape> time_shape = JUST(GetTimeShape4InputIndex(i));\n    if ((!input_blob_fastest_time_shape_)\n        || input_blob_fastest_time_shape_->elem_cnt() < time_shape->elem_cnt()) {\n      input_blob_fastest_time_shape_ = time_shape;\n    }\n    input_index2time_shape_->emplace_back(time_shape);\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> Operator::InferOpTimeShapeIf() {\n  CHECK_OR_RETURN(!op_time_shape_);\n  CHECK_OR_RETURN(input_index2time_shape_);\n  auto GetTimeShape4BnInOp = [&](const std::string& ibn) -> Maybe<const Shape> {\n    const auto& it = bn2index_pair_.find(ibn);\n    CHECK_OR_RETURN(it != bn2index_pair_.end());\n    CHECK_EQ_OR_RETURN(it->second.first, kInputBlobName);\n    return input_index2time_shape_->at(it->second.second);\n  };\n  JUST(InferOpTimeShape(GetTimeShape4BnInOp, &op_time_shape_));\n  if (input_blob_fastest_time_shape_\n      && input_blob_fastest_time_shape_->elem_cnt() > op_time_shape_->elem_cnt()) {\n    input_output_fastest_time_shape_ = input_blob_fastest_time_shape_;\n  } else {\n    input_output_fastest_time_shape_ = op_time_shape_;\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> Operator::InferOpTimeShape(\n    const std::function<Maybe<const Shape>(const std::string&)>& GetTimeShape4BnInOp,\n    std::shared_ptr<const Shape>* time_shape) const {\n  if (!input_bns().empty()) {\n    std::shared_ptr<const Shape> first_time_shape = input_index2time_shape_->at(0);\n    for (int64_t i = 1; i < input_bns().size(); ++i) {\n      CHECK_EQ_OR_RETURN(*input_index2time_shape_->at(i), *first_time_shape);\n    }\n    *time_shape = first_time_shape;\n  } else {\n    *time_shape = std::make_shared<const Shape>(Shape({1, 1}));\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<const Shape> Operator::GetOpTimeShape() const {\n  CHECK_OR_RETURN(op_time_shape_);\n  return op_time_shape_;\n}\n\nMaybe<const Shape> Operator::GetInputBlobFastestTimeShape() const {\n  return input_blob_fastest_time_shape_;\n}\n\nMaybe<const Shape> Operator::GetInputOutputFastestTimeShape() const {\n  return input_output_fastest_time_shape_;\n}\n\nMaybe<void> Operator::GetSbpSignaturesIf(\n    const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,\n    int32_t hierarchy_value, SbpSignatureList* sbp_sig_list) const {\n  JUST(GetSbpSignatures(LogicalBlobDesc4Ibn, hierarchy_value, sbp_sig_list));\n  SbpSignatureBuilder()\n      .Broadcast(input_bns())\n      .Broadcast(output_bns())\n      .Build(sbp_sig_list->mutable_sbp_signature()->Add());\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> Operator::EnumerateNdSbpSignatures(\n    const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,\n    const ParallelDesc& parallel_desc, std::vector<NdSbpSignature>* nd_sbp_sig_list) const {\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> Operator::GetNdSbpSignatureList(\n    const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,\n    const ParallelDesc& parallel_desc, std::vector<NdSbpSignature>* nd_sbp_sig_list) const {\n  // Get 1D sbp signature list\n  HashMap<int32_t, SbpSignatureList> hierarchy_value2sbp_sig_list;\n  // hierarchy value is the value at the dimension corresponding to the current SBP\n  // For example, 2 machines, 4 gpus per machine, hierarchy = [2, 4]\n  // Suppose we have nd_sbp = (S0, B)\n  // The hierarchy value corresponding to S0 is 2\n  // The hierarchy value corresponding to B is 4.\n  for (int32_t hierarchy_value : *parallel_desc.hierarchy()) {\n    if (hierarchy_value2sbp_sig_list.find(hierarchy_value) == hierarchy_value2sbp_sig_list.end()) {\n      auto* sbp_sig_list = &hierarchy_value2sbp_sig_list[hierarchy_value];\n      JUST(GetSbpSignaturesIf(LogicalBlobDesc4Ibn, hierarchy_value, sbp_sig_list));\n      CHECK_GT_OR_RETURN(sbp_sig_list->sbp_signature_size(), 0)\n          << op_name()\n          << \" gets no sbp signature from GetSbpSignaturesIf function for hierarchy value: \"\n          << hierarchy_value;\n    }\n  }\n\n  int32_t sbp_dimension = parallel_desc.hierarchy()->NumAxes();\n  NdSbpSignature nd_sbp_sig;\n  SbpSignatureToNdSbpSignature(hierarchy_value2sbp_sig_list.begin()->second.sbp_signature(0),\n                               &nd_sbp_sig);\n  ResizeNdSbpSignature(nd_sbp_sig, sbp_dimension);\n  // ND sbp signature list would be direct product of 1D sbp signatures\n  CHECK_OR_RETURN(nd_sbp_sig_list->empty());\n  DfsGetNdSbpSignature(nd_sbp_sig, 0, sbp_dimension, *parallel_desc.hierarchy(),\n                       hierarchy_value2sbp_sig_list, nd_sbp_sig_list);\n  JUST(EnumerateNdSbpSignatures(LogicalBlobDesc4Ibn, parallel_desc, nd_sbp_sig_list));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> Operator::GetValidNdSbpSignatureList(\n    const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,\n    const ParallelDesc& parallel_desc, std::vector<NdSbpSignature>* nd_sbp_sig_list,\n    bool check_output) const {\n  JUST(GetNdSbpSignatureList(LogicalBlobDesc4Ibn, parallel_desc, nd_sbp_sig_list));\n  // Leave those valid Nd SBPs\n  JUST(FilterNdSbpSignatureListByLogicalShape(LogicalBlobDesc4Ibn, parallel_desc, nd_sbp_sig_list,\n                                              check_output));\n  CHECK_OR_RETURN(nd_sbp_sig_list->size() > 0)\n      << \"Empty sbp signature after filtering for \" << op_name();\n  return Maybe<void>::Ok();\n}\n\nOperator::DumpNdSbpSignatureForOpConfFn Operator::GetDumpNdSbpSignatureForOpConfFn() const {\n  return [](const NdSbpSignature& nd_sbp_sig, OperatorConf* op_conf) -> Maybe<void> {\n    return Maybe<void>::Ok();\n  };\n}\n\nvoid Operator::ForEachBnInOp(const std::function<void(const std::string&)>& Handler) const {\n  for (const std::string& bn_in_op : input_bns()) { Handler(bn_in_op); }\n  for (const std::string& bn_in_op : output_bns()) { Handler(bn_in_op); }\n  for (const std::string& bn_in_op : tmp_bns()) { Handler(bn_in_op); }\n}\n\nMaybe<void> Operator::FillSbpSignature(const SbpSignature& sbp_signature) {\n  NdSbpSignature nd_sbp_signature;\n  SbpSignatureToNdSbpSignature(sbp_signature, &nd_sbp_signature);\n  JUST(FillNdSbpSignature(nd_sbp_signature));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> Operator::FillNdSbpSignature(const NdSbpSignature& signature) {\n  CHECK_OR_RETURN(!nd_sbp_signature_);\n  CHECK_OR_RETURN(!sbp_signature_);\n  nd_sbp_signature_.reset(new NdSbpSignature(signature));\n  CHECK_OR_RETURN(op_parallel_desc_);\n  if (op_parallel_desc_->hierarchy()->NumAxes() == 1) {\n    SbpSignature sbp_signature;\n    NdSbpSignatureToSbpSignature(signature, &sbp_signature);\n    sbp_signature_.reset(new SbpSignature(sbp_signature));\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> Operator::InferSbpSignatureIf(\n    const SbpSignature& sbp_sig_conf,\n    const std::function<int32_t(const SbpSignature&)>& CalcOrderValue4SbpSig,\n    const std::function<Maybe<const SbpInferHint*>(const std::string&)>& SbpInferHint4Ibn,\n    const ParallelDesc& parallel_desc) {\n  SbpSignature signature;\n\n  JUST(InferSbpSignature(&signature, sbp_sig_conf, CalcOrderValue4SbpSig, SbpInferHint4Ibn,\n                         parallel_desc));\n\n  JUST(FillSbpSignature(signature));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> Operator::InferSbpSignature(\n    SbpSignature* infered_sbp_signature, const SbpSignature& sbp_sig_conf,\n    const HashMap<std::string, SbpInferHint>& ibn2sbp_infer_hint) const {\n  auto SbpInferHint4Ibn = [&](const std::string& ibn) -> Maybe<const SbpInferHint*> {\n    auto it = ibn2sbp_infer_hint.find(ibn);\n    if (it == ibn2sbp_infer_hint.end()) {\n      return Error::CheckFailedError()\n             << \"cannot find corresponding SbpInferHint for input_blob_name : \" << ibn;\n    }\n    return &(it->second);\n  };\n  std::function<int32_t(const SbpSignature&)> CalcOrderValue4SbpSig;\n  auto OrderValue4SourceDefaultSplit0 = [&](const std::string& bn,\n                                            const SbpParallel& sbp_parallel) -> int32_t {\n    return -1 * (sbp_parallel.has_split_parallel() && sbp_parallel.split_parallel().axis() == 0);\n  };\n  auto OrderValue4SbpHint = [&](const std::string& ibn,\n                                const SbpParallel& sbp_parallel) -> int32_t {\n    const auto* hint = CHECK_JUST(SbpInferHint4Ibn(ibn));\n    // NOTE(chengcheng): one to one connect.\n    return -10\n           * (hint->sbp_parallel() == sbp_parallel\n              && hint->parallel_desc().parallel_num() == op_parallel_desc_->parallel_num());\n  };\n\n  if (sbp_sig_conf.bn_in_op2sbp_parallel().empty()) {\n    CalcOrderValue4SbpSig = [&](const SbpSignature& sbp_signature) -> int32_t {\n      int32_t order_value = 0;\n      if (input_bns().size() > 0) {\n        // NOTE(chengcheng): non-source op only ordered by input sbp match.\n        for (const auto& ibn : input_bns()) {\n          const auto& sbp_parallel_it = sbp_signature.bn_in_op2sbp_parallel().find(ibn);\n          CHECK(sbp_parallel_it != sbp_signature.bn_in_op2sbp_parallel().end());\n          order_value += OrderValue4SbpHint(ibn, sbp_parallel_it->second);\n        }\n      } else {\n        // NOTE(chengcheng): source op default split(0)\n        //   ONLY data source op will consider order here. variable op sbp is set by user.\n        for (const auto& obn : output_bns()) {\n          const auto& sbp_parallel_it = sbp_signature.bn_in_op2sbp_parallel().find(obn);\n          CHECK(sbp_parallel_it != sbp_signature.bn_in_op2sbp_parallel().end());\n          order_value += OrderValue4SourceDefaultSplit0(obn, sbp_parallel_it->second);\n        }\n      }\n      return order_value;\n    };\n  } else {\n    CalcOrderValue4SbpSig = [](const SbpSignature&) -> int32_t { return 0; };\n  }\n\n  JUST(InferSbpSignature(infered_sbp_signature, sbp_sig_conf, CalcOrderValue4SbpSig,\n                         SbpInferHint4Ibn, *op_parallel_desc_));\n\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> Operator::FilterAndCheckValidSbpSignatureListByLogicalShape(\n    const SbpSignatureList& total_sbp_sig_list,\n    const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,\n    const ParallelDesc& parallel_desc, SbpSignatureList* valid_sbp_sig_list) const {\n  auto GetOpDebugShapeStr = [&]() -> std::string {\n    std::string ret = \"\";\n    if (op_conf().has_user_conf()) {\n      ret += (\"op_type_name = \" + op_conf().user_conf().op_type_name() + \", \");\n    }\n    for (const auto& ibn : input_bns()) {\n      ret +=\n          (\" ibn:(\" + ibn + \") lbn:(\" + GenLogicalBlobName(BnInOp2Lbi(ibn))\n           + \") logical_shape = \" + CHECK_JUST(LogicalBlobDesc4Ibn(ibn)).shape().DebugStr() + \", \");\n    }\n    return ret;\n  };\n  for (const auto& sbp_signature : total_sbp_sig_list.sbp_signature()) {\n    bool is_valid = true;\n\n    for (const auto& ibn : input_bns()) {\n      const auto& sbp_parallel_it = sbp_signature.bn_in_op2sbp_parallel().find(ibn);\n      CHECK_OR_RETURN(sbp_parallel_it != sbp_signature.bn_in_op2sbp_parallel().end());\n      const SbpParallel& sbp_parallel = sbp_parallel_it->second;\n      const Shape& logical_shape = JUST(LogicalBlobDesc4Ibn(ibn)).shape();\n      // NOTE(chengcheng): disable split when logical shape cannot split at this axis\n      if (sbp_parallel.has_split_parallel()) {\n        const int64_t axis = sbp_parallel.split_parallel().axis();\n        CHECK_OR_RETURN(axis >= 0 && axis < logical_shape.NumAxes())\n            << \"The sbp sign is ERROR because of the split axis >= shape num axes. In op: [\"\n            << op_name() << \"] ibn: (\" << ibn << \") the split axis is = \" << axis\n            << \" . And the logical_shape = \" << logical_shape.DebugStr()\n            << \". This Op debug str = {\" << GetOpDebugShapeStr() << \"}\";\n        if (logical_shape.At(axis) < parallel_desc.parallel_num()) {\n          // NOTE(chengcheng): cannot split at this axis!\n          is_valid = false;\n          break;\n        }\n      }\n    }\n    if (is_valid) { *valid_sbp_sig_list->mutable_sbp_signature()->Add() = sbp_signature; }\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> Operator::FilterNdSbpSignatureListByLogicalShape(\n    const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,\n    const ParallelDesc& parallel_desc, std::vector<NdSbpSignature>* nd_sbp_sig_list,\n    bool check_output) const {\n  auto FilterSbp4Blobs = [&](const PbRpf<std::string>& bns,\n                             const NdSbpSignature& nd_sbp_sig) -> Maybe<bool> {\n    // {in_0 : (S(6), B), in_1 : (S(0), S(1)), out : (B, S(1))}\n    // look through input blob name in_0 and in_1\n    for (const auto& ibn : bns) {\n      const auto& nd_sbp_it = nd_sbp_sig.bn_in_op2nd_sbp().find(ibn);\n      // Find an unexpected blob name\n      CHECK_OR_RETURN(nd_sbp_it != nd_sbp_sig.bn_in_op2nd_sbp().end());\n      const auto& nd_sbp = nd_sbp_it->second;\n      Shape logical_shape = JUST(LogicalBlobDesc4Ibn(ibn)).shape();\n      const auto& parallel_hierarchy = parallel_desc.hierarchy();\n      // Treat 1D sbp and nD sbp differently. Please refer to\n      // JobBuildAndInferCtx::CheckOpBlobSplitability\n      // for more details.\n      if (JUST(FilterNdSbpByLogicalShape(nd_sbp, logical_shape, *parallel_hierarchy))) {\n        return true;\n      }\n    }\n    return false;\n  };\n  // Go down from the tail to the head, since we might drop the tail.\n  for (int32_t sbp_id = nd_sbp_sig_list->size() - 1; sbp_id >= 0; sbp_id--) {\n    if (JUST(FilterSbp4Blobs(input_bns(), JUST(VectorAt(*nd_sbp_sig_list, sbp_id))))\n        || (check_output\n            && JUST(FilterSbp4Blobs(output_bns(), JUST(VectorAt(*nd_sbp_sig_list, sbp_id)))))) {\n      // Remove the Nd SBP candidate\n      (*nd_sbp_sig_list)[sbp_id] = JUST(VectorAt(*nd_sbp_sig_list, nd_sbp_sig_list->size() - 1));\n      nd_sbp_sig_list->pop_back();\n    }\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> Operator::GreedilyFindMinCopyCostNdSbp(\n    NdSbpSignature* nd_sbp_signature,\n    const std::function<Maybe<const NdSbpInferHint*>(const std::string&)>& NdSbpInferHint4Ibn,\n    const std::vector<NdSbpSignature>& nd_sbp_sig_list) const {\n  int32_t select_sbp_idx = -1;\n  double min_copy_cost = GetValidMaxCopyCost();\n  // We notice that we have a lot of inquiries asking for the cost.\n  // If the candidate list only have one entry, select it to reduce the inquiries.\n  // Normally, we support all the sbp combination for boxing. Therefore, we do not need to worry\n  // about the case that we can not transfer to this sbp signature. Even if we do not support such\n  // transfer, a report would be sent in boxing_with_middle_nodes.cpp.\n  if (nd_sbp_sig_list.size() == 1) {\n    select_sbp_idx = 0;\n  } else {\n    std::vector<bool> requires_same_sbp(input_bns().size());\n    for (int32_t ibn_id = 0; ibn_id < input_bns().size(); ibn_id++) {\n      const auto& ibn = input_bns().at(ibn_id);\n      const auto& blob_modifier_ = InputBlobModifier4Ibn(ibn);\n      requires_same_sbp[ibn_id] =\n          (blob_modifier_.has_is_mutable() && blob_modifier_.is_mutable())\n          || NotSupportBoxingDataType(\n              JUST(NdSbpInferHint4Ibn(ibn))->logical_blob_desc().data_type());\n    }\n    // SBP_INFER_RULE_TAG = 1, pick the sbp signature which matches all the producers\n    //                          or has the lowest cost\n    // SBP_INFER_RULE_TAG = 2, pick the sbp signature which matches as much as possible\n    // SBP_INFER_RULE_TAG = 3, pick the sbp signature which has the lowest cost\n    static int32_t infer_rule = ParseIntegerFromEnv(\"SBP_INFER_RULE_TAG\", 1);\n    for (int32_t i = 0; i < nd_sbp_sig_list.size(); ++i) {\n      double total_copy_cost = 0.0;\n      double sum_priority_ratio = 0.0;\n      // The initial ratio do not need to be a large one.\n      // Since any copy cost less than infinity would reset the min_sum_priority_ratio.\n      double min_sum_priority_ratio = 0.0;\n      bool same_sbp_before_reduce = true;\n      for (int32_t ibn_id = 0; ibn_id < input_bns().size(); ibn_id++) {\n        const auto& ibn = input_bns().at(ibn_id);\n        const auto& producer_infer_hint4ibn = JUST(NdSbpInferHint4Ibn(ibn));\n        same_sbp_before_reduce &= producer_infer_hint4ibn->nd_sbp()\n                                  == JUST(VectorAt(nd_sbp_sig_list, i)).bn_in_op2nd_sbp().at(ibn);\n        // Skip the computation of priority ratio if SBP_INFER_RULE_TAG = 3\n        if (infer_rule != SbpInferRuleTag::kMinCost) {\n          double priority_ratio = ComputeSbpInferPriority(\n              producer_infer_hint4ibn->nd_sbp(),\n              JUST(VectorAt(nd_sbp_sig_list, i)).bn_in_op2nd_sbp().at(ibn),\n              producer_infer_hint4ibn->parallel_desc(), *JUST(GetParallelDesc4BnInOp(ibn)),\n              requires_same_sbp[ibn_id], producer_infer_hint4ibn->logical_blob_desc().shape());\n          sum_priority_ratio += priority_ratio;\n          // We do not accept any blob which has a priority ratio greater than 1\n          if (priority_ratio > 1.5) {\n            total_copy_cost = GetMaxVal<float>();\n            break;\n          }\n          // If SBP_INFER_RULE_TAG = 2 and the input blob has a matched sbp,\n          // skip the computation of the transfer cost\n          if (infer_rule == SbpInferRuleTag::kMatchAMAP && priority_ratio == 0.0) { continue; }\n        }\n        // Compute the cost and add them up\n        total_copy_cost += JUST(ComputeCopyCostBetweenNdSbp(\n            producer_infer_hint4ibn->nd_sbp(),\n            JUST(VectorAt(nd_sbp_sig_list, i)).bn_in_op2nd_sbp().at(ibn),\n            producer_infer_hint4ibn->logical_blob_desc(), producer_infer_hint4ibn->parallel_desc(),\n            *JUST(GetParallelDesc4BnInOp(ibn)), requires_same_sbp[ibn_id]));\n        // Reduce inquiries when the current cost is larger than the minimum cost\n        // For SBP_INFER_RULE_TAG = 1, do not prune it since the all-matched case\n        // might have larger cost.\n        if (infer_rule != SbpInferRuleTag::kAllMatch && total_copy_cost > min_copy_cost) { break; }\n      }\n      // For SBP_INFER_RULE_TAG = 1, select the all-matched case if found\n      if (infer_rule == SbpInferRuleTag::kAllMatch && same_sbp_before_reduce\n          && sum_priority_ratio == 0.0) {\n        select_sbp_idx = i;\n        break;\n      }\n      // Otherwise, select the case with the lowest cost\n      if (total_copy_cost < min_copy_cost * kFloatDeviationMinus      // Strict less than\n          || (total_copy_cost <= min_copy_cost * kFloatDeviationPlus  // Loose equal\n              && sum_priority_ratio < min_sum_priority_ratio)) {\n        select_sbp_idx = i;\n        min_copy_cost = total_copy_cost;\n        min_sum_priority_ratio = sum_priority_ratio;  // NOLINT(clang-analyzer-deadcode.DeadStores)\n      }\n    }\n    // Can't find any available sbp\n    if (select_sbp_idx == -1) {\n      std::ostringstream err;\n      err << \"op: `\" << op_name() << \"` can't find available sbp signature.\" << std::endl;\n      err << \"candidate nd sbp signature are: \"\n          << *JUST(NdSbpSignatureListToString(nd_sbp_sig_list, input_bns(), output_bns()));\n      err << \", but inputs sbp are:\";\n      for (int32_t ibn_id = 0; ibn_id < input_bns().size(); ibn_id++) {\n        const auto& ibn = input_bns().at(ibn_id);\n        const NdSbp& nd_sbp = JUST(NdSbpInferHint4Ibn(ibn))->nd_sbp();\n        err << \" \" << ibn << \": \" << NdSbpToString(nd_sbp);\n        if (requires_same_sbp[ibn_id]) { err << \" [ transfer disabled ]\"; }\n        err << \";\";\n      }\n\n      return Error::RuntimeError() << err.str();\n    }\n  }\n  nd_sbp_signature->CopyFrom(nd_sbp_sig_list.at(select_sbp_idx));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> Operator::InferSbpSignature(\n    SbpSignature* sbp_signature, const SbpSignature& sbp_sig_conf,\n    const std::function<int32_t(const SbpSignature&)>& CalcOrderValue4SbpSig,\n    std::function<Maybe<const SbpInferHint*>(const std::string&)> SbpInferHint4Ibn,\n    const ParallelDesc& parallel_desc) const {\n  // get op sbp signatures\n  auto LogicalBlobDesc4Ibn = [&](const std::string& ibn) -> Maybe<const BlobDesc&> {\n    const SbpInferHint* sbp_infer_hint = JUST(SbpInferHint4Ibn(ibn));\n    return Maybe<const BlobDesc&>(sbp_infer_hint->logical_blob_desc());\n  };\n  SbpSignatureList valid_sbp_sig_list;\n  {\n    SbpSignatureList sbp_sig_candidates;\n    // For 1d sbp, hierarchy value = parallel num\n    JUST(\n        GetSbpSignaturesIf(LogicalBlobDesc4Ibn, parallel_desc.parallel_num(), &sbp_sig_candidates));\n    // filter sbp signatures by logical shape\n    JUST(FilterAndCheckValidSbpSignatureListByLogicalShape(sbp_sig_candidates, LogicalBlobDesc4Ibn,\n                                                           parallel_desc, &valid_sbp_sig_list));\n  }\n  // filter sbp signatures by sbp signature conf\n  SbpSignatureList filtered_sbp_sigs_by_conf;\n  FilterSbpSignatureList(valid_sbp_sig_list, sbp_sig_conf, &filtered_sbp_sigs_by_conf);\n  CHECK_GT_OR_RETURN(filtered_sbp_sigs_by_conf.sbp_signature_size(), 0)\n      << op_name() << \" has no maching sbp after flitering valid sbp list \"\n      << valid_sbp_sig_list.DebugString() << \" with sbp hint \" << sbp_sig_conf.DebugString();\n  if (filtered_sbp_sigs_by_conf.sbp_signature_size() == 1) {\n    *sbp_signature = *filtered_sbp_sigs_by_conf.sbp_signature().begin();\n    return Maybe<void>::Ok();\n  }\n  // sort sbp signatures by copy cost, then return the one with least cost\n  HashMap<std::string, const SbpParallel*> ibn2producer_sbp_parallel;\n  for (const auto& ibn : input_bns()) {\n    ibn2producer_sbp_parallel[ibn] = &(JUST(SbpInferHint4Ibn(ibn))->sbp_parallel());\n  }\n  std::vector<const SbpSignature*> sorted_sbp_signatures;\n  SortSbpSignatureListByCopyCost(filtered_sbp_sigs_by_conf, input_bns(), SbpInferHint4Ibn,\n                                 CalcOrderValue4SbpSig, &sorted_sbp_signatures);\n  *sbp_signature = *sorted_sbp_signatures.at(0);\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> Operator::InferNdSbpSignatureIf(\n    const NdSbpSignature& nd_sbp_constraints, const ParallelDesc& parallel_desc,\n    std::function<Maybe<const NdSbpInferHint*>(const std::string&)> NdSbpInferHint4Ibn) {\n  NdSbpSignature nd_sbp_signature;\n  JUST(InferNdSbpSignature(&nd_sbp_signature, nd_sbp_constraints, parallel_desc,\n                           NdSbpInferHint4Ibn));\n  JUST(FillNdSbpSignature(nd_sbp_signature));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> Operator::InferNdSbpSignature(\n    NdSbpSignature* nd_sbp_signature, const NdSbpSignature& nd_sbp_constraints,\n    const ParallelDesc& parallel_desc,\n    std::function<Maybe<const NdSbpInferHint*>(const std::string&)> NdSbpInferHint4Ibn) const {\n  const auto& parallel_hierarchy = parallel_desc.hierarchy();\n  CHECK_GT(parallel_hierarchy->NumAxes(), 0);\n  if (parallel_hierarchy->NumAxes() == 1) {\n    // Infer 1d sbp\n    HashMap<std::string, SbpInferHint> ibn2sbp_infer_hint;\n    for (const auto& ibn : input_bns()) {\n      const NdSbpInferHint* hint = JUST(NdSbpInferHint4Ibn(ibn));\n      ibn2sbp_infer_hint.emplace(ibn,\n                                 SbpInferHint(&hint->parallel_desc(), &hint->logical_blob_desc(),\n                                              &hint->nd_sbp().sbp_parallel(0)));\n    }\n    SbpSignature sbp_constraints;\n    NdSbpSignatureToSbpSignature(nd_sbp_constraints, &sbp_constraints);\n    SbpSignature sbp_signature;\n    JUST(InferSbpSignature(&sbp_signature, sbp_constraints, ibn2sbp_infer_hint));\n    SbpSignatureToNdSbpSignature(sbp_signature, nd_sbp_signature);\n    return Maybe<void>::Ok();\n  } else {\n    // Infer nd sbp\n    const auto LogicalBlobDesc4Ibn = [&](const std::string& ibn) -> Maybe<const BlobDesc&> {\n      return JUST(NdSbpInferHint4Ibn(ibn))->logical_blob_desc();\n    };\n    std::vector<NdSbpSignature> nd_sbp_sig_list;\n    JUST(GetValidNdSbpSignatureList(LogicalBlobDesc4Ibn, parallel_desc, &nd_sbp_sig_list,\n                                    /*check_output=*/false));\n    // Filter nd_sbp according to `nd_sbp_constraints`\n    for (int32_t i = nd_sbp_sig_list.size() - 1; i >= 0; --i) {\n      // If any blob do not match nd_sbp_constraints, the candidate nd_sbp will be deleted.\n      if (/*not_match=*/std::any_of(input_bns().begin(), input_bns().end(), [&](const auto& ibn) {\n        const auto nd_sbp_constraints_it = nd_sbp_constraints.bn_in_op2nd_sbp().find(ibn);\n        if (nd_sbp_constraints_it != nd_sbp_constraints.bn_in_op2nd_sbp().end()) {\n          return nd_sbp_sig_list.at(i).bn_in_op2nd_sbp().at(ibn) != nd_sbp_constraints_it->second;\n        }\n        return false;\n      })) {\n        nd_sbp_sig_list.at(i) = nd_sbp_sig_list.back();\n        nd_sbp_sig_list.pop_back();\n      }\n    }\n    CHECK_OR_RETURN(!nd_sbp_sig_list.empty())\n        << \"Empty sbp signature after filtering for \" << op_name();\n    JUST(GreedilyFindMinCopyCostNdSbp(nd_sbp_signature, NdSbpInferHint4Ibn, nd_sbp_sig_list));\n    return Maybe<void>::Ok();\n  }\n}\n\nMaybe<void> Operator::InferLocalSignatureIf(\n    std::function<Maybe<const LocalSigInferHint*>(const std::string&)> LocalSigInferHint4Ibn,\n    bool is_local_parallel_view_conf, const ParallelDesc& parallel_desc) {\n  return InferLocalSignature(std::move(LocalSigInferHint4Ibn), is_local_parallel_view_conf,\n                             parallel_desc);\n}\n\n// Compute time complexity for given blob description and sbp signature.\n// Use function to replace the HashMap from logical blob id to blob description pointer.\nMaybe<double> Operator::GetComputeComplexity(\n    NdSbpSignature* sbp_signature,\n    std::function<const BlobDesc&(const std::string& bn)> logical_blob_desc4bn,\n    const ParallelDesc& parallel_desc) const {\n  const auto& sbp_bn_in_op2nd_sbp = sbp_signature->bn_in_op2nd_sbp();\n  double complexity = 0;\n  const auto& parallel_hierarchy = *parallel_desc.hierarchy();\n\n  auto ComputeComplexity4Blobs = [&](const PbRpf<std::string>& bns) -> Maybe<void> {\n    for (const auto& bn : bns) {\n      const BlobDesc& logical_blob_desc = logical_blob_desc4bn(bn);\n      const NdSbp& nd_sbp = sbp_bn_in_op2nd_sbp.at(bn);\n      CHECK_EQ_OR_RETURN(nd_sbp.sbp_parallel_size(), parallel_hierarchy.NumAxes())\n          << \"At this moment, the dimension of nd SBP should be equal to the depth of hierarchy in \"\n          << \"parallel description.\";\n\n      double total_cost = logical_blob_desc.shape().elem_cnt();\n      for (int32_t sbp_dim = 0; sbp_dim < nd_sbp.sbp_parallel_size(); sbp_dim++) {\n        const auto& sbp = nd_sbp.sbp_parallel(sbp_dim);\n        if (sbp.has_split_parallel()) {\n          const int64_t axis = sbp.split_parallel().axis();\n          if (axis >= logical_blob_desc.shape().NumAxes()\n              || logical_blob_desc.shape().At(axis) < parallel_hierarchy.At(sbp_dim)) {\n            complexity = GetMaxVal<float>();\n            return Maybe<void>::Ok();\n          } else {\n            total_cost /= parallel_hierarchy.At(sbp_dim);\n          }\n        }\n      }\n      complexity += total_cost;\n    }\n    return Maybe<void>::Ok();\n  };\n  JUST(ComputeComplexity4Blobs(input_bns()));\n  JUST(ComputeComplexity4Blobs(output_bns()));\n  return complexity;\n}\n\nstd::string DebugString4LocalHint(\n    std::function<Maybe<const LocalSigInferHint*>(const std::string&)> LocalSigInferHint4Ibn,\n    const Operator& op) {\n  std::string ret;\n  for (const auto& ibn : op.input_bns()) {\n    const auto& infer_hint = *CHECK_JUST(LocalSigInferHint4Ibn(ibn));\n    bool is_local = infer_hint.is_local_parallel_view();\n    ret += \"arg: \" + ibn + \", is_local: \" + (is_local ? \"true\" : \"false\") + \"\\n\";\n  }\n  return ret;\n}\n\nMaybe<void> Operator::InferLocalSignature(\n    std::function<Maybe<const LocalSigInferHint*>(const std::string&)>\n        LocalSigInferHint4Ibn,  // NOLINT\n    bool is_local_parallel_view_conf, const ParallelDesc& parallel_desc) {\n  HashSet<bool> is_local_parallel_view_values;\n  for (const auto& ibn : input_bns()) {\n    const auto& infer_hint = *JUST(LocalSigInferHint4Ibn(ibn));\n    is_local_parallel_view_values.insert(infer_hint.is_local_parallel_view());\n  }\n  CHECK_LE_OR_RETURN(is_local_parallel_view_values.size(), 1)\n      << \"mixed parallel_views are disallowed.\"\n      << \"\\n=========== is_mirrrored_conf ===========\\n\"\n      << DebugString4LocalHint(LocalSigInferHint4Ibn, *this) << \"\\n=========== op_cnf ===========\\n\"\n      << op_conf().DebugString();\n  if (is_local_parallel_view_values.size() == 1) {\n    is_local_parallel_view_conf = *is_local_parallel_view_values.begin();\n  }\n  if (is_local_parallel_view_conf) {\n    for (const auto& ibn : input_bns()) {\n      const auto& infer_hint = *JUST(LocalSigInferHint4Ibn(ibn));\n      CHECK_EQ_OR_RETURN(infer_hint.parallel_desc().parallel_num(), parallel_desc.parallel_num());\n    }\n  }\n  const auto SetIsLocalParallel = [&](const std::string& bn_in_op) {\n    if (is_local_parallel_view_conf) {\n      MutOptLocalParallel(bn_in_op)->mutable_local_parallel();\n    } else {\n      MutOptLocalParallel(bn_in_op)->clear_local_parallel();\n    }\n  };\n  for (const auto& ibn : input_bns()) { SetIsLocalParallel(ibn); }\n  for (const auto& obn : output_bns()) { SetIsLocalParallel(obn); }\n  return Maybe<void>::Ok();\n}\n\nMaybe<const SbpSignature*> Operator::sbp_signature() const {\n  CHECK_OR_RETURN(sbp_signature_) << \"sbp signature not infered\";\n  return sbp_signature_.get();\n}\n\nMaybe<const NdSbpSignature*> Operator::nd_sbp_signature() const {\n  CHECK_OR_RETURN(nd_sbp_signature_) << \"parallel distribution signature not infered\";\n  return nd_sbp_signature_.get();\n}\n\nBlobLastUsedSignature* Operator::mut_blob_last_used_signature() {\n  if (!blob_last_used_signature_) { blob_last_used_signature_.reset(new BlobLastUsedSignature()); }\n  return blob_last_used_signature_.get();\n}\n\nBlobBackwardUsedSignature* Operator::mut_blob_backward_used_signature() {\n  if (!blob_backward_used_signature_) {\n    blob_backward_used_signature_.reset(new BlobBackwardUsedSignature());\n  }\n  return blob_backward_used_signature_.get();\n}\n\nMaybe<const SbpParallel*> Operator::SbpParallel4BnInOp(const std::string& bn_in_op) const {\n  CHECK_OR_RETURN(sbp_signature_) << \"sbp signature not infered\";\n  const auto& map = sbp_signature_->bn_in_op2sbp_parallel();\n  const auto& iter = map.find(bn_in_op);\n  CHECK_OR_RETURN(iter != map.end()) << \"blob_name \" << bn_in_op << \" not found in sbp signature\";\n  return &iter->second;\n}\n\nMaybe<const NdSbp*> Operator::NdSbp4BnInOp(const std::string& bn_in_op) const {\n  CHECK_OR_RETURN(nd_sbp_signature_) << \"parallel distribution signature not infered\";\n  const auto& map = nd_sbp_signature_->bn_in_op2nd_sbp();\n  const auto& iter = map.find(bn_in_op);\n  CHECK_OR_RETURN(iter != map.end()) << \"op_name \" << op_name() << \" blob_name \" << bn_in_op\n                                     << \" not found in parallel distribution\";\n  return &iter->second;\n}\n\nMaybe<const OptLocalParallel*> Operator::OptLocalParallel4BnInOp(\n    const std::string& bn_in_op) const {\n  CHECK_OR_RETURN(local_signature_) << \"local signature not infered\";\n  const auto& map = local_signature_->bn_in_op2opt_local_parallel();\n  const auto& iter = map.find(bn_in_op);\n  CHECK_OR_RETURN(iter != map.end()) << \"blob_name \" << bn_in_op << \" not found in local signature\";\n  return &iter->second;\n}\n\nOptLocalParallel* Operator::MutOptLocalParallel(const std::string& bn_in_op) {\n  if (!local_signature_) { local_signature_.reset(new LocalSignature()); }\n  auto* map = local_signature_->mutable_bn_in_op2opt_local_parallel();\n  return &(*map)[bn_in_op];\n}\n\nnamespace {\n\nbool HasBlobDescWithField(std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,\n                          const PbRpf<std::string>& bn_in_ops,\n                          std::function<bool(const BlobDesc*)> Predicator4BlobDesc) {\n  for (const std::string& bn_in_op : bn_in_ops) {\n    const BlobDesc* blob_desc = GetBlobDesc4BnInOp(bn_in_op);\n    if (blob_desc && Predicator4BlobDesc(blob_desc)) { return true; }\n  }\n  return false;\n}\n\n}  // namespace\n\nvoid Operator::GenKernelConf(\n    const std::function<const BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,\n    const ParallelContext* parallel_ctx, KernelConf* kernel_conf) const {\n  auto* dtype_signature = kernel_conf->mutable_dtype_signature();\n  for (const std::string& ibn : input_bns()) {\n    const BlobDesc* blob_desc = GetBlobDesc4BnInOp(ibn);\n    if (blob_desc == nullptr) { continue; }\n    (*dtype_signature->mutable_name2dtype())[ibn] = blob_desc->data_type();\n  }\n\n  CHECK_JUST(ToOpAttribute(kernel_conf->mutable_op_attribute()));\n  kernel_conf->set_all_blobs_are_static(\n      !HasBlobDescWithField(GetBlobDesc4BnInOp, output_bns(),\n                            [](const BlobDesc* blob_desc) { return blob_desc->is_dynamic(); }));\n  {\n    DataType data_type = GetDataTypeFromBnInOpVec(GetBlobDesc4BnInOp, output_bns());\n    if (data_type == DataType::kInvalidDataType) {\n      data_type = GetDataTypeFromBnInOpVec(GetBlobDesc4BnInOp, input_bns());\n    }\n    kernel_conf->set_data_type(data_type);\n  }\n\n  if (parallel_ctx != nullptr) { *(kernel_conf->mutable_parallel_ctx()) = *parallel_ctx; }\n\n  VirtualGenKernelConf(GetBlobDesc4BnInOp, parallel_ctx, kernel_conf);\n}\n\nvoid Operator::VirtualGenKernelConf(\n    std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,\n    const ParallelContext* parallel_ctx, KernelConf* kernel_conf) const {}\n\nvoid Operator::AddLbi2OutputIndex(const LogicalBlobId& lbi, int32_t output_index) {\n  CHECK(lbi2output_index_.emplace(lbi, output_index).second);\n}\n\nstd::string Operator::Bn2ConfName(const std::string& bn) const {\n  return GetStrValInPbFdOrPbRpf(GetCustomizedConf(), bn);\n}\n\nLogicalBlobId Operator::lbi4ibn(const std::string& input_bn) const {\n  return GenLogicalBlobId(Bn2ConfName(input_bn));\n}\nLogicalBlobId Operator::lbi4obn(const std::string& output_bn) const {\n  LogicalBlobId ret;\n  ret.set_op_name(op_name());\n  ret.set_blob_name(Bn2ConfName(output_bn));\n  return ret;\n}\nLogicalBlobId Operator::tbn2lbi(const std::string& tmp_bn) const {\n  LogicalBlobId ret;\n  ret.set_op_name(op_name());\n  ret.set_blob_name(tmp_bn);\n  return ret;\n}\n\nvoid Operator::EnrollTmpBn(const std::string& tbn) {\n  *tmp_bns_.Add() = tbn;\n  CHECK(mut_bn_in_op2lbi()->insert({tbn, tbn2lbi(tbn)}).second);\n}\n\nInputBlobModifier* Operator::EnrollInputBn(const std::string& ibn, bool has_diff) {\n  LogicalBlobId lbi = lbi4ibn(ibn);\n  auto* map = arg_modifier_signature_.mutable_ibn2input_blob_modifier();\n  const auto& pair = map->insert({ibn, InputBlobModifier()});\n  CHECK(pair.second);\n  const int32_t input_index = input_bns_.size();\n  CHECK(\n      bn2index_pair_.emplace(ibn, std::make_pair(BlobNameTag::kInputBlobName, input_index)).second);\n  *input_bns_.Add() = ibn;\n  CHECK(mut_bn_in_op2lbi()->insert({ibn, lbi}).second);\n  auto* ret = &pair.first->second;\n  ret->set_requires_grad(has_diff);\n  return ret;\n}\n\nconst InputBlobModifier& Operator::InputBlobModifier4Ibn(const std::string& ibn) const {\n  return arg_modifier_signature_.ibn2input_blob_modifier().at(ibn);\n}\n\nconst OutputBlobModifier& Operator::OutputBlobModifier4Obn(const std::string& obn) const {\n  return arg_modifier_signature_.obn2output_blob_modifier().at(obn);\n}\n\nInputBlobModifier* Operator::MutInputBlobModifier4Ibn(const std::string& ibn) {\n  auto* map = arg_modifier_signature_.mutable_ibn2input_blob_modifier();\n  return &map->at(ibn);\n}\n\nOutputBlobModifier* Operator::MutOutputBlobModifier4Obn(const std::string& obn) {\n  auto* map = arg_modifier_signature_.mutable_obn2output_blob_modifier();\n  return &map->at(obn);\n}\n\nvoid Operator::EnrollRepeatedInputBn(const std::string& ibn_prefix, int32_t num, bool has_diff) {\n  FOR_RANGE(int32_t, i, 0, num) { EnrollInputBn(GenRepeatedBn(ibn_prefix, i), has_diff); }\n}\n\nvoid Operator::EnrollRepeatedInputBn(const std::string& ibn_prefix, bool has_diff) {\n  EnrollRepeatedInputBn(ibn_prefix, GetPbRpfFromCustomizedConf<std::string>(ibn_prefix).size(),\n                        has_diff);\n}\n\nvoid Operator::EnrollRepeatedInputBn(const std::string& ibn_prefix, int32_t num) {\n  EnrollRepeatedInputBn(ibn_prefix, num, true);\n}\n\nvoid Operator::EnrollRepeatedInputBn(const std::string& ibn_prefix) {\n  EnrollRepeatedInputBn(ibn_prefix, true);\n}\n\nOutputBlobModifier* Operator::EnrollOutputBn(const std::string& obn, bool has_diff) {\n  LogicalBlobId lbi = lbi4obn(obn);\n  auto* map = arg_modifier_signature_.mutable_obn2output_blob_modifier();\n  const auto& pair = map->insert({obn, OutputBlobModifier()});\n  CHECK(pair.second);\n  auto* ret = &pair.first->second;\n  const int32_t output_index = output_bns_.size();\n  CHECK(bn2index_pair_.emplace(obn, std::make_pair(BlobNameTag::kOutputBlobName, output_index))\n            .second);\n  AddLbi2OutputIndex(lbi, output_index);\n  *output_bns_.Add() = obn;\n  CHECK(mut_bn_in_op2lbi()->insert({obn, lbi}).second);\n  ret->set_requires_grad(has_diff);\n  return ret;\n}\n\nvoid Operator::EnrollRepeatedOutputBnWithSetter(\n    const std::string& obn_prefix, int32_t num, bool has_diff,\n    const std::function<void(OutputBlobModifier*)>& ModifierSetter) {\n  FOR_RANGE(int32_t, i, 0, num) {\n    ModifierSetter(EnrollOutputBn(GenRepeatedBn(obn_prefix, i), has_diff));\n  }\n}\n\nvoid Operator::EnrollRepeatedOutputBnWithSetter(\n    const std::string& obn_prefix, bool has_diff,\n    const std::function<void(OutputBlobModifier*)>& ModifierSetter) {\n  EnrollRepeatedOutputBnWithSetter(obn_prefix,\n                                   GetPbRpfFromCustomizedConf<std::string>(obn_prefix).size(),\n                                   has_diff, ModifierSetter);\n}\n\nvoid Operator::EnrollRepeatedOutputBnWithSetter(\n    const std::string& obn_prefix, int32_t num,\n    const std::function<void(OutputBlobModifier*)>& ModifierSetter) {\n  EnrollRepeatedOutputBnWithSetter(obn_prefix, num, true, ModifierSetter);\n}\n\nvoid Operator::EnrollRepeatedOutputBnWithSetter(\n    const std::string& obn_prefix, const std::function<void(OutputBlobModifier*)>& ModifierSetter) {\n  EnrollRepeatedOutputBnWithSetter(obn_prefix, true, ModifierSetter);\n}\n\nvoid Operator::EnrollRepeatedOutputBn(const std::string& obn_prefix, int32_t num, bool has_diff) {\n  FOR_RANGE(int32_t, i, 0, num) { EnrollOutputBn(GenRepeatedBn(obn_prefix, i), has_diff); }\n}\n\nvoid Operator::EnrollRepeatedOutputBn(const std::string& obn_prefix, bool has_diff) {\n  EnrollRepeatedOutputBn(obn_prefix, GetPbRpfFromCustomizedConf<std::string>(obn_prefix).size(),\n                         has_diff);\n}\n\nvoid Operator::EnrollRepeatedOutputBn(const std::string& obn_prefix, int32_t num) {\n  EnrollRepeatedOutputBn(obn_prefix, num, true);\n}\n\nvoid Operator::EnrollRepeatedOutputBn(const std::string& obn_prefix) {\n  EnrollRepeatedOutputBn(obn_prefix, true);\n}\n\nstd::string GenRepeatedBn(const std::string& bn_prefix, int32_t idx) {\n  CHECK_GE(idx, 0);\n  return bn_prefix + \"_\" + std::to_string(idx);\n}\n\nstd::pair<std::string, int32_t> GenUnRepeatedBn(const std::string& bn) {\n  return GetFieldNameAndIndex4StrVal(bn);\n}\n\nbool IsCpuOnly(const std::string& user_op_type_name) {\n  auto* registration_val = user_op::UserOpRegistryMgr::Get().GetOpRegistryResult(user_op_type_name);\n  CHECK(registration_val != nullptr) << \"user_op_type_name: \" << user_op_type_name;\n  return registration_val->cpu_only_supported;\n}\n\nbool IsCpuOnly(const OperatorConf& op_conf) {\n  OperatorConf::OpTypeCase op_type_case = op_conf.op_type_case();\n  using CpuOnly = OnlyCpuSupportPredicator;\n  auto* ptr = NewObj<int32_t, CpuOnly>(op_type_case);\n  CHECK(ptr != nullptr) << \"op_conf\\n\" << op_conf.DebugString();\n  if (*std::unique_ptr<CpuOnly>(ptr)) { return true; }\n  if (!op_conf.has_user_conf()) { return false; }\n  return IsCpuOnly(op_conf.user_conf().op_type_name());\n}\n\nMaybe<Operator> ConstructOp(const OperatorConf& op_conf, DeviceType device_type) {\n  std::shared_ptr<OperatorConf> dev_op_conf = std::make_shared<OperatorConf>(op_conf);\n  dev_op_conf->set_device_tag(*CHECK_JUST(DeviceTag4DeviceType(device_type)));\n  auto op = JUST(CheckAndConstructOp(dev_op_conf));\n  return op;\n}\n\nMaybe<Operator> ConstructOp(const OperatorConf& op_conf) {\n  if (IsCpuOnly(op_conf)) { return JUST(ConstructOp(op_conf, DeviceType::kCPU)); }\n  return CheckAndConstructOp(std::make_shared<OperatorConf>(op_conf));\n}\n\nSymbol<OperatorConf> Operator::GetOpConfWithoutOpNameAndLbn() const {\n  OperatorConf op_conf(this->op_conf());\n  op_conf.set_name(\"undefined-op-name\");\n  PbMessage* op_type_conf = MutableMessageInPbMessage(&op_conf, op_conf.op_type_case());\n  for (const auto& ibn : input_bns()) {\n    if (!HasStrFieldInPbFdOrPbRpf(*op_type_conf, ibn)) { continue; }\n    ReplaceInputLbnInOpCustomizedConf(&op_conf, ibn, \"undefined-op-name/undefined-ibn\");\n  }\n  return SymbolOf(op_conf);\n}\n\nstd::shared_ptr<OpAttribute> Operator::GetOpAttributeWithoutOpNameAndLbn() const {\n  auto op_attribute = std::make_shared<OpAttribute>();\n  CHECK_JUST(ToOpAttribute(op_attribute.get()));\n  op_attribute->mutable_sbp_signature();\n  *op_attribute->mutable_op_conf() = *GetOpConfWithoutOpNameAndLbn();\n  return op_attribute;\n}\n\nMaybe<int32_t> Operator::GetInputIndex(const std::string& ibn) const {\n  auto it = bn2index_pair_.find(ibn);\n  CHECK_OR_RETURN(it != bn2index_pair_.end());\n  CHECK_EQ_OR_RETURN(it->second.first, BlobNameTag::kInputBlobName);\n  return it->second.second;\n}\n\nMaybe<int32_t> Operator::GetOutputIndex(const std::string& obn) const {\n  auto it = bn2index_pair_.find(obn);\n  CHECK_OR_RETURN(it != bn2index_pair_.end());\n  CHECK_EQ_OR_RETURN(it->second.first, BlobNameTag::kOutputBlobName);\n  return it->second.second;\n}\n\nMaybe<int32_t> Operator::GetOutputIndex(const LogicalBlobId& lbi) const {\n  auto it = lbi2output_index_.find(lbi);\n  CHECK_OR_RETURN(it != lbi2output_index_.end());\n  return it->second;\n}\n\nMaybe<void> Operator::ToOpAttribute(OpAttribute* op_attribute) const {\n  *op_attribute->mutable_input_bns() = input_bns_;\n  *op_attribute->mutable_output_bns() = output_bns_;\n  *op_attribute->mutable_tmp_bns() = tmp_bns_;\n  *op_attribute->mutable_op_conf() = op_conf();\n  *op_attribute->mutable_arg_signature() = arg_signature_;\n  *op_attribute->mutable_arg_modifier_signature() = arg_modifier_signature_;\n  if (blob_last_used_signature_) {\n    *op_attribute->mutable_blob_last_used_signature() = *blob_last_used_signature_;\n  } else {\n    op_attribute->clear_blob_last_used_signature();\n  }\n  if (blob_backward_used_signature_) {\n    *op_attribute->mutable_blob_backward_used_signature() = *blob_backward_used_signature_;\n  } else {\n    op_attribute->clear_blob_backward_used_signature();\n  }\n  if (sbp_signature_) {\n    *op_attribute->mutable_sbp_signature() = *sbp_signature_;\n  } else {\n    op_attribute->clear_sbp_signature();\n  }\n  if (nd_sbp_signature_) {\n    *op_attribute->mutable_nd_sbp_signature() = *nd_sbp_signature_;\n  } else {\n    op_attribute->clear_nd_sbp_signature();\n  }\n  if (local_signature_) {\n    *op_attribute->mutable_local_signature() = *local_signature_;\n  } else {\n    op_attribute->clear_local_signature();\n  }\n  if (input_index2logical_blob_desc_) {\n    JUST(FillLogicalBlobDescSignature(\n        input_bns(), input_index2logical_blob_desc_,\n        op_attribute->mutable_logical_blob_desc_signature()->mutable_bn_in_op2blob_desc()));\n  }\n  if (output_index2logical_blob_desc_) {\n    JUST(FillLogicalBlobDescSignature(\n        output_bns(), output_index2logical_blob_desc_,\n        op_attribute->mutable_logical_blob_desc_signature()->mutable_bn_in_op2blob_desc()));\n  }\n  if (op_parallel_desc_) {\n    *op_attribute->mutable_parallel_conf_signature()->mutable_op_parallel_conf() =\n        op_parallel_desc_->parallel_conf();\n  }\n  if (bn2parallel_desc_) {\n    auto* map = op_attribute->mutable_parallel_conf_signature()->mutable_bn_in_op2parallel_conf();\n    for (const auto& pair : *bn2parallel_desc_) {\n      const bool has_same_parallel_conf_as_op = *op_parallel_desc_ == *pair.second;\n      if (!has_same_parallel_conf_as_op) { (*map)[pair.first] = pair.second->parallel_conf(); }\n    }\n  }\n  return Maybe<void>::Ok();\n}\n\nLogicalBlobId GenLogicalBlobId(const std::string& lbn) {\n  LogicalBlobId lbi;\n  size_t pos = lbn.find('/');\n  CHECK_NE(pos, std::string::npos) << \"lbn: \" << lbn;\n  lbi.set_op_name(lbn.substr(0, pos));\n  std::string blob_name_with_hit = lbn.substr(pos + 1);\n  size_t vbar_pos = blob_name_with_hit.rfind('|');\n  std::string blob_name_with_split_hit = blob_name_with_hit.substr(0, vbar_pos);\n  size_t split_pos = blob_name_with_split_hit.rfind(':');\n  lbi.set_blob_name(blob_name_with_split_hit.substr(0, split_pos));\n  return lbi;\n}\n\nMaybe<bool> GetSbpParallelInLbnOrNothing(const std::string& lbn, SbpParallel* sbp) {\n  size_t vbar_pos = lbn.rfind('|');\n  std::string lbn_with_split_hint = lbn.substr(0, vbar_pos);\n  size_t pos = lbn_with_split_hint.rfind(':');\n  CHECK_NE(pos, lbn_with_split_hint.length() - 1);\n  if (pos == std::string::npos) { return false; }\n  std::string split_hint = lbn_with_split_hint.substr(pos + 1);\n  if (split_hint[0] == 'S') {\n    std::string axis_str = split_hint.substr(1);\n    CHECK_OR_RETURN(IsStrInt(axis_str));\n    sbp->mutable_split_parallel()->set_axis(oneflow_cast<int64_t>(axis_str));\n  } else if (split_hint[0] == 'B') {\n    sbp->mutable_broadcast_parallel();\n  } else {\n    return Error::CheckFailedError()\n           << \"split hint only support 'S' or 'B', but get:\" << split_hint[0];\n  }\n  return true;\n}\n\nMaybe<bool> ParseDisableBoxingFlag(const std::string& lbn_with_hint, bool* disable_boxing) {\n  size_t pos = lbn_with_hint.rfind('|');\n  if (pos == std::string::npos) { return false; }\n  CHECK_NE(pos, lbn_with_hint.length() - 1);\n  std::string disable_boxing_str = lbn_with_hint.substr(pos + 1);\n  CHECK_OR_RETURN(IsStrInt(disable_boxing_str));\n  *disable_boxing = oneflow_cast<int64_t>(disable_boxing_str);\n  return true;\n}\n\nstd::string GetInputLbnInOpCustomizedConf(const OperatorConf& op_conf,\n                                          const std::string& fd_name_may_have_idx) {\n  const PbMessage& msg = GetMessageInPbMessage(op_conf, op_conf.op_type_case());\n  const PbMessage* msg_ptr = &msg;\n  const UserOpConf* user_conf = dynamic_cast<const UserOpConf*>(msg_ptr);\n  if (user_conf) {\n    std::pair<std::string, int32_t> pair = GetFieldNameAndIndex4StrVal(fd_name_may_have_idx);\n    if (user_conf->input().find(pair.first) != user_conf->input().end()) {\n      return user_conf->input().at(pair.first).s(pair.second);\n    } else {\n      LOG(WARNING) << \"cannot find input arg val in user op conf. (arg_name = \" << pair.first\n                   << \", id = \" << std::to_string(pair.second) << \")\";\n      return \"\";\n    }\n  } else {\n    return GetStrValInPbFdOrPbRpf(msg, fd_name_may_have_idx);\n  }\n}\n\n// return old value\nstd::string ReplaceInputLbnInOpTypeConf(PbMessage* msg, const std::string& fd_name_may_have_idx,\n                                        const std::string& new_val) {\n  UserOpConf* user_conf = dynamic_cast<UserOpConf*>(msg);\n  std::string old_val;\n  if (user_conf) {\n    std::pair<std::string, int32_t> pair = GetFieldNameAndIndex4StrVal(fd_name_may_have_idx);\n    CHECK(user_conf->input().find(pair.first) != user_conf->input().end())\n        << \"cannot find input arg val in user op conf. (arg_name = \" << pair.first\n        << \", id = \" << std::to_string(pair.second) << \")\\n\"\n        << \" new lbn = \" << new_val;\n    old_val = user_conf->input().at(pair.first).s(pair.second);\n    (*(user_conf->mutable_input()))[pair.first].set_s(pair.second, new_val);\n  } else {\n    old_val = ReplaceStrValInPbFdOrPbRpf(msg, fd_name_may_have_idx, new_val);\n  }\n  return old_val;\n}\n\nstd::string ReplaceInputLbnInOpCustomizedConf(OperatorConf* op_conf,\n                                              const std::string& fd_name_may_have_idx,\n                                              const std::string& new_val) {\n  PbMessage* op_type_conf = MutableMessageInPbMessage(op_conf, op_conf->op_type_case());\n  return ReplaceInputLbnInOpTypeConf(op_type_conf, fd_name_may_have_idx, new_val);\n}\n\nbool operator==(const OperatorConf& lhs, const OperatorConf& rhs) {\n  return PbMd().Equals(lhs, rhs);\n}\n\nnamespace {\n\nMaybe<void> InferOpOutSbpParallel(\n    Operator* op, const OpNodeSignature& upstream_signature,\n    const std::function<const BlobDesc&(const std::string&)>& ConstBlobDesc4Ibn,\n    const SbpSignature& sbp_sig_conf, const ParallelDesc& parallel_desc) {\n  const auto& SbpParallel4Ibn = [&](const std::string& ibn) -> const SbpParallel* {\n    const auto& map = upstream_signature.sbp_signature().bn_in_op2sbp_parallel();\n    return &map.at(ibn);\n  };\n  HashMap<std::string, SbpInferHint> ibn2sbp_infer_hint;\n  for (const std::string& ibn : op->input_bns()) {\n    const ParallelDesc* pd = &parallel_desc;\n    const BlobDesc* logical_blob_desc = &ConstBlobDesc4Ibn(ibn);\n    const SbpParallel* sbp_parallel = SbpParallel4Ibn(ibn);\n    ibn2sbp_infer_hint.emplace(ibn, SbpInferHint(pd, logical_blob_desc, sbp_parallel));\n  }\n  SbpSignature sbp_signature;\n  JUST(op->InferSbpSignature(&sbp_signature, sbp_sig_conf, ibn2sbp_infer_hint));\n  JUST(op->FillSbpSignature(sbp_signature));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InferLocalSignature(Operator* op, const OpNodeSignature& upstream_signature,\n                                bool is_local, const ParallelDesc& parallel_desc) {\n  HashMap<std::string, LocalSigInferHint> ibn2local_sig_infer_hint;\n  for (const std::string& ibn : op->input_bns()) {\n    const auto& map = upstream_signature.local_signature().bn_in_op2opt_local_parallel();\n    const auto& opt_local_parallel = map.at(ibn);\n    ibn2local_sig_infer_hint.emplace(\n        ibn, LocalSigInferHint(&parallel_desc, opt_local_parallel.has_local_parallel()));\n  }\n  const auto& LocalSigInferHint4Ibn =\n      [&](const std::string& ibn) -> Maybe<const LocalSigInferHint*> {\n    const auto& iter = ibn2local_sig_infer_hint.find(ibn);\n    CHECK_OR_RETURN(iter != ibn2local_sig_infer_hint.end()) << \"input blob not found. ibn: \" << ibn;\n    return &iter->second;\n  };\n  JUST(op->InferLocalSignatureIf(LocalSigInferHint4Ibn, is_local, parallel_desc));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> CheckOpInputSignature(const Operator& op, const OpNodeSignature& upstream_signature) {\n  for (const auto& ibn : op.input_bns()) {\n    {\n      CHECK_OR_RETURN(upstream_signature.has_logical_blob_desc_signature());\n      const auto& map = upstream_signature.logical_blob_desc_signature().bn_in_op2blob_desc();\n      CHECK_OR_RETURN(map.find(ibn) != map.end());\n    }\n    {\n      CHECK_OR_RETURN(upstream_signature.has_sbp_signature());\n      const auto& map = upstream_signature.sbp_signature().bn_in_op2sbp_parallel();\n      CHECK_OR_RETURN(map.find(ibn) != map.end());  // NOLINT\n    }\n    {\n      CHECK_OR_RETURN(upstream_signature.has_local_signature());  // NOLINT\n      const auto& map = upstream_signature.local_signature().bn_in_op2opt_local_parallel();\n      CHECK_OR_RETURN(map.find(ibn) != map.end());\n    }\n  }\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\nMaybe<Operator> ConstructAndInferOp(const OperatorConf& op_conf,\n                                    const OpNodeSignature& upstream_signature, const Scope& scope) {\n  const auto& parallel_desc = *JUST(scope.GetParallelDesc(op_conf));\n  bool is_local = scope.opt_local_parallel_conf().has_local_parallel();\n  const auto& op = JUST(ConstructOp(op_conf));\n  JUST(CheckOpInputSignature(*op, upstream_signature));\n  JUST(op->FillOpParallelDesc(parallel_desc));\n  HashMap<std::string, std::unique_ptr<BlobDesc>> bn_in_op2blob_desc;\n  for (const auto& ibn : op->input_bns()) {\n    const auto& map = upstream_signature.logical_blob_desc_signature().bn_in_op2blob_desc();\n    bn_in_op2blob_desc[ibn].reset(new BlobDesc(map.at(ibn)));\n  }\n  const auto& ConstBlobDesc4Ibn = [&](const std::string& ibn) -> const BlobDesc& {\n    return *bn_in_op2blob_desc.at(ibn);\n  };\n  JUST(op->FillLogicalInBlobDesc(ConstBlobDesc4Ibn));\n  // infer is_local\n  JUST(InferLocalSignature(op.get(), upstream_signature, is_local, parallel_desc));\n  SbpSignature sbp_sig_conf;\n  // iner sbp\n  JUST(InferOpOutSbpParallel(op.get(), upstream_signature, ConstBlobDesc4Ibn, sbp_sig_conf,\n                             parallel_desc));\n  // infer logical blob_desc\n  JUST(op->InferLogicalOutBlobDescsIf());\n  return op;\n}\n\nnamespace {\n\ntemplate<typename SbpParallelT>\nMaybe<Shape> Get1dHierarchyPhysicalShape(const Shape& logical_shape,\n                                         const SbpParallelT& sbp_parallel,\n                                         const int64_t parallel_num, const int64_t parallel_id) {\n  std::shared_ptr<Shape> physical = std::make_shared<Shape>(logical_shape);\n\n  if (sbp_parallel.has_split_parallel()) {\n    const int64_t axis = sbp_parallel.split_parallel().axis();\n    if (logical_shape.At(axis) > 0) {\n      CHECK_GE_OR_RETURN(logical_shape.At(axis), parallel_num);\n      const BalancedSplitter bs(logical_shape.At(axis), parallel_num);\n      physical->Set(axis, bs.At(parallel_id).size());\n    }\n  } else if (sbp_parallel.has_broadcast_parallel() || sbp_parallel.has_partial_sum_parallel()) {\n    // do nothing\n  } else {\n    UNIMPLEMENTED();\n  }\n  return physical;\n}\n\nMaybe<Shape> GetNdHierarchyPhysicalShape(const Shape& logical_shape, const NdSbp& nd_sbp,\n                                         const ParallelDesc& parallel_desc,\n                                         const int64_t parallel_id) {\n  const auto& parallel_hierarchy = *parallel_desc.hierarchy();\n  std::shared_ptr<Shape> physical = std::make_shared<Shape>(logical_shape);\n  Stride hierarch_stride(parallel_hierarchy);\n  FOR_RANGE(int64_t, i, 0, parallel_hierarchy.NumAxes()) {\n    const auto& sbp_parallel = nd_sbp.sbp_parallel(i);\n    if (sbp_parallel.has_split_parallel()) {\n      const int64_t split_axis = sbp_parallel.split_parallel().axis();\n      // Both the lazy and eager mode support unbalanced splitting now\n      if (physical->At(split_axis) > 0) {\n        CHECK_GE_OR_RETURN(physical->At(split_axis), parallel_hierarchy.At(i))\n            << Error::RuntimeError() << \"Expected size at split axis (\" << split_axis\n            << \") of logical shape must be be greater than or equal to parallel num, but got \"\n               \"logical_shape: \"\n            << logical_shape.ToString()\n            << \", placement: \" << *JUST(PlacementToString(SymbolOf(parallel_desc)))\n            << \", nd_sbp: \" << NdSbpToString(SymbolOf(nd_sbp));\n        const BalancedSplitter bs(physical->At(split_axis), parallel_hierarchy.At(i));\n        physical->Set(split_axis, bs.At(CalcIndex4Axis(parallel_id, hierarch_stride, i)).size());\n      }\n    }\n  }\n  return physical;\n}\n\n}  // namespace\n\nMaybe<Shape> GetPhysicalShape(const Shape& logical_shape, const NdSbp& nd_sbp,\n                              const ParallelDesc& parallel_desc, int64_t parallel_id) {\n  CHECK_GE_OR_RETURN(parallel_id, 0);\n  CHECK_LT_OR_RETURN(parallel_id, parallel_desc.hierarchy()->elem_cnt());\n  CHECK_EQ_OR_RETURN(parallel_desc.hierarchy()->NumAxes(), nd_sbp.sbp_parallel_size());\n  if (parallel_desc.hierarchy()->NumAxes() == 1) {\n    return Get1dHierarchyPhysicalShape(logical_shape, nd_sbp.sbp_parallel(0),\n                                       parallel_desc.hierarchy()->elem_cnt(), parallel_id);\n  } else {\n    return GetNdHierarchyPhysicalShape(logical_shape, nd_sbp, parallel_desc, parallel_id);\n  }\n}\n\nMaybe<Shape> GetPhysicalShape(const Shape& logical_shape, const NdSbp& nd_sbp,\n                              const ParallelDesc& parallel_desc,\n                              const ParallelContext& parallel_ctx) {\n  return GetPhysicalShape(logical_shape, nd_sbp, parallel_desc, parallel_ctx.parallel_id());\n}\n\nMaybe<Shape> GetLogicalShape(const Shape& physical_shape, const NdSbp& nd_sbp,\n                             const ParallelDesc& parallel_desc) {\n  const auto& parallel_hierarchy = *parallel_desc.hierarchy();\n  CHECK_EQ_OR_RETURN(parallel_hierarchy.NumAxes(), nd_sbp.sbp_parallel_size());\n  std::shared_ptr<Shape> logical_shape = std::make_shared<Shape>(physical_shape);\n  for (int i = parallel_hierarchy.NumAxes() - 1; i >= 0; --i) {\n    const auto& sbp_parallel = nd_sbp.sbp_parallel(i);\n    if (sbp_parallel.has_split_parallel()) {\n      const int64_t split_axis = sbp_parallel.split_parallel().axis();\n      logical_shape->Set(split_axis, logical_shape->At(split_axis) * parallel_hierarchy.At(i));\n    }\n  }\n  return logical_shape;\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/operator/operator.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_OPERATOR_OPERATOR_H_\n#define ONEFLOW_CORE_OPERATOR_OPERATOR_H_\n\n#include \"oneflow/core/common/str_util.h\"\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/common/preprocessor.h\"\n#include \"oneflow/core/common/protobuf.h\"\n#include \"oneflow/core/common/auto_registration_factory.h\"\n#include \"oneflow/core/common/symbol.h\"\n#include \"oneflow/core/job/parallel_desc.h\"\n#include \"oneflow/core/job/sbp_parallel.h\"\n#include \"oneflow/core/job/local_parallel.pb.h\"\n#include \"oneflow/core/operator/op_conf_util.h\"\n#include \"oneflow/core/register/blob_desc.h\"\n#include \"oneflow/core/job/job_builder.h\"\n#include \"oneflow/core/job/sbp_signature_builder.h\"\n#include \"oneflow/core/kernel/kernel.pb.h\"\n#include \"oneflow/core/job/nd_sbp_infer_hint.h\"\n\nnamespace oneflow {\n\nclass LocalSigInferHint;\nclass OpNodeSignature;\nclass Scope;\n\nclass Operator {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(Operator);\n  Operator();\n  virtual ~Operator() = default;\n\n  //\n  Maybe<void> Init(const OperatorConf& op_conf);\n  Maybe<void> Init(std::shared_ptr<const OperatorConf> op_conf);\n  virtual Maybe<void> InitFromOpConf() = 0;\n\n  // bn_in_op <-> lbi\n  const LogicalBlobId& BnInOp2Lbi(const std::string& bn_in_op) const;\n\n  // Getters\n  const std::string& op_name() const { return op_conf().name(); }\n  const std::string& op_loc() const { return op_conf().loc(); }\n  DeviceType device_type() const;\n  const OperatorConf& op_conf() const;\n  std::shared_ptr<const OperatorConf> shared_op_conf() const;\n  const PbMessage& GetCustomizedConf() const {\n    return GetMessageInPbMessage(op_conf(), op_conf().op_type_case());\n  }\n\n  template<typename T>\n  T GetValFromCustomizedConf(const std::string& field_name) const {\n    return GetValFromPbMessage<T>(GetCustomizedConf(), field_name);\n  }\n\n  template<typename T>\n  const PbRpf<T>& GetPbRpfFromCustomizedConf(const std::string& field_name) const {\n    return GetPbRpfFromPbMessage<T>(GetCustomizedConf(), field_name);\n  }\n\n  const std::string& SoleIbn() const;\n  const std::string& SoleObn() const;\n  const std::string& SoleTbn() const;\n  Maybe<const std::string*> obn4lbi(const LogicalBlobId& lbi) const;\n\n  const PbRpf<std::string>& input_bns() const;\n  const PbRpf<std::string>& output_bns() const;\n  const PbRpf<std::string>& tmp_bns() const;\n  const PbRpf<std::string>& input_output_bns() const;\n\n  Maybe<void> FillOpParallelDesc(const ParallelDesc& parallel_desc);\n  Maybe<void> FillOpParallelDesc(std::shared_ptr<const ParallelDesc> parallel_desc);\n  Maybe<const ParallelDesc> GetOpParallelDesc() const;\n\n  Maybe<void> InferParallelSignatureIf();\n  Maybe<const ParallelDesc> GetParallelDesc4BnInOp(const std::string& bn) const;\n\n  Maybe<void> FillLogicalInBlobDesc(\n      const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp);\n  Maybe<void> FillLogicalInBlobDesc(\n      const std::function<const BlobDesc&(const std::string&)>& BlobDesc4BnInOp);\n  Maybe<void> FillLogicalInBlobDesc(\n      const std::function<Maybe<const BlobDesc>(int32_t)>& BlobDesc4InputIndex);\n  Maybe<const BlobDesc> GetLogicalBlobDesc4Ibn(const std::string& ibn) const;\n  Maybe<const BlobDesc> GetLogicalBlobDesc4InputIndex(int32_t index) const;\n  Maybe<void> FillLogicalOutBlobDesc(\n      const std::function<const BlobDesc&(const std::string&)>& BlobDesc4BnInOp);\n  Maybe<void> FillLogicalOutBlobDesc(\n      const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp);\n  Maybe<const BlobDesc> GetLogicalBlobDesc4Obn(const std::string& obn) const;\n  Maybe<const BlobDesc> GetLogicalBlobDesc4OutputIndex(int32_t index) const;\n  Maybe<const BlobDesc*> GetLogicalBlobDescPtr4OutputIndex(int32_t index) const;\n  Maybe<const BlobDesc> GetLogicalBlobDesc4BnInOp(const std::string& bn) const;\n  Maybe<void> InferLogicalOutBlobDescsIf();\n  virtual Maybe<void> InferLogicalOutBlobDescs(\n      const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,\n      const ParallelDesc& parallel_desc) const = 0;\n\n  // Read: shape of input_blobs\n  // Write: shape of output_blobs\n  Maybe<void> InferBlobDescsIf(\n      const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,\n      const ParallelContext*, const JobDesc* job_desc) const;\n\n  Maybe<void> InferOutBlobDescsIf(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,\n                                  const ParallelContext*) const;\n\n  Maybe<void> InferInternalBlobDescsIf(\n      const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,\n      const ParallelContext* parallel_ctx, const JobDesc* job_desc) const;\n\n  Maybe<void> InferInplaceObn2IbnIf(\n      HashMap<std::string, std::string>* mut_inplace_obn2ibn,\n      HashMap<std::string, std::string>* con_inplace_obn2ibn,\n      const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,\n      const ParallelContext* parallel_ctx) const;\n\n  Maybe<void> FillInputBlobTimeShape(\n      const std::function<Maybe<const Shape>(int32_t)>& GetTimeShape4InputIndex);\n  Maybe<void> InferOpTimeShapeIf();\n  virtual Maybe<void> InferOpTimeShape(\n      const std::function<Maybe<const Shape>(const std::string&)>& GetTimeShape4BnInOp,\n      std::shared_ptr<const Shape>* time_shape) const;\n  Maybe<const Shape> GetOpTimeShape() const;\n  Maybe<const Shape> GetInputBlobFastestTimeShape() const;\n  Maybe<const Shape> GetInputOutputFastestTimeShape() const;\n\n  Maybe<void> InferSbpSignature(SbpSignature* sbp_signature, const SbpSignature& sbp_sig_conf,\n                                const HashMap<std::string, SbpInferHint>& ibn2sbp_infer_hint) const;\n  Maybe<void> FillSbpSignature(const SbpSignature& sbp_signature);\n  Maybe<void> FillNdSbpSignature(const NdSbpSignature& signature);\n  Maybe<void> InferSbpSignatureIf(\n      const SbpSignature& sbp_sig_conf,\n      const std::function<int32_t(const SbpSignature&)>& CalcOrderValue4SbpSig,\n      const std::function<Maybe<const SbpInferHint*>(const std::string&)>& SbpInferHint4Ibn,\n      const ParallelDesc& parallel_desc);\n  Maybe<void> InferNdSbpSignatureIf(\n      const NdSbpSignature& nd_sbp_constraints, const ParallelDesc& parallel_desc,\n      std::function<Maybe<const NdSbpInferHint*>(const std::string&)> NdSbpInferHint4Ibn);\n\n  // The function that how to dump nd_sbp for op_conf\n  using DumpNdSbpSignatureForOpConfFn =\n      std::function<Maybe<void>(const NdSbpSignature& nd_sbp_sig, OperatorConf* op_conf)>;\n  virtual DumpNdSbpSignatureForOpConfFn GetDumpNdSbpSignatureForOpConfFn() const;\n\n  // Infer blob's LocalSignature\n  Maybe<void> InferLocalSignatureIf(\n      std::function<Maybe<const LocalSigInferHint*>(const std::string&)> LocalSigInferHint4Ibn,\n      bool is_local_parallel_view_conf, const ParallelDesc& parallel_desc);\n\n  void GenKernelConf(const std::function<const BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,\n                     const ParallelContext*, KernelConf*) const;\n  const InputBlobModifier& InputBlobModifier4Ibn(const std::string& ibn) const;\n  const OutputBlobModifier& OutputBlobModifier4Obn(const std::string& obn) const;\n  Maybe<const SbpParallel*> SbpParallel4BnInOp(const std::string& bn_in_op) const;\n  Maybe<const NdSbp*> NdSbp4BnInOp(const std::string& bn_in_op) const;\n  Maybe<const OptLocalParallel*> OptLocalParallel4BnInOp(const std::string& bn_in_op) const;\n\n  Maybe<void> GetSbpSignaturesIf(\n      const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,\n      int32_t hierarchy_value, SbpSignatureList* sbp_sig_list) const;\n  virtual Maybe<void> EnumerateNdSbpSignatures(\n      const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,\n      const ParallelDesc& parallel_desc, std::vector<NdSbpSignature>* nd_sbp_sig_list) const;\n  virtual Maybe<void> GetNdSbpSignatureList(\n      const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,\n      const ParallelDesc& parallel_desc, std::vector<NdSbpSignature>* nd_sbp_sig_list) const;\n  virtual Maybe<double> GetComputeComplexity(\n      NdSbpSignature* sbp_signature,\n      std::function<const BlobDesc&(const std::string& bn)> logical_blob_desc4bn,\n      const ParallelDesc& parallel_desc) const;\n  // TODO: Will infer blob shape before inferring sbp and delete the check_output later\n  Maybe<void> GetValidNdSbpSignatureList(\n      const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,\n      const ParallelDesc& parallel_desc, std::vector<NdSbpSignature>* nd_sbp_sig_list,\n      bool check_output) const;\n\n  void ForEachBnInOp(const std::function<void(const std::string&)>&) const;\n\n  virtual Symbol<OperatorConf> GetOpConfWithoutOpNameAndLbn() const;\n  std::shared_ptr<OpAttribute> GetOpAttributeWithoutOpNameAndLbn() const;\n\n  Maybe<const SbpSignature*> sbp_signature() const;\n  Maybe<const NdSbpSignature*> nd_sbp_signature() const;\n  BlobLastUsedSignature* mut_blob_last_used_signature();\n  BlobBackwardUsedSignature* mut_blob_backward_used_signature();\n\n  Maybe<int32_t> GetInputIndex(const std::string& ibn) const;\n  Maybe<int32_t> GetOutputIndex(const std::string& obn) const;\n  Maybe<int32_t> GetOutputIndex(const LogicalBlobId& lbi) const;\n\n  Maybe<void> ToOpAttribute(OpAttribute* op_attribute) const;\n\n protected:\n  Maybe<void> FillBlobParallelDesc(\n      const std::function<Maybe<const ParallelDesc>(const std::string&)>& ParallelDesc4Bn);\n  virtual Maybe<void> InferBlobParallelDesc();\n  virtual Maybe<void> InferOutBlobDescs(\n      const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,\n      const ParallelContext* parallel_ctx) const;\n  virtual Maybe<void> InferInternalBlobDescs(\n      const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,\n      const ParallelContext* parallel_ctx, const JobDesc* job_desc) const;\n  virtual Maybe<void> GetSbpSignatures(\n      const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,\n      int32_t hierarchy_value, SbpSignatureList* sbp_sig_list) const {\n    return GetSbpSignatures(LogicalBlobDesc4Ibn, sbp_sig_list);\n  }\n  virtual Maybe<void> GetSbpSignatures(\n      const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,\n      SbpSignatureList* sbp_sig_list) const {\n    return GetSbpSignatures(sbp_sig_list);\n  }\n  virtual Maybe<void> InferSbpSignature(\n      SbpSignature* sbp_signature, const SbpSignature& sbp_sig_conf,\n      const std::function<int32_t(const SbpSignature&)>& CalcOrderValue4SbpSig,\n      std::function<Maybe<const SbpInferHint*>(const std::string&)> SbpInferHint4Ibn,\n      const ParallelDesc& parallel_desc) const;\n  virtual Maybe<void> InferNdSbpSignature(\n      NdSbpSignature* nd_sbp_signature, const NdSbpSignature& nd_sbp_constraints,\n      const ParallelDesc& parallel_desc,\n      std::function<Maybe<const NdSbpInferHint*>(const std::string&)> NdSbpInferHint4Ibn) const;\n  virtual Maybe<void> GetSbpSignatures(SbpSignatureList* sbp_sig_list) const {\n    OF_UNIMPLEMENTED() << \" GetSbpSignatures unimplemented, op name: \" << op_name();\n  }\n  virtual Maybe<void> InferLocalSignature(\n      std::function<Maybe<const LocalSigInferHint*>(const std::string&)> LocalSigInferHint4Ibn,\n      bool is_local_parallel_view_conf, const ParallelDesc& parallel_desc);\n\n  virtual Maybe<void> InferInplaceObn2Ibn(\n      HashMap<std::string, std::string>* mut_inplace_obn2ibn,\n      HashMap<std::string, std::string>* con_inplace_obn2ibn,\n      const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,\n      const ParallelContext* parallel_ctx) const;\n\n  virtual void VirtualGenKernelConf(\n      std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, const ParallelContext*,\n      KernelConf*) const;\n\n  virtual void AddLbi2OutputIndex(const LogicalBlobId& lbi, int32_t output_index);\n\n  virtual LogicalBlobId lbi4ibn(const std::string& input_bn) const;\n  virtual LogicalBlobId lbi4obn(const std::string& output_bn) const;\n\n  // enroll data blobs\n  void EnrollTmpBn(const std::string& dtbn);\n  void EnrollRepeatedInputBn(const std::string& ibn_prefix, int32_t num, bool has_diff);\n  void EnrollRepeatedInputBn(const std::string& ibn_prefix, bool has_diff);\n  void EnrollRepeatedInputBn(const std::string& ibn_prefix, int32_t num);\n  void EnrollRepeatedInputBn(const std::string& ibn_prefix);\n\n  void EnrollRepeatedOutputBn(const std::string& obn_prefix, int32_t num, bool has_diff);\n  void EnrollRepeatedOutputBn(const std::string& obn_prefix, bool has_diff);\n  void EnrollRepeatedOutputBn(const std::string& obn_prefix, int32_t num);\n  void EnrollRepeatedOutputBn(const std::string& obn_prefix);\n\n  void EnrollRepeatedOutputBnWithSetter(\n      const std::string& obn_prefix, int32_t num, bool has_diff,\n      const std::function<void(OutputBlobModifier*)>& ModifierSetter);\n  void EnrollRepeatedOutputBnWithSetter(\n      const std::string& obn_prefix, bool has_diff,\n      const std::function<void(OutputBlobModifier*)>& ModifierSetter);\n  void EnrollRepeatedOutputBnWithSetter(\n      const std::string& obn_prefix, int32_t num,\n      const std::function<void(OutputBlobModifier*)>& ModifierSetter);\n  void EnrollRepeatedOutputBnWithSetter(\n      const std::string& obn_prefix,\n      const std::function<void(OutputBlobModifier*)>& ModifierSetter);\n\n  InputBlobModifier* EnrollInputBn(const std::string& ibn, bool has_diff);\n  InputBlobModifier* EnrollInputBn(const std::string& ibn) { return EnrollInputBn(ibn, true); }\n  OutputBlobModifier* EnrollOutputBn(const std::string& obn, bool has_diff);\n  OutputBlobModifier* EnrollOutputBn(const std::string& obn) { return EnrollOutputBn(obn, true); }\n\n  InputBlobModifier* MutInputBlobModifier4Ibn(const std::string& ibn);\n  OutputBlobModifier* MutOutputBlobModifier4Obn(const std::string& obn);\n  OptLocalParallel* MutOptLocalParallel(const std::string& bn_in_op);\n\n private:\n  enum BlobNameTag {\n    kInputBlobName,\n    kOutputBlobName,\n  };\n  Maybe<void> FilterAndCheckValidSbpSignatureListByLogicalShape(\n      const SbpSignatureList& total_sbp_sig_list,\n      const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,\n      const ParallelDesc& parallel_desc, SbpSignatureList* valid_sbp_sig_list) const;\n  // TODO(wyg): 1d and nd sbp use this function to filter and check\n  Maybe<void> FilterNdSbpSignatureListByLogicalShape(\n      const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,\n      const ParallelDesc& parallel_desc, std::vector<NdSbpSignature>* nd_sbp_sig_list,\n      bool check_output) const;\n  Maybe<void> GreedilyFindMinCopyCostNdSbp(\n      NdSbpSignature* nd_sbp_signature,\n      const std::function<Maybe<const NdSbpInferHint*>(const std::string&)>& NdSbpInferHint4Ibn,\n      const std::vector<NdSbpSignature>& nd_sbp_sig_list) const;\n\n  LogicalBlobId tbn2lbi(const std::string& data_tmp_bn) const;\n  std::string Bn2ConfName(const std::string& bn) const;\n  PbMap<std::string, LogicalBlobId>* mut_bn_in_op2lbi() {\n    return arg_signature_.mutable_bn_in_op2lbi();\n  }\n\n  std::shared_ptr<const OperatorConf> op_conf_;\n  std::shared_ptr<const ParallelDesc> op_parallel_desc_;\n  std::unique_ptr<HashMap<std::string, std::shared_ptr<const ParallelDesc>>> bn2parallel_desc_;\n  std::unique_ptr<std::vector<std::shared_ptr<const BlobDesc>>> input_index2logical_blob_desc_;\n  std::unique_ptr<std::vector<std::shared_ptr<const BlobDesc>>> output_index2logical_blob_desc_;\n  std::unique_ptr<std::vector<std::shared_ptr<const Shape>>> input_index2time_shape_;\n  std::shared_ptr<const Shape> input_blob_fastest_time_shape_;\n  std::shared_ptr<const Shape> input_output_fastest_time_shape_;\n  std::shared_ptr<const Shape> op_time_shape_;\n  std::shared_ptr<const SbpSignature> sbp_signature_;\n  std::shared_ptr<const NdSbpSignature> nd_sbp_signature_;\n  PbRpf<std::string> input_bns_;\n  PbRpf<std::string> output_bns_;\n  PbRpf<std::string> tmp_bns_;\n  PbRpf<std::string> input_output_bns_;\n  DeviceType device_type_;\n  ArgSignature arg_signature_;\n  ArgModifierSignature arg_modifier_signature_;\n  std::unique_ptr<BlobLastUsedSignature> blob_last_used_signature_;\n  std::unique_ptr<BlobBackwardUsedSignature> blob_backward_used_signature_;\n  std::unique_ptr<LocalSignature> local_signature_;\n\n  HashMap<std::string, std::pair<BlobNameTag, int32_t>> bn2index_pair_;\n  HashMap<LogicalBlobId, int32_t> lbi2output_index_;\n};\n\nstd::string GenRepeatedBn(const std::string& bn_prefix, int32_t idx);\nstd::pair<std::string, int32_t> GenUnRepeatedBn(const std::string& bn);\n\nbool IsCpuOnly(const std::string& user_op_type_name);\nbool IsCpuOnly(const OperatorConf& op_conf);\n\nstruct OnlyCpuSupportPredicator {\n  OnlyCpuSupportPredicator(bool only_cpu) : only_cpu_(only_cpu) {}\n  operator bool() { return only_cpu_; }\n\n private:\n  bool only_cpu_;\n};\n\nstruct RuntimeRegstNum4OpSameOutputBlob final {\n  RuntimeRegstNum4OpSameOutputBlob(size_t num) : num_(num) {}\n  RuntimeRegstNum4OpSameOutputBlob(std::function<size_t()> get_num)\n      : get_num_(new std::function<size_t()>(get_num)) {}\n  operator size_t() {\n    if (!get_num_) { return num_; }\n    return (*this->get_num_)();\n  }\n\n private:\n  size_t num_;\n  std::unique_ptr<std::function<size_t()>> get_num_;\n};\n\n#define REGISTER_OP(op_type_case, OpType)                                       \\\n  REGISTER_CLASS_CREATOR(int32_t, op_type_case, OnlyCpuSupportPredicator,       \\\n                         ([] { return new OnlyCpuSupportPredicator(false); })); \\\n  REGISTER_CLASS_WITH_ARGS(int32_t, op_type_case, Operator, OpType, const OperatorConf&)\n\n#define REGISTER_CPU_OP(op_type_case, OpType)                                  \\\n  REGISTER_CLASS_CREATOR(int32_t, op_type_case, OnlyCpuSupportPredicator,      \\\n                         ([] { return new OnlyCpuSupportPredicator(true); })); \\\n  REGISTER_CLASS_WITH_ARGS(int32_t, op_type_case, Operator, OpType, const OperatorConf&)\n\n#define REGISTER_OP_CREATOR(op_type_case, creator)                              \\\n  REGISTER_CLASS_CREATOR(int32_t, op_type_case, OnlyCpuSupportPredicator,       \\\n                         ([] { return new OnlyCpuSupportPredicator(false); })); \\\n  REGISTER_CLASS_CREATOR(int32_t, op_type_case, Operator, creator, const OperatorConf&)\n\n#define REGISTER_OP_SAME_OUTPUT_BLOB_REGST_NUM(op_type_case, num)                 \\\n  REGISTER_CLASS_CREATOR(int32_t, op_type_case, RuntimeRegstNum4OpSameOutputBlob, \\\n                         ([] { return new RuntimeRegstNum4OpSameOutputBlob(num); }))\n\n#define REGISTER_USER_OP_SAME_OUTPUT_BLOB_REGST_NUM(op_type_name, num)                \\\n  REGISTER_CLASS_CREATOR(std::string, op_type_name, RuntimeRegstNum4OpSameOutputBlob, \\\n                         ([] { return new RuntimeRegstNum4OpSameOutputBlob(num); }))\n\n#define REGISTER_USER_OP_SAME_OUTPUT_BLOB_REGST_NUM_WITH_FUNC(op_type_name, func)     \\\n  REGISTER_CLASS_CREATOR(std::string, op_type_name, RuntimeRegstNum4OpSameOutputBlob, \\\n                         ([] { return new RuntimeRegstNum4OpSameOutputBlob(func); }));\n\nstruct IsInterfaceOpConf4OpTypeCase final {};\n\n#define REGISTER_INTERFACE_OP(op_type_case)                                   \\\n  REGISTER_CLASS_CREATOR(int32_t, op_type_case, IsInterfaceOpConf4OpTypeCase, \\\n                         ([] { return new IsInterfaceOpConf4OpTypeCase(); }))\n\nstruct DisableInputBoxingGroup final {};\n\n#define REGISTER_DISABLE_INPUT_BOXING_GROUP(op_type_case)                \\\n  REGISTER_CLASS_CREATOR(int32_t, op_type_case, DisableInputBoxingGroup, \\\n                         ([] { return new DisableInputBoxingGroup(); }))\n\nstruct IsTickTockOpTypeCase final {};\n\n#define REGISTER_TICK_TOCK_OP(op_type_case)                           \\\n  REGISTER_CLASS_CREATOR(int32_t, op_type_case, IsTickTockOpTypeCase, \\\n                         ([] { return new IsTickTockOpTypeCase; }))\n\nMaybe<Operator> ConstructOp(const OperatorConf& op_conf);\nMaybe<Operator> ConstructOp(const OperatorConf& op_conf, DeviceType device_type);\n\ninline OpBlobArg GenOpBlobArg(const std::string& op_name, const std::string& bn_in_op) {\n  OpBlobArg oba;\n  oba.set_op_name(op_name);\n  oba.set_bn_in_op(bn_in_op);\n  return oba;\n}\n\nLogicalBlobId GenLogicalBlobId(const std::string& lbn);\n\ninline std::string GenLogicalBlobName(const std::string& op_name, const std::string& blob_name) {\n  return op_name + \"/\" + blob_name;\n}\n\ninline std::string GenLogicalBlobName(const LogicalBlobId& lbi) {\n  CHECK_EQ(lbi.has_op_name(), true);\n  CHECK_EQ(lbi.has_blob_name(), true);\n  return GenLogicalBlobName(lbi.op_name(), lbi.blob_name());\n}\n\nMaybe<bool> GetSbpParallelInLbnOrNothing(const std::string& lbn, SbpParallel* sbp);\nMaybe<bool> ParseDisableBoxingFlag(const std::string& lbn_with_hint, bool* disable_boxing);\n\nstd::string GetInputLbnInOpCustomizedConf(const OperatorConf& op_conf,\n                                          const std::string& fd_name_may_have_idx);\n\n// return old value\nstd::string ReplaceInputLbnInOpCustomizedConf(OperatorConf* op_conf,\n                                              const std::string& fd_name_may_have_idx,\n                                              const std::string& new_val);\n\nbool operator==(const OperatorConf& lhs, const OperatorConf& rhs);\n\nMaybe<Operator> ConstructAndInferOp(const OperatorConf& op_conf,\n                                    const OpNodeSignature& upstream_signature, const Scope& scope);\n\nMaybe<Shape> GetPhysicalShape(const Shape& logical_shape, const NdSbp& nd_sbp,\n                              const ParallelDesc& parallel_desc,\n                              const ParallelContext& parallel_ctx);\n\nMaybe<Shape> GetPhysicalShape(const Shape& logical_shape, const NdSbp& nd_sbp,\n                              const ParallelDesc& parallel_desc, int64_t parallel_id);\n\nMaybe<Shape> GetLogicalShape(const Shape& physical_shape, const NdSbp& nd_sbp,\n                             const ParallelDesc& parallel_desc);\n\n}  // namespace oneflow\n\nnamespace std {\n\ntemplate<>\nstruct hash<oneflow::OperatorConf> final {\n  size_t operator()(const oneflow::OperatorConf& op_conf) const {\n    std::string serialized;\n    op_conf.SerializeToString(&serialized);\n    return std::hash<std::string>()(serialized);\n  }\n};\n\n}  // namespace std\n\n#endif  // ONEFLOW_CORE_OPERATOR_OPERATOR_H_\n"
  },
  {
    "path": "oneflow/core/operator/operator_util.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/operator/operator_util.h\"\n#include \"oneflow/core/framework/user_op_conf.h\"\n\nnamespace oneflow {\n\nsize_t DhwOffset(const std::string& data_format) {\n  if (data_format == \"channels_first\") {\n    return 2;\n  } else if (data_format == \"channels_last\") {\n    return 1;\n  } else {\n    UNIMPLEMENTED();\n  }\n}\n\nstd::vector<int32_t> Get3DVecInOpConf(const PbRf<int32_t>& field_vals, int32_t NDims) {\n  std::vector<int32_t> vec;\n  vec.reserve(3);\n  FOR_RANGE(uint8_t, dim, 0, 3) {\n    int64_t index = static_cast<int64_t>(dim) - (3 - NDims);\n    if (index < 0) {\n      vec.emplace_back(1);\n    } else {\n      vec.emplace_back(field_vals.Get(index));\n    }\n  }\n  return vec;\n}\n\nint64_t GetInDim(const ShapeView& shape, const std::string& data_format, int32_t dim,\n                 int32_t NDims) {\n  int64_t offset = 0;\n  if (data_format == \"channels_last\") {\n    offset = 1;\n  } else if (data_format == \"channels_first\") {\n    offset = 2;\n  } else {\n    UNIMPLEMENTED();\n  }\n  int64_t index = offset + static_cast<int64_t>(dim) - static_cast<int64_t>(3 - NDims);\n  if (index < offset) {\n    return 1;\n  } else {\n    return shape.At(index);\n  }\n}\n\nvoid GetWindowedOutputSize(int64_t input_size, int32_t filter_size, int32_t dilation_rate,\n                           int32_t stride, const std::string& padding_type, bool ceil_mode,\n                           int64_t* output_size, int32_t* padding_before, int32_t* padding_after) {\n  CHECK_GT(stride, 0);\n  CHECK_GE(dilation_rate, 1);\n\n  int32_t effective_filter_size = (filter_size - 1) * dilation_rate + 1;\n  if (padding_type == \"customized\") {\n    if (output_size) {\n      *output_size = (input_size + *padding_before + *padding_after - effective_filter_size + stride\n                      + (ceil_mode ? stride - 1 : 0))\n                     / stride;\n      CHECK_GE((*output_size), 0);\n    }\n  } else if (padding_type == \"valid\") {\n    if (output_size) { *output_size = (input_size - effective_filter_size + stride) / stride; }\n    if (padding_before) { *padding_before = 0; }\n    if (padding_after) { *padding_after = 0; }\n  } else {\n    int64_t tmp_output_size = (input_size + stride - 1) / stride;\n    if (output_size) { *output_size = tmp_output_size; }\n    const int32_t padding_needed = std::max(\n        0,\n        static_cast<int32_t>((tmp_output_size - 1) * stride + effective_filter_size - input_size));\n    const int32_t padding_small = padding_needed / 2;\n    const int32_t padding_large = padding_needed - padding_needed / 2;\n    if (padding_type == \"same_upper\") {\n      if (padding_before) { *padding_before = padding_small; }\n      if (padding_after) { *padding_after = padding_large; }\n    } else if (padding_type == \"same_lower\") {\n      if (padding_before) { *padding_before = padding_large; }\n      if (padding_after) { *padding_after = padding_small; }\n    } else {\n      UNIMPLEMENTED();\n    }\n  }\n  if (output_size) { CHECK_GE((*output_size), 0); }\n}\n\nvoid GetWindowedOutputSize(int64_t input_size, int32_t filter_size, int32_t dilation_rate,\n                           int32_t stride, const std::string& padding_type, int64_t* output_size,\n                           int32_t* padding_before, int32_t* padding_after) {\n  CHECK_GT(stride, 0);\n  CHECK_GE(dilation_rate, 1);\n\n  int32_t effective_filter_size = (filter_size - 1) * dilation_rate + 1;\n  if (padding_type == \"valid\") {\n    if (output_size) { *output_size = (input_size - effective_filter_size + stride) / stride; }\n    if (padding_before) { *padding_before = 0; }\n    if (padding_after) { *padding_after = 0; }\n  } else if (padding_type == \"same\") {\n    int64_t tmp_output_size = (input_size + stride - 1) / stride;\n    if (output_size) { *output_size = tmp_output_size; }\n    const int32_t padding_needed = std::max(\n        0,\n        static_cast<int32_t>((tmp_output_size - 1) * stride + effective_filter_size - input_size));\n    // For odd values of total padding, add more padding at the 'right'\n    // side of the given dimension.\n    if (padding_before) { *padding_before = padding_needed / 2; }\n    if (padding_after) { *padding_after = padding_needed - padding_needed / 2; }\n  } else {\n    UNIMPLEMENTED();\n  }\n  if (output_size) { CHECK_GE((*output_size), 0); }\n}\n\nvoid GetWindowedOutputSize(int64_t input_size, int32_t filter_size, int32_t stride,\n                           const std::string& padding_type, int64_t* output_size,\n                           int32_t* padding_before, int32_t* padding_after) {\n  GetWindowedOutputSize(input_size, filter_size, 1, stride, padding_type, output_size,\n                        padding_before, padding_after);\n}\n\nvoid GetWindowedOutputSize(int64_t input_size, int32_t filter_size, int32_t stride,\n                           const std::string& padding_type, int64_t* output_size,\n                           int32_t* padding_size) {\n  GetWindowedOutputSize(input_size, filter_size, stride, padding_type, output_size, padding_size,\n                        nullptr);\n}\n\nvoid Get3DOutputSize(const DimVector& in, const std::vector<int32_t>& pool_size,\n                     const std::vector<int32_t>& strides, const std::string& padding_type,\n                     DimVector* out, std::vector<int32_t>* padding) {\n  Get3DOutputSize(in, pool_size, strides, padding_type, out, padding, nullptr, nullptr);\n}\n\nvoid Get3DOutputSize(const DimVector& in, const std::vector<int32_t>& pool_size,\n                     const std::vector<int32_t>& strides, const std::string& padding_type,\n                     DimVector* out, std::vector<int32_t>* padding_before,\n                     std::vector<int32_t>* padding_after) {\n  Get3DOutputSize(in, pool_size, strides, padding_type, out, padding_before, padding_after,\n                  nullptr);\n}\n\nvoid Get3DOutputSize(const DimVector& in, const std::vector<int32_t>& pool_size,\n                     const std::vector<int32_t>& strides, const std::string& padding_type,\n                     DimVector* out, std::vector<int32_t>* padding_before,\n                     std::vector<int32_t>* padding_after, std::vector<int32_t>* dilation_rate) {\n  CHECK(out);\n  out->clear();\n  out->resize(3);\n  if (padding_before) {\n    padding_before->clear();\n    padding_before->resize(3);\n  }\n  if (padding_after) {\n    padding_after->clear();\n    padding_after->resize(3);\n  }\n  FOR_RANGE(size_t, i, 0, 3) {\n    int64_t* out_ptr = &(*out).at(i);\n    int32_t* padding_before_ptr = padding_before ? (&(*padding_before).at(i)) : nullptr;\n    int32_t* padding_after_ptr = padding_after ? (&(*padding_after).at(i)) : nullptr;\n    if (dilation_rate) {\n      GetWindowedOutputSize(in.at(i), pool_size.at(i), dilation_rate->at(i), strides.at(i),\n                            padding_type, out_ptr, padding_before_ptr, padding_after_ptr);\n    } else {\n      GetWindowedOutputSize(in.at(i), pool_size.at(i), strides.at(i), padding_type, out_ptr,\n                            padding_before_ptr, padding_after_ptr);\n    }\n  }\n}\n\nvoid Get3DOutputSize(const DimVector& in, const std::vector<int32_t>& pool_size,\n                     const std::vector<int32_t>& strides, const std::string& padding_type,\n                     const bool ceil_mode, std::vector<int32_t>* dilation_rate, DimVector* out,\n                     std::vector<int32_t>* padding_before, std::vector<int32_t>* padding_after) {\n  CHECK(out);\n  out->clear();\n  out->resize(3);\n  FOR_RANGE(size_t, i, 0, 3) {\n    int64_t* out_ptr = &(*out).at(i);\n    if (dilation_rate) {\n      GetWindowedOutputSize(in.at(i), pool_size.at(i), dilation_rate->at(i), strides.at(i),\n                            padding_type, ceil_mode, out_ptr, &(padding_before->at(i)),\n                            &(padding_after->at(i)));\n    } else {\n      GetWindowedOutputSize(in.at(i), pool_size.at(i), 1, strides.at(i), padding_type, ceil_mode,\n                            out_ptr, &(padding_before->at(i)), &(padding_after->at(i)));\n    }\n  }\n}\n\nvoid GetConvOutAndPad(const ShapeView& in_blob_shape, const PbMessage& conv_conf, DimVector* out,\n                      std::vector<int32_t>* pad_small_side, std::vector<int32_t>* pad_large_side) {\n  int32_t opkernel_dim = in_blob_shape.NumAxes() - 2;\n  if (out) { out->assign(opkernel_dim, 0); }\n  if (pad_small_side) { pad_small_side->assign(opkernel_dim, 0); }\n  if (pad_large_side) { pad_large_side->assign(opkernel_dim, 0); }\n  const auto& data_format = GetValFromPbMessage<std::string>(conv_conf, \"data_format\");\n  const std::string& padding = GetValFromPbMessage<std::string>(conv_conf, \"padding\");\n  const PbRf<int32_t>& dilation_rate = GetPbRfFromPbMessage<int32_t>(conv_conf, \"dilation_rate\");\n  const auto& strides = GetPbRfFromPbMessage<int32_t>(conv_conf, \"strides\");\n  const PbRf<int32_t>& kernel_size = GetPbRfFromPbMessage<int32_t>(conv_conf, \"kernel_size\");\n  FOR_RANGE(int32_t, i, 0, opkernel_dim) {\n    GetWindowedOutputSize(in_blob_shape.At(DhwOffset(data_format) + i), kernel_size.Get(i),\n                          dilation_rate.Get(i), strides.Get(i), padding,\n                          out ? &(out->at(i)) : nullptr,\n                          pad_small_side ? &(pad_small_side->at(i)) : nullptr,\n                          pad_large_side ? &(pad_large_side->at(i)) : nullptr);\n  }\n}\n\nvoid GetConvOutAndPad(const ShapeView& in_blob_shape, const user_op::UserOpConfWrapper& conv_conf,\n                      DimVector* out, std::vector<int32_t>* pad_small_side,\n                      std::vector<int32_t>* pad_large_side) {\n  int32_t opkernel_dim = in_blob_shape.NumAxes() - 2;\n  if (out) { out->assign(opkernel_dim, 0); }\n  if (pad_small_side) { pad_small_side->assign(opkernel_dim, 0); }\n  if (pad_large_side) { pad_large_side->assign(opkernel_dim, 0); }\n  const auto& data_format = conv_conf.attr<std::string>(\"data_format\");\n  const auto& padding = conv_conf.attr<std::string>(\"padding\");\n  const auto& strides = conv_conf.attr<std::vector<int32_t>>(\"strides\");\n  const auto& dilation_rate = conv_conf.attr<std::vector<int32_t>>(\"dilation_rate\");\n  const auto& kernel_size = conv_conf.attr<std::vector<int32_t>>(\"kernel_size\");\n  FOR_RANGE(int32_t, i, 0, opkernel_dim) {\n    GetWindowedOutputSize(in_blob_shape.At(DhwOffset(data_format) + i), kernel_size.at(i),\n                          dilation_rate.at(i), strides.at(i), padding,\n                          out ? &(out->at(i)) : nullptr,\n                          pad_small_side ? &(pad_small_side->at(i)) : nullptr,\n                          pad_large_side ? &(pad_large_side->at(i)) : nullptr);\n  }\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/operator/operator_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_OPERATOR_OPERATOR_UTIL_H_\n#define ONEFLOW_CORE_OPERATOR_OPERATOR_UTIL_H_\n\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/common/shape_view.h\"\n#include \"oneflow/core/common/protobuf.h\"\n\nnamespace oneflow {\n\nnamespace user_op {\n\nclass UserOpConfWrapper;\n}\n\nsize_t DhwOffset(const std::string& data_format);\n\nstd::vector<int32_t> Get3DVecInOpConf(const PbRf<int32_t>& field_vals, int32_t NDims);\n\nint64_t GetInDim(const ShapeView& shape, const std::string& data_format, int32_t dim, int32_t NDim);\n\nvoid GetWindowedOutputSize(int64_t input_size, int32_t filter_size, int32_t stride,\n                           const std::string& padding_type, int64_t* output_size,\n                           int32_t* padding_before, int32_t* padding_after);\n\nvoid GetWindowedOutputSize(int64_t input_size, int32_t filter_size, int32_t stride,\n                           const std::string& padding_type, int64_t* output_size,\n                           int32_t* padding_size);\n\nvoid GetWindowedOutputSize(int64_t input_size, int32_t filter_size, int32_t dilation_rate,\n                           int32_t stride, const std::string& padding_type, int64_t* output_size,\n                           int32_t* padding_before, int32_t* padding_after);\n\nvoid GetWindowedOutputSize(int64_t input_size, int32_t filter_size, int32_t dilation_rate,\n                           int32_t stride, const std::string& padding_type, bool ceil_mode,\n                           int64_t* output_size, int32_t* padding_before, int32_t* padding_after);\n\nvoid Get3DOutputSize(const DimVector& in, const std::vector<int32_t>& pool_size,\n                     const std::vector<int32_t>& strides, const std::string& padding_type,\n                     DimVector* out, std::vector<int32_t>* padding);\n\nvoid Get3DOutputSize(const DimVector& in, const std::vector<int32_t>& pool_size,\n                     const std::vector<int32_t>& strides, const std::string& padding_type,\n                     DimVector* out, std::vector<int32_t>* padding_before,\n                     std::vector<int32_t>* padding_after);\n\nvoid Get3DOutputSize(const DimVector& in, const std::vector<int32_t>& pool_size,\n                     const std::vector<int32_t>& strides, const std::string& padding_type,\n                     DimVector* out, std::vector<int32_t>* padding_before,\n                     std::vector<int32_t>* padding_after, std::vector<int32_t>* dilation_rate);\n\nvoid Get3DOutputSize(const DimVector& in, const std::vector<int32_t>& pool_size,\n                     const std::vector<int32_t>& strides, const std::string& padding_type,\n                     const bool ceil_mode, std::vector<int32_t>* dilation_rate, DimVector* out,\n                     std::vector<int32_t>* padding_before, std::vector<int32_t>* padding_after);\n\nvoid GetConvOutAndPad(const ShapeView& in_blob_shape, const PbMessage& conv_conf, DimVector* out,\n                      std::vector<int32_t>* pad_small_side, std::vector<int32_t>* pad_large_side);\n\nvoid GetConvOutAndPad(const ShapeView& in_blob_shape, const user_op::UserOpConfWrapper& conv_conf,\n                      DimVector* out, std::vector<int32_t>* pad_small_side,\n                      std::vector<int32_t>* pad_large_side);\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_OPERATOR_OPERATOR_UTIL_H_\n"
  },
  {
    "path": "oneflow/core/operator/output_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/operator/interface_op_util.h\"\n#include \"oneflow/core/operator/output_op.h\"\n#include \"oneflow/core/job/sbp_signature_builder.h\"\n#include \"oneflow/core/job/env_desc.h\"\n\nnamespace oneflow {\n\nnamespace {\nMaybe<void> InferOutputOpNdSbpSignature(NdSbpSignature* nd_sbp_signature,\n                                        const ParallelDesc& parallel_desc,\n                                        const OperatorConf& op_conf) {\n  const InterfaceBlobConf& blob_conf = op_conf.output_conf().blob_conf();\n  NdSbp& in_nd_sbp = (*nd_sbp_signature->mutable_bn_in_op2nd_sbp())[\"in\"];\n  NdSbp& out_nd_sbp = (*nd_sbp_signature->mutable_bn_in_op2nd_sbp())[\"out\"];\n  JUST(InterfaceOpUtil::ParseNdSbpFromBlobConf(blob_conf, parallel_desc, &in_nd_sbp));\n  JUST(InterfaceOpUtil::ParseNdSbpFromBlobConf(blob_conf, parallel_desc, &out_nd_sbp));\n  return Maybe<void>::Ok();\n}\n}  // anonymous namespace\n\nMaybe<void> OutputOp::InitFromOpConf() {\n  CHECK(op_conf().has_output_conf());\n  EnrollInputBn(\"in\");\n  EnrollOutputBn(\"out\")->set_is_mutable(true);\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> OutputOp::InferLogicalOutBlobDescs(\n    const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,\n    const ParallelDesc& parallel_desc) const {\n  BlobDesc* out_blob_desc = BlobDesc4BnInOp(\"out\");\n  *out_blob_desc = *BlobDesc4BnInOp(\"in\");\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> OutputOp::InferOutBlobDescs(\n    const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,\n    const ParallelContext* parallel_ctx) const {\n  const BlobDesc* in_blob_desc = GetBlobDesc4BnInOp(\"in\");\n  BlobDesc* out_blob_desc = GetBlobDesc4BnInOp(\"out\");\n  // NOTE(chengcheng):\n  //   In multi-client, in blob shape maybe changed and NOT equal with output_conf.blob_conf,\n  //   and the output op actually is return op (used in single-client) with NO blob conf.\n  *out_blob_desc = *in_blob_desc;\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> OutputOp::GetSbpSignatures(SbpSignatureList* sbp_sig_list) const {\n  SbpSignature* sbp = sbp_sig_list->mutable_sbp_signature()->Add();\n  CHECK_EQ_OR_RETURN(JUST(GetOpParallelDesc())->hierarchy()->NumAxes(), 1)\n      << \"Only support 1d sbp now.\";\n  // Get sbp from BlobConf\n  const InterfaceBlobConf& blob_conf = op_conf().output_conf().blob_conf();\n  // TODO: make sure blob_conf must set nd_sbp\n  CHECK_OR_RETURN(blob_conf.has_nd_sbp());\n  const SbpParallel& sbp_parallel = SbpParallel(blob_conf.nd_sbp().sbp_parallel(0));\n  if (sbp_parallel.has_broadcast_parallel()) {\n    SbpSignatureBuilder().Broadcast(\"in\").Broadcast(\"out\").Build(sbp);\n  } else if (sbp_parallel.has_partial_sum_parallel()) {\n    SbpSignatureBuilder().PartialSum(\"in\").PartialSum(\"out\").Build(sbp);\n  } else if (sbp_parallel.has_split_parallel()) {\n    int64_t split_axis = sbp_parallel.split_parallel().axis();\n    SbpSignatureBuilder().Split(\"in\", split_axis).Split(\"out\", split_axis).Build(sbp);\n  } else {\n    UNIMPLEMENTED_THEN_RETURN();\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> OutputOp::GetNdSbpSignatureList(\n    const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,\n    const ParallelDesc& parallel_desc, std::vector<NdSbpSignature>* nd_sbp_sig_list) const {\n  NdSbpSignature nd_sbp_signature;\n  JUST(InferOutputOpNdSbpSignature(&nd_sbp_signature, parallel_desc, op_conf()));\n  nd_sbp_sig_list->emplace_back(nd_sbp_signature);\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> OutputOp::InferSbpSignature(\n    SbpSignature* sbp_signature, const SbpSignature& sbp_sig_conf,\n    const std::function<int32_t(const SbpSignature&)>& CalcOrderValue4SbpSig,\n    std::function<Maybe<const SbpInferHint*>(const std::string&)> SbpInferHint4Ibn,\n    const ParallelDesc& parallel_desc) const {\n  JUST(InterfaceOpUtil::GetOutputLikeOpSbpSignature(op_conf().output_conf().blob_conf(),\n                                                    input_bns(), output_bns(), sbp_signature));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> OutputOp::InferNdSbpSignature(\n    NdSbpSignature* nd_sbp_signature, const NdSbpSignature& nd_sbp_constraints,\n    const ParallelDesc& parallel_desc,\n    std::function<Maybe<const NdSbpInferHint*>(const std::string&)> NdSbpInferHint4Ibn) const {\n  JUST(InferOutputOpNdSbpSignature(nd_sbp_signature, parallel_desc, op_conf()));\n  return Maybe<void>::Ok();\n}\n\nSymbol<OperatorConf> OutputOp::GetOpConfWithoutOpNameAndLbn() const {\n  return SymbolOf(this->op_conf());\n}\n\nREGISTER_OP(OperatorConf::kOutputConf, OutputOp);\nREGISTER_OP_SAME_OUTPUT_BLOB_REGST_NUM(OperatorConf::kOutputConf, 1);\nREGISTER_INTERFACE_OP(OperatorConf::kOutputConf);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/operator/output_op.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_OPERATOR_OUTPUT_OP_H_\n#define ONEFLOW_CORE_OPERATOR_OUTPUT_OP_H_\n\n#include \"oneflow/core/operator/operator.h\"\n\nnamespace oneflow {\n\nclass OutputOp final : public Operator {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(OutputOp);\n  OutputOp() = default;\n  ~OutputOp() override = default;\n\n  Maybe<void> InitFromOpConf() override;\n  Maybe<void> InferOutBlobDescs(\n      const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,\n      const ParallelContext* parallel_ctx) const override;\n\n private:\n  Maybe<void> InferLogicalOutBlobDescs(\n      const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,\n      const ParallelDesc& parallel_desc) const override;\n  Maybe<void> InferSbpSignature(\n      SbpSignature* sbp_signature, const SbpSignature& sbp_sig_conf,\n      const std::function<int32_t(const SbpSignature&)>& CalcOrderValue4SbpSig,\n      std::function<Maybe<const SbpInferHint*>(const std::string&)> SbpInferHint4Ibn,\n      const ParallelDesc& parallel_desc) const override;\n  Symbol<OperatorConf> GetOpConfWithoutOpNameAndLbn() const override;\n  Maybe<void> GetSbpSignatures(SbpSignatureList* sbp_sig_list) const override;\n  Maybe<void> GetNdSbpSignatureList(\n      const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,\n      const ParallelDesc& parallel_desc,\n      std::vector<NdSbpSignature>* nd_sbp_sig_list) const override;\n  Maybe<void> InferNdSbpSignature(NdSbpSignature* nd_sbp_signature,\n                                  const NdSbpSignature& nd_sbp_constraints,\n                                  const ParallelDesc& parallel_desc,\n                                  std::function<Maybe<const NdSbpInferHint*>(const std::string&)>\n                                      NdSbpInferHint4Ibn) const override;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_OPERATOR_OUTPUT_OP_H_\n"
  },
  {
    "path": "oneflow/core/operator/reduce_sbp_util.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/operator/reduce_sbp_util.h\"\n\nnamespace oneflow {\n\nbool ReduceSbpUtil::IsReduceAxisSplitted(const SbpInferHint& ibn_hint,\n                                         const HashSet<int64_t>& reduced_axes) {\n  if (ibn_hint.sbp_parallel().has_split_parallel() == false) { return false; }\n  if (reduced_axes.empty()) { return true; }\n  return reduced_axes.find(ibn_hint.sbp_parallel().split_parallel().axis()) != reduced_axes.end();\n}\n\nstd::function<bool(int32_t)> ReduceSbpUtil::MakePredicatorIsReducedAxis(const PbRf<int32_t>& axes,\n                                                                        int32_t num_axes) {\n  HashSet<int32_t> axes_set = {axes.begin(), axes.end()};\n  return MakePredicatorIsReducedAxis(axes_set, num_axes);\n}\n\nstd::function<bool(int32_t)> ReduceSbpUtil::MakePredicatorIsReducedAxis(\n    const HashSet<int32_t>& axes, int32_t num_axes) {\n  auto axis_set = std::make_shared<HashSet<int32_t>>(axes);\n  return [axis_set](int32_t axis) -> bool { return axis_set->find(axis) != axis_set->end(); };\n}\n\nvoid ReduceSbpUtil::GetRegularAxes(int64_t num_axes, const std::vector<int32_t>& reduce_axes,\n                                   HashSet<int32_t>* axes) {\n  for (auto axis : reduce_axes) { axes->insert(ShiftNegativeAxis(axis, num_axes)); }\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/operator/reduce_sbp_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_OPERATOR_REDUCE_SBP_UTIL_H_\n#define ONEFLOW_CORE_OPERATOR_REDUCE_SBP_UTIL_H_\n\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/job/sbp_infer_hint.h\"\n#include \"oneflow/core/operator/operator.h\"\n\nnamespace oneflow {\n\nstruct ReduceSbpUtil final {\n  static bool IsReduceAxisSplitted(const SbpInferHint& ibn_hint,\n                                   const HashSet<int64_t>& reduced_axes);\n  static std::function<bool(int32_t)> MakePredicatorIsReducedAxis(const HashSet<int32_t>& axes,\n                                                                  int32_t num_axes);\n  static std::function<bool(int32_t)> MakePredicatorIsReducedAxis(const PbRf<int32_t>& axes,\n                                                                  int32_t num_axes);\n  static void GetRegularAxes(int64_t num_axes, const std::vector<int32_t>& reduce_axes,\n                             HashSet<int32_t>* axes);\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_OPERATOR_REDUCE_SBP_UTIL_H_\n"
  },
  {
    "path": "oneflow/core/operator/reentrant_lock_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/operator/reentrant_lock_op.h\"\n#include \"oneflow/core/job/sbp_signature_builder.h\"\n\nnamespace oneflow {\n\nMaybe<void> ReentrantLockOp::InitFromOpConf() {\n  EnrollInputBn(\"start\", false);\n  if (op_conf().reentrant_lock_conf().has_end()) { EnrollInputBn(\"end\", false); }\n  EnrollOutputBn(\"out\", false);\n  return Maybe<void>::Ok();\n}\n\nnamespace {\n\nMaybe<void> InferBlobDescs(const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp) {\n  const BlobDesc* start = BlobDesc4BnInOp(\"start\");\n  const DataType data_type = start->data_type();\n  CHECK_OR_RETURN(IsIntegralDataType(data_type));\n  BlobDesc* out = BlobDesc4BnInOp(\"out\");\n  out->set_shape(Shape({1}));\n  out->set_data_type(data_type);\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\nMaybe<void> ReentrantLockOp::InferLogicalOutBlobDescs(\n    const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,\n    const ParallelDesc& parallel_desc) const {\n  CHECK_EQ_OR_RETURN(parallel_desc.parallel_num(), 1);\n  return InferBlobDescs(BlobDesc4BnInOp);\n}\n\nMaybe<void> ReentrantLockOp::InferOutBlobDescs(\n    const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,\n    const ParallelContext* parallel_ctx) const {\n  CHECK_EQ_OR_RETURN(parallel_ctx->parallel_num(), 1);\n  return InferBlobDescs(GetBlobDesc4BnInOp);\n}\n\nMaybe<void> ReentrantLockOp::GetSbpSignatures(\n    const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,\n    SbpSignatureList* sbp_sig_list) const {\n  return Maybe<void>::Ok();\n}\n\nREGISTER_CPU_OP(OperatorConf::kReentrantLockConf, ReentrantLockOp);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/operator/reentrant_lock_op.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_OPERATOR_REENTRANT_LOCK_OP_H_\n#define ONEFLOW_CORE_OPERATOR_REENTRANT_LOCK_OP_H_\n\n#include \"oneflow/core/operator/operator.h\"\n\nnamespace oneflow {\n\nclass ReentrantLockOp final : public Operator {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(ReentrantLockOp);\n  ReentrantLockOp() = default;\n  ~ReentrantLockOp() override = default;\n\n  Maybe<void> InitFromOpConf() override;\n  Maybe<void> InferLogicalOutBlobDescs(\n      const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,\n      const ParallelDesc& parallel_desc) const override;\n  Maybe<void> InferOutBlobDescs(\n      const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,\n      const ParallelContext* parallel_ctx) const override;\n\n private:\n  Maybe<void> GetSbpSignatures(\n      const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,\n      SbpSignatureList* sbp_sig_list) const override;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_OPERATOR_REENTRANT_LOCK_OP_H_\n"
  },
  {
    "path": "oneflow/core/operator/return_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/operator/return_op.h\"\n#include \"oneflow/core/job/sbp_signature_builder.h\"\n#include \"oneflow/core/operator/interface_op_util.h\"\n\nnamespace oneflow {\n\nMaybe<void> ReturnOp::InitFromOpConf() {\n  CHECK(op_conf().has_return_conf());\n  EnrollInputBn(\"in\");\n  EnrollOutputBn(\"out\")->set_is_mutable(true);\n  return Maybe<void>::Ok();\n}\n\nnamespace {\n\nMaybe<void> InferBlobDescs(const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp) {\n  *BlobDesc4BnInOp(\"out\") = *BlobDesc4BnInOp(\"in\");\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\nMaybe<void> ReturnOp::InferLogicalOutBlobDescs(\n    const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,\n    const ParallelDesc& parallel_desc) const {\n  return InferBlobDescs(BlobDesc4BnInOp);\n}\n\nMaybe<void> ReturnOp::InferOutBlobDescs(\n    const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,\n    const ParallelContext* parallel_ctx) const {\n  return InferBlobDescs(GetBlobDesc4BnInOp);\n}\n\nMaybe<void> ReturnOp::InferSbpSignature(\n    SbpSignature* sbp_signature, const SbpSignature& sbp_sig_conf,\n    const std::function<int32_t(const SbpSignature&)>& CalcOrderValue4SbpSig,\n    std::function<Maybe<const SbpInferHint*>(const std::string&)> SbpInferHint4Ibn,\n    const ParallelDesc& parallel_desc) const {\n  const auto& in_sbp_infer_hint = *JUST(SbpInferHint4Ibn(\"in\"));\n  CHECK_EQ_OR_RETURN(in_sbp_infer_hint.parallel_desc().parallel_num(),\n                     parallel_desc.parallel_num());\n  if (in_sbp_infer_hint.sbp_parallel().has_partial_sum_parallel()) {\n    SbpSignatureBuilder().Broadcast(input_bns()).Broadcast(output_bns()).Build(sbp_signature);\n  } else {\n    auto* bn2sbp = sbp_signature->mutable_bn_in_op2sbp_parallel();\n    (*bn2sbp)[\"in\"] = in_sbp_infer_hint.sbp_parallel();\n    (*bn2sbp)[\"out\"] = in_sbp_infer_hint.sbp_parallel();\n  }\n  return Maybe<void>::Ok();\n}\n\nSymbol<OperatorConf> ReturnOp::GetOpConfWithoutOpNameAndLbn() const {\n  return SymbolOf(this->op_conf());\n}\n\nREGISTER_OP(OperatorConf::kReturnConf, ReturnOp);\nREGISTER_OP_SAME_OUTPUT_BLOB_REGST_NUM(OperatorConf::kReturnConf, 1);\nREGISTER_INTERFACE_OP(OperatorConf::kReturnConf);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/operator/return_op.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_OPERATOR_RETURN_OP_H_\n#define ONEFLOW_CORE_OPERATOR_RETURN_OP_H_\n\n#include \"oneflow/core/operator/operator.h\"\n\nnamespace oneflow {\n\nclass ReturnOp final : public Operator {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(ReturnOp);\n  ReturnOp() = default;\n  ~ReturnOp() override = default;\n\n  Maybe<void> InitFromOpConf() override;\n  Maybe<void> InferLogicalOutBlobDescs(\n      const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,\n      const ParallelDesc& parallel_desc) const override;\n  Maybe<void> InferOutBlobDescs(\n      const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,\n      const ParallelContext* parallel_ctx) const override;\n\n private:\n  Maybe<void> InferSbpSignature(\n      SbpSignature* sbp_signature, const SbpSignature& sbp_sig_conf,\n      const std::function<int32_t(const SbpSignature&)>& CalcOrderValue4SbpSig,\n      std::function<Maybe<const SbpInferHint*>(const std::string&)> SbpInferHint4Ibn,\n      const ParallelDesc& parallel_desc) const override;\n  Symbol<OperatorConf> GetOpConfWithoutOpNameAndLbn() const override;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_OPERATOR_RETURN_OP_H_\n"
  },
  {
    "path": "oneflow/core/operator/scalar_op_base.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/operator/scalar_op_base.h\"\n#include \"oneflow/core/job/sbp_signature_builder.h\"\n\nnamespace oneflow {\n\nMaybe<void> ScalarOpBase::InitFromOpConf() {\n  EnrollInputBn(\"in\");\n  EnrollInputBn(\"scalar\");\n  EnrollOutputBn(\"out\")->set_mutable_inplace_ibn(\"in\");\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> ScalarOpBase::InferOutBlobDescs(\n    const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,\n    const ParallelContext* parallel_ctx) const {\n  const BlobDesc* in_blob_desc = GetBlobDesc4BnInOp(\"in\");\n  const BlobDesc* scalar_blob_desc = GetBlobDesc4BnInOp(\"scalar\");\n  CHECK_EQ_OR_RETURN(in_blob_desc->data_type(), scalar_blob_desc->data_type());\n  CHECK_EQ_OR_RETURN(scalar_blob_desc->shape().elem_cnt(), 1);\n  BlobDesc* out_blob_desc = GetBlobDesc4BnInOp(\"out\");\n  *out_blob_desc = *in_blob_desc;\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> ScalarOpBase::GetSbpSignatures(\n    const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,\n    SbpSignatureList* sbp_sig_list) const {\n  const Shape& in_shape = JUST(LogicalBlobDesc4Ibn(\"in\")).shape();\n  FOR_RANGE(int64_t, i, 0, in_shape.NumAxes()) {\n    SbpSignatureBuilder().Split(\"in\", i).Broadcast(\"scalar\").Split(\"out\", i).Build(\n        sbp_sig_list->mutable_sbp_signature()->Add());\n  }\n  JUST(VirtualGetSbpSignatures(LogicalBlobDesc4Ibn, sbp_sig_list));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/operator/scalar_op_base.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_OPERATOR_SCALAR_OP_BASE_H_\n#define ONEFLOW_CORE_OPERATOR_SCALAR_OP_BASE_H_\n\n#include \"oneflow/core/operator/operator.h\"\n\nnamespace oneflow {\n\nclass ScalarOpBase : public Operator {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(ScalarOpBase);\n  ScalarOpBase() = default;\n  ~ScalarOpBase() override = default;\n\n  Maybe<void> InitFromOpConf() override;\n  Maybe<void> InferOutBlobDescs(\n      const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,\n      const ParallelContext* parallel_ctx) const override;\n\n protected:\n  virtual Maybe<void> VirtualGetSbpSignatures(\n      const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,\n      SbpSignatureList* sbp_sig_list) const {\n    return Maybe<void>::Ok();\n  }\n\n private:\n  Maybe<void> GetSbpSignatures(\n      const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,\n      SbpSignatureList* sbp_sig_list) const override;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_OPERATOR_SCALAR_OP_BASE_H_\n"
  },
  {
    "path": "oneflow/core/operator/shape_elem_cnt_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/operator/shape_elem_cnt_op.h\"\n#include \"oneflow/core/operator/reduce_sbp_util.h\"\n#include \"oneflow/core/job/sbp_signature_builder.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nHashSet<int32_t> GetInclusiveAxes(const ShapeElemCntOpConf& conf, int32_t num_axes) {\n  HashSet<int32_t> ret;\n  if (conf.has_exclude_axis_conf()) {\n    HashSet<int32_t> exclude_axes(conf.exclude_axis_conf().axis().begin(),\n                                  conf.exclude_axis_conf().axis().end());\n    FOR_RANGE(int32_t, i, 0, num_axes) {\n      if (exclude_axes.find(i) == exclude_axes.end()\n          && exclude_axes.find(i - num_axes) == exclude_axes.end()) {\n        ret.insert(i);\n      }\n    }\n  } else if (conf.has_include_axis_conf()) {\n    for (int32_t axis : conf.include_axis_conf().axis()) {\n      if (axis < 0) { axis += num_axes; }\n      CHECK_GE(axis, 0);\n      CHECK_LT(axis, num_axes);\n      ret.insert(axis);\n    }\n  } else if (conf.has_range_axis_conf()) {\n    TODO();\n  } else {\n    UNIMPLEMENTED();\n  }\n  return ret;\n}\n\n}  // namespace\n\nMaybe<void> ShapeElemCntOp::InitFromOpConf() {\n  EnrollInputBn(\"x\", false);\n  EnrollOutputBn(\"y\", false);\n  return Maybe<void>::Ok();\n}\n\nnamespace {\n\nMaybe<void> InferBlobDescs(const OperatorConf& op_conf,\n                           const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp) {\n  BlobDesc4BnInOp(\"y\")->set_data_type(op_conf.shape_elem_cnt_conf().data_type());\n  BlobDesc4BnInOp(\"y\")->set_shape(Shape({}));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\nMaybe<void> ShapeElemCntOp::InferLogicalOutBlobDescs(\n    const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,\n    const ParallelDesc& parallel_desc) const {\n  return InferBlobDescs(op_conf(), BlobDesc4BnInOp);\n}\n\nMaybe<void> ShapeElemCntOp::InferOutBlobDescs(\n    const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,\n    const ParallelContext* parallel_ctx) const {\n  return InferBlobDescs(op_conf(), GetBlobDesc4BnInOp);\n}\n\nvoid ShapeElemCntOp::VirtualGenKernelConf(\n    std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,\n    const ParallelContext* parallel_ctx, KernelConf* kernel_conf) const {\n  int32_t num_axes = GetBlobDesc4BnInOp(\"x\")->shape().NumAxes();\n  const HashSet<int32_t>& inclusive_axis =\n      GetInclusiveAxes(op_conf().shape_elem_cnt_conf(), num_axes);\n  *kernel_conf->mutable_shape_elem_cnt_conf()->mutable_axis() = {inclusive_axis.begin(),\n                                                                 inclusive_axis.end()};\n}\n\nMaybe<void> ShapeElemCntOp::GetSbpSignatures(\n    const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,\n    SbpSignatureList* sbp_sig_list) const {\n  int32_t num_axes = JUST(LogicalBlobDesc4Ibn(\"x\")).shape().NumAxes();\n  const auto& inclusive_axes = GetInclusiveAxes(op_conf().shape_elem_cnt_conf(), num_axes);\n  auto IsReducedAxis = ReduceSbpUtil::MakePredicatorIsReducedAxis(inclusive_axes, num_axes);\n  FOR_RANGE(int64_t, i, 0, num_axes) {\n    if (IsReducedAxis(i)) {\n      SbpSignatureBuilder()\n          .Split(input_bns(), i)\n          .PartialSum(output_bns())\n          .Build(sbp_sig_list->mutable_sbp_signature()->Add());\n    } else {\n      SbpSignatureBuilder()\n          .Split(input_bns(), i)\n          .Broadcast(output_bns())\n          .Build(sbp_sig_list->mutable_sbp_signature()->Add());\n    }\n  }\n  if (num_axes == 0) {\n    SbpSignatureBuilder()\n        .PartialSum(input_bns())\n        .PartialSum(output_bns())\n        .Build(sbp_sig_list->mutable_sbp_signature()->Add());\n  }\n  return Maybe<void>::Ok();\n}\n\nREGISTER_OP(OperatorConf::kShapeElemCntConf, ShapeElemCntOp);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/operator/shape_elem_cnt_op.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_OPERATOR_SHAPE_ELEM_CNT_H_\n#define ONEFLOW_CORE_OPERATOR_SHAPE_ELEM_CNT_H_\n\n#include \"oneflow/core/operator/operator.h\"\n\nnamespace oneflow {\n\nclass ShapeElemCntOp final : public Operator {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(ShapeElemCntOp);\n  ShapeElemCntOp() = default;\n  ~ShapeElemCntOp() override = default;\n\n  Maybe<void> InitFromOpConf() override;\n  Maybe<void> InferLogicalOutBlobDescs(\n      const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,\n      const ParallelDesc& parallel_desc) const override;\n  Maybe<void> InferOutBlobDescs(\n      const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,\n      const ParallelContext* parallel_ctx) const override;\n\n private:\n  Maybe<void> GetSbpSignatures(\n      const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,\n      SbpSignatureList* sbp_sig_list) const override;\n  void VirtualGenKernelConf(std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,\n                            const ParallelContext*, KernelConf*) const override;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_OPERATOR_SHAPE_ELEM_CNT_H_\n"
  },
  {
    "path": "oneflow/core/operator/sink_tick_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/operator/sink_tick_op.h\"\n#include \"oneflow/core/job/sbp_signature_builder.h\"\n\nnamespace oneflow {\n\nMaybe<void> SinkTickOp::InitFromOpConf() {\n  CHECK(op_conf().has_sink_tick_conf());\n  EnrollRepeatedInputBn(\"tick\", false);\n  EnrollOutputBn(\"out\", false);\n  return Maybe<void>::Ok();\n}\n\nnamespace {\n\nMaybe<void> InferBlobDescs(const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp) {\n  BlobDesc* blob_desc = BlobDesc4BnInOp(\"out\");\n  blob_desc->set_shape(Shape({1}));\n  blob_desc->set_data_type(DataType::kInt8);\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\nMaybe<void> SinkTickOp::InferLogicalOutBlobDescs(\n    const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,\n    const ParallelDesc& parallel_desc) const {\n  return InferBlobDescs(BlobDesc4BnInOp);\n}\n\nMaybe<void> SinkTickOp::InferOutBlobDescs(\n    const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,\n    const ParallelContext* parallel_ctx) const {\n  return InferBlobDescs(GetBlobDesc4BnInOp);\n}\n\nMaybe<void> SinkTickOp::GetSbpSignatures(SbpSignatureList* sbp_sig_list) const {\n  SbpSignatureBuilder()\n      .Broadcast(input_bns())\n      .Broadcast(output_bns())\n      .Build(sbp_sig_list->mutable_sbp_signature()->Add());\n  return Maybe<void>::Ok();\n}\n\nREGISTER_CPU_OP(OperatorConf::kSinkTickConf, SinkTickOp);\nREGISTER_TICK_TOCK_OP(OperatorConf::kSinkTickConf);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/operator/sink_tick_op.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_OPERATOR_SINK_TICK_OP_H_\n#define ONEFLOW_CORE_OPERATOR_SINK_TICK_OP_H_\n\n#include \"oneflow/core/operator/operator.h\"\n\nnamespace oneflow {\n\nclass SinkTickOp final : public Operator {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(SinkTickOp);\n  SinkTickOp() = default;\n  ~SinkTickOp() = default;\n\n  Maybe<void> InitFromOpConf() override;\n  Maybe<void> InferLogicalOutBlobDescs(\n      const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,\n      const ParallelDesc& parallel_desc) const override;\n  Maybe<void> InferOutBlobDescs(\n      const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,\n      const ParallelContext* parallel_ctx) const override;\n\n private:\n  Maybe<void> GetSbpSignatures(SbpSignatureList* sbp_sig_list) const override;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_OPERATOR_SINK_TICK_OP_H_\n"
  },
  {
    "path": "oneflow/core/operator/slice_boxing_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/operator/operator.h\"\n#include \"oneflow/core/common/protobuf.h\"\n#include \"oneflow/core/register/tensor_slice_view.h\"\n\nnamespace oneflow {\n\nclass SliceBoxingOp : public Operator {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(SliceBoxingOp);\n  SliceBoxingOp() = default;\n  ~SliceBoxingOp() override = default;\n\n  Maybe<void> InitFromOpConf() override;\n  Maybe<void> InferLogicalOutBlobDescs(\n      const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,\n      const ParallelDesc& parallel_desc) const override {\n    UNIMPLEMENTED_THEN_RETURN();\n  }\n  Maybe<void> InferOutBlobDescs(\n      const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,\n      const ParallelContext* parallel_ctx) const override;\n\n protected:\n  virtual const SliceBoxingConf& GetCustomizedBoxingConf() const = 0;\n  virtual void VirtualInitFromOpConf(){};\n\n private:\n  LogicalBlobId lbi4ibn(const std::string& input_bn) const override;\n  LogicalBlobId lbi4obn(const std::string& output_bn) const override;\n};\n\nclass SliceBoxingCopyOp final : public SliceBoxingOp {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(SliceBoxingCopyOp);\n  SliceBoxingCopyOp() = default;\n  ~SliceBoxingCopyOp() override = default;\n\n private:\n  const SliceBoxingConf& GetCustomizedBoxingConf() const override {\n    return op_conf().slice_boxing_copy_conf().slice_boxing_conf();\n  }\n  Symbol<OperatorConf> GetOpConfWithoutOpNameAndLbn() const override;\n};\n\nclass SliceBoxingAddOp final : public SliceBoxingOp {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(SliceBoxingAddOp);\n  SliceBoxingAddOp() = default;\n  ~SliceBoxingAddOp() override = default;\n\n private:\n  const SliceBoxingConf& GetCustomizedBoxingConf() const override {\n    return op_conf().slice_boxing_add_conf().slice_boxing_conf();\n  }\n  void VirtualInitFromOpConf() override;\n  Maybe<void> InferInternalBlobDescs(\n      const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,\n      const ParallelContext* parallel_ctx, const JobDesc* job_desc) const override;\n  Symbol<OperatorConf> GetOpConfWithoutOpNameAndLbn() const override;\n};\n\nMaybe<void> SliceBoxingOp::InitFromOpConf() {\n  EnrollRepeatedInputBn(\"in\", GetCustomizedBoxingConf().in_slice_size(), false);\n  EnrollOutputBn(\"out\");\n  VirtualInitFromOpConf();\n  return Maybe<void>::Ok();\n}\n\nLogicalBlobId SliceBoxingOp::lbi4ibn(const std::string& input_bn) const {\n  return GetCustomizedBoxingConf().lbi();\n}\n\nLogicalBlobId SliceBoxingOp::lbi4obn(const std::string& output_bn) const {\n  return GetCustomizedBoxingConf().lbi();\n}\n\nMaybe<void> SliceBoxingOp::InferOutBlobDescs(\n    const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,\n    const ParallelContext* parallel_ctx) const {\n  const SliceBoxingConf& slice_boxing_conf = GetCustomizedBoxingConf();\n  const PbRpf<TensorSliceViewProto>& in_slice_proto = slice_boxing_conf.in_slice();\n  const TensorSliceViewProto& out_slice_proto = slice_boxing_conf.out_slice();\n  const BlobDesc* in_0 = GetBlobDesc4BnInOp(GenRepeatedBn(\"in\", 0));\n  const DataType data_type = in_0->data_type();\n  FOR_RANGE(int64_t, i, 1, input_bns().size()) {\n    const BlobDesc* in_i = GetBlobDesc4BnInOp(GenRepeatedBn(\"in\", i));\n    CHECK_EQ(in_i->data_type(), data_type);\n  }\n  FOR_RANGE(int64_t, i, 0, input_bns().size()) {\n    const BlobDesc* in_i = GetBlobDesc4BnInOp(GenRepeatedBn(\"in\", i));\n    const TensorSliceView in_i_slice(in_slice_proto.Get(i));\n    CHECK_EQ(in_i->shape().elem_cnt(), in_i_slice.shape().elem_cnt());\n  }\n  const TensorSliceView out_slice(out_slice_proto);\n  BlobDesc* out = GetBlobDesc4BnInOp(\"out\");\n  out->set_data_type(data_type);\n  if (slice_boxing_conf.has_out_shape()) {\n    const Shape out_shape(slice_boxing_conf.out_shape());\n    CHECK_EQ(out_shape.elem_cnt(), out_slice.shape().elem_cnt());\n    out->set_shape(out_shape);\n  } else {\n    out->set_shape(out_slice.shape());\n  }\n  return Maybe<void>::Ok();\n}\n\nSymbol<OperatorConf> SliceBoxingCopyOp::GetOpConfWithoutOpNameAndLbn() const {\n  OperatorConf op_conf(this->op_conf());\n  op_conf.set_name(\"undefined-op-name\");\n  CHECK(op_conf.has_slice_boxing_copy_conf());\n  auto* boxing_conf = op_conf.mutable_slice_boxing_copy_conf();\n  LogicalBlobId empty_logical_blob_id{};\n  *boxing_conf->mutable_slice_boxing_conf()->mutable_lbi() = empty_logical_blob_id;\n  return SymbolOf(op_conf);\n}\n\nvoid SliceBoxingAddOp::VirtualInitFromOpConf() { EnrollTmpBn(\"buf\"); }\n\nMaybe<void> SliceBoxingAddOp::InferInternalBlobDescs(\n    const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,\n    const ParallelContext* parallel_ctx, const JobDesc* job_desc) const {\n  *GetBlobDesc4BnInOp(\"buf\") = *GetBlobDesc4BnInOp(\"out\");\n  return Maybe<void>::Ok();\n}\n\nSymbol<OperatorConf> SliceBoxingAddOp::GetOpConfWithoutOpNameAndLbn() const {\n  OperatorConf op_conf(this->op_conf());\n  op_conf.set_name(\"undefined-op-name\");\n  CHECK(op_conf.has_slice_boxing_add_conf());\n  auto* boxing_conf = op_conf.mutable_slice_boxing_add_conf();\n  LogicalBlobId empty_logical_blob_id{};\n  *boxing_conf->mutable_slice_boxing_conf()->mutable_lbi() = empty_logical_blob_id;\n  return SymbolOf(op_conf);\n}\n\nREGISTER_OP(OperatorConf::kSliceBoxingCopyConf, SliceBoxingCopyOp);\nREGISTER_OP(OperatorConf::kSliceBoxingAddConf, SliceBoxingAddOp);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/operator/source_tick_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/operator/source_tick_op.h\"\n#include \"oneflow/core/job/sbp_signature_builder.h\"\n\nnamespace oneflow {\n\nMaybe<void> SourceTickOp::InitFromOpConf() {\n  CHECK(op_conf().has_source_tick_conf());\n  CHECK(op_conf().ctrl_in_op_name().empty());\n  EnrollOutputBn(\"out\", false);\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> SourceTickOp::InferLogicalOutBlobDescs(\n    const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,\n    const ParallelDesc& parallel_desc) const {\n  BlobDesc* blob_desc = BlobDesc4BnInOp(\"out\");\n  blob_desc->set_shape(Shape({1}));\n  blob_desc->set_data_type(DataType::kInt8);\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> SourceTickOp::InferOutBlobDescs(\n    const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,\n    const ParallelContext* parallel_ctx) const {\n  CHECK_EQ_OR_RETURN(parallel_ctx->parallel_num(), 1);\n  BlobDesc* blob_desc = GetBlobDesc4BnInOp(\"out\");\n  blob_desc->set_shape(Shape({1}));\n  blob_desc->set_data_type(DataType::kInt8);\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> SourceTickOp::GetSbpSignatures(SbpSignatureList* sbp_sig_list) const {\n  SbpSignatureBuilder().Broadcast(output_bns()).Build(sbp_sig_list->mutable_sbp_signature()->Add());\n  return Maybe<void>::Ok();\n}\n\nREGISTER_CPU_OP(OperatorConf::kSourceTickConf, SourceTickOp);\nREGISTER_TICK_TOCK_OP(OperatorConf::kSourceTickConf);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/operator/source_tick_op.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_OPERATOR_SOURCE_TICK_OP_H_\n#define ONEFLOW_CORE_OPERATOR_SOURCE_TICK_OP_H_\n\n#include \"oneflow/core/operator/operator.h\"\n\nnamespace oneflow {\n\nclass SourceTickOp final : public Operator {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(SourceTickOp);\n  SourceTickOp() = default;\n  ~SourceTickOp() = default;\n\n  Maybe<void> InitFromOpConf() override;\n  Maybe<void> InferLogicalOutBlobDescs(\n      const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,\n      const ParallelDesc& parallel_desc) const override;\n  Maybe<void> InferOutBlobDescs(\n      const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,\n      const ParallelContext* parallel_ctx) const override;\n\n private:\n  Maybe<void> GetSbpSignatures(SbpSignatureList* sbp_sig_list) const override;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_OPERATOR_SOURCE_TICK_OP_H_\n"
  },
  {
    "path": "oneflow/core/operator/src_subset_tick_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/operator/operator.h\"\n#include \"oneflow/core/job/sbp_signature_builder.h\"\n\nnamespace oneflow {\n\nclass SrcSubsetTickOp final : public Operator {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(SrcSubsetTickOp);\n  SrcSubsetTickOp() = default;\n  ~SrcSubsetTickOp() = default;\n\n  Maybe<void> InitFromOpConf() override;\n  Maybe<void> InferLogicalOutBlobDescs(\n      const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,\n      const ParallelDesc& parallel_desc) const override;\n  Maybe<void> InferOutBlobDescs(\n      const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,\n      const ParallelContext* parallel_ctx) const override;\n\n private:\n  Maybe<void> GetSbpSignatures(SbpSignatureList* sbp_sig_list) const override;\n};\n\nMaybe<void> SrcSubsetTickOp::InitFromOpConf() {\n  CHECK(op_conf().has_src_subset_tick_conf());\n  EnrollRepeatedInputBn(\"in\", false);\n  EnrollOutputBn(\"out\", false);\n  return Maybe<void>::Ok();\n}\n\nnamespace {\n\nMaybe<void> InferBlobDescs(const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp) {\n  BlobDesc* blob_desc = BlobDesc4BnInOp(\"out\");\n  blob_desc->set_shape(Shape({1}));\n  blob_desc->set_data_type(DataType::kInt8);\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\nMaybe<void> SrcSubsetTickOp::InferLogicalOutBlobDescs(\n    const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,\n    const ParallelDesc& parallel_desc) const {\n  return InferBlobDescs(BlobDesc4BnInOp);\n}\n\nMaybe<void> SrcSubsetTickOp::InferOutBlobDescs(\n    const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,\n    const ParallelContext* parallel_ctx) const {\n  return InferBlobDescs(GetBlobDesc4BnInOp);\n}\n\nMaybe<void> SrcSubsetTickOp::GetSbpSignatures(SbpSignatureList* sbp_sig_list) const {\n  SbpSignatureBuilder()\n      .Broadcast(input_bns())\n      .Broadcast(output_bns())\n      .Build(sbp_sig_list->mutable_sbp_signature()->Add());\n  return Maybe<void>::Ok();\n}\n\nREGISTER_CPU_OP(OperatorConf::kSrcSubsetTickConf, SrcSubsetTickOp);\nREGISTER_TICK_TOCK_OP(OperatorConf::kSrcSubsetTickConf);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/operator/sync_dynamic_resize_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/operator/operator.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nMaybe<void> InferBlobDescs(const OperatorConf& op_conf,\n                           const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp) {\n  const SyncDynamicResizeOpConf& conf = op_conf.sync_dynamic_resize_conf();\n  CHECK_EQ_OR_RETURN(conf.axis(), 0);\n  const BlobDesc* in = BlobDesc4BnInOp(\"in\");\n  const BlobDesc* size = BlobDesc4BnInOp(\"size\");\n  CHECK_EQ_OR_RETURN(size->shape().elem_cnt(), 1);\n  CHECK_OR_RETURN(IsIntegralDataType(size->data_type()));\n  BlobDesc* out = BlobDesc4BnInOp(\"out\");\n  *out = *in;\n  out->set_is_dynamic(true);\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\nclass SyncDynamicResizeOp : public Operator {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(SyncDynamicResizeOp);\n  SyncDynamicResizeOp() = default;\n  ~SyncDynamicResizeOp() override = default;\n\n  Maybe<void> InitFromOpConf() override {\n    EnrollInputBn(\"in\");\n    EnrollInputBn(\"size\", false);\n    EnrollOutputBn(\"out\")->set_header_infered_before_compute(false);\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> InferLogicalOutBlobDescs(\n      const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,\n      const ParallelDesc& parallel_desc) const override {\n    return InferBlobDescs(op_conf(), BlobDesc4BnInOp);\n  }\n\n  Maybe<void> InferOutBlobDescs(\n      const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,\n      const ParallelContext* parallel_ctx) const override {\n    return InferBlobDescs(op_conf(), GetBlobDesc4BnInOp);\n  }\n\n  Maybe<void> GetSbpSignatures(\n      const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,\n      SbpSignatureList* sbp_sig_list) const override {\n    return Maybe<void>::Ok();\n  }\n\n  void VirtualGenKernelConf(std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,\n                            const ParallelContext* parallel_ctx,\n                            KernelConf* kernel_conf) const override {\n    kernel_conf->mutable_sync_dynamic_resize_conf()->set_size_data_type(\n        GetBlobDesc4BnInOp(\"size\")->data_type());\n  }\n};\n\nREGISTER_OP(OperatorConf::kSyncDynamicResizeConf, SyncDynamicResizeOp);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/operator/tick_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/operator/tick_op.h\"\n#include \"oneflow/core/job/sbp_signature_builder.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nMaybe<void> InferBlobDescs(const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp) {\n  BlobDesc* blob_desc = BlobDesc4BnInOp(\"out\");\n  blob_desc->set_shape(Shape({1}));\n  blob_desc->set_data_type(DataType::kInt8);\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\nMaybe<void> TickOp::InitFromOpConf() {\n  CHECK(op_conf().has_tick_conf());\n  EnrollRepeatedInputBn(\"tick\", false);\n  EnrollOutputBn(\"out\", false);\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> TickOp::InferLogicalOutBlobDescs(\n    const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,\n    const ParallelDesc& parallel_desc) const {\n  return InferBlobDescs(BlobDesc4BnInOp);\n}\n\nMaybe<void> TickOp::InferOutBlobDescs(\n    const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,\n    const ParallelContext* parallel_ctx) const {\n  return InferBlobDescs(GetBlobDesc4BnInOp);\n}\n\nMaybe<void> TickOp::GetSbpSignatures(\n    const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,\n    SbpSignatureList* sbp_sig_list) const {\n  return Maybe<void>::Ok();\n}\n\nREGISTER_OP_SAME_OUTPUT_BLOB_REGST_NUM(OperatorConf::kTickConf, 2);\nREGISTER_OP(OperatorConf::kTickConf, TickOp);\nREGISTER_TICK_TOCK_OP(OperatorConf::kTickConf);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/operator/tick_op.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_OPERATOR_TICK_OP_H_\n#define ONEFLOW_CORE_OPERATOR_TICK_OP_H_\n\n#include \"oneflow/core/operator/operator.h\"\n\nnamespace oneflow {\n\nclass TickOp final : public Operator {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(TickOp);\n  TickOp() = default;\n  ~TickOp() = default;\n\n  Maybe<void> InitFromOpConf() override;\n  Maybe<void> InferLogicalOutBlobDescs(\n      const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,\n      const ParallelDesc& parallel_desc) const override;\n  Maybe<void> InferOutBlobDescs(\n      const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,\n      const ParallelContext* parallel_ctx) const override;\n\n private:\n  Maybe<void> GetSbpSignatures(\n      const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,\n      SbpSignatureList* sbp_sig_list) const override;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_OPERATOR_TICK_OP_H_\n"
  },
  {
    "path": "oneflow/core/operator/total_loss_instance_num_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/operator/total_loss_instance_num_op.h\"\n\nnamespace oneflow {\n\nvoid TotalLossInstanceNumOp::VirtualInitFromOpConf() {\n  CHECK(op_conf().has_total_loss_instance_num_conf());\n}\n\nMaybe<void> TotalLossInstanceNumOp::VirtualInferBlobDescs(\n    std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,\n    const ParallelContext* parallel_ctx) const {\n  for (const std::string& ibn : input_bns()) {\n    CHECK_OR_RETURN(*GetBlobDesc4BnInOp(ibn) == *GetBlobDesc4BnInOp(input_bns().Get(0)));\n  }\n  return Maybe<void>::Ok();\n}\n\nREGISTER_CPU_OP(OperatorConf::kTotalLossInstanceNumConf, TotalLossInstanceNumOp);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/operator/total_loss_instance_num_op.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_OPERATOR_TOTAL_LOSS_INSTANCE_NUM_OP_H_\n#define ONEFLOW_CORE_OPERATOR_TOTAL_LOSS_INSTANCE_NUM_OP_H_\n\n#include \"oneflow/core/operator/cwise_op.h\"\n\nnamespace oneflow {\n\nclass TotalLossInstanceNumOp final : public CWiseOp {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(TotalLossInstanceNumOp);\n  TotalLossInstanceNumOp() = default;\n  ~TotalLossInstanceNumOp() = default;\n\n  void VirtualInitFromOpConf() override;\n  Maybe<void> VirtualInferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,\n                                    const ParallelContext* parallel_ctx) const override;\n\n private:\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_OPERATOR_TOTAL_LOSS_INSTANCE_NUM_OP_H_\n"
  },
  {
    "path": "oneflow/core/operator/user_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/infer_util.h\"\n#include \"oneflow/core/framework/sbp_context.h\"\n#include \"oneflow/core/common/tensor_desc.h\"\n#include \"oneflow/core/framework/to_string.h\"\n#include \"oneflow/core/operator/user_op.h\"\n#include \"oneflow/core/framework/infer_output_blob_time_shape_fn_context.h\"\n#include \"oneflow/core/framework/infer_nd_sbp_fn_context.h\"\n#include \"oneflow/core/framework/compute_complexity_fn_context.h\"\n#include \"oneflow/core/framework/get_nd_sbp_signature_list_context.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nBlobDesc* FindValidBlobDescOfBnsInOp(\n    std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,\n    const PbRpf<std::string>& bn_in_ops) {\n  BlobDesc* valid = nullptr;\n  for (const std::string& bn_in_op : bn_in_ops) {\n    BlobDesc* blob_desc = GetBlobDesc4BnInOp(bn_in_op);\n    if (blob_desc) {\n      const bool is_dynamic = blob_desc->is_dynamic();\n      if (valid == nullptr || is_dynamic) {\n        valid = blob_desc;\n        if (is_dynamic) { break; }\n      }\n    }\n  }\n  return valid;\n}\n\nuser_op::NaiveTensorDesc GenTensorDescFromBlobDesc(const BlobDesc* blob_desc) {\n  user_op::NaiveTensorDesc tensor_desc;\n  tensor_desc.set_shape(blob_desc->shape());\n  tensor_desc.set_stride(blob_desc->stride());\n  tensor_desc.set_data_type(blob_desc->data_type());\n  tensor_desc.set_memory_format(blob_desc->memory_format());\n  tensor_desc.set_is_dynamic(blob_desc->is_dynamic());\n  return tensor_desc;\n}\n\n}  // namespace\n\n// kernel registry context used in infer functions of user op\nclass UserOpKernelRegContext final : public user_op::KernelRegContext {\n public:\n  using ArgVec = std::vector<std::pair<std::string, int32_t>>;\n\n  explicit UserOpKernelRegContext(const UserOp* user_op,\n                                  std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,\n                                  const ParallelContext* parallel_ctx)\n      : user_op_conf_(user_op->op_conf()) {\n    const auto& op_conf = user_op->op_conf();\n    CHECK(op_conf.has_user_conf());\n\n    device_type_ = CHECK_JUST(DeviceType4DeviceTag(op_conf.device_tag()));\n    parallel_ctx_ = parallel_ctx;\n\n    auto InitInOrOut = [&](const PbMap<std::string, UserOpConf::ListString>& arg_map,\n                           ArgVec* arg_vec) {\n      for (auto it = arg_map.begin(); it != arg_map.end(); ++it) {\n        for (int32_t i = 0; i < it->second.s_size(); ++i) {\n          arg_vec->emplace_back(std::make_pair(it->first, i));\n        }\n      }\n    };\n    InitInOrOut(op_conf.user_conf().input(), &inputs_);\n    InitInOrOut(op_conf.user_conf().output(), &outputs_);\n\n    {\n#define INSERT_TO_ARG2TENSOR_DESC(prefix)                                                \\\n  for (const auto& bn : user_op->prefix##_bns()) {                                       \\\n    const BlobDesc* blob_desc = GetBlobDesc4BnInOp(bn);                                  \\\n    if (!blob_desc) { continue; }                                                        \\\n    arg2tensor_desc_.emplace(GenUnRepeatedBn(bn), GenTensorDescFromBlobDesc(blob_desc)); \\\n  }\n\n      INSERT_TO_ARG2TENSOR_DESC(input)\n      INSERT_TO_ARG2TENSOR_DESC(output)\n      INSERT_TO_ARG2TENSOR_DESC(tmp)\n\n#undef INSERT_TO_ARG2TENSOR_DESC\n    }\n  }\n  ~UserOpKernelRegContext() = default;\n\n  DeviceType device_type() const override { return device_type_; }\n\n  const ParallelContext& parallel_ctx() const override { return *parallel_ctx_; }\n  const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(const std::string& arg_name,\n                                                        int32_t index) const override {\n    auto it = arg2tensor_desc_.find(std::make_pair(arg_name, index));\n    if (it == arg2tensor_desc_.end()) { return nullptr; }\n    return &(it->second);\n  }\n  const ArgVec& inputs() const override { return inputs_; }\n  const ArgVec& outputs() const override { return outputs_; }\n\n  const user_op::UserOpConfWrapper& user_op_conf() const override { return user_op_conf_; }\n\n  const std::shared_ptr<const user_op::AttrVal>& Attr4Name(\n      const std::string& attr_name) const override {\n    return user_op_conf().Attr4Name(attr_name);\n  }\n\n private:\n  const user_op::UserOpConfWrapper user_op_conf_;\n  ArgVec inputs_;\n  ArgVec outputs_;\n  DeviceType device_type_;\n  const ParallelContext* parallel_ctx_;\n  HashMap<std::pair<std::string, int32_t>, user_op::NaiveTensorDesc> arg2tensor_desc_;\n};\n\nclass UserOpInferContext final : public user_op::InferContext {\n public:\n  using ArgVec = std::vector<std::pair<std::string, int32_t>>;\n\n  UserOpInferContext(const UserOp* op, const ParallelContext* parallel_ctx, const JobDesc* job_desc,\n                     const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp)\n      : op_(op), parallel_ctx_(parallel_ctx), job_desc_(job_desc) {\n    bn2logical_tensor_desc_.reset(new HashMap<std::string, user_op::NaiveTensorDesc>());\n    auto InitTensorDesc = [&](const ArgVec& arg_vec, const PbRpf<std::string>& bns) {\n      CHECK_EQ(arg_vec.size(), bns.size());\n      for (int32_t i = 0; i < arg_vec.size(); ++i) {\n        const auto& bn_i = bns.Get(i);\n        BlobDesc* blob = GetBlobDesc4BnInOp(bns.Get(i));\n        CHECK(blob != nullptr) << bn_i;\n        arg2tensor_desc_.emplace(arg_vec.at(i), GenTensorDescFromBlobDesc(blob));\n      }\n    };\n    InitTensorDesc(op->inputs(), op->input_bns());\n    InitTensorDesc(op->outputs(), op->output_bns());\n  }\n  ~UserOpInferContext() override = default;\n\n  const user_op::TensorDesc& InputTensorDesc(const std::string& arg_name,\n                                             int32_t index) const override {\n    return *TensorDesc4ArgNameAndIndex(arg_name, index);\n  }\n  const user_op::TensorDesc& OutputTensorDesc(const std::string& arg_name,\n                                              int32_t index) const override {\n    return *TensorDesc4ArgNameAndIndex(arg_name, index);\n  }\n  user_op::TensorDesc* MutOutputTensorDesc(const std::string& arg_name, int32_t index) override {\n    return MutTensorDesc4ArgNameAndIndex(arg_name, index);\n  }\n  const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(const std::string& arg_name,\n                                                        int32_t index) const {\n    auto it = arg2tensor_desc_.find(std::make_pair(arg_name, index));\n    if (it == arg2tensor_desc_.end()) { return nullptr; }\n    return &it->second;\n  }\n  user_op::TensorDesc* MutTensorDesc4ArgNameAndIndex(const std::string& arg_name, int32_t index) {\n    auto it = arg2tensor_desc_.find(std::make_pair(arg_name, index));\n    if (it == arg2tensor_desc_.end()) { return nullptr; };\n    return &(it->second);\n  }\n  const user_op::TensorDesc* LogicalTensorDesc4ArgNameAndIndex(const std::string& arg_name,\n                                                               int32_t index) const override {\n    const std::string bn = GenRepeatedBn(arg_name, index);\n    const auto it = bn2logical_tensor_desc_->find(bn);\n    if (it != bn2logical_tensor_desc_->end()) {\n      return &it->second;\n    } else {\n      std::shared_ptr<const BlobDesc> blob_desc = CHECK_JUST(op_->GetLogicalBlobDesc4BnInOp(bn));\n      bn2logical_tensor_desc_->emplace(bn, GenTensorDescFromBlobDesc(blob_desc.get()));\n      return &(bn2logical_tensor_desc_->emplace(bn, GenTensorDescFromBlobDesc(blob_desc.get()))\n                   .first->second);\n    }\n  }\n  const Shape& InputShape(const std::string& arg_name, int32_t index) const override {\n    return Shape4ArgNameAndIndex(arg_name, index);\n  }\n  const Shape& OutputShape(const std::string& arg_name, int32_t index) const override {\n    return Shape4ArgNameAndIndex(arg_name, index);\n  }\n  void SetOutputShape(const std::string& arg_name, int32_t index, const Shape& shape) override {\n    SetShape4ArgNameAndIndex(arg_name, index, shape);\n  }\n  const Shape& Shape4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override {\n    auto it = arg2tensor_desc_.find(std::make_pair(arg_name, index));\n    if (it == arg2tensor_desc_.end()) {\n      thread_local static Shape non_shape;\n      return non_shape;\n    };\n    return it->second.shape();\n  }\n  void SetShape4ArgNameAndIndex(const std::string& arg_name, int32_t index,\n                                const Shape& shape) override {\n    auto it = arg2tensor_desc_.find(std::make_pair(arg_name, index));\n    if (it == arg2tensor_desc_.end()) { return; };\n    return it->second.set_shape(shape);\n  }\n  const Stride& InputStride(const std::string& arg_name, int32_t index) const override {\n    return Stride4ArgNameAndIndex(arg_name, index);\n  }\n  const Stride& OutputStride(const std::string& arg_name, int32_t index) const override {\n    return Stride4ArgNameAndIndex(arg_name, index);\n  }\n  void SetOutputStride(const std::string& arg_name, int32_t index, const Stride& stride) override {\n    return SetStride4ArgNameAndIndex(arg_name, index, stride);\n  }\n  const Stride& Stride4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override {\n    auto it = arg2tensor_desc_.find(std::make_pair(arg_name, index));\n    if (it == arg2tensor_desc_.end()) {\n      thread_local static Stride non_stride;\n      return non_stride;\n    };\n    return it->second.stride();\n  }\n  void SetStride4ArgNameAndIndex(const std::string& arg_name, int32_t index,\n                                 const Stride& stride) override {\n    auto it = arg2tensor_desc_.find(std::make_pair(arg_name, index));\n    if (it == arg2tensor_desc_.end()) { return; };\n    return it->second.set_stride(stride);\n  }\n  DataType InputDType(const std::string& arg_name, int32_t index) const override {\n    return Dtype4ArgNameAndIndex(arg_name, index);\n  }\n  DataType OutputDType(const std::string& arg_name, int32_t index) const override {\n    return Dtype4ArgNameAndIndex(arg_name, index);\n  }\n  void SetOutputDType(const std::string& arg_name, int32_t index, DataType data_type) override {\n    return SetDtype4ArgNameAndIndex(arg_name, index, data_type);\n  }\n  DataType Dtype4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override {\n    auto it = arg2tensor_desc_.find(std::make_pair(arg_name, index));\n    if (it == arg2tensor_desc_.end()) { return DataType::kInvalidDataType; };\n    return it->second.data_type();\n  }\n  void SetDtype4ArgNameAndIndex(const std::string& arg_name, int32_t index,\n                                DataType data_type) override {\n    auto it = arg2tensor_desc_.find(std::make_pair(arg_name, index));\n    if (it == arg2tensor_desc_.end()) { return; };\n    return it->second.set_data_type(data_type);\n  }\n\n  MemoryFormat InputMemoryFormat(const std::string& arg_name, int32_t index) const override {\n    return MemoryFormat4ArgNameAndIndex(arg_name, index);\n  }\n  MemoryFormat OutputMemoryFormat(const std::string& arg_name, int32_t index) const override {\n    return MemoryFormat4ArgNameAndIndex(arg_name, index);\n  }\n  void SetOutputMemoryFormat(const std::string& arg_name, int32_t index,\n                             MemoryFormat memory_format) override {\n    return SetMemoryFormat4ArgNameAndIndex(arg_name, index, memory_format);\n  }\n  MemoryFormat MemoryFormat4ArgNameAndIndex(const std::string& arg_name,\n                                            int32_t index) const override {\n    auto it = arg2tensor_desc_.find(std::make_pair(arg_name, index));\n    if (it == arg2tensor_desc_.end()) { return MemoryFormat::kContiguous; };\n    return it->second.memory_format();\n  }\n  void SetMemoryFormat4ArgNameAndIndex(const std::string& arg_name, int32_t index,\n                                       MemoryFormat memory_format) override {\n    auto it = arg2tensor_desc_.find(std::make_pair(arg_name, index));\n    if (it == arg2tensor_desc_.end()) { return; };\n    return it->second.set_memory_format(memory_format);\n  }\n\n  bool InputIsDynamic(const std::string& arg_name, int32_t index) const override {\n    return IsDynamic4ArgNameAndIndex(arg_name, index);\n  }\n  bool OutputIsDynamic(const std::string& arg_name, int32_t index) const override {\n    return IsDynamic4ArgNameAndIndex(arg_name, index);\n  }\n  void SetOutputIsDynamic(const std::string& arg_name, int32_t index, bool is_dynamic) override {\n    return SetIsDynamic4ArgNameAndIndex(arg_name, index, is_dynamic);\n  }\n  bool IsDynamic4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override {\n    auto it = arg2tensor_desc_.find(std::make_pair(arg_name, index));\n    if (it == arg2tensor_desc_.end()) { return false; };\n    return it->second.is_dynamic();\n  }\n  void SetIsDynamic4ArgNameAndIndex(const std::string& arg_name, int32_t index,\n                                    bool is_dynamic) override {\n    auto it = arg2tensor_desc_.find(std::make_pair(arg_name, index));\n    if (it == arg2tensor_desc_.end()) { return; };\n    return it->second.set_is_dynamic(is_dynamic);\n  }\n\n  const ArgVec& inputs() const override { return op_->inputs(); }\n  const ArgVec& outputs() const override { return op_->outputs(); }\n  const ParallelContext& parallel_ctx() const override { return *parallel_ctx_; };\n  const ParallelDesc& parallel_desc() const override {\n    return *CHECK_JUST(op_->GetOpParallelDesc());\n  };\n  const JobDesc* job_desc() const override {\n    CHECK_NOTNULL(job_desc_);\n    return job_desc_;\n  }\n\n  const SbpParallel& SbpParallel4ArgNameAndIndex(const std::string& arg_name,\n                                                 int32_t index) const override {\n    CHECK_EQ(CHECK_JUST(op_->GetOpParallelDesc())->hierarchy()->NumAxes(), 1);\n    const auto& bn2sbp = CHECK_JUST(op_->sbp_signature())->bn_in_op2sbp_parallel();\n    std::string bn = GenRepeatedBn(arg_name, index);\n    auto it = bn2sbp.find(bn);\n    CHECK(it != bn2sbp.end());\n    return it->second;\n  }\n\n  const NdSbp& NdSbp4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override {\n    const auto& bn2nd_sbp = CHECK_JUST(op_->nd_sbp_signature())->bn_in_op2nd_sbp();\n    std::string bn = GenRepeatedBn(arg_name, index);\n    auto it = bn2nd_sbp.find(bn);\n    CHECK(it != bn2nd_sbp.end());\n    return it->second;\n  }\n\n  int64_t parallel_num() const override {\n    return CHECK_JUST(op_->GetOpParallelDesc())->parallel_num();\n  }\n\n  const std::string& input(const std::string& arg_name, int32_t index) const override {\n    return user_op_conf().input(arg_name, index);\n  }\n  const std::string& output(const std::string& arg_name, int32_t index) const override {\n    return user_op_conf().output(arg_name, index);\n  }\n  bool has_input(const std::string& arg_name, int32_t index) const override {\n    return user_op_conf().has_input(arg_name, index);\n  }\n  bool has_output(const std::string& arg_name, int32_t index) const override {\n    return user_op_conf().has_output(arg_name, index);\n  }\n  int32_t input_size(const std::string& arg_name) const override {\n    return user_op_conf().input_size(arg_name);\n  }\n  int32_t output_size(const std::string& arg_name) const override {\n    return user_op_conf().output_size(arg_name);\n  }\n  const std::string& op_name() const override { return user_op_conf().op_name(); }\n  const std::string& op_type_name() const override { return user_op_conf().op_type_name(); }\n\n  const std::string& op_loc() const override { return op_->op_loc(); }\n\n private:\n  const user_op::UserOpConfWrapper& user_op_conf() const { return op_->user_op_conf(); }\n  const std::shared_ptr<const user_op::AttrVal>& Attr4Name(\n      const std::string& attr_name) const override {\n    return user_op_conf().Attr4Name(attr_name);\n  }\n\n  const UserOp* op_;\n  const ParallelContext* parallel_ctx_;\n  const JobDesc* job_desc_;\n  HashMap<std::pair<std::string, int32_t>, user_op::NaiveTensorDesc> arg2tensor_desc_;\n  std::unique_ptr<HashMap<std::string, user_op::NaiveTensorDesc>> bn2logical_tensor_desc_;\n};\n\nclass UserOpSbpContext : public user_op::SbpContext {\n public:\n  using ArgVec = std::vector<std::pair<std::string, int32_t>>;\n\n  UserOpSbpContext(const UserOp* op, SbpSignatureList* sbp_sig_list,\n                   std::function<Maybe<const BlobDesc&>(const std::string&)> LogicalBlobDesc4Ibn,\n                   int32_t hierarchy_value)\n      : op_(op), sbp_sig_list_(sbp_sig_list), hierarchy_value_(hierarchy_value) {\n    const auto& user_op_conf = op->op_conf().user_conf();\n    for (auto it = user_op_conf.input().begin(); it != user_op_conf.input().end(); ++it) {\n      const std::string& arg_name = it->first;\n      for (int32_t i = 0; i < it->second.s_size(); ++i) {\n        const BlobDesc* blob = &CHECK_JUST(LogicalBlobDesc4Ibn(GenRepeatedBn(arg_name, i)));\n        arg2tensor_desc_.emplace(std::make_pair(arg_name, i), GenTensorDescFromBlobDesc(blob));\n      }\n    }\n  }\n  ~UserOpSbpContext() override = default;\n\n  const user_op::TensorDesc& LogicalTensorDesc4InputArgNameAndIndex(\n      const std::string& input_arg_name, int32_t index) const override {\n    auto it = arg2tensor_desc_.find(std::make_pair(input_arg_name, index));\n    CHECK(it != arg2tensor_desc_.end())\n        << \"Cannot find input_arg_name : \" << input_arg_name << \" input_arg_index : \" << index;\n    return it->second;\n  }\n  const ArgVec& inputs() const override { return op_->inputs(); }\n  const ArgVec& outputs() const override { return op_->outputs(); }\n  const user_op::UserOpConfWrapper& user_op_conf() const override { return op_->user_op_conf(); }\n\n  user_op::UserOpSbpSignatureBuilder NewBuilder() override {\n    return user_op::UserOpSbpSignatureBuilder(sbp_sig_list_);\n  }\n\n  DeviceType device_type() const override { return op_->device_type(); }\n\n  int64_t parallel_num() const override {\n    return CHECK_JUST(op_->GetOpParallelDesc())->parallel_num();\n  }\n\n  int64_t hierarchy_value() const override { return hierarchy_value_; }\n\n private:\n  const UserOp* op_;\n  SbpSignatureList* sbp_sig_list_;\n  HashMap<std::pair<std::string, int32_t>, user_op::NaiveTensorDesc> arg2tensor_desc_;\n  int32_t hierarchy_value_;\n};\n\nclass UserOpInferSbpSignatureFnContext : public user_op::InferSbpSignatureFnContext {\n public:\n  using ArgVec = std::vector<std::pair<std::string, int32_t>>;\n\n  UserOpInferSbpSignatureFnContext(\n      const UserOp* op, SbpSignature* signature, const SbpSignature& sbp_signature_conf,\n      std::function<Maybe<const SbpInferHint*>(const std::string&)> SbpInferHint4Ibn)\n      : op_(op),\n        signature_(signature),\n        sbp_signature_conf_(sbp_signature_conf),\n        sbp_infer_hint4ibn_fn_(std::move(SbpInferHint4Ibn)) {\n    const auto& user_op_conf = op->op_conf().user_conf();\n    for (const auto& it : user_op_conf.input()) {\n      const std::string& arg_name = it.first;\n      for (int32_t i = 0; i < it.second.s_size(); ++i) {\n        auto hint = CHECK_JUST(sbp_infer_hint4ibn_fn_(GenRepeatedBn(arg_name, i)));\n        arg2tensor_desc_.emplace(std::make_pair(arg_name, i),\n                                 GenTensorDescFromBlobDesc(&hint->logical_blob_desc()));\n        arg2sbp_parallel_hint_.emplace(std::make_pair(arg_name, i), hint->sbp_parallel());\n      }\n    }\n  }\n  ~UserOpInferSbpSignatureFnContext() override = default;\n\n  const user_op::TensorDesc& LogicalTensorDesc4InputArgNameAndIndex(\n      const std::string& input_arg_name, int32_t index) const override {\n    auto it = arg2tensor_desc_.find(std::make_pair(input_arg_name, index));\n    CHECK(it != arg2tensor_desc_.end())\n        << \"Cannot find input_arg_name : \" << input_arg_name << \" input_arg_index : \" << index;\n    return it->second;\n  }\n  const ArgVec& inputs() const override { return op_->inputs(); }\n  const ArgVec& outputs() const override { return op_->outputs(); }\n  SbpSignature* mutable_sbp_signature() override { return signature_; }\n  const SbpSignature& sbp_signature_conf() const override { return sbp_signature_conf_; }\n\n  const SbpParallel& SbpParallelHint4InputArgNameAndIndex(const std::string& input_arg_name,\n                                                          int32_t index) const override {\n    auto it = arg2sbp_parallel_hint_.find(std::make_pair(input_arg_name, index));\n    CHECK(it != arg2sbp_parallel_hint_.end())\n        << \"Cannot find input_arg_name : \" << input_arg_name << \" input_arg_index : \" << index;\n    return it->second;\n  }\n\n  const user_op::UserOpConfWrapper& user_op_conf() const override { return op_->user_op_conf(); }\n\n  DeviceType device_type() const override { return op_->device_type(); }\n\n  int64_t parallel_num() const override {\n    return CHECK_JUST(op_->GetOpParallelDesc())->parallel_num();\n  }\n\n private:\n  const UserOp* op_;\n  HashMap<std::pair<std::string, int32_t>, user_op::NaiveTensorDesc> arg2tensor_desc_;\n  HashMap<std::pair<std::string, int32_t>, SbpParallel> arg2sbp_parallel_hint_;\n  SbpSignature* signature_;\n  SbpSignature sbp_signature_conf_;\n  std::function<Maybe<const SbpInferHint*>(const std::string&)> sbp_infer_hint4ibn_fn_;\n};\n\nclass UserOpInferOutputBlobTimeShapeFnContext : public user_op::InferOutputBlobTimeShapeFnContext {\n public:\n  UserOpInferOutputBlobTimeShapeFnContext(\n      const UserOp* op,\n      const std::function<Maybe<const Shape>(const std::string&)>& GetTimeShape4BnInOp,\n      Shape* output_blob_time_shape)\n      : op_(op), output_blob_time_shape_(output_blob_time_shape) {\n    for (const auto& it : op->op_conf().user_conf().input()) {\n      const std::string& arg_name = it.first;\n      for (int32_t i = 0; i < it.second.s_size(); ++i) {\n        std::string ibn = GenRepeatedBn(arg_name, i);\n        arg2time_shape_.emplace(std::make_pair(arg_name, i), *CHECK_JUST(GetTimeShape4BnInOp(ibn)));\n      }\n    }\n  }\n  ~UserOpInferOutputBlobTimeShapeFnContext() override = default;\n\n  const Shape& TimeShape4InputArgNameAndIndex(const std::string& arg_name, int32_t index) override {\n    return arg2time_shape_.at(std::make_pair(arg_name, index));\n  }\n\n  const user_op::UserOpConfWrapper& user_op_conf() const override { return op_->user_op_conf(); }\n\n  Shape* mut_output_blob_time_shape() override { return output_blob_time_shape_; };\n\n private:\n  const UserOp* op_;\n  HashMap<std::pair<std::string, int32_t>, Shape> arg2time_shape_;\n  Shape* output_blob_time_shape_;\n};\n\nclass UserOpInferNdSbpFnContext : public user_op::InferNdSbpFnContext {\n public:\n  using ArgVec = std::vector<std::pair<std::string, int32_t>>;\n  UserOpInferNdSbpFnContext(\n      const UserOp* op, NdSbpSignature* nd_sbp_signature, const NdSbpSignature& nd_sbp_constraints,\n      std::function<Maybe<const NdSbpInferHint*>(const std::string&)> NdSbpInferHint4Ibn)\n      : op_(op),\n        nd_sbp_signature_(nd_sbp_signature),\n        nd_sbp_constraints_(nd_sbp_constraints),\n        nd_sbp_infer_hint4ibn_fn_(std::move(NdSbpInferHint4Ibn)) {\n    const auto& user_op_conf = op->op_conf().user_conf();\n    for (const auto& it : user_op_conf.input()) {\n      const std::string& arg_name = it.first;\n      for (int32_t i = 0; i < it.second.s_size(); ++i) {\n        auto hint = CHECK_JUST(nd_sbp_infer_hint4ibn_fn_(GenRepeatedBn(arg_name, i)));\n        CHECK(arg2tensor_desc_\n                  .emplace(std::make_pair(arg_name, i),\n                           GenTensorDescFromBlobDesc(&hint->logical_blob_desc()))\n                  .second);\n      }\n    }\n  }\n  ~UserOpInferNdSbpFnContext() override = default;\n\n  const user_op::TensorDesc& LogicalTensorDesc4InputArgNameAndIndex(\n      const std::string& input_arg_name, int32_t index) const override {\n    auto it = arg2tensor_desc_.find(std::make_pair(input_arg_name, index));\n    CHECK(it != arg2tensor_desc_.end())\n        << \"Cannot find input_arg_name : \" << input_arg_name << \" input_arg_index : \" << index;\n    return it->second;\n  }\n\n  const NdSbpSignature& nd_sbp_constraints() const override { return nd_sbp_constraints_; }\n\n  NdSbp* NdSbp4ArgNameAndIndex(const std::string& arg_name, int32_t index) override {\n    return &(*nd_sbp_signature_->mutable_bn_in_op2nd_sbp())[GenRepeatedBn(arg_name, index)];\n  }\n\n  const NdSbp& NdSbpHint4InputArgNameAndIndex(const std::string& arg_name,\n                                              int32_t index) const override {\n    auto hint = CHECK_JUST(nd_sbp_infer_hint4ibn_fn_(GenRepeatedBn(arg_name, index)));\n    return hint->nd_sbp();\n  }\n\n  const user_op::UserOpConfWrapper& user_op_conf() const override { return op_->user_op_conf(); }\n\n  int64_t parallel_num() const override {\n    return CHECK_JUST(op_->GetOpParallelDesc())->parallel_num();\n  }\n\n  const Shape& parallel_hierarchy() override {\n    return *(CHECK_JUST(op_->GetOpParallelDesc())->hierarchy());\n  }\n\n  const ArgVec& inputs() const override { return op_->inputs(); }\n  const ArgVec& outputs() const override { return op_->outputs(); }\n\n private:\n  const UserOp* op_;\n  HashMap<std::pair<std::string, int32_t>, user_op::NaiveTensorDesc> arg2tensor_desc_;\n  NdSbpSignature* nd_sbp_signature_;\n  NdSbpSignature nd_sbp_constraints_;\n  std::function<Maybe<const NdSbpInferHint*>(const std::string&)> nd_sbp_infer_hint4ibn_fn_;\n};\n\n// Store information for computing computation cost\n// TODO: Maybe this class could simplify\nclass UserOpComputeComplexityFnContext : public user_op::ComputeComplexityFnContext {\n public:\n  using ArgVec = std::vector<std::pair<std::string, int32_t>>;\n\n  UserOpComputeComplexityFnContext(\n      const OperatorConf& op_conf, const ParallelDesc& parallel_desc,\n      const NdSbpSignature* sbp_signature,\n      std::function<const BlobDesc&(const std::string& bn)> logical_blob_desc4bn)\n      : user_op::ComputeComplexityFnContext(user_op::UserOpConfWrapper(op_conf)),\n        parallel_desc_(parallel_desc),\n        sbp_signature_(sbp_signature) {\n    auto InitInOrOut = [&](const PbMap<std::string, UserOpConf::ListString>& arg_map,\n                           ArgVec* arg_vec) {\n      for (auto it = arg_map.begin(); it != arg_map.end(); ++it) {\n        const std::string& arg_name = it->first;\n        for (int32_t i = 0; i < it->second.s_size(); ++i) {\n          const BlobDesc& blob = logical_blob_desc4bn(GenRepeatedBn(arg_name, i));\n          auto key = std::make_pair(arg_name, i);\n          arg2tensor_desc_.emplace(key, GenTensorDescFromBlobDesc(&blob));\n          arg_vec->emplace_back(std::make_pair(arg_name, i));\n        }\n      }\n    };\n    InitInOrOut(op_conf.user_conf().input(), &inputs_);\n    InitInOrOut(op_conf.user_conf().output(), &outputs_);\n  }\n  ~UserOpComputeComplexityFnContext() override = default;\n\n  const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(const std::string& arg_name,\n                                                        int32_t index) override {\n    auto it = arg2tensor_desc_.find(std::make_pair(arg_name, index));\n    if (it == arg2tensor_desc_.end()) { return nullptr; };\n    return &(it->second);\n  }\n  const Shape& Shape4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override {\n    auto it = arg2tensor_desc_.find(std::make_pair(arg_name, index));\n    if (it == arg2tensor_desc_.end()) {\n      thread_local static Shape non_shape;\n      return non_shape;\n    };\n    return it->second.shape();\n  }\n  DataType Dtype4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override {\n    auto it = arg2tensor_desc_.find(std::make_pair(arg_name, index));\n    if (it == arg2tensor_desc_.end()) { return DataType::kInvalidDataType; };\n    return it->second.data_type();\n  }\n  bool IsDynamic4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override {\n    auto it = arg2tensor_desc_.find(std::make_pair(arg_name, index));\n    if (it == arg2tensor_desc_.end()) { return false; };\n    return it->second.is_dynamic();\n  }\n\n  const NdSbp NdSbp4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override {\n    const auto& bn2sbp = sbp_signature_->bn_in_op2nd_sbp();\n    std::string bn = GenRepeatedBn(arg_name, index);\n    CHECK(bn2sbp.find(bn) != bn2sbp.end());\n    return sbp_signature_->bn_in_op2nd_sbp().at(bn);\n  }\n\n  const ArgVec& inputs() const override { return inputs_; }\n  const ArgVec& outputs() const override { return outputs_; }\n  const ParallelDesc& parallel_desc() const override { return parallel_desc_; };\n  const NdSbpSignature* GetNdSbpSignature() const override { return sbp_signature_; }\n\n private:\n  ArgVec inputs_;\n  ArgVec outputs_;\n  const ParallelDesc parallel_desc_;\n  const NdSbpSignature* sbp_signature_;\n  HashMap<std::pair<std::string, int32_t>, user_op::NaiveTensorDesc> arg2tensor_desc_;\n};\n\nclass UserOpGetNdSbpSignatureListContext : public user_op::GetNdSbpSignatureListContext {\n public:\n  UserOpGetNdSbpSignatureListContext(\n      const UserOp* op,\n      std::function<Maybe<const BlobDesc&>(const std::string&)> LogicalBlobDesc4Ibn,\n      const ParallelDesc& parallel_desc, std::vector<NdSbpSignature>* nd_sbp_sig_list)\n      : user_op::GetNdSbpSignatureListContext(user_op::UserOpConfWrapper(op->user_op_conf())),\n        op_(op),\n        logical_blob_desc4ibn_(std::move(LogicalBlobDesc4Ibn)),\n        parallel_desc_(parallel_desc),\n        nd_sbp_sig_list_(nd_sbp_sig_list) {}\n  ~UserOpGetNdSbpSignatureListContext() override = default;\n\n  std::vector<NdSbpSignature>* MutNdSbpSignatureList() override { return nd_sbp_sig_list_; }\n\n  void AddNdSbpSignature(NdSbpSignature& nd_sbp_sig) override {\n    nd_sbp_sig_list_->emplace_back(nd_sbp_sig);\n  }\n\n  const Shape& parallel_hierarchy() override {\n    return *(CHECK_JUST(op_->GetOpParallelDesc())->hierarchy());\n  }\n\n  const Shape& BlobShape4InputArgNameAndIndex(const std::string& arg_name,\n                                              int32_t index) const override {\n    return CHECK_JUST(logical_blob_desc4ibn_(GenRepeatedBn(arg_name, index))).shape();\n  }\n\n private:\n  const UserOp* op_;\n  std::function<Maybe<const BlobDesc&>(const std::string&)> logical_blob_desc4ibn_;\n  const ParallelDesc parallel_desc_;\n  std::vector<NdSbpSignature>* nd_sbp_sig_list_;\n};\n\nMaybe<void> UserOp::InitFromOpConf() {\n  CHECK_OR_RETURN(op_conf().has_user_conf());\n  for (const auto& pair : op_conf().user_conf().input()) {\n    EnrollRepeatedInputBn(pair.first, pair.second.s_size());\n    for (int32_t i = 0; i < pair.second.s_size(); ++i) {\n      inputs_.emplace_back(std::make_pair(pair.first, i));\n    }\n  }\n  for (const auto& pair : op_conf().user_conf().output()) {\n    EnrollRepeatedOutputBn(pair.first, pair.second.s_size());\n    for (int32_t i = 0; i < pair.second.s_size(); ++i) {\n      outputs_.emplace_back(std::make_pair(pair.first, i));\n    }\n  }\n  EnrollTmpBn(GenRepeatedBn(\"tmp_buffer\", 0));\n  user_op_conf_.reset(new user_op::UserOpConfWrapper(shared_op_conf()));\n  val_ =\n      user_op::UserOpRegistryMgr::Get().GetOpRegistryResult(op_conf().user_conf().op_type_name());\n  if (val_ != nullptr) {\n    if (val_->input_arg_modify_fn) {\n      user_op::GetInputArgModifier GetInputArgModifierFn =\n          [&](const std::string& in_arg_name, int32_t in_arg_index) -> user_op::InputArgModifier* {\n        std::string ibn = GenRepeatedBn(in_arg_name, in_arg_index);\n        if (std::find(input_bns().begin(), input_bns().end(), ibn) != input_bns().end()) {\n          return MutInputBlobModifier4Ibn(ibn);\n        }\n        return nullptr;\n      };\n      JUST(val_->input_arg_modify_fn(GetInputArgModifierFn, *user_op_conf_));\n    }\n    if (val_->output_arg_modify_fn) {\n      user_op::GetOutputArgModifier GetOutputArgModifierFn =\n          [&](const std::string& out_arg_name,\n              int32_t out_arg_index) -> user_op::OutputArgModifier* {\n        std::string obn = GenRepeatedBn(out_arg_name, out_arg_index);\n        if (std::find(output_bns().begin(), output_bns().end(), obn) != output_bns().end()) {\n          return MutOutputBlobModifier4Obn(obn);\n        }\n        return nullptr;\n      };\n      JUST(val_->output_arg_modify_fn(GetOutputArgModifierFn, *user_op_conf_));\n    }\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> UserOp::InferInternalBlobDescs(\n    const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,\n    const ParallelContext* parallel_ctx, const JobDesc* job_desc) const {\n  // tmp buffer size must be inferred after out shape/dtype\n  UserOpInferContext infer_ctx(this, parallel_ctx, job_desc, GetBlobDesc4BnInOp);\n  const user_op::OpKernelRegistryResult* kernel_reg_val =\n      JUST(user_op::UserOpRegistryMgr::Get().GetOpKernelRegistryResult(\n          op_conf().user_conf().op_type_name(),\n          UserOpKernelRegContext(this, GetBlobDesc4BnInOp, parallel_ctx)));\n  CHECK_OR_RETURN(kernel_reg_val != nullptr)\n      << \"cannot find op_type: \" << op_conf().user_conf().op_type_name() << \" in kernel registry !\";\n\n  size_t tmp_size = kernel_reg_val->infer_tmp_size_fn(&infer_ctx);\n  if (tmp_size > 0) {\n    BlobDesc* tmp_buffer_blob = GetBlobDesc4BnInOp(GenRepeatedBn(\"tmp_buffer\", 0));\n    CHECK_NOTNULL_OR_RETURN(tmp_buffer_blob);\n    tmp_buffer_blob->set_data_type(DataType::kChar);\n    tmp_buffer_blob->set_memory_format(MemoryFormat::kContiguous);\n    tmp_buffer_blob->set_shape(Shape({static_cast<int64_t>(tmp_size)}));\n    tmp_buffer_blob->set_stride(Stride({static_cast<int64_t>(tmp_size)}));\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> UserOp::InferLogicalOutBlobDescs(\n    const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,\n    const ParallelDesc& parallel_desc) const {\n  CHECK_OR_RETURN(val_ != nullptr)\n      << \"cannot find op_type: \" << op_conf().user_conf().op_type_name() << \" in op registry!\";\n  // default method set output blob desc (such as Dtype, is_dynamic)\n  // set out blob desc attr as first input blob desc (if has)\n  BlobDesc* first_in_blob_desc = FindValidBlobDescOfBnsInOp(BlobDesc4BnInOp, input_bns());\n  if (first_in_blob_desc) {\n    for (const std::string& obn : output_bns()) {\n      BlobDesc4BnInOp(obn)->CopyFrom(*first_in_blob_desc);\n    }\n  }\n\n  UserOpInferContext infer_ctx(this, nullptr, nullptr, BlobDesc4BnInOp);\n\n  CHECK_OR_RETURN(val_->data_type_infer_fn)\n      << \"No InferDataType function for \" << val_->op_type_name;\n  JUST(val_->data_type_infer_fn(&infer_ctx));\n  JUST(val_->logical_tensor_desc_infer_fn(&infer_ctx));\n  for (const auto& pair : infer_ctx.outputs()) {\n    BlobDesc* out_blob_desc = BlobDesc4BnInOp(GenRepeatedBn(pair.first, pair.second));\n    const user_op::TensorDesc& tensor_desc = infer_ctx.OutputTensorDesc(pair.first, pair.second);\n    out_blob_desc->set_data_type(tensor_desc.data_type());\n    out_blob_desc->set_memory_format(tensor_desc.memory_format());\n    out_blob_desc->set_shape(tensor_desc.shape());\n    if (val_->non_contiguous_supported) {\n      out_blob_desc->set_stride(tensor_desc.stride());\n    } else {\n      out_blob_desc->set_stride(Stride(out_blob_desc->shape()));\n    }\n    CHECK_EQ_OR_RETURN(out_blob_desc->stride().size(), out_blob_desc->shape().size())\n        << Error::RuntimeError() << \"stride and shape size mismatch since stride is \"\n        << out_blob_desc->stride().ToString() << \" but shape is \"\n        << out_blob_desc->shape().ToString();\n    out_blob_desc->set_is_dynamic(tensor_desc.is_dynamic());\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> UserOp::InferOutBlobDescs(\n    const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,\n    const ParallelContext* parallel_ctx) const {\n  CHECK_OR_RETURN(val_ != nullptr)\n      << \"cannot find op_type: \" << op_conf().user_conf().op_type_name() << \" in op registry!\";\n  if (!val_->physical_tensor_desc_infer_fn) {\n    return Operator::InferOutBlobDescs(GetBlobDesc4BnInOp, parallel_ctx);\n  } else {\n    // default method set output blob desc (such as Dtype, is_dynamic, is_tensor_list)\n    // set out blob desc attr as first input blob desc (if has)\n    BlobDesc* first_in_blob_desc = FindValidBlobDescOfBnsInOp(GetBlobDesc4BnInOp, input_bns());\n    if (first_in_blob_desc) {\n      for (const std::string& obn : output_bns()) {\n        GetBlobDesc4BnInOp(obn)->CopyFrom(*first_in_blob_desc);\n      }\n    }\n    UserOpInferContext infer_ctx(this, parallel_ctx, nullptr, GetBlobDesc4BnInOp);\n\n    CHECK_OR_RETURN(val_->data_type_infer_fn)\n        << \"No InferDataType function for \" << val_->op_type_name;\n    JUST(val_->data_type_infer_fn(&infer_ctx));\n    JUST(val_->physical_tensor_desc_infer_fn(&infer_ctx));\n    for (const auto& pair : infer_ctx.outputs()) {\n      BlobDesc* out_blob_desc = GetBlobDesc4BnInOp(GenRepeatedBn(pair.first, pair.second));\n      out_blob_desc->set_data_type(infer_ctx.OutputDType(pair.first, pair.second));\n      out_blob_desc->set_memory_format(infer_ctx.OutputMemoryFormat(pair.first, pair.second));\n      out_blob_desc->set_shape(infer_ctx.OutputShape(pair.first, pair.second));\n      if (val_->non_contiguous_supported) {\n        out_blob_desc->set_stride(infer_ctx.OutputStride(pair.first, pair.second));\n      } else {\n        out_blob_desc->set_stride(Stride(out_blob_desc->shape()));\n      }\n      CHECK_EQ_OR_RETURN(out_blob_desc->stride().size(), out_blob_desc->shape().size())\n          << Error::RuntimeError() << \"stride and shape size mismatch since stride is \"\n          << out_blob_desc->stride().ToString() << \" but shape is \"\n          << out_blob_desc->shape().ToString();\n      out_blob_desc->set_is_dynamic(infer_ctx.OutputIsDynamic(pair.first, pair.second));\n    }\n    return Maybe<void>::Ok();\n  }\n}\n\nMaybe<void> UserOp::InferInplaceObn2Ibn(\n    HashMap<std::string, std::string>* mut_inplace_obn2ibn,\n    HashMap<std::string, std::string>* con_inplace_obn2ibn,\n    const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,\n    const ParallelContext* parallel_ctx) const {\n  UserOpInferContext infer_ctx(this, parallel_ctx, nullptr, GetBlobDesc4BnInOp);\n  const user_op::OpKernelRegistryResult* kernel_reg_val =\n      JUST(user_op::UserOpRegistryMgr::Get().GetOpKernelRegistryResult(\n          op_conf().user_conf().op_type_name(),\n          UserOpKernelRegContext(this, GetBlobDesc4BnInOp, parallel_ctx)));\n  CHECK_OR_RETURN(kernel_reg_val != nullptr)\n      << \"cannot find op_type: \" << op_conf().user_conf().op_type_name() << \" in kernel registry !\";\n  HashSet<std::string> bn_in_op_unique_check;\n  user_op::AddInplaceArgPair AddInplaceArgPairFn =\n      [&](const std::string& out_arg_name, int32_t out_arg_index, const std::string& in_arg_name,\n          int32_t in_arg_index, bool is_mutable) -> Maybe<void> {\n    std::string ibn = GenRepeatedBn(in_arg_name, in_arg_index);\n    std::string obn = GenRepeatedBn(out_arg_name, out_arg_index);\n    if (is_mutable) {\n      mut_inplace_obn2ibn->emplace(obn, ibn);\n    } else {\n      con_inplace_obn2ibn->emplace(obn, ibn);\n    }\n    CHECK_OR_RETURN(std::find(input_bns().begin(), input_bns().end(), ibn) != input_bns().end())\n        << \"Cannot find input_arg_name : \" << in_arg_name << \" input_arg_index : \" << in_arg_index\n        << \" in op_name: \" << op_conf().name();\n    CHECK_OR_RETURN(std::find(output_bns().begin(), output_bns().end(), obn) != output_bns().end())\n        << \"Cannot find output_arg_name : \" << out_arg_name\n        << \" output_arg_index : \" << out_arg_index << \" in op_name: \" << op_conf().name();\n\n    std::string repeated_ibn_err_msg =\n        \"Cannot repeated set inplace proposal for same intput arg : \" + in_arg_name\n        + \" index : \" + std::to_string(in_arg_index) + \" in op_name: \" + op_conf().name();\n    std::string repeated_obn_err_msg =\n        \"Cannot repeated set inplace proposal for same output arg : \" + out_arg_name\n        + \" index : \" + std::to_string(out_arg_index) + \" in op_name: \" + op_conf().name();\n    CHECK_OR_RETURN(bn_in_op_unique_check.insert(ibn).second) << repeated_ibn_err_msg;\n    CHECK_OR_RETURN(bn_in_op_unique_check.insert(obn).second) << repeated_obn_err_msg;\n    return Maybe<void>::Ok();\n  };\n  JUST(kernel_reg_val->inplace_proposal_fn(infer_ctx, AddInplaceArgPairFn));\n  return Maybe<void>::Ok();\n}\n\nLogicalBlobId UserOp::lbi4ibn(const std::string& input_bn) const {\n  auto pair = GenUnRepeatedBn(input_bn);\n  return GenLogicalBlobId(op_conf().user_conf().input().at(pair.first).s(pair.second));\n}\n\nLogicalBlobId UserOp::lbi4obn(const std::string& output_bn) const {\n  // TODO: remove this workaround and use different lbi for input and output\n  const bool is_copy_hd = op_conf().user_conf().op_type_name() == \"copy_d2h\"\n                          || op_conf().user_conf().op_type_name() == \"copy_h2d\";\n  if (is_copy_hd) { return GenLogicalBlobId(op_conf().user_conf().input().at(\"in\").s(0)); }\n  auto pair = GenUnRepeatedBn(output_bn);\n  auto ret = GenLogicalBlobId(op_conf().user_conf().output().at(pair.first).s(pair.second));\n  CHECK_EQ(ret.op_name(), op_conf().name());\n  CHECK_EQ(ret.blob_name(), output_bn);\n  return ret;\n}\n\nMaybe<void> UserOp::InferSbpSignature(\n    SbpSignature* sbp_signature, const SbpSignature& sbp_sig_conf,\n    const std::function<int32_t(const SbpSignature&)>& CalcOrderValue4SbpSig,\n    std::function<Maybe<const SbpInferHint*>(const std::string&)> SbpInferHint4Ibn,\n    const ParallelDesc& parallel_desc) const {\n  if (val_->sbp_signature_infer_fn) {\n    UserOpInferSbpSignatureFnContext ctx(this, sbp_signature, sbp_sig_conf, SbpInferHint4Ibn);\n    return val_->sbp_signature_infer_fn(&ctx);\n  } else {\n    return Operator::InferSbpSignature(sbp_signature, sbp_sig_conf, CalcOrderValue4SbpSig,\n                                       SbpInferHint4Ibn, parallel_desc);\n  }\n}\n\nMaybe<void> UserOp::GetSbpSignatures(\n    const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,\n    int32_t hierarchy_value, SbpSignatureList* sbp_sig_list) const {\n  CHECK_OR_RETURN(val_ != nullptr)\n      << \"cannot find op_type: \" << op_conf().user_conf().op_type_name() << \" in op registry!\";\n  UserOpSbpContext sbp_ctx(this, sbp_sig_list, LogicalBlobDesc4Ibn, hierarchy_value);\n  JUST(val_->get_sbp_fn(&sbp_ctx));\n  // Add Broadcast for source user op tick input\n  if (val_->op_def.input_size() == 1 && input_bns().size() == 1\n      && val_->op_def.input(0).name() == user_op::kUserSourceOpTickInputArgName) {\n    std::string tick_bn = GenRepeatedBn(user_op::kUserSourceOpTickInputArgName, 0);\n    CHECK_OR_RETURN(input_bns().Get(0) == tick_bn)\n        << \"user op_name: \" << op_conf().name()\n        << \" op_type_name: \" << op_conf().user_conf().op_type_name()\n        << \" set ERROR input arg name : \" << input_bns().Get(0) << \" because NO input in op def\";\n    for (auto& sbp_sig : *sbp_sig_list->mutable_sbp_signature()) {\n      auto* bn2sbp = sbp_sig.mutable_bn_in_op2sbp_parallel();\n      if (bn2sbp->find(tick_bn) == bn2sbp->end()) {\n        (*bn2sbp)[tick_bn].mutable_broadcast_parallel();\n      }\n    }\n  }\n  // Check valid\n  for (const auto& sbp_sig : sbp_sig_list->sbp_signature()) {\n    const auto& bn2sbp = sbp_sig.bn_in_op2sbp_parallel();\n    for (const auto& ibn : input_bns()) {\n      auto pair = GenUnRepeatedBn(ibn);\n      CHECK_OR_RETURN(bn2sbp.find(ibn) != bn2sbp.end())\n          << \"In op_name: \" << op_conf().name()\n          << \" op_type_name: \" << op_conf().user_conf().op_type_name()\n          << \", input_arg_name : \" << pair.first << \" input_arg_index : \" << pair.second\n          << \" have NOT set sbp signature\";\n    }\n    for (const auto& obn : output_bns()) {\n      auto pair = GenUnRepeatedBn(obn);\n      CHECK_OR_RETURN(bn2sbp.find(obn) != bn2sbp.end())\n          << \"In op_name: \" << op_conf().name()\n          << \" op_type_name: \" << op_conf().user_conf().op_type_name()\n          << \", output_arg_name : \" << pair.first << \" output_arg_index : \" << pair.second\n          << \" have NOT set sbp signature\";\n    }\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<double> UserOp::GetComputeComplexity(\n    NdSbpSignature* sbp_signature,\n    std::function<const BlobDesc&(const std::string& bn)> logical_blob_desc4bn,\n    const ParallelDesc& parallel_desc) const {\n  if (val_->compute_complexity_fn) {\n    UserOpComputeComplexityFnContext user_op_compute_complexity_fn_context(\n        op_conf(), parallel_desc, sbp_signature, logical_blob_desc4bn);\n    return val_->compute_complexity_fn(&user_op_compute_complexity_fn_context);\n  } else {\n    return Operator::GetComputeComplexity(sbp_signature, logical_blob_desc4bn, parallel_desc);\n  }\n}\n\nOperator::DumpNdSbpSignatureForOpConfFn UserOp::GetDumpNdSbpSignatureForOpConfFn() const {\n  if (val_->dump_nd_sbp_signature_for_op_conf_fn) {\n    return val_->dump_nd_sbp_signature_for_op_conf_fn;\n  } else {\n    return Operator::GetDumpNdSbpSignatureForOpConfFn();\n  }\n}\n\nMaybe<void> UserOp::InferOpTimeShape(\n    const std::function<Maybe<const Shape>(const std::string&)>& GetTimeShape4BnInOp,\n    std::shared_ptr<const Shape>* time_shape) const {\n  if (val_->output_blob_time_shape_infer_fn) {\n    std::shared_ptr<Shape> op_time_shape(new Shape());\n    UserOpInferOutputBlobTimeShapeFnContext infer_output_blob_time_shape_fn_ctx(\n        this, GetTimeShape4BnInOp, op_time_shape.get());\n    *time_shape = op_time_shape;\n    return val_->output_blob_time_shape_infer_fn(&infer_output_blob_time_shape_fn_ctx);\n  } else {\n    return Operator::InferOpTimeShape(GetTimeShape4BnInOp, time_shape);\n  }\n}\n\nnamespace {\n\nbool IgnoreInferNdSbpFnWhenFlatHierarchy(const std::string& op_type_name) {\n  return (op_type_name == \"reshape\" || op_type_name == \"reshape_like\");\n}\n\n}  // namespace\n\nMaybe<void> UserOp::InferNdSbpSignature(\n    NdSbpSignature* nd_sbp_signature, const NdSbpSignature& nd_sbp_constraints,\n    const ParallelDesc& parallel_desc,\n    std::function<Maybe<const NdSbpInferHint*>(const std::string&)> NdSbpInferHint4Ibn) const {\n  if (val_->nd_sbp_infer_fn\n      && (parallel_desc.hierarchy()->NumAxes() > 1\n          || !IgnoreInferNdSbpFnWhenFlatHierarchy(this->user_op_conf().op_type_name()))) {\n    UserOpInferNdSbpFnContext ctx(this, nd_sbp_signature, nd_sbp_constraints, NdSbpInferHint4Ibn);\n    JUST(val_->nd_sbp_infer_fn(&ctx));\n  } else {\n    JUST(Operator::InferNdSbpSignature(nd_sbp_signature, nd_sbp_constraints, parallel_desc,\n                                       NdSbpInferHint4Ibn));\n  }\n  std::string tick_bn = GenRepeatedBn(user_op::kUserSourceOpTickInputArgName, 0);\n  if (std::find(input_bns().begin(), input_bns().end(), tick_bn) != input_bns().end()) {\n    auto* map = nd_sbp_signature->mutable_bn_in_op2nd_sbp();\n    if (map->count(tick_bn) == 0) {\n      auto* sbp_list = (*map)[tick_bn].mutable_sbp_parallel();\n      for (int i = 0; i < parallel_desc.hierarchy()->NumAxes(); ++i) {\n        sbp_list->Add()->mutable_broadcast_parallel();\n      }\n    }\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> UserOp::EnumerateNdSbpSignatures(\n    const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,\n    const ParallelDesc& parallel_desc, std::vector<NdSbpSignature>* nd_sbp_sig_list) const {\n  if (val_->enumerate_nd_sbp_signatures_fn) {\n    NdSbpSignature empty_sbp_signature;\n    UserOpGetNdSbpSignatureListContext user_op_get_nd_sbp_list_context(\n        this, LogicalBlobDesc4Ibn, parallel_desc, nd_sbp_sig_list);\n    return val_->enumerate_nd_sbp_signatures_fn(&user_op_get_nd_sbp_list_context);\n  } else {\n    return Operator::EnumerateNdSbpSignatures(LogicalBlobDesc4Ibn, parallel_desc, nd_sbp_sig_list);\n  }\n}\n\nMaybe<void> UserOp::GetNdSbpSignatureList(\n    const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,\n    const ParallelDesc& parallel_desc, std::vector<NdSbpSignature>* nd_sbp_sig_list) const {\n  if (val_->get_nd_sbp_list_fn) {\n    NdSbpSignature empty_sbp_signature;\n    UserOpGetNdSbpSignatureListContext user_op_get_nd_sbp_list_context(\n        this, LogicalBlobDesc4Ibn, parallel_desc, nd_sbp_sig_list);\n    return val_->get_nd_sbp_list_fn(&user_op_get_nd_sbp_list_context);\n  } else {\n    JUST(Operator::GetNdSbpSignatureList(LogicalBlobDesc4Ibn, parallel_desc, nd_sbp_sig_list));\n  }\n  return Maybe<void>::Ok();\n}\n\nSymbol<OperatorConf> UserOp::GetOpConfWithoutOpNameAndLbn() const {\n  OperatorConf op_conf(this->op_conf());\n  op_conf.set_name(\"undefined-op-name\");\n  UserOpConf* user_op_conf = op_conf.mutable_user_conf();\n  for (auto& pair : *user_op_conf->mutable_input()) {\n    for (auto& str : *pair.second.mutable_s()) { str = \"undefined-op-name/undefined-ibn\"; }\n  }\n  for (auto& pair : *user_op_conf->mutable_output()) {\n    std::string prefix = \"undefined-op-name/\";\n    prefix += pair.first;\n    prefix += \"_\";\n    int i = 0;\n    for (auto& str : *pair.second.mutable_s()) { str = prefix + std::to_string(i++); }\n  }\n  return SymbolOf(op_conf);\n}\n\nvoid UserOp::VirtualGenKernelConf(\n    std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,\n    const ParallelContext* parallel_ctx, KernelConf* kernel_conf) const {\n  auto user_conf = kernel_conf->mutable_user_conf();\n  ForEachBnInOp([&](const std::string& bn) {\n    const BlobDesc* blob_desc = GetBlobDesc4BnInOp(bn);\n    if (blob_desc) { blob_desc->ToProto(&(*user_conf->mutable_bn_in_op2blob_desc())[bn]); }\n  });\n}\n\nconst user_op::UserOpConfWrapper& UserOp::user_op_conf() const {\n  CHECK(user_op_conf_);\n  return *user_op_conf_;\n}\n\nREGISTER_OP(OperatorConf::kUserConf, UserOp);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/operator/user_op.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_OPERATOR_USER_OP_H_\n#define ONEFLOW_CORE_OPERATOR_USER_OP_H_\n\n#include \"oneflow/core/framework/user_op_registry_manager.h\"\n#include \"oneflow/core/operator/operator.h\"\n\nnamespace oneflow {\n\nclass UserOp final : public Operator {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(UserOp);\n  UserOp() = default;\n  ~UserOp() = default;\n\n  using ArgVec = std::vector<std::pair<std::string, int32_t>>;\n\n  Maybe<void> InitFromOpConf() override;\n  Maybe<void> InferInternalBlobDescs(\n      const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,\n      const ParallelContext* parallel_ctx, const JobDesc* job_desc) const override;\n  Maybe<void> InferOutBlobDescs(\n      const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,\n      const ParallelContext* parallel_ctx) const override;\n  Maybe<void> InferLogicalOutBlobDescs(\n      const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,\n      const ParallelDesc& parallel_desc) const override;\n  Maybe<void> InferInplaceObn2Ibn(\n      HashMap<std::string, std::string>* mut_inplace_obn2ibn,\n      HashMap<std::string, std::string>* con_inplace_obn2ibn,\n      const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,\n      const ParallelContext* parallel_ctx) const override;\n  Maybe<double> GetComputeComplexity(\n      NdSbpSignature* sbp_signature,\n      std::function<const BlobDesc&(const std::string& bn)> logical_blob_desc4bn,\n      const ParallelDesc& parallel_desc) const override;\n  Operator::DumpNdSbpSignatureForOpConfFn GetDumpNdSbpSignatureForOpConfFn() const override;\n  Symbol<OperatorConf> GetOpConfWithoutOpNameAndLbn() const override;\n  const user_op::UserOpConfWrapper& user_op_conf() const;\n  const ArgVec& inputs() const { return inputs_; }\n  const ArgVec& outputs() const { return outputs_; }\n\n private:\n  LogicalBlobId lbi4ibn(const std::string& input_bn) const override;\n  LogicalBlobId lbi4obn(const std::string& output_bn) const override;\n  Maybe<void> InferSbpSignature(\n      SbpSignature* sbp_signature, const SbpSignature& sbp_sig_conf,\n      const std::function<int32_t(const SbpSignature&)>& CalcOrderValue4SbpSig,\n      std::function<Maybe<const SbpInferHint*>(const std::string&)> SbpInferHint4Ibn,\n      const ParallelDesc& parallel_desc) const override;\n  Maybe<void> GetSbpSignatures(\n      const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,\n      int32_t hierarchy_value, SbpSignatureList* sbp_sig_list) const override;\n  Maybe<void> EnumerateNdSbpSignatures(\n      const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,\n      const ParallelDesc& parallel_desc,\n      std::vector<NdSbpSignature>* nd_sbp_sig_list) const override;\n  Maybe<void> GetNdSbpSignatureList(\n      const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,\n      const ParallelDesc& parallel_desc,\n      std::vector<NdSbpSignature>* nd_sbp_sig_list) const override;\n  Maybe<void> InferOpTimeShape(\n      const std::function<Maybe<const Shape>(const std::string&)>& GetTimeShape4BnInOp,\n      std::shared_ptr<const Shape>* time_shape) const override;\n  Maybe<void> InferNdSbpSignature(NdSbpSignature* nd_sbp_signature,\n                                  const NdSbpSignature& nd_sbp_constraints,\n                                  const ParallelDesc& parallel_desc,\n                                  std::function<Maybe<const NdSbpInferHint*>(const std::string&)>\n                                      NdSbpInferHint4Ibn) const override;\n  void VirtualGenKernelConf(std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,\n                            const ParallelContext* parallel_ctx,\n                            KernelConf* kernel_conf) const override;\n\n  const user_op::OpRegistryResult* val_;\n  std::unique_ptr<user_op::UserOpConfWrapper> user_op_conf_;\n  ArgVec inputs_;\n  ArgVec outputs_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_OPERATOR_USER_OP_H_\n"
  },
  {
    "path": "oneflow/core/operator/variable_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/operator/variable_op.h\"\n#include \"oneflow/core/common/balanced_splitter.h\"\n#include \"oneflow/core/job/sbp_signature_builder.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nMaybe<void> ParseNdSbpFromConf(const VariableOpConf& conf, const ParallelDesc& parallel_desc,\n                               NdSbp* nd_sbp) {\n  const bool has_nd_sbp_conf = (conf.nd_sbp_size() != 0);\n  const int64_t num_axes = parallel_desc.hierarchy()->NumAxes();\n  if (has_nd_sbp_conf) { CHECK_EQ(conf.nd_sbp_size(), num_axes); }\n  nd_sbp->clear_sbp_parallel();\n  FOR_RANGE(int64_t, i, 0, num_axes) {\n    if (has_nd_sbp_conf) {\n      CHECK_OR_RETURN(ParseSbpParallelFromString(conf.nd_sbp(i), nd_sbp->add_sbp_parallel()));\n    } else {\n      nd_sbp->add_sbp_parallel()->mutable_broadcast_parallel();\n    }\n  }\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\nMaybe<void> VariableOp::InitFromOpConf() {\n  CHECK(op_conf().has_variable_conf());\n  if (op_conf().variable_conf().has_tick()) { EnrollInputBn(\"tick\", false); }\n  bool is_trainable = op_conf().variable_conf().trainable();\n  EnrollOutputBn(\"out\", is_trainable)->set_is_mutable(true);\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> VariableOp::InferLogicalOutBlobDescs(\n    const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,\n    const ParallelDesc& parallel_desc) const {\n  const VariableOpConf& variable_conf = op_conf().variable_conf();\n  BlobDesc* out_blob_desc = BlobDesc4BnInOp(\"out\");\n  out_blob_desc->set_shape(Shape(variable_conf.shape()));\n  CHECK_OR_RETURN(variable_conf.has_data_type());\n  out_blob_desc->set_data_type(variable_conf.data_type());\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> VariableOp::InferOutBlobDescs(\n    const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,\n    const ParallelContext* parallel_ctx) const {\n  const VariableOpConf& variable_conf = op_conf().variable_conf();\n  const ParallelDesc& parallel_desc = *JUST(GetOpParallelDesc());\n  BlobDesc* out_blob_desc = GetBlobDesc4BnInOp(\"out\");\n  CHECK_OR_RETURN(variable_conf.has_data_type());\n  out_blob_desc->set_data_type(variable_conf.data_type());\n  NdSbp nd_sbp;\n  JUST(ParseNdSbpFromConf(variable_conf, parallel_desc, &nd_sbp));\n  out_blob_desc->set_shape(\n      *JUST(GetPhysicalShape(Shape(variable_conf.shape()), nd_sbp, parallel_desc, *parallel_ctx)));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> VariableOp::GetSbpSignatures(SbpSignatureList* sbp_sig_list) const {\n  int64_t num_axes = op_conf().variable_conf().shape().dim_size();\n  for (int i = 0; i < num_axes; ++i) {\n    SbpSignatureBuilder()\n        .Broadcast(input_bns())\n        .Split(output_bns(), i)\n        .Build(sbp_sig_list->mutable_sbp_signature()->Add());\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> VariableOp::InferSbpSignature(\n    SbpSignature* sbp_signature, const SbpSignature& sbp_sig_conf,\n    const std::function<int32_t(const SbpSignature&)>& CalcOrderValue4SbpSig,\n    std::function<Maybe<const SbpInferHint*>(const std::string&)> SbpInferHint4Ibn,\n    const ParallelDesc& parallel_desc) const {\n  CHECK_EQ_OR_RETURN(parallel_desc.hierarchy()->NumAxes(), 1);\n  SbpSignatureBuilder sbp_sig_builder;\n  if (op_conf().variable_conf().nd_sbp_size() != 0) {\n    CHECK_EQ_OR_RETURN(op_conf().variable_conf().nd_sbp_size(), 1);\n    SbpParallel sbp_parallel;\n    CHECK_OR_RETURN(ParseSbpParallelFromString(op_conf().variable_conf().nd_sbp(0), &sbp_parallel));\n    if (sbp_parallel.has_split_parallel()) {\n      sbp_sig_builder.Split(output_bns(), sbp_parallel.split_parallel().axis());\n    } else {\n      sbp_sig_builder.Broadcast(output_bns());\n    }\n  } else {\n    sbp_sig_builder.Broadcast(output_bns());\n  }\n  sbp_sig_builder.Broadcast(input_bns()).Build(sbp_signature);\n  return Maybe<void>::Ok();\n}\n\nSymbol<OperatorConf> VariableOp::GetOpConfWithoutOpNameAndLbn() const {\n  return SymbolOf(this->op_conf());\n}\n\nMaybe<void> VariableOp::InferNdSbpSignature(\n    NdSbpSignature* nd_sbp_signature, const NdSbpSignature& nd_sbp_constraints,\n    const ParallelDesc& parallel_desc,\n    std::function<Maybe<const NdSbpInferHint*>(const std::string&)> NdSbpInferHint4Ibn) const {\n  const auto& parallel_hierarchy = parallel_desc.hierarchy();\n  const VariableOpConf& conf = this->op_conf().variable_conf();\n  NdSbp& out_nd_sbp = (*nd_sbp_signature->mutable_bn_in_op2nd_sbp())[\"out\"];\n  JUST(ParseNdSbpFromConf(conf, parallel_desc, &out_nd_sbp));\n  if (conf.has_tick()) {\n    NdSbp& tick_nd_sbp = (*nd_sbp_signature->mutable_bn_in_op2nd_sbp())[\"tick\"];\n    for (int64_t i = 0; i < parallel_hierarchy->NumAxes(); ++i) {\n      tick_nd_sbp.mutable_sbp_parallel()->Add()->mutable_broadcast_parallel();\n    }\n  }\n  return Maybe<void>::Ok();\n}\n\nOperator::DumpNdSbpSignatureForOpConfFn VariableOp::GetDumpNdSbpSignatureForOpConfFn() const {\n  return [](const NdSbpSignature& nd_sbp_sig, OperatorConf* op_conf) -> Maybe<void> {\n    CHECK_OR_RETURN(op_conf->has_variable_conf()) << \"VariableOp don't set variable op_conf\";\n    op_conf->mutable_variable_conf()->clear_nd_sbp();\n    const auto& nd_sbp = nd_sbp_sig.bn_in_op2nd_sbp().at(\"out\");\n\n    for (const auto& sbp_parallel : nd_sbp.sbp_parallel()) {\n      op_conf->mutable_variable_conf()->mutable_nd_sbp()->Add(SbpParallelToString(sbp_parallel));\n    }\n    return Maybe<void>::Ok();\n  };\n}\n\nREGISTER_OP(OperatorConf::kVariableConf, VariableOp);\nREGISTER_OP_SAME_OUTPUT_BLOB_REGST_NUM(OperatorConf::kVariableConf, 1);\nREGISTER_INTERFACE_OP(OperatorConf::kVariableConf);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/operator/variable_op.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_OPERATOR_VARIABLE_OP_H_\n#define ONEFLOW_CORE_OPERATOR_VARIABLE_OP_H_\n\n#include \"oneflow/core/operator/operator.h\"\n\nnamespace oneflow {\n\nclass VariableOp final : public Operator {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(VariableOp);\n  VariableOp() : Operator() {}\n  ~VariableOp() = default;\n\n  Maybe<void> InitFromOpConf() override;\n  Maybe<void> InferOutBlobDescs(\n      const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,\n      const ParallelContext* parallel_ctx) const override;\n\n private:\n  Maybe<void> InferLogicalOutBlobDescs(\n      const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,\n      const ParallelDesc& parallel_desc) const override;\n  Maybe<void> InferSbpSignature(\n      SbpSignature* sbp_signature, const SbpSignature& sbp_sig_conf,\n      const std::function<int32_t(const SbpSignature&)>& CalcOrderValue4SbpSig,\n      std::function<Maybe<const SbpInferHint*>(const std::string&)> SbpInferHint4Ibn,\n      const ParallelDesc& parallel_desc) const override;\n  Maybe<void> GetSbpSignatures(SbpSignatureList* sbp_sig_list) const override;\n  Symbol<OperatorConf> GetOpConfWithoutOpNameAndLbn() const override;\n  Maybe<void> InferNdSbpSignature(NdSbpSignature* nd_sbp_signature,\n                                  const NdSbpSignature& nd_sbp_constraints,\n                                  const ParallelDesc& parallel_desc,\n                                  std::function<Maybe<const NdSbpInferHint*>(const std::string&)>\n                                      NdSbpInferHint4Ibn) const override;\n  DumpNdSbpSignatureForOpConfFn GetDumpNdSbpSignatureForOpConfFn() const override;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_OPERATOR_VARIABLE_OP_H_\n"
  },
  {
    "path": "oneflow/core/operator/wait_and_send_ids_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/operator/wait_and_send_ids_op.h\"\n#include \"oneflow/core/job/sbp_signature_builder.h\"\n\nnamespace oneflow {\n\nMaybe<void> WaitAndSendIdsOp::InitFromOpConf() {\n  CHECK(op_conf().has_wait_and_send_ids_conf());\n  EnrollOutputBn(\"out\", false);\n  return Maybe<void>::Ok();\n}\n\nnamespace {\n\nMaybe<void> InferBlobDescs(const OperatorConf& op_conf,\n                           const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp) {\n  BlobDesc4BnInOp(\"out\")->set_shape(Shape({1}));\n  BlobDesc4BnInOp(\"out\")->set_data_type(op_conf.wait_and_send_ids_conf().data_type());\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\nMaybe<void> WaitAndSendIdsOp::InferLogicalOutBlobDescs(\n    const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,\n    const ParallelDesc& parallel_desc) const {\n  CHECK_EQ_OR_RETURN(parallel_desc.parallel_num(), 1);\n  return InferBlobDescs(op_conf(), BlobDesc4BnInOp);\n}\n\nMaybe<void> WaitAndSendIdsOp::InferOutBlobDescs(\n    const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,\n    const ParallelContext* parallel_ctx) const {\n  CHECK_EQ_OR_RETURN(parallel_ctx->parallel_num(), 1);\n  return InferBlobDescs(op_conf(), GetBlobDesc4BnInOp);\n}\n\nMaybe<void> WaitAndSendIdsOp::GetSbpSignatures(SbpSignatureList* sbp_sig_list) const {\n  SbpSignatureBuilder().Broadcast(output_bns()).Build(sbp_sig_list->mutable_sbp_signature()->Add());\n  return Maybe<void>::Ok();\n}\n\nREGISTER_CPU_OP(OperatorConf::kWaitAndSendIdsConf, WaitAndSendIdsOp);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/operator/wait_and_send_ids_op.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_OPERATOR_WAIT_AND_SEND_IDS_OP_H_\n#define ONEFLOW_CORE_OPERATOR_WAIT_AND_SEND_IDS_OP_H_\n\n#include \"oneflow/core/operator/operator.h\"\n\nnamespace oneflow {\n\nclass WaitAndSendIdsOp final : public Operator {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(WaitAndSendIdsOp);\n  WaitAndSendIdsOp() = default;\n  ~WaitAndSendIdsOp() = default;\n\n  Maybe<void> InitFromOpConf() override;\n  Maybe<void> InferLogicalOutBlobDescs(\n      const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,\n      const ParallelDesc& parallel_desc) const override;\n  Maybe<void> InferOutBlobDescs(\n      const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,\n      const ParallelContext* parallel_ctx) const override;\n\n private:\n  Maybe<void> GetSbpSignatures(SbpSignatureList* sbp_sig_list) const override;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_OPERATOR_WAIT_AND_SEND_IDS_OP_H_\n"
  },
  {
    "path": "oneflow/core/persistence/binary_in_stream.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_PERSISTENCE_BINARY_IN_STREAM_H_\n#define ONEFLOW_CORE_PERSISTENCE_BINARY_IN_STREAM_H_\n\n#include \"oneflow/core/common/util.h\"\n\nnamespace oneflow {\n\nclass BinaryInStream {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(BinaryInStream);\n  virtual ~BinaryInStream() = default;\n\n  // 0: success\n  // -1: eof\n  virtual int32_t Read(char* s, size_t n) = 0;\n\n  virtual uint64_t file_size() const = 0;\n  virtual uint64_t cur_file_pos() const = 0;\n  virtual void set_cur_file_pos(uint64_t val) = 0;\n  virtual bool IsEof() const = 0;\n\n protected:\n  BinaryInStream() = default;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_PERSISTENCE_BINARY_IN_STREAM_H_\n"
  },
  {
    "path": "oneflow/core/persistence/binary_in_stream_with_local_copy.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/persistence/binary_in_stream_with_local_copy.h\"\n#include \"oneflow/core/persistence/binary_in_stream_without_local_copy.h\"\n#include \"oneflow/core/common/str_util.h\"\n\nnamespace oneflow {\n\nBinaryInStreamWithLocalCopy::BinaryInStreamWithLocalCopy(fs::FileSystem* fs,\n                                                         const std::string& file_path)\n    : once_read_(false) {\n  VLOG(3) << \"New BinaryInStreamWithLocalCopy \" << file_path;\n  in_stream_.reset(new BinaryInStreamWithoutLocalCopy(fs, file_path));\n  local_copy_path_ = JoinPath(FLAGS_log_dir, \"global_fs_buffer\", file_path);\n  out_stream_.reset(new PersistentOutStream(LocalFS(), local_copy_path_));\n  read_mthd_ = &BinaryInStreamWithLocalCopy::ReadAndWriteToLocal;\n}\n\nint32_t BinaryInStreamWithLocalCopy::ReadAndWriteToLocal(char* s, size_t n) {\n  if (Restart()) {\n    CopyToLocalFinish();\n    return Read(s, n);\n  } else {\n    int32_t ret = in_stream_->Read(s, n);\n    CHECK_EQ(ret, 0);\n    out_stream_->Write(s, n);\n    once_read_ = true;\n    return 0;\n  }\n}\n\nbool BinaryInStreamWithLocalCopy::Restart() {\n  return in_stream_->cur_file_pos() == 0 && once_read_;\n}\n\nvoid BinaryInStreamWithLocalCopy::CopyToLocalFinish() {\n  out_stream_.reset();\n  in_stream_.reset(new BinaryInStreamWithoutLocalCopy(LocalFS(), local_copy_path_));\n  read_mthd_ = &BinaryInStreamWithLocalCopy::ReadFromLocal;\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/persistence/binary_in_stream_with_local_copy.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_PERSISTENCE_BINARY_IN_STREAM_WITH_LOCAL_COPY_H_\n#define ONEFLOW_CORE_PERSISTENCE_BINARY_IN_STREAM_WITH_LOCAL_COPY_H_\n\n#include \"oneflow/core/persistence/binary_in_stream.h\"\n#include \"oneflow/core/persistence/persistent_out_stream.h\"\n\nnamespace oneflow {\n\nclass BinaryInStreamWithLocalCopy final : public BinaryInStream {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(BinaryInStreamWithLocalCopy);\n  BinaryInStreamWithLocalCopy() = delete;\n  ~BinaryInStreamWithLocalCopy() = default;\n\n  BinaryInStreamWithLocalCopy(fs::FileSystem* fs, const std::string& file_path);\n\n  int32_t Read(char* s, size_t n) override { return (this->*read_mthd_)(s, n); }\n\n  uint64_t file_size() const override { return in_stream_->file_size(); }\n  uint64_t cur_file_pos() const override { return in_stream_->cur_file_pos(); }\n  void set_cur_file_pos(uint64_t val) override { in_stream_->set_cur_file_pos(val); }\n  bool IsEof() const override { return in_stream_->IsEof(); }\n\n private:\n  int32_t ReadAndWriteToLocal(char* s, size_t n);\n  int32_t ReadFromLocal(char* s, size_t n) { return in_stream_->Read(s, n); }\n\n  bool Restart();\n  void CopyToLocalFinish();\n\n  bool once_read_;\n  std::unique_ptr<BinaryInStream> in_stream_;\n  std::string local_copy_path_;\n  std::unique_ptr<PersistentOutStream> out_stream_;\n  int32_t (BinaryInStreamWithLocalCopy::*read_mthd_)(char*, size_t);\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_PERSISTENCE_BINARY_IN_STREAM_WITH_LOCAL_COPY_H_\n"
  },
  {
    "path": "oneflow/core/persistence/binary_in_stream_without_local_copy.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/persistence/binary_in_stream_without_local_copy.h\"\n#include \"oneflow/core/job/job_desc.h\"\n#include <cstring>\n\nnamespace oneflow {\n\nint32_t BinaryInStreamWithoutLocalCopy::Read(char* s, size_t n) {\n  if (IsEof()) return -1;\n  CHECK_LE(cur_file_pos_ + n, file_size_);\n  file_->Read(cur_file_pos_, n, s);\n  cur_file_pos_ += n;\n  return 0;\n}\n\nBinaryInStreamWithoutLocalCopy::BinaryInStreamWithoutLocalCopy(fs::FileSystem* fs,\n                                                               const std::string& file_path)\n    : cur_file_pos_(0) {\n  fs->NewRandomAccessFile(file_path, &file_);\n  file_size_ = fs->GetFileSize(file_path);\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/persistence/binary_in_stream_without_local_copy.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_PERSISTENCE_BINARY_IN_STREAM_WITHOUT_LOCAL_COPY_H_\n#define ONEFLOW_CORE_PERSISTENCE_BINARY_IN_STREAM_WITHOUT_LOCAL_COPY_H_\n\n#include \"oneflow/core/persistence/file_system.h\"\n#include \"oneflow/core/persistence/binary_in_stream.h\"\n\nnamespace oneflow {\n\nclass BinaryInStreamWithoutLocalCopy final : public BinaryInStream {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(BinaryInStreamWithoutLocalCopy);\n  BinaryInStreamWithoutLocalCopy() = delete;\n  virtual ~BinaryInStreamWithoutLocalCopy() = default;\n\n  BinaryInStreamWithoutLocalCopy(fs::FileSystem*, const std::string& file_path);\n  int32_t Read(char* s, size_t n) override;\n\n  uint64_t file_size() const override { return file_size_; }\n  uint64_t cur_file_pos() const override { return cur_file_pos_; }\n  void set_cur_file_pos(uint64_t val) override { cur_file_pos_ = val; }\n  bool IsEof() const override { return cur_file_pos_ == file_size_; }\n\n private:\n  std::unique_ptr<fs::RandomAccessFile> file_;\n  uint64_t file_size_;\n  uint64_t cur_file_pos_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_PERSISTENCE_BINARY_IN_STREAM_WITHOUT_LOCAL_COPY_H_\n"
  },
  {
    "path": "oneflow/core/persistence/file_system.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/persistence/file_system.h\"\n#include <errno.h>\n#include \"oneflow/core/common/str_util.h\"\n#include \"oneflow/core/job/job_desc.h\"\n#include \"oneflow/core/job/job_set.pb.h\"\n#include \"oneflow/core/persistence/hadoop/hadoop_file_system.h\"\n#include \"oneflow/core/persistence/posix/posix_file_system.h\"\n#include \"oneflow/core/job/job_set.pb.h\"\n\nnamespace oneflow {\n\nnamespace fs {\n\nstd::string FileSystem::SplitRecursiveDir(const std::string& dirname,\n                                          std::vector<std::string>& sub_dirs) {\n  std::string remaining_dir = dirname;\n  while (!remaining_dir.empty()) {\n    bool status = FileExists(remaining_dir);\n    if (status) { break; }\n    // Basename returns \"\" for / ending dirs.\n    if (remaining_dir[remaining_dir.length() - 1] != '/') {\n      sub_dirs.emplace_back(Basename(remaining_dir));\n    }\n    remaining_dir = Dirname(remaining_dir);\n  }\n\n  // sub_dirs contains all the dirs to be created but in reverse order.\n  std::reverse(sub_dirs.begin(), sub_dirs.end());\n  return remaining_dir;\n}\n\nvoid FileSystem::CreateDirIfNotExist(const std::string& dirname) {\n  if (IsDirectory(dirname)) { return; }\n  CreateDir(dirname);\n}\n\nvoid FileSystem::RecursivelyCreateDirIfNotExist(const std::string& dirname) {\n  if (IsDirectory(dirname)) { return; }\n  // sub_dirs contains all the dirs to be created but in reverse order.\n  std::vector<std::string> sub_dirs;\n  std::string remaining_dir = SplitRecursiveDir(dirname, sub_dirs);\n\n  // Now create the directories.\n  std::string built_path = remaining_dir;\n  for (const std::string& sub_dir : sub_dirs) {\n    built_path = JoinPath(built_path, sub_dir);\n    CreateDirIfNotExist(built_path);\n  }\n}\n\nbool FileSystem::IsDirEmpty(const std::string& dirname) { return ListDir(dirname).empty(); }\n\nstd::string FileSystem::TranslateName(const std::string& name) const { return CleanPath(name); }\n\nvoid FileSystem::MakeEmptyDir(const std::string& dirname) {\n  if (IsDirectory(dirname)) { RecursivelyDeleteDir(dirname); }\n  RecursivelyCreateDir(dirname);\n}\n\nvoid FileSystem::RecursivelyDeleteDir(const std::string& dirname) {\n  CHECK(FileExists(dirname));\n  std::deque<std::string> dir_q;      // Queue for the BFS\n  std::vector<std::string> dir_list;  // List of all dirs discovered\n  dir_q.emplace_back(dirname);\n  // ret : Status to be returned.\n  // Do a BFS on the directory to discover all the sub-directories. Remove all\n  // children that are files along the way. Then cleanup and remove the\n  // directories in reverse order.;\n  while (!dir_q.empty()) {\n    std::string dir = dir_q.front();\n    dir_q.pop_front();\n    dir_list.emplace_back(dir);\n    // GetChildren might fail if we don't have appropriate permissions.\n    std::vector<std::string> children = ListDir(dir);\n    for (const std::string& child : children) {\n      const std::string child_path = JoinPath(dir, child);\n      // If the child is a directory add it to the queue, otherwise delete it.\n      if (IsDirectory(child_path)) {\n        dir_q.emplace_back(child_path);\n      } else {\n        // Delete file might fail because of permissions issues or might be\n        // unimplemented.\n        DelFile(child_path);\n      }\n    }\n  }\n  // Now reverse the list of directories and delete them. The BFS ensures that\n  // we can delete the directories in this order.\n  std::reverse(dir_list.begin(), dir_list.end());\n  for (const std::string& dir : dir_list) {\n    // Delete dir might fail because of permissions issues or might be\n    // unimplemented.\n    DeleteDir(dir);\n  }\n}\n\nvoid FileSystem::RecursivelyCreateDir(const std::string& dirname) {\n  // sub_dirs contains all the dirs to be created but in reverse order.\n  std::vector<std::string> sub_dirs;\n  std::string remaining_dir = SplitRecursiveDir(dirname, sub_dirs);\n\n  // Now create the directories.\n  std::string built_path = remaining_dir;\n  for (const std::string& sub_dir : sub_dirs) {\n    built_path = JoinPath(built_path, sub_dir);\n    CreateDir(built_path);\n  }\n}\n\n}  // namespace fs\n\nvoid CreateLocalFS(std::unique_ptr<fs::FileSystem>& fs) {\n#ifdef OF_PLATFORM_POSIX\n  fs.reset(new fs::PosixFileSystem);\n#else\n  OF_UNIMPLEMENTED();\n#endif\n}\n\nvoid CreateHadoopFS(std::unique_ptr<fs::FileSystem>& fs, const std::string& namenode) {\n  fs.reset(new fs::HadoopFileSystem(namenode));\n}\n\nvoid CreateFileSystemFromEnv(std::unique_ptr<fs::FileSystem>& fs, const std::string& env_prefix) {\n  CHECK(!fs);\n\n  auto fs_type_env = env_prefix + \"_TYPE\";\n  const char* fs_type = std::getenv(fs_type_env.c_str());\n  std::string fs_type_str;\n  if (fs_type) {\n    fs_type_str = ToLower(fs_type);\n  } else {\n    // local file system by default\n    fs_type_str = \"local\";\n  }\n\n  if (fs_type_str == \"local\") {\n    CreateLocalFS(fs);\n  } else if (fs_type_str == \"hdfs\") {\n    auto hdfs_nn_env = env_prefix + \"_HDFS_NAMENODE\";\n    const char* hdfs_namenode = std::getenv(hdfs_nn_env.c_str());\n    if (hdfs_namenode == nullptr) {\n      LOG(FATAL) << \"env \" << hdfs_nn_env << \" must be set when \" << fs_type_env\n                 << \" be set to hdfs\";\n    }\n    CreateHadoopFS(fs, hdfs_namenode);\n  } else {\n    LOG(FATAL) << \"invalid value \" << fs_type << \" of env \" << fs_type_env;\n  }\n}\n\nfs::FileSystem* DataFS() {\n  static std::unique_ptr<fs::FileSystem> data_fs;\n  static std::mutex data_fs_mutex;\n  {\n    std::lock_guard<std::mutex> lock(data_fs_mutex);\n    if (!data_fs) { CreateFileSystemFromEnv(data_fs, \"ONEFLOW_DATA_FILE_SYSTEM\"); }\n  }\n  return data_fs.get();\n}\n\nfs::FileSystem* SnapshotFS() {\n  static std::unique_ptr<fs::FileSystem> snapshot_fs;\n  static std::mutex snapshot_fs_mutex;\n  {\n    std::lock_guard<std::mutex> lock(snapshot_fs_mutex);\n    if (!snapshot_fs) { CreateFileSystemFromEnv(snapshot_fs, \"ONEFLOW_SNAPSHOT_FILE_SYSTEM\"); }\n  }\n  return snapshot_fs.get();\n}\n\nfs::FileSystem* LocalFS() {\n  static std::unique_ptr<fs::FileSystem> local_fs;\n  static std::mutex local_fs_mutex;\n  {\n    std::lock_guard<std::mutex> lock(local_fs_mutex);\n    if (!local_fs) { CreateLocalFS(local_fs); }\n  }\n  return local_fs.get();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/persistence/file_system.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_PERSISTENCE_FILE_SYSTEM_H_\n#define ONEFLOW_CORE_PERSISTENCE_FILE_SYSTEM_H_\n\n#include \"oneflow/core/common/platform.h\"\n#include \"oneflow/core/common/util.h\"\n\nnamespace oneflow {\n\nnamespace fs {\n\n// A file abstraction for randomly reading the contents of a file.\nclass RandomAccessFile {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(RandomAccessFile);\n  RandomAccessFile() = default;\n  virtual ~RandomAccessFile() = default;\n\n  // Reads `n` bytes from the file starting at `offset`.\n  // Sets `*result` to the data that was read.\n  //\n  // Safe for concurrent use by multiple threads.\n  virtual void Read(uint64_t offset, size_t n, char* result) const = 0;\n\n private:\n};\n\n//  A file abstraction for sequential writing.\n//\n// The implementation must provide buffering since callers may append\n// small fragments at a time to the file.\nclass WritableFile {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(WritableFile);\n  WritableFile() = default;\n  virtual ~WritableFile() = default;\n\n  // Append 'data' to the file.\n  virtual void Append(const char* data, size_t n) = 0;\n\n  // Close the file.\n  //\n  // Flush() and de-allocate resources associated with this file\n  virtual void Close() = 0;\n\n  //  Flushes the file and optionally syncs contents to filesystem.\n  //\n  // This should flush any local buffers whose contents have not been\n  // delivered to the filesystem.\n  //\n  // If the process terminates after a successful flush, the contents\n  // may still be persisted, since the underlying filesystem may\n  // eventually flush the contents.  If the OS or machine crashes\n  // after a successful flush, the contents may or may not be\n  // persisted, depending on the implementation.\n  virtual void Flush() = 0;\n\n private:\n};\n\nclass FileSystem {\n public:\n  virtual ~FileSystem() = default;\n\n  // Creates a brand new random access read-only file with the\n  // specified name.\n  //\n  // On success, stores a pointer to the new file in\n  // *result.  On failure stores NULL in *result.\n  //\n  // The returned file may be concurrently accessed by multiple threads.\n  //\n  // The ownership of the returned RandomAccessFile is passed to the caller\n  // and the object should be deleted when is not used.\n  virtual void NewRandomAccessFile(const std::string& fname,\n                                   std::unique_ptr<RandomAccessFile>* result) = 0;\n\n  // Creates an object that writes to a new file with the specified\n  // name.\n  //\n  // Deletes any existing file with the same name and creates a\n  // new file.  On success, stores a pointer to the new file in\n  // *result.  On failure stores NULL in *result.\n  //\n  // The returned file will only be accessed by one thread at a time.\n  //\n  // The ownership of the returned WritableFile is passed to the caller\n  // and the object should be deleted when is not used.\n  virtual void NewWritableFile(const std::string& fname, std::unique_ptr<WritableFile>* result) = 0;\n\n  // Creates an object that either appends to an existing file, or\n  // writes to a new file (if the file does not exist to begin with).\n  //\n  // On success, stores a pointer to the new file in *result.\n  // On failure stores NULL in *result.\n  //\n  // The returned file will only be accessed by one thread at a time.\n  //\n  // The ownership of the returned WritableFile is passed to the caller\n  // and the object should be deleted when is not used.\n  virtual void NewAppendableFile(const std::string& fname,\n                                 std::unique_ptr<WritableFile>* result) = 0;\n\n  // Returns true if the named path exists and false otherwise.\n  virtual bool FileExists(const std::string& fname) = 0;\n\n  // Returns the immediate children in the `dir`\n  //\n  // The returned paths are relative to 'dir'.\n  virtual std::vector<std::string> ListDir(const std::string& dir) = 0;\n\n  // Deletes the named file.\n  // Using DelFile to avoid Windows macro\n  virtual void DelFile(const std::string& fname) = 0;\n\n  // Creates the specified directory.\n  virtual void CreateDir(const std::string& dirname) = 0;\n  virtual void CreateDirIfNotExist(const std::string& dirname);\n  virtual void RecursivelyCreateDir(const std::string& dirname);\n  void RecursivelyCreateDirIfNotExist(const std::string& dirname);\n\n  // Empty\n  bool IsDirEmpty(const std::string& dirname);\n  void MakeEmptyDir(const std::string& dirname);\n\n  // Deletes the specified directory.\n  virtual void DeleteDir(const std::string& dirname) = 0;\n\n  // Deletes the specified directory and all subdirectories and files\n  // underneath it. undeleted_files and undeleted_dirs stores the number of\n  // files and directories that weren't deleted.\n  virtual void RecursivelyDeleteDir(const std::string& dirname);\n\n  // Returns the size of `fname`.\n  virtual uint64_t GetFileSize(const std::string& fname) = 0;\n\n  // Overwrites the target if it exists.\n  virtual void RenameFile(const std::string& old_name, const std::string& new_name) = 0;\n\n  // Translate an URI to a filename for the FileSystem implementation.\n  //\n  // The implementation in this class cleans up the path, removing\n  // duplicate /'s, resolving .. and . (more details in\n  // str_util.h CleanPath).\n  virtual std::string TranslateName(const std::string& name) const;\n\n  // Returns whether the given path is a directory or not.\n  virtual bool IsDirectory(const std::string& fname) = 0;\n\n protected:\n  FileSystem() = default;\n\n private:\n  std::string SplitRecursiveDir(const std::string& dirname, std::vector<std::string>& sub_dirs);\n};\n\n}  // namespace fs\n\nfs::FileSystem* LocalFS();\nfs::FileSystem* DataFS();\nfs::FileSystem* SnapshotFS();\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_PERSISTENCE_FILE_SYSTEM_H_\n"
  },
  {
    "path": "oneflow/core/persistence/file_system_test.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <gtest/gtest.h>\n#include \"oneflow/core/common/process_state.h\"\n#include \"oneflow/core/common/str_util.h\"\n#include \"oneflow/core/persistence/posix/posix_file_system.h\"\n\nnamespace oneflow {\n\nnamespace fs {\n\nvoid TestFileOperation(FileSystem* file_system) {\n  std::string current_dir = GetCwd();\n  StringReplace(&current_dir, '\\\\', '/');\n  std::string file_name = JoinPath(current_dir, \"/tmp_test_file_asdfasdf\");\n  // write\n  std::unique_ptr<WritableFile> writable_file;\n  file_system->NewWritableFile(file_name, &writable_file);\n  std::string write_content = \"oneflow-file-system-test\";\n  writable_file->Append(write_content.substr(0, 10).c_str(), 10);\n  writable_file->Flush();\n  writable_file->Append(write_content.substr(10, 14).c_str(), 14);\n  writable_file->Close();\n  // write append\n  std::string append_content = \"append-text\";\n  std::unique_ptr<WritableFile> appendable_file;\n  file_system->NewAppendableFile(file_name, &appendable_file);\n  appendable_file->Append(append_content.c_str(), 11);\n  appendable_file->Flush();\n  appendable_file->Close();\n  // rename\n  std::string new_file_name = file_name + \"_new\";\n  file_system->RenameFile(file_name, new_file_name);\n  file_system->RenameFile(new_file_name, file_name);\n  // read\n  std::unique_ptr<RandomAccessFile> random_access_file;\n  file_system->NewRandomAccessFile(file_name, &random_access_file);\n  uint64_t file_size = file_system->GetFileSize(file_name);\n  ASSERT_EQ(file_size, 35);\n  char* read_array = new char[file_size];\n  random_access_file->Read(0, file_size, read_array);\n  std::string read_content(read_array, file_size);\n  ASSERT_EQ(write_content + append_content, read_content);\n  file_system->DelFile(file_name);\n  delete[] read_array;\n}\n\nvoid TestDirOperation(FileSystem* file_system) {\n  std::string current_dir = GetCwd();\n  StringReplace(&current_dir, '\\\\', '/');\n  std::string test_root_path = JoinPath(current_dir, \"/tmp_test_dir_asdfasdf\");\n  if (file_system->IsDirectory(test_root_path)) {\n    ASSERT_TRUE(file_system->ListDir(test_root_path).empty());\n  } else {\n    file_system->CreateDir(test_root_path);\n  }\n  std::string file_name = JoinPath(test_root_path, \"/direct_file_\");\n  std::string content = \"test_file\";\n  std::unique_ptr<WritableFile> file_a;\n  std::unique_ptr<WritableFile> file_b;\n  file_system->NewWritableFile(file_name + \"_a\", &file_a);\n  file_a->Append(content.c_str(), 9);\n  file_a->Close();\n  file_system->NewWritableFile(file_name + \"_b\", &file_b);\n  file_b->Append(content.c_str(), 9);\n  file_b->Close();\n  std::string child_dir = JoinPath(test_root_path, \"/direct_dir\");\n  file_system->CreateDir(child_dir);\n  ASSERT_EQ(file_system->ListDir(test_root_path).size(), 3);\n  file_system->DeleteDir(child_dir);\n  ASSERT_TRUE(!file_system->IsDirectory(child_dir));\n  file_system->RecursivelyDeleteDir(test_root_path);\n  ASSERT_TRUE(!file_system->IsDirectory(test_root_path));\n}\n\nvoid TestMultiThreadsDirOperation(FileSystem* file_system) {\n  std::string current_dir = GetCwd();\n  StringReplace(&current_dir, '\\\\', '/');\n  std::string test_root_path = JoinPath(current_dir, \"tmp_multithread_test_dir\");\n  std::vector<std::thread> thread_vector;\n  for (int i = 0; i < 10; i++) {\n    thread_vector.emplace_back(\n        std::thread([&]() { file_system->RecursivelyCreateDirIfNotExist(test_root_path); }));\n  }\n  for (int i = 0; i < 10; i++) { thread_vector[i].join(); }\n  ASSERT_TRUE(file_system->IsDirectory(test_root_path));\n}\n\nvoid TestFileSystem(FileSystem* file_system) {\n  TestFileOperation(file_system);\n  TestDirOperation(file_system);\n  TestMultiThreadsDirOperation(file_system);\n}\n\n}  // namespace fs\n\nTEST(file_system, write_and_read) {\n#ifdef OF_PLATFORM_POSIX\n  fs::FileSystem* file_system = new fs::PosixFileSystem();\n  fs::TestFileSystem(file_system);\n#endif\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/persistence/hadoop/hadoop_file_system.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/persistence/hadoop/hadoop_file_system.h\"\n#include <mutex>\n#include \"oneflow/core/common/str_util.h\"\n\n#ifdef OF_PLATFORM_POSIX\n\n#include <dlfcn.h>\n\n#endif  // OF_PLATFORM_POSIX\n\n#define FS_RETURN_FALSE_IF_FALSE(val) \\\n  if (!val) {                         \\\n    PLOG(WARNING);                    \\\n    return false;                     \\\n  }\n\nnamespace oneflow {\n\nnamespace fs {\n\nnamespace internal {\n\n#ifdef OF_PLATFORM_POSIX\n\nbool GetSymbolFromLibrary(void* handle, const char* symbol_name, void** symbol) {\n  *symbol = dlsym(handle, symbol_name);\n  if (!*symbol) {\n    PLOG(WARNING) << dlerror();\n    return false;\n  }\n  return true;\n}\n\nbool LoadLibrary(const char* library_filename, void** handle) {\n  *handle = dlopen(library_filename, RTLD_NOW | RTLD_LOCAL);\n  if (!*handle) {\n    PLOG(WARNING) << dlerror();\n    return false;\n  }\n  return true;\n}\n\n#endif  // OF_PLATFORM_POSIX\n\n}  // namespace internal\n\ntemplate<typename R, typename... Args>\nbool BindFunc(void* handle, const char* name, std::function<R(Args...)>* func) {\n  void* symbol_ptr = nullptr;\n  FS_RETURN_FALSE_IF_FALSE(internal::GetSymbolFromLibrary(handle, name, &symbol_ptr));\n  *func = reinterpret_cast<R (*)(Args...)>(symbol_ptr);\n  return true;\n}\n\nvoid LibHDFS::LoadAndBind() {\n  auto TryLoadAndBind = [this](const char* name, void** handle) -> bool {\n    FS_RETURN_FALSE_IF_FALSE(internal::LoadLibrary(name, handle));\n#define BIND_HDFS_FUNC(function) FS_RETURN_FALSE_IF_FALSE(BindFunc(*handle, #function, &function));\n\n    BIND_HDFS_FUNC(hdfsBuilderConnect);\n    BIND_HDFS_FUNC(hdfsNewBuilder);\n    BIND_HDFS_FUNC(hdfsBuilderSetNameNode);\n    BIND_HDFS_FUNC(hdfsConfGetStr);\n    BIND_HDFS_FUNC(hdfsBuilderSetKerbTicketCachePath);\n    BIND_HDFS_FUNC(hdfsCloseFile);\n    BIND_HDFS_FUNC(hdfsPread);\n    BIND_HDFS_FUNC(hdfsWrite);\n    BIND_HDFS_FUNC(hdfsHFlush);\n    BIND_HDFS_FUNC(hdfsHSync);\n    BIND_HDFS_FUNC(hdfsOpenFile);\n    BIND_HDFS_FUNC(hdfsExists);\n    BIND_HDFS_FUNC(hdfsListDirectory);\n    BIND_HDFS_FUNC(hdfsFreeFileInfo);\n    BIND_HDFS_FUNC(hdfsDelete);\n    BIND_HDFS_FUNC(hdfsCreateDirectory);\n    BIND_HDFS_FUNC(hdfsGetPathInfo);\n    BIND_HDFS_FUNC(hdfsRename);\n#undef BIND_HDFS_FUNC\n    return true;\n  };\n\n  // libhdfs.so won't be in the standard locations. Use the path as specified\n  // in the libhdfs documentation.\n  const char* kLibHdfsDso = \"libhdfs.so\";\n  char* hdfs_home = getenv(\"HADOOP_HOME\");\n  if (hdfs_home == nullptr) {\n    PLOG(WARNING) << \"Environment variable HADOOP_HOME not set\";\n    status_ = false;\n    return;\n  }\n  std::string path = JoinPath(hdfs_home, \"lib\", \"native\", kLibHdfsDso);\n  status_ = TryLoadAndBind(path.c_str(), &handle_);\n  if (!status_) {\n    // try load libhdfs.so using dynamic loader's search path in case\n    // libhdfs.so is installed in non-standard location\n    status_ = TryLoadAndBind(kLibHdfsDso, &handle_);\n  }\n}\n\nHadoopFileSystem::HadoopFileSystem(const std::string& namenode)\n    : namenode_(namenode), hdfs_(LibHDFS::Load()) {}\n\nbool HadoopFileSystem::Connect(hdfsFS* fs) {\n  FS_RETURN_FALSE_IF_FALSE(hdfs_->status());\n  hdfsBuilder* builder = hdfs_->hdfsNewBuilder();\n  hdfs_->hdfsBuilderSetNameNode(builder, namenode_.c_str());\n  // KERB_TICKET_CACHE_PATH will be deleted in the future, Because KRB5CCNAME\n  // is the build in environment variable of Kerberos, so\n  // KERB_TICKET_CACHE_PATH and related code are unnecessary.\n  char* ticket_cache_path = getenv(\"KERB_TICKET_CACHE_PATH\");\n  if (ticket_cache_path != nullptr) {\n    hdfs_->hdfsBuilderSetKerbTicketCachePath(builder, ticket_cache_path);\n  }\n  *fs = hdfs_->hdfsBuilderConnect(builder);\n  if (*fs == nullptr) {\n    PLOG(WARNING) << \" HDFS connect failed. NOT FOUND\";\n    return false;\n  }\n  return true;\n}\n\nclass HDFSRandomAccessFile : public RandomAccessFile {\n public:\n  HDFSRandomAccessFile(const std::string& filename, const std::string& hdfs_filename, LibHDFS* hdfs,\n                       hdfsFS fs, hdfsFile file)\n      : filename_(filename), hdfs_filename_(hdfs_filename), hdfs_(hdfs), fs_(fs), file_(file) {}\n\n  ~HDFSRandomAccessFile() override {\n    if (file_ != nullptr) {\n      std::unique_lock<std::mutex> lock(mu_);\n      hdfs_->hdfsCloseFile(fs_, file_);\n    }\n  }\n\n  void Read(uint64_t offset, size_t n, char* result) const override {\n    char* dst = result;\n    bool eof_retried = false;\n    while (n > 0) {\n      // We lock inside the loop rather than outside so we don't block other\n      // concurrent readers.\n      std::unique_lock<std::mutex> lock(mu_);\n      tSize r =\n          hdfs_->hdfsPread(fs_, file_, static_cast<tOffset>(offset), dst, static_cast<tSize>(n));\n      if (r > 0) {\n        dst += r;\n        n -= r;\n        offset += r;\n      } else if (!eof_retried && r == 0) {\n        // Always reopen the file upon reaching EOF to see if there's more data.\n        // If writers are streaming contents while others are concurrently\n        // reading, HDFS requires that we reopen the file to see updated\n        // contents.\n        PCHECK(file_ == nullptr || hdfs_->hdfsCloseFile(fs_, file_) == 0) << filename_;\n        file_ = hdfs_->hdfsOpenFile(fs_, hdfs_filename_.c_str(), O_RDONLY, 0, 0, 0);\n        PCHECK(file_ != nullptr) << filename_;\n        eof_retried = true;\n      } else if (eof_retried && r == 0) {\n        PLOG(FATAL) << \"Read less bytes than requested\";\n        return;\n      } else if (errno == EINTR || errno == EAGAIN) {\n        // hdfsPread may return EINTR too. Just retry.\n      } else {\n        PLOG(FATAL) << filename_;\n        return;\n      }\n    }\n  }\n\n private:\n  std::string filename_;\n  std::string hdfs_filename_;\n  LibHDFS* hdfs_;\n  hdfsFS fs_;\n\n  mutable std::mutex mu_;\n  mutable hdfsFile file_;\n};\n\nvoid HadoopFileSystem::NewRandomAccessFile(const std::string& fname,\n                                           std::unique_ptr<RandomAccessFile>* result) {\n  hdfsFS fs = nullptr;\n  CHECK(Connect(&fs));\n\n  hdfsFile file = hdfs_->hdfsOpenFile(fs, TranslateName(fname).c_str(), O_RDONLY, 0, 0, 0);\n  PCHECK(file != nullptr) << fname;\n  result->reset(new HDFSRandomAccessFile(fname, TranslateName(fname), hdfs_, fs, file));\n  CHECK_NOTNULL(result->get());\n}\n\nclass HDFSWritableFile : public WritableFile {\n public:\n  HDFSWritableFile(const std::string& fname, LibHDFS* hdfs, hdfsFS fs, hdfsFile file)\n      : filename_(fname), hdfs_(hdfs), fs_(fs), file_(file) {}\n\n  ~HDFSWritableFile() override {\n    if (file_ != nullptr) { Close(); }\n  }\n\n  void Append(const char* data, size_t n) override {\n    PCHECK(hdfs_->hdfsWrite(fs_, file_, data, static_cast<tSize>(n)) != -1) << filename_;\n  }\n\n  void Close() override {\n    int32_t result = hdfs_->hdfsCloseFile(fs_, file_);\n    hdfs_ = nullptr;\n    fs_ = nullptr;\n    file_ = nullptr;\n    PCHECK(result == 0) << filename_;\n  }\n\n  void Flush() override { PCHECK(hdfs_->hdfsHFlush(fs_, file_) == 0) << filename_; }\n\n private:\n  std::string filename_;\n  LibHDFS* hdfs_;\n  hdfsFS fs_;\n  hdfsFile file_;\n};\n\nvoid HadoopFileSystem::NewWritableFile(const std::string& fname,\n                                       std::unique_ptr<WritableFile>* result) {\n  hdfsFS fs = nullptr;\n  CHECK(Connect(&fs));\n\n  hdfsFile file = hdfs_->hdfsOpenFile(fs, TranslateName(fname).c_str(), O_WRONLY, 0, 0, 0);\n  PCHECK(file != nullptr) << fname;\n  result->reset(new HDFSWritableFile(fname, hdfs_, fs, file));\n  CHECK_NOTNULL(result->get());\n}\n\nvoid HadoopFileSystem::NewAppendableFile(const std::string& fname,\n                                         std::unique_ptr<WritableFile>* result) {\n  hdfsFS fs = nullptr;\n  CHECK(Connect(&fs));\n\n  hdfsFile file =\n      hdfs_->hdfsOpenFile(fs, TranslateName(fname).c_str(), O_WRONLY | O_APPEND, 0, 0, 0);\n  PCHECK(file != nullptr) << fname;\n  result->reset(new HDFSWritableFile(fname, hdfs_, fs, file));\n  CHECK_NOTNULL(result->get());\n}\n\nbool HadoopFileSystem::FileExists(const std::string& fname) {\n  hdfsFS fs = nullptr;\n  CHECK(Connect(&fs));\n  if (hdfs_->hdfsExists(fs, TranslateName(fname).c_str()) == 0) { return true; }\n  return false;\n}\n\nstd::vector<std::string> HadoopFileSystem::ListDir(const std::string& dir) {\n  std::vector<std::string> result;\n  hdfsFS fs = nullptr;\n  CHECK(Connect(&fs));\n\n  // hdfsListDirectory returns nullptr if the directory is empty. Do a separate\n  // check to verify the directory exists first.\n  CHECK(IsDirectory(dir)) << \"directory not found, path: \" << dir;\n\n  int entries = 0;\n  hdfsFileInfo* info = hdfs_->hdfsListDirectory(fs, TranslateName(dir).c_str(), &entries);\n  if (info == nullptr) {\n    // Assume it's an empty directory.\n    return result;\n  }\n  for (int i = 0; i < entries; i++) { result.emplace_back(Basename(info[i].mName)); }\n  hdfs_->hdfsFreeFileInfo(info, entries);\n  return result;\n}\n\nvoid HadoopFileSystem::DelFile(const std::string& fname) {\n  hdfsFS fs = nullptr;\n  CHECK(Connect(&fs));\n  PCHECK(hdfs_->hdfsDelete(fs, TranslateName(fname).c_str(), /*recursive=*/0) == 0) << fname;\n}\n\nvoid HadoopFileSystem::CreateDir(const std::string& dir) {\n  hdfsFS fs = nullptr;\n  CHECK(Connect(&fs));\n\n  PCHECK(hdfs_->hdfsCreateDirectory(fs, TranslateName(dir).c_str()) == 0) << dir;\n}\n\nvoid HadoopFileSystem::DeleteDir(const std::string& dir) {\n  hdfsFS fs = nullptr;\n  CHECK(Connect(&fs));\n\n  // Count the number of entries in the directory, and only delete if it's\n  // non-empty. This is consistent with the interface, but note that there's\n  // a race condition where a file may be added after this check, in which\n  // case the directory will still be deleted.\n  int entries = 0;\n  hdfsFileInfo* info = hdfs_->hdfsListDirectory(fs, TranslateName(dir).c_str(), &entries);\n  if (info != nullptr) { hdfs_->hdfsFreeFileInfo(info, entries); }\n  // Due to HDFS bug HDFS-8407, we can't distinguish between an error and empty\n  // folder, expscially for Kerberos enable setup, EAGAIN is quite common\n  // when the call is actually successful. Check again by Stat.\n  if (info == nullptr && errno != 0) {\n    CHECK(IsDirectory(dir)) << \"directory not found, path: \" << dir;\n  }\n  PCHECK(entries == 0) << dir << \"Cannot delete a non-empty directory.\";\n  PCHECK(hdfs_->hdfsDelete(fs, TranslateName(dir).c_str(), /*recursive=*/1) == 0) << dir;\n}\n\nvoid HadoopFileSystem::RecursivelyDeleteDir(const std::string& dirname) {\n  hdfsFS fs = nullptr;\n  CHECK(Connect(&fs));\n\n  PCHECK(hdfs_->hdfsDelete(fs, TranslateName(dirname).c_str(), /*recursive=*/1) == 0) << dirname;\n}\n\nuint64_t HadoopFileSystem::GetFileSize(const std::string& fname) {\n  hdfsFS fs = nullptr;\n  CHECK(Connect(&fs));\n\n  hdfsFileInfo* info = hdfs_->hdfsGetPathInfo(fs, TranslateName(fname).c_str());\n  PCHECK(info != nullptr) << fname;\n  uint64_t ret = info->mSize;\n  hdfs_->hdfsFreeFileInfo(info, 1);\n  return ret;\n}\n\nvoid HadoopFileSystem::RenameFile(const std::string& old_name, const std::string& new_name) {\n  hdfsFS fs = nullptr;\n  CHECK(Connect(&fs));\n\n  PCHECK(hdfs_->hdfsExists(fs, TranslateName(new_name).c_str()) != 0\n         || hdfs_->hdfsDelete(fs, TranslateName(new_name).c_str(), /*recursive=*/0) == 0)\n      << new_name;\n\n  PCHECK(hdfs_->hdfsRename(fs, TranslateName(old_name).c_str(), TranslateName(new_name).c_str())\n         == 0)\n      << old_name;\n}\n\nbool HadoopFileSystem::IsDirectory(const std::string& fname) {\n  hdfsFS fs = nullptr;\n  CHECK(Connect(&fs));\n\n  hdfsFileInfo* info = hdfs_->hdfsGetPathInfo(fs, TranslateName(fname).c_str());\n  if (info == nullptr || info->mKind != kObjectKindDirectory) { return false; }\n  hdfs_->hdfsFreeFileInfo(info, 1);\n  return true;\n}\n\n}  // namespace fs\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/persistence/hadoop/hadoop_file_system.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_PERSISTENCE_HADOOP_HADOOP_FILE_SYSTEM_H_\n#define ONEFLOW_CORE_PERSISTENCE_HADOOP_HADOOP_FILE_SYSTEM_H_\n\n#include \"oneflow/core/job/job_desc.h\"\n#include \"oneflow/core/job/plan.pb.h\"\n#include \"oneflow/core/persistence/file_system.h\"\n#include \"oneflow/core/persistence/hadoop/hdfs.h\"\n\nextern \"C\" {\nstruct hdfs_internal;\ntypedef hdfs_internal* hdfsFS;\n}\n\nnamespace oneflow {\n\nnamespace fs {\n\nclass LibHDFS {\n public:\n  static LibHDFS* Load() {\n    static LibHDFS* lib = []() -> LibHDFS* {\n      LibHDFS* lib = new LibHDFS;\n      lib->LoadAndBind();\n      return lib;\n    }();\n    return lib;\n  }\n\n  // The status, if any, from failure to load.\n  // true is OK\n  // false is non-OK\n  bool status() { return status_; }\n\n  std::function<hdfsFS(hdfsBuilder*)> hdfsBuilderConnect;\n  std::function<hdfsBuilder*()> hdfsNewBuilder;\n  std::function<void(hdfsBuilder*, const char*)> hdfsBuilderSetNameNode;\n  std::function<int(const char*, char**)> hdfsConfGetStr;\n  std::function<void(hdfsBuilder*, const char* kerbTicketCachePath)>\n      hdfsBuilderSetKerbTicketCachePath;\n  std::function<int(hdfsFS, hdfsFile)> hdfsCloseFile;\n  std::function<tSize(hdfsFS, hdfsFile, tOffset, void*, tSize)> hdfsPread;\n  std::function<tSize(hdfsFS, hdfsFile, const void*, tSize)> hdfsWrite;\n  std::function<int(hdfsFS, hdfsFile)> hdfsHFlush;\n  std::function<int(hdfsFS, hdfsFile)> hdfsHSync;\n  std::function<hdfsFile(hdfsFS, const char*, int, int, short, tSize)> hdfsOpenFile;\n  std::function<int(hdfsFS, const char*)> hdfsExists;\n  std::function<hdfsFileInfo*(hdfsFS, const char*, int*)> hdfsListDirectory;\n  std::function<void(hdfsFileInfo*, int)> hdfsFreeFileInfo;\n  std::function<int(hdfsFS, const char*, int recursive)> hdfsDelete;\n  std::function<int(hdfsFS, const char*)> hdfsCreateDirectory;\n  std::function<hdfsFileInfo*(hdfsFS, const char*)> hdfsGetPathInfo;\n  std::function<int(hdfsFS, const char*, const char*)> hdfsRename;\n\n private:\n  void LoadAndBind();\n  bool status_;\n  void* handle_ = nullptr;\n};\n\nclass HadoopFileSystem final : public FileSystem {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(HadoopFileSystem);\n  HadoopFileSystem() = delete;\n  ~HadoopFileSystem() = default;\n\n  HadoopFileSystem(const std::string&);\n\n  void NewRandomAccessFile(const std::string& fname,\n                           std::unique_ptr<RandomAccessFile>* result) override;\n\n  void NewWritableFile(const std::string& fname, std::unique_ptr<WritableFile>* result) override;\n\n  void NewAppendableFile(const std::string& fname, std::unique_ptr<WritableFile>* result) override;\n\n  bool FileExists(const std::string& fname) override;\n\n  std::vector<std::string> ListDir(const std::string& dir) override;\n\n  void DelFile(const std::string& fname) override;\n\n  void CreateDir(const std::string& dirname) override;\n\n  void DeleteDir(const std::string& dirname) override;\n\n  void RecursivelyDeleteDir(const std::string& dirname) override;\n\n  uint64_t GetFileSize(const std::string& fname) override;\n\n  void RenameFile(const std::string& old_name, const std::string& new_name) override;\n\n  bool IsDirectory(const std::string& fname) override;\n\n private:\n  bool Connect(hdfsFS* fs);\n  std::string namenode_;\n  LibHDFS* hdfs_;\n};\n\n}  // namespace fs\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_PERSISTENCE_HADOOP_HADOOP_FILE_SYSTEM_H_\n"
  },
  {
    "path": "oneflow/core/persistence/hadoop/hdfs.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_PERSISTENCE_HADOOP_HDFS_H_\n#define ONEFLOW_CORE_PERSISTENCE_HADOOP_HDFS_H_\n\n#include <errno.h>  /* for EINTERNAL, etc. */\n#include <fcntl.h>  /* for O_RDONLY, O_WRONLY */\n#include <stdint.h> /* for uint64_t, etc. */\n#include <time.h>   /* for time_t */\n\n/*\n * Support export of DLL symbols during libhdfs build, and import of DLL symbols\n * during client application build.  A client application may optionally define\n * symbol LIBHDFS_DLL_IMPORT in its build.  This is not strictly required, but\n * the compiler can produce more efficient code with it.\n */\n#ifdef WIN32\n#ifdef LIBHDFS_DLL_EXPORT\n#define LIBHDFS_EXTERNAL __declspec(dllexport)\n#elif LIBHDFS_DLL_IMPORT\n#define LIBHDFS_EXTERNAL __declspec(dllimport)\n#else\n#define LIBHDFS_EXTERNAL\n#endif\n#else\n#ifdef LIBHDFS_DLL_EXPORT\n#define LIBHDFS_EXTERNAL __attribute__((visibility(\"default\")))\n#elif LIBHDFS_DLL_IMPORT\n#define LIBHDFS_EXTERNAL __attribute__((visibility(\"default\")))\n#else\n#define LIBHDFS_EXTERNAL\n#endif\n#endif\n\n#ifndef O_RDONLY\n#define O_RDONLY 1\n#endif\n\n#ifndef O_WRONLY\n#define O_WRONLY 2\n#endif\n\n#ifndef EINTERNAL\n#define EINTERNAL 255\n#endif\n\n#define ELASTIC_BYTE_BUFFER_POOL_CLASS \"org/apache/hadoop/io/ElasticByteBufferPool\"\n\n/** All APIs set errno to meaningful values */\n\n#ifdef __cplusplus\nextern \"C\" {\n#endif\n/**\n * Some utility decls used in libhdfs.\n */\nstruct hdfsBuilder;\ntypedef int32_t tSize;    /// size of data for read/write io ops\ntypedef time_t tTime;     /// time type in seconds\ntypedef int64_t tOffset;  /// offset within the file\ntypedef uint16_t tPort;   /// port\ntypedef enum tObjectKind {\n  kObjectKindFile = 'F',\n  kObjectKindDirectory = 'D',\n} tObjectKind;\n\n/**\n * The C reflection of org.apache.org.hadoop.FileSystem .\n */\nstruct hdfs_internal;\ntypedef struct hdfs_internal* hdfsFS;\n\nstruct hdfsFile_internal;\ntypedef struct hdfsFile_internal* hdfsFile;\n\nstruct hadoopRzOptions;\n\nstruct hadoopRzBuffer;\n\n/**\n * Determine if a file is open for read.\n *\n * @param file     The HDFS file\n * @return         1 if the file is open for read; 0 otherwise\n */\nLIBHDFS_EXTERNAL\nint hdfsFileIsOpenForRead(hdfsFile file);\n\n/**\n * Determine if a file is open for write.\n *\n * @param file     The HDFS file\n * @return         1 if the file is open for write; 0 otherwise\n */\nLIBHDFS_EXTERNAL\nint hdfsFileIsOpenForWrite(hdfsFile file);\n\nstruct hdfsReadStatistics {\n  uint64_t totalBytesRead;\n  uint64_t totalLocalBytesRead;\n  uint64_t totalShortCircuitBytesRead;\n  uint64_t totalZeroCopyBytesRead;\n};\n\n/**\n * Get read statistics about a file.  This is only applicable to files\n * opened for reading.\n *\n * @param file     The HDFS file\n * @param stats    (out parameter) on a successful return, the read\n *                 statistics.  Unchanged otherwise.  You must free the\n *                 returned statistics with hdfsFileFreeReadStatistics.\n * @return         0 if the statistics were successfully returned,\n *                 -1 otherwise.  On a failure, please check errno against\n *                 ENOTSUP.  webhdfs, LocalFilesystem, and so forth may\n *                 not support read statistics.\n */\nLIBHDFS_EXTERNAL\nint hdfsFileGetReadStatistics(hdfsFile file, struct hdfsReadStatistics** stats);\n\n/**\n * @param stats    HDFS read statistics for a file.\n *\n * @return the number of remote bytes read.\n */\nLIBHDFS_EXTERNAL\nint64_t hdfsReadStatisticsGetRemoteBytesRead(const struct hdfsReadStatistics* stats);\n\n/**\n * Clear the read statistics for a file.\n *\n * @param file      The file to clear the read statistics of.\n *\n * @return          0 on success; the error code otherwise.\n *                  EINVAL: the file is not open for reading.\n *                  ENOTSUP: the file does not support clearing the read\n *                  statistics.\n *                  Errno will also be set to this code on failure.\n */\nLIBHDFS_EXTERNAL\nint hdfsFileClearReadStatistics(hdfsFile file);\n\n/**\n * Free some HDFS read statistics.\n *\n * @param stats    The HDFS read statistics to free.\n */\nLIBHDFS_EXTERNAL\nvoid hdfsFileFreeReadStatistics(struct hdfsReadStatistics* stats);\n\n/**\n * hdfsConnectAsUser - Connect to a hdfs file system as a specific user\n * Connect to the hdfs.\n * @param nn   The NameNode.  See hdfsBuilderSetNameNode for details.\n * @param port The port on which the server is listening.\n * @param user the user name (this is hadoop domain user). Or NULL is equivalent\n * to hhdfsConnect(host, port)\n * @return Returns a handle to the filesystem or NULL on error.\n * @deprecated Use hdfsBuilderConnect instead.\n */\nLIBHDFS_EXTERNAL\nhdfsFS hdfsConnectAsUser(const char* nn, tPort port, const char* user);\n\n/**\n * hdfsConnect - Connect to a hdfs file system.\n * Connect to the hdfs.\n * @param nn   The NameNode.  See hdfsBuilderSetNameNode for details.\n * @param port The port on which the server is listening.\n * @return Returns a handle to the filesystem or NULL on error.\n * @deprecated Use hdfsBuilderConnect instead.\n */\nLIBHDFS_EXTERNAL\nhdfsFS hdfsConnect(const char* nn, tPort port);\n\n/**\n * hdfsConnect - Connect to an hdfs file system.\n *\n * Forces a new instance to be created\n *\n * @param nn     The NameNode.  See hdfsBuilderSetNameNode for details.\n * @param port   The port on which the server is listening.\n * @param user   The user name to use when connecting\n * @return       Returns a handle to the filesystem or NULL on error.\n * @deprecated   Use hdfsBuilderConnect instead.\n */\nLIBHDFS_EXTERNAL\nhdfsFS hdfsConnectAsUserNewInstance(const char* nn, tPort port, const char* user);\n\n/**\n * hdfsConnect - Connect to an hdfs file system.\n *\n * Forces a new instance to be created\n *\n * @param nn     The NameNode.  See hdfsBuilderSetNameNode for details.\n * @param port   The port on which the server is listening.\n * @return       Returns a handle to the filesystem or NULL on error.\n * @deprecated   Use hdfsBuilderConnect instead.\n */\nLIBHDFS_EXTERNAL\nhdfsFS hdfsConnectNewInstance(const char* nn, tPort port);\n\n/**\n * Connect to HDFS using the parameters defined by the builder.\n *\n * The HDFS builder will be freed, whether or not the connection was\n * successful.\n *\n * Every successful call to hdfsBuilderConnect should be matched with a call\n * to hdfsDisconnect, when the hdfsFS is no longer needed.\n *\n * @param bld    The HDFS builder\n * @return       Returns a handle to the filesystem, or NULL on error.\n */\nLIBHDFS_EXTERNAL\nhdfsFS hdfsBuilderConnect(struct hdfsBuilder* bld);\n\n/**\n * Create an HDFS builder.\n *\n * @return The HDFS builder, or NULL on error.\n */\nLIBHDFS_EXTERNAL\nstruct hdfsBuilder* hdfsNewBuilder(void);\n\n/**\n * Force the builder to always create a new instance of the FileSystem,\n * rather than possibly finding one in the cache.\n *\n * @param bld The HDFS builder\n */\nLIBHDFS_EXTERNAL\nvoid hdfsBuilderSetForceNewInstance(struct hdfsBuilder* bld);\n\n/**\n * Set the HDFS NameNode to connect to.\n *\n * @param bld  The HDFS builder\n * @param nn   The NameNode to use.\n *\n *             If the string given is 'default', the default NameNode\n *             configuration will be used (from the XML configuration files)\n *\n *             If NULL is given, a LocalFileSystem will be created.\n *\n *             If the string starts with a protocol type such as file:// or\n *             hdfs://, this protocol type will be used.  If not, the\n *             hdfs:// protocol type will be used.\n *\n *             You may specify a NameNode port in the usual way by\n *             passing a string of the format hdfs://<hostname>:<port>.\n *             Alternately, you may set the port with\n *             hdfsBuilderSetNameNodePort.  However, you must not pass the\n *             port in two different ways.\n */\nLIBHDFS_EXTERNAL\nvoid hdfsBuilderSetNameNode(struct hdfsBuilder* bld, const char* nn);\n\n/**\n * Set the port of the HDFS NameNode to connect to.\n *\n * @param bld The HDFS builder\n * @param port The port.\n */\nLIBHDFS_EXTERNAL\nvoid hdfsBuilderSetNameNodePort(struct hdfsBuilder* bld, tPort port);\n\n/**\n * Set the username to use when connecting to the HDFS cluster.\n *\n * @param bld The HDFS builder\n * @param userName The user name.  The string will be shallow-copied.\n */\nLIBHDFS_EXTERNAL\nvoid hdfsBuilderSetUserName(struct hdfsBuilder* bld, const char* userName);\n\n/**\n * Set the path to the Kerberos ticket cache to use when connecting to\n * the HDFS cluster.\n *\n * @param bld The HDFS builder\n * @param kerbTicketCachePath The Kerberos ticket cache path.  The string\n *                            will be shallow-copied.\n */\nLIBHDFS_EXTERNAL\nvoid hdfsBuilderSetKerbTicketCachePath(struct hdfsBuilder* bld, const char* kerbTicketCachePath);\n\n/**\n * Free an HDFS builder.\n *\n * It is normally not necessary to call this function since\n * hdfsBuilderConnect frees the builder.\n *\n * @param bld The HDFS builder\n */\nLIBHDFS_EXTERNAL\nvoid hdfsFreeBuilder(struct hdfsBuilder* bld);\n\n/**\n * Set a configuration string for an HdfsBuilder.\n *\n * @param key      The key to set.\n * @param val      The value, or NULL to set no value.\n *                 This will be shallow-copied.  You are responsible for\n *                 ensuring that it remains valid until the builder is\n *                 freed.\n *\n * @return         0 on success; nonzero error code otherwise.\n */\nLIBHDFS_EXTERNAL\nint hdfsBuilderConfSetStr(struct hdfsBuilder* bld, const char* key, const char* val);\n\n/**\n * Get a configuration string.\n *\n * @param key      The key to find\n * @param val      (out param) The value.  This will be set to NULL if the\n *                 key isn't found.  You must free this string with\n *                 hdfsConfStrFree.\n *\n * @return         0 on success; nonzero error code otherwise.\n *                 Failure to find the key is not an error.\n */\nLIBHDFS_EXTERNAL\nint hdfsConfGetStr(const char* key, char** val);\n\n/**\n * Get a configuration integer.\n *\n * @param key      The key to find\n * @param val      (out param) The value.  This will NOT be changed if the\n *                 key isn't found.\n *\n * @return         0 on success; nonzero error code otherwise.\n *                 Failure to find the key is not an error.\n */\nLIBHDFS_EXTERNAL\nint hdfsConfGetInt(const char* key, int32_t* val);\n\n/**\n * Free a configuration string found with hdfsConfGetStr.\n *\n * @param val      A configuration string obtained from hdfsConfGetStr\n */\nLIBHDFS_EXTERNAL\nvoid hdfsConfStrFree(char* val);\n\n/**\n * hdfsDisconnect - Disconnect from the hdfs file system.\n * Disconnect from hdfs.\n * @param fs The configured filesystem handle.\n * @return Returns 0 on success, -1 on error.\n *         Even if there is an error, the resources associated with the\n *         hdfsFS will be freed.\n */\nLIBHDFS_EXTERNAL\nint hdfsDisconnect(hdfsFS fs);\n\n/**\n * hdfsOpenFile - Open a hdfs file in given mode.\n * @param fs The configured filesystem handle.\n * @param path The full path to the file.\n * @param flags - an | of bits/fcntl.h file flags - supported flags are\n * O_RDONLY, O_WRONLY (meaning create or overwrite i.e., implies O_TRUNCAT),\n * O_WRONLY|O_APPEND. Other flags are generally ignored other than (O_RDWR ||\n * (O_EXCL & O_CREAT)) which return NULL and set errno equal ENOTSUP.\n * @param bufferSize Size of buffer for read/write - pass 0 if you want\n * to use the default configured values.\n * @param replication Block replication - pass 0 if you want to use\n * the default configured values.\n * @param blocksize Size of block - pass 0 if you want to use the\n * default configured values.\n * @return Returns the handle to the open file or NULL on error.\n */\nLIBHDFS_EXTERNAL\nhdfsFile hdfsOpenFile(hdfsFS fs, const char* path, int flags, int bufferSize, short replication,\n                      tSize blocksize);\n\n/**\n * hdfsTruncateFile - Truncate a hdfs file to given length.\n * @param fs The configured filesystem handle.\n * @param path The full path to the file.\n * @param newlength The size the file is to be truncated to\n * @return 1 if the file has been truncated to the desired newlength\n *         and is immediately available to be reused for write operations\n *         such as append.\n *         0 if a background process of adjusting the length of the last\n *         block has been started, and clients should wait for it to\n *         complete before proceeding with further file updates.\n *         -1 on error.\n */\nint hdfsTruncateFile(hdfsFS fs, const char* path, tOffset newlength);\n\n/**\n * hdfsUnbufferFile - Reduce the buffering done on a file.\n *\n * @param file  The file to unbuffer.\n * @return      0 on success\n *              ENOTSUP if the file does not support unbuffering\n *              Errno will also be set to this value.\n */\nLIBHDFS_EXTERNAL\nint hdfsUnbufferFile(hdfsFile file);\n\n/**\n * hdfsCloseFile - Close an open file.\n * @param fs The configured filesystem handle.\n * @param file The file handle.\n * @return Returns 0 on success, -1 on error.\n *         On error, errno will be set appropriately.\n *         If the hdfs file was valid, the memory associated with it will\n *         be freed at the end of this call, even if there was an I/O\n *         error.\n */\nLIBHDFS_EXTERNAL\nint hdfsCloseFile(hdfsFS fs, hdfsFile file);\n\n/**\n * hdfsExists - Checks if a given path exsits on the filesystem\n * @param fs The configured filesystem handle.\n * @param path The path to look for\n * @return Returns 0 on success, -1 on error.\n */\nLIBHDFS_EXTERNAL\nint hdfsExists(hdfsFS fs, const char* path);\n\n/**\n * hdfsSeek - Seek to given offset in file.\n * This works only for files opened in read-only mode.\n * @param fs The configured filesystem handle.\n * @param file The file handle.\n * @param desiredPos Offset into the file to seek into.\n * @return Returns 0 on success, -1 on error.\n */\nLIBHDFS_EXTERNAL\nint hdfsSeek(hdfsFS fs, hdfsFile file, tOffset desiredPos);\n\n/**\n * hdfsTell - Get the current offset in the file, in bytes.\n * @param fs The configured filesystem handle.\n * @param file The file handle.\n * @return Current offset, -1 on error.\n */\nLIBHDFS_EXTERNAL\ntOffset hdfsTell(hdfsFS fs, hdfsFile file);\n\n/**\n * hdfsRead - Read data from an open file.\n * @param fs The configured filesystem handle.\n * @param file The file handle.\n * @param buffer The buffer to copy read bytes into.\n * @param length The length of the buffer.\n * @return      On success, a positive number indicating how many bytes\n *              were read.\n *              On end-of-file, 0.\n *              On error, -1.  Errno will be set to the error code.\n *              Just like the POSIX read function, hdfsRead will return -1\n *              and set errno to EINTR if data is temporarily unavailable,\n *              but we are not yet at the end of the file.\n */\nLIBHDFS_EXTERNAL\ntSize hdfsRead(hdfsFS fs, hdfsFile file, void* buffer, tSize length);\n\n/**\n * hdfsPread - Positional read of data from an open file.\n * @param fs The configured filesystem handle.\n * @param file The file handle.\n * @param position Position from which to read\n * @param buffer The buffer to copy read bytes into.\n * @param length The length of the buffer.\n * @return      See hdfsRead\n */\nLIBHDFS_EXTERNAL\ntSize hdfsPread(hdfsFS fs, hdfsFile file, tOffset position, void* buffer, tSize length);\n\n/**\n * hdfsWrite - Write data into an open file.\n * @param fs The configured filesystem handle.\n * @param file The file handle.\n * @param buffer The data.\n * @param length The no. of bytes to write.\n * @return Returns the number of bytes written, -1 on error.\n */\nLIBHDFS_EXTERNAL\ntSize hdfsWrite(hdfsFS fs, hdfsFile file, const void* buffer, tSize length);\n\n/**\n * hdfsWrite - Flush the data.\n * @param fs The configured filesystem handle.\n * @param file The file handle.\n * @return Returns 0 on success, -1 on error.\n */\nLIBHDFS_EXTERNAL\nint hdfsFlush(hdfsFS fs, hdfsFile file);\n\n/**\n * hdfsHFlush - Flush out the data in client's user buffer. After the\n * return of this call, new readers will see the data.\n * @param fs configured filesystem handle\n * @param file file handle\n * @return 0 on success, -1 on error and sets errno\n */\nLIBHDFS_EXTERNAL\nint hdfsHFlush(hdfsFS fs, hdfsFile file);\n\n/**\n * hdfsHSync - Similar to posix fsync, Flush out the data in client's\n * user buffer. all the way to the disk device (but the disk may have\n * it in its cache).\n * @param fs configured filesystem handle\n * @param file file handle\n * @return 0 on success, -1 on error and sets errno\n */\nLIBHDFS_EXTERNAL\nint hdfsHSync(hdfsFS fs, hdfsFile file);\n\n/**\n * hdfsAvailable - Number of bytes that can be read from this\n * input stream without blocking.\n * @param fs The configured filesystem handle.\n * @param file The file handle.\n * @return Returns available bytes; -1 on error.\n */\nLIBHDFS_EXTERNAL\nint hdfsAvailable(hdfsFS fs, hdfsFile file);\n\n/**\n * hdfsCopy - Copy file from one filesystem to another.\n * @param srcFS The handle to source filesystem.\n * @param src The path of source file.\n * @param dstFS The handle to destination filesystem.\n * @param dst The path of destination file.\n * @return Returns 0 on success, -1 on error.\n */\nLIBHDFS_EXTERNAL\nint hdfsCopy(hdfsFS srcFS, const char* src, hdfsFS dstFS, const char* dst);\n\n/**\n * hdfsMove - Move file from one filesystem to another.\n * @param srcFS The handle to source filesystem.\n * @param src The path of source file.\n * @param dstFS The handle to destination filesystem.\n * @param dst The path of destination file.\n * @return Returns 0 on success, -1 on error.\n */\nLIBHDFS_EXTERNAL\nint hdfsMove(hdfsFS srcFS, const char* src, hdfsFS dstFS, const char* dst);\n\n/**\n * hdfsDelete - Delete file.\n * @param fs The configured filesystem handle.\n * @param path The path of the file.\n * @param recursive if path is a directory and set to\n * non-zero, the directory is deleted else throws an exception. In\n * case of a file the recursive argument is irrelevant.\n * @return Returns 0 on success, -1 on error.\n */\nLIBHDFS_EXTERNAL\nint hdfsDelete(hdfsFS fs, const char* path, int recursive);\n\n/**\n * hdfsRename - Rename file.\n * @param fs The configured filesystem handle.\n * @param oldPath The path of the source file.\n * @param newPath The path of the destination file.\n * @return Returns 0 on success, -1 on error.\n */\nLIBHDFS_EXTERNAL\nint hdfsRename(hdfsFS fs, const char* oldPath, const char* newPath);\n\n/**\n * hdfsGetWorkingDirectory - Get the current working directory for\n * the given filesystem.\n * @param fs The configured filesystem handle.\n * @param buffer The user-buffer to copy path of cwd into.\n * @param bufferSize The length of user-buffer.\n * @return Returns buffer, NULL on error.\n */\nLIBHDFS_EXTERNAL\nchar* hdfsGetWorkingDirectory(hdfsFS fs, char* buffer, size_t bufferSize);\n\n/**\n * hdfsSetWorkingDirectory - Set the working directory. All relative\n * paths will be resolved relative to it.\n * @param fs The configured filesystem handle.\n * @param path The path of the new 'cwd'.\n * @return Returns 0 on success, -1 on error.\n */\nLIBHDFS_EXTERNAL\nint hdfsSetWorkingDirectory(hdfsFS fs, const char* path);\n\n/**\n * hdfsCreateDirectory - Make the given file and all non-existent\n * parents into directories.\n * @param fs The configured filesystem handle.\n * @param path The path of the directory.\n * @return Returns 0 on success, -1 on error.\n */\nLIBHDFS_EXTERNAL\nint hdfsCreateDirectory(hdfsFS fs, const char* path);\n\n/**\n * hdfsSetReplication - Set the replication of the specified\n * file to the supplied value\n * @param fs The configured filesystem handle.\n * @param path The path of the file.\n * @return Returns 0 on success, -1 on error.\n */\nLIBHDFS_EXTERNAL\nint hdfsSetReplication(hdfsFS fs, const char* path, int16_t replication);\n\n/**\n * hdfsFileInfo - Information about a file/directory.\n */\ntypedef struct {\n  tObjectKind mKind;  /* file or directory */\n  char* mName;        /* the name of the file */\n  tTime mLastMod;     /* the last modification time for the file in seconds */\n  tOffset mSize;      /* the size of the file in bytes */\n  short mReplication; /* the count of replicas */\n  tOffset mBlockSize; /* the block size for the file */\n  char* mOwner;       /* the owner of the file */\n  char* mGroup;       /* the group associated with the file */\n  short mPermissions; /* the permissions associated with the file */\n  tTime mLastAccess;  /* the last access time for the file in seconds */\n} hdfsFileInfo;\n\n/**\n * hdfsListDirectory - Get list of files/directories for a given\n * directory-path. hdfsFreeFileInfo should be called to deallocate memory.\n * @param fs The configured filesystem handle.\n * @param path The path of the directory.\n * @param numEntries Set to the number of files/directories in path.\n * @return Returns a dynamically-allocated array of hdfsFileInfo\n * objects; NULL on error.\n */\nLIBHDFS_EXTERNAL\nhdfsFileInfo* hdfsListDirectory(hdfsFS fs, const char* path, int* numEntries);\n\n/**\n * hdfsGetPathInfo - Get information about a path as a (dynamically\n * allocated) single hdfsFileInfo struct. hdfsFreeFileInfo should be\n * called when the pointer is no longer needed.\n * @param fs The configured filesystem handle.\n * @param path The path of the file.\n * @return Returns a dynamically-allocated hdfsFileInfo object;\n * NULL on error.\n */\nLIBHDFS_EXTERNAL\nhdfsFileInfo* hdfsGetPathInfo(hdfsFS fs, const char* path);\n\n/**\n * hdfsFreeFileInfo - Free up the hdfsFileInfo array (including fields)\n * @param hdfsFileInfo The array of dynamically-allocated hdfsFileInfo\n * objects.\n * @param numEntries The size of the array.\n */\nLIBHDFS_EXTERNAL\nvoid hdfsFreeFileInfo(hdfsFileInfo* hdfsFileInfo, int numEntries);\n\n/**\n * hdfsFileIsEncrypted: determine if a file is encrypted based on its\n * hdfsFileInfo.\n * @return -1 if there was an error (errno will be set), 0 if the file is\n *         not encrypted, 1 if the file is encrypted.\n */\nLIBHDFS_EXTERNAL\nint hdfsFileIsEncrypted(hdfsFileInfo* hdfsFileInfo);\n\n/**\n * hdfsGetHosts - Get hostnames where a particular block (determined by\n * pos & blocksize) of a file is stored. The last element in the array\n * is NULL. Due to replication, a single block could be present on\n * multiple hosts.\n * @param fs The configured filesystem handle.\n * @param path The path of the file.\n * @param start The start of the block.\n * @param length The length of the block.\n * @return Returns a dynamically-allocated 2-d array of blocks-hosts;\n * NULL on error.\n */\nLIBHDFS_EXTERNAL\nchar*** hdfsGetHosts(hdfsFS fs, const char* path, tOffset start, tOffset length);\n\n/**\n * hdfsFreeHosts - Free up the structure returned by hdfsGetHosts\n * @param hdfsFileInfo The array of dynamically-allocated hdfsFileInfo\n * objects.\n * @param numEntries The size of the array.\n */\nLIBHDFS_EXTERNAL\nvoid hdfsFreeHosts(char*** blockHosts);\n\n/**\n * hdfsGetDefaultBlockSize - Get the default blocksize.\n *\n * @param fs            The configured filesystem handle.\n * @deprecated          Use hdfsGetDefaultBlockSizeAtPath instead.\n *\n * @return              Returns the default blocksize, or -1 on error.\n */\nLIBHDFS_EXTERNAL\ntOffset hdfsGetDefaultBlockSize(hdfsFS fs);\n\n/**\n * hdfsGetDefaultBlockSizeAtPath - Get the default blocksize at the\n * filesystem indicated by a given path.\n *\n * @param fs            The configured filesystem handle.\n * @param path          The given path will be used to locate the actual\n *                      filesystem.  The full path does not have to exist.\n *\n * @return              Returns the default blocksize, or -1 on error.\n */\nLIBHDFS_EXTERNAL\ntOffset hdfsGetDefaultBlockSizeAtPath(hdfsFS fs, const char* path);\n\n/**\n * hdfsGetCapacity - Return the raw capacity of the filesystem.\n * @param fs The configured filesystem handle.\n * @return Returns the raw-capacity; -1 on error.\n */\nLIBHDFS_EXTERNAL\ntOffset hdfsGetCapacity(hdfsFS fs);\n\n/**\n * hdfsGetUsed - Return the total raw size of all files in the filesystem.\n * @param fs The configured filesystem handle.\n * @return Returns the total-size; -1 on error.\n */\nLIBHDFS_EXTERNAL\ntOffset hdfsGetUsed(hdfsFS fs);\n\n/**\n * Change the user and/or group of a file or directory.\n *\n * @param fs            The configured filesystem handle.\n * @param path          the path to the file or directory\n * @param owner         User string.  Set to NULL for 'no change'\n * @param group         Group string.  Set to NULL for 'no change'\n * @return              0 on success else -1\n */\nLIBHDFS_EXTERNAL\nint hdfsChown(hdfsFS fs, const char* path, const char* owner, const char* group);\n\n/**\n * hdfsChmod\n * @param fs The configured filesystem handle.\n * @param path the path to the file or directory\n * @param mode the bitmask to set it to\n * @return 0 on success else -1\n */\nLIBHDFS_EXTERNAL\nint hdfsChmod(hdfsFS fs, const char* path, short mode);\n\n/**\n * hdfsUtime\n * @param fs The configured filesystem handle.\n * @param path the path to the file or directory\n * @param mtime new modification time or -1 for no change\n * @param atime new access time or -1 for no change\n * @return 0 on success else -1\n */\nLIBHDFS_EXTERNAL\nint hdfsUtime(hdfsFS fs, const char* path, tTime mtime, tTime atime);\n\n/**\n * Allocate a zero-copy options structure.\n *\n * You must free all options structures allocated with this function using\n * hadoopRzOptionsFree.\n *\n * @return            A zero-copy options structure, or NULL if one could\n *                    not be allocated.  If NULL is returned, errno will\n *                    contain the error number.\n */\nLIBHDFS_EXTERNAL\nstruct hadoopRzOptions* hadoopRzOptionsAlloc(void);\n\n/**\n * Determine whether we should skip checksums in read0.\n *\n * @param opts        The options structure.\n * @param skip        Nonzero to skip checksums sometimes; zero to always\n *                    check them.\n *\n * @return            0 on success; -1 plus errno on failure.\n */\nLIBHDFS_EXTERNAL\nint hadoopRzOptionsSetSkipChecksum(struct hadoopRzOptions* opts, int skip);\n\n/**\n * Set the ByteBufferPool to use with read0.\n *\n * @param opts        The options structure.\n * @param className   If this is NULL, we will not use any\n *                    ByteBufferPool.  If this is non-NULL, it will be\n *                    treated as the name of the pool class to use.\n *                    For example, you can use\n *                    ELASTIC_BYTE_BUFFER_POOL_CLASS.\n *\n * @return            0 if the ByteBufferPool class was found and\n *                    instantiated;\n *                    -1 plus errno otherwise.\n */\nLIBHDFS_EXTERNAL\nint hadoopRzOptionsSetByteBufferPool(struct hadoopRzOptions* opts, const char* className);\n\n/**\n * Free a hadoopRzOptionsFree structure.\n *\n * @param opts        The options structure to free.\n *                    Any associated ByteBufferPool will also be freed.\n */\nLIBHDFS_EXTERNAL\nvoid hadoopRzOptionsFree(struct hadoopRzOptions* opts);\n\n/**\n * Perform a byte buffer read.\n * If possible, this will be a zero-copy (mmap) read.\n *\n * @param file       The file to read from.\n * @param opts       An options structure created by hadoopRzOptionsAlloc.\n * @param maxLength  The maximum length to read.  We may read fewer bytes\n *                   than this length.\n *\n * @return           On success, we will return a new hadoopRzBuffer.\n *                   This buffer will continue to be valid and readable\n *                   until it is released by readZeroBufferFree.  Failure to\n *                   release a buffer will lead to a memory leak.\n *                   You can access the data within the hadoopRzBuffer with\n *                   hadoopRzBufferGet.  If you have reached EOF, the data\n *                   within the hadoopRzBuffer will be NULL.  You must still\n *                   free hadoopRzBuffer instances containing NULL.\n *\n *                   On failure, we will return NULL plus an errno code.\n *                   errno = EOPNOTSUPP indicates that we could not do a\n *                   zero-copy read, and there was no ByteBufferPool\n *                   supplied.\n */\nLIBHDFS_EXTERNAL\nstruct hadoopRzBuffer* hadoopReadZero(hdfsFile file, struct hadoopRzOptions* opts,\n                                      int32_t maxLength);\n\n/**\n * Determine the length of the buffer returned from readZero.\n *\n * @param buffer     a buffer returned from readZero.\n * @return           the length of the buffer.\n */\nLIBHDFS_EXTERNAL\nint32_t hadoopRzBufferLength(const struct hadoopRzBuffer* buffer);\n\n/**\n * Get a pointer to the raw buffer returned from readZero.\n *\n * To find out how many bytes this buffer contains, call\n * hadoopRzBufferLength.\n *\n * @param buffer     a buffer returned from readZero.\n * @return           a pointer to the start of the buffer.  This will be\n *                   NULL when end-of-file has been reached.\n */\nLIBHDFS_EXTERNAL\nconst void* hadoopRzBufferGet(const struct hadoopRzBuffer* buffer);\n\n/**\n * Release a buffer obtained through readZero.\n *\n * @param file       The hdfs stream that created this buffer.  This must be\n *                   the same stream you called hadoopReadZero on.\n * @param buffer     The buffer to release.\n */\nLIBHDFS_EXTERNAL\nvoid hadoopRzBufferFree(hdfsFile file, struct hadoopRzBuffer* buffer);\n\n#ifdef __cplusplus\n}\n#endif\n\n#undef LIBHDFS_EXTERNAL\n#endif /*ONEFLOW_CORE_PERSISTENCE_HADOOP_HDFS_H_*/\n\n/**\n * vim: ts=4: sw=4: et\n */\n"
  },
  {
    "path": "oneflow/core/persistence/persistent_in_stream.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/persistence/persistent_in_stream.h\"\n#include \"oneflow/core/persistence/binary_in_stream_with_local_copy.h\"\n#include \"oneflow/core/persistence/binary_in_stream_without_local_copy.h\"\n#include \"oneflow/core/job/job_set.pb.h\"\n#include <cstring>\n#include \"oneflow/core/common/constant.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nconstexpr size_t kDefaultBufferSize = 32 * 1024;  // 32KB\n\nsize_t GetBufferSize() {\n  const char* buf_size_str = std::getenv(\"ONEFLOW_PERSISTENT_IN_STREAM_BUFFER_SIZE_BYTES\");\n  if (buf_size_str) {\n    int buf_size = atoi(buf_size_str);\n    if (buf_size > 0) {\n      return buf_size;\n    } else {\n      LOG(WARNING) << \"invalid env ONEFLOW_PERSISTENT_IN_STREAM_BUFFER_SIZE_BYTES \" << buf_size_str\n                   << \", default size \" << kDefaultBufferSize << \" is set\";\n      return kDefaultBufferSize;\n    }\n  }\n  return kDefaultBufferSize;\n}\n\n}  // namespace\n\nPersistentInStream::PersistentInStream(fs::FileSystem* fs,\n                                       const std::vector<std::string>& file_paths, uint64_t offset,\n                                       bool cyclic, bool with_local_copy)\n    : PersistentInStream(kInvalidSessionId, fs, file_paths, offset, cyclic, with_local_copy) {}\n\nPersistentInStream::PersistentInStream(int64_t session_id, fs::FileSystem* fs,\n                                       const std::vector<std::string>& file_paths, uint64_t offset,\n                                       bool cyclic, bool with_local_copy) {\n  if (with_local_copy) { CHECK_EQ(offset, 0); }\n  std::vector<std::shared_ptr<BinaryInStream>> streams;\n  for (auto& file_path : file_paths) {\n    if (with_local_copy) {\n      streams.emplace_back(new BinaryInStreamWithLocalCopy(fs, file_path));\n    } else {\n      streams.emplace_back(new BinaryInStreamWithoutLocalCopy(fs, file_path));\n    }\n  }\n  if (cyclic) {\n    stream_scanner_.reset(new CyclicStreamScanner(fs, streams, offset));\n  } else {\n    stream_scanner_.reset(new AcyclicStreamScanner(fs, streams, offset));\n  }\n  buffer_.resize(GetBufferSize() + 1);\n  cur_buf_begin_ = buffer_.data();\n  cur_buf_end_ = buffer_.data();\n  *cur_buf_end_ = '\\0';\n}\n\nPersistentInStream::PersistentInStream(fs::FileSystem* fs,\n                                       const std::vector<std::string>& file_paths, bool cyclic,\n                                       bool with_local_copy)\n    : PersistentInStream(fs, file_paths, 0, cyclic, with_local_copy) {}\n\nPersistentInStream::PersistentInStream(fs::FileSystem* fs, const std::string& file_path,\n                                       uint64_t offset, bool cyclic, bool with_local_copy)\n    : PersistentInStream(fs, std::vector<std::string>({file_path}), offset, cyclic,\n                         with_local_copy) {}\n\nPersistentInStream::PersistentInStream(fs::FileSystem* fs, const std::string& file_path,\n                                       uint64_t offset)\n    : PersistentInStream(fs, file_path, offset, false, false) {}\n\nPersistentInStream::PersistentInStream(fs::FileSystem* fs, const std::string& file_path)\n    : PersistentInStream(fs, file_path, 0, false, false) {}\n\nPersistentInStream::PersistentInStream(int64_t session_id, fs::FileSystem* fs,\n                                       const std::string& file_path)\n    : PersistentInStream(session_id, fs, std::vector<std::string>({file_path}), 0, false, false) {}\n\nint32_t PersistentInStream::ReadLine(std::string* l) {\n  if (IsEof()) { return -1; }\n  l->clear();\n  while (*cur_buf_begin_ != '\\n') {\n    if (cur_buf_begin_ == cur_buf_end_) {\n      UpdateBuffer();\n      if (cur_buf_begin_ == cur_buf_end_) {\n        return 0;\n      } else {\n        continue;\n      }\n    }\n    l->push_back(*cur_buf_begin_++);\n  }\n  ++cur_buf_begin_;\n  return 0;\n}\n\nint32_t PersistentInStream::ReadFully(char* s, size_t n) {\n  if (IsEof()) { return -1; }\n  while (n) {\n    if (cur_buf_begin_ == cur_buf_end_) { UpdateBuffer(); }\n    CHECK_LT(cur_buf_begin_, cur_buf_end_);\n    int64_t copy_size = std::min<int64_t>(cur_buf_end_ - cur_buf_begin_, n);\n    std::memcpy(s, cur_buf_begin_, static_cast<size_t>(copy_size));\n    s += copy_size;\n    cur_buf_begin_ += copy_size;\n    n -= copy_size;\n  }\n  return 0;\n}\n\nvoid PersistentInStream::UpdateBuffer() {\n  CHECK_EQ(cur_buf_begin_, cur_buf_end_);\n  uint64_t n = stream_scanner_->UpdateBuffer(&buffer_);\n  cur_buf_begin_ = buffer_.data();\n  cur_buf_end_ = buffer_.data() + n;\n  *cur_buf_end_ = '\\0';\n}\n\nbool PersistentInStream::IsEof() const {\n  return cur_buf_begin_ == cur_buf_end_ && stream_scanner_->IsEof();\n}\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/persistence/persistent_in_stream.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_PERSISTENCE_PERSISTENT_IN_STREAM_H_\n#define ONEFLOW_CORE_PERSISTENCE_PERSISTENT_IN_STREAM_H_\n\n#include \"oneflow/core/persistence/file_system.h\"\n#include \"oneflow/core/persistence/stream_scanner.h\"\n\nnamespace oneflow {\n\nclass PersistentInStream {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(PersistentInStream);\n  virtual ~PersistentInStream() {}\n  PersistentInStream(fs::FileSystem* fs, const std::vector<std::string>& file_paths,\n                     uint64_t offset, bool cyclic, bool with_local_copy);\n  PersistentInStream(fs::FileSystem* fs, const std::vector<std::string>& file_paths, bool cyclic,\n                     bool with_local_copy);\n  PersistentInStream(fs::FileSystem* fs, const std::string& file_path, uint64_t offset, bool cyclic,\n                     bool with_local_copy);\n  PersistentInStream(fs::FileSystem* fs, const std::string& file_path, uint64_t offset);\n  PersistentInStream(fs::FileSystem* fs, const std::string& file_path);\n  PersistentInStream(int64_t session_id, fs::FileSystem* fs, const std::string& file_path);\n\n  PersistentInStream(int64_t session_id, fs::FileSystem* fs,\n                     const std::vector<std::string>& file_paths, uint64_t offset, bool cyclic,\n                     bool with_local_copy);\n\n  // 0: success\n  // -1: eof\n  int32_t ReadLine(std::string* l);\n  int32_t ReadFully(char* s, size_t n);\n\n private:\n  bool IsEof() const;\n  void UpdateBuffer();\n\n  std::unique_ptr<StreamScanner> stream_scanner_;\n\n  std::vector<char> buffer_;\n  char* cur_buf_begin_;\n  char* cur_buf_end_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_PERSISTENCE_PERSISTENT_IN_STREAM_H_\n"
  },
  {
    "path": "oneflow/core/persistence/persistent_out_stream.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/persistence/persistent_out_stream.h\"\n#include \"oneflow/core/common/str_util.h\"\n#include \"oneflow/core/control/ctrl_client.h\"\n#include \"oneflow/core/control/global_process_ctx.h\"\n\nnamespace oneflow {\n\nPersistentOutStream::PersistentOutStream(fs::FileSystem* fs, const std::string& file_path) {\n  std::string file_dir = Dirname(file_path);\n  OfCallOnce(GlobalProcessCtx::LogDirEntry() + \"/\" + file_dir, fs,\n             &fs::FileSystem::RecursivelyCreateDirIfNotExist, file_dir);\n  fs->NewWritableFile(file_path, &file_);\n}\n\nPersistentOutStream::~PersistentOutStream() { file_->Close(); }\n\nPersistentOutStream& PersistentOutStream::Write(const char* s, size_t n) {\n  file_->Append(s, n);\n  return *this;\n}\n\nvoid PersistentOutStream::Flush() { file_->Flush(); }\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/persistence/persistent_out_stream.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_PERSISTENCE_PERSISTENT_OUT_STREAM_H_\n#define ONEFLOW_CORE_PERSISTENCE_PERSISTENT_OUT_STREAM_H_\n\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/persistence/file_system.h\"\n\nnamespace oneflow {\n\nclass PersistentOutStream final {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(PersistentOutStream);\n  PersistentOutStream() = delete;\n  ~PersistentOutStream();\n\n  PersistentOutStream(fs::FileSystem*, const std::string& file_path);\n\n  // Write block of data\n  // Inserts the first n characters of the array pointed by s into the stream.\n  PersistentOutStream& Write(const char* s, size_t n);\n\n  void Flush();\n\n private:\n  std::unique_ptr<fs::WritableFile> file_;\n};\n\ntemplate<typename T>\ntypename std::enable_if<std::is_fundamental<T>::value, PersistentOutStream&>::type operator<<(\n    PersistentOutStream& out_stream, const T& x) {\n  const char* x_ptr = reinterpret_cast<const char*>(&x);\n  size_t n = sizeof(x);\n  out_stream.Write(x_ptr, n);\n  return out_stream;\n}\n\ninline PersistentOutStream& operator<<(PersistentOutStream& out_stream, const std::string& s) {\n  out_stream.Write(s.c_str(), s.size());\n  return out_stream;\n}\n\ntemplate<size_t n>\nPersistentOutStream& operator<<(PersistentOutStream& out_stream, const char (&s)[n]) {\n  out_stream.Write(s, strlen(s));\n  return out_stream;\n}\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_PERSISTENCY_PERSISTENT_OUT_STREAM_H\n"
  },
  {
    "path": "oneflow/core/persistence/posix/posix_file_system.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/persistence/posix/posix_file_system.h\"\n\n#ifdef OF_PLATFORM_POSIX\n\n#include <dirent.h>\n#include <errno.h>\n#include <fcntl.h>\n#include <stdio.h>\n#include <sys/mman.h>\n#include <sys/stat.h>\n#include <sys/time.h>\n#include <sys/types.h>\n#include <time.h>\n#include <unistd.h>\n\nnamespace oneflow {\n\nnamespace fs {\n\nclass PosixRandomAccessFile : public RandomAccessFile {\n private:\n  std::string fname_;\n  int fd_;\n\n public:\n  PosixRandomAccessFile(const std::string& fname, int fd) : fname_(fname), fd_(fd) {}\n  ~PosixRandomAccessFile() override { close(fd_); }\n\n  void Read(uint64_t offset, size_t n, char* result) const override {\n    char* dst = result;\n    while (n > 0) {\n      ssize_t r = pread(fd_, dst, n, static_cast<off_t>(offset));\n      if (r > 0) {\n        dst += r;\n        n -= r;\n        offset += r;\n      } else if (r == 0) {\n        PLOG(FATAL) << \"Read EOF\";\n        return;\n      } else if (errno == EINTR || errno == EAGAIN) {\n        // Retry\n      } else {\n        PLOG(FATAL) << \"Fail to read file \" << fname_;\n        return;\n      }\n    }\n  }\n};\n\nclass PosixWritableFile : public WritableFile {\n private:\n  std::string fname_;\n  FILE* file_;\n\n public:\n  PosixWritableFile(const std::string& fname, FILE* file) : fname_(fname), file_(file) {}\n\n  ~PosixWritableFile() override {\n    if (file_ != nullptr) { fclose(file_); }\n  }\n\n  void Append(const char* data, size_t n) override {\n    PCHECK(fwrite(data, sizeof(char), n, file_) == n)\n        << \"Fail to append to file \" << fname_ << \", errno is \" << errno;\n  }\n\n  void Close() override {\n    Flush();\n    PCHECK(fclose(file_) == 0) << \"Fail to close file \" << fname_ << \", errno is \" << errno;\n    file_ = nullptr;\n  }\n\n  void Flush() override {\n    PCHECK(fflush(file_) == 0) << \"Fail to flush file \" << fname_ << \", errno is \" << errno;\n  }\n};\n\nvoid PosixFileSystem::NewRandomAccessFile(const std::string& fname,\n                                          std::unique_ptr<RandomAccessFile>* result) {\n  std::string translated_fname = TranslateName(fname);\n  int fd = open(translated_fname.c_str(), O_RDONLY);\n  PCHECK(fd >= 0) << \"Fail to open file \" << fname << \", errno is \" << errno;\n  result->reset(new PosixRandomAccessFile(fname, fd));\n  CHECK_NOTNULL(result->get());\n}\n\nvoid PosixFileSystem::NewWritableFile(const std::string& fname,\n                                      std::unique_ptr<WritableFile>* result) {\n  std::string translated_fname = TranslateName(fname);\n  FILE* f = fopen(translated_fname.c_str(), \"w\");\n  PCHECK(f != nullptr) << \"Fail to open file \" << fname << \", errno is \" << errno;\n  result->reset(new PosixWritableFile(translated_fname, f));\n  CHECK_NOTNULL(result->get());\n}\n\nvoid PosixFileSystem::NewAppendableFile(const std::string& fname,\n                                        std::unique_ptr<WritableFile>* result) {\n  std::string translated_name = TranslateName(fname);\n  FILE* f = fopen(translated_name.c_str(), \"a\");\n  PCHECK(f != nullptr) << \"Fail to open file \" << fname << \", errno is \" << errno;\n  result->reset(new PosixWritableFile(translated_name, f));\n  CHECK_NOTNULL(result->get());\n}\n\nbool PosixFileSystem::FileExists(const std::string& fname) {\n  if (access(TranslateName(fname).c_str(), F_OK) == 0) { return true; }\n  return false;\n}\n\nstd::vector<std::string> PosixFileSystem::ListDir(const std::string& dir) {\n  std::string translated_dir = TranslateName(dir);\n  std::vector<std::string> result;\n  DIR* d = opendir(translated_dir.c_str());\n  PCHECK(d != nullptr) << \"Fail to open dir \" << dir << \", errno is \" << errno;\n  struct dirent* entry;\n  while ((entry = readdir(d)) != nullptr) {\n    if (strcmp(entry->d_name, \".\") == 0 || strcmp(entry->d_name, \"..\") == 0) { continue; }\n    result.emplace_back(entry->d_name);\n  }\n  closedir(d);\n  return result;\n}\n\nvoid PosixFileSystem::DelFile(const std::string& fname) {\n  PCHECK(unlink(TranslateName(fname).c_str()) == 0)\n      << \"Fail to delete file \" << fname << \", errno is \" << errno;\n}\n\nvoid PosixFileSystem::CreateDir(const std::string& dirname) {\n  PCHECK(mkdir(TranslateName(dirname).c_str(), 0755) == 0)\n      << \"Fail to create dir \" << dirname << \", errno is \" << errno;\n}\n\nvoid PosixFileSystem::CreateDirIfNotExist(const std::string& dirname) {\n  int ret = mkdir(TranslateName(dirname).c_str(), 0755);\n  PCHECK(ret == 0 || (errno == EEXIST && IsDirectory(dirname)))\n      << \"Fail to create dir \" << dirname << \", errno is \" << errno;\n}\n\nvoid PosixFileSystem::DeleteDir(const std::string& dirname) {\n  PCHECK(rmdir(TranslateName(dirname).c_str()) == 0)\n      << \"Fail to delete dir \" << dirname << \", errno is \" << errno;\n}\n\nuint64_t PosixFileSystem::GetFileSize(const std::string& fname) {\n  struct stat sbuf;\n  PCHECK(stat(TranslateName(fname).c_str(), &sbuf) == 0)\n      << \"Fail to load statistics of \" << fname << \", errno is \" << errno;\n  return sbuf.st_size;\n}\n\nvoid PosixFileSystem::RenameFile(const std::string& old_name, const std::string& new_name) {\n  PCHECK(rename(TranslateName(old_name).c_str(), TranslateName(new_name).c_str()) == 0)\n      << \"Fail to rename file from \" << old_name << \" to \" << new_name << \", errno is \" << errno;\n}\n\nbool PosixFileSystem::IsDirectory(const std::string& fname) {\n  struct stat sbuf;\n  if (stat(TranslateName(fname).c_str(), &sbuf) == 0 && S_ISDIR(sbuf.st_mode)) { return true; }\n  return false;\n}\n\n}  // namespace fs\n\n}  // namespace oneflow\n\n#endif  // OF_PLATFORM_POSIX\n"
  },
  {
    "path": "oneflow/core/persistence/posix/posix_file_system.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_PERSISTENCE_POSIX_POSIX_FILE_SYSTEM_H_\n#define ONEFLOW_CORE_PERSISTENCE_POSIX_POSIX_FILE_SYSTEM_H_\n\n#include \"oneflow/core/persistence/file_system.h\"\n\n#ifdef OF_PLATFORM_POSIX\n\nnamespace oneflow {\n\nnamespace fs {\n\nclass PosixFileSystem final : public FileSystem {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(PosixFileSystem);\n  PosixFileSystem() = default;\n  ~PosixFileSystem() = default;\n\n  void NewRandomAccessFile(const std::string& fname,\n                           std::unique_ptr<RandomAccessFile>* result) override;\n\n  void NewWritableFile(const std::string& fname, std::unique_ptr<WritableFile>* result) override;\n\n  void NewAppendableFile(const std::string& fname, std::unique_ptr<WritableFile>* result) override;\n\n  bool FileExists(const std::string& fname) override;\n\n  std::vector<std::string> ListDir(const std::string& dir) override;\n\n  void DelFile(const std::string& fname) override;\n\n  void CreateDir(const std::string& dirname) override;\n\n  void CreateDirIfNotExist(const std::string& dirname) override;\n\n  void DeleteDir(const std::string& dirname) override;\n\n  uint64_t GetFileSize(const std::string& fname) override;\n\n  void RenameFile(const std::string& old_name, const std::string& new_name) override;\n\n  bool IsDirectory(const std::string& fname) override;\n\n private:\n};\n\n}  // namespace fs\n\n}  // namespace oneflow\n\n#endif  // OF_PLATFORM_POSIX\n\n#endif  // ONEFLOW_CORE_PERSISTENCE_POSIX_POSIX_FILE_SYSTEM_H_\n"
  },
  {
    "path": "oneflow/core/persistence/stream_scanner.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/persistence/stream_scanner.h\"\n#include \"oneflow/core/persistence/binary_in_stream_without_local_copy.h\"\n#include \"oneflow/core/persistence/binary_in_stream_with_local_copy.h\"\n\nnamespace oneflow {\n\nStreamScanner::StreamScanner(fs::FileSystem* fs,\n                             const std::vector<std::shared_ptr<BinaryInStream>>& streams,\n                             uint64_t offset)\n    : whole_file_offset_(offset) {\n  stream_num_ = streams.size();\n  whole_file_size_ = 0;\n  int64_t idx = 0;\n  for (auto& stream : streams) {\n    AddStream(fs, stream, idx);\n    ++idx;\n  }\n  CHECK_LE(whole_file_offset_, whole_file_size_);\n  whole_file_pos_ = whole_file_offset_;\n}\n\nvoid StreamScanner::AddStream(fs::FileSystem* fs, const std::shared_ptr<BinaryInStream>& stream,\n                              int64_t idx) {\n  uint64_t cur_file_size = stream->file_size();\n  if (whole_file_offset_ < whole_file_size_) {\n    stream->set_cur_file_pos(0);\n  } else if (whole_file_size_ <= whole_file_offset_\n             && whole_file_offset_ < whole_file_size_ + cur_file_size) {\n    stream->set_cur_file_pos(whole_file_offset_ - whole_file_size_);\n    cur_stream_id_ = idx;\n  } else if (whole_file_offset_ >= whole_file_size_ + cur_file_size) {\n    stream->set_cur_file_pos(0);  // works for both cyclic and acyclic cases\n  }\n\n  streams_.emplace_back(stream);\n  whole_file_size_ += cur_file_size;\n}\n\nbool StreamScanner::IsEof() const { return whole_file_pos_ == whole_file_size_; }\n\nuint64_t StreamScanner::UpdateBuffer(std::vector<char>* buffer) {\n  if (cur_stream_id_ == stream_num_) return 0;\n  uint64_t n =\n      std::min<uint64_t>(buffer->size() - 1, streams_[cur_stream_id_]->file_size()\n                                                 - streams_[cur_stream_id_]->cur_file_pos());\n  if (n == 0) { return 0; }\n  streams_[cur_stream_id_]->Read(buffer->data(), n);\n  AddNForCurFilePos(n);\n  return n;\n}\n\nvoid AcyclicStreamScanner::AddNForCurFilePos(uint64_t n) {\n  whole_file_pos_ += n;\n  if (streams_[cur_stream_id_]->IsEof()) { ++cur_stream_id_; }\n}\n\nvoid CyclicStreamScanner::AddNForCurFilePos(uint64_t n) {\n  whole_file_pos_ = (whole_file_pos_ + n) % whole_file_size_;\n  if (streams_[cur_stream_id_]->IsEof()) {\n    streams_[cur_stream_id_]->set_cur_file_pos(0);\n    ++cur_stream_id_;\n    if (cur_stream_id_ == stream_num_) {\n      CHECK_EQ(whole_file_pos_, 0);\n      cur_stream_id_ = 0;\n    }\n  }\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/persistence/stream_scanner.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_PERSISTENCE_STREAM_SCANNER_H_\n#define ONEFLOW_CORE_PERSISTENCE_STREAM_SCANNER_H_\n\n#include <vector>\n#include <string>\n#include \"oneflow/core/persistence/binary_in_stream.h\"\n#include \"oneflow/core/persistence/file_system.h\"\n\nnamespace oneflow {\n\nclass StreamScanner {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(StreamScanner);\n  virtual ~StreamScanner() {}\n\n  StreamScanner(fs::FileSystem* fs, const std::vector<std::shared_ptr<BinaryInStream>>& streams,\n                uint64_t offset);\n  bool IsEof() const;\n  uint64_t UpdateBuffer(std::vector<char>* buffer);\n\n protected:\n  virtual void AddNForCurFilePos(uint64_t n) = 0;\n\n  std::vector<std::shared_ptr<BinaryInStream>> streams_;\n  uint64_t whole_file_size_;\n  uint64_t whole_file_pos_;\n  int32_t cur_stream_id_;\n  int32_t stream_num_;\n  uint64_t whole_file_offset_;\n\n private:\n  void AddStream(fs::FileSystem* fs, const std::shared_ptr<BinaryInStream>& stream, int64_t idx);\n};\n\nclass CyclicStreamScanner final : public StreamScanner {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CyclicStreamScanner);\n  CyclicStreamScanner(fs::FileSystem* fs,\n                      const std::vector<std::shared_ptr<BinaryInStream>>& streams, uint64_t offset)\n      : StreamScanner(fs, streams, offset) {}\n\n protected:\n  void AddNForCurFilePos(uint64_t n) override;\n};\n\nclass AcyclicStreamScanner final : public StreamScanner {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(AcyclicStreamScanner);\n  AcyclicStreamScanner(fs::FileSystem* fs,\n                       const std::vector<std::shared_ptr<BinaryInStream>>& streams, uint64_t offset)\n      : StreamScanner(fs, streams, offset) {}\n\n protected:\n  void AddNForCurFilePos(uint64_t n) override;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_PERSISTENCE_STREAM_SCANNER_H_\n"
  },
  {
    "path": "oneflow/core/persistence/tee_persistent_log_stream.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/persistence/tee_persistent_log_stream.h\"\n#include \"oneflow/core/common/str_util.h\"\n#include <google/protobuf/text_format.h>\n\nnamespace oneflow {\n\nTeePersistentLogStream::TeePersistentLogStream(const std::string& path) {\n  destinations_.emplace_back(LocalFS(), FLAGS_log_dir);\n  branches_.reserve(destinations_.size());\n  for (const auto& destination : destinations_) {\n    branches_.emplace_back(std::make_unique<PersistentOutStream>(\n        destination.mut_file_system(), JoinPath(destination.base_dir(), path)));\n  }\n}\n\nTeePersistentLogStream::~TeePersistentLogStream() { Flush(); }\n\nstd::unique_ptr<TeePersistentLogStream> TeePersistentLogStream::Create(const std::string& path) {\n  auto stream_ptr = new TeePersistentLogStream(path);\n  return std::unique_ptr<TeePersistentLogStream>(stream_ptr);\n}\n\nvoid TeePersistentLogStream::Flush() {\n  for (const auto& branch : branches_) { branch->Flush(); }\n};\n\nvoid TeePersistentLogStream::Write(const char* s, size_t n) {\n  for (const auto& branch : branches_) { branch->Write(s, n); }\n};\n\nvoid TeePersistentLogStream::Write(const std::string& str) { this->Write(str.data(), str.size()); }\n\nvoid TeePersistentLogStream::Write(const PbMessage& proto) {\n  std::string output;\n  google::protobuf::TextFormat::PrintToString(proto, &output);\n  this->Write(output);\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/persistence/tee_persistent_log_stream.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_PERSISTENCE_TEE_PERSISTENT_LOG_STREAM_H_\n#define ONEFLOW_CORE_PERSISTENCE_TEE_PERSISTENT_LOG_STREAM_H_\n\n#include \"oneflow/core/common/protobuf.h\"\n#include \"oneflow/core/persistence/persistent_out_stream.h\"\n\nnamespace oneflow {\n\nclass LogStreamDestination final {\n public:\n  LogStreamDestination(fs::FileSystem* file_system, const std::string& base_dir)\n      : file_system_(file_system), base_dir_(base_dir) {}\n  ~LogStreamDestination() = default;\n  fs::FileSystem* mut_file_system() const { return file_system_; };\n  const std::string& base_dir() const { return base_dir_; };\n\n private:\n  fs::FileSystem* file_system_;\n  std::string base_dir_;\n};\n\nclass TeePersistentLogStream final {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(TeePersistentLogStream);\n  ~TeePersistentLogStream();\n\n  void Write(const char* s, size_t n);\n  void Write(const std::string& str);\n  void Write(const PbMessage& proto);\n\n  static std::unique_ptr<TeePersistentLogStream> Create(const std::string& path);\n  void Flush();\n\n private:\n  explicit TeePersistentLogStream(const std::string& path);\n  std::vector<LogStreamDestination> destinations_;\n  std::vector<std::unique_ptr<PersistentOutStream>> branches_;\n};\n\ninline TeePersistentLogStream& operator<<(TeePersistentLogStream& log_stream,\n                                          const std::string& s) {\n  log_stream.Write(s.c_str(), s.size());\n  return log_stream;\n}\n\ninline std::unique_ptr<TeePersistentLogStream>& operator<<(\n    std::unique_ptr<TeePersistentLogStream>& log_stream, const std::string& s) {\n  log_stream->Write(s.c_str(), s.size());\n  return log_stream;\n}\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_PERSISTENCE_TEE_PERSISTENT_LOG_STREAM_H_\n"
  },
  {
    "path": "oneflow/core/platform/include/ibv.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#if defined(WITH_RDMA)\n#ifndef ONEFLOW_CORE_PLATFORM_INCLUDE_IBV_H_\n#define ONEFLOW_CORE_PLATFORM_INCLUDE_IBV_H_\n#include \"oneflow/core/platform/include/wrapper.h\"\n#include <infiniband/verbs.h>\n\nnamespace oneflow {\n\nnamespace ibv {\n// has to add extern otherwise it fails to compile at changes meaning of functions\nextern \"C\" typedef struct IBV {\n#define IBV_APIS(_)       \\\n  _(ibv_free_device_list) \\\n  _(ibv_destroy_qp)       \\\n  _(ibv_query_gid)        \\\n  _(ibv_fork_init)        \\\n  _(ibv_open_device)      \\\n  _(ibv_destroy_cq)       \\\n  _(ibv_alloc_pd)         \\\n  _(ibv_modify_qp)        \\\n  _(ibv_dealloc_pd)       \\\n  _(ibv_get_device_list)  \\\n  _(ibv_close_device)     \\\n  _(ibv_create_qp)        \\\n  _(ibv_dereg_mr)         \\\n  _(ibv_create_cq)        \\\n  _(ibv_query_device)     \\\n  _(ibv_get_device_name)\n\n#define DECLARE_ONE(name) decltype(&name) name;\n  IBV_APIS(DECLARE_ONE)\n#undef DECLARE_ONE\n  // for a function is not only a function but also a macro,\n  // it requires an alternative name\n  struct ibv_mr* (*ibv_reg_mr_wrap)(struct ibv_pd* pd, void* addr, size_t length, int access);\n  int (*ibv_query_port_wrap)(struct ibv_context* context, uint8_t port_num,\n                             struct ibv_port_attr* port_attr);\n} IBV;\n\nbool IsAvailable();\n\nextern IBV wrapper;\n\n}  // namespace ibv\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_PLATFORM_INCLUDE_IBV_H_\n#endif  // WITH_RDMA\n"
  },
  {
    "path": "oneflow/core/platform/include/pthread_fork.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_PLATFORM_INCLUDE_PTHREAD_FORK_H_\n#define ONEFLOW_CORE_PLATFORM_INCLUDE_PTHREAD_FORK_H_\n\nnamespace oneflow {\n\nnamespace pthread_fork {\n\nbool IsForkedSubProcess();\n\nextern const char* kOfCudaNotSupportInForkedSubProcess;\n\n}  // namespace pthread_fork\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_PLATFORM_INCLUDE_PTHREAD_FORK_H_\n"
  },
  {
    "path": "oneflow/core/platform/include/wrapper.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_PLATFORM_INCLUDE_WRAPPER_H_\n#define ONEFLOW_CORE_PLATFORM_INCLUDE_WRAPPER_H_\n\n#include \"oneflow/core/common/util.h\"\n\nnamespace oneflow {\n\nnamespace platform {\n\nclass DynamicLibrary {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(DynamicLibrary);\n  ~DynamicLibrary();\n\n  static std::unique_ptr<DynamicLibrary> Load(const std::vector<std::string>& names);\n  void* LoadSym(const char* name);\n#ifdef __linux__\n  std::string AbsolutePath();\n#endif  // __linux__\n\n private:\n  DynamicLibrary(void* handle) : handle_(handle){};\n  void* handle_ = nullptr;\n};\n\n}  // namespace platform\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_PLATFORM_INCLUDE_WRAPPER_H_\n"
  },
  {
    "path": "oneflow/core/platform/lib/ibv_wrapper.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#if defined(WITH_RDMA)\n#include \"oneflow/core/platform/include/ibv.h\"\n\nnamespace oneflow {\n\nnamespace ibv {\n\nstd::vector<std::string> GetLibPaths() {\n  const char* custom_path = std::getenv(\"ONEFLOW_LIBIBVERBS_PATH\");\n  if (custom_path == nullptr) {\n    return {\"libibverbs.so.1\", \"libibverbs.so\"};\n  } else {\n    return {custom_path};\n  }\n}\n\nplatform::DynamicLibrary* GetIBVLibraryPtr() {\n  static std::unique_ptr<platform::DynamicLibrary> lib =\n      platform::DynamicLibrary::Load(GetLibPaths());\n  return lib.get();\n}\n\nplatform::DynamicLibrary& GetIBVLibrary() {\n  platform::DynamicLibrary* lib = GetIBVLibraryPtr();\n  CHECK(lib != nullptr) << \"fail to find libibverbs\";\n  return *lib;\n}\n\ntemplate<typename FUNC>\nFUNC LoadSymbol(const char* name, FUNC* save) {\n  auto fn = reinterpret_cast<FUNC>(GetIBVLibrary().LoadSym(name));\n  if (!fn) {\n    std::cerr << \"Can't load libibverbs symbol \" << name << \"\\n\";\n    abort();\n  };\n  *save = fn;\n  return fn;\n}\n\nbool IsAvailable() { return GetIBVLibraryPtr() != nullptr; }\n\nnamespace _stubs {\n\nvoid ibv_free_device_list(struct ibv_device** list) {\n  return LoadSymbol(__func__, &wrapper.ibv_free_device_list)(list);\n}\n\nstruct ibv_mr* ibv_reg_mr_wrap(struct ibv_pd* pd, void* addr, size_t length, int access) {\n  return LoadSymbol(\"ibv_reg_mr\", &wrapper.ibv_reg_mr_wrap)(pd, addr, length, access);\n}\n\nint ibv_destroy_qp(struct ibv_qp* qp) { return LoadSymbol(__func__, &wrapper.ibv_destroy_qp)(qp); }\n\nint ibv_query_gid(struct ibv_context* context, uint8_t port_num, int index, union ibv_gid* gid) {\n  return LoadSymbol(__func__, &wrapper.ibv_query_gid)(context, port_num, index, gid);\n}\n\nint ibv_fork_init(void) { return LoadSymbol(__func__, &wrapper.ibv_fork_init)(); }\n\nint ibv_query_port_wrap(struct ibv_context* context, uint8_t port_num,\n                        struct ibv_port_attr* port_attr) {\n  return LoadSymbol(\"ibv_query_port\", &wrapper.ibv_query_port_wrap)(context, port_num, port_attr);\n}\n\nstruct ibv_context* ibv_open_device(struct ibv_device* device) {\n  return LoadSymbol(__func__, &wrapper.ibv_open_device)(device);\n}\n\nint ibv_destroy_cq(struct ibv_cq* cq) { return LoadSymbol(__func__, &wrapper.ibv_destroy_cq)(cq); }\n\nstruct ibv_pd* ibv_alloc_pd(struct ibv_context* context) {\n  return LoadSymbol(__func__, &wrapper.ibv_alloc_pd)(context);\n}\n\nint ibv_modify_qp(struct ibv_qp* qp, struct ibv_qp_attr* attr, int attr_mask) {\n  return LoadSymbol(__func__, &wrapper.ibv_modify_qp)(qp, attr, attr_mask);\n}\n\nint ibv_dealloc_pd(struct ibv_pd* pd) { return LoadSymbol(__func__, &wrapper.ibv_dealloc_pd)(pd); }\n\nstruct ibv_device** ibv_get_device_list(int* num_devices) {\n  return LoadSymbol(__func__, &wrapper.ibv_get_device_list)(num_devices);\n}\n\nint ibv_close_device(struct ibv_context* context) {\n  return LoadSymbol(__func__, &wrapper.ibv_close_device)(context);\n}\n\nstruct ibv_qp* ibv_create_qp(struct ibv_pd* pd, struct ibv_qp_init_attr* qp_init_attr) {\n  return LoadSymbol(__func__, &wrapper.ibv_create_qp)(pd, qp_init_attr);\n}\n\nint ibv_dereg_mr(struct ibv_mr* mr) { return LoadSymbol(__func__, &wrapper.ibv_dereg_mr)(mr); }\n\nstruct ibv_cq* ibv_create_cq(struct ibv_context* context, int cqe, void* cq_context,\n                             struct ibv_comp_channel* channel, int comp_vector) {\n  return LoadSymbol(__func__, &wrapper.ibv_create_cq)(context, cqe, cq_context, channel,\n                                                      comp_vector);\n}\n\nint ibv_query_device(struct ibv_context* context, struct ibv_device_attr* device_attr) {\n  return LoadSymbol(__func__, &wrapper.ibv_query_device)(context, device_attr);\n}\n\nconst char* ibv_get_device_name(struct ibv_device* device) {\n  return LoadSymbol(__func__, &wrapper.ibv_get_device_name)(device);\n}\n\n}  // namespace _stubs\n\nIBV wrapper = {\n#define _REFERENCE_MEMBER(name) _stubs::name,\n    IBV_APIS(_REFERENCE_MEMBER)\n#undef _REFERENCE_MEMBER\n        _stubs::ibv_reg_mr_wrap,\n    _stubs::ibv_query_port_wrap};\n\n}  // namespace ibv\n}  // namespace oneflow\n#endif  // WITH_RDMA\n"
  },
  {
    "path": "oneflow/core/platform/lib/pthread_fork.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/platform/include/pthread_fork.h\"\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/vm/virtual_machine.h\"\n#include \"oneflow/core/vm/vm_util.h\"\n#include \"oneflow/core/vm/sync_vm_mode_guard.h\"\n\nnamespace oneflow {\n\nnamespace pthread_fork {\n\nstatic bool is_fork = false;\n\nbool IsForkedSubProcess() { return is_fork; }\nstatic void SetIsForkedSubProcess() { is_fork = true; }\n\nnamespace {\nvoid CurrentRankVmSync() {\n  if (SyncVmModeGuard::IsCurrentSyncVmMode()) { return; }\n  // Instructions in forked subprocesses are not dispatched to vm,\n  // so no need to sync vm in these processes.\n  if (!is_fork && Singleton<VirtualMachine>::Get() != nullptr) {\n    CHECK_JUST(vm::CurrentRankSync());\n  }\n}\n}  // namespace\n\nvoid RegisterForkCallback() { pthread_atfork(&CurrentRankVmSync, nullptr, &SetIsForkedSubProcess); }\nCOMMAND(RegisterForkCallback());\n\nconst char* kOfCudaNotSupportInForkedSubProcess =\n    \"Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you \"\n    \"must add 'multiprocessing.set_start_method(\\\"spawn\\\")' in '__main__' if you are using \"\n    \"Python's multiprocessing\";\n\n}  // namespace pthread_fork\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/platform/lib/wrapper.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/platform/include/wrapper.h\"\n#include <dlfcn.h>\n\n#ifdef __linux__\n#include <link.h>\n#endif  // __linux__\n\nnamespace oneflow {\nnamespace platform {\n\nnamespace {\n\nvoid* OpenSymbol(void* handle, const char* name) {\n  void* ret = dlsym(handle, name);\n  if (!ret) {\n    std::cerr << \"Error in dlopen or dlsym: \" << dlerror() << \"\\n\";\n    abort();\n  }\n  return ret;\n}\n\n}  // namespace\n\n// original implementation is from pytorch:\n// https://github.com/pytorch/pytorch/blob/259d19a7335b32c4a27a018034551ca6ae997f6b/aten/src/ATen/DynamicLibrary.cpp\n\nstd::unique_ptr<DynamicLibrary> DynamicLibrary::Load(const std::vector<std::string>& names) {\n  for (const std::string& name : names) {\n    void* handle = dlopen(name.c_str(), RTLD_LOCAL | RTLD_NOW);\n    if (handle != nullptr) {\n      DynamicLibrary* lib = new DynamicLibrary(handle);\n      return std::unique_ptr<DynamicLibrary>(lib);\n    }\n  }\n  return std::unique_ptr<DynamicLibrary>();\n}\n\nvoid* DynamicLibrary::LoadSym(const char* name) { return OpenSymbol(handle_, name); }\n\n#ifdef __linux__\nstd::string DynamicLibrary::AbsolutePath() {\n  struct link_map* map;\n  dlinfo(handle_, RTLD_DI_LINKMAP, &map);\n  return map->l_name;\n}\n#endif  // __linux__\n\nDynamicLibrary::~DynamicLibrary() { dlclose(handle_); }\n\n}  // namespace platform\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/profiler/event.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"fmt/core.h\"\n#include \"fmt/format.h\"\n#include \"oneflow/core/profiler/event.h\"\n#include \"oneflow/core/profiler/util.h\"\n\nusing json = nlohmann::json;\n\nnamespace oneflow {\n\nnamespace profiler {\nnlohmann::json IEvent::ToJson() { return json{{\"name\", name_}, {\"time\", GetDuration<double>()}}; }\n\nvoid IEvent::SetStartedAt(double t) { started_at_ = t; }\n\nvoid IEvent::SetFinishedAt(double t) { finished_at_ = t; }\n\nvoid IEvent::Start() { SetStartedAt(GetTimeNow()); }\n\nvoid IEvent::Finish() { SetFinishedAt(GetTimeNow()); }\n\nbool IEvent::IsChildOf(const IEvent* e) {\n  if (!e) { return false; }\n  if (this == e) { return false; }\n  return GetStartedAt<double>() >= e->GetStartedAt<double>()\n         && GetFinishedAt<double>() <= e->GetFinishedAt<double>();\n}\n\nconst std::string& IEvent::GetName() const { return name_; }\n\nnlohmann::json CustomEvent::ToJson() {\n  auto j = IEvent::ToJson();\n  j[\"type\"] = EventType::kCustom;\n  j[\"custom_type\"] = type_;\n  return j;\n}\n\nstd::shared_ptr<CustomEvent> CustomEvent::Create(const std::string& name, CustomEventType type) {\n  return std::shared_ptr<CustomEvent>(new CustomEvent(name, type));\n}\n\nnlohmann::json KernelEvent::ToJson() {\n  auto j = IEvent::ToJson();\n  j[\"type\"] = EventType::kOneflowKernel;\n  for (const auto& desc : description_) {\n    j[\"description\"][desc.first] = {desc.second.first, desc.second.second};\n  }\n#if defined(WITH_CUDA)\n  j[\"memory_size\"] = memory_size_;\n  if (!children_.empty()) { j[\"children\"] = children_; }\n#endif  // WITH_CUDA\n  return j;\n}\n\nstd::shared_ptr<KernelEvent> KernelEvent::Create(const std::string& name,\n                                                 const Description& description) {\n  return std::shared_ptr<KernelEvent>(new KernelEvent(name, description));\n}\n\n}  // namespace profiler\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/profiler/event.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_PROFILER_EVENT_H_\n#define ONEFLOW_CORE_PROFILER_EVENT_H_\n\n#include <functional>\n#include <memory>\n#include <vector>\n#include \"nlohmann/json.hpp\"\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/common/shape_view.h\"\n\nnamespace oneflow {\n\nnamespace profiler {\n\nclass ProfileManager;\n\nenum class EventType {\n  kCustom,        // has three kinds\n  kOneflowKernel  // OneFlow cpu/cuda kernel\n};\nenum class CustomEventType {\n  kDefault,     // for record_function\n  kCudaKernel,  // cuda kernel\n  kCudaRuntime  // something like cudaLaunchKernel\n};\nenum class EventTimeUnit { kNS, kUS };\n\nclass IEvent {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(IEvent);\n\n  IEvent() = delete;\n  IEvent(const std::string& name, EventTimeUnit time_unit) : name_(name), time_unit_(time_unit) {}\n\n  virtual nlohmann::json ToJson();\n  virtual ~IEvent() = default;\n\n  virtual void Start();\n  virtual void Finish();\n  bool IsChildOf(const IEvent* e);\n\n  const std::string& GetName() const;\n  template<typename T>\n  const T GetDuration(EventTimeUnit time_unit = EventTimeUnit::kUS) const;\n  template<typename T>\n  const T GetStartedAt(EventTimeUnit time_unit = EventTimeUnit::kUS) const;\n  template<typename T>\n  const T GetFinishedAt(EventTimeUnit time_unit = EventTimeUnit::kUS) const;\n\n protected:\n  virtual void SetStartedAt(double t);\n  virtual void SetFinishedAt(double t);\n\n  std::string name_;\n  EventTimeUnit time_unit_;\n  double started_at_ = 0;\n  double finished_at_ = 0;\n};\n\ninline double ConvertTime(double time_, EventTimeUnit src_time_unit, EventTimeUnit dst_time_unit) {\n  if (src_time_unit == EventTimeUnit::kNS && dst_time_unit == EventTimeUnit::kUS) {\n    return time_ / 1000;\n  }\n  if (src_time_unit == EventTimeUnit::kUS && dst_time_unit == EventTimeUnit::kNS) {\n    return time_ * 1000;\n  }\n  return time_;\n}\n\ntemplate<>\nconst inline double IEvent::GetStartedAt<double>(EventTimeUnit time_unit) const {\n  return ConvertTime(started_at_, time_unit_, time_unit);\n}\n\ntemplate<>\nconst inline time_t IEvent::GetStartedAt<time_t>(EventTimeUnit time_unit) const {\n  return static_cast<time_t>(GetStartedAt<double>(time_unit));\n}\n\ntemplate<>\nconst inline double IEvent::GetFinishedAt<double>(EventTimeUnit time_unit) const {\n  return ConvertTime(finished_at_, time_unit_, time_unit);\n}\n\ntemplate<>\nconst inline time_t IEvent::GetFinishedAt<time_t>(EventTimeUnit time_unit) const {\n  return static_cast<time_t>(GetFinishedAt<double>(time_unit));\n}\n\ntemplate<>\nconst inline double IEvent::GetDuration<double>(EventTimeUnit time_unit) const {\n  return GetFinishedAt<double>(time_unit) - GetStartedAt<double>(time_unit);\n}\n\ntemplate<>\nconst inline time_t IEvent::GetDuration<time_t>(EventTimeUnit time_unit) const {\n  return static_cast<time_t>(GetDuration<double>(time_unit));\n}\n\nclass CustomEvent final : public IEvent {\n public:\n  friend class ProfileManager;\n\n  nlohmann::json ToJson() override;\n\n  static std::shared_ptr<CustomEvent> Create(const std::string& name,\n                                             CustomEventType type = CustomEventType::kDefault);\n\n private:\n  CustomEventType type_;\n  CustomEvent(const std::string& custom_name, CustomEventType type)\n      : IEvent(custom_name,\n               type == CustomEventType::kDefault ? EventTimeUnit::kNS : EventTimeUnit::kUS),\n        type_(type) {}\n};\n\nclass KernelEvent final : public IEvent {\n public:\n  using Description = std::map<std::string, std::pair<std::string, int64_t>>;\n\n  nlohmann::json ToJson() override;\n\n  static std::shared_ptr<KernelEvent> Create(const std::string& name,\n                                             const Description& description);\n\n#if defined(WITH_CUDA)\n  void SetMemorySize(int64_t memory_size) { memory_size_ = memory_size; }\n  void AddChildEvent(const std::shared_ptr<IEvent>& e) { children_.emplace(e); }\n  bool AddChildEventIfSo(const std::shared_ptr<IEvent>& e) {\n    if (e->IsChildOf(dynamic_cast<IEvent*>(this))) {\n      children_.emplace(e);\n      return true;\n    }\n    return false;\n  }\n  bool HasChildEvent(const std::shared_ptr<IEvent>& e) { return children_.count(e); }\n  void WalkAmongChildren(const std::function<void(const std::shared_ptr<IEvent>& e)>& f) const {\n    for (const auto& x : children_) { f(x); }\n  }\n#endif  // WITH_CUDA\n\n private:\n  KernelEvent(const std::string& kernel_name, const Description& description)\n      : IEvent(kernel_name, EventTimeUnit::kNS), description_(description) {}\n\n#if defined(WITH_CUDA)\n  int64_t memory_size_ = -1;\n  std::set<std::shared_ptr<IEvent>> children_;\n#endif  // WITH_CUDA\n\n  const Description description_;\n};\n\n}  // namespace profiler\n}  // namespace oneflow\n\nnamespace nlohmann {\n\ninline void to_json(json& j, const std::shared_ptr<::oneflow::profiler::IEvent>& event) {\n  j = event->ToJson();\n}\n\n}  // namespace nlohmann\n\n#endif  // ONEFLOW_CORE_PROFILER_EVENT_H_\n"
  },
  {
    "path": "oneflow/core/profiler/event_recorder.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/profiler/event_recorder.h\"\n#include \"oneflow/core/profiler/profile_manager.h\"\n#include \"oneflow/core/common/shape_view.h\"\n\nnamespace oneflow {\nnamespace profiler {\n\nMaybe<void> EventRecorder::RegisterEventToProfileManager(const std::shared_ptr<IEvent>& event) {\n  auto* pmgr = JUST(SingletonMaybe<ProfileManager>());\n  pmgr->events_.push(event_);\n  return Maybe<void>::Ok();\n}\n\nstd::shared_ptr<EventRecorder> EventRecorder::CreateCustomEventRecorder(const std::string& name) {\n  return std::make_shared<EventRecorder>(CustomEvent::Create(name));\n}\n\nMaybe<EventRecorder> EventRecorder::CreateKernelEventRecorder(\n    const std::string& name,\n#if defined(WITH_CUDA)\n    const std::function<int64_t()>& memory_size_getter,\n#endif\n    const DescriptionGetter& input_shapes_getter, const DescriptionGetter& attrs_getter) {\n  auto pmgr = Singleton<ProfileManager>::Get();\n  if (pmgr) {\n    const auto description_getter = [pmgr, input_shapes_getter, attrs_getter]() {\n      KernelEvent::Description desc;\n      if (pmgr->record_shapes_) { desc[\"input_shapes\"] = input_shapes_getter(); }\n      if (pmgr->record_attrs_) { desc[\"attrs\"] = attrs_getter(); }\n      return desc;\n    };\n#if defined(WITH_CUDA)\n    if (pmgr->use_cpu_ || pmgr->use_cuda_) {\n      auto event = KernelEvent::Create(name, description_getter());\n      if (pmgr->use_cuda_) {\n        if (pmgr->record_bandwidth_) { event->SetMemorySize(memory_size_getter()); }\n      }\n      return std::make_shared<EventRecorder>(event);\n    }\n#else\n    if (pmgr->use_cpu_) {\n      return std::make_shared<EventRecorder>(KernelEvent::Create(name, description_getter()));\n    }\n#endif  // WITH_CUDA\n  }\n\n  std::shared_ptr<EventRecorder> null_recorder;\n  return null_recorder;\n}\n\n}  // namespace profiler\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/profiler/event_recorder.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_PROFILER_EVENT_RECORDER_H_\n#define ONEFLOW_CORE_PROFILER_EVENT_RECORDER_H_\n\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/profiler/event.h\"\n\nnamespace oneflow {\nnamespace profiler {\n\nclass EventRecorder {\n public:\n  using DescriptionGetter = std::function<std::pair<std::string, int64_t>()>;\n\n  OF_DISALLOW_COPY_AND_MOVE(EventRecorder);\n\n  explicit EventRecorder(const std::shared_ptr<IEvent>& event) : event_(event) {\n    CHECK_JUST(RegisterEventToProfileManager(event));\n    event_->Start();\n  }\n\n  Maybe<void> RegisterEventToProfileManager(const std::shared_ptr<IEvent>& event);\n\n  ~EventRecorder() {\n    if (event_) {\n      event_->Finish();\n      event_.reset();\n    }\n  }\n\n  static std::shared_ptr<EventRecorder> CreateCustomEventRecorder(const std::string& name);\n\n  static Maybe<EventRecorder> CreateKernelEventRecorder(\n      const std::string& name,\n#if defined(WITH_CUDA)\n      const std::function<int64_t()>& memory_size_getter,\n#endif\n      const DescriptionGetter& input_shapes_getter, const DescriptionGetter& attrs_getter);\n\n private:\n  std::shared_ptr<IEvent> event_;\n};\n\n}  // namespace profiler\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_PROFILER_EVENT_RECORDER_H_\n"
  },
  {
    "path": "oneflow/core/profiler/kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/profiler/kernel.h\"\n#include \"oneflow/core/profiler/profiler.h\"\n#include \"oneflow/core/kernel/kernel.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n#include \"oneflow/core/lazy/actor/actor_context.h\"\n\nnamespace oneflow {\n\nnamespace profiler {\n\nnamespace {\n\nbool profile_cuda_memory_bandwidth = false;\nbool profile_kernel_forward_range = false;\n\nvoid Init() {\n  profile_cuda_memory_bandwidth =\n      ParseBooleanFromEnv(\"ONEFLOW_PROFILER_KERNEL_PROFILE_CUDA_MEMORY_BANDWIDTH\", false);\n  profile_kernel_forward_range =\n      ParseBooleanFromEnv(\"ONEFLOW_PROFILER_KERNEL_PROFILE_KERNEL_FORWARD_RANGE\", false);\n}\n\nCOMMAND(Init());\n\n#if defined(WITH_CUDA)\nthread_local cudaEvent_t cuda_memory_bandwidth_profile_start_event = nullptr;\nthread_local cudaEvent_t cuda_memory_bandwidth_profile_end_event = nullptr;\n#endif  // WITH_CUDA\n\n}  // namespace\n\nvoid TraceKernelForwardDataContentStart(KernelContext* kernel_ctx, const Kernel* kernel) {\n#if defined(WITH_CUDA)\n  if (profile_cuda_memory_bandwidth) {\n    auto* actor_context_provider = dynamic_cast<ActorContextProvider*>(kernel_ctx);\n    auto* cuda_stream = dynamic_cast<ep::CudaStream*>(kernel_ctx->stream());\n    if (cuda_stream != nullptr && actor_context_provider != nullptr) {\n      CHECK(cuda_memory_bandwidth_profile_start_event == nullptr);\n      CHECK(cuda_memory_bandwidth_profile_end_event == nullptr);\n      OF_CUDA_CHECK(cudaEventCreate(&cuda_memory_bandwidth_profile_start_event));\n      OF_CUDA_CHECK(cudaEventCreate(&cuda_memory_bandwidth_profile_end_event));\n      OF_CUDA_CHECK(\n          cudaEventRecord(cuda_memory_bandwidth_profile_start_event, cuda_stream->cuda_stream()));\n    }\n  }\n  if (profile_kernel_forward_range) { OF_PROFILER_RANGE_PUSH(kernel->op_conf().name()); }\n#endif  // WITH_CUDA\n}\n\nvoid TraceKernelForwardDataContentEnd(KernelContext* kernel_ctx, const Kernel* kernel) {\n#if defined(WITH_CUDA)\n  if (profile_kernel_forward_range) { OF_PROFILER_RANGE_POP(); }\n  // The memory bandwidth profiler only works in lazy mode.\n  if (profile_cuda_memory_bandwidth) {\n    auto* cuda_stream = dynamic_cast<ep::CudaStream*>(kernel_ctx->stream());\n    auto* actor_context_provider = dynamic_cast<ActorContextProvider*>(kernel_ctx);\n    if (cuda_stream != nullptr && actor_context_provider != nullptr) {\n      cudaEvent_t start_event = cuda_memory_bandwidth_profile_start_event;\n      cudaEvent_t end_event = cuda_memory_bandwidth_profile_end_event;\n      cuda_memory_bandwidth_profile_start_event = nullptr;\n      cuda_memory_bandwidth_profile_end_event = nullptr;\n      CHECK_NOTNULL(start_event);\n      CHECK_NOTNULL(end_event);\n      OF_CUDA_CHECK(cudaEventRecord(end_event, cuda_stream->cuda_stream()));\n      int64_t memory_size = 0;\n      for (const auto& bn : kernel->op_attribute().input_bns()) {\n        const Blob* blob = kernel_ctx->BnInOp2Blob(bn);\n        if (blob) { memory_size += blob->ByteSizeOfBlobBody(); }\n      }\n      for (const auto& bn : kernel->op_attribute().output_bns()) {\n        const Blob* blob = kernel_ctx->BnInOp2Blob(bn);\n        if (blob) { memory_size += blob->ByteSizeOfBlobBody(); }\n      }\n      const std::string op_name = kernel->op_conf().name();\n      actor_context_provider->GetActorContext()->AddCallback(\n          [start_event, end_event, memory_size, op_name]() {\n            float elapsed_ms = 0;\n            OF_CUDA_CHECK(cudaEventElapsedTime(&elapsed_ms, start_event, end_event));\n            OF_CUDA_CHECK(cudaEventDestroy(start_event));\n            OF_CUDA_CHECK(cudaEventDestroy(end_event));\n            double bandwidth =\n                static_cast<double>(memory_size) / (1024.0 * 1024.0 * 1024.0) / (elapsed_ms / 1000);\n            LOG(INFO) << \"PROFILER::KERNEL::CUDA_MEMORY_BANDWIDTH op_name: \" << op_name\n                      << \" elapsed(ms): \" << elapsed_ms << \" memory_size(Byte): \" << memory_size\n                      << \" bandwidth(GB/s): \" << bandwidth;\n          });\n    }\n  }\n#endif  // WITH_CUDA\n}\n\n}  // namespace profiler\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/profiler/kernel.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_PROFILER_KERNEL_H_\n#define ONEFLOW_CORE_PROFILER_KERNEL_H_\n\n#include \"oneflow/core/common/util.h\"\n\nnamespace oneflow {\n\nclass Kernel;\nclass KernelContext;\nclass Blob;\n\nnamespace profiler {\n\nvoid TraceKernelForwardDataContentStart(KernelContext* kernel_ctx, const Kernel* kernel);\n\nvoid TraceKernelForwardDataContentEnd(KernelContext* kernel_ctx, const Kernel* kernel);\n\n}  // namespace profiler\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_PROFILER_KERNEL_H_\n"
  },
  {
    "path": "oneflow/core/profiler/kineto_shim.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#if defined(WITH_CUDA)\n\n#include \"oneflow/core/profiler/kineto_shim.h\"\n#include \"libkineto.h\"\n\nnamespace oneflow {\n\nnamespace profiler {\nnamespace {\n\nconst std::set<libkineto::ActivityType> cpuTypes{\n    libkineto::ActivityType::CPU_OP,          libkineto::ActivityType::CPU_INSTANT_EVENT,\n    libkineto::ActivityType::USER_ANNOTATION, libkineto::ActivityType::EXTERNAL_CORRELATION,\n    libkineto::ActivityType::CUDA_RUNTIME,  // something like cudaLaunchKernel\n    libkineto::ActivityType::PYTHON_FUNCTION,\n};\n\nconst std::set<libkineto::ActivityType> cudaTypes = {\n    libkineto::ActivityType::GPU_MEMCPY, libkineto::ActivityType::GPU_MEMSET,\n    libkineto::ActivityType::CONCURRENT_KERNEL,  // cuda kernel\n    // CUDA_RUNTIME appears in both cpuTypes and cudaTypes.\n    libkineto::ActivityType::CUDA_RUNTIME,  // something like cudaLaunchKernel\n};\n}  // namespace\n\nActivityTraceWrapper::ActivityTraceWrapper(std::unique_ptr<interface_trace_t> trace)\n    : trace_(std::move(trace)), saved_{false} {}\n\nActivityTraceWrapper::operator bool() const { return trace_ != nullptr; }\n\nvoid ActivityTraceWrapper::save(const std::string& path) {\n  //   TORCH_CHECK(!saved_, \"Trace is already saved.\");\n  //   TORCH_CHECK(trace_ != nullptr, \"Missing trace.\")\n  trace_->save(path);\n  saved_ = true;\n}\n\nvoid PrepareTrace(const bool cpuOnly, const ActivitySet& activities) {\n  if (!libkineto::api().isProfilerRegistered()) {\n    libkineto_init(/*cpuOnly=*/cpuOnly, /*logOnError=*/true);\n    libkineto::api().suppressLogMessages();\n  }\n\n  if (!libkineto::api().isProfilerInitialized()) { libkineto::api().initProfilerIfRegistered(); }\n\n  std::set<libkineto::ActivityType> k_activities;\n  if (activities.count(ActivityType::CPU)) {\n    k_activities.insert(cpuTypes.begin(), cpuTypes.end());\n  }\n  if (activities.count(ActivityType::CUDA)) {\n    k_activities.insert(cudaTypes.begin(), cudaTypes.end());\n  }\n\n  libkineto::api().activityProfiler().prepareTrace(k_activities);\n}\n\nvoid StartTrace() { libkineto::api().activityProfiler().startTrace(); }\n\nActivityTraceWrapper StopTrace() {\n  return ActivityTraceWrapper{libkineto::api().activityProfiler().stopTrace()};\n}\n\n}  // namespace profiler\n}  // namespace oneflow\n\n#endif  // WITH_CUDA\n"
  },
  {
    "path": "oneflow/core/profiler/kineto_shim.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_PROFILER_KINETO_SHIM_H_\n#define ONEFLOW_CORE_PROFILER_KINETO_SHIM_H_\n\n#if defined(WITH_CUDA)\n\n#include <string>\n#include <memory>\n#include <set>\n\nnamespace libkineto {\n\nenum class ActivityType;\nclass ActivityTraceInterface;\n\n}  // namespace libkineto\n\nnamespace oneflow {\n\nnamespace profiler {\n\nenum class ActivityType {\n  CPU = 0,\n  CUDA,\n};\n\nusing interface_trace_t = libkineto::ActivityTraceInterface;\n\nstruct ActivityTraceWrapper {\n  explicit ActivityTraceWrapper(std::unique_ptr<interface_trace_t> trace);\n  ActivityTraceWrapper() = default;\n  ActivityTraceWrapper(ActivityTraceWrapper&&) = default;\n  ActivityTraceWrapper(const ActivityTraceWrapper&) = delete;\n  explicit operator bool() const;\n  void save(const std::string& path);\n\n  const std::unique_ptr<interface_trace_t>& get() { return trace_; }\n\n private:\n  std::unique_ptr<interface_trace_t> trace_;\n  bool saved_ = false;  // Kineto's save is destructive\n};\n\nusing ActivitySet = std::set<ActivityType>;\nvoid PrepareTrace(const bool cpuOnly, const ActivitySet& activities);\nvoid StartTrace();\nActivityTraceWrapper StopTrace();\n\n}  // namespace profiler\n}  // namespace oneflow\n\n#endif  // WITH_CUDA\n\n#endif  // ONEFLOW_CORE_PROFILER_KINETO_SHIM_H_\n"
  },
  {
    "path": "oneflow/core/profiler/profile_manager.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <memory>\n#include <unordered_map>\n#include \"fmt/core.h\"\n#include \"nlohmann/json.hpp\"\n#include \"oneflow/core/profiler/kineto_shim.h\"\n#include \"oneflow/core/profiler/profile_manager.h\"\n#include \"oneflow/core/profiler/event.h\"\n#if defined(WITH_CUDA)\n#include <libkineto.h>\n#endif  // WITH_CUDA\n\nusing json = nlohmann::json;\n\nnamespace oneflow {\nnamespace profiler {\n\nstd::string ProfileManager::RegisterEventRecorder(\n    const std::shared_ptr<EventRecorder>& event_recorder, const std::string& name) {\n  std::string recorder_key = GetNextEventRecorderKey(name);\n  event_recorders_.emplace(recorder_key, event_recorder);\n  return recorder_key;\n}\n\nvoid ProfileManager::UnregisterEventRecorder(const std::string& event_recorder_key) {\n  if (event_recorders_.find(event_recorder_key) != event_recorders_.end()) {\n    event_recorders_.erase(event_recorder_key);\n  }\n}\n\nstd::string ProfileManager::DumpResultsJson() {\n  const json j = ExportEvents();\n  return j.dump();\n}\n\nstd::vector<std::shared_ptr<IEvent>> ProfileManager::ExportEvents() {\n#if defined(WITH_CUDA)\n  auto trace = StopTrace();\n  const auto& kineto_events = *(trace.get()->activities());\n  std::set<std::shared_ptr<IEvent>> custom_events;\n  std::unordered_map<std::shared_ptr<IEvent>, int64_t> corr_ids;\n\n  const std::vector<std::pair<libkineto::ActivityType, CustomEventType>> type_pairs = {\n      {libkineto::ActivityType::CUDA_RUNTIME, CustomEventType::kCudaRuntime},\n      {libkineto::ActivityType::CONCURRENT_KERNEL, CustomEventType::kCudaKernel}};\n\n  for (const auto& evt_ptr : kineto_events) {\n    if (evt_ptr == nullptr) { continue; }\n    const auto& activity = *evt_ptr;\n    for (auto& pair : type_pairs) {\n      if (activity.type() == pair.first) {\n        auto custom_event = CustomEvent::Create(activity.name(), pair.second);\n        custom_event->SetStartedAt(static_cast<time_t>(activity.timestamp()));\n        custom_event->SetFinishedAt(static_cast<time_t>(activity.timestamp())\n                                    + activity.duration());\n        custom_events.emplace(custom_event);\n        corr_ids[custom_event] = activity.correlationId();\n      }\n    }\n  }\n#endif  // WITH_CUDA\n  std::vector<std::shared_ptr<IEvent>> events;\n  while (!events_.empty()) {\n    auto evt = events_.front();\n    events_.pop();\n#if defined(WITH_CUDA)\n    auto evt_kernel = std::dynamic_pointer_cast<KernelEvent>(evt);\n    if (evt_kernel) {\n      std::set<int64_t> current_corr_ids;\n      if (!custom_events.empty()) {\n        for (const auto& x : custom_events) {\n          if (evt_kernel->AddChildEventIfSo(x)) { current_corr_ids.insert(corr_ids[x]); }\n        }\n        for (const auto& x : custom_events) {\n          if (!evt_kernel->HasChildEvent(x) && current_corr_ids.count(corr_ids[x])) {\n            evt_kernel->AddChildEvent(x);\n          }\n        }\n        evt_kernel->WalkAmongChildren(\n            [&custom_events](const std::shared_ptr<IEvent>& child) { custom_events.erase(child); });\n      }\n    }\n#endif  // WITH_CUDA\n    events.emplace_back(evt);\n  }\n  return events;\n}\n\nstd::string ProfileManager::GetNextEventRecorderKey(const std::string& name) {\n  if (event_recorders_last_id_.find(name) == event_recorders_last_id_.end()) {\n    event_recorders_last_id_[name] = 0;\n  } else {\n    event_recorders_last_id_[name]++;\n  }\n  return fmt::format(\"{}.{}\", name, event_recorders_last_id_[name]);\n}\n\n}  // namespace profiler\n}  // namespace oneflow"
  },
  {
    "path": "oneflow/core/profiler/profile_manager.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_PROFILER_PROFILE_MANAGER_H_\n#define ONEFLOW_CORE_PROFILER_PROFILE_MANAGER_H_\n\n#include <memory>\n#include <queue>\n#include <set>\n#include <unordered_map>\n#include \"oneflow/core/profiler/kineto_shim.h\"\n\nnamespace oneflow {\nnamespace profiler {\n\nclass IEvent;\nclass EventRecorder;\n\nclass ProfileManager {\n public:\n  friend class EventRecorder;\n\n  ProfileManager(bool use_cpu, bool use_cuda, bool record_shapes, bool record_attrs,\n                 bool record_bandwidth)\n      : use_cpu_(use_cpu),\n        use_cuda_(use_cuda),\n        record_shapes_(record_shapes),\n        record_attrs_(record_attrs),\n        record_bandwidth_(record_bandwidth) {\n#if defined(WITH_CUDA)\n    std::set<ActivityType> activities{};\n    if (use_cpu) { activities.insert(ActivityType::CPU); }\n    if (use_cuda) { activities.insert(ActivityType::CUDA); }\n    PrepareTrace(/*cpuOnly*/ false, activities);\n    StartTrace();\n#endif  // WITH_CUDA\n  }\n\n  std::string RegisterEventRecorder(const std::shared_ptr<EventRecorder>& event_recorder,\n                                    const std::string& name);\n  void UnregisterEventRecorder(const std::string& event_recorder_key);\n  std::string DumpResultsJson();\n\n private:\n  bool use_cpu_;\n  bool use_cuda_;\n  bool record_shapes_;\n  bool record_attrs_;\n  bool record_bandwidth_;\n\n  std::queue<std::shared_ptr<IEvent>> events_;\n  std::unordered_map<std::string, std::shared_ptr<EventRecorder>> event_recorders_;\n  // To prevent releasing EventRecorders of the same name.\n  std::unordered_map<std::string, int64_t> event_recorders_last_id_;\n\n  std::string GetNextEventRecorderKey(const std::string& name);\n  std::vector<std::shared_ptr<IEvent>> ExportEvents();\n};\n\n}  // namespace profiler\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_PROFILER_PROFILE_MANAGER_H_\n"
  },
  {
    "path": "oneflow/core/profiler/profiler.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/profiler/profiler.h\"\n#include \"oneflow/core/profiler/profile_manager.h\"\n#include \"oneflow/core/profiler/kineto_shim.h\"\n#include \"oneflow/core/profiler/event_recorder.h\"\n#include \"oneflow/core/vm/vm_util.h\"\n#ifdef WITH_CUDA\n#include \"oneflow/core/device/cuda_util.h\"\n#include <nvtx3/nvToolsExt.h>\n#include <sys/syscall.h>\n#include <iostream>\n#include <cuda_profiler_api.h>\n#endif  // WITH_CUDA\n\nnamespace oneflow {\n\nnamespace profiler {\n\nvoid NameThisHostThread(const std::string& name) {\n#ifdef WITH_CUDA\n  static thread_local std::unique_ptr<std::string> thread_name_prefix;\n  if (!thread_name_prefix) {\n    thread_name_prefix.reset(\n        new std::string(GetStringFromEnv(\"ONEFLOW_PROFILER_HOST_THREAD_NAME_PREFIX\", \"\")));\n  }\n  const std::string name_with_prefix = *thread_name_prefix + name;\n  nvtxNameOsThreadA(syscall(SYS_gettid), name_with_prefix.c_str());\n#endif  // WITH_CUDA\n}\n\nvoid RangePush(const std::string& name) {\n#ifdef OF_ENABLE_PROFILER\n  nvtxRangePushA(name.c_str());\n#endif  // OF_ENABLE_PROFILER\n}\n\nvoid RangePop() {\n#ifdef OF_ENABLE_PROFILER\n  nvtxRangePop();\n#endif  // OF_ENABLE_PROFILER\n}\n\nRangeGuard::RangeGuard(const std::string& name) {\n#ifdef OF_ENABLE_PROFILER\n  RangePush(name);\n#endif  // OF_ENABLE_PROFILER\n}\n\nRangeGuard::~RangeGuard() {\n#ifdef OF_ENABLE_PROFILER\n  RangePop();\n#endif  // OF_ENABLE_PROFILER\n}\n\nvoid LogHostMemoryUsage(const std::string& name) {\n#ifdef OF_ENABLE_PROFILER\n  int64_t vm_pages;\n  int64_t rss_pages;\n  std::ifstream ifs(\"/proc/self/statm\");\n  ifs >> vm_pages >> rss_pages;\n  ifs.close();\n  const int64_t page_size = sysconf(_SC_PAGE_SIZE);\n  LOG(INFO) << \"HostMemoryUsage: \" << name << \" VM \" << vm_pages * page_size << \" RSS \"\n            << rss_pages * page_size;\n#endif  // OF_ENABLE_PROFILER\n}\n\nvoid ProfilerStart() {\n#ifdef OF_ENABLE_PROFILER\n  OF_CUDA_CHECK(cudaProfilerStart());\n#endif  // OF_ENABLE_PROFILER\n}\n\nvoid ProfilerStop() {\n#ifdef OF_ENABLE_PROFILER\n  OF_CUDA_CHECK(cudaProfilerStop());\n#endif  // OF_ENABLE_PROFILER\n}\n\nvoid EnableProfiler(bool use_cpu, bool use_cuda, bool record_shapes, bool record_attrs,\n                    bool record_bandwidth) {\n  CHECK_JUST(vm::ClusterSync());\n  if (Singleton<ProfileManager>::Get() == nullptr) {\n    Singleton<ProfileManager>::New(use_cpu, use_cuda, record_shapes, record_attrs,\n                                   record_bandwidth);\n  }\n}\n\n// DisableProfilerAndReturnResult will return a json of profile results.\nMaybe<std::string> DisableProfilerAndReturnResult() {\n  JUST(vm::ClusterSync());\n#if defined(WITH_CUDA)\n  OF_CUDA_CHECK(cudaDeviceSynchronize());\n#endif  // WITH_CUDA\n  auto* pmgr = JUST(SingletonMaybe<ProfileManager>());\n  std::string results = pmgr->DumpResultsJson();\n  Singleton<ProfileManager>::Delete();\n  return results;\n}\n\nMaybe<std::string> StartRecord(const std::string& name) {\n  auto* pmgr = JUST(SingletonMaybe<ProfileManager>());\n  JUST(vm::ClusterSync());\n  return pmgr->RegisterEventRecorder(profiler::EventRecorder::CreateCustomEventRecorder(name),\n                                     name);\n}\n\nMaybe<void> EndRecord(const std::string& event_recorder_key) {\n  auto* pmgr = JUST(SingletonMaybe<ProfileManager>());\n  JUST(vm::ClusterSync());\n  pmgr->UnregisterEventRecorder(event_recorder_key);\n  return Maybe<void>::Ok();\n}\n\n}  // namespace profiler\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/profiler/profiler.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_PROFILER_PROFILER_H_\n#define ONEFLOW_CORE_PROFILER_PROFILER_H_\n\n#include \"oneflow/core/common/util.h\"\n\nnamespace oneflow {\n\nnamespace profiler {\n\nvoid NameThisHostThread(const std::string& name);\n\nvoid RangePush(const std::string& name);\n\nvoid RangePop();\n\nvoid LogHostMemoryUsage(const std::string& name);\n\nvoid ProfilerStart();\n\nvoid ProfilerStop();\n\nclass RangeGuardCtx;\n\nclass RangeGuard final {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(RangeGuard);\n  explicit RangeGuard(const std::string& name);\n  ~RangeGuard();\n\n private:\n  std::shared_ptr<RangeGuardCtx> ctx_;\n};\n\n#define OF_PROFILER_NAME_THIS_HOST_THREAD(name) ::oneflow::profiler::NameThisHostThread(name)\n\n#ifdef OF_ENABLE_PROFILER\n#define OF_PROFILER_ONLY_CODE(...) __VA_ARGS__\n#define OF_PROFILER_RANGE_PUSH(name) ::oneflow::profiler::RangePush(name)\n#define OF_PROFILER_RANGE_POP() ::oneflow::profiler::RangePop()\n#define OF_PROFILER_RANGE_GUARD(name) \\\n  ::oneflow::profiler::RangeGuard OF_PP_CAT(_of_profiler_range_guard_, __COUNTER__)(name)\n#define OF_PROFILER_LOG_HOST_MEMORY_USAGE(name) ::oneflow::profiler::LogHostMemoryUsage(name)\n#else\n#define OF_PROFILER_ONLY_CODE(...)\n#define OF_PROFILER_RANGE_PUSH(name)\n#define OF_PROFILER_RANGE_POP()\n#define OF_PROFILER_RANGE_GUARD(name)\n#define OF_PROFILER_LOG_HOST_MEMORY_USAGE(name)\n#endif\n\nvoid EnableProfiler(bool use_cpu, bool use_cuda, bool record_shapes, bool record_attrs,\n                    bool record_bandwidth);\n\n// DisableProfilerAndReturnResult will return a json of profile results.\nMaybe<std::string> DisableProfilerAndReturnResult();\n\nMaybe<std::string> StartRecord(const std::string& name);\n\nMaybe<void> EndRecord(const std::string& event_recorder_key);\n\n}  // namespace profiler\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_PROFILER_PROFILER_H_\n"
  },
  {
    "path": "oneflow/core/profiler/util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#ifndef ONEFLOW_CORE_PROFILER_UTIL_H_\n#define ONEFLOW_CORE_PROFILER_UTIL_H_\n\n#include <cstdint>\n#include <time.h>\n\nnamespace oneflow {\n\nnamespace profiler {\n\nusing time_t = int64_t;\n\ninline time_t GetTimeNow(bool allow_monotonic = false) {\n  struct timespec t {};\n  auto mode = CLOCK_REALTIME;\n  if (allow_monotonic) { mode = CLOCK_MONOTONIC; }\n  clock_gettime(mode, &t);\n  return static_cast<time_t>(t.tv_sec) * 1000000000 + static_cast<time_t>(t.tv_nsec);\n}\n\n}  // namespace profiler\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_PROFILER_UTIL_H_\n"
  },
  {
    "path": "oneflow/core/record/coco.proto",
    "content": "syntax = \"proto2\";\npackage oneflow;\n\nimport \"oneflow/core/record/record.proto\";\n\nmessage PolygonList {\n    repeated FloatList polygons = 1;\n}\n"
  },
  {
    "path": "oneflow/core/record/record.proto",
    "content": "syntax = \"proto2\";\npackage oneflow;\n\nmessage BytesList {\n  repeated bytes value = 1;\n}\n\nmessage FloatList {\n  repeated float value = 1 [packed = true];\n}\n\nmessage DoubleList {\n  repeated double value = 1 [packed = true];\n}\n\nmessage Int32List {\n  repeated int32 value = 1 [packed = true];\n}\n\nmessage Int64List {\n  repeated int64 value = 1 [packed = true];\n}\n\nmessage Feature {\n  oneof kind {\n    BytesList bytes_list = 1;\n    FloatList float_list = 2;\n    DoubleList double_list = 3;\n    Int32List int32_list = 4;\n    Int64List int64_list = 5;\n  }\n}\n\nmessage OFRecord {\n  map<string, Feature> feature = 1;\n}\n"
  },
  {
    "path": "oneflow/core/register/blob.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/register/blob.h\"\n#include \"oneflow/core/kernel/kernel_util.h\"\n\nnamespace oneflow {\n\nBlob::Blob(const MemoryCase& mem_case, const BlobDesc* blob_desc, char* header_ptr) {\n  Init(mem_case, blob_desc, header_ptr, nullptr, 0);\n}\n\nBlob::Blob(const MemoryCase& mem_case, const BlobDesc* blob_desc, char* header_ptr,\n           char* body_ptr) {\n  Init(mem_case, blob_desc, header_ptr, body_ptr, 0);\n}\n\nBlob::Blob(const MemoryCase& mem_case,  // NOLINT，Blob::Blob(...) { // NOLINT\n           const BlobDesc* blob_desc, char* header_ptr, char* body_ptr, const int64_t offset) {\n  Init(mem_case, blob_desc, header_ptr, body_ptr, offset);\n}\n\nvoid Blob::Init(const MemoryCase& mem_case, const BlobDesc* blob_desc, char* header_ptr,\n                char* body_ptr, const int64_t offset) {\n  mem_case_ = mem_case;\n  blob_desc_ = blob_desc;\n  storage_offset_ = offset;\n  dptr_ = body_ptr;\n  header_ptr_ = header_ptr;\n  this->blob_access_checker_ = Singleton<BlobAccessCheckerIf<true, true>>::Get();\n  int64_t* shape_ptr = reinterpret_cast<int64_t*>(header_ptr);\n  shape_view_.reset(new ShapeView(shape_ptr, static_shape().NumAxes()));\n  if (blob_desc->is_dynamic()) {\n    mut_shape_view_.reset(new MutShapeView(shape_ptr, static_shape().NumAxes()));\n  }\n  MutShapeView(shape_ptr, static_shape().NumAxes()).set_shape(static_shape());\n}\n\nvoid Blob::CopyHeaderFrom(const Blob* rhs) {\n  size_t header_size = blob_desc().ByteSizeOfBlobHeader();\n  CHECK_EQ(header_size, rhs->blob_desc().ByteSizeOfBlobHeader());\n  if (this == rhs || header_size == 0) { return; }\n  std::memcpy(header_ptr_, rhs->header_ptr(), header_size);\n}\n\nchar* Blob::mut_contiguous_header_ptr() {\n  // check header and body is continuous\n  CHECK_EQ(header_ptr() + blob_desc_->AlignedByteSizeOfBlobHeader(), dptr<char>());\n  return header_ptr_;\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/register/blob.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_REGISTER_BLOB_H_\n#define ONEFLOW_CORE_REGISTER_BLOB_H_\n\n#include \"oneflow/core/job/resource.pb.h\"\n#include \"oneflow/core/memory/memory_case.pb.h\"\n#include \"oneflow/core/register/blob_desc.h\"\n#include \"oneflow/core/common/shape_view.h\"\n#include \"oneflow/core/common/symbol.h\"\n\nnamespace oneflow {\n\nclass BlobAccessChecker {\n public:\n  virtual void CheckHeaderMutable() const = 0;\n  virtual void CheckBodyMutable() const = 0;\n};\n\ntemplate<bool is_header_mutable, bool is_body_mutable>\nclass BlobAccessCheckerIf final : public BlobAccessChecker {\n public:\n  void CheckHeaderMutable() const override {\n    CHECK(is_header_mutable)\n        << \"header mutable check not passed, blob's shape is not mutable at this moment!\";\n  }\n\n  void CheckBodyMutable() const override {\n    CHECK(is_body_mutable)\n        << \"body mutable check not passed, blob's data is not mutable at this moment!\";\n  }\n};\n\nclass Blob final {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(Blob);\n  Blob(const MemoryCase& mem_case, const BlobDesc* blob_desc, char* header_ptr);\n  Blob(const MemoryCase& mem_case, const BlobDesc* blob_desc, char* header_ptr, char* body_ptr);\n  Blob(const MemoryCase& mem_case, const BlobDesc* blob_desc, char* header_ptr, char* body_ptr,\n       const int64_t offset);\n  virtual ~Blob() = default;\n\n  DataType data_type() const { return blob_desc_->data_type(); }\n  MemoryFormat memory_format() const { return blob_desc_->memory_format(); }\n  const char* header_ptr() const { return header_ptr_; }\n  [[deprecated(\n      \"\\\"mut_header_ptr\\\" will be removed in Bolb. Please avoid to use this method whenever \"\n      \"possible. Almost all methods of `mut_header_ptr` are also in `Blob`.\")]] char*\n  mut_header_ptr() {\n    return header_ptr_;\n  }\n  char* mut_contiguous_header_ptr();\n  const BlobDesc& blob_desc() const { return *blob_desc_; }\n  const BlobDesc* blob_desc_ptr() const { return blob_desc_; }\n\n  template<typename T = void>\n  const T* dptr() const {\n    CheckDataType<T>(data_type());\n    return reinterpret_cast<T*>(static_cast<char*>(dptr_)\n                                + storage_offset_ * GetSizeOfDataType(data_type()));\n  }\n  template<typename T = void>\n  T* mut_dptr() {\n    this->blob_access_checker()->CheckBodyMutable();\n    CheckDataType<T>(data_type());\n    return reinterpret_cast<T*>(static_cast<char*>(dptr_)\n                                + storage_offset_ * GetSizeOfDataType(data_type()));\n  }\n  template<typename T = void>\n  T* ForceMutDptr() {\n    CheckDataType<T>(data_type());\n    return reinterpret_cast<T*>(static_cast<char*>(dptr_)\n                                + storage_offset_ * GetSizeOfDataType(data_type()));\n  }\n  template<typename T = void>\n  const T* raw_dptr() const {\n    CheckDataType<T>(data_type());\n    return static_cast<T*>(dptr_);\n  }\n  template<typename T = void>\n  T* mut_raw_dptr() {\n    this->blob_access_checker()->CheckBodyMutable();\n    CheckDataType<T>(data_type());\n    return static_cast<T*>(dptr_);\n  }\n\n  // shape\n  const Shape& static_shape() const { return blob_desc_->shape(); }\n  const ShapeView& shape_view() const { return *shape_view_; }\n  const ShapeView& shape() const { return *shape_view_; }\n  MutShapeView* mut_shape_view() {\n    this->blob_access_checker()->CheckHeaderMutable();\n    return mut_shape_view_.get();\n  }\n  MutShapeView* ForceMutShapeView() { return mut_shape_view_.get(); }\n\n  // stride\n  const Stride& stride() const { return blob_desc_->stride(); }\n\n  void reset_dptr(char* dptr) { dptr_ = dptr; }\n\n  void CopyHeaderFrom(const Blob* rhs);\n  bool IsBodyEmpty() const { return shape().elem_cnt() == 0; }\n\n  size_t AlignedTotalByteSize() const { return blob_desc_->AlignedTotalByteSize(); }\n  const MemoryCase& mem_case() const { return mem_case_; }\n\n  size_t ByteSizeOfBlobBody() const { return blob_desc_->ByteSizeOfBlobBody(); }\n  size_t AlignedByteSizeOfBlobBody() const { return blob_desc_->AlignedByteSizeOfBlobBody(); }\n\n  void set_blob_access_checker(const BlobAccessChecker* blob_access_checker) {\n    this->blob_access_checker_ = blob_access_checker;\n  }\n\n  const BlobAccessChecker* blob_access_checker() { return this->blob_access_checker_; }\n\n private:\n  void Init(const MemoryCase& mem_case, const BlobDesc* blob_desc, char* header_ptr, char* body_ptr,\n            const int64_t offset);\n\n  const BlobAccessChecker* blob_access_checker_;\n  MemoryCase mem_case_;\n  const BlobDesc* blob_desc_;\n  void* dptr_;\n  char* header_ptr_;\n  int64_t storage_offset_;\n  std::unique_ptr<ShapeView> shape_view_;\n  std::unique_ptr<MutShapeView> mut_shape_view_;\n};\n\n#define INIT_GLOBAL_BLOB_MUTABLE_CHECKER(is_header_mutable, is_body_mutable)                \\\n  COMMAND(Singleton<BlobAccessCheckerIf<is_header_mutable, is_body_mutable>>::SetAllocated( \\\n      new BlobAccessCheckerIf<is_header_mutable, is_body_mutable>()))\n\nINIT_GLOBAL_BLOB_MUTABLE_CHECKER(false, false);\nINIT_GLOBAL_BLOB_MUTABLE_CHECKER(false, true);\nINIT_GLOBAL_BLOB_MUTABLE_CHECKER(true, false);\nINIT_GLOBAL_BLOB_MUTABLE_CHECKER(true, true);\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_REGISTER_BLOB_H_\n"
  },
  {
    "path": "oneflow/core/register/blob_desc.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/register/blob_desc.h\"\n\nnamespace oneflow {\n\nbool CompareLbiBlobDescPair(const LbiBlobDescPair& lhs, const LbiBlobDescPair& rhs) {\n  return lhs.lbi() < rhs.lbi();\n}\n\nBlobDesc::BlobDesc(const Shape& shape, DataType dtype, MemoryFormat memory_format, bool is_dynamic)\n    : shape_(SymbolOf(shape)),\n      stride_(SymbolOf(Stride(shape))),\n      data_type_(dtype),\n      memory_format_(memory_format),\n      is_dynamic_(is_dynamic) {}\nBlobDesc::BlobDesc(const Shape& shape, const Stride& stride, DataType dtype,\n                   MemoryFormat memory_format, bool is_dynamic)\n    : shape_(SymbolOf(shape)),\n      stride_(SymbolOf(stride)),\n      data_type_(dtype),\n      memory_format_(memory_format),\n      is_dynamic_(is_dynamic) {}\nBlobDesc::BlobDesc(Symbol<Shape> shape, Symbol<Stride> stride, DataType dtype,\n                   MemoryFormat memory_format, bool is_dynamic)\n    : shape_(shape),\n      stride_(stride),\n      data_type_(dtype),\n      memory_format_(memory_format),\n      is_dynamic_(is_dynamic) {}\nBlobDesc::BlobDesc(const Shape& shape, DataType dtype, MemoryFormat memory_format)\n    : BlobDesc(shape, Stride(shape), dtype, memory_format, false) {}\nBlobDesc::BlobDesc(const Shape& shape, const Stride& stride, DataType dtype,\n                   MemoryFormat memory_format)\n    : BlobDesc(shape, stride, dtype, memory_format, false) {}\nBlobDesc::BlobDesc(Symbol<Shape> shape, Symbol<Stride> stride, DataType dtype,\n                   MemoryFormat memory_format)\n    : BlobDesc(shape, stride, dtype, memory_format, false) {}\nBlobDesc::BlobDesc(DataType dtype, MemoryFormat memory_format)\n    : BlobDesc(Shape(), Stride(), dtype, memory_format, false) {}\n\nBlobDesc::BlobDesc(const BlobDescProto& proto)\n    : shape_(SymbolOf(Shape(proto.shape()))),\n      stride_(SymbolOf(Stride(proto.stride()))),\n      data_type_(proto.data_type()),\n      memory_format_(proto.memory_format()),\n      is_dynamic_(proto.is_dynamic()) {}\n\nBlobDesc::BlobDesc(const BlobDesc& other)\n    : shape_(other.shape_),\n      stride_(other.stride_),\n      data_type_(other.data_type()),\n      memory_format_(other.memory_format()),\n      is_dynamic_(other.is_dynamic()) {}\n\nvoid BlobDesc::ToProto(BlobDescProto* proto) const {\n  shape().ToProto(proto->mutable_shape());\n  stride().ToProto(proto->mutable_stride());\n  proto->set_data_type(data_type_);\n  proto->set_memory_format(memory_format_);\n  proto->set_is_dynamic(is_dynamic_);\n}\n\nBlobDesc& BlobDesc::operator=(const BlobDesc& rhs) {\n  this->CopyFrom(rhs);\n  return *this;\n}\n\nvoid BlobDesc::CopyFrom(const BlobDesc& other) {\n  set_shape(other.shape());\n  set_stride(other.stride());\n  set_data_type(other.data_type());\n  set_memory_format(other.memory_format());\n  set_is_dynamic(other.is_dynamic());\n}\n\nvoid BlobDesc::set_is_dynamic(bool is_dynamic) { is_dynamic_ = is_dynamic; }\n\nbool BlobDesc::operator==(const BlobDesc& rhs) const {\n  return (shape() == rhs.shape()) && (stride() == rhs.stride()) && (data_type() == rhs.data_type())\n         && (memory_format() == rhs.memory_format()) && (is_dynamic() == rhs.is_dynamic());\n}\n\nsize_t BlobDesc::ByteSizeOfBlobHeader() const {\n  return shape().is_initialized() ? shape().NumAxes() * sizeof(int64_t) : 0;\n}\n\nsize_t BlobDesc::AlignedByteSizeOfBlobHeader() const {\n  return shape().is_initialized()\n             ? RoundUp(shape().NumAxes() * sizeof(int64_t), kBlobHeaderAlignSize)\n             : RoundUp(0, kBlobHeaderAlignSize);\n}\n\nsize_t BlobDesc::ByteSizeOfBlobBody() const {\n  return shape().is_initialized() ? shape().elem_cnt() * GetSizeOfDataType(data_type()) : 0;\n}\n\nsize_t BlobDesc::AlignedByteSizeOfBlobBody() const {\n  return RoundUp(ByteSizeOfBlobBody(), kBlobBodyAlignSize);\n}\n\nsize_t BlobDesc::AlignedTotalByteSize() const {\n  return AlignedByteSizeOfBlobHeader() + AlignedByteSizeOfBlobBody();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/register/blob_desc.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_REGISTER_BLOB_DESC_H_\n#define ONEFLOW_CORE_REGISTER_BLOB_DESC_H_\n\n#include <memory>\n#include \"oneflow/core/common/data_type.h\"\n#include \"oneflow/core/common/memory_format.pb.h\"\n#include \"oneflow/core/common/shape.h\"\n#include \"oneflow/core/common/stride.h\"\n#include \"oneflow/core/common/symbol.h\"\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/common/protobuf.h\"\n#include \"oneflow/core/register/blob_desc.pb.h\"\n#include \"oneflow/core/register/register_desc.pb.h\"\n\nnamespace oneflow {\n\nclass BlobDesc final {\n public:\n  BlobDesc() = delete;\n  ~BlobDesc() = default;\n\n  // NOTE(chengcheng): Cannot using std::make_shared in header file, because it will cause\n  //  Segmentation fault with unknown reason.\n  BlobDesc(const Shape& shape, DataType dtype, MemoryFormat memory_format, bool is_dynamic);\n  BlobDesc(const Shape& shape, const Stride& stride, DataType dtype, MemoryFormat memory_format,\n           bool is_dynamic);\n  BlobDesc(Symbol<Shape> shape, Symbol<Stride> stride, DataType dtype, MemoryFormat memory_format,\n           bool is_dynamic);\n\n  BlobDesc(const Shape& shape, DataType dtype, MemoryFormat memory_format);\n  BlobDesc(const Shape& shape, const Stride& stride, DataType dtype, MemoryFormat memory_format);\n  BlobDesc(Symbol<Shape> shape, Symbol<Stride> stride, DataType dtype, MemoryFormat memory_format);\n  explicit BlobDesc(DataType dtype, MemoryFormat memory_format);\n  explicit BlobDesc(const BlobDescProto& proto);\n  explicit BlobDesc(const BlobDesc&);\n\n  BlobDesc& operator=(const BlobDesc&);\n\n  const Shape& shape() const {\n    CHECK(shape_.operator bool());\n    return *shape_;\n  }\n  const Stride& stride() const {\n    CHECK(stride_.operator bool());\n    return *stride_;\n  }\n  const std::shared_ptr<const Shape>& shape_ptr() const { return shape_.shared_from_symbol(); }\n  const std::shared_ptr<const Stride>& stride_ptr() const { return stride_.shared_from_symbol(); }\n\n  void set_shape(const Shape& shape) { this->shape_ = SymbolOf(shape); }\n  void set_stride(const Stride& stride) { this->stride_ = SymbolOf(stride); }\n\n  DataType data_type() const { return data_type_; }\n  void set_data_type(DataType data_type) { data_type_ = data_type; }\n\n  MemoryFormat memory_format() const { return memory_format_; }\n  void set_memory_format(MemoryFormat memory_format) { memory_format_ = memory_format; }\n\n  bool is_dynamic() const { return is_dynamic_; }\n  void set_is_dynamic(bool is_dynamic);\n\n  bool operator==(const BlobDesc&) const;\n  void ToProto(BlobDescProto*) const;\n\n  void CopyFrom(const BlobDesc&);\n\n  size_t ByteSizeOfBlobHeader() const;\n  size_t ByteSizeOfBlobBody() const;\n  size_t AlignedByteSizeOfBlobHeader() const;\n  size_t AlignedByteSizeOfBlobBody() const;\n  size_t AlignedTotalByteSize() const;\n\n private:\n  Symbol<Shape> shape_;\n  Symbol<Stride> stride_;\n  DataType data_type_;\n  MemoryFormat memory_format_;\n  bool is_dynamic_;\n};\n\nbool CompareLbiBlobDescPair(const LbiBlobDescPair& lhs, const LbiBlobDescPair& rhs);\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_REGISTER_BLOB_DESC_H_\n"
  },
  {
    "path": "oneflow/core/register/blob_desc.proto",
    "content": "syntax = \"proto2\";\npackage oneflow;\n\nimport \"oneflow/core/common/shape.proto\";\nimport \"oneflow/core/common/sequential.proto\";\nimport \"oneflow/core/common/data_type.proto\";\nimport \"oneflow/core/common/memory_format.proto\";\n\nmessage BlobDescProto {\n  required ShapeProto shape = 1;\n  required Int64ListProto stride = 2;\n  required DataType data_type = 3;\n  required bool is_dynamic = 4;\n  required MemoryFormat memory_format = 5;\n}\n\nmessage BlobDescSignature {\n  map<string, BlobDescProto> bn_in_op2blob_desc = 1;\n}\n"
  },
  {
    "path": "oneflow/core/register/logical_blob_id.proto",
    "content": "syntax = \"proto2\";\npackage oneflow;\n\nmessage LogicalBlobId {\n  optional string op_name = 1;\n  optional string blob_name = 2;\n}\n\nmessage LogicalBlobIdPair {\n  required LogicalBlobId first = 1;\n  required LogicalBlobId second = 2;\n}\n\nmessage LogicalBlobIdPairs {\n  repeated LogicalBlobIdPair pair = 1;\n}\n\nmessage LogicalBlobIdGroups {\n  message LogicalBlobIdGroup {\n    repeated LogicalBlobId lbi = 1;\n  }\n  repeated LogicalBlobIdGroup lbi_group = 2;\n}\n\nmessage ArgSignature {\n  map<string, LogicalBlobId> bn_in_op2lbi = 1;\n}\n"
  },
  {
    "path": "oneflow/core/register/op_blob_arg.proto",
    "content": "syntax = \"proto2\";\npackage oneflow;\n\nmessage OpBlobArg {\n  required string op_name = 1;\n  // blob name in op\n  required string bn_in_op = 2;\n}\n\nmessage OpBlobArgPair {\n  required OpBlobArg first = 1;\n  required OpBlobArg second = 2;\n}\n\nmessage OpBlobArgPairs {\n  repeated OpBlobArgPair pair = 1;\n}\n\nmessage OpBlobArgList {\n  repeated OpBlobArg oba = 1;\n}\n"
  },
  {
    "path": "oneflow/core/register/op_blob_arg_info.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_REGISTER_OP_BLOB_ARG_INFO_H_\n#define ONEFLOW_CORE_REGISTER_OP_BLOB_ARG_INFO_H_\n\n#include \"oneflow/core/register/op_blob_arg.pb.h\"\n\nnamespace oneflow {\n\nstruct InplaceObasInfo {\n  OpBlobArgList mut_in_obas;\n  OpBlobArgPairs mut_inplace_oba_pairs;\n  OpBlobArgPairs con_inplace_oba_pairs;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_REGISTER_OP_BLOB_ARG_INFO_H_\n"
  },
  {
    "path": "oneflow/core/register/register.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/register/register.h\"\n#include \"oneflow/core/comm_network/comm_network.h\"\n#include \"oneflow/core/memory/memory_allocator.h\"\n\nnamespace oneflow {\n\nconst std::vector<int64_t>& Regst::consumers_actor_id() const {\n  return regst_desc_->consumers_actor_id();\n}\n\nRegst::Regst(const RtRegstDesc* regst_desc, RegstAllocationType allocation_type)\n    : regst_desc_(regst_desc),\n      header_mem_ptr_(nullptr),\n      body_mem_ptr_(nullptr),\n      comm_net_token_(nullptr),\n      allocation_type_(allocation_type) {\n  sorted_blob_vec_.resize(regst_desc->lbi_num());\n}\n\nRegst::~Regst() {\n  if (comm_net_token_ != nullptr) { Singleton<CommNet>::Get()->UnRegisterMemory(comm_net_token_); }\n}\n\nvoid Regst::Init(void* header_mem_ptr) {\n  CHECK(header_mem_ptr_ == nullptr);\n  header_mem_ptr_ = header_mem_ptr;\n  regst_desc_->ForEachBlobDescOffsetInOnRegst([&](int64_t ordinal, const LogicalBlobId& lbi,\n                                                  const BlobDesc* blob_desc, int64_t body_offset,\n                                                  int64_t header_offset) {\n    sorted_blob_vec_.at(ordinal).reset(\n        new Blob(regst_desc_->mem_case(), blob_desc,\n                 reinterpret_cast<char*>(header_mem_ptr_) + header_offset));\n  });\n}\n\nvoid Regst::ResetBodyMemPtr(void* body_mem_ptr) {\n  if (body_mem_ptr_ == body_mem_ptr) { return; }\n  body_mem_ptr_ = body_mem_ptr;\n  if (body_mem_ptr_ == nullptr) {\n    for (auto& blob : sorted_blob_vec_) { blob->reset_dptr(nullptr); }\n  } else {\n    regst_desc_->ForEachBlobDescOffsetInOnRegst([&](int64_t ordinal, const LogicalBlobId& lbi,\n                                                    const BlobDesc* blob_desc, int64_t body_offset,\n                                                    int64_t header_offset) {\n      sorted_blob_vec_.at(ordinal)->reset_dptr(reinterpret_cast<char*>(body_mem_ptr_)\n                                               + body_offset);\n      InitNonPODTypeBlobIfNeed(Singleton<MemoryAllocator>::Get(),\n                               sorted_blob_vec_.at(ordinal).get());\n    });\n  }\n}\n\nBlob* Regst::GetBlobByOrdinal(int64_t ordinal) { return sorted_blob_vec_.at(ordinal).get(); }\n\nBlob* Regst::GetBlobByLbi(const LogicalBlobId& lbi) {\n  const int64_t ordinal = regst_desc_->GetOrdinalForLbi(lbi);\n  if (ordinal >= 0) {\n    return sorted_blob_vec_.at(ordinal).get();\n  } else {\n    return nullptr;\n  }\n}\n\nvoid Regst::SetBlobByOrdinal(int64_t ordinal, std::unique_ptr<Blob>&& blob) {\n  CHECK(!sorted_blob_vec_.at(ordinal));\n  sorted_blob_vec_.at(ordinal).swap(blob);\n}\n\nBlob* Regst::GetMutSoleBlob() {\n  CHECK_EQ(GetBlobSize(), 1);\n  return sorted_blob_vec_.front().get();\n}\n\nconst Blob* Regst::GetSoleBlob() const {\n  CHECK_EQ(GetBlobSize(), 1);\n  return sorted_blob_vec_.front().get();\n}\n\nvoid* Regst::comm_net_token() {\n  void* token = comm_net_token_.load(std::memory_order_relaxed);\n  if (token != nullptr) { return token; }\n  {\n    std::lock_guard<std::mutex> lock(comm_net_token_mutex_);\n    token = comm_net_token_;\n    if (token != nullptr) { return token; }\n    CHECK(body_mem_ptr_ != nullptr);\n    CHECK(header_mem_ptr_ != nullptr);\n    CHECK(reinterpret_cast<char*>(header_mem_ptr_) + regst_desc_->HeaderByteSize4OneRegst()\n          == body_mem_ptr_);\n    token = Singleton<CommNet>::Get()->RegisterMemory(header_mem_ptr_,\n                                                      this->regst_desc()->MainByteSize4OneRegst());\n    comm_net_token_ = token;\n    return token;\n  }\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/register/register.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_REGISTER_REGISTER_H_\n#define ONEFLOW_CORE_REGISTER_REGISTER_H_\n\n#include \"oneflow/core/register/blob.h\"\n#include \"oneflow/core/register/runtime_register_desc.h\"\n\nnamespace oneflow {\n\nenum class RegstAllocationType {\n  kInvalid = 0,\n  kStatic = 1,\n  kStreamOrdered = 2,\n};\n\nclass Regst final {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(Regst);\n  ~Regst();\n\n  // Getters\n  int64_t regst_desc_id() const {\n    CHECK(regst_desc_ != nullptr);\n    return regst_desc_->regst_desc_id();\n  }\n\n  void Init(void* header_mem_ptr);\n  void ResetBodyMemPtr(void* body_mem_ptr);\n  int64_t producer_actor_id() const { return regst_desc_->producer_actor_id(); }\n  const std::vector<int64_t>& consumers_actor_id() const;\n  const RtRegstDesc* regst_desc() const { return regst_desc_; }\n  Blob* GetBlobByOrdinal(int64_t ordinal);\n  Blob* GetBlobByLbi(const LogicalBlobId& lbi);\n  const Blob* GetSoleBlob() const;\n  Blob* GetMutSoleBlob();\n  int64_t GetBlobSize() const { return static_cast<int64_t>(sorted_blob_vec_.size()); }\n\n  void* comm_net_token();\n\n  void* header_mem_ptr() const { return header_mem_ptr_; }\n\n  void* body_mem_ptr() const { return body_mem_ptr_; }\n\n  RegstAllocationType allocation_type() const { return allocation_type_; }\n\n private:\n  friend class RegstMgr;\n  Regst(const RtRegstDesc* regst_desc, RegstAllocationType allocation_type);\n\n  void SetBlobByOrdinal(int64_t ordinal, std::unique_ptr<Blob>&& blob);\n\n  const RtRegstDesc* regst_desc_;\n  std::vector<std::unique_ptr<Blob>> sorted_blob_vec_;\n\n  void* header_mem_ptr_;\n  void* body_mem_ptr_;\n\n  std::atomic<void*> comm_net_token_;\n  std::mutex comm_net_token_mutex_;\n  RegstAllocationType allocation_type_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_REGISTER_REGISTER_H_\n"
  },
  {
    "path": "oneflow/core/register/register_desc.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/register/register_desc.h\"\n#include \"oneflow/core/common/protobuf.h\"\n#include \"oneflow/core/graph/copy_task_node.h\"\n#include \"oneflow/core/job/id_manager.h\"\n#include \"oneflow/core/register/runtime_register_desc.h\"\n#include \"oneflow/core/memory/memory_case_util.h\"\n\nnamespace oneflow {\n\nRegstDesc::RegstDesc() {\n  regst_desc_id_ = Singleton<IDMgr>::Get()->NewRegstDescId();  // NOLINT\n  producer_ = nullptr;\n  min_register_num_ = 1;\n  max_register_num_ = kMaxRegisterNum;\n  enable_reuse_mem_ = false;\n  mem_block_id_ = -1;\n  mem_block_offset_ = -1;\n  hint_inplace_consumed_regst_desc_id_ = -1;\n  force_inplace_consumed_regst_desc_id_ = -1;\n}\n\nint64_t RegstDesc::mem_block_offset() const {\n  CHECK_GE(mem_block_offset_, 0);\n  return mem_block_offset_;\n}\n\nvoid RegstDesc::AddConsumer(const TaskNode* new_consumer) {\n  CHECK(consumers_.insert(new_consumer).second);\n}\n\nvoid RegstDesc::DeleteConsumer(const TaskNode* consumer) {\n  CHECK_EQ(consumers_.erase(consumer), 1);\n}\n\nvoid RegstDesc::UpdtMinRegstNumIfNeed(int32_t val) {\n  CHECK_LE(val, max_register_num_);\n  min_register_num_ = std::max(min_register_num_, val);\n}\nvoid RegstDesc::UpdtMaxRegstNumIfNeed(int32_t val) {\n  CHECK_GE(val, min_register_num_);\n  max_register_num_ = std::min(max_register_num_, val);\n}\n\nvoid RegstDesc::CopyBlobDescFrom(const RegstDesc* rhs) {\n  CHECK(lbi2blob_desc_.empty());\n  for (const auto& pair : rhs->lbi2blob_desc_) {\n    const LogicalBlobId& lbi = pair.first;\n    AddLbi(lbi);\n  }\n  CopyBlobDescWithoutAddLbi(rhs);\n}\n\nvoid RegstDesc::CopyMemBlockInfoFrom(const RegstDesc* rhs) {\n  enable_reuse_mem_ = rhs->enable_reuse_mem_;\n  mem_block_id_ = rhs->mem_block_id_;\n  mem_block_offset_ = rhs->mem_block_offset_;\n}\n\nvoid RegstDesc::CopyBlobDescWithoutAddLbi(const RegstDesc* rhs) {\n  for (const auto& pair : lbi2blob_desc_) {\n    auto rhs_it = rhs->lbi2blob_desc_.find(pair.first);\n    if (rhs_it != rhs->lbi2blob_desc_.end()) { *(pair.second) = *(rhs_it->second); }\n  }\n}\n\nBlobDesc* RegstDesc::AddLbi(const LogicalBlobId& lbi) {\n  CHECK(lbi2blob_desc_.find(lbi) == lbi2blob_desc_.end());\n  BlobDesc* blob_desc = new BlobDesc(GlobalJobDesc().DefaultDataType(), MemoryFormat::kContiguous);\n  lbi2blob_desc_[lbi].reset(blob_desc);\n  return blob_desc;\n}\n\nconst BlobDesc* RegstDesc::GetBlobDesc(const LogicalBlobId& lbi) const {\n  return const_cast<RegstDesc*>(this)->MutBlobDesc(lbi);\n}\n\nbool RegstDesc::HasLbi(const LogicalBlobId& lbi) const {\n  return lbi2blob_desc_.find(lbi) != lbi2blob_desc_.end();\n}\n\nBlobDesc* RegstDesc::MutBlobDesc(const LogicalBlobId& lbi) {\n  auto it = lbi2blob_desc_.find(lbi);\n  if (it != lbi2blob_desc_.end()) {\n    return it->second.get();\n  } else {\n    return nullptr;\n  }\n}\n\nconst BlobDesc* RegstDesc::SoleBlobDesc() const {\n  CHECK_EQ(1, lbi2blob_desc_.size());\n  return (*lbi2blob_desc_.begin()).second.get();\n}\n\nBlobDesc* RegstDesc::MutSoleBlobDesc() { return const_cast<BlobDesc*>(SoleBlobDesc()); }\n\nvoid RegstDesc::ForEachLbi(std::function<void(const LogicalBlobId&)> func) const {\n  for (const auto& p : lbi2blob_desc_) { func(p.first); }\n}\n\nvoid RegstDesc::EraseUninitializedShapeBlob() {\n  EraseIf<LogicalBlobId, std::unique_ptr<BlobDesc>>(\n      &lbi2blob_desc_, [](HashMap<LogicalBlobId, std::unique_ptr<BlobDesc>>::iterator it) {\n        return !it->second->shape().is_initialized();\n      });\n}\n\nvoid RegstDesc::InitFromProtoExceptConsumers(const RegstDescProto& proto) {\n  regst_desc_id_ = proto.regst_desc_id();\n  CHECK_EQ(proto.producer_task_id(), producer_->task_id());\n  regst_desc_type_ = proto.regst_desc_type();\n  if (regst_desc_type_.has_data_regst_desc()) {\n    const DataRegstDesc& data_regst_desc_proto = proto.regst_desc_type().data_regst_desc();\n    for (const auto& pair : data_regst_desc_proto.lbi2blob_desc()) {\n      *AddLbi(pair.lbi()) = BlobDesc(pair.blob_desc());\n    }\n    CHECK(!data_regst_desc_proto.has_time_shape());\n  } else if (regst_desc_type_.has_ctrl_regst_desc()) {\n    // do nothing\n  } else {\n    UNIMPLEMENTED();\n  }\n  min_register_num_ = proto.min_register_num();\n  max_register_num_ = proto.max_register_num();\n  min_register_num_ = proto.register_num();\n  mem_case_ = proto.mem_case();\n  enable_reuse_mem_ = proto.enable_reuse_mem();\n  mem_block_id_ = proto.mem_block_id();\n  mem_block_offset_ = proto.mem_block_offset();\n  hint_inplace_consumed_regst_desc_id_ = proto.hint_inplace_consumed_regst_desc_id();\n  force_inplace_consumed_regst_desc_id_ = proto.force_inplace_consumed_regst_desc_id();\n}\n\nvoid RegstDesc::ToProto(RegstDescProto* ret, bool check) const {\n  ret->set_regst_desc_id(regst_desc_id_);\n  ret->set_producer_task_id(producer_->task_id());\n  for (const TaskNode* consumer : consumers_) { ret->add_consumer_task_id(consumer->task_id()); }\n  *(ret->mutable_regst_desc_type()) = regst_desc_type_;\n  if (regst_desc_type_.has_data_regst_desc()) {\n    DataRegstDesc* data_regst_desc_proto =\n        ret->mutable_regst_desc_type()->mutable_data_regst_desc();\n    for (const auto& pair : lbi2blob_desc_) {\n      LbiBlobDescPair* pb_pair = data_regst_desc_proto->mutable_lbi2blob_desc()->Add();\n      *(pb_pair->mutable_lbi()) = pair.first;\n      pair.second->ToProto(pb_pair->mutable_blob_desc());\n    }\n    if (check) { CHECK(data_regst_time_shape_); }\n    if (data_regst_time_shape_) {\n      data_regst_time_shape_->ToProto(data_regst_desc_proto->mutable_time_shape());\n    }\n  } else if (regst_desc_type_.has_ctrl_regst_desc()) {\n    // do nothing\n  } else {\n    UNIMPLEMENTED();\n  }\n  ret->set_min_register_num(min_register_num_);\n  ret->set_max_register_num(max_register_num_);\n  ret->set_register_num(min_register_num_);\n  *(ret->mutable_mem_case()) = mem_case_;\n  ret->set_enable_reuse_mem(enable_reuse_mem_);\n  ret->set_mem_block_id(mem_block_id_);\n  ret->set_mem_block_offset(mem_block_offset_);\n  if (check) {\n    CHECK(hint_inplace_consumed_regst_desc_id_ == -1 || force_inplace_consumed_regst_desc_id_ == -1)\n        << \"They are oneof fields\";\n  }\n  if (hint_inplace_consumed_regst_desc_id_ != -1) {\n    ret->set_hint_inplace_consumed_regst_desc_id(hint_inplace_consumed_regst_desc_id_);\n  } else if (force_inplace_consumed_regst_desc_id_ != -1) {\n    ret->set_force_inplace_consumed_regst_desc_id(force_inplace_consumed_regst_desc_id_);\n  } else {\n    // do nothing\n  }\n}\n\nbool RegstDesc::HasSameMemSize(const RegstDesc* rhs) {\n  return SoleBlobDesc()->AlignedTotalByteSize() == rhs->SoleBlobDesc()->AlignedTotalByteSize();\n}\n\nbool RegstDesc::HasSameBlobDescs(const RegstDesc* rhs) {\n  if (rhs->lbi2blob_desc_.size() != lbi2blob_desc_.size()) { return false; }\n  for (const auto& pair : rhs->lbi2blob_desc_) {\n    auto iter = lbi2blob_desc_.find(pair.first);\n    if (iter == lbi2blob_desc_.end()) { return false; }\n    if (!(*(pair.second.get()) == *(iter->second.get()))) { return false; }\n  }\n  return true;\n}\n\nvoid InitCtrlRegstDesc(int64_t producer_task_id, RegstDescProto* ctrl_regst_proto) {\n  CHECK_NOTNULL(ctrl_regst_proto);\n  ctrl_regst_proto->set_regst_desc_id(Singleton<IDMgr>::Get()->NewRegstDescId());\n  ctrl_regst_proto->set_producer_task_id(producer_task_id);\n  ctrl_regst_proto->set_min_register_num(1);\n  ctrl_regst_proto->set_max_register_num(1);\n  ctrl_regst_proto->set_register_num(1);\n  ctrl_regst_proto->mutable_regst_desc_type()->mutable_ctrl_regst_desc();\n  *ctrl_regst_proto->mutable_mem_case() = memory::MakeHostMemCase();\n  ctrl_regst_proto->set_enable_reuse_mem(false);\n  ctrl_regst_proto->set_mem_block_id(-1);\n  ctrl_regst_proto->set_mem_block_offset(-1);\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/register/register_desc.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_REGISTER_REGISTER_DESC_H_\n#define ONEFLOW_CORE_REGISTER_REGISTER_DESC_H_\n\n#include \"oneflow/core/register/blob_desc.h\"\n#include \"oneflow/core/register/register_desc.pb.h\"\n\nnamespace oneflow {\n\nconst int32_t kMaxRegisterNum = std::numeric_limits<int32_t>::max();\n\nvoid InitCtrlRegstDesc(int64_t producer_task_id, RegstDescProto* ctrl_regst_proto);\n\nclass TaskNode;\n\nclass RegstDesc final {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(RegstDesc);\n  RegstDesc();\n  ~RegstDesc() = default;\n\n  // regst_desc_id\n  int64_t regst_desc_id() const { return regst_desc_id_; }\n\n  // producer_, consumers_\n  const TaskNode* producer() const { return producer_; }\n  void set_producer(const TaskNode* val) { producer_ = val; }\n  const HashSet<const TaskNode*>& consumers() const { return consumers_; }\n  void AddConsumer(const TaskNode*);\n  void DeleteConsumer(const TaskNode*);\n\n  // min_register_num_, max_register_num_\n  int32_t min_register_num() const { return min_register_num_; }\n  void UpdtMinRegstNumIfNeed(int32_t val);\n  int32_t max_register_num() const { return max_register_num_; }\n  void UpdtMaxRegstNumIfNeed(int32_t val);\n\n  // lbi2blob_desc_\n  void CopyBlobDescFrom(const RegstDesc*);\n  void CopyBlobDescWithoutAddLbi(const RegstDesc*);\n  BlobDesc* AddLbi(const LogicalBlobId&);\n  const BlobDesc* GetBlobDesc(const LogicalBlobId& lbi) const;\n  bool HasLbi(const LogicalBlobId& lbi) const;\n  BlobDesc* MutBlobDesc(const LogicalBlobId& lbi);\n  const BlobDesc* SoleBlobDesc() const;\n  BlobDesc* MutSoleBlobDesc();\n  void ForEachLbi(std::function<void(const LogicalBlobId&)> func) const;\n  size_t NumOfLbi() const { return lbi2blob_desc_.size(); }\n\n  // mem\n  const MemoryCase& mem_case() const { return mem_case_; }\n  MemoryCase* mut_mem_case() { return &mem_case_; }\n  bool enable_reuse_mem() const { return enable_reuse_mem_; }\n  void set_enable_reuse_mem(bool enable_reuse_mem) { enable_reuse_mem_ = enable_reuse_mem; }\n  int64_t mem_block_offset() const;\n  void set_mem_block_offset(int64_t val) { mem_block_offset_ = val; }\n  void set_hint_inplace_consumed_regst_desc_id(int64_t val) {\n    CHECK_EQ(force_inplace_consumed_regst_desc_id_, -1);\n    hint_inplace_consumed_regst_desc_id_ = val;\n  }\n  bool has_force_inplace_consumed_regst_desc_id() {\n    return force_inplace_consumed_regst_desc_id_ != -1;\n  }\n  void set_force_inplace_consumed_regst_desc_id(int64_t val) {\n    CHECK_EQ(hint_inplace_consumed_regst_desc_id_, -1);\n    force_inplace_consumed_regst_desc_id_ = val;\n  }\n  int32_t mem_block_id() const { return mem_block_id_; }\n  void set_mem_block_id(int32_t val) { mem_block_id_ = val; }\n  bool HasSetMemBlockId() { return mem_block_id_ != -1; }\n  void CopyMemBlockInfoFrom(const RegstDesc*);\n\n  const std::shared_ptr<Shape>& data_regst_time_shape() const {\n    CHECK(regst_desc_type_.has_data_regst_desc());\n    CHECK(data_regst_time_shape_);\n    return data_regst_time_shape_;\n  }\n  std::shared_ptr<Shape>* mut_data_regst_time_shape() {\n    CHECK(regst_desc_type_.has_data_regst_desc());\n    return &data_regst_time_shape_;\n  }\n  RegstDescTypeProto* mut_regst_desc_type() { return &regst_desc_type_; }\n  const RegstDescTypeProto& regst_desc_type() const { return regst_desc_type_; }\n  bool HasSameMemSize(const RegstDesc*);\n\n  // util\n  void EraseUninitializedShapeBlob();\n  void InitFromProtoExceptConsumers(const RegstDescProto& proto);\n  void ToProto(RegstDescProto* proto) const { ToProto(proto, /*check*/ true); }\n  void ToProto(RegstDescProto*, bool check) const;\n  bool HasSameBlobDescs(const RegstDesc*);\n\n private:\n  int64_t regst_desc_id_;\n  const TaskNode* producer_;\n  HashSet<const TaskNode*> consumers_;\n  int32_t min_register_num_;\n  int32_t max_register_num_;\n\n  HashMap<LogicalBlobId, std::unique_ptr<BlobDesc>> lbi2blob_desc_;\n\n  MemoryCase mem_case_;\n  RegstDescTypeProto regst_desc_type_;\n  bool enable_reuse_mem_;\n  int32_t mem_block_id_;\n  int64_t mem_block_offset_;\n  int64_t hint_inplace_consumed_regst_desc_id_;\n  int64_t force_inplace_consumed_regst_desc_id_;\n\n  std::shared_ptr<Shape> data_regst_time_shape_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_REGISTER_REGISTER_DESC_H_\n"
  },
  {
    "path": "oneflow/core/register/register_desc.proto",
    "content": "syntax = \"proto2\";\npackage oneflow;\n\nimport \"oneflow/core/register/blob_desc.proto\";\nimport \"oneflow/core/register/logical_blob_id.proto\";\nimport \"oneflow/core/memory/memory_case.proto\";\nimport \"oneflow/core/common/shape.proto\";\n\nmessage LbiBlobDescPair {\n  required LogicalBlobId lbi = 1;\n  required BlobDescProto blob_desc = 2;\n}\n\nmessage DataRegstDesc {\n  repeated LbiBlobDescPair lbi2blob_desc = 1;\n  optional ShapeProto time_shape = 3;\n}\n\nmessage CtrlRegstDesc {\n}\n\nmessage RegstDescTypeProto {\n  oneof type {\n    DataRegstDesc data_regst_desc = 1;\n    CtrlRegstDesc ctrl_regst_desc = 3;\n  }\n}\n\nmessage RegstDescProto {\n  required int64 regst_desc_id = 1;\n  required int64 producer_task_id = 2;\n  repeated int64 consumer_task_id = 3;\n  required int32 min_register_num = 4;\n  required int32 max_register_num = 5;\n  required int32 register_num = 6;\n  required MemoryCase mem_case = 7;\n  required RegstDescTypeProto regst_desc_type = 8;\n  required bool enable_reuse_mem = 9;\n  required int64 mem_block_id = 10;\n  required int64 mem_block_offset = 11;\n  optional int64 separated_header_mem_block_id = 12 [default = -1];\n  optional int64 inplace_consumed_regst_desc_id = 13 [default = -1];\n  oneof inplace_info_type {\n    int64 hint_inplace_consumed_regst_desc_id = 14 [default = -1];\n    int64 force_inplace_consumed_regst_desc_id = 15 [default = -1];\n  }\n  // NOTE(chengcheng): mark this regst memory is shared with EagerParameter.\n  optional string variable_op_name = 16 [default = \"\"];\n  // NOTE(chengcheng): for mem block debug.\n  optional int64 mem_block_total_actor_count = 20 [default = -1];\n  optional int64 alloc_before_actor = 21 [default = -1];\n  optional int64 free_after_actor = 22 [default = -1];\n}\n"
  },
  {
    "path": "oneflow/core/register/register_manager.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/register/register_manager.h\"\n#include \"oneflow/core/eager/eager_blob_object.h\"\n#include \"oneflow/core/job/job_desc.h\"\n#include \"oneflow/core/register/blob.h\"\n#include \"oneflow/core/common/str_util.h\"\n#include \"oneflow/core/common/tensor_buffer.h\"\n#include \"oneflow/core/comm_network/comm_network.h\"\n#include \"oneflow/core/control/global_process_ctx.h\"\n#include \"oneflow/core/memory/memory_case.pb.h\"\n#include \"oneflow/core/memory/memory_allocator.h\"\n#include \"oneflow/core/memory/chunk_manager.h\"\n#include \"oneflow/core/ep/include/device_manager_registry.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nstruct PackedChunkInfo {\n  MemoryCase mem_case;\n  int64_t size;\n  std::vector<const MemBlockProto*> blocks;\n  PackedChunkInfo(const MemoryCase& mem) {\n    mem_case = mem;\n    size = 0;\n  }\n};\n\nstd::shared_ptr<ep::Device> GetDeviceByMemoryCase(const MemoryCase& mem_case) {\n  return Singleton<ep::DeviceManagerRegistry>::Get()->GetDevice(mem_case.device_type(),\n                                                                mem_case.device_id());\n}\n\nvoid InitDataRegst(Regst* regst, char* main_mem_ptr, char* separated_header_mem_ptr) {\n  auto* rt_regst_desc = regst->regst_desc();\n  size_t separated_header_mem_size = rt_regst_desc->SeparatedHeaderByteSize4OneRegst();\n  char* cur_body_pointer = nullptr;\n  char* cur_header_pointer = nullptr;\n  if (separated_header_mem_size > 0) {\n    MemoryCase host_mem_case = memory::MakeHostMemCase();\n    if (separated_header_mem_ptr == nullptr) {\n      separated_header_mem_ptr =\n          Singleton<MemoryAllocator>::Get()->Allocate(host_mem_case, separated_header_mem_size);\n    }\n    cur_header_pointer = separated_header_mem_ptr;\n    cur_body_pointer = main_mem_ptr;\n  } else {\n    CHECK(separated_header_mem_ptr == nullptr);\n    cur_header_pointer = main_mem_ptr;\n    if (main_mem_ptr == nullptr) {\n      cur_body_pointer = nullptr;\n    } else {\n      cur_body_pointer =\n          main_mem_ptr + rt_regst_desc->GetSoleBlobDesc()->AlignedByteSizeOfBlobHeader();\n    }\n  }\n  if (regst->allocation_type() == RegstAllocationType::kStatic) {\n    CHECK(cur_body_pointer != nullptr || rt_regst_desc->TotalBodyByteSize4AllRegst() == 0);\n  } else if (regst->allocation_type() == RegstAllocationType::kStreamOrdered) {\n    CHECK(cur_body_pointer == nullptr);\n  } else {\n    UNIMPLEMENTED();\n  }\n  regst->Init(cur_header_pointer);\n  regst->ResetBodyMemPtr(cur_body_pointer);\n}\n\n}  // namespace\n\nRegstMgr::RegstMgr() : stream_ordered_memory_allocation_enabled_(false) {\n  stream_ordered_memory_allocation_enabled_ =\n      ParseBooleanFromEnv(\"ONEFLOW_GRAPH_ENABLE_STREAM_ORDERED_MEMORY_ALLOCATION\", false);\n}\n\nbool RegstMgr::IsStreamOrderedMemoryAllocationCase(const MemoryCase& mem_case) const {\n  if (!stream_ordered_memory_allocation_enabled_) { return false; }\n  const auto& device = GetDeviceByMemoryCase(mem_case);\n  return device->IsStreamOrderedMemoryAllocationSupported();\n}\n\nvoid RegstMgr::AddPlan(\n    const Plan& plan,\n    const HashMap<std::string, vm::EagerBlobObject*>& variable_op_name2eager_blob_object) {\n  int64_t this_machine_id = GlobalProcessCtx::Rank();\n\n  HashMap<int64_t, char*> chunk_id2ptr;\n  for (const ChunkProto& chunk : plan.block_chunk_list().chunk()) {\n    if (chunk.machine_id() != this_machine_id) { continue; }\n    if (chunk.mem_size() == 0) { continue; }\n    if (IsStreamOrderedMemoryAllocationCase(chunk.mem_case())) { continue; }\n    char* chunk_ptr = Singleton<ChunkMgr>::Get()->FindOrCreateChunk(chunk);\n    CHECK(chunk_id2ptr.emplace(chunk.chunk_id(), chunk_ptr).second);\n  }\n\n  HashSet<int64_t> all_block_ids;\n  HashMap<int64_t, PackedChunkInfo> zone_id2packed_chunk;\n  for (const MemBlockProto& mem_block : plan.block_chunk_list().mem_block()) {\n    if (mem_block.machine_id() != this_machine_id) { continue; }\n    if (mem_block.mem_size() == 0) { continue; }\n    const int64_t mem_block_id = mem_block.mem_block_id();\n    CHECK(all_block_ids.insert(mem_block_id).second);\n\n    if (mem_block.has_chunk_id()) {\n      if (IsStreamOrderedMemoryAllocationCase(mem_block.mem_case())) {\n        CHECK(mem_block.enable_reuse_mem());\n        CHECK(stream_ordered_allocation_mem_block_ids_.emplace(mem_block_id).second);\n        continue;\n      }\n      CHECK(mem_block.has_chunk_offset());\n      CHECK(chunk_id2ptr.find(mem_block.chunk_id()) != chunk_id2ptr.end());\n      char* mem_block_ptr = chunk_id2ptr.at(mem_block.chunk_id()) + mem_block.chunk_offset();\n      CHECK(mem_block_id2ptr_.emplace(mem_block_id, mem_block_ptr).second)\n          << \" duplicated mem_block_id \" << mem_block_id;\n      CHECK(!mem_block.has_variable_op_name());\n    } else if (mem_block.has_variable_op_name()) {\n      // NOTE(chengcheng): bind mem_block_ptr to variable blob header_ptr and body_ptr\n      CHECK(!mem_block.enable_reuse_mem());\n      const std::string& var_name = mem_block.variable_op_name();\n      CHECK(!var_name.empty());\n      auto it = variable_op_name2eager_blob_object.find(var_name);\n      CHECK(it != variable_op_name2eager_blob_object.end())\n          << \" CANNOT find variable op name: \" << var_name;\n      CHECK(mem_block.has_is_separated_header());\n      vm::EagerBlobObject* var_blob = it->second;\n      CHECK(var_blob) << \" variable op name: \" << var_name << \" in rank: \" << this_machine_id\n                      << \" CANNNOT NULL.\";\n      if (mem_block.is_separated_header()) {\n        CHECK_GE(var_blob->AlignedByteSizeOfBlobHeader(), mem_block.mem_size());\n        CHECK_GE(mem_block.mem_size(), var_blob->ByteSizeOfBlobHeader());\n        CHECK(mem_block_id2ptr_.emplace(mem_block_id, var_blob->mut_header_ptr()).second);\n        CHECK(memory::IsHostMem(mem_block.mem_case()));\n      } else {\n        CHECK_GE(var_blob->AlignedByteSizeOfBlobBody(), mem_block.mem_size());\n        CHECK_GE(mem_block.mem_size(), var_blob->ByteSizeOfBlobBody());\n        CHECK(mem_block_id2ptr_.emplace(mem_block_id, var_blob->mut_dptr<char>()).second);\n        // NOTE(chengcheng):\n        //   CPU eager var tensor mem case is host_mem WITHOUT cuda pinned, but Lazy Complier\n        //   will set variable op output blob mem_case with cuda pinned memory if this output\n        //   blob has GPU op consume. We can JUST ignore this diff because it ONLY has little\n        //   perf loss but correct.\n        //   And this problem is NOT tensor.to(\"cuda\") or tensor.to_global().\n        CHECK(memory::EqualsIgnorePinnedDevice(mem_block.mem_case(), var_blob->mem_case()))\n            << \" variable op name: \" << var_name << \" in rank: \" << this_machine_id\n            << \" bind eager tensor failed. The eager var tensor mem_case is : \"\n            << var_blob->mem_case().DebugString()\n            << \" but graph expected_mem block mem_case is : \" << mem_block.mem_case().DebugString();\n      }\n    } else {\n      int64_t zone_id = memory::GetMemCaseId(mem_block.mem_case());\n      if (zone_id2packed_chunk.find(zone_id) == zone_id2packed_chunk.end()) {\n        zone_id2packed_chunk.emplace(zone_id, PackedChunkInfo(mem_block.mem_case()));\n      }\n      PackedChunkInfo* packed_chunk = &(zone_id2packed_chunk.at(zone_id));\n      packed_chunk->blocks.emplace_back(&mem_block);\n      packed_chunk->size += mem_block.mem_size();\n      CHECK(packed_chunk->mem_case == mem_block.mem_case());\n    }\n  }\n\n  for (auto& pair : zone_id2packed_chunk) {\n    PackedChunkInfo* packed_chunk = &pair.second;\n    char* ptr =\n        Singleton<MemoryAllocator>::Get()->Allocate(packed_chunk->mem_case, packed_chunk->size);\n    // sort blocks as thrd id\n    std::vector<const MemBlockProto*>* blocks = &(packed_chunk->blocks);\n    std::sort(blocks->begin(), blocks->end(),\n              [](const MemBlockProto* lhs, const MemBlockProto* rhs) {\n                if (lhs->thrd_id_hint() == rhs->thrd_id_hint()) {\n                  return lhs->mem_block_id() < rhs->mem_block_id();\n                }\n                return lhs->thrd_id_hint() < rhs->thrd_id_hint();\n              });\n    int64_t offset = 0;\n    for (const MemBlockProto* block : packed_chunk->blocks) {\n      CHECK(mem_block_id2ptr_.emplace(block->mem_block_id(), ptr + offset).second);\n      offset += block->mem_size();\n    }\n    CHECK_EQ(offset, packed_chunk->size);\n  }\n\n  for (int64_t mem_block_id : all_block_ids) {\n    if (mem_block_id2ptr_.find(mem_block_id) != mem_block_id2ptr_.end()) {\n      CHECK(stream_ordered_allocation_mem_block_ids_.find(mem_block_id)\n            == stream_ordered_allocation_mem_block_ids_.end());\n    } else {\n      CHECK(stream_ordered_allocation_mem_block_ids_.find(mem_block_id)\n            != stream_ordered_allocation_mem_block_ids_.end());\n    }\n  }\n\n  for (const TaskProto& task : plan.task()) {\n    if (task.machine_id() != this_machine_id) { continue; }\n    for (const auto& pair : task.produced_regst_desc()) {\n      const RegstDescProto& regst_desc = pair.second;\n      const int64_t regst_desc_id = regst_desc.regst_desc_id();\n      CHECK(regst_desc_id2rt_regst_desc_\n                .emplace(regst_desc_id, std::make_unique<const RtRegstDesc>(regst_desc))\n                .second);\n    }\n  }\n  for (const auto& pair : plan.ctrl_regst_desc_info().ctrl_regst_desc_id2producer_task_id()) {\n    CHECK(ctrl_regst_desc_id2producer_task_id_.emplace(pair.first, pair.second).second);\n  }\n}\n\nvoid RegstMgr::AddPlan(const Plan& plan) {\n  HashMap<std::string, vm::EagerBlobObject*> variable_op_name2eager_blob_object;\n  AddPlan(plan, variable_op_name2eager_blob_object);\n}\n\nvoid RegstMgr::NewRegsts(const RegstDescProto& regst_desc_proto,\n                         std::function<void(Regst*)> OneRegstDone) {\n  const int64_t regst_desc_id = regst_desc_proto.regst_desc_id();\n  const RegstDescTypeProto& regst_desc_type = regst_desc_proto.regst_desc_type();\n  const RtRegstDesc* rt_regst_desc = regst_desc_id2rt_regst_desc_.at(regst_desc_id).get();\n  char* main_mem_ptr = nullptr;\n  char* separated_header_mem_ptr = nullptr;\n  int64_t mem_block_id = regst_desc_proto.mem_block_id();\n  int64_t header_block_id = regst_desc_proto.separated_header_mem_block_id();\n  if (mem_block_id != -1 && mem_block_id2ptr_.find(mem_block_id) != mem_block_id2ptr_.end()) {\n    main_mem_ptr = mem_block_id2ptr_.at(mem_block_id) + regst_desc_proto.mem_block_offset();\n  }\n  if (header_block_id != -1 && mem_block_id2ptr_.find(header_block_id) != mem_block_id2ptr_.end()) {\n    separated_header_mem_ptr = mem_block_id2ptr_.at(header_block_id);\n  }\n  RegstAllocationType allocation_type = stream_ordered_allocation_mem_block_ids_.find(mem_block_id)\n                                                == stream_ordered_allocation_mem_block_ids_.end()\n                                            ? RegstAllocationType::kStatic\n                                            : RegstAllocationType::kStreamOrdered;\n  for (int64_t i = 0; i < rt_regst_desc->register_num(); ++i) {\n    Regst* regst = new Regst(rt_regst_desc, allocation_type);\n    if (regst_desc_type.has_data_regst_desc()) {\n      InitDataRegst(regst, main_mem_ptr, separated_header_mem_ptr);\n      if (main_mem_ptr != nullptr) { main_mem_ptr += rt_regst_desc->MainByteSize4OneRegst(); }\n      if (separated_header_mem_ptr != nullptr) {\n        separated_header_mem_ptr += rt_regst_desc->SeparatedHeaderByteSize4OneRegst();\n      }\n    } else if (regst_desc_type.has_ctrl_regst_desc()) {\n      // do nothing\n    } else {\n      UNIMPLEMENTED();\n    }\n    OneRegstDone(regst);\n  }\n}\n\nconst RtRegstDesc& RegstMgr::RegstDesc4RegstDescId(int64_t regst_desc_id) const {\n  const auto& it = regst_desc_id2rt_regst_desc_.find(regst_desc_id);\n  CHECK(it != regst_desc_id2rt_regst_desc_.end());\n  return *it->second;\n}\n\nbool RegstMgr::HasRegstDescId(int64_t regst_desc_id) const {\n  return regst_desc_id2rt_regst_desc_.find(regst_desc_id) != regst_desc_id2rt_regst_desc_.end();\n}\n\nint64_t RegstMgr::ProducerTaskId4RegstDescId(int64_t regst_desc_id) const {\n  const auto& it = ctrl_regst_desc_id2producer_task_id_.find(regst_desc_id);\n  CHECK(it != ctrl_regst_desc_id2producer_task_id_.end());\n  return it->second;\n}\n\nbool RegstMgr::HasProducerTaskId4RegstDescId(int64_t regst_desc_id) const {\n  return ctrl_regst_desc_id2producer_task_id_.find(regst_desc_id)\n         != ctrl_regst_desc_id2producer_task_id_.end();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/register/register_manager.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_REGISTER_REGISTER_MANAGER_H_\n#define ONEFLOW_CORE_REGISTER_REGISTER_MANAGER_H_\n\n#include <mutex>\n\n#include \"oneflow/core/job/id_manager.h\"\n#include \"oneflow/core/job/plan.pb.h\"\n#include \"oneflow/core/job/runtime_context.h\"\n#include \"oneflow/core/memory/memory_allocator.h\"\n#include \"oneflow/core/register/blob.h\"\n#include \"oneflow/core/register/logical_blob_id.pb.h\"\n#include \"oneflow/core/register/register.h\"\n#include \"oneflow/core/record/record.pb.h\"\n\nnamespace oneflow {\n\nnamespace vm {\nclass EagerBlobObject;\n}\n\nclass RegstMgr final {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(RegstMgr);\n  RegstMgr();\n  ~RegstMgr() = default;\n\n  void AddPlan(\n      const Plan& plan,\n      const HashMap<std::string, vm::EagerBlobObject*>& variable_op_name2eager_blob_object);\n  void AddPlan(const Plan& plan);\n  void NewRegsts(const RegstDescProto& regst_desc_proto, std::function<void(Regst*)> OneRegstDone);\n  const RtRegstDesc& RegstDesc4RegstDescId(int64_t regst_desc_id) const;\n  bool HasRegstDescId(int64_t regst_desc_id) const;\n  int64_t ProducerTaskId4RegstDescId(int64_t regst_desc_id) const;\n  bool HasProducerTaskId4RegstDescId(int64_t regst_desc_id) const;\n\n private:\n  bool IsStreamOrderedMemoryAllocationCase(const MemoryCase& mem_case) const;\n\n  HashMap<int64_t, std::unique_ptr<const RtRegstDesc>> regst_desc_id2rt_regst_desc_;\n  HashMap<int64_t, char*> mem_block_id2ptr_;\n  HashSet<int64_t> stream_ordered_allocation_mem_block_ids_;\n  HashMap<int64_t, int64_t> ctrl_regst_desc_id2producer_task_id_;\n  std::mutex mutex_;\n  bool stream_ordered_memory_allocation_enabled_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_REGISTER_REGISTER_MANAGER_H_\n"
  },
  {
    "path": "oneflow/core/register/runtime_register_desc.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/register/runtime_register_desc.h\"\n#include \"oneflow/core/memory/memory_case_util.h\"\n#include \"oneflow/core/common/protobuf.h\"\n\nnamespace oneflow {\n\nRtRegstDesc::RtRegstDesc(const RegstDescProto& proto)\n    : one_regst_header_size_(0), one_regst_body_size_(0) {\n  regst_desc_id_ = proto.regst_desc_id();\n  producer_actor_id_ = proto.producer_task_id();\n  consumers_actor_id_ = PbRf2StdVec(proto.consumer_task_id());\n  register_num_ = proto.register_num();\n  mem_case_ = proto.mem_case();\n  regst_desc_type_ = proto.regst_desc_type();\n  if (proto.regst_desc_type().has_data_regst_desc()) {\n    const DataRegstDesc& data_regst_desc = proto.regst_desc_type().data_regst_desc();\n    std::vector<LbiBlobDescPair> lbi_pairs(\n        {data_regst_desc.lbi2blob_desc().cbegin(), data_regst_desc.lbi2blob_desc().cend()});\n    std::sort(lbi_pairs.begin(), lbi_pairs.end(), &CompareLbiBlobDescPair);\n    CHECK_EQ(lbi_pairs.size(), 1);\n    sorted_blob_desc_vec_.reserve(lbi_pairs.size());\n    sorted_lbi_vec_.reserve(lbi_pairs.size());\n    for (int64_t i = 0; i < lbi_pairs.size(); ++i) {\n      const LbiBlobDescPair& pair = lbi_pairs.at(i);\n      sorted_blob_desc_vec_.emplace_back(std::make_unique<const BlobDesc>(pair.blob_desc()));\n      sorted_lbi_vec_.emplace_back(pair.lbi());\n      lbi2blob_desc_ordinal_.emplace(pair.lbi(), i);\n    }\n    CHECK(data_regst_desc.has_time_shape());\n    data_regst_time_shape_.reset(new Shape(data_regst_desc.time_shape()));\n  } else {\n    sorted_blob_desc_vec_.emplace_back(\n        std::make_unique<const BlobDesc>(BlobDesc(DataType::kChar, MemoryFormat::kContiguous)));\n  }\n  for (const auto& blob_desc_ : sorted_blob_desc_vec_) {\n    one_regst_header_size_ += blob_desc_->AlignedByteSizeOfBlobHeader();\n    one_regst_body_size_ += blob_desc_->AlignedByteSizeOfBlobBody();\n  }\n\n  if ((!memory::IsHostMem(proto.mem_case()))\n      || (proto.has_variable_op_name() && !proto.variable_op_name().empty())) {\n    // NOTE(chengcheng): When this regst is shared with EagerBlobObject, header is ALWAYS separated.\n    has_separated_header_ = true;\n  } else {\n    has_separated_header_ = false;\n  }\n}\n\nint64_t RtRegstDesc::GetOrdinalForLbi(const LogicalBlobId& lbi) const {\n  auto it = lbi2blob_desc_ordinal_.find(lbi);\n  if (it != lbi2blob_desc_ordinal_.cend()) {\n    return it->second;\n  } else {\n    return -1;\n  }\n}\n\nconst BlobDesc* RtRegstDesc::GetBlobDescFromLbi(const LogicalBlobId& lbi) const {\n  auto it = lbi2blob_desc_ordinal_.find(lbi);\n  if (it == lbi2blob_desc_ordinal_.end()) {\n    return nullptr;\n  } else {\n    return GetBlobDescByOrdinal(it->second);\n  }\n}\n\nconst BlobDesc* RtRegstDesc::GetBlobDescByOrdinal(int64_t ordinal) const {\n  return sorted_blob_desc_vec_.at(ordinal).get();\n}\n\nconst LogicalBlobId& RtRegstDesc::GetLbiByOrdinal(int64_t ordinal) const {\n  return sorted_lbi_vec_.at(ordinal);\n}\n\nconst BlobDesc* RtRegstDesc::GetSoleBlobDesc() const {\n  CHECK_EQ(sorted_blob_desc_vec_.size(), 1);\n  return sorted_blob_desc_vec_.at(0).get();\n}\n\nsize_t RtRegstDesc::TotalByteSize4AllRegst() const {\n  return (one_regst_header_size_ + one_regst_body_size_) * register_num_;\n}\n\nsize_t RtRegstDesc::TotalMainByteSize4AllRegst() const {\n  return MainByteSize4OneRegst() * register_num_;\n}\n\nsize_t RtRegstDesc::TotalBodyByteSize4AllRegst() const {\n  return BodyByteSize4OneRegst() * register_num_;\n}\n\nsize_t RtRegstDesc::MainByteSize4OneRegst() const {\n  if (has_separated_header_) {\n    return one_regst_body_size_;\n  } else {\n    return one_regst_body_size_ + one_regst_header_size_;\n  }\n}\n\nsize_t RtRegstDesc::BodyByteSize4OneRegst() const { return one_regst_body_size_; }\n\nsize_t RtRegstDesc::HeaderByteSize4OneRegst() const { return one_regst_header_size_; }\n\nsize_t RtRegstDesc::TotalSeparatedHeaderByteSize4AllRegst() const {\n  return SeparatedHeaderByteSize4OneRegst() * register_num_;\n}\n\nsize_t RtRegstDesc::SeparatedHeaderByteSize4OneRegst() const {\n  if (has_separated_header_) {\n    // NOTE(chengcheng): Header size need to be aligned for XRT memory allocate\n    return one_regst_header_size_;\n  } else {\n    return 0;\n  }\n}\n\nconst Shape& RtRegstDesc::data_regst_time_shape() const {\n  CHECK(regst_desc_type_.has_data_regst_desc());\n  CHECK(data_regst_time_shape_);\n  return *data_regst_time_shape_;\n}\n\nvoid RtRegstDesc::ForEachBlobDescOffsetInOnRegst(\n    const std::function<void(int64_t ordinal, const LogicalBlobId& lbi, const BlobDesc* desc,\n                             int64_t body_offset, int64_t header_offset)>& Handler) const {\n  int64_t cur_body_offset = 0;\n  int64_t cur_header_offset = 0;\n  for (int64_t i = 0; i < sorted_blob_desc_vec_.size(); ++i) {\n    const BlobDesc* blob_desc = sorted_blob_desc_vec_.at(i).get();\n    const LogicalBlobId& lbi = sorted_lbi_vec_.at(i);\n    Handler(i, lbi, blob_desc, cur_body_offset, cur_header_offset);\n    cur_body_offset += blob_desc->AlignedByteSizeOfBlobBody();\n    cur_header_offset += blob_desc->AlignedByteSizeOfBlobHeader();\n  }\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/register/runtime_register_desc.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_REGISTER_RUNTIME_REGISTER_DESC_H_\n#define ONEFLOW_CORE_REGISTER_RUNTIME_REGISTER_DESC_H_\n\n#include \"oneflow/core/memory/memory_case.pb.h\"\n#include \"oneflow/core/register/blob_desc.h\"\n#include \"oneflow/core/register/register_desc.pb.h\"\n\nnamespace oneflow {\n\nclass RtRegstDesc {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(RtRegstDesc);\n  RtRegstDesc() = delete;\n  ~RtRegstDesc() = default;\n\n  RtRegstDesc(const RegstDescProto& regst_desc_proto);\n\n  int64_t regst_desc_id() const { return regst_desc_id_; }\n  int64_t producer_actor_id() const { return producer_actor_id_; }\n  const std::vector<int64_t>& consumers_actor_id() const { return consumers_actor_id_; }\n  int64_t register_num() const { return register_num_; }\n  const MemoryCase& mem_case() const { return mem_case_; }\n  const RegstDescTypeProto& regst_desc_type() const { return regst_desc_type_; }\n\n  int64_t lbi_num() const { return sorted_lbi_vec_.size(); }\n  int64_t GetOrdinalForLbi(const LogicalBlobId& lbi) const;\n  const BlobDesc* GetBlobDescFromLbi(const LogicalBlobId& lbi) const;\n  const BlobDesc* GetBlobDescByOrdinal(int64_t ordinal) const;\n  const BlobDesc* GetSoleBlobDesc() const;\n  const LogicalBlobId& GetLbiByOrdinal(int64_t ordinal) const;\n  size_t TotalByteSize4AllRegst() const;\n  size_t TotalMainByteSize4AllRegst() const;\n  size_t TotalBodyByteSize4AllRegst() const;\n  size_t TotalSeparatedHeaderByteSize4AllRegst() const;\n  size_t SeparatedHeaderByteSize4OneRegst() const;\n  size_t MainByteSize4OneRegst() const;\n  size_t BodyByteSize4OneRegst() const;\n  size_t HeaderByteSize4OneRegst() const;\n  const Shape& data_regst_time_shape() const;\n\n  void ForEachBlobDescOffsetInOnRegst(\n      const std::function<void(int64_t ordinal, const LogicalBlobId& lbi, const BlobDesc* desc,\n                               int64_t body_offset, int64_t header_offset)>& Handler) const;\n\n private:\n  int64_t regst_desc_id_;\n  int64_t producer_actor_id_;\n  std::vector<int64_t> consumers_actor_id_;\n  int64_t register_num_;\n  RegstDescTypeProto regst_desc_type_;\n  MemoryCase mem_case_;\n  HashMap<LogicalBlobId, int64_t> lbi2blob_desc_ordinal_;\n  std::unique_ptr<Shape> data_regst_time_shape_;\n  std::vector<std::unique_ptr<const BlobDesc>> sorted_blob_desc_vec_;\n  std::vector<LogicalBlobId> sorted_lbi_vec_;\n\n  bool has_separated_header_;\n  size_t one_regst_header_size_;\n  size_t one_regst_body_size_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_REGISTER_RUNTIME_REGISTER_DESC_H_\n"
  },
  {
    "path": "oneflow/core/register/tensor_slice_copier.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/register/tensor_slice_copier.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nTensorSliceView GetRawTenserSliceView(const TensorSliceView& view, DataType data_type) {\n  const size_t size_of_data_type = GetSizeOfDataType(data_type);\n  if (size_of_data_type == 1) {\n    return view;\n  } else {\n    std::vector<Range> range_vec = view.range_vec();\n    if (!view.IsEmpty()) {\n      range_vec.back().mut_begin() = range_vec.back().begin() * size_of_data_type;\n      range_vec.back().mut_end() = range_vec.back().end() * size_of_data_type;\n    }\n    return TensorSliceView(range_vec);\n  }\n}\n\n}  // namespace\n\nTensorSliceCopier::TensorSliceCopier(const TensorSliceView& dst_view,\n                                     const TensorSliceView& src_view,\n                                     const TensorSliceView& copy_view, const DataType data_type,\n                                     const DeviceType device_type)\n    : dst_view_(dst_view), src_view_(src_view), extent_(copy_view.shape()), data_type_(data_type) {\n  copy_nd_primitive_ = ep::primitive::NewPrimitive<ep::primitive::CopyNdFactory>(\n      device_type, dst_view_.shape().NumAxes());\n  CHECK(dst_view.Contains(copy_view));\n  CHECK(src_view.Contains(copy_view));\n  dst_pos_ = copy_view.OffsetTo(dst_view);\n  src_pos_ = copy_view.OffsetTo(src_view);\n}\n\nTensorSliceCopier::TensorSliceCopier(const TensorSliceView& dst_view,\n                                     const TensorSliceView& src_view, const DataType data_type,\n                                     const DeviceType device_type)\n    : TensorSliceCopier(dst_view, src_view, dst_view.Intersect(src_view), data_type, device_type) {}\n\nvoid TensorSliceCopier::Copy(ep::Stream* stream, void* dst, const void* src) const {\n  copy_nd_primitive_->Launch(stream, data_type_, dst_view_.shape().NumAxes(), dst,\n                             dst_view_.shape().dim_vec().data(), dst_pos_.dim_vec().data(), src,\n                             src_view_.shape().dim_vec().data(), src_pos_.dim_vec().data(),\n                             extent_.dim_vec().data());\n}\n\nvoid TensorSliceCopier::Copy(ep::Stream* stream, Blob* dst_blob, const Blob* src_blob) const {\n  CHECK_EQ(dst_blob->data_type(), data_type_);\n  CHECK_EQ(src_blob->data_type(), data_type_);\n  CHECK_EQ(dst_view_.shape().elem_cnt(), dst_blob->shape().elem_cnt());\n  CHECK_EQ(src_view_.shape().elem_cnt(), src_blob->shape().elem_cnt());\n  Copy(stream, dst_blob->mut_dptr(), src_blob->dptr());\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/register/tensor_slice_copier.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_REGISTER_TENSOR_SLICE_COPIER_H_\n#define ONEFLOW_CORE_REGISTER_TENSOR_SLICE_COPIER_H_\n\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/register/tensor_slice_view.h\"\n#include \"oneflow/core/register/blob.h\"\n#include \"oneflow/core/ep/include/primitive/copy_nd.h\"\n\nnamespace oneflow {\n\nclass TensorSliceCopier final {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(TensorSliceCopier);\n  TensorSliceCopier(const TensorSliceView& dst_view, const TensorSliceView& src_view,\n                    const TensorSliceView& copy_view, DataType data_type, DeviceType device_type);\n  TensorSliceCopier(const TensorSliceView& dst_view, const TensorSliceView& src_view,\n                    DataType data_type, DeviceType device_type);\n  virtual ~TensorSliceCopier() = default;\n\n  void Copy(ep::Stream* stream, void* dst, const void* src) const;\n  void Copy(ep::Stream* stream, Blob* dst_blob, const Blob* src_blob) const;\n\n private:\n  const TensorSliceView dst_view_;\n  const TensorSliceView src_view_;\n  NdIndex dst_pos_;\n  NdIndex src_pos_;\n  Shape extent_;\n  const DataType data_type_;\n  std::unique_ptr<ep::primitive::CopyNd> copy_nd_primitive_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_REGISTER_TENSOR_SLICE_COPIER_H_\n"
  },
  {
    "path": "oneflow/core/register/tensor_slice_view.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/register/tensor_slice_view.h\"\n\nnamespace oneflow {\n\nTensorSliceView::TensorSliceView(const std::initializer_list<Range>& ranges) : range_vec_(ranges) {\n  UpdateShape();\n}\n\nTensorSliceView::TensorSliceView(const std::vector<Range>& ranges) : range_vec_(ranges) {\n  UpdateShape();\n}\n\nTensorSliceView::TensorSliceView(const TensorSliceViewProto& proto) {\n  range_vec_.resize(proto.dim_size());\n  std::transform(proto.dim().cbegin(), proto.dim().cend(), range_vec_.begin(),\n                 [](const RangeProto& rp) { return Range(rp); });\n  UpdateShape();\n}\n\nTensorSliceView::TensorSliceView(const Shape& shape) {\n  range_vec_.resize(shape.dim_vec().size());\n  std::transform(shape.dim_vec().cbegin(), shape.dim_vec().cend(), range_vec_.begin(),\n                 [](const int64_t dim_size) { return Range(0, dim_size); });\n  UpdateShape();\n}\n\nTensorSliceView& TensorSliceView::operator=(const TensorSliceView& other) {\n  range_vec_ = other.range_vec_;\n  UpdateShape();\n  return *this;\n}\n\nbool TensorSliceView::operator==(const TensorSliceView& rhs) const {\n  return range_vec_ == rhs.range_vec_;\n}\n\nbool TensorSliceView::operator!=(const TensorSliceView& rhs) const { return !(*this == rhs); }\n\nvoid TensorSliceView::UpdateShape() {\n  DimVector dim_vec(range_vec_.size());\n  std::transform(range_vec_.cbegin(), range_vec_.cend(), dim_vec.begin(),\n                 [](const Range& range) { return range.size(); });\n  shape_ = Shape(dim_vec);\n}\n\nbool TensorSliceView::IsEmpty() const { return range_vec_.empty(); }\n\nbool TensorSliceView::Contains(const TensorSliceView& other) const {\n  if (other.IsEmpty()) { return true; }\n  CHECK_EQ(NumAxes(), other.NumAxes());\n  FOR_RANGE(int64_t, i, 0, NumAxes()) {\n    if (range_vec_.at(i).begin() > other.range_vec_.at(i).begin()\n        || range_vec_.at(i).end() < other.range_vec_.at(i).end()) {\n      return false;\n    }\n  }\n  return true;\n}\n\nTensorSliceView TensorSliceView::Intersect(const TensorSliceView& other) const {\n  if (IsEmpty() || other.IsEmpty()) { return TensorSliceView(); }\n  CHECK_EQ(other.range_vec_.size(), range_vec_.size());\n  std::vector<Range> intersection_vec;\n  intersection_vec.reserve(range_vec_.size());\n  const Range zero(0, 0);\n  FOR_RANGE(int64_t, i, 0, range_vec_.size()) {\n    const Range intersection = FindIntersectant(range_vec_.at(i), other.range_vec_.at(i));\n    if (intersection == zero) {\n      return TensorSliceView();\n    } else {\n      intersection_vec.emplace_back(intersection);\n    }\n  }\n  return TensorSliceView(intersection_vec);\n}\n\nconst Range& TensorSliceView::At(int64_t index) const { return range_vec_.at(index); }\n\nconst Shape& TensorSliceView::shape() const { return shape_; }\n\nconst std::vector<Range>& TensorSliceView::range_vec() const { return range_vec_; }\n\nsize_t TensorSliceView::NumAxes() const { return range_vec_.size(); }\n\nNdIndex TensorSliceView::OffsetTo(const TensorSliceView& other) const {\n  CHECK_EQ(other.NumAxes(), NumAxes());\n  DimVector indices_vec(range_vec_.size());\n  std::transform(range_vec_.cbegin(), range_vec_.cend(), other.range_vec_.cbegin(),\n                 indices_vec.begin(),\n                 [](const Range& lhs, const Range& rhs) { return lhs.begin() - rhs.begin(); });\n  return NdIndex(indices_vec);\n}\n\nvoid TensorSliceView::ToProto(TensorSliceViewProto* proto) const {\n  for (const Range& range : range_vec_) { range.ToProto(proto->mutable_dim()->Add()); }\n}\n\nTensorSliceView TensorSliceView::Concatenate(std::vector<TensorSliceView>& slices, int64_t axis) {\n  CHECK_GT(slices.size(), 0);\n  const int64_t num_axes = slices.front().shape().NumAxes();\n  FOR_RANGE(int64_t, i, 1, slices.size()) { CHECK_EQ(slices.at(i).NumAxes(), num_axes); }\n  CHECK_GE(axis, 0);\n  CHECK_LT(axis, num_axes);\n  FOR_RANGE(int64_t, i, 0, num_axes) {\n    if (i == axis) {\n      CHECK(std::adjacent_find(slices.cbegin(), slices.cend() - 1,\n                               [&](const TensorSliceView& lhs, const TensorSliceView& rhs) {\n                                 return lhs.At(i).end() != rhs.At(i).begin();\n                               })\n            == slices.cend() - 1);\n    } else {\n      const Range& dim_range = slices.front().At(i);\n      CHECK(std::all_of(slices.cbegin() + 1, slices.cbegin(),\n                        [&](const TensorSliceView& view) { return view.At(i) == dim_range; }));\n    }\n  }\n  std::vector<Range> range_vec = slices.front().range_vec();\n  range_vec.at(axis).mut_end() = slices.back().At(axis).end();\n  return TensorSliceView(range_vec);\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/register/tensor_slice_view.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_REGISTER_TENSOR_SLICE_VIEW_H_\n#define ONEFLOW_CORE_REGISTER_TENSOR_SLICE_VIEW_H_\n\n#include \"oneflow/core/common/range.h\"\n#include \"oneflow/core/common/shape.h\"\n#include \"oneflow/core/common/nd_index.h\"\n#include \"oneflow/core/register/tensor_slice_view.pb.h\"\n\nnamespace oneflow {\n\nclass TensorSliceView final {\n public:\n  TensorSliceView() = default;\n  TensorSliceView(const std::initializer_list<Range>& ranges);\n  explicit TensorSliceView(const std::vector<Range>& ranges);\n  explicit TensorSliceView(const TensorSliceViewProto& proto);\n  explicit TensorSliceView(const Shape& shape);\n\n  TensorSliceView& operator=(const TensorSliceView& other);\n  bool operator==(const TensorSliceView& rhs) const;\n  bool operator!=(const TensorSliceView& rhs) const;\n\n  bool IsEmpty() const;\n  TensorSliceView Intersect(const TensorSliceView& other) const;\n  bool Contains(const TensorSliceView& other) const;\n  const Range& At(int64_t index) const;\n  const Shape& shape() const;\n  const std::vector<Range>& range_vec() const;\n  size_t NumAxes() const;\n  NdIndex OffsetTo(const TensorSliceView& other) const;\n  void ToProto(TensorSliceViewProto* proto) const;\n\n  static TensorSliceView Concatenate(std::vector<TensorSliceView>& slices, int64_t axis);\n\n private:\n  std::vector<Range> range_vec_;\n  Shape shape_;\n\n  void UpdateShape();\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_REGISTER_TENSOR_SLICE_VIEW_H_\n"
  },
  {
    "path": "oneflow/core/register/tensor_slice_view.proto",
    "content": "syntax = \"proto2\";\npackage oneflow;\n\nimport \"oneflow/core/common/range.proto\";\n\nmessage TensorSliceViewProto {\n    repeated RangeProto dim = 1;\n}\n"
  },
  {
    "path": "oneflow/core/rpc/include/base.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_RPC_INCLUDE_BASE_CTRL_\n#define ONEFLOW_CORE_RPC_INCLUDE_BASE_CTRL_\n\n#include \"oneflow/core/common/preprocessor.h\"\n#include \"oneflow/core/common/protobuf.h\"\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/control/control.pb.h\"\n#include \"oneflow/core/control/ctrl_bootstrap.pb.h\"\n\nnamespace oneflow {\n\n#define CTRL_METHOD_SEQ               \\\n  OF_PP_MAKE_TUPLE_SEQ(LoadServer)    \\\n  OF_PP_MAKE_TUPLE_SEQ(Barrier)       \\\n  OF_PP_MAKE_TUPLE_SEQ(TryLock)       \\\n  OF_PP_MAKE_TUPLE_SEQ(NotifyDone)    \\\n  OF_PP_MAKE_TUPLE_SEQ(WaitUntilDone) \\\n  OF_PP_MAKE_TUPLE_SEQ(PushKV)        \\\n  OF_PP_MAKE_TUPLE_SEQ(ClearKV)       \\\n  OF_PP_MAKE_TUPLE_SEQ(PullKV)        \\\n  OF_PP_MAKE_TUPLE_SEQ(Clear)         \\\n  OF_PP_MAKE_TUPLE_SEQ(IncreaseCount) \\\n  OF_PP_MAKE_TUPLE_SEQ(EraseCount)\n\n#define CatRequest(method) method##Request,\n#define CatReqponse(method) method##Response,\n#define CatEnum(method) k##method,\n#define CatName(method) \"/oneflow.CtrlService/\" OF_PP_STRINGIZE(method),\n\n#define MAKE_META_DATA()                                                                       \\\n  enum class CtrlMethod { OF_PP_FOR_EACH_TUPLE(CatEnum, CTRL_METHOD_SEQ) };                    \\\n  static const char* g_method_name[] = {OF_PP_FOR_EACH_TUPLE(CatName, CTRL_METHOD_SEQ)};       \\\n  using CtrlRequestTuple = std::tuple<OF_PP_FOR_EACH_TUPLE(CatRequest, CTRL_METHOD_SEQ) void>; \\\n  using CtrlResponseTuple = std::tuple<OF_PP_FOR_EACH_TUPLE(CatReqponse, CTRL_METHOD_SEQ) void>;\n\nMAKE_META_DATA()\n\nconstexpr const size_t kCtrlMethodNum = OF_PP_SEQ_SIZE(CTRL_METHOD_SEQ);\n\ntemplate<CtrlMethod ctrl_method>\nusing CtrlRequest =\n    typename std::tuple_element<static_cast<size_t>(ctrl_method), CtrlRequestTuple>::type;\n\ntemplate<CtrlMethod ctrl_method>\nusing CtrlResponse =\n    typename std::tuple_element<static_cast<size_t>(ctrl_method), CtrlResponseTuple>::type;\n\ninline const char* GetMethodName(CtrlMethod method) {\n  return g_method_name[static_cast<int32_t>(method)];\n}\n\nclass CtrlClient {\n public:\n  explicit CtrlClient(const ProcessCtx& process_ctx);\n  CtrlClient() = default;\n  virtual ~CtrlClient() = default;\n\n  virtual void Barrier(const std::string& barrier_name) = 0;\n  virtual void Barrier(const std::string& barrier_name, int32_t barrier_num) = 0;\n\n  virtual TryLockResult TryLock(const std::string& name) = 0;\n  virtual void NotifyDone(const std::string& name) = 0;\n  virtual void WaitUntilDone(const std::string& name) = 0;\n\n  virtual void PushKV(const std::string& k, std::function<void(std::string*)> VSetter) = 0;\n  virtual void PushKV(const std::string& k, const std::string& v) = 0;\n  virtual void PushKV(const std::string& k, const PbMessage& msg) = 0;\n  virtual void PushMasterKV(const std::string& k, const PbMessage& msg) = 0;\n  template<typename T>\n  typename std::enable_if<std::is_arithmetic<T>::value>::type PushKVT(const std::string& k, T v) {\n    PushKV(k, std::to_string(v));\n  }\n\n  virtual void ClearKV(const std::string& k) = 0;\n  virtual void ClearMasterKV(const std::string& k) = 0;\n\n  virtual void PullKV(const std::string& k, std::function<void(const std::string&)> VGetter) = 0;\n  virtual void PullKV(const std::string& k, std::string* v) = 0;\n  virtual void PullKV(const std::string& k, PbMessage* msg) = 0;\n  virtual void PullMasterKV(const std::string& k, PbMessage* msg) = 0;\n  template<typename T>\n  typename std::enable_if<std::is_arithmetic<T>::value>::type PullKVT(const std::string& k, T* v) {\n    std::string v_str;\n    PullKV(k, &v_str);\n    *v = oneflow_cast<T>(v_str);\n  }\n\n  virtual void Clear() = 0;\n  virtual int32_t IncreaseCount(const std::string& k, int32_t v) = 0;\n  int32_t IncreaseCount(const std::string& k) { return IncreaseCount(k, 1); }\n  virtual void EraseCount(const std::string& k) = 0;\n};\n\n#define FILE_LINE_STR __FILE__ \":\" OF_PP_STRINGIZE(__LINE__)\n#define OF_ENV_BARRIER() oneflow::Singleton<oneflow::CtrlClient>::Get()->Barrier(FILE_LINE_STR)\n#define OF_SESSION_BARRIER()                               \\\n  oneflow::Singleton<oneflow::CtrlClient>::Get()->Barrier( \\\n      FILE_LINE_STR, Singleton<ResourceDesc, ForSession>::Get()->process_ranks().size())\n\nstatic void OfCallOnce(const std::string& name, std::function<void()> f) {\n  TryLockResult lock_ret = Singleton<CtrlClient>::Get()->TryLock(name);\n  if (lock_ret == TryLockResult::kLocked) {\n    f();\n    Singleton<CtrlClient>::Get()->NotifyDone(name);\n  } else if (lock_ret == TryLockResult::kDone) {\n  } else if (lock_ret == TryLockResult::kDoing) {\n    Singleton<CtrlClient>::Get()->WaitUntilDone(name);\n  } else {\n    UNIMPLEMENTED();\n  }\n}\n\ntemplate<typename Self, typename F, typename Arg, typename... Args>\nstatic void OfCallOnce(const std::string& name, Self self, F f, Arg&& arg, Args&&... args) {\n  std::function<void()> fn =\n      std::bind(f, self, std::forward<Arg>(arg), std::forward<Args>(args)...);\n  OfCallOnce(name, std::move(fn));\n}\n\ntemplate<typename Self, typename F>\nstatic void OfCallOnce(const std::string& name, Self self, F f) {\n  std::function<void()> fn = std::bind(f, self, name);\n  OfCallOnce(name, std::move(fn));\n}\n\ntemplate<typename F, typename Arg, typename... Args>\nstatic void OfCallOnce(const std::string& name, F f, Arg&& arg, Args&&... args) {\n  std::function<void()> fn = std::bind(f, std::forward<Arg>(arg), std::forward<Args>(args)...);\n  OfCallOnce(name, std::move(fn));\n}\n\nclass RpcManager {\n public:\n  RpcManager() = default;\n  virtual ~RpcManager() = default;\n  virtual Maybe<void> Bootstrap() = 0;\n  virtual Maybe<void> CreateServer() = 0;\n  virtual Maybe<void> CreateClient() = 0;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_RPC_INCLUDE_BASE_CTRL_\n"
  },
  {
    "path": "oneflow/core/rpc/include/ctrl.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#ifndef ONEFLOW_CORE_RPC_INCLUDE_CTRL_\n#define ONEFLOW_CORE_RPC_INCLUDE_CTRL_\n\n#ifdef RPC_BACKEND_GRPC\n#include \"oneflow/core/rpc/include/grpc.h\"\n#endif  // RPC_BACKEND_GRPC\n\n#ifdef RPC_BACKEND_LOCAL\n#include \"oneflow/core/rpc/include/local.h\"\n#endif  // RPC_BACKEND_LOCAL\n\n#endif  // ONEFLOW_CORE_RPC_INCLUDE_CTRL_\n"
  },
  {
    "path": "oneflow/core/rpc/include/global_process_ctx.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_RPC_INCLUDE_GLOBAL_PROCESS_CTX_\n#define ONEFLOW_CORE_RPC_INCLUDE_GLOBAL_PROCESS_CTX_\n\n#include <string>\n\nnamespace oneflow {\n\nstruct GlobalProcessCtx {\n  static void GetMachineIdAndDeviceId(int64_t rank, int64_t* machine_id, int64_t* device_id);\n  static void GetCurrentMachineIdAndDeviceId(int64_t* machine_id, int64_t* device_id);\n  static int64_t Rank();\n  static int64_t LocalRank();\n  static int64_t LocalRank(int64_t rank);\n  static int64_t NodeId(int64_t process_id);\n  static int64_t NodeSize();\n  static int64_t ThisNodeId();\n  static int64_t NumOfProcessPerNode();\n  static bool IsThisProcessMaster();\n  static size_t WorldSize();\n  static std::string LogDirEntry();\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_RPC_INCLUDE_GLOBAL_PROCESS_CTX_\n"
  },
  {
    "path": "oneflow/core/rpc/include/grpc.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_RPC_INCLUDE_GRPC_H_\n#define ONEFLOW_CORE_RPC_INCLUDE_GRPC_H_\n\n#include \"oneflow/core/control/rpc_client.h\"\n#include \"oneflow/core/rpc/include/base.h\"\n#include \"oneflow/core/control/ctrl_bootstrap.pb.h\"\n\nnamespace oneflow {\n\nclass GrpcCtrlClient final : public CtrlClient {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(GrpcCtrlClient);\n  explicit GrpcCtrlClient(const ProcessCtx& process_ctx);\n  ~GrpcCtrlClient() override;\n\n  void Barrier(const std::string& barrier_name) override;\n  void Barrier(const std::string& barrier_name, int32_t barrier_num) override;\n\n  TryLockResult TryLock(const std::string& name) override;\n  void NotifyDone(const std::string& name) override;\n  void WaitUntilDone(const std::string& name) override;\n\n  void PushKV(const std::string& k, std::function<void(std::string*)> VSetter) override;\n  void PushKV(const std::string& k, const std::string& v) override;\n  void PushKV(const std::string& k, const PbMessage& msg) override;\n  void PushMasterKV(const std::string& k, const PbMessage& msg) override;\n\n  void ClearKV(const std::string& k) override;\n  void ClearMasterKV(const std::string& k) override;\n\n  void PullKV(const std::string& k, std::function<void(const std::string&)> VGetter) override;\n  void PullKV(const std::string& k, std::string* v) override;\n  void PullKV(const std::string& k, PbMessage* msg) override;\n  void PullMasterKV(const std::string& k, PbMessage* msg) override;\n  void Clear() override;\n  int32_t IncreaseCount(const std::string& k, int32_t v) override;\n  void EraseCount(const std::string& k) override;\n  void StopHeartbeat();\n\n private:\n  const ProcessCtx& process_ctx() const { return process_ctx_; }\n  ProcessCtx process_ctx_;\n  bool need_heartbeat_thread_stop_;\n  std::mutex need_heartbeat_thread_stop_mtx_;\n  std::condition_variable need_heartbeat_thread_stop_cv_;\n  std::thread heartbeat_thread_;\n  RpcClient rpc_client_;\n};\n\nclass GrpcRpcManager : public RpcManager {\n public:\n  GrpcRpcManager() = default;\n  ~GrpcRpcManager() override;\n  Maybe<void> Bootstrap() override;\n  Maybe<void> CreateServer() override;\n  Maybe<void> CreateClient() override;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_RPC_INCLUDE_GRPC_H_\n"
  },
  {
    "path": "oneflow/core/rpc/include/local.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_RPC_INCLUDE_LOCAL_H_\n#define ONEFLOW_CORE_RPC_INCLUDE_LOCAL_H_\n\n#include <string>\n#include <unordered_map>\n#include \"oneflow/core/common/blocking_counter.h\"\n#include \"oneflow/core/common/protobuf.h\"\n#include \"oneflow/core/control/ctrl_bootstrap.pb.h\"\n#include \"oneflow/core/rpc/include/base.h\"\n\nnamespace oneflow {\n\nclass LocalCtrlClient : public CtrlClient {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(LocalCtrlClient);\n  explicit LocalCtrlClient(const ProcessCtx& process_ctx);\n  ~LocalCtrlClient() override = default;\n\n  void Barrier(const std::string& barrier_name) override;\n  void Barrier(const std::string& barrier_name, int32_t barrier_num) override;\n\n  TryLockResult TryLock(const std::string& name) override;\n  void NotifyDone(const std::string& name) override;\n  void WaitUntilDone(const std::string& name) override;\n\n  void PushKV(const std::string& k, std::function<void(std::string*)> VSetter) override;\n  void PushKV(const std::string& k, const std::string& v) override;\n  void PushKV(const std::string& k, const PbMessage& msg) override;\n  void PushMasterKV(const std::string& k, const PbMessage& msg) override;\n\n  void ClearKV(const std::string& k) override;\n  void ClearMasterKV(const std::string& k) override;\n\n  void PullKV(const std::string& k, std::function<void(const std::string&)> VGetter) override;\n  void PullKV(const std::string& k, std::string* v) override;\n  void PullKV(const std::string& k, PbMessage* msg) override;\n  void PullMasterKV(const std::string& k, PbMessage* msg) override;\n  void Clear() override;\n  int32_t IncreaseCount(const std::string& k, int32_t v) override;\n  void EraseCount(const std::string& k) override;\n\n  HashSet<std::string> done_names_;\n  HashSet<std::string> doing_names_;\n  std::mutex done_names_mtx_;\n  std::condition_variable done_names_cv_;\n  HashMap<std::string, std::string> kv_;\n  std::mutex kv_mtx_;\n  std::condition_variable kv_cv_;\n  HashMap<std::string, int32_t> counter_;\n  std::mutex counter_mtx_;\n  HashMap<std::string, std::shared_ptr<BlockingCounter>> barrier_counter_;\n  std::mutex barrier_counter_mtx_;\n};\n\nclass LocalRpcManager : public RpcManager {\n public:\n  LocalRpcManager() = default;\n  ~LocalRpcManager() override;\n  Maybe<void> Bootstrap() override;\n  Maybe<void> CreateServer() override { return Maybe<void>::Ok(); }\n  Maybe<void> CreateClient() override;\n};\n\nclass DryRunRpcManager : public RpcManager {\n public:\n  DryRunRpcManager() = default;\n  ~DryRunRpcManager() override;\n  Maybe<void> Bootstrap() override;\n  Maybe<void> CreateServer() override { return Maybe<void>::Ok(); }\n  Maybe<void> CreateClient() override;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_RPC_INCLUDE_LOCAL_H_\n"
  },
  {
    "path": "oneflow/core/rpc/include/manager.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_RPC_INCLUDE_MANAGER_H_\n#define ONEFLOW_CORE_RPC_INCLUDE_MANAGER_H_\n\n#ifdef RPC_BACKEND_GRPC\n#include \"oneflow/core/rpc/include/grpc.h\"\n#endif  // RPC_BACKEND_GRPC\n\n#ifdef RPC_BACKEND_LOCAL\n#include \"oneflow/core/rpc/include/local.h\"\n#endif  // RPC_BACKEND_LOCAL\n\n#endif  // ONEFLOW_CORE_RPC_INCLUDE_MANAGER_H_\n"
  },
  {
    "path": "oneflow/core/rpc/lib/global_process_ctx.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/singleton.h\"\n#include \"oneflow/core/common/str_util.h\"\n#include \"oneflow/core/control/ctrl_bootstrap.pb.h\"\n#include \"oneflow/core/rpc/include/global_process_ctx.h\"\n\nnamespace oneflow {\n\nvoid GlobalProcessCtx::GetMachineIdAndDeviceId(int64_t rank, int64_t* machine_id,\n                                               int64_t* device_id) {\n  *machine_id = rank;\n  *device_id = rank % NumOfProcessPerNode();\n}\n\nvoid GlobalProcessCtx::GetCurrentMachineIdAndDeviceId(int64_t* machine_id, int64_t* device_id) {\n  *machine_id = Rank();\n  *device_id = LocalRank();\n}\n\nint64_t GlobalProcessCtx::Rank() {\n  CHECK_NOTNULL(Singleton<ProcessCtx>::Get());\n  return Singleton<ProcessCtx>::Get()->rank();\n}\n\nint64_t GlobalProcessCtx::LocalRank() {\n  char* local_rank_env = std::getenv(\"LOCAL_RANK\");\n  if (!local_rank_env) {\n    static int64_t local_rank = Rank() % NumOfProcessPerNode();\n    return local_rank;\n  }\n  CHECK(IsStrInt(local_rank_env));\n  static int64_t local_rank = std::stol(local_rank_env);\n  return local_rank;\n}\n\nint64_t GlobalProcessCtx::NodeSize() {\n  CHECK_NOTNULL(Singleton<ProcessCtx>::Get());\n  return Singleton<ProcessCtx>::Get()->node_size();\n}\n\nint64_t GlobalProcessCtx::ThisNodeId() {\n  CHECK_NOTNULL(Singleton<ProcessCtx>::Get());\n  return NodeId(Rank());\n}\n\nint64_t GlobalProcessCtx::NodeId(int64_t process_id) {\n  CHECK_NOTNULL(Singleton<ProcessCtx>::Get());\n  return process_id / NumOfProcessPerNode();\n}\n\nint64_t GlobalProcessCtx::NumOfProcessPerNode() {\n  CHECK_NOTNULL(Singleton<ProcessCtx>::Get());\n  CHECK_EQ(WorldSize() % NodeSize(), 0);\n  return int64_t(WorldSize() / NodeSize());\n}\n\nbool GlobalProcessCtx::IsThisProcessMaster() {\n  CHECK_NOTNULL(Singleton<ProcessCtx>::Get());\n  return Singleton<ProcessCtx>::Get()->rank() == 0;\n}\n\nsize_t GlobalProcessCtx::WorldSize() {\n  CHECK_NOTNULL(Singleton<ProcessCtx>::Get());\n  return Singleton<ProcessCtx>::Get()->ctrl_addr().size();\n}\n\nstd::string GlobalProcessCtx::LogDirEntry() {\n  CHECK_NOTNULL(Singleton<ProcessCtx>::Get());\n  const auto& process_ctx = *Singleton<ProcessCtx>::Get();\n  const auto& addr = process_ctx.ctrl_addr(process_ctx.rank());\n  CHECK(addr.has_host());\n  return addr.host() + \"-\" + std::to_string(addr.port()) + \"-\" + std::to_string(process_ctx.rank());\n}\n\n/* static */ int64_t GlobalProcessCtx::LocalRank(int64_t rank) {\n  return rank % NumOfProcessPerNode();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/rpc/lib/grpc.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifdef RPC_BACKEND_GRPC\n\n#include \"oneflow/core/control/ctrl_bootstrap.h\"\n#include \"oneflow/core/control/ctrl_server.h\"\n#include \"oneflow/core/rpc/include/grpc.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nMaybe<int> GetCtrlPort(const EnvDesc& env_desc) {\n  int port = 0;\n  if (env_desc.has_bootstrap_conf_ctrl_port()) { port = env_desc.bootstrap_conf_ctrl_port(); }\n  return port;\n}\n\n}  // namespace\n\nMaybe<void> GrpcRpcManager::Bootstrap() {\n  std::shared_ptr<CtrlBootstrap> ctrl_bootstrap;\n  auto& env_desc = *Singleton<EnvDesc>::Get();\n  if (env_desc.has_ctrl_bootstrap_conf()) {\n    ctrl_bootstrap.reset(new RankInfoCtrlBootstrap(env_desc.bootstrap_conf()));\n  } else {\n    ctrl_bootstrap.reset(new HostListCtrlBootstrap(env_desc));\n  }\n  JUST(ctrl_bootstrap->InitProcessCtx(Singleton<CtrlServer>::Get()->port(),\n                                      Singleton<ProcessCtx>::Get()));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> GrpcRpcManager::CreateServer() {\n  Singleton<CtrlServer>::New(JUST(GetCtrlPort(*Singleton<EnvDesc>::Get())));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> GrpcRpcManager::CreateClient() {\n  auto* client = new GrpcCtrlClient(*Singleton<ProcessCtx>::Get());\n  Singleton<CtrlClient>::SetAllocated(client);\n  return Maybe<void>::Ok();\n}\n\nGrpcRpcManager::~GrpcRpcManager() {\n  auto* grpc_client = dynamic_cast<GrpcCtrlClient*>(Singleton<CtrlClient>::Get());\n  CHECK_NOTNULL(grpc_client);\n  grpc_client->StopHeartbeat();\n  OF_ENV_BARRIER();\n  Singleton<CtrlClient>::Delete();\n  CHECK_NOTNULL(Singleton<CtrlServer>::Get());\n  Singleton<CtrlServer>::Delete();\n}\n\n}  // namespace oneflow\n\n#endif  // RPC_BACKEND_GRPC\n"
  },
  {
    "path": "oneflow/core/rpc/lib/local.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifdef RPC_BACKEND_LOCAL\n\n#include \"glog/logging.h\"\n#include \"oneflow/core/job/env_desc.h\"\n#include \"oneflow/core/rpc/include/local.h\"\n#include \"oneflow/core/job/resource_desc.h\"\n#include \"oneflow/core/job/global_for.h\"\n\nnamespace oneflow {\n\nLocalCtrlClient::LocalCtrlClient(const ProcessCtx& process_ctx) {\n  CHECK(process_ctx.ctrl_addr_size() == 1);\n  CHECK(process_ctx.node_size() == 1);\n}\n\nvoid LocalCtrlClient::Barrier(const std::string& barrier_name) {\n  Barrier(barrier_name, Singleton<EnvDesc>::Get()->TotalMachineNum());\n}\n\nvoid LocalCtrlClient::Barrier(const std::string& barrier_name, int32_t barrier_num) {\n  std::shared_ptr<BlockingCounter> counter;\n  bool is_first = false;\n  {\n    std::unique_lock<std::mutex> lck(barrier_counter_mtx_);\n    auto it = barrier_counter_.find(barrier_name);\n    if (it == barrier_counter_.end()) {\n      is_first = true;\n      counter = std::make_shared<BlockingCounter>(barrier_num);\n      CHECK(barrier_counter_.emplace(barrier_name, counter).second);\n    } else {\n      counter = it->second;\n    }\n  }\n  counter->Decrease();\n  counter->WaitForeverUntilCntEqualZero();\n  if (is_first) {\n    std::unique_lock<std::mutex> lck(barrier_counter_mtx_);\n    CHECK_EQ(barrier_counter_.erase(barrier_name), 1);\n  }\n}\n\nTryLockResult LocalCtrlClient::TryLock(const std::string& name) {\n  std::unique_lock<std::mutex> lck(done_names_mtx_);\n  if (done_names_.find(name) != done_names_.end()) {\n    return TryLockResult::kDone;\n  } else if (doing_names_.find(name) != doing_names_.end()) {\n    return TryLockResult::kDoing;\n  } else {\n    doing_names_.insert(name);\n    return TryLockResult::kLocked;\n  }\n}\n\nvoid LocalCtrlClient::NotifyDone(const std::string& name) {\n  std::unique_lock<std::mutex> lck(done_names_mtx_);\n  done_names_.insert(name);\n  CHECK_EQ(doing_names_.erase(name), 1);\n  done_names_cv_.notify_all();\n}\n\nvoid LocalCtrlClient::WaitUntilDone(const std::string& name) {\n  std::unique_lock<std::mutex> lck(done_names_mtx_);\n  VLOG(3) << \"waiting for name: \" << name;\n  done_names_cv_.wait(lck);\n  CHECK(done_names_.find(name) != done_names_.end());\n}\n\nvoid LocalCtrlClient::PushKV(const std::string& k, std::function<void(std::string*)> VSetter) {\n  std::unique_lock<std::mutex> lck(kv_mtx_);\n  VSetter(&kv_[k]);\n  kv_cv_.notify_all();\n}\n\nvoid LocalCtrlClient::PushKV(const std::string& k, const std::string& v) {\n  PushKV(k, [&](std::string* o) { *o = v; });\n}\n\nvoid LocalCtrlClient::PushKV(const std::string& k, const PbMessage& msg) {\n  PushKV(k, [&](std::string* o) { msg.SerializeToString(o); });\n}\n\nvoid LocalCtrlClient::PushMasterKV(const std::string& k, const PbMessage& msg) {\n  PushKV(k, [&](std::string* o) { msg.SerializeToString(o); });\n}\n\nvoid LocalCtrlClient::ClearKV(const std::string& k) {\n  std::unique_lock<std::mutex> lck(kv_mtx_);\n  kv_.erase(k);\n}\n\nvoid LocalCtrlClient::ClearMasterKV(const std::string& k) { ClearKV(k); }\n\nvoid LocalCtrlClient::PullKV(const std::string& k,\n                             std::function<void(const std::string&)> VGetter) {\n  std::unique_lock<std::mutex> lck(kv_mtx_);\n  while (true) {\n    auto it = kv_.find(k);\n    if (it == kv_.end()) {\n      VLOG(3) << \"waiting for key: \" << k;\n      kv_cv_.wait(lck);\n    } else {\n      VGetter(it->second);\n      break;\n    }\n  }\n}\n\nvoid LocalCtrlClient::PullKV(const std::string& k, std::string* v) {\n  PullKV(k, [&](const std::string& i) { *v = i; });\n}\n\nvoid LocalCtrlClient::PullKV(const std::string& k, PbMessage* msg) {\n  PullKV(k, [&](const std::string& i) { msg->ParseFromString(i); });\n}\n\nvoid LocalCtrlClient::PullMasterKV(const std::string& k, PbMessage* msg) {\n  PullKV(k, [&](const std::string& i) { msg->ParseFromString(i); });\n}\n\nvoid LocalCtrlClient::Clear() {\n  {\n    std::unique_lock<std::mutex> lck(done_names_mtx_);\n    done_names_.clear();\n    done_names_cv_.notify_all();\n  }\n  {\n    std::unique_lock<std::mutex> lck(kv_mtx_);\n    kv_.clear();\n    kv_cv_.notify_all();\n  }\n}\n\nint32_t LocalCtrlClient::IncreaseCount(const std::string& k, int32_t v) {\n  std::unique_lock<std::mutex> lck(counter_mtx_);\n  auto it = counter_.find(k);\n  if (it == counter_.end()) {\n    counter_[k] = 1;\n    return 1;\n  } else {\n    const int32_t new_val = it->second + 1;\n    counter_[k] = new_val;\n    return new_val;\n  }\n}\n\nvoid LocalCtrlClient::EraseCount(const std::string& k) {\n  std::unique_lock<std::mutex> lck(counter_mtx_);\n  counter_.erase(k);\n}\n\nclass DryRunCtrlClient : public CtrlClient {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(DryRunCtrlClient);\n  explicit DryRunCtrlClient(const ProcessCtx& process_ctx)\n      : local_ctrl_client_{std::unique_ptr<LocalCtrlClient>(new LocalCtrlClient(process_ctx))} {\n    CHECK(process_ctx.ctrl_addr_size() == 1);\n    CHECK(process_ctx.node_size() == 1);\n  }\n  ~DryRunCtrlClient() override = default;\n\n  void Barrier(const std::string& barrier_name) override {\n    Barrier(barrier_name, Singleton<EnvDesc>::Get()->TotalMachineNum());\n  }\n  void Barrier(const std::string& barrier_name, int32_t barrier_num) override {\n    VLOG(3) << \"skipping barrier in dry run, barrier name: \" << barrier_name\n            << \", barrier num: \" << barrier_num;\n  }\n\n  TryLockResult TryLock(const std::string& name) override {\n    return local_ctrl_client_->TryLock(name);\n  }\n  void NotifyDone(const std::string& name) override { local_ctrl_client_->NotifyDone(name); }\n  void WaitUntilDone(const std::string& name) override { local_ctrl_client_->WaitUntilDone(name); }\n\n  void PushKV(const std::string& k, std::function<void(std::string*)> VSetter) override {\n    local_ctrl_client_->PushKV(k, VSetter);\n  }\n  void PushKV(const std::string& k, const std::string& v) override {\n    local_ctrl_client_->PushKV(k, v);\n  }\n  void PushKV(const std::string& k, const PbMessage& msg) override {\n    local_ctrl_client_->PushKV(k, msg);\n  }\n  void PushMasterKV(const std::string& k, const PbMessage& msg) override {\n    local_ctrl_client_->PushMasterKV(k, msg);\n  }\n\n  void ClearKV(const std::string& k) override { local_ctrl_client_->ClearKV(k); }\n  void ClearMasterKV(const std::string& k) override { local_ctrl_client_->ClearMasterKV(k); }\n\n  void PullKV(const std::string& k, std::function<void(const std::string&)> VGetter) override {\n    local_ctrl_client_->PullKV(k, VGetter);\n  }\n  void PullKV(const std::string& k, std::string* v) override { local_ctrl_client_->PullKV(k, v); }\n  void PullKV(const std::string& k, PbMessage* msg) override { local_ctrl_client_->PullKV(k, msg); }\n  void PullMasterKV(const std::string& k, PbMessage* msg) override {\n    local_ctrl_client_->PullMasterKV(k, msg);\n  }\n  void Clear() override { local_ctrl_client_->Clear(); }\n  int32_t IncreaseCount(const std::string& k, int32_t v) override {\n    return local_ctrl_client_->IncreaseCount(k, v);\n  }\n  void EraseCount(const std::string& k) override { local_ctrl_client_->EraseCount(k); }\n\n private:\n  std::unique_ptr<LocalCtrlClient> local_ctrl_client_;\n};\n\nvoid SetLocalProcessCtx(oneflow::ProcessCtx* ctx) {\n  Address* addr = ctx->add_ctrl_addr();\n  addr->set_host(\"localhost\");\n  ctx->set_rank(0);\n  ctx->set_node_size(1);\n}\n\nMaybe<void> LocalRpcManager::Bootstrap() {\n  SetLocalProcessCtx(Singleton<ProcessCtx>::Get());\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> LocalRpcManager::CreateClient() {\n  auto* client = new LocalCtrlClient(*Singleton<ProcessCtx>::Get());\n  Singleton<CtrlClient>::SetAllocated(client);\n  return Maybe<void>::Ok();\n}\n\nLocalRpcManager::~LocalRpcManager() { Singleton<CtrlClient>::Delete(); }\n\nMaybe<void> DryRunRpcManager::Bootstrap() {\n  SetLocalProcessCtx(Singleton<ProcessCtx>::Get());\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> DryRunRpcManager::CreateClient() {\n  auto* client = new DryRunCtrlClient(*Singleton<ProcessCtx>::Get());\n  Singleton<CtrlClient>::SetAllocated(client);\n  return Maybe<void>::Ok();\n}\n\nDryRunRpcManager::~DryRunRpcManager() { Singleton<CtrlClient>::Delete(); }\n\n}  // namespace oneflow\n\n#endif  // RPC_BACKEND_LOCAL\n"
  },
  {
    "path": "oneflow/core/summary/event.proto",
    "content": "syntax = \"proto2\";\npackage oneflow.summary;\n\nimport \"oneflow/core/summary/summary.proto\";\n\nmessage Event {\n  required double wall_time = 1;\n  optional int64 step = 2;\n  oneof what {\n    string file_version = 3;\n    bytes graph_def = 4;\n    Summary summary = 5;\n    bytes meta_graph_def = 9;\n  }\n}\n"
  },
  {
    "path": "oneflow/core/summary/graph.proto",
    "content": "syntax = \"proto2\";\npackage oneflow.summary;\n\nimport \"oneflow/core/framework/user_op_attr.proto\";\n\nmessage GraphDef {\n  repeated NodeDef node = 1;\n  required int32 version = 2 [deprecated = true];\n}\n\n\nmessage NodeDef {\n  required string name = 1;\n  required string op = 2;\n  repeated string input = 3;\n  optional string device = 4;\n  map<string, AttrValue> attr = 5;\n}\n"
  },
  {
    "path": "oneflow/core/summary/plugin_data.proto",
    "content": "syntax = \"proto2\";\npackage oneflow.summary;\n\nimport \"google/protobuf/struct.proto\";\n\nmessage HParamsPluginData {\n  required int32 version = 1;\n  oneof data {\n    SessionStartInfo session_start_info = 3;\n  }\n}\n\nmessage SessionStartInfo {\n  map<string, google.protobuf.Value> hparams = 1;\n  required string group_name = 4;\n  required double start_time_secs = 5;\n  map<string, google.protobuf.Value> metrics = 6;\n}\n"
  },
  {
    "path": "oneflow/core/summary/projector.proto",
    "content": "syntax = \"proto2\";\npackage oneflow.summary;\n\nmessage MetaData {\n  enum ProjectorType {\n    EMBEDDING = 0;\n    EXCEPTION = 1;\n  }\n  required ProjectorType type = 1;\n  //Metadata specific information\n  optional string content = 2;\n}\n\nmessage Tensor {\n  message TensorShape {\n    message Dim {\n      required int64 size = 1;\n      optional string name = 2;\n    }\n    repeated Dim dim = 1;\n  }\n  required string dtype = 1;\n  required TensorShape shape = 2;\n  optional bytes content = 3;\n}\n\nmessage Sample{\n  enum SampleType {\n    IMAGE = 0;\n    AUDIO = 1;\n    TEXT = 2;\n  }\n  required string name = 1;\n  required SampleType type = 2;\n  required Tensor X = 3;\n}\n\nmessage Projector {\n  required string tag = 1;\n  optional int64 step = 2;\n  required double WALL_TIME = 3;\n  required Tensor value = 4;\n  optional Tensor label = 5;\n}\n\nmessage SummaryProjector {\n  required MetaData metadata = 6;\n  optional Sample sample = 2;\n  repeated Projector projector = 1;\n}\n"
  },
  {
    "path": "oneflow/core/summary/summary.proto",
    "content": "syntax = \"proto2\";\npackage oneflow.summary;\n\nimport \"oneflow/core/summary/tensor.proto\";\n\nmessage SummaryMetadata {\n  message PluginData {\n    required string plugin_name = 1;\n    optional bytes content = 2;\n  }\n  required PluginData plugin_data = 1;\n  optional string display_name = 2;\n  optional string summary_description = 3;\n};\n\nmessage HistogramProto {\n  required double min = 1;\n  required double max = 2;\n  required double num = 3;\n  required double sum = 4;\n  required double sum_squares = 5;\n  repeated double bucket_limit = 6 [packed = true];\n  repeated double bucket = 7 [packed = true];\n};\n\nmessage Image {\n    required int32 height = 1;\n    required int32 width = 2;\n    required int32 colorspace = 3;\n    required bytes encoded_image_string = 4;\n};\n\nmessage Summary {\n  message Value {\n    optional string node_name = 7;\n    required string tag = 1;\n    optional SummaryMetadata metadata = 9;\n    oneof value {\n      float simple_value = 2;\n      bytes obsolete_old_style_histogram = 3;\n      Image image = 4;\n      HistogramProto histo = 5;\n      //Audio audio = 6;\n      TensorProto tensor = 8;\n    }\n  }\n  repeated Value value = 1;\n}\n"
  },
  {
    "path": "oneflow/core/summary/tensor.proto",
    "content": "syntax = \"proto2\";\npackage oneflow.summary;\n\nmessage TensorProto {\n  required TensorDataType dtype = 1;\n  required TensorShapeProto tensor_shape = 2;\n  optional int32 version_number = 3;\n  optional bytes tensor_content = 4;\n  repeated float float_val = 5 [packed = true];\n  repeated double double_val = 6 [packed = true];\n  repeated int32 int_val = 7 [packed = true];\n  repeated bytes string_val = 8;\n  repeated int64 int64_val = 9 [packed = true];\n  repeated bool bool_val = 10 [packed = true];\n  repeated uint32 uint32_val = 11 [packed = true];\n  repeated uint64 uint64_val = 12 [packed = true];\n  repeated int32 half_val = 13 [packed = true];\n};\n\nmessage TensorShapeProto {\n  message Dim {\n    required int64 size = 1;\n    optional string name = 2;\n  };\n  repeated Dim dim = 2;\n};\n\n\nenum TensorDataType {\n  DT_INVALID = 0;\n  DT_FLOAT = 1;\n  DT_DOUBLE = 2;\n  DT_INT32 = 3;\n  DT_UINT8 = 4;\n  DT_INT16 = 5;\n  DT_INT8 = 6;\n  DT_STRING = 7;\n  DT_INT64 = 8;\n  DT_UINT16 = 9;\n  DT_HALF = 10;\n  DT_UINT32 = 11;\n  DT_UINT64 = 12;\n}\n"
  },
  {
    "path": "oneflow/core/thread/is_main_thread_test.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <gtest/gtest.h>\n#include <thread>\n#include \"oneflow/core/thread/thread_manager.h\"\n\nnamespace oneflow {\nnamespace test {\n\nTEST(IsMainThread, IsMainThread) {\n  EXPECT_TRUE(IsMainThread());\n  auto non_main_thread = std::thread([&]() { EXPECT_FALSE(IsMainThread()); });\n  non_main_thread.join();\n}\n\n}  // namespace test\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/thread/thread.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/thread/thread.h\"\n#include \"oneflow/core/job/runtime_context.h\"\n#include \"oneflow/core/job/global_for.h\"\n#include \"oneflow/core/lazy/actor/actor.h\"\n#include \"oneflow/core/lazy/actor/light_actor.h\"\n#include \"oneflow/core/profiler/profiler.h\"\n#include \"oneflow/core/lazy/stream_context/include/stream_context.h\"\n#include \"oneflow/core/framework/to_string.h\"\n#include \"oneflow/core/lazy/stream_context/include/generic_stream_context.h\"\n#include \"oneflow/core/job/lazy_mode.h\"\n\nnamespace oneflow {\n\nThread::Thread(const StreamId& stream_id) : thrd_id_(EncodeStreamIdToInt64(stream_id)) {\n  local_msg_queue_enabled_ = ParseBooleanFromEnv(\"ONEFLOW_THREAD_ENABLE_LOCAL_MESSAGE_QUEUE\", true);\n  light_actor_enabled_ = ParseBooleanFromEnv(\"ONEFLOW_ACTOR_ENABLE_LIGHT_ACTOR\", true);\n  if (IsClassRegistered<int, StreamContext, const StreamId&>(stream_id.device_id().device_type(),\n                                                             stream_id)) {\n    stream_ctx_.reset(NewObj<int, StreamContext, const StreamId&>(\n        stream_id.device_id().device_type(), stream_id));\n  } else {\n    stream_ctx_.reset(new GenericStreamContext(stream_id));\n  }\n\n  actor_thread_ = std::thread([this, stream_id]() {\n    LazyMode::Guard guard(true);\n    OF_PROFILER_NAME_THIS_HOST_THREAD(\"_\" + ToString(stream_id.device_id().device_type())\n                                      + std::to_string(stream_id.device_id().device_index())\n                                      + \"_actor\");\n    CHECK_JUST(stream_ctx_->stream()->OnExecutionContextSetup());\n    PollMsgChannel();\n    CHECK_JUST(stream_ctx_->stream()->OnExecutionContextTeardown());\n  });\n}\n\nThread::~Thread() {\n  actor_thread_.join();\n  CHECK(id2task_.empty());\n  msg_channel_.Close();\n}\n\nvoid Thread::AddTask(const TaskProto& task) {\n  std::unique_lock<std::mutex> lck(id2task_mtx_);\n  CHECK(id2task_.emplace(task.task_id(), task).second);\n}\n\nvoid Thread::PollMsgChannel() {\n  while (true) {\n    if (local_msg_queue_.empty()) {\n      CHECK_EQ(msg_channel_.ReceiveMany(&local_msg_queue_), kChannelStatusSuccess);\n    }\n    ActorMsg msg = std::move(local_msg_queue_.front());\n    local_msg_queue_.pop();\n    if (msg.msg_type() == ActorMsgType::kCmdMsg) {\n      if (msg.actor_cmd() == ActorCmd::kStopThread) {\n        CHECK(id2actor_ptr_.empty())\n            << \" RuntimeError! Thread: \" << thrd_id_\n            << \" NOT empty when stop with actor num: \" << id2actor_ptr_.size();\n        break;\n      } else if (msg.actor_cmd() == ActorCmd::kConstructActor) {\n        ConstructActor(msg.dst_actor_id());\n        continue;\n      } else {\n        // do nothing\n      }\n    }\n    int64_t actor_id = msg.dst_actor_id();\n    auto actor_it = id2actor_ptr_.find(actor_id);\n    CHECK(actor_it != id2actor_ptr_.end());\n    int process_msg_ret = actor_it->second.second->ProcessMsg(msg);\n    if (process_msg_ret == 1) {\n      VLOG(3) << \"thread \" << thrd_id_ << \" deconstruct actor \" << actor_id;\n      auto job_id_it = id2job_id_.find(actor_id);\n      const int64_t job_id = job_id_it->second;\n      id2job_id_.erase(job_id_it);\n      id2actor_ptr_.erase(actor_it);\n      Singleton<RuntimeCtx>::Get()->DecreaseCounter(GetRunningActorCountKeyByJobId(job_id));\n    } else {\n      CHECK_EQ(process_msg_ret, 0);\n    }\n  }\n}\n\nvoid Thread::ConstructActor(int64_t actor_id) {\n  std::unique_lock<std::mutex> lck(id2task_mtx_);\n  auto task_it = id2task_.find(actor_id);\n  const TaskProto& task = task_it->second;\n  std::unique_ptr<ActorContext> actor_ctx = NewActorContext(task, stream_ctx_.get());\n  CHECK(actor_ctx);\n  std::unique_ptr<ActorBase> actor_ptr;\n  if (light_actor_enabled_) { actor_ptr = TryNewLightActor(actor_ctx.get()); }\n  if (!actor_ptr) {\n    actor_ptr = NewActor(actor_ctx.get());\n    VLOG(3) << \"Thread \" << thrd_id_ << \" construct Actor \" << TaskType_Name(task.task_type())\n            << \" \" << actor_id;\n  } else {\n    VLOG(3) << \"Thread \" << thrd_id_ << \" construct LightActor \" << TaskType_Name(task.task_type())\n            << \" \" << actor_id;\n  }\n  CHECK(id2actor_ptr_.emplace(actor_id, std::make_pair(std::move(actor_ctx), std::move(actor_ptr)))\n            .second);\n  CHECK(id2job_id_.emplace(actor_id, task.job_id()).second);\n  id2task_.erase(task_it);\n  Singleton<RuntimeCtx>::Get()->DecreaseCounter(\"constructing_actor_cnt\");\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/thread/thread.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_THREAD_THREAD_H_\n#define ONEFLOW_CORE_THREAD_THREAD_H_\n\n#include \"oneflow/core/lazy/actor/actor_message_bus.h\"\n#include \"oneflow/core/common/channel.h\"\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/job/task.pb.h\"\n#include \"oneflow/core/lazy/actor/actor.h\"\n#include \"oneflow/core/lazy/actor/actor_context.h\"\n\nnamespace oneflow {\n\nclass StreamContext;\n\nclass Thread {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(Thread);\n  explicit Thread(const StreamId& stream_id);\n  virtual ~Thread();\n\n  void AddTask(const TaskProto&);\n\n  Channel<ActorMsg>* GetMsgChannelPtr() { return &msg_channel_; }\n\n  inline void EnqueueActorMsg(const ActorMsg& msg) {\n    if (UseLocalMsgQueue()) {\n      local_msg_queue_.push(msg);\n    } else {\n      msg_channel_.Send(msg);\n    }\n  }\n\n  template<typename InputIt>\n  inline void EnqueueActorMsg(InputIt first, InputIt last) {\n    if (UseLocalMsgQueue()) {\n      for (auto it = first; it != last; ++it) { local_msg_queue_.push(*it); }\n    } else {\n      for (auto it = first; it != last; ++it) { msg_channel_.Send(*it); }\n    }\n  }\n\n protected:\n  void PollMsgChannel();\n\n private:\n  void ConstructActor(int64_t actor_id);\n\n  inline bool UseLocalMsgQueue() const {\n    return local_msg_queue_enabled_ && std::this_thread::get_id() == actor_thread_.get_id();\n  }\n\n  HashMap<int64_t, TaskProto> id2task_;\n  std::mutex id2task_mtx_;\n\n  std::thread actor_thread_;\n  Channel<ActorMsg> msg_channel_;\n  HashMap<int64_t, std::pair<std::unique_ptr<ActorContext>, std::unique_ptr<ActorBase>>>\n      id2actor_ptr_;\n  HashMap<int64_t, int64_t> id2job_id_;\n  std::queue<ActorMsg> local_msg_queue_;\n  bool local_msg_queue_enabled_;\n  int64_t thrd_id_;\n  bool light_actor_enabled_;\n  std::unique_ptr<StreamContext> stream_ctx_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_THREAD_THREAD_H_\n"
  },
  {
    "path": "oneflow/core/thread/thread_global_id.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/thread/thread_global_id.h\"\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/common/optional.h\"\n#include \"oneflow/core/framework/transport_util.h\"\n#include \"oneflow/core/common/container_util.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nint64_t* MutThreadLocalUniqueGlobalId() {\n  static thread_local int64_t global_id = kThreadGlobalIdMain;\n  return &global_id;\n}\n\n}  // namespace\n\nint64_t GetThisThreadGlobalId() { return *MutThreadLocalUniqueGlobalId(); }\n\nThreadGlobalIdGuard::ThreadGlobalIdGuard(int64_t thread_global_id)\n    : old_thread_global_id_(GetThisThreadGlobalId()) {\n  if (old_thread_global_id_ != kThreadGlobalIdMain) {\n    CHECK_EQ(old_thread_global_id_, thread_global_id)\n        << \"nested ThreadGlobalIdGuard disabled. old thread_global_id: \" << old_thread_global_id_\n        << \", new thread_global_id:\" << thread_global_id;\n  }\n  *MutThreadLocalUniqueGlobalId() = thread_global_id;\n}\n\nThreadGlobalIdGuard::~ThreadGlobalIdGuard() {\n  *MutThreadLocalUniqueGlobalId() = old_thread_global_id_;\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/thread/thread_global_id.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_THREAD_GLOBAL_UNIQUE_ID_H_\n#define ONEFLOW_CORE_THREAD_GLOBAL_UNIQUE_ID_H_\n\n#include <string>\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/common/optional.h\"\n\nnamespace oneflow {\n\nconst static int kThreadGlobalIdDefaultWorker = 0;\nconst static int kThreadGlobalIdMain = 7;\n\nint64_t GetThisThreadGlobalId();\n\nclass ThreadGlobalIdGuard final {\n public:\n  explicit ThreadGlobalIdGuard(int64_t thread_global_id);\n  ~ThreadGlobalIdGuard();\n\n private:\n  int64_t old_thread_global_id_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_THREAD_GLOBAL_UNIQUE_ID_H_\n"
  },
  {
    "path": "oneflow/core/thread/thread_manager.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/thread/thread_manager.h\"\n#include \"oneflow/core/job/resource_desc.h\"\n#include \"oneflow/core/job/global_for.h\"\n#include \"oneflow/core/control/global_process_ctx.h\"\n\nnamespace oneflow {\n\nThreadMgr::~ThreadMgr() {\n  for (auto& thread_pair : threads_) {\n    ActorMsg msg = ActorMsg::BuildCommandMsg(-1, ActorCmd::kStopThread);\n    thread_pair.second->GetMsgChannelPtr()->Send(msg);\n    thread_pair.second.reset();\n    VLOG(1) << \" Actor thread: \" << thread_pair.first << \" finished when process exits.\";\n  }\n}\n\nThread* ThreadMgr::GetThrd(int64_t thrd_id) {\n  auto iter = threads_.find(thrd_id);\n  CHECK(iter != threads_.end()) << \" Thread: \" << thrd_id << \" not found\";\n  return iter->second.get();\n}\n\nvoid ThreadMgr::AddThreads(const HashSet<int64_t>& thread_ids) {\n  const int64_t this_rank = GlobalProcessCtx::Rank();\n  for (int64_t thrd_id : thread_ids) {\n    const auto& it = threads_.find(thrd_id);\n    if (it != threads_.end()) {\n      // NOTE(chengcheng): check thread is not null.\n      CHECK(it->second) << \" RuntimeError! Thread: \" << thrd_id << \" in manager must be NOT null.\";\n      VLOG(1) << \" Actor thread: \" << thrd_id << \" reused.\";\n      continue;\n    }\n    StreamId stream_id = DecodeStreamIdFromInt64(thrd_id);\n    if (stream_id.rank() != this_rank) { continue; }\n    Thread* thread = new Thread(stream_id);\n    CHECK_NOTNULL(thread);\n    threads_[thrd_id].reset(thread);\n    VLOG(1) << \" Actor thread: \" << thrd_id << \" created.\";\n  }\n}\n\nvoid ThreadMgr::DeleteThreads(const HashSet<int64_t>& thread_ids) {\n  std::unique_lock<std::mutex> lock(mutex4del_threads_);\n  for (int64_t thrd_id : thread_ids) {\n    const auto& it = threads_.find(thrd_id);\n    CHECK((it != threads_.end()) && (it->second))\n        << \" RuntimeError! Actor thread: \" << thrd_id << \" non-existent but want to delete\";\n    auto& thread = it->second;\n    ActorMsg msg = ActorMsg::BuildCommandMsg(-1, ActorCmd::kStopThread);\n    thread->GetMsgChannelPtr()->Send(msg);\n    thread.reset();\n    VLOG(1) << \" Actor thread: \" << thrd_id << \" finished when the graph is destructed.\";\n    threads_.erase(it);\n  }\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/thread/thread_manager.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_THREAD_THREAD_MANAGER_H_\n#define ONEFLOW_CORE_THREAD_THREAD_MANAGER_H_\n\n#include <mutex>\n#include \"oneflow/core/common/channel.h\"\n#include \"oneflow/core/common/protobuf.h\"\n#include \"oneflow/core/common/auto_registration_factory.h\"\n#include \"oneflow/core/common/blocking_counter.h\"\n#include \"oneflow/core/common/balanced_splitter.h\"\n#include \"oneflow/core/common/cpp_attribute.h\"\n#include \"oneflow/core/thread/thread.h\"\n#include \"oneflow/core/thread/thread_pool.h\"\n#include \"oneflow/core/platform/include/pthread_fork.h\"\n\nnamespace oneflow {\n\nclass Plan;\n\nclass ThreadMgr final {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(ThreadMgr);\n  ThreadMgr() = default;\n  ~ThreadMgr();\n\n  void AddThreads(const HashSet<int64_t>& thread_ids);\n  void DeleteThreads(const HashSet<int64_t>& thread_ids);\n  Thread* GetThrd(int64_t thrd_id);\n\n private:\n  friend class Singleton<ThreadMgr>;\n\n  HashMap<int64_t, std::unique_ptr<Thread>> threads_;\n  std::mutex mutex4del_threads_;\n};\n\n// Use limit_thread_num to config the max thread num.\n// limit_thread_num == -1 means no limit, use the max avaliable thread num of the ThreadPool.\n// limit_thread_num == 0 means use the current thread.\ntemplate<typename DoEachT>\nvoid MultiThreadLoop(size_t work_num, const DoEachT& DoEachWork, int64_t limit_thread_num = -1) {\n  if (work_num == 0) { return; }\n  if (unlikely(pthread_fork::IsForkedSubProcess() || Singleton<ThreadPool>::Get() == nullptr\n               || limit_thread_num == 0)) {\n    FOR_RANGE(size_t, i, 0, work_num) { DoEachWork(i); }\n    return;\n  }\n  size_t thread_num = Singleton<ThreadPool>::Get()->thread_num();\n  if (limit_thread_num > 0) {\n    thread_num = std::min(thread_num, static_cast<size_t>(limit_thread_num));\n  }\n  thread_num = std::min(work_num, thread_num);\n  BalancedSplitter bs(work_num, thread_num);\n  BlockingCounter bc(thread_num);\n  FOR_RANGE(size_t, range_id, 0, thread_num) {\n    Singleton<ThreadPool>::Get()->AddWork([&bc, &bs, range_id, DoEachWork] {\n      size_t start = bs.At(range_id).begin();\n      size_t end = bs.At(range_id).end();\n      FOR_RANGE(size_t, i, start, end) { DoEachWork(i); }\n      bc.Decrease();\n    });\n  }\n  // busy loop wait.\n  bc.WaitForeverUntilCntEqualZero();\n}\n\ninline bool* MutIsMainThread() {\n  thread_local bool is_main_thread = false;\n  return &is_main_thread;\n}\n\ninline bool IsMainThread() { return *MutIsMainThread(); }\ninline void SetIsMainThread(bool is_main_thread) { *MutIsMainThread() = is_main_thread; }\n\nCOMMAND(SetIsMainThread(true));\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_THREAD_THREAD_MANAGER_H_\n"
  },
  {
    "path": "oneflow/core/thread/thread_pool.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/thread/thread_pool.h\"\n#include \"oneflow/core/vm/sync_vm_mode_guard.h\"\n\nnamespace oneflow {\n\nThreadPool::ThreadPool(int32_t thread_num)\n    : work_chans_(thread_num), threads_(thread_num), work_cnt_(0) {\n  FOR_RANGE(int32_t, i, 0, thread_num) {\n    Channel<std::function<void()>>* chan = &(work_chans_.at(i));\n    threads_[i] = std::thread([chan]() {\n      SyncVmModeGuard guard(SyncVmMode::kEnable);\n      std::function<void()> work;\n      while (chan->Receive(&work) == kChannelStatusSuccess) { work(); }\n    });\n  }\n}\n\nThreadPool::~ThreadPool() {\n  FOR_RANGE(int32_t, i, 0, work_chans_.size()) {\n    work_chans_.at(i).Close();\n    threads_.at(i).join();\n  }\n}\n\nvoid ThreadPool::AddWork(const std::function<void()>& work) {\n  const size_t cur_chan_idx =\n      work_cnt_.fetch_add(1, std::memory_order_relaxed) % work_chans_.size();\n  work_chans_.at(cur_chan_idx).Send(work);\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/thread/thread_pool.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_THREAD_THREAD_POOL_H_\n#define ONEFLOW_CORE_THREAD_THREAD_POOL_H_\n\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/common/channel.h\"\n\nnamespace oneflow {\n\nclass ThreadPool final {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(ThreadPool);\n  ThreadPool() = delete;\n  ThreadPool(int32_t thread_num);\n  ~ThreadPool();\n\n  int32_t thread_num() const { return threads_.size(); }\n  void AddWork(const std::function<void()>& work);\n\n private:\n  std::vector<Channel<std::function<void()>>> work_chans_;\n  std::vector<std::thread> threads_;\n\n  std::atomic<size_t> work_cnt_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_THREAD_THREAD_POOL_H_\n"
  },
  {
    "path": "oneflow/core/thread/thread_runtime.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_THREAD_THREAD_RUNTIME_H_\n#define ONEFLOW_CORE_THREAD_THREAD_RUNTIME_H_\n\n#include <functional>\n#include \"oneflow/core/common/blocking_counter.h\"\n#include \"oneflow/core/common/balanced_splitter.h\"\n#include \"oneflow/core/thread/thread.h\"\n#include \"oneflow/core/thread/thread_pool.h\"\n#include \"oneflow/core/platform/include/pthread_fork.h\"\n\n#ifdef WITH_TBB\n#include <tbb/blocked_range.h>\n#include <tbb/parallel_for.h>\n#include <tbb/global_control.h>\n#endif\n\n#ifdef WITH_OMP\n#include <omp.h>\n#endif\n\nnamespace oneflow {\nnamespace thread {\n\nnamespace {\n\nusing CallableT = std::function<void(int64_t, int64_t)>;\n\nvoid SeqFor(int64_t begin, int64_t end, const CallableT& func) { func(begin, end); }\n\nsize_t DivUp(size_t x, size_t y) { return (x + y - 1) / y; }\n\n}  // namespace\n\nclass RuntimeBase {\n public:\n  void ParallelFor(int64_t begin, int64_t end, const CallableT& func, size_t num_threads,\n                   size_t grain_size) {\n    if (begin >= end) { return; }\n    if (num_threads == 1) { return SeqFor(begin, end, func); }\n    ParallelForImpl(begin, end, func, num_threads, grain_size);\n  }\n\n private:\n  virtual void ParallelForImpl(int64_t begin, int64_t end, const CallableT& func,\n                               size_t num_threads, size_t grain_size) = 0;\n};\n\nclass SeqRuntime final : public RuntimeBase {\n private:\n  void ParallelForImpl(int64_t begin, int64_t end, const CallableT& func, size_t num_threads,\n                       size_t grain_size) override {\n    return SeqFor(begin, end, func);\n  }\n};\n\nclass OfRuntime final : public RuntimeBase {\n private:\n  void ParallelForImpl(int64_t begin, int64_t end, const CallableT& func, size_t num_threads,\n                       size_t grain_size) override {\n    if (unlikely(pthread_fork::IsForkedSubProcess()) || Singleton<ThreadPool>::Get() == nullptr) {\n      return SeqFor(begin, end, func);\n    }\n    const size_t num_elements = end - begin;\n    num_threads = std::min(num_elements, num_threads);\n    BalancedSplitter bs(num_elements, num_threads);\n    BlockingCounter bc(num_threads);\n\n    FOR_RANGE(size_t, range_id, 0, num_threads) {\n      Singleton<ThreadPool>::Get()->AddWork([&bc, &bs, range_id, func] {\n        const size_t begin_ = bs.At(range_id).begin();\n        const size_t end_ = bs.At(range_id).end();\n        SeqFor(begin_, end_, func);\n        bc.Decrease();\n      });\n    }\n    // buzy loop wait.\n    bc.WaitForeverUntilCntEqualZero();\n  }\n};\n\n#if WITH_TBB\nclass TbbRuntime final : public RuntimeBase {\n private:\n  void ParallelForImpl(int64_t begin, int64_t end, const CallableT& func, size_t num_threads,\n                       size_t grain_size) override {\n    tbb::global_control global_thread_limit(tbb::global_control::max_allowed_parallelism,\n                                            num_threads);\n    const size_t chunk_size = std::max(DivUp((end - begin), num_threads), grain_size);\n\n    tbb::parallel_for(\n        tbb::blocked_range<int64_t>(begin, end, chunk_size),\n        [&func](const tbb::blocked_range<int64_t>& r) { SeqFor(r.begin(), r.end(), func); },\n        tbb::static_partitioner{});\n  }\n};\n#endif\n\n#if WITH_OMP\nclass OmpRuntime final : public RuntimeBase {\n private:\n  void ParallelForImpl(int64_t begin, int64_t end, const CallableT& func, size_t num_threads,\n                       size_t grain_size) override {\n    num_threads = std::min(DivUp((end - begin), grain_size), num_threads);\n#pragma omp parallel num_threads(num_threads)\n    {\n      int64_t omp_num_thread = omp_get_num_threads();\n      int64_t chunk_size = DivUp((end - begin), omp_num_thread);\n      int64_t omp_tid = omp_get_thread_num();\n      int64_t thread_begin_index = begin + omp_tid * chunk_size;\n      int64_t thread_end_index = std::min(end, chunk_size + thread_begin_index);\n\n      if (thread_begin_index < end) { SeqFor(thread_begin_index, thread_end_index, func); }\n    }\n  }\n};\n#endif\n\n}  // namespace thread\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_THREAD_THREAD_RUNTIME_H_\n"
  },
  {
    "path": "oneflow/core/thread/thread_runtime_factory.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include <fmt/core.h>\n#include <unordered_map>\n#include \"oneflow/core/thread/thread_runtime_factory.h\"\n#include \"oneflow/core/thread/thread_runtime.h\"\n\nnamespace oneflow {\nnamespace thread {\n\nnamespace {\n\ntemplate<typename T>\nstd::shared_ptr<thread::RuntimeBase> CreateRuntime() {\n  return std::shared_ptr<thread::RuntimeBase>(std::make_shared<T>());\n}\n\n}  // namespace\n\nMaybe<thread::RuntimeBase> RuntimeFactory::Create(RuntimeType type) {\n  if (type == RuntimeType::kOf) { return CreateRuntime<thread::OfRuntime>(); }\n  const auto format_error_msg = [](const auto& name, const auto& option) {\n    return fmt::format(\"{} is not enabled, you should compile oneflow with \"\n                       \"`-DCPU_THREADING_RUNTIMES={}`\",\n                       name, option);\n  };\n\n  if (type == RuntimeType::kTbb) {\n    if (!IsTbbEnabled()) { return Error::RuntimeError() << format_error_msg(\"OneTBB\", \"TBB\"); }\n#ifdef WITH_TBB\n    return CreateRuntime<thread::TbbRuntime>();\n#endif\n  }\n  if (type == RuntimeType::kOmp) {\n    if (!IsOmpEnabled()) { return Error::RuntimeError() << format_error_msg(\"OpenMP\", \"OMP\"); }\n#ifdef WITH_OMP\n    return CreateRuntime<thread::OmpRuntime>();\n#endif\n  }\n  return CreateRuntime<thread::SeqRuntime>();\n}\n\nMaybe<thread::RuntimeBase> RuntimeFactory::Create(const std::string& type) {\n  std::unordered_map<std::string, RuntimeType> types{\n      {\"SEQ\", RuntimeType::kSeq},\n      {\"OF\", RuntimeType::kOf},\n      {\"TBB\", RuntimeType::kTbb},\n      {\"OMP\", RuntimeType::kOmp},\n  };\n  if (types.find(type) == types.end()) {\n    return Error::RuntimeError() << fmt::format(\"Not supportted cpu threading runtime: {}\", type);\n  }\n  return Create(types[type]);\n}\n\n}  // namespace thread\n}  // namespace oneflow"
  },
  {
    "path": "oneflow/core/thread/thread_runtime_factory.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_THREAD_THREAD_EXECUTOR_FACTORY_H_\n#define ONEFLOW_CORE_THREAD_THREAD_EXECUTOR_FACTORY_H_\n\n#include \"oneflow/core/thread/thread_runtime.h\"\n\nnamespace oneflow {\nnamespace thread {\n\nconstexpr bool IsTbbEnabled() {\n#ifdef WITH_TBB\n  return true;\n#else\n  return false;\n#endif\n}\n\nconstexpr bool IsOmpEnabled() {\n#ifdef WITH_OMP\n  return true;\n#else\n  return false;\n#endif\n}\n\nenum class RuntimeType {\n  kSeq,\n  kOf,\n  kTbb,\n  kOmp,\n};\n\nclass RuntimeFactory {\n public:\n  static Maybe<thread::RuntimeBase> Create(RuntimeType type);\n  static Maybe<thread::RuntimeBase> Create(const std::string& type);\n};\n\n}  // namespace thread\n}  // namespace oneflow\n#endif  // ONEFLOW_CORE_THREAD_THREAD_EXECUTOR_FACTORY_H_\n"
  },
  {
    "path": "oneflow/core/transport/transport.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifdef __linux__\n\n#include \"oneflow/core/control/global_process_ctx.h\"\n#include \"oneflow/core/transport/transport.h\"\n\nnamespace oneflow {\n\nTransport::Transport() {\n  comm_net_ = Singleton<EpollCommNet>::Get();  // NOLINT\n  this_machine_id_ = GlobalProcessCtx::Rank();\n  CHECK(comm_net_ != nullptr);\n  // maybe need new read id for each dst machine id, maybe need 2 * machine num read ids\n  read_id_ = comm_net_->NewActorReadId();\n  msg_poller_ = std::thread([this]() { PollMsgChannel(); });\n}\n\nTransport::~Transport() {\n  msg_channel_.Close();\n  msg_poller_.join();\n  comm_net_->DeleteActorReadId(read_id_);\n}\n\nvoid Transport::EnqueueTransportMsg(const TransportMsg& msg) {\n  CHECK_EQ(msg_channel_.Send(msg), kChannelStatusSuccess);\n}\n\nvoid Transport::PollMsgChannel() {\n  TransportMsg msg;\n  while (true) {\n    ChannelStatus stat = msg_channel_.Receive(&msg);\n    if (stat != kChannelStatusSuccess) {\n      CHECK_EQ(stat, kChannelStatusErrorClosed);\n      break;\n    }\n    switch (msg.type) {\n      case TransportMsgType::kSend: {\n        HandlerAchievedTransportSendMsgFromSrcMachine(msg);\n        break;\n      }\n      case TransportMsgType::kAck: {\n        HandlerAchievedTransportAckMsgFromDstMachine(msg);\n        break;\n      }\n      default: UNIMPLEMENTED(); break;\n    }\n  }\n}\n\nvoid Transport::HandlerAchievedTransportSendMsgFromSrcMachine(const TransportMsg& msg) {\n  // This machine is dst machine, and receive Send msg from source machine\n  // Maybe we need create TransportStatus,\n  // or we need update TransportStatus and DoRead().\n  CHECK_EQ(msg.type, TransportMsgType::kSend);\n  CHECK(msg.src_mem_token != nullptr);\n  CHECK(msg.dst_mem_token == nullptr);\n  uint64_t token = msg.token;\n  CHECK(token != -1);\n\n  // There are two ways to trigger the creation of TransportStatus:\n  //   1. The time (T_A) when the dst machine receives SendMsg from src machine\n  //   2. The time (T_B) when method Receive() called by the dst machine.\n  // Because of T_ A and t_ B are both protected by the lock(status_mutex_), so the creation of\n  // TransportStatus will NOT trigger at the same time.\n  //\n  // T_ A maybe earlier than t_ B, maybe later.\n  //\n  // In either case, the earlier one is responsible for creating the TransportStatus, and the later\n  // one is responsible for checking the TransportStatus and then calling the DoRead() operation.\n\n  // prepare transport status for this token.\n  // store callback.\n  TransportStatus* stat = nullptr;\n\n  // if recv_before_send is true, it means the Receive() method has been called before this handler\n  bool recv_before_send = false;\n  {\n    std::unique_lock<std::mutex> lock(status_mutex_);\n    auto it = token2status_.find(token);\n    if (it == token2status_.end()) {\n      token2status_.emplace(token, TransportStatus(token));\n      stat = &(token2status_.at(token));\n\n      // init stat\n      // These three members must be initialized in the block protected by lock\n      //  to prevent multi-threaded access bugs\n      stat->size = msg.size;\n      stat->src_machine_id = msg.src_machine_id;\n      stat->dst_machine_id = msg.dst_machine_id;\n    } else {\n      recv_before_send = true;\n      stat = &(it->second);\n      CHECK_GE(stat->size, msg.size);  // NOTE(chengcheng): Recv size may larger than Send size.\n      stat->size = msg.size;           // NOTE(chengcheng): msg.size always is smaller one.\n    }\n\n    stat->is_send_ready = true;\n    CHECK(stat->src_mem_token == nullptr);\n    // src_mem_token MUST init in the block protected by lock\n    stat->src_mem_token = msg.src_mem_token;\n  }\n\n  if (recv_before_send) {\n    // it means the local machine has call Transport::Receive() before this handler\n    // check status\n    CHECK_EQ(stat->src_machine_id, msg.src_machine_id);\n    CHECK_EQ(stat->dst_machine_id, msg.dst_machine_id);\n\n    // the recv is ready, and the send is ready too, so call DoRead();\n    DoRead(token);\n  }\n}\n\nvoid Transport::HandlerAchievedTransportAckMsgFromDstMachine(const TransportMsg& msg) {\n  // This machine is src machine, and receive Ack msg from dst machine. The Send/Receive pair of\n  // this token is all done. So we can call callback function and erase TransportStatus.\n  CHECK_EQ(msg.type, TransportMsgType::kAck);\n  CHECK(msg.src_mem_token != nullptr);\n  CHECK(msg.dst_mem_token != nullptr);\n  uint64_t token = msg.token;\n  CHECK(token != -1);\n  std::function<void()> callback;\n\n  // get status from map\n  {\n    std::unique_lock<std::mutex> lock(status_mutex_);\n    auto it = token2status_.find(token);\n    CHECK(it != token2status_.end());\n    TransportStatus* stat = &(it->second);\n\n    // check msg == stat\n    CHECK_EQ(stat->src_mem_token, msg.src_mem_token);\n    CHECK_EQ(stat->size, msg.size);\n    CHECK_EQ(stat->src_machine_id, msg.src_machine_id);\n    CHECK_EQ(stat->dst_machine_id, msg.dst_machine_id);\n    CHECK(stat->callback != nullptr);\n\n    callback = stat->callback;\n\n    // Recovery status\n    token2status_.erase(it);\n  }\n\n  // UnRegisterMemory\n  comm_net_->UnRegisterMemory(msg.src_mem_token);\n\n  // Do Send callback\n  callback();\n}\n\nvoid Transport::Send(uint64_t token, int64_t dst_machine_id, const void* ptr, std::size_t size,\n                     std::function<void()> callback) {\n  void* mut_ptr = const_cast<void*>(ptr);\n\n  // handler for send to local machine\n  if (dst_machine_id == this_machine_id_) {\n    SendToLocalMachine(token, mut_ptr, size, callback);\n    return;\n  }\n\n  // prepare transport status for this token.\n  // store callback.\n  TransportStatus* stat = nullptr;\n  {\n    std::unique_lock<std::mutex> lock(status_mutex_);\n    CHECK(token2status_.find(token)\n          == token2status_.end());  // this token must be first add to status\n    token2status_.emplace(token, TransportStatus(token));\n    stat = &(token2status_.at(token));\n  }\n  stat->callback = callback;\n  stat->is_send_ready = true;\n  stat->is_recv_ready = false;\n  stat->src_mem_token = comm_net_->RegisterMemory(mut_ptr, size);\n  stat->dst_mem_token = nullptr;\n  stat->size = size;\n  stat->src_machine_id = this_machine_id_;\n  stat->dst_machine_id = dst_machine_id;\n\n  // Send msg to dst machine\n  TransportMsg msg;\n  msg.token = token;\n  msg.src_machine_id = stat->src_machine_id;\n  msg.dst_machine_id = stat->dst_machine_id;\n  msg.size = size;\n  msg.src_mem_token = stat->src_mem_token;\n  msg.dst_mem_token = stat->dst_mem_token;\n  msg.type = TransportMsgType::kSend;\n  comm_net_->SendTransportMsg(msg.dst_machine_id, msg);\n}\n\nvoid Transport::Receive(uint64_t token, int64_t src_machine_id, void* ptr, std::size_t max_size,\n                        std::function<void()> callback) {\n  // handler for receive from local machine\n  if (src_machine_id == this_machine_id_) {\n    RecvFromLocalMachine(token, ptr, max_size, callback);\n    return;\n  }\n\n  // prepare transport status for this token.\n  // store callback.\n  TransportStatus* stat = nullptr;\n\n  // if recv_before_send is true, it means the SendMsg has been handled before this Receive called.\n  bool send_before_recv = false;\n  {\n    std::unique_lock<std::mutex> lock(status_mutex_);\n    auto it = token2status_.find(token);\n    if (it == token2status_.end()) {\n      token2status_.emplace(token, TransportStatus(token));\n      stat = &(token2status_.at(token));\n\n      // init stat\n      // These three members must be initialized in the block protected by lock\n      //  to prevent multi-threaded access bugs\n      stat->size = max_size;\n      stat->src_machine_id = src_machine_id;\n      stat->dst_machine_id = this_machine_id_;\n    } else {\n      send_before_recv = true;\n      stat = &(it->second);\n    }\n\n    stat->callback = callback;\n    stat->is_recv_ready = true;\n    // NOTE(chengcheng): Store dst_ptr so that we can create dst_mem_token in DoRead()\n    stat->dst_ptr = ptr;\n  }\n\n  if (send_before_recv) {\n    // it means the source machine has send message to this machine\n    // check status\n    CHECK_LE(stat->size,\n             max_size);  // NOTE(chengcheng): Receive max_size may larger than Send size.\n    CHECK_EQ(stat->src_machine_id, src_machine_id);\n    CHECK_EQ(stat->dst_machine_id, this_machine_id_);\n\n    // the recv is ready, and the send is ready too, so call DoRead();\n    DoRead(token);\n  }\n}\n\nvoid Transport::DoRead(uint64_t token) {\n  TransportStatus* stat = nullptr;\n  {\n    std::unique_lock<std::mutex> lock(status_mutex_);\n    auto it = token2status_.find(token);\n    CHECK(it != token2status_.end());\n    stat = &(it->second);\n\n    // dst_mem_token MUST init in the block protected by lock\n    CHECK(stat->dst_mem_token == nullptr);\n    // NOTE(chengcheng): ONLY at this time, the stat->size is the real size assigned by Send\n    stat->dst_mem_token = comm_net_->RegisterMemory(stat->dst_ptr, stat->size);\n  }\n  CHECK(stat->is_send_ready && stat->is_recv_ready);\n  CHECK(stat->src_mem_token != nullptr);\n  CHECK(stat->dst_mem_token != nullptr);\n  CHECK(stat->src_machine_id != -1);\n  CHECK(stat->dst_machine_id != -1);\n  CHECK(stat->size != -1);\n  CHECK(stat->callback);\n  comm_net_->Read(read_id_, stat->src_machine_id, stat->src_mem_token, stat->dst_mem_token);\n  comm_net_->AddReadCallBack(read_id_, [stat, this]() {\n    // Send ack message to source machine\n    TransportMsg msg;\n    msg.token = stat->token;\n    msg.src_machine_id = stat->src_machine_id;\n    msg.dst_machine_id = stat->dst_machine_id;\n    msg.size = stat->size;\n    msg.src_mem_token = stat->src_mem_token;\n    msg.dst_mem_token = stat->dst_mem_token;\n    msg.type = TransportMsgType::kAck;\n    comm_net_->SendTransportMsg(msg.src_machine_id, msg);\n\n    // UnRegisterMemory\n    comm_net_->UnRegisterMemory(msg.dst_mem_token);\n\n    // Do Receive callback\n    stat->callback();\n\n    // Recovery status\n    {\n      std::unique_lock<std::mutex> lock(status_mutex_);\n      auto it = token2status_.find(stat->token);\n      CHECK(it != token2status_.end());\n      token2status_.erase(it);\n    }\n  });\n}\n\nvoid Transport::SendToLocalMachine(uint64_t token, void* ptr, std::size_t size,\n                                   std::function<void()> callback) {\n  bool need_do_copy = false;\n  bool need_do_callback = false;\n  std::function<void()> receive_callback;\n  void* dst_ptr = nullptr;\n  {\n    std::unique_lock<std::mutex> lock(local_copy_lock_);\n    auto it = token2local_copy_status_.find(token);\n    if (it == token2local_copy_status_.end()) {\n      // init local copy status\n      token2local_copy_status_.emplace(token, CopyStatusOnLocalMachine(token, ptr, size, callback));\n    } else {\n      need_do_callback = true;\n      receive_callback = std::move(it->second.callback);\n\n      dst_ptr = it->second.ptr;\n      CHECK(size <= it->second.size);  // NOTE(chengcheng): Recv size may larger than Send size.\n\n      if (ptr != dst_ptr) { need_do_copy = true; }\n\n      // erase local copy status\n      token2local_copy_status_.erase(it);\n    }\n  }\n\n  if (need_do_copy) { memcpy(dst_ptr, ptr, size); }\n\n  if (need_do_callback) {\n    callback();\n    receive_callback();\n  }\n}\n\nvoid Transport::RecvFromLocalMachine(uint64_t token, void* ptr, std::size_t max_size,\n                                     std::function<void()> callback) {\n  bool need_do_copy = false;\n  bool need_do_callback = false;\n  std::function<void()> send_callback;\n  void* src_ptr = nullptr;\n  std::size_t size = -1;\n  {\n    std::unique_lock<std::mutex> lock(local_copy_lock_);\n    auto it = token2local_copy_status_.find(token);\n    if (it == token2local_copy_status_.end()) {\n      // init local copy status\n      token2local_copy_status_.emplace(token,\n                                       CopyStatusOnLocalMachine(token, ptr, max_size, callback));\n    } else {\n      need_do_callback = true;\n      send_callback = std::move(it->second.callback);\n\n      src_ptr = it->second.ptr;\n      size = it->second.size;\n      CHECK(max_size >= size);  // NOTE(chengcheng): Recv size may larger than Send size.\n\n      if (ptr != src_ptr) { need_do_copy = true; }\n\n      // erase local copy status\n      token2local_copy_status_.erase(it);\n    }\n  }\n\n  if (need_do_copy) { memcpy(ptr, src_ptr, size); }\n\n  if (need_do_callback) {\n    callback();\n    send_callback();\n  }\n}\n\n}  // namespace oneflow\n\n#endif  // __linux__\n"
  },
  {
    "path": "oneflow/core/transport/transport.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifdef __linux__\n\n#ifndef ONEFLOW_CORE_TRANSPORT_TRANSPORT_H_\n#define ONEFLOW_CORE_TRANSPORT_TRANSPORT_H_\n\n#include \"oneflow/core/common/channel.h\"\n#include \"oneflow/core/comm_network/epoll/epoll_comm_network.h\"\n#include \"oneflow/core/transport/transport_message.h\"\n\nnamespace oneflow {\n\n// Transport supports sending and receiving data between two machines, which is identified by\n// a unique token.\n//\n// Suppose machine A wants to send a piece of data to machine B. Singleton<Transport> both need\n// created on machine A and machine B respectively.\n//\n// Machin A need call:\n//   Singleton<Transport>::Get()->Send(token, B, data_ptr_A, data_size_A, callback_after_send);\n// Machin B need call:\n//   Singleton<Transport>::Get()->Receive(token, A, data_ptr_B, data_size_B,\n//   callback_after_receive);\n//\n// data_size_A <= data_size_B\n//\n// Both call: Send()/Receive() will be executed asynchronously.\n//\n// When the data transmission is completed, the callbacks of the two machines callback_after_send()\n// and callback_after_receive() will be executed on their respective machines.\n//\n// Transport supports send and receive data on local machine.\n//\nclass Transport {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(Transport);\n  virtual ~Transport();\n\n  void Send(uint64_t token, int64_t dst_machine_id, const void* ptr, std::size_t size,\n            std::function<void()> callback);\n  void Receive(uint64_t token, int64_t src_machine_id, void* ptr, std::size_t max_size,\n               std::function<void()> callback);\n  void EnqueueTransportMsg(const TransportMsg& msg);\n\n private:\n  void PollMsgChannel();\n  void HandlerAchievedTransportSendMsgFromSrcMachine(const TransportMsg& msg);\n  void HandlerAchievedTransportAckMsgFromDstMachine(const TransportMsg& msg);\n  void DoRead(uint64_t token);\n  void SendToLocalMachine(uint64_t token, void* ptr, std::size_t size,\n                          std::function<void()> callback);\n  void RecvFromLocalMachine(uint64_t token, void* ptr, std::size_t max_size,\n                            std::function<void()> callback);\n\n  // TODO(chengcheng)\n  // Singleton<Transport> has a dependency on Singleton<CommNet> which should be initialized first.\n  friend class Singleton<Transport>;\n  Transport();\n\n  // TransportStatus stores all the information that Transport needs in a Send / Receive process.\n  //\n  // At the sender (source machine), the TransportStatus stores the callback from the Send().\n  // At the receiver (destination machine), the TransportStatus stores the callback from Receive().\n  //\n  // In the process of one transmission between two machines, the TransportStatus will be created,\n  // changed and finally deleted by sending and receiving messages for many times.\n  struct TransportStatus {\n    const uint64_t token;\n    std::function<void()> callback;\n    bool is_send_ready;\n    bool is_recv_ready;\n    void* src_mem_token;\n    void* dst_mem_token;\n    // NOTE(chengcheng): must store dst_ptr in status when Receive max_size > Send size\n    void* dst_ptr;\n    std::size_t size;\n    int64_t src_machine_id;\n    int64_t dst_machine_id;\n    TransportStatus(uint64_t tk)\n        : token(tk),\n          callback(nullptr),\n          is_send_ready(false),\n          is_recv_ready(false),\n          src_mem_token(nullptr),\n          dst_mem_token(nullptr),\n          size(-1),\n          src_machine_id(-1),\n          dst_machine_id(-1) {}\n  };\n\n  // CopyStatusOnLocalMachine is a stored state to support local data transfer.\n  //\n  // This state stores only the most necessary information.\n  //\n  // When Send() is called first, it stores the token, pointer, size and callback of the sender.\n  // In this way, when Receive() is called, copy and two callbacks can be executed.\n  //\n  // When Receive() is called first, it stores the token, pointer, size and callback of the\n  // receiver. In this way, when Send() is called, copy and two callbacks can be executed.\n  struct CopyStatusOnLocalMachine {\n    const uint64_t token;\n    void* ptr;\n    std::size_t size;\n    std::function<void()> callback;\n    CopyStatusOnLocalMachine(uint64_t tk, void* p, std::size_t s, std::function<void()> cb)\n        : token(tk), ptr(p), size(s), callback(std::move(cb)) {}\n  };\n\n  // Store the TransportStatus for each token (Send/Receive pair).\n  // The map token2status_ should be protected by status_mutex_ when you want to change it.\n  std::mutex status_mutex_;\n  HashMap<uint64_t, TransportStatus> token2status_;\n\n  // for local copy\n  std::mutex local_copy_lock_;\n  HashMap<uint64_t, CopyStatusOnLocalMachine> token2local_copy_status_;\n\n  int64_t this_machine_id_;\n  void* read_id_;\n  EpollCommNet* comm_net_;\n\n  Channel<TransportMsg> msg_channel_;\n  std::thread msg_poller_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_TRANSPORT_TRANSPORT_H_\n\n#endif  // __linux__\n"
  },
  {
    "path": "oneflow/core/transport/transport_message.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_TRANSPORT_TRANSPORT_MESSAGE_H_\n#define ONEFLOW_CORE_TRANSPORT_TRANSPORT_MESSAGE_H_\n\n#include \"oneflow/core/common/platform.h\"\n#include \"oneflow/core/common/util.h\"\n\n#ifdef __linux__\n\nnamespace oneflow {\n\nenum class TransportMsgType {\n  kInvalid = 0,\n  kSend = 1,  // send msg from local to remote transport\n  kAck = 2,   // this token transmission task is down\n};\n\nstruct TransportMsg {\n  uint64_t token;\n  void* src_mem_token;\n  void* dst_mem_token;\n  std::size_t size;\n  int64_t src_machine_id;\n  int64_t dst_machine_id;\n  TransportMsgType type;\n};\n\n}  // namespace oneflow\n\n#endif  // __linux__\n\n#endif  // ONEFLOW_CORE_TRANSPORT_TRANSPORT_MESSAGE_H_\n"
  },
  {
    "path": "oneflow/core/vm/access_blob_arg_cb_instruction_policy.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_VM_ACCESS_BLOB_ARG_CB_INSTRUCTION_POLICY_H_\n#define ONEFLOW_CORE_VM_ACCESS_BLOB_ARG_CB_INSTRUCTION_POLICY_H_\n\n#include <functional>\n#include <memory>\n#include \"oneflow/core/vm/instruction.h\"\n#include \"oneflow/core/vm/instruction_policy.h\"\n#include \"oneflow/core/vm/instruction_policy_util.h\"\n#include \"oneflow/core/eager/local_dep_object.h\"\n#include \"oneflow/core/eager/eager_blob_object.h\"\n#include \"oneflow/core/eager/tensor_storage.h\"\n#include \"oneflow/core/framework/tensor_storage.h\"\n#include \"oneflow/core/intrusive/list.h\"\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/vm/op_call_instruction_policy.h\"\n#include \"oneflow/core/vm/stream_policy.h\"\n\nnamespace oneflow {\nnamespace vm {\n\nclass AccessBlobArgCbInstructionPolicy final : public InstructionPolicy {\n public:\n  AccessBlobArgCbInstructionPolicy(\n      const std::shared_ptr<EagerBlobObject>& eager_blob_object,\n      const std::function<void(ep::Stream*, const std::shared_ptr<vm::EagerBlobObject>&)>& callback,\n      const std::string& modifier)\n      : eager_blob_object_(eager_blob_object),\n        callback_(callback),\n        modifier_(modifier),\n        input_dependences_(),\n        output_dependences_() {\n    ForEachConstDependence(InstructionPolicyUtil::SetInserter(&input_dependences_));\n    ForEachMutDependence(InstructionPolicyUtil::SetInserter(&output_dependences_));\n    ForEachMut2Dependence(InstructionPolicyUtil::SetInserter(&output_dependences_));\n    stream_sequential_dependence_ = nullptr;\n  }\n  ~AccessBlobArgCbInstructionPolicy() = default;\n\n  const std::shared_ptr<EagerBlobObject>& eager_blob_object() const { return eager_blob_object_; }\n\n  const DependenceVector& input_dependences() const override { return input_dependences_; }\n  const DependenceVector& output_dependences() const override { return output_dependences_; }\n\n  void ForEachConstDependence(const std::function<void(Dependence* compute)>& DoEach) const {\n    if (modifier_ == \"const\") {\n      DoEach(CHECK_JUST(eager_blob_object_->compute_local_dep_object()));\n    }\n  }\n\n  void ForEachMutDependence(const std::function<void(Dependence* compute)>& DoEach) const {\n    if (modifier_ == \"mut\") { DoEach(CHECK_JUST(eager_blob_object_->compute_local_dep_object())); }\n  }\n\n  void ForEachMut2Dependence(const std::function<void(Dependence* compute)>& DoEach) const {\n    if (modifier_ == \"mut2\") { DoEach(CHECK_JUST(eager_blob_object_->compute_local_dep_object())); }\n  }\n\n  std::string DebugName(const Instruction& instruction) const override {\n    return \"AccessBlobByCallback\";\n  }\n  Maybe<void> Prepare(Instruction* instruction) override { return Maybe<void>::Ok(); }\n  void Compute(Instruction* instruction) override {\n    StreamPolicy* stream_policy = instruction->mut_stream_policy();\n    auto rematable_storage =\n        std::dynamic_pointer_cast<RematableTensorStorage>(eager_blob_object()->tensor_storage());\n\n    if (rematable_storage && !rematable_storage->is_in_memory()) {\n      OpCallInstructionPolicy tmp_op = rematable_storage->compute_op();\n      CHECK_JUST(Recompute(&tmp_op, instruction->mut_stream()));\n    }\n    callback_(stream_policy->stream(), eager_blob_object());\n    if (rematable_storage && (modifier_ == \"mut\" || modifier_ == \"mut2\")) {\n      rematable_storage->set_eviction_disabled(true);\n    }\n  }\n\n private:\n  std::shared_ptr<EagerBlobObject> eager_blob_object_;\n  std::function<void(ep::Stream*, const std::shared_ptr<vm::EagerBlobObject>&)> callback_;\n  const std::string modifier_;\n  DependenceVector input_dependences_;\n  DependenceVector output_dependences_;\n};\n\n}  // namespace vm\n}  // namespace oneflow\n#endif  // ONEFLOW_CORE_VM_ACCESS_BLOB_ARG_CB_INSTRUCTION_POLICY_H_\n"
  },
  {
    "path": "oneflow/core/vm/allocate_tensor_instruction_policy.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/vm/allocate_tensor_instruction_policy.h\"\n\nnamespace oneflow {\nnamespace vm {\n\nAllocateTensorInstructionPolicy::AllocateTensorInstructionPolicy(\n    const EagerBlobObjectList& eager_blob_objects, vm::Stream* vm_stream)\n    : eager_blob_objects_(eager_blob_objects) {\n  stream_sequential_dependence_ = vm_stream->schedule_local_dep_object().get();\n  for (const auto& eager_blob_object : eager_blob_objects) {\n    output_dependences_.push_back(CHECK_JUST(eager_blob_object->compute_local_dep_object()));\n  }\n}\n\nstd::string AllocateTensorInstructionPolicy::DebugName(const vm::Instruction& instruction) const {\n  return \"AllocateTensor\";\n}\n\nvoid AllocateTensorInstructionPolicy::Compute(Instruction* instruction) {\n  Allocator* allocator = instruction->mut_stream()->mut_stream_policy()->mut_allocator();\n  for (const auto& eager_blob_object : eager_blob_objects_) {\n    CHECK_JUST(eager_blob_object->TryAllocateBlobBodyMemory(allocator));\n  }\n}\n\n}  // namespace vm\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/vm/allocate_tensor_instruction_policy.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_VM_ALLOCATE_INSTRUCTION_POLICY_H_\n#define ONEFLOW_CORE_VM_ALLOCATE_INSTRUCTION_POLICY_H_\n\n#include <memory>\n#include \"oneflow/core/eager/eager_blob_object.h\"\n#include \"oneflow/core/vm/instruction_policy.h\"\n#include \"oneflow/core/vm/stream.h\"\n\nnamespace oneflow {\n\nnamespace vm {\n\nclass AllocateTensorInstructionPolicy final : public InstructionPolicy {\n public:\n  AllocateTensorInstructionPolicy(const EagerBlobObjectList& eager_blob_objects,\n                                  vm::Stream* vm_stream);\n  AllocateTensorInstructionPolicy(const AllocateTensorInstructionPolicy&) = delete;\n  AllocateTensorInstructionPolicy(AllocateTensorInstructionPolicy&&) = delete;\n\n  ~AllocateTensorInstructionPolicy() override = default;\n\n  const DependenceVector& input_dependences() const override {\n    static thread_local DependenceVector input_dependences{};\n    return input_dependences;\n  }\n  const DependenceVector& output_dependences() const override { return output_dependences_; }\n\n  InstructionFuseType fuse_type() const override { return kEnableInstructionFuseAtAnyPosition; }\n\n  std::string DebugName(const vm::Instruction& instruction) const override;\n\n private:\n  Maybe<void> Prepare(Instruction* instruction) override { return Maybe<void>::Ok(); }\n  void Compute(Instruction* instruction) override;\n\n  EagerBlobObjectList eager_blob_objects_;\n  DependenceVector output_dependences_;\n};\n\n}  // namespace vm\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_VM_ALLOCATE_INSTRUCTION_POLICY_H_\n"
  },
  {
    "path": "oneflow/core/vm/allocator.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_VM_ALLOCATOR_H_\n#define ONEFLOW_CORE_VM_ALLOCATOR_H_\n\n#include <cstddef>\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/common/throw.h\"\n\nnamespace oneflow {\nnamespace vm {\n\nclass Allocator {\n public:\n  virtual ~Allocator() = default;\n\n  virtual Maybe<void> Allocate(char** mem_ptr, std::size_t size) = 0;\n  virtual void Deallocate(char* mem_ptr, std::size_t size) = 0;\n  virtual void DeviceReset() = 0;\n\n protected:\n  Allocator() = default;\n};\n\nclass UnimplementedAllocator final : public Allocator {\n public:\n  explicit UnimplementedAllocator(const std::string& debug_str) : debug_str_(debug_str) {}\n  virtual ~UnimplementedAllocator() = default;\n\n  Maybe<void> Allocate(char** mem_ptr, std::size_t size) override {\n    UNIMPLEMENTED_THEN_RETURN() << debug_str_;\n  }\n\n  void Deallocate(char* mem_ptr, std::size_t size) override { LOG(FATAL) << debug_str_; }\n  void DeviceReset() override { LOG(FATAL) << debug_str_; }\n\n private:\n  std::string debug_str_;\n};\n\n}  // namespace vm\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_VM_ALLOCATOR_H_\n"
  },
  {
    "path": "oneflow/core/vm/barrier_instruction_policy.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_VM_BARRIER_INSTRUCTION_POLICY_H_\n#define ONEFLOW_CORE_VM_BARRIER_INSTRUCTION_POLICY_H_\n\n#include \"oneflow/core/vm/instruction_policy.h\"\nnamespace oneflow {\nnamespace vm {\n\nclass BarrierInstructionPolicy final : public InstructionPolicy {\n public:\n  BarrierInstructionPolicy(const std::function<void()>& callback) : callback_(callback) {\n    stream_sequential_dependence_ = nullptr;\n  }\n  ~BarrierInstructionPolicy() override = default;\n\n  const DependenceVector& input_dependences() const override {\n    static DependenceVector dependences{};\n    return dependences;\n  }\n  const DependenceVector& output_dependences() const override {\n    static DependenceVector dependences{};\n    return dependences;\n  }\n\n  bool IsBarrier() const override { return true; }\n\n  std::string DebugName(const vm::Instruction& instruction) const override { return \"Barrier\"; }\n  Maybe<void> Prepare(Instruction* instruction) override { return Maybe<void>::Ok(); }\n  void Compute(Instruction* instruction) override { return callback_(); }\n\n private:\n  std::function<void()> callback_;\n};\n\n}  // namespace vm\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_VM_BARRIER_INSTRUCTION_POLICY_H_\n"
  },
  {
    "path": "oneflow/core/vm/bin_allocator.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_VM_BIN_ALLOCATOR_H_\n#define ONEFLOW_CORE_VM_BIN_ALLOCATOR_H_\n\n#include <cstdint>\n#include \"oneflow/core/vm/allocator.h\"\n#include \"oneflow/core/vm/caching_allocator.h\"\n#include \"oneflow/core/common/util.h\"\n\nnamespace oneflow {\nnamespace vm {\n\ntemplate<typename ThreadLock>\nclass BinAllocator final : public CachingAllocator {\n public:\n  explicit BinAllocator(size_t alignment, std::unique_ptr<Allocator>&& backend);\n  ~BinAllocator();\n\n  Maybe<void> Allocate(char** mem_ptr, std::size_t size) override;\n  void Deallocate(char* mem_ptr, std::size_t size) override;\n  void DeviceReset() override {\n    typename ThreadLock::RAIIGuard guard(thread_lock_);\n    backend_->DeviceReset();\n  }\n  void Shrink() override {\n    typename ThreadLock::RAIIGuard guard(thread_lock_);\n    DeallocateFreeBlockForGarbageCollection();\n  }\n\n private:\n  static constexpr int32_t kInvalidBinNum = -1;\n  static constexpr int32_t kBinNumSize = 20;\n\n  // Piece is the basic memory unit of BinAllocator.\n  // A Piece is either is free(is_free = true) or in used(is_free = false).\n  // If the Piece is_free = true, the pointer to the piece will be stored in the Bin structure of\n  // the corresponding BinSize. Pieces are stored in a linked list. The Piece's prev and next are\n  // continuous with the current Piece in physical memory.\n  struct Piece {\n    size_t size = 0;\n    char* ptr = nullptr;\n    bool is_free = false;\n    Piece* prev = nullptr;\n    Piece* next = nullptr;\n    int32_t bin_num = kInvalidBinNum;\n  };\n\n  // Bin is a structure that stores a set of pieces which is free and has similar size, and\n  // these Pieces are arger than the size of bin\n  //\n  // BinAllocator has a set of Bin structures according to the binary multiple increasing relation,\n  // which is used to quickly index and find the free Piece of appropriate size when Allocate()\n  //\n  // The size of the smallest bin is 512 (512 is the smallest unit Allocated by BinAllocator,\n  // and the memory size of all Allocated will be multiples of 512, 512 is kCudaMemAllocAlignSize).\n  // The size of each Bin is twice the size of the previous Bin, like\n  //    BinNum:   Bin0, Bin1, Bin2, Bin3, ..., Bin19\n  //    BinSize:  512, 1024, 2048, 4096, ... , 512MB\n  struct Bin {\n    size_t size = 0;\n\n    struct PieceCmp {\n      bool operator()(const Piece* lhs, const Piece* rhs) const {\n        if (lhs->size != rhs->size) { return lhs->size < rhs->size; }\n        return lhs->ptr < rhs->ptr;\n      }\n    };\n    std::set<Piece*, PieceCmp> pieces;\n  };\n\n  // Block is large physical memory that is actually allocated.\n  // There maybe many consecutive disjoint Pieces distributed on the Block memory\n  struct Block {\n    size_t size = 0;\n    char* ptr = nullptr;\n    Piece* start_piece = nullptr;\n    Block(Piece* p) : size(p->size), ptr(p->ptr), start_piece(p) {}\n  };\n\n  size_t BinSize4BinNum(int32_t bin_num) { return kCudaMemAllocAlignSize << bin_num; }\n\n  int32_t BinNum4BinSize(size_t size) {\n    uint64_t value = std::max(size, kCudaMemAllocAlignSize) >> 9;\n    return std::min(kBinNumSize - 1, static_cast<int32_t>(63 ^ __builtin_clzll(value)));\n  }\n\n  // Try find free Piece which size is larger than aligned_size in Bins.\n  // Return nullptr when find failure\n  Piece* FindPiece(size_t aligned_size);\n\n  // Insert the free Piece to the appropriate Bin which bin size is smaller than piece\n  void InsertPiece2Bin(Piece* piece);\n\n  // Create new empty Piece or recycle a Piece from recycle_piece_list_\n  Piece* AllocatePiece();\n  // Delete a Piece and move in the linked list recycle_piece_list_\n  void DeallocatePiece(Piece* piece);\n\n  // Insert a {piece->ptr, piece} pair into the ptr2piece_ map for search Piece when call\n  // Deallocate()\n  void MarkPiece(Piece* piece);\n  // Erase the {piece->ptr, piece} pair from ptr2piece_ because the ptr is useless\n  // Usually call before DeallocatePiece()\n  void UnMarkPiece(Piece* piece);\n\n  void MergeNeighbourFreePiece(Piece* lhs, Piece* rhs);\n  void RemovePieceFromBin(Piece* piece);\n\n  Maybe<bool> AllocateBlockToExtendTotalMem(size_t aligned_size);\n  bool DeallocateFreeBlockForGarbageCollection();\n\n  const size_t alignment_;\n  const std::unique_ptr<Allocator> backend_;\n  ThreadLock thread_lock_;\n  size_t total_memory_bytes_;\n  HashMap<char*, Block> mem_ptr2block_;\n\n  std::vector<Bin> bins_;\n  std::vector<std::unique_ptr<Piece>> pieces_;\n  HashMap<char*, Piece*> ptr2piece_;\n  Piece* recycle_piece_list_;\n};\n\nnamespace {\n\ninline size_t MemAlignedBytes(size_t bytes, size_t alignment) { return RoundUp(bytes, alignment); }\n\ninline bool IsAlignedSize(size_t size, size_t alignment) { return size % alignment == 0; }\n\nstatic const size_t kPieceSplitThreshold = 128 << 20;  // 128MiB\n\n}  // namespace\n\ntemplate<typename ThreadLock>\nBinAllocator<ThreadLock>::BinAllocator(size_t alignment, std::unique_ptr<Allocator>&& backend)\n    : CachingAllocator(),\n      alignment_(alignment),\n      backend_(std::move(backend)),\n      total_memory_bytes_(0),\n      recycle_piece_list_(nullptr) {\n  CHECK_GE(alignment, 1);\n  CHECK_EQ(1 << static_cast<int>(std::log2(alignment)), alignment);\n  bins_.resize(kBinNumSize);\n  for (int i = 0; i < kBinNumSize; ++i) {\n    size_t bin_size = BinSize4BinNum(i);\n    bins_.at(i).size = bin_size;\n    CHECK_EQ(BinNum4BinSize(bin_size), i);\n    CHECK_EQ(BinNum4BinSize(bin_size + alignment_ - 1), i);\n    CHECK_EQ(BinNum4BinSize(bin_size * 2 - 1), i);\n    CHECK_EQ(BinNum4BinSize(bin_size * 2), i == (kBinNumSize - 1) ? i : i + 1);\n  }\n}\n\ntemplate<typename ThreadLock>\nBinAllocator<ThreadLock>::~BinAllocator() {\n  if (total_memory_bytes_ == 0) {\n    CHECK_EQ(mem_ptr2block_.size(), 0);\n    return;\n  }\n  for (auto& pair : mem_ptr2block_) { backend_->Deallocate(pair.first, pair.second.size); }\n}\n\ntemplate<typename ThreadLock>\nvoid BinAllocator<ThreadLock>::InsertPiece2Bin(Piece* piece) {\n  CHECK(piece->is_free && piece->bin_num == kInvalidBinNum);\n  int32_t bin_num = BinNum4BinSize(piece->size);\n  piece->bin_num = bin_num;\n  CHECK(bins_.at(bin_num).pieces.insert(piece).second);\n}\n\ntemplate<typename ThreadLock>\nvoid BinAllocator<ThreadLock>::RemovePieceFromBin(Piece* piece) {\n  CHECK(piece->is_free);\n  CHECK_NE(piece->bin_num, kInvalidBinNum);\n  CHECK_GT(bins_.at(piece->bin_num).pieces.erase(piece), 0);\n  piece->bin_num = kInvalidBinNum;\n}\n\ntemplate<typename ThreadLock>\ntypename BinAllocator<ThreadLock>::Piece* BinAllocator<ThreadLock>::AllocatePiece() {\n  if (recycle_piece_list_) {\n    Piece* ret = recycle_piece_list_;\n    recycle_piece_list_ = recycle_piece_list_->next;\n    return ret;\n  } else {\n    pieces_.emplace_back(new Piece());\n    return pieces_.at(pieces_.size() - 1).get();\n  }\n}\n\ntemplate<typename ThreadLock>\nvoid BinAllocator<ThreadLock>::DeallocatePiece(Piece* piece) {\n  piece->ptr = nullptr;\n  piece->size = 0;\n  piece->bin_num = kInvalidBinNum;\n  piece->is_free = true;\n  piece->prev = nullptr;\n  piece->next = recycle_piece_list_;\n  recycle_piece_list_ = piece;\n}\n\ntemplate<typename ThreadLock>\nvoid BinAllocator<ThreadLock>::MarkPiece(Piece* piece) {\n  CHECK_NOTNULL(piece->ptr);\n  CHECK(ptr2piece_.emplace(piece->ptr, piece).second);\n}\ntemplate<typename ThreadLock>\nvoid BinAllocator<ThreadLock>::UnMarkPiece(Piece* piece) {\n  CHECK_NOTNULL(piece->ptr);\n  auto it = ptr2piece_.find(piece->ptr);\n  CHECK(it != ptr2piece_.end());\n  ptr2piece_.erase(it);\n}\n\ntemplate<typename ThreadLock>\ntypename BinAllocator<ThreadLock>::Piece* BinAllocator<ThreadLock>::FindPiece(size_t aligned_size) {\n  CHECK(IsAlignedSize(aligned_size, alignment_));\n  for (int32_t bin_num = BinNum4BinSize(aligned_size); bin_num < kBinNumSize; ++bin_num) {\n    Bin* bin = &bins_.at(bin_num);\n    for (auto it = bin->pieces.begin(); it != bin->pieces.end(); ++it) {\n      Piece* piece = *it;\n      CHECK(piece->is_free);\n      CHECK_NOTNULL(piece->ptr);\n      CHECK_EQ(piece->bin_num, bin_num);\n      CHECK(IsAlignedSize(piece->size, alignment_));\n      if (piece->size >= aligned_size) {\n        bin->pieces.erase(it);\n        piece->bin_num = kInvalidBinNum;\n        piece->is_free = false;\n        if (piece->size >= aligned_size * 2 || piece->size - aligned_size >= kPieceSplitThreshold) {\n          Piece* new_piece = AllocatePiece();\n          new_piece->ptr = piece->ptr + aligned_size;\n          new_piece->size = piece->size - aligned_size;\n          piece->size = aligned_size;\n\n          Piece* next_p = piece->next;\n          piece->next = new_piece;\n          new_piece->prev = piece;\n          new_piece->next = next_p;\n          if (next_p != nullptr) { next_p->prev = new_piece; }\n\n          new_piece->is_free = true;\n          new_piece->bin_num = kInvalidBinNum;\n          CHECK(IsAlignedSize(piece->size, alignment_));\n          CHECK(IsAlignedSize(new_piece->size, alignment_));\n          InsertPiece2Bin(new_piece);\n          MarkPiece(new_piece);\n        }\n        return piece;\n      }\n    }\n  }\n  return nullptr;\n}\n\ntemplate<typename ThreadLock>\nvoid BinAllocator<ThreadLock>::MergeNeighbourFreePiece(Piece* lhs, Piece* rhs) {\n  CHECK(lhs->is_free);\n  CHECK(rhs->is_free);\n  CHECK(lhs->next == rhs);\n  CHECK(lhs == rhs->prev);\n  CHECK(lhs->ptr + lhs->size == rhs->ptr);\n\n  lhs->size += rhs->size;\n  lhs->next = rhs->next;\n  if (rhs->next != nullptr) { rhs->next->prev = lhs; }\n  UnMarkPiece(rhs);\n  DeallocatePiece(rhs);\n}\n\ntemplate<typename ThreadLock>\nMaybe<bool> BinAllocator<ThreadLock>::AllocateBlockToExtendTotalMem(size_t aligned_size) {\n  CHECK_OR_RETURN(IsAlignedSize(aligned_size, alignment_)) << \"not aligned\";\n\n  size_t allocate_bytes = aligned_size;\n  if (allocate_bytes < 1048576) {\n    // Allocate 2MB if `allocate_bytes` is less than 1MB\n    allocate_bytes = 2097152;\n  } else if (allocate_bytes < 10485760) {\n    // Allocate 20MB if `allocate_bytes` is between 1MB and 10MB\n    allocate_bytes = 20971520;\n  } else {\n    // Round up to 2MB if `allocate_bytes` is larger than 10MB\n    allocate_bytes = RoundUp(allocate_bytes, 2097152);\n  }\n  const size_t final_allocate_bytes = MemAlignedBytes(allocate_bytes, alignment_);\n\n  if (final_allocate_bytes < aligned_size) { return false; }\n\n  char* mem_ptr = nullptr;\n  JUST(backend_->Allocate(&mem_ptr, final_allocate_bytes));\n  if (mem_ptr == nullptr) { return false; }\n\n  // extend sucess\n  total_memory_bytes_ += final_allocate_bytes;\n\n  Piece* piece = AllocatePiece();\n  piece->size = final_allocate_bytes;\n  piece->ptr = mem_ptr;\n  piece->prev = nullptr;\n  piece->next = nullptr;\n  piece->is_free = true;\n  piece->bin_num = kInvalidBinNum;\n  InsertPiece2Bin(piece);\n  MarkPiece(piece);\n\n  CHECK_OR_RETURN(mem_ptr2block_.emplace(mem_ptr, Block(piece)).second) << \"existed mem_ptr\";\n\n  return true;\n}\n\ntemplate<typename ThreadLock>\nbool BinAllocator<ThreadLock>::DeallocateFreeBlockForGarbageCollection() {\n  size_t total_free_bytes = 0;\n  HashSet<char*> free_block_ptrs;\n  for (const auto& pair : mem_ptr2block_) {\n    const Block& block = pair.second;\n    bool all_free = true;\n    Piece* p = block.start_piece;\n    while (p != nullptr) {\n      if (!(p->is_free)) {\n        all_free = false;\n        break;\n      }\n      p = p->next;\n    }\n\n    if (all_free) {\n      total_free_bytes += block.size;\n      free_block_ptrs.insert(pair.first);\n    }\n  }\n\n  total_memory_bytes_ -= total_free_bytes;\n\n  if (total_free_bytes > 0) {\n    VLOG(3) << \"BinAllocator try deallocate free block for garbage collection. \"\n            << \" deallocate free bytes : \" << total_free_bytes;\n    for (char* ptr : free_block_ptrs) {\n      auto it = mem_ptr2block_.find(ptr);\n      CHECK(it != mem_ptr2block_.end());\n      const Block& block = it->second;\n\n      // delete all Piece on Block\n      size_t piece_size_sum = 0;\n      Piece* p = block.start_piece;\n      CHECK_EQ(block.ptr, block.start_piece->ptr);\n      CHECK_EQ(block.ptr, ptr);\n      while (p != nullptr) {\n        Piece* next_p = p->next;\n        piece_size_sum += p->size;\n        RemovePieceFromBin(p);\n        UnMarkPiece(p);\n        DeallocatePiece(p);\n        p = next_p;\n      }\n      CHECK_EQ(block.size, piece_size_sum);\n\n      mem_ptr2block_.erase(it);\n      backend_->Deallocate(ptr, block.size);\n    }\n  }\n  return total_free_bytes > 0;\n}\n\ntemplate<typename ThreadLock>\nMaybe<void> BinAllocator<ThreadLock>::Allocate(char** mem_ptr, std::size_t size) {\n  typename ThreadLock::RAIIGuard guard(thread_lock_);\n  if (size == 0) {\n    *mem_ptr = nullptr;\n    return Maybe<void>::Ok();\n  }\n  size_t aligned_size = MemAlignedBytes(size, alignment_);\n\n  Piece* piece = FindPiece(aligned_size);\n\n  if (piece == nullptr) {\n    if (JUST(AllocateBlockToExtendTotalMem(aligned_size))) { piece = FindPiece(aligned_size); }\n  }\n\n  CHECK_NOTNULL_OR_RETURN(piece)\n      << Error::OutOfMemoryError() << \"Error! : Out of memory when allocate size : \" << size\n      << \".\\n The total_memory_bytes allocated by this BinAllocator is : \" << total_memory_bytes_;\n\n  if (piece == nullptr) {\n    backend_->DeviceReset();\n    LOG(FATAL) << \"Error! : Out of memory when allocate size : \" << size\n               << \".\\n The total_memory_bytes allocated by this BinAllocator is : \"\n               << total_memory_bytes_;\n  }\n  CHECK_NOTNULL_OR_RETURN(piece->ptr) << \"invalid piece null ptr\";\n  CHECK_OR_RETURN(ptr2piece_.find(piece->ptr) != ptr2piece_.end()) << \"piece is not found\";\n  *mem_ptr = piece->ptr;\n  return Maybe<void>::Ok();\n}\n\ntemplate<typename ThreadLock>\nvoid BinAllocator<ThreadLock>::Deallocate(char* mem_ptr, std::size_t size) {\n  if (mem_ptr == nullptr) { return; }\n  typename ThreadLock::RAIIGuard guard(thread_lock_);\n\n  auto it = ptr2piece_.find(mem_ptr);\n  CHECK(it != ptr2piece_.end()) << \"Error! : Try deallocate mem_ptr non-existent. mem ptr = \"\n                                << mem_ptr << \" size = \" << size;\n  Piece* piece = it->second;\n  CHECK_NOTNULL(piece);\n  CHECK_EQ(piece->ptr, mem_ptr);\n  CHECK(!piece->is_free);\n\n  piece->is_free = true;\n\n  Piece* last_piece_insert_to_bin = piece;\n  Piece* next_p = piece->next;\n  Piece* prev_p = piece->prev;\n\n  if (next_p != nullptr && next_p->is_free) {\n    CHECK_EQ(next_p->ptr, piece->ptr + piece->size);\n    RemovePieceFromBin(next_p);\n    MergeNeighbourFreePiece(piece, next_p);\n  }\n\n  if (prev_p != nullptr && prev_p->is_free) {\n    CHECK_EQ(piece->ptr, prev_p->ptr + prev_p->size);\n    RemovePieceFromBin(prev_p);\n    MergeNeighbourFreePiece(prev_p, piece);\n    last_piece_insert_to_bin = prev_p;\n  }\n  InsertPiece2Bin(last_piece_insert_to_bin);\n}\n\n}  // namespace vm\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_VM_BIN_ALLOCATOR_H_\n"
  },
  {
    "path": "oneflow/core/vm/bin_allocator_test.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <memory>\n#ifdef WITH_CUDA\n#include \"gtest/gtest.h\"\n#include \"oneflow/core/vm/bin_allocator.h\"\n#include \"oneflow/core/vm/thread_safe_guard.h\"\n#include \"oneflow/core/device/cuda_util.h\"\n\nnamespace oneflow {\nnamespace vm {\n\nclass CudaBackendAllocator final : public CachingAllocator {\n public:\n  explicit CudaBackendAllocator(int64_t device_id) : device_id_(device_id) {}\n  ~CudaBackendAllocator() override = default;\n\n  Maybe<void> Allocate(char** mem_ptr, std::size_t size) override;\n  void Deallocate(char* mem_ptr, std::size_t size) override;\n  void DeviceReset() override;\n  void Shrink() override{};\n\n private:\n  int64_t device_id_;\n};\n\nMaybe<void> CudaBackendAllocator::Allocate(char** mem_ptr, std::size_t size) {\n  cudaSetDevice(device_id_);\n  if (cudaMalloc(mem_ptr, size) != cudaSuccess) { *mem_ptr = nullptr; }\n  return Maybe<void>::Ok();\n}\n\nvoid CudaBackendAllocator::Deallocate(char* mem_ptr, std::size_t size) {\n  cudaSetDevice(device_id_);\n  OF_CUDA_CHECK(cudaFree(mem_ptr));\n}\n\nvoid CudaBackendAllocator::DeviceReset() {\n  cudaSetDevice(device_id_);\n  // NOTE(chengcheng): In some corner case on ubuntu, cuda memory not released even if OOM.\n  //   So there need release all cuda memory allocated by this process before core dump.\n  LOG(WARNING) << \"OOM error is detected, process will exit. And it will start to reset CUDA \"\n               << \"device for releasing device memory.\";\n  OF_CUDA_CHECK(cudaDeviceReset());\n}\n\nTEST(CudaBinAllocator, cuda_allocator) {\n  int gpu_num = -1;\n  cudaGetDeviceCount(&gpu_num);\n  if (gpu_num <= 0) {\n    LOG(INFO) << \"CudaBinAllocator Test: Skip because of non GPU device.\";\n    return;\n  }\n  ASSERT_TRUE(cudaSuccess == cudaSetDevice(0));\n  size_t free_bytes = -1;\n  size_t total_bytes = -1;\n  const size_t remain_bytes = 50 * 1048576;\n  ASSERT_TRUE(cudaSuccess == cudaMemGetInfo(&free_bytes, &total_bytes));\n  if (free_bytes <= remain_bytes || free_bytes - remain_bytes < remain_bytes) {\n    LOG(INFO)\n        << \"CudaBinAllocator Test: Skip because of allocator mem bytes less than 50MiB in GPU 0\";\n    return;\n  }\n  std::unique_ptr<Allocator> allo(new BinAllocator<ThreadSafeLock>(\n      kCudaMemAllocAlignSize, std::make_unique<CudaBackendAllocator>(0)));\n  Allocator* a = allo.get();\n  std::vector<char*> ptrs;\n  for (int i = 0; i < 512; ++i) {\n    char* ptr = nullptr;\n    CHECK_JUST(a->Allocate(&ptr, 1));\n    ASSERT_TRUE(ptr != nullptr);\n    ptrs.emplace_back(ptr);\n  }\n  std::sort(ptrs.begin(), ptrs.end());\n  for (int i = 0; i < 512; ++i) {\n    if (i > 0) {\n      ASSERT_TRUE(ptrs.at(i) != ptrs.at(i - 1));\n      ASSERT_TRUE(std::abs(ptrs.at(i) - ptrs.at(i - 1)) >= kCudaMemAllocAlignSize);\n    }\n    a->Deallocate(ptrs.at(i), 1);\n  }\n\n  ptrs.clear();\n  for (int i = 0; i < 2048; ++i) {\n    char* ptr = nullptr;\n    CHECK_JUST(a->Allocate(&ptr, 10000));\n    ASSERT_TRUE(ptr != nullptr);\n    ptrs.emplace_back(ptr);\n  }\n  std::sort(ptrs.begin(), ptrs.end());\n  for (int i = 0; i < 2048; ++i) {\n    if (i > 0) {\n      ASSERT_TRUE(ptrs.at(i) != ptrs.at(i - 1));\n      ASSERT_TRUE(std::abs(ptrs.at(i) - ptrs.at(i - 1)) >= kCudaMemAllocAlignSize);\n    }\n    a->Deallocate(ptrs.at(i), 10000);\n  }\n\n  char* data_ptr_1 = nullptr;\n  CHECK_JUST(a->Allocate(&data_ptr_1, 2048 * sizeof(float)));\n\n  char* data_ptr_2 = nullptr;\n  CHECK_JUST(a->Allocate(&data_ptr_2, 4096 * sizeof(double)));\n\n  ASSERT_TRUE(data_ptr_1 != data_ptr_2);\n  if (data_ptr_1 < data_ptr_2) {\n    ASSERT_TRUE(data_ptr_1 + 2048 * sizeof(float) <= data_ptr_2);\n  } else {\n    ASSERT_TRUE(data_ptr_2 + 4096 * sizeof(double) <= data_ptr_1);\n  }\n\n  a->Deallocate(data_ptr_2, 4096 * sizeof(double));\n  a->Deallocate(data_ptr_1, 2048 * sizeof(float));\n}\n\n}  // namespace vm\n}  // namespace oneflow\n\n#endif  // WITH_CUDA\n"
  },
  {
    "path": "oneflow/core/vm/caching_allocator.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_VM_CACHING_ALLOCATOR_H_\n#define ONEFLOW_CORE_VM_CACHING_ALLOCATOR_H_\n\n#include <cstddef>\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/vm/allocator.h\"\n\nnamespace oneflow {\nnamespace vm {\n\nclass CachingAllocator : public Allocator {\n public:\n  virtual ~CachingAllocator() = default;\n  virtual void Shrink() = 0;\n\n protected:\n  CachingAllocator() = default;\n};\n\n}  // namespace vm\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_VM_CACHING_ALLOCATOR_H_\n"
  },
  {
    "path": "oneflow/core/vm/control_stream_policy.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_VM_CONTROL_STREAM_POLICY_H_\n#define ONEFLOW_CORE_VM_CONTROL_STREAM_POLICY_H_\n\n#include \"oneflow/core/vm/instruction.h\"\n#include \"oneflow/core/vm/naive_instruction_status_querier.h\"\n#include \"oneflow/core/vm/stream_policy.h\"\n#include \"oneflow/core/vm/vm_object.h\"\n\nnamespace oneflow {\nnamespace vm {\n\nclass ControlStreamPolicy final : public StreamPolicy {\n public:\n  ControlStreamPolicy() = default;\n  ~ControlStreamPolicy() = default;\n\n  vm::Allocator* mut_allocator() override { return (vm::Allocator*)nullptr; }\n\n  DeviceType device_type() const override {\n    PRINT_BUG_PROMPT_AND_ABORT();\n    return DeviceType::kInvalidDevice;\n  }\n\n  ep::Stream* stream() override {\n    PRINT_BUG_PROMPT_AND_ABORT();\n    return nullptr;\n  }\n\n  void InitInstructionStatus(const Stream& stream,\n                             InstructionStatusBuffer* status_buffer) const override {\n    static_assert(sizeof(NaiveInstrStatusQuerier) < kInstructionStatusBufferBytes, \"\");\n    NaiveInstrStatusQuerier::PlacementNew(status_buffer->mut_buffer());\n  }\n  void DeleteInstructionStatus(const Stream& stream,\n                               InstructionStatusBuffer* status_buffer) const override {\n    auto* ptr = NaiveInstrStatusQuerier::MutCast(status_buffer->mut_buffer());\n    ptr->~NaiveInstrStatusQuerier();\n  }\n  bool QueryInstructionStatusLaunched(const Stream& stream,\n                                      const InstructionStatusBuffer& status_buffer) const override {\n    return NaiveInstrStatusQuerier::Cast(status_buffer.buffer())->launched();\n  }\n  bool QueryInstructionStatusDone(const Stream& stream,\n                                  const InstructionStatusBuffer& status_buffer) const override {\n    return NaiveInstrStatusQuerier::Cast(status_buffer.buffer())->done();\n  }\n  void Run(Instruction* instruction) const override {\n    instruction->Compute();\n    auto* status_buffer = instruction->mut_status_buffer();\n    NaiveInstrStatusQuerier::MutCast(status_buffer->mut_buffer())->set_done();\n  }\n\n  bool OnSchedulerThread(StreamType) const override { return true; }\n  bool SupportingTransportInstructions() const override { return false; }\n};\n\n}  // namespace vm\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_VM_CONTROL_STREAM_POLICY_H_\n"
  },
  {
    "path": "oneflow/core/vm/critical_section_instruction_policy.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/vm/critical_section_instruction_policy.h\"\n#include \"oneflow/core/common/container_util.h\"\n#include \"oneflow/core/common/just.h\"\n#include \"oneflow/core/device/ep_based_event_record.h\"\n#include \"oneflow/core/common/container_util.h\"\n#include \"oneflow/core/kernel/kernel_util.h\"\n#include \"oneflow/core/vm/stream.h\"\n#include \"oneflow/core/vm/vm_object.h\"\n\nnamespace oneflow {\nnamespace vm {\n\nvoid CriticalSectionBeginInstructionPolicy::ForEachDependence(\n    const std::function<void(Dependence*)>& DoEach) const {\n  for (const auto& eager_blob_object : *eager_blob_objects_) {\n    DoEach(CHECK_JUST(eager_blob_object->compute_local_dep_object()));\n  }\n}\n\nvoid CriticalSectionBeginInstructionPolicy::ForEachMutDependence(\n    const std::function<void(Dependence*)>& DoEach) const {\n  DoEach(vm_stream_->schedule_local_dep_object().get());\n}\n\nvoid CriticalSectionBeginInstructionPolicy::FinishInvalidInterfaceEventRecords() {\n  for (const auto& op_name : interfaces_op_names()) {\n    size_t index = CHECK_JUST(MapAt(op_name2interface_index_, op_name));\n    if (!interfaces_valid().at(index)) {\n      const auto& iter = op_name2end_event_record_->find(op_name);\n      CHECK(iter != op_name2end_event_record_->end());\n      iter->second->Init(std::make_shared<NaiveEventRecord>());\n    }\n  }\n}\n\nvoid CriticalSectionBeginInstructionPolicy::Finish() {\n  for (const auto& pair : *op_name2end_event_record_) {\n    pair.second->TryInit(std::make_shared<NaiveEventRecord>());\n  }\n}\n\nvoid InputCriticalSectionBeginInstructionPolicy::AccessBlobByOpName(ep::Stream* stream, Blob* blob,\n                                                                    const std::string& op_name) {\n  int64_t i = CHECK_JUST(MapAt(op_name2interface_index_, op_name));\n  CHECK(interfaces_valid().at(i));\n  const auto& eager_blob_object = eager_blob_objects_->at(i);\n  {\n    size_t header_size = blob->blob_desc().ByteSizeOfBlobHeader();\n    CHECK_EQ(header_size, eager_blob_object->shape().NumAxes() * sizeof(int64_t));\n    CHECK_EQ(blob->static_shape(), eager_blob_object->shape());\n  }\n  const auto& end_event_record = op_name2end_event_record_->at(op_name);\n  if (eager_blob_object->dptr() == nullptr) {\n    end_event_record->Init(std::make_shared<NaiveEventRecord>());\n  } else {\n    {\n      const size_t body_bytes = blob->ByteSizeOfBlobBody();\n      CHECK_EQ(eager_blob_object->ByteSizeOfBlobBody(), body_bytes);\n      AutoMemcpy(stream, blob->mut_dptr(), eager_blob_object->dptr(), body_bytes, blob->mem_case(),\n                 eager_blob_object->mem_case());\n    }\n    end_event_record->Init(EpBasedEventRecord::MakeEventRecord(stream));\n  }\n}\n\nvoid OutputCriticalSectionBeginInstructionPolicy::AccessBlobByOpName(ep::Stream* stream, Blob* blob,\n                                                                     const std::string& op_name) {\n  int64_t i = CHECK_JUST(MapAt(op_name2interface_index_, op_name));\n  CHECK(interfaces_valid().at(i));\n  auto& eager_blob_object = eager_blob_objects_->at(i);\n  CHECK_EQ(blob->static_shape(), eager_blob_object->shape());\n  const auto& end_event_record = op_name2end_event_record_->at(op_name);\n  if (eager_blob_object->dptr() == nullptr) {\n    end_event_record->Init(std::make_shared<NaiveEventRecord>());\n  } else {\n    {\n      const size_t body_bytes = blob->ByteSizeOfBlobBody();\n      CHECK_EQ(eager_blob_object->ByteSizeOfBlobBody(), body_bytes);\n      AutoMemcpy(stream, eager_blob_object->mut_dptr(), blob->dptr(), body_bytes,\n                 eager_blob_object->mem_case(), blob->mem_case());\n    }\n    end_event_record->Init(EpBasedEventRecord::MakeEventRecord(stream));\n  }\n}\n\nvoid CriticalSectionEndInstructionPolicy::ForEachDependence(\n    const std::function<void(vm::Dependence*)>& DoEach) const {\n  DoEach(CHECK_JUST(eager_blob_object_->compute_local_dep_object()));\n}\n\nvoid CriticalSectionEndInstructionPolicy::ForEachMutDependence(\n    const std::function<void(vm::Dependence*)>& DoEach) const {\n  DoEach(vm_stream_->schedule_local_dep_object().get());\n}\n\n}  // namespace vm\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/vm/critical_section_instruction_policy.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_VM_CRITICAL_SECTION_INSTRUCTION_POLICY_H_\n#define ONEFLOW_CORE_VM_CRITICAL_SECTION_INSTRUCTION_POLICY_H_\n\n#include \"oneflow/core/common/buffer_manager.h\"\n#include \"oneflow/core/device/event_record.h\"\n#include \"oneflow/core/eager/eager_blob_object.h\"\n#include \"oneflow/core/framework/nn_graph_if.h\"\n#include \"oneflow/core/job/critical_section_instance.h\"\n#include \"oneflow/core/vm/critical_section_status_querier.h\"\n#include \"oneflow/core/vm/instruction.h\"\n#include \"oneflow/core/vm/instruction_policy.h\"\n#include \"oneflow/core/vm/instruction_policy_util.h\"\n#include \"oneflow/core/vm/stream.h\"\n\nnamespace oneflow {\n\nnamespace vm {\n\nclass CriticalSectionBeginInstructionPolicy\n    : public InstructionPolicy,\n      public std::enable_shared_from_this<CriticalSectionBeginInstructionPolicy> {\n public:\n  CriticalSectionBeginInstructionPolicy(const CriticalSectionBeginInstructionPolicy&) = delete;\n  CriticalSectionBeginInstructionPolicy(CriticalSectionBeginInstructionPolicy&&) = delete;\n  CriticalSectionBeginInstructionPolicy& operator=(const CriticalSectionBeginInstructionPolicy&) =\n      delete;\n  CriticalSectionBeginInstructionPolicy& operator=(CriticalSectionBeginInstructionPolicy&&) =\n      delete;\n  virtual ~CriticalSectionBeginInstructionPolicy() = default;\n  explicit CriticalSectionBeginInstructionPolicy(\n      const std::shared_ptr<NNGraphIf>& nn_graph, const EagerBlobObjectListPtr& eager_blob_objects,\n      const std::shared_ptr<HashMap<std::string, std::shared_ptr<SharedEventRecord>>>&\n          op_name2end_event_record,\n      Stream* vm_stream)\n      : nn_graph_(nn_graph),\n        eager_blob_objects_(eager_blob_objects),\n        op_name2end_event_record_(op_name2end_event_record),\n        vm_stream_(vm_stream) {}\n\n  std::string DebugName(const Instruction& instruction) const override {\n    return \"CriticalSectionBegin\";\n  }\n  Maybe<void> Prepare(Instruction* instruction) override { return Maybe<void>::Ok(); }\n  void Compute(vm::Instruction* instruction) override {\n    OF_PROFILER_RANGE_GUARD(\"CriticalSectionBegin\");\n    {\n      const auto& critical_section_instance = MakeCriticalSectionInstance();\n      const auto& job_name = critical_section_instance->job_name();\n      auto* buffer_mgr = Singleton<BufferMgr<std::shared_ptr<CriticalSectionInstance>>>::Get();\n      for (int i = 0; i < interfaces_op_names().size(); ++i) {\n        if (interfaces_valid().at(i)) {\n          const std::string& interface_op_name = interfaces_op_names().at(i);\n          const auto& buffer_name = GetInterfaceBufferName(job_name, interface_op_name);\n          buffer_mgr->Get(buffer_name)->Push(critical_section_instance);\n        }\n      }\n      const auto& callback_buffer_name = GetInterfaceCriticalSectionCallbackBufferName(job_name);\n      buffer_mgr->Get(callback_buffer_name)->Push(critical_section_instance);\n      const auto& wait_buffer_name = GetInterfaceCriticalSectionWaitBufferName(job_name);\n      buffer_mgr->Get(wait_buffer_name)->Push(critical_section_instance);\n    }\n    {\n      auto* status_buffer_data = instruction->mut_status_buffer()->mut_buffer();\n      auto* status_querier = CriticalSectionStatusQuerier::MutCast(status_buffer_data);\n      status_querier->SetLaunched(std::make_shared<NaiveEventRecord>());\n    }\n  }\n  const std::shared_ptr<NNGraphIf>& nn_graph() const { return nn_graph_; }\n  const EagerBlobObjectListPtr& eager_blob_objects() const { return eager_blob_objects_; }\n\n  void ForEachDependence(const std::function<void(Dependence* compute)>&) const;\n\n  void ForEachMutDependence(const std::function<void(Dependence* compute)>&) const;\n\n  virtual const std::vector<std::string>& interfaces_op_names() const = 0;\n  virtual const std::vector<bool>& interfaces_valid() const = 0;\n  virtual std::string GetInterfaceBufferName(const std::string& job_name,\n                                             const std::string& op_name) const = 0;\n  virtual std::string GetInterfaceCriticalSectionCallbackBufferName(\n      const std::string& job_name) const = 0;\n  virtual std::string GetInterfaceCriticalSectionWaitBufferName(\n      const std::string& job_name) const = 0;\n  virtual void AccessBlobByOpName(ep::Stream* stream, Blob* blob, const std::string& op_name) = 0;\n\n  void FinishInvalidInterfaceEventRecords();\n  void Finish();\n\n protected:\n  std::shared_ptr<NNGraphIf> nn_graph_;\n  EagerBlobObjectListPtr eager_blob_objects_;\n  std::shared_ptr<HashMap<std::string, std::shared_ptr<SharedEventRecord>>>\n      op_name2end_event_record_;\n  HashMap<std::string, size_t> op_name2interface_index_;\n  Stream* vm_stream_;\n\n private:\n  class NaiveCriticalSectionInstance final : public CriticalSectionInstance {\n   public:\n    NaiveCriticalSectionInstance(const std::shared_ptr<CriticalSectionBeginInstructionPolicy>&\n                                     critical_section_begin_instruction_policy,\n                                 const std::string& job_name)\n        : CriticalSectionInstance(),\n          critical_section_begin_instruction_policy_(critical_section_begin_instruction_policy),\n          job_name_(job_name) {}\n\n    ~NaiveCriticalSectionInstance() override = default;\n\n    const std::string& job_name() const override { return job_name_; }\n\n    void AccessBlobByOpName(ep::Stream* stream, Blob* blob,\n                            const std::string& op_name) const override {\n      critical_section_begin_instruction_policy_->AccessBlobByOpName(stream, blob, op_name);\n    }\n    void Finish() const override { critical_section_begin_instruction_policy_->Finish(); }\n\n   private:\n    std::shared_ptr<CriticalSectionBeginInstructionPolicy>\n        critical_section_begin_instruction_policy_;\n    std::string job_name_;\n  };\n\n  std::shared_ptr<CriticalSectionInstance> MakeCriticalSectionInstance() {\n    return std::make_shared<NaiveCriticalSectionInstance>(this->shared_from_this(),\n                                                          nn_graph_->job_name());\n  }\n};\n\nclass InputCriticalSectionBeginInstructionPolicy final\n    : public CriticalSectionBeginInstructionPolicy {\n public:\n  InputCriticalSectionBeginInstructionPolicy(\n      const std::shared_ptr<NNGraphIf>& nn_graph, const EagerBlobObjectListPtr& eager_blob_objects,\n      const std::shared_ptr<HashMap<std::string, std::shared_ptr<SharedEventRecord>>>&\n          op_name2end_event_record,\n      Stream* vm_stream)\n      : CriticalSectionBeginInstructionPolicy(nn_graph, eager_blob_objects,\n                                              op_name2end_event_record, vm_stream),\n        input_dependences_(),\n        output_dependences_() {\n    ForEachConstDependence(InstructionPolicyUtil::SetInserter(&input_dependences_));\n    ForEachMutDependence(InstructionPolicyUtil::SetInserter(&output_dependences_));\n    ForEachMut2Dependence(InstructionPolicyUtil::SetInserter(&output_dependences_));\n    CHECK_EQ(nn_graph->inputs_op_names().size(), eager_blob_objects->size());\n    CHECK_EQ(nn_graph->inputs_op_names().size(), nn_graph->inputs_valid().size());\n    for (int i = 0; i < nn_graph->inputs_op_names().size(); ++i) {\n      CHECK(op_name2interface_index_.emplace(nn_graph->inputs_op_names().at(i), i).second);\n    }\n  }\n\n  ~InputCriticalSectionBeginInstructionPolicy() override = default;\n\n  const DependenceVector& input_dependences() const override { return input_dependences_; }\n  const DependenceVector& output_dependences() const override { return output_dependences_; }\n\n  // for inputs\n  void ForEachConstDependence(const std::function<void(Dependence* compute)>& DoEach) const {\n    ForEachDependence(DoEach);\n  }\n\n  // for outputs\n  const std::vector<std::string>& interfaces_op_names() const override {\n    return nn_graph_->inputs_op_names();\n  }\n  const std::vector<bool>& interfaces_valid() const override { return nn_graph_->inputs_valid(); }\n  std::string GetInterfaceBufferName(const std::string& job_name,\n                                     const std::string& op_name) const override {\n    return GetInputBufferName(job_name, op_name);\n  }\n  std::string GetInterfaceCriticalSectionCallbackBufferName(\n      const std::string& job_name) const override {\n    return GetInputCriticalSectionCallbackBufferName(job_name);\n  }\n  std::string GetInterfaceCriticalSectionWaitBufferName(\n      const std::string& job_name) const override {\n    return GetInputCriticalSectionWaitBufferName(job_name);\n  }\n  void AccessBlobByOpName(ep::Stream* stream, Blob* blob, const std::string& op_name) override;\n  void ForEachMut2Dependence(const std::function<void(Dependence* compute)>&) const {}\n\n private:\n  DependenceVector input_dependences_;\n  DependenceVector output_dependences_;\n};\n\nclass OutputCriticalSectionBeginInstructionPolicy final\n    : public CriticalSectionBeginInstructionPolicy {\n public:\n  OutputCriticalSectionBeginInstructionPolicy(\n      const std::shared_ptr<NNGraphIf>& nn_graph, const EagerBlobObjectListPtr& eager_blob_objects,\n      const std::shared_ptr<HashMap<std::string, std::shared_ptr<SharedEventRecord>>>&\n          op_name2end_event_record,\n      Stream* vm_stream)\n      : CriticalSectionBeginInstructionPolicy(nn_graph, eager_blob_objects,\n                                              op_name2end_event_record, vm_stream),\n        input_dependences_(),\n        output_dependences_() {\n    ForEachConstDependence(InstructionPolicyUtil::SetInserter(&input_dependences_));\n    ForEachMutDependence(InstructionPolicyUtil::SetInserter(&output_dependences_));\n    ForEachMut2Dependence(InstructionPolicyUtil::SetInserter(&output_dependences_));\n    CHECK_EQ(nn_graph->outputs_op_names().size(), eager_blob_objects->size());\n    CHECK_EQ(nn_graph->outputs_op_names().size(), nn_graph->outputs_valid().size());\n    for (int i = 0; i < nn_graph->outputs_op_names().size(); ++i) {\n      CHECK(op_name2interface_index_.emplace(nn_graph->outputs_op_names().at(i), i).second);\n    }\n  }\n\n  ~OutputCriticalSectionBeginInstructionPolicy() override = default;\n\n  const DependenceVector& input_dependences() const override { return input_dependences_; }\n  const DependenceVector& output_dependences() const override { return output_dependences_; }\n\n  // for inputs\n  void ForEachConstDependence(const std::function<void(Dependence* compute)>&) const {}\n\n  // for outputs\n  void ForEachMut2Dependence(const std::function<void(Dependence* compute)>& DoEach) const {\n    ForEachDependence(DoEach);\n  }\n\n  const std::vector<std::string>& interfaces_op_names() const override {\n    return nn_graph_->outputs_op_names();\n  }\n  const std::vector<bool>& interfaces_valid() const override { return nn_graph_->outputs_valid(); }\n  std::string GetInterfaceBufferName(const std::string& job_name,\n                                     const std::string& op_name) const override {\n    return GetOutputBufferName(job_name, op_name);\n  }\n  std::string GetInterfaceCriticalSectionCallbackBufferName(\n      const std::string& job_name) const override {\n    return GetOutputCriticalSectionCallbackBufferName(job_name);\n  }\n  std::string GetInterfaceCriticalSectionWaitBufferName(\n      const std::string& job_name) const override {\n    return GetOutputCriticalSectionWaitBufferName(job_name);\n  }\n  void AccessBlobByOpName(ep::Stream* stream, Blob* blob, const std::string& op_name) override;\n\n private:\n  DependenceVector input_dependences_;\n  DependenceVector output_dependences_;\n};\n\nclass CriticalSectionEndInstructionPolicy : public InstructionPolicy {\n public:\n  CriticalSectionEndInstructionPolicy(const CriticalSectionEndInstructionPolicy&) = delete;\n  CriticalSectionEndInstructionPolicy(CriticalSectionEndInstructionPolicy&&) = delete;\n  CriticalSectionEndInstructionPolicy& operator=(const CriticalSectionEndInstructionPolicy&) =\n      delete;\n  CriticalSectionEndInstructionPolicy& operator=(CriticalSectionEndInstructionPolicy&&) = delete;\n  CriticalSectionEndInstructionPolicy(const std::shared_ptr<EagerBlobObject>& eager_blob_object,\n                                      const std::shared_ptr<SharedEventRecord>& event_record,\n                                      vm::Stream* vm_stream)\n      : eager_blob_object_(eager_blob_object), event_record_(event_record), vm_stream_(vm_stream) {}\n  virtual ~CriticalSectionEndInstructionPolicy() = default;\n\n  std::string DebugName(const Instruction& instruction) const override {\n    return \"CriticalSectionEnd\";\n  }\n  Maybe<void> Prepare(Instruction* instruction) override { return Maybe<void>::Ok(); }\n  void Compute(Instruction* instruction) override {\n    auto* status_buffer_data = instruction->mut_status_buffer()->mut_buffer();\n    auto* status_querier = CriticalSectionStatusQuerier::MutCast(status_buffer_data);\n    status_querier->SetLaunched(event_record());\n  }\n  const std::shared_ptr<SharedEventRecord>& event_record() const { return event_record_; }\n\n  void ForEachDependence(const std::function<void(vm::Dependence* compute)>&) const;\n\n  void ForEachMutDependence(const std::function<void(vm::Dependence* compute)>&) const;\n\n private:\n  std::shared_ptr<EagerBlobObject> eager_blob_object_;\n  std::shared_ptr<SharedEventRecord> event_record_;\n  vm::Stream* vm_stream_;\n};\n\nclass InputCriticalSectionEndInstructionPolicy final : public CriticalSectionEndInstructionPolicy {\n public:\n  InputCriticalSectionEndInstructionPolicy(\n      const std::shared_ptr<EagerBlobObject>& eager_blob_object,\n      const std::shared_ptr<SharedEventRecord>& event_record, vm::Stream* vm_stream)\n      : CriticalSectionEndInstructionPolicy(eager_blob_object, event_record, vm_stream),\n        input_dependences_(),\n        output_dependences_() {\n    ForEachConstDependence(InstructionPolicyUtil::SetInserter(&input_dependences_));\n    ForEachMutDependence(InstructionPolicyUtil::SetInserter(&output_dependences_));\n    ForEachMut2Dependence(InstructionPolicyUtil::SetInserter(&output_dependences_));\n  }\n  ~InputCriticalSectionEndInstructionPolicy() override = default;\n\n  const DependenceVector& input_dependences() const override { return input_dependences_; }\n  const DependenceVector& output_dependences() const override { return output_dependences_; }\n\n  void ForEachConstDependence(const std::function<void(vm::Dependence* compute)>& DoEach) const {\n    ForEachDependence(DoEach);\n  }\n\n  void ForEachMut2Dependence(const std::function<void(vm::Dependence* compute)>&) const {}\n\n private:\n  DependenceVector input_dependences_;\n  DependenceVector output_dependences_;\n};\n\nclass OutputCriticalSectionEndInstructionPolicy final : public CriticalSectionEndInstructionPolicy {\n public:\n  OutputCriticalSectionEndInstructionPolicy(\n      const std::shared_ptr<EagerBlobObject>& eager_blob_object,\n      const std::shared_ptr<SharedEventRecord>& event_record, vm::Stream* vm_stream)\n      : CriticalSectionEndInstructionPolicy(eager_blob_object, event_record, vm_stream),\n        input_dependences_(),\n        output_dependences_() {\n    ForEachConstDependence(InstructionPolicyUtil::SetInserter(&input_dependences_));\n    ForEachMutDependence(InstructionPolicyUtil::SetInserter(&output_dependences_));\n    ForEachMut2Dependence(InstructionPolicyUtil::SetInserter(&output_dependences_));\n  }\n  ~OutputCriticalSectionEndInstructionPolicy() override = default;\n\n  const DependenceVector& input_dependences() const override { return input_dependences_; }\n  const DependenceVector& output_dependences() const override { return output_dependences_; }\n\n  // for inputs\n  void ForEachConstDependence(const std::function<void(vm::Dependence* compute)>&) const {}\n\n  // for outputs\n  void ForEachMut2Dependence(const std::function<void(vm::Dependence* compute)>& DoEach) const {\n    ForEachDependence(DoEach);\n  }\n\n private:\n  DependenceVector input_dependences_;\n  DependenceVector output_dependences_;\n};\n\n}  // namespace vm\n}  // namespace oneflow\n#endif  // ONEFLOW_CORE_VM_CRITICAL_SECTION_INSTRUCTION_POLICY_H_\n"
  },
  {
    "path": "oneflow/core/vm/critical_section_status_querier.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_VM_CRITICAL_SECTION_QUERIER_H_\n#define ONEFLOW_CORE_VM_CRITICAL_SECTION_QUERIER_H_\n\n#include <atomic>\n#include <memory>\n#include \"oneflow/core/device/event_record.h\"\n\nnamespace oneflow {\nnamespace vm {\n\nclass CriticalSectionStatusQuerier final {\n public:\n  ~CriticalSectionStatusQuerier() = default;\n\n  bool QueryLaunched() const { return launched_; }\n  bool QueryDone() const { return launched_ && event_record_->QueryDone(); }\n\n  void SetLaunched(const std::shared_ptr<EventRecord>& event_record) {\n    // No lock needed. This function will be called only one time.\n    // In most cases, errors will be successfully detected by CHECK\n    // even though run in different threads.\n    CHECK(!launched_);\n    event_record_ = event_record;\n    launched_ = true;\n  }\n\n  static const CriticalSectionStatusQuerier* Cast(const char* mem_ptr) {\n    return reinterpret_cast<const CriticalSectionStatusQuerier*>(mem_ptr);\n  }\n  static CriticalSectionStatusQuerier* MutCast(char* mem_ptr) {\n    return reinterpret_cast<CriticalSectionStatusQuerier*>(mem_ptr);\n  }\n  static CriticalSectionStatusQuerier* PlacementNew(char* mem_ptr) {\n    return new (mem_ptr) CriticalSectionStatusQuerier();\n  }\n\n private:\n  explicit CriticalSectionStatusQuerier() : launched_(false) {}\n\n  std::atomic<bool> launched_;\n  std::shared_ptr<EventRecord> event_record_;\n};\n\n}  // namespace vm\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_VM_CRITICAL_SECTION_QUERIER_H_\n"
  },
  {
    "path": "oneflow/core/vm/critical_section_stream_policy.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/vm/critical_section_stream_policy.h\"\n#include \"oneflow/core/vm/instruction.h\"\n#include \"oneflow/core/vm/thread_ctx.h\"\n#include \"oneflow/core/vm/critical_section_status_querier.h\"\n#include \"oneflow/core/common/util.h\"\n\nnamespace oneflow {\nnamespace vm {\n\nvoid CriticalSectionStreamPolicy::InitInstructionStatus(\n    const Stream& stream, InstructionStatusBuffer* status_buffer) const {\n  static_assert(sizeof(CriticalSectionStatusQuerier) < kInstructionStatusBufferBytes, \"\");\n  CriticalSectionStatusQuerier::PlacementNew(status_buffer->mut_buffer());\n}\n\nvoid CriticalSectionStreamPolicy::DeleteInstructionStatus(\n    const Stream& stream, InstructionStatusBuffer* status_buffer) const {\n  auto* ptr = CriticalSectionStatusQuerier::MutCast(status_buffer->mut_buffer());\n  ptr->~CriticalSectionStatusQuerier();\n}\n\nbool CriticalSectionStreamPolicy::QueryInstructionStatusLaunched(\n    const Stream& stream, const InstructionStatusBuffer& status_buffer) const {\n  return CriticalSectionStatusQuerier::Cast(status_buffer.buffer())->QueryLaunched();\n}\n\nbool CriticalSectionStreamPolicy::QueryInstructionStatusDone(\n    const Stream& stream, const InstructionStatusBuffer& status_buffer) const {\n  return CriticalSectionStatusQuerier::Cast(status_buffer.buffer())->QueryDone();\n}\n\nvoid CriticalSectionStreamPolicy::Run(Instruction* instruction) const { instruction->Compute(); }\n\n}  // namespace vm\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/vm/critical_section_stream_policy.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#ifndef ONEFLOW_CORE_VM_CRITICAL_SECTION_STREAM_POLICY_H_\n#define ONEFLOW_CORE_VM_CRITICAL_SECTION_STREAM_POLICY_H_\n\n#include \"oneflow/core/vm/stream_policy.h\"\n#include \"oneflow/core/vm/instruction.h\"\n\nnamespace oneflow {\nnamespace vm {\n\nclass CriticalSectionStreamPolicy final : public StreamPolicy {\n public:\n  CriticalSectionStreamPolicy() = default;\n  virtual ~CriticalSectionStreamPolicy() = default;\n\n  vm::Allocator* mut_allocator() override { return (vm::Allocator*)nullptr; }\n\n  DeviceType device_type() const override {\n    PRINT_BUG_PROMPT_AND_ABORT();\n    return DeviceType::kInvalidDevice;\n  }\n\n  ep::Stream* stream() override {\n    PRINT_BUG_PROMPT_AND_ABORT();\n    return nullptr;\n  }\n\n  void InitInstructionStatus(const Stream& stream,\n                             InstructionStatusBuffer* status_buffer) const override;\n  void DeleteInstructionStatus(const Stream& stream,\n                               InstructionStatusBuffer* status_buffer) const override;\n  bool QueryInstructionStatusLaunched(const Stream& stream,\n                                      const InstructionStatusBuffer& status_buffer) const override;\n  bool QueryInstructionStatusDone(const Stream& stream,\n                                  const InstructionStatusBuffer& status_buffer) const override;\n  void Run(Instruction* instruction) const override;\n  bool SupportingTransportInstructions() const override { return false; }\n};\n\n}  // namespace vm\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_VM_CRITICAL_SECTION_STREAM_POLICY_H_\n"
  },
  {
    "path": "oneflow/core/vm/ep_backend_allocator.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/vm/ep_backend_allocator.h\"\n#include \"oneflow/core/ep/include/device.h\"\n\nnamespace oneflow {\nnamespace vm {\n\nMaybe<void> EpBackendAllocator::Allocate(char** mem_ptr, std::size_t size) {\n  return ep_device_->Alloc(allocation_options_, reinterpret_cast<void**>(mem_ptr), size);\n}\n\nvoid EpBackendAllocator::Deallocate(char* mem_ptr, std::size_t size) {\n  ep_device_->Free(allocation_options_, mem_ptr);\n}\n\nvoid EpBackendAllocator::DeviceReset() {\n  if (ep_device_->device_type() != DeviceType::kCPU) {\n    // NOTE(chengcheng): In some corner case on ubuntu, cuda memory not released even if OOM.\n    //   So there need release all cuda memory allocated by this process before core dump.\n    LOG(WARNING) << \"OOM error is detected, process will exit. And it will start to reset \"\n                 << \"device for releasing device memory.\";\n    ep_device_->Reset();\n  }\n}\n\n}  // namespace vm\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/vm/ep_backend_allocator.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_VM_CUDA_BACKEND_ALLOCATOR_H_\n#define ONEFLOW_CORE_VM_CUDA_BACKEND_ALLOCATOR_H_\n\n#include <cstdint>\n#include \"oneflow/core/vm/allocator.h\"\n#include \"oneflow/core/ep/include/allocation_options.h\"\n#include \"oneflow/core/common/util.h\"\n\nnamespace oneflow {\n\nnamespace ep {\n\nclass Device;\n\n}\n\nnamespace vm {\n\nclass EpBackendAllocator final : public Allocator {\n public:\n  explicit EpBackendAllocator(const std::shared_ptr<ep::Device>& ep_device,\n                              const ep::AllocationOptions& allocation_options)\n      : ep_device_(ep_device), allocation_options_(allocation_options) {}\n  ~EpBackendAllocator() override = default;\n\n  Maybe<void> Allocate(char** mem_ptr, std::size_t size) override;\n  void Deallocate(char* mem_ptr, std::size_t size) override;\n  void DeviceReset() override;\n\n private:\n  std::shared_ptr<ep::Device> ep_device_;\n  ep::AllocationOptions allocation_options_;\n};\n\n}  // namespace vm\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_VM_CUDA_BACKEND_ALLOCATOR_H_\n"
  },
  {
    "path": "oneflow/core/vm/ep_backend_host_allocator.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/vm/ep_backend_host_allocator.h\"\n#include \"oneflow/core/device/cuda_util.h\"\n#include \"oneflow/core/ep/include/device.h\"\n\nnamespace oneflow {\n\nnamespace vm {\n\nMaybe<void> EpBackendHostAllocator::Allocate(char** mem_ptr, std::size_t size) {\n  JUST(ep_device_->AllocPinned(allocation_options_, reinterpret_cast<void**>(mem_ptr), size));\n  return Maybe<void>::Ok();\n}\n\nvoid EpBackendHostAllocator::Deallocate(char* mem_ptr, std::size_t size) {\n  ep_device_->FreePinned(allocation_options_, mem_ptr);\n}\n\n}  // namespace vm\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/vm/ep_backend_host_allocator.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_VM_CUDA_BACKEND_HOST_ALLOCATOR_H_\n#define ONEFLOW_CORE_VM_CUDA_BACKEND_HOST_ALLOCATOR_H_\n\n#include <cstdint>\n#include \"oneflow/core/vm/allocator.h\"\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/ep/include/allocation_options.h\"\n\nnamespace oneflow {\n\nnamespace ep {\n\nclass Device;\n\n}\n\nnamespace vm {\n\nclass EpBackendHostAllocator final : public Allocator {\n public:\n  explicit EpBackendHostAllocator(const std::shared_ptr<ep::Device>& ep_device,\n                                  const ep::AllocationOptions& allocation_options)\n      : ep_device_(ep_device), allocation_options_(allocation_options) {}\n  ~EpBackendHostAllocator() override = default;\n\n  Maybe<void> Allocate(char** mem_ptr, std::size_t size) override;\n  void Deallocate(char* mem_ptr, std::size_t size) override;\n  void DeviceReset() override {}\n\n private:\n  std::shared_ptr<ep::Device> ep_device_;\n  ep::AllocationOptions allocation_options_;\n};\n\n}  // namespace vm\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_VM_CUDA_BACKEND_HOST_ALLOCATOR_H_\n"
  },
  {
    "path": "oneflow/core/vm/ep_d2h_stream_policy.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/vm/ep_d2h_stream_policy.h\"\n#include <memory>\n#include \"oneflow/core/vm/stream.h\"\n#include \"oneflow/core/vm/thread_ctx.h\"\n#include \"oneflow/core/vm/ep_optional_event_record_status_querier.h\"\n#include \"oneflow/core/vm/ep_backend_host_allocator.h\"\n#include \"oneflow/core/common/util.h\"\n\nnamespace oneflow {\nnamespace vm {\n\nnamespace {\n\nstd::unique_ptr<BinAllocator<ThreadSafeLock>> CreateEpBackendHostAllocator(Symbol<Device> device) {\n  DeviceType device_type = device->enum_type();\n  size_t device_index = device->device_id();\n  auto ep_device =\n      Singleton<ep::DeviceManagerRegistry>::Get()->GetDevice(device_type, device_index);\n  auto ep_backend_allocator =\n      std::make_unique<EpBackendHostAllocator>(ep_device, ep::AllocationOptions{});\n  return std::make_unique<BinAllocator<ThreadSafeLock>>(ep::kMaxAlignmentRequirement,\n                                                        std::move(ep_backend_allocator));\n}\n\n}  // namespace\n\nEpD2HStreamPolicy::EpD2HStreamPolicy(Symbol<Device> device)\n    : EpStreamPolicyBase(device, CreateEpBackendHostAllocator(device)) {}\n\nvoid EpD2HStreamPolicy::InitInstructionStatus(const Stream& stream,\n                                              InstructionStatusBuffer* status_buffer) const {\n  static_assert(sizeof(EpOptionalEventRecordStatusQuerier) < kInstructionStatusBufferBytes, \"\");\n  EpStreamPolicyBase* ep_stream_policy_base =\n      dynamic_cast<EpStreamPolicyBase*>(const_cast<Stream&>(stream).mut_stream_policy());\n  CHECK_NOTNULL(ep_stream_policy_base);\n  auto* ep_event_provider = ep_stream_policy_base->ep_event_provider();\n  auto* data_ptr = status_buffer->mut_buffer();\n  const auto& ep_event = CHECK_NOTNULL(ep_event_provider)->GetReusedEpEvent();\n  EpOptionalEventRecordStatusQuerier::PlacementNew(data_ptr, ep_event);\n}\n\n}  // namespace vm\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/vm/ep_d2h_stream_policy.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_VM_EP_D2H_STREAM_POLICY_H_\n#define ONEFLOW_CORE_VM_EP_D2H_STREAM_POLICY_H_\n\n#include \"oneflow/core/vm/ep_stream_policy_base.h\"\n\nnamespace oneflow {\nnamespace vm {\n\nclass EpD2HStreamPolicy final : public EpStreamPolicyBase {\n public:\n  EpD2HStreamPolicy(Symbol<Device> device);\n  ~EpD2HStreamPolicy() override = default;\n\n  void InitInstructionStatus(const Stream& stream,\n                             InstructionStatusBuffer* status_buffer) const override;\n};\n\n}  // namespace vm\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_VM_EP_D2H_STREAM_POLICY_H_\n"
  },
  {
    "path": "oneflow/core/vm/ep_event.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/vm/ep_event.h\"\n\nnamespace oneflow {\n\nEpEvent::EpEvent(ep::Device* device) : device_(device), event_(nullptr) {\n  device_->SetAsActiveDevice();\n  event_ = device_->CreateEvent();  // NOLINT\n}\n\nEpEvent::~EpEvent() {\n  device_->SetAsActiveDevice();\n  device_->DestroyEvent(event_);\n}\n\nbool EpEvent::Query() const {\n  device_->SetAsActiveDevice();\n  return CHECK_JUST(event_->QueryDone());\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/vm/ep_event.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_VM_EP_EVENT_H_\n#define ONEFLOW_CORE_VM_EP_EVENT_H_\n\n#include \"oneflow/core/ep/include/device.h\"\n#include \"oneflow/core/ep/include/event.h\"\n#include \"oneflow/core/common/single_thread_obj_pool.h\"\n\nnamespace oneflow {\n\nclass EpEvent final {\n public:\n  EpEvent(const EpEvent&) = delete;\n  EpEvent(EpEvent&&) = delete;\n\n  EpEvent(ep::Device* device);\n  ~EpEvent();\n\n  bool Query() const;\n\n  ep::Device* mut_device() { return device_; }\n\n  ep::Event* mut_event() { return event_; }\n\n private:\n  ep::Device* device_;\n  ep::Event* event_;\n};\n\nclass EpEventProvider {\n public:\n  EpEventProvider(const EpEventProvider&) = delete;\n  EpEventProvider(EpEventProvider&&) = delete;\n  virtual ~EpEventProvider() = default;\n\n  virtual std::shared_ptr<EpEvent> GetReusedEpEvent() = 0;\n\n protected:\n  EpEventProvider() = default;\n};\n\nclass SingleThreadEpEventProvider final : public EpEventProvider {\n public:\n  SingleThreadEpEventProvider(const SingleThreadEpEventProvider&) = delete;\n  SingleThreadEpEventProvider(SingleThreadEpEventProvider&&) = delete;\n  explicit SingleThreadEpEventProvider(ep::Device* device)\n      : EpEventProvider(), events_(new SingleThreadPoolType()), device_(device) {}\n  ~SingleThreadEpEventProvider() = default;\n\n  std::shared_ptr<EpEvent> GetReusedEpEvent() override { return events_->make_shared(device_); }\n\n private:\n  using SingleThreadPoolType =\n      obj_pool::SingleThreadObjPool<EpEvent, obj_pool::kDisableReconstruct>;\n  std::shared_ptr<SingleThreadPoolType> events_;\n  ep::Device* device_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_VM_EP_EVENT_H_\n"
  },
  {
    "path": "oneflow/core/vm/ep_optional_event_record_status_querier.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/vm/ep_optional_event_record_status_querier.h\"\n\nnamespace oneflow {\nnamespace vm {\n\nvoid EpOptionalEventRecordStatusQuerier::SetLaunched(ep::Stream* stream) {\n  CHECK(!launched_);\n  if (ep_event_) {\n    ep_event_->mut_device()->SetAsActiveDevice();\n    stream->RecordEvent(ep_event_->mut_event());\n  }\n  launched_ = true;\n}\n\nEpOptionalEventRecordStatusQuerier::~EpOptionalEventRecordStatusQuerier() {}\n\n}  // namespace vm\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/vm/ep_optional_event_record_status_querier.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_VM_EP_OPTIONAL_EVENT_RECORD_STATUS_QUERIER_H_\n#define ONEFLOW_CORE_VM_EP_OPTIONAL_EVENT_RECORD_STATUS_QUERIER_H_\n\n#include <atomic>\n#include \"oneflow/core/vm/ep_event.h\"\n\nnamespace oneflow {\nnamespace vm {\n\nclass EpOptionalEventRecordStatusQuerier {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(EpOptionalEventRecordStatusQuerier);\n  ~EpOptionalEventRecordStatusQuerier();\n\n  bool launched() const { return launched_; }\n\n  bool done() const { return launched_ && (ep_event_ == nullptr || ep_event_->Query()); }\n\n  void SetLaunched(ep::Stream* stream);\n\n  void reset_ep_event(const std::shared_ptr<EpEvent>& ep_event) { ep_event_ = ep_event; }\n\n  static const EpOptionalEventRecordStatusQuerier* Cast(const char* mem_ptr) {\n    return reinterpret_cast<const EpOptionalEventRecordStatusQuerier*>(mem_ptr);\n  }\n  static EpOptionalEventRecordStatusQuerier* MutCast(char* mem_ptr) {\n    return reinterpret_cast<EpOptionalEventRecordStatusQuerier*>(mem_ptr);\n  }\n  static EpOptionalEventRecordStatusQuerier* PlacementNew(\n      char* mem_ptr, const std::shared_ptr<EpEvent>& ep_event) {\n    return new (mem_ptr) EpOptionalEventRecordStatusQuerier(ep_event);\n  }\n\n private:\n  explicit EpOptionalEventRecordStatusQuerier(const std::shared_ptr<EpEvent>& ep_event)\n      : launched_(false), ep_event_(ep_event) {}\n\n  std::atomic<bool> launched_;\n  std::shared_ptr<EpEvent> ep_event_;\n};\n\n}  // namespace vm\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_VM_EP_OPTIONAL_EVENT_RECORD_STATUS_QUERIER_H_\n"
  },
  {
    "path": "oneflow/core/vm/ep_record_event_instruction_policy.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_VM_EP_RECORD_EVENT_INSTRUCTION_POLICY_H_\n#define ONEFLOW_CORE_VM_EP_RECORD_EVENT_INSTRUCTION_POLICY_H_\n\n#include <memory>\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/vm/ep_optional_event_record_status_querier.h\"\n#include \"oneflow/core/vm/instruction_policy.h\"\n#include \"oneflow/core/eager/local_dep_object.h\"\n#include \"oneflow/core/vm/ep_stream_policy_base.h\"\n#include \"oneflow/core/vm/stream.h\"\n\nnamespace oneflow {\nnamespace vm {\nclass EpRecordEventInstructionPolicy final : public InstructionPolicy {\n public:\n  EpRecordEventInstructionPolicy(\n      small_vector<intrusive::shared_ptr<LocalDepObject>>&& compute_local_dep_objects,\n      const std::string& modifier)\n      : compute_local_dep_objects_(std::move(compute_local_dep_objects)),\n        modifier_(modifier),\n        input_dependences_(),\n        output_dependences_() {\n    ForEachConstDependence([&](auto* dep) { input_dependences_.emplace_back(dep); });\n    ForEachMutDependence([&](auto* dep) { output_dependences_.emplace_back(dep); });\n    ForEachMut2Dependence([&](auto* dep) { output_dependences_.emplace_back(dep); });\n  }\n\n  ~EpRecordEventInstructionPolicy() override = default;\n  const DependenceVector& input_dependences() const override { return input_dependences_; }\n  const DependenceVector& output_dependences() const override { return output_dependences_; }\n\n  template<typename DoEachT>\n  void ForEachConstDependence(const DoEachT& DoEach) const {\n    if (modifier_ == \"const\") {\n      for (const auto& dep : compute_local_dep_objects_) { DoEach(dep.get()); }\n    }\n  }\n\n  template<typename DoEachT>\n  void ForEachMutDependence(const DoEachT& DoEach) const {\n    if (modifier_ == \"mut\") {\n      for (const auto& dep : compute_local_dep_objects_) { DoEach(dep.get()); }\n    }\n  }\n\n  template<typename DoEachT>\n  void ForEachMut2Dependence(const DoEachT& DoEach) const {\n    if (modifier_ == \"mut2\") {\n      for (const auto& dep : compute_local_dep_objects_) { DoEach(dep.get()); }\n    }\n  }\n  InstructionFuseType fuse_type() const override { return kEnableInstructionFuseAsTailOnly; }\n\n  void InitInstructionStatus(Instruction* instruction) override {\n    auto* status_buffer = instruction->mut_status_buffer();\n    auto* stream = instruction->mut_stream();\n    instruction->stream_policy().InitInstructionStatus(*stream, status_buffer);\n    EpStreamPolicyBase* ep_stream_policy_base =\n        dynamic_cast<EpStreamPolicyBase*>(stream->mut_stream_policy());\n    CHECK_NOTNULL(ep_stream_policy_base);\n    auto* ep_event_provider = ep_stream_policy_base->ep_event_provider();\n    const auto& ep_event = CHECK_NOTNULL(ep_event_provider)->GetReusedEpEvent();\n    auto* data_ptr = status_buffer->mut_buffer();\n    EpOptionalEventRecordStatusQuerier::MutCast(data_ptr)->reset_ep_event(ep_event);\n  }\n  Maybe<void> Prepare(vm::Instruction* instruction) override { return Maybe<void>::Ok(); }\n  std::string DebugName(const vm::Instruction&) const override { return \"RecordEvent\"; }\n  void Compute(vm::Instruction* instruction) override {}\n\n private:\n  small_vector<intrusive::shared_ptr<LocalDepObject>> compute_local_dep_objects_;\n  const std::string modifier_;\n  DependenceVector input_dependences_;\n  DependenceVector output_dependences_;\n};\n\n}  // namespace vm\n\nstruct GetRecordEventInstructionPolicy : public StreamTypeVisitor<GetRecordEventInstructionPolicy> {\n  template<typename... Args>\n  static Maybe<vm::InstructionPolicy> VisitCompute(DeviceType device_type, Args&&... args) {\n    return std::shared_ptr<vm::InstructionPolicy>(\n        new vm::EpRecordEventInstructionPolicy(std::forward<Args>(args)...));\n  }\n  template<typename... Args>\n  static Maybe<vm::InstructionPolicy> VisitHost2Device(DeviceType device_type, Args&&... args) {\n    return std::shared_ptr<vm::InstructionPolicy>(\n        new vm::EpRecordEventInstructionPolicy(std::forward<Args>(args)...));\n  }\n  template<typename... Args>\n  static Maybe<vm::InstructionPolicy> VisitDevice2Host(DeviceType device_type, Args&&... args) {\n    return std::shared_ptr<vm::InstructionPolicy>(\n        new vm::EpRecordEventInstructionPolicy(std::forward<Args>(args)...));\n  }\n  template<typename... Args>\n  static Maybe<vm::InstructionPolicy> VisitCcl(DeviceType device_type, Args&&... args) {\n    return std::shared_ptr<vm::InstructionPolicy>(\n        new vm::EpRecordEventInstructionPolicy(std::forward<Args>(args)...));\n  }\n  template<typename... Args>\n  static Maybe<vm::InstructionPolicy> VisitBarrier(DeviceType device_type, Args&&... args) {\n    UNIMPLEMENTED_THEN_RETURN() << \"EpRecordEvent instruction not supported in Barrier stream\";\n  }\n  template<typename... Args>\n  static Maybe<vm::InstructionPolicy> VisitCriticalSection(DeviceType device_type, Args&&... args) {\n    UNIMPLEMENTED_THEN_RETURN()\n        << \"EpRecordEvent instruction not supported in CriticalSection stream\";\n  }\n  template<typename... Args>\n  static Maybe<vm::InstructionPolicy> VisitLazyJobLauncher(DeviceType device_type, Args&&... args) {\n    UNIMPLEMENTED_THEN_RETURN()\n        << \"EpRecordEvent instruction not supported in LaunchLazyJob stream\";\n  }\n  template<typename... Args>\n  static Maybe<vm::InstructionPolicy> VisitPinnedCompute(DeviceType device_type, Args&&... args) {\n    return std::shared_ptr<vm::InstructionPolicy>(\n        new vm::EpRecordEventInstructionPolicy(std::forward<Args>(args)...));\n  }\n};\n\n}  // namespace oneflow\n#endif  // ONEFLOW_CORE_EAGER_BLOB_INSTRUCTION_TYPE_H_\n"
  },
  {
    "path": "oneflow/core/vm/ep_stream_policy.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/vm/ep_stream_policy.h\"\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/common/stream_type.h\"\n#include \"oneflow/core/vm/remat/allocator.h\"\n#include \"oneflow/core/vm/stream.h\"\n#include \"oneflow/core/vm/thread_ctx.h\"\n#include \"oneflow/core/vm/ep_optional_event_record_status_querier.h\"\n#include \"oneflow/core/vm/ep_backend_allocator.h\"\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/vm/remat/util.h\"\n\nnamespace oneflow {\nnamespace vm {\n\nnamespace {\n\nstd::unique_ptr<vm::Allocator> CreateEpBackendDeviceAllocator(Symbol<Device> device) {\n  DeviceType device_type = device->enum_type();\n  size_t device_index = device->device_id();\n\n  if (device->rematable()) {\n    return std::make_unique<vm::DtrEpAllocatorProxy>(\n        Singleton<remat::AllocatorManager>::Get()->CreateOrGetAllocator(device_type, device_index));\n  } else {\n    auto ep_device =\n        Singleton<ep::DeviceManagerRegistry>::Get()->GetDevice(device_type, device_index);\n    auto ep_backend_allocator =\n        std::make_unique<EpBackendAllocator>(ep_device, ep::AllocationOptions{});\n    return std::make_unique<BinAllocator<ThreadSafeLock>>(ep::kMaxAlignmentRequirement,\n                                                          std::move(ep_backend_allocator));\n  }\n}\n\n}  // namespace\n\nEpStreamPolicy::EpStreamPolicy(Symbol<Device> device)\n    : EpStreamPolicyBase(device, CreateEpBackendDeviceAllocator(device)) {}\n\nvoid EpStreamPolicy::InitInstructionStatus(const Stream& stream,\n                                           InstructionStatusBuffer* status_buffer) const {\n  static_assert(sizeof(EpOptionalEventRecordStatusQuerier) < kInstructionStatusBufferBytes, \"\");\n  auto* data_ptr = status_buffer->mut_buffer();\n  EpOptionalEventRecordStatusQuerier::PlacementNew(data_ptr, nullptr);\n}\n\n}  // namespace vm\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/vm/ep_stream_policy.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_VM_EP_STREAM_POLICY_H_\n#define ONEFLOW_CORE_VM_EP_STREAM_POLICY_H_\n\n#include \"oneflow/core/vm/ep_stream_policy_base.h\"\n\nnamespace oneflow {\nnamespace vm {\n\nclass EpStreamPolicy final : public EpStreamPolicyBase {\n public:\n  EpStreamPolicy(Symbol<Device> device);\n  ~EpStreamPolicy() override = default;\n\n  void InitInstructionStatus(const Stream& stream,\n                             InstructionStatusBuffer* status_buffer) const override;\n};\n\n}  // namespace vm\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_VM_EP_STREAM_POLICY_H_\n"
  },
  {
    "path": "oneflow/core/vm/ep_stream_policy_base.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/vm/ep_stream_policy_base.h\"\n#include <memory>\n#include \"oneflow/core/vm/stream.h\"\n#include \"oneflow/core/vm/thread_ctx.h\"\n#include \"oneflow/core/vm/ep_optional_event_record_status_querier.h\"\n#include \"oneflow/core/vm/ep_backend_host_allocator.h\"\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/vm/instruction.h\"\n#include \"oneflow/core/profiler/profiler.h\"\n\nnamespace oneflow {\nnamespace vm {\n\nvoid EpStreamPolicyBase::DeleteInstructionStatus(const Stream& stream,\n                                                 InstructionStatusBuffer* status_buffer) const {\n  auto* ptr = EpOptionalEventRecordStatusQuerier::MutCast(status_buffer->mut_buffer());\n  ptr->~EpOptionalEventRecordStatusQuerier();\n}\n\nbool EpStreamPolicyBase::QueryInstructionStatusLaunched(\n    const Stream& stream, const InstructionStatusBuffer& status_buffer) const {\n  return EpOptionalEventRecordStatusQuerier::Cast(status_buffer.buffer())->launched();\n}\n\nbool EpStreamPolicyBase::QueryInstructionStatusDone(\n    const Stream& stream, const InstructionStatusBuffer& status_buffer) const {\n  return EpOptionalEventRecordStatusQuerier::Cast(status_buffer.buffer())->done();\n}\n\nvoid EpStreamPolicyBase::Run(Instruction* instruction) const {\n  OF_PROFILER_RANGE_GUARD(\"S:\" + instruction->DebugName());\n  auto* stream = instruction->mut_stream();\n  EpStreamPolicyBase* ep_stream_policy_base =\n      dynamic_cast<EpStreamPolicyBase*>(stream->mut_stream_policy());\n  CHECK_NOTNULL(ep_stream_policy_base);\n  auto* ep_device = ep_stream_policy_base->GetOrCreateEpDevice();\n  ep_device->SetAsActiveDevice();\n  instruction->Compute();\n  char* data_ptr = instruction->mut_status_buffer()->mut_buffer();\n  EpOptionalEventRecordStatusQuerier::MutCast(data_ptr)->SetLaunched(\n      stream->mut_stream_policy()->stream());\n}\n\n}  // namespace vm\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/vm/ep_stream_policy_base.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_VM_EP_STREAM_POLICY_BASE_H_\n#define ONEFLOW_CORE_VM_EP_STREAM_POLICY_BASE_H_\n\n#include \"oneflow/core/vm/stream_policy.h\"\n#include \"oneflow/core/framework/device.h\"\n#include \"oneflow/core/vm/ep_event.h\"\n#include \"oneflow/core/vm/bin_allocator.h\"\n#include \"oneflow/core/vm/thread_safe_guard.h\"\n#include \"oneflow/core/ep/include/device_manager_registry.h\"\n\nnamespace oneflow {\nnamespace vm {\n\nclass EpStreamPolicyBase : public StreamPolicy {\n public:\n  EpStreamPolicyBase(Symbol<Device> device, std::unique_ptr<vm::Allocator>&& backend_allocator)\n      : device_(device),\n        ep_event_provier_(),\n        ep_stream_(nullptr),\n        ep_allocator_(std::move(backend_allocator)) {}\n  virtual ~EpStreamPolicyBase() override {\n    if (ep_stream_ != nullptr) {\n      CHECK(ep_device_);\n      ep_device_->DestroyStream(ep_stream_);\n    }\n  }\n\n  ep::Stream* stream() override { return GetOrCreateEpStream(); }\n\n  vm::Allocator* mut_allocator() override { return ep_allocator_.get(); }\n\n  DeviceType device_type() const override { return device_->enum_type(); }\n\n  EpEventProvider* ep_event_provider() {\n    if (unlikely(ep_event_provier_ == nullptr)) {\n      ep_event_provier_.reset(new SingleThreadEpEventProvider(GetOrCreateEpDevice()));\n    }\n    return ep_event_provier_.get();\n  }\n\n  ep::Device* GetOrCreateEpDevice() const {\n    if (unlikely(ep_device_ == nullptr)) {\n      ep_device_ = Singleton<ep::DeviceManagerRegistry>::Get()->GetDevice(device_->enum_type(),\n                                                                          device_->device_id());\n      CHECK(ep_device_);\n    }\n    return ep_device_.get();\n  }\n\n  bool SupportingTransportInstructions() const override { return true; }\n\n  void DeleteInstructionStatus(const Stream& stream,\n                               InstructionStatusBuffer* status_buffer) const override;\n\n  bool QueryInstructionStatusLaunched(const Stream& stream,\n                                      const InstructionStatusBuffer& status_buffer) const override;\n\n  bool QueryInstructionStatusDone(const Stream& stream,\n                                  const InstructionStatusBuffer& status_buffer) const override;\n\n  void Run(Instruction* instruction) const override;\n\n private:\n  ep::Stream* GetOrCreateEpStream() const {\n    if (unlikely(ep_stream_ == nullptr)) {\n      ep_stream_ = GetOrCreateEpDevice()->CreateStream();\n      CHECK(ep_stream_ != nullptr);\n    }\n    return ep_stream_;\n  }\n\n  Symbol<Device> device_;\n  std::unique_ptr<EpEventProvider> ep_event_provier_;\n  mutable std::shared_ptr<ep::Device> ep_device_;\n  mutable ep::Stream* ep_stream_;\n  std::unique_ptr<vm::Allocator> ep_allocator_;\n};\n\n}  // namespace vm\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_VM_EP_STREAM_POLICY_BASE_H_\n"
  },
  {
    "path": "oneflow/core/vm/event_recorded_ep_stream_policy.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/vm/event_recorded_ep_stream_policy.h\"\n#include \"oneflow/core/vm/stream.h\"\n#include \"oneflow/core/vm/thread_ctx.h\"\n#include \"oneflow/core/vm/ep_optional_event_record_status_querier.h\"\n#include \"oneflow/core/vm/ep_backend_allocator.h\"\n#include \"oneflow/core/common/util.h\"\n\nnamespace oneflow {\nnamespace vm {\n\n/*static*/ std::unique_ptr<BinAllocator<ThreadSafeLock>>\nEventRecordedEpStreamPolicy::CreateEpBackendDeviceAllocator(Symbol<Device> device) {\n  DeviceType device_type = device->enum_type();\n  size_t device_index = device->device_id();\n  auto ep_device =\n      Singleton<ep::DeviceManagerRegistry>::Get()->GetDevice(device_type, device_index);\n  auto ep_backend_allocator =\n      std::make_unique<EpBackendAllocator>(ep_device, ep::AllocationOptions{});\n  return std::make_unique<BinAllocator<ThreadSafeLock>>(ep::kMaxAlignmentRequirement,\n                                                        std::move(ep_backend_allocator));\n}\n\nEventRecordedEpStreamPolicy::EventRecordedEpStreamPolicy(Symbol<Device> device,\n                                                         std::unique_ptr<vm::Allocator>&& allocator)\n    : EpStreamPolicyBase(device, std::move(allocator)) {}\n\nvoid EventRecordedEpStreamPolicy::InitInstructionStatus(\n    const Stream& stream, InstructionStatusBuffer* status_buffer) const {\n  static_assert(sizeof(EpOptionalEventRecordStatusQuerier) < kInstructionStatusBufferBytes, \"\");\n  EpStreamPolicyBase* ep_stream_policy_base =\n      dynamic_cast<EpStreamPolicyBase*>(const_cast<Stream&>(stream).mut_stream_policy());\n  CHECK_NOTNULL(ep_stream_policy_base);\n  auto* ep_event_provider = ep_stream_policy_base->ep_event_provider();\n  auto* data_ptr = status_buffer->mut_buffer();\n  const auto& ep_event = CHECK_NOTNULL(ep_event_provider)->GetReusedEpEvent();\n  EpOptionalEventRecordStatusQuerier::PlacementNew(data_ptr, ep_event);\n}\n\n}  // namespace vm\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/vm/event_recorded_ep_stream_policy.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_VM_EVENT_RECORDED_EP_STREAM_POLICY_H_\n#define ONEFLOW_CORE_VM_EVENT_RECORDED_EP_STREAM_POLICY_H_\n\n#include \"oneflow/core/vm/ep_stream_policy_base.h\"\n\nnamespace oneflow {\nnamespace vm {\n\nclass EventRecordedEpStreamPolicy final : public EpStreamPolicyBase {\n public:\n  EventRecordedEpStreamPolicy(Symbol<Device> device, std::unique_ptr<vm::Allocator>&& allocator);\n  ~EventRecordedEpStreamPolicy() override = default;\n\n  void InitInstructionStatus(const Stream& stream,\n                             InstructionStatusBuffer* status_buffer) const override;\n\n  static std::unique_ptr<BinAllocator<ThreadSafeLock>> CreateEpBackendDeviceAllocator(\n      Symbol<Device> device);\n};\n\n}  // namespace vm\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_VM_EVENT_RECORDED_EP_STREAM_POLICY_H_\n"
  },
  {
    "path": "oneflow/core/vm/fuse_instruction_policy.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_VM_FUSE_INSTRUCTION_POLICY_H_\n#define ONEFLOW_CORE_VM_FUSE_INSTRUCTION_POLICY_H_\n\n#include <functional>\n#include \"oneflow/core/vm/instruction.h\"\n#include \"oneflow/core/vm/instruction_policy_util.h\"\n#include \"oneflow/core/vm/vm_object.h\"\n\nnamespace oneflow {\nnamespace vm {\n\nclass FuseInstructionPolicy final : public InstructionPolicy {\n public:\n  explicit FuseInstructionPolicy(InstructionList&& instruction_list)\n      : instruction_list_(), input_dependences_(), output_dependences_() {\n    instruction_list.MoveTo(&instruction_list_);\n    auto ReadOnlyDepsInserter = InstructionPolicyUtil::SetInserter(&input_dependences_);\n    auto WritableDepsInserter = InstructionPolicyUtil::SetInserter(&output_dependences_);\n    auto* last_instruction = instruction_list_.Last();\n    INTRUSIVE_UNSAFE_FOR_EACH_PTR(instruction, &instruction_list_) {\n      if (instruction == last_instruction) {\n        CHECK(instruction->instruction_policy().fuse_type() == kEnableInstructionFuseAsTailOnly\n              || instruction->instruction_policy().fuse_type()\n                     == kEnableInstructionFuseAtAnyPosition);\n      } else {\n        CHECK(instruction->instruction_policy().fuse_type() == kEnableInstructionFuseAtAnyPosition);\n      }\n      if (unlikely(stream_sequential_dependence_ == nullptr)) {\n        stream_sequential_dependence_ =\n            instruction->instruction_policy().stream_sequential_dependence();\n      } else {\n        CHECK_EQ(stream_sequential_dependence_,\n                 instruction->instruction_policy().stream_sequential_dependence());\n      }\n      for (auto* dep : instruction->instruction_policy().input_dependences()) {\n        ReadOnlyDepsInserter(dep);\n      }\n      for (auto* dep : instruction->instruction_policy().output_dependences()) {\n        WritableDepsInserter(dep);\n      }\n    }\n  }\n\n  ~FuseInstructionPolicy() override = default;\n\n  const DependenceVector& input_dependences() const override { return input_dependences_; }\n  const DependenceVector& output_dependences() const override { return output_dependences_; }\n\n  InstructionList* mut_instruction_list() { return &instruction_list_; }\n\n private:\n  Maybe<void> Prepare(Instruction* instruction) override {\n    INTRUSIVE_UNSAFE_FOR_EACH_PTR(instruction, mut_instruction_list()) {\n      JUST(instruction->Prepare());\n    }\n    return Maybe<void>::Ok();\n  }\n  void Compute(Instruction* instruction) override {\n    OF_PROFILER_RANGE_GUARD(\"F:\" + instruction->DebugName());\n    INTRUSIVE_UNSAFE_FOR_EACH_PTR(instruction, mut_instruction_list()) { instruction->Compute(); }\n  }\n  void InitInstructionStatus(Instruction* instruction) override {\n    auto* last_instruction = CHECK_NOTNULL(mut_instruction_list()->Last());\n    last_instruction->mut_instruction_policy()->InitInstructionStatusIf(instruction);\n  }\n\n  std::string DebugName(const Instruction&) const override { return \"Fuse\"; }\n\n  InstructionList instruction_list_;\n  DependenceVector input_dependences_;\n  DependenceVector output_dependences_;\n};\n\n}  // namespace vm\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_VM_FUSE_INSTRUCTION_POLICY_H_\n"
  },
  {
    "path": "oneflow/core/vm/global_sync_instruction_policy.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_VM_GLOBAL_SYNC_INSTRUCTION_POLICY_H_\n#define ONEFLOW_CORE_VM_GLOBAL_SYNC_INSTRUCTION_POLICY_H_\n\n#include \"oneflow/core/rpc/include/base.h\"\n#include \"oneflow/core/vm/instruction_policy.h\"\nnamespace oneflow {\nnamespace vm {\n\nclass GlobalSyncInstructionPolicy final : public InstructionPolicy {\n public:\n  GlobalSyncInstructionPolicy() = default;\n  ~GlobalSyncInstructionPolicy() override = default;\n\n  const DependenceVector& input_dependences() const override {\n    static DependenceVector dependences{};\n    return dependences;\n  }\n  const DependenceVector& output_dependences() const override {\n    static DependenceVector dependences{};\n    return dependences;\n  }\n\n  bool IsBarrier() const override { return true; }\n\n  std::string DebugName(const vm::Instruction& instruction) const override { return \"GlobalSync\"; }\n  Maybe<void> Prepare(Instruction* instruction) override { return Maybe<void>::Ok(); }\n  void Compute(Instruction* instruction) override { OF_ENV_BARRIER(); }\n};\n\n}  // namespace vm\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_VM_GLOBAL_SYNC_INSTRUCTION_POLICY_H_\n"
  },
  {
    "path": "oneflow/core/vm/instruction.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/vm/instruction.h\"\n#include \"oneflow/core/thread/thread_manager.h\"\n#include \"oneflow/core/vm/stream.h\"\n#include \"oneflow/core/vm/thread_ctx.h\"\n#include \"oneflow/core/vm/virtual_machine_engine.h\"\n#include \"oneflow/core/framework/stream_get_stream_type_name.h\"\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/common/cpp_attribute.h\"\n#include \"oneflow/extension/stack/foreign_stack_getter.h\"\n#include \"oneflow/core/profiler/profiler.h\"\n\nnamespace oneflow {\nnamespace vm {\n\nstd::string Instruction::DebugName() const {\n  std::string instr_name = instruction_policy().DebugName(*this);\n  return instr_name + \":s_\" + GetStreamTypeName::Visit(stream().stream_type());\n}\n\nvoid Instruction::__Init__(Stream* stream,\n                           std::shared_ptr<InstructionPolicy>&& instruction_policy) {\n  stream_ = stream;\n  instruction_policy_ = std::move(instruction_policy);\n  if (IsMainThread()) {\n    if (auto* stack_getter = Singleton<ForeignStackGetter>::Get()) {\n      foreign_frame_ = stack_getter->GetCurrentFrame();\n    }\n  }\n}\n\nvoid Instruction::InitStatus() { instruction_policy_->InitInstructionStatusIf(this); }\n\nMaybe<void> Instruction::Prepare() {\n  ForeignFrameThreadLocalGuard guard(foreign_frame_);\n  return instruction_policy_->PrepareIf(this);\n}\nvoid Instruction::Compute() {\n  ForeignFrameThreadLocalGuard guard(foreign_frame_);\n  instruction_policy_->ComputeIf(this);\n}\n\nvoid Instruction::DeleteStatusAndCheckEdges() {\n  OF_PROFILER_RANGE_GUARD(\"Instruction::DeleteStatusAndCheckEdges\");\n  instruction_policy_->DeleteInstructionStatusIf(this);\n  INTRUSIVE_FOR_EACH_PTR(edge, mut_in_edges()) {\n    Instruction* in_instruction = edge->mut_src_instruction();\n    LOG(FATAL) << \"unerased edge: \" << in_instruction->DebugName() << \" -> \" << this->DebugName();\n  }\n  INTRUSIVE_FOR_EACH_PTR(edge, mut_out_edges()) {\n    Instruction* out_instruction = edge->mut_dst_instruction();\n    LOG(FATAL) << \"unerased edge: \" << this->DebugName() << \" -> \" << out_instruction->DebugName();\n  }\n}\n\nbool Instruction::Launched() const {\n  return stream_policy().QueryInstructionStatusLaunched(stream(), status_buffer());\n}\n\nbool Instruction::Done() const {\n  return stream_policy().QueryInstructionStatusDone(stream(), status_buffer());\n}\n\nStreamPolicy* Instruction::mut_stream_policy() { return mut_stream()->mut_stream_policy(); }\n\nconst StreamPolicy& Instruction::stream_policy() const { return stream().stream_policy(); }\n\n}  // namespace vm\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/vm/instruction.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_VM_VPU_INSTRUCTION__H_\n#define ONEFLOW_CORE_VM_VPU_INSTRUCTION__H_\n\n#include <cstring>\n#include <memory>\n#include <mutex>\n#include \"oneflow/core/common/symbol.h\"\n#include \"oneflow/core/intrusive/intrusive.h\"\n#include \"oneflow/core/intrusive/object_pool.h\"\n#include \"oneflow/core/vm/vm_object.h\"\n#include \"oneflow/core/vm/instruction_policy.h\"\n#include \"oneflow/core/vm/stream_policy.h\"\n#include \"oneflow/extension/stack/foreign_stack_getter.h\"\n\nnamespace oneflow {\n\nclass Stream;\n\nnamespace vm {\n\nstatic const int kInstructionStatusBufferBytes = 64;\n\nclass InstructionStatusBuffer final {\n public:\n  InstructionStatusBuffer() = default;\n  ~InstructionStatusBuffer() = default;\n\n  const char* buffer() const { return &buffer_[0]; }\n  char* mut_buffer() { return &buffer_[0]; }\n\n private:\n  char buffer_[kInstructionStatusBufferBytes];\n};\n\nclass Instruction;\nclass InstructionEdge final\n    : public intrusive::Base,\n      public intrusive::EnableObjectPool<InstructionEdge,\n                                         intrusive::kThreadUnsafeAndDisableDestruct> {\n public:\n  InstructionEdge()\n      : intrusive_ref_(),\n        src_instruction_(),\n        dst_instruction_(),\n        in_edge_hook_(),\n        out_edge_hook_() {}\n  void __Init__() {\n    clear_src_instruction();\n    clear_dst_instruction();\n  }\n  // Getters\n  bool has_src_instruction() const { return src_instruction_ != nullptr; }\n  bool has_dst_instruction() const { return dst_instruction_ != nullptr; }\n  const Instruction& src_instruction() const { return *src_instruction_; }\n  const Instruction& dst_instruction() const { return *dst_instruction_; }\n  // Setters\n  void set_src_instruction(Instruction* val) { src_instruction_ = val; }\n  void set_dst_instruction(Instruction* val) { dst_instruction_ = val; }\n  void clear_src_instruction() { src_instruction_ = nullptr; }\n  void clear_dst_instruction() { dst_instruction_ = nullptr; }\n  Instruction* mut_src_instruction() { return src_instruction_; }\n  Instruction* mut_dst_instruction() { return dst_instruction_; }\n  // methods\n  void __Init__(Instruction* src_instruction, Instruction* dst_instruction) {\n    __Init__();\n    set_src_instruction(src_instruction);\n    set_dst_instruction(dst_instruction);\n  }\n\n  intrusive::Ref* mut_intrusive_ref() { return &intrusive_ref_; }\n\n private:\n  intrusive::Ref intrusive_ref_;\n  // fields\n  Instruction* src_instruction_;\n  Instruction* dst_instruction_;\n\n public:\n  // list hooks\n  intrusive::ListHook in_edge_hook_;\n  intrusive::ListHook out_edge_hook_;\n};\n\nclass Stream;\nclass Instruction final : public intrusive::Base {\n public:\n  // types\n  using InEdgeList = intrusive::List<INTRUSIVE_FIELD(InstructionEdge, in_edge_hook_)>;\n  using OutEdgeList = intrusive::List<INTRUSIVE_FIELD(InstructionEdge, out_edge_hook_)>;\n  using DependenceAccessList =\n      intrusive::List<INTRUSIVE_FIELD(DependenceAccess, instruction_access_hook_)>;\n\n  void __Init__(Stream* stream, std::shared_ptr<InstructionPolicy>&& instruction_policy);\n\n  // Getters\n  const Stream& stream() const { return *stream_; }\n  const InstructionStatusBuffer& status_buffer() const { return status_buffer_; }\n  const intrusive::ListHook& main_instruction_hook() const { return main_instruction_hook_; }\n  const InstructionPolicy& instruction_policy() const { return *instruction_policy_; }\n  std::string DebugName() const;\n\n  const intrusive::ListHook& dispatched_instruction_hook() const {\n    return dispatched_instruction_hook_;\n  }\n  const intrusive::ListHook& lively_instruction_hook() const { return lively_instruction_hook_; }\n  const intrusive::ListHook& worker_pending_instruction_hook() const {\n    return worker_pending_instruction_hook_;\n  }\n  const intrusive::ListHook& barrier_instruction_hook() const { return barrier_instruction_hook_; }\n  const InEdgeList& in_edges() const { return in_edges_; }\n  const OutEdgeList& out_edges() const { return out_edges_; }\n  const DependenceAccessList& access_list() const { return access_list_; }\n\n  Maybe<void> Prepare();\n  void Compute();\n\n  // Setters\n  Stream* mut_stream() { return stream_; }\n  InstructionStatusBuffer* mut_status_buffer() { return &status_buffer_; }\n  InstructionPolicy* mut_instruction_policy() { return instruction_policy_.get(); }\n  InEdgeList* mut_in_edges() { return &in_edges_; }\n  OutEdgeList* mut_out_edges() { return &out_edges_; }\n  DependenceAccessList* mut_access_list() { return &access_list_; }\n\n  // methods\n  void InitStatus();\n  void DeleteStatusAndCheckEdges();\n  bool Launched() const;\n  bool Done() const;\n  StreamPolicy* mut_stream_policy();\n  const StreamPolicy& stream_policy() const;\n  std::shared_ptr<Frame> foreign_frame() const { return foreign_frame_; }\n\n  intrusive::Ref::RefCntType ref_cnt() const { return intrusive_ref_.ref_cnt(); }\n\n  // used for instructions building, pending to scheduler, constructing DAG, pending to callback\n  // thread and so on.\n  // lifetime of barrier instructions:\n  //\n  //   |<-----main_instruction_hook_----->|\n  //                                    |<-----------lively_instruction_hook_---------------->|\n  //                                          |<---------barrier_instruction_hook_--------->|\n  //\n  //\n  // lifetime of non-barrier instructions:\n  //\n  //   |<-----main_instruction_hook_----->|\n  //                                    |<-----------lively_instruction_hook_---------------->|\n  //                                          |<-------dispatched_instruction_hook_-------->|\n  //                                               |<--worker_pending_instruction_hook_-->|\n  //\n  //\n  intrusive::ListHook main_instruction_hook_;\n  // dispatched to Stream\n  intrusive::ListHook dispatched_instruction_hook_;\n  // valid during vm processing\n  intrusive::ListHook lively_instruction_hook_;\n  // pending to ThreadCtx\n  intrusive::ListHook worker_pending_instruction_hook_;\n  // for barrier instruction\n  intrusive::ListHook barrier_instruction_hook_;\n\n private:\n  friend class intrusive::Ref;\n  intrusive::Ref* mut_intrusive_ref() { return &intrusive_ref_; }\n\n  Instruction()\n      : main_instruction_hook_(),\n        dispatched_instruction_hook_(),\n        lively_instruction_hook_(),\n        worker_pending_instruction_hook_(),\n        barrier_instruction_hook_(),\n        access_list_(),\n        in_edges_(),\n        out_edges_(),\n        intrusive_ref_(),\n        stream_(),\n        instruction_policy_(),\n        status_buffer_() {}\n\n  // lists\n  DependenceAccessList access_list_;\n  InEdgeList in_edges_;\n  OutEdgeList out_edges_;\n\n  // fields\n  intrusive::Ref intrusive_ref_;\n  Stream* stream_;\n  std::shared_ptr<InstructionPolicy> instruction_policy_;\n  InstructionStatusBuffer status_buffer_;\n  std::shared_ptr<Frame> foreign_frame_;\n};\n\nusing InstructionList = intrusive::List<INTRUSIVE_FIELD(Instruction, main_instruction_hook_)>;\n\n}  // namespace vm\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_VM_VPU_INSTRUCTION__H_\n"
  },
  {
    "path": "oneflow/core/vm/instruction_fuse_type.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_VM_INSTRUCTION_FUSE_TYPE_H_\n#define ONEFLOW_CORE_VM_INSTRUCTION_FUSE_TYPE_H_\n\nnamespace oneflow {\nnamespace vm {\n\nenum InstructionFuseType {\n  kInvalidInstructionFuseType = 0,\n  kDisableInstructionFuse,\n  kEnableInstructionFuseAtAnyPosition,\n  kEnableInstructionFuseAsTailOnly,\n};\n\n}\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_VM_INSTRUCTION_FUSE_TYPE_H_\n"
  },
  {
    "path": "oneflow/core/vm/instruction_policy.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/vm/instruction_policy.h\"\n#include \"oneflow/core/vm/instruction.h\"\n#include \"oneflow/core/eager/eager_blob_object.h\"\n#include \"oneflow/core/common/util.h\"\n\nnamespace oneflow {\nnamespace vm {\n\nvoid InstructionPolicy::InitInstructionStatus(Instruction* instruction) {\n  instruction->stream_policy().InitInstructionStatus(instruction->stream(),\n                                                     instruction->mut_status_buffer());\n}\n\nvoid InstructionPolicy::DeleteInstructionStatus(Instruction* instruction) {\n  instruction->stream_policy().DeleteInstructionStatus(instruction->stream(),\n                                                       instruction->mut_status_buffer());\n}\n\n}  // namespace vm\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/vm/instruction_policy.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_VM_INSTRUCTION_POLICY_H_\n#define ONEFLOW_CORE_VM_INSTRUCTION_POLICY_H_\n\n#include <functional>\n#include <vector>\n#include <memory>\n#include \"oneflow/core/intrusive/intrusive.h\"\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/profiler/profiler.h\"\n#include \"oneflow/core/vm/instruction_fuse_type.h\"\n#include \"oneflow/core/vm/vm_object.h\"\n\nnamespace oneflow {\nnamespace vm {\n\nclass EagerBlobObject;\nclass Stream;\n\nclass InstructionPolicy {\n public:\n  virtual ~InstructionPolicy() = default;\n\n  // Same stream.\n  virtual bool Prescheduleable(const vm::Stream* src, const vm::Stream* dst) const {\n    return src == dst;\n  }\n\n  virtual const DependenceVector& input_dependences() const = 0;\n  virtual const DependenceVector& output_dependences() const = 0;\n  virtual Dependence* stream_sequential_dependence() const { return stream_sequential_dependence_; }\n\n  virtual bool IsBarrier() const { return false; }\n  virtual InstructionFuseType fuse_type() const { return kDisableInstructionFuse; }\n  virtual std::string DebugName(const Instruction&) const = 0;\n\n  Maybe<void> PrepareIf(Instruction* instruction) {\n    OF_PROFILER_RANGE_GUARD(std::string(\"Prepare:\") + DebugName(*instruction));\n    return Prepare(instruction);\n  }\n\n  void ComputeIf(Instruction* instruction) {\n    OF_PROFILER_RANGE_GUARD(std::string(\"Compute:\") + DebugName(*instruction));\n    Compute(instruction);\n  }\n\n  void InitInstructionStatusIf(Instruction* instruction) { InitInstructionStatus(instruction); }\n\n  void DeleteInstructionStatusIf(Instruction* instruction) { DeleteInstructionStatus(instruction); }\n\n protected:\n  InstructionPolicy() : stream_sequential_dependence_(nullptr) {}\n\n  Dependence* stream_sequential_dependence_;\n\n private:\n  // Usually for Allocating and deallocating tensors.\n  virtual Maybe<void> Prepare(Instruction* instruction) = 0;\n  virtual void Compute(Instruction* instruction) = 0;\n  virtual void InitInstructionStatus(Instruction* instruction);\n  virtual void DeleteInstructionStatus(Instruction* instruction);\n};\n\n}  // namespace vm\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_VM_INSTRUCTION_POLICY_H_\n"
  },
  {
    "path": "oneflow/core/vm/instruction_policy_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_VM_INSTRUCTION_POLICY_UTIL_H_\n#define ONEFLOW_CORE_VM_INSTRUCTION_POLICY_UTIL_H_\n\n#include <functional>\n#include <set>\n#include \"oneflow/core/vm/vm_object.h\"\n\nnamespace oneflow {\nnamespace vm {\n\nstruct InstructionPolicyUtil {\n  static std::function<void(Dependence*)> SetInserter(DependenceVector* dependences) {\n    auto existed =\n        std::make_shared<std::set<Dependence*>>(dependences->begin(), dependences->end());\n    return [dependences, existed](Dependence* object) {\n      if (existed->insert(object).second) { dependences->push_back(object); }\n    };\n  }\n};\n\n}  // namespace vm\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_VM_INSTRUCTION_POLICY_UTIL_H_\n"
  },
  {
    "path": "oneflow/core/vm/lazy_job_instruction_policy.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_VM_LAZY_JOB_INSTRUCTION_POLICY_H_\n#define ONEFLOW_CORE_VM_LAZY_JOB_INSTRUCTION_POLICY_H_\n\n#include \"oneflow/core/common/buffer_manager.h\"\n#include \"oneflow/core/common/of_unused.h\"\n#include \"oneflow/core/eager/eager_blob_object.h\"\n#include \"oneflow/core/framework/nn_graph_if.h\"\n#include \"oneflow/core/framework/stream.h\"\n#include \"oneflow/core/job/job_instance.h\"\n#include \"oneflow/core/vm/instruction_policy.h\"\n#include \"oneflow/core/vm/instruction_policy_util.h\"\n#include \"oneflow/core/vm/naive_instruction_status_querier.h\"\n#include \"oneflow/core/vm/lazy_job_stream_policy.h\"\n#include \"oneflow/core/vm/virtual_machine.h\"\n#include <robin_hood.h>\n\nnamespace oneflow {\n\nclass LazyJobInstance final : public JobInstance {\n public:\n  LazyJobInstance(const LazyJobInstance&) = delete;\n  LazyJobInstance(LazyJobInstance&&) = delete;\n  ~LazyJobInstance() override = default;\n  LazyJobInstance(const std::string& job_name, const std::function<void()>& finish_cb)\n      : job_name_(job_name), finish_cb_(finish_cb) {}\n\n  std::string job_name() const override { return job_name_; }\n  void Finish() const override { finish_cb_(); }\n\n private:\n  const std::string job_name_;\n  const std::function<void()> finish_cb_;\n};\n\nnamespace vm {\n\nclass LaunchLazyJobInstructionPolicy final : public InstructionPolicy {  // NOLINT\n public:\n  LaunchLazyJobInstructionPolicy(const LaunchLazyJobInstructionPolicy&) = delete;\n  LaunchLazyJobInstructionPolicy(LaunchLazyJobInstructionPolicy&&) = delete;\n  ~LaunchLazyJobInstructionPolicy() = default;\n\n  LaunchLazyJobInstructionPolicy(const std::shared_ptr<NNGraphIf>& nn_graph,\n                                 const EagerBlobObjectListPtr& param_blob_objects)\n      : nn_graph_(nn_graph),\n        param_blob_objects_(param_blob_objects),\n        input_dependences_(),\n        output_dependences_() {\n    robin_hood::unordered_flat_map<Dependence*, bool> unique_map;\n    ForEachConstDependence([&](Dependence* compute) {\n      if (unique_map.emplace(compute, true).second) { input_dependences_.emplace_back(compute); }\n    });\n    unique_map.clear();\n    output_dependences_.reserve(param_blob_objects_->size());\n    unique_map.reserve(param_blob_objects_->size());\n    ForEachMutDependence([&](Dependence* compute) {\n      if (unique_map.emplace(compute, true).second) { output_dependences_.emplace_back(compute); }\n    });\n    ForEachMut2Dependence([&](Dependence* compute) {\n      if (unique_map.emplace(compute, true).second) { output_dependences_.emplace_back(compute); }\n    });\n  }\n\n  const DependenceVector& input_dependences() const override { return input_dependences_; }\n  const DependenceVector& output_dependences() const override { return output_dependences_; }\n\n  void ForEachConstDependence(const std::function<void(Dependence* compute)>&) const {}\n\n  void ForEachMutDependence(const std::function<void(Dependence* compute)>& DoEach) const {\n    for (const auto& eager_blob_object : *param_blob_objects_) {\n      DoEach(CHECK_JUST(eager_blob_object->compute_local_dep_object()));\n    }\n    DoEach(CHECK_JUST(SingletonMaybe<VirtualMachine>())\n               ->FindOrCreateTransportLocalDepObject()\n               .Mutable());\n  }\n\n  void ForEachMut2Dependence(const std::function<void(Dependence* compute)>&) const {}\n\n  std::string DebugName(const Instruction&) const override { return \"LaunchLazyJob\"; }\n  Maybe<void> Prepare(Instruction* instruction) override { return Maybe<void>::Ok(); }\n  void Compute(Instruction* instruction) override {\n    VLOG(3) << \" VM try launch Graph: \" << nn_graph_->job_name()\n            << \" in run_cnt: \" << nn_graph_->run_cnt() << \" START.\";\n    auto* lazy_job_stream_policy = GetLazyJobStreamPolicy(instruction);\n    {\n      OF_PROFILER_RANGE_GUARD(\"WaitUntilQueueEmptyIfFrontNNGraphNotEquals\");\n      lazy_job_stream_policy->WaitUntilQueueEmptyIfFrontNNGraphNotEquals(nn_graph_);\n      VLOG(3) << \" VM launch Graph: \" << nn_graph_->job_name()\n              << \" in run_cnt: \" << nn_graph_->run_cnt()\n              << \" WaitUntilQueueEmptyIfFrontNNGraphNotEquals.\";\n    }\n    {\n      OF_PROFILER_RANGE_GUARD(\"Send all buffers to BufferMgr\");\n      const auto& job_instance = MakeJobInstance(instruction);\n      const auto& job_name = job_instance->job_name();\n      auto* buffer_mgr = Singleton<BufferMgr<std::shared_ptr<JobInstance>>>::Get();\n      buffer_mgr->Get(GetCallbackNotifierBufferName(job_name))->Push(job_instance);\n      VLOG(3) << \" VM Push CallbackNotifier to Graph: \" << nn_graph_->job_name()\n              << \" in run_cnt: \" << nn_graph_->run_cnt();\n      buffer_mgr->Get(GetSourceTickBufferName(job_name))->Push(job_instance);\n      VLOG(3) << \" VM Push SourceTick to Graph: \" << nn_graph_->job_name()\n              << \" in run_cnt: \" << nn_graph_->run_cnt();\n    }\n    OF_PROFILER_RANGE_GUARD(\"EnqueueNNGraph\");\n    lazy_job_stream_policy->EnqueueNNGraph(nn_graph_);\n    VLOG(3) << \" VM Enqueue Graph: \" << nn_graph_->job_name()\n            << \" run_cnt: \" << nn_graph_->run_cnt() << \" END.\";\n    nn_graph_->NextRunCnt();\n  }\n\n private:\n  LazyJobStreamPolicy* GetLazyJobStreamPolicy(Instruction* instruction) const {\n    StreamPolicy* stream_policy = instruction->mut_stream()->mut_stream_policy();\n    LazyJobStreamPolicy* lazy_job_stream_policy = dynamic_cast<LazyJobStreamPolicy*>(stream_policy);\n    CHECK_NOTNULL(lazy_job_stream_policy);\n    return lazy_job_stream_policy;\n  }\n\n  std::shared_ptr<LazyJobInstance> MakeJobInstance(Instruction* instruction) const {\n    const auto& FinishCb = [this, instruction]() {\n      auto* lazy_job_stream_policy = GetLazyJobStreamPolicy(instruction);\n      lazy_job_stream_policy->DequeueNNGraph();\n      auto* status_buffer = instruction->mut_status_buffer();\n      NaiveInstrStatusQuerier::MutCast(status_buffer->mut_buffer())->set_done();\n    };\n    return std::make_shared<LazyJobInstance>(nn_graph_->job_name(), FinishCb);\n  }\n\n  std::shared_ptr<NNGraphIf> nn_graph_;\n  EagerBlobObjectListPtr param_blob_objects_;\n  DependenceVector input_dependences_;\n  DependenceVector output_dependences_;\n};\n\n}  // namespace vm\n}  // namespace oneflow\n#endif  // ONEFLOW_CORE_VM_LAZY_JOB_INSTRUCTION_POLICY_H_\n"
  },
  {
    "path": "oneflow/core/vm/lazy_job_stream_policy.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/vm/lazy_job_stream_policy.h\"\n#include \"oneflow/core/vm/thread_ctx.h\"\n#include \"oneflow/core/vm/naive_instruction_status_querier.h\"\n#include \"oneflow/core/framework/nn_graph_if.h\"\n#include \"oneflow/core/common/util.h\"\n\nnamespace oneflow {\nnamespace vm {\n\nvoid LazyJobStreamPolicy::WaitUntilQueueEmptyIfFrontNNGraphNotEquals(\n    const std::shared_ptr<NNGraphIf>& nn_graph) {\n  std::unique_lock<std::mutex> lock(mutex_);\n  if (queue_.empty()) { return; }\n  const auto& last_nn_graph = queue_.front().lock();\n  if (!last_nn_graph) { return; }\n  if (last_nn_graph == nn_graph) { return; }\n  cond_.wait(lock, [this]() { return queue_.empty(); });\n}\n\nvoid LazyJobStreamPolicy::EnqueueNNGraph(const std::shared_ptr<NNGraphIf>& nn_graph) {\n  std::unique_lock<std::mutex> lock(mutex_);\n  queue_.emplace(nn_graph);\n}\n\nvoid LazyJobStreamPolicy::DequeueNNGraph() {\n  std::unique_lock<std::mutex> lock(mutex_);\n  queue_.pop();\n  cond_.notify_all();\n}\n\nvoid LazyJobStreamPolicy::InitInstructionStatus(const Stream& stream,\n                                                InstructionStatusBuffer* status_buffer) const {\n  static_assert(sizeof(NaiveInstrStatusQuerier) < kInstructionStatusBufferBytes, \"\");\n  NaiveInstrStatusQuerier::PlacementNew(status_buffer->mut_buffer());\n}\n\nvoid LazyJobStreamPolicy::DeleteInstructionStatus(const Stream& stream,\n                                                  InstructionStatusBuffer* status_buffer) const {\n  auto* ptr = NaiveInstrStatusQuerier::MutCast(status_buffer->mut_buffer());\n  ptr->~NaiveInstrStatusQuerier();\n}\n\nbool LazyJobStreamPolicy::QueryInstructionStatusLaunched(\n    const Stream& stream, const InstructionStatusBuffer& status_buffer) const {\n  return NaiveInstrStatusQuerier::Cast(status_buffer.buffer())->launched();\n}\n\nbool LazyJobStreamPolicy::QueryInstructionStatusDone(\n    const Stream& stream, const InstructionStatusBuffer& status_buffer) const {\n  return NaiveInstrStatusQuerier::Cast(status_buffer.buffer())->done();\n}\n\nvoid LazyJobStreamPolicy::Run(Instruction* instruction) const { instruction->Compute(); }\n\n}  // namespace vm\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/vm/lazy_job_stream_policy.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#ifndef ONEFLOW_CORE_VM_LAZY_JOB_STREAM_POLICY_H_\n#define ONEFLOW_CORE_VM_LAZY_JOB_STREAM_POLICY_H_\n\n#include \"oneflow/core/vm/stream_policy.h\"\n#include \"oneflow/core/vm/instruction.h\"\n\nnamespace oneflow {\n\nclass NNGraphIf;\n\nnamespace vm {\n\nclass LazyJobStreamPolicy final : public StreamPolicy {\n public:\n  LazyJobStreamPolicy() = default;\n  virtual ~LazyJobStreamPolicy() = default;\n\n  vm::Allocator* mut_allocator() override { return (vm::Allocator*)nullptr; }\n\n  DeviceType device_type() const override {\n    UNIMPLEMENTED();\n    return DeviceType::kInvalidDevice;\n  }\n\n  ep::Stream* stream() override {\n    UNIMPLEMENTED();\n    return nullptr;\n  }\n\n  void WaitUntilQueueEmptyIfFrontNNGraphNotEquals(const std::shared_ptr<NNGraphIf>& nn_graph);\n\n  void EnqueueNNGraph(const std::shared_ptr<NNGraphIf>& nn_graph);\n\n  void DequeueNNGraph();\n\n  void InitInstructionStatus(const Stream& stream,\n                             InstructionStatusBuffer* status_buffer) const override;\n  void DeleteInstructionStatus(const Stream& stream,\n                               InstructionStatusBuffer* status_buffer) const override;\n  bool QueryInstructionStatusLaunched(const Stream& stream,\n                                      const InstructionStatusBuffer& status_buffer) const override;\n  bool QueryInstructionStatusDone(const Stream& stream,\n                                  const InstructionStatusBuffer& status_buffer) const override;\n  void Run(Instruction* instruction) const override;\n  bool SupportingTransportInstructions() const override { return false; }\n\n private:\n  std::queue<std::weak_ptr<NNGraphIf>> queue_;\n  std::mutex mutex_;\n  std::condition_variable cond_;\n};\n\n}  // namespace vm\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_VM_LAZY_JOB_STREAM_POLICY_H_\n"
  },
  {
    "path": "oneflow/core/vm/naive_instruction_status_querier.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_VM_NAIVE_VM_INSTRUCTION_STATUS_QUERIER_H_\n#define ONEFLOW_CORE_VM_NAIVE_VM_INSTRUCTION_STATUS_QUERIER_H_\n\n#include <atomic>\n\nnamespace oneflow {\nnamespace vm {\n\nclass NaiveInstrStatusQuerier {\n public:\n  ~NaiveInstrStatusQuerier() = default;\n\n  bool launched() const { return done_; }\n  bool done() const { return done_; }\n  void set_done() { done_ = true; }\n\n  static const NaiveInstrStatusQuerier* Cast(const char* mem_ptr) {\n    return reinterpret_cast<const NaiveInstrStatusQuerier*>(mem_ptr);\n  }\n  static NaiveInstrStatusQuerier* MutCast(char* mem_ptr) {\n    return reinterpret_cast<NaiveInstrStatusQuerier*>(mem_ptr);\n  }\n  static NaiveInstrStatusQuerier* PlacementNew(char* mem_ptr) {\n    return new (mem_ptr) NaiveInstrStatusQuerier();\n  }\n\n private:\n  NaiveInstrStatusQuerier() : done_(false) {}\n  std::atomic<bool> done_;\n};\n\n}  // namespace vm\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_VM_NAIVE_VM_INSTRUCTION_STATUS_QUERIER_H_\n"
  },
  {
    "path": "oneflow/core/vm/op_call_instruction_policy.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/vm/op_call_instruction_policy.h\"\n#include <fmt/format.h>\n#include <algorithm>\n#include \"oneflow/core/common/env_var/vm.h\"\n#include \"oneflow/core/vm/allocator.h\"\n#include \"oneflow/core/vm/remat/allocator.h\"\n#include \"oneflow/core/vm/remat/disjoint_set.h\"\n#include \"oneflow/core/vm/remat/env.h\"\n#include \"oneflow/core/vm/remat/util.h\"\n#include \"oneflow/user/kernels/stateful_opkernel.h\"\n#include \"oneflow/core/eager/dev_vm_dep_object_consume_mode.h\"\n#include \"oneflow/core/eager/tensor_storage.h\"\n#include \"oneflow/core/framework/stream_is_comm_net_stream.h\"\n#include \"oneflow/core/framework/stream_get_stream_type_name.h\"\n#include \"oneflow/core/vm/stream_get_allocator_stream_type.h\"\n#include \"oneflow/core/profiler/profiler.h\"\n#include \"fmt/core.h\"\n\nnamespace oneflow {\nnamespace vm {\n\nstruct OpCallInstructionUtil final {\n  static inline Maybe<void> Prepare(OpCallInstructionPolicy* op_call_instruction_policy,\n                                    Instruction* instruction) {\n    VLOG_REMAT(1) << \"prepare \" << op_call_instruction_policy->opkernel().op_type_name()\n                  << std::endl;\n    if (unlikely(op_call_instruction_policy->need_temp_storage())) {\n      InferTempStorageSize(op_call_instruction_policy);\n    }\n    return Maybe<void>::Ok();\n  }\n\n  static inline Maybe<void> Compute(OpCallInstructionPolicy* op_call_instruction_policy,\n                                    vm::Stream* vm_stream, bool first, bool recompute) {\n    Allocator* allocator = vm_stream->mut_stream_policy()->mut_allocator();\n    const auto [remat_helper, inputs_rematable, outputs_rematable] =\n        InitRematInfo(op_call_instruction_policy, vm_stream);\n    const auto& current_op_type_name = op_call_instruction_policy->opkernel().op_type_name();\n    ThreadLocalGuard<remat::CurrentOpTypeName> current_op_type_name_guard({current_op_type_name});\n    if (inputs_rematable || outputs_rematable) {\n      VLOG_REMAT(2) << \"set current op type name to \" << current_op_type_name << std::endl;\n      VLOG_REMAT(2) << \"op: \" << op_call_instruction_policy->opkernel().op_type_name() << std::endl;\n      VLOG_REMAT(2) << \"input_rematable: \" << inputs_rematable\n                    << \", output_rematable: \" << outputs_rematable << std::endl;\n    }\n    if (inputs_rematable) { JUST(remat_helper->RematInputs(vm_stream, first, ComputeFnForRemat)); }\n    JUST(AllocateOutputBlobsMemory(op_call_instruction_policy, allocator, vm_stream));\n    if (unlikely(op_call_instruction_policy->need_temp_storage())) {\n      JUST(TryAllocateTempStorage(op_call_instruction_policy, allocator));\n    }\n    ep::Stream* stream = vm_stream->mut_stream_policy()->stream();\n    user_op::OpKernelState* state = nullptr;\n    user_op::OpKernelCache* cache = nullptr;\n    if (op_call_instruction_policy->user_opkernel()->has_state_or_cache()) {\n      TryInitOpKernelStateAndCache(op_call_instruction_policy, stream, &state, &cache);\n    }\n    OpKernelCompute(op_call_instruction_policy, stream, state, cache);\n    if (unlikely(op_call_instruction_policy->need_temp_storage())) {\n      DeallocateTempStorage(op_call_instruction_policy, allocator);\n    }\n    if (inputs_rematable) { JUST(remat_helper->EagerlyEvictRemattedTensors(first)); }\n    if (inputs_rematable || outputs_rematable) {\n      JUST(remat_helper->UpdateRematInfo(first, recompute, inputs_rematable, outputs_rematable));\n    }\n    return Maybe<void>::Ok();\n  }\n\n private:\n  static inline void InferTempStorageSize(OpCallInstructionPolicy* op_call_instruction_policy) {\n    auto* tmp_tensor = op_call_instruction_policy->mut_call_ctx()->mut_tmp_tensor();\n    size_t temp_size = op_call_instruction_policy->opkernel().InferTmpSize(\n        op_call_instruction_policy->mut_call_ctx(), op_call_instruction_policy->user_opkernel());\n    tmp_tensor->set_tmp_buffer_size(temp_size);\n  }\n\n  static inline void TryInitOpKernelStateAndCache(\n      OpCallInstructionPolicy* op_call_instruction_policy, ep::Stream* stream,\n      user_op::OpKernelState** state, user_op::OpKernelCache** cache) {\n    OF_PROFILER_RANGE_GUARD(\"TryInitOpKernelStateAndCache\");\n    if (likely(op_call_instruction_policy->op_interp_ctx().state)) {\n      *state = op_call_instruction_policy->op_interp_ctx().state.get();\n      // set state to nullptr so that state initialization in TryInitOpKernelStateAndCache will be\n      // skipped.\n      state = nullptr;\n    }\n    op_call_instruction_policy->mut_opkernel()->TryInitOpKernelStateAndCache(\n        op_call_instruction_policy->mut_call_ctx(), stream,\n        op_call_instruction_policy->user_opkernel(), state, cache);\n  }\n\n  // Returns true if allocation happened.\n  static inline Maybe<void> AllocateOutputBlobsMemory(\n      OpCallInstructionPolicy* op_call_instruction_policy, Allocator* allocator,\n      const vm::Stream* vm_stream) {\n    OF_PROFILER_RANGE_GUARD(\"AllocateOutputBlobsMemory\");\n    StreamType stream_type = vm_stream->stream_type();\n    StreamType allocator_stream_type = JUST(GetAllocatorStreamType::Visit(stream_type));\n    for (const auto& blob_object : op_call_instruction_policy->outputs()) {\n      if (JUST(blob_object->TryAllocateBlobBodyMemory(allocator))) {\n        CHECK_OR_RETURN(stream_type == allocator_stream_type)\n            << \"no allocator supported on stream type \" << GetStreamTypeName::Visit(stream_type);\n        if (auto* dtr_allocator = dynamic_cast<vm::DtrEpAllocatorProxy*>(allocator)) {\n          dtr_allocator->allocator->LinkStorageAndPtr(\n              dynamic_cast<RematableTensorStorage*>(blob_object->tensor_storage().get()),\n              static_cast<const char*>(blob_object->dptr()));\n        }\n      }\n    }\n    return Maybe<void>::Ok();\n  }\n\n  static inline Maybe<void> TryAllocateTempStorage(\n      OpCallInstructionPolicy* op_call_instruction_policy, Allocator* allocator) {\n    OF_PROFILER_RANGE_GUARD(\"TryAllocateTempStorage\");\n    auto* tmp_tensor = op_call_instruction_policy->mut_call_ctx()->mut_tmp_tensor();\n    size_t byte_size = tmp_tensor->tmp_buffer_size();\n    if (byte_size > 0) {\n      char* mem_ptr = nullptr;\n      JUST(allocator->Allocate(&mem_ptr, byte_size));\n      tmp_tensor->set_tmp_buffer_ptr(mem_ptr);\n    }\n    return Maybe<void>::Ok();\n  }\n\n  static inline void DeallocateTempStorage(OpCallInstructionPolicy* op_call_instruction_policy,\n                                           Allocator* allocator) {\n    auto* tmp_tensor = op_call_instruction_policy->mut_call_ctx()->mut_tmp_tensor();\n    allocator->Deallocate(tmp_tensor->mut_tmp_buffer_ptr(), tmp_tensor->tmp_buffer_size());\n    tmp_tensor->set_tmp_buffer_ptr(nullptr);\n  }\n\n  static inline void OpKernelCompute(OpCallInstructionPolicy* op_call_instruction_policy,\n                                     ep::Stream* stream, user_op::OpKernelState* state,\n                                     user_op::OpKernelCache* cache) {\n    auto* user_kernel = op_call_instruction_policy->user_opkernel();\n    op_call_instruction_policy->mut_opkernel()->Compute(op_call_instruction_policy->mut_call_ctx(),\n                                                        stream, user_kernel, state, cache);\n  }\n\n  static inline Maybe<void> ComputeFnForRemat(OpCallInstructionPolicy* op_call_instruction_policy,\n                                              vm::Stream* vm_stream) {\n    return Compute(op_call_instruction_policy, vm_stream, false, true);\n  }\n\n  static inline std::tuple<std::unique_ptr<RematHelper>, bool, bool> InitRematInfo(\n      OpCallInstructionPolicy* op_call_instruction_policy, vm::Stream* vm_stream) {\n    bool inputs_rematable = false;\n    bool outputs_rematable = false;\n    if (op_call_instruction_policy->opkernel().op_type_name() == \"copy\") {\n      inputs_rematable =\n          op_call_instruction_policy->inputs()[0]->tensor_storage()->device()->rematable();\n      outputs_rematable =\n          op_call_instruction_policy->outputs()[0]->tensor_storage()->device()->rematable();\n    } else {\n      inputs_rematable = vm_stream->device()->rematable();\n      outputs_rematable = vm_stream->device()->rematable();\n    }\n    std::unique_ptr<RematHelper> remat_helper;\n    if (inputs_rematable || outputs_rematable) {\n      remat_helper = std::make_unique<RematHelper>(*op_call_instruction_policy, inputs_rematable,\n                                                   outputs_rematable);\n    }\n    return std::make_tuple(std::move(remat_helper), inputs_rematable, outputs_rematable);\n  }\n};\n\nOpCallInstructionPolicy::OpCallInstructionPolicy(\n    Stream* vm_stream, const std::shared_ptr<one::StatefulOpKernel>& opkernel,\n    EagerBlobObjectList&& inputs, EagerBlobObjectList&& outputs,\n    const std::shared_ptr<const one::GlobalTensorInferResult>& global_tensor_infer_result,\n    const one::OpExprInterpContext& op_interp_ctx,\n    const one::DevVmDepObjectConsumeMode dev_vm_dep_object_consume_mode)\n    : vm_stream_(vm_stream),\n      call_ctx_(ComposedAttrMap(op_interp_ctx.attrs, opkernel->base_attrs()), std::move(inputs),\n                std::move(outputs), global_tensor_infer_result, op_interp_ctx,\n                opkernel->mem_case()),\n      opkernel_(opkernel),\n      user_opkernel_(nullptr),\n      infer_tmp_size_fn_(nullptr),\n      need_temp_storage_(false),\n      dev_vm_dep_object_consume_mode_(dev_vm_dep_object_consume_mode),\n      input_dependences_(),\n      output_dependences_() {\n  ForEachConstDependence([&](auto* dep) { input_dependences_.emplace_back(dep); });\n  ForEachMutDependence([&](auto* dep) { output_dependences_.emplace_back(dep); });\n  ForEachMut2Dependence([&](auto* dep) { output_dependences_.emplace_back(dep); });\n  InitStreamSequentialDependence();\n}\n\nMaybe<void> OpCallInstructionPolicy::Init() {\n  return mut_opkernel()->ChooseOpKernel(&call_ctx_, &user_opkernel_, &need_temp_storage_);\n}\n\nOpCallInstructionPolicy::OpCallInstructionPolicy(const DtrOpCallInstructionPolicy& policy)\n    : vm_stream_(policy.vm_stream_),\n      call_ctx_(policy.dtr_call_ctx_),\n      opkernel_(policy.opkernel_),\n      user_opkernel_(policy.user_opkernel_),\n      infer_tmp_size_fn_(policy.infer_tmp_size_fn_),\n      need_temp_storage_(policy.need_temp_storage_),\n      dev_vm_dep_object_consume_mode_(policy.dev_vm_dep_object_consume_mode_),\n      input_dependences_(policy.input_dependences_),\n      output_dependences_(policy.output_dependences_) {}\n\ntemplate<typename DoEachT>\nvoid OpCallInstructionPolicy::ForEachConstDependence(const DoEachT& DoEach) const {\n  const auto& input_list = inputs();\n  for (int64_t index : opkernel().input_tuple_indexes4const_ibns()) {\n    const auto& input = input_list.at(index);\n    DoEach(CHECK_JUST(input->compute_local_dep_object()));\n  }\n}\n\nvoid OpCallInstructionPolicy::InitStreamSequentialDependence() {\n  auto* device_schedule_dep_object = vm_stream_->schedule_local_dep_object().get();\n  if (IsCommNetStream::Visit(vm_stream_->stream_type())) {\n    // Sequantialize nccl instructions to avoid deadlock\n    stream_sequential_dependence_ = device_schedule_dep_object;\n  } else {\n    // Sequantialize instructions to avoid explosive memory allocation of source ops\n    if (dev_vm_dep_object_consume_mode() == one::DevVmDepObjectConsumeMode::MUTABLE) {\n      stream_sequential_dependence_ = device_schedule_dep_object;\n    } else if (opkernel().input_tuple_indexes4const_ibns().empty()\n               && opkernel().input_tuple_indexes4mut_ibns().empty()) {\n      stream_sequential_dependence_ = device_schedule_dep_object;\n    }\n  }\n}\n\ntemplate<typename DoEachT>\nvoid OpCallInstructionPolicy::ForEachMutDependence(const DoEachT& DoEach) const {\n  for (const auto& transport_dependence : vm_stream_->transport_dependences()) {\n    DoEach(transport_dependence.get());\n  }\n\n  const auto& input_list = inputs();\n  for (int64_t index : opkernel().input_tuple_indexes4mut_ibns()) {\n    const auto& input = input_list.at(index);\n    DoEach(CHECK_JUST(input->compute_local_dep_object()));\n  }\n  const auto& output_list = outputs();\n  for (int64_t index : opkernel().output_tuple_indexes4mut_obns()) {\n    const auto& output = output_list.at(index);\n    DoEach(CHECK_JUST(output->compute_local_dep_object()));\n  }\n}\n\ntemplate<typename DoEachT>\nvoid OpCallInstructionPolicy::ForEachMut2Dependence(const DoEachT& DoEach) const {\n  const auto& output_list = outputs();\n  for (int64_t index : opkernel().output_tuple_indexes4mut2_obns()) {\n    const auto& output = output_list.at(index);\n    DoEach(CHECK_JUST(output->compute_local_dep_object()));\n  }\n}\n\nMaybe<void> OpCallInstructionPolicy::Prepare(vm::Instruction* instruction) {\n  return OpCallInstructionUtil::Prepare(this, instruction);\n}\n\nvoid OpCallInstructionPolicy::Compute(vm::Instruction* instruction) {\n  CHECK_JUST_MSG(OpCallInstructionUtil::Compute(this, instruction->mut_stream(), true, false),\n                 instruction->DebugName());\n}\n\nstd::string OpCallInstructionPolicy::DebugName(const vm::Instruction& instruction) const {\n  return opkernel().op_type_name() + \":OpCall\";\n}\n\nMaybe<void> Recompute(OpCallInstructionPolicy* op_call_instruction_policy, vm::Stream* vm_stream) {\n  VLOG_REMAT(1) << \"recompute \" << op_call_instruction_policy->opkernel().op_type_name()\n                << \" manually\";\n  return OpCallInstructionUtil::Compute(op_call_instruction_policy, vm_stream, true, true);\n}\n\n}  // namespace vm\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/vm/op_call_instruction_policy.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_VM_OP_CALL_INSTRUCTION_POLICY_H_\n#define ONEFLOW_CORE_VM_OP_CALL_INSTRUCTION_POLICY_H_\n\n#include <memory>\n#include \"oneflow/core/eager/call_context.h\"\n#include \"oneflow/core/eager/dev_vm_dep_object_consume_mode.h\"\n#include \"oneflow/core/framework/user_op_kernel_registry.h\"\n#include \"oneflow/core/vm/instruction_policy.h\"\n#include \"oneflow/core/vm/stream.h\"\n#include \"oneflow/user/kernels/stateful_opkernel.h\"\n\nnamespace oneflow {\n\nnamespace user_op {\n\nclass OpKernel;\n\n}  // namespace user_op\n\nnamespace vm {\n\nclass DtrOpCallInstructionPolicy;\n\nclass OpCallInstructionPolicy final : public InstructionPolicy {\n public:\n  OpCallInstructionPolicy(const OpCallInstructionPolicy& other) = default;\n  OpCallInstructionPolicy(OpCallInstructionPolicy&& other) = default;\n  OpCallInstructionPolicy& operator=(const OpCallInstructionPolicy& other) = delete;\n  OpCallInstructionPolicy& operator=(OpCallInstructionPolicy&& other) = delete;\n  ~OpCallInstructionPolicy() override = default;\n\n  template<typename... Args>\n  static Maybe<OpCallInstructionPolicy> New(Args&&... args) {\n    auto* ptr = new OpCallInstructionPolicy(std::forward<Args>(args)...);\n    JUST(ptr->Init());\n    return std::shared_ptr<OpCallInstructionPolicy>(ptr);\n  }\n\n  const one::StatefulOpKernel& opkernel() const { return *opkernel_; }\n  const EagerBlobObjectList& inputs() const { return call_ctx_.inputs(); }\n  const EagerBlobObjectList& outputs() const { return call_ctx_.outputs(); }\n  EagerBlobObjectList& mut_inputs() { return call_ctx_.mut_inputs(); }\n  EagerBlobObjectList& mut_outputs() { return call_ctx_.mut_outputs(); }\n  const ComposedAttrMap& composed_attrs() const { return call_ctx_.composed_attrs(); }\n  const one::OpExprInterpContext& op_interp_ctx() const { return call_ctx_.op_interp_ctx(); }\n  const one::DevVmDepObjectConsumeMode& dev_vm_dep_object_consume_mode() const {\n    return dev_vm_dep_object_consume_mode_;\n  }\n\n  one::StatefulOpKernel* mut_opkernel() { return opkernel_.get(); }\n\n  template<typename DoEachT>\n  Maybe<void> ForEachOutputTensor(const DoEachT& DoEach) {\n    for (const auto& output : outputs()) { JUST(DoEach(output.get())); }\n    return Maybe<void>::Ok();\n  }\n\n  const DependenceVector& input_dependences() const override { return input_dependences_; }\n  const DependenceVector& output_dependences() const override { return output_dependences_; }\n\n  template<typename DoEachT>\n  void ForEachConstDependence(const DoEachT& DoEach) const;\n\n  template<typename DoEachT>\n  void ForEachMutDependence(const DoEachT& DoEach) const;\n\n  template<typename DoEachT>\n  void ForEachMut2Dependence(const DoEachT& DoEach) const;\n\n  bool need_temp_storage() const { return need_temp_storage_; }\n  const user_op::OpKernel* user_opkernel() const { return user_opkernel_; }\n  const user_op::InferTmpSizeFn& infer_tmp_size_fn() const { return *infer_tmp_size_fn_; }\n\n  const std::shared_ptr<const one::GlobalTensorInferResult>& global_tensor_infer_result() const {\n    return call_ctx_.global_tensor_infer_result();\n  }\n\n  const eager::CallContext& call_ctx() const { return call_ctx_; }\n  eager::CallContext* mut_call_ctx() { return &call_ctx_; }\n\n  Stream* vm_stream() const { return vm_stream_; }\n\n  InstructionFuseType fuse_type() const override { return kEnableInstructionFuseAtAnyPosition; }\n\n  std::string DebugName(const vm::Instruction& instruction) const override;\n\n  explicit OpCallInstructionPolicy(const DtrOpCallInstructionPolicy& policy);\n\n private:\n  OpCallInstructionPolicy(\n      Stream* vm_stream, const std::shared_ptr<one::StatefulOpKernel>& opkernel,\n      EagerBlobObjectList&& inputs, EagerBlobObjectList&& outputs,\n      const std::shared_ptr<const one::GlobalTensorInferResult>& global_tensor_infer_result,\n      const one::OpExprInterpContext& op_interp_ctx,\n      const one::DevVmDepObjectConsumeMode dev_vm_dep_object_consume_mode);\n  Maybe<void> Init();\n  void InitStreamSequentialDependence();\n  Maybe<void> Prepare(Instruction* instruction) override;\n  void Compute(Instruction* instruction) override;\n\n  Stream* vm_stream_;\n  eager::CallContext call_ctx_;\n  std::shared_ptr<one::StatefulOpKernel> opkernel_;\n  const user_op::OpKernel* user_opkernel_;\n  const user_op::InferTmpSizeFn* infer_tmp_size_fn_;\n  bool need_temp_storage_;\n  const one::DevVmDepObjectConsumeMode dev_vm_dep_object_consume_mode_;\n  DependenceVector input_dependences_;\n  DependenceVector output_dependences_;\n  friend class DtrOpCallInstructionPolicy;\n};\n\nclass DtrOpCallInstructionPolicy {\n  Stream* vm_stream_;\n  eager::DtrCallContext dtr_call_ctx_;\n  std::shared_ptr<one::StatefulOpKernel> opkernel_;\n  const user_op::OpKernel* user_opkernel_;\n  const user_op::InferTmpSizeFn* infer_tmp_size_fn_;\n  bool need_temp_storage_;\n  const one::DevVmDepObjectConsumeMode dev_vm_dep_object_consume_mode_;\n  DependenceVector input_dependences_;\n  DependenceVector output_dependences_;\n\n public:\n  explicit DtrOpCallInstructionPolicy(const OpCallInstructionPolicy& op)\n      : vm_stream_(op.vm_stream()),\n        dtr_call_ctx_(op.call_ctx()),\n        opkernel_(op.opkernel_),\n        user_opkernel_(op.user_opkernel_),\n        infer_tmp_size_fn_(op.infer_tmp_size_fn_),\n        need_temp_storage_(op.need_temp_storage()),\n        dev_vm_dep_object_consume_mode_(op.dev_vm_dep_object_consume_mode()),\n        input_dependences_(op.input_dependences()),\n        output_dependences_(op.output_dependences()) {}\n  friend class OpCallInstructionPolicy;\n  EagerBlobObjectList& mut_inputs() { return dtr_call_ctx_.mut_inputs(); }\n  WeakEagerBlobObjectList& mut_outputs() { return dtr_call_ctx_.mut_outputs(); }\n  const one::StatefulOpKernel& opkernel() const { return *opkernel_; }\n};\n\nMaybe<void> Recompute(OpCallInstructionPolicy* op_call_instruction_policy, vm::Stream* vm_stream);\n\n}  // namespace vm\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_VM_OP_CALL_INSTRUCTION_POLICY_H_\n"
  },
  {
    "path": "oneflow/core/vm/pinned_ep_stream_policy.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/vm/pinned_ep_stream_policy.h\"\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/common/stream_type.h\"\n#include \"oneflow/core/vm/stream.h\"\n#include \"oneflow/core/vm/thread_ctx.h\"\n#include \"oneflow/core/vm/ep_optional_event_record_status_querier.h\"\n#include \"oneflow/core/vm/ep_backend_host_allocator.h\"\n#include \"oneflow/core/common/util.h\"\n\nnamespace oneflow {\nnamespace vm {\n\nnamespace {\n\nstd::unique_ptr<BinAllocator<ThreadSafeLock>> CreatePinedEpBackendHostAllocator(\n    Symbol<Device> device) {\n  // TODO:(zhaoluyang) empty/cast/copy op support pin_memory_device\n  DeviceType device_type = device->enum_type();\n  size_t device_index = device->device_id();\n  auto ep_device =\n      Singleton<ep::DeviceManagerRegistry>::Get()->GetDevice(device_type, device_index);\n  ep::AllocationOptions options{};\n  options.SetPinnedDevice(device_type, device_index);\n  auto ep_backend_allocator = std::make_unique<EpBackendHostAllocator>(ep_device, options);\n  return std::make_unique<BinAllocator<ThreadSafeLock>>(ep::kMaxAlignmentRequirement,\n                                                        std::move(ep_backend_allocator));\n}\n\n}  // namespace\n\nPinnedEpStreamPolicy::PinnedEpStreamPolicy(Symbol<Device> device)\n    : EpStreamPolicyBase(device, CreatePinedEpBackendHostAllocator(device)) {}\n\nvoid PinnedEpStreamPolicy::InitInstructionStatus(const Stream& stream,\n                                                 InstructionStatusBuffer* status_buffer) const {\n  static_assert(sizeof(EpOptionalEventRecordStatusQuerier) < kInstructionStatusBufferBytes, \"\");\n  auto* data_ptr = status_buffer->mut_buffer();\n  EpOptionalEventRecordStatusQuerier::PlacementNew(data_ptr, nullptr);\n}\n\n}  // namespace vm\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/vm/pinned_ep_stream_policy.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_VM_PINNED_EP_STREAM_POLICY_H_\n#define ONEFLOW_CORE_VM_PINNED_EP_STREAM_POLICY_H_\n\n#include \"oneflow/core/vm/ep_stream_policy_base.h\"\n\nnamespace oneflow {\nnamespace vm {\n\nclass PinnedEpStreamPolicy final : public EpStreamPolicyBase {\n public:\n  PinnedEpStreamPolicy(Symbol<Device> device);\n  ~PinnedEpStreamPolicy() override = default;\n\n  void InitInstructionStatus(const Stream& stream,\n                             InstructionStatusBuffer* status_buffer) const override;\n};\n\n}  // namespace vm\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_VM_PINNED_EP_STREAM_POLICY_H_\n"
  },
  {
    "path": "oneflow/core/vm/probe.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#ifndef ONEFLOW_CORE_VM_PROBE_H_\n#define ONEFLOW_CORE_VM_PROBE_H_\n\n#include \"oneflow/core/intrusive/intrusive.h\"\n\nnamespace oneflow {\nnamespace vm {\n\ntemplate<typename ProbeFunction>\nclass Probe final : public intrusive::Base {\n public:\n  Probe(const Probe&) = delete;\n  Probe(Probe&&) = delete;\n\n  Probe() = default;\n  ~Probe() = default;\n\n  void __Init__(const ProbeFunction& probe_function) { probe_function_ = probe_function; }\n\n  const ProbeFunction& probe_function() const { return probe_function_; }\n\n private:\n  friend class intrusive::Ref;\n  intrusive::Ref* mut_intrusive_ref() { return &intrusive_ref_; }\n\n  // fields\n  intrusive::Ref intrusive_ref_;\n  ProbeFunction probe_function_;\n\n public:\n  // hooks\n  intrusive::ListHook probe_hook_;\n};\n\n}  // namespace vm\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_VM_PROBE_H_\n"
  },
  {
    "path": "oneflow/core/vm/ref_cnt_instruction_status_querier.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_VM_REF_CNT_VM_INSTRUCTION_STATUS_QUERIER_H_\n#define ONEFLOW_CORE_VM_REF_CNT_VM_INSTRUCTION_STATUS_QUERIER_H_\n\n#include <atomic>\n#include <memory>\n\nnamespace oneflow {\nnamespace vm {\n\nclass RefCntInstrStatusQuerier {\n public:\n  ~RefCntInstrStatusQuerier() = default;\n\n  bool done() const { return launched_ && *ref_cnt_ == 0; }\n  void SetRefCntAndSetLaunched(const std::shared_ptr<std::atomic<int64_t>>& ref_cnt) {\n    // No lock needed. This function will be called only one time.\n    // In most cases, errors will be successfully detected by CHECK\n    // even though run in different threads.\n    CHECK(!launched_);\n    ref_cnt_ = ref_cnt;\n    launched_ = true;\n  }\n\n  static const RefCntInstrStatusQuerier* Cast(const char* mem_ptr) {\n    return reinterpret_cast<const RefCntInstrStatusQuerier*>(mem_ptr);\n  }\n  static RefCntInstrStatusQuerier* MutCast(char* mem_ptr) {\n    return reinterpret_cast<RefCntInstrStatusQuerier*>(mem_ptr);\n  }\n  static RefCntInstrStatusQuerier* PlacementNew(char* mem_ptr) {\n    return new (mem_ptr) RefCntInstrStatusQuerier();\n  }\n\n private:\n  RefCntInstrStatusQuerier() : launched_(false), ref_cnt_() {}\n\n  std::atomic<bool> launched_;\n  std::shared_ptr<std::atomic<int64_t>> ref_cnt_;\n};\n\n}  // namespace vm\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_VM_REF_CNT_VM_INSTRUCTION_STATUS_QUERIER_H_\n"
  },
  {
    "path": "oneflow/core/vm/release_tensor_instruction_policy.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_VM_RELEASE_TENSOR_INSTRUCTION_POLICY_H_\n#define ONEFLOW_CORE_VM_RELEASE_TENSOR_INSTRUCTION_POLICY_H_\n\n#include <functional>\n#include <memory>\n#include \"oneflow/core/common/throw.h\"\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/intrusive/intrusive.h\"\n#include \"oneflow/core/vm/ep_optional_event_record_status_querier.h\"\n#include \"oneflow/core/eager/local_dep_object.h\"\n#include \"oneflow/core/eager/eager_blob_object.h\"\n#include \"oneflow/core/eager/tensor_storage.h\"\n#include \"oneflow/core/common/symbol.h\"\n#include \"oneflow/core/common/optional.h\"\n#include \"oneflow/core/framework/device.h\"\n#include \"oneflow/core/framework/stream.h\"\n#include \"oneflow/core/vm/stream.h\"\n#include \"oneflow/core/framework/stream_need_soft_sync.h\"\n\nnamespace oneflow {\n\nnamespace vm {\n\nclass EagerBlobObject;\n\nclass ReleaseTensorInstructionPolicy : public InstructionPolicy {\n public:\n  ReleaseTensorInstructionPolicy(const std::shared_ptr<vm::EagerBlobObject>& eager_blob_object,\n                                 const Optional<vm::Stream*>& stream)\n      : eager_blob_object_(eager_blob_object), output_dependences_() {\n    output_dependences_.push_back(CHECK_JUST(eager_blob_object->compute_local_dep_object()));\n    if (stream.has_value()) {\n      stream_sequential_dependence_ = CHECK_JUST(stream)->schedule_local_dep_object().get();\n    }\n  }\n  ~ReleaseTensorInstructionPolicy() override = default;\n\n  const std::shared_ptr<vm::EagerBlobObject>& eager_blob_object() const {\n    return eager_blob_object_;\n  }\n\n  const DependenceVector& input_dependences() const override {\n    static thread_local DependenceVector empty{};\n    return empty;\n  }\n\n  const DependenceVector& output_dependences() const override { return output_dependences_; }\n\n  Dependence* stream_sequential_dependence() const override {\n    return stream_sequential_dependence_;\n  }\n\n protected:\n  void Release(const std::shared_ptr<vm::EagerBlobObject>& eager_blob_object) const {\n    CHECK_JUST(eager_blob_object->DeallocateBlobDataPtr());\n  }\n\n private:\n  void InitInstructionStatus(Instruction* instruction) override {\n    auto* status_buffer = instruction->mut_status_buffer();\n    auto* stream = instruction->mut_stream();\n    instruction->stream_policy().InitInstructionStatus(*stream, status_buffer);\n    auto* data_ptr = status_buffer->mut_buffer();\n    EpOptionalEventRecordStatusQuerier::MutCast(data_ptr)->reset_ep_event(nullptr);\n  }\n  std::shared_ptr<vm::EagerBlobObject> eager_blob_object_;\n  DependenceVector output_dependences_;\n};\n\nclass FastReleaseTensorInstructionPolicy final : public ReleaseTensorInstructionPolicy {\n public:\n  using ReleaseTensorInstructionPolicy::ReleaseTensorInstructionPolicy;\n\n  bool Prescheduleable(const vm::Stream* src, const vm::Stream* dst) const override {\n    return false;\n  }\n\n private:\n  std::string DebugName(const vm::Instruction& instruction) const override {\n    return \"FastReleaseTensor\";\n  }\n\n  Maybe<void> Prepare(vm::Instruction* instruction) override {\n    DataType data_type = eager_blob_object()->data_type();\n    CHECK_OR_RETURN(IsTriviallyCopyableDataType(data_type));\n    if (eager_blob_object()->tensor_storage()->is_allocated_in_vm()) {\n      Release(eager_blob_object());\n    }\n    return Maybe<void>::Ok();\n  }\n\n  void Compute(vm::Instruction* instruction) override {\n    if (!eager_blob_object()->tensor_storage()->is_allocated_in_vm()) {\n      Release(eager_blob_object());\n    }\n  }\n};\n\nclass SlowReleaseTensorInstructionPolicy final : public ReleaseTensorInstructionPolicy {\n public:\n  using ReleaseTensorInstructionPolicy::ReleaseTensorInstructionPolicy;\n\n private:\n  std::string DebugName(const vm::Instruction& instruction) const override {\n    return \"SlowReleaseTensor\";\n  }\n\n  Maybe<void> Prepare(vm::Instruction* instruction) override { return Maybe<void>::Ok(); }\n\n  void Compute(vm::Instruction* instruction) override { Release(eager_blob_object()); }\n};\n\nstruct MakeReleaseTensorInstructionPolicy\n    : public StreamTypeVisitor<MakeReleaseTensorInstructionPolicy> {\n  static Maybe<vm::InstructionPolicy> VisitCompute(\n      const std::shared_ptr<vm::EagerBlobObject>& eager_blob_object,\n      const Optional<vm::Stream*>& stream) {\n    return Make(eager_blob_object, stream);\n  }\n  static Maybe<vm::InstructionPolicy> VisitHost2Device(\n      const std::shared_ptr<vm::EagerBlobObject>& eager_blob_object,\n      const Optional<vm::Stream*>& stream) {\n    return Make(eager_blob_object, stream);\n  }\n  static Maybe<vm::InstructionPolicy> VisitDevice2Host(\n      const std::shared_ptr<vm::EagerBlobObject>& eager_blob_object,\n      const Optional<vm::Stream*>& stream) {\n    return Make(eager_blob_object, stream);\n  }\n  static Maybe<vm::InstructionPolicy> VisitCcl(\n      const std::shared_ptr<vm::EagerBlobObject>& eager_blob_object,\n      const Optional<vm::Stream*>& stream) {\n    return Make(eager_blob_object, stream);\n  }\n  static Maybe<vm::InstructionPolicy> VisitBarrier(\n      const std::shared_ptr<vm::EagerBlobObject>& eager_blob_object,\n      const Optional<vm::Stream*>& stream) {\n    UNIMPLEMENTED_THEN_RETURN() << \"ReleaseTensor instruction not supported in Barrier stream\";\n  }\n  static Maybe<vm::InstructionPolicy> VisitCriticalSection(\n      const std::shared_ptr<vm::EagerBlobObject>& eager_blob_object,\n      const Optional<vm::Stream*>& stream) {\n    UNIMPLEMENTED_THEN_RETURN()\n        << \"ReleaseTensor instruction not supported in CriticalSection stream\";\n  }\n  static Maybe<vm::InstructionPolicy> VisitLazyJobLauncher(\n      const std::shared_ptr<vm::EagerBlobObject>& eager_blob_object,\n      const Optional<vm::Stream*>& stream) {\n    UNIMPLEMENTED_THEN_RETURN()\n        << \"ReleaseTensor instruction not supported in LaunchLazyJob stream\";\n  }\n  static Maybe<vm::InstructionPolicy> VisitPinnedCompute(\n      const std::shared_ptr<vm::EagerBlobObject>& eager_blob_object,\n      const Optional<vm::Stream*>& stream) {\n    return VisitCompute(eager_blob_object, stream);\n  }\n\n private:\n  static Maybe<vm::InstructionPolicy> Make(\n      const std::shared_ptr<vm::EagerBlobObject>& eager_blob_object,\n      const Optional<vm::Stream*>& stream) {\n    DataType data_type = eager_blob_object->data_type();\n    if (!IsTriviallyCopyableDataType(data_type)) {\n      return std::shared_ptr<vm::InstructionPolicy>(\n          new vm::SlowReleaseTensorInstructionPolicy(eager_blob_object, stream));\n    }\n    Symbol<oneflow::Stream> last_used_stream = JUST(eager_blob_object->last_used_stream());\n    DeviceType device_type = last_used_stream->device()->enum_type();\n    if (NeedSoftSync::Visit(last_used_stream->stream_type(), device_type)) {\n      return std::shared_ptr<vm::InstructionPolicy>(\n          new vm::SlowReleaseTensorInstructionPolicy(eager_blob_object, stream));\n    } else {\n      return std::shared_ptr<vm::InstructionPolicy>(\n          new vm::FastReleaseTensorInstructionPolicy(eager_blob_object, stream));\n    }\n  }\n};\n\n}  // namespace vm\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_VM_RELEASE_TENSOR_INSTRUCTION_POLICY_H_\n"
  },
  {
    "path": "oneflow/core/vm/remat/allocator.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include <iterator>\n#include <vector>\n#include \"nlohmann/json.hpp\"\n#include \"oneflow/core/common/env_var/debug_mode.h\"\n#include \"oneflow/core/common/thread_local_guard.h\"\n#include \"oneflow/core/ep/include/device_manager_registry.h\"\n#include \"oneflow/core/profiler/util.h\"\n\n#include \"oneflow/core/common/env_var/remat.h\"\n#include \"oneflow/core/vm/ep_backend_allocator.h\"\n#include \"oneflow/core/vm/remat/allocator.h\"\n#include \"oneflow/core/eager/eager_blob_object.h\"\n#include \"oneflow/core/eager/tensor_storage.h\"\n#include \"oneflow/core/vm/remat/env.h\"\n#include \"oneflow/core/vm/remat/util.h\"\n#include \"oneflow/core/vm/thread_safe_guard.h\"\n#include \"oneflow/core/vm/remat/disjoint_set.h\"\n#include <iostream>\n\nnamespace oneflow {\nnamespace vm {\n\nnamespace {\n\ninline size_t CudaMemAlignedBytes(size_t bytes) { return RoundUp(bytes, kCudaMemAllocAlignSize); }\n\ninline bool IsAlignedSize(size_t size) { return size % kCudaMemAllocAlignSize == 0; }\n\ninline double bytes2Mb(size_t bytes) { return bytes * 1. / 1024 / 1024; }\n\nstatic constexpr size_t kSmallPieceThreshold = 10 * 1024;  // 10 KB\n\ninline bool ShouldBeHeldBySmallPiece(size_t size) {\n  return Singleton<remat::Env>::Get()->is_small_pieces_optimization_enabled()\n         && size <= kSmallPieceThreshold;\n}\n\nstd::vector<size_t> GroupNumToIndexes(size_t group_num) {\n  switch (group_num) {\n    case 1: return {0};\n    case 2: return {0, 1};\n    case 3: return {0, 1, 2};\n    case 4: return {0, 1, 3, 2};\n    case 6: return {3, 1, 0, 5, 4, 2};\n  }\n  UNIMPLEMENTED();\n}\n\n}  // namespace\n\nRematEpAllocator::RematEpAllocator(size_t alignment, std::unique_ptr<Allocator>&& backend)\n    : Allocator(),\n      alignment_(alignment),\n      backend_(std::move(backend)),\n      memory_size_(0),\n      recycle_piece_list_(nullptr),\n      normal_group_num_(EnvInteger<ONEFLOW_REMAT_GROUP_NUM>()),\n      group_indexes_(GroupNumToIndexes(normal_group_num_)),\n      cur_group_index_id_(normal_group_num_ > 1 ? 1 : 0),\n      cur_group_index_id_high_cost_(0),\n      enable_left_and_right_(normal_group_num_ > 1) {\n  free_pieces_overlapping_with_group_.resize(normal_group_num_ + 1);\n}\n\nRematEpAllocator::~RematEpAllocator() {\n  if (memory_ != nullptr) { backend_->Deallocate(static_cast<char*>(memory_), memory_size_); }\n}\n\nRematEpAllocator::offset_t RematEpAllocator::get_offset(const char* mem_ptr) const {\n  return mem_ptr - (char*)memory_;\n}\n\nvoid RematEpAllocator::LinkStorageAndPtr(RematableTensorStorage* storage, const char* mem_ptr) {\n  Piece* piece = ptr2piece_.at(mem_ptr);\n  piece->tensor = storage;\n  CHECK_NOTNULL(piece->tensor);\n  VLOG(1) << \"tensor \" << piece->tensor->id() << \" is allocated at \" << get_offset(mem_ptr)\n          << \", left: \" << piece->is_left;\n}\n\nMaybe<bool> RematEpAllocator::InSmallMemoryArea(void* ptr) {\n  CHECK_NOTNULL_OR_RETURN(small_piece_area_ptr_);\n  CHECK_GE_OR_RETURN(ptr, memory_);\n  CHECK_LT_OR_RETURN(ptr, (char*)memory_ + memory_size_);\n  // compare pointer by raw < or > is undefined behavior\n  return std::greater_equal<>{}(ptr, small_piece_area_ptr_);\n}\n\nRematEpAllocator::Piece* RematEpAllocator::AllocatePiece() {\n  if (recycle_piece_list_) {\n    Piece* ret = recycle_piece_list_;\n    recycle_piece_list_ = recycle_piece_list_->next;\n    return ret;\n  } else {\n    pieces_.emplace_back(new Piece());\n    return pieces_.at(pieces_.size() - 1).get();\n  }\n}\n\nvoid RematEpAllocator::DeallocatePiece(Piece* piece) {\n  piece->ptr = nullptr;\n  piece->size = 0;\n  CHECK(piece->is_free);\n  piece->prev = nullptr;\n  piece->next = recycle_piece_list_;\n  piece->is_left = true;\n  recycle_piece_list_ = piece;\n}\n\nvoid RematEpAllocator::InsertPiece2PtrMap(Piece* piece) {\n  VLOG(2) << \"insert piece, offset \" << get_offset(piece->ptr);\n  CHECK_NOTNULL(piece->ptr);\n  CHECK(ptr2piece_.emplace(piece->ptr, piece).second);\n}\n\nvoid RematEpAllocator::ErasePieceFromPtrMap(Piece* piece) {\n  VLOG(2) << \"erase piece, offset \" << get_offset(piece->ptr);\n  CHECK_NOTNULL(piece->ptr);\n  auto it = ptr2piece_.find(piece->ptr);\n  CHECK(it != ptr2piece_.end());\n  ptr2piece_.erase(it);\n}\n\ndouble get_cost(const vm::RematableTensorStorage* storage) {\n  if (storage == nullptr) { return 0.; }\n  double cost = CHECK_JUST(storage->cost(0));\n\n  CHECK(!std::isnan(cost));\n  return cost;\n}\n\ndouble get_cost(const vm::RematableTensorStorage* storage, size_t size) {\n  if (storage == nullptr) { return 0.; }\n  double cost = CHECK_JUST(storage->cost(size));\n\n  CHECK(!std::isnan(cost));\n  return cost;\n}\n\nvoid RematEpAllocator::CheckPieces() {\n  auto it = ptr2piece_.cbegin();\n  for (int i = 0; i < ptr2piece_.size(); ++i) {\n    Piece* piece = it->second;\n    if (piece->tensor == nullptr) { CHECK(piece->is_free); }\n    if (piece->is_free) { CHECK_ISNULL(piece->tensor); }\n    if (i != 0) {\n      CHECK_EQ(piece->prev->next, piece);\n      CHECK_EQ(piece->prev->ptr + piece->prev->size, piece->ptr);\n      auto it2 = it;\n      --it2;\n      CHECK_EQ(piece->prev, it2->second);\n    }\n    if (i != ptr2piece_.size() - 1) {\n      CHECK_EQ(piece->next->prev, piece);\n      CHECK_EQ(piece->ptr + piece->size, piece->next->ptr);\n      auto it2 = it;\n      ++it2;\n      CHECK_EQ(piece->next, it2->second);\n    }\n    it++;\n  }\n}\n\nvoid RematEpAllocator::DisplayAllPieces() {\n  std::cout << \"ops: \" << Singleton<remat::Env>::Get()->ops.size() << std::endl;\n  for (const auto& pair : ptr2piece_) {\n    Piece* piece = pair.second;\n    std::stringstream ss;\n    ss << \"piece \" << piece << \", \" << (void*)piece->ptr << \", \" << piece->size << \", \";\n    if (piece->tensor) {\n      ss << \"ebo: \" << piece->tensor << \", id: \" << piece->tensor->id() << \", cost: \"\n         << (piece->tensor->is_eviction_disabled() ? \"disabled\"\n                                                   : std::to_string(get_cost(piece->tensor)))\n         << \", pinned: \" << piece->tensor->num_pinned()\n         << \", evictable: \" << piece->tensor->is_evictable()\n         << \", compute op: \" << piece->tensor->compute_op_type_name();\n    } else {\n      ss << \"no tensor\";\n    }\n    std::cout << ss.str() << std::endl;\n  }\n}\n\nvoid RematEpAllocator::Display() {\n  double total_free_piece_bytes = 0.;\n  for (const auto& free_list : free_pieces_overlapping_with_group_) {\n    for (auto it = free_list.begin(); it != free_list.end(); ++it) {\n      Piece* piece = *it;\n      CHECK(piece->is_free);\n      CHECK_NOTNULL(piece->ptr);\n      CHECK(IsAlignedSize(piece->size));\n      std::cout << \"memory: \" << piece->size * 1. / 1024 / 1024 << \"MB\" << std::endl;\n      total_free_piece_bytes += piece->size;\n    }\n  }\n  std::cout << \"total_free_piece_bytes: \" << bytes2Mb(total_free_piece_bytes) << \"MB\"\n            << \", total allocate bytes: \" << bytes2Mb(total_allocate_bytes_) << \"MB\"\n            << \", total deallocate bytes: \" << bytes2Mb(total_deallocate_bytes_) << \"MB\"\n            << std::endl;\n}\n\n// 开启了 left-right 之后，才能开启 op guided\n\nRematEpAllocator::offset_t RematEpAllocator::FindProperPositionInGroup(Piece* piece,\n                                                                       size_t group_idx,\n                                                                       size_t request_size) const {\n  const offset_t grp_left_bound = group_boundaries_[group_idx].first;\n  const offset_t grp_right_bound = group_boundaries_[group_idx].second;\n  const offset_t piece_left_bound = get_offset(piece->ptr);\n  const offset_t piece_right_bound = piece_left_bound + piece->size;\n  const bool is_right =\n      enable_left_and_right_ && (group_idx % 2 == 1) && group_idx != normal_group_num_;\n#define PNT3(var) VLOG(3) << OF_PP_STRINGIZE(var) << \": \" << var << std::endl\n  PNT3(group_idx);\n  PNT3(grp_left_bound);\n  PNT3(grp_right_bound);\n  PNT3(piece_left_bound);\n  PNT3(piece_right_bound);\n  PNT3(is_right);\n  PNT3(request_size);\n\n  if (is_right) {\n    if (grp_right_bound < piece_right_bound) {\n      if (grp_right_bound - request_size > piece_left_bound) {\n        return grp_right_bound - request_size;\n      }\n    }\n    // half of tensor in group\n    if (piece_right_bound - request_size / 2 < grp_right_bound) {\n      return piece_right_bound - request_size;\n    }\n  } else {\n    if (grp_left_bound > piece_left_bound) {\n      if (grp_left_bound + request_size < piece_right_bound) { return grp_left_bound; }\n    }\n    // half of tensor in group\n    if (piece_left_bound + request_size / 2 > grp_left_bound) { return piece_left_bound; }\n  }\n  return SIZE_MAX;\n}\n\nvoid RematEpAllocator::InsertToFreeList(Piece* piece) {\n  const offset_t piece_left = get_offset(piece->ptr);\n  const offset_t piece_right = piece_left + piece->size;\n  VLOG(3) << \"piece_left: \" << piece_left << \", right: \" << piece_right << std::endl;\n  for (size_t i = 0; i < group_boundaries_.size(); i++) {\n    VLOG(3) << \"g left: \" << group_boundaries_[i].first\n            << \", right: \" << group_boundaries_[i].second << std::endl;\n    if ((piece_left >= group_boundaries_[i].first && piece_left < group_boundaries_[i].second)\n        || (piece_right > group_boundaries_[i].first\n            && piece_right <= group_boundaries_[i].second)) {\n      VLOG(3) << \"overlap\" << std::endl;\n      free_pieces_overlapping_with_group_[i].insert(piece);\n    }\n  }\n}\n\nvoid RematEpAllocator::EraseFromFreeList(Piece* piece) {\n  VLOG(3) << \"erase \" << get_offset(piece->ptr);\n  // NOTE: very strange bug:\n  // std::map::erase(Key) returns 2 instead of 0 or 1, which conflicts with documentation.\n  for (auto& free_list : free_pieces_overlapping_with_group_) {\n    for (auto it = free_list.begin(); it != free_list.end(); it++) {\n      if ((*it)->ptr == piece->ptr) {\n        free_list.erase(it);\n        break;\n      }\n    }\n  }\n}\n\nauto RematEpAllocator::AllocateMemoryInPiece(Piece* piece, offset_t offset_in_piece, size_t size)\n    -> Piece* {\n  auto SplitPiece = [this](Piece* piece, offset_t offset_in_piece) -> Piece* {\n    // offset_in_piece must be less (not equal) than piece->size so that\n    // new_piece has size\n    CHECK_LE(offset_in_piece, piece->size);\n    Piece* new_piece = AllocatePiece();\n    new_piece->ptr = piece->ptr + offset_in_piece;\n    VLOG(2) << get_offset(piece->ptr);\n    new_piece->size = piece->size - offset_in_piece;\n    piece->size = offset_in_piece;\n\n    Piece* next_p = piece->next;\n    piece->next = new_piece;\n    new_piece->prev = piece;\n    new_piece->next = next_p;\n    if (next_p != nullptr) { next_p->prev = new_piece; }\n    InsertPiece2PtrMap(new_piece);\n\n    CHECK(IsAlignedSize(piece->size));\n    CHECK(IsAlignedSize(new_piece->size));\n    return new_piece;\n  };\n  auto SplitPiece3 = [&SplitPiece](\n                         Piece* piece, offset_t offset1_in_piece,\n                         offset_t offset2_in_piece) -> std::tuple<Piece*, Piece*, Piece*> {\n    Piece* piece1 = nullptr;\n    Piece* piece2 = nullptr;\n    Piece* piece3 = nullptr;\n    bool has_piece3 = offset2_in_piece != piece->size;\n    if (offset1_in_piece > 0) {\n      piece1 = piece;\n      piece2 = SplitPiece(piece, offset1_in_piece);\n    } else {\n      piece1 = nullptr;\n      piece2 = piece;\n    }\n    if (has_piece3) { piece3 = SplitPiece(piece2, offset2_in_piece - offset1_in_piece); }\n    return {piece1, piece2, piece3};\n  };\n  auto pieces = SplitPiece3(piece, offset_in_piece, offset_in_piece + size);\n  EraseFromFreeList(piece);\n  Piece *piece1 = std::get<0>(pieces), *piece2 = std::get<1>(pieces), *piece3 = std::get<2>(pieces);\n  if (piece1 != nullptr) {\n    // piece1 is already free\n    InsertToFreeList(piece1);\n  }\n  // piece2->is_free = false;\n  if (piece3 != nullptr) {\n    piece3->is_free = true;\n    InsertToFreeList(piece3);\n  }\n  return piece2;\n}\n\nsize_t RematEpAllocator::iterate_group_index(bool high) const {\n  if (normal_group_num_ == 1) { return 0; }\n  auto is_high_group = [](size_t idx) -> bool { return (idx / 2) % 2 == (idx % 2); };\n  if (high) {\n    size_t index;  // NOLINT\n    do {\n      cur_group_index_id_high_cost_ = (cur_group_index_id_high_cost_ + 1) % normal_group_num_;\n      index = group_indexes_[cur_group_index_id_high_cost_];\n    } while (!is_high_group(index));\n    return index;\n  } else {\n    size_t index;  // NOLINT\n    do {\n      cur_group_index_id_ = (cur_group_index_id_ + 1) % normal_group_num_;\n      index = group_indexes_[cur_group_index_id_];\n    } while (is_high_group(index));\n    return index;\n  }\n}\n\nsize_t RematEpAllocator::group_index(bool high) const {\n  if (high) {\n    return group_indexes_[cur_group_index_id_high_cost_];\n  } else {\n    return group_indexes_[cur_group_index_id_];\n  }\n}\n\nvoid RematEpAllocator::InitMemory() {\n  memory_size_ = Singleton<remat::Env>::Get()->budget_in_bytes();\n  CHECK_JUST(backend_->Allocate(&memory_, memory_size_));\n  LOG(INFO) << \"memory_: \" << (void*)memory_ << \", size: \" << memory_size_;\n  const size_t small_piece_area_size =\n      Singleton<remat::Env>::Get()->is_small_pieces_optimization_enabled()\n          ? 1024 * kSmallPieceThreshold\n          : 0;\n  const size_t normal_area_size = memory_size_ - small_piece_area_size;\n  small_piece_area_ptr_ = memory_ + normal_area_size;\n\n  if (enable_left_and_right_) { CHECK_EQ(normal_group_num_ % 2, 0); }\n  const size_t effective_normal_group_num =\n      enable_left_and_right_ ? normal_group_num_ / 2 : normal_group_num_;\n  const std::vector<offset_t> boundary_tmp = [&]() {\n    const size_t mem_per_group = normal_area_size / effective_normal_group_num;\n    std::vector<offset_t> boundary_tmp;\n    for (size_t i = 0, b = 0; i < effective_normal_group_num; i++, b += mem_per_group) {\n      boundary_tmp.push_back(b);\n    }\n    boundary_tmp.push_back(normal_area_size);\n    return boundary_tmp;\n  }();\n  for (size_t i = 0; i < effective_normal_group_num; i++) {\n    group_boundaries_.emplace_back(boundary_tmp[i], boundary_tmp[i + 1]);\n    if (enable_left_and_right_) {\n      group_boundaries_.emplace_back(boundary_tmp[i], boundary_tmp[i + 1]);\n    }\n  }\n  if (normal_area_size != memory_size_) {\n    group_boundaries_.emplace_back(normal_area_size, memory_size_);\n  }\n\n  Piece* piece = AllocatePiece();\n  piece->size = memory_size_;\n  piece->ptr = memory_;\n  piece->prev = nullptr;\n  piece->next = nullptr;\n  piece->is_free = true;\n  piece->tensor = nullptr;\n  InsertToFreeList(piece);\n  InsertPiece2PtrMap(piece);\n}\n\nMaybe<RematEpAllocator::Piece*> RematEpAllocator::FindPiece(size_t aligned_size,\n                                                            bool after_eviction) {\n  CHECK_OR_RETURN(IsAlignedSize(aligned_size));\n\n  if (memory_ == nullptr) { InitMemory(); }\n\n  // NOLINTNEXTLINE\n  const bool is_high_op = [&]() {\n    std::vector<std::string> high_compute_cost_names{\"conv2d\", \"conv_data_grad\", \"conv_filter_grad\",\n                                                     \"add_n\",  \"matmul\",         \"batch_matmul\"};\n    const auto current_op_type_name =\n        CHECK_JUST(ThreadLocalGuard<remat::CurrentOpTypeName>::Current())->value;\n    PNT3(current_op_type_name);\n    if (std::find(high_compute_cost_names.cbegin(), high_compute_cost_names.cend(),\n                  current_op_type_name)\n        != high_compute_cost_names.cend()) {\n      return true;\n    }\n    return false;\n  }();\n\n  size_t group_idx = [&]() -> size_t {\n    if (ShouldBeHeldBySmallPiece(aligned_size)) { return normal_group_num_; }\n    // if (after_eviction) { return true; }\n    return group_index(is_high_op);\n  }();\n  PNT3(aligned_size);\n  size_t iterate_num = 0;\n  do {\n    const auto& free_pieces = free_pieces_overlapping_with_group_[group_idx];\n    PNT3(group_idx);\n    PNT3(free_pieces.size());\n    for (auto it = free_pieces.begin(); it != free_pieces.end(); ++it) {\n      Piece* piece = *it;\n      CHECK_OR_RETURN(piece->is_free);\n      CHECK_NOTNULL(piece->ptr);\n      CHECK_OR_RETURN(IsAlignedSize(piece->size));\n      PNT3(get_offset(piece->ptr));\n      PNT3(piece->size);\n      if (piece->size >= aligned_size) {\n        const offset_t offset_in_memory = FindProperPositionInGroup(piece, group_idx, aligned_size);\n        PNT3(offset_in_memory);\n        if (offset_in_memory != SIZE_MAX) {\n          const offset_t offset_in_piece = offset_in_memory - get_offset(piece->ptr);\n          auto ret = AllocateMemoryInPiece(piece, offset_in_piece, aligned_size);\n          CheckPieces();\n          return ret;\n        }\n      }\n    }\n    // update group_idx only if this group fails\n    // multiple outputs of a single op places in the same group\n    group_idx = iterate_group_index(is_high_op);\n    iterate_num++;\n  } while (!ShouldBeHeldBySmallPiece(aligned_size) && iterate_num < normal_group_num_);\n\n  return nullptr;\n}\n\nvoid RematEpAllocator::MergeNeighbourFreePiece(Piece* lhs, Piece* rhs) {\n  CHECK(lhs->is_free);\n  CHECK(rhs->is_free);\n  CHECK(lhs->next == rhs);\n  CHECK(lhs == rhs->prev);\n  CHECK(lhs->ptr + lhs->size == rhs->ptr);\n\n  lhs->size += rhs->size;\n  lhs->next = rhs->next;\n  if (rhs->next != nullptr) { rhs->next->prev = lhs; }\n  ErasePieceFromPtrMap(rhs);\n  DeallocatePiece(rhs);\n}\n\nMaybe<RematEpAllocator::Piece*> RematEpAllocator::EvictAndFindPieceLoop(size_t required_size,\n                                                                        bool consider_neighbor) {\n  VLOG(2) << \"required size: \" << required_size;\n  auto GetSizeIncludingNeighborhood = [](auto it, auto begin, auto end) -> size_t {\n    size_t size = it->second->size;\n    if (it != begin) {\n      for (auto t = std::prev(it); t->second->tensor == nullptr; t--) {\n        size += t->second->size;\n        if (t == begin) { break; }\n      }\n    }\n    if (it != end) {\n      for (auto t = std::next(it); t != end && t->second->tensor == nullptr; t++) {\n        size += t->second->size;\n      }\n    }\n    return size;\n  };\n\n  while (true) {\n    double min_cost = std::numeric_limits<double>::max();\n    vm::RematableTensorStorage* min_tensor = nullptr;\n    for (auto it = ptr2piece_.begin();\n         it != ptr2piece_.end() && !JUST(InSmallMemoryArea(it->second->ptr)); it++) {\n      auto* tensor = it->second->tensor;\n      if (tensor != nullptr && !tensor->is_pinned() && tensor->is_evictable()) {\n        auto cur_op_cost =\n            consider_neighbor ? get_cost(\n                tensor, GetSizeIncludingNeighborhood(it, ptr2piece_.begin(), ptr2piece_.end()))\n                              : get_cost(tensor);\n        if (cur_op_cost < min_cost) {\n          min_cost = cur_op_cost;\n          min_tensor = tensor;\n        }\n      }\n    }\n    if (min_tensor) {\n      min_tensor->Evict(false);\n      Piece* piece = JUST(FindPiece(required_size, true));\n      if (piece != nullptr) { return piece; }\n    } else {\n      return Error::RuntimeError() << \"Cannot find a piece to evict\";\n    }\n  }\n}\n\nMaybe<RematEpAllocator::Piece*> RematEpAllocator::EvictAndFindPieceOnce(size_t required_size) {\n  VLOG(2) << \"required size: \" << required_size;\n  auto start = ptr2piece_.begin();\n  auto end = ptr2piece_.begin();\n  size_t total_size = 0;\n  double cost_except_size = 0;\n  double min_cost = std::numeric_limits<double>::max();\n  auto min_start = start;\n  auto min_end = start;\n  std::vector<double> costs;\n  costs.reserve(ptr2piece_.size());\n  size_t start_i = 0;\n  size_t end_i = 0;\n  while (end != ptr2piece_.end() && !JUST(InSmallMemoryArea(end->second->ptr))) {\n    if (total_size < required_size) {\n      auto* end_tensor = end->second->tensor;\n      if (end_tensor != nullptr && (end_tensor->is_pinned() || !end_tensor->is_evictable())) {\n        VLOG(2) << \"skip tensor: \" << end_tensor << \", size: \" << end_tensor->blob_bytes()\n                << \", compute op \" << end_tensor->compute_op_type_name()\n                << \", num_pinned: \" << end_tensor->num_pinned()\n                << \", is_evictable: \" << end_tensor->is_evictable();\n        end++;\n        costs.push_back(0);\n        end_i++;\n        start = end;\n        start_i = end_i;\n        total_size = 0;\n        cost_except_size = 0;\n        continue;\n      }\n      total_size += end->second->size;\n      auto cur_op_cost = get_cost(end_tensor);\n      costs.push_back(cur_op_cost);\n      cost_except_size += cur_op_cost;\n      VLOG(2) << \"move end, include op: \"\n              << (end_tensor != nullptr ? end_tensor->compute_op_type_name() : \"no tensor\")\n              << \", size: \" << end->second->size << \", total_size: \" << total_size\n              << \", total cost: \" << cost_except_size << \", cur op cost: \" << cur_op_cost;\n      end++;\n      end_i++;\n    } else {\n      auto* start_tensor = start->second->tensor;\n      // const auto* start_tensor = start->second->tensor;\n      total_size -= start->second->size;\n      // start_tensor is back in the pool, update_after_pesudo_compute\n      double cur_op_cost = 0;\n      cur_op_cost = costs[start_i];\n      cost_except_size -= cur_op_cost;\n      VLOG(2) << \"move start, exclude op: \"\n              << (start_tensor != nullptr ? start_tensor->compute_op_type_name() : \"no tensor\")\n              << \", size: \" << start->second->size << \", total_size: \" << total_size\n              << \", total cost: \" << cost_except_size << \", cur op cost: \" << cur_op_cost;\n      start++;\n      start_i++;\n    }\n    double cost = cost_except_size;\n    if (total_size >= required_size && cost < min_cost) {\n      min_cost = cost;\n      min_start = start;\n      min_end = end;\n      VLOG(2) << \"record, min_cost: \" << min_cost;\n    }\n  }\n  // CHECK(min_end != start);\n  // collect piece ptrs into a new container, because evict() will devalidate the iterators\n  std::vector<Piece*> pieces_to_be_evicted;\n  for (auto it = min_start; it != min_end; ++it) {\n    Piece* piece = it->second;\n    pieces_to_be_evicted.push_back(piece);\n  }\n  if (IsInDebugMode()) {\n    for (auto* piece : pieces_to_be_evicted) {\n      LOG(INFO) << \"release dptr: \" << get_offset(piece->ptr) << \", size: \" << piece->size\n                << \", cost: \" << get_cost(piece->tensor) << \", compute op: \"\n                << (piece->tensor != nullptr ? piece->tensor->compute_op_type_name() : \"no\")\n                << \", id: \"\n                << (piece->tensor != nullptr ? std::to_string(piece->tensor->id()) : \"no\");\n    }\n  }\n  size_t evict_size = 0;\n  for (auto* piece : pieces_to_be_evicted) {\n    evict_size += piece->size;\n    // NOTE: evict will trigger the merge and deallocation of neighbour free pieces,\n    // e.g. two contiguous pieces relu, no_tensor, after relu evict, no_tensor will be deallocated.\n    // currently deallocation only set tensor to nullptr, not real free,\n    // so no bug occurs. It is tricky and fragile.\n    if (piece->tensor != nullptr) {\n      CHECK_OR_RETURN(!ShouldBeHeldBySmallPiece(piece->size));\n      piece->tensor->Evict(false);\n    }\n  }\n  VLOG(2) << \"evict size: \" << evict_size;\n\n  if (!pieces_to_be_evicted.empty()) { return CHECK_NOTNULL(JUST(FindPiece(required_size, true))); }\n  return nullptr;\n}\n\nMaybe<void> RematEpAllocator::Allocate(char** mem_ptr, std::size_t size) {\n  if (size == 0) {\n    *mem_ptr = nullptr;\n    return Maybe<void>::Ok();\n  }\n  ReentrantThreadSafeLock::RAIIGuard guard(thread_lock_);\n  size_t aligned_size = CudaMemAlignedBytes(size);\n\n  Piece* piece = JUST(FindPiece(aligned_size, false));\n\n  if (piece == nullptr) {\n    if (first_time) {\n      if (EnvBool<ONEFLOW_REMAT_DISPLAY_IN_FIRST_TIME>()) { DisplayAllPieces(); }\n      first_time = false;\n    }\n    const auto started_at = profiler::GetTimeNow();\n    const size_t evict_num1 = Singleton<remat::Env>::Get()->forced_eviction_num();\n    if (EnvBool<ONEFLOW_REMAT_HEURISTIC_DTE>()) {\n      piece = JUST(EvictAndFindPieceLoop(aligned_size, true));\n    } else if (EnvBool<ONEFLOW_REMAT_HEURISTIC_DTR>()) {\n      piece = JUST(EvictAndFindPieceLoop(aligned_size, false));\n    } else {\n      piece = JUST(EvictAndFindPieceOnce(aligned_size));\n    }\n    const size_t evict_num2 = Singleton<remat::Env>::Get()->forced_eviction_num();\n    const auto duration = profiler::GetTimeNow() - started_at;\n    search_free_mem_cost_.emplace_back(size, evict_num2 - evict_num1, duration);\n    if (EnvBool<ONEFLOW_REMAT_RECORD_MEM_FRAG_RATE>()) {\n      size_t free_mem = 0;\n      for (const auto& pair : ptr2piece_) {\n        Piece* piece = pair.second;\n        if (piece->is_free) {\n          CHECK_ISNULL_OR_RETURN(piece->tensor);\n          free_mem += piece->size;\n        }\n      }\n      remat::append_memory_frag_info_and_get(free_mem, memory_size_);\n    }\n  }\n\n  if (piece == nullptr) { DisplayAllPieces(); }\n\n  CHECK_OR_RETURN(piece != nullptr) << \"Error! : Out of memory when allocate size : \" << size;\n  CHECK_NOTNULL(piece->ptr);\n  CHECK_OR_RETURN(ptr2piece_.find(piece->ptr) != ptr2piece_.end());\n  LOG(INFO) << \"allocate offset: \" << get_offset(piece->ptr) << \", size: \" << piece->size\n            << std::endl;\n  *mem_ptr = piece->ptr;\n  total_allocate_bytes_ += size;\n  piece->is_free = false;\n\n  return Maybe<void>::Ok();\n}\n\nvoid RematEpAllocator::Deallocate(char* mem_ptr, std::size_t size) {\n  if (mem_ptr == nullptr) { return; }\n  ReentrantThreadSafeLock::RAIIGuard guard(thread_lock_);\n\n  auto it = ptr2piece_.find(mem_ptr);\n  CHECK(it != ptr2piece_.end()) << \"Error! : Try deallocate mem_ptr non-existent. mem ptr = \"\n                                << mem_ptr << \" size = \" << size;\n  Piece* piece = it->second;\n  CHECK_NOTNULL(piece);\n  CHECK_EQ(piece->ptr, mem_ptr);\n  CHECK(!piece->is_free);\n\n  if (auto* tensor = piece->tensor) {\n    CHECK_JUST(remat::DisjointSet::update_after_release(tensor));\n  }\n\n  piece->is_free = true;\n  piece->tensor = nullptr;\n  piece->is_left = true;\n\n  Piece* last_piece_insert_to_free_list = piece;\n  Piece* next_p = piece->next;\n  Piece* prev_p = piece->prev;\n\n  VLOG(2) << \"deallocate offset: \" << get_offset(piece->ptr) << \", size: \" << piece->size\n          << \", prev: \" << prev_p << \", next: \" << next_p;\n\n  if (next_p != nullptr && next_p->is_free) {\n    CHECK_EQ(next_p->ptr, piece->ptr + piece->size);\n    EraseFromFreeList(next_p);\n    VLOG(2) << \"merge with next_p\";\n    MergeNeighbourFreePiece(piece, next_p);\n  }\n\n  if (prev_p != nullptr && prev_p->is_free) {\n    CHECK_EQ(piece->ptr, prev_p->ptr + prev_p->size);\n    EraseFromFreeList(prev_p);\n    VLOG(2) << \"merge with prev_p\";\n    MergeNeighbourFreePiece(prev_p, piece);\n    last_piece_insert_to_free_list = prev_p;\n  }\n  InsertToFreeList(last_piece_insert_to_free_list);\n  total_deallocate_bytes_ += size;\n  CheckPieces();\n}\n\nsize_t RematEpAllocator::allocated_memory() {\n  CHECK_GE(total_allocate_bytes_, total_deallocate_bytes_);\n  return total_allocate_bytes_ - total_deallocate_bytes_;\n}\n\nvoid RematEpAllocator::DeviceReset() {\n  ReentrantThreadSafeLock::RAIIGuard guard(thread_lock_);\n  backend_->DeviceReset();\n}\n\nnlohmann::json RematEpAllocator::DumpSearchFreeMemCost() {\n  return {{\"overhead\", search_free_mem_cost_}};\n}\n\n}  // namespace vm\n\nvm::RematEpAllocator* remat::AllocatorManager::CreateOrGetAllocator(DeviceType device_type,\n                                                                    size_t device_index) {\n  auto key = std::make_pair(device_type, device_index);\n  auto it = allocators_.find(key);\n  if (it == allocators_.end()) {\n    auto ep_device =\n        Singleton<ep::DeviceManagerRegistry>::Get()->GetDevice(device_type, device_index);\n    auto ep_backend_allocator =\n        std::make_unique<vm::EpBackendAllocator>(ep_device, ep::AllocationOptions{});\n    auto allocator = std::make_unique<vm::RematEpAllocator>(ep::kMaxAlignmentRequirement,\n                                                            std::move(ep_backend_allocator));\n    allocators_.emplace(key, std::move(allocator));\n    return allocators_.at(key).get();\n  } else {\n    return it->second.get();\n  }\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/vm/remat/allocator.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_VM_DTR_EP_ALLOCATOR_H_\n#define ONEFLOW_CORE_VM_DTR_EP_ALLOCATOR_H_\n\n#include <cstdint>\n#include \"oneflow/core/common/env_var/remat.h\"\n#include \"oneflow/core/ep/include/device.h\"\n#include \"oneflow/core/vm/allocator.h\"\n#include \"oneflow/core/common/util.h\"\n#include \"nlohmann/json.hpp\"\n#include \"oneflow/core/vm/thread_safe_guard.h\"\n\nnamespace oneflow {\n\nnamespace vm {\n\nclass EagerBlobObject;\nclass RematableTensorStorage;\n\nclass RematEpAllocator final : public Allocator {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(RematEpAllocator);\n  explicit RematEpAllocator(size_t alignment, std::unique_ptr<Allocator>&& backend);\n  ~RematEpAllocator() override;\n  void DeviceReset() override;\n\n  Maybe<void> Allocate(char** mem_ptr, std::size_t size) override;\n  void Deallocate(char* mem_ptr, std::size_t size) override;\n  void LinkStorageAndPtr(RematableTensorStorage* storage, const char* mem_ptr);\n  void CheckPieces();\n  void DisplayAllPieces();\n  nlohmann::json DumpSearchFreeMemCost();\n  size_t allocated_memory();\n  void set_left(bool is_left) { left = is_left; }\n  bool left = true;\n\n  size_t iterate_group_index(bool high) const;\n\n  bool first_time = true;\n\n private:\n  const size_t alignment_;\n  const std::unique_ptr<Allocator> backend_;\n  ReentrantThreadSafeLock thread_lock_;\n\n  using offset_t = size_t;\n\n  offset_t get_offset(const char* mem_ptr) const;\n\n  // Piece is the basic memory unit of CudaAllocator.\n  // A Piece is either is free(is_free = true) or in used(is_free = false).\n  // Pieces are stored in a linked list. The Piece's prev and next are\n  // continuous with the current Piece in physical memory.\n  struct Piece {\n    size_t size = 0;\n    char* ptr = nullptr;\n    bool is_free = true;\n    Piece* prev = nullptr;\n    Piece* next = nullptr;\n    vm::RematableTensorStorage* tensor = nullptr;\n    bool is_left = true;\n  };\n\n  Maybe<bool> InSmallMemoryArea(void* ptr);\n\n  offset_t FindProperPositionInGroup(Piece* piece, size_t group_idx, size_t request_size) const;\n\n  Piece* AllocateMemoryInPiece(Piece* piece, offset_t offset_in_piece, size_t size);\n\n  void InsertToFreeList(Piece* piece);\n  void EraseFromFreeList(Piece* piece);\n\n  void InitMemory();\n\n  // Try find free Piece which size is larger than aligned_size\n  // Return nullptr when find failure\n  Maybe<Piece*> FindPiece(size_t aligned_size, bool after_eviction);\n  void Display();\n\n  // Create new empty Piece or recycle a Piece from recycle_piece_list_\n  Piece* AllocatePiece();\n  // Delete a Piece and move in the linked list recycle_piece_list_\n  void DeallocatePiece(Piece* piece);\n\n  // Insert a {piece->ptr, piece} pair into the ptr2piece_ map for search Piece when call\n  // Deallocate()\n  void InsertPiece2PtrMap(Piece* piece);\n  // Erase the {piece->ptr, piece} pair from ptr2piece_ because the ptr is useless\n  // Usually call before DeallocatePiece()\n  void ErasePieceFromPtrMap(Piece* piece);\n\n  void MergeNeighbourFreePiece(Piece* lhs, Piece* rhs);\n\n  Maybe<Piece*> EvictAndFindPieceOnce(size_t required_size);\n  Maybe<Piece*> EvictAndFindPieceLoop(size_t required_size, bool consider_neighbor);\n\n  char* memory_ = nullptr;\n  size_t memory_size_;\n  void* small_piece_area_ptr_ = nullptr;\n\n  // hold the lifetime of Piece\n  std::vector<std::unique_ptr<Piece>> pieces_;\n  struct PieceCmp {\n    bool operator()(const Piece* lhs, const Piece* rhs) const {\n      if (lhs->size != rhs->size) { return lhs->size < rhs->size; }\n      // compare pointer by raw < or > is undefined behavior\n      return std::less<>{}(lhs->ptr, rhs->ptr);\n    }\n  };\n  std::vector<std::set<Piece*, PieceCmp>> free_pieces_overlapping_with_group_;\n  // std::map is sorted by key, so we can find contiguous memory by it\n  std::map<const char*, Piece*> ptr2piece_;\n  std::vector<std::tuple<size_t, int, int64_t>> search_free_mem_cost_;\n  Piece* recycle_piece_list_;\n  size_t total_allocate_bytes_ = 0;\n  size_t total_deallocate_bytes_ = 0;\n\n  // -----\n  size_t normal_group_num_;\n  std::vector<size_t> group_indexes_;\n  mutable size_t cur_group_index_id_;\n  mutable size_t cur_group_index_id_high_cost_;\n  bool enable_left_and_right_;\n  std::vector<std::pair<offset_t, offset_t>> group_boundaries_;\n\n  size_t group_index(bool high) const;\n};\n\nclass DtrEpAllocatorProxy final : public Allocator {\n public:\n  explicit DtrEpAllocatorProxy(vm::RematEpAllocator* allocator) : allocator(allocator) {}\n  void DeviceReset() override { allocator->DeviceReset(); }\n\n  Maybe<void> Allocate(char** mem_ptr, std::size_t size) override {\n    return allocator->Allocate(mem_ptr, size);\n  }\n  void Deallocate(char* mem_ptr, std::size_t size) override {\n    allocator->Deallocate(mem_ptr, size);\n  }\n  vm::RematEpAllocator* const allocator;\n};\n\n}  // namespace vm\n\nnamespace remat {\nclass AllocatorManager {\n public:\n  vm::RematEpAllocator* CreateOrGetAllocator(DeviceType device_type, size_t device_index);\n\n private:\n  std::unordered_map<std::pair<DeviceType, size_t>, std::unique_ptr<vm::RematEpAllocator>>\n      allocators_;\n};\n\n}  // namespace remat\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_VM_DTR_EP_ALLOCATOR_H_\n"
  },
  {
    "path": "oneflow/core/vm/remat/disjoint_set.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/vm/remat/disjoint_set.h\"\n\n#include \"oneflow/core/vm/op_call_instruction_policy.h\"\n#include \"oneflow/core/eager/tensor_storage.h\"\n#include \"oneflow/core/vm/remat/allocator.h\"\n\nnamespace oneflow {\n\nnamespace remat {\n\nvoid DisjointSet::merge(std::shared_ptr<DisjNode>& x, std::shared_ptr<DisjNode>& y) {\n  auto parent_x = find_father(x);\n  auto parent_y = find_father(y);\n  if (parent_x.get() == parent_y.get()) { return; }\n\n  parent_y->set_compute_time(parent_y->compute_time() + parent_x->compute_time());\n  parent_x->set_parent(parent_y);\n}\n\nstd::shared_ptr<DisjNode> DisjointSet::find_father(std::shared_ptr<DisjNode>& x) {\n  if (x->is_root()) {\n    return x;\n  } else {\n    auto fa = x->parent();\n    auto y = find_father(fa);\n    x->set_parent(y);\n    return y;\n  }\n}\n\nvoid DisjointSet::update_after_compute(vm::RematableTensorStorage* obj) {\n  auto fa = find_father(obj->node);\n  fa->set_compute_time(fa->compute_time() - obj->node->compute_time());\n  obj->node->reset(obj->compute_time());\n}\n\nMaybe<void> DisjointSet::update_after_release(vm::RematableTensorStorage* obj) {\n  CHECK_NOTNULL_OR_RETURN(obj);\n  if (obj->is_eviction_disabled()) { return Maybe<void>::Ok(); }\n\n  const auto merge_nodes = [&obj](const auto& eager_blob_objects) {\n    for (int i = 0; i < eager_blob_objects.size(); ++i) {\n      if (auto storage = std::dynamic_pointer_cast<vm::RematableTensorStorage>(\n              eager_blob_objects[i]->tensor_storage());\n          storage && !storage->is_in_memory()) {\n        merge(storage->node, obj->node);\n      }\n    }\n  };\n\n  auto operand = obj->compute_op();\n  merge_nodes(operand.inputs());\n  merge_nodes(operand.outputs());\n\n  return Maybe<void>::Ok();\n}\n\n}  // namespace remat\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/vm/remat/disjoint_set.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#pragma once\n\n#include <memory>\n\n#include \"oneflow/core/common/maybe.h\"\n\nnamespace oneflow {\n\nnamespace vm {\nclass RematableTensorStorage;\n}\n\nnamespace remat {\n\nclass DisjNode {\n public:\n  explicit DisjNode(double time) : compute_time_(time), parent_(nullptr), cnt_(1) {}\n\n  bool is_root() { return !bool(parent_); }\n\n  void set_parent(std::shared_ptr<DisjNode>& parent) { parent_ = parent; }\n  void set_compute_time(double new_time) { compute_time_ = new_time; }\n\n  void set_cnt(int cnt) { cnt_ = cnt; }\n  void add_cnt() { cnt_++; }\n  void reduce_cnt() { cnt_--; }\n\n  double compute_time() { return compute_time_; }\n  std::shared_ptr<DisjNode> parent() { return parent_; }\n  int cnt() { return cnt_; }\n\n  void reset(double t) {\n    compute_time_ = t;\n    parent_.reset();\n  }\n\n private:\n  double compute_time_;\n  std::shared_ptr<DisjNode> parent_;\n  int cnt_;\n};\n\nclass DisjointSet {\n public:\n  static void merge(std::shared_ptr<DisjNode>& x, std::shared_ptr<DisjNode>& y);\n  static std::shared_ptr<DisjNode> find_father(std::shared_ptr<DisjNode>& x);\n  static void update_after_compute(vm::RematableTensorStorage* obj);\n  static Maybe<void> update_after_release(vm::RematableTensorStorage* obj);\n};\n\n}  // namespace remat\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/vm/remat/env.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/vm/remat/env.h\"\n\n#include \"nlohmann/json.hpp\"\n#include \"oneflow/core/eager/tensor_storage.h\"\n#include \"oneflow/core/vm/op_call_instruction_policy.h\"\n#include \"oneflow/core/rpc/include/global_process_ctx.h\"\n\nnamespace oneflow {\n\nnamespace remat {\n\nvm::OpCallInstructionPolicy Env::update_tensor_with_storage(\n    vm::RematableTensorStorage* storage, const vm::OpCallInstructionPolicy& current_compute_op) {\n  // TODO: set disjnode properly\n  auto new_storage = std::make_shared<vm::RematableTensorStorage>(storage->device());\n  std::unordered_map<vm::EagerBlobObject*, std::shared_ptr<vm::EagerBlobObject>> old2new;\n  auto update = [&new_storage, &old2new](std::shared_ptr<vm::EagerBlobObject>& old) {\n    auto it = old2new.find(old.get());\n    if (it != old2new.end()) {\n      old = it->second;\n    } else {\n      auto local_tensor_meta = old->tensor_meta();\n      const auto& eager_blob_object = std::make_shared<vm::EagerBlobObject>(\n          std::make_shared<MemoryCase>(old->mem_case()), local_tensor_meta, old->mut_tensor_meta(),\n          local_tensor_meta->dtype(), local_tensor_meta->memory_format(), new_storage);\n      eager_blob_object->set_storage_offset(old->storage_offset());\n      old2new.emplace(old.get(), eager_blob_object);\n      old = eager_blob_object;\n    }\n  };\n  auto update_output = [&old2new, &new_storage](std::weak_ptr<vm::EagerBlobObject>& old) {\n    auto it = old2new.find(CHECK_NOTNULL(old.lock()).get());\n    if (it != old2new.end()) {\n      old = it->second;\n    } else {\n      auto old_locked = old.lock();\n      auto local_tensor_meta = old_locked->tensor_meta();\n      const auto& eager_blob_object = std::make_shared<vm::EagerBlobObject>(\n          std::make_shared<MemoryCase>(old_locked->mem_case()), local_tensor_meta,\n          old_locked->mut_tensor_meta(), local_tensor_meta->dtype(),\n          local_tensor_meta->memory_format(), new_storage);\n      eager_blob_object->set_storage_offset(old_locked->storage_offset());\n      old2new.emplace(old_locked.get(), eager_blob_object);\n      old = eager_blob_object;\n    }\n  };\n  for (int i = ops.size() - 1; i >= 0; i--) {\n    auto& op = ops[i];\n    for (int j = 0; j < op->mut_inputs().size(); j++) {\n      auto& x = op->mut_inputs()[j];\n      if (x == nullptr) {\n        LOG(INFO) << \"No.\" << j << \" input of \" << op->opkernel().op_type_name() << \" is nullptr\"\n                  << std::endl;\n        continue;\n      }\n      if (x->tensor_storage().get() == storage) {\n        vm::EagerBlobObject* old_ptr = x.get();\n        update(x);\n        VLOG(1) << \"update input of \" << op->opkernel().op_type_name() << \" from \" << old_ptr\n                << \" (storage \" << storage << \") to \" << x.get() << \" (storage \"\n                << new_storage.get() << \"), op addr \" << op << std::endl;\n      }\n    }\n    for (int j = 0; j < op->mut_outputs().size(); j++) {\n      auto& y = op->mut_outputs()[j];\n      if (y.lock() == nullptr) {\n        LOG(INFO) << \"No.\" << j << \" output of \" << op->opkernel().op_type_name() << \" is nullptr\"\n                  << std::endl;\n        continue;\n      }\n      if (CHECK_NOTNULL(y.lock())->tensor_storage().get() == storage) {\n        vm::EagerBlobObject* old_ptr = y.lock().get();\n        update_output(y);\n        VLOG(1) << \"update output of \" << op->opkernel().op_type_name() << \" from \" << old_ptr\n                << \" (storage \" << storage << \") to \" << y.lock().get() << \" (storage \"\n                << new_storage.get() << \"), op addr \" << op << std::endl;\n      }\n    }\n  }\n  vm::OpCallInstructionPolicy new_compute_op = current_compute_op;\n  // only update inputs\n  for (auto& x : new_compute_op.mut_inputs()) {\n    if (x->tensor_storage().get() == storage) {\n      vm::EagerBlobObject* old_ptr = x.get();\n      update(x);\n      VLOG(1) << \"update input of \" << new_compute_op.opkernel().op_type_name() << \" from \"\n              << old_ptr << \" to \" << x.get() << std::endl;\n    }\n  }\n  VLOG(1) << \"update_tensor_with_storage: storage \" << storage->id();\n  // set compute_op_ and compute_time_\n  new_storage->set_compute_op(storage->dtr_compute_op(), storage->compute_time());\n  // set blob_bytes_\n  new_storage->set_blob_dptr(nullptr, storage->blob_bytes());\n  // set is_initialized_\n  new_storage->set_initialized();\n  // set last_access_time_\n  new_storage->Access();\n  storage->clear_compute_op();\n  return new_compute_op;\n}\n\nvoid Env::add_eviction_num(bool eager_eviction) {\n  if (eager_eviction) {\n    eager_eviction_num_++;\n  } else {\n    forced_eviction_num_++;\n  }\n}\n\nEnv::~Env() {\n  LOG(INFO) << \"forced eviction num: \" << forced_eviction_num_;\n  LOG(INFO) << \"eager eviction num: \" << eager_eviction_num_;\n  LOG(INFO) << \"recomputation num: \" << recomputation_num_;\n  LOG(INFO) << \"duration: \" << time_now_;\n\n  const char* prefix = std::getenv(\"ONEFLOW_REMAT_SUMMARY_FILE_PREFIX\");\n  if (prefix != nullptr && GlobalProcessCtx::LocalRank() == 0) {\n    using json = nlohmann::json;\n    json cpp_summary{{\"forced eviction\", forced_eviction_num_},\n                     {\"eager eviction\", eager_eviction_num_},\n                     {\"recomputation\", recomputation_num_},\n                     {\"dataset time\", time_now_}};\n\n    json full_json;\n    // std::fstream has strange default append semantic\n    {\n      std::ifstream fs(std::string(prefix) + \".json\");\n      if (fs.is_open()) { fs >> full_json; }\n    }\n    full_json.merge_patch(cpp_summary);\n    {\n      std::ofstream fs(std::string(prefix) + \".json\");\n      fs << full_json;\n    }\n  }\n}\n\n}  // namespace remat\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/vm/remat/env.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#pragma once\n\n#include \"oneflow/core/common/env_var/remat.h\"\n#include \"oneflow/core/common/util.h\"\n\n#define VLOG_REMAT(verbose_level) \\\n  if (Singleton<remat::Env>::Get()->log_enabled()) VLOG(verbose_level)\n\nnamespace oneflow {\n\nnamespace vm {\nclass RematableTensorStorage;\nclass OpCallInstructionPolicy;\nclass DtrOpCallInstructionPolicy;\n}  // namespace vm\n\nnamespace remat {\n\nclass Env {\n public:\n  Env() = default;\n  ~Env();\n  OF_DISALLOW_COPY_AND_MOVE(Env);\n  double time_now() { return time_now_; }\n  void add_time(double time) { time_now_ += time; }\n  void remove_compute_op(vm::DtrOpCallInstructionPolicy* op) {\n    ops.erase(std::remove(ops.begin(), ops.end(), op), ops.end());\n  }\n  vm::OpCallInstructionPolicy update_tensor_with_storage(\n      vm::RematableTensorStorage* storage, const vm::OpCallInstructionPolicy& current_compute_op);\n\n  std::vector<vm::DtrOpCallInstructionPolicy*> ops;\n\n  void add_eviction_num(bool eager_eviction);\n\n  int eager_eviction_num() const { return eager_eviction_num_; }\n  int forced_eviction_num() const { return forced_eviction_num_; }\n\n  void add_recomputation_num() { recomputation_num_++; }\n  int recomputation_num() const { return recomputation_num_; }\n\n  void clear_stats() {\n    time_now_ = 0;\n    eager_eviction_num_ = 0;\n    forced_eviction_num_ = 0;\n    recomputation_num_ = 0;\n  }\n\n  std::set<vm::RematableTensorStorage*> need_eager_eviction_storages;\n\n  void set_budget_in_bytes(size_t budget_in_bytes) { budget_in_bytes_ = budget_in_bytes; }\n  size_t budget_in_bytes() const { return budget_in_bytes_; }\n\n  void set_small_pieces_optimization(bool enabled) { small_pieces_optimization_ = enabled; }\n  bool is_small_pieces_optimization_enabled() const { return small_pieces_optimization_; }\n\n  bool log_enabled() const { return EnvBool<ONEFLOW_REMAT_LOG>(); }\n\n private:\n  double time_now_ = 0;\n\n  int eager_eviction_num_ = 0;\n  int forced_eviction_num_ = 0;\n  int recomputation_num_ = 0;\n\n  size_t budget_in_bytes_ = 0;\n  bool small_pieces_optimization_ = true;\n};\n\nstruct CurrentOpTypeName {\n  std::string value;\n};\n\n}  // namespace remat\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/vm/remat/util.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/vm/remat/util.h\"\n\n#include <algorithm>\n\n#include \"nlohmann/json.hpp\"\n#include \"oneflow/core/common/env_var/remat.h\"\n#include \"oneflow/core/common/env_var/vm.h\"\n#include \"oneflow/core/eager/tensor_storage.h\"\n#include \"oneflow/core/framework/compute_complexity_fn_context.h\"\n#include \"oneflow/core/vm/op_call_instruction_policy.h\"\n#include \"oneflow/core/vm/remat/env.h\"\n#include \"oneflow/core/vm/remat/disjoint_set.h\"\n#include \"oneflow/user/kernels/stateful_opkernel.h\"\n#include \"oneflow/core/framework/user_op_registry_manager.h\"\n\nnamespace oneflow {\n\nnamespace remat {\n\ndouble append_memory_frag_info_and_get(size_t free_mem, size_t threshold) {\n  static size_t num = 0;\n  // maintain a summation of memory frag rate\n  static double memory_frag_rate_sum = 0;\n  if (threshold > 0) {\n    memory_frag_rate_sum += (1. * free_mem / threshold);\n    num++;\n  }\n  return memory_frag_rate_sum / num;\n}\n\nnamespace {\n\nstd::string SortKey(const std::string& key) {\n  const auto shape_finish_at = key.rfind(\")\");\n  if (shape_finish_at == std::string::npos || shape_finish_at + 2 == key.size()) { return key; }\n  const auto name_and_shape = key.substr(0, shape_finish_at + 1);\n  auto attrs = key.substr(shape_finish_at + 2);\n  if (attrs.substr(attrs.size() - 2) == \", \") { attrs = attrs.substr(0, attrs.size() - 2); }\n\n  const auto need_find_next = [](const std::string& s, size_t index) {\n    const size_t final_pos = index + 2;\n    if (final_pos >= s.size()) { return false; }\n    if (s.at(index + 1) != ' ') { return true; }\n    if (!(s.at(final_pos) >= 'a' && s.at(final_pos) <= 'z')) { return true; }\n    return false;\n  };\n\n  const auto split = [&need_find_next](const std::string& s, std::vector<std::string>& tokens,\n                                       const std::string& delimiters) {\n    std::string::size_type lastPos = s.find_first_not_of(delimiters, 0);\n    std::string::size_type pos = s.find_first_of(delimiters, lastPos);\n    while (std::string::npos != pos && need_find_next(s, pos)) {\n      pos = s.find_first_of(delimiters, pos + 1);\n    }\n    while (std::string::npos != pos || std::string::npos != lastPos) {\n      tokens.push_back(s.substr(lastPos, pos - lastPos));\n      lastPos = s.find_first_not_of(delimiters, pos);\n      pos = s.find_first_of(delimiters, lastPos);\n      while (std::string::npos != pos && need_find_next(s, pos)) {\n        pos = s.find_first_of(delimiters, pos + 1);\n      }\n    }\n  };\n  std::vector<std::string> attrs_splited;\n  split(attrs, attrs_splited, \", \");\n  std::sort(attrs_splited.begin(), attrs_splited.end());\n  return fmt::format(\"{} {}, \", name_and_shape, fmt::join(attrs_splited, \", \"));\n}\n\nusing json = nlohmann::json;\n\njson LoadTimeDataset() {\n  json j;\n  if (const char* c = std::getenv(\"ONEFLOW_REMAT_OP_TIME_DATASET\")) {\n    std::ifstream i(c);\n    i >> j;\n    i.close();\n  }\n  json new_j;\n\n  for (json::iterator iter = j.begin(); iter != j.end(); ++iter) {\n    new_j[SortKey(iter.key())] = iter.value();\n  }\n  return new_j;\n}\n\nMaybe<double> GetDatasetComputeTime(const json& j, const vm::OpCallInstructionPolicy& operand) {\n  const std::vector<std::string> zero_time_list{\n      \"empty\", \"identity\", \"constant\", \"copy\", \"zero_like\", \"expand_dims\", \"flatten\", \"reduce_sum\",\n      \"reshape\", \"reshape_like\", \"squeeze\", \"transpose\", \"nll\", \"nll_grad\", \"uniform\",\n      \"uniform_int\", \"fill_\", \"slice_update\", \"normal\",\n      // ddp\n      \"eager_ccl_broadcast\", \"eager_ccl_all_reduce\", \"eager_ccl_touch\", \"scalar_mul\",\n\n      // \"adaptive_avg_pool2d\",\n      // \"adaptive_avg_pool2d_grad\"\n  };\n  for (const auto& x : zero_time_list) {\n    if (operand.opkernel().op_type_name() == x) { return 0; }\n  }\n\n  const std::string op_type_str = operand.opkernel().op_type_name();\n  const std::string input_shape_str = [&]() {\n    std::stringstream ss;\n    for (size_t i = 0; i < operand.inputs().size(); i++) {\n      ss << operand.inputs().at(i)->shape();\n      if (i != operand.inputs().size() - 1) { ss << \", \"; }\n    }\n    return ss.str();\n  }();\n  const std::string attr_str = operand.composed_attrs().ToString();\n  std::string key = op_type_str + \" \" + input_shape_str + \" \" + attr_str;\n  key = SortKey(key);\n  CHECK_OR_RETURN(j.contains(key)) << \"key \" << key << \" not found\";\n  CHECK_OR_RETURN(j[key].is_number_float()) << \"key \" << key << \" is not float, but \" << j[key];\n  return j[key].get<double>();\n}\n\nstatic Maybe<double> GetComputeComplexityEstimatedBySize(\n    const vm::OpCallInstructionPolicy& operand) {\n  const auto& inputs = operand.inputs();\n  const auto& outputs = operand.outputs();\n  size_t estimated_compute_time = 0;\n  for (const auto& input : inputs) { estimated_compute_time += input->shape().elem_cnt(); }\n  for (const auto& output : outputs) { estimated_compute_time += output->shape().elem_cnt(); }\n  return estimated_compute_time;\n}\n\nint32_t TryGetTensorTupleIndex(const std::unordered_map<std::string, std::vector<int32_t>>&\n                                   arg_name2bn_index2tensor_tuple_index,\n                               const std::string& arg_name, const int32_t arg_index) {\n  auto it = arg_name2bn_index2tensor_tuple_index.find(arg_name);\n  if (it != arg_name2bn_index2tensor_tuple_index.end()) { return it->second.at(arg_index); }\n  return -1;\n}\n\nclass SingleDeviceOpComputeComplexityFnContext : public user_op::ComputeComplexityFnContext {\n public:\n  using ArgVec = std::vector<std::pair<std::string, int32_t>>;\n\n  SingleDeviceOpComputeComplexityFnContext(const OperatorConf& op_conf,\n                                           const vm::EagerBlobObjectList& inputs,\n                                           const vm::EagerBlobObjectList& outputs,\n                                           const ArgTuple* input_arg_tuple,\n                                           const ArgTuple* output_arg_tuple)\n      : user_op::ComputeComplexityFnContext(user_op::UserOpConfWrapper(op_conf)),\n        input_tensors_(inputs),\n        output_tensors_(outputs),\n        input_arg_tuple_(input_arg_tuple),\n        output_arg_tuple_(output_arg_tuple) {}\n  ~SingleDeviceOpComputeComplexityFnContext() override = default;\n\n#define RETURN_IF_FOUND(inputs, outputs, post_action)                                             \\\n  int32_t i = TryGetTensorTupleIndex(input_arg_tuple_->arg_name2bn_index2tensor_tuple_index(),    \\\n                                     arg_name, index);                                            \\\n  if (i >= 0) { return (inputs).at(i) post_action; }                                              \\\n  i = TryGetTensorTupleIndex(output_arg_tuple_->arg_name2bn_index2tensor_tuple_index(), arg_name, \\\n                             index);                                                              \\\n  if (i >= 0) { return (outputs).at(i) post_action; }\n\n  const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(const std::string& arg_name,\n                                                        int32_t index) override {\n    RETURN_IF_FOUND(input_tensors_, output_tensors_, ->tensor_meta().shared_from_symbol().get());\n    return nullptr;\n  }\n  const Shape& Shape4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override {\n    RETURN_IF_FOUND(input_tensors_, output_tensors_, ->shape())\n    UNIMPLEMENTED_THEN_THROW();\n  }\n  DataType Dtype4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override {\n    RETURN_IF_FOUND(input_tensors_, output_tensors_, ->data_type())\n    UNIMPLEMENTED_THEN_THROW();\n  }\n  bool IsDynamic4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override {\n    return false;\n  }\n\n  const NdSbp NdSbp4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override {\n    static NdSbp nd_sbp = []() {\n      NdSbp nd_sbp;\n      nd_sbp.add_sbp_parallel()->broadcast_parallel();\n      return nd_sbp;\n    }();\n    return nd_sbp;\n  }\n\n  const ArgVec& inputs() const override { UNIMPLEMENTED_THEN_THROW(); }\n  const ArgVec& outputs() const override { UNIMPLEMENTED_THEN_THROW(); }\n  const ParallelDesc& parallel_desc() const override {\n    static ParallelDesc parallel_desc = []() {\n      ParallelConf parallel_conf;\n      parallel_conf.set_device_tag(\"cpu\");\n      parallel_conf.add_device_name(\"0:0-0\");\n      return ParallelDesc(parallel_conf);\n    }();\n    return parallel_desc;\n  }\n  const NdSbpSignature* GetNdSbpSignature() const override { UNIMPLEMENTED_THEN_THROW(); }\n\n private:\n  const vm::EagerBlobObjectList& input_tensors_;\n  const vm::EagerBlobObjectList& output_tensors_;\n  const ArgTuple* input_arg_tuple_;\n  const ArgTuple* output_arg_tuple_;\n};\n\nMaybe<double> GetComputeComplexity(const vm::OpCallInstructionPolicy& operand) {\n  const auto& op_conf = operand.opkernel().op_conf();\n  auto registry =\n      user_op::UserOpRegistryMgr::Get().GetOpRegistryResult(op_conf.user_conf().op_type_name());\n\n  if (registry->compute_complexity_fn) {\n    SingleDeviceOpComputeComplexityFnContext ctx(op_conf, operand.inputs(), operand.outputs(),\n                                                 operand.opkernel().input_arg_tuple(),\n                                                 operand.opkernel().output_arg_tuple());\n    return registry->compute_complexity_fn(&ctx);\n  } else {\n    return GetComputeComplexityEstimatedBySize(operand);\n  }\n}\n}  // namespace\n\nMaybe<double> GetComputeTime(const vm::OpCallInstructionPolicy& operand) {\n  const static json time_dataset = LoadTimeDataset();\n  if (!time_dataset.empty()) { return GetDatasetComputeTime(time_dataset, operand); }\n  return GetComputeComplexity(operand);\n}\n\n}  // namespace remat\n\nnamespace vm {\n\nRematHelper::RematHelper(const OpCallInstructionPolicy& op_call_instruction_policy_)\n    : op_call_instruction_policy_(op_call_instruction_policy_) {\n  const auto save_eager_blob_object_storages = [](const auto& eager_blob_objects,\n                                                  auto& storage_conatiner) {\n    storage_conatiner.reserve(eager_blob_objects.size());\n    for (const auto& x : eager_blob_objects) {\n      storage_conatiner.emplace_back(\n          std::dynamic_pointer_cast<RematableTensorStorage>(x->tensor_storage()));\n    }\n  };\n  save_eager_blob_object_storages(op_call_instruction_policy_.inputs(), input_storages_);\n  save_eager_blob_object_storages(op_call_instruction_policy_.outputs(), output_storages_);\n}\n\nRematHelper::RematHelper(const OpCallInstructionPolicy& op_call_instruction_policy,\n                         bool inputs_rematable, bool outputs_rematable)\n    : RematHelper(op_call_instruction_policy) {\n  if (outputs_rematable) {\n    storage_is_initialized_.reserve(output_storages_.size());\n    for (auto& storage : output_storages_) {\n      storage_is_initialized_.push_back(storage->is_initialized());\n    }\n    if (!inputs_rematable) {\n      for (auto& storage : output_storages_) {\n        VLOG_REMAT(1) << \"set storage \" << storage->id() << \" unevictable\" << std::endl;\n        storage->set_eviction_disabled(true);\n      }\n    }\n  }\n}\n\nMaybe<void> RematHelper::_IncReferenceNumOfRecomputedTensor(\n    int& pinned_num, std::set<const DtrOpCallInstructionPolicy*>& visited_ops) {\n  VLOG_REMAT(1) << \"op is \" << op_call_instruction_policy_.opkernel().op_type_name();\n  for (int i = 0; i < input_storages_.size(); i++) {\n    auto& storage = input_storages_[i];\n    storage->Pin();\n    VLOG_REMAT(1) << \"No.\" << i << \" input is in memory? \" << storage->is_in_memory();\n    if (!storage->is_in_memory()) {\n      OpCallInstructionPolicy tmp_op = storage->compute_op();\n      if (!storage->is_needed_by_backward()) {\n        Singleton<remat::Env>::Get()->need_eager_eviction_storages.insert(storage.get());\n      }\n\n      if (visited_ops.find(storage->dtr_compute_op().get()) == visited_ops.end()) {\n        visited_ops.insert(storage->dtr_compute_op().get());\n        RematHelper new_helper(tmp_op);\n        JUST(new_helper._IncReferenceNumOfRecomputedTensor(pinned_num, visited_ops));\n      }\n    } else {\n      pinned_num++;\n    }\n  }\n  VLOG_REMAT(1) << \"op \" << op_call_instruction_policy_.opkernel().op_type_name() << \" end\";\n  return Maybe<void>::Ok();\n}\n\nMaybe<int> RematHelper::IncReferenceNumOfRecomputedTensor() {\n  int pinned_num = 0;\n  std::set<const DtrOpCallInstructionPolicy*> visited_ops;\n  JUST(_IncReferenceNumOfRecomputedTensor(pinned_num, visited_ops));\n  return pinned_num;\n}\n\nMaybe<void> RematHelper::RematInputs(\n    vm::Stream* vm_stream, bool first,\n    const std::function<Maybe<void>(OpCallInstructionPolicy*, vm::Stream*)>& compute_fn) {\n  CHECK_OR_RETURN(!ThreadLocalEnvBool<ONEFLOW_VM_MULTI_THREAD>());\n  if (first) { JUST(IncReferenceNumOfRecomputedTensor()); }\n  VLOG_REMAT(1) << \"compute \" << op_call_instruction_policy_.opkernel().op_type_name() << std::endl;\n  VLOG_REMAT(1) << \"input num \" << op_call_instruction_policy_.inputs().size() << std::endl;\n\n  for (int i = 0; i < input_storages_.size(); i++) {\n    auto& storage = input_storages_[i];\n    if (!storage->is_in_memory()) {\n      VLOG_REMAT(1) << \"recompute No.\" << i << \" input by \" << storage->compute_op_type_name()\n                    << \". Storage id: \" << storage->id();\n      OpCallInstructionPolicy tmp_op = storage->compute_op();\n      JUST(compute_fn(&tmp_op, vm_stream));\n    }\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> RematHelper::EagerlyEvictRemattedTensors(bool first) {\n  auto& need_eager_eviction_storages = Singleton<remat::Env>::Get()->need_eager_eviction_storages;\n  for (auto& storage : input_storages_) {\n    storage->Unpin();\n    if (storage->num_pinned() == 0 && need_eager_eviction_storages.count(storage.get()) > 0) {\n      need_eager_eviction_storages.erase(storage.get());\n      storage->Evict(true);\n    }\n  }\n  if (first) {\n    if (!need_eager_eviction_storages.empty()) {\n      for (const auto& storage : need_eager_eviction_storages) {\n        VLOG_REMAT(1) << \"not empty, storage id: \" << storage->id();\n      }\n    }\n    CHECK_OR_RETURN(need_eager_eviction_storages.empty());\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> RematHelper::UpdateRematInfo(bool first, bool recompute, bool include_input,\n                                         bool include_output) {\n  if (include_output) {\n    const std::unique_ptr<OpCallInstructionPolicy> compute_op = [&]() {\n      auto compute_op = std::make_unique<OpCallInstructionPolicy>(op_call_instruction_policy_);\n      for (int i = 0; i < output_storages_.size(); i++) {\n        const auto& storage = output_storages_[i];\n        VLOG_REMAT(1) << \"output \" << i << \" storage id: \" << storage->id();\n        if (storage->is_eviction_disabled()) { continue; }\n        if (storage_is_initialized_[i] && !recompute) {\n          VLOG_REMAT(1) << \"storage->is_initialized(), op is \" << storage->compute_op_type_name()\n                        << std::endl;\n          compute_op = std::make_unique<OpCallInstructionPolicy>(\n              Singleton<remat::Env>::Get()->update_tensor_with_storage(\n                  storage.get(), op_call_instruction_policy_));\n        }\n      }\n      return compute_op;\n    }();\n    std::shared_ptr<DtrOpCallInstructionPolicy> dtr_compute_op =\n        std::make_shared<DtrOpCallInstructionPolicy>(*compute_op);\n    double compute_time = JUST(remat::GetComputeTime(*compute_op));\n    for (auto& storage : output_storages_) {\n      storage->Pin();\n      if (!recompute && !storage->is_eviction_disabled()) {\n        storage->set_compute_op(dtr_compute_op, compute_time);\n      }\n      storage->Unpin();\n      storage->Access();\n      remat::DisjointSet::update_after_compute(storage.get());\n    }\n  }\n  if (include_input) {\n    for (int i : op_call_instruction_policy_.opkernel().input_tuple_indexes4mut_ibns()) {\n      input_storages_[i]->set_eviction_disabled(true);\n    }\n\n    for (auto& storage : input_storages_) { storage->Access(); }\n  }\n\n  if (recompute) { Singleton<remat::Env>::Get()->add_recomputation_num(); }\n  Singleton<remat::Env>::Get()->add_time(JUST(remat::GetComputeTime(op_call_instruction_policy_)));\n  VLOG_REMAT(1) << \"end compute \" << op_call_instruction_policy_.opkernel().op_type_name()\n                << std::endl;\n  return Maybe<void>::Ok();\n}\n\n}  // namespace vm\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/vm/remat/util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#pragma once\n\n#include <memory>\n\n#include \"oneflow/core/common/maybe.h\"\n\nnamespace oneflow {\n\nnamespace vm {\nclass OpCallInstructionPolicy;\n}\n\nnamespace remat {\n\ndouble append_memory_frag_info_and_get(size_t free_mem, size_t threshold);\n\nMaybe<double> GetComputeTime(const vm::OpCallInstructionPolicy& operand);\n\n}  // namespace remat\n\nnamespace vm {\n\nclass RematableTensorStorage;\nclass Stream;\nclass DtrOpCallInstructionPolicy;\n\n// This class is mainly for holding RematableTensorStorage vector so that we do not\n// need to generate them every time.\nclass RematHelper {\n public:\n  explicit RematHelper(const OpCallInstructionPolicy& op_call_instruction_policy);\n  RematHelper(const OpCallInstructionPolicy& op_call_instruction_policy, bool inputs_rematable,\n              bool outputs_rematable);\n\n  Maybe<void> RematInputs(\n      vm::Stream* vm_stream, bool first,\n      const std::function<Maybe<void>(OpCallInstructionPolicy*, vm::Stream*)>& compute_fn);\n  Maybe<void> EagerlyEvictRemattedTensors(bool first);\n  Maybe<void> UpdateRematInfo(bool first, bool recompute, bool include_input, bool include_output);\n\n private:\n  Maybe<int> IncReferenceNumOfRecomputedTensor();\n  Maybe<void> _IncReferenceNumOfRecomputedTensor(\n      int& pinned_num, std::set<const DtrOpCallInstructionPolicy*>& visited_ops);\n  const OpCallInstructionPolicy& op_call_instruction_policy_;\n  std::vector<std::shared_ptr<RematableTensorStorage>> input_storages_;\n  std::vector<std::shared_ptr<RematableTensorStorage>> output_storages_;\n  std::vector<bool> storage_is_initialized_;\n};\n\n}  // namespace vm\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/vm/stream.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/vm/stream.h\"\n#include \"oneflow/core/vm/thread_ctx.h\"\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/common/cpp_attribute.h\"\n#include \"oneflow/core/framework/device.h\"\n#include \"oneflow/core/vm/stream_create_stream_policy.h\"\n#include \"oneflow/core/framework/stream_on_independent_thread.h\"\n\nnamespace oneflow {\nnamespace vm {\n\nvoid Stream::__Init__(ThreadCtx* thread_ctx, Symbol<Device> device, StreamType stream_type,\n                      const intrusive::shared_ptr<Dependence>& schedule_local_dep_object,\n                      const std::vector<intrusive::shared_ptr<Dependence>>& transport_dependences) {\n  set_thread_ctx(thread_ctx);\n  device_ = device;\n  stream_type_ = stream_type;\n  stream_policy_ = CHECK_JUST(CreateStreamPolicy::Visit(stream_type, device));\n  schedule_local_dep_object_ = schedule_local_dep_object;\n  transport_dependences_ = transport_dependences;\n  on_scheduler_thread_ = stream_policy_->OnSchedulerThread(stream_type);\n}\n\nint64_t Stream::device_id() const { return device_->device_id(); }\n\nchar* Stream::CheckSizeAndGetTmpSmallPinnedMemPtr(size_t size) {\n  static constexpr int kSmallSize = 512;\n  CHECK_LE(size, kSmallSize);\n  if (!static_cast<bool>(small_pinned_mem_ptr_)) {\n    auto* ep_device = stream_policy_->stream()->device();\n    void* mem_ptr = nullptr;\n    CHECK_JUST(ep_device->AllocPinned(ep::AllocationOptions{}, &mem_ptr, kSmallSize));\n    std::function<void(char*)> Deleter = [ep_device](char* ptr) {\n      ep_device->FreePinned(ep::AllocationOptions{}, ptr);\n    };\n    char* ptr = reinterpret_cast<char*>(mem_ptr);\n    small_pinned_mem_ptr_ = decltype(small_pinned_mem_ptr_)(ptr, Deleter);\n  }\n  return small_pinned_mem_ptr_.get();\n}\n\n}  // namespace vm\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/vm/stream.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_VM_STREAM_H_\n#define ONEFLOW_CORE_VM_STREAM_H_\n\n#include \"oneflow/core/vm/instruction.h\"\n#include \"oneflow/core/common/symbol.h\"\n#include \"oneflow/core/common/optional.h\"\n#include \"oneflow/core/common/stream_type.h\"\n#include \"oneflow/core/vm/stream_policy.h\"\n\nnamespace oneflow {\n\nclass Device;\n\nnamespace vm {\n\nclass ThreadCtx;\nclass MirroredObject;\nclass Dependence;\n\nclass Stream final : public intrusive::Base {\n public:\n  // types\n  using DispatchedInstructionList =\n      intrusive::List<INTRUSIVE_FIELD(Instruction, dispatched_instruction_hook_)>;\n\n  // Getters\n  const StreamPolicy& stream_policy() const { return *stream_policy_; }\n  const ThreadCtx& thread_ctx() const { return *thread_ctx_; }\n  bool has_thread_ctx() const { return thread_ctx_ != nullptr; }\n  const intrusive::ListHook& active_stream_hook() const { return active_stream_hook_; }\n  const DispatchedInstructionList& running_instruction_list() const {\n    return running_instruction_list_;\n  }\n\n  // Setters\n  StreamPolicy* mut_stream_policy() { return stream_policy_.get(); }\n  ThreadCtx* mut_thread_ctx() { return thread_ctx_; }\n  void set_thread_ctx(ThreadCtx* val) { thread_ctx_ = val; }\n  void clear_thread_ctx() { thread_ctx_ = nullptr; }\n  DispatchedInstructionList* mut_running_instruction_list() { return &running_instruction_list_; }\n\n  // methods\n  void __Init__(ThreadCtx* thread_ctx, Symbol<Device> device, StreamType stream_type,\n                const intrusive::shared_ptr<Dependence>& schedule_local_dep_object,\n                const std::vector<intrusive::shared_ptr<Dependence>>& transport_dependences);\n  int64_t device_id() const;\n  Symbol<Device> device() const { return device_; }\n  StreamType stream_type() const { return stream_type_; }\n  bool on_scheduler_thread() const { return on_scheduler_thread_; }\n\n  const intrusive::shared_ptr<Dependence>& schedule_local_dep_object() const {\n    return schedule_local_dep_object_;\n  }\n\n  const std::vector<intrusive::shared_ptr<Dependence>>& transport_dependences() const {\n    return transport_dependences_;\n  }\n\n  char* CheckSizeAndGetTmpSmallPinnedMemPtr(size_t size);\n\n private:\n  void MoveToFreeList(intrusive::shared_ptr<Instruction>&& instruction);\n  void MoveFromZombieListToFreeList();\n\n  friend class intrusive::Ref;\n  intrusive::Ref* mut_intrusive_ref() { return &intrusive_ref_; }\n\n  Stream()\n      : intrusive_ref_(),\n        thread_ctx_(),\n        device_(),\n        stream_type_(StreamType::kInvalid),\n        stream_policy_(),\n        on_scheduler_thread_(false),\n        small_pinned_mem_ptr_(),\n        running_instruction_list_(),\n        active_stream_hook_(),\n        thread_ctx_stream_hook_() {}\n  intrusive::Ref intrusive_ref_;\n  // fields\n  ThreadCtx* thread_ctx_;\n  Symbol<Device> device_;\n  StreamType stream_type_;\n  std::shared_ptr<StreamPolicy> stream_policy_;\n  bool on_scheduler_thread_;\n  std::unique_ptr<char, std::function<void(char*)>> small_pinned_mem_ptr_;\n  // lists\n  DispatchedInstructionList running_instruction_list_;\n\n  intrusive::shared_ptr<Dependence> schedule_local_dep_object_;\n  std::vector<intrusive::shared_ptr<Dependence>> transport_dependences_;\n\n public:\n  // list hooks\n  intrusive::ListHook active_stream_hook_;\n  intrusive::ListHook thread_ctx_stream_hook_;\n};\n\n}  // namespace vm\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_VM_STREAM_H_\n"
  },
  {
    "path": "oneflow/core/vm/stream_create_stream_policy.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_VM_STREAM_CREATE_STREAM_POLICY_H_\n#define ONEFLOW_CORE_VM_STREAM_CREATE_STREAM_POLICY_H_\n\n#include \"oneflow/core/common/symbol.h\"\n#include \"oneflow/core/common/stream_type.h\"\n#include \"oneflow/core/vm/control_stream_policy.h\"\n#include \"oneflow/core/vm/event_recorded_ep_stream_policy.h\"\n#include \"oneflow/core/vm/critical_section_stream_policy.h\"\n#include \"oneflow/core/vm/ep_d2h_stream_policy.h\"\n#include \"oneflow/core/vm/ep_stream_policy.h\"\n#include \"oneflow/core/vm/pinned_ep_stream_policy.h\"\n#include \"oneflow/core/vm/lazy_job_stream_policy.h\"\n\nnamespace oneflow {\n\nclass Device;\n\nstruct CreateStreamPolicy final : public StreamTypeVisitor<CreateStreamPolicy> {\n  static Maybe<vm::StreamPolicy> VisitCompute(Symbol<Device> device) {\n    return std::shared_ptr<vm::StreamPolicy>(new vm::EpStreamPolicy(device));\n  }\n  static Maybe<vm::StreamPolicy> VisitHost2Device(Symbol<Device> device) {\n    std::unique_ptr<vm::Allocator> allocator{};\n    if (device->enum_type() == DeviceType::kCPU) {\n      allocator = vm::EventRecordedEpStreamPolicy::CreateEpBackendDeviceAllocator(device);\n    } else {\n      allocator =\n          std::make_unique<vm::UnimplementedAllocator>(\"allocator is not supported on h2d stream.\");\n    }\n    return std::shared_ptr<vm::StreamPolicy>(\n        new vm::EventRecordedEpStreamPolicy(device, std::move(allocator)));\n  }\n  static Maybe<vm::StreamPolicy> VisitDevice2Host(Symbol<Device> device) {\n    return std::shared_ptr<vm::StreamPolicy>(new vm::EpD2HStreamPolicy(device));\n  }\n  static Maybe<vm::StreamPolicy> VisitCcl(Symbol<Device> device) {\n    auto allocator = vm::EventRecordedEpStreamPolicy::CreateEpBackendDeviceAllocator(device);\n    return std::shared_ptr<vm::StreamPolicy>(\n        new vm::EventRecordedEpStreamPolicy(device, std::move(allocator)));\n  }\n  static Maybe<vm::StreamPolicy> VisitBarrier(Symbol<Device> device) {\n    return std::shared_ptr<vm::StreamPolicy>(new vm::ControlStreamPolicy());\n  }\n  static Maybe<vm::StreamPolicy> VisitCriticalSection(Symbol<Device> device) {\n    return std::shared_ptr<vm::StreamPolicy>(new vm::CriticalSectionStreamPolicy());\n  }\n  static Maybe<vm::StreamPolicy> VisitLazyJobLauncher(Symbol<Device> device) {\n    return std::shared_ptr<vm::StreamPolicy>(new vm::LazyJobStreamPolicy());\n  }\n  static Maybe<vm::StreamPolicy> VisitPinnedCompute(Symbol<Device> device) {\n    return std::shared_ptr<vm::StreamPolicy>(new vm::PinnedEpStreamPolicy(device));\n  }\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_VM_STREAM_CREATE_STREAM_POLICY_H_\n"
  },
  {
    "path": "oneflow/core/vm/stream_get_allocator_stream_type.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_VM_STREAM_GET_ALLOCATOR_STREAM_TYPE_H_\n#define ONEFLOW_CORE_VM_STREAM_GET_ALLOCATOR_STREAM_TYPE_H_\n\n#include \"oneflow/core/common/stream_type.h\"\n\nnamespace oneflow {\n\nstruct GetAllocatorStreamType final : public StreamTypeVisitor<GetAllocatorStreamType> {\n  static Maybe<StreamType> VisitCompute() { return StreamType::kCompute; }\n  static Maybe<StreamType> VisitHost2Device() { return StreamType::kCompute; }\n  static Maybe<StreamType> VisitCcl() { return StreamType::kCompute; }\n  static Maybe<StreamType> VisitPinnedCompute() { return StreamType::kPinnedCompute; }\n  static Maybe<StreamType> VisitDevice2Host() { return StreamType::kDevice2Host; }\n  static Maybe<StreamType> VisitBarrier() {\n    UNIMPLEMENTED_THEN_RETURN() << \"no allocator supported on 'barrier' stream_type.\";\n  }\n  static Maybe<StreamType> VisitCriticalSection() {\n    UNIMPLEMENTED_THEN_RETURN() << \"no allocator supported on 'critical_section' stream_type.\";\n  }\n  static Maybe<StreamType> VisitLazyJobLauncher() {\n    UNIMPLEMENTED_THEN_RETURN() << \"no allocator supported on 'lazy_job_launcher' stream_type.\";\n  }\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_VM_STREAM_GET_ALLOCATOR_STREAM_TYPE_H_\n"
  },
  {
    "path": "oneflow/core/vm/stream_policy.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/vm/stream_policy.h\"\n#include \"oneflow/core/vm/stream.h\"\n#include \"oneflow/core/vm/instruction.h\"\n#include \"oneflow/core/framework/stream_on_independent_thread.h\"\n#include \"oneflow/core/framework/stream_is_comm_net_stream.h\"\n#include \"oneflow/core/common/env_var/vm.h\"\n#include \"oneflow/core/thread/thread_global_id.h\"\n\nnamespace oneflow {\nnamespace vm {\n\nbool StreamPolicy::OnSchedulerThread(StreamType stream_type) const {\n  if (StreamOnIndependentThread::Visit(stream_type)) { return false; }\n  return !ThreadLocalEnvBool<ONEFLOW_VM_COMPUTE_ON_WORKER_THREAD>();\n}\n\nvoid StreamPolicy::RunIf(Instruction* instruction) const {\n  if (IsCommNetStream::Visit(instruction->stream().stream_type())\n      && ThreadLocalEnvBool<ONEFLOW_VM_MULTI_THREAD>()) {\n    ThreadGlobalIdGuard guard{kThreadGlobalIdDefaultWorker};\n    Run(instruction);\n  } else {\n    Run(instruction);\n  }\n}\n\n}  // namespace vm\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/vm/stream_policy.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_VM_STREAM_POLICY_H_\n#define ONEFLOW_CORE_VM_STREAM_POLICY_H_\n\n#include <string>\n#include <typeindex>\n#include \"oneflow/core/framework/nn_graph_if.h\"\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/job/resource.pb.h\"\n#include \"oneflow/core/common/stream_type.h\"\n#include \"oneflow/core/common/symbol.h\"\n\nnamespace oneflow {\n\nclass EpEventProvider;\n\nnamespace ep {\n\nclass Device;\nclass Stream;\n\n}  // namespace ep\n\nnamespace vm {\n\nclass Allocator;\nclass Stream;\nclass InstructionStatusBuffer;\nclass Instruction;\n\nclass StreamPolicy {\n public:\n  virtual ~StreamPolicy() = default;\n\n  virtual ep::Stream* stream() = 0;\n  virtual vm::Allocator* mut_allocator() = 0;\n  virtual DeviceType device_type() const = 0;\n\n  virtual void InitInstructionStatus(const Stream& stream,\n                                     InstructionStatusBuffer* status_buffer) const = 0;\n  virtual void DeleteInstructionStatus(const Stream& stream,\n                                       InstructionStatusBuffer* status_buffer) const = 0;\n  virtual bool QueryInstructionStatusLaunched(\n      const Stream& stream, const InstructionStatusBuffer& status_buffer) const = 0;\n  virtual bool QueryInstructionStatusDone(const Stream& stream,\n                                          const InstructionStatusBuffer& status_buffer) const = 0;\n  virtual bool OnSchedulerThread(StreamType stream_type) const;\n  virtual bool SupportingTransportInstructions() const = 0;\n\n  void RunIf(Instruction* instruction) const;\n\n protected:\n  StreamPolicy() = default;\n\n private:\n  virtual void Run(Instruction* instruction) const = 0;\n};\n\n}  // namespace vm\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_VM_STREAM_POLICY_H_\n"
  },
  {
    "path": "oneflow/core/vm/stream_record_event_instruction_policy.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/vm/stream_record_event_instruction_policy.h\"\n#include \"oneflow/core/vm/ep_event.h\"\n#include \"oneflow/core/vm/instruction.h\"\n#include \"oneflow/core/vm/stream.h\"\n#include \"oneflow/core/ep/cuda/cuda_event.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n#include \"oneflow/core/ep/cuda/cuda_device.h\"\n#include \"oneflow/core/vm/ep_stream_policy_base.h\"\n#include \"oneflow/core/vm/ep_optional_event_record_status_querier.h\"\n\nnamespace oneflow {\nnamespace vm {\n\nStreamRecordEventInstructionPolicy::StreamRecordEventInstructionPolicy(\n    const small_vector<intrusive::shared_ptr<LocalDepObject>>& dependences)\n    : dependences_(dependences), input_dependences_(), output_dependences_() {\n  for (const auto& dep : dependences_) { output_dependences_.push_back(dep.get()); }\n}\n\nvoid StreamRecordEventInstructionPolicy::InitInstructionStatus(Instruction* instruction) {\n  auto* stream = instruction->mut_stream();\n  {\n    auto* ep_stream_policy_base =\n        CHECK_NOTNULL(dynamic_cast<EpStreamPolicyBase*>(instruction->mut_stream_policy()));\n    ep_stream_policy_base->InitInstructionStatus(*stream, instruction->mut_status_buffer());\n    auto* ep_event_provider = ep_stream_policy_base->ep_event_provider();\n    const auto& ep_event = CHECK_NOTNULL(ep_event_provider)->GetReusedEpEvent();\n    mut_ep_event() = ep_event;\n  }\n  {\n    auto* status_buffer = instruction->mut_status_buffer();\n    instruction->stream_policy().InitInstructionStatus(*stream, status_buffer);\n    auto* data_ptr = status_buffer->mut_buffer();\n    EpOptionalEventRecordStatusQuerier::MutCast(data_ptr)->reset_ep_event(nullptr);\n  }\n}\n\nvoid StreamRecordEventInstructionPolicy::Compute(vm::Instruction* instruction) {\n  const auto& ep_event = mut_ep_event();\n  // Record event.\n  auto* stream_policy =\n      dynamic_cast<EpStreamPolicyBase*>(instruction->mut_stream()->mut_stream_policy());\n  CHECK_NOTNULL(stream_policy)->stream()->RecordEvent(ep_event->mut_event());\n}\n\n}  // namespace vm\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/vm/stream_record_event_instruction_policy.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#ifndef ONEFLOW_CORE_VM_STREAM_RECORD_EVENT_INSTRUCTION_POLICY_H_\n#define ONEFLOW_CORE_VM_STREAM_RECORD_EVENT_INSTRUCTION_POLICY_H_\n\n#include <functional>\n#include \"oneflow/core/eager/local_dep_object.h\"\n#include \"oneflow/core/vm/instruction_policy.h\"\n#include \"oneflow/core/common/op_args_reserved_size.h\"\n#include \"oneflow/core/common/small_vector.h\"\n\nnamespace oneflow {\nclass EpEvent;\nnamespace vm {\n\nclass Stream;\n\nclass StreamRecordEventInstructionPolicy final : public vm::InstructionPolicy {\n public:\n  StreamRecordEventInstructionPolicy(\n      const small_vector<intrusive::shared_ptr<LocalDepObject>>& dependences);\n  ~StreamRecordEventInstructionPolicy() = default;\n\n  std::string DebugName(const vm::Instruction&) const override { return \"StreamRecordEvent\"; }\n\n  void InitInstructionStatus(Instruction* instruction) override;\n  Maybe<void> Prepare(vm::Instruction* instruction) override { return Maybe<void>::Ok(); }\n  void Compute(vm::Instruction* instruction) override;\n\n  const DependenceVector& input_dependences() const override { return input_dependences_; }\n  const DependenceVector& output_dependences() const override { return output_dependences_; }\n\n  std::shared_ptr<EpEvent>& mut_ep_event() { return ep_event_; }\n\n private:\n  small_vector<intrusive::shared_ptr<LocalDepObject>> dependences_;\n  DependenceVector input_dependences_;\n  DependenceVector output_dependences_;\n  std::shared_ptr<EpEvent> ep_event_;\n};\n\n}  // namespace vm\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_VM_STREAM_RECORD_EVENT_INSTRUCTION_POLICY_H_\n"
  },
  {
    "path": "oneflow/core/vm/stream_wait_event_instruction_policy.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/vm/stream_wait_event_instruction_policy.h\"\n#include \"oneflow/core/vm/ep_event.h\"\n#include \"oneflow/core/vm/instruction.h\"\n#include \"oneflow/core/vm/stream.h\"\n#include \"oneflow/core/vm/ep_stream_policy_base.h\"\n#include \"oneflow/core/vm/ep_optional_event_record_status_querier.h\"\n\nnamespace oneflow {\nnamespace vm {\n\nStreamWaitEventInstructionPolicy::StreamWaitEventInstructionPolicy(\n    const small_vector<intrusive::shared_ptr<LocalDepObject>>& dependences,\n    const std::shared_ptr<StreamRecordEventInstructionPolicy>&\n        stream_record_event_instruction_policy)\n    : dependences_(dependences),\n      input_dependences_(),\n      output_dependences_(),\n      stream_record_event_instruction_policy_(stream_record_event_instruction_policy) {\n  for (const auto& dep : dependences_) { output_dependences_.push_back(dep.get()); }\n}\n\nvoid StreamWaitEventInstructionPolicy::DeleteInstructionStatus(Instruction* instruction) {\n  auto* stream = instruction->mut_stream();\n  instruction->stream_policy().DeleteInstructionStatus(*stream, instruction->mut_status_buffer());\n  stream_record_event_instruction_policy_->mut_ep_event().reset();\n}\n\nvoid StreamWaitEventInstructionPolicy::Compute(vm::Instruction* instruction) {\n  const auto& ep_event = stream_record_event_instruction_policy_->mut_ep_event();\n  // Wait event.\n  auto* ep_stream_policy_base =\n      dynamic_cast<EpStreamPolicyBase*>(instruction->mut_stream()->mut_stream_policy());\n  CHECK_NOTNULL(ep_stream_policy_base);\n  auto* ep_stream = ep_stream_policy_base->stream();\n  CHECK_EQ(ep_event->mut_device(), ep_stream->device())\n      << \"only support waiting events from same device\";\n  ep_event->mut_device()->SetAsActiveDevice();\n  ep_stream->WaitEvent(ep_event->mut_event());\n}\n\n}  // namespace vm\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/vm/stream_wait_event_instruction_policy.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#ifndef ONEFLOW_CORE_VM_STREAM_WAIT_EVENT_INSTRUCTION_POLICY_H_\n#define ONEFLOW_CORE_VM_STREAM_WAIT_EVENT_INSTRUCTION_POLICY_H_\n\n#include <functional>\n#include \"oneflow/core/eager/local_dep_object.h\"\n#include \"oneflow/core/vm/instruction_policy.h\"\n#include \"oneflow/core/common/op_args_reserved_size.h\"\n#include \"oneflow/core/common/small_vector.h\"\n#include \"oneflow/core/vm/stream_record_event_instruction_policy.h\"\n\nnamespace oneflow {\nnamespace vm {\n\nclass Stream;\n\nclass StreamWaitEventInstructionPolicy final : public vm::InstructionPolicy {\n public:\n  StreamWaitEventInstructionPolicy(\n      const small_vector<intrusive::shared_ptr<LocalDepObject>>& dependences,\n      const std::shared_ptr<StreamRecordEventInstructionPolicy>&\n          stream_record_event_instruction_policy);\n  ~StreamWaitEventInstructionPolicy() = default;\n\n  std::string DebugName(const vm::Instruction&) const override { return \"StreamWaitEvent\"; }\n\n  void DeleteInstructionStatus(Instruction* instruction) override;\n  Maybe<void> Prepare(vm::Instruction* instruction) override { return Maybe<void>::Ok(); }\n  void Compute(vm::Instruction* instruction) override;\n\n  const DependenceVector& input_dependences() const override { return input_dependences_; }\n  const DependenceVector& output_dependences() const override { return output_dependences_; }\n\n private:\n  small_vector<intrusive::shared_ptr<LocalDepObject>> dependences_;\n  DependenceVector input_dependences_;\n  DependenceVector output_dependences_;\n  std::shared_ptr<StreamRecordEventInstructionPolicy> stream_record_event_instruction_policy_;\n};\n\n}  // namespace vm\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_VM_STREAM_WAIT_EVENT_INSTRUCTION_POLICY_H_\n"
  },
  {
    "path": "oneflow/core/vm/stream_wait_instruction_policy.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/vm/stream_wait_instruction_policy.h\"\n#include \"oneflow/core/vm/ep_event.h\"\n#include \"oneflow/core/vm/instruction.h\"\n#include \"oneflow/core/vm/stream.h\"\n#include \"oneflow/core/vm/ep_stream_policy_base.h\"\n#include \"oneflow/core/vm/ep_optional_event_record_status_querier.h\"\n\nnamespace oneflow {\nnamespace vm {\n\nStreamWaitInstructionPolicy::StreamWaitInstructionPolicy(\n    small_vector<intrusive::shared_ptr<LocalDepObject>>&& dependences, vm::Stream* from_vm_stream,\n    vm::Stream* to_vm_stream)\n    : dependences_(std::move(dependences)),\n      input_dependences_(),\n      output_dependences_(),\n      from_vm_stream_(from_vm_stream) {\n  for (const auto& dep : dependences_) { output_dependences_.push_back(dep.get()); }\n  stream_sequential_dependence_ = to_vm_stream->schedule_local_dep_object().get();\n}\n\nbool StreamWaitInstructionPolicy::Prescheduleable(const Stream* src, const Stream* dst) const {\n  return &src->thread_ctx() == &dst->thread_ctx();\n}\n\nvoid StreamWaitInstructionPolicy::InitInstructionStatus(Instruction* instruction) {\n  auto* stream = instruction->mut_stream();\n  auto* ep_stream_policy_base =\n      CHECK_NOTNULL(dynamic_cast<EpStreamPolicyBase*>(instruction->mut_stream_policy()));\n  ep_stream_policy_base->InitInstructionStatus(*stream, instruction->mut_status_buffer());\n  auto* ep_event_provider = ep_stream_policy_base->ep_event_provider();\n  const auto& ep_event = CHECK_NOTNULL(ep_event_provider)->GetReusedEpEvent();\n  mut_ep_event() = ep_event;\n}\n\nvoid StreamWaitInstructionPolicy::DeleteInstructionStatus(Instruction* instruction) {\n  auto* stream = instruction->mut_stream();\n  instruction->stream_policy().DeleteInstructionStatus(*stream, instruction->mut_status_buffer());\n  mut_ep_event().reset();\n}\n\nvoid StreamWaitInstructionPolicy::Compute(vm::Instruction* instruction) {\n  const auto& ep_event = mut_ep_event();\n  {\n    // Record event.\n    auto* from_naive_stream_policy =\n        dynamic_cast<EpStreamPolicyBase*>(mut_from_vm_stream()->mut_stream_policy());\n    CHECK_NOTNULL(from_naive_stream_policy);\n    auto* from_stream = from_naive_stream_policy->stream();\n    from_stream->RecordEvent(ep_event->mut_event());\n  }\n  {\n    // Wait event.\n    auto* to_ep_stream_policy_base =\n        dynamic_cast<EpStreamPolicyBase*>(instruction->mut_stream()->mut_stream_policy());\n    CHECK_NOTNULL(to_ep_stream_policy_base);\n    auto* to_ep_stream = to_ep_stream_policy_base->stream();\n    CHECK_EQ(ep_event->mut_device(), to_ep_stream->device())\n        << \"only support waiting events from same device\";\n    ep_event->mut_device()->SetAsActiveDevice();\n    to_ep_stream->WaitEvent(ep_event->mut_event());\n  }\n}\n\n}  // namespace vm\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/vm/stream_wait_instruction_policy.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#ifndef ONEFLOW_CORE_VM_STREAM_WAIT_INSTRUCTION_POLICY_H_\n#define ONEFLOW_CORE_VM_STREAM_WAIT_INSTRUCTION_POLICY_H_\n\n#include <functional>\n#include \"oneflow/core/eager/local_dep_object.h\"\n#include \"oneflow/core/vm/instruction_policy.h\"\n#include \"oneflow/core/common/op_args_reserved_size.h\"\n#include \"oneflow/core/common/small_vector.h\"\n\nnamespace oneflow {\nclass EpEvent;\nnamespace vm {\n\nclass Stream;\n\nclass StreamWaitInstructionPolicy final : public vm::InstructionPolicy {\n public:\n  StreamWaitInstructionPolicy(small_vector<intrusive::shared_ptr<LocalDepObject>>&& dependences,\n                              vm::Stream* from_vm_stream, vm::Stream* to_vm_stream);\n  ~StreamWaitInstructionPolicy() = default;\n\n  std::string DebugName(const vm::Instruction&) const override { return \"StreamWait\"; }\n\n  bool Prescheduleable(const Stream* src, const Stream* dst) const override;\n  void InitInstructionStatus(Instruction* instruction) override;\n  void DeleteInstructionStatus(Instruction* instruction) override;\n  Maybe<void> Prepare(vm::Instruction* instruction) override { return Maybe<void>::Ok(); }\n  void Compute(vm::Instruction* instruction) override;\n\n  const DependenceVector& input_dependences() const override { return input_dependences_; }\n  const DependenceVector& output_dependences() const override { return output_dependences_; }\n\n private:\n  vm::Stream* mut_from_vm_stream() { return from_vm_stream_; }\n  std::shared_ptr<EpEvent>& mut_ep_event() { return ep_event_; }\n\n  small_vector<intrusive::shared_ptr<LocalDepObject>> dependences_;\n  DependenceVector input_dependences_;\n  DependenceVector output_dependences_;\n  vm::Stream* from_vm_stream_;\n  std::shared_ptr<EpEvent> ep_event_;\n};\n\n}  // namespace vm\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_VM_STREAM_WAIT_INSTRUCTION_POLICY_H_\n"
  },
  {
    "path": "oneflow/core/vm/symbol_storage.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/vm/symbol_storage.h\"\n#include \"oneflow/core/job/parallel_desc.h\"\n#include \"oneflow/core/job/job_desc.h\"\n#include \"oneflow/core/job/scope.h\"\n#include \"oneflow/core/operator/op_conf_symbol.h\"\n\nnamespace oneflow {\n\nnamespace symbol {\n\nnamespace detail {\n\ntemplate<>\nMaybe<ParallelDesc> NewSymbol<ParallelDesc>(\n    int64_t symbol_id, const typename ConstructArgType4Symbol<ParallelDesc>::type& data) {\n  return ParallelDesc::New(symbol_id, data);\n}\n\ntemplate<>\nMaybe<JobDesc> NewSymbol<JobDesc>(int64_t symbol_id,\n                                  const typename ConstructArgType4Symbol<JobDesc>::type& data) {\n  return JobDesc::New(symbol_id, data);\n}\n\ntemplate<>\nMaybe<Scope> NewSymbol<Scope>(int64_t symbol_id,\n                              const typename ConstructArgType4Symbol<Scope>::type& data) {\n  return Scope::New(symbol_id, data);\n}\n\ntemplate<>\nMaybe<OperatorConfSymbol> NewSymbol<OperatorConfSymbol>(\n    int64_t symbol_id, const typename ConstructArgType4Symbol<OperatorConfSymbol>::type& data) {\n  return std::make_shared<OperatorConfSymbol>(symbol_id, data);\n}\n\n}  // namespace detail\n\n}  // namespace symbol\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/vm/symbol_storage.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_VM_STORAGE_H_\n#define ONEFLOW_CORE_VM_STORAGE_H_\n\n#include <mutex>\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/common/container_util.h\"\n\nnamespace oneflow {\n\nclass OperatorConfSymbol;\nclass OperatorConf;\n\nclass ParallelDesc;\nclass ParallelConf;\n\nclass JobDesc;\nclass JobConfigProto;\n\nclass Scope;\nclass ScopeProto;\n\nnamespace symbol {\n\ntemplate<typename T>\nstruct ConstructArgType4Symbol final {\n  using type = T;\n};\n\ntemplate<>\nstruct ConstructArgType4Symbol<OperatorConfSymbol> final {\n  using type = OperatorConf;\n};\n\ntemplate<>\nstruct ConstructArgType4Symbol<ParallelDesc> final {\n  using type = ParallelConf;\n};\n\ntemplate<>\nstruct ConstructArgType4Symbol<JobDesc> final {\n  using type = JobConfigProto;\n};\n\ntemplate<>\nstruct ConstructArgType4Symbol<Scope> final {\n  using type = ScopeProto;\n};\n\nnamespace detail {\n\ntemplate<typename T>\nMaybe<T> NewSymbol(int64_t symbol_id, const typename ConstructArgType4Symbol<T>::type& data) {\n  return std::make_shared<T>(data);\n}\n\ntemplate<>\nMaybe<OperatorConfSymbol> NewSymbol<OperatorConfSymbol>(\n    int64_t symbol_id, const typename ConstructArgType4Symbol<OperatorConfSymbol>::type& data);\n\ntemplate<>\nMaybe<ParallelDesc> NewSymbol<ParallelDesc>(\n    int64_t symbol_id, const typename ConstructArgType4Symbol<ParallelDesc>::type& data);\n\ntemplate<>\nMaybe<JobDesc> NewSymbol<JobDesc>(int64_t symbol_id,\n                                  const typename ConstructArgType4Symbol<JobDesc>::type& data);\n\ntemplate<>\nMaybe<Scope> NewSymbol<Scope>(int64_t symbol_id,\n                              const typename ConstructArgType4Symbol<Scope>::type& data);\n\n}  // namespace detail\n\ntemplate<typename T>\nclass Storage final {\n public:\n  Storage(const Storage&) = delete;\n  Storage(Storage&&) = delete;\n\n  Storage() = default;\n  ~Storage() = default;\n\n  bool Has(int64_t symbol_id) const {\n    std::unique_lock<std::mutex> lock(mutex_);\n    return symbol_id2symbol_.find(symbol_id) != symbol_id2symbol_.end();\n  }\n\n  bool Has(const typename ConstructArgType4Symbol<T>::type& symbol_data) const {\n    std::unique_lock<std::mutex> lock(mutex_);\n    const auto& iter = data2symbol_id_.find(symbol_data);\n    return iter != data2symbol_id_.end();\n  }\n\n  Maybe<const T&> MaybeGet(int64_t symbol_id) const { return *JUST(MaybeGetPtr(symbol_id)); }\n\n  Maybe<const T&> MaybeGet(const typename ConstructArgType4Symbol<T>::type& data) const {\n    return *JUST(MaybeGetPtr(data));\n  }\n\n  const T& Get(int64_t symbol_id) const { return *GetPtr(symbol_id); }\n\n  const T& Get(const typename ConstructArgType4Symbol<T>::type& data) const {\n    return *GetPtr(data);\n  }\n\n  Maybe<T> MaybeGetPtr(int64_t symbol_id) const {\n    std::unique_lock<std::mutex> lock(mutex_);\n    const auto& iter = symbol_id2symbol_.find(symbol_id);\n    CHECK_OR_RETURN(iter != symbol_id2symbol_.end()) << \"symbol_id: \" << symbol_id;\n    return iter->second;\n  }\n\n  Maybe<T> MaybeGetPtr(const typename ConstructArgType4Symbol<T>::type& data) const {\n    std::unique_lock<std::mutex> lock(mutex_);\n    const auto& iter = data2symbol_id_.find(data);\n    CHECK_OR_RETURN(iter != data2symbol_id_.end());\n    return JUST(MapAt(symbol_id2symbol_, iter->second));\n  }\n\n  const std::shared_ptr<T>& GetPtr(int64_t symbol_id) const {\n    std::unique_lock<std::mutex> lock(mutex_);\n    const auto& iter = symbol_id2symbol_.find(symbol_id);\n    CHECK(iter != symbol_id2symbol_.end()) << \"symbol_id: \" << symbol_id;\n    return iter->second;\n  }\n\n  const std::shared_ptr<T>& GetPtr(const typename ConstructArgType4Symbol<T>::type& data) const {\n    std::unique_lock<std::mutex> lock(mutex_);\n    const auto& iter = data2symbol_id_.find(data);\n    CHECK(iter != data2symbol_id_.end());\n    return CHECK_JUST(MapAt(symbol_id2symbol_, iter->second));\n  }\n\n  Maybe<void> Add(int64_t symbol_id, const typename ConstructArgType4Symbol<T>::type& data) {\n    CHECK_GT_OR_RETURN(symbol_id, 0);\n    const auto& ptr = JUST(detail::NewSymbol<T>(symbol_id, data));\n    std::unique_lock<std::mutex> lock(mutex_);\n    CHECK_OR_RETURN(symbol_id2symbol_.emplace(symbol_id, ptr).second);\n    data2symbol_id_[data] = symbol_id;\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<void> TryAdd(int64_t symbol_id, const typename ConstructArgType4Symbol<T>::type& data) {\n    CHECK_GT_OR_RETURN(symbol_id, 0);\n    const auto& ptr = JUST(detail::NewSymbol<T>(symbol_id, data));\n    std::unique_lock<std::mutex> lock(mutex_);\n    const auto& iter = symbol_id2symbol_.find(symbol_id);\n    if (iter != symbol_id2symbol_.end()) {\n      CHECK_OR_RETURN(data2symbol_id_.find(data) != data2symbol_id_.end());\n      return Maybe<void>::Ok();\n    }\n    CHECK_OR_RETURN(symbol_id2symbol_.emplace(symbol_id, ptr).second);\n    data2symbol_id_[data] = symbol_id;\n    return Maybe<void>::Ok();\n  }\n\n  Maybe<T> FindOrCreate(const typename ConstructArgType4Symbol<T>::type& symbol_data,\n                        const std::function<Maybe<int64_t>()>& Create) {\n    int64_t symbol_id = JUST(Create());\n    const auto& ptr = JUST(detail::NewSymbol<T>(symbol_id, symbol_data));\n    std::unique_lock<std::mutex> lock(mutex_);\n    const auto& iter = data2symbol_id_.find(symbol_data);\n    if (iter != data2symbol_id_.end()) { return JUST(MapAt(symbol_id2symbol_, iter->second)); }\n    CHECK_OR_RETURN(symbol_id2symbol_.emplace(symbol_id, ptr).second);\n    data2symbol_id_[symbol_data] = symbol_id;\n    return JUST(MapAt(symbol_id2symbol_, symbol_id));\n  }\n\n  void Clear(int64_t symbol_id) {\n    std::unique_lock<std::mutex> lock(mutex_);\n    auto iter = symbol_id2symbol_.find(symbol_id);\n    if (iter != symbol_id2symbol_.end()) {\n      data2symbol_id_.erase(iter->second->data());\n      symbol_id2symbol_.erase(symbol_id);\n    }\n  }\n  void ClearAll() {\n    std::unique_lock<std::mutex> lock(mutex_);\n    symbol_id2symbol_.clear();\n    data2symbol_id_.clear();\n  }\n\n private:\n  mutable std::mutex mutex_;\n  HashMap<int64_t, std::shared_ptr<T>> symbol_id2symbol_;\n  HashMap<typename ConstructArgType4Symbol<T>::type, int64_t> data2symbol_id_;\n};\n\n}  // namespace symbol\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_VM_STORAGE_H_\n"
  },
  {
    "path": "oneflow/core/vm/sync_access_instruction_policy.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/vm/sync_access_instruction_policy.h\"\n#include \"oneflow/core/vm/stream.h\"\n#include \"oneflow/core/kernel/kernel_util.h\"\n\nnamespace oneflow {\nnamespace vm {\n\nSyncAccessInstructionPolicy::SyncAccessInstructionPolicy()\n    : host_mem_case_(memory::MakeHostMemCase()),\n      btb_(),\n      mem_ptr_(nullptr),\n      bytes_(0),\n      eager_blob_object_(nullptr) {\n  ResetBase(nullptr, 0, nullptr);\n}\n\nvoid SyncAccessInstructionPolicy::ResetBase(char* mem_ptr, size_t bytes,\n                                            EagerBlobObject* eager_blob_object) {\n  btb_.Reset();\n  mem_ptr_ = mem_ptr;\n  bytes_ = bytes;\n  eager_blob_object_ = eager_blob_object;\n}\n\nnamespace {\n\nvoid FastCopy(char* dst, const char* src, size_t bytes) {\n  switch (bytes) {\n    case 1: {\n      *dst = *src;\n      return;\n    }\n    case 2: {\n      *reinterpret_cast<int16_t*>(dst) = *reinterpret_cast<const int16_t*>(src);\n      return;\n    }\n    case 4: {\n      *reinterpret_cast<int32_t*>(dst) = *reinterpret_cast<const int32_t*>(src);\n      return;\n    }\n    case 8: {\n      *reinterpret_cast<int64_t*>(dst) = *reinterpret_cast<const int64_t*>(src);\n      return;\n    }\n    case 16: {\n      using Bit128 = std::pair<int64_t, int64_t>;\n      *reinterpret_cast<Bit128*>(dst) = *reinterpret_cast<const Bit128*>(src);\n      return;\n    }\n    default: UNIMPLEMENTED() << \"FastCopy on bytes \" << bytes << \" not supported.\";\n  }\n}\n\n}  // namespace\n\nvoid SyncReadInstructionPolicy::Compute(Instruction* instruction) {\n  StreamPolicy* stream_policy = instruction->mut_stream_policy();\n  char* pinned_buffer = instruction->mut_stream()->CheckSizeAndGetTmpSmallPinnedMemPtr(bytes_);\n  mut_btb()->mut_notifier()->Notify();\n  SyncAutoMemcpy(stream_policy->stream(), pinned_buffer, eager_blob_object_->mut_dptr(), bytes_,\n                 host_mem_case_, eager_blob_object_->mem_case());\n  FastCopy(mem_ptr_, pinned_buffer, bytes_);\n  mut_btb()->mut_spin_counter()->Decrease();\n}\n\n}  // namespace vm\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/vm/sync_access_instruction_policy.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_VM_SYNC_ACCESS_INSTRUCTION_POLICY_H_\n#define ONEFLOW_CORE_VM_SYNC_ACCESS_INSTRUCTION_POLICY_H_\n\n#include <functional>\n#include <memory>\n#include \"oneflow/core/vm/instruction.h\"\n#include \"oneflow/core/vm/instruction_policy.h\"\n#include \"oneflow/core/vm/instruction_policy_util.h\"\n#include \"oneflow/core/eager/local_dep_object.h\"\n#include \"oneflow/core/eager/eager_blob_object.h\"\n#include \"oneflow/core/framework/tensor_storage.h\"\n#include \"oneflow/core/common/blocking_then_busy.h\"\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/vm/stream_policy.h\"\n#include \"oneflow/core/memory/memory_case_util.h\"\n\nnamespace oneflow {\nnamespace vm {\n\nclass SyncAccessInstructionPolicy : public InstructionPolicy {\n public:\n  SyncAccessInstructionPolicy();\n  virtual ~SyncAccessInstructionPolicy() = default;\n\n  Maybe<void> Prepare(Instruction* instruction) override { return Maybe<void>::Ok(); }\n\n  BlockingThenBusy* mut_btb() { return &btb_; }\n\n protected:\n  void ResetBase(char* mem_ptr, size_t bytes, EagerBlobObject* eager_blob_object);\n\n  const MemoryCase host_mem_case_;\n  BlockingThenBusy btb_;\n  char* mem_ptr_;\n  size_t bytes_;\n  EagerBlobObject* eager_blob_object_;\n};\n\nclass SyncReadInstructionPolicy final : public SyncAccessInstructionPolicy {\n public:\n  SyncReadInstructionPolicy() = default;\n  ~SyncReadInstructionPolicy() = default;\n\n  const DependenceVector& input_dependences() const override {\n    CHECK_EQ(input_dependences_.size(), 1);\n    return input_dependences_;\n  }\n\n  const DependenceVector& output_dependences() const override {\n    static thread_local DependenceVector empty{};\n    return empty;\n  }\n\n  std::string DebugName(const Instruction& instruction) const override { return \"SyncRead\"; }\n\n  void Reset(char* mem_ptr, size_t bytes, EagerBlobObject* eager_blob_object) {\n    ResetBase(mem_ptr, bytes, eager_blob_object);\n    if (likely(input_dependences_.size())) { input_dependences_.clear(); }\n    input_dependences_.push_back(CHECK_JUST(eager_blob_object->compute_local_dep_object()));\n  }\n\n  void Compute(Instruction* instruction) override;\n\n private:\n  DependenceVector input_dependences_;\n};\n\n}  // namespace vm\n}  // namespace oneflow\n#endif  // ONEFLOW_CORE_VM_SYNC_ACCESS_INSTRUCTION_POLICY_H_\n"
  },
  {
    "path": "oneflow/core/vm/sync_vm_mode_guard.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_VM_SYNC_VM_MODE_GUARD_H_\n#define ONEFLOW_CORE_VM_SYNC_VM_MODE_GUARD_H_\n\n#include \"oneflow/core/common/thread_local_guard.h\"\n\nnamespace oneflow {\n\nenum class SyncVmMode {\n  kInvalid = 0,\n  kEnable = 1,\n  kDisable = 2,\n};\n\nclass SyncVmModeGuard final : public ThreadLocalGuard<SyncVmMode> {\n public:\n  using ThreadLocalGuard<SyncVmMode>::ThreadLocalGuard;\n  ~SyncVmModeGuard() = default;\n\n  static bool IsCurrentSyncVmMode() {\n    const auto& opt_sync_mode = Current();\n    return opt_sync_mode.has_value() && CHECK_JUST(opt_sync_mode) == SyncVmMode::kEnable;\n  }\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_VM_SYNC_VM_MODE_GUARD_H_\n"
  },
  {
    "path": "oneflow/core/vm/thread_ctx.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/vm/thread_ctx.h\"\n#include \"oneflow/core/common/util.h\"\n\nnamespace oneflow {\nnamespace vm {\n\nThreadCtx::ThreadCtx()\n    : intrusive_ref_(),\n      stream_list_(),\n      worker_pending_instruction_mutex_(),\n      worker_pending_instruction_list_(&worker_pending_instruction_mutex_),\n      notifier_(),\n      transport_dependence_(intrusive::make_shared<vm::Dependence>()),\n      thread_ctx_hook_() {}\n\nsize_t ThreadCtx::TryReceiveAndRun() {\n  intrusive::List<INTRUSIVE_FIELD(Instruction, worker_pending_instruction_hook_)> tmp_list;\n  mut_worker_pending_instruction_list()->MoveTo(&tmp_list);\n  size_t size = tmp_list.size();\n  INTRUSIVE_FOR_EACH(instruction, &tmp_list) {\n    tmp_list.Erase(instruction.Mutable());\n    const StreamPolicy& stream_policy = instruction->stream().stream_policy();\n    stream_policy.RunIf(instruction.Mutable());\n  }\n  return size;\n}\n\n}  // namespace vm\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/vm/thread_ctx.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_VM_THREAD__H_\n#define ONEFLOW_CORE_VM_THREAD__H_\n\n#include <functional>\n#include \"oneflow/core/intrusive/intrusive.h\"\n#include \"oneflow/core/intrusive/mutexed_list.h\"\n#include \"oneflow/core/common/notifier.h\"\n#include \"oneflow/core/vm/stream.h\"\n#include \"oneflow/core/vm/vm_object.h\"\n\nnamespace oneflow {\nnamespace vm {\n\nusing WorkerPendingInstructionMutexedList =\n    intrusive::MutexedList<INTRUSIVE_FIELD(Instruction, worker_pending_instruction_hook_)>;\n\nclass ThreadCtx final : public intrusive::Base {\n public:\n  // types\n  using StreamList = intrusive::List<INTRUSIVE_FIELD(Stream, thread_ctx_stream_hook_)>;\n\n  // Getters\n  const StreamList& stream_list() const { return stream_list_; }\n\n  // Setters\n  StreamList* mut_stream_list() { return &stream_list_; }\n  WorkerPendingInstructionMutexedList* mut_worker_pending_instruction_list() {\n    return &worker_pending_instruction_list_;\n  }\n\n  // methods\n  size_t TryReceiveAndRun();\n\n  Notifier* mut_notifier() { return &notifier_; }\n\n  const intrusive::shared_ptr<vm::Dependence>& transport_dependence() const {\n    return transport_dependence_;\n  };\n\n private:\n  friend class intrusive::Ref;\n  intrusive::Ref* mut_intrusive_ref() { return &intrusive_ref_; }\n\n  ThreadCtx();\n\n  intrusive::Ref intrusive_ref_;\n  // lists\n  StreamList stream_list_;\n  std::mutex worker_pending_instruction_mutex_;\n  WorkerPendingInstructionMutexedList worker_pending_instruction_list_;\n  Notifier notifier_;\n  intrusive::shared_ptr<vm::Dependence> transport_dependence_;\n\n public:\n  // list hooks\n  intrusive::ListHook thread_ctx_hook_;\n};\n\n}  // namespace vm\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_VM_THREAD__H_\n"
  },
  {
    "path": "oneflow/core/vm/thread_safe_guard.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_VM_THREAD_SAFE_ALLOCATOR_H_\n#define ONEFLOW_CORE_VM_THREAD_SAFE_ALLOCATOR_H_\n\n#include <cstdint>\n#include <memory>\n#include <mutex>\n#include <thread>\n#include \"oneflow/core/common/util.h\"\n\nnamespace oneflow {\n\nnamespace vm {\nclass ThreadSafeLock final {\n public:\n  ThreadSafeLock() = default;\n  ~ThreadSafeLock() = default;\n  OF_DISALLOW_COPY_AND_MOVE(ThreadSafeLock);\n\n  class RAIIGuard final {\n   public:\n    explicit RAIIGuard(ThreadSafeLock& lock) : guard_(lock.mutex4guard) {}\n    ~RAIIGuard() = default;\n    OF_DISALLOW_COPY_AND_MOVE(RAIIGuard);\n\n   private:\n    std::unique_lock<std::mutex> guard_;\n  };\n\n private:\n  std::mutex mutex4guard;\n};\n\nclass ReentrantThreadSafeLock final {\n public:\n  ReentrantThreadSafeLock() = default;\n  ~ReentrantThreadSafeLock() = default;\n  OF_DISALLOW_COPY_AND_MOVE(ReentrantThreadSafeLock);\n\n  class RAIIGuard final {\n   public:\n    explicit RAIIGuard(ReentrantThreadSafeLock& lock) : guard_(lock.mutex4guard) {}\n    ~RAIIGuard() = default;\n    OF_DISALLOW_COPY_AND_MOVE(RAIIGuard);\n\n   private:\n    std::unique_lock<std::recursive_mutex> guard_;\n  };\n\n private:\n  std::recursive_mutex mutex4guard;\n};\n}  // namespace vm\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_VM_THREAD_SAFE_ALLOCATOR_H_\n"
  },
  {
    "path": "oneflow/core/vm/touch_tensors_instruction_policy.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_VM_TOUCH_TENSORS_INSTRUCTION_POLICY_H_\n#define ONEFLOW_CORE_VM_TOUCH_TENSORS_INSTRUCTION_POLICY_H_\n\n#include \"oneflow/core/vm/instruction_policy.h\"\n#include \"oneflow/core/eager/eager_blob_object.h\"\n#include \"oneflow/core/vm/instruction_policy_util.h\"\n\nnamespace oneflow {\nnamespace vm {\n\nclass TouchTensorsInstructionPolicy final : public InstructionPolicy {\n public:\n  explicit TouchTensorsInstructionPolicy(const vm::EagerBlobObjectList& eager_blob_objects)\n      : eager_blob_objects_(eager_blob_objects) {\n    const auto& Insert = InstructionPolicyUtil::SetInserter(&input_dependences_);\n    for (const auto& eager_blob_object : eager_blob_objects_) {\n      Insert(CHECK_JUST(eager_blob_object->compute_local_dep_object()));\n    }\n  }\n  ~TouchTensorsInstructionPolicy() = default;\n\n  const DependenceVector& input_dependences() const override { return input_dependences_; }\n  const DependenceVector& output_dependences() const override {\n    static DependenceVector empty{};\n    return empty;\n  }\n\n  std::string DebugName(const vm::Instruction& instruction) const override {\n    return \"TouchTensors\";\n  }\n  Maybe<void> Prepare(vm::Instruction* instruction) override { return Maybe<void>::Ok(); }\n  void Compute(vm::Instruction* instruction) override {}\n\n private:\n  vm::EagerBlobObjectList eager_blob_objects_;\n  DependenceVector input_dependences_;\n};\n\n}  // namespace vm\n}  // namespace oneflow\n#endif  // ONEFLOW_CORE_VM_TOUCH_TENSORS_INSTRUCTION_POLICY_H_\n"
  },
  {
    "path": "oneflow/core/vm/virtual_machine.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <typeinfo>\n#include <thread>\n#include <chrono>\n#include \"oneflow/core/vm/sync_vm_mode_guard.h\"\n#include \"oneflow/core/vm/barrier_instruction_policy.h\"\n#include \"oneflow/core/vm/caching_allocator.h\"\n#include \"oneflow/core/vm/global_sync_instruction_policy.h\"\n#include \"oneflow/core/vm/virtual_machine.h\"\n#include \"oneflow/core/vm/instruction.h\"\n#include \"oneflow/core/vm/vm_util.h\"\n#include \"oneflow/core/vm/allocator.h\"\n#include \"oneflow/core/common/blocking_counter.h\"\n#include \"oneflow/core/common/cpp_attribute.h\"\n#include \"oneflow/core/control/global_process_ctx.h\"\n#include \"oneflow/core/job/global_for.h\"\n#include \"oneflow/core/common/foreign_lock_helper.h\"\n#include \"oneflow/core/thread/thread_global_id.h\"\n#include \"oneflow/core/framework/transport_token.h\"\n#include \"oneflow/core/framework/to_string.h\"\n#include \"oneflow/core/framework/stream_on_independent_thread.h\"\n#include \"oneflow/core/framework/stream_is_comm_net_stream.h\"\n#include \"oneflow/core/profiler/profiler.h\"\n#include \"oneflow/core/platform/include/pthread_fork.h\"\n#include \"oneflow/core/common/env_var/env_var.h\"\n#include \"oneflow/core/common/env_var/vm.h\"\n#include \"oneflow/core/common/container_util.h\"\n#include \"oneflow/core/framework/device.h\"\n#include \"oneflow/core/framework/stream.h\"\n#include \"oneflow/core/framework/stream_get_stream_type_name.h\"\n#include \"oneflow/core/framework/stream_mgr.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename T>\nint MicrosecondsFrom(const T& start) {\n  return std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::steady_clock::now()\n                                                               - start)\n      .count();\n}\n\nMaybe<void> ForEachThreadCtx(vm::VirtualMachineEngine* engine,\n                             const std::function<Maybe<void>(vm::ThreadCtx*)>& DoEach) {\n  INTRUSIVE_UNSAFE_FOR_EACH_PTR(thread_ctx, engine->mut_thread_ctx_list()) {\n    JUST(DoEach(thread_ctx));\n  }\n  return Maybe<void>::Ok();\n}\n\nvoid GetSchedulerThreadInitializer(std::function<void()>* Initializer) {\n  *Initializer = [&]() { OF_PROFILER_NAME_THIS_HOST_THREAD(\"_VM::Scheduler\"); };\n}\n\nvoid WorkerLoop(vm::ThreadCtx* thread_ctx, const std::function<void(vm::ThreadCtx*)>& Initializer) {\n  SyncVmModeGuard guard(SyncVmMode::kEnable);\n  Initializer(thread_ctx);\n  constexpr static size_t kExpireMicroseconds = 200;\n  while (thread_ctx->mut_notifier()->WaitAndClearNotifiedCnt() == kNotifierStatusSuccess) {\n    std::chrono::time_point<std::chrono::steady_clock> start{};\n    do {\n      while (thread_ctx->TryReceiveAndRun()) { start = std::chrono::steady_clock::now(); }\n      std::this_thread::yield();\n    } while (MicrosecondsFrom(start) < kExpireMicroseconds);\n  }\n}\n\n}  // namespace\n\nVirtualMachine::VirtualMachine()\n    : multi_thread_(ThreadLocalEnvBool<ONEFLOW_VM_MULTI_THREAD>()),\n      threads_closed_(false),\n      scheduler_stopped_(false) {\n  // Class VirtualMachineEngine only cares the basic logical of vm, while class VirtualMachine\n  // manages threads and condition variables.\n  // In order to notify threads in VirtualMachineEngine, a notify callback lambda should be take as\n  // an argument for VirtualMachineEngine's constructor.\n  engine_ = intrusive::make_shared<vm::VirtualMachineEngine>();\n  OF_PROFILER_NAME_THIS_HOST_THREAD(\"_Main\");\n\n  if (multi_thread_) {\n    std::function<void()> SchedulerInitializer;\n    GetSchedulerThreadInitializer(&SchedulerInitializer);\n    schedule_thread_ = std::thread(&VirtualMachine::ScheduleLoop, this, SchedulerInitializer);\n  }\n  transport_dependence_.Reset();\n}\n\nnamespace {\n\nMaybe<Symbol<Stream>> GetBarrierStream() {\n  auto device = JUST(Device::New(\"cpu\"));\n  return Stream::New(device, StreamType::kBarrier);\n}\n\nvoid MakeBarrierInstructions(vm::InstructionList* list,\n                             const std::function<void()>& BarrierCallback) {\n  auto* vm = Singleton<VirtualMachine>::Get();\n  {\n    auto stream = CHECK_JUST(GetBarrierStream());\n    auto instruction = intrusive::make_shared<vm::Instruction>(\n        CHECK_JUST(vm->GetVmStream(stream)), std::make_shared<vm::GlobalSyncInstructionPolicy>());\n    list->EmplaceBack(std::move(instruction));\n  }\n  {\n    auto stream = CHECK_JUST(GetBarrierStream());\n    auto instruction = intrusive::make_shared<vm::Instruction>(\n        CHECK_JUST(vm->GetVmStream(stream)),\n        std::make_shared<vm::BarrierInstructionPolicy>(BarrierCallback));\n    list->EmplaceBack(std::move(instruction));\n  }\n}\n\n}  // namespace\n\nvoid VirtualMachine::ControlSync() {\n  auto bc = std::make_shared<BlockingCounter>(1);\n  vm::InstructionList list;\n  MakeBarrierInstructions(&list, [bc] { bc->Decrease(); });\n  CHECK_JUST(Receive(&list));\n  CHECK_JUST(bc->WaitUntilCntEqualZero(VirtualMachine::GetPredicatorNoMoreInstructionsFinished()));\n}\n\nMaybe<void> VirtualMachine::CloseVMThreads() {\n  CHECK_OR_RETURN(!threads_closed_) << \"vm threads closed\";\n  ControlSync();\n  pending_notifier_.Close();\n  if (multi_thread_) {\n    schedule_thread_.join();\n  } else {\n    // For technical reasons, worker threads are always created even in single thread mode\n    JUST(CloseWorkerThreads());\n  }\n  threads_closed_ = true;\n  return Maybe<void>::Ok();\n}\n\nnamespace {\n\nclass SingleThreadScheduleCtx : public vm::ScheduleCtx {\n public:\n  SingleThreadScheduleCtx() = default;\n  ~SingleThreadScheduleCtx() = default;\n\n  void OnWorkerLoadPending(vm::ThreadCtx* thread_ctx) const override {\n    while (thread_ctx->TryReceiveAndRun() > 0) {}\n  }\n};\n\nvoid ScheduleUntilVMEmpty(vm::VirtualMachineEngine* vm, const vm::ScheduleCtx& schedule_ctx) {\n  do { vm->Schedule(schedule_ctx); } while (!(vm->SchedulerEmpty()));\n}\n\n}  // namespace\n\nMaybe<void> VirtualMachine::BlockingRunProbeFunc(\n    const std::function<bool(vm::VirtualMachineEngine*)>& prob_func) {\n  JUST(Singleton<ForeignLockHelper>::Get()->WithScopedRelease([&, this]() -> Maybe<void> {\n    auto bc = std::make_shared<BlockingCounter>(1);\n    engine_->InsertProbe([bc, prob_func](vm::VirtualMachineEngine* engine) {\n      if (!prob_func(engine)) { return false; }\n      bc->Decrease();\n      return true;\n    });\n    if (threads_closed_ || !multi_thread_) {\n      ScheduleUntilVMEmpty(engine_.Mutable(), SingleThreadScheduleCtx());\n    } else {\n      pending_notifier_.Notify();\n    }\n    JUST(bc->WaitUntilCntEqualZero(VirtualMachine::GetPredicatorNoMoreInstructionsFinished()));\n    return Maybe<void>::Ok();\n  }));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> VirtualMachine::ShrinkAllMem() {\n  auto try_shrink_men = [](vm::VirtualMachineEngine* engine) -> bool {\n    if (engine->mut_active_stream_list()->size()) { return false; }\n    INTRUSIVE_FOR_EACH_PTR(thread_ctx, engine->mut_thread_ctx_list()) {\n      INTRUSIVE_FOR_EACH_PTR(stream, thread_ctx->mut_stream_list()) {\n        vm::Allocator* allocator = stream->mut_stream_policy()->mut_allocator();\n        if (allocator) {\n          auto* cache = dynamic_cast<vm::CachingAllocator*>(allocator);\n          if (cache != nullptr) { cache->Shrink(); }\n        }\n      }\n    }\n    return true;\n  };\n  return BlockingRunProbeFunc(try_shrink_men);\n}\n\nVirtualMachine::~VirtualMachine() {\n  if (!threads_closed_) { CHECK_JUST(CloseVMThreads()); }\n  RunMainThreadPendingTasks();\n  CHECK(engine_->SchedulerEmpty());\n  engine_.Reset();\n}\n\nstd::function<Maybe<bool>()> VirtualMachine::GetPredicatorNoMoreInstructionsFinished() {\n  auto last_total_erased = std::make_shared<size_t>(0);\n  auto* vm = Singleton<VirtualMachine>::Get();\n  if (vm != nullptr) { *last_total_erased = vm->engine_->total_erased_instruction_cnt(); }\n  return [last_total_erased]() -> Maybe<bool> {\n    auto* vm = Singleton<VirtualMachine>::Get();\n    CHECK_NOTNULL_OR_RETURN(vm) << \"virtual machine not initialized.\";\n    CHECK_OR_RETURN(!vm->NoMoreErasedInstructions(last_total_erased.get()))\n        << \"blocking instructions\\n\"\n        << vm->GetBlockingDebugString();\n    return false;\n  };\n}\n\nbool VirtualMachine::NoMoreErasedInstructions(size_t* last_total_erased_instruction_cnt) const {\n  size_t cnt = engine_->total_erased_instruction_cnt();\n  bool no_more_erased = (*last_total_erased_instruction_cnt == cnt);\n  *last_total_erased_instruction_cnt = cnt;\n  return no_more_erased;\n}\n\nstd::string VirtualMachine::GetBlockingDebugString() {\n  size_t limit = EnvInteger<ONEFLOW_VM_BLOCKING_DEBUG_INSTRUCTIONS_DISPLAY_LIMIT>();\n  return engine_->GetLivelyInstructionListDebugString(limit);\n}\n\nvoid VirtualMachine::RunMainThreadPendingTasks() {\n  std::unique_lock lock(main_thread_pending_tasks_mutex_);\n  for (const auto& main_thread_pending_task : main_thread_pending_tasks_) {\n    main_thread_pending_task();\n  }\n  main_thread_pending_tasks_.clear();\n}\n\nMaybe<void> VirtualMachine::Receive(vm::InstructionList* instruction_list) {\n  SyncVmModeGuard guard(SyncVmMode::kEnable);\n  RunMainThreadPendingTasks();\n  if (unlikely(pthread_fork::IsForkedSubProcess())) {\n    INTRUSIVE_FOR_EACH_PTR(instruction, instruction_list) {\n      const auto& device = instruction->stream().device();\n      CHECK_OR_RETURN(device->enum_type() == DeviceType::kCPU)\n          << pthread_fork::kOfCudaNotSupportInForkedSubProcess;\n      JUST(instruction->Prepare());\n      instruction->Compute();\n    }\n    instruction_list->Clear();\n  } else if (unlikely(threads_closed_ || !multi_thread_)) {\n    JUST(RunInCurrentThread(instruction_list));\n  } else {\n    const int64_t kHighWaterMark = GetInstructionHighWaterMark();\n    if (engine_->flying_instruction_cnt() > kHighWaterMark) {\n      JUST(Singleton<ForeignLockHelper>::Get()->WithScopedRelease([&, this]() -> Maybe<void> {\n        auto bc = std::make_shared<BlockingCounter>(1);\n        engine_->InsertProbe([bc](vm::VirtualMachineEngine* engine) {\n          const int64_t kLowWaterMark = GetInstructionLowWaterMark();\n          if (engine->flying_instruction_cnt() > kLowWaterMark) { return false; }\n          bc->Decrease();\n          return true;\n        });\n        pending_notifier_.Notify();\n        JUST(bc->WaitUntilCntEqualZero(VirtualMachine::GetPredicatorNoMoreInstructionsFinished()));\n        return Maybe<void>::Ok();\n      }));\n    }\n    if (JUST(engine_->Receive(instruction_list))) {\n      // old scheduler_pending_instruction_list is empty.\n      pending_notifier_.Notify();\n    }\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> VirtualMachine::NotifyOrRunScheduler() {\n  if (unlikely(pthread_fork::IsForkedSubProcess() || threads_closed_ || !multi_thread_)) {\n    ScheduleUntilVMEmpty(engine_.Mutable(), SingleThreadScheduleCtx());\n  } else {\n    pending_notifier_.Notify();\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> VirtualMachine::CloseWorkerThreads() {\n  JUST(ForEachThreadCtx(engine_.Mutable(), [&](vm::ThreadCtx* thread_ctx) -> Maybe<void> {\n    thread_ctx->mut_notifier()->Close();\n    return Maybe<void>::Ok();\n  }));\n  {\n    std::unique_lock<std::mutex> lock(worker_threads_mutex_);\n    for (const auto& worker_thread : worker_threads_) { worker_thread->join(); }\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> VirtualMachine::RunInCurrentThread(vm::InstructionList* instr_list) {\n  CHECK_OR_RETURN(engine_->SchedulerEmpty())\n      << \"vm scheduler not empty. May be a fatal error occured\";\n  JUST(engine_->Receive(instr_list));\n  ScheduleUntilVMEmpty(engine_.Mutable(), SingleThreadScheduleCtx());\n  return Maybe<void>::Ok();\n}\n\nnamespace {\n\nclass MultiThreadScheduleCtx : public vm::ScheduleCtx {\n public:\n  MultiThreadScheduleCtx() = default;\n  ~MultiThreadScheduleCtx() = default;\n\n  void OnWorkerLoadPending(vm::ThreadCtx* thread_ctx) const override {\n    thread_ctx->mut_notifier()->Notify();\n  }\n};\n\n}  // namespace\n\nvoid VirtualMachine::ScheduleLoop(const std::function<void()>& Initializer) {\n  SyncVmModeGuard guard(SyncVmMode::kEnable);\n  Initializer();\n  MultiThreadScheduleCtx schedule_ctx{};\n  while (pending_notifier_.WaitAndClearNotifiedCnt() == kNotifierStatusSuccess) {\n    OF_PROFILER_RANGE_GUARD(\"VirtualMachine::ScheduleLoop\");\n    auto start = std::chrono::steady_clock::now();\n    static constexpr int kWorkingMicroseconds = 1000;\n    // Every time this thread wakes up, engine_ is scheduled for about `kWorkingMicroseconds`.\n    // The cost of os thread switching is about 5-10 microseconds. Doing more scheduling in\n    // a single waiting up can reach higher performance.\n    do {\n      // Use SchedulerThreadUnsafeEmpty to avoid acquiring mutex lock.\n      // It's safe to use SchedulerThreadUnsafeEmpty here. pending_notifier_.notified_cnt_ will be\n      // greater than zero when inconsistency between\n      // engine_->pending_instruction_list.list_head_.list_head_.container_ and\n      // engine_->pending_instruction_list.list_head_.list_head_.size_ occured. hence the pending\n      // instructions\n      // will get handled in the next iteration.\n      //  VirtualMachine::Receive may be less effiencient if the thread safe version\n      //  `engine_->SchedulerEmpty()`\n      // used\n      //  here, because VirtualMachine::ScheduleLoop is more likely to get the mutex lock.\n      do {\n        const size_t total_inserted = engine_->total_inserted_instruction_cnt();\n        const size_t total_erased = engine_->total_erased_instruction_cnt();\n        engine_->Schedule(schedule_ctx);\n        if (ThreadLocalEnvBool<ONEFLOW_VM_ENABLE_SCHEDULE_YIELD>()\n            && total_inserted == engine_->total_inserted_instruction_cnt()\n            && total_erased == engine_->total_erased_instruction_cnt()) {  // nothing handled.\n          std::this_thread::yield();\n        }\n      } while (!engine_->SchedulerThreadUnsafeEmpty());\n    } while (MicrosecondsFrom(start) < kWorkingMicroseconds);\n  }\n  ScheduleUntilVMEmpty(engine_.Mutable(), schedule_ctx);\n  CHECK_JUST(CloseWorkerThreads());\n  scheduler_stopped_ = true;\n}\n\nintrusive::shared_ptr<vm::Dependence> VirtualMachine::FindOrCreateScheduleDependence(\n    Symbol<Stream> stream) {\n  std::unique_lock<std::recursive_mutex> lock(stream_and_thread_ctx_mutex_);\n  intrusive::shared_ptr<vm::Dependence>* ptr = &stream2dependence_[stream];\n  if (!*ptr) { *ptr = intrusive::make_shared<vm::Dependence>(); }\n  return *ptr;\n}\n\nintrusive::shared_ptr<vm::Dependence> VirtualMachine::FindOrCreateTransportLocalDepObject() {\n  std::unique_lock<std::recursive_mutex> lock(stream_and_thread_ctx_mutex_);\n  if (!transport_dependence_) { transport_dependence_ = intrusive::make_shared<vm::Dependence>(); }\n  return transport_dependence_;\n}\n\nMaybe<vm::Stream*> VirtualMachine::CreateStream(Symbol<Stream> stream) {\n  std::unique_lock<std::recursive_mutex> lock(stream_and_thread_ctx_mutex_);\n  vm::ThreadCtx* thread_ctx =\n      JUST(FindOrCreateThreadCtx(stream->device(), stream->stream_type(), stream->thread_uid()));\n  return JUST(CreateStream(thread_ctx, stream));\n}\n\nMaybe<vm::Stream*> VirtualMachine::GetVmStream(Symbol<Stream> stream) {\n  if (stream->unique_stream_id() >= unique_stream_id2vm_stream_.size()) {\n    std::unique_lock<std::recursive_mutex> lock(stream_and_thread_ctx_mutex_);\n    if (stream->unique_stream_id() >= unique_stream_id2vm_stream_.size()) {\n      auto* stream_mgr = JUST(SingletonMaybe<StreamMgr>());\n      for (int i = unique_stream_id2vm_stream_.size(); i <= stream->unique_stream_id(); ++i) {\n        Symbol<Stream> cur_stream = JUST(stream_mgr->GetStreamSymbol(i));\n        CHECK_EQ_OR_RETURN(cur_stream->unique_stream_id(), i)\n            << \"invalid Stream::unique_stream_id()\";\n        unique_stream_id2vm_stream_.SetOrAdd(cur_stream->unique_stream_id(),\n                                             JUST(CreateStream(cur_stream)));\n      }\n    }\n  }\n  return JUST(VectorAt(unique_stream_id2vm_stream_, stream->unique_stream_id()));\n}\n\nMaybe<vm::ThreadCtx*> VirtualMachine::FindOrCreateThreadCtx(Symbol<Device> device,\n                                                            StreamType stream_type,\n                                                            size_t thread_uid) {\n  std::unique_lock<std::recursive_mutex> lock(stream_and_thread_ctx_mutex_);\n  vm::ThreadCtx** thread_ctx_ptr = nullptr;\n  if (StreamOnIndependentThread::Visit(stream_type)) {\n    auto key = std::make_pair(device->enum_type(), stream_type);\n    thread_ctx_ptr = &devcie_type_stream_type_2independent_thread_ctx_[key];\n  } else {\n    thread_ctx_ptr = &thread_uid2shared_thread_ctx_[thread_uid];\n  }\n  if (*thread_ctx_ptr == nullptr) {\n    *thread_ctx_ptr = JUST(CreateThreadCtx(device, stream_type, thread_uid));\n  }\n  return *thread_ctx_ptr;\n}\n\nMaybe<vm::ThreadCtx*> VirtualMachine::CreateThreadCtx(Symbol<Device> device, StreamType stream_type,\n                                                      size_t thread_uid) {\n  std::unique_lock<std::recursive_mutex> lock(stream_and_thread_ctx_mutex_);\n  // thread_ctx_ptr may be used after timout.\n  auto thread_ctx_ptr = std::make_shared<vm::ThreadCtx*>(nullptr);\n  {\n    auto bc = std::make_shared<BlockingCounter>(1);\n    engine_->InsertProbe([thread_ctx_ptr, bc](vm::VirtualMachineEngine* engine) {\n      auto thread_ctx = intrusive::make_shared<vm::ThreadCtx>();\n      engine->mut_thread_ctx_list()->PushBack(thread_ctx.Mutable());\n      *thread_ctx_ptr = thread_ctx.Mutable();\n      bc->Decrease();\n      return true;\n    });\n    JUST(NotifyOrRunScheduler());\n    JUST(bc->WaitUntilCntEqualZero(VirtualMachine::GetPredicatorNoMoreInstructionsFinished()));\n  }\n  auto* thread_ctx = *thread_ctx_ptr;\n  {\n    const std::string thread_tag = [&] {\n      std::string device_tag = *CHECK_JUST(DeviceTag4DeviceType(device->enum_type()));\n      if (StreamOnIndependentThread::Visit(stream_type)) {\n        return device_tag + GetStreamTypeName::Visit(stream_type);\n      } else {\n        return std::to_string(thread_uid);\n      }\n    }();\n    const auto& WorkerInitializer = [thread_tag](vm::ThreadCtx* thread_ctx) {\n      OF_PROFILER_NAME_THIS_HOST_THREAD(\"_VM::Worker_\" + thread_tag);\n    };\n    auto thread = std::make_unique<std::thread>(&WorkerLoop, thread_ctx, WorkerInitializer);\n    {\n      std::unique_lock<std::mutex> lock(worker_threads_mutex_);\n      worker_threads_.push_back(std::move(thread));\n    }\n  }\n  return thread_ctx;\n}\n\nMaybe<vm::Stream*> VirtualMachine::CreateStream(vm::ThreadCtx* thread_ctx, Symbol<Stream> stream) {\n  std::unique_lock<std::recursive_mutex> lock(stream_and_thread_ctx_mutex_);\n  intrusive::shared_ptr<vm::Dependence> schedule_dependence =\n      FindOrCreateScheduleDependence(stream);\n  std::vector<intrusive::shared_ptr<vm::Dependence>> transport_dependences{};\n  if (IsCommNetStream::Visit(stream->stream_type())) {\n    transport_dependences.push_back(FindOrCreateTransportLocalDepObject());\n  }\n  auto vm_stream =\n      intrusive::make_shared<vm::Stream>(thread_ctx, stream->device(), stream->stream_type(),\n                                         schedule_dependence, transport_dependences);\n\n  auto bc = std::make_shared<BlockingCounter>(1);\n  engine_->InsertProbe([&vm_stream, thread_ctx, bc](vm::VirtualMachineEngine* engine) {\n    thread_ctx->mut_stream_list()->PushBack(vm_stream.Mutable());\n    bc->Decrease();\n    return true;\n  });\n  JUST(NotifyOrRunScheduler());\n  JUST(bc->WaitUntilCntEqualZero(VirtualMachine::GetPredicatorNoMoreInstructionsFinished()));\n  return vm_stream.Mutable();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/vm/virtual_machine.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_VM_VIRTUAL_MACHINE_H_\n#define ONEFLOW_CORE_VM_VIRTUAL_MACHINE_H_\n\n#include <mutex>\n#include \"oneflow/core/common/notifier.h\"\n#include \"oneflow/core/vm/virtual_machine_engine.h\"\n#include \"oneflow/core/thread/thread_pool.h\"\n#include \"oneflow/core/common/stream_type.h\"\n#include \"oneflow/core/common/steady_vector.h\"\n\nnamespace oneflow {\n\nclass InstructionsBuilder;\nclass Device;\n\nclass VirtualMachine final {\n public:\n  VirtualMachine(const VirtualMachine&) = delete;\n  VirtualMachine(VirtualMachine&&) = delete;\n  VirtualMachine();\n  ~VirtualMachine();\n\n  static std::function<Maybe<bool>()> GetPredicatorNoMoreInstructionsFinished();\n\n  intrusive::shared_ptr<vm::Dependence> FindOrCreateTransportLocalDepObject();\n\n  std::string GetBlockingDebugString();\n\n  Maybe<void> Receive(vm::InstructionList* instr_list);\n\n  Maybe<void> CloseVMThreads();\n\n  // Never called in vm work threads.\n  // VM sync must be called to ensure all working instructions are finished.\n  Maybe<void> ShrinkAllMem();\n  Maybe<vm::Stream*> GetVmStream(Symbol<Stream> stream);\n\n  size_t flying_instruction_cnt() const { return engine().flying_instruction_cnt(); }\n\n  void add_main_thread_pending_task(std::function<void()> task) {\n    std::unique_lock lock(main_thread_pending_tasks_mutex_);\n    main_thread_pending_tasks_.push_back(std::move(task));\n  }\n\n private:\n  friend class InstructionsBuilder;\n\n  void ScheduleLoop(const std::function<void()>& Initializer);\n\n  intrusive::shared_ptr<vm::Dependence> FindOrCreateScheduleDependence(Symbol<Stream> stream);\n  bool NoMoreErasedInstructions(size_t* last_total_erased_instruction_cnt) const;\n\n  const vm::VirtualMachineEngine& engine() const { return *engine_; }\n  vm::VirtualMachineEngine* mut_engine() { return engine_.Mutable(); }\n\n  void ControlSync();\n  Maybe<vm::ThreadCtx*> FindOrCreateThreadCtx(Symbol<Device> device, StreamType stream_type,\n                                              size_t thread_uid);\n  Maybe<vm::ThreadCtx*> CreateThreadCtx(Symbol<Device> device, StreamType stream_type,\n                                        size_t thread_uid);\n  Maybe<vm::Stream*> CreateStream(Symbol<Stream> stream);\n\n  Maybe<vm::Stream*> CreateStream(vm::ThreadCtx* thread_ctx, Symbol<Stream> stream);\n\n  Maybe<void> RunInCurrentThread(vm::InstructionList* instr_list);\n\n  Maybe<void> BlockingRunProbeFunc(const std::function<bool(vm::VirtualMachineEngine*)>& prob_func);\n\n  Maybe<void> NotifyOrRunScheduler();\n\n  Maybe<void> CloseWorkerThreads();\n\n  void RunMainThreadPendingTasks();\n\n  bool multi_thread_;\n  bool threads_closed_;\n  bool scheduler_stopped_;\n  intrusive::shared_ptr<vm::VirtualMachineEngine> engine_;\n\n  // for asynchronized execution\n  std::mutex worker_threads_mutex_;\n  std::list<std::unique_ptr<std::thread>> worker_threads_;\n\n  // for vm::Stream and vm::ThreadCtx\n  std::recursive_mutex stream_and_thread_ctx_mutex_;\n  HashMap<size_t, vm::ThreadCtx*> thread_uid2shared_thread_ctx_;\n  HashMap<std::pair<DeviceType, StreamType>, vm::ThreadCtx*>\n      devcie_type_stream_type_2independent_thread_ctx_;\n  HashMap<Symbol<Stream>, intrusive::shared_ptr<vm::Dependence>> stream2dependence_;\n  intrusive::shared_ptr<vm::Dependence> transport_dependence_;\n  SteadyVector<vm::Stream*> unique_stream_id2vm_stream_;\n\n  std::thread schedule_thread_;\n  Notifier pending_notifier_;\n\n  std::mutex main_thread_pending_tasks_mutex_;\n  std::vector<std::function<void()>> main_thread_pending_tasks_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_VM_VIRTUAL_MACHINE_H_\n"
  },
  {
    "path": "oneflow/core/vm/virtual_machine_engine.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/vm/virtual_machine_engine.h\"\n#include \"oneflow/core/common/env_var/vm.h\"\n#include \"oneflow/core/vm/caching_allocator.h\"\n#include \"oneflow/core/vm/fuse_instruction_policy.h\"\n#include \"oneflow/core/vm/release_tensor_instruction_policy.h\"\n#include \"oneflow/core/vm/allocator.h\"\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/common/balanced_splitter.h\"\n#include \"oneflow/core/common/cpp_attribute.h\"\n#include \"oneflow/core/framework/device.h\"\n#include \"oneflow/core/platform/include/pthread_fork.h\"\n#include \"oneflow/core/profiler/profiler.h\"\n#include \"oneflow/core/common/cpp_attribute.h\"\n#include \"oneflow/core/common/singleton.h\"\n#include \"oneflow/core/common/foreign_lock_helper.h\"\n#include \"oneflow/extension/stack/foreign_stack_getter.h\"\n\nnamespace oneflow {\n\nnamespace vm {\n\nvoid VirtualMachineEngine::ReleaseInstruction(Instruction* instruction) {\n  OF_PROFILER_RANGE_GUARD(\"R:\" + instruction->DebugName());\n  auto* access_list = instruction->mut_access_list();\n  INTRUSIVE_FOR_EACH(access, access_list) {\n    CHECK_GT(access->ref_cnt(), 1);\n    access_list->Erase(access.Mutable());\n    auto* dependence = access->mut_dependence();\n    if (unlikely(!access->rw_mutexed_object_access_hook().empty())) {\n      dependence->mut_access_list()->Erase(access.Mutable());\n    }\n  }\n  auto* out_edges = instruction->mut_out_edges();\n  INTRUSIVE_FOR_EACH_PTR(out_edge, out_edges) {\n    Instruction* out_instruction = out_edge->mut_dst_instruction();\n    // Edges are erased only if the instruction is completed.\n    out_edges->Erase(out_edge);\n    out_instruction->mut_in_edges()->Erase(out_edge);\n    if (Dispatchable(out_instruction)) {\n      OF_PROFILER_RANGE_GUARD(\"E:\" + out_instruction->DebugName());\n      mut_ready_instruction_list()->PushBack(out_instruction);\n    }\n  }\n}\n\n// Handle pending instructions, and try schedule them to ready list.\nvoid VirtualMachineEngine::HandleLocalPending() {\n  OF_PROFILER_RANGE_GUARD(\"HandleLocalPending\");\n  InstructionList pending_instructions;\n  FetchAndTryFusePendingInstructions(&pending_instructions);\n  INTRUSIVE_FOR_EACH_PTR(instruction, &pending_instructions) {\n    const auto& instruction_policy = instruction->instruction_policy();\n    instruction->InitStatus();\n    LivelyInstructionListPushBack(instruction);\n    if (unlikely(instruction_policy.IsBarrier())) {\n      mut_barrier_instruction_list()->PushBack(instruction);\n    } else {\n      ConsumeDependences(instruction);\n      if (likely(Dispatchable(instruction))) {\n        mut_ready_instruction_list()->PushBack(instruction);\n      }\n    }\n  }\n}\n\nnamespace {\n\nbool FusableBetween(InstructionFuseType fuse_type, Instruction* instruction,\n                    Instruction* prev_instruction) {\n  if (unlikely(instruction->instruction_policy().fuse_type() != fuse_type)) { return false; }\n  auto* stream = instruction->mut_stream();\n  if (unlikely(stream == nullptr)) { return false; }\n  auto* sequential_dep = instruction->instruction_policy().stream_sequential_dependence();\n  if (unlikely(sequential_dep == nullptr)) { return false; }\n\n  if (unlikely(prev_instruction == nullptr)) { return true; }\n  if (unlikely(stream != prev_instruction->mut_stream())) { return false; }\n  if (unlikely(sequential_dep\n               != prev_instruction->instruction_policy().stream_sequential_dependence())) {\n    return false;\n  }\n  return true;\n}\n\n}  // namespace\n\nvoid VirtualMachineEngine::MakeAndAppendFusedInstruction(\n    InstructionList&& fused_instruction_list, InstructionList* /*out*/ pending_instructions) {\n  if (unlikely(fused_instruction_list.size() == 0)) { return; }\n  if (unlikely(fused_instruction_list.size() == 1)) {\n    fused_instruction_list.MoveTo(pending_instructions);\n    return;\n  }\n  auto* begin = fused_instruction_list.Begin();\n  auto instruction = intrusive::make_shared<Instruction>(\n      begin->mut_stream(),\n      std::make_shared<FuseInstructionPolicy>(std::move(fused_instruction_list)));\n  pending_instructions->EmplaceBack(std::move(instruction));\n}\n\nvoid VirtualMachineEngine::FetchAndTryFusePendingInstructions(\n    InstructionList* /*out*/ pending_instructions) {\n  size_t window_size = ThreadLocalEnvInteger<ONEFLOW_VM_PENDING_HANDLE_WINDOW_SIZE>();\n  InstructionList fused_instruction_list;\n  INTRUSIVE_FOR_EACH_PTR(instruction, mut_local_pending_instruction_list()) {\n    if (window_size-- <= 0) { break; }\n    auto* fuse_begin = fused_instruction_list.Begin();\n    if (likely(FusableBetween(kEnableInstructionFuseAtAnyPosition, instruction, fuse_begin))) {\n      // fuse\n      mut_local_pending_instruction_list()->MoveToDstBack(instruction, &fused_instruction_list);\n    } else if (likely(FusableBetween(kEnableInstructionFuseAsTailOnly, instruction, fuse_begin))) {\n      // fuse\n      mut_local_pending_instruction_list()->MoveToDstBack(instruction, &fused_instruction_list);\n      MakeAndAppendFusedInstruction(std::move(fused_instruction_list), pending_instructions);\n    } else {\n      // no fuse\n      MakeAndAppendFusedInstruction(std::move(fused_instruction_list), pending_instructions);\n      mut_local_pending_instruction_list()->MoveToDstBack(instruction, pending_instructions);\n    }\n  }\n  MakeAndAppendFusedInstruction(std::move(fused_instruction_list), pending_instructions);\n}\n\nstd::string VirtualMachineEngine::GetLivelyInstructionListDebugString(int64_t debug_cnt) {\n  std::stringstream ss;\n  INTRUSIVE_UNSAFE_FOR_EACH_PTR(instruction, mut_lively_instruction_list()) {\n    if (--debug_cnt <= 0) { break; }\n    ss << instruction->DebugName() << \" ptr: \" << instruction\n       << \" dispatched:\" << (instruction->dispatched_instruction_hook().empty() ? \"0\" : \"1\")\n       << \" launched:\" << (instruction->Launched() ? \"1\" : \"0\")\n       << \" done:\" << (instruction->Done() ? \"1\" : \"0\");\n    INTRUSIVE_UNSAFE_FOR_EACH_PTR(edge, instruction->mut_in_edges()) {\n      ss << \" dep-ptr:\" << &edge->src_instruction();\n    }\n    ss << \"\\n\";\n  }\n  return ss.str();\n}\n\nvoid VirtualMachineEngine::LivelyInstructionListPushBack(Instruction* instruction) {\n  ++total_inserted_instruction_cnt_;\n  mut_lively_instruction_list()->PushBack(instruction);\n}\n\nvoid VirtualMachineEngine::InsertProbe(\n    const std::function<bool(VirtualMachineEngine*)>& ProbeFunction) {\n  probe_list_.EmplaceBack(intrusive::make_shared<VmProbe>(ProbeFunction));\n}\n\nvoid VirtualMachineEngine::HandleLocalProbe() {\n  OF_PROFILER_RANGE_GUARD(\"HandleLocalProbe\");\n  if (unlikely(local_probe_list_.size())) {\n    OF_PROFILER_RANGE_PUSH(\"HandleLocalProbe\");\n    INTRUSIVE_FOR_EACH_PTR(probe, &local_probe_list_) {\n      if (probe->probe_function()(this)) { local_probe_list_.Erase(probe); }\n    }\n    OF_PROFILER_RANGE_POP();\n  }\n}\n\nintrusive::shared_ptr<Instruction> VirtualMachineEngine::LivelyInstructionListErase(\n    Instruction* instruction) {\n  ++total_erased_instruction_cnt_;\n  return mut_lively_instruction_list()->Erase(instruction);\n}\n\n// Collect ready instructions onto ready_instruction_list_\nvoid VirtualMachineEngine::ReleaseFinishedInstructions(const ScheduleCtx& schedule_ctx) {\n  INTRUSIVE_FOR_EACH_PTR(stream, mut_active_stream_list()) {\n    while (true) {\n      auto* instruction_ptr = stream->mut_running_instruction_list()->Begin();\n      if (instruction_ptr == nullptr) { break; }\n      if (!(instruction_ptr->in_edges().empty() && instruction_ptr->Done())) { break; }\n      ReleaseInstruction(instruction_ptr);\n      // Prevent destructing instruction_ptr.\n      intrusive::shared_ptr<Instruction> instruction =\n          stream->mut_running_instruction_list()->Erase(instruction_ptr);\n      LivelyInstructionListErase(instruction_ptr);\n      instruction_ptr->DeleteStatusAndCheckEdges();\n    }\n    if (stream->running_instruction_list().empty()) { mut_active_stream_list()->Erase(stream); }\n  }\n}\n\nDependenceAccess* VirtualMachineEngine::AccessDependence(OperandAccessType access_type,\n                                                         Dependence* dependence,\n                                                         Instruction* instruction) {\n  auto access = access_pool_.make_shared(instruction, dependence, access_type);\n  auto* ptr = access.Mutable();\n  instruction->mut_access_list()->PushBack(ptr);\n  dependence->mut_access_list()->EmplaceBack(std::move(access));\n  return ptr;\n}\n\nvoid VirtualMachineEngine::TryConnectInstruction(Instruction* src_instruction,\n                                                 Instruction* dst_instruction) {\n  if (unlikely(src_instruction == dst_instruction)) { return; }\n  if (likely(EdgeDispatchable(src_instruction, dst_instruction))) { return; }\n  auto edge = instruction_edge_pool_.make_shared(src_instruction, dst_instruction);\n  src_instruction->mut_out_edges()->PushBack(edge.Mutable());\n  dst_instruction->mut_in_edges()->PushBack(edge.Mutable());\n}\n\nvoid VirtualMachineEngine::ConnectInstructionsByWrite(DependenceAccess* dst_access) {\n  CHECK(dst_access->is_mut_operand());\n  auto* dependence = dst_access->mut_dependence();\n  auto* dst_instruction = dst_access->mut_instruction();\n  auto* access_list = dependence->mut_access_list();\n  if (likely(access_list->Begin() == dst_access)) { return; }\n  INTRUSIVE_FOR_EACH_PTR(src_access, access_list) {\n    if (unlikely(src_access == dst_access)) { break; }\n    TryConnectInstruction(src_access->mut_instruction(), dst_instruction);\n    access_list->Erase(src_access);\n  }\n}\n\nvoid VirtualMachineEngine::ConnectInstructionsByRead(DependenceAccess* dst_access) {\n  CHECK(dst_access->is_const_operand());\n  auto* dependence = dst_access->mut_dependence();\n  auto* dst_instruction = dst_access->mut_instruction();\n  auto* first = dependence->mut_access_list()->Begin();\n  if (first->is_mut_operand()) {\n    TryConnectInstruction(first->mut_instruction(), dst_instruction);\n  } else if (first->is_const_operand()) {\n    // do nothing\n  } else {\n    UNIMPLEMENTED();\n  }\n}\n\nvoid VirtualMachineEngine::ConsumeDependences(Instruction* instruction) {\n  const auto& instruction_policy = instruction->instruction_policy();\n  auto* stream_sequential_dep = instruction_policy.stream_sequential_dependence();\n  if (likely(stream_sequential_dep != nullptr)) {\n    ConnectInstructionsByWrite(\n        AccessDependence(kMutableOperandAccess, stream_sequential_dep, instruction));\n  }\n  // Connect instructions by write before connecting by read.\n  for (auto* dependence : instruction_policy.output_dependences()) {\n    ConnectInstructionsByWrite(AccessDependence(kMutableOperandAccess, dependence, instruction));\n  }\n  for (auto* dependence : instruction_policy.input_dependences()) {\n    ConnectInstructionsByRead(AccessDependence(kConstOperandAccess, dependence, instruction));\n  }\n}\n\nbool VirtualMachineEngine::EdgeDispatchable(const Instruction* src, const Instruction* dst) const {\n  return dst->instruction_policy().Prescheduleable(&src->stream(), &dst->stream())\n         && !src->dispatched_instruction_hook().empty() /* dispatched */;\n}\n\nbool VirtualMachineEngine::Dispatchable(Instruction* instruction) const {\n  if (unlikely(!instruction->dispatched_instruction_hook().empty())) { return false; }\n  INTRUSIVE_UNSAFE_FOR_EACH_PTR(edge, instruction->mut_in_edges()) {\n    const auto* src_instruction = &edge->src_instruction();\n    if (unlikely(!EdgeDispatchable(src_instruction, instruction))) { return false; }\n  }\n  return true;\n}\n\n// Dispatch ready instructions and put prescheduled instructions onto ready_instruction_list_.\nvoid VirtualMachineEngine::DispatchAndPrescheduleInstructions(const ScheduleCtx& schedule_ctx) {\n  OF_PROFILER_RANGE_GUARD(\"DispatchAndPrescheduleInstructions\");\n  ReadyInstructionList tmp_ready_instruction_list;\n  mut_ready_instruction_list()->MoveTo(&tmp_ready_instruction_list);\n  INTRUSIVE_FOR_EACH(instruction, &tmp_ready_instruction_list) {\n    // Erases `instruction` from tmp_ready_instruction_list before dispatching, because\n    // `instruction.dispatched_instruction_hook_` are used in DispatchInstruction.\n    tmp_ready_instruction_list.Erase(instruction.Mutable());\n    OF_PROFILER_RANGE_GUARD(\"D:\" + instruction->DebugName());\n    DispatchInstruction(instruction.Mutable(), schedule_ctx);\n    // preschedule instructions\n    INTRUSIVE_UNSAFE_FOR_EACH_PTR(edge, instruction->mut_out_edges()) {\n      auto* out_instruction = edge->mut_dst_instruction();\n      if (Dispatchable(out_instruction)) {\n        OF_PROFILER_RANGE_GUARD(\"P:\" + out_instruction->DebugName());\n        mut_ready_instruction_list()->PushBack(out_instruction);\n      }\n    }\n  }\n}\n\nnamespace {\n\nstd::string DebugDeviceReset(vm::Stream* stream) {\n  stream->mut_stream_policy()->mut_allocator()->DeviceReset();\n  return \"reset device\";\n}\n\n}  // namespace\n\nvoid VirtualMachineEngine::DispatchInstruction(Instruction* instruction,\n                                               const ScheduleCtx& schedule_ctx) {\n  ForeignFrameThreadLocalGuard guard(instruction->foreign_frame());\n  auto* stream = instruction->mut_stream();\n  // Prepare\n  {\n    const auto& ret = TRY(instruction->Prepare());\n    if (unlikely(!ret.IsOk())) {\n      if (ret.error()->has_out_of_memory_error()) {\n        CHECK_JUST_MSG(ret, std::stringstream() << DebugDeviceReset(stream));\n      } else {\n        CHECK_JUST(ret);\n      }\n    }\n  }\n  stream->mut_running_instruction_list()->PushBack(instruction);\n  if (stream->active_stream_hook().empty()) { mut_active_stream_list()->PushBack(stream); }\n  // Compute\n  if (OnSchedulerThread(*stream)) {\n    stream->stream_policy().RunIf(instruction);\n  } else {\n    stream->mut_thread_ctx()->mut_worker_pending_instruction_list()->PushBack(instruction);\n    schedule_ctx.OnWorkerLoadPending(stream->mut_thread_ctx());\n  }\n}\n\n// Returns true if old scheduler_pending_instruction_list is empty\nMaybe<bool> VirtualMachineEngine::Receive(InstructionList* compute_instruction_list) {\n  OF_PROFILER_RANGE_GUARD(\"vm:Receive\");\n#ifdef OF_ENABLE_PROFILER\n  INTRUSIVE_UNSAFE_FOR_EACH_PTR(compute_instruction, compute_instruction_list) {\n    OF_PROFILER_RANGE_GUARD(compute_instruction->DebugName());\n  }\n#endif\n\n  bool old_list_empty = mut_pending_instruction_list()->MoveFrom(compute_instruction_list);\n  return old_list_empty;\n}\n\nbool VirtualMachineEngine::OnSchedulerThread(const Stream& stream) {\n  return stream.on_scheduler_thread() || pthread_fork::IsForkedSubProcess();\n}\n\n// Barrier instructions are run after all previous lively instructions.\n//\n// `instruction.lively_instruction_hook_` is linked to `vm.lively_instruction_list_` for all\n// instructions. `instruction.barrier_instruction_list_` is linked to `vm.barrier_instruction_list_`\n// only for barrier instructions.\n//\n//\n//  e.g. case0: waiting other instructions done.\n//\n//  +---------------------------+   +---------------------------+   +---------------------------+\n//  |      virtual_machine      |   |        instruction0       |   |        instruction1       |\n//  +---------------------------+   +---------------------------+   +---------------------------+\n//  |            ...            |   |            ...            |   |            ...            |\n//  |---------------------------|   |---------------------------|   |---------------------------|\n//  | lively_instruction_list_  |<->| lively_instruction_hook_  |<->| lively_instruction_hook_  |\n//  |---------------------------|   |---------------------------|   |---------------------------|\n//  |            ...            |   |            ...            |   |            ...            |\n//  |---------------------------|   |---------------------------|   |---------------------------|\n//  | barrier_instruction_list_ |<+ | barrier_instruction_hook_ | +>| barrier_instruction_hook_ |\n//  |---------------------------| | |---------------------------| | |---------------------------|\n//  |            ...            | | |            ...            | | |            ...            |\n//  +---------------------------+ | +---------------------------+ | +---------------------------+\n//                                |                               |\n//                                +-------------------------------+\n//\n// `instruction1` is a barrier instruction with barrier_instruction_hook_ linked, while\n// instruction0 is not. From the `virtual_machine`'s view, `barrier_instruction_list_.Begin() !=\n// lively_instruction_list_.Begin()`, so it's not the time to run barrier instruction\n// `barrier_instruction_list_.Begin()`.\n//\n//\n//  e.g. case1: run barrier instructions.\n//\n//  +---------------------------+   +---------------------------+   +---------------------------+\n//  |      virtual_machine      |   |        instruction0       |   |        instruction1       |\n//  +---------------------------+   +---------------------------+   +---------------------------+\n//  |            ...            |   |            ...            |   |            ...            |\n//  |---------------------------|   |---------------------------|   |---------------------------|\n//  | lively_instruction_list_  |<->| lively_instruction_hook_  |<->| lively_instruction_hook_  |\n//  |---------------------------|   |---------------------------|   |---------------------------|\n//  |            ...            |   |            ...            |   |            ...            |\n//  |---------------------------|   |---------------------------|   |---------------------------|\n//  | barrier_instruction_list_ |<->| barrier_instruction_hook_ |   | barrier_instruction_hook_ |\n//  |---------------------------|   |---------------------------|   |---------------------------|\n//  |            ...            |   |            ...            |   |            ...            |\n//  +---------------------------+   +---------------------------+   +---------------------------+\n//\n// `instruction0` is a barrier instruction with barrier_instruction_hook_ linked.\n// From the `virtual_machine`'s view, `barrier_instruction_list_.Begin() ==\n// lively_instruction_list_.Begin()`, so it's the time to run barrier instruction\n// `barrier_instruction_list_.Begin()`.\n//\n//\n// With the introduction of barrier_instruction_list_/barrier_instruction_hook_, the function\n// VirtualMachineEngine::Schedule can achive higher performance. For the most cases, barrier\n// instructions are scarcely received by vm, there is no need for vm to run\n// VirtualMachineEngine::TryRunBarrierInstruction every time VirtualMachineEngine::Schedule run. On\n// the other hand, `barrier_instruction_hook_.size() == 0` is more lightweight than\n// `lively_instruction_list_.Begin()?->instruction_policy().IsBarrier()`\n//\nvoid VirtualMachineEngine::TryRunBarrierInstruction(const ScheduleCtx& schedule_ctx) {\n  auto* sequnential_instruction = mut_barrier_instruction_list()->Begin();\n  CHECK_NOTNULL(sequnential_instruction);\n  if (likely(sequnential_instruction != mut_lively_instruction_list()->Begin())) { return; }\n  // All instructions before `sequnential_instruction` are handled now, it's time to handle\n  // `sequnential_instruction`.\n  OF_PROFILER_RANGE_GUARD(\"TryRunBarrierInstruction\");\n  const auto& instruction_policy = sequnential_instruction->instruction_policy();\n  CHECK(instruction_policy.IsBarrier());\n  CHECK(OnSchedulerThread(sequnential_instruction->stream()));\n  const StreamPolicy& stream_policy = sequnential_instruction->stream().stream_policy();\n  stream_policy.RunIf(sequnential_instruction);\n  mut_barrier_instruction_list()->Erase(sequnential_instruction);\n  LivelyInstructionListErase(sequnential_instruction);\n}\n\nvoid VirtualMachineEngine::Schedule(const ScheduleCtx& schedule_ctx) {\n  // Release finished instructions and try to schedule out instructions in DAG onto ready list.\n  if (unlikely(mut_active_stream_list()->size())) { ReleaseFinishedInstructions(schedule_ctx); }\n  // Try run the first barrier instruction.\n  if (unlikely(mut_barrier_instruction_list()->size())) { TryRunBarrierInstruction(schedule_ctx); }\n  // Handle pending instructions, and try schedule them to ready list.\n  // Use thread_unsafe_size to avoid acquiring mutex lock.\n  // The inconsistency between pending_instruction_list.list_head_.list_head_.container_ and\n  // pending_instruction_list.list_head_.list_head_.size_ is not a fatal error because\n  // VirtualMachineEngine::Schedule is always in a busy loop. All instructions will get handled\n  // eventually.\n  //  VirtualMachineEngine::Receive may be less effiencient if the thread safe version\n  //  `pending_instruction_list().size()` used here, because VirtualMachineEngine::Schedule is more\n  //  likely to get the mutex lock.\n  if (unlikely(local_pending_instruction_list().size())) {\n    HandleLocalPending();\n  } else if (unlikely(pending_instruction_list().thread_unsafe_size())) {\n    // MoveTo is under a lock.\n    mut_pending_instruction_list()->MoveTo(mut_local_pending_instruction_list());\n    if (local_pending_instruction_list().size()) { HandleLocalPending(); }\n  }\n  // dispatch ready instructions and try to schedule out instructions in DAG onto ready list.\n  if (unlikely(mut_ready_instruction_list()->size())) {\n    DispatchAndPrescheduleInstructions(schedule_ctx);\n  }\n  // handle scheduler probes\n  if (unlikely(local_probe_list_.size())) {\n    HandleLocalProbe();\n  } else if (unlikely(probe_list_.thread_unsafe_size())) {\n    probe_list_.MoveTo(&local_probe_list_);\n    if (local_probe_list_.size()) { HandleLocalProbe(); }\n  }\n}\n\nbool VirtualMachineEngine::SchedulerThreadUnsafeEmpty() const {\n  return pending_instruction_list().thread_unsafe_size() == 0\n         && local_pending_instruction_list().empty() && lively_instruction_list_.empty()\n         && active_stream_list().empty() && probe_list_.thread_unsafe_size() == 0\n         && local_probe_list_.empty();\n}\n\nbool VirtualMachineEngine::SchedulerEmpty() const {\n  // hook and size will be check in pending_instruction_list().empty().\n  return pending_instruction_list().empty() && probe_list_.empty() && SchedulerThreadUnsafeEmpty();\n}\n\n}  // namespace vm\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/vm/virtual_machine_engine.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_VM_VIRTUAL_MACHINE_ENGINE_H_\n#define ONEFLOW_CORE_VM_VIRTUAL_MACHINE_ENGINE_H_\n\n#include <mutex>\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/vm/instruction.h\"\n#include \"oneflow/core/vm/stream.h\"\n#include \"oneflow/core/vm/thread_ctx.h\"\n#include \"oneflow/core/vm/vm_object.h\"\n#include \"oneflow/core/common/range.h\"\n#include \"oneflow/core/intrusive/mutexed_list.h\"\n#include \"oneflow/core/intrusive/object_pool.h\"\n#include \"oneflow/core/vm/probe.h\"\n\nnamespace oneflow {\n\nnamespace vm {\n\nclass ThreadCtx;\n\nclass ScheduleCtx {\n public:\n  ScheduleCtx() = default;\n  virtual ~ScheduleCtx() = default;\n\n  virtual void OnWorkerLoadPending(vm::ThreadCtx* thread_ctx) const = 0;\n};\n\nusing ReadyInstructionList =\n    intrusive::List<INTRUSIVE_FIELD(Instruction, dispatched_instruction_hook_)>;\n\nclass VirtualMachineEngine final : public intrusive::Base {\n public:\n  // types\n  using ActiveStreamList = intrusive::List<INTRUSIVE_FIELD(Stream, active_stream_hook_)>;\n  using ThreadCtxList = intrusive::List<INTRUSIVE_FIELD(ThreadCtx, thread_ctx_hook_)>;\n  using InstructionList = intrusive::List<INTRUSIVE_FIELD(Instruction, main_instruction_hook_)>;\n  using LivelyInstructionList =\n      intrusive::List<INTRUSIVE_FIELD(Instruction, lively_instruction_hook_)>;\n  using BarrierInstructionList =\n      intrusive::List<INTRUSIVE_FIELD(Instruction, barrier_instruction_hook_)>;\n  using InstructionMutexedList =\n      intrusive::MutexedList<INTRUSIVE_FIELD(Instruction, Instruction::main_instruction_hook_)>;\n\n  // Getters\n  std::size_t flying_instruction_cnt() const {\n    return pending_instruction_list().thread_unsafe_size() + local_pending_instruction_list().size()\n           + (total_inserted_instruction_cnt() - total_erased_instruction_cnt());\n  }\n  size_t total_inserted_instruction_cnt() const { return total_inserted_instruction_cnt_; }\n  size_t total_erased_instruction_cnt() const { return total_erased_instruction_cnt_; }\n  void InsertProbe(const std::function<bool(VirtualMachineEngine*)>& ProbeFunction);\n  const ActiveStreamList& active_stream_list() const { return active_stream_list_; }\n  const ThreadCtxList& thread_ctx_list() const { return thread_ctx_list_; }\n  const LivelyInstructionList& lively_instruction_list() const { return lively_instruction_list_; }\n  const BarrierInstructionList& barrier_instruction_list() const {\n    return barrier_instruction_list_;\n  }\n  const InstructionMutexedList& pending_instruction_list() const {\n    return pending_instruction_list_;\n  }\n  const InstructionList& local_pending_instruction_list() const {\n    return local_pending_instruction_list_;\n  }\n  // Setters\n  ActiveStreamList* mut_active_stream_list() { return &active_stream_list_; }\n  ThreadCtxList* mut_thread_ctx_list() { return &thread_ctx_list_; }\n  LivelyInstructionList* mut_lively_instruction_list() { return &lively_instruction_list_; }\n  BarrierInstructionList* mut_barrier_instruction_list() { return &barrier_instruction_list_; }\n  InstructionMutexedList* mut_pending_instruction_list() { return &pending_instruction_list_; }\n  InstructionList* mut_local_pending_instruction_list() { return &local_pending_instruction_list_; }\n  // Returns true if old scheduler_pending_instruction_list is empty\n  Maybe<bool> Receive(InstructionList* instr_list);\n  void Schedule(const ScheduleCtx& schedule_ctx);\n  bool SchedulerThreadUnsafeEmpty() const;\n  bool SchedulerEmpty() const;\n  std::string GetLivelyInstructionListDebugString(int64_t debug_cnt);\n  void MoveToGarbageListAndNotifyGC(const ScheduleCtx& schedule_ctx);\n\n private:\n  ReadyInstructionList* mut_ready_instruction_list() { return &ready_instruction_list_; }\n\n  void ReleaseFinishedInstructions(const ScheduleCtx& schedule_ctx);\n  void HandleLocalPending();\n  void FetchAndTryFusePendingInstructions(InstructionList* /*out*/ pending_instructions);\n  void MakeAndAppendFusedInstruction(InstructionList&& fused_instruction_list,\n                                     InstructionList* /*out*/ pending_instructions);\n  void TryRunBarrierInstruction(const ScheduleCtx& schedule_ctx);\n  void DispatchAndPrescheduleInstructions(const ScheduleCtx& schedule_ctx);\n  bool OnSchedulerThread(const vm::Stream& stream);\n\n  void ReleaseInstruction(Instruction* instruction);\n\n  void TryConnectInstruction(Instruction* src_instruction, Instruction* dst_instruction);\n  void ConnectInstructionsByWrite(DependenceAccess* dst_access);\n  void ConnectInstructionsByRead(DependenceAccess* dst_access);\n  DependenceAccess* AccessDependence(OperandAccessType access_type, Dependence* dependence,\n                                     Instruction* instrution);\n  void ConsumeDependences(Instruction* instruction);\n  void DispatchInstruction(Instruction* instruction, const ScheduleCtx& schedule_ctx);\n\n  bool EdgeDispatchable(const Instruction* src, const Instruction* dst) const;\n  bool Dispatchable(Instruction* instruction) const;\n\n  void TryDispatchReadyInstructions();\n\n  void LivelyInstructionListPushBack(Instruction* instruction);\n  intrusive::shared_ptr<Instruction> LivelyInstructionListErase(Instruction* instruction);\n  void HandleLocalProbe();\n\n  friend class intrusive::Ref;\n  intrusive::Ref* mut_intrusive_ref() { return &intrusive_ref_; }\n\n  VirtualMachineEngine()\n      : intrusive_ref_(),\n        active_stream_list_(),\n        thread_ctx_list_(),\n        pending_instruction_mutex_(),\n        pending_instruction_list_(&pending_instruction_mutex_),\n        local_pending_instruction_list_(),\n        ready_instruction_list_(),\n        lively_instruction_list_(),\n        total_inserted_instruction_cnt_(0),\n        total_erased_instruction_cnt_(0),\n        probe_mutex_(),\n        probe_list_(&probe_mutex_),\n        local_probe_list_(),\n        barrier_instruction_list_() {}\n  intrusive::Ref intrusive_ref_;\n  // lists or maps\n  // Do not change the order of the following fields\n  ActiveStreamList active_stream_list_;\n  ThreadCtxList thread_ctx_list_;\n  std::mutex pending_instruction_mutex_;\n  InstructionMutexedList pending_instruction_list_;\n  // local_pending_instruction_list_ should be consider as the cache of pending_instruction_list_.\n  InstructionList local_pending_instruction_list_;\n  ReadyInstructionList ready_instruction_list_;\n  LivelyInstructionList lively_instruction_list_;\n  size_t total_inserted_instruction_cnt_;\n  size_t total_erased_instruction_cnt_;\n\n  using VmProbe = Probe<std::function<bool(VirtualMachineEngine*)>>;\n  std::mutex probe_mutex_;\n  intrusive::MutexedList<INTRUSIVE_FIELD(VmProbe, probe_hook_)> probe_list_;\n  intrusive::List<INTRUSIVE_FIELD(VmProbe, probe_hook_)> local_probe_list_;\n\n  BarrierInstructionList barrier_instruction_list_;\n  DependenceAccess::object_pool_type access_pool_;\n  InstructionEdge::object_pool_type instruction_edge_pool_;\n};\n\n}  // namespace vm\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_VM_VIRTUAL_MACHINE_ENGINE_H_\n"
  },
  {
    "path": "oneflow/core/vm/virtual_machine_scope.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/vm/virtual_machine_scope.h\"\n#include \"oneflow/core/vm/virtual_machine_engine.h\"\n#include \"oneflow/core/vm/virtual_machine.h\"\n#include \"oneflow/core/control/global_process_ctx.h\"\n\nnamespace oneflow {\nnamespace vm {\n\nVirtualMachineScope::VirtualMachineScope(const Resource& resource) {\n  Singleton<VirtualMachine>::New();\n}\n\nVirtualMachineScope::~VirtualMachineScope() { Singleton<VirtualMachine>::Delete(); }\n\n}  // namespace vm\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/vm/virtual_machine_scope.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/job/resource.pb.h\"\n\nnamespace oneflow {\nnamespace vm {\n\nclass VirtualMachineScope {\n public:\n  VirtualMachineScope(const Resource& resource);\n  ~VirtualMachineScope();\n};\n\n}  // namespace vm\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/vm/vm_object.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/vm/vm_object.h\"\n#include \"oneflow/core/vm/instruction.h\"\n#include \"oneflow/core/common/util.h\"\n\nnamespace oneflow {\nnamespace vm {\n\nvoid DependenceAccess::__Init__() {\n  clear_instruction();\n  clear_dependence();\n}\n\nvoid DependenceAccess::__Init__(Instruction* instruction, Dependence* dependence,\n                                OperandAccessType access_type) {\n  __Init__();\n  set_instruction(instruction);\n  set_dependence(dependence);\n  set_access_type(access_type);\n}\n\n}  // namespace vm\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/vm/vm_object.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_VM_VM_OBJECT_H_\n#define ONEFLOW_CORE_VM_VM_OBJECT_H_\n\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/intrusive/intrusive.h\"\n#include \"oneflow/core/intrusive/object_pool.h\"\n\nnamespace oneflow {\n\nnamespace vm {\n\nclass Instruction;\nclass Dependence;\n\nusing DependenceVector = std::vector<Dependence*>;\n\nenum OperandAccessType {\n  kConstOperandAccess = 0,\n  kMutableOperandAccess,\n};\n\nclass DependenceAccess final\n    : public intrusive::Base,\n      public intrusive::EnableObjectPool<DependenceAccess,\n                                         intrusive::kThreadUnsafeAndDisableDestruct> {\n public:\n  void __Init__();\n  // Getters\n  OperandAccessType access_type() const { return access_type_; }\n  bool has_instruction() const { return instruction_ != nullptr; }\n  bool has_dependence() const { return dependence_ != nullptr; }\n  const Instruction& instruction() const { return *instruction_; }\n  const Dependence& dependence() const { return *dependence_; }\n  const intrusive::ListHook& rw_mutexed_object_access_hook() const {\n    return rw_mutexed_object_access_hook_;\n  }\n\n  // Setters\n  void set_access_type(OperandAccessType val) { access_type_ = val; }\n  void set_instruction(Instruction* val) { instruction_ = val; }\n  void set_dependence(Dependence* val) { dependence_ = val; }\n  void clear_instruction() { instruction_ = nullptr; }\n  void clear_dependence() { dependence_ = nullptr; }\n  Instruction* mut_instruction() { return instruction_; }\n  Dependence* mut_dependence() { return dependence_; }\n\n  // methods\n  void __Init__(Instruction* instruction, Dependence* dependence, OperandAccessType access_type);\n\n  bool is_const_operand() const { return kConstOperandAccess == access_type(); }\n  bool is_mut_operand() const { return kMutableOperandAccess == access_type(); }\n\n  intrusive::Ref::RefCntType ref_cnt() const { return intrusive_ref_.ref_cnt(); }\n  intrusive::Ref* mut_intrusive_ref() { return &intrusive_ref_; }  // NOLINT\n\n private:\n  friend class intrusive::Ref;\n\n  DependenceAccess()\n      : intrusive_ref_(),\n        access_type_(),\n        instruction_(),\n        dependence_(),\n        instruction_access_hook_(),\n        rw_mutexed_object_access_hook_() {}\n  intrusive::Ref intrusive_ref_;\n  // fields\n  OperandAccessType access_type_;\n  Instruction* instruction_;\n  Dependence* dependence_;\n\n public:\n  // list hooks\n  intrusive::ListHook instruction_access_hook_;\n  intrusive::ListHook rw_mutexed_object_access_hook_;\n};  // NOLINT\n\nclass Dependence final : public intrusive::Base {\n public:\n  // types\n  using DependenceAccessList =\n      intrusive::List<INTRUSIVE_FIELD(DependenceAccess, rw_mutexed_object_access_hook_)>;\n\n  // Setters\n  DependenceAccessList* mut_access_list() { return &access_list_; }\n\n  // methods\n  void __Init__() {}\n\n  intrusive::Ref::RefCntType ref_cnt() const { return intrusive_ref_.ref_cnt(); }\n\n private:\n  friend class intrusive::Ref;\n  intrusive::Ref* mut_intrusive_ref() { return &intrusive_ref_; }\n\n  Dependence() : intrusive_ref_(), access_list_() {}\n\n  intrusive::Ref intrusive_ref_;\n  // list hooks\n  DependenceAccessList access_list_;\n};\n\n}  // namespace vm\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_VM_VM_OBJECT_H_\n"
  },
  {
    "path": "oneflow/core/vm/vm_sync.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_VM_SYNC_H_\n#define ONEFLOW_CORE_VM_SYNC_H_\n\n#include \"oneflow/core/common/maybe.h\"\n\nnamespace oneflow {\nnamespace vm {\n\nMaybe<void> ClusterSync();\nMaybe<void> CurrentRankSync();\n\n}  // namespace vm\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_VM_SYNC_H_\n"
  },
  {
    "path": "oneflow/core/vm/vm_util.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/blocking_counter.h\"\n\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/common/protobuf.h\"\n#include \"oneflow/core/job/cluster_instruction.h\"\n#include \"oneflow/core/vm/vm_util.h\"\n#include \"oneflow/core/vm/virtual_machine.h\"\n#include \"oneflow/core/framework/instructions_builder.h\"\n#include \"oneflow/core/job/resource_desc.h\"\n#include \"oneflow/core/job/global_for.h\"\n#include \"oneflow/core/rpc/include/global_process_ctx.h\"\n\nnamespace oneflow {\nnamespace vm {\n\nMaybe<void> Run(vm::InstructionList* instruction_list) {\n  auto* virtual_machine = JUST(SingletonMaybe<VirtualMachine>());\n  JUST(virtual_machine->Receive(instruction_list));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> ClusterSync() {\n  auto bc = std::make_shared<BlockingCounter>(1);\n  JUST(PhysicalRun([bc](InstructionsBuilder* builder) -> Maybe<void> {\n    JUST(builder->GlobalSync());\n    JUST(builder->Barrier([bc]() { bc->Decrease(); }));\n    return Maybe<void>::Ok();\n  }));\n  JUST(bc->WaitUntilCntEqualZero(VirtualMachine::GetPredicatorNoMoreInstructionsFinished()));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> CurrentRankSync() {\n  auto bc = std::make_shared<BlockingCounter>(1);\n  JUST(PhysicalRun([bc](InstructionsBuilder* builder) -> Maybe<void> {\n    JUST(builder->Barrier([bc]() { bc->Decrease(); }));\n    return Maybe<void>::Ok();\n  }));\n  JUST(bc->WaitUntilCntEqualZero(VirtualMachine::GetPredicatorNoMoreInstructionsFinished()));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace vm\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/core/vm/vm_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_VM_H_\n#define ONEFLOW_CORE_VM_H_\n\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/intrusive/intrusive.h\"\n#include \"oneflow/core/vm/instruction.h\"\n#include \"oneflow/core/vm/vm_sync.h\"\n\nnamespace oneflow {\nnamespace vm {\n\nclass Instruction;\n\nMaybe<void> Run(vm::InstructionList* instruction_list);\nMaybe<void> ClusterSync();\nMaybe<void> CurrentRankSync();\n\n}  // namespace vm\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_VM_H_\n"
  },
  {
    "path": "oneflow/extension/python/numpy.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <pybind11/pybind11.h>\n#include \"oneflow/core/common/stride.h\"\n#include \"oneflow/core/common/throw.h\"\n#include \"oneflow/core/common/registry_error.h\"\n#include \"oneflow/extension/python/numpy_internal.h\"\n\nnamespace py = pybind11;\n\nnamespace oneflow {\n\nnamespace numpy {\n\nNumPyArrayInternal::NumPyArrayInternal(PyObject* obj, const std::function<void()>& deleter)\n    : obj_((PyArrayObject*)obj), deleter_(deleter) {\n  CHECK_OR_THROW(PyArray_Check(obj)) << \"The object is not a numpy array.\";\n  CHECK_OR_THROW(PyArray_ISCONTIGUOUS(obj_)) << \"Contiguous array is expected.\";\n  size_ = PyArray_SIZE(obj_);\n  data_ = PyArray_DATA(obj_);\n}\n\nNumPyArrayInternal::~NumPyArrayInternal() {\n  if (deleter_) { deleter_(); }\n}\n\nMaybe<int> OFDataTypeToNumpyType(DataType of_data_type) {\n  switch (of_data_type) {\n    case DataType::kBool: return NPY_BOOL;\n    case DataType::kFloat: return NPY_FLOAT32;\n    case DataType::kDouble: return NPY_FLOAT64;\n    case DataType::kInt8: return NPY_INT8;\n    case DataType::kInt16: return NPY_INT16;\n    case DataType::kChar: return NPY_INT8;\n    case DataType::kInt32: return NPY_INT32;\n    case DataType::kInt64: return NPY_INT64;\n    case DataType::kUInt8: return NPY_UINT8;\n    case DataType::kFloat16: return NPY_FLOAT16;\n    case DataType::kComplex64: return NPY_COMPLEX64;\n    case DataType::kComplex128: return NPY_COMPLEX128;\n    default:\n      return Error::InvalidValueError() << \"OneFlow data type \" << DataType_Name(of_data_type)\n                                        << \" is not valid to Numpy data type.\";\n  }\n}\n\nMaybe<DataType> NumpyTypeToOFDataType(int np_type) {\n  switch (np_type) {\n    case NPY_BOOL: return DataType::kBool;\n    case NPY_FLOAT32: return DataType::kFloat;\n    case NPY_FLOAT64: return DataType::kDouble;\n    case NPY_INT8: return DataType::kInt8;\n    case NPY_INT16: return DataType::kInt16;\n    case NPY_INT32: return DataType::kInt32;\n    case NPY_INT64:\n    case NPY_LONGLONG: return DataType::kInt64;\n    case NPY_UINT8: return DataType::kUInt8;\n    case NPY_FLOAT16: return DataType::kFloat16;\n    case NPY_COMPLEX64: return DataType::kComplex64;\n    case NPY_COMPLEX128: return DataType::kComplex128;\n    default:\n      return Error::InvalidValueError() << \"Numpy data type \" << std::to_string(np_type)\n                                        << \" is not valid to OneFlow data type.\";\n  }\n}\n\nMaybe<DataType> GetOFDataTypeFromNpArray(PyArrayObject* array) {\n  int np_array_type = PyArray_TYPE(array);\n  return NumpyTypeToOFDataType(np_array_type);\n}\n\nstd::vector<size_t> OFShapeToNumpyShape(const DimVector& fixed_vec) {\n  size_t ndim = fixed_vec.size();\n  auto result = std::vector<size_t>(ndim);\n  for (int i = 0; i < ndim; i++) { result[i] = fixed_vec.at(i); }\n  return result;\n}\n\n// NumPy strides use bytes. OneFlow strides use element counts.\nstd::vector<size_t> OFStrideToNumpyStride(const Stride& stride, const DataType data_type) {\n  size_t ndim = stride.size();\n  auto result = std::vector<size_t>(ndim);\n  int byte_per_elem = GetSizeOfDataType(data_type);\n  for (int i = 0; i < ndim; i++) { result[i] = stride.at(i) * byte_per_elem; }\n  return result;\n}\n\nbool PyArrayCheckLongScalar(PyObject* obj) {\n  return PyArray_CheckScalar(obj) && PyDataType_ISINTEGER(PyArray_DescrFromScalar(obj));\n}\n\nbool PyArrayCheckFloatScalar(PyObject* obj) {\n  return PyArray_CheckScalar(obj) && PyDataType_ISFLOAT(PyArray_DescrFromScalar(obj));\n}\n\nbool PyArrayCheckBoolScalar(PyObject* obj) {\n  return PyArray_CheckScalar(obj) && PyDataType_ISBOOL(PyArray_DescrFromScalar(obj));\n}\n\nbool PyArrayCheckComplexScalar(PyObject* obj) {\n  return PyArray_CheckScalar(obj) && PyDataType_ISCOMPLEX(PyArray_DescrFromScalar(obj));\n}\n\n// Executing any numpy c api before _import_array() results in segfault\n// NOTE: this InitNumpyCAPI() works because of `PY_ARRAY_UNIQUE_SYMBOL`\n// defined in numpy_internal.h\n// Reference:\n// https://numpy.org/doc/stable/reference/c-api/array.html#importing-the-api\nMaybe<void> InitNumpyCAPI() {\n  CHECK_ISNULL_OR_RETURN(PyArray_API);\n  CHECK_EQ_OR_RETURN(_import_array(), 0)\n      << \". Unable to import Numpy array, try to upgrade Numpy version!\";\n  return Maybe<void>::Ok();\n}\n\n}  // namespace numpy\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/extension/python/numpy.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_EXTENSION_PYTHON_NUMPY_H_\n#define ONEFLOW_EXTENSION_PYTHON_NUMPY_H_\n\n#define NO_IMPORT_ARRAY\n#include \"oneflow/extension/python/numpy_internal.h\"\n\nnamespace oneflow {\n\nclass NumPyArrayPtr final {\n public:\n  NumPyArrayPtr(PyObject* obj)\n      : internal_(std::make_shared<numpy::NumPyArrayInternal>(obj, []() -> void {})) {}\n  NumPyArrayPtr(PyObject* obj, const std::function<void()>& deleter)\n      : internal_(std::make_shared<numpy::NumPyArrayInternal>(obj, deleter)) {}\n\n  void* data() const { return internal_->data(); }\n\n  size_t size() const { return internal_->size(); }\n\n private:\n  std::shared_ptr<numpy::NumPyArrayInternal> internal_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_EXTENSION_PYTHON_NUMPY_H_\n"
  },
  {
    "path": "oneflow/extension/python/numpy_internal.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n// ************************\n//\n// NOTE: Do NOT include this file (numpy_internal.h) directly.\n// Include numpy.h instead.\n//\n// ************************\n\n#include \"oneflow/core/common/data_type.h\"\n#include \"oneflow/core/common/small_vector.h\"\n#include \"oneflow/core/common/shape_vec.h\"\n\n// PyArrayObject cannot be forward declared, or a compile error will occur\n\n// https://numpy.org/doc/stable/reference/c-api/array.html?highlight=array%20api#importing-the-api\n#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION\n#define PY_ARRAY_UNIQUE_SYMBOL oneflow_ARRAY_API\n#include <numpy/arrayobject.h>\n\nnamespace oneflow {\n\nclass Stride;\n\nnamespace numpy {\n\nclass NumPyArrayInternal final {\n public:\n  NumPyArrayInternal(PyObject* obj, const std::function<void()>& deleter);\n  ~NumPyArrayInternal();\n\n  void* data() const { return data_; }\n\n  size_t size() const { return size_; }\n\n private:\n  PyArrayObject* obj_;\n  void* data_;\n  size_t size_;\n  std::function<void()> deleter_;\n};\n\nMaybe<int> OFDataTypeToNumpyType(DataType of_data_type);\n\nMaybe<DataType> NumpyTypeToOFDataType(int np_array_type);\n\nMaybe<DataType> GetOFDataTypeFromNpArray(PyArrayObject* array);\n\nstd::vector<size_t> OFShapeToNumpyShape(const DimVector& fixed_vec);\n\nstd::vector<size_t> OFStrideToNumpyStride(const Stride& stride, const DataType data_type);\n\nbool PyArrayCheckLongScalar(PyObject* obj);\n\nbool PyArrayCheckFloatScalar(PyObject* obj);\n\nbool PyArrayCheckBoolScalar(PyObject* obj);\n\nbool PyArrayCheckComplexScalar(PyObject* obj);\n\nMaybe<void> InitNumpyCAPI();\n\n}  // namespace numpy\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/extension/python/py_compute.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/extension/python/py_compute.h\"\n\n#define PY_SSIZE_T_CLEAN\n#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION\n#include <numpy/arrayobject.h>\n\n#include \"oneflow/core/common/data_type.h\"\n#include \"oneflow/core/framework/user_op_tensor.h\"\n#include \"oneflow/core/framework/util.h\"\n#include \"oneflow/extension/python/numpy.h\"\n\nnamespace oneflow {\nnamespace pyext {\n\nnamespace {\nstatic PyObject* py_kernels_dic = nullptr;\n\n#define TENSOR_MEM_CAST(dtype) static_cast<void*>(const_cast<dtype*>(tensor->dptr<dtype>()))\n\nvoid* TensorToMem(const user_op::Tensor* tensor) {\n  switch (tensor->data_type()) {\n    case DataType::kFloat: return TENSOR_MEM_CAST(float);\n    case DataType::kDouble: return TENSOR_MEM_CAST(double);\n    case DataType::kBool: return TENSOR_MEM_CAST(bool);\n    case DataType::kInt8: return TENSOR_MEM_CAST(int8_t);\n    case DataType::kInt32: return TENSOR_MEM_CAST(int32_t);\n    case DataType::kInt64: return TENSOR_MEM_CAST(int64_t);\n    case DataType::kUInt8: return TENSOR_MEM_CAST(uint8_t);\n    case DataType::kFloat16: return TENSOR_MEM_CAST(float16);\n    default:\n      LOG(FATAL) << \"OneFlow data type \" << DataType_Name(tensor->data_type())\n                 << \" is not supported yet.\";\n      return nullptr;\n  }\n}\n\nvoid TensorToNumpy(const user_op::Tensor* tensor, PyObject** arg_ptr) {\n  if (tensor == nullptr) {\n    Py_INCREF(Py_None);\n    *arg_ptr = Py_None;\n    return;\n  }\n  int type_num = CHECK_JUST(numpy::OFDataTypeToNumpyType(tensor->data_type()));\n  VLOG(3) << \"Tensor data type \" << DataType_Name(tensor->data_type()) << \" Numpy type \"\n          << type_num;\n  int dim_size = tensor->shape_view().NumAxes();\n  npy_intp dims[dim_size];\n  FOR_RANGE(size_t, i, 0, dim_size) { dims[i] = tensor->shape_view().At(i); }\n\n  void* data = TensorToMem(tensor);\n  auto* np_array =\n      reinterpret_cast<PyArrayObject*>(PyArray_SimpleNewFromData(dim_size, dims, type_num, data));\n  // Numpy will not release the data\n  PyArray_CLEARFLAGS(np_array, NPY_ARRAY_OWNDATA);\n  *arg_ptr = reinterpret_cast<PyObject*>(np_array);\n}\n\n#define TENSOR_MEM_ASSIGN(dtype)                                                     \\\n  do {                                                                               \\\n    dtype* array_data = static_cast<dtype*>(array_data_ptr);                         \\\n    FOR_RANGE(int64_t, i, 0, size) { tensor->mut_dptr<dtype>()[i] = array_data[i]; } \\\n  } while (0)\n\nvoid MemToTensor(void* array_data_ptr, const size_t size, user_op::Tensor* tensor) {\n  switch (tensor->data_type()) {\n    case DataType::kFloat: TENSOR_MEM_ASSIGN(float); break;\n    case DataType::kDouble: TENSOR_MEM_ASSIGN(double); break;\n    case DataType::kBool: TENSOR_MEM_ASSIGN(bool); break;\n    case DataType::kInt8: TENSOR_MEM_ASSIGN(int8_t); break;\n    case DataType::kInt32: TENSOR_MEM_ASSIGN(int32_t); break;\n    case DataType::kInt64: TENSOR_MEM_ASSIGN(int64_t); break;\n    case DataType::kUInt8: TENSOR_MEM_ASSIGN(uint8_t); break;\n    case DataType::kFloat16: TENSOR_MEM_ASSIGN(float16); break;\n    default:\n      LOG(FATAL) << \"OneFlow data type \" << DataType_Name(tensor->data_type())\n                 << \" is not supported yet.\";\n  }\n}\n\nvoid NumpyToTensor(PyObject* arg, user_op::Tensor* tensor) {\n  PyObject* ro_array = PyArray_FromAny(arg, nullptr, 0, 0, NPY_ARRAY_CARRAY_RO, nullptr);\n  // PyArray_FromAny has increased the reference count\n  Py_DECREF(ro_array);\n  PyArrayObject* array = reinterpret_cast<PyArrayObject*>(ro_array);\n\n  DataType of_data_type = CHECK_JUST(numpy::GetOFDataTypeFromNpArray(array));\n  CHECK_EQ(of_data_type, tensor->data_type())\n      << \"Numpy to OneFlow data type \" << DataType_Name(of_data_type)\n      << \" is not equal to OneFlow tensor data type \" << DataType_Name(tensor->data_type());\n\n  int64_t array_elem_cnt = 1;\n  FOR_RANGE(int, i, 0, PyArray_NDIM(array)) { array_elem_cnt *= PyArray_SHAPE(array)[i]; }\n  CHECK_EQ(array_elem_cnt, tensor->shape_view().elem_cnt())\n      << \"Numpy array element count \" << array_elem_cnt\n      << \" is not equal to OneFlow tensor element count \" << tensor->shape_view().elem_cnt();\n\n  void* array_data_ptr = PyArray_DATA(array);\n  MemToTensor(array_data_ptr, array_elem_cnt, tensor);\n}\n\nvoid MakePyInputs(const UserOpDef& op_def, user_op::KernelComputeContext* ctx,\n                  PyObject** py_inputs) {\n  const size_t kernel_in_num = ctx->inputs().size();\n  const size_t def_in_num = op_def.input_size();\n  CHECK_EQ(kernel_in_num, def_in_num) << \"kernel input num \" << kernel_in_num\n                                      << \" not equal to definition input num \" << def_in_num;\n  PyObject* py_list = PyList_New(def_in_num);\n  CHECK(py_list);\n\n  FOR_RANGE(size_t, i, 0, def_in_num) {\n    PyObject* arg = nullptr;\n    const std::string& arg_name = op_def.input(i).name();\n    VLOG(3) << \"input arg_name \" << arg_name;\n    // do not support multi input in one symbolic arg name\n    int32_t index = 0;\n    TensorToNumpy(ctx->Tensor4ArgNameAndIndex(arg_name, index), &arg);\n    arg = PyArray_Return(reinterpret_cast<PyArrayObject*>(arg));\n    PyList_SetItem(py_list, i, arg);\n  }\n  *py_inputs = Py_BuildValue(\"(N)\", py_list);\n  CHECK(*py_inputs);\n}\n\nvoid GetPyOutputs(const UserOpDef& op_def, user_op::KernelComputeContext* ctx,\n                  PyObject* py_outputs) {\n  const size_t kernel_out_num = ctx->outputs().size();\n  const size_t def_out_num = op_def.output_size();\n  CHECK_EQ(kernel_out_num, def_out_num) << \"kernel output num \" << kernel_out_num\n                                        << \" not equal to definition output num \" << def_out_num;\n  if (PyList_Check(py_outputs)) {\n    FOR_RANGE(size_t, i, 0, def_out_num) {\n      const std::string& arg_name = op_def.output(i).name();\n      VLOG(3) << \"output arg_name \" << arg_name;\n      int32_t index = 0;\n      NumpyToTensor(PyList_GetItem(py_outputs, i), ctx->Tensor4ArgNameAndIndex(arg_name, index));\n    }\n  } else if (PyArray_Check(py_outputs)) {\n    const std::string& arg_name = ctx->outputs().at(0).first;\n    VLOG(3) << \"output arg_name \" << arg_name;\n    int32_t index = 0;\n    NumpyToTensor(py_outputs, ctx->Tensor4ArgNameAndIndex(arg_name, index));\n  } else {\n    LOG(FATAL) << \"Unexpeted PyObject was returned: \" << Py_TYPE(py_outputs)->tp_name;\n  }\n}\n\n}  // namespace\n\nvoid PyRegisterKernels(PyObject* py_kernels) {\n  if (py_kernels_dic == nullptr) {\n    py_kernels_dic = py_kernels;\n    Py_INCREF(py_kernels_dic);\n  } else {\n    LOG(FATAL) << \"RegisterPyKernels should only be call once.\";\n  }\n}\n\nvoid PyCompute(user_op::KernelComputeContext* ctx, const std::string& py_func_name) {\n  const std::string& op_type_name = ctx->op_type_name();\n  const user_op::OpRegistryResult* val =\n      user_op::UserOpRegistryMgr::Get().GetOpRegistryResult(op_type_name);\n  CHECK(val) << \"Op op_type_name \" << op_type_name << \" has no definition.\";\n  const UserOpDef& op_def = val->op_def;\n\n  // get GIL\n  PyGILState_STATE py_gil_st;\n  py_gil_st = PyGILState_Ensure();\n  // prepare for numpy c api\n  if (PyArray_API == nullptr) { _import_array(); }\n\n  PyObject *py_str, *py_module, *py_func;\n  PyObject *py_inputs, *py_outputs;\n\n  // get python kernel\n  static const std::string forward_suffix = \"_forward\";\n  static const std::string backward_suffix = \"_backward\";\n  std::string op_module_name = op_type_name;\n  if (op_type_name.size() > forward_suffix.size()\n      && op_type_name.rfind(forward_suffix) == (op_type_name.size() - forward_suffix.size())) {\n    op_module_name = op_type_name.substr(0, op_type_name.size() - forward_suffix.size());\n  }\n  if (op_type_name.size() > backward_suffix.size()\n      && op_type_name.rfind(backward_suffix) == (op_type_name.size() - backward_suffix.size())) {\n    op_module_name = op_type_name.substr(0, op_type_name.size() - backward_suffix.size());\n  }\n  py_str = PyUnicode_DecodeFSDefault(op_module_name.c_str());\n  CHECK(py_kernels_dic) << \"py_kernels_dic should not be nullptr.\";\n  py_module = PyDict_GetItem(py_kernels_dic, py_str);\n  CHECK(py_module) << op_module_name << \" has no python kernel.\";\n  Py_DECREF(py_str);\n\n  // get func\n  py_func = PyObject_GetAttrString(py_module, py_func_name.c_str());\n  if (py_func == nullptr || !PyCallable_Check(py_func)) {\n    Py_DECREF(py_module);\n    PyErr_Print();\n  }\n\n  // get numpy input\n  MakePyInputs(op_def, ctx, &py_inputs);\n\n  // call func\n  py_outputs = PyEval_CallObject(py_func, py_inputs);\n  Py_DECREF(py_inputs);\n\n  // get numpy output\n  GetPyOutputs(op_def, ctx, py_outputs);\n\n  Py_XDECREF(py_func);\n  Py_DECREF(py_outputs);\n\n  // release GIL\n  PyGILState_Release(py_gil_st);\n}\n\n}  // namespace pyext\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/extension/python/py_compute.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_EXTENSION_PYTHON_PY_COMPUTE_H_\n#define ONEFLOW_EXTENSION_PYTHON_PY_COMPUTE_H_\n#include <Python.h>\n#undef _PyGC_FINALIZED\n#include \"oneflow/core/framework/framework.h\"\n\nnamespace oneflow {\nnamespace pyext {\nvoid PyRegisterKernels(PyObject* py_kernels);\nvoid PyCompute(user_op::KernelComputeContext* ctx, const std::string& py_func_name);\n}  // namespace pyext\n}  // namespace oneflow\n\n#endif  // ONEFLOW_EXTENSION_PYTHON_PY_COMPUTE_H_\n"
  },
  {
    "path": "oneflow/extension/python/py_kernel_caller.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/extension/python/py_kernel_caller.h\"\n#include \"oneflow/extension/python/py_compute.h\"\n\nnamespace oneflow {\nvoid PyForwardKernel::Compute(user_op::KernelComputeContext* ctx) const {\n  ::oneflow::pyext::PyCompute(ctx, \"forward\");\n}\n\nvoid PyBackwardKernel::Compute(user_op::KernelComputeContext* ctx) const {\n  ::oneflow::pyext::PyCompute(ctx, \"backward\");\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/extension/python/py_kernel_caller.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_EXTENSION_PYTHON_PY_KERNEL_CALLER_H_\n#define ONEFLOW_EXTENSION_PYTHON_PY_KERNEL_CALLER_H_\n#include \"oneflow/core/framework/framework.h\"\n\nnamespace oneflow {\nclass PyForwardKernel final : public user_op::OpKernel {\n public:\n  PyForwardKernel() = default;\n  ~PyForwardKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override;\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\nclass PyBackwardKernel final : public user_op::OpKernel {\n public:\n  PyBackwardKernel() = default;\n  ~PyBackwardKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override;\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_EXTENSION_PYTHON_PY_KERNEL_CALLER_H_\n"
  },
  {
    "path": "oneflow/extension/python/py_kernel_registry.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/extension/python/py_kernel_registry.h\"\n#include \"oneflow/extension/python/py_compute.h\"\n#include \"oneflow/extension/python/py_kernel_caller.h\"\n\nnamespace oneflow {\nnamespace pyext {\n\nMaybe<void> RegisterPyKernelCaller(const std::string& op_module_name) {\n  // register python op kernel\n  auto reg = user_op::UserOpRegistryMgr::Get()\n                 .CheckAndGetOpKernelRegistry(op_module_name + \"_forward\")\n                 .SetCreateFn<PyForwardKernel>()\n                 .SetIsMatchedHob(((user_op::HobDeviceType() == DeviceType::kCPU)\n                                   && (user_op::HobDeviceSubTag() == \"py\")));\n  JUST(user_op::UserOpRegistryMgr::Get().Register(JUST(reg.Finish()).GetResult()));\n  // register python grad op kernel\n  auto grad_reg = user_op::UserOpRegistryMgr::Get()\n                      .CheckAndGetOpKernelRegistry(op_module_name + \"_backward\")\n                      .SetCreateFn<PyBackwardKernel>()\n                      .SetIsMatchedHob(((user_op::HobDeviceType() == DeviceType::kCPU)\n                                        && (user_op::HobDeviceSubTag() == \"py\")));\n  JUST(user_op::UserOpRegistryMgr::Get().Register(JUST(grad_reg.Finish()).GetResult()));\n  return Maybe<void>::Ok();\n}\n\nvoid RegisterPyKernels(PyObject* py_kernels) { PyRegisterKernels(py_kernels); }\n\n}  // namespace pyext\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/extension/python/py_kernel_registry.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_EXTENSION_PYTHON_PY_KERNEL_REGISTRY_H_\n#define ONEFLOW_EXTENSION_PYTHON_PY_KERNEL_REGISTRY_H_\n\n#include <string>\n#include <Python.h>\n#undef _PyGC_FINALIZED\n#include \"oneflow/core/common/maybe.h\"\n\nnamespace oneflow {\nnamespace pyext {\nMaybe<void> RegisterPyKernelCaller(const std::string& op_module_name);\nvoid RegisterPyKernels(PyObject* py_kernels);\n}  // namespace pyext\n}  // namespace oneflow\n\n#endif  // ONEFLOW_EXTENSION_PYTHON_PY_KERNEL_REGISTRY_H_\n"
  },
  {
    "path": "oneflow/extension/stack/foreign_stack_getter.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_EXTENSION_STACK_STACK_GETTER_H_\n#define ONEFLOW_EXTENSION_STACK_STACK_GETTER_H_\n\n#include <cstdint>\n#include <utility>\n#include \"oneflow/core/common/thread_local_guard.h\"\n\nnamespace oneflow {\n\nclass Frame {\n public:\n  virtual ~Frame() = default;\n};\n\nusing ForeignFrameThreadLocalGuard = ThreadLocalGuard<std::shared_ptr<Frame>>;\n\nclass ForeignStackGetter {\n public:\n  virtual ~ForeignStackGetter() = default;\n  virtual std::shared_ptr<Frame> GetCurrentFrame() const = 0;\n  virtual std::string GetFormattedStack(std::shared_ptr<Frame> frame) const = 0;\n};\n}  // namespace oneflow\n\n#endif  // ONEFLOW_EXTENSION_STACK_STACK_GETTER_H_\n"
  },
  {
    "path": "oneflow/extension/stack/python/custom_eval_frame.c",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n// see https://bugs.python.org/issue23644 for why this file is written\n// as .c instead of .cpp\n\n#include \"oneflow/extension/stack/python/custom_eval_frame.h\"\n\n#define PY_SSIZE_T_CLEAN\n#include <Python.h>\n#undef _PyGC_FINALIZED\n#include <frameobject.h>\n#include <pystate.h>\n// see https://bugs.python.org/issue35886\n#if PY_VERSION_HEX >= 0x03080000\n#define Py_BUILD_CORE\n#include \"internal/pycore_pystate.h\"\n#undef Py_BUILD_CORE\n#endif\n\ninline static void EnableCustomEvalFrame(PyThreadState* tstate, _PyFrameEvalFunction eval_func) {\n#if PY_VERSION_HEX >= 0x03090000\n  if (_PyInterpreterState_GetEvalFrameFunc(tstate->interp) != eval_func) {\n    _PyInterpreterState_SetEvalFrameFunc(tstate->interp, eval_func);\n  }\n#else\n  if (tstate->interp->eval_frame != eval_func) {\n    tstate->interp->eval_frame = eval_func;\n  }\n#endif\n}\n\nvoid EnableCustomEvalFrameForCurrentThread(PyFrameEvalFunc eval_func) {\n  return EnableCustomEvalFrame(PyThreadState_GET(), eval_func);\n}\n"
  },
  {
    "path": "oneflow/extension/stack/python/custom_eval_frame.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_API_PYTHON_CUSTOM_EVAL_FRAME_H_\n#define ONEFLOW_API_PYTHON_CUSTOM_EVAL_FRAME_H_\n\n#ifdef __cplusplus\nextern \"C\" {\n#endif\n\n#include <Python.h>\n#undef _PyGC_FINALIZED\n\n#if PY_VERSION_HEX >= 0x03090000\ntypedef PyObject* (*PyFrameEvalFunc)(struct _ts*, struct _frame*, int);\n#else\ntypedef PyObject* (*PyFrameEvalFunc)(struct _frame*, int);\n#endif\nvoid EnableCustomEvalFrameForCurrentThread(PyFrameEvalFunc eval_func);\n\n#ifdef __cplusplus\n}\n#endif\n\n#endif  // ONEFLOW_API_PYTHON_CUSTOM_EVAL_FRAME_H_\n"
  },
  {
    "path": "oneflow/extension/stack/python/stack_getter.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/extension/stack/python/stack_getter.h\"\n\n#include <utility>\n\n#include \"fmt/core.h\"\n#include \"fmt/color.h\"\n#include \"fmt/ostream.h\"\n#include \"pybind11/pybind11.h\"\n\n#if PY_VERSION_HEX >= 0x030b0000\n#ifndef Py_BUILD_CORE\n#define Py_BUILD_CORE 1\n#endif\n#include \"internal/pycore_frame.h\"\n#endif\n\n#include \"oneflow/api/python/of_api_registry.h\"\n#include \"oneflow/core/common/env_var/debug_mode.h\"\n#include \"oneflow/core/common/singleton.h\"\n#include \"oneflow/core/framework/shut_down_util.h\"\n#include \"oneflow/core/common/foreign_lock_helper.h\"\n#include \"oneflow/core/common/env_var/debug_mode.h\"\n#include \"oneflow/core/job/graph_scope_vars.h\"\n#include \"oneflow/extension/stack/foreign_stack_getter.h\"\n#include \"oneflow/extension/stack/python/custom_eval_frame.h\"\n\nnamespace py = pybind11;\n\nnamespace oneflow {\n\nnamespace {\nstd::string PyUnicodeToStdString(const PyObject* py_str) {\n  return PyBytes_AsString(PyUnicode_AsEncodedString(const_cast<PyObject*>(py_str), \"utf-8\", \"~E~\"));\n}\n#if PY_VERSION_HEX < 0x03090000\nPyCodeObject* PyFrame_GetCode(PyFrameObject* frame) {\n  assert(frame != NULL);\n  PyCodeObject* code = frame->f_code;\n  assert(code != NULL);\n  Py_INCREF(code);\n  return code;\n}\n#endif\n}  // namespace\n\nclass PyFrame final : public Frame {\n public:\n  // There is no need to increase the reference count of these cpython objects\n  // because they must be alive during the lifetime of `PyFrame`.\n  PyFrame(PyFrameObject* frame, std::shared_ptr<PyFrame> back)\n      : cpython_frame(frame), lineno(0), back(std::move(back)) {\n    PyCodeObject* code = PyFrame_GetCode(frame);\n    filename = code->co_filename;\n    funcname = code->co_name;\n    Py_DECREF(code);\n  }\n  ~PyFrame() = default;\n  OF_DISALLOW_COPY_AND_MOVE(PyFrame);\n\n  PyObject* filename;\n  PyObject* funcname;\n  PyFrameObject* cpython_frame;\n  int lineno;\n  std::shared_ptr<PyFrame> back;\n};\n\nclass PyStackGetter final : public ForeignStackGetter {\n public:\n  PyStackGetter() {\n    auto* frame = PyEval_GetFrame();\n    // Get the first frame. It assumes `import oneflow` is called in global scope,\n    while (frame->f_back != nullptr) { frame = frame->f_back; }\n    current_frame_ = std::make_shared<PyFrame>(frame, nullptr);\n  }\n  // indended to be called in main thread.\n  std::shared_ptr<Frame> GetCurrentFrame() const override {\n    if (IsShuttingDown() || !current_frame_) { return nullptr; }\n    // See `RecordAndEvalFrame` for documentation.\n    current_frame_->lineno = PyFrame_GetLineNumber(current_frame_->cpython_frame);\n    return current_frame_;\n  }\n\n  // bad path, performance is not a concern.\n  std::string GetFormattedStack(std::shared_ptr<Frame> frame) const override {\n    if (frame == nullptr) { return \"  <unknown>\\n\"; }\n    std::string buffer;\n    const auto* py_frame = dynamic_cast<const PyFrame*>(frame.get());\n    py::gil_scoped_acquire acquire;\n    while (py_frame != nullptr) {\n      const auto& lineno = py_frame->lineno;\n      const std::string line_text = [&]() -> std::string {\n        std::string line_text;\n        std::ifstream ifs(PyUnicodeToStdString(py_frame->filename));\n        if (!ifs.is_open()) { return \"<unknown>\"; }\n        for (int j = 0; j < lineno; ++j) { std::getline(ifs, line_text); }\n        line_text.erase(line_text.find_last_not_of(' ') + 1);  // suffixing spaces\n        line_text.erase(0, line_text.find_first_not_of(' '));  // prefixing spaces\n        return line_text;\n      }();\n      // immitate python's stack trace format\n      fmt::format_to(std::back_inserter(buffer), \"  File \\\"{}\\\", line {}, in {}\\n    {}\\n\",\n                     PyUnicodeToStdString(py_frame->filename), lineno,\n                     PyUnicodeToStdString(py_frame->funcname), line_text);\n      py_frame = py_frame->back.get();\n    }\n    return buffer;\n  };\n\n#if PY_VERSION_HEX >= 0x03090000\n  PyObject* RecordAndEvalFrame(PyThreadState* tstate, PyFrameObject* frame,\n#else\n  PyObject* RecordAndEvalFrame(PyFrameObject* frame,\n#endif\n                               int throw_flag) {\n    // Example:\n    // >> def f(): # Line 1\n    // >>   pass   # Line 2\n    // >> f()      # Line 3\n    //\n    // When we call f(), `RecordAndEvalFrame` is triggered and the `frame`\n    // argument is the frame of function `f`, which is Line 1 at that time. It is not\n    // useful to us, but we can adjust it in `GetCurrentFrame` method.\n    //\n    PushFrame(frame);\n#if PY_VERSION_HEX >= 0x03090000\n    if (tstate == NULL) { tstate = PyThreadState_GET(); }\n#if PY_VERSION_HEX >= 0x030b0000\n    PyObject* ret = _PyEval_EvalFrameDefault(tstate, frame->f_frame, throw_flag);\n#else\n    PyObject* ret = _PyEval_EvalFrameDefault(tstate, frame, throw_flag);\n#endif\n#else\n    PyObject* ret = _PyEval_EvalFrameDefault(frame, throw_flag);\n#endif\n    PopFrame();\n    return ret;\n  }\n\n private:\n  std::shared_ptr<PyFrame> current_frame_;\n\n  void PushFrame(PyFrameObject* frame) {\n    if (auto* f = frame->f_back) { current_frame_->lineno = PyFrame_GetLineNumber(f); }\n    current_frame_ = std::make_shared<PyFrame>(frame, current_frame_);\n  }\n  void PopFrame() {\n    CHECK_NOTNULL(current_frame_);\n    current_frame_ = current_frame_->back;\n  }\n};\n\n#if PY_VERSION_HEX >= 0x03090000\nPyObject* RecordAndEvalFrame(PyThreadState* tstate, PyFrameObject* frame,\n#else\nPyObject* RecordAndEvalFrame(PyFrameObject* frame,\n#endif\n                             int throw_flag) {\n  using namespace oneflow;\n  return dynamic_cast<PyStackGetter*>(Singleton<ForeignStackGetter>::Get())\n#if PY_VERSION_HEX >= 0x03090000\n      ->RecordAndEvalFrame(tstate, frame, throw_flag);\n#else\n      ->RecordAndEvalFrame(frame, throw_flag);\n#endif\n}\n\nvoid RegisterPyStackGetter() {\n  if (!IsPythonStackGetterEnabled()) { return; }\n  Singleton<ForeignStackGetter>::Delete();\n  Singleton<ForeignStackGetter>::SetAllocated(new PyStackGetter());\n  EnableCustomEvalFrameForCurrentThread(&RecordAndEvalFrame);\n}\n\nnamespace {\n\n// get a formatted stack frame representation\nstd::string get_python_frame_str_repr(PyFrameObject* frame) {\n  if (frame == NULL) return \"\";\n  std::string buffer;\n  PyCodeObject* code = PyFrame_GetCode(frame);\n  std::string file_name = PyUnicodeToStdString(code->co_filename);\n  std::string code_name = PyUnicodeToStdString(code->co_name);\n  Py_DECREF(code);\n  int line_number = PyFrame_GetLineNumber(frame);\n\n  fmt::format_to(std::back_inserter(buffer), \"File \\\"{}\\\", line {}, in {}\", file_name, line_number,\n                 code_name);\n\n  std::string line_text;\n  const bool debug_mode = GetGraphDebugMode() || IsInDebugMode();\n  if (debug_mode) {\n    const auto& GetCurSrc = [&file_name, line_number]() -> std::string {\n      std::string line_text;\n      std::ifstream ifs(file_name);\n      if (!ifs.is_open()) { return \"<unknown>\"; }\n      for (int j = 0; j < line_number; ++j) { std::getline(ifs, line_text); }\n      line_text.erase(line_text.find_last_not_of(' ') + 1);  // suffixing spaces\n      line_text.erase(0, line_text.find_first_not_of(' '));  // prefixing spaces\n      return line_text;\n    };\n    line_text = GetCurSrc();\n    buffer += \", source < \" + line_text + \" >; \";\n  } else {\n    buffer += \"; \";\n  }\n\n  return buffer;\n}\n\nbool check_if_python_file_should_be_filtered(const std::string& path) {\n  const auto& paths_to_be_kept = GetPythonPathsToBeKeptForDebugging();\n  for (int i = 0; i < paths_to_be_kept.size(); ++i) {\n    const std::string& path_to_be_kept = paths_to_be_kept[i];\n    if (path.size() > path_to_be_kept.size()) {\n      if (path.substr(0, path_to_be_kept.size()) == path_to_be_kept) { return false; }\n    }\n  }\n\n  const auto& paths_to_be_filtered = GetPythonPathsToBeFilteredForDebugging();\n  for (int i = 0; i < paths_to_be_filtered.size(); ++i) {\n    const std::string& path_to_be_filtered = paths_to_be_filtered[i];\n    if (path.size() > path_to_be_filtered.size()) {\n      if (path.substr(0, path_to_be_filtered.size()) == path_to_be_filtered) { return true; }\n    }\n  }\n\n  return false;\n}\n\nbool check_if_frame_should_be_filtered(PyFrameObject* frame) {\n  std::string frame_file_name = PyUnicodeToStdString(PyFrame_GetCode(frame)->co_filename);\n  return check_if_python_file_should_be_filtered(frame_file_name);\n}\n\nbool check_if_should_skip_this_frame(PyFrameObject* frame) {\n  const bool only_user_py_stack = GetGraphDebugOnlyUserPyStack();\n  if (only_user_py_stack) { return check_if_frame_should_be_filtered(frame); }\n  return false;\n}\n\nint32_t get_cur_stack_depth() {\n  int32_t current_stack_depth = 0;\n  PyFrameObject* f = PyEval_GetFrame();\n  while (f) {\n    if (check_if_should_skip_this_frame(f)) {\n      f = f->f_back;\n      continue;\n    }\n\n    current_stack_depth++;\n    f = f->f_back;\n  }\n  return current_stack_depth;\n}\n\nstd::string get_cur_frame_stack_str() {\n  const int32_t max_stack_depth = GetGraphDebugMaxPyStackDepth();\n  std::string cur_f_str;\n  PyFrameObject* cur_frame = PyEval_GetFrame();\n\n  int i = 0;\n  while (i < max_stack_depth) {\n    if (cur_frame == NULL) break;\n\n    if (check_if_should_skip_this_frame(cur_frame)) {\n      cur_frame = cur_frame->f_back;\n      continue;\n    }\n    cur_f_str += get_python_frame_str_repr(cur_frame);\n    cur_frame = cur_frame->f_back;\n    i++;\n  }\n\n  // show how may stack frames remain to be shown in debug mode\n  const bool debug_mode = GetGraphDebugMode() || IsInDebugMode();\n  if (debug_mode) {\n    const int32_t current_stack_depth = get_cur_stack_depth();\n    if (current_stack_depth > max_stack_depth) {\n      cur_f_str += \"... \" + std::to_string(current_stack_depth - max_stack_depth) + \" more\";\n    }\n  } else {\n    if (cur_frame != NULL) { cur_f_str += \" ... more\"; }\n  }\n\n  return cur_f_str;\n}\n\n}  // namespace\n\nPythonFrameGuard::PythonFrameGuard() {\n  if (OF_PREDICT_FALSE(LazyMode::is_enabled())) {\n    prev_frame_str_ = DispatchFrame::get_str();\n    DispatchFrame::set_str(get_cur_frame_stack_str());\n  }\n}\nPythonFrameGuard::~PythonFrameGuard() {\n  if (OF_PREDICT_FALSE(LazyMode::is_enabled())) { DispatchFrame::set_str(prev_frame_str_); }\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/extension/stack/python/stack_getter.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_EXTENSION_STACK_PYTHON_STACK_GETTER\n#define ONEFLOW_EXTENSION_STACK_PYTHON_STACK_GETTER\n\n#include <string>\nnamespace oneflow {\nvoid RegisterPyStackGetter();\n\nclass PythonFrameGuard {\n public:\n  PythonFrameGuard();\n  ~PythonFrameGuard();\n\n private:\n  std::string prev_frame_str_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_EXTENSION_STACK_PYTHON_STACK_GETTER\n"
  },
  {
    "path": "oneflow/extension/stack/stacktrace.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n/*\n * backward.hpp\n * Copyright 2013 Google Inc. All Rights Reserved.\n *\n * Permission is hereby granted, free of charge, to any person obtaining a copy\n * of this software and associated documentation files (the \"Software\"), to deal\n * in the Software without restriction, including without limitation the rights\n * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n * copies of the Software, and to permit persons to whom the Software is\n * furnished to do so, subject to the following conditions:\n *\n * The above copyright notice and this permission notice shall be included in\n * all copies or substantial portions of the Software.\n *\n * THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\n * SOFTWARE.\n */\n\n#ifndef H_6B9572DA_A64B_49E6_B234_051480991C89\n#define H_6B9572DA_A64B_49E6_B234_051480991C89\n\n#ifndef __cplusplus\n#error \"It's not going to compile without a C++ compiler...\"\n#endif\n\n#if defined(BACKWARD_CXX11)\n#elif defined(BACKWARD_CXX98)\n#else\n#if __cplusplus >= 201103L || (defined(_MSC_VER) && _MSC_VER >= 1800)\n#define BACKWARD_CXX11\n#define BACKWARD_ATLEAST_CXX11\n#define BACKWARD_ATLEAST_CXX98\n#if __cplusplus >= 201703L || (defined(_MSVC_LANG) && _MSVC_LANG >= 201703L)\n#define BACKWARD_ATLEAST_CXX17\n#endif\n#else\n#define BACKWARD_CXX98\n#define BACKWARD_ATLEAST_CXX98\n#endif\n#endif\n\n// You can define one of the following (or leave it to the auto-detection):\n//\n// #define BACKWARD_SYSTEM_LINUX\n//\t- specialization for linux\n//\n// #define BACKWARD_SYSTEM_DARWIN\n//\t- specialization for Mac OS X 10.5 and later.\n//\n// #define BACKWARD_SYSTEM_WINDOWS\n//  - specialization for Windows (Clang 9 and MSVC2017)\n//\n// #define BACKWARD_SYSTEM_UNKNOWN\n//\t- placebo implementation, does nothing.\n//\n#if defined(BACKWARD_SYSTEM_LINUX)\n#elif defined(BACKWARD_SYSTEM_DARWIN)\n#elif defined(BACKWARD_SYSTEM_UNKNOWN)\n#elif defined(BACKWARD_SYSTEM_WINDOWS)\n#else\n#if defined(__linux) || defined(__linux__)\n#define BACKWARD_SYSTEM_LINUX\n#elif defined(__APPLE__)\n#define BACKWARD_SYSTEM_DARWIN\n#elif defined(_WIN32)\n#define BACKWARD_SYSTEM_WINDOWS\n#else\n#define BACKWARD_SYSTEM_UNKNOWN\n#endif\n#endif\n\n#define NOINLINE __attribute__((noinline))\n\n#include <algorithm>\n#include <cctype>\n#include <cstdio>\n#include <cstdlib>\n#include <cstring>\n#include <filesystem>\n#include <fstream>\n#include <iomanip>\n#include <iostream>\n#include <limits>\n#include <new>\n#include <sstream>\n#include <streambuf>\n#include <string>\n#include <vector>\n#include <exception>\n#include <iterator>\n#include <regex>\n\n#if defined(BACKWARD_SYSTEM_LINUX)\n\n// On linux, backtrace can back-trace or \"walk\" the stack using the following\n// libraries:\n//\n// #define BACKWARD_HAS_UNWIND 1\n//  - unwind comes from libgcc, but I saw an equivalent inside clang itself.\n//  - with unwind, the stacktrace is as accurate as it can possibly be, since\n//  this is used by the C++ runtime in gcc/clang for stack unwinding on\n//  exception.\n//  - normally libgcc is already linked to your program by default.\n//\n// #define BACKWARD_HAS_LIBUNWIND 1\n//  - libunwind provides, in some cases, a more accurate stacktrace as it knows\n//  to decode signal handler frames and lets us edit the context registers when\n//  unwinding, allowing stack traces over bad function references.\n//\n// #define BACKWARD_HAS_BACKTRACE == 1\n//  - backtrace seems to be a little bit more portable than libunwind, but on\n//  linux, it uses unwind anyway, but abstract away a tiny information that is\n//  sadly really important in order to get perfectly accurate stack traces.\n//  - backtrace is part of the (e)glib library.\n//\n// The default is:\n// #define BACKWARD_HAS_UNWIND == 1\n//\n// Note that only one of the define should be set to 1 at a time.\n//\n#if BACKWARD_HAS_UNWIND == 1\n#elif BACKWARD_HAS_LIBUNWIND == 1\n#elif BACKWARD_HAS_BACKTRACE == 1\n#else\n#undef BACKWARD_HAS_UNWIND\n#define BACKWARD_HAS_UNWIND 1\n#undef BACKWARD_HAS_LIBUNWIND\n#define BACKWARD_HAS_LIBUNWIND 0\n#undef BACKWARD_HAS_BACKTRACE\n#define BACKWARD_HAS_BACKTRACE 0\n#endif\n\n// On linux, backward can extract detailed information about a stack trace\n// using one of the following libraries:\n//\n// #define BACKWARD_HAS_DW 1\n//  - libdw gives you the most juicy details out of your stack traces:\n//    - object filename\n//    - function name\n//    - source filename\n//    - line and column numbers\n//    - source code snippet (assuming the file is accessible)\n//    - variable names (if not optimized out)\n//    - variable values (not supported by backward-cpp)\n//  - You need to link with the lib \"dw\":\n//    - apt-get install libdw-dev\n//    - g++/clang++ -ldw ...\n//\n// #define BACKWARD_HAS_BFD 1\n//  - With libbfd, you get a fair amount of details:\n//    - object filename\n//    - function name\n//    - source filename\n//    - line numbers\n//    - source code snippet (assuming the file is accessible)\n//  - You need to link with the lib \"bfd\":\n//    - apt-get install binutils-dev\n//    - g++/clang++ -lbfd ...\n//\n// #define BACKWARD_HAS_DWARF 1\n//  - libdwarf gives you the most juicy details out of your stack traces:\n//    - object filename\n//    - function name\n//    - source filename\n//    - line and column numbers\n//    - source code snippet (assuming the file is accessible)\n//    - variable names (if not optimized out)\n//    - variable values (not supported by backward-cpp)\n//  - You need to link with the lib \"dwarf\":\n//    - apt-get install libdwarf-dev\n//    - g++/clang++ -ldwarf ...\n//\n// #define BACKWARD_HAS_BACKTRACE_SYMBOL 1\n//  - backtrace provides minimal details for a stack trace:\n//    - object filename\n//    - function name\n//  - backtrace is part of the (e)glib library.\n//\n// The default is:\n// #define BACKWARD_HAS_BACKTRACE_SYMBOL == 1\n//\n// Note that only one of the define should be set to 1 at a time.\n//\n#if BACKWARD_HAS_DW == 1\n#elif BACKWARD_HAS_BFD == 1\n#elif BACKWARD_HAS_DWARF == 1\n#elif BACKWARD_HAS_BACKTRACE_SYMBOL == 1\n#else\n#undef BACKWARD_HAS_DW\n#define BACKWARD_HAS_DW 0\n#undef BACKWARD_HAS_BFD\n#define BACKWARD_HAS_BFD 0\n#undef BACKWARD_HAS_DWARF\n#define BACKWARD_HAS_DWARF 0\n#undef BACKWARD_HAS_BACKTRACE_SYMBOL\n#define BACKWARD_HAS_BACKTRACE_SYMBOL 1\n#endif\n\n#include <cxxabi.h>\n#include <fcntl.h>\n#ifdef __ANDROID__\n//\t\tOld Android API levels define _Unwind_Ptr in both link.h and\n// unwind.h \t\tRename the one in link.h as we are not going to be using\n// it\n#define _Unwind_Ptr _Unwind_Ptr_Custom\n#include <link.h>\n#undef _Unwind_Ptr\n#else\n#include <link.h>\n#endif\n#if defined(__ppc__) || defined(__powerpc) || defined(__powerpc__) || defined(__POWERPC__)\n// Linux kernel header required for the struct pt_regs definition\n// to access the NIP (Next Instruction Pointer) register value\n#include <asm/ptrace.h>\n#endif\n#include <signal.h>\n#include <sys/stat.h>\n#include <syscall.h>\n#include <unistd.h>\n#ifndef _GNU_SOURCE\n#define _GNU_SOURCE\n#include <dlfcn.h>\n#undef _GNU_SOURCE\n#else\n#include <dlfcn.h>\n#endif\n\n#if BACKWARD_HAS_BFD == 1\n//              NOTE: defining PACKAGE{,_VERSION} is required before including\n//                    bfd.h on some platforms, see also:\n//                    https://sourceware.org/bugzilla/show_bug.cgi?id=14243\n#ifndef PACKAGE\n#define PACKAGE\n#endif\n#ifndef PACKAGE_VERSION\n#define PACKAGE_VERSION\n#endif\n#include <bfd.h>\n#endif\n\n#if BACKWARD_HAS_DW == 1\n#include <dwarf.h>\n#include <elfutils/libdw.h>\n#include <elfutils/libdwfl.h>\n#endif\n\n#if BACKWARD_HAS_DWARF == 1\n#include <algorithm>\n#include <dwarf.h>\n#include <libdwarf.h>\n#include <libelf.h>\n#include <map>\n#endif\n\n#if (BACKWARD_HAS_BACKTRACE == 1) || (BACKWARD_HAS_BACKTRACE_SYMBOL == 1)\n// then we shall rely on backtrace\n#include <execinfo.h>\n#endif\n\n#endif  // defined(BACKWARD_SYSTEM_LINUX)\n\n#if defined(BACKWARD_SYSTEM_DARWIN)\n// On Darwin, backtrace can back-trace or \"walk\" the stack using the following\n// libraries:\n//\n// #define BACKWARD_HAS_UNWIND 1\n//  - unwind comes from libgcc, but I saw an equivalent inside clang itself.\n//  - with unwind, the stacktrace is as accurate as it can possibly be, since\n//  this is used by the C++ runtime in gcc/clang for stack unwinding on\n//  exception.\n//  - normally libgcc is already linked to your program by default.\n//\n// #define BACKWARD_HAS_LIBUNWIND 1\n//  - libunwind comes from clang, which implements an API compatible version.\n//  - libunwind provides, in some cases, a more accurate stacktrace as it knows\n//  to decode signal handler frames and lets us edit the context registers when\n//  unwinding, allowing stack traces over bad function references.\n//\n// #define BACKWARD_HAS_BACKTRACE == 1\n//  - backtrace is available by default, though it does not produce as much\n//  information as another library might.\n//\n// The default is:\n// #define BACKWARD_HAS_UNWIND == 1\n//\n// Note that only one of the define should be set to 1 at a time.\n//\n#if BACKWARD_HAS_UNWIND == 1\n#elif BACKWARD_HAS_BACKTRACE == 1\n#elif BACKWARD_HAS_LIBUNWIND == 1\n#else\n#undef BACKWARD_HAS_UNWIND\n#define BACKWARD_HAS_UNWIND 1\n#undef BACKWARD_HAS_BACKTRACE\n#define BACKWARD_HAS_BACKTRACE 0\n#undef BACKWARD_HAS_LIBUNWIND\n#define BACKWARD_HAS_LIBUNWIND 0\n#endif\n\n// On Darwin, backward can extract detailed information about a stack trace\n// using one of the following libraries:\n//\n// #define BACKWARD_HAS_BACKTRACE_SYMBOL 1\n//  - backtrace provides minimal details for a stack trace:\n//    - object filename\n//    - function name\n//\n// The default is:\n// #define BACKWARD_HAS_BACKTRACE_SYMBOL == 1\n//\n#if BACKWARD_HAS_BACKTRACE_SYMBOL == 1\n#else\n#undef BACKWARD_HAS_BACKTRACE_SYMBOL\n#define BACKWARD_HAS_BACKTRACE_SYMBOL 1\n#endif\n\n#include <cxxabi.h>\n#include <fcntl.h>\n#include <pthread.h>\n#include <signal.h>\n#include <sys/stat.h>\n#include <unistd.h>\n\n#if (BACKWARD_HAS_BACKTRACE == 1) || (BACKWARD_HAS_BACKTRACE_SYMBOL == 1)\n#include <execinfo.h>\n#endif\n#endif  // defined(BACKWARD_SYSTEM_DARWIN)\n\n#if defined(BACKWARD_SYSTEM_WINDOWS)\n\n#include <condition_variable>\n#include <mutex>\n#include <thread>\n\n#include <basetsd.h>\n\n#ifdef _WIN64\ntypedef SSIZE_T ssize_t;\n#else\ntypedef int ssize_t;\n#endif\n\n#ifndef NOMINMAX\n#define NOMINMAX\n#endif\n#include <windows.h>\n#include <winnt.h>\n\n#include <psapi.h>\n#include <signal.h>\n\n#ifndef __clang__\n#undef NOINLINE\n#define NOINLINE __declspec(noinline)\n#endif\n\n#ifdef _MSC_VER\n#pragma comment(lib, \"psapi.lib\")\n#pragma comment(lib, \"dbghelp.lib\")\n#endif\n\n// Comment / packing is from stackoverflow:\n// https://stackoverflow.com/questions/6205981/windows-c-stack-trace-from-a-running-app/28276227#28276227\n// Some versions of imagehlp.dll lack the proper packing directives themselves\n// so we need to do it.\n#pragma pack(push, before_imagehlp, 8)\n#include <imagehlp.h>\n#pragma pack(pop, before_imagehlp)\n\n// TODO maybe these should be undefined somewhere else?\n#undef BACKWARD_HAS_UNWIND\n#undef BACKWARD_HAS_BACKTRACE\n#if BACKWARD_HAS_PDB_SYMBOL == 1\n#else\n#undef BACKWARD_HAS_PDB_SYMBOL\n#define BACKWARD_HAS_PDB_SYMBOL 1\n#endif\n\n#endif\n\n#if BACKWARD_HAS_UNWIND == 1\n\n#include <unwind.h>\n// while gcc's unwind.h defines something like that:\n//  extern _Unwind_Ptr _Unwind_GetIP (struct _Unwind_Context *);\n//  extern _Unwind_Ptr _Unwind_GetIPInfo (struct _Unwind_Context *, int *);\n//\n// clang's unwind.h defines something like this:\n//  uintptr_t _Unwind_GetIP(struct _Unwind_Context* __context);\n//\n// Even if the _Unwind_GetIPInfo can be linked to, it is not declared, worse we\n// cannot just redeclare it because clang's unwind.h doesn't define _Unwind_Ptr\n// anyway.\n//\n// Luckily we can play on the fact that the guard macros have a different name:\n#ifdef __CLANG_UNWIND_H\n// In fact, this function still comes from libgcc (on my different linux boxes,\n// clang links against libgcc).\n#include <inttypes.h>\nextern \"C\" uintptr_t _Unwind_GetIPInfo(_Unwind_Context*, int*);\n#endif\n\n#endif  // BACKWARD_HAS_UNWIND == 1\n\n#if BACKWARD_HAS_LIBUNWIND == 1\n#define UNW_LOCAL_ONLY\n#include <libunwind.h>\n#endif  // BACKWARD_HAS_LIBUNWIND == 1\n\n#ifdef BACKWARD_ATLEAST_CXX11\n#include <unordered_map>\n#include <utility>  // for std::swap\nnamespace backward {\nnamespace details {\ntemplate<typename K, typename V>\nstruct hashtable {\n  typedef std::unordered_map<K, V> type;\n};\nusing std::move;\n}  // namespace details\n}  // namespace backward\n#else  // NOT BACKWARD_ATLEAST_CXX11\n#define nullptr NULL\n#define override\n#include <map>\nnamespace backward {\nnamespace details {\ntemplate<typename K, typename V>\nstruct hashtable {\n  typedef std::map<K, V> type;\n};\ntemplate<typename T>\nconst T& move(const T& v) {\n  return v;\n}\ntemplate<typename T>\nT& move(T& v) {\n  return v;\n}\n}  // namespace details\n}  // namespace backward\n#endif  // BACKWARD_ATLEAST_CXX11\n\nnamespace backward {\nnamespace details {\n#if defined(BACKWARD_SYSTEM_WINDOWS)\nconst char kBackwardPathDelimiter[] = \";\";\n#else\nconst char kBackwardPathDelimiter[] = \":\";\n#endif\n}  // namespace details\n}  // namespace backward\n\nnamespace backward {\n\nnamespace system_tag {\nstruct linux_tag;  // seems that I cannot call that \"linux\" because the name\n// is already defined... so I am adding _tag everywhere.\nstruct darwin_tag;\nstruct windows_tag;\nstruct unknown_tag;\n\n#if defined(BACKWARD_SYSTEM_LINUX)\ntypedef linux_tag current_tag;\n#elif defined(BACKWARD_SYSTEM_DARWIN)\ntypedef darwin_tag current_tag;\n#elif defined(BACKWARD_SYSTEM_WINDOWS)\ntypedef windows_tag current_tag;\n#elif defined(BACKWARD_SYSTEM_UNKNOWN)\ntypedef unknown_tag current_tag;\n#else\n#error \"May I please get my system defines?\"\n#endif\n}  // namespace system_tag\n\nnamespace trace_resolver_tag {\n#if defined(BACKWARD_SYSTEM_LINUX)\nstruct libdw;\nstruct libbfd;\nstruct libdwarf;\nstruct backtrace_symbol;\n\n#if BACKWARD_HAS_DW == 1\ntypedef libdw current;\n#elif BACKWARD_HAS_BFD == 1\ntypedef libbfd current;\n#elif BACKWARD_HAS_DWARF == 1\ntypedef libdwarf current;\n#elif BACKWARD_HAS_BACKTRACE_SYMBOL == 1\ntypedef backtrace_symbol current;\n#else\n#error \"You shall not pass, until you know what you want.\"\n#endif\n#elif defined(BACKWARD_SYSTEM_DARWIN)\nstruct backtrace_symbol;\n\n#if BACKWARD_HAS_BACKTRACE_SYMBOL == 1\ntypedef backtrace_symbol current;\n#else\n#error \"You shall not pass, until you know what you want.\"\n#endif\n#elif defined(BACKWARD_SYSTEM_WINDOWS)\nstruct pdb_symbol;\n#if BACKWARD_HAS_PDB_SYMBOL == 1\ntypedef pdb_symbol current;\n#else\n#error \"You shall not pass, until you know what you want.\"\n#endif\n#endif\n}  // namespace trace_resolver_tag\n\nnamespace details {\n\ntemplate<typename T>\nstruct rm_ptr {\n  typedef T type;\n};\n\ntemplate<typename T>\nstruct rm_ptr<T*> {\n  typedef T type;\n};\n\ntemplate<typename T>\nstruct rm_ptr<const T*> {\n  typedef const T type;\n};\n\ntemplate<typename R, typename T, R (*F)(T)>\nstruct deleter {\n  template<typename U>\n  void operator()(U& ptr) const {\n    (*F)(ptr);\n  }\n};\n\ntemplate<typename T>\nstruct default_delete {\n  void operator()(T& ptr) const { delete ptr; }\n};\n\ntemplate<typename T, typename Deleter = deleter<void, void*, &::free>>\nclass handle {\n  struct dummy;\n  T _val;\n  bool _empty;\n\n#ifdef BACKWARD_ATLEAST_CXX11\n  handle(const handle&) = delete;\n  handle& operator=(const handle&) = delete;\n#endif\n\n public:\n  ~handle() {\n    if (!_empty) { Deleter()(_val); }\n  }\n\n  explicit handle() : _val(), _empty(true) {}\n  explicit handle(T val) : _val(val), _empty(false) {\n    if (!_val) _empty = true;\n  }\n\n#ifdef BACKWARD_ATLEAST_CXX11\n  handle(handle&& from) : _empty(true) { swap(from); }\n  handle& operator=(handle&& from) {\n    swap(from);\n    return *this;\n  }\n#else\n  explicit handle(const handle& from) : _empty(true) {\n    // some sort of poor man's move semantic.\n    swap(const_cast<handle&>(from));\n  }\n  handle& operator=(const handle& from) {\n    // some sort of poor man's move semantic.\n    swap(const_cast<handle&>(from));\n    return *this;\n  }\n#endif\n\n  void reset(T new_val) {\n    handle tmp(new_val);\n    swap(tmp);\n  }\n\n  void update(T new_val) {\n    _val = new_val;\n    _empty = !static_cast<bool>(new_val);\n  }\n\n  operator const dummy*() const {\n    if (_empty) { return nullptr; }\n    return reinterpret_cast<const dummy*>(_val);\n  }\n  T get() { return _val; }\n  T release() {\n    _empty = true;\n    return _val;\n  }\n  void swap(handle& b) {\n    using std::swap;\n    swap(b._val, _val);      // can throw, we are safe here.\n    swap(b._empty, _empty);  // should not throw: if you cannot swap two\n    // bools without throwing... It's a lost cause anyway!\n  }\n\n  T& operator->() { return _val; }\n  const T& operator->() const { return _val; }\n\n  typedef typename rm_ptr<T>::type& ref_t;\n  typedef const typename rm_ptr<T>::type& const_ref_t;\n  ref_t operator*() { return *_val; }\n  const_ref_t operator*() const { return *_val; }\n  ref_t operator[](size_t idx) { return _val[idx]; }\n\n  // Watch out, we've got a badass over here\n  T* operator&() {\n    _empty = false;\n    return &_val;\n  }\n};\n\nnamespace {\n// how many args to keep in template params\n// e.g. {std::vector, 1} means std::vector<T0, T1, T2, ..., Tn> -> std::vector<T0>\nstatic std::unordered_map<std::string, size_t> class2keepsize{\n    {\"std::vector\", 1},\n    {\"Maybe\", 1},\n};\n\nclass SignatureType {\n public:\n  SignatureType(const std::string& name, const std::vector<SignatureType>& args,\n                const std::string& specifier)\n      : name(name), args(args), specifier(specifier){};\n  std::string name;\n  std::vector<SignatureType> args;\n  std::string specifier;\n  using pss = std::pair<std::string, std::string>;\n\n  size_t get_keep_size(const std::string& name) {\n    auto it = class2keepsize.find(name);\n    if (it == class2keepsize.end()) {\n      return 0;\n    } else {\n      return it->second;\n    }\n  }\n\n  std::string to_string() {\n    std::string str_args, str_specifer;\n\n    if (args.empty()) {\n      str_args = \"\";\n    } else {\n      str_args = \"<\";\n      size_t keep_size = get_keep_size(name);\n      if (keep_size == 0) { keep_size = args.size(); }\n      for (int i = 0; i < keep_size; i++) {\n        SignatureType type = args[i];\n        str_args += type.to_string() + ((i != (keep_size - 1)) ? \", \" : \"\");\n      };\n      str_args += \"> \";\n    }\n\n    return name + str_args + specifier;\n  }\n\n  static std::pair<std::vector<SignatureType>, std::string> parse_args(std::string s) {\n    std::vector<SignatureType> args;\n    while (s[0] != '>') {\n      s = s.substr(1, s.size() - 1);\n      auto type_and_rest = parse_type(s);\n      s = type_and_rest.second;\n      args.push_back(type_and_rest.first);\n    }\n    return {args, s.substr(1, s.size() - 1)};\n  }\n\n  static pss parse_spaces(const std::string& inp) {\n    size_t pos = inp.find_first_not_of(\" \");\n    if (pos == 0) {\n      return {\"\", inp};\n    } else {\n      return {inp.substr(0, pos), inp.substr(pos, inp.size() - pos)};\n    }\n  }\n\n  static pss parse_type_specifier(const std::string& inp) {\n    static std::vector<std::string> specifier_list{\n        \"const&\",\n        \"const\",\n        \"volatile\",\n    };\n    for (const auto& specifier : specifier_list) {\n      if (inp.rfind(specifier, 0) == 0) {\n        return {specifier, inp.substr(specifier.size(), inp.size() - specifier.size())};\n      }\n    }\n    return {\"\", inp};\n  }\n\n  static pss parse_simple_type_id(const std::string& inp) {\n    auto rest = parse_spaces(inp).second;\n    std::smatch found;\n    std::regex_search(rest, found, std::regex(\"^((\\\\w|:|\\\\*|&)+)\"));\n    std::string name = found[0];\n    return {name, rest.substr(name.size(), rest.size() - name.size())};\n  }\n\n  static std::pair<SignatureType, std::string> parse_type_id(const std::string& inp) {\n    auto rest = parse_spaces(inp).second;\n    std::smatch found;\n    std::regex_search(rest, found, std::regex(\"^((\\\\w|:|\\\\*|&)+)\"));\n    std::string name = found[0];\n    rest = rest.substr(name.size(), rest.size() - name.size());\n    rest = parse_spaces(rest).second;\n    auto spec_and_rest = parse_type_specifier(rest);\n    auto type_spec = spec_and_rest.first;\n    rest = spec_and_rest.second;\n    return {SignatureType(name, {}, type_spec), rest};\n  }\n\n  static std::pair<SignatureType, std::string> parse_type(const std::string& inp) {\n    auto name_and_rest = parse_simple_type_id(inp);\n    auto type_name = name_and_rest.first;\n    auto rest = name_and_rest.second;\n    std::vector<SignatureType> args;\n    if (rest[0] == '<') {\n      auto args_and_rest = parse_args(rest);\n      args = args_and_rest.first;\n      rest = args_and_rest.second;\n    }\n    rest = parse_spaces(rest).second;\n    auto specifier_and_rest = parse_type_specifier(rest);\n    auto type_spec = specifier_and_rest.first;\n    rest = specifier_and_rest.second;\n    return {SignatureType(type_name, args, type_spec), rest};\n  }\n};\n\nstd::string replace_each(const std::string& signature, const std::string& src,\n                         const std::string& dst) {\n  std::string result;\n  std::string::size_type substr_begin = 0;\n  for (std::string::size_type pos = 0;\n       signature.npos != (pos = signature.find(src.data(), pos, src.length()));) {\n    result.insert(result.end(), signature.begin() + substr_begin, signature.begin() + pos);\n    result += dst;\n    substr_begin = pos + src.length();\n    pos = substr_begin;\n  }\n  result.insert(result.end(), signature.begin() + substr_begin, signature.end());\n  return result;\n}\n\nstd::string replace(const std::string& signature) {\n  static std::vector<std::pair<std::string, std::string>> replace_pairs = {\n      {\"oneflow::one::\", \"\"},\n      {\"oneflow::\", \"\"},\n      {\"std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >\",\n       \"std::string\"},\n  };\n  std::string result = signature;\n  for (const auto& p : replace_pairs) { result = replace_each(result, p.first, p.second); }\n  return result;\n}\n\nstd::string simplify_type(const std::string& inp, const std::string& type_name) {\n  std::string result;\n  std::string::size_type begin = 0;\n  std::string::size_type pos = 0;\n  for (; inp.npos != (pos = inp.find(type_name.data(), pos, type_name.length()));) {\n    result.insert(result.end(), inp.begin() + begin, inp.begin() + pos);\n    auto type_and_rest = SignatureType::parse_type(inp.substr(pos, inp.size() - pos));\n    result += type_and_rest.first.to_string();\n    begin = inp.size() - type_and_rest.second.size();\n    pos = begin;\n  }\n  result.insert(result.end(), inp.begin() + begin, inp.end());\n  return result;\n}\n\nstd::string simplify(const std::string& inp) {\n  std::string result = replace(inp);\n  for (const auto& type_pair : class2keepsize) {\n    auto type_name = type_pair.first;\n    result = simplify_type(result, type_name);\n  }\n  return result;\n}\n}  // namespace\n\n// Default demangler implementation (do nothing).\ntemplate<typename TAG>\nstruct demangler_impl {\n  static std::string demangle(const char* funcname) { return funcname; }\n};\n\n#if defined(BACKWARD_SYSTEM_LINUX) || defined(BACKWARD_SYSTEM_DARWIN)\n\ntemplate<>\nstruct demangler_impl<system_tag::current_tag> {\n  demangler_impl() : _demangle_buffer_length(0) {}\n\n  std::string demangle(const char* funcname) {\n    using namespace details;\n    char* result =\n        abi::__cxa_demangle(funcname, _demangle_buffer.get(), &_demangle_buffer_length, nullptr);\n    if (result) {\n      _demangle_buffer.update(result);\n      // Modify: simplify func signature\n      return simplify(result);\n      // return result;\n    }\n    return funcname;\n  }\n\n private:\n  details::handle<char*> _demangle_buffer;\n  size_t _demangle_buffer_length;\n};\n\n#endif  // BACKWARD_SYSTEM_LINUX || BACKWARD_SYSTEM_DARWIN\n\nstruct demangler : public demangler_impl<system_tag::current_tag> {};\n\n// Split a string on the platform's PATH delimiter.  Example: if delimiter\n// is \":\" then:\n//   \"\"              --> []\n//   \":\"             --> [\"\",\"\"]\n//   \"::\"            --> [\"\",\"\",\"\"]\n//   \"/a/b/c\"        --> [\"/a/b/c\"]\n//   \"/a/b/c:/d/e/f\" --> [\"/a/b/c\",\"/d/e/f\"]\n//   etc.\ninline std::vector<std::string> split_source_prefixes(const std::string& s) {\n  std::vector<std::string> out;\n  size_t last = 0;\n  size_t next = 0;\n  size_t delimiter_size = sizeof(kBackwardPathDelimiter) - 1;\n  while ((next = s.find(kBackwardPathDelimiter, last)) != std::string::npos) {\n    out.push_back(s.substr(last, next - last));\n    last = next + delimiter_size;\n  }\n  if (last <= s.length()) { out.push_back(s.substr(last)); }\n  return out;\n}\n\n}  // namespace details\n\n/*************** A TRACE ***************/\n\nstruct Trace {\n  void* addr;\n  size_t idx;\n\n  Trace() : addr(nullptr), idx(0) {}\n\n  explicit Trace(void* _addr, size_t _idx) : addr(_addr), idx(_idx) {}\n};\n\nstruct ResolvedTrace : public Trace {\n  struct SourceLoc {\n    std::string function;\n    std::string filename;\n    unsigned line;\n    unsigned col;\n\n    SourceLoc() : line(0), col(0) {}\n\n    bool operator==(const SourceLoc& b) const {\n      return function == b.function && filename == b.filename && line == b.line && col == b.col;\n    }\n\n    bool operator!=(const SourceLoc& b) const { return !(*this == b); }\n  };\n\n  // In which binary object this trace is located.\n  std::string object_filename;\n\n  // The function in the object that contain the trace. This is not the same\n  // as source.function which can be an function inlined in object_function.\n  std::string object_function;\n\n  // The source location of this trace. It is possible for filename to be\n  // empty and for line/col to be invalid (value 0) if this information\n  // couldn't be deduced, for example if there is no debug information in the\n  // binary object.\n  SourceLoc source;\n\n  // An optionals list of \"inliners\". All the successive sources location\n  // from where the source location of the trace (the attribute right above)\n  // is inlined. It is especially useful when you compiled with optimization.\n  typedef std::vector<SourceLoc> source_locs_t;\n  source_locs_t inliners;\n\n  ResolvedTrace() : Trace() {}\n  ResolvedTrace(const Trace& mini_trace) : Trace(mini_trace) {}\n};\n\n/*************** STACK TRACE ***************/\n\n// default implemention.\ntemplate<typename TAG>\nclass StackTraceImpl {\n public:\n  size_t size() const { return 0; }\n  Trace operator[](size_t) const { return Trace(); }\n  size_t load_here(size_t = 0) { return 0; }\n  size_t load_from(void*, size_t = 0, void* = nullptr, void* = nullptr) { return 0; }\n  size_t thread_id() const { return 0; }\n  void skip_n_firsts(size_t) {}\n};\n\nclass StackTraceImplBase {\n public:\n  StackTraceImplBase() : _thread_id(0), _skip(0), _context(nullptr), _error_addr(nullptr) {}\n\n  size_t thread_id() const { return _thread_id; }\n\n  void skip_n_firsts(size_t n) { _skip = n; }\n\n protected:\n  void load_thread_info() {\n#ifdef BACKWARD_SYSTEM_LINUX\n#ifndef __ANDROID__\n    _thread_id = static_cast<size_t>(syscall(SYS_gettid));\n#else\n    _thread_id = static_cast<size_t>(gettid());\n#endif\n    if (_thread_id == static_cast<size_t>(getpid())) {\n      // If the thread is the main one, let's hide that.\n      // I like to keep little secret sometimes.\n      _thread_id = 0;\n    }\n#elif defined(BACKWARD_SYSTEM_DARWIN)\n    _thread_id = reinterpret_cast<size_t>(pthread_self());\n    if (pthread_main_np() == 1) {\n      // If the thread is the main one, let's hide that.\n      _thread_id = 0;\n    }\n#endif\n  }\n\n  void set_context(void* context) { _context = context; }\n  void* context() const { return _context; }\n\n  void set_error_addr(void* error_addr) { _error_addr = error_addr; }\n  void* error_addr() const { return _error_addr; }\n\n  size_t skip_n_firsts() const { return _skip; }\n\n private:\n  size_t _thread_id;\n  size_t _skip;\n  void* _context;\n  void* _error_addr;\n};\n\nclass StackTraceImplHolder : public StackTraceImplBase {\n public:\n  size_t size() const {\n    return (_stacktrace.size() >= skip_n_firsts()) ? _stacktrace.size() - skip_n_firsts() : 0;\n  }\n  Trace operator[](size_t idx) const {\n    if (idx >= size()) { return Trace(); }\n    return Trace(_stacktrace[idx + skip_n_firsts()], idx);\n  }\n  void* const* begin() const {\n    if (size()) { return &_stacktrace[skip_n_firsts()]; }\n    return nullptr;\n  }\n\n protected:\n  std::vector<void*> _stacktrace;\n};\n\n#if BACKWARD_HAS_UNWIND == 1\n\nnamespace details {\n\ntemplate<typename F>\nclass Unwinder {\n public:\n  size_t operator()(F& f, size_t depth) {\n    _f = &f;\n    _index = -1;\n    _depth = depth;\n    _Unwind_Backtrace(&this->backtrace_trampoline, this);\n    if (_index == -1) {\n      // _Unwind_Backtrace has failed to obtain any backtraces\n      return 0;\n    } else {\n      return static_cast<size_t>(_index);\n    }\n  }\n\n private:\n  F* _f;\n  ssize_t _index;\n  size_t _depth;\n\n  static _Unwind_Reason_Code backtrace_trampoline(_Unwind_Context* ctx, void* self) {\n    return (static_cast<Unwinder*>(self))->backtrace(ctx);\n  }\n\n  _Unwind_Reason_Code backtrace(_Unwind_Context* ctx) {\n    if (_index >= 0 && static_cast<size_t>(_index) >= _depth) return _URC_END_OF_STACK;\n\n    int ip_before_instruction = 0;\n    uintptr_t ip = _Unwind_GetIPInfo(ctx, &ip_before_instruction);\n\n    if (!ip_before_instruction) {\n      // calculating 0-1 for unsigned, looks like a possible bug to sanitizers,\n      // so let's do it explicitly:\n      if (ip == 0) {\n        ip = std::numeric_limits<uintptr_t>::max();  // set it to 0xffff... (as\n                                                     // from casting 0-1)\n      } else {\n        ip -= 1;  // else just normally decrement it (no overflow/underflow will\n                  // happen)\n      }\n    }\n\n    if (_index >= 0) {  // ignore first frame.\n      (*_f)(static_cast<size_t>(_index), reinterpret_cast<void*>(ip));\n    }\n    _index += 1;\n    return _URC_NO_REASON;\n  }\n};\n\ntemplate<typename F>\nsize_t unwind(F f, size_t depth) {\n  Unwinder<F> unwinder;\n  return unwinder(f, depth);\n}\n\n}  // namespace details\n\ntemplate<>\nclass StackTraceImpl<system_tag::current_tag> : public StackTraceImplHolder {\n public:\n  NOINLINE\n  size_t load_here(size_t depth = 32, void* context = nullptr, void* error_addr = nullptr) {\n    load_thread_info();\n    set_context(context);\n    set_error_addr(error_addr);\n    if (depth == 0) { return 0; }\n    _stacktrace.resize(depth);\n    size_t trace_cnt = details::unwind(callback(*this), depth);\n    _stacktrace.resize(trace_cnt);\n    skip_n_firsts(0);\n    return size();\n  }\n  size_t load_from(void* addr, size_t depth = 32, void* context = nullptr,\n                   void* error_addr = nullptr) {\n    load_here(depth + 8, context, error_addr);\n\n    for (size_t i = 0; i < _stacktrace.size(); ++i) {\n      if (_stacktrace[i] == addr) {\n        skip_n_firsts(i);\n        break;\n      }\n    }\n\n    _stacktrace.resize(std::min(_stacktrace.size(), skip_n_firsts() + depth));\n    return size();\n  }\n\n private:\n  struct callback {\n    StackTraceImpl& self;\n    callback(StackTraceImpl& _self) : self(_self) {}\n\n    void operator()(size_t idx, void* addr) { self._stacktrace[idx] = addr; }\n  };\n};\n\n#elif BACKWARD_HAS_LIBUNWIND == 1\n\ntemplate<>\nclass StackTraceImpl<system_tag::current_tag> : public StackTraceImplHolder {\n public:\n  __attribute__((noinline)) size_t load_here(size_t depth = 32, void* _context = nullptr,\n                                             void* _error_addr = nullptr) {\n    set_context(_context);\n    set_error_addr(_error_addr);\n    load_thread_info();\n    if (depth == 0) { return 0; }\n    _stacktrace.resize(depth + 1);\n\n    int result = 0;\n\n    unw_context_t ctx;\n    size_t index = 0;\n\n    // Add the tail call. If the Instruction Pointer is the crash address it\n    // means we got a bad function pointer dereference, so we \"unwind\" the\n    // bad pointer manually by using the return address pointed to by the\n    // Stack Pointer as the Instruction Pointer and letting libunwind do\n    // the rest\n\n    if (context()) {\n      ucontext_t* uctx = reinterpret_cast<ucontext_t*>(context());\n// x86_64\n#ifdef REG_RIP\n      if (uctx->uc_mcontext.gregs[REG_RIP] == reinterpret_cast<greg_t>(error_addr())) {\n        uctx->uc_mcontext.gregs[REG_RIP] =\n            *reinterpret_cast<size_t*>(uctx->uc_mcontext.gregs[REG_RSP]);\n      }\n      _stacktrace[index] = reinterpret_cast<void*>(uctx->uc_mcontext.gregs[REG_RIP]);\n      ++index;\n      ctx = *reinterpret_cast<unw_context_t*>(uctx);\n// x86_32\n#elif defined(REG_EIP)\n      if (uctx->uc_mcontext.gregs[REG_EIP] == reinterpret_cast<greg_t>(error_addr())) {\n        uctx->uc_mcontext.gregs[REG_EIP] =\n            *reinterpret_cast<size_t*>(uctx->uc_mcontext.gregs[REG_ESP]);\n      }\n      _stacktrace[index] = reinterpret_cast<void*>(uctx->uc_mcontext.gregs[REG_EIP]);\n      ++index;\n      ctx = *reinterpret_cast<unw_context_t*>(uctx);\n#elif defined(__arm__)\n      // libunwind uses its own context type for ARM unwinding.\n      // Copy the registers from the signal handler's context so we can\n      // unwind\n      unw_getcontext(&ctx);\n      ctx.regs[UNW_ARM_R0] = uctx->uc_mcontext.arm_r0;\n      ctx.regs[UNW_ARM_R1] = uctx->uc_mcontext.arm_r1;\n      ctx.regs[UNW_ARM_R2] = uctx->uc_mcontext.arm_r2;\n      ctx.regs[UNW_ARM_R3] = uctx->uc_mcontext.arm_r3;\n      ctx.regs[UNW_ARM_R4] = uctx->uc_mcontext.arm_r4;\n      ctx.regs[UNW_ARM_R5] = uctx->uc_mcontext.arm_r5;\n      ctx.regs[UNW_ARM_R6] = uctx->uc_mcontext.arm_r6;\n      ctx.regs[UNW_ARM_R7] = uctx->uc_mcontext.arm_r7;\n      ctx.regs[UNW_ARM_R8] = uctx->uc_mcontext.arm_r8;\n      ctx.regs[UNW_ARM_R9] = uctx->uc_mcontext.arm_r9;\n      ctx.regs[UNW_ARM_R10] = uctx->uc_mcontext.arm_r10;\n      ctx.regs[UNW_ARM_R11] = uctx->uc_mcontext.arm_fp;\n      ctx.regs[UNW_ARM_R12] = uctx->uc_mcontext.arm_ip;\n      ctx.regs[UNW_ARM_R13] = uctx->uc_mcontext.arm_sp;\n      ctx.regs[UNW_ARM_R14] = uctx->uc_mcontext.arm_lr;\n      ctx.regs[UNW_ARM_R15] = uctx->uc_mcontext.arm_pc;\n\n      // If we have crashed in the PC use the LR instead, as this was\n      // a bad function dereference\n      if (reinterpret_cast<unsigned long>(error_addr()) == uctx->uc_mcontext.arm_pc) {\n        ctx.regs[UNW_ARM_R15] = uctx->uc_mcontext.arm_lr - sizeof(unsigned long);\n      }\n      _stacktrace[index] = reinterpret_cast<void*>(ctx.regs[UNW_ARM_R15]);\n      ++index;\n#elif defined(__APPLE__) && defined(__x86_64__)\n      unw_getcontext(&ctx);\n      // OS X's implementation of libunwind uses its own context object\n      // so we need to convert the passed context to libunwind's format\n      // (information about the data layout taken from unw_getcontext.s\n      // in Apple's libunwind source\n      ctx.data[0] = uctx->uc_mcontext->__ss.__rax;\n      ctx.data[1] = uctx->uc_mcontext->__ss.__rbx;\n      ctx.data[2] = uctx->uc_mcontext->__ss.__rcx;\n      ctx.data[3] = uctx->uc_mcontext->__ss.__rdx;\n      ctx.data[4] = uctx->uc_mcontext->__ss.__rdi;\n      ctx.data[5] = uctx->uc_mcontext->__ss.__rsi;\n      ctx.data[6] = uctx->uc_mcontext->__ss.__rbp;\n      ctx.data[7] = uctx->uc_mcontext->__ss.__rsp;\n      ctx.data[8] = uctx->uc_mcontext->__ss.__r8;\n      ctx.data[9] = uctx->uc_mcontext->__ss.__r9;\n      ctx.data[10] = uctx->uc_mcontext->__ss.__r10;\n      ctx.data[11] = uctx->uc_mcontext->__ss.__r11;\n      ctx.data[12] = uctx->uc_mcontext->__ss.__r12;\n      ctx.data[13] = uctx->uc_mcontext->__ss.__r13;\n      ctx.data[14] = uctx->uc_mcontext->__ss.__r14;\n      ctx.data[15] = uctx->uc_mcontext->__ss.__r15;\n      ctx.data[16] = uctx->uc_mcontext->__ss.__rip;\n\n      // If the IP is the same as the crash address we have a bad function\n      // dereference The caller's address is pointed to by %rsp, so we\n      // dereference that value and set it to be the next frame's IP.\n      if (uctx->uc_mcontext->__ss.__rip == reinterpret_cast<__uint64_t>(error_addr())) {\n        ctx.data[16] = *reinterpret_cast<__uint64_t*>(uctx->uc_mcontext->__ss.__rsp);\n      }\n      _stacktrace[index] = reinterpret_cast<void*>(ctx.data[16]);\n      ++index;\n#elif defined(__APPLE__)\n      unw_getcontext(&ctx)\n          // TODO: Convert the ucontext_t to libunwind's unw_context_t like\n          // we do in 64 bits\n          if (ctx.uc_mcontext->__ss.__eip == reinterpret_cast<greg_t>(error_addr())) {\n        ctx.uc_mcontext->__ss.__eip = ctx.uc_mcontext->__ss.__esp;\n      }\n      _stacktrace[index] = reinterpret_cast<void*>(ctx.uc_mcontext->__ss.__eip);\n      ++index;\n#endif\n    }\n\n    unw_cursor_t cursor;\n    if (context()) {\n#if defined(UNW_INIT_SIGNAL_FRAME)\n      result = unw_init_local2(&cursor, &ctx, UNW_INIT_SIGNAL_FRAME);\n#else\n      result = unw_init_local(&cursor, &ctx);\n#endif\n    } else {\n      unw_getcontext(&ctx);\n      ;\n      result = unw_init_local(&cursor, &ctx);\n    }\n\n    if (result != 0) return 1;\n\n    unw_word_t ip = 0;\n\n    while (index <= depth && unw_step(&cursor) > 0) {\n      result = unw_get_reg(&cursor, UNW_REG_IP, &ip);\n      if (result == 0) {\n        _stacktrace[index] = reinterpret_cast<void*>(--ip);\n        ++index;\n      }\n    }\n    --index;\n\n    _stacktrace.resize(index + 1);\n    skip_n_firsts(0);\n    return size();\n  }\n\n  size_t load_from(void* addr, size_t depth = 32, void* context = nullptr,\n                   void* error_addr = nullptr) {\n    load_here(depth + 8, context, error_addr);\n\n    for (size_t i = 0; i < _stacktrace.size(); ++i) {\n      if (_stacktrace[i] == addr) {\n        skip_n_firsts(i);\n        _stacktrace[i] = (void*)((uintptr_t)_stacktrace[i]);\n        break;\n      }\n    }\n\n    _stacktrace.resize(std::min(_stacktrace.size(), skip_n_firsts() + depth));\n    return size();\n  }\n};\n\n#elif defined(BACKWARD_HAS_BACKTRACE)\n\ntemplate<>\nclass StackTraceImpl<system_tag::current_tag> : public StackTraceImplHolder {\n public:\n  NOINLINE\n  size_t load_here(size_t depth = 32, void* context = nullptr, void* error_addr = nullptr) {\n    set_context(context);\n    set_error_addr(error_addr);\n    load_thread_info();\n    if (depth == 0) { return 0; }\n    _stacktrace.resize(depth + 1);\n    size_t trace_cnt = backtrace(&_stacktrace[0], _stacktrace.size());\n    _stacktrace.resize(trace_cnt);\n    skip_n_firsts(1);\n    return size();\n  }\n\n  size_t load_from(void* addr, size_t depth = 32, void* context = nullptr,\n                   void* error_addr = nullptr) {\n    load_here(depth + 8, context, error_addr);\n\n    for (size_t i = 0; i < _stacktrace.size(); ++i) {\n      if (_stacktrace[i] == addr) {\n        skip_n_firsts(i);\n        _stacktrace[i] = (void*)((uintptr_t)_stacktrace[i] + 1);\n        break;\n      }\n    }\n\n    _stacktrace.resize(std::min(_stacktrace.size(), skip_n_firsts() + depth));\n    return size();\n  }\n};\n\n#elif defined(BACKWARD_SYSTEM_WINDOWS)\n\ntemplate<>\nclass StackTraceImpl<system_tag::current_tag> : public StackTraceImplHolder {\n public:\n  // We have to load the machine type from the image info\n  // So we first initialize the resolver, and it tells us this info\n  void set_machine_type(DWORD machine_type) { machine_type_ = machine_type; }\n  void set_context(CONTEXT* ctx) { ctx_ = ctx; }\n  void set_thread_handle(HANDLE handle) { thd_ = handle; }\n\n  NOINLINE\n  size_t load_here(size_t depth = 32, void* context = nullptr, void* error_addr = nullptr) {\n    set_context(static_cast<CONTEXT*>(context));\n    set_error_addr(error_addr);\n    CONTEXT localCtx;  // used when no context is provided\n\n    if (depth == 0) { return 0; }\n\n    if (!ctx_) {\n      ctx_ = &localCtx;\n      RtlCaptureContext(ctx_);\n    }\n\n    if (!thd_) { thd_ = GetCurrentThread(); }\n\n    HANDLE process = GetCurrentProcess();\n\n    STACKFRAME64 s;\n    memset(&s, 0, sizeof(STACKFRAME64));\n\n    // TODO: 32 bit context capture\n    s.AddrStack.Mode = AddrModeFlat;\n    s.AddrFrame.Mode = AddrModeFlat;\n    s.AddrPC.Mode = AddrModeFlat;\n#ifdef _M_X64\n    s.AddrPC.Offset = ctx_->Rip;\n    s.AddrStack.Offset = ctx_->Rsp;\n    s.AddrFrame.Offset = ctx_->Rbp;\n#else\n    s.AddrPC.Offset = ctx_->Eip;\n    s.AddrStack.Offset = ctx_->Esp;\n    s.AddrFrame.Offset = ctx_->Ebp;\n#endif\n\n    if (!machine_type_) {\n#ifdef _M_X64\n      machine_type_ = IMAGE_FILE_MACHINE_AMD64;\n#else\n      machine_type_ = IMAGE_FILE_MACHINE_I386;\n#endif\n    }\n\n    for (;;) {\n      // NOTE: this only works if PDBs are already loaded!\n      SetLastError(0);\n      if (!StackWalk64(machine_type_, process, thd_, &s, ctx_, NULL, SymFunctionTableAccess64,\n                       SymGetModuleBase64, NULL))\n        break;\n\n      if (s.AddrReturn.Offset == 0) break;\n\n      _stacktrace.push_back(reinterpret_cast<void*>(s.AddrPC.Offset));\n\n      if (size() >= depth) break;\n    }\n\n    return size();\n  }\n\n  size_t load_from(void* addr, size_t depth = 32, void* context = nullptr,\n                   void* error_addr = nullptr) {\n    load_here(depth + 8, context, error_addr);\n\n    for (size_t i = 0; i < _stacktrace.size(); ++i) {\n      if (_stacktrace[i] == addr) {\n        skip_n_firsts(i);\n        break;\n      }\n    }\n\n    _stacktrace.resize(std::min(_stacktrace.size(), skip_n_firsts() + depth));\n    return size();\n  }\n\n private:\n  DWORD machine_type_ = 0;\n  HANDLE thd_ = 0;\n  CONTEXT* ctx_ = nullptr;\n};\n\n#endif\n\nclass StackTrace : public StackTraceImpl<system_tag::current_tag> {};\n\n/*************** TRACE RESOLVER ***************/\n\nclass TraceResolverImplBase {\n public:\n  virtual ~TraceResolverImplBase() {}\n\n  virtual void load_addresses(void* const* addresses, int address_count) {\n    (void)addresses;\n    (void)address_count;\n  }\n\n  template<class ST>\n  void load_stacktrace(ST& st) {\n    load_addresses(st.begin(), static_cast<int>(st.size()));\n  }\n\n  virtual ResolvedTrace resolve(ResolvedTrace t) { return t; }\n\n protected:\n  std::string demangle(const char* funcname) { return _demangler.demangle(funcname); }\n\n private:\n  details::demangler _demangler;\n};\n\ntemplate<typename TAG>\nclass TraceResolverImpl;\n\n#ifdef BACKWARD_SYSTEM_UNKNOWN\n\ntemplate<>\nclass TraceResolverImpl<system_tag::unknown_tag> : public TraceResolverImplBase {};\n\n#endif\n\n#ifdef BACKWARD_SYSTEM_LINUX\n\nclass TraceResolverLinuxBase : public TraceResolverImplBase {\n public:\n  TraceResolverLinuxBase() : argv0_(get_argv0()), exec_path_(read_symlink(\"/proc/self/exe\")) {}\n  std::string resolve_exec_path(Dl_info& symbol_info) const {\n    // mutates symbol_info.dli_fname to be filename to open and returns filename\n    // to display\n    if (symbol_info.dli_fname == argv0_) {\n      // dladdr returns argv[0] in dli_fname for symbols contained in\n      // the main executable, which is not a valid path if the\n      // executable was found by a search of the PATH environment\n      // variable; In that case, we actually open /proc/self/exe, which\n      // is always the actual executable (even if it was deleted/replaced!)\n      // but display the path that /proc/self/exe links to.\n      // However, this right away reduces probability of successful symbol\n      // resolution, because libbfd may try to find *.debug files in the\n      // same dir, in case symbols are stripped. As a result, it may try\n      // to find a file /proc/self/<exe_name>.debug, which obviously does\n      // not exist. /proc/self/exe is a last resort. First load attempt\n      // should go for the original executable file path.\n      symbol_info.dli_fname = \"/proc/self/exe\";\n      return exec_path_;\n    } else {\n      return symbol_info.dli_fname;\n    }\n  }\n\n private:\n  std::string argv0_;\n  std::string exec_path_;\n\n  static std::string get_argv0() {\n    std::string argv0;\n    std::ifstream ifs(\"/proc/self/cmdline\");\n    std::getline(ifs, argv0, '\\0');\n    return argv0;\n  }\n\n  static std::string read_symlink(std::string const& symlink_path) {\n    std::string path;\n    path.resize(100);\n\n    while (true) {\n      ssize_t len = ::readlink(symlink_path.c_str(), &*path.begin(), path.size());\n      if (len < 0) { return \"\"; }\n      if (static_cast<size_t>(len) == path.size()) {\n        path.resize(path.size() * 2);\n      } else {\n        path.resize(static_cast<std::string::size_type>(len));\n        break;\n      }\n    }\n\n    return path;\n  }\n};\n\ntemplate<typename STACKTRACE_TAG>\nclass TraceResolverLinuxImpl;\n\n#if BACKWARD_HAS_BACKTRACE_SYMBOL == 1\n\ntemplate<>\nclass TraceResolverLinuxImpl<trace_resolver_tag::backtrace_symbol> : public TraceResolverLinuxBase {\n public:\n  void load_addresses(void* const* addresses, int address_count) override {\n    if (address_count == 0) { return; }\n    _symbols.reset(backtrace_symbols(addresses, address_count));\n  }\n\n  ResolvedTrace resolve(ResolvedTrace trace) override {\n    char* filename = _symbols[trace.idx];\n    char* funcname = filename;\n    while (*funcname && *funcname != '(') { funcname += 1; }\n    trace.object_filename.assign(filename,\n                                 funcname);  // ok even if funcname is the ending\n                                             // \\0 (then we assign entire string)\n\n    if (*funcname) {  // if it's not end of string (e.g. from last frame ip==0)\n      funcname += 1;\n      char* funcname_end = funcname;\n      while (*funcname_end && *funcname_end != ')' && *funcname_end != '+') { funcname_end += 1; }\n      *funcname_end = '\\0';\n      trace.object_function = this->demangle(funcname);\n      trace.source.function = trace.object_function;  // we cannot do better.\n    }\n    return trace;\n  }\n\n private:\n  details::handle<char**> _symbols;\n};\n\n#endif  // BACKWARD_HAS_BACKTRACE_SYMBOL == 1\n\n#if BACKWARD_HAS_BFD == 1\n\ntemplate<>\nclass TraceResolverLinuxImpl<trace_resolver_tag::libbfd> : public TraceResolverLinuxBase {\n public:\n  TraceResolverLinuxImpl() : _bfd_loaded(false) {}\n\n  ResolvedTrace resolve(ResolvedTrace trace) override {\n    Dl_info symbol_info;\n\n    // trace.addr is a virtual address in memory pointing to some code.\n    // Let's try to find from which loaded object it comes from.\n    // The loaded object can be yourself btw.\n    if (!dladdr(trace.addr, &symbol_info)) {\n      return trace;  // dat broken trace...\n    }\n\n    // Now we get in symbol_info:\n    // .dli_fname:\n    //\t\tpathname of the shared object that contains the address.\n    // .dli_fbase:\n    //\t\twhere the object is loaded in memory.\n    // .dli_sname:\n    //\t\tthe name of the nearest symbol to trace.addr, we expect a\n    //\t\tfunction name.\n    // .dli_saddr:\n    //\t\tthe exact address corresponding to .dli_sname.\n\n    if (symbol_info.dli_sname) { trace.object_function = demangle(symbol_info.dli_sname); }\n\n    if (!symbol_info.dli_fname) { return trace; }\n\n    trace.object_filename = resolve_exec_path(symbol_info);\n    bfd_fileobject* fobj;\n    // Before rushing to resolution need to ensure the executable\n    // file still can be used. For that compare inode numbers of\n    // what is stored by the executable's file path, and in the\n    // dli_fname, which not necessarily equals to the executable.\n    // It can be a shared library, or /proc/self/exe, and in the\n    // latter case has drawbacks. See the exec path resolution for\n    // details. In short - the dli object should be used only as\n    // the last resort.\n    // If inode numbers are equal, it is known dli_fname and the\n    // executable file are the same. This is guaranteed by Linux,\n    // because if the executable file is changed/deleted, it will\n    // be done in a new inode. The old file will be preserved in\n    // /proc/self/exe, and may even have inode 0. The latter can\n    // happen if the inode was actually reused, and the file was\n    // kept only in the main memory.\n    //\n    struct stat obj_stat;\n    struct stat dli_stat;\n    if (stat(trace.object_filename.c_str(), &obj_stat) == 0\n        && stat(symbol_info.dli_fname, &dli_stat) == 0 && obj_stat.st_ino == dli_stat.st_ino) {\n      // The executable file, and the shared object containing the\n      // address are the same file. Safe to use the original path.\n      // this is preferable. Libbfd will search for stripped debug\n      // symbols in the same directory.\n      fobj = load_object_with_bfd(trace.object_filename);\n    } else {\n      // The original object file was *deleted*! The only hope is\n      // that the debug symbols are either inside the shared\n      // object file, or are in the same directory, and this is\n      // not /proc/self/exe.\n      fobj = nullptr;\n    }\n    if (fobj == nullptr || !fobj->handle) {\n      fobj = load_object_with_bfd(symbol_info.dli_fname);\n      if (!fobj->handle) { return trace; }\n    }\n\n    find_sym_result* details_selected;  // to be filled.\n\n    // trace.addr is the next instruction to be executed after returning\n    // from the nested stack frame. In C++ this usually relate to the next\n    // statement right after the function call that leaded to a new stack\n    // frame. This is not usually what you want to see when printing out a\n    // stacktrace...\n    find_sym_result details_call_site =\n        find_symbol_details(fobj, trace.addr, symbol_info.dli_fbase);\n    details_selected = &details_call_site;\n\n#if BACKWARD_HAS_UNWIND == 0\n    // ...this is why we also try to resolve the symbol that is right\n    // before the return address. If we are lucky enough, we will get the\n    // line of the function that was called. But if the code is optimized,\n    // we might get something absolutely not related since the compiler\n    // can reschedule the return address with inline functions and\n    // tail-call optimization (among other things that I don't even know\n    // or cannot even dream about with my tiny limited brain).\n    find_sym_result details_adjusted_call_site =\n        find_symbol_details(fobj, (void*)(uintptr_t(trace.addr) - 1), symbol_info.dli_fbase);\n\n    // In debug mode, we should always get the right thing(TM).\n    if (details_call_site.found && details_adjusted_call_site.found) {\n      // Ok, we assume that details_adjusted_call_site is a better estimation.\n      details_selected = &details_adjusted_call_site;\n      trace.addr = (void*)(uintptr_t(trace.addr) - 1);\n    }\n\n    if (details_selected == &details_call_site && details_call_site.found) {\n      // we have to re-resolve the symbol in order to reset some\n      // internal state in BFD... so we can call backtrace_inliners\n      // thereafter...\n      details_call_site = find_symbol_details(fobj, trace.addr, symbol_info.dli_fbase);\n    }\n#endif  // BACKWARD_HAS_UNWIND\n\n    if (details_selected->found) {\n      if (details_selected->filename) { trace.source.filename = details_selected->filename; }\n      trace.source.line = details_selected->line;\n\n      if (details_selected->funcname) {\n        // this time we get the name of the function where the code is\n        // located, instead of the function were the address is\n        // located. In short, if the code was inlined, we get the\n        // function corresponding to the code. Else we already got in\n        // trace.function.\n        trace.source.function = demangle(details_selected->funcname);\n\n        if (!symbol_info.dli_sname) {\n          // for the case dladdr failed to find the symbol name of\n          // the function, we might as well try to put something\n          // here.\n          trace.object_function = trace.source.function;\n        }\n      }\n\n      // Maybe the source of the trace got inlined inside the function\n      // (trace.source.function). Let's see if we can get all the inlined\n      // calls along the way up to the initial call site.\n      trace.inliners = backtrace_inliners(fobj, *details_selected);\n\n#if 0\n\t\t\tif (trace.inliners.size() == 0) {\n\t\t\t\t// Maybe the trace was not inlined... or maybe it was and we\n\t\t\t\t// are lacking the debug information. Let's try to make the\n\t\t\t\t// world better and see if we can get the line number of the\n\t\t\t\t// function (trace.source.function) now.\n\t\t\t\t//\n\t\t\t\t// We will get the location of where the function start (to be\n\t\t\t\t// exact: the first instruction that really start the\n\t\t\t\t// function), not where the name of the function is defined.\n\t\t\t\t// This can be quite far away from the name of the function\n\t\t\t\t// btw.\n\t\t\t\t//\n\t\t\t\t// If the source of the function is the same as the source of\n\t\t\t\t// the trace, we cannot say if the trace was really inlined or\n\t\t\t\t// not.  However, if the filename of the source is different\n\t\t\t\t// between the function and the trace... we can declare it as\n\t\t\t\t// an inliner.  This is not 100% accurate, but better than\n\t\t\t\t// nothing.\n\n\t\t\t\tif (symbol_info.dli_saddr) {\n\t\t\t\t\tfind_sym_result details = find_symbol_details(fobj,\n\t\t\t\t\t\t\tsymbol_info.dli_saddr,\n\t\t\t\t\t\t\tsymbol_info.dli_fbase);\n\n\t\t\t\t\tif (details.found) {\n\t\t\t\t\t\tResolvedTrace::SourceLoc diy_inliner;\n\t\t\t\t\t\tdiy_inliner.line = details.line;\n\t\t\t\t\t\tif (details.filename) {\n\t\t\t\t\t\t\tdiy_inliner.filename = details.filename;\n\t\t\t\t\t\t}\n\t\t\t\t\t\tif (details.funcname) {\n\t\t\t\t\t\t\tdiy_inliner.function = demangle(details.funcname);\n\t\t\t\t\t\t} else {\n\t\t\t\t\t\t\tdiy_inliner.function = trace.source.function;\n\t\t\t\t\t\t}\n\t\t\t\t\t\tif (diy_inliner != trace.source) {\n\t\t\t\t\t\t\ttrace.inliners.push_back(diy_inliner);\n\t\t\t\t\t\t}\n\t\t\t\t\t}\n\t\t\t\t}\n\t\t\t}\n#endif\n    }\n\n    return trace;\n  }\n\n private:\n  bool _bfd_loaded;\n\n  typedef details::handle<bfd*, details::deleter<bfd_boolean, bfd*, &bfd_close>> bfd_handle_t;\n\n  typedef details::handle<asymbol**> bfd_symtab_t;\n\n  struct bfd_fileobject {\n    bfd_handle_t handle;\n    bfd_vma base_addr;\n    bfd_symtab_t symtab;\n    bfd_symtab_t dynamic_symtab;\n  };\n\n  typedef details::hashtable<std::string, bfd_fileobject>::type fobj_bfd_map_t;\n  fobj_bfd_map_t _fobj_bfd_map;\n\n  bfd_fileobject* load_object_with_bfd(const std::string& filename_object) {\n    using namespace details;\n\n    if (!_bfd_loaded) {\n      using namespace details;\n      bfd_init();\n      _bfd_loaded = true;\n    }\n\n    fobj_bfd_map_t::iterator it = _fobj_bfd_map.find(filename_object);\n    if (it != _fobj_bfd_map.end()) { return &it->second; }\n\n    // this new object is empty for now.\n    bfd_fileobject* r = &_fobj_bfd_map[filename_object];\n\n    // we do the work temporary in this one;\n    bfd_handle_t bfd_handle;\n\n    int fd = open(filename_object.c_str(), O_RDONLY);\n    bfd_handle.reset(bfd_fdopenr(filename_object.c_str(), \"default\", fd));\n    if (!bfd_handle) {\n      close(fd);\n      return r;\n    }\n\n    if (!bfd_check_format(bfd_handle.get(), bfd_object)) {\n      return r;  // not an object? You lose.\n    }\n\n    if ((bfd_get_file_flags(bfd_handle.get()) & HAS_SYMS) == 0) {\n      return r;  // that's what happen when you forget to compile in debug.\n    }\n\n    ssize_t symtab_storage_size = bfd_get_symtab_upper_bound(bfd_handle.get());\n\n    ssize_t dyn_symtab_storage_size = bfd_get_dynamic_symtab_upper_bound(bfd_handle.get());\n\n    if (symtab_storage_size <= 0 && dyn_symtab_storage_size <= 0) {\n      return r;  // weird, is the file is corrupted?\n    }\n\n    bfd_symtab_t symtab, dynamic_symtab;\n    ssize_t symcount = 0, dyn_symcount = 0;\n\n    if (symtab_storage_size > 0) {\n      symtab.reset(static_cast<bfd_symbol**>(malloc(static_cast<size_t>(symtab_storage_size))));\n      symcount = bfd_canonicalize_symtab(bfd_handle.get(), symtab.get());\n    }\n\n    if (dyn_symtab_storage_size > 0) {\n      dynamic_symtab.reset(\n          static_cast<bfd_symbol**>(malloc(static_cast<size_t>(dyn_symtab_storage_size))));\n      dyn_symcount = bfd_canonicalize_dynamic_symtab(bfd_handle.get(), dynamic_symtab.get());\n    }\n\n    if (symcount <= 0 && dyn_symcount <= 0) {\n      return r;  // damned, that's a stripped file that you got there!\n    }\n\n    r->handle = move(bfd_handle);\n    r->symtab = move(symtab);\n    r->dynamic_symtab = move(dynamic_symtab);\n    return r;\n  }\n\n  struct find_sym_result {\n    bool found;\n    const char* filename;\n    const char* funcname;\n    unsigned int line;\n  };\n\n  struct find_sym_context {\n    TraceResolverLinuxImpl* self;\n    bfd_fileobject* fobj;\n    void* addr;\n    void* base_addr;\n    find_sym_result result;\n  };\n\n  find_sym_result find_symbol_details(bfd_fileobject* fobj, void* addr, void* base_addr) {\n    find_sym_context context;\n    context.self = this;\n    context.fobj = fobj;\n    context.addr = addr;\n    context.base_addr = base_addr;\n    context.result.found = false;\n    bfd_map_over_sections(fobj->handle.get(), &find_in_section_trampoline,\n                          static_cast<void*>(&context));\n    return context.result;\n  }\n\n  static void find_in_section_trampoline(bfd*, asection* section, void* data) {\n    find_sym_context* context = static_cast<find_sym_context*>(data);\n    context->self->find_in_section(reinterpret_cast<bfd_vma>(context->addr),\n                                   reinterpret_cast<bfd_vma>(context->base_addr), context->fobj,\n                                   section, context->result);\n  }\n\n  void find_in_section(bfd_vma addr, bfd_vma base_addr, bfd_fileobject* fobj, asection* section,\n                       find_sym_result& result) {\n    if (result.found) return;\n\n#ifdef bfd_get_section_flags\n    if ((bfd_get_section_flags(fobj->handle.get(), section) & SEC_ALLOC) == 0)\n#else\n    if ((bfd_section_flags(section) & SEC_ALLOC) == 0)\n#endif\n      return;  // a debug section is never loaded automatically.\n\n#ifdef bfd_get_section_vma\n    bfd_vma sec_addr = bfd_get_section_vma(fobj->handle.get(), section);\n#else\n    bfd_vma sec_addr = bfd_section_vma(section);\n#endif\n#ifdef bfd_get_section_size\n    bfd_size_type size = bfd_get_section_size(section);\n#else\n    bfd_size_type size = bfd_section_size(section);\n#endif\n\n    // are we in the boundaries of the section?\n    if (addr < sec_addr || addr >= sec_addr + size) {\n      addr -= base_addr;  // oops, a relocated object, lets try again...\n      if (addr < sec_addr || addr >= sec_addr + size) { return; }\n    }\n\n#if defined(__clang__)\n#pragma clang diagnostic push\n#pragma clang diagnostic ignored \"-Wzero-as-null-pointer-constant\"\n#endif\n    if (!result.found && fobj->symtab) {\n      result.found =\n          bfd_find_nearest_line(fobj->handle.get(), section, fobj->symtab.get(), addr - sec_addr,\n                                &result.filename, &result.funcname, &result.line);\n    }\n\n    if (!result.found && fobj->dynamic_symtab) {\n      result.found =\n          bfd_find_nearest_line(fobj->handle.get(), section, fobj->dynamic_symtab.get(),\n                                addr - sec_addr, &result.filename, &result.funcname, &result.line);\n    }\n#if defined(__clang__)\n#pragma clang diagnostic pop\n#endif\n  }\n\n  ResolvedTrace::source_locs_t backtrace_inliners(bfd_fileobject* fobj,\n                                                  find_sym_result previous_result) {\n    // This function can be called ONLY after a SUCCESSFUL call to\n    // find_symbol_details. The state is global to the bfd_handle.\n    ResolvedTrace::source_locs_t results;\n    while (previous_result.found) {\n      find_sym_result result;\n      result.found = bfd_find_inliner_info(fobj->handle.get(), &result.filename, &result.funcname,\n                                           &result.line);\n\n      if (result.found) /* and not (\n                              cstrings_eq(previous_result.filename,\n                           result.filename) and\n                           cstrings_eq(previous_result.funcname, result.funcname)\n                              and result.line == previous_result.line\n                              )) */\n      {\n        ResolvedTrace::SourceLoc src_loc;\n        src_loc.line = result.line;\n        if (result.filename) { src_loc.filename = result.filename; }\n        if (result.funcname) { src_loc.function = demangle(result.funcname); }\n        results.push_back(src_loc);\n      }\n      previous_result = result;\n    }\n    return results;\n  }\n\n  bool cstrings_eq(const char* a, const char* b) {\n    if (!a || !b) { return false; }\n    return strcmp(a, b) == 0;\n  }\n};\n#endif  // BACKWARD_HAS_BFD == 1\n\n#if BACKWARD_HAS_DW == 1\n\ntemplate<>\nclass TraceResolverLinuxImpl<trace_resolver_tag::libdw> : public TraceResolverLinuxBase {\n public:\n  TraceResolverLinuxImpl() : _dwfl_handle_initialized(false) {}\n\n  ResolvedTrace resolve(ResolvedTrace trace) override {\n    using namespace details;\n\n    Dwarf_Addr trace_addr = reinterpret_cast<Dwarf_Addr>(trace.addr);\n\n    if (!_dwfl_handle_initialized) {\n      // initialize dwfl...\n      _dwfl_cb.reset(new Dwfl_Callbacks);\n      _dwfl_cb->find_elf = &dwfl_linux_proc_find_elf;\n      _dwfl_cb->find_debuginfo = &dwfl_standard_find_debuginfo;\n      _dwfl_cb->debuginfo_path = 0;\n\n      _dwfl_handle.reset(dwfl_begin(_dwfl_cb.get()));\n      _dwfl_handle_initialized = true;\n\n      if (!_dwfl_handle) { return trace; }\n\n      // ...from the current process.\n      dwfl_report_begin(_dwfl_handle.get());\n      int r = dwfl_linux_proc_report(_dwfl_handle.get(), getpid());\n      dwfl_report_end(_dwfl_handle.get(), NULL, NULL);\n      if (r < 0) { return trace; }\n    }\n\n    if (!_dwfl_handle) { return trace; }\n\n    // find the module (binary object) that contains the trace's address.\n    // This is not using any debug information, but the addresses ranges of\n    // all the currently loaded binary object.\n    Dwfl_Module* mod = dwfl_addrmodule(_dwfl_handle.get(), trace_addr);\n    if (mod) {\n      // now that we found it, lets get the name of it, this will be the\n      // full path to the running binary or one of the loaded library.\n      const char* module_name = dwfl_module_info(mod, 0, 0, 0, 0, 0, 0, 0);\n      if (module_name) { trace.object_filename = module_name; }\n      // We also look after the name of the symbol, equal or before this\n      // address. This is found by walking the symtab. We should get the\n      // symbol corresponding to the function (mangled) containing the\n      // address. If the code corresponding to the address was inlined,\n      // this is the name of the out-most inliner function.\n      const char* sym_name = dwfl_module_addrname(mod, trace_addr);\n      if (sym_name) { trace.object_function = demangle(sym_name); }\n    }\n\n    // now let's get serious, and find out the source location (file and\n    // line number) of the address.\n\n    // This function will look in .debug_aranges for the address and map it\n    // to the location of the compilation unit DIE in .debug_info and\n    // return it.\n    Dwarf_Addr mod_bias = 0;\n    Dwarf_Die* cudie = dwfl_module_addrdie(mod, trace_addr, &mod_bias);\n\n#if 1\n    if (!cudie) {\n      // Sadly clang does not generate the section .debug_aranges, thus\n      // dwfl_module_addrdie will fail early. Clang doesn't either set\n      // the lowpc/highpc/range info for every compilation unit.\n      //\n      // So in order to save the world:\n      // for every compilation unit, we will iterate over every single\n      // DIEs. Normally functions should have a lowpc/highpc/range, which\n      // we will use to infer the compilation unit.\n\n      // note that this is probably badly inefficient.\n      while ((cudie = dwfl_module_nextcu(mod, cudie, &mod_bias))) {\n        Dwarf_Die die_mem;\n        Dwarf_Die* fundie = find_fundie_by_pc(cudie, trace_addr - mod_bias, &die_mem);\n        if (fundie) { break; }\n      }\n    }\n#endif\n\n//#define BACKWARD_I_DO_NOT_RECOMMEND_TO_ENABLE_THIS_HORRIBLE_PIECE_OF_CODE\n#ifdef BACKWARD_I_DO_NOT_RECOMMEND_TO_ENABLE_THIS_HORRIBLE_PIECE_OF_CODE\n    if (!cudie) {\n      // If it's still not enough, lets dive deeper in the shit, and try\n      // to save the world again: for every compilation unit, we will\n      // load the corresponding .debug_line section, and see if we can\n      // find our address in it.\n\n      Dwarf_Addr cfi_bias;\n      Dwarf_CFI* cfi_cache = dwfl_module_eh_cfi(mod, &cfi_bias);\n\n      Dwarf_Addr bias;\n      while ((cudie = dwfl_module_nextcu(mod, cudie, &bias))) {\n        if (dwarf_getsrc_die(cudie, trace_addr - bias)) {\n          // ...but if we get a match, it might be a false positive\n          // because our (address - bias) might as well be valid in a\n          // different compilation unit. So we throw our last card on\n          // the table and lookup for the address into the .eh_frame\n          // section.\n\n          handle<Dwarf_Frame*> frame;\n          dwarf_cfi_addrframe(cfi_cache, trace_addr - cfi_bias, &frame);\n          if (frame) { break; }\n        }\n      }\n    }\n#endif\n\n    if (!cudie) {\n      return trace;  // this time we lost the game :/\n    }\n\n    // Now that we have a compilation unit DIE, this function will be able\n    // to load the corresponding section in .debug_line (if not already\n    // loaded) and hopefully find the source location mapped to our\n    // address.\n    Dwarf_Line* srcloc = dwarf_getsrc_die(cudie, trace_addr - mod_bias);\n\n    if (srcloc) {\n      const char* srcfile = dwarf_linesrc(srcloc, 0, 0);\n      if (srcfile) { trace.source.filename = srcfile; }\n      int line = 0, col = 0;\n      dwarf_lineno(srcloc, &line);\n      dwarf_linecol(srcloc, &col);\n      trace.source.line = static_cast<unsigned>(line);\n      trace.source.col = static_cast<unsigned>(col);\n    }\n\n    deep_first_search_by_pc(cudie, trace_addr - mod_bias, inliners_search_cb(trace));\n    if (trace.source.function.size() == 0) {\n      // fallback.\n      trace.source.function = trace.object_function;\n    }\n\n    return trace;\n  }\n\n private:\n  typedef details::handle<Dwfl*, details::deleter<void, Dwfl*, &dwfl_end>> dwfl_handle_t;\n  details::handle<Dwfl_Callbacks*, details::default_delete<Dwfl_Callbacks*>> _dwfl_cb;\n  dwfl_handle_t _dwfl_handle;\n  bool _dwfl_handle_initialized;\n\n  // defined here because in C++98, template function cannot take locally\n  // defined types... grrr.\n  struct inliners_search_cb {\n    void operator()(Dwarf_Die* die) {\n      switch (dwarf_tag(die)) {\n        const char* name;\n        case DW_TAG_subprogram:\n          if ((name = dwarf_diename(die))) { trace.source.function = name; }\n          break;\n\n        case DW_TAG_inlined_subroutine:\n          ResolvedTrace::SourceLoc sloc;\n          Dwarf_Attribute attr_mem;\n\n          if ((name = dwarf_diename(die))) { sloc.function = name; }\n          if ((name = die_call_file(die))) { sloc.filename = name; }\n\n          Dwarf_Word line = 0, col = 0;\n          dwarf_formudata(dwarf_attr(die, DW_AT_call_line, &attr_mem), &line);\n          dwarf_formudata(dwarf_attr(die, DW_AT_call_column, &attr_mem), &col);\n          sloc.line = static_cast<unsigned>(line);\n          sloc.col = static_cast<unsigned>(col);\n\n          trace.inliners.push_back(sloc);\n          break;\n      };\n    }\n    ResolvedTrace& trace;\n    inliners_search_cb(ResolvedTrace& t) : trace(t) {}\n  };\n\n  static bool die_has_pc(Dwarf_Die* die, Dwarf_Addr pc) {\n    Dwarf_Addr low, high;\n\n    // continuous range\n    if (dwarf_hasattr(die, DW_AT_low_pc) && dwarf_hasattr(die, DW_AT_high_pc)) {\n      if (dwarf_lowpc(die, &low) != 0) { return false; }\n      if (dwarf_highpc(die, &high) != 0) {\n        Dwarf_Attribute attr_mem;\n        Dwarf_Attribute* attr = dwarf_attr(die, DW_AT_high_pc, &attr_mem);\n        Dwarf_Word value;\n        if (dwarf_formudata(attr, &value) != 0) { return false; }\n        high = low + value;\n      }\n      return pc >= low && pc < high;\n    }\n\n    // non-continuous range.\n    Dwarf_Addr base;\n    ptrdiff_t offset = 0;\n    while ((offset = dwarf_ranges(die, offset, &base, &low, &high)) > 0) {\n      if (pc >= low && pc < high) { return true; }\n    }\n    return false;\n  }\n\n  static Dwarf_Die* find_fundie_by_pc(Dwarf_Die* parent_die, Dwarf_Addr pc, Dwarf_Die* result) {\n    if (dwarf_child(parent_die, result) != 0) { return 0; }\n\n    Dwarf_Die* die = result;\n    do {\n      switch (dwarf_tag(die)) {\n        case DW_TAG_subprogram:\n        case DW_TAG_inlined_subroutine:\n          if (die_has_pc(die, pc)) { return result; }\n      };\n      bool declaration = false;\n      Dwarf_Attribute attr_mem;\n      dwarf_formflag(dwarf_attr(die, DW_AT_declaration, &attr_mem), &declaration);\n      if (!declaration) {\n        // let's be curious and look deeper in the tree,\n        // function are not necessarily at the first level, but\n        // might be nested inside a namespace, structure etc.\n        Dwarf_Die die_mem;\n        Dwarf_Die* indie = find_fundie_by_pc(die, pc, &die_mem);\n        if (indie) {\n          *result = die_mem;\n          return result;\n        }\n      }\n    } while (dwarf_siblingof(die, result) == 0);\n    return 0;\n  }\n\n  template<typename CB>\n  static bool deep_first_search_by_pc(Dwarf_Die* parent_die, Dwarf_Addr pc, CB cb) {\n    Dwarf_Die die_mem;\n    if (dwarf_child(parent_die, &die_mem) != 0) { return false; }\n\n    bool branch_has_pc = false;\n    Dwarf_Die* die = &die_mem;\n    do {\n      bool declaration = false;\n      Dwarf_Attribute attr_mem;\n      dwarf_formflag(dwarf_attr(die, DW_AT_declaration, &attr_mem), &declaration);\n      if (!declaration) {\n        // let's be curious and look deeper in the tree, function are\n        // not necessarily at the first level, but might be nested\n        // inside a namespace, structure, a function, an inlined\n        // function etc.\n        branch_has_pc = deep_first_search_by_pc(die, pc, cb);\n      }\n      if (!branch_has_pc) { branch_has_pc = die_has_pc(die, pc); }\n      if (branch_has_pc) { cb(die); }\n    } while (dwarf_siblingof(die, &die_mem) == 0);\n    return branch_has_pc;\n  }\n\n  static const char* die_call_file(Dwarf_Die* die) {\n    Dwarf_Attribute attr_mem;\n    Dwarf_Word file_idx = 0;\n\n    dwarf_formudata(dwarf_attr(die, DW_AT_call_file, &attr_mem), &file_idx);\n\n    if (file_idx == 0) { return 0; }\n\n    Dwarf_Die die_mem;\n    Dwarf_Die* cudie = dwarf_diecu(die, &die_mem, 0, 0);\n    if (!cudie) { return 0; }\n\n    Dwarf_Files* files = 0;\n    size_t nfiles;\n    dwarf_getsrcfiles(cudie, &files, &nfiles);\n    if (!files) { return 0; }\n\n    return dwarf_filesrc(files, file_idx, 0, 0);\n  }\n};\n#endif  // BACKWARD_HAS_DW == 1\n\n#if BACKWARD_HAS_DWARF == 1\n\ntemplate<>\nclass TraceResolverLinuxImpl<trace_resolver_tag::libdwarf> : public TraceResolverLinuxBase {\n public:\n  TraceResolverLinuxImpl() : _dwarf_loaded(false) {}\n\n  ResolvedTrace resolve(ResolvedTrace trace) override {\n    // trace.addr is a virtual address in memory pointing to some code.\n    // Let's try to find from which loaded object it comes from.\n    // The loaded object can be yourself btw.\n\n    Dl_info symbol_info;\n    int dladdr_result = 0;\n#if defined(__GLIBC__)\n    link_map* link_map;\n    // We request the link map so we can get information about offsets\n    dladdr_result =\n        dladdr1(trace.addr, &symbol_info, reinterpret_cast<void**>(&link_map), RTLD_DL_LINKMAP);\n#else\n    // Android doesn't have dladdr1. Don't use the linker map.\n    dladdr_result = dladdr(trace.addr, &symbol_info);\n#endif\n    if (!dladdr_result) {\n      return trace;  // dat broken trace...\n    }\n\n    // Now we get in symbol_info:\n    // .dli_fname:\n    //      pathname of the shared object that contains the address.\n    // .dli_fbase:\n    //      where the object is loaded in memory.\n    // .dli_sname:\n    //      the name of the nearest symbol to trace.addr, we expect a\n    //      function name.\n    // .dli_saddr:\n    //      the exact address corresponding to .dli_sname.\n    //\n    // And in link_map:\n    // .l_addr:\n    //      difference between the address in the ELF file and the address\n    //      in memory\n    // l_name:\n    //      absolute pathname where the object was found\n\n    if (symbol_info.dli_sname) { trace.object_function = demangle(symbol_info.dli_sname); }\n\n    if (!symbol_info.dli_fname) { return trace; }\n\n    trace.object_filename = resolve_exec_path(symbol_info);\n    dwarf_fileobject& fobj = load_object_with_dwarf(symbol_info.dli_fname);\n    if (!fobj.dwarf_handle) {\n      return trace;  // sad, we couldn't load the object :(\n    }\n\n#if defined(__GLIBC__)\n    // Convert the address to a module relative one by looking at\n    // the module's loading address in the link map\n    Dwarf_Addr address =\n        reinterpret_cast<uintptr_t>(trace.addr) - reinterpret_cast<uintptr_t>(link_map->l_addr);\n#else\n    Dwarf_Addr address = reinterpret_cast<uintptr_t>(trace.addr);\n#endif\n\n    if (trace.object_function.empty()) {\n      symbol_cache_t::iterator it = fobj.symbol_cache.lower_bound(address);\n\n      if (it != fobj.symbol_cache.end()) {\n        if (it->first != address) {\n          if (it != fobj.symbol_cache.begin()) { --it; }\n        }\n        trace.object_function = demangle(it->second.c_str());\n      }\n    }\n\n    // Get the Compilation Unit DIE for the address\n    Dwarf_Die die = find_die(fobj, address);\n\n    if (!die) {\n      return trace;  // this time we lost the game :/\n    }\n\n    // libdwarf doesn't give us direct access to its objects, it always\n    // allocates a copy for the caller. We keep that copy alive in a cache\n    // and we deallocate it later when it's no longer required.\n    die_cache_entry& die_object = get_die_cache(fobj, die);\n    if (die_object.isEmpty()) return trace;  // We have no line section for this DIE\n\n    die_linemap_t::iterator it = die_object.line_section.lower_bound(address);\n\n    if (it != die_object.line_section.end()) {\n      if (it->first != address) {\n        if (it == die_object.line_section.begin()) {\n          // If we are on the first item of the line section\n          // but the address does not match it means that\n          // the address is below the range of the DIE. Give up.\n          return trace;\n        } else {\n          --it;\n        }\n      }\n    } else {\n      return trace;  // We didn't find the address.\n    }\n\n    // Get the Dwarf_Line that the address points to and call libdwarf\n    // to get source file, line and column info.\n    Dwarf_Line line = die_object.line_buffer[it->second];\n    Dwarf_Error error = DW_DLE_NE;\n\n    char* filename;\n    if (dwarf_linesrc(line, &filename, &error) == DW_DLV_OK) {\n      trace.source.filename = std::string(filename);\n      dwarf_dealloc(fobj.dwarf_handle.get(), filename, DW_DLA_STRING);\n    }\n\n    Dwarf_Unsigned number = 0;\n    if (dwarf_lineno(line, &number, &error) == DW_DLV_OK) {\n      trace.source.line = number;\n    } else {\n      trace.source.line = 0;\n    }\n\n    if (dwarf_lineoff_b(line, &number, &error) == DW_DLV_OK) {\n      trace.source.col = number;\n    } else {\n      trace.source.col = 0;\n    }\n\n    std::vector<std::string> namespace_stack;\n    deep_first_search_by_pc(fobj, die, address, namespace_stack,\n                            inliners_search_cb(trace, fobj, die));\n\n    dwarf_dealloc(fobj.dwarf_handle.get(), die, DW_DLA_DIE);\n\n    return trace;\n  }\n\n public:\n  static int close_dwarf(Dwarf_Debug dwarf) { return dwarf_finish(dwarf, NULL); }\n\n private:\n  bool _dwarf_loaded;\n\n  typedef details::handle<int, details::deleter<int, int, &::close>> dwarf_file_t;\n\n  typedef details::handle<Elf*, details::deleter<int, Elf*, &elf_end>> dwarf_elf_t;\n\n  typedef details::handle<Dwarf_Debug, details::deleter<int, Dwarf_Debug, &close_dwarf>>\n      dwarf_handle_t;\n\n  typedef std::map<Dwarf_Addr, int> die_linemap_t;\n\n  typedef std::map<Dwarf_Off, Dwarf_Off> die_specmap_t;\n\n  struct die_cache_entry {\n    die_specmap_t spec_section;\n    die_linemap_t line_section;\n    Dwarf_Line* line_buffer;\n    Dwarf_Signed line_count;\n    Dwarf_Line_Context line_context;\n\n    inline bool isEmpty() {\n      return line_buffer == NULL || line_count == 0 || line_context == NULL || line_section.empty();\n    }\n\n    die_cache_entry() : line_buffer(0), line_count(0), line_context(0) {}\n\n    ~die_cache_entry() {\n      if (line_context) { dwarf_srclines_dealloc_b(line_context); }\n    }\n  };\n\n  typedef std::map<Dwarf_Off, die_cache_entry> die_cache_t;\n\n  typedef std::map<uintptr_t, std::string> symbol_cache_t;\n\n  struct dwarf_fileobject {\n    dwarf_file_t file_handle;\n    dwarf_elf_t elf_handle;\n    dwarf_handle_t dwarf_handle;\n    symbol_cache_t symbol_cache;\n\n    // Die cache\n    die_cache_t die_cache;\n    die_cache_entry* current_cu;\n  };\n\n  typedef details::hashtable<std::string, dwarf_fileobject>::type fobj_dwarf_map_t;\n  fobj_dwarf_map_t _fobj_dwarf_map;\n\n  static bool cstrings_eq(const char* a, const char* b) {\n    if (!a || !b) { return false; }\n    return strcmp(a, b) == 0;\n  }\n\n  dwarf_fileobject& load_object_with_dwarf(const std::string& filename_object) {\n    if (!_dwarf_loaded) {\n      // Set the ELF library operating version\n      // If that fails there's nothing we can do\n      _dwarf_loaded = elf_version(EV_CURRENT) != EV_NONE;\n    }\n\n    fobj_dwarf_map_t::iterator it = _fobj_dwarf_map.find(filename_object);\n    if (it != _fobj_dwarf_map.end()) { return it->second; }\n\n    // this new object is empty for now\n    dwarf_fileobject& r = _fobj_dwarf_map[filename_object];\n\n    dwarf_file_t file_handle;\n    file_handle.reset(open(filename_object.c_str(), O_RDONLY));\n    if (file_handle.get() < 0) { return r; }\n\n    // Try to get an ELF handle. We need to read the ELF sections\n    // because we want to see if there is a .gnu_debuglink section\n    // that points to a split debug file\n    dwarf_elf_t elf_handle;\n    elf_handle.reset(elf_begin(file_handle.get(), ELF_C_READ, NULL));\n    if (!elf_handle) { return r; }\n\n    const char* e_ident = elf_getident(elf_handle.get(), 0);\n    if (!e_ident) { return r; }\n\n    // Get the number of sections\n    // We use the new APIs as elf_getshnum is deprecated\n    size_t shdrnum = 0;\n    if (elf_getshdrnum(elf_handle.get(), &shdrnum) == -1) { return r; }\n\n    // Get the index to the string section\n    size_t shdrstrndx = 0;\n    if (elf_getshdrstrndx(elf_handle.get(), &shdrstrndx) == -1) { return r; }\n\n    std::string debuglink;\n    // Iterate through the ELF sections to try to get a gnu_debuglink\n    // note and also to cache the symbol table.\n    // We go the preprocessor way to avoid having to create templated\n    // classes or using gelf (which might throw a compiler error if 64 bit\n    // is not supported\n#define ELF_GET_DATA(ARCH)                                                                 \\\n  Elf_Scn* elf_section = 0;                                                                \\\n  Elf_Data* elf_data = 0;                                                                  \\\n  Elf##ARCH##_Shdr* section_header = 0;                                                    \\\n  Elf_Scn* symbol_section = 0;                                                             \\\n  size_t symbol_count = 0;                                                                 \\\n  size_t symbol_strings = 0;                                                               \\\n  Elf##ARCH##_Sym* symbol = 0;                                                             \\\n  const char* section_name = 0;                                                            \\\n                                                                                           \\\n  while ((elf_section = elf_nextscn(elf_handle.get(), elf_section)) != NULL) {             \\\n    section_header = elf##ARCH##_getshdr(elf_section);                                     \\\n    if (section_header == NULL) { return r; }                                              \\\n                                                                                           \\\n    if ((section_name = elf_strptr(elf_handle.get(), shdrstrndx, section_header->sh_name)) \\\n        == NULL) {                                                                         \\\n      return r;                                                                            \\\n    }                                                                                      \\\n                                                                                           \\\n    if (cstrings_eq(section_name, \".gnu_debuglink\")) {                                     \\\n      elf_data = elf_getdata(elf_section, NULL);                                           \\\n      if (elf_data && elf_data->d_size > 0) {                                              \\\n        debuglink = std::string(reinterpret_cast<const char*>(elf_data->d_buf));           \\\n      }                                                                                    \\\n    }                                                                                      \\\n                                                                                           \\\n    switch (section_header->sh_type) {                                                     \\\n      case SHT_SYMTAB:                                                                     \\\n        symbol_section = elf_section;                                                      \\\n        symbol_count = section_header->sh_size / section_header->sh_entsize;               \\\n        symbol_strings = section_header->sh_link;                                          \\\n        break;                                                                             \\\n                                                                                           \\\n      /* We use .dynsyms as a last resort, we prefer .symtab */                            \\\n      case SHT_DYNSYM:                                                                     \\\n        if (!symbol_section) {                                                             \\\n          symbol_section = elf_section;                                                    \\\n          symbol_count = section_header->sh_size / section_header->sh_entsize;             \\\n          symbol_strings = section_header->sh_link;                                        \\\n        }                                                                                  \\\n        break;                                                                             \\\n    }                                                                                      \\\n  }                                                                                        \\\n                                                                                           \\\n  if (symbol_section && symbol_count && symbol_strings) {                                  \\\n    elf_data = elf_getdata(symbol_section, NULL);                                          \\\n    symbol = reinterpret_cast<Elf##ARCH##_Sym*>(elf_data->d_buf);                          \\\n    for (size_t i = 0; i < symbol_count; ++i) {                                            \\\n      int type = ELF##ARCH##_ST_TYPE(symbol->st_info);                                     \\\n      if (type == STT_FUNC && symbol->st_value > 0) {                                      \\\n        r.symbol_cache[symbol->st_value] =                                                 \\\n            std::string(elf_strptr(elf_handle.get(), symbol_strings, symbol->st_name));    \\\n      }                                                                                    \\\n      ++symbol;                                                                            \\\n    }                                                                                      \\\n  }\n\n    if (e_ident[EI_CLASS] == ELFCLASS32) {\n      ELF_GET_DATA(32)\n    } else if (e_ident[EI_CLASS] == ELFCLASS64) {\n      // libelf might have been built without 64 bit support\n#if __LIBELF64\n      ELF_GET_DATA(64)\n#endif\n    }\n\n    if (!debuglink.empty()) {\n      // We have a debuglink section! Open an elf instance on that\n      // file instead. If we can't open the file, then return\n      // the elf handle we had already opened.\n      dwarf_file_t debuglink_file;\n      debuglink_file.reset(open(debuglink.c_str(), O_RDONLY));\n      if (debuglink_file.get() > 0) {\n        dwarf_elf_t debuglink_elf;\n        debuglink_elf.reset(elf_begin(debuglink_file.get(), ELF_C_READ, NULL));\n\n        // If we have a valid elf handle, return the new elf handle\n        // and file handle and discard the original ones\n        if (debuglink_elf) {\n          elf_handle = move(debuglink_elf);\n          file_handle = move(debuglink_file);\n        }\n      }\n    }\n\n    // Ok, we have a valid ELF handle, let's try to get debug symbols\n    Dwarf_Debug dwarf_debug;\n    Dwarf_Error error = DW_DLE_NE;\n    dwarf_handle_t dwarf_handle;\n\n    int dwarf_result =\n        dwarf_elf_init(elf_handle.get(), DW_DLC_READ, NULL, NULL, &dwarf_debug, &error);\n\n    // We don't do any special handling for DW_DLV_NO_ENTRY specially.\n    // If we get an error, or the file doesn't have debug information\n    // we just return.\n    if (dwarf_result != DW_DLV_OK) { return r; }\n\n    dwarf_handle.reset(dwarf_debug);\n\n    r.file_handle = move(file_handle);\n    r.elf_handle = move(elf_handle);\n    r.dwarf_handle = move(dwarf_handle);\n\n    return r;\n  }\n\n  die_cache_entry& get_die_cache(dwarf_fileobject& fobj, Dwarf_Die die) {\n    Dwarf_Error error = DW_DLE_NE;\n\n    // Get the die offset, we use it as the cache key\n    Dwarf_Off die_offset;\n    if (dwarf_dieoffset(die, &die_offset, &error) != DW_DLV_OK) { die_offset = 0; }\n\n    die_cache_t::iterator it = fobj.die_cache.find(die_offset);\n\n    if (it != fobj.die_cache.end()) {\n      fobj.current_cu = &it->second;\n      return it->second;\n    }\n\n    die_cache_entry& de = fobj.die_cache[die_offset];\n    fobj.current_cu = &de;\n\n    Dwarf_Addr line_addr;\n    Dwarf_Small table_count;\n\n    // The addresses in the line section are not fully sorted (they might\n    // be sorted by block of code belonging to the same file), which makes\n    // it necessary to do so before searching is possible.\n    //\n    // As libdwarf allocates a copy of everything, let's get the contents\n    // of the line section and keep it around. We also create a map of\n    // program counter to line table indices so we can search by address\n    // and get the line buffer index.\n    //\n    // To make things more difficult, the same address can span more than\n    // one line, so we need to keep the index pointing to the first line\n    // by using insert instead of the map's [ operator.\n\n    // Get the line context for the DIE\n    if (dwarf_srclines_b(die, 0, &table_count, &de.line_context, &error) == DW_DLV_OK) {\n      // Get the source lines for this line context, to be deallocated\n      // later\n      if (dwarf_srclines_from_linecontext(de.line_context, &de.line_buffer, &de.line_count, &error)\n          == DW_DLV_OK) {\n        // Add all the addresses to our map\n        for (int i = 0; i < de.line_count; i++) {\n          if (dwarf_lineaddr(de.line_buffer[i], &line_addr, &error) != DW_DLV_OK) { line_addr = 0; }\n          de.line_section.insert(std::pair<Dwarf_Addr, int>(line_addr, i));\n        }\n      }\n    }\n\n    // For each CU, cache the function DIEs that contain the\n    // DW_AT_specification attribute. When building with -g3 the function\n    // DIEs are separated in declaration and specification, with the\n    // declaration containing only the name and parameters and the\n    // specification the low/high pc and other compiler attributes.\n    //\n    // We cache those specifications so we don't skip over the declarations,\n    // because they have no pc, and we can do namespace resolution for\n    // DWARF function names.\n    Dwarf_Debug dwarf = fobj.dwarf_handle.get();\n    Dwarf_Die current_die = 0;\n    if (dwarf_child(die, &current_die, &error) == DW_DLV_OK) {\n      for (;;) {\n        Dwarf_Die sibling_die = 0;\n\n        Dwarf_Half tag_value;\n        dwarf_tag(current_die, &tag_value, &error);\n\n        if (tag_value == DW_TAG_subprogram || tag_value == DW_TAG_inlined_subroutine) {\n          Dwarf_Bool has_attr = 0;\n          if (dwarf_hasattr(current_die, DW_AT_specification, &has_attr, &error) == DW_DLV_OK) {\n            if (has_attr) {\n              Dwarf_Attribute attr_mem;\n              if (dwarf_attr(current_die, DW_AT_specification, &attr_mem, &error) == DW_DLV_OK) {\n                Dwarf_Off spec_offset = 0;\n                if (dwarf_formref(attr_mem, &spec_offset, &error) == DW_DLV_OK) {\n                  Dwarf_Off spec_die_offset;\n                  if (dwarf_dieoffset(current_die, &spec_die_offset, &error) == DW_DLV_OK) {\n                    de.spec_section[spec_offset] = spec_die_offset;\n                  }\n                }\n              }\n              dwarf_dealloc(dwarf, attr_mem, DW_DLA_ATTR);\n            }\n          }\n        }\n\n        int result = dwarf_siblingof(dwarf, current_die, &sibling_die, &error);\n        if (result == DW_DLV_ERROR) {\n          break;\n        } else if (result == DW_DLV_NO_ENTRY) {\n          break;\n        }\n\n        if (current_die != die) {\n          dwarf_dealloc(dwarf, current_die, DW_DLA_DIE);\n          current_die = 0;\n        }\n\n        current_die = sibling_die;\n      }\n    }\n    return de;\n  }\n\n  static Dwarf_Die get_referenced_die(Dwarf_Debug dwarf, Dwarf_Die die, Dwarf_Half attr,\n                                      bool global) {\n    Dwarf_Error error = DW_DLE_NE;\n    Dwarf_Attribute attr_mem;\n\n    Dwarf_Die found_die = NULL;\n    if (dwarf_attr(die, attr, &attr_mem, &error) == DW_DLV_OK) {\n      Dwarf_Off offset;\n      int result = 0;\n      if (global) {\n        result = dwarf_global_formref(attr_mem, &offset, &error);\n      } else {\n        result = dwarf_formref(attr_mem, &offset, &error);\n      }\n\n      if (result == DW_DLV_OK) {\n        if (dwarf_offdie(dwarf, offset, &found_die, &error) != DW_DLV_OK) { found_die = NULL; }\n      }\n      dwarf_dealloc(dwarf, attr_mem, DW_DLA_ATTR);\n    }\n    return found_die;\n  }\n\n  static std::string get_referenced_die_name(Dwarf_Debug dwarf, Dwarf_Die die, Dwarf_Half attr,\n                                             bool global) {\n    Dwarf_Error error = DW_DLE_NE;\n    std::string value;\n\n    Dwarf_Die found_die = get_referenced_die(dwarf, die, attr, global);\n\n    if (found_die) {\n      char* name;\n      if (dwarf_diename(found_die, &name, &error) == DW_DLV_OK) {\n        if (name) { value = std::string(name); }\n        dwarf_dealloc(dwarf, name, DW_DLA_STRING);\n      }\n      dwarf_dealloc(dwarf, found_die, DW_DLA_DIE);\n    }\n\n    return value;\n  }\n\n  // Returns a spec DIE linked to the passed one. The caller should\n  // deallocate the DIE\n  static Dwarf_Die get_spec_die(dwarf_fileobject& fobj, Dwarf_Die die) {\n    Dwarf_Debug dwarf = fobj.dwarf_handle.get();\n    Dwarf_Error error = DW_DLE_NE;\n    Dwarf_Off die_offset;\n    if (fobj.current_cu && dwarf_die_CU_offset(die, &die_offset, &error) == DW_DLV_OK) {\n      die_specmap_t::iterator it = fobj.current_cu->spec_section.find(die_offset);\n\n      // If we have a DIE that completes the current one, check if\n      // that one has the pc we are looking for\n      if (it != fobj.current_cu->spec_section.end()) {\n        Dwarf_Die spec_die = 0;\n        if (dwarf_offdie(dwarf, it->second, &spec_die, &error) == DW_DLV_OK) { return spec_die; }\n      }\n    }\n\n    // Maybe we have an abstract origin DIE with the function information?\n    return get_referenced_die(fobj.dwarf_handle.get(), die, DW_AT_abstract_origin, true);\n  }\n\n  static bool die_has_pc(dwarf_fileobject& fobj, Dwarf_Die die, Dwarf_Addr pc) {\n    Dwarf_Addr low_pc = 0, high_pc = 0;\n    Dwarf_Half high_pc_form = 0;\n    Dwarf_Form_Class return_class;\n    Dwarf_Error error = DW_DLE_NE;\n    Dwarf_Debug dwarf = fobj.dwarf_handle.get();\n    bool has_lowpc = false;\n    bool has_highpc = false;\n    bool has_ranges = false;\n\n    if (dwarf_lowpc(die, &low_pc, &error) == DW_DLV_OK) {\n      // If we have a low_pc check if there is a high pc.\n      // If we don't have a high pc this might mean we have a base\n      // address for the ranges list or just an address.\n      has_lowpc = true;\n\n      if (dwarf_highpc_b(die, &high_pc, &high_pc_form, &return_class, &error) == DW_DLV_OK) {\n        // We do have a high pc. In DWARF 4+ this is an offset from the\n        // low pc, but in earlier versions it's an absolute address.\n\n        has_highpc = true;\n        // In DWARF 2/3 this would be a DW_FORM_CLASS_ADDRESS\n        if (return_class == DW_FORM_CLASS_CONSTANT) { high_pc = low_pc + high_pc; }\n\n        // We have low and high pc, check if our address\n        // is in that range\n        return pc >= low_pc && pc < high_pc;\n      }\n    } else {\n      // Reset the low_pc, in case dwarf_lowpc failing set it to some\n      // undefined value.\n      low_pc = 0;\n    }\n\n    // Check if DW_AT_ranges is present and search for the PC in the\n    // returned ranges list. We always add the low_pc, as it not set it will\n    // be 0, in case we had a DW_AT_low_pc and DW_AT_ranges pair\n    bool result = false;\n\n    Dwarf_Attribute attr;\n    if (dwarf_attr(die, DW_AT_ranges, &attr, &error) == DW_DLV_OK) {\n      Dwarf_Off offset;\n      if (dwarf_global_formref(attr, &offset, &error) == DW_DLV_OK) {\n        Dwarf_Ranges* ranges;\n        Dwarf_Signed ranges_count = 0;\n        Dwarf_Unsigned byte_count = 0;\n\n        if (dwarf_get_ranges_a(dwarf, offset, die, &ranges, &ranges_count, &byte_count, &error)\n            == DW_DLV_OK) {\n          has_ranges = ranges_count != 0;\n          for (int i = 0; i < ranges_count; i++) {\n            if (ranges[i].dwr_addr1 != 0 && pc >= ranges[i].dwr_addr1 + low_pc\n                && pc < ranges[i].dwr_addr2 + low_pc) {\n              result = true;\n              break;\n            }\n          }\n          dwarf_ranges_dealloc(dwarf, ranges, ranges_count);\n        }\n      }\n    }\n\n    // Last attempt. We might have a single address set as low_pc.\n    if (!result && low_pc != 0 && pc == low_pc) { result = true; }\n\n    // If we don't have lowpc, highpc and ranges maybe this DIE is a\n    // declaration that relies on a DW_AT_specification DIE that happens\n    // later. Use the specification cache we filled when we loaded this CU.\n    if (!result && (!has_lowpc && !has_highpc && !has_ranges)) {\n      Dwarf_Die spec_die = get_spec_die(fobj, die);\n      if (spec_die) {\n        result = die_has_pc(fobj, spec_die, pc);\n        dwarf_dealloc(dwarf, spec_die, DW_DLA_DIE);\n      }\n    }\n\n    return result;\n  }\n\n  static void get_type(Dwarf_Debug dwarf, Dwarf_Die die, std::string& type) {\n    Dwarf_Error error = DW_DLE_NE;\n\n    Dwarf_Die child = 0;\n    if (dwarf_child(die, &child, &error) == DW_DLV_OK) { get_type(dwarf, child, type); }\n\n    if (child) {\n      type.insert(0, \"::\");\n      dwarf_dealloc(dwarf, child, DW_DLA_DIE);\n    }\n\n    char* name;\n    if (dwarf_diename(die, &name, &error) == DW_DLV_OK) {\n      type.insert(0, std::string(name));\n      dwarf_dealloc(dwarf, name, DW_DLA_STRING);\n    } else {\n      type.insert(0, \"<unknown>\");\n    }\n  }\n\n  static std::string get_type_by_signature(Dwarf_Debug dwarf, Dwarf_Die die) {\n    Dwarf_Error error = DW_DLE_NE;\n\n    Dwarf_Sig8 signature;\n    Dwarf_Bool has_attr = 0;\n    if (dwarf_hasattr(die, DW_AT_signature, &has_attr, &error) == DW_DLV_OK) {\n      if (has_attr) {\n        Dwarf_Attribute attr_mem;\n        if (dwarf_attr(die, DW_AT_signature, &attr_mem, &error) == DW_DLV_OK) {\n          if (dwarf_formsig8(attr_mem, &signature, &error) != DW_DLV_OK) {\n            return std::string(\"<no type signature>\");\n          }\n        }\n        dwarf_dealloc(dwarf, attr_mem, DW_DLA_ATTR);\n      }\n    }\n\n    Dwarf_Unsigned next_cu_header;\n    Dwarf_Sig8 tu_signature;\n    std::string result;\n    bool found = false;\n\n    while (dwarf_next_cu_header_d(dwarf, 0, 0, 0, 0, 0, 0, 0, &tu_signature, 0, &next_cu_header, 0,\n                                  &error)\n           == DW_DLV_OK) {\n      if (strncmp(signature.signature, tu_signature.signature, 8) == 0) {\n        Dwarf_Die type_cu_die = 0;\n        if (dwarf_siblingof_b(dwarf, 0, 0, &type_cu_die, &error) == DW_DLV_OK) {\n          Dwarf_Die child_die = 0;\n          if (dwarf_child(type_cu_die, &child_die, &error) == DW_DLV_OK) {\n            get_type(dwarf, child_die, result);\n            found = !result.empty();\n            dwarf_dealloc(dwarf, child_die, DW_DLA_DIE);\n          }\n          dwarf_dealloc(dwarf, type_cu_die, DW_DLA_DIE);\n        }\n      }\n    }\n\n    if (found) {\n      while (dwarf_next_cu_header_d(dwarf, 0, 0, 0, 0, 0, 0, 0, 0, 0, &next_cu_header, 0, &error)\n             == DW_DLV_OK) {\n        // Reset the cu header state. Unfortunately, libdwarf's\n        // next_cu_header API keeps its own iterator per Dwarf_Debug\n        // that can't be reset. We need to keep fetching elements until\n        // the end.\n      }\n    } else {\n      // If we couldn't resolve the type just print out the signature\n      std::ostringstream string_stream;\n      string_stream << \"<0x\" << std::hex << std::setfill('0');\n      for (int i = 0; i < 8; ++i) {\n        string_stream << std::setw(2) << std::hex << (int)(unsigned char)(signature.signature[i]);\n      }\n      string_stream << \">\";\n      result = string_stream.str();\n    }\n    return result;\n  }\n\n  struct type_context_t {\n    bool is_const;\n    bool is_typedef;\n    bool has_type;\n    bool has_name;\n    std::string text;\n\n    type_context_t() : is_const(false), is_typedef(false), has_type(false), has_name(false) {}\n  };\n\n  // Types are resolved from right to left: we get the variable name first\n  // and then all specifiers (like const or pointer) in a chain of DW_AT_type\n  // DIEs. Call this function recursively until we get a complete type\n  // string.\n  static void set_parameter_string(dwarf_fileobject& fobj, Dwarf_Die die, type_context_t& context) {\n    char* name;\n    Dwarf_Error error = DW_DLE_NE;\n\n    // typedefs contain also the base type, so we skip it and only\n    // print the typedef name\n    if (!context.is_typedef) {\n      if (dwarf_diename(die, &name, &error) == DW_DLV_OK) {\n        if (!context.text.empty()) { context.text.insert(0, \" \"); }\n        context.text.insert(0, std::string(name));\n        dwarf_dealloc(fobj.dwarf_handle.get(), name, DW_DLA_STRING);\n      }\n    } else {\n      context.is_typedef = false;\n      context.has_type = true;\n      if (context.is_const) {\n        context.text.insert(0, \"const \");\n        context.is_const = false;\n      }\n    }\n\n    bool next_type_is_const = false;\n    bool is_keyword = true;\n\n    Dwarf_Half tag = 0;\n    Dwarf_Bool has_attr = 0;\n    if (dwarf_tag(die, &tag, &error) == DW_DLV_OK) {\n      switch (tag) {\n        case DW_TAG_structure_type:\n        case DW_TAG_union_type:\n        case DW_TAG_class_type:\n        case DW_TAG_enumeration_type:\n          context.has_type = true;\n          if (dwarf_hasattr(die, DW_AT_signature, &has_attr, &error) == DW_DLV_OK) {\n            // If we have a signature it means the type is defined\n            // in .debug_types, so we need to load the DIE pointed\n            // at by the signature and resolve it\n            if (has_attr) {\n              std::string type = get_type_by_signature(fobj.dwarf_handle.get(), die);\n              if (context.is_const) type.insert(0, \"const \");\n\n              if (!context.text.empty()) context.text.insert(0, \" \");\n              context.text.insert(0, type);\n            }\n\n            // Treat enums like typedefs, and skip printing its\n            // base type\n            context.is_typedef = (tag == DW_TAG_enumeration_type);\n          }\n          break;\n        case DW_TAG_const_type: next_type_is_const = true; break;\n        case DW_TAG_pointer_type: context.text.insert(0, \"*\"); break;\n        case DW_TAG_reference_type: context.text.insert(0, \"&\"); break;\n        case DW_TAG_restrict_type: context.text.insert(0, \"restrict \"); break;\n        case DW_TAG_rvalue_reference_type: context.text.insert(0, \"&&\"); break;\n        case DW_TAG_volatile_type: context.text.insert(0, \"volatile \"); break;\n        case DW_TAG_typedef:\n          // Propagate the const-ness to the next type\n          // as typedefs are linked to its base type\n          next_type_is_const = context.is_const;\n          context.is_typedef = true;\n          context.has_type = true;\n          break;\n        case DW_TAG_base_type: context.has_type = true; break;\n        case DW_TAG_formal_parameter: context.has_name = true; break;\n        default: is_keyword = false; break;\n      }\n    }\n\n    if (!is_keyword && context.is_const) { context.text.insert(0, \"const \"); }\n\n    context.is_const = next_type_is_const;\n\n    Dwarf_Die ref = get_referenced_die(fobj.dwarf_handle.get(), die, DW_AT_type, true);\n    if (ref) {\n      set_parameter_string(fobj, ref, context);\n      dwarf_dealloc(fobj.dwarf_handle.get(), ref, DW_DLA_DIE);\n    }\n\n    if (!context.has_type && context.has_name) {\n      context.text.insert(0, \"void \");\n      context.has_type = true;\n    }\n  }\n\n  // Resolve the function return type and parameters\n  static void set_function_parameters(std::string& function_name, std::vector<std::string>& ns,\n                                      dwarf_fileobject& fobj, Dwarf_Die die) {\n    Dwarf_Debug dwarf = fobj.dwarf_handle.get();\n    Dwarf_Error error = DW_DLE_NE;\n    Dwarf_Die current_die = 0;\n    std::string parameters;\n    bool has_spec = true;\n    // Check if we have a spec DIE. If we do we use it as it contains\n    // more information, like parameter names.\n    Dwarf_Die spec_die = get_spec_die(fobj, die);\n    if (!spec_die) {\n      has_spec = false;\n      spec_die = die;\n    }\n\n    std::vector<std::string>::const_iterator it = ns.begin();\n    std::string ns_name;\n    for (it = ns.begin(); it < ns.end(); ++it) { ns_name.append(*it).append(\"::\"); }\n\n    if (!ns_name.empty()) { function_name.insert(0, ns_name); }\n\n    // See if we have a function return type. It can be either on the\n    // current die or in its spec one (usually true for inlined functions)\n    std::string return_type = get_referenced_die_name(dwarf, die, DW_AT_type, true);\n    if (return_type.empty()) {\n      return_type = get_referenced_die_name(dwarf, spec_die, DW_AT_type, true);\n    }\n    if (!return_type.empty()) {\n      return_type.append(\" \");\n      function_name.insert(0, return_type);\n    }\n\n    if (dwarf_child(spec_die, &current_die, &error) == DW_DLV_OK) {\n      for (;;) {\n        Dwarf_Die sibling_die = 0;\n\n        Dwarf_Half tag_value;\n        dwarf_tag(current_die, &tag_value, &error);\n\n        if (tag_value == DW_TAG_formal_parameter) {\n          // Ignore artificial (ie, compiler generated) parameters\n          bool is_artificial = false;\n          Dwarf_Attribute attr_mem;\n          if (dwarf_attr(current_die, DW_AT_artificial, &attr_mem, &error) == DW_DLV_OK) {\n            Dwarf_Bool flag = 0;\n            if (dwarf_formflag(attr_mem, &flag, &error) == DW_DLV_OK) { is_artificial = flag != 0; }\n            dwarf_dealloc(dwarf, attr_mem, DW_DLA_ATTR);\n          }\n\n          if (!is_artificial) {\n            type_context_t context;\n            set_parameter_string(fobj, current_die, context);\n\n            if (parameters.empty()) {\n              parameters.append(\"(\");\n            } else {\n              parameters.append(\", \");\n            }\n            parameters.append(context.text);\n          }\n        }\n\n        int result = dwarf_siblingof(dwarf, current_die, &sibling_die, &error);\n        if (result == DW_DLV_ERROR) {\n          break;\n        } else if (result == DW_DLV_NO_ENTRY) {\n          break;\n        }\n\n        if (current_die != die) {\n          dwarf_dealloc(dwarf, current_die, DW_DLA_DIE);\n          current_die = 0;\n        }\n\n        current_die = sibling_die;\n      }\n    }\n    if (parameters.empty()) parameters = \"(\";\n    parameters.append(\")\");\n\n    // If we got a spec DIE we need to deallocate it\n    if (has_spec) dwarf_dealloc(dwarf, spec_die, DW_DLA_DIE);\n\n    function_name.append(parameters);\n  }\n\n  // defined here because in C++98, template function cannot take locally\n  // defined types... grrr.\n  struct inliners_search_cb {\n    void operator()(Dwarf_Die die, std::vector<std::string>& ns) {\n      Dwarf_Error error = DW_DLE_NE;\n      Dwarf_Half tag_value;\n      Dwarf_Attribute attr_mem;\n      Dwarf_Debug dwarf = fobj.dwarf_handle.get();\n\n      dwarf_tag(die, &tag_value, &error);\n\n      switch (tag_value) {\n        char* name;\n        case DW_TAG_subprogram:\n          if (!trace.source.function.empty()) break;\n          if (dwarf_diename(die, &name, &error) == DW_DLV_OK) {\n            trace.source.function = std::string(name);\n            dwarf_dealloc(dwarf, name, DW_DLA_STRING);\n          } else {\n            // We don't have a function name in this DIE.\n            // Check if there is a referenced non-defining\n            // declaration.\n            trace.source.function =\n                get_referenced_die_name(dwarf, die, DW_AT_abstract_origin, true);\n            if (trace.source.function.empty()) {\n              trace.source.function =\n                  get_referenced_die_name(dwarf, die, DW_AT_specification, true);\n            }\n          }\n\n          // Append the function parameters, if available\n          set_function_parameters(trace.source.function, ns, fobj, die);\n\n          // If the object function name is empty, it's possible that\n          // there is no dynamic symbol table (maybe the executable\n          // was stripped or not built with -rdynamic). See if we have\n          // a DWARF linkage name to use instead. We try both\n          // linkage_name and MIPS_linkage_name because the MIPS tag\n          // was the unofficial one until it was adopted in DWARF4.\n          // Old gcc versions generate MIPS_linkage_name\n          if (trace.object_function.empty()) {\n            details::demangler demangler;\n\n            if (dwarf_attr(die, DW_AT_linkage_name, &attr_mem, &error) != DW_DLV_OK) {\n              if (dwarf_attr(die, DW_AT_MIPS_linkage_name, &attr_mem, &error) != DW_DLV_OK) {\n                break;\n              }\n            }\n\n            char* linkage;\n            if (dwarf_formstring(attr_mem, &linkage, &error) == DW_DLV_OK) {\n              trace.object_function = demangler.demangle(linkage);\n              dwarf_dealloc(dwarf, linkage, DW_DLA_STRING);\n            }\n            dwarf_dealloc(dwarf, attr_mem, DW_DLA_ATTR);\n          }\n          break;\n\n        case DW_TAG_inlined_subroutine:\n          ResolvedTrace::SourceLoc sloc;\n\n          if (dwarf_diename(die, &name, &error) == DW_DLV_OK) {\n            sloc.function = std::string(name);\n            dwarf_dealloc(dwarf, name, DW_DLA_STRING);\n          } else {\n            // We don't have a name for this inlined DIE, it could\n            // be that there is an abstract origin instead.\n            // Get the DW_AT_abstract_origin value, which is a\n            // reference to the source DIE and try to get its name\n            sloc.function = get_referenced_die_name(dwarf, die, DW_AT_abstract_origin, true);\n          }\n\n          set_function_parameters(sloc.function, ns, fobj, die);\n\n          std::string file = die_call_file(dwarf, die, cu_die);\n          if (!file.empty()) sloc.filename = file;\n\n          Dwarf_Unsigned number = 0;\n          if (dwarf_attr(die, DW_AT_call_line, &attr_mem, &error) == DW_DLV_OK) {\n            if (dwarf_formudata(attr_mem, &number, &error) == DW_DLV_OK) { sloc.line = number; }\n            dwarf_dealloc(dwarf, attr_mem, DW_DLA_ATTR);\n          }\n\n          if (dwarf_attr(die, DW_AT_call_column, &attr_mem, &error) == DW_DLV_OK) {\n            if (dwarf_formudata(attr_mem, &number, &error) == DW_DLV_OK) { sloc.col = number; }\n            dwarf_dealloc(dwarf, attr_mem, DW_DLA_ATTR);\n          }\n\n          trace.inliners.push_back(sloc);\n          break;\n      };\n    }\n    ResolvedTrace& trace;\n    dwarf_fileobject& fobj;\n    Dwarf_Die cu_die;\n    inliners_search_cb(ResolvedTrace& t, dwarf_fileobject& f, Dwarf_Die c)\n        : trace(t), fobj(f), cu_die(c) {}\n  };\n\n  static Dwarf_Die find_fundie_by_pc(dwarf_fileobject& fobj, Dwarf_Die parent_die, Dwarf_Addr pc,\n                                     Dwarf_Die result) {\n    Dwarf_Die current_die = 0;\n    Dwarf_Error error = DW_DLE_NE;\n    Dwarf_Debug dwarf = fobj.dwarf_handle.get();\n\n    if (dwarf_child(parent_die, &current_die, &error) != DW_DLV_OK) { return NULL; }\n\n    for (;;) {\n      Dwarf_Die sibling_die = 0;\n      Dwarf_Half tag_value;\n      dwarf_tag(current_die, &tag_value, &error);\n\n      switch (tag_value) {\n        case DW_TAG_subprogram:\n        case DW_TAG_inlined_subroutine:\n          if (die_has_pc(fobj, current_die, pc)) { return current_die; }\n      };\n      bool declaration = false;\n      Dwarf_Attribute attr_mem;\n      if (dwarf_attr(current_die, DW_AT_declaration, &attr_mem, &error) == DW_DLV_OK) {\n        Dwarf_Bool flag = 0;\n        if (dwarf_formflag(attr_mem, &flag, &error) == DW_DLV_OK) { declaration = flag != 0; }\n        dwarf_dealloc(dwarf, attr_mem, DW_DLA_ATTR);\n      }\n\n      if (!declaration) {\n        // let's be curious and look deeper in the tree, functions are\n        // not necessarily at the first level, but might be nested\n        // inside a namespace, structure, a function, an inlined\n        // function etc.\n        Dwarf_Die die_mem = 0;\n        Dwarf_Die indie = find_fundie_by_pc(fobj, current_die, pc, die_mem);\n        if (indie) {\n          result = die_mem;\n          return result;\n        }\n      }\n\n      int res = dwarf_siblingof(dwarf, current_die, &sibling_die, &error);\n      if (res == DW_DLV_ERROR) {\n        return NULL;\n      } else if (res == DW_DLV_NO_ENTRY) {\n        break;\n      }\n\n      if (current_die != parent_die) {\n        dwarf_dealloc(dwarf, current_die, DW_DLA_DIE);\n        current_die = 0;\n      }\n\n      current_die = sibling_die;\n    }\n    return NULL;\n  }\n\n  template<typename CB>\n  static bool deep_first_search_by_pc(dwarf_fileobject& fobj, Dwarf_Die parent_die, Dwarf_Addr pc,\n                                      std::vector<std::string>& ns, CB cb) {\n    Dwarf_Die current_die = 0;\n    Dwarf_Debug dwarf = fobj.dwarf_handle.get();\n    Dwarf_Error error = DW_DLE_NE;\n\n    if (dwarf_child(parent_die, &current_die, &error) != DW_DLV_OK) { return false; }\n\n    bool branch_has_pc = false;\n    bool has_namespace = false;\n    for (;;) {\n      Dwarf_Die sibling_die = 0;\n\n      Dwarf_Half tag;\n      if (dwarf_tag(current_die, &tag, &error) == DW_DLV_OK) {\n        if (tag == DW_TAG_namespace || tag == DW_TAG_class_type) {\n          char* ns_name = NULL;\n          if (dwarf_diename(current_die, &ns_name, &error) == DW_DLV_OK) {\n            if (ns_name) {\n              ns.push_back(std::string(ns_name));\n            } else {\n              ns.push_back(\"<unknown>\");\n            }\n            dwarf_dealloc(dwarf, ns_name, DW_DLA_STRING);\n          } else {\n            ns.push_back(\"<unknown>\");\n          }\n          has_namespace = true;\n        }\n      }\n\n      bool declaration = false;\n      Dwarf_Attribute attr_mem;\n      if (tag != DW_TAG_class_type\n          && dwarf_attr(current_die, DW_AT_declaration, &attr_mem, &error) == DW_DLV_OK) {\n        Dwarf_Bool flag = 0;\n        if (dwarf_formflag(attr_mem, &flag, &error) == DW_DLV_OK) { declaration = flag != 0; }\n        dwarf_dealloc(dwarf, attr_mem, DW_DLA_ATTR);\n      }\n\n      if (!declaration) {\n        // let's be curious and look deeper in the tree, function are\n        // not necessarily at the first level, but might be nested\n        // inside a namespace, structure, a function, an inlined\n        // function etc.\n        branch_has_pc = deep_first_search_by_pc(fobj, current_die, pc, ns, cb);\n      }\n\n      if (!branch_has_pc) { branch_has_pc = die_has_pc(fobj, current_die, pc); }\n\n      if (branch_has_pc) { cb(current_die, ns); }\n\n      int result = dwarf_siblingof(dwarf, current_die, &sibling_die, &error);\n      if (result == DW_DLV_ERROR) {\n        return false;\n      } else if (result == DW_DLV_NO_ENTRY) {\n        break;\n      }\n\n      if (current_die != parent_die) {\n        dwarf_dealloc(dwarf, current_die, DW_DLA_DIE);\n        current_die = 0;\n      }\n\n      if (has_namespace) {\n        has_namespace = false;\n        ns.pop_back();\n      }\n      current_die = sibling_die;\n    }\n\n    if (has_namespace) { ns.pop_back(); }\n    return branch_has_pc;\n  }\n\n  static std::string die_call_file(Dwarf_Debug dwarf, Dwarf_Die die, Dwarf_Die cu_die) {\n    Dwarf_Attribute attr_mem;\n    Dwarf_Error error = DW_DLE_NE;\n    Dwarf_Unsigned file_index;\n\n    std::string file;\n\n    if (dwarf_attr(die, DW_AT_call_file, &attr_mem, &error) == DW_DLV_OK) {\n      if (dwarf_formudata(attr_mem, &file_index, &error) != DW_DLV_OK) { file_index = 0; }\n      dwarf_dealloc(dwarf, attr_mem, DW_DLA_ATTR);\n\n      if (file_index == 0) { return file; }\n\n      char** srcfiles = 0;\n      Dwarf_Signed file_count = 0;\n      if (dwarf_srcfiles(cu_die, &srcfiles, &file_count, &error) == DW_DLV_OK) {\n        if (file_count > 0 && file_index <= static_cast<Dwarf_Unsigned>(file_count)) {\n          file = std::string(srcfiles[file_index - 1]);\n        }\n\n        // Deallocate all strings!\n        for (int i = 0; i < file_count; ++i) { dwarf_dealloc(dwarf, srcfiles[i], DW_DLA_STRING); }\n        dwarf_dealloc(dwarf, srcfiles, DW_DLA_LIST);\n      }\n    }\n    return file;\n  }\n\n  Dwarf_Die find_die(dwarf_fileobject& fobj, Dwarf_Addr addr) {\n    // Let's get to work! First see if we have a debug_aranges section so\n    // we can speed up the search\n\n    Dwarf_Debug dwarf = fobj.dwarf_handle.get();\n    Dwarf_Error error = DW_DLE_NE;\n    Dwarf_Arange* aranges;\n    Dwarf_Signed arange_count;\n\n    Dwarf_Die returnDie;\n    bool found = false;\n    if (dwarf_get_aranges(dwarf, &aranges, &arange_count, &error) != DW_DLV_OK) { aranges = NULL; }\n\n    if (aranges) {\n      // We have aranges. Get the one where our address is.\n      Dwarf_Arange arange;\n      if (dwarf_get_arange(aranges, arange_count, addr, &arange, &error) == DW_DLV_OK) {\n        // We found our address. Get the compilation-unit DIE offset\n        // represented by the given address range.\n        Dwarf_Off cu_die_offset;\n        if (dwarf_get_cu_die_offset(arange, &cu_die_offset, &error) == DW_DLV_OK) {\n          // Get the DIE at the offset returned by the aranges search.\n          // We set is_info to 1 to specify that the offset is from\n          // the .debug_info section (and not .debug_types)\n          int dwarf_result = dwarf_offdie_b(dwarf, cu_die_offset, 1, &returnDie, &error);\n\n          found = dwarf_result == DW_DLV_OK;\n        }\n        dwarf_dealloc(dwarf, arange, DW_DLA_ARANGE);\n      }\n    }\n\n    if (found) return returnDie;  // The caller is responsible for freeing the die\n\n    // The search for aranges failed. Try to find our address by scanning\n    // all compilation units.\n    Dwarf_Unsigned next_cu_header;\n    Dwarf_Half tag = 0;\n    returnDie = 0;\n\n    while (!found\n           && dwarf_next_cu_header_d(dwarf, 1, 0, 0, 0, 0, 0, 0, 0, 0, &next_cu_header, 0, &error)\n                  == DW_DLV_OK) {\n      if (returnDie) dwarf_dealloc(dwarf, returnDie, DW_DLA_DIE);\n\n      if (dwarf_siblingof(dwarf, 0, &returnDie, &error) == DW_DLV_OK) {\n        if ((dwarf_tag(returnDie, &tag, &error) == DW_DLV_OK) && tag == DW_TAG_compile_unit) {\n          if (die_has_pc(fobj, returnDie, addr)) { found = true; }\n        }\n      }\n    }\n\n    if (found) {\n      while (dwarf_next_cu_header_d(dwarf, 1, 0, 0, 0, 0, 0, 0, 0, 0, &next_cu_header, 0, &error)\n             == DW_DLV_OK) {\n        // Reset the cu header state. Libdwarf's next_cu_header API\n        // keeps its own iterator per Dwarf_Debug that can't be reset.\n        // We need to keep fetching elements until the end.\n      }\n    }\n\n    if (found) return returnDie;\n\n    // We couldn't find any compilation units with ranges or a high/low pc.\n    // Try again by looking at all DIEs in all compilation units.\n    Dwarf_Die cudie;\n    while (dwarf_next_cu_header_d(dwarf, 1, 0, 0, 0, 0, 0, 0, 0, 0, &next_cu_header, 0, &error)\n           == DW_DLV_OK) {\n      if (dwarf_siblingof(dwarf, 0, &cudie, &error) == DW_DLV_OK) {\n        Dwarf_Die die_mem = 0;\n        Dwarf_Die resultDie = find_fundie_by_pc(fobj, cudie, addr, die_mem);\n\n        if (resultDie) {\n          found = true;\n          break;\n        }\n      }\n    }\n\n    if (found) {\n      while (dwarf_next_cu_header_d(dwarf, 1, 0, 0, 0, 0, 0, 0, 0, 0, &next_cu_header, 0, &error)\n             == DW_DLV_OK) {\n        // Reset the cu header state. Libdwarf's next_cu_header API\n        // keeps its own iterator per Dwarf_Debug that can't be reset.\n        // We need to keep fetching elements until the end.\n      }\n    }\n\n    if (found) return cudie;\n\n    // We failed.\n    return NULL;\n  }\n};\n#endif  // BACKWARD_HAS_DWARF == 1\n\ntemplate<>\nclass TraceResolverImpl<system_tag::linux_tag>\n    : public TraceResolverLinuxImpl<trace_resolver_tag::current> {};\n\n#endif  // BACKWARD_SYSTEM_LINUX\n\n#ifdef BACKWARD_SYSTEM_DARWIN\n\ntemplate<typename STACKTRACE_TAG>\nclass TraceResolverDarwinImpl;\n\ntemplate<>\nclass TraceResolverDarwinImpl<trace_resolver_tag::backtrace_symbol> : public TraceResolverImplBase {\n public:\n  void load_addresses(void* const* addresses, int address_count) override {\n    if (address_count == 0) { return; }\n    _symbols.reset(backtrace_symbols(addresses, address_count));\n  }\n\n  ResolvedTrace resolve(ResolvedTrace trace) override {\n    // parse:\n    // <n>  <file>  <addr>  <mangled-name> + <offset>\n    char* filename = _symbols[trace.idx];\n\n    // skip \"<n>  \"\n    while (*filename && *filename != ' ') filename++;\n    while (*filename == ' ') filename++;\n\n    // find start of <mangled-name> from end (<file> may contain a space)\n    char* p = filename + strlen(filename) - 1;\n    // skip to start of \" + <offset>\"\n    while (p > filename && *p != ' ') p--;\n    while (p > filename && *p == ' ') p--;\n    while (p > filename && *p != ' ') p--;\n    while (p > filename && *p == ' ') p--;\n    char* funcname_end = p + 1;\n\n    // skip to start of \"<manged-name>\"\n    while (p > filename && *p != ' ') p--;\n    char* funcname = p + 1;\n\n    // skip to start of \"  <addr>  \"\n    while (p > filename && *p == ' ') p--;\n    while (p > filename && *p != ' ') p--;\n    while (p > filename && *p == ' ') p--;\n\n    // skip \"<file>\", handling the case where it contains a\n    char* filename_end = p + 1;\n    if (p == filename) {\n      // something went wrong, give up\n      filename_end = filename + strlen(filename);\n      funcname = filename_end;\n    }\n    trace.object_filename.assign(filename, filename_end);  // ok even if filename_end is the ending\n                                                           // \\0 (then we assign entire string)\n\n    if (*funcname) {  // if it's not end of string\n      *funcname_end = '\\0';\n\n      trace.object_function = this->demangle(funcname);\n      trace.object_function += \" \";\n      trace.object_function += (funcname_end + 1);\n      trace.source.function = trace.object_function;  // we cannot do better.\n    }\n    return trace;\n  }\n\n private:\n  details::handle<char**> _symbols;\n};\n\ntemplate<>\nclass TraceResolverImpl<system_tag::darwin_tag>\n    : public TraceResolverDarwinImpl<trace_resolver_tag::current> {};\n\n#endif  // BACKWARD_SYSTEM_DARWIN\n\n#ifdef BACKWARD_SYSTEM_WINDOWS\n\n// Load all symbol info\n// Based on:\n// https://stackoverflow.com/questions/6205981/windows-c-stack-trace-from-a-running-app/28276227#28276227\n\nstruct module_data {\n  std::string image_name;\n  std::string module_name;\n  void* base_address;\n  DWORD load_size;\n};\n\nclass get_mod_info {\n  HANDLE process;\n  static const int buffer_length = 4096;\n\n public:\n  get_mod_info(HANDLE h) : process(h) {}\n\n  module_data operator()(HMODULE module) {\n    module_data ret;\n    char temp[buffer_length];\n    MODULEINFO mi;\n\n    GetModuleInformation(process, module, &mi, sizeof(mi));\n    ret.base_address = mi.lpBaseOfDll;\n    ret.load_size = mi.SizeOfImage;\n\n    GetModuleFileNameExA(process, module, temp, sizeof(temp));\n    ret.image_name = temp;\n    GetModuleBaseNameA(process, module, temp, sizeof(temp));\n    ret.module_name = temp;\n    std::vector<char> img(ret.image_name.begin(), ret.image_name.end());\n    std::vector<char> mod(ret.module_name.begin(), ret.module_name.end());\n    SymLoadModule64(process, 0, &img[0], &mod[0], (DWORD64)ret.base_address, ret.load_size);\n    return ret;\n  }\n};\n\ntemplate<>\nclass TraceResolverImpl<system_tag::windows_tag> : public TraceResolverImplBase {\n public:\n  TraceResolverImpl() {\n    HANDLE process = GetCurrentProcess();\n\n    std::vector<module_data> modules;\n    DWORD cbNeeded;\n    std::vector<HMODULE> module_handles(1);\n    SymInitialize(process, NULL, false);\n    DWORD symOptions = SymGetOptions();\n    symOptions |= SYMOPT_LOAD_LINES | SYMOPT_UNDNAME;\n    SymSetOptions(symOptions);\n    EnumProcessModules(process, &module_handles[0],\n                       static_cast<DWORD>(module_handles.size() * sizeof(HMODULE)), &cbNeeded);\n    module_handles.resize(cbNeeded / sizeof(HMODULE));\n    EnumProcessModules(process, &module_handles[0],\n                       static_cast<DWORD>(module_handles.size() * sizeof(HMODULE)), &cbNeeded);\n    std::transform(module_handles.begin(), module_handles.end(), std::back_inserter(modules),\n                   get_mod_info(process));\n    void* base = modules[0].base_address;\n    IMAGE_NT_HEADERS* h = ImageNtHeader(base);\n    image_type = h->FileHeader.Machine;\n  }\n\n  static const int max_sym_len = 255;\n  struct symbol_t {\n    SYMBOL_INFO sym;\n    char buffer[max_sym_len];\n  } sym;\n\n  DWORD64 displacement;\n\n  ResolvedTrace resolve(ResolvedTrace t) override {\n    HANDLE process = GetCurrentProcess();\n\n    char name[256];\n\n    memset(&sym, 0, sizeof(sym));\n    sym.sym.SizeOfStruct = sizeof(SYMBOL_INFO);\n    sym.sym.MaxNameLen = max_sym_len;\n\n    if (!SymFromAddr(process, (ULONG64)t.addr, &displacement, &sym.sym)) {\n      // TODO:  error handling everywhere\n      char* lpMsgBuf;\n      DWORD dw = GetLastError();\n\n      if (FormatMessageA(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM\n                             | FORMAT_MESSAGE_IGNORE_INSERTS,\n                         NULL, dw, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), (char*)&lpMsgBuf, 0,\n                         NULL)) {\n        std::fprintf(stderr, \"%s\\n\", lpMsgBuf);\n        LocalFree(lpMsgBuf);\n      }\n\n      // abort();\n    }\n    UnDecorateSymbolName(sym.sym.Name, (PSTR)name, 256, UNDNAME_COMPLETE);\n\n    DWORD offset = 0;\n    IMAGEHLP_LINE line;\n    if (SymGetLineFromAddr(process, (ULONG64)t.addr, &offset, &line)) {\n      t.object_filename = line.FileName;\n      t.source.filename = line.FileName;\n      t.source.line = line.LineNumber;\n      t.source.col = offset;\n    }\n\n    t.source.function = name;\n    t.object_filename = \"\";\n    t.object_function = name;\n\n    return t;\n  }\n\n  DWORD machine_type() const { return image_type; }\n\n private:\n  DWORD image_type;\n};\n\n#endif\n\nclass TraceResolver : public TraceResolverImpl<system_tag::current_tag> {};\n\n/*************** CODE SNIPPET ***************/\n\nclass SourceFile {\n public:\n  typedef std::vector<std::pair<unsigned, std::string>> lines_t;\n\n  SourceFile() {}\n  SourceFile(const std::string& path) {\n    // 1. If BACKWARD_CXX_SOURCE_PREFIXES is set then assume it contains\n    //    a colon-separated list of path prefixes.  Try prepending each\n    //    to the given path until a valid file is found.\n    const std::vector<std::string>& prefixes = get_paths_from_env_variable();\n    for (size_t i = 0; i < prefixes.size(); ++i) {\n      // Double slashes (//) should not be a problem.\n      std::string new_path = prefixes[i] + '/' + path;\n      _file.reset(new std::ifstream(new_path.c_str()));\n      if (is_open()) break;\n    }\n    // 2. If no valid file found then fallback to opening the path as-is.\n    if (!_file || !is_open()) { _file.reset(new std::ifstream(path.c_str())); }\n  }\n  bool is_open() const { return _file->is_open(); }\n\n  lines_t& get_lines(unsigned line_start, unsigned line_count, lines_t& lines) {\n    using namespace std;\n    // This function make uses of the dumbest algo ever:\n    //\t1) seek(0)\n    //\t2) read lines one by one and discard until line_start\n    //\t3) read line one by one until line_start + line_count\n    //\n    // If you are getting snippets many time from the same file, it is\n    // somewhat a waste of CPU, feel free to benchmark and propose a\n    // better solution ;)\n\n    _file->clear();\n    _file->seekg(0);\n    string line;\n    unsigned line_idx;\n\n    for (line_idx = 1; line_idx < line_start; ++line_idx) {\n      std::getline(*_file, line);\n      if (!*_file) { return lines; }\n    }\n\n    // think of it like a lambda in C++98 ;)\n    // but look, I will reuse it two times!\n    // What a good boy am I.\n    struct isspace {\n      bool operator()(char c) { return std::isspace(c); }\n    };\n\n    bool started = false;\n    for (; line_idx < line_start + line_count; ++line_idx) {\n      getline(*_file, line);\n      if (!*_file) { return lines; }\n      if (!started) {\n        if (std::find_if(line.begin(), line.end(), not_isspace()) == line.end()) continue;\n        started = true;\n      }\n      lines.push_back(make_pair(line_idx, line));\n    }\n\n    lines.erase(std::find_if(lines.rbegin(), lines.rend(), not_isempty()).base(), lines.end());\n    return lines;\n  }\n\n  lines_t get_lines(unsigned line_start, unsigned line_count) {\n    lines_t lines;\n    return get_lines(line_start, line_count, lines);\n  }\n\n  // there is no find_if_not in C++98, lets do something crappy to\n  // workaround.\n  struct not_isspace {\n    bool operator()(char c) { return !std::isspace(c); }\n  };\n  // and define this one here because C++98 is not happy with local defined\n  // struct passed to template functions, fuuuu.\n  struct not_isempty {\n    bool operator()(const lines_t::value_type& p) {\n      return !(std::find_if(p.second.begin(), p.second.end(), not_isspace()) == p.second.end());\n    }\n  };\n\n  void swap(SourceFile& b) { _file.swap(b._file); }\n\n#ifdef BACKWARD_ATLEAST_CXX11\n  SourceFile(SourceFile&& from) : _file(nullptr) { swap(from); }\n  SourceFile& operator=(SourceFile&& from) {\n    swap(from);\n    return *this;\n  }\n#else\n  explicit SourceFile(const SourceFile& from) {\n    // some sort of poor man's move semantic.\n    swap(const_cast<SourceFile&>(from));\n  }\n  SourceFile& operator=(const SourceFile& from) {\n    // some sort of poor man's move semantic.\n    swap(const_cast<SourceFile&>(from));\n    return *this;\n  }\n#endif\n\n  // Allow adding to paths gotten from BACKWARD_CXX_SOURCE_PREFIXES after loading the\n  // library; this can be useful when the library is loaded when the locations are unknown\n  // Warning: Because this edits the static paths variable, it is *not* intrinsiclly thread safe\n  static void add_paths_to_env_variable_impl(const std::string& to_add) {\n    get_mutable_paths_from_env_variable().push_back(to_add);\n  }\n\n private:\n  details::handle<std::ifstream*, details::default_delete<std::ifstream*>> _file;\n\n  static std::vector<std::string> get_paths_from_env_variable_impl() {\n    std::vector<std::string> paths;\n    const char* prefixes_str = std::getenv(\"BACKWARD_CXX_SOURCE_PREFIXES\");\n    if (prefixes_str && prefixes_str[0]) { paths = details::split_source_prefixes(prefixes_str); }\n    return paths;\n  }\n\n  static std::vector<std::string>& get_mutable_paths_from_env_variable() {\n    static volatile std::vector<std::string> paths = get_paths_from_env_variable_impl();\n    return const_cast<std::vector<std::string>&>(paths);\n  }\n\n  static const std::vector<std::string>& get_paths_from_env_variable() {\n    return get_mutable_paths_from_env_variable();\n  }\n\n#ifdef BACKWARD_ATLEAST_CXX11\n  SourceFile(const SourceFile&) = delete;\n  SourceFile& operator=(const SourceFile&) = delete;\n#endif\n};\n\nclass SnippetFactory {\n public:\n  typedef SourceFile::lines_t lines_t;\n\n  lines_t get_snippet(const std::string& filename, unsigned line_start, unsigned context_size) {\n    SourceFile& src_file = get_src_file(filename);\n    unsigned start = line_start - context_size / 2;\n    return src_file.get_lines(start, context_size);\n  }\n\n  lines_t get_combined_snippet(const std::string& filename_a, unsigned line_a,\n                               const std::string& filename_b, unsigned line_b,\n                               unsigned context_size) {\n    SourceFile& src_file_a = get_src_file(filename_a);\n    SourceFile& src_file_b = get_src_file(filename_b);\n\n    lines_t lines = src_file_a.get_lines(line_a - context_size / 4, context_size / 2);\n    src_file_b.get_lines(line_b - context_size / 4, context_size / 2, lines);\n    return lines;\n  }\n\n  lines_t get_coalesced_snippet(const std::string& filename, unsigned line_a, unsigned line_b,\n                                unsigned context_size) {\n    SourceFile& src_file = get_src_file(filename);\n\n    using std::max;\n    using std::min;\n    unsigned a = min(line_a, line_b);\n    unsigned b = max(line_a, line_b);\n\n    if ((b - a) < (context_size / 3)) {\n      return src_file.get_lines((a + b - context_size + 1) / 2, context_size);\n    }\n\n    lines_t lines = src_file.get_lines(a - context_size / 4, context_size / 2);\n    src_file.get_lines(b - context_size / 4, context_size / 2, lines);\n    return lines;\n  }\n\n private:\n  typedef details::hashtable<std::string, SourceFile>::type src_files_t;\n  src_files_t _src_files;\n\n  SourceFile& get_src_file(const std::string& filename) {\n    src_files_t::iterator it = _src_files.find(filename);\n    if (it != _src_files.end()) { return it->second; }\n    SourceFile& new_src_file = _src_files[filename];\n    new_src_file = SourceFile(filename);\n    return new_src_file;\n  }\n};\n\n/*************** PRINTER ***************/\n\nnamespace ColorMode {\nenum type { automatic, never, always };\n}\n\nclass cfile_streambuf : public std::streambuf {\n public:\n  cfile_streambuf(FILE* _sink) : sink(_sink) {}\n  int_type underflow() override { return traits_type::eof(); }\n  int_type overflow(int_type ch) override {\n    if (traits_type::not_eof(ch) && fputc(ch, sink) != EOF) { return ch; }\n    return traits_type::eof();\n  }\n\n  std::streamsize xsputn(const char_type* s, std::streamsize count) override {\n    return static_cast<std::streamsize>(fwrite(s, sizeof *s, static_cast<size_t>(count), sink));\n  }\n\n#ifdef BACKWARD_ATLEAST_CXX11\n public:\n  cfile_streambuf(const cfile_streambuf&) = delete;\n  cfile_streambuf& operator=(const cfile_streambuf&) = delete;\n#else\n private:\n  cfile_streambuf(const cfile_streambuf&);\n  cfile_streambuf& operator=(const cfile_streambuf&);\n#endif\n\n private:\n  FILE* sink;\n  std::vector<char> buffer;\n};\n\n#ifdef BACKWARD_SYSTEM_LINUX\n\nnamespace Color {\nenum type { yellow = 33, purple = 35, reset = 39 };\n}  // namespace Color\n\nclass Colorize {\n public:\n  Colorize(std::ostream& os) : _os(os), _reset(false), _enabled(false) {}\n\n  void activate(ColorMode::type mode) { _enabled = mode == ColorMode::always; }\n\n  void activate(ColorMode::type mode, FILE* fp) { activate(mode, fileno(fp)); }\n\n  void set_color(Color::type ccode) {\n    if (!_enabled) return;\n\n    // I assume that the terminal can handle basic colors. Seriously I\n    // don't want to deal with all the termcap shit.\n    _os << \"\\033[\" << static_cast<int>(ccode) << \"m\";\n    _reset = (ccode != Color::reset);\n  }\n\n  ~Colorize() {\n    if (_reset) { set_color(Color::reset); }\n  }\n\n private:\n  void activate(ColorMode::type mode, int fd) {\n    activate(mode == ColorMode::automatic && isatty(fd) ? ColorMode::always : mode);\n  }\n\n  std::ostream& _os;\n  bool _reset;\n  bool _enabled;\n};\n\n#else  // ndef BACKWARD_SYSTEM_LINUX\n\nnamespace Color {\nenum type { yellow = 0, purple = 0, reset = 0 };\n}  // namespace Color\n\nclass Colorize {\n public:\n  Colorize(std::ostream&) {}\n  void activate(ColorMode::type) {}\n  void activate(ColorMode::type, FILE*) {}\n  void set_color(Color::type) {}\n};\n\n#endif  // BACKWARD_SYSTEM_LINUX\n\nclass Printer {\n public:\n  bool snippet;\n  ColorMode::type color_mode;\n  bool address;\n  bool object;\n  int inliner_context_size;\n  int trace_context_size;\n  bool reverse;\n\n  Printer()\n      : snippet(true),\n        color_mode(ColorMode::automatic),\n        address(false),\n        object(false),\n        // Modify: Show one line by default\n        // inliner_context_size(5),\n        // trace_context_size(7),\n        inliner_context_size(1),\n        trace_context_size(1),\n        reverse(true) {}\n\n  template<typename ST>\n  FILE* print(ST& st, FILE* fp = stderr) {\n    cfile_streambuf obuf(fp);\n    std::ostream os(&obuf);\n    Colorize colorize(os);\n    colorize.activate(color_mode, fp);\n    print_stacktrace(st, os, colorize);\n    return fp;\n  }\n\n  template<typename ST>\n  std::ostream& print(ST& st, std::ostream& os) {\n    Colorize colorize(os);\n    colorize.activate(color_mode);\n    print_stacktrace(st, os, colorize);\n    return os;\n  }\n\n  template<typename IT>\n  FILE* print(IT begin, IT end, FILE* fp = stderr, size_t thread_id = 0) {\n    cfile_streambuf obuf(fp);\n    std::ostream os(&obuf);\n    Colorize colorize(os);\n    colorize.activate(color_mode, fp);\n    print_stacktrace(begin, end, os, thread_id, colorize);\n    return fp;\n  }\n\n  template<typename IT>\n  std::ostream& print(IT begin, IT end, std::ostream& os, size_t thread_id = 0) {\n    Colorize colorize(os);\n    colorize.activate(color_mode);\n    print_stacktrace(begin, end, os, thread_id, colorize);\n    return os;\n  }\n\n  // Modify: skip stacks in python object file\n  static inline bool is_oneflow_file(const std::string& filename) {\n    return std::string(std::filesystem::path(filename).filename()).find(\"oneflow\")\n           != std::string::npos;\n  }\n\n  TraceResolver const& resolver() const { return _resolver; }\n\n private:\n  TraceResolver _resolver;\n  SnippetFactory _snippets;\n\n  template<typename ST>\n  void print_stacktrace(ST& st, std::ostream& os, Colorize& colorize) {\n    print_header(os, st.thread_id());\n    _resolver.load_stacktrace(st);\n    if (reverse) {\n      for (size_t trace_idx = st.size(); trace_idx > 0; --trace_idx) {\n        print_trace(os, _resolver.resolve(st[trace_idx - 1]), colorize);\n      }\n    } else {\n      for (size_t trace_idx = 0; trace_idx < st.size(); ++trace_idx) {\n        print_trace(os, _resolver.resolve(st[trace_idx]), colorize);\n      }\n    }\n    // Modify: Add a new line before Python stack\n    os << std::endl;\n  }\n\n  template<typename IT>\n  void print_stacktrace(IT begin, IT end, std::ostream& os, size_t thread_id, Colorize& colorize) {\n    print_header(os, thread_id);\n    for (; begin != end; ++begin) { print_trace(os, *begin, colorize); }\n  }\n\n  void print_header(std::ostream& os, size_t thread_id) {\n    os << \"Stack trace (most recent call last)\";\n    if (thread_id) { os << \" in thread \" << thread_id; }\n    os << \":\\n\";\n  }\n\n  void print_trace(std::ostream& os, const ResolvedTrace& trace, Colorize& colorize) {\n    // Modify: skip stacks in python object file\n    if (!is_oneflow_file(trace.object_filename)) { return; }\n    // Modify: symbol '#', trace idx and indent are not necessary\n    // os << \"#\" << std::left << std::setw(2) << trace.idx << std::right;\n    // bool already_indented = true;\n\n    if (!trace.source.filename.size() || object) {\n      os << \"   Object \\\"\" << trace.object_filename << \"\\\", at \" << trace.addr << \", in \"\n         << trace.object_function << \"\\n\";\n      // Modify: Extra indent is not necessary\n      // already_indented = false;\n    }\n\n    for (size_t inliner_idx = trace.inliners.size(); inliner_idx > 0; --inliner_idx) {\n      // Modify: Extra indent is not necessary\n      // if (!already_indented) { os << \"   \"; }\n      const ResolvedTrace::SourceLoc& inliner_loc = trace.inliners[inliner_idx - 1];\n      print_source_loc(os, \" | \", inliner_loc);\n      if (snippet) {\n        // Modify: Symbol '|' is not necessary\n        // print_snippet(os, \"    | \", inliner_loc, colorize, Color::purple, inliner_context_size);\n        print_snippet(os, \"   \", inliner_loc, colorize, Color::purple, inliner_context_size);\n      }\n      // Modify: Extra indent is not necessary\n      // already_indented = false;\n    }\n\n    if (trace.source.filename.size()) {\n      // Modify: Extra indent is not necessary\n      // if (!already_indented) { os << \"   \"; }\n      // Modify: Adjust the indent\n      // print_source_loc(os, \"   \", trace.source, trace.addr);\n      print_source_loc(os, \"  \", trace.source, trace.addr);\n      if (snippet) {\n        // Modify: Adjust the indent\n        // print_snippet(os, \"      \", trace.source, colorize, Color::yellow, trace_context_size);\n        print_snippet(os, \"   \", trace.source, colorize, Color::yellow, trace_context_size);\n      }\n    }\n  }\n\n  void print_snippet(std::ostream& os, const char* indent,\n                     const ResolvedTrace::SourceLoc& source_loc, Colorize& colorize,\n                     Color::type color_code, int context_size) {\n    using namespace std;\n    typedef SnippetFactory::lines_t lines_t;\n\n    lines_t lines = _snippets.get_snippet(source_loc.filename, source_loc.line,\n                                          static_cast<unsigned>(context_size));\n\n    for (lines_t::const_iterator it = lines.begin(); it != lines.end(); ++it) {\n      if (it->first == source_loc.line) {\n        colorize.set_color(color_code);\n        // Modify: Remove symbol '>' if there is only one line to show\n        //   os << indent << \">\";\n        // } else {\n        //   os << indent << \" \";\n        // }\n        // os << std::setw(4) << it->first << \": \" << it->second << \"\\n\";\n        if (lines.size() > 1) {\n          os << indent << \">\";\n        } else {\n          os << indent << \" \";\n        }\n      } else {\n        os << indent << \" \";\n      }\n      const auto pos = it->second.find_first_not_of(\" \\t\");\n      os << std::setw(4) << it->second.substr(pos, it->second.size() - pos) << \"\\n\";\n      if (it->first == source_loc.line) { colorize.set_color(Color::reset); }\n    }\n  }\n\n  void print_source_loc(std::ostream& os, const char* indent,\n                        const ResolvedTrace::SourceLoc& source_loc, void* addr = nullptr) {\n    // Modify: Remove indent and replace 'Source' to 'File'\n    // os << indent << \"Source \\\"\" << source_loc.filename << \"\\\", line \" << source_loc.line << \", in\n    // \"\n    os << \"  File \\\"\" << source_loc.filename << \"\\\", line \" << source_loc.line << \", in \"\n       << source_loc.function;\n\n    if (address && addr != nullptr) { os << \" [\" << addr << \"]\"; }\n    os << \"\\n\";\n  }\n};\n\n/*************** SIGNALS HANDLING ***************/\n\n#if defined(BACKWARD_SYSTEM_LINUX) || defined(BACKWARD_SYSTEM_DARWIN)\n\nclass SignalHandling {\n public:\n  static std::vector<int> make_default_signals() {\n    const int posix_signals[] = {\n      // Signals for which the default action is \"Core\".\n      SIGABRT,  // Abort signal from abort(3)\n      SIGBUS,   // Bus error (bad memory access)\n      SIGFPE,   // Floating point exception\n      SIGILL,   // Illegal Instruction\n      SIGIOT,   // IOT trap. A synonym for SIGABRT\n      SIGQUIT,  // Quit from keyboard\n      SIGSEGV,  // Invalid memory reference\n      SIGSYS,   // Bad argument to routine (SVr4)\n      SIGTRAP,  // Trace/breakpoint trap\n      SIGXCPU,  // CPU time limit exceeded (4.2BSD)\n      SIGXFSZ,  // File size limit exceeded (4.2BSD)\n#if defined(BACKWARD_SYSTEM_DARWIN)\n      SIGEMT,  // emulation instruction executed\n#endif\n    };\n    return std::vector<int>(posix_signals,\n                            posix_signals + sizeof posix_signals / sizeof posix_signals[0]);\n  }\n\n  SignalHandling(const std::vector<int>& posix_signals = make_default_signals()) : _loaded(false) {\n    bool success = true;\n\n    const size_t stack_size = 1024 * 1024 * 8;\n    _stack_content.reset(static_cast<char*>(malloc(stack_size)));\n    if (_stack_content) {\n      stack_t ss;\n      ss.ss_sp = _stack_content.get();\n      ss.ss_size = stack_size;\n      ss.ss_flags = 0;\n      if (sigaltstack(&ss, nullptr) < 0) { success = false; }\n    } else {\n      success = false;\n    }\n\n    for (size_t i = 0; i < posix_signals.size(); ++i) {\n      struct sigaction action;\n      memset(&action, 0, sizeof action);\n      action.sa_flags = static_cast<int>(SA_SIGINFO | SA_ONSTACK | SA_NODEFER | SA_RESETHAND);\n      sigfillset(&action.sa_mask);\n      sigdelset(&action.sa_mask, posix_signals[i]);\n#if defined(__clang__)\n#pragma clang diagnostic push\n#pragma clang diagnostic ignored \"-Wdisabled-macro-expansion\"\n#endif\n      action.sa_sigaction = &sig_handler;\n#if defined(__clang__)\n#pragma clang diagnostic pop\n#endif\n\n      int r = sigaction(posix_signals[i], &action, nullptr);\n      if (r < 0) success = false;\n    }\n\n    _loaded = success;\n  }\n\n  bool loaded() const { return _loaded; }\n\n  static void handleSignal(int, siginfo_t* info, void* _ctx) {\n    ucontext_t* uctx = static_cast<ucontext_t*>(_ctx);\n\n    StackTrace st;\n    void* error_addr = nullptr;\n#ifdef REG_RIP  // x86_64\n    error_addr = reinterpret_cast<void*>(uctx->uc_mcontext.gregs[REG_RIP]);\n#elif defined(REG_EIP)  // x86_32\n    error_addr = reinterpret_cast<void*>(uctx->uc_mcontext.gregs[REG_EIP]);\n#elif defined(__arm__)\n    error_addr = reinterpret_cast<void*>(uctx->uc_mcontext.arm_pc);\n#elif defined(__aarch64__)\n#if defined(__APPLE__)\n    error_addr = reinterpret_cast<void*>(uctx->uc_mcontext->__ss.__pc);\n#else\n    error_addr = reinterpret_cast<void*>(uctx->uc_mcontext.pc);\n#endif\n#elif defined(__mips__)\n    error_addr =\n        reinterpret_cast<void*>(reinterpret_cast<struct sigcontext*>(&uctx->uc_mcontext)->sc_pc);\n#elif defined(__ppc__) || defined(__powerpc) || defined(__powerpc__) || defined(__POWERPC__)\n    error_addr = reinterpret_cast<void*>(uctx->uc_mcontext.regs->nip);\n#elif defined(__riscv)\n    error_addr = reinterpret_cast<void*>(uctx->uc_mcontext.__gregs[REG_PC]);\n#elif defined(__s390x__)\n    error_addr = reinterpret_cast<void*>(uctx->uc_mcontext.psw.addr);\n#elif defined(__APPLE__) && defined(__x86_64__)\n    error_addr = reinterpret_cast<void*>(uctx->uc_mcontext->__ss.__rip);\n#elif defined(__APPLE__)\n    error_addr = reinterpret_cast<void*>(uctx->uc_mcontext->__ss.__eip);\n#else\n#warning \":/ sorry, ain't know no nothing none not of your architecture!\"\n#endif\n    if (error_addr) {\n      st.load_from(error_addr, 32, reinterpret_cast<void*>(uctx), info->si_addr);\n    } else {\n      st.load_here(32, reinterpret_cast<void*>(uctx), info->si_addr);\n    }\n\n    Printer printer;\n    // Modify: Hide the address in stack when seg fault\n    // printer.address = true;\n    printer.print(st, stderr);\n\n#if (defined(_XOPEN_SOURCE) && _XOPEN_SOURCE >= 700) \\\n    || (defined(_POSIX_C_SOURCE) && _POSIX_C_SOURCE >= 200809L)\n    psiginfo(info, nullptr);\n#else\n    (void)info;\n#endif\n  }\n\n private:\n  details::handle<char*> _stack_content;\n  bool _loaded;\n\n#ifdef __GNUC__\n  __attribute__((noreturn))\n#endif\n  static void\n  sig_handler(int signo, siginfo_t* info, void* _ctx) {\n    handleSignal(signo, info, _ctx);\n\n    // try to forward the signal.\n    raise(info->si_signo);\n\n    // terminate the process immediately.\n    puts(\"watf? exit\");\n    _exit(EXIT_FAILURE);\n  }\n};\n\n#endif  // BACKWARD_SYSTEM_LINUX || BACKWARD_SYSTEM_DARWIN\n\n#ifdef BACKWARD_SYSTEM_WINDOWS\n\nclass SignalHandling {\n public:\n  SignalHandling(const std::vector<int>& = std::vector<int>())\n      : reporter_thread_([]() {\n          /* We handle crashes in a utility thread:\n            backward structures and some Windows functions called here\n            need stack space, which we do not have when we encounter a\n            stack overflow.\n            To support reporting stack traces during a stack overflow,\n            we create a utility thread at startup, which waits until a\n            crash happens or the program exits normally. */\n\n          {\n            std::unique_lock<std::mutex> lk(mtx());\n            cv().wait(lk, [] { return crashed() != crash_status::running; });\n          }\n          if (crashed() == crash_status::crashed) { handle_stacktrace(skip_recs()); }\n          {\n            std::unique_lock<std::mutex> lk(mtx());\n            crashed() = crash_status::ending;\n          }\n          cv().notify_one();\n        }) {\n    SetUnhandledExceptionFilter(crash_handler);\n\n    signal(SIGABRT, signal_handler);\n    _set_abort_behavior(0, _WRITE_ABORT_MSG | _CALL_REPORTFAULT);\n\n    std::set_terminate(&terminator);\n#ifndef BACKWARD_ATLEAST_CXX17\n    std::set_unexpected(&terminator);\n#endif\n    _set_purecall_handler(&terminator);\n    _set_invalid_parameter_handler(&invalid_parameter_handler);\n  }\n  bool loaded() const { return true; }\n\n  ~SignalHandling() {\n    {\n      std::unique_lock<std::mutex> lk(mtx());\n      crashed() = crash_status::normal_exit;\n    }\n\n    cv().notify_one();\n\n    reporter_thread_.join();\n  }\n\n private:\n  static CONTEXT* ctx() {\n    static CONTEXT data;\n    return &data;\n  }\n\n  enum class crash_status { running, crashed, normal_exit, ending };\n\n  static crash_status& crashed() {\n    static crash_status data;\n    return data;\n  }\n\n  static std::mutex& mtx() {\n    static std::mutex data;\n    return data;\n  }\n\n  static std::condition_variable& cv() {\n    static std::condition_variable data;\n    return data;\n  }\n\n  static HANDLE& thread_handle() {\n    static HANDLE handle;\n    return handle;\n  }\n\n  std::thread reporter_thread_;\n\n  // TODO: how not to hardcode these?\n  static const constexpr int signal_skip_recs =\n#ifdef __clang__\n      // With clang, RtlCaptureContext also captures the stack frame of the\n      // current function Below that, there are 3 internal Windows functions\n      4\n#else\n      // With MSVC cl, RtlCaptureContext misses the stack frame of the current\n      // function The first entries during StackWalk are the 3 internal Windows\n      // functions\n      3\n#endif\n      ;\n\n  static int& skip_recs() {\n    static int data;\n    return data;\n  }\n\n  static inline void terminator() {\n    crash_handler(signal_skip_recs);\n    abort();\n  }\n\n  static inline void signal_handler(int) {\n    crash_handler(signal_skip_recs);\n    abort();\n  }\n\n  static inline void __cdecl invalid_parameter_handler(const wchar_t*, const wchar_t*,\n                                                       const wchar_t*, unsigned int, uintptr_t) {\n    crash_handler(signal_skip_recs);\n    abort();\n  }\n\n  NOINLINE static LONG WINAPI crash_handler(EXCEPTION_POINTERS* info) {\n    // The exception info supplies a trace from exactly where the issue was,\n    // no need to skip records\n    crash_handler(0, info->ContextRecord);\n    return EXCEPTION_CONTINUE_SEARCH;\n  }\n\n  NOINLINE static void crash_handler(int skip, CONTEXT* ct = nullptr) {\n    if (ct == nullptr) {\n      RtlCaptureContext(ctx());\n    } else {\n      memcpy(ctx(), ct, sizeof(CONTEXT));\n    }\n    DuplicateHandle(GetCurrentProcess(), GetCurrentThread(), GetCurrentProcess(), &thread_handle(),\n                    0, FALSE, DUPLICATE_SAME_ACCESS);\n\n    skip_recs() = skip;\n\n    {\n      std::unique_lock<std::mutex> lk(mtx());\n      crashed() = crash_status::crashed;\n    }\n\n    cv().notify_one();\n\n    {\n      std::unique_lock<std::mutex> lk(mtx());\n      cv().wait(lk, [] { return crashed() != crash_status::crashed; });\n    }\n  }\n\n  static void handle_stacktrace(int skip_frames = 0) {\n    // printer creates the TraceResolver, which can supply us a machine type\n    // for stack walking. Without this, StackTrace can only guess using some\n    // macros.\n    // StackTrace also requires that the PDBs are already loaded, which is done\n    // in the constructor of TraceResolver\n    Printer printer;\n\n    StackTrace st;\n    st.set_machine_type(printer.resolver().machine_type());\n    st.set_thread_handle(thread_handle());\n    st.load_here(32 + skip_frames, ctx());\n    st.skip_n_firsts(skip_frames);\n\n    printer.address = true;\n    printer.print(st, std::cerr);\n  }\n};\n\n#endif  // BACKWARD_SYSTEM_WINDOWS\n\n#ifdef BACKWARD_SYSTEM_UNKNOWN\n\nclass SignalHandling {\n public:\n  SignalHandling(const std::vector<int>& = std::vector<int>()) {}\n  bool init() { return false; }\n  bool loaded() { return false; }\n};\n\n#endif  // BACKWARD_SYSTEM_UNKNOWN\n\n}  // namespace backward\n\n#endif /* H_GUARD */\n"
  },
  {
    "path": "oneflow/ir/.gitignore",
    "content": "/build*\nlit.site.cfg.py\n"
  },
  {
    "path": "oneflow/ir/CMakeLists.txt",
    "content": "cmake_minimum_required(VERSION 3.13.4)\ninclude(FetchContent)\n# prevent LLVM_DEFINITIONS has a TRUE in it\nunset(result CACHE)\nset(CMAKE_INSTALL_MESSAGE LAZY)\nif(POLICY CMP0068)\n  cmake_policy(SET CMP0068 NEW)\n  set(CMAKE_BUILD_WITH_INSTALL_NAME_DIR ON)\nendif()\n\nif(POLICY CMP0075)\n  cmake_policy(SET CMP0075 NEW)\nendif()\n\nif(POLICY CMP0077)\n  cmake_policy(SET CMP0077 NEW)\nendif()\n\nif(POLICY CMP0116)\n  cmake_policy(SET CMP0116 OLD)\nendif()\n\nproject(oneflow-dialect LANGUAGES CXX C)\n\n# https://github.com/llvm/llvm-project/issues/55010\nset(LLVM_ABI_BREAKING_CHECKS \"FORCE_OFF\" CACHE STRING \"\")\n\nif(LLVM_PROVIDER STREQUAL \"in-tree\")\n  include(llvm-in-tree.cmake)\nelseif(LLVM_PROVIDER STREQUAL \"install\")\n  include(install-llvm.cmake)\nelse()\n  message(FATAL_ERROR \"LLVM_PROVIDER should be in-tree or install, but got: ${LLVM_PROVIDER}\")\nendif()\n\nset_property(GLOBAL PROPERTY LLVM_INSTALL_DIR ${LLVM_INSTALL_DIR})\nset(MLIR_TABLEGEN_EXE mlir-tblgen)\nset(MLIR_PDLL_TABLEGEN_EXE mlir-pdll)\n\ninclude_directories(${LLVM_INCLUDE_DIRS})\ninclude_directories(${MLIR_INCLUDE_DIRS})\nset(LLVM_INCLUDE_DIRS ${LLVM_INCLUDE_DIRS} PARENT_SCOPE)\nset(MLIR_INCLUDE_DIRS ${MLIR_INCLUDE_DIRS} PARENT_SCOPE)\nset(ONEFLOW_MLIR_SOURCE_INCLUDE_DIRS ${PROJECT_SOURCE_DIR}/include PARENT_SCOPE)\nset(ONEFLOW_MLIR_BINARY_INCLUDE_DIRS ${PROJECT_BINARY_DIR}/include PARENT_SCOPE)\n\ninclude_directories(${PROJECT_SOURCE_DIR}/include)\ninclude_directories(${PROJECT_BINARY_DIR}/include)\nlink_directories(${LLVM_BUILD_LIBRARY_DIR})\nadd_definitions(${LLVM_DEFINITIONS})\n\nif(LLVM_PROVIDER STREQUAL \"in-tree\")\n  add_subdirectory(${CMAKE_SOURCE_DIR}/tools/oneflow-tblgen ${PROJECT_BINARY_DIR}/oneflow-tblgen)\nendif()\n\nfunction(update_rpath)\n  set_property(TARGET ${ARGV0} APPEND PROPERTY BUILD_RPATH \"${LLVM_LIBRARY_DIR}\")\n  set_property(TARGET ${ARGV0} APPEND PROPERTY BUILD_RPATH \"${ONEFLOW_BUILD_ROOT_DIR}\")\n  set_property(TARGET ${ARGV0} APPEND PROPERTY INSTALL_RPATH \"${LLVM_LIBRARY_DIR}\")\n  set_property(TARGET ${ARGV0} APPEND PROPERTY INSTALL_RPATH \"${ONEFLOW_BUILD_ROOT_DIR}\")\nendfunction(update_rpath)\n\nfunction(oneflow_add_mlir_library)\n  add_mlir_library(${ARGV})\n  set_compile_options_to_oneflow_target(${ARGV0})\n  update_rpath(${ARGV0})\nendfunction()\n\nfunction(oneflow_add_mlir_dialect_library)\n  add_mlir_dialect_library(${ARGV})\n  set_compile_options_to_oneflow_target(${ARGV0})\n  update_rpath(${ARGV0})\nendfunction()\n\nfunction(oneflow_add_llvm_tool)\n  add_llvm_tool(${ARGV})\n  llvm_update_compile_flags(oneflow-runner)\n  set_compile_options_to_oneflow_target(${ARGV0})\n  update_rpath(${ARGV0})\nendfunction()\n\nfind_package(Threads REQUIRED)\nset(LLVM_PTHREAD_LIB ${CMAKE_THREAD_LIBS_INIT})\n\nset(LLVM_RUNTIME_OUTPUT_INTDIR ${PROJECT_BINARY_DIR}/bin)\nset(LLVM_LIBRARY_OUTPUT_INTDIR ${PROJECT_BINARY_DIR}/lib)\nif(WITH_MLIR)\n  add_subdirectory(include)\n  add_subdirectory(lib)\n  add_subdirectory(test)\n  add_subdirectory(oneflow-translate)\n  add_subdirectory(oneflow-runtime)\n  add_subdirectory(oneflow-extension)\n  add_subdirectory(oneflow-opt)\n  add_subdirectory(oneflow-runner)\n  add_subdirectory(oneflow-lite)\nendif(WITH_MLIR)\n\nif(BUILD_PYTHON)\n  foreach(llvm_include_dir ${LLVM_INCLUDE_DIRS})\n    if(llvm_include_dir MATCHES \"/include$\")\n      list(APPEND LLVM_INSTALL_INCLUDE_DIRS \"${llvm_include_dir}//\")\n    else()\n      list(APPEND LLVM_INSTALL_INCLUDE_DIRS \"${llvm_include_dir}\")\n    endif()\n  endforeach()\n  install(\n    DIRECTORY ${LLVM_INSTALL_INCLUDE_DIRS}\n    DESTINATION ${ONEFLOW_INCLUDE_DIR}\n    COMPONENT oneflow_py_include\n    EXCLUDE_FROM_ALL FILES_MATCHING\n    PATTERN llvm/ADT/ArrayRef.h\n    PATTERN llvm/ADT/Hashing.h\n    PATTERN llvm/ADT/iterator.h\n    PATTERN llvm/ADT/None.h\n    PATTERN llvm/ADT/SmallVector.h\n    PATTERN llvm/ADT/STLExtras.h\n    PATTERN llvm/ADT/STLFunctionalExtras.h\n    PATTERN llvm/ADT/DenseMapInfo.h\n    PATTERN llvm/ADT/identity.h\n    PATTERN llvm/ADT/iterator_range.h\n    PATTERN llvm/ADT/Optional.h\n    PATTERN llvm/ADT/STLArrayExtras.h\n    PATTERN llvm/ADT/STLForwardCompat.h\n    PATTERN llvm/ADT/StringRef.h\n    PATTERN llvm/ADT/bit.h\n    PATTERN llvm/Config/abi-breaking.h\n    PATTERN llvm/Config/llvm-config.h\n    PATTERN llvm/Support/Compiler.h\n    PATTERN llvm/Support/DataTypes.h\n    PATTERN llvm/Support/ErrorHandling.h\n    PATTERN llvm/Support/SwapByteOrder.h\n    PATTERN llvm/Support/type_traits.h\n    PATTERN llvm-c/DataTypes.h)\nendif()\n"
  },
  {
    "path": "oneflow/ir/README.md",
    "content": "# OneFlow IR\n\nOneFlow IR, a MLIR dialect\n\n## Code style\n\nInevitably, developers maintaining OneFlow IR would face these challenges:\n- Debugging components related to IR, compiler could be complicated and peculiar.\n- IR subsystems should follow latest changes of OneFlow and MLIR closely.\n\nTo address these problems,\nwithin the IR source code directory,\nthere are some rules must be enforced for all the optimizers, importers, exporters, runners:\n- separate library, include, target\n- MLIR-releted code should follow the style and paradigm of MLIR and LLVM closely\n- ensure every component could be independently compiled and tested\n    - there should be one `CMakeLists.txt` in every sub-directory\n    - don't link anything from OneFlow unless it is necessary for the feature\n\n## Major components\n- ### oneflow-translate\nEverything related to MLIR-OneFlow translation. [read more](oneflow-translate/README.md)\n\n- ### oneflow-opt\nOptimizations on OneFlow MLIR dialect. A CLI to optimize .mlir file. [read more](oneflow-opt/README.md)\n\n- ### OneFlow dialect\nIn the `include` and `lib` directories, there are definitions of MLIR OneFlow dialect and its operators.\n\n- ### OneFlow Kenerl Memory (OKM) Dialect\nIn the `include` and `lib` directories, there are definitions of MLIR OKM dialect and its operators.\nOKM is a dialect which support oneflow using mlir memref style and use-def flow  to optimize memory usage.\n\n- ### OneFlow Kernel Launch (OKL) dialect\nIn the `include` and `lib` directories, there are definitions of MLIR OKL dialect and its operators.\nOKL is a dialect which support oneflow kernel ops launched as a a llvm dialect callee.\n## Parallel Signature\n\n- There is parallel signature as 0 for OneFlow Ops in MLIR. It is implemented as MLIR dialect attribute. Some examples:\n    - 1D SBP\n        ```mlir\n        %100 = \"oneflow.relu\"(%99) {parallel = #sbp.parallel<[#sbp.S<0>] -> [#sbp.S<0>]>, ...\n        ```\n    - multiple inputs and outputs 1D SBP\n        ```mlir\n        %102 = \"oneflow.add_n2\"(%101, %97) {parallel = #sbp.parallel<[#sbp.S<0>, #sbp.S<0>] -> [#sbp.S<0>]>, ...\n        ```\n    - 2D SBP `matmul`\n        ```\n        %120 = \"oneflow.matmul\"(%119, %output_105) {parallel = #sbp.parallel<[[#sbp.S<0>, #sbp.P], #sbp.S<0>] -> [#sbp.S<0>]>, ...\n        ```\n\n- To avoid confusion and potential parsing error, use the term \"parallel\" instead of using \"sbp\" but conceptually and documentally there are the same.\n\n### Principle\n- In IR, The signature should be orthogonal to device placement information althogh in some passes they might be related to each other.\n\n## Development\n\n- To run all the regression tests. The `-j3` option for [`LIT`](https://llvm.org/docs/CommandGuide/lit.html) is to prevent OOM on GPU.\n    ```bash\n    LIT_OPTS=\"-j3\" cmake --build build -t c1 -j24\n    ```\n"
  },
  {
    "path": "oneflow/ir/include/CMakeLists.txt",
    "content": "add_subdirectory(OneFlow)\nadd_subdirectory(Transform)\n"
  },
  {
    "path": "oneflow/ir/include/OneFlow/CMakeLists.txt",
    "content": "# set(ONEFLOW_USER_OP_GEN_TD_PATH \"${PROJECT_BINARY_DIR}/include/OneFlow\")\nset(ONEFLOW_USER_OP_GEN_TD_PATH \"${PROJECT_SOURCE_DIR}/include/OneFlow\")\n\nset(LLVM_TARGET_DEFINITIONS OneFlowEnums.td)\nmlir_tablegen(OneFlowEnums.h.inc -gen-enum-decls)\nmlir_tablegen(OneFlowEnums.cpp.inc -gen-enum-defs)\nadd_public_tablegen_target(MLIROneFlowEnumsIncGen)\n\nset(LLVM_TARGET_DEFINITIONS OneFlowPatterns.td)\nset(ONEFLOW_OP_GROUPS_USED_IN_PATTERNS\n    \"SCALAR;UNARY;FUSED;MISC;BINARY;IDEMPOTENT;NORMALIZATION;MATMUL;BROADCAST;CONV;PADDING\")\nforeach(OP_GROUP_NAME IN LISTS ONEFLOW_OP_GROUPS_USED_IN_PATTERNS)\n  list(APPEND LLVM_TABLEGEN_FLAGS \"-DGET_ONEFLOW_${OP_GROUP_NAME}_OP_DEFINITIONS\")\nendforeach()\nmlir_tablegen(OneFlowPatterns.cpp.inc -gen-rewriters)\nadd_public_tablegen_target(MLIROneFlowPatternsIncGen)\n\n# NOTE: seperate conversion and opt with --name\nif(WITH_MLIR_CUDA_CODEGEN)\n  list(APPEND LLVM_TABLEGEN_FLAGS \"-DWITH_MLIR_CUDA_CODEGEN\")\nendif()\nset(LLVM_TARGET_DEFINITIONS OneFlowPasses.td)\nmlir_tablegen(OneFlowPasses.h.inc -gen-pass-decls)\nadd_public_tablegen_target(MLIROneFlowPassIncGen)\n\nset(LLVM_TABLEGEN_FLAGS \"\")\nadd_mlir_interface(OneFlowInterfaces)\n\nset(LLVM_TARGET_DEFINITIONS OneFlowOpGetGen.td)\n\nset(ONEFLOW_OP_GROUPS\n    \"ASSIGN;BINARY;BROADCAST;CONV;CROSS_ENTROPY;CUDA;DATASET;DETECTION;EAGER;FUSED;IDEMPOTENT;IDENTITY;IMAGE;INDICES;INVOLUTION;LOSS;MATH;MATMUL;MISC;NCCL;NORMALIZATION;OPTIMIZER;PADDING;PARALLEL_CAST;POOL;QUANTIZATION;REDUCE;RESHAPE;SCALAR;SOFTMAX;SUMMARY;TENSOR_BUFFER;TEST;TRIGONOMETRIC;UNARY;UPSAMPLE;ONE_EMBEDDING;LINEAR_ALGEBRA;SYSTEM;MLIR_JIT\"\n)\nforeach(OP_GROUP_NAME IN LISTS ONEFLOW_OP_GROUPS)\n  message(STATUS \"Enable OneFlow MLIR op group: ${OP_GROUP_NAME}\")\n  set(ONE_LLVM_TABLEGEN_FLAGS \"-DGET_ONEFLOW_${OP_GROUP_NAME}_OP_DEFINITIONS\")\n  list(APPEND FULL_LLVM_TABLEGEN_FLAGS \"${ONE_LLVM_TABLEGEN_FLAGS}\")\n  set(LLVM_TABLEGEN_FLAGS \"${ONE_LLVM_TABLEGEN_FLAGS}\")\n  string(TOLOWER \"${OP_GROUP_NAME}\" OP_GROUP_NAME_LOWER)\n  set(CPP_INC_FILE \"OneFlow.${OP_GROUP_NAME_LOWER}_ops.cpp.inc\")\n  set(HEADER_INC_FILE \"OneFlow.${OP_GROUP_NAME_LOWER}_ops.h.inc\")\n  mlir_tablegen(${CPP_INC_FILE} -gen-op-defs)\n  mlir_tablegen(${HEADER_INC_FILE} -gen-op-decls)\nendforeach()\nadd_public_tablegen_target(MLIROneFlowOpGroupDefsIncGen)\n\nset(LLVM_TABLEGEN_FLAGS \"${FULL_LLVM_TABLEGEN_FLAGS}\")\nmlir_tablegen(OneFlow.gen_ops.h.inc -gen-op-decls)\nadd_public_tablegen_target(MLIROneFlowOpGroupDeclsIncGen)\n\nset(LLVM_TARGET_DEFINITIONS SBP/SBPOps.td)\nmlir_tablegen(SBPDialect.h.inc -gen-dialect-decls)\nmlir_tablegen(SBPDialect.cpp.inc -gen-dialect-defs)\nmlir_tablegen(SBPAttributes.h.inc -gen-attrdef-decls)\nmlir_tablegen(SBPAttributes.cpp.inc -gen-attrdef-defs)\nadd_public_tablegen_target(MLIRSBPIncGen)\n\nset(LLVM_TARGET_DEFINITIONS OKL/OKLOps.td)\nmlir_tablegen(OKLDialect.h.inc -gen-dialect-decls -dialect=okl)\nmlir_tablegen(OKLDialect.cpp.inc -gen-dialect-defs -dialect=okl)\nmlir_tablegen(OKLOps.h.inc -gen-op-decls)\nmlir_tablegen(OKLOps.cpp.inc -gen-op-defs)\nmlir_tablegen(OKLTypes.h.inc -gen-typedef-decls)\nmlir_tablegen(OKLTypes.cpp.inc -gen-typedef-defs)\nmlir_tablegen(OKLPasses.h.inc -gen-pass-decls)\nmlir_tablegen(OKLEnums.h.inc -gen-enum-decls)\nmlir_tablegen(OKLEnums.cpp.inc -gen-enum-defs)\nmlir_tablegen(OKLAttributes.h.inc -gen-attrdef-decls -attrdefs-dialect=okl)\nmlir_tablegen(OKLAttributes.cpp.inc -gen-attrdef-defs -attrdefs-dialect=okl)\nadd_public_tablegen_target(MLIROKLIncGen)\n\nset(LLVM_TARGET_DEFINITIONS OKM/OKMOps.td)\nmlir_tablegen(OKMDialect.h.inc -gen-dialect-decls -dialect=okm)\nmlir_tablegen(OKMDialect.cpp.inc -gen-dialect-defs -dialect=okm)\nmlir_tablegen(OKMOps.h.inc -gen-op-decls)\nmlir_tablegen(OKMOps.cpp.inc -gen-op-defs)\nmlir_tablegen(OKMPasses.h.inc -gen-pass-decls)\nmlir_tablegen(OKMAttributes.h.inc -gen-attrdef-decls)\nmlir_tablegen(OKMAttributes.cpp.inc -gen-attrdef-defs)\nadd_public_tablegen_target(MLIROKMIncGen)\n\nset(LLVM_TABLEGEN_FLAGS \"\")\nadd_mlir_dialect(\n  OneFlowOps\n  oneflow\n  DEPENDS\n  MLIRSBPIncGen\n  MLIROneFlowEnumsIncGen\n  MLIROneFlowPatternsIncGen\n  MLIROneFlowPassIncGen\n  MLIROneFlowInterfacesIncGen\n  MLIROneFlowOpGroupDefsIncGen\n  MLIROneFlowOpGroupDeclsIncGen)\n"
  },
  {
    "path": "oneflow/ir/include/OneFlow/Conversion/NVVMToCubin.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_IR_INCLUDE_ONEFLOW_CONVERSION_NVVMTOCUBIN_H_\n#define ONEFLOW_IR_INCLUDE_ONEFLOW_CONVERSION_NVVMTOCUBIN_H_\n\n#ifdef WITH_MLIR_CUDA_CODEGEN\n\n#include \"mlir/Pass/Pass.h\"\n\nnamespace mlir {\n\nnamespace gpu {\n\ninline std::string getCubinAnnotation() { return \"gpu.binary\"; }\n\n}  // namespace gpu\n\nnamespace oneflow {\n\nconst char* getArchVersion();\n\nstd::unique_ptr<mlir::Pass> createNVVMToCubinPass();\n\nvoid InitializeLLVMNVPTXBackend();\n\n}  // namespace oneflow\n\n}  // namespace mlir\n\n#endif  // WITH_MLIR_CUDA_CODEGEN\n\n#endif  // ONEFLOW_IR_INCLUDE_ONEFLOW_CONVERSION_NVVMTOCUBIN_H_"
  },
  {
    "path": "oneflow/ir/include/OneFlow/Conversion/OneFlowToTosa.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_IR_INCLUDE_ONEFLOW_CONVERSION_ONEFLOWTOTOSA_H_\n#define ONEFLOW_IR_INCLUDE_ONEFLOW_CONVERSION_ONEFLOWTOTOSA_H_\n\n#include \"mlir/Dialect/Tosa/IR/TosaOps.h\"\n#include \"mlir/Pass/Pass.h\"\n\nnamespace mlir {\n\nnamespace oneflow {\n\nstd::unique_ptr<mlir::Pass> createLowerOneFlowToTosaPass();\nstd::unique_ptr<mlir::Pass> createLowerOneFlowToLinalgPass();\nstd::unique_ptr<mlir::Pass> createConvertToSignlessForTosaPass();\nstd::unique_ptr<mlir::Pass> createCastOneFlowOpsToSignlessPass();\n\n}  // namespace oneflow\n\n}  // namespace mlir\n\n#endif  // ONEFLOW_IR_INCLUDE_ONEFLOW_CONVERSION_ONEFLOWTOTOSA_H_\n"
  },
  {
    "path": "oneflow/ir/include/OneFlow/Extension.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#ifndef ONEFLOW_IR_INCLUDE_ONEFLOW_EXTENSION_H_\n#define ONEFLOW_IR_INCLUDE_ONEFLOW_EXTENSION_H_\n#include <unordered_set>\n#include <string>\n\nnamespace oneflow {\n\nusing SharedLibs = std::unordered_set<std::string>;\nSharedLibs* MutSharedLibPaths();\nconst SharedLibs* SharedLibPaths();\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_IR_INCLUDE_ONEFLOW_EXTENSION_H_\n"
  },
  {
    "path": "oneflow/ir/include/OneFlow/OKL/Conversion/Conversion.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_IR_INCLUDE_ONEFLOW_OKL_CONVERSION_CONVERSION_H_\n#define ONEFLOW_IR_INCLUDE_ONEFLOW_OKL_CONVERSION_CONVERSION_H_\n\n#include \"OneFlow/OKL/Conversion/OKLToLLVM.h\"\n#include \"mlir/IR/BuiltinOps.h\"\n\nnamespace mlir {\nnamespace okl {\n\n// convert okl dialect to llvm dialect\nLogicalResult LowerOKLComputeToLLVM(ModuleOp module);\n\n}  // namespace okl\n}  // namespace mlir\n\n#endif  // ONEFLOW_IR_INCLUDE_ONEFLOW_OKL_CONVERSION_CONVERSION_H_\n"
  },
  {
    "path": "oneflow/ir/include/OneFlow/OKL/Conversion/OKLToLLVM.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_IR_INCLUDE_ONEFLOW_OKL_CONVERSION_OKLTOLLVM_H_\n#define ONEFLOW_IR_INCLUDE_ONEFLOW_OKL_CONVERSION_OKLTOLLVM_H_\n\n#include \"mlir/Pass/Pass.h\"\n\nnamespace mlir {\nnamespace okl {\n\n// lower !okl.launcher_ctx to !llvm.ptr<i8>\nstd::unique_ptr<mlir::Pass> createLowerLauncherToLLVMPtrPass();\n\n// lower okl ops to llvm.call @{callee in liboneflow.so}\nstd::unique_ptr<mlir::Pass> createLowerOKLToLLVMCallPass();\n\n// tag {okl.cuda_graph_support} according to its wrapped ops\nstd::unique_ptr<mlir::Pass> createTagCudaGraphSupportPass();\n\nnamespace cuda_graph_support {\n\nstatic const auto TAG_NAME = \"cuda_graph_support\";\n\n}  // namespace cuda_graph_support\n}  // namespace okl\n}  // namespace mlir\n\n#endif  // ONEFLOW_IR_INCLUDE_ONEFLOW_OKL_CONVERSION_OKLTOLLVM_H_\n"
  },
  {
    "path": "oneflow/ir/include/OneFlow/OKL/Kernel/ComputeContext.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_IR_INCLUDE_ONEFLOW_OKL_KERNEL_COMPUTECONTEXT_H_\n#define ONEFLOW_IR_INCLUDE_ONEFLOW_OKL_KERNEL_COMPUTECONTEXT_H_\n\n#include \"mlir/IR/BuiltinAttributes.h\"\n#include \"OneFlow/OKL/Kernel/RegContext.h\"\n#include \"OneFlow/OKL/Kernel/TmpBufferManager.h\"\n\nnamespace oneflow {\nnamespace okl {\nclass ComputeContext final : public user_op::KernelComputeContext {\n public:\n  ComputeContext(RegContext const* reg_ctx, user_op::KernelComputeContext* comp_ctx)\n      : reg_ctx_(reg_ctx),\n        comp_ctx_(comp_ctx),\n        tmp_buffer_(comp_ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0)) {}\n\n  ~ComputeContext() = default;\n\n  const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(const std::string& arg_name,\n                                                        int32_t index) const override {\n    return reg_ctx_->TensorDesc4ArgNameAndIndex(arg_name, index);\n  }\n\n  ep::Stream* stream() override { return comp_ctx_->stream(); }\n\n  DeviceType device_type() const override { return reg_ctx_->device_type(); }\n  const ParallelContext& parallel_ctx() const override { return comp_ctx_->parallel_ctx(); }\n\n  const ArgVec& inputs() const override { return reg_ctx_->inputs(); }\n  const ArgVec& outputs() const override { return reg_ctx_->outputs(); }\n\n  const user_op::UserOpConfWrapper& user_op_conf() const override {\n    return reg_ctx_->user_op_conf();\n  }\n  user_op::Tensor* Tensor4ArgNameAndIndex(const std::string& arg_name, int32_t index) override;\n\n private:\n  RegContext const* reg_ctx_;\n  KernelComputeContext* comp_ctx_;\n  TmpBufferManager tmp_buffer_;\n\n  std::unordered_map<mlir::oneflow::user_op::ArgID, user_op::Tensor*> tensor_{};\n\n  user_op::Tensor* CreateTensorWithArgNameAndIndex(const std::string& arg_name, int32_t index);\n  const std::shared_ptr<const user_op::AttrVal>& Attr4Name(\n      const std::string& attr_name) const override {\n    return user_op_conf().Attr4Name(attr_name);\n  }\n};\n\n}  // namespace okl\n}  // namespace oneflow\n\n#endif  // ONEFLOW_IR_INCLUDE_ONEFLOW_OKL_KERNEL_COMPUTECONTEXT_H_\n"
  },
  {
    "path": "oneflow/ir/include/OneFlow/OKL/Kernel/InferContext.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#ifndef ONEFLOW_IR_INCLUDE_ONEFLOW_OKL_KERNEL_INFERCONTEXT_H_\n#define ONEFLOW_IR_INCLUDE_ONEFLOW_OKL_KERNEL_INFERCONTEXT_H_\n\n#include \"oneflow/core/kernel/kernel_context.h\"\n#include \"oneflow/core/kernel/user_kernel.h\"\n#include \"OneFlow/OKL/Kernel/RegContext.h\"\n\nnamespace oneflow {\nnamespace okl {\n\nclass InferContext final : public user_op::InferContext {\n public:\n  explicit InferContext(RegContext const* reg_ctx);\n\n  const user_op::TensorDesc& InputTensorDesc(const std::string& arg_name,\n                                             int32_t index) const override {\n    return *LogicalTensorDesc4ArgNameAndIndex(arg_name, index);\n  }\n  const user_op::TensorDesc& OutputTensorDesc(const std::string& arg_name,\n                                              int32_t index) const override {\n    return *LogicalTensorDesc4ArgNameAndIndex(arg_name, index);\n  }\n  user_op::TensorDesc* MutOutputTensorDesc(const std::string&, int32_t) override { TODO(); }\n  const user_op::TensorDesc* LogicalTensorDesc4ArgNameAndIndex(const std::string& arg_name,\n                                                               int32_t index) const override;\n\n  const Shape& InputShape(const std::string& arg_name, int32_t index) const override;\n\n  const Shape& OutputShape(const std::string&, int32_t) const override { TODO(); }\n  void SetOutputShape(const std::string&, int32_t, const Shape&) override { TODO(); }\n  const Shape& Shape4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override;\n  void SetShape4ArgNameAndIndex(const std::string&, int32_t, const Shape&) override { TODO(); }\n  const Stride& InputStride(const std::string&, int32_t) const override { TODO(); }\n  const Stride& OutputStride(const std::string&, int32_t) const override { TODO(); }\n  void SetOutputStride(const std::string&, int32_t, const Stride&) override { TODO(); }\n  const Stride& Stride4ArgNameAndIndex(const std::string&, int32_t) const override { TODO(); }\n  void SetStride4ArgNameAndIndex(const std::string&, int32_t, const Stride&) override { TODO(); }\n  DataType InputDType(const std::string&, int32_t) const override { TODO(); }\n  DataType OutputDType(const std::string&, int32_t) const override { TODO(); }\n  void SetOutputDType(const std::string&, int32_t, DataType) override { TODO(); }\n  DataType Dtype4ArgNameAndIndex(const std::string&, int32_t) const override { TODO(); }\n  void SetDtype4ArgNameAndIndex(const std::string&, int32_t, DataType) override { TODO(); }\n  MemoryFormat InputMemoryFormat(const std::string&, int32_t) const override { TODO(); }\n  MemoryFormat OutputMemoryFormat(const std::string&, int32_t) const override { TODO(); }\n  void SetOutputMemoryFormat(const std::string&, int32_t, MemoryFormat) override { TODO(); }\n  MemoryFormat MemoryFormat4ArgNameAndIndex(const std::string&, int32_t) const override { TODO(); }\n  void SetMemoryFormat4ArgNameAndIndex(const std::string&, int32_t, MemoryFormat) override {\n    TODO();\n  }\n\n  const std::vector<std::pair<std::string, int32_t>>& inputs() const override {\n    return reg_ctx_->inputs();\n  }\n  const std::vector<std::pair<std::string, int32_t>>& outputs() const override {\n    return reg_ctx_->outputs();\n  }\n\n  const std::string& input(const std::string& arg_name, int32_t index) const override {\n    return reg_ctx_->user_op_conf().input(arg_name, index);\n  }\n  const std::string& output(const std::string& arg_name, int32_t index) const override {\n    return reg_ctx_->user_op_conf().output(arg_name, index);\n  }\n\n  bool has_input(const std::string& arg_name, int32_t index) const override {\n    return reg_ctx_->user_op_conf().has_input(arg_name, index);\n  }\n  bool has_output(const std::string& arg_name, int32_t index) const override {\n    return reg_ctx_->user_op_conf().has_input(arg_name, index);\n  }\n\n  int32_t input_size(const std::string& arg_name) const override {\n    return reg_ctx_->user_op_conf().input_size(arg_name);\n  }\n  int32_t output_size(const std::string& arg_name) const override {\n    return reg_ctx_->user_op_conf().output_size(arg_name);\n  }\n  const std::string& op_name() const override { return reg_ctx_->user_op_conf().op_name(); }\n  const std::string& op_type_name() const override {\n    return reg_ctx_->user_op_conf().op_type_name();\n  }\n  const std::string& op_loc() const override { TODO(); }\n\n  const ParallelContext& parallel_ctx() const override { TODO(); }\n  const ParallelDesc& parallel_desc() const override { TODO(); }\n\n  const SbpParallel& SbpParallel4ArgNameAndIndex(const std::string&, int32_t) const override {\n    TODO();\n  }\n\n  const NdSbp& NdSbp4ArgNameAndIndex(const std::string&, int32_t) const override { TODO(); }\n\n  bool InputIsDynamic(const std::string&, int32_t) const override { TODO(); }\n  bool OutputIsDynamic(const std::string&, int32_t) const override { TODO(); }\n  void SetOutputIsDynamic(const std::string&, int32_t, bool) override { TODO(); }\n  bool IsDynamic4ArgNameAndIndex(const std::string&, int32_t) const override { TODO(); }\n  void SetIsDynamic4ArgNameAndIndex(const std::string&, int32_t, bool) override { TODO(); }\n\n  int64_t parallel_num() const override { TODO(); }\n\n private:\n  const std::shared_ptr<const user_op::AttrVal>& Attr4Name(\n      const std::string& attr_name) const override;\n\n  RegContext const* reg_ctx_;\n};\n\n}  // namespace okl\n}  // namespace oneflow\n\n#endif  // ONEFLOW_IR_INCLUDE_ONEFLOW_OKL_KERNEL_INFERCONTEXT_H_\n"
  },
  {
    "path": "oneflow/ir/include/OneFlow/OKL/Kernel/InitContext.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#ifndef ONEFLOW_IR_INCLUDE_ONEFLOW_OKL_KERNEL_CACHECONTEXT_H_\n#define ONEFLOW_IR_INCLUDE_ONEFLOW_OKL_KERNEL_CACHECONTEXT_H_\n\n#include \"OneFlow/OKL/Kernel/RegContext.h\"\n#include \"oneflow/core/common/tensor_desc.h\"\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/framework/op_kernel.h\"\n\nnamespace oneflow {\nnamespace okl {\n\nclass InitContext final : public user_op::KernelCacheContext, public user_op::KernelInitContext {\n public:\n  InitContext(RegContext const* reg_ctx, user_op::KernelComputeContext* compute_ctx)\n      : reg_ctx_(reg_ctx), compute_ctx_(compute_ctx) {}\n\n  DeviceType device_type() const override { return reg_ctx_->device_type(); }\n  const ParallelContext& parallel_ctx() const override { return compute_ctx_->parallel_ctx(); }\n  ep::Stream* stream() override { return compute_ctx_->stream(); }\n\n  const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(const std::string& arg_name,\n                                                        int32_t index) const override {\n    return reg_ctx_->TensorDesc4ArgNameAndIndex(arg_name, index);\n  }\n\n  const SbpParallel& SbpParallel4ArgNameAndIndex(const std::string&, int32_t) const override {\n    TODO();\n  }\n  const user_op::TensorDesc* LogicalTensorDesc4ArgNameAndIndex(const std::string& arg_name,\n                                                               int32_t index) const override {\n    return reg_ctx_->TensorDesc4ArgNameAndIndex(arg_name, index);\n  }\n  const ParallelDesc& parallel_desc() const override { TODO(); }\n  const NdSbp& NdSbp4ArgNameAndIndex(const std::string&, int32_t) const override { TODO(); }\n\n  const std::vector<std::pair<std::string, int32_t>>& inputs() const override {\n    return reg_ctx_->inputs();\n  }\n  const std::vector<std::pair<std::string, int32_t>>& outputs() const override {\n    return reg_ctx_->outputs();\n  }\n\n private:\n  RegContext const* reg_ctx_;\n  user_op::KernelComputeContext* compute_ctx_;\n\n  const user_op::UserOpConfWrapper& user_op_conf() const override {\n    return reg_ctx_->user_op_conf();\n  }\n  const std::shared_ptr<const user_op::AttrVal>& Attr4Name(\n      const std::string& attr_name) const override {\n    return reg_ctx_->Attr4Name(attr_name);\n  }\n};\n\n}  // namespace okl\n}  // namespace oneflow\n\n#endif  // ONEFLOW_IR_INCLUDE_ONEFLOW_OKL_KERNEL_CACHECONTEXT_H_\n"
  },
  {
    "path": "oneflow/ir/include/OneFlow/OKL/Kernel/JITEngine.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_IR_ONEFLOW_EXTENSION_INCLUDE_ONEFLOW_JITENGINE_H_\n#define ONEFLOW_IR_ONEFLOW_EXTENSION_INCLUDE_ONEFLOW_JITENGINE_H_\n\n#include \"mlir/ExecutionEngine/ExecutionEngine.h\"\n#include \"mlir/IR/BuiltinOps.h\"\n#include \"oneflow/core/framework/op_kernel.h\"\n#include \"OneFlow/OKL/Kernel/LauncherContext.h\"\n\nextern \"C\" {\nvoid okl_llvm_func(void* launcher, int64_t index);\n}  // extern \"C\"\n\nnamespace oneflow {\nnamespace okl {\n\nusing LLVMLaunchArgs = std::tuple<LauncherContext*, int>;\n\nclass JITEngine {\n public:\n  explicit JITEngine(mlir::ModuleOp module);\n\n  void Run(const std::string& name, LauncherContext* launcher) const {\n    auto error = engine_->invoke(name, launcher);\n    CHECK(!error) << \"fail to invoke jit engine, error: \" << llvm::toString(std::move(error));\n  }\n\n private:\n  std::unique_ptr<mlir::ExecutionEngine> engine_;\n};\n\nnamespace llvm_func {\n#define C_FUNC_NAME(func) #func\n\nconst auto LLVM_FUNC = C_FUNC_NAME(okl_llvm_func);\n\n#undef C_FUNC_NAME\n}  // namespace llvm_func\n\n}  // namespace okl\n}  // namespace oneflow\n\n#endif  // ONEFLOW_IR_ONEFLOW_EXTENSION_INCLUDE_ONEFLOW_JITENGINE_H_\n"
  },
  {
    "path": "oneflow/ir/include/OneFlow/OKL/Kernel/JITOpInfer.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_IR_ONEFLOW_EXTENSION_INCLUDE_ONEFLOW_JITOPINFER_H_\n#define ONEFLOW_IR_ONEFLOW_EXTENSION_INCLUDE_ONEFLOW_JITOPINFER_H_\n\n#include \"oneflow/core/framework/infer_util.h\"\n\nnamespace oneflow {\n\nnamespace ir {\n\nnamespace jit {\n\nMaybe<void> InferTensorDesc(user_op::InferContext* ctx);\nMaybe<void> SetTensorDataType(user_op::InferContext* ctx);\n\n}  // namespace jit\n\n}  // namespace ir\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_IR_ONEFLOW_EXTENSION_INCLUDE_ONEFLOW_JITOPINFER_H_\n"
  },
  {
    "path": "oneflow/ir/include/OneFlow/OKL/Kernel/LauncherContext.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_IR_INCLUDE_ONEFLOW_OKL_KERNEL_LAUNCHER_CONTEXT_H_\n#define ONEFLOW_IR_INCLUDE_ONEFLOW_OKL_KERNEL_LAUNCHER_CONTEXT_H_\n\n#include \"oneflow/core/framework/op_kernel.h\"\n#include \"OneFlow/OKL/OKLOps.h\"\n#include \"OneFlow/OKL/Kernel/RegContext.h\"\n#include \"OneFlow/OKL/Kernel/WrapperContext.h\"\n#include \"mlir/IR/Operation.h\"\n\nnamespace oneflow {\nnamespace okl {\n\nclass LauncherContext final {\n public:\n  // compile the mlir to ctx\n  explicit LauncherContext(mlir::ModuleOp module);\n  // infer ctx with okl info\n  bool Infer() { return inferred_; }\n  bool Infer(user_op::KernelComputeContext* compute_context);\n  // launch kernel with index\n  void Launch(int index);\n\n private:\n  bool inferred_ = false;\n\n  std::vector<CompileTimeWrapperContext> compile_ctx_vec_;\n  std::vector<RunTimeWrapperContext> run_ctx_vec_;\n};\n\n}  // namespace okl\n}  // namespace oneflow\n\n#endif  // ONEFLOW_IR_INCLUDE_ONEFLOW_OKL_KERNEL_LAUNCHER_CONTEXT_H_\n"
  },
  {
    "path": "oneflow/ir/include/OneFlow/OKL/Kernel/LauncherState.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_IR_INCLUDE_ONEFLOW_OKL_KERNEL_OP_KERNEL_STATE_H_\n#define ONEFLOW_IR_INCLUDE_ONEFLOW_OKL_KERNEL_OP_KERNEL_STATE_H_\n\n#include \"OneFlow/OneFlowDialect.h\"\n#include \"OneFlow/OKL/OKLDialect.h\"\n#include \"OneFlow/OKL/Kernel/JITEngine.h\"\n#include \"OneFlow/OKL/Kernel/LauncherContext.h\"\n#include \"OneFlow/OKL/Conversion/Conversion.h\"\n#include \"mlir/Dialect/Arith/IR/Arith.h\"\n#include \"mlir/IR/DialectRegistry.h\"\n#include \"mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h\"\n#include \"mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h\"\n\nnamespace oneflow {\nnamespace okl {\n\ninline mlir::DialectRegistry GetRegistry() {\n  mlir::DialectRegistry registry;\n  registry.insert<mlir::oneflow::OneFlowDialect, mlir::okl::OKLDialect, mlir::func::FuncDialect,\n                  mlir::arith::ArithDialect, mlir::LLVM::LLVMDialect>();\n  mlir::registerBuiltinDialectTranslation(registry);\n  mlir::registerLLVMDialectTranslation(registry);\n  return registry;\n}\n\nclass LauncherState final : public user_op::OpKernelState {\n public:\n  explicit LauncherState(user_op::KernelInitContext* ctx);\n  ~LauncherState() = default;\n\n  void DoCompute(user_op::KernelComputeContext* ctx);\n  bool IsCudaGraphSupported(user_op::KernelInitContext* ctx);\n\n private:\n  // manage module(compile)\n  mlir::MLIRContext mlir_ctx_;\n  mlir::OwningOpRef<mlir::ModuleOp> module_;\n\n  // manage context\n  LauncherContext launcher_context_;\n\n  // manage engine(runtime)\n  JITEngine engine_;\n};\n\n}  // namespace okl\n}  // namespace oneflow\n\n#endif  // ONEFLOW_IR_INCLUDE_ONEFLOW_OKL_KERNEL_OP_KERNEL_STATE_H_\n"
  },
  {
    "path": "oneflow/ir/include/OneFlow/OKL/Kernel/README.md",
    "content": "## context相关概念与其生命周期：\n### LauncherState\nLauncherState 是OpKernelState的派生类，在okl kernel的初始化kernel state的阶段被创建。\n``` c++\nLauncherState final : public user_op::OpKernelState\n```\n每个LauncherState拥有一个LauncherContext管理运行的上下文和一个JIT Engine管理运行时引擎。\n    单个okl kernel的资源的管理者\n     - LauncherContext的维护者，负责对于context信息的更新;\n     - JIT Engine的所有者\n\n### LauncherContext\nLauncherContext作为单个okl kernel的上下文的管理者，维护若干有序的oneflow op的上下文资源信息，每个oneflow op对应的上下文资源对应一个专门的WrapperContext作为一个总体的维护者。\n因此LauncherContext下维护一系列编译期状态的WrapperContext和运行时状态的WrapperContext以对应不同阶段的上下文。这些ctx与oneflow op一一对应。\n```\nclass LauncherContext final {\n  bool inferred_ = false;\n\n  std::vector<CompileTimeWrapperContext> compile_ctx_vec_;\n  std::vector<RunTimeWrapperContext> run_ctx_vec_;\n};\n``` \n\n### WrapperContext(op, ctx):\n    单个被okl wrap的oneflow op的管理者，编译期存在的东西在初始化后不可被改变，运行时需要做一个懒汉模式的infer推导流程。\n    1. 推导前\n    - reg_ctx_(op) \n     - device\n     - inputs/outputs\n     - kernel\n     - user config\n    \n    2. 推导后\n    - init_ctx_(reg_ctx_, ctx)\n    - state_(reg_ctx, init_ctx_)\n    - cache_(reg_ctx, init_ctx_)\n    - compute_ctx_(ctx)\n\n```\nclass CompileTimeWrapperContext {\n  std::shared_ptr<const RegContext> reg_ctx_;\n};\n\nclass RunTimeWrapperContext : public CompileTimeWrapperContext {\n  std::shared_ptr<ComputeContext> compute_ctx_;\n  std::shared_ptr<InitContext> init_ctx_;\n\n  std::shared_ptr<user_op::OpKernelState> kernel_state_;\n  std::shared_ptr<user_op::OpKernelCache> kernel_cache_;\n};\n```\nCompileTimeWrapperContext维护着从ir获取得到的上下文信息并加以封装到reg ctx，用于作为后面推导RunTimeWrapperContext的输入之一。\nRunTimeWrapperContext通过CompileTimeWrapperContext的信息以及okl kernel所创建的comp ctx以及tmp buffer等资源组成了单个op运行时计算所需的实际上下文环境。通过创建的init ctx，创建kernel state，kernel cache等资源用于kernel的compute计算。"
  },
  {
    "path": "oneflow/ir/include/OneFlow/OKL/Kernel/RegContext.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_IR_INCLUDE_ONEFLOW_OKL_KERNEL_REGCONTEXT_H_\n#define ONEFLOW_IR_INCLUDE_ONEFLOW_OKL_KERNEL_REGCONTEXT_H_\n\n#include \"oneflow/core/framework/user_op_kernel_registry.h\"\n#include \"OneFlow/UserOpReflection.h\"\n#include \"mlir/IR/Operation.h\"\n\nnamespace oneflow {\nnamespace okl {\n// this context should support querying information about the kernel from representation in MLIR\nusing ArgVec = std::vector<std::pair<std::string, int32_t>>;\nclass RegContext final : public user_op::KernelRegContext {\n public:\n  explicit RegContext(mlir::Operation* op);\n  ~RegContext() = default;\n\n  // override user_op KernelRegContext\n  DeviceType device_type() const override;\n  const ParallelContext& parallel_ctx() const override;\n  const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(const std::string& arg_name,\n                                                        int32_t index) const override;\n  const ArgVec& inputs() const override;\n  const ArgVec& outputs() const override;\n  const user_op::UserOpConfWrapper& user_op_conf() const override;\n  const std::shared_ptr<const user_op::AttrVal>& Attr4Name(\n      const std::string& attr_name) const override;\n\n  const size_t GetTmpBufferSize() const;\n  ::mlir::Operation* GetOp() const { return op_; };\n  const user_op::OpKernel* GetKernel() const { return kernel_; };\n\n private:\n  ::mlir::Operation* op_;\n  DeviceType device_type_ = DeviceType::kInvalidDevice;\n  std::unordered_map<mlir::oneflow::user_op::ArgID, user_op::NaiveTensorDesc> arg2tensor_desc_{};\n  ArgVec inputs_;\n  ArgVec outputs_;\n  user_op::UserOpConfWrapper conf_wrapper_;\n\n  const user_op::OpKernelRegistryResult* reg_res_ = nullptr;\n  const user_op::OpKernel* kernel_ = nullptr;\n};\n\n}  // namespace okl\n}  // namespace oneflow\n\n#endif  // ONEFLOW_IR_INCLUDE_ONEFLOW_OKL_KERNEL_REGCONTEXT_H_\n"
  },
  {
    "path": "oneflow/ir/include/OneFlow/OKL/Kernel/TmpBufferManager.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_IR_INCLUDE_ONEFLOW_OKL_KERNEL_TMP_BUFFER_MANAGER_H_\n#define ONEFLOW_IR_INCLUDE_ONEFLOW_OKL_KERNEL_TMP_BUFFER_MANAGER_H_\n#include \"oneflow/core/common/data_type.h\"\n#include \"oneflow/core/common/shape.h\"\n#include \"oneflow/core/framework/infer_util.h\"\n#include \"oneflow/core/framework/user_op_tensor.h\"\n#include <unordered_map>\n\nnamespace oneflow {\nnamespace okl {\n\nclass TmpBufferManager {\n  class PoolToTensor final : public oneflow::user_op::Tensor {\n   public:\n    explicit PoolToTensor(user_op::Tensor* tensor, const user_op::TensorDesc* tensor_desc,\n                          int64_t offset)\n        : tensor_(tensor),\n          raw_dptr_(reinterpret_cast<char*>(tensor_->mut_raw_dptr()) + offset),\n          tensor_desc_(tensor_desc) {}\n\n    ShapeView shape_view() const override { return tensor_desc_->shape(); }\n    const Stride& stride() const override { return tensor_desc_->stride(); }\n    DataType data_type() const override { return tensor_desc_->data_type(); }\n    MemoryFormat memory_format() const override { return tensor_desc_->memory_format(); }\n    MutShapeView mut_shape_view() override { TODO(); }\n    const MemoryCase& mem_case() const override { return tensor_->mem_case(); }\n\n    const void* raw_dptr() const override { return raw_dptr_; }\n    void* mut_raw_dptr() override { return raw_dptr_; }\n\n   private:\n    user_op::Tensor* tensor_;\n    void* raw_dptr_;\n    const user_op::TensorDesc* tensor_desc_;\n  };\n\n  class PoolToBuffer final : public oneflow::user_op::Tensor {\n   public:\n    explicit PoolToBuffer(user_op::Tensor* tensor, int64_t size, int64_t offset)\n        : tensor_(tensor),\n          raw_dptr_(reinterpret_cast<char*>(tensor_->mut_raw_dptr()) + offset),\n          shape_({size}) {}\n\n    ShapeView shape_view() const override { return shape_; }\n    const Stride& stride() const override { return tensor_->stride(); }\n    DataType data_type() const override { return tensor_->data_type(); }\n    MemoryFormat memory_format() const override { return tensor_->memory_format(); }\n    MutShapeView mut_shape_view() override { return shape_; }\n    const MemoryCase& mem_case() const override { return tensor_->mem_case(); }\n\n    const void* raw_dptr() const override { return raw_dptr_; }\n    void* mut_raw_dptr() override { return raw_dptr_; }\n\n   private:\n    user_op::Tensor* tensor_;\n    void* raw_dptr_;\n    Shape shape_;\n  };\n\n public:\n  static size_t InferTmpSize(user_op::InferContext* ctx);\n\n  explicit TmpBufferManager(user_op::Tensor* tensor) : tensor_(tensor) {}\n  user_op::Tensor* GetPoolTensor(const user_op::TensorDesc* tensor_desc, int64_t offset) {\n    CHECK_LE(offset + tensor_desc->shape().elem_cnt() * GetSizeOfDataType(tensor_desc->data_type()),\n             tensor_->shape_view().elem_cnt());\n    auto res = tensor_map_.insert({tensor_desc, PoolToTensor(tensor_, tensor_desc, offset)}).first;\n    return &res->second;\n  }\n\n  user_op::Tensor* GetPoolBuffer(int64_t size, int64_t offset) {\n    auto res = buffer_map_.insert({{size, offset}, PoolToBuffer(tensor_, size, offset)}).first;\n    return &res->second;\n  }\n\n private:\n  std::unordered_map<const user_op::TensorDesc*, PoolToTensor> tensor_map_{};\n  std::unordered_map<std::pair<int64_t, int64_t>, PoolToBuffer> buffer_map_{};\n  user_op::Tensor* tensor_;\n};\n\n}  // namespace okl\n}  // namespace oneflow\n\n#endif  // ONEFLOW_IR_INCLUDE_ONEFLOW_OKL_KERNEL_TMP_BUFFER_MANAGER_H_\n"
  },
  {
    "path": "oneflow/ir/include/OneFlow/OKL/Kernel/WrapperContext.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_IR_INCLUDE_ONEFLOW_OKL_KERNEL_WRAPPERCONTEXT_H_\n#define ONEFLOW_IR_INCLUDE_ONEFLOW_OKL_KERNEL_WRAPPERCONTEXT_H_\n\n#include \"mlir/IR/BuiltinAttributes.h\"\n#include \"OneFlow/OKL/Kernel/InitContext.h\"\n#include \"OneFlow/OKL/Kernel/RegContext.h\"\n#include \"OneFlow/OKL/Kernel/ComputeContext.h\"\n#include \"oneflow/core/framework/op_kernel.h\"\n\nnamespace oneflow {\nnamespace okl {\n\nclass CompileTimeWrapperContext {\n public:\n  explicit CompileTimeWrapperContext(mlir::Operation* op)\n      : reg_ctx_(std::make_shared<const RegContext>(op)) {}\n\n  CompileTimeWrapperContext(CompileTimeWrapperContext&&) = default;\n\n  RegContext const* GetRegContext() const { return reg_ctx_.get(); }\n\n private:\n  std::shared_ptr<const RegContext> reg_ctx_;\n};\n\nclass RunTimeWrapperContext {\n public:\n  RunTimeWrapperContext(mlir::Operation* op, user_op::KernelComputeContext* ctx)\n      : compile_time_wrapper_ctx_(op),\n        compute_ctx_(std::make_unique<ComputeContext>(GetRegContext(), ctx)),\n        init_ctx_(std::make_unique<InitContext>(GetRegContext(), ctx)),\n        kernel_state_(GetRegContext()->GetKernel()->CreateOpKernelState(init_ctx_.get())),\n        kernel_cache_(GetRegContext()->GetKernel()->InitOpKernelCache(init_ctx_.get())) {}\n\n  void Run() {\n    GetRegContext()->GetKernel()->Compute(compute_ctx_.get(), kernel_state_.get(),\n                                          kernel_cache_.get());\n  }\n\n  RegContext const* GetRegContext() const { return compile_time_wrapper_ctx_.GetRegContext(); }\n\n private:\n  CompileTimeWrapperContext compile_time_wrapper_ctx_;\n  std::unique_ptr<ComputeContext> compute_ctx_;\n  std::unique_ptr<InitContext> init_ctx_;\n\n  std::shared_ptr<user_op::OpKernelState> kernel_state_;\n  std::shared_ptr<user_op::OpKernelCache> kernel_cache_;\n};\n\n}  // namespace okl\n}  // namespace oneflow\n\n#endif  // ONEFLOW_IR_INCLUDE_ONEFLOW_OKL_KERNEL_WRAPPERCONTEXT_H_\n"
  },
  {
    "path": "oneflow/ir/include/OneFlow/OKL/OKLAttributes.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_IR_INCLUDE_OKL_OKLATTRIBUTES_H_\n#define ONEFLOW_IR_INCLUDE_OKL_OKLATTRIBUTES_H_\n\n#include \"mlir/IR/BuiltinAttributes.h\"\n#include \"mlir/Support/LLVM.h\"\n#include \"OneFlow/OKLEnums.h.inc\"\n#define GET_ATTRDEF_CLASSES\n#include \"OneFlow/OKLAttributes.h.inc\"\n\n#endif  // ONEFLOW_IR_INCLUDE_OKL_OKLATTRIBUTES_H_\n"
  },
  {
    "path": "oneflow/ir/include/OneFlow/OKL/OKLAttributes.td",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#ifndef ONEFLOW_IR_INCLUDE_OKL_OKLATTRIBUTES\n#define ONEFLOW_IR_INCLUDE_OKL_OKLATTRIBUTES\n\ninclude \"OneFlow/OKL/OKLDialect.td\"\ninclude \"mlir/IR/AttrTypeBase.td\"\ninclude \"mlir/IR/SymbolInterfaces.td\"\ninclude \"mlir/Interfaces/SideEffectInterfaces.td\"\ninclude \"mlir/Interfaces/InferTypeOpInterface.td\"\n\n\n#endif // ONEFLOW_IR_INCLUDE_OKL_OKLATTRIBUTES\n"
  },
  {
    "path": "oneflow/ir/include/OneFlow/OKL/OKLBase.td",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_IR_INCLUDE_OKL_OKLBASE\n#define ONEFLOW_IR_INCLUDE_OKL_OKLBASE\n\ninclude \"OneFlow/OKL/OKLDialect.td\"\ninclude \"mlir/IR/AttrTypeBase.td\"\ninclude \"mlir/IR/SymbolInterfaces.td\"\ninclude \"mlir/Interfaces/SideEffectInterfaces.td\"\ninclude \"mlir/Interfaces/InferTypeOpInterface.td\"\n\nclass OKL_Op<string name, list<Trait> traits = []> :\n    Op<OKL_Dialect, name, traits>;\n\nclass OKL_Type<string name, string typeMnemonic, list<Trait> traits = []>\n    : TypeDef<OKL_Dialect, name, traits> {\n  let mnemonic = typeMnemonic;\n}\n\n#endif // ONEFLOW_IR_INCLUDE_OKL_OKLBASE\n"
  },
  {
    "path": "oneflow/ir/include/OneFlow/OKL/OKLDialect.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_IR_INCLUDE_OKL_OKLDIALECT_H_\n#define ONEFLOW_IR_INCLUDE_OKL_OKLDIALECT_H_\n\n#include \"OneFlow/Passes.h\"\n#include \"mlir/IR/Dialect.h\"\n#include \"mlir/Dialect/Func/IR/FuncOps.h\"\n\n#include \"OneFlow/OKLDialect.h.inc\"\n#include \"OneFlow/OKL/OKLOps.h\"\n\n#endif  // ONEFLOW_IR_INCLUDE_OKL_OKLDIALECT_H_\n"
  },
  {
    "path": "oneflow/ir/include/OneFlow/OKL/OKLDialect.td",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_IR_INCLUDE_OKL_OKLDIALECT\n#define ONEFLOW_IR_INCLUDE_OKL_OKLDIALECT\n\ninclude \"mlir/IR/OpBase.td\"\n\ndef OKL_Dialect : Dialect {\n    let name = \"okl\";\n    let summary = \"OneFlow Kernel Launch Dialect.\";\n    let description = [{\n        This dialect is the IR of abstract represent of OneFlow Kernel Launch Op.\n    }];\n    let cppNamespace = \"::mlir::okl\";\n    let dependentDialects = [\n        \"func::FuncDialect\"\n    ];\n    let useDefaultTypePrinterParser = 1;\n}\n\n#endif // ONEFLOW_IR_INCLUDE_OKL_OKLDIALECT\n"
  },
  {
    "path": "oneflow/ir/include/OneFlow/OKL/OKLOps.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#ifndef ONEFLOW_IR_INCLUDE_ONEFLOW_OKL_OKLOPS_H_\n#define ONEFLOW_IR_INCLUDE_ONEFLOW_OKL_OKLOPS_H_\n\n#include \"mlir/IR/Dialect.h\"\n#include \"mlir/IR/OpDefinition.h\"\n#include \"mlir/IR/OpImplementation.h\"\n#include \"mlir/IR/Builders.h\"\n#include \"mlir/IR/FunctionInterfaces.h\"\n#include \"mlir/Interfaces/CallInterfaces.h\"\n#include \"mlir/Interfaces/SideEffectInterfaces.h\"\n#include \"mlir/Interfaces/ControlFlowInterfaces.h\"\n#include \"mlir/Dialect/LLVMIR/LLVMDialect.h\"\n#include \"OneFlow/OKL/OKLTypes.h\"\n#include \"OneFlow/OKL/OKLAttributes.h\"\n\nnamespace mlir {\nnamespace func {\nclass FuncOp;\n}  // namespace func\n}  // namespace mlir\n\n#define GET_OP_CLASSES\n#include \"OneFlow/OKLOps.h.inc\"\n\n#endif  // ONEFLOW_IR_INCLUDE_ONEFLOW_OKL_OKLOPS_H_\n"
  },
  {
    "path": "oneflow/ir/include/OneFlow/OKL/OKLOps.td",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_IR_INCLUDE_OKL_OKLOPS\n#define ONEFLOW_IR_INCLUDE_OKL_OKLOPS\n\ninclude \"OneFlow/OKL/OKLDialect.td\"\ninclude \"OneFlow/OKL/OKLBase.td\"\ninclude \"OneFlow/OKL/OKLTypes.td\"\ninclude \"mlir/Pass/PassBase.td\"\ninclude \"mlir/Dialect/LLVMIR/LLVMOpBase.td\"\ninclude \"mlir/IR/OpBase.td\"\ninclude \"mlir/IR/EnumAttr.td\"\n\n\ndef GetTensorFromArgOp : OKL_Op<\"get_tensor_from_arg\"> {\n  let summary = \"get tensor as arguments from operands of context\";\n  let description = [{\n    get tensor value from launcher context as arguments.\n  }];\n\n  let arguments = (ins\n    LauncherContextType:$launcher_ctx,\n    I32Attr:$index\n  );\n  let results = (outs AnyTensor);\n}\n\ndef GetTensorFromRetOp : OKL_Op<\"get_tensor_from_ret\"> {\n  let summary = \"get tensor as arguments from results of context\";\n  let description = [{\n    get tensor value from launcher context as arguments.\n  }];\n\n  let arguments = (ins\n    LauncherContextType:$launcher_ctx,\n    I32Attr:$index\n  );\n  let results = (outs AnyTensor);\n}\n\ndef GetTensorAsRetOp : OKL_Op<\"get_tensor_as_ret\"> {\n  let summary = \"get tensor as outcomes from results of context\";\n  let description = [{\n    get tensor value from launcher context as outcomes.\n  }];\n\n  let arguments = (ins\n    LauncherContextType:$launcher_ctx,\n    AnyTensor:$tensor,\n    I32Attr:$index\n  );\n  let results = (outs AnyTensor);\n}\n\ndef PoolToTensorOp : OKL_Op<\"pool_to_tensor\"> {\n  let arguments = (ins\n    LauncherContextType:$launcher_ctx,\n    I64Attr:$offset\n  );\n  let results = (outs AnyTensor);\n}\n\ndef PoolToBufferOp : OKL_Op<\"pool_to_buffer\"> {\n  let arguments = (ins\n    LauncherContextType:$launcher_ctx,\n    I64Attr:$offset\n  );\n  let results = (outs AnyTensor);\n}\n\ndef TensorToPoolOp : OKL_Op<\"tensor_to_pool\"> {\n  let arguments = (ins\n    LauncherContextType:$launcher_ctx,\n    AnyTensor:$tensor,\n    I64Attr:$offset\n  );\n  let results = (outs AnyTensor);\n}\n\ndef WrapperKernelOp : OKL_Op<\"wrapper_kernel\"> {\n  let summary = \"build reg context operation\";\n  let description = [{\n    this context is generated from module op and used on kernel/run_ctx build phase.\n    each wrapped op has their own reg_ctx with their own attrs.\n  }];\n\n  let arguments = (ins\n    I32Attr:$index\n  );\n\n  let regions = (region AnyRegion:$body);\n}\n\ndef ReturnOp : OKL_Op<\"return\", [HasParent<\"WrapperKernelOp\">, Terminator]> {\n  let summary = \"return operation\";\n  let description = [{\n    return oneflow ops in reg context\n    ```\n  }];\n\n  let arguments = (ins Variadic<AnyType>:$operands);\n\n  let builders = [\n    OpBuilder<(ins),\n    [{ build($_builder, $_state, llvm::None); }]>];\n\n  let assemblyFormat = \"attr-dict ($operands^ `:` type($operands))?\";\n}\n\ndef LowerLauncherToLLVMPtrPass : Pass<\"lower-launcher-to-llvm-ptr\", \"ModuleOp\"> {\n  let summary = \"convert okl dialect func to llvm dialect\";\n  let constructor = \"mlir::okl::createLowerLauncherToLLVMPtrPass()\";\n}\n\ndef LowerOKLToLLVMCallPass : Pass<\"lower-okl-to-llvm-call\", \"ModuleOp\"> {\n  let summary = \"convert okl dialect ops to llvm dialect llvm.call\";\n  let constructor = \"mlir::okl::createLowerOKLToLLVMCallPass()\";\n}\n\ndef TagCudaGraphSupportPass : Pass<\"tag-cuda-graph-support\", \"ModuleOp\"> {\n  let summary = \"tag cuda graph support according to its wrapped ops\";\n  let constructor = \"mlir::okl::createTagCudaGraphSupportPass()\";\n}\n\n#endif // ONEFLOW_IR_INCLUDE_OKL_OKLOPS\n"
  },
  {
    "path": "oneflow/ir/include/OneFlow/OKL/OKLTypes.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_IR_INCLUDE_ONEFLOW_OKL_OKLTYPES_H_\n#define ONEFLOW_IR_INCLUDE_ONEFLOW_OKL_OKLTYPES_H_\n\n#include \"mlir/IR/Types.h\"\n\n#define GET_TYPEDEF_CLASSES\n#include \"OneFlow/OKLTypes.h.inc\"\n\n#endif  // ONEFLOW_IR_INCLUDE_ONEFLOW_OKL_OKLTYPES_H_\n"
  },
  {
    "path": "oneflow/ir/include/OneFlow/OKL/OKLTypes.td",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_IR_INCLUDE_ONEFLOW_OKL_OKLTYPES\n#define ONEFLOW_IR_INCLUDE_ONEFLOW_OKL_OKLTYPES\n\ninclude \"OneFlow/OKL/OKLBase.td\"\ninclude \"mlir/IR/AttrTypeBase.td\"\n\ndef LauncherContextType : OKL_Type<\"LauncherContext\", \"launcher_ctx\">;\n\n#endif // ONEFLOW_IR_INCLUDE_ONEFLOW_OKL_OKLTYPES\n"
  },
  {
    "path": "oneflow/ir/include/OneFlow/OKL/passes.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_IR_INCLUDE_ONEFLOW_OKL_PASSES_H_\n#define ONEFLOW_IR_INCLUDE_ONEFLOW_OKL_PASSES_H_\n\n#include \"mlir/IR/BuiltinOps.h\"\n#include \"mlir/Pass/Pass.h\"\n#include \"OneFlow/OKL/Conversion/OKLToLLVM.h\"\n\nnamespace mlir {\n\nnamespace okl {\n\n#define GEN_PASS_CLASSES\n#define GEN_PASS_REGISTRATION\n#include \"OneFlow/OKLPasses.h.inc\"\n\n}  // namespace okl\n\n}  // namespace mlir\n\n#endif  // ONEFLOW_IR_INCLUDE_ONEFLOW_OKL_PASSES_H_\n"
  },
  {
    "path": "oneflow/ir/include/OneFlow/OKM/Conversion/Conversion.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_IR_INCLUDE_ONEFLOW_OKM_CONVERSION_CONVERSION_H_\n#define ONEFLOW_IR_INCLUDE_ONEFLOW_OKM_CONVERSION_CONVERSION_H_\n\n#include \"mlir/IR/BuiltinOps.h\"\n\nnamespace mlir {\nnamespace okm {\n\nLogicalResult LowerWrapOpsToOKL(ModuleOp module);\n\n}\n}  // namespace mlir\n#endif  // ONEFLOW_IR_INCLUDE_ONEFLOW_OKM_CONVERSION_CONVERSION_H_\n"
  },
  {
    "path": "oneflow/ir/include/OneFlow/OKM/OKMAttributes.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_IR_INCLUDE_ONEFLOW_OKM_OKMATTRIBUTES_H_\n#define ONEFLOW_IR_INCLUDE_ONEFLOW_OKM_OKMATTRIBUTES_H_\n\n#include \"mlir/IR/BuiltinAttributes.h\"\n\n#define GET_ATTRDEF_CLASSES\n#include \"OneFlow/OKMAttributes.h.inc\"\n\n#endif  // ONEFLOW_IR_INCLUDE_ONEFLOW_OKM_OKMATTRIBUTES_H_\n"
  },
  {
    "path": "oneflow/ir/include/OneFlow/OKM/OKMAttributes.td",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#ifndef ONEFLOW_IR_INCLUDE_ONEFLOW_OKM_OKMATTRIBUTES\n#define ONEFLOW_IR_INCLUDE_ONEFLOW_OKM_OKMATTRIBUTES\n\ninclude \"OneFlow/OKM/OKMBase.td\"\n\n#endif // ONEFLOW_IR_INCLUDE_ONEFLOW_OKM_OKMATTRIBUTES\n"
  },
  {
    "path": "oneflow/ir/include/OneFlow/OKM/OKMBase.td",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_IR_INCLUDE_ONEFLOW_OKM_OKMBASE\n#define ONEFLOW_IR_INCLUDE_ONEFLOW_OKM_OKMBASE\n\ninclude \"OneFlow/OKM/OKMDialect.td\"\ninclude \"mlir/IR/AttrTypeBase.td\"\ninclude \"mlir/Pass/PassBase.td\"\n\nclass OKM_Op<string name, list<Trait> traits = []> :\n    Op<OKM_Dialect, name, traits>;\n\nclass OKM_Attr<string name, string attrMnemonic, list<Trait> traits = []>\n    : AttrDef<OKM_Dialect, name, traits> {\n  let mnemonic = attrMnemonic;\n}\n#endif // ONEFLOW_IR_INCLUDE_ONEFLOW_OKM_OKMBASE\n"
  },
  {
    "path": "oneflow/ir/include/OneFlow/OKM/OKMDialect.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_IR_INCLUDE_ONEFLOW_OKM_OKMDIALECT_H_\n#define ONEFLOW_IR_INCLUDE_ONEFLOW_OKM_OKMDIALECT_H_\n\n#include \"mlir/IR/Dialect.h\"\n#include \"mlir/Dialect/Func/IR/FuncOps.h\"\n\n#include \"OneFlow/OKMDialect.h.inc\"\n#include \"OneFlow/OKM/OKMOps.h\"\n\n#endif  // ONEFLOW_IR_INCLUDE_ONEFLOW_OKM_OKMDIALECT_H_\n"
  },
  {
    "path": "oneflow/ir/include/OneFlow/OKM/OKMDialect.td",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_IR_INCLUDE_ONEFLOW_OKM_OKMDIALECT\n#define ONEFLOW_IR_INCLUDE_ONEFLOW_OKM_OKMDIALECT\n\ninclude \"mlir/IR/OpBase.td\"\n\ndef OKM_Dialect : Dialect {\n    let name = \"okm\";\n    let summary = \"OneFlow Kernel Memory Dialect.\";\n    let description = [{\n        This dialect is the IR of abstract represent of OneFlow Kernel Launch Op.\n    }];\n    let cppNamespace = \"::mlir::okm\";\n    let dependentDialects = [\n        \"func::FuncDialect\",\n        \"memref::MemRefDialect\"\n    ];\n}\n\n#endif // ONEFLOW_IR_INCLUDE_ONEFLOW_OKM_OKMDIALECT\n"
  },
  {
    "path": "oneflow/ir/include/OneFlow/OKM/OKMOps.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#ifndef ONEFLOW_IR_INCLUDE_ONEFLOW_OKM_OKMOPS_H_\n#define ONEFLOW_IR_INCLUDE_ONEFLOW_OKM_OKMOPS_H_\n\n#include \"mlir/IR/Dialect.h\"\n#include \"mlir/IR/OpDefinition.h\"\n#include \"mlir/IR/OpImplementation.h\"\n#include \"mlir/IR/Builders.h\"\n#include \"OneFlow/OKM/OKMAttributes.h\"\n\nnamespace mlir {\nnamespace func {\nclass FuncOp;\n}  // namespace func\n}  // namespace mlir\n\n#define GET_OP_CLASSES\n#include \"OneFlow/OKMOps.h.inc\"\n\n#endif  // ONEFLOW_IR_INCLUDE_ONEFLOW_OKM_OKMOPS_H_\n"
  },
  {
    "path": "oneflow/ir/include/OneFlow/OKM/OKMOps.td",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_IR_INCLUDE_ONEFLOW_OKM_OKMOPS\n#define ONEFLOW_IR_INCLUDE_ONEFLOW_OKM_OKMOPS\n\ninclude \"OneFlow/OKM/OKMAttributes.td\"\ninclude \"OneFlow/OKM/OKMPasses.td\"\ninclude \"mlir/IR/OpBase.td\"\n\ndef ArgToTensorOp : OKM_Op<\"arg_to_tensor\"> {\n  let arguments = (ins\n    I32Attr:$index\n  );\n  let results = (outs AnyTensor);\n}\n\ndef ArgToMemrefOp : OKM_Op<\"arg_to_memref\"> {\n  let arguments = (ins\n    I32Attr:$index\n  );\n  let results = (outs AnyMemRef);\n}\n\ndef RetToMemrefOp : OKM_Op<\"ret_to_memref\"> {\n  let arguments = (ins\n    I32Attr:$index\n  );\n  let results = (outs AnyMemRef);\n}\n\ndef AllocMemrefOp : OKM_Op<\"alloc_memref\"> {\n  let results = (outs AnyMemRef);\n}\n\ndef PlanMemrefOp : OKM_Op<\"plan_memref\"> {\n  let results = (outs AnyMemRef);\n}\n\ndef TensorToRetOp : OKM_Op<\"tensor_to_ret\"> {\n  let arguments = (ins\n    AnyTensor:$tensor,\n    I32Attr:$index\n  );\n  let results = (outs AnyTensor);\n}\n\ndef MemrefToRetOp : OKM_Op<\"memref_to_ret\"> {\n  let arguments = (ins\n    AnyMemRef:$tensor,\n    I32Attr:$index\n  );\n  let results = (outs AnyMemRef);\n}\n\ndef WrapperOp : OKM_Op<\"wrapper_kernel\"> {\n  let arguments = (ins\n    Variadic<AnyType>:$operands\n  );\n  let results = (outs Variadic<AnyType>);\n  let regions = (region AnyRegion:$body);\n}\n\n\ndef ReturnOp : OKM_Op<\"return\", [HasParent<\"WrapperOp\">, Terminator]> {\n  let arguments = (ins Variadic<AnyType>:$operands);\n\n  let builders = [\n    OpBuilder<(ins),\n    [{ build($_builder, $_state, llvm::None); }]>];\n\n  let assemblyFormat = \"attr-dict ($operands^ `:` type($operands))?\";\n}\n\n#endif // ONEFLOW_IR_INCLUDE_ONEFLOW_OKM_OKMOPS\n"
  },
  {
    "path": "oneflow/ir/include/OneFlow/OKM/OKMPasses.td",
    "content": "#ifndef ONEFLOW_OKM_PASSES\n#define ONEFLOW_OKM_PASSES\n\ninclude \"OneFlow/OKM/OKMBase.td\"\n\ndef ExtractOKMTensorPass : Pass<\"extract-okm-tensor\", \"ModuleOp\"> {\n  let summary = \"extract okm tensors from args and rets\";\n  let constructor = \"mlir::okm::createExtractOKMTensorPass()\";\n}\n\ndef WrapOKMKernelPass : Pass<\"wrap-okm-kernel\", \"ModuleOp\"> {\n  let summary = \"wrap kernel in okm\";\n  let constructor = \"mlir::okm::createWrapOKMKernelPass()\";\n}\n\ndef OptOKMMemrefPass : Pass<\"opt-okm-memref\", \"ModuleOp\"> {\n  let summary = \"optimize okm memref\";\n  let constructor = \"mlir::okm::createOptOKMMemrefPass()\";\n}\n\ndef ConvertOKMToOKLPass : Pass<\"convert-okm-to-okl\", \"ModuleOp\"> {\n  let summary = \"convert okm to okl\";\n  let constructor = \"mlir::okm::createConvertOKMToOKLPass()\";\n}\n#endif // ONEFLOW_OKM_PASSES\n\n"
  },
  {
    "path": "oneflow/ir/include/OneFlow/OKM/passes.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_IR_INCLUDE_ONEFLOW_OKM_PASSES_H_\n#define ONEFLOW_IR_INCLUDE_ONEFLOW_OKM_PASSES_H_\n\n#include \"mlir/IR/BuiltinOps.h\"\n#include \"mlir/Pass/Pass.h\"\n\nnamespace mlir {\n\nnamespace okm {\n\nnamespace func_name {\n\nextern const std::string GRAPH_NAME;\nextern const std::string MEM_GRAPH_NAME;\nextern const std::string WRAP_GRAPH_NAME;\nextern const std::string OPT_GRAPH_NAME;\nextern const std::string OKL_GRAPH_NAME;\nextern const std::string OKL_POOL_SIZE_TAG;\n\n}  // namespace func_name\n\nstd::unique_ptr<mlir::Pass> createExtractOKMTensorPass();\nstd::unique_ptr<mlir::Pass> createWrapOKMKernelPass();\nstd::unique_ptr<mlir::Pass> createOptOKMMemrefPass();\nstd::unique_ptr<mlir::Pass> createConvertOKMToOKLPass();\n\n#define GEN_PASS_CLASSES\n#define GEN_PASS_REGISTRATION\n#include \"OneFlow/OKMPasses.h.inc\"\n\n}  // namespace okm\n\n}  // namespace mlir\n\n#endif  // ONEFLOW_IR_INCLUDE_ONEFLOW_OKM_PASSES_H_\n"
  },
  {
    "path": "oneflow/ir/include/OneFlow/OneFlowBase.td",
    "content": "#ifndef ONEFLOW_IR_INCLUDE_ONEFLOW_ONEFLOWBASE_H_\n#define ONEFLOW_IR_INCLUDE_ONEFLOW_ONEFLOWBASE_H_\n\ninclude \"OneFlow/OneFlowDialect.td\"\ninclude \"OneFlow/OneFlowInterfaces.td\"\ninclude \"mlir/IR/SymbolInterfaces.td\"\ninclude \"mlir/Interfaces/SideEffectInterfaces.td\"\ninclude \"/mlir/Interfaces/InferTypeOpInterface.td\"\n\ndef OneFlow_InvalidElement: TypeDef<OneFlow_Dialect, \"InvalidElement\"> {\n  let mnemonic = \"invalid_element\";\n}\ndef OneFlow_CharElement: TypeDef<OneFlow_Dialect, \"CharElement\"> {\n  let mnemonic = \"char_element\";\n}\ndef OneFlow_TensorBufferElement: TypeDef<OneFlow_Dialect, \"TensorBufferElement\"> {\n  let mnemonic = \"tensor_buffer_element\";\n}\ndef OneFlow_OFRecordElement: TypeDef<OneFlow_Dialect, \"OFRecordElement\"> {\n  let mnemonic = \"of_record_element\";\n}\n\ndef OneFlow_OFRecordTensor : TensorOf<[OneFlow_OFRecordElement]>;\ndef OneFlow_TensorBufferTensor : TensorOf<[OneFlow_TensorBufferElement]>;\n\ndef OneFlow_Tensor : TensorOf<[AnyType]>;\n\ndef SI32ArrayAttr : TypedArrayAttrBase<SI32Attr, \"signed 32-bit integer array attribute\"> {}\n\ndef SI64ArrayAttr : TypedArrayAttrBase<SI64Attr, \"signed 64-bit integer array attribute\"> {}\n\ndef ShapeAttr : TypedArrayAttrBase<SI64Attr, \"\"> {}\n\ndef DTArrayAttr : TypedArrayAttrBase<OneFlow_DataType, \"\"> {}\n\ndef ShapeArrayAttr : TypedArrayAttrBase<ShapeAttr, \"\"> {}\n\ndef ComplexDoubleAttr : TypedArrayAttrBase<F64Attr, \"\"> {}\n\ndef BytesAttr : StringBasedAttr<CPred<\"$_self.isa<::mlir::StringAttr>()\">,\n                              \"bytes attribute\">;\n\ndef OneFlow_IsOpConfCompatible : NativeOpTrait<\"IsOpConfCompatible\">;\ndef OneFlow_IsImportCompatible : NativeOpTrait<\"IsImportCompatible\">;\ndef OneFlow_AlternativeOp : NativeOpTrait<\"IsAlternative\">;\ndef OneFlow_TensorSource : NativeOpTrait<\"TensorSource\">;\ndef OneFlow_OnlyExistsInIR : NativeOpTrait<\"OnlyExistsInIR\">;\ndef OneFlow_ElementwiseOp : NativeOpTrait<\"IsElementwise\">;\n\nclass OneFlow_IROp<string mnemonic, list<Trait> traits = []> :\n        Op<OneFlow_Dialect, mnemonic, !listconcat(traits, [OneFlow_OnlyExistsInIR])> {}\n\nclass OneFlow_BaseOp<string mnemonic, list<Trait> traits = []> :\n        Op<OneFlow_Dialect, mnemonic, !listconcat(traits, [OneFlow_IsOpConfCompatible])> {\n  dag op_conf_attrs = (ins\n    StrAttr:$op_name,\n    StrAttr:$device_tag,\n    StrArrayAttr:$device_name, // TODO: change device_name to dict and parse the literal fmt like \"0:0-0\"\n    OptionalAttr<I64Attr>:$scope_symbol_id,\n    OptionalAttr<I64ArrayAttr>:$hierarchy\n  );\n  dag attrs = (ins);\n  dag trait_attrs = (ins);\n  dag user_op_attrs = (ins);\n  dag input = (ins\n    Optional<OneFlow_Tensor>:$UserSourceOpTickInput\n  );\n  dag output = (outs);\n  dag ctrl_input = (ins);\n  dag ctrl_output = (outs);\n  let arguments = !con(\n      input,\n      ctrl_input,\n      op_conf_attrs,\n      trait_attrs,\n      user_op_attrs,\n      attrs\n  );\n  let results = !con(\n    output,\n    ctrl_output\n  );\n  int same_output_regst_num = -1;\n\n  bit has_check_fn = 0;\n  bit has_logical_tensor_desc_infer_fn = 0;\n  bit has_physical_tensor_desc_infer_fn = 0;\n  bit has_get_sbp_fn = 0;\n  bit has_sbp_signature_infer_fn = 0;\n  bit has_data_type_infer_fn = 0;\n  bit has_device_and_stream_infer_fn = 0;\n  bit has_input_arg_modify_fn = 0;\n  bit has_output_arg_modify_fn = 0;\n  bit has_output_blob_time_shape_infer_fn = 0;\n  bit has_nd_sbp_infer_fn = 0;\n  bit has_compute_complexity_fn = 0;\n  bit has_get_nd_sbp_fn = 0;\n  bit has_enumerate_nd_sbp_signatures_fn = 0;\n  bit has_dump_nd_sbp_signature_for_op_conf_fn = 0;\n}\n\nclass OneFlow_Op<string mnemonic, list<Trait> traits = []> :\n        OneFlow_BaseOp<mnemonic, !listconcat(traits, [AttrSizedOperandSegments, AttrSizedResultSegments, DeclareOpInterfaceMethods<ControlEdgeCompatibleInterface>])> {\n  let ctrl_input = (ins Variadic<AnyType>:$ctrl_inputs);\n  let ctrl_output = (outs Optional<AnyType>:$ctrl_output);\n  let trait_attrs = (ins\n    DenseI32ArrayAttr:$operand_segment_sizes,\n    DenseI32ArrayAttr:$result_segment_sizes\n  );\n}\n\nclass OneFlow_UserBaseOp<string mnemonic, list<Trait> traits = [OneFlow_AlternativeOp]> :\n        OneFlow_BaseOp<mnemonic, traits> {\n    let summary = \"\";\n    let user_op_attrs = (ins\n      StrAttr:$op_type_name,\n      // NOTE: vector types must have positive constant sizes, so we can't use I32ElementsAttr\n      I32ArrayAttr:$input_sizes,\n      I32ArrayAttr:$output_sizes\n    );\n}\n\n// Why don't we merge ctrl in/out and data in/out into operand_segment/result_segment_sizes?\n// 1. We only need to erase operand_segment/result_segment_sizes when we are creating a concrete user op\n// 2. Isolating data and ctrl make debug easier and produced IR more human-readable\nclass OneFlow_UserBaseWithCtrlOp<string mnemonic, list<Trait> traits = []> :\n        OneFlow_UserBaseOp<mnemonic, !listconcat(traits, [AttrSizedOperandSegments, AttrSizedResultSegments, DeclareOpInterfaceMethods<ControlEdgeCompatibleInterface>])> {\n    let summary = \"\";\n    let ctrl_input = (ins Variadic<AnyType>:$ctrl_inputs);\n    let ctrl_output = (outs Optional<AnyType>:$ctrl_output);\n    let trait_attrs = (ins\n      DenseI32ArrayAttr:$operand_segment_sizes,\n      DenseI32ArrayAttr:$result_segment_sizes\n    );\n}\n\n\nclass OneFlow_ConvolutionBaseOp<string mnemonic, list<Trait> traits = []> :\n        OneFlow_BaseOp<mnemonic, !listconcat(traits, [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>])> {\n    let summary = \"OneFlow convolution operation\";\n    let description = [{\n    \"The convolution operator consumes an input tensor and a filter, and\"\n    \"computes the output.\"\n    }];\n    let input = (ins\n      OneFlow_Tensor:$in,\n      OneFlow_Tensor:$weight,\n      Optional<OneFlow_Tensor>:$bias,\n      Optional<OneFlow_Tensor>:$_add_to_output\n    );\n    let output = (outs OneFlow_Tensor:$out);\n    let attrs = (ins\n      DefaultValuedAttr<SI32Attr, \"0\">:$filters,\n      SI32ArrayAttr:$padding_before,\n      StrAttr:$data_format,\n      SI32ArrayAttr:$kernel_size,\n      SI32ArrayAttr:$strides,\n      SI32ArrayAttr:$dilation_rate,\n      DefaultValuedAttr<SI32Attr, \"1\">:$groups,\n      DefaultValuedAttr<StrAttr, \"\\\"\\\"\">:$tuning_cache\n    );\n    let trait_attrs = (ins\n      DenseI32ArrayAttr:$operand_segment_sizes\n    );\n    let has_check_fn = 1;\n    let has_logical_tensor_desc_infer_fn = 1;\n    let has_physical_tensor_desc_infer_fn = 1;\n    let has_get_sbp_fn = 1;\n    let has_data_type_infer_fn = 1;\n    let has_compute_complexity_fn = 1;\n}\n\nclass OneFlow_TFPoolBaseOp<string mnemonic, list<Trait> traits = []> :\n        OneFlow_BaseOp<mnemonic, !listconcat(traits, [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>])> {\n    let summary = \"OneFlow pooling operation, align with TensorFlow\";\n    let input = (ins AnyType:$x);\n    let output = (outs AnyType:$y);\n    let attrs = (ins\n    StrAttr:$padding,\n    SI32ArrayAttr:$padding_before,\n    SI32ArrayAttr:$padding_after,\n    StrAttr:$data_format,\n    SI32ArrayAttr:$pool_size,\n    SI32ArrayAttr:$strides,\n    BoolAttr:$ceil_mode\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_compute_complexity_fn = 1;\n}\n\nclass OneFlow_TFPoolGradBaseOp<string mnemonic, list<Trait> traits = []> :\n        OneFlow_BaseOp<mnemonic, !listconcat(traits, [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>])> {\n  let summary = \"OneFlow pooling grad operation, align with TensorFlow\";\n  let input = (ins\n    AnyType:$x,\n    AnyType:$y,\n    AnyType:$dy\n  );\n  let output = (outs AnyType:$dx);\n  let attrs = (ins\n    StrAttr:$padding,\n    SI32ArrayAttr:$padding_before,\n    SI32ArrayAttr:$padding_after,\n    StrAttr:$data_format,\n    SI32ArrayAttr:$pool_size,\n    SI32ArrayAttr:$strides,\n    BoolAttr:$ceil_mode\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_compute_complexity_fn = 1;\n}\n\n\nclass OneFlow_MaxPoolBaseOp<string mnemonic, list<Trait> traits = []> :\n        OneFlow_BaseOp<mnemonic, !listconcat(traits, [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>])> {\n  let summary = \"OneFlow Max Pooling operation\";\n  let input = (ins\n    AnyType:$x\n  );\n  let output = (outs\n    AnyType:$y,\n    AnyType:$indice\n  );\n  let attrs = (ins\n    SI32ArrayAttr:$padding,\n    StrAttr:$data_format,\n    SI32ArrayAttr:$kernel_size,\n    SI32ArrayAttr:$stride,\n    SI32ArrayAttr:$dilation,\n    DefaultValuedAttr<BoolAttr, \"false\">:$return_indices,\n    DefaultValuedAttr<BoolAttr, \"false\">:$ceil_mode\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_compute_complexity_fn = 1;\n}\n\nclass OneFlow_MaxUnpoolBaseOp<string mnemonic, list<Trait> traits = []> :\n        OneFlow_BaseOp<mnemonic, !listconcat(traits, [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>])> {\n  let summary = \"OneFlow Max Unpooling operation\";\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$indices\n  );\n  let output = (outs\n    OneFlow_Tensor:$y\n  );\n  let attrs = (ins\n    SI64ArrayAttr:$kernel_size,\n    SI64ArrayAttr:$stride,\n    SI64ArrayAttr:$padding,\n    DefaultValuedAttr<BoolAttr, \"false\">:$has_output_size,\n    ShapeAttr:$output_size\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\nclass OneFlow_AvgPoolBaseOp<string mnemonic, list<Trait> traits = []> :\n        OneFlow_BaseOp<mnemonic, !listconcat(traits, [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>])> {\n  let summary = \"OneFlow Average Pooling operation\";\n  let input = (ins\n    AnyType:$x\n  );\n  let output = (outs\n    AnyType:$y\n  );\n  let attrs = (ins\n    SI32ArrayAttr:$padding,\n    StrAttr:$data_format,\n    SI32ArrayAttr:$kernel_size,\n    SI32ArrayAttr:$stride,\n    DefaultValuedAttr<BoolAttr, \"false\">:$ceil_mode,\n    DefaultValuedAttr<BoolAttr, \"false\">:$count_include_pad,\n    DefaultValuedAttr<SI32Attr, \"0\">:$divisor_override\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_compute_complexity_fn = 1;\n}\n\nclass OneFlow_MaxPoolGradBaseOp<string mnemonic, list<Trait> traits = []> :\n        OneFlow_BaseOp<mnemonic, !listconcat(traits, [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>])> {\n  let summary = \"OneFlow Max Pooling Grad operation\";\n  let input = (ins\n    AnyType:$x,\n    AnyType:$indice,\n    AnyType:$dy\n  );\n  let output = (outs\n    AnyType:$dx\n  );\n  let attrs = (ins\n    SI32ArrayAttr:$padding,\n    StrAttr:$data_format,\n    SI32ArrayAttr:$kernel_size,\n    SI32ArrayAttr:$stride,\n    SI32ArrayAttr:$dilation,\n    DefaultValuedAttr<BoolAttr, \"false\">:$return_indices,\n    DefaultValuedAttr<BoolAttr, \"false\">:$ceil_mode\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_compute_complexity_fn = 1;\n}\n\nclass OneFlow_AvgPoolGradBaseOp<string mnemonic, list<Trait> traits = []> :\n        OneFlow_BaseOp<mnemonic, !listconcat(traits, [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>])> {\n  let summary = \"OneFlow Average Pooling Grad operation\";\n  let input = (ins\n    AnyType:$x,\n    AnyType:$dy\n  );\n  let output = (outs\n    AnyType:$dx\n  );\n  let attrs = (ins\n    SI32ArrayAttr:$padding,\n    StrAttr:$data_format,\n    SI32ArrayAttr:$kernel_size,\n    SI32ArrayAttr:$stride,\n    DefaultValuedAttr<BoolAttr, \"false\">:$ceil_mode,\n    DefaultValuedAttr<BoolAttr, \"false\">:$count_include_pad,\n    DefaultValuedAttr<SI32Attr, \"0\">:$divisor_override\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_compute_complexity_fn = 1;\n}\n\nclass OneFlow_MaxUnpoolGradBaseOp<string mnemonic, list<Trait> traits = []> :\n        OneFlow_BaseOp<mnemonic, !listconcat(traits, [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>])> {\n  let summary = \"OneFlow Max Unpooling Grad operation\";\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$indices,\n    OneFlow_Tensor:$dy\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\nclass OneFlow_AdaptivePoolBaseOp<string mnemonic, list<Trait> traits = []> :\n       OneFlow_BaseOp<mnemonic, !listconcat(traits, [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>])> {\n  let summary = \"OneFlow adaptive pool operation\";\n  let input = (ins\n    AnyType:$x\n  );\n  let output = (outs AnyType:$y);\n  let attrs = (ins\n    StrAttr:$data_format,\n    SI64ArrayAttr:$output_size\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\nclass OneFlow_AdaptivePoolGradBaseOp<string mnemonic, list<Trait> traits = []> :\n       OneFlow_BaseOp<mnemonic, !listconcat(traits, [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>])> {\n  let summary = \"OneFlow adaptive pool operation\";\n  let input = (ins\n    AnyType:$x,\n    AnyType:$dy\n  );\n  let output = (outs AnyType:$dx);\n  let attrs = (ins\n    StrAttr:$data_format,\n    SI64ArrayAttr:$output_size\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\nclass OneFlow_UnaryBaseOp<string mnemonic, list<Trait> traits = []> :\n        OneFlow_BaseOp<mnemonic, !listconcat(traits, [SameOperandsAndResultType, NoMemoryEffect])> {\n  let summary = \"\";\n  let input = (ins AnyType:$x);\n  let output = (outs AnyType:$y);\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\nclass OneFlow_AdaptiveMaxPoolBaseOp<string mnemonic, list<Trait> traits = []> :\n       OneFlow_BaseOp<mnemonic, !listconcat(traits, [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>])> {\n  let summary = \"OneFlow adaptive max pool operation\";\n  let input = (ins\n    AnyType:$x\n  );\n  let output = (outs\n    AnyType:$y,\n    AnyType:$index\n  );\n  let attrs = (ins\n    StrAttr:$data_format,\n    SI64ArrayAttr:$output_size\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\nclass OneFlow_AdaptiveMaxPoolGradBaseOp<string mnemonic, list<Trait> traits = []> :\n       OneFlow_BaseOp<mnemonic, !listconcat(traits, [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>])> {\n  let summary = \"OneFlow adaptive max pool grad operation\";\n  let input = (ins\n    AnyType:$dy,\n    AnyType:$x,\n    AnyType:$index\n  );\n  let output = (outs AnyType:$dx);\n  let attrs = (ins\n    StrAttr:$data_format\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_Idempotent : NativeOpTrait<\"IsIdempotentOfIdenticalPlacement\">;\n\nclass OneFlow_IdempotentBaseOp<string mnemonic, list<Trait> traits = []> :\n        OneFlow_UnaryBaseOp<mnemonic, !listconcat(traits, [OneFlow_Idempotent])> {}\n\ndef OneFlow_Involution : NativeOpTrait<\"IsInvolutionOfIdenticalPlacement\">;\n\nclass OneFlow_InvolutionBaseOp<string mnemonic, list<Trait> traits = []> :\n        OneFlow_UnaryBaseOp<mnemonic, !listconcat(traits, [OneFlow_Involution])> {}\n\n#define GET_ONEFLOW_BASE_OP_DEFINITIONS\ninclude \"OneFlow/OneFlowUserOps.td\"\n\n#endif  // ONEFLOW_IR_INCLUDE_ONEFLOW_ONEFLOWBASE_H_\n"
  },
  {
    "path": "oneflow/ir/include/OneFlow/OneFlowDataTypeConversion.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_IR_INCLUDE_ONEFLOW_ONEFLOWDATATYPECONVERSION_H_\n#define ONEFLOW_IR_INCLUDE_ONEFLOW_ONEFLOWDATATYPECONVERSION_H_\n\n#include \"mlir/IR/Builders.h\"\n#include \"OneFlow/OneFlowSupport.h\"\n\nnamespace mlir {\n\nnamespace oneflow {\n\nType getTypeFromOneFlowDataType(MLIRContext* context, ::oneflow::DataType dt);\n\n}  // namespace oneflow\n\n}  // namespace mlir\n\n#endif  // ONEFLOW_IR_INCLUDE_ONEFLOW_ONEFLOWDATATYPECONVERSION_H_\n"
  },
  {
    "path": "oneflow/ir/include/OneFlow/OneFlowDialect.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_ONEFLOWDIALECT_H\n#define ONEFLOW_ONEFLOWDIALECT_H\n\n#include \"mlir/IR/Dialect.h\"\n#include \"mlir/Dialect/Func/IR/FuncOps.h\"\n#include \"OneFlow/SBP/SBPDialect.h\"\n\n#include \"OneFlow/OneFlowOpsDialect.h.inc\"\n\n#endif  // ONEFLOW_ONEFLOWDIALECT_H\n"
  },
  {
    "path": "oneflow/ir/include/OneFlow/OneFlowDialect.td",
    "content": "#ifndef ONEFLOW_DIALECT\n#define ONEFLOW_DIALECT\n\ninclude \"mlir/IR/OpBase.td\"\n\ndef OneFlow_Dialect : Dialect {\n    let name = \"oneflow\";\n    let summary = \"OneFlow MLIR dialect.\";\n    let description = [{\n        This dialect is the IR of OneFlow.\n    }];\n    let cppNamespace = \"::mlir::oneflow\";\n    let dependentDialects = [\n        \"sbp::SBPDialect\",\n        \"func::FuncDialect\"\n    ];\n    let hasConstantMaterializer = 1;\n    let useDefaultTypePrinterParser = 1;\n}\n\n#endif // ONEFLOW_DIALECT\n"
  },
  {
    "path": "oneflow/ir/include/OneFlow/OneFlowEnums.td",
    "content": "#ifndef ONEFLOW_IR_INCLUDE_ONEFLOW_ONEFLOWENUMS_H_\n#define ONEFLOW_IR_INCLUDE_ONEFLOW_ONEFLOWENUMS_H_\n\ninclude \"mlir/IR/OpBase.td\"\ninclude \"mlir/IR/EnumAttr.td\"\n\ndef OneFlow_InvalidDataType : I32EnumAttrCase<\"DT_InvalidDataType\", 0>;\ndef OneFlow_Char : I32EnumAttrCase<\"DT_Char\", 1>;\ndef OneFlow_Float : I32EnumAttrCase<\"DT_Float\", 2>;\ndef OneFlow_Double : I32EnumAttrCase<\"DT_Double\", 3>;\ndef OneFlow_Int8 : I32EnumAttrCase<\"DT_Int8\", 4>;\ndef OneFlow_Int32 : I32EnumAttrCase<\"DT_Int32\", 5>;\ndef OneFlow_Int64 : I32EnumAttrCase<\"DT_Int64\", 6>;\ndef OneFlow_UInt8 : I32EnumAttrCase<\"DT_UInt8\", 7>;\ndef OneFlow_OFRecord : I32EnumAttrCase<\"DT_OFRecord\", 8>;\ndef OneFlow_Float16 : I32EnumAttrCase<\"DT_Float16\", 9>;\ndef OneFlow_TensorBuffer: I32EnumAttrCase<\"DT_TensorBuffer\", 10>;\ndef OneFlow_BFloat16: I32EnumAttrCase<\"DT_BFloat16\", 11>;\ndef OneFlow_Bool: I32EnumAttrCase<\"DT_Bool\", 12>;\n\ndef OneFlow_DataType: I32EnumAttr<\"DataType\", \"OneFlow Data Type enum\",\n  [\n    OneFlow_InvalidDataType,\n    OneFlow_Char,\n    OneFlow_Float,\n    OneFlow_Double,\n    OneFlow_Int8,\n    OneFlow_Int32,\n    OneFlow_Int64,\n    OneFlow_UInt8,\n    OneFlow_OFRecord,\n    OneFlow_Float16,\n    OneFlow_TensorBuffer,\n    OneFlow_BFloat16,\n    OneFlow_Bool\n  ]\n> {\n  let cppNamespace = \"::mlir::oneflow\";\n  let stringToSymbolFnName = \"ConvertToEnum\";\n  let symbolToStringFnName = \"ConvertToString\";\n}\n\ndef OneFlow_Contiguous : I32EnumAttrCase<\"MF_Contiguous\", 0>;\ndef OneFlow_ChannelsLast : I32EnumAttrCase<\"MF_ChannelsLast\", 1>;\ndef OneFlow_Preserve : I32EnumAttrCase<\"MF_Preserve\", 2>;\ndef OneFlow_MemoryFormatCount : I32EnumAttrCase<\"MF_MemoryFormatCount\", 3>;\n\ndef OneFlow_MemoryFormat: I32EnumAttr<\"MemoryFormat\", \"OneFlow Memory Format enum\",\n  [\n    OneFlow_Contiguous,\n    OneFlow_ChannelsLast,\n    OneFlow_Preserve,\n    OneFlow_MemoryFormatCount\n  ]\n> {\n  let cppNamespace = \"::mlir::oneflow\";\n  let stringToSymbolFnName = \"ConvertToMemoryFormat\";\n  let symbolToStringFnName = \"ConvertMemoryFormatToString\";\n}\n\n#endif  // ONEFLOW_IR_INCLUDE_ONEFLOW_ONEFLOWENUMS_H_\n"
  },
  {
    "path": "oneflow/ir/include/OneFlow/OneFlowInterfaces.td",
    "content": "#ifndef ONEFLOW_IR_INCLUDE_ONEFLOW_ONEFLOWINTERFACES_H_\n#define ONEFLOW_IR_INCLUDE_ONEFLOW_ONEFLOWINTERFACES_H_\n\ninclude \"mlir/IR/OpBase.td\"\n\ndef UserOpCompatibleInterface : OpInterface<\"UserOpCompatible\"> {\n  let description = [{\n    Interface to getting the hard-coded bn\n  }];\n\n  let methods = [\n    StaticInterfaceMethod<\"\",\n        \"const std::vector<std::string>*\", \"inputKeys\", (ins), [{\n        static std::vector<std::string> val(mlir::oneflow::support::GetInputKeys(ConcreteOp::getOperationName().split('.').second.str()));\n        return &val;\n    }]>,\n    StaticInterfaceMethod<\"\",\n        \"const std::vector<std::string>*\", \"outputKeys\", (ins), [{\n        static std::vector<std::string> val(mlir::oneflow::support::GetOutputKeys(ConcreteOp::getOperationName().split('.').second.str()));\n        return &val;\n    }]>,\n    InterfaceMethod<\"\",\n        \"std::pair<unsigned, unsigned>\", \"getODSOperandIndexAndLength\", (ins \"unsigned\":$index), [{\n        return $_op.getODSOperandIndexAndLength(index);\n    }]>,\n    InterfaceMethod<\"\",\n        \"std::pair<unsigned, unsigned>\", \"getODSResultIndexAndLength\", (ins \"unsigned\":$index), [{\n        return $_op.getODSResultIndexAndLength(index);\n    }]>\n  ];\n  let cppNamespace = \"::mlir::oneflow\";\n}\n\ndef AlternativeOpTypeNameInterface : OpInterface<\"HasAlternativeOpTypeName\"> {\n  let description = [{\n    Interface to getting control edges\n  }];\n\n  let methods = [\n    StaticInterfaceMethod<\"\",\n        \"std::string\", \"getOriginalOpTypeName\", (ins)\n    >,\n    StaticInterfaceMethod<\"\",\n        \"const std::vector<std::string>*\", \"inputKeys\", (ins), [{\n        static std::vector<std::string> val(mlir::oneflow::support::GetInputKeys(ConcreteOp::getOriginalOpTypeName()));\n        return &val;\n    }]>,\n    StaticInterfaceMethod<\"\",\n        \"const std::vector<std::string>*\", \"outputKeys\", (ins), [{\n        static std::vector<std::string> val(mlir::oneflow::support::GetOutputKeys(ConcreteOp::getOriginalOpTypeName()));\n        return &val;\n    }]>,\n  ];\n  let cppNamespace = \"::mlir::oneflow\";\n}\n\ndef ControlEdgeCompatibleInterface : OpInterface<\"ControlEdgeCompatible\"> {\n  let description = [{\n    Interface to getting control edges\n  }];\n\n  let methods = [\n    InterfaceMethod<\"\",\n        \"::mlir::OperandRange\", \"dataInputOperands\", (ins)\n    >,\n    InterfaceMethod<\"\",\n        \"::mlir::OperandRange\", \"ctrlInputOperands\", (ins)\n    >,\n    InterfaceMethod<\"\",\n        \"::mlir::ResultRange\", \"dataOutputResults\", (ins)\n    >,\n    InterfaceMethod<\"\",\n        \"::mlir::Value\", \"ctrlOutputResult\", (ins)\n    >\n  ];\n  let cppNamespace = \"::mlir::oneflow\";\n}\n\ndef NoGrad : OpInterface<\"NoGrad\"> {\n  let description = [{\n  }];\n  let cppNamespace = \"::mlir::oneflow\";\n}\n\ndef SupportNonContiguous : OpInterface<\"SupportNonContiguous\"> {\n  let description = [{\n  }];\n  let cppNamespace = \"::mlir::oneflow\";\n}\n\ndef CpuOnly : OpInterface<\"CpuOnly\"> {\n  let description = [{\n  }];\n  let cppNamespace = \"::mlir::oneflow\";\n}\n\ndef NCHWCompatibleInterface : OpInterface<\"NCHWCompatible\"> {\n  let description = [{\n    Interface of NCHW compatibility\n  }];\n\n  let methods = [\n    InterfaceMethod<\"\",\n        \"bool\", \"IsNCHW\", (ins)\n    >,\n    InterfaceMethod<\"Create NHWC op and return the new op's results to be transposed\",\n        \"llvm::SmallVector<mlir::Value, 4>\", \"NchwToNhwc\", (ins \"llvm::SmallVector<mlir::Value, 4>\": $transposed_inputs, \"PatternRewriter&\": $rewriter)\n    >,\n    InterfaceMethod<\"\",\n        \"llvm::DenseSet<mlir::Value>\", \"OperandsToTranspose\", (ins)\n    >,\n    InterfaceMethod<\"\",\n        \"llvm::DenseSet<mlir::Value>\", \"ResultsToTranspose\", (ins)\n    >,\n  ];\n  let cppNamespace = \"::mlir::oneflow\";\n}\n\ndef BiasAddCompatibleInterface : OpInterface<\"BiasAddCompatible\"> {\n  let description = [{\n    Interface of ops used as bias add\n  }];\n\n  let methods = [\n    InterfaceMethod<\"\",\n        \"bool\", \"isLastDim\", (ins)\n    >,\n    InterfaceMethod<\"\",\n        \"mlir::Value\", \"biasAddGetBias\", (ins)\n    >,\n    InterfaceMethod<\"\",\n        \"mlir::Value\", \"biasAddGetOut\", (ins)\n    >,\n  ];\n  let cppNamespace = \"::mlir::oneflow\";\n}\n\ndef MatMulCompatibleInterface : OpInterface<\"MatMulCompatible\"> {\n  let description = [{\n    Interface of ops used as matmul\n  }];\n\n  let methods = [\n    InterfaceMethod<\"is this a transpose_a=false, transpose_b=true matmul\",\n        \"bool\", \"isLinear\", (ins)\n    >,\n    InterfaceMethod<\"\",\n        \"mlir::Value\", \"matMulGetX\", (ins)\n    >,\n    InterfaceMethod<\"\",\n        \"mlir::Value\", \"matMulGetW\", (ins)\n    >,\n    InterfaceMethod<\"\",\n        \"mlir::Value\", \"matMulGetY\", (ins)\n    >,\n  ];\n  let cppNamespace = \"::mlir::oneflow\";\n}\n\n\n#endif  // ONEFLOW_IR_INCLUDE_ONEFLOW_ONEFLOWINTERFACES_H_\n"
  },
  {
    "path": "oneflow/ir/include/OneFlow/OneFlowOpGetGen.td",
    "content": "include \"OneFlow/OneFlowDialect.td\"\ninclude \"OneFlow/OneFlowEnums.td\"\ninclude \"mlir/Interfaces/SideEffectInterfaces.td\"\ninclude \"OneFlow/OneFlowBase.td\"\n"
  },
  {
    "path": "oneflow/ir/include/OneFlow/OneFlowOpTraits.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_IR_INCLUDE_ONEFLOW_ONEFLOWOPTRAITS_H_\n#define ONEFLOW_IR_INCLUDE_ONEFLOW_ONEFLOWOPTRAITS_H_\n\n#include \"mlir/IR/OpDefinition.h\"\n#include \"mlir/IR/Operation.h\"\n#include \"oneflow/core/operator/op_conf.pb.h\"\n\nnamespace mlir {\n\nnamespace OpTrait {\n\nnamespace impl {\n\nOpFoldResult foldIdempotentOfIdenticalPlacement(Operation* op);\nOpFoldResult foldInvolutionOfIdenticalPlacement(Operation* op);\nLogicalResult VerifyIsOpConfCompatible(Operation* op);\nLogicalResult VerifyIsImportCompatible(Operation* op);\n\n// trait IsOpConfCompatible\nLogicalResult saveAttrToOpConf(Operation* op, ::oneflow::OperatorConf* op_conf);\nLogicalResult saveAttrsToNamedAttrList(Operation* op, NamedAttrList& named_attr_list);\nStringAttr getOpName(Operation* op);\nStringAttr getDeviceTag(Operation* op);\nArrayAttr getDeviceName(Operation* op);\nIntegerAttr getScopeSymbolID(Operation* op);\nArrayAttr getHierarchy(Operation* op);\n\n}  // namespace impl\n\ntemplate<typename ConcreteType>\nclass IsOpConfCompatible : public TraitBase<ConcreteType, IsOpConfCompatible> {\n public:\n  static StringRef getOpNameAttr() { return \"op_name\"; }\n  static StringRef getDeviceTagAttr() { return \"device_tag\"; }\n  static StringRef getDeviceNameAttr() { return \"device_name\"; }\n  static StringRef getScopeSymbolIDAttr() { return \"scope_symbol_id\"; }\n  static StringRef getHierarchyAttr() { return \"hierarchy\"; }\n  static LogicalResult verifyTrait(Operation* op) { return impl::VerifyIsOpConfCompatible(op); }\n  static LogicalResult dump_attr(Operation* op, ::oneflow::OperatorConf* op_conf) {\n    return impl::saveAttrToOpConf(op, op_conf);\n  }\n  static LogicalResult saveToNamedAttrList(Operation* op, NamedAttrList& named_attr_list) {\n    return impl::saveAttrsToNamedAttrList(op, named_attr_list);\n  }\n  static StringAttr getOpName(Operation* op) { return impl::getOpName(op); }\n  static StringAttr getDeviceTag(Operation* op) { return impl::getDeviceTag(op); }\n  static ArrayAttr getDeviceName(Operation* op) { return impl::getDeviceName(op); }\n  static IntegerAttr getScopeSymbolID(Operation* op) { return impl::getScopeSymbolID(op); }\n  static ArrayAttr getHierarchy(Operation* op) { return impl::getHierarchy(op); }\n};\n\ntemplate<typename ConcreteType>\nclass IsImportCompatible : public TraitBase<ConcreteType, IsImportCompatible> {\n public:\n  static StringRef getOutputLBNsAttr() { return \"output_lbns\"; }\n  static LogicalResult verifyTrait(Operation* op) { return impl::VerifyIsImportCompatible(op); }\n};\n\ntemplate<typename ConcreteType>\nclass IsIdempotentOfIdenticalPlacement\n    : public TraitBase<ConcreteType, IsIdempotentOfIdenticalPlacement> {\n public:\n  static LogicalResult verifyTrait(Operation* op) {\n    static_assert(ConcreteType::template hasTrait<OneResult>(),\n                  \"expected operation to produce one result\");\n    static_assert(ConcreteType::template hasTrait<OneOperand>(),\n                  \"expected operation to take one operand\");\n    static_assert(ConcreteType::template hasTrait<SameOperandsAndResultType>(),\n                  \"expected operation to preserve type\");\n    static_assert(ConcreteType::template hasTrait<OpTrait::IsOpConfCompatible>(),\n                  \"expected operation to be op conf compatible\");\n    return impl::verifyIsIdempotent(op);\n  }\n\n  static OpFoldResult foldTrait(Operation* op, ArrayRef<Attribute> operands) {\n    return impl::foldIdempotentOfIdenticalPlacement(op);\n  }\n};\n\ntemplate<typename ConcreteType>\nclass IsInvolutionOfIdenticalPlacement\n    : public TraitBase<ConcreteType, IsInvolutionOfIdenticalPlacement> {\n public:\n  static LogicalResult verifyTrait(Operation* op) {\n    static_assert(ConcreteType::template hasTrait<OneResult>(),\n                  \"expected operation to produce one result\");\n    static_assert(ConcreteType::template hasTrait<OneOperand>(),\n                  \"expected operation to take one operand\");\n    static_assert(ConcreteType::template hasTrait<SameOperandsAndResultType>(),\n                  \"expected operation to preserve type\");\n    static_assert(ConcreteType::template hasTrait<OpTrait::IsOpConfCompatible>(),\n                  \"expected operation to be op conf compatible\");\n    return impl::verifyIsInvolution(op);\n  }\n\n  static OpFoldResult foldTrait(Operation* op, ArrayRef<Attribute> operands) {\n    return impl::foldInvolutionOfIdenticalPlacement(op);\n  }\n};\n\ntemplate<typename ConcreteType>\nclass IsAlternative : public TraitBase<ConcreteType, IsAlternative> {\n public:\n  static StringRef getOpTypeNameAttr() { return \"op_type_name\"; }\n  static LogicalResult verifyTrait(Operation* op) {\n    if (op->hasAttrOfType<StringAttr>(getOpTypeNameAttr())) {\n      return success();\n    } else {\n      return op->emitError(\"expected operation to have attribute: \" + getOpTypeNameAttr());\n    }\n  }\n};\n\ntemplate<typename ConcreteType>\nclass TensorSource : public TraitBase<ConcreteType, TensorSource> {\n public:\n  static StringRef getShapeAttrName() { return \"shape\"; }\n  static StringRef getDataTypeAttrName() { return \"data_type\"; }\n  static StringRef getIsDynamicAttrName() { return \"is_dynamic\"; }\n  static StringRef getNdSbpAttrName() { return \"nd_sbp\"; }\n  static StringRef getSbpAttrName() { return \"parallel\"; }\n\n  static LogicalResult verifyTrait(Operation* op) {\n    if (!op->hasAttrOfType<ArrayAttr>(getShapeAttrName())) {\n      return op->emitError(\"expected operation to have attribute: \" + getShapeAttrName());\n    }\n    if (!op->hasAttrOfType<IntegerAttr>(getDataTypeAttrName())) {\n      return op->emitError(\"expected operation to have attribute: \" + getDataTypeAttrName());\n    }\n    return success();\n  }\n};\n\ntemplate<typename ConcreteType>\nclass OnlyExistsInIR : public TraitBase<ConcreteType, OnlyExistsInIR> {};\n\ntemplate<typename ConcreteType>\nclass IsElementwise : public TraitBase<ConcreteType, IsElementwise> {};\n\n}  // namespace OpTrait\n\n}  // namespace mlir\n\n#endif  // ONEFLOW_IR_INCLUDE_ONEFLOW_ONEFLOWOPTRAITS_H_\n"
  },
  {
    "path": "oneflow/ir/include/OneFlow/OneFlowOps.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_IR_INCLUDE_ONEFLOW_ONEFLOWOPS_H_\n#define ONEFLOW_IR_INCLUDE_ONEFLOW_ONEFLOWOPS_H_\n\n#include \"mlir/IR/Dialect.h\"\n#include \"mlir/IR/OpDefinition.h\"\n#include \"mlir/IR/OpImplementation.h\"\n#include \"mlir/IR/Builders.h\"\n#include \"mlir/IR/FunctionInterfaces.h\"\n#include \"mlir/Interfaces/CallInterfaces.h\"\n#include \"mlir/Interfaces/SideEffectInterfaces.h\"\n#include \"mlir/Interfaces/ControlFlowInterfaces.h\"\n#include \"mlir/Interfaces/InferTypeOpInterface.h\"\n#include \"mlir/IR/PatternMatch.h\"\n#include \"OneFlow/OneFlowSupport.h\"\n#include \"OneFlow/OneFlowInterfaces.h.inc\"\n#include \"OneFlow/OneFlowOpTraits.h\"\n#include \"OneFlow/SBP/SBPAttributes.h\"\n\nnamespace mlir {\n\nnamespace func {\nclass FuncOp;\n}  // namespace func\n\n}  // namespace mlir\n\n#define GET_OP_CLASSES\n#include \"OneFlow/OneFlowOps.h.inc\"\n#define GET_OP_CLASSES\n#include \"OneFlow/OneFlow.gen_ops.h.inc\"\n\nnamespace mlir {\n\nnamespace oneflow {\n\ntemplate<typename T>\ninline std::string GetOpTypeName(T op) {\n  std::string op_type_name = op->getName().stripDialect().str();\n  if (op->template hasTrait<OpTrait::IsAlternative>()) {\n    op_type_name =\n        op->template getAttrOfType<StringAttr>(OpTrait::IsAlternative<void>::getOpTypeNameAttr())\n            .str();\n  }\n  if (auto alternative_name = dyn_cast<oneflow::HasAlternativeOpTypeName>(op)) {\n    op_type_name = alternative_name.getOriginalOpTypeName();\n  }\n  if (auto user_op = dyn_cast<oneflow::UserOp>(op)) {\n    op_type_name = user_op.getOpTypeName().str();\n  }\n  return op_type_name;\n}\nResultRange GetDataOutputResults(Operation* op);\nOperandRange GetDataInputOperands(Operation* op);\nllvm::Optional<OperandRange> GetCtrlIntputOperands(Operation* op);\nllvm::Optional<OpResult> GetCtrlOutputResult(Operation* op);\n\nArrayAttr getSI32ArrayAttr(::mlir::PatternRewriter& rewriter, ArrayRef<int32_t> values);\n\n}  // namespace oneflow\n\n}  // namespace mlir\n\n#endif  // ONEFLOW_IR_INCLUDE_ONEFLOW_ONEFLOWOPS_H_\n"
  },
  {
    "path": "oneflow/ir/include/OneFlow/OneFlowOps.td",
    "content": "#ifndef ONEFLOW_OPS\n#define ONEFLOW_OPS\n\ninclude \"OneFlow/OneFlowDialect.td\"\ninclude \"OneFlow/OneFlowEnums.td\"\ninclude \"OneFlow/OneFlowInterfaces.td\"\ninclude \"OneFlow/OneFlowBase.td\"\n\ninclude \"mlir/Interfaces/SideEffectInterfaces.td\"\ninclude \"mlir/IR/FunctionInterfaces.td\"\ninclude \"mlir/Interfaces/CallInterfaces.td\"\ninclude \"mlir/Interfaces/ControlFlowInterfaces.td\"\ninclude \"mlir/Pass/PassBase.td\"\n\ninclude \"mlir/IR/AttrTypeBase.td\"\ninclude \"mlir/IR/OpBase.td\"\n\ninclude \"OneFlow/SBP/SBPOps.td\"\n\n\n#ifndef REMOVE_ONEFLOW_MLIR_ONLY_OP_DEFINITIONS\n\ndef OneFlow_UserOp : OneFlow_UserBaseWithCtrlOp<\"user\", [OneFlow_IsImportCompatible]> {\n  let summary = \"\";\n  let input = (ins Variadic<AnyType>:$data_input);\n  let output = (outs Variadic<AnyType>:$data_output);\n  let attrs = (ins\n    StrArrayAttr:$output_lbns\n  );\n  let hasCanonicalizer = 1;\n}\n\ndef OneFlow_ConfOp : OneFlow_BaseOp<\"conf\", [OneFlow_IsImportCompatible]> {\n  let summary = \"This op is mainly used by create its adaptor in importing/exporting\";\n}\n\ndef OneFlow_SystemOp : OneFlow_Op<\"system\", [OneFlow_IsImportCompatible]> {\n  let summary = \"\";\n  let input = (ins Variadic<AnyType>:$data_input);\n  let output = (outs Variadic<AnyType>:$data_output);\n  let attrs = (ins\n    StrArrayAttr:$input_bns,\n    StrArrayAttr:$output_lbns,\n    I32Attr:$op_type_case\n  );\n  let hasCanonicalizer = 1;\n}\n\ndef F32ElementsAttr : FloatElementsAttr<32>;\n\ndef OneFlow_FrozenVariableOp : OneFlow_IROp<\"variable_ir\", [ConstantLike, NoMemoryEffect]> {\n  let summary = \"Auxiliary variable op for constant folding, only exists in IR.\";\n  let arguments = (ins\n    F32ElementsAttr:$value,\n    StrAttr:$op_name,\n    OptionalAttr<OneFlow_DataType>:$data_type,\n    StrAttr:$device_tag,\n    StrArrayAttr:$device_name, // TODO: change device_name to dict and parse the literal fmt like \"0:0-0\"\n    OptionalAttr<I64Attr>:$scope_symbol_id,\n    OptionalAttr<I64ArrayAttr>:$hierarchy,\n    StrArrayAttr:$nd_sbp\n  );\n  let results = (outs\n    AnyType:$output\n  );\n  let hasFolder = 1;\n}\n\ndef OneFlow_Add2Op : OneFlow_BaseOp<\"add_n2\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>, DeclareOpInterfaceMethods<AlternativeOpTypeNameInterface>, DeclareOpInterfaceMethods<NCHWCompatibleInterface>]> {\n  let summary = \"\";\n  let input = (ins\n    AnyType:$in0,\n    AnyType:$in1\n  );\n  let output = (outs AnyType:$out);\n}\n\n\nclass OneFlow_ConcreteSystemOp<string mnemonic, list<Trait> traits = []> :\n        OneFlow_BaseOp<mnemonic, !listconcat(traits, [OneFlow_IsOpConfCompatible,\n        OneFlow_IsImportCompatible,\n        DeclareOpInterfaceMethods<ControlEdgeCompatibleInterface>])> {\n  let input = (ins);\n  let output = (ins);\n  let ctrl_input = (ins Variadic<AnyType>:$ctrl_inputs);\n  let ctrl_output = (outs Optional<AnyType>:$ctrl_output);\n  dag required_attrs = (ins StrArrayAttr:$output_lbns);\n  dag custom_attrs = (ins);\n  let attrs = !con(\n    required_attrs,\n    custom_attrs\n  );\n  let hasCanonicalizer = 1;\n}\n\ndef OneFlow_VariableOp : OneFlow_ConcreteSystemOp<\"variable\", [OneFlow_TensorSource]> {\n  let summary = \"\";\n  let input = (ins);\n  let output = (outs AnyType:$output);\n  let custom_attrs = (ins\n    ShapeAttr:$shape,\n    OptionalAttr<OneFlow_DataType>:$data_type,\n    OptionalAttr<StrAttr>:$model_name,\n    OptionalAttr<F32Attr>:$l1_regularization,\n    OptionalAttr<F32Attr>:$l2_regularization,\n    OptionalAttr<BoolAttr>:$trainable,\n    OptionalAttr<F32Attr>:$float_initializer,\n    OptionalAttr<SI64Attr>:$integer_initializer,\n    OptionalAttr<SBP_ParallelSignatureAttr>:$parallel\n  );\n}\n\ndef OneFlow_InputOp : OneFlow_ConcreteSystemOp<\"input\", [OneFlow_TensorSource]> {\n  let summary = \"\";\n  let input = (ins AnyType:$input);\n  let output = (outs AnyType:$output);\n  let custom_attrs = (ins\n    OptionalAttr<ShapeAttr>:$shape,\n    OptionalAttr<OneFlow_DataType>:$data_type,\n    OptionalAttr<BoolAttr>:$is_dynamic,\n    OptionalAttr<StrArrayAttr>:$nd_sbp,\n    OptionalAttr<StrAttr>:$job_name\n  );\n  let builders = [\n    OpBuilder<(ins\n      \"::oneflow::OperatorConf\":$op_conf\n    )>\n  ];\n}\n\ndef OneFlow_OutputOp : OneFlow_ConcreteSystemOp<\"output\", [OneFlow_TensorSource]> {\n  let summary = \"\";\n  let input = (ins AnyType:$input);\n  let output = (outs AnyType:$output);\n  let custom_attrs = (ins\n    OptionalAttr<ShapeAttr>:$shape,\n    OptionalAttr<OneFlow_DataType>:$data_type,\n    OptionalAttr<BoolAttr>:$is_dynamic,\n    OptionalAttr<StrArrayAttr>:$nd_sbp,\n    OptionalAttr<StrAttr>:$job_name\n  );\n}\n\ndef OneFlow_Job : Op<OneFlow_Dialect, \"job\", [FunctionOpInterface, IsolatedFromAbove, Symbol]>  {\n  let regions = (region AnyRegion:$body);\n\n  let arguments = (ins\n    SymbolNameAttr:$sym_name,\n    TypeAttrOf<FunctionType>:$function_type,\n    OptionalAttr<StrAttr>:$sym_visibility,\n    OptionalAttr<DictArrayAttr>:$arg_attrs,\n    OptionalAttr<DictArrayAttr>:$res_attrs\n  );\n\n  let builders = [OpBuilder<(ins\n    \"StringRef\":$sym_name, \"FunctionType\":$function_type,\n      CArg<\"ArrayRef<NamedAttribute>\", \"{}\">:$attrs\n  )>];\n\n  let extraClassDeclaration = [{\n    bool isDeclaration() { return isExternal(); }\n\n    ArrayRef<Type> getArgumentTypes() { return getFunctionType().getInputs(); }\n\n    ArrayRef<Type> getResultTypes() { return getFunctionType().getResults(); }\n\n    LogicalResult verifyType() {\n      auto type = getFunctionTypeAttr().getValue();\n      if (!type.isa<FunctionType>())\n        return emitOpError(\"requires '\" + getFunctionTypeAttrName().str() +\n                           \"' attribute of function type\");\n      return success();\n    }\n  }];\n\n  let hasCustomAssemblyFormat = 1;\n  let hasVerifier = 1;\n}\n\ndef OneFlow_ReturnOp : Op<OneFlow_Dialect, \"return\", [NoMemoryEffect, HasParent<\"Job\">,\n                                MemRefsNormalizable, ReturnLike, Terminator]> {\n  let summary = \"return operation\";\n  let description = [{\n    The \"return\" operation represents a return operation within a Job.\n    The operation takes an optional tensor operand and produces no results.\n    The operand type must match the signature of the job function that contains\n    the operation. For example:\n\n    ```mlir\n      job @foo() -> tensor<2xf64> {\n        ...\n        oneflow.return %0 : tensor<2xf64>\n      }\n    ```\n  }];\n\n  let arguments = (ins Variadic<AnyType>:$operands);\n\n  let builders = [\n    OpBuilder<(ins),\n    [{ build($_builder, $_state, llvm::None); }]>];\n\n  let assemblyFormat = \"attr-dict ($operands^ `:` type($operands))?\";\n\n  let hasCustomAssemblyFormat = 1;\n  let hasVerifier = 1;\n}\n\ndef OneFlow_NormalizationInferenceOp : OneFlow_NormalizationBaseOp<\"normalization_infer\", [DeclareOpInterfaceMethods<AlternativeOpTypeNameInterface>, DeclareOpInterfaceMethods<NCHWCompatibleInterface>]> {\n  let output = (outs\n    OneFlow_Tensor:$y\n  );\n}\n\n#endif // REMOVE_ONEFLOW_MLIR_ONLY_OP_DEFINITIONS\n\n#endif // ONEFLOW_OPS\n\n"
  },
  {
    "path": "oneflow/ir/include/OneFlow/OneFlowPDLLPatterns.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_IR_INCLUDE_ONEFLOW_ONEFLOWPDLLPATTERNS_H_\n#define ONEFLOW_IR_INCLUDE_ONEFLOW_ONEFLOWPDLLPATTERNS_H_\n#include \"mlir/IR/PatternMatch.h\"\n\nnamespace mlir {\n\nnamespace oneflow {\n\nvoid populateAllocEliminationPatterns(RewritePatternSet& patterns);\nvoid populateForwardOpPatterns(RewritePatternSet& patterns);\nvoid populateNormalizationOpPatterns(RewritePatternSet& patterns);\nvoid populateFuseConv2DBatchNormPattern(RewritePatternSet& patterns);\nvoid populateFuseOpsWithBackwardImplPattern(RewritePatternSet& patterns);\n\n}  // namespace oneflow\n\n}  // namespace mlir\n\n#endif  // ONEFLOW_IR_INCLUDE_ONEFLOW_ONEFLOWPDLLPATTERNS_H_\n"
  },
  {
    "path": "oneflow/ir/include/OneFlow/OneFlowPasses.td",
    "content": "#ifndef ONEFLOW_PASSES\n#define ONEFLOW_PASSES\n\ninclude \"OneFlow/OneFlowOps.td\"\n\n#ifdef WITH_MLIR_CUDA_CODEGEN\ndef NVVMToCubinPass : InterfacePass<\"nvvm-to-cubin\", \"SymbolOpInterface\"> {\n  let summary = \"convert nvvm ir to cubin\";\n  let constructor = \"mlir::oneflow::createNVVMToCubinPass()\";\n  let options = [\n    Option<\"triple\", \"triple\", \"StringRef\", \"\\\"nvptx64-nvidia-cuda\\\"\", \"Target triple\">,\n    Option<\"chip\", \"chip\", \"StringRef\", \"mlir::oneflow::getArchVersion()\", \"Target architecture\">,\n    Option<\"features\", \"features\", \"StringRef\", \"\\\"+ptx60\\\"\", \"Target features\">,\n  ];\n}\n#endif // WITH_MLIR_CUDA_CODEGEN\n\ndef TestOneFlowTraitFolderPass : Pass<\"test-oneflow-trait-folder\", \"func::FuncOp\"> {\n  let constructor = \"mlir::oneflow::createTestOneFlowTraitFolderPass()\";\n}\n\ndef LowerOneFlowToTosaPass : Pass<\"lower-oneflow-to-tosa\", \"ModuleOp\"> {\n  let summary = \"lower oneflow dialect to tosa dialect\";\n  let constructor = \"mlir::oneflow::createLowerOneFlowToTosaPass()\";\n  let dependentDialects = [\"tosa::TosaDialect\", \"memref::MemRefDialect\", \"mlir::func::FuncDialect\"];\n  let options = [\n    Option<\"variableAsConstant\", \"variable-as-constant\", \"int\", \"0\",\n           \"convert variable op as const op of tosa\">,\n    Option<\"fullyConvert\", \"full\", \"bool\",\n           /*default=*/\"true\", \"Fully convert operations and make OneFlow dialect illegal target\">,\n    Option<\"lowerJob\", \"lower-job\", \"bool\",\n           /*default=*/\"true\", \"Convert oneflow.job to func.func\">,\n  ];\n}\n\ndef LowerOneFlowToLinalgPass : Pass<\"lower-oneflow-to-linalg\", \"ModuleOp\"> {\n  let summary = \"lower oneflow dialect to Linalg dialect\";\n  let constructor = \"mlir::oneflow::createLowerOneFlowToLinalgPass()\";\n  let dependentDialects = [\"linalg::LinalgDialect\", \"memref::MemRefDialect\", \"mlir::func::FuncDialect\"];\n}\n\ndef FuncToOneFlowJobPass : Pass<\"func-to-ofjob\", \"ModuleOp\"> {\n  let summary = \"convert func Ops to oneflow Ops\";\n  let constructor = \"mlir::oneflow::createFuncToOneFlowJobPass()\";\n  let dependentDialects = [\"mlir::func::FuncDialect\"];\n}\n\ndef OneFlowJobToFuncPass : Pass<\"ofjob-to-func\", \"ModuleOp\"> {\n  let summary = \"convert oneflow Ops to func Ops\";\n  let constructor = \"mlir::oneflow::createOneFlowJobToFuncPass()\";\n  let dependentDialects = [\"mlir::func::FuncDialect\"];\n}\n\ndef CastOneFlowOpsToSignlessPass : Pass<\"cast-ofops-to-signless\", \"ModuleOp\"> {\n  let summary = \"cast oneflow ops to singless\";\n  let constructor = \"mlir::oneflow::createCastOneFlowOpsToSignlessPass()\";\n  let dependentDialects = [\"mlir::func::FuncDialect\", \"mlir::BuiltinDialect\"];\n}\n\ndef BufferHostRegisterPass : Pass<\"buffer-host-register\", \"func::FuncOp\"> {\n  let summary = \"\";\n  let constructor = \"mlir::oneflow::createBufferHostRegisterPass()\";\n  let dependentDialects = [\"gpu::GPUDialect\"];\n}\n\ndef GpuCopyArgPass : Pass<\"gpu-copy-arg\", \"func::FuncOp\"> {\n  let summary = \"\";\n  let constructor = \"mlir::oneflow::createGpuCopyArgPass()\";\n  let dependentDialects = [\"memref::MemRefDialect\", \"gpu::GPUDialect\"];\n}\n\ndef OutlineJitFunctionPass : InterfacePass<\"outline-jit-function\", \"FunctionOpInterface\"> {\n  let summary = \"move ops could be jitted to jit function\";\n  let constructor = \"mlir::oneflow::createOutlineJitFunctionPass()\";\n  let dependentDialects = [\"pdl_interp::PDLInterpDialect\", \"pdl::PDLDialect\", \"LLVM::LLVMDialect\"];\n  let options = [\n    Option<\"compileToLLVM\", \"compile-to-llvm\", \"bool\",\n           /*default=*/\"true\", \"Convert to llvm dialect in this pass\">,\n  ];\n}\n\ndef AggregateComputeOpsPass : Pass<\"aggregate-compute-ops\", \"ModuleOp\"> {\n  let summary = \"aggregate compute ops together\";\n  let constructor = \"mlir::oneflow::createAggregateComputeOpsPass()\";\n}\n\ndef WrapOpsToKernelLaunchPass : Pass<\"wrap-ops-to-kernel-launch\", \"ModuleOp\"> {\n  let summary = \"wrap user ops with a single kernel launch op in OneFlow Job\";\n  let constructor = \"mlir::oneflow::createWrapOpsToKernelLaunchPass()\";\n}\n\ndef EliminateAllocOpsPass : Pass<\"eliminate-alloc-ops\", \"ModuleOp\"> {\n  let summary = \"eliminate memref.alloc and memref.copy which target is a block argument\";\n  let constructor = \"mlir::oneflow::createEliminateAllocOpsPass()\";\n  let dependentDialects = [\"pdl_interp::PDLInterpDialect\", \"pdl::PDLDialect\"];\n}\n\ndef AppendOneFlowStreamPass : Pass<\"append-ofstream\", \"ModuleOp\"> {\n  let summary = \"append oneflow stream to gpu function arguments\";\n  let constructor = \"mlir::oneflow::createAppendOneFlowStreamPass()\";\n}\n\ndef MgpuToOneFlowStreamPass : Pass<\"mgpu-to-ofstream\", \"ModuleOp\"> {\n  let summary = \"convert mlir abi about mgpu to oneflow stream, this pass should be invoked after append-ofstream pass\";\n  let constructor = \"mlir::oneflow::createMgpuToOneFlowStreamPass()\";\n}\n\ndef FuseIntoExistingOpPass : Pass<\"fuse-into-existing-op\", \"ModuleOp\"> {\n  let summary = \"\";\n  let constructor = \"mlir::oneflow::createFuseIntoExistingOpPass()\";\n  let dependentDialects = [\"pdl_interp::PDLInterpDialect\", \"pdl::PDLDialect\"];\n}\n\ndef InsertOneFlowMemPoolPass : Pass<\"insert-ofmempool\", \"ModuleOp\"> {\n  let summary = \"insert oneflow tmp buffer as a memory pool in mlir codegen\";\n  let constructor = \"mlir::oneflow::createInsertOneFlowMemPoolPass()\";\n}\n\ndef FoldAllocToSubviewPass : Pass<\"fold-alloc-to-subview\", \"func::FuncOp\"> {\n  let summary = \"fold dispersed memref.alloc ops with memory optimize algo to a single memref.alloc op and memref.subview ops\";\n  let constructor = \"mlir::oneflow::createFoldAllocToSubviewPass()\";\n}\n\n\ndef AutoNhwcPass : Pass<\"auto-nhwc\", \"ModuleOp\"> {\n  let summary = \"\";\n  let constructor = \"mlir::oneflow::createAutoNhwcPass()\";\n}\n\ndef PreConvertInferenceOpPass : Pass<\"pre-convert-inference-op\", \"ModuleOp\"> {\n  let summary = \"Convert variable op to variable ir op for constant folding.\";\n  let constructor = \"mlir::oneflow::createPreConvertInferenceOpPass()\";\n}\n\ndef ConvertInferenceOpPass : Pass<\"convert-inference-op\", \"ModuleOp\"> {\n  let summary = \"Convert ops to their inference version and rewrite them with a more performant equivalent in inference workflow.\";\n  let constructor = \"mlir::oneflow::createConvertInferenceOpPass()\";\n  let dependentDialects = [\"pdl_interp::PDLInterpDialect\", \"pdl::PDLDialect\"];\n}\n\ndef PostConvertInferenceOpPass : Pass<\"post-convert-inference-op\", \"ModuleOp\"> {\n  let summary = \"Convert variable ir op to variable op after contant folding.\";\n  let constructor = \"mlir::oneflow::createPostConvertInferenceOpPass()\";\n}\n\n\ndef ConvertToSignlessForTosaPass : Pass<\"convert-to-signless-for-tosa\", \"ModuleOp\"> {\n  let summary = \"convert func type to unsigned before lowering to tosa\";\n  let description = [{\n    In oneflow, int typed tensor is explicit signed. Convert them before lowering to TOSA.\n  }];\n  let constructor = \"mlir::oneflow::createConvertToSignlessForTosaPass()\";\n  let dependentDialects = [\"func::FuncDialect\"];\n}\n\ndef CSEWithAttributesIgnored : Pass<\"cse-with-attributes-ignored\", \"ModuleOp\"> {\n  let summary = \"ignore oneflow attributes to have cse work\";\n  let description = [{\n    cse and ignore oneflow attributes like op name, symbol id, etc.\n  }];\n  let constructor = \"mlir::oneflow::createCSEWithAttributesIgnored()\";\n  let dependentDialects = [];\n}\n\ndef CSEPutAttributes : Pass<\"cse-put-attributes\", \"ModuleOp\"> {\n  let summary = \"cse and ignore oneflow attributes\";\n  let description = [{\n    put back oneflow attributes like op name, symbol id, etc.\n  }];\n  let constructor = \"mlir::oneflow::createCSEPutAttributes()\";\n  let dependentDialects = [];\n}\n\ndef GroupMatMul : Pass<\"group-matmul\", \"ModuleOp\"> {\n  let summary = \"group matmul together\";\n  let description = [{\n    group matmul ops together and use cudnn batched matmul\n  }];\n  let constructor = \"mlir::oneflow::createGroupMatMul()\";\n  let dependentDialects = [];\n}\n\ndef FuseForwardOps : Pass<\"fuse-forward-only-ops\", \"ModuleOp\"> {\n  let summary = \"fuse forward ops\";\n  let description = [{\n    fuse forward ops. Usually they are actions after an op.\n  }];\n  let constructor = \"mlir::oneflow::createFuseForwardOps()\";\n  let dependentDialects = [];\n}\n\ndef FuseOpsWithBackwardImpl : Pass<\"fuse-ops-with-backward-impl\", \"ModuleOp\"> {\n  let summary = \"fuse ops with backward impl\";\n  let description = [{\n    fuse ops with backward impl.\n  }];\n  let constructor = \"mlir::oneflow::createFuseOpsWithBackwardImpl()\";\n  let dependentDialects = [\"pdl_interp::PDLInterpDialect\", \"pdl::PDLDialect\"];\n}\n\ndef FuseNormalizationOps : Pass<\"fuse-normalization-ops\", \"ModuleOp\"> {\n  let summary = \"fuse forward ops\";\n  let description = [{\n    fuse forward ops. Usually they are actions after an op.\n  }];\n  let constructor = \"mlir::oneflow::createFuseNormalizationOps()\";\n  let dependentDialects = [\"pdl_interp::PDLInterpDialect\", \"pdl::PDLDialect\"];\n}\n\n#endif // ONEFLOW_PASSES\n"
  },
  {
    "path": "oneflow/ir/include/OneFlow/OneFlowPatternUtils.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"mlir/IR/BuiltinAttributes.h\"\n#include \"mlir/IR/PatternMatch.h\"\n\nnamespace mlir {\n\nnamespace oneflow {\n\nnamespace rewrites {\n\nmlir::IntegerAttr GetDefaultSeed(::mlir::PatternRewriter& rewriter);\nvoid populateRewrites(RewritePatternSet& patterns);\n\n}  // namespace rewrites\n\nnamespace constraints {\n\nvoid populateConstraints(RewritePatternSet& patterns);\n\n}  // namespace constraints\n}  // namespace oneflow\n\n}  // namespace mlir\n"
  },
  {
    "path": "oneflow/ir/include/OneFlow/OneFlowPatterns.td",
    "content": "\n#ifndef ONEFLOW_PATTERNS\n#define ONEFLOW_PATTERNS\n\ninclude \"mlir/IR/PatternBase.td\"\ninclude \"OneFlow/OneFlowOps.td\"\ninclude \"mlir/Dialect/MemRef/IR/MemRefOps.td\"\ninclude \"mlir/Dialect/GPU/IR/GPUOps.td\"\n\ndef GetFirstValue :\n  NativeCodeCall<\"*$0.begin()\">;\n\n\ndef IsAddToOutputNone: Constraint<CPred<\"mlir::oneflow::IsAddToOutputNone($0)\">, \"\">;\n\ndef IsTraingTrue: Constraint<CPred<\"$0.getValue()\">, \"\">;\n\ndef IsArg: Constraint<CPred<\"$0.dyn_cast<::mlir::BlockArgument>()\">, \"\">;\ndef getResultTypes : NativeCodeCall<\"$0.getResultTypes()\">;\ndef CreateGPUMemcpyOpFromMemrefCopy : NativeCodeCall<\"::mlir::oneflow::CreateGPUMemcpyOpFromMemrefCopy($_builder, $0)\">;\ndef ReplaceCopyWithGPUPattern : Pat<\n  (\n    CopyOp:$results\n    $src,\n    $dst\n  ),\n  (\n    CreateGPUMemcpyOpFromMemrefCopy $results\n  ),\n  [(IsArg $dst)]\n>;\n\n#endif // ONEFLOW_PATTERNS\n"
  },
  {
    "path": "oneflow/ir/include/OneFlow/OneFlowSupport.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_IR_INCLUDE_ONEFLOW_ONEFLOWSUPPORT_H_\n#define ONEFLOW_IR_INCLUDE_ONEFLOW_ONEFLOWSUPPORT_H_\n\n#include <string>\n#include <vector>\n\n#include \"mlir/IR/BuiltinAttributes.h\"\n#include \"mlir/IR/BuiltinTypes.h\"\n#include \"mlir/IR/MLIRContext.h\"\n#include \"OneFlow/OneFlowEnums.h.inc\"\n#include \"oneflow/core/common/shape.h\"\n#include \"oneflow/core/framework/tensor.h\"\n// This include is not necessary now, but it is here for testing the namespace collision\n#include \"oneflow/core/framework/user_op_registry_manager.h\"\n\nnamespace mlir {\n\nnamespace oneflow {\n\nnamespace support {\n\nconst ::oneflow::UserOpDef& getUserOpDef(const std::string& op_type_name);\nstatic const std::vector<std::string>* inputKeys() {\n  static std::vector<std::string> val({\"in\"});\n  return &val;\n}\n\nstd::vector<std::string> GetInputKeys(const std::string& op_type_name);\n\nstd::vector<std::string> GetOutputKeys(const std::string& op_type_name);\n\nmlir::DenseElementsAttr TensorToDenseElementsAttr(\n    const std::shared_ptr<::oneflow::one::Tensor>& tensor, MLIRContext* ctx);\n\nstd::shared_ptr<::oneflow::one::Tensor> DenseElementsAttrToTensor(\n    const mlir::Attribute& attr, const mlir::Attribute& device_tag,\n    const mlir::Attribute& device_name);\nvoid DenseElementsAttrToTensor(const mlir::Attribute& attr, const mlir::Attribute& device_tag,\n                               const mlir::Attribute& device_name,\n                               std::shared_ptr<::oneflow::one::Tensor>& tensor);\n\nFailureOr<::oneflow::DataType> FromMLIRTypeToOFDataType(Type mlir_type);\nFailureOr<::oneflow::DataType> FromMLIRDataTypeToOFDataType(::mlir::oneflow::DataType data_type);\nFailureOr<::oneflow::DataType> FromMLIRAttrToOFDataType(Attribute attr);\n\n}  // namespace support\n\n}  // namespace oneflow\n\n}  // namespace mlir\n\n#endif  // ONEFLOW_IR_INCLUDE_ONEFLOW_ONEFLOWSUPPORT_H_\n"
  },
  {
    "path": "oneflow/ir/include/OneFlow/OneFlowTypes.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_IR_INCLUDE_ONEFLOW_ONEFLOWTYPES_H_\n#define ONEFLOW_IR_INCLUDE_ONEFLOW_ONEFLOWTYPES_H_\n\n#include \"mlir/IR/Types.h\"\n\n#define GET_TYPEDEF_CLASSES\n#include \"OneFlow/OneFlowOpsTypes.h.inc\"\n\n#endif  // ONEFLOW_IR_INCLUDE_ONEFLOW_ONEFLOWTYPES_H_\n"
  },
  {
    "path": "oneflow/ir/include/OneFlow/OneFlowUserOps.td",
    "content": "#ifdef GET_ONEFLOW_ASSIGN_OP_DEFINITIONS\n\ndef OneFlow_AssignUserOp : OneFlow_BaseOp<\"assign\", [NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$ref,\n    OneFlow_Tensor:$value\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_input_arg_modify_fn = 1;\n}\n\ndef OneFlow_AssignIfOp : OneFlow_BaseOp<\"assign_if\", [NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$ref,\n    OneFlow_Tensor:$value,\n    OneFlow_Tensor:$condition\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_input_arg_modify_fn = 1;\n}\n\ndef OneFlow_AssignIfNotOp : OneFlow_BaseOp<\"assign_if_not\", [NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$ref,\n    OneFlow_Tensor:$value,\n    OneFlow_Tensor:$condition\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_input_arg_modify_fn = 1;\n}\n\n#endif // GET_ONEFLOW_ASSIGN_OP_DEFINITIONS\n\n\n#ifdef GET_ONEFLOW_BASE_OP_DEFINITIONS\n\n\n\nclass OneFlow_NormalizationBaseOp<string mnemonic, list<Trait> traits = []> : OneFlow_BaseOp<mnemonic, !listconcat(traits, [NoMemoryEffect, AttrSizedOperandSegments, DeclareOpInterfaceMethods<UserOpCompatibleInterface>])> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    Optional<OneFlow_Tensor>:$moving_mean,\n    Optional<OneFlow_Tensor>:$moving_variance,\n    OneFlow_Tensor:$gamma,\n    OneFlow_Tensor:$beta,\n    Optional<OneFlow_Tensor>:$_add_to_output\n  );\n  let output = (outs\n    OneFlow_Tensor:$y,\n    Optional<OneFlow_Tensor>:$mean,\n    Optional<OneFlow_Tensor>:$inv_variance\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI32Attr, \"0\">:$axis,\n    DefaultValuedAttr<F32Attr, \"0.\">:$epsilon,\n    DefaultValuedAttr<BoolAttr, \"false\">:$training,\n    DefaultValuedAttr<F32Attr, \"0.\">:$momentum\n  );\n  let trait_attrs = (ins\n    DenseI32ArrayAttr:$operand_segment_sizes,\n    DenseI32ArrayAttr:$result_segment_sizes\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_input_arg_modify_fn = 1;\n}\n\n#endif // GET_ONEFLOW_BASE_OP_DEFINITIONS\n\n\n#ifdef GET_ONEFLOW_BINARY_OP_DEFINITIONS\n\ndef OneFlow_BiasAddOp : OneFlow_BaseOp<\"bias_add\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>, DeclareOpInterfaceMethods<NCHWCompatibleInterface>, DeclareOpInterfaceMethods<BiasAddCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$a,\n    OneFlow_Tensor:$b\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI32Attr, \"0\">:$axis\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_CastLikeOp : OneFlow_BaseOp<\"cast_like\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in,\n    OneFlow_Tensor:$dtype_like\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_input_arg_modify_fn = 1;\n}\n\ndef OneFlow_CeluGradOp : OneFlow_BaseOp<\"celu_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$y,\n    OneFlow_Tensor:$dy\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let attrs = (ins\n    DefaultValuedAttr<F64Attr, \"0.\">:$alpha\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_DiagGradOp : OneFlow_BaseOp<\"diag_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$dy,\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI32Attr, \"0\">:$diagonal\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_DiagonalGradOp : OneFlow_BaseOp<\"diagonal_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$dy,\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI32Attr, \"0\">:$offset\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_DotOp : OneFlow_BaseOp<\"dot\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$y\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_DropoutGradOp : OneFlow_BaseOp<\"dropout_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$dy,\n    OneFlow_Tensor:$mask\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let attrs = (ins\n    DefaultValuedAttr<F32Attr, \"0.\">:$scale\n  );\n  let has_check_fn = 1;\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_ElementwiseMaximumOp : OneFlow_BaseOp<\"elementwise_maximum\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$y\n  );\n  let output = (outs\n    OneFlow_Tensor:$z\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_ElementwiseMinimumOp : OneFlow_BaseOp<\"elementwise_minimum\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$y\n  );\n  let output = (outs\n    OneFlow_Tensor:$z\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_EluGradOp : OneFlow_BaseOp<\"elu_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$dy\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let attrs = (ins\n    DefaultValuedAttr<F64Attr, \"0.\">:$alpha\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_FloordivOp : OneFlow_BaseOp<\"floordiv\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$y\n  );\n  let output = (outs\n    OneFlow_Tensor:$z\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_LerpOp : OneFlow_BaseOp<\"lerp\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$start,\n    OneFlow_Tensor:$end,\n    OneFlow_Tensor:$weight\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_LerpGradOp : OneFlow_BaseOp<\"lerp_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$start,\n    OneFlow_Tensor:$end,\n    OneFlow_Tensor:$weight,\n    OneFlow_Tensor:$out_diff\n  );\n  let output = (outs\n    OneFlow_Tensor:$start_diff,\n    OneFlow_Tensor:$end_diff,\n    OneFlow_Tensor:$weight_diff\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_TruncdivOp : OneFlow_BaseOp<\"truncdiv\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$y\n  );\n  let output = (outs\n    OneFlow_Tensor:$z\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_GeluGradOp : OneFlow_BaseOp<\"gelu_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$dy\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_FastGeluGradOp : OneFlow_BaseOp<\"fast_gelu_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$dy\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_QuickGeluGradOp : OneFlow_BaseOp<\"quick_gelu_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$dy\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_SquareReLUGradOp : OneFlow_BaseOp<\"square_relu_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$dy\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_GridSampleOp : OneFlow_BaseOp<\"grid_sample\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$input,\n    OneFlow_Tensor:$grid\n  );\n  let output = (outs\n    OneFlow_Tensor:$output\n  );\n  let attrs = (ins\n    StrAttr:$interpolation_mode,\n    StrAttr:$padding_mode,\n    DefaultValuedAttr<BoolAttr, \"false\">:$align_corners\n  );\n  let has_check_fn = 1;\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_HardsigmoidGradOp : OneFlow_BaseOp<\"hardsigmoid_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$dy\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_HardShrinkGradOp : OneFlow_BaseOp<\"hardshrink_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$y,\n    OneFlow_Tensor:$dy\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let attrs = (ins\n    DefaultValuedAttr<F64Attr, \"0.\">:$lambd\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_HardswishGradOp : OneFlow_BaseOp<\"hardswish_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$dy\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_L1L2RegularizeGradientOp : OneFlow_BaseOp<\"l1_l2_regularize_gradient\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$model,\n    OneFlow_Tensor:$model_diff\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<F32Attr, \"0.\">:$l1,\n    DefaultValuedAttr<F32Attr, \"0.\">:$l2\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_LeakyReluGradOp : OneFlow_BaseOp<\"leaky_relu_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$dy\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let attrs = (ins\n    DefaultValuedAttr<F32Attr, \"0.\">:$alpha\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_MaskedFillOp : OneFlow_BaseOp<\"masked_fill\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$mask\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<BoolAttr, \"false\">:$has_int_operand,\n    DefaultValuedAttr<BoolAttr, \"false\">:$has_float_operand,\n    DefaultValuedAttr<BoolAttr, \"false\">:$has_bool_operand,\n    DefaultValuedAttr<SI64Attr, \"0\">:$int_operand,\n    DefaultValuedAttr<F64Attr, \"0.\">:$float_operand,\n    DefaultValuedAttr<BoolAttr, \"0.\">:$bool_operand\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_input_arg_modify_fn = 1;\n}\n\ndef OneFlow_MishGradOp : OneFlow_BaseOp<\"mish_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$dy\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_NarrowGradOp : OneFlow_BaseOp<\"narrow_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$dy,\n    OneFlow_Tensor:$like\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI64Attr, \"0\">:$dim,\n    DefaultValuedAttr<SI64Attr, \"0\">:$start,\n    DefaultValuedAttr<SI64Attr, \"0\">:$length\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_input_arg_modify_fn = 1;\n}\n\ndef OneFlow_PowOp : OneFlow_BaseOp<\"pow\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$y\n  );\n  let output = (outs\n    OneFlow_Tensor:$z\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_FracOp : OneFlow_BaseOp<\"frac\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x\n  );\n  let output = (outs\n    OneFlow_Tensor:$y\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_PreluOp : OneFlow_BaseOp<\"prelu\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$alpha\n  );\n  let output = (outs\n    OneFlow_Tensor:$y\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_ReluGradOp : OneFlow_BaseOp<\"relu_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$y,\n    OneFlow_Tensor:$dy\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_SeluGradOp : OneFlow_BaseOp<\"selu_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$dy\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_SiluGradOp : OneFlow_BaseOp<\"silu_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$dy\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_ThresholdGradOp : OneFlow_BaseOp<\"threshold_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$dy\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let attrs = (ins\n    DefaultValuedAttr<F64Attr, \"0.\">:$threshold_val\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_SoftShrinkGradOp : OneFlow_BaseOp<\"softshrink_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$y,\n    OneFlow_Tensor:$dy\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let attrs = (ins\n    DefaultValuedAttr<F64Attr, \"0.\">:$alpha\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_TfPreluOp : OneFlow_BaseOp<\"tf_prelu\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$alpha\n  );\n  let output = (outs\n    OneFlow_Tensor:$y\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_UnfoldTensorGradOp : OneFlow_BaseOp<\"unfold_tensor_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$dy,\n    OneFlow_Tensor:$x\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI32Attr, \"0\">:$dimension,\n    DefaultValuedAttr<SI32Attr, \"0\">:$size,\n    DefaultValuedAttr<SI32Attr, \"0\">:$step\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_XdivyOp : OneFlow_BaseOp<\"xdivy\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$y\n  );\n  let output = (outs\n    OneFlow_Tensor:$z\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_XlogyOp : OneFlow_BaseOp<\"xlogy\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$y\n  );\n  let output = (outs\n    OneFlow_Tensor:$z\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_BroadcastZetaOp : OneFlow_BaseOp<\"broadcast_zeta\", [NoGrad,NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$y\n  );\n  let output = (outs\n    OneFlow_Tensor:$z\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\n#endif // GET_ONEFLOW_BINARY_OP_DEFINITIONS\n\n\n#ifdef GET_ONEFLOW_BROADCAST_OP_DEFINITIONS\n\ndef OneFlow_BroadcastAddOp : OneFlow_BaseOp<\"broadcast_add\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>, DeclareOpInterfaceMethods<NCHWCompatibleInterface>, DeclareOpInterfaceMethods<BiasAddCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$y\n  );\n  let output = (outs\n    OneFlow_Tensor:$z\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_BroadcastDivOp : OneFlow_BaseOp<\"broadcast_div\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$y\n  );\n  let output = (outs\n    OneFlow_Tensor:$z\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let hasFolder = 1;\n}\n\ndef OneFlow_BroadcastDivGradOp : OneFlow_BaseOp<\"broadcast_div_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$y,\n    OneFlow_Tensor:$z,\n    OneFlow_Tensor:$dz\n  );\n  let output = (outs\n    OneFlow_Tensor:$dy\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_BroadcastEqualOp : OneFlow_BaseOp<\"broadcast_equal\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$y\n  );\n  let output = (outs\n    OneFlow_Tensor:$z\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_BroadcastFloorModOp : OneFlow_BaseOp<\"broadcast_floor_mod\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$y\n  );\n  let output = (outs\n    OneFlow_Tensor:$z\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_BroadcastFmodOp : OneFlow_BaseOp<\"broadcast_fmod\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$y\n  );\n  let output = (outs\n    OneFlow_Tensor:$z\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_BroadcastGreaterOp : OneFlow_BaseOp<\"broadcast_greater\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$y\n  );\n  let output = (outs\n    OneFlow_Tensor:$z\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_BroadCastInplaceGreaterOp : OneFlow_BaseOp<\"broadcast_inplace_greater\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$y\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_BroadcastGreaterEqualOp : OneFlow_BaseOp<\"broadcast_greater_equal\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$y\n  );\n  let output = (outs\n    OneFlow_Tensor:$z\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_BroadcastLessOp : OneFlow_BaseOp<\"broadcast_less\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$y\n  );\n  let output = (outs\n    OneFlow_Tensor:$z\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_BroadcastLessEqualOp : OneFlow_BaseOp<\"broadcast_less_equal\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$y\n  );\n  let output = (outs\n    OneFlow_Tensor:$z\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_BroadcastIsCloseEqualNanOp : OneFlow_BaseOp<\"broadcast_isclose_eq_nan\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$y\n  );\n  let output = (outs\n    OneFlow_Tensor:$z\n  );\n  let attrs = (ins\n    DefaultValuedAttr<F32Attr, \"1e-08\">:$atol,\n    DefaultValuedAttr<F32Attr, \"1e-05\">:$rtol\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_BroadcastIsCloseNotEqualNanOp : OneFlow_BaseOp<\"broadcast_isclose_neq_nan\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$y\n  );\n  let output = (outs\n    OneFlow_Tensor:$z\n  );\n  let attrs = (ins\n    DefaultValuedAttr<F32Attr, \"1e-08\">:$atol,\n    DefaultValuedAttr<F32Attr, \"1e-05\">:$rtol\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_BroadcastLikeOp : OneFlow_BaseOp<\"broadcast_like\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$like\n  );\n  let output = (outs\n    OneFlow_Tensor:$y\n  );\n  let attrs = (ins\n    SI32ArrayAttr:$broadcast_axes\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_input_arg_modify_fn = 1;\n}\n\ndef OneFlow_BroadcastLogicalAndOp : OneFlow_BaseOp<\"broadcast_logical_and\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$y\n  );\n  let output = (outs\n    OneFlow_Tensor:$z\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_BroadcastLogicalOrOp : OneFlow_BaseOp<\"broadcast_logical_or\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$y\n  );\n  let output = (outs\n    OneFlow_Tensor:$z\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_BroadcastLogicalXorOp : OneFlow_BaseOp<\"broadcast_logical_xor\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$y\n  );\n  let output = (outs\n    OneFlow_Tensor:$z\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_BroadcastMaximumOp : OneFlow_BaseOp<\"broadcast_maximum\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$y\n  );\n  let output = (outs\n    OneFlow_Tensor:$z\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_BroadcastMinimumOp : OneFlow_BaseOp<\"broadcast_minimum\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$y\n  );\n  let output = (outs\n    OneFlow_Tensor:$z\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_BroadcastMulOp : OneFlow_BaseOp<\"broadcast_mul\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$y\n  );\n  let output = (outs\n    OneFlow_Tensor:$z\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let hasFolder = 1;\n}\n\ndef OneFlow_BroadcastNotEqualOp : OneFlow_BaseOp<\"broadcast_not_equal\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$y\n  );\n  let output = (outs\n    OneFlow_Tensor:$z\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_BroadcastPowOp : OneFlow_BaseOp<\"broadcast_pow\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$y\n  );\n  let output = (outs\n    OneFlow_Tensor:$z\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_BroadcastSubOp : OneFlow_BaseOp<\"broadcast_sub\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$y\n  );\n  let output = (outs\n    OneFlow_Tensor:$z\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let hasFolder = 1;\n}\n\ndef OneFlow_BitwiseNotOp : OneFlow_BaseOp<\"bitwise_not\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x\n  );\n  let output = (outs\n    OneFlow_Tensor:$y\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_BroadcastBitwiseAndOp : OneFlow_BaseOp<\"broadcast_bitwise_and\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$y\n  );\n  let output = (outs\n    OneFlow_Tensor:$z\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_BroadcastBitwiseOrOp : OneFlow_BaseOp<\"broadcast_bitwise_or\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$y\n  );\n  let output = (outs\n    OneFlow_Tensor:$z\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_BroadcastBitwiseXorOp : OneFlow_BaseOp<\"broadcast_bitwise_xor\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$y\n  );\n  let output = (outs\n    OneFlow_Tensor:$z\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\n#endif // GET_ONEFLOW_BROADCAST_OP_DEFINITIONS\n\n\n#ifdef GET_ONEFLOW_CONV_OP_DEFINITIONS\n\ndef OneFlow_Conv1DOp : OneFlow_ConvolutionBaseOp<\"conv1d\", [NoMemoryEffect, AttrSizedOperandSegments, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {}\n\ndef OneFlow_Conv2DOp : OneFlow_ConvolutionBaseOp<\"conv2d\", [NoMemoryEffect, AttrSizedOperandSegments, DeclareOpInterfaceMethods<UserOpCompatibleInterface>, DeclareOpInterfaceMethods<NCHWCompatibleInterface>]> {}\n\ndef OneFlow_Conv3DOp : OneFlow_ConvolutionBaseOp<\"conv3d\", [NoMemoryEffect, AttrSizedOperandSegments, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {}\n\ndef OneFlow_ConvBiasGradOp : OneFlow_BaseOp<\"conv_bias_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$dy\n  );\n  let output = (outs\n    OneFlow_Tensor:$bias_diff\n  );\n  let attrs = (ins\n    StrAttr:$data_format,\n    DefaultValuedAttr<SI32Attr, \"0\">:$num_spatial_dims\n  );\n  let has_check_fn = 1;\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_compute_complexity_fn = 1;\n}\n\ndef OneFlow_ConvDataGradOp : OneFlow_BaseOp<\"conv_data_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$dy,\n    OneFlow_Tensor:$filter,\n    OneFlow_Tensor:$x_like,\n    Optional<OneFlow_Tensor>:$_add_to_output\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI32Attr, \"0\">:$num_spatial_dims,\n    SI32ArrayAttr:$padding_before,\n    StrAttr:$data_format,\n    SI32ArrayAttr:$kernel_size,\n    SI32ArrayAttr:$strides,\n    SI32ArrayAttr:$dilation_rate,\n    DefaultValuedAttr<SI32Attr, \"0\">:$groups\n  );\n  let has_check_fn = 1;\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_compute_complexity_fn = 1;\n}\n\ndef OneFlow_ConvFilterGradOp : OneFlow_BaseOp<\"conv_filter_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$dy,\n    OneFlow_Tensor:$x\n  );\n  let output = (outs\n    OneFlow_Tensor:$filter_diff\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI32Attr, \"0\">:$num_spatial_dims,\n    SI32ArrayAttr:$padding_before,\n    StrAttr:$data_format,\n    SI32ArrayAttr:$kernel_size,\n    SI32ArrayAttr:$strides,\n    SI32ArrayAttr:$dilation_rate,\n    DefaultValuedAttr<SI32Attr, \"0\">:$groups\n  );\n  let has_check_fn = 1;\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_compute_complexity_fn = 1;\n}\n\ndef OneFlow_Deconv1DOp : OneFlow_BaseOp<\"deconv1d\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in,\n    OneFlow_Tensor:$weight\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI32Attr, \"0\">:$filters,\n    SI32ArrayAttr:$padding_before,\n    StrAttr:$data_format,\n    SI32ArrayAttr:$kernel_size,\n    SI32ArrayAttr:$output_padding,\n    SI32ArrayAttr:$strides,\n    SI32ArrayAttr:$dilation_rate,\n    DefaultValuedAttr<SI32Attr, \"1\">:$groups\n  );\n  let has_check_fn = 1;\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_Deconv2DOp : OneFlow_BaseOp<\"deconv2d\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in,\n    OneFlow_Tensor:$weight\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI32Attr, \"0\">:$filters,\n    SI32ArrayAttr:$padding_before,\n    StrAttr:$data_format,\n    SI32ArrayAttr:$kernel_size,\n    SI32ArrayAttr:$output_padding,\n    SI32ArrayAttr:$strides,\n    SI32ArrayAttr:$dilation_rate,\n    DefaultValuedAttr<SI32Attr, \"1\">:$groups\n  );\n  let has_check_fn = 1;\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_Deconv3DOp : OneFlow_BaseOp<\"deconv3d\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in,\n    OneFlow_Tensor:$weight\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI32Attr, \"0\">:$filters,\n    SI32ArrayAttr:$padding_before,\n    StrAttr:$data_format,\n    SI32ArrayAttr:$kernel_size,\n    SI32ArrayAttr:$output_padding,\n    SI32ArrayAttr:$strides,\n    SI32ArrayAttr:$dilation_rate,\n    DefaultValuedAttr<SI32Attr, \"1\">:$groups\n  );\n  let has_check_fn = 1;\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\ndef OneFlow_DeformConv2dOp : OneFlow_BaseOp<\"deform_conv2d\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$input,\n    OneFlow_Tensor:$offset,\n    OneFlow_Tensor:$weight,\n    Optional<OneFlow_Tensor>:$bias,\n    OneFlow_Tensor:$mask\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n\n  let attrs = (ins\n    DefaultValuedAttr<SI32Attr, \"0\">:$stride_h,\n    DefaultValuedAttr<SI32Attr, \"0\">:$stride_w,\n    DefaultValuedAttr<SI32Attr, \"0\">:$pad_h,\n    DefaultValuedAttr<SI32Attr, \"0\">:$pad_w,\n    DefaultValuedAttr<SI32Attr, \"0\">:$dilation_h,\n    DefaultValuedAttr<SI32Attr, \"0\">:$dilation_w,\n    DefaultValuedAttr<SI32Attr, \"0\">:$groups,\n    DefaultValuedAttr<SI32Attr, \"0\">:$offset_groups,\n    DefaultValuedAttr<BoolAttr, \"false\">:$use_mask\n  );\n\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_DeformConv2dInputGradOp : OneFlow_BaseOp<\"deform_conv2d_input_grad\", [NoMemoryEffect,\nDeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$output_grad,\n    OneFlow_Tensor:$input,\n    OneFlow_Tensor:$offset,\n    OneFlow_Tensor:$weight,\n    OneFlow_Tensor:$mask\n  );\n  let output = (outs\n    OneFlow_Tensor:$input_grad,\n    //OneFlow_Tensor:$weight_grad,\n    OneFlow_Tensor:$offset_grad,\n    OneFlow_Tensor:$mask_grad\n  );\n\n  let attrs = (ins\n    DefaultValuedAttr<SI32Attr, \"0\">:$stride_h,\n    DefaultValuedAttr<SI32Attr, \"0\">:$stride_w,\n    DefaultValuedAttr<SI32Attr, \"0\">:$pad_h,\n    DefaultValuedAttr<SI32Attr, \"0\">:$pad_w,\n    DefaultValuedAttr<SI32Attr, \"0\">:$dilation_h,\n    DefaultValuedAttr<SI32Attr, \"0\">:$dilation_w,\n    DefaultValuedAttr<SI32Attr, \"0\">:$groups,\n    DefaultValuedAttr<SI32Attr, \"0\">:$offset_groups,\n    DefaultValuedAttr<BoolAttr, \"false\">:$use_mask\n  );\n\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_DeformConv2dParamGradOp : OneFlow_BaseOp<\"deform_conv2d_param_grad\", [NoMemoryEffect,\nDeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$output_grad,\n    OneFlow_Tensor:$input,\n    OneFlow_Tensor:$offset,\n    OneFlow_Tensor:$weight,\n    OneFlow_Tensor:$mask\n  );\n  let output = (outs\n    OneFlow_Tensor:$weight_grad\n  );\n\n  let attrs = (ins\n    DefaultValuedAttr<SI32Attr, \"0\">:$stride_h,\n    DefaultValuedAttr<SI32Attr, \"0\">:$stride_w,\n    DefaultValuedAttr<SI32Attr, \"0\">:$pad_h,\n    DefaultValuedAttr<SI32Attr, \"0\">:$pad_w,\n    DefaultValuedAttr<SI32Attr, \"0\">:$dilation_h,\n    DefaultValuedAttr<SI32Attr, \"0\">:$dilation_w,\n    DefaultValuedAttr<SI32Attr, \"0\">:$groups,\n    DefaultValuedAttr<SI32Attr, \"0\">:$offset_groups,\n    DefaultValuedAttr<BoolAttr, \"false\">:$use_mask\n  );\n\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\n#endif // GET_ONEFLOW_CONV_OP_DEFINITIONS\n\n\n#ifdef GET_ONEFLOW_CROSS_ENTROPY_OP_DEFINITIONS\n\ndef OneFlow_BinaryCrossEntropyOp : OneFlow_BaseOp<\"binary_cross_entropy\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$input,\n    OneFlow_Tensor:$target,\n    Optional<OneFlow_Tensor>:$weight\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_input_arg_modify_fn = 1;\n}\n\ndef OneFlow_BinaryCrossEntropyGradOp : OneFlow_BaseOp<\"binary_cross_entropy_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$input,\n    OneFlow_Tensor:$target,\n    Optional<OneFlow_Tensor>:$weight,\n    OneFlow_Tensor:$dy\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_BinaryCrossEntropyWithLogitsOp : OneFlow_BaseOp<\"binary_cross_entropy_with_logits\", [NoMemoryEffect, AttrSizedOperandSegments, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$input,\n    OneFlow_Tensor:$target,\n    Optional<OneFlow_Tensor>:$weight,\n    Optional<OneFlow_Tensor>:$pos_weight\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<BoolAttr, \"false\">:$has_pos_weight\n  );\n  let trait_attrs = (ins\n    DenseI32ArrayAttr:$operand_segment_sizes\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_input_arg_modify_fn = 1;\n}\n\ndef OneFlow_BinaryCrossEntropyWithLogitsGradOp : OneFlow_BaseOp<\"binary_cross_entropy_with_logits_grad\", [NoMemoryEffect, AttrSizedOperandSegments, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$input,\n    OneFlow_Tensor:$target,\n    Optional<OneFlow_Tensor>:$weight,\n    Optional<OneFlow_Tensor>:$pos_weight,\n    OneFlow_Tensor:$dy\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let attrs = (ins\n    DefaultValuedAttr<BoolAttr, \"false\">:$has_pos_weight\n  );\n  let trait_attrs = (ins\n    DenseI32ArrayAttr:$operand_segment_sizes\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_BinaryCrossEntropyWithLogitsReduceMeanOp : OneFlow_BaseOp<\"binary_cross_entropy_with_logits_reduce_mean\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$input,\n    OneFlow_Tensor:$target\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_input_arg_modify_fn = 1;\n}\n\ndef OneFlow_BinaryCrossEntropyWithLogitsReduceMeanGradOp : OneFlow_BaseOp<\"binary_cross_entropy_with_logits_reduce_mean_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$input,\n    OneFlow_Tensor:$target,\n    OneFlow_Tensor:$dy\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_FusedBCEReduceMeanFwBwOp : OneFlow_BaseOp<\"fused_bce_reduce_mean_fw_bw\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$input,\n    OneFlow_Tensor:$target\n  );\n  let output = (outs\n    OneFlow_Tensor:$out,\n    OneFlow_Tensor:$dx\n  );\n  let attrs = (ins\n    OneFlow_DataType:$out_dtype,\n    DefaultValuedAttr<F64Attr, \"0.\">:$constant_value\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_SigmoidCrossEntropyOp : OneFlow_BaseOp<\"sigmoid_cross_entropy\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$prediction,\n    OneFlow_Tensor:$label\n  );\n  let output = (outs\n    OneFlow_Tensor:$loss\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_input_arg_modify_fn = 1;\n}\n\ndef OneFlow_SigmoidCrossEntropyGradOp : OneFlow_BaseOp<\"sigmoid_cross_entropy_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$prediction,\n    OneFlow_Tensor:$loss_diff,\n    OneFlow_Tensor:$label\n  );\n  let output = (outs\n    OneFlow_Tensor:$prediction_diff\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_input_arg_modify_fn = 1;\n}\n\ndef OneFlow_SparseCrossEntropyOp : OneFlow_BaseOp<\"sparse_cross_entropy\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$prediction,\n    OneFlow_Tensor:$label\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI64Attr, \"0\">:$depth\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_input_arg_modify_fn = 1;\n}\n\ndef OneFlow_SparseCrossEntropyGradOp : OneFlow_BaseOp<\"sparse_cross_entropy_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$prediction,\n    OneFlow_Tensor:$label,\n    OneFlow_Tensor:$dy\n  );\n  let output = (outs\n    OneFlow_Tensor:$prediction_diff\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI64Attr, \"0\">:$depth\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_SparseCrossEntropyMsOp : OneFlow_BaseOp<\"sparse_cross_entropy_ms\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$prediction,\n    OneFlow_Tensor:$label\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI64Attr, \"0\">:$depth\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_input_arg_modify_fn = 1;\n}\n\ndef OneFlow_SparseCrossEntropyMsGradOp : OneFlow_BaseOp<\"sparse_cross_entropy_ms_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$prediction,\n    OneFlow_Tensor:$label,\n    OneFlow_Tensor:$dy\n  );\n  let output = (outs\n    OneFlow_Tensor:$prediction_diff\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI64Attr, \"0\">:$depth\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\n#endif // GET_ONEFLOW_CROSS_ENTROPY_OP_DEFINITIONS\n\n\n#ifdef GET_ONEFLOW_CUDA_OP_DEFINITIONS\n\ndef OneFlow_NvtxEndOp : OneFlow_BaseOp<\"nvtx_end\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    StrAttr:$mark_prefix\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_NvtxStartOp : OneFlow_BaseOp<\"nvtx_start\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    StrAttr:$mark_prefix\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\n#endif // GET_ONEFLOW_CUDA_OP_DEFINITIONS\n\n\n#ifdef GET_ONEFLOW_DATASET_OP_DEFINITIONS\n\ndef OneFlow_COCOReaderOp : OneFlow_BaseOp<\"COCOReader\", [NoMemoryEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let output = (outs\n    OneFlow_TensorBufferTensor:$image,\n    OneFlow_Tensor:$image_id,\n    OneFlow_Tensor:$image_size,\n    OneFlow_TensorBufferTensor:$gt_bbox,\n    OneFlow_TensorBufferTensor:$gt_label,\n    OneFlow_TensorBufferTensor:$gt_segm,\n    OneFlow_TensorBufferTensor:$gt_segm_index\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI64Attr, \"0\">:$session_id,\n    StrAttr:$annotation_file,\n    StrAttr:$image_dir,\n    DefaultValuedAttr<SI64Attr, \"0\">:$batch_size,\n    DefaultValuedAttr<BoolAttr, \"true\">:$shuffle_after_epoch,\n    DefaultValuedAttr<SI64Attr, \"-1\">:$random_seed,\n    DefaultValuedAttr<BoolAttr, \"true\">:$group_by_ratio,\n    DefaultValuedAttr<BoolAttr, \"true\">:$remove_images_without_annotations,\n    DefaultValuedAttr<BoolAttr, \"false\">:$stride_partition,\n    StrArrayAttr:$nd_sbp\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_output_arg_modify_fn = 1;\n  let has_nd_sbp_infer_fn = 1;\n}\n\ndef OneFlow_OFRecordReaderOp : OneFlow_BaseOp<\"OFRecordReader\", [NoMemoryEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    StrAttr:$data_dir,\n    DefaultValuedAttr<SI32Attr, \"0\">:$data_part_num,\n    DefaultValuedAttr<SI32Attr, \"0\">:$batch_size,\n    DefaultValuedAttr<StrAttr, \"\\\"part-\\\"\">:$part_name_prefix,\n    DefaultValuedAttr<SI32Attr, \"-1\">:$part_name_suffix_length,\n    DefaultValuedAttr<BoolAttr, \"false\">:$random_shuffle,\n    DefaultValuedAttr<SI64Attr, \"-1\">:$seed,\n    DefaultValuedAttr<SI32Attr, \"1024\">:$shuffle_buffer_size,\n    DefaultValuedAttr<BoolAttr, \"false\">:$shuffle_after_epoch,\n    StrArrayAttr:$nd_sbp\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_output_arg_modify_fn = 1;\n  let has_nd_sbp_infer_fn = 1;\n  let has_get_nd_sbp_fn = 1;\n  let has_compute_complexity_fn = 1;\n}\n\ndef OneFlow_CtcGreedyDecoderOp : OneFlow_BaseOp<\"ctc_greedy_decoder\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$log_probs,\n    OneFlow_Tensor:$input_lengths\n  );\n  let output = (outs\n    OneFlow_Tensor:$decoded,\n    OneFlow_Tensor:$neg_sum_logits\n  );\n  let attrs = (ins\n    DefaultValuedAttr<BoolAttr, \"false\">:$merge_repeated\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_MegatronGptMmapDataLoaderOp : OneFlow_BaseOp<\"megatron_gpt_mmap_data_loader\", [NoMemoryEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    Optional<OneFlow_Tensor>:$iteration\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    StrAttr:$data_file_prefix,\n    DefaultValuedAttr<SI64Attr, \"0\">:$seq_length,\n    DefaultValuedAttr<SI64Attr, \"1\">:$label_length,\n    DefaultValuedAttr<SI64Attr, \"0\">:$num_samples,\n    DefaultValuedAttr<SI64Attr, \"0\">:$batch_size,\n    OneFlow_DataType:$dtype,\n    SI64ArrayAttr:$split_sizes,\n    DefaultValuedAttr<SI64Attr, \"0\">:$split_index,\n    DefaultValuedAttr<BoolAttr, \"false\">:$shuffle,\n    DefaultValuedAttr<SI64Attr, \"0\">:$random_seed,\n    StrArrayAttr:$nd_sbp\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_input_arg_modify_fn = 1;\n  let has_nd_sbp_infer_fn = 1;\n}\n\ndef OneFlow_OfrecordBytesDecoderOp : OneFlow_BaseOp<\"ofrecord_bytes_decoder\", [NoMemoryEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    StrAttr:$name\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_input_arg_modify_fn = 1;\n}\n\ndef OneFlow_OfrecordImageClassificationReaderOp : OneFlow_BaseOp<\"ofrecord_image_classification_reader\", [NoMemoryEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let output = (outs\n    OneFlow_Tensor:$image,\n    OneFlow_Tensor:$label\n  );\n  let attrs = (ins\n    StrAttr:$data_dir,\n    DefaultValuedAttr<SI32Attr, \"0\">:$data_part_num,\n    DefaultValuedAttr<SI32Attr, \"0\">:$batch_size,\n    DefaultValuedAttr<StrAttr, \"\\\"part-\\\"\">:$part_name_prefix,\n    DefaultValuedAttr<SI32Attr, \"-1\">:$part_name_suffix_length,\n    DefaultValuedAttr<BoolAttr, \"false\">:$random_shuffle,\n    DefaultValuedAttr<SI64Attr, \"-1\">:$seed,\n    DefaultValuedAttr<SI32Attr, \"1024\">:$shuffle_buffer_size,\n    DefaultValuedAttr<BoolAttr, \"false\">:$shuffle_after_epoch,\n    DefaultValuedAttr<StrAttr, \"\\\"BGR\\\"\">:$color_space,\n    DefaultValuedAttr<StrAttr, \"\\\"encoded\\\"\">:$image_feature_name,\n    DefaultValuedAttr<StrAttr, \"\\\"class/label\\\"\">:$label_feature_name,\n    DefaultValuedAttr<SI32Attr, \"8\">:$decode_buffer_size_per_thread,\n    DefaultValuedAttr<SI32Attr, \"0\">:$num_decode_threads_per_machine\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_output_arg_modify_fn = 1;\n}\n\ndef OneFlow_OfrecordImageDecoderOp : OneFlow_BaseOp<\"ofrecord_image_decoder\", [NoMemoryEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    StrAttr:$name,\n    DefaultValuedAttr<StrAttr, \"\\\"BGR\\\"\">:$color_space\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_input_arg_modify_fn = 1;\n}\n\ndef OneFlow_OfrecordImageDecoderRandomCropOp : OneFlow_BaseOp<\"ofrecord_image_decoder_random_crop\", [NoMemoryEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    StrAttr:$name,\n    DefaultValuedAttr<StrAttr, \"\\\"BGR\\\"\">:$color_space,\n    DefaultValuedAttr<SI32Attr, \"10\">:$num_attempts,\n    DefaultValuedAttr<SI64Attr, \"-1\">:$seed,\n    DefaultValuedAttr<BoolAttr, \"false\">:$has_seed,\n    F32ArrayAttr:$random_area,\n    F32ArrayAttr:$random_aspect_ratio\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_input_arg_modify_fn = 1;\n}\n\ndef OneFlow_OfrecordRawDecoderOp : OneFlow_BaseOp<\"ofrecord_raw_decoder\", [NoMemoryEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    StrAttr:$name,\n    ShapeAttr:$shape,\n    OneFlow_DataType:$data_type,\n    DefaultValuedAttr<BoolAttr, \"false\">:$dim1_varying_length,\n    DefaultValuedAttr<BoolAttr, \"false\">:$truncate\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_input_arg_modify_fn = 1;\n}\n\ndef OneFlow_RawReaderOp : OneFlow_BaseOp<\"raw_reader\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    StrArrayAttr:$files,\n    OneFlow_DataType:$data_type,\n    ShapeAttr:$shape,\n    SI64Attr:$batch_size,\n    SI64Attr:$shuffle_block_size,\n    DefaultValuedAttr<BoolAttr, \"true\">:$random_shuffle,\n    DefaultValuedAttr<SI64Attr, \"0\">:$seed,\n    StrArrayAttr:$nd_sbp\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_nd_sbp_infer_fn = 1;\n}\n\n#endif // GET_ONEFLOW_DATASET_OP_DEFINITIONS\n\n\n#ifdef GET_ONEFLOW_DETECTION_OP_DEFINITIONS\n\ndef OneFlow_InTopKOp : OneFlow_BaseOp<\"in_top_k\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$targets,\n    OneFlow_Tensor:$predictions\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI32Attr, \"0\">:$k\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_NmsOp : OneFlow_BaseOp<\"nms\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<F32Attr, \"0.\">:$iou_threshold,\n    DefaultValuedAttr<SI32Attr, \"0\">:$keep_n\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_ObjectBboxFlipOp : OneFlow_BaseOp<\"object_bbox_flip\", [NoMemoryEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$bbox,\n    OneFlow_Tensor:$image_size,\n    OneFlow_Tensor:$flip_code\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_ObjectBboxScaleOp : OneFlow_BaseOp<\"object_bbox_scale\", [NoMemoryEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$bbox,\n    OneFlow_Tensor:$scale\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_ObjectSegmentationPolygonFlipOp : OneFlow_BaseOp<\"object_segmentation_polygon_flip\", [NoMemoryEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$poly,\n    OneFlow_Tensor:$image_size,\n    OneFlow_Tensor:$flip_code\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_ObjectSegmentationPolygonScaleOp : OneFlow_BaseOp<\"object_segmentation_polygon_scale\", [NoMemoryEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$poly,\n    OneFlow_Tensor:$scale\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_ObjectSegmentationPolygonToMaskOp : OneFlow_BaseOp<\"object_segmentation_polygon_to_mask\", [NoMemoryEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$poly,\n    OneFlow_Tensor:$poly_index,\n    OneFlow_Tensor:$image_size\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_RoiAlignOp : OneFlow_BaseOp<\"roi_align\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$rois\n  );\n  let output = (outs\n    OneFlow_Tensor:$y\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI32Attr, \"0\">:$pooled_h,\n    DefaultValuedAttr<SI32Attr, \"0\">:$pooled_w,\n    DefaultValuedAttr<F32Attr, \"0.\">:$spatial_scale,\n    DefaultValuedAttr<SI32Attr, \"0\">:$sampling_ratio,\n    DefaultValuedAttr<BoolAttr, \"false\">:$aligned\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_input_arg_modify_fn = 1;\n}\n\ndef OneFlow_RoiAlignGradOp : OneFlow_BaseOp<\"roi_align_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$dy,\n    OneFlow_Tensor:$x_like,\n    OneFlow_Tensor:$rois\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI32Attr, \"0\">:$pooled_h,\n    DefaultValuedAttr<SI32Attr, \"0\">:$pooled_w,\n    DefaultValuedAttr<F32Attr, \"0.\">:$spatial_scale,\n    DefaultValuedAttr<SI32Attr, \"0\">:$sampling_ratio,\n    DefaultValuedAttr<BoolAttr, \"false\">:$aligned\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_TopKOp : OneFlow_BaseOp<\"top_k\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI32Attr, \"0\">:$k,\n    DefaultValuedAttr<BoolAttr, \"false\">:$sorted\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\n#endif // GET_ONEFLOW_DETECTION_OP_DEFINITIONS\n\n\n#ifdef GET_ONEFLOW_EAGER_OP_DEFINITIONS\n\ndef OneFlow_EagerBToSOp : OneFlow_BaseOp<\"eager_b_to_s\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI64Attr, \"-1\">:$out_split_axis,\n    StrAttr:$in_parallel_conf,\n    StrAttr:$out_parallel_conf,\n    ShapeAttr:$shape\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_device_and_stream_infer_fn = 1;\n  let has_nd_sbp_infer_fn = 1;\n}\n\ndef OneFlow_EagerNaiveSToSOp : OneFlow_BaseOp<\"eager_naive_s_to_s\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI64Attr, \"-1\">:$in_split_axis,\n    DefaultValuedAttr<SI64Attr, \"-1\">:$out_split_axis,\n    StrAttr:$in_parallel_conf,\n    StrAttr:$out_parallel_conf,\n    ShapeAttr:$shape\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_device_and_stream_infer_fn = 1;\n  let has_nd_sbp_infer_fn = 1;\n}\n\ndef OneFlow_EagerCclAllGatherOp : OneFlow_BaseOp<\"eager_ccl_all_gather\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    StrAttr:$parallel_conf,\n    ShapeAttr:$output_shape,\n    OneFlow_DataType:$output_dtype\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_device_and_stream_infer_fn = 1;\n  let has_nd_sbp_infer_fn = 1;\n}\n\ndef OneFlow_EagerCclAllReduceOp : OneFlow_BaseOp<\"eager_ccl_all_reduce\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    StrAttr:$parallel_conf\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_device_and_stream_infer_fn = 1;\n}\n\ndef OneFlow_EagerCclBroadcastOp : OneFlow_BaseOp<\"eager_ccl_broadcast\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    StrAttr:$parallel_conf,\n    ShapeArrayAttr:$shape_list,\n    DefaultValuedAttr<SI64Attr, \"0\">:$root,\n    DefaultValuedAttr<BoolAttr, \"true\">:$async_launch\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_device_and_stream_infer_fn = 1;\n}\n\ndef OneFlow_EagerCclTouchOp : OneFlow_BaseOp<\"eager_ccl_touch\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    Variadic<OneFlow_Tensor>:$in\n  );\n  let attrs = (ins\n    DefaultValuedAttr<BoolAttr, \"true\">:$async_launch\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_device_and_stream_infer_fn = 1;\n}\n\ndef OneFlow_EagerCclReduceOp : OneFlow_BaseOp<\"eager_ccl_reduce\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    StrAttr:$parallel_conf,\n    DefaultValuedAttr<SI64Attr, \"0\">:$root\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_device_and_stream_infer_fn = 1;\n}\n\ndef OneFlow_EagerCclReduceScatterOp : OneFlow_BaseOp<\"eager_ccl_reduce_scatter\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    StrAttr:$parallel_conf,\n    ShapeAttr:$output_shape,\n    OneFlow_DataType:$output_dtype,\n    DefaultValuedAttr<StrAttr, \"\\\"sum\\\"\">:$op_type\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_device_and_stream_infer_fn = 1;\n  let has_nd_sbp_infer_fn = 1;\n}\n\ndef OneFlow_EagerCclS2SOp : OneFlow_BaseOp<\"eager_ccl_s2s\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI64Attr, \"-1\">:$in_split_axis,\n    DefaultValuedAttr<SI64Attr, \"-1\">:$out_split_axis,\n    StrAttr:$parallel_conf\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_device_and_stream_infer_fn = 1;\n  let has_nd_sbp_infer_fn = 1;\n}\n\ndef OneFlow_EagerPToBOp : OneFlow_BaseOp<\"eager_p_to_b\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    StrAttr:$in_parallel_conf,\n    StrAttr:$out_parallel_conf,\n    ShapeAttr:$shape\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_device_and_stream_infer_fn = 1;\n  let has_nd_sbp_infer_fn = 1;\n}\n\ndef OneFlow_EagerPToSOp : OneFlow_BaseOp<\"eager_p_to_s\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI64Attr, \"-1\">:$out_split_axis,\n    StrAttr:$in_parallel_conf,\n    StrAttr:$out_parallel_conf,\n    ShapeAttr:$shape\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_device_and_stream_infer_fn = 1;\n  let has_nd_sbp_infer_fn = 1;\n}\n\ndef OneFlow_EagerSToBOp : OneFlow_BaseOp<\"eager_s_to_b\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI64Attr, \"-1\">:$in_split_axis,\n    StrAttr:$in_parallel_conf,\n    StrAttr:$out_parallel_conf,\n    ShapeAttr:$shape\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_device_and_stream_infer_fn = 1;\n  let has_nd_sbp_infer_fn = 1;\n}\n\ndef OneFlow_EagerSToPOp : OneFlow_BaseOp<\"eager_s_to_p\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI64Attr, \"-1\">:$in_split_axis,\n    StrAttr:$in_parallel_conf,\n    StrAttr:$out_parallel_conf,\n    ShapeAttr:$shape\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_device_and_stream_infer_fn = 1;\n  let has_nd_sbp_infer_fn = 1;\n}\n\ndef OneFlow_EagerSymmetricSToPOp : OneFlow_BaseOp<\"eager_symmetric_s_to_p\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI64Attr, \"-1\">:$in_split_axis,\n    StrAttr:$parallel_conf\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_device_and_stream_infer_fn = 1;\n  let has_nd_sbp_infer_fn = 1;\n}\n\n#endif // GET_ONEFLOW_EAGER_OP_DEFINITIONS\n\n\n#ifdef GET_ONEFLOW_FUSED_OP_DEFINITIONS\n\ndef OneFlow_FusedLstmCellOp : OneFlow_BaseOp<\"fused_lstm_cell\", [NoMemoryEffect, AttrSizedOperandSegments, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$input_gates,\n    OneFlow_Tensor:$hidden_gates,\n    OneFlow_Tensor:$cx,\n    Optional<OneFlow_Tensor>:$input_bias,\n    Optional<OneFlow_Tensor>:$hidden_bias\n  );\n  let output = (outs\n    OneFlow_Tensor:$hy,\n    OneFlow_Tensor:$cy,\n    OneFlow_Tensor:$workspace\n  );\n  let trait_attrs = (ins\n    DenseI32ArrayAttr:$operand_segment_sizes\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_FusedLstmCellGradOp : OneFlow_BaseOp<\"fused_lstm_cell_grad\", [NoMemoryEffect, AttrSizedResultSegments, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$grad_hy,\n    OneFlow_Tensor:$grad_cy,\n    OneFlow_Tensor:$cx,\n    OneFlow_Tensor:$cy,\n    OneFlow_Tensor:$workspace\n  );\n  let output = (outs\n    OneFlow_Tensor:$grad_gates,\n    Optional<OneFlow_Tensor>:$grad_cx,\n    Optional<OneFlow_Tensor>:$grad_bias\n  );\n  let trait_attrs = (ins\n    DenseI32ArrayAttr:$result_segment_sizes\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_FusedGruCellOp : OneFlow_BaseOp<\"fused_gru_cell\", [NoMemoryEffect, AttrSizedOperandSegments, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$input_gates,\n    OneFlow_Tensor:$hidden_gates,\n    OneFlow_Tensor:$hx,\n    Optional<OneFlow_Tensor>:$input_bias,\n    Optional<OneFlow_Tensor>:$hidden_bias\n  );\n  let output = (outs\n    OneFlow_Tensor:$hy,\n    OneFlow_Tensor:$workspace\n  );\n  let trait_attrs = (ins\n    DenseI32ArrayAttr:$operand_segment_sizes\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_FusedGruCellGradOp : OneFlow_BaseOp<\"fused_gru_cell_grad\", [NoMemoryEffect, AttrSizedResultSegments, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$grad_hy,\n    OneFlow_Tensor:$workspace\n  );\n  let output = (outs\n    OneFlow_Tensor:$grad_input_gates,\n    OneFlow_Tensor:$grad_hidden_gates,\n    Optional<OneFlow_Tensor>:$grad_hx,\n    Optional<OneFlow_Tensor>:$grad_input_bias,\n    Optional<OneFlow_Tensor>:$grad_hidden_bias\n  );\n  let trait_attrs = (ins\n    DenseI32ArrayAttr:$result_segment_sizes\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_CudnnFusedNormalizationAddReluOp : OneFlow_BaseOp<\"cudnn_fused_normalization_add_relu\", [NoMemoryEffect, AttrSizedOperandSegments, AttrSizedResultSegments, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    Optional<OneFlow_Tensor>:$addend,\n    Optional<OneFlow_Tensor>:$moving_mean,\n    Optional<OneFlow_Tensor>:$moving_variance,\n    OneFlow_Tensor:$gamma,\n    OneFlow_Tensor:$beta\n  );\n  let output = (outs\n    OneFlow_Tensor:$y,\n    OneFlow_Tensor:$reserve_space,\n    Optional<OneFlow_Tensor>:$mean,\n    Optional<OneFlow_Tensor>:$inv_variance\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI32Attr, \"0\">:$axis,\n    DefaultValuedAttr<F32Attr, \"0.\">:$epsilon,\n    DefaultValuedAttr<F32Attr, \"0.\">:$momentum\n  );\n  let trait_attrs = (ins\n    DenseI32ArrayAttr:$operand_segment_sizes,\n    DenseI32ArrayAttr:$result_segment_sizes\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_input_arg_modify_fn = 1;\n}\n\ndef OneFlow_CudnnFusedNormalizationAddReluGradOp : OneFlow_BaseOp<\"cudnn_fused_normalization_add_relu_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$dy,\n    OneFlow_Tensor:$mean,\n    OneFlow_Tensor:$inv_variance,\n    OneFlow_Tensor:$gamma,\n    OneFlow_Tensor:$beta,\n    OneFlow_Tensor:$reserve_space,\n    OneFlow_Tensor:$y\n  );\n  let output = (outs\n    OneFlow_Tensor:$gamma_diff,\n    OneFlow_Tensor:$beta_diff,\n    OneFlow_Tensor:$dx,\n    Optional<OneFlow_Tensor>:$addend_diff\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI32Attr, \"0\">:$axis,\n    DefaultValuedAttr<F32Attr, \"0.\">:$epsilon\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_FusedBiasAddGeluOp : OneFlow_BaseOp<\"fused_bias_add_gelu\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$a,\n    OneFlow_Tensor:$b\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI32Attr, \"0\">:$axis\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_FusedBiasAddGeluGradOp : OneFlow_BaseOp<\"fused_bias_add_gelu_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$a,\n    OneFlow_Tensor:$b,\n    OneFlow_Tensor:$dy\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI32Attr, \"0\">:$axis\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_FusedBiasAddMaskScaleOp : OneFlow_BaseOp<\"fused_bias_add_mask_scale\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$a,\n    OneFlow_Tensor:$b,\n    OneFlow_Tensor:$mask,\n    Optional<OneFlow_Tensor>:$_add_to_output\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI32Attr, \"0\">:$axis,\n    DefaultValuedAttr<F32Attr, \"1.\">:$scale,\n    DefaultValuedAttr<SI64Attr, \"0\">:$seed\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_input_arg_modify_fn = 1;\n}\n\ndef OneFlow_FusedCastScaleOp : OneFlow_BaseOp<\"fused_cast_scale\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$scale_by_tensor\n  );\n  let output = (outs\n    OneFlow_Tensor:$y\n  );\n  let attrs = (ins\n    DefaultValuedAttr<F64Attr, \"1.\">:$scale\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_FusedScaleMaskSoftmaxOp : OneFlow_BaseOp<\"fused_scale_mask_softmax\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$mask\n  );\n  let output = (outs\n    OneFlow_Tensor:$y\n  );\n  let attrs = (ins\n    DefaultValuedAttr<F32Attr, \"1.\">:$scale_value,\n    DefaultValuedAttr<F32Attr, \"0.\">:$mask_fill_value\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_input_arg_modify_fn = 1;\n}\n\ndef OneFlow_FusedScaleMaskSoftmaxDropoutOp : OneFlow_BaseOp<\"fused_scale_mask_softmax_dropout\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$mask,\n    OneFlow_Tensor:$dropout_mask\n  );\n  let output = (outs\n    OneFlow_Tensor:$y,\n    OneFlow_Tensor:$softmax_y\n  );\n  let attrs = (ins\n    DefaultValuedAttr<F32Attr, \"1.\">:$scale_value,\n    DefaultValuedAttr<F32Attr, \"0.\">:$mask_fill_value,\n    DefaultValuedAttr<F32Attr, \"1.\">:$dropout_scale_value\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_input_arg_modify_fn = 1;\n}\n\ndef OneFlow_FusedScaleMaskSoftmaxDropoutGradOp : OneFlow_BaseOp<\"fused_scale_mask_softmax_dropout_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$softmax_y,\n    OneFlow_Tensor:$dy,\n    OneFlow_Tensor:$mask,\n    OneFlow_Tensor:$dropout_mask\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let attrs = (ins\n    DefaultValuedAttr<F32Attr, \"0.\">:$scale_value,\n    DefaultValuedAttr<F32Attr, \"0.\">:$dropout_scale_value\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_FusedBiasAddScaleMaskSoftmaxDropoutOp : OneFlow_BaseOp<\"fused_bias_add_scale_mask_softmax_dropout\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$bias,\n    OneFlow_Tensor:$mask,\n    OneFlow_Tensor:$dropout_mask\n  );\n  let output = (outs\n    OneFlow_Tensor:$y,\n    OneFlow_Tensor:$softmax_y\n  );\n  let attrs = (ins\n    DefaultValuedAttr<F32Attr, \"1.\">:$scale_value,\n    DefaultValuedAttr<F32Attr, \"0.\">:$mask_fill_value,\n    DefaultValuedAttr<F32Attr, \"1.\">:$dropout_scale_value\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_input_arg_modify_fn = 1;\n}\n\ndef OneFlow_FusedScaleMaskSoftmaxGradOp : OneFlow_BaseOp<\"fused_scale_mask_softmax_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$y,\n    OneFlow_Tensor:$dy,\n    OneFlow_Tensor:$mask\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let attrs = (ins\n    DefaultValuedAttr<F32Attr, \"0.\">:$scale_value\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_FusedScaleTrilOp : OneFlow_BaseOp<\"fused_scale_tril\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI64Attr, \"0\">:$diagonal,\n    DefaultValuedAttr<F64Attr, \"0.\">:$floating_fill_value,\n    DefaultValuedAttr<SI64Attr, \"0\">:$integer_fill_value,\n    DefaultValuedAttr<BoolAttr, \"false\">:$is_floating_fill_value,\n    DefaultValuedAttr<F64Attr, \"1.\">:$floating_scale_value,\n    DefaultValuedAttr<SI64Attr, \"1\">:$integer_scale_value,\n    DefaultValuedAttr<BoolAttr, \"false\">:$is_floating_scale_value\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_FusedSelfAttentionQueryMulKeyAndValueOp : OneFlow_BaseOp<\"fused_self_attention_query_mul_key_and_value\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$hidden_states\n  );\n  let output = (outs\n    OneFlow_Tensor:$query_mul_key,\n    OneFlow_Tensor:$value\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI64Attr, \"0\">:$head_size,\n    DefaultValuedAttr<F32Attr, \"0.\">:$alpha\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_FusedSelfAttentionQueryMulKeyAndValueGradOp : OneFlow_BaseOp<\"fused_self_attention_query_mul_key_and_value_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$query_mul_key_grad,\n    OneFlow_Tensor:$value_grad,\n    OneFlow_Tensor:$hidden_states\n  );\n  let output = (outs\n    OneFlow_Tensor:$hidden_states_grad\n  );\n  let attrs = (ins\n    DefaultValuedAttr<F32Attr, \"0.\">:$alpha\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_FusedTrilScaleSoftmaxMaskScaleOp : OneFlow_BaseOp<\"fused_tril_scale_softmax_mask_scale\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$mask\n  );\n  let output = (outs\n    OneFlow_Tensor:$y,\n    OneFlow_Tensor:$softmax_y\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI64Attr, \"0\">:$diagonal,\n    DefaultValuedAttr<F32Attr, \"0.\">:$tril_fill_value,\n    DefaultValuedAttr<F32Attr, \"1.\">:$tril_scale_value,\n    DefaultValuedAttr<F32Attr, \"1.\">:$mask_scale_value\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_input_arg_modify_fn = 1;\n}\n\ndef OneFlow_FusedTrilScaleSoftmaxMaskScaleGradOp : OneFlow_BaseOp<\"fused_tril_scale_softmax_mask_scale_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$softmax_y,\n    OneFlow_Tensor:$dy,\n    OneFlow_Tensor:$mask\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI64Attr, \"0\">:$diagonal,\n    DefaultValuedAttr<F32Attr, \"0.\">:$tril_scale_value,\n    DefaultValuedAttr<F32Attr, \"0.\">:$mask_scale_value\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\n\ndef OneFlow_NormalizationAddReluGradOp : OneFlow_BaseOp<\"normalization_add_relu_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$dy,\n    OneFlow_Tensor:$mean,\n    OneFlow_Tensor:$inv_variance,\n    OneFlow_Tensor:$gamma,\n    OneFlow_Tensor:$beta,\n    OneFlow_Tensor:$reserve_space,\n    OneFlow_Tensor:$y\n  );\n  let output = (outs\n    OneFlow_Tensor:$gamma_diff,\n    OneFlow_Tensor:$beta_diff,\n    OneFlow_Tensor:$dx,\n    Optional<OneFlow_Tensor>:$addend_diff\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI32Attr, \"0\">:$axis,\n    DefaultValuedAttr<F32Attr, \"0.\">:$epsilon\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\n\ndef OneFlow_FusedDotFeatureInteractionOp : OneFlow_BaseOp<\"fused_dot_feature_interaction\", [NoMemoryEffect, AttrSizedOperandSegments, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    Variadic<OneFlow_Tensor>:$features,\n    Optional<OneFlow_Tensor>:$output_concat,\n    Optional<OneFlow_Tensor>:$num_valid_sparse_feature,\n    Optional<OneFlow_Tensor>:$sparse_feature,\n    Optional<OneFlow_Tensor>:$sparse_indices\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<BoolAttr, \"false\">:$self_interaction,\n    DefaultValuedAttr<BoolAttr, \"false\">:$has_output_concat,\n    DefaultValuedAttr<SI32Attr, \"0\">:$output_padding,\n    DefaultValuedAttr<StrAttr, \"\\\"none\\\"\">:$pooling\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_FusedDotFeatureInteractionGradOp : OneFlow_BaseOp<\"fused_dot_feature_interaction_grad\", [NoMemoryEffect, AttrSizedOperandSegments, AttrSizedResultSegments, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$dy,\n    Variadic<OneFlow_Tensor>:$features,\n    Optional<OneFlow_Tensor>:$num_valid_sparse_feature,\n    Optional<OneFlow_Tensor>:$sparse_feature,\n    Optional<OneFlow_Tensor>:$sparse_indices\n  );\n  let output = (outs\n    Variadic<OneFlow_Tensor>:$features_grad,\n    Optional<OneFlow_Tensor>:$output_concat_grad,\n    Optional<OneFlow_Tensor>:$sparse_feature_grad\n  );\n  let attrs = (ins\n    DefaultValuedAttr<BoolAttr, \"false\">:$self_interaction,\n    DefaultValuedAttr<SI32Attr, \"0\">:$output_concat_grad_dim,\n    DefaultValuedAttr<StrAttr, \"\\\"none\\\"\">:$pooling\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_FusedCrossFeatureInteractionOp : OneFlow_BaseOp<\"fused_cross_feature_interaction\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$weight,\n    OneFlow_Tensor:$bias,\n    OneFlow_Tensor:$x0\n  );\n  let output = (outs\n    OneFlow_Tensor:$out,\n    OneFlow_Tensor:$matmul_result\n  );\n  let attrs = (ins\n    StrAttr:$interaction_mode\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\n\ndef OneFlow_FusedCrossFeatureInteractionV1GradOp : OneFlow_BaseOp<\"fused_cross_feature_interaction_v1_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>, NoGrad]> {\n  let input = (ins\n    OneFlow_Tensor:$dy,\n    OneFlow_Tensor:$weight,\n    OneFlow_Tensor:$x0,\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$matmul_result\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx0,\n    OneFlow_Tensor:$dw,\n    OneFlow_Tensor:$dx,\n    OneFlow_Tensor:$dbias\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_FusedCrossFeatureInteractionV2GradOp : OneFlow_BaseOp<\"fused_cross_feature_interaction_v2_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>, NoGrad]> {\n  let input = (ins\n    OneFlow_Tensor:$dy,\n    OneFlow_Tensor:$weight,\n    OneFlow_Tensor:$bias,\n    OneFlow_Tensor:$x0,\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$matmul_result\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx0,\n    OneFlow_Tensor:$dw,\n    OneFlow_Tensor:$dx,\n    OneFlow_Tensor:$dbias\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_ScaledDotProductFlashAttentionOp : OneFlow_BaseOp<\"scaled_dot_product_flash_attention\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$query,\n    OneFlow_Tensor:$key,\n    OneFlow_Tensor:$value,\n    Optional<OneFlow_Tensor>:$alibi_slopes_\n  );\n  let output = (outs\n    OneFlow_Tensor:$out,\n    OneFlow_Tensor:$softmax_lse,\n    OneFlow_Tensor:$rng_state\n  );\n  let attrs = (ins\n    DefaultValuedAttr<F32Attr, \"0.\">:$p_dropout,\n    DefaultValuedAttr<F32Attr, \"0.\">:$softmax_scale,\n    DefaultValuedAttr<BoolAttr, \"false\">:$is_causal,\n    SI32Attr:$window_size_left,\n    SI32Attr:$window_size_right,\n    DefaultValuedAttr<SI64Attr, \"0\">:$seed\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_ScaledDotProductFlashAttentionGradOp : OneFlow_BaseOp<\"scaled_dot_product_flash_attention_grad\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$grad_out,\n    OneFlow_Tensor:$query,\n    OneFlow_Tensor:$key,\n    OneFlow_Tensor:$value,\n    OneFlow_Tensor:$out,\n    OneFlow_Tensor:$softmax_lse,\n    OneFlow_Tensor:$rng_state,\n    Optional<OneFlow_Tensor>:$alibi_slopes_\n  );\n  let output = (outs\n    OneFlow_Tensor:$grad_q,\n    OneFlow_Tensor:$grad_k,\n    OneFlow_Tensor:$grad_v\n  );\n  let attrs = (ins\n    DefaultValuedAttr<F32Attr, \"0.\">:$p_dropout,\n    DefaultValuedAttr<F32Attr, \"0.\">:$softmax_scale,\n    DefaultValuedAttr<BoolAttr, \"false\">:$is_causal,\n    SI32Attr:$window_size_left,\n    SI32Attr:$window_size_right\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_FusedMultiHeadAttentionInferenceOp : OneFlow_BaseOp<\"fused_multi_head_attention_inference\", [NoMemoryEffect, AttrSizedOperandSegments, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$query,\n    OneFlow_Tensor:$key,\n    OneFlow_Tensor:$value,\n    Optional<OneFlow_Tensor>:$attn_bias,\n    Optional<OneFlow_Tensor>:$query_seq_start,\n    Optional<OneFlow_Tensor>:$key_seq_start,\n    Optional<OneFlow_Tensor>:$key_seq_len\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    SI64Attr:$query_head_size,\n    DefaultValuedAttr<SI64Attr, \"0\">:$query_max_seq_len,\n    DefaultValuedAttr<SI64Attr, \"0\">:$key_max_seq_len,\n    F64Attr:$scale,\n    DefaultValuedAttr<SI64Attr, \"0\">:$causal_diagonal_offset,\n    DefaultValuedAttr<StrAttr, \"\\\"none\\\"\">:$attn_mask_type,\n    StrAttr:$query_layout,\n    StrAttr:$key_layout,\n    StrAttr:$value_layout,\n    StrAttr:$output_layout\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_FusedAttentionConcatPastKeyValueOp : OneFlow_BaseOp<\"fused_attention_concat_past_key_value\", [NoMemoryEffect, AttrSizedOperandSegments, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$key,\n    OneFlow_Tensor:$value,\n    Optional<OneFlow_Tensor>:$past_key,\n    Optional<OneFlow_Tensor>:$past_value\n  );\n  let output = (outs\n    OneFlow_Tensor:$output_key,\n    OneFlow_Tensor:$output_value\n  );\n  let attrs = (ins\n    StrAttr:$past_key_layout,\n    StrAttr:$past_value_layout,\n    StrAttr:$key_layout,\n    StrAttr:$value_layout,\n    SI64Attr:$key_head_size\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_FusedFastGeluMulOp : OneFlow_BaseOp<\"fused_fast_gelu_mul\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in,\n    OneFlow_Tensor:$multiplier\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_FusedFastGeluMulGradOp : OneFlow_BaseOp<\"fused_fast_gelu_mul_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$out_diff,\n    OneFlow_Tensor:$in,\n    OneFlow_Tensor:$multiplier\n  );\n  let output = (outs\n    OneFlow_Tensor:$in_diff,\n    OneFlow_Tensor:$multiplier_diff\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_FusedGetBounddingBoxesCoordOp : OneFlow_BaseOp<\"fused_get_boundding_boxes_coord\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x1,\n    OneFlow_Tensor:$y1,\n    OneFlow_Tensor:$w1,\n    OneFlow_Tensor:$h1,\n    OneFlow_Tensor:$x2,\n    OneFlow_Tensor:$y2,\n    OneFlow_Tensor:$w2,\n    OneFlow_Tensor:$h2\n  );\n  let output = (outs\n    OneFlow_Tensor:$b1_x1,\n    OneFlow_Tensor:$b1_x2,\n    OneFlow_Tensor:$b1_y1,\n    OneFlow_Tensor:$b1_y2,\n    OneFlow_Tensor:$b2_x1,\n    OneFlow_Tensor:$b2_x2,\n    OneFlow_Tensor:$b2_y1,\n    OneFlow_Tensor:$b2_y2\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_FusedGetBounddingBoxesCoordGradOp : OneFlow_BaseOp<\"fused_get_boundding_boxes_coord_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$b1_x1_diff,\n    OneFlow_Tensor:$b1_x2_diff,\n    OneFlow_Tensor:$b1_y1_diff,\n    OneFlow_Tensor:$b1_y2_diff,\n    OneFlow_Tensor:$b2_x1_diff,\n    OneFlow_Tensor:$b2_x2_diff,\n    OneFlow_Tensor:$b2_y1_diff,\n    OneFlow_Tensor:$b2_y2_diff\n  );\n  let output = (outs\n    OneFlow_Tensor:$x1_diff,\n    OneFlow_Tensor:$y1_diff,\n    OneFlow_Tensor:$w1_diff,\n    OneFlow_Tensor:$h1_diff,\n    OneFlow_Tensor:$x2_diff,\n    OneFlow_Tensor:$y2_diff,\n    OneFlow_Tensor:$w2_diff,\n    OneFlow_Tensor:$h2_diff\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_FusedGetCiouResultOp : OneFlow_BaseOp<\"fused_get_ciou_result\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$v,\n    OneFlow_Tensor:$iou,\n    OneFlow_Tensor:$rho2,\n    OneFlow_Tensor:$c2\n  );\n  let output = (outs\n    OneFlow_Tensor:$y,\n    OneFlow_Tensor:$alpha\n  );\n  let attrs = (ins\n    F32Attr: $eps\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_FusedGetCiouResultGradOp : OneFlow_BaseOp<\"fused_get_ciou_result_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$dy,\n    OneFlow_Tensor:$alpha,\n    OneFlow_Tensor:$rho2,\n    OneFlow_Tensor:$c2\n  );\n  let output = (outs\n    OneFlow_Tensor:$dv,\n    OneFlow_Tensor:$diou,\n    OneFlow_Tensor:$drho2,\n    OneFlow_Tensor:$dc2\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_FusedGetIouOp : OneFlow_BaseOp<\"fused_get_iou\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$w1,\n    OneFlow_Tensor:$h1,\n    OneFlow_Tensor:$w2,\n    OneFlow_Tensor:$h2,\n    OneFlow_Tensor:$inter\n  );\n  let output = (outs\n    OneFlow_Tensor:$iou\n  );\n  let attrs = (ins\n    F32Attr: $eps\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_FusedGetIouGradOp : OneFlow_BaseOp<\"fused_get_iou_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$diou,\n    OneFlow_Tensor:$w1,\n    OneFlow_Tensor:$h1,\n    OneFlow_Tensor:$w2,\n    OneFlow_Tensor:$h2,\n    OneFlow_Tensor:$inter\n  );\n  let attrs = (ins\n    F32Attr: $eps\n  );\n  let output = (outs\n    OneFlow_Tensor:$dw1,\n    OneFlow_Tensor:$dh1,\n    OneFlow_Tensor:$dinter\n  );\n  let attrs = (ins\n    F32Attr: $eps\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_FusedCenterOp : OneFlow_BaseOp<\"fused_get_center_dist\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$b1_x1,\n    OneFlow_Tensor:$b1_x2,\n    OneFlow_Tensor:$b2_x1,\n    OneFlow_Tensor:$b2_x2,\n    OneFlow_Tensor:$b1_y1,\n    OneFlow_Tensor:$b1_y2,\n    OneFlow_Tensor:$b2_y1,\n    OneFlow_Tensor:$b2_y2\n  );\n  let output = (outs\n    OneFlow_Tensor:$rho2\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_FusedCenterGradOp : OneFlow_BaseOp<\"fused_get_center_dist_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$b1_x1,\n    OneFlow_Tensor:$b1_x2,\n    OneFlow_Tensor:$b2_x1,\n    OneFlow_Tensor:$b2_x2,\n    OneFlow_Tensor:$b1_y1,\n    OneFlow_Tensor:$b1_y2,\n    OneFlow_Tensor:$b2_y1,\n    OneFlow_Tensor:$b2_y2,\n    OneFlow_Tensor:$rho2_diff\n  );\n  let output = (outs\n    OneFlow_Tensor:$b1_x1_diff,\n    OneFlow_Tensor:$b1_x2_diff,\n    OneFlow_Tensor:$b2_x1_diff,\n    OneFlow_Tensor:$b2_x2_diff,\n    OneFlow_Tensor:$b1_y1_diff,\n    OneFlow_Tensor:$b1_y2_diff,\n    OneFlow_Tensor:$b2_y1_diff,\n    OneFlow_Tensor:$b2_y2_diff\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_FusedGetCiouDiagonalAngleOp : OneFlow_BaseOp<\"fused_get_ciou_diagonal_angle\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$w1,\n    OneFlow_Tensor:$h1,\n    OneFlow_Tensor:$w2,\n    OneFlow_Tensor:$h2\n  );\n  let output = (outs\n    OneFlow_Tensor:$v\n  );\n  let attrs = (ins\n    DefaultValuedAttr<F32Attr, \"1e-08\">:$eps\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_FusedGetCiouDiagonalAngleGradOp : OneFlow_BaseOp<\"fused_get_ciou_diagonal_angle_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$w1,\n    OneFlow_Tensor:$h1,\n    OneFlow_Tensor:$w2,\n    OneFlow_Tensor:$h2,\n    OneFlow_Tensor:$v_diff\n  );\n  let output = (outs\n    OneFlow_Tensor:$w1_diff,\n    OneFlow_Tensor:$h1_diff,\n    OneFlow_Tensor:$w2_diff,\n    OneFlow_Tensor:$h2_diff\n  );\n  let attrs = (ins\n    DefaultValuedAttr<F32Attr, \"1e-08\">:$eps\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_FusedGetIntersectionAreaOp : OneFlow_BaseOp<\"fused_get_intersection_area\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$b1_x1,\n    OneFlow_Tensor:$b1_x2,\n    OneFlow_Tensor:$b2_x1,\n    OneFlow_Tensor:$b2_x2,\n    OneFlow_Tensor:$b1_y1,\n    OneFlow_Tensor:$b1_y2,\n    OneFlow_Tensor:$b2_y1,\n    OneFlow_Tensor:$b2_y2\n  );\n  let output = (outs\n    OneFlow_Tensor:$inter\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_FusedGetIntersectionAreaGradOp : OneFlow_BaseOp<\"fused_get_intersection_area_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$b1_x1,\n    OneFlow_Tensor:$b1_x2,\n    OneFlow_Tensor:$b2_x1,\n    OneFlow_Tensor:$b2_x2,\n    OneFlow_Tensor:$b1_y1,\n    OneFlow_Tensor:$b1_y2,\n    OneFlow_Tensor:$b2_y1,\n    OneFlow_Tensor:$b2_y2,\n    OneFlow_Tensor:$inter_diff\n  );\n  let output = (outs\n    OneFlow_Tensor:$b1_x1_diff,\n    OneFlow_Tensor:$b1_x2_diff,\n    OneFlow_Tensor:$b2_x1_diff,\n    OneFlow_Tensor:$b2_x2_diff,\n    OneFlow_Tensor:$b1_y1_diff,\n    OneFlow_Tensor:$b1_y2_diff,\n    OneFlow_Tensor:$b2_y1_diff,\n    OneFlow_Tensor:$b2_y2_diff\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\n\ndef OneFlow_FusedGetConvexDiagonalSquaredOp : OneFlow_BaseOp<\"fused_get_convex_diagonal_squared\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$b1_x1,\n    OneFlow_Tensor:$b1_x2,\n    OneFlow_Tensor:$b2_x1,\n    OneFlow_Tensor:$b2_x2,\n    OneFlow_Tensor:$b1_y1,\n    OneFlow_Tensor:$b1_y2,\n    OneFlow_Tensor:$b2_y1,\n    OneFlow_Tensor:$b2_y2\n  );\n  let output = (outs\n    OneFlow_Tensor:$c2\n  );\n  let attrs = (ins\n    DefaultValuedAttr<F32Attr, \"1e-08\">:$eps\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_FusedGetConvexDiagonalSquaredGradOp : OneFlow_BaseOp<\"fused_get_convex_diagonal_squared_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$c2_diff,\n    OneFlow_Tensor:$b1_x1,\n    OneFlow_Tensor:$b1_x2,\n    OneFlow_Tensor:$b2_x1,\n    OneFlow_Tensor:$b2_x2,\n    OneFlow_Tensor:$b1_y1,\n    OneFlow_Tensor:$b1_y2,\n    OneFlow_Tensor:$b2_y1,\n    OneFlow_Tensor:$b2_y2\n  );\n  let output = (outs\n    OneFlow_Tensor:$b1_x1_diff,\n    OneFlow_Tensor:$b1_x2_diff,\n    OneFlow_Tensor:$b2_x1_diff,\n    OneFlow_Tensor:$b2_x2_diff,\n    OneFlow_Tensor:$b1_y1_diff,\n    OneFlow_Tensor:$b1_y2_diff,\n    OneFlow_Tensor:$b2_y1_diff,\n    OneFlow_Tensor:$b2_y2_diff\n  );\n  let attrs = (ins\n    DefaultValuedAttr<F32Attr, \"1e-08\">:$eps\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_FusedScaleMaskBiasSoftmaxGradOp : OneFlow_BaseOp<\"fused_scale_mask_bias_softmax_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$y,\n    OneFlow_Tensor:$dy\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let attrs = (ins\n    DefaultValuedAttr<F32Attr, \"0.35355\">:$scale\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_FusedScaleMaskBiasSoftmaxOp : OneFlow_BaseOp<\"fused_scale_mask_bias_softmax\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$mask,\n    Optional<OneFlow_Tensor>:$bias\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<F32Attr, \"0.125\">:$scale,\n    DefaultValuedAttr<BoolAttr, \"false\">:$inplace\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_FusedCodegeexQkvReshapeOp : OneFlow_BaseOp<\"fused_codegeex_qkv_reshape\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$query,\n    OneFlow_Tensor:$key,\n    OneFlow_Tensor:$value\n  );\n  let output = (outs\n    OneFlow_Tensor:$new_query,\n    OneFlow_Tensor:$new_key,\n    OneFlow_Tensor:$new_value\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI32Attr, \"1\">:$num_attention_heads\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_FusedClipGradOp : OneFlow_BaseOp<\"fused_clip_grad\", [NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    Variadic<OneFlow_Tensor>:$model_diff\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<F32Attr, \"2.\">:$max_norm,\n    DefaultValuedAttr<F32Attr, \"1.\">:$norm_type\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_input_arg_modify_fn = 1;\n}\n\ndef OneFlow_NonContiguousBinaryOp : OneFlow_BaseOp<\"noncontiguous_binary_op\", [NoMemoryEffect, SupportNonContiguous, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$lhs,\n    OneFlow_Tensor:$rhs\n  );\n  let output = (outs\n    OneFlow_Tensor:$y\n  );\n  let attrs = (ins\n    DefaultValuedAttr<StrAttr, \"\\\"add\\\"\">:$op,\n    DefaultValuedAttr<BoolAttr, \"false\">:$inplace\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_NonContiguousBinaryOpGrad : OneFlow_BaseOp<\"noncontiguous_binary_op_grad\", [NoMemoryEffect, SupportNonContiguous, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$dy,\n    OneFlow_Tensor:$lhs,\n    OneFlow_Tensor:$rhs\n  );\n  let output = (outs\n    OneFlow_Tensor:$dlhs,\n    OneFlow_Tensor:$drhs\n  );\n  let attrs = (ins\n    DefaultValuedAttr<StrAttr, \"\\\"add\\\"\">:$op,\n    DefaultValuedAttr<BoolAttr, \"false\">:$inplace\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\n#endif // GET_ONEFLOW_FUSED_OP_DEFINITIONS\n\n\n#ifdef GET_ONEFLOW_IDEMPOTENT_OP_DEFINITIONS\n\ndef OneFlow_AbsOp : OneFlow_IdempotentBaseOp<\"abs\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {}\n\ndef OneFlow_CeilOp : OneFlow_IdempotentBaseOp<\"ceil\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {}\n\ndef OneFlow_FloorOp : OneFlow_IdempotentBaseOp<\"floor\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {}\n\ndef OneFlow_OnesLikeOp : OneFlow_IdempotentBaseOp<\"ones_like\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let same_output_regst_num = 1;\n  let has_nd_sbp_infer_fn = 1;\n  let input = (ins AnyType:$like);\n  let output = (outs AnyType:$out);\n}\n\ndef OneFlow_ReluOp : OneFlow_IdempotentBaseOp<\"relu\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>, DeclareOpInterfaceMethods<NCHWCompatibleInterface>]> {}\n\ndef OneFlow_RintOp : OneFlow_IdempotentBaseOp<\"rint\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {}\n\ndef OneFlow_RoundOp : OneFlow_IdempotentBaseOp<\"round\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {}\n\ndef OneFlow_SignOp : OneFlow_IdempotentBaseOp<\"sign\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {}\n\n#endif // GET_ONEFLOW_IDEMPOTENT_OP_DEFINITIONS\n\n\n#ifdef GET_ONEFLOW_IDENTITY_OP_DEFINITIONS\n\ndef OneFlow_AmpWhiteIdentityOp : OneFlow_BaseOp<\"amp_white_identity\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_AmpBlackIdentityOp : OneFlow_BaseOp<\"amp_black_identity\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_IdentityOp : OneFlow_BaseOp<\"identity\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_IdentityBufferOp : OneFlow_BaseOp<\"identity_buffer\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI64Attr, \"0\">:$buffer_size\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_TupleIdentityOp : OneFlow_BaseOp<\"tuple_identity\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    Variadic<OneFlow_Tensor>:$in\n  );\n  let output = (outs\n    Variadic<OneFlow_Tensor>:$out\n  );\n  let has_check_fn = 1;\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_sbp_signature_infer_fn = 1;\n}\n\ndef OneFlow_PinnedIdentityOp : OneFlow_BaseOp<\"pinned_identity\", [DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let summary = \"mark defining op of operand can't be erased\";\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\n#endif // GET_ONEFLOW_IDENTITY_OP_DEFINITIONS\n\n\n#ifdef GET_ONEFLOW_IMAGE_OP_DEFINITIONS\n\ndef OneFlow_ImageBatchAlignOp : OneFlow_BaseOp<\"image_batch_align\", [NoMemoryEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    ShapeAttr:$shape,\n    OneFlow_DataType:$data_type,\n    DefaultValuedAttr<SI32Attr, \"0\">:$alignment,\n    DefaultValuedAttr<BoolAttr, \"false\">:$dynamic_out\n  );\n  let has_check_fn = 1;\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_output_arg_modify_fn = 1;\n}\n\ndef OneFlow_ImageDecodeOp : OneFlow_BaseOp<\"image_decode\", [NoMemoryEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<StrAttr, \"\\\"BGR\\\"\">:$color_space,\n    OneFlow_DataType:$data_type\n  );\n  let has_check_fn = 1;\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_ImageFlipOp : OneFlow_BaseOp<\"image_flip\", [NoMemoryEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in,\n    OneFlow_Tensor:$flip_code\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_ImageRandomCropOp : OneFlow_BaseOp<\"image_random_crop\", [NoMemoryEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI32Attr, \"10\">:$num_attempts,\n    DefaultValuedAttr<SI64Attr, \"-1\">:$seed,\n    DefaultValuedAttr<BoolAttr, \"false\">:$has_seed,\n    F32ArrayAttr:$random_area,\n    F32ArrayAttr:$random_aspect_ratio\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_input_arg_modify_fn = 1;\n}\n\ndef OneFlow_ImageResizeKeepAspectRatioOp : OneFlow_BaseOp<\"image_resize_keep_aspect_ratio\", [NoMemoryEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out,\n    OneFlow_Tensor:$size,\n    OneFlow_Tensor:$scale\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI32Attr, \"0\">:$target_size,\n    DefaultValuedAttr<SI32Attr, \"0\">:$min_size,\n    DefaultValuedAttr<SI32Attr, \"0\">:$max_size,\n    DefaultValuedAttr<BoolAttr, \"false\">:$resize_longer,\n    DefaultValuedAttr<StrAttr, \"\\\"bilinear\\\"\">:$interpolation_type\n  );\n  let has_check_fn = 1;\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_ImageResizeToFixedOp : OneFlow_BaseOp<\"image_resize_to_fixed\", [NoMemoryEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out,\n    OneFlow_Tensor:$scale\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI64Attr, \"0\">:$target_width,\n    DefaultValuedAttr<SI64Attr, \"0\">:$target_height,\n    DefaultValuedAttr<SI64Attr, \"3\">:$channels,\n    OneFlow_DataType:$data_type,\n    DefaultValuedAttr<StrAttr, \"\\\"bilinear\\\"\">:$interpolation_type\n  );\n  let has_check_fn = 1;\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\n#endif // GET_ONEFLOW_IMAGE_OP_DEFINITIONS\n\n\n#ifdef GET_ONEFLOW_INDICES_OP_DEFINITIONS\n\ndef OneFlow_ArgSortOp : OneFlow_BaseOp<\"arg_sort\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    StrAttr:$direction\n  );\n  let has_check_fn = 1;\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_ArgmaxOp : OneFlow_BaseOp<\"argmax\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_ArgwhereOp : OneFlow_BaseOp<\"argwhere\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$input\n  );\n  let output = (outs\n    OneFlow_Tensor:$output,\n    OneFlow_Tensor:$output_size\n  );\n  let attrs = (ins\n    OneFlow_DataType:$dtype\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_BatchGatherOp : OneFlow_BaseOp<\"batch_gather\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in,\n    OneFlow_Tensor:$indices\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_input_arg_modify_fn = 1;\n}\n\ndef OneFlow_DimGatherOp : OneFlow_BaseOp<\"dim_gather\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$input,\n    OneFlow_Tensor:$index\n  );\n  let output = (outs\n    OneFlow_Tensor:$output\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI32Attr, \"0\">:$dim\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_input_arg_modify_fn = 1;\n}\n\ndef OneFlow_DimScatterAddOp : OneFlow_BaseOp<\"dim_scatter_add\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$input,\n    OneFlow_Tensor:$index,\n    OneFlow_Tensor:$src\n  );\n  let output = (outs\n    OneFlow_Tensor:$output\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI32Attr, \"0\">:$dim\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_input_arg_modify_fn = 1;\n}\n\ndef OneFlow_DimScatterAddLikeOp : OneFlow_BaseOp<\"dim_scatter_add_like\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$like,\n    OneFlow_Tensor:$index,\n    OneFlow_Tensor:$src\n  );\n  let output = (outs\n    OneFlow_Tensor:$output\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI32Attr, \"0\">:$dim\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_input_arg_modify_fn = 1;\n}\n\ndef OneFlow_DimScatterAddScalarOp : OneFlow_BaseOp<\"dim_scatter_add_scalar\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$input,\n    OneFlow_Tensor:$index\n  );\n  let output = (outs\n    OneFlow_Tensor:$output\n  );\n  let attrs = (ins\n    DefaultValuedAttr<F32Attr, \"0.\">:$src_scalar,\n    DefaultValuedAttr<SI32Attr, \"0\">:$dim\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_input_arg_modify_fn = 1;\n}\n\ndef OneFlow_DimScatterMulOp : OneFlow_BaseOp<\"dim_scatter_mul\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$input,\n    OneFlow_Tensor:$index,\n    OneFlow_Tensor:$src\n  );\n  let output = (outs\n    OneFlow_Tensor:$output\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI32Attr, \"0\">:$dim\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_input_arg_modify_fn = 1;\n}\n\ndef OneFlow_DimScatterMulScalarOp : OneFlow_BaseOp<\"dim_scatter_mul_scalar\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$input,\n    OneFlow_Tensor:$index\n  );\n  let output = (outs\n    OneFlow_Tensor:$output\n  );\n  let attrs = (ins\n    DefaultValuedAttr<F32Attr, \"0.\">:$src_scalar,\n    DefaultValuedAttr<SI32Attr, \"0\">:$dim\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_input_arg_modify_fn = 1;\n}\n\ndef OneFlow_DimScatterUpdateOp : OneFlow_BaseOp<\"dim_scatter_update\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$input,\n    OneFlow_Tensor:$index,\n    OneFlow_Tensor:$src\n  );\n  let output = (outs\n    OneFlow_Tensor:$output\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI32Attr, \"0\">:$dim\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_input_arg_modify_fn = 1;\n}\n\ndef OneFlow_DimScatterUpdateScalarOp : OneFlow_BaseOp<\"dim_scatter_update_scalar\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$input,\n    OneFlow_Tensor:$index\n  );\n  let output = (outs\n    OneFlow_Tensor:$output\n  );\n  let attrs = (ins\n    DefaultValuedAttr<F32Attr, \"0.\">:$src_scalar,\n    DefaultValuedAttr<SI32Attr, \"0\">:$dim\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_input_arg_modify_fn = 1;\n}\n\ndef OneFlow_EmbeddingRenormOp : OneFlow_BaseOp<\"embedding_renorm\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in,\n    OneFlow_Tensor:$indices\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<F64Attr, \"0.\">:$max_norm,\n    DefaultValuedAttr<F64Attr, \"2.\">:$norm_type\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_EmbeddingOp : OneFlow_BaseOp<\"embedding\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$weight,\n    OneFlow_Tensor:$indices\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI64Attr, \"-1\">:$padding_idx,\n    DefaultValuedAttr<BoolAttr, \"false\">:$scale_grad_by_freq\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_input_arg_modify_fn = 1;\n}\n\ndef OneFlow_FusedApplyRotaryEmbOp : OneFlow_BaseOp<\"fused_apply_rotary_emb\", [NoMemoryEffect, AttrSizedOperandSegments, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    Optional<OneFlow_Tensor>:$cos,\n    Optional<OneFlow_Tensor>:$sin,\n    Optional<OneFlow_Tensor>:$position_ids\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<StrAttr, \"\\\"BHMK\\\"\">:$x_layout,\n    DefaultValuedAttr<StrAttr, \"\\\"BHMK\\\"\">:$output_layout,\n    DefaultValuedAttr<StrAttr, \"\\\"plane\\\"\">:$mode,\n    DefaultValuedAttr<SI64Attr, \"0\">:$tensor_index,\n    DefaultValuedAttr<F32Attr, \"1e4\">:$base,\n    DefaultValuedAttr<SI64Attr, \"0\">:$k_size,\n    DefaultValuedAttr<SI64Attr, \"0\">:$rotary_size\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_EmbeddingGradOp : OneFlow_BaseOp<\"embedding_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$dy,\n    OneFlow_Tensor:$weight,\n    OneFlow_Tensor:$indices\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI64Attr, \"-1\">:$padding_idx,\n    DefaultValuedAttr<BoolAttr, \"false\">:$scale_grad_by_freq\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_input_arg_modify_fn = 1;\n}\n\ndef OneFlow_GatherOp : OneFlow_BaseOp<\"gather\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in,\n    OneFlow_Tensor:$indices\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI64Attr, \"0\">:$axis\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_input_arg_modify_fn = 1;\n}\n\ndef OneFlow_GatherNdOp : OneFlow_BaseOp<\"gather_nd\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$params,\n    OneFlow_Tensor:$indices\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_input_arg_modify_fn = 1;\n}\n\ndef OneFlow_GenerateRandomBatchPermutationIndicesOp : OneFlow_BaseOp<\"generate_random_batch_permutation_indices\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x\n  );\n  let output = (outs\n    OneFlow_Tensor:$y\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI64Attr, \"0\">:$seed\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_ImageTargetResizeOp : OneFlow_BaseOp<\"image_target_resize\", [NoMemoryEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out,\n    OneFlow_Tensor:$size,\n    OneFlow_Tensor:$scale\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI32Attr, \"0\">:$target_size,\n    DefaultValuedAttr<SI32Attr, \"0\">:$max_size\n  );\n  let has_check_fn = 1;\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_SliceOp : OneFlow_BaseOp<\"slice\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x\n  );\n  let output = (outs\n    OneFlow_Tensor:$y\n  );\n  let attrs = (ins\n    SI64ArrayAttr:$start,\n    SI64ArrayAttr:$stop,\n    SI64ArrayAttr:$step\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_SliceUpdateOp : OneFlow_BaseOp<\"slice_update\", [SupportNonContiguous, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$ref,\n    OneFlow_Tensor:$value\n  );\n  let output = (outs\n    OneFlow_Tensor:$y\n  );\n  let attrs = (ins\n    SI64ArrayAttr:$start,\n    SI64ArrayAttr:$stop,\n    SI64ArrayAttr:$step\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_SliceGradOp : OneFlow_BaseOp<\"slice_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$dy\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let attrs = (ins\n    ShapeAttr:$like_shape,\n    SI64ArrayAttr:$start,\n    SI64ArrayAttr:$stop,\n    SI64ArrayAttr:$step\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_input_arg_modify_fn = 1;\n}\n\ndef OneFlow_ScatterNdOp : OneFlow_BaseOp<\"scatter_nd\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$indices,\n    OneFlow_Tensor:$updates\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    ShapeAttr:$shape\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_input_arg_modify_fn = 1;\n}\n\ndef OneFlow_ScatterNdLikeOp : OneFlow_BaseOp<\"scatter_nd_like\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$like,\n    OneFlow_Tensor:$indices,\n    OneFlow_Tensor:$updates\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_TensorScatterNdAddOp : OneFlow_BaseOp<\"tensor_scatter_nd_add\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$params,\n    OneFlow_Tensor:$updates,\n    OneFlow_Tensor:$indices\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_input_arg_modify_fn = 1;\n}\n\ndef OneFlow_TensorScatterNdUpdateOp : OneFlow_BaseOp<\"tensor_scatter_nd_update\", [NoMemoryEffect, SupportNonContiguous, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$params,\n    OneFlow_Tensor:$updates,\n    OneFlow_Tensor:$indices\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_input_arg_modify_fn = 1;\n}\n\ndef OneFlow_UnsortedBatchSegmentSumOp : OneFlow_BaseOp<\"unsorted_batch_segment_sum\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$data,\n    OneFlow_Tensor:$segment_ids\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI64Attr, \"0\">:$num_segments\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_input_arg_modify_fn = 1;\n}\n\ndef OneFlow_UnsortedSegmentSumOp : OneFlow_BaseOp<\"unsorted_segment_sum\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$data,\n    OneFlow_Tensor:$segment_ids\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI64Attr, \"0\">:$axis,\n    DefaultValuedAttr<SI64Attr, \"0\">:$num_segments\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_input_arg_modify_fn = 1;\n}\n\ndef OneFlow_UnsortedSegmentSumLikeOp : OneFlow_BaseOp<\"unsorted_segment_sum_like\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$data,\n    OneFlow_Tensor:$segment_ids,\n    OneFlow_Tensor:$like\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI64Attr, \"0\">:$axis\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_input_arg_modify_fn = 1;\n}\n\ndef OneFlow_WhereOp : OneFlow_BaseOp<\"where\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$condition,\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$y\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_input_arg_modify_fn = 1;\n}\n\ndef OneFlow_MedianOp : OneFlow_BaseOp<\"median\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$input\n  );\n  let output = (outs\n    OneFlow_Tensor:$output\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_MedianWithIndicesOp : OneFlow_BaseOp<\"median_with_indices\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$input\n  );\n  let output = (outs\n    OneFlow_Tensor:$values,\n    OneFlow_Tensor:$indices\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_SearchSortedOp : OneFlow_BaseOp<\"searchsorted\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$sorted_sequence,\n    OneFlow_Tensor:$values\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<BoolAttr, \"false\">:$out_int32,\n    DefaultValuedAttr<BoolAttr, \"false\">:$right\n  );\n  let has_check_fn = 1;\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_SearchSortedScalarOp : OneFlow_BaseOp<\"searchsorted_scalar\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$sorted_sequence\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<BoolAttr, \"false\">:$out_int32,\n    DefaultValuedAttr<BoolAttr, \"false\">:$right,\n    DefaultValuedAttr<F32Attr, \"0.\">:$values\n  );\n  let has_check_fn = 1;\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_ModeOp: OneFlow_BaseOp<\"mode\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$input\n  );\n  let output = (outs\n    OneFlow_Tensor:$values,\n    OneFlow_Tensor:$indices\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\n#endif // GET_ONEFLOW_INDICES_OP_DEFINITIONS\n\n\n#ifdef GET_ONEFLOW_INVOLUTION_OP_DEFINITIONS\n\ndef OneFlow_NegativeOp : OneFlow_InvolutionBaseOp<\"negative\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {}\n\ndef OneFlow_ReciprocalOp : OneFlow_InvolutionBaseOp<\"reciprocal\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {}\n\n#endif // GET_ONEFLOW_INVOLUTION_OP_DEFINITIONS\n\n\n#ifdef GET_ONEFLOW_LOSS_OP_DEFINITIONS\n\ndef OneFlow_CombinedMarginLossOp : OneFlow_BaseOp<\"combined_margin_loss\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$label\n  );\n  let output = (outs\n    OneFlow_Tensor:$y,\n    OneFlow_Tensor:$theta\n  );\n  let attrs = (ins\n    DefaultValuedAttr<F32Attr, \"0.\">:$m1,\n    DefaultValuedAttr<F32Attr, \"0.\">:$m2,\n    DefaultValuedAttr<F32Attr, \"0.\">:$m3,\n    DefaultValuedAttr<SI64Attr, \"0\">:$depth\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_input_arg_modify_fn = 1;\n}\n\ndef OneFlow_CombinedMarginLossGradOp : OneFlow_BaseOp<\"combined_margin_loss_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$dy,\n    OneFlow_Tensor:$label,\n    OneFlow_Tensor:$theta\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let attrs = (ins\n    DefaultValuedAttr<F32Attr, \"0.\">:$m1,\n    DefaultValuedAttr<F32Attr, \"0.\">:$m2,\n    DefaultValuedAttr<F32Attr, \"0.\">:$m3,\n    DefaultValuedAttr<SI64Attr, \"0\">:$depth\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_CtcLossOp : OneFlow_BaseOp<\"ctc_loss\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$log_probs,\n    OneFlow_Tensor:$targets,\n    OneFlow_Tensor:$input_lengths,\n    OneFlow_Tensor:$target_lengths\n  );\n  let output = (outs\n    OneFlow_Tensor:$loss,\n    OneFlow_Tensor:$alpha\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI64Attr, \"0\">:$max_target_length,\n    DefaultValuedAttr<SI32Attr, \"0\">:$blank,\n    DefaultValuedAttr<BoolAttr, \"false\">:$zero_infinity\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_CtcLossGradOp : OneFlow_BaseOp<\"ctc_loss_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$grad_out,\n    OneFlow_Tensor:$log_probs,\n    OneFlow_Tensor:$targets,\n    OneFlow_Tensor:$input_lengths,\n    OneFlow_Tensor:$target_lengths,\n    OneFlow_Tensor:$loss,\n    OneFlow_Tensor:$alpha\n  );\n  let output = (outs\n    OneFlow_Tensor:$grad\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI64Attr, \"0\">:$max_target_length,\n    DefaultValuedAttr<SI32Attr, \"0\">:$blank,\n    DefaultValuedAttr<BoolAttr, \"false\">:$zero_infinity\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_DynamicLossScaleScheduleOp : OneFlow_BaseOp<\"dynamic_loss_scale_schedule\", [DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$count_not_finite,\n    OneFlow_Tensor:$loss_scale,\n    OneFlow_Tensor:$good_step_counter\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI64Attr, \"2000\">:$increment_period,\n    DefaultValuedAttr<F32Attr, \"2.\">:$multiplier\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_input_arg_modify_fn = 1;\n}\n\ndef OneFlow_KlDivLossOp : OneFlow_BaseOp<\"kl_div_loss\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$input,\n    OneFlow_Tensor:$target\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<BoolAttr, \"false\">:$log_target\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_input_arg_modify_fn = 1;\n}\n\ndef OneFlow_KlDivLossGradOp : OneFlow_BaseOp<\"kl_div_loss_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$input,\n    OneFlow_Tensor:$target,\n    OneFlow_Tensor:$dy\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let attrs = (ins\n    DefaultValuedAttr<BoolAttr, \"false\">:$log_target\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_SmoothL1LossOp : OneFlow_BaseOp<\"smooth_l1_loss\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$input,\n    OneFlow_Tensor:$target\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<F32Attr, \"0.\">:$beta\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_input_arg_modify_fn = 1;\n}\n\ndef OneFlow_SmoothL1LossGradOp : OneFlow_BaseOp<\"smooth_l1_loss_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$input,\n    OneFlow_Tensor:$target,\n    OneFlow_Tensor:$dy\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let attrs = (ins\n    DefaultValuedAttr<F32Attr, \"0.\">:$beta\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\n#endif // GET_ONEFLOW_LOSS_OP_DEFINITIONS\n\n\n#ifdef GET_ONEFLOW_MATH_OP_DEFINITIONS\n\ndef OneFlow_AbsGradOp : OneFlow_BaseOp<\"abs_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$dy\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_ErfOp : OneFlow_BaseOp<\"erf\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x\n  );\n  let output = (outs\n    OneFlow_Tensor:$y\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_ErfGradOp : OneFlow_BaseOp<\"erf_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$dy\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_ExpOp : OneFlow_BaseOp<\"exp\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x\n  );\n  let output = (outs\n    OneFlow_Tensor:$y\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_ExpGradOp : OneFlow_BaseOp<\"exp_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$dy\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_Exp2Op : OneFlow_BaseOp<\"exp2\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x\n  );\n  let output = (outs\n    OneFlow_Tensor:$y\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_Exp2GradOp : OneFlow_BaseOp<\"exp2_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$dy\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_Expm1Op : OneFlow_BaseOp<\"expm1\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x\n  );\n  let output = (outs\n    OneFlow_Tensor:$y\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_Expm1GradOp : OneFlow_BaseOp<\"expm1_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$dy\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\n\ndef OneFlow_FloordivXGradOp : OneFlow_BaseOp<\"floordiv_x_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$y,\n    OneFlow_Tensor:$dz\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_FloordivYGradOp : OneFlow_BaseOp<\"floordiv_y_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$y,\n    OneFlow_Tensor:$dz\n  );\n  let output = (outs\n    OneFlow_Tensor:$dy\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_TruncdivXGradOp : OneFlow_BaseOp<\"truncdiv_x_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$y,\n    OneFlow_Tensor:$dz\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_TruncdivYGradOp : OneFlow_BaseOp<\"truncdiv_y_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$y,\n    OneFlow_Tensor:$dz\n  );\n  let output = (outs\n    OneFlow_Tensor:$dy\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_LgammaOp : OneFlow_BaseOp<\"lgamma\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x\n  );\n  let output = (outs\n    OneFlow_Tensor:$y\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_LgammaGradOp : OneFlow_BaseOp<\"lgamma_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$dy\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\n\ndef OneFlow_DigammaOp : OneFlow_BaseOp<\"digamma\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x\n  );\n  let output = (outs\n    OneFlow_Tensor:$y\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_DigammaGradOp : OneFlow_BaseOp<\"digamma_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$dy\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_TrigammaOp : OneFlow_BaseOp<\"trigamma\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x\n  );\n  let output = (outs\n    OneFlow_Tensor:$y\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_LogOp : OneFlow_BaseOp<\"log\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x\n  );\n  let output = (outs\n    OneFlow_Tensor:$y\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_Log1pOp : OneFlow_BaseOp<\"log1p\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x\n  );\n  let output = (outs\n    OneFlow_Tensor:$y\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_Log1pGradOp : OneFlow_BaseOp<\"log1p_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$dy\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_Log2GradOp : OneFlow_BaseOp<\"log2_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$dy\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_Log10GradOp : OneFlow_BaseOp<\"log10_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$dy\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_LogGradOp : OneFlow_BaseOp<\"log_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$dy\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_LogSigmoidOp : OneFlow_BaseOp<\"log_sigmoid\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x\n  );\n  let output = (outs\n    OneFlow_Tensor:$y\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_LogSigmoidGradOp : OneFlow_BaseOp<\"log_sigmoid_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$dy\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_ReciprocalGradOp : OneFlow_BaseOp<\"reciprocal_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$dy\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_ReciprocalNoNanOp : OneFlow_BaseOp<\"reciprocal_no_nan\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x\n  );\n  let output = (outs\n    OneFlow_Tensor:$y\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_ReciprocalNoNanGradOp : OneFlow_BaseOp<\"reciprocal_no_nan_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$dy\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_RsqrtOp : OneFlow_BaseOp<\"rsqrt\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x\n  );\n  let output = (outs\n    OneFlow_Tensor:$y\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_RsqrtGradOp : OneFlow_BaseOp<\"rsqrt_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$dy\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\n\n\ndef OneFlow_SigmoidOp : OneFlow_BaseOp<\"sigmoid\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x\n  );\n  let output = (outs\n    OneFlow_Tensor:$y\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_SigmoidGradOp : OneFlow_BaseOp<\"sigmoid_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$y,\n    OneFlow_Tensor:$dy\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_SoftplusOp : OneFlow_BaseOp<\"softplus\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x\n  );\n  let output = (outs\n    OneFlow_Tensor:$y\n  );\n  let attrs = (ins\n    DefaultValuedAttr<F64Attr, \"1.\">:$beta,\n    DefaultValuedAttr<F64Attr, \"20.\">:$threshold\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_SoftplusGradOp : OneFlow_BaseOp<\"softplus_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$dy\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let attrs = (ins\n    DefaultValuedAttr<F64Attr, \"1.\">:$beta,\n    DefaultValuedAttr<F64Attr, \"20.\">:$threshold\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_SoftsignGradOp : OneFlow_BaseOp<\"softsign_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$dy\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_VarOp : OneFlow_BaseOp<\"var\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$input\n  );\n  let output = (outs\n    OneFlow_Tensor:$output\n  );\n  let attrs = (ins\n    SI32ArrayAttr:$dim,\n    DefaultValuedAttr<BoolAttr, \"true\">:$unbiased,\n    DefaultValuedAttr<BoolAttr, \"false\">:$keepdim,\n    OneFlow_DataType:$dtype\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\n\ndef OneFlow_SqrtOp : OneFlow_BaseOp<\"sqrt\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x\n  );\n  let output = (outs\n    OneFlow_Tensor:$y\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let hasFolder = 1;\n}\n\ndef OneFlow_SqrtGradOp : OneFlow_BaseOp<\"sqrt_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$dy\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_SquareOp : OneFlow_BaseOp<\"square\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x\n  );\n  let output = (outs\n    OneFlow_Tensor:$y\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_SquareGradOp : OneFlow_BaseOp<\"square_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$dy\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_XlogyXGradOp : OneFlow_BaseOp<\"xlogy_x_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$y,\n    OneFlow_Tensor:$dz\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_XlogyYGradOp : OneFlow_BaseOp<\"xlogy_y_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$y,\n    OneFlow_Tensor:$dz\n  );\n  let output = (outs\n    OneFlow_Tensor:$dy\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_CumsumOp : OneFlow_BaseOp<\"cumsum\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x\n  );\n  let output = (outs\n    OneFlow_Tensor:$y\n  );\n  let attrs = (ins\n    SI64Attr:$dim\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_CumProdOp : OneFlow_BaseOp<\"cumprod\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x\n  );\n  let output = (outs\n    OneFlow_Tensor:$y\n  );\n  let attrs = (ins\n    SI64Attr:$dim\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_CumProdGradOp : OneFlow_BaseOp<\"cumprod_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$dy,\n    OneFlow_Tensor:$output,\n    OneFlow_Tensor:$input\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let attrs = (ins\n    SI64Attr:$dim\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_ErfInvOp : OneFlow_BaseOp<\"erfinv\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x\n  );\n  let output = (outs\n    OneFlow_Tensor:$y\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_FftC2COp : OneFlow_BaseOp<\"fft_c2c\", [SupportNonContiguous, NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$input\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n\n  let attrs = (ins\n    SI64ArrayAttr:$dims,\n    BoolAttr:$forward,\n    SI32Attr:$norm_mode,\n    F64Attr:$norm_fct\n  );\n\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_FftR2COp : OneFlow_BaseOp<\"fft_r2c\", [SupportNonContiguous, NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$input\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n\n  let attrs = (ins\n    SI64ArrayAttr:$dims,\n    SI32Attr:$norm_mode,\n    F64Attr:$norm_fct,\n    BoolAttr:$onesided\n  );\n\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_FftC2ROp : OneFlow_BaseOp<\"fft_c2r\", [SupportNonContiguous, NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$input\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n\n  let attrs = (ins\n    SI64ArrayAttr:$dims,\n    SI32Attr:$norm_mode,\n    F64Attr:$norm_fct,\n    SI64Attr:$last_dim_size\n  );\n\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_StftOp : OneFlow_BaseOp<\"stft\", [SupportNonContiguous, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$input,\n    Optional<OneFlow_Tensor>:$window\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n\n  let attrs = (ins\n    DefaultValuedAttr<SI32Attr, \"0\">:$n_fft,\n    DefaultValuedAttr<SI32Attr, \"0\">:$hop_length,\n    DefaultValuedAttr<SI32Attr, \"0\">:$win_length,\n    DefaultValuedAttr<BoolAttr, \"true\">:$center,\n    DefaultValuedAttr<StrAttr, \"\\\"reflect\\\"\">:$pad_mode,\n    DefaultValuedAttr<BoolAttr, \"false\">:$normalized,\n    DefaultValuedAttr<BoolAttr, \"false\">:$onesided,\n    DefaultValuedAttr<BoolAttr, \"false\">:$return_complex\n  );\n\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_compute_complexity_fn = 1;\n}\n#endif // GET_ONEFLOW_MATH_OP_DEFINITIONS\n\n\n#ifdef GET_ONEFLOW_MATMUL_OP_DEFINITIONS\n\ndef OneFlow_BatchMatmulOp : OneFlow_BaseOp<\"batch_matmul\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$a,\n    OneFlow_Tensor:$b,\n    Optional<OneFlow_Tensor>:$_add_to_output\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<BoolAttr, \"false\">:$transpose_a,\n    DefaultValuedAttr<BoolAttr, \"false\">:$transpose_b,\n    DefaultValuedAttr<F64Attr, \"1.\">:$alpha\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_compute_complexity_fn = 1;\n}\n\ndef OneFlow_BroadcastMatmulOp : OneFlow_BaseOp<\"broadcast_matmul\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>, DeclareOpInterfaceMethods<MatMulCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$a,\n    OneFlow_Tensor:$b,\n    Optional<OneFlow_Tensor>:$_add_to_output\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<BoolAttr, \"false\">:$transpose_a,\n    DefaultValuedAttr<BoolAttr, \"false\">:$transpose_b,\n    DefaultValuedAttr<F64Attr, \"1.\">:$alpha\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_compute_complexity_fn = 1;\n}\n\ndef OneFlow_BroadcastMatmulGradBOp : OneFlow_BaseOp<\"broadcast_matmul_grad_b\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$a,\n    OneFlow_Tensor:$b,\n    Optional<OneFlow_Tensor>:$_add_to_output\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<F64Attr, \"1.\">:$alpha\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_compute_complexity_fn = 1;\n}\n\ndef OneFlow_DistributedPartialFcSampleOp : OneFlow_BaseOp<\"distributed_partial_fc_sample\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$weight,\n    OneFlow_Tensor:$label\n  );\n  let output = (outs\n    OneFlow_Tensor:$mapped_label,\n    OneFlow_Tensor:$sampled_label,\n    OneFlow_Tensor:$sampled_weight\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI64Attr, \"0\">:$num_sample,\n    DefaultValuedAttr<SI64Attr, \"-1\">:$seed\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_input_arg_modify_fn = 1;\n}\n\ndef OneFlow_DistributedPartialFcSampleDisableBoxingOp : OneFlow_BaseOp<\"distributed_partial_fc_sample_disable_boxing\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$sampled_weight_diff,\n    OneFlow_Tensor:$sampled_label\n  );\n  let output = (outs\n    OneFlow_Tensor:$boxing_disabled_sampled_weight_diff,\n    OneFlow_Tensor:$boxing_disabled_sampled_label\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_ErfcOp : OneFlow_BaseOp<\"erfc\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x\n  );\n  let output = (outs\n    OneFlow_Tensor:$y\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_ErfcGradOp : OneFlow_BaseOp<\"erfc_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$dy\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_MatmulOp : OneFlow_BaseOp<\"matmul\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>, DeclareOpInterfaceMethods<MatMulCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$a,\n    OneFlow_Tensor:$b,\n    Optional<OneFlow_Tensor>:$_add_to_output\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<BoolAttr, \"false\">:$transpose_a,\n    DefaultValuedAttr<BoolAttr, \"false\">:$transpose_b,\n    DefaultValuedAttr<F64Attr, \"1.\">:$alpha\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_compute_complexity_fn = 1;\n}\n\ndef OneFlow_MatrixVectorProductOp : OneFlow_BaseOp<\"matrix_vector_product\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$a,\n    OneFlow_Tensor:$b\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_MatrixVectorProductGradAOp : OneFlow_BaseOp<\"matrix_vector_product_grad_a\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$dy,\n    OneFlow_Tensor:$b\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_MatrixVectorProductGradBOp : OneFlow_BaseOp<\"matrix_vector_product_grad_b\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$dy,\n    OneFlow_Tensor:$a\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_VectorMatrixProductOp : OneFlow_BaseOp<\"vector_matrix_product\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$a,\n    OneFlow_Tensor:$b\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_VectorMatrixProductGradAOp : OneFlow_BaseOp<\"vector_matrix_product_grad_a\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$dy,\n    OneFlow_Tensor:$b\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_VectorMatrixProductGradBOp : OneFlow_BaseOp<\"vector_matrix_product_grad_b\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$dy,\n    OneFlow_Tensor:$a\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\n\ndef OneFlow_CublasFusedMLPOp : OneFlow_BaseOp<\"cublas_fused_mlp\", [NoMemoryEffect, AttrSizedOperandSegments, AttrSizedResultSegments, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    Variadic<OneFlow_Tensor>:$weights,\n    Variadic<OneFlow_Tensor>:$biases\n  );\n  let output = (outs\n    OneFlow_Tensor:$out,\n    Variadic<OneFlow_Tensor>:$cublas_aux,\n    Variadic<OneFlow_Tensor>:$hidden\n  );\n  let attrs = (ins\n    DefaultValuedAttr<BoolAttr, \"false\">:$skip_final_activation\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_CublasFusedMLPGradOp : OneFlow_BaseOp<\"cublas_fused_mlp_grad\", [NoMemoryEffect, NoGrad, AttrSizedOperandSegments, AttrSizedResultSegments, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$dy,\n    OneFlow_Tensor:$x,\n    Variadic<OneFlow_Tensor>:$weights,\n    Variadic<OneFlow_Tensor>:$cublas_aux,\n    Variadic<OneFlow_Tensor>:$hidden\n  );\n  let output = (outs\n    OneFlow_Tensor:$d_x,\n    Variadic<OneFlow_Tensor>:$d_biases,\n    Variadic<OneFlow_Tensor>:$d_weights\n  );\n  let attrs = (ins\n    F32ArrayAttr:$alpha_list\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_CublasBiasAddReluMatmulGradOp : OneFlow_BaseOp<\"cublas_bias_add_relu_matmul_grad\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$dy,\n    OneFlow_Tensor:$weight,\n    OneFlow_Tensor:$aux\n  );\n  let output = (outs\n    OneFlow_Tensor:$d_grad,\n    OneFlow_Tensor:$d_bias\n  );\n  let attrs = (ins\n    DefaultValuedAttr<F64Attr, \"1.\">:$alpha\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_CublasMatmulBiasAddGradOp : OneFlow_BaseOp<\"cublas_matmul_bias_add_grad\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$dy,\n    OneFlow_Tensor:$x\n  );\n  let output = (outs\n    OneFlow_Tensor:$w_grad,\n    OneFlow_Tensor:$b_grad\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_FusedMatmulBiasOp : OneFlow_BaseOp<\"fused_matmul_bias\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>, DeclareOpInterfaceMethods<MatMulCompatibleInterface>, DeclareOpInterfaceMethods<BiasAddCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$weight,\n    OneFlow_Tensor:$bias,\n    Optional<OneFlow_Tensor>:$_add_to_output\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<F64Attr, \"1.\">:$alpha,\n    DefaultValuedAttr<F64Attr, \"1.\">:$beta\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n}\n\ndef OneFlow_FusedMatmulBiasAddReluDropoutOp : OneFlow_BaseOp<\"fused_matmul_bias_add_relu_dropout\", [NoMemoryEffect, AttrSizedOperandSegments, AttrSizedResultSegments, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    Variadic<OneFlow_Tensor>:$weights,\n    Variadic<OneFlow_Tensor>:$biases\n  );\n  let output = (outs\n    OneFlow_Tensor:$out,\n    Variadic<OneFlow_Tensor>:$cublas_aux,\n    Variadic<OneFlow_Tensor>:$hidden\n  );\n  let attrs = (ins\n    DefaultValuedAttr<BoolAttr, \"false\">:$skip_final_activation,\n    DefaultValuedAttr<SI64Attr, \"0\">:$seed,\n    F32ArrayAttr:$dropout_rate_list\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_FusedReluDropoutGradOp : OneFlow_BaseOp<\"fused_relu_dropout_grad\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$dy,\n    OneFlow_Tensor:$mask\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let attrs = (ins\n    DefaultValuedAttr<F32Attr, \"0.\">:$scale\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_FusedGluOp : OneFlow_BaseOp<\"fused_glu\", [NoMemoryEffect, AttrSizedOperandSegments, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$w,\n    Optional<OneFlow_Tensor>:$b,\n    Optional<OneFlow_Tensor>:$v,\n    Optional<OneFlow_Tensor>:$c\n  );\n  let attrs = (ins\n    DefaultValuedAttr<StrAttr, \"\\\"none\\\"\">:$activation,\n    DefaultValuedAttr<BoolAttr, \"false\">:$has_bias,\n    DefaultValuedAttr<BoolAttr, \"false\">:$is_split\n  );\n  let trait_attrs = (ins\n    DenseI32ArrayAttr:$operand_segment_sizes\n  );\n  let output = (outs\n    OneFlow_Tensor:$y,\n    OneFlow_Tensor:$matmul_wx,\n    Optional<OneFlow_Tensor>:$matmul_vx\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_FusedGluWithoutLinearGradOp : OneFlow_BaseOp<\"fused_glu_without_linear_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$dy,\n    OneFlow_Tensor:$matmul_wx,\n    Optional<OneFlow_Tensor>:$matmul_vx\n  );\n  let attrs = (ins\n    DefaultValuedAttr<StrAttr, \"\\\"none\\\"\">:$activation\n  );\n  let output = (outs\n    OneFlow_Tensor:$d_matmul_wx,\n    Optional<OneFlow_Tensor>:$d_matmul_vx\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_GroupedMatmulBiasOp : OneFlow_BaseOp<\"grouped_matmul_bias\", [NoMemoryEffect, AttrSizedOperandSegments,  DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    Variadic<OneFlow_Tensor>:$xs,\n    Variadic<OneFlow_Tensor>:$weights,\n    Variadic<OneFlow_Tensor>:$biases\n  );\n  let output = (outs\n    Variadic<OneFlow_Tensor>:$ys\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\n#endif // GET_ONEFLOW_MATMUL_OP_DEFINITIONS\n\n\n#ifdef GET_ONEFLOW_MISC_OP_DEFINITIONS\n\ndef OneFlow_CategoricalOrdinalEncodeOp : OneFlow_BaseOp<\"CategoricalOrdinalEncode\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$table,\n    OneFlow_Tensor:$size,\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<BoolAttr, \"false\">:$hash_precomputed\n  );\n  let has_check_fn = 1;\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_input_arg_modify_fn = 1;\n}\n\ndef OneFlow_AddNOp : OneFlow_BaseOp<\"add_n\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    Variadic<OneFlow_Tensor>:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let hasCanonicalizer = 1;\n  let has_check_fn = 1;\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_ArangeOp : OneFlow_BaseOp<\"arange\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI64Attr, \"0\">:$integer_start,\n    DefaultValuedAttr<SI64Attr, \"0\">:$integer_delta,\n    DefaultValuedAttr<SI64Attr, \"0\">:$integer_limit,\n    DefaultValuedAttr<F64Attr, \"0.\">:$float_start,\n    DefaultValuedAttr<F64Attr, \"0.\">:$float_delta,\n    DefaultValuedAttr<F64Attr, \"0.\">:$float_limit,\n    OneFlow_DataType:$dtype,\n    StrArrayAttr:$nd_sbp\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_nd_sbp_infer_fn = 1;\n}\n\ndef OneFlow_BinCountOp : OneFlow_BaseOp<\"bincount\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in,\n    Optional<OneFlow_Tensor>:$weight\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI64Attr, \"0\">:$size\n  );\n  let has_data_type_infer_fn = 1;\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n}\n\ndef OneFlow_CoinFlipOp : OneFlow_BaseOp<\"coin_flip\", [NoMemoryEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<F32Attr, \"0.5\">:$probability,\n    DefaultValuedAttr<SI64Attr, \"0\">:$batch_size,\n    DefaultValuedAttr<SI64Attr, \"-1\">:$seed,\n    DefaultValuedAttr<BoolAttr, \"false\">:$has_seed,\n    StrArrayAttr:$nd_sbp\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_nd_sbp_infer_fn = 1;\n}\n\ndef OneFlow_ConcatOp : OneFlow_BaseOp<\"cat\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>, DeclareOpInterfaceMethods<NCHWCompatibleInterface>]> {\n  let input = (ins\n    Variadic<OneFlow_Tensor>:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI64Attr, \"0\">:$axis,\n    DefaultValuedAttr<SI64Attr, \"0\">:$max_dim_size\n  );\n  let has_check_fn = 1;\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_TensorConstantOp : OneFlow_BaseOp<\"tensor_constant\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    OneFlow_DataType:$dtype,\n    ShapeAttr:$shape,\n    StrArrayAttr:$nd_sbp\n  );\n  let same_output_regst_num = 1;\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_nd_sbp_infer_fn = 1;\n}\n\ndef OneFlow_ConstantOp : OneFlow_BaseOp<\"constant\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    ComplexDoubleAttr:$complex_value,\n    DefaultValuedAttr<F64Attr, \"0.\">:$floating_value,\n    DefaultValuedAttr<SI64Attr, \"0\">:$integer_value,\n    DefaultValuedAttr<BoolAttr, \"false\">:$is_floating_value,\n    DefaultValuedAttr<BoolAttr, \"false\">:$is_complex_value,\n    OneFlow_DataType:$dtype,\n    ShapeAttr:$shape,\n    StrArrayAttr:$nd_sbp\n  );\n  let same_output_regst_num = 1;\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_nd_sbp_infer_fn = 1;\n}\n\ndef OneFlow_DropoutOp : OneFlow_BaseOp<\"dropout\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in,\n    Optional<OneFlow_Tensor>:$_add_to_output\n  );\n  let output = (outs\n    OneFlow_Tensor:$out,\n    OneFlow_Tensor:$mask\n  );\n  let attrs = (ins\n    DefaultValuedAttr<F32Attr, \"0.\">:$rate,\n    DefaultValuedAttr<SI64Attr, \"0\">:$seed\n  );\n  let has_check_fn = 1;\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_ElementwiseMaximumBackwardOp : OneFlow_BaseOp<\"elementwise_maximum_backward\", [NoMemoryEffect, AttrSizedResultSegments, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$dz,\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$y\n  );\n  let output = (outs\n    Optional<OneFlow_Tensor>:$dx,\n    Optional<OneFlow_Tensor>:$dy\n  );\n  let trait_attrs = (ins\n    DenseI32ArrayAttr:$result_segment_sizes\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_ElementwiseMinimumBackwardOp : OneFlow_BaseOp<\"elementwise_minimum_backward\", [NoMemoryEffect, AttrSizedResultSegments, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$dz,\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$y\n  );\n  let output = (outs\n    Optional<OneFlow_Tensor>:$dx,\n    Optional<OneFlow_Tensor>:$dy\n  );\n  let trait_attrs = (ins\n    DenseI32ArrayAttr:$result_segment_sizes\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_EmptyOp : OneFlow_BaseOp<\"empty\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    OneFlow_DataType:$dtype,\n    ShapeAttr:$shape,\n    StrArrayAttr:$nd_sbp,\n    DefaultValuedAttr<BoolAttr, \"false\">:$pin_memory,\n    StrAttr:$device_type,\n    DefaultValuedAttr<SI64Attr, \"0\">:$device_id\n  );\n  let same_output_regst_num = 1;\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_nd_sbp_infer_fn = 1;\n  let has_device_and_stream_infer_fn = 1;\n}\n\ndef OneFlow_EyeOp : OneFlow_BaseOp<\"eye\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI64Attr, \"0\">:$rows,\n    DefaultValuedAttr<SI64Attr, \"0\">:$cols,\n    OneFlow_DataType:$dtype,\n    StrArrayAttr:$nd_sbp\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_GridSampleGradOp : OneFlow_BaseOp<\"grid_sample_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$doutput,\n    OneFlow_Tensor:$input,\n    OneFlow_Tensor:$grid\n  );\n  let output = (outs\n    OneFlow_Tensor:$dinput,\n    OneFlow_Tensor:$dgrid\n  );\n  let attrs = (ins\n    StrAttr:$interpolation_mode,\n    StrAttr:$padding_mode,\n    DefaultValuedAttr<BoolAttr, \"false\">:$align_corners\n  );\n  let has_check_fn = 1;\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_MultiCountNotFiniteOp : OneFlow_BaseOp<\"multi_count_not_finite\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    Variadic<OneFlow_Tensor>:$x\n  );\n  let output = (outs\n    OneFlow_Tensor:$y\n  );\n  let has_check_fn = 1;\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_MultiSquareSumOp : OneFlow_BaseOp<\"multi_square_sum\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    Variadic<OneFlow_Tensor>:$x\n  );\n  let output = (outs\n    OneFlow_Tensor:$y\n  );\n  let has_check_fn = 1;\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_MultiReduceSumPowAbsOp : OneFlow_BaseOp<\"multi_reduce_sum_pow_abs\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    Variadic<OneFlow_Tensor>:$x\n  );\n  let output = (outs\n    OneFlow_Tensor:$y\n  );\n  let attrs = (ins\n    DefaultValuedAttr<F32Attr, \"0\">:$p\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n}\n\ndef OneFlow_MultiReduceMaxAbsOp : OneFlow_BaseOp<\"multi_reduce_max_abs\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    Variadic<OneFlow_Tensor>:$x\n  );\n  let output = (outs\n    OneFlow_Tensor:$y\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n}\n\ndef OneFlow_MultiReduceMinAbsOp : OneFlow_BaseOp<\"multi_reduce_min_abs\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    Variadic<OneFlow_Tensor>:$x\n  );\n  let output = (outs\n    OneFlow_Tensor:$y\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n}\n\ndef OneFlow_LocalMultiReduceMaxAbsOp : OneFlow_BaseOp<\"local_multi_reduce_max_abs\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    Variadic<OneFlow_Tensor>:$x\n  );\n  let output = (outs\n    OneFlow_Tensor:$y\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n}\n\ndef OneFlow_LocalMultiReduceMinAbsOp : OneFlow_BaseOp<\"local_multi_reduce_min_abs\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    Variadic<OneFlow_Tensor>:$x\n  );\n  let output = (outs\n    OneFlow_Tensor:$y\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n}\n\ndef OneFlow_NLLOp : OneFlow_BaseOp<\"nll\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$input,\n    OneFlow_Tensor:$target,\n    Optional<OneFlow_Tensor>:$weight\n  );\n  let output = (outs\n    OneFlow_Tensor:$output,\n    OneFlow_Tensor:$out_weight\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI64Attr, \"0\">:$ignore_index\n  );\n  let has_data_type_infer_fn = 1;\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_input_arg_modify_fn = 1;\n}\n\ndef OneFlow_NLLGradOp : OneFlow_BaseOp<\"nll_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$out_grad,\n    OneFlow_Tensor:$input,\n    OneFlow_Tensor:$target,\n    Optional<OneFlow_Tensor>:$weight\n  );\n  let output = (outs\n    OneFlow_Tensor:$in_grad\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI64Attr, \"0\">:$ignore_index\n  );\n  let has_data_type_infer_fn = 1;\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n}\n\ndef OneFlow_PowXGradOp : OneFlow_BaseOp<\"pow_x_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$y,\n    OneFlow_Tensor:$dz\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_PowYGradOp : OneFlow_BaseOp<\"pow_y_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$y,\n    OneFlow_Tensor:$dz\n  );\n  let output = (outs\n    OneFlow_Tensor:$dy\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_PreluGradOp : OneFlow_BaseOp<\"prelu_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$dy,\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$alpha\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx,\n    OneFlow_Tensor:$alpha_diff\n  );\n  let attrs = (ins\n    DefaultValuedAttr<BoolAttr, \"true\">:$alpha_requires_grad\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_RandpermOp : OneFlow_BaseOp<\"randperm\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI32Attr, \"0\">:$n,\n    DefaultValuedAttr<SI64Attr, \"0\">:$seed,\n    StrArrayAttr:$nd_sbp\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_nd_sbp_infer_fn = 1;\n}\n\ndef OneFlow_RecvOp : OneFlow_BaseOp<\"recv\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI64Attr, \"0\">:$src_process_id,\n    OneFlow_DataType:$dtype,\n    ShapeAttr:$shape,\n    StrAttr:$device_type,\n    DefaultValuedAttr<SI64Attr, \"0\">:$device_id\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_device_and_stream_infer_fn = 1;\n}\n\ndef OneFlow_SendOp : OneFlow_BaseOp<\"send\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI64Attr, \"0\">:$dst_process_id\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_device_and_stream_infer_fn = 1;\n}\n\ndef OneFlow_SplitLikeOp : OneFlow_BaseOp<\"split_like\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in,\n    Variadic<OneFlow_Tensor>:$like\n  );\n  let output = (outs\n    Variadic<OneFlow_Tensor>:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI64Attr, \"0\">:$axis\n  );\n  let has_check_fn = 1;\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_input_arg_modify_fn = 1;\n}\n\ndef OneFlow_SspVariableProxyOp : OneFlow_BaseOp<\"ssp_variable_proxy\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$var\n  );\n  let output = (outs\n    OneFlow_Tensor:$ref,\n    OneFlow_Tensor:$value\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI64Attr, \"1\">:$buffer_size\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_output_arg_modify_fn = 1;\n}\n\ndef OneFlow_TfPreluGradOp : OneFlow_BaseOp<\"tf_prelu_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$dy,\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$alpha\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx,\n    OneFlow_Tensor:$alpha_diff\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_UniformOp : OneFlow_BaseOp<\"uniform\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<F64Attr, \"0.\">:$from,\n    DefaultValuedAttr<F64Attr, \"1.\">:$to,\n    DefaultValuedAttr<SI64Attr, \"0\">:$seed,\n    OneFlow_DataType:$dtype,\n    ShapeAttr:$shape,\n    StrArrayAttr:$nd_sbp\n  );\n  let same_output_regst_num = 1;\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_nd_sbp_infer_fn = 1;\n  let has_dump_nd_sbp_signature_for_op_conf_fn = 1;\n}\n\ndef OneFlow_UniformIntOp : OneFlow_BaseOp<\"uniform_int\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI64Attr, \"0\">:$from,\n    DefaultValuedAttr<SI64Attr, \"1\">:$to,\n    DefaultValuedAttr<SI64Attr, \"0\">:$seed,\n    OneFlow_DataType:$dtype,\n    ShapeAttr:$shape,\n    StrArrayAttr:$nd_sbp\n  );\n  let same_output_regst_num = 1;\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_nd_sbp_infer_fn = 1;\n}\n\ndef OneFlow_ExponentialOp : OneFlow_BaseOp<\"exponential\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI64Attr, \"0\">:$seed,\n    DefaultValuedAttr<F32Attr, \"1.0\">:$lambd,\n    OneFlow_DataType:$dtype,\n    ShapeAttr:$out_shape,\n    StrArrayAttr:$nd_sbp\n  );\n  let same_output_regst_num = 1;\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_nd_sbp_infer_fn = 1;\n}\n\ndef OneFlow_MultinomialWithReplacementOp : OneFlow_BaseOp<\"multinomial_with_replacement\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    Optional<OneFlow_Tensor>:$prefix_sum\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI64Attr, \"0\">:$seed,\n    DefaultValuedAttr<SI32Attr, \"1\">:$num_samples,\n    DefaultValuedAttr<BoolAttr, \"true\">:$replacement\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_UniqueOp : OneFlow_BaseOp<\"unique\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x\n  );\n  let output = (outs\n    OneFlow_Tensor:$y,\n    OneFlow_Tensor:$idx,\n    OneFlow_Tensor:$num_unique\n  );\n  let attrs = (ins\n    OneFlow_DataType:$out_idx,\n    DefaultValuedAttr<BoolAttr, \"true\">:$sorted\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_UniqueWithCountsOp : OneFlow_BaseOp<\"unique_with_counts\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x\n  );\n  let output = (outs\n    OneFlow_Tensor:$y,\n    OneFlow_Tensor:$idx,\n    OneFlow_Tensor:$num_unique,\n    OneFlow_Tensor:$count\n  );\n  let attrs = (ins\n    OneFlow_DataType:$out_idx,\n    DefaultValuedAttr<BoolAttr, \"true\">:$sorted\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_XdivyXGradOp : OneFlow_BaseOp<\"xdivy_x_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$y,\n    OneFlow_Tensor:$dz\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_XdivyYGradOp : OneFlow_BaseOp<\"xdivy_y_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$y,\n    OneFlow_Tensor:$dz\n  );\n  let output = (outs\n    OneFlow_Tensor:$dy\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_StackOp : OneFlow_BaseOp<\"stack\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    Variadic<OneFlow_Tensor>:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI64Attr, \"0\">:$axis,\n    DefaultValuedAttr<SI64Attr, \"0\">:$max_dim_size\n  );\n  let has_check_fn = 1;\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_StackGradOp : OneFlow_BaseOp<\"stack_grad\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in,\n    Variadic<OneFlow_Tensor>:$like\n  );\n  let output = (outs\n    Variadic<OneFlow_Tensor>:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI64Attr, \"0\">:$axis\n  );\n  let has_check_fn = 1;\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_input_arg_modify_fn = 1;\n}\n\ndef OneFlow_FusedWeightedSumOp : OneFlow_BaseOp<\"fused_weighted_sum\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    Variadic<OneFlow_Tensor>:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    F32ArrayAttr:$weights,\n    DefaultValuedAttr<F32Attr, \"1.0\">:$alpha\n  );\n  let has_check_fn = 1;\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_DependOp : OneFlow_BaseOp<\"depend\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in,\n    OneFlow_Tensor:$depend_tensor\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\n#endif // GET_ONEFLOW_MISC_OP_DEFINITIONS\n\n\n#ifdef GET_ONEFLOW_NCCL_OP_DEFINITIONS\n\ndef OneFlow__ncclLogical_2DSameDim0All2allOp : OneFlow_BaseOp<\"_nccl_logical_2D_same_dim0_all2all\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    StrArrayAttr:$src_reduced_nd_sbp,\n    StrArrayAttr:$dst_reduced_nd_sbp\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_device_and_stream_infer_fn = 1;\n  let has_nd_sbp_infer_fn = 1;\n}\n\ndef OneFlow__ncclLogical_2DSameDim0AllGatherOp : OneFlow_BaseOp<\"_nccl_logical_2D_same_dim0_all_gather\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    StrArrayAttr:$src_reduced_nd_sbp,\n    StrArrayAttr:$dst_reduced_nd_sbp\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_device_and_stream_infer_fn = 1;\n  let has_nd_sbp_infer_fn = 1;\n}\n\ndef OneFlow__ncclLogical_2DSameDim0AllGatherNoncontinuousOp : OneFlow_BaseOp<\"_nccl_logical_2D_same_dim0_all_gather_noncontinuous\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    StrArrayAttr:$src_reduced_nd_sbp,\n    StrArrayAttr:$dst_reduced_nd_sbp\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_device_and_stream_infer_fn = 1;\n  let has_nd_sbp_infer_fn = 1;\n}\n\ndef OneFlow__ncclLogical_2DSameDim0AllReduceOp : OneFlow_BaseOp<\"_nccl_logical_2D_same_dim0_all_reduce\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    StrArrayAttr:$src_reduced_nd_sbp,\n    StrArrayAttr:$dst_reduced_nd_sbp\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_device_and_stream_infer_fn = 1;\n  let has_nd_sbp_infer_fn = 1;\n}\n\ndef OneFlow__ncclLogical_2DSameDim1AllReduceOp : OneFlow_BaseOp<\"_nccl_logical_2D_same_dim1_all_reduce\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    StrArrayAttr:$src_reduced_nd_sbp,\n    StrArrayAttr:$dst_reduced_nd_sbp\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_device_and_stream_infer_fn = 1;\n  let has_nd_sbp_infer_fn = 1;\n}\n\ndef OneFlow__ncclLogicalAllGatherOp : OneFlow_BaseOp<\"_nccl_logical_all_gather\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    StrArrayAttr:$src_reduced_nd_sbp,\n    StrArrayAttr:$dst_reduced_nd_sbp\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_device_and_stream_infer_fn = 1;\n  let has_nd_sbp_infer_fn = 1;\n}\n\ndef OneFlow__ncclLogicalAllGatherNoncontinuousOp : OneFlow_BaseOp<\"_nccl_logical_all_gather_noncontinuous\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    StrArrayAttr:$src_reduced_nd_sbp,\n    StrArrayAttr:$dst_reduced_nd_sbp\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_device_and_stream_infer_fn = 1;\n  let has_nd_sbp_infer_fn = 1;\n}\n\ndef OneFlow__ncclLogicalAllReduceOp : OneFlow_BaseOp<\"_nccl_logical_all_reduce\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    StrArrayAttr:$src_reduced_nd_sbp,\n    StrArrayAttr:$dst_reduced_nd_sbp\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_device_and_stream_infer_fn = 1;\n  let has_nd_sbp_infer_fn = 1;\n}\n\ndef OneFlow__ncclLogicalReduceScatterOp : OneFlow_BaseOp<\"_nccl_logical_reduce_scatter\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    StrArrayAttr:$src_reduced_nd_sbp,\n    StrArrayAttr:$dst_reduced_nd_sbp\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_device_and_stream_infer_fn = 1;\n  let has_nd_sbp_infer_fn = 1;\n}\n\ndef OneFlow__ncclLogicalReduceScatterNoncontinuousOp : OneFlow_BaseOp<\"_nccl_logical_reduce_scatter_noncontinuous\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    StrArrayAttr:$src_reduced_nd_sbp,\n    StrArrayAttr:$dst_reduced_nd_sbp\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_device_and_stream_infer_fn = 1;\n  let has_nd_sbp_infer_fn = 1;\n}\n\ndef OneFlow__ncclLogicalS2sOp : OneFlow_BaseOp<\"_nccl_logical_s2s\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    StrArrayAttr:$src_reduced_nd_sbp,\n    StrArrayAttr:$dst_reduced_nd_sbp\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_device_and_stream_infer_fn = 1;\n  let has_nd_sbp_infer_fn = 1;\n}\n\ndef OneFlow__ncclLogicalSendRecvOp : OneFlow_BaseOp<\"_nccl_logical_send_recv\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    StrArrayAttr:$src_reduced_nd_sbp,\n    StrArrayAttr:$dst_reduced_nd_sbp\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_device_and_stream_infer_fn = 1;\n  let has_nd_sbp_infer_fn = 1;\n}\n\ndef OneFlow__ncclLogicalFusionOp : OneFlow_BaseOp<\"_nccl_logical_fusion\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    Variadic<OneFlow_Tensor>:$in\n  );\n  let output = (outs\n    Variadic<OneFlow_Tensor>:$out\n  );\n  let attrs = (ins\n    StrArrayAttr:$src_nd_sbp_str_list,\n    StrArrayAttr:$dst_nd_sbp_str_list,\n    StrArrayAttr:$nccl_type_list\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_device_and_stream_infer_fn = 1;\n  let has_nd_sbp_infer_fn = 1;\n}\n\n#endif // GET_ONEFLOW_NCCL_OP_DEFINITIONS\n\n\n#ifdef GET_ONEFLOW_NORMALIZATION_OP_DEFINITIONS\n\ndef OneFlow_NormalizationAddReluOp : OneFlow_BaseOp<\"normalization_add_relu\", [NoMemoryEffect, AttrSizedOperandSegments, AttrSizedResultSegments, DeclareOpInterfaceMethods<UserOpCompatibleInterface>, DeclareOpInterfaceMethods<InferTypeOpInterface, [\"refineReturnTypes\"]>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    Optional<OneFlow_Tensor>:$addend,\n    Optional<OneFlow_Tensor>:$moving_mean,\n    Optional<OneFlow_Tensor>:$moving_variance,\n    OneFlow_Tensor:$gamma,\n    OneFlow_Tensor:$beta\n  );\n  let output = (outs\n    OneFlow_Tensor:$y,\n    OneFlow_Tensor:$reserve_space,\n    Optional<OneFlow_Tensor>:$mean,\n    Optional<OneFlow_Tensor>:$inv_variance\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI32Attr, \"0\">:$axis,\n    DefaultValuedAttr<F32Attr, \"0.\">:$epsilon,\n    DefaultValuedAttr<BoolAttr, \"false\">:$training,\n    DefaultValuedAttr<F32Attr, \"0.\">:$momentum\n  );\n  let trait_attrs = (ins\n    DenseI32ArrayAttr:$operand_segment_sizes,\n    DenseI32ArrayAttr:$result_segment_sizes\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_input_arg_modify_fn = 1;\n}\n\ndef OneFlow_BatchNormStatsOp : OneFlow_BaseOp<\"batch_norm_stats\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$input\n  );\n  let output = (outs\n    OneFlow_Tensor:$mean,\n    OneFlow_Tensor:$invstd\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI32Attr, \"1\">:$axis,\n    DefaultValuedAttr<F32Attr, \"0.00001\">:$eps\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_BatchNormGatherStatsWithCountsOp : OneFlow_BaseOp<\"batch_norm_gather_stats_with_counts\", [NoMemoryEffect, AttrSizedOperandSegments, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$input,\n    OneFlow_Tensor:$mean,\n    OneFlow_Tensor:$invstd,\n    OneFlow_Tensor:$counts,\n    Optional<OneFlow_Tensor>:$running_mean,\n    Optional<OneFlow_Tensor>:$running_var\n  );\n  let output = (outs\n    OneFlow_Tensor:$global_mean,\n    OneFlow_Tensor:$global_invstd\n  );\n  let attrs = (ins\n    DefaultValuedAttr<F32Attr, \"0.00001\">:$eps,\n    DefaultValuedAttr<F32Attr, \"0.9\">:$momentum\n  );\n  let trait_attrs = (ins\n    DenseI32ArrayAttr:$operand_segment_sizes\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_input_arg_modify_fn = 1;\n}\n\ndef OneFlow_BatchNormElemtOp : OneFlow_BaseOp<\"batch_norm_elemt\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$input,\n    OneFlow_Tensor:$weight,\n    OneFlow_Tensor:$bias,\n    OneFlow_Tensor:$mean,\n    OneFlow_Tensor:$invstd\n  );\n  let output = (outs\n    OneFlow_Tensor:$output\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI32Attr, \"1\">:$axis,\n    DefaultValuedAttr<F32Attr, \"0.00001\">:$eps\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_BatchNormBackwardReduceOp : OneFlow_BaseOp<\"batch_norm_backward_reduce\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$grad_out,\n    OneFlow_Tensor:$input,\n    OneFlow_Tensor:$mean,\n    OneFlow_Tensor:$invstd\n  );\n  let output = (outs\n    OneFlow_Tensor:$sum_dy,\n    OneFlow_Tensor:$sum_dy_xmu,\n    OneFlow_Tensor:$grad_weight,\n    OneFlow_Tensor:$grad_bias\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI32Attr, \"1\">:$axis\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_BatchNormBackwardElemtOp : OneFlow_BaseOp<\"batch_norm_backward_elemt\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$grad_out,\n    OneFlow_Tensor:$input,\n    OneFlow_Tensor:$mean,\n    OneFlow_Tensor:$invstd,\n    OneFlow_Tensor:$weight,\n    OneFlow_Tensor:$sum_dy,\n    OneFlow_Tensor:$sum_dy_xmu,\n    OneFlow_Tensor:$count\n  );\n  let output = (outs\n    OneFlow_Tensor:$grad_in\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI32Attr, \"1\">:$axis\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_CropMirrorNormalizeFromTensorbufferOp : OneFlow_BaseOp<\"crop_mirror_normalize_from_tensorbuffer\", [NoMemoryEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in,\n    Optional<OneFlow_Tensor>:$mirror\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<StrAttr, \"\\\"BGR\\\"\">:$color_space,\n    DefaultValuedAttr<StrAttr, \"\\\"NCHW\\\"\">:$output_layout,\n    F32ArrayAttr:$mean,\n    F32ArrayAttr:$std,\n    DefaultValuedAttr<SI64Attr, \"0\">:$crop_h,\n    DefaultValuedAttr<SI64Attr, \"0\">:$crop_w,\n    DefaultValuedAttr<F32Attr, \"0.5\">:$crop_pos_x,\n    DefaultValuedAttr<F32Attr, \"0.5\">:$crop_pos_y,\n    OneFlow_DataType:$output_dtype\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_CropMirrorNormalizeFromUint8Op : OneFlow_BaseOp<\"crop_mirror_normalize_from_uint8\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in,\n    Optional<OneFlow_Tensor>:$mirror\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<StrAttr, \"\\\"BGR\\\"\">:$color_space,\n    DefaultValuedAttr<StrAttr, \"\\\"NCHW\\\"\">:$output_layout,\n    F32ArrayAttr:$mean,\n    F32ArrayAttr:$std,\n    DefaultValuedAttr<SI64Attr, \"0\">:$crop_h,\n    DefaultValuedAttr<SI64Attr, \"0\">:$crop_w,\n    DefaultValuedAttr<F32Attr, \"0.5\">:$crop_pos_x,\n    DefaultValuedAttr<F32Attr, \"0.5\">:$crop_pos_y,\n    OneFlow_DataType:$output_dtype\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_ImageNormalizeOp : OneFlow_BaseOp<\"image_normalize\", [NoMemoryEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    F32ArrayAttr:$std,\n    F32ArrayAttr:$mean\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_L2NormalizeOp : OneFlow_BaseOp<\"l2_normalize\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x\n  );\n  let output = (outs\n    OneFlow_Tensor:$y,\n    OneFlow_Tensor:$square_x_sum\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI32Attr, \"0\">:$axis,\n    DefaultValuedAttr<F32Attr, \"0.\">:$epsilon\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_L2NormalizeGradOp : OneFlow_BaseOp<\"l2_normalize_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$dy,\n    OneFlow_Tensor:$y,\n    OneFlow_Tensor:$square_x_sum\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI32Attr, \"0\">:$axis,\n    DefaultValuedAttr<F32Attr, \"0.\">:$epsilon\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_LayerNormOp : OneFlow_BaseOp<\"layer_norm\", [NoMemoryEffect, AttrSizedOperandSegments, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    Optional<OneFlow_Tensor>:$beta,\n    Optional<OneFlow_Tensor>:$gamma\n  );\n  let output = (outs\n    OneFlow_Tensor:$y,\n    OneFlow_Tensor:$mean,\n    OneFlow_Tensor:$inv_variance\n  );\n  let attrs = (ins\n    DefaultValuedAttr<BoolAttr, \"false\">:$center,\n    DefaultValuedAttr<BoolAttr, \"false\">:$scale,\n    DefaultValuedAttr<SI64Attr, \"0\">:$begin_norm_axis,\n    DefaultValuedAttr<SI64Attr, \"0\">:$begin_params_axis,\n    DefaultValuedAttr<F64Attr, \"0.\">:$epsilon\n  );\n  let trait_attrs = (ins\n    DenseI32ArrayAttr:$operand_segment_sizes\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_SkipLayerNormOp : OneFlow_BaseOp<\"skip_layer_norm\", [NoMemoryEffect, AttrSizedOperandSegments, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    Optional<OneFlow_Tensor>:$gamma,\n    Optional<OneFlow_Tensor>:$beta,\n    Optional<OneFlow_Tensor>:$bias,\n    Optional<OneFlow_Tensor>:$skip\n  );\n  let output = (outs\n    OneFlow_Tensor:$y,\n    OneFlow_Tensor:$mean,\n    OneFlow_Tensor:$inv_variance\n  );\n  let attrs = (ins\n    DefaultValuedAttr<F64Attr, \"0.00001\">:$epsilon,\n    DefaultValuedAttr<F64Attr, \"1.0\">:$alpha\n  );\n  let trait_attrs = (ins\n    DenseI32ArrayAttr:$operand_segment_sizes\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_LayerNormGradOp : OneFlow_BaseOp<\"layer_norm_grad\", [NoMemoryEffect, AttrSizedOperandSegments, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$dy,\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$mean,\n    OneFlow_Tensor:$inv_variance,\n    Optional<OneFlow_Tensor>:$gamma,\n    Optional<OneFlow_Tensor>:$_add_to_output\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI64Attr, \"0\">:$begin_norm_axis,\n    DefaultValuedAttr<F64Attr, \"0.\">:$epsilon\n  );\n  let trait_attrs = (ins\n    DenseI32ArrayAttr:$operand_segment_sizes\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_FuseLayerNormGradOp : OneFlow_BaseOp<\"fuse_layer_norm_grad\", [NoMemoryEffect, AttrSizedOperandSegments, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$dy,\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$mean,\n    OneFlow_Tensor:$inv_variance,\n    Optional<OneFlow_Tensor>:$gamma,\n    Optional<OneFlow_Tensor>:$_add_to_output\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx,\n    OneFlow_Tensor:$gamma_diff,\n    OneFlow_Tensor:$beta_diff\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI64Attr, \"0\">:$begin_norm_axis,\n    DefaultValuedAttr<SI64Attr, \"0\">:$begin_params_axis,\n    DefaultValuedAttr<F64Attr, \"0.\">:$epsilon\n  );\n  let trait_attrs = (ins\n    DenseI32ArrayAttr:$operand_segment_sizes,\n    DenseI32ArrayAttr:$result_segment_sizes\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_LayerNormParamGradOp : OneFlow_BaseOp<\"layer_norm_param_grad\", [NoMemoryEffect, AttrSizedResultSegments, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$dy,\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$mean,\n    OneFlow_Tensor:$inv_variance\n  );\n  let output = (outs\n    Optional<OneFlow_Tensor>:$beta_diff,\n    Optional<OneFlow_Tensor>:$gamma_diff\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI64Attr, \"0\">:$begin_params_axis\n  );\n  let trait_attrs = (ins\n    DenseI32ArrayAttr:$result_segment_sizes\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_NormalOp : OneFlow_BaseOp<\"normal\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<F64Attr, \"0.\">:$mean,\n    DefaultValuedAttr<F64Attr, \"1.\">:$std,\n    DefaultValuedAttr<SI64Attr, \"0\">:$seed,\n    OneFlow_DataType:$dtype,\n    ShapeAttr:$shape,\n    StrArrayAttr:$nd_sbp\n  );\n  let same_output_regst_num = 1;\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_nd_sbp_infer_fn = 1;\n}\n\ndef OneFlow_NormalizationOp : OneFlow_NormalizationBaseOp<\"normalization\", [AttrSizedResultSegments, DeclareOpInterfaceMethods<NCHWCompatibleInterface>]> {\n  let hasCanonicalizer = 1;\n}\n\ndef OneFlow_NormalizationGradOp : OneFlow_BaseOp<\"normalization_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$dy,\n    OneFlow_Tensor:$mean,\n    OneFlow_Tensor:$inv_variance,\n    OneFlow_Tensor:$gamma\n  );\n  let output = (outs\n    OneFlow_Tensor:$gamma_diff,\n    OneFlow_Tensor:$beta_diff,\n    OneFlow_Tensor:$dx\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI32Attr, \"0\">:$axis,\n    DefaultValuedAttr<F32Attr, \"0.\">:$epsilon\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_GroupNormOp : OneFlow_BaseOp<\"group_norm\", [NoMemoryEffect, AttrSizedOperandSegments, DeclareOpInterfaceMethods<UserOpCompatibleInterface>, DeclareOpInterfaceMethods<NCHWCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    Optional<OneFlow_Tensor>:$beta,\n    Optional<OneFlow_Tensor>:$gamma\n  );\n  let output = (outs\n    OneFlow_Tensor:$y,\n    OneFlow_Tensor:$mean,\n    OneFlow_Tensor:$inv_variance\n  );\n  let attrs = (ins\n    DefaultValuedAttr<BoolAttr, \"false\">:$affine,\n    DefaultValuedAttr<SI32Attr, \"0\">:$num_groups,\n    DefaultValuedAttr<F64Attr, \"0.\">:$epsilon,\n    DefaultValuedAttr<StrAttr, \"\\\"channels_first\\\"\">:$data_format,\n    DefaultValuedAttr<StrAttr, \"\\\"none\\\"\">:$activation\n  );\n  let trait_attrs = (ins\n    DenseI32ArrayAttr:$operand_segment_sizes\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_GroupNormGradOp : OneFlow_BaseOp<\"group_norm_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$dy,\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$mean,\n    OneFlow_Tensor:$inv_variance,\n    Optional<OneFlow_Tensor>:$gamma\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI32Attr, \"0\">:$num_groups,\n    DefaultValuedAttr<F64Attr, \"0.\">:$epsilon\n  );\n  let trait_attrs = (ins\n    DenseI32ArrayAttr:$operand_segment_sizes\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_GroupNormParamGradOp : OneFlow_BaseOp<\"group_norm_param_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$dy,\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$mean,\n    OneFlow_Tensor:$inv_variance\n  );\n  let output = (outs\n    OneFlow_Tensor:$dgamma,\n    OneFlow_Tensor:$dbeta\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_RmsNormOp : OneFlow_BaseOp<\"rms_norm\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    Optional<OneFlow_Tensor>:$weight\n  );\n  let output = (outs\n    OneFlow_Tensor:$y,\n    OneFlow_Tensor:$inv_rms\n  );\n  let attrs = (ins\n    ShapeAttr:$normalized_shape,\n    DefaultValuedAttr<F32Attr, \"0.00001\">:$epsilon\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_RmsNormParamGradOp : OneFlow_BaseOp<\"rms_norm_param_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$dy,\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$inv_rms\n  );\n  let output = (outs\n    OneFlow_Tensor:$weight_grad\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_RmsNormGradOp : OneFlow_BaseOp<\"rms_norm_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$dy,\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$inv_rms,\n    Optional<OneFlow_Tensor>:$weight\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_SkipRmsNormOp : OneFlow_BaseOp<\"skip_rms_norm\", [NoMemoryEffect, AttrSizedOperandSegments, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    Optional<OneFlow_Tensor>:$weight,\n    Optional<OneFlow_Tensor>:$bias,\n    Optional<OneFlow_Tensor>:$skip\n  );\n  let output = (outs\n    OneFlow_Tensor:$y,\n    OneFlow_Tensor:$inv_rms\n  );\n  let attrs = (ins\n    ShapeAttr:$normalized_shape,\n    DefaultValuedAttr<F32Attr, \"0.00001\">:$epsilon,\n    DefaultValuedAttr<F32Attr, \"1.0\">:$alpha\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\n#endif // GET_ONEFLOW_NORMALIZATION_OP_DEFINITIONS\n\n\n#ifdef GET_ONEFLOW_OPTIMIZER_OP_DEFINITIONS\n\ndef OneFlow_AdagradUpdateOp : OneFlow_BaseOp<\"adagrad_update\", [NoGrad, AttrSizedOperandSegments, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$model,\n    OneFlow_Tensor:$model_diff,\n    Optional<OneFlow_Tensor>:$learning_rate,\n    Optional<OneFlow_Tensor>:$scale_by_tensor,\n    Optional<OneFlow_Tensor>:$skip_if,\n    Optional<OneFlow_Tensor>:$train_step,\n    OneFlow_Tensor:$sum\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI32Attr, \"0\">:$train_step_val,\n    DefaultValuedAttr<F32Attr, \"0.\">:$learning_rate_val,\n    DefaultValuedAttr<F32Attr, \"1.\">:$learning_rate_scale,\n    DefaultValuedAttr<F64Attr, \"1.\">:$scale,\n    DefaultValuedAttr<F32Attr, \"0.\">:$l1,\n    DefaultValuedAttr<F32Attr, \"0.\">:$l2,\n    DefaultValuedAttr<F32Attr, \"0.\">:$lr_decay,\n    DefaultValuedAttr<F32Attr, \"0.\">:$weight_decay,\n    DefaultValuedAttr<F32Attr, \"0.\">:$epsilon\n  );\n  let trait_attrs = (ins\n    DenseI32ArrayAttr:$operand_segment_sizes\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_input_arg_modify_fn = 1;\n}\n\ndef OneFlow_AdamBiasCorrectionFactorOp : OneFlow_BaseOp<\"adam_bias_correction_factor\", [NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$train_step\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<F32Attr, \"0.9\">:$beta\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_AdamUpdateOp : OneFlow_BaseOp<\"adam_update\", [NoGrad, AttrSizedOperandSegments, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$model,\n    OneFlow_Tensor:$model_diff,\n    Optional<OneFlow_Tensor>:$model_copy,\n    Optional<OneFlow_Tensor>:$learning_rate,\n    Optional<OneFlow_Tensor>:$scale_by_tensor,\n    Optional<OneFlow_Tensor>:$skip_if,\n    Optional<OneFlow_Tensor>:$bias_correction1,\n    Optional<OneFlow_Tensor>:$bias_correction2,\n    OneFlow_Tensor:$m,\n    OneFlow_Tensor:$v,\n    Optional<OneFlow_Tensor>:$max_v\n  );\n  let attrs = (ins\n    DefaultValuedAttr<F32Attr, \"0.\">:$learning_rate_val,\n    DefaultValuedAttr<F32Attr, \"1.\">:$learning_rate_scale,\n    DefaultValuedAttr<F32Attr, \"1.\">:$bias_correction1_val,\n    DefaultValuedAttr<F32Attr, \"1.\">:$bias_correction2_val,\n    DefaultValuedAttr<F64Attr, \"1.\">:$scale,\n    DefaultValuedAttr<F32Attr, \"0.\">:$l1,\n    DefaultValuedAttr<F32Attr, \"0.\">:$l2,\n    DefaultValuedAttr<F32Attr, \"0.9\">:$beta1,\n    DefaultValuedAttr<F32Attr, \"0.999\">:$beta2,\n    DefaultValuedAttr<F32Attr, \"0.\">:$epsilon,\n    DefaultValuedAttr<F32Attr, \"0.\">:$weight_decay,\n    DefaultValuedAttr<BoolAttr, \"false\">:$amsgrad,\n    DefaultValuedAttr<BoolAttr, \"true\">:$do_bias_correction\n  );\n  let trait_attrs = (ins\n    DenseI32ArrayAttr:$operand_segment_sizes\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_input_arg_modify_fn = 1;\n}\n\ndef OneFlow_IndexedSlicesAdamUpdateOp : OneFlow_BaseOp<\"indexed_slices_adam_update\", [NoGrad, AttrSizedOperandSegments, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$model,\n    OneFlow_Tensor:$model_diff_indices,\n    OneFlow_Tensor:$model_diff_values,\n    OneFlow_Tensor:$learning_rate,\n    Optional<OneFlow_Tensor>:$bias_correction1,\n    Optional<OneFlow_Tensor>:$bias_correction2,\n    OneFlow_Tensor:$m,\n    OneFlow_Tensor:$v,\n    Optional<OneFlow_Tensor>:$max_v\n  );\n  let attrs = (ins\n    DefaultValuedAttr<F32Attr, \"0.\">:$learning_rate_val,\n    DefaultValuedAttr<F32Attr, \"1.\">:$learning_rate_scale,\n    DefaultValuedAttr<F32Attr, \"0.9\">:$beta1,\n    DefaultValuedAttr<F32Attr, \"0.999\">:$beta2,\n    DefaultValuedAttr<F32Attr, \"0.\">:$epsilon,\n    DefaultValuedAttr<F32Attr, \"0.\">:$weight_decay,\n    DefaultValuedAttr<BoolAttr, \"false\">:$amsgrad,\n    DefaultValuedAttr<BoolAttr, \"true\">:$do_bias_correction\n  );\n  let trait_attrs = (ins\n    DenseI32ArrayAttr:$operand_segment_sizes\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_input_arg_modify_fn = 1;\n}\n\ndef OneFlow_IndexedSlicesMomentumUpdateOp : OneFlow_BaseOp<\"indexed_slices_momentum_update\", [NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$model,\n    OneFlow_Tensor:$model_diff_indices,\n    OneFlow_Tensor:$model_diff_values,\n    OneFlow_Tensor:$learning_rate,\n    OneFlow_Tensor:$momentum\n  );\n  let attrs = (ins\n    DefaultValuedAttr<F32Attr, \"1.\">:$learning_rate_scale,\n    DefaultValuedAttr<F32Attr, \"0.9\">:$beta,\n    DefaultValuedAttr<F32Attr, \"0.0\">:$dampening,\n    DefaultValuedAttr<BoolAttr, \"false\">:$nesterov,\n    DefaultValuedAttr<BoolAttr, \"false\">:$maximize,\n    DefaultValuedAttr<F32Attr, \"0.\">:$weight_decay\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_input_arg_modify_fn = 1;\n}\n\ndef OneFlow_IndexedSlicesSgdUpdateOp : OneFlow_BaseOp<\"indexed_slices_sgd_update\", [NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$model,\n    OneFlow_Tensor:$model_diff_indices,\n    OneFlow_Tensor:$model_diff_values,\n    OneFlow_Tensor:$learning_rate\n  );\n  let attrs = (ins\n    DefaultValuedAttr<F32Attr, \"1.\">:$learning_rate_scale,\n    DefaultValuedAttr<F32Attr, \"0.\">:$weight_decay\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_input_arg_modify_fn = 1;\n}\n\ndef OneFlow_LambUpdateOp : OneFlow_BaseOp<\"lamb_update\", [NoGrad, AttrSizedOperandSegments, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$model,\n    OneFlow_Tensor:$model_diff,\n    Optional<OneFlow_Tensor>:$learning_rate,\n    Optional<OneFlow_Tensor>:$scale_by_tensor,\n    Optional<OneFlow_Tensor>:$skip_if,\n    Optional<OneFlow_Tensor>:$bias_correction1,\n    Optional<OneFlow_Tensor>:$bias_correction2,\n    OneFlow_Tensor:$m,\n    OneFlow_Tensor:$v\n  );\n  let attrs = (ins\n    DefaultValuedAttr<F32Attr, \"0.\">:$learning_rate_val,\n    DefaultValuedAttr<F32Attr, \"1.\">:$learning_rate_scale,\n    DefaultValuedAttr<F32Attr, \"1.\">:$bias_correction1_val,\n    DefaultValuedAttr<F32Attr, \"1.\">:$bias_correction2_val,\n    DefaultValuedAttr<F64Attr, \"1.\">:$scale,\n    DefaultValuedAttr<F32Attr, \"0.\">:$l1,\n    DefaultValuedAttr<F32Attr, \"0.\">:$l2,\n    DefaultValuedAttr<F32Attr, \"0.9\">:$beta1,\n    DefaultValuedAttr<F32Attr, \"0.999\">:$beta2,\n    DefaultValuedAttr<F32Attr, \"0.\">:$epsilon,\n    DefaultValuedAttr<F32Attr, \"0.\">:$weight_decay,\n    DefaultValuedAttr<BoolAttr, \"true\">:$do_bias_correction\n  );\n  let trait_attrs = (ins\n    DenseI32ArrayAttr:$operand_segment_sizes\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_input_arg_modify_fn = 1;\n}\n\ndef OneFlow_LarsUpdateOp : OneFlow_BaseOp<\"lars_update\", [NoGrad, AttrSizedOperandSegments, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$model,\n    OneFlow_Tensor:$model_diff,\n    OneFlow_Tensor:$learning_rate,\n    OneFlow_Tensor:$momentum,\n    Optional<OneFlow_Tensor>:$scale_by_tensor,\n    Optional<OneFlow_Tensor>:$skip_if\n  );\n  let attrs = (ins\n    DefaultValuedAttr<F32Attr, \"1.\">:$learning_rate_scale,\n    DefaultValuedAttr<F64Attr, \"1.\">:$scale,\n    DefaultValuedAttr<F32Attr, \"0.\">:$l1,\n    DefaultValuedAttr<F32Attr, \"0.\">:$l2,\n    DefaultValuedAttr<F32Attr, \"0.9\">:$momentum_beta,\n    DefaultValuedAttr<F32Attr, \"0.\">:$epsilon,\n    DefaultValuedAttr<F32Attr, \"0.0001\">:$lars_coefficient,\n    DefaultValuedAttr<F32Attr, \"0.\">:$weight_decay\n  );\n  let trait_attrs = (ins\n    DenseI32ArrayAttr:$operand_segment_sizes\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_input_arg_modify_fn = 1;\n}\n\ndef OneFlow_MomentumUpdateOp : OneFlow_BaseOp<\"momentum_update\", [NoGrad, AttrSizedOperandSegments, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$model,\n    OneFlow_Tensor:$model_diff,\n    OneFlow_Tensor:$momentum,\n    Optional<OneFlow_Tensor>:$learning_rate,\n    Optional<OneFlow_Tensor>:$scale_by_tensor,\n    Optional<OneFlow_Tensor>:$skip_if\n  );\n  let attrs = (ins\n    DefaultValuedAttr<F32Attr, \"0.\">:$learning_rate_val,\n    DefaultValuedAttr<F32Attr, \"1.\">:$learning_rate_scale,\n    DefaultValuedAttr<F64Attr, \"1.\">:$scale,\n    DefaultValuedAttr<F32Attr, \"0.\">:$l1,\n    DefaultValuedAttr<F32Attr, \"0.\">:$l2,\n    DefaultValuedAttr<F32Attr, \"0.9\">:$beta,\n    DefaultValuedAttr<F32Attr, \"0.0\">:$dampening,\n    DefaultValuedAttr<BoolAttr, \"false\">:$nesterov,\n    DefaultValuedAttr<BoolAttr, \"false\">:$maximize,\n    DefaultValuedAttr<F32Attr, \"0.\">:$weight_decay\n  );\n  let trait_attrs = (ins\n    DenseI32ArrayAttr:$operand_segment_sizes\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_input_arg_modify_fn = 1;\n}\n\ndef OneFlow_RmspropUpdateOp : OneFlow_BaseOp<\"rmsprop_update\", [NoGrad, AttrSizedOperandSegments, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$model,\n    OneFlow_Tensor:$model_diff,\n    Optional<OneFlow_Tensor>:$learning_rate,\n    Optional<OneFlow_Tensor>:$scale_by_tensor,\n    Optional<OneFlow_Tensor>:$skip_if,\n    OneFlow_Tensor:$mean_square,\n    Optional<OneFlow_Tensor>:$mean_gradient\n  );\n  let attrs = (ins\n    DefaultValuedAttr<F32Attr, \"0.\">:$learning_rate_val,\n    DefaultValuedAttr<F32Attr, \"1.\">:$learning_rate_scale,\n    DefaultValuedAttr<F64Attr, \"1.\">:$scale,\n    DefaultValuedAttr<F32Attr, \"0.\">:$l1,\n    DefaultValuedAttr<F32Attr, \"0.\">:$l2,\n    DefaultValuedAttr<BoolAttr, \"false\">:$centered,\n    DefaultValuedAttr<F32Attr, \"0.\">:$epsilon,\n    DefaultValuedAttr<F32Attr, \"0.99\">:$decay_rate,\n    DefaultValuedAttr<F32Attr, \"0.\">:$weight_decay\n  );\n  let trait_attrs = (ins\n    DenseI32ArrayAttr:$operand_segment_sizes\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_input_arg_modify_fn = 1;\n}\n\ndef OneFlow_SgdUpdateOp : OneFlow_BaseOp<\"sgd_update\", [NoGrad, AttrSizedOperandSegments, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$model,\n    OneFlow_Tensor:$model_diff,\n    Optional<OneFlow_Tensor>:$model_copy,\n    Optional<OneFlow_Tensor>:$learning_rate,\n    Optional<OneFlow_Tensor>:$scale_by_tensor,\n    Optional<OneFlow_Tensor>:$skip_if\n  );\n  let attrs = (ins\n    DefaultValuedAttr<F32Attr, \"0.\">:$learning_rate_val,\n    DefaultValuedAttr<F32Attr, \"1.\">:$learning_rate_scale,\n    DefaultValuedAttr<F64Attr, \"1.\">:$scale,\n    DefaultValuedAttr<F32Attr, \"0.\">:$l1,\n    DefaultValuedAttr<F32Attr, \"0.\">:$l2,\n    DefaultValuedAttr<F32Attr, \"0.\">:$weight_decay\n  );\n  let trait_attrs = (ins\n    DenseI32ArrayAttr:$operand_segment_sizes\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_input_arg_modify_fn = 1;\n}\n\ndef OneFlow_FtrlUpdateOp : OneFlow_BaseOp<\"ftrl_update\", [NoGrad, AttrSizedOperandSegments, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$model,\n    OneFlow_Tensor:$model_diff,\n    Optional<OneFlow_Tensor>:$learning_rate,\n    Optional<OneFlow_Tensor>:$skip_if,\n    OneFlow_Tensor:$accumulate,\n    OneFlow_Tensor:$z\n  );\n  let attrs = (ins\n    DefaultValuedAttr<F32Attr, \"0.\">:$learning_rate_val,\n    DefaultValuedAttr<F32Attr, \"1.\">:$learning_rate_scale,\n    DefaultValuedAttr<F64Attr, \"1.\">:$scale,\n    DefaultValuedAttr<F32Attr, \"0.\">:$l1,\n    DefaultValuedAttr<F32Attr, \"0.\">:$l2,\n    DefaultValuedAttr<F32Attr, \"0.\">:$weight_decay,\n    DefaultValuedAttr<F32Attr, \"0.\">:$lr_power,\n    DefaultValuedAttr<F32Attr, \"0.\">:$lambda1,\n    DefaultValuedAttr<F32Attr, \"0.\">:$lambda2,\n    DefaultValuedAttr<F32Attr, \"0.\">:$beta\n  );\n  let trait_attrs = (ins\n    DenseI32ArrayAttr:$operand_segment_sizes\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_input_arg_modify_fn = 1;\n}\n\ndef OneFlow_AdadeltaUpdateOp : OneFlow_BaseOp<\"adadelta_update\", [NoGrad, AttrSizedOperandSegments, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$model,\n    OneFlow_Tensor:$model_diff,\n    Optional<OneFlow_Tensor>:$learning_rate,\n    Optional<OneFlow_Tensor>:$skip_if,\n    OneFlow_Tensor:$square_avgs,\n    OneFlow_Tensor:$acc_deltas\n  );\n  let attrs = (ins\n    DefaultValuedAttr<F32Attr, \"0.\">:$learning_rate_val,\n    DefaultValuedAttr<F32Attr, \"1.\">:$learning_rate_scale,\n    DefaultValuedAttr<F64Attr, \"1.\">:$scale,\n    DefaultValuedAttr<F32Attr, \"0.\">:$l1,\n    DefaultValuedAttr<F32Attr, \"0.\">:$l2,\n    DefaultValuedAttr<F32Attr, \"0.\">:$weight_decay,\n    DefaultValuedAttr<F32Attr, \"0.9\">:$rho,\n    DefaultValuedAttr<F32Attr, \"0.\">:$epsilon,\n    DefaultValuedAttr<BoolAttr, \"false\">:$maximize\n  );\n  let trait_attrs = (ins\n    DenseI32ArrayAttr:$operand_segment_sizes\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_input_arg_modify_fn = 1;\n}\n\ndef OneFlow_MultiTensorSgdUpdateOp : OneFlow_BaseOp<\"multi_tensor_sgd_update\", [NoGrad, AttrSizedOperandSegments, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    Variadic<OneFlow_Tensor>:$model,\n    Variadic<OneFlow_Tensor>:$model_diff,\n    Optional<OneFlow_Tensor>:$learning_rate,\n    Optional<OneFlow_Tensor>:$scale_by_tensor,\n    Optional<OneFlow_Tensor>:$skip_if\n  );\n  let attrs = (ins\n    DefaultValuedAttr<F32Attr, \"0.\">:$learning_rate_val,\n    DefaultValuedAttr<F32Attr, \"1.\">:$learning_rate_scale,\n    DefaultValuedAttr<F64Attr, \"1.\">:$scale,\n    DefaultValuedAttr<F32Attr, \"0.\">:$l1,\n    DefaultValuedAttr<F32Attr, \"0.\">:$l2,\n    DefaultValuedAttr<F32Attr, \"0.\">:$weight_decay\n  );\n  let trait_attrs = (ins\n    DenseI32ArrayAttr:$operand_segment_sizes\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_input_arg_modify_fn = 1;\n}\n\ndef OneFlow_MultiTensorMomentumUpdateOp : OneFlow_BaseOp<\"multi_tensor_momentum_update\", [NoGrad, AttrSizedOperandSegments, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    Variadic<OneFlow_Tensor>:$model,\n    Variadic<OneFlow_Tensor>:$model_diff,\n    Variadic<OneFlow_Tensor>:$momentum_buf,\n    Optional<OneFlow_Tensor>:$learning_rate,\n    Optional<OneFlow_Tensor>:$scale_by_tensor,\n    Optional<OneFlow_Tensor>:$skip_if\n  );\n  let attrs = (ins\n    DefaultValuedAttr<F32Attr, \"0.\">:$learning_rate_val,\n    DefaultValuedAttr<F32Attr, \"1.\">:$learning_rate_scale,\n    DefaultValuedAttr<F64Attr, \"1.\">:$scale,\n    DefaultValuedAttr<F32Attr, \"0.\">:$l1,\n    DefaultValuedAttr<F32Attr, \"0.\">:$l2,\n    DefaultValuedAttr<F32Attr, \"0.\">:$weight_decay,\n    DefaultValuedAttr<F32Attr, \"0.\">:$momentum,\n    DefaultValuedAttr<F32Attr, \"0.\">:$dampening,\n    DefaultValuedAttr<BoolAttr, \"false\">:$nesterov,\n    DefaultValuedAttr<BoolAttr, \"false\">:$maximize\n  );\n  let trait_attrs = (ins\n    DenseI32ArrayAttr:$operand_segment_sizes\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_input_arg_modify_fn = 1;\n}\n\ndef OneFlow_MultiTensorAdamUpdateOp : OneFlow_BaseOp<\"multi_tensor_adam_update\", [NoGrad, AttrSizedOperandSegments, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    Variadic<OneFlow_Tensor>:$model,\n    Variadic<OneFlow_Tensor>:$model_diff,\n    Optional<OneFlow_Tensor>:$learning_rate,\n    Optional<OneFlow_Tensor>:$scale_by_tensor,\n    Optional<OneFlow_Tensor>:$skip_if,\n    Optional<OneFlow_Tensor>:$bias_correction1,\n    Optional<OneFlow_Tensor>:$bias_correction2,\n    Variadic<OneFlow_Tensor>:$m,\n    Variadic<OneFlow_Tensor>:$v\n  );\n  let attrs = (ins\n    DefaultValuedAttr<F32Attr, \"0.\">:$learning_rate_val,\n    DefaultValuedAttr<F32Attr, \"1.\">:$learning_rate_scale,\n    DefaultValuedAttr<F32Attr, \"1.\">:$bias_correction1_val,\n    DefaultValuedAttr<F32Attr, \"1.\">:$bias_correction2_val,\n    DefaultValuedAttr<F64Attr, \"1.\">:$scale,\n    DefaultValuedAttr<F32Attr, \"0.\">:$l1,\n    DefaultValuedAttr<F32Attr, \"0.\">:$l2,\n    DefaultValuedAttr<F32Attr, \"0.9\">:$beta1,\n    DefaultValuedAttr<F32Attr, \"0.999\">:$beta2,\n    DefaultValuedAttr<F32Attr, \"0.00001\">:$epsilon,\n    DefaultValuedAttr<F32Attr, \"0.\">:$weight_decay,\n    DefaultValuedAttr<BoolAttr, \"false\">:$amsgrad,\n    DefaultValuedAttr<BoolAttr, \"true\">:$do_bias_correction\n  );\n  let trait_attrs = (ins\n    DenseI32ArrayAttr:$operand_segment_sizes\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_input_arg_modify_fn = 1;\n}\n\ndef OneFlow_MultiTensorSgdUpdateWithCastOp : OneFlow_BaseOp<\"multi_tensor_sgd_update_with_cast\", [NoGrad, AttrSizedOperandSegments, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    Variadic<OneFlow_Tensor>:$model,\n    Variadic<OneFlow_Tensor>:$model_diff,\n    Variadic<OneFlow_Tensor>:$model_copy,\n    Optional<OneFlow_Tensor>:$learning_rate,\n    Optional<OneFlow_Tensor>:$scale_by_tensor,\n    Optional<OneFlow_Tensor>:$skip_if\n  );\n  let attrs = (ins\n    DefaultValuedAttr<F32Attr, \"0.\">:$learning_rate_val,\n    DefaultValuedAttr<F32Attr, \"1.\">:$learning_rate_scale,\n    DefaultValuedAttr<F64Attr, \"1.\">:$scale,\n    DefaultValuedAttr<F32Attr, \"0.\">:$l1,\n    DefaultValuedAttr<F32Attr, \"0.\">:$l2,\n    DefaultValuedAttr<F32Attr, \"0.\">:$weight_decay\n  );\n  let trait_attrs = (ins\n    DenseI32ArrayAttr:$operand_segment_sizes\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_input_arg_modify_fn = 1;\n}\n\ndef OneFlow_MultiTensorMomentumUpdateWithCastOp : OneFlow_BaseOp<\"multi_tensor_momentum_update_with_cast\", [NoGrad, AttrSizedOperandSegments, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    Variadic<OneFlow_Tensor>:$model,\n    Variadic<OneFlow_Tensor>:$model_diff,\n    Variadic<OneFlow_Tensor>:$model_copy,\n    Variadic<OneFlow_Tensor>:$momentum_buf,\n    Optional<OneFlow_Tensor>:$learning_rate,\n    Optional<OneFlow_Tensor>:$scale_by_tensor,\n    Optional<OneFlow_Tensor>:$skip_if\n  );\n  let attrs = (ins\n    DefaultValuedAttr<F32Attr, \"0.\">:$learning_rate_val,\n    DefaultValuedAttr<F32Attr, \"1.\">:$learning_rate_scale,\n    DefaultValuedAttr<F64Attr, \"1.\">:$scale,\n    DefaultValuedAttr<F32Attr, \"0.\">:$l1,\n    DefaultValuedAttr<F32Attr, \"0.\">:$l2,\n    DefaultValuedAttr<F32Attr, \"0.\">:$weight_decay,\n    DefaultValuedAttr<F32Attr, \"0.\">:$momentum,\n    DefaultValuedAttr<F32Attr, \"0.\">:$dampening,\n    DefaultValuedAttr<BoolAttr, \"false\">:$nesterov,\n    DefaultValuedAttr<BoolAttr, \"false\">:$maximize\n  );\n  let trait_attrs = (ins\n    DenseI32ArrayAttr:$operand_segment_sizes\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_input_arg_modify_fn = 1;\n}\n\ndef OneFlow_MultiTensorAdamUpdateWithCastOp : OneFlow_BaseOp<\"multi_tensor_adam_update_with_cast\", [NoGrad, AttrSizedOperandSegments, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    Variadic<OneFlow_Tensor>:$model,\n    Variadic<OneFlow_Tensor>:$model_diff,\n    Variadic<OneFlow_Tensor>:$model_copy,\n    Optional<OneFlow_Tensor>:$learning_rate,\n    Optional<OneFlow_Tensor>:$scale_by_tensor,\n    Optional<OneFlow_Tensor>:$skip_if,\n    Optional<OneFlow_Tensor>:$bias_correction1,\n    Optional<OneFlow_Tensor>:$bias_correction2,\n    Variadic<OneFlow_Tensor>:$m,\n    Variadic<OneFlow_Tensor>:$v\n  );\n  let attrs = (ins\n    DefaultValuedAttr<F32Attr, \"0.\">:$learning_rate_val,\n    DefaultValuedAttr<F32Attr, \"1.\">:$learning_rate_scale,\n    DefaultValuedAttr<F32Attr, \"1.\">:$bias_correction1_val,\n    DefaultValuedAttr<F32Attr, \"1.\">:$bias_correction2_val,\n    DefaultValuedAttr<F64Attr, \"1.\">:$scale,\n    DefaultValuedAttr<F32Attr, \"0.\">:$l1,\n    DefaultValuedAttr<F32Attr, \"0.\">:$l2,\n    DefaultValuedAttr<F32Attr, \"0.9\">:$beta1,\n    DefaultValuedAttr<F32Attr, \"0.999\">:$beta2,\n    DefaultValuedAttr<F32Attr, \"0.00001\">:$epsilon,\n    DefaultValuedAttr<F32Attr, \"0.\">:$weight_decay,\n    DefaultValuedAttr<BoolAttr, \"false\">:$amsgrad,\n    DefaultValuedAttr<BoolAttr, \"true\">:$do_bias_correction\n  );\n  let trait_attrs = (ins\n    DenseI32ArrayAttr:$operand_segment_sizes\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_input_arg_modify_fn = 1;\n}\n\ndef OneFlow_MultiTensorYoloV5WeightUpdateOp : OneFlow_BaseOp<\"multi_tensor_yolov5_weight_update\", [NoGrad, AttrSizedOperandSegments, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    Variadic<OneFlow_Tensor>:$model,\n    Variadic<OneFlow_Tensor>:$model_update\n  );\n  let attrs = (ins\n    DefaultValuedAttr<F32Attr, \"0.\">:$d\n  );\n  let trait_attrs = (ins\n    DenseI32ArrayAttr:$operand_segment_sizes\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_input_arg_modify_fn = 1;\n}\n\n#endif // GET_ONEFLOW_OPTIMIZER_OP_DEFINITIONS\n\n\n#ifdef GET_ONEFLOW_PADDING_OP_DEFINITIONS\n\n\ndef OneFlow_PadOp : OneFlow_BaseOp<\"pad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x\n  );\n  let output = (outs\n    OneFlow_Tensor:$y\n  );\n  let attrs = (ins\n    SI64ArrayAttr:$padding_before,\n    SI64ArrayAttr:$padding_after,\n    SI64ArrayAttr:$padding,\n    DefaultValuedAttr<F64Attr, \"0.\">:$floating_constant_value,\n    DefaultValuedAttr<SI64Attr, \"0\">:$integral_constant_value\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_ReplicationPad1DOp : OneFlow_BaseOp<\"replication_pad1d\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x\n  );\n  let output = (outs\n    OneFlow_Tensor:$y\n  );\n  let attrs = (ins\n    SI64ArrayAttr:$padding\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_input_arg_modify_fn = 1;\n}\n\ndef OneFlow_ReplicationPad1DGradOp : OneFlow_BaseOp<\"replication_pad1d_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$dy\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let attrs = (ins\n    SI64ArrayAttr:$padding\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_ReflectionPad1DOp : OneFlow_BaseOp<\"reflection_pad1d\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x\n  );\n  let output = (outs\n    OneFlow_Tensor:$y\n  );\n  let attrs = (ins\n    SI64ArrayAttr:$padding\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_input_arg_modify_fn = 1;\n}\n\ndef OneFlow_ReflectionPad1DGradOp : OneFlow_BaseOp<\"reflection_pad1d_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$dy\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let attrs = (ins\n    SI64ArrayAttr:$padding\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_ReflectionPad2DOp : OneFlow_BaseOp<\"reflection_pad2d\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x\n  );\n  let output = (outs\n    OneFlow_Tensor:$y\n  );\n  let attrs = (ins\n    SI64ArrayAttr:$padding\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_input_arg_modify_fn = 1;\n}\n\ndef OneFlow_ReflectionPad2DGradOp : OneFlow_BaseOp<\"reflection_pad2d_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$dy\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let attrs = (ins\n    SI64ArrayAttr:$padding\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_ReplicationPad2DOp : OneFlow_BaseOp<\"replication_pad2d\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x\n  );\n  let output = (outs\n    OneFlow_Tensor:$y\n  );\n  let attrs = (ins\n    SI64ArrayAttr:$padding\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_input_arg_modify_fn = 1;\n}\n\ndef OneFlow_ReplicationPad2DGradOp : OneFlow_BaseOp<\"replication_pad2d_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$dy\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let attrs = (ins\n    SI64ArrayAttr:$padding\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_SamePaddingOp : OneFlow_BaseOp<\"same_padding\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x\n  );\n  let output = (outs\n    OneFlow_Tensor:$y\n  );\n  let attrs = (ins\n    StrAttr:$padding,\n    StrAttr:$data_format,\n    SI32ArrayAttr:$kernel_size,\n    SI32ArrayAttr:$strides,\n    SI32ArrayAttr:$dilation_rate\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_SamePaddingGradOp : OneFlow_BaseOp<\"same_padding_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x_like,\n    OneFlow_Tensor:$dy\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let attrs = (ins\n    StrAttr:$padding,\n    StrAttr:$data_format,\n    SI32ArrayAttr:$kernel_size,\n    SI32ArrayAttr:$strides,\n    SI32ArrayAttr:$dilation_rate\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\n#endif // GET_ONEFLOW_PADDING_OP_DEFINITIONS\n\n\n#ifdef GET_ONEFLOW_PARALLEL_CAST_OP_DEFINITIONS\n\ndef OneFlow_HierarchicalParallelCastOp : OneFlow_BaseOp<\"hierarchical_parallel_cast\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    StrArrayAttr:$nd_sbp,\n    StrAttr:$grad_mode,\n    StrArrayAttr:$grad_nd_sbp\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_nd_sbp_infer_fn = 1;\n  let has_get_nd_sbp_fn = 1;\n}\n\ndef OneFlow_HierarchicalParallelCastLikeOp : OneFlow_BaseOp<\"hierarchical_parallel_cast_like\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in,\n    OneFlow_Tensor:$like\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_nd_sbp_infer_fn = 1;\n}\n\ndef OneFlow_ParallelCastOp : OneFlow_BaseOp<\"parallel_cast\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    StrAttr:$sbp_parallel,\n    StrAttr:$grad_sbp_parallel\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_sbp_signature_infer_fn = 1;\n}\n\n#endif // GET_ONEFLOW_PARALLEL_CAST_OP_DEFINITIONS\n\n\n#ifdef GET_ONEFLOW_POOL_OP_DEFINITIONS\n\ndef OneFlow_AdaptiveAvgPool1DOp : OneFlow_AdaptivePoolBaseOp<\"adaptive_avg_pool1d\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {}\n\ndef OneFlow_AdaptiveAvgPool1DGradOp : OneFlow_AdaptivePoolGradBaseOp<\"adaptive_avg_pool1d_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {}\n\ndef OneFlow_AdaptiveAvgPool2DOp : OneFlow_AdaptivePoolBaseOp<\"adaptive_avg_pool2d\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {}\n\ndef OneFlow_AdaptiveAvgPool2DGradOp : OneFlow_AdaptivePoolGradBaseOp<\"adaptive_avg_pool2d_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {}\n\ndef OneFlow_AdaptiveAvgPool3DOp : OneFlow_AdaptivePoolBaseOp<\"adaptive_avg_pool3d\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {}\n\ndef OneFlow_AdaptiveAvgPool3DGradOp : OneFlow_AdaptivePoolGradBaseOp<\"adaptive_avg_pool3d_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {}\n\ndef OneFlow_AdaptiveMaxPool1DOp : OneFlow_AdaptiveMaxPoolBaseOp<\"adaptive_max_pool1d\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {}\ndef OneFlow_AdaptiveMaxPool1DGradOp : OneFlow_AdaptiveMaxPoolGradBaseOp<\"adaptive_max_pool1d_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {}\n\ndef OneFlow_AdaptiveMaxPool2DOp : OneFlow_AdaptiveMaxPoolBaseOp<\"adaptive_max_pool2d\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {}\ndef OneFlow_AdaptiveMaxPool2DGradOp : OneFlow_AdaptiveMaxPoolGradBaseOp<\"adaptive_max_pool2d_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {}\n\ndef OneFlow_AdaptiveMaxPool3DOp : OneFlow_AdaptiveMaxPoolBaseOp<\"adaptive_max_pool3d\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {}\ndef OneFlow_AdaptiveMaxPool3DGradOp : OneFlow_AdaptiveMaxPoolGradBaseOp<\"adaptive_max_pool3d_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {}\n\ndef OneFlow_AvgPool1DOp : OneFlow_AvgPoolBaseOp<\"avg_pool_1d\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {}\n\ndef OneFlow_AvgPool1DGradOp : OneFlow_AvgPoolGradBaseOp<\"avg_pool_1d_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {}\n\ndef OneFlow_AvgPool2DOp : OneFlow_AvgPoolBaseOp<\"avg_pool_2d\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {}\n\ndef OneFlow_AvgPool2DGradOp : OneFlow_AvgPoolGradBaseOp<\"avg_pool_2d_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {}\n\ndef OneFlow_AvgPool3DOp : OneFlow_AvgPoolBaseOp<\"avg_pool_3d\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {}\n\ndef OneFlow_AvgPool3DGradOp : OneFlow_AvgPoolGradBaseOp<\"avg_pool_3d_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {}\n\ndef OneFlow_MaxPool1DOp : OneFlow_MaxPoolBaseOp<\"max_pool_1d\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {}\n\ndef OneFlow_MaxPool1DGradOp : OneFlow_MaxPoolGradBaseOp<\"max_pool_1d_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {}\n\ndef OneFlow_MaxPool2DOp : OneFlow_MaxPoolBaseOp<\"max_pool_2d\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>, DeclareOpInterfaceMethods<NCHWCompatibleInterface>]> {}\n\ndef OneFlow_MaxPool2DGradOp : OneFlow_MaxPoolGradBaseOp<\"max_pool_2d_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {}\n\ndef OneFlow_MaxPool3DOp : OneFlow_MaxPoolBaseOp<\"max_pool_3d\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {}\n\ndef OneFlow_MaxPool3DGradOp : OneFlow_MaxPoolGradBaseOp<\"max_pool_3d_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {}\n\ndef OneFlow_MaxUnpool1DOp : OneFlow_MaxUnpoolBaseOp<\"max_unpool_1d\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {}\n\ndef OneFlow_MaxUnpool2DOp : OneFlow_MaxUnpoolBaseOp<\"max_unpool_2d\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {}\n\ndef OneFlow_MaxUnpool3DOp : OneFlow_MaxUnpoolBaseOp<\"max_unpool_3d\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {}\n\ndef OneFlow_MaxUnpool1DGradOp : OneFlow_MaxUnpoolGradBaseOp<\"max_unpool_1d_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {}\n\ndef OneFlow_MaxUnpool2DGradOp : OneFlow_MaxUnpoolGradBaseOp<\"max_unpool_2d_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {}\n\ndef OneFlow_MaxUnpool3DGradOp : OneFlow_MaxUnpoolGradBaseOp<\"max_unpool_3d_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {}\n\ndef OneFlow_TfAvgPool1DOp : OneFlow_TFPoolBaseOp<\"tf_avg_pool_1d\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {}\n\ndef OneFlow_TfAvgPool1DGradOp : OneFlow_TFPoolGradBaseOp<\"tf_avg_pool_1d_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {}\n\ndef OneFlow_TfAvgPool2DOp : OneFlow_TFPoolBaseOp<\"tf_avg_pool_2d\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {}\n\ndef OneFlow_TfAvgPool2DGradOp : OneFlow_TFPoolGradBaseOp<\"tf_avg_pool_2d_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {}\n\ndef OneFlow_TfAvgPool3DOp : OneFlow_TFPoolBaseOp<\"tf_avg_pool_3d\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {}\n\ndef OneFlow_TfAvgPool3DGradOp : OneFlow_TFPoolGradBaseOp<\"tf_avg_pool_3d_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {}\n\ndef OneFlow_TfMaxPool1DOp : OneFlow_TFPoolBaseOp<\"tf_max_pool_1d\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {}\n\ndef OneFlow_TfMaxPool1DGradOp : OneFlow_TFPoolGradBaseOp<\"tf_max_pool_1d_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {}\n\ndef OneFlow_TfMaxPool2DOp : OneFlow_TFPoolBaseOp<\"tf_max_pool_2d\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {}\n\ndef OneFlow_TfMaxPool2DGradOp : OneFlow_TFPoolGradBaseOp<\"tf_max_pool_2d_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {}\n\ndef OneFlow_TfMaxPool3DOp : OneFlow_TFPoolBaseOp<\"tf_max_pool_3d\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {}\n\ndef OneFlow_TfMaxPool3DGradOp : OneFlow_TFPoolGradBaseOp<\"tf_max_pool_3d_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {}\n\n#endif // GET_ONEFLOW_POOL_OP_DEFINITIONS\n\n\n#ifdef GET_ONEFLOW_QUANTIZATION_OP_DEFINITIONS\n\ndef OneFlow_FakeQuantizationOp : OneFlow_BaseOp<\"fake_quantization\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in,\n    OneFlow_Tensor:$scale,\n    OneFlow_Tensor:$zero_point\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<StrAttr, \"\\\"google\\\"\">:$quantization_formula,\n    DefaultValuedAttr<SI32Attr, \"8\">:$quantization_bit,\n    DefaultValuedAttr<StrAttr, \"\\\"symmetric\\\"\">:$quantization_scheme\n  );\n  let has_check_fn = 1;\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_input_arg_modify_fn = 1;\n}\n\ndef OneFlow_MinMaxObserverOp : OneFlow_BaseOp<\"min_max_observer\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$scale,\n    OneFlow_Tensor:$zero_point\n  );\n  let attrs = (ins\n    DefaultValuedAttr<StrAttr, \"\\\"google\\\"\">:$quantization_formula,\n    DefaultValuedAttr<SI32Attr, \"8\">:$quantization_bit,\n    DefaultValuedAttr<StrAttr, \"\\\"symmetric\\\"\">:$quantization_scheme,\n    DefaultValuedAttr<BoolAttr, \"true\">:$per_layer_quantization\n  );\n  let has_check_fn = 1;\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_input_arg_modify_fn = 1;\n}\n\ndef OneFlow_MovingAverageMinMaxObserverOp : OneFlow_BaseOp<\"moving_average_min_max_observer\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in,\n    OneFlow_Tensor:$current_train_step,\n    OneFlow_Tensor:$moving_max,\n    OneFlow_Tensor:$moving_min\n  );\n  let output = (outs\n    OneFlow_Tensor:$scale,\n    OneFlow_Tensor:$zero_point\n  );\n  let attrs = (ins\n    DefaultValuedAttr<BoolAttr, \"false\">:$training,\n    DefaultValuedAttr<StrAttr, \"\\\"google\\\"\">:$quantization_formula,\n    DefaultValuedAttr<SI64Attr, \"0\">:$stop_update_after_iters,\n    DefaultValuedAttr<SI32Attr, \"8\">:$quantization_bit,\n    DefaultValuedAttr<StrAttr, \"\\\"symmetric\\\"\">:$quantization_scheme,\n    DefaultValuedAttr<F32Attr, \"0.95\">:$momentum\n  );\n  let has_check_fn = 1;\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_input_arg_modify_fn = 1;\n}\n\ndef OneFlow_QuantizationOp : OneFlow_BaseOp<\"quantization\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in,\n    OneFlow_Tensor:$scale,\n    OneFlow_Tensor:$zero_point\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<StrAttr, \"\\\"google\\\"\">:$quantization_formula,\n    DefaultValuedAttr<SI32Attr, \"8\">:$quantization_bit,\n    DefaultValuedAttr<StrAttr, \"\\\"symmetric\\\"\">:$quantization_scheme\n  );\n  let has_check_fn = 1;\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_input_arg_modify_fn = 1;\n}\n\ndef OneFlow_GroupwiseDequantizeOp : OneFlow_BaseOp<\"groupwise_dequantize\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in,\n    OneFlow_Tensor:$scale,\n    Optional<OneFlow_Tensor>:$zero\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI32Attr, \"8\">:$num_bits,\n    DefaultValuedAttr<BoolAttr, \"true\">:$symmetric,\n    SI64Attr:$group_dim,\n    SI64Attr:$group_size\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_FusedLinearWithGroupwiseQuantizedWeightOp : OneFlow_BaseOp<\"fused_linear_with_groupwise_quantized_weight\", [NoMemoryEffect, AttrSizedOperandSegments, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$w,\n    OneFlow_Tensor:$w_scale,\n    Optional<OneFlow_Tensor>:$w_zero,\n    Optional<OneFlow_Tensor>:$b\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI32Attr, \"8\">:$num_bits,\n    DefaultValuedAttr<BoolAttr, \"true\">:$symmetric,\n    SI64Attr:$group_dim,\n    SI64Attr:$group_size\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\n#endif // GET_ONEFLOW_QUANTIZATION_OP_DEFINITIONS\n\n\n#ifdef GET_ONEFLOW_REDUCE_OP_DEFINITIONS\n\ndef OneFlow_IndexedSlicesReduceSumOp : OneFlow_BaseOp<\"indexed_slices_reduce_sum\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x_indices,\n    OneFlow_Tensor:$x_values\n  );\n  let output = (outs\n    OneFlow_Tensor:$y_indices,\n    OneFlow_Tensor:$y_values,\n    OneFlow_Tensor:$num_unique\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_ReduceAllOp : OneFlow_BaseOp<\"reduce_all\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$input_tensor\n  );\n  let output = (outs\n    OneFlow_Tensor:$output_tensor\n  );\n  let attrs = (ins\n    SI32ArrayAttr:$axis,\n    DefaultValuedAttr<BoolAttr, \"false\">:$keepdims\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_ReduceAnyOp : OneFlow_BaseOp<\"reduce_any\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$input_tensor\n  );\n  let output = (outs\n    OneFlow_Tensor:$output_tensor\n  );\n  let attrs = (ins\n    SI32ArrayAttr:$axis,\n    DefaultValuedAttr<BoolAttr, \"false\">:$keepdims\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_ReduceMaxOp : OneFlow_BaseOp<\"reduce_max\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$input_tensor\n  );\n  let output = (outs\n    OneFlow_Tensor:$output_tensor\n  );\n  let attrs = (ins\n    SI32ArrayAttr:$axis,\n    DefaultValuedAttr<BoolAttr, \"false\">:$keepdims\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_ReduceMaxDeviceStageOp : OneFlow_BaseOp<\"reduce_max_device_stage\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out,\n    OneFlow_Tensor:$mask,\n    OneFlow_Tensor:$count\n  );\n  let attrs = (ins\n    SI32ArrayAttr:$axis\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_ReduceMaxDeviceStageGradOp : OneFlow_BaseOp<\"reduce_max_device_stage_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$out_diff,\n    OneFlow_Tensor:$mask,\n    OneFlow_Tensor:$count\n  );\n  let output = (outs\n    OneFlow_Tensor:$in_diff\n  );\n  let attrs = (ins\n    SI32ArrayAttr:$axis\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_ReduceMaxGlobalStageOp : OneFlow_BaseOp<\"reduce_max_global_stage\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in,\n    OneFlow_Tensor:$device_count\n  );\n  let output = (outs\n    OneFlow_Tensor:$out,\n    OneFlow_Tensor:$mask\n  );\n  let attrs = (ins\n    SI32ArrayAttr:$axis,\n    DefaultValuedAttr<BoolAttr, \"false\">:$keepdims\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_input_arg_modify_fn = 1;\n}\n\ndef OneFlow_ReduceMaxGlobalStageGradOp : OneFlow_BaseOp<\"reduce_max_global_stage_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$out_diff,\n    OneFlow_Tensor:$mask,\n    OneFlow_Tensor:$device_count\n  );\n  let output = (outs\n    OneFlow_Tensor:$in_diff\n  );\n  let attrs = (ins\n    SI32ArrayAttr:$axis,\n    DefaultValuedAttr<BoolAttr, \"false\">:$keepdims\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_ReduceMinOp : OneFlow_BaseOp<\"reduce_min\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$input_tensor\n  );\n  let output = (outs\n    OneFlow_Tensor:$output_tensor\n  );\n  let attrs = (ins\n    SI32ArrayAttr:$axis,\n    DefaultValuedAttr<BoolAttr, \"false\">:$keepdims\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_ReduceMinDeviceStageOp : OneFlow_BaseOp<\"reduce_min_device_stage\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out,\n    OneFlow_Tensor:$mask,\n    OneFlow_Tensor:$count\n  );\n  let attrs = (ins\n    SI32ArrayAttr:$axis\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_ReduceMinDeviceStageGradOp : OneFlow_BaseOp<\"reduce_min_device_stage_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$out_diff,\n    OneFlow_Tensor:$mask,\n    OneFlow_Tensor:$count\n  );\n  let output = (outs\n    OneFlow_Tensor:$in_diff\n  );\n  let attrs = (ins\n    SI32ArrayAttr:$axis\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_ReduceMinGlobalStageOp : OneFlow_BaseOp<\"reduce_min_global_stage\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in,\n    OneFlow_Tensor:$device_count\n  );\n  let output = (outs\n    OneFlow_Tensor:$out,\n    OneFlow_Tensor:$mask\n  );\n  let attrs = (ins\n    SI32ArrayAttr:$axis,\n    DefaultValuedAttr<BoolAttr, \"false\">:$keepdims\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_input_arg_modify_fn = 1;\n}\n\ndef OneFlow_ReduceMinGlobalStageGradOp : OneFlow_BaseOp<\"reduce_min_global_stage_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$out_diff,\n    OneFlow_Tensor:$mask,\n    OneFlow_Tensor:$device_count\n  );\n  let output = (outs\n    OneFlow_Tensor:$in_diff\n  );\n  let attrs = (ins\n    SI32ArrayAttr:$axis,\n    DefaultValuedAttr<BoolAttr, \"false\">:$keepdims\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_ReduceProdOp : OneFlow_BaseOp<\"reduce_prod\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$input_tensor\n  );\n  let output = (outs\n    OneFlow_Tensor:$output_tensor\n  );\n  let attrs = (ins\n    SI32ArrayAttr:$axis,\n    DefaultValuedAttr<BoolAttr, \"false\">:$keepdims\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_ReduceSumOp : OneFlow_BaseOp<\"reduce_sum\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$input_tensor\n  );\n  let output = (outs\n    OneFlow_Tensor:$output_tensor\n  );\n  let attrs = (ins\n    SI32ArrayAttr:$axis,\n    DefaultValuedAttr<BoolAttr, \"false\">:$keepdims\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_ReduceNanSumOp : OneFlow_BaseOp<\"reduce_nansum\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$input_tensor\n  );\n  let output = (outs\n    OneFlow_Tensor:$output_tensor\n  );\n  let attrs = (ins\n    SI32ArrayAttr:$axis,\n    DefaultValuedAttr<BoolAttr, \"false\">:$keepdims\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_ReduceSumLikeOp : OneFlow_BaseOp<\"reduce_sum_like\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$like\n  );\n  let output = (outs\n    OneFlow_Tensor:$y\n  );\n  let attrs = (ins\n    SI32ArrayAttr:$axis\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_input_arg_modify_fn = 1;\n}\n\n#endif // GET_ONEFLOW_REDUCE_OP_DEFINITIONS\n\n\n#ifdef GET_ONEFLOW_RESHAPE_OP_DEFINITIONS\n\ndef OneFlow_ReshapeOp : OneFlow_BaseOp<\"reshape\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    ShapeAttr:$shape\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_enumerate_nd_sbp_signatures_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let hasFolder = 1;\n}\n\ndef OneFlow_ReshapeLikeOp : OneFlow_BaseOp<\"reshape_like\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in,\n    OneFlow_Tensor:$like\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_input_arg_modify_fn = 1;\n}\n\n#endif // GET_ONEFLOW_RESHAPE_OP_DEFINITIONS\n\n\n#ifdef GET_ONEFLOW_SCALAR_OP_DEFINITIONS\n\ndef OneFlow_ClipByScalarOp : OneFlow_BaseOp<\"clip_by_scalar\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x\n  );\n  let output = (outs\n    OneFlow_Tensor:$y\n  );\n  let attrs = (ins\n    DefaultValuedAttr<F64Attr, \"0.\">:$floating_min,\n    DefaultValuedAttr<SI64Attr, \"0\">:$integral_min,\n    DefaultValuedAttr<F64Attr, \"0.\">:$floating_max,\n    DefaultValuedAttr<SI64Attr, \"0\">:$integral_max\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_ClipByScalarGradOp : OneFlow_BaseOp<\"clip_by_scalar_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$dy,\n    OneFlow_Tensor:$x\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let attrs = (ins\n    DefaultValuedAttr<F64Attr, \"0.\">:$floating_min,\n    DefaultValuedAttr<SI64Attr, \"0\">:$integral_min,\n    DefaultValuedAttr<F64Attr, \"0.\">:$floating_max,\n    DefaultValuedAttr<SI64Attr, \"0\">:$integral_max\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_ClipByScalarMaxOp : OneFlow_BaseOp<\"clip_by_scalar_max\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x\n  );\n  let output = (outs\n    OneFlow_Tensor:$y\n  );\n  let attrs = (ins\n    DefaultValuedAttr<F64Attr, \"0.\">:$floating_max,\n    DefaultValuedAttr<SI64Attr, \"0\">:$integral_max\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_ClipByScalarMaxGradOp : OneFlow_BaseOp<\"clip_by_scalar_max_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$dy,\n    OneFlow_Tensor:$x\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let attrs = (ins\n    DefaultValuedAttr<F64Attr, \"0.\">:$floating_max,\n    DefaultValuedAttr<SI64Attr, \"0\">:$integral_max\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_ClipByScalarMinOp : OneFlow_BaseOp<\"clip_by_scalar_min\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x\n  );\n  let output = (outs\n    OneFlow_Tensor:$y\n  );\n  let attrs = (ins\n    DefaultValuedAttr<F64Attr, \"0.\">:$floating_min,\n    DefaultValuedAttr<SI64Attr, \"0\">:$integral_min\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_ClipByScalarMinGradOp : OneFlow_BaseOp<\"clip_by_scalar_min_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$dy,\n    OneFlow_Tensor:$x\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let attrs = (ins\n    DefaultValuedAttr<F64Attr, \"0.\">:$floating_min,\n    DefaultValuedAttr<SI64Attr, \"0\">:$integral_min\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_ScalarAddOp : OneFlow_BaseOp<\"scalar_add\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<BoolAttr, \"false\">:$has_int_operand,\n    DefaultValuedAttr<BoolAttr, \"false\">:$has_float_operand,\n    DefaultValuedAttr<SI64Attr, \"0\">:$int_operand,\n    DefaultValuedAttr<F64Attr, \"0.\">:$float_operand\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let hasFolder = 1;\n}\n\ndef OneFlow_ScalarAddByTensorOp : OneFlow_BaseOp<\"scalar_add_by_tensor\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$scalar\n  );\n  let output = (outs\n    OneFlow_Tensor:$y\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\n// host_scalar_add_by_tensor op just for test host memory input\ndef OneFlow_HostScalarAddByTensorOp : OneFlow_BaseOp<\"host_scalar_add_by_tensor\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$scalar\n  );\n  let output = (outs\n    OneFlow_Tensor:$y\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_ScalarDivByTensorOp : OneFlow_BaseOp<\"scalar_div_by_tensor\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$scalar\n  );\n  let output = (outs\n    OneFlow_Tensor:$y\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_ScalarFloordivOp : OneFlow_BaseOp<\"scalar_floordiv\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<BoolAttr, \"false\">:$has_int_operand,\n    DefaultValuedAttr<BoolAttr, \"false\">:$has_float_operand,\n    DefaultValuedAttr<SI64Attr, \"0\">:$int_operand,\n    DefaultValuedAttr<F64Attr, \"0.\">:$float_operand\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_ScalarTruncdivOp : OneFlow_BaseOp<\"scalar_truncdiv\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<BoolAttr, \"false\">:$has_int_operand,\n    DefaultValuedAttr<BoolAttr, \"false\">:$has_float_operand,\n    DefaultValuedAttr<SI64Attr, \"0\">:$int_operand,\n    DefaultValuedAttr<F64Attr, \"0.\">:$float_operand\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_ScalarFmodOp : OneFlow_BaseOp<\"scalar_fmod\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<BoolAttr, \"false\">:$has_int_operand,\n    DefaultValuedAttr<BoolAttr, \"false\">:$has_float_operand,\n    DefaultValuedAttr<SI64Attr, \"0\">:$int_operand,\n    DefaultValuedAttr<F64Attr, \"0.\">:$float_operand\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_ScalarLogicalAndOp : OneFlow_BaseOp<\"scalar_logical_and\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<BoolAttr, \"false\">:$has_int_operand,\n    DefaultValuedAttr<BoolAttr, \"false\">:$has_float_operand,\n    DefaultValuedAttr<SI64Attr, \"0\">:$int_operand,\n    DefaultValuedAttr<F64Attr, \"0.\">:$float_operand\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_ScalarLogicalEqualOp : OneFlow_BaseOp<\"scalar_logical_equal\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<BoolAttr, \"false\">:$has_int_operand,\n    DefaultValuedAttr<BoolAttr, \"false\">:$has_float_operand,\n    DefaultValuedAttr<SI64Attr, \"0\">:$int_operand,\n    DefaultValuedAttr<F64Attr, \"0.\">:$float_operand\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_ScalarLogicalGreaterOp : OneFlow_BaseOp<\"scalar_logical_greater\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<BoolAttr, \"false\">:$has_int_operand,\n    DefaultValuedAttr<BoolAttr, \"false\">:$has_float_operand,\n    DefaultValuedAttr<SI64Attr, \"0\">:$int_operand,\n    DefaultValuedAttr<F64Attr, \"0.\">:$float_operand\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_ScalarLogicalGreaterEqualOp : OneFlow_BaseOp<\"scalar_logical_greater_equal\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<BoolAttr, \"false\">:$has_int_operand,\n    DefaultValuedAttr<BoolAttr, \"false\">:$has_float_operand,\n    DefaultValuedAttr<SI64Attr, \"0\">:$int_operand,\n    DefaultValuedAttr<F64Attr, \"0.\">:$float_operand\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_ScalarLogicalInplaceGreaterOp : OneFlow_BaseOp<\"scalar_logical_inplace_greater\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<BoolAttr, \"false\">:$has_int_operand,\n    DefaultValuedAttr<BoolAttr, \"false\">:$has_float_operand,\n    DefaultValuedAttr<SI64Attr, \"0\">:$int_operand,\n    DefaultValuedAttr<F64Attr, \"0.\">:$float_operand\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_ScalarLogicalLessOp : OneFlow_BaseOp<\"scalar_logical_less\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<BoolAttr, \"false\">:$has_int_operand,\n    DefaultValuedAttr<BoolAttr, \"false\">:$has_float_operand,\n    DefaultValuedAttr<SI64Attr, \"0\">:$int_operand,\n    DefaultValuedAttr<F64Attr, \"0.\">:$float_operand\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_ScalarLogicalLessEqualOp : OneFlow_BaseOp<\"scalar_logical_less_equal\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<BoolAttr, \"false\">:$has_int_operand,\n    DefaultValuedAttr<BoolAttr, \"false\">:$has_float_operand,\n    DefaultValuedAttr<SI64Attr, \"0\">:$int_operand,\n    DefaultValuedAttr<F64Attr, \"0.\">:$float_operand\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_ScalarLogicalNotEqualOp : OneFlow_BaseOp<\"scalar_logical_not_equal\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<BoolAttr, \"false\">:$has_int_operand,\n    DefaultValuedAttr<BoolAttr, \"false\">:$has_float_operand,\n    DefaultValuedAttr<SI64Attr, \"0\">:$int_operand,\n    DefaultValuedAttr<F64Attr, \"0.\">:$float_operand\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_ScalarLogicalOrOp : OneFlow_BaseOp<\"scalar_logical_or\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<BoolAttr, \"false\">:$has_int_operand,\n    DefaultValuedAttr<BoolAttr, \"false\">:$has_float_operand,\n    DefaultValuedAttr<SI64Attr, \"0\">:$int_operand,\n    DefaultValuedAttr<F64Attr, \"0.\">:$float_operand\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_ScalarLogicalXorOp : OneFlow_BaseOp<\"scalar_logical_xor\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<BoolAttr, \"false\">:$has_int_operand,\n    DefaultValuedAttr<BoolAttr, \"false\">:$has_float_operand,\n    DefaultValuedAttr<SI64Attr, \"0\">:$int_operand,\n    DefaultValuedAttr<F64Attr, \"0.\">:$float_operand\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_ScalarMulOp : OneFlow_BaseOp<\"scalar_mul\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<BoolAttr, \"false\">:$has_int_operand,\n    DefaultValuedAttr<BoolAttr, \"false\">:$has_float_operand,\n    DefaultValuedAttr<SI64Attr, \"0\">:$int_operand,\n    DefaultValuedAttr<F64Attr, \"0.\">:$float_operand\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_ScalarMulByTensorOp : OneFlow_BaseOp<\"scalar_mul_by_tensor\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$scalar\n  );\n  let output = (outs\n    OneFlow_Tensor:$y\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_ScalarDivOp : OneFlow_BaseOp<\"scalar_div\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>, DeclareOpInterfaceMethods<NCHWCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<BoolAttr, \"false\">:$has_int_operand,\n    DefaultValuedAttr<BoolAttr, \"false\">:$has_float_operand,\n    DefaultValuedAttr<SI64Attr, \"0\">:$int_operand,\n    DefaultValuedAttr<F64Attr, \"0.\">:$float_operand\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_ScalarPowOp : OneFlow_BaseOp<\"scalar_pow\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<BoolAttr, \"false\">:$has_int_operand,\n    DefaultValuedAttr<BoolAttr, \"false\">:$has_float_operand,\n    DefaultValuedAttr<SI64Attr, \"0\">:$int_operand,\n    DefaultValuedAttr<F64Attr, \"0.\">:$float_operand\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_ScalarPowGradOp : OneFlow_BaseOp<\"scalar_pow_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$dy\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let attrs = (ins\n    DefaultValuedAttr<BoolAttr, \"false\">:$has_int_operand,\n    DefaultValuedAttr<BoolAttr, \"false\">:$has_float_operand,\n    DefaultValuedAttr<SI64Attr, \"0\">:$int_operand,\n    DefaultValuedAttr<F64Attr, \"0.\">:$float_operand\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_ScalarReversePowOp : OneFlow_BaseOp<\"scalar_reverse_pow\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<BoolAttr, \"false\">:$has_int_operand,\n    DefaultValuedAttr<BoolAttr, \"false\">:$has_float_operand,\n    DefaultValuedAttr<SI64Attr, \"0\">:$int_operand,\n    DefaultValuedAttr<F64Attr, \"0.\">:$float_operand\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_ScalarReversePowGradOp : OneFlow_BaseOp<\"scalar_reverse_pow_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$dy\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let attrs = (ins\n    DefaultValuedAttr<BoolAttr, \"false\">:$has_int_operand,\n    DefaultValuedAttr<BoolAttr, \"false\">:$has_float_operand,\n    DefaultValuedAttr<SI64Attr, \"0\">:$int_operand,\n    DefaultValuedAttr<F64Attr, \"0.\">:$float_operand\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_ScalarSubByTensorOp : OneFlow_BaseOp<\"scalar_sub_by_tensor\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$scalar\n  );\n  let output = (outs\n    OneFlow_Tensor:$y\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_ScalarLerpOp : OneFlow_BaseOp<\"scalar_lerp\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$start,\n    OneFlow_Tensor:$end\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<BoolAttr, \"false\">:$has_int_operand,\n    DefaultValuedAttr<BoolAttr, \"false\">:$has_float_operand,\n    DefaultValuedAttr<SI64Attr, \"0\">:$int_operand,\n    DefaultValuedAttr<F64Attr, \"0.\">:$float_operand\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_ScalarLerpGradOp : OneFlow_BaseOp<\"scalar_lerp_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$start,\n    OneFlow_Tensor:$end,\n    OneFlow_Tensor:$out_diff\n  );\n  let output = (outs\n    OneFlow_Tensor:$start_diff,\n    OneFlow_Tensor:$end_diff\n  );\n  let attrs = (ins\n    DefaultValuedAttr<BoolAttr, \"false\">:$has_int_operand,\n    DefaultValuedAttr<BoolAttr, \"false\">:$has_float_operand,\n    DefaultValuedAttr<SI64Attr, \"0\">:$int_operand,\n    DefaultValuedAttr<F64Attr, \"0.\">:$float_operand\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_ScalarBitwiseAndOp : OneFlow_BaseOp<\"scalar_bitwise_and\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI64Attr, \"0\">:$operand\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_ScalarBitwiseOrOp : OneFlow_BaseOp<\"scalar_bitwise_or\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI64Attr, \"0\">:$operand\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_ScalarBitwiseXorOp : OneFlow_BaseOp<\"scalar_bitwise_xor\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI64Attr, \"0\">:$operand\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n#endif // GET_ONEFLOW_SCALAR_OP_DEFINITIONS\n\n\n#ifdef GET_ONEFLOW_SOFTMAX_OP_DEFINITIONS\n\ndef OneFlow_LogSoftmaxOp : OneFlow_BaseOp<\"log_softmax\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$prob\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_LogSoftmaxGradOp : OneFlow_BaseOp<\"log_softmax_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$prob,\n    OneFlow_Tensor:$dy\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_SoftmaxOp : OneFlow_BaseOp<\"softmax\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_compute_complexity_fn = 1;\n}\n\ndef OneFlow_SoftmaxCrossEntropyOp : OneFlow_BaseOp<\"softmax_cross_entropy\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$prediction,\n    OneFlow_Tensor:$label\n  );\n  let output = (outs\n    OneFlow_Tensor:$prob,\n    OneFlow_Tensor:$out\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_input_arg_modify_fn = 1;\n}\n\ndef OneFlow_SoftmaxCrossEntropyGradOp : OneFlow_BaseOp<\"softmax_cross_entropy_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$dy,\n    OneFlow_Tensor:$label,\n    OneFlow_Tensor:$prob\n  );\n  let output = (outs\n    OneFlow_Tensor:$prediction_diff\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_SoftmaxGradOp : OneFlow_BaseOp<\"softmax_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$y,\n    OneFlow_Tensor:$dy\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_SparseSoftmaxCrossEntropyOp : OneFlow_BaseOp<\"sparse_softmax_cross_entropy\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$prediction,\n    OneFlow_Tensor:$label\n  );\n  let output = (outs\n    OneFlow_Tensor:$prob,\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI64Attr, \"0\">:$depth\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_input_arg_modify_fn = 1;\n}\n\ndef OneFlow_SparseSoftmaxCrossEntropyGradOp : OneFlow_BaseOp<\"sparse_softmax_cross_entropy_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$label,\n    OneFlow_Tensor:$dy,\n    OneFlow_Tensor:$prob\n  );\n  let output = (outs\n    OneFlow_Tensor:$prediction_diff\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI64Attr, \"0\">:$depth\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_SparseSoftmaxCrossEntropyMsOp : OneFlow_BaseOp<\"sparse_softmax_cross_entropy_ms\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$prediction,\n    OneFlow_Tensor:$label\n  );\n  let output = (outs\n    OneFlow_Tensor:$prob,\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI64Attr, \"0\">:$depth\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_input_arg_modify_fn = 1;\n}\n\ndef OneFlow_SparseSoftmaxCrossEntropyMsGradOp : OneFlow_BaseOp<\"sparse_softmax_cross_entropy_ms_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$label,\n    OneFlow_Tensor:$dy,\n    OneFlow_Tensor:$prob\n  );\n  let output = (outs\n    OneFlow_Tensor:$prediction_diff\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI64Attr, \"0\">:$depth\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\n#endif // GET_ONEFLOW_SOFTMAX_OP_DEFINITIONS\n\n\n#ifdef GET_ONEFLOW_SUMMARY_OP_DEFINITIONS\n\ndef OneFlow_CreateSummaryWriterOp : OneFlow_BaseOp<\"create_summary_writer\", [NoMemoryEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let attrs = (ins\n    StrAttr:$logdir\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_FlushSummaryWriterOp : OneFlow_BaseOp<\"flush_summary_writer\", [NoMemoryEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_SummaryWriteHistogramOp : OneFlow_BaseOp<\"summary_write_histogram\", [NoMemoryEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in,\n    OneFlow_Tensor:$step,\n    OneFlow_Tensor:$tag\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_SummaryWriteImageOp : OneFlow_BaseOp<\"summary_write_image\", [NoMemoryEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in,\n    OneFlow_Tensor:$step,\n    OneFlow_Tensor:$tag\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_SummaryWritePbOp : OneFlow_BaseOp<\"summary_write_pb\", [NoMemoryEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in,\n    OneFlow_Tensor:$step\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_SummaryWriteScalarOp : OneFlow_BaseOp<\"summary_write_scalar\", [NoMemoryEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in,\n    OneFlow_Tensor:$step,\n    OneFlow_Tensor:$tag\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\n#endif // GET_ONEFLOW_SUMMARY_OP_DEFINITIONS\n\n\n#ifdef GET_ONEFLOW_TENSOR_BUFFER_OP_DEFINITIONS\n\ndef OneFlow_GenTensorBufferOp : OneFlow_BaseOp<\"gen_tensor_buffer\", [NoMemoryEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    ShapeAttr:$shape,\n    ShapeArrayAttr:$shape_list,\n    F32ArrayAttr:$value_list,\n    OneFlow_DataType:$data_type,\n    DefaultValuedAttr<BoolAttr, \"false\">:$dynamic_out\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_TensorBufferToListOfTensorsOp : OneFlow_BaseOp<\"tensor_buffer_to_list_of_tensors\", [NoMemoryEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    Variadic<OneFlow_Tensor>:$out\n  );\n  let attrs = (ins\n    ShapeAttr:$out_shape,\n    OneFlow_DataType:$out_dtype,\n    DefaultValuedAttr<BoolAttr, \"false\">:$dynamic_out\n  );\n  let has_check_fn = 1;\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_output_arg_modify_fn = 1;\n}\n\ndef OneFlow_TensorBufferToListOfTensorsV2Op : OneFlow_BaseOp<\"tensor_buffer_to_list_of_tensors_v2\", [NoMemoryEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    Variadic<OneFlow_Tensor>:$out\n  );\n  let attrs = (ins\n    ShapeArrayAttr:$out_shapes,\n    DTArrayAttr:$out_dtypes,\n    DefaultValuedAttr<BoolAttr, \"false\">:$dynamic_out\n  );\n  let has_check_fn = 1;\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_output_arg_modify_fn = 1;\n}\n\ndef OneFlow_TensorBufferToTensorOp : OneFlow_BaseOp<\"tensor_buffer_to_tensor\", [NoMemoryEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    ShapeAttr:$instance_shape,\n    OneFlow_DataType:$dtype\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_TensorToTensorBufferOp : OneFlow_BaseOp<\"tensor_to_tensor_buffer\", [NoMemoryEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI32Attr, \"0\">:$instance_dims\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\n#endif // GET_ONEFLOW_TENSOR_BUFFER_OP_DEFINITIONS\n\n\n#ifdef GET_ONEFLOW_TEST_OP_DEFINITIONS\n\ndef OneFlow_ThrowErrorOp : OneFlow_BaseOp<\"throw_error\", [NoMemoryEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x\n  );\n  let output = (outs\n    OneFlow_Tensor:$y\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\n#endif // GET_ONEFLOW_TEST_OP_DEFINITIONS\n\n#ifdef GET_ONEFLOW_TRIGONOMETRIC_OP_DEFINITIONS\n\ndef OneFlow_AcosOp : OneFlow_BaseOp<\"acos\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x\n  );\n  let output = (outs\n    OneFlow_Tensor:$y\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_AcosGradOp : OneFlow_BaseOp<\"acos_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$dy\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_AcoshOp : OneFlow_BaseOp<\"acosh\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x\n  );\n  let output = (outs\n    OneFlow_Tensor:$y\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_AcoshGradOp : OneFlow_BaseOp<\"acosh_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$dy\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_AsinOp : OneFlow_BaseOp<\"asin\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x\n  );\n  let output = (outs\n    OneFlow_Tensor:$y\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_AsinGradOp : OneFlow_BaseOp<\"asin_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$dy\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_AsinhOp : OneFlow_BaseOp<\"asinh\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x\n  );\n  let output = (outs\n    OneFlow_Tensor:$y\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_AsinhGradOp : OneFlow_BaseOp<\"asinh_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$dy\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_AtanOp : OneFlow_BaseOp<\"atan\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x\n  );\n  let output = (outs\n    OneFlow_Tensor:$y\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_Atan2Op : OneFlow_BaseOp<\"atan2\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$y\n  );\n  let output = (outs\n    OneFlow_Tensor:$z\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_Atan2XGradOp : OneFlow_BaseOp<\"atan2_x_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$y,\n    OneFlow_Tensor:$dz\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_Atan2YGradOp : OneFlow_BaseOp<\"atan2_y_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$y,\n    OneFlow_Tensor:$dz\n  );\n  let output = (outs\n    OneFlow_Tensor:$dy\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_AtanGradOp : OneFlow_BaseOp<\"atan_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$dy\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_AtanhOp : OneFlow_BaseOp<\"atanh\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x\n  );\n  let output = (outs\n    OneFlow_Tensor:$y\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_AtanhGradOp : OneFlow_BaseOp<\"atanh_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$dy\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_CosOp : OneFlow_BaseOp<\"cos\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x\n  );\n  let output = (outs\n    OneFlow_Tensor:$y\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_CosGradOp : OneFlow_BaseOp<\"cos_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$dy\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_CoshOp : OneFlow_BaseOp<\"cosh\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x\n  );\n  let output = (outs\n    OneFlow_Tensor:$y\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_CoshGradOp : OneFlow_BaseOp<\"cosh_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$dy\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_HardtanhOp : OneFlow_BaseOp<\"hardtanh\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<F64Attr, \"0.\">:$min_val,\n    DefaultValuedAttr<F64Attr, \"0.\">:$max_val\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_HardtanhGradOp : OneFlow_BaseOp<\"hardtanh_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$y,\n    OneFlow_Tensor:$dy\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let attrs = (ins\n    DefaultValuedAttr<F64Attr, \"0.\">:$min_val,\n    DefaultValuedAttr<F64Attr, \"0.\">:$max_val\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_SinOp : OneFlow_BaseOp<\"sin\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x\n  );\n  let output = (outs\n    OneFlow_Tensor:$y\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_SinGradOp : OneFlow_BaseOp<\"sin_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$dy\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_SinhOp : OneFlow_BaseOp<\"sinh\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x\n  );\n  let output = (outs\n    OneFlow_Tensor:$y\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_SinhGradOp : OneFlow_BaseOp<\"sinh_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$dy\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_TanOp : OneFlow_BaseOp<\"tan\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x\n  );\n  let output = (outs\n    OneFlow_Tensor:$y\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_TanGradOp : OneFlow_BaseOp<\"tan_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$dy\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_TanhOp : OneFlow_BaseOp<\"tanh\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x\n  );\n  let output = (outs\n    OneFlow_Tensor:$y\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_TanhGradOp : OneFlow_BaseOp<\"tanh_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$y,\n    OneFlow_Tensor:$dy\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_NotEqualZeroOp : OneFlow_BaseOp<\"not_equal_zero\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x\n  );\n  let output = (outs\n    OneFlow_Tensor:$y\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\n#endif // GET_ONEFLOW_TRIGONOMETRIC_OP_DEFINITIONS\n\n\n#ifdef GET_ONEFLOW_UNARY_OP_DEFINITIONS\n\ndef OneFlow_AccOp : OneFlow_BaseOp<\"acc\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI32Attr, \"0\">:$max_acc_num\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_output_blob_time_shape_infer_fn = 1;\n}\n\ndef OneFlow_AccCtrlTickOp : OneFlow_BaseOp<\"acc_ctrl_tick\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI32Attr, \"0\">:$max_acc_num\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_nd_sbp_infer_fn = 1;\n  let has_output_blob_time_shape_infer_fn = 1;\n}\n\ndef OneFlow_AffineGridOp : OneFlow_BaseOp<\"affine_grid\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$theta\n  );\n  let output = (outs\n    OneFlow_Tensor:$grid\n  );\n  let attrs = (ins\n    ShapeAttr:$size,\n    DefaultValuedAttr<BoolAttr, \"false\">:$align_corners\n  );\n  let has_check_fn = 1;\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_AffineGridGradOp : OneFlow_BaseOp<\"affine_grid_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$dgrid\n  );\n  let output = (outs\n    OneFlow_Tensor:$dtheta\n  );\n  let attrs = (ins\n    ShapeAttr:$size,\n    DefaultValuedAttr<BoolAttr, \"false\">:$align_corners\n  );\n  let has_check_fn = 1;\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_BernoulliOp : OneFlow_BaseOp<\"bernoulli\", [NoMemoryEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    OneFlow_DataType:$dtype,\n    DefaultValuedAttr<SI64Attr, \"-1\">:$seed,\n    DefaultValuedAttr<F64Attr, \"0.\">:$p\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_CastOp : OneFlow_BaseOp<\"cast\", [NoMemoryEffect, SupportNonContiguous, DeclareOpInterfaceMethods<UserOpCompatibleInterface>, DeclareOpInterfaceMethods<NCHWCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    OneFlow_DataType:$dtype,\n    DefaultValuedAttr<BoolAttr, \"false\">:$pin_memory\n  );\n  let has_device_and_stream_infer_fn = 1;\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_MutableCastOnceOp : OneFlow_BaseOp<\"mutable_cast_once\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    OneFlow_DataType:$dtype\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let same_output_regst_num = 1;\n}\n\ndef OneFlow_CastToStaticShapeOp : OneFlow_BaseOp<\"cast_to_static_shape\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$input\n  );\n  let output = (outs\n    OneFlow_Tensor:$output\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_CastToTickOp : OneFlow_BaseOp<\"cast_to_tick\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_nd_sbp_infer_fn = 1;\n}\n\ndef OneFlow_CeluOp : OneFlow_BaseOp<\"celu\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<F64Attr, \"0.\">:$alpha\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_CopyOp : OneFlow_BaseOp<\"copy\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    StrAttr:$device_type,\n    DefaultValuedAttr<SI64Attr, \"0\">:$device_id,\n    DefaultValuedAttr<BoolAttr, \"false\">:$pin_memory\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_device_and_stream_infer_fn = 1;\n}\n\ndef OneFlow_CountNotFiniteOp : OneFlow_BaseOp<\"count_not_finite\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x\n  );\n  let output = (outs\n    OneFlow_Tensor:$y\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_DiagOp : OneFlow_BaseOp<\"diag\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI32Attr, \"0\">:$diagonal\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_DiagonalOp : OneFlow_BaseOp<\"diagonal\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI32Attr, \"0\">:$offset\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_EluOp : OneFlow_BaseOp<\"elu\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<F64Attr, \"0.\">:$alpha\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_ExpandOp : OneFlow_BaseOp<\"expand\", [NoMemoryEffect, SupportNonContiguous, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    ShapeAttr:$expand_shape\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_ExpandDimsOp : OneFlow_BaseOp<\"expand_dims\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI32Attr, \"0\">:$axis\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_FlipOp : OneFlow_BaseOp<\"flip\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x\n  );\n  let output = (outs\n    OneFlow_Tensor:$y\n  );\n  let attrs = (ins\n    SI32ArrayAttr:$dims\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_FoldOp : OneFlow_BaseOp<\"fold\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x\n  );\n  let output = (outs\n    OneFlow_Tensor:$y\n  );\n  let attrs = (ins\n    StrAttr:$data_format,\n    SI32ArrayAttr:$output_size,\n    SI32ArrayAttr:$kernel_size,\n    SI32ArrayAttr:$strides,\n    SI32ArrayAttr:$padding,\n    SI32ArrayAttr:$dilation_rate\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_GeluOp : OneFlow_BaseOp<\"gelu\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_FastGeluOp : OneFlow_BaseOp<\"fast_gelu\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_QuickGeluOp : OneFlow_BaseOp<\"quick_gelu\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x\n  );\n  let output = (outs\n    OneFlow_Tensor:$y\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_SquareReLUOp : OneFlow_BaseOp<\"square_relu\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x\n  );\n  let output = (outs\n    OneFlow_Tensor:$y\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_HardsigmoidOp : OneFlow_BaseOp<\"hardsigmoid\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_HardShrinkOp : OneFlow_BaseOp<\"hardshrink\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<F64Attr, \"0.\">:$lambd\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_HardswishOp : OneFlow_BaseOp<\"hardswish\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_LeakyReluOp : OneFlow_BaseOp<\"leaky_relu\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x\n  );\n  let output = (outs\n    OneFlow_Tensor:$y\n  );\n  let attrs = (ins\n    DefaultValuedAttr<F32Attr, \"0.\">:$alpha\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_RReluOp : OneFlow_BaseOp<\"rrelu\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$output,\n    OneFlow_Tensor:$noise_data\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI64Attr, \"0\">:$seed,\n    DefaultValuedAttr<F32Attr, \"0.125\">:$lower,\n    DefaultValuedAttr<F32Attr, \"0.3333333333333333\">:$upper,\n    DefaultValuedAttr<BoolAttr, \"false\">:$training\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_Log2Op : OneFlow_BaseOp<\"log2\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x\n  );\n  let output = (outs\n    OneFlow_Tensor:$y\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_Log10Op : OneFlow_BaseOp<\"log10\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x\n  );\n  let output = (outs\n    OneFlow_Tensor:$y\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_LogicalNotOp : OneFlow_BaseOp<\"logical_not\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x\n  );\n  let output = (outs\n    OneFlow_Tensor:$y\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_MishOp : OneFlow_BaseOp<\"mish\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_NarrowOp : OneFlow_BaseOp<\"narrow\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI64Attr, \"0\">:$dim,\n    DefaultValuedAttr<SI64Attr, \"0\">:$start,\n    DefaultValuedAttr<SI64Attr, \"0\">:$length\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_OneHotOp : OneFlow_BaseOp<\"one_hot\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$indices\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI64Attr, \"0\">:$depth,\n    DefaultValuedAttr<F64Attr, \"0.\">:$floating_on_value,\n    DefaultValuedAttr<SI64Attr, \"0\">:$integer_on_value,\n    DefaultValuedAttr<F64Attr, \"0.\">:$floating_off_value,\n    DefaultValuedAttr<SI64Attr, \"0\">:$integer_off_value,\n    OneFlow_DataType:$dtype\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_input_arg_modify_fn = 1;\n}\n\ndef OneFlow_PackOp : OneFlow_BaseOp<\"pack\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI32Attr, \"0\">:$pack_num\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_output_blob_time_shape_infer_fn = 1;\n}\n\ndef OneFlow_RandomMaskLikeOp : OneFlow_BaseOp<\"random_mask_like\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$like\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<F32Attr, \"0.\">:$rate,\n    DefaultValuedAttr<SI64Attr, \"0\">:$seed\n  );\n  let has_check_fn = 1;\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let hasCanonicalizer = 1;\n}\n\ndef OneFlow_RepeatOp : OneFlow_BaseOp<\"repeat\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI32Attr, \"0\">:$repeat_num\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_output_blob_time_shape_infer_fn = 1;\n}\n\ndef OneFlow_Repeat_InterLeaveOp : OneFlow_BaseOp<\"repeat_interleave\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in,\n    OneFlow_Tensor:$cumsum\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI64Attr, \"0\">:$repeat_num\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_RollOp : OneFlow_BaseOp<\"roll\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    SI32ArrayAttr:$shifts,\n    SI32ArrayAttr:$dims\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_SeluOp : OneFlow_BaseOp<\"selu\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_SiluOp : OneFlow_BaseOp<\"silu\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>, DeclareOpInterfaceMethods<NCHWCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_SoftShrinkOp: OneFlow_BaseOp<\"softshrink\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<F64Attr, \"0.\">:$alpha\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_SoftsignOp : OneFlow_BaseOp<\"softsign\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_SortOp : OneFlow_BaseOp<\"sort\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    StrAttr:$direction\n  );\n  let has_check_fn = 1;\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_SquareSumOp : OneFlow_BaseOp<\"square_sum\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x\n  );\n  let output = (outs\n    OneFlow_Tensor:$y\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_SqrtSquareSumOp : OneFlow_BaseOp<\"sqrt_square_sum\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x\n  );\n  let output = (outs\n    OneFlow_Tensor:$y\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_SqueezeOp : OneFlow_BaseOp<\"squeeze\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    SI32ArrayAttr:$axes\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_ThresholdOp : OneFlow_BaseOp<\"threshold\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<F64Attr, \"0.\">:$threshold_val,\n    DefaultValuedAttr<F64Attr, \"0.\">:$value\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_TransposeOp : OneFlow_BaseOp<\"transpose\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$input\n  );\n  let output = (outs\n    OneFlow_Tensor:$output\n  );\n  let attrs = (ins\n    SI32ArrayAttr:$perm\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let hasFolder = 1;\n}\n\ndef OneFlow_AsStridedOp : OneFlow_BaseOp<\"as_strided\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$input\n  );\n  let output = (outs\n    OneFlow_Tensor:$output\n  );\n  let attrs = (ins\n    SI64ArrayAttr:$size,\n    SI64ArrayAttr:$stride,\n    DefaultValuedAttr<SI64Attr, \"0\">:$storage_offset\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_IndexAddOp : OneFlow_BaseOp<\"index_add\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$input,\n    OneFlow_Tensor:$index,\n    OneFlow_Tensor:$source\n  );\n  let output = (outs\n    OneFlow_Tensor:$output\n  );\n  let attrs = (ins\n    SI64Attr: $dim,\n    F32Attr: $alpha\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_AsStridedGradOp : OneFlow_BaseOp<\"as_strided_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$dy,\n    OneFlow_Tensor:$input\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let attrs = (ins\n    SI64ArrayAttr:$size,\n    SI64ArrayAttr:$stride,\n    DefaultValuedAttr<SI64Attr, \"0\">:$storage_offset\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_TrilOp : OneFlow_BaseOp<\"tril\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI64Attr, \"0\">:$diagonal,\n    DefaultValuedAttr<F64Attr, \"0.\">:$floating_fill_value,\n    DefaultValuedAttr<SI64Attr, \"0\">:$integer_fill_value,\n    DefaultValuedAttr<BoolAttr, \"false\">:$is_floating_fill_value\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_TriuOp : OneFlow_BaseOp<\"triu\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI64Attr, \"0\">:$diagonal\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_TruncOp : OneFlow_BaseOp<\"trunc\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_TruncGradOp : OneFlow_BaseOp<\"trunc_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x,\n    OneFlow_Tensor:$dy\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_UnfoldOp : OneFlow_BaseOp<\"unfold\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x\n  );\n  let output = (outs\n    OneFlow_Tensor:$y\n  );\n  let attrs = (ins\n    StrAttr:$data_format,\n    SI32ArrayAttr:$kernel_size,\n    SI32ArrayAttr:$padding,\n    SI32ArrayAttr:$strides,\n    SI32ArrayAttr:$dilation_rate\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_UnfoldTensorOp : OneFlow_BaseOp<\"unfold_tensor\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x\n  );\n  let output = (outs\n    OneFlow_Tensor:$y\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI32Attr, \"0\">:$dimension,\n    DefaultValuedAttr<SI32Attr, \"0\">:$size,\n    DefaultValuedAttr<SI32Attr, \"0\">:$step\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_UnpackOp : OneFlow_BaseOp<\"unpack\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI32Attr, \"0\">:$unpack_num\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_output_blob_time_shape_infer_fn = 1;\n}\n\ndef OneFlow_ZeroLikeOp : OneFlow_BaseOp<\"zero_like\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$like\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let same_output_regst_num = 1;\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_nd_sbp_infer_fn = 1;\n}\n\ndef OneFlow_ToContiguousOp : OneFlow_BaseOp<\"to_contiguous\", [NoMemoryEffect, SupportNonContiguous, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_ConvertMemoryFormatOp : OneFlow_BaseOp<\"convert_memory_format\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    OneFlow_MemoryFormat:$memory_format\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_IsNanOp : OneFlow_BaseOp<\"isnan\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_IsInfOp : OneFlow_BaseOp<\"isinf\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_IsFiniteOp : OneFlow_BaseOp<\"isfinite\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_RealOp : OneFlow_BaseOp<\"real\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_RealGradOp : OneFlow_BaseOp<\"real_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$dout\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_ImagOp : OneFlow_BaseOp<\"imag\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_ImagGradOp : OneFlow_BaseOp<\"imag_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$dout\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_ConjPhysicalOp : OneFlow_BaseOp<\"conj_physical\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x\n  );\n  let output = (outs\n    OneFlow_Tensor:$y\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\n#endif // GET_ONEFLOW_UNARY_OP_DEFINITIONS\n\n\n#ifdef GET_ONEFLOW_UPSAMPLE_OP_DEFINITIONS\n\ndef OneFlow_UpsampleBicubic2DOp : OneFlow_BaseOp<\"upsample_bicubic_2d\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x\n  );\n  let output = (outs\n    OneFlow_Tensor:$y\n  );\n  let attrs = (ins\n    DefaultValuedAttr<F64Attr, \"0.\">:$height_scale,\n    DefaultValuedAttr<F64Attr, \"0.\">:$width_scale,\n    DefaultValuedAttr<BoolAttr, \"false\">:$align_corners,\n    SI64ArrayAttr:$output_size,\n    StrAttr:$data_format\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_UpsampleBicubic2DGradOp : OneFlow_BaseOp<\"upsample_bicubic_2d_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$dy,\n    OneFlow_Tensor:$x\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let attrs = (ins\n    DefaultValuedAttr<F64Attr, \"0.\">:$height_scale,\n    DefaultValuedAttr<F64Attr, \"0.\">:$width_scale,\n    DefaultValuedAttr<BoolAttr, \"false\">:$align_corners,\n    SI64ArrayAttr:$output_size,\n    StrAttr:$data_format\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_UpsampleBilinear2DOp : OneFlow_BaseOp<\"upsample_bilinear_2d\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x\n  );\n  let output = (outs\n    OneFlow_Tensor:$y\n  );\n  let attrs = (ins\n    DefaultValuedAttr<F64Attr, \"0.\">:$height_scale,\n    DefaultValuedAttr<F64Attr, \"0.\">:$width_scale,\n    DefaultValuedAttr<BoolAttr, \"false\">:$align_corners,\n    SI64ArrayAttr:$output_size,\n    StrAttr:$data_format\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_UpsampleBilinear2DGradOp : OneFlow_BaseOp<\"upsample_bilinear_2d_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$dy,\n    OneFlow_Tensor:$x\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let attrs = (ins\n    DefaultValuedAttr<F64Attr, \"0.\">:$height_scale,\n    DefaultValuedAttr<F64Attr, \"0.\">:$width_scale,\n    DefaultValuedAttr<BoolAttr, \"false\">:$align_corners,\n    SI64ArrayAttr:$output_size,\n    StrAttr:$data_format\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_UpsampleLinear1DOp : OneFlow_BaseOp<\"upsample_linear_1d\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x\n  );\n  let output = (outs\n    OneFlow_Tensor:$y\n  );\n  let attrs = (ins\n    DefaultValuedAttr<F64Attr, \"0.\">:$scale_factor,\n    DefaultValuedAttr<BoolAttr, \"false\">:$align_corners,\n    SI64ArrayAttr:$output_size,\n    StrAttr:$data_format\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_UpsampleLinear1DGradOp : OneFlow_BaseOp<\"upsample_linear_1d_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$dy,\n    OneFlow_Tensor:$x\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let attrs = (ins\n    DefaultValuedAttr<F64Attr, \"0.\">:$scale_factor,\n    DefaultValuedAttr<BoolAttr, \"false\">:$align_corners,\n    SI64ArrayAttr:$output_size,\n    StrAttr:$data_format\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_UpsampleNearest1DOp : OneFlow_BaseOp<\"upsample_nearest_1d\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x\n  );\n  let output = (outs\n    OneFlow_Tensor:$y\n  );\n  let attrs = (ins\n    DefaultValuedAttr<F64Attr, \"0.\">:$scale_factor,\n    SI64ArrayAttr:$output_size,\n    StrAttr:$data_format\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_UpsampleNearest1DGradOp : OneFlow_BaseOp<\"upsample_nearest_1d_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$dy,\n    OneFlow_Tensor:$x\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let attrs = (ins\n    DefaultValuedAttr<F64Attr, \"0.\">:$scale_factor,\n    SI64ArrayAttr:$output_size,\n    StrAttr:$data_format\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_UpsampleNearest2DOp : OneFlow_BaseOp<\"upsample_nearest_2d\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x\n  );\n  let output = (outs\n    OneFlow_Tensor:$y\n  );\n  let attrs = (ins\n    DefaultValuedAttr<F64Attr, \"0.\">:$height_scale,\n    DefaultValuedAttr<F64Attr, \"0.\">:$width_scale,\n    SI64ArrayAttr:$output_size,\n    StrAttr:$data_format\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_UpsampleNearest2DGradOp : OneFlow_BaseOp<\"upsample_nearest_2d_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$dy,\n    OneFlow_Tensor:$x\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let attrs = (ins\n    DefaultValuedAttr<F64Attr, \"0.\">:$height_scale,\n    DefaultValuedAttr<F64Attr, \"0.\">:$width_scale,\n    SI64ArrayAttr:$output_size,\n    StrAttr:$data_format\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_UpsampleNearest3DOp : OneFlow_BaseOp<\"upsample_nearest_3d\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x\n  );\n  let output = (outs\n    OneFlow_Tensor:$y\n  );\n  let attrs = (ins\n    DefaultValuedAttr<F64Attr, \"0.\">:$depth_scale,\n    DefaultValuedAttr<F64Attr, \"0.\">:$height_scale,\n    DefaultValuedAttr<F64Attr, \"0.\">:$width_scale,\n    SI64ArrayAttr:$output_size,\n    StrAttr:$data_format\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_UpsampleNearest3DGradOp : OneFlow_BaseOp<\"upsample_nearest_3d_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$dy,\n    OneFlow_Tensor:$x\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let attrs = (ins\n    DefaultValuedAttr<F64Attr, \"0.\">:$depth_scale,\n    DefaultValuedAttr<F64Attr, \"0.\">:$height_scale,\n    DefaultValuedAttr<F64Attr, \"0.\">:$width_scale,\n    SI64ArrayAttr:$output_size,\n    StrAttr:$data_format\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_UpsampleTrilinear3DOp : OneFlow_BaseOp<\"upsample_trilinear_3d\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x\n  );\n  let output = (outs\n    OneFlow_Tensor:$y\n  );\n  let attrs = (ins\n    DefaultValuedAttr<F64Attr, \"0.\">:$depth_scale,\n    DefaultValuedAttr<F64Attr, \"0.\">:$height_scale,\n    DefaultValuedAttr<F64Attr, \"0.\">:$width_scale,\n    DefaultValuedAttr<BoolAttr, \"false\">:$align_corners,\n    SI64ArrayAttr:$output_size,\n    StrAttr:$data_format\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_UpsampleTrilinear3DGradOp : OneFlow_BaseOp<\"upsample_trilinear_3d_grad\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$dy,\n    OneFlow_Tensor:$x\n  );\n  let output = (outs\n    OneFlow_Tensor:$dx\n  );\n  let attrs = (ins\n    DefaultValuedAttr<F64Attr, \"0.\">:$depth_scale,\n    DefaultValuedAttr<F64Attr, \"0.\">:$height_scale,\n    DefaultValuedAttr<F64Attr, \"0.\">:$width_scale,\n    DefaultValuedAttr<BoolAttr, \"false\">:$align_corners,\n    SI64ArrayAttr:$output_size,\n    StrAttr:$data_format\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n#endif // GET_ONEFLOW_UPSAMPLE_OP_DEFINITIONS\n\n\n#ifdef GET_ONEFLOW_ONE_EMBEDDING_OP_DEFINITIONS\n\ndef OneFlow_OneEmbeddingFusedLookupOp : OneFlow_BaseOp<\"one_embedding_fused_lookup\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$shadow,\n    OneFlow_Tensor:$ids,\n    Optional<OneFlow_Tensor>:$table_ids\n  );\n  let output = (outs\n    OneFlow_Tensor:$embeddings\n  );\n  let attrs = (ins\n    OneFlow_DataType:$dtype,\n    StrAttr:$embedding_name,\n    DefaultValuedAttr<SI64Attr, \"0\">:$line_size,\n    DefaultValuedAttr<SI64Attr, \"0\">:$embedding_size,\n    DefaultValuedAttr<BoolAttr, \"false\">:$is_full_cache,\n    DefaultValuedAttr<SI32Attr, \"1\">:$num_tables,\n    DefaultValuedAttr<SI64Attr, \"-1\">:$padding_idx,\n    DefaultValuedAttr<BoolAttr, \"false\">:$has_padding_idx,\n    StrAttr:$embedding_tables,\n    DefaultValuedAttr<SI64Attr, \"0\">:$seed\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n  let has_input_arg_modify_fn = 1;\n}\n\ndef OneFlow_OneEmbeddingFusedLookupGradOp : OneFlow_BaseOp<\"one_embedding_fused_lookup_grad\", [DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$ids,\n    OneFlow_Tensor:$embedding_grad\n  );\n  let attrs = (ins\n    StrAttr:$embedding_name,\n    DefaultValuedAttr<SI64Attr, \"0\">:$line_size,\n    DefaultValuedAttr<SI64Attr, \"0\">:$embedding_size\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_UniqueKeyValuePairOp : OneFlow_BaseOp<\"unique_key_value_pair\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$keys,\n    Optional<OneFlow_Tensor>:$values\n  );\n  let output = (outs\n    OneFlow_Tensor:$num_unique,\n    OneFlow_Tensor:$unique_keys,\n    OneFlow_Tensor:$unique_values,\n    OneFlow_Tensor:$inverse_indices\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI32Attr, \"1\">:$num_tables,\n    DefaultValuedAttr<SI64Attr, \"-1\">:$padding_idx,\n    DefaultValuedAttr<BoolAttr, \"false\">:$has_padding_idx,\n    StrAttr:$embedding_name\n  );\n  let same_output_regst_num = 1;\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_IdShuffleOp : OneFlow_BaseOp<\"id_shuffle\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$ids,\n    Optional<OneFlow_Tensor>:$table_ids\n  );\n  let output = (outs\n    OneFlow_Tensor:$num_unique_matrix,\n    OneFlow_Tensor:$inverse_unique_partition_indices,\n    OneFlow_Tensor:$cur_rank_num_unique,\n    OneFlow_Tensor:$cur_rank_unique_ids,\n    OneFlow_Tensor:$cur_rank_unique_table_ids,\n    OneFlow_Tensor:$cur_rank_inverse_indices\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI32Attr, \"1\">:$num_tables,\n    DefaultValuedAttr<SI64Attr, \"-1\">:$padding_idx,\n    DefaultValuedAttr<BoolAttr, \"false\">:$has_padding_idx,\n    StrAttr:$embedding_name\n  );\n  let same_output_regst_num = 2;\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_IdShuffleCopyOutOp : OneFlow_BaseOp<\"id_shuffle_copy_out\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$num_unique_matrix,\n    OneFlow_Tensor:$inverse_unique_partition_indices,\n    OneFlow_Tensor:$cur_rank_num_unique,\n    OneFlow_Tensor:$cur_rank_unique_ids,\n    OneFlow_Tensor:$cur_rank_unique_table_ids,\n    OneFlow_Tensor:$cur_rank_inverse_indices\n  );\n  let output = (outs\n    OneFlow_Tensor:$out_num_unique_matrix,\n    OneFlow_Tensor:$out_inverse_unique_partition_indices,\n    OneFlow_Tensor:$out_cur_rank_num_unique,\n    OneFlow_Tensor:$out_cur_rank_unique_ids,\n    OneFlow_Tensor:$out_cur_rank_unique_table_ids,\n    OneFlow_Tensor:$out_cur_rank_inverse_indices\n  );\n  let attrs = (ins\n    StrAttr:$embedding_name\n  );\n  let same_output_regst_num = 1;\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_OneEmbeddingGatherOp : OneFlow_BaseOp<\"one_embedding_gather\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in,\n    OneFlow_Tensor:$indices\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI64Attr, \"0\">:$embedding_size,\n    StrAttr:$embedding_name\n  );\n  let same_output_regst_num = 1;\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_EmbeddingShuffleOp : OneFlow_BaseOp<\"embedding_shuffle\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$cur_rank_embeddings,\n    OneFlow_Tensor:$num_unique_matrix,\n    OneFlow_Tensor:$cur_rank_inverse_indices,\n    OneFlow_Tensor:$inverse_unique_partition_indices\n  );\n  let output = (outs\n    OneFlow_Tensor:$embeddings\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI64Attr, \"0\">:$embedding_size,\n    DefaultValuedAttr<BoolAttr, \"false\">:$skip_last_gather,\n    DefaultValuedAttr<BoolAttr, \"false\">:$is_train,\n    StrAttr:$embedding_name\n  );\n  let same_output_regst_num = 1;\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_EmbeddingGradientShuffleOp : OneFlow_BaseOp<\"embedding_gradient_shuffle\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$embedding_grad,\n    OneFlow_Tensor:$num_unique_matrix,\n    OneFlow_Tensor:$cur_rank_inverse_indices,\n    OneFlow_Tensor:$inverse_unique_partition_indices\n  );\n  let output = (outs\n    OneFlow_Tensor:$cur_rank_unique_embedding_grad\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI64Attr, \"0\">:$embedding_size,\n    DefaultValuedAttr<BoolAttr, \"false\">:$only_zero_valid_grad,\n    DefaultValuedAttr<BoolAttr, \"false\">:$skip_first_scatter,\n    StrAttr:$embedding_name\n  );\n  let same_output_regst_num = 1;\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_EmbeddingPrefetchOp : OneFlow_BaseOp<\"embedding_prefetch\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$num_unique_ids,\n    OneFlow_Tensor:$unique_ids,\n    OneFlow_Tensor:$table_ids\n  );\n  let output = (outs\n    OneFlow_Tensor:$context //no practical sense, control lookup run after prefetch.\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI64Attr, \"0\">:$line_size,\n    DefaultValuedAttr<SI64Attr, \"0\">:$embedding_size,\n    StrAttr:$embedding_name,\n    StrAttr:$embedding_tables,\n    StrAttr:$state_initializer,\n    DefaultValuedAttr<SI64Attr, \"0\">:$seed\n  );\n  let same_output_regst_num = 1;\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_EmbeddingLookupOp : OneFlow_BaseOp<\"embedding_lookup\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$num_unique_ids,\n    OneFlow_Tensor:$unique_ids,\n    OneFlow_Tensor:$table_ids,\n    Optional<OneFlow_Tensor>:$context\n  );\n  let output = (outs\n    OneFlow_Tensor:$unique_values,\n    Optional<OneFlow_Tensor>:$embeddings\n  );\n  let attrs = (ins\n    OneFlow_DataType:$dtype,\n    OneFlow_DataType:$embeddings_dtype,\n    DefaultValuedAttr<SI64Attr, \"0\">:$line_size,\n    DefaultValuedAttr<SI64Attr, \"0\">:$embedding_size,\n    StrAttr:$embedding_name,\n    StrAttr:$embedding_tables,\n    StrAttr:$state_initializer,\n    DefaultValuedAttr<SI64Attr, \"0\">:$seed\n  );\n  let same_output_regst_num = 1;\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_OneEmbeddingFusedSgdUpdatePutOp : OneFlow_BaseOp<\"one_embedding_fused_sgd_update_put\", [DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$num_unique_ids,\n    OneFlow_Tensor:$unique_ids,\n    OneFlow_Tensor:$unique_embeddings,\n    OneFlow_Tensor:$embedding_grad,\n    OneFlow_Tensor:$learning_rate\n  );\n  let attrs = (ins\n    DefaultValuedAttr<F64Attr, \"1.\">:$scale,\n    DefaultValuedAttr<SI64Attr, \"0\">:$line_size,\n    DefaultValuedAttr<SI64Attr, \"0\">:$embedding_size,\n    StrAttr:$embedding_name\n  );\n  let same_output_regst_num = 1;\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_OneEmbeddingSgdUpdateOp : OneFlow_BaseOp<\"one_embedding_sgd_update\", [AttrSizedOperandSegments, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$num_unique_ids,\n    OneFlow_Tensor:$unique_embeddings,\n    OneFlow_Tensor:$embedding_grad,\n    Optional<OneFlow_Tensor>:$learning_rate,\n    Optional<OneFlow_Tensor>:$scale_by_tensor,\n    Optional<OneFlow_Tensor>:$down_scale_by_tensor,\n    Optional<OneFlow_Tensor>:$skip_if\n  );\n  let output = (outs\n    OneFlow_Tensor:$updated_unique_embeddings\n  );\n  let attrs = (ins\n    DefaultValuedAttr<F32Attr, \"0.\">:$learning_rate_val,\n    DefaultValuedAttr<F64Attr, \"1.\">:$scale,\n    DefaultValuedAttr<F32Attr, \"0.\">:$l1,\n    DefaultValuedAttr<F32Attr, \"0.\">:$l2,\n    DefaultValuedAttr<F32Attr, \"0.\">:$weight_decay,\n    DefaultValuedAttr<SI64Attr, \"0\">:$line_size,\n    DefaultValuedAttr<SI64Attr, \"0\">:$embedding_size,\n    StrAttr:$embedding_name\n  );\n  let same_output_regst_num = 1;\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_OneEmbeddingMomentumUpdateOp : OneFlow_BaseOp<\"one_embedding_momentum_update\", [AttrSizedOperandSegments, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$num_unique_ids,\n    OneFlow_Tensor:$unique_embeddings,\n    OneFlow_Tensor:$embedding_grad,\n    Optional<OneFlow_Tensor>:$learning_rate,\n    Optional<OneFlow_Tensor>:$scale_by_tensor,\n    Optional<OneFlow_Tensor>:$down_scale_by_tensor,\n    Optional<OneFlow_Tensor>:$skip_if\n  );\n  let output = (outs\n    OneFlow_Tensor:$updated_unique_embeddings\n  );\n  let attrs = (ins\n    DefaultValuedAttr<F32Attr, \"0.\">:$learning_rate_val,\n    DefaultValuedAttr<F64Attr, \"1.\">:$scale,\n    DefaultValuedAttr<F32Attr, \"0.\">:$l1,\n    DefaultValuedAttr<F32Attr, \"0.\">:$l2,\n    DefaultValuedAttr<F32Attr, \"0.\">:$weight_decay,\n    DefaultValuedAttr<F32Attr, \"0.9\">:$beta,\n    DefaultValuedAttr<SI64Attr, \"0\">:$line_size,\n    DefaultValuedAttr<SI64Attr, \"0\">:$embedding_size,\n    StrAttr:$embedding_name\n  );\n  let same_output_regst_num = 1;\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_OneEmbeddingAdamUpdateOp : OneFlow_BaseOp<\"one_embedding_adam_update\", [AttrSizedOperandSegments, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$num_unique_ids,\n    OneFlow_Tensor:$unique_embeddings,\n    OneFlow_Tensor:$embedding_grad,\n    Optional<OneFlow_Tensor>:$learning_rate,\n    Optional<OneFlow_Tensor>:$scale_by_tensor,\n    Optional<OneFlow_Tensor>:$down_scale_by_tensor,\n    Optional<OneFlow_Tensor>:$skip_if,\n    Optional<OneFlow_Tensor>:$bias_correction1,\n    Optional<OneFlow_Tensor>:$bias_correction2\n  );\n  let output = (outs\n    OneFlow_Tensor:$updated_unique_embeddings\n  );\n  let attrs = (ins\n    DefaultValuedAttr<F32Attr, \"0.\">:$learning_rate_val,\n    DefaultValuedAttr<F32Attr, \"1.\">:$bias_correction1_val,\n    DefaultValuedAttr<F32Attr, \"1.\">:$bias_correction2_val,\n    DefaultValuedAttr<F64Attr, \"1.\">:$scale,\n    DefaultValuedAttr<F32Attr, \"0.\">:$l1,\n    DefaultValuedAttr<F32Attr, \"0.\">:$l2,\n    DefaultValuedAttr<F32Attr, \"0.\">:$weight_decay,\n    DefaultValuedAttr<F32Attr, \"0.9\">:$beta1,\n    DefaultValuedAttr<F32Attr, \"0.999\">:$beta2,\n    DefaultValuedAttr<F32Attr, \"0.\">:$epsilon,\n    DefaultValuedAttr<BoolAttr, \"true\">:$do_bias_correction,\n    DefaultValuedAttr<SI64Attr, \"0\">:$line_size,\n    DefaultValuedAttr<SI64Attr, \"0\">:$embedding_size,\n    StrAttr:$embedding_name\n  );\n  let same_output_regst_num = 1;\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_OneEmbeddingSmartDecaySparseAdamUpdateOp : OneFlow_BaseOp<\"one_embedding_smart_decay_sparse_adam_update\", [AttrSizedOperandSegments, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$num_unique_ids,\n    OneFlow_Tensor:$unique_embeddings,\n    OneFlow_Tensor:$embedding_grad,\n    Optional<OneFlow_Tensor>:$train_step,\n    Optional<OneFlow_Tensor>:$learning_rate,\n    Optional<OneFlow_Tensor>:$scale_by_tensor,\n    Optional<OneFlow_Tensor>:$down_scale_by_tensor,\n    Optional<OneFlow_Tensor>:$skip_if\n  );\n  let output = (outs\n    OneFlow_Tensor:$updated_unique_embeddings\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI64Attr, \"0\">:$train_step_val,\n    DefaultValuedAttr<F32Attr, \"0.\">:$learning_rate_val,\n    DefaultValuedAttr<F64Attr, \"1.\">:$scale,\n    DefaultValuedAttr<F32Attr, \"0.\">:$l1,\n    DefaultValuedAttr<F32Attr, \"0.\">:$l2,\n    DefaultValuedAttr<F32Attr, \"0.\">:$weight_decay,\n    DefaultValuedAttr<F32Attr, \"0.9\">:$beta1,\n    DefaultValuedAttr<F32Attr, \"0.999\">:$beta2,\n    DefaultValuedAttr<F32Attr, \"0.\">:$epsilon,\n    DefaultValuedAttr<BoolAttr, \"true\">:$do_bias_correction,\n    DefaultValuedAttr<SI64Attr, \"0\">:$line_size,\n    DefaultValuedAttr<SI64Attr, \"0\">:$embedding_size,\n    StrAttr:$embedding_name\n  );\n  let same_output_regst_num = 1;\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_OneEmbeddingAdagradUpdateOp : OneFlow_BaseOp<\"one_embedding_adagrad_update\", [AttrSizedOperandSegments, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$num_unique_ids,\n    OneFlow_Tensor:$unique_embeddings,\n    OneFlow_Tensor:$embedding_grad,\n    Optional<OneFlow_Tensor>:$train_step,\n    Optional<OneFlow_Tensor>:$learning_rate,\n    Optional<OneFlow_Tensor>:$scale_by_tensor,\n    Optional<OneFlow_Tensor>:$down_scale_by_tensor,\n    Optional<OneFlow_Tensor>:$skip_if\n  );\n  let output = (outs\n    OneFlow_Tensor:$updated_unique_embeddings\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI64Attr, \"0\">:$train_step_val,\n    DefaultValuedAttr<F32Attr, \"0.\">:$learning_rate_val,\n    DefaultValuedAttr<F64Attr, \"1.\">:$scale,\n    DefaultValuedAttr<F32Attr, \"0.\">:$l1,\n    DefaultValuedAttr<F32Attr, \"0.\">:$l2,\n    DefaultValuedAttr<F32Attr, \"0.\">:$weight_decay,\n    DefaultValuedAttr<F32Attr, \"0.\">:$lr_decay,\n    DefaultValuedAttr<F32Attr, \"0.\">:$epsilon,\n    DefaultValuedAttr<SI64Attr, \"0\">:$line_size,\n    DefaultValuedAttr<SI64Attr, \"0\">:$embedding_size,\n    StrAttr:$embedding_name\n  );\n  let same_output_regst_num = 1;\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_EmbeddingPutOp : OneFlow_BaseOp<\"embedding_put\", [DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$num_unique_ids,\n    OneFlow_Tensor:$unique_ids,\n    OneFlow_Tensor:$unique_embeddings\n  );\n  let attrs = (ins\n    DefaultValuedAttr<SI64Attr, \"0\">:$line_size,\n    StrAttr:$embedding_name\n  );\n  let same_output_regst_num = 1;\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_OneEmbeddingFtrlUpdateOp : OneFlow_BaseOp<\"one_embedding_ftrl_update\", [AttrSizedOperandSegments, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$num_unique_ids,\n    OneFlow_Tensor:$unique_embeddings,\n    OneFlow_Tensor:$embedding_grad,\n    Optional<OneFlow_Tensor>:$learning_rate,\n    Optional<OneFlow_Tensor>:$scale_by_tensor,\n    Optional<OneFlow_Tensor>:$down_scale_by_tensor,\n    Optional<OneFlow_Tensor>:$skip_if\n  );\n  let output = (outs\n    OneFlow_Tensor:$updated_unique_embeddings\n  );\n  let attrs = (ins\n    DefaultValuedAttr<F32Attr, \"0.\">:$learning_rate_val,\n    DefaultValuedAttr<F64Attr, \"1.\">:$scale,\n    DefaultValuedAttr<F32Attr, \"0.\">:$l1,\n    DefaultValuedAttr<F32Attr, \"0.\">:$l2,\n    DefaultValuedAttr<F32Attr, \"0.\">:$weight_decay,\n    DefaultValuedAttr<F32Attr, \"0.\">:$lr_power,\n    DefaultValuedAttr<F32Attr, \"0.\">:$lambda1,\n    DefaultValuedAttr<F32Attr, \"0.\">:$lambda2,\n    DefaultValuedAttr<F32Attr, \"0.\">:$beta,\n    DefaultValuedAttr<SI64Attr, \"0\">:$line_size,\n    DefaultValuedAttr<SI64Attr, \"0\">:$embedding_size,\n    StrAttr:$embedding_name\n  );\n  let same_output_regst_num = 1;\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_RocAucScoreOp : OneFlow_BaseOp<\"roc_auc_score\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$label,\n    OneFlow_Tensor:$pred\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_FillOp : OneFlow_BaseOp<\"fill_\", [NoMemoryEffect, SupportNonContiguous, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<F64Attr, \"0.\">:$floating_value,\n    DefaultValuedAttr<SI64Attr, \"0.\">:$integral_value,\n    DefaultValuedAttr<BoolAttr, \"false\">:$is_floating_value\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_FillTensorOp : OneFlow_BaseOp<\"fill_tensor_\", [NoMemoryEffect, SupportNonContiguous, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$in,\n    OneFlow_Tensor:$value\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\n#endif // GET_ONEFLOW_ONE_EMBEDDING_OP_DEFINITIONS\n\n\n#ifdef GET_ONEFLOW_LINEAR_ALGEBRA_OP_DEFINITIONS\n\ndef OneFlow_InvOp : OneFlow_BaseOp<\"inv\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x\n  );\n  let output = (outs\n    OneFlow_Tensor:$y\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_LinalgCrossOp : OneFlow_BaseOp<\"linalg_cross\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$input,\n    OneFlow_Tensor:$other\n  );\n  let attrs = (ins\n    SI64Attr:$dim\n  );\n  let output = (outs\n    OneFlow_Tensor:$out\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_DetOp : OneFlow_BaseOp<\"det\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x\n  );\n  let output = (outs\n    OneFlow_Tensor:$y\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_LUDecompositionOp : OneFlow_BaseOp<\"lu_decomposition\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n    OneFlow_Tensor:$x\n  );\n  let output = (outs\n    OneFlow_Tensor:$LU,\n    OneFlow_Tensor:$pivot\n  );\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\n#endif // GET_ONEFLOW_LINEAR_ALGEBRA_OP_DEFINITIONS\n\n\n#ifdef GET_ONEFLOW_SYSTEM_OP_DEFINITIONS\n\ndef OneFlow_CopyH2DOp : OneFlow_BaseOp<\"copy_h2d\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n      OneFlow_Tensor:$in\n  );\n  let output = (outs\n      OneFlow_Tensor:$out\n  );\n\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\ndef OneFlow_CopyD2HOp : OneFlow_BaseOp<\"copy_d2h\", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins\n      OneFlow_Tensor:$in\n  );\n  let output = (outs\n      OneFlow_Tensor:$out\n  );\n\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\n#endif // GET_ONEFLOW_SYSTEM_OP_DEFINITIONS\n\ninclude \"mlir/Interfaces/CallInterfaces.td\"\n\nclass OneFlow_JITLikeOp <string mnemonic> : OneFlow_BaseOp<mnemonic, [CallOpInterface, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let input = (ins Variadic<AnyType>:$in);\n  let output = (outs Variadic<AnyType>:$out);\n  let attrs = (ins\n    FlatSymbolRefAttr:$callee,\n    BytesAttr:$mlir_assembly\n  );\n  let builders = [\n    OpBuilder<(ins \"func::FuncOp\":$callee,\n      \"NamedAttrList\":$attributes,\n      CArg<\"ValueRange\", \"{}\">:$in), [{\n      $_state.addOperands(in);\n      $_state.addAttributes(attributes);\n      $_state.addAttribute(\"callee\", SymbolRefAttr::get(callee));\n      $_state.addTypes(callee.getFunctionType().getResults());\n    }]>\n  ];\n  let extraClassDeclaration = [{\n    operand_range getArgOperands() {\n      return {arg_operand_begin(), arg_operand_end()};\n    }\n\n    operand_iterator arg_operand_begin() { return operand_begin(); }\n    operand_iterator arg_operand_end() { return operand_end(); }\n    CallInterfaceCallable getCallableForCallee() {\n      return (*this)->getAttrOfType<SymbolRefAttr>(\"callee\");\n    }\n    \n    void setCalleeFromCallable(CallInterfaceCallable callee) {\n      (*this)->setAttr(\"callee\", callee.get<SymbolRefAttr>());\n    }\n  }];\n  let assemblyFormat = [{\n    $callee `(` $in `)` attr-dict `:` functional-type($in, results)\n  }];\n  let has_logical_tensor_desc_infer_fn = 1;\n  let has_physical_tensor_desc_infer_fn = 1;\n  let has_get_sbp_fn = 1;\n  let has_data_type_infer_fn = 1;\n}\n\n\n#ifdef GET_ONEFLOW_MLIR_JIT_OP_DEFINITIONS\n\ndef OneFlow_MlirJitOp : OneFlow_JITLikeOp<\"mlir_jit\"> {}\n\ndef OneFlow_KernelLaunchOp : OneFlow_JITLikeOp<\"kernel_launch\"> {}\n\n#endif // GET_ONEFLOW_MLIR_JIT_OP_DEFINITIONS\n"
  },
  {
    "path": "oneflow/ir/include/OneFlow/OneFlowUtils.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#ifndef ONEFLOW_IR_INCLUDE_ONEFLOW_ONEFLOWUTILS_H_\n#define ONEFLOW_IR_INCLUDE_ONEFLOW_ONEFLOWUTILS_H_\n#include \"mlir/Pass/PassManager.h\"\n#include \"mlir/Support/LLVM.h\"\n#include \"llvm/ADT/StringExtras.h\"\nnamespace mlir {\nnamespace oneflow {\nvoid CheckEnableIRPrinting(mlir::PassManager& pm);\n// sanitize identifier to make the special name allowed as a legal token\nStringRef SanitizeIdentifier(StringRef name, SmallString<16>& buffer,\n                             StringRef allowedPunctChars = \"$._\", bool allowTrailingDigit = true);\n}  // namespace oneflow\n}  // namespace mlir\n\n#endif  // ONEFLOW_IR_INCLUDE_ONEFLOW_ONEFLOWUTILS_H_\n"
  },
  {
    "path": "oneflow/ir/include/OneFlow/Passes.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_IR_INCLUDE_ONEFLOW_PASSES_H_\n#define ONEFLOW_IR_INCLUDE_ONEFLOW_PASSES_H_\n\n#include \"mlir/Dialect/PDL/IR/PDL.h\"\n#include \"mlir/Dialect/PDLInterp/IR/PDLInterp.h\"\n#include \"mlir/Dialect/MemRef/IR/MemRef.h\"\n#include \"mlir/Dialect/Tosa/IR/TosaOps.h\"\n#include \"mlir/Dialect/SCF/IR/SCF.h\"\n#include \"mlir/Dialect/GPU/IR/GPUDialect.h\"\n#include \"mlir/Dialect/LLVMIR/NVVMDialect.h\"\n#include \"mlir/Dialect/Linalg/IR/Linalg.h\"\n#include \"mlir/Pass/Pass.h\"\n#include \"mlir/Dialect/Func/IR/FuncOps.h\"\n#include \"OneFlow/Conversion/OneFlowToTosa.h\"\n#include \"OneFlow/Transform/OneFlowMemPool.h\"\n#include \"OneFlow/Transform/BufferHostRegister.h\"\n#include \"OneFlow/Transform/ConvertInferenceOp.h\"\n#include \"OneFlow/Transform/OutlineAndFuse.h\"\n#include \"OneFlow/Transform/AutoNhwc.h\"\n#include \"OneFlow/Transform/AggregateOps.h\"\n#include \"OneFlow/Transform/FuncOps.h\"\n#include \"OneFlow/Transform/CSEWithAttributesIgnored.h\"\n#include \"OneFlow/Transform/OneFlowStream.h\"\n#include \"OneFlow/Transform/EliminateAllocOps.h\"\n#include \"OneFlow/Transform/TraitFolder.h\"\n\n#ifdef WITH_MLIR_CUDA_CODEGEN\n#include \"OneFlow/Conversion/NVVMToCubin.h\"\n#endif  // WITH_MLIR_CUDA_CODEGEN\n\nnamespace mlir {\n\nnamespace oneflow {\n\n#define GEN_PASS_CLASSES\n#define GEN_PASS_REGISTRATION\n#include \"OneFlow/OneFlowPasses.h.inc\"\n\nLogicalResult LowerModuleToLLVM(mlir::MLIRContext* context, ModuleOp module);\n#ifdef WITH_MLIR_CUDA_CODEGEN\nLogicalResult LowerModuleToCUDALLVM(mlir::MLIRContext* context, ModuleOp module);\n#endif  // WITH_MLIR_CUDA_CODEGEN\nvoid populateWrapOpsToKernelLaunchPatterns(::mlir::RewritePatternSet& patterns,\n                                           const std::string& mode);\nvoid populateFuserForExistingOp(::mlir::RewritePatternSet& patterns);\nvoid populateGpuHelperPatterns(::mlir::RewritePatternSet& patterns);\nvoid populateAutoNhwcPatterns(::mlir::RewritePatternSet& patterns);\n\nvoid populatePreConvertInferenceOp(::mlir::RewritePatternSet& patterns);\nvoid populateConvertInferenceOp(::mlir::RewritePatternSet& patterns);\nvoid populatePostConvertInferenceOp(::mlir::RewritePatternSet& patterns);\n\nnamespace okl_func {\nconst auto OKL_FUNC = \"_mlir_okl_subgraph\";\n}  // namespace okl_func\n\n}  // namespace oneflow\n\n}  // namespace mlir\n\n#endif  // ONEFLOW_IR_INCLUDE_ONEFLOW_PASSES_H_\n"
  },
  {
    "path": "oneflow/ir/include/OneFlow/SBP/SBPAttributes.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_IR_INCLUDE_SBP_SBPATTRIBUTES_H_\n#define ONEFLOW_IR_INCLUDE_SBP_SBPATTRIBUTES_H_\n\n#include \"mlir/IR/BuiltinAttributes.h\"\n#include \"mlir/Support/LLVM.h\"\n#define GET_ATTRDEF_CLASSES\n#include \"OneFlow/SBPAttributes.h.inc\"\n\n#endif  // ONEFLOW_IR_INCLUDE_SBP_SBPATTRIBUTES_H_\n"
  },
  {
    "path": "oneflow/ir/include/OneFlow/SBP/SBPBase.td",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_IR_INCLUDE_SBP_SBPBASE\n#define ONEFLOW_IR_INCLUDE_SBP_SBPBASE\n\ninclude \"OneFlow/SBP/SBPDialect.td\"\ninclude \"mlir/IR/AttrTypeBase.td\"\ninclude \"mlir/IR/SymbolInterfaces.td\"\ninclude \"mlir/Interfaces/SideEffectInterfaces.td\"\ninclude \"mlir/Interfaces/InferTypeOpInterface.td\"\n\n\nclass SBP_Attr<string name, string attrMnemonic, list<Trait> traits = []>\n    : AttrDef<SBP_Dialect, name, traits> {\n  let mnemonic = attrMnemonic;\n}\n\ndef SBP_SplitAttr : SBP_Attr<\"Split\", \"S\"> {\n  let summary = \"Signature S\";\n  let description = [{\n    signature split, representing a sharded tensor at the `axis`\n  }];\n  let parameters = (ins \"int\":$axis);\n  let assemblyFormat = \"`<` $axis `>`\";\n}\n\ndef SBP_BroadcastAttr : SBP_Attr<\"Broadcast\", \"B\"> {\n  let summary = \"Signature B\";\n  let description = [{\n    signature broadcast, representing a tensor to be duplicated\n  }];\n}\n\ndef SBP_PartialSumAttr : SBP_Attr<\"PartialSum\", \"P\"> {\n  let summary = \"Signature P\";\n  let description = [{\n    signature partial sum, representing a shareded tensor will be reduced lazily\n  }];\n}\n\ndef SBP_AnyAttr : SBP_Attr<\"Any\", \"Any\"> {\n  let summary = \"Signature Any\";\n  let description = [{\n    signature any, representing one of sbp tensor;\n  }];\n}\n\ndef SBP_ParallelSignatureAttr : SBP_Attr<\"ParallelSignature\", \"parallel\"> {\n  let summary = \"Parallel signature of OneFlow Op, aka. SBP\";\n  let description = [{\n    To represent a signature, with a arrow in beween, pass two listes corepondent to the data input and data output tensors. For example:\n    ```\n    #sbp.parallel<[#sbp.S<0>] -> [#sbp.S<0>]>\n    ```\n    One level nested list is used to represent a 2D parallelism signature. For example:\n    ```\n    #sbp.parallel<[[#sbp.S<0>, #sbp.P]] -> [#sbp.S<0>]>\n    ```\n  }];\n  let parameters = (ins \"ArrayAttr\":$inputs, \"ArrayAttr\":$outputs);\n  let assemblyFormat = \"`<` custom<SBP>($inputs) ` ` `->` ` ` custom<SBP>($outputs) `>`\";\n}\n\n#endif // ONEFLOW_IR_INCLUDE_SBP_SBPBASE\n"
  },
  {
    "path": "oneflow/ir/include/OneFlow/SBP/SBPDialect.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_IR_INCLUDE_SBP_SBPDIALECT_H_\n#define ONEFLOW_IR_INCLUDE_SBP_SBPDIALECT_H_\n\n#include \"mlir/IR/Dialect.h\"\n#include \"mlir/Dialect/Func/IR/FuncOps.h\"\n\n#include \"OneFlow/SBPDialect.h.inc\"\n\n#endif  // ONEFLOW_IR_INCLUDE_SBP_SBPDIALECT_H_\n"
  },
  {
    "path": "oneflow/ir/include/OneFlow/SBP/SBPDialect.td",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_IR_INCLUDE_SBP_SBPDIALECT\n#define ONEFLOW_IR_INCLUDE_SBP_SBPDIALECT\n\ninclude \"mlir/IR/OpBase.td\"\n\ndef SBP_Dialect : Dialect {\n    let name = \"sbp\";\n    let summary = \"S(split)B(broadcast)P(partial sum) MLIR dialect.\";\n    let description = [{\n        This dialect is the IR of S(split)B(broadcast)P(partial sum).\n    }];\n    let cppNamespace = \"::mlir::sbp\";\n    let dependentDialects = [\n        \"func::FuncDialect\"\n    ];\n    let extraClassDeclaration = [{\n        void registerAttributes();\n    }];\n    let useDefaultAttributePrinterParser = 1;\n}\n\n#endif // ONEFLOW_IR_INCLUDE_SBP_SBPDIALECT\n"
  },
  {
    "path": "oneflow/ir/include/OneFlow/SBP/SBPImporter.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_IR_INCLUDE_SBP_SBPIMPORTER_H_\n#define ONEFLOW_IR_INCLUDE_SBP_SBPIMPORTER_H_\n#include \"oneflow/core/job/job.pb.h\"\n#include \"oneflow/core/job/sbp_parallel.pb.h\"\n#include \"oneflow/core/operator/op_conf.pb.h\"\n#include \"OneFlow/OneFlowOps.h\"\n\n#include \"mlir/IR/BuiltinOps.h\"\n#include \"mlir/IR/Builders.h\"\n#include \"mlir/IR/MLIRContext.h\"\n\n#include <functional>\n#include <string>\n\nnamespace mlir {\nnamespace oneflow {\n\nclass SBPTranslation {\n public:\n  static mlir::LogicalResult PrintSbpAttrToString(mlir::Attribute sbp_attr, std::string& sbp);\n  static mlir::Attribute ConvertSBPToString(mlir::Builder& builder,\n                                            mlir::sbp::ParallelSignatureAttr& parallel);\n  static mlir::Attribute ConvertNdSbpToPsig(mlir::Builder& builder,\n                                            const std::vector<std::string>& nd_sbp,\n                                            const int nd_size);\n};\n\n}  // namespace oneflow\n}  // namespace mlir\n\n#endif  // ONEFLOW_IR_INCLUDE_SBP_SBPIMPORTER_H_\n"
  },
  {
    "path": "oneflow/ir/include/OneFlow/SBP/SBPOps.td",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_IR_INCLUDE_SBP_SBPOPS\n#define ONEFLOW_IR_INCLUDE_SBP_SBPOPS\n\ninclude \"OneFlow/SBP/SBPDialect.td\"\ninclude \"OneFlow/SBP/SBPBase.td\"\n\ninclude \"mlir/Interfaces/SideEffectInterfaces.td\"\ninclude \"mlir/IR/FunctionInterfaces.td\"\ninclude \"mlir/Interfaces/CallInterfaces.td\"\ninclude \"mlir/Interfaces/ControlFlowInterfaces.td\"\ninclude \"mlir/Pass/PassBase.td\"\n\ninclude \"mlir/IR/OpBase.td\"\n\n#endif // ONEFLOW_IR_INCLUDE_SBP_SBPOPS\n"
  },
  {
    "path": "oneflow/ir/include/OneFlow/Transform/AggregateOps.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_IR_INCLUDE_ONEFLOW_TRANSFORM_AGGREGATE_COMPUTE_OPS_H_\n#define ONEFLOW_IR_INCLUDE_ONEFLOW_TRANSFORM_AGGREGATE_COMPUTE_OPS_H_\n\n#include \"mlir/Pass/Pass.h\"\n\nnamespace mlir {\nnamespace oneflow {\n\nstd::unique_ptr<mlir::Pass> createAggregateComputeOpsPass();\n\n}  // namespace oneflow\n}  // namespace mlir\n\n#endif  // ONEFLOW_IR_INCLUDE_ONEFLOW_TRANSFORM_AGGREGATE_COMPUTE_OPS_H_\n"
  },
  {
    "path": "oneflow/ir/include/OneFlow/Transform/AutoNhwc.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_IR_INCLUDE_ONEFLOW_TRANSFORM_AUTONHWC_H_\n#define ONEFLOW_IR_INCLUDE_ONEFLOW_TRANSFORM_AUTONHWC_H_\n\n#include \"mlir/Pass/Pass.h\"\n\nnamespace mlir {\n\nnamespace oneflow {\n\nstd::unique_ptr<mlir::Pass> createAutoNhwcPass();\n\n}  // namespace oneflow\n\n}  // namespace mlir\n\n#endif  // ONEFLOW_IR_INCLUDE_ONEFLOW_TRANSFORM_AUTONHWC_H_\n"
  },
  {
    "path": "oneflow/ir/include/OneFlow/Transform/BufferHostRegister.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_IR_INCLUDE_ONEFLOW_TRANSFORM_BUFFERHOSTREGISTER_H_\n#define ONEFLOW_IR_INCLUDE_ONEFLOW_TRANSFORM_BUFFERHOSTREGISTER_H_\n\n#include \"mlir/Pass/Pass.h\"\n\nnamespace mlir {\n\nnamespace oneflow {\n\nstd::unique_ptr<mlir::Pass> createBufferHostRegisterPass();\nstd::unique_ptr<mlir::Pass> createGpuCopyArgPass();\n\n}  // namespace oneflow\n\n}  // namespace mlir\n\n#endif  // ONEFLOW_IR_INCLUDE_ONEFLOW_TRANSFORM_BUFFERHOSTREGISTER_H_\n"
  },
  {
    "path": "oneflow/ir/include/OneFlow/Transform/CSEWithAttributesIgnored.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_IR_INCLUDE_ONEFLOW_TRANSFORM_CSEWITHATTRIBUTESIGNORED_H_\n#define ONEFLOW_IR_INCLUDE_ONEFLOW_TRANSFORM_CSEWITHATTRIBUTESIGNORED_H_\n\n#include \"mlir/Pass/Pass.h\"\n\nnamespace mlir {\n\nnamespace oneflow {\n\nstruct CSEState {\n  llvm::DenseMap<Operation*, IntegerAttr> scopeSymbolIDs;\n  llvm::DenseMap<Operation*, StringAttr> opNames;\n};\nstd::unique_ptr<mlir::Pass> createCSEWithAttributesIgnored();\nstd::unique_ptr<mlir::Pass> createCSEPutAttributes();\nstd::pair<std::unique_ptr<Pass>, std::unique_ptr<Pass>> createCSEPasses(\n    std::shared_ptr<CSEState> state);\nvoid registerCSEPasses(std::shared_ptr<CSEState> state);\n\n}  // namespace oneflow\n\n}  // namespace mlir\n\n#endif  // ONEFLOW_IR_INCLUDE_ONEFLOW_TRANSFORM_CSEWITHATTRIBUTESIGNORED_H_\n"
  },
  {
    "path": "oneflow/ir/include/OneFlow/Transform/ConvertInferenceOp.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_IR_INCLUDE_ONEFLOW_TRANSFORM_CONVERTINFERENCE_H_\n#define ONEFLOW_IR_INCLUDE_ONEFLOW_TRANSFORM_CONVERTINFERENCE_H_\n\n#include \"mlir/Pass/Pass.h\"\n\nnamespace mlir {\n\nnamespace oneflow {\n\nstd::unique_ptr<mlir::Pass> createPreConvertInferenceOpPass();\n\nstd::unique_ptr<mlir::Pass> createConvertInferenceOpPass();\n\nstd::unique_ptr<mlir::Pass> createPostConvertInferenceOpPass();\n\n}  // namespace oneflow\n\n}  // namespace mlir\n\n#endif  // ONEFLOW_IR_INCLUDE_ONEFLOW_TRANSFORM_CONVERTINFERENCE_H_\n"
  },
  {
    "path": "oneflow/ir/include/OneFlow/Transform/EliminateAllocOps.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_IR_INCLUDE_ONEFLOW_TRANSFORM_ELIMINATE_ALLOC_OPS_H_\n#define ONEFLOW_IR_INCLUDE_ONEFLOW_TRANSFORM_ELIMINATE_ALLOC_OPS_H_\n\n#include \"mlir/Pass/Pass.h\"\n\nnamespace mlir {\n\nnamespace oneflow {\n\nstd::unique_ptr<mlir::Pass> createEliminateAllocOpsPass();\n\n}  // namespace oneflow\n\n}  // namespace mlir\n\n#endif  // ONEFLOW_IR_INCLUDE_ONEFLOW_TRANSFORM_ELIMINATE_ALLOC_OPS_H_\n"
  },
  {
    "path": "oneflow/ir/include/OneFlow/Transform/FuncOps.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_IR_INCLUDE_ONEFLOW_TRANSFORM_FUNCOPS_H_\n#define ONEFLOW_IR_INCLUDE_ONEFLOW_TRANSFORM_FUNCOPS_H_\n\n#include \"mlir/Pass/Pass.h\"\n\nnamespace mlir {\n\nnamespace oneflow {\n\nstd::unique_ptr<mlir::Pass> createOneFlowJobToFuncPass();\n\nstd::unique_ptr<mlir::Pass> createFuncToOneFlowJobPass();\n\n}  // namespace oneflow\n\n}  // namespace mlir\n\n#endif  // ONEFLOW_IR_INCLUDE_ONEFLOW_TRANSFORM_FUNCOPS_H_\n"
  },
  {
    "path": "oneflow/ir/include/OneFlow/Transform/OneFlow MLIR CodeGen ABI.md",
    "content": "mlir生成的llvm最终的参数列表为：\n\n - 缓存池相关信息\n - 输入1相关信息 ... 输入n相关信息\n - 输出1相关信息 ... 输出n相关信息\n - stream 相关信息\n\n基于上述abi设计相关pass\n - append-ofstream\n - insert-ofmempool"
  },
  {
    "path": "oneflow/ir/include/OneFlow/Transform/OneFlowMemPool.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_IR_INCLUDE_ONEFLOW_TRANSFORM_ONEFLOW_MEMPOOL_H_\n#define ONEFLOW_IR_INCLUDE_ONEFLOW_TRANSFORM_ONEFLOW_MEMPOOL_H_\n\n#include \"mlir/Pass/Pass.h\"\n#include \"mlir/Dialect/Func/IR/FuncOps.h\"\n\nnamespace mlir {\nnamespace oneflow {\n\nnamespace codegen {\nnamespace mempool {\n\ninline const std::string MEMPOOL_ATTR_NAME = \"oneflow.mempool\";\n\n}  // namespace mempool\n}  // namespace codegen\n\nvoid applyFoldAlloc(func::FuncOp op);\n\nstd::unique_ptr<mlir::Pass> createFoldAllocToSubviewPass();\nstd::unique_ptr<mlir::Pass> createInsertOneFlowMemPoolPass();\n\n}  // namespace oneflow\n}  // namespace mlir\n\n#endif  // ONEFLOW_IR_INCLUDE_ONEFLOW_TRANSFORM_ONEFLOW_MEMPOOL_H_"
  },
  {
    "path": "oneflow/ir/include/OneFlow/Transform/OneFlowStream.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_IR_INCLUDE_ONEFLOW_TRANSFORM_ONEFLOW_STREAM_H_\n#define ONEFLOW_IR_INCLUDE_ONEFLOW_TRANSFORM_ONEFLOW_STREAM_H_\n\n#include \"mlir/Pass/Pass.h\"\n\nnamespace mlir {\nnamespace oneflow {\n\nstd::unique_ptr<mlir::Pass> createAppendOneFlowStreamPass();\nstd::unique_ptr<mlir::Pass> createMgpuToOneFlowStreamPass();\n}  // namespace oneflow\n}  // namespace mlir\n\n#endif  // ONEFLOW_IR_INCLUDE_ONEFLOW_TRANSFORM_ONEFLOW_STREAM_H_"
  },
  {
    "path": "oneflow/ir/include/OneFlow/Transform/OutlineAndFuse.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_IR_INCLUDE_ONEFLOW_TRANSFORM_OUTLINEJITFUNCTION_H_\n#define ONEFLOW_IR_INCLUDE_ONEFLOW_TRANSFORM_OUTLINEJITFUNCTION_H_\n\n#include \"mlir/Pass/Pass.h\"\n\nnamespace mlir {\n\nnamespace oneflow {\n\nnamespace wrap_mode {\ninline const std::string SIMPLE = \"simple\";\ninline const std::string CUDA_GRAPH = \"cuda_graph\";\n}  // namespace wrap_mode\n\nnamespace jit {\ninline const std::string RAW_GRAPH = \"oneflow.raw_graph\";\n}\n\nstd::unique_ptr<mlir::Pass> createWrapOpsToKernelLaunchPass();\nstd::unique_ptr<mlir::Pass> createOutlineJitFunctionPass();\nstd::unique_ptr<mlir::Pass> createFuseIntoExistingOpPass();\nstd::unique_ptr<mlir::Pass> createGroupMatMul();\nstd::unique_ptr<mlir::Pass> createFuseForwardOps();\nstd::unique_ptr<mlir::Pass> createFuseOpsWithBackwardImpl();\nstd::unique_ptr<mlir::Pass> createFuseNormalizationOps();\n\n}  // namespace oneflow\n\n}  // namespace mlir\n\n#endif  // ONEFLOW_IR_INCLUDE_ONEFLOW_TRANSFORM_OUTLINEJITFUNCTION_H_\n"
  },
  {
    "path": "oneflow/ir/include/OneFlow/Transform/TraitFolder.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_IR_INCLUDE_ONEFLOW_TRANSFORM_TRAIT_FOLDER_H_\n#define ONEFLOW_IR_INCLUDE_ONEFLOW_TRANSFORM_TRAIT_FOLDER_H_\n\n#include \"mlir/Pass/Pass.h\"\n\nnamespace mlir {\nnamespace oneflow {\n\nstd::unique_ptr<mlir::Pass> createTestOneFlowTraitFolderPass();\n\n}  // namespace oneflow\n}  // namespace mlir\n\n#endif  // ONEFLOW_IR_INCLUDE_ONEFLOW_TRANSFORM_TRAIT_FOLDER_H_"
  },
  {
    "path": "oneflow/ir/include/OneFlow/Transform/TransposeHelpers.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_IR_INCLUDE_ONEFLOW_TRANSFORM_TRANSPOSEHELPERS_H_\n#define ONEFLOW_IR_INCLUDE_ONEFLOW_TRANSFORM_TRANSPOSEHELPERS_H_\n\n#include \"mlir/IR/BuiltinTypes.h\"\n#include \"mlir/IR/Value.h\"\n#include \"OneFlow/OneFlowOps.h\"\n\nnamespace mlir {\n\nnamespace oneflow {\n\nRankedTensorType getNHWCType(RankedTensorType t);\nRankedTensorType getNHWCType(Type t);\nRankedTensorType getNHWCType(Value v);\nRankedTensorType getNCHWType(RankedTensorType t);\nRankedTensorType getNCHWType(Type t);\nRankedTensorType getNCHWType(Value v);\nllvm::SmallVector<Type, 4> getNHWCResultTypes(NCHWCompatible op);\n\n}  // namespace oneflow\n\n}  // namespace mlir\n\n#endif  // ONEFLOW_IR_INCLUDE_ONEFLOW_TRANSFORM_TRANSPOSEHELPERS_H_\n"
  },
  {
    "path": "oneflow/ir/include/OneFlow/UserOpConversion.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_IR_INCLUDE_ONEFLOW_USEROPCONVERSION_H_\n#define ONEFLOW_IR_INCLUDE_ONEFLOW_USEROPCONVERSION_H_\n#include \"OneFlow/OneFlowOps.h\"\n\nnamespace mlir {\n\nnamespace oneflow {\n\nnamespace user_op {\n\n::oneflow::ShapeProto getAttrAsShape(mlir::Attribute& attr);\n::oneflow::Int64ListProto getAttrAsStride(mlir::Attribute& attr);\n::oneflow::AttrType queryAttrType(const std::string& op_type_name, const std::string& attr_name);\nLogicalResult saveAttrDictionaryToOpConf(DictionaryAttr attributes,\n                                         ::oneflow::OperatorConf* op_conf);\nLogicalResult ConvertUserOpAttributes(llvm::StringRef op_type_name, ValueRange operands,\n                                      DictionaryAttr attributes, ::oneflow::OperatorConf& op_conf);\nLogicalResult ConvertUserOpAttributes(Operation* op, ::oneflow::OperatorConf& op_conf);\nLogicalResult ConvertUserOpAttributes(\n    Operation* op, ::oneflow::OperatorConf& op_conf,\n    bool is_mapping_size /* the input and output size should be mapped after building kernel and\n                            provide information for the next query*/\n    = false);\nLogicalResult ConvertUserOpInputs(llvm::StringRef op_type_name, ValueRange operands,\n                                  DictionaryAttr attributes, ::oneflow::UserOpConf* user_conf);\n::oneflow::ParallelConf getParallelConfFromAttrDictionary(DictionaryAttr attributes);\n::oneflow::ParallelConf getParallelConfFromAttrs(Attribute device_name_attr,\n                                                 Attribute device_tag_attr);\n::oneflow::DeviceType getDeviceTypeFromAttrDictionary(DictionaryAttr attributes);\n\n}  // namespace user_op\n\n}  // namespace oneflow\n\n}  // namespace mlir\n\n#endif  // ONEFLOW_IR_INCLUDE_ONEFLOW_USEROPCONVERSION_H_\n"
  },
  {
    "path": "oneflow/ir/include/OneFlow/UserOpReflection.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_IR_INCLUDE_ONEFLOW_USEROPRELFECTION_H_\n#define ONEFLOW_IR_INCLUDE_ONEFLOW_USEROPRELFECTION_H_\n#include \"OneFlow/OneFlowOps.h\"\n\nnamespace mlir {\n\nnamespace oneflow {\n\nnamespace user_op {\n\ntemplate<template<typename T> class Trait>\nLogicalResult GetFilteredSegmentKeyAndSizes(Operation* op, std::vector<std::string>& keys,\n                                            std::vector<int32_t>& sizes);\ntemplate<template<typename T> class Trait>\nLogicalResult GetFilteredSegmentKeyAndSizes(llvm::StringRef op_type_name, size_t valueSize,\n                                            DictionaryAttr attributes,\n                                            std::vector<std::string>& keys,\n                                            std::vector<int32_t>& sizes);\n\nstruct Source {\n  enum {\n    INPUT,\n    OUTPUT,\n    BUFFER,\n    INVALID,\n  } type;\n  int offset;\n};\nSource GetOpSourceByName(Operation* op, const std::string& to_find);\n\nusing ArgID = std::pair<std::string, int32_t>;\n\ntemplate<template<typename T> class Trait>\nclass ArgIds {\n public:\n  explicit ArgIds(Operation* op);\n  ArgIds(llvm::StringRef op_type_name, size_t valueSize, DictionaryAttr attributes);\n  std::vector<ArgID>::const_iterator begin() const { return ids_.begin(); }\n  std::vector<ArgID>::const_iterator end() const { return ids_.end(); }\n\n private:\n  std::vector<ArgID> ids_;\n};\n\nllvm::Optional<std::string> GetOutputLbn(OpResult result);\n\n}  // namespace user_op\n\n}  // namespace oneflow\n\n}  // namespace mlir\n#endif  // ONEFLOW_IR_INCLUDE_ONEFLOW_USEROPRELFECTION_H_\n"
  },
  {
    "path": "oneflow/ir/include/Transform/CMakeLists.txt",
    "content": "set(LLVM_TARGET_DEFINITIONS TransformDialectExtension.td)\nmlir_tablegen(TransformDialectExtension.h.inc -gen-op-decls)\nmlir_tablegen(TransformDialectExtension.cpp.inc -gen-op-defs)\nmlir_tablegen(TransformDialectExtensionTypes.h.inc -gen-typedef-decls -typedefs-dialect=transform)\nmlir_tablegen(TransformDialectExtensionTypes.cpp.inc -gen-typedef-defs -typedefs-dialect=transform)\nadd_public_tablegen_target(MLIROneFlowTransformDialectExtensionIncGen)\n"
  },
  {
    "path": "oneflow/ir/include/Transform/TransformDialectExtension.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_IR_INCLUDE_TRANSOFRM_TRANSFORM_DIALECT_EXTENSION_H_\n#define ONEFLOW_IR_INCLUDE_TRANSOFRM_TRANSFORM_DIALECT_EXTENSION_H_\n\n#include \"mlir/Dialect/PDL/IR/PDLTypes.h\"\n#include \"mlir/Dialect/Transform/IR/MatchInterfaces.h\"\n#include \"mlir/Dialect/Transform/IR/TransformInterfaces.h\"\n#include \"mlir/Dialect/Transform/IR/TransformTypes.h\"\n#include \"mlir/IR/OpImplementation.h\"\n#include \"mlir/IR/PatternMatch.h\"\n\nnamespace mlir {\nclass DialectRegistry;\n\nnamespace oneflow {\nnamespace transform_dialect {\n/// Registers the test extension to the Transform dialect.\nvoid registerTransformDialectExtension(::mlir::DialectRegistry& registry);\nvoid registerTransformDialectEraseSchedulePass();\nvoid registerTransformDialectInterpreterPass();\n\nstruct ApplyPatternsOpPatterns {\n  bool canonicalization = false;\n  bool cse = false;\n};\n\n}  // namespace transform_dialect\n\n}  // namespace oneflow\n}  // namespace mlir\n\n#define GET_TYPEDEF_CLASSES\n#include \"Transform/TransformDialectExtensionTypes.h.inc\"\n\n#define GET_OP_CLASSES\n#include \"Transform/TransformDialectExtension.h.inc\"\n\n#endif  // ONEFLOW_IR_INCLUDE_TRANSOFRM_TRANSFORM_DIALECT_EXTENSION_H_\n"
  },
  {
    "path": "oneflow/ir/include/Transform/TransformDialectExtension.td",
    "content": "#ifndef ONEFLOW_IR_INCLUDE_TRANSOFRM_TRANSFORM_DIALECT_EXTENSION_TD_\n#define ONEFLOW_IR_INCLUDE_TRANSOFRM_TRANSFORM_DIALECT_EXTENSION_TD_\n\ninclude \"mlir/Interfaces/SideEffectInterfaces.td\"\ninclude \"mlir/IR/AttrTypeBase.td\"\ninclude \"mlir/IR/OpBase.td\"\ninclude \"mlir/Dialect/Transform/IR/MatchInterfaces.td\"\ninclude \"mlir/Dialect/Transform/IR/TransformDialect.td\"\ninclude \"mlir/Dialect/Transform/IR/TransformInterfaces.td\"\ninclude \"mlir/Dialect/PDL/IR/PDLTypes.td\"\n\nclass ProduceNoneProto<string mnemonic, list<Trait> traits = []> : \n  Op<Transform_Dialect, mnemonic,\n    traits # [FunctionalStyleTransformOpTrait,\n     MemoryEffectsOpInterface,\n     TransformOpInterface,\n     TransformEachOpTrait]> {\n  let arguments = (ins TransformHandleTypeInterface:$target);\n  let results = (outs);\n\n  let assemblyFormat = \"$target attr-dict `:` functional-type($target, results)\";\n  let cppNamespace = \"mlir::oneflow::transform_dialect\";\n\n  let extraClassDeclaration = [{\n    ::mlir::DiagnosedSilenceableFailure applyToOne(\n        ::mlir::Operation *target,\n        ::mlir::transform::ApplyToEachResultList &results,\n        ::mlir::transform::TransformState &state);\n\n    void getEffects(\n        SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {\n      ::mlir::transform::onlyReadsHandle(getTarget(), effects);\n      ::mlir::transform::modifiesPayload(effects);\n    }\n  }];\n}\n\ndef CSEOp : ProduceNoneProto<\"oneflow.cse\"> {\n  let description = [{\n    cse in transform dialect.\n  }];\n}\n\ndef CanonicalizationOp : ProduceNoneProto<\"oneflow.canonicalization\"> {\n  let description = [{\n    canonicalization in transform dialect.\n  }];\n}\n\ndef ExplicitLinalgOutcomeOp : ProduceNoneProto<\"oneflow.explicit_linalg_outcome\"> {\n  let description = [{\n    fold unit-extent dimensions in operands/results of linalg ops on tensors via rank-reducing slice in transform dialect.\n  }];\n}\n\ndef EliminateCopyOp : ProduceNoneProto<\"oneflow.eliminate_copy\"> {\n  let description = [{\n    eliminate memref.copy if its target equals to source or comes from block arguments.\n  }];\n}\n\ndef FoldAllocOp : ProduceNoneProto<\"oneflow.fold_alloc\"> {\n  let description = [{\n    fold memref.alloc to a single one and subview on it.\n  }];\n}\n\ndef ResultsToOutParamsOp : ProduceNoneProto<\"oneflow.results_to_out_params\"> {\n  let description = [{\n    move results to out params.\n  }];\n}\n\n#endif // ONEFLOW_IR_INCLUDE_TRANSOFRM_TRANSFORM_DIALECT_EXTENSION_TD_\n"
  },
  {
    "path": "oneflow/ir/include/Transform/TransformStateExtension.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#ifndef ONEFLOW_IR_INCLUDE_TRANSOFRM_TRANSFORM_STATE_EXTENSION_H_\n#define ONEFLOW_IR_INCLUDE_TRANSOFRM_TRANSFORM_STATE_EXTENSION_H_\n\n#include \"mlir/Dialect/Transform/IR/TransformInterfaces.h\"\n\nnamespace mlir {\nnamespace oneflow {\n\nnamespace transform_dialect {\nclass TransformStateExtension : public ::mlir::transform::TransformState::Extension {\n public:\n  TransformStateExtension(::mlir::transform::TransformState& state, StringAttr message)\n      : Extension(state), message(message) {}\n\n  StringRef getMessage() const { return message.getValue(); }\n\n  LogicalResult updateMapping(Operation* previous, Operation* updated);\n\n private:\n  StringAttr message;\n};\n\n}  // namespace transform_dialect\n}  // namespace oneflow\n}  // namespace mlir\n\n#endif  // ONEFLOW_IR_INCLUDE_TRANSOFRM_TRANSFORM_STATE_EXTENSION_H_\n"
  },
  {
    "path": "oneflow/ir/install-llvm.cmake",
    "content": "message(\"-- LLVM_MONO_REPO_URL: \" ${LLVM_MONO_REPO_URL})\nmessage(\"-- LLVM_MONO_REPO_MD5: \" ${LLVM_MONO_REPO_MD5})\nFetchContent_Declare(llvm_monorepo)\nFetchContent_GetProperties(llvm_monorepo)\n\nif(NOT llvm_monorepo_POPULATED)\n  FetchContent_Populate(llvm_monorepo URL ${LLVM_MONO_REPO_URL} URL_HASH MD5=${LLVM_MONO_REPO_MD5})\n  set(LLVM_INSTALL_DIR ${THIRD_PARTY_DIR}/llvm)\n\n  execute_process(\n    COMMAND\n      \"${CMAKE_COMMAND}\" ${llvm_monorepo_SOURCE_DIR}/llvm\n      -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} # this is required in newer version of LLVM\n      -DCMAKE_C_COMPILER_LAUNCHER=${CMAKE_C_COMPILER_LAUNCHER}\n      -DCMAKE_CXX_COMPILER_LAUNCHER=${CMAKE_CXX_COMPILER_LAUNCHER}\n      -DCMAKE_CUDA_COMPILER_LAUNCHER=${CMAKE_CUDA_COMPILER_LAUNCHER}\n      -DCMAKE_EXE_LINKER_FLAGS_INIT=${CMAKE_EXE_LINKER_FLAGS_INIT}\n      -DCMAKE_MODULE_LINKER_FLAGS_INIT=${CMAKE_MODULE_LINKER_FLAGS_INIT}\n      -DCMAKE_SHARED_LINKER_FLAGS_INIT=${CMAKE_SHARED_LINKER_FLAGS_INIT}\n      -DCMAKE_INSTALL_PREFIX=${LLVM_INSTALL_DIR} -DCMAKE_INSTALL_MESSAGE=${CMAKE_INSTALL_MESSAGE}\n      -DLLVM_ENABLE_RTTI=ON # turn this on to make it compatible with protobuf\n      -DLLVM_ENABLE_EH=ON # turn this on to make it compatible with half (the library)\n      -DLLVM_BUILD_EXAMPLES=OFF -DLLVM_BUILD_TOOLS=OFF -DLLVM_INCLUDE_EXAMPLES=OFF\n      -DLLVM_INCLUDE_TESTS=OFF -DLLVM_INCLUDE_BENCHMARKS=OFF -DLLVM_TARGETS_TO_BUILD=host\\;NVPTX\n      -DLLVM_ENABLE_ASSERTIONS=ON -DLLVM_ENABLE_PROJECTS=mlir -DLLVM_APPEND_VC_REV=OFF\n      -DLLVM_ENABLE_ZLIB=OFF -DLLVM_INSTALL_UTILS=ON -DBUILD_SHARED_LIBS=${BUILD_SHARED_LIBS}\n      -DLLVM_ENABLE_OCAMLDOC=OFF -DLLVM_ENABLE_BINDINGS=OFF\n      -DLLVM_ENABLE_TERMINFO=OFF # Disable terminfo in llvm so that oneflow doesn't need to link against it\n      -DMLIR_ENABLE_CUDA_RUNNER=${WITH_MLIR_CUDA_CODEGEN}\n      -DCMAKE_CUDA_COMPILER=${CMAKE_CUDA_COMPILER} -DINJA_URL=${INJA_URL}\n      -DINJA_URL_HASH=${INJA_URL_HASH} -DJSON_URL=${JSON_URL} -DJSON_URL_HASH=${JSON_URL_HASH}\n      -DCMAKE_CUDA_COMPILER=${CMAKE_CUDA_COMPILER} -DLLVM_EXTERNAL_PROJECTS=OneFlowTableGen\n      -DLLVM_EXTERNAL_ONEFLOWTABLEGEN_SOURCE_DIR=${CMAKE_SOURCE_DIR}/tools/oneflow-tblgen -G\n      ${CMAKE_GENERATOR}\n    WORKING_DIRECTORY ${llvm_monorepo_BINARY_DIR}\n    RESULT_VARIABLE ret)\n  if(ret EQUAL \"1\")\n    message(FATAL_ERROR \"Bad exit status\")\n  endif()\n  include(ProcessorCount)\n  ProcessorCount(PROC_NUM)\n  if(WITH_MLIR)\n    set(INSTALL_ALL \"install\")\n  endif()\n  execute_process(\n    COMMAND \"${CMAKE_COMMAND}\" --build . -j${PROC_NUM} --target ${INSTALL_ALL}\n            install-oneflow-tblgen install-mlir-headers\n    WORKING_DIRECTORY ${llvm_monorepo_BINARY_DIR} RESULT_VARIABLE ret)\n  if(ret EQUAL \"1\")\n    message(FATAL_ERROR \"Bad exit status\")\n  endif()\nendif()\n\nset(LLVM_INCLUDE_DIRS ${llvm_monorepo_SOURCE_DIR}/llvm/include;${llvm_monorepo_BINARY_DIR}/include)\n\nif(WITH_MLIR)\n  set(LLVM_DIR ${LLVM_INSTALL_DIR}/lib/cmake/llvm)\n  set(MLIR_DIR ${LLVM_INSTALL_DIR}/lib/cmake/mlir)\n  find_package(MLIR REQUIRED CONFIG)\n\n  message(STATUS \"Using MLIRConfig.cmake in: ${MLIR_DIR}\")\n  message(STATUS \"Using LLVMConfig.cmake in: ${LLVM_DIR}\")\n\n  set(MLIR_BINARY_DIR ${llvm_monorepo_BINARY_DIR})\n\n  list(APPEND CMAKE_MODULE_PATH \"${MLIR_CMAKE_DIR}\")\n  list(APPEND CMAKE_MODULE_PATH \"${LLVM_CMAKE_DIR}\")\n  include(TableGen)\n  include(AddLLVM)\n  include(AddMLIR)\n  include(HandleLLVMOptions)\n  set(LLVM_EXTERNAL_LIT \"${llvm_monorepo_BINARY_DIR}/bin/llvm-lit\" CACHE STRING \"\" FORCE)\nendif()\n"
  },
  {
    "path": "oneflow/ir/lib/CMakeLists.txt",
    "content": "add_subdirectory(OneFlow)\nadd_subdirectory(Transform)\n"
  },
  {
    "path": "oneflow/ir/lib/OneFlow/CMakeLists.txt",
    "content": "get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)\nmessage(STATUS \"MLIR_DIALECT_LIBS: ${dialect_libs}\")\nif(WITH_MLIR_CUDA_CODEGEN)\n  set(MLIR_GPU_LIBS MLIRGPUToNVVMTransforms MLIRNVVMToLLVMIRTranslation)\nendif(WITH_MLIR_CUDA_CODEGEN)\n\nset(ONEFLOW_OP_GROUPS\n    \"ASSIGN;BINARY;BROADCAST;CONV;CROSS_ENTROPY;CUDA;DATASET;DETECTION;EAGER;FUSED;IDEMPOTENT;IDENTITY;IMAGE;INDICES;INVOLUTION;LOSS;MATH;MATMUL;MISC;NCCL;NORMALIZATION;OPTIMIZER;PADDING;PARALLEL_CAST;POOL;QUANTIZATION;REDUCE;RESHAPE;SCALAR;SOFTMAX;SUMMARY;TENSOR_BUFFER;TEST;TRIGONOMETRIC;UNARY;UPSAMPLE;ONE_EMBEDDING;LINEAR_ALGEBRA;SYSTEM;MLIR_JIT\"\n)\n\nforeach(OP_GROUP_NAME IN LISTS ONEFLOW_OP_GROUPS)\n  string(TOLOWER \"${OP_GROUP_NAME}\" OP_GROUP_NAME_LOWER)\n  set(CPP_FILE \"OneFlow.${OP_GROUP_NAME_LOWER}_ops.cpp\")\n  list(APPEND GROUPED_OP_CPP_FILES \"${CPP_FILE}\")\n  configure_file(OneFlowOpGetGen.cpp.in ${CPP_FILE} @ONLY)\nendforeach()\n\nadd_subdirectory(PDLL)\n\noneflow_add_mlir_dialect_library(\n  MLIROneFlow\n  OKM/OKMDialect.cpp\n  OKM/passes.cpp\n  OKM/Conversion/Conversion.cpp\n  OKL/OKLDialect.cpp\n  OKL/OKLOps.cpp\n  OKL/OKLTypes.cpp\n  OKL/Conversion/OKLToLLVM.cpp\n  OKL/Conversion/CudaGraphSupport.cpp\n  OKL/Conversion/Conversion.cpp\n  OKL/Kernel/InferContext.cpp\n  OKL/Kernel/KernelLaunchOp.cpp\n  OKL/Kernel/LauncherState.cpp\n  OKL/Kernel/LauncherContext.cpp\n  OKL/Kernel/ComputeContext.cpp\n  OKL/Kernel/RegContext.cpp\n  OKL/Kernel/TmpBufferManager.cpp\n  OKL/Kernel/JITOpInfer.cpp\n  OKL/Kernel/JITEngine.cpp\n  SBP/SBPDialect.cpp\n  SBP/SBPAttributes.cpp\n  SBP/SBPImporter.cpp\n  OneFlowDialect.cpp\n  OneFlowTypes.cpp\n  OneFlowInferReturnTypes.cpp\n  OneFlowOps.cpp\n  OneFlowOpTraits.cpp\n  OneFlowSupport.cpp\n  OneFlowUtils.cpp\n  OneFlowDataTypeConversion.cpp\n  UserOpReflection.cpp\n  UserOpConversion.cpp\n  OneFlowOpFolders.cpp\n  Conversion/OneFlowToTosa.cpp\n  Conversion/OneFlowToLinalg.cpp\n  Conversion/NVVMToCubin.cpp\n  Transform/BufferHostRegister.cpp\n  Transform/OutlineAndFuse.cpp\n  Transform/JITPasses.cpp\n  Transform/AutoNhwc.cpp\n  Transform/ConvertInferenceOp.cpp\n  Transform/AggregateOps.cpp\n  Transform/EliminateAllocOps.cpp\n  Transform/FuncOps.cpp\n  Transform/CSEWithAttributesIgnored.cpp\n  Transform/GroupMatMulOps.cpp\n  Transform/AutoNHWCOps.cpp\n  Transform/OneFlowMemPool.cpp\n  Transform/OneFlowStream.cpp\n  Transform/TraitFolder.cpp\n  TransposeHelpers.cpp\n  Passes.cpp\n  OneFlowCanonicalizers.cpp\n  OneFlowRewrites.cpp\n  ${GROUPED_OP_CPP_FILES}\n  ADDITIONAL_HEADER_DIRS\n  ${PROJECT_SOURCE_DIR}/include/OneFlow\n  DEPENDS\n  MLIROneFlowOpsIncGen\n  prepare_oneflow_third_party\n  LINK_LIBS\n  PUBLIC\n  ${dialect_libs}\n  MLIRTosaToLinalg\n  MLIRTosaToTensor\n  MLIRMemRefToLLVM\n  MLIRLinalgToLLVM\n  MLIRSCFToGPU\n  MLIRReconcileUnrealizedCasts\n  ${MLIR_GPU_LIBS}\n  MLIRIR\n  MLIRBytecodeWriter\n  MLIROneFlowPDLLPatterns\n  MLIRExecutionEngine\n  oneflow)\n\nif(WITH_MLIR_CUDA_CODEGEN)\n  find_library(CUDA_DRIVER_LIBRARY cuda)\n  target_link_libraries(MLIROneFlow PRIVATE ${CUDA_DRIVER_LIBRARY})\n  include_directories(${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})\nendif(WITH_MLIR_CUDA_CODEGEN)\n"
  },
  {
    "path": "oneflow/ir/lib/OneFlow/Conversion/NVVMToCubin.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifdef WITH_MLIR_CUDA_CODEGEN\n#include \"oneflow/core/common/util.h\"\n#include \"OneFlow/Passes.h\"\n#include \"mlir/Dialect/GPU/IR/GPUDialect.h\"\n#include \"mlir/Dialect/GPU/Transforms/Passes.h\"\n#include \"mlir/Dialect/LLVMIR/LLVMDialect.h\"\n#include \"mlir/Support/FileUtilities.h\"\n#include \"mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h\"\n#include \"mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h\"\n#include \"mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h\"\n#include \"mlir/Target/LLVMIR/Export.h\"\n#include \"llvm/ADT/StringSet.h\"\n#include \"llvm/Analysis/TargetTransformInfo.h\"\n#include \"llvm/Bitcode/BitcodeReader.h\"\n#include \"llvm/IR/LegacyPassManager.h\"\n#include \"llvm/IR/Module.h\"\n#include \"llvm/IR/PassManager.h\"\n#include \"llvm/IR/Verifier.h\"\n#include \"llvm/Linker/Linker.h\"\n#include \"llvm/MC/TargetRegistry.h\"\n#include \"llvm/Passes/PassBuilder.h\"\n#include \"llvm/Support/MemoryBuffer.h\"\n#include \"llvm/Support/TargetSelect.h\"\n#include \"llvm/Support/FileSystem.h\"\n#include \"llvm/Target/TargetMachine.h\"\n#include \"llvm/Transforms/IPO.h\"\n#include \"llvm/Transforms/IPO/AlwaysInliner.h\"\n#include \"llvm/Transforms/IPO/Internalize.h\"\n#include \"llvm/Transforms/Scalar/DCE.h\"\n#include \"llvm/Transforms/Vectorize/LoopVectorize.h\"\n#include \"llvm/Transforms/Vectorize/SLPVectorizer.h\"\n\n#include <cuda.h>\n#include <cuda_runtime_api.h>\n\nstatic void emitCudaError(const llvm::Twine& expr, const char* buffer, CUresult result,\n                          mlir::Location loc) {\n  const char* error;\n  cuGetErrorString(result, &error);\n  emitError(loc, expr.concat(\" failed with error code \")\n                     .concat(llvm::Twine{error})\n                     .concat(\"[\")\n                     .concat(buffer)\n                     .concat(\"]\"));\n}\n\n#define RETURN_ON_CUDA_ERROR(expr)                       \\\n  do {                                                   \\\n    if (auto status = (expr)) {                          \\\n      emitCudaError(#expr, jitErrorBuffer, status, loc); \\\n      return {};                                         \\\n    }                                                    \\\n  } while (false)\n\nnamespace mlir {\nnamespace oneflow {\n\nconst char* getArchVersion() {\n  static std::string version;\n  if (!version.empty()) return version.c_str();\n  cudaDeviceProp prop{};\n  cudaError_t err = cudaGetDeviceProperties(&prop, 0);\n  if (err != cudaSuccess) {\n    printf(\"%s\\n\", cudaGetErrorString(err));\n    exit(1);\n  }\n  version = \"sm_\" + std::to_string(prop.major) + std::to_string(prop.minor);\n  return version.c_str();\n}\n\nnamespace {\n\nconst std::string& getLibDevice() {\n  static std::string p;\n  if (!p.empty()) return p;\n  const auto toolkit_env_name = \"CUDA_TOOLKIT_ROOT_DIR\";\n  p = ::oneflow::GetStringFromEnv(toolkit_env_name, \"/usr/local/cuda/\")\n      + \"nvvm/libdevice/libdevice.10.bc\";\n  if (llvm::sys::fs::exists(p)) return p;\n  LOG(FATAL) << \"Could not find file: \" << p << \". Please check you cuda toolkit directory and set \"\n             << toolkit_env_name << \" correctly as an environment variable\";\n}\n\nLogicalResult linkLibdevice(llvm::Module& llvmModule, llvm::LLVMContext& llvmContext) {\n  // Note: infer libdevice path from environment variable\n  auto libDevice = getLibDevice();\n\n  // Note: load raw data from file\n  std::string errorMessage;\n  auto libDeviceBuf = openInputFile(libDevice, &errorMessage);\n  if (!libDeviceBuf) LOG(FATAL) << \"Open File error when link libdevice: \" << errorMessage;\n\n  // Note: load module from raw data\n  auto moduleOrErr = llvm::getOwningLazyBitcodeModule(std::move(libDeviceBuf), llvmContext);\n  if (!moduleOrErr) LOG(FATAL) << \"Failed to load: \" << libDevice << \"\\n\";\n  std::unique_ptr<llvm::Module> libDeviceModule = std::move(moduleOrErr.get());\n\n  // Note: link libdevice with module\n  if (llvm::Linker::linkModules(llvmModule, std::move(libDeviceModule),\n                                llvm::Linker::Flags::LinkOnlyNeeded,\n                                [](llvm::Module& M, const llvm::StringSet<>& GS) {\n                                  llvm::internalizeModule(M, [&GS](const llvm::GlobalValue& GV) {\n                                    return !GV.hasName() || (GS.count(GV.getName()) == 0);\n                                  });\n                                })) {\n    LOG(FATAL) << \"failed to link libdevice module\\n\";\n  }\n\n  return success();\n}\n\nstd::optional<std::string> translateToISA(llvm::Module& llvmModule,\n                                          llvm::TargetMachine& targetMachine) {\n  llvmModule.setDataLayout(targetMachine.createDataLayout());\n\n  // TODO(yuhao): optimizeLlvm\n\n  std::string targetISA;\n  llvm::raw_string_ostream stream(targetISA);\n\n  {  // Note: Drop pstream after this to prevent the ISA from being stuck buffering\n    llvm::buffer_ostream pstream(stream);\n    llvm::legacy::PassManager codegenPasses;\n\n    if (targetMachine.addPassesToEmitFile(codegenPasses, pstream, nullptr, llvm::CGFT_AssemblyFile))\n      return std::nullopt;\n\n    codegenPasses.run(llvmModule);\n  }\n  return stream.str();\n}\n\nclass NVVMToCubinPass : public NVVMToCubinPassBase<NVVMToCubinPass> {\n  std::unique_ptr<llvm::Module> translateToLLVMIR(llvm::LLVMContext& llvmContext) {\n    return translateModuleToLLVMIR(getOperation(), llvmContext, \"LLVMDialectModule\");\n  }\n\n public:\n  std::unique_ptr<llvm::TargetMachine> createTargetMachine();\n  std::unique_ptr<std::vector<char>> serializeISA(const std::string& isa);\n\n  void runOnOperation() override;\n\n  void getDependentDialects(::mlir::DialectRegistry& registry) const override {\n    registerLLVMDialectTranslation(registry);\n    registerNVVMDialectTranslation(registry);\n    registerGPUDialectTranslation(registry);\n    registerLLVMDialectTranslation(registry);\n  }\n};\n\nstd::unique_ptr<llvm::TargetMachine> NVVMToCubinPass::createTargetMachine() {\n  Location loc = getOperation().getLoc();\n  std::string error;\n  const llvm::Target* target = ::llvm::TargetRegistry::lookupTarget(triple.str(), error);\n  if (!target) {\n    emitError(loc, Twine(\"failed to lookup target: \") + error);\n    return {};\n  }\n  llvm::TargetMachine* machine =\n      target->createTargetMachine(triple.str(), chip.str(), features.str(), {}, {});\n  if (!machine) {\n    emitError(loc, \"failed to create target machine\");\n    return {};\n  }\n\n  return std::unique_ptr<llvm::TargetMachine>{machine};\n}\n\nstd::unique_ptr<std::vector<char>> NVVMToCubinPass::serializeISA(const std::string& isa) {\n  Location loc = getOperation().getLoc();\n  char jitErrorBuffer[4096] = {0};\n\n  RETURN_ON_CUDA_ERROR(cuInit(0));\n\n  // Note: Linking requires a device context.\n  CUdevice device;\n  RETURN_ON_CUDA_ERROR(cuDeviceGet(&device, 0));\n  CUcontext context;\n  RETURN_ON_CUDA_ERROR(cuCtxCreate(&context, 0, device));\n  CUlinkState linkState;\n\n  CUjit_option jitOptions[] = {CU_JIT_ERROR_LOG_BUFFER, CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES};\n  void* jitOptionsVals[] = {jitErrorBuffer, reinterpret_cast<void*>(sizeof(jitErrorBuffer))};\n\n  RETURN_ON_CUDA_ERROR(cuLinkCreate(2,              /* number of jit options */\n                                    jitOptions,     /* jit options */\n                                    jitOptionsVals, /* jit option values */\n                                    &linkState));\n\n  auto kernelName = getOperation().getName().str();\n  RETURN_ON_CUDA_ERROR(cuLinkAddData(linkState, CUjitInputType::CU_JIT_INPUT_PTX,\n                                     const_cast<void*>(static_cast<const void*>(isa.c_str())),\n                                     isa.length(), kernelName.c_str(),\n                                     0,       /* number of jit options */\n                                     nullptr, /* jit options */\n                                     nullptr  /* jit option values */\n                                     ));\n\n  void* cubinData;\n  size_t cubinSize;\n  RETURN_ON_CUDA_ERROR(cuLinkComplete(linkState, &cubinData, &cubinSize));\n\n  char* cubinAsChar = static_cast<char*>(cubinData);\n  auto result = std::make_unique<std::vector<char>>(cubinAsChar, cubinAsChar + cubinSize);\n\n  RETURN_ON_CUDA_ERROR(cuLinkDestroy(linkState));\n  RETURN_ON_CUDA_ERROR(cuCtxDestroy(context));\n\n  return result;\n}\n\nvoid NVVMToCubinPass::runOnOperation() {\n  llvm::LLVMContext llvmContext;\n  std::unique_ptr<llvm::Module> llvmModule = translateToLLVMIR(llvmContext);\n  if (!llvmModule) return signalPassFailure();\n  if (failed(linkLibdevice(*llvmModule, llvmContext))) { return signalPassFailure(); }\n\n  // Note: Lower the LLVM IR module to target ISA.\n  std::unique_ptr<llvm::TargetMachine> targetMachine = createTargetMachine();\n  if (!targetMachine) return signalPassFailure();\n\n  std::optional<std::string> maybeTargetISA = translateToISA(*llvmModule, *targetMachine);\n\n  if (!maybeTargetISA.has_value()) return signalPassFailure();\n\n  std::string targetISA = std::move(*maybeTargetISA);\n\n  // Note: Serialize the target ISA.\n  std::unique_ptr<std::vector<char>> blob = serializeISA(targetISA);\n  if (!blob) return signalPassFailure();\n\n  // Note: Add the blob as module attribute.\n  auto attr = StringAttr::get(&getContext(), StringRef(blob->data(), blob->size()));\n  getOperation()->setAttr(gpu::getCubinAnnotation(), attr);\n}\n\n}  // namespace\n\nstd::unique_ptr<mlir::Pass> createNVVMToCubinPass() { return std::make_unique<NVVMToCubinPass>(); }\n\nvoid InitializeLLVMNVPTXBackend() {\n  LLVMInitializeNVPTXTarget();\n  LLVMInitializeNVPTXTargetInfo();\n  LLVMInitializeNVPTXTargetMC();\n  LLVMInitializeNVPTXAsmPrinter();\n}\n\n}  // namespace oneflow\n}  // namespace mlir\n\n#endif  // WITH_MLIR_CUDA_CODEGEN"
  },
  {
    "path": "oneflow/ir/lib/OneFlow/Conversion/OneFlowToLinalg.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"OneFlow/OneFlowOps.h\"\n#include \"OneFlow/Passes.h\"\n#include \"mlir/Dialect/Arith/IR/Arith.h\"\n#include \"mlir/Dialect/Linalg/Utils/Utils.h\"\n#include \"mlir/Dialect/Math/IR/Math.h\"\n#include \"mlir/Dialect/Tensor/IR/Tensor.h\"\n#include \"mlir/Transforms/DialectConversion.h\"\n\nnamespace mlir {\n\nnamespace oneflow {\n\nnamespace {\n\nstd::tuple<SmallVector<::mlir::utils::IteratorType>, SmallVector<AffineMap>>\ncomputeIteratorTypesAndIndexingMaps(int64_t inputRank, int64_t dim, OpBuilder& builder,\n                                    bool allParallel = false) {\n  SmallVector<::mlir::utils::IteratorType> iteratorTypes(inputRank,\n                                                         ::mlir::utils::IteratorType::parallel);\n  if (!allParallel) iteratorTypes[dim] = ::mlir::utils::IteratorType::reduction;\n  auto identityMap = AffineMap::getMultiDimIdentityMap(inputRank, builder.getContext());\n  SmallVector<AffineExpr, 2> affineExprs;\n  for (int i = 0; i < inputRank; i++) {\n    if (i != dim) affineExprs.push_back(mlir::getAffineDimExpr(i, builder.getContext()));\n  }\n  auto reductionMap = AffineMap::get(inputRank, 0, affineExprs, builder.getContext());\n  SmallVector<AffineMap> indexingMaps{identityMap, reductionMap};\n  return std::make_tuple(iteratorTypes, indexingMaps);\n}\n\ntemplate<typename T>\nstatic Value reduce(Value input, Value output, int64_t dim, Location loc, OpBuilder& builder) {\n  auto inputType = input.getType().cast<ShapedType>();\n  ArrayRef<int64_t> inputShape = inputType.getShape();\n  int64_t inputRank = inputShape.size();\n  auto [iteratorTypes, indexingMaps] = computeIteratorTypesAndIndexingMaps(inputRank, dim, builder);\n  auto genericOp = builder.create<linalg::GenericOp>(\n      loc, output.getType(), input, output, indexingMaps, iteratorTypes,\n      [&](OpBuilder& b, Location loc, ValueRange args) {\n        Value result = b.create<T>(loc, args[0], args[1]);\n        b.create<linalg::YieldOp>(loc, result);\n      });\n  return genericOp.getResult(0);\n}\n\nstatic Value subtractAndExp(Value input, Value max, Value output, int64_t dim, Location loc,\n                            OpBuilder& builder) {\n  auto inputType = input.getType().cast<ShapedType>();\n  ArrayRef<int64_t> inputShape = inputType.getShape();\n  int64_t inputRank = inputShape.size();\n  auto [iteratorTypes, indexingMaps] =\n      computeIteratorTypesAndIndexingMaps(inputRank, dim, builder, true);\n  indexingMaps.push_back(indexingMaps[0]);\n  auto genericOp = builder.create<linalg::GenericOp>(\n      loc, input.getType(), ValueRange{input, max}, output, indexingMaps, iteratorTypes,\n      [&](OpBuilder& b, Location loc, ValueRange args) {\n        Value diff = b.create<arith::SubFOp>(loc, args[0], args[1]);\n        Value result = b.create<math::ExpOp>(loc, diff);\n        b.create<linalg::YieldOp>(loc, result);\n      });\n  return genericOp.getResult(0);\n}\n\nstatic Value computeSoftmax(Value numerator, Value denominator, Value output, int64_t dim,\n                            Location loc, OpBuilder& builder) {\n  auto inputType = numerator.getType().cast<ShapedType>();\n  ArrayRef<int64_t> inputShape = inputType.getShape();\n  int64_t inputRank = inputShape.size();\n  auto [iteratorTypes, indexingMaps] =\n      computeIteratorTypesAndIndexingMaps(inputRank, dim, builder, true);\n  indexingMaps.push_back(indexingMaps[0]);\n  auto genericOp = builder.create<linalg::GenericOp>(\n      loc, numerator.getType(), ValueRange{numerator, denominator}, output, indexingMaps,\n      iteratorTypes, [&](OpBuilder& b, Location loc, ValueRange args) {\n        Value result = b.create<arith::DivFOp>(loc, args[0], args[1]);\n        b.create<linalg::YieldOp>(loc, result);\n      });\n  return genericOp.getResult(0);\n}\n\n/// Given an N-dimensional tensor x, this op converts\n/// softmax(x) to the following sequence of operations:\n///\n/// 1. Compute the max of x along dimension d. This results\n///    in a N-1 dimensional tensor m.\n///    m = max(x, dim = d)\n///\n/// 2. Subtract m from x and exponentiate. This results in\n///    a N dimensional tensor z.\n///    z = exp(x - m)\n///\n/// 3. Compute the sum of z along dimension d. This results in\n///    a N-1 dimensional tensor l.\n///    l = sum(z, dim = d)\n///\n/// 4. Divide z and l. This gives the N-dimensional softmax.\n///    softmax = z / l\n///\n\n// Implementation above is from IREE.\n// https://github.com/google/iree/blob/b339919814f10589f779b39c3ab7c6575716dab6/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/DecomposeSoftmax.cpp\n\nSmallVector<OpFoldResult> createDimValues(OpBuilder& b, Location loc, Value rankedTensor) {\n  auto tensorTy = rankedTensor.getType().cast<RankedTensorType>();\n  SmallVector<OpFoldResult> dims;\n  for (const auto& en : llvm::enumerate(tensorTy.getShape())) {\n    if (ShapedType::isDynamic(en.value())) {\n      dims.push_back(b.createOrFold<tensor::DimOp>(loc, rankedTensor, en.index()));\n    } else {\n      dims.push_back(b.getIndexAttr(en.value()));\n    }\n  }\n  return dims;\n}\n\nstruct SoftmaxOpLowering final : public OpConversionPattern<SoftmaxOp> {\n public:\n  using OpConversionPattern<SoftmaxOp>::OpConversionPattern;\n  LogicalResult matchAndRewrite(SoftmaxOp softmaxOp, OpAdaptor adaptor,\n                                ConversionPatternRewriter& rewriter) const override {\n    OpBuilder::InsertionGuard guard(rewriter);\n    rewriter.setInsertionPoint(softmaxOp);\n    Location loc = softmaxOp.getLoc();\n    Value input = softmaxOp.getIn();\n    ShapedType inputType = input.getType().cast<ShapedType>();\n    Type elementType = inputType.getElementType();\n    int64_t reductionDim = inputType.getRank() - 1;\n    SmallVector<OpFoldResult> dims = createDimValues(rewriter, loc, input);\n    Value outputNd = rewriter.create<tensor::EmptyOp>(loc, dims, elementType);\n    dims.erase(dims.begin() + reductionDim);\n    // Compute max along dim\n    Value output = rewriter.create<tensor::EmptyOp>(loc, dims, elementType);\n    Value largeNegative =\n        rewriter.create<arith::ConstantOp>(loc, rewriter.getFloatAttr(elementType, -1.0e30));\n    Value negativeInit =\n        rewriter.create<linalg::FillOp>(loc, Value{largeNegative}, output).result();\n    Value max = reduce<arith::MaxFOp>(input, negativeInit, reductionDim, loc, rewriter);\n    // Subtract max from input and exponentiate\n    Value numerator = subtractAndExp(input, max, outputNd, reductionDim, loc, rewriter);\n    // Compute sum along dim\n    Value zero = rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(elementType));\n    Value zeroInit = rewriter.create<linalg::FillOp>(loc, Value{zero}, output).result();\n    Value denominator = reduce<arith::AddFOp>(numerator, zeroInit, reductionDim, loc, rewriter);\n    // Compute softmax\n    Value result = computeSoftmax(numerator, denominator, outputNd, reductionDim, loc, rewriter);\n    rewriter.replaceOp(softmaxOp, {result});\n    return success();\n  }\n};\n\nstruct OneFlowLoweringToLinalgPass\n    : public LowerOneFlowToLinalgPassBase<OneFlowLoweringToLinalgPass> {\n  void runOnOperation() {\n    MLIRContext* context = &getContext();\n    ConversionTarget target(*context);\n    target.addLegalDialect<memref::MemRefDialect, mlir::func::FuncDialect, tosa::TosaDialect,\n                           linalg::LinalgDialect, tensor::TensorDialect, arith::ArithDialect,\n                           math::MathDialect>();\n    RewritePatternSet patterns(context);\n    patterns.add<SoftmaxOpLowering>(context);\n    if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) {\n      signalPassFailure();\n    }\n  }\n};\n\n}  // namespace\n\nstd::unique_ptr<Pass> createLowerOneFlowToLinalgPass() {\n  return std::make_unique<OneFlowLoweringToLinalgPass>();\n}\n\n}  // namespace oneflow\n}  // namespace mlir\n"
  },
  {
    "path": "oneflow/ir/lib/OneFlow/Conversion/OneFlowToTosa.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"OneFlow/OneFlowOps.h\"\n#include <cstdint>\n#include <iostream>\n#include <string>\n#include \"OneFlow/OneFlowDialect.h\"\n#include \"OneFlow/Passes.h\"\n#include \"llvm/ADT/ArrayRef.h\"\n#include \"llvm/ADT/STLExtras.h\"\n#include \"llvm/Support/Casting.h\"\n#include \"mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h\"\n#include \"mlir/Conversion/TosaToLinalg/TosaToLinalg.h\"\n#include \"mlir/Dialect/Affine/IR/AffineOps.h\"\n#include \"mlir/Dialect/Arith/IR/Arith.h\"\n#include \"mlir/Dialect/Linalg/Passes.h\"\n#include \"mlir/Dialect/MemRef/IR/MemRef.h\"\n#include \"mlir/Dialect/Func/IR/FuncOps.h\"\n#include \"mlir/Dialect/Func/Transforms/Passes.h\"\n#include \"mlir/Dialect/Tensor/IR/Tensor.h\"\n#include \"mlir/Dialect/Tosa/IR/TosaOps.h\"\n#include \"mlir/IR/BuiltinAttributes.h\"\n#include \"mlir/IR/BuiltinDialect.h\"\n#include \"mlir/IR/BuiltinTypeInterfaces.h\"\n#include \"mlir/IR/BuiltinTypes.h\"\n#include \"mlir/IR/Diagnostics.h\"\n#include \"mlir/IR/OpImplementation.h\"\n\n#include \"mlir/Pass/Pass.h\"\n#include \"mlir/Pass/PassManager.h\"\n#include \"mlir/Support/LogicalResult.h\"\n#include \"mlir/Transforms/DialectConversion.h\"\n#include \"mlir/Transforms/GreedyPatternRewriteDriver.h\"\n#include \"mlir/Transforms/Passes.h\"\n#include \"oneflow/core/framework/op_expr_grad_function.h\"\n#include \"oneflow/core/framework/variable_tensor_mgr.h\"\n\n#include <limits>\n\nnamespace mlir {\n\nnamespace oneflow {\n\nType convertToSignless(MLIRContext* context, Type type) {\n  if (auto ranked_tensor = type.dyn_cast<RankedTensorType>()) {\n    if (auto intTy = ranked_tensor.getElementType().dyn_cast<IntegerType>()) {\n      if (!intTy.isSignless()) {\n        return RankedTensorType::get(\n            ranked_tensor.getShape(),\n            IntegerType::get(context, intTy.getWidth(),\n                             mlir::IntegerType::SignednessSemantics::Signless));\n      }\n    }\n  }\n  return type;\n}\n\nFunctionType convertToSignlessFuncType(MLIRContext* context, FunctionType funcType) {\n  llvm::SmallVector<Type, 4> inputs;\n  llvm::SmallVector<Type, 4> results;\n  for (auto arg : funcType.getInputs()) { inputs.push_back(convertToSignless(context, arg)); }\n  for (auto res : funcType.getResults()) { results.push_back(convertToSignless(context, res)); }\n  return FunctionType::get(context, inputs, results);\n}\n\nbool isSignLessTensorOrOther(Type type) {\n  if (auto ranked_tensor = type.dyn_cast<RankedTensorType>()) {\n    if (auto intTy = ranked_tensor.getElementType().dyn_cast<IntegerType>()) {\n      if (intTy.isUnsigned()) { return false; }\n      if (intTy.isSigned()) { return false; }\n    }\n  }\n  return true;\n}\nbool allSignless(mlir::TypeRange types) {\n  for (auto type : types) {\n    if (!isSignLessTensorOrOther(type)) { return false; }\n  }\n  return true;\n}\n\nbool allSignless(FunctionType funcType) {\n  for (auto arg : funcType.getInputs()) {\n    if (!isSignLessTensorOrOther(arg)) { return false; }\n  }\n  for (auto res : funcType.getResults()) {\n    if (!isSignLessTensorOrOther(res)) { return false; }\n  }\n  return true;\n}\n\nValue CreateTransposeValue(Location& loc, ConversionPatternRewriter& rewriter, Value input,\n                           ArrayRef<int32_t> perms) {\n  int perms_size = perms.size();\n  auto transpose_perms = rewriter.create<tosa::ConstOp>(\n      loc, RankedTensorType::get({perms_size}, rewriter.getI32Type()),\n      rewriter.getI32TensorAttr(perms));\n  const auto shape_type = input.getType().cast<ShapedType>();\n  std::vector<int64_t> ranked_type;\n  for (const auto& index : perms) ranked_type.push_back(shape_type.getDimSize(index));\n  return rewriter.create<tosa::TransposeOp>(\n      loc, RankedTensorType::get(ranked_type, shape_type.getElementType()), input, transpose_perms);\n};\n\nRankedTensorType CreateTransposeType(ShapedType output, ArrayRef<int32_t> perms) {\n  std::vector<int64_t> ranked_type;\n  for (auto index : perms) ranked_type.push_back(output.getDimSize(index));\n  return RankedTensorType::get(ranked_type, output.getElementType());\n};\n\nValue CreateBNOp(Location loc, ConversionPatternRewriter& rewriter, Type output_type, Value x,\n                 Value mean, Value variance, Value epsilon, Value gamma, Value beta) {\n  // sub_op = sub(input, mean)\n  auto sub_op0 = rewriter.create<tosa::SubOp>(loc, output_type, x, mean);\n  // add_op0 = add(var, epsilon)\n  auto add_op0 = rewriter.create<tosa::AddOp>(loc, variance.getType(), variance, epsilon);\n  // rsqrt_op = rsqrt(add_op0)\n  auto rsqrt_op = rewriter.create<tosa::RsqrtOp>(loc, variance.getType(), add_op0);\n  // op4 = mul(sub_op, rsqrt_op)\n  auto mul_op0 = rewriter.create<tosa::MulOp>(loc, output_type, sub_op0, rsqrt_op, 0);\n  // op5 = mul(mul_op0, gamma)\n  auto mul_op1 = rewriter.create<tosa::MulOp>(loc, output_type, mul_op0, gamma, 0);\n  // op6 = add(mul_op1, beta)\n  Value batch_norm = rewriter.create<tosa::AddOp>(loc, output_type, mul_op1, beta);\n  return batch_norm;\n};\n\nstruct ScalarMulByTensorOpLowering final : public OpConversionPattern<ScalarMulByTensorOp> {\n public:\n  using OpConversionPattern<ScalarMulByTensorOp>::OpConversionPattern;\n\n  LogicalResult matchAndRewrite(ScalarMulByTensorOp op, OpAdaptor adaptor,\n                                ConversionPatternRewriter& rewriter) const override {\n    Value scalar = op.getScalar();\n    rewriter.replaceOpWithNewOp<tosa::MulOp>(\n        op,\n        /* output */ op->getResultTypes().front().cast<TensorType>(),\n        /* input1 */ op.getX(),\n        /* input2 */ scalar,\n        /* shift */ rewriter.getIntegerAttr(rewriter.getI32Type(), 0));\n    return success();\n  }\n};\n\nstruct JobLowering final : public OpConversionPattern<Job> {\n public:\n  using OpConversionPattern<Job>::OpConversionPattern;\n  LogicalResult matchAndRewrite(Job op, OpAdaptor adaptor,\n                                ConversionPatternRewriter& rewriter) const override {\n    auto func_type = convertToSignlessFuncType(op->getContext(), op.getFunctionType());\n    auto func = rewriter.create<mlir::func::FuncOp>(op.getLoc(), op.getName(), func_type);\n    rewriter.inlineRegionBefore(op.getRegion(), func.getBody(), func.end());\n    rewriter.eraseOp(op);\n    return success();\n  }\n};\n\nstruct ReturnOpLowering final : public OpConversionPattern<ReturnOp> {\n public:\n  using OpConversionPattern<ReturnOp>::OpConversionPattern;\n  LogicalResult matchAndRewrite(ReturnOp op, OpAdaptor adaptor,\n                                ConversionPatternRewriter& rewriter) const override {\n    rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(op,\n                                                      /* operands */ op.getOperands());\n    return success();\n  }\n};\n\nstruct InputOpLowering final : public OpConversionPattern<InputOp> {\n public:\n  using OpConversionPattern<InputOp>::OpConversionPattern;\n  LogicalResult matchAndRewrite(InputOp op, OpAdaptor adaptor,\n                                ConversionPatternRewriter& rewriter) const override {\n    // TODO: more choices to passing data between tosa and oneflow\n    const auto newValues = op.getInput();\n    const auto is_block_arg = newValues.dyn_cast<BlockArgument>() != nullptr;\n    if (!is_block_arg) { return op->emitError(\"input is not block arg\"); }\n    rewriter.replaceOp(op, newValues);\n    return success();\n  }\n};\n\nstruct OutputOpLowering final : public OpConversionPattern<OutputOp> {\n public:\n  using OpConversionPattern<OutputOp>::OpConversionPattern;\n  LogicalResult matchAndRewrite(OutputOp op, OpAdaptor adaptor,\n                                ConversionPatternRewriter& rewriter) const override {\n    // TODO: more choices to passing data between tosa and oneflow\n    const auto newValues = op.getInput();\n    rewriter.replaceOp(op, newValues);\n    return success();\n  }\n};\n\nstruct VariableOpLowering final : public OpConversionPattern<VariableOp> {\n public:\n  using OpConversionPattern<VariableOp>::OpConversionPattern;\n  LogicalResult matchAndRewrite(VariableOp op, OpAdaptor adaptor,\n                                ConversionPatternRewriter& rewriter) const override {\n    const auto mgr = ::oneflow::Singleton<::oneflow::VariableTensorMgr>::Get();\n    if (!mgr) { return op->emitError(\"global variable tensor manager miss\"); }\n\n    const auto tensor = CHECK_JUST(mgr->Get(op.getOpName().str()));\n    if (!tensor) { return op->emitError(\"tensor is null\"); }\n    const auto value = support::TensorToDenseElementsAttr(tensor, rewriter.getContext());\n    const auto output = op.getOutput().getType();\n\n    rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, output, value);\n    return success();\n  }\n};\n\nstruct VariableOpToConstLowering final : public OpConversionPattern<VariableOp> {\n public:\n  VariableOpToConstLowering(TypeConverter& typeConverter, MLIRContext* context, int const_val)\n      : OpConversionPattern<VariableOp>(typeConverter, context), const_val_(const_val){};\n\n  using OpConversionPattern<VariableOp>::OpConversionPattern;\n  LogicalResult matchAndRewrite(VariableOp op, OpAdaptor adaptor,\n                                ConversionPatternRewriter& rewriter) const override {\n    const auto output = op.getOutput().getType();\n    const auto type = output.cast<ShapedType>().getElementType();\n\n    // TODO: more control about this scope with flag\n    if (type.isa<FloatType>()) {\n      const auto float_attr = rewriter.getFloatAttr(type, const_val_);\n      auto value = DenseElementsAttr::get(output.cast<ShapedType>(), float_attr);\n\n      rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, output, value);\n    } else if (auto integerType = type.dyn_cast<IntegerType>()) {\n      const auto int_attr =\n          rewriter.getIntegerAttr(type, APInt(type.cast<IntegerType>().getWidth(), const_val_));\n      auto value = DenseElementsAttr::get(output.cast<ShapedType>(), int_attr);\n\n      rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, output, value);\n    } else {\n      return op->emitError(\n          \"OneFlow variable op lower to TOSA const op only support integer and float value now\");\n    }\n\n    return success();\n  }\n\n private:\n  int const_val_;\n};\n\nstruct CastOpLowering final : public OpConversionPattern<CastOp> {\n public:\n  using OpConversionPattern<CastOp>::OpConversionPattern;\n  LogicalResult matchAndRewrite(CastOp op, OpAdaptor adaptor,\n                                ConversionPatternRewriter& rewriter) const override {\n    auto output = op.getOut().getType();\n    auto input = op.getIn();\n    rewriter.replaceOpWithNewOp<tosa::CastOp>(op, output, input);\n    return success();\n  }\n};\n\nstruct ReluOpLowering final : public OpConversionPattern<ReluOp> {\n public:\n  using OpConversionPattern<ReluOp>::OpConversionPattern;\n  LogicalResult matchAndRewrite(ReluOp op, OpAdaptor adaptor,\n                                ConversionPatternRewriter& rewriter) const override {\n    const auto output = op.getY().getType();\n    auto input = op.getX();\n\n    auto ranked_output = llvm::dyn_cast_or_null<RankedTensorType>(output);\n    auto value =\n        DenseElementsAttr::get(output.cast<ShapedType>(),\n                               rewriter.getZeroAttr(ranked_output ? ranked_output.getElementType()\n                                                                  : rewriter.getI64Type()));\n    tosa::ConstOp zeros = rewriter.create<tosa::ConstOp>(op.getLoc(), output, value);\n    rewriter.replaceOpWithNewOp<tosa::MaximumOp>(op, output, input, zeros);\n    return success();\n  }\n};\n\nstruct BroadcastAddOpLowering final : public OpConversionPattern<BroadcastAddOp> {\n public:\n  using OpConversionPattern<BroadcastAddOp>::OpConversionPattern;\n  LogicalResult matchAndRewrite(BroadcastAddOp op, OpAdaptor adaptor,\n                                ConversionPatternRewriter& rewriter) const override {\n    const auto output = op.getZ().getType();\n    auto input1 = op.getX();\n    auto input2 = op.getY();\n\n    rewriter.replaceOpWithNewOp<tosa::AddOp>(op, output, input1, input2);\n    return success();\n  }\n};\n\nstruct Add2OpLowering final : public OpConversionPattern<Add2Op> {\n public:\n  using OpConversionPattern<Add2Op>::OpConversionPattern;\n  LogicalResult matchAndRewrite(Add2Op op, OpAdaptor adaptor,\n                                ConversionPatternRewriter& rewriter) const override {\n    const auto output = op.getOut().getType();\n    auto input1 = op.getIn0();\n    auto input2 = op.getIn1();\n\n    rewriter.replaceOpWithNewOp<tosa::AddOp>(op, output, input1, input2);\n    return success();\n  }\n};\n\nstruct AvgPool2DOpLowering final : public OpConversionPattern<AvgPool2DOp> {\n public:\n  using OpConversionPattern<AvgPool2DOp>::OpConversionPattern;\n  LogicalResult matchAndRewrite(AvgPool2DOp op, OpAdaptor adaptor,\n                                ConversionPatternRewriter& rewriter) const override {\n    auto get_pair_int64_from_array = [](ArrayAttr arr) -> std::pair<int64_t, int64_t> {\n      return {arr.getValue()[0].cast<IntegerAttr>().getSInt(),\n              arr.getValue()[1].cast<IntegerAttr>().getSInt()};\n    };\n\n    auto stride_pairs = get_pair_int64_from_array(op.getStride());\n    auto pad_pairs = get_pair_int64_from_array(op.getPadding());\n    auto kernel_pairs = get_pair_int64_from_array(op.getKernelSize());\n\n    auto loc = op.getLoc();\n    auto perms = {0, 2, 3, 1};\n\n    const auto kernel = rewriter.getDenseI64ArrayAttr({kernel_pairs.first, kernel_pairs.second});\n    const auto stride = rewriter.getDenseI64ArrayAttr({stride_pairs.first, stride_pairs.second});\n    const auto pad = rewriter.getDenseI64ArrayAttr(\n        {pad_pairs.first, pad_pairs.second, pad_pairs.first, pad_pairs.second});\n\n    auto input = CreateTransposeValue(loc, rewriter, op.getX(), perms);\n    auto output = CreateTransposeType(op.getY().getType().cast<ShapedType>(), perms);\n\n    auto avg_pool2d = rewriter.create<tosa::AvgPool2dOp>(loc, output, input, kernel, stride, pad);\n\n    auto out = CreateTransposeValue(loc, rewriter, avg_pool2d, {0, 3, 1, 2});\n    rewriter.replaceOp(op, {out});\n    return success();\n  }\n};\n\nstruct MaxPool2DOpLowering final : public OpConversionPattern<MaxPool2DOp> {\n public:\n  using OpConversionPattern<MaxPool2DOp>::OpConversionPattern;\n  LogicalResult matchAndRewrite(MaxPool2DOp op, OpAdaptor adaptor,\n                                ConversionPatternRewriter& rewriter) const override {\n    auto get_pair_int64_from_array = [](ArrayAttr arr) -> std::pair<int64_t, int64_t> {\n      return {arr.getValue()[0].cast<IntegerAttr>().getSInt(),\n              arr.getValue()[1].cast<IntegerAttr>().getSInt()};\n    };\n    // TODO: support return indice\n    if (op.getReturnIndices()) { return op->emitError(\"not support return indices now\"); }\n    auto stride_pairs = get_pair_int64_from_array(op.getStride());\n    auto kernel_pairs = get_pair_int64_from_array(op.getKernelSize());\n    auto pad_pairs = get_pair_int64_from_array(op.getPadding());\n\n    auto loc = op.getLoc();\n\n    const auto kernel = rewriter.getDenseI64ArrayAttr({kernel_pairs.first, kernel_pairs.second});\n    const auto stride = rewriter.getDenseI64ArrayAttr({stride_pairs.first, stride_pairs.second});\n    const auto pad = rewriter.getDenseI64ArrayAttr(\n        {pad_pairs.first, pad_pairs.second, pad_pairs.first, pad_pairs.second});\n\n    auto input = op.getX();\n    auto out_type = op.getY().getType().cast<ShapedType>();\n\n    Value y;\n    if (op.IsNCHW()) {\n      auto perms = {0, 2, 3, 1};\n      auto reverse_perms = {0, 3, 1, 2};\n      input = CreateTransposeValue(loc, rewriter, input, perms);\n      out_type = CreateTransposeType(out_type, perms);\n      auto max_pool2d =\n          rewriter.create<tosa::MaxPool2dOp>(loc, out_type, input, kernel, stride, pad);\n      y = CreateTransposeValue(loc, rewriter, max_pool2d, reverse_perms);\n    } else {\n      y = rewriter.create<tosa::MaxPool2dOp>(loc, out_type, input, kernel, stride, pad);\n    }\n\n    auto indice_output = convertToSignless(op->getContext(), op.getIndice().getType());\n    auto value = DenseElementsAttr::get(indice_output.cast<ShapedType>(),\n                                        rewriter.getZeroAttr(rewriter.getI64Type()));\n    tosa::ConstOp indice = rewriter.create<tosa::ConstOp>(loc, indice_output, value);\n    rewriter.replaceOp(op, {y, indice});\n    return success();\n  }\n};\n\nstruct ReshapeOpLowering final : public OpConversionPattern<ReshapeOp> {\n public:\n  using OpConversionPattern<ReshapeOp>::OpConversionPattern;\n  LogicalResult matchAndRewrite(ReshapeOp op, OpAdaptor adaptor,\n                                ConversionPatternRewriter& rewriter) const override {\n    auto output = op.getOut().getType();\n    auto input = op.getIn();\n    llvm::SmallVector<int64_t> new_shape;\n    for (const auto& dim_attr : op.getShape()) {\n      new_shape.push_back(dim_attr.cast<IntegerAttr>().getSInt());\n    }\n    rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(op, output, input,\n                                                 rewriter.getDenseI64ArrayAttr(new_shape));\n    return success();\n  }\n};\n\n// transpose the last two dims of the tensor. Reshape it to 3D if it is 2D.\nValue transposeAndReshapeIfRequired(Location loc, ConversionPatternRewriter& rewriter, Value matrix,\n                                    bool transpose) {\n  auto shape_type = matrix.getType().cast<ShapedType>();\n  CHECK(shape_type.getRank() == 2 || shape_type.getRank() == 3);\n  if (transpose) {\n    if (shape_type.getRank() == 2) {\n      matrix = CreateTransposeValue(loc, rewriter, matrix, {1, 0});\n      shape_type = matrix.getType().cast<ShapedType>();\n      llvm::SmallVector<int64_t, 4> reshape_dims{1, shape_type.getDimSize(0),\n                                                 shape_type.getDimSize(1)};\n      auto reshape_type = RankedTensorType::get(reshape_dims, shape_type.getElementType());\n      return rewriter.create<tosa::ReshapeOp>(loc, reshape_type, matrix,\n                                              rewriter.getDenseI64ArrayAttr(reshape_dims));\n    } else if (shape_type.getRank() == 3) {\n      return CreateTransposeValue(loc, rewriter, matrix, {0, 2, 1});\n    } else {\n      return Value{};\n    }\n  } else if (shape_type.getRank() == 2) {\n    llvm::SmallVector<int64_t, 4> reshape_dims{1, shape_type.getDimSize(0),\n                                               shape_type.getDimSize(1)};\n    auto reshape_type = RankedTensorType::get(reshape_dims, shape_type.getElementType());\n    return rewriter.create<tosa::ReshapeOp>(loc, reshape_type, matrix,\n                                            rewriter.getDenseI64ArrayAttr(reshape_dims));\n  }\n  return matrix;\n}\n\n// Reshape: 2D -> 3D -> tosa.matmul -> 3D -> 2D\nstruct MatmulOpLowering final : public OpConversionPattern<MatmulOp> {\n public:\n  using OpConversionPattern<MatmulOp>::OpConversionPattern;\n  LogicalResult matchAndRewrite(MatmulOp op, OpAdaptor adaptor,\n                                ConversionPatternRewriter& rewriter) const override {\n    auto a = transposeAndReshapeIfRequired(op->getLoc(), rewriter, op.getA(), op.getTransposeA());\n    auto b = transposeAndReshapeIfRequired(op->getLoc(), rewriter, op.getB(), op.getTransposeB());\n\n    const auto out_shape_type = op.getOut().getType().cast<ShapedType>();\n    const auto out_reshape_type =\n        RankedTensorType::get({1, out_shape_type.getDimSize(0), out_shape_type.getDimSize(1)},\n                              out_shape_type.getElementType());\n\n    auto matmul = rewriter.create<tosa::MatMulOp>(op.getLoc(), out_reshape_type, a, b);\n    const auto new_shape =\n        rewriter.getDenseI64ArrayAttr({out_shape_type.getDimSize(0), out_shape_type.getDimSize(1)});\n\n    rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(op, out_shape_type, matmul, new_shape);\n    return success();\n  }\n};\n\nstruct BatchMatmulOpLowering final : public OpConversionPattern<BatchMatmulOp> {\n public:\n  using OpConversionPattern<BatchMatmulOp>::OpConversionPattern;\n  LogicalResult matchAndRewrite(BatchMatmulOp op, OpAdaptor adaptor,\n                                ConversionPatternRewriter& rewriter) const override {\n    auto a = transposeAndReshapeIfRequired(op->getLoc(), rewriter, op.getA(), op.getTransposeA());\n    auto b = transposeAndReshapeIfRequired(op->getLoc(), rewriter, op.getB(), op.getTransposeB());\n    rewriter.replaceOpWithNewOp<tosa::MatMulOp>(op, op.getOut().getType(), a, b);\n    return success();\n  }\n};\n\nstruct NormalizationInferenceOpLowering final\n    : public OpConversionPattern<NormalizationInferenceOp> {\n public:\n  using OpConversionPattern<NormalizationInferenceOp>::OpConversionPattern;\n  LogicalResult matchAndRewrite(NormalizationInferenceOp op, OpAdaptor adaptor,\n                                ConversionPatternRewriter& rewriter) const override {\n    auto loc = op->getLoc();\n\n    const auto epsilon_type = RankedTensorType::get({}, rewriter.getF32Type());\n    auto epsilon = rewriter.create<tosa::ConstOp>(\n        loc, epsilon_type, DenseElementsAttr::get(epsilon_type, op.getEpsilon()));\n    auto mean = op.getMovingMean();\n    auto variance = op.getMovingVariance();\n    auto gamma = op.getGamma();\n    auto beta = op.getBeta();\n    auto output_type = op.getY().getType();\n    Value x = op.getX();\n\n    if (op.IsNCHW()) {\n      const auto perms = {0, 2, 3, 1};\n      x = CreateTransposeValue(loc, rewriter, x, perms);\n      output_type = CreateTransposeType(output_type, perms);\n    }\n\n    auto batch_norm =\n        oneflow::CreateBNOp(loc, rewriter, output_type, x, mean, variance, epsilon, gamma, beta);\n\n    if (op.IsNCHW()) {\n      const auto reverse_perms = {0, 3, 1, 2};\n      batch_norm = CreateTransposeValue(loc, rewriter, batch_norm, reverse_perms);\n    }\n    rewriter.replaceOp(op, {batch_norm});\n    return success();\n  }\n};\n\nstruct NormalizationOpLowering final : public OpConversionPattern<NormalizationOp> {\n public:\n  using OpConversionPattern<NormalizationOp>::OpConversionPattern;\n  LogicalResult matchAndRewrite(NormalizationOp op, OpAdaptor adaptor,\n                                ConversionPatternRewriter& rewriter) const override {\n    auto loc = op->getLoc();\n\n    const auto epsilon_type = RankedTensorType::get({}, rewriter.getF32Type());\n    auto epsilon = rewriter.create<tosa::ConstOp>(\n        loc, epsilon_type, DenseElementsAttr::get(epsilon_type, op.getEpsilon()));\n    auto mean = op.getMovingMean();\n    auto variance = op.getMovingVariance();\n    auto gamma = op.getGamma();\n    auto beta = op.getBeta();\n    auto output_type = op.getY().getType();\n    Value x = op.getX();\n\n    if (op.IsNCHW()) {\n      const auto perms = {0, 2, 3, 1};\n      x = CreateTransposeValue(loc, rewriter, x, perms);\n      output_type = CreateTransposeType(output_type, perms);\n    }\n\n    auto batch_norm =\n        oneflow::CreateBNOp(loc, rewriter, output_type, x, mean, variance, epsilon, gamma, beta);\n\n    if (op.IsNCHW()) {\n      const auto reverse_perms = {0, 3, 1, 2};\n      batch_norm = CreateTransposeValue(loc, rewriter, batch_norm, reverse_perms);\n    }\n    auto moving_mean = op.getMovingMean();\n    auto moving_variance = op.getMovingVariance();\n\n    rewriter.replaceOp(op, {batch_norm, moving_mean, moving_variance});\n    return success();\n  }\n};\n\nstruct Conv2DOpLowering final : public OpConversionPattern<Conv2DOp> {\n public:\n  using OpConversionPattern<Conv2DOp>::OpConversionPattern;\n  LogicalResult matchAndRewrite(Conv2DOp op, OpAdaptor adaptor,\n                                ConversionPatternRewriter& rewriter) const override {\n    auto get_pair_int64_from_array = [](ArrayAttr arr) -> std::pair<int64_t, int64_t> {\n      return {arr.getValue()[0].cast<IntegerAttr>().getSInt(),\n              arr.getValue()[1].cast<IntegerAttr>().getSInt()};\n    };\n\n    auto stride_pairs = get_pair_int64_from_array(op.getStrides());\n    auto pad_pairs = get_pair_int64_from_array(op.getPaddingBeforeAttr());\n    auto dilation_pairs = get_pair_int64_from_array(op.getDilationRate());\n\n    const auto pad = rewriter.getDenseI64ArrayAttr(\n        {pad_pairs.first, pad_pairs.second, pad_pairs.first, pad_pairs.second});\n    const auto stride = rewriter.getDenseI64ArrayAttr({stride_pairs.first, stride_pairs.second});\n    const auto dilation =\n        rewriter.getDenseI64ArrayAttr({dilation_pairs.first, dilation_pairs.second});\n\n    auto bias = op.getBias();\n    auto loc = op.getLoc();\n    if (!bias) {\n      const auto output_shape = op.getOut().getType().cast<ShapedType>();\n      // support nhwc\n      const auto output_channels = output_shape.getDimSize(op.IsNCHW() ? 1 : 3);\n      const auto bias_elem_type = output_shape.getElementType();\n      const auto type = RankedTensorType::get(output_channels, bias_elem_type);\n      bias = rewriter.create<tosa::ConstOp>(\n          op.getLoc(), type, DenseElementsAttr::get(type, rewriter.getZeroAttr(bias_elem_type)));\n    }\n\n    Value in = op.getIn();\n    Value weight = op.getWeight();\n    auto out_type = op.getOut().getType().cast<ShapedType>();\n    if (out_type.getRank() != 4) {\n      LOG(FATAL) << \"Failed to lowering oneflow op\";\n      op->dump();\n    }\n    // support nhwc\n    if (op.IsNCHW()) {\n      const auto perms = {0, 2, 3, 1};\n      const auto reverse_perms = {0, 3, 1, 2};\n      in = CreateTransposeValue(loc, rewriter, in, perms);\n      weight = CreateTransposeValue(loc, rewriter, weight, perms);\n      out_type = CreateTransposeType(out_type, perms);\n      auto conv2d =\n          rewriter.create<tosa::Conv2DOp>(loc, out_type, in, weight, bias, pad, stride, dilation);\n\n      auto res = CreateTransposeValue(loc, rewriter, conv2d, reverse_perms);\n      rewriter.replaceOp(op, {res});\n    } else {\n      rewriter.replaceOpWithNewOp<tosa::Conv2DOp>(op, out_type, in, weight, bias, pad, stride,\n                                                  dilation);\n    }\n    return success();\n  }\n};\n\nstruct TransposeOpLowering final : public OpConversionPattern<TransposeOp> {\n public:\n  using OpConversionPattern<TransposeOp>::OpConversionPattern;\n  LogicalResult matchAndRewrite(TransposeOp op, OpAdaptor adaptor,\n                                ConversionPatternRewriter& rewriter) const override {\n    llvm::SmallVector<int32_t, 4> perms{};\n    for (auto dim : op.getPerm().getAsValueRange<mlir::IntegerAttr>()) {\n      perms.push_back(dim.getSExtValue());\n    }\n    llvm::SmallVector<int64_t, 4> perms_shape(op.getPerm().size(), 1);\n    auto perms_op = rewriter.create<tosa::ConstOp>(\n        op->getLoc(), RankedTensorType::get(perms_shape, rewriter.getI32Type()),\n        rewriter.getI32TensorAttr(perms));\n    rewriter.replaceOpWithNewOp<tosa::TransposeOp>(op, op.getOutput().getType(), op.getInput(),\n                                                   perms_op.getOutput());\n    return success();\n  }\n};\n\nstruct CastInputConversion final : public OpRewritePattern<InputOp> {\n public:\n  explicit CastInputConversion(mlir::MLIRContext* context)\n      : OpRewritePattern<InputOp>(context, /*benefit=*/0) {}\n  mlir::LogicalResult matchAndRewrite(InputOp op, mlir::PatternRewriter& rewriter) const override {\n    auto outType = op.getOutput().getType();\n    if (isSignLessTensorOrOther(outType)) { return failure(); }\n    if (op->hasOneUse()) {\n      if (auto cast =\n              llvm::dyn_cast<UnrealizedConversionCastOp>(op.getOutput().use_begin()->getOwner())) {\n        if (isSignLessTensorOrOther(cast.getResult(0).getType())) { return failure(); }\n      }\n    }\n    InputOp cloned = rewriter.create<InputOp>(op->getLoc(), op.getResultTypes(), op->getOperands(),\n                                              op->getAttrs());\n    rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(\n        op, convertToSignless(getContext(), op.getOutput().getType()), cloned.getOutput());\n    return success();\n  }\n};\n\nstruct CastVariableConversion final : public OpRewritePattern<VariableOp> {\n public:\n  explicit CastVariableConversion(mlir::MLIRContext* context)\n      : OpRewritePattern<VariableOp>(context, /*benefit=*/0) {}\n  mlir::LogicalResult matchAndRewrite(VariableOp op,\n                                      mlir::PatternRewriter& rewriter) const override {\n    auto outType = op.getOutput().getType();\n    if (isSignLessTensorOrOther(outType)) { return failure(); }\n    if (op->hasOneUse()) {\n      if (auto cast =\n              llvm::dyn_cast<UnrealizedConversionCastOp>(op.getOutput().use_begin()->getOwner())) {\n        if (isSignLessTensorOrOther(cast.getResult(0).getType())) { return failure(); }\n      }\n    }\n    if (op.getOutput().getUses().empty()) { return failure(); }\n    VariableOp cloned = rewriter.create<VariableOp>(op->getLoc(), op.getResultTypes(),\n                                                    op->getOperands(), op->getAttrs());\n    rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(\n        op, convertToSignless(getContext(), op.getOutput().getType()), cloned.getOutput());\n    return success();\n  }\n};\n\nnamespace {\n\nclass CastOneFlowOpsToSignlessPass\n    : public CastOneFlowOpsToSignlessPassBase<CastOneFlowOpsToSignlessPass> {\n  void getDependentDialects(::mlir::DialectRegistry& registry) const override {\n    registry.insert<oneflow::OneFlowDialect>();\n  }\n  void runOnOperation() override {\n    Operation* op = getOperation();\n    RewritePatternSet patterns(&getContext());\n    patterns.add<oneflow::CastInputConversion, oneflow::CastVariableConversion>(op->getContext());\n\n    (void)applyPatternsAndFoldGreedily(op, std::move(patterns));\n  }\n};\n\nstruct OneFlowLoweringToTosaPass : public LowerOneFlowToTosaPassBase<OneFlowLoweringToTosaPass> {\n  void runOnOperation() override;\n};\n\nstruct ConvertToSignlessForTosaPass\n    : public ConvertToSignlessForTosaPassBase<ConvertToSignlessForTosaPass> {\n  void runOnOperation() override;\n};\n\n}  // namespace\n\nstd::unique_ptr<Pass> createLowerOneFlowToTosaPass() {\n  return std::make_unique<OneFlowLoweringToTosaPass>();\n}\n\nstd::unique_ptr<Pass> createConvertToSignlessForTosaPass() {\n  return std::make_unique<ConvertToSignlessForTosaPass>();\n}\n\nvoid OneFlowLoweringToTosaPass::runOnOperation() {\n  MLIRContext* context = &getContext();\n  ConversionTarget target(*context);\n  target.addLegalDialect<memref::MemRefDialect, mlir::func::FuncDialect, tosa::TosaDialect,\n                         tensor::TensorDialect, arith::ArithDialect, BuiltinDialect>();\n  if (fullyConvert) { target.addIllegalDialect<OneFlowDialect>(); }\n\n  TypeConverter typeConverter;\n  typeConverter.addConversion([context](Type type) { return convertToSignless(context, type); });\n  typeConverter.addSourceMaterialization(\n      [&](OpBuilder& builder, Type resultType, ValueRange inputs, Location loc) -> Optional<Value> {\n        CHECK_EQ(inputs.size(), 1) << \"expect to materialize a single value\";\n        return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs).getResult(0);\n      });\n  typeConverter.addTargetMaterialization(\n      [&](OpBuilder& builder, Type resultType, ValueRange inputs, Location loc) -> Optional<Value> {\n        CHECK_EQ(inputs.size(), 1) << \"expect to materialize a single value\";\n        return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs).getResult(0);\n      });\n  RewritePatternSet patterns(context);\n\n  // check if the pass is triggered by python based on the presence of variable tensor manger\n  if (fullyConvert) {\n    if (::oneflow::Singleton<::oneflow::VariableTensorMgr>::Get()) {\n      patterns.add<VariableOpLowering>(typeConverter, context);\n    } else {\n      patterns.add<VariableOpToConstLowering>(typeConverter, context, this->variableAsConstant);\n    }\n  }\n  patterns.add<CastOpLowering, ScalarMulByTensorOpLowering, ReluOpLowering, Conv2DOpLowering,\n               AvgPool2DOpLowering, ReshapeOpLowering, Add2OpLowering, MaxPool2DOpLowering,\n               MatmulOpLowering, BatchMatmulOpLowering, BroadcastAddOpLowering,\n               NormalizationOpLowering, NormalizationInferenceOpLowering, TransposeOpLowering>(\n      typeConverter, context);\n  if (lowerJob) {\n    patterns.add<InputOpLowering, OutputOpLowering, JobLowering, ReturnOpLowering>(typeConverter,\n                                                                                   context);\n  }\n  if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) {\n    signalPassFailure();\n    LOG(ERROR) << \"Failed to lower OneFlow to Tosa\";\n    getOperation()->dump();\n  }\n}\n\nstruct ConvertReturnToSignlessPattern : public OpRewritePattern<func::ReturnOp> {\n  explicit ConvertReturnToSignlessPattern(::mlir::MLIRContext* context)\n      : OpRewritePattern<func::ReturnOp>(context, /*benefit=*/1) {}\n  ::mlir::LogicalResult matchAndRewrite(func::ReturnOp op,\n                                        ::mlir::PatternRewriter& rewriter) const override {\n    // make sure result not converted\n    if (allSignless(op.getOperandTypes())) { return failure(); }\n    llvm::SmallVector<Type, 1> results;\n    for (auto res : op->getOperandTypes()) {\n      results.push_back(convertToSignless(op->getContext(), res));\n    }\n    auto uc = rewriter.create<UnrealizedConversionCastOp>(op->getLoc(), results, op.getOperands());\n    rewriter.replaceOpWithNewOp<func::ReturnOp>(op, op->getResultTypes(), uc->getResults(),\n                                                op->getAttrs());\n    return success();\n  }\n};\n\nstruct ConvertFuncToSignlessPattern : public OpRewritePattern<func::FuncOp> {\n  explicit ConvertFuncToSignlessPattern(::mlir::MLIRContext* context)\n      : OpRewritePattern<func::FuncOp>(context, /*benefit=*/1) {}\n  ::mlir::LogicalResult matchAndRewrite(func::FuncOp op,\n                                        ::mlir::PatternRewriter& rewriter) const override {\n    if (allSignless(op.getFunctionType())) { return failure(); }\n    auto ft = convertToSignlessFuncType(op->getContext(), op.getFunctionType());\n    auto func = rewriter.create<mlir::func::FuncOp>(op.getLoc(), op.getName(), ft);\n    IRMapping bvm;\n    op.getRegion().cloneInto(&func.getRegion(), bvm);\n    for (auto& block : func.getBody().getBlocks()) {\n      for (auto arg : block.getArguments()) {\n        auto new_type = convertToSignless(op.getContext(), arg.getType());\n        arg.setType(new_type);\n        for (auto* use : arg.getUsers()) {\n          if (auto input = llvm::dyn_cast_or_null<InputOp>(use)) {\n            input.getOutput().setType(new_type);\n          }\n        }\n      }\n    }\n    rewriter.eraseOp(op);\n    RewritePatternSet patterns(func->getContext());\n    patterns.add<ConvertReturnToSignlessPattern>(func->getContext());\n    (void)applyPatternsAndFoldGreedily(func, std::move(patterns));\n    return success();\n  }\n};\n\nvoid ConvertToSignlessForTosaPass::runOnOperation() {\n  Operation* op = getOperation();\n  RewritePatternSet patterns(op->getContext());\n  patterns.add<ConvertFuncToSignlessPattern>(op->getContext());\n  (void)applyPatternsAndFoldGreedily(op, std::move(patterns));\n}\n\nstd::unique_ptr<Pass> createCastOneFlowOpsToSignlessPass() {\n  return std::make_unique<CastOneFlowOpsToSignlessPass>();\n}\n\n}  // namespace oneflow\n\n}  // namespace mlir\n"
  },
  {
    "path": "oneflow/ir/lib/OneFlow/OKL/Conversion/Conversion.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"OneFlow/OKL/Conversion/Conversion.h\"\n#include \"OneFlow/OKL/Conversion/OKLToLLVM.h\"\n#include \"OneFlow/Passes.h\"\n#include \"OneFlow/Transform/OutlineAndFuse.h\"\n#include \"mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h\"\n#include \"mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h\"\n#include \"mlir/Pass/PassManager.h\"\n#include \"oneflow/ir/include/OneFlow/OneFlowUtils.h\"\n\nnamespace mlir {\nnamespace okl {\n\nLogicalResult LowerOKLComputeToLLVM(ModuleOp module) {\n  PassManager pm(module->getContext());\n  pm.addPass(createLowerLauncherToLLVMPtrPass());    // lower-launcher-to-llvm-ptr\n  pm.addPass(createLowerOKLToLLVMCallPass());        // lower-okl-to-llvm-call\n  pm.addPass(createConvertFuncToLLVMPass());         // convert-func-to-llvm\n  pm.addPass(createReconcileUnrealizedCastsPass());  // reconcile-unrealized-casts\n  oneflow::CheckEnableIRPrinting(pm);\n  return pm.run(module);\n}\n\n}  // namespace okl\n}  // namespace mlir\n"
  },
  {
    "path": "oneflow/ir/lib/OneFlow/OKL/Conversion/CudaGraphSupport.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"OneFlow/OKL/Kernel/JITEngine.h\"\n#include \"oneflow/core/framework/op_kernel.h\"\n#include \"oneflow/core/kernel/cuda_graph_support.h\"\n#include \"OneFlow/OKL/Kernel/RegContext.h\"\n#include \"OneFlow/OKL/OKLDialect.h\"\n#include \"OneFlow/OKL/OKLOps.h\"\n#include \"OneFlow/OKL/OKLTypes.h\"\n#include \"OneFlow/OKL/passes.h\"\n#include \"OneFlow/OKM/passes.h\"\n#include \"OneFlow/OneFlowDialect.h\"\n#include \"OneFlow/Passes.h\"\n#include \"mlir/Dialect/Func/IR/FuncOps.h\"\n#include \"mlir/Dialect/LLVMIR/LLVMDialect.h\"\n#include \"mlir/Dialect/LLVMIR/LLVMTypes.h\"\n#include \"mlir/IR/IRMapping.h\"\n#include \"mlir/IR/Builders.h\"\n#include \"mlir/IR/BuiltinAttributes.h\"\n#include \"mlir/IR/BuiltinOps.h\"\n#include \"mlir/IR/BuiltinTypes.h\"\n#include \"mlir/IR/OperationSupport.h\"\n#include \"mlir/Pass/Pass.h\"\n#include \"mlir/Pass/PassManager.h\"\n#include \"mlir/Support/LogicalResult.h\"\n#include \"mlir/Transforms/DialectConversion.h\"\n#include \"mlir/Transforms/GreedyPatternRewriteDriver.h\"\n#include \"mlir/Transforms/Passes.h\"\n#include \"llvm/Support/raw_ostream.h\"\n\nnamespace mlir {\nnamespace okl {\nstruct TagCudaGraphSupportPattern final : public mlir::OpRewritePattern<func::FuncOp> {\n  static mlir::Operation* FindOneFlowOp(mlir::Operation* op) {\n    mlir::Operation* reg_op = nullptr;\n    for (auto& op_it : op->getRegion(0).front().getOperations()) {\n      if (op_it.getDialect()->getNamespace() != \"oneflow\") { continue; }\n      reg_op = &op_it;\n      break;\n    }\n    return reg_op;\n  }\n\n  static LogicalResult CheckChild(func::FuncOp func) {\n    using namespace ::oneflow::user_op;\n    for (auto& op : func->getRegion(0).front()) {\n      if (auto reg_ctx_op = llvm::dyn_cast_or_null<mlir::okl::WrapperKernelOp>(&op)) {\n        // iter reg context op\n        const auto reg_op = FindOneFlowOp(&op);\n        if (!reg_op) {\n          func->emitError(\"Failed to find reg_op in okl.build_reg_context_op\");\n          return failure();\n        }\n        // generate kernel from oneflow.{compute op}\n        ::oneflow::okl::RegContext reg_ctx(reg_op);\n        auto* kernel = const_cast<OpKernel*>(reg_ctx.GetKernel());\n\n        // check whether cuda graph support is base class\n        if (const auto* cuda_graph_support = dynamic_cast<CudaGraphSupport*>(kernel)) {\n          // TODO: more check\n          continue;\n        }\n        return failure();\n      }\n    }\n    return success();\n  }\n\n public:\n  explicit TagCudaGraphSupportPattern(mlir::MLIRContext* context)\n      : OpRewritePattern<func::FuncOp>(context, /*benefit=*/0) {}\n  mlir::LogicalResult matchAndRewrite(func::FuncOp op,\n                                      mlir::PatternRewriter& rewriter) const override {\n    const auto tag_name = mlir::okl::cuda_graph_support::TAG_NAME;\n    // check whether this op is okl init context function  op\n    if (!op.getSymName().startswith(mlir::okm::func_name::OKL_GRAPH_NAME)) { return failure(); }\n    // check whether this op has been taged before\n    if (op->getAttr(tag_name).dyn_cast_or_null<BoolAttr>() != nullptr) { return success(); }\n    // check whether its childern is all cuda graph supported\n    const auto outcome = succeeded(CheckChild(op));\n\n    // set cuda graph support tag on init_context and compute function ops\n    op->setAttr(tag_name, rewriter.getBoolAttr(outcome));\n    return success();\n  }\n};\n\nnamespace {\nstruct TagCudaGraphSupportPass : public TagCudaGraphSupportPassBase<TagCudaGraphSupportPass> {\n  void runOnOperation() override;\n\n  void getDependentDialects(DialectRegistry& registry) const override {\n    registry.insert<okl::OKLDialect>();\n  }\n};\n}  // namespace\n\nstd::unique_ptr<Pass> createTagCudaGraphSupportPass() {\n  return std::make_unique<TagCudaGraphSupportPass>();\n}\n\nvoid TagCudaGraphSupportPass::runOnOperation() {\n  MLIRContext* context = &getContext();\n  RewritePatternSet patterns(context);\n\n  patterns.add<TagCudaGraphSupportPattern>(context);\n\n  (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));\n}\n\n}  // namespace okl\n}  // namespace mlir\n"
  },
  {
    "path": "oneflow/ir/lib/OneFlow/OKL/Conversion/OKLToLLVM.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"OneFlow/OKL/Kernel/JITEngine.h\"\n#include \"OneFlow/OKL/OKLDialect.h\"\n#include \"OneFlow/OKL/OKLOps.h\"\n#include \"OneFlow/OKL/OKLTypes.h\"\n#include \"OneFlow/OKL/passes.h\"\n#include \"OneFlow/OneFlowDialect.h\"\n#include \"OneFlow/Passes.h\"\n#include \"llvm/Support/raw_ostream.h\"\n#include \"mlir/Dialect/Func/IR/FuncOps.h\"\n#include \"mlir/Dialect/LLVMIR/LLVMDialect.h\"\n#include \"mlir/Dialect/LLVMIR/LLVMTypes.h\"\n#include \"mlir/IR/IRMapping.h\"\n#include \"mlir/IR/Builders.h\"\n#include \"mlir/IR/BuiltinAttributes.h\"\n#include \"mlir/IR/BuiltinOps.h\"\n#include \"mlir/IR/BuiltinTypes.h\"\n#include \"mlir/IR/OperationSupport.h\"\n#include \"mlir/Pass/Pass.h\"\n#include \"mlir/Pass/PassManager.h\"\n#include \"mlir/Support/LogicalResult.h\"\n#include \"mlir/Transforms/DialectConversion.h\"\n#include \"mlir/Transforms/GreedyPatternRewriteDriver.h\"\n#include \"mlir/Transforms/Passes.h\"\n\nnamespace mlir {\nnamespace okl {\n\ntemplate<typename Wrap, typename T>\nModuleOp GetModuleOpFromJobBodyOp(T op) {\n  auto parent_func_op = op->template getParentOfType<Wrap>();\n  if (!parent_func_op) { return nullptr; }\n  return parent_func_op->template getParentOfType<ModuleOp>();\n}\n\n// use this func to union the ptr type in this conversion phase.\nLLVM::LLVMPointerType GetPtrType(::mlir::PatternRewriter& rewriter) {\n  return LLVM::LLVMPointerType::get(IntegerType::get(rewriter.getContext(), 8));\n}\n\nstruct WrapperKernelOpLowering final : public OpConversionPattern<WrapperKernelOp> {\n  using OpConversionPattern<WrapperKernelOp>::OpConversionPattern;\n  using OpAdaptor = typename WrapperKernelOp::Adaptor;\n\n  static LLVM::LLVMFuncOp DeclareLaunchFunc(::mlir::PatternRewriter& rewriter, ModuleOp* module) {\n    LLVM::LLVMFuncOp func;\n    const auto func_name = ::oneflow::okl::llvm_func::LLVM_FUNC;\n    if (!(func = module->lookupSymbol<LLVM::LLVMFuncOp>(func_name))) {\n      OpBuilder::InsertionGuard guard(rewriter);\n      rewriter.setInsertionPointToStart(module->getBody());\n\n      auto void_type = LLVM::LLVMVoidType::get(rewriter.getContext());\n      auto func_type = LLVM::LLVMFunctionType::get(\n          void_type, {GetPtrType(rewriter), rewriter.getI64Type()}, false);\n\n      func = rewriter.create<LLVM::LLVMFuncOp>(rewriter.getUnknownLoc(), func_name, func_type,\n                                               LLVM::Linkage::External);\n      func->setAttr(\"llvm.emit_c_interface\", mlir::UnitAttr::get(rewriter.getContext()));\n    }\n    return func;\n  }\n\n  LogicalResult matchAndRewrite(WrapperKernelOp op, OpAdaptor adaptor,\n                                ConversionPatternRewriter& rewriter) const override {\n    auto module = GetModuleOpFromJobBodyOp<func::FuncOp>(op);\n    if (!module) { LOG(FATAL) << \"Failed to lowering llvm call because of op is not in a module\"; };\n\n    auto launch_func = DeclareLaunchFunc(rewriter, &module);\n    auto launcher_ctx = op->getParentOfType<func::FuncOp>().getBody().getArgument(0);\n    auto index_op = rewriter.create<LLVM::ConstantOp>(op->getLoc(), rewriter.getI64Type(),\n                                                      rewriter.getIndexAttr(op.getIndex()));\n    auto new_op = rewriter.create<LLVM::CallOp>(op->getLoc(), launch_func,\n                                                ValueRange{launcher_ctx, index_op});\n    rewriter.replaceOp(op, new_op.getResults());\n    return success();\n  }\n};\n\n// erase type of okl.launcher_ctx and get opaque ptr\n// llvm.ptr<i8> -> okl.launcher_ctx }\nstruct RewriteFunctionArgsPattern final : public mlir::OpRewritePattern<func::FuncOp> {\n  static LogicalResult ConvertLauncherToLLVMPtr(func::FuncOp op, mlir::PatternRewriter& rewriter) {\n    auto func_type = rewriter.getFunctionType({GetPtrType(rewriter)}, {});\n    auto func = rewriter.create<mlir::func::FuncOp>(op.getLoc(), op.getSymName(), func_type);\n    func->setAttr(\"llvm.emit_c_interface\", mlir::UnitAttr::get(rewriter.getContext()));\n    IRMapping bvm;\n    op.getRegion().cloneInto(&func.getRegion(), bvm);\n    auto& block = func.getBody().getBlocks().front();\n    auto launcher_ctx = block.getArgument(0);\n\n    OpBuilder::InsertionGuard guard(rewriter);\n    rewriter.setInsertionPointToStart(&block);\n    auto cast_op = rewriter.create<UnrealizedConversionCastOp>(op->getLoc(), launcher_ctx.getType(),\n                                                               launcher_ctx);\n    launcher_ctx.setType(GetPtrType(rewriter));\n    launcher_ctx.replaceAllUsesExcept(cast_op->getResult(0), {cast_op});\n    rewriter.setInsertionPointToEnd(&block);\n    rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(&block.back(), ValueRange());\n    rewriter.eraseOp(op);\n    return success();\n  }\n\n public:\n  explicit RewriteFunctionArgsPattern(mlir::MLIRContext* context)\n      : OpRewritePattern<func::FuncOp>(context, /*benefit=*/0) {}\n  mlir::LogicalResult matchAndRewrite(func::FuncOp op,\n                                      mlir::PatternRewriter& rewriter) const override {\n    if (op.getNumArguments() == 1\n        && op.getArgumentTypes().begin()->isa<okl::LauncherContextType>()) {\n      return ConvertLauncherToLLVMPtr(op, rewriter);\n    }\n    return success();\n  }\n};\n\nnamespace {\nstruct LowerLauncherToLLVMPtrPass\n    : public LowerLauncherToLLVMPtrPassBase<LowerLauncherToLLVMPtrPass> {\n  void runOnOperation() override;\n\n  void getDependentDialects(DialectRegistry& registry) const override {\n    registry.insert<LLVM::LLVMDialect>();\n    registry.insert<okl::OKLDialect>();\n  }\n};\n\nstruct LowerOKLToLLVMCallPass : public LowerOKLToLLVMCallPassBase<LowerOKLToLLVMCallPass> {\n  void runOnOperation() override;\n\n  void getDependentDialects(DialectRegistry& registry) const override {\n    registry.insert<LLVM::LLVMDialect>();\n    registry.insert<okl::OKLDialect>();\n  }\n};\n}  // namespace\n\nstd::unique_ptr<Pass> createLowerOKLToLLVMCallPass() {\n  return std::make_unique<LowerOKLToLLVMCallPass>();\n}\nstd::unique_ptr<Pass> createLowerLauncherToLLVMPtrPass() {\n  return std::make_unique<LowerLauncherToLLVMPtrPass>();\n}\n\nvoid LowerLauncherToLLVMPtrPass::runOnOperation() {\n  MLIRContext* context = &getContext();\n  RewritePatternSet patterns(context);\n\n  patterns.add<RewriteFunctionArgsPattern>(context);\n\n  (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));\n}\n\nvoid LowerOKLToLLVMCallPass::runOnOperation() {\n  MLIRContext* context = &getContext();\n  ConversionTarget target(*context);\n  target.addLegalDialect<LLVM::LLVMDialect>();\n  target.addIllegalDialect<okl::OKLDialect>();\n\n  auto llvm_ptr_type = LLVM::LLVMPointerType::get(IntegerType::get(context, 8));\n  TypeConverter typeConverter;\n  typeConverter.addConversion([&](mlir::okl::LauncherContextType type) { return llvm_ptr_type; });\n\n  RewritePatternSet patterns(context);\n\n  patterns.add<WrapperKernelOpLowering>(typeConverter, context);\n\n  if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) {\n    signalPassFailure();\n    getOperation()->emitError(\"Failed to lower OKL to LLVM Call\");\n  }\n}\n\n}  // namespace okl\n}  // namespace mlir\n"
  },
  {
    "path": "oneflow/ir/lib/OneFlow/OKL/Kernel/ComputeContext.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"OneFlow/OKL/Kernel/ComputeContext.h\"\n#include \"llvm/ADT/TypeSwitch.h\"\n#include \"llvm/Support/Casting.h\"\n#include \"OneFlow/OKL/OKLOps.h\"\n#include \"oneflow/core/common/shape_view.h\"\n\nnamespace oneflow {\nnamespace okl {\n\nuser_op::Tensor* ComputeContext::CreateTensorWithArgNameAndIndex(const std::string& arg_name,\n                                                                 int32_t index) {\n  auto op = reg_ctx_->GetOp();\n  auto source = mlir::oneflow::user_op::GetOpSourceByName(op, arg_name);\n\n  if (source.type == mlir::oneflow::user_op::Source::OUTPUT) {\n    if (op->getNumResults() <= index + source.offset) { return nullptr; }\n    mlir::Value val = op->getResult(index + source.offset);\n    auto use = *val.getUsers().begin();\n    if (auto ret_op = llvm::dyn_cast_or_null<mlir::okl::GetTensorAsRetOp>(use)) {\n      return comp_ctx_->Tensor4ArgNameAndIndex(\"out\", ret_op.getIndex());\n    }\n    if (auto pool_op = llvm::dyn_cast_or_null<mlir::okl::TensorToPoolOp>(use)) {\n      return tmp_buffer_.GetPoolTensor(TensorDesc4ArgNameAndIndex(arg_name, index),\n                                       pool_op.getOffset());\n    }\n    op->emitError(\"Failed to find \" + std::to_string(index) + \"in outputs\");\n    exit(1);\n  }\n\n  if (source.type == mlir::oneflow::user_op::Source::INPUT) {\n    if (op->getNumOperands() <= index + source.offset) { return nullptr; }\n    mlir::Value val = op->getOperand(index + source.offset);\n    auto define_op = val.getDefiningOp();\n    return llvm::TypeSwitch<::mlir::Operation*, user_op::Tensor*>(define_op)\n        .Case([&](mlir::okl::GetTensorFromArgOp elem) {\n          return comp_ctx_->Tensor4ArgNameAndIndex(\"in\", elem.getIndex());\n        })\n        .Case([&](mlir::okl::GetTensorFromRetOp elem) {\n          return comp_ctx_->Tensor4ArgNameAndIndex(\"out\", elem.getIndex());\n        })\n        .Case([&](mlir::okl::PoolToTensorOp elem) {\n          return tmp_buffer_.GetPoolTensor(TensorDesc4ArgNameAndIndex(arg_name, index),\n                                           elem.getOffset());\n        })\n        .Default([&](::mlir::Operation* op) {\n          op->dump();\n          LOG(FATAL) << \"Signature: \" << arg_name << \" Not supported\";\n          return nullptr;\n        });\n  }\n\n  if (source.type == mlir::oneflow::user_op::Source::BUFFER) {\n    auto wrap = op->getParentOfType<mlir::okl::WrapperKernelOp>();\n    for (auto& op : wrap.getBody().front()) {\n      if (auto pool_to_buffer = llvm::dyn_cast_or_null<mlir::okl::PoolToBufferOp>(op)) {\n        return tmp_buffer_.GetPoolBuffer(pool_to_buffer.getType().getShape()[0],\n                                         pool_to_buffer.getOffset());\n      }\n    }\n  }\n\n  op->emitError(\"Failed to check source type\");\n  exit(1);\n}\nuser_op::Tensor* ComputeContext::Tensor4ArgNameAndIndex(const std::string& arg_name,\n                                                        int32_t index) {\n  auto it = tensor_.find({arg_name, index});\n  if (it != tensor_.end()) return it->second;\n  user_op::Tensor* res = CreateTensorWithArgNameAndIndex(arg_name, index);\n  tensor_[{arg_name, index}] = res;\n  return res;\n}\n\n}  // namespace okl\n\n}  // namespace oneflow"
  },
  {
    "path": "oneflow/ir/lib/OneFlow/OKL/Kernel/InferContext.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"OneFlow/OKL/Kernel/InferContext.h\"\n#include \"mlir/IR/MLIRContext.h\"\n#include \"mlir/Parser/Parser.h\"\n#include \"llvm/Support/Casting.h\"\n\nnamespace oneflow {\nnamespace okl {\nusing namespace user_op;\n\nInferContext::InferContext(const RegContext* reg_ctx) : reg_ctx_(reg_ctx) {}\n\nconst TensorDesc* InferContext::LogicalTensorDesc4ArgNameAndIndex(const std::string& arg_name,\n                                                                  int32_t index) const {\n  return reg_ctx_->TensorDesc4ArgNameAndIndex(arg_name, index);\n}\n\nconst Shape& InferContext::InputShape(const std::string& arg_name, int32_t index) const {\n  return Shape4ArgNameAndIndex(arg_name, index);\n}\n\nconst Shape& InferContext::Shape4ArgNameAndIndex(const std::string& arg_name, int32_t index) const {\n  return LogicalTensorDesc4ArgNameAndIndex(arg_name, index)->shape();\n}\n\nconst std::shared_ptr<const AttrVal>& InferContext::Attr4Name(const std::string& attr_name) const {\n  return reg_ctx_->user_op_conf().Attr4Name(attr_name);\n}\n\n}  // namespace okl\n}  // namespace oneflow"
  },
  {
    "path": "oneflow/ir/lib/OneFlow/OKL/Kernel/JITEngine.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"OneFlow/Extension.h\"\n#include \"llvm/ExecutionEngine/ExecutionEngine.h\"\n#include \"mlir/Dialect/Func/IR/FuncOps.h\"\n#include \"mlir/Dialect/LLVMIR/LLVMDialect.h\"\n#include \"mlir/IR/BuiltinAttributes.h\"\n#include \"mlir/IR/Operation.h\"\n#include \"OneFlow/OKL/Kernel/JITEngine.h\"\n#include \"OneFlow/OKL/Kernel/ComputeContext.h\"\n#include \"oneflow/core/device/cuda_util.h\"\n\nextern \"C\" {\nvoid okl_llvm_func(void* launcher, int64_t index) {\n  static_cast<typename std::tuple_element_t<0, oneflow::okl::LLVMLaunchArgs>>(launcher)->Launch(\n      index);\n}\n}  // extern \"C\"\n\nnamespace oneflow {\n\nSharedLibs* MutSharedLibPaths() {\n  static SharedLibs libs = {};\n  return &libs;\n}\n\nconst SharedLibs* SharedLibPaths() { return MutSharedLibPaths(); }\n}  // namespace oneflow\n\noneflow::okl::JITEngine::JITEngine(mlir::ModuleOp module) {\n  llvm::SmallVector<llvm::StringRef, 4> ext_libs(\n      {oneflow::SharedLibPaths()->begin(), oneflow::SharedLibPaths()->end()});\n  mlir::ExecutionEngineOptions jitOptions;\n  jitOptions.transformer = {};\n  jitOptions.jitCodeGenOptLevel = llvm::None;\n  jitOptions.sharedLibPaths = ext_libs;\n\n  auto jit_or_error = mlir::ExecutionEngine::create(module, jitOptions);\n  CHECK(!!jit_or_error) << \"failed to create JIT exe engine, \"\n                        << llvm::toString((jit_or_error).takeError());\n  jit_or_error->swap(engine_);\n}\n"
  },
  {
    "path": "oneflow/ir/lib/OneFlow/OKL/Kernel/JITOpInfer.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"OneFlow/OneFlowDialect.h\"\n#include \"OneFlow/Passes.h\"\n#include \"OneFlow/OneFlowSupport.h\"\n#include \"oneflow/core/common/data_type.pb.h\"\n#include \"oneflow/core/common/device_type.pb.h\"\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/common/shape.h\"\n#include \"oneflow/core/common/throw.h\"\n#include \"oneflow/core/framework/dtype.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n#include \"oneflow/core/operator/operator.h\"\n#include \"oneflow/user/ops/nn_util.h\"\n\n#include \"mlir/Dialect/Func/IR/FuncOps.h\"\n#include \"mlir/IR/Block.h\"\n#include \"mlir/IR/BuiltinTypes.h\"\n#include \"mlir/IR/OwningOpRef.h\"\n#include \"mlir/IR/Types.h\"\n#include \"mlir/InitAllDialects.h\"\n#include \"mlir/Parser/Parser.h\"\n\n#include \"llvm/ADT/SmallVector.h\"\n#include \"llvm/Support/raw_ostream.h\"\n\nnamespace oneflow {\n\nnamespace ir {\n\nnamespace jit {\n\nstatic Maybe<mlir::FunctionType> GetFunctionType(user_op::InferContext* ctx,\n                                                 mlir::OwningOpRef<mlir::ModuleOp>& module) {\n  mlir::func::FuncOp funcOp = mlir::SymbolTable::lookupNearestSymbolFrom<mlir::func::FuncOp>(\n      module.get(), mlir::SymbolRefAttr::get(module->getContext(), ctx->op_name()));\n  CHECK_OR_RETURN(funcOp) << \"Fail to find funcOp of symbol \" << ctx->op_name();\n  const auto funcType = funcOp.getFunctionType();\n  CHECK_EQ_OR_RETURN(funcType.getNumInputs(), ctx->input_size(\"in\"))\n      << \"input size mismatch with mlir assembly\";\n  CHECK_EQ_OR_RETURN(funcType.getNumResults(), ctx->output_size(\"out\"))\n      << \"output size mismatch with mlir assembly\";\n  int32_t arg_i = 0;\n  for (mlir::Type arg_type : funcType.getInputs()) {\n    if (auto rankedTensorType = arg_type.dyn_cast<mlir::RankedTensorType>()) {\n      CHECK_EQ_OR_RETURN(\n          (Shape{rankedTensorType.getShape().begin(), rankedTensorType.getShape().end()}),\n          ctx->InputShape(\"in\", arg_i))\n          << \"arg #\" << arg_i;\n      const auto data_type =\n          mlir::oneflow::support::FromMLIRTypeToOFDataType(rankedTensorType.getElementType());\n      if (mlir::failed(data_type)) { exit(1); }\n      CHECK_EQ_OR_RETURN(data_type.value(), ctx->InputDType(\"in\", arg_i)) << \"arg #\" << arg_i;\n      arg_i += 1;\n    } else {\n      std::string arg_type_str = \"\";\n      llvm::raw_string_ostream os(arg_type_str);\n      arg_type.print(os);\n      THROW(RuntimeError) << \"Unsupported arg type \" << arg_type_str;\n    }\n  }\n  return funcType;\n}\n\nMaybe<void> SetTensorDataType(user_op::InferContext* ctx) {\n  auto mlir_assembly = ctx->Attr<std::vector<char>>(\"mlir_assembly\");\n  mlir::DialectRegistry registry;\n  mlir::registerAllDialects(registry);\n  mlir::MLIRContext context(registry);\n  context.loadDialect<mlir::func::FuncDialect>();\n  context.loadDialect<mlir::oneflow::OneFlowDialect>();\n\n  mlir::OwningOpRef<mlir::ModuleOp> module = mlir::parseSourceString<mlir::ModuleOp>(\n      llvm::StringRef(mlir_assembly.data(), mlir_assembly.size() - 1), &context);\n  if (!module) {\n    LOG(ERROR) << \"Fail to load mlir assembly\";\n    exit(1);\n  }\n\n  if ((*module)->hasAttr(mlir::oneflow::jit::RAW_GRAPH)) {\n    auto raw_graph = (*module)->getAttr(mlir::oneflow::jit::RAW_GRAPH).cast<mlir::StringAttr>();\n    if (raw_graph)\n      module = mlir::parseSourceString<mlir::ModuleOp>(raw_graph.strref(), module->getContext());\n  }\n\n  auto funcType = *JUST(GetFunctionType(ctx, module));\n  int32_t res_i = 0;\n  for (mlir::Type res_type : funcType.getResults()) {\n    if (auto rankedTensorType = res_type.dyn_cast<mlir::RankedTensorType>()) {\n      const auto data_type =\n          mlir::oneflow::support::FromMLIRTypeToOFDataType(rankedTensorType.getElementType());\n      if (mlir::failed(data_type)) { exit(1); }\n      ctx->SetDtype4ArgNameAndIndex(\"out\", res_i, data_type.value());\n      res_i += 1;\n    } else {\n      std::string res_type_str = \"\";\n      llvm::raw_string_ostream os(res_type_str);\n      res_type.print(os);\n      THROW(RuntimeError) << \"Unsupported arg type \" << res_type_str;\n    }\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InferTensorDesc(user_op::InferContext* ctx) {\n  auto mlir_assembly = ctx->Attr<std::vector<char>>(\"mlir_assembly\");\n  mlir::DialectRegistry registry;\n  mlir::registerAllDialects(registry);\n  mlir::MLIRContext context(registry);\n  context.loadDialect<mlir::func::FuncDialect>();\n  context.loadDialect<mlir::oneflow::OneFlowDialect>();\n\n  mlir::OwningOpRef<mlir::ModuleOp> module = mlir::parseSourceString<mlir::ModuleOp>(\n      llvm::StringRef(mlir_assembly.data(), mlir_assembly.size() - 1), &context);\n  if (!module) {\n    LOG(ERROR) << \"Fail to load mlir assembly\";\n    exit(1);\n  }\n\n  if ((*module)->hasAttr(mlir::oneflow::jit::RAW_GRAPH)) {\n    auto raw_graph = (*module)->getAttr(mlir::oneflow::jit::RAW_GRAPH).cast<mlir::StringAttr>();\n    if (raw_graph)\n      module = mlir::parseSourceString<mlir::ModuleOp>(raw_graph.strref(), module->getContext());\n  }\n\n  auto funcType = *JUST(GetFunctionType(ctx, module));\n  int32_t res_i = 0;\n  for (mlir::Type res_type : funcType.getResults()) {\n    if (auto rankedTensorType = res_type.dyn_cast<mlir::RankedTensorType>()) {\n      ctx->SetOutputShape(\n          \"out\", res_i,\n          Shape{rankedTensorType.getShape().begin(), rankedTensorType.getShape().end()});\n\n      const auto data_type =\n          mlir::oneflow::support::FromMLIRTypeToOFDataType(rankedTensorType.getElementType());\n      if (mlir::failed(data_type)) { exit(1); }\n      ctx->SetOutputDType(\"out\", res_i, data_type.value());\n      llvm::SmallVector<int64_t> strides;\n      int64_t _;\n      auto mem_type =\n          mlir::MemRefType::get(rankedTensorType.getShape(), rankedTensorType.getElementType());\n      if (failed(mlir::getStridesAndOffset(mem_type, strides, _))) {\n        LOG(FATAL) << \"Fail to get stride from memory type\";\n      }\n      ctx->SetOutputStride(\"out\", res_i, Stride(strides.begin(), strides.end()));\n      res_i += 1;\n    } else {\n      std::string res_type_str = \"\";\n      llvm::raw_string_ostream os(res_type_str);\n      res_type.print(os);\n      THROW(RuntimeError) << \"Unsupported arg type \" << res_type_str;\n    }\n  }\n  return Maybe<void>::Ok();\n}\n\n}  // namespace jit\n\n}  // namespace ir\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/ir/lib/OneFlow/OKL/Kernel/KernelLaunchOp.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"OneFlow/OKL/Conversion/Conversion.h\"\n#include \"OneFlow/OneFlowDialect.h\"\n#include \"OneFlow/OneFlowOps.h\"\n#include \"OneFlow/UserOpReflection.h\"\n#include \"OneFlow/Passes.h\"\n#include \"OneFlow/Extension.h\"\n#include \"oneflow/core/common/singleton.h\"\n#include \"oneflow/core/common/str_util.h\"\n#include \"oneflow/core/common/switch_func.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_kernel.h\"\n#include \"oneflow/core/kernel/blob_tensor_view.h\"\n#include \"oneflow/core/kernel/cuda_graph_support.h\"\n#include \"oneflow/core/kernel/new_kernel_util.h\"\n#include \"oneflow/core/persistence/tee_persistent_log_stream.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n#include \"OneFlow/OKL/Kernel/JITOpInfer.h\"\n#include \"OneFlow/OKL/Kernel/JITEngine.h\"\n#include \"OneFlow/OKL/Kernel/LauncherState.h\"\n#include \"OneFlow/OKL/Kernel/TmpBufferManager.h\"\n\n#include \"mlir/IR/DialectRegistry.h\"\n#include \"mlir/Parser/Parser.h\"\n#include \"mlir/Dialect/Func/IR/FuncOps.h\"\n#include \"mlir/Dialect/Linalg/IR/Linalg.h\"\n#include \"mlir/ExecutionEngine/ExecutionEngine.h\"\n#include \"mlir/ExecutionEngine/MemRefUtils.h\"\n#include \"mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h\"\n\n#include \"llvm/Support/Error.h\"\n#include \"llvm/Support/TargetSelect.h\"\n\n#include <memory>\n#include <tuple>\n#include <utility>\n#include <sys/types.h>\n\nnamespace oneflow {\n\nMaybe<void> KernelLaunchOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  return ir::jit::InferTensorDesc(ctx);\n}\n\nMaybe<void> KernelLaunchOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return ir::jit::InferTensorDesc(ctx);\n}\n\nMaybe<void> KernelLaunchOp::GetSbp(user_op::SbpContext* ctx) {\n  ctx->NewBuilder().Broadcast(ctx->inputs()).Broadcast(ctx->outputs()).Build();\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> KernelLaunchOp::InferDataType(user_op::InferContext* ctx) {\n  return ir::jit::SetTensorDataType(ctx);\n}\n\nnamespace {\n\nusing namespace oneflow::okl;\n\ntemplate<typename T>\nclass KernelLaunchKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport {\n public:\n  KernelLaunchKernel() = default;\n  ~KernelLaunchKernel() = default;\n\n  std::shared_ptr<user_op::OpKernelState> CreateOpKernelState(\n      user_op::KernelInitContext* ctx) const override {\n    // use ctx to create module, reg_ctx and fn;\n    std::shared_ptr<user_op::OpKernelState> res(new LauncherState(ctx));\n    return res;\n  }\n\n  bool IsCudaGraphSupported(user_op::KernelInitContext* ctx,\n                            user_op::OpKernelState* state) const override {\n    return dynamic_cast<LauncherState*>(state)->IsCudaGraphSupported(ctx);\n  }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state,\n               const user_op::OpKernelCache*) const override {\n    auto* okl_state = dynamic_cast<LauncherState*>(state);\n    okl_state->DoCompute(ctx);\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_KERNEL_LAUNCH_CPU_KERNEL(dtype)                                                \\\n  REGISTER_USER_KERNEL(\"kernel_launch\")                                                         \\\n      .SetCreateFn<KernelLaunchKernel<dtype>>()                                                 \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)                           \\\n                       && (user_op::HobDataType(\"out\", 0) == GetDataType<dtype>::value))        \\\n      .SetInferTmpSizeFn([](user_op::InferContext* ctx) {                                       \\\n        return oneflow::okl::TmpBufferManager::InferTmpSize(ctx);                               \\\n      })                                                                                        \\\n      .SetInplaceProposalFn([](const user_op::InferContext&,                                    \\\n                               user_op::AddInplaceArgPair AddInplaceArgPairFn) -> Maybe<void> { \\\n        return Maybe<void>::Ok();                                                               \\\n      });\n\nREGISTER_KERNEL_LAUNCH_CPU_KERNEL(float)\nREGISTER_KERNEL_LAUNCH_CPU_KERNEL(double)\nREGISTER_KERNEL_LAUNCH_CPU_KERNEL(int32_t)\nREGISTER_KERNEL_LAUNCH_CPU_KERNEL(int64_t)\n#undef REGISTER_KERNEL_LAUNCH_CPU_KERNEL\n\n#define REGISTER_KERNEL_LAUNCH_GPU_KERNEL(dtype)                                                \\\n  REGISTER_USER_KERNEL(\"kernel_launch\")                                                         \\\n      .SetCreateFn<KernelLaunchKernel<dtype>>()                                                 \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)                          \\\n                       && (user_op::HobDataType(\"out\", 0) == GetDataType<dtype>::value))        \\\n      .SetInferTmpSizeFn([](user_op::InferContext* ctx) {                                       \\\n        return oneflow::okl::TmpBufferManager::InferTmpSize(ctx);                               \\\n      })                                                                                        \\\n      .SetInplaceProposalFn([](const user_op::InferContext&,                                    \\\n                               user_op::AddInplaceArgPair AddInplaceArgPairFn) -> Maybe<void> { \\\n        return Maybe<void>::Ok();                                                               \\\n      });\n\nREGISTER_KERNEL_LAUNCH_GPU_KERNEL(float)\nREGISTER_KERNEL_LAUNCH_GPU_KERNEL(double)\nREGISTER_KERNEL_LAUNCH_GPU_KERNEL(int8_t)\nREGISTER_KERNEL_LAUNCH_GPU_KERNEL(int32_t)\nREGISTER_KERNEL_LAUNCH_GPU_KERNEL(int64_t)\n\n#if CUDA_VERSION >= 11000\nREGISTER_KERNEL_LAUNCH_GPU_KERNEL(half)\nREGISTER_KERNEL_LAUNCH_GPU_KERNEL(nv_bfloat16)\n#endif\n\n#undef REGISTER_KERNEL_LAUNCH_GPU_KERNEL\n\n}  // namespace\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/ir/lib/OneFlow/OKL/Kernel/LauncherContext.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"OneFlow/OKL/Kernel/WrapperContext.h\"\n#include \"OneFlow/OKM/passes.h\"\n#include \"OneFlow/Passes.h\"\n#include \"llvm/Support/ErrorHandling.h\"\n#include \"mlir/Dialect/Func/IR/FuncOps.h\"\n#include \"mlir/IR/BuiltinAttributes.h\"\n#include \"mlir/IR/BuiltinOps.h\"\n#include \"mlir/IR/BuiltinTypes.h\"\n#include \"mlir/IR/Operation.h\"\n#include \"oneflow/core/framework/op_kernel.h\"\n#include \"OneFlow/OKL/OKLOps.h\"\n#include \"OneFlow/OKL/Kernel/RegContext.h\"\n#include \"OneFlow/OKL/Kernel/ComputeContext.h\"\n#include \"OneFlow/OKL/Kernel/LauncherContext.h\"\n#include \"llvm/ADT/TypeSwitch.h\"\n\nnamespace oneflow {\nnamespace okl {\n\nLauncherContext::LauncherContext(mlir::ModuleOp module) {\n  mlir::Operation* func;\n  module->walk([&](mlir::func::FuncOp op) {\n    if (op.getSymName().startswith(mlir::okm::func_name::OKL_GRAPH_NAME)) { func = op; }\n  });\n  if (!func) { LOG(FATAL) << \"Not Found okl_func in mlir ir\"; }\n  auto& ops = func->getRegion(0).front();\n\n  for (auto& op : ops) {\n    llvm::TypeSwitch<mlir::Operation*>(&op)\n        .Case([&](mlir::okl::WrapperKernelOp elem) {\n          mlir::Operation* reg_op = nullptr;\n          for (auto& op_it : op.getRegion(0).front().getOperations()) {\n            if (op_it.getDialect()->getNamespace() == \"oneflow\") {\n              reg_op = &op_it;\n              break;\n            }\n          }\n\n          if (!reg_op) { LOG(FATAL) << \"Failed to find reg_op in okl.build_reg_context_op\"; }\n          compile_ctx_vec_.emplace_back(reg_op);\n        })\n        .Case([&](mlir::func::ReturnOp elem) {})\n        .Default([&](mlir::Operation* elem) {\n          elem->dump();\n          LOG(FATAL) << \"Fail to parse this op in okl init context\";\n        });\n  }\n}\n\nbool LauncherContext::Infer(user_op::KernelComputeContext* compute_context) {\n  // if this context has been inferred before, it won't be rebuilt later\n  if (inferred_) { return inferred_; }\n\n  for (auto& elem : compile_ctx_vec_) {\n    run_ctx_vec_.emplace_back(elem.GetRegContext()->GetOp(), compute_context);\n  }\n  inferred_ = compile_ctx_vec_.size() == run_ctx_vec_.size();\n  return inferred_;\n}\n\nvoid LauncherContext::Launch(int index) {\n  if (!inferred_) { LOG(FATAL) << \"Not infer yet when launch kernels\"; }\n  run_ctx_vec_[index].Run();\n}\n\n}  // namespace okl\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/ir/lib/OneFlow/OKL/Kernel/LauncherState.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"OneFlow/OKL/Conversion/Conversion.h\"\n#include \"OneFlow/OKM/Conversion/Conversion.h\"\n#include \"OneFlow/Passes.h\"\n#include \"OneFlow/OKM/passes.h\"\n#include \"llvm/ADT/StringRef.h\"\n#include \"llvm/Support/ErrorHandling.h\"\n#include \"mlir/Dialect/Arith/IR/Arith.h\"\n#include \"mlir/Dialect/Func/IR/FuncOps.h\"\n#include \"mlir/IR/DialectRegistry.h\"\n#include \"oneflow/core/framework/op_kernel.h\"\n#include \"OneFlow/OneFlowDialect.h\"\n#include \"OneFlow/OKL/OKLDialect.h\"\n#include \"mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h\"\n#include \"mlir/Parser/Parser.h\"\n#include \"mlir/Dialect/LLVMIR/LLVMDialect.h\"\n#include \"OneFlow/OKL/Kernel/JITEngine.h\"\n#include \"OneFlow/OKL/Kernel/LauncherContext.h\"\n#include \"OneFlow/OKL/Kernel/LauncherState.h\"\n\nnamespace oneflow {\nnamespace okl {\n\nnamespace {\n\nmlir::OwningOpRef<mlir::ModuleOp> GetModule(user_op::KernelInitContext* ctx,\n                                            mlir::MLIRContext* mlir_ctx) {\n  auto mlir_assembly = ctx->Attr<std::vector<char>>(\"mlir_assembly\");\n  mlir::OwningOpRef<mlir::ModuleOp> module = mlir::parseSourceString<mlir::ModuleOp>(\n      llvm::StringRef(mlir_assembly.data(), mlir_assembly.size() - 1), mlir_ctx);\n  if (!module) { LOG(FATAL) << \"Fail to load mlir assembly\"; }\n  // lower oneflow wrap ops into okl dialect\n  if (failed(mlir::okm::LowerWrapOpsToOKL(*module))) {\n    LOG(FATAL) << \"Fail lowering kernel launch Module to okm and okl ir\";\n  }\n  return module;\n}\n\nJITEngine GetEngine(mlir::ModuleOp module) {\n  if (failed(mlir::okl::LowerOKLComputeToLLVM(module))) {\n    LOG(FATAL) << \"Fail lowering okl compute Module to llvm ir\";\n  }\n  return JITEngine(module);\n}\n\n}  // namespace\n\nLauncherState::LauncherState(user_op::KernelInitContext* ctx)\n    : mlir_ctx_(GetRegistry()),\n      module_(GetModule(ctx, &mlir_ctx_)),\n      launcher_context_(module_->clone()),\n      engine_(GetEngine(module_->clone())) {}\n\nbool LauncherState::IsCudaGraphSupported(user_op::KernelInitContext* ctx) {\n  const auto tag_name = mlir::okl::cuda_graph_support::TAG_NAME;\n  if (const auto func = module_->lookupSymbol(mlir::okm::func_name::OKL_GRAPH_NAME)) {\n    if (const auto is_supported = func->getAttr(tag_name).dyn_cast_or_null<mlir::BoolAttr>()) {\n      return is_supported.getValue();\n    }\n  }\n  return false;\n}\n\nvoid LauncherState::DoCompute(user_op::KernelComputeContext* ctx) {\n  launcher_context_.Infer(ctx);\n  engine_.Run(mlir::okm::func_name::OKL_GRAPH_NAME, &launcher_context_);\n}\n\n}  // namespace okl\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/ir/lib/OneFlow/OKL/Kernel/RegContext.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"OneFlow/UserOpConversion.h\"\n#include \"OneFlow/UserOpReflection.h\"\n#include \"oneflow/core/framework/infer_util.h\"\n#include \"oneflow/core/framework/user_op_attr.pb.h\"\n#include \"oneflow/core/common/protobuf.h\"\n#include \"oneflow/core/kernel/blob_tensor_view.h\"\n#include \"oneflow/core/memory/memory_case.pb.h\"\n#include \"oneflow/core/operator/op_conf.pb.h\"\n#include \"OneFlow/OKL/Kernel/InferContext.h\"\n#include \"OneFlow/OKL/Kernel/RegContext.h\"\n#include \"oneflow/core/framework/user_op_kernel_registry.h\"\n#include \"oneflow/ir/oneflow-translate/include/OneFlow/MLIROneFlowTranslation.h\"\n#include \"mlir/IR/BuiltinAttributes.h\"\n#include \"mlir/Support/LogicalResult.h\"\n#include \"mlir/IR/OpDefinition.h\"\n#include \"mlir/IR/Operation.h\"\n\nnamespace oneflow {\nnamespace okl {\n\nstatic user_op::UserOpConfWrapper GetConfWrapper(mlir::Operation* op,\n                                                 bool is_mapping_size = false) {\n  OperatorConf op_conf;\n  if (mlir::failed(mlir::oneflow::user_op::ConvertUserOpAttributes(op, op_conf, is_mapping_size))) {\n    op->emitError(\"fail to convert user op attributes\");\n    exit(1);\n  }\n  auto conf_wrapper_ = user_op::UserOpConfWrapper(std::make_shared<OperatorConf>(op_conf));\n  return conf_wrapper_;\n}\n\nRegContext::RegContext(mlir::Operation* op) : op_(op), conf_wrapper_(GetConfWrapper(op, true)) {\n  const auto handle_operands_or_results =\n      [&op, this](const auto& arg_ids, const auto& get_operand_or_result, ArgVec& arg_vec) {\n        for (const auto& obj_id : ::llvm::enumerate(arg_ids)) {\n          user_op::NaiveTensorDesc tensor_desc{};\n          auto obj = get_operand_or_result(op, obj_id.index());\n          if (auto rankedTensorType = obj.getType().template dyn_cast<mlir::RankedTensorType>()) {\n            tensor_desc.set_shape(\n                Shape{rankedTensorType.getShape().begin(), rankedTensorType.getShape().end()});\n            const auto data_type =\n                mlir::oneflow::support::FromMLIRTypeToOFDataType(rankedTensorType.getElementType());\n            if (mlir::failed(data_type)) { exit(1); }\n            tensor_desc.set_data_type(data_type.value());\n            llvm::SmallVector<int64_t> strides;\n            int64_t _;\n            auto mem_type = mlir::MemRefType::get(rankedTensorType.getShape(),\n                                                  rankedTensorType.getElementType());\n            if (failed(mlir::getStridesAndOffset(mem_type, strides, _))) {\n              LOG(FATAL) << \"Fail to get stride from memory type\";\n            }\n            tensor_desc.set_stride(Stride(strides.begin(), strides.end()));\n            // TODO: set is_dynamic\n          } else {\n            LOG(FATAL) << \"Unranked tensor type not supported\";\n          }\n          CHECK(arg2tensor_desc_.emplace(obj_id.value(), tensor_desc).second) << \"duplicate key\";\n          arg_vec.push_back(obj_id.value());\n        }\n      };\n  handle_operands_or_results(\n      ::mlir::oneflow::user_op::ArgIds<mlir::OpTrait::AttrSizedOperandSegments>(op),\n      [](auto& x, size_t index) { return x->getOperand(index); }, inputs_);\n  handle_operands_or_results(\n      ::mlir::oneflow::user_op::ArgIds<mlir::OpTrait::AttrSizedResultSegments>(op),\n      [](auto& x, size_t index) { return x->getResult(index); }, outputs_);\n\n  auto dev_tag = mlir::OpTrait::IsOpConfCompatible<void>::getDeviceTag(op);\n  if (dev_tag == \"cpu\") {\n    device_type_ = DeviceType::kCPU;\n  } else if (dev_tag == \"cuda\") {\n    device_type_ = DeviceType::kCUDA;\n  } else {\n    LOG(FATAL) << \"Unsupported device tag: \" << dev_tag.str();\n  }\n  auto op_name = GetOp()->getName().stripDialect().str();\n  if (const auto op_type_name =\n          GetOp()->getAttr(\"op_type_name\").dyn_cast_or_null<mlir::StringAttr>()) {\n    op_name = op_type_name.str();\n  }\n\n  reg_res_ =\n      CHECK_JUST(user_op::UserOpRegistryMgr::Get().GetOpKernelRegistryResult(op_name, *this));\n  kernel_ = reg_res_->create_fn();\n\n  conf_wrapper_ = GetConfWrapper(op_, true);\n}\n\nDeviceType RegContext::device_type() const { return device_type_; }\nconst ParallelContext& RegContext::parallel_ctx() const {\n  TODO() << \"create parallel_ctx from op in mlir\";\n  ParallelContext* parallel_ctx = nullptr;\n  return *parallel_ctx;\n}\nconst user_op::TensorDesc* RegContext::TensorDesc4ArgNameAndIndex(const std::string& arg_name,\n                                                                  int32_t index) const {\n  auto it = arg2tensor_desc_.find(std::make_pair(arg_name, index));\n  if (it == arg2tensor_desc_.end()) { return nullptr; }\n  return &(it->second);\n}\nconst ArgVec& RegContext::inputs() const { return inputs_; }\nconst ArgVec& RegContext::outputs() const { return outputs_; }\n\n// TODO: more information is needed\nconst user_op::UserOpConfWrapper& RegContext::user_op_conf() const { return conf_wrapper_; }\n\nconst std::shared_ptr<const user_op::AttrVal>& RegContext::Attr4Name(\n    const std::string& attr_name) const {\n  return user_op_conf().Attr4Name(attr_name);\n}\n\nconst size_t RegContext::GetTmpBufferSize() const {\n  if (reg_res_->need_temp_storage) {\n    InferContext infer_ctx(this);\n    return reg_res_->infer_tmp_size_fn(&infer_ctx);\n  }\n  return 0;\n}\n\n}  // namespace okl\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/ir/lib/OneFlow/OKL/Kernel/TmpBufferManager.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"OneFlow/OKL/Kernel/TmpBufferManager.h\"\n#include \"OneFlow/OKL/Kernel/LauncherState.h\"\n#include \"OneFlow/OKL/OKLOps.h\"\n#include \"OneFlow/OKM/Conversion/Conversion.h\"\n#include \"OneFlow/OKM/passes.h\"\n#include \"OneFlow/Passes.h\"\n#include \"mlir/IR/BuiltinAttributes.h\"\n#include \"mlir/IR/MLIRContext.h\"\n#include \"mlir/Parser/Parser.h\"\n#include \"llvm/Support/Casting.h\"\n\nnamespace oneflow {\nnamespace okl {\n\nsize_t TmpBufferManager::InferTmpSize(user_op::InferContext* ctx) {\n  using namespace user_op;\n  mlir::MLIRContext mlir_ctx(GetRegistry());\n\n  auto mlir_assembly = ctx->Attr<std::vector<char>>(\"mlir_assembly\");\n  mlir::OwningOpRef<mlir::ModuleOp> module = mlir::parseSourceString<mlir::ModuleOp>(\n      llvm::StringRef(mlir_assembly.data(), mlir_assembly.size() - 1), &mlir_ctx);\n  if (!module) { LOG(FATAL) << \"Fail to load mlir assembly\"; }\n  if (failed(mlir::okm::LowerWrapOpsToOKL(*module))) {\n    LOG(ERROR) << \"Fail lowering kernel launch Module to okl ir\";\n    exit(1);\n  }\n\n  size_t pool_size = 0;\n  module->walk([&](mlir::func::FuncOp op) {\n    if (op.getSymName().startswith(mlir::okm::func_name::OKL_GRAPH_NAME)) {\n      if (auto pool_size_attr =\n              op->getAttrOfType<mlir::IntegerAttr>(mlir::okm::func_name::OKL_POOL_SIZE_TAG)) {\n        pool_size = pool_size_attr.getInt();\n      }\n    }\n  });\n  return pool_size;\n}\n\n}  // namespace okl\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/ir/lib/OneFlow/OKL/OKLDialect.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"OneFlow/OKL/OKLDialect.h\"\n#include \"OneFlow/OKL/OKLOps.h\"\n#include \"OneFlow/OKL/OKLTypes.h\"\n#include \"OneFlow/OKL/OKLAttributes.h\"\n#include \"OneFlow/OneFlowOps.h\"\n#include \"OneFlow/Passes.h\"\n#include \"mlir/IR/BuiltinAttributes.h\"\n#include \"OneFlow/OKLDialect.cpp.inc\"\n#include \"mlir/IR/Dialect.h\"\n#include \"mlir/IR/TypeRange.h\"\n#include \"mlir/IR/Dialect.h\"\n#include \"llvm/ADT/SmallVector.h\"\n#include \"llvm/ADT/TypeSwitch.h\"\n#include \"mlir/IR/Attributes.h\"\n#include \"mlir/IR/DialectImplementation.h\"\n#include \"mlir/Support/LogicalResult.h\"\n\n#define GET_ATTRDEF_CLASSES\n#include \"OneFlow/OKLAttributes.cpp.inc\"\nnamespace mlir {\n\nnamespace okl {\n\nvoid OKLDialect::initialize() {\n  addOperations<\n#define GET_OP_LIST\n#include \"OneFlow/OKLOps.cpp.inc\"\n      >();\n  addTypes<\n#define GET_TYPEDEF_LIST\n#include \"OneFlow/OKLTypes.cpp.inc\"\n      >();\n  addAttributes<\n#define GET_ATTRDEF_LIST\n#include \"OneFlow/OKLAttributes.cpp.inc\"\n      >();\n}\n\n}  // namespace okl\n\n}  // namespace mlir\n"
  },
  {
    "path": "oneflow/ir/lib/OneFlow/OKL/OKLOps.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"OneFlow/OKL/OKLDialect.h\"\n#include \"OneFlow/OKL/OKLTypes.h\"\n#include \"OneFlow/OKL/OKLOps.h\"\n#include \"OneFlow/OKL/OKLAttributes.h\"\n#include \"OneFlow/OneFlowOps.h\"\n#include \"mlir/IR/BuiltinAttributes.h\"\n#include \"mlir/IR/Dialect.h\"\n#include \"mlir/IR/TypeRange.h\"\n#include \"mlir/IR/Dialect.h\"\n\n#include \"OneFlow/OKLEnums.cpp.inc\"\n\n#define GET_OP_CLASSES\n#include \"OneFlow/OKLOps.cpp.inc\"\n"
  },
  {
    "path": "oneflow/ir/lib/OneFlow/OKL/OKLTypes.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"OneFlow/OKL/OKLDialect.h\"\n#include \"OneFlow/OKL/OKLTypes.h\"\n#include \"mlir/IR/DialectImplementation.h\"\n#include \"llvm/ADT/TypeSwitch.h\"\n\n#define GET_TYPEDEF_CLASSES\n#include \"OneFlow/OKLTypes.cpp.inc\"\n"
  },
  {
    "path": "oneflow/ir/lib/OneFlow/OKL/README-OriginVersion.md",
    "content": "# 初版OKL设计文档\noneflow kernel launch dialect\n\n将oneflow kernel引入mlir执行。\n\n## 编译期\n\n### 1. FromGraphToMLIR\n - GraphToJob\n - JobToOneFlowDialect\n\n### 2. OneFlowDialectToOKLDialect\n通过三个Pass将OneFlow转换成okl的ir形式。\n- extract-kernel-launch-tensor\n- trim-return-to-void\n- lower-to-okl\n``` mlir\n module {\n  func.func @wrap0(%arg0: tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) attributes {llvm.emit_c_interface} {\n    %0 = \"oneflow.relu\"(%arg0) {device_name = [\"@0:0\"], device_tag = \"cpu\", hierarchy = [1], op_name = \"relu-0\", scope_symbol_id = 12 : i64} : (tensor<2xf32>) -> tensor<2xf32>\n    %1 = \"oneflow.tanh\"(%0) {device_name = [\"@0:0\"], device_tag = \"cpu\", hierarchy = [1], op_name = \"tanh-1\", scope_symbol_id = 12 : i64} : (tensor<2xf32>) -> tensor<2xf32>\n    return %0, %1 : tensor<2xf32>, tensor<2xf32>\n  }\n}\n```\n-extract-kernel-launch-tensor\n\n将tensor的输入流转换为ctx中获取\n``` mlir\nmodule {\n  func.func @wrap0(%arg0: !okl.launcher_ctx) -> (tensor<2xf32>, tensor<2xf32>) {\n    %0 = \"okl.get_tensor_from_arg\"(%arg0) {index = 0 : i32} : (!okl.launcher_ctx) -> tensor<2xf32>\n    %1 = \"oneflow.relu\"(%0) {device_name = [\"@0:0\"], device_tag = \"cpu\", hierarchy = [1], op_name = \"relu-0\", scope_symbol_id = 12 : i64} : (tensor<2xf32>) -> tensor<2xf32>\n    %2 = \"oneflow.tanh\"(%1) {device_name = [\"@0:0\"], device_tag = \"cpu\", hierarchy = [1], op_name = \"tanh-1\", scope_symbol_id = 12 : i64} : (tensor<2xf32>) -> tensor<2xf32>\n    %3 = \"okl.get_tensor_as_ret\"(%arg0, %1) {index = 0 : i32} : (!okl.launcher_ctx, tensor<2xf32>) -> tensor<2xf32>\n    %4 = \"okl.get_tensor_as_ret\"(%arg0, %2) {index = 1 : i32} : (!okl.launcher_ctx, tensor<2xf32>) -> tensor<2xf32>\n    return %3, %4 : tensor<2xf32>, tensor<2xf32>\n  }\n}\n```\n-trim-return-to-void\n\n将tensor的输出流删除掉\n```mlir\nmodule {\n  func.func @wrap0(%arg0: !okl.launcher_ctx) {\n    %0 = \"okl.get_tensor_from_arg\"(%arg0) {index = 0 : i32} : (!okl.launcher_ctx) -> tensor<2xf32>\n    %1 = \"oneflow.relu\"(%0) {device_name = [\"@0:0\"], device_tag = \"cpu\", hierarchy = [1], op_name = \"relu-0\", scope_symbol_id = 12 : i64} : (tensor<2xf32>) -> tensor<2xf32>\n    %2 = \"oneflow.tanh\"(%1) {device_name = [\"@0:0\"], device_tag = \"cpu\", hierarchy = [1], op_name = \"tanh-1\", scope_symbol_id = 12 : i64} : (tensor<2xf32>) -> tensor<2xf32>\n    %3 = \"okl.get_tensor_as_ret\"(%arg0, %1) {index = 0 : i32} : (!okl.launcher_ctx, tensor<2xf32>) -> tensor<2xf32>\n    %4 = \"okl.get_tensor_as_ret\"(%arg0, %2) {index = 1 : i32} : (!okl.launcher_ctx, tensor<2xf32>) -> tensor<2xf32>\n    return\n  }\n}\n```\n-lower-to-okl\n\n将oneflow kernel op用okl wrapper_kernel封装起来，并通过okl op编译推导对应tensor流的信息。\n```mlir\nmodule {\n  func.func @okl_func(%arg0: !okl.launcher_ctx) {\n    \"okl.wrapper_kernel\"() ({\n      %0 = \"okl.get_tensor_from_arg\"(%arg0) {index = 0 : i32} : (!okl.launcher_ctx) -> tensor<2xf32>\n      %1 = \"oneflow.relu\"(%0) {device_name = [\"@0:0\"], device_tag = \"cpu\", hierarchy = [1], op_name = \"relu-0\", scope_symbol_id = 12 : i64} : (tensor<2xf32>) -> tensor<2xf32>\n      %2 = \"okl.get_tensor_as_ret\"(%arg0, %1) {index = 0 : i32} : (!okl.launcher_ctx, tensor<2xf32>) -> tensor<2xf32>\n      okl.return\n    }) {index = 0 : i32} : () -> ()\n    \"okl.wrapper_kernel\"() ({\n      %0 = \"okl.get_tensor_from_ret\"(%arg0) {index = 0 : i32} : (!okl.launcher_ctx) -> tensor<2xf32>\n      %1 = \"oneflow.tanh\"(%0) {device_name = [\"@0:0\"], device_tag = \"cpu\", hierarchy = [1], op_name = \"tanh-1\", scope_symbol_id = 12 : i64} : (tensor<2xf32>) -> tensor<2xf32>\n      %2 = \"okl.get_tensor_as_ret\"(%arg0, %1) {index = 1 : i32} : (!okl.launcher_ctx, tensor<2xf32>) -> tensor<2xf32>\n      okl.return\n    }) {index = 1 : i32} : () -> ()\n    return\n  }\n}\n```\n### 3. OKLDialectToLLVMDialect\n通过四个Pass将OKL的IR转换为LLVM的IR形式作为运行时的输入\n- lower-launcher-to-llvm-ptr\n- lower-okl-to-llvm-call\n- reconcile-unrealized-casts\n- convert-func-to-llvm\n\n-lower-launcher-to-llvm-ptr\n\n将ctx转换成一个llvm.ptr，通过llvm.ptr表示ctx的传递。\n```mlir\nmodule {\n  func.func @okl_func(%arg0: !llvm.ptr<i8>) attributes {llvm.emit_c_interface} {\n    %0 = builtin.unrealized_conversion_cast %arg0 : !llvm.ptr<i8> to !okl.launcher_ctx\n    \"okl.wrapper_kernel\"() ({\n      %1 = \"okl.get_tensor_from_arg\"(%0) {index = 0 : i32} : (!okl.launcher_ctx) -> tensor<2xf32>\n      %2 = \"oneflow.relu\"(%1) {device_name = [\"@0:0\"], device_tag = \"cpu\", hierarchy = [1], op_name = \"relu-0\", scope_symbol_id = 12 : i64} : (tensor<2xf32>) -> tensor<2xf32>\n      %3 = \"okl.get_tensor_as_ret\"(%0, %2) {index = 0 : i32} : (!okl.launcher_ctx, tensor<2xf32>) -> tensor<2xf32>\n      okl.return\n    }) {index = 0 : i32} : () -> ()\n    \"okl.wrapper_kernel\"() ({\n      %1 = \"okl.get_tensor_from_ret\"(%0) {index = 0 : i32} : (!okl.launcher_ctx) -> tensor<2xf32>\n      %2 = \"oneflow.tanh\"(%1) {device_name = [\"@0:0\"], device_tag = \"cpu\", hierarchy = [1], op_name = \"tanh-1\", scope_symbol_id = 12 : i64} : (tensor<2xf32>) -> tensor<2xf32>\n      %3 = \"okl.get_tensor_as_ret\"(%0, %2) {index = 1 : i32} : (!okl.launcher_ctx, tensor<2xf32>) -> tensor<2xf32>\n      okl.return\n    }) {index = 1 : i32} : () -> ()\n    return\n  }\n}\n```\n-lower-okl-to-llvm-call\n\n将okl的wrapper_kernel转换成llvm的call调用。\n```mlir\nmodule {\n  llvm.func @okl_llvm_func(!llvm.ptr<i8>, i64) attributes {llvm.emit_c_interface}\n  func.func @okl_func(%arg0: !llvm.ptr<i8>) attributes {llvm.emit_c_interface} {\n    %0 = builtin.unrealized_conversion_cast %arg0 : !llvm.ptr<i8> to !okl.launcher_ctx\n    %1 = llvm.mlir.constant(0 : index) : i64\n    llvm.call @okl_llvm_func(%arg0, %1) : (!llvm.ptr<i8>, i64) -> ()\n    %2 = llvm.mlir.constant(1 : index) : i64\n    llvm.call @okl_llvm_func(%arg0, %2) : (!llvm.ptr<i8>, i64) -> ()\n    return\n  }\n}\n```\n-reconcile-unrealized-casts\n-convert-func-to-llvm \n\n转换成可以直接运行的llvm IR\n```mlir\nmodule attributes {llvm.data_layout = \"\"} {\n  llvm.func @okl_llvm_func(!llvm.ptr<i8>, i64) attributes {llvm.emit_c_interface}\n  llvm.func @okl_func(%arg0: !llvm.ptr<i8>) attributes {llvm.emit_c_interface} {\n    %0 = llvm.mlir.constant(0 : index) : i64\n    llvm.call @okl_llvm_func(%arg0, %0) : (!llvm.ptr<i8>, i64) -> ()\n    %1 = llvm.mlir.constant(1 : index) : i64\n    llvm.call @okl_llvm_func(%arg0, %1) : (!llvm.ptr<i8>, i64) -> ()\n    llvm.return\n  }\n  llvm.func @_mlir_ciface_okl_func(%arg0: !llvm.ptr<i8>) attributes {llvm.emit_c_interface} {\n    llvm.call @okl_func(%arg0) : (!llvm.ptr<i8>) -> ()\n    llvm.return\n  }\n}\n```\n\n\n## 运行时\n\nOKLDialect IR不仅作为编译期最后一阶段的输出，同时作为运行时初始化时期资源的输入来初始化运行时的各种ctx，从而为计算期的计算做准备。\n\n一个 OKL 的 kernel 包含了一整个子图。因此 OKL 的 kernel 需要管理子图的若干有序子 op 的 ctx 资源。这些通过 LauncherState 来初始化创建，LauncherState 中含有 LauncherContext 用来统一管理子图的所有子 Op 的资源。\n\nLauncherContext含有若干有序的CompileTimeWrapperContext一一对应其子Op未Infer前的ctx，以及若干RunTimeWrapperContext一一对应其子Op在Infer后的ctx。\n\n下面为这两种Ctx所持有的资源。\n```\nclass CompileTimeWrapperContext {\n  std::shared_ptr<const RegContext> reg_ctx_;\n};\nclass RunTimeWrapperContext : public CompileTimeWrapperContext {\n  std::shared_ptr<ComputeContext> compute_ctx_;\n  std::shared_ptr<InitContext> init_ctx_;\n  std::shared_ptr<user_op::OpKernelState> kernel_state_;\n  std::shared_ptr<user_op::OpKernelCache> kernel_cache_;\n};\n```\nCompileTimeWrapperContext 主要是reg_ctx，以作为infer推导的必须输入。\n\nRunTimeWrapperContext 包含所有子op运行时计算需要用的的资源，主要有compute_ctx以及state和cache。"
  },
  {
    "path": "oneflow/ir/lib/OneFlow/OKM/Conversion/Conversion.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"OneFlow/OKL/Conversion/Conversion.h\"\n#include \"OneFlow/OKM/passes.h\"\n#include \"OneFlow/OneFlowUtils.h\"\n\nnamespace mlir {\nnamespace okm {\n\nLogicalResult LowerWrapOpsToOKL(ModuleOp module) {\n  PassManager pm(module->getContext());\n  pm.addPass(createExtractOKMTensorPass());\n  pm.addPass(createWrapOKMKernelPass());\n  pm.addPass(createOptOKMMemrefPass());\n  pm.addPass(createConvertOKMToOKLPass());\n  pm.addPass(okl::createTagCudaGraphSupportPass());\n  oneflow::CheckEnableIRPrinting(pm);\n  return pm.run(module);\n}\n\n}  // namespace okm\n}  // namespace mlir\n"
  },
  {
    "path": "oneflow/ir/lib/OneFlow/OKM/OKMDialect.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"OneFlow/OKM/OKMDialect.h\"\n#include \"OneFlow/OKM/OKMOps.h\"\n#include \"OneFlow/OKM/OKMAttributes.h\"\n\n#include \"OneFlow/OKM/passes.h\"\n#include \"OneFlow/OneFlowOps.h\"\n#include \"OneFlow/Passes.h\"\n#include \"mlir/IR/BuiltinAttributes.h\"\n#include \"OneFlow/OKMDialect.cpp.inc\"\n#include \"mlir/IR/Dialect.h\"\n#include \"mlir/IR/TypeRange.h\"\n#include \"mlir/IR/Dialect.h\"\n#include \"llvm/ADT/SmallVector.h\"\n#include \"llvm/ADT/TypeSwitch.h\"\n#include \"mlir/IR/Attributes.h\"\n#include \"mlir/IR/DialectImplementation.h\"\n#include \"mlir/Support/LogicalResult.h\"\n\n#define GET_ATTRDEF_CLASSES\n#include \"OneFlow/OKMAttributes.cpp.inc\"\n\n#define GET_OP_CLASSES\n#include \"OneFlow/OKMOps.cpp.inc\"\n\nnamespace mlir {\n\nnamespace okm {\n\nvoid OKMDialect::initialize() {\n  addOperations<\n#define GET_OP_LIST\n#include \"OneFlow/OKMOps.cpp.inc\"\n      >();\n  addAttributes<\n#define GET_ATTRDEF_LIST\n#include \"OneFlow/OKMAttributes.cpp.inc\"\n      >();\n}\n\n}  // namespace okm\n\n}  // namespace mlir\n"
  },
  {
    "path": "oneflow/ir/lib/OneFlow/OKM/passes.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/job/intra_job_mem_sharing_util.h\"\n#include \"OneFlow/OKL/OKLDialect.h\"\n#include \"OneFlow/OKL/OKLOps.h\"\n#include \"OneFlow/OKL/Kernel/RegContext.h\"\n#include \"OneFlow/OKM/OKMDialect.h\"\n#include \"OneFlow/OKM/OKMOps.h\"\n#include \"OneFlow/OKM/passes.h\"\n#include \"OneFlow/OneFlowDialect.h\"\n#include \"llvm/Support/Casting.h\"\n#include \"llvm/Support/raw_ostream.h\"\n#include \"mlir-c/BuiltinTypes.h\"\n#include \"mlir/Dialect/Arith/IR/Arith.h\"\n#include \"mlir/Dialect/Bufferization/IR/Bufferization.h\"\n#include \"mlir/Dialect/Func/IR/FuncOps.h\"\n#include \"mlir/Dialect/MemRef/IR/MemRef.h\"\n#include \"mlir/IR/IRMapping.h\"\n#include \"mlir/IR/BuiltinAttributes.h\"\n#include \"mlir/IR/BuiltinTypes.h\"\n#include \"mlir/IR/ImplicitLocOpBuilder.h\"\n#include \"mlir/IR/Operation.h\"\n#include \"mlir/IR/PatternMatch.h\"\n#include \"mlir/IR/Region.h\"\n#include \"mlir/IR/ValueRange.h\"\n#include \"mlir/Pass/Pass.h\"\n#include \"mlir/Support/LogicalResult.h\"\n#include \"mlir/Transforms/GreedyPatternRewriteDriver.h\"\n#include \"llvm/ADT/SmallVector.h\"\n#include \"llvm/ADT/TypeSwitch.h\"\n\nnamespace mlir {\nnamespace okm {\n\nnamespace func_name {\n\nconst std::string GRAPH_NAME = \"_mlir_oneflow_subgraph\";\nconst std::string MEM_GRAPH_NAME = \"okm_subgraph\";\nconst std::string WRAP_GRAPH_NAME = \"okm_wrap_subgraph\";\nconst std::string OPT_GRAPH_NAME = \"okm_alloc_subgraph\";\nconst std::string OKL_GRAPH_NAME = \"okl_subgraph\";\nconst std::string OKL_POOL_SIZE_TAG = \"pool_size\";\n\n}  // namespace func_name\n\nstruct ExtractOKMTensorPattern : public mlir::OpRewritePattern<func::FuncOp> {\n  static void ExtractArgTensors(func::FuncOp op, mlir::PatternRewriter& rewriter) {\n    auto& body = op.getBody();\n    OpBuilder::InsertionGuard guard(rewriter);\n    rewriter.setInsertionPointToStart(&body.front());\n\n    for (const auto& arg : llvm::enumerate(op.getBody().getArguments())) {\n      auto tensor =\n          rewriter.create<okm::ArgToTensorOp>(op->getLoc(), arg.value().getType(), arg.index());\n      arg.value().replaceAllUsesWith(tensor);\n    }\n  }\n\n  static void ExtractRetTensors(func::FuncOp op, mlir::PatternRewriter& rewriter) {\n    auto& return_op = op.getBody().front().back();\n    OpBuilder::InsertionGuard guard(rewriter);\n    rewriter.setInsertionPoint(&return_op);\n\n    llvm::SmallVector<Value> returns;\n    for (const auto& ret_val : llvm::enumerate(return_op.getOperands())) {\n      auto new_ret = rewriter.create<okm::TensorToRetOp>(op->getLoc(), ret_val.value().getType(),\n                                                         ret_val.value(), ret_val.index());\n      returns.push_back(new_ret);\n    }\n\n    rewriter.replaceOpWithNewOp<func::ReturnOp>(&return_op, ValueRange{returns});\n  }\n\n  explicit ExtractOKMTensorPattern(mlir::MLIRContext* context)\n      : OpRewritePattern<func::FuncOp>(context, /*benefit=*/0) {}\n  mlir::LogicalResult matchAndRewrite(func::FuncOp op,\n                                      mlir::PatternRewriter& rewriter) const override {\n    const auto sym_name = op.getSymName();\n    if (sym_name.startswith(func_name::GRAPH_NAME)) {\n      // rename function\n      const auto index = sym_name.substr(func_name::GRAPH_NAME.size());\n      const auto rename = func_name::MEM_GRAPH_NAME + index;\n      op.setSymNameAttr(rewriter.getStringAttr(rename));\n      // extract tensors\n      ExtractArgTensors(op, rewriter);\n      ExtractRetTensors(op, rewriter);\n      return success();\n    }\n    return failure();\n  }\n};\n\nclass ExtractOKMTensorPass : public ExtractOKMTensorPassBase<ExtractOKMTensorPass> {\n  void getDependentDialects(DialectRegistry& registry) const override {\n    registry.insert<oneflow::OneFlowDialect>();\n    registry.insert<OKMDialect>();\n  }\n\n  void runOnOperation() override {\n    Operation* op = getOperation();\n    RewritePatternSet patterns(op->getContext());\n    patterns.add<ExtractOKMTensorPattern>(patterns.getContext());\n    (void)applyPatternsAndFoldGreedily(op, std::move(patterns));\n  }\n};\n\nstd::unique_ptr<Pass> createExtractOKMTensorPass() {\n  return std::make_unique<ExtractOKMTensorPass>();\n}\n\nstruct WrapOKMKernelPattern : public mlir::OpRewritePattern<func::FuncOp> {\n  static Value AllocOrMapOutTensor(Value res, mlir::PatternRewriter& rewriter) {\n    if (auto type = res.getType().dyn_cast_or_null<TensorType>()) {\n      int ret_index = -1;\n      for (auto use : res.getUsers()) {\n        if (auto to_ret = llvm::dyn_cast_or_null<TensorToRetOp>(use)) {\n          ret_index = to_ret.getIndex();\n          break;\n        }\n      }\n      auto mem_type = MemRefType::get(type.getShape(), type.getElementType());\n      auto out =\n          (ret_index == -1)\n              ? rewriter.create<PlanMemrefOp>(rewriter.getUnknownLoc(), mem_type)\n              : rewriter.create<RetToMemrefOp>(rewriter.getUnknownLoc(), mem_type, ret_index);\n      return out->getResult(0);\n    }\n    return nullptr;\n  }\n\n  static void CreateWrapOp(Operation* op, mlir::PatternRewriter& rewriter, IRMapping& mapper,\n                           const llvm::SmallVector<Type>& mem_outs_types,\n                           const llvm::SmallVector<Value>& map_ins) {\n    auto wrapper_op = rewriter.create<WrapperOp>(op->getLoc(), mem_outs_types, ValueRange(map_ins));\n    for (auto elem : llvm::zip(op->getResults(), wrapper_op->getResults())) {\n      mapper.map(std::get<0>(elem), std::get<1>(elem));\n    }\n    auto& wrap_block = wrapper_op.getBody().emplaceBlock();\n    OpBuilder::InsertionGuard insertGuard(rewriter);\n    rewriter.setInsertionPointToStart(&wrap_block);\n    ImplicitLocOpBuilder nb(rewriter.getUnknownLoc(), rewriter);\n    IRMapping wrap_mapper;\n    for (auto in : llvm::zip(op->getOperands(), wrapper_op.getOperands())) {\n      auto to_tensor = rewriter.create<mlir::bufferization::ToTensorOp>(rewriter.getUnknownLoc(),\n                                                                        std::get<1>(in));\n      wrap_mapper.map(std::get<0>(in), to_tensor);\n    }\n    auto new_op = nb.clone(*op, wrap_mapper);\n    SmallVector<Value> outs;\n    for (auto out : new_op->getResults()) {\n      if (auto type = out.getType().dyn_cast_or_null<TensorType>()) {\n        auto mem_type = MemRefType::get(type.getShape(), type.getElementType());\n        auto to_memref = rewriter.create<mlir::bufferization::ToMemrefOp>(rewriter.getUnknownLoc(),\n                                                                          mem_type, out);\n        outs.push_back(to_memref);\n      } else {\n        llvm::errs() << \"Fail to identify op type in wrap okm kernel\";\n        exit(1);\n      }\n    }\n    rewriter.create<ReturnOp>(rewriter.getUnknownLoc(), ValueRange(outs));\n  }\n\n  static void HandleOneFlowOp(Operation* op, mlir::PatternRewriter& rewriter, IRMapping& mapper) {\n    // record outs type\n    llvm::SmallVector<Type> mem_outs_types;\n    for (auto it : op->getResultTypes()) {\n      if (auto type = it.dyn_cast_or_null<TensorType>()) {\n        auto mem_type = MemRefType::get(type.getShape(), type.getElementType());\n        mem_outs_types.push_back(mem_type);\n      } else {\n        llvm::errs() << \"Fail to identify op type in wrap okm kernel\";\n        exit(1);\n      }\n    }\n    llvm::SmallVector<Value> map_ins;\n    // record ins\n    for (auto in : op->getOperands()) {\n      auto mirror = mapper.lookup(in);\n      if (auto wrap_op = llvm::dyn_cast_or_null<okm::WrapperOp>(mirror.getDefiningOp())) {\n        int idx = 0;\n        for (auto res : wrap_op->getResults()) {\n          if (mirror == res) { break; }\n          ++idx;\n        }\n        Operation* oneflow_op = nullptr;\n        auto& ops = wrap_op.getBody().front();\n        for (auto& op : ops) {\n          if (oneflow::OneFlowDialect::getDialectNamespace().equals(\n                  op.getDialect()->getNamespace())) {\n            oneflow_op = &op;\n          }\n        }\n        if (!oneflow_op) { LOG(FATAL) << \"Fail to find oneflow op in wrap op\"; }\n        mirror =\n            wrap_op->getOperand(oneflow_op->getNumOperands() + idx).getDefiningOp()->getResult(0);\n      }\n      map_ins.push_back(mirror);\n    }\n    // append alloc outs after ins\n    for (auto out : op->getResults()) {\n      if (auto new_out = AllocOrMapOutTensor(out, rewriter)) {\n        map_ins.push_back(new_out);\n      } else {\n        llvm::errs() << \"Fail to alloc or map op in wrap okm kernel\";\n        exit(1);\n      }\n    }\n    if (int64_t buffer_size = ::oneflow::okl::RegContext(op).GetTmpBufferSize()) {\n      auto type = MemRefType::get({buffer_size}, rewriter.getI8Type());\n      auto tmp_buffer = rewriter.create<PlanMemrefOp>(rewriter.getUnknownLoc(), type)->getResult(0);\n      map_ins.push_back(tmp_buffer);\n    }\n\n    CreateWrapOp(op, rewriter, mapper, mem_outs_types, map_ins);\n  }\n\n  static func::FuncOp WrapOps(func::FuncOp func, mlir::PatternRewriter& rewriter,\n                              const std::string& func_name) {\n    OpBuilder::InsertionGuard insertGuard(rewriter);\n    auto func_type = rewriter.getFunctionType({}, {});\n    rewriter.setInsertionPoint(func);\n    auto wrap_func = rewriter.create<func::FuncOp>(rewriter.getUnknownLoc(), func_name, func_type);\n    auto& block = wrap_func.getBody().emplaceBlock();\n    rewriter.setInsertionPointToStart(&block);\n\n    auto& ops = func.getBody().front();\n    IRMapping mapper;\n    for (auto& op : ops) {\n      llvm::TypeSwitch<Operation*>(&op)\n          .Case<ArgToTensorOp>([&](ArgToTensorOp op) {\n            auto mem_type = MemRefType::get(op.getType().getShape(), op.getType().getElementType());\n            auto mem_op = rewriter.create<ArgToMemrefOp>(op->getLoc(), mem_type, op.getIndex());\n            mapper.map(Value(op), mem_op);\n          })\n          .Case<TensorToRetOp>([&](TensorToRetOp op) {\n            auto mem_type = MemRefType::get(op.getType().getShape(), op.getType().getElementType());\n            rewriter.create<MemrefToRetOp>(op->getLoc(), mem_type, mapper.lookup(op.getTensor()),\n                                           op.getIndex());\n          })\n          .Default([&](Operation* op) {\n            if (oneflow::OneFlowDialect::getDialectNamespace().equals(\n                    op->getDialect()->getNamespace())) {\n              HandleOneFlowOp(op, rewriter, mapper);\n            }\n          });\n    }\n    rewriter.create<func::ReturnOp>(rewriter.getUnknownLoc());\n    return wrap_func;\n  }\n\n  explicit WrapOKMKernelPattern(mlir::MLIRContext* context)\n      : OpRewritePattern<func::FuncOp>(context, /*benefit=*/0) {}\n  mlir::LogicalResult matchAndRewrite(func::FuncOp op,\n                                      mlir::PatternRewriter& rewriter) const override {\n    const auto sym_name = op.getSymName();\n    if (sym_name.startswith(func_name::MEM_GRAPH_NAME)) {\n      // rename function\n      const auto index = sym_name.substr(func_name::MEM_GRAPH_NAME.size()).str();\n      const std::string rename = func_name::WRAP_GRAPH_NAME + index;\n      // wrap kernels\n      WrapOps(op, rewriter, rename);\n      rewriter.eraseOp(op);\n    }\n    return success();\n  }\n};\n\nclass WrapOKMKernelPass : public WrapOKMKernelPassBase<WrapOKMKernelPass> {\n  void getDependentDialects(DialectRegistry& registry) const override {\n    registry.insert<oneflow::OneFlowDialect>();\n    registry.insert<OKMDialect>();\n    registry.insert<bufferization::BufferizationDialect>();\n  }\n\n  void runOnOperation() override {\n    Operation* op = getOperation();\n    RewritePatternSet patterns(op->getContext());\n    patterns.add<WrapOKMKernelPattern>(patterns.getContext());\n    (void)applyPatternsAndFoldGreedily(op, std::move(patterns));\n  }\n};\n\nstd::unique_ptr<Pass> createWrapOKMKernelPass() { return std::make_unique<WrapOKMKernelPass>(); }\n\nnamespace {\nvoid MemSizeFirst(func::FuncOp func, mlir::PatternRewriter& rewriter) {\n  OpBuilder::InsertionGuard insertGuard(rewriter);\n  auto& ops = func.getBody().front();\n\n  rewriter.setInsertionPointToStart(&ops);\n  auto mem_type = MemRefType::get({0}, rewriter.getI8Type());\n  auto global_buffer = rewriter.create<memref::AllocOp>(rewriter.getUnknownLoc(), mem_type);\n\n  ::oneflow::HashMap<Operation*, int32_t> op2lifetime;\n  int32_t idx = 0;\n  for (auto& op : ops) {\n    if (auto wrap_op = llvm::dyn_cast_or_null<WrapperOp>(op)) { op2lifetime[&op] = idx++; }\n  }\n\n  ::oneflow::HashMap<Operation*, size_t> val2size;\n  ::oneflow::HashMap<Operation*, std::pair<int32_t, int32_t>> val2lifetime;\n  for (auto& op : ops) {\n    if (auto alloc_op = llvm::dyn_cast_or_null<PlanMemrefOp>(op)) {\n      // get size\n      MemRefType type = op.getResult(0).getType().dyn_cast<MemRefType>();\n      int64_t size = type.getElementTypeBitWidth() / 8;\n      for (int64_t i : type.getShape()) { size *= i; }\n      int align = ::oneflow::kBlobBodyAlignSize;\n      size = (size / align + ((size % align) != 0)) * align;\n      val2size[&op] = size;\n\n      // get life time\n      int min = INT_MAX, max = 0;\n      for (auto use : op.getUsers()) {\n        if (auto wrap_op = llvm::dyn_cast_or_null<WrapperOp>(use)) {\n          auto op_val = op2lifetime[use];\n          min = std::min(min, op_val);\n          max = std::max(max, op_val + 1);\n        }\n      }\n      val2lifetime[&op] = {min, max};\n    }\n  }\n\n  ::oneflow::MemBlockResultInfo<Operation*> res;\n  ::oneflow::MemReusedMemSizeFirstAlgo(false, val2lifetime, val2size, &res);\n\n  auto val2offset = res.regst_desc2offset;\n  for (auto [op, offset] : val2offset) {\n    if (auto plan_op = llvm::dyn_cast_or_null<PlanMemrefOp>(op)) {\n      rewriter.setInsertionPoint(plan_op);\n      auto off_set = rewriter.create<arith::ConstantIndexOp>(rewriter.getUnknownLoc(), offset);\n      auto type = plan_op->getResult(0).getType();\n      rewriter.replaceOpWithNewOp<memref::ViewOp>(plan_op, type, global_buffer, off_set,\n                                                  ValueRange{});\n    }\n  }\n\n  mem_type = MemRefType::get({static_cast<long>(res.mem_block_size)}, rewriter.getI8Type());\n  rewriter.setInsertionPoint(global_buffer);\n  rewriter.replaceOpWithNewOp<AllocMemrefOp>(global_buffer, mem_type);\n}\n}  // namespace\nstruct OptOKMMemrefPattern : public mlir::OpRewritePattern<func::FuncOp> {\n  explicit OptOKMMemrefPattern(mlir::MLIRContext* context)\n      : OpRewritePattern<func::FuncOp>(context, /*benefit=*/0) {}\n  mlir::LogicalResult matchAndRewrite(func::FuncOp op,\n                                      mlir::PatternRewriter& rewriter) const override {\n    const auto sym_name = op.getSymName();\n    if (sym_name.startswith(func_name::WRAP_GRAPH_NAME)) {\n      const auto index = sym_name.substr(func_name::WRAP_GRAPH_NAME.size()).str();\n      const std::string rename = func_name::OPT_GRAPH_NAME + index;\n      op.setSymNameAttr(rewriter.getStringAttr(rename));\n      MemSizeFirst(op, rewriter);\n    }\n    return success();\n  }\n};\n\nclass OptOKMMemrefPass : public OptOKMMemrefPassBase<OptOKMMemrefPass> {\n  void getDependentDialects(DialectRegistry& registry) const override {\n    registry.insert<oneflow::OneFlowDialect>();\n    registry.insert<OKMDialect>();\n    registry.insert<bufferization::BufferizationDialect>();\n    registry.insert<arith::ArithDialect>();\n  }\n\n  void runOnOperation() override {\n    Operation* op = getOperation();\n    RewritePatternSet patterns(op->getContext());\n    patterns.add<OptOKMMemrefPattern>(patterns.getContext());\n    (void)applyPatternsAndFoldGreedily(op, std::move(patterns));\n  }\n};\n\nstd::unique_ptr<Pass> createOptOKMMemrefPass() { return std::make_unique<OptOKMMemrefPass>(); }\n\nstruct ConvertOKMToOKLPattern : public mlir::OpRewritePattern<func::FuncOp> {\n  static void ConvertOpToOKL(mlir::Operation& it, func::FuncOp& wrap_func, WrapperOp wrap_mem_op,\n                             mlir::PatternRewriter& rewriter, int& index) {\n    auto wrap_okl_op = rewriter.create<okl::WrapperKernelOp>(rewriter.getUnknownLoc(), index++);\n    wrap_okl_op.getBody().emplaceBlock();\n    OpBuilder::InsertionGuard insertGuard(rewriter);\n    rewriter.setInsertionPointToStart(&wrap_okl_op.getBody().front());\n\n    IRMapping mapper;\n    auto ins_num = it.getNumOperands();\n    auto outs_num = it.getNumResults() + ins_num;\n    for (int idx = 0; idx < ins_num; ++idx) {\n      auto val = llvm::TypeSwitch<Operation*, Value>(wrap_mem_op->getOperand(idx).getDefiningOp())\n                     .Case<ArgToMemrefOp>([&](ArgToMemrefOp op) {\n                       return rewriter.create<okl::GetTensorFromArgOp>(\n                           rewriter.getUnknownLoc(),\n                           memref::getTensorTypeFromMemRefType(op->getResult(0).getType()),\n                           wrap_func.getArgument(0), op.getIndex());\n                     })\n                     .Case<RetToMemrefOp>([&](RetToMemrefOp op) {\n                       return rewriter.create<okl::GetTensorFromRetOp>(\n                           rewriter.getUnknownLoc(),\n                           memref::getTensorTypeFromMemRefType(op->getResult(0).getType()),\n                           wrap_func.getArgument(0), op.getIndex());\n                     })\n                     .Case<memref::ViewOp>([&](memref::ViewOp op) {\n                       auto offset = rewriter.getI64IntegerAttr(\n                           llvm::dyn_cast<arith::ConstantIndexOp>(op.getByteShift().getDefiningOp())\n                               .value());\n                       return rewriter.create<okl::PoolToTensorOp>(\n                           rewriter.getUnknownLoc(),\n                           memref::getTensorTypeFromMemRefType(op->getResult(0).getType()),\n                           wrap_func.getArgument(0), offset);\n                     })\n                     .Default([&](Operation*) { return Value{}; });\n      mapper.map(it.getOperand(idx), val);\n    }\n    ImplicitLocOpBuilder new_block(rewriter.getUnknownLoc(), rewriter);\n    auto new_op = new_block.clone(it, mapper);\n    for (int idx = ins_num; idx < outs_num; ++idx) {\n      llvm::TypeSwitch<Operation*>(wrap_mem_op->getOperand(idx).getDefiningOp())\n          .Case<RetToMemrefOp>([&](RetToMemrefOp op) {\n            return rewriter.create<okl::GetTensorAsRetOp>(\n                rewriter.getUnknownLoc(),\n                memref::getTensorTypeFromMemRefType(op->getResult(0).getType()),\n                wrap_func.getArgument(0), new_op->getResult(idx - ins_num), op.getIndex());\n          })\n          .Case<memref::ViewOp>([&](memref::ViewOp op) {\n            auto offset = rewriter.getI64IntegerAttr(\n                llvm::dyn_cast<arith::ConstantIndexOp>(op.getByteShift().getDefiningOp()).value());\n            return rewriter.create<okl::TensorToPoolOp>(\n                rewriter.getUnknownLoc(),\n                memref::getTensorTypeFromMemRefType(op->getResult(0).getType()),\n                wrap_func.getArgument(0), new_op->getResult(idx - ins_num), offset);\n          })\n          .Default([&](Operation*) { return Value{}; });\n    }\n    if (outs_num + 1 == wrap_mem_op->getNumOperands()) {\n      auto op = llvm::dyn_cast<memref::ViewOp>(wrap_mem_op->getOperand(outs_num).getDefiningOp());\n\n      auto offset = rewriter.getI64IntegerAttr(\n          llvm::dyn_cast<arith::ConstantIndexOp>(op.getByteShift().getDefiningOp()).value());\n      rewriter.create<okl::PoolToBufferOp>(\n          rewriter.getUnknownLoc(), memref::getTensorTypeFromMemRefType(op->getResult(0).getType()),\n          wrap_func.getArgument(0), offset);\n    }\n\n    rewriter.create<okl::ReturnOp>(rewriter.getUnknownLoc());\n  }\n\n  static func::FuncOp BuildOKLGraph(func::FuncOp func, mlir::PatternRewriter& rewriter,\n                                    const std::string& func_name) {\n    OpBuilder::InsertionGuard insertGuard(rewriter);\n    rewriter.setInsertionPoint(func);\n\n    auto func_type = rewriter.getFunctionType(\n        {mlir::okl::LauncherContextType::get(rewriter.getContext())}, TypeRange{});\n    auto wrap_func = rewriter.create<func::FuncOp>(rewriter.getUnknownLoc(), func_name, func_type);\n    auto& block = wrap_func.getBody().emplaceBlock();\n    wrap_func.getBody().addArguments(mlir::okl::LauncherContextType::get(rewriter.getContext()),\n                                     rewriter.getUnknownLoc());\n    rewriter.setInsertionPointToStart(&block);\n\n    llvm::SmallVector<Operation*> raw_ops;\n    for (auto& op : func.getBody().front()) { raw_ops.push_back(&op); }\n    auto index = 0;\n    for (auto op : raw_ops) {\n      if (auto alloc_op = llvm::dyn_cast_or_null<okm::AllocMemrefOp>(op)) {\n        if (auto mem_type = alloc_op->getResult(0).getType().dyn_cast_or_null<MemRefType>()) {\n          wrap_func->setAttr(func_name::OKL_POOL_SIZE_TAG,\n                             rewriter.getI64IntegerAttr(mem_type.getShape().front()));\n        }\n      }\n      if (auto wrap_mem_op = llvm::dyn_cast_or_null<WrapperOp>(op)) {\n        auto& wrap_ops = wrap_mem_op.getBody().front();\n        for (auto& it : wrap_ops) {\n          if (oneflow::OneFlowDialect::getDialectNamespace().equals(\n                  it.getDialect()->getNamespace())) {\n            ConvertOpToOKL(it, wrap_func, wrap_mem_op, rewriter, index);\n          }\n        }\n      }\n    }\n    rewriter.setInsertionPointToEnd(&block);\n    rewriter.create<func::ReturnOp>(rewriter.getUnknownLoc());\n    return wrap_func;\n  }\n\n  explicit ConvertOKMToOKLPattern(mlir::MLIRContext* context)\n      : OpRewritePattern<func::FuncOp>(context, /*benefit=*/0) {}\n  mlir::LogicalResult matchAndRewrite(func::FuncOp op,\n                                      mlir::PatternRewriter& rewriter) const override {\n    const auto sym_name = op.getSymName();\n    if (sym_name.startswith(func_name::OPT_GRAPH_NAME)) {\n      const auto index = sym_name.substr(func_name::OPT_GRAPH_NAME.size()).str();\n      const std::string rename = func_name::OKL_GRAPH_NAME;\n      BuildOKLGraph(op, rewriter, rename);\n      rewriter.eraseOp(op);\n    }\n    return success();\n  }\n};\n\nclass ConvertOKMToOKLPass : public ConvertOKMToOKLPassBase<ConvertOKMToOKLPass> {\n  void getDependentDialects(DialectRegistry& registry) const override {\n    registry.insert<oneflow::OneFlowDialect>();\n    registry.insert<OKMDialect>();\n    registry.insert<bufferization::BufferizationDialect>();\n    registry.insert<arith::ArithDialect>();\n    registry.insert<okl::OKLDialect>();\n  }\n\n  void runOnOperation() override {\n    Operation* op = getOperation();\n    RewritePatternSet patterns(op->getContext());\n    patterns.add<ConvertOKMToOKLPattern>(patterns.getContext());\n    (void)applyPatternsAndFoldGreedily(op, std::move(patterns));\n  }\n};\n\nstd::unique_ptr<Pass> createConvertOKMToOKLPass() {\n  return std::make_unique<ConvertOKMToOKLPass>();\n}\n\n}  // namespace okm\n\n}  // namespace mlir\n"
  },
  {
    "path": "oneflow/ir/lib/OneFlow/OneFlowCanonicalizers.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/random_generator.h\"\n#include \"OneFlow/OneFlowOps.h\"\n#include \"OneFlow/OneFlowPatternUtils.h\"\n\nnamespace mlir {\n\nnamespace oneflow {\n\nnamespace {\n\nstruct PutSeed : public OpRewritePattern<RandomMaskLikeOp> {\n  explicit PutSeed(MLIRContext* context)\n      : OpRewritePattern<RandomMaskLikeOp>(context, /*benefit=*/1) {}\n  LogicalResult matchAndRewrite(oneflow::RandomMaskLikeOp op,\n                                PatternRewriter& rewriter) const override {\n    if (op->hasAttr(op.getSeedAttrName())) {\n      return failure();\n    } else {\n      op->setAttr(op.getSeedAttrName(), rewrites::GetDefaultSeed(rewriter));\n      return success();\n    }\n  }\n};\n\n}  // namespace\n\nvoid RandomMaskLikeOp::getCanonicalizationPatterns(RewritePatternSet& results,\n                                                   MLIRContext* context) {\n  results.insert<PutSeed>(context);\n}\n\n}  // namespace oneflow\n\n}  // namespace mlir\n"
  },
  {
    "path": "oneflow/ir/lib/OneFlow/OneFlowDataTypeConversion.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"OneFlow/OneFlowDataTypeConversion.h\"\n#include \"OneFlow/OneFlowTypes.h\"\n\nnamespace mlir {\n\nnamespace oneflow {\n\nType getTypeFromOneFlowDataType(MLIRContext* context, ::oneflow::DataType dt) {\n  if (dt == ::oneflow::DataType::kInvalidDataType) { return InvalidElementType::get(context); }\n  if (dt == ::oneflow::DataType::kChar) { return CharElementType::get(context); }\n  if (dt == ::oneflow::DataType::kFloat16) { return FloatType::getF16(context); }\n  if (dt == ::oneflow::DataType::kFloat) { return FloatType::getF32(context); }\n  if (dt == ::oneflow::DataType::kDouble) { return FloatType::getF64(context); }\n  if (dt == ::oneflow::DataType::kInt8) {\n    return IntegerType::get(context, 8, IntegerType::Signed);\n  }\n  if (dt == ::oneflow::DataType::kInt32) {\n    return IntegerType::get(context, 32, IntegerType::Signed);\n  }\n  if (dt == ::oneflow::DataType::kInt64) {\n    return IntegerType::get(context, 64, IntegerType::Signed);\n  }\n  if (dt == ::oneflow::DataType::kOFRecord) { return OFRecordElementType::get(context); }\n  if (dt == ::oneflow::DataType::kTensorBuffer) { return TensorBufferElementType::get(context); }\n  if (dt == ::oneflow::DataType::kBool) {\n    return IntegerType::get(context, 8, IntegerType::Signed);\n  }\n  if (dt == ::oneflow::DataType::kUInt8) {\n    return IntegerType::get(context, 8, IntegerType::Unsigned);\n  }\n  if (dt == ::oneflow::DataType::kUInt16) {\n    return IntegerType::get(context, 16, IntegerType::Unsigned);\n  }\n  if (dt == ::oneflow::DataType::kUInt32) {\n    return IntegerType::get(context, 32, IntegerType::Unsigned);\n  }\n  if (dt == ::oneflow::DataType::kUInt64) {\n    return IntegerType::get(context, 64, IntegerType::Unsigned);\n  }\n  if (dt == ::oneflow::DataType::kUInt128) {\n    return IntegerType::get(context, 128, IntegerType::Unsigned);\n  }\n  llvm::errs() << \"unsupported oneflow data type: \" << dt << \"\\n\";\n  return Type();\n}\n\n}  // namespace oneflow\n\n}  // namespace mlir\n"
  },
  {
    "path": "oneflow/ir/lib/OneFlow/OneFlowDialect.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"OneFlow/OneFlowDialect.h\"\n#include \"OneFlow/OneFlowOps.h\"\n#include \"OneFlow/OneFlowTypes.h\"\n#include \"OneFlow/OneFlowOpsDialect.cpp.inc\"\n#include \"mlir/IR/BuiltinAttributes.h\"\n#include \"mlir/IR/Dialect.h\"\n#include \"mlir/IR/TypeRange.h\"\n\nnamespace mlir {\n\nnamespace oneflow {\n\nvoid OneFlowDialect::initialize() {\n  addOperations<\n#define GET_OP_LIST\n#include \"OneFlow/OneFlowOps.cpp.inc\"\n      ,\n#define GET_OP_LIST\n#include \"OneFlow/OneFlow.assign_ops.cpp.inc\"\n      ,\n#define GET_OP_LIST\n#include \"OneFlow/OneFlow.binary_ops.cpp.inc\"\n      ,\n#define GET_OP_LIST\n#include \"OneFlow/OneFlow.broadcast_ops.cpp.inc\"\n      ,\n#define GET_OP_LIST\n#include \"OneFlow/OneFlow.conv_ops.cpp.inc\"\n      ,\n#define GET_OP_LIST\n#include \"OneFlow/OneFlow.cross_entropy_ops.cpp.inc\"\n      ,\n#define GET_OP_LIST\n#include \"OneFlow/OneFlow.cuda_ops.cpp.inc\"\n      ,\n#define GET_OP_LIST\n#include \"OneFlow/OneFlow.dataset_ops.cpp.inc\"\n      ,\n#define GET_OP_LIST\n#include \"OneFlow/OneFlow.detection_ops.cpp.inc\"\n      ,\n#define GET_OP_LIST\n#include \"OneFlow/OneFlow.eager_ops.cpp.inc\"\n      ,\n#define GET_OP_LIST\n#include \"OneFlow/OneFlow.fused_ops.cpp.inc\"\n      ,\n#define GET_OP_LIST\n#include \"OneFlow/OneFlow.idempotent_ops.cpp.inc\"\n      ,\n#define GET_OP_LIST\n#include \"OneFlow/OneFlow.identity_ops.cpp.inc\"\n      ,\n#define GET_OP_LIST\n#include \"OneFlow/OneFlow.image_ops.cpp.inc\"\n      ,\n#define GET_OP_LIST\n#include \"OneFlow/OneFlow.indices_ops.cpp.inc\"\n      ,\n#define GET_OP_LIST\n#include \"OneFlow/OneFlow.involution_ops.cpp.inc\"\n      ,\n#define GET_OP_LIST\n#include \"OneFlow/OneFlow.loss_ops.cpp.inc\"\n      ,\n#define GET_OP_LIST\n#include \"OneFlow/OneFlow.math_ops.cpp.inc\"\n      ,\n#define GET_OP_LIST\n#include \"OneFlow/OneFlow.matmul_ops.cpp.inc\"\n      ,\n#define GET_OP_LIST\n#include \"OneFlow/OneFlow.misc_ops.cpp.inc\"\n      ,\n#define GET_OP_LIST\n#include \"OneFlow/OneFlow.nccl_ops.cpp.inc\"\n      ,\n#define GET_OP_LIST\n#include \"OneFlow/OneFlow.normalization_ops.cpp.inc\"\n      ,\n#define GET_OP_LIST\n#include \"OneFlow/OneFlow.optimizer_ops.cpp.inc\"\n      ,\n#define GET_OP_LIST\n#include \"OneFlow/OneFlow.padding_ops.cpp.inc\"\n      ,\n#define GET_OP_LIST\n#include \"OneFlow/OneFlow.parallel_cast_ops.cpp.inc\"\n      ,\n#define GET_OP_LIST\n#include \"OneFlow/OneFlow.pool_ops.cpp.inc\"\n      ,\n#define GET_OP_LIST\n#include \"OneFlow/OneFlow.quantization_ops.cpp.inc\"\n      ,\n#define GET_OP_LIST\n#include \"OneFlow/OneFlow.reduce_ops.cpp.inc\"\n      ,\n#define GET_OP_LIST\n#include \"OneFlow/OneFlow.reshape_ops.cpp.inc\"\n      ,\n#define GET_OP_LIST\n#include \"OneFlow/OneFlow.scalar_ops.cpp.inc\"\n      ,\n#define GET_OP_LIST\n#include \"OneFlow/OneFlow.softmax_ops.cpp.inc\"\n      ,\n#define GET_OP_LIST\n#include \"OneFlow/OneFlow.summary_ops.cpp.inc\"\n      ,\n#define GET_OP_LIST\n#include \"OneFlow/OneFlow.tensor_buffer_ops.cpp.inc\"\n      ,\n#define GET_OP_LIST\n#include \"OneFlow/OneFlow.trigonometric_ops.cpp.inc\"\n      ,\n#define GET_OP_LIST\n#include \"OneFlow/OneFlow.unary_ops.cpp.inc\"\n      ,\n#define GET_OP_LIST\n#include \"OneFlow/OneFlow.upsample_ops.cpp.inc\"\n      ,\n#define GET_OP_LIST\n#include \"OneFlow/OneFlow.one_embedding_ops.cpp.inc\"\n      ,\n#define GET_OP_LIST\n#include \"OneFlow/OneFlow.linear_algebra_ops.cpp.inc\"\n      ,\n#define GET_OP_LIST\n#include \"OneFlow/OneFlow.system_ops.cpp.inc\"\n      ,\n#define GET_OP_LIST\n#include \"OneFlow/OneFlow.mlir_jit_ops.cpp.inc\"\n      >();\n  addTypes<\n#define GET_TYPEDEF_LIST\n#include \"OneFlow/OneFlowOpsTypes.cpp.inc\"\n      >();\n}\n\nmlir::Operation* OneFlowDialect::materializeConstant(mlir::OpBuilder& builder,\n                                                     mlir::Attribute value, mlir::Type type,\n                                                     mlir::Location loc) {\n  return builder.create<FrozenVariableOp>(loc, type, ValueRange(),\n                                          value.cast<mlir::DictionaryAttr>().getValue());\n}\n\n}  // namespace oneflow\n\n}  // namespace mlir\n"
  },
  {
    "path": "oneflow/ir/lib/OneFlow/OneFlowInferReturnTypes.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"OneFlow/OneFlowOps.h\"\n#include \"OneFlow/UserOpConversion.h\"\n#include \"OneFlow/UserOpReflection.h\"\n#include \"OneFlow/OneFlowDataTypeConversion.h\"\n#include \"mlir/Support/LogicalResult.h\"\n\nnamespace mlir {\n\nnamespace oneflow {\n\nnamespace {\n\nstd::unique_ptr<::oneflow::BlobDesc> getBlobDescFromTensorType(TensorType tensor_type) {\n  auto data_type = mlir::oneflow::support::FromMLIRTypeToOFDataType(tensor_type.getElementType());\n  if (mlir::succeeded(data_type)) {\n    auto shape_from_mlir = new ::oneflow::Shape(llvm::SmallVector<int64_t, 4>(\n        {tensor_type.getShape().begin(), tensor_type.getShape().end()}));\n    return std::make_unique<::oneflow::BlobDesc>(*shape_from_mlir, data_type.value(),\n                                                 ::oneflow::MemoryFormat::kContiguous);\n  }\n  tensor_type.dump();\n  LOG(FATAL) << \"fail to get BlobDesc from TensorType\";\n}\n\nType getTensorTypeFromBlobDesc(MLIRContext* context, const ::oneflow::BlobDesc* blob_desc) {\n  if (auto type = getTypeFromOneFlowDataType(context, blob_desc->data_type())) {\n    return RankedTensorType::get(\n        llvm::SmallVector<int64_t, 4>(\n            {blob_desc->shape().dim_vec().begin(), blob_desc->shape().dim_vec().end()}),\n        type);\n  } else {\n    return Type{};\n  }\n}\n\nstatic auto MagicalOpName = \"INFER_MAGICAL\";\nLogicalResult ConvertUserOp(llvm::StringRef op_type_name, ::oneflow::OperatorConf& op_conf,\n                            ValueRange operands, DictionaryAttr attributes) {\n  oneflow::ConfOpAdaptor conf_op_adaptor(operands, attributes);\n  op_conf.set_name(MagicalOpName);\n  CHECK(\n      user_op::ConvertUserOpInputs(op_type_name, operands, attributes, op_conf.mutable_user_conf())\n          .succeeded());\n  if (!succeeded(user_op::ConvertUserOpAttributes(op_type_name, operands, attributes, op_conf))) {\n    return failure();\n  }\n  return success();\n}\n\nsize_t getResultSize(DictionaryAttr attributes) {\n  const StringRef attr_name = OpTrait::AttrSizedResultSegments<void>::getResultSegmentSizeAttr();\n  const DenseI32ArrayAttr& size_attr =\n      attributes.get(attr_name).dyn_cast_or_null<DenseI32ArrayAttr>();\n  CHECK(size_attr) << \"Attr \" << attr_name.str() << \" is not found or not DenseI32ArrayAttr\";\n  auto size = 0;\n  for (auto s : size_attr.asArrayRef()) { size += s; }\n  return size;\n}\n\n::mlir::LogicalResult inferReturnTypesWithOpTypeName(\n    llvm::StringRef op_type_name, ::mlir::MLIRContext* context, ::mlir::ValueRange operands,\n    ::mlir::DictionaryAttr attributes, ::mlir::RegionRange regions,\n    ::llvm::SmallVectorImpl<::mlir::Type>& inferredReturnTypes) {\n  ::oneflow::OperatorConf op_conf{};\n  CHECK(ConvertUserOp(op_type_name, op_conf, operands, attributes).succeeded());\n  std::unordered_map<std::string, std::unique_ptr<::oneflow::BlobDesc>> lbi2logical_blob_desc_;\n  auto operand_ids =\n      user_op::ArgIds<OpTrait::AttrSizedOperandSegments>(op_type_name, operands.size(), attributes);\n  auto operand_index = 0;\n  for (const auto& idOperand : llvm::zip(operand_ids, operands)) {\n    const auto& arg_name = std::get<0>(idOperand).first;\n    const auto& arg_id = std::get<0>(idOperand).second;\n    const auto operand = std::get<1>(idOperand);\n    auto blob_desc = getBlobDescFromTensorType(operand.getType().cast<TensorType>());\n    auto bn = ::oneflow::GenRepeatedBn(arg_name, arg_id);\n    lbi2logical_blob_desc_.emplace(bn, std::move(blob_desc));\n    operand_index += 1;\n  }\n  auto result_ids = user_op::ArgIds<OpTrait::AttrSizedResultSegments>(\n      op_type_name, getResultSize(attributes), attributes);\n  for (const auto& result_id : result_ids) {\n    const auto& arg_name = result_id.first;\n    const auto& arg_id = result_id.second;\n    const auto bn = ::oneflow::GenRepeatedBn(arg_name, arg_id);\n    auto blob_desc = std::make_unique<::oneflow::BlobDesc>(::oneflow::kInvalidDataType,\n                                                           ::oneflow::MemoryFormat::kContiguous);\n    lbi2logical_blob_desc_.emplace(bn, std::move(blob_desc));\n    (*op_conf.mutable_user_conf()->mutable_output())[arg_name].add_s(\n        ::oneflow::GenLogicalBlobName(op_conf.name(), bn));\n  }\n  auto op = CHECK_JUST(ConstructOp(op_conf, user_op::getDeviceTypeFromAttrDictionary(attributes)));\n  auto GetLogicalBlobDesc4BnInOp = [&](const std::string& bn) -> ::oneflow::BlobDesc* {\n    auto it = lbi2logical_blob_desc_.find(bn);\n    if (it == lbi2logical_blob_desc_.end()) {\n      LOG(FATAL) << \"fail to find blob name in op: \" << bn;\n    }\n    return it->second.get();\n  };\n  ::oneflow::ParallelConf parallel_conf = user_op::getParallelConfFromAttrDictionary(attributes);\n  ::oneflow::ParallelDesc parallel_desc{parallel_conf};\n  CHECK_JUST(op->FillOpParallelDesc(parallel_desc));\n  CHECK_JUST(op->InferLogicalOutBlobDescs(GetLogicalBlobDesc4BnInOp, parallel_desc));\n  for (const auto& result_id : result_ids) {\n    const auto& arg_name = result_id.first;\n    const auto& arg_id = result_id.second;\n    const auto bn = ::oneflow::GenRepeatedBn(arg_name, arg_id);\n    const auto* desc = lbi2logical_blob_desc_.at(bn).get();\n    if (auto t = getTensorTypeFromBlobDesc(context, desc)) { inferredReturnTypes.push_back(t); }\n  }\n  return success();\n}\n\n}  // namespace\n\n::mlir::LogicalResult NormalizationAddReluOp::refineReturnTypes(\n    ::mlir::MLIRContext* context, ::llvm::Optional<::mlir::Location> location,\n    ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes,\n    ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions,\n    ::llvm::SmallVectorImpl<::mlir::Type>& inferredReturnTypes) {\n  return success();\n}\n\n::mlir::LogicalResult NormalizationAddReluOp::inferReturnTypes(\n    ::mlir::MLIRContext* context, ::llvm::Optional<::mlir::Location> location,\n    ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes,\n    ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions,\n    ::llvm::SmallVectorImpl<::mlir::Type>& inferredReturnTypes) {\n  return inferReturnTypesWithOpTypeName(\"normalization_add_relu\", context, operands, attributes,\n                                        regions, inferredReturnTypes);\n}\n\n}  // namespace oneflow\n\n}  // namespace mlir\n"
  },
  {
    "path": "oneflow/ir/lib/OneFlow/OneFlowOpFolders.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <functional>\n#include <memory>\n#include <vector>\n#include \"OneFlow/OneFlowOpTraits.h\"\n#include \"OneFlow/OneFlowOps.h\"\n#include \"mlir/IR/Attributes.h\"\n#include \"mlir/IR/BuiltinAttributes.h\"\n#include \"mlir/IR/BuiltinTypes.h\"\n#include \"mlir/IR/OpDefinition.h\"\n#include \"mlir/IR/OperationSupport.h\"\n#include \"mlir/IR/Value.h\"\n#include \"oneflow/core/common/just.h\"\n#include \"oneflow/core/common/shape_vec.h\"\n#include \"oneflow/core/framework/tensor.h\"\n#include \"oneflow/core/functional/functional_api.yaml.h\"\n#include \"oneflow/core/job/lazy_mode.h\"\n#include \"oneflow/core/framework/variable_tensor_mgr.h\"\n\nnamespace mlir {\nnamespace oneflow {\nnamespace {\n\nnamespace functional = ::oneflow::one::functional;\nusing TensorPtr = std::shared_ptr<::oneflow::one::Tensor>;\nusing MaybeTensor = ::oneflow::Maybe<::oneflow::one::Tensor>;\n\nStringAttr GenNewVariableOpName(MLIRContext* ctx, const std::string& key = \"\") {\n  if (key == \"\") { return StringAttr::get(ctx, \"variable_\" + ::oneflow::NewUniqueId()); }\n  return StringAttr::get(ctx, \"variable_\" + key + \"_\" + ::oneflow::NewUniqueId());\n}\n\nbool MLIRDataTypesAreSame(const std::vector<DataType>& data_types) {\n  if (data_types.empty() || data_types.size() == 1) { return true; }\n  bool result = true;\n  const auto first_data_type = data_types[0];\n  for (size_t i = 1; i < data_types.size(); ++i) { result &= (first_data_type == data_types[i]); }\n  return result;\n}\n\nbool DictionaryAttrsHaveSameDataType(const std::vector<mlir::DictionaryAttr>& attrs) {\n  std::vector<DataType> data_types;\n  for (const auto& attr : attrs) {\n    data_types.push_back(attr.get(OpTrait::TensorSource<void>::getDataTypeAttrName())\n                             .cast<DataTypeAttr>()\n                             .getValue());\n  }\n  return MLIRDataTypesAreSame(data_types);\n}\n\nOpFoldResult UnaryFold(MLIRContext* ctx, ArrayRef<Attribute> operands,\n                       const std::function<MaybeTensor(const TensorPtr&)>& f) {\n  ::oneflow::LazyMode::Guard guard{false};\n  if (!operands.front()) { return {}; }  // Important!\n\n  const auto attr_dict = operands.front().cast<mlir::DictionaryAttr>();\n  auto attrs = NamedAttrList(attr_dict);\n  const auto tensor = support::DenseElementsAttrToTensor(\n      attr_dict.get(\"value\"), attr_dict.get(OpTrait::IsOpConfCompatible<void>::getDeviceTagAttr()),\n      attr_dict.get(OpTrait::IsOpConfCompatible<void>::getDeviceNameAttr()));\n  const auto result = f(tensor).GetPtrOrThrow();\n  attrs.set(\"value\", support::TensorToDenseElementsAttr(result, ctx));\n  attrs.set(OpTrait::IsOpConfCompatible<void>::getOpNameAttr(), GenNewVariableOpName(ctx));\n  attrs.set(OpTrait::TensorSource<void>::getDataTypeAttrName(),\n            attr_dict.get(OpTrait::TensorSource<void>::getDataTypeAttrName()));\n\n  return attrs.getDictionary(ctx);\n}\n\nOpFoldResult BinaryFold(MLIRContext* ctx, ArrayRef<Attribute> operands,\n                        const std::function<MaybeTensor(const TensorPtr&, const TensorPtr&)>& f) {\n  ::oneflow::LazyMode::Guard guard{false};\n  if (!(operands.front() && operands.back())) { return {}; }  // Important!\n  auto lhs_attr_dict = operands.front().cast<mlir::DictionaryAttr>();\n  auto rhs_attr_dict = operands.back().cast<mlir::DictionaryAttr>();\n  if (!DictionaryAttrsHaveSameDataType({lhs_attr_dict, rhs_attr_dict})) {\n    llvm::errs()\n        << \"Input tensors should have same data type in binary operation of constant folding.\"\n        << \"\\n\";\n    return nullptr;\n  }\n\n  auto attrs = NamedAttrList(lhs_attr_dict);\n  const auto lhs_tensor = support::DenseElementsAttrToTensor(\n      lhs_attr_dict.get(\"value\"),\n      lhs_attr_dict.get(OpTrait::IsOpConfCompatible<void>::getDeviceTagAttr()),\n      lhs_attr_dict.get(OpTrait::IsOpConfCompatible<void>::getDeviceNameAttr()));\n  const auto rhs_tensor = support::DenseElementsAttrToTensor(\n      rhs_attr_dict.get(\"value\"),\n      rhs_attr_dict.get(OpTrait::IsOpConfCompatible<void>::getDeviceTagAttr()),\n      rhs_attr_dict.get(OpTrait::IsOpConfCompatible<void>::getDeviceNameAttr()));\n\n  const auto result = f(lhs_tensor, rhs_tensor).GetPtrOrThrow();\n\n  attrs.set(\"value\", support::TensorToDenseElementsAttr(result, ctx));\n  attrs.set(OpTrait::IsOpConfCompatible<void>::getOpNameAttr(), GenNewVariableOpName(ctx));\n  attrs.set(OpTrait::TensorSource<void>::getDataTypeAttrName(),\n            lhs_attr_dict.get(OpTrait::TensorSource<void>::getDataTypeAttrName()));\n\n  return attrs.getDictionary(ctx);\n}\n\n}  // namespace\n\nOpFoldResult FrozenVariableOp::fold(FoldAdaptor adaptor) {\n  NamedAttrList attrs;\n  attrs.set(getValueAttrName(), getValueAttr());\n  attrs.set(getOpNameAttrName(), getOpNameAttr());\n  attrs.set(getDataTypeAttrName(), getDataTypeAttr());\n  attrs.set(getDeviceTagAttrName(), getDeviceTagAttr());\n  attrs.set(getDeviceNameAttrName(), getDeviceNameAttr());\n  attrs.set(getScopeSymbolIdAttrName(), getScopeSymbolIdAttr());\n  attrs.set(getHierarchyAttrName(), getHierarchyAttr());\n  attrs.set(getNdSbpAttrName(), getNdSbpAttr());\n  return DictionaryAttr::get(getContext(), attrs);\n}\n\nOpFoldResult TransposeOp::fold(FoldAdaptor adaptor) {\n  auto operands = adaptor.getOperands();\n  return UnaryFold(getContext(), operands, [this](const auto& tensor) {\n    std::vector<int32_t> perm_;\n    for (auto& x : getPerm().getValue()) { perm_.emplace_back(x.cast<IntegerAttr>().getSInt()); }\n    return functional::Transpose(tensor, perm_);\n  });\n}\n\nOpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {\n  auto operands = adaptor.getOperands();\n  return UnaryFold(getContext(), operands, [this](const auto& tensor) {\n    std::vector<int64_t> shape_vec;\n    for (auto& x : getShape().getValue()) {\n      shape_vec.emplace_back(x.cast<mlir::IntegerAttr>().getValue().getSExtValue());\n    }\n    return functional::Reshape(\n        tensor, ::oneflow::Shape(::oneflow::DimVector(shape_vec.begin(), shape_vec.end())));\n  });\n}\n\nOpFoldResult ScalarAddOp::fold(FoldAdaptor adaptor) {\n  auto operands = adaptor.getOperands();\n  return UnaryFold(getContext(), operands, [this](const auto& tensor) -> MaybeTensor {\n    if (getHasIntOperand()) { return functional::ScalarAdd(tensor, getIntOperand(), 1, false); }\n    if (getHasFloatOperand()) {\n      return functional::ScalarAdd(tensor, getFloatOperand().convertToDouble(), 1, false);\n    }\n    emitError(\"Scalar op must has a int operand or a float operand.\");\n    return TensorPtr();\n  });\n}\n\nOpFoldResult SqrtOp::fold(FoldAdaptor adaptor) {\n  auto operands = adaptor.getOperands();\n  return UnaryFold(getContext(), operands, functional::Sqrt);\n}\n\nOpFoldResult BroadcastMulOp::fold(FoldAdaptor adaptor) {\n  auto operands = adaptor.getOperands();\n  return BinaryFold(getContext(), operands, functional::Mul);\n}\n\nOpFoldResult BroadcastDivOp::fold(FoldAdaptor adaptor) {\n  auto operands = adaptor.getOperands();\n  return BinaryFold(getContext(), operands, functional::Div);\n}\n\nOpFoldResult BroadcastSubOp::fold(FoldAdaptor adaptor) {\n  auto operands = adaptor.getOperands();\n  return BinaryFold(getContext(), operands, [](const auto& lhs, const auto& rhs) -> MaybeTensor {\n    return functional::Sub(lhs, rhs, /*alpha=*/1.0, false);\n  });\n}\n\n}  // namespace oneflow\n}  // namespace mlir\n"
  },
  {
    "path": "oneflow/ir/lib/OneFlow/OneFlowOpGetGen.cpp.in",
    "content": "#include <iostream>\n#include <string>\n#include \"llvm/ADT/STLExtras.h\"\n#include \"mlir/IR/BuiltinAttributes.h\"\n#include \"mlir/IR/OpImplementation.h\"\n#include \"llvm/ADT/StringSet.h\"\n#include \"mlir/Support/LLVM.h\"\n#include \"mlir/Support/LogicalResult.h\"\n#include \"OneFlow/OneFlowDialect.h\"\n#include \"OneFlow/OneFlowOpTraits.h\"\n#include \"OneFlow/OneFlowSupport.h\"\n#include \"OneFlow/OneFlowInterfaces.h.inc\"\n#include \"OneFlow/OneFlowTypes.h\"\n\n#define GET_OP_CLASSES\n#include \"OneFlow/OneFlow.@OP_GROUP_NAME_LOWER@_ops.h.inc\"\n#define GET_OP_CLASSES\n#include \"OneFlow/OneFlow.@OP_GROUP_NAME_LOWER@_ops.cpp.inc\"\n"
  },
  {
    "path": "oneflow/ir/lib/OneFlow/OneFlowOpTraits.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"OneFlow/OneFlowOps.h\"\n#include \"OneFlow/UserOpConversion.h\"\n\nnamespace mlir {\n\nnamespace OpTrait {\n\nnamespace {\n\n// TODO: merge all ctrl input and output when folding op\nbool HaveIdenticalPlacement(mlir::Operation* a, mlir::Operation* b) {\n  const bool has_identical_dev_tag =\n      IsOpConfCompatible<void>::getDeviceTag(a) == IsOpConfCompatible<void>::getDeviceTag(b);\n  const bool has_identical_dev_name =\n      IsOpConfCompatible<void>::getDeviceName(a) == IsOpConfCompatible<void>::getDeviceName(b);\n  return has_identical_dev_tag && has_identical_dev_name;\n}\n\n}  // namespace\n\nnamespace impl {\n\nOpFoldResult foldIdempotentOfIdenticalPlacement(Operation* op) {\n  auto* argument_op = op->getOperand(0).getDefiningOp();\n  if (argument_op && op->getName() == argument_op->getName()\n      && HaveIdenticalPlacement(op, argument_op)) {\n    return op->getOperand(0);\n  }\n  return {};\n}\n\nOpFoldResult foldInvolutionOfIdenticalPlacement(Operation* op) {\n  auto* argument_op = op->getOperand(0).getDefiningOp();\n  if (argument_op && op->getName() == argument_op->getName()\n      && HaveIdenticalPlacement(op, argument_op)) {\n    return argument_op->getOperand(0);\n  }\n  return {};\n}\n\nLogicalResult VerifyIsOpConfCompatible(Operation* op) {\n  for (auto attr : {\n           IsOpConfCompatible<void>::getOpNameAttr(),\n           IsOpConfCompatible<void>::getDeviceTagAttr(),\n       }) {\n    if (!op->hasAttrOfType<StringAttr>(attr)) {\n      return op->emitError(\"expected operation to have attribute: \" + attr);\n    }\n  }\n  if (!op->hasAttrOfType<ArrayAttr>(IsOpConfCompatible<void>::getDeviceNameAttr())) {\n    return op->emitError(\"expected operation to have attribute: \"\n                         + IsOpConfCompatible<void>::getDeviceNameAttr());\n  }\n  return success();\n}\n\nLogicalResult VerifyIsImportCompatible(Operation* op) {\n  if (auto output_lbns =\n          op->getAttrOfType<ArrayAttr>(IsImportCompatible<void>::getOutputLBNsAttr())) {\n    if (auto cec = dyn_cast<oneflow::ControlEdgeCompatible>(op)) {\n      if (cec.dataOutputResults().size() != output_lbns.size()) {\n        return op->emitError(\"expected number of data output results to be \"\n                             + std::to_string(output_lbns.size()) + \" but got \"\n                             + std::to_string(cec.dataOutputResults().size()));\n      }\n    } else {\n      return op->emitError(\"expected to support ControlEdgeCompatible\");\n    }\n  } else {\n    return op->emitError(\"expected operation to have attribute: \"\n                         + IsImportCompatible<void>::getOutputLBNsAttr());\n  }\n  return success();\n}\n\nLogicalResult saveAttrToOpConf(Operation* op, ::oneflow::OperatorConf* op_conf) {\n  return oneflow::user_op::saveAttrDictionaryToOpConf(op->getAttrDictionary(), op_conf);\n}\n\nStringAttr getOpName(Operation* op) {\n  assert(op->hasTrait<OpTrait::IsOpConfCompatible>());\n  return op->getAttrOfType<StringAttr>(OpTrait::IsOpConfCompatible<void>::getOpNameAttr());\n}\nStringAttr getDeviceTag(Operation* op) {\n  assert(op->hasTrait<OpTrait::IsOpConfCompatible>());\n  return op->getAttrOfType<StringAttr>(IsOpConfCompatible<void>::getDeviceTagAttr());\n}\nArrayAttr getDeviceName(Operation* op) {\n  assert(op->hasTrait<OpTrait::IsOpConfCompatible>());\n  return op->getAttrOfType<ArrayAttr>(IsOpConfCompatible<void>::getDeviceNameAttr());\n}\n\nIntegerAttr getScopeSymbolID(Operation* op) {\n  assert(op->hasTrait<OpTrait::IsOpConfCompatible>());\n  return op->getAttrOfType<IntegerAttr>(IsOpConfCompatible<void>::getScopeSymbolIDAttr());\n}\nArrayAttr getHierarchy(Operation* op) {\n  assert(op->hasTrait<OpTrait::IsOpConfCompatible>());\n  return op->getAttrOfType<ArrayAttr>(IsOpConfCompatible<void>::getHierarchyAttr());\n}\n\nLogicalResult saveAttrsToNamedAttrList(Operation* op, NamedAttrList& attributes) {\n  attributes.set(OpTrait::IsOpConfCompatible<void>::getDeviceTagAttr(),\n                 OpTrait::IsOpConfCompatible<void>::getDeviceTag(op));\n  attributes.set(OpTrait::IsOpConfCompatible<void>::getDeviceNameAttr(),\n                 OpTrait::IsOpConfCompatible<void>::getDeviceName(op));\n  if (auto hierarchy = OpTrait::IsOpConfCompatible<void>::getHierarchy(op)) {\n    attributes.set(OpTrait::IsOpConfCompatible<void>::getHierarchyAttr(), hierarchy);\n  }\n  attributes.set(OpTrait::IsOpConfCompatible<void>::getOpNameAttr(),\n                 OpTrait::IsOpConfCompatible<void>::getOpName(op));\n  if (auto scope_symbol_id = OpTrait::IsOpConfCompatible<void>::getScopeSymbolID(op)) {\n    attributes.set(OpTrait::IsOpConfCompatible<void>::getScopeSymbolIDAttr(), scope_symbol_id);\n  }\n  return success();\n}\n\n}  // namespace impl\n\n}  // namespace OpTrait\n\n}  // namespace mlir\n"
  },
  {
    "path": "oneflow/ir/lib/OneFlow/OneFlowOps.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"OneFlow/OneFlowOps.h\"\n#include \"OneFlow/OneFlowDialect.h\"\n#include \"OneFlow/OneFlowSupport.h\"\n#include \"OneFlow/SBP/SBPAttributes.h\"\n#include \"OneFlow/Transform/TransposeHelpers.h\"\n#include \"llvm/ADT/StringRef.h\"\n#include \"mlir/IR/Attributes.h\"\n#include \"mlir/IR/BuiltinTypes.h\"\n#include \"mlir/IR/OperationSupport.h\"\n#include \"oneflow/core/common/data_type.pb.h\"\n#include \"oneflow/core/common/just.h\"\n#include \"oneflow/core/common/shape.h\"\n#include \"oneflow/core/common/shape_vec.h\"\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/framework/device.h\"\n#include \"oneflow/core/framework/tensor.h\"\n#include \"oneflow/core/functional/functional_api.yaml.h\"\n#include \"oneflow/core/common/data_type.h\"\n#include \"oneflow/core/framework/tensor_util.h\"\n#include \"oneflow/core/job/lazy_mode.h\"\n#include \"oneflow/core/vm/vm_util.h\"\n\n#include \"llvm/ADT/STLExtras.h\"\n#include \"llvm/ADT/StringSet.h\"\n\n#include \"llvm/Support/Casting.h\"\n#include \"mlir/IR/Attributes.h\"\n#include \"mlir/IR/BuiltinAttributes.h\"\n#include \"mlir/IR/MLIRContext.h\"\n#include \"mlir/IR/OpImplementation.h\"\n#include \"mlir/IR/OperationSupport.h\"\n#include \"mlir/IR/FunctionImplementation.h\"\n#include \"mlir/Support/LLVM.h\"\n#include \"mlir/Support/LogicalResult.h\"\n\n#include <iostream>\n#include <memory>\n#include <string>\n#include <vector>\n\nnamespace mlir {\n\nnamespace oneflow {\n\nOperandRange UserOp::dataInputOperands() { return getDataInput(); }\nOperandRange UserOp::ctrlInputOperands() { return getCtrlInputs(); }\nResultRange UserOp::dataOutputResults() { return getDataOutput(); }\nValue UserOp::ctrlOutputResult() { return getCtrlOutput(); }\n\nOperandRange SystemOp::dataInputOperands() { return getDataInput(); }\nOperandRange SystemOp::ctrlInputOperands() { return getCtrlInputs(); }\nResultRange SystemOp::dataOutputResults() { return getDataOutput(); }\nValue SystemOp::ctrlOutputResult() { return getCtrlOutput(); }\n\nOperandRange VariableOp::dataInputOperands() { return {operand_begin(), operand_begin()}; }\nOperandRange VariableOp::ctrlInputOperands() { return getCtrlInputs(); }\nResultRange VariableOp::dataOutputResults() { return getOutput().dyn_cast<OpResult>(); }\nValue VariableOp::ctrlOutputResult() { return getCtrlOutput(); }\n\nOperandRange InputOp::dataInputOperands() { return getODSOperands(0); }\nOperandRange InputOp::ctrlInputOperands() { return getCtrlInputs(); }\nResultRange InputOp::dataOutputResults() { return getOutput().dyn_cast<OpResult>(); }\nValue InputOp::ctrlOutputResult() { return getCtrlOutput(); }\n\nOperandRange OutputOp::dataInputOperands() { return getODSOperands(0); }\nOperandRange OutputOp::ctrlInputOperands() { return getCtrlInputs(); }\nResultRange OutputOp::dataOutputResults() { return getOutput().dyn_cast<OpResult>(); }\nValue OutputOp::ctrlOutputResult() { return getCtrlOutput(); }\n\nstatic ParseResult parseConstantOp(OpAsmParser& parser, OperationState& result) {\n  mlir::DenseElementsAttr value;\n  if (parser.parseOptionalAttrDict(result.attributes)\n      || parser.parseAttribute(value, \"value\", result.attributes)) {\n    return failure();\n  }\n  result.addTypes(value.getType());\n  return success();\n}\n\nArrayAttr getSI32ArrayAttr(::mlir::PatternRewriter& rewriter, ArrayRef<int32_t> values) {\n  auto attrs = llvm::to_vector<8>(llvm::map_range(\n      values, [&](int32_t v) -> Attribute { return rewriter.getSI32IntegerAttr(v); }));\n  return rewriter.getArrayAttr(attrs);\n}\n\nnamespace {\n\nLogicalResult TrimRedundantCtrl(Operation* op, PatternRewriter& rewriter) {\n  auto ctrl_out = GetCtrlOutputResult(op);\n  auto data_outputs = GetDataOutputResults(op);\n  if (ctrl_out && ctrl_out.value().use_empty()) {\n    const int32_t num_data_outputs = data_outputs.size();\n    NamedAttrList attributes(op->getAttrs());\n    if (op->hasTrait<OpTrait::AttrSizedResultSegments>()) {\n      attributes.erase(OpTrait::AttrSizedResultSegments<void>::getResultSegmentSizeAttr());\n      attributes.push_back(\n          rewriter.getNamedAttr(OpTrait::AttrSizedResultSegments<void>::getResultSegmentSizeAttr(),\n                                rewriter.getDenseI32ArrayAttr({num_data_outputs, 0})));\n    }\n    OperationState state(op->getLoc(), op->getName(), op->getOperands(), data_outputs.getTypes(),\n                         attributes);\n    auto created = rewriter.create(state);\n    for (auto data_output : data_outputs) {\n      data_output.replaceAllUsesWith(created->getOpResult(data_output.getResultNumber()));\n    }\n    op->erase();\n    return success();\n  }\n  return failure();\n}\n\nbool IsCtrlOutTrimmed(UserOp& op) { return !op.getCtrlOutput(); }\n\nbool IsCtrlInAbsent(UserOp& op) {\n  if (!op->hasAttrOfType<DenseI32ArrayAttr>(\n          OpTrait::AttrSizedOperandSegments<void>::getOperandSegmentSizeAttr()))\n    op.dump();\n  return op.getCtrlInputs().empty();\n}\n\n}  // namespace\n\ntemplate<typename T>\nstatic void getValuesFromIntArrayAttribute(ArrayAttr attr, SmallVector<T>& arrayValues) {\n  for (Attribute val : attr.getValue()) {\n    arrayValues.push_back(val.cast<IntegerAttr>().getValue().getSExtValue());\n  }\n}\n\nstruct ConcreteUserOps : public OpRewritePattern<UserOp> {\n  explicit ConcreteUserOps(MLIRContext* context)\n      : OpRewritePattern<UserOp>(context, /*benefit=*/1) {}\n  LogicalResult matchAndRewrite(UserOp op, PatternRewriter& rewriter) const override {\n    if (succeeded(TrimRedundantCtrl(op, rewriter))) { return success(); }\n    // In principle, a concrete user op has no ctrl input/output. Some benefits:\n    // 1. simplify things\n    // 2. make conversion and code gen more doable\n    // 3. enable the reuse of established MLIR infra like built-in traits\n    if (IsCtrlOutTrimmed(op) && IsCtrlInAbsent(op)) {\n      NamedAttrList attributes(op->getAttrDictionary());\n      attributes.erase(op.getInputSizesAttrName());\n      attributes.erase(op.getOutputSizesAttrName());\n      attributes.erase(op.getOutputLbnsAttrName());\n      attributes.erase(OpTrait::AttrSizedOperandSegments<void>::getOperandSegmentSizeAttr());\n      attributes.erase(OpTrait::AttrSizedResultSegments<void>::getResultSegmentSizeAttr());\n      llvm::SmallVector<int32_t> input_sizes, output_sizes;\n      getValuesFromIntArrayAttribute(op.getInputSizes(), input_sizes);\n      getValuesFromIntArrayAttribute(op.getOutputSizes(), output_sizes);\n      if (!input_sizes.empty()) {\n        attributes.push_back(rewriter.getNamedAttr(\n            OpTrait::AttrSizedOperandSegments<void>::getOperandSegmentSizeAttr(),\n            rewriter.getDenseI32ArrayAttr(input_sizes)));\n      }\n      if (!output_sizes.empty()) {\n        attributes.push_back(rewriter.getNamedAttr(\n            OpTrait::AttrSizedResultSegments<void>::getResultSegmentSizeAttr(),\n            rewriter.getDenseI32ArrayAttr(output_sizes)));\n      }\n      OperationState state(op->getLoc(), OneFlowDialect::getDialectNamespace().str() + \".\"\n                                             + op.getOpTypeName().str());\n      state.addAttributes(attributes);\n      state.addOperands(op.getODSOperands(0) /* data in */);\n      state.addTypes(op.getODSResults(0 /* data out */).getTypes());\n      if (auto created = rewriter.create(state)) {\n        if (created->hasTrait<OpTrait::AttrSizedOperandSegments>() == false) {\n          created->removeAttr(OpTrait::AttrSizedOperandSegments<void>::getOperandSegmentSizeAttr());\n        }\n        if (created->hasTrait<OpTrait::AttrSizedResultSegments>() == false) {\n          created->removeAttr(OpTrait::AttrSizedResultSegments<void>::getResultSegmentSizeAttr());\n        }\n        if (created->hasTrait<OpTrait::IsAlternative>() == false) {\n          created->removeAttr(OpTrait::IsAlternative<void>::getOpTypeNameAttr());\n        }\n        rewriter.replaceOp(op, created->getResults());\n      } else {\n        op->emitError(\"Fail to convert opaque user op to concrete op when creating: \"\n                      + op.getOpTypeName());\n        op->dump();\n        return failure();\n      }\n    }\n    return success();\n  }\n};\n\nvoid UserOp::getCanonicalizationPatterns(RewritePatternSet& results, MLIRContext* context) {\n  results.insert<ConcreteUserOps>(context);\n}\n\nstruct ConcreteSystemOps : public OpRewritePattern<SystemOp> {\n  explicit ConcreteSystemOps(MLIRContext* context)\n      : OpRewritePattern<SystemOp>(context, /*benefit=*/1) {}\n  LogicalResult matchAndRewrite(oneflow::SystemOp op, PatternRewriter& rewriter) const override {\n    return TrimRedundantCtrl(op, rewriter);\n  }\n};\n\nvoid SystemOp::getCanonicalizationPatterns(RewritePatternSet& results, MLIRContext* context) {\n  results.insert<ConcreteSystemOps>(context);\n}\n\nstruct ConvertAddOpWithArity : public OpRewritePattern<AddNOp> {\n  explicit ConvertAddOpWithArity(MLIRContext* context)\n      : OpRewritePattern<AddNOp>(context, /*benefit=*/1) {}\n  LogicalResult matchAndRewrite(AddNOp op, PatternRewriter& rewriter) const override {\n    const auto arity = op.getIn().size();\n    if (arity == 2) {\n      NamedAttrList attributes = op->getAttrs();\n      attributes.set(OpTrait::IsAlternative<void>::getOpTypeNameAttr(),\n                     rewriter.getStringAttr(\"add_n\"));\n      if (auto created_op = rewriter.replaceOpWithNewOp<Add2Op>(op, op->getResultTypes(),\n                                                                op.getOperands(), attributes)) {\n        return success();\n      } else {\n        op->emitError(\"Fail to convert add op with arity: \") << arity;\n        op->dump();\n        return failure();\n      }\n    }\n    return failure();\n  }\n};\n\nvoid AddNOp::getCanonicalizationPatterns(RewritePatternSet& results, MLIRContext* context) {\n  results.insert<ConvertAddOpWithArity>(context);\n}\n\ntemplate<typename OpType>\nstruct ConcreteSystemOpPattern : public OpRewritePattern<OpType> {\n  explicit ConcreteSystemOpPattern(MLIRContext* context)\n      : OpRewritePattern<OpType>(context, /*benefit=*/1) {}\n  LogicalResult matchAndRewrite(OpType op, PatternRewriter& rewriter) const override {\n    if (op.getCtrlOutput() && op.getCtrlOutput().use_empty()) {\n      NamedAttrList attributes(op->getAttrDictionary());\n      if (auto created = rewriter.create<OpType>(op->getLoc(), op.getOutput().getType(),\n                                                 op->getOperands(), attributes)) {\n        op.getOutput().replaceAllUsesWith(\n            created->getResult(op.getOutput().template cast<OpResult>().getResultNumber()));\n        op->erase();\n        return success();\n      }\n    }\n    return failure();\n  }\n};\n\nvoid VariableOp::getCanonicalizationPatterns(RewritePatternSet& results, MLIRContext* context) {\n  results.insert<ConcreteSystemOpPattern<VariableOp>>(context);\n}\n\nvoid InputOp::getCanonicalizationPatterns(RewritePatternSet& results, MLIRContext* context) {\n  results.insert<ConcreteSystemOpPattern<InputOp>>(context);\n}\n\nvoid OutputOp::getCanonicalizationPatterns(RewritePatternSet& results, MLIRContext* context) {\n  results.insert<ConcreteSystemOpPattern<OutputOp>>(context);\n}\n\nstd::string Add2Op::getOriginalOpTypeName() { return \"add_n\"; }\nstd::string NormalizationInferenceOp::getOriginalOpTypeName() { return \"normalization\"; }\n\nvoid Job::build(OpBuilder& builder, OperationState& state, StringRef name, FunctionType type,\n                llvm::ArrayRef<mlir::NamedAttribute> attrs) {\n  state.addAttribute(SymbolTable::getSymbolAttrName(), builder.getStringAttr(name));\n  state.addAttribute(Job::getFunctionTypeAttrName(state.name), TypeAttr::get(type));\n  state.attributes.append(attrs.begin(), attrs.end());\n\n  state.addRegion();\n}\n\nParseResult Job::parse(OpAsmParser& parser, OperationState& result) {\n  auto buildFuncType = [](Builder& builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,\n                          function_interface_impl::VariadicFlag,\n                          std::string&) { return builder.getFunctionType(argTypes, results); };\n  return mlir::function_interface_impl::parseFunctionOp(\n      parser, result, /*allowVariadic=*/false, getFunctionTypeAttrName(result.name), buildFuncType,\n      getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));\n}\n\nvoid Job::print(OpAsmPrinter& p) {\n  function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false,\n                                           getFunctionTypeAttrName(), getArgAttrsAttrName(),\n                                           getResAttrsAttrName());\n}\n\nLogicalResult Job::verify() {\n  // If this function is external there is nothing to do.\n  if (isExternal()) return success();\n\n  // Verify that the argument list of the function and the arg list of the entry\n  // block line up.  The trait already verified that the number of arguments is\n  // the same between the signature and the block.\n  auto fnInputTypes = getFunctionType().getInputs();\n  Block& entryBlock = front();\n  for (unsigned i = 0, e = entryBlock.getNumArguments(); i != e; ++i)\n    if (fnInputTypes[i] != entryBlock.getArgument(i).getType())\n      return emitOpError(\"type of entry block argument #\")\n             << i << '(' << entryBlock.getArgument(i).getType()\n             << \") must match the type of the corresponding argument in \"\n             << \"function signature(\" << fnInputTypes[i] << ')';\n\n  return success();\n}\n\nLogicalResult ReturnOp::verify() {\n  auto job = cast<Job>((*this)->getParentOp());\n\n  // The operand number and types must match the function signature.\n  const auto& results = job.getFunctionType().getResults();\n  if (getNumOperands() != results.size())\n    return emitOpError(\"has \") << getNumOperands() << \" operands, but enclosing function (@\"\n                               << job.getName() << \") returns \" << results.size();\n\n  for (unsigned i = 0, e = results.size(); i != e; ++i)\n    if (getOperand(i).getType() != results[i])\n      return emitError() << \"type of return operand \" << i << \" (\" << getOperand(i).getType()\n                         << \") doesn't match function result type (\" << results[i] << \")\"\n                         << \" in function @\" << job.getName();\n\n  return success();\n}\n\nstruct NormalizationInferencePattern : public OpRewritePattern<NormalizationOp> {\n  explicit NormalizationInferencePattern(MLIRContext* context)\n      : OpRewritePattern<NormalizationOp>(context, /*benefit=*/1) {}\n  LogicalResult matchAndRewrite(oneflow::NormalizationOp op,\n                                PatternRewriter& rewriter) const override {\n    if (op.getMean() || op.getInvVariance()) return failure();\n    if (auto created_op = rewriter.replaceOpWithNewOp<NormalizationInferenceOp>(\n            op, op->getResultTypes(), op.getOperands(), op->getAttrs())) {\n      return success();\n    }\n    op.emitError(\"Failed to create inference bn op\");\n    return failure();\n  }\n};\n\nvoid NormalizationOp::getCanonicalizationPatterns(RewritePatternSet& results,\n                                                  MLIRContext* context) {\n  results.insert<NormalizationInferencePattern>(context);\n}\n\nResultRange GetDataOutputResults(Operation* op) {\n  if (auto cec = dyn_cast<ControlEdgeCompatible>(op)) {\n    return cec.dataOutputResults();\n  } else {\n    return op->getResults();\n  }\n}\n\nOperandRange GetDataInputOperands(Operation* op) {\n  if (auto cec = dyn_cast<ControlEdgeCompatible>(op)) {\n    return cec.dataInputOperands();\n  } else {\n    return op->getOperands();\n  }\n}\n\nllvm::Optional<OperandRange> GetCtrlIntputOperands(Operation* op) {\n  if (auto cec = dyn_cast<ControlEdgeCompatible>(op)) {\n    return cec.ctrlInputOperands();\n  } else {\n    return llvm::None;\n  }\n}\n\nllvm::Optional<OpResult> GetCtrlOutputResult(Operation* op) {\n  if (auto cec = dyn_cast<ControlEdgeCompatible>(op)) {\n    if (auto ctrl_out = cec.ctrlOutputResult()) { return ctrl_out.cast<OpResult>(); }\n  }\n  return llvm::None;\n}\n\n}  // namespace oneflow\n\n}  // namespace mlir\n\n#include \"OneFlow/OneFlowEnums.cpp.inc\"\n\n#define GET_OP_CLASSES\n#include \"OneFlow/OneFlowOps.cpp.inc\"\n#include \"OneFlow/OneFlowInterfaces.cpp.inc\"\n"
  },
  {
    "path": "oneflow/ir/lib/OneFlow/OneFlowRewrites.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n//===- TestPDLByteCode.cpp - Test PDLL functionality ----------------------===//\n//\n// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.\n// See https://llvm.org/LICENSE.txt for license information.\n// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception\n//\n//===----------------------------------------------------------------------===//\n\n#include \"OneFlow/UserOpConversion.h\"\n#include \"mlir/Dialect/PDL/IR/PDL.h\"\n#include \"mlir/Dialect/PDLInterp/IR/PDLInterp.h\"\n#include \"mlir/IR/BuiltinAttributes.h\"\n#include \"mlir/IR/ImplicitLocOpBuilder.h\"\n#include \"mlir/Parser/Parser.h\"\n#include \"mlir/Pass/Pass.h\"\n#include \"mlir/Pass/PassManager.h\"\n#include \"mlir/Support/LogicalResult.h\"\n#include \"mlir/Transforms/GreedyPatternRewriteDriver.h\"\n#include \"OneFlow/OneFlowPDLLPatterns.h\"\n#include \"OneFlow/OneFlowOps.h\"\n#include \"oneflow/core/framework/random_generator.h\"\n#include \"OneFlow/OneFlowUtils.h\"\n#include \"mlir/IR/IRMapping.h\"\n#include \"mlir/Dialect/Func/IR/FuncOps.h\"\nusing namespace mlir;\n\n#include \"oneflow/ir/lib/OneFlow/PDLL/ForwardOpPatterns.h.inc\"\n\nnamespace mlir {\n\nnamespace oneflow {\n\nnamespace {\n\nstatic std::atomic<int64_t> uniqID{0};\n\nstd::string getUniqName(llvm::StringRef name) {\n  uniqID += 1;\n  return name.str() + \"-mlir-gen-\" + std::to_string(uniqID);\n}\n\nstatic Operation* CopyUserOpAttrs(PatternRewriter& rewriter, Operation* src, Operation* dst) {\n  dst->setAttr(OpTrait::IsOpConfCompatible<void>::getDeviceTagAttr(),\n               OpTrait::IsOpConfCompatible<void>::getDeviceTag(src));\n  dst->setAttr(OpTrait::IsOpConfCompatible<void>::getDeviceNameAttr(),\n               OpTrait::IsOpConfCompatible<void>::getDeviceName(src));\n  if (auto hierarchy = OpTrait::IsOpConfCompatible<void>::getHierarchy(src)) {\n    dst->setAttr(OpTrait::IsOpConfCompatible<void>::getHierarchyAttr(), hierarchy);\n  }\n  if (auto scope_symbol_id = OpTrait::IsOpConfCompatible<void>::getScopeSymbolID(src)) {\n    dst->setAttr(OpTrait::IsOpConfCompatible<void>::getScopeSymbolIDAttr(), scope_symbol_id);\n  }\n  dst->setAttr(\n      OpTrait::IsOpConfCompatible<void>::getOpNameAttr(),\n      rewriter.getStringAttr(getUniqName(OpTrait::IsOpConfCompatible<void>::getOpName(src).str())));\n  return dst;\n}\n\nstatic Operation* BuildFusedBiasAddMaskScaleOpWithRate(PatternRewriter& rewriter, Value a, Value b,\n                                                       Value mask, Attribute axis, Attribute rate,\n                                                       Operation* dropout) {\n  auto dropout_op = llvm::dyn_cast<DropoutOp>(dropout);\n  assert(dropout_op);\n  SmallVector<Value, 4> operands;\n  operands.push_back(a);\n  operands.push_back(b);\n  operands.push_back(mask);\n  NamedAttrList attributes;\n  attributes.set(\"axis\", axis);\n  float scale = 1.0f;\n  float rate_float = rate.cast<FloatAttr>().getValueAsDouble();\n  if (rate_float < 1.0f) { scale = 1.0f / (1.0f - rate_float); }\n  attributes.set(\"scale\", rewriter.getF32FloatAttr(scale));\n  return rewriter.create<FusedBiasAddMaskScaleOp>(\n      dropout_op->getLoc(), dropout_op.getOut().getType(), operands, attributes);\n}\n\nstatic Operation* CreateConv2dAndErasePad(PatternRewriter& rewriter, Value x, Value weight,\n                                          Attribute padding_before, Attribute data_format,\n                                          Operation* conv) {\n  auto conv_op = llvm::dyn_cast<Conv2DOp>(conv);\n  assert(conv_op);\n  SmallVector<Value, 4> operands;\n  operands.push_back(x);\n  operands.push_back(weight);\n  NamedAttrList attributes = conv_op->getAttrs();\n  llvm::SmallVector<int32_t> padding_before_array;\n\n  attributes.set(OpTrait::IsOpConfCompatible<void>::getOpNameAttr(),\n                 rewriter.getStringAttr(OpTrait::IsOpConfCompatible<void>::getOpName(conv).str()\n                                        + \"-fuse-conv\"));\n\n  if (data_format.cast<StringAttr>().str() == \"channels_first\") {\n    for (auto val : padding_before.cast<ArrayAttr>().getValue().take_back(2)) {\n      padding_before_array.push_back(val.cast<IntegerAttr>().getValue().getSExtValue());\n    }\n  } else {\n    padding_before_array.push_back(padding_before.cast<ArrayAttr>()\n                                       .getValue()[1]\n                                       .cast<IntegerAttr>()\n                                       .getValue()\n                                       .getSExtValue());\n    padding_before_array.push_back(padding_before.cast<ArrayAttr>()\n                                       .getValue()[2]\n                                       .cast<IntegerAttr>()\n                                       .getValue()\n                                       .getSExtValue());\n  }\n\n  attributes.set(conv_op.getPaddingBeforeAttrName(),\n                 getSI32ArrayAttr(rewriter, padding_before_array));\n  return rewriter.create<Conv2DOp>(conv_op->getLoc(), conv_op.getOut().getType(), operands,\n                                   attributes);\n}\n\nIntegerAttr getSI64IntegerAttr(::mlir::PatternRewriter& rewriter, int64_t value) {\n  return IntegerAttr::get(rewriter.getIntegerType(64, /*isSigned=*/true),\n                          APInt(64, value, /*isSigned=*/true));\n}\n\nstatic Attribute GetHeadSizeFromTranpose(PatternRewriter& rewriter, Operation* transpose) {\n  auto transpose_op = llvm::dyn_cast<TransposeOp>(transpose);\n  CHECK(transpose_op);\n  return getSI64IntegerAttr(rewriter,\n                            transpose_op.getOutput().getType().cast<ShapedType>().getDimSize(3));\n}\nNamedAttrList GetUserOpCommonAttrs(MLIRContext* ctx, const std::string& op_name) {\n  NamedAttrList attrs;\n  attrs.set(OpTrait::IsOpConfCompatible<void>::getOpNameAttr(), StringAttr::get(ctx, op_name));\n  attrs.set(OpTrait::IsOpConfCompatible<void>::getDeviceTagAttr(), StringAttr::get(ctx, \"cpu\"));\n  attrs.set(OpTrait::IsOpConfCompatible<void>::getDeviceNameAttr(),\n            ArrayAttr::get(ctx, llvm::to_vector<8>(llvm::map_range(ArrayRef<StringRef>({\"@0:0\"}),\n                                                                   [&](StringRef v) -> Attribute {\n                                                                     return StringAttr::get(ctx, v);\n                                                                   }))));\n  return attrs;\n}\nstatic Operation* CreateConv2DBatchNorm(PatternRewriter& rewriter, Attribute epsilon,\n                                        Operation* conv, Operation* bn) {\n  auto conv_op = llvm::dyn_cast<oneflow::Conv2DOp>(conv);\n  auto bn_op = llvm::dyn_cast<oneflow::NormalizationInferenceOp>(bn);\n  auto ctx = rewriter.getContext();\n  NamedAttrList attributes = conv_op->getAttrs();\n\n  attributes.set(OpTrait::AttrSizedOperandSegments<void>::getOperandSegmentSizeAttr(),\n                 rewriter.getDenseI32ArrayAttr({1, 1, 1, 0}));\n\n  SmallVector<Value, 4> operands;\n  operands.push_back(conv_op.getIn());\n\n  // deal with weight\n  auto add_op_attrs = GetUserOpCommonAttrs(ctx, \"scalar_add\");\n  add_op_attrs.set(\"has_float_operand\", BoolAttr::get(ctx, true));\n\n  double epsilon_attr = epsilon.cast<FloatAttr>().getValueAsDouble();\n  add_op_attrs.set(\"float_operand\", rewriter.getF64FloatAttr(epsilon_attr));\n\n  auto add_op = rewriter.create<oneflow::ScalarAddOp>(\n      conv_op->getLoc(), conv_op.getOut().getType(),\n      SmallVector<Value, 4>({bn_op.getMovingVariance()}), add_op_attrs);\n\n  auto sqrt_op = rewriter.create<oneflow::SqrtOp>(conv_op->getLoc(), conv_op.getOut().getType(),\n                                                  SmallVector<Value, 4>({add_op.getOut()}),\n                                                  GetUserOpCommonAttrs(ctx, \"sqrt\"));\n\n  auto div_op = rewriter.create<oneflow::BroadcastDivOp>(\n      conv_op->getLoc(), conv_op.getOut().getType(),\n      SmallVector<Value, 4>({bn_op.getGamma(), sqrt_op.getY()}), GetUserOpCommonAttrs(ctx, \"div\"));\n\n  auto bn_gamma_variable_op =\n      llvm::dyn_cast<oneflow::FrozenVariableOp>(bn_op.getGamma().getDefiningOp());\n\n  CHECK(bn_gamma_variable_op) << \"Gamma of batchnorm should be a FrozenVariableOp.\";\n\n  auto bn_gamma_shape =\n      bn_gamma_variable_op.getValue().getType().cast<mlir::RankedTensorType>().getShape();\n\n  auto conv_weight_variable_op =\n      llvm::dyn_cast<oneflow::FrozenVariableOp>(conv_op.getWeight().getDefiningOp());\n\n  CHECK(conv_weight_variable_op) << \"Weight of conv2d should be a FrozenVariableOp.\";\n\n  auto conv_weight_shape =\n      conv_weight_variable_op.getValue().getType().cast<mlir::RankedTensorType>().getShape();\n\n  std::vector<int64_t> bn_gamma_new_shape({bn_gamma_shape.front()});\n  for (int i = 1; i < conv_weight_shape.size(); ++i) { bn_gamma_new_shape.emplace_back(1); }\n  auto reshape_op_attrs = GetUserOpCommonAttrs(ctx, \"reshape\");\n  reshape_op_attrs.set(\n      \"shape\",\n      ArrayAttr::get(ctx, llvm::to_vector<8>(llvm::map_range(\n                              ArrayRef<int64_t>(bn_gamma_new_shape), [&](int64_t v) -> Attribute {\n                                return getSI64IntegerAttr(rewriter, v);\n                              }))));\n  auto reshape_op =\n      rewriter.create<oneflow::ReshapeOp>(conv_op->getLoc(), conv_op.getOut().getType(),\n                                          SmallVector<Value, 4>({div_op.getZ()}), reshape_op_attrs);\n\n  auto mul_op = rewriter.create<oneflow::BroadcastMulOp>(\n      conv_op->getLoc(), conv_op.getOut().getType(),\n      SmallVector<Value, 4>({conv_op.getWeight(), reshape_op.getOut()}),\n      GetUserOpCommonAttrs(ctx, \"multiply\"));\n  operands.push_back(mul_op.getZ());\n\n  // deal with bias\n  CHECK(!conv_op.getBias())\n      << \"Fusing conv2d and batch_norm only supports conv2d without bias now.\";\n\n  auto mul_op_bias = rewriter.create<oneflow::BroadcastMulOp>(\n      conv_op->getLoc(), conv_op.getOut().getType(),\n      SmallVector<Value, 4>({bn_op.getMovingMean(), div_op.getZ()}),\n      GetUserOpCommonAttrs(ctx, \"multiply_bias\"));\n  auto sub_op_bias = rewriter.create<oneflow::BroadcastSubOp>(\n      conv_op->getLoc(), conv_op.getOut().getType(),\n      SmallVector<Value, 4>({bn_op.getBeta(), mul_op_bias.getZ()}),\n      GetUserOpCommonAttrs(ctx, \"sub_bias\"));\n  operands.push_back(sub_op_bias.getZ());\n\n  auto new_conv_op = rewriter.create<oneflow::Conv2DOp>(\n      conv_op->getLoc(), conv_op.getOut().getType(), operands, attributes);\n\n  return new_conv_op;\n}\n\nstatic LogicalResult IsPaddingCouldBeAssimilatedIntoConv(PatternRewriter& rewriter,\n                                                         Attribute padding_before,\n                                                         Attribute padding_after,\n                                                         Attribute data_format) {\n  if (padding_before.cast<ArrayAttr>().size() == 4 && padding_after.cast<ArrayAttr>().size() == 4) {\n    if (padding_before.cast<ArrayAttr>().getValue().equals(\n            padding_after.cast<ArrayAttr>().getValue())) {\n      if (data_format.cast<StringAttr>().str() == \"channels_first\") {\n        return success(padding_before.cast<ArrayAttr>()\n                               .getValue()[0]\n                               .cast<IntegerAttr>()\n                               .getValue()\n                               .getSExtValue()\n                           == 0\n                       && padding_before.cast<ArrayAttr>()\n                                  .getValue()[1]\n                                  .cast<IntegerAttr>()\n                                  .getValue()\n                                  .getSExtValue()\n                              == 0);\n      }\n      if (data_format.cast<StringAttr>().str() == \"channels_last\") {\n        return success(padding_before.cast<ArrayAttr>()\n                               .getValue()[0]\n                               .cast<IntegerAttr>()\n                               .getValue()\n                               .getSExtValue()\n                           == 0\n                       && padding_before.cast<ArrayAttr>()\n                                  .getValue()[3]\n                                  .cast<IntegerAttr>()\n                                  .getValue()\n                                  .getSExtValue()\n                              == 0);\n      }\n    }\n  }\n  return failure();\n}\nstatic LogicalResult IsNotNestedInJit(PatternRewriter& rewriter, Operation* mul) {\n  return success(mul->getParentOfType<oneflow::Job>());\n}\n\nstatic LogicalResult IsScalarTensor(PatternRewriter& rewriter, Value value) {\n  if (auto tensor = value.getType().dyn_cast<RankedTensorType>()) {\n    return success(tensor.getNumElements() == 1);\n  }\n  return failure();\n}\n\nstatic float mha_scale_max_diff = 1e-5;\n\nstatic LogicalResult IsScalarEqualSqrtDim(PatternRewriter& rewriter, Value query_reshape,\n                                          Attribute scalar_div_operand) {\n  auto query_reshape_shape = query_reshape.getType().dyn_cast<ShapedType>();\n  double scalar_div_operand_attr = scalar_div_operand.cast<FloatAttr>().getValueAsDouble();\n  return success(\n      std::abs(std::sqrt(query_reshape_shape.getShape().back()) - scalar_div_operand_attr)\n      < mha_scale_max_diff);\n}\n\nstatic LogicalResult IsScalarEqualSqrtDimReciprocal(PatternRewriter& rewriter, Value query_reshape,\n                                                    Attribute scalar_div_operand) {\n  auto query_reshape_shape = query_reshape.getType().dyn_cast<ShapedType>();\n  double scalar_div_operand_attr = scalar_div_operand.cast<FloatAttr>().getValueAsDouble();\n  return success(\n      std::abs(std::sqrt(query_reshape_shape.getShape().back()) - (1 / scalar_div_operand_attr))\n      < mha_scale_max_diff);\n}\n\nstatic Attribute GetReciprocal(PatternRewriter& rewriter, Attribute a) {\n  return rewriter.getF64FloatAttr(1 / a.cast<FloatAttr>().getValueAsDouble());\n}\n\n}  // namespace\n\nnamespace rewrites {\n\nvoid populateRewrites(RewritePatternSet& patterns) {\n  patterns.getPDLPatterns().registerRewriteFunction(\"BuildFusedBiasAddMaskScaleOpWithRate\",\n                                                    BuildFusedBiasAddMaskScaleOpWithRate);\n  patterns.getPDLPatterns().registerRewriteFunction(\"CopyUserOpAttrs\", CopyUserOpAttrs);\n  patterns.getPDLPatterns().registerRewriteFunction(\"GetHeadSizeFromTranpose\",\n                                                    GetHeadSizeFromTranpose);\n  patterns.getPDLPatterns().registerRewriteFunction(\"CreateConv2dAndErasePad\",\n                                                    CreateConv2dAndErasePad);\n  patterns.getPDLPatterns().registerRewriteFunction(\"CreateConv2DBatchNorm\", CreateConv2DBatchNorm);\n  patterns.getPDLPatterns().registerRewriteFunction(\"GetReciprocal\", GetReciprocal);\n}\n\nmlir::IntegerAttr GetDefaultSeed(::mlir::PatternRewriter& rewriter) {\n  const auto gen = CHECK_JUST(::oneflow::one::DefaultAutoGenerator());\n  return getSI64IntegerAttr(rewriter, (int64_t)gen->current_seed());\n}\n\n}  // namespace rewrites\n\nnamespace constraints {\n\nvoid populateConstraints(RewritePatternSet& patterns) {\n  auto& pdll_patterns = patterns.getPDLPatterns();\n\n#define PDLL_REGISTER(NAME) pdll_patterns.registerConstraintFunction(#NAME, NAME);\n\n  PDLL_REGISTER(IsPaddingCouldBeAssimilatedIntoConv);\n  PDLL_REGISTER(IsNotNestedInJit);\n  PDLL_REGISTER(IsScalarTensor);\n  PDLL_REGISTER(IsScalarEqualSqrtDim);\n  PDLL_REGISTER(IsScalarEqualSqrtDimReciprocal);\n\n#undef PDLL_REGISTER\n}\n\n}  // namespace constraints\n}  // namespace oneflow\n\n}  // namespace mlir\n"
  },
  {
    "path": "oneflow/ir/lib/OneFlow/OneFlowSupport.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"OneFlow/OneFlowTypes.h\"\n#include \"mlir/IR/Attributes.h\"\n#include \"mlir/IR/BuiltinAttributes.h\"\n#include \"mlir/IR/BuiltinTypes.h\"\n#include \"mlir/IR/MLIRContext.h\"\n#include \"oneflow/ir/include/OneFlow/OneFlowSupport.h\"\n#include \"oneflow/core/common/data_type.pb.h\"\n#include \"oneflow/core/common/just.h\"\n#include \"oneflow/core/eager/eager_blob_object.h\"\n#include \"oneflow/core/job/lazy_mode.h\"\n#include \"oneflow/core/functional/functional_api.yaml.h\"\n#include \"oneflow/core/framework/tensor.h\"\n#include \"oneflow/core/framework/tensor_util.h\"\n#include \"oneflow/core/framework/user_op_registry_manager.h\"\n#include \"oneflow/core/kernel/kernel_util.h\"\n#include \"oneflow/core/memory/memory_case_util.h\"\n#include \"oneflow/core/common/data_type.h\"\n\n#include <iostream>\n#include <vector>\n\nnamespace mlir {\n\nnamespace oneflow {\n\nnamespace support {\n\nstd::vector<std::string> GetInputKeys(const std::string& op_type_name) {\n  std::vector<std::string> ret{};\n  for (auto& arg : getUserOpDef(op_type_name).input()) { ret.push_back(arg.name()); }\n  return ret;\n}\n\nstd::vector<std::string> GetOutputKeys(const std::string& op_type_name) {\n  std::vector<std::string> ret{};\n  for (auto& arg : getUserOpDef(op_type_name).output()) { ret.push_back(arg.name()); }\n  return ret;\n}\n\nnamespace {\n\n::oneflow::Symbol<::oneflow::Device> MakeDevice(const mlir::Attribute& device_tag_attr,\n                                                const mlir::Attribute& device_name_attr) {\n  const auto device_tag = device_tag_attr.cast<mlir::StringAttr>().str();\n  const auto device_name =\n      device_name_attr.cast<mlir::ArrayAttr>().getValue().front().cast<mlir::StringAttr>().str();\n  const std::string device_info =\n      device_tag == \"gpu\" ? \"cuda\" : device_tag + device_name.substr(device_name.rfind(\":\"));\n  return ::oneflow::Device::ParseAndNew(device_info).GetOrThrow();\n}\n\ntemplate<typename T, typename MLIR_T>\nmlir::DenseElementsAttr __TensorToDenseElementsAttr(\n    const std::shared_ptr<::oneflow::one::Tensor>& tensor, const MLIR_T& mlir_type) {\n  ::oneflow::LazyMode::Guard guard{false};\n  const auto tensor_ = ::oneflow::one::functional::ToContiguous(tensor).GetPtrOrThrow();\n  auto shape = tensor_->shape();\n  std::vector<int64_t> shape_vec(shape->dim_vec().begin(), shape->dim_vec().end());\n  std::vector<T> data(shape->elem_cnt());\n  const auto& callback =\n      [&](::oneflow::ep::Stream* stream,\n          const std::shared_ptr<::oneflow::vm::EagerBlobObject>& eager_blob_object) {\n        ::oneflow::AutoMemcpy(stream, data.data(), eager_blob_object->dptr(),\n                              data.size() * sizeof(T), ::oneflow::memory::MakeHostMemCase(),\n                              eager_blob_object->mem_case());\n      };\n  ::oneflow::one::SyncAccessTensorWithTimeOut(tensor_, callback, \"const\").GetOrThrow();\n  return mlir::DenseElementsAttr::get(mlir::RankedTensorType::get(shape_vec, mlir_type),\n                                      llvm::makeArrayRef(data));\n}\n\ntemplate<typename T>\nstd::shared_ptr<::oneflow::one::Tensor> __DenseElementsAttrToTensor(\n    const mlir::DenseElementsAttr dense_attr, const mlir::Attribute& device_tag_attr,\n    const mlir::Attribute& device_name_attr, const ::oneflow::DataType& dtype) {\n  const auto dense_type = dense_attr.getType().cast<mlir::RankedTensorType>();\n  std::vector<int64_t> shape = dense_type.getShape().vec();\n\n  const auto device = MakeDevice(device_tag_attr, device_name_attr);\n\n  std::shared_ptr<::oneflow::one::Tensor> tensor =\n      ::oneflow::one::functional::Empty(\n          ::oneflow::Shape(::oneflow::DimVector(shape.begin(), shape.end())),\n          ::oneflow::DType::Get(dtype).GetOrThrow(), device, /*requires_grad=*/false,\n          /*pin_memory=*/false)\n          .GetPtrOrThrow();\n\n  std::vector<T> data(dense_attr.getValues<T>().begin(), dense_attr.getValues<T>().end());\n  const auto& callback =\n      [&](::oneflow::ep::Stream* stream,\n          const std::shared_ptr<::oneflow::vm::EagerBlobObject>& eager_blob_object) {\n        ::oneflow::AutoMemcpy(stream, eager_blob_object->mut_dptr(), data.data(),\n                              tensor->shape()->elem_cnt() * sizeof(T),\n                              eager_blob_object->mem_case(), ::oneflow::memory::MakeHostMemCase());\n      };\n  ::oneflow::one::SyncAccessTensorWithTimeOut(tensor, callback, \"mut\").GetOrThrow();\n  return tensor;\n}\n\ntemplate<typename T>\nvoid __DenseElementsAttrToTensor(const mlir::DenseElementsAttr dense_attr,\n                                 const mlir::Attribute& device_tag_attr,\n                                 const mlir::Attribute& device_name_attr,\n                                 const ::oneflow::DataType& dtype,\n                                 std::shared_ptr<::oneflow::one::Tensor>& tensor) {\n  const auto dense_type = dense_attr.getType().cast<mlir::RankedTensorType>();\n  std::vector<int64_t> shape = dense_type.getShape().vec();\n  int ndim = shape.size();\n  CHECK_EQ(tensor->shape()->size(), ndim);\n  for (int i = 0; i < ndim; ++i) { CHECK_EQ(tensor->shape()->at(i), shape[i]); }\n\n  const auto device = MakeDevice(device_tag_attr, device_name_attr);\n  CHECK(CHECK_JUST(tensor->device()) == device);\n\n  std::vector<T> data;\n  std::vector<::oneflow::float16> fp16_data;\n  void* dptr = nullptr;\n  const size_t tensor_size =\n      tensor->shape()->elem_cnt() * ::oneflow::GetSizeOfDataType(tensor->dtype()->data_type());\n\n  CHECK_EQ(::oneflow::GetDataType<T>::value, dtype);\n  if (tensor->dtype()->data_type() == ::oneflow::DataType::kFloat16) {\n    for (const T elem : dense_attr.getValues<T>()) {\n      fp16_data.push_back(static_cast<::oneflow::float16>(elem));\n    }\n    CHECK_EQ(fp16_data.size() * sizeof(::oneflow::float16), tensor_size);\n    dptr = fp16_data.data();\n  } else if (tensor->dtype()->data_type() == dtype) {\n    for (const T elem : dense_attr.getValues<T>()) { data.push_back(elem); }\n    CHECK_EQ(data.size() * sizeof(T), tensor_size);\n    dptr = data.data();\n  } else {\n    UNIMPLEMENTED();\n  }\n\n  const auto& callback =\n      [=](::oneflow::ep::Stream* stream,\n          const std::shared_ptr<::oneflow::vm::EagerBlobObject>& eager_blob_object) {\n        ::oneflow::AutoMemcpy(stream, eager_blob_object->mut_dptr(), dptr, tensor_size,\n                              eager_blob_object->mem_case(), ::oneflow::memory::MakeHostMemCase());\n      };\n  ::oneflow::one::SyncAccessTensorWithTimeOut(tensor, callback, \"mut\").GetOrThrow();\n}\n\n}  // namespace\n\nmlir::DenseElementsAttr TensorToDenseElementsAttr(\n    const std::shared_ptr<::oneflow::one::Tensor>& tensor, MLIRContext* ctx) {\n  const auto dtype = tensor->dtype()->data_type();\n  if (dtype == ::oneflow::DataType::kFloat) {\n    return __TensorToDenseElementsAttr<float, mlir::FloatType>(tensor,\n                                                               mlir::FloatType::getF32(ctx));\n  } else if (dtype == ::oneflow::DataType::kInt64) {\n    auto mlir_type = mlir::IntegerType::IntegerType::get(\n        ctx, 64, mlir::IntegerType::SignednessSemantics::Signed);\n    return __TensorToDenseElementsAttr<int64_t, mlir::IntegerType>(tensor, mlir_type);\n  }\n  llvm::errs() << \"Converting oneflow::Tensor to mlir::DenseElementsAttr only support float32 now.\"\n               << \"\\n\";\n  exit(EXIT_FAILURE);\n}\n\nstd::shared_ptr<::oneflow::one::Tensor> DenseElementsAttrToTensor(\n    const mlir::Attribute& dense_attr, const mlir::Attribute& device_tag_attr,\n    const mlir::Attribute& device_name_attr) {\n  ::oneflow::LazyMode::Guard guard{false};\n  const auto dense_attr_ = dense_attr.cast<mlir::DenseElementsAttr>();\n  const auto dense_element_type = dense_attr_.getElementType();\n  if (dense_element_type.isF32()) {\n    return __DenseElementsAttrToTensor<float>(dense_attr_, device_tag_attr, device_name_attr,\n                                              ::oneflow::DataType::kFloat);\n  }\n  llvm::errs()\n      << \"Converting mlir::DenseElementsAttr to oneflow::Tensor only support float32 and int64 now.\"\n      << \"\\n\";\n  exit(EXIT_FAILURE);\n}\n\nvoid DenseElementsAttrToTensor(const mlir::Attribute& dense_attr,\n                               const mlir::Attribute& device_tag_attr,\n                               const mlir::Attribute& device_name_attr,\n                               std::shared_ptr<::oneflow::one::Tensor>& tensor) {\n  ::oneflow::LazyMode::Guard guard{false};\n  const auto dense_attr_ = dense_attr.cast<mlir::DenseElementsAttr>();\n  const auto dense_element_type = dense_attr_.getElementType();\n  if (dense_element_type.isF32()) {\n    __DenseElementsAttrToTensor<float>(dense_attr_, device_tag_attr, device_name_attr,\n                                       ::oneflow::DataType::kFloat, tensor);\n  } else {\n    llvm::errs() << \"Converting mlir::DenseElementsAttr to oneflow::Tensor only support float32 \"\n                    \"and int64 now.\"\n                 << \"\\n\";\n    exit(EXIT_FAILURE);\n  }\n}\n\nFailureOr<::oneflow::DataType> FromMLIRTypeToOFDataType(Type mlir_type) {\n  if (mlir_type.dyn_cast<InvalidElementType>()) { return ::oneflow::DataType::kInvalidDataType; }\n  if (mlir_type.dyn_cast<CharElementType>()) { return ::oneflow::DataType::kChar; }\n  if (mlir_type.dyn_cast<OFRecordElementType>()) { return ::oneflow::DataType::kOFRecord; }\n  if (mlir_type.dyn_cast<TensorBufferElementType>()) { return ::oneflow::DataType::kTensorBuffer; }\n  if (mlir_type.isF16()) { return ::oneflow::DataType::kFloat16; }\n  if (mlir_type.isF32()) { return ::oneflow::DataType::kFloat; }\n  if (mlir_type.isF64()) { return ::oneflow::DataType::kDouble; }\n\n  if (mlir_type.isSignlessInteger(8)) { return ::oneflow::DataType::kBool; }\n  if (mlir_type.isSignlessInteger(16)) { return ::oneflow::DataType::kUInt16; }\n  if (mlir_type.isSignlessInteger(32)) { return ::oneflow::DataType::kUInt32; }\n  if (mlir_type.isSignlessInteger(64)) { return ::oneflow::DataType::kUInt64; }\n  if (mlir_type.isSignlessInteger(128)) { return ::oneflow::DataType::kUInt128; }\n\n  if (mlir_type.isSignedInteger(8)) { return ::oneflow::DataType::kInt8; }\n  if (mlir_type.isSignedInteger(16)) { return ::oneflow::DataType::kInt16; }\n  if (mlir_type.isSignedInteger(32)) { return ::oneflow::DataType::kInt32; }\n  if (mlir_type.isSignedInteger(64)) { return ::oneflow::DataType::kInt64; }\n  if (mlir_type.isSignedInteger(128)) { return ::oneflow::DataType::kInt128; }\n  llvm::errs() << \"Unsupported data type: \" << mlir_type << \"\\n\";\n  return failure();\n}\n\nFailureOr<::oneflow::DataType> FromMLIRDataTypeToOFDataType(::mlir::oneflow::DataType data_type) {\n  switch (data_type) {\n    case ::mlir::oneflow::DataType::DT_InvalidDataType:\n      return ::oneflow::DataType::kInvalidDataType;\n#define DEFINE_ONE_CASE(datatype) \\\n  case ::mlir::oneflow::DataType::DT_##datatype: return ::oneflow::DataType::k##datatype;\n      DEFINE_ONE_CASE(Char)\n      DEFINE_ONE_CASE(Float)\n      DEFINE_ONE_CASE(Double)\n      DEFINE_ONE_CASE(Int8)\n      DEFINE_ONE_CASE(Int32)\n      DEFINE_ONE_CASE(Int64)\n      DEFINE_ONE_CASE(UInt8)\n      DEFINE_ONE_CASE(OFRecord)\n      DEFINE_ONE_CASE(Float16)\n      DEFINE_ONE_CASE(TensorBuffer)\n      DEFINE_ONE_CASE(Bool)\n#undef DEFINE_ONE_CASE\n    default: {\n      return failure();\n    }\n  }\n  return failure();\n}\n\nFailureOr<::oneflow::DataType> FromMLIRAttrToOFDataType(Attribute attr) {\n  const auto data_type_attr = attr.dyn_cast<mlir::oneflow::DataTypeAttr>();\n  return FromMLIRDataTypeToOFDataType(data_type_attr.getValue());\n}\n\nconst ::oneflow::UserOpDef& getUserOpDef(const std::string& op_type_name) {\n  const ::oneflow::user_op::OpRegistryResult* val =\n      ::oneflow::user_op::UserOpRegistryMgr::Get().GetOpRegistryResult(op_type_name);\n  CHECK(val) << \" Cannot find op_type_name: \" << op_type_name;\n  return val->op_def;\n}\n\n}  // namespace support\n\n}  // namespace oneflow\n\n}  // namespace mlir\n"
  },
  {
    "path": "oneflow/ir/lib/OneFlow/OneFlowTypes.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"OneFlow/OneFlowDialect.h\"\n#include \"OneFlow/OneFlowTypes.h\"\n#include \"mlir/IR/DialectImplementation.h\"\n#include \"llvm/ADT/TypeSwitch.h\"\n#define GET_TYPEDEF_CLASSES\n#include \"OneFlow/OneFlowOpsTypes.cpp.inc\"\n"
  },
  {
    "path": "oneflow/ir/lib/OneFlow/OneFlowUtils.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"OneFlow/OneFlowUtils.h\"\n#include \"oneflow/core/common/util.h\"\nnamespace mlir {\nnamespace oneflow {\n\nvoid CheckEnableIRPrinting(mlir::PassManager& pm) {\n  bool enable_ir_printing =\n      ::oneflow::ParseBooleanFromEnv(\"ONEFLOW_MLIR_ENABLE_IR_PRINTING\", false);\n  pm.getContext()->disableMultithreading(enable_ir_printing);\n  if (enable_ir_printing) { pm.enableIRPrinting(); }\n}\n\nStringRef SanitizeIdentifier(StringRef name, SmallString<16>& buffer, StringRef allowedPunctChars,\n                             bool allowTrailingDigit) {\n  assert(!name.empty() && \"Shouldn't have an empty name here\");\n\n  auto copyNameToBuffer = [&] {\n    for (char ch : name) {\n      if (llvm::isAlnum(ch) || allowedPunctChars.contains(ch))\n        buffer.push_back(ch);\n      else if (ch == ' ')\n        buffer.push_back('_');\n      else\n        buffer.append(llvm::utohexstr((unsigned char)ch));\n    }\n  };\n\n  // Check to see if this name is valid. If it starts with a digit, then it\n  // could conflict with the autogenerated numeric ID's, so add an underscore\n  // prefix to avoid problems.\n  if (isdigit(name[0])) {\n    buffer.push_back('_');\n    copyNameToBuffer();\n    return buffer;\n  }\n\n  // If the name ends with a trailing digit, add a '_' to avoid potential\n  // conflicts with autogenerated ID's.\n  if (!allowTrailingDigit && isdigit(name.back())) {\n    copyNameToBuffer();\n    buffer.push_back('_');\n    return buffer;\n  }\n\n  // Check to see that the name consists of only valid identifier characters.\n  for (char ch : name) {\n    if (!llvm::isAlnum(ch) && !allowedPunctChars.contains(ch)) {\n      copyNameToBuffer();\n      return buffer;\n    }\n  }\n\n  // If there are no invalid characters, return the original name.\n  return name;\n}\n}  // namespace oneflow\n}  // namespace mlir\n"
  },
  {
    "path": "oneflow/ir/lib/OneFlow/PDLL/AllocEliminationPatterns.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"mlir/Dialect/Func/IR/FuncOps.h\"\n#include \"mlir/Dialect/PDL/IR/PDL.h\"\n#include \"mlir/Dialect/PDLInterp/IR/PDLInterp.h\"\n#include \"mlir/Parser/Parser.h\"\n#include \"mlir/Pass/Pass.h\"\n#include \"mlir/Pass/PassManager.h\"\n#include \"mlir/Transforms/GreedyPatternRewriteDriver.h\"\n#include \"OneFlow/OneFlowPDLLPatterns.h\"\n#include \"mlir/IR/Value.h\"\n\nusing namespace mlir;\n\n#include \"oneflow/ir/lib/OneFlow/PDLL/AllocEliminationPatterns.h.inc\"\n\nnamespace mlir {\n\nnamespace oneflow {\n\nvoid populateAllocEliminationPatterns(RewritePatternSet& patterns) {\n  populateGeneratedPDLLPatterns(patterns);\n}\n\n}  // namespace oneflow\n\n}  // namespace mlir\n"
  },
  {
    "path": "oneflow/ir/lib/OneFlow/PDLL/AllocEliminationPatterns.pdll",
    "content": "#include \"OneFlow/OneFlowOps.td\"\n\nConstraint IsFuncArguments(value: Value) [{\n  return success(llvm::dyn_cast<mlir::BlockArgument>(value));\n}];\n\nPattern {\n  arg: Value;\n  let alloc = op<memref.alloc>();\n  let copy = op<memref.copy>(alloc.0, arg);\n  IsFuncArguments(arg);\n\n  rewrite alloc with {\n    erase copy;\n    replace alloc with arg;\n  };\n}\n"
  },
  {
    "path": "oneflow/ir/lib/OneFlow/PDLL/CMakeLists.txt",
    "content": "add_mlir_pdll_library(MLIROneFlowPDLLAllocEliminaionPatternsIncGen AllocEliminationPatterns.pdll\n                      AllocEliminationPatterns.h.inc)\nadd_mlir_pdll_library(MLIROneFlowPDLLForwardOpPatternsIncGen ForwardOpPatterns.pdll\n                      ForwardOpPatterns.h.inc)\nadd_mlir_pdll_library(MLIROneFlowPDLLNormalizationPatternsIncGen NormalizationPatterns.pdll\n                      NormalizationPatterns.h.inc)\nadd_mlir_pdll_library(MLIROneFlowPDLLFuseConv2DBatchNormPatternIncGen\n                      FuseConv2DBatchNormPattern.pdll FuseConv2DBatchNormPattern.h.inc)\nadd_mlir_pdll_library(MLIROneFlowPDLLFuseOpsWithBackwardImplPatternsIncGen\n                      FuseOpsWithBackwardImplPattern.pdll FuseOpsWithBackwardImplPattern.h.inc)\noneflow_add_mlir_dialect_library(\n  MLIROneFlowPDLLPatterns\n  AllocEliminationPatterns.cpp\n  ForwardOpPatterns.cpp\n  NormalizationPatterns.cpp\n  FuseConv2DBatchNormPattern.cpp\n  FuseOpsWithBackwardImplPattern.cpp\n  DEPENDS\n  MLIROneFlowPDLLAllocEliminaionPatternsIncGen\n  MLIROneFlowPDLLForwardOpPatternsIncGen\n  MLIROneFlowPDLLNormalizationPatternsIncGen\n  MLIROneFlowPDLLFuseConv2DBatchNormPatternIncGen\n  MLIROneFlowPDLLFuseOpsWithBackwardImplPatternsIncGen)\n"
  },
  {
    "path": "oneflow/ir/lib/OneFlow/PDLL/ForwardOpPatterns.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"mlir/Dialect/PDL/IR/PDL.h\"\n#include \"mlir/Dialect/PDLInterp/IR/PDLInterp.h\"\n#include \"mlir/Parser/Parser.h\"\n#include \"mlir/Pass/Pass.h\"\n#include \"mlir/Pass/PassManager.h\"\n#include \"mlir/Transforms/GreedyPatternRewriteDriver.h\"\n#include \"OneFlow/OneFlowPDLLPatterns.h\"\n\nusing namespace mlir;\n\n#include \"oneflow/ir/lib/OneFlow/PDLL/ForwardOpPatterns.h.inc\"\n\nnamespace mlir {\n\nnamespace oneflow {\n\nvoid populateForwardOpPatterns(RewritePatternSet& patterns) {\n  populateGeneratedPDLLPatterns(patterns);\n}\n\n}  // namespace oneflow\n\n}  // namespace mlir\n"
  },
  {
    "path": "oneflow/ir/lib/OneFlow/PDLL/ForwardOpPatterns.pdll",
    "content": "#include \"OneFlow/OneFlowOps.td\"\n#include \"OneFlowPDLLUtils.pdll\"\n\nPattern {\n  let rate: Attr;\n  let device_name: Attr;\n  let device_tag: Attr;\n  let axis: Attr;\n  let dropout =\n    op<oneflow.dropout>(\n      op<oneflow.bias_add>(a: Value, b: Value) {axis = axis, device_name = device_name, device_tag = device_tag})\n      {rate = rate, device_name = device_name, device_tag = device_tag} -> (out: Type, mask: Type);\n\n  rewrite dropout with {\n    let random_mask_like = CopyUserOpAttrs(dropout, op<oneflow.random_mask_like>(a){rate = rate} -> (mask));\n    let fused_bias_add_mask_scale = CopyUserOpAttrs(dropout, BuildFusedBiasAddMaskScaleOpWithRate(a, b, random_mask_like.0, axis, rate, dropout));\n    replace dropout with (fused_bias_add_mask_scale.0, random_mask_like.0);\n  };\n}\n\nPattern {\n  let device_name: Attr;\n  let device_tag: Attr;\n  let axis: Attr;\n  let gelu =\n    op<oneflow.gelu>(\n      op<oneflow.bias_add>(a: Value, b: Value) {axis = axis, device_name = device_name, device_tag = device_tag})\n      {device_name = device_name, device_tag = device_tag} -> (out: Type);\n\n  rewrite gelu with{\n    replace gelu with CopyUserOpAttrs(gelu, op<oneflow.fused_bias_add_gelu>(a, b){axis = axis} -> (out));\n  };\n}\n\nPattern {\n  let device_name: Attr;\n  let device_tag = attr<\"\\\"cuda\\\"\">;\n  let scalar_div_operand: Attr;\n  let out_shape: Attr;\n\n  let query: Value;\n  let key: Value;\n  let value: Value;\n\n  let query_reshape = op<oneflow.reshape>(query) {device_name = device_name, device_tag = device_tag};\n  let key_reshape = op<oneflow.reshape>(key) {device_name = device_name, device_tag = device_tag};\n  let value_reshape = op<oneflow.reshape>(value) {device_name = device_name, device_tag = device_tag};\n  let query_transpose = op<oneflow.transpose>(query_reshape.0) {device_name = device_name, device_tag = device_tag, perm = attr<\"[0 : si32, 2 : si32, 1 : si32, 3 : si32]\">};\n  let key_transpose = op<oneflow.transpose>(key_reshape.0) {device_name = device_name, device_tag = device_tag, perm = attr<\"[0 : si32, 2 : si32, 3 : si32, 1 : si32]\">};\n  let value_transpose = op<oneflow.transpose>(value_reshape.0) {device_name = device_name, device_tag = device_tag, perm = attr<\"[0 : si32, 2 : si32, 1 : si32, 3 : si32]\">};\n  let scores = op<oneflow.batch_matmul>(query_transpose.0, key_transpose.0) {alpha = attr<\"1.000000e+00 : f64\">, device_name = device_name, device_tag = device_tag, transpose_a = attr<\"false\">, transpose_b = attr<\"false\">};\n  let scores_scaled = op<oneflow.scalar_div>(scores.0) {device_name = device_name, device_tag = device_tag, float_operand = scalar_div_operand, has_float_operand = attr<\"true\">};\n  let attn = op<oneflow.softmax>(scores_scaled.0) {device_name = device_name, device_tag = device_tag};\n  let out = op<oneflow.batch_matmul>(attn.0, value_transpose.0) {alpha = attr<\"1.000000e+00 : f64\">, device_name = device_name, device_tag = device_tag, transpose_a = attr<\"false\">, transpose_b = attr<\"false\">};\n  let out_transpose = op<oneflow.transpose>(out.0) {device_name = device_name, device_tag = device_tag, perm = attr<\"[0 : si32, 2 : si32, 1 : si32, 3 : si32]\">};\n  let out_reshape = op<oneflow.reshape>(out_transpose.0) {device_name = device_name, device_tag = device_tag, shape = out_shape} -> (out_t: Type);\n\n  IsScalarEqualSqrtDim(query_reshape.0, scalar_div_operand);\n\n  rewrite out_reshape with{\n    replace out_reshape with CopyUserOpAttrs(out, op<oneflow.fused_multi_head_attention_inference>(query, key, value) {\n      attn_mask_type = attr<\"\\\"none\\\"\">,\n      query_max_seq_len = attr<\"0 : si64\">,\n      key_max_seq_len = attr<\"0 : si64\">,\n      causal_diagonal_offset = attr<\"0 : si64\">,\n      query_head_size = GetHeadSizeFromTranpose(query_transpose),\n      query_layout = attr<\"\\\"BM(HK)\\\"\">,\n      key_layout = attr<\"\\\"BM(HK)\\\"\">,\n      value_layout = attr<\"\\\"BM(HK)\\\"\">,\n      output_layout = attr<\"\\\"BM(HK)\\\"\">,\n      operand_segment_sizes = attr<\"array<i32: 1, 1, 1, 0, 0, 0, 0>\">,\n      scale = GetReciprocal(scalar_div_operand)\n    } -> (out_t));\n  };\n}\n\nPattern {\n  let device_name: Attr;\n  let device_tag = attr<\"\\\"cuda\\\"\">;\n  let batch_matmul_alpha: Attr;\n  let out_shape: Attr;\n\n  let query: Value;\n  let key: Value;\n  let value: Value;\n\n  let value_reshape = op<oneflow.reshape>(value) {device_name = device_name, device_tag = device_tag};\n  let key_reshape = op<oneflow.reshape>(key) {device_name = device_name, device_tag = device_tag};\n  let query_reshape = op<oneflow.reshape>(query) {device_name = device_name, device_tag = device_tag};\n  let value_permute = op<oneflow.transpose>(value_reshape.0) {device_name = device_name, device_tag = device_tag, perm = attr<\"[0 : si32, 2 : si32, 1 : si32, 3 : si32]\">};\n  let key_permute = op<oneflow.transpose>(key_reshape.0) {device_name = device_name, device_tag = device_tag, perm = attr<\"[0 : si32, 2 : si32, 1 : si32, 3 : si32]\">};\n  let query_permute = op<oneflow.transpose>(query_reshape.0) {device_name = device_name, device_tag = device_tag, perm = attr<\"[0 : si32, 2 : si32, 1 : si32, 3 : si32]\">};\n  let value_reshape_to_batch = op<oneflow.reshape>(value_permute.0) {device_name = device_name, device_tag = device_tag};\n  let key_reshape_to_batch = op<oneflow.reshape>(key_permute.0) {device_name = device_name, device_tag = device_tag};\n  let query_reshape_to_batch = op<oneflow.reshape>(query_permute.0) {device_name = device_name, device_tag = device_tag};\n  let key_transpose = op<oneflow.transpose>(key_reshape_to_batch.0) {device_name = device_name, device_tag = device_tag, perm = attr<\"[0 : si32, 2 : si32, 1 : si32]\">};\n  let scores_scaled = op<oneflow.batch_matmul>(query_reshape_to_batch.0, key_transpose.0) {alpha = batch_matmul_alpha, device_name = device_name, device_tag = device_tag, transpose_a = attr<\"false\">, transpose_b = attr<\"false\">};\n  let attn = op<oneflow.softmax>(scores_scaled.0) {device_name = device_name, device_tag = device_tag};\n  let out = op<oneflow.batch_matmul>(attn.0, value_reshape_to_batch.0) {alpha = attr<\"1.000000e+00  : f64\">, device_name = device_name, device_tag = device_tag, transpose_a = attr<\"false\">, transpose_b = attr<\"false\">};\n  let out_reshape_before = op<oneflow.reshape>(out.0) {device_name = device_name, device_tag = device_tag};\n  let out_transpose = op<oneflow.transpose>(out_reshape_before.0) {device_name = device_name, device_tag = device_tag, perm = attr<\"[0 : si32, 2 : si32, 1 : si32, 3 : si32]\">};\n  let out_reshape = op<oneflow.reshape>(out_transpose.0) {device_name = device_name, device_tag = device_tag, shape = out_shape} -> (out_t: Type);\n\n  IsScalarEqualSqrtDimReciprocal(query_reshape.0, batch_matmul_alpha);\n\n  rewrite out_reshape with{\n    replace out_reshape with CopyUserOpAttrs(out, op<oneflow.fused_multi_head_attention_inference>(query, key, value) {\n      attn_mask_type = attr<\"\\\"none\\\"\">,\n      query_max_seq_len = attr<\"0 : si64\">,\n      key_max_seq_len = attr<\"0 : si64\">,\n      causal_diagonal_offset = attr<\"0 : si64\">,\n      query_head_size = GetHeadSizeFromTranpose(query_permute),\n      query_layout = attr<\"\\\"BM(HK)\\\"\">,\n      key_layout = attr<\"\\\"BM(HK)\\\"\">,\n      value_layout = attr<\"\\\"BM(HK)\\\"\">,\n      output_layout = attr<\"\\\"BM(HK)\\\"\">,\n      operand_segment_sizes = attr<\"array<i32: 1, 1, 1, 0, 0, 0, 0>\">,\n      scale = batch_matmul_alpha\n    } -> (out_t));\n  };\n}\n\nPattern {\n  let device_name: Attr;\n  let device_tag: Attr;\n  let padding_before: Attr;\n  let padding_after: Attr;\n  let data_format: Attr;\n\n  let conv =\n    op<oneflow.conv2d>(\n      op<oneflow.pad>(x: Value){device_name = device_name, device_tag = device_tag, padding_before = padding_before, padding_after = padding_after}, weight: Value)\n      {device_name = device_name, device_tag = device_tag, data_format = data_format};\n  IsPaddingCouldBeAssimilatedIntoConv(padding_before, padding_after, data_format);\n\n  rewrite conv with{\n    let conv2d_and_erase_pad = CreateConv2dAndErasePad(x, weight, padding_before, data_format, conv);\n    replace conv with CopyUserOpAttrs(conv, conv2d_and_erase_pad);\n  };\n}\n\nPattern {\n  let valueType: Type;\n  let x: Value<valueType>;\n  let cast = op<oneflow.cast>(x) -> (valueType);\n\n  replace cast with x;\n}\n\nPattern {\n  let device_name: Attr;\n  let has_float_operand: Attr;\n  let int_operand: Attr;\n  let float_operand: Attr;\n  let diagonal: Attr;\n  let floating_fill_value: Attr;\n  let integer_fill_value: Attr;\n  let is_floating_fill_value: Attr;\n\n  let tril =\n    op<oneflow.tril>(\n      op<oneflow.scalar_mul>(x: Value)\n      {device_name = device_name, device_tag = attr<\"\\\"cuda\\\"\">, has_float_operand = has_float_operand,\n        int_operand = int_operand, float_operand = float_operand})\n      {device_name = device_name, device_tag = attr<\"\\\"cuda\\\"\">, diagonal = diagonal, floating_fill_value = floating_fill_value,\n        integer_fill_value =integer_fill_value, is_floating_fill_value = is_floating_fill_value} -> (out: Type);\n\n  replace tril with CopyUserOpAttrs(tril, CreatScaleTrilOp(x, diagonal, floating_fill_value, integer_fill_value,\n                                      is_floating_fill_value, float_operand ,int_operand, has_float_operand, out));\n}\n\nPattern {\n  let device_name: Attr;\n  let has_float_operand: Attr;\n  let int_operand: Attr;\n  let float_operand: Attr;\n  let diagonal: Attr;\n  let floating_fill_value: Attr;\n  let integer_fill_value: Attr;\n  let is_floating_fill_value: Attr;\n\n  let scalar =\n    op<oneflow.scalar_mul>(\n      op<oneflow.tril>(x: Value)\n      {device_name = device_name, device_tag = attr<\"\\\"cuda\\\"\">, diagonal = diagonal, floating_fill_value = floating_fill_value,\n        integer_fill_value =integer_fill_value, is_floating_fill_value = is_floating_fill_value })\n      {device_name = device_name, device_tag = attr<\"\\\"cuda\\\"\">, has_float_operand = has_float_operand,\n        int_operand = int_operand, float_operand = float_operand} -> (out: Type);\n\n  replace scalar with CopyUserOpAttrs(scalar, CreatScaleTrilOp(x, diagonal, floating_fill_value, integer_fill_value,\n                                        is_floating_fill_value, float_operand ,int_operand, has_float_operand, out));\n}\n\nPattern {\n  let device_name: Attr;\n  let device_tag: Attr;\n\n  let broadcast_mul = op<oneflow.broadcast_mul>(x: Value, y: Value){device_name = device_name, device_tag = device_tag}-> (out: Type);\n\n  IsScalarTensor(y);\n\n  rewrite broadcast_mul with{\n    let scalar_mul = op<oneflow.scalar_mul_by_tensor>(x, y) {device_name = device_name, device_tag = device_tag} -> (out);\n    replace broadcast_mul with CopyUserOpAttrs(broadcast_mul, scalar_mul);\n  };\n}\n"
  },
  {
    "path": "oneflow/ir/lib/OneFlow/PDLL/FuseConv2DBatchNormPattern.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"mlir/Dialect/PDL/IR/PDL.h\"\n#include \"mlir/Dialect/PDLInterp/IR/PDLInterp.h\"\n#include \"mlir/Parser/Parser.h\"\n#include \"mlir/Pass/Pass.h\"\n#include \"mlir/Pass/PassManager.h\"\n#include \"mlir/Transforms/GreedyPatternRewriteDriver.h\"\n#include \"OneFlow/OneFlowPDLLPatterns.h\"\n\nusing namespace mlir;\n\n#include \"oneflow/ir/lib/OneFlow/PDLL/FuseConv2DBatchNormPattern.h.inc\"\n\nnamespace mlir {\n\nnamespace oneflow {\n\nvoid populateFuseConv2DBatchNormPattern(RewritePatternSet& patterns) {\n  populateGeneratedPDLLPatterns(patterns);\n}\n\n}  // namespace oneflow\n\n}  // namespace mlir\n"
  },
  {
    "path": "oneflow/ir/lib/OneFlow/PDLL/FuseConv2DBatchNormPattern.pdll",
    "content": "#include \"OneFlowPDLLUtils.pdll\"\n\nPattern {\n  let device_name: Attr;\n  let device_tag: Attr;\n  let epsilon: Attr;\n  let moving_mean: Value;\n  let moving_variance: Value;\n  let beta: Value;\n  \n  let weight = op<oneflow.variable_ir>;\n  let gamma = op<oneflow.variable_ir>;\n\n  let conv = op<oneflow.conv2d>(x: Value, weight.0){device_name = device_name, device_tag = device_tag};\n  \n  let normalization = op<oneflow.normalization_infer>(conv, moving_mean, moving_variance, gamma.0, beta) {device_name = device_name, device_tag = device_tag, epsilon = epsilon} -> (y: Type);\n\n  rewrite normalization with{\n    let conv2d_bn = CreateConv2DBatchNorm(epsilon, conv, normalization);\n    replace normalization with CopyUserOpAttrs(normalization, conv2d_bn);\n  };\n  \n}\n\n\n\n"
  },
  {
    "path": "oneflow/ir/lib/OneFlow/PDLL/FuseOpsWithBackwardImplPattern.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"mlir/Dialect/PDL/IR/PDL.h\"\n#include \"mlir/Dialect/PDLInterp/IR/PDLInterp.h\"\n#include \"mlir/Parser/Parser.h\"\n#include \"mlir/Pass/Pass.h\"\n#include \"mlir/Pass/PassManager.h\"\n#include \"mlir/Transforms/GreedyPatternRewriteDriver.h\"\n#include \"OneFlow/OneFlowPDLLPatterns.h\"\n\nusing namespace mlir;\n\n#include \"oneflow/ir/lib/OneFlow/PDLL/FuseOpsWithBackwardImplPattern.h.inc\"\n\nnamespace mlir {\n\nnamespace oneflow {\n\nvoid populateFuseOpsWithBackwardImplPattern(RewritePatternSet& patterns) {\n  populateGeneratedPDLLPatterns(patterns);\n}\n\n}  // namespace oneflow\n\n}  // namespace mlir\n"
  },
  {
    "path": "oneflow/ir/lib/OneFlow/PDLL/FuseOpsWithBackwardImplPattern.pdll",
    "content": "#include \"OneFlowPDLLUtils.pdll\"\n\nPattern {\n  let device_name: Attr;\n  let device_tag: Attr;\n\n  let matmul_wx = op<oneflow.broadcast_matmul>(x: Value, w: Value){device_name = device_name, device_tag = device_tag, alpha = attr<\"1.000000e+00 : f64\">};\n  let matmul_wx_add = op<oneflow.broadcast_add>(matmul_wx.0, b: Value){device_name = device_name, device_tag = device_tag} -> (matmul_wx_out: Type);\n\n  let hidden_states = op<oneflow.narrow>(matmul_wx_add.0){device_name = device_name, device_tag = device_tag};\n  let gate = op<oneflow.narrow>(matmul_wx_add.0){device_name = device_name, device_tag = device_tag};\n  let gate_activate = op<oneflow.gelu>(gate.0){device_name = device_name, device_tag = device_tag};\n  let gelu_out = op<oneflow.broadcast_mul>(hidden_states.0,gate_activate.0){device_name = device_name, device_tag = device_tag} -> (out: Type);\n\n  rewrite gelu_out with{\n    let fused_gelu_out = op<oneflow.fused_glu>(x, w, b){activation = attr<\"\\\"gelu\\\"\">, operand_segment_sizes = attr<\"array<i32: 1, 1, 1, 0, 0>\">, device_name = device_name, device_tag = device_tag, has_bias = attr<\"true\">, is_split = attr<\"false\">}-> (out, matmul_wx_out);\n    CopyUserOpAttrs(gelu_out, fused_gelu_out);\n    replace gelu_out with fused_gelu_out.0;\n    replace matmul_wx_add with fused_gelu_out.1;\n  };\n}\n\nPattern {\n  let device_name: Attr;\n  let device_tag: Attr;\n\n  let matmul_wx_add = op<oneflow.fused_matmul_bias>(x: Value, w: Value, b: Value){device_name = device_name, device_tag = device_tag, alpha = attr<\"1.000000e+00 : f64\">} -> (matmul_wx_out: Type);\n\n  let hidden_states = op<oneflow.narrow>(matmul_wx_add.0){device_name = device_name, device_tag = device_tag};\n  let gate = op<oneflow.narrow>(matmul_wx_add.0){device_name = device_name, device_tag = device_tag};\n  let gate_activate = op<oneflow.gelu>(gate.0){device_name = device_name, device_tag = device_tag};\n  let gelu_out = op<oneflow.broadcast_mul>(hidden_states.0,gate_activate.0){device_name = device_name, device_tag = device_tag}-> (out: Type);\n\n  rewrite gelu_out with{\n    let fused_gelu_out = op<oneflow.fused_glu>(x, w, b){activation = attr<\"\\\"gelu\\\"\">, operand_segment_sizes = attr<\"array<i32: 1, 1, 1, 0, 0>\">, device_name = device_name, device_tag = device_tag}-> (out, matmul_wx_out);\n    CopyUserOpAttrs(gelu_out, fused_gelu_out);\n    replace gelu_out with fused_gelu_out.0;\n    replace matmul_wx_add with fused_gelu_out.1;\n  };\n}\n"
  },
  {
    "path": "oneflow/ir/lib/OneFlow/PDLL/NormalizationPatterns.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"mlir/Dialect/PDL/IR/PDL.h\"\n#include \"mlir/Dialect/PDLInterp/IR/PDLInterp.h\"\n#include \"mlir/Parser/Parser.h\"\n#include \"mlir/Pass/Pass.h\"\n#include \"mlir/Pass/PassManager.h\"\n#include \"mlir/Transforms/GreedyPatternRewriteDriver.h\"\n#include \"OneFlow/OneFlowPDLLPatterns.h\"\n\nusing namespace mlir;\n\n#include \"oneflow/ir/lib/OneFlow/PDLL/NormalizationPatterns.h.inc\"\n\nnamespace mlir {\n\nnamespace oneflow {\n\nvoid populateNormalizationOpPatterns(RewritePatternSet& patterns) {\n  populateGeneratedPDLLPatterns(patterns);\n}\n\n}  // namespace oneflow\n\n}  // namespace mlir\n"
  },
  {
    "path": "oneflow/ir/lib/OneFlow/PDLL/NormalizationPatterns.pdll",
    "content": "#include \"OneFlowPDLLUtils.pdll\"\n\nPattern {\n  let device_name: Attr;\n  let device_tag: Attr;\n  let axis: Attr;\n  let epsilon: Attr;\n  let training = attr<\"true\">;\n  let momentum: Attr;\n  let x: Value;\n  let moving_mean: Value;\n  let moving_variance: Value;\n  let gamma: Value;\n  let beta: Value;\n  let addend: Value;\n  let normalization = op<oneflow.normalization>(x, moving_mean, moving_variance, gamma, beta) {operand_segment_sizes = attr<\"array<i32: 1, 1, 1, 1, 1, 0>\">, axis = axis, epsilon = epsilon, training = training, momentum = momentum, device_name = device_name, device_tag = device_tag} -> (y: Type, mean: Type, inv_variance: Type);\n  let relu =\n    op<oneflow.relu>(\n      op<oneflow.add_n2>(normalization.0, addend) {device_name = device_name, device_tag = device_tag})\n      {device_name = device_name, device_tag = device_tag} -> (out: Type);\n\n  rewrite relu with{\n    let fused_bn = CopyUserOpAttrs(normalization, op<oneflow.normalization_add_relu>(x, addend, moving_mean, moving_variance, gamma, beta) {operand_segment_sizes = attr<\"array<i32: 1, 1, 1, 1, 1, 1>\">, result_segment_sizes = attr<\"array<i32: 1, 1, 1, 1>\">, axis = axis, epsilon = epsilon, training = training, momentum = momentum, device_name = device_name, device_tag = device_tag});\n    replace relu with fused_bn.0;\n  };\n}\n\nPattern {\n  let device_name: Attr;\n  let device_tag: Attr;\n  let axis: Attr;\n  let epsilon: Attr;\n  let training = attr<\"true\">;\n  let momentum: Attr;\n  let x: Value;\n  let moving_mean: Value;\n  let moving_variance: Value;\n  let gamma: Value;\n  let beta: Value;\n  let addend: Value;\n  let normalization = op<oneflow.normalization_infer>(x, moving_mean, moving_variance, gamma, beta) {operand_segment_sizes = attr<\"array<i32: 1, 1, 1, 1, 1, 0>\">, axis = axis, epsilon = epsilon, training = training, momentum = momentum, device_name = device_name, device_tag = device_tag} -> (y: Type);\n  let relu =\n    op<oneflow.relu>(\n      op<oneflow.add_n2>(normalization.0, addend) {device_name = device_name, device_tag = device_tag})\n      {device_name = device_name, device_tag = device_tag} -> (out: Type);\n\n  rewrite relu with{\n    let fused_bn = CopyUserOpAttrs(normalization, op<oneflow.normalization_add_relu>(x, addend, moving_mean, moving_variance, gamma, beta) {operand_segment_sizes = attr<\"array<i32: 1, 1, 1, 1, 1, 1>\">, result_segment_sizes = attr<\"array<i32: 1, 1, 1, 1>\">, axis = axis, epsilon = epsilon, training = training, momentum = momentum, device_name = device_name, device_tag = device_tag});\n    replace relu with fused_bn.0;\n  };\n}\n\nPattern {\n  let device_name: Attr;\n  let device_tag: Attr;\n  let axis: Attr;\n  let epsilon: Attr;\n  let training = attr<\"false\">;\n  let momentum: Attr;\n  let x: Value;\n  let moving_mean: Value;\n  let moving_variance: Value;\n  let gamma: Value;\n  let beta: Value;\n  let addend: Value;\n  let normalization = op<oneflow.normalization>(x, moving_mean, moving_variance, gamma, beta) {operand_segment_sizes = attr<\"array<i32: 1, 1, 1, 1, 1, 0>\">, axis = axis, epsilon = epsilon, training = training, momentum = momentum, device_name = device_name, device_tag = device_tag} -> (y: Type, mean: Type, inv_variance: Type);\n  let relu =\n    op<oneflow.relu>(\n      op<oneflow.add_n2>(normalization.0, addend) {device_name = device_name, device_tag = device_tag})\n      {device_name = device_name, device_tag = device_tag} -> (out: Type);\n\n  rewrite relu with{\n    let fused_bn = CopyUserOpAttrs(normalization, op<oneflow.normalization_add_relu>(x, addend, moving_mean, moving_variance, gamma, beta) {operand_segment_sizes = attr<\"array<i32: 1, 1, 1, 1, 1, 1>\">, result_segment_sizes = attr<\"array<i32: 1, 1, 0, 0>\">, axis = axis, epsilon = epsilon, training = training, momentum = momentum, device_name = device_name, device_tag = device_tag});\n    replace relu with fused_bn.0;\n  };\n}\n\nPattern {\n  let device_name: Attr;\n  let device_tag: Attr;\n  let axis: Attr;\n  let epsilon: Attr;\n  let training = attr<\"false\">;\n  let momentum: Attr;\n  let x: Value;\n  let moving_mean: Value;\n  let moving_variance: Value;\n  let gamma: Value;\n  let beta: Value;\n  let addend: Value;\n  let normalization = op<oneflow.normalization_infer>(x, moving_mean, moving_variance, gamma, beta) {operand_segment_sizes = attr<\"array<i32: 1, 1, 1, 1, 1, 0>\">, axis = axis, epsilon = epsilon, training = training, momentum = momentum, device_name = device_name, device_tag = device_tag} -> (y: Type);\n  let relu =\n    op<oneflow.relu>(\n      op<oneflow.add_n2>(normalization.0, addend) {device_name = device_name, device_tag = device_tag})\n      {device_name = device_name, device_tag = device_tag} -> (out: Type);\n\n  rewrite relu with{\n    let fused_bn = CopyUserOpAttrs(normalization, op<oneflow.normalization_add_relu>(x, addend, moving_mean, moving_variance, gamma, beta) {operand_segment_sizes = attr<\"array<i32: 1, 1, 1, 1, 1, 1>\">, result_segment_sizes = attr<\"array<i32: 1, 1, 0, 0>\">, axis = axis, epsilon = epsilon, training = training, momentum = momentum, device_name = device_name, device_tag = device_tag});\n    replace relu with fused_bn.0;\n  };\n}\n\nPattern {\n  let device_name: Attr;\n  let device_tag: Attr;\n  let axis: Attr;\n  let epsilon: Attr;\n  let training = attr<\"false\">;\n  let momentum: Attr;\n  let x: Value;\n  let moving_mean: Value;\n  let moving_variance: Value;\n  let gamma: Value;\n  let beta: Value;\n  let normalization = op<oneflow.normalization>(x, moving_mean, moving_variance, gamma, beta) {operand_segment_sizes = attr<\"array<i32: 1, 1, 1, 1, 1, 0>\">, axis = axis, epsilon = epsilon, training = training, momentum = momentum, device_name = device_name, device_tag = device_tag} -> (y: Type, mean: Type, inv_variance: Type);\n  let relu =\n    op<oneflow.relu>(normalization.0) {device_name = device_name, device_tag = device_tag} -> (out: Type);\n\n  rewrite relu with{\n    let fused_bn = CopyUserOpAttrs(normalization, op<oneflow.normalization_add_relu>(x, moving_mean, moving_variance, gamma, beta) {operand_segment_sizes = attr<\"array<i32: 1, 0, 1, 1, 1, 1>\">, result_segment_sizes = attr<\"array<i32: 1, 1, 0, 0>\">, axis = axis, epsilon = epsilon, training = training, momentum = momentum, device_name = device_name, device_tag = device_tag});\n    replace relu with fused_bn.0;\n  };\n}\n\nPattern {\n  let device_name: Attr;\n  let device_tag: Attr;\n  let axis: Attr;\n  let epsilon: Attr;\n  let training = attr<\"false\">;\n  let momentum: Attr;\n  let x: Value;\n  let moving_mean: Value;\n  let moving_variance: Value;\n  let gamma: Value;\n  let beta: Value;\n  let normalization = op<oneflow.normalization_infer>(x, moving_mean, moving_variance, gamma, beta) {operand_segment_sizes = attr<\"array<i32: 1, 1, 1, 1, 1, 0>\">, axis = axis, epsilon = epsilon, training = training, momentum = momentum, device_name = device_name, device_tag = device_tag} -> (y: Type);\n  let relu =\n    op<oneflow.relu>(normalization.0) {device_name = device_name, device_tag = device_tag} -> (out: Type);\n\n  rewrite relu with{\n    let fused_bn = CopyUserOpAttrs(normalization, op<oneflow.normalization_add_relu>(x, moving_mean, moving_variance, gamma, beta) {operand_segment_sizes = attr<\"array<i32: 1, 0, 1, 1, 1, 1>\">, result_segment_sizes = attr<\"array<i32: 1, 1, 0, 0>\">, axis = axis, epsilon = epsilon, training = training, momentum = momentum, device_name = device_name, device_tag = device_tag});\n    replace relu with fused_bn.0;\n  };\n}\n"
  },
  {
    "path": "oneflow/ir/lib/OneFlow/PDLL/OneFlowPDLLUtils.pdll",
    "content": "Rewrite BuildFusedBiasAddMaskScaleOpWithRate(a: Value, b: Value, mask: Value, axis: Attr, rate: Attr, dropout: Op) -> Op;\nRewrite CopyUserOpAttrs(src: Op, dst: Op) -> Op;\n\nRewrite GetHeadSizeFromTranpose(transpose: Op) -> Attr;\nRewrite CreateConv2dAndErasePad(x: Value, weight: Value, padding_before: Attr, data_format: Attr, conv: Op) -> Op;\nRewrite CreatScaleTrilOp(x: Value, diagonal: Attr, floating_fill_value: Attr, integer_fill_value: Attr,\n                          is_floating_fill_value: Attr, float_operand: Attr, int_operand: Attr, has_float_operand: Attr, out: Type) -> Op {\n  let floating_scale_value = float_operand;\n  let integer_scale_value = int_operand;\n  let is_floating_scale_value = has_float_operand;\n  let scale_tril_op = op<oneflow.fused_scale_tril>(x){diagonal = diagonal, floating_fill_value = floating_fill_value, integer_fill_value = integer_fill_value,\n                                                        is_floating_fill_value = is_floating_fill_value, floating_scale_value = floating_scale_value,\n                                                        integer_scale_value = integer_scale_value, is_floating_scale_value = is_floating_scale_value} -> (out);\n  return scale_tril_op;\n}\n\nRewrite CreateConv2DBatchNorm(epsilon: Attr, conv: Op, bn: Op) -> Op;\n\nConstraint IsPaddingCouldBeAssimilatedIntoConv(padding_before: Attr, padding_after: Attr, data_format:Attr);\nConstraint IsNotNestedInJit(mul: Op);\nConstraint IsScalarTensor(value: Value);\nConstraint IsScalarEqualSqrtDim(query_reshape: Value, scalar_div_operand: Attr);\nConstraint IsScalarEqualSqrtDimReciprocal(query_reshape: Value, scalar_div_operand: Attr);\nRewrite GetReciprocal(a: Attr) -> Attr;\n"
  },
  {
    "path": "oneflow/ir/lib/OneFlow/Passes.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"mlir/Dialect/MemRef/Transforms/Passes.h\"\n#include \"OneFlow/Transform/OneFlowMemPool.h\"\n#include \"OneFlow/Transform/EliminateAllocOps.h\"\n#include \"OneFlow/Transform/OneFlowStream.h\"\n#include \"mlir/Conversion/TosaToTensor/TosaToTensor.h\"\n#include \"oneflow/ir/oneflow-translate/include/OneFlow/MLIROneFlowTranslation.h\"\n#include \"oneflow/core/kernel/cuda_graph_support.h\"\n#include \"oneflow/core/common/data_type.pb.h\"\n#include \"oneflow/core/framework/dtype.h\"\n#include \"oneflow/core/framework/variable_tensor_mgr.h\"\n#include \"oneflow/core/operator/variable_op.h\"\n#include \"oneflow/core/framework/sbp_context.h\"\n#include \"oneflow/core/job/sbp_signature_builder.h\"\n#include \"oneflow/core/framework/random_generator.h\"\n#include \"oneflow/core/framework/variable_tensor_mgr.h\"\n#include \"oneflow/core/operator/variable_op.h\"\n#include \"oneflow/core/framework/sbp_context.h\"\n#include \"oneflow/core/job/sbp_signature_builder.h\"\n#include \"OneFlow/OneFlowOps.h\"\n#include \"OneFlow/OneFlowDialect.h\"\n#include \"OneFlow/OneFlowUtils.h\"\n#include \"OneFlow/Passes.h\"\n#include \"OneFlow/OneFlowUtils.h\"\n#include \"OneFlow/OneFlowPatternUtils.h\"\n#include \"OneFlow/OneFlowSupport.h\"\n#include \"OneFlow/SBP/SBPImporter.h\"\n#include \"OneFlow/SBP/SBPAttributes.h\"\n#include \"OneFlow/OKL/OKLOps.h\"\n#include \"OneFlow/OKL/OKLTypes.h\"\n#include \"OneFlow/OKL/Kernel/RegContext.h\"\n#include \"OneFlow/OKM/Conversion/Conversion.h\"\n#include \"OneFlow/Transform/TransposeHelpers.h\"\n#include \"OneFlow/Transform/OutlineAndFuse.h\"\n#include \"OneFlow/OneFlowPDLLPatterns.h\"\n#include \"OneFlow/OKL/passes.h\"\n#include \"OneFlow/OKL/OKLAttributes.h\"\n#include \"OneFlow/OKM/passes.h\"\n#include \"mlir/Dialect/Tosa/Transforms/Passes.h\"\n#include \"mlir/Dialect/LLVMIR/FunctionCallUtils.h\"\n#include \"mlir/Dialect/LLVMIR/LLVMDialect.h\"\n#include \"mlir/Dialect/LLVMIR/LLVMTypes.h\"\n#include \"mlir/Dialect/LLVMIR/Transforms/RequestCWrappers.h\"\n#include \"mlir/IR/BuiltinOps.h\"\n#include \"mlir/IR/Diagnostics.h\"\n#include \"mlir/IR/SymbolTable.h\"\n#include \"mlir-c/BuiltinAttributes.h\"\n#include \"mlir/IR/Attributes.h\"\n#include \"mlir/IR/OperationSupport.h\"\n#include \"mlir/IR/MLIRContext.h\"\n#include \"mlir/Dialect/Tosa/Transforms/Passes.h\"\n#include \"mlir/Dialect/LLVMIR/FunctionCallUtils.h\"\n#include \"mlir/Dialect/LLVMIR/LLVMDialect.h\"\n#include \"mlir/Dialect/LLVMIR/LLVMTypes.h\"\n#include \"mlir/Dialect/LLVMIR/Transforms/RequestCWrappers.h\"\n#include \"mlir/IR/Builders.h\"\n#include \"mlir/IR/BuiltinOps.h\"\n#include \"mlir/IR/Diagnostics.h\"\n#include \"mlir/IR/Location.h\"\n#include \"mlir/IR/Operation.h\"\n#include \"mlir/IR/SymbolTable.h\"\n#include \"mlir/IR/TypeRange.h\"\n#include \"mlir/Support/LLVM.h\"\n\n#include \"mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h\"\n#include \"mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h\"\n#include \"mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h\"\n#include \"mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h\"\n#include \"mlir/Conversion/TosaToLinalg/TosaToLinalg.h\"\n#include \"mlir/Conversion/AffineToStandard/AffineToStandard.h\"\n#include \"mlir/Dialect/Affine/IR/AffineOps.h\"\n#include \"mlir/Dialect/Linalg/Passes.h\"\n#include \"mlir/Dialect/MemRef/IR/MemRef.h\"\n#include \"mlir/Dialect/SCF/Transforms/Passes.h\"\n#include \"mlir/Dialect/Func/IR/FuncOps.h\"\n#include \"mlir/Dialect/Func/Transforms/Passes.h\"\n#include \"mlir/Dialect/Tensor/Transforms/Passes.h\"\n#include \"mlir/Dialect/Tosa/IR/TosaOps.h\"\n#include \"mlir/IR/BuiltinAttributes.h\"\n#include \"mlir/IR/BuiltinTypes.h\"\n#include \"mlir/IR/OpImplementation.h\"\n#include \"mlir/IR/OperationSupport.h\"\n#include \"mlir/IR/PatternMatch.h\"\n#include \"mlir/IR/IRMapping.h\"\n#include \"mlir/IR/Value.h\"\n#include \"mlir/Pass/Pass.h\"\n#include \"mlir/Pass/PassManager.h\"\n#include \"mlir/Support/LogicalResult.h\"\n#include \"mlir/Transforms/DialectConversion.h\"\n#include \"mlir/Transforms/Passes.h\"\n#include \"mlir/Dialect/Bufferization/Transforms/Passes.h\"\n#include \"mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h\"\n\n#include \"llvm/ADT/ArrayRef.h\"\n#include \"llvm/ADT/None.h\"\n#include \"llvm/ADT/DenseSet.h\"\n#include \"llvm/ADT/SmallVector.h\"\n#include \"llvm/ADT/SetOperations.h\"\n#include \"llvm/ADT/STLExtras.h\"\n#include \"llvm/Support/Casting.h\"\n#include \"llvm/Support/FormatVariadic.h\"\n#include \"llvm/Support/ErrorHandling.h\"\n\n#include <algorithm>\n#include <memory>\n#include <vector>\n#include <iostream>\n#include <string>\n\n#ifdef WITH_MLIR_CUDA_CODEGEN\n#include \"mlir/Conversion/GPUCommon/GPUCommonPass.h\"\n#include \"mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h\"\n#include \"mlir/Dialect/GPU/Transforms/Passes.h\"\n#include \"mlir/Conversion/SCFToGPU/SCFToGPUPass.h\"\n\n#endif  // WITH_MLIR_CUDA_CODEGEN\n\n#ifdef WITH_CUDA\n// enable with_cuda_graphs\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n#endif  // WITH_CUDA\n\nnamespace mlir {\nnamespace oneflow {\n\nLLVM::LLVMPointerType GetPtr(::mlir::PatternRewriter& rewriter) {\n  return LLVM::LLVMPointerType::get(IntegerType::get(rewriter.getContext(), 8));\n}\n\ntemplate<typename T>\nLogicalResult DumpAssembly(::mlir::PatternRewriter& rewriter, T op, StringRef func_name) {\n  // TODO: now we only need one JIT engine\n  auto parent_func_op = op->template getParentOfType<oneflow::Job>();\n  if (!parent_func_op) { return failure(); }\n  auto parent_module_op = parent_func_op->template getParentOfType<ModuleOp>();\n  if (!parent_module_op) { return failure(); }\n  SymbolTable symbol_table(parent_module_op);\n  std::string mlir;\n  llvm::raw_string_ostream os_mlir(mlir);\n  if (auto found = symbol_table.lookup(func_name)) {\n    found->print(os_mlir);\n  } else {\n    parent_module_op->dump();\n    return op.emitError(\"symbol of jit function not found: \" + op.getOpName());\n  }\n  op->setAttr(\"mlir_assembly\", rewriter.getStringAttr(mlir));\n  return success();\n}\n\nLLVM::LLVMFuncOp DeclareKernelLaunchCInterface(::mlir::PatternRewriter& rewriter,\n                                               mlir::Location loc, ModuleOp* module,\n                                               StringRef c_api_callee, Type llvm_ptr_type) {\n  LLVM::LLVMFuncOp func;\n  if (!(func = module->lookupSymbol<LLVM::LLVMFuncOp>(c_api_callee))) {\n    OpBuilder::InsertionGuard guard(rewriter);\n    rewriter.setInsertionPointToStart(module->getBody());\n    auto void_type = LLVM::LLVMVoidType::get(rewriter.getContext());\n    auto func_type = LLVM::LLVMFunctionType::get(void_type, {llvm_ptr_type, llvm_ptr_type}, false);\n    func = rewriter.create<LLVM::LLVMFuncOp>(loc, c_api_callee, func_type, LLVM::Linkage::External);\n\n    func->setAttr(\"llvm.emit_c_interface\", mlir::UnitAttr::get(rewriter.getContext()));\n  }\n  return func;\n}\n\nLLVM::GlobalOp DeclareOrGetGlobalString(::mlir::PatternRewriter& rewriter, mlir::Location loc,\n                                        ModuleOp* module, StringRef func_name) {\n  LLVM::GlobalOp global;\n  StringRef variable = rewriter.getStringAttr(func_name + \"_var\");\n  if (!(global = module->lookupSymbol<LLVM::GlobalOp>(variable))) {\n    OpBuilder::InsertionGuard insertGuard(rewriter);\n    rewriter.setInsertionPointToStart(module->getBody());\n    auto type =\n        LLVM::LLVMArrayType::get(IntegerType::get(rewriter.getContext(), 8), func_name.size());\n    global =\n        rewriter.create<LLVM::GlobalOp>(loc, type, /*isConstant=*/true, LLVM::Linkage::Internal,\n                                        variable, rewriter.getStringAttr(func_name),\n                                        /*alignment=*/0);\n  }\n  return global;\n}\n\ntemplate<typename Wrap>\nModuleOp GetModuleOpFromJobBodyOp(Operation* op) {\n  auto parent_func_op = op->getParentOfType<Wrap>();\n  if (!parent_func_op) { return nullptr; }\n  return parent_func_op->template getParentOfType<ModuleOp>();\n}\n\nfunc::FuncOp InsertKernelOFFuncOp(::mlir::PatternRewriter& rewriter, Operation* op,\n                                  const std::string& func_name) {\n  auto loc = op->getLoc();\n  auto module = GetModuleOpFromJobBodyOp<func::FuncOp>(op);\n  if (!module) {\n    emitError(loc) << \"null ModuleOp \" << *op;\n    return nullptr;\n  }\n\n  IRMapping mapping;\n  OpBuilder::InsertionGuard guard(rewriter);\n  rewriter.setInsertionPointToStart(module.getBody());\n\n  auto func_type =\n      rewriter.getFunctionType(TypeRange(op->getOperandTypes()), TypeRange(op->getResultTypes()));\n  func::FuncOp func = rewriter.create<func::FuncOp>(loc, func_name, func_type);\n  func->setAttr(\"compiled\", rewriter.getStringAttr(\"true\"));\n  func.getBody().emplaceBlock();\n  for (auto& arg : func_type.getInputs()) { func.getBody().addArguments(arg, loc); }\n  for (auto argument_pair :\n       llvm::zip(ValueRange(op->getOperands()), func.getBody().getArguments())) {\n    mapping.map(std::get<0>(argument_pair), std::get<1>(argument_pair));\n  }\n  rewriter.setInsertionPointToStart(&func.getBody().front());\n  ImplicitLocOpBuilder new_block(loc, rewriter);\n  new_block.clone(*op, mapping);\n  SmallVector<::mlir::Value, 4> mapped_results;\n  for (auto result : ValueRange(op->getResults())) {\n    mapped_results.push_back(mapping.lookup(result));\n  }\n  rewriter.create<func::ReturnOp>(loc, mapped_results);\n  return func;\n}\n\n::llvm::SmallVector<::mlir::Value, 4> CreateGPUMemcpyOpFromMemrefCopy(\n    ::mlir::PatternRewriter& rewriter, ::mlir::memref::CopyOp copyOp) {\n  // NOTE: to get lowered to LLVM, it has to be async\n  ::mlir::ValueRange empty_async_dependencies{};\n  auto token = rewriter.getType<gpu::AsyncTokenType>();\n  auto t0 = rewriter.create<gpu::WaitOp>(copyOp->getLoc(), token, empty_async_dependencies)\n                .getAsyncToken();\n  auto t2 = rewriter\n                .create<gpu::MemcpyOp>(copyOp->getLoc(),\n                                       /*optional asyncToken*/ token,\n                                       /*asyncDependencies*/ llvm::SmallVector<Value, 1>({t0}),\n                                       /*dst*/ copyOp.getTarget(),\n                                       /*src*/ copyOp.getSource())\n                .getResults();\n  rewriter.create<gpu::WaitOp>(copyOp->getLoc(), llvm::None, t2);\n  return {};\n}\n\nbool HasZeroPadding(mlir::ArrayAttr padding) {\n  for (auto val : padding.getValue()) {\n    if (val.cast<IntegerAttr>().getValue().getSExtValue() != 0) return false;\n  }\n  return true;\n}\n\nNamedAttrList GetUserOpCommonAttrs(MLIRContext* ctx, const std::string& op_name) {\n  NamedAttrList attrs;\n  attrs.set(OpTrait::IsOpConfCompatible<void>::getOpNameAttr(), StringAttr::get(ctx, op_name));\n  attrs.set(OpTrait::IsOpConfCompatible<void>::getDeviceTagAttr(), StringAttr::get(ctx, \"cpu\"));\n  attrs.set(OpTrait::IsOpConfCompatible<void>::getDeviceNameAttr(),\n            ArrayAttr::get(ctx, llvm::to_vector<8>(llvm::map_range(ArrayRef<StringRef>({\"@0:0\"}),\n                                                                   [&](StringRef v) -> Attribute {\n                                                                     return StringAttr::get(ctx, v);\n                                                                   }))));\n  return attrs;\n}\n\nstruct ReplaceVariablePattern : public ::mlir::RewritePattern {\n  explicit ReplaceVariablePattern(::mlir::MLIRContext* context)\n      : ::mlir::RewritePattern(\"oneflow.variable\", 1, context, {\"oneflow.variable_ir\"}) {}\n  ::mlir::LogicalResult matchAndRewrite(::mlir::Operation* op0,\n                                        ::mlir::PatternRewriter& rewriter) const override {\n    auto op = ::llvm::dyn_cast<oneflow::VariableOp>(op0);\n    if (!op) return failure();\n    NamedAttrList attrs;\n    if (op.getOpName().str().find(\"FreeEagerTensor\") != std::string::npos) { return failure(); }\n    attrs.set(StringAttr::get(getContext(), \"value\"),\n              support::TensorToDenseElementsAttr(\n                  CHECK_JUST(::oneflow::Singleton<::oneflow::VariableTensorMgr>::Get()->Get(\n                      op.getOpName().str(), ::oneflow::DType::Float())),\n                  rewriter.getContext()));\n    attrs.set(op.getOpNameAttrName(), op.getOpNameAttr());\n    attrs.set(op.getDataTypeAttrName(), op.getDataTypeAttr());\n    attrs.set(op.getDeviceTagAttrName(), op.getDeviceTagAttr());\n    attrs.set(op.getDeviceNameAttrName(), op.getDeviceNameAttr());\n    attrs.set(op.getScopeSymbolIdAttrName(), op.getScopeSymbolIdAttr());\n    attrs.set(op.getHierarchyAttrName(), op.getHierarchyAttr());\n    auto name = FrozenVariableOp::getNdSbpAttrName(\n        OperationName(FrozenVariableOp::getOperationName(), rewriter.getContext()));\n\n    auto parallel_attr = op.getParallelAttr();\n    attrs.set(name, SBPTranslation::ConvertSBPToString(rewriter, parallel_attr));\n    auto op_new = rewriter.create<oneflow::FrozenVariableOp>(op->getLoc(), op.getOutput().getType(),\n                                                             ValueRange(), attrs);\n    rewriter.replaceOp(op0, op_new->getResults());\n    return ::mlir::success();\n  }\n};\n\nstruct ReplaceVariableIrPattern : public ::mlir::RewritePattern {\n  explicit ReplaceVariableIrPattern(::mlir::MLIRContext* context)\n      : ::mlir::RewritePattern(\"oneflow.variable_ir\", 1, context, {\"oneflow.variable\"}) {}\n  ::mlir::LogicalResult matchAndRewrite(::mlir::Operation* op0,\n                                        ::mlir::PatternRewriter& rewriter) const override {\n    auto op = ::llvm::dyn_cast<oneflow::FrozenVariableOp>(op0);\n    if (!op) return failure();\n    NamedAttrList attrs;\n    const auto tensor_attr = op.getValue();\n    attrs.set(StringAttr::get(getContext(), \"shape\"),\n              rewriter.getArrayAttr(llvm::to_vector<8>(llvm::map_range(\n                  tensor_attr.getType().cast<mlir::RankedTensorType>().getShape(),\n                  [&](int64_t v) -> Attribute {\n                    return IntegerAttr::get(rewriter.getIntegerType(64, /*isSigned=*/true),\n                                            APInt(64, v, /*isSigned=*/true));\n                  }))));\n    attrs.set(StringAttr::get(getContext(), \"data_type\"),\n              oneflow::DataTypeAttr::get(getContext(), oneflow::DataType::DT_Float));\n    auto output_lbns_attr = rewriter.getStrArrayAttr({op.getOpName().str() + \"/out\"});\n    attrs.set(OpTrait::IsImportCompatible<void>::getOutputLBNsAttr(), output_lbns_attr);\n    attrs.set(op.getOpNameAttrName(), op.getOpNameAttr());\n    attrs.set(op.getDataTypeAttrName(), op.getDataTypeAttr());\n    attrs.set(op.getDeviceTagAttrName(), op.getDeviceTagAttr());\n    attrs.set(op.getDeviceNameAttrName(), op.getDeviceNameAttr());\n    attrs.set(op.getScopeSymbolIdAttrName(), op.getScopeSymbolIdAttr());\n    attrs.set(op.getHierarchyAttrName(), op.getHierarchyAttr());\n    auto name = VariableOp::getParallelAttrName(\n        OperationName(VariableOp::getOperationName(), rewriter.getContext()));\n\n    auto nd_size = op.getHierarchy()->size();\n    ArrayAttr nd_sbp = op.getNdSbp();\n    std::vector<std::string> nd_sbp_str;\n    std::for_each(nd_sbp.begin(), nd_sbp.end(), [&](Attribute elem) {\n      if (auto sbp_str_attr = elem.dyn_cast<StringAttr>()) {\n        nd_sbp_str.push_back(sbp_str_attr.str());\n      }\n    });\n    attrs.set(name, SBPTranslation::ConvertNdSbpToPsig(rewriter, nd_sbp_str, nd_size));\n    auto op_new = rewriter.create<oneflow::VariableOp>(op->getLoc(), op.getOutput().getType(),\n                                                       ValueRange(), attrs);\n    const std::string tensor_name = op.getOpNameAttr().str();\n    const auto data_type = support::FromMLIRAttrToOFDataType(op.getDataTypeAttr());\n    if (failed(data_type)) {\n      op0->emitError(::llvm::formatv(\"unsupported data type: {0}\",\n                                     ConvertToString(op.getDataTypeAttr().getValue())));\n      return ::mlir::failure();\n    }\n    auto var_tensor = CHECK_JUST(\n        ::oneflow::Singleton<::oneflow::VariableTensorMgr>::Get()->Get(op.getOpName().str()));\n    if (var_tensor) {\n      support::DenseElementsAttrToTensor(tensor_attr, op.getDeviceTagAttr(), op.getDeviceNameAttr(),\n                                         var_tensor);\n    } else {\n      CHECK_JUST(::oneflow::Singleton<::oneflow::VariableTensorMgr>::Get()->Set(\n          tensor_name,  // tensor_name can't be replaced by op.op_nameAttr().str() directly when\n                        // compiling with gcc and I has no idea why.\n                        // But it works when compiling with clang.\n                        // Maybe temporary objects would be released earlier when using gcc.\n          support::DenseElementsAttrToTensor(tensor_attr, op.getDeviceTagAttr(),\n                                             op.getDeviceNameAttr()),\n          CHECK_JUST(::oneflow::DType::Get(data_type.value()))));\n    }\n    // replaceOp may deallocate `op0` (and also `op`), so we should not use `op` after this call.\n    rewriter.replaceOp(op0, op_new->getResults());\n    return ::mlir::success();\n  }\n};\n\nLogicalResult InitTransposeAttributes(Operation* op, NamedAttrList& transpose_attributes,\n                                      PatternRewriter& rewriter) {\n  if (op->hasTrait<OpTrait::IsOpConfCompatible>()) {\n    return OpTrait::IsOpConfCompatible<void>::saveToNamedAttrList(op, transpose_attributes);\n  } else {\n    op->emitError(\"must be a op of trait IsOpConfCompatible!\");\n    return failure();\n  }\n}\n\nbool IsAddToOutputNone(ValueRange value) { return (int)value.size() > 0 ? false : true; }\n\nllvm::SmallVector<int32_t> getChannelLastTransposePerm() { return {0, 2, 3, 1}; }\n\nllvm::SmallVector<int32_t> getChannelFirstTransposePerm() { return {0, 3, 1, 2}; }\n\nllvm::SmallVector<mlir::Value, 4> getInputOperandTransposeOp(NCHWCompatible op, Value val,\n                                                             NamedAttrList transpose_attributes,\n                                                             int num_transposed_operand,\n                                                             PatternRewriter& rewriter) {\n  std::string transpose_name = OpTrait::IsOpConfCompatible<void>::getOpName(op).str()\n                               + \"_transpose_input_\" + std::to_string(num_transposed_operand);\n  transpose_attributes.set(llvm::StringRef(OpTrait::IsOpConfCompatible<void>::getOpNameAttr()),\n                           rewriter.getStringAttr(transpose_name));\n  SmallVector<Value, 4> input_operands;\n  input_operands.push_back(val);\n  auto res = rewriter\n                 .create<oneflow::TransposeOp>(op.getLoc(), getNHWCType(val.getType()),\n                                               input_operands, transpose_attributes)\n                 ->getResults();\n  return res;\n}\n\nTransposeOp getResultTransposeOp(NCHWCompatible op, Value val, NamedAttrList transpose_attributes,\n                                 int num_transposed_result, PatternRewriter& rewriter) {\n  std::string transpose_name = OpTrait::IsOpConfCompatible<void>::getOpName(op).str()\n                               + \"_transpose_output_\" + std::to_string(num_transposed_result);\n  transpose_attributes.set(llvm::StringRef(OpTrait::IsOpConfCompatible<void>::getOpNameAttr()),\n                           rewriter.getStringAttr(transpose_name));\n  SmallVector<Value, 4> operands;\n  operands.push_back(val);\n  TransposeOp transpose_op = rewriter.create<oneflow::TransposeOp>(\n      op.getLoc(), getNCHWType(val.getType()), operands, transpose_attributes);\n  return transpose_op;\n}\n\nbool IsInsertTransposeOpBefore(NCHWCompatible op, PatternRewriter& rewriter) {\n  bool insert_transpose_op_flag = false;\n  for (mlir::Value operand : op->getOperands()) {\n    TransposeOp transposeInputOp = operand.getDefiningOp<TransposeOp>();\n    if (!transposeInputOp) continue;\n    const auto perm = transposeInputOp.getPermAttr();\n    if (perm.size() == 4 && perm[0] == rewriter.getSI32IntegerAttr(0)\n        && perm[1] == rewriter.getSI32IntegerAttr(3) && perm[2] == rewriter.getSI32IntegerAttr(1)\n        && perm[3] == rewriter.getSI32IntegerAttr(2)) {\n      insert_transpose_op_flag = true;\n      break;\n    }\n  }\n  return insert_transpose_op_flag;\n}\n\n}  // namespace oneflow\n\n}  // namespace mlir\n\n#include \"OneFlow/OneFlowPatterns.cpp.inc\"\n\nnamespace mlir {\n\nnamespace oneflow {\n\ntemplate<typename Op>\nstruct FusedConsecutiveAddPattern : public OpRewritePattern<Op> {\n  explicit FusedConsecutiveAddPattern(mlir::MLIRContext* context)\n      : OpRewritePattern<Op>(context, /*benefit=*/1) {}\n\n public:\n  LogicalResult matchAndRewrite(Op op, PatternRewriter& rewriter) const override;\n};\n\ntemplate<typename Op>\nLogicalResult TryFusedConsecutiveAdd(Op op, const SmallVector<mlir::Value, 4>& opOperands,\n                                     PatternRewriter& rewriter) {\n  for (mlir::Value operand : opOperands) {\n    if (!operand.getDefiningOp<AddNOp>() && !operand.getDefiningOp<Add2Op>()) { continue; }\n    // check if the operand has only one user\n    LogicalResult checkResult = [&]() {\n      for (const auto& use : operand.getUses()) {\n        if (use.getOwner() != op) { return failure(); }\n      }\n      return success();\n    }();\n    if (failed(checkResult)) { continue; }\n\n    SmallVector<mlir::Value, 4> operands;\n    SmallVector<mlir::Value, 4> inputOpOperands;\n    mlir::Value inputOpResult;\n    if (AddNOp addInputOp = operand.getDefiningOp<AddNOp>()) {\n      inputOpOperands = addInputOp.getIn();\n      inputOpResult = addInputOp.getOut();\n    } else if (Add2Op addInputOp = operand.getDefiningOp<Add2Op>()) {\n      inputOpOperands = {addInputOp.getIn0(), addInputOp.getIn1()};\n      inputOpResult = addInputOp.getOut();\n    }\n    for (mlir::Value operand : opOperands) {\n      if (operand != inputOpResult) {\n        operands.push_back(operand);\n      } else {\n        operands.insert(operands.end(), inputOpOperands.begin(), inputOpOperands.end());\n      }\n    }\n    auto new_op =\n        rewriter.create<AddNOp>(op->getLoc(), op->getResultTypes(), operands, op->getAttrs());\n    rewriter.replaceOp(op, new_op.getOut());\n    return success();\n  }\n  return failure();\n}\n\ntemplate<>\nLogicalResult FusedConsecutiveAddPattern<AddNOp>::matchAndRewrite(AddNOp op,\n                                                                  PatternRewriter& rewriter) const {\n  return TryFusedConsecutiveAdd<AddNOp>(op, op.getIn(), rewriter);\n}\n\ntemplate<>\nLogicalResult FusedConsecutiveAddPattern<Add2Op>::matchAndRewrite(Add2Op op,\n                                                                  PatternRewriter& rewriter) const {\n  return TryFusedConsecutiveAdd<Add2Op>(op, {op.getIn0(), op.getIn1()}, rewriter);\n}\n\nstruct AutoNhwcPattern : public OpInterfaceRewritePattern<NCHWCompatible> {\n  explicit AutoNhwcPattern(mlir::MLIRContext* context)\n      : OpInterfaceRewritePattern<NCHWCompatible>(context, /*benefit=*/1) {}\n\n public:\n  LogicalResult matchAndRewrite(NCHWCompatible op, PatternRewriter& rewriter) const override {\n    if (op->hasTrait<OpTrait::IsOpConfCompatible>()) {\n      for (mlir::Value operand : op.OperandsToTranspose()) {\n        if (operand.getType().cast<mlir::RankedTensorType>().getShape().size() != 4) {\n          return failure();\n        }\n      }\n      const auto device_name = OpTrait::IsOpConfCompatible<void>::getDeviceTag(op)\n                                   .cast<mlir::StringAttr>()\n                                   .getValue()\n                                   .str();\n      if (device_name == \"cpu\") { return failure(); }\n    }\n    llvm::SmallVector<int32_t> perm = getChannelLastTransposePerm();\n    llvm::SmallVector<int32_t> result_perm = getChannelFirstTransposePerm();\n\n    NamedAttrList transpose_attributes;\n    if (InitTransposeAttributes(op, transpose_attributes, rewriter).succeeded()) {\n      transpose_attributes.append(llvm::StringRef(\"perm\"), getSI32ArrayAttr(rewriter, perm));\n    } else {\n      return failure();\n    }\n    // when op op has no sense of data_format and pre op is transpose, we greedily insert transpose\n    // into this op, seeking more opportunities to eliminate transpose pattern.\n    const bool greedily_transpose_flag = !op.IsNCHW() && IsInsertTransposeOpBefore(op, rewriter);\n\n    if (op.IsNCHW() || greedily_transpose_flag) {\n      // create transpose op for input operand\n      SmallVector<Value, 4> tranposed_operands;\n      llvm::DenseSet<Value> operand_transpose = op.OperandsToTranspose();\n      int num_transposed_operand = 0;\n      for (Value operand : op->getOperands()) {\n        if (operand_transpose.find(operand) != operand_transpose.end()) {\n          SmallVector<Value, 4> input_res = getInputOperandTransposeOp(\n              op, operand, transpose_attributes, num_transposed_operand, rewriter);\n          tranposed_operands.push_back(input_res[0]);\n          num_transposed_operand += 1;\n        }\n      }\n      // create NHWC op\n      SmallVector<Value, 4> created_results = op.NchwToNhwc(tranposed_operands, rewriter);\n      // create transpose op for results\n      int num_transposed_result = 0;\n      transpose_attributes.set(llvm::StringRef(\"perm\"), getSI32ArrayAttr(rewriter, result_perm));\n      llvm::DenseSet<Value> transpose_result = op.ResultsToTranspose();\n\n      for (Value result : op->getOpResults()) {\n        if (transpose_result.find(result) != transpose_result.end()) {\n          if (auto result_transpose_op =\n                  getResultTransposeOp(op, created_results[num_transposed_result],\n                                       transpose_attributes, num_transposed_result, rewriter)) {\n            result.replaceAllUsesWith(result_transpose_op);\n            num_transposed_result += 1;\n          } else {\n            return failure();\n          }\n        }\n      }\n    }\n    return success();\n  }\n};\n\nbool IsRedundantTransposeMatch(ArrayAttr pre, ArrayAttr afe, mlir::PatternRewriter& rewriter) {\n  const auto prePerm = pre.getValue().vec();\n  const auto afePerm = afe.getValue().vec();\n  if (prePerm.size() == 4 && afePerm.size() == 4) {\n    // handle nchw->nhwc->nchw: (0, 2, 3, 1) -> (0, 3, 1, 2)\n    if (prePerm[0] == afePerm[0] && prePerm[1] == afePerm[3] && prePerm[2] == afePerm[1]\n        && prePerm[3] == afePerm[2] && prePerm[0] == rewriter.getSI32IntegerAttr(0)\n        && prePerm[1] == rewriter.getSI32IntegerAttr(2)\n        && prePerm[2] == rewriter.getSI32IntegerAttr(3)\n        && prePerm[3] == rewriter.getSI32IntegerAttr(1))\n      return true;\n    // handle nhwc->nchw->nhwc: (0, 3, 1, 2) -> (0, 2, 3, 1)\n    if (prePerm[0] == afePerm[0] && prePerm[1] == afePerm[2] && prePerm[2] == afePerm[3]\n        && prePerm[3] == afePerm[1] && prePerm[0] == rewriter.getSI32IntegerAttr(0)\n        && prePerm[1] == rewriter.getSI32IntegerAttr(3)\n        && prePerm[2] == rewriter.getSI32IntegerAttr(1)\n        && prePerm[3] == rewriter.getSI32IntegerAttr(2))\n      return true;\n  }\n  return false;\n}\n\nstruct AutoNhwcEliminateRedundantTransposePattern : public mlir::OpRewritePattern<TransposeOp> {\n  explicit AutoNhwcEliminateRedundantTransposePattern(mlir::MLIRContext* context)\n      : OpRewritePattern<TransposeOp>(context, /*benefit=*/1) {}\n  mlir::LogicalResult matchAndRewrite(TransposeOp op,\n                                      mlir::PatternRewriter& rewriter) const override {\n    mlir::Value transposeInput = op.getOperand();\n    TransposeOp transposeInputOp = transposeInput.getDefiningOp<TransposeOp>();\n\n    if (!transposeInputOp\n        || !IsRedundantTransposeMatch(op.getPermAttr(), transposeInputOp.getPermAttr(), rewriter)) {\n      return failure();\n    }\n    rewriter.replaceOp(op, {transposeInputOp.getOperand()});\n    return success();\n  }\n};\n\nstruct LowerToOKLPattern : public mlir::OpRewritePattern<func::FuncOp> {\n  static LogicalResult LowerToOKLOp(::mlir::PatternRewriter& rewriter, Operation* op,\n                                    func::FuncOp okl_func, int index) {\n    auto op_type_name = op->getAttr(\"op_name\").dyn_cast<StringAttr>();\n    auto raw_func = op->getParentOfType<func::FuncOp>();\n    if (!op_type_name) { return failure(); }\n    OpBuilder::InsertionGuard guard(rewriter);\n    rewriter.setInsertionPointToEnd(&okl_func.getBody().back());\n\n    auto loc = op->getLoc();\n\n    auto wrap_kernel = rewriter.create<okl::WrapperKernelOp>(loc, index);\n    wrap_kernel.getBody().emplaceBlock();\n    rewriter.setInsertionPointToEnd(&wrap_kernel.getBody().back());\n\n    IRMapping mapping;\n\n    // map launcher_ctx from wrap func to block\n    mapping.map(raw_func.getArgument(0), okl_func.getArgument(0));\n\n    ImplicitLocOpBuilder new_block(loc, rewriter);\n    for (auto arg : op->getOperands()) {\n      auto define_op = arg.getDefiningOp();\n      if (define_op->getName().getStringRef() == okl::GetTensorFromArgOp::getOperationName()) {\n        new_block.clone(*define_op, mapping);\n      } else {\n        auto find = false;\n        for (auto use : arg.getUsers()) {\n          if (use->getName().getStringRef() == okl::GetTensorAsRetOp::getOperationName()) {\n            find = true;\n            auto index = use->getAttr(\"index\").cast<IntegerAttr>().getInt();\n            auto source = rewriter.create<okl::GetTensorFromRetOp>(op->getLoc(), arg.getType(),\n                                                                   okl_func.getArgument(0), index);\n            mapping.map(arg, source->getResult(0));\n            break;\n          }\n        }\n        if (!find) { op->emitError(\"Fail to find operand source\"); }\n      }\n    }\n    new_block.clone(*op, mapping);\n    for (auto ret : op->getResults()) {\n      auto find = false;\n      for (auto use : ret.getUsers()) {\n        if (use->getName().getStringRef() == okl::GetTensorAsRetOp::getOperationName()) {\n          find = true;\n          new_block.clone(*use, mapping);\n          break;\n        }\n      }\n      if (!find) { op->emitError(\"Fail to find result source\"); }\n    }\n    rewriter.create<okl::ReturnOp>(loc);\n\n    return success();\n  }\n\n  explicit LowerToOKLPattern(mlir::MLIRContext* context)\n      : OpRewritePattern<func::FuncOp>(context, /*benefit=*/0) {}\n  mlir::LogicalResult matchAndRewrite(func::FuncOp op,\n                                      mlir::PatternRewriter& rewriter) const override {\n    ModuleOp module = op->getParentOfType<ModuleOp>();\n    if (!module) { LOG(FATAL) << \"Not found module\"; }\n    if (module.lookupSymbol(okl_func::OKL_FUNC)) { return success(); }\n\n    OpBuilder::InsertionGuard guard(rewriter);\n    rewriter.setInsertionPointAfter(op);\n    auto& block = op.getBody().front();\n    auto loc = op->getLoc();\n\n    auto func_type = rewriter.getFunctionType(\n        {mlir::okl::LauncherContextType::get(rewriter.getContext())}, TypeRange{});\n    auto okl_func = rewriter.create<func::FuncOp>(loc, okl_func::OKL_FUNC, func_type);\n    okl_func.getBody().emplaceBlock();\n    okl_func.getBody().addArguments(mlir::okl::LauncherContextType::get(rewriter.getContext()),\n                                    loc);\n\n    auto index = 0;\n    for (auto& op : block) {\n      if (!op.hasAttr(\"op_name\")) {\n        if (op.getDialect()->getNamespace() == \"okl\") { continue; }\n        if (isa<func::ReturnOp>(op)) { break; }\n        op.emitError(\"Failed to parse this op in kernel launch wrap func.\");\n      }\n      if (failed(LowerToOKLOp(rewriter, &op, okl_func, index))) {\n        index += 1;\n        op.emitError(\"Failed to lowering OneFlow op to okl dialect.\");\n        return failure();\n      }\n      index += 1;\n    }\n\n    rewriter.setInsertionPointToEnd(&okl_func.getBody().back());\n    rewriter.create<func::ReturnOp>(loc);\n    rewriter.eraseOp(op);\n    return success();\n  }\n};\n\n// {func, ins, outs_mapping}\nstd::tuple<func::FuncOp, std::vector<Value>, std::vector<std::vector<int>>>\nCreateWrapFuncAndReturnWithIns(mlir::Location loc, std::vector<Operation*>& wrap_ops,\n                               mlir::PatternRewriter& rewriter, int& name_index) {\n  auto getProto =\n      [&]() -> std::tuple<std::vector<Value>, std::vector<Value>, std::vector<std::vector<int>>> {\n    std::vector<Value> whole_ins, whole_outs, ins, outs;\n    std::vector<std::vector<int>> outs_mapping;\n    for (auto op : wrap_ops) {\n      auto operands = op->getOperands();\n      auto results = op->getResults();\n      for (auto it = operands.begin(); it != operands.end(); ++it) { whole_ins.push_back(*it); }\n\n      std::vector<int> map;\n      auto add_res = [&](mlir::OpResult res) {\n        map.push_back(outs.size());\n        outs.push_back(res);\n      };\n      for (auto it = results.begin(); it != results.end(); ++it) {\n        whole_outs.push_back(*it);\n        for (auto user : (*it).getUsers()) {\n          if (std::find(wrap_ops.begin(), wrap_ops.end(), user) == wrap_ops.end()) {\n            add_res(*it);\n            break;\n          }\n        }\n      }\n      outs_mapping.push_back(map);\n    }\n\n    for (auto in : whole_ins) {\n      if (std::find(whole_outs.begin(), whole_outs.end(), in) == whole_outs.end()) {\n        ins.push_back(in);\n      }\n    }\n    return {ins, outs, outs_mapping};\n  };\n\n  auto [ins, outs, map] = getProto();\n  auto func_type = rewriter.getFunctionType(TypeRange(ValueRange(ArrayRef<Value>(ins))),\n                                            TypeRange(ValueRange(ArrayRef<Value>(outs))));\n  auto func_name = okm::func_name::GRAPH_NAME + std::to_string(name_index++);\n  auto module = GetModuleOpFromJobBodyOp<Job>(wrap_ops[0]);\n  if (!module) { LOG(FATAL) << \"Fail to find parent ModuleOp\"; }\n  OpBuilder::InsertionGuard guard(rewriter);\n  rewriter.setInsertionPointToStart(module.getBody());\n  auto function = rewriter.create<func::FuncOp>(loc, func_name, func_type);\n  function->setAttr(\"llvm.emit_c_interface\", mlir::UnitAttr::get(rewriter.getContext()));\n  function.getBody().emplaceBlock();\n  for (auto arg : ins) { function.getBody().addArgument(arg.getType(), loc); }\n\n  IRMapping mapping;\n  for (auto args_pair : llvm::zip(ins, function.getBody().getArguments())) {\n    mapping.map(std::get<0>(args_pair), std::get<1>(args_pair));\n  }\n  rewriter.setInsertionPointToStart(&function.getBody().front());\n  ImplicitLocOpBuilder new_block(loc, rewriter);\n  for (auto op : wrap_ops) { new_block.clone(*op, mapping); }\n\n  SmallVector<::mlir::Value, 4> mapped_results;\n  for (auto result : outs) { mapped_results.push_back(mapping.lookup(result)); }\n  rewriter.create<func::ReturnOp>(loc, mapped_results);\n  return {function, ins, map};\n};\n\nKernelLaunchOp ConsumeOpsToFunc(std::vector<Operation*>& wrap_ops, mlir::PatternRewriter& rewriter,\n                                int& name_index) {\n  if (wrap_ops.size() < 2) {\n    wrap_ops.clear();\n    return nullptr;\n  }\n  auto loc = wrap_ops.front()->getLoc();\n  OpBuilder::InsertionGuard guard(rewriter);\n\n  auto [wrap_func, wrap_ins, map] =\n      CreateWrapFuncAndReturnWithIns(loc, wrap_ops, rewriter, name_index);\n\n  auto func_name = wrap_func.getSymNameAttr();\n  std::vector<NamedAttribute> attrs;\n  for (auto attr : wrap_ops[0]->getAttrs()) {\n    auto attr_list = {\"scope_symbol_id\", \"device_tag\", \"device_name\"};\n    if (std::find(attr_list.begin(), attr_list.end(), attr.getName()) != attr_list.end()) {\n      attrs.push_back(attr);\n    }\n  }\n\n  attrs.emplace_back(rewriter.getStringAttr(\"op_name\"), func_name);\n\n  rewriter.setInsertionPointAfter(wrap_ops.back());\n  auto func = rewriter.create<KernelLaunchOp>(wrap_ops[0]->getLoc(), wrap_func,\n                                              ArrayRef<NamedAttribute>(attrs), wrap_ins);\n\n  if (failed(DumpAssembly(rewriter, func, func_name))) {\n    LOG(FATAL) << \"Fail to dumping asm to kernel launch op.\";\n  }\n  for (auto it : llvm::zip(map, wrap_ops)) {\n    auto op = std::get<1>(it);\n    auto list = std::get<0>(it);\n    if (!list.size()) {\n      op->dropAllUses();\n      rewriter.eraseOp(op);\n      continue;\n    }\n    std::vector<Value> vals;\n    for (auto idx : list) { vals.push_back(func->getResult(idx)); }\n    if (op->getNumResults() == vals.size()) {\n      rewriter.replaceOp(op, vals);\n    } else {  // if op has multi results but only some of them used outside, we need tackle with\n              // mapper manually.\n      int idx = 0;\n      auto results = op->getResults();\n      for (auto it = results.begin(); it != results.end(); ++it) {\n        for (auto user : (*it).getUsers()) {\n          if (std::find(wrap_ops.begin(), wrap_ops.end(), user) == wrap_ops.end()) {\n            (*it).replaceAllUsesWith(func->getResult(list[idx]));\n            idx += 1;\n            break;\n          }\n        }\n      }\n      rewriter.eraseOp(op);\n    }\n  }\n  wrap_ops.clear();\n  return func;\n}\nstruct ExtractKernelLaunchTensorPattern : public mlir::OpRewritePattern<func::FuncOp> {\n  static func::FuncOp ExtractArgTensors(func::FuncOp op, mlir::PatternRewriter& rewriter) {\n    auto launcher_ctx_type = okl::LauncherContextType::get(rewriter.getContext());\n    auto return_types = op.getBody().front().back().getOperandTypes();\n    auto func_type = rewriter.getFunctionType({launcher_ctx_type}, return_types);\n\n    auto func = rewriter.create<mlir::func::FuncOp>(op.getLoc(), op.getName(), func_type);\n    auto& body = func.getBody();\n\n    body.emplaceBlock();\n    body.addArgument(launcher_ctx_type, op->getLoc());\n    auto launcher_ctx = body.getArgument(0);\n\n    OpBuilder::InsertionGuard guard(rewriter);\n    rewriter.setInsertionPointToStart(&body.front());\n\n    IRMapping mapping;\n    for (const auto& arg : llvm::enumerate(op.getBody().getArguments())) {\n      auto tensor = rewriter.create<okl::GetTensorFromArgOp>(func->getLoc(), arg.value().getType(),\n                                                             launcher_ctx, arg.index());\n      mapping.map(arg.value(), tensor);\n    }\n\n    ImplicitLocOpBuilder new_block(func->getLoc(), rewriter);\n    for (auto& op : op.getBody().front().getOperations()) { new_block.clone(op, mapping); }\n    rewriter.eraseOp(op);\n    return func;\n  }\n\n  static func::FuncOp ExtractRetTensors(func::FuncOp op, mlir::PatternRewriter& rewriter) {\n    auto& block = op.getBody().front();\n    auto launcher_ctx = op.getArgument(0);\n    auto& return_op = block.back();\n\n    OpBuilder::InsertionGuard guard(rewriter);\n    rewriter.setInsertionPoint(&return_op);\n\n    std::vector<Value> returns;\n    for (const auto& ret_val : llvm::enumerate(return_op.getOperands())) {\n      auto new_ret = rewriter.create<okl::GetTensorAsRetOp>(\n          op->getLoc(), ret_val.value().getType(), launcher_ctx, ret_val.value(), ret_val.index());\n      returns.push_back(new_ret);\n    }\n\n    rewriter.replaceOpWithNewOp<func::ReturnOp>(&return_op, ValueRange{returns});\n    return op;\n  }\n\n  explicit ExtractKernelLaunchTensorPattern(mlir::MLIRContext* context)\n      : OpRewritePattern<func::FuncOp>(context, /*benefit=*/0) {}\n  mlir::LogicalResult matchAndRewrite(func::FuncOp op,\n                                      mlir::PatternRewriter& rewriter) const override {\n    if (op.getBody().getNumArguments()) {\n      // skip if already converted\n      if (op.getBody().getArgument(0).getType().isa<okl::LauncherContextType>()) {\n        return success();\n      }\n    }\n    op = ExtractArgTensors(op, rewriter);\n    op = ExtractRetTensors(op, rewriter);\n    return success();\n  }\n};\n\nstruct TrimReturnAsVoidPattern : public mlir::OpRewritePattern<func::FuncOp> {\n  explicit TrimReturnAsVoidPattern(mlir::MLIRContext* context)\n      : OpRewritePattern<func::FuncOp>(context, /*benefit=*/0) {}\n  mlir::LogicalResult matchAndRewrite(func::FuncOp op,\n                                      mlir::PatternRewriter& rewriter) const override {\n    if (op.getBody().front().back().getNumOperands() == 0) { return success(); }\n    OpBuilder::InsertionGuard guard(rewriter);\n    rewriter.setInsertionPoint(op);\n\n    auto func_type = rewriter.getFunctionType(op.getFunctionType().getInputs(), TypeRange{});\n    auto func = rewriter.create<func::FuncOp>(op.getLoc(), op.getName(), func_type);\n\n    IRMapping bvm;\n    op.getRegion().cloneInto(&func.getRegion(), bvm);\n\n    auto& old_ret = func.getBody().front().back();\n    rewriter.setInsertionPoint(&old_ret);\n    rewriter.replaceOpWithNewOp<func::ReturnOp>(&old_ret);\n    rewriter.eraseOp(op);\n    return success();\n  }\n};\n\nstruct KernelLaunchPattern : public mlir::OpRewritePattern<oneflow::Job> {\n  explicit KernelLaunchPattern(mlir::MLIRContext* context, bool trim = false)\n      : OpRewritePattern<oneflow::Job>(context, /*benefit=*/0) {}\n\n  // if the pre-packed ops is continuous with the current op, this current op will be packed with\n  // pre-packed ops together.\n  virtual bool IsConsecutive(std::vector<Operation*>&, mlir::Operation*) const { return true; };\n\n  virtual bool IsPackagable(mlir::Operation* op) const {\n    return GetModuleOpFromJobBodyOp<Job>(&(*op)) && op->getAttr(\"op_name\")\n           && dyn_cast<UserOpCompatible>(op)\n           && op->getName().getStringRef() != KernelLaunchOp::getOperationName();\n  }\n\n  mlir::LogicalResult matchAndRewrite(oneflow::Job op,\n                                      mlir::PatternRewriter& rewriter) const override {\n    auto& ops = op->getRegion(0).front();\n    if (ops.empty()) { return success(); }\n\n    int name_index = 0;\n    std::vector<Operation*> current_wrap_ops;\n    for (auto op_it = ops.begin(); op_it != ops.end(); ++op_it) {\n      auto current_op = &(*op_it);\n      if (!IsPackagable(current_op)) {\n        ConsumeOpsToFunc(current_wrap_ops, rewriter, name_index);\n        continue;\n      }\n\n      if (!IsConsecutive(current_wrap_ops, current_op)) {\n        ConsumeOpsToFunc(current_wrap_ops, rewriter, name_index);\n      }\n      current_wrap_ops.push_back(current_op);\n    }\n    if (!current_wrap_ops.empty()) { ConsumeOpsToFunc(current_wrap_ops, rewriter, name_index); }\n    return success();\n  }\n};\n\nstruct KernelLaunchSimplePattern : public KernelLaunchPattern {\n  explicit KernelLaunchSimplePattern(mlir::MLIRContext* context) : KernelLaunchPattern(context) {}\n\n  bool IsSameDevice(std::vector<Operation*>& ops, mlir::Operation* op) const {\n    if (ops.empty()) { return true; }\n\n    auto device_tag = op->getAttr(\"device_tag\").dyn_cast_or_null<StringAttr>();\n    auto device_name = op->getAttr(\"device_name\").dyn_cast_or_null<ArrayAttr>();\n    auto cmp_device_tag = ops.front()->getAttr(\"device_tag\").dyn_cast_or_null<StringAttr>();\n    auto cmp_device_name = ops.front()->getAttr(\"device_name\").dyn_cast_or_null<ArrayAttr>();\n\n    if (!device_tag || !device_name || !cmp_device_tag || !cmp_device_name) { return false; }\n\n    auto same_device_tag = device_tag.str() == cmp_device_tag.str();\n    auto same_device_name =\n        std::equal(device_name.begin(), device_name.end(), cmp_device_name.begin(),\n                   [](const Attribute a, const Attribute b) {\n                     auto a_str = a.dyn_cast_or_null<StringAttr>();\n                     auto b_str = b.dyn_cast_or_null<StringAttr>();\n                     if (!a_str || !b_str) { return false; }\n                     return a_str.str() == b_str.str();\n                   });\n\n    return same_device_tag && same_device_name;\n  }\n\n  bool IsConsecutive(std::vector<Operation*>& ops, mlir::Operation* op) const override {\n    if (ops.empty()) { return true; }\n    return IsSameDevice(ops, op);\n  }\n};\n\nstruct KernelLaunchWithCudaGraphPattern : public KernelLaunchSimplePattern {\n  explicit KernelLaunchWithCudaGraphPattern(mlir::MLIRContext* context)\n      : KernelLaunchSimplePattern(context) {}\n\n  bool IsOpCudaGraphSupport(mlir::Operation* op) const {\n    ::oneflow::okl::RegContext reg_ctx(op);\n    auto* kernel = const_cast<::oneflow::user_op::OpKernel*>(reg_ctx.GetKernel());\n    return dynamic_cast<::oneflow::user_op::CudaGraphSupport*>(kernel);\n  }\n\n  bool IsSameCudaGraphSupport(std::vector<Operation*>& ops, mlir::Operation* op) const {\n    if (ops.empty()) { return true; }\n    auto cuda_support = IsOpCudaGraphSupport(op);\n    return cuda_support == IsOpCudaGraphSupport(ops.front());\n  }\n\n  bool IsConsecutive(std::vector<Operation*>& ops, mlir::Operation* op) const override {\n    if (ops.empty()) { return true; }\n    return IsSameDevice(ops, op) && IsSameCudaGraphSupport(ops, op);\n  }\n};\n\nvoid AddLoweringToLinalgMemRefPasses(PassManager& pm) {\n  pm.addPass(createConvertToSignlessForTosaPass());\n  pm.addNestedPass<func::FuncOp>(LLVM::createRequestCWrappersPass());\n  pm.addPass(createLowerOneFlowToTosaPass());\n  pm.addNestedPass<func::FuncOp>(tosa::createTosaMakeBroadcastablePass());\n  pm.addPass(createCSEPass());\n  pm.addNestedPass<func::FuncOp>(tosa::createTosaToLinalg());\n  pm.addNestedPass<func::FuncOp>(tosa::createTosaToTensor());\n  pm.addNestedPass<func::FuncOp>(createLinalgElementwiseOpFusionPass());\n  // TODO: more optimization pass\n  // Note: OneShot bufferization with result extract realization.\n  pm.addPass(bufferization::createEmptyTensorEliminationPass());\n  pm.addPass(bufferization::createEmptyTensorToAllocTensorPass());\n\n  auto oneshot_bufferize = bufferization::createOneShotBufferizePass();\n  CHECK(\n      oneshot_bufferize\n          ->initializeOptions(\"create-deallocs=0 bufferize-function-boundaries allow-return-allocs\")\n          .succeeded());\n  pm.addPass(std::move(oneshot_bufferize));\n  pm.addPass(bufferization::createBufferResultsToOutParamsPass());\n  pm.addPass(mlir::oneflow::createEliminateAllocOpsPass());\n  pm.addPass(createCanonicalizerPass());\n}\n\nLogicalResult LowerModuleToLLVM(mlir::MLIRContext* context, ModuleOp module) {\n  mlir::PassManager pm(context);\n  mlir::oneflow::CheckEnableIRPrinting(pm);\n  AddLoweringToLinalgMemRefPasses(pm);\n  pm.addNestedPass<func::FuncOp>(createConvertLinalgToLoopsPass());\n  pm.addNestedPass<func::FuncOp>(createConvertSCFToCFPass());\n  pm.addNestedPass<func::FuncOp>(createFoldAllocToSubviewPass());\n  pm.addPass(createInsertOneFlowMemPoolPass());\n  pm.addPass(createAppendOneFlowStreamPass());\n  pm.addPass(memref::createExpandOpsPass());\n  pm.addPass(memref::createExpandStridedMetadataPass());\n  pm.addPass(createFinalizeMemRefToLLVMConversionPass());\n  pm.addPass(createLowerAffinePass());\n  pm.addPass(createConvertLinalgToLLVMPass());\n  pm.addPass(createConvertFuncToLLVMPass());\n  pm.addPass(createReconcileUnrealizedCastsPass());\n  return pm.run(module);\n}\n\n#ifdef WITH_MLIR_CUDA_CODEGEN\n\nvoid AddLoweringLinalgOnBufferToGpuWithStdPasses(PassManager& pm) {\n  pm.addNestedPass<func::FuncOp>(createConvertLinalgToParallelLoopsPass());\n  pm.addNestedPass<func::FuncOp>(createGpuMapParallelLoopsPass());\n  pm.addNestedPass<func::FuncOp>(createParallelLoopToGpuPass());\n  pm.addNestedPass<func::FuncOp>(createGpuLauchSinkIndexComputationsPass());\n  pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());\n  pm.addNestedPass<func::FuncOp>(createCSEPass());\n  pm.addNestedPass<func::FuncOp>(createFoldAllocToSubviewPass());\n  pm.addPass(createInsertOneFlowMemPoolPass());\n  pm.addNestedPass<func::FuncOp>(createConvertLinalgToLoopsPass());\n  pm.addNestedPass<func::FuncOp>(createConvertSCFToCFPass());\n  pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());\n  pm.addNestedPass<func::FuncOp>(createCSEPass());\n  pm.addNestedPass<func::FuncOp>(createGpuCopyArgPass());\n}\n\nvoid AddAdheringCubinToGpuModulePasses(PassManager& pm) {\n  pm.addNestedPass<gpu::GPUModuleOp>(createLowerAffinePass());\n  pm.addNestedPass<gpu::GPUModuleOp>(createStripDebugInfoPass());\n  pm.addNestedPass<gpu::GPUModuleOp>(createLowerGpuOpsToNVVMOpsPass());\n  pm.addNestedPass<gpu::GPUModuleOp>(createNVVMToCubinPass());\n}\n\nvoid AddLoweringGpuToLLVMPasses(PassManager& pm) {\n  pm.addPass(createFinalizeMemRefToLLVMConversionPass());\n  pm.addPass(createLowerAffinePass());\n  pm.addPass(createAppendOneFlowStreamPass());\n  pm.addPass(createGpuToLLVMConversionPass());\n  pm.addPass(createMgpuToOneFlowStreamPass());\n  pm.addPass(createReconcileUnrealizedCastsPass());\n}\n\nLogicalResult LowerModuleToCUDALLVM(mlir::MLIRContext* context, ModuleOp module) {\n  InitializeLLVMNVPTXBackend();\n  mlir::PassManager pm(context);\n  mlir::oneflow::CheckEnableIRPrinting(pm);\n  AddLoweringToLinalgMemRefPasses(pm);\n  AddLoweringLinalgOnBufferToGpuWithStdPasses(pm);\n  pm.addPass(memref::createExpandOpsPass());\n  pm.addPass(memref::createExpandStridedMetadataPass());\n  pm.addPass(createGpuKernelOutliningPass());\n  AddAdheringCubinToGpuModulePasses(pm);\n  AddLoweringGpuToLLVMPasses(pm);\n  return pm.run(module);\n}\n\n#endif  // WITH_MLIR_CUDA_CODEGEN\n\nvoid populateWrapOpsToKernelLaunchPatterns(::mlir::RewritePatternSet& patterns,\n                                           const std::string& mode) {\n  if (mode == wrap_mode::SIMPLE) {\n    patterns.add<KernelLaunchSimplePattern>(patterns.getContext());\n  } else if (mode == wrap_mode::CUDA_GRAPH) {\n#ifdef WITH_CUDA_GRAPHS\n    patterns.add<KernelLaunchWithCudaGraphPattern>(patterns.getContext());\n#else\n    patterns.add<KernelLaunchPattern>(patterns.getContext());\n#endif\n  } else {\n    LOG(FATAL) << \"Found an unsupported mode in wrap-ops-to-kernel-launch pass\";\n  }\n}\n\nvoid populateFuserForExistingOp(::mlir::RewritePatternSet& patterns) {\n  populateForwardOpPatterns(patterns);\n  rewrites::populateRewrites(patterns);\n  constraints::populateConstraints(patterns);\n  populateNormalizationOpPatterns(patterns);\n  patterns.add<FusedConsecutiveAddPattern<Add2Op>>(patterns.getContext());\n  patterns.add<FusedConsecutiveAddPattern<AddNOp>>(patterns.getContext());\n}\n\nvoid populateAutoNhwcPatterns(::mlir::RewritePatternSet& patterns) {\n  bool enable_nhwc = ::oneflow::ParseBooleanFromEnv(\"ONEFLOW_MLIR_PREFER_NHWC\", false);\n  if (enable_nhwc) {\n    patterns.add<AutoNhwcPattern>(patterns.getContext());\n    patterns.add<AutoNhwcEliminateRedundantTransposePattern>(patterns.getContext());\n  }\n}\n\nvoid populateGpuHelperPatterns(::mlir::RewritePatternSet& patterns) {\n  patterns.add<ReplaceCopyWithGPUPattern>(patterns.getContext());\n}\n\nvoid populatePreConvertInferenceOp(::mlir::RewritePatternSet& patterns) {\n  patterns.add<ReplaceVariablePattern>(patterns.getContext());\n}\n\nvoid populateConvertInferenceOp(::mlir::RewritePatternSet& patterns) {\n  populateFuseConv2DBatchNormPattern(patterns);\n}\n\nvoid populatePostConvertInferenceOp(::mlir::RewritePatternSet& patterns) {\n  patterns.add<ReplaceVariableIrPattern>(patterns.getContext());\n}\n\n}  // namespace oneflow\n\n}  // namespace mlir\n"
  },
  {
    "path": "oneflow/ir/lib/OneFlow/SBP/SBPAttributes.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"OneFlow/SBP/SBPDialect.h\"\n#include \"OneFlow/SBP/SBPAttributes.h\"\n#include \"llvm/ADT/SmallVector.h\"\n#include \"llvm/ADT/TypeSwitch.h\"\n#include \"llvm/Support/Casting.h\"\n#include \"mlir/IR/Attributes.h\"\n#include \"mlir/IR/BuiltinAttributes.h\"\n#include \"mlir/IR/DialectImplementation.h\"\n#include \"mlir/Support/LogicalResult.h\"\n\nusing namespace mlir;\n\nLogicalResult parseSBP(AsmParser& parser, ArrayAttr& args) {\n  if (failed(parser.parseLSquare())) { return failure(); }\n  if (succeeded(parser.parseOptionalRSquare())) {\n    args = parser.getBuilder().getArrayAttr({});\n    return success();\n  }\n  llvm::SmallVector<Attribute> res;\n  llvm::SmallVector<Attribute> nd_list;\n\n  auto parserListElem = [&](llvm::SmallVector<Attribute>& list) {\n    auto loc = parser.getCurrentLocation();\n    if (failed(parser.parseAttribute(list.emplace_back()))) {\n      parser.emitError(loc, \"failed to parse an attribute here\");\n      return failure();\n    }\n    if (list.back().dyn_cast<sbp::SplitAttr>() || list.back().dyn_cast<sbp::BroadcastAttr>()\n        || list.back().dyn_cast<sbp::PartialSumAttr>() || list.back().dyn_cast<sbp::AnyAttr>()) {\n      return success();\n    }\n    parser.emitError(loc, \"failed to parse a sbp attribute here\");\n    return failure();\n  };\n\n  auto parserList = [&]() {\n    nd_list.clear();\n    if (parser.parseCommaSeparatedList([&]() { return parserListElem(nd_list); })\n        || parser.parseRSquare()) {\n      return failure();\n    }\n    res.emplace_back(parser.getBuilder().getArrayAttr(nd_list));\n    return success();\n  };\n\n  if (parser.parseCommaSeparatedList([&]() {\n        if (succeeded(parser.parseOptionalLSquare())) { return parserList(); }\n        return parserListElem(res);\n      })\n      || parser.parseRSquare()) {\n    return failure();\n  }\n  args = parser.getBuilder().getArrayAttr(res);\n  return success();\n}\nvoid printSBP(AsmPrinter& printer, ArrayAttr args) { printer << args; }\n\n#define GET_ATTRDEF_CLASSES\n#include \"OneFlow/SBPAttributes.cpp.inc\"\nnamespace mlir {\n\nnamespace sbp {\n\nvoid SBPDialect::registerAttributes() {\n  addAttributes<\n#define GET_ATTRDEF_LIST\n#include \"OneFlow/SBPAttributes.cpp.inc\"\n      >();\n}\n\n}  // namespace sbp\n\n}  // namespace mlir\n"
  },
  {
    "path": "oneflow/ir/lib/OneFlow/SBP/SBPDialect.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"OneFlow/SBP/SBPDialect.h\"\n#include \"mlir/IR/BuiltinAttributes.h\"\n#include \"OneFlow/SBPDialect.cpp.inc\"\n#include \"mlir/IR/Dialect.h\"\n#include \"mlir/IR/TypeRange.h\"\n\nnamespace mlir {\n\nnamespace sbp {\n\nvoid SBPDialect::initialize() { registerAttributes(); }\n\n}  // namespace sbp\n\n}  // namespace mlir\n"
  },
  {
    "path": "oneflow/ir/lib/OneFlow/SBP/SBPImporter.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"OneFlow/SBP/SBPImporter.h\"\n\n#include <vector>\n#include <string>\n\nnamespace mlir {\nnamespace oneflow {\n\nmlir::LogicalResult SBPTranslation::PrintSbpAttrToString(mlir::Attribute sbp_attr,\n                                                         std::string& sbp) {\n  if (auto sbp_s_attr = sbp_attr.dyn_cast<mlir::sbp::SplitAttr>()) {\n    sbp = \"S(\" + std::to_string(sbp_s_attr.getAxis()) + \")\";\n  } else if (auto sbp_b_attr = sbp_attr.dyn_cast<mlir::sbp::BroadcastAttr>()) {\n    sbp = \"B\";\n  } else if (auto sbp_p_attr = sbp_attr.dyn_cast<mlir::sbp::PartialSumAttr>()) {\n    sbp = \"P\";\n  } else if (auto sbp_p_attr = sbp_attr.dyn_cast<mlir::sbp::AnyAttr>()) {\n    sbp = \"\";\n  } else {\n    return mlir::failure();\n  }\n  return mlir::success();\n}\nmlir::Attribute SBPTranslation::ConvertSBPToString(mlir::Builder& builder,\n                                                   mlir::sbp::ParallelSignatureAttr& parallel) {\n  std::vector<std::string> list;\n  for (auto output : parallel.getOutputs()) {\n    if (auto nd_outputs = output.dyn_cast<mlir::ArrayAttr>()) {\n      for (auto nd_output : nd_outputs) {\n        std::string sbp;\n        if (failed(SBPTranslation::PrintSbpAttrToString(nd_output, sbp))) return {};\n        list.push_back(sbp);\n      }\n    } else {\n      std::string sbp;\n      if (failed(SBPTranslation::PrintSbpAttrToString(output, sbp))) return {};\n      list.push_back(sbp);\n    }\n  }\n  return builder.getStrArrayAttr(\n      makeArrayRef(llvm::SmallVector<llvm::StringRef>(list.begin(), list.end())));\n}\n\nmlir::Attribute SBPTranslation::ConvertNdSbpToPsig(mlir::Builder& builder,\n                                                   const std::vector<std::string>& nd_sbp,\n                                                   const int nd_size) {\n  auto ctx = builder.getContext();\n  std::vector<mlir::Attribute> outputs_vec;\n  for (const auto& sbp_data : nd_sbp) {\n    mlir::Attribute attr;\n    if (sbp_data == \"\") {\n      attr = mlir::sbp::AnyAttr::get(ctx);\n    } else {\n      ::oneflow::SbpParallel sbp;\n      ParseSbpParallelFromString(sbp_data, &sbp);\n      if (sbp.has_split_parallel()) {\n        attr = mlir::sbp::SplitAttr::get(ctx, sbp.split_parallel().axis());\n      } else if (sbp.has_broadcast_parallel()) {\n        attr = mlir::sbp::BroadcastAttr::get(ctx);\n      } else if (sbp.has_partial_sum_parallel()) {\n        attr = mlir::sbp::PartialSumAttr::get(ctx);\n      } else {\n        llvm::errs() << \"Unsupported sbp type from nd_sbp: \";\n        for (const auto& sbp_data : nd_sbp) { llvm::errs() << sbp_data << \" \"; }\n        llvm::errs() << \"\\n\";\n        exit(EXIT_FAILURE);\n      }\n    }\n    outputs_vec.push_back(attr);\n  }\n\n  auto inputs = builder.getArrayAttr({});\n  mlir::ArrayAttr outputs;\n\n  std::vector<mlir::Attribute> outputs_vec_nd;\n  for (auto iter = outputs_vec.begin(); iter < outputs_vec.end(); iter += nd_size) {\n    outputs_vec_nd.emplace_back(\n        builder.getArrayAttr(std::vector<mlir::Attribute>(iter, iter + nd_size)));\n  }\n  outputs = builder.getArrayAttr(outputs_vec_nd);\n  return mlir::sbp::ParallelSignatureAttr::get(ctx, inputs, outputs);\n}\n}  // namespace oneflow\n}  // namespace mlir\n"
  },
  {
    "path": "oneflow/ir/lib/OneFlow/Transform/AggregateOps.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"OneFlow/OKL/OKLDialect.h\"\n#include \"OneFlow/OneFlowDialect.h\"\n#include \"OneFlow/OneFlowOps.h\"\n#include \"OneFlow/Passes.h\"\n#include \"llvm/Support/Casting.h\"\n#include \"mlir/Dialect/Func/IR/FuncOps.h\"\n#include \"mlir/Dialect/LLVMIR/LLVMDialect.h\"\n#include \"mlir/IR/IRMapping.h\"\n#include \"mlir/IR/BuiltinOps.h\"\n#include \"mlir/Pass/Pass.h\"\n#include \"mlir/Support/LogicalResult.h\"\n#include \"mlir/Transforms/DialectConversion.h\"\n#include \"mlir/Transforms/GreedyPatternRewriteDriver.h\"\n\n#include <iostream>\n#include <string>\n\nnamespace mlir {\nnamespace oneflow {\n\nstruct AggregateComputeOpsPattern : public mlir::OpRewritePattern<OutputOp> {\n  explicit AggregateComputeOpsPattern(mlir::MLIRContext* context)\n      : OpRewritePattern<OutputOp>(context, /*benefit=*/0) {}\n\n  mlir::LogicalResult matchAndRewrite(OutputOp op, mlir::PatternRewriter& rewriter) const override {\n    if (op->getNumResults() != 1) { return failure(); }\n    if (llvm::isa<oneflow::ReturnOp>(op->getNextNode())) { return failure(); }\n    // oneflow.output only have a single result\n    for (auto user : op->getResult(0).getUsers()) {\n      if (!llvm::isa<oneflow::ReturnOp>(user)) { return failure(); }\n      rewriter.setInsertionPoint(user);\n    }\n\n    auto new_val = rewriter.clone(*op)->getResults();\n    rewriter.replaceOp(op, new_val);\n    return success();\n  };\n};\n\nnamespace {\n\nclass AggregateComputeOpsPass : public AggregateComputeOpsPassBase<AggregateComputeOpsPass> {\n  void getDependentDialects(DialectRegistry& registry) const override {\n    registry.insert<oneflow::OneFlowDialect>();\n  }\n\n  void runOnOperation() override {\n    Operation* op = getOperation();\n    RewritePatternSet patterns(op->getContext());\n    patterns.add<AggregateComputeOpsPattern>(patterns.getContext());\n    (void)applyPatternsAndFoldGreedily(op, std::move(patterns));\n  }\n};\n\n}  // namespace\n\nstd::unique_ptr<Pass> createAggregateComputeOpsPass() {\n  return std::make_unique<AggregateComputeOpsPass>();\n}\n\n}  // namespace oneflow\n}  // namespace mlir\n"
  },
  {
    "path": "oneflow/ir/lib/OneFlow/Transform/AutoNHWCOps.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"OneFlow/OneFlowOps.h\"\n#include \"OneFlow/Transform/TransposeHelpers.h\"\n\nnamespace mlir {\n\nnamespace oneflow {\n\nbool Conv2DOp::IsNCHW() { return this->getDataFormat().str() == \"channels_first\"; }\n\nllvm::DenseSet<Value> Conv2DOp::OperandsToTranspose() {\n  if (this->get_addToOutput()) {\n    return {this->getIn(), this->getWeight(), this->get_addToOutput()};\n  } else {\n    return {this->getIn(), this->getWeight()};\n  }\n}\n\nllvm::DenseSet<Value> Conv2DOp::ResultsToTranspose() { return {this->getOut()}; }\n\nllvm::SmallVector<Value, 4> Conv2DOp::NchwToNhwc(llvm::SmallVector<Value, 4> value,\n                                                 PatternRewriter& rewriter) {\n  auto conv_op = *this;\n  SmallVector<Value, 4> operands;\n  operands.push_back(value[0]);\n  operands.push_back(value[1]);\n  if (conv_op.getBias()) operands.push_back(conv_op.getBias());\n  if (this->get_addToOutput()) { operands.push_back(value[2]); }\n  NamedAttrList attributes = conv_op->getAttrs();\n  attributes.set(conv_op.getDataFormatAttrName(), rewriter.getStringAttr(\"channels_last\"));\n  auto res = rewriter\n                 .create<oneflow::Conv2DOp>(conv_op.getLoc(), getNHWCResultTypes(conv_op), operands,\n                                            attributes)\n                 ->getResults();\n  llvm::SmallVector<Value, 4> results;\n  results.push_back(res[0]);\n  return results;\n}\n\nbool BiasAddOp::IsNCHW() { return this->getAxisAttr().getValue().getSExtValue() == 1; }\n\nllvm::DenseSet<Value> BiasAddOp::OperandsToTranspose() { return {this->getA()}; }\n\nllvm::DenseSet<Value> BiasAddOp::ResultsToTranspose() { return {this->getOut()}; }\n\nllvm::SmallVector<Value, 4> BiasAddOp::NchwToNhwc(llvm::SmallVector<Value, 4> value,\n                                                  PatternRewriter& rewriter) {\n  auto bias_add_op = *this;\n  SmallVector<Value, 4> operands;\n  operands.push_back(value[0]);\n  operands.push_back(bias_add_op.getB());\n  NamedAttrList attributes = bias_add_op->getAttrs();\n  attributes.set(bias_add_op.getAxisAttrName(), rewriter.getSI32IntegerAttr(3));\n  auto res = rewriter\n                 .create<oneflow::BiasAddOp>(bias_add_op.getLoc(), getNHWCResultTypes(bias_add_op),\n                                             operands, attributes)\n                 ->getResults();\n  llvm::SmallVector<Value, 4> results;\n  results.push_back(res[0]);\n  return results;\n}\n\nbool BroadcastAddOp::IsNCHW() { return false; }\n\nllvm::DenseSet<Value> BroadcastAddOp::OperandsToTranspose() { return {this->getX(), this->getY()}; }\n\nllvm::DenseSet<Value> BroadcastAddOp::ResultsToTranspose() { return {this->getZ()}; }\n\nllvm::SmallVector<Value, 4> BroadcastAddOp::NchwToNhwc(llvm::SmallVector<Value, 4> values,\n                                                       PatternRewriter& rewriter) {\n  auto broadcast_op = *this;\n  NamedAttrList attributes = broadcast_op->getAttrs();\n  auto res = rewriter\n                 .create<oneflow::BroadcastAddOp>(\n                     broadcast_op.getLoc(), getNHWCResultTypes(broadcast_op), values, attributes)\n                 .getZ();\n  llvm::SmallVector<Value, 4> results;\n  results.push_back(res);\n  return results;\n}\n\nbool NormalizationOp::IsNCHW() { return this->getAxisAttr().getValue().getSExtValue() == 1; }\n\nbool NormalizationInferenceOp::IsNCHW() {\n  return this->getAxisAttr().getValue().getSExtValue() == 1;\n}\n\nllvm::DenseSet<Value> NormalizationOp::OperandsToTranspose() { return {this->getX()}; }\n\nllvm::DenseSet<Value> NormalizationInferenceOp::OperandsToTranspose() { return {this->getX()}; }\n\nllvm::DenseSet<Value> NormalizationOp::ResultsToTranspose() { return {this->getY()}; }\n\nllvm::DenseSet<Value> NormalizationInferenceOp::ResultsToTranspose() { return {this->getY()}; }\n\nllvm::SmallVector<Value, 4> NormalizationOp::NchwToNhwc(llvm::SmallVector<Value, 4> value,\n                                                        PatternRewriter& rewriter) {\n  auto normalization_op = *this;\n  SmallVector<Value, 4> operands;\n  operands.push_back(value[0]);\n  if (normalization_op.getMovingMean()) operands.push_back(normalization_op.getMovingMean());\n  if (normalization_op.getMovingVariance())\n    operands.push_back(normalization_op.getMovingVariance());\n  operands.push_back(normalization_op.getGamma());\n  operands.push_back(normalization_op.getBeta());\n  if (normalization_op.get_addToOutput()) operands.push_back(normalization_op.get_addToOutput());\n  NamedAttrList attributes = normalization_op->getAttrs();\n  attributes.set(normalization_op.getAxisAttrName(), rewriter.getSI32IntegerAttr(3));\n  auto res =\n      rewriter\n          .create<oneflow::NormalizationOp>(\n              normalization_op.getLoc(), getNHWCResultTypes(normalization_op), operands, attributes)\n          ->getResults();\n  llvm::SmallVector<Value, 4> results;\n  results.push_back(res[0]);\n  return results;\n}\n\nllvm::SmallVector<Value, 4> NormalizationInferenceOp::NchwToNhwc(llvm::SmallVector<Value, 4> value,\n                                                                 PatternRewriter& rewriter) {\n  auto normalization_op = *this;\n  SmallVector<Value, 4> operands;\n  operands.push_back(value[0]);\n  if (normalization_op.getMovingMean()) operands.push_back(normalization_op.getMovingMean());\n  if (normalization_op.getMovingVariance())\n    operands.push_back(normalization_op.getMovingVariance());\n  operands.push_back(normalization_op.getGamma());\n  operands.push_back(normalization_op.getBeta());\n  if (normalization_op.get_addToOutput()) operands.push_back(normalization_op.get_addToOutput());\n  NamedAttrList attributes = normalization_op->getAttrs();\n  attributes.set(normalization_op.getAxisAttrName(), rewriter.getSI32IntegerAttr(3));\n  auto res =\n      rewriter\n          .create<oneflow::NormalizationInferenceOp>(\n              normalization_op.getLoc(), getNHWCResultTypes(normalization_op), operands, attributes)\n          ->getResults();\n  llvm::SmallVector<Value, 4> results;\n  results.push_back(res[0]);\n  return results;\n}\n\nbool MaxPool2DOp::IsNCHW() { return this->getDataFormat().str() == \"channels_first\"; }\n\nllvm::DenseSet<Value> MaxPool2DOp::OperandsToTranspose() { return {this->getX()}; }\n\nllvm::DenseSet<Value> MaxPool2DOp::ResultsToTranspose() {\n  return {this->getY(), this->getIndice()};\n}\n\nllvm::SmallVector<Value, 4> MaxPool2DOp::NchwToNhwc(llvm::SmallVector<Value, 4> value,\n                                                    PatternRewriter& rewriter) {\n  auto max_pool_2d_op = *this;\n  SmallVector<Value, 4> operands;\n  operands.push_back(value[0]);\n  NamedAttrList attributes = max_pool_2d_op->getAttrs();\n  attributes.set(max_pool_2d_op.getDataFormatAttrName(), rewriter.getStringAttr(\"channels_last\"));\n  auto res =\n      rewriter\n          .create<oneflow::MaxPool2DOp>(max_pool_2d_op.getLoc(), getNHWCResultTypes(max_pool_2d_op),\n                                        operands, attributes)\n          ->getResults();\n  llvm::SmallVector<Value, 4> results;\n  results.push_back(res[0]);\n  results.push_back(res[1]);\n  return results;\n}\n\nbool ReluOp::IsNCHW() { return false; }\n\nllvm::DenseSet<Value> ReluOp::OperandsToTranspose() { return {this->getX()}; }\n\nllvm::DenseSet<Value> ReluOp::ResultsToTranspose() { return {this->getY()}; }\n\nllvm::SmallVector<Value, 4> ReluOp::NchwToNhwc(llvm::SmallVector<Value, 4> value,\n                                               PatternRewriter& rewriter) {\n  auto relu_op = *this;\n  SmallVector<Value, 4> operands{value[0]};\n  auto res = rewriter\n                 .create<oneflow::ReluOp>(relu_op.getLoc(), getNHWCResultTypes(relu_op), operands,\n                                          relu_op->getAttrs())\n                 ->getResults();\n  return {res[0]};\n}\n\nbool ScalarDivOp::IsNCHW() { return false; }\n\nllvm::DenseSet<Value> ScalarDivOp::OperandsToTranspose() { return {this->getIn()}; }\n\nllvm::DenseSet<Value> ScalarDivOp::ResultsToTranspose() { return {this->getOut()}; }\n\nllvm::SmallVector<Value, 4> ScalarDivOp::NchwToNhwc(llvm::SmallVector<Value, 4> value,\n                                                    PatternRewriter& rewriter) {\n  auto elementwise_op = *this;\n  SmallVector<Value, 4> operands{value[0]};\n  auto res =\n      rewriter\n          .create<oneflow::ScalarDivOp>(elementwise_op.getLoc(), getNHWCResultTypes(elementwise_op),\n                                        operands, elementwise_op->getAttrs())\n          ->getResults();\n  return {res[0]};\n}\n\nbool SiluOp::IsNCHW() { return false; }\n\nllvm::DenseSet<Value> SiluOp::OperandsToTranspose() { return {this->getIn()}; }\n\nllvm::DenseSet<Value> SiluOp::ResultsToTranspose() { return {this->getOut()}; }\n\nllvm::SmallVector<Value, 4> SiluOp::NchwToNhwc(llvm::SmallVector<Value, 4> value,\n                                               PatternRewriter& rewriter) {\n  auto elementwise_op = *this;\n  SmallVector<Value, 4> operands{value[0]};\n  auto res =\n      rewriter\n          .create<oneflow::SiluOp>(elementwise_op.getLoc(), getNHWCResultTypes(elementwise_op),\n                                   operands, elementwise_op->getAttrs())\n          ->getResults();\n  return {res[0]};\n}\n\nbool CastOp::IsNCHW() { return false; }\n\nllvm::DenseSet<Value> CastOp::OperandsToTranspose() { return {this->getIn()}; }\n\nllvm::DenseSet<Value> CastOp::ResultsToTranspose() { return {this->getOut()}; }\n\nllvm::SmallVector<Value, 4> CastOp::NchwToNhwc(llvm::SmallVector<Value, 4> value,\n                                               PatternRewriter& rewriter) {\n  auto elementwise_op = *this;\n  SmallVector<Value, 4> operands{value[0]};\n  auto res =\n      rewriter\n          .create<oneflow::CastOp>(elementwise_op.getLoc(), getNHWCResultTypes(elementwise_op),\n                                   operands, elementwise_op->getAttrs())\n          ->getResults();\n  return {res[0]};\n}\n\nbool Add2Op::IsNCHW() { return false; }\n\nllvm::DenseSet<Value> Add2Op::OperandsToTranspose() { return {this->getIn0(), this->getIn1()}; }\n\nllvm::DenseSet<Value> Add2Op::ResultsToTranspose() { return {this->getOut()}; }\n\nllvm::SmallVector<Value, 4> Add2Op::NchwToNhwc(llvm::SmallVector<Value, 4> value,\n                                               PatternRewriter& rewriter) {\n  auto add2_op = *this;\n  SmallVector<Value, 4> operands{value[0], value[1]};\n  auto res = rewriter\n                 .create<oneflow::Add2Op>(add2_op.getLoc(), getNHWCResultTypes(add2_op), operands,\n                                          add2_op->getAttrs())\n                 ->getResults();\n  return {res[0]};\n}\n\nbool ConcatOp::IsNCHW() { return this->getAxisAttr().getValue().getSExtValue() == 1; }\n\nllvm::DenseSet<Value> ConcatOp::OperandsToTranspose() {\n  llvm::DenseSet<Value> operands;\n  for (auto operand : this->getIn()) { operands.insert(operand); }\n  return operands;\n}\n\nllvm::DenseSet<Value> ConcatOp::ResultsToTranspose() { return {this->getOut()}; }\n\nllvm::SmallVector<Value, 4> ConcatOp::NchwToNhwc(llvm::SmallVector<Value, 4> values,\n                                                 PatternRewriter& rewriter) {\n  auto elementwise_op = *this;\n  NamedAttrList attributes = elementwise_op->getAttrs();\n  attributes.set(elementwise_op.getAxisAttrName(),\n                 IntegerAttr::get(rewriter.getIntegerType(64, /*isSigned=*/true),\n                                  APInt(64, 3, /*isSigned=*/true)));\n  auto out = rewriter\n                 .create<oneflow::ConcatOp>(elementwise_op.getLoc(),\n                                            getNHWCResultTypes(elementwise_op), values, attributes)\n                 .getOut();\n  return {out};\n}\n\nbool GroupNormOp::IsNCHW() { return this->getDataFormat().str() == \"channels_first\"; }\n\nllvm::DenseSet<Value> GroupNormOp::OperandsToTranspose() { return {this->getX()}; }\n\nllvm::DenseSet<Value> GroupNormOp::ResultsToTranspose() { return {this->getY()}; }\n\nllvm::SmallVector<Value, 4> GroupNormOp::NchwToNhwc(llvm::SmallVector<Value, 4> value,\n                                                    PatternRewriter& rewriter) {\n  auto group_norm_op = *this;\n  SmallVector<Value, 4> operands;\n  operands.push_back(value[0]);\n  if (this->getAffine()) {\n    operands.push_back(this->getBeta());\n    operands.push_back(this->getGamma());\n  }\n  NamedAttrList attributes = group_norm_op->getAttrs();\n  attributes.set(group_norm_op.getDataFormatAttrName(), rewriter.getStringAttr(\"channels_last\"));\n  auto res =\n      rewriter\n          .create<oneflow::GroupNormOp>(group_norm_op.getLoc(), getNHWCResultTypes(group_norm_op),\n                                        operands, attributes)\n          ->getResults();\n  llvm::SmallVector<Value, 4> results;\n  results.push_back(res[0]);\n  results.push_back(res[1]);\n  results.push_back(res[2]);\n  return results;\n}\n\n}  // namespace oneflow\n\n}  // namespace mlir\n"
  },
  {
    "path": "oneflow/ir/lib/OneFlow/Transform/AutoNhwc.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <iostream>\n#include <string>\n#include \"OneFlow/Passes.h\"\n#include \"mlir/Pass/Pass.h\"\n#include \"mlir/Transforms/GreedyPatternRewriteDriver.h\"\n\nnamespace mlir {\n\nnamespace oneflow {\n\nnamespace {\n\nclass AutoNhwcPass : public AutoNhwcPassBase<AutoNhwcPass> {\n  void runOnOperation() override {\n    Operation* op = getOperation();\n    RewritePatternSet patterns(op->getContext());\n    oneflow::populateAutoNhwcPatterns(patterns);\n    (void)applyPatternsAndFoldGreedily(op, std::move(patterns));\n  }\n};\n\n}  // namespace\n\nstd::unique_ptr<Pass> createAutoNhwcPass() { return std::make_unique<AutoNhwcPass>(); }\n\n}  // namespace oneflow\n\n}  // namespace mlir\n"
  },
  {
    "path": "oneflow/ir/lib/OneFlow/Transform/BufferHostRegister.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <iostream>\n#include <string>\n#include \"OneFlow/Passes.h\"\n#include \"mlir/Pass/Pass.h\"\n#include \"mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h\"\n#include \"mlir/Transforms/GreedyPatternRewriteDriver.h\"\n\nnamespace mlir {\n\nnamespace oneflow {\n\nnamespace {\n\nclass BufferHostRegisterPass : public BufferHostRegisterPassBase<BufferHostRegisterPass> {\n  void runOnOperation() override {\n    getOperation()->walk([&](memref::AllocOp alloc) {\n      auto ranked_type = alloc.getResult().getType().cast<MemRefType>();\n      Type unranked_type =\n          UnrankedMemRefType::get(ranked_type.getElementType(), ranked_type.getMemorySpace());\n      OpBuilder builder(alloc);\n      builder.setInsertionPointAfter(alloc);\n      Value casted = builder.create<memref::CastOp>(alloc->getLoc(), unranked_type, alloc);\n      builder.create<gpu::HostRegisterOp>(alloc->getLoc(), casted);\n    });\n  }\n};\n\nclass GpuCopyArgPass : public GpuCopyArgPassBase<GpuCopyArgPass> {\n  void runOnOperation() override {\n    Operation* op = getOperation();\n    RewritePatternSet patterns(op->getContext());\n    oneflow::populateGpuHelperPatterns(patterns);\n    (void)applyPatternsAndFoldGreedily(op, std::move(patterns));\n  }\n};\n\n}  // namespace\nstd::unique_ptr<Pass> createBufferHostRegisterPass() {\n  return std::make_unique<BufferHostRegisterPass>();\n}\n\nstd::unique_ptr<Pass> createGpuCopyArgPass() { return std::make_unique<GpuCopyArgPass>(); }\n\n}  // namespace oneflow\n\n}  // namespace mlir\n"
  },
  {
    "path": "oneflow/ir/lib/OneFlow/Transform/CSEWithAttributesIgnored.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <iostream>\n#include <string>\n#include \"OneFlow/OneFlowOps.h\"\n#include \"OneFlow/Passes.h\"\n#include \"mlir/Pass/Pass.h\"\n#include \"mlir/Transforms/GreedyPatternRewriteDriver.h\"\n\nnamespace mlir {\n\nnamespace oneflow {\n\nnamespace {\n\nstatic const auto MAGIC_OP_NAME = \"ONEFLOW_ERASE_MAGIC\";\nstatic const auto MAGIC_SCOPE_SYMBOL_ID = 77777;\n\nstruct EraseAttributes : public mlir::OpInterfaceRewritePattern<UserOpCompatible> {\n  explicit EraseAttributes(mlir::MLIRContext* context, std::shared_ptr<CSEState> state)\n      : OpInterfaceRewritePattern<UserOpCompatible>(context, /*benefit=*/1), state_{state} {}\n  mlir::LogicalResult matchAndRewrite(UserOpCompatible op,\n                                      mlir::PatternRewriter& rewriter) const override {\n    if (op->getAttrOfType<StringAttr>(OpTrait::IsOpConfCompatible<void>::getOpNameAttr())\n            .getValue()\n            .str()\n        != MAGIC_OP_NAME) {\n      if (state_) {\n        state_->opNames[op] =\n            op->getAttrOfType<StringAttr>(OpTrait::IsOpConfCompatible<void>::getOpNameAttr());\n        state_->scopeSymbolIDs[op] = op->getAttrOfType<IntegerAttr>(\n            OpTrait::IsOpConfCompatible<void>::getScopeSymbolIDAttr());\n      }\n      op->setAttr(OpTrait::IsOpConfCompatible<void>::getOpNameAttr(),\n                  rewriter.getStringAttr(MAGIC_OP_NAME));\n      op->setAttr(OpTrait::IsOpConfCompatible<void>::getScopeSymbolIDAttr(),\n                  rewriter.getI64IntegerAttr(MAGIC_SCOPE_SYMBOL_ID));\n      return success();\n    } else {\n      return failure();\n    }\n  }\n\n private:\n  std::shared_ptr<CSEState> state_;\n};\n\nstruct PutAttributes : public mlir::OpInterfaceRewritePattern<UserOpCompatible> {\n  explicit PutAttributes(mlir::MLIRContext* context, std::shared_ptr<CSEState> state)\n      : OpInterfaceRewritePattern<UserOpCompatible>(context, /*benefit=*/1), state_{state} {}\n  mlir::LogicalResult matchAndRewrite(UserOpCompatible op,\n                                      mlir::PatternRewriter& rewriter) const override {\n    if (op->getAttrOfType<StringAttr>(OpTrait::IsOpConfCompatible<void>::getOpNameAttr())\n            .getValue()\n            .str()\n        == MAGIC_OP_NAME) {\n      if (state_) {\n        op->setAttr(OpTrait::IsOpConfCompatible<void>::getOpNameAttr(), state_->opNames[op]);\n        op->setAttr(OpTrait::IsOpConfCompatible<void>::getScopeSymbolIDAttr(),\n                    state_->scopeSymbolIDs[op]);\n      }\n      return success();\n    } else {\n      return failure();\n    }\n  }\n\n private:\n  std::shared_ptr<CSEState> state_;\n};\n\nclass CSEWithAttributesIgnored : public CSEWithAttributesIgnoredBase<CSEWithAttributesIgnored> {\n public:\n  explicit CSEWithAttributesIgnored() {}\n  explicit CSEWithAttributesIgnored(std::shared_ptr<CSEState> state) : state_(state) {}\n  void runOnOperation() override {\n    Operation* op = getOperation();\n    RewritePatternSet patterns(op->getContext());\n    patterns.add<EraseAttributes>(op->getContext(), state_);\n    (void)applyPatternsAndFoldGreedily(op, std::move(patterns));\n  }\n\n private:\n  std::shared_ptr<CSEState> state_;\n};\n\nclass CSEPutAttributes : public CSEPutAttributesBase<CSEPutAttributes> {\n public:\n  explicit CSEPutAttributes() {}\n  explicit CSEPutAttributes(std::shared_ptr<CSEState> state) { state_ = state; }\n\n  void runOnOperation() override {\n    Operation* op = getOperation();\n    RewritePatternSet patterns(op->getContext());\n    patterns.add<PutAttributes>(op->getContext(), state_);\n    (void)applyPatternsAndFoldGreedily(op, std::move(patterns));\n  }\n\n private:\n  std::shared_ptr<CSEState> state_;\n};\n\n}  // namespace\n\nstd::unique_ptr<Pass> createCSEWithAttributesIgnored() {\n  return std::make_unique<CSEWithAttributesIgnored>();\n}\n\nstd::unique_ptr<Pass> createCSEPutAttributes() { return std::make_unique<CSEPutAttributes>(); }\n\nstd::pair<std::unique_ptr<Pass>, std::unique_ptr<Pass>> createCSEPasses(\n    std::shared_ptr<CSEState> state) {\n  return std::make_pair(std::make_unique<CSEWithAttributesIgnored>(state),\n                        std::make_unique<CSEPutAttributes>(state));\n}\n\nvoid registerCSEPasses(std::shared_ptr<CSEState> state) {\n  ::mlir::registerPass([state]() -> std::unique_ptr<::mlir::Pass> {\n    return std::make_unique<CSEWithAttributesIgnored>(state);\n  });\n  ::mlir::registerPass([state]() -> std::unique_ptr<::mlir::Pass> {\n    return std::make_unique<CSEPutAttributes>(state);\n  });\n}\n\n}  // namespace oneflow\n\n}  // namespace mlir\n"
  },
  {
    "path": "oneflow/ir/lib/OneFlow/Transform/ConvertInferenceOp.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <iostream>\n#include <string>\n#include \"OneFlow/Passes.h\"\n#include \"mlir/Pass/Pass.h\"\n#include \"mlir/Transforms/GreedyPatternRewriteDriver.h\"\n#include \"OneFlow/OneFlowPatternUtils.h\"\n\nnamespace mlir {\n\nnamespace oneflow {\n\nnamespace {\n\nclass PreConvertInferenceOpPass : public PreConvertInferenceOpPassBase<PreConvertInferenceOpPass> {\n  void runOnOperation() override {\n    Operation* op = getOperation();\n    RewritePatternSet patterns(op->getContext());\n    oneflow::populatePreConvertInferenceOp(patterns);\n    (void)applyPatternsAndFoldGreedily(op, std::move(patterns));\n  }\n};\n\nclass ConvertInferenceOpPass : public ConvertInferenceOpPassBase<ConvertInferenceOpPass> {\n  void runOnOperation() override {\n    Operation* op = getOperation();\n    RewritePatternSet patterns(op->getContext());\n    oneflow::populateConvertInferenceOp(patterns);\n    oneflow::rewrites::populateRewrites(patterns);\n    (void)applyPatternsAndFoldGreedily(op, std::move(patterns));\n  }\n};\n\nclass PostConvertInferenceOpPass\n    : public PostConvertInferenceOpPassBase<PostConvertInferenceOpPass> {\n  void runOnOperation() override {\n    Operation* op = getOperation();\n    RewritePatternSet patterns(op->getContext());\n    oneflow::populatePostConvertInferenceOp(patterns);\n    (void)applyPatternsAndFoldGreedily(op, std::move(patterns));\n  }\n};\n\n}  // namespace\n\nstd::unique_ptr<Pass> createPreConvertInferenceOpPass() {\n  return std::make_unique<PreConvertInferenceOpPass>();\n}\n\nstd::unique_ptr<Pass> createConvertInferenceOpPass() {\n  return std::make_unique<ConvertInferenceOpPass>();\n}\n\nstd::unique_ptr<Pass> createPostConvertInferenceOpPass() {\n  return std::make_unique<PostConvertInferenceOpPass>();\n}\n\n}  // namespace oneflow\n\n}  // namespace mlir"
  },
  {
    "path": "oneflow/ir/lib/OneFlow/Transform/EliminateAllocOps.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"OneFlow/OneFlowPDLLPatterns.h\"\n#include \"OneFlow/Passes.h\"\n#include \"mlir/Transforms/GreedyPatternRewriteDriver.h\"\n\nnamespace mlir {\nnamespace oneflow {\n\nnamespace {\nclass EliminateAllocOpsPass : public EliminateAllocOpsPassBase<EliminateAllocOpsPass> {\n  void runOnOperation() override {\n    Operation* op = getOperation();\n    RewritePatternSet patterns(op->getContext());\n    mlir::oneflow::populateAllocEliminationPatterns(patterns);\n    (void)applyPatternsAndFoldGreedily(op, std::move(patterns));\n  }\n};\n\n}  // namespace\n\nstd::unique_ptr<Pass> createEliminateAllocOpsPass() {\n  return std::make_unique<EliminateAllocOpsPass>();\n}\n\n}  // namespace oneflow\n}  // namespace mlir\n"
  },
  {
    "path": "oneflow/ir/lib/OneFlow/Transform/FuncOps.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"OneFlow/OneFlowDialect.h\"\n#include \"OneFlow/OneFlowOps.h\"\n#include \"OneFlow/Passes.h\"\n#include \"mlir/Pass/Pass.h\"\n#include \"mlir/Transforms/DialectConversion.h\"\n#include \"mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h\"\n#include \"mlir/Transforms/GreedyPatternRewriteDriver.h\"\n\nnamespace mlir {\n\nnamespace func {\nstruct FuncConversionToOneFlow final : public OpConversionPattern<FuncOp> {\n public:\n  using OpConversionPattern<FuncOp>::OpConversionPattern;\n  LogicalResult matchAndRewrite(FuncOp op, OpAdaptor adaptor,\n                                ConversionPatternRewriter& rewriter) const override {\n    auto func = rewriter.create<oneflow::Job>(op.getLoc(), op.getName(), op.getFunctionType());\n    rewriter.inlineRegionBefore(op.getRegion(), func.getBody(), func.end());\n    rewriter.eraseOp(op);\n    return success();\n  }\n};\n\nstruct ReturnConversionToOneFlow final : public OpConversionPattern<ReturnOp> {\n public:\n  using OpConversionPattern<ReturnOp>::OpConversionPattern;\n  LogicalResult matchAndRewrite(ReturnOp op, OpAdaptor adaptor,\n                                ConversionPatternRewriter& rewriter) const override {\n    rewriter.replaceOpWithNewOp<oneflow::ReturnOp>(op,\n                                                   /* operands */ op.getOperands());\n    return success();\n  }\n};\n}  // namespace func\n\nnamespace oneflow {\nstruct JobConversionToFunc final : public OpConversionPattern<Job> {\n public:\n  using OpConversionPattern<Job>::OpConversionPattern;\n  LogicalResult matchAndRewrite(Job op, OpAdaptor adaptor,\n                                ConversionPatternRewriter& rewriter) const override {\n    auto func = rewriter.create<func::FuncOp>(op.getLoc(), op.getName(), op.getFunctionType());\n    rewriter.inlineRegionBefore(op.getRegion(), func.getBody(), func.end());\n    rewriter.eraseOp(op);\n    return success();\n  }\n};\n\nstruct ReturnConversionToFunc final : public OpConversionPattern<ReturnOp> {\n public:\n  using OpConversionPattern<ReturnOp>::OpConversionPattern;\n  LogicalResult matchAndRewrite(ReturnOp op, OpAdaptor adaptor,\n                                ConversionPatternRewriter& rewriter) const override {\n    rewriter.replaceOpWithNewOp<func::ReturnOp>(op,\n                                                /* operands */ op.getOperands());\n    return success();\n  }\n};\n\nnamespace {\n\nclass OneFlowJobToFuncPass : public OneFlowJobToFuncPassBase<OneFlowJobToFuncPass> {\n  void runOnOperation() override {\n    Operation* op = getOperation();\n    ConversionTarget target(getContext());\n    target.addLegalDialect<mlir::func::FuncDialect>();\n    RewritePatternSet patterns(&getContext());\n    patterns.add<oneflow::JobConversionToFunc, oneflow::ReturnConversionToFunc>(op->getContext());\n    if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) {\n      signalPassFailure();\n      LOG(ERROR) << \"Failed to ofjob to func\";\n      getOperation()->dump();\n    }\n  }\n};\n\nclass FuncToOneFlowJobPass : public FuncToOneFlowJobPassBase<FuncToOneFlowJobPass> {\n  void getDependentDialects(::mlir::DialectRegistry& registry) const override {\n    registry.insert<oneflow::OneFlowDialect>();\n  }\n  void runOnOperation() override {\n    Operation* op = getOperation();\n    ConversionTarget target(getContext());\n    target.addLegalDialect<mlir::oneflow::OneFlowDialect>();\n    RewritePatternSet patterns(&getContext());\n    patterns.add<func::FuncConversionToOneFlow, func::ReturnConversionToOneFlow>(op->getContext());\n    if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) {\n      signalPassFailure();\n      LOG(ERROR) << \"Failed to func to ofjob\";\n      getOperation()->dump();\n    }\n  }\n};\n\n}  // namespace\n\nstd::unique_ptr<Pass> createOneFlowJobToFuncPass() {\n  return std::make_unique<OneFlowJobToFuncPass>();\n}\n\nstd::unique_ptr<Pass> createFuncToOneFlowJobPass() {\n  return std::make_unique<FuncToOneFlowJobPass>();\n}\n\n}  // namespace oneflow\n\n}  // namespace mlir\n"
  },
  {
    "path": "oneflow/ir/lib/OneFlow/Transform/GroupMatMulOps.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"OneFlow/OneFlowOps.h\"\n\nnamespace mlir {\n\nnamespace oneflow {\n\ntemplate<typename OpTy>\nbool isLinearMatmulOp(OpTy op) {\n  const bool isAlphaOne = op.getAlpha().convertToDouble() == 1.0;\n  const bool isLinear = op.getTransposeA() == false && op.getTransposeB() == true;\n  const bool hasNoAddToOutput = !op.get_addToOutput();\n  const bool isCUDA = op.getDeviceTag() == \"cuda\";\n  return isAlphaOne && isLinear && hasNoAddToOutput && isCUDA;\n}\n\nbool MatmulOp::isLinear() { return isLinearMatmulOp(*this); }\n\nValue MatmulOp::matMulGetX() { return getA(); }\n\nValue MatmulOp::matMulGetW() { return getB(); }\n\nValue MatmulOp::matMulGetY() { return getOut(); }\n\nbool BroadcastMatmulOp::isLinear() { return isLinearMatmulOp(*this); }\n\nValue BroadcastMatmulOp::matMulGetX() { return getA(); }\n\nValue BroadcastMatmulOp::matMulGetW() { return getB(); }\n\nValue BroadcastMatmulOp::matMulGetY() { return getOut(); }\n\nbool BiasAddOp::isLastDim() {\n  return getAxis() == -1 || getAxis() == getOut().getType().cast<ShapedType>().getRank() - 1;\n}\n\nValue BiasAddOp::biasAddGetBias() { return getB(); }\n\nValue BiasAddOp::biasAddGetOut() { return getOut(); }\n\nValue BroadcastAddOp::biasAddGetBias() { return getY(); }\n\nValue BroadcastAddOp::biasAddGetOut() { return getZ(); }\n\nbool BroadcastAddOp::isLastDim() { return true; }\n\nValue FusedMatmulBiasOp::matMulGetX() { return getX(); }\n\nValue FusedMatmulBiasOp::matMulGetW() { return getWeight(); }\n\nValue FusedMatmulBiasOp::matMulGetY() { return getOut(); }\n\nnamespace {\n\nbool shouldGroupFusedMatmulBiasOp(FusedMatmulBiasOp& op) {\n  return !op.get_addToOutput() && op.getDeviceTag() == \"cuda\"\n         && op.getAlpha().convertToDouble() == 1.0;\n}\n\n}  // namespace\n\nbool FusedMatmulBiasOp::isLinear() { return shouldGroupFusedMatmulBiasOp(*this); }\n\nbool FusedMatmulBiasOp::isLastDim() { return shouldGroupFusedMatmulBiasOp(*this); }\n\nValue FusedMatmulBiasOp::biasAddGetBias() { return getBias(); }\n\nValue FusedMatmulBiasOp::biasAddGetOut() { return getOut(); }\n\n}  // namespace oneflow\n\n}  // namespace mlir\n"
  },
  {
    "path": "oneflow/ir/lib/OneFlow/Transform/JITPasses.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"OneFlow/OneFlowDialect.h\"\n#include \"OneFlow/OneFlowOps.h\"\n#include \"OneFlow/OneFlowUtils.h\"\n#include \"OneFlow/Passes.h\"\n#include \"mlir/Transforms/GreedyPatternRewriteDriver.h\"\n#include \"mlir/Dialect/Tensor/IR/Tensor.h\"\n#include \"mlir/InitAllDialects.h\"\n#include \"mlir/Parser/Parser.h\"\n#include \"mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h\"\n#include \"mlir/Bytecode/BytecodeWriter.h\"\n\nnamespace mlir {\n\nnamespace oneflow {\n\nnamespace {\n\n// general lowering path:\n// 1. outline linalg ops to a func.func and an oneflow.jit op\n// 2. bufferize the func.func and update oneflow.jit op's tmp buffer size\n\n// 1. collect ops to outline\n// 2. create func.func jit ops to call\n// 3. replace the usages with jit ops' results\n\n// entries: non-oneflow ops which have operands are from oneflow ops\n// exits: result consumed by oneflow ops\n\n// NOTE: we assume all arg values are produced by an oneflow op and won't be an argument\n\nNamedAttrList GetJitOpAttributes(Builder& rewriter, StringRef op_name, int32_t input_size,\n                                 int32_t output_size, Operation* op) {\n  NamedAttrList attributes;\n  attributes.set(OpTrait::IsOpConfCompatible<void>::getDeviceTagAttr(),\n                 OpTrait::IsOpConfCompatible<void>::getDeviceTag(op));\n  attributes.set(OpTrait::IsOpConfCompatible<void>::getDeviceNameAttr(),\n                 OpTrait::IsOpConfCompatible<void>::getDeviceName(op));\n  if (auto hierarchy = OpTrait::IsOpConfCompatible<void>::getHierarchy(op)) {\n    attributes.set(OpTrait::IsOpConfCompatible<void>::getHierarchyAttr(), hierarchy);\n  }\n  attributes.set(OpTrait::IsOpConfCompatible<void>::getOpNameAttr(),\n                 rewriter.getStringAttr(op_name));\n  if (auto scope_symbol_id = OpTrait::IsOpConfCompatible<void>::getScopeSymbolID(op)) {\n    attributes.set(OpTrait::IsOpConfCompatible<void>::getScopeSymbolIDAttr(), scope_symbol_id);\n  }\n  return attributes;\n}\n\nbool isOneFlowOp(Operation* op) { return llvm::dyn_cast<OneFlowDialect>(op->getDialect()); }\nclass Outliner {\n private:\n  OpBuilder& builder;\n  Block* body;\n  llvm::DenseSet<Operation*>& visitedOps;\n  std::queue<Operation*> worklist{};\n  void cloneOpsToNewBody(Operation* op, bool defer = false) {\n    if (visitedOps.contains(op)) { return; }\n    for (auto operand : op->getOperands()) {\n      if (!mapping.lookupOrNull(operand)) {\n        if (auto defOp = operand.getDefiningOp()) {\n          if (isOneFlowOp(defOp)) {\n            entries.insert(operand);\n            auto arg = body->addArgument(operand.getType(), operand.getLoc());\n            mapping.map(operand, arg);\n            mappingReversed.map(arg, operand);\n          } else {\n            cloneOpsToNewBody(defOp, true);\n          }\n        }\n      }\n    }\n    ImplicitLocOpBuilder nb(op->getLoc(), builder);\n    nb.clone(*op, mapping);\n    visitedOps.insert(op);\n\n    for (auto& use : op->getUses()) {\n      auto owner = use.getOwner();\n      if (isOneFlowOp(owner)) {\n        exits.insert(use.get());\n      } else {\n        if (defer) {\n          worklist.push(owner);\n        } else {\n          cloneOpsToNewBody(owner);\n        }\n      }\n    }\n    if (!defer) {\n      while (!worklist.empty()) {\n        auto op = worklist.front();\n        worklist.pop();\n        cloneOpsToNewBody(op);\n      }\n    }\n  }\n\n public:\n  Outliner(OpBuilder& builder, Block* body, Operation* op, llvm::DenseSet<Operation*>& visitedOps)\n      : builder{builder}, body{body}, visitedOps{visitedOps} {\n    cloneOpsToNewBody(op);\n  }\n\n  IRMapping mapping{};\n  IRMapping mappingReversed{};\n  llvm::DenseSet<Value> entries{}, exits{};\n};\n\nstatic std::string JITOpNamePrefix = \"JITOpGenerated\";\nint64_t getCountJITFunction() {\n  static std::atomic_int64_t countJITFunction = 0;\n  return countJITFunction.fetch_add(1);\n}\n\nnamespace {\n\nstd::function<void(mlir::MLIRContext* mlir_ctx, mlir::ModuleOp module)> getLowerFunction(\n    const StringAttr& device_tag) {\n  auto device_tag_str = device_tag.str();\n#ifdef WITH_MLIR_CUDA_CODEGEN\n  if (device_tag_str == \"cuda\") {\n    return [](mlir::MLIRContext* mlir_ctx, mlir::ModuleOp module) {\n      CHECK(mlir::succeeded(mlir::oneflow::LowerModuleToCUDALLVM(mlir_ctx, module)))\n          << \"fail to lower OneFlow to CUDA LLVM\";\n    };\n  }\n#endif  // WITH_MLIR_CUDA_CODEGEN\n  if (device_tag_str == \"cpu\") {\n    return [](mlir::MLIRContext* mlir_ctx, mlir::ModuleOp module) {\n      CHECK(mlir::succeeded(mlir::oneflow::LowerModuleToLLVM(mlir_ctx, module)))\n          << \"fail to lower OneFlow to LLVM\";\n    };\n  }\n  LOG(FATAL) << \"Fail to match lowering function with device tag name: \" << device_tag_str;\n}\nstd::string convertFuncToByte(func::FuncOp& func) {\n  std::string byte;\n  llvm::raw_string_ostream os_byte(byte);\n  mlir::writeBytecodeToFile(func, os_byte);\n  return byte;\n}\n\nstd::string lowerFuncToLLVMByte(const std::string& raw_byte, const StringAttr& device_tag) {\n  mlir::DialectRegistry registry;\n  mlir::registerAllDialects(registry);\n  registry.insert<mlir::oneflow::OneFlowDialect>();\n  mlir::MLIRContext mlir_ctx(registry);\n\n  mlir::OwningOpRef<mlir::ModuleOp> module =\n      ::mlir::parseSourceString<mlir::ModuleOp>(raw_byte, &mlir_ctx);\n  mlir::registerLLVMDialectTranslation(registry);\n  if (::oneflow::ParseBooleanFromEnv(\"ONEFLOW_MLIR_STDOUT\", false)) { module->print(llvm::outs()); }\n  getLowerFunction(device_tag)(&mlir_ctx, *module);\n  if (::oneflow::ParseBooleanFromEnv(\"ONEFLOW_MLIR_STDOUT\", false)) { module->print(llvm::outs()); }\n  (*module)->setAttr(jit::RAW_GRAPH, StringAttr::get(&mlir_ctx, raw_byte));\n\n  std::string byte;\n  llvm::raw_string_ostream os_byte(byte);\n  mlir::writeBytecodeToFile(*module, os_byte);\n  return byte;\n}\n\n}  // namespace\n\nclass OutlineJitFunctionPass : public OutlineJitFunctionPassBase<OutlineJitFunctionPass> {\n  void runOnOperation() override {\n    llvm::DenseSet<Operation*> entryOps, visitedOps;\n    FunctionOpInterface job = getOperation();\n    auto& operations = job.getFunctionBody().front().getOperations();\n\n    for (auto& op : operations) {\n      if (llvm::dyn_cast<OneFlowDialect>(op.getDialect())) {\n        for (auto result : op.getResults()) {\n          for (auto user : result.getUsers()) {\n            if (!isOneFlowOp(user)) { entryOps.insert(user); }\n          }\n        }\n      }\n    }\n\n    OpBuilder builder{&getContext()};\n    for (auto entryOp : entryOps) {\n      if (visitedOps.contains(entryOp)) { continue; }\n      OpBuilder::InsertionGuard guard(builder);\n      auto block = new Block();\n      builder.setInsertionPointToStart(block);\n      auto outliner = Outliner(builder, block, entryOp, visitedOps);\n\n      SmallVector<::mlir::Value, 4> entries, exits, mappedExits;\n      SmallVector<Type, 4> argumentTypes, resultTypes;\n\n      for (Value exit : outliner.exits) {\n        exits.push_back(exit);\n        mappedExits.push_back(outliner.mapping.lookup(exit));\n        resultTypes.push_back(exit.getType());\n      }\n      builder.setInsertionPointToEnd(block);\n      builder.create<func::ReturnOp>(entryOp->getLoc(), mappedExits);\n\n      for (auto argument : block->getArguments()) {\n        if (auto found = outliner.mappingReversed.lookupOrNull(argument)) {\n          entries.push_back(found);\n          argumentTypes.push_back(argument.getType());\n        } else {\n          job->emitError() << \"fail to outline, entry not found for argument #\"\n                           << argument.getArgNumber();\n          signalPassFailure();\n        }\n      }\n      auto funcType = builder.getFunctionType(argumentTypes, resultTypes);\n      if (auto mod = job->getParentOfType<ModuleOp>()) {\n        auto name = JITOpNamePrefix + std::to_string(getCountJITFunction());\n        SmallString<16> tempBuffer;\n        name = SanitizeIdentifier(name, tempBuffer);\n\n        builder.setInsertionPointToStart(&mod.getRegion().front());\n        auto function = builder.create<func::FuncOp>(entryOp->getLoc(), name, funcType);\n        function.getBody().push_front(block);\n\n        if (auto lastOp = exits.back().getDefiningOp()) {\n          builder.setInsertionPointAfter(lastOp);\n          NamedAttrList attributes =\n              GetJitOpAttributes(builder, name, argumentTypes.size(), resultTypes.size(),\n                                 entryOp->getOperand(0).getDefiningOp());\n          std::string byte =\n              compileToLLVM.getValue() ? lowerFuncToLLVMByte(\n                  convertFuncToByte(function),\n                  attributes.get(OpTrait::IsOpConfCompatible<void>::getDeviceTagAttr())\n                      .cast<StringAttr>())\n                                       : convertFuncToByte(function);\n          auto jitOp = builder.create<MlirJitOp>(entryOp->getLoc(), function, attributes, entries);\n          jitOp->setAttr(\"mlir_assembly\", builder.getStringAttr(byte));\n          for (const auto& old : llvm::enumerate(exits)) {\n            old.value().replaceAllUsesWith(jitOp->getResult(old.index()));\n          }\n        } else {\n          job->emitError() << \"fail to outline, nowhere to replace\";\n          signalPassFailure();\n        }\n      } else {\n        job->emitError() << \"fail to outline\";\n        signalPassFailure();\n      }\n    }\n  }\n};\n\n}  // namespace\n\nstd::unique_ptr<Pass> createOutlineJitFunctionPass() {\n  return std::make_unique<OutlineJitFunctionPass>();\n}\n\n}  // namespace oneflow\n\n}  // namespace mlir\n"
  },
  {
    "path": "oneflow/ir/lib/OneFlow/Transform/OneFlowMemPool.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"OneFlow/Passes.h\"\n#include \"OneFlow/Transform/OneFlowMemPool.h\"\n#include \"llvm/ADT/SmallVector.h\"\n#include \"llvm/Support/Casting.h\"\n#include \"mlir/Dialect/Arith/IR/Arith.h\"\n#include \"mlir/Dialect/Func/IR/FuncOps.h\"\n#include \"mlir/Dialect/GPU/IR/GPUDialect.h\"\n#include \"mlir/Dialect/LLVMIR/LLVMDialect.h\"\n#include \"mlir/Dialect/MemRef/IR/MemRef.h\"\n#include \"mlir/IR/BuiltinOps.h\"\n#include \"mlir/IR/BuiltinTypes.h\"\n#include \"mlir/IR/MLIRContext.h\"\n#include \"mlir/IR/Visitors.h\"\n#include \"mlir/Support/LLVM.h\"\n#include \"mlir/Support/LogicalResult.h\"\n#include \"mlir/Transforms/GreedyPatternRewriteDriver.h\"\n#include \"oneflow/core/common/hash_container.h\"\n#include \"oneflow/core/job/intra_job_mem_sharing_util.h\"\n#include <glog/logging.h>\n#include <algorithm>\n#include <climits>\n#include <tuple>\n#include <vector>\n\nnamespace mlir {\nnamespace oneflow {\nnamespace {\n\nType getMemPoolElemType(MLIRContext* ctx) { return IntegerType::get(ctx, 8); }\n\nconst int align_size_ = ::oneflow::kBlobBodyAlignSize;\n\nstruct AllocOpInfo {\n  Operation* val_ = nullptr;\n  int32_t start_lifetime_ = 0;\n  int32_t end_lifetime_ = 0;\n  size_t size_ = 0;\n};\n\ntemplate<class T>\nstd::vector<AllocOpInfo> getAllocInfoList(T op) {\n  std::vector<memref::AllocOp> list;\n  // collect all memref.alloc ops and gpu.launch_func ops\n  op->walk([&](memref::AllocOp alloc) { list.push_back(alloc); });\n\n  std::vector<AllocOpInfo> ret;\n  for (auto alloc : list) {\n    // compute size\n    MemRefType type = alloc->getResult(0).getType().dyn_cast<MemRefType>();\n    size_t size = type.getElementTypeBitWidth() / 8;\n    for (int64_t i : type.getShape()) { size *= i; }\n    size = (size / align_size_ + ((size % align_size_) != 0)) * align_size_;\n    // compute lifetime\n    // TODO: support lifetime analysis\n    int32_t start_lifetime = 0;\n    int32_t end_lifetime = INT_MAX;\n    ret.push_back({alloc, start_lifetime, end_lifetime, size});\n  }\n  return ret;\n}\n\nvoid replaceAllocwithSubview(func::FuncOp func, OpBuilder& builder,\n                             const ::oneflow::MemBlockResultInfo<Operation*>& ret) {\n  // create the uni memref.alloc op\n  builder.setInsertionPointToStart(&func.getBody().front());\n  auto output_type = MemRefType::get({static_cast<long>(ret.mem_block_size)},\n                                     getMemPoolElemType(func->getContext()));\n  Value mempool = builder.create<memref::AllocOp>(func->getLoc(), output_type);\n  // replace alloc with subview\n  for (auto [op, offset] : ret.regst_desc2offset) {\n    MemRefType type = op->getResult(0).getType().cast<MemRefType>();\n    Value byte_shift = builder.create<arith::ConstantIndexOp>(op->getLoc(), offset);\n    Value new_op =\n        builder.create<memref::ViewOp>(op->getLoc(), type, mempool, byte_shift, ValueRange{});\n    op->replaceAllUsesWith(ValueRange{new_op});\n    op->erase();\n  }\n}\n\nbool isMemPool(Operation* op) {\n  auto alloc = dyn_cast<memref::AllocOp>(op);\n  if (!alloc) return false;\n  MemRefType type = alloc->getOpResult(0).getType().cast<MemRefType>();\n  if (!type) return false;\n  return type.getRank() == 1 && type.getElementType() == getMemPoolElemType(op->getContext());\n}\n\nstruct InsertOneFlowMemPoolPattern final : public OpRewritePattern<func::FuncOp> {\n  // GetAllocOpSize(funop) -> <is_legal, size_of_mem_pool>\n  std::pair<bool, memref::AllocOp> getAllocOp(func::FuncOp func) const {\n    memref::AllocOp ret;\n    auto& ops = func.getBody().front();\n    for (auto& op : ops) {\n      if (auto alloc = llvm::dyn_cast_or_null<memref::AllocOp>(op)) {\n        if (ret) return {false, ret};\n        ret = alloc;\n      }\n    }\n    return {true, ret};\n  }\n\n  MemRefType getNullMemType(mlir::PatternRewriter& rewriter) const {\n    return MemRefType::get({1}, getMemPoolElemType(rewriter.getContext()));\n  }\n\n public:\n  explicit InsertOneFlowMemPoolPattern(mlir::MLIRContext* context)\n      : OpRewritePattern<func::FuncOp>(context, /*benefit=*/0) {}\n  mlir::LogicalResult matchAndRewrite(func::FuncOp op,\n                                      mlir::PatternRewriter& rewriter) const override {\n    auto module = op->getParentOfType<ModuleOp>();\n    if (module && module->getAttr(codegen::mempool::MEMPOOL_ATTR_NAME)) return success();\n\n    auto [is_legal, alloc_op] = getAllocOp(op);\n    if (!is_legal) {\n      LOG(FATAL) << \"you should run -fold-memref-alloc before insert-ofmem-pool pass\";\n      return failure();\n    }\n\n    auto type = alloc_op ? alloc_op->getResult(0).getType().dyn_cast_or_null<MemRefType>()\n                         : getNullMemType(rewriter);\n    if (type.getRank() != 1 || type.getElementType() != getMemPoolElemType(op->getContext())) {\n      LOG(FATAL) << \"the alloc op fail to matching memref<?xi8>\";\n      return failure();\n    }\n    llvm::SmallVector<Type> new_operand_types;\n    new_operand_types.push_back(type);\n    for (auto type : op.getFunctionType().getInputs()) { new_operand_types.push_back(type); }\n    auto function_type =\n        rewriter.getFunctionType(new_operand_types, op.getFunctionType().getResults());\n\n    auto func = rewriter.create<func::FuncOp>(op.getLoc(), op.getName(), function_type);\n    for (auto pair : op->getDialectAttrs()) { func->setAttr(pair.getName(), pair.getValue()); }\n    op.getBody().insertArgument(unsigned(0), type, op->getLoc());\n    if (alloc_op) rewriter.replaceOp(alloc_op, {op.getArgument(0)});\n    IRMapping bvm;\n    op.getRegion().cloneInto(&func.getRegion(), bvm);\n    rewriter.eraseOp(op);\n    module->setAttr(codegen::mempool::MEMPOOL_ATTR_NAME,\n                    rewriter.getI64IntegerAttr(type.getDimSize(0)));\n    return success();\n  }\n};\n\nclass InsertOneFlowMemPoolPass : public InsertOneFlowMemPoolPassBase<InsertOneFlowMemPoolPass> {\n  void runOnOperation() override {\n    Operation* op = getOperation();\n    auto ctx = op->getContext();\n    RewritePatternSet patterns(ctx);\n    patterns.add<InsertOneFlowMemPoolPattern>(ctx);\n    (void)applyPatternsAndFoldGreedily(op, std::move(patterns));\n  }\n};\n\nclass FoldAllocToSubviewPass : public FoldAllocToSubviewPassBase<FoldAllocToSubviewPass> {\n  void runOnOperation() override {\n    func::FuncOp op = getOperation();\n    applyFoldAlloc(op);\n  }\n};\n\n}  // namespace\n\nvoid applyFoldAlloc(func::FuncOp op) {\n  std::vector<AllocOpInfo> list;\n  // TODO-1: support cpu memory fold\n  // TODO-2: support multiple gpu.launch\n  op->walk([&](gpu::LaunchOp launchOp) { list = getAllocInfoList(launchOp); });\n  op->walk([&](scf::ForallOp launchOp) { list = getAllocInfoList(launchOp); });\n\n  {\n    std::vector<AllocOpInfo> body_list;\n    body_list = getAllocInfoList(op);\n    list.insert(list.end(), body_list.begin(), body_list.end());\n  }\n\n  auto ctx = op->getContext();\n  OpBuilder builder(ctx);\n\n  // Note: no malloc op should be folded.\n  if (!list.size()) { return; }\n  // Note: the single malloc op with out type of memref<?xi8> means it has been folded.\n  if (list.size() == 1 && oneflow::isMemPool(list.front().val_)) { return; }\n\n  ::oneflow::HashMap<Operation*, std::pair<int32_t, int32_t>> val2lifetime;\n  ::oneflow::HashMap<Operation*, size_t> val2size;\n  for (const auto& info : list) {\n    val2lifetime[info.val_] = {info.start_lifetime_, info.end_lifetime_};\n    val2size[info.val_] = info.size_;\n  }\n  ::oneflow::MemBlockResultInfo<Operation*> ret;\n\n  ::oneflow::MemReusedMemSizeFirstAlgo(false, val2lifetime, val2size, &ret);\n  oneflow::replaceAllocwithSubview(op, builder, ret);\n}\n\nstd::unique_ptr<Pass> createInsertOneFlowMemPoolPass() {\n  return std::make_unique<InsertOneFlowMemPoolPass>();\n}\n\nstd::unique_ptr<Pass> createFoldAllocToSubviewPass() {\n  return std::make_unique<FoldAllocToSubviewPass>();\n}\n\n}  // namespace oneflow\n}  // namespace mlir"
  },
  {
    "path": "oneflow/ir/lib/OneFlow/Transform/OneFlowStream.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"OneFlow/OneFlowPDLLPatterns.h\"\n#include \"OneFlow/Passes.h\"\n#include \"llvm/ADT/SmallVector.h\"\n#include \"llvm/Support/Casting.h\"\n#include \"mlir/Dialect/Func/IR/FuncOps.h\"\n#include \"mlir/Dialect/LLVMIR/LLVMDialect.h\"\n#include \"mlir/Support/LLVM.h\"\n#include \"mlir/Support/LogicalResult.h\"\n#include \"mlir/Transforms/GreedyPatternRewriteDriver.h\"\n\n#include <glog/logging.h>\n#include <functional>\n\nnamespace mlir {\nnamespace oneflow {\n\nnamespace {\n\nstruct MgpuToOneFlowStreamPattern final : public OpRewritePattern<LLVM::CallOp> {\n public:\n  explicit MgpuToOneFlowStreamPattern(mlir::MLIRContext* context)\n      : OpRewritePattern<LLVM::CallOp>(context, /*benefit=*/0) {}\n  mlir::LogicalResult matchAndRewrite(LLVM::CallOp op,\n                                      mlir::PatternRewriter& rewriter) const override {\n    auto ptr_type = LLVM::LLVMPointerType::get(rewriter.getContext());\n    auto func = op->getParentOfType<LLVM::LLVMFuncOp>();\n    auto callee = op.getCallee();\n    if (!func || !callee) return failure();\n    Value stream = func.getArguments().back();\n    if (stream.getType() != ptr_type) {\n      LOG(ERROR) << \"failed to find stream in llvm.func block arguments\";\n      return failure();\n    }\n\n    DenseMap<StringRef,\n             std::pair<std::function<bool(LLVM::CallOp&, Value&)>,\n                       std::function<void(mlir::PatternRewriter&, LLVM::CallOp&, Value&)>>>\n        oneflow_abi = {\n            {\"mgpuStreamCreate\",\n             {[](LLVM::CallOp& op, Value& stream) { return true; },\n              [](mlir::PatternRewriter& rewriter, LLVM::CallOp& op, Value& stream) {\n                rewriter.replaceOp(op, {stream});\n              }}},\n            {\"mgpuLaunchKernel\",\n             {[](LLVM::CallOp& op, Value& stream) {\n                unsigned idx = op->getNumOperands();\n                return op.getOperand(idx - 3) != stream;\n              },\n              [](mlir::PatternRewriter& rewriter, LLVM::CallOp& op, Value& stream) {\n                unsigned idx = op->getNumOperands();\n                auto target = op.getOperand(idx - 3).getDefiningOp();\n                rewriter.replaceOp(target, {stream});\n              }}},\n            // this sync operation is created by gpu-to-llvm-pass from gpu.launch_func op.\n            {\"mgpuStreamSynchronize\",\n             {[](LLVM::CallOp& op, Value& stream) { return true; },\n              [](mlir::PatternRewriter& rewriter, LLVM::CallOp& op, Value& stream) {\n                rewriter.eraseOp(op);\n              }}},\n            {\"mgpuStreamDestroy\",\n             {[](LLVM::CallOp& op, Value& stream) { return true; },\n              [](mlir::PatternRewriter& rewriter, LLVM::CallOp& op, Value& stream) {\n                rewriter.eraseOp(op);\n              }}},\n        };\n    auto out = oneflow_abi.find(callee.value().str());\n    if (out != oneflow_abi.end() && out->getSecond().first(op, stream)) {\n      out->getSecond().second(rewriter, op, stream);\n    }\n    return success();\n  }\n};\n\nstruct AppendOneFlowStreamPattern final : public OpRewritePattern<func::FuncOp> {\n public:\n  explicit AppendOneFlowStreamPattern(mlir::MLIRContext* context)\n      : OpRewritePattern<func::FuncOp>(context, /*benefit=*/0) {}\n  mlir::LogicalResult matchAndRewrite(func::FuncOp op,\n                                      mlir::PatternRewriter& rewriter) const override {\n    auto ptr_type = LLVM::LLVMPointerType::get(rewriter.getContext());\n    if (llvm::dyn_cast<LLVM::LLVMPointerType>(op.getFunctionType().getInputs().back()))\n      return success();\n\n    llvm::SmallVector<Type> new_operand_type;\n    for (auto type : op.getFunctionType().getInputs()) { new_operand_type.push_back(type); }\n    new_operand_type.push_back(ptr_type);\n    auto function_type =\n        rewriter.getFunctionType(new_operand_type, op.getFunctionType().getResults());\n\n    auto func = rewriter.create<func::FuncOp>(op.getLoc(), op.getName(), function_type);\n    for (auto pair : op->getDialectAttrs()) { func->setAttr(pair.getName(), pair.getValue()); }\n    op.getBody().addArgument(ptr_type, func->getLoc());\n    IRMapping bvm;\n    op.getRegion().cloneInto(&func.getRegion(), bvm);\n    rewriter.eraseOp(op);\n    return success();\n  }\n};\n\nclass AppendOneFlowStreamPass : public AppendOneFlowStreamPassBase<AppendOneFlowStreamPass> {\n  void runOnOperation() override {\n    Operation* op = getOperation();\n    auto ctx = op->getContext();\n    RewritePatternSet patterns(ctx);\n    patterns.add<AppendOneFlowStreamPattern>(ctx);\n    (void)applyPatternsAndFoldGreedily(op, std::move(patterns));\n  }\n};\n\nclass MgpuToOneFlowStreamPass : public MgpuToOneFlowStreamPassBase<MgpuToOneFlowStreamPass> {\n  void runOnOperation() override {\n    Operation* op = getOperation();\n    auto ctx = op->getContext();\n    RewritePatternSet patterns(ctx);\n    patterns.add<MgpuToOneFlowStreamPattern>(ctx);\n    (void)applyPatternsAndFoldGreedily(op, std::move(patterns));\n  }\n};\n\n}  // namespace\n\nstd::unique_ptr<Pass> createAppendOneFlowStreamPass() {\n  return std::make_unique<AppendOneFlowStreamPass>();\n}\n\nstd::unique_ptr<Pass> createMgpuToOneFlowStreamPass() {\n  return std::make_unique<MgpuToOneFlowStreamPass>();\n}\n\n}  // namespace oneflow\n}  // namespace mlir"
  },
  {
    "path": "oneflow/ir/lib/OneFlow/Transform/OutlineAndFuse.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"OneFlow/Transform/OutlineAndFuse.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n#include \"OneFlow/OKL/OKLDialect.h\"\n#include \"OneFlow/OneFlowDialect.h\"\n#include \"OneFlow/OneFlowOps.h\"\n#include \"OneFlow/Passes.h\"\n#include \"OneFlow/OneFlowPDLLPatterns.h\"\n#include \"OneFlow/OneFlowPatternUtils.h\"\n#include \"llvm/Support/Casting.h\"\n#include \"mlir/Dialect/LLVMIR/LLVMDialect.h\"\n#include \"mlir/IR/BuiltinOps.h\"\n#include \"mlir/Pass/Pass.h\"\n#include \"mlir/Transforms/DialectConversion.h\"\n#include \"mlir/Transforms/GreedyPatternRewriteDriver.h\"\n\n#include <iostream>\n#include <string>\n\nnamespace mlir {\nnamespace oneflow {\n\nnamespace {\n\nclass WrapOpsToKernelLaunchPass : public WrapOpsToKernelLaunchPassBase<WrapOpsToKernelLaunchPass> {\n public:\n  WrapOpsToKernelLaunchPass() = default;\n  WrapOpsToKernelLaunchPass(const WrapOpsToKernelLaunchPass& other)\n      : WrapOpsToKernelLaunchPassBase(other) {}\n\n  void getDependentDialects(DialectRegistry& registry) const override {\n    registry.insert<oneflow::OneFlowDialect>();\n  }\n\n  void runOnOperation() override {\n    Operation* op = getOperation();\n    RewritePatternSet patterns(op->getContext());\n    populateWrapOpsToKernelLaunchPatterns(patterns, wrap_ops_mode_.c_str());\n    (void)applyPatternsAndFoldGreedily(op, std::move(patterns));\n  }\n\n private:\n  Option<std::string> wrap_ops_mode_{*this, \"mode\",\n                                     llvm::cl::desc(\"the mode of this pass to wrap ops\"),\n                                     llvm::cl::init(wrap_mode::SIMPLE)};\n};\n\nclass FuseIntoExistingOpPass : public FuseIntoExistingOpPassBase<FuseIntoExistingOpPass> {\n  void runOnOperation() override {\n    Operation* op = getOperation();\n    RewritePatternSet patterns(op->getContext());\n    populateFuserForExistingOp(patterns);\n    (void)applyPatternsAndFoldGreedily(op, std::move(patterns));\n  }\n};\n\nnamespace {\n\nBiasAddCompatible getBiasAddCompatibleOp(MatMulCompatible op) {\n  BiasAddCompatible bias_add;\n  auto self_bias_op = dyn_cast<BiasAddCompatible>(op.getOperation());\n  if (self_bias_op) /* matmul itself is also bias add op */ {\n    bias_add = self_bias_op;\n  } else /* there is bias add op */ {\n    for (auto u : op.matMulGetY().getUsers()) {\n      if (auto b = dyn_cast<BiasAddCompatible>(u)) {\n        bias_add = b;\n        break;\n      }\n    }\n  }\n  if (bias_add && bias_add.isLastDim()) {\n    return bias_add;\n  } else {\n    return BiasAddCompatible{};\n  }\n}\n\n}  // namespace\nstruct GroupMatMulPattern : public mlir::OpInterfaceRewritePattern<MatMulCompatible> {\n  explicit GroupMatMulPattern(mlir::MLIRContext* context)\n      : OpInterfaceRewritePattern<MatMulCompatible>(context, /*benefit=*/1) {}\n  mlir::LogicalResult matchAndRewrite(MatMulCompatible op,\n                                      mlir::PatternRewriter& rewriter) const override {\n    if (!op.isLinear()) { return failure(); }\n    auto bias_add = getBiasAddCompatibleOp(op);\n    llvm::SmallVector<MatMulCompatible, 4> all_matmuls{};\n    llvm::SmallVector<BiasAddCompatible, 4> all_bias_adds{};\n    for (auto xUser : op.matMulGetX().getUsers()) {\n      if (auto matmul = dyn_cast<MatMulCompatible>(xUser)) {\n        if (!matmul.isLinear()) { continue; }\n        auto each_bias_add = getBiasAddCompatibleOp(matmul);\n        if (each_bias_add) { all_bias_adds.push_back(each_bias_add); }\n        if (!!bias_add == !!each_bias_add) { all_matmuls.push_back(matmul); }\n      }\n    }\n    // all_matmuls has only self, means no other matmul can be grouped\n    if (all_matmuls.size() == 1) { return failure(); }\n    llvm::SmallVector<Value, 4> operands{};\n    for (auto matmul : all_matmuls) { operands.push_back(matmul.matMulGetX()); }\n    for (auto matmul : all_matmuls) { operands.push_back(matmul.matMulGetW()); }\n    for (auto bias_adds : all_bias_adds) { operands.push_back(bias_adds.biasAddGetBias()); }\n    llvm::SmallVector<Type, 4> results{};\n    for (auto matmul : all_matmuls) { results.push_back(matmul.matMulGetY().getType()); }\n    NamedAttrList attributes{};\n    attributes.set(OpTrait::IsOpConfCompatible<void>::getDeviceTagAttr(),\n                   OpTrait::IsOpConfCompatible<void>::getDeviceTag(op));\n    attributes.set(OpTrait::IsOpConfCompatible<void>::getDeviceNameAttr(),\n                   OpTrait::IsOpConfCompatible<void>::getDeviceName(op));\n    if (auto hierarchy = OpTrait::IsOpConfCompatible<void>::getHierarchy(op)) {\n      attributes.set(OpTrait::IsOpConfCompatible<void>::getHierarchyAttr(), hierarchy);\n    }\n    if (auto scope_symbol_id = OpTrait::IsOpConfCompatible<void>::getScopeSymbolID(op)) {\n      attributes.set(OpTrait::IsOpConfCompatible<void>::getScopeSymbolIDAttr(), scope_symbol_id);\n    }\n    attributes.set(OpTrait::AttrSizedOperandSegments<void>::getOperandSegmentSizeAttr(),\n                   rewriter.getDenseI32ArrayAttr({static_cast<int>(all_matmuls.size()),\n                                                  static_cast<int>(all_matmuls.size()),\n                                                  static_cast<int>(all_bias_adds.size())}));\n    attributes.set(OpTrait::IsOpConfCompatible<void>::getOpNameAttr(),\n                   rewriter.getStringAttr(\n                       \"grouped_matmul_\" + OpTrait::IsOpConfCompatible<void>::getOpName(op).str()));\n    auto grouped_matmul =\n        rewriter.create<GroupedMatmulBiasOp>(op->getLoc(), results, operands, attributes);\n    if (all_bias_adds.empty()) {\n      for (const auto& matmul : llvm::enumerate(all_matmuls)) {\n        matmul.value().matMulGetY().replaceAllUsesWith(grouped_matmul.getYs()[matmul.index()]);\n      }\n    } else {\n      CHECK(all_bias_adds.size() == all_matmuls.size());\n      for (const auto& bias_add : llvm::enumerate(all_bias_adds)) {\n        bias_add.value().biasAddGetOut().replaceAllUsesWith(\n            grouped_matmul.getYs()[bias_add.index()]);\n      }\n    }\n    return success();\n  }\n};\n\nclass GroupMatMulPass : public GroupMatMulBase<GroupMatMulPass> {\n  void runOnOperation() override {\n    Operation* op = getOperation();\n    RewritePatternSet patterns(op->getContext());\n    patterns.add<GroupMatMulPattern>(op->getContext());\n    (void)applyPatternsAndFoldGreedily(op, std::move(patterns));\n  }\n};\n\nstruct GroupNormActivationPattern : public OpRewritePattern<GroupNormOp> {\n  explicit GroupNormActivationPattern(MLIRContext* context)\n      : OpRewritePattern<GroupNormOp>(context, /*benefit=*/1) {}\n  LogicalResult matchAndRewrite(oneflow::GroupNormOp op, PatternRewriter& rewriter) const override {\n    if (op.getActivation() == \"none\") {\n      llvm::SmallVector<Operation*, 4> act_ops{};\n      for (auto& u : op.getY().getUses()) {\n        if (auto act_op = dyn_cast<oneflow::SiluOp>(u.getOwner())) { act_ops.push_back(act_op); }\n      }\n      NamedAttrList attributes(op->getAttrs());\n      attributes.set(OpTrait::IsOpConfCompatible<void>::getOpNameAttr(),\n                     rewriter.getStringAttr(OpTrait::IsOpConfCompatible<void>::getOpName(op).str()\n                                            + \"_with_activation\"));\n      attributes.set(\"activation\", rewriter.getStringAttr(\"silu\"));\n      auto gn_with_act = rewriter.create<GroupNormOp>(op->getLoc(), op->getResultTypes(),\n                                                      op.getOperands(), attributes);\n      for (auto act : act_ops) {\n        if (auto op = dyn_cast<oneflow::SiluOp>(act)) {\n          op.getOut().replaceAllUsesWith(gn_with_act.getY());\n        }\n      }\n      return success();\n    }\n    return failure();\n  }\n};\n\nclass FuseForwardOpsPass : public FuseForwardOpsBase<FuseForwardOpsPass> {\n  void runOnOperation() override {\n    Operation* op = getOperation();\n    RewritePatternSet patterns(op->getContext());\n    patterns.add<GroupNormActivationPattern>(op->getContext());\n    (void)applyPatternsAndFoldGreedily(op, std::move(patterns));\n  }\n};\n\nclass FuseOpsWithBackwardImplPass\n    : public FuseOpsWithBackwardImplBase<FuseOpsWithBackwardImplPass> {\n  void runOnOperation() override {\n    Operation* op = getOperation();\n    RewritePatternSet patterns(op->getContext());\n    populateFuseOpsWithBackwardImplPattern(patterns);\n    rewrites::populateRewrites(patterns);\n    (void)applyPatternsAndFoldGreedily(op, std::move(patterns));\n  }\n};\n\nclass FuseNormalizationOpsPass : public FuseNormalizationOpsBase<FuseNormalizationOpsPass> {\n  void runOnOperation() override {\n    Operation* op = getOperation();\n    RewritePatternSet patterns(op->getContext());\n    populateNormalizationOpPatterns(patterns);\n    rewrites::populateRewrites(patterns);\n    (void)applyPatternsAndFoldGreedily(op, std::move(patterns));\n  }\n};\n\n}  // namespace\n\nstd::unique_ptr<Pass> createWrapOpsToKernelLaunchPass() {\n  return std::make_unique<WrapOpsToKernelLaunchPass>();\n}\n\nstd::unique_ptr<Pass> createFuseIntoExistingOpPass() {\n  return std::make_unique<FuseIntoExistingOpPass>();\n}\n\nstd::unique_ptr<Pass> createGroupMatMul() { return std::make_unique<GroupMatMulPass>(); }\n\nstd::unique_ptr<Pass> createFuseForwardOps() { return std::make_unique<FuseForwardOpsPass>(); }\nstd::unique_ptr<Pass> createFuseOpsWithBackwardImpl() {\n  return std::make_unique<FuseOpsWithBackwardImplPass>();\n}\n\nstd::unique_ptr<Pass> createFuseNormalizationOps() {\n  return std::make_unique<FuseNormalizationOpsPass>();\n}\n\n}  // namespace oneflow\n}  // namespace mlir\n"
  },
  {
    "path": "oneflow/ir/lib/OneFlow/Transform/TraitFolder.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"OneFlow/Passes.h\"\n#include \"llvm/ADT/SmallVector.h\"\n#include \"llvm/Support/Casting.h\"\n#include \"mlir/Dialect/Func/IR/FuncOps.h\"\n#include \"mlir/Dialect/LLVMIR/LLVMDialect.h\"\n#include \"mlir/Support/LLVM.h\"\n#include \"mlir/Support/LogicalResult.h\"\n#include \"mlir/Transforms/GreedyPatternRewriteDriver.h\"\n\n#include <glog/logging.h>\n#include <functional>\n\nnamespace mlir {\nnamespace oneflow {\nnamespace {\nclass TestOneFlowTraitFolderPass\n    : public TestOneFlowTraitFolderPassBase<TestOneFlowTraitFolderPass> {\n  void runOnOperation() override {\n    if (failed(applyPatternsAndFoldGreedily(getOperation(), RewritePatternSet(&getContext())))) {\n      exit(1);\n    }\n  }\n};\n\n}  // namespace\n\nstd::unique_ptr<Pass> createTestOneFlowTraitFolderPass() {\n  return std::make_unique<TestOneFlowTraitFolderPass>();\n}\n\n}  // namespace oneflow\n}  // namespace mlir\n"
  },
  {
    "path": "oneflow/ir/lib/OneFlow/TransposeHelpers.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <oneflow/ir/include/OneFlow/Transform/TransposeHelpers.h>\n\nnamespace mlir {\n\nnamespace oneflow {\n\nRankedTensorType getNHWCType(RankedTensorType t) {\n  return RankedTensorType::get({t.getShape()[0], t.getShape()[2], t.getShape()[3], t.getShape()[1]},\n                               t.getElementType());\n}\n\nRankedTensorType getNHWCType(Type t) { return getNHWCType(t.cast<RankedTensorType>()); }\nRankedTensorType getNHWCType(Value v) { return getNHWCType(v.getType()); }\n\nRankedTensorType getNCHWType(RankedTensorType t) {\n  return RankedTensorType::get({t.getShape()[0], t.getShape()[3], t.getShape()[1], t.getShape()[2]},\n                               t.getElementType());\n}\nRankedTensorType getNCHWType(Type t) { return getNCHWType(t.cast<RankedTensorType>()); }\nRankedTensorType getNCHWType(Value v) { return getNCHWType(v.getType()); }\n\nllvm::SmallVector<Type, 4> getNHWCResultTypes(NCHWCompatible op) {\n  llvm::SmallVector<Type, 4> result_types;\n  llvm::DenseSet<Value> transpose_result = op.ResultsToTranspose();\n  for (Value result : op->getOpResults()) {\n    Type t = result.getType();\n    if (transpose_result.find(result) != transpose_result.end()) {\n      result_types.push_back(getNHWCType(t));\n    } else {\n      result_types.push_back(t);\n    }\n  }\n  return result_types;\n}\n\n}  // namespace oneflow\n\n}  // namespace mlir\n"
  },
  {
    "path": "oneflow/ir/lib/OneFlow/UserOpConversion.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n// this file should contains functions to get operands and results with user op name and index\n\n#include \"OneFlow/UserOpConversion.h\"\n#include \"OneFlow/UserOpReflection.h\"\n#include \"oneflow/core/framework/user_op_def.h\"\n\nnamespace mlir {\n\nnamespace oneflow {\n\nnamespace user_op {\n\nLogicalResult saveAttrDictionaryToOpConf(DictionaryAttr attributes,\n                                         ::oneflow::OperatorConf* op_conf) {\n  if (auto scope_symbol_id =\n          attributes.get(OpTrait::IsOpConfCompatible<void>::getScopeSymbolIDAttr())\n              .dyn_cast_or_null<IntegerAttr>()) {\n    op_conf->set_scope_symbol_id(scope_symbol_id.getInt());\n  }\n  if (auto op_name = attributes.get(OpTrait::IsOpConfCompatible<void>::getOpNameAttr())\n                         .dyn_cast_or_null<StringAttr>()) {\n    op_conf->set_name(op_name.str());\n  }\n  auto device_tag = attributes.get(OpTrait::IsOpConfCompatible<void>::getDeviceTagAttr())\n                        .dyn_cast_or_null<StringAttr>();\n  CHECK(device_tag) << \"attr absent: \"\n                    << OpTrait::IsOpConfCompatible<void>::getDeviceTagAttr().str();\n  op_conf->set_device_tag(device_tag.str());\n  return success();\n}\n\nLogicalResult doConvertUserOpAttributes(llvm::StringRef op_type_name, DictionaryAttr attributes,\n                                        ::oneflow::OperatorConf& op_conf) {\n  auto user_conf = op_conf.mutable_user_conf();\n  op_conf.mutable_user_conf()->set_op_type_name(op_type_name.str());\n  CHECK(saveAttrDictionaryToOpConf(attributes, &op_conf).succeeded());\n  for (auto id_attr : attributes) {\n    auto id = id_attr.getName();\n    // mlir only attrs\n    // TODO: prefix special attributes with \"oneflow.\". For example: `oneflow.op_type_name = \"add\"`\n    if (id.strref().equals(\"callee\")\n        || id.strref().equals(OpTrait::IsOpConfCompatible<void>::getDeviceNameAttr())\n        || id.strref().equals(OpTrait::IsOpConfCompatible<void>::getHierarchyAttr())\n        || id.strref().equals(OpTrait::IsImportCompatible<void>::getOutputLBNsAttr())\n        || id.strref().equals(OpTrait::IsAlternative<void>::getOpTypeNameAttr())\n        || id.strref().equals(\n            mlir::OpTrait::AttrSizedOperandSegments<void>::getOperandSegmentSizeAttr())\n        || id.strref().equals(\n            mlir::OpTrait::AttrSizedResultSegments<void>::getResultSegmentSizeAttr())) {\n      continue;\n    } else if (id.strref().equals(\"input_sizes\") || id.strref().equals(\"output_sizes\")) {\n      continue;\n    }\n    // convert op conf attributes\n    else if (id.strref().equals(OpTrait::IsOpConfCompatible<void>::getOpNameAttr())) {\n      continue;\n    } else if (id.strref().equals(OpTrait::IsOpConfCompatible<void>::getDeviceTagAttr())) {\n      continue;\n    } else if (id.strref().equals(OpTrait::IsOpConfCompatible<void>::getScopeSymbolIDAttr())) {\n      continue;\n    }\n    // convert user conf attributes\n    else {\n      auto attr_name = id.str();\n      Attribute attr = id_attr.getValue();\n      auto user_attr = ::oneflow::AttrValue();\n      const ::oneflow::AttrType attr_type = queryAttrType(op_type_name.str(), attr_name);\n      if (attr_type == ::oneflow::kAtInt32) {\n        user_attr.set_at_int32(attr.dyn_cast<IntegerAttr>().getSInt());\n      } else if (attr_type == ::oneflow::kAtInt64) {\n        user_attr.set_at_int64(attr.dyn_cast<IntegerAttr>().getSInt());\n      } else if (attr_type == ::oneflow::kAtBool) {\n        user_attr.set_at_bool(attr.dyn_cast<BoolAttr>().getValue());\n      } else if (attr_type == ::oneflow::kAtFloat) {\n        user_attr.set_at_float(attr.dyn_cast<FloatAttr>().getValue().convertToFloat());\n      } else if (attr_type == ::oneflow::kAtDouble) {\n        user_attr.set_at_double(attr.dyn_cast<FloatAttr>().getValue().convertToDouble());\n      } else if (attr_type == ::oneflow::kAtString) {\n        user_attr.set_at_string(attr.dyn_cast<StringAttr>().getValue().str());\n      } else if (attr_type == ::oneflow::kAtShape) {\n        *user_attr.mutable_at_shape() = getAttrAsShape(attr);\n      } else if (attr_type == ::oneflow::kAtStride) {\n        *user_attr.mutable_at_stride() = getAttrAsStride(attr);\n      } else if (attr_type == ::oneflow::kAtDataType) {\n        const auto dt = support::FromMLIRAttrToOFDataType(attr);\n        if (succeeded(dt)) {\n          user_attr.set_at_data_type(dt.value());\n        } else {\n          LOG(FATAL) << \"fail to convert op attr to data type, key: \" + id.str();\n          return failure();\n        }\n      } else if (attr_type == ::oneflow::kAtListInt32) {\n        user_attr.mutable_at_list_int32();\n        auto ref = attr.dyn_cast<ArrayAttr>();\n        for (auto v : ref.getValue()) {\n          user_attr.mutable_at_list_int32()->add_val(v.dyn_cast<IntegerAttr>().getSInt());\n        }\n      } else if (attr_type == ::oneflow::kAtListInt64) {\n        user_attr.mutable_at_list_int64();\n        auto ref = attr.dyn_cast<ArrayAttr>();\n        for (auto v : ref.getValue()) {\n          user_attr.mutable_at_list_int64()->add_val(v.dyn_cast<IntegerAttr>().getSInt());\n        }\n      } else if (attr_type == ::oneflow::kAtListFloat) {\n        user_attr.mutable_at_list_float();\n        auto ref = attr.dyn_cast<ArrayAttr>();\n        for (auto v : ref.getValue()) {\n          user_attr.mutable_at_list_float()->add_val(\n              v.dyn_cast<FloatAttr>().getValue().convertToFloat());\n        }\n      } else if (attr_type == ::oneflow::kAtListDataType) {\n        for (auto v : attr.dyn_cast<ArrayAttr>().getValue()) {\n          const auto dt = support::FromMLIRAttrToOFDataType(attr);\n          if (succeeded(dt)) {\n            user_attr.mutable_at_list_data_type()->add_val(dt.value());\n          } else {\n            LOG(FATAL) << \"fail to convert op attr to data type, key: \" + id.str();\n            return failure();\n          }\n        }\n      } else if (attr_type == ::oneflow::kAtListShape) {\n        for (auto shape_attr : attr.dyn_cast<ArrayAttr>().getValue()) {\n          ::oneflow::ShapeProto* shape_ptr = user_attr.mutable_at_list_shape()->add_val();\n          *shape_ptr = getAttrAsShape(shape_attr);\n        }\n      } else if (attr_type == ::oneflow::kAtListStride) {\n        for (auto stride_attr : attr.dyn_cast<ArrayAttr>().getValue()) {\n          ::oneflow::Int64ListProto* stride_ptr = user_attr.mutable_at_list_stride()->add_val();\n          *stride_ptr = getAttrAsStride(stride_attr);\n        }\n      } else if (attr_type == ::oneflow::kAtListString) {\n        // attr like nd_sbp requires the existence of list even it is empty\n        user_attr.mutable_at_list_string();\n        for (auto s : attr.dyn_cast<ArrayAttr>().getValue()) {\n          user_attr.mutable_at_list_string()->add_val(s.dyn_cast<StringAttr>().getValue().str());\n        }\n      } else if (attr_type == ::oneflow::kAtComplexDouble) {\n        // TODO(lml): use arrayattr to represent complex number is not safe, need improve.\n        user_attr.mutable_at_complex_double();\n        auto ref = attr.dyn_cast<ArrayAttr>();\n        user_attr.mutable_at_complex_double()->set_real(\n            ref.getValue()[0].dyn_cast<FloatAttr>().getValue().convertToDouble());\n        user_attr.mutable_at_complex_double()->set_imag(\n            ref.getValue()[1].dyn_cast<FloatAttr>().getValue().convertToDouble());\n      } else {\n        return failure();\n      }\n      (*user_conf->mutable_attr())[id.str()] = user_attr;\n    }\n  }\n  return success();\n}\n\nLogicalResult ConvertUserOpAttributes(llvm::StringRef op_type_name, ValueRange operands,\n                                      DictionaryAttr attributes, ::oneflow::OperatorConf& op_conf) {\n  {\n    std::vector<std::string> keys{};\n    std::vector<int32_t> sizes{};\n    if (failed(user_op::GetFilteredSegmentKeyAndSizes<OpTrait::AttrSizedOperandSegments>(\n            op_type_name, operands.size(), attributes, keys, sizes))) {\n      LOG(FATAL) << \"fail to get filtered segment key and sizes\";\n      return failure();\n    }\n    for (const auto& s : keys) { op_conf.mutable_user_conf()->add_input_order(s); }\n  }\n  return doConvertUserOpAttributes(op_type_name, attributes, op_conf);\n}\n\nLogicalResult ConvertUserOpAttributes(Operation* op, ::oneflow::OperatorConf& op_conf) {\n  std::string op_type_name = GetOpTypeName(op);\n  {\n    std::vector<std::string> keys{};\n    std::vector<int32_t> sizes{};\n    if (failed(user_op::GetFilteredSegmentKeyAndSizes<OpTrait::AttrSizedOperandSegments>(op, keys,\n                                                                                         sizes))) {\n      op->emitError(\"fail to convert user op input order\");\n      return failure();\n    }\n    for (const auto& s : keys) { op_conf.mutable_user_conf()->add_input_order(s); }\n  }\n  {\n    std::vector<std::string> keys{};\n    std::vector<int32_t> sizes{};\n    if (failed(user_op::GetFilteredSegmentKeyAndSizes<OpTrait::AttrSizedResultSegments>(op, keys,\n                                                                                        sizes))) {\n      op->emitError(\"fail to convert user op output order\");\n      return failure();\n    }\n    for (const auto& s : keys) { op_conf.mutable_user_conf()->add_output_order(s); }\n  }\n  return doConvertUserOpAttributes(op_type_name, op->getAttrDictionary(), op_conf);\n}\n\nLogicalResult ConvertUserOpAttributes(Operation* op, ::oneflow::OperatorConf& op_conf,\n                                      bool is_mapping_size) {\n  auto user_conf = op_conf.mutable_user_conf();\n  std::string op_type_name = GetOpTypeName(op);\n  op_conf.mutable_user_conf()->set_op_type_name(op_type_name);\n  if (op->hasTrait<OpTrait::IsOpConfCompatible>()) {\n    if (OpTrait::IsOpConfCompatible<void>::dump_attr(op, &op_conf).failed()) {\n      return op->emitError(\"fail to save attr to op_conf\");\n    }\n  }\n\n  auto writeAttrToShape = [](mlir::Attribute& attr, ::oneflow::ShapeProto* shape) {\n    for (auto v : attr.dyn_cast<ArrayAttr>().getValue()) {\n      shape->add_dim(v.dyn_cast<IntegerAttr>().getSInt());\n    }\n  };\n\n  auto writeAttrToStride = [](mlir::Attribute& attr, ::oneflow::Int64ListProto* stride) {\n    for (auto v : attr.dyn_cast<ArrayAttr>().getValue()) {\n      stride->add_dim(v.dyn_cast<IntegerAttr>().getSInt());\n    }\n  };\n\n  for (auto id_attr : op->getAttrDictionary()) {\n    auto id = id_attr.getName();\n    // mlir only attrs\n    // TODO: prefix special attributes with \"oneflow.\". For example: `oneflow.op_type_name = \"add\"`\n    if (id.strref().equals(\"callee\")\n        || id.strref().equals(OpTrait::IsOpConfCompatible<void>::getDeviceNameAttr())\n        || id.strref().equals(OpTrait::IsOpConfCompatible<void>::getHierarchyAttr())\n        || id.strref().equals(OpTrait::IsImportCompatible<void>::getOutputLBNsAttr())\n        || id.strref().equals(OpTrait::IsAlternative<void>::getOpTypeNameAttr())\n        || id.strref().equals(\n            mlir::OpTrait::AttrSizedOperandSegments<void>::getOperandSegmentSizeAttr())\n        || id.strref().equals(\n            mlir::OpTrait::AttrSizedResultSegments<void>::getResultSegmentSizeAttr())) {\n      continue;\n    } else if (id.strref().equals(\"input_sizes\") || id.strref().equals(\"output_sizes\")) {\n      continue;\n    }\n    // convert op conf attributes\n    else if (id.strref().equals(OpTrait::IsOpConfCompatible<void>::getOpNameAttr())) {\n      continue;\n    } else if (id.strref().equals(OpTrait::IsOpConfCompatible<void>::getDeviceTagAttr())) {\n      continue;\n    } else if (id.strref().equals(OpTrait::IsOpConfCompatible<void>::getScopeSymbolIDAttr())) {\n      continue;\n    }\n    // convert user conf attributes\n    else {\n      auto attr_name = id.str();\n      Attribute attr = id_attr.getValue();\n      auto user_attr = ::oneflow::AttrValue();\n      const ::oneflow::AttrType attr_type = user_op::queryAttrType(op_type_name, attr_name);\n      if (attr_type == ::oneflow::kAtInt32) {\n        user_attr.set_at_int32(attr.dyn_cast<IntegerAttr>().getSInt());\n      } else if (attr_type == ::oneflow::kAtInt64) {\n        user_attr.set_at_int64(attr.dyn_cast<IntegerAttr>().getSInt());\n      } else if (attr_type == ::oneflow::kAtBool) {\n        user_attr.set_at_bool(attr.dyn_cast<BoolAttr>().getValue());\n      } else if (attr_type == ::oneflow::kAtFloat) {\n        user_attr.set_at_float(attr.dyn_cast<FloatAttr>().getValue().convertToFloat());\n      } else if (attr_type == ::oneflow::kAtDouble) {\n        user_attr.set_at_double(attr.dyn_cast<FloatAttr>().getValue().convertToDouble());\n      } else if (attr_type == ::oneflow::kAtString) {\n        user_attr.set_at_string(attr.dyn_cast<StringAttr>().getValue().str());\n      } else if (attr_type == ::oneflow::kAtShape) {\n        writeAttrToShape(attr, user_attr.mutable_at_shape());\n      } else if (attr_type == ::oneflow::kAtStride) {\n        writeAttrToStride(attr, user_attr.mutable_at_stride());\n      } else if (attr_type == ::oneflow::kAtDataType) {\n        const auto dt = support::FromMLIRAttrToOFDataType(attr);\n        if (succeeded(dt)) {\n          user_attr.set_at_data_type(dt.value());\n        } else {\n          op->emitError() << \"fail to convert op attr to data type, key: \" + id.str();\n          return failure();\n        }\n      } else if (attr_type == ::oneflow::kAtListInt32) {\n        user_attr.mutable_at_list_int32();\n        auto ref = attr.dyn_cast<ArrayAttr>();\n        for (auto v : ref.getValue()) {\n          user_attr.mutable_at_list_int32()->add_val(v.dyn_cast<IntegerAttr>().getSInt());\n        }\n      } else if (attr_type == ::oneflow::kAtListInt64) {\n        user_attr.mutable_at_list_int64();\n        auto ref = attr.dyn_cast<ArrayAttr>();\n        for (auto v : ref.getValue()) {\n          user_attr.mutable_at_list_int64()->add_val(v.dyn_cast<IntegerAttr>().getSInt());\n        }\n      } else if (attr_type == ::oneflow::kAtListFloat) {\n        user_attr.mutable_at_list_float();\n        auto ref = attr.dyn_cast<ArrayAttr>();\n        for (auto v : ref.getValue()) {\n          user_attr.mutable_at_list_float()->add_val(\n              v.dyn_cast<FloatAttr>().getValue().convertToFloat());\n        }\n      } else if (attr_type == ::oneflow::kAtListDataType) {\n        for (auto v : attr.dyn_cast<ArrayAttr>().getValue()) {\n          const auto dt = support::FromMLIRAttrToOFDataType(attr);\n          if (succeeded(dt)) {\n            user_attr.mutable_at_list_data_type()->add_val(dt.value());\n          } else {\n            op->emitError() << \"fail to convert op attr to data type, key: \" + id.str();\n            return failure();\n          }\n        }\n      } else if (attr_type == ::oneflow::kAtListShape) {\n        for (auto shape_attr : attr.dyn_cast<ArrayAttr>().getValue()) {\n          ::oneflow::ShapeProto* shape_ptr = user_attr.mutable_at_list_shape()->add_val();\n          writeAttrToShape(shape_attr, shape_ptr);\n        }\n      } else if (attr_type == ::oneflow::kAtListStride) {\n        for (auto stride_attr : attr.dyn_cast<ArrayAttr>().getValue()) {\n          ::oneflow::Int64ListProto* stride_ptr = user_attr.mutable_at_list_stride()->add_val();\n          writeAttrToStride(stride_attr, stride_ptr);\n        }\n      } else if (attr_type == ::oneflow::kAtListString) {\n        // attr like nd_sbp requires the existence of list even it is empty\n        user_attr.mutable_at_list_string();\n        for (auto s : attr.dyn_cast<ArrayAttr>().getValue()) {\n          user_attr.mutable_at_list_string()->add_val(s.dyn_cast<StringAttr>().getValue().str());\n        }\n      } else if (attr_type == ::oneflow::kAtComplexDouble) {\n        // TODO(lml): use arrayattr to represent complex number is not safe, need improve.\n        user_attr.mutable_at_complex_double();\n        auto ref = attr.dyn_cast<ArrayAttr>();\n        user_attr.mutable_at_complex_double()->set_real(\n            ref.getValue()[0].dyn_cast<FloatAttr>().getValue().convertToDouble());\n        user_attr.mutable_at_complex_double()->set_imag(\n            ref.getValue()[1].dyn_cast<FloatAttr>().getValue().convertToDouble());\n      } else if (attr_type == ::oneflow::kAtBytes) {\n        auto value = attr.dyn_cast<StringAttr>().getValue().str();\n        // The trailing null character also needs to be saved.\n        user_attr.mutable_at_bytes()->assign(value.data(), value.size() + 1);\n      } else {\n        op->emitError() << \"fail to convert op attr of name: \" + attr_name;\n        return failure();\n      }\n      (*user_conf->mutable_attr())[id.str()] = user_attr;\n    }\n  }\n  {\n    std::vector<std::string> keys{};\n    std::vector<int32_t> sizes{};\n    if (failed(user_op::GetFilteredSegmentKeyAndSizes<OpTrait::AttrSizedOperandSegments>(op, keys,\n                                                                                         sizes))) {\n      op->emitError(\"fail to convert user op input order\");\n      return failure();\n    }\n    for (const auto& s : keys) { op_conf.mutable_user_conf()->add_input_order(s); }\n\n    if (is_mapping_size) {\n      for (const auto it : llvm::zip(keys, sizes)) {\n        auto key = std::get<0>(it).c_str();\n        auto size = std::get<1>(it);\n        auto tar = op_conf.mutable_user_conf()->mutable_input();\n        auto val = ::oneflow::UserOpConf_ListString::default_instance();\n        tar->insert({key, val});\n        for (int i = 0; i < size; ++i) { tar->at(key).add_s(); }\n      }\n    }\n  }\n  {\n    std::vector<std::string> keys{};\n    std::vector<int32_t> sizes{};\n    if (failed(user_op::GetFilteredSegmentKeyAndSizes<OpTrait::AttrSizedResultSegments>(op, keys,\n                                                                                        sizes))) {\n      op->emitError(\"fail to convert user op output order\");\n      return failure();\n    }\n    for (const auto& s : keys) { op_conf.mutable_user_conf()->add_output_order(s); }\n    if (is_mapping_size) {\n      for (const auto it : llvm::zip(keys, sizes)) {\n        auto key = std::get<0>(it).c_str();\n        auto size = std::get<1>(it);\n        auto tar = op_conf.mutable_user_conf()->mutable_output();\n        auto val = ::oneflow::UserOpConf_ListString::default_instance();\n        tar->insert({key, val});\n        for (int i = 0; i < size; ++i) { tar->at(key).add_s(); }\n      }\n    }\n  }\n  return success();\n}\nLogicalResult ConvertUserOpInputs(llvm::StringRef op_type_name, ValueRange operands,\n                                  DictionaryAttr attributes, ::oneflow::UserOpConf* user_conf) {\n  std::vector<std::string> keys{};\n  std::vector<int32_t> sizes{};\n  CHECK(user_op::GetFilteredSegmentKeyAndSizes<OpTrait::AttrSizedOperandSegments>(\n            op_type_name, operands.size(), attributes, keys, sizes)\n            .succeeded());\n  int32_t input_idx = 0;\n  for (auto tuple : llvm::zip(keys, sizes)) {\n    auto input_key = std::get<0>(tuple);\n    auto input_size = std::get<1>(tuple);\n    for (int32_t i = 0; i < input_size; i++) {\n      auto input_s_ptr = (*user_conf->mutable_input())[input_key].mutable_s()->Add();\n      if (auto result = operands[input_idx].dyn_cast<mlir::OpResult>()) {\n        *(input_s_ptr) = GetOutputLbn(result).value();\n      } else if (auto argument = operands[input_idx].dyn_cast<mlir::BlockArgument>()) {\n        *(input_s_ptr) = \"BlockArgument/\" + std::to_string(argument.getArgNumber());\n      } else {\n        LOG(FATAL) << \"fail to convert MLIR result to protobuf, op_type_name: \"\n                          + op_type_name.str();\n        return failure();\n      }\n      input_idx += 1;\n    }\n  }\n  return success();\n}\n\n::oneflow::ShapeProto getAttrAsShape(mlir::Attribute& attr) {\n  ::oneflow::ShapeProto shape{};\n  for (auto v : attr.dyn_cast<ArrayAttr>().getValue()) {\n    shape.add_dim(v.dyn_cast<IntegerAttr>().getSInt());\n  }\n  return shape;\n}\n\n::oneflow::Int64ListProto getAttrAsStride(mlir::Attribute& attr) {\n  ::oneflow::Int64ListProto stride{};\n  for (auto v : attr.dyn_cast<ArrayAttr>().getValue()) {\n    stride.add_dim(v.dyn_cast<IntegerAttr>().getSInt());\n  }\n  return stride;\n}\n\n::oneflow::ParallelConf getParallelConfFromAttrDictionary(DictionaryAttr attributes) {\n  ::oneflow::ParallelConf parallel_conf{};\n  auto device_tag = attributes.get(OpTrait::IsOpConfCompatible<void>::getDeviceTagAttr())\n                        .dyn_cast_or_null<StringAttr>();\n  CHECK(device_tag) << \"attr absent: \"\n                    << OpTrait::IsOpConfCompatible<void>::getDeviceTagAttr().str();\n  parallel_conf.set_device_tag(device_tag.str());\n  auto device_name = attributes.get(OpTrait::IsOpConfCompatible<void>::getDeviceNameAttr())\n                         .dyn_cast_or_null<ArrayAttr>();\n  CHECK(device_name) << \"attr absent: \"\n                     << OpTrait::IsOpConfCompatible<void>::getDeviceNameAttr().str();\n  for (auto s : device_name.getValue()) {\n    parallel_conf.add_device_name(s.cast<StringAttr>().str());\n  }\n  if (auto hierarchy = attributes.get(OpTrait::IsOpConfCompatible<void>::getHierarchyAttr())\n                           .dyn_cast_or_null<ArrayAttr>()) {\n    for (auto dim : hierarchy.getValue()) {\n      parallel_conf.mutable_hierarchy()->add_dim(dim.template dyn_cast<IntegerAttr>().getInt());\n    }\n  }\n  return parallel_conf;\n}\n\n::oneflow::ParallelConf getParallelConfFromAttrs(Attribute device_name_attr,\n                                                 Attribute device_tag_attr) {\n  ::oneflow::ParallelConf parallel_conf{};\n  auto device_tag = device_tag_attr.dyn_cast_or_null<StringAttr>();\n  CHECK(device_tag) << \"attr absent: \"\n                    << OpTrait::IsOpConfCompatible<void>::getDeviceTagAttr().str();\n  parallel_conf.set_device_tag(device_tag.str());\n  auto device_name = device_name_attr.dyn_cast_or_null<ArrayAttr>();\n  CHECK(device_name) << \"attr absent: \"\n                     << OpTrait::IsOpConfCompatible<void>::getDeviceNameAttr().str();\n  for (auto s : device_name.getValue()) {\n    parallel_conf.add_device_name(s.cast<StringAttr>().str());\n  }\n  return parallel_conf;\n}\n\n::oneflow::DeviceType getDeviceTypeFromAttrDictionary(DictionaryAttr attributes) {\n  ::oneflow::ParallelConf parallel_conf{};\n  auto device_tag = attributes.get(OpTrait::IsOpConfCompatible<void>::getDeviceTagAttr())\n                        .dyn_cast_or_null<StringAttr>();\n  CHECK(device_tag) << \"attr absent: \"\n                    << OpTrait::IsOpConfCompatible<void>::getDeviceTagAttr().str();\n  if (device_tag.str() == \"cpu\") {\n    return ::oneflow::DeviceType::kCPU;\n  } else if (device_tag.str() == \"cuda\") {\n    return ::oneflow::DeviceType::kCUDA;\n  } else if (device_tag.str() == \"mlu\") {\n    return ::oneflow::DeviceType::kMLU;\n  } else if (device_tag.str() == \"npu\") {\n    return ::oneflow::DeviceType::kNPU;\n  } else if (device_tag.str() == \"xpu\") {\n    return ::oneflow::DeviceType::kXPU;\n  } else {\n    LOG(FATAL) << \"unsupported device tag: \" << device_tag.str();\n    return ::oneflow::DeviceType::kInvalidDevice;\n  }\n}\n\n::oneflow::AttrType queryAttrType(const std::string& op_type_name, const std::string& attr_name) {\n  ::oneflow::user_op::UserOpDefWrapper op_def(support::getUserOpDef(op_type_name));\n  CHECK(op_def.IsAttrName(attr_name)) << attr_name << \" not a attr name for op: \" << op_type_name;\n  return op_def.GetAttrType(attr_name);\n}\n\n}  // namespace user_op\n\n}  // namespace oneflow\n\n}  // namespace mlir\n"
  },
  {
    "path": "oneflow/ir/lib/OneFlow/UserOpReflection.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n// this file should contains functions to get operands and results with user op name and index\n\n#include \"OneFlow/UserOpReflection.h\"\n#include \"llvm/ADT/STLExtras.h\"\n#include \"llvm/Support/Casting.h\"\n\nnamespace mlir {\n\nnamespace oneflow {\n\nnamespace user_op {\n\ntemplate<template<typename T> class Trait>\nconst std::vector<std::string>* GetFullKeys(UserOpCompatible& uc, Operation* op);\ntemplate<template<typename T> class Trait>\nstd::vector<std::string> GetFullKeys(UserOp op);\ntemplate<template<typename T> class Trait>\nstd::vector<std::string> GetFullKeys(::llvm::StringRef op_type_name);\n\ntemplate<>\nconst std::vector<std::string>* GetFullKeys<OpTrait::AttrSizedOperandSegments>(UserOpCompatible& uc,\n                                                                               Operation* op) {\n  if (auto alternative_name = dyn_cast<HasAlternativeOpTypeName>(op)) {\n    return alternative_name.inputKeys();\n  }\n  return uc.inputKeys();\n}\n\ntemplate<>\nconst std::vector<std::string>* GetFullKeys<OpTrait::AttrSizedResultSegments>(UserOpCompatible& uc,\n                                                                              Operation* op) {\n  if (auto alternative_name = dyn_cast<HasAlternativeOpTypeName>(op)) {\n    return alternative_name.outputKeys();\n  }\n  return uc.outputKeys();\n}\n\ntemplate<>\nstd::vector<std::string> GetFullKeys<OpTrait::AttrSizedOperandSegments>(UserOp op) {\n  return mlir::oneflow::support::GetInputKeys(op.getOpTypeName().str());\n}\n\ntemplate<>\nstd::vector<std::string> GetFullKeys<OpTrait::AttrSizedResultSegments>(UserOp op) {\n  return mlir::oneflow::support::GetOutputKeys(op.getOpTypeName().str());\n}\n\ntemplate<>\nstd::vector<std::string> GetFullKeys<OpTrait::AttrSizedOperandSegments>(\n    ::llvm::StringRef op_type_name) {\n  return mlir::oneflow::support::GetInputKeys(op_type_name.str());\n}\n\ntemplate<>\nstd::vector<std::string> GetFullKeys<OpTrait::AttrSizedResultSegments>(\n    ::llvm::StringRef op_type_name) {\n  return mlir::oneflow::support::GetOutputKeys(op_type_name.str());\n}\n\ntemplate<template<typename T> class Trait>\nstd::pair<unsigned, unsigned> getODSIndexAndLength(UserOpCompatible& op, unsigned index);\n\ntemplate<>\nstd::pair<unsigned, unsigned> getODSIndexAndLength<OpTrait::AttrSizedOperandSegments>(\n    UserOpCompatible& op, unsigned index) {\n  return op.getODSOperandIndexAndLength(index);\n}\n\ntemplate<>\nstd::pair<unsigned, unsigned> getODSIndexAndLength<OpTrait::AttrSizedResultSegments>(\n    UserOpCompatible& op, unsigned index) {\n  return op.getODSResultIndexAndLength(index);\n}\n\ntemplate<template<typename T> class Trait>\nStringRef GetSegmentSizeAttr();\n\ntemplate<>\nStringRef GetSegmentSizeAttr<OpTrait::AttrSizedOperandSegments>() {\n  return OpTrait::AttrSizedOperandSegments<void>::getOperandSegmentSizeAttr();\n}\n\ntemplate<>\nStringRef GetSegmentSizeAttr<OpTrait::AttrSizedResultSegments>() {\n  return OpTrait::AttrSizedResultSegments<void>::getResultSegmentSizeAttr();\n}\n\ntemplate<template<typename T> class Trait>\nint32_t GetSingleSegmentSize(Operation*);\n\ntemplate<>\nint32_t GetSingleSegmentSize<OpTrait::AttrSizedOperandSegments>(Operation* op) {\n  return op->getNumOperands();\n}\n\ntemplate<>\nint32_t GetSingleSegmentSize<OpTrait::AttrSizedResultSegments>(Operation* op) {\n  return op->getNumResults();\n}\n\ntemplate<template<typename T> class Trait>\nArrayAttr GetUserOpArgSizes(UserOp);\n\ntemplate<>\nArrayAttr GetUserOpArgSizes<OpTrait::AttrSizedOperandSegments>(UserOp op) {\n  return op.getInputSizes();\n}\n\ntemplate<>\nArrayAttr GetUserOpArgSizes<OpTrait::AttrSizedResultSegments>(UserOp op) {\n  return op.getOutputSizes();\n}\n\ntemplate<template<typename T> class Trait>\nLogicalResult GetUserOpFilteredSegmentKeyAndSizes(UserOp op, std::vector<std::string>& keys,\n                                                  std::vector<int32_t>& sizes) {\n  auto full_keys = GetFullKeys<Trait>(op);\n  for (const auto& key_size_tuple : llvm::zip(full_keys, GetUserOpArgSizes<Trait>(op).getValue())) {\n    const std::string& key = std::get<0>(key_size_tuple);\n    const int32_t size =\n        std::get<1>(key_size_tuple).template cast<IntegerAttr>().getValue().getSExtValue();\n    if (size > 0) {\n      keys.push_back(key);\n      sizes.push_back(size);\n    }\n  }\n  return success();\n}\n\nSource GetOpSourceByName(Operation* op, const std::string& to_find) {\n  if (auto user_op = dyn_cast<UserOpCompatible>(op)) {\n    auto found = [&](std::vector<std::string> keys,\n                     bool find_in_results /*or in operands*/ = false) -> int {\n      auto offset = 0;\n      for (const auto& key : llvm::enumerate(keys)) {\n        if (key.value() == to_find) { return offset; }\n        offset += find_in_results ? user_op.getODSResultIndexAndLength(key.index()).second\n                                  : user_op.getODSOperandIndexAndLength(key.index()).second;\n      }\n      return -1;\n    };\n\n    if (auto alternative_name = dyn_cast<HasAlternativeOpTypeName>(op)) {\n      if (auto offset = found(*alternative_name.inputKeys()); offset != -1) {\n        return {Source::INPUT, offset};\n      }\n      if (auto offset = found(*alternative_name.outputKeys(), true); offset != -1) {\n        return {Source::OUTPUT, offset};\n      }\n    }\n\n    if (to_find == \"tmp_buffer\") { return {Source::BUFFER, 0}; }\n\n    if (auto offset = found(*user_op.inputKeys()); offset != -1) { return {Source::INPUT, offset}; }\n    if (auto offset = found(*user_op.outputKeys(), true); offset != -1) {\n      return {Source::OUTPUT, offset};\n    }\n\n    op->emitError(to_find + \" not found in this op\");\n    return {Source::INVALID, -1};\n  }\n  op->emitError(\"Not support op which is not user  op\");\n  return {Source::INVALID, -1};\n}\n\ntemplate<template<typename T> class Trait>\nLogicalResult GetFilteredSegmentKeyAndSizes(Operation* op, std::vector<std::string>& keys,\n                                            std::vector<int32_t>& sizes) {\n  if (auto user_op = dyn_cast<UserOp>(op)) {\n    return GetUserOpFilteredSegmentKeyAndSizes<Trait>(user_op, keys, sizes);\n  }\n  const std::vector<std::string>* full_keys = nullptr;\n  std::vector<int32_t> full_sizes{};\n  auto uc = dyn_cast<UserOpCompatible>(op);\n  if (!uc) {\n    op->emitError(\"interface UserOpCompatible not supported\");\n    return failure();\n  }\n  full_keys = GetFullKeys<Trait>(uc, op);\n  if (op->hasTrait<Trait>()) {\n    const StringRef attr_name = GetSegmentSizeAttr<Trait>();\n    const DenseI32ArrayAttr& size_attr = op->getAttrOfType<DenseI32ArrayAttr>(attr_name);\n    if (!size_attr) return failure();\n    auto segment_sizes = size_attr.asArrayRef();\n    if (full_keys->size() != segment_sizes.size()) {\n      op->emitError() << \"fail to convert op inputs, attr_name: \" << attr_name\n                      << \", full_keys: \" << full_keys->size()\n                      << \", segment_sizes: \" << segment_sizes.size() << \", name: \" << op->getName();\n      op->dump();\n      return failure();\n    };\n    full_sizes = {segment_sizes.begin(), segment_sizes.end()};\n  } else {\n    if (full_keys->size() == 1) {\n      full_sizes.push_back(GetSingleSegmentSize<Trait>(op));\n    } else {\n      for (const auto& key : llvm::enumerate(*full_keys)) {\n        full_sizes.push_back(getODSIndexAndLength<Trait>(uc, key.index()).second);\n      }\n    }\n  }\n  for (const auto& key_size_tuple : llvm::zip(*full_keys, full_sizes)) {\n    const std::string& key = std::get<0>(key_size_tuple);\n    const int32_t size = std::get<1>(key_size_tuple);\n    if (size > 0) {\n      keys.push_back(key);\n      sizes.push_back(size);\n    }\n  }\n  return success();\n}\n\ntemplate<template<typename T> class Trait>\nLogicalResult GetFilteredSegmentKeyAndSizes(llvm::StringRef op_type_name, size_t valueSize,\n                                            DictionaryAttr attributes,\n                                            std::vector<std::string>& keys,\n                                            std::vector<int32_t>& sizes) {\n  const std::vector<std::string> full_keys = GetFullKeys<Trait>(op_type_name);\n  std::vector<int32_t> full_sizes{};\n  const StringRef attr_name = GetSegmentSizeAttr<Trait>();\n  if (auto size_attr = attributes.get(attr_name).dyn_cast_or_null<DenseI32ArrayAttr>()) {\n    if (!size_attr) return failure();\n    auto segment_sizes = size_attr.asArrayRef();\n    if (full_keys.size() != segment_sizes.size()) {\n      LOG(FATAL) << \"fail to convert op inputs, attr_name: \" << attr_name.str()\n                 << \", full_keys: \" << full_keys.size()\n                 << \", segment_sizes: \" << segment_sizes.size();\n      return failure();\n    };\n    full_sizes = {segment_sizes.begin(), segment_sizes.end()};\n  } else {\n    if (full_keys.size() == 1) {\n      full_sizes.push_back(valueSize);\n    } else {\n      LOG(FATAL) << \"set attr: \" << attr_name.str();\n    }\n  }\n  for (const auto& key_size_tuple : llvm::zip(full_keys, full_sizes)) {\n    const std::string& key = std::get<0>(key_size_tuple);\n    const int32_t size = std::get<1>(key_size_tuple);\n    if (size > 0) {\n      keys.push_back(key);\n      sizes.push_back(size);\n    }\n  }\n  return success();\n}\n\ntemplate LogicalResult GetFilteredSegmentKeyAndSizes<OpTrait::AttrSizedOperandSegments>(\n    Operation* op, std::vector<std::string>& keys, std::vector<int32_t>& sizes);\ntemplate LogicalResult GetFilteredSegmentKeyAndSizes<OpTrait::AttrSizedResultSegments>(\n    Operation* op, std::vector<std::string>& keys, std::vector<int32_t>& sizes);\ntemplate LogicalResult GetFilteredSegmentKeyAndSizes<OpTrait::AttrSizedOperandSegments>(\n    llvm::StringRef op_type_name, size_t valueSize, DictionaryAttr attributes,\n    std::vector<std::string>& keys, std::vector<int32_t>& sizes);\ntemplate LogicalResult GetFilteredSegmentKeyAndSizes<OpTrait::AttrSizedResultSegments>(\n    llvm::StringRef op_type_name, size_t valueSize, DictionaryAttr attributes,\n    std::vector<std::string>& keys, std::vector<int32_t>& sizes);\n\ntemplate<template<typename T> class Trait>\nArgIds<Trait>::ArgIds(Operation* op) {\n  std::vector<std::string> keys;\n  std::vector<int32_t> sizes;\n  if (failed(GetFilteredSegmentKeyAndSizes<Trait>(op, keys, sizes))) {\n    op->emitError(\"fail to get filtered segment key and sizes\");\n    exit(1);\n  }\n  for (int i = 0; i < keys.size(); i += 1) {\n    auto& key = keys[i];\n    for (size_t j = 0; j < sizes[i]; j += 1) {\n      ArgID id{key, j};\n      ids_.push_back(id);\n    }\n  }\n}\n\ntemplate<template<typename T> class Trait>\nArgIds<Trait>::ArgIds(llvm::StringRef op_type_name, size_t valueSize, DictionaryAttr attributes) {\n  std::vector<std::string> keys{};\n  std::vector<int32_t> sizes{};\n  CHECK(user_op::GetFilteredSegmentKeyAndSizes<Trait>(op_type_name, valueSize, attributes, keys,\n                                                      sizes)\n            .succeeded());\n  for (int i = 0; i < keys.size(); i += 1) {\n    auto& key = keys[i];\n    for (size_t j = 0; j < sizes[i]; j += 1) {\n      ArgID id{key, j};\n      ids_.push_back(id);\n    }\n  }\n}\n\ntemplate oneflow::user_op::ArgIds<OpTrait::AttrSizedOperandSegments>::ArgIds(Operation*);\ntemplate oneflow::user_op::ArgIds<OpTrait::AttrSizedResultSegments>::ArgIds(Operation*);\ntemplate oneflow::user_op::ArgIds<OpTrait::AttrSizedOperandSegments>::ArgIds(\n    llvm::StringRef op_type_name, size_t valueSize, DictionaryAttr attributes);\ntemplate oneflow::user_op::ArgIds<OpTrait::AttrSizedResultSegments>::ArgIds(\n    llvm::StringRef op_type_name, size_t valueSize, DictionaryAttr attributes);\n\nllvm::Optional<std::string> GetOutputLbn(OpResult result) {\n  const auto def_op = result.getDefiningOp();\n  if (def_op->hasTrait<OpTrait::IsImportCompatible>()) {\n    return def_op\n        ->getAttrOfType<ArrayAttr>(\n            OpTrait::IsImportCompatible<void>::getOutputLBNsAttr())[result.getResultNumber()]\n        .dyn_cast<StringAttr>()\n        .getValue()\n        .str();\n  } else {\n    std::vector<std::string> def_op_keys{};\n    std::vector<int32_t> def_op_sizes{};\n    if (failed(user_op::GetFilteredSegmentKeyAndSizes<OpTrait::AttrSizedResultSegments>(\n            def_op, def_op_keys, def_op_sizes))) {\n      def_op->emitError(\"fail to get output lbn\");\n      return llvm::None;\n    }\n    const auto result_number = result.getResultNumber();\n    uint32_t size_sum = 0;\n    for (const auto& name_size_tuple : llvm::zip(def_op_keys, def_op_sizes)) {\n      auto name = std::get<0>(name_size_tuple);\n      auto size = std::get<1>(name_size_tuple);\n      if ((size_sum + size) > result_number) {\n        const uint32_t bn_i = result_number - size_sum;\n        return OpTrait::IsOpConfCompatible<void>::getOpName(def_op).str() + \"/\" + name + \"_\"\n               + std::to_string(bn_i);\n      }\n      size_sum += size;\n    }\n  }\n  return llvm::None;\n}\n\n}  // namespace user_op\n\n}  // namespace oneflow\n\n}  // namespace mlir\n"
  },
  {
    "path": "oneflow/ir/lib/Transform/CMakeLists.txt",
    "content": "add_mlir_library(\n  MLIROneFlowTransformDialect\n  TransformDialectExtension.cpp\n  TransformDialectInterpreter.cpp\n  TransformStateExtension.cpp\n  EXCLUDE_FROM_LIBMLIR\n  DEPENDS\n  MLIROneFlowTransformDialectExtensionIncGen\n  LINK_LIBS\n  PUBLIC\n  MLIRIR\n  MLIRPass\n  MLIRPDLDialect\n  MLIRTransformDialect\n  MLIRTransformDialectTransforms)\n"
  },
  {
    "path": "oneflow/ir/lib/Transform/TransformDialectExtension.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"OneFlow/Transform/OneFlowMemPool.h\"\n#include \"OneFlow/OneFlowPDLLPatterns.h\"\n#include \"Transform/TransformDialectExtension.h\"\n#include \"Transform/TransformStateExtension.h\"\n#include \"mlir/Dialect/PDL/IR/PDL.h\"\n#include \"mlir/Dialect/Func/IR/FuncOps.h\"\n#include \"mlir/Dialect/Transform/IR/TransformDialect.h\"\n#include \"mlir/Dialect/Transform/IR/TransformInterfaces.h\"\n#include \"mlir/IR/OpImplementation.h\"\n#include \"llvm/ADT/STLExtras.h\"\n#include \"llvm/ADT/TypeSwitch.h\"\n#include \"llvm/Support/Compiler.h\"\n#include \"llvm/Support/raw_ostream.h\"\n#include \"mlir/Dialect/Linalg/Transforms/Transforms.h\"\n#include \"mlir/Dialect/Bufferization/Transforms/Passes.h\"\n#include \"mlir/Transforms/GreedyPatternRewriteDriver.h\"\n#include \"mlir/Pass/PassManager.h\"\n#include \"mlir/Transforms/Passes.h\"\n\nusing namespace mlir;\nusing namespace mlir::oneflow;\nusing namespace mlir::transform;\n\nnamespace {\nstruct MemrefCopyOpFoldPatterns final : public OpRewritePattern<memref::CopyOp> {\n public:\n  using OpRewritePattern<memref::CopyOp>::OpRewritePattern;\n  LogicalResult matchAndRewrite(memref::CopyOp op, PatternRewriter& rewriter) const override {\n    if (op.getSource() == op.getTarget()) rewriter.eraseOp(op);\n    return success();\n  }\n};\n\n}  // namespace\n\nDiagnosedSilenceableFailure transform_dialect::EliminateCopyOp::applyToOne(\n    Operation* target, transform::ApplyToEachResultList& results,\n    transform::TransformState& state) {\n  MLIRContext* ctx = target->getContext();\n  RewritePatternSet patterns(ctx);\n  patterns.add<MemrefCopyOpFoldPatterns>(patterns.getContext());\n  mlir::oneflow::populateAllocEliminationPatterns(patterns);\n  SmallVector<Operation*> ops;\n  GreedyRewriteConfig config;\n  target->walk([&](Operation* nestedOp) {\n    if (target != nestedOp) ops.push_back(nestedOp);\n  });\n  LogicalResult result = applyOpPatternsAndFold(ops, std::move(patterns), config);\n  if (failed(result)) { return DiagnosedSilenceableFailure::definiteFailure(); }\n  return DiagnosedSilenceableFailure::success();\n}\n\nDiagnosedSilenceableFailure transform_dialect::ExplicitLinalgOutcomeOp::applyToOne(\n    Operation* target, transform::ApplyToEachResultList& results,\n    transform::TransformState& state) {\n  MLIRContext* ctx = target->getContext();\n  RewritePatternSet patterns(ctx);\n  linalg::populateFoldUnitExtentDimsViaSlicesPatterns(patterns);\n  SmallVector<Operation*> ops;\n  GreedyRewriteConfig config;\n  target->walk([&](Operation* nestedOp) {\n    if (target != nestedOp) ops.push_back(nestedOp);\n  });\n  LogicalResult result = applyOpPatternsAndFold(ops, std::move(patterns), config);\n  if (failed(result)) { return DiagnosedSilenceableFailure::definiteFailure(); }\n  return DiagnosedSilenceableFailure::success();\n}\n\nDiagnosedSilenceableFailure transform_dialect::CanonicalizationOp::applyToOne(\n    Operation* target, transform::ApplyToEachResultList& results,\n    transform::TransformState& state) {\n  MLIRContext* ctx = target->getContext();\n  RewritePatternSet patterns(ctx);\n  for (Dialect* dialect : ctx->getLoadedDialects()) dialect->getCanonicalizationPatterns(patterns);\n  for (RegisteredOperationName op : ctx->getRegisteredOperations())\n    op.getCanonicalizationPatterns(patterns, ctx);\n  SmallVector<Operation*> ops;\n  GreedyRewriteConfig config;\n  target->walk([&](Operation* nestedOp) {\n    if (target != nestedOp) ops.push_back(nestedOp);\n  });\n  LogicalResult result = applyOpPatternsAndFold(ops, std::move(patterns), config);\n  if (failed(result)) { return DiagnosedSilenceableFailure::definiteFailure(); }\n  return DiagnosedSilenceableFailure::success();\n}\n\nDiagnosedSilenceableFailure transform_dialect::FoldAllocOp::applyToOne(\n    Operation* target, transform::ApplyToEachResultList& results,\n    transform::TransformState& state) {\n  if (auto func = llvm::dyn_cast<func::FuncOp>(target)) { applyFoldAlloc(func); }\n  return DiagnosedSilenceableFailure::success();\n}\n\nDiagnosedSilenceableFailure transform_dialect::ResultsToOutParamsOp::applyToOne(\n    Operation* target, transform::ApplyToEachResultList& results,\n    transform::TransformState& state) {\n  if (auto module = llvm::dyn_cast<ModuleOp>(target)) {\n    if (failed(bufferization::promoteBufferResultsToOutParams(module, {}))) {\n      return DiagnosedSilenceableFailure::definiteFailure();\n    }\n  }\n  return DiagnosedSilenceableFailure::success();\n}\n\nDiagnosedSilenceableFailure transform_dialect::CSEOp::applyToOne(Operation* target,\n                                                                 ApplyToEachResultList& results,\n                                                                 transform::TransformState& state) {\n  auto context = target->getContext();\n  mlir::PassManager pm(context);\n  pm.addPass(createCSEPass());\n  if (failed(pm.run(target))) return mlir::emitDefiniteFailure(target, \"greedy patterns failed\");\n  return DiagnosedSilenceableFailure::success();\n}\n\nnamespace {\nclass OneFlowTransformDialectExtension\n    : public transform::TransformDialectExtension<OneFlowTransformDialectExtension> {\n public:\n  using Base::Base;\n\n  void init() {\n    declareDependentDialect<pdl::PDLDialect>();\n    registerTransformOps<\n#define GET_OP_LIST\n#include \"Transform/TransformDialectExtension.cpp.inc\"\n        >();\n    registerTypes<\n#define GET_TYPEDEF_LIST\n#include \"Transform/TransformDialectExtensionTypes.cpp.inc\"\n        >();\n  }\n};\n}  // namespace\n\n// These are automatically generated by ODS but are not used as the Transform\n// dialect uses a different dispatch mechanism to support dialect extensions.\nLLVM_ATTRIBUTE_UNUSED static OptionalParseResult generatedTypeParser(AsmParser& parser,\n                                                                     StringRef* mnemonic,\n                                                                     Type& value);\nLLVM_ATTRIBUTE_UNUSED static LogicalResult generatedTypePrinter(Type def, AsmPrinter& printer);\n\n#define GET_TYPEDEF_CLASSES\n#include \"Transform/TransformDialectExtensionTypes.cpp.inc\"\n\n#define GET_OP_CLASSES\n#include \"Transform/TransformDialectExtension.cpp.inc\"\n\nvoid mlir::oneflow::transform_dialect::registerTransformDialectExtension(\n    DialectRegistry& registry) {\n  registry.addExtensions<OneFlowTransformDialectExtension>();\n}\n"
  },
  {
    "path": "oneflow/ir/lib/Transform/TransformDialectInterpreter.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"mlir/Dialect/Transform/IR/TransformInterfaces.h\"\n#include \"mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h\"\n#include \"mlir/IR/Builders.h\"\n#include \"mlir/IR/BuiltinOps.h\"\n#include \"mlir/Pass/Pass.h\"\n\nusing namespace mlir;\n\nnamespace {\n/// Simple pass that applies transform dialect ops directly contained in a\n/// module.\n\ntemplate<typename Derived>\nclass OpPassWrapper : public PassWrapper<Derived, OperationPass<>> {};\n\nclass TransformDialectInterpreterPass\n    : public transform::TransformInterpreterPassBase<TransformDialectInterpreterPass,\n                                                     OpPassWrapper> {\n public:\n  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TransformDialectInterpreterPass)\n\n  TransformDialectInterpreterPass() = default;\n  TransformDialectInterpreterPass(const TransformDialectInterpreterPass& pass)\n      : TransformInterpreterPassBase(pass) {}\n\n  StringRef getArgument() const override { return \"oneflow-transform-dialect-interpreter\"; }\n\n  StringRef getDescription() const override {\n    return \"apply transform dialect operations one by one\";\n  }\n\n  void findOperationsByName(Operation* root, StringRef name,\n                            SmallVectorImpl<Operation*>& operations) {\n    root->walk([&](Operation* op) {\n      if (op->getName().getStringRef() == name) { operations.push_back(op); }\n    });\n  }\n\n  void createParameterMapping(MLIRContext& context, ArrayRef<int> values,\n                              RaggedArray<transform::MappedValue>& result) {\n    SmallVector<transform::MappedValue> storage =\n        llvm::to_vector(llvm::map_range(values, [&](int v) {\n          Builder b(&context);\n          return transform::MappedValue(b.getI64IntegerAttr(v));\n        }));\n    result.push_back(std::move(storage));\n  }\n\n  void createOpResultMapping(Operation* root, StringRef name,\n                             RaggedArray<transform::MappedValue>& extraMapping) {\n    SmallVector<Operation*> operations;\n    findOperationsByName(root, name, operations);\n    SmallVector<Value> results;\n    for (Operation* op : operations) llvm::append_range(results, op->getResults());\n    extraMapping.push_back(results);\n  }\n\n  unsigned numberOfSetOptions(const Option<std::string>& ops, const ListOption<int>& params,\n                              const Option<std::string>& values) {\n    unsigned numSetValues = 0;\n    numSetValues += !ops.empty();\n    numSetValues += !params.empty();\n    numSetValues += !values.empty();\n    return numSetValues;\n  }\n\n  void runOnOperation() override {\n    unsigned firstSetOptions = numberOfSetOptions(bindFirstExtraToOps, bindFirstExtraToParams,\n                                                  bindFirstExtraToResultsOfOps);\n    unsigned secondSetOptions = numberOfSetOptions(bindSecondExtraToOps, bindSecondExtraToParams,\n                                                   bindSecondExtraToResultsOfOps);\n    auto loc = UnknownLoc::get(&getContext());\n    if (firstSetOptions > 1) {\n      emitError(loc) << \"cannot bind the first extra top-level argument to \"\n                        \"multiple entities\";\n      return signalPassFailure();\n    }\n    if (secondSetOptions > 1) {\n      emitError(loc) << \"cannot bind the second extra top-level argument to \"\n                        \"multiple entities\";\n      return signalPassFailure();\n    }\n    if (firstSetOptions == 0 && secondSetOptions != 0) {\n      emitError(loc) << \"cannot bind the second extra top-level argument \"\n                        \"without bindings the first\";\n    }\n\n    RaggedArray<transform::MappedValue> extraMapping;\n    if (!bindFirstExtraToOps.empty()) {\n      SmallVector<Operation*> operations;\n      findOperationsByName(getOperation(), bindFirstExtraToOps.getValue(), operations);\n      extraMapping.push_back(operations);\n    } else if (!bindFirstExtraToParams.empty()) {\n      createParameterMapping(getContext(), bindFirstExtraToParams, extraMapping);\n    } else if (!bindFirstExtraToResultsOfOps.empty()) {\n      createOpResultMapping(getOperation(), bindFirstExtraToResultsOfOps, extraMapping);\n    }\n\n    if (!bindSecondExtraToOps.empty()) {\n      SmallVector<Operation*> operations;\n      findOperationsByName(getOperation(), bindSecondExtraToOps, operations);\n      extraMapping.push_back(operations);\n    } else if (!bindSecondExtraToParams.empty()) {\n      createParameterMapping(getContext(), bindSecondExtraToParams, extraMapping);\n    } else if (!bindSecondExtraToResultsOfOps.empty()) {\n      createOpResultMapping(getOperation(), bindSecondExtraToResultsOfOps, extraMapping);\n    }\n\n    options = options.enableExpensiveChecks(enableExpensiveChecks);\n    if (failed(transform::detail::interpreterBaseRunOnOperationImpl(\n            getOperation(), getArgument(), getSharedTransformModule(), getTransformLibraryModule(),\n            extraMapping, options, transformFileName, transformLibraryFileName, debugPayloadRootTag,\n            debugTransformRootTag, getBinaryName())))\n      return signalPassFailure();\n  }\n\n  Option<bool> enableExpensiveChecks{\n      *this, \"enable-expensive-checks\", llvm::cl::init(false),\n      llvm::cl::desc(\"perform expensive checks to better report errors in the \"\n                     \"transform IR\")};\n\n  Option<std::string> bindFirstExtraToOps{\n      *this, \"bind-first-extra-to-ops\",\n      llvm::cl::desc(\"bind the first extra argument of the top-level op to \"\n                     \"payload operations of the given kind\")};\n  ListOption<int> bindFirstExtraToParams{\n      *this, \"bind-first-extra-to-params\",\n      llvm::cl::desc(\"bind the first extra argument of the top-level op to \"\n                     \"the given integer parameters\")};\n  Option<std::string> bindFirstExtraToResultsOfOps{\n      *this, \"bind-first-extra-to-results-of-ops\",\n      llvm::cl::desc(\"bind the first extra argument of the top-level op to \"\n                     \"results of payload operations of the given kind\")};\n\n  Option<std::string> bindSecondExtraToOps{\n      *this, \"bind-second-extra-to-ops\",\n      llvm::cl::desc(\"bind the second extra argument of the top-level op to \"\n                     \"payload operations of the given kind\")};\n  ListOption<int> bindSecondExtraToParams{\n      *this, \"bind-second-extra-to-params\",\n      llvm::cl::desc(\"bind the second extra argument of the top-level op to \"\n                     \"the given integer parameters\")};\n  Option<std::string> bindSecondExtraToResultsOfOps{\n      *this, \"bind-second-extra-to-results-of-ops\",\n      llvm::cl::desc(\"bind the second extra argument of the top-level op to \"\n                     \"results of payload operations of the given kind\")};\n\n  Option<std::string> transformFileName{\n      *this, \"transform-file-name\", llvm::cl::init(\"\"),\n      llvm::cl::desc(\"Optional filename containing a transform dialect specification to \"\n                     \"apply. If left empty, the IR is assumed to contain one top-level \"\n                     \"transform dialect operation somewhere in the module.\")};\n  Option<std::string> debugPayloadRootTag{\n      *this, \"debug-payload-root-tag\", llvm::cl::init(\"\"),\n      llvm::cl::desc(\"Select the operation with 'transform.target_tag' attribute having \"\n                     \"the given value as payload IR root. If empty select the pass anchor \"\n                     \"operation as the payload IR root.\")};\n  Option<std::string> debugTransformRootTag{\n      *this, \"debug-transform-root-tag\", llvm::cl::init(\"\"),\n      llvm::cl::desc(\"Select the operation with 'transform.target_tag' attribute having \"\n                     \"the given value as container IR for top-level transform ops. This \"\n                     \"allows user control on what transformation to apply. If empty, \"\n                     \"select the container of the top-level transform op.\")};\n  Option<std::string> transformLibraryFileName{\n      *this, \"transform-library-file-name\", llvm::cl::init(\"\"),\n      llvm::cl::desc(\"Optional name of the file containing transform dialect symbol \"\n                     \"definitions to be injected into the transform module.\")};\n};\n\nstruct TransformDialectEraseSchedulePass\n    : public PassWrapper<TransformDialectEraseSchedulePass, OperationPass<ModuleOp>> {\n  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TransformDialectEraseSchedulePass)\n\n  StringRef getArgument() const final { return \"oneflow-transform-dialect-erase-schedule\"; }\n\n  StringRef getDescription() const final { return \"erase transform dialect schedule from the IR\"; }\n\n  void runOnOperation() override {\n    getOperation()->walk<WalkOrder::PreOrder>([&](Operation* nestedOp) {\n      if (isa<transform::TransformOpInterface>(nestedOp)) {\n        nestedOp->erase();\n        return WalkResult::skip();\n      }\n      return WalkResult::advance();\n    });\n  }\n};\n}  // namespace\n\nnamespace mlir {\nnamespace oneflow {\nnamespace transform_dialect {\n/// Registers the test pass for erasing transform dialect ops.\nvoid registerTransformDialectEraseSchedulePass() {\n  PassRegistration<TransformDialectEraseSchedulePass> reg;\n}\n/// Registers the test pass for applying transform dialect ops.\nvoid registerTransformDialectInterpreterPass() {\n  PassRegistration<TransformDialectInterpreterPass> reg;\n}\n}  // namespace transform_dialect\n}  // namespace oneflow\n}  // namespace mlir\n"
  },
  {
    "path": "oneflow/ir/lib/Transform/TransformStateExtension.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"Transform/TransformStateExtension.h\"\n\nusing namespace mlir;\n\nLogicalResult mlir::oneflow::transform_dialect::TransformStateExtension::updateMapping(\n    Operation* previous, Operation* updated) {\n  // Update value handles. The new ops should have at least as many results as\n  // the replacement op. Fewer results are acceptable, if those results are not\n  // mapped to any handle.\n  for (auto r = updated->getNumResults(); r < previous->getNumResults(); ++r) {\n    SmallVector<Value> handles;\n    (void)getTransformState().getHandlesForPayloadValue(previous->getResult(r), handles);\n    if (!handles.empty())\n      return emitError(previous->getLoc())\n             << \"cannot replace an op with another op producing fewer results \"\n                \"while tracking handles\";\n  }\n\n  for (auto [oldValue, newValue] : llvm::zip(previous->getResults(), updated->getResults()))\n    if (failed(replacePayloadValue(oldValue, newValue))) return failure();\n\n  // Update op handle.\n  return replacePayloadOp(previous, updated);\n}\n"
  },
  {
    "path": "oneflow/ir/llvm-in-tree.cmake",
    "content": "include(FetchContent)\nmessage(\"-- LLVM_MONO_REPO_URL: \" ${LLVM_MONO_REPO_URL})\nmessage(\"-- LLVM_MONO_REPO_MD5: \" ${LLVM_MONO_REPO_MD5})\nFetchContent_Declare(llvm_monorepo)\nFetchContent_GetProperties(llvm_monorepo)\n\nset(LLVM_INSTALL_DIR ${THIRD_PARTY_DIR}/llvm)\n\nif(NOT llvm_monorepo_POPULATED)\n  FetchContent_Populate(llvm_monorepo URL ${LLVM_MONO_REPO_URL} URL_HASH MD5=${LLVM_MONO_REPO_MD5})\nendif()\n\nset(CMAKE_INSTALL_PREFIX ${LLVM_INSTALL_DIR} CACHE STRING \"\" FORCE)\nset(LLVM_ENABLE_RTTI ON CACHE BOOL \"turn this on to make it compatible with protobuf\")\nset(LLVM_ENABLE_EH ON CACHE BOOL \"turn this on to make it compatible with half (the library)\")\nset(LLVM_ENABLE_TERMINFO OFF\n    CACHE BOOL \"disable terminfo in llvm so that oneflow doesn't need to link against it\")\nset(LLVM_BUILD_EXAMPLES OFF CACHE BOOL \"\")\nset(LLVM_BUILD_TOOLS OFF CACHE BOOL \"\")\nset(LLVM_INCLUDE_EXAMPLES OFF CACHE BOOL \"\")\nset(LLVM_INCLUDE_TESTS OFF CACHE BOOL \"\" FORCE)\nset(MLIR_INCLUDE_TESTS OFF CACHE BOOL \"\" FORCE)\nset(LLVM_INCLUDE_BENCHMARKS OFF CACHE BOOL \"\")\nset(LLVM_TARGETS_TO_BUILD host;NVPTX CACHE STRING \"\")\nset(LLVM_ENABLE_ASSERTIONS ON CACHE BOOL \"\")\nset(LLVM_ENABLE_PROJECTS mlir CACHE STRING \"\")\nset(LLVM_APPEND_VC_REV OFF CACHE BOOL \"\")\nset(LLVM_ENABLE_ZLIB OFF CACHE BOOL \"\")\nset(LLVM_INSTALL_UTILS ON CACHE BOOL \"\")\nset(LLVM_ENABLE_OCAMLDOC OFF CACHE BOOL \"\")\nset(LLVM_ENABLE_BINDINGS OFF CACHE BOOL \"\")\nset(LLVM_OPTIMIZED_TABLEGEN ON CACHE BOOL \"\" FORCE)\nset(MLIR_ENABLE_CUDA_RUNNER ${WITH_MLIR_CUDA_CODEGEN} CACHE BOOL \"\" FORCE)\nset(LLVM_MAIN_SRC_DIR ${llvm_monorepo_SOURCE_DIR}/llvm)\nset(LLVM_BINARY_DIR ${llvm_monorepo_BINARY_DIR})\nset(LLVM_TOOLS_BINARY_DIR ${llvm_monorepo_BINARY_DIR}/bin CACHE STRING \"\" FORCE)\nset(MLIR_MAIN_SRC_DIR ${LLVM_MAIN_SRC_DIR}/../mlir)\nset(MLIR_INCLUDE_DIR ${LLVM_MAIN_SRC_DIR}/../mlir/include)\nset(MLIR_GENERATED_INCLUDE_DIR ${LLVM_BINARY_DIR}/tools/mlir/include)\nset(MLIR_INCLUDE_DIRS \"${MLIR_INCLUDE_DIR};${MLIR_GENERATED_INCLUDE_DIR}\")\n\nset(llvm_monorepo_BINARY_DIR ${llvm_monorepo_BINARY_DIR})\ninstall(TARGETS oneflow of_protoobj of_functional_obj EXPORT oneflow DESTINATION lib)\ninstall(EXPORT oneflow DESTINATION lib/oneflow)\nadd_subdirectory(${llvm_monorepo_SOURCE_DIR}/llvm ${llvm_monorepo_BINARY_DIR})\nset(LLVM_INCLUDE_DIRS ${LLVM_MAIN_SRC_DIR}/include;${llvm_monorepo_BINARY_DIR}/include)\nset(LLVM_EXTERNAL_LIT \"${llvm_monorepo_BINARY_DIR}/bin/llvm-lit\" CACHE STRING \"\" FORCE)\nset(LTDL_SHLIB_EXT ${CMAKE_SHARED_LIBRARY_SUFFIX})\nset(LLVM_LIBRARY_DIR \"${llvm_monorepo_BINARY_DIR}/lib\")\n"
  },
  {
    "path": "oneflow/ir/oneflow-extension/CMakeLists.txt",
    "content": "include_directories(${PROJECT_SOURCE_DIR}/oneflow-extension/include)\n\noneflow_add_mlir_library(\n  MLIROneFlowExtension\n  mlir_jit_op.cpp\n  mlir_jit_op_kernel.cpp\n  ir_pass.cpp\n  lr_jit.cpp\n  mlir_gen.cpp\n  DEPENDS\n  LINK_LIBS\n  PUBLIC\n  MLIRIR\n  ${dialect_libs}\n  ${translation_libs}\n  MLIRIR\n  MLIRParser\n  MLIRPass\n  MLIRSPIRVDialect\n  MLIRTranslateLib\n  MLIRSupport\n  MLIROneFlow\n  oneflow\n  MLIRExecutionEngine\n  MLIROneFlowTranslation\n  MLIROneFlowRuntime)\nmlir_check_all_link_libraries(MLIROneFlowExtension)\nadd_custom_target(mex DEPENDS MLIROneFlowExtension)\n"
  },
  {
    "path": "oneflow/ir/oneflow-extension/README.md",
    "content": "# OneFlow extension of MLIR features]\n\n## KernelLaunchOp\n\n### Stage 1\n\n- 1:1 conversion from user op to kernel launch op\n\n### Stage 2\n\n- multi user op merged into one single kernel launch op\n\n### Stage 3\n\n- oneflow-opt and similar non-python execution environment\n- multi-gpu/multi-node compilation support (in the beginning it is all single-node with broadcast SBP signature)\n\n### relationship with MlirJitOp\n\n- the graph of a MlirJitOp might contain one or multiple kernel launch op\n- an op inside the graph of MlirJitOp could be optionally lowered to a kernel launch op\n"
  },
  {
    "path": "oneflow/ir/oneflow-extension/include/CMakeLists.txt",
    "content": "add_subdirectory(OneFlow)\n"
  },
  {
    "path": "oneflow/ir/oneflow-extension/include/OneFlow/CMakeLists.txt",
    "content": "\n"
  },
  {
    "path": "oneflow/ir/oneflow-extension/include/OneFlow/JITOpInfer.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_IR_ONEFLOW_EXTENSION_INCLUDE_ONEFLOW_JITOPINFER_H_\n#define ONEFLOW_IR_ONEFLOW_EXTENSION_INCLUDE_ONEFLOW_JITOPINFER_H_\n\n#include \"oneflow/core/framework/infer_util.h\"\n\nnamespace oneflow {\n\nnamespace ir {\n\nnamespace jit {\n\nMaybe<void> InferTensorDesc(user_op::InferContext* ctx);\nMaybe<void> SetTensorDataType(user_op::InferContext* ctx);\n\n}  // namespace jit\n\n}  // namespace ir\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_IR_ONEFLOW_EXTENSION_INCLUDE_ONEFLOW_JITOPINFER_H_\n"
  },
  {
    "path": "oneflow/ir/oneflow-extension/include/OneFlow/OneFlowLRJITRegistry.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_IR_ONEFLOW_EXTENSION_INCLUDE_ONEFLOW_ONEFLOW_LRJIT_REGISTRY_H_\n#define ONEFLOW_IR_ONEFLOW_EXTENSION_INCLUDE_ONEFLOW_ONEFLOW_LRJIT_REGISTRY_H_\n\n#include \"oneflow/core/common/just.h\"\n#include \"oneflow/core/common/singleton.h\"\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/ir/oneflow-extension/include/PyAst/Ast.h\"\n\n#include <cstdint>\n#include <memory>\n#include <unordered_map>\n#include <utility>\n#include <iostream>\n#include <string>\n\nnamespace mlir {\nclass ExecutionEngine;\n}\n\ntypedef std::pair<std::shared_ptr<mlir::ExecutionEngine>, std::function<double(double, double)>>\n    LRJITRegistry_Store_;\n\nclass LRJITRegistry final {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(LRJITRegistry);\n  ~LRJITRegistry() = default;\n\n  void Register(const std::string& function_id, pyast::FunctionDef& ast, bool is_dump);\n  std::function<double(double, double)> LookUp(const std::string& function_id);\n\n private:\n  friend class oneflow::Singleton<LRJITRegistry>;\n  LRJITRegistry() = default;\n\n  std::unordered_map<std::string, LRJITRegistry_Store_> functionId2engine_;\n};\n\n#endif  // ONEFLOW_IR_ONEFLOW_EXTENSION_INCLUDE_ONEFLOW_ONEFLOW_LRJIT_REGISTRY_H_\n"
  },
  {
    "path": "oneflow/ir/oneflow-extension/include/OneFlow/OneFlowRoundTrip.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_IR_ONEFLOW_EXTENSION_INCLUDE_ONEFLOW_ROUNDTRIP_H_\n#define ONEFLOW_IR_ONEFLOW_EXTENSION_INCLUDE_ONEFLOW_ROUNDTRIP_H_\n\n#include \"oneflow/core/job_rewriter/job_pass.h\"\n\nnamespace oneflow {\n\nenum IRPassType : int32_t { kBeforeAD = 0, kAfterAD = 1 };\n\ntemplate<IRPassType ir_pass_type>\nclass IRRoundTrip final : public JobPass {\n public:\n  IRRoundTrip() = default;\n  ~IRRoundTrip() override = default;\n  bool IsEnabled(const JobPassCtx& ctx) const;\n  Maybe<void> Apply(Job* job, JobPassCtx* ctx) const override;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_IR_ONEFLOW_EXTENSION_INCLUDE_ONEFLOW_ROUNDTRIP_H_\n"
  },
  {
    "path": "oneflow/ir/oneflow-extension/include/PyAst/Ast.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_IR_ONEFLOW_EXTENSION_INCLUDE_PYAST_AST_H_\n#define ONEFLOW_IR_ONEFLOW_EXTENSION_INCLUDE_PYAST_AST_H_\n\n#include <utility>\n#include <vector>\n#include <string>\n#include <memory>\n\nnamespace pyast {\n\nusing namespace std;\ntypedef string identifier;\n\nclass arg {\n  identifier id;\n\n public:\n  explicit arg(const identifier& arg) : id(arg) {}\n\n  identifier get_arg() { return id; }\n\n  static shared_ptr<arg> arg_(const identifier& arg_) { return make_shared<arg>(arg_); }\n};\n\nclass arguments {\n  vector<shared_ptr<arg>> args;\n\n public:\n  explicit arguments(vector<shared_ptr<arg>> args) : args(std::move(args)) {}\n\n  vector<shared_ptr<arg>> get_args() { return args; }\n\n  static shared_ptr<arguments> arguments_(vector<shared_ptr<arg>> args) {\n    return make_shared<arguments>(args);\n  }\n};\n\nclass stmt {\n public:\n  enum StmtKind {\n    kFunctionDef,\n    kReturn,\n    kAssign,\n    kIf,\n    kRaise,\n    kAssert,\n    kExpr,\n  };\n\n  explicit stmt(StmtKind kind) : kind(kind) {}\n  virtual ~stmt() = default;\n\n  StmtKind get_kind() const { return kind; }\n\n private:\n  StmtKind kind;\n};\n\nclass expr {\n public:\n  enum ExprKind {\n    kBoolOp,\n    kBinOp,\n    kLambda,\n    kCompare,\n    kCall,\n    kNum,\n    kConstant,\n    kAttribute,\n    kName,\n  };\n\n  explicit expr(ExprKind kind) : kind(kind) {}\n  virtual ~expr() = default;\n\n  ExprKind get_kind() const { return kind; }\n\n private:\n  ExprKind kind;\n};\n\nclass FunctionDef : public stmt {\n  identifier name;\n  shared_ptr<arguments> args;\n  vector<shared_ptr<stmt>> body;\n\n public:\n  FunctionDef(identifier name, shared_ptr<arguments> args, vector<shared_ptr<stmt>> body)\n      : stmt(kFunctionDef), name(std::move(name)), args(std::move(args)), body(std::move(body)) {}\n\n  static shared_ptr<FunctionDef> FunctionDef_(identifier name, shared_ptr<arguments> args,\n                                              vector<shared_ptr<stmt>> body) {\n    return make_shared<FunctionDef>(name, args, body);\n  }\n\n  identifier get_name() { return name; }\n  shared_ptr<arguments> get_args() { return args; }\n  vector<shared_ptr<stmt>> get_body() { return body; }\n\n  static bool classof(const stmt* c) { return c->get_kind() == kFunctionDef; }\n};\n\nclass Return : public stmt {\n  shared_ptr<expr> value;\n\n public:\n  explicit Return(shared_ptr<expr> value) : stmt(kReturn), value(std::move(value)) {}\n\n  static shared_ptr<Return> Return_(shared_ptr<expr> value) { return make_shared<Return>(value); }\n\n  shared_ptr<expr> get_value() { return value; }\n\n  static bool classof(const stmt* c) { return c->get_kind() == kReturn; }\n};\n\nclass Assign : public stmt {\n  vector<shared_ptr<expr>> targets;\n  shared_ptr<expr> value;\n\n public:\n  Assign(vector<shared_ptr<expr>> targets, shared_ptr<expr> value)\n      : stmt(kAssign), targets(std::move(targets)), value(std::move(value)) {}\n\n  static shared_ptr<Assign> Assign_(vector<shared_ptr<expr>> targets, shared_ptr<expr> value) {\n    return make_shared<Assign>(targets, value);\n  }\n\n  shared_ptr<expr> get_value() { return value; }\n  vector<shared_ptr<expr>> get_targets() { return targets; }\n\n  static bool classof(const stmt* c) { return c->get_kind() == kAssign; }\n};\n\nclass If : public stmt {\n  shared_ptr<expr> test;\n  vector<shared_ptr<stmt>> body;\n  vector<shared_ptr<stmt>> orelse;\n\n public:\n  If(shared_ptr<expr> test, vector<shared_ptr<stmt>> body, vector<shared_ptr<stmt>> orelse)\n      : stmt(kIf), test(std::move(test)), body(std::move(body)), orelse(orelse) {}\n\n  static shared_ptr<If> If_(shared_ptr<expr> test, vector<shared_ptr<stmt>> body,\n                            vector<shared_ptr<stmt>> orelse) {\n    return make_shared<If>(test, body, orelse);\n  }\n\n  shared_ptr<expr> get_test() { return test; }\n  vector<shared_ptr<stmt>> get_body() { return body; }\n  vector<shared_ptr<stmt>> get_orelse() { return orelse; }\n\n  static bool classof(const stmt* c) { return c->get_kind() == kIf; }\n};\n\nclass Raise : public stmt {\n  shared_ptr<expr> exc;\n  shared_ptr<expr> cause;\n\n public:\n  Raise(shared_ptr<expr> exc, shared_ptr<expr> cause)\n      : stmt(kRaise), exc(std::move(exc)), cause(std::move(cause)) {}\n\n  static shared_ptr<Raise> Raise_(shared_ptr<expr> exc, shared_ptr<expr> cause) {\n    return make_shared<Raise>(exc, cause);\n  }\n\n  shared_ptr<expr> get_exc() { return exc; }\n  shared_ptr<expr> get_cause() { return cause; }\n\n  static bool classof(const stmt* c) { return c->get_kind() == kRaise; }\n};\n\nclass Assert : public stmt {\n  shared_ptr<expr> test;\n  shared_ptr<expr> msg;\n\n public:\n  Assert(shared_ptr<expr> test, shared_ptr<expr> msg)\n      : stmt(kAssert), test(std::move(test)), msg(std::move(msg)) {}\n\n  static shared_ptr<Assert> Assert_(shared_ptr<expr> test, shared_ptr<expr> msg) {\n    return make_shared<Assert>(test, msg);\n  }\n  shared_ptr<expr> get_test() { return test; }\n  shared_ptr<expr> get_msg() { return msg; }\n\n  static bool classof(const stmt* c) { return c->get_kind() == kAssert; }\n};\n\nclass Expr : public stmt {\n  shared_ptr<expr> value;\n\n public:\n  explicit Expr(shared_ptr<expr> value) : stmt(kExpr), value(std::move(value)) {}\n\n  static shared_ptr<Expr> Expr_(shared_ptr<expr> value) { return make_shared<Expr>(value); }\n\n  shared_ptr<expr> get_value() { return value; }\n\n  static bool classof(const stmt* c) { return c->get_kind() == kExpr; }\n};\n\nclass BoolOp : public expr {\n public:\n  enum boolop_t {\n    kAnd = 1,\n    kOr,\n  };\n  BoolOp(boolop_t op, vector<shared_ptr<expr>> values)\n      : expr(kBoolOp), op(op), values(std::move(values)) {}\n\n  static shared_ptr<BoolOp> BoolOp_(boolop_t op, vector<shared_ptr<expr>> values) {\n    return make_shared<BoolOp>(op, values);\n  }\n\n  boolop_t get_op() { return op; }\n  vector<shared_ptr<expr>> get_values() { return values; }\n\n  static bool classof(const expr* c) { return c->get_kind() == kBoolOp; }\n\n private:\n  boolop_t op;\n  vector<shared_ptr<expr>> values;\n};\n\nclass BinOp : public expr {\n public:\n  enum operator_t {\n    kAdd = 1,\n    kSub,\n    kMult,\n    kDiv,\n    kPow,\n  };\n\n  BinOp(shared_ptr<expr> left, operator_t op, shared_ptr<expr> right)\n      : expr(kBinOp), left(std::move(left)), right(std::move(right)), op(std::move(op)) {}\n\n  BinOp(shared_ptr<expr> left, int op, shared_ptr<expr> right)\n      : expr(kBinOp), left(std::move(left)), right(std::move(right)), op(int2op(op)) {}\n\n  static shared_ptr<BinOp> BinOp_(shared_ptr<expr> left, int op, shared_ptr<expr> right) {\n    return make_shared<BinOp>(left, op, right);\n  }\n\n  static operator_t int2op(int op) { return operator_t(op); }\n\n  operator_t get_op() { return op; }\n  shared_ptr<expr> get_left() { return left; }\n  shared_ptr<expr> get_right() { return right; }\n\n  static bool classof(const expr* c) { return c->get_kind() == kBinOp; }\n\n private:\n  shared_ptr<expr> left;\n  shared_ptr<expr> right;\n  operator_t op;\n};\n\nclass Lambda : public expr {\n  shared_ptr<arguments> args;\n  shared_ptr<expr> body;\n\n public:\n  Lambda(shared_ptr<arguments> args, shared_ptr<expr> body)\n      : expr(kLambda), args(std::move(args)), body(std::move(body)) {}\n\n  static shared_ptr<Lambda> Lambda_(shared_ptr<arguments> args, shared_ptr<expr> body) {\n    return make_shared<Lambda>(args, body);\n  }\n\n  shared_ptr<arguments> get_args() { return args; }\n  shared_ptr<expr> get_body() { return body; }\n\n  static bool classof(const expr* c) { return c->get_kind() == kLambda; }\n};\n\nclass Compare : public expr {\n public:\n  enum cmpop_t {\n    kEq = 1,\n    kNotEq,\n    kLt,\n    kLtE,\n    kGt,\n    kGtE,\n  };\n\n  Compare(shared_ptr<expr> left, vector<cmpop_t> ops, vector<shared_ptr<expr>> comparators)\n      : expr(kCompare),\n        left(std::move(left)),\n        ops(std::move(ops)),\n        comparators(std::move(comparators)) {}\n\n  Compare(shared_ptr<expr> left, const vector<int>& ops, vector<shared_ptr<expr>> comparators)\n      : expr(kCompare),\n        left(std::move(left)),\n        ops(int2op(ops)),\n        comparators(std::move(comparators)) {}\n\n  static shared_ptr<Compare> Compare_(shared_ptr<expr> left, vector<int> ops,\n                                      vector<shared_ptr<expr>> comparators) {\n    return make_shared<Compare>(left, ops, comparators);\n  }\n\n  static vector<cmpop_t> int2op(const vector<int>& op) {\n    vector<cmpop_t> res;\n    for (auto i : op) res.emplace_back(cmpop_t(i));\n    return res;\n  }\n\n  vector<cmpop_t> get_ops() { return ops; }\n  shared_ptr<expr> get_left() { return left; }\n  vector<shared_ptr<expr>> get_comparators() { return comparators; }\n\n  static bool classof(const expr* c) { return c->get_kind() == kCompare; }\n\n private:\n  shared_ptr<expr> left;\n  vector<cmpop_t> ops;\n  vector<shared_ptr<expr>> comparators;\n};\n\nclass Call : public expr {\n  shared_ptr<expr> func;\n  vector<shared_ptr<expr>> args;\n\n public:\n  Call(shared_ptr<expr> func, vector<shared_ptr<expr>> args)\n      : expr(kCall), func(std::move(func)), args(std::move(args)) {}\n\n  static shared_ptr<Call> Call_(shared_ptr<expr> func, vector<shared_ptr<expr>> args) {\n    return make_shared<Call>(func, args);\n  }\n\n  shared_ptr<expr> get_func() { return func; }\n  vector<shared_ptr<expr>> get_args() { return args; }\n\n  static bool classof(const expr* c) { return c->get_kind() == kCall; }\n};\n\nclass Num : public expr {\n  double value;\n\n public:\n  explicit Num(double value) : expr(kNum), value(value) {}\n\n  static shared_ptr<Num> Num_(double value) { return make_shared<Num>(value); }\n\n  double get_value() { return value; }\n  static bool classof(const expr* c) { return c->get_kind() == kNum; }\n};\n\nclass Constant : public expr {\n  double value;\n\n public:\n  explicit Constant(double value) : expr(kConstant), value(value) {}\n\n  static shared_ptr<Constant> Constant_(double value) { return make_shared<Constant>(value); }\n\n  double get_value() { return value; }\n  static bool classof(const expr* c) { return c->get_kind() == kConstant; }\n};\n\nclass Attribute : public expr {\n  shared_ptr<expr> value;\n  identifier attr;\n\n public:\n  Attribute(shared_ptr<expr> value, const identifier& attr)\n      : expr(kAttribute), value(std::move(value)), attr(attr) {}\n\n  static shared_ptr<Attribute> Attribute_(shared_ptr<expr> value, const identifier& attr) {\n    return make_shared<Attribute>(value, attr);\n  }\n\n  shared_ptr<expr> get_value() { return value; }\n  identifier get_attr() { return attr; }\n\n  static bool classof(const expr* c) { return c->get_kind() == kAttribute; }\n};\n\nclass Name : public expr {\n  identifier id;\n\n public:\n  explicit Name(const identifier& id) : expr(kName), id(id) {}\n\n  static shared_ptr<Name> Name_(const identifier& id) { return make_shared<Name>(id); }\n\n  identifier get_id() { return id; }\n  static bool classof(const expr* c) { return c->get_kind() == kName; }\n};\n\n}  // namespace pyast\n\n#endif  // ONEFLOW_IR_ONEFLOW_EXTENSION_INCLUDE_PYAST_AST_H_\n"
  },
  {
    "path": "oneflow/ir/oneflow-extension/include/PyAst/AstMlirGen.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_IR_ONEFLOW_EXTENSION_INCLUDE_PYAST_AST_MLIR_GEN_H_\n#define ONEFLOW_IR_ONEFLOW_EXTENSION_INCLUDE_PYAST_AST_MLIR_GEN_H_\n\n#include \"OneFlow/OneFlowLRJITRegistry.h\"\n#include \"PyAst/Ast.h\"\n\n#include \"mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h\"\n#include \"mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h\"\n#include \"mlir/Conversion/AffineToStandard/AffineToStandard.h\"\n#include \"mlir/Conversion/ArithToLLVM/ArithToLLVM.h\"\n#include \"mlir/Conversion/MathToLLVM/MathToLLVM.h\"\n#include \"mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h\"\n#include \"mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h\"\n#include \"mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h\"\n#include \"mlir/Dialect/Func/Transforms/Passes.h\"\n#include \"mlir/Dialect/Affine/IR/AffineOps.h\"\n#include \"mlir/Dialect/ControlFlow/IR/ControlFlowOps.h\"\n#include \"mlir/Dialect/Arith/IR/Arith.h\"\n#include \"mlir/Dialect/ControlFlow/IR/ControlFlow.h\"\n#include \"mlir/Dialect/Func/IR/FuncOps.h\"\n#include \"mlir/Dialect/Math/IR/Math.h\"\n#include \"mlir/Dialect/MemRef/IR/MemRef.h\"\n#include \"mlir/Dialect/SCF/IR/SCF.h\"\n#include \"mlir/IR/Attributes.h\"\n#include \"mlir/IR/Block.h\"\n#include \"mlir/IR/OperationSupport.h\"\n#include \"mlir/IR/TypeRange.h\"\n#include \"mlir/IR/Value.h\"\n#include \"mlir/IR/BuiltinOps.h\"\n#include \"mlir/IR/BuiltinTypes.h\"\n#include \"mlir/IR/OwningOpRef.h\"\n#include \"mlir/IR/MLIRContext.h\"\n#include \"mlir/IR/Verifier.h\"\n#include \"mlir/IR/Builders.h\"\n#include \"mlir/Parser/Parser.h\"\n#include \"mlir/Pass/PassManager.h\"\n#include \"mlir/Support/LogicalResult.h\"\n#include \"mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h\"\n#include \"mlir/ExecutionEngine/ExecutionEngine.h\"\n#include \"mlir/ExecutionEngine/MemRefUtils.h\"\n#include \"mlir/Dialect/Linalg/Passes.h\"\n#include \"mlir/InitAllDialects.h\"\n#include \"mlir/Transforms/Passes.h\"\n\n#include \"llvm/ADT/ScopedHashTable.h\"\n#include \"llvm/Support/raw_ostream.h\"\n#include \"llvm/Support/TargetSelect.h\"\n#include \"llvm/ADT/StringRef.h\"\n#include \"llvm/ADT/STLExtras.h\"\n#include \"llvm/ADT/TypeSwitch.h\"\n\n#include <numeric>\n#include <any>\n#include <functional>\n#include <memory>\n\nclass BuilderWithSymbolTable {\n protected:\n  mlir::OpBuilder builder_;\n  mlir::ModuleOp theModule_;\n  std::map<std::string, mlir::Value> symbolTable_;\n  mlir::Block* symbolTableForDeclareBlock_{};\n\n  explicit BuilderWithSymbolTable(mlir::MLIRContext& context) : builder_(&context) {}\n  virtual ~BuilderWithSymbolTable() = default;\n\n  mlir::LogicalResult Declare(const std::string& var, mlir::Value value);\n  mlir::Value LoopUp(const std::string& var);\n  mlir::Location Loc(const std::string& file_name = \"unknown\", int line = 0, int col = 0);\n  void Dump();\n};\n\nclass MLIRGenImpl : public BuilderWithSymbolTable {\n public:\n  explicit MLIRGenImpl(mlir::MLIRContext& context) : BuilderWithSymbolTable(context) {}\n\n  mlir::ModuleOp GenModule(pyast::FunctionDef* func);\n\n  mlir::Value MlirGen(pyast::Compare* expr);\n  mlir::Value MlirGen(pyast::BinOp* expr);\n  mlir::Value MlirGen(pyast::Call* expr);\n  mlir::Value MlirGen(pyast::Constant* expr);\n  mlir::Value MlirGen(pyast::Name* expr);\n\n  mlir::Value MlirGen(pyast::expr* expr);\n\n  void MlirGen(pyast::If* stmt);\n  void MlirGen(pyast::Assign* stmt);\n  void MlirGen(pyast::Return* stmt);\n\n  void MlirGen(pyast::stmt* stmt);\n};\n\n#endif  // ONEFLOW_IR_ONEFLOW_EXTENSION_INCLUDE_PYAST_AST_MLIR_GEN_H_\n"
  },
  {
    "path": "oneflow/ir/oneflow-extension/ir_pass.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <utility>\n#include <vector>\n#include \"oneflow/core/graph/op_graph.h\"\n#include \"OneFlow/OneFlowRoundTrip.h\"\n#include \"oneflow/ir/oneflow-translate/include/OneFlow/MLIROneFlowTranslation.h\"\n#include \"oneflow/core/framework/user_op_def.h\"\n#include \"oneflow/core/framework/user_op_registry.h\"\n#include \"oneflow/core/framework/user_op_registry_manager.h\"\n#include \"oneflow/core/job/job_ir.h\"\n#include \"oneflow/core/common/env_var/debug_mode.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<IRPassType>\nstd::string IRPassTypeName();\n\ntemplate<>\nstd::string IRPassTypeName<kBeforeAD>() {\n  return \"before_ad\";\n}\n\ntemplate<>\nstd::string IRPassTypeName<kAfterAD>() {\n  return \"after_ad\";\n}\n\ntemplate<IRPassType>\nbool IsLastIRPassForIRPassType();\n\ntemplate<>\nbool IsLastIRPassForIRPassType<kBeforeAD>() {\n  return false;\n}\n\ntemplate<>\nbool IsLastIRPassForIRPassType<kAfterAD>() {\n  return true;\n}\n\ntemplate<IRPassType ir_pass_type>\nclass RoundTripOneFlowJobWrapper : public mlir::oneflow::RoundTripOneFlowJobWrapperInterface {\n public:\n  explicit RoundTripOneFlowJobWrapper(::oneflow::Job* job)\n      : job_(job), op_graph_(*job), job_builder_(job), is_updated_(false) {}\n\n  const Job* job() const override { return job_; }\n\n  bool IsLastIRPass() const override { return IsLastIRPassForIRPassType<ir_pass_type>(); }\n\n  void UpdateJob(::oneflow::Job* new_job) override {\n    CHECK(is_updated_ == false);\n    job_->Swap(new_job);\n    is_updated_ = true;\n  }\n  void DumpLog(const std::string& filename, const std::string& content) override {\n    if (IsInDebugMode()) {\n      TeePersistentLogStream::Create(JoinPath(LogDir(), filename))->Write(content);\n    }\n  }\n\n  const ::oneflow::ParallelConf& ParallelConf4OpName(const std::string& op_name) const override {\n    return job_builder_.ParallelConf4OpName(op_name).GetOrThrow();\n  }\n  const ::oneflow::OperatorConf& OpConf4OpName(const std::string& op_name) const override {\n    return job_builder_.OpConf4OpName(op_name).GetOrThrow();\n  }\n  std::pair<std::vector<std::string>, std::vector<std::string>> InputBns4OpName(\n      const std::string& op_name) const override {\n    auto node = op_graph_.OpNode4OpName(op_name);\n    std::vector<std::string> input_bns{};\n    std::vector<std::string> input_lbns{};\n    for (auto e : node->in_edges()) {\n      for (const auto& lbi_ibn_pair : e->lbi2ibns()) {\n        for (const auto& ibn : lbi_ibn_pair.second) {\n          input_bns.push_back(ibn);\n          input_lbns.push_back(GenLogicalBlobName(lbi_ibn_pair.first));\n        }\n      }\n    }\n    return std::make_pair(input_bns, input_lbns);\n  }\n\n  std::vector<std::string> OutputLbns4OpName(const std::string& op_name) const override {\n    std::unordered_set<std::string> ret{};\n    auto node = op_graph_.OpNode4OpName(op_name);\n    for (auto e : node->out_edges()) {\n      for (const auto& lbi : e->lbis()) { ret.insert(GenLogicalBlobName(lbi)); }\n    }\n    return {ret.begin(), ret.end()};\n  }\n\n  std::string ReplaceInputLbnInOpCustomizedConf(::oneflow::OperatorConf* op_conf,\n                                                const std::string& ibn,\n                                                const std::string& new_val) const override {\n    return ::oneflow::ReplaceInputLbnInOpCustomizedConf(op_conf, ibn, new_val);\n  }\n\n  void QueryLogicalBlob(\n      const std::string& lbn,\n      std::function<void(const int64_t* shape_begin, const int64_t* shape_end, DataType dt)> cb)\n      const override {\n    LogicalBlobId lbi = GenLogicalBlobId(lbn);\n    auto& blob_desc = op_graph_.GetLogicalBlobDesc(lbi);\n    cb(blob_desc.shape().dim_vec().begin(), blob_desc.shape().dim_vec().end(),\n       blob_desc.data_type());\n  }\n\n  void TopoForEachOpConf(\n      std::function<void(const ::oneflow::OperatorConf*)> Handler) const override {\n    op_graph_.TopoForEachNodeWithCtrlEdge(\n        [&](OpNode* op_node) { Handler(&op_node->op().op_conf()); });\n  }\n\n  std::string LogDir() {\n    return JoinPath(\"ir_pass\", IRPassTypeName<ir_pass_type>(), job_->job_conf().job_name());\n  }\n\n private:\n  Job* job_;\n  const OpGraph op_graph_;\n  JobBuilder job_builder_;\n  bool is_updated_;\n};\n\n}  // namespace\n\ntemplate<IRPassType ir_pass_type>\nbool IRRoundTrip<ir_pass_type>::IsEnabled(const JobPassCtx& ctx) const {\n  return ParseBooleanFromEnv(\"ONEFLOW_MLIR_ENABLE_ROUND_TRIP\", false);\n}\n\nvoid SortJob(Job& job) {\n  auto* ops = job.mutable_net()->mutable_op();\n  std::sort(ops->begin(), ops->end(),\n            [](const oneflow::OperatorConf& l, const oneflow::OperatorConf& r) {\n              return l.name() < r.name();\n            });\n}\n\ntemplate<IRPassType ir_pass_type>\nMaybe<void> IRRoundTrip<ir_pass_type>::Apply(Job* job, JobPassCtx* ctx) const {\n  if (!IsEnabled(*ctx)) { return Maybe<void>::Ok(); }\n  const OpGraph op_graph(*job);\n  Job job_before{};\n  job_before.CopyFrom(*job);\n  RoundTripOneFlowJobWrapper<ir_pass_type> w(job);\n  SortJob(job_before);\n  if (IsInDebugMode()) {\n    TeePersistentLogStream::Create(JoinPath(w.LogDir(), \"job_before_ir_round_trip.prototxt\"))\n        ->Write(job_before);\n  }\n  mlir::oneflow::RoundTripOneFlowJob(w, [](::oneflow::Job* job, std::string& reason) {\n    // TODO: It is not clear how to define if extra boxing is introduced\n    TODO();\n    return true;\n  });\n  if (IsInDebugMode()) {\n    Job job_after{};\n    job_after.CopyFrom(*job);\n    SortJob(job_after);\n    TeePersistentLogStream::Create(JoinPath(w.LogDir(), \"job_after_ir_round_trip.prototxt\"))\n        ->Write(job_after);\n  }\n  return Maybe<void>::Ok();\n}\n\ntemplate class IRRoundTrip<kBeforeAD>;\ntemplate class IRRoundTrip<kAfterAD>;\n\nMaybe<std::string> ConvertJobToTosaIR(Job* job) {\n  RoundTripOneFlowJobWrapper<kBeforeAD> job_wrapper(job);\n  return ::mlir::oneflow::ConvertJobToTosaIR(job_wrapper);\n}\n\nMaybe<void> SaveJobToIR(Job* job, const std::string& path) {\n  // TODO: check path is valid dir\n  if (IsInDebugMode()) { TeePersistentLogStream::Create(\"saved_job\")->Write(*job); }\n  RoundTripOneFlowJobWrapper<kBeforeAD> job_wrapper(job);\n  ::mlir::oneflow::SaveJobToIR(job_wrapper, path);\n  return Maybe<void>::Ok();\n}\n\nMaybe<std::string> ConvertJobToIR(Job* job) {\n  if (IsInDebugMode()) { TeePersistentLogStream::Create(\"saved_job\")->Write(*job); }\n  RoundTripOneFlowJobWrapper<kBeforeAD> job_wrapper(job);\n  return ::mlir::oneflow::ConvertJobToIR(job_wrapper);\n}\n\nMaybe<void> LoadJobFromIR(Job* job, const std::string& path) {\n  job->Clear();\n  RoundTripOneFlowJobWrapper<kBeforeAD> job_wrapper(job);\n  ::mlir::oneflow::LoadJobFromIR(job_wrapper, path);\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/ir/oneflow-extension/lr_jit.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"PyAst/Ast.h\"\n#include \"PyAst/AstMlirGen.h\"\n\n#include \"mlir/Conversion/AffineToStandard/AffineToStandard.h\"\n#include \"mlir/Conversion/ArithToLLVM/ArithToLLVM.h\"\n#include \"mlir/Conversion/MathToLLVM/MathToLLVM.h\"\n#include \"mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h\"\n#include \"mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h\"\n#include \"mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h\"\n#include \"mlir/Dialect/Arith/Transforms/Passes.h\"\n#include \"mlir/Dialect/Func/Transforms/Passes.h\"\n#include \"mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h\"\n#include \"mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h\"\n#include \"mlir/Dialect/Affine/IR/AffineOps.h\"\n#include \"mlir/Dialect/ControlFlow/IR/ControlFlowOps.h\"\n#include \"mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h\"\n#include \"mlir/Dialect/Arith/IR/Arith.h\"\n#include \"mlir/Dialect/ControlFlow/IR/ControlFlow.h\"\n#include \"mlir/Dialect/Func/IR/FuncOps.h\"\n#include \"mlir/Dialect/LLVMIR/Transforms/RequestCWrappers.h\"\n#include \"mlir/Dialect/Math/IR/Math.h\"\n#include \"mlir/Dialect/MemRef/IR/MemRef.h\"\n#include \"mlir/Dialect/SCF/IR/SCF.h\"\n#include \"mlir/IR/Attributes.h\"\n#include \"mlir/IR/OperationSupport.h\"\n#include \"mlir/IR/TypeRange.h\"\n#include \"mlir/IR/Value.h\"\n#include \"mlir/InitAllDialects.h\"\n#include \"mlir/IR/Builders.h\"\n#include \"mlir/Parser/Parser.h\"\n#include \"mlir/Pass/PassManager.h\"\n#include \"mlir/Support/LogicalResult.h\"\n#include \"mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h\"\n#include \"mlir/ExecutionEngine/ExecutionEngine.h\"\n#include \"mlir/ExecutionEngine/MemRefUtils.h\"\n#include \"mlir/IR/BuiltinOps.h\"\n#include \"mlir/IR/BuiltinTypes.h\"\n#include \"mlir/IR/OwningOpRef.h\"\n#include \"mlir/Dialect/Linalg/Passes.h\"\n#include \"mlir/IR/Verifier.h\"\n#include \"mlir/IR/MLIRContext.h\"\n#include \"mlir/Transforms/Passes.h\"\n\n#include \"llvm/Support/TargetSelect.h\"\n#include \"llvm/Support/raw_ostream.h\"\n#include \"llvm/ADT/STLExtras.h\"\n#include \"llvm/ADT/ScopedHashTable.h\"\n#include \"llvm/Support/raw_ostream.h\"\n#include \"llvm/ADT/TypeSwitch.h\"\n#include \"llvm/ADT/StringRef.h\"\n\n#include <glog/logging.h>\n#include <numeric>\n#include <any>\n#include <functional>\n#include <memory>\n\nusing llvm::ArrayRef;\nusing llvm::ScopedHashTableScope;\nusing llvm::SmallVector;\nusing llvm::StringRef;\nusing llvm::Twine;\n\nstatic struct LLVMInitializer {\n  LLVMInitializer() {\n    llvm::InitializeNativeTarget();\n    llvm::InitializeNativeTargetAsmPrinter();\n  }\n} initializer;\n\nstatic mlir::LogicalResult lowerToLLVMDialect(mlir::ModuleOp module) {\n  mlir::PassManager pm(module.getContext());\n\n  pm.addNestedPass<mlir::func::FuncOp>(mlir::LLVM::createRequestCWrappersPass());\n  pm.addPass(mlir::createCSEPass());\n  pm.addPass(mlir::createCanonicalizerPass());\n  pm.addPass(mlir::createFinalizeMemRefToLLVMConversionPass());\n  pm.addPass(mlir::createConvertFuncToLLVMPass());\n  pm.addPass(mlir::createConvertSCFToCFPass());\n  pm.addPass(mlir::createConvertControlFlowToLLVMPass());\n  pm.addPass(mlir::createConvertMathToLLVMPass());\n  pm.addPass(mlir::arith::createArithExpandOpsPass());\n  pm.addPass(mlir::createArithToLLVMConversionPass());\n  pm.addPass(mlir::createReconcileUnrealizedCastsPass());\n  return pm.run(module);\n}\n\n// generate a simple mlir module for test\nstatic mlir::OwningOpRef<mlir::ModuleOp> GenModuleForTest(mlir::MLIRContext& context) {\n  std::string moduleStr = R\"mlir(\n  func.func @get_lr(%arg0 : f32, %arg1 : i32) -> f32 attributes { llvm.emit_c_interface } {\n    return %arg0 : f32\n  }\n  )mlir\";\n  mlir::OwningOpRef<mlir::ModuleOp> module =\n      mlir::parseSourceString<mlir::ModuleOp>(moduleStr, &context);\n  return module;\n}\n\n// generate a module op from a function def python ast\nstatic mlir::OwningOpRef<mlir::ModuleOp> GenModule(mlir::MLIRContext& context,\n                                                   pyast::FunctionDef& ast) {\n  using namespace pyast;\n\n  MLIRGenImpl mlir_gen(context);\n  mlir::OwningOpRef<mlir::ModuleOp> module = mlir_gen.GenModule(&ast);\n  // module->dump();\n  return module;\n}\n\n// generate store of lr jit registry from a function def python ast\nstatic LRJITRegistry_Store_ GenFunc(pyast::FunctionDef& ast, bool is_dump) {\n  mlir::DialectRegistry registry;\n  mlir::registerAllDialects(registry);\n  mlir::registerLLVMDialectTranslation(registry);\n  mlir::registerBuiltinDialectTranslation(registry);\n  mlir::MLIRContext context(registry);\n  context.loadDialect<mlir::memref::MemRefDialect>();\n  context.loadDialect<mlir::func::FuncDialect>();\n  context.loadDialect<mlir::arith::ArithDialect>();\n  context.loadDialect<mlir::math::MathDialect>();\n  context.loadDialect<mlir::scf::SCFDialect>();\n  context.loadDialect<mlir::cf::ControlFlowDialect>();\n  context.loadDialect<mlir::affine::AffineDialect>();\n\n  auto module = GenModule(context, ast);\n  if (is_dump) { module->dump(); }\n  // auto module = genModuleForTest(context);\n  CHECK(!!module) << \"failed to parse module\";\n  CHECK(succeeded(lowerToLLVMDialect(*module))) << \"failed to lower to llvm dialect\";\n  auto jit_or_err = mlir::ExecutionEngine::create(*module);\n  CHECK(jit_or_err) << \"failed to create JIT exe engine, \"\n                    << llvm::toString(jit_or_err.takeError());\n\n  std::shared_ptr<mlir::ExecutionEngine> engine = cantFail(std::move(jit_or_err));\n\n  std::weak_ptr<mlir::ExecutionEngine> engine_ = engine;\n\n  auto func = [engine_](double base_lr, double step) {\n    float res = 0;\n    if (!engine_.expired()) {\n      auto engine = engine_.lock();\n      auto&& out = mlir::ExecutionEngine::result(res);\n      auto base_lr_jit = static_cast<float>(base_lr);\n      auto step_jit = static_cast<float>(step);\n      auto err = engine->invoke(\"get_lr\", base_lr_jit, step_jit, out);\n    }\n    return res;\n  };\n  return {engine, func};\n}\n\nvoid LRJITRegistry::Register(const std::string& function_id, pyast::FunctionDef& ast,\n                             bool is_dump) {\n  auto jit = GenFunc(ast, is_dump);\n  functionId2engine_[function_id] = jit;\n}\n\nstd::function<double(double, double)> LRJITRegistry::LookUp(const std::string& function_id) {\n  auto iter = functionId2engine_.find(function_id);\n  if (iter != functionId2engine_.end()) { return iter->second.second; }\n  llvm::errs() << \"function '\" << function_id << \"' not be registered before lookup.\";\n  return nullptr;\n};\n"
  },
  {
    "path": "oneflow/ir/oneflow-extension/mlir_gen.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"PyAst/AstMlirGen.h\"\n\n// declare any scope variables in the front of function block to ensure the enough lifetime.\nmlir::LogicalResult BuilderWithSymbolTable::Declare(const std::string& var, mlir::Value value) {\n  auto iter = symbolTable_.find(var);\n  if (iter != symbolTable_.end()) {\n    builder_.create<mlir::memref::StoreOp>(Loc(), value, iter->second);\n    return mlir::failure();\n  }\n\n  auto history_block = builder_.getInsertionBlock();\n  auto history_point = builder_.getInsertionPoint();\n\n  builder_.setInsertionPointToStart(symbolTableForDeclareBlock_);\n\n  auto single_type = mlir::Float32Type::getF32(builder_.getContext());\n  auto type = mlir::MemRefType::get({}, single_type);\n  auto key = builder_.create<mlir::memref::AllocOp>(Loc(), type);\n\n  builder_.setInsertionPoint(history_block, history_point);\n  builder_.create<mlir::memref::StoreOp>(Loc(), value, key);\n  symbolTable_[var] = key;\n  return mlir::success();\n}\n\n// look up memref of the special symbol with variable name\nmlir::Value BuilderWithSymbolTable::LoopUp(const std::string& var) {\n  if (symbolTable_.count(var) == 1) { return symbolTable_[var]; }\n  theModule_->emitError(\"error: unknown variable '\" + var + \"'\");\n  return nullptr;\n}\n\n// generate a location of mlir for ops\nmlir::Location BuilderWithSymbolTable::Loc(const std::string& file_name, int line, int col) {\n  return mlir::FileLineColLoc::get(builder_.getStringAttr(file_name), line, col);\n}\n\n// dump the current whole module up\nvoid BuilderWithSymbolTable::Dump() { theModule_.dump(); }\n\n// generate a module op for lr jit registry from a ast\nmlir::ModuleOp MLIRGenImpl::GenModule(pyast::FunctionDef* func) {\n  theModule_ = mlir::ModuleOp::create(Loc());\n\n  if (failed(verify(theModule_))) {\n    theModule_.emitError(\"module verification error\");\n    return nullptr;\n  }\n\n  builder_.setInsertionPointToEnd(theModule_.getBody());\n\n  auto args = func->get_args()->get_args();\n  auto type = mlir::Float32Type::getF32(builder_.getContext());\n  llvm::SmallVector<mlir::Type> arg_types(args.size(), type);\n  llvm::SmallVector<mlir::Type> res_types(1, type);\n\n  auto func_type = builder_.getFunctionType(arg_types, res_types);\n  auto function = mlir::func::FuncOp::create(Loc(), func->get_name(), func_type);\n\n  auto* entry_block = function.addEntryBlock();\n  symbolTableForDeclareBlock_ = entry_block;\n  theModule_.push_back(function);\n  builder_.setInsertionPointToStart(entry_block);\n\n  for (const auto nameValue : llvm::zip(args, entry_block->getArguments())) {\n    if (failed(Declare(std::get<0>(nameValue)->get_arg(), std::get<1>(nameValue)))) {\n      return nullptr;\n    }\n  }\n\n  builder_.setInsertionPointToStart(entry_block);\n  for (auto& stmt : func->get_body()) { MlirGen(stmt.get()); }\n\n  return theModule_;\n}\n\n// use llvm rtti to dispatch respective code gen tasks of stmt\nvoid MLIRGenImpl::MlirGen(pyast::stmt* stmt) {\n  llvm::TypeSwitch<pyast::stmt*>(stmt)\n      .Case<pyast::Return, pyast::Assign, pyast::If>([&](auto* node) { MlirGen(node); })\n      .Default([&](auto* node) { theModule_->emitError(\"StmtKind not support yet\"); });\n}\n\n// use llvm rtti to dispatch respective code gen tasks of expr\nmlir::Value MLIRGenImpl::MlirGen(pyast::expr* expr) {\n  mlir::Value res;\n  llvm::TypeSwitch<pyast::expr*>(expr)\n      .Case<pyast::BinOp, pyast::Compare, pyast::Call, pyast::Constant, pyast::Name>(\n          [&](auto* node) { res = MlirGen(node); })\n      .Default([&](auto* node) { theModule_->emitError(\"ExprKind not support yet\"); });\n  return res;\n}\n\nvoid MLIRGenImpl::MlirGen(pyast::If* expr) {\n  auto test = MlirGen(expr->get_test().get());\n\n  if (test.getType().isF32()) {\n    auto eq = mlir::arith::CmpFPredicate::ONE;\n    auto zero_attr = builder_.getF32FloatAttr(0);\n    auto zero = builder_.create<mlir::arith::ConstantOp>(Loc(), zero_attr);\n    test = builder_.create<mlir::arith::CmpFOp>(Loc(), eq, test, zero);\n  }\n\n  mlir::Block* then_block = builder_.createBlock(builder_.getBlock()->getParent());\n  mlir::Block* else_block = builder_.createBlock(builder_.getBlock()->getParent());\n  mlir::Block* after_block = builder_.createBlock(builder_.getBlock()->getParent());\n  builder_.setInsertionPointAfterValue(test);\n  builder_.create<mlir::cf::CondBranchOp>(Loc(), test, then_block, llvm::None, else_block,\n                                          llvm::None);\n\n  builder_.setInsertionPointToStart(then_block);\n  for (auto& expr : expr->get_body()) { MlirGen(expr.get()); }\n  if (then_block->empty() || !llvm::dyn_cast<mlir::func::ReturnOp>(then_block->back())) {\n    builder_.create<mlir::cf::BranchOp>(Loc(), after_block);\n  }\n\n  builder_.setInsertionPointToStart(else_block);\n  for (auto& expr : expr->get_orelse()) { MlirGen(expr.get()); }\n  if (else_block->empty() || !llvm::dyn_cast<mlir::func::ReturnOp>(else_block->back())) {\n    builder_.create<mlir::cf::BranchOp>(Loc(), after_block);\n  }\n\n  builder_.setInsertionPointToStart(after_block);\n}\n\nmlir::Value MLIRGenImpl::MlirGen(pyast::Compare* expr) {\n  if (expr->get_comparators().size() != 1 || expr->get_ops().size() != 1) {\n    theModule_->emitError(\"compare only support once compare now\");\n  }\n\n  mlir::arith::CmpFPredicate op = mlir::arith::CmpFPredicate::OEQ;\n  switch (expr->get_ops()[0]) {\n    case pyast::Compare::kEq: op = mlir::arith::CmpFPredicate::OEQ; break;\n    case pyast::Compare::kNotEq: op = mlir::arith::CmpFPredicate::ONE; break;\n    case pyast::Compare::kLt: op = mlir::arith::CmpFPredicate::OLT; break;\n    case pyast::Compare::kLtE: op = mlir::arith::CmpFPredicate::OLE; break;\n    case pyast::Compare::kGt: op = mlir::arith::CmpFPredicate::OGT; break;\n    case pyast::Compare::kGtE: op = mlir::arith::CmpFPredicate::OGE; break;\n    default: theModule_->emitError(\"compare_ not support op now\");\n  }\n\n  auto lhs = MlirGen(expr->get_left().get());\n  auto rhs = MlirGen(expr->get_comparators()[0].get());\n  auto res = builder_.create<mlir::arith::CmpFOp>(Loc(), op, lhs, rhs);\n  return res;\n}\n\nmlir::Value MLIRGenImpl::MlirGen(pyast::BinOp* expr) {\n  auto lhs = MlirGen(expr->get_left().get());\n  auto rhs = MlirGen(expr->get_right().get());\n  mlir::Value res;\n\n  switch (expr->get_op()) {\n    case pyast::BinOp::kAdd: res = builder_.create<mlir::arith::AddFOp>(Loc(), lhs, rhs); break;\n    case pyast::BinOp::kSub: res = builder_.create<mlir::arith::SubFOp>(Loc(), lhs, rhs); break;\n    case pyast::BinOp::kDiv: res = builder_.create<mlir::arith::DivFOp>(Loc(), lhs, rhs); break;\n    case pyast::BinOp::kMult: res = builder_.create<mlir::arith::MulFOp>(Loc(), lhs, rhs); break;\n    case pyast::BinOp::kPow: res = builder_.create<mlir::math::PowFOp>(Loc(), lhs, rhs); break;\n    default: break;\n  }\n\n  return res;\n}\n\nmlir::Value MLIRGenImpl::MlirGen(pyast::Call* expr) {\n  mlir::Value res;\n  if (expr->get_func()->get_kind() == pyast::expr::kAttribute) {\n    auto func_ = expr->get_func().get();\n    auto func = *dynamic_cast<pyast::Attribute*>(func_);\n    auto func_value = func.get_value();\n\n    if (func_value->get_kind() != pyast::expr::kName\n        || dynamic_cast<pyast::Name*>(func_value.get())->get_id() != \"math\") {\n      theModule_->emitError(\"only support call func is python math lib\");\n    }\n    if (expr->get_args().size() != 1) {\n      theModule_->emitError(\"attribute node only support call func with one param\");\n    }\n\n    auto value = MlirGen(expr->get_args()[0].get());\n    auto attr = func.get_attr();\n\n    if (attr == \"floor\") {\n      res = builder_.create<mlir::math::FloorOp>(Loc(), value);\n    } else if (attr == \"cos\") {\n      res = builder_.create<mlir::math::CosOp>(Loc(), value);\n    } else if (attr == \"ceil\") {\n      res = builder_.create<mlir::math::CeilOp>(Loc(), value);\n    } else {\n      theModule_->emitError(attr + \" not support yet\");\n    }\n  } else if (expr->get_func()->get_kind() == pyast::expr::kName) {\n    auto func_ = expr->get_func().get();\n    auto func = *dynamic_cast<pyast::Name*>(func_);\n\n    if (expr->get_args().size() != 2) {\n      theModule_->emitError(\"name node only support call func with two param\");\n    }\n\n    auto left = MlirGen(expr->get_args()[0].get());\n    auto right = MlirGen(expr->get_args()[1].get());\n\n    auto attr = func.get_id();\n\n    if (attr == \"max\") {\n      res = builder_.create<mlir::arith::MaxFOp>(Loc(), left, right);\n    } else if (attr == \"min\") {\n      res = builder_.create<mlir::arith::MinFOp>(Loc(), left, right);\n    } else {\n      theModule_->emitError(attr + \" not support yet\");\n    }\n\n  } else {\n    theModule_->emitError(\"only support call func is attribute and name node\");\n  }\n\n  return res;\n}\n\nmlir::Value MLIRGenImpl::MlirGen(pyast::Constant* expr) {\n  float value = expr->get_value();\n  auto constant = builder_.create<mlir::arith::ConstantOp>(Loc(), builder_.getF32FloatAttr(value));\n  return constant;\n}\n\nmlir::Value MLIRGenImpl::MlirGen(pyast::Name* expr) {\n  auto key = LoopUp(expr->get_id());\n  builder_.setInsertionPointToEnd(builder_.getInsertionBlock());\n  auto value = builder_.create<mlir::memref::LoadOp>(Loc(), key);\n  return value;\n}\n\nvoid MLIRGenImpl::MlirGen(pyast::Assign* stmt) {\n  auto value = MlirGen(stmt->get_value().get());\n\n  for (auto& target : stmt->get_targets()) {\n    if (target->get_kind() != pyast::expr::kName) {\n      theModule_->emitError(\"only support assign to name node\");\n    }\n    auto name = dynamic_cast<pyast::Name*>(target.get())->get_id();\n\n    Declare(name, value);\n  }\n}\n\nvoid MLIRGenImpl::MlirGen(pyast::Return* stmt) {\n  auto value = MlirGen(stmt->get_value().get());\n\n  builder_.create<mlir::func::ReturnOp>(Loc(), mlir::ValueRange({value}));\n}\n"
  },
  {
    "path": "oneflow/ir/oneflow-extension/mlir_jit_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"OneFlow/OneFlowDialect.h\"\n#include \"OneFlow/OneFlowSupport.h\"\n#include \"llvm/Support/raw_ostream.h\"\n#include \"oneflow/core/common/data_type.pb.h\"\n#include \"oneflow/core/common/device_type.pb.h\"\n#include \"oneflow/core/common/shape.h\"\n#include \"oneflow/core/framework/dtype.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n#include \"oneflow/core/operator/operator.h\"\n#include \"oneflow/user/ops/nn_util.h\"\n#include \"OneFlow/OKL/Kernel/JITOpInfer.h\"\n\n#include \"mlir/Dialect/Func/IR/FuncOps.h\"\n#include \"mlir/IR/Block.h\"\n#include \"mlir/IR/BuiltinTypes.h\"\n#include \"mlir/IR/OwningOpRef.h\"\n#include \"mlir/IR/Types.h\"\n#include \"mlir/InitAllDialects.h\"\n#include \"mlir/Parser/Parser.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nMaybe<void> GetSbpFn(user_op::SbpContext* ctx) {\n  ctx->NewBuilder().Broadcast(ctx->inputs()).Broadcast(ctx->outputs()).Build();\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\nMaybe<void> MlirJitOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  return ir::jit::InferTensorDesc(ctx);\n}\n\nMaybe<void> MlirJitOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return ir::jit::InferTensorDesc(ctx);\n}\n\nMaybe<void> MlirJitOp::GetSbp(user_op::SbpContext* ctx) { return GetSbpFn(ctx); }\n\nMaybe<void> MlirJitOp::InferDataType(user_op::InferContext* ctx) {\n  return ir::jit::SetTensorDataType(ctx);\n  ;\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/ir/oneflow-extension/mlir_jit_op_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"OneFlow/OneFlowDialect.h\"\n#include \"OneFlow/OKL/Kernel/LauncherState.h\"\n#include \"oneflow/core/common/str_util.h\"\n#include \"oneflow/core/common/switch_func.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/new_kernel_util.h\"\n#include \"oneflow/core/persistence/tee_persistent_log_stream.h\"\n#include \"oneflow/ir/include/OneFlow/Passes.h\"\n#include \"oneflow/ir/include/OneFlow/Extension.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n\n#include \"mlir/Dialect/Tensor/IR/Tensor.h\"\n#include \"mlir/Parser/Parser.h\"\n#include \"mlir/Dialect/Func/IR/FuncOps.h\"\n#include \"mlir/Dialect/Linalg/IR/Linalg.h\"\n#include \"mlir/ExecutionEngine/ExecutionEngine.h\"\n#include \"mlir/ExecutionEngine/MemRefUtils.h\"\n#include \"mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h\"\n#include \"mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h\"\n#include \"llvm/Support/TargetSelect.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nusing OpaqueMemRefDescriptor = std::shared_ptr<void>;\n\ntemplate<unsigned N, typename T>\nOpaqueMemRefDescriptor CreateMemRefDescriptor(user_op::Tensor* tensor) {\n  using MemRefType = StridedMemRefType<const T, N>;\n  auto desc = new MemRefType();\n  *desc = mlir::detail::makeStridedMemRefDescriptor<N>(\n      tensor->dptr<T>(), tensor->dptr<T>(),\n      {tensor->shape_view().ptr(), tensor->shape_view().ptr() + tensor->shape_view().NumAxes()},\n      {tensor->shape_view().ptr(), tensor->shape_view().ptr() + tensor->shape_view().NumAxes()});\n  auto deleter = [](void const* data) {\n    auto p = static_cast<MemRefType const*>(data);\n    delete p;\n  };\n  return OpaqueMemRefDescriptor(desc, deleter);\n}\n\ntemplate<unsigned N, typename T>\nOpaqueMemRefDescriptor CreateMutMemRefDescriptor(user_op::Tensor* tensor) {\n  using MemRefType = StridedMemRefType<T, N>;\n  auto desc = new MemRefType();\n  *desc = mlir::detail::makeStridedMemRefDescriptor<N>(\n      tensor->mut_dptr<T>(), tensor->mut_dptr<T>(),\n      {tensor->shape_view().ptr(), tensor->shape_view().ptr() + tensor->shape_view().NumAxes()},\n      {tensor->shape_view().ptr(), tensor->shape_view().ptr() + tensor->shape_view().NumAxes()});\n  auto deleter = [](void const* data) {\n    auto p = static_cast<MemRefType const*>(data);\n    delete p;\n  };\n  return OpaqueMemRefDescriptor(desc, deleter);\n}\n\n#define MAKE_STRIDED_MEM_REF_SWITCH_ENTRY(func_name, N, T) func_name<N, T>\nDEFINE_STATIC_SWITCH_FUNC(OpaqueMemRefDescriptor, CreateMemRefDescriptor,\n                          MAKE_STRIDED_MEM_REF_SWITCH_ENTRY, MAKE_NDIM_CTRV_SEQ(DIM_SEQ),\n                          MAKE_DATA_TYPE_CTRV_SEQ(ARITHMETIC_DATA_TYPE_SEQ));\nDEFINE_STATIC_SWITCH_FUNC(OpaqueMemRefDescriptor, CreateMutMemRefDescriptor,\n                          MAKE_STRIDED_MEM_REF_SWITCH_ENTRY, MAKE_NDIM_CTRV_SEQ(DIM_SEQ),\n                          MAKE_DATA_TYPE_CTRV_SEQ(ARITHMETIC_DATA_TYPE_SEQ));\n#undef MAKE_STRIDED_MEM_REF_SWITCH_ENTRY\n\nstd::string GetMLIRCInterface(const std::string& func_name) {\n  return std::string(\"_mlir_ciface_\") + func_name;\n}\n\nllvm::SmallVector<OpaqueMemRefDescriptor> GetMLIRCInterfaceArgs(\n    user_op::KernelComputeContext* ctx) {\n  llvm::SmallVector<OpaqueMemRefDescriptor> args{};\n  auto tensor = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n  args.push_back(SwitchCreateMemRefDescriptor(SwitchCase(1, kInt8), tensor));\n  for (auto& pair : ctx->inputs()) {\n    auto tensor = ctx->Tensor4ArgNameAndIndex(pair.first, pair.second);\n    auto ref = SwitchCreateMemRefDescriptor(\n        SwitchCase(tensor->shape_view().NumAxes(), tensor->data_type()), tensor);\n    args.push_back(ref);\n  }\n  for (auto& pair : ctx->outputs()) {\n    auto tensor = ctx->Tensor4ArgNameAndIndex(pair.first, pair.second);\n    auto ref = SwitchCreateMutMemRefDescriptor(\n        SwitchCase(tensor->shape_view().NumAxes(), tensor->data_type()), tensor);\n    args.push_back(ref);\n  }\n  return args;\n}\n\nmlir::DialectRegistry getDialectRegistry() {\n  mlir::DialectRegistry registry;\n  registry\n      .insert<mlir::oneflow::OneFlowDialect, mlir::func::FuncDialect, mlir::memref::MemRefDialect,\n              mlir::tosa::TosaDialect, mlir::linalg::LinalgDialect, mlir::tensor::TensorDialect>();\n  mlir::registerLLVMDialectTranslation(registry);\n  mlir::registerBuiltinDialectTranslation(registry);\n  return registry;\n}\n\nvoid WithMlirContext(\n    user_op::KernelComputeContext* ctx, const llvm::SmallVector<llvm::StringRef, 4>& ext_libs,\n    const std::function<mlir::OwningOpRef<mlir::ModuleOp>(mlir::MLIRContext* mlir_ctx)>& parse,\n    void* stream) {\n  mlir::MLIRContext mlir_ctx(getDialectRegistry());\n  mlir::OwningOpRef<mlir::ModuleOp> module = parse(&mlir_ctx);\n  CHECK(module) << \"fail to parse MLIR, op: \" << ctx->op_name();\n  if (ParseBooleanFromEnv(\"ONEFLOW_MLIR_STDOUT\", false)) { module->print(llvm::outs()); }\n\n  mlir::ExecutionEngineOptions jitOptions;\n  jitOptions.transformer = {};\n  jitOptions.jitCodeGenOptLevel = std::nullopt;\n  jitOptions.sharedLibPaths = ext_libs;\n\n  auto jit_or_error = mlir::ExecutionEngine::create(*module, jitOptions);\n  CHECK(!!jit_or_error) << \"failed to create JIT exe engine, \"\n                        << llvm::toString(jit_or_error.takeError());\n  auto jit = std::move(jit_or_error.get());\n  llvm::SmallVector<OpaqueMemRefDescriptor> args /* args must outlive JIT invocation */ =\n      GetMLIRCInterfaceArgs(ctx);\n  llvm::SmallVector<void*> packed_args{};\n  for (auto& arg /* arg must be a reference*/ : args) { packed_args.push_back(&arg); }\n  packed_args.push_back(&stream);\n  auto error = jit->invokePacked(GetMLIRCInterface(ctx->op_name()), packed_args);\n  CHECK(!error) << \"fail to invoke jit engine, error: \" << llvm::toString(std::move(error));\n}\n\nsize_t inferOneFlowMemPoolSize(user_op::InferContext* ctx) {\n  using namespace user_op;\n  mlir::MLIRContext mlir_ctx(oneflow::okl::GetRegistry());\n  auto mlir_assembly = ctx->Attr<std::vector<char>>(\"mlir_assembly\");\n  auto mlir = mlir::parseSourceString<mlir::ModuleOp>(\n      llvm::StringRef(mlir_assembly.data(), mlir_assembly.size() - 1), &mlir_ctx);\n\n  auto module = mlir.get();\n  if (auto mempool = module->getAttr(mlir::oneflow::codegen::mempool::MEMPOOL_ATTR_NAME)\n                         .cast<mlir::IntegerAttr>()) {\n    return mempool.getInt();\n  }\n  // Note: we should ensure the tmp buffer should be fetched in the mlir jit op in case of null\n  // object error.\n  return 1;\n}\n\ntemplate<typename T>\nclass MlirJitCpuKernel final : public user_op::OpKernel {\n public:\n  MlirJitCpuKernel() = default;\n  ~MlirJitCpuKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    llvm::SmallVector<llvm::StringRef, 4> ext_libs(\n        {SharedLibPaths()->begin(), SharedLibPaths()->end()});\n    WithMlirContext(\n        ctx, ext_libs,\n        [&ctx](mlir::MLIRContext* mlir_ctx) {\n          auto mlir_assembly = ctx->Attr<std::vector<char>>(\"mlir_assembly\");\n          return mlir::parseSourceString<mlir::ModuleOp>(\n              llvm::StringRef(mlir_assembly.data(), mlir_assembly.size() - 1), mlir_ctx);\n        },\n        nullptr);\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_MLIR_JIT_CPU_KERNEL(dtype)                                                       \\\n  REGISTER_USER_KERNEL(\"mlir_jit\")                                                                \\\n      .SetCreateFn<MlirJitCpuKernel<dtype>>()                                                     \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)                             \\\n                       && (user_op::HobDataType(\"out\", 0) == GetDataType<dtype>::value))          \\\n      .SetInferTmpSizeFn([](user_op::InferContext* ctx) { return inferOneFlowMemPoolSize(ctx); }) \\\n      .SetInplaceProposalFn([](const user_op::InferContext&,                                      \\\n                               user_op::AddInplaceArgPair AddInplaceArgPairFn) -> Maybe<void> {   \\\n        return Maybe<void>::Ok();                                                                 \\\n      });\n\nREGISTER_MLIR_JIT_CPU_KERNEL(float)\nREGISTER_MLIR_JIT_CPU_KERNEL(double)\nREGISTER_MLIR_JIT_CPU_KERNEL(int32_t)\nREGISTER_MLIR_JIT_CPU_KERNEL(int64_t)\n\n#undef REGISTER_MLIR_JIT_CPU_KERNEL\n\n#ifdef WITH_MLIR_CUDA_CODEGEN\n\ntemplate<typename T>\nclass MlirJitGpuKernel final : public user_op::OpKernel {\n public:\n  MlirJitGpuKernel() = default;\n  ~MlirJitGpuKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    llvm::SmallVector<llvm::StringRef, 4> ext_libs(\n        {SharedLibPaths()->begin(), SharedLibPaths()->end()});\n    WithMlirContext(\n        ctx, ext_libs,\n        [&ctx](mlir::MLIRContext* mlir_ctx) {\n          auto mlir_assembly = ctx->Attr<std::vector<char>>(\"mlir_assembly\");\n          return mlir::parseSourceString<mlir::ModuleOp>(\n              llvm::StringRef(mlir_assembly.data(), mlir_assembly.size() - 1), mlir_ctx);\n        },\n#ifdef WITH_CUDA\n        ctx->stream()->As<ep::CudaStream>()->cuda_stream());\n#else\n        nullptr);\n#endif  // WITH_CUDA\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_MLIR_JIT_GPU_KERNEL(dtype)                                                       \\\n  REGISTER_USER_KERNEL(\"mlir_jit\")                                                                \\\n      .SetCreateFn<MlirJitGpuKernel<dtype>>()                                                     \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)                            \\\n                       && (user_op::HobDataType(\"out\", 0) == GetDataType<dtype>::value))          \\\n      .SetInferTmpSizeFn([](user_op::InferContext* ctx) { return inferOneFlowMemPoolSize(ctx); }) \\\n      .SetInplaceProposalFn([](const user_op::InferContext&,                                      \\\n                               user_op::AddInplaceArgPair AddInplaceArgPairFn) -> Maybe<void> {   \\\n        return Maybe<void>::Ok();                                                                 \\\n      });\n\nREGISTER_MLIR_JIT_GPU_KERNEL(float)\nREGISTER_MLIR_JIT_GPU_KERNEL(double)\nREGISTER_MLIR_JIT_GPU_KERNEL(int32_t)\nREGISTER_MLIR_JIT_GPU_KERNEL(int64_t)\n\n#undef REGISTER_MLIR_JIT_GPU_KERNEL\n\n#endif  // WITH_MLIR_CUDA_CODEGEN\n\n}  // namespace\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/ir/oneflow-lite/CMakeLists.txt",
    "content": "include_directories(${PROJECT_BINARY_DIR}/oneflow-lite)\ninclude_directories(${PROJECT_SOURCE_DIR}/oneflow-lite)\ninclude_directories(${PROJECT_SOURCE_DIR}/oneflow-lite/include)\ninclude_directories(${PROJECT_BINARY_DIR}/oneflow-lite/include)\n\nadd_subdirectory(schemas)\nadd_subdirectory(lib)\n\nset(LLVM_LINK_COMPONENTS Support)\nget_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)\n\nadd_llvm_executable(\n  oneflow-lite-compile\n  OneFlowLiteCompileMain.cpp\n  DEPENDS\n  MLIROneFlow\n  lite_schemas\n  OneFlowLiteConversion\n  flatcc-runtime)\n\nset(_origin_prefix \"\\$ORIGIN\")\nif(APPLE)\n  set(_origin_prefix \"@loader_path\")\nendif()\nset_target_properties(\n  oneflow-lite-compile PROPERTIES BUILD_WITH_INSTALL_RPATH OFF BUILD_RPATH \"${_origin_prefix}\"\n                                  INSTALL_RPATH \"${_origin_prefix}\")\n\nllvm_update_compile_flags(oneflow-lite-compile)\n\ntarget_link_libraries(oneflow-lite-compile PRIVATE OneFlowLiteConversion ${dialect_libs}\n                                                   flatcc-runtime)\n"
  },
  {
    "path": "oneflow/ir/oneflow-lite/OneFlowLiteCompileMain.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"llvm/ADT/SmallString.h\"\n#include \"llvm/Support/CommandLine.h\"\n#include \"llvm/Support/InitLLVM.h\"\n#include \"llvm/Support/Path.h\"\n#include \"llvm/Support/ToolOutputFile.h\"\n\n#include \"mlir/IR/OwningOpRef.h\"\n#include \"mlir/Pass/PassManager.h\"\n#include \"mlir/Parser/Parser.h\"\n#include \"mlir/Support/FileUtilities.h\"\n#include \"mlir/Support/LLVM.h\"\n#include \"mlir/Support/LogicalResult.h\"\n#include \"mlir/Support/ToolUtilities.h\"\n#include \"mlir/Transforms/Passes.h\"\n\n#include \"OneFlow/OneFlowDialect.h\"\n#include \"OneFlow/ConvertToLiteExecutable.h\"\n\nnamespace mlir {\nnamespace oneflow {\nnamespace lite {\n\nLogicalResult Compile(int argc, char** argv) {\n  llvm::InitLLVM y(argc, argv);\n  static llvm::cl::OptionCategory mainOptions(\"OneFlowLite Compile Main Options\");\n\n  llvm::cl::opt<std::string> inputFiledir(llvm::cl::Positional,\n                                          llvm::cl::desc(\"<Input saved model directory>\"),\n                                          llvm::cl::Required, llvm::cl::cat(mainOptions));\n\n  llvm::cl::opt<std::string> outputFilename(\"o\", llvm::cl::desc(\"Output filename\"),\n                                            llvm::cl::value_desc(\"filename\"), llvm::cl::init(\"-\"),\n                                            llvm::cl::cat(mainOptions));\n\n  llvm::cl::list<std::string> targets(\"targets\",\n                                      llvm::cl::desc(\"Target backends for executable compilation\"),\n                                      llvm::cl::ZeroOrMore, llvm::cl::cat(mainOptions));\n\n  llvm::cl::ParseCommandLineOptions(argc, argv, \"OneFlowLite compile\\n\");\n\n  llvm::SmallString<128> inputFilename = StringRef(inputFiledir + \"/model.mlir\");\n  llvm::sys::path::native(inputFilename);\n\n  mlir::MLIRContext context;\n  context.getOrLoadDialect<oneflow::OneFlowDialect>();\n  context.loadDialect<mlir::func::FuncDialect>();\n\n  OwningOpRef<ModuleOp> module = parseSourceFile<ModuleOp>(inputFilename, &context);\n\n  ConvertOptions options;\n  options.checkpointDir = inputFiledir;\n  if (targets.empty()) {\n    options.target = \"host\";\n  } else {\n    if (targets.size() > 1) {\n      llvm::errs() << \"Support only one target currently.\\n\";\n      return failure();\n    }\n    options.target = targets[0];\n  }\n  llvm::errs() << \"Enable compilation for target: \" << options.target << \"\\n\";\n\n  llvm::SmallVector<uint8_t, 32> executable;\n  if (failed(ConvertToLiteExecutable(&context, module.get(), options, &executable))) {\n    return failure();\n  }\n  std::string errorMessage;\n  auto output = mlir::openOutputFile(outputFilename, &errorMessage);\n  if (!output) {\n    llvm::errs() << errorMessage << \"\\n\";\n    return failure();\n  }\n  output->os().write(reinterpret_cast<char*>(executable.data()), executable.size());\n  output->keep();\n  return success();\n}\n\n}  // namespace lite\n}  // namespace oneflow\n}  // namespace mlir\n\nint main(int argc, char** argv) {\n  if (mlir::failed(mlir::oneflow::lite::Compile(argc, argv))) { return 1; }\n  return 0;\n}\n"
  },
  {
    "path": "oneflow/ir/oneflow-lite/include/OneFlow/ConvertToLiteExecutable.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_IR_ONEFLOW_LITE_INCLUDE_ONEFLOW_CONVERTTOLITEEXECUTABLE_H_\n#define ONEFLOW_IR_ONEFLOW_LITE_INCLUDE_ONEFLOW_CONVERTTOLITEEXECUTABLE_H_\n\n#include \"llvm/ADT/SmallString.h\"\n\n#include \"mlir/IR/BuiltinOps.h\"\n#include \"mlir/Support/LLVM.h\"\n\n#include \"OneFlow/FlatbufferUtils.h\"\n\nnamespace mlir {\nnamespace oneflow {\n\nnamespace lite {\n\ntypedef struct ConvertOptions {\n  llvm::SmallString<128> target;\n  llvm::SmallString<128> checkpointDir;\n} ConvertOptions;\n\nLogicalResult ConvertToLiteExecutable(MLIRContext* context, ModuleOp module, ConvertOptions options,\n                                      llvm::SmallVector<uint8_t, 32>* executable);\n\n}  // namespace lite\n\n}  // namespace oneflow\n}  // namespace mlir\n\n#endif  // ONEFLOW_IR_ONEFLOW_LITE_INCLUDE_ONEFLOW_CONVERTTOLITEEXECUTABLE_H_\n"
  },
  {
    "path": "oneflow/ir/oneflow-lite/include/OneFlow/FlatbufferUtils.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n// Copyright 2020 The IREE Authors\n//\n// Licensed under the Apache License v2.0 with LLVM Exceptions.\n// See https://llvm.org/LICENSE.txt for license information.\n// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception\n\n#ifndef ONEFLOW_IR_ONEFLOW_LITE_INCLUDE_ONEFLOW_FLATBUFFERUTILS_H_\n#define ONEFLOW_IR_ONEFLOW_LITE_INCLUDE_ONEFLOW_FLATBUFFERUTILS_H_\n\n#include <stddef.h>\n#include <stdint.h>\n\n#include <functional>\n\n#include \"llvm/ADT/ArrayRef.h\"\n#include \"llvm/ADT/STLExtras.h\"\n#include \"llvm/ADT/SmallVector.h\"\n#include \"llvm/ADT/StringRef.h\"\n#include \"llvm/Support/raw_ostream.h\"\n#include \"mlir/IR/BuiltinAttributes.h\"\n#include \"mlir/IR/MLIRContext.h\"\n#include \"mlir/IR/Types.h\"\n#include \"mlir/Support/LLVM.h\"\n#include \"mlir/Support/LogicalResult.h\"\n\n#pragma GCC diagnostic push\n#pragma GCC diagnostic ignored \"-Wcast-qual\"\n#include \"flatcc/flatcc_builder.h\"\n#include \"flatcc/flatcc_json_printer.h\"\n#include \"flatcc/reflection/reflection_builder.h\"\n#pragma GCC diagnostic pop\n\nnamespace mlir {\nnamespace oneflow {\n\nnamespace lite {\n\n// RAII wrapper for flatcc_builder_t; pass to functions requiring a builder.\n//\n// Usage:\n//   FlatbufferBuilder builder;\n//   // NOTE: FlatBuffers are built bottoms-up so we first generate our [uint8]:\n//   auto dataRef = builder.streamUint8Vec(...);\n//   // ... and then start the table that references it:\n//   my_type_start_as_root(builder);\n//   my_type_uint8_vec_field_add(builder, dataRef);\n//   my_type_end_as_root(builder);\n//   // ... and finally capture the results as an mlir::Attribute.\n//   auto attr = builder.getBufferAttr(mlirContext);\nclass FlatbufferBuilder {\n public:\n  FlatbufferBuilder();\n  ~FlatbufferBuilder();\n\n  operator flatcc_builder_t*() { return &builder; }\n\n  // Creates a string with the given string contents (including zeros).\n  flatbuffers_string_ref_t createString(StringRef value) {\n    if (value.empty()) return 0;\n    return flatbuffers_string_create(*this, value.data(), value.size());\n  }\n\n  // Creates a string vector containing all strings in the given range.\n  template<typename RangeTy>\n  flatbuffers_string_vec_ref_t createStringVec(RangeTy&& Range) {\n    auto stringRefs = llvm::to_vector<8>(llvm::map_range(Range, [&](StringRef value) {\n      return flatbuffers_string_create(*this, value.data(), value.size());\n    }));\n    if (stringRefs.empty()) return 0;\n    return flatbuffers_string_vec_create(*this, stringRefs.data(), stringRefs.size());\n  }\n\n  // Creates an offset vector with the given values. The source values will not\n  // be modified.\n  flatbuffers_vec_ref_t createOffsetVec(ArrayRef<flatcc_builder_ref_t> values) {\n    if (values.empty()) return 0;\n    return flatcc_builder_create_offset_vector(*this, values.data(), values.size());\n  }\n\n  // Creates an offset vector with the given values.\n  // Unlike createOffsetVec this will destroy the input values array during\n  // serialization but be much faster.\n  flatbuffers_vec_ref_t createOffsetVecDestructive(SmallVectorImpl<flatcc_builder_ref_t>& values) {\n    if (values.empty()) return 0;\n    return flatcc_builder_create_offset_vector_direct(*this, values.data(), values.size());\n  }\n\n  // Creates an [int32] vec with the contents of the given range.\n  template<typename RangeTy>\n  flatbuffers_int32_vec_ref_t createInt32Vec(RangeTy&& Range) {\n    if (std::empty(Range)) return 0;\n    flatbuffers_int32_vec_start(*this);\n    for (int32_t v : Range) { flatbuffers_int32_vec_push_create(*this, v); }\n    return flatbuffers_int32_vec_end(*this);\n  }\n\n  // Creates an [int64] vec with the contents of the given range.\n  template<typename RangeTy>\n  flatbuffers_int64_vec_ref_t createInt64Vec(RangeTy&& Range) {\n    if (std::empty(Range)) return 0;\n    flatbuffers_int64_vec_start(*this);\n    for (int64_t v : Range) { flatbuffers_int64_vec_push_create(*this, v); }\n    return flatbuffers_int64_vec_end(*this);\n  }\n\n  // Provides a raw_ostream that |fn| can use to directly stream into a [uint8]\n  // in the FlatBuffer builder.\n  //\n  // Usage:\n  //   auto ref = builder.streamUint8Vec([&](llvm::raw_ostream &stream) {\n  //     stream << \"foo\";\n  //     return true;\n  //   });\n  //   ...\n  //   my_type_uint8_vec_field_add(builder, ref);  // use vec reference\n  //   ...\n  flatbuffers_uint8_vec_ref_t streamUint8Vec(std::function<bool(raw_ostream& stream)> fn,\n                                             size_t alignment = 16);\n\n  // Captures the current contents of the flatbuffer builder and returns them\n  // as a shaped `vector<SIZExi8>` dense attr. The builder is left unmodified.\n  DenseIntElementsAttr getBufferAttr(MLIRContext* context);\n\n  // Copies the current contents of the flatbuffer builder to the target output\n  // stream. The builder is left unmodified.\n  //\n  // This is reduces a significant large allocation that can happen when trying\n  // to stitch together all of the pages that were allocated in the emitter as\n  // the FlatBuffer was constructed; here we can just walk over each page and\n  // write it out in order without any allocations.\n  LogicalResult copyToStream(llvm::raw_ostream& output);\n\n  using print_json_fn_t = int (*)(flatcc_json_printer_t* ctx, const char* buf, size_t bufsiz);\n\n  // Prints the FlatBuffer in its canonical JSON format to the given stream.\n  // The builder is left unmodified.\n  //\n  // |pretty| enables newlines and indentation; somewhat useful for lit testing\n  // (as large byte buffers end up with a byte per line!).\n  //\n  // |includeDefaults| will force all values, including those that would not\n  // be serialized to the binary format due to the default value (0, etc) being\n  // omitted.\n  //\n  // NOTE: JSON representations will also differ structurally from the binary\n  // format as reused tables are printed wherever they are used as opposed to\n  // referencing the same bytes; meaning that this can't be used to verify that\n  // we are correctly memoizing strings/structures/etc.\n  LogicalResult printJsonToStream(bool pretty, bool includeDefaults, print_json_fn_t printJsonFn,\n                                  llvm::raw_ostream& output);\n\n private:\n  flatcc_builder_t builder;\n};\n\n// Allows streaming bytes directly into a FlatBuffer `[uint8]` field.\n// The ostream runs in buffered mode and routes all writes into pages\n// allocated by the FlatBuffer builder as we grow the output.\n//\n// Usage:\n//   flatbuffers_uint8_vec_start(builder);\n//   raw_flatbuffer_uint8_vec_ostream stream(builder);\n//   stream << \"foo\";\n//   stream.flush();  // *********** IMPORTANT ***********\n//   flatbuffers_uint8_vec_ref_t ref = flatbuffers_uint8_vec_end(builder);\nclass raw_flatbuffer_uint8_vec_ostream : public llvm::raw_ostream {\n public:\n  explicit raw_flatbuffer_uint8_vec_ostream(flatcc_builder_t* builder)\n      : raw_ostream(/*unbuffered=*/true), builder(builder) {}\n\n  ~raw_flatbuffer_uint8_vec_ostream() override { flush(); }\n\n private:\n  void write_impl(const char* Ptr, size_t Size) override {\n    flatbuffers_uint8_vec_append(builder, reinterpret_cast<const uint8_t*>(Ptr), Size);\n    pos += Size;\n  }\n\n  uint64_t current_pos() const override { return pos - GetNumBytesInBuffer(); }\n\n  flatcc_builder_t* builder;\n  uint64_t pos = 0;\n};\n\n}  // namespace lite\n}  // namespace oneflow\n}  // namespace mlir\n\n#endif  // ONEFLOW_IR_ONEFLOW_LITE_INCLUDE_ONEFLOW_FLATBUFFERUTILS_H_\n"
  },
  {
    "path": "oneflow/ir/oneflow-lite/include/OneFlow/OneFlowLiteUtils.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_IR_ONEFLOW_LITE_INCLUDE_ONEFLOW_ONEFLOWLITEUTILS_H_\n#define ONEFLOW_IR_ONEFLOW_LITE_INCLUDE_ONEFLOW_ONEFLOWLITEUTILS_H_\n\n#include \"OneFlow/OneFlowDialect.h\"\n#include \"OneFlow/OneFlowOps.h\"\n#include \"OneFlow/FlatbufferUtils.h\"\n\n#include \"mlir/Support/LLVM.h\"\n#include \"mlir/Support/LogicalResult.h\"\n\nnamespace mlir {\nnamespace oneflow {\n\nnamespace lite {\n\nOperation* getEntryJobOp(ModuleOp module);\nOperation* getEntryJobOp(Operation* op);\n\nStringAttr getValueDevice(Value value);\n\nOptional<StringRef> getLiteStringElementType(Type type);\nOptional<StringRef> getLiteStringElementType(::mlir::oneflow::DataType type);\n\nOptional<::oneflow::AttrType> getUserOpAttrType(StringRef opName, StringRef attrName);\n\nvoid serializeI32Attr(FlatbufferBuilder& builder, Attribute attribute);\nvoid serializeI64Attr(FlatbufferBuilder& builder, Attribute attribute);\nvoid serializeBoolAttr(FlatbufferBuilder& builder, Attribute attribute);\nvoid serializeF32Attr(FlatbufferBuilder& builder, Attribute attribute);\nvoid serializeF64Attr(FlatbufferBuilder& builder, Attribute attribute);\nvoid serializeStringAttr(FlatbufferBuilder& builder, Attribute attribute);\nvoid serializeShapeAttr(FlatbufferBuilder& builder, Attribute attribute);\nvoid serializeStrideAttr(FlatbufferBuilder& builder, Attribute attribute);\nvoid serializeDataTypeAttr(FlatbufferBuilder& builder, Attribute attribute);\nvoid serializeI32sAttr(FlatbufferBuilder& builder, Attribute attribute);\nvoid serializeI64sAttr(FlatbufferBuilder& builder, Attribute attribute);\nvoid serializeF32sAttr(FlatbufferBuilder& builder, Attribute attribute);\nvoid serializeDataTypesAttr(FlatbufferBuilder& builder, Attribute attribute);\nvoid serializeShapesAttr(FlatbufferBuilder& builder, Attribute attribute);\nvoid serializeStridesAttr(FlatbufferBuilder& builder, Attribute attribute);\nvoid serializeStringsAttr(FlatbufferBuilder& builder, Attribute attribute);\n\n}  // namespace lite\n\n}  // namespace oneflow\n}  // namespace mlir\n\n#endif  // ONEFLOW_IR_ONEFLOW_LITE_INCLUDE_ONEFLOW_ONEFLOWLITEUTILS_H_\n"
  },
  {
    "path": "oneflow/ir/oneflow-lite/include/OneFlow/Transform/FoldVariable.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_IR_ONEFLOW_LITE_INCLUDE_ONEFLOW_TRANSFORM_FOLDVARIABLE_H_\n#define ONEFLOW_IR_ONEFLOW_LITE_INCLUDE_ONEFLOW_TRANSFORM_FOLDVARIABLE_H_\n\n#include \"mlir/Pass/Pass.h\"\n\nnamespace mlir {\nnamespace oneflow {\nnamespace lite {\n\nstd::unique_ptr<mlir::Pass> createLiteFoldVariablePass();\n\n}  // namespace lite\n}  // namespace oneflow\n}  // namespace mlir\n\n#endif  // ONEFLOW_IR_ONEFLOW_LITE_INCLUDE_ONEFLOW_TRANSFORM_FOLDVARIABLE_H_\n"
  },
  {
    "path": "oneflow/ir/oneflow-lite/include/OneFlow/Transform/InferPlacement.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_IR_ONEFLOW_LITE_INCLUDE_ONEFLOW_TRANSFORM_INFERPLACEMENT_H_\n#define ONEFLOW_IR_ONEFLOW_LITE_INCLUDE_ONEFLOW_TRANSFORM_INFERPLACEMENT_H_\n\n#include \"mlir/Pass/Pass.h\"\n\nnamespace mlir {\nnamespace oneflow {\nnamespace lite {\n\nstd::unique_ptr<mlir::Pass> createLiteInferPlacementPass(StringRef target);\n\n}  // namespace lite\n}  // namespace oneflow\n}  // namespace mlir\n\n#endif  // ONEFLOW_IR_ONEFLOW_LITE_INCLUDE_ONEFLOW_TRANSFORM_INFERPLACEMENT_H_\n"
  },
  {
    "path": "oneflow/ir/oneflow-lite/include/OneFlow/Transform/InsertTransferOp.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_IR_ONEFLOW_LITE_INCLUDE_ONEFLOW_TRANSFORM_INSERTTRANSFEROP_H_\n#define ONEFLOW_IR_ONEFLOW_LITE_INCLUDE_ONEFLOW_TRANSFORM_INSERTTRANSFEROP_H_\n\n#include \"mlir/Pass/Pass.h\"\n\nnamespace mlir {\nnamespace oneflow {\nnamespace lite {\n\nstd::unique_ptr<mlir::Pass> createLiteInsertTransferOpPass();\n\n}  // namespace lite\n}  // namespace oneflow\n}  // namespace mlir\n\n#endif  // ONEFLOW_IR_ONEFLOW_LITE_INCLUDE_ONEFLOW_TRANSFORM_INSERTTRANSFEROP_H_\n"
  },
  {
    "path": "oneflow/ir/oneflow-lite/include/OneFlow/Transform/Lowering/LoweringAscend.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_IR_ONEFLOW_LITE_INCLUDE_ONEFLOW_TRANSFORM_LOWERING_LOWERINGASCEND_H_\n#define ONEFLOW_IR_ONEFLOW_LITE_INCLUDE_ONEFLOW_TRANSFORM_LOWERING_LOWERINGASCEND_H_\n\n#include \"OneFlow/OneFlowDialect.h\"\n#include \"OneFlow/OneFlowOps.h\"\n#include \"OneFlow/OneFlowOpTraits.h\"\n#include \"OneFlow/OneFlowLiteUtils.h\"\n\n#include \"mlir/IR/BuiltinOps.h\"\n#include \"mlir/IR/Builders.h\"\n#include \"mlir/IR/MLIRContext.h\"\n\nnamespace mlir {\nnamespace oneflow {\nnamespace lite {\n\nLogicalResult loweringAscend(OpBuilder& builder, Operation* callee, StringRef checkpointDir,\n                             llvm::SmallVector<uint8_t, 4>* loweringData);\n\n}  // namespace lite\n}  // namespace oneflow\n}  // namespace mlir\n\n#endif  // ONEFLOW_IR_ONEFLOW_LITE_INCLUDE_ONEFLOW_TRANSFORM_LOWERING_LOWERINGASCEND_H_\n"
  },
  {
    "path": "oneflow/ir/oneflow-lite/include/OneFlow/Transform/Lowering/LoweringAscendUtils.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_IR_ONEFLOW_LITE_INCLUDE_ONEFLOW_TRANSFORM_LOWERING_LOWERINGASCENDUTILS_H_\n#define ONEFLOW_IR_ONEFLOW_LITE_INCLUDE_ONEFLOW_TRANSFORM_LOWERING_LOWERINGASCENDUTILS_H_\n\n#include <vector>\n\n#include \"OneFlow/OneFlowDialect.h\"\n#include \"OneFlow/OneFlowOps.h\"\n#include \"OneFlow/OneFlowOpTraits.h\"\n#include \"OneFlow/OneFlowLiteUtils.h\"\n\n#include \"mlir/IR/BuiltinOps.h\"\n#include \"mlir/IR/Builders.h\"\n#include \"mlir/IR/MLIRContext.h\"\n\n// huawei ascend sdk headers\n#pragma GCC diagnostic push\n#pragma GCC diagnostic ignored \"-Wignored-qualifiers\"\n#include \"op_proto/built-in/inc/all_ops.h\"\n#pragma GCC diagnostic pop\n\nnamespace mlir {\nnamespace oneflow {\nnamespace lite {\n\ninline ge::Shape convertAscendShape(ArrayRef<int64_t> shape) {\n  return ge::Shape(std::vector<int64_t>{shape.begin(), shape.end()});\n}\n\ninline Optional<ge::DataType> convertAscendElementType(Type type) {\n  assert(type.isIntOrFloat());\n  if (type.isF16()) {\n    return ge::DT_FLOAT16;\n  } else if (type.isF32()) {\n    return ge::DT_FLOAT;\n  } else if (type.isF64()) {\n    return ge::DT_DOUBLE;\n  } else if (type.isSignedInteger()) {\n    int bitwidth = type.getIntOrFloatBitWidth();\n    if (bitwidth == 8) {\n      return ge::DT_INT8;\n    } else if (bitwidth == 16) {\n      return ge::DT_INT16;\n    } else if (bitwidth == 32) {\n      return ge::DT_INT32;\n    } else if (bitwidth == 64) {\n      return ge::DT_INT64;\n    } else {\n      return llvm::None;\n    }\n  } else if (type.isUnsignedInteger()) {\n    int bitwidth = type.getIntOrFloatBitWidth();\n    if (bitwidth == 8) {\n      return ge::DT_UINT8;\n    } else if (bitwidth == 16) {\n      return ge::DT_UINT16;\n    } else if (bitwidth == 32) {\n      return ge::DT_UINT32;\n    } else if (bitwidth == 64) {\n      return ge::DT_UINT64;\n    } else {\n      return llvm::None;\n    }\n  } else {\n    return llvm::None;\n  }\n}\n\ninline Optional<ge::DataType> convertAscendElementType(::mlir::oneflow::DataType type) {\n  switch (type) {\n    case ::mlir::oneflow::DataType::DT_Bool: return ge::DT_BOOL;\n    case ::mlir::oneflow::DataType::DT_Char: return ge::DT_UINT8;\n    case ::mlir::oneflow::DataType::DT_Float16: return ge::DT_FLOAT16;\n    case ::mlir::oneflow::DataType::DT_Float: return ge::DT_FLOAT;\n    case ::mlir::oneflow::DataType::DT_Double: return ge::DT_DOUBLE;\n    case ::mlir::oneflow::DataType::DT_Int8: return ge::DT_INT8;\n    case ::mlir::oneflow::DataType::DT_Int32: return ge::DT_INT32;\n    case ::mlir::oneflow::DataType::DT_Int64: return ge::DT_INT64;\n    case ::mlir::oneflow::DataType::DT_UInt8: return ge::DT_UINT8;\n    default: {\n      return llvm::None;\n    }\n  }\n}\n\ninline ge::TensorDesc convertAscendType(Type type) {\n  auto tensorType = type.cast<TensorType>();\n  assert(tensorType && \"type should be tensor type\");\n  auto elementType = convertAscendElementType(tensorType.getElementType());\n  if (!elementType) {\n    llvm::errs() << \"element type \" << tensorType.getElementType() << \" is not supported\\n\";\n    exit(1);\n  }\n  return ge::TensorDesc(convertAscendShape(tensorType.getShape()), ge::FORMAT_NCHW,\n                        elementType.value());\n}\n\ninline ge::TensorDesc convertAscendType(::mlir::oneflow::DataType type, ArrayRef<int64_t> shape) {\n  auto elementType = convertAscendElementType(type);\n  if (!elementType) {\n    llvm::errs() << \"element type \" << static_cast<uint32_t>(type) << \" is not supported\\n\";\n    exit(1);\n  }\n  return ge::TensorDesc(convertAscendShape(shape), ge::FORMAT_NCHW, elementType.value());\n}\n\ninline ge::TensorDesc convertAscendType(Attribute type, Attribute shape) {\n  SmallVector<int64_t, 4> shapeArray;\n  for (auto v : shape.dyn_cast<ArrayAttr>().getValue()) {\n    shapeArray.push_back(v.dyn_cast<IntegerAttr>().getSInt());\n  }\n  return convertAscendType(type.dyn_cast<mlir::oneflow::DataTypeAttr>().getValue(), shapeArray);\n}\n\ninline ge::Operator::OpListInt convertPaddings(ArrayAttr paddings) {\n  assert(paddings.size() == 2 || paddings.size() == 4);\n  if (paddings.size() == 2) {\n    int s0 = paddings[0].dyn_cast<IntegerAttr>().getSInt();\n    int s1 = paddings[1].dyn_cast<IntegerAttr>().getSInt();\n    return ge::Operator::OpListInt({s0, s0, s1, s1});\n  } else {\n    int s0 = paddings[0].dyn_cast<IntegerAttr>().getSInt();\n    int s1 = paddings[1].dyn_cast<IntegerAttr>().getSInt();\n    int s2 = paddings[2].dyn_cast<IntegerAttr>().getSInt();\n    int s3 = paddings[3].dyn_cast<IntegerAttr>().getSInt();\n    return ge::Operator::OpListInt({s0, s1, s2, s3});\n  }\n}\n\ninline ge::Operator::OpListInt convertStrides(ArrayAttr strides) {\n  assert(strides.size() == 2);\n  int s0 = strides[0].dyn_cast<IntegerAttr>().getSInt();\n  int s1 = strides[1].dyn_cast<IntegerAttr>().getSInt();\n  return ge::Operator::OpListInt({1, 1, s0, s1});\n}\n\ninline ge::Operator::OpListInt convertDilations(ArrayAttr dilations) {\n  return convertStrides(dilations);\n}\n\ninline ge::Operator::OpListInt convertKernelSize(ArrayAttr kernel_size) {\n  return convertStrides(kernel_size);\n}\n\ninline StringRef convertDataFormat(StringRef dataFormat) {\n  if (dataFormat == \"nchw\" || dataFormat == \"NCHW\" || dataFormat == \"channels_first\") {\n    return StringRef(\"NCHW\");\n  } else if (dataFormat == \"nhwc\" || dataFormat == \"NHWC\" || dataFormat == \"channels_last\") {\n    return StringRef(\"NHWC\");\n  } else {\n    llvm::errs() << \"unsupport data format \" << dataFormat << \"\\n\";\n    exit(1);\n  }\n}\n\n}  // namespace lite\n}  // namespace oneflow\n}  // namespace mlir\n\n#endif  // ONEFLOW_IR_ONEFLOW_LITE_INCLUDE_ONEFLOW_TRANSFORM_LOWERING_LOWERINGASCENDUTILS_H_\n"
  },
  {
    "path": "oneflow/ir/oneflow-lite/include/OneFlow/Transform/LoweringLaunchJob.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_IR_ONEFLOW_LITE_INCLUDE_ONEFLOW_TRANSFORM_LOWERINGLAUNCHJOB_H_\n#define ONEFLOW_IR_ONEFLOW_LITE_INCLUDE_ONEFLOW_TRANSFORM_LOWERINGLAUNCHJOB_H_\n\n#include \"mlir/Pass/Pass.h\"\n\nnamespace mlir {\nnamespace oneflow {\nnamespace lite {\n\nstd::unique_ptr<mlir::Pass> createLiteLoweringLaunchJobPass(StringRef checkpointDir);\n\n}  // namespace lite\n}  // namespace oneflow\n}  // namespace mlir\n\n#endif  // ONEFLOW_IR_ONEFLOW_LITE_INCLUDE_ONEFLOW_TRANSFORM_LOWERINGLAUNCHJOB_H_\n"
  },
  {
    "path": "oneflow/ir/oneflow-lite/include/OneFlow/Transform/MemoryPlanning.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_IR_ONEFLOW_LITE_INCLUDE_ONEFLOW_TRANSFORM_MEMORYPLANNING_H_\n#define ONEFLOW_IR_ONEFLOW_LITE_INCLUDE_ONEFLOW_TRANSFORM_MEMORYPLANNING_H_\n\n#include \"mlir/Pass/Pass.h\"\n\nnamespace mlir {\nnamespace oneflow {\nnamespace lite {\n\nstruct LiteBufferSegment {\n  StringRef device;\n  size_t size;\n  size_t alignment;\n};\n\nclass LiteBufferStrategy {\n public:\n  LiteBufferStrategy() = default;\n\n  const llvm::SmallVector<LiteBufferSegment, 4>& getSegments() const { return segments; }\n\n  llvm::SmallVector<LiteBufferSegment, 4>& getSegments() { return segments; }\n\n  int getValueSegmentId(Value value) const;\n  size_t getValueSegmentOffset(Value value) const;\n\n  LogicalResult insertValue(Value value, int segmentId, size_t segmentOffset);\n\n private:\n  llvm::SmallVector<LiteBufferSegment, 4> segments;\n  struct ValueSegmentInfo {\n    int segmentId;\n    size_t segmentOffset;\n  };\n  llvm::DenseMap<Value, ValueSegmentInfo> valueSegmentInfos;\n};\n\nstd::unique_ptr<mlir::Pass> createLiteMemoryPlanningPass(LiteBufferStrategy* strategy);\n\n}  // namespace lite\n}  // namespace oneflow\n}  // namespace mlir\n\n#endif  // ONEFLOW_IR_ONEFLOW_LITE_INCLUDE_ONEFLOW_TRANSFORM_MEMORYPLANNING_H_\n"
  },
  {
    "path": "oneflow/ir/oneflow-lite/include/OneFlow/Transform/PartitionLaunchJob.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_IR_ONEFLOW_LITE_INCLUDE_ONEFLOW_TRANSFORM_PARTITIONLAUNCHJOB_H_\n#define ONEFLOW_IR_ONEFLOW_LITE_INCLUDE_ONEFLOW_TRANSFORM_PARTITIONLAUNCHJOB_H_\n\n#include \"mlir/Pass/Pass.h\"\n\nnamespace mlir {\nnamespace oneflow {\nnamespace lite {\n\nstd::unique_ptr<mlir::Pass> createLitePartitionLaunchJobPass();\n\n}  // namespace lite\n}  // namespace oneflow\n}  // namespace mlir\n\n#endif  // ONEFLOW_IR_ONEFLOW_LITE_INCLUDE_ONEFLOW_TRANSFORM_PARTITIONLAUNCHJOB_H_\n"
  },
  {
    "path": "oneflow/ir/oneflow-lite/lib/CMakeLists.txt",
    "content": "add_subdirectory(OneFlow)\n"
  },
  {
    "path": "oneflow/ir/oneflow-lite/lib/OneFlow/CMakeLists.txt",
    "content": "set(LITE_LOWERING_SRCS \"\")\nset(LITE_LOWERING_LIBS \"\")\n\nif(LITE_USE_ASCEND_NPU)\n  include(cmake/FindAscendSdk.cmake)\n  include_directories(${ASCEND_INCLUDE_DIR})\n  include_directories(${ASCEND_INCLUDE_DIR}/../../opp)\n\n  add_definitions(-DLITE_USE_ASCEND_NPU=1)\n\n  list(APPEND LITE_LOWERING_SRCS Transform/Lowering/LoweringAscend.cpp)\n  list(APPEND LITE_LOWERING_LIBS ${ASCEND_LIBRARIES})\nendif()\n\noneflow_add_mlir_library(\n  OneFlowLiteConversion\n  ConvertToLiteExecutable.cpp\n  FlatbufferUtils.cpp\n  OneFlowLiteUtils.cpp\n  Transform/FoldVariable.cpp\n  Transform/InferPlacement.cpp\n  Transform/InsertTransferOp.cpp\n  Transform/MemoryPlanning.cpp\n  Transform/PartitionLaunchJob.cpp\n  Transform/LoweringLaunchJob.cpp\n  ${LITE_LOWERING_SRCS}\n  DEPENDS\n  MLIRIR\n  MLIRParser\n  MLIRPass\n  MLIRSPIRVDialect\n  MLIRTranslateLib\n  MLIRSupport\n  MLIROneFlow\n  MLIROneFlowExtension\n  flatcc-runtime\n  LINK_LIBS\n  MLIRIR\n  ${dialect_libs}\n  ${translation_libs}\n  MLIRParser\n  MLIRPass\n  MLIRSPIRVDialect\n  MLIRTranslateLib\n  MLIRSupport\n  MLIROneFlow\n  oneflow\n  $<TARGET_OBJECTS:of_op_schema>\n  MLIROneFlowExtension\n  ${LITE_LOWERING_LIBS}\n  $<BUILD_INTERFACE:flatcc-runtime>)\n"
  },
  {
    "path": "oneflow/ir/oneflow-lite/lib/OneFlow/ConvertToLiteExecutable.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"OneFlow/ConvertToLiteExecutable.h\"\n\n#include \"OneFlow/OneFlowDialect.h\"\n\n// undefine fallthrough to fix the conflicit of flatcc and fmt\n#if defined(fallthrough)\n#undef fallthrough\n#endif\n#include \"OneFlow/OneFlowOps.h\"\n#include \"OneFlow/OneFlowOpTraits.h\"\n#include \"OneFlow/Passes.h\"\n#include \"OneFlow/OneFlowUtils.h\"\n#include \"OneFlow/OneFlowLiteUtils.h\"\n#include \"OneFlow/Transform/FoldVariable.h\"\n#include \"OneFlow/Transform/InferPlacement.h\"\n#include \"OneFlow/Transform/InsertTransferOp.h\"\n#include \"OneFlow/Transform/LoweringLaunchJob.h\"\n#include \"OneFlow/Transform/MemoryPlanning.h\"\n#include \"OneFlow/Transform/PartitionLaunchJob.h\"\n\n#include \"llvm/ADT/SmallString.h\"\n#include \"llvm/Support/MemoryBuffer.h\"\n#include \"llvm/Support/Path.h\"\n#include \"llvm/Support/ToolOutputFile.h\"\n\n#include \"mlir/Pass/PassManager.h\"\n#include \"mlir/Support/LLVM.h\"\n#include \"mlir/Support/FileUtilities.h\"\n#include \"mlir/Support/LogicalResult.h\"\n#include \"mlir/Support/ToolUtilities.h\"\n#include \"mlir/Transforms/Passes.h\"\n\n#pragma GCC diagnostic push\n#pragma GCC diagnostic ignored \"-Wcast-qual\"\n#include \"schemas/executable_generated.h\"\n#pragma GCC diagnostic pop\n\nnamespace mlir {\nnamespace oneflow {\n\nnamespace lite {\n\nstatic flatbuffers_vec_ref_t createLiteOpAttrs(FlatbufferBuilder& builder, Operation* op) {\n  assert((llvm::dyn_cast<oneflow::UserOp>(op) || llvm::dyn_cast<UserOpCompatible>(op))\n         && \"the argument op is not a valid user op\");\n  llvm::SmallVector<oneflow_lite_AttrDef_ref_t, 4> attrDefs;\n  for (auto kv : op->getAttrDictionary()) {\n    auto attrName = kv.getName();\n    Optional<::oneflow::AttrType> attrType =\n        getUserOpAttrType(GetOpTypeName(op), attrName.strref());\n    if (!attrType) { continue; }\n\n    auto attrValue = kv.getValue();\n    StringRef strAttrType;\n    FlatbufferBuilder attrBuilder;\n    if (attrType.value() == ::oneflow::kAtInt32) {\n      strAttrType = \"i32\";\n      serializeI32Attr(attrBuilder, attrValue);\n    } else if (attrType.value() == ::oneflow::kAtInt64) {\n      strAttrType = \"i64\";\n      serializeI64Attr(attrBuilder, attrValue);\n    } else if (attrType.value() == ::oneflow::kAtBool) {\n      strAttrType = \"bool\";\n      serializeBoolAttr(attrBuilder, attrValue);\n    } else if (attrType.value() == ::oneflow::kAtFloat) {\n      strAttrType = \"f32\";\n      serializeF32Attr(attrBuilder, attrValue);\n    } else if (attrType.value() == ::oneflow::kAtDouble) {\n      strAttrType = \"f64\";\n      serializeF64Attr(attrBuilder, attrValue);\n    } else if (attrType.value() == ::oneflow::kAtString) {\n      strAttrType = \"str\";\n      serializeStringAttr(attrBuilder, attrValue);\n    } else if (attrType.value() == ::oneflow::kAtShape) {\n      strAttrType = \"shape\";\n      serializeShapeAttr(attrBuilder, attrValue);\n    } else if (attrType.value() == ::oneflow::kAtStride) {\n      strAttrType = \"stride\";\n      serializeStrideAttr(attrBuilder, attrValue);\n    } else if (attrType.value() == ::oneflow::kAtDataType) {\n      strAttrType = \"dtype\";\n      serializeDataTypeAttr(attrBuilder, attrValue);\n    } else if (attrType.value() == ::oneflow::kAtListInt32) {\n      strAttrType = \"i32s\";\n      serializeI32sAttr(attrBuilder, attrValue);\n    } else if (attrType.value() == ::oneflow::kAtListInt64) {\n      strAttrType = \"i64s\";\n      serializeI64sAttr(attrBuilder, attrValue);\n    } else if (attrType.value() == ::oneflow::kAtListFloat) {\n      strAttrType = \"f32s\";\n      serializeF32sAttr(attrBuilder, attrValue);\n    } else if (attrType.value() == ::oneflow::kAtListDataType) {\n      strAttrType = \"dtypes\";\n      serializeDataTypesAttr(attrBuilder, attrValue);\n    } else if (attrType.value() == ::oneflow::kAtListShape) {\n      strAttrType = \"shapes\";\n      serializeShapesAttr(attrBuilder, attrValue);\n    } else if (attrType.value() == ::oneflow::kAtListStride) {\n      strAttrType = \"strides\";\n      serializeStridesAttr(attrBuilder, attrValue);\n    } else if (attrType.value() == ::oneflow::kAtListString) {\n      strAttrType = \"strs\";\n      serializeStringsAttr(attrBuilder, attrValue);\n    } else {\n      llvm::errs() << \"error attribute type: \" << attrType.value() << \"\\n\";\n      exit(1);\n    }\n    oneflow_lite_AttrDef_start(builder);\n    oneflow_lite_AttrDef_type_add(builder, builder.createString(strAttrType));\n    oneflow_lite_AttrDef_key_add(builder, builder.createString(attrName.strref()));\n    oneflow_lite_AttrDef_value_add(builder, builder.streamUint8Vec([&](llvm::raw_ostream& stream) {\n      if (failed(attrBuilder.copyToStream(stream))) { return false; }\n      return true;\n    }));\n    attrDefs.push_back(oneflow_lite_AttrDef_end(builder));\n  }\n  return builder.createOffsetVecDestructive(attrDefs);\n}\n\nstatic flatbuffers_vec_ref_t createLiteVariableOpAttrs(FlatbufferBuilder& builder, VariableOp op,\n                                                       StringRef checkpointDir) {\n  llvm::SmallVector<oneflow_lite_AttrDef_ref_t, 4> attrDefs;\n  {\n    oneflow_lite_AttrDef_start(builder);\n    oneflow_lite_AttrDef_type_add(builder, builder.createString(\"dtype\"));\n    oneflow_lite_AttrDef_key_add(builder, builder.createString(\"dtype\"));\n    FlatbufferBuilder attrBuilder;\n    serializeDataTypeAttr(attrBuilder, op.getDataTypeAttr());\n    oneflow_lite_AttrDef_value_add(builder, builder.streamUint8Vec([&](llvm::raw_ostream& stream) {\n      if (failed(attrBuilder.copyToStream(stream))) { return false; }\n      return true;\n    }));\n    attrDefs.push_back(oneflow_lite_AttrDef_end(builder));\n  }\n  {\n    oneflow_lite_AttrDef_start(builder);\n    oneflow_lite_AttrDef_type_add(builder, builder.createString(\"shape\"));\n    oneflow_lite_AttrDef_key_add(builder, builder.createString(\"shape\"));\n    FlatbufferBuilder attrBuilder;\n    serializeShapeAttr(attrBuilder, op.getShapeAttr());\n    oneflow_lite_AttrDef_value_add(builder, builder.streamUint8Vec([&](llvm::raw_ostream& stream) {\n      if (failed(attrBuilder.copyToStream(stream))) { return false; }\n      return true;\n    }));\n    attrDefs.push_back(oneflow_lite_AttrDef_end(builder));\n  }\n  // serialize weight data\n  oneflow_lite_AttrDef_start(builder);\n  oneflow_lite_AttrDef_type_add(builder, builder.createString(\"u8\"));\n  oneflow_lite_AttrDef_key_add(builder, builder.createString(\"value\"));\n\n  llvm::SmallString<128> inputFilename;\n  llvm::sys::path::native(checkpointDir + \"/\" + op.getOpName() + \"/out\", inputFilename);\n  std::string errorMessage;\n  auto input = mlir::openInputFile(inputFilename, &errorMessage);\n  if (!input) {\n    llvm::errs() << errorMessage << \"\\n\";\n    exit(1);\n  }\n  oneflow_lite_AttrDef_value_add(builder, builder.streamUint8Vec([&](llvm::raw_ostream& stream) {\n    stream << input->getBuffer();\n    stream.flush();\n    return true;\n  }));\n  attrDefs.push_back(oneflow_lite_AttrDef_end(builder));\n  return builder.createOffsetVecDestructive(attrDefs);\n}\n\nstatic oneflow_lite_OpDef_ref_t createLiteVariableOpDef(\n    FlatbufferBuilder& builder, VariableOp op, llvm::DenseMap<Value, int>& valueOrdering,\n    const llvm::DenseMap<StringRef, int>& deviceOrdering, StringRef checkpointDir) {\n  oneflow_lite_OpDef_start(builder);\n  oneflow_lite_OpDef_name_add(builder, builder.createString(\"constant\"));\n  oneflow_lite_OpDef_inputs_add(builder, 0);\n\n  auto index = valueOrdering.try_emplace(op.getOutput(), valueOrdering.size()).first->second;\n  oneflow_lite_OpDef_outputs_add(builder,\n                                 builder.createInt32Vec(llvm::SmallVector<int32_t, 4>{index}));\n\n  oneflow_lite_OpDef_attrs_add(builder, createLiteVariableOpAttrs(builder, op, checkpointDir));\n\n  auto it = deviceOrdering.find(op.getDeviceTag());\n  assert(it != deviceOrdering.end());\n  oneflow_lite_OpDef_device_add(builder, it->second);\n  return oneflow_lite_OpDef_end(builder);\n}\n\nstatic oneflow_lite_OpDef_ref_t createLiteOpDef(\n    FlatbufferBuilder& builder, Operation* op, llvm::DenseMap<Value, int>& valueOrdering,\n    const llvm::DenseMap<StringRef, int>& deviceOrdering) {\n  llvm::SmallVector<size_t, 4> inputOrdering;\n  for (const auto& operand : op->getOperands()) {\n    auto it = valueOrdering.find(operand);\n    if (it == valueOrdering.end()) {\n      it = valueOrdering.try_emplace(operand, valueOrdering.size()).first;\n    }\n    inputOrdering.push_back(it->second);\n  }\n  llvm::SmallVector<size_t, 4> outputOrdering;\n  for (const auto& result : op->getResults()) {\n    auto it = valueOrdering.find(result);\n    if (it == valueOrdering.end()) {\n      it = valueOrdering.try_emplace(result, valueOrdering.size()).first;\n    }\n    outputOrdering.push_back(it->second);\n  }\n  oneflow_lite_OpDef_start(builder);\n  oneflow_lite_OpDef_name_add(builder, builder.createString(GetOpTypeName(op)));\n  oneflow_lite_OpDef_inputs_add(builder, builder.createInt32Vec(inputOrdering));\n  oneflow_lite_OpDef_outputs_add(builder, builder.createInt32Vec(outputOrdering));\n\n  oneflow_lite_OpDef_attrs_add(builder, createLiteOpAttrs(builder, op));\n\n  auto device =\n      op->getAttrOfType<StringAttr>(OpTrait::IsOpConfCompatible<void>::getDeviceTagAttr());\n  auto it = deviceOrdering.find(device.getValue());\n  assert(it != deviceOrdering.end());\n  oneflow_lite_OpDef_device_add(builder, it->second);\n  return oneflow_lite_OpDef_end(builder);\n}\n\nstatic oneflow_lite_TensorDef_ref_t createLiteTensorDef(FlatbufferBuilder& builder, Value value,\n                                                        int segmentId, size_t segmentOffset) {\n  TensorType type = value.getType().cast<TensorType>();\n  oneflow_lite_TensorDef_start(builder);\n  auto elemType = getLiteStringElementType(type.getElementType());\n  if (!elemType) {\n    llvm::errs() << \"error tensor element type: \" << type.getElementType() << \"\\n\";\n    exit(1);\n  }\n  oneflow_lite_TensorDef_type_add(builder, builder.createString(elemType.value()));\n  oneflow_lite_TensorDef_layout_add(builder, builder.createString(\"default\"));\n  oneflow_lite_TensorDef_sizes_add(builder, builder.createInt64Vec(type.getShape()));\n  oneflow_lite_TensorDef_strides_add(builder,\n                                     builder.createInt64Vec(llvm::SmallVector<int64_t, 4>{}));\n  oneflow_lite_TensorDef_segment_id_add(builder, segmentId);\n  oneflow_lite_TensorDef_segment_offset_add(builder, segmentOffset);\n  return oneflow_lite_TensorDef_end(builder);\n}\n\nstatic oneflow_lite_BufferSegmentDef_ref_t createLiteBufferSegmentDef(\n    FlatbufferBuilder& builder, const LiteBufferSegment& segment,\n    const llvm::DenseMap<StringRef, int>& deviceOrdering) {\n  auto it = deviceOrdering.find(segment.device);\n  assert(it != deviceOrdering.end());\n  oneflow_lite_BufferSegmentDef_start(builder);\n  oneflow_lite_BufferSegmentDef_size_add(builder, segment.size);\n  oneflow_lite_BufferSegmentDef_device_add(builder, it->second);\n  oneflow_lite_BufferSegmentDef_alignment_add(builder, static_cast<int>(segment.alignment));\n  return oneflow_lite_BufferSegmentDef_end(builder);\n}\n\nLogicalResult ConvertToLiteExecutable(MLIRContext* context, ModuleOp module, ConvertOptions options,\n                                      llvm::SmallVector<uint8_t, 32>* executable) {\n  mlir::PassManager pm(context);\n  pm.addPass(createCanonicalizerPass());\n  pm.addPass(createLiteFoldVariablePass());\n  pm.addPass(createLiteInferPlacementPass(options.target));\n  pm.addPass(createLiteInsertTransferOpPass());\n  pm.addPass(createLitePartitionLaunchJobPass());\n  pm.addPass(createLiteLoweringLaunchJobPass(options.checkpointDir));\n  pm.addPass(createCanonicalizerPass());\n\n  LiteBufferStrategy bufferStrategy;\n  pm.addPass(createLiteMemoryPlanningPass(&bufferStrategy));\n  if (mlir::failed(pm.run(module))) {\n    llvm::errs() << \"Failed to run oneflow lite compilation passes.\\n\";\n    return failure();\n  }\n\n  // llvm::errs() << *module << \"\\n\";\n\n  Operation* entryJobOp = getEntryJobOp(module);\n  if (!entryJobOp) {\n    llvm::errs() << \"Job not found in module: \" << *module;\n    return failure();\n  }\n\n  auto funcName = entryJobOp->getAttrOfType<StringAttr>(\"sym_name\");\n  llvm::SmallVector<StringRef, 4> devices;\n  llvm::DenseMap<StringRef, int> deviceOrdering;\n  for (const auto& segment : bufferStrategy.getSegments()) {\n    int ordering = deviceOrdering.size();\n    if (deviceOrdering.try_emplace(segment.device, ordering).second) {\n      devices.push_back(segment.device);\n    }\n  }\n  FlatbufferBuilder builder;\n  oneflow_lite_ExecutableDef_start_as_root(builder);\n  oneflow_lite_ExecutableDef_version_add(builder, 0);\n  oneflow_lite_ExecutableDef_name_add(builder, builder.createString(funcName.getValue()));\n  oneflow_lite_ExecutableDef_devices_add(builder, builder.createStringVec(devices));\n\n  llvm::DenseMap<Value, int> valueOrdering;\n  llvm::SmallVector<int, 4> inputValueOrdering, outputValueOrdering;\n  llvm::SmallVector<StringRef, 4> inputValueNames, outputValueNames;\n  llvm::SmallVector<oneflow_lite_OpDef_ref_t, 4> opDefs;\n\n  entryJobOp->walk([&](Operation* op) {\n    if (!op->hasTrait<OpTrait::IsOpConfCompatible>()) { return; }\n    if (auto inputOp = llvm::dyn_cast<InputOp>(op)) {\n      auto it = valueOrdering.try_emplace(inputOp.getOutput(), valueOrdering.size()).first;\n      inputValueOrdering.push_back(it->second);\n      inputValueNames.push_back(\n          op->getAttrOfType<StringAttr>(OpTrait::IsOpConfCompatible<void>::getOpNameAttr())\n              .getValue());\n    } else if (auto outputOp = llvm::dyn_cast<OutputOp>(op)) {\n      auto it = valueOrdering.try_emplace(outputOp.getInput(), valueOrdering.size()).first;\n      outputValueOrdering.push_back(it->second);\n      outputValueNames.push_back(\n          op->getAttrOfType<StringAttr>(OpTrait::IsOpConfCompatible<void>::getOpNameAttr())\n              .getValue());\n    } else if (auto variableOp = llvm::dyn_cast<VariableOp>(op)) {\n      opDefs.push_back(createLiteVariableOpDef(builder, variableOp, valueOrdering, deviceOrdering,\n                                               options.checkpointDir));\n    } else {\n      opDefs.push_back(createLiteOpDef(builder, op, valueOrdering, deviceOrdering));\n    }\n  });\n  oneflow_lite_ExecutableDef_ops_add(builder, builder.createOffsetVecDestructive(opDefs));\n\n  llvm::SmallVector<Value, 4> orderedValues(valueOrdering.size());\n  for (auto it : valueOrdering) { orderedValues[it.second] = it.first; }\n  llvm::SmallVector<oneflow_lite_TensorDef_ref_t, 4> tensorDefs;\n  for (auto value : orderedValues) {\n    int segmentId = bufferStrategy.getValueSegmentId(value);\n    size_t segmentOffset = bufferStrategy.getValueSegmentOffset(value);\n    tensorDefs.push_back(createLiteTensorDef(builder, value, segmentId, segmentOffset));\n  }\n  oneflow_lite_ExecutableDef_operands_add(builder, builder.createOffsetVecDestructive(tensorDefs));\n\n  oneflow_lite_ExecutableDef_inputs_add(builder, builder.createInt32Vec(inputValueOrdering));\n  oneflow_lite_ExecutableDef_outputs_add(builder, builder.createInt32Vec(outputValueOrdering));\n  oneflow_lite_ExecutableDef_input_names_add(builder, builder.createStringVec(inputValueNames));\n  oneflow_lite_ExecutableDef_output_names_add(builder, builder.createStringVec(outputValueNames));\n\n  llvm::SmallVector<oneflow_lite_BufferSegmentDef_ref_t, 4> segmentDefs;\n  for (const auto& segment : bufferStrategy.getSegments()) {\n    segmentDefs.push_back(createLiteBufferSegmentDef(builder, segment, deviceOrdering));\n  }\n  oneflow_lite_ExecutableDef_segments_add(builder, builder.createOffsetVecDestructive(segmentDefs));\n\n  oneflow_lite_ExecutableDef_end_as_root(builder);\n\n  size_t packedSize = flatcc_builder_get_buffer_size(builder);\n  executable->resize(packedSize);\n  flatcc_builder_copy_buffer(builder, executable->data(), packedSize);\n  return success();\n}\n\n}  // namespace lite\n\n}  // namespace oneflow\n}  // namespace mlir\n"
  },
  {
    "path": "oneflow/ir/oneflow-lite/lib/OneFlow/FlatbufferUtils.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n// Copyright 2020 The IREE Authors\n//\n// Licensed under the Apache License v2.0 with LLVM Exceptions.\n// See https://llvm.org/LICENSE.txt for license information.\n// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception\n\n#include \"OneFlow/FlatbufferUtils.h\"\n\n#include <assert.h>\n#include <stdlib.h>\n\n#include <cstdint>\n#include <type_traits>\n\n#include \"mlir/IR/BuiltinTypes.h\"\n\nnamespace mlir {\nnamespace oneflow {\nnamespace lite {\n\n// Combines all pages of the FlatBuffer builder into a single contiguous byte\n// buffer and returns the result.\n//\n// NOTE: this is a alloc/copy. We need to have a single contiguous buffer to\n// pass into the elements factory function and the data we have in the\n// builder is paged. If we end up with a custom attribute type for this that\n// does not support storage uniquing then we can directly allocate and copy\n// the pages into the buffer without the extra copy.\nstatic SmallVector<uint8_t, 32> cloneBufferIntoContiguousBytes(FlatbufferBuilder& fbb) {\n  size_t packedSize = flatcc_builder_get_buffer_size(fbb);\n  SmallVector<uint8_t, 32> packedData(packedSize);\n  void* result = flatcc_builder_copy_buffer(fbb, packedData.data(), packedData.size());\n  assert(result && \"flatcc_emitter_t impl failed (non-default?)\");\n  (void)result;\n  return packedData;\n}\n\nFlatbufferBuilder::FlatbufferBuilder() { flatcc_builder_init(&builder); }\n\nFlatbufferBuilder::~FlatbufferBuilder() { flatcc_builder_clear(&builder); }\n\nflatbuffers_uint8_vec_ref_t FlatbufferBuilder::streamUint8Vec(\n    std::function<bool(raw_ostream& stream)> fn, size_t alignment) {\n  flatcc_builder_start_vector(*this, 1, alignment, FLATBUFFERS_COUNT_MAX(1));\n  raw_flatbuffer_uint8_vec_ostream stream(*this);\n  if (!fn(stream)) { return 0; }\n  stream.flush();\n  return flatbuffers_uint8_vec_end(*this);\n}\n\nDenseIntElementsAttr FlatbufferBuilder::getBufferAttr(MLIRContext* context) {\n  // We require direct access to the FlatBuffer bytes so we can pass them to\n  // the attribute constructor (which needs to inspect them all for uniquing).\n  auto bufferData = cloneBufferIntoContiguousBytes(*this);\n\n  // NOTE: ew. OpaqueAttr may be better? It does equality checks but won't try\n  // to unique and would let us get a mutable buffer out.\n  return DenseIntElementsAttr::get(\n      VectorType::get({static_cast<int64_t>(bufferData.size())}, IntegerType::get(context, 8)),\n      std::move(bufferData));\n}\n\nLogicalResult FlatbufferBuilder::copyToStream(llvm::raw_ostream& output) {\n  // NOTE: expected to be the default emitter.\n  auto* E = reinterpret_cast<flatcc_emitter_t*>(flatcc_builder_get_emit_context(*this));\n\n  if (!E->front) { return failure(); }\n  if (E->front == E->back) {\n    output.write(reinterpret_cast<char*>(E->front_cursor), E->used);\n    return success();\n  }\n  size_t len = FLATCC_EMITTER_PAGE_SIZE - E->front_left;\n  output.write(reinterpret_cast<char*>(E->front_cursor), len);\n  flatcc_emitter_page_t* p = E->front->next;\n  while (p != E->back) {\n    output.write(reinterpret_cast<char*>(p->page), FLATCC_EMITTER_PAGE_SIZE);\n    p = p->next;\n  }\n  output.write(reinterpret_cast<char*>(p->page), FLATCC_EMITTER_PAGE_SIZE - E->back_left);\n  return success();\n}\n\nLogicalResult FlatbufferBuilder::printJsonToStream(bool pretty, bool includeDefaults,\n                                                   print_json_fn_t printJsonFn,\n                                                   llvm::raw_ostream& output) {\n  // The printer requires direct access to the FlatBuffer bytes so clone here.\n  auto bufferData = cloneBufferIntoContiguousBytes(*this);\n  auto moduleData = ArrayRef<uint8_t>(bufferData.data(), bufferData.size())\n                        .drop_front(sizeof(flatbuffers_uoffset_t));\n\n  flatcc_json_printer_t printer;\n  flatcc_json_printer_init_dynamic_buffer(&printer, /*buffer_size=*/0);\n  flatcc_json_printer_set_indent(&printer, pretty ? 2 : 0);\n  flatcc_json_printer_set_skip_default(&printer, !includeDefaults);\n  flatcc_json_printer_set_force_default(&printer, includeDefaults);\n\n  // Print into the dynamically-resizing buffer. May fail if OOM.\n  int rv =\n      printJsonFn(&printer, reinterpret_cast<const char*>(moduleData.data()), moduleData.size());\n  if (rv == -1) {\n    flatcc_json_printer_clear(&printer);\n    return failure();\n  }\n\n  // Take the buffer from the printer; note that it is 0 terminated and can be\n  // used directly as a cstr if needed.\n  size_t outputSize = 0;\n  char* outputBytes =\n      reinterpret_cast<char*>(flatcc_json_printer_finalize_dynamic_buffer(&printer, &outputSize));\n  output.write(outputBytes, outputSize);\n  free(outputBytes);\n\n  return success();\n}\n\n}  // namespace lite\n}  // namespace oneflow\n}  // namespace mlir\n"
  },
  {
    "path": "oneflow/ir/oneflow-lite/lib/OneFlow/OneFlowLiteUtils.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"OneFlow/OneFlowLiteUtils.h\"\n\n#include \"oneflow/core/framework/user_op_def.h\"\n#include \"oneflow/core/framework/user_op_registry_manager.h\"\n\n#include \"OneFlow/OneFlowDialect.h\"\n#include \"OneFlow/OneFlowOps.h\"\n#include \"OneFlow/OneFlowOpTraits.h\"\n#include \"OneFlow/OneFlowUtils.h\"\n\n#include \"mlir/Support/LLVM.h\"\n#include \"mlir/Support/LogicalResult.h\"\n\n#pragma GCC diagnostic push\n#pragma GCC diagnostic ignored \"-Wcast-qual\"\n#include \"schemas/executable_generated.h\"\n#include \"schemas/attributes/bool_generated.h\"\n#include \"schemas/attributes/f32_generated.h\"\n#include \"schemas/attributes/f32s_generated.h\"\n#include \"schemas/attributes/f64_generated.h\"\n#include \"schemas/attributes/i32_generated.h\"\n#include \"schemas/attributes/i32s_generated.h\"\n#include \"schemas/attributes/i64_generated.h\"\n#include \"schemas/attributes/i64s_generated.h\"\n#include \"schemas/attributes/shape_generated.h\"\n#include \"schemas/attributes/shapes_generated.h\"\n#include \"schemas/attributes/str_generated.h\"\n#include \"schemas/attributes/strs_generated.h\"\n#pragma GCC diagnostic pop\n\nnamespace mlir {\nnamespace oneflow {\n\nnamespace lite {\n\nOperation* getEntryJobOp(ModuleOp module) { return getEntryJobOp(module.getOperation()); }\n\nOperation* getEntryJobOp(Operation* op) {\n  Operation* entry = nullptr;\n  op->walk([&](oneflow::Job job) -> WalkResult {\n    entry = job.getOperation();\n    return WalkResult::advance();\n  });\n  return entry;\n}\n\nStringAttr getValueDevice(Value value) {\n  StringAttr device;\n  Operation* op = value.getDefiningOp();\n  if (auto copyOp = dyn_cast<CopyOp>(op)) {\n    device = copyOp.getDeviceTypeAttr();\n  } else {\n    device = value.getDefiningOp()->getAttrOfType<StringAttr>(\n        OpTrait::IsOpConfCompatible<void>::getDeviceTagAttr());\n  }\n  return device;\n}\n\nOptional<StringRef> getLiteStringElementType(Type type) {\n  assert(type.isIntOrFloat());\n  if (type.isF16()) {\n    return StringRef(\"f16\");\n  } else if (type.isBF16()) {\n    return StringRef(\"bf16\");\n  } else if (type.isF32()) {\n    return StringRef(\"f32\");\n  } else if (type.isF64()) {\n    return StringRef(\"f64\");\n  } else if (type.isSignedInteger()) {\n    int bitwidth = type.getIntOrFloatBitWidth();\n    return StringRef(\"i\" + llvm::Twine(bitwidth).str());\n  } else if (type.isUnsignedInteger()) {\n    int bitwidth = type.getIntOrFloatBitWidth();\n    return StringRef(\"u\" + llvm::Twine(bitwidth).str());\n  } else {\n    return llvm::None;\n  }\n}\n\nOptional<StringRef> getLiteStringElementType(::mlir::oneflow::DataType type) {\n  switch (type) {\n    case ::mlir::oneflow::DataType::DT_Bool: return StringRef(\"bool\");\n    case ::mlir::oneflow::DataType::DT_Char: return StringRef(\"char\");\n    case ::mlir::oneflow::DataType::DT_Float16: return StringRef(\"f16\");\n    case ::mlir::oneflow::DataType::DT_Float: return StringRef(\"f32\");\n    case ::mlir::oneflow::DataType::DT_Double: return StringRef(\"f64\");\n    case ::mlir::oneflow::DataType::DT_Int8: return StringRef(\"i8\");\n    case ::mlir::oneflow::DataType::DT_Int32: return StringRef(\"i32\");\n    case ::mlir::oneflow::DataType::DT_Int64: return StringRef(\"i64\");\n    case ::mlir::oneflow::DataType::DT_UInt8: return StringRef(\"u8\");\n    default: {\n      return llvm::None;\n    }\n  }\n}\n\nOptional<::oneflow::AttrType> getUserOpAttrType(StringRef opName, StringRef attrName) {\n  const ::oneflow::user_op::OpRegistryResult* val =\n      ::oneflow::user_op::UserOpRegistryMgr::Get().GetOpRegistryResult(opName.str());\n  if (!val) {\n    llvm::errs() << \"unregistered user op: \" << opName << \"\\n\";\n    exit(1);\n  }\n  ::oneflow::user_op::UserOpDefWrapper op_def(val->op_def);\n  if (!op_def.IsAttrName(attrName.str())) { return llvm::None; }\n  return op_def.GetAttrType(attrName.str());\n}\n\nvoid serializeI32Attr(FlatbufferBuilder& builder, Attribute attribute) {\n  oneflow_lite_I32Def_start_as_root(builder);\n  oneflow_lite_I32Def_value_add(builder, attribute.dyn_cast<IntegerAttr>().getSInt());\n  oneflow_lite_I32Def_end_as_root(builder);\n}\n\nvoid serializeI64Attr(FlatbufferBuilder& builder, Attribute attribute) {\n  oneflow_lite_I64Def_start_as_root(builder);\n  oneflow_lite_I64Def_value_add(builder, attribute.dyn_cast<IntegerAttr>().getSInt());\n  oneflow_lite_I64Def_end_as_root(builder);\n}\n\nvoid serializeBoolAttr(FlatbufferBuilder& builder, Attribute attribute) {\n  oneflow_lite_BoolDef_start_as_root(builder);\n  oneflow_lite_BoolDef_value_add(builder, attribute.dyn_cast<BoolAttr>().getValue());\n  oneflow_lite_BoolDef_end_as_root(builder);\n}\n\nvoid serializeF32Attr(FlatbufferBuilder& builder, Attribute attribute) {\n  oneflow_lite_F32Def_start_as_root(builder);\n  oneflow_lite_F32Def_value_add(builder,\n                                attribute.dyn_cast<FloatAttr>().getValue().convertToFloat());\n  oneflow_lite_F32Def_end_as_root(builder);\n}\n\nvoid serializeF64Attr(FlatbufferBuilder& builder, Attribute attribute) {\n  oneflow_lite_F64Def_start_as_root(builder);\n  oneflow_lite_F64Def_value_add(builder,\n                                attribute.dyn_cast<FloatAttr>().getValue().convertToDouble());\n  oneflow_lite_F64Def_end_as_root(builder);\n}\n\nvoid serializeStringAttr(FlatbufferBuilder& builder, Attribute attribute) {\n  oneflow_lite_StringDef_start_as_root(builder);\n  oneflow_lite_StringDef_value_add(\n      builder, builder.createString(attribute.dyn_cast<StringAttr>().getValue()));\n  oneflow_lite_StringDef_end_as_root(builder);\n}\n\nvoid serializeShapeAttr(FlatbufferBuilder& builder, Attribute attribute) {\n  oneflow_lite_ShapeDef_start_as_root(builder);\n  SmallVector<int64_t, 4> shape;\n  for (auto v : attribute.dyn_cast<ArrayAttr>().getValue()) {\n    shape.push_back(v.dyn_cast<IntegerAttr>().getSInt());\n  }\n  oneflow_lite_ShapeDef_value_add(builder, builder.createInt64Vec(shape));\n  oneflow_lite_ShapeDef_end_as_root(builder);\n}\n\nvoid serializeStrideAttr(FlatbufferBuilder& builder, Attribute attribute) {\n  serializeShapeAttr(builder, attribute);\n}\n\nvoid serializeDataTypeAttr(FlatbufferBuilder& builder, Attribute attribute) {\n  oneflow_lite_StringDef_start_as_root(builder);\n  auto dtype =\n      getLiteStringElementType(attribute.dyn_cast<mlir::oneflow::DataTypeAttr>().getValue());\n  if (!dtype) {\n    llvm::errs() << \"error data type: \" << attribute << \"\\n\";\n    exit(1);\n  }\n  oneflow_lite_StringDef_value_add(builder, builder.createString(dtype.value()));\n  oneflow_lite_StringDef_end_as_root(builder);\n}\n\nvoid serializeI32sAttr(FlatbufferBuilder& builder, Attribute attribute) {\n  oneflow_lite_I32sDef_start_as_root(builder);\n  SmallVector<int32_t, 4> vec;\n  for (auto v : attribute.dyn_cast<ArrayAttr>().getValue()) {\n    vec.push_back(v.dyn_cast<IntegerAttr>().getSInt());\n  }\n  oneflow_lite_I32sDef_value_add(builder, builder.createInt32Vec(vec));\n  oneflow_lite_I32sDef_end_as_root(builder);\n}\n\nvoid serializeI64sAttr(FlatbufferBuilder& builder, Attribute attribute) {\n  oneflow_lite_I64sDef_start_as_root(builder);\n  SmallVector<int64_t, 4> vec;\n  for (auto v : attribute.dyn_cast<ArrayAttr>().getValue()) {\n    vec.push_back(v.dyn_cast<IntegerAttr>().getSInt());\n  }\n  oneflow_lite_I64sDef_value_add(builder, builder.createInt64Vec(vec));\n  oneflow_lite_I64sDef_end_as_root(builder);\n}\n\nvoid serializeF32sAttr(FlatbufferBuilder& builder, Attribute attribute) {\n  oneflow_lite_F32sDef_start_as_root(builder);\n  flatbuffers_float_vec_start(builder);\n  for (auto v : attribute.dyn_cast<ArrayAttr>().getValue()) {\n    flatbuffers_float_vec_push_create(builder, v.dyn_cast<FloatAttr>().getValue().convertToFloat());\n  }\n  oneflow_lite_F32sDef_value_add(builder, flatbuffers_float_vec_end(builder));\n  oneflow_lite_F32sDef_end_as_root(builder);\n}\n\nvoid serializeDataTypesAttr(FlatbufferBuilder& builder, Attribute attribute) {\n  oneflow_lite_StringsDef_start_as_root(builder);\n  llvm::SmallVector<StringRef, 4> dtypes;\n  for (auto v : attribute.dyn_cast<ArrayAttr>().getValue()) {\n    auto dtype = getLiteStringElementType(v.dyn_cast<mlir::oneflow::DataTypeAttr>().getValue());\n    if (!dtype) {\n      llvm::errs() << \"error data type: \" << v << \"\\n\";\n      exit(1);\n    }\n    dtypes.push_back(dtype.value());\n  }\n  oneflow_lite_StringsDef_value_add(builder, builder.createStringVec(dtypes));\n  oneflow_lite_StringsDef_end_as_root(builder);\n}\n\nvoid serializeShapesAttr(FlatbufferBuilder& builder, Attribute attribute) {\n  oneflow_lite_ShapesDef_start_as_root(builder);\n  SmallVector<oneflow_lite_ShapeDef_ref_t, 4> shapeDefs;\n  for (auto v : attribute.dyn_cast<ArrayAttr>().getValue()) {\n    oneflow_lite_ShapeDef_start(builder);\n    SmallVector<int64_t, 4> vec;\n    for (auto p : v.dyn_cast<ArrayAttr>().getValue()) {\n      vec.push_back(p.dyn_cast<IntegerAttr>().getSInt());\n    }\n    oneflow_lite_ShapeDef_value_add(builder, builder.createInt64Vec(vec));\n    shapeDefs.push_back(oneflow_lite_ShapeDef_end(builder));\n  }\n  oneflow_lite_ShapesDef_value_add(builder, builder.createOffsetVecDestructive(shapeDefs));\n  oneflow_lite_ShapesDef_end_as_root(builder);\n}\n\nvoid serializeStridesAttr(FlatbufferBuilder& builder, Attribute attribute) {\n  return serializeShapesAttr(builder, attribute);\n}\n\nvoid serializeStringsAttr(FlatbufferBuilder& builder, Attribute attribute) {\n  oneflow_lite_StringsDef_start_as_root(builder);\n  SmallVector<oneflow_lite_StringDef_ref_t, 4> stringDefs;\n  for (auto v : attribute.dyn_cast<ArrayAttr>().getValue()) {\n    oneflow_lite_StringDef_start(builder);\n    oneflow_lite_StringDef_value_add(builder,\n                                     builder.createString(v.dyn_cast<StringAttr>().getValue()));\n    stringDefs.push_back(oneflow_lite_StringDef_end(builder));\n  }\n  oneflow_lite_StringsDef_value_add(builder, builder.createOffsetVecDestructive(stringDefs));\n  oneflow_lite_StringsDef_end_as_root(builder);\n}\n\n}  // namespace lite\n\n}  // namespace oneflow\n}  // namespace mlir\n"
  },
  {
    "path": "oneflow/ir/oneflow-lite/lib/OneFlow/Transform/FoldVariable.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"OneFlow/Transform/FoldVariable.h\"\n\nnamespace mlir {\nnamespace oneflow {\nnamespace lite {\n\nstruct FoldVariablePass : public PassWrapper<FoldVariablePass, OperationPass<>> {\n  void runOnOperation() override {\n    // TODO\n  }\n};\n\nstd::unique_ptr<mlir::Pass> createLiteFoldVariablePass() {\n  return std::unique_ptr<mlir::Pass>(new FoldVariablePass);\n}\n\n}  // namespace lite\n}  // namespace oneflow\n}  // namespace mlir\n"
  },
  {
    "path": "oneflow/ir/oneflow-lite/lib/OneFlow/Transform/InferPlacement.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"OneFlow/Transform/InferPlacement.h\"\n\n#include \"OneFlow/OneFlowDialect.h\"\n#include \"OneFlow/OneFlowOps.h\"\n#include \"OneFlow/OneFlowOpTraits.h\"\n\n#include \"mlir/IR/BuiltinOps.h\"\n#include \"mlir/IR/Builders.h\"\n#include \"mlir/IR/MLIRContext.h\"\n\nnamespace mlir {\nnamespace oneflow {\nnamespace lite {\n\nstatic bool CanScheduleOnTarget(Operation* op, StringRef target) {\n  if (!op->hasTrait<OpTrait::IsOpConfCompatible>()) { return false; }\n  if (llvm::dyn_cast<oneflow::InputOp>(op) || llvm::dyn_cast<oneflow::OutputOp>(op)) {\n    return false;\n  }\n  // TODO()\n  return true;\n}\n\nstruct InferPlacementPass : public PassWrapper<InferPlacementPass, OperationPass<ModuleOp>> {\n  StringRef target_;\n  explicit InferPlacementPass(StringRef target) : target_(target) {}\n\n  void runOnOperation() override;\n};\n\nvoid InferPlacementPass::runOnOperation() {\n  getOperation().walk([&](Operation* op) {\n    if (!op->hasTrait<OpTrait::IsOpConfCompatible>()) { return; }\n    auto target = [&]() -> StringRef {\n      if (CanScheduleOnTarget(op, target_)) { return target_; }\n      return StringRef(\"host\");\n    }();\n\n    OpBuilder builder(&getContext());\n    op->setAttr(OpTrait::IsOpConfCompatible<void>::getDeviceTagAttr(),\n                builder.getStringAttr(target));\n  });\n}\n\nstd::unique_ptr<mlir::Pass> createLiteInferPlacementPass(StringRef target) {\n  return std::unique_ptr<mlir::Pass>(new InferPlacementPass(target));\n}\n\n}  // namespace lite\n}  // namespace oneflow\n}  // namespace mlir\n"
  },
  {
    "path": "oneflow/ir/oneflow-lite/lib/OneFlow/Transform/InsertTransferOp.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"OneFlow/Transform/InsertTransferOp.h\"\n\n#include \"OneFlow/OneFlowDialect.h\"\n#include \"OneFlow/OneFlowOps.h\"\n#include \"OneFlow/OneFlowOpTraits.h\"\n\n#include \"mlir/IR/BuiltinOps.h\"\n#include \"mlir/IR/Builders.h\"\n#include \"mlir/IR/MLIRContext.h\"\n\nnamespace mlir {\nnamespace oneflow {\nnamespace lite {\n\nstruct InsertTransferOpPass : public PassWrapper<InsertTransferOpPass, OperationPass<ModuleOp>> {\n  void runOnOperation() override;\n\n  StringAttr InferTargetDevice(StringAttr from, StringAttr to) const;\n};\n\nStringAttr InsertTransferOpPass::InferTargetDevice(StringAttr from, StringAttr to) const {\n  auto IsHostDevice = [](StringAttr device) {\n    return device == \"host\" || device == \"cpu\" || device == \"x86\" || device == \"arm\";\n  };\n  return IsHostDevice(from) ? to : from;\n}\n\nvoid InsertTransferOpPass::runOnOperation() {\n  auto opNameAttrkey = OpTrait::IsOpConfCompatible<void>::getOpNameAttr();\n  auto deviceTagAttrKey = OpTrait::IsOpConfCompatible<void>::getDeviceTagAttr();\n  auto deviceNameAttrKey = OpTrait::IsOpConfCompatible<void>::getDeviceNameAttr();\n\n  OpBuilder builder(&getContext());\n\n  getOperation().walk([&](Operation* op) {\n    if (!op->hasTrait<OpTrait::IsOpConfCompatible>()) { return; }\n    auto device = op->getAttrOfType<StringAttr>(deviceTagAttrKey);\n\n    for (Value result : op->getResults()) {\n      llvm::DenseMap<StringAttr, SmallVector<OpOperand*, 4>> operandsToReplace;\n      for (auto& use : result.getUses()) {\n        if (!use.getOwner()->hasTrait<OpTrait::IsOpConfCompatible>()) { continue; }\n        auto use_device = use.getOwner()->getAttrOfType<StringAttr>(deviceTagAttrKey);\n        if (use_device != device) { operandsToReplace[use_device].push_back(&use); }\n      }\n      for (const auto& it : operandsToReplace) {\n        NamedAttrList attrs;\n        attrs.set(opNameAttrkey, builder.getStringAttr(\"copy\"));\n        attrs.set(deviceTagAttrKey, InferTargetDevice(device, it.first));\n        attrs.set(deviceNameAttrKey,\n                  builder.getArrayAttr(llvm::to_vector<8>(llvm::map_range(\n                      ArrayRef<StringRef>({\"@0:0\"}),\n                      [&](StringRef v) -> Attribute { return builder.getStringAttr(v); }))));\n        attrs.set(builder.getStringAttr(\"device_type\"), it.first);\n\n        builder.setInsertionPointAfter(op);\n        SmallVector<mlir::Value, 4> operands{result};\n        auto copy_op =\n            builder.create<oneflow::CopyOp>(op->getLoc(), op->getResultTypes(), operands, attrs);\n\n        for (OpOperand* operand : it.second) {\n          operand->getOwner()->setOperand(operand->getOperandNumber(), copy_op.getOut());\n        }\n      }\n    }\n  });\n}\n\nstd::unique_ptr<mlir::Pass> createLiteInsertTransferOpPass() {\n  return std::unique_ptr<mlir::Pass>(new InsertTransferOpPass());\n}\n\n}  // namespace lite\n}  // namespace oneflow\n}  // namespace mlir\n"
  },
  {
    "path": "oneflow/ir/oneflow-lite/lib/OneFlow/Transform/Lowering/LoweringAscend.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"OneFlow/Transform/Lowering/LoweringAscend.h\"\n\n#include \"OneFlow/Transform/Lowering/LoweringAscendUtils.h\"\n#include <memory>\n#include <vector>\n\n#include \"llvm/Support/MemoryBuffer.h\"\n#include \"llvm/Support/Path.h\"\n#include \"llvm/Support/ToolOutputFile.h\"\n\n#include \"OneFlow/OneFlowDialect.h\"\n#include \"OneFlow/OneFlowOps.h\"\n#include \"OneFlow/OneFlowOpTraits.h\"\n#include \"OneFlow/OneFlowLiteUtils.h\"\n\n#include \"mlir/IR/BuiltinOps.h\"\n#include \"mlir/IR/Builders.h\"\n#include \"mlir/IR/MLIRContext.h\"\n#include \"mlir/Support/FileUtilities.h\"\n#include \"mlir/Support/ToolUtilities.h\"\n\n// huawei ascend sdk headers\n#pragma GCC diagnostic push\n#pragma GCC diagnostic ignored \"-Wignored-qualifiers\"\n#include \"op_proto/built-in/inc/all_ops.h\"\n#pragma GCC diagnostic pop\n\nnamespace mlir {\nnamespace oneflow {\nnamespace lite {\n\nclass AscendValue {\n public:\n  AscendValue() = default;\n  AscendValue(const std::shared_ptr<ge::Operator>& op, const ge::TensorDesc& type,\n              StringRef componentName)\n      : op_(op), type_(type), componentName_(componentName), componentIndex_(-1) {}\n  AscendValue(const std::shared_ptr<ge::Operator>& op, const ge::TensorDesc& type,\n              StringRef componentName, int componentIndex)\n      : op_(op), type_(type), componentName_(componentName), componentIndex_(componentIndex) {}\n\n  AscendValue(const AscendValue&) = default;\n\n  const std::shared_ptr<ge::Operator>& getOperation() const { return op_; }\n\n  const ge::TensorDesc& getType() const { return type_; }\n\n  StringRef getComponentName() const { return componentName_; }\n  int getComponentIndex() const { return componentIndex_; }\n\n  StringRef getComponentNameAndIndex() const {\n    if (componentIndex_ < 0) { return componentName_; }\n    auto name = componentName_ + llvm::Twine(componentIndex_);\n    return StringRef(name.str());\n  }\n\n  void setOperation(const std::shared_ptr<ge::Operator>& op) { op_ = op; }\n  void setType(const ge::TensorDesc& type) { type_ = type; }\n  void setComponentName(StringRef componentName) { componentName_ = componentName; }\n  void setComponentIndex(int componentIndex) { componentIndex_ = componentIndex; }\n\n private:\n  std::shared_ptr<ge::Operator> op_;\n  ge::TensorDesc type_;\n  StringRef componentName_;\n  int componentIndex_;\n};\n\nclass AscendCompiler {\n public:\n  AscendCompiler() = default;\n\n  void addInputs(llvm::SmallVector<Value, 4>& operands);\n\n  void lowerOp(VariableOp op, StringRef checkpointDir);\n  void lowerOp(Conv2DOp op);\n  void lowerOp(NormalizationInferenceOp op);\n  void lowerOp(ReluOp op);\n  void lowerOp(MaxPool2DOp op);\n  void lowerOp(AvgPool2DOp op);\n  void lowerOp(Add2Op op);\n  void lowerOp(AdaptiveAvgPool2DOp op);\n  void lowerOp(MatmulOp op);\n  void lowerOp(BroadcastAddOp op);\n  void lowerOp(ReshapeOp op);\n  void lowerOp(func::ReturnOp op);\n\n  void serializeToBuffer(llvm::SmallVector<uint8_t, 4>* data);\n\n private:\n  AscendValue getValue(Value value) const {\n    auto it = ascendVals.find(value);\n    assert(it != ascendVals.end());\n    return it->second;\n  }\n\n  template<typename T>\n  std::shared_ptr<T> createOp(llvm::Twine opName) {\n    auto op = std::make_shared<T>(opName.str());\n    ascendOps.push_back(op);\n    return op;\n  }\n\n  llvm::SmallVector<AscendValue, 4> inputs;\n  llvm::SmallVector<AscendValue, 4> results;\n  llvm::SmallVector<std::shared_ptr<ge::Operator>, 4> ascendOps;\n  llvm::DenseMap<Value, AscendValue> ascendVals;\n};\n\nvoid AscendCompiler::serializeToBuffer(llvm::SmallVector<uint8_t, 4>* data) {\n  std::vector<ge::Operator> ins;\n  std::vector<std::pair<ge::Operator, ge::AscendString>> outs;\n  for (auto in : inputs) { ins.push_back(*(in.getOperation())); }\n  for (auto out : results) {\n    outs.push_back(std::make_pair(*(out.getOperation()), out.getComponentNameAndIndex().data()));\n  }\n  ge::Graph graph(\"ascend-graph\");\n  graph.SetInputs(ins).SetOutputs(outs);\n\n  if (!graph.IsValid()) {\n    llvm::errs() << \"ascend graph is invalid\\n\";\n    exit(1);\n  }\n  const char* outputFilename = \".__TMP__ascend_graph\";\n  graph.SaveToFile(outputFilename);\n\n  std::string errorMessage;\n  auto f = mlir::openInputFile(outputFilename, &errorMessage);\n  if (!f) {\n    llvm::errs() << errorMessage << \"\\n\";\n    exit(1);\n  }\n  data->resize(f->getBufferSize());\n  memcpy(data->data(), f->getBufferStart(), data->size());\n\n  // clean temp file\n  if (0 != remove(outputFilename)) {\n    llvm::errs() << \"faile to clean temp file\\n\";\n    exit(1);\n  }\n}\n\nvoid AscendCompiler::addInputs(llvm::SmallVector<Value, 4>& operands) {\n  for (auto operand : llvm::enumerate(operands)) {\n    llvm::Twine opName = \"input_\" + llvm::Twine(operand.index());\n    auto inputOp = createOp<ge::op::Data>(opName.str());\n    auto ascendType = convertAscendType(operand.value().getType());\n    inputOp->update_input_desc_x(ascendType);\n    inputOp->update_output_desc_y(ascendType);\n    inputs.push_back(AscendValue(inputOp, ascendType, \"y\"));\n    ascendVals[operand.value()] = inputs.back();\n  }\n}\n\nvoid AscendCompiler::lowerOp(VariableOp op, StringRef checkpointDir) {\n  auto ascendType = convertAscendType(op.data_typeAttr(), op.shapeAttr());\n  llvm::SmallString<128> inputFilename;\n  llvm::sys::path::native(checkpointDir + \"/\" + op.getOpName() + \"/out\", inputFilename);\n  std::string errorMessage;\n  auto input = mlir::openInputFile(inputFilename, &errorMessage);\n  if (!input) {\n    llvm::errs() << errorMessage << \"\\n\";\n    exit(1);\n  }\n  auto constantOp = createOp<ge::op::Const>(op.getOpName());\n  auto tensor = std::make_shared<ge::Tensor>();\n  tensor->SetTensorDesc(ascendType);\n  tensor->SetData(reinterpret_cast<const uint8_t*>(input->getBufferStart()),\n                  input->getBufferSize());\n  constantOp->set_attr_value(*tensor);\n  ascendVals[op.getOutput()] = AscendValue(constantOp, ascendType, \"y\");\n}\n\n#define SET_INPUT(op, name, value) \\\n  op->set_input_##name##_by_name(*(value.getOperation()), value.getComponentNameAndIndex().data())\n\nvoid AscendCompiler::lowerOp(Conv2DOp op) {\n  auto conv2DOp = createOp<ge::op::Conv2D>(op.getOpName());\n  conv2DOp->set_attr_pads(convertPaddings(op.padding_before()));\n  conv2DOp->set_attr_dilations(convertDilations(op.getDilationRate()));\n  conv2DOp->set_attr_strides(convertStrides(op.getStrides()));\n  conv2DOp->set_attr_groups(op.getGroups());\n  conv2DOp->set_attr_data_format(convertDataFormat(op.data_format()).data());\n\n  SET_INPUT(conv2DOp, x, getValue(op.getIn()));\n  SET_INPUT(conv2DOp, filter, getValue(op.getWeight()));\n  if (op.getBias()) { SET_INPUT(conv2DOp, bias, getValue(op.getBias())); }\n  auto outType = convertAscendType(op.getOut().getType());\n  conv2DOp->update_output_desc_y(outType);\n\n  auto output = AscendValue(conv2DOp, outType, \"y\");\n\n  if (op._add_to_output()) {\n    auto addOp = createOp<ge::op::AddV2>(op.getOpName() + \"_add_to_output\");\n    SET_INPUT(addOp, x1, output);\n    SET_INPUT(addOp, x2, getValue(op._add_to_output()));\n    addOp->update_output_desc_y(outType);\n    output = AscendValue(addOp, outType, \"y\");\n  }\n  ascendVals[op.getOut()] = output;\n}\n\nvoid AscendCompiler::lowerOp(NormalizationInferenceOp op) {\n  auto batchNormOp = createOp<ge::op::BNInfer>(op.getOpName());\n  batchNormOp->set_attr_epsilon(op.getEpsilon().convertToFloat());\n\n  SET_INPUT(batchNormOp, x, getValue(op.getX()));\n  SET_INPUT(batchNormOp, mean, getValue(op.getMovingMean()));\n  SET_INPUT(batchNormOp, variance, getValue(op.getMovingVariance()));\n  SET_INPUT(batchNormOp, scale, getValue(op.getGamma()));\n  SET_INPUT(batchNormOp, offset, getValue(op.getBeta()));\n\n  auto outType = convertAscendType(op.getY().getType());\n  batchNormOp->update_output_desc_y(outType);\n\n  auto output = AscendValue(batchNormOp, outType, \"y\");\n  if (op._add_to_output()) {\n    auto addOp = createOp<ge::op::AddV2>(op.getOpName() + \"_add_to_output\");\n    SET_INPUT(addOp, x1, output);\n    SET_INPUT(addOp, x2, getValue(op._add_to_output()));\n    addOp->update_output_desc_y(outType);\n    output = AscendValue(addOp, outType, \"y\");\n  }\n  ascendVals[op.getY()] = output;\n}\n\nvoid AscendCompiler::lowerOp(ReluOp op) {\n  auto reluOp = createOp<ge::op::Relu>(op.getOpName());\n  SET_INPUT(reluOp, x, getValue(op.getX()));\n  auto outType = convertAscendType(op.getY().getType());\n  reluOp->update_output_desc_y(outType);\n  ascendVals[op.getY()] = AscendValue(reluOp, outType, \"y\");\n}\n\nvoid AscendCompiler::lowerOp(MaxPool2DOp op) {\n  auto maxPoolOp = createOp<ge::op::MaxPoolV3>(op.getOpName());\n  maxPoolOp->set_attr_ksize(convertKernelSize(op.getKernelSize()));\n  maxPoolOp->set_attr_pads(convertPaddings(op.getPadding()));\n  maxPoolOp->set_attr_strides(convertStrides(op.getStride()));\n  maxPoolOp->set_attr_ceil_mode(op.ceil_mode());\n  maxPoolOp->set_attr_padding_mode(\"CALCULATED\");\n  maxPoolOp->set_attr_global_pooling(false);\n\n  SET_INPUT(maxPoolOp, x, getValue(op.getX()));\n  auto outType = convertAscendType(op.getY().getType());\n  maxPoolOp->update_output_desc_y(outType);\n  ascendVals[op.getY()] = AscendValue(maxPoolOp, outType, \"y\");\n}\n\nvoid AscendCompiler::lowerOp(AvgPool2DOp op) {\n  auto avgPoolOp = createOp<ge::op::AvgPoolV2>(op.getOpName());\n  avgPoolOp->set_attr_ksize(convertKernelSize(op.getKernelSize()));\n  avgPoolOp->set_attr_pads(convertPaddings(op.getPadding()));\n  avgPoolOp->set_attr_strides(convertStrides(op.getStride()));\n  avgPoolOp->set_attr_ceil_mode(op.ceil_mode());\n  avgPoolOp->set_attr_padding_mode(\"CALCULATED\");\n  avgPoolOp->set_attr_global_pooling(false);\n  avgPoolOp->set_attr_exclusive(!op.count_include_pad());\n\n  SET_INPUT(avgPoolOp, x, getValue(op.getX()));\n  auto outType = convertAscendType(op.getY().getType());\n  avgPoolOp->update_output_desc_y(outType);\n  ascendVals[op.getY()] = AscendValue(avgPoolOp, outType, \"y\");\n}\n\nvoid AscendCompiler::lowerOp(Add2Op op) {\n  auto addOp = createOp<ge::op::AddV2>(op.getOpName());\n  SET_INPUT(addOp, x1, getValue(op.getIn0()));\n  SET_INPUT(addOp, x2, getValue(op.getIn1()));\n  auto outType = convertAscendType(op.getOut().getType());\n  addOp->update_output_desc_y(outType);\n  ascendVals[op.getOut()] = AscendValue(addOp, outType, \"y\");\n}\n\nvoid AscendCompiler::lowerOp(AdaptiveAvgPool2DOp op) {\n  auto adaptiveAvgPoolOp = createOp<ge::op::AdaptiveAvgPool2d>(op.getOpName());\n  ArrayAttr output_size = op.output_size();\n  assert(output_size.size() == 2);\n  int64_t s0 = output_size[0].dyn_cast<IntegerAttr>().getSInt();\n  int64_t s1 = output_size[1].dyn_cast<IntegerAttr>().getSInt();\n  adaptiveAvgPoolOp->set_attr_output_size(ge::Operator::OpListInt({s0, s1}));\n  SET_INPUT(adaptiveAvgPoolOp, x, getValue(op.getX()));\n  auto outType = convertAscendType(op.getY().getType());\n  adaptiveAvgPoolOp->update_output_desc_y(outType);\n  ascendVals[op.getY()] = AscendValue(adaptiveAvgPoolOp, outType, \"y\");\n}\n\nvoid AscendCompiler::lowerOp(MatmulOp op) {\n  auto matmulOp = createOp<ge::op::MatMulV2>(op.getOpName());\n  matmulOp->set_attr_transpose_x1(op.getTransposeA());\n  matmulOp->set_attr_transpose_x2(op.getTransposeB());\n\n  SET_INPUT(matmulOp, x1, getValue(op.getA()));\n  SET_INPUT(matmulOp, x2, getValue(op.getB()));\n  auto outType = convertAscendType(op.getOut().getType());\n  matmulOp->update_output_desc_y(outType);\n\n  auto output = AscendValue(matmulOp, outType, \"y\");\n  if (op._add_to_output()) {\n    auto addOp = createOp<ge::op::AddV2>(op.getOpName() + \"_add_to_output\");\n    SET_INPUT(addOp, x1, output);\n    SET_INPUT(addOp, x2, getValue(op._add_to_output()));\n    addOp->update_output_desc_y(outType);\n    output = AscendValue(addOp, outType, \"y\");\n  }\n  ascendVals[op.getOut()] = output;\n}\n\nvoid AscendCompiler::lowerOp(BroadcastAddOp op) {\n  auto addOp = createOp<ge::op::AddV2>(op.getOpName());\n  SET_INPUT(addOp, x1, getValue(op.getX()));\n  SET_INPUT(addOp, x2, getValue(op.getY()));\n  auto outType = convertAscendType(op.getZ().getType());\n  addOp->update_output_desc_y(outType);\n  ascendVals[op.getZ()] = AscendValue(addOp, outType, \"y\");\n}\n\nvoid AscendCompiler::lowerOp(ReshapeOp op) {\n  llvm::SmallVector<int64_t, 4> shape;\n  for (auto v : op.getShape()) { shape.push_back(v.dyn_cast<IntegerAttr>().getSInt()); }\n  auto constantOp = createOp<ge::op::Const>(op.getOpName() + \"_shape\");\n  auto shapeType =\n      ge::TensorDesc(ge::Shape(std::vector<int64_t>{static_cast<int64_t>(shape.size())}),\n                     ge::FORMAT_NCHW, ge::DT_INT64);\n  auto tensor = std::make_shared<ge::Tensor>();\n  tensor->SetTensorDesc(shapeType);\n  tensor->SetData(reinterpret_cast<const uint8_t*>(shape.data()), shape.size() * sizeof(int64_t));\n  constantOp->set_attr_value(*tensor);\n\n  auto reshapeOp = createOp<ge::op::Reshape>(op.getOpName());\n  SET_INPUT(reshapeOp, x, getValue(op.getIn()));\n  SET_INPUT(reshapeOp, shape, (AscendValue(constantOp, shapeType, \"y\")));\n\n  auto outType = convertAscendType(op.getOut().getType());\n  reshapeOp->update_output_desc_y(outType);\n  ascendVals[op.getOut()] = AscendValue(reshapeOp, outType, \"y\");\n}\n\nvoid AscendCompiler::lowerOp(func::ReturnOp op) {\n  for (auto operand : op.getOperands()) { results.push_back(getValue(operand)); }\n}\n\n#undef SET_INPUT\n\nLogicalResult loweringAscend(OpBuilder& builder, Operation* callee, StringRef checkpointDir,\n                             llvm::SmallVector<uint8_t, 4>* loweringData) {\n  AscendCompiler compiler;\n  llvm::SmallVector<Value, 4> inputs;\n  auto func = dyn_cast<func::FuncOp>(callee);\n  for (auto argument : func.getArguments()) { inputs.push_back(argument); }\n\n  compiler.addInputs(inputs);\n\n  func.getBody().walk([&](Operation* op) {\n    if (auto x = dyn_cast<VariableOp>(op)) {\n      compiler.lowerOp(x, checkpointDir);\n    } else if (auto x = dyn_cast<Conv2DOp>(op)) {\n      compiler.lowerOp(x);\n    } else if (auto x = dyn_cast<NormalizationInferenceOp>(op)) {\n      compiler.lowerOp(x);\n    } else if (auto x = dyn_cast<ReluOp>(op)) {\n      compiler.lowerOp(x);\n    } else if (auto x = dyn_cast<MaxPool2DOp>(op)) {\n      compiler.lowerOp(x);\n    } else if (auto x = dyn_cast<AvgPool2DOp>(op)) {\n      compiler.lowerOp(x);\n    } else if (auto x = dyn_cast<Add2Op>(op)) {\n      compiler.lowerOp(x);\n    } else if (auto x = dyn_cast<AdaptiveAvgPool2DOp>(op)) {\n      compiler.lowerOp(x);\n    } else if (auto x = dyn_cast<MatmulOp>(op)) {\n      compiler.lowerOp(x);\n    } else if (auto x = dyn_cast<BroadcastAddOp>(op)) {\n      compiler.lowerOp(x);\n    } else if (auto x = dyn_cast<ReshapeOp>(op)) {\n      compiler.lowerOp(x);\n    } else if (auto x = dyn_cast<func::ReturnOp>(op)) {\n      compiler.lowerOp(x);\n    } else {\n      llvm::errs() << \"could not lowerring \" << op->getName() << \" for backend ascend\\n\";\n      exit(1);\n    }\n  });\n  compiler.serializeToBuffer(loweringData);\n  return success();\n}\n\n}  // namespace lite\n}  // namespace oneflow\n}  // namespace mlir\n"
  },
  {
    "path": "oneflow/ir/oneflow-lite/lib/OneFlow/Transform/LoweringLaunchJob.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"OneFlow/Transform/LoweringLaunchJob.h\"\n\n#include \"OneFlow/OneFlowDialect.h\"\n#include \"OneFlow/OneFlowOps.h\"\n#include \"OneFlow/OneFlowOpTraits.h\"\n#include \"OneFlow/OneFlowLiteUtils.h\"\n\n#include \"mlir/IR/BuiltinOps.h\"\n#include \"mlir/IR/Builders.h\"\n#include \"mlir/IR/MLIRContext.h\"\n\n#ifdef LITE_USE_ASCEND_NPU\n#include \"OneFlow/Transform/Lowering/LoweringAscend.h\"\n#endif  // LITE_USE_ASCEND_NPU\n\nnamespace mlir {\nnamespace oneflow {\nnamespace lite {\n\nstruct LoweringLaunchJobPass : public PassWrapper<LoweringLaunchJobPass, OperationPass<ModuleOp>> {\n  StringRef checkpointDir;\n\n  explicit LoweringLaunchJobPass(StringRef checkpointDir) : checkpointDir(checkpointDir) {}\n\n  void runOnOperation() override;\n\n  LogicalResult loweringLaunchJob(OpBuilder& builder, Operation* callee, StringRef backend,\n                                  llvm::SmallVector<uint8_t, 4>* loweringData);\n};\n\nLogicalResult LoweringLaunchJobPass::loweringLaunchJob(\n    OpBuilder& builder, Operation* callee, StringRef backend,\n    llvm::SmallVector<uint8_t, 4>* loweringData) {\n  if (backend == \"ascend\") {\n#ifdef LITE_USE_ASCEND_NPU\n    return loweringAscend(builder, callee, checkpointDir, loweringData);\n#else\n    llvm::errs() << \"please recompile with LITE_USE_ASCEND_NPU=ON\\n\";\n    return failure();\n#endif  // LITE_USE_ASCEND_NPU\n  } else {\n    llvm::errs() << \"lowering for backend \" << backend << \" is not supported yet\\n\";\n    return failure();\n  }\n  return success();\n}\n\nvoid LoweringLaunchJobPass::runOnOperation() {\n  SmallVector<Operation*, 4> launchOps;\n  Operation* entryJobOp = getEntryJobOp(getOperation());\n  entryJobOp->walk([&](Operation* op) {\n    if (dyn_cast<oneflow::MlirJitOp>(op)) { launchOps.push_back(op); }\n  });\n\n  SymbolTable symbolTable(getOperation());\n  OpBuilder builder(&getContext());\n\n  // TODO(): register backend converters\n  for (Operation* op : launchOps) {\n    auto launchOp = dyn_cast<oneflow::MlirJitOp>(op);\n    Operation* callee = symbolTable.lookup(launchOp.getCallee());\n    if (!callee) {\n      llvm::errs() << \"can not find a callee named \" << launchOp.getCallee() << \"\\n\";\n      return signalPassFailure();\n    }\n    llvm::SmallVector<uint8_t, 4> loweringData;\n    if (failed(loweringLaunchJob(builder, callee, launchOp.getDeviceTag(), &loweringData))) {\n      llvm::errs() << \"failed to lowerring job \" << launchOp.getCallee() << \"\\n\";\n    }\n    op->setAttr(\"mlir_assembly\",\n                builder.getStringAttr(StringRef(reinterpret_cast<const char*>(loweringData.data()),\n                                                loweringData.size())));\n  }\n}\n\nstd::unique_ptr<mlir::Pass> createLiteLoweringLaunchJobPass(StringRef checkpointDir) {\n  return std::unique_ptr<mlir::Pass>(new LoweringLaunchJobPass(checkpointDir));\n}\n\n}  // namespace lite\n}  // namespace oneflow\n}  // namespace mlir\n"
  },
  {
    "path": "oneflow/ir/oneflow-lite/lib/OneFlow/Transform/MemoryPlanning.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"OneFlow/Transform/MemoryPlanning.h\"\n\n#include <assert.h>\n\n#include \"OneFlow/OneFlowDialect.h\"\n#include \"OneFlow/OneFlowOps.h\"\n#include \"OneFlow/OneFlowOpTraits.h\"\n#include \"OneFlow/OneFlowLiteUtils.h\"\n\n#include \"mlir/IR/BuiltinOps.h\"\n#include \"mlir/IR/Builders.h\"\n#include \"mlir/IR/MLIRContext.h\"\n#include \"llvm/ADT/SetVector.h\"\n\nnamespace mlir {\nnamespace oneflow {\nnamespace lite {\n\nint LiteBufferStrategy::getValueSegmentId(Value value) const {\n  auto it = valueSegmentInfos.find(value);\n  if (it == valueSegmentInfos.end()) { return -1; }\n  return it->second.segmentId;\n}\n\nsize_t LiteBufferStrategy::getValueSegmentOffset(Value value) const {\n  auto it = valueSegmentInfos.find(value);\n  if (it == valueSegmentInfos.end()) { return -1; }\n  return it->second.segmentOffset;\n}\n\nLogicalResult LiteBufferStrategy::insertValue(Value value, int segmentId, size_t segmentOffset) {\n  if (segments.size() < segmentId) {\n    llvm::errs() << \"segmentId is out of boundary.\\n\";\n    return failure();\n  }\n  valueSegmentInfos[value] = ValueSegmentInfo{segmentId, segmentOffset};\n  return success();\n}\n\nclass ValueLiveness {\n public:\n  ValueLiveness() = default;\n\n  void addValue(Value value, size_t liveStart, size_t liveEnd) {\n    liveness[value] = LiveRange{liveStart, liveEnd};\n  }\n\n  bool isLivenessOverlap(Value lhs, Value rhs) {\n    LiveRange lhs_liveness = liveness[lhs];\n    LiveRange rhs_liveness = liveness[rhs];\n    return lhs_liveness.liveEnd < rhs_liveness.liveStart\n           || lhs_liveness.liveStart > rhs_liveness.liveEnd;\n  }\n\n private:\n  struct LiveRange {\n    size_t liveStart;\n    size_t liveEnd;\n  };\n  llvm::DenseMap<Value, LiveRange> liveness;\n};\n\nstruct MemoryPlanningPass : public PassWrapper<MemoryPlanningPass, OperationPass<ModuleOp>> {\n  Operation* entryJobOp;\n  ValueLiveness valueLiveness;\n  llvm::SmallVector<Value, 4> sortedValues;\n  LiteBufferStrategy* bufferStrategy;\n\n  explicit MemoryPlanningPass(LiteBufferStrategy* strategy) : bufferStrategy(strategy) {}\n\n  void runOnOperation() override {\n    entryJobOp = getEntryJobOp(getOperation());\n    if (!entryJobOp) {\n      llvm::errs() << \"Job not found in module: \" << *getOperation();\n      exit(1);\n    }\n    computeValueLiveness();\n    computeValueSizeAndSort();\n    doMemoryPlanning();\n  }\n\n  void computeValueLiveness();\n  void computeValueSizeAndSort();\n  void doMemoryPlanning();\n  bool canShareMemoryWithBlock(Value value, llvm::SmallVector<Value, 4> block);\n};\n\nvoid MemoryPlanningPass::computeValueLiveness() {\n  llvm::SmallVector<Operation*, 4> opList;\n  llvm::DenseMap<Operation*, size_t> opOrdering;\n  llvm::DenseMap<Value, size_t> liveEnds;\n\n  // Compute value liveness\n  entryJobOp->walk([&](Operation* op) {\n    if (!op->hasTrait<OpTrait::IsOpConfCompatible>() || llvm::dyn_cast<OutputOp>(op)) { return; }\n    opOrdering[op] = opOrdering.size();\n    opList.push_back(op);\n  });\n  for (Operation* op : llvm::reverse(opList)) {\n    size_t ordering = opOrdering[op];\n    for (Value operand : op->getOperands()) {\n      if (liveEnds.find(operand) == liveEnds.end()) { liveEnds[operand] = ordering; }\n    }\n    for (Value result : op->getResults()) {\n      size_t liveEnd = opOrdering.size();\n      const auto& it = liveEnds.find(result);\n      if (it != liveEnds.end()) { liveEnd = it->second; }\n      valueLiveness.addValue(result, ordering, liveEnd);\n    }\n  }\n}\n\nstatic bool isDynamicTensorType(TensorType value) {\n  for (auto dim : value.getShape()) {\n    if (dim == -1) { return true; }\n  }\n  return false;\n}\n\n/// Returns the bitwidth of a scalar or vector type.\nstatic size_t getTensorBitSize(TensorType value) {\n  auto type = value.getElementType();\n  assert(type.isIntOrFloat());\n  if (isDynamicTensorType(value)) { return 0; }\n  int64_t num = 1;\n  for (auto dim : value.getShape()) { num *= dim; }\n  return num * type.getIntOrFloatBitWidth();\n}\n\nvoid MemoryPlanningPass::computeValueSizeAndSort() {\n  llvm::SetVector<Value, llvm::SmallVector<Value, 4>> valueList;\n  entryJobOp->walk([&](Operation* op) {\n    if (!op->hasTrait<OpTrait::IsOpConfCompatible>() || llvm::dyn_cast<InputOp>(op)\n        || llvm::dyn_cast<OutputOp>(op)) {\n      return;\n    }\n    valueList.insert(op->getOperands().begin(), op->getOperands().end());\n    valueList.insert(op->getResults().begin(), op->getResults().end());\n  });\n  sortedValues = valueList.takeVector();\n  llvm::sort(sortedValues.begin(), sortedValues.end(), [](Value lhs, Value rhs) {\n    assert(lhs.getType().isa<TensorType>());\n    assert(rhs.getType().isa<TensorType>());\n    return getTensorBitSize(lhs.getType().cast<TensorType>())\n           > getTensorBitSize(rhs.getType().cast<TensorType>());\n  });\n}\n\nbool MemoryPlanningPass::canShareMemoryWithBlock(Value value, llvm::SmallVector<Value, 4> block) {\n  if (isDynamicTensorType(value.getType().cast<TensorType>())) { return false; }\n  auto device = getValueDevice(value);\n  for (auto v : block) {\n    if (device != getValueDevice(v)) { return false; }\n    if (valueLiveness.isLivenessOverlap(value, v)) { return false; }\n  }\n  return true;\n}\n\nvoid MemoryPlanningPass::doMemoryPlanning() {\n  if (sortedValues.empty()) { return; }\n  llvm::SmallVector<llvm::SmallVector<Value, 4>, 4> memoryBlocks;\n  for (auto value : sortedValues) {\n    bool shared = false;\n    for (auto& block : memoryBlocks) {\n      if (canShareMemoryWithBlock(value, block)) {\n        block.push_back(value);\n        shared = true;\n      }\n    }\n    if (!shared) { memoryBlocks.push_back(llvm::SmallVector<Value, 4>{value}); }\n  }\n\n  llvm::SmallVector<LiteBufferSegment, 4>& segments = bufferStrategy->getSegments();\n  for (auto& block : memoryBlocks) {\n    auto device = getValueDevice(block.front());\n    int segmentId = segments.size();\n    size_t blockSize = 0;\n    size_t alignment = 512;\n    for (auto value : block) {\n      size_t valueSize = getTensorBitSize(value.getType().cast<TensorType>());\n      if (valueSize > blockSize) { blockSize = valueSize; }\n    }\n    blockSize = (blockSize + 7) / 8;                                  // convert to bytes\n    blockSize = (blockSize + alignment - 1) / alignment * alignment;  // alignas 512 bytes\n    segments.push_back(LiteBufferSegment{device.getValue(), blockSize, alignment});\n\n    for (auto value : block) {\n      auto result = bufferStrategy->insertValue(value, segmentId, /*segmentOffset*/ 0);\n      assert(succeeded(result) && \"failed to insert value to buffer strategy\");\n    }\n  }\n}\n\nstd::unique_ptr<mlir::Pass> createLiteMemoryPlanningPass(LiteBufferStrategy* strategy) {\n  return std::unique_ptr<mlir::Pass>(new MemoryPlanningPass(strategy));\n}\n\n}  // namespace lite\n}  // namespace oneflow\n}  // namespace mlir\n"
  },
  {
    "path": "oneflow/ir/oneflow-lite/lib/OneFlow/Transform/PartitionLaunchJob.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"OneFlow/Transform/PartitionLaunchJob.h\"\n\n#include \"OneFlow/OneFlowDialect.h\"\n#include \"OneFlow/OneFlowOps.h\"\n#include \"OneFlow/OneFlowOpTraits.h\"\n\n#include \"mlir/IR/IRMapping.h\"\n#include \"mlir/IR/BuiltinOps.h\"\n#include \"mlir/IR/Builders.h\"\n#include \"mlir/IR/MLIRContext.h\"\n\nnamespace mlir {\nnamespace oneflow {\nnamespace lite {\n\nstruct PartitionLaunchJobPass\n    : public PassWrapper<PartitionLaunchJobPass, OperationPass<ModuleOp>> {\n  void runOnOperation() override;\n\n  bool needPartition(StringRef device) const { return device == \"tensorrt\" || device == \"ascend\"; }\n\n  func::FuncOp addCallableFunc(OpBuilder& builder, StringRef callee_name,\n                               const llvm::SmallVector<Value, 4>& operands,\n                               const llvm::SmallVector<Value, 4>& results,\n                               const llvm::SmallVector<Operation*, 4>& block);\n};\n\nfunc::FuncOp PartitionLaunchJobPass::addCallableFunc(\n    OpBuilder& builder, StringRef callee_name, const llvm::SmallVector<Value, 4>& operands,\n    const llvm::SmallVector<Value, 4>& results, const llvm::SmallVector<Operation*, 4>& block) {\n  llvm::SmallVector<Type, 4> operand_types, result_types;\n  for (auto operand : operands) { operand_types.push_back(operand.getType()); }\n  for (auto result : results) { result_types.push_back(result.getType()); }\n\n  auto parentFuncOp = block[0]->getParentOfType<oneflow::Job>();\n  auto parentModuleOp = parentFuncOp->getParentOfType<ModuleOp>();\n\n  Block::iterator insertPt(parentFuncOp->getNextNode());\n  builder.setInsertionPointToStart(parentModuleOp.getBody());\n\n  auto funcType = builder.getFunctionType(operand_types, result_types);\n  auto funcOp = builder.create<func::FuncOp>(block[0]->getLoc(), callee_name, funcType);\n  auto* entryBlock = funcOp.addEntryBlock();\n\n  IRMapping mapping;\n  for (auto operand : llvm::enumerate(operands)) {\n    mapping.map(operand.value(), entryBlock->getArgument(operand.index()));\n  }\n\n  builder.setInsertionPointToStart(entryBlock);\n  for (Operation* op : block) {\n    builder.insert(op->clone(mapping));\n    for (auto result : llvm::enumerate(op->getResults())) {\n      mapping.map(result.value(), entryBlock->back().getResult(result.index()));\n    }\n  }\n  llvm::SmallVector<Value, 4> mappingResults;\n  for (auto result : results) { mappingResults.push_back(mapping.lookup(result)); }\n  builder.create<func::ReturnOp>(block[0]->getLoc(), mappingResults);\n  return funcOp;\n}\n\nvoid PartitionLaunchJobPass::runOnOperation() {\n  // TODO(): refactor\n  llvm::DenseMap<StringRef, llvm::SetVector<Operation*, llvm::SmallVector<Operation*, 4>>>\n      partitionOps;\n  getOperation().walk([&](Operation* op) {\n    if (!op->hasTrait<OpTrait::IsOpConfCompatible>()) { return; }\n    if (dyn_cast<CopyOp>(op)) { return; }\n    auto device =\n        op->getAttrOfType<StringAttr>(OpTrait::IsOpConfCompatible<void>::getDeviceTagAttr());\n    if (!needPartition(device.getValue())) { return; }\n\n    partitionOps[device.getValue()].insert(op);\n  });\n\n  for (auto it : partitionOps) {\n    if (it.second.empty()) { continue; }\n\n    llvm::DenseMap<Value, int> inputVals, resultVals;\n    for (Operation* op : it.second) {\n      for (Value operand : op->getOperands()) {\n        if (!it.second.count(operand.getDefiningOp())) {\n          inputVals.try_emplace(operand, inputVals.size());\n        }\n      }\n      for (Value result : op->getResults()) {\n        for (auto& use : result.getUses()) {\n          if (!it.second.count(use.getOwner())) {\n            resultVals.try_emplace(result, resultVals.size());\n            break;\n          }\n        }\n      }\n    }\n    auto block = it.second.takeVector();\n    // TODO(): check job is acyclic or not\n    llvm::SmallVector<Value, 4> operands(inputVals.size());\n    llvm::SmallVector<Value, 4> results(resultVals.size());\n    for (auto in : inputVals) { operands[in.second] = in.first; }\n    for (auto out : resultVals) { results[out.second] = out.first; }\n\n    OpBuilder builder(&getContext());\n    auto callableFunc =\n        addCallableFunc(builder, it.first.str() + \".launch\", operands, results, block);\n\n    Operation* firstOp = block[0];\n    NamedAttrList attributes;\n    attributes.set(OpTrait::IsOpConfCompatible<void>::getDeviceTagAttr(),\n                   OpTrait::IsOpConfCompatible<void>::getDeviceTag(firstOp));\n    attributes.set(OpTrait::IsOpConfCompatible<void>::getDeviceNameAttr(),\n                   OpTrait::IsOpConfCompatible<void>::getDeviceName(firstOp));\n    if (auto hierarchy = OpTrait::IsOpConfCompatible<void>::getHierarchy(firstOp)) {\n      attributes.set(OpTrait::IsOpConfCompatible<void>::getHierarchyAttr(), hierarchy);\n    }\n    attributes.set(OpTrait::IsOpConfCompatible<void>::getOpNameAttr(),\n                   builder.getStringAttr(it.first.str() + \".launch\"));\n    if (auto scope_symbol_id = OpTrait::IsOpConfCompatible<void>::getScopeSymbolID((firstOp))) {\n      attributes.set(OpTrait::IsOpConfCompatible<void>::getScopeSymbolIDAttr(), scope_symbol_id);\n    }\n    builder.setInsertionPointAfter(firstOp);\n\n    auto launchOp =\n        builder.create<MlirJitOp>(firstOp->getLoc(), callableFunc, attributes, operands);\n    launchOp->setAttr(\"mlir_assembly\", builder.getStringAttr(\"\"));\n\n    for (auto result : llvm::enumerate(results)) {\n      result.value().replaceAllUsesWith(launchOp->getOperand(result.index()));\n    }\n    for (Operation* op : block) {\n      op->dropAllUses();\n      op->erase();\n    }\n  }\n}\n\nstd::unique_ptr<mlir::Pass> createLitePartitionLaunchJobPass() {\n  return std::unique_ptr<mlir::Pass>(new PartitionLaunchJobPass());\n}\n\n}  // namespace lite\n}  // namespace oneflow\n}  // namespace mlir\n"
  },
  {
    "path": "oneflow/ir/oneflow-lite/lib/OneFlow/cmake/FindAscendSdk.cmake",
    "content": "find_path(ASCEND_INCLUDE_DIR graph/graph.h\n          PATHS ${ASCEND_HOME_PATH} ${ASCEND_HOME_PATH}/include $ENV{ASCEND_HOME_PATH}\n                $ENV{ASCEND_HOME_PATH}/include)\n\nfind_library(\n  ASCEND_GRAPH_LIBRARY NAMES graph PATHS ${ASCEND_HOME_PATH} ${ASCEND_HOME_PATH}/lib64\n                                         $ENV{ASCEND_HOME_PATH} $ENV{ASCEND_HOME_PATH}/lib64)\n\nif(NOT ASCEND_INCLUDE_DIR OR NOT ASCEND_GRAPH_LIBRARY)\n  message(\n    FATAL_ERROR \"Ascend Sdk was not found. You can set ASCEND_HOME_PATH to specify the search path.\"\n  )\nendif()\n\nadd_library(ascend_graph SHARED IMPORTED GLOBAL)\nset_property(TARGET ascend_graph PROPERTY IMPORTED_LOCATION ${ASCEND_GRAPH_LIBRARY})\n\nset(ASCEND_LIBRARIES ascend_graph)\n"
  },
  {
    "path": "oneflow/ir/oneflow-lite/schemas/CMakeLists.txt",
    "content": "include(install_flatcc.cmake)\n\nadd_subdirectory(attributes)\n\nfile(GLOB LITE_SCHEMA_FILES *.fbs)\nflatcc_generate(SCHEMA_SRCS ${LITE_SCHEMA_FILES})\n\nadd_custom_target(lite_schema_gen DEPENDS ${SCHEMA_SRCS} flatcc-runtime)\nadd_library(lite_schemas INTERFACE)\nadd_dependencies(lite_schemas lite_schema_gen lite_attribute_schema_gen)\n"
  },
  {
    "path": "oneflow/ir/oneflow-lite/schemas/attributes/CMakeLists.txt",
    "content": "file(GLOB LITE_ATTRIBUTE_SCHEMA_FILES *.fbs)\nflatcc_generate(ATTRIBUTE_SCHEMA_SRCS ${LITE_ATTRIBUTE_SCHEMA_FILES})\n\nadd_custom_target(lite_attribute_schema_gen DEPENDS ${ATTRIBUTE_SCHEMA_SRCS} flatcc-runtime)\n"
  },
  {
    "path": "oneflow/ir/oneflow-lite/schemas/attributes/bool.fbs",
    "content": "// Copyright 2020 The OneFlow Authors. All rights reserved.\n// \n// Licensed under the Apache License, Version 2.0 (the \"License\");\n// you may not use this file except in compliance with the License.\n// You may obtain a copy of the License at\n// \n//     http://www.apache.org/licenses/LICENSE-2.0\n// \n// Unless required by applicable law or agreed to in writing, software\n// distributed under the License is distributed on an \"AS IS\" BASIS,\n// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n// See the License for the specific language governing permissions and\n// limitations under the License.\n\nnamespace oneflow_lite;\n\ntable BoolDef {\n  value:bool;\n}\n\nroot_type BoolDef;\n"
  },
  {
    "path": "oneflow/ir/oneflow-lite/schemas/attributes/f32.fbs",
    "content": "// Copyright 2020 The OneFlow Authors. All rights reserved.\n// \n// Licensed under the Apache License, Version 2.0 (the \"License\");\n// you may not use this file except in compliance with the License.\n// You may obtain a copy of the License at\n// \n//     http://www.apache.org/licenses/LICENSE-2.0\n// \n// Unless required by applicable law or agreed to in writing, software\n// distributed under the License is distributed on an \"AS IS\" BASIS,\n// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n// See the License for the specific language governing permissions and\n// limitations under the License.\n\nnamespace oneflow_lite;\n\ntable F32Def {\n  value:float;\n}\n\nroot_type F32Def;\n"
  },
  {
    "path": "oneflow/ir/oneflow-lite/schemas/attributes/f32s.fbs",
    "content": "// Copyright 2020 The OneFlow Authors. All rights reserved.\n// \n// Licensed under the Apache License, Version 2.0 (the \"License\");\n// you may not use this file except in compliance with the License.\n// You may obtain a copy of the License at\n// \n//     http://www.apache.org/licenses/LICENSE-2.0\n// \n// Unless required by applicable law or agreed to in writing, software\n// distributed under the License is distributed on an \"AS IS\" BASIS,\n// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n// See the License for the specific language governing permissions and\n// limitations under the License.\n\nnamespace oneflow_lite;\n\ntable F32sDef {\n  value:[float];\n}\n\nroot_type F32sDef;\n"
  },
  {
    "path": "oneflow/ir/oneflow-lite/schemas/attributes/f64.fbs",
    "content": "// Copyright 2020 The OneFlow Authors. All rights reserved.\n// \n// Licensed under the Apache License, Version 2.0 (the \"License\");\n// you may not use this file except in compliance with the License.\n// You may obtain a copy of the License at\n// \n//     http://www.apache.org/licenses/LICENSE-2.0\n// \n// Unless required by applicable law or agreed to in writing, software\n// distributed under the License is distributed on an \"AS IS\" BASIS,\n// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n// See the License for the specific language governing permissions and\n// limitations under the License.\n\nnamespace oneflow_lite;\n\ntable F64Def {\n  value:double;\n}\n\nroot_type F64Def;\n"
  },
  {
    "path": "oneflow/ir/oneflow-lite/schemas/attributes/i32.fbs",
    "content": "// Copyright 2020 The OneFlow Authors. All rights reserved.\n// \n// Licensed under the Apache License, Version 2.0 (the \"License\");\n// you may not use this file except in compliance with the License.\n// You may obtain a copy of the License at\n// \n//     http://www.apache.org/licenses/LICENSE-2.0\n// \n// Unless required by applicable law or agreed to in writing, software\n// distributed under the License is distributed on an \"AS IS\" BASIS,\n// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n// See the License for the specific language governing permissions and\n// limitations under the License.\n\nnamespace oneflow_lite;\n\ntable I32Def {\n  value:int;\n}\n\nroot_type I32Def;\n"
  },
  {
    "path": "oneflow/ir/oneflow-lite/schemas/attributes/i32s.fbs",
    "content": "// Copyright 2020 The OneFlow Authors. All rights reserved.\n// \n// Licensed under the Apache License, Version 2.0 (the \"License\");\n// you may not use this file except in compliance with the License.\n// You may obtain a copy of the License at\n// \n//     http://www.apache.org/licenses/LICENSE-2.0\n// \n// Unless required by applicable law or agreed to in writing, software\n// distributed under the License is distributed on an \"AS IS\" BASIS,\n// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n// See the License for the specific language governing permissions and\n// limitations under the License.\n\nnamespace oneflow_lite;\n\ntable I32sDef {\n  value:[int];\n}\n\nroot_type I32sDef;\n"
  },
  {
    "path": "oneflow/ir/oneflow-lite/schemas/attributes/i64.fbs",
    "content": "// Copyright 2020 The OneFlow Authors. All rights reserved.\n// \n// Licensed under the Apache License, Version 2.0 (the \"License\");\n// you may not use this file except in compliance with the License.\n// You may obtain a copy of the License at\n// \n//     http://www.apache.org/licenses/LICENSE-2.0\n// \n// Unless required by applicable law or agreed to in writing, software\n// distributed under the License is distributed on an \"AS IS\" BASIS,\n// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n// See the License for the specific language governing permissions and\n// limitations under the License.\n\nnamespace oneflow_lite;\n\ntable I64Def {\n  value:long;\n}\n\nroot_type I64Def;\n"
  },
  {
    "path": "oneflow/ir/oneflow-lite/schemas/attributes/i64s.fbs",
    "content": "// Copyright 2020 The OneFlow Authors. All rights reserved.\n// \n// Licensed under the Apache License, Version 2.0 (the \"License\");\n// you may not use this file except in compliance with the License.\n// You may obtain a copy of the License at\n// \n//     http://www.apache.org/licenses/LICENSE-2.0\n// \n// Unless required by applicable law or agreed to in writing, software\n// distributed under the License is distributed on an \"AS IS\" BASIS,\n// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n// See the License for the specific language governing permissions and\n// limitations under the License.\n\nnamespace oneflow_lite;\n\ntable I64sDef {\n  value:[long];\n}\n\nroot_type I64sDef;\n"
  },
  {
    "path": "oneflow/ir/oneflow-lite/schemas/attributes/shape.fbs",
    "content": "// Copyright 2020 The OneFlow Authors. All rights reserved.\n// \n// Licensed under the Apache License, Version 2.0 (the \"License\");\n// you may not use this file except in compliance with the License.\n// You may obtain a copy of the License at\n// \n//     http://www.apache.org/licenses/LICENSE-2.0\n// \n// Unless required by applicable law or agreed to in writing, software\n// distributed under the License is distributed on an \"AS IS\" BASIS,\n// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n// See the License for the specific language governing permissions and\n// limitations under the License.\n\nnamespace oneflow_lite;\n\ntable ShapeDef {\n  value:[long];\n}\n\nroot_type ShapeDef;\n"
  },
  {
    "path": "oneflow/ir/oneflow-lite/schemas/attributes/shapes.fbs",
    "content": "// Copyright 2020 The OneFlow Authors. All rights reserved.\n// \n// Licensed under the Apache License, Version 2.0 (the \"License\");\n// you may not use this file except in compliance with the License.\n// You may obtain a copy of the License at\n// \n//     http://www.apache.org/licenses/LICENSE-2.0\n// \n// Unless required by applicable law or agreed to in writing, software\n// distributed under the License is distributed on an \"AS IS\" BASIS,\n// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n// See the License for the specific language governing permissions and\n// limitations under the License.\n\ninclude \"shape.fbs\";\n\nnamespace oneflow_lite;\n\ntable ShapesDef {\n  value:[ShapeDef];\n}\n\nroot_type ShapesDef;\n"
  },
  {
    "path": "oneflow/ir/oneflow-lite/schemas/attributes/str.fbs",
    "content": "// Copyright 2020 The OneFlow Authors. All rights reserved.\n// \n// Licensed under the Apache License, Version 2.0 (the \"License\");\n// you may not use this file except in compliance with the License.\n// You may obtain a copy of the License at\n// \n//     http://www.apache.org/licenses/LICENSE-2.0\n// \n// Unless required by applicable law or agreed to in writing, software\n// distributed under the License is distributed on an \"AS IS\" BASIS,\n// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n// See the License for the specific language governing permissions and\n// limitations under the License.\n\nnamespace oneflow_lite;\n\ntable StringDef {\n  value:string;\n}\n\nroot_type StringDef;\n"
  },
  {
    "path": "oneflow/ir/oneflow-lite/schemas/attributes/strs.fbs",
    "content": "// Copyright 2020 The OneFlow Authors. All rights reserved.\n// \n// Licensed under the Apache License, Version 2.0 (the \"License\");\n// you may not use this file except in compliance with the License.\n// You may obtain a copy of the License at\n// \n//     http://www.apache.org/licenses/LICENSE-2.0\n// \n// Unless required by applicable law or agreed to in writing, software\n// distributed under the License is distributed on an \"AS IS\" BASIS,\n// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n// See the License for the specific language governing permissions and\n// limitations under the License.\n\nnamespace oneflow_lite;\n\ntable StringsDef {\n  value:[string];\n}\n\nroot_type StringsDef;\n"
  },
  {
    "path": "oneflow/ir/oneflow-lite/schemas/executable.fbs",
    "content": "// Copyright 2020 The OneFlow Authors. All rights reserved.\n// \n// Licensed under the Apache License, Version 2.0 (the \"License\");\n// you may not use this file except in compliance with the License.\n// You may obtain a copy of the License at\n// \n//     http://www.apache.org/licenses/LICENSE-2.0\n// \n// Unless required by applicable law or agreed to in writing, software\n// distributed under the License is distributed on an \"AS IS\" BASIS,\n// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n// See the License for the specific language governing permissions and\n// limitations under the License.\n\nnamespace oneflow_lite;\n\n// Buffer segment can be regarded as the device memory block\ntable BufferSegmentDef {\n  size:long;\n  // Device the segment belongs\n  device:int;\n  alignment:int;\n}\n\ntable TensorDef {\n  // Type should be one of the primary data type\n  // i8,i16,i32,i64,u8,u16,u32,u64,f8,f16,bf16,f32,f64,bool\n  type:string;\n  layout:string;\n  sizes:[long];\n  strides:[long];\n\n  // Memory planning information about this tensor\n  segment_id:int;\n  segment_offset:long;\n}\n\ntable ParameterDef {\n  // Type should be one of the primary data type\n  // i8,i16,i32,i64,u8,u16,u32,u64,f8,f16,bf16,f32,f64,bool\n  type:string;\n  sizes:[long];\n  buffer:[byte];\n}\n\ntable AttrDef {\n  // Type should be one of the primary data type\n  // i8,i16,i32,i64,u8,u16,u32,u64,f8,f16,bf16,f32,f64,bool,str,param,etc.\n  // or a list of them i8s,i16s,i32s,i64s,u8s,u16s,u32s,u64s,\n  // f8s,f16s,bf16s,f32s,f64s,bools,strs\n  type:string;\n  key:string;\n  value:[byte];\n}\n\ntable OpFunctionDef {\n  name:string;\n  // Code generated for AOT codegen. \n  body:[byte];\n  // Signature of the function call. Signature can be empty to use the\n  // default function signature\n  // \"(t0, t1, ..., tN, r0, r1, ..., rM) -> (tN+1, tN+2, ..., tN+T)\"\n  // in which t means Tensor and r means op attributes\n  signature:string;\n}\n\ntable OpDef {\n  // The operator type name, such as \"conv2d\", \"softmax\"\n  name:string;\n  // Input operand indices\n  inputs:[int];\n  // Output operand indices\n  outputs:[int];\n  // Attributes the operator has\n  attrs:[AttrDef];\n  // Device that executes the operator\n  device:int;\n}\n\ntable ExecutableDef {\n  version:int;\n  name:string;\n\n  // Devices used in this executable\n  devices:[string];\n  ops:[OpDef];\n  operands:[TensorDef];\n  inputs:[int];\n  outputs:[int];\n  input_names:[string];\n  output_names:[string];\n  segments:[BufferSegmentDef];\n\n  // Functions will be registered in the global function table and will\n  // be used firstly, even if those operators functions are available\n  // in the runtime library\n  functions:[OpFunctionDef];\n}\n\nroot_type ExecutableDef;\n"
  },
  {
    "path": "oneflow/ir/oneflow-lite/schemas/install_flatcc.cmake",
    "content": "include(ExternalProject)\n\ninclude(FetchContent)\n\nset(FLATCC_URL https://github.com/dvidelabs/flatcc/archive/refs/tags/v0.6.1.tar.gz)\nuse_mirror(VARIABLE FLATCC_URL URL ${FLATCC_URL})\nmessage(STATUS \"Download flatcc from url: ${FLATCC_URL}\")\n\n#FetchContent_Declare(flatcc URL ${FLATCC_URL})\n#FetchContent_MakeAvailable(flatcc)\nFetchContent_Populate(flatcc URL ${FLATCC_URL} SOURCE_DIR flatcc)\n\nset(FLATCC_ROOT ${CMAKE_CURRENT_BINARY_DIR}/flatcc)\nset(FLATCC_SRCS\n    \"${FLATCC_ROOT}/src/runtime/builder.c\"\n    \"${FLATCC_ROOT}/src/runtime/verifier.c\"\n    \"${FLATCC_ROOT}/src/runtime/emitter.c\"\n    \"${FLATCC_ROOT}/src/runtime/json_parser.c\"\n    \"${FLATCC_ROOT}/src/runtime/json_printer.c\"\n    \"${FLATCC_ROOT}/src/runtime/refmap.c\"\n    \"${FLATCC_ROOT}/config/config.h\")\nset(FLATCC_INCLUDE_DIR ${FLATCC_ROOT}/include)\nadd_library(flatcc-runtime STATIC ${FLATCC_SRCS})\ntarget_include_directories(flatcc-runtime SYSTEM PUBLIC ${FLATCC_INCLUDE_DIR})\n\nadd_executable(\n  flatcc-cli\n  \"${FLATCC_ROOT}/src/cli/flatcc_cli.c\"\n  \"${FLATCC_ROOT}/external/hash/cmetrohash64.c\"\n  \"${FLATCC_ROOT}/external/hash/str_set.c\"\n  \"${FLATCC_ROOT}/external/hash/ptr_set.c\"\n  \"${FLATCC_ROOT}/src/compiler/hash_tables/symbol_table.c\"\n  \"${FLATCC_ROOT}/src/compiler/hash_tables/scope_table.c\"\n  \"${FLATCC_ROOT}/src/compiler/hash_tables/name_table.c\"\n  \"${FLATCC_ROOT}/src/compiler/hash_tables/schema_table.c\"\n  \"${FLATCC_ROOT}/src/compiler/hash_tables/value_set.c\"\n  \"${FLATCC_ROOT}/src/compiler/fileio.c\"\n  \"${FLATCC_ROOT}/src/compiler/parser.c\"\n  \"${FLATCC_ROOT}/src/compiler/semantics.c\"\n  \"${FLATCC_ROOT}/src/compiler/coerce.c\"\n  \"${FLATCC_ROOT}/src/compiler/codegen_schema.c\"\n  \"${FLATCC_ROOT}/src/compiler/flatcc.c\"\n  \"${FLATCC_ROOT}/src/compiler/codegen_c.c\"\n  \"${FLATCC_ROOT}/src/compiler/codegen_c_reader.c\"\n  \"${FLATCC_ROOT}/src/compiler/codegen_c_sort.c\"\n  \"${FLATCC_ROOT}/src/compiler/codegen_c_builder.c\"\n  \"${FLATCC_ROOT}/src/compiler/codegen_c_verifier.c\"\n  \"${FLATCC_ROOT}/src/compiler/codegen_c_sorter.c\"\n  \"${FLATCC_ROOT}/src/compiler/codegen_c_json_parser.c\"\n  \"${FLATCC_ROOT}/src/compiler/codegen_c_json_printer.c\"\n  \"${FLATCC_ROOT}/src/runtime/builder.c\"\n  \"${FLATCC_ROOT}/src/runtime/emitter.c\"\n  \"${FLATCC_ROOT}/src/runtime/refmap.c\")\ntarget_include_directories(flatcc-cli PRIVATE \"${FLATCC_ROOT}/external\" \"${FLATCC_ROOT}/include\"\n                                              \"${FLATCC_ROOT}/config\")\n\n#set(FLATCC_EXE ${CMAKE_CURRENT_BINARY_DIR}/flatcc-cli PARENT_SCOPE)\nset(FLATCC_EXE ${CMAKE_CURRENT_BINARY_DIR}/flatcc-cli)\n\nfunction(FLATCC_GENERATE SRCS)\n  set(${SRCS})\n  foreach(FIL ${ARGN})\n    get_filename_component(ABS_FIL ${FIL} ABSOLUTE)\n    get_filename_component(FIL_WE ${FIL} NAME_WE)\n\n    list(APPEND ${SRCS} \"${CMAKE_CURRENT_BINARY_DIR}/${FIL_WE}_generated.h\")\n    add_custom_command(\n      OUTPUT \"${CMAKE_CURRENT_BINARY_DIR}/${FIL_WE}_generated.h\"\n      COMMAND ${FLATCC_EXE} ARGS --builder --verifier\n              --outfile=${CMAKE_CURRENT_BINARY_DIR}/${FIL_WE}_generated.h -a ${ABS_FIL}\n      DEPENDS ${ABS_FIL} ${FLATCC_EXE}\n      COMMENT \"Running flatcc compiler on ${FIL}\"\n      VERBATIM)\n    set(${SRCS} ${${SRCS}} PARENT_SCOPE)\n  endforeach()\nendfunction()\n"
  },
  {
    "path": "oneflow/ir/oneflow-opt/CMakeLists.txt",
    "content": "get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)\nget_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS)\nadd_llvm_executable(oneflow-opt oneflow-opt.cpp)\n\nset(_origin_prefix \"\\$ORIGIN\")\nif(APPLE)\n  set(_origin_prefix \"@loader_path\")\nendif()\nset_target_properties(\n  oneflow-opt PROPERTIES BUILD_WITH_INSTALL_RPATH OFF BUILD_RPATH \"${_origin_prefix}\"\n                         INSTALL_RPATH \"${_origin_prefix}\")\nllvm_update_compile_flags(oneflow-opt)\ntarget_link_libraries(\n  oneflow-opt\n  PRIVATE MLIROneFlow\n          ${dialect_libs}\n          ${conversion_libs}\n          MLIROptLib\n          $<TARGET_OBJECTS:of_op_schema>\n          MLIROneFlowExtension\n          MLIROneFlowTransformDialect)\n\nmlir_check_all_link_libraries(oneflow-opt)\n"
  },
  {
    "path": "oneflow/ir/oneflow-opt/README.md",
    "content": "# OneFlow MLIR Optimizer\n\nThis module includes a CLI optimize a `.mlir` file.\n"
  },
  {
    "path": "oneflow/ir/oneflow-opt/oneflow-opt.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/framework/op_generated.h\"\n#include \"oneflow/core/control/ctrl_bootstrap.pb.h\"\n#include \"OneFlow/OneFlowDialect.h\"\n#include \"OneFlow/Passes.h\"\n#include \"OneFlow/SBP/SBPDialect.h\"\n#include \"OneFlow/OKL/OKLDialect.h\"\n#include \"OneFlow/OKM/OKMDialect.h\"\n#include \"OneFlow/OKL/passes.h\"\n#include \"OneFlow/OKM/passes.h\"\n#include \"Transform/TransformDialectExtension.h\"\n\n#include \"mlir/InitAllDialects.h\"\n#include \"mlir/InitAllPasses.h\"\n#include \"mlir/Tools/mlir-opt/MlirOptMain.h\"\n\nconst auto global_cse_state = std::make_shared<mlir::oneflow::CSEState>();\n\nint32_t main(int32_t argc, char** argv) {\n  ::oneflow::Singleton<::oneflow::ProcessCtx>::New();\n  mlir::registerAllPasses();\n  mlir::oneflow::registerCSEPasses(global_cse_state);\n  mlir::oneflow::registerPasses();\n  mlir::okm::registerPasses();\n  mlir::okl::registerPasses();\n  mlir::oneflow::transform_dialect::registerTransformDialectEraseSchedulePass();\n  mlir::oneflow::transform_dialect::registerTransformDialectInterpreterPass();\n\n  mlir::DialectRegistry registry;\n  // Note: register all mlir dialect and their extension.\n  mlir::registerAllDialects(registry);\n  mlir::oneflow::transform_dialect::registerTransformDialectExtension(registry);\n  registry.insert<mlir::okl::OKLDialect>();\n  registry.insert<mlir::okm::OKMDialect>();\n  registry.insert<mlir::sbp::SBPDialect>();\n  registry.insert<mlir::oneflow::OneFlowDialect>();\n  return failed(mlir::MlirOptMain(argc, argv, \"OneFlow optimizer driver\\n\", registry));\n}\n"
  },
  {
    "path": "oneflow/ir/oneflow-runner/CMakeLists.txt",
    "content": "set(LLVM_LINK_COMPONENTS Core Support nativecodegen native)\n\noneflow_add_llvm_tool(oneflow-runner oneflow-runner.cpp)\n\nset(_origin_prefix \"\\$ORIGIN\")\nif(APPLE)\n  set(_origin_prefix \"@loader_path\")\nendif()\nset_target_properties(\n  oneflow-runner PROPERTIES BUILD_WITH_INSTALL_RPATH OFF BUILD_RPATH \"${_origin_prefix}\"\n                            INSTALL_RPATH \"${_origin_prefix}\")\n\ntarget_link_libraries(\n  oneflow-runner\n  PRIVATE MLIRAnalysis\n          MLIRExecutionEngine\n          MLIRIR\n          MLIRJitRunner\n          MLIRLLVMIRTransforms\n          MLIRLLVMToLLVMIRTranslation\n          MLIRToLLVMIRTranslationRegistration\n          MLIRParser\n          MLIRTargetLLVMIRExport\n          MLIRSupport\n          MLIROneFlow\n          glog::glog)\n"
  },
  {
    "path": "oneflow/ir/oneflow-runner/oneflow-runner.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n//===- mlir-cpu-runner.cpp - MLIR CPU Execution Driver---------------------===//\n//\n// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.\n// See https://llvm.org/LICENSE.txt for license information.\n// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception\n//\n//===----------------------------------------------------------------------===//\n//\n// Main entry point to a command line utility that executes an MLIR file on the\n// CPU by  translating MLIR to LLVM IR before JIT-compiling and executing the\n// latter.\n//\n//===----------------------------------------------------------------------===//\n\n#include \"mlir/Dialect/LLVMIR/LLVMDialect.h\"\n#include \"mlir/ExecutionEngine/JitRunner.h\"\n#include \"mlir/ExecutionEngine/OptUtils.h\"\n#include \"mlir/IR/Dialect.h\"\n#include \"mlir/Target/LLVMIR/Dialect/All.h\"\n\n#include \"llvm/Support/InitLLVM.h\"\n#include \"llvm/Support/TargetSelect.h\"\n#include \"OneFlow/OneFlowDialect.h\"\n\nint main(int argc, char** argv) {\n  llvm::InitLLVM y(argc, argv);\n  llvm::InitializeNativeTarget();\n  llvm::InitializeNativeTargetAsmPrinter();\n  // llvm::InitializeNativeTargetAsmParser(); // link fails\n\n  mlir::DialectRegistry registry;\n  mlir::registerAllToLLVMIRTranslations(registry);\n  registry.insert<mlir::oneflow::OneFlowDialect>();\n  return mlir::JitRunnerMain(argc, argv, registry);\n}\n"
  },
  {
    "path": "oneflow/ir/oneflow-runtime/CMakeLists.txt",
    "content": "add_subdirectory(lib)\n"
  },
  {
    "path": "oneflow/ir/oneflow-runtime/lib/CMakeLists.txt",
    "content": "oneflow_add_mlir_library(MLIROneFlowRuntime Runtime.cpp)\nif(WITH_MLIR_CUDA_CODEGEN)\n  set(MLIR_RUNTIME_GPU_LIBS mlir_cuda_runtime)\nendif(WITH_MLIR_CUDA_CODEGEN)\ntarget_link_libraries(MLIROneFlowRuntime PUBLIC -Wl,--no-as-needed ${MLIR_RUNTIME_GPU_LIBS}\n                                                mlir_c_runner_utils -Wl,--as-needed)\n"
  },
  {
    "path": "oneflow/ir/oneflow-runtime/lib/Runtime.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n// This file is added to avoid cmake error\n"
  },
  {
    "path": "oneflow/ir/oneflow-translate/CMakeLists.txt",
    "content": "set(LLVM_LINK_COMPONENTS Support)\n\nget_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)\nget_property(translation_libs GLOBAL PROPERTY MLIR_TRANSLATION_LIBS)\n\nset(LLVM_ENABLE_RTTI ON) # turn this on to make it compatible with protobuf\n\ninclude_directories(${PROJECT_SOURCE_DIR}/oneflow-translate/include)\ninclude_directories(${PROJECT_BINARY_DIR}/oneflow-translate/include)\n\nadd_subdirectory(include)\nadd_subdirectory(lib)\n\nadd_llvm_executable(oneflow-translate oneflow-translate.cpp DEPENDS MLIROneFlow\n                    MLIROneFlowTranslation)\n\nset(_origin_prefix \"\\$ORIGIN\")\nif(APPLE)\n  set(_origin_prefix \"@loader_path\")\nendif()\nset_target_properties(\n  oneflow-translate PROPERTIES BUILD_WITH_INSTALL_RPATH OFF BUILD_RPATH \"${_origin_prefix}\"\n                               INSTALL_RPATH \"${_origin_prefix}\")\n\nllvm_update_compile_flags(oneflow-translate)\n\ntarget_link_libraries(oneflow-translate PRIVATE ${dialect_libs} ${translation_libs}\n                      PUBLIC MLIRTranslateLib MLIROneFlowTranslation)\n\nmlir_check_link_libraries(oneflow-translate)\n"
  },
  {
    "path": "oneflow/ir/oneflow-translate/README.md",
    "content": "# OneFlow Translate\n## Import OneFlow Job to MLIR and dump a new Job\n```\njob -> module\nsub graph -> function\n```\n\n### Pipeline\n- Lower case: OneFlow, upper case: MLIR\n- [something]: a step, could be rewrite or other kinds of optimizations\n    ```\n    user op ->  OPAQUE USER OP -> CONCRETE OP -> [OPTIMIZATION] -> user op\n    system op ->  OPAQUE SYSTEM OP -> system op\n    ```\n\n### About blob name\n- MLIR exporters and and exporters should take care of blob names so other components don't touch it.\n\n### About SBP signature\n- There should be a sharding op to store SBP information.\n- Reusing built-in tensor types is pratical and makes it easy to resuse pass interfaces.\n- Implementing a tensor type with SBP is actually working agaist MLIR because pass in MLIR works better with operations.\n\n### Basic principles for a legit rewrite\n\n1. Source op of control edge shouldn't be erased\n2. Erasing, creating op shouldn't introduce boxing\n3. Results' shapes should stay identical\n### Information not included in OpConf\n\n- There are information in job not included in `OpConf`:\n```protobuf\nmessage JobHelperConf {\n  map<string, LogicalBlobIdPairs> tag2lbi_relations = 1;\n  ...\n}\n\nmessage JobParallelViewConf {\n  ...\n}\n```\n\n- Create callbacks wrapping `JobBuilder` MLIR can call to update job helperconfs when it is erasing/building operations.\n"
  },
  {
    "path": "oneflow/ir/oneflow-translate/include/CMakeLists.txt",
    "content": "add_subdirectory(OneFlow)\n"
  },
  {
    "path": "oneflow/ir/oneflow-translate/include/OneFlow/CMakeLists.txt",
    "content": "\n"
  },
  {
    "path": "oneflow/ir/oneflow-translate/include/OneFlow/MLIROneFlowTranslation.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_IR_ONEFLOW_TRANSLATE_INCLUDE_ONEFLOW_MLIRONEFLOWTRANSLATION_H_\n#define ONEFLOW_IR_ONEFLOW_TRANSLATE_INCLUDE_ONEFLOW_MLIRONEFLOWTRANSLATION_H_\n\n#include \"oneflow/core/framework/user_op_def.pb.h\"\n#include \"oneflow/core/job/job.pb.h\"\n#include \"oneflow/core/job/sbp_parallel.pb.h\"\n#include \"oneflow/core/operator/op_conf.pb.h\"\n#include \"OneFlow/SBP/SBPImporter.h\"\n\n#include \"OneFlow/OneFlowOps.h\"\n\n#include \"mlir/IR/BuiltinOps.h\"\n#include \"mlir/IR/Builders.h\"\n#include \"mlir/IR/MLIRContext.h\"\n\n#include <functional>\n#include <string>\n\nusing UserOpArgs = const ::google::protobuf::Map<std::string, ::oneflow::UserOpConf_ListString>&;\nusing UserOpArgDefs = const ::google::protobuf::RepeatedPtrField<::oneflow::UserOpDef_ArgDef>&;\n\nnamespace mlir {\n\nnamespace oneflow {\n\n// TODO: wrap in a helper namespace\n\nLogicalResult IsAttrBelong2Op(const std::string& op_type_name, const std::string& attr_name);\n\nLogicalResult ConvertUserOpInputs(Operation* op, StringRef op_name,\n                                  ::oneflow::UserOpConf* user_conf);\nLogicalResult ConvertUserOpOutputs(Operation* op, StringRef op_name,\n                                   ::oneflow::UserOpConf* user_conf);\nLogicalResult ConvertCtrlInputs(Operation* op, ::oneflow::OperatorConf& op_conf);\nllvm::Optional<mlir::oneflow::DataTypeAttr> GetDataTypeAttr(MLIRContext* context,\n                                                            ::oneflow::DataType oneflow_value);\nLogicalResult ConvertVariableOpConf(VariableOp op, ::oneflow::OperatorConf* op_conf);\nLogicalResult ConvertInputOpConf(InputOp op, ::oneflow::OperatorConf* op_conf);\nLogicalResult ConvertOutputOpConf(OutputOp op, ::oneflow::OperatorConf* op_conf);\n\nLogicalResult ParseNdSbpFromAttr(ArrayAttr nd_sbp_attr, ::oneflow::NdSbp* nd_sbp);\nAttribute ConvertNdSbpToAttr(Builder& builder, const ::oneflow::NdSbp& nd_sbp);\n\nclass Importer {\n public:\n  Importer(MLIRContext* context, ModuleOp module)\n      : builder_(context),\n        context_(context),\n        module_(module),\n        unknown_loc_(FileLineColLoc::get(context, \"unknown_loc\", 0, 0)) {}\n  virtual ~Importer() = default;\n  LogicalResult namedAttributesFromUserOp(const ::oneflow::OperatorConf& op,\n                                          std::vector<NamedAttribute>& attr_vec);\n  virtual LogicalResult AppendDataInOperand(const std::string& lbn,\n                                            std::vector<::mlir::Value>& operand_vec) {\n    return failure();\n  }\n  virtual LogicalResult AppendDataInOperand(const std::string& key, const int32_t index,\n                                            const std::string& lbn,\n                                            std::vector<::mlir::Value>& operand_vec) {\n    return AppendDataInOperand(lbn, operand_vec);\n  }\n  virtual LogicalResult AppendCtrlInOperand(const ::oneflow::OperatorConf& op,\n                                            std::vector<::mlir::Value>& operand_vec) = 0;\n  LogicalResult AppendCtrlOutType(llvm::SmallVector<Type, 8>& out_types);\n  LogicalResult AddOpConf(const ::oneflow::OperatorConf& op, std::vector<NamedAttribute>& attr_vec);\n  LogicalResult AddUserOpInputOutputSegments(const ::oneflow::OperatorConf& op,\n                                             std::vector<NamedAttribute>& attr_vec);\n  virtual LogicalResult AddDeviceName(const ::oneflow::OperatorConf& op,\n                                      std::vector<NamedAttribute>& attr_vec) = 0;\n  LogicalResult AddOperandSegmentSizes(int32_t input_lbns_size, int32_t ctrl_in_size,\n                                       std::vector<NamedAttribute>& attr_vec);\n  LogicalResult AddResultSegmentSizes(int32_t output_lbns_size,\n                                      std::vector<NamedAttribute>& attr_vec);\n  virtual LogicalResult InsertOpResults(const ::oneflow::OperatorConf& op, Operation*) = 0;\n  LogicalResult ProcessUserOp(const ::oneflow::OperatorConf& op);\n  virtual LogicalResult ProcessSystemOp(const ::oneflow::OperatorConf& op) = 0;\n\n  IntegerAttr getSI64IntegerAttr(int64_t value) {\n    return IntegerAttr::get(GetBuilder().getIntegerType(64, /*isSigned=*/true),\n                            APInt(64, value, /*isSigned=*/true));\n  }\n  ArrayAttr getSI32ArrayAttr(ArrayRef<int32_t> values) {\n    auto attrs = llvm::to_vector<8>(llvm::map_range(\n        values, [this](int32_t v) -> Attribute { return GetBuilder().getSI32IntegerAttr(v); }));\n    return GetBuilder().getArrayAttr(attrs);\n  }\n  ArrayAttr getSI64ArrayAttr(ArrayRef<int64_t> values) {\n    auto attrs = llvm::to_vector<8>(\n        llvm::map_range(values, [this](int64_t v) -> Attribute { return getSI64IntegerAttr(v); }));\n    return GetBuilder().getArrayAttr(attrs);\n  }\n\n  ArrayAttr GetAttrFromShape(const ::oneflow::ShapeProto& shape);\n  ArrayAttr GetAttrFromStride(const ::oneflow::Int64ListProto& stride);\n  OpBuilder& GetBuilder() { return builder_; }\n  MLIRContext* GetMLIRContext() { return context_; }\n  ModuleOp& GetModule() { return module_; }\n  Location& GetRootLocation() { return unknown_loc_; }\n  virtual Type GetTensorTypeOfLbn(const std::string& lbn) = 0;\n  void SetOpStateLoc(const ::oneflow::OperatorConf&, OperationState&);\n\n private:\n  OpBuilder builder_;\n  MLIRContext* context_;\n  ModuleOp module_;\n  Location unknown_loc_;\n};\n\nclass RoundTripOneFlowJobWrapperInterface {\n public:\n  virtual ~RoundTripOneFlowJobWrapperInterface() {}\n  virtual const ::oneflow::Job* job() const = 0;\n  virtual void UpdateJob(::oneflow::Job* new_job) = 0;\n  virtual void DumpLog(const std::string& filename, const std::string& content) = 0;\n  virtual const ::oneflow::ParallelConf& ParallelConf4OpName(const std::string& op_name) const = 0;\n  virtual const ::oneflow::OperatorConf& OpConf4OpName(const std::string& op_name) const = 0;\n  virtual std::pair<std::vector<std::string>, std::vector<std::string>> InputBns4OpName(\n      const std::string& op_name) const = 0;\n  virtual std::vector<std::string> OutputLbns4OpName(const std::string& op_name) const = 0;\n  virtual std::string ReplaceInputLbnInOpCustomizedConf(::oneflow::OperatorConf* op_conf,\n                                                        const std::string& ibn,\n                                                        const std::string& new_val) const = 0;\n  virtual void QueryLogicalBlob(\n      const std::string& lbn, std::function<void(const int64_t* shape_begin,\n                                                 const int64_t* shape_end, ::oneflow::DataType dt)>\n                                  cb) const = 0;\n  virtual void TopoForEachOpConf(\n      std::function<void(const ::oneflow::OperatorConf*)> Handler) const = 0;\n  virtual bool IsLastIRPass() const = 0;\n};\n\nvoid RoundTripOneFlowJob(\n    RoundTripOneFlowJobWrapperInterface& job_wrapper,\n    const std::function<bool(::oneflow::Job* job, std::string& reason)>& is_legit_job);\n\nvoid registerFromOneFlowJobTranslation();\n\nstd::string ConvertJobToTosaIR(RoundTripOneFlowJobWrapperInterface& job_wrapper);\nvoid SaveJobToIR(RoundTripOneFlowJobWrapperInterface& job_wrapper, const std::string& path);\nstd::string ConvertJobToIR(RoundTripOneFlowJobWrapperInterface& job_wrapper);\nvoid LoadJobFromIR(RoundTripOneFlowJobWrapperInterface& job_wrapper, const std::string& path);\n\n}  // namespace oneflow\n\n}  // namespace mlir\n\n#endif  // ONEFLOW_IR_ONEFLOW_TRANSLATE_INCLUDE_ONEFLOW_MLIRONEFLOWTRANSLATION_H_\n"
  },
  {
    "path": "oneflow/ir/oneflow-translate/lib/CMakeLists.txt",
    "content": "add_subdirectory(OneFlow)\n"
  },
  {
    "path": "oneflow/ir/oneflow-translate/lib/OneFlow/CMakeLists.txt",
    "content": "oneflow_add_mlir_library(\n  MLIROneFlowTranslation\n  MLIROneFlowTranslation.cpp\n  Importer.cpp\n  ADDITIONAL_HEADER_DIRS\n  ${PROJECT_SOURCE_DIR}/oneflow-translate/include/OneFlow\n  DEPENDS\n  oneflow_deps\n  LINK_LIBS\n  PUBLIC\n  MLIRIR\n  ${dialect_libs}\n  ${translation_libs}\n  MLIRIR\n  MLIRParser\n  MLIRPass\n  MLIRSPIRVDialect\n  MLIRTranslateLib\n  MLIRSupport\n  MLIROneFlow\n  MLIRTosaToTensor\n  oneflow)\n\nif(BUILD_SHARED_LIBS)\n  get_filename_component(ONEFLOW_BUILD_ROOT_DIR ${CMAKE_CURRENT_BINARY_DIR}/../../../../.. ABSOLUTE)\n  get_property(TRANSLATE_INSTALL_RPATH TARGET MLIROneFlowTranslation PROPERTY INSTALL_RPATH)\n  list(APPEND TRANSLATE_INSTALL_RPATH ${PROTOBUF_LIBRARY_DIR})\n  list(APPEND TRANSLATE_INSTALL_RPATH ${ONEFLOW_BUILD_ROOT_DIR})\n  set_target_properties(MLIROneFlowTranslation PROPERTIES INSTALL_RPATH\n                                                          \"${TRANSLATE_INSTALL_RPATH}\")\nendif()\n\nmlir_check_link_libraries(MLIROneFlowTranslation)\n"
  },
  {
    "path": "oneflow/ir/oneflow-translate/lib/OneFlow/Importer.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"OneFlow/UserOpConversion.h\"\n#include \"oneflow/core/common/data_type.pb.h\"\n#include \"oneflow/core/framework/user_op_conf.pb.h\"\n#include \"oneflow/core/job/job.pb.h\"\n#include \"oneflow/core/operator/op_conf.pb.h\"\n#include \"oneflow/core/framework/user_op_def.h\"\n#include \"oneflow/core/framework/user_op_registry_manager.h\"\n\n#include \"OneFlow/OneFlowDialect.h\"\n#include \"OneFlow/SBP/SBPDialect.h\"\n#include \"OneFlow/SBP/SBPAttributes.h\"\n#include \"OneFlow/OneFlowOps.h\"\n#include \"OneFlow/UserOpReflection.h\"\n#include \"OneFlow/OneFlowTypes.h\"\n#include \"OneFlow/OneFlowSupport.h\"\n#include \"OneFlow/Passes.h\"\n#include \"OneFlow/MLIROneFlowTranslation.h\"\n#include \"OneFlow/OneFlowSupport.h\"\n#include \"OneFlow/OneFlowDataTypeConversion.h\"\n\n#include \"mlir/Dialect/Func/IR/FuncOps.h\"\n#include \"mlir/IR/Attributes.h\"\n#include \"mlir/IR/Builders.h\"\n#include \"mlir/IR/BuiltinAttributes.h\"\n#include \"mlir/IR/BuiltinOps.h\"\n#include \"mlir/IR/BuiltinTypes.h\"\n#include \"mlir/IR/Location.h\"\n#include \"mlir/IR/OperationSupport.h\"\n\n#include \"mlir/IR/UseDefLists.h\"\n#include \"mlir/IR/Value.h\"\n#include \"mlir/Pass/PassManager.h\"\n#include \"mlir/Support/LLVM.h\"\n#include \"mlir/Support/LogicalResult.h\"\n#include \"mlir/Transforms/Passes.h\"\n#include \"mlir/Tools/mlir-translate/MlirTranslateMain.h\"\n#include \"mlir/Tools/mlir-translate/Translation.h\"\n#include \"mlir/Transforms/GreedyPatternRewriteDriver.h\"\n\n#include \"llvm-c/Core.h\"\n#include \"llvm/ADT/ArrayRef.h\"\n#include \"llvm/ADT/None.h\"\n#include \"llvm/ADT/Optional.h\"\n#include \"llvm/ADT/STLExtras.h\"\n#include \"llvm/ADT/StringRef.h\"\n#include \"llvm/ADT/StringSet.h\"\n#include \"llvm/Support/Casting.h\"\n#include \"llvm/Support/raw_ostream.h\"\n\n#include <google/protobuf/text_format.h>\n\n#include \"oneflow/core/framework/sbp_context.h\"\n#include \"oneflow/core/job/sbp_signature_builder.h\"\nnamespace mlir {\n\nnamespace oneflow {\n\nusing PbMessage = google::protobuf::Message;\n\nnamespace {\n\nusing SizeVec = SmallVector<int32_t, 8>;\n\nSizeVec GetSizesFromArgs(UserOpArgs args, UserOpArgDefs arg_defs) {\n  SizeVec sizes{};\n  llvm::StringSet<> names({});\n  for (const auto& arg : args) { names.insert(arg.first); }\n  for (const auto& arg_def : arg_defs) {\n    int32_t size = 0;\n    if (names.contains(arg_def.name())) { size = args.at(arg_def.name()).s_size(); }\n    sizes.push_back(size);\n  }\n  return sizes;\n}\n\nstd::vector<std::string> GetOutputLbns(const ::oneflow::OperatorConf& op, UserOpArgDefs arg_defs) {\n  SizeVec sizes{};\n  llvm::StringSet<> names_appeared({});\n  std::vector<std::string> output_lbn_vec{};\n  const auto& op_name = op.name();\n  for (const auto& arg : op.user_conf().output()) { names_appeared.insert(arg.first); }\n  for (const auto& arg_def : arg_defs) {\n    const auto& key = arg_def.name();\n    const auto& it = op.user_conf().output().find(key);\n    if (it == op.user_conf().output().end()) { continue; }\n    auto result_size = it->second.s_size();\n    if (result_size == 0) { continue; }\n    for (int32_t i = 0; i < result_size; i++) {\n      const auto output_lbn = op_name + \"/\" + key + \"_\" + std::to_string(i);\n      output_lbn_vec.push_back(output_lbn);\n    }\n  }\n  return output_lbn_vec;\n}\n\n}  // namespace\n\nLogicalResult IsAttrBelong2Op(const std::string& op_type_name, const std::string& attr_name) {\n  ::oneflow::user_op::UserOpDefWrapper op_def(support::getUserOpDef(op_type_name));\n  return success(op_def.IsAttrName(attr_name));\n}\n\nLogicalResult Importer::AddUserOpInputOutputSegments(const ::oneflow::OperatorConf& op,\n                                                     std::vector<NamedAttribute>& attr_vec) {\n  if (op.has_user_conf() == false) return failure();\n  const auto& user_conf = op.user_conf();\n  const ::oneflow::UserOpDef& op_def = support::getUserOpDef(op.user_conf().op_type_name());\n  const auto UserOpOperationName = OperationName(UserOp::getOperationName(), GetMLIRContext());\n  attr_vec.push_back(GetBuilder().getNamedAttr(\n      oneflow::UserOp::getInputSizesAttrName(UserOpOperationName),\n      GetBuilder().getI32ArrayAttr(GetSizesFromArgs(user_conf.input(), op_def.input()))));\n  attr_vec.push_back(GetBuilder().getNamedAttr(\n      oneflow::UserOp::getOutputSizesAttrName(UserOpOperationName),\n      GetBuilder().getI32ArrayAttr(GetSizesFromArgs(user_conf.output(), op_def.output()))));\n  auto output_lbns = GetOutputLbns(op, op_def.output());\n  attr_vec.push_back(GetBuilder().getNamedAttr(\n      OpTrait::IsImportCompatible<void>::getOutputLBNsAttr(),\n      GetBuilder().getStrArrayAttr(\n          SmallVector<StringRef, 8>({output_lbns.begin(), output_lbns.end()}))));\n  return success();\n}\n\nllvm::Optional<mlir::oneflow::DataTypeAttr> GetDataTypeAttr(MLIRContext* context,\n                                                            ::oneflow::DataType oneflow_value) {\n  switch (oneflow_value) {\n    case ::oneflow::DataType::kInvalidDataType:\n      return oneflow::DataTypeAttr::get(context, mlir::oneflow::DataType::DT_InvalidDataType);\n      break;\n#define DEFINE_ONE_ELIF(datatype)                                                       \\\n  case ::oneflow::DataType::k##datatype:                                                \\\n    return oneflow::DataTypeAttr::get(context, mlir::oneflow::DataType::DT_##datatype); \\\n    break;\n      DEFINE_ONE_ELIF(Char)\n      DEFINE_ONE_ELIF(Float)\n      DEFINE_ONE_ELIF(Double)\n      DEFINE_ONE_ELIF(Int8)\n      DEFINE_ONE_ELIF(Int32)\n      DEFINE_ONE_ELIF(Int64)\n      DEFINE_ONE_ELIF(UInt8)\n      DEFINE_ONE_ELIF(OFRecord)\n      DEFINE_ONE_ELIF(Float16)\n      DEFINE_ONE_ELIF(TensorBuffer)\n      DEFINE_ONE_ELIF(BFloat16)\n      DEFINE_ONE_ELIF(Bool)\n#undef DEFINE_ONE_ELIF\n    default: llvm::errs() << \"unsupported data type: \" << oneflow_value << \"\\n\"; return llvm::None;\n  }\n}\n\nArrayAttr Importer::GetAttrFromShape(const ::oneflow::ShapeProto& shape) {\n  return GetBuilder().getArrayAttr(llvm::to_vector<8>(llvm::map_range(\n      shape.dim(), [this](int64_t v) -> Attribute { return getSI64IntegerAttr(v); })));\n}\n\nArrayAttr Importer::GetAttrFromStride(const ::oneflow::Int64ListProto& stride) {\n  return GetBuilder().getArrayAttr(llvm::to_vector<8>(llvm::map_range(\n      stride.dim(), [this](int64_t v) -> Attribute { return getSI64IntegerAttr(v); })));\n}\n\nLogicalResult Importer::namedAttributesFromUserOp(const ::oneflow::OperatorConf& op,\n                                                  std::vector<NamedAttribute>& attr_vec) {\n  if (op.has_user_conf() == false) {\n    GetModule().emitError(\"Not a user op. op name: \" + op.name());\n    return failure();\n  }\n  for (const google::protobuf::MapPair<class std::basic_string<char>, ::oneflow::AttrValue>& attr :\n       op.user_conf().attr()) {\n    const std::string& name = attr.first;\n    const ::oneflow::AttrValue& value = attr.second;\n    if (value.has_at_int32()) {\n      mlir::NamedAttribute kv =\n          GetBuilder().getNamedAttr(name, GetBuilder().getSI32IntegerAttr(value.at_int32()));\n      attr_vec.emplace_back(kv);\n    } else if (value.has_at_int64()) {\n      mlir::NamedAttribute kv =\n          GetBuilder().getNamedAttr(name, getSI64IntegerAttr(value.at_int64()));\n      attr_vec.emplace_back(kv);\n    }\n#define DEFINE_ONE_ELIF(at_key, get_attr)                                       \\\n  else if (value.has_##at_key()) {                                              \\\n    mlir::NamedAttribute kv =                                                   \\\n        GetBuilder().getNamedAttr(name, GetBuilder().get_attr(value.at_key())); \\\n    attr_vec.emplace_back(kv);                                                  \\\n  }\n    DEFINE_ONE_ELIF(at_bool, getBoolAttr)\n    DEFINE_ONE_ELIF(at_float, getF32FloatAttr)\n    DEFINE_ONE_ELIF(at_double, getF64FloatAttr)\n    DEFINE_ONE_ELIF(at_string, getStringAttr)\n#undef DEFINE_ONE_ELIF\n    else if (value.has_at_shape()) {\n      attr_vec.emplace_back(GetBuilder().getNamedAttr(name, GetAttrFromShape(value.at_shape())));\n    }\n    else if (value.has_at_stride()) {\n      attr_vec.emplace_back(GetBuilder().getNamedAttr(name, GetAttrFromStride(value.at_stride())));\n    }\n#define DEFINE_ONE_ELIF(at_key, get_attr, field)                                         \\\n  else if (value.has_##at_key()) {                                                       \\\n    mlir::NamedAttribute kv = GetBuilder().getNamedAttr(                                 \\\n        name, get_attr({value.at_key().field().begin(), value.at_key().field().end()})); \\\n    attr_vec.emplace_back(kv);                                                           \\\n  }\n    DEFINE_ONE_ELIF(at_list_int32, getSI32ArrayAttr, val)\n    DEFINE_ONE_ELIF(at_list_int64, getSI64ArrayAttr, val)\n    DEFINE_ONE_ELIF(at_list_float, GetBuilder().getF32ArrayAttr, val)\n#undef DEFINE_ONE_ELIF\n    else if (value.has_at_list_string()) {\n      std::vector<llvm::StringRef> r_vec = {value.at_list_string().val().begin(),\n                                            value.at_list_string().val().end()};\n      mlir::NamedAttribute kv =\n          GetBuilder().getNamedAttr(name, GetBuilder().getStrArrayAttr(r_vec));\n      attr_vec.emplace_back(kv);\n    }\n    else if (value.has_at_data_type()) {\n      if (auto dt_attr = GetDataTypeAttr(GetMLIRContext(), value.at_data_type())) {\n        mlir::NamedAttribute kv = GetBuilder().getNamedAttr(name, dt_attr.value());\n        attr_vec.emplace_back(kv);\n      } else {\n        GetModule().emitError(\"fail to convert op attr, key: \" + name);\n        return failure();\n      }\n    }\n    else if (value.has_at_list_data_type()) {\n      auto dt_attr_list =\n          llvm::map_range(value.at_list_data_type().val(), [&](auto t) -> mlir::Attribute {\n            auto dt = GetDataTypeAttr(GetMLIRContext(), static_cast<::oneflow::DataType>(t));\n            CHECK(dt) << \"fail to convert op attr, key: \" + name;\n            return dt.value();\n          });\n      attr_vec.emplace_back(GetBuilder().getNamedAttr(\n          name, GetBuilder().getArrayAttr(llvm::to_vector<8>(dt_attr_list))));\n    }\n    else if (value.has_at_list_shape()) {\n      auto dense_attr_list =\n          llvm::map_range(value.at_list_shape().val(),\n                          [&](const ::oneflow::ShapeProto& s) { return GetAttrFromShape(s); });\n      std::vector<mlir::Attribute> dense_attr_vector{dense_attr_list.begin(),\n                                                     dense_attr_list.end()};\n      attr_vec.emplace_back(\n          GetBuilder().getNamedAttr(name, GetBuilder().getArrayAttr(dense_attr_vector)));\n    }\n    else if (value.has_at_list_stride()) {\n      auto dense_attr_list =\n          llvm::map_range(value.at_list_stride().val(),\n                          [&](const ::oneflow::Int64ListProto& s) { return GetAttrFromStride(s); });\n      std::vector<mlir::Attribute> dense_attr_vector{dense_attr_list.begin(),\n                                                     dense_attr_list.end()};\n      attr_vec.emplace_back(\n          GetBuilder().getNamedAttr(name, GetBuilder().getArrayAttr(dense_attr_vector)));\n    }\n    else if (value.has_at_complex_double()) {\n      std::vector<mlir::Attribute> dense_attr_vector{\n          GetBuilder().getF64FloatAttr(value.at_complex_double().real()),\n          GetBuilder().getF64FloatAttr(value.at_complex_double().imag())};\n      attr_vec.emplace_back(\n          GetBuilder().getNamedAttr(name, GetBuilder().getArrayAttr(dense_attr_vector)));\n    }\n    else {\n      GetModule().emitError(\"can't handle user op attr: \" + name + \", op name: \" + op.name()\n                            + \", op type name: \" + op.user_conf().op_type_name());\n      return failure();\n    }\n  }\n\n  if (failed(AddUserOpInputOutputSegments(op, attr_vec))) {\n    GetModule().emitError(\"fail to add input output segments: \" + op.name());\n    return failure();\n  }\n\n  return success();\n}\n\nLogicalResult Importer::AddOperandSegmentSizes(int32_t input_lbns_size, int32_t ctrl_in_size,\n                                               std::vector<NamedAttribute>& attr_vec) {\n  attr_vec.push_back(GetBuilder().getNamedAttr(\n      mlir::OpTrait::AttrSizedOperandSegments<void>::getOperandSegmentSizeAttr(),\n      GetBuilder().getDenseI32ArrayAttr({input_lbns_size, ctrl_in_size})));\n  return success();\n}\n\nLogicalResult Importer::AddResultSegmentSizes(int32_t output_lbns_size,\n                                              std::vector<NamedAttribute>& attr_vec) {\n  attr_vec.push_back(GetBuilder().getNamedAttr(\n      mlir::OpTrait::AttrSizedResultSegments<void>::getResultSegmentSizeAttr(),\n      GetBuilder().getDenseI32ArrayAttr(\n          {output_lbns_size, 1} /* {data_out_size, ctrl_out_size} */)));\n  return success();\n}\n\nLogicalResult Importer::AppendCtrlOutType(llvm::SmallVector<Type, 8>& out_types) {\n  out_types.append({RankedTensorType::get({}, GetBuilder().getI1Type())});\n  return success();\n}\n\nLogicalResult Importer::AddOpConf(const ::oneflow::OperatorConf& op,\n                                  std::vector<NamedAttribute>& attr_vec) {\n  attr_vec.push_back(GetBuilder().getNamedAttr(OpTrait::IsOpConfCompatible<void>::getOpNameAttr(),\n                                               GetBuilder().getStringAttr(op.name())));\n  if (op.has_device_tag()) {\n    attr_vec.push_back(\n        GetBuilder().getNamedAttr(OpTrait::IsOpConfCompatible<void>::getDeviceTagAttr(),\n                                  GetBuilder().getStringAttr(op.device_tag())));\n  }\n  attr_vec.push_back(\n      GetBuilder().getNamedAttr(OpTrait::IsOpConfCompatible<void>::getScopeSymbolIDAttr(),\n                                GetBuilder().getI64IntegerAttr(op.scope_symbol_id())));\n  return success();\n}\n\nLogicalResult ParseNdSbpFromAttr(::llvm::ArrayRef<Attribute> nd_sbp_attr,\n                                 ::oneflow::NdSbp* nd_sbp) {\n  for (const auto& sbp_attr : nd_sbp_attr) {\n    auto sbp_str_attr = sbp_attr.dyn_cast<StringAttr>();\n    if (!sbp_str_attr) {\n      llvm::errs() << \"nd_sbp attr is not a StrArrayAttr\";\n      return failure();\n    }\n    auto sbp_strref = sbp_str_attr.getValue();\n    if (sbp_strref.startswith(\"S\")) {\n      if (!(sbp_strref.substr(1, 1) == \"(\" && sbp_strref.endswith(\")\"))) {\n        llvm::errs() << \"invalid sbp S(x) string value: \" << sbp_strref;\n        return failure();\n      }\n      auto split_axis = std::stoi(sbp_strref.substr(2, 1).str());\n      nd_sbp->add_sbp_parallel()->mutable_split_parallel()->set_axis(split_axis);\n    } else if (sbp_strref == \"B\") {\n      nd_sbp->add_sbp_parallel()->mutable_broadcast_parallel();\n    } else if (sbp_strref == \"P\") {\n      nd_sbp->add_sbp_parallel()->mutable_partial_sum_parallel();\n    } else {\n      llvm::errs() << \"unsupported nd_sbp string value: \" << sbp_strref;\n      return failure();\n    }\n  }\n  return success();\n}\n\nAttribute ConvertNdSbpToAttr(Builder& builder, const ::oneflow::NdSbp& nd_sbp) {\n  llvm::SmallVector<std::string, 2> sbp_strs;\n  for (const auto& sbp : nd_sbp.sbp_parallel()) {\n    if (sbp.has_split_parallel()) {\n      sbp_strs.emplace_back(\"S(\" + std::to_string(sbp.split_parallel().axis()) + \")\");\n    } else if (sbp.has_broadcast_parallel()) {\n      sbp_strs.emplace_back(\"B\");\n    } else if (sbp.has_partial_sum_parallel()) {\n      sbp_strs.emplace_back(\"P\");\n    } else {\n      llvm::errs() << \"unsupported sbp: \" << nd_sbp.DebugString();\n      exit(EXIT_FAILURE);\n    }\n  }\n  return builder.getStrArrayAttr(\n      makeArrayRef(llvm::SmallVector<StringRef>(sbp_strs.begin(), sbp_strs.end())));\n}\n\nLogicalResult ValidateUserOpConf(const ::oneflow::OperatorConf& op_conf, UserOpArgs args,\n                                 UserOpArgDefs arg_defs) {\n  for (const auto& input_arg : args) {\n    const bool found = std::find_if(arg_defs.begin(), arg_defs.end(),\n                                    [&](const ::oneflow::UserOpDef_ArgDef& arg_def) {\n                                      return input_arg.first == arg_def.name();\n                                    })\n                       != arg_defs.end();\n    if (!found) {\n      llvm::errs() << \"fail to validate user op conf, arg def of arg not found: \" << input_arg.first\n                   << \", op: \\n\"\n                   << op_conf.DebugString() << \"\\n\";\n      return failure();\n    }\n  }\n  return success();\n}\n\nLogicalResult Importer::ProcessUserOp(const ::oneflow::OperatorConf& op) {\n  if (op.has_user_conf() == false) {\n    GetModule().emitError(\"Not a user op. op name: \" + op.name());\n    return failure();\n  }\n  std::vector<NamedAttribute> attr_vec;\n  if (failed(AddOpConf(op, attr_vec))) { return failure(); }\n  if (failed(AddDeviceName(op, attr_vec))) { return failure(); }\n  attr_vec.push_back(\n      GetBuilder().getNamedAttr(OpTrait::IsAlternative<void>::getOpTypeNameAttr(),\n                                GetBuilder().getStringAttr(op.user_conf().op_type_name())));\n  std::vector<::mlir::Value> operand_vec;\n  if (failed(namedAttributesFromUserOp(op, attr_vec))) { return failure(); }\n  const auto& op_def = support::getUserOpDef(op.user_conf().op_type_name());\n  if (failed(ValidateUserOpConf(op, op.user_conf().input(), op_def.input()))) { return failure(); }\n  if (failed(ValidateUserOpConf(op, op.user_conf().output(), op_def.output()))) {\n    return failure();\n  }\n  for (const auto& arg_def : op_def.input()) {\n    const auto& key = arg_def.name();\n    auto it = op.user_conf().input().find(key);\n    if (it == op.user_conf().input().end()) { continue; }\n    int32_t index = 0;\n    for (const std::string& lbn : it->second.s()) {\n      if (failed(AppendDataInOperand(key, index, lbn, operand_vec))) { return failure(); }\n      index += 1;\n    }\n  }\n  if (failed(AppendCtrlInOperand(op, operand_vec))) { return failure(); }\n  ::mlir::ValueRange operands(operand_vec);\n\n  Operation* created_op = nullptr;\n\n  auto out_types = llvm::SmallVector<Type, 8>();\n  for (const auto& arg_def : op_def.output()) {\n    const auto& key = arg_def.name();\n    auto it = op.user_conf().output().find(key);\n    if (it == op.user_conf().output().end()) { continue; }\n    for (const auto& output_lbn : it->second.s()) {\n      out_types.push_back(GetTensorTypeOfLbn(output_lbn));\n    }\n  }\n\n  if (failed(AppendCtrlOutType(out_types))) { return failure(); }\n  OperationState state(FileLineColLoc::get(GetMLIRContext(), op.name(), 0, 0),\n                       UserOp::getOperationName());\n  uint32_t data_input_size = 0;\n  uint32_t data_output_size = 0;\n  for (const auto& input : op.user_conf().input()) { data_input_size += input.second.s().size(); }\n  for (const auto& output : op.user_conf().output()) {\n    data_output_size += output.second.s().size();\n  }\n  if (failed(AddOperandSegmentSizes(data_input_size, op.ctrl_in_op_name_size(), attr_vec))) {\n    return failure();\n  }\n  if (failed(AddResultSegmentSizes(data_output_size, attr_vec))) { return failure(); }\n  ArrayRef<NamedAttribute> named_attributes(attr_vec);\n  state.addAttributes(named_attributes);\n  state.addOperands(operands);\n  state.addTypes(out_types);\n  SetOpStateLoc(op, state);\n  created_op = GetBuilder().create(state);\n\n  if (created_op == nullptr) {\n    GetModule()->emitError(\"fail to create \" + op.user_conf().op_type_name()\n                           + \" op, name: \" + op.name());\n    return failure();\n  }\n  if (failed(InsertOpResults(op, created_op))) { return failure(); }\n\n  return success();\n}  // namespace\n\nLogicalResult ConvertCtrlInputs(Operation* op, ::oneflow::OperatorConf& op_conf) {\n  if (op->isRegistered() && !llvm::dyn_cast<oneflow::UserOp>(op)) return success();\n  if (auto ctrl_ins = GetCtrlIntputOperands(op)) {\n    for (auto ctrl_in : ctrl_ins.value()) {\n      op_conf.add_ctrl_in_op_name(\n          OpTrait::IsOpConfCompatible<void>::getOpName(ctrl_in.getDefiningOp()).str());\n    }\n  }\n  return success();\n}\n\nLogicalResult ConvertUserOpInputs(Operation* op, StringRef op_name,\n                                  ::oneflow::UserOpConf* user_conf) {\n  std::vector<std::string> keys{};\n  std::vector<int32_t> sizes{};\n  if (failed(user_op::GetFilteredSegmentKeyAndSizes<OpTrait::AttrSizedOperandSegments>(op, keys,\n                                                                                       sizes))) {\n    op->emitError(\"fail to convert user op inputs\");\n    return failure();\n  }\n  int32_t input_idx = 0;\n  for (auto tuple : llvm::zip(keys, sizes)) {\n    auto input_key = std::get<0>(tuple);\n    auto input_size = std::get<1>(tuple);\n    if (input_size <= 0)\n      return op->emitError(\"input_size <= 0, op: \" + op->getName().getStringRef());\n    for (int32_t i = 0; i < input_size; i++) {\n      if (auto result = GetDataInputOperands(op)[input_idx].dyn_cast<mlir::OpResult>()) {\n        auto input_s_ptr = (*user_conf->mutable_input())[input_key].mutable_s()->Add();\n        *(input_s_ptr) = user_op::GetOutputLbn(result).value();\n        input_idx += 1;\n      } else {\n        op->emitError() << \"fail to convert MLIR result to protobuf, name: \" + op_name;\n        op->dump();\n        return failure();\n      }\n    }\n  }\n  return success();\n}\n\nLogicalResult ConvertUserOpOutputs(Operation* op, StringRef op_name,\n                                   ::oneflow::UserOpConf* user_conf) {\n  std::vector<std::string> keys{};\n  std::vector<int32_t> sizes{};\n  if (failed(user_op::GetFilteredSegmentKeyAndSizes<OpTrait::AttrSizedResultSegments>(op, keys,\n                                                                                      sizes))) {\n    op->emitError(\"fail to convert user op outputs\");\n    return failure();\n  }\n  for (auto tuple : llvm::zip(keys, sizes)) {\n    auto name = std::get<0>(tuple);\n    auto result_size = std::get<1>(tuple);\n    if (result_size == 0) continue;\n    for (int32_t i = 0; i < result_size; i++) {\n      auto out_s_ptr = (*user_conf->mutable_output())[name].mutable_s()->Add();\n      *(out_s_ptr) = op_name.str() + \"/\" + name + \"_\" + std::to_string(i);\n    }\n  }\n  return success();\n}\n\nLogicalResult ConvertDT(::mlir::oneflow::DataType data_type_mlir, ::oneflow::DataType& data_type) {\n  switch (data_type_mlir) {\n    case oneflow::DataType::DT_InvalidDataType:\n      data_type = ::oneflow::DataType::kInvalidDataType;\n      break;\n#define DEFINE_ONE_CASE(datatype) \\\n  case oneflow::DataType::DT_##datatype: data_type = ::oneflow::DataType::k##datatype; break;\n      DEFINE_ONE_CASE(Char)\n      DEFINE_ONE_CASE(Float)\n      DEFINE_ONE_CASE(Double)\n      DEFINE_ONE_CASE(Int8)\n      DEFINE_ONE_CASE(Int32)\n      DEFINE_ONE_CASE(Int64)\n      DEFINE_ONE_CASE(UInt8)\n      DEFINE_ONE_CASE(OFRecord)\n      DEFINE_ONE_CASE(Float16)\n      DEFINE_ONE_CASE(TensorBuffer)\n      DEFINE_ONE_CASE(Bool)\n#undef DEFINE_ONE_CASE\n    default: return failure();\n  }\n  return success();\n}\n\nLogicalResult ConvertDTFromAttr(Attribute attr, ::oneflow::DataType& data_type) {\n  auto dt_attr = attr.dyn_cast<mlir::oneflow::DataTypeAttr>();\n  return ConvertDT(dt_attr.getValue(), data_type);\n}\n\nvoid Importer::SetOpStateLoc(const ::oneflow::OperatorConf& op_conf, OperationState& state) {\n  if (op_conf.has_loc()) {\n    state.location = (FileLineColLoc::get(GetMLIRContext(), op_conf.loc(), 0, 0));\n  }\n}\n\nLogicalResult ConvertVariableOpConf(VariableOp op, ::oneflow::OperatorConf* op_conf) {\n  op_conf->set_name(op.getOpName().str());\n  op_conf->set_device_tag(op.getDeviceTag().str());\n  if (auto scope_symbol_id = op.getScopeSymbolId()) {\n    op_conf->set_scope_symbol_id(scope_symbol_id.value());\n  }\n  // TODO: process stream_name_hint\n\n  auto* var_op_conf = op_conf->mutable_variable_conf();\n  var_op_conf->set_out(\"out\");\n\n  if (auto shape_attr =\n          op->getAttrOfType<ArrayAttr>(OpTrait::TensorSource<void>::getShapeAttrName())) {\n    *var_op_conf->mutable_shape() = user_op::getAttrAsShape(shape_attr);\n  }\n\n  if (op->hasAttr(OpTrait::TensorSource<void>::getDataTypeAttrName())) {\n    if (auto dt_mlir = op.getDataType()) {\n      const auto dt = support::FromMLIRDataTypeToOFDataType(dt_mlir.value());\n      if (failed(dt)) { return failure(); }\n      var_op_conf->set_data_type(dt.value());\n    }\n  }\n\n  if (auto model_name = op.getModelNameAttr()) {\n    var_op_conf->set_model_name(model_name.getValue().str());\n  }\n\n  if (auto l1_regularization = op.getL1RegularizationAttr()) {\n    LOG(ERROR) << op_conf->name();\n    var_op_conf->mutable_regularizer()->mutable_l1_l2_conf()->set_l1(\n        l1_regularization.getValue().convertToFloat());\n  }\n\n  if (auto l2_regularization = op.getL2RegularizationAttr()) {\n    var_op_conf->mutable_regularizer()->mutable_l1_l2_conf()->set_l2(\n        l2_regularization.getValue().convertToFloat());\n  }\n\n  if (auto trainable = op.getTrainableAttr()) { var_op_conf->set_trainable(trainable.getValue()); }\n\n  for (auto output : op.getParallel()->getOutputs()) {\n    if (auto nd_outputs = output.dyn_cast<ArrayAttr>()) {\n      for (auto nd_output : nd_outputs) {\n        std::string sbp{};\n        if (failed(SBPTranslation::PrintSbpAttrToString(nd_output, sbp))) return failure();\n        var_op_conf->add_nd_sbp(sbp);\n      }\n    } else {\n      std::string sbp{};\n      if (failed(SBPTranslation::PrintSbpAttrToString(output, sbp))) return failure();\n      var_op_conf->add_nd_sbp(sbp);\n    }\n  }\n  // all operands are ctrl_inputs\n  for (const auto& operand : op->getOperands()) {\n    op_conf->add_ctrl_in_op_name(\n        OpTrait::IsOpConfCompatible<void>::getOpName(operand.getDefiningOp()).str());\n  }\n  if (auto floatInit = op.getFloatInitializer()) {\n    var_op_conf->mutable_initializer()->mutable_constant_conf()->set_value(\n        floatInit.value().convertToFloat());\n  } else if (auto integerInit = op.getIntegerInitializer()) {\n    var_op_conf->mutable_initializer()->mutable_constant_int_conf()->set_value(integerInit.value());\n  } else {\n    // empty initializer\n    var_op_conf->mutable_initializer()->mutable_empty_conf();\n  }\n\n  return success();\n}\n\nLogicalResult ConvertInputOpConf(InputOp op, ::oneflow::OperatorConf* op_conf) {\n  op_conf->set_name(op.getOpName().str());\n  op_conf->set_device_tag(op.getDeviceTag().str());\n  if (auto scope_symbol_id = op.getScopeSymbolId()) {\n    op_conf->set_scope_symbol_id(scope_symbol_id.value());\n  }\n  // TODO: process stream_name_hint\n\n  auto* input_op_conf = op_conf->mutable_input_conf();\n  input_op_conf->set_out(\"out\");\n\n  if (auto shape_attr =\n          op->getAttrOfType<ArrayAttr>(OpTrait::TensorSource<void>::getShapeAttrName())) {\n    *input_op_conf->mutable_blob_conf()->mutable_shape() = user_op::getAttrAsShape(shape_attr);\n  }\n\n  if (op->hasAttr(OpTrait::TensorSource<void>::getDataTypeAttrName())) {\n    if (auto dt_mlir = op.getDataType()) {\n      const auto dt = support::FromMLIRDataTypeToOFDataType(dt_mlir.value());\n      if (failed(dt)) { return failure(); }\n      input_op_conf->mutable_blob_conf()->set_data_type(dt.value());\n    }\n  }\n\n  if (op->hasAttr(OpTrait::TensorSource<void>::getIsDynamicAttrName())) {\n    input_op_conf->mutable_blob_conf()->set_is_dynamic(op.getIsDynamic().value());\n  }\n\n  if (op->hasAttr(OpTrait::TensorSource<void>::getNdSbpAttrName())) {\n    if (failed(ParseNdSbpFromAttr(op.getNdSbp()->getValue(),\n                                  input_op_conf->mutable_blob_conf()->mutable_nd_sbp()))) {\n      return failure();\n    }\n  }\n\n  if (op->hasAttr(\"job_name\")) { input_op_conf->set_job_name(op.getJobName().value().str()); }\n\n  // operand 0 is block argument, others are ctrl_inputs\n  for (size_t i = 1; i < op->getNumOperands(); ++i) {\n    op_conf->add_ctrl_in_op_name(\n        OpTrait::IsOpConfCompatible<void>::getOpName(op->getOperand(i).getDefiningOp()).str());\n  }\n\n  return success();\n}\n\nLogicalResult ConvertOutputOpConf(OutputOp op, ::oneflow::OperatorConf* op_conf) {\n  op_conf->set_name(op.getOpName().str());\n  op_conf->set_device_tag(op.getDeviceTag().str());\n  if (auto scope_symbol_id = op.getScopeSymbolId()) {\n    op_conf->set_scope_symbol_id(scope_symbol_id.value());\n  }\n  // TODO: process stream_name_hint\n\n  auto* output_op_conf = op_conf->mutable_output_conf();\n  output_op_conf->set_out(\"out\");\n\n  if (auto shape_attr =\n          op->getAttrOfType<ArrayAttr>(OpTrait::TensorSource<void>::getShapeAttrName())) {\n    *output_op_conf->mutable_blob_conf()->mutable_shape() = user_op::getAttrAsShape(shape_attr);\n  }\n\n  if (op->hasAttr(OpTrait::TensorSource<void>::getDataTypeAttrName())) {\n    if (auto dt_mlir = op.getDataType()) {\n      const auto dt = support::FromMLIRDataTypeToOFDataType(dt_mlir.value());\n      if (failed(dt)) { return failure(); }\n      output_op_conf->mutable_blob_conf()->set_data_type(dt.value());\n    }\n  }\n\n  if (op->hasAttr(OpTrait::TensorSource<void>::getIsDynamicAttrName())) {\n    output_op_conf->mutable_blob_conf()->set_is_dynamic(op.getIsDynamic().value());\n  }\n\n  if (op->hasAttr(OpTrait::TensorSource<void>::getNdSbpAttrName())) {\n    if (failed(ParseNdSbpFromAttr(op.getNdSbp()->getValue(),\n                                  output_op_conf->mutable_blob_conf()->mutable_nd_sbp()))) {\n      return failure();\n    }\n  }\n\n  if (op->hasAttr(\"job_name\")) { output_op_conf->set_job_name(op.getJobName().value().str()); }\n\n  if (op->getNumOperands() == 0) {\n    op->emitError(\"output op has at least one input.\");\n    return failure();\n  }\n  auto result = op->getOperand(0).dyn_cast<mlir::OpResult>();\n  auto output_lbn = user_op::GetOutputLbn(result).value();\n  output_op_conf->set_in(output_lbn);\n  for (size_t i = 1; i < op->getNumOperands(); ++i) {\n    op_conf->add_ctrl_in_op_name(\n        OpTrait::IsOpConfCompatible<void>::getOpName(op->getOperand(i).getDefiningOp()).str());\n  }\n  return success();\n}\n\n}  // namespace oneflow\n\n}  // namespace mlir\n"
  },
  {
    "path": "oneflow/ir/oneflow-translate/lib/OneFlow/MLIROneFlowTranslation.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"OneFlow/Conversion/OneFlowToTosa.h\"\n#include \"OneFlow/OneFlowDataTypeConversion.h\"\n#include \"OneFlow/Transform/FuncOps.h\"\n#include \"OneFlow/UserOpReflection.h\"\n#include \"OneFlow/Transform/AggregateOps.h\"\n#include \"mlir/Conversion/TosaToLinalg/TosaToLinalg.h\"\n#include \"mlir/Conversion/TosaToTensor/TosaToTensor.h\"\n#include \"mlir/Dialect/Linalg/Passes.h\"\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/common/data_type.pb.h\"\n#include \"oneflow/core/framework/user_op_conf.pb.h\"\n#include \"oneflow/core/job/job.pb.h\"\n#include \"oneflow/core/operator/op_conf.pb.h\"\n#include \"oneflow/core/operator/interface_blob_conf.pb.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n\n#include \"OneFlow/OneFlowDialect.h\"\n#include \"OneFlow/OneFlowOps.h\"\n#include \"OneFlow/OneFlowOpTraits.h\"\n#include \"OneFlow/Passes.h\"\n#include \"OneFlow/MLIROneFlowTranslation.h\"\n#include \"OneFlow/OneFlowUtils.h\"\n#include \"OneFlow/UserOpConversion.h\"\n\n#include \"mlir/Dialect/Tosa/Transforms/Passes.h\"\n#include \"mlir/Dialect/Func/IR/FuncOps.h\"\n#include \"mlir/IR/Attributes.h\"\n#include \"mlir/IR/Builders.h\"\n#include \"mlir/IR/BuiltinAttributes.h\"\n#include \"mlir/IR/BuiltinOps.h\"\n#include \"mlir/IR/BuiltinTypes.h\"\n#include \"mlir/IR/Location.h\"\n#include \"mlir/IR/OperationSupport.h\"\n\n#include \"mlir/IR/UseDefLists.h\"\n#include \"mlir/IR/Value.h\"\n#include \"mlir/IR/Visitors.h\"\n#include \"mlir/Pass/PassManager.h\"\n#include \"mlir/Support/LLVM.h\"\n#include \"mlir/Support/LogicalResult.h\"\n#include \"mlir/Transforms/Passes.h\"\n#include \"mlir/Tools/mlir-translate/MlirTranslateMain.h\"\n#include \"mlir/Tools/mlir-translate/Translation.h\"\n#include \"mlir/Transforms/GreedyPatternRewriteDriver.h\"\n#include \"mlir/Parser/Parser.h\"\n\n#include \"llvm-c/Core.h\"\n#include \"llvm/ADT/ArrayRef.h\"\n#include \"llvm/ADT/None.h\"\n#include \"llvm/ADT/Optional.h\"\n#include \"llvm/ADT/STLExtras.h\"\n#include \"llvm/ADT/StringRef.h\"\n#include \"llvm/Support/Casting.h\"\n#include \"llvm/Support/raw_ostream.h\"\n\n#include <google/protobuf/text_format.h>\n\nnamespace mlir {\n\nnamespace oneflow {\n\nusing PbMessage = google::protobuf::Message;\n\nclass JobImporter : Importer {\n public:\n  JobImporter(RoundTripOneFlowJobWrapperInterface& job_wrapper, MLIRContext* context,\n              ModuleOp module)\n      : Importer(context, module), job_(job_wrapper.job()), job_wrapper_(job_wrapper) {}\n  virtual ~JobImporter() = default;\n  LogicalResult AppendDataInOperand(const std::string& lbn,\n                                    std::vector<::mlir::Value>& operand_vec) override;\n  LogicalResult AppendCtrlInOperand(const ::oneflow::OperatorConf& op,\n                                    std::vector<::mlir::Value>& operand_vec) override;\n  LogicalResult AddDeviceName(const ::oneflow::OperatorConf& op,\n                              std::vector<NamedAttribute>& attr_vec) override;\n  LogicalResult InsertOpResults(const ::oneflow::OperatorConf& op, Operation*) override;\n\n  LogicalResult ProcessJob();\n  LogicalResult ProcessSystemOp(const ::oneflow::OperatorConf& op) override;\n  LogicalResult ProcessVariableOp(const ::oneflow::OperatorConf& op);\n  LogicalResult ProcessInputOp(const ::oneflow::OperatorConf& op_conf, Block* entry_block,\n                               size_t& input_count);\n  LogicalResult ProcessOutputOp(const ::oneflow::OperatorConf& op_conf);\n\n  LogicalResult TryToUpdateJob();\n  LogicalResult ConvertUserOp(Operation* op, ::oneflow::Job& job);\n  LogicalResult ConvertSystemOp(Operation* op, ::oneflow::Job& job);\n  LogicalResult ConvertVariableOp(VariableOp op, ::oneflow::Job& job);\n  LogicalResult ConvertInputOp(InputOp op, ::oneflow::Job& job);\n  LogicalResult ConvertOutputOp(OutputOp op, ::oneflow::Job& job);\n\n  Type GetTensorTypeOfLbn(const std::string& lbn) override;\n  Type GetInterfaceBlobConfType(const ::oneflow::InterfaceBlobConf& blob_conf);\n\n private:\n  std::unordered_map<std::string, mlir::OpResult> lbn2result_;\n  std::unordered_map<std::string, mlir::OpResult> op_name2ctrl_result_;\n  const ::oneflow::Job* job_;\n  RoundTripOneFlowJobWrapperInterface& job_wrapper_;\n};\n\nLogicalResult JobImporter::AppendCtrlInOperand(const ::oneflow::OperatorConf& op,\n                                               std::vector<::mlir::Value>& operand_vec) {\n  for (auto& ctrl_in_op_name : op.ctrl_in_op_name()) {\n    auto it = op_name2ctrl_result_.find(ctrl_in_op_name);\n    if (it == op_name2ctrl_result_.end()) {\n      GetModule().emitError(\"ctrl edge result of this op not found: \" + ctrl_in_op_name\n                            + \". op being controlled: \" + op.name());\n      return failure();\n    } else {\n      operand_vec.push_back(it->second);\n    }\n  }\n  return success();\n}\n\nLogicalResult JobImporter::AppendDataInOperand(const std::string& lbn,\n                                               std::vector<::mlir::Value>& operand_vec) {\n  auto it = lbn2result_.find(lbn);\n  if (it == lbn2result_.end()) {\n    GetModule().emitError(\"IR result not found for: \" + lbn);\n    return failure();\n  } else {\n    operand_vec.push_back(it->second);\n    return success();\n  }\n}\n\nLogicalResult JobImporter::InsertOpResults(const ::oneflow::OperatorConf& op,\n                                           Operation* created_op) {\n  auto output_lbns =\n      created_op->getAttrOfType<ArrayAttr>(OpTrait::IsImportCompatible<void>::getOutputLBNsAttr());\n  auto data_results = GetDataOutputResults(created_op);\n  if (output_lbns.size() != data_results.size()) {\n    output_lbns.dump();\n    llvm::errs() << \"output_lbns size: \" << output_lbns.size()\n                 << \" != data_results size: \" << data_results.size() << \"\\n\"\n                 << op.DebugString();\n    created_op->getAttrDictionary().dump();\n    created_op->dump();\n    return failure();\n  }\n  for (const auto& data_out : llvm::enumerate(data_results)) {\n    auto data_out_index = data_out.index();\n    lbn2result_.insert({output_lbns[data_out_index].dyn_cast<StringAttr>().getValue().str(),\n                        data_out.value().dyn_cast<OpResult>()});\n  }\n  if (auto ctrl_out = GetCtrlOutputResult(created_op)) {\n    op_name2ctrl_result_.insert(\n        {created_op->getAttrOfType<StringAttr>(OpTrait::IsOpConfCompatible<void>::getOpNameAttr())\n             .getValue()\n             .str(),\n         ctrl_out->dyn_cast<OpResult>()});\n  }\n  return success();\n}\n\nLogicalResult JobImporter::AddDeviceName(const ::oneflow::OperatorConf& op,\n                                         std::vector<NamedAttribute>& attr_vec) {\n  const ::oneflow::ParallelConf& pc = job_wrapper_.ParallelConf4OpName(op.name());\n  std::vector<llvm::StringRef> device_vec = {pc.device_name().begin(), pc.device_name().end()};\n  attr_vec.push_back(\n      GetBuilder().getNamedAttr(OpTrait::IsOpConfCompatible<void>::getDeviceNameAttr(),\n                                GetBuilder().getStrArrayAttr(device_vec)));\n  if (pc.has_hierarchy()) {\n    attr_vec.push_back(GetBuilder().getNamedAttr(\n        OpTrait::IsOpConfCompatible<void>::getHierarchyAttr(),\n        GetBuilder().getI64ArrayAttr({pc.hierarchy().dim().begin(), pc.hierarchy().dim().end()})));\n  }\n  return success();\n}\n\nType JobImporter::GetTensorTypeOfLbn(const std::string& lbn) {\n  Type ret{};\n  job_wrapper_.QueryLogicalBlob(\n      lbn, [this, &ret, &lbn](const int64_t* shape_begin, const int64_t* shape_end,\n                              ::oneflow::DataType dt) {\n        if (auto t = getTypeFromOneFlowDataType(GetMLIRContext(), dt)) {\n          ret = RankedTensorType::get(ArrayRef<int64_t>(shape_begin, shape_end), t);\n        } else {\n          llvm::errs() << \"fail to get data tensor type for: \" << lbn << \"\\n\";\n        }\n      });\n  return ret;\n}\n\nLogicalResult JobImporter::ProcessSystemOp(const ::oneflow::OperatorConf& op) {\n  if (op.has_user_conf()) {\n    GetModule().emitError(\"Not a sys op. op name: \" + op.name());\n    return failure();\n  }\n  if (op.has_variable_conf()) { return ProcessVariableOp(op); }\n\n  auto input_bns_lbns = job_wrapper_.InputBns4OpName(op.name());\n  auto input_bns = input_bns_lbns.first;\n  auto input_lbns = input_bns_lbns.second;\n  auto output_lbns = job_wrapper_.OutputLbns4OpName(op.name());\n  job_wrapper_.OutputLbns4OpName(op.name());\n  std::vector<NamedAttribute> attr_vec;\n  if (failed(AddOpConf(op, attr_vec))) { return failure(); }\n  if (failed(AddDeviceName(op, attr_vec))) { return failure(); }\n  attr_vec.push_back(GetBuilder().getNamedAttr(\n      \"input_bns\", GetBuilder().getStrArrayAttr(\n                       std::vector<llvm::StringRef>({input_bns.begin(), input_bns.end()}))));\n  attr_vec.push_back(GetBuilder().getNamedAttr(\n      OpTrait::IsImportCompatible<void>::getOutputLBNsAttr(),\n      GetBuilder().getStrArrayAttr(\n          std::vector<llvm::StringRef>({output_lbns.begin(), output_lbns.end()}))));\n  OperationState state(FileLineColLoc::get(GetMLIRContext(), op.name(), 0, 0),\n                       SystemOp::getOperationName());\n  attr_vec.push_back(\n      GetBuilder().getNamedAttr(\"op_type_case\", GetBuilder().getI32IntegerAttr(op.op_type_case())));\n  if (failed(AddOperandSegmentSizes(static_cast<int>(input_lbns.size()), op.ctrl_in_op_name_size(),\n                                    attr_vec))) {\n    return failure();\n  }\n  if (failed(AddResultSegmentSizes(output_lbns.size(), attr_vec))) { return failure(); }\n  state.addAttributes(attr_vec);\n  std::vector<::mlir::Value> operand_vec;\n  for (const auto& input_lbn : input_lbns) {\n    if (failed(AppendDataInOperand(input_lbn, operand_vec))) { return failure(); }\n  }\n  if (failed(AppendCtrlInOperand(op, operand_vec))) { return failure(); }\n  auto out_types = llvm::SmallVector<Type, 8>();\n  for (const auto& output_lbn : output_lbns) {\n    out_types.push_back(GetTensorTypeOfLbn(output_lbn));\n  }\n  if (failed(AppendCtrlOutType(out_types))) { return failure(); }\n  state.addOperands(operand_vec);\n  state.addTypes(out_types);\n  if (auto created_op = GetBuilder().create(state)) {\n    if (failed(InsertOpResults(op, created_op))) { return failure(); }\n  } else {\n    GetModule()->emitError(\"fail to create op, name: \" + op.name());\n    return failure();\n  }\n  return success();\n}\n\nLogicalResult JobImporter::ProcessVariableOp(const ::oneflow::OperatorConf& op_conf) {\n  if (!op_conf.has_variable_conf()) {\n    GetModule().emitError(\"Not a variable op. op name: \" + op_conf.name());\n    return failure();\n  }\n\n  if (op_conf.variable_conf().has_tick()) {\n    GetModule().emitError(\"variable op has tick input. op name: \" + op_conf.name());\n    return failure();\n  }\n\n  OperationState state(FileLineColLoc::get(GetMLIRContext(), op_conf.name(), 0, 0),\n                       \"oneflow.variable\");\n  // attrs\n  std::vector<NamedAttribute> attr_vec;\n  if (failed(AddOpConf(op_conf, attr_vec))) { return failure(); }\n  if (failed(AddDeviceName(op_conf, attr_vec))) { return failure(); }\n  // attr output_lbns\n  auto output_lbns_attr = GetBuilder().getStrArrayAttr({op_conf.name() + \"/out\"});\n  attr_vec.emplace_back(GetBuilder().getNamedAttr(\n      OpTrait::IsImportCompatible<void>::getOutputLBNsAttr(), output_lbns_attr));\n  // attr shape\n  auto shape_attr = GetAttrFromShape(op_conf.variable_conf().shape());\n  auto shape_named_attr =\n      GetBuilder().getNamedAttr(OpTrait::TensorSource<void>::getShapeAttrName(), shape_attr);\n  attr_vec.emplace_back(shape_named_attr);\n  // attr data_type\n  if (op_conf.variable_conf().has_data_type()) {\n    attr_vec.emplace_back(GetBuilder().getNamedAttr(\n        OpTrait::TensorSource<void>::getDataTypeAttrName(),\n        GetDataTypeAttr(GetMLIRContext(), op_conf.variable_conf().data_type()).value()));\n  }\n  // attr model_name\n  if (op_conf.variable_conf().has_model_name()) {\n    const std::string& model_name = op_conf.variable_conf().model_name();\n    attr_vec.emplace_back(\n        GetBuilder().getNamedAttr(\"model_name\", GetBuilder().getStringAttr(model_name)));\n  }\n  // attr l1 l2 regularization\n  if (op_conf.variable_conf().has_regularizer()\n      && op_conf.variable_conf().regularizer().has_l1_l2_conf()) {\n    if (op_conf.variable_conf().regularizer().l1_l2_conf().has_l1()) {\n      float l1_regularization = op_conf.variable_conf().regularizer().l1_l2_conf().l1();\n      attr_vec.emplace_back(GetBuilder().getNamedAttr(\n          \"l1_regularization\", GetBuilder().getF32FloatAttr(l1_regularization)));\n    }\n    if (op_conf.variable_conf().regularizer().l1_l2_conf().has_l2()) {\n      float l2_regularization = op_conf.variable_conf().regularizer().l1_l2_conf().l2();\n      attr_vec.emplace_back(GetBuilder().getNamedAttr(\n          \"l2_regularization\", GetBuilder().getF32FloatAttr(l2_regularization)));\n    }\n  }\n  // attr trainable\n  if (op_conf.variable_conf().has_trainable()) {\n    bool trainable = op_conf.variable_conf().trainable();\n    attr_vec.emplace_back(\n        GetBuilder().getNamedAttr(\"trainable\", GetBuilder().getBoolAttr(trainable)));\n  }\n  if (op_conf.variable_conf().has_initializer()) {\n    if (op_conf.variable_conf().initializer().has_constant_conf()) {\n      const mlir::Attribute const_initialize_attr = GetBuilder().getF32FloatAttr(\n          op_conf.variable_conf().initializer().constant_conf().value());\n      attr_vec.emplace_back(GetBuilder().getNamedAttr(\"float_initializer\", const_initialize_attr));\n    } else if (op_conf.variable_conf().initializer().has_constant_int_conf()) {\n      const mlir::Attribute const_initialize_attr =\n          getSI64IntegerAttr(op_conf.variable_conf().initializer().constant_int_conf().value());\n      attr_vec.emplace_back(\n          GetBuilder().getNamedAttr(\"integer_initializer\", const_initialize_attr));\n    }\n  }\n  // attr parallel\n  auto conf = this->job_wrapper_.ParallelConf4OpName(op_conf.name());\n\n  auto nd_size = conf.hierarchy().dim().size();\n  auto nd_sbp = op_conf.variable_conf().nd_sbp();\n  auto parallel = mlir::oneflow::SBPTranslation::ConvertNdSbpToPsig(\n      GetBuilder(), std::vector<std::string>(nd_sbp.begin(), nd_sbp.end()), nd_size);\n  attr_vec.emplace_back(\n      GetBuilder().getNamedAttr(OpTrait::TensorSource<void>::getSbpAttrName(), parallel));\n  // add attrs\n  state.addAttributes(attr_vec);\n  // operands\n  std::vector<::mlir::Value> operand_vec;\n  if (failed(AppendCtrlInOperand(op_conf, operand_vec))) { return failure(); }\n  state.addOperands(operand_vec);\n  // result types\n  llvm::SmallVector<Type, 8> out_types;\n  auto output_lbn = op_conf.name() + \"/out\";\n  out_types.push_back(GetTensorTypeOfLbn(output_lbn));\n  if (failed(AppendCtrlOutType(out_types))) { return failure(); }\n  state.addTypes(out_types);\n  SetOpStateLoc(op_conf, state);\n  // create op\n  auto op = GetBuilder().create(state);\n  if (!op) {\n    GetModule()->emitError(\"fail to create op, name: \" + op_conf.name());\n    return failure();\n  }\n  // record result\n  if (op->getNumResults() != 2) {\n    op->emitError(\"variable op should has two results (out and ctrl_output), but got \"\n                  + std::to_string(op->getNumResults()) + \"\\n\");\n    return failure();\n  }\n  if (!lbn2result_.emplace(output_lbn, op->getResult(0)).second) {\n    op->emitError(\"lbn already exists, lbn: \") << output_lbn;\n    return failure();\n  }\n  if (!op_name2ctrl_result_.emplace(op_conf.name(), op->getResult(1)).second) {\n    op->emitError(\"ctrl output already exists, op_name: \") << op_conf.name();\n    return failure();\n  }\n  return success();\n}\n\nLogicalResult JobImporter::ProcessInputOp(const ::oneflow::OperatorConf& op_conf,\n                                          Block* entry_block, size_t& input_count) {\n  if (!op_conf.has_input_conf()) {\n    GetModule().emitError(\"Not a input op. op name: \" + op_conf.name());\n    return failure();\n  }\n\n  if (op_conf.input_conf().has_tick()) {\n    GetModule().emitError(\"input op has tick input. op name: \" + op_conf.name());\n    return failure();\n  }\n\n  OperationState state(FileLineColLoc::get(GetMLIRContext(), op_conf.name(), 0, 0),\n                       \"oneflow.input\");\n  // attrs\n  std::vector<NamedAttribute> attr_vec;\n  if (failed(AddOpConf(op_conf, attr_vec))) { return failure(); }\n  if (failed(AddDeviceName(op_conf, attr_vec))) { return failure(); }\n  // attr output_lbns\n  auto output_lbns_attr = GetBuilder().getStrArrayAttr({op_conf.name() + \"/out\"});\n  attr_vec.emplace_back(GetBuilder().getNamedAttr(\n      OpTrait::IsImportCompatible<void>::getOutputLBNsAttr(), output_lbns_attr));\n  // attr shape\n  if (op_conf.input_conf().blob_conf().has_shape()) {\n    auto shape_attr = GetAttrFromShape(op_conf.input_conf().blob_conf().shape());\n    attr_vec.emplace_back(\n        GetBuilder().getNamedAttr(OpTrait::TensorSource<void>::getShapeAttrName(), shape_attr));\n  }\n  // attr data_type\n  if (op_conf.input_conf().blob_conf().has_data_type()) {\n    attr_vec.emplace_back(GetBuilder().getNamedAttr(\n        OpTrait::TensorSource<void>::getDataTypeAttrName(),\n        GetDataTypeAttr(GetMLIRContext(), op_conf.input_conf().blob_conf().data_type()).value()));\n  }\n  // attr is_dynamic\n  if (op_conf.input_conf().blob_conf().has_is_dynamic()) {\n    bool is_dynamic = op_conf.input_conf().blob_conf().is_dynamic();\n    attr_vec.emplace_back(GetBuilder().getNamedAttr(\n        OpTrait::TensorSource<void>::getIsDynamicAttrName(), GetBuilder().getBoolAttr(is_dynamic)));\n  }\n  // attr nd_sbp\n  if (op_conf.input_conf().blob_conf().has_nd_sbp()) {\n    auto nd_sbp_attr = ConvertNdSbpToAttr(GetBuilder(), op_conf.input_conf().blob_conf().nd_sbp());\n    attr_vec.emplace_back(\n        GetBuilder().getNamedAttr(OpTrait::TensorSource<void>::getNdSbpAttrName(), nd_sbp_attr));\n  }\n  // attr job_name\n  if (op_conf.input_conf().has_job_name()) {\n    const std::string& job_name = op_conf.input_conf().job_name();\n    attr_vec.emplace_back(\n        GetBuilder().getNamedAttr(\"job_name\", GetBuilder().getStringAttr(job_name)));\n  }\n  // add attrs\n  state.addAttributes(attr_vec);\n  // operands\n  std::vector<::mlir::Value> operand_vec;\n  operand_vec.emplace_back(entry_block->getArgument(input_count++));\n  if (failed(AppendCtrlInOperand(op_conf, operand_vec))) { return failure(); }\n  state.addOperands(operand_vec);\n  // result types\n  llvm::SmallVector<Type, 8> out_types;\n  auto output_lbn = op_conf.name() + \"/out\";\n  out_types.push_back(GetTensorTypeOfLbn(output_lbn));\n  if (failed(AppendCtrlOutType(out_types))) { return failure(); }\n  state.addTypes(out_types);\n  // create op\n  auto op = GetBuilder().create(state);\n  if (!op) {\n    GetModule()->emitError(\"fail to create op, name: \" + op_conf.name());\n    return failure();\n  }\n  // record result\n  if (op->getNumResults() != 2) {\n    op->emitError(\"input op should has two results (out and ctrl_output), but got \"\n                  + std::to_string(op->getNumResults()) + \"\\n\");\n    return failure();\n  }\n  if (!lbn2result_.emplace(output_lbn, op->getResult(0)).second) {\n    op->emitError(\"lbn already exists, lbn: \") << output_lbn;\n    return failure();\n  }\n  if (!op_name2ctrl_result_.emplace(op_conf.name(), op->getResult(1)).second) {\n    op->emitError(\"ctrl output already exists, op_name: \") << op_conf.name();\n    return failure();\n  }\n  return success();\n}\n\nLogicalResult JobImporter::ProcessOutputOp(const ::oneflow::OperatorConf& op_conf) {\n  if (!op_conf.has_output_conf()) {\n    GetModule().emitError(\"Not a output op. op name: \" + op_conf.name());\n    return failure();\n  }\n\n  OperationState state(FileLineColLoc::get(GetMLIRContext(), op_conf.name(), 0, 0),\n                       \"oneflow.output\");\n  // attrs\n  std::vector<NamedAttribute> attr_vec;\n  if (failed(AddOpConf(op_conf, attr_vec))) { return failure(); }\n  if (failed(AddDeviceName(op_conf, attr_vec))) { return failure(); }\n  // attr output_lbns\n  auto output_lbns_attr = GetBuilder().getStrArrayAttr({op_conf.name() + \"/out\"});\n  attr_vec.emplace_back(GetBuilder().getNamedAttr(\n      OpTrait::IsImportCompatible<void>::getOutputLBNsAttr(), output_lbns_attr));\n  // attr shape\n  if (op_conf.output_conf().blob_conf().has_shape()) {\n    auto shape_attr = GetAttrFromShape(op_conf.output_conf().blob_conf().shape());\n    attr_vec.emplace_back(\n        GetBuilder().getNamedAttr(OpTrait::TensorSource<void>::getShapeAttrName(), shape_attr));\n  }\n  // attr data_type\n  if (op_conf.output_conf().blob_conf().has_data_type()) {\n    attr_vec.emplace_back(GetBuilder().getNamedAttr(\n        OpTrait::TensorSource<void>::getDataTypeAttrName(),\n        GetDataTypeAttr(GetMLIRContext(), op_conf.output_conf().blob_conf().data_type()).value()));\n  }\n  // attr is_dynamic\n  if (op_conf.output_conf().blob_conf().has_is_dynamic()) {\n    bool is_dynamic = op_conf.output_conf().blob_conf().is_dynamic();\n    attr_vec.emplace_back(GetBuilder().getNamedAttr(\n        OpTrait::TensorSource<void>::getIsDynamicAttrName(), GetBuilder().getBoolAttr(is_dynamic)));\n  }\n  // attr nd_sbp\n  if (op_conf.output_conf().blob_conf().has_nd_sbp()) {\n    auto nd_sbp_attr = ConvertNdSbpToAttr(GetBuilder(), op_conf.output_conf().blob_conf().nd_sbp());\n    attr_vec.emplace_back(\n        GetBuilder().getNamedAttr(OpTrait::TensorSource<void>::getNdSbpAttrName(), nd_sbp_attr));\n  }\n  // attr job_name\n  if (op_conf.output_conf().has_job_name()) {\n    const std::string& job_name = op_conf.output_conf().job_name();\n    attr_vec.emplace_back(\n        GetBuilder().getNamedAttr(\"job_name\", GetBuilder().getStringAttr(job_name)));\n  }\n  // add attrs\n  state.addAttributes(attr_vec);\n  // operands\n  std::vector<::mlir::Value> operand_vec;\n  auto input_bns_lbns = job_wrapper_.InputBns4OpName(op_conf.name());\n  if (input_bns_lbns.second.size() != 1) {\n    GetModule()->emitError(\"output op should has only one input, op_name: \" + op_conf.name());\n    return failure();\n  }\n  if (failed(AppendDataInOperand(input_bns_lbns.second[0], operand_vec))) { return failure(); }\n  if (failed(AppendCtrlInOperand(op_conf, operand_vec))) { return failure(); }\n  state.addOperands(operand_vec);\n  // result types\n  llvm::SmallVector<Type, 8> out_types;\n  auto output_lbn = op_conf.name() + \"/out\";\n  out_types.push_back(GetTensorTypeOfLbn(output_lbn));\n  if (failed(AppendCtrlOutType(out_types))) { return failure(); }\n  state.addTypes(out_types);\n  // create op\n  auto op = GetBuilder().create(state);\n  if (!op) {\n    GetModule()->emitError(\"fail to create op, name: \" + op_conf.name());\n    return failure();\n  }\n  // record result\n  if (op->getNumResults() != 2) {\n    op->emitError(\"output_conf op should has two results (out and ctrl_output), but got \"\n                  + std::to_string(op->getNumResults()) + \"\\n\");\n    return failure();\n  }\n  if (!lbn2result_.emplace(output_lbn, op->getResult(0)).second) {\n    op->emitError(\"lbn already exists, lbn: \") << output_lbn;\n    return failure();\n  }\n  if (!op_name2ctrl_result_.emplace(op_conf.name(), op->getResult(1)).second) {\n    op->emitError(\"ctrl output already exists, op_name: \") << op_conf.name();\n    return failure();\n  }\n  return success();\n}\n\nLogicalResult JobImporter::ProcessJob() {\n  llvm::SmallVector<Type, 8> input_types;\n  llvm::SmallVector<Type, 4> result_types;\n  llvm::SmallVector<Value, 4> results;\n  bool is_succeeded = true;\n\n  job_wrapper_.TopoForEachOpConf([&](const ::oneflow::OperatorConf* op_conf) {\n    if (op_conf->has_input_conf()) {\n      auto type = GetInterfaceBlobConfType(op_conf->input_conf().blob_conf());\n      if (type) {\n        input_types.emplace_back(type);\n      } else {\n        GetModule()->emitError(\"fail to collect func arg types for job:\\n\"\n                               + op_conf->DebugString());\n        is_succeeded = false;\n      }\n    }\n  });\n  if (!is_succeeded) { return failure(); }\n\n  auto func_type = GetBuilder().getFunctionType(input_types, std::nullopt);\n  auto job_op =\n      GetBuilder().create<oneflow::Job>(GetRootLocation(), job_->job_conf().job_name(), func_type);\n  auto* entryBlock = job_op.addEntryBlock();\n  GetBuilder().setInsertionPointToStart(entryBlock);\n\n  is_succeeded = true;\n  size_t input_count = 0;\n  job_wrapper_.TopoForEachOpConf([&](const ::oneflow::OperatorConf* op_conf) {\n    if (is_succeeded == false) { return; }\n    if (op_conf->has_user_conf()) {\n      is_succeeded = succeeded(ProcessUserOp(*op_conf));\n    } else if (op_conf->has_input_conf()) {\n      is_succeeded = succeeded(ProcessInputOp(*op_conf, entryBlock, input_count));\n    } else if (op_conf->has_output_conf()) {\n      is_succeeded = succeeded(ProcessOutputOp(*op_conf));\n      if (is_succeeded) {\n        auto result = entryBlock->back().getResult(0);\n        results.emplace_back(result);\n        result_types.emplace_back(result.getType());\n      }\n    } else {\n      is_succeeded = succeeded(ProcessSystemOp(*op_conf));\n    }\n  });\n  if (is_succeeded == false) { return failure(); }\n  mlir::oneflow::ReturnOp return_op;\n  if (!entryBlock->empty()) { return_op = dyn_cast<mlir::oneflow::ReturnOp>(entryBlock->back()); }\n  if (!return_op) { GetBuilder().create<mlir::oneflow::ReturnOp>(GetRootLocation(), results); }\n\n  func_type = GetBuilder().getFunctionType(input_types, result_types);\n  job_op.setFunctionTypeAttr(TypeAttr::get(func_type));\n  GetModule().push_back(job_op);\n  return success();\n}\n\ntemplate<typename OpType, typename AdaptorType>\nvoid UpdatePlacement(OpType* op, AdaptorType& adaptor, ::oneflow::Job& job) {\n  auto* pg = job.mutable_placement()->add_placement_group();\n  pg->mutable_op_set()->add_op_name(adaptor.getOpName().str());\n  pg->mutable_parallel_conf()->set_device_tag(adaptor.getDeviceTag().str());\n  for (auto p : adaptor.getDeviceName()) {\n    pg->mutable_parallel_conf()->add_device_name(\n        p.template dyn_cast<StringAttr>().getValue().str());\n  }\n  if (::llvm::Optional<ArrayAttr> hierarchy = adaptor.getHierarchy()) {\n    for (auto dim : hierarchy->getValue()) {\n      pg->mutable_parallel_conf()->mutable_hierarchy()->add_dim(\n          dim.template dyn_cast<IntegerAttr>().getInt());\n    }\n  }\n}\n\nLogicalResult JobImporter::TryToUpdateJob() {\n  auto new_job = ::oneflow::Job();\n  new_job.CopyFrom(*job_);\n  new_job.clear_net();\n  new_job.mutable_placement()->clear_placement_group();\n\n  Operation* job_op = nullptr;\n  llvm::SmallVector<Value, 4> outputs;\n\n  auto find_first_job = [&](oneflow::Job job) -> WalkResult {\n    job_op = job.getOperation();\n    new_job.mutable_job_conf()->set_job_name(job.getSymName().str());\n    return WalkResult::interrupt();\n  };\n\n  GetModule().getOperation()->walk(find_first_job);\n  if (!job_op) {\n    GetModule()->emitError(\"job not found. module op: \") << *GetModule();\n    return failure();\n  }\n\n  auto ConvertOp = [&](Operation* op) -> WalkResult {\n    if (op->hasTrait<OpTrait::IsOpConfCompatible>()) {\n      if (llvm::dyn_cast<oneflow::UserOp>(op)) {\n        if (failed(ConvertUserOp(op, new_job))) {\n          op->emitError(\"failed to convert generic UserOp: \") << *op;\n          return WalkResult::interrupt();\n        }\n      } else if (llvm::dyn_cast<oneflow::SystemOp>(op)) {\n        if (failed(ConvertSystemOp(op, new_job))) {\n          op->emitError(\"failed to convert SystemOp: \") << *op;\n          return WalkResult::interrupt();\n        }\n      } else if (auto variable_op = llvm::dyn_cast<oneflow::VariableOp>(op)) {\n        if (failed(ConvertVariableOp(variable_op, new_job))) {\n          op->emitError(\"failed to process VariableOp: \") << *op;\n          return WalkResult::interrupt();\n        }\n      } else if (llvm::dyn_cast<oneflow::InputOp>(op) || llvm::dyn_cast<oneflow::OutputOp>(op)) {\n        // do nothing and advance\n      } else {\n        if (!dyn_cast<UserOpCompatible>(op)) {\n          op->emitError(\"op is not UserOpCompatible \") << *op;\n          return WalkResult::interrupt();\n        }\n        if (failed(ConvertUserOp(op, new_job))) {\n          op->emitError(\"failed to process UserOp: \") << *op;\n          return WalkResult::interrupt();\n        }\n      }\n    } else if (llvm::dyn_cast<mlir::oneflow::Job>(op)) {\n      // do nothing and advance\n    } else if (op->hasTrait<OpTrait::OnlyExistsInIR>()) {\n      // do nothing and advance\n    } else if (auto return_op = llvm::dyn_cast<mlir::oneflow::ReturnOp>(op)) {\n      for (auto operand : return_op->getOperands()) { outputs.emplace_back(operand); }\n    } else {\n      op->emitError(\"unexcepted op: \") << *op;\n      return WalkResult::interrupt();\n    }\n    return WalkResult::advance();\n  };\n  if (job_op->walk(ConvertOp).wasInterrupted()) { return failure(); }\n\n  // add input op\n  auto arguments = llvm::dyn_cast<oneflow::Job>(job_op).getBody().front().getArguments();\n  for (BlockArgument argument : arguments) {\n    for (auto& use : argument.getUses()) {\n      Operation* owner = use.getOwner();\n      if (auto input_op = dyn_cast<oneflow::InputOp>(owner)) {\n        if (failed(ConvertInputOp(input_op, new_job))) { return failure(); }\n      } else {\n        return failure();\n      }\n    }\n  }\n  // add output op\n  for (auto output : outputs) {\n    Operation* owner = output.getDefiningOp();\n    if (auto output_op = dyn_cast<oneflow::OutputOp>(owner)) {\n      if (failed(ConvertOutputOp(output_op, new_job))) { return failure(); }\n    } else {\n      return failure();\n    }\n  }\n\n  job_wrapper_.UpdateJob(&new_job);\n  return success();\n}\n\nLogicalResult JobImporter::ConvertUserOp(Operation* op, ::oneflow::Job& job) {\n  oneflow::ConfOpAdaptor conf_op_adaptor(op->getOperands(), op->getAttrDictionary());\n  UpdatePlacement(op, conf_op_adaptor, job);\n  StringRef op_name = conf_op_adaptor.getOpName();\n\n  auto* op_conf = job.mutable_net()->add_op();\n  auto* user_conf = op_conf->mutable_user_conf();\n  if (!succeeded(ConvertUserOpInputs(op, op_name, user_conf))) {\n    op->emitError(\"fail to convert user op inputs\");\n    return failure();\n  }\n  if (!succeeded(ConvertUserOpOutputs(op, op_name, user_conf))) {\n    op->emitError(\"fail to convert user op outputs\");\n    return failure();\n  }\n  if (!succeeded(user_op::ConvertUserOpAttributes(op, *op_conf, false))) {\n    op->emitError(\"fail to convert user op attributes\");\n    return failure();\n  }\n  if (!succeeded(ConvertCtrlInputs(op, *op_conf))) {\n    op->emitError(\"fail to convert user op control inputs\");\n    return failure();\n  }\n  return success();\n}\n\nLogicalResult JobImporter::ConvertSystemOp(Operation* op, ::oneflow::Job& job) {\n  oneflow::SystemOpAdaptor system_op_adaptor(op->getOperands(), op->getAttrDictionary());\n  UpdatePlacement(op, system_op_adaptor, job);\n  auto op_name = system_op_adaptor.getOpName().str();\n  ::oneflow::OperatorConf op_conf = job_wrapper_.OpConf4OpName(op_name);\n  for (const auto& ibn : llvm::enumerate(op->getAttrOfType<ArrayAttr>(\"input_bns\"))) {\n    auto result = GetDataInputOperands(op)[ibn.index()].dyn_cast<OpResult>();\n    std::string new_val = user_op::GetOutputLbn(result).value();\n    job_wrapper_.ReplaceInputLbnInOpCustomizedConf(\n        &op_conf, ibn.value().dyn_cast<StringAttr>().getValue().str(), new_val);\n  }\n  if (failed(ConvertCtrlInputs(op, op_conf))) { return failure(); }\n  *(job.mutable_net()->add_op()) = op_conf;\n  return success();\n}\n\nLogicalResult JobImporter::ConvertVariableOp(VariableOp op, ::oneflow::Job& job) {\n  oneflow::VariableOpAdaptor op_adaptor(op->getOperands(), op->getAttrDictionary());\n  UpdatePlacement(&op, op_adaptor, job);\n  auto* op_conf = job.mutable_net()->add_op();\n  return ConvertVariableOpConf(op, op_conf);\n}\n\nLogicalResult JobImporter::ConvertInputOp(InputOp op, ::oneflow::Job& job) {\n  oneflow::InputOpAdaptor op_adaptor(op->getOperands(), op->getAttrDictionary());\n  UpdatePlacement(&op, op_adaptor, job);\n  auto* op_conf = job.mutable_net()->add_op();\n  return ConvertInputOpConf(op, op_conf);\n}\n\nLogicalResult JobImporter::ConvertOutputOp(OutputOp op, ::oneflow::Job& job) {\n  oneflow::OutputOpAdaptor op_adaptor(op->getOperands(), op->getAttrDictionary());\n  UpdatePlacement(&op, op_adaptor, job);\n  auto* op_conf = job.mutable_net()->add_op();\n  return ConvertOutputOpConf(op, op_conf);\n}\n\nType JobImporter::GetInterfaceBlobConfType(const ::oneflow::InterfaceBlobConf& blob_conf) {\n  if (!blob_conf.has_data_type()) { return Type{}; }\n  if (!blob_conf.has_shape()) { return Type{}; };\n  if (auto data_type = getTypeFromOneFlowDataType(GetMLIRContext(), blob_conf.data_type())) {\n    return RankedTensorType::get({blob_conf.shape().dim().begin(), blob_conf.shape().dim().end()},\n                                 data_type);\n  } else {\n    return Type{};\n  }\n}\n\nvoid DumpMLIR(RoundTripOneFlowJobWrapperInterface& job_wrapper, ModuleOp module,\n              const std::string& name) {\n  std::string mlir;\n  llvm::raw_string_ostream os_mlir(mlir);\n  module->print(os_mlir);\n  job_wrapper.DumpLog(name + \".mlir\", mlir);\n}\n\nLogicalResult ApplyRoundTripPatterns(RoundTripOneFlowJobWrapperInterface& job_wrapper,\n                                     MLIRContext* context, OwningOpRef<ModuleOp>& module) {\n  mlir::PassManager pm(context);\n  if (::oneflow::ParseBooleanFromEnv(\"ONEFLOW_MLIR_ENABLE_TIMING\", false)) { pm.enableTiming(); }\n  mlir::oneflow::CheckEnableIRPrinting(pm);\n  // this canonicalizer should create concrete ops and create fuse opportunities\n  pm.addPass(createCanonicalizerPass());\n  // we must do auto nhwc and eliminate redundant transpose op first, avoid insert redundant\n  // transpose op due to fuse pattern like normlazation_add_relu.\n  pm.addPass(oneflow::createAutoNhwcPass());\n  if (::oneflow::ParseBooleanFromEnv(\"ONEFLOW_MLIR_CSE\", false)) {\n    auto cse_state = std::make_shared<CSEState>();\n    auto passes = createCSEPasses(cse_state);\n    pm.addPass(std::move(passes.first));\n    pm.addPass(createCSEPass());\n    pm.addPass(std::move(passes.second));\n  }\n  if (job_wrapper.IsLastIRPass()\n      && ::oneflow::ParseBooleanFromEnv(\"ONEFLOW_MLIR_FUSE_FORWARD_OPS\", false)) {\n    pm.addPass(oneflow::createFuseForwardOps());\n    pm.addPass(oneflow::createFuseIntoExistingOpPass());\n  }\n  if (job_wrapper.IsLastIRPass()\n      && ::oneflow::ParseBooleanFromEnv(\"ONEFLOW_MLIR_ENABLE_CODEGEN_FUSERS\", false)) {\n    pm.addPass(oneflow::createOneFlowJobToFuncPass());\n    pm.addPass(oneflow::createCastOneFlowOpsToSignlessPass());\n    auto toTosa = oneflow::createLowerOneFlowToTosaPass();\n    CHECK(toTosa->initializeOptions(\"full=0 lower-job=0\").succeeded());\n    pm.addPass(std::move(toTosa));\n    pm.addNestedPass<func::FuncOp>(tosa::createTosaMakeBroadcastablePass());\n    pm.addPass(oneflow::createLowerOneFlowToLinalgPass());\n    pm.addPass(tosa::createTosaToTensor());\n    pm.addNestedPass<func::FuncOp>(tosa::createTosaToLinalgNamed());\n    pm.addNestedPass<func::FuncOp>(tosa::createTosaToLinalg());\n    pm.addPass(createLinalgElementwiseOpFusionPass());\n    pm.addPass(oneflow::createFuncToOneFlowJobPass());\n    pm.addNestedPass<Job>(oneflow::createOutlineJitFunctionPass());\n    pm.addPass(createCanonicalizerPass());\n  }\n  if (!job_wrapper.IsLastIRPass()\n      && ::oneflow::ParseBooleanFromEnv(\"ONEFLOW_MLIR_FUSE_OPS_WITH_BACKWARD_IMPL\", false)) {\n    pm.addPass(oneflow::createFuseOpsWithBackwardImpl());\n  }\n  // TODO: support backward or put it in a env flag\n  if (job_wrapper.IsLastIRPass()\n      && ::oneflow::ParseBooleanFromEnv(\"ONEFLOW_MLIR_GROUP_MATMUL\", false)) {\n    pm.addPass(oneflow::createGroupMatMul());\n  }\n  if (!job_wrapper.IsLastIRPass()\n      && ::oneflow::ParseBooleanFromEnv(\"ONEFLOW_MLIR_ENABLE_INFERENCE_OPTIMIZATION\", false)) {\n    pm.addPass(oneflow::createPreConvertInferenceOpPass());\n    pm.addPass(oneflow::createConvertInferenceOpPass());\n    pm.addPass(oneflow::createPostConvertInferenceOpPass());\n  }\n  if (!job_wrapper.IsLastIRPass()\n      && ::oneflow::ParseBooleanFromEnv(\"ONEFLOW_MLIR_FUSE_NORMALIZATION_OPS\", false)) {\n    pm.addPass(oneflow::createFuseNormalizationOps());\n  }\n  if (job_wrapper.IsLastIRPass()\n      && ::oneflow::ParseBooleanFromEnv(\"ONEFLOW_MLIR_FUSE_KERNEL_LAUNCH\", false)) {\n    pm.addPass(createAggregateComputeOpsPass());\n\n    auto wrap_pass = createWrapOpsToKernelLaunchPass();\n    std::string options =\n        \"mode=\"\n        + (::oneflow::ParseBooleanFromEnv(\"ONEFLOW_KERNEL_ENABLE_CUDA_GRAPH\", false)\n               ? wrap_mode::CUDA_GRAPH\n               : wrap_mode::SIMPLE);\n\n    (void)wrap_pass->initializeOptions(options);\n    pm.addPass(std::move(wrap_pass));\n  }\n  pm.addPass(createCanonicalizerPass());\n  if (::oneflow::ParseBooleanFromEnv(\"ONEFLOW_MLIR_PRINT_STATS\", false)) {\n    pm.addPass(createPrintOpStatsPass());\n  }\n  std::string graphviz;\n  llvm::raw_string_ostream os_graphviz(graphviz);\n  const bool shouldPrintGraphviz =\n      ::oneflow::ParseBooleanFromEnv(\"ONEFLOW_MLIR_PRINT_OP_GRAPH\", false);\n  if (shouldPrintGraphviz) { pm.addPass(createPrintOpGraphPass(os_graphviz)); }\n  if (mlir::failed(pm.run(*module))) {\n    module->emitError(\"Failed to run round-trip passes\");\n    return failure();\n  }\n  if (shouldPrintGraphviz) {\n    job_wrapper.DumpLog(\"RoundTripOneFlowJob.optimized.mlir.dot\", graphviz);\n  }\n  if (::oneflow::ParseBooleanFromEnv(\"ONEFLOW_MLIR_DUMPMLIR\", false)) {\n    DumpMLIR(job_wrapper, module.get(), \"RoundTripOneFlowJob.optimized\");\n  }\n  return success();\n}\n\nOwningOpRef<ModuleOp> TranslateOneFlowJobToModule(llvm::StringRef str, MLIRContext* context) {\n  std::string cpp_str = str.str();\n  ::oneflow::Job job;\n  google::protobuf::TextFormat::ParseFromString(cpp_str, &job);\n  context->loadDialect<oneflow::OneFlowDialect>();\n  context->loadDialect<mlir::func::FuncDialect>();\n  OwningOpRef<ModuleOp> module(\n      ModuleOp::create(FileLineColLoc::get(context, \"\", /*line=*/0, /*column=*/0)));\n  return module;\n}\n\nvoid RoundTripOneFlowJob(\n    RoundTripOneFlowJobWrapperInterface& job_wrapper,\n    const std::function<bool(::oneflow::Job* job, std::string& reason)>& is_legit_job) {\n  const ::oneflow::Job* job = job_wrapper.job();\n  mlir::MLIRContext context;\n  context.getOrLoadDialect<oneflow::OneFlowDialect>();\n  context.loadDialect<mlir::func::FuncDialect>();\n\n  OwningOpRef<ModuleOp> module(\n      ModuleOp::create(FileLineColLoc::get(&context, \"\", /*line=*/0, /*column=*/0)));\n  JobImporter imp(job_wrapper, &context, module.get());\n  // TODO: Add flag in job desc to decide whether to run mlir optimizer\n  if (succeeded(imp.ProcessJob())) {\n    if (::oneflow::ParseBooleanFromEnv(\"ONEFLOW_MLIR_DUMPMLIR\", false)) {\n      DumpMLIR(job_wrapper, module.get(), \"RoundTripOneFlowJob.imported\");\n    }\n    if (failed(ApplyRoundTripPatterns(job_wrapper, &context, module))) { exit(EXIT_FAILURE); }\n    if (::oneflow::ParseBooleanFromEnv(\"ONEFLOW_MLIR_STDOUT\", false)\n        && job_wrapper.IsLastIRPass()) {\n      // for FileCheck\n      module->print(llvm::outs());\n    }\n    // TODO: Add flag in oneflow to define if failure in MLIR is allowed\n    if (failed(imp.TryToUpdateJob())) {\n      llvm::errs() << \"fail to update job with IR, job will stay intact, job_name: \"\n                   << job->job_conf().job_name() << \"\\n\";\n      exit(EXIT_FAILURE);\n    }\n  } else {\n    llvm::errs() << \"fail to convert job to IR, job_name: \" << job->job_conf().job_name() << \"\\n\";\n    exit(EXIT_FAILURE);\n  }\n}\n\nstd::string ConvertJobToTosaIR(RoundTripOneFlowJobWrapperInterface& job_wrapper) {\n  const ::oneflow::Job* job = job_wrapper.job();\n  mlir::MLIRContext context;\n  context.getOrLoadDialect<oneflow::OneFlowDialect>();\n  context.loadDialect<mlir::func::FuncDialect>();\n\n  OwningOpRef<ModuleOp> module(\n      ModuleOp::create(FileLineColLoc::get(&context, \"\", /*line=*/0, /*column=*/0)));\n  JobImporter imp(job_wrapper, &context, module.get());\n  if (succeeded(imp.ProcessJob())) {\n    mlir::PassManager pm(&context);\n    pm.addPass(createCanonicalizerPass());\n    pm.addPass(createConvertToSignlessForTosaPass());\n    pm.addPass(createLowerOneFlowToTosaPass());\n    pm.addNestedPass<func::FuncOp>(tosa::createTosaMakeBroadcastablePass());\n    if (mlir::failed(pm.run(*module))) {\n      module->emitError(\"Failed to run oneflow-to-tosa pass\");\n      exit(EXIT_FAILURE);\n    }\n\n    std::string mlir;\n    llvm::raw_string_ostream os_mlir(mlir);\n    module->print(os_mlir);\n    return mlir;\n  } else {\n    const auto& job_name = job->job_conf().job_name();\n    llvm::errs() << \"fail to convert job to IR, job_name: \" << job_name << \"\\n\";\n    exit(EXIT_FAILURE);\n  }\n}\n\nstd::string ConvertJobToIR(RoundTripOneFlowJobWrapperInterface& job_wrapper) {\n  const ::oneflow::Job* job = job_wrapper.job();\n  mlir::MLIRContext context;\n  context.getOrLoadDialect<oneflow::OneFlowDialect>();\n  context.loadDialect<mlir::func::FuncDialect>();\n\n  OwningOpRef<ModuleOp> module(\n      ModuleOp::create(FileLineColLoc::get(&context, \"\", /*line=*/0, /*column=*/0)));\n  JobImporter imp(job_wrapper, &context, module.get());\n  if (succeeded(imp.ProcessJob())) {\n    mlir::PassManager pm(&context);\n    pm.addPass(createCanonicalizerPass());\n    if (mlir::failed(pm.run(*module))) {\n      module->emitError(\"Failed to run canonicalizer pass\");\n      exit(EXIT_FAILURE);\n    }\n\n    std::string mlir;\n    llvm::raw_string_ostream os_mlir(mlir);\n    module->print(os_mlir);\n    return mlir;\n  } else {\n    const auto& job_name = job->job_conf().job_name();\n    llvm::errs() << \"Failed to convert Job to IR, job_name: \" << job_name << \"\\n\";\n    exit(EXIT_FAILURE);\n  }\n}\n\nvoid SaveJobToIR(RoundTripOneFlowJobWrapperInterface& job_wrapper, const std::string& path) {\n  const ::oneflow::Job* job = job_wrapper.job();\n  mlir::MLIRContext context;\n  context.getOrLoadDialect<oneflow::OneFlowDialect>();\n  context.loadDialect<mlir::func::FuncDialect>();\n\n  OwningOpRef<ModuleOp> module(\n      ModuleOp::create(FileLineColLoc::get(&context, \"\", /*line=*/0, /*column=*/0)));\n  JobImporter imp(job_wrapper, &context, module.get());\n  if (succeeded(imp.ProcessJob())) {\n    mlir::PassManager pm(&context);\n    pm.addPass(createCanonicalizerPass());\n    if (mlir::failed(pm.run(*module))) {\n      module->emitError(\"Failed to run canonicalizer pass\");\n      exit(EXIT_FAILURE);\n    }\n\n    std::string mlir;\n    llvm::raw_string_ostream os_mlir(mlir);\n    module->print(os_mlir);\n    std::string filename = path + \"/model.mlir\";\n    std::ofstream fs(filename, std::ios::trunc);\n    if (!fs.is_open()) {\n      llvm::errs() << \"fail to open file \" << filename;\n      exit(EXIT_FAILURE);\n    }\n    fs << mlir;\n    fs.close();\n  } else {\n    const auto& job_name = job->job_conf().job_name();\n    llvm::errs() << \"fail to convert job to IR, job_name: \" << job_name << \"\\n\";\n    exit(EXIT_FAILURE);\n  }\n}\n\nvoid LoadJobFromIR(RoundTripOneFlowJobWrapperInterface& job_wrapper, const std::string& path) {\n  MLIRContext context;\n  context.getOrLoadDialect<oneflow::OneFlowDialect>();\n  context.loadDialect<mlir::func::FuncDialect>();\n  OwningOpRef<ModuleOp> module = parseSourceFile<ModuleOp>(path, &context);\n  if (!module) {\n    llvm::errs() << \"fail to parse file: \" << path << \"\\n\";\n    exit(EXIT_FAILURE);\n  }\n  JobImporter imp(job_wrapper, &context, module.get());\n  if (failed(imp.TryToUpdateJob())) {\n    llvm::errs() << \"fail to load job from IR\";\n    exit(EXIT_FAILURE);\n  }\n}\n\nvoid registerFromOneFlowJobTranslation() {\n  TranslateToMLIRRegistration fromOneFlowJob(\"import-oneflow-job\", \"import oneflow from job\",\n                                             [](llvm::StringRef str, MLIRContext* context) {\n                                               return TranslateOneFlowJobToModule(str, context);\n                                             });\n}\n\n}  // namespace oneflow\n\n}  // namespace mlir\n"
  },
  {
    "path": "oneflow/ir/oneflow-translate/oneflow-translate.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"OneFlow/MLIROneFlowTranslation.h\"\n#include \"mlir/InitAllTranslations.h\"\n#include \"mlir/Tools/mlir-translate/MlirTranslateMain.h\"\n#include \"mlir/Tools/mlir-translate/Translation.h\"\n#include \"mlir/Support/LogicalResult.h\"\n\nint32_t main(int32_t argc, char** argv) {\n  mlir::registerAllTranslations();\n  mlir::oneflow::registerFromOneFlowJobTranslation();\n\n  return failed(mlir::mlirTranslateMain(argc, argv, \"MLIR Translation Testing Tool\"));\n}\n"
  },
  {
    "path": "oneflow/ir/test/CMakeLists.txt",
    "content": "llvm_canonicalize_cmake_booleans(WITH_MLIR_CUDA_CODEGEN BUILD_CUDA)\n\nmessage(STATUS \"LLVM_TOOLS_BINARY_DIR (used as LLVM_TOOLS_DIR): ${LLVM_TOOLS_BINARY_DIR}\")\nmessage(STATUS \"LLVM_EXTERNAL_LIT: ${LLVM_EXTERNAL_LIT}\")\nconfigure_lit_site_cfg(\n  ${CMAKE_CURRENT_SOURCE_DIR}/lit.site.cfg.py.in ${CMAKE_CURRENT_BINARY_DIR}/lit.site.cfg.py\n  MAIN_CONFIG ${CMAKE_CURRENT_SOURCE_DIR}/lit.cfg.py)\n\nset(ONEFLOW_TEST_DEPENDS FileCheck count not oneflow-opt oneflow-translate)\n\nadd_lit_testsuite(\n  check-oneflow \"Running the OneFlow MLIR regression tests from: ${CMAKE_CURRENT_SOURCE_DIR}\"\n  ${CMAKE_CURRENT_BINARY_DIR} DEPENDS ${ONEFLOW_TEST_DEPENDS})\nset_target_properties(check-oneflow PROPERTIES FOLDER \"Tests\")\nif(LLVM_PROVIDER STREQUAL \"in-tree\")\n  add_dependencies(check-oneflow mlir-cpu-runner)\nendif()\nadd_dependencies(check-oneflow oneflow_internal)\nadd_dependencies(check-oneflow oneflow-runner)\nadd_lit_testsuites(ONEFLOW ${CMAKE_CURRENT_SOURCE_DIR} DEPENDS ${ONEFLOW_TEST_DEPENDS})\n\nadd_custom_target(c1 DEPENDS check-oneflow)\n"
  },
  {
    "path": "oneflow/ir/test/Frontend/lit.local.cfg",
    "content": "if not config.WITH_ONEFLOW_IREE:\n  config.unsupported = True\n"
  },
  {
    "path": "oneflow/ir/test/Frontend/oneflow_to_iree.mlir",
    "content": "// RUN: oneflow-opt %s \\\n// RUN: -split-input-file \\\n// RUN: -auto-nhwc \\\n// RUN: -lower-oneflow-to-tosa \\\n// RUN: -tosa-make-broadcastable \\\n// RUN: -verify-diagnostics -o - | FileCheck %s\n\n\n// CHECK-NOT: oneflow\noneflow.job @test_func(%arg0: tensor<1xf32>) -> tensor<1xf32>\n{\n    oneflow.return %arg0 : tensor<1xf32>\n}\n\n\noneflow.job @test_input(%arg0: tensor<1xf32>) -> tensor<1xf32>\n{\n    %res = \"oneflow.input\"(%arg0)\n    {\n        data_type = 2 : i32,\n        device_name = [\"@0:0\"],\n        device_tag = \"cpu\",\n        hierarchy = [1],\n        is_dynamic = false,\n        nd_sbp = [\"B\"],\n        op_name = \"\",\n        output_lbns = [\"\"],\n        scope_symbol_id = 4611686018427412479 : i64,\n        shape = [1 : si64]\n    } : (tensor<1xf32>) -> tensor<1xf32>\n    oneflow.return %res : tensor<1xf32>\n}\n\n\noneflow.job @test_output(%arg0: tensor<1xf32>) -> tensor<1xf32>\n{\n    %res = \"oneflow.output\"(%arg0)\n    {\n        data_type = 2 : i32,\n        device_name = [\"@0:0\"],\n        device_tag = \"cpu\",\n        hierarchy = [1],\n        is_dynamic = false,\n        nd_sbp = [\"B\"],\n        op_name = \"\",\n        output_lbns = [\"\"],\n        scope_symbol_id = 4611686018427412479 : i64,\n        shape = [1 : si64]\n    } : (tensor<1xf32>) -> tensor<1xf32>\n    oneflow.return %res : tensor<1xf32>\n}\n\n\noneflow.job @test_variable() -> tensor<64x3x7x7xf32>\n{\n    %res = \"oneflow.variable\"() {\n        data_type = 2 : i32,\n        device_name = [\"@0:0\"],\n        device_tag = \"cpu\",\n        hierarchy = [1],\n        parallel = #sbp.parallel<[] -> [#sbp.B]>,\n        op_name = \"fw.model.conv1.weight\",\n        output_lbns = [\"fw.model.conv1.weight/out\"],\n        scope_symbol_id = 4611686018427432959 : i64,\n        shape = [64 : si64, 3 : si64, 7 : si64, 7 : si64]\n    } : () -> tensor<64x3x7x7xf32>\n    oneflow.return %res : tensor<64x3x7x7xf32>\n}\n\n\noneflow.job @test_add_n2(%arg0: tensor<1x7x7xf32>, %arg1: tensor<1x7x7xf32>) -> tensor<1x7x7xf32>\n{\n    %res = \"oneflow.add_n2\"(%arg0, %arg1)\n    {\n        device_name = [\"@0:0\"],\n        device_tag = \"cpu\",\n        hierarchy = [1],\n        op_name = \"\",\n        op_type_name = \"add_n\",\n        output_lbns = [\"\"],\n        scope_symbol_id = 4611686018431205375 : i64\n    } : (tensor<1x7x7xf32>, tensor<1x7x7xf32>) -> tensor<1x7x7xf32>\n    oneflow.return %res: tensor<1x7x7xf32>\n}\n\n\noneflow.job @test_broadcast_add(%arg0: tensor<1x1000xf32>, %arg1: tensor<1000xf32>) -> tensor<1x1000xf32>\n{\n    %res = \"oneflow.broadcast_add\"(%arg0, %arg1)\n    {\n        device_name = [\"@0:0\"],\n        device_tag = \"cpu\",\n        hierarchy = [1],\n        op_name = \"\",\n        output_lbns = [\"\"],\n        scope_symbol_id = 4611686018431234047 : i64\n    } : (tensor<1x1000xf32>, tensor<1000xf32>) -> tensor<1x1000xf32>\n    oneflow.return %res : tensor<1x1000xf32>\n}\n\n\noneflow.job @test_max_pool_2d(%arg0: tensor<1x64x112x112xf32>) -> tensor<1x64x56x56xf32>\n{\n    %y, %indice = \"oneflow.max_pool_2d\"(%arg0)\n    {\n        ceil_mode = false,\n        data_format = \"channels_first\",\n        device_name = [\"@0:0\"],\n        device_tag = \"cpu\",\n        dilation = [1 : si32, 1 : si32],\n        hierarchy = [1], kernel_size = [3 : si32, 3 : si32],\n        op_name = \"\",\n        output_lbns = [\"\", \"\"],\n        padding = [1 : si32, 1 : si32],\n        return_indices = false,\n        scope_symbol_id = 4611686018427502591 : i64,\n        stride = [2 : si32, 2 : si32]\n    } : (tensor<1x64x112x112xf32>) -> (tensor<1x64x56x56xf32>, tensor<1x64x56x56xi64>)\n    oneflow.return %y :  tensor<1x64x56x56xf32>\n}\n\n\noneflow.job @test_avg_pool_2d(%arg0: tensor<1x2048x7x7xf32>) -> tensor<1x2048x1x1xf32>\n{\n    %res = \"oneflow.avg_pool_2d\"(%arg0)\n    {\n        ceil_mode = false,\n        count_include_pad = true,\n        data_format = \"channels_first\",\n        device_name = [\"@0:0\"],\n        device_tag = \"cpu\",\n        divisor_override = 0 : si32,\n        hierarchy = [1],\n        kernel_size = [7 : si32, 7 : si32],\n        op_name = \"model.avgpool-avg_pool_2d-172\",\n        output_lbns = [\"model.avgpool-avg_pool_2d-172/y_0\"],\n        padding = [0 : si32, 0 : si32],\n        scope_symbol_id = 4611686018430775295 : i64,\n        stride = [7 : si32, 7 : si32]\n    } : (tensor<1x2048x7x7xf32>) -> tensor<1x2048x1x1xf32>\n    oneflow.return %res: tensor<1x2048x1x1xf32>\n}\n\n\noneflow.job @test_conv2d(%arg0: tensor<1x3x224x224xf32>, %arg1: tensor<5x3x1x1xf32>) -> tensor<1x5x224x224xf32>\n{\n    %res = \"oneflow.conv2d\"(%arg0, %arg1)\n    {\n        data_format = \"channels_first\",\n        device_name = [\"@0:0\"],\n        device_tag = \"cpu\",\n        dilation_rate = [1 : si32, 1 : si32],\n        filters = 512 : si32,\n        groups = 1 : si32,\n        hierarchy = [1],\n        kernel_size = [1 : si32, 1 : si32],\n        op_name = \"\",\n        operand_segment_sizes = array<i32: 1, 1, 0, 0>,\n        output_lbns = [\"\"],\n        padding_before = [0 : si32, 0 : si32],\n        scope_symbol_id = 4611686018431012863 : i64,\n        strides = [1 : si32, 1 : si32]\n    } : (tensor<1x3x224x224xf32>, tensor<5x3x1x1xf32>) -> tensor<1x5x224x224xf32>\n    oneflow.return %res : tensor<1x5x224x224xf32>\n}\n\n\noneflow.job @test_matmul(%arg0: tensor<1x2048xf32>, %arg1: tensor<1000x2048xf32>) ->tensor<1x1000xf32>\n{\n    %res = \"oneflow.matmul\"(%arg0, %arg1)\n    {\n        alpha = 1.000000e+00 : f64,\n        device_name = [\"@0:0\"],\n        device_tag = \"cpu\",\n        hierarchy = [1],\n        op_name = \"\",\n        output_lbns = [\"\"],\n        scope_symbol_id = 4611686018431234047 : i64,\n        transpose_a = false,\n        transpose_b = true\n    } : (tensor<1x2048xf32>, tensor<1000x2048xf32>) -> tensor<1x1000xf32>\n    oneflow.return %res : tensor<1x1000xf32>\n}\n\n\noneflow.job @test_relu(%arg0: tensor<1xf32>) -> tensor<1xf32> {\n    %res = \"oneflow.relu\"(%arg0)\n    {\n        device_name = [\"@0:0\"],\n        device_tag = \"cpu\",\n        hierarchy = [1],\n        op_name = \"\",\n        output_lbns = [\"\"],\n        scope_symbol_id = 4611686018427424767 : i64\n    } : (tensor<1xf32>) -> tensor<1xf32>\n    oneflow.return %res : tensor<1xf32>\n}\n\noneflow.job @test_bn(\n%x:               tensor<1x64x112x112xf32>,\n%moving_mean:     tensor<64xf32>,\n%moving_variance: tensor<64xf32>,\n%gamma:           tensor<64xf32>,\n%beta:            tensor<64xf32>) -> tensor<1x64x112x112xf32>\n{\n    %y, %mean, %inv_variance = \"oneflow.normalization\"(%x, %moving_mean, %moving_variance, %gamma, %beta)\n    {\n        axis = 1 : si32,\n        device_name = [\"@0:0\"],\n        device_tag = \"cpu\",\n        epsilon = 9.99999974E-6 : f32,\n        hierarchy = [1],\n        momentum = 0.899999976 : f32,\n        op_name = \"\",\n        operand_segment_sizes = array<i32: 1, 1, 1, 1, 1, 0>,\n        output_lbns = [\"\", \"\", \"\"],\n        result_segment_sizes = array<i32: 1, 1, 1>,\n        scope_symbol_id = 4611686018427453439 : i64,\n        training = true\n    } : (tensor<1x64x112x112xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>) -> (tensor<1x64x112x112xf32>, tensor<64xf32>, tensor<64xf32>)\n    oneflow.return %y: tensor<1x64x112x112xf32>\n}\n\noneflow.job @test_bn_infer(\n%x:               tensor<1x64x112x112xf32>,\n%moving_mean:     tensor<64xf32>,\n%moving_variance: tensor<64xf32>,\n%gamma:           tensor<64xf32>,\n%beta:            tensor<64xf32>) -> tensor<1x64x112x112xf32>\n{\n    %y = \"oneflow.normalization_infer\"(%x, %moving_mean, %moving_variance, %gamma, %beta)\n    {\n        axis = 1 : si32,\n        device_name = [\"@0:0\"],\n        device_tag = \"cpu\",\n        epsilon = 9.99999974E-6 : f32,\n        hierarchy = [1],\n        momentum = 0.899999976 : f32,\n        op_name = \"\",\n        operand_segment_sizes = array<i32: 1, 1, 1, 1, 1, 0>,\n        output_lbns = [\"\", \"\", \"\"],\n        result_segment_sizes = array<i32: 1, 1, 1>,\n        scope_symbol_id = 4611686018427453439 : i64,\n        training = true\n    } : (tensor<1x64x112x112xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>) -> tensor<1x64x112x112xf32>\n    oneflow.return %y: tensor<1x64x112x112xf32>\n}\n"
  },
  {
    "path": "oneflow/ir/test/Frontend/tosa_to_elf.mlir",
    "content": "// RUN: oneflow-opt %s \\\n// RUN: -pass-pipeline=\"builtin.module(func.func(tosa-to-linalg))\" \\\n// RUN: | oneflow-opt -cse \\\n// RUN: --linalg-fuse-elementwise-ops -empty-tensor-to-alloc-tensor -linalg-bufferize \\\n// RUN: -tensor-bufferize -func-bufferize -buffer-results-to-out-params \\\n// RUN: -convert-linalg-to-loops -convert-math-to-libm -convert-math-to-llvm -convert-scf-to-cf -convert-linalg-to-llvm \\\n// RUN: -convert-func-to-llvm -finalize-memref-to-llvm -reconcile-unrealized-casts --print-after-all \\\n// RUN: | oneflow-translate -mlir-to-llvmir\n\nbuiltin.module {\n  func.func @Graph_0(%arg0: tensor<2xf32>) -> tensor<2xf32> {\n    %0 = \"tosa.cast\"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>\n    %1 = \"tosa.tanh\"(%0) : (tensor<2xf32>) -> tensor<2xf32>\n    %2 = \"tosa.cast\"(%1) : (tensor<2xf32>) -> tensor<2xf32>\n    func.return %2 : tensor<2xf32>\n  }\n}\n"
  },
  {
    "path": "oneflow/ir/test/GPU/lit.local.cfg",
    "content": "if not config.WITH_MLIR_CUDA_CODEGEN:\n  config.unsupported = True\n"
  },
  {
    "path": "oneflow/ir/test/GPU/nvvm_to_cubin.mlir",
    "content": "// RUN: oneflow-opt %s -pass-pipeline=\"builtin.module(gpu.module(nvvm-to-cubin))\" | FileCheck %s\n\n// CHECK: .text.__nv_logf\n// CHECK-SAME: .text.__nv_expf\nmodule attributes {gpu.container_module, oneflow.mempool = 1 : i64} {\n  func.func @JITOpGenerated0(%arg0: memref<1xi8>, %arg1: memref<5xi64>, %arg2: memref<1xf32>, %arg3: memref<5xf32>) attributes {llvm.emit_c_interface} {\n    return\n  }\n  gpu.module @JITOpGenerated0_kernel {\n    llvm.func @__nv_logf(f32) -> f32\n    llvm.func @__nv_expf(f32) -> f32\n    llvm.func @JITOpGenerated0(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: i64, %arg3: i64, %arg4: i64, %arg5: !llvm.ptr, %arg6: !llvm.ptr, %arg7: i64, %arg8: i64, %arg9: i64, %arg10: i64, %arg11: i64, %arg12: !llvm.ptr, %arg13: !llvm.ptr, %arg14: i64, %arg15: i64, %arg16: i64, %arg17: !llvm.ptr, %arg18: !llvm.ptr, %arg19: i64, %arg20: i64, %arg21: i64, %arg22: i64, %arg23: i64, %arg24: !llvm.ptr, %arg25: !llvm.ptr, %arg26: i64, %arg27: i64, %arg28: i64, %arg29: i64, %arg30: i64, %arg31: !llvm.ptr, %arg32: !llvm.ptr, %arg33: i64, %arg34: i64, %arg35: i64, %arg36: i64, %arg37: i64, %arg38: !llvm.ptr, %arg39: !llvm.ptr, %arg40: i64, %arg41: i64, %arg42: i64, %arg43: i64, %arg44: i64, %arg45: !llvm.ptr, %arg46: !llvm.ptr, %arg47: i64, %arg48: i64, %arg49: i64, %arg50: i64, %arg51: i64) attributes {gpu.kernel, nvvm.kernel} {\n      %0 = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>\n      %1 = llvm.insertvalue %arg0, %0[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> \n      %2 = llvm.insertvalue %arg1, %1[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> \n      %3 = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>\n      %4 = llvm.insertvalue %arg5, %3[0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> \n      %5 = llvm.insertvalue %arg6, %4[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> \n      %6 = llvm.insertvalue %arg7, %5[2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> \n      %7 = llvm.insertvalue %arg8, %6[3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> \n      %8 = llvm.insertvalue %arg12, %0[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> \n      %9 = llvm.insertvalue %arg13, %8[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> \n      %10 = llvm.insertvalue %arg17, %3[0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> \n      %11 = llvm.insertvalue %arg18, %10[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> \n      %12 = llvm.insertvalue %arg19, %11[2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> \n      %13 = llvm.insertvalue %arg20, %12[3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> \n      %14 = llvm.insertvalue %arg24, %3[0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> \n      %15 = llvm.insertvalue %arg25, %14[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> \n      %16 = llvm.insertvalue %arg26, %15[2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> \n      %17 = llvm.insertvalue %arg27, %16[3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> \n      %18 = llvm.insertvalue %arg31, %3[0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> \n      %19 = llvm.insertvalue %arg32, %18[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> \n      %20 = llvm.insertvalue %arg33, %19[2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> \n      %21 = llvm.insertvalue %arg34, %20[3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> \n      %22 = llvm.insertvalue %arg38, %3[0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> \n      %23 = llvm.insertvalue %arg39, %22[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> \n      %24 = llvm.insertvalue %arg40, %23[2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> \n      %25 = llvm.insertvalue %arg41, %24[3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> \n      %26 = llvm.insertvalue %arg45, %3[0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> \n      %27 = llvm.insertvalue %arg46, %26[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> \n      %28 = llvm.insertvalue %arg47, %27[2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> \n      %29 = llvm.insertvalue %arg48, %28[3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> \n      %30 = llvm.mlir.constant(0 : index) : i64\n      %31 = llvm.mlir.constant(4000 : index) : i64\n      %32 = llvm.mlir.constant(1000 : index) : i64\n      %33 = llvm.mlir.constant(-1 : index) : i64\n      %34 = nvvm.read.ptx.sreg.ctaid.x : i32\n      %35 = llvm.sext %34 : i32 to i64\n      %36 = nvvm.read.ptx.sreg.ntid.x : i32\n      %37 = llvm.sext %36 : i32 to i64\n      %38 = nvvm.read.ptx.sreg.tid.x : i32\n      %39 = llvm.sext %38 : i32 to i64\n      %40 = llvm.mul %37, %35  : i64\n      %41 = llvm.add %39, %40  : i64\n      %42 = llvm.icmp \"slt\" %41, %31 : i64\n      llvm.cond_br %42, ^bb1, ^bb2\n    ^bb1:  // pred: ^bb0\n      %43 = llvm.srem %41, %32  : i64\n      %44 = llvm.icmp \"slt\" %43, %30 : i64\n      %45 = llvm.add %43, %32  : i64\n      %46 = llvm.select %44, %45, %43 : i1, i64\n      %47 = llvm.icmp \"slt\" %41, %30 : i64\n      %48 = llvm.sub %33, %41  : i64\n      %49 = llvm.select %47, %48, %41 : i1, i64\n      %50 = llvm.sdiv %49, %32  : i64\n      %51 = llvm.sub %33, %50  : i64\n      %52 = llvm.select %47, %51, %50 : i1, i64\n      %53 = llvm.mul %52, %32  : i64\n      %54 = llvm.add %53, %46  : i64\n      %55 = llvm.getelementptr %arg18[%54] : (!llvm.ptr, i64) -> !llvm.ptr, f16\n      %56 = llvm.load %55 : !llvm.ptr -> f16\n      %57 = llvm.getelementptr %arg6[%54] : (!llvm.ptr, i64) -> !llvm.ptr, f16\n      %58 = llvm.load %57 : !llvm.ptr -> f16\n      %59 = llvm.getelementptr %arg1[%52] : (!llvm.ptr, i64) -> !llvm.ptr, f16\n      %60 = llvm.load %59 : !llvm.ptr -> f16\n      %61 = llvm.getelementptr %arg13[%52] : (!llvm.ptr, i64) -> !llvm.ptr, f16\n      %62 = llvm.load %61 : !llvm.ptr -> f16\n      %63 = llvm.getelementptr %arg25[%54] : (!llvm.ptr, i64) -> !llvm.ptr, f32\n      %64 = llvm.load %63 : !llvm.ptr -> f32\n      %65 = llvm.fpext %60 : f16 to f32\n      %66 = llvm.call @__nv_logf(%65) : (f32) -> f32\n      %67 = llvm.fptrunc %66 : f32 to f16\n      %68 = llvm.fsub %58, %67  : f16\n      %69 = llvm.fpext %68 : f16 to f32\n      %70 = llvm.call @__nv_expf(%69) : (f32) -> f32\n      %71 = llvm.fptrunc %70 : f32 to f16\n      %72 = llvm.fmul %71, %62  : f16\n      %73 = llvm.fsub %56, %72  : f16\n      %74 = llvm.fmul %69, %64  : f32\n      %75 = llvm.fpext %73 : f16 to f32\n      %76 = llvm.getelementptr %arg32[%54] : (!llvm.ptr, i64) -> !llvm.ptr, f16\n      llvm.store %73, %76 : f16, !llvm.ptr\n      %77 = llvm.getelementptr %arg39[%54] : (!llvm.ptr, i64) -> !llvm.ptr, f32\n      llvm.store %74, %77 : f32, !llvm.ptr\n      %78 = llvm.getelementptr %arg46[%54] : (!llvm.ptr, i64) -> !llvm.ptr, f32\n      llvm.store %75, %78 : f32, !llvm.ptr\n      llvm.br ^bb2\n    ^bb2:  // 2 preds: ^bb0, ^bb1\n      llvm.return\n    }\n  }\n}"
  },
  {
    "path": "oneflow/ir/test/OneFlow/auto_nhwc/lit.local.cfg",
    "content": "if not config.BUILD_CUDA:\n  config.unsupported = True\n"
  },
  {
    "path": "oneflow/ir/test/OneFlow/auto_nhwc/test_nhwc_batchnorm_relu.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n# RUN: python3 -m oneflow.test_utils.throttle --with-cuda=%with_cuda python3 %s | FileCheck %s\n# CHECK: oneflow.transpose\n\nimport unittest\nimport numpy as np\n\nimport os\n\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\ndef do_nhwc_bacth_norm(test_case, with_cuda):\n    x = flow.randn(2, 3, 4, 5)\n    bn = flow.nn.BatchNorm2d(3)\n    if with_cuda:\n        x = x.cuda()\n        bn.to(\"cuda\")\n\n    eager_batch_norm_res = flow.relu(bn(x))\n\n    class GraphToRun(flow.nn.Graph):\n        def __init__(self):\n            super().__init__()\n            self.m = bn\n\n        def build(self, x):\n            return flow.relu(self.m(x))\n\n    graph_to_run = GraphToRun()\n    lazy_batch_norm_res = graph_to_run(x)\n    test_case.assertTrue(\n        np.allclose(\n            eager_batch_norm_res.numpy(),\n            lazy_batch_norm_res.numpy(),\n            rtol=1e-5,\n            atol=1e-5,\n        )\n    )\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestNhwcConv(oneflow.unittest.MLIRTestCase):\n    def setUp(self):\n        os.environ[\"ONEFLOW_MLIR_ENABLE_ROUND_TRIP\"] = \"1\"\n        os.environ[\"ONEFLOW_MLIR_PREFER_NHWC\"] = \"1\"\n\n    def test_nhwc_conv_graph(test_case):\n        import oneflow.sysconfig\n\n        if oneflow.sysconfig.with_cuda():\n            do_nhwc_bacth_norm(test_case, True)\n        # do_nhwc_bacth_norm(test_case, False)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "oneflow/ir/test/OneFlow/auto_nhwc/test_nhwc_bias_add.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n# RUN: python3 -m oneflow.test_utils.throttle --with-cuda=%with_cuda python3 %s | FileCheck %s\n# CHECK: oneflow.transpose\n\nimport unittest\nimport numpy as np\nimport os\nimport oneflow as flow\nimport oneflow.unittest\n\n\ndef do_nhwc_bias_add(test_case, with_cuda):\n    a = flow.randn(2, 3, 4, 5)\n    b = flow.randn(3)\n    if with_cuda:\n        a = a.cuda()\n        b = b.cuda()\n\n    eager_bias_add_res = flow._C.bias_add(a, b, axis=1)\n\n    class GraphToRun(flow.nn.Graph):\n        def __init__(self):\n            super().__init__()\n\n        def build(self, a, b):\n            return flow._C.bias_add(a, b, axis=1)\n\n    graph_to_run = GraphToRun()\n    lazy_bias_add_res = graph_to_run(a, b)\n    test_case.assertTrue(\n        np.allclose(\n            eager_bias_add_res.numpy(), lazy_bias_add_res.numpy(), rtol=1e-5, atol=1e-5\n        )\n    )\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestNhwcBiasAdd(oneflow.unittest.MLIRTestCase):\n    def setUp(self):\n        os.environ[\"ONEFLOW_MLIR_ENABLE_ROUND_TRIP\"] = \"1\"\n        os.environ[\"ONEFLOW_MLIR_PREFER_NHWC\"] = \"1\"\n\n    def test_nhwc_bias_add_graph(test_case):\n        import oneflow.sysconfig\n\n        if oneflow.sysconfig.with_cuda():\n            do_nhwc_bias_add(test_case, True)\n        do_nhwc_bias_add(test_case, False)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "oneflow/ir/test/OneFlow/auto_nhwc/test_nhwc_conv.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n# RUN: python3 -m oneflow.test_utils.throttle --with-cuda=%with_cuda python3 %s | FileCheck %s\n# CHECK: oneflow.transpose\n\nimport unittest\nimport numpy as np\n\nimport os\n\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\ndef do_nhwc_conv(test_case, with_cuda, with_bias):\n    x = flow.randn(2, 3, 4, 5)\n    conv = flow.nn.Conv2d(3, 4, 2, 1, bias=with_bias)\n    if with_cuda:\n        x = x.cuda()\n        conv.to(\"cuda\")\n\n    eager_conv_x = conv(x)\n\n    class GraphToRun(flow.nn.Graph):\n        def __init__(self):\n            super().__init__()\n            self.conv = conv\n\n        def build(self, x):\n            return self.conv(x)\n\n    graph_to_run = GraphToRun()\n    lazy_conv_x = graph_to_run(x)\n    test_case.assertTrue(\n        np.allclose(eager_conv_x.numpy(), lazy_conv_x.numpy(), rtol=1e-5, atol=1e-5)\n    )\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestNhwcConv(oneflow.unittest.MLIRTestCase):\n    def setUp(self):\n        os.environ[\"ONEFLOW_MLIR_ENABLE_ROUND_TRIP\"] = \"1\"\n        os.environ[\"ONEFLOW_MLIR_PREFER_NHWC\"] = \"1\"\n        os.environ[\"ONEFLOW_MLIR_ENABLE_INFERENCE_OPTIMIZATION\"] = \"1\"\n\n    def test_nhwc_conv_graph(test_case):\n        do_nhwc_conv(test_case, True, True)\n        do_nhwc_conv(test_case, False, True)\n        do_nhwc_conv(test_case, True, False)\n        do_nhwc_conv(test_case, False, False)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "oneflow/ir/test/OneFlow/auto_nhwc/test_nhwc_conv2d_maxpool2d.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n# RUN: python3 -m oneflow.test_utils.throttle --with-cuda=%with_cuda python3 %s | FileCheck %s\n# CHECK: oneflow.transpose\n\nimport unittest\nimport numpy as np\n\nimport os\n\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\ndef do_nhwc_conv_maxpool(test_case, with_cuda, with_bias):\n    x = flow.randn(2, 3, 4, 5)\n    conv = flow.nn.Conv2d(3, 4, 2, 1, bias=with_bias)\n    maxpool_2d = flow.nn.MaxPool2d(\n        kernel_size=3, padding=1, stride=2, return_indices=False\n    )\n    if with_cuda:\n        x = x.cuda()\n        conv.to(\"cuda\")\n        maxpool_2d.to(\"cuda\")\n\n    eager_x = maxpool_2d(conv(x))\n\n    class GraphToRun(flow.nn.Graph):\n        def __init__(self):\n            super().__init__()\n            self.conv = conv\n\n        def build(self, x):\n            return maxpool_2d(self.conv(x))\n\n    graph_to_run = GraphToRun()\n    lazy_x = graph_to_run(x)\n    test_case.assertTrue(\n        np.allclose(eager_x.numpy(), lazy_x.numpy(), rtol=1e-5, atol=1e-5)\n    )\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestNhwcConvMaxPool(oneflow.unittest.MLIRTestCase):\n    def setUp(self):\n        os.environ[\"ONEFLOW_MLIR_ENABLE_ROUND_TRIP\"] = \"1\"\n        os.environ[\"ONEFLOW_MLIR_PREFER_NHWC\"] = \"1\"\n\n    def test_nhwc_conv_graph(test_case):\n        do_nhwc_conv_maxpool(test_case, True, True)\n        do_nhwc_conv_maxpool(test_case, True, False)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "oneflow/ir/test/OneFlow/auto_nhwc/test_nhwc_conv_relu_add.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n# RUN: python3 -m oneflow.test_utils.throttle --with-cuda=%with_cuda python3 %s | FileCheck %s\n# CHECK: oneflow.transpose\n\nimport unittest\nimport numpy as np\n\nimport os\n\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\ndef do_nhwc_conv(test_case, with_cuda, with_bias):\n    x = flow.randn(2, 3, 4, 5)\n    conv = flow.nn.Conv2d(3, 4, 2, 1, bias=with_bias)\n    if with_cuda:\n        x = x.cuda()\n        conv.to(\"cuda\")\n\n    eager_conv_x = flow.relu(conv(x)) + flow.relu(conv(x))\n\n    class GraphToRun(flow.nn.Graph):\n        def __init__(self):\n            super().__init__()\n            self.conv = conv\n\n        def build(self, x):\n            return flow.relu(self.conv(x)) + flow.relu(self.conv(x))\n\n    graph_to_run = GraphToRun()\n    lazy_conv_x = graph_to_run(x)\n    print(eager_conv_x.numpy().flatten()[:10])\n    print(lazy_conv_x.numpy().flatten()[:10])\n    test_case.assertTrue(\n        np.allclose(eager_conv_x.numpy(), lazy_conv_x.numpy(), rtol=1e-5, atol=1e-5)\n    )\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestNhwcConv(oneflow.unittest.MLIRTestCase):\n    def setUp(self):\n        os.environ[\"ONEFLOW_MLIR_ENABLE_ROUND_TRIP\"] = \"1\"\n        os.environ[\"ONEFLOW_MLIR_PREFER_NHWC\"] = \"1\"\n\n    def test_nhwc_conv_graph(test_case):\n        do_nhwc_conv(test_case, True, True)\n        do_nhwc_conv(test_case, False, True)\n        do_nhwc_conv(test_case, True, False)\n        do_nhwc_conv(test_case, False, False)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "oneflow/ir/test/OneFlow/auto_nhwc/test_nhwc_lenet.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n# RUN: python3 -m oneflow.test_utils.throttle --with-cuda=%with_cuda python3 %s | FileCheck %s\n# CHECK: oneflow.transpose\n\nimport unittest\nimport numpy as np\n\nimport os\n\nimport oneflow as flow\nimport oneflow.unittest\nimport oneflow.nn as nn\nimport oneflow.nn.functional as F\n\n\nclass LeNet(nn.Module):\n    def __init__(self):\n        super(LeNet, self).__init__()\n        self.conv1 = nn.Conv2d(3, 6, 5)\n        self.conv2 = nn.Conv2d(6, 16, 5)\n        self.fc1 = nn.Linear(16 * 5 * 5, 120)\n        self.fc2 = nn.Linear(120, 84)\n        self.fc3 = nn.Linear(84, 10)\n\n    def forward(self, x):\n        out = F.relu(self.conv1(x))\n        out = F.max_pool2d(out, 2)\n        out = F.relu(self.conv2(out))\n        out = F.max_pool2d(out, 2)\n        out = out.view(out.size(0), -1)\n        out = F.relu(self.fc1(out))\n        out = F.relu(self.fc2(out))\n        out = self.fc3(out)\n        return out\n\n\ndef do_lenet(test_case, with_cuda):\n    x = flow.randn(2, 3, 32, 32)\n    lenet = LeNet()\n    if with_cuda:\n        x = x.cuda()\n        lenet.to(\"cuda\")\n\n    eager_res = lenet(x)\n\n    class GraphToRun(flow.nn.Graph):\n        def __init__(self):\n            super().__init__()\n            self.lenet = lenet\n\n        def build(self, x):\n            return self.lenet(x)\n\n    graph_to_run = GraphToRun()\n    lazy_res = graph_to_run(x)\n    test_case.assertTrue(\n        np.allclose(eager_res.numpy(), lazy_res.numpy(), rtol=1e-5, atol=1e-5)\n    )\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestLeNet(oneflow.unittest.MLIRTestCase):\n    def setUp(self):\n        os.environ[\"ONEFLOW_MLIR_ENABLE_ROUND_TRIP\"] = \"1\"\n        os.environ[\"ONEFLOW_MLIR_PREFER_NHWC\"] = \"1\"\n\n    def test_nhwc_lenet_graph(test_case):\n        do_lenet(test_case, True)\n        do_lenet(test_case, False)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "oneflow/ir/test/OneFlow/auto_nhwc/test_nhwc_maxpool_2d.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n# RUN: python3 -m oneflow.test_utils.throttle --with-cuda=%with_cuda python3 %s | FileCheck %s\n# CHECK: oneflow.transpose\n\nimport unittest\nimport numpy as np\n\nimport os\n\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\ndef do_nhwc_maxpool_2d(test_case, with_cuda, with_return_induces):\n    x = flow.randn(1, 4, 4, 4)\n    maxpool_2d = flow.nn.MaxPool2d(\n        kernel_size=3, padding=1, stride=3, return_indices=with_return_induces\n    )\n    if with_cuda:\n        x = x.cuda()\n        maxpool_2d.to(\"cuda\")\n\n    eager_maxpool_2d_res = maxpool_2d(x)\n\n    class GraphToRun(flow.nn.Graph):\n        def __init__(self):\n            super().__init__()\n            self.m = maxpool_2d\n\n        def build(self, x):\n            return self.m(x)\n\n    graph_to_run = GraphToRun()\n    lazy_maxpool_2d_res = graph_to_run(x)\n    if with_return_induces:\n        test_case.assertTrue(\n            np.allclose(\n                eager_maxpool_2d_res[0].numpy(),\n                lazy_maxpool_2d_res[0].numpy(),\n                rtol=1e-5,\n                atol=1e-5,\n            )\n        )\n    else:\n        test_case.assertTrue(\n            np.allclose(\n                eager_maxpool_2d_res.numpy(),\n                lazy_maxpool_2d_res.numpy(),\n                rtol=1e-5,\n                atol=1e-5,\n            )\n        )\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestNhwcMaxPool2d(oneflow.unittest.MLIRTestCase):\n    def setUp(self):\n        os.environ[\"ONEFLOW_MLIR_ENABLE_ROUND_TRIP\"] = \"1\"\n        os.environ[\"ONEFLOW_MLIR_PREFER_NHWC\"] = \"1\"\n\n    def test_nhwc_maxpool_2d_graph(test_case):\n        do_nhwc_maxpool_2d(test_case, True, True)\n        do_nhwc_maxpool_2d(test_case, True, False)\n        do_nhwc_maxpool_2d(test_case, False, True)\n        do_nhwc_maxpool_2d(test_case, False, False)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "oneflow/ir/test/OneFlow/auto_nhwc/test_nhwc_resnet.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n# RUN: python3 -m oneflow.test_utils.throttle --with-cuda=%with_cuda python3 %s | FileCheck %s\n# CHECK: oneflow.transpose\nimport unittest\nimport numpy as np\n\nfrom typing import Type, Any, Callable, Union, List, Optional\nimport os\n\n\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow import Tensor\nimport oneflow.nn as nn\n\n\n__all__ = [\n    \"ResNet\",\n    \"resnet50\",\n]\n\n\ndef conv3x3(\n    in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1\n) -> nn.Conv2d:\n    \"\"\"3x3 convolution with padding\"\"\"\n    return nn.Conv2d(\n        in_planes,\n        out_planes,\n        kernel_size=3,\n        stride=stride,\n        padding=dilation,\n        groups=groups,\n        bias=False,\n        dilation=dilation,\n    )\n\n\ndef conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:\n    \"\"\"1x1 convolution\"\"\"\n    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)\n\n\nclass BasicBlock(nn.Module):\n    expansion: int = 1\n\n    def __init__(\n        self,\n        inplanes: int,\n        planes: int,\n        stride: int = 1,\n        downsample: Optional[nn.Module] = None,\n        groups: int = 1,\n        base_width: int = 64,\n        dilation: int = 1,\n        norm_layer: Optional[Callable[..., nn.Module]] = None,\n    ) -> None:\n        super(BasicBlock, self).__init__()\n        if norm_layer is None:\n            norm_layer = nn.BatchNorm2d\n        if groups != 1 or base_width != 64:\n            raise ValueError(\"BasicBlock only supports groups=1 and base_width=64\")\n        if dilation > 1:\n            raise NotImplementedError(\"Dilation > 1 not supported in BasicBlock\")\n        # Both self.conv1 and self.downsample layers downsample the input when stride != 1\n        self.conv1 = conv3x3(inplanes, planes, stride)\n        self.bn1 = norm_layer(planes)\n        self.relu = nn.ReLU(inplace=True)\n        self.conv2 = conv3x3(planes, planes)\n        self.bn2 = norm_layer(planes)\n        self.downsample = downsample\n        self.stride = stride\n\n    def forward(self, x: Tensor) -> Tensor:\n        identity = x\n\n        out = self.conv1(x)\n        out = self.bn1(out)\n        out = self.relu(out)\n\n        out = self.conv2(out)\n        out = self.bn2(out)\n\n        if self.downsample is not None:\n            identity = self.downsample(x)\n\n        out += identity\n        out = self.relu(out)\n\n        return out\n\n\nclass Bottleneck(nn.Module):\n    # Bottleneck in flowvision places the stride for downsampling at 3x3 convolution(self.conv2)\n    # while original implementation places the stride at the first 1x1 convolution(self.conv1)\n    # according to \"Deep residual learning for image recognition\"https://arxiv.org/abs/1512.03385.\n    # This variant is also known as ResNet V1.5 and improves accuracy according to\n    # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.\n\n    expansion: int = 4\n\n    def __init__(\n        self,\n        inplanes: int,\n        planes: int,\n        stride: int = 1,\n        downsample: Optional[nn.Module] = None,\n        groups: int = 1,\n        base_width: int = 64,\n        dilation: int = 1,\n        norm_layer: Optional[Callable[..., nn.Module]] = None,\n    ) -> None:\n        super(Bottleneck, self).__init__()\n        if norm_layer is None:\n            norm_layer = nn.BatchNorm2d\n        width = int(planes * (base_width / 64.0)) * groups\n        # Both self.conv2 and self.downsample layers downsample the input when stride != 1\n        self.conv1 = conv1x1(inplanes, width)\n        self.bn1 = norm_layer(width)\n        self.conv2 = conv3x3(width, width, stride, groups, dilation)\n        self.bn2 = norm_layer(width)\n        self.conv3 = conv1x1(width, planes * self.expansion)\n        self.bn3 = norm_layer(planes * self.expansion)\n        self.relu = nn.ReLU(inplace=True)\n        self.downsample = downsample\n        self.stride = stride\n\n    def forward(self, x: Tensor) -> Tensor:\n        identity = x\n\n        out = self.conv1(x)\n        out = self.bn1(out)\n        out = self.relu(out)\n\n        out = self.conv2(out)\n        out = self.bn2(out)\n        out = self.relu(out)\n\n        out = self.conv3(out)\n        out = self.bn3(out)\n\n        if self.downsample is not None:\n            identity = self.downsample(x)\n\n        out += identity\n        out = self.relu(out)\n\n        return out\n\n\nclass ResNet(nn.Module):\n    def __init__(\n        self,\n        block: Type[Union[BasicBlock, Bottleneck]],\n        layers: List[int],\n        num_classes: int = 1000,\n        zero_init_residual: bool = False,\n        groups: int = 1,\n        width_per_group: int = 64,\n        replace_stride_with_dilation: Optional[List[bool]] = None,\n        norm_layer: Optional[Callable[..., nn.Module]] = None,\n    ) -> None:\n        super(ResNet, self).__init__()\n        if norm_layer is None:\n            norm_layer = nn.BatchNorm2d\n        self._norm_layer = norm_layer\n\n        self.inplanes = 64\n        self.dilation = 1\n        if replace_stride_with_dilation is None:\n            # each element in the tuple indicates if we should replace\n            # the 2x2 stride with a dilated convolution instead\n            replace_stride_with_dilation = [False, False, False]\n        if len(replace_stride_with_dilation) != 3:\n            raise ValueError(\n                \"replace_stride_with_dilation should be None \"\n                \"or a 3-element tuple, got {}\".format(replace_stride_with_dilation)\n            )\n        self.groups = groups\n        self.base_width = width_per_group\n        self.conv1 = nn.Conv2d(\n            3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False\n        )\n        self.bn1 = norm_layer(self.inplanes)\n        self.relu = nn.ReLU(inplace=True)\n        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)\n        self.layer1 = self._make_layer(block, 64, layers[0])\n        self.layer2 = self._make_layer(\n            block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0]\n        )\n        self.layer3 = self._make_layer(\n            block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1]\n        )\n        self.layer4 = self._make_layer(\n            block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2]\n        )\n        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))\n        self.fc = nn.Linear(512 * block.expansion, num_classes)\n\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                nn.init.kaiming_normal_(m.weight, mode=\"fan_out\", nonlinearity=\"relu\")\n            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):\n                nn.init.constant_(m.weight, 1)\n                nn.init.constant_(m.bias, 0)\n\n        # Zero-initialize the last BN in each residual branch,\n        # so that the residual branch starts with zeros, and each residual block behaves like an identity.\n        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677\n        if zero_init_residual:\n            for m in self.modules():\n                if isinstance(m, Bottleneck):\n                    nn.init.constant_(m.bn3.weight, 0)  # type: ignore[arg-type]\n                elif isinstance(m, BasicBlock):\n                    nn.init.constant_(m.bn2.weight, 0)  # type: ignore[arg-type]\n\n    def _make_layer(\n        self,\n        block: Type[Union[BasicBlock, Bottleneck]],\n        planes: int,\n        blocks: int,\n        stride: int = 1,\n        dilate: bool = False,\n    ) -> nn.Sequential:\n        norm_layer = self._norm_layer\n        downsample = None\n        previous_dilation = self.dilation\n        if dilate:\n            self.dilation *= stride\n            stride = 1\n        if stride != 1 or self.inplanes != planes * block.expansion:\n            downsample = nn.Sequential(\n                conv1x1(self.inplanes, planes * block.expansion, stride),\n                norm_layer(planes * block.expansion),\n            )\n\n        layers = []\n        layers.append(\n            block(\n                self.inplanes,\n                planes,\n                stride,\n                downsample,\n                self.groups,\n                self.base_width,\n                previous_dilation,\n                norm_layer,\n            )\n        )\n        self.inplanes = planes * block.expansion\n        for _ in range(1, blocks):\n            layers.append(\n                block(\n                    self.inplanes,\n                    planes,\n                    groups=self.groups,\n                    base_width=self.base_width,\n                    dilation=self.dilation,\n                    norm_layer=norm_layer,\n                )\n            )\n\n        return nn.Sequential(*layers)\n\n    def _forward_impl(self, x: Tensor) -> Tensor:\n        # See note [TorchScript super()]\n        x = self.conv1(x)\n        x = self.bn1(x)\n        x = self.relu(x)\n        x = self.maxpool(x)\n\n        x = self.layer1(x)\n        x = self.layer2(x)\n        x = self.layer3(x)\n        x = self.layer4(x)\n\n        x = self.avgpool(x)\n        x = flow.flatten(x, 1)\n        x = self.fc(x)\n\n        return x\n\n    def forward(self, x: Tensor) -> Tensor:\n        return self._forward_impl(x)\n\n\ndef _resnet(\n    arch: str,\n    block: Type[Union[BasicBlock, Bottleneck]],\n    layers: List[int],\n    pretrained: bool,\n    progress: bool,\n    **kwargs: Any\n) -> ResNet:\n    model = ResNet(block, layers, **kwargs)\n    return model\n\n\ndef resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:\n    r\"\"\"ResNet-50 model from\n    `\"Deep Residual Learning for Image Recognition\" <https://arxiv.org/pdf/1512.03385.pdf>`_.\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n        progress (bool): If True, displays a progress bar of the download to stderr\n    \"\"\"\n    return _resnet(\"resnet50\", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs)\n\n\ndef do_resnet(test_case):\n    x = flow.randn(2, 3, 224, 224)\n    resnet = resnet50()\n    x = x.cuda()\n    resnet.to(\"cuda\")\n\n    eager_res = resnet(x)\n\n    class GraphToRun(flow.nn.Graph):\n        def __init__(self):\n            super().__init__()\n            self.resnet = resnet\n\n        def build(self, x):\n            return self.resnet(x)\n\n    graph_to_run = GraphToRun()\n    lazy_res = graph_to_run(x)\n    test_case.assertTrue(\n        # TODO(yuhao): High precision loss\n        np.allclose(eager_res.numpy(), lazy_res.numpy(), rtol=1e-4, atol=1e-1)\n    )\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestResNet(oneflow.unittest.MLIRTestCase):\n    def setUp(self):\n\n        os.environ[\"ONEFLOW_MLIR_ENABLE_ROUND_TRIP\"] = \"1\"\n        os.environ[\"ONEFLOW_MLIR_PREFER_NHWC\"] = \"1\"\n\n    def test_nhwc_resnet_graph(test_case):\n        do_resnet(test_case)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "oneflow/ir/test/OneFlow/auto_nhwc/test_nhwc_transpose_eliminate.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n# RUN: python3 -m oneflow.test_utils.throttle --with-cuda=%with_cuda python3 %s | FileCheck %s\n# CHECK-NOT: oneflow.transpose\n\nimport unittest\nimport numpy as np\n\nimport os\n\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\ndef do_eliminate_transpose(test_case, with_cuda):\n    x = flow.randn(2, 3, 4, 5)\n    if with_cuda:\n        x = x.cuda()\n\n    eager_res = flow.permute(flow.permute(x, (0, 2, 3, 1)), (0, 3, 1, 2))\n\n    class GraphToRun(flow.nn.Graph):\n        def __init__(self):\n            super().__init__()\n\n        def build(self, x):\n            return flow.permute(flow.permute(x, (0, 2, 3, 1)), (0, 3, 1, 2))\n\n    graph_to_run = GraphToRun()\n    lazy_res = graph_to_run(x)\n    test_case.assertTrue(\n        np.allclose(eager_res.numpy(), lazy_res.numpy(), rtol=1e-5, atol=1e-5)\n    )\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestNhwcEliminateTranspose(oneflow.unittest.MLIRTestCase):\n    def setUp(self):\n        os.environ[\"ONEFLOW_MLIR_ENABLE_ROUND_TRIP\"] = \"1\"\n        os.environ[\"ONEFLOW_MLIR_PREFER_NHWC\"] = \"1\"\n\n    def test_eliminate_transpose(test_case):\n        do_eliminate_transpose(test_case, True)\n        do_eliminate_transpose(test_case, False)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "oneflow/ir/test/OneFlow/auto_nhwc/test_resnet101_benchmark.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n# RUN: python3 -m oneflow.test_utils.throttle --with-cuda=%with_cuda python3 %s | FileCheck %s\n# CHECK: oneflow.transpose\nimport unittest\nimport numpy as np\n\nimport time\nimport datetime\nfrom typing import Type, Any, Callable, Union, List, Optional\nimport os\n\nos.environ[\"ONEFLOW_MLIR_ENABLE_ROUND_TRIP\"] = \"1\"\nos.environ[\"ONEFLOW_MLIR_PREFER_NHWC\"] = \"1\"\nos.environ[\"ONEFLOW_MLIR_ENABLE_INFERENCE_OPTIMIZATION\"] = \"1\"\n\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow import Tensor\nimport oneflow.nn as nn\n\n\n__all__ = [\n    \"ResNet\",\n    \"resnet50\",\n]\n\n\ndef conv3x3(\n    in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1\n) -> nn.Conv2d:\n    \"\"\"3x3 convolution with padding\"\"\"\n    return nn.Conv2d(\n        in_planes,\n        out_planes,\n        kernel_size=3,\n        stride=stride,\n        padding=dilation,\n        groups=groups,\n        bias=False,\n        dilation=dilation,\n    )\n\n\ndef conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:\n    \"\"\"1x1 convolution\"\"\"\n    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)\n\n\nclass BasicBlock(nn.Module):\n    expansion: int = 1\n\n    def __init__(\n        self,\n        inplanes: int,\n        planes: int,\n        stride: int = 1,\n        downsample: Optional[nn.Module] = None,\n        groups: int = 1,\n        base_width: int = 64,\n        dilation: int = 1,\n        norm_layer: Optional[Callable[..., nn.Module]] = None,\n    ) -> None:\n        super(BasicBlock, self).__init__()\n        if norm_layer is None:\n            norm_layer = nn.BatchNorm2d\n        if groups != 1 or base_width != 64:\n            raise ValueError(\"BasicBlock only supports groups=1 and base_width=64\")\n        if dilation > 1:\n            raise NotImplementedError(\"Dilation > 1 not supported in BasicBlock\")\n        # Both self.conv1 and self.downsample layers downsample the input when stride != 1\n        self.conv1 = conv3x3(inplanes, planes, stride)\n        self.bn1 = norm_layer(planes)\n        self.relu = nn.ReLU(inplace=True)\n        self.conv2 = conv3x3(planes, planes)\n        self.bn2 = norm_layer(planes)\n        self.downsample = downsample\n        self.stride = stride\n\n    def forward(self, x: Tensor) -> Tensor:\n        identity = x\n\n        out = self.conv1(x)\n        out = self.bn1(out)\n        out = self.relu(out)\n\n        out = self.conv2(out)\n        out = self.bn2(out)\n\n        if self.downsample is not None:\n            identity = self.downsample(x)\n\n        out += identity\n        out = self.relu(out)\n\n        return out\n\n\nclass Bottleneck(nn.Module):\n    # Bottleneck in flowvision places the stride for downsampling at 3x3 convolution(self.conv2)\n    # while original implementation places the stride at the first 1x1 convolution(self.conv1)\n    # according to \"Deep residual learning for image recognition\"https://arxiv.org/abs/1512.03385.\n    # This variant is also known as ResNet V1.5 and improves accuracy according to\n    # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.\n\n    expansion: int = 4\n\n    def __init__(\n        self,\n        inplanes: int,\n        planes: int,\n        stride: int = 1,\n        downsample: Optional[nn.Module] = None,\n        groups: int = 1,\n        base_width: int = 64,\n        dilation: int = 1,\n        norm_layer: Optional[Callable[..., nn.Module]] = None,\n    ) -> None:\n        super(Bottleneck, self).__init__()\n        if norm_layer is None:\n            norm_layer = nn.BatchNorm2d\n        width = int(planes * (base_width / 64.0)) * groups\n        # Both self.conv2 and self.downsample layers downsample the input when stride != 1\n        self.conv1 = conv1x1(inplanes, width)\n        self.bn1 = norm_layer(width)\n        self.conv2 = conv3x3(width, width, stride, groups, dilation)\n        self.bn2 = norm_layer(width)\n        self.conv3 = conv1x1(width, planes * self.expansion)\n        self.bn3 = norm_layer(planes * self.expansion)\n        self.relu = nn.ReLU(inplace=True)\n        self.downsample = downsample\n        self.stride = stride\n\n    def forward(self, x: Tensor) -> Tensor:\n        identity = x\n\n        out = self.conv1(x)\n        out = self.bn1(out)\n        out = self.relu(out)\n\n        out = self.conv2(out)\n        out = self.bn2(out)\n        out = self.relu(out)\n\n        out = self.conv3(out)\n        out = self.bn3(out)\n\n        if self.downsample is not None:\n            identity = self.downsample(x)\n\n        out += identity\n        out = self.relu(out)\n\n        return out\n\n\nclass ResNet(nn.Module):\n    def __init__(\n        self,\n        block: Type[Union[BasicBlock, Bottleneck]],\n        layers: List[int],\n        num_classes: int = 1000,\n        zero_init_residual: bool = False,\n        groups: int = 1,\n        width_per_group: int = 64,\n        replace_stride_with_dilation: Optional[List[bool]] = None,\n        norm_layer: Optional[Callable[..., nn.Module]] = None,\n    ) -> None:\n        super(ResNet, self).__init__()\n        if norm_layer is None:\n            norm_layer = nn.BatchNorm2d\n        self._norm_layer = norm_layer\n\n        self.inplanes = 64\n        self.dilation = 1\n        if replace_stride_with_dilation is None:\n            # each element in the tuple indicates if we should replace\n            # the 2x2 stride with a dilated convolution instead\n            replace_stride_with_dilation = [False, False, False]\n        if len(replace_stride_with_dilation) != 3:\n            raise ValueError(\n                \"replace_stride_with_dilation should be None \"\n                \"or a 3-element tuple, got {}\".format(replace_stride_with_dilation)\n            )\n        self.groups = groups\n        self.base_width = width_per_group\n        self.conv1 = nn.Conv2d(\n            3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False\n        )\n        self.bn1 = norm_layer(self.inplanes)\n        self.relu = nn.ReLU(inplace=True)\n        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)\n        self.layer1 = self._make_layer(block, 64, layers[0])\n        self.layer2 = self._make_layer(\n            block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0]\n        )\n        self.layer3 = self._make_layer(\n            block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1]\n        )\n        self.layer4 = self._make_layer(\n            block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2]\n        )\n        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))\n        self.fc = nn.Linear(512 * block.expansion, num_classes)\n\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                nn.init.kaiming_normal_(m.weight, mode=\"fan_out\", nonlinearity=\"relu\")\n            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):\n                nn.init.constant_(m.weight, 1)\n                nn.init.constant_(m.bias, 0)\n\n        # Zero-initialize the last BN in each residual branch,\n        # so that the residual branch starts with zeros, and each residual block behaves like an identity.\n        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677\n        if zero_init_residual:\n            for m in self.modules():\n                if isinstance(m, Bottleneck):\n                    nn.init.constant_(m.bn3.weight, 0)  # type: ignore[arg-type]\n                elif isinstance(m, BasicBlock):\n                    nn.init.constant_(m.bn2.weight, 0)  # type: ignore[arg-type]\n\n    def _make_layer(\n        self,\n        block: Type[Union[BasicBlock, Bottleneck]],\n        planes: int,\n        blocks: int,\n        stride: int = 1,\n        dilate: bool = False,\n    ) -> nn.Sequential:\n        norm_layer = self._norm_layer\n        downsample = None\n        previous_dilation = self.dilation\n        if dilate:\n            self.dilation *= stride\n            stride = 1\n        if stride != 1 or self.inplanes != planes * block.expansion:\n            downsample = nn.Sequential(\n                conv1x1(self.inplanes, planes * block.expansion, stride),\n                norm_layer(planes * block.expansion),\n            )\n\n        layers = []\n        layers.append(\n            block(\n                self.inplanes,\n                planes,\n                stride,\n                downsample,\n                self.groups,\n                self.base_width,\n                previous_dilation,\n                norm_layer,\n            )\n        )\n        self.inplanes = planes * block.expansion\n        for _ in range(1, blocks):\n            layers.append(\n                block(\n                    self.inplanes,\n                    planes,\n                    groups=self.groups,\n                    base_width=self.base_width,\n                    dilation=self.dilation,\n                    norm_layer=norm_layer,\n                )\n            )\n\n        return nn.Sequential(*layers)\n\n    def _forward_impl(self, x: Tensor) -> Tensor:\n        # See note [TorchScript super()]\n        x = self.conv1(x)\n        x = self.bn1(x)\n        x = self.relu(x)\n        x = self.maxpool(x)\n\n        x = self.layer1(x)\n        x = self.layer2(x)\n        x = self.layer3(x)\n        x = self.layer4(x)\n\n        x = self.avgpool(x)\n        x = flow.flatten(x, 1)\n        x = self.fc(x)\n\n        return x\n\n    def forward(self, x: Tensor) -> Tensor:\n        return self._forward_impl(x)\n\n\ndef _resnet(\n    arch: str,\n    block: Type[Union[BasicBlock, Bottleneck]],\n    layers: List[int],\n    pretrained: bool,\n    progress: bool,\n    **kwargs: Any\n) -> ResNet:\n    model = ResNet(block, layers, **kwargs)\n    return model\n\n\ndef resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:\n    \"\"\"\n    Constructs the ResNet-101 model.\n    .. note::\n        `Deep Residual Learning for Image Recognition <https://arxiv.org/pdf/1512.03385.pdf>`_.\n    Args:\n        pretrained (bool): Whether to download the pre-trained model on ImageNet. Default: ``False``\n        progress (bool): If True, displays a progress bar of the download to stderr. Default: ``True``\n    For example:\n    .. code-block:: python\n        >>> import flowvision\n        >>> resnet101 = flowvision.models.resnet101(pretrained=False, progress=True)\n    \"\"\"\n    return _resnet(\n        \"resnet101\", Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs\n    )\n\n\ndef bench(forward: Callable, x, n=1000):\n    # warm up\n    for _ in range(5):\n        output = forward(x)\n        res = output.numpy()\n\n    flow._oneflow_internal.profiler.RangePush(\"eval begin\")\n    start_time = time.time()\n    for _ in range(n):\n        flow._oneflow_internal.profiler.RangePush(\"forward\")\n        output = forward(x)\n        flow._oneflow_internal.profiler.RangePop()\n        flow._oneflow_internal.profiler.RangePush(\"numpy\")\n        res = output.numpy()\n        flow._oneflow_internal.profiler.RangePop()\n    flow._oneflow_internal.profiler.RangePop()\n    total_time = time.time() - start_time\n    total_time_str = str(datetime.timedelta(seconds=int(total_time)))\n    print(total_time_str)\n\n\nclass ResNetEvalGraph(nn.Graph):\n    def __init__(self, model):\n        super().__init__()\n        self.model = model\n        self.config.enable_amp(True)\n\n    def build(self, x):\n\n        y_pred = self.model(x)\n        return y_pred\n\n\ndef main():\n\n    np.random.seed(42)\n\n    device = oneflow.device(\"cuda\")\n    model = resnet101()\n    model.eval()\n    model.to(device)\n    batch_size = 64\n    x = oneflow.randn(batch_size, 3, 224, 224).to(oneflow.device(\"cuda\"))\n\n    model_graph = ResNetEvalGraph(model)\n    bench(model_graph, x, n=10)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "oneflow/ir/test/OneFlow/conversion/lower_to_tosa.mlir",
    "content": "// RUN: oneflow-opt \\\n// RUN: -lower-oneflow-to-tosa \\\n// RUN: -tosa-make-broadcastable \\\n// RUN: --print-after-all %s\n\nmodule  {\n  func.func @Cast_1__FUSE__ScalarMulByTensor_2(%arg0: tensor<96x96xi64>, %arg1: tensor<1xf32>) -> tensor<96x96xf32> {\n    %0 = \"oneflow.cast\"(%arg0) {device_name = [\"0:0\"], device_tag = \"cpu\", dtype = 2 : i32, hierarchy = [1], op_name = \"Cast_1\", op_type_name = \"cast\", scope_symbol_id = 4611686018427416574 : i64} : (tensor<96x96xi64>) -> tensor<96x96xf32>\n    %1 = \"oneflow.scalar_mul_by_tensor\"(%0, %arg1) {device_name = [\"0:0\"], device_tag = \"cpu\", hierarchy = [1], op_name = \"ScalarMulByTensor_2\", op_type_name = \"scalar_mul_by_tensor\", scope_symbol_id = 4611686018427416574 : i64} : (tensor<96x96xf32>, tensor<1xf32>) -> tensor<96x96xf32>\n    return %1 : tensor<96x96xf32>\n  }\n}\n"
  },
  {
    "path": "oneflow/ir/test/OneFlow/conversion/lower_to_tosa_signed.mlir",
    "content": "// RUN: oneflow-opt -convert-to-signless-for-tosa   --mlir-print-ir-before-all --mlir-print-ir-after-all \\\n// RUN: -lower-oneflow-to-tosa \\\n// RUN: -tosa-make-broadcastable \\\n// RUN: -reconcile-unrealized-casts --print-after-all %s\n\nmodule  {\n  func.func @test(%arg0: tensor<1x64x112x112xf32>) -> tensor<1x64x56x56xsi64> {\n    %1, %indice = \"oneflow.max_pool_2d\"(%arg0) {\n      ceil_mode = false,\n      data_format = \"channels_first\",\n      device_name = [\"@0:0\"],\n      device_tag = \"cpu\",\n      dilation = [1 : si32, 1 : si32],\n      hierarchy = [1],\n      kernel_size = [3 : si32, 3 : si32],\n      op_name = \"model.maxpool-max_pool_2d-3\",\n      padding = [1 : si32, 1 : si32],\n      return_indices = false,\n      scope_symbol_id = 49 : i64,\n      stride = [2 : si32, 2 : si32]\n    } : (tensor<1x64x112x112xf32>) -> (tensor<1x64x56x56xf32>, tensor<1x64x56x56xsi64>)\n    return %indice : tensor<1x64x56x56xsi64>\n  }\n}\n"
  },
  {
    "path": "oneflow/ir/test/OneFlow/conversion/oneflow_to_tosa.mlir",
    "content": "// RUN: oneflow-opt %s \\\n// RUN: -split-input-file \\\n// RUN: -auto-nhwc \\\n// RUN: -lower-oneflow-to-tosa \\\n// RUN: -verify-diagnostics -o - \\\n// RUN: | FileCheck %s\n\n\n// CHECK-LABEL: test_func\n// CHECK: return [[V0:%.+]] : tensor<1xf32>\noneflow.job @test_func(%arg0: tensor<1xf32>) -> tensor<1xf32>\n{\n    oneflow.return %arg0 : tensor<1xf32>\n}\n\n\n// CHECK-LABEL: test_input\n// CHECK: return [[V0:%.+]] : tensor<1xf32>\noneflow.job @test_input(%arg0: tensor<1xf32>) -> tensor<1xf32>\n{\n    %res = \"oneflow.input\"(%arg0)\n    {\n        data_type = 2 : i32,\n        device_name = [\"@0:0\"],\n        device_tag = \"cpu\",\n        hierarchy = [1],\n        is_dynamic = false,\n        nd_sbp = [\"B\"],\n        op_name = \"\",\n        output_lbns = [\"\"],\n        scope_symbol_id = 4611686018427412479 : i64,\n        shape = [1 : si64]\n    } : (tensor<1xf32>) -> tensor<1xf32>\n    oneflow.return %res : tensor<1xf32>\n}\n\n\n// CHECK-LABEL: test_output\n// CHECK: return [[V0:%.+]] : tensor<1xf32>\noneflow.job @test_output(%arg0: tensor<1xf32>) -> tensor<1xf32>\n{\n    %res = \"oneflow.output\"(%arg0)\n    {\n        data_type = 2 : i32,\n        device_name = [\"@0:0\"],\n        device_tag = \"cpu\",\n        hierarchy = [1],\n        is_dynamic = false,\n        nd_sbp = [\"B\"],\n        op_name = \"\",\n        output_lbns = [\"\"],\n        scope_symbol_id = 4611686018427412479 : i64,\n        shape = [1 : si64]\n    } : (tensor<1xf32>) -> tensor<1xf32>\n    oneflow.return %res : tensor<1xf32>\n}\n\n\n// CHECK-LABEL: test_variable\n// CHECK: [[V0:%.+]] = \"tosa.const\"() \n// CHECK-SAME: {value = dense<0.000000e+00> : tensor<64x3x7x7xf32>}\n// CHECK: return [[V0]] : tensor<64x3x7x7xf32>\noneflow.job @test_variable() -> tensor<64x3x7x7xf32>\n{\n    %res = \"oneflow.variable\"() {\n        data_type = 2 : i32,\n        device_name = [\"@0:0\"],\n        device_tag = \"cpu\",\n        hierarchy = [1],\n        parallel = #sbp.parallel<[] -> [#sbp.B]>,\n        op_name = \"fw.model.conv1.weight\",\n        output_lbns = [\"fw.model.conv1.weight/out\"],\n        scope_symbol_id = 4611686018427432959 : i64,\n        shape = [64 : si64, 3 : si64, 7 : si64, 7 : si64]\n    } : () -> tensor<64x3x7x7xf32>\n    oneflow.return %res : tensor<64x3x7x7xf32>\n}\n\n\n// CHECK-LABEL: test_add_n2\n// CHECK: [[V0:%.+]] = \"tosa.add\"(%arg0, %arg1) : (tensor<1x7x7xf32>, tensor<1x7x7xf32>) -> tensor<1x7x7xf32>\n// CHECK: return [[V0]] : tensor<1x7x7xf32>\noneflow.job @test_add_n2(%arg0: tensor<1x7x7xf32>, %arg1: tensor<1x7x7xf32>) -> tensor<1x7x7xf32>\n{\n    %res = \"oneflow.add_n2\"(%arg0, %arg1)\n    {\n        device_name = [\"@0:0\"],\n        device_tag = \"cpu\",\n        hierarchy = [1],\n        op_name = \"\",\n        op_type_name = \"add_n\",\n        output_lbns = [\"\"],\n        scope_symbol_id = 4611686018431205375 : i64\n    } : (tensor<1x7x7xf32>, tensor<1x7x7xf32>) -> tensor<1x7x7xf32>\n    oneflow.return %res: tensor<1x7x7xf32>\n}\n\n\n//CHECK-LABEL: test_broadcast_add\n//CHECK: [[V0:%.+]] = \"tosa.add\"(%arg0, %arg1) : (tensor<1x1000xf32>, tensor<1000xf32>) -> tensor<1x1000xf32>\n//CHECK: return [[V0]] : tensor<1x1000xf32>\noneflow.job @test_broadcast_add(%arg0: tensor<1x1000xf32>, %arg1: tensor<1000xf32>) -> tensor<1x1000xf32>\n{\n    %res = \"oneflow.broadcast_add\"(%arg0, %arg1)\n    {\n        device_name = [\"@0:0\"],\n        device_tag = \"cpu\",\n        hierarchy = [1],\n        op_name = \"\",\n        output_lbns = [\"\"],\n        scope_symbol_id = 4611686018431234047 : i64\n    } : (tensor<1x1000xf32>, tensor<1000xf32>) -> tensor<1x1000xf32>\n    oneflow.return %res : tensor<1x1000xf32>\n}\n\n\n// CHECK-LABEL: test_max_pool_2d\n// CHECK: [[V0:%.+]] = \"tosa.const\"() \n// CHECK-SAME: {value = dense<[0, 2, 3, 1]> : tensor<4xi32>}\n// CHECK: [[V1:%.+]] = \"tosa.transpose\"(%arg0, [[V0]]) : (tensor<1x64x112x112xf32>, tensor<4xi32>) -> tensor<1x112x112x64xf32>\n// CHECK: [[V2:%.+]] = \"tosa.max_pool2d\"([[V1]]) \n// CHECK-SAME {kernel = array<i64: 3, 3>, pad = array<i64: 1, 1, 1, 1>, stride = array<i64: 2, 2>}\n// CHECK: [[V3:%.+]] = \"tosa.const\"() \n// CHECK-SAME: {value = dense<[0, 3, 1, 2]> : tensor<4xi32>}\n// CHECK: [[V4:%.+]] = \"tosa.transpose\"([[V2]], [[V3]]) : (tensor<1x56x56x64xf32>, tensor<4xi32>) -> tensor<1x64x56x56xf32>\n// CHECK: [[V5:%.+]] = \"tosa.const\"() \n// CHECK-SAME: {value = dense<0> : tensor<1x64x56x56xi64>}\n// CHECK: return [[V4]] : tensor<1x64x56x56xf32>\noneflow.job @test_max_pool_2d(%arg0: tensor<1x64x112x112xf32>) -> tensor<1x64x56x56xf32>\n{\n    %y, %indice = \"oneflow.max_pool_2d\"(%arg0)\n    {\n        ceil_mode = false,\n        data_format = \"channels_first\",\n        device_name = [\"@0:0\"],\n        device_tag = \"cpu\",\n        dilation = [1 : si32, 1 : si32],\n        hierarchy = [1], kernel_size = [3 : si32, 3 : si32],\n        op_name = \"\",\n        output_lbns = [\"\", \"\"],\n        padding = [1 : si32, 1 : si32],\n        return_indices = false,\n        scope_symbol_id = 4611686018427502591 : i64,\n        stride = [2 : si32, 2 : si32]\n    } : (tensor<1x64x112x112xf32>) -> (tensor<1x64x56x56xf32>, tensor<1x64x56x56xi64>)\n    oneflow.return %y :  tensor<1x64x56x56xf32>\n}\n\n\n// CHECK-LABEL: test_avg_pool_2d\n// CHECK: [[V0:%.+]] = \"tosa.const\"() \n// CHECK-SAME: {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} \n// CHECK: [[V1:%.+]] = \"tosa.transpose\"(%arg0, [[V0]]) : (tensor<1x2048x7x7xf32>, tensor<4xi32>) -> tensor<1x7x7x2048xf32>\n// CHECK: [[V2:%.+]] = \"tosa.avg_pool2d\"([[V1]]) \n// CHECK-SAME: {kernel = array<i64: 7, 7>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 7, 7>}\n// CHECK: [[V3:%.+]] = \"tosa.const\"()\n// CHECK-SAME: {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} \n// CHECK: [[V4:%.+]] = \"tosa.transpose\"([[V2]], [[V3]]) : (tensor<1x1x1x2048xf32>, tensor<4xi32>) -> tensor<1x2048x1x1xf32>\n// CHECK: return [[V4]] : tensor<1x2048x1x1xf32>\noneflow.job @test_avg_pool_2d(%arg0: tensor<1x2048x7x7xf32>) -> tensor<1x2048x1x1xf32>\n{\n    %res = \"oneflow.avg_pool_2d\"(%arg0)\n    {\n        ceil_mode = false,\n        count_include_pad = true,\n        data_format = \"channels_first\",\n        device_name = [\"@0:0\"],\n        device_tag = \"cpu\",\n        divisor_override = 0 : si32,\n        hierarchy = [1],\n        kernel_size = [7 : si32, 7 : si32],\n        op_name = \"model.avgpool-avg_pool_2d-172\",\n        output_lbns = [\"model.avgpool-avg_pool_2d-172/y_0\"],\n        padding = [0 : si32, 0 : si32],\n        scope_symbol_id = 4611686018430775295 : i64,\n        stride = [7 : si32, 7 : si32]\n    } : (tensor<1x2048x7x7xf32>) -> tensor<1x2048x1x1xf32>\n    oneflow.return %res: tensor<1x2048x1x1xf32>\n}\n\n\n// CHECK-LABEL: test_conv2d\n// CHECK: [[V0:%.+]] = \"tosa.const\"()\n// CHECK-SAME: {value = dense<0.000000e+00> : tensor<5xf32>} \n// CHECK: [[V1:%.+]] = \"tosa.const\"()\n// CHECK-SAME: {value = dense<[0, 2, 3, 1]> : tensor<4xi32>}\n// CHECK: [[V2:%.+]] = \"tosa.transpose\"(%arg0, [[V1]]) : (tensor<1x3x224x224xf32>, tensor<4xi32>) -> tensor<1x224x224x3xf32>\n// CHECK: [[V3:%.+]] = \"tosa.const\"()\n// CHECK-SAME: {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} \n// CHECK: [[V4:%.+]] = \"tosa.transpose\"(%arg1, [[V3]]) : (tensor<5x3x1x1xf32>, tensor<4xi32>) -> tensor<5x1x1x3xf32>\n// CHECK: [[V5:%.+]] = \"tosa.conv2d\"([[V2]], [[V4]], [[V0]])\n// CHECK-SAME: {dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}\n// CHECK: [[V6:%.+]] = \"tosa.const\"()\n// CHECK-SAME: {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} \n// CHECK: [[V7:%.+]] = \"tosa.transpose\"([[V5]], [[V6]]) : (tensor<1x224x224x5xf32>, tensor<4xi32>) -> tensor<1x5x224x224xf32>\n// CHECK: return [[V7]] : tensor<1x5x224x224xf32>\noneflow.job @test_conv2d(%arg0: tensor<1x3x224x224xf32>, %arg1: tensor<5x3x1x1xf32>) -> tensor<1x5x224x224xf32>\n{\n    %res = \"oneflow.conv2d\"(%arg0, %arg1)\n    {\n        data_format = \"channels_first\",\n        device_name = [\"@0:0\"],\n        device_tag = \"cpu\",\n        dilation_rate = [1 : si32, 1 : si32],\n        filters = 512 : si32,\n        groups = 1 : si32,\n        hierarchy = [1],\n        kernel_size = [1 : si32, 1 : si32],\n        op_name = \"\",\n        operand_segment_sizes = array<i32: 1, 1, 0, 0>,\n        output_lbns = [\"\"],\n        padding_before = [0 : si32, 0 : si32],\n        scope_symbol_id = 4611686018431012863 : i64,\n        strides = [1 : si32, 1 : si32]\n    } : (tensor<1x3x224x224xf32>, tensor<5x3x1x1xf32>) -> tensor<1x5x224x224xf32>\n    oneflow.return %res : tensor<1x5x224x224xf32>\n}\n\n\n// CHECK-LABEL: test_matmul\n// CHECK: [[V0:%.+]] = \"tosa.reshape\"(%arg0)\n// CHECK: [[V1:%.+]] = \"tosa.const\"()\n// CHECK-SAME: {value = dense<[1, 0]> : tensor<2xi32>} \n// CHECK: [[V2:%.+]] = \"tosa.transpose\"(%arg1, [[V1]]) : (tensor<1000x2048xf32>, tensor<2xi32>) -> tensor<2048x1000xf32>\n// CHECK: [[V3:%.+]] = \"tosa.reshape\"([[V2]])\n// CHECK: [[V4:%.+]] = \"tosa.matmul\"([[V0]], [[V3]]) : (tensor<1x1x2048xf32>, tensor<1x2048x1000xf32>) -> tensor<1x1x1000xf32>\n// CHECK: [[V5:%.+]] = \"tosa.reshape\"([[V4]])\n// CHECK: return [[V5]] : tensor<1x1000xf32>\noneflow.job @test_matmul(%arg0: tensor<1x2048xf32>, %arg1: tensor<1000x2048xf32>) ->tensor<1x1000xf32>\n{\n    %res = \"oneflow.matmul\"(%arg0, %arg1)\n    {\n        alpha = 1.000000e+00 : f64,\n        device_name = [\"@0:0\"],\n        device_tag = \"cpu\",\n        hierarchy = [1],\n        op_name = \"\",\n        output_lbns = [\"\"],\n        scope_symbol_id = 4611686018431234047 : i64,\n        transpose_a = false,\n        transpose_b = true\n    } : (tensor<1x2048xf32>, tensor<1000x2048xf32>) -> tensor<1x1000xf32>\n    oneflow.return %res : tensor<1x1000xf32>\n}\n\n\n// CHECK-LABEL: test_relu\n// CHECK: [[V0:%.+]] = \"tosa.maximum\"([[V1:%.+]], [[V2:%.+]]) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>\n// CHECK: return [[V0]] : tensor<1xf32>\noneflow.job @test_relu(%arg0: tensor<1xf32>) -> tensor<1xf32> {\n    %res = \"oneflow.relu\"(%arg0)\n    {\n        device_name = [\"@0:0\"],\n        device_tag = \"cpu\",\n        hierarchy = [1],\n        op_name = \"\",\n        output_lbns = [\"\"],\n        scope_symbol_id = 4611686018427424767 : i64\n    } : (tensor<1xf32>) -> tensor<1xf32>\n    oneflow.return %res : tensor<1xf32>\n}\n\n// CHECK-LABEL: test_bn\n// CHECK: \"tosa.sub\"\n// CHECK: \"tosa.add\"\n// CHECK: \"tosa.rsqrt\"\n// CHECK: \"tosa.mul\"\n// CHECK: \"tosa.mul\"\n// CHECK: \"tosa.add\"\noneflow.job @test_bn(\n%x:               tensor<1x64x112x112xf32>,\n%moving_mean:     tensor<64xf32>,\n%moving_variance: tensor<64xf32>,\n%gamma:           tensor<64xf32>,\n%beta:            tensor<64xf32>) -> tensor<1x64x112x112xf32>\n{\n    %y, %mean, %inv_variance = \"oneflow.normalization\"(%x, %moving_mean, %moving_variance, %gamma, %beta)\n    {\n        axis = 1 : si32,\n        device_name = [\"@0:0\"],\n        device_tag = \"cpu\",\n        epsilon = 9.99999974E-6 : f32,\n        hierarchy = [1],\n        momentum = 0.899999976 : f32,\n        op_name = \"\",\n        operand_segment_sizes = array<i32: 1, 1, 1, 1, 1, 0>,\n        output_lbns = [\"\", \"\", \"\"],\n        result_segment_sizes = array<i32: 1, 1, 1>,\n        scope_symbol_id = 4611686018427453439 : i64,\n        training = true\n    } : (tensor<1x64x112x112xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>) -> (tensor<1x64x112x112xf32>, tensor<64xf32>, tensor<64xf32>)\n    oneflow.return %y: tensor<1x64x112x112xf32>\n}\n\n// CHECK-LABEL: test_bn_infer\n// CHECK: \"tosa.sub\"\n// CHECK: \"tosa.add\"\n// CHECK: \"tosa.rsqrt\"\n// CHECK: \"tosa.mul\"\n// CHECK: \"tosa.mul\"\n// CHECK: \"tosa.add\"\noneflow.job @test_bn_infer(\n%x:               tensor<1x64x112x112xf32>,\n%moving_mean:     tensor<64xf32>,\n%moving_variance: tensor<64xf32>,\n%gamma:           tensor<64xf32>,\n%beta:            tensor<64xf32>) -> tensor<1x64x112x112xf32>\n{\n    %y = \"oneflow.normalization_infer\"(%x, %moving_mean, %moving_variance, %gamma, %beta)\n    {\n        axis = 1 : si32,\n        device_name = [\"@0:0\"],\n        device_tag = \"cpu\",\n        epsilon = 9.99999974E-6 : f32,\n        hierarchy = [1],\n        momentum = 0.899999976 : f32,\n        op_name = \"\",\n        operand_segment_sizes = array<i32: 1, 1, 1, 1, 1, 0>,\n        output_lbns = [\"\", \"\", \"\"],\n        result_segment_sizes = array<i32: 1>,\n        scope_symbol_id = 4611686018427453439 : i64,\n        training = true\n    } : (tensor<1x64x112x112xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>) -> tensor<1x64x112x112xf32>\n    oneflow.return %y: tensor<1x64x112x112xf32>\n}\n"
  },
  {
    "path": "oneflow/ir/test/OneFlow/cse.mlir",
    "content": "// RUN: oneflow-opt %s \\\n// RUN: -cse-with-attributes-ignored -cse -cse-put-attributes -canonicalize | FileCheck %s\n\nmodule  {\n  func.func @Cast_1__FUSE__ScalarMulByTensor_2(%arg0: tensor<96x96xi64>) -> tensor<96x96xf32> {\n    %0 = \"oneflow.cast\"(%arg0) {device_name = [\"0:0\"], device_tag = \"cpu\", dtype = 2 : i32, hierarchy = [1], op_name = \"Cast_1\", op_type_name = \"cast\", scope_symbol_id = 4611686018427416574 : i64} : (tensor<96x96xi64>) -> tensor<96x96xf32>\n    %1 = \"oneflow.cast\"(%arg0) {device_name = [\"0:0\"], device_tag = \"cpu\", dtype = 2 : i32, hierarchy = [1], op_name = \"Cast_2\", op_type_name = \"cast\", scope_symbol_id = 4611686018427416574 : i64} : (tensor<96x96xi64>) -> tensor<96x96xf32>\n    %2 = \"oneflow.add_n\"(%0, %1) {device_name = [\"0:0\"], device_tag = \"cpu\", hierarchy = [1], op_name = \"ScalarMulByTensor_2\", op_type_name = \"add_n\", scope_symbol_id = 4611686018427416574 : i64} : (tensor<96x96xf32>, tensor<96x96xf32>) -> tensor<96x96xf32>\n    // CHECK: %[[OUT:[a-zA-Z0-9_]+]] = \"oneflow.cast\"\n    // CHECK: \"oneflow.add_n2\"(%[[OUT]], %[[OUT]])\n    // CHECK: op_name = \"ScalarMulByTensor_2\"\n    return %2 : tensor<96x96xf32>\n  }\n  func.func @f2(%input: tensor<2x64x64x320xf16>, %w: tensor<320x320x3x3xf16>, %bias: tensor<320xf16>) -> (tensor<2x64x64x320xf16>, tensor<2x64x64x320xf16>) {\n    %transpose_w = \"oneflow.transpose\"(%w) {device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], op_name = \"unet.down_blocks.0.resnets.0.conv1-conv2d-31_transpose_input_1\", perm = [0 : si32, 2 : si32, 3 : si32, 1 : si32], scope_symbol_id = 163 : i64} : (tensor<320x320x3x3xf16>) -> tensor<320x3x3x320xf16>\n    %conv2d = \"oneflow.conv2d\"(%input, %transpose_w, %bias) {data_format = \"channels_last\", device_name = [\"@0:0\"], device_tag = \"cuda\", dilation_rate = [1 : si32, 1 : si32], filters = 320 : si32, groups = 1 : si32, hierarchy = [1], kernel_size = [3 : si32, 3 : si32], op_name = \"unet.down_blocks.0.resnets.0.conv1-conv2d-31\", operand_segment_sizes = array<i32: 1, 1, 1, 0>, padding_before = [1 : si32, 1 : si32], scope_symbol_id = 163 : i64, strides = [1 : si32, 1 : si32]} : (tensor<2x64x64x320xf16>, tensor<320x3x3x320xf16>, tensor<320xf16>) -> tensor<2x64x64x320xf16>\n    %transpose_w1 = \"oneflow.transpose\"(%w) {device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], op_name = \"unet.down_blocks.0.resnets.0.conv1-conv2d-31_transpose_input_2\", perm = [0 : si32, 2 : si32, 3 : si32, 1 : si32], scope_symbol_id = 163 : i64} : (tensor<320x320x3x3xf16>) -> tensor<320x3x3x320xf16>\n    %conv2d_2 = \"oneflow.conv2d\"(%input, %transpose_w1, %bias) {data_format = \"channels_last\", device_name = [\"@0:0\"], device_tag = \"cuda\", dilation_rate = [1 : si32, 1 : si32], filters = 320 : si32, groups = 1 : si32, hierarchy = [1], kernel_size = [3 : si32, 3 : si32], op_name = \"unet.down_blocks.0.resnets.0.conv1-conv2d-31\", operand_segment_sizes = array<i32: 1, 1, 1, 0>, padding_before = [1 : si32, 1 : si32], scope_symbol_id = 163 : i64, strides = [1 : si32, 1 : si32]} : (tensor<2x64x64x320xf16>, tensor<320x3x3x320xf16>, tensor<320xf16>) -> tensor<2x64x64x320xf16>\n    return %conv2d, %conv2d_2 : tensor<2x64x64x320xf16>, tensor<2x64x64x320xf16>\n  // CHECK: %[[OUT:[a-zA-Z0-9_]+]] = \"oneflow.conv2d\"\n  // CHECK: scope_symbol_id = 163 : i64\n  // CHECK: return %[[OUT]], %[[OUT]]\n  }\n}\n"
  },
  {
    "path": "oneflow/ir/test/OneFlow/cuda_code_gen/gpu_copy_arg.mlir",
    "content": "// RUN: oneflow-opt %s -lower-oneflow-to-tosa -tosa-make-broadcastable \\\n// RUN: | oneflow-opt -pass-pipeline=\"builtin.module(func.func(tosa-to-linalg))\" \\\n// RUN: | oneflow-opt -cse --linalg-fuse-elementwise-ops -linalg-bufferize -convert-linalg-to-parallel-loops -gpu-map-parallel-loops \\\n// RUN: -convert-parallel-loops-to-gpu -gpu-kernel-outlining -buffer-host-register -canonicalize \\\n// RUN: | oneflow-opt -pass-pipeline='builtin.module(gpu.module(strip-debuginfo,lower-affine,convert-gpu-to-nvvm,gpu-to-cubin))' \\\n// RUN: | oneflow-opt --func-bufferize -buffer-results-to-out-params -gpu-copy-arg\nfunc.func @Cast_289__FUSE__ScalarMulByTensor_290(%arg0: tensor<3x3xi64>, %arg1: tensor<1xf32>) -> tensor<3x3xf32> {\n  %0 = \"oneflow.cast\"(%arg0) {device_name = [\"@0:0\"], device_tag = \"cuda\", dtype = 2 : i32, hierarchy = [1], op_name = \"Cast_289\", output_lbns = [\"Cast_289/out_0\"], scope_symbol_id = 4611686018427478014 : i64} : (tensor<3x3xi64>) -> tensor<3x3xf32>\n  %1 = \"oneflow.scalar_mul_by_tensor\"(%0, %arg1) {device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], op_name = \"ScalarMulByTensor_290\", output_lbns = [\"ScalarMulByTensor_290/y_0\"], scope_symbol_id = 4611686018427478014 : i64} : (tensor<3x3xf32>, tensor<1xf32>) -> tensor<3x3xf32>\n  return %1 : tensor<3x3xf32>\n}\n\n// CHECK: gpu.memcpy  %arg2\n"
  },
  {
    "path": "oneflow/ir/test/OneFlow/cuda_code_gen/lit.local.cfg",
    "content": "if not config.WITH_MLIR_CUDA_CODEGEN:\n  config.unsupported = True\n"
  },
  {
    "path": "oneflow/ir/test/OneFlow/cuda_code_gen/test_append_oneflow_stream.mlir",
    "content": "// RUN: oneflow-opt %s -append-ofstream | FileCheck %s\n\n// CHECK: func.func @JITOpGenerated0(%arg0: memref<1xf32>, %arg1: memref<5xi64>, %arg2: memref<5xf32>, %arg3: !llvm.ptr) attributes {llvm.emit_c_interface}\n\nmodule attributes {gpu.container_module} {\n  func.func @JITOpGenerated0(%arg0: memref<1xf32>, %arg1: memref<5xi64>, %arg2: memref<5xf32>) attributes {llvm.emit_c_interface} {\n    %c5 = arith.constant 5 : index\n    %c1 = arith.constant 1 : index\n    %collapse_shape = memref.collapse_shape %arg0 [] : memref<1xf32> into memref<f32>\n    gpu.launch_func  @JITOpGenerated0_kernel::@JITOpGenerated0_kernel blocks in (%c5, %c1, %c1) threads in (%c1, %c1, %c1) args(%arg1 : memref<5xi64>, %collapse_shape : memref<f32>, %arg2 : memref<5xf32>)\n    return\n  }\n  gpu.module @JITOpGenerated0_kernel attributes {gpu.binary = \"\"} {\n    llvm.func @JITOpGenerated0_kernel(%arg0: !llvm.ptr<i64>, %arg1: !llvm.ptr<i64>, %arg2: i64, %arg3: i64, %arg4: i64, %arg5: !llvm.ptr<f32>, %arg6: !llvm.ptr<f32>, %arg7: i64, %arg8: !llvm.ptr<f32>, %arg9: !llvm.ptr<f32>, %arg10: i64, %arg11: i64, %arg12: i64) attributes {gpu.kernel, gpu.known_block_size = array<i32: 1, 1, 1>, nvvm.kernel} {\n      %0 = llvm.mlir.undef : !llvm.struct<(ptr<i64>, ptr<i64>, i64, array<1 x i64>, array<1 x i64>)>\n      %1 = llvm.insertvalue %arg0, %0[0] : !llvm.struct<(ptr<i64>, ptr<i64>, i64, array<1 x i64>, array<1 x i64>)> \n      %2 = llvm.insertvalue %arg1, %1[1] : !llvm.struct<(ptr<i64>, ptr<i64>, i64, array<1 x i64>, array<1 x i64>)> \n      %3 = llvm.insertvalue %arg2, %2[2] : !llvm.struct<(ptr<i64>, ptr<i64>, i64, array<1 x i64>, array<1 x i64>)> \n      %4 = llvm.insertvalue %arg3, %3[3, 0] : !llvm.struct<(ptr<i64>, ptr<i64>, i64, array<1 x i64>, array<1 x i64>)> \n      %5 = llvm.insertvalue %arg4, %4[4, 0] : !llvm.struct<(ptr<i64>, ptr<i64>, i64, array<1 x i64>, array<1 x i64>)> \n      %6 = llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64)>\n      %7 = llvm.insertvalue %arg5, %6[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64)> \n      %8 = llvm.insertvalue %arg6, %7[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64)> \n      %9 = llvm.insertvalue %arg7, %8[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64)> \n      %10 = llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>\n      %11 = llvm.insertvalue %arg8, %10[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)> \n      %12 = llvm.insertvalue %arg9, %11[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)> \n      %13 = llvm.insertvalue %arg10, %12[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)> \n      %14 = llvm.insertvalue %arg11, %13[3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)> \n      %15 = llvm.insertvalue %arg12, %14[4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)> \n      %16 = nvvm.read.ptx.sreg.ctaid.x : i32\n      %17 = llvm.sext %16 : i32 to i64\n      %18 = llvm.extractvalue %5[1] : !llvm.struct<(ptr<i64>, ptr<i64>, i64, array<1 x i64>, array<1 x i64>)> \n      %19 = llvm.getelementptr %18[%17] : (!llvm.ptr<i64>, i64) -> !llvm.ptr<i64>\n      %20 = llvm.load %19 : !llvm.ptr<i64>\n      %21 = llvm.extractvalue %9[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64)> \n      %22 = llvm.load %21 : !llvm.ptr<f32>\n      %23 = llvm.sitofp %20 : i64 to f32\n      %24 = llvm.fmul %23, %22  : f32\n      %25 = llvm.extractvalue %15[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)> \n      %26 = llvm.getelementptr %25[%17] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>\n      llvm.store %24, %26 : !llvm.ptr<f32>\n      llvm.return\n    }\n  }\n}\n"
  },
  {
    "path": "oneflow/ir/test/OneFlow/cuda_code_gen/test_cast_ops_to_signless.mlir",
    "content": "// RUN: oneflow-opt %s -cast-ofops-to-signless  | FileCheck %s\n// CHECK: unrealized_conversion_cast\nfunc.func @Cast_289__FUSE__ScalarMulByTensor_290() -> tensor<512x2048x1x1xf32> {\n    %output_299 = \"oneflow.variable\"() {data_type = 2 : i32, device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], op_name = \"resnet.layer4.2.conv1.weight\", output_lbns = [\"resnet.layer4.2.conv1.weight/out\"], parallel = #sbp.parallel<[] -> [[#sbp.B]]>, scope_symbol_id = 1995 : i64, shape = [512 : si64, 2048 : si64, 1 : si64, 1 : si64]} : () -> tensor<512x2048x1x1xsi64>\n    %0 = \"oneflow.cast\"(%output_299) {device_name = [\"0:0\"], device_tag = \"cpu\", dtype = 2 : i32, hierarchy = [1], op_name = \"Cast_1\", op_type_name = \"cast\", scope_symbol_id = 4611686018427416574 : i64} : (tensor<512x2048x1x1xsi64>) -> tensor<512x2048x1x1xf32>\n    func.return %0 : tensor<512x2048x1x1xf32>\n}"
  },
  {
    "path": "oneflow/ir/test/OneFlow/cuda_code_gen/test_fold_alloc_to_subview.mlir",
    "content": "// RUN: oneflow-opt %s -fold-alloc-to-subview\n#map = affine_map<(d0)[s0, s1] -> ((d0 - s0) ceildiv s1)>\n#map1 = affine_map<(d0)[s0, s1] -> (d0 * s0 + s1)>\nmodule attributes {gpu.container_module} {\n  func.func @JITOpGenerated0(%arg0: memref<1xf32>, %arg1: memref<5xi64>, %arg2: memref<5xf32>) attributes {llvm.emit_c_interface} {\n    %c0 = arith.constant 0 : index\n    %c5 = arith.constant 5 : index\n    %c1 = arith.constant 1 : index\n    %collapse_shape = memref.collapse_shape %arg0 [] : memref<1xf32> into memref<f32>\n    %alloc = memref.alloc() {alignment = 64 : i64} : memref<5xf32>\n    // CHECK-NOT: %alloc = memref.alloc() {alignment = 64 : i64} : memref<5xf32>\n    // CHECK: memref.alloc() : memref<512xi8>\n    // CHECK: memref.view\n    %c1_0 = arith.constant 1 : index\n    %0 = affine.apply #map(%c5)[%c0, %c1]\n    gpu.launch_func  @JITOpGenerated0_kernel::@JITOpGenerated0_kernel blocks in (%0, %c1_0, %c1_0) threads in (%c1_0, %c1_0, %c1_0) args(%arg1 : memref<5xi64>, %alloc : memref<5xf32>)\n    %c1_2 = arith.constant 1 : index\n    %1 = affine.apply #map(%c5)[%c0, %c1]\n    gpu.launch_func  @JITOpGenerated0_kernel_0::@JITOpGenerated0_kernel blocks in (%1, %c1_2, %c1_2) threads in (%c1_2, %c1_2, %c1_2) args(%alloc : memref<5xf32>, %collapse_shape : memref<f32>, %arg2 : memref<5xf32>)\n    return\n  }\n  gpu.module @JITOpGenerated0_kernel {\n    gpu.func @JITOpGenerated0_kernel(%arg0: memref<5xi64>, %arg1: memref<5xf32>) kernel attributes {gpu.known_block_size = array<i32: 1, 1, 1>} {\n      %0 = gpu.block_id  x\n      %1 = gpu.block_id  y\n      %2 = gpu.block_id  z\n      %3 = gpu.thread_id  x\n      %4 = gpu.thread_id  y\n      %5 = gpu.thread_id  z\n      %6 = gpu.grid_dim  x\n      %7 = gpu.grid_dim  y\n      %8 = gpu.grid_dim  z\n      %9 = gpu.block_dim  x\n      %10 = gpu.block_dim  y\n      %11 = gpu.block_dim  z\n      cf.br ^bb1\n    ^bb1:  // pred: ^bb0\n      %c1 = arith.constant 1 : index\n      %c0 = arith.constant 0 : index\n      %12 = affine.apply #map1(%0)[%c1, %c0]\n      %13 = memref.load %arg0[%12] : memref<5xi64>\n      %14 = arith.sitofp %13 : i64 to f32\n      memref.store %14, %arg1[%12] : memref<5xf32>\n      gpu.return\n    }\n  }\n  gpu.module @JITOpGenerated0_kernel_0 {\n    gpu.func @JITOpGenerated0_kernel(%arg0: memref<5xf32>, %arg1: memref<f32>, %arg2: memref<5xf32>) kernel attributes {gpu.known_block_size = array<i32: 1, 1, 1>} {\n      %0 = gpu.block_id  x\n      %1 = gpu.block_id  y\n      %2 = gpu.block_id  z\n      %3 = gpu.thread_id  x\n      %4 = gpu.thread_id  y\n      %5 = gpu.thread_id  z\n      %6 = gpu.grid_dim  x\n      %7 = gpu.grid_dim  y\n      %8 = gpu.grid_dim  z\n      %9 = gpu.block_dim  x\n      %10 = gpu.block_dim  y\n      %11 = gpu.block_dim  z\n      cf.br ^bb1\n    ^bb1:  // pred: ^bb0\n      %c1 = arith.constant 1 : index\n      %c0 = arith.constant 0 : index\n      %12 = affine.apply #map1(%0)[%c1, %c0]\n      %13 = memref.load %arg0[%12] : memref<5xf32>\n      %14 = memref.load %arg1[] : memref<f32>\n      %15 = arith.mulf %13, %14 : f32\n      memref.store %15, %arg2[%12] : memref<5xf32>\n      gpu.return\n    }\n  }\n}\n"
  },
  {
    "path": "oneflow/ir/test/OneFlow/cuda_code_gen/test_fuser_cast_scale.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n# RUN: python3 -m oneflow.test_utils.throttle --with-cuda=%with_cuda python3 %s | FileCheck %s\n# CHECK: jit\n\nimport unittest\nimport numpy as np\n\nimport os\n\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\nclass CastModule(flow.nn.Module):\n    def __init__(self):\n        super().__init__()\n\n    def forward(self, x, scale):\n        # TODO: also support scale as a scalar, for instance: scale = 7.7\n        return x.to(dtype=flow.float32) * scale\n\n\ndef do_relu_graph(test_case, data, with_cuda):\n    x = flow.tensor(data, dtype=flow.int64)\n    scale = flow.tensor([7.7], dtype=flow.float32)\n    if with_cuda:\n        x = x.cuda()\n        scale = scale.cuda()\n    module_to_run = CastModule()\n    y_eager = module_to_run(x, scale)\n\n    class GraphToRun(flow.nn.Graph):\n        def __init__(self):\n            super().__init__()\n            self.fw = module_to_run\n\n        def build(self, x, scale):\n            return self.fw(x, scale)\n\n    graph_to_run = GraphToRun()\n    y_lazy = graph_to_run(x, scale)\n    test_case.assertTrue(np.array_equal(y_eager.numpy(), y_lazy.numpy()))\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestFuseCastScale(oneflow.unittest.MLIRTestCase):\n    def setUp(self):\n        os.environ[\"ONEFLOW_MLIR_ENABLE_ROUND_TRIP\"] = \"1\"\n        os.environ[\"ONEFLOW_MLIR_ENABLE_CODEGEN_FUSERS\"] = \"1\"\n        os.environ[\"ONEFLOW_MLIR_FUSE_FORWARD_OPS\"] = \"1\"\n\n    def test_relu_graph(test_case):\n        import oneflow.sysconfig\n\n        if oneflow.sysconfig.with_cuda():\n            do_relu_graph(test_case, np.array([2.0, 1.0, 0.0, -1.0, -2.0]), True)\n        do_relu_graph(\n            test_case,\n            np.array([[2.0, 1.0, 0.0, -1.0, -2.0], [2.0, 1.0, 0.0, -1.0, -2.0]]),\n            False,\n        )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "oneflow/ir/test/OneFlow/cuda_code_gen/test_gpu_all_reduce.mlir",
    "content": "// RUN: oneflow-opt %s \\\n// RUN: | oneflow-opt -gpu-kernel-outlining \\\n// RUN: | oneflow-opt -pass-pipeline='builtin.module(gpu.module(strip-debuginfo,convert-gpu-to-nvvm,gpu-to-cubin))' \\\n// RUN: | oneflow-opt -gpu-to-llvm \\\n// RUN: | oneflow-runner \\\n// RUN:   --shared-libs=%mlir_cuda_runtime \\\n// RUN:   --shared-libs=%mlir_runner_utils \\\n// RUN:   --entry-point-result=void \\\n// RUN: | FileCheck %s\n\nfunc.func @main() {\n  %data = memref.alloc() : memref<2x6xi32>\n  %sum = memref.alloc() : memref<2xi32>\n  %cst0 = arith.constant 0 : i32\n  %cst1 = arith.constant 1 : i32\n  %cst2 = arith.constant 2 : i32\n  %cst4 = arith.constant 4 : i32\n  %cst8 = arith.constant 8 : i32\n  %cst16 = arith.constant 16 : i32\n\n  %cst3 = arith.constant 3 : i32\n  %cst6 = arith.constant 6 : i32\n  %cst7 = arith.constant 7 : i32\n  %cst10 = arith.constant 10 : i32\n  %cst11 = arith.constant 11 : i32\n\n  %c0 = arith.constant 0 : index\n  %c1 = arith.constant 1 : index\n  %c2 = arith.constant 2 : index\n  %c3 = arith.constant 3 : index\n  %c4 = arith.constant 4 : index\n  %c5 = arith.constant 5 : index\n  %c6 = arith.constant 6 : index\n\n  %cast_data = memref.cast %data : memref<2x6xi32> to memref<*xi32>\n  gpu.host_register %cast_data : memref<*xi32>\n  %cast_sum = memref.cast %sum : memref<2xi32> to memref<*xi32>\n  gpu.host_register %cast_sum : memref<*xi32>\n\n  memref.store %cst0, %data[%c0, %c0] : memref<2x6xi32>\n  memref.store %cst1, %data[%c0, %c1] : memref<2x6xi32>\n  memref.store %cst2, %data[%c0, %c2] : memref<2x6xi32>\n  memref.store %cst4, %data[%c0, %c3] : memref<2x6xi32>\n  memref.store %cst8, %data[%c0, %c4] : memref<2x6xi32>\n  memref.store %cst16, %data[%c0, %c5] : memref<2x6xi32>\n\n  memref.store %cst2, %data[%c1, %c0] : memref<2x6xi32>\n  memref.store %cst3, %data[%c1, %c1] : memref<2x6xi32>\n  memref.store %cst6, %data[%c1, %c2] : memref<2x6xi32>\n  memref.store %cst7, %data[%c1, %c3] : memref<2x6xi32>\n  memref.store %cst10, %data[%c1, %c4] : memref<2x6xi32>\n  memref.store %cst11, %data[%c1, %c5] : memref<2x6xi32>\n\n  // MAX\n  gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c2, %grid_y = %c1, %grid_z = %c1)\n             threads(%tx, %ty, %tz) in (%block_x = %c6, %block_y = %c1, %block_z = %c1) {\n    %val = memref.load %data[%bx, %tx] : memref<2x6xi32>\n    %reduced = gpu.all_reduce max %val uniform {} : (i32) -> (i32)\n    memref.store %reduced, %sum[%bx] : memref<2xi32>\n    gpu.terminator\n  }\n\n  call @printMemrefI32(%cast_sum) : (memref<*xi32>) -> ()\n  // CHECK: [16, 11]\n\n  return\n}\n\nfunc.func private @printMemrefI32(memref<*xi32>)\n\n"
  },
  {
    "path": "oneflow/ir/test/OneFlow/cuda_code_gen/test_insert_ofmempool.mlir",
    "content": "// RUN: oneflow-opt %s -insert-ofmempool | FileCheck %s\n\n#map = affine_map<(d0)[s0, s1] -> (d0 * s0 + s1)>\nmodule attributes {gpu.container_module} {\n  // CHECK: func.func @JITOpGenerated0(%[[ARG0:[a-zA-Z0-9_]+]]: memref<512xi8>\n  func.func @JITOpGenerated0(%arg0: memref<1xf32>, %arg1: memref<5xi64>, %arg2: memref<5xf32>) attributes {llvm.emit_c_interface} {\n    %c1 = arith.constant 1 : index\n    %c5 = arith.constant 5 : index\n    %c0 = arith.constant 0 : index\n    // CHECK-NOT: memref.alloc() : memref<512xi8>\n    %alloc = memref.alloc() : memref<512xi8>\n    // CHECK: memref.view %[[ARG0]]\n    %view = memref.view %alloc[%c0][] : memref<512xi8> to memref<5xf32>\n    %collapse_shape = memref.collapse_shape %arg0 [] : memref<1xf32> into memref<f32>\n    gpu.launch_func  @JITOpGenerated0_kernel::@JITOpGenerated0_kernel blocks in (%c5, %c1, %c1) threads in (%c1, %c1, %c1) args(%arg1 : memref<5xi64>, %view : memref<5xf32>)\n    gpu.launch_func  @JITOpGenerated0_kernel_0::@JITOpGenerated0_kernel blocks in (%c5, %c1, %c1) threads in (%c1, %c1, %c1) args(%view : memref<5xf32>, %collapse_shape : memref<f32>, %arg2 : memref<5xf32>)\n    return\n  }\n  gpu.module @JITOpGenerated0_kernel {\n    gpu.func @JITOpGenerated0_kernel(%arg0: memref<5xi64>, %arg1: memref<5xf32>) kernel attributes {gpu.known_block_size = array<i32: 1, 1, 1>} {\n      %c0 = arith.constant 0 : index\n      %c1 = arith.constant 1 : index\n      %0 = gpu.block_id  x\n      cf.br ^bb1\n    ^bb1:  // pred: ^bb0\n      %1 = affine.apply #map(%0)[%c1, %c0]\n      %2 = memref.load %arg0[%1] : memref<5xi64>\n      %3 = arith.sitofp %2 : i64 to f32\n      memref.store %3, %arg1[%1] : memref<5xf32>\n      gpu.return\n    }\n  }\n  gpu.module @JITOpGenerated0_kernel_0 {\n    gpu.func @JITOpGenerated0_kernel(%arg0: memref<5xf32>, %arg1: memref<f32>, %arg2: memref<5xf32>) kernel attributes {gpu.known_block_size = array<i32: 1, 1, 1>} {\n      %c0 = arith.constant 0 : index\n      %c1 = arith.constant 1 : index\n      %0 = gpu.block_id  x\n      cf.br ^bb1\n    ^bb1:  // pred: ^bb0\n      %1 = affine.apply #map(%0)[%c1, %c0]\n      %2 = memref.load %arg0[%1] : memref<5xf32>\n      %3 = memref.load %arg1[] : memref<f32>\n      %4 = arith.mulf %2, %3 : f32\n      memref.store %4, %arg2[%1] : memref<5xf32>\n      gpu.return\n    }\n  }\n}"
  },
  {
    "path": "oneflow/ir/test/OneFlow/cuda_code_gen/test_matmul.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n# RUN: python3 -m oneflow.test_utils.throttle --with-cuda=%with_cuda python3 %s | FileCheck %s\n# CHECK: jit\n\nimport unittest\nimport numpy as np\n\nimport os\n\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\nclass MatMulModule(flow.nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.w = flow.nn.Parameter(flow.Tensor(5, 10))\n        self.b = flow.nn.Parameter(flow.Tensor(10))\n\n    def forward(self, x):\n        return flow.matmul(x, self.w) + self.b\n\n\ndef do_matmul_graph(test_case, with_cuda=False):\n    x = flow.randn(2, 5)\n    module_to_run = MatMulModule()\n    if with_cuda:\n        x = x.cuda()\n        module_to_run = module_to_run.to(\"cuda\")\n    y_eager = module_to_run(x)\n\n    class GraphToRun(flow.nn.Graph):\n        def __init__(self):\n            super().__init__()\n            self.fw = module_to_run\n\n        def build(self, x):\n            return self.fw(x)\n\n    graph_to_run = GraphToRun()\n    y_lazy = graph_to_run(x)\n    test_case.assertTrue(np.array_equal(y_eager.numpy(), y_lazy.numpy()))\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestFuseCastScale(oneflow.unittest.MLIRTestCase):\n    def setUp(self):\n        os.environ[\"ONEFLOW_MLIR_ENABLE_ROUND_TRIP\"] = \"1\"\n        os.environ[\"ONEFLOW_MLIR_ENABLE_CODEGEN_FUSERS\"] = \"1\"\n        os.environ[\"ONEFLOW_MLIR_FUSE_FORWARD_OPS\"] = \"1\"\n\n    def test_relu_graph(test_case):\n        import oneflow.sysconfig\n\n        if oneflow.sysconfig.with_cuda():\n            do_matmul_graph(test_case, True)\n\n        do_matmul_graph(test_case)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "oneflow/ir/test/OneFlow/cuda_code_gen/test_mgpu_to_oneflow_stream.mlir",
    "content": "// RUN: oneflow-opt %s -mgpu-to-ofstream\n\nmodule attributes {gpu.container_module} {\n  llvm.mlir.global internal constant @JITOpGenerated0_kernel_JITOpGenerated0_kernel_kernel_name(\"JITOpGenerated0_kernel\\00\") {addr_space = 0 : i32}\n  llvm.mlir.global internal constant @JITOpGenerated0_kernel_gpubin_cst(\"\\7FELF\\02\\01\\013\\07\\00\\00\\00\\00\\00\\00\\00\\02\\00\\BE\\00u\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\0A\\00\\00\\00\\00\\00\\00V\\05V\\00@\\00\\00\\00\\00\\00@\\00\\0C\\00\\01\\00\\00.shstrtab\\00.strtab\\00.symtab\\00.symtab_shndx\\00.nv.info\\00.text.JITOpGenerated0_kernel\\00.nv.info.JITOpGenerated0_kernel\\00.nv.shared.JITOpGenerated0_kernel\\00.nv.constant0.JITOpGenerated0_kernel\\00.rel.nv.constant0.JITOpGenerated0_kernel\\00.debug_frame\\00.rel.debug_frame\\00.rela.debug_frame\\00.nv.callgraph\\00.nv.prototype\\00.nv.rel.action\\00\\00.shstrtab\\00.strtab\\00.symtab\\00.symtab_shndx\\00.nv.info\\00JITOpGenerated0_kernel\\00.text.JITOpGenerated0_kernel\\00.nv.info.JITOpGenerated0_kernel\\00.nv.shared.JITOpGenerated0_kernel\\00.rel.nv.constant0.JITOpGenerated0_kernel\\00.nv.constant0.JITOpGenerated0_kernel\\00_param\\00.debug_frame\\00.rel.debug_frame\\00.rela.debug_frame\\00.nv.callgraph\\00.nv.prototype\\00.nv.rel.action\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00I\\00\\00\\00\\03\\00\\0B\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\D1\\00\\00\\00\\03\\00\\0A\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\FD\\00\\00\\00\\03\\00\\04\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00-\\01\\00\\00\\03\\00\\07\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00I\\01\\00\\00\\03\\00\\08\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\002\\00\\00\\00\\12\\10\\0B\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\02\\00\\00\\00\\00\\00\\00\\FF\\FF\\FF\\FF(\\00\\00\\00\\00\\00\\00\\00\\FF\\FF\\FF\\FF\\FF\\FF\\FF\\FF\\03\\00\\04|\\FF\\FF\\FF\\FF\\0F\\0C\\81\\80\\80(\\00\\08\\FF\\81\\80(\\08\\81\\80\\80(\\00\\00\\00\\00\\00\\00\\00\\FF\\FF\\FF\\FF0\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\F0\\01\\00\\00\\00\\00\\00\\00\\04\\04\\00\\00\\00\\04<\\00\\00\\00\\0C\\81\\80\\80(\\00\\04\\FC\\FF\\FF?\\00\\00\\00\\04\\11\\08\\00\\06\\00\\00\\00\\00\\00\\00\\00\\04/\\08\\00\\06\\00\\00\\00\\0E\\00\\00\\00\\04\\12\\08\\00\\06\\00\\00\\00\\00\\00\\00\\00\\04\\1C\\04\\00\\F0\\00\\00\\00\\03\\1B\\FF\\00\\04\\17\\0C\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\F0!\\00\\04\\17\\0C\\00\\00\\00\\00\\00\\01\\00\\08\\00\\00\\F0!\\00\\04\\17\\0C\\00\\00\\00\\00\\00\\02\\00\\10\\00\\00\\F0!\\00\\04\\17\\0C\\00\\00\\00\\00\\00\\03\\00\\18\\00\\00\\F0!\\00\\04\\17\\0C\\00\\00\\00\\00\\00\\04\\00 \\00\\00\\F0!\\00\\04\\17\\0C\\00\\00\\00\\00\\00\\05\\00(\\00\\00\\F0!\\00\\04\\17\\0C\\00\\00\\00\\00\\00\\06\\000\\00\\00\\F0!\\00\\04\\17\\0C\\00\\00\\00\\00\\00\\07\\008\\00\\00\\F0!\\00\\04\\17\\0C\\00\\00\\00\\00\\00\\08\\00@\\00\\00\\F0!\\00\\04\\17\\0C\\00\\00\\00\\00\\00\\09\\00H\\00\\00\\F0!\\00\\04\\17\\0C\\00\\00\\00\\00\\00\\0A\\00P\\00\\00\\F0!\\00\\04\\17\\0C\\00\\00\\00\\00\\00\\0B\\00X\\00\\00\\F0!\\00\\04\\17\\0C\\00\\00\\00\\00\\00\\0C\\00`\\00\\00\\F0!\\00\\03\\19h\\00\\04\\0A\\08\\00\\02\\00\\00\\00`\\01h\\00\\015\\00\\00\\047\\04\\00u\\00\\00\\00\\00\\00\\00\\00\\FF\\FF\\FF\\FF\\00\\00\\00\\00\\FE\\FF\\FF\\FF\\00\\00\\00\\00\\FD\\FF\\FF\\FF\\00\\00\\00\\00K\\00\\00\\00\\00\\00\\00\\00\\00\\02\\02\\08\\10\\0A/\\22\\00\\00\\00\\08\\00\\00\\00\\00\\00\\00\\08\\08\\00\\00\\00\\00\\00\\00\\10\\08\\00\\00\\00\\00\\00\\00\\18\\08\\00\\00\\00\\00\\00\\00 \\08\\00\\00\\00\\00\\00\\00(\\08\\00\\00\\00\\00\\00\\000\\08\\00\\00\\00\\00\\00\\008\\08\\00\\00\\00\\00\\01\\00\\00\\08\\00\\00\\00\\00\\01\\00\\08\\08\\00\\00\\00\\00\\01\\00\\10\\08\\00\\00\\00\\00\\01\\00\\18\\08\\00\\00\\00\\00\\01\\00 \\08\\00\\00\\00\\00\\01\\00(\\08\\00\\00\\00\\00\\01\\000\\08\\00\\00\\00\\00\\01\\008\\08\\00\\00\\00\\00\\02\\00\\00\\08\\00\\00\\00\\00\\02\\00\\08\\08\\00\\00\\00\\00\\02\\00\\10\\08\\00\\00\\00\\00\\02\\00\\18\\08\\00\\00\\00\\00\\02\\00 \\08\\00\\00\\00\\00\\02\\00(\\08\\00\\00\\00\\00\\02\\000\\08\\00\\00\\00\\00\\02\\008\\08\\00\\00\\00\\00\\00\\00\\00\\14,\\00\\00\\00\\09\\00\\00\\0C\\00\\00\\00\\00H\\00\\00\\00\\00\\00\\00\\00\\02\\00\\00\\00\\06\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\02z\\01\\00\\00\\0A\\00\\00\\00\\0F\\00\\00\\00\\C4\\0F\\00\\19y\\06\\00\\00\\00\\00\\00\\00%\\00\\00\\00\\22\\0E\\00\\02x\\03\\00\\08\\00\\00\\00\\00\\0F\\00\\00\\00\\E2\\0F\\00\\B9z\\04\\00\\00F\\00\\00\\00\\0A\\00\\00\\00\\E2\\0F\\00\\02z\\04\\00\\00d\\00\\00\\00\\0F\\00\\00\\00\\E4\\0F\\00\\02z\\05\\00\\00e\\00\\00\\00\\0F\\00\\00\\00\\CA\\0F\\00\\80y\\04\\04\\04\\00\\00\\00\\00\\19\\10\\0C\\00\\A2\\0E\\00%v\\02\\06\\00Z\\00\\00\\03\\02\\8E\\07\\00\\CA\\1F\\00\\80y\\08\\02\\04\\00\\00\\00\\00\\19\\10\\0C\\00\\E8\\0E\\00\\80y\\09\\02\\04\\04\\00\\00\\00\\19\\10\\0C\\00\\E2\\0E\\00\\02x\\07\\00\\04\\00\\00\\00\\00\\0F\\00\\00\\00\\CA\\0F\\00%v\\06\\06\\00j\\00\\00\\07\\02\\8E\\07\\00\\E2\\0F\\00\\12s\\09\\00\\08\\00\\00\\00\\00\\140\\00\\00\\A4\\8E\\00 r\\0B\\04\\09\\00\\00\\00\\00\\00@\\00\\00\\CAO\\00\\85y\\00\\06\\0B\\00\\00\\00\\04\\19\\10\\0C\\00\\E2\\0F\\00My\\00\\00\\00\\00\\00\\00\\00\\00\\80\\03\\00\\EA\\0F\\00Gy\\00\\00\\F0\\FF\\FF\\FF\\FF\\FF\\83\\03\\00\\C0\\0F\\00\\18y\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\C0\\0F\\00\\18y\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\C0\\0F\\00\\18y\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\C0\\0F\\00\\18y\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\C0\\0F\\00\\18y\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\C0\\0F\\00\\18y\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\C0\\0F\\00\\18y\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\C0\\0F\\00\\18y\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\C0\\0F\\00\\18y\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\C0\\0F\\00\\18y\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\C0\\0F\\00\\18y\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\C0\\0F\\00\\18y\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\C0\\0F\\00\\18y\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\C0\\0F\\00\\18y\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\C0\\0F\\00\\18y\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\C0\\0F\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\01\\00\\00\\00\\03\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00@\\00\\00\\00\\00\\00\\00\\00:\\01\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\01\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\0B\\00\\00\\00\\03\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00z\\01\\00\\00\\00\\00\\00\\00X\\01\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\01\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\13\\00\\00\\00\\02\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\D8\\02\\00\\00\\00\\00\\00\\00\\A8\\00\\00\\00\\00\\00\\00\\00\\02\\00\\00\\00\\06\\00\\00\\00\\08\\00\\00\\00\\00\\00\\00\\00\\18\\00\\00\\00\\00\\00\\00\\00\\DF\\00\\00\\00\\01\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\80\\03\\00\\00\\00\\00\\00\\00p\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\01\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00)\\00\\00\\00\\00\\00\\00p\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\F0\\03\\00\\00\\00\\00\\00\\00$\\00\\00\\00\\00\\00\\00\\00\\03\\00\\00\\00\\00\\00\\00\\00\\04\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00O\\00\\00\\00\\00\\00\\00p\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\14\\04\\00\\00\\00\\00\\00\\00\\F8\\00\\00\\00\\00\\00\\00\\00\\03\\00\\00\\00\\0B\\00\\00\\00\\04\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\0F\\01\\00\\00\\01\\00\\00p\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\0C\\05\\00\\00\\00\\00\\00\\00\\18\\00\\00\\00\\00\\00\\00\\00\\03\\00\\00\\00\\00\\00\\00\\00\\04\\00\\00\\00\\00\\00\\00\\00\\08\\00\\00\\00\\00\\00\\00\\00+\\01\\00\\00\\0B\\00\\00p\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00(\\05\\00\\00\\00\\00\\00\\00\\E0\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\08\\00\\00\\00\\00\\00\\00\\00\\08\\00\\00\\00\\00\\00\\00\\00\\EC\\00\\00\\00\\09\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\08\\06\\00\\00\\00\\00\\00\\00\\10\\00\\00\\00\\00\\00\\00\\00\\03\\00\\00\\00\\04\\00\\00\\00\\08\\00\\00\\00\\00\\00\\00\\00\\10\\00\\00\\00\\00\\00\\00\\00\\91\\00\\00\\00\\01\\00\\00\\00\\02\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\18\\06\\00\\00\\00\\00\\00\\00\\C8\\01\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\0B\\00\\00\\00\\04\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\002\\00\\00\\00\\01\\00\\00\\00\\06\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\08\\00\\00\\00\\00\\00\\00\\00\\02\\00\\00\\00\\00\\00\\00\\03\\00\\00\\00\\06\\00\\00\\0E\\80\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\") {addr_space = 0 : i32}\n  llvm.func @JITOpGenerated0(%arg0: !llvm.ptr<f32>, %arg1: !llvm.ptr<f32>, %arg2: i64, %arg3: i64, %arg4: i64, %arg5: !llvm.ptr<i64>, %arg6: !llvm.ptr<i64>, %arg7: i64, %arg8: i64, %arg9: i64, %arg10: !llvm.ptr<f32>, %arg11: !llvm.ptr<f32>, %arg12: i64, %arg13: i64, %arg14: i64, %arg15: !llvm.ptr<i8>) attributes {llvm.emit_c_interface} {\n    %0 = llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>\n    %1 = llvm.insertvalue %arg0, %0[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)> \n    %2 = llvm.insertvalue %arg1, %1[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)> \n    %3 = llvm.insertvalue %arg2, %2[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)> \n    %4 = llvm.insertvalue %arg3, %3[3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)> \n    %5 = llvm.insertvalue %arg4, %4[4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)> \n    %6 = builtin.unrealized_conversion_cast %5 : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)> to memref<1xf32>\n    %7 = llvm.mlir.undef : !llvm.struct<(ptr<i64>, ptr<i64>, i64, array<1 x i64>, array<1 x i64>)>\n    %8 = llvm.insertvalue %arg5, %7[0] : !llvm.struct<(ptr<i64>, ptr<i64>, i64, array<1 x i64>, array<1 x i64>)> \n    %9 = llvm.insertvalue %arg6, %8[1] : !llvm.struct<(ptr<i64>, ptr<i64>, i64, array<1 x i64>, array<1 x i64>)> \n    %10 = llvm.insertvalue %arg7, %9[2] : !llvm.struct<(ptr<i64>, ptr<i64>, i64, array<1 x i64>, array<1 x i64>)> \n    %11 = llvm.insertvalue %arg8, %10[3, 0] : !llvm.struct<(ptr<i64>, ptr<i64>, i64, array<1 x i64>, array<1 x i64>)> \n    %12 = llvm.insertvalue %arg9, %11[4, 0] : !llvm.struct<(ptr<i64>, ptr<i64>, i64, array<1 x i64>, array<1 x i64>)> \n    %13 = llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>\n    %14 = llvm.insertvalue %arg10, %13[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)> \n    %15 = llvm.insertvalue %arg11, %14[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)> \n    %16 = llvm.insertvalue %arg12, %15[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)> \n    %17 = llvm.insertvalue %arg13, %16[3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)> \n    %18 = llvm.insertvalue %arg14, %17[4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)> \n    %19 = llvm.mlir.constant(5 : index) : i64\n    %20 = llvm.mlir.constant(1 : index) : i64\n    %collapse_shape = memref.collapse_shape %6 [] : memref<1xf32> into memref<f32>\n    %21 = builtin.unrealized_conversion_cast %collapse_shape : memref<f32> to !llvm.struct<(ptr<f32>, ptr<f32>, i64)>\n    %22 = llvm.mlir.addressof @JITOpGenerated0_kernel_gpubin_cst : !llvm.ptr<array<3328 x i8>>\n    %23 = llvm.getelementptr %22[0, 0] : (!llvm.ptr<array<3328 x i8>>) -> !llvm.ptr<i8>\n    %24 = llvm.call @mgpuModuleLoad(%23) : (!llvm.ptr<i8>) -> !llvm.ptr<i8>\n    %25 = llvm.mlir.addressof @JITOpGenerated0_kernel_JITOpGenerated0_kernel_kernel_name : !llvm.ptr<array<23 x i8>>\n    %26 = llvm.getelementptr %25[0, 0] : (!llvm.ptr<array<23 x i8>>) -> !llvm.ptr<i8>\n    %27 = llvm.call @mgpuModuleGetFunction(%24, %26) : (!llvm.ptr<i8>, !llvm.ptr<i8>) -> !llvm.ptr<i8>\n    %28 = llvm.mlir.constant(0 : i32) : i32\n    // CHECK-NOT: mgpuStreamCreate\n    %29 = llvm.call @mgpuStreamCreate() : () -> !llvm.ptr<i8>\n    %30 = llvm.extractvalue %12[0] : !llvm.struct<(ptr<i64>, ptr<i64>, i64, array<1 x i64>, array<1 x i64>)> \n    %31 = llvm.extractvalue %12[1] : !llvm.struct<(ptr<i64>, ptr<i64>, i64, array<1 x i64>, array<1 x i64>)> \n    %32 = llvm.extractvalue %12[2] : !llvm.struct<(ptr<i64>, ptr<i64>, i64, array<1 x i64>, array<1 x i64>)> \n    %33 = llvm.extractvalue %12[3, 0] : !llvm.struct<(ptr<i64>, ptr<i64>, i64, array<1 x i64>, array<1 x i64>)> \n    %34 = llvm.extractvalue %12[4, 0] : !llvm.struct<(ptr<i64>, ptr<i64>, i64, array<1 x i64>, array<1 x i64>)> \n    %35 = llvm.extractvalue %21[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64)> \n    %36 = llvm.extractvalue %21[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64)> \n    %37 = llvm.extractvalue %21[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64)> \n    %38 = llvm.extractvalue %18[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)> \n    %39 = llvm.extractvalue %18[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)> \n    %40 = llvm.extractvalue %18[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)> \n    %41 = llvm.extractvalue %18[3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)> \n    %42 = llvm.extractvalue %18[4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)> \n    %43 = llvm.mlir.constant(1 : i32) : i32\n    %44 = llvm.alloca %43 x !llvm.struct<\"\", (ptr<i64>, ptr<i64>, i64, i64, i64, ptr<f32>, ptr<f32>, i64, ptr<f32>, ptr<f32>, i64, i64, i64)> : (i32) -> !llvm.ptr<struct<\"\", (ptr<i64>, ptr<i64>, i64, i64, i64, ptr<f32>, ptr<f32>, i64, ptr<f32>, ptr<f32>, i64, i64, i64)>>\n    %45 = llvm.mlir.constant(13 : i32) : i32\n    %46 = llvm.alloca %45 x !llvm.ptr<i8> : (i32) -> !llvm.ptr<ptr<i8>>\n    %47 = llvm.getelementptr %44[0, 0] : (!llvm.ptr<struct<\"\", (ptr<i64>, ptr<i64>, i64, i64, i64, ptr<f32>, ptr<f32>, i64, ptr<f32>, ptr<f32>, i64, i64, i64)>>) -> !llvm.ptr<ptr<i64>>\n    llvm.store %30, %47 : !llvm.ptr<ptr<i64>>\n    %48 = llvm.getelementptr %46[0] : (!llvm.ptr<ptr<i8>>) -> !llvm.ptr<ptr<i8>>\n    %49 = llvm.bitcast %47 : !llvm.ptr<ptr<i64>> to !llvm.ptr<i8>\n    llvm.store %49, %48 : !llvm.ptr<ptr<i8>>\n    %50 = llvm.getelementptr %44[0, 1] : (!llvm.ptr<struct<\"\", (ptr<i64>, ptr<i64>, i64, i64, i64, ptr<f32>, ptr<f32>, i64, ptr<f32>, ptr<f32>, i64, i64, i64)>>) -> !llvm.ptr<ptr<i64>>\n    llvm.store %31, %50 : !llvm.ptr<ptr<i64>>\n    %51 = llvm.getelementptr %46[1] : (!llvm.ptr<ptr<i8>>) -> !llvm.ptr<ptr<i8>>\n    %52 = llvm.bitcast %50 : !llvm.ptr<ptr<i64>> to !llvm.ptr<i8>\n    llvm.store %52, %51 : !llvm.ptr<ptr<i8>>\n    %53 = llvm.getelementptr %44[0, 2] : (!llvm.ptr<struct<\"\", (ptr<i64>, ptr<i64>, i64, i64, i64, ptr<f32>, ptr<f32>, i64, ptr<f32>, ptr<f32>, i64, i64, i64)>>) -> !llvm.ptr<i64>\n    llvm.store %32, %53 : !llvm.ptr<i64>\n    %54 = llvm.getelementptr %46[2] : (!llvm.ptr<ptr<i8>>) -> !llvm.ptr<ptr<i8>>\n    %55 = llvm.bitcast %53 : !llvm.ptr<i64> to !llvm.ptr<i8>\n    llvm.store %55, %54 : !llvm.ptr<ptr<i8>>\n    %56 = llvm.getelementptr %44[0, 3] : (!llvm.ptr<struct<\"\", (ptr<i64>, ptr<i64>, i64, i64, i64, ptr<f32>, ptr<f32>, i64, ptr<f32>, ptr<f32>, i64, i64, i64)>>) -> !llvm.ptr<i64>\n    llvm.store %33, %56 : !llvm.ptr<i64>\n    %57 = llvm.getelementptr %46[3] : (!llvm.ptr<ptr<i8>>) -> !llvm.ptr<ptr<i8>>\n    %58 = llvm.bitcast %56 : !llvm.ptr<i64> to !llvm.ptr<i8>\n    llvm.store %58, %57 : !llvm.ptr<ptr<i8>>\n    %59 = llvm.getelementptr %44[0, 4] : (!llvm.ptr<struct<\"\", (ptr<i64>, ptr<i64>, i64, i64, i64, ptr<f32>, ptr<f32>, i64, ptr<f32>, ptr<f32>, i64, i64, i64)>>) -> !llvm.ptr<i64>\n    llvm.store %34, %59 : !llvm.ptr<i64>\n    %60 = llvm.getelementptr %46[4] : (!llvm.ptr<ptr<i8>>) -> !llvm.ptr<ptr<i8>>\n    %61 = llvm.bitcast %59 : !llvm.ptr<i64> to !llvm.ptr<i8>\n    llvm.store %61, %60 : !llvm.ptr<ptr<i8>>\n    %62 = llvm.getelementptr %44[0, 5] : (!llvm.ptr<struct<\"\", (ptr<i64>, ptr<i64>, i64, i64, i64, ptr<f32>, ptr<f32>, i64, ptr<f32>, ptr<f32>, i64, i64, i64)>>) -> !llvm.ptr<ptr<f32>>\n    llvm.store %35, %62 : !llvm.ptr<ptr<f32>>\n    %63 = llvm.getelementptr %46[5] : (!llvm.ptr<ptr<i8>>) -> !llvm.ptr<ptr<i8>>\n    %64 = llvm.bitcast %62 : !llvm.ptr<ptr<f32>> to !llvm.ptr<i8>\n    llvm.store %64, %63 : !llvm.ptr<ptr<i8>>\n    %65 = llvm.getelementptr %44[0, 6] : (!llvm.ptr<struct<\"\", (ptr<i64>, ptr<i64>, i64, i64, i64, ptr<f32>, ptr<f32>, i64, ptr<f32>, ptr<f32>, i64, i64, i64)>>) -> !llvm.ptr<ptr<f32>>\n    llvm.store %36, %65 : !llvm.ptr<ptr<f32>>\n    %66 = llvm.getelementptr %46[6] : (!llvm.ptr<ptr<i8>>) -> !llvm.ptr<ptr<i8>>\n    %67 = llvm.bitcast %65 : !llvm.ptr<ptr<f32>> to !llvm.ptr<i8>\n    llvm.store %67, %66 : !llvm.ptr<ptr<i8>>\n    %68 = llvm.getelementptr %44[0, 7] : (!llvm.ptr<struct<\"\", (ptr<i64>, ptr<i64>, i64, i64, i64, ptr<f32>, ptr<f32>, i64, ptr<f32>, ptr<f32>, i64, i64, i64)>>) -> !llvm.ptr<i64>\n    llvm.store %37, %68 : !llvm.ptr<i64>\n    %69 = llvm.getelementptr %46[7] : (!llvm.ptr<ptr<i8>>) -> !llvm.ptr<ptr<i8>>\n    %70 = llvm.bitcast %68 : !llvm.ptr<i64> to !llvm.ptr<i8>\n    llvm.store %70, %69 : !llvm.ptr<ptr<i8>>\n    %71 = llvm.getelementptr %44[0, 8] : (!llvm.ptr<struct<\"\", (ptr<i64>, ptr<i64>, i64, i64, i64, ptr<f32>, ptr<f32>, i64, ptr<f32>, ptr<f32>, i64, i64, i64)>>) -> !llvm.ptr<ptr<f32>>\n    llvm.store %38, %71 : !llvm.ptr<ptr<f32>>\n    %72 = llvm.getelementptr %46[8] : (!llvm.ptr<ptr<i8>>) -> !llvm.ptr<ptr<i8>>\n    %73 = llvm.bitcast %71 : !llvm.ptr<ptr<f32>> to !llvm.ptr<i8>\n    llvm.store %73, %72 : !llvm.ptr<ptr<i8>>\n    %74 = llvm.getelementptr %44[0, 9] : (!llvm.ptr<struct<\"\", (ptr<i64>, ptr<i64>, i64, i64, i64, ptr<f32>, ptr<f32>, i64, ptr<f32>, ptr<f32>, i64, i64, i64)>>) -> !llvm.ptr<ptr<f32>>\n    llvm.store %39, %74 : !llvm.ptr<ptr<f32>>\n    %75 = llvm.getelementptr %46[9] : (!llvm.ptr<ptr<i8>>) -> !llvm.ptr<ptr<i8>>\n    %76 = llvm.bitcast %74 : !llvm.ptr<ptr<f32>> to !llvm.ptr<i8>\n    llvm.store %76, %75 : !llvm.ptr<ptr<i8>>\n    %77 = llvm.getelementptr %44[0, 10] : (!llvm.ptr<struct<\"\", (ptr<i64>, ptr<i64>, i64, i64, i64, ptr<f32>, ptr<f32>, i64, ptr<f32>, ptr<f32>, i64, i64, i64)>>) -> !llvm.ptr<i64>\n    llvm.store %40, %77 : !llvm.ptr<i64>\n    %78 = llvm.getelementptr %46[10] : (!llvm.ptr<ptr<i8>>) -> !llvm.ptr<ptr<i8>>\n    %79 = llvm.bitcast %77 : !llvm.ptr<i64> to !llvm.ptr<i8>\n    llvm.store %79, %78 : !llvm.ptr<ptr<i8>>\n    %80 = llvm.getelementptr %44[0, 11] : (!llvm.ptr<struct<\"\", (ptr<i64>, ptr<i64>, i64, i64, i64, ptr<f32>, ptr<f32>, i64, ptr<f32>, ptr<f32>, i64, i64, i64)>>) -> !llvm.ptr<i64>\n    llvm.store %41, %80 : !llvm.ptr<i64>\n    %81 = llvm.getelementptr %46[11] : (!llvm.ptr<ptr<i8>>) -> !llvm.ptr<ptr<i8>>\n    %82 = llvm.bitcast %80 : !llvm.ptr<i64> to !llvm.ptr<i8>\n    llvm.store %82, %81 : !llvm.ptr<ptr<i8>>\n    %83 = llvm.getelementptr %44[0, 12] : (!llvm.ptr<struct<\"\", (ptr<i64>, ptr<i64>, i64, i64, i64, ptr<f32>, ptr<f32>, i64, ptr<f32>, ptr<f32>, i64, i64, i64)>>) -> !llvm.ptr<i64>\n    llvm.store %42, %83 : !llvm.ptr<i64>\n    %84 = llvm.getelementptr %46[12] : (!llvm.ptr<ptr<i8>>) -> !llvm.ptr<ptr<i8>>\n    %85 = llvm.bitcast %83 : !llvm.ptr<i64> to !llvm.ptr<i8>\n    llvm.store %85, %84 : !llvm.ptr<ptr<i8>>\n    %86 = llvm.mlir.null : !llvm.ptr<ptr<i8>>\n    // CHECK-NOT: mgpuLaunchKernel(%18, %4, %3, %3, %3, %3, %3, %2, %arg15, %23, %62)\n    llvm.call @mgpuLaunchKernel(%27, %19, %20, %20, %20, %20, %20, %28, %29, %46, %86) : (!llvm.ptr<i8>, i64, i64, i64, i64, i64, i64, i32, !llvm.ptr<i8>, !llvm.ptr<ptr<i8>>, !llvm.ptr<ptr<i8>>) -> ()\n    llvm.call @mgpuStreamSynchronize(%29) : (!llvm.ptr<i8>) -> ()\n    // CHECK-NOT: mgpuStreamDestroy\n    llvm.call @mgpuStreamDestroy(%29) : (!llvm.ptr<i8>) -> ()\n    llvm.call @mgpuModuleUnload(%24) : (!llvm.ptr<i8>) -> ()\n    llvm.return\n  }\n  llvm.func @_mlir_ciface_JITOpGenerated0(%arg0: !llvm.ptr<struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>>, %arg1: !llvm.ptr<struct<(ptr<i64>, ptr<i64>, i64, array<1 x i64>, array<1 x i64>)>>, %arg2: !llvm.ptr<struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>>, %arg3: !llvm.ptr<i8>) attributes {llvm.emit_c_interface} {\n    %0 = llvm.load %arg0 : !llvm.ptr<struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>>\n    %1 = llvm.extractvalue %0[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)> \n    %2 = llvm.extractvalue %0[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)> \n    %3 = llvm.extractvalue %0[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)> \n    %4 = llvm.extractvalue %0[3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)> \n    %5 = llvm.extractvalue %0[4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)> \n    %6 = llvm.load %arg1 : !llvm.ptr<struct<(ptr<i64>, ptr<i64>, i64, array<1 x i64>, array<1 x i64>)>>\n    %7 = llvm.extractvalue %6[0] : !llvm.struct<(ptr<i64>, ptr<i64>, i64, array<1 x i64>, array<1 x i64>)> \n    %8 = llvm.extractvalue %6[1] : !llvm.struct<(ptr<i64>, ptr<i64>, i64, array<1 x i64>, array<1 x i64>)> \n    %9 = llvm.extractvalue %6[2] : !llvm.struct<(ptr<i64>, ptr<i64>, i64, array<1 x i64>, array<1 x i64>)> \n    %10 = llvm.extractvalue %6[3, 0] : !llvm.struct<(ptr<i64>, ptr<i64>, i64, array<1 x i64>, array<1 x i64>)> \n    %11 = llvm.extractvalue %6[4, 0] : !llvm.struct<(ptr<i64>, ptr<i64>, i64, array<1 x i64>, array<1 x i64>)> \n    %12 = llvm.load %arg2 : !llvm.ptr<struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>>\n    %13 = llvm.extractvalue %12[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)> \n    %14 = llvm.extractvalue %12[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)> \n    %15 = llvm.extractvalue %12[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)> \n    %16 = llvm.extractvalue %12[3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)> \n    %17 = llvm.extractvalue %12[4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)> \n    llvm.call @JITOpGenerated0(%1, %2, %3, %4, %5, %7, %8, %9, %10, %11, %13, %14, %15, %16, %17, %arg3) : (!llvm.ptr<f32>, !llvm.ptr<f32>, i64, i64, i64, !llvm.ptr<i64>, !llvm.ptr<i64>, i64, i64, i64, !llvm.ptr<f32>, !llvm.ptr<f32>, i64, i64, i64, !llvm.ptr<i8>) -> ()\n    llvm.return\n  }\n  llvm.func @mgpuModuleLoad(!llvm.ptr<i8>) -> !llvm.ptr<i8>\n  llvm.func @mgpuModuleGetFunction(!llvm.ptr<i8>, !llvm.ptr<i8>) -> !llvm.ptr<i8>\n  llvm.func @mgpuStreamCreate() -> !llvm.ptr<i8>\n  llvm.func @mgpuLaunchKernel(!llvm.ptr<i8>, i64, i64, i64, i64, i64, i64, i32, !llvm.ptr<i8>, !llvm.ptr<ptr<i8>>, !llvm.ptr<ptr<i8>>)\n  llvm.func @mgpuStreamSynchronize(!llvm.ptr<i8>)\n  llvm.func @mgpuStreamDestroy(!llvm.ptr<i8>)\n  llvm.func @mgpuModuleUnload(!llvm.ptr<i8>)\n}"
  },
  {
    "path": "oneflow/ir/test/OneFlow/cuda_code_gen/tosa_to_linalg.mlir",
    "content": "// RUN: oneflow-opt %s -ofjob-to-func --tosa-make-broadcastable \\\n// RUN: | oneflow-opt -pass-pipeline=\"builtin.module(oneflow.job(tosa-to-linalg))\" \\\n// RUN: | oneflow-opt -func-to-ofjob\n\noneflow.job @GraphToRun_1(%arg0: tensor<2x5xi64>, %arg1: tensor<1xf32>) -> tensor<2x5xf32> {\n    %2 = \"tosa.cast\"(%arg0) : (tensor<2x5xi64>) -> tensor<2x5xf32>\n    %3 = \"tosa.mul\"(%2, %arg1) {shift = 0 : i32} : (tensor<2x5xf32>, tensor<1xf32>) -> tensor<2x5xf32>\n    oneflow.return %3 : tensor<2x5xf32>\n}\n"
  },
  {
    "path": "oneflow/ir/test/OneFlow/folding/test_conv_bn.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n# RUN: python3 -m oneflow.test_utils.throttle --with-cuda=%with_cuda python3 %s | FileCheck %s\n# CHECK-NOT: oneflow.normalization\n\nimport os\nimport unittest\nimport numpy as np\nimport oneflow as flow\nimport oneflow.unittest\nimport oneflow.nn as nn\nfrom flowvision.models.resnet import resnet50\n\n\ndef _test_fuse_conv_bn(test_case):\n    data = flow.randn(1, 3, 224, 224)\n\n    model = resnet50(pretrained=False, progress=True)\n    model.eval()\n    eager_res = model(data)\n\n    class Resnet50Graph(nn.Graph):\n        def __init__(self):\n            super().__init__()\n            self.model = model\n\n        def build(self, *input):\n            return self.model(*input)\n\n    graph = Resnet50Graph()\n    lazy_res = graph(data)\n\n    test_case.assertTrue(\n        np.allclose(eager_res.numpy(), lazy_res.numpy(), rtol=1e-2, atol=1e-2)\n    )\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestFuseConvBn(oneflow.unittest.MLIRTestCase):\n    def setUp(self):\n        os.environ[\"ONEFLOW_MLIR_ENABLE_ROUND_TRIP\"] = \"1\"\n        os.environ[\"ONEFLOW_MLIR_ENABLE_INFERENCE_OPTIMIZATION\"] = \"1\"\n\n    def test_fuse_conv_bn(test_case):\n        _test_fuse_conv_bn(test_case)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "oneflow/ir/test/OneFlow/folding/test_simple_multiply.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n# RUN: python3 -m oneflow.test_utils.throttle --with-cuda=%with_cuda python3 %s | FileCheck %s\n# CHECK-NOT: oneflow.broadcast_mul\n\nimport os\nimport unittest\nimport numpy as np\nimport oneflow as flow\nimport oneflow.unittest\nimport oneflow.nn as nn\n\n\nclass MultiplyModel(nn.Module):\n    def __init__(self, dtype=flow.float32):\n        super().__init__()\n        self.dtype = dtype\n        self.x = nn.Parameter(flow.tensor([2, 2], dtype=self.dtype), False)\n        self.y = nn.Parameter(flow.tensor([3, 3], dtype=self.dtype), False)\n\n    def forward(self):\n        return self.x * self.y\n\n\nclass MultiplyModelComplex(MultiplyModel):\n    def __init__(self, dtype=flow.float32):\n        super().__init__(dtype)\n        self.z = nn.Parameter(flow.tensor([4, 5], dtype=self.dtype), False)\n\n    def forward(self):\n        return self.x * self.y * self.z\n\n\nclass MultiplyModelWithInput(MultiplyModel):\n    def __init__(self, dtype=flow.float32):\n        super().__init__(dtype)\n\n    def forward(self, a: flow.Tensor, b: flow.Tensor):\n        z = self.x * self.y\n        return a + b + z\n\n\ndef _test_fold_multiply(test_case, module, with_cuda, *args, dtype=oneflow.float32):\n    model = module(dtype)\n\n    if with_cuda:\n        model.to(\"cuda\")\n    model.eval()\n    eager_res = model(*args)\n\n    class MultiplyGraph(nn.Graph):\n        def __init__(self):\n            super().__init__()\n            self.model = model\n\n        def build(self, *args):\n            return self.model(*args)\n\n    graph = MultiplyGraph()\n    lazy_res = graph(*args)\n\n    test_case.assertTrue(\n        np.allclose(eager_res.numpy(), lazy_res.numpy(), rtol=1e-5, atol=1e-5)\n    )\n    test_case.assertTrue(eager_res.dtype == dtype and lazy_res.dtype == dtype)\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestFoldMultiply(oneflow.unittest.MLIRTestCase):\n    def setUp(self):\n        os.environ[\"ONEFLOW_MLIR_ENABLE_ROUND_TRIP\"] = \"1\"\n        os.environ[\"ONEFLOW_MLIR_ENABLE_INFERENCE_OPTIMIZATION\"] = \"1\"\n\n    def test_fold_multiply(test_case):\n        _test_fold_multiply(test_case, MultiplyModel, with_cuda=False)\n        _test_fold_multiply(\n            test_case, MultiplyModel, with_cuda=False, dtype=flow.float16\n        )\n\n    @unittest.skipUnless(oneflow.sysconfig.with_cuda(), \"only test cpu cases\")\n    def test_fold_multiply_cuda(test_case):\n        _test_fold_multiply(test_case, MultiplyModel, with_cuda=True)\n        _test_fold_multiply(\n            test_case, MultiplyModel, with_cuda=True, dtype=flow.float16\n        )\n\n    def test_fold_multiply_complex(test_case):\n        _test_fold_multiply(test_case, MultiplyModelComplex, with_cuda=False)\n        _test_fold_multiply(\n            test_case, MultiplyModelComplex, with_cuda=False, dtype=flow.float16\n        )\n\n    @unittest.skipUnless(oneflow.sysconfig.with_cuda(), \"only test cpu cases\")\n    def test_fold_multiply_complex_cuda(test_case):\n        _test_fold_multiply(test_case, MultiplyModelComplex, with_cuda=True)\n        _test_fold_multiply(\n            test_case, MultiplyModelComplex, with_cuda=True, dtype=flow.float16\n        )\n\n    def test_fold_multiply_with_input(test_case):\n        a = flow.tensor([3, 7], dtype=flow.float32)\n        b = flow.tensor([9, -1], dtype=flow.float32)\n        a_fp16 = flow.tensor([3, 7], dtype=flow.float16)\n        b_fp16 = flow.tensor([9, -1], dtype=flow.float16)\n        _test_fold_multiply(test_case, MultiplyModelWithInput, False, a, b)\n        _test_fold_multiply(\n            test_case, MultiplyModelWithInput, False, a_fp16, b_fp16, dtype=flow.float16\n        )\n\n    @unittest.skipUnless(oneflow.sysconfig.with_cuda(), \"only test cpu cases\")\n    def test_fold_multiply_with_input_cuda(test_case):\n        a = flow.tensor([3, 7], dtype=flow.float32, device=\"cuda\")\n        b = flow.tensor([9, -1], dtype=flow.float32, device=\"cuda\")\n        a_fp16 = flow.tensor([3, 7], dtype=flow.float16, device=\"cuda\")\n        b_fp16 = flow.tensor([9, -1], dtype=flow.float16, device=\"cuda\")\n        _test_fold_multiply(test_case, MultiplyModelWithInput, True, a, b)\n        _test_fold_multiply(\n            test_case, MultiplyModelWithInput, True, a_fp16, b_fp16, dtype=flow.float16\n        )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "oneflow/ir/test/OneFlow/fuse/fuse_forward_ops.mlir",
    "content": "// RUN: oneflow-opt %s \\\n// RUN: -fuse-forward-only-ops -fuse-into-existing-op -fuse-normalization-ops -convert-inference-op -fuse-ops-with-backward-impl -canonicalize | FileCheck %s\n\nmodule  {\n  func.func @Cast_1__FUSE__ScalarMulByTensor_2(%685: tensor<2x64x64x320xf16>, %output_574: tensor<320xf16>, %output_573: tensor<320xf16>) -> tensor<2x64x64x320xf16> {\n    %y_958, %mean_959, %inv_variance_960 = \"oneflow.group_norm\"(%685, %output_574, %output_573) {activation = \"none\", affine = true, data_format = \"channels_last\", device_name = [\"@0:0\"], device_tag = \"cuda\", epsilon = 1.000000e-05 : f64, hierarchy = [1], num_groups = 32 : si32, op_name = \"unet.up_blocks.3.resnets.0.norm2-group_norm-877\", operand_segment_sizes = array<i32: 1, 1, 1>, scope_symbol_id = 5517 : i64} : (tensor<2x64x64x320xf16>, tensor<320xf16>, tensor<320xf16>) -> (tensor<2x64x64x320xf16>, tensor<2x32xf32>, tensor<2x32xf32>)\n    %686 = \"oneflow.silu\"(%y_958) {device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], op_name = \"unet.up_blocks.3.resnets.0.nonlinearity-silu-878\", scope_symbol_id = 5466 : i64} : (tensor<2x64x64x320xf16>) -> tensor<2x64x64x320xf16>\n    // CHECK: activation = \"silu\"\n    // CHECK-NOT: oneflow.silu\n    return %686 : tensor<2x64x64x320xf16>\n  }\n\n  func.func @GraphToRun_bias_add_and_dropout_0(%arg0: tensor<2x3x4x5xf32>, %arg1: tensor<5xf32>) -> (tensor<2x3x4x5xf32>, tensor<2x3x4x5xi8>) {\n    %0 = \"oneflow.bias_add\"(%arg0, %arg1) {axis = 3 : si32, device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], op_name = \"bias_add-0\", scope_symbol_id = 12 : i64} : (tensor<2x3x4x5xf32>, tensor<5xf32>) -> tensor<2x3x4x5xf32>\n    %out, %mask = \"oneflow.dropout\"(%0) {device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], op_name = \"dropout-dropout-1\", rate = 0.750000e+00 : f32, scope_symbol_id = 22 : i64} : (tensor<2x3x4x5xf32>) -> (tensor<2x3x4x5xf32>, tensor<2x3x4x5xi8>)\n    // CHECK: func.func @GraphToRun_bias_add_and_dropout_0(%[[A:[a-zA-Z0-9_]+]]: tensor<2x3x4x5xf32>, %[[B:[a-zA-Z0-9_]+]]: tensor<5xf32>) -> (tensor<2x3x4x5xf32>, tensor<2x3x4x5xi8>)\n    // CHECK: %[[MASK:[a-zA-Z0-9_]+]] = \"oneflow.random_mask_like\"(%[[A]])\n    // CHECK: \"oneflow.fused_bias_add_mask_scale\"(%[[A]], %[[B]], %[[MASK]])\n    // CHECK: scale = 4.000000e+00\n    return %out, %mask : tensor<2x3x4x5xf32>, tensor<2x3x4x5xi8>\n  }\n\n  func.func @GraphToRun_bias_add_and_gelu_0(%arg0: tensor<2x3x4x5xf32>, %arg1: tensor<5xf32>) -> tensor<2x3x4x5xf32> {\n    %0 = \"oneflow.bias_add\"(%arg0, %arg1) {axis = 3 : si32, device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], op_name = \"bias_add-0\", scope_symbol_id = 12 : i64} : (tensor<2x3x4x5xf32>, tensor<5xf32>) -> tensor<2x3x4x5xf32>\n    %out = \"oneflow.gelu\"(%0) {axis = 3 : si32, device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], op_name = \"gelu-gelu-1\", scope_symbol_id = 22 : i64} : (tensor<2x3x4x5xf32>) -> tensor<2x3x4x5xf32>\n    // CHECK: func.func @GraphToRun_bias_add_and_gelu_0(%[[A:[a-zA-Z0-9_]+]]: tensor<2x3x4x5xf32>, %[[B:[a-zA-Z0-9_]+]]: tensor<5xf32>) -> tensor<2x3x4x5xf32>\n    // CHECK: %[[OUT0:[a-zA-Z0-9_]+]] = \"oneflow.fused_bias_add_gelu\"(%[[A]], %[[B]]) {axis = 3 : si32\n    // CHECK： return %[[OUT0]]\n    return %out : tensor<2x3x4x5xf32>\n  }\n\n  func.func @fuse_mha(%query: tensor<2x4096x320xf16>, %key: tensor<2x4096x320xf16>, %value: tensor<2x4096x320xf16>) -> tensor<2x4096x320xf16> {\n    %query_reshape = \"oneflow.reshape\"(%query) {device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], op_name = \"reshape-1\", scope_symbol_id = 12 : i64, shape = [2 : si64, 4096 : si64, 8 : si64, 40 : si64]} : (tensor<2x4096x320xf16>) -> tensor<2x4096x8x40xf16>\n    %key_reshape = \"oneflow.reshape\"(%key) {device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], op_name = \"reshape-3\", scope_symbol_id = 12 : i64, shape = [2 : si64, 4096 : si64, 8 : si64, 40 : si64]} : (tensor<2x4096x320xf16>) -> tensor<2x4096x8x40xf16>\n    %value_reshape = \"oneflow.reshape\"(%value) {device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], op_name = \"reshape-5\", scope_symbol_id = 12 : i64, shape = [2 : si64, 4096 : si64, 8 : si64, 40 : si64]} : (tensor<2x4096x320xf16>) -> tensor<2x4096x8x40xf16>\n    %query_transpose = \"oneflow.transpose\"(%query_reshape) {device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], op_name = \"transpose-2\", perm = [0 : si32, 2 : si32, 1 : si32, 3 : si32], scope_symbol_id = 12 : i64} : (tensor<2x4096x8x40xf16>) -> tensor<2x8x4096x40xf16>\n    %key_transpose = \"oneflow.transpose\"(%key_reshape) {device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], op_name = \"transpose-4\", perm = [0 : si32, 2 : si32, 3 : si32, 1 : si32], scope_symbol_id = 12 : i64} : (tensor<2x4096x8x40xf16>) -> tensor<2x8x40x4096xf16>\n    %value_transpose = \"oneflow.transpose\"(%value_reshape) {device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], op_name = \"transpose-6\", perm = [0 : si32, 2 : si32, 1 : si32, 3 : si32], scope_symbol_id = 12 : i64} : (tensor<2x4096x8x40xf16>) -> tensor<2x8x4096x40xf16>\n    %scores = \"oneflow.batch_matmul\"(%query_transpose, %key_transpose) {alpha = 1.000000e+00 : f64, device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], op_name = \"batch_matmul-7\", scope_symbol_id = 12 : i64, transpose_a = false, transpose_b = false} : (tensor<2x8x4096x40xf16>, tensor<2x8x40x4096xf16>) -> tensor<2x8x4096x4096xf16>\n    %scores_scaled = \"oneflow.scalar_div\"(%scores) {device_name = [\"@0:0\"], device_tag = \"cuda\", float_operand = 6.324555320336759 : f64, has_float_operand = true, has_int_operand = false, hierarchy = [1], int_operand = 0 : si64, op_name = \"scalar_div-8\", scope_symbol_id = 12 : i64} : (tensor<2x8x4096x4096xf16>) -> tensor<2x8x4096x4096xf16>\n    %attn = \"oneflow.softmax\"(%scores_scaled) {device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], op_name = \"softmax-9\", scope_symbol_id = 12 : i64} : (tensor<2x8x4096x4096xf16>) -> tensor<2x8x4096x4096xf16>\n    %out = \"oneflow.batch_matmul\"(%attn, %value_transpose) {alpha = 1.000000e+00 : f64, device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], op_name = \"batch_matmul-10\", scope_symbol_id = 12 : i64, transpose_a = false, transpose_b = false} : (tensor<2x8x4096x4096xf16>, tensor<2x8x4096x40xf16>) -> tensor<2x8x4096x40xf16>\n    %out_transpose = \"oneflow.transpose\"(%out) {device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], op_name = \"transpose-11\", perm = [0 : si32, 2 : si32, 1 : si32, 3 : si32], scope_symbol_id = 12 : i64} : (tensor<2x8x4096x40xf16>) -> tensor<2x4096x8x40xf16>\n    %out_reshape = \"oneflow.reshape\"(%out_transpose) {device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], op_name = \"reshape-12\", scope_symbol_id = 12 : i64, shape = [2 : si64, 4096 : si64, 320 : si64]} : (tensor<2x4096x8x40xf16>) -> tensor<2x4096x320xf16>\n    // CHECK: func.func @fuse_mha(%[[QUERY:[a-zA-Z0-9_]+]]: tensor<2x4096x320xf16>, %[[KEY:[a-zA-Z0-9_]+]]: tensor<2x4096x320xf16>, %[[VALUE:[a-zA-Z0-9_]+]]: tensor<2x4096x320xf16>)\n    // CHECK: \"oneflow.fused_multi_head_attention_inference\"(%[[QUERY]], %[[KEY]], %[[VALUE]]) {attn_mask_type = \"none\", causal_diagonal_offset = 0 : si64, device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], key_layout = \"BM(HK)\", key_max_seq_len = 0 : si64, op_name = [[OP_NAME:\".*\"]], operand_segment_sizes = array<i32: 1, 1, 1, 0, 0, 0, 0>, output_layout = \"BM(HK)\", query_head_size = 40 : si64, query_layout = \"BM(HK)\", query_max_seq_len = 0 : si64, scale = 0.15811388300841897 : f64, scope_symbol_id = 12 : i64, value_layout = \"BM(HK)\"} : (tensor<2x4096x320xf16>, tensor<2x4096x320xf16>, tensor<2x4096x320xf16>) -> tensor<2x4096x320xf16>\n    return %out_reshape : tensor<2x4096x320xf16>\n  }\n\n  func.func @fuse_mha2(%query: tensor<2x4096x320xf16>, %key: tensor<2x4096x320xf16>, %value: tensor<2x4096x320xf16>) -> tensor<2x4096x320xf16> {\n    %value_reshape = \"oneflow.reshape\"(%value) {device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], op_name = \"unet.down_blocks.0.attentions.1.transformer_blocks.0.attn1-reshape-124\", scope_symbol_id = 661 : i64, shape = [2 : si64, 4096 : si64, 8 : si64, 40 : si64]} : (tensor<2x4096x320xf16>) -> tensor<2x4096x8x40xf16>\n    %key_reshape = \"oneflow.reshape\"(%key) {device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], op_name = \"unet.down_blocks.0.attentions.1.transformer_blocks.0.attn1-reshape-121\", scope_symbol_id = 661 : i64, shape = [2 : si64, 4096 : si64, 8 : si64, 40 : si64]} : (tensor<2x4096x320xf16>) -> tensor<2x4096x8x40xf16>\n    %query_reshape = \"oneflow.reshape\"(%query) {device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], op_name = \"unet.down_blocks.0.attentions.1.transformer_blocks.0.attn1-reshape-116\", scope_symbol_id = 661 : i64, shape = [2 : si64, 4096 : si64, 8 : si64, 40 : si64]} : (tensor<2x4096x320xf16>) -> tensor<2x4096x8x40xf16>\n    %value_permute = \"oneflow.transpose\"(%value_reshape) {device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], op_name = \"unet.down_blocks.0.attentions.1.transformer_blocks.0.attn1-transpose-125\", perm = [0 : si32, 2 : si32, 1 : si32, 3 : si32], scope_symbol_id = 661 : i64} : (tensor<2x4096x8x40xf16>) -> tensor<2x8x4096x40xf16>\n    %key_permute = \"oneflow.transpose\"(%key_reshape) {device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], op_name = \"unet.down_blocks.0.attentions.1.transformer_blocks.0.attn1-transpose-122\", perm = [0 : si32, 2 : si32, 1 : si32, 3 : si32], scope_symbol_id = 661 : i64} : (tensor<2x4096x8x40xf16>) -> tensor<2x8x4096x40xf16>\n    %query_permute = \"oneflow.transpose\"(%query_reshape) {device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], op_name = \"unet.down_blocks.0.attentions.1.transformer_blocks.0.attn1-transpose-117\", perm = [0 : si32, 2 : si32, 1 : si32, 3 : si32], scope_symbol_id = 661 : i64} : (tensor<2x4096x8x40xf16>) -> tensor<2x8x4096x40xf16>\n    %value_reshape_to_batch = \"oneflow.reshape\"(%value_permute) {device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], op_name = \"unet.down_blocks.0.attentions.1.transformer_blocks.0.attn1-reshape-126\", scope_symbol_id = 661 : i64, shape = [16 : si64, 4096 : si64, 40 : si64]} : (tensor<2x8x4096x40xf16>) -> tensor<16x4096x40xf16>\n    %key_reshape_to_batch = \"oneflow.reshape\"(%key_permute) {device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], op_name = \"unet.down_blocks.0.attentions.1.transformer_blocks.0.attn1-reshape-123\", scope_symbol_id = 661 : i64, shape = [16 : si64, 4096 : si64, 40 : si64]} : (tensor<2x8x4096x40xf16>) -> tensor<16x4096x40xf16>\n    %query_reshape_to_batch = \"oneflow.reshape\"(%query_permute) {device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], op_name = \"unet.down_blocks.0.attentions.1.transformer_blocks.0.attn1-reshape-118\", scope_symbol_id = 661 : i64, shape = [16 : si64, 4096 : si64, 40 : si64]} : (tensor<2x8x4096x40xf16>) -> tensor<16x4096x40xf16>\n    %key_transpose = \"oneflow.transpose\"(%key_reshape_to_batch) {device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], op_name = \"unet.down_blocks.0.attentions.1.transformer_blocks.0.attn1-transpose-128\", perm = [0 : si32, 2 : si32, 1 : si32], scope_symbol_id = 661 : i64} : (tensor<16x4096x40xf16>) -> tensor<16x40x4096xf16>\n    %scores_scaled = \"oneflow.batch_matmul\"(%query_reshape_to_batch, %key_transpose) {alpha = 0.15811388300841897 : f64, device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], op_name = \"unet.down_blocks.0.attentions.1.transformer_blocks.0.attn1-batch_matmul-129\", scope_symbol_id = 661 : i64, transpose_a = false, transpose_b = false} : (tensor<16x4096x40xf16>, tensor<16x40x4096xf16>) -> tensor<16x4096x4096xf16>\n    %attn = \"oneflow.softmax\"(%scores_scaled) {device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], op_name = \"unet.down_blocks.0.attentions.1.transformer_blocks.0.attn1-softmax-130\", scope_symbol_id = 661 : i64} : (tensor<16x4096x4096xf16>) -> tensor<16x4096x4096xf16>\n    %309 = \"oneflow.batch_matmul\"(%attn, %value_reshape_to_batch) {alpha = 1.000000e+00 : f64, device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], op_name = \"unet.down_blocks.0.attentions.1.transformer_blocks.0.attn1-batch_matmul-131\", scope_symbol_id = 661 : i64, transpose_a = false, transpose_b = false} : (tensor<16x4096x4096xf16>, tensor<16x4096x40xf16>) -> tensor<16x4096x40xf16>\n    %310 = \"oneflow.reshape\"(%309) {device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], op_name = \"unet.down_blocks.0.attentions.1.transformer_blocks.0.attn1-reshape-132\", scope_symbol_id = 661 : i64, shape = [2 : si64, 8 : si64, 4096 : si64, 40 : si64]} : (tensor<16x4096x40xf16>) -> tensor<2x8x4096x40xf16>\n    %311 = \"oneflow.transpose\"(%310) {device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], op_name = \"unet.down_blocks.0.attentions.1.transformer_blocks.0.attn1-transpose-133\", perm = [0 : si32, 2 : si32, 1 : si32, 3 : si32], scope_symbol_id = 661 : i64} : (tensor<2x8x4096x40xf16>) -> tensor<2x4096x8x40xf16>\n    %out_reshape_to_heads = \"oneflow.reshape\"(%311) {device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], op_name = \"unet.down_blocks.0.attentions.1.transformer_blocks.0.attn1-reshape-134\", scope_symbol_id = 661 : i64, shape = [2 : si64, 4096 : si64, 320 : si64]} : (tensor<2x4096x8x40xf16>) -> tensor<2x4096x320xf16>\n    // CHECK: func.func @fuse_mha2(%[[QUERY:[a-zA-Z0-9_]+]]: tensor<2x4096x320xf16>, %[[KEY:[a-zA-Z0-9_]+]]: tensor<2x4096x320xf16>, %[[VALUE:[a-zA-Z0-9_]+]]: tensor<2x4096x320xf16>)\n    // CHECK: oneflow.fused_multi_head_attention_inference\"(%[[QUERY]], %[[KEY]], %[[VALUE]]) {attn_mask_type = \"none\", causal_diagonal_offset = 0 : si64, device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], key_layout = \"BM(HK)\", key_max_seq_len = 0 : si64, op_name = [[OP_NAME:\".*\"]], operand_segment_sizes = array<i32: 1, 1, 1, 0, 0, 0, 0>, output_layout = \"BM(HK)\", query_head_size = 40 : si64, query_layout = \"BM(HK)\", query_max_seq_len = 0 : si64, scale = 0.15811388300841897 : f64, scope_symbol_id = 661 : i64, value_layout = \"BM(HK)\"} : (tensor<2x4096x320xf16>, tensor<2x4096x320xf16>, tensor<2x4096x320xf16>) -> tensor<2x4096x320xf16>\n    return %out_reshape_to_heads : tensor<2x4096x320xf16>\n  }\n\n  func.func @GraphToRun_pad_and_conv2d_0(%arg0: tensor<2x3x4x5xf32>) -> tensor<2x3x5x6xf32> {\n    %output = \"oneflow.variable\"() {data_type = 2 : i32, device_name = [\"@0:0\"], device_tag = \"cpu\", hierarchy = [1], op_name = \"conv.weight\", output_lbns = [\"conv.weight/out\"], parallel = #sbp.parallel<[] -> [[#sbp.B]]>, scope_symbol_id = 73 : i64, shape = [3 : si64, 3 : si64, 2 : si64, 2 : si64]} : () -> tensor<3x3x2x2xf32>\n    %output_0 = \"oneflow.input\"(%arg0) {data_type = 2 : i32, device_name = [\"@0:0\"], device_tag = \"cpu\", hierarchy = [1], is_dynamic = false, nd_sbp = [\"B\"], op_name = \"_GraphToRun_2_input.0.0_2\", output_lbns = [\"_GraphToRun_2_input.0.0_2/out\"], scope_symbol_id = 65 : i64, shape = [2 : si64, 3 : si64, 4 : si64, 5 : si64]} : (tensor<2x3x4x5xf32>) -> tensor<2x3x4x5xf32>\n    %0 = \"oneflow.pad\"(%output_0) {device_name = [\"@0:0\"], device_tag = \"cpu\", floating_constant_value = 0.000000e+00 : f64, hierarchy = [1], integral_constant_value = 0 : si64, op_name = \"pad-0\", padding = [1 : si64, 1 : si64, 1 : si64, 1 : si64], padding_after = [0 : si64, 0 : si64, 1 : si64, 1 : si64], padding_before = [0 : si64, 0 : si64, 1 : si64, 1 : si64], scope_symbol_id = 65 : i64} : (tensor<2x3x4x5xf32>) -> tensor<2x3x6x7xf32>\n    %1 = \"oneflow.conv2d\"(%0, %output) {data_format = \"channels_first\", device_name = [\"@0:0\"], device_tag = \"cpu\", dilation_rate = [1 : si32, 1 : si32], filters = 3 : si32, groups = 1 : si32, hierarchy = [1], kernel_size = [2 : si32, 2 : si32], op_name = \"conv-conv2d-1\", operand_segment_sizes = array<i32: 1, 1, 0, 0>, padding_before = [0 : si32, 0 : si32], scope_symbol_id = 76 : i64, strides = [1 : si32, 1 : si32]} : (tensor<2x3x6x7xf32>, tensor<3x3x2x2xf32>) -> tensor<2x3x5x6xf32>\n    %output_1 = \"oneflow.output\"(%1) {data_type = 2 : i32, device_name = [\"@0:0\"], device_tag = \"cpu\", hierarchy = [1], is_dynamic = false, nd_sbp = [\"B\"], op_name = \"_GraphToRun_2_output.0.0_2\", output_lbns = [\"_GraphToRun_2_output.0.0_2/out\"], scope_symbol_id = 65 : i64, shape = [2 : si64, 3 : si64, 5 : si64, 6 : si64]} : (tensor<2x3x5x6xf32>) -> tensor<2x3x5x6xf32>\n    // CHECK: func.func @GraphToRun_pad_and_conv2d_0(%[[A:[a-zA-Z0-9_]+]]: tensor<2x3x4x5xf32>) -> tensor<2x3x5x6xf32> {\n    // CHECK: %[[OUT:[a-zA-Z0-9_]+]] = \"oneflow.variable\"()\n    // CHECK: %[[OUT0:[a-zA-Z0-9_]+]] = \"oneflow.input\"(%[[A]])\n    // CHECK-NOT: oneflow.pad\n    // CHECK: %[[OUT1:[a-zA-Z0-9_]+]] = \"oneflow.conv2d\"(%[[OUT0]], %[[OUT]])\n    // CHECK: %[[OUT2:[a-zA-Z0-9_]+]] = \"oneflow.output\"\n    return %output_1 : tensor<2x3x5x6xf32>\n  }\n\n  func.func @GraphToRun_same_dtype_cast_0(%arg0: tensor<2x3x4x5xf32>) -> tensor<2x3x4x5xf32> {\n    %output_0 = \"oneflow.input\"(%arg0) {data_type = 2 : i32, device_name = [\"@0:0\"], device_tag = \"cpu\", hierarchy = [1], is_dynamic = false, nd_sbp = [\"B\"], op_name = \"_GraphToRun_3_input.0.0_2\", output_lbns = [\"_GraphToRun_3_input.0.0_2/out\"], scope_symbol_id = 65 : i64, shape = [2 : si64, 3 : si64, 4 : si64, 5 : si64]} : (tensor<2x3x4x5xf32>) -> tensor<2x3x4x5xf32>\n    %0 = \"oneflow.cast\"(%output_0) {device_name = [\"0:0\"], device_tag = \"cpu\", dtype = 2 : i32, hierarchy = [1], op_name = \"Cast_1\", op_type_name = \"cast\", scope_symbol_id = 65 : i64} : (tensor<2x3x4x5xf32>) -> tensor<2x3x4x5xf32>\n    %output_1 = \"oneflow.output\"(%0) {data_type = 2 : i32, device_name = [\"@0:0\"], device_tag = \"cpu\", hierarchy = [1], is_dynamic = false, nd_sbp = [\"B\"], op_name = \"_GraphToRun_3_output.0.0_2\", output_lbns = [\"_GraphToRun_3_output.0.0_2/out\"], scope_symbol_id = 65 : i64, shape = [2 : si64, 3 : si64, 4 : si64, 5 : si64]} : (tensor<2x3x4x5xf32>) -> tensor<2x3x4x5xf32>\n    // CHECK: func.func @GraphToRun_same_dtype_cast_0(%[[A:[a-zA-Z0-9_]+]]: tensor<2x3x4x5xf32>) -> tensor<2x3x4x5xf32> {\n    // CHECK: %[[OUT0:[a-zA-Z0-9_]+]] = \"oneflow.input\"(%[[A]])\n    // CHECK-NOT: oneflow.cast\n    // CHECK: %[[OUT:[a-zA-Z0-9_]+]] = \"oneflow.output\"(%[[OUT0]])\n    // CHECK：return %[[OUT]] : tensor<2x3x4x5xf32>\n    return %output_1 : tensor<2x3x4x5xf32>\n  }\n\n  func.func @GraphToRun_same_dtype_cast_1(%arg0: tensor<2x3x4x5xf32>) -> tensor<2x3x4x5xi32> {\n    %output_0 = \"oneflow.input\"(%arg0) {data_type = 2 : i32, device_name = [\"@0:0\"], device_tag = \"cpu\", hierarchy = [1], is_dynamic = false, nd_sbp = [\"B\"], op_name = \"_GraphToRun_4_input.0.0_2\", output_lbns = [\"_GraphToRun_4_input.0.0_2/out\"], scope_symbol_id = 65 : i64, shape = [2 : si64, 3 : si64, 4 : si64, 5 : si64]} : (tensor<2x3x4x5xf32>) -> tensor<2x3x4x5xf32>\n    %0 = \"oneflow.cast\"(%output_0) {device_name = [\"0:0\"], device_tag = \"cpu\", dtype = 5 : i32, hierarchy = [1], op_name = \"Cast_1\", op_type_name = \"cast\", scope_symbol_id = 65 : i64} : (tensor<2x3x4x5xf32>) -> tensor<2x3x4x5xi32>\n    %output_1 = \"oneflow.output\"(%0) {data_type = 5 : i32, device_name = [\"@0:0\"], device_tag = \"cpu\", hierarchy = [1], is_dynamic = false, nd_sbp = [\"B\"], op_name = \"_GraphToRun_4_output.0.0_2\", output_lbns = [\"_GraphToRun_4_output.0.0_2/out\"], scope_symbol_id = 65 : i64, shape = [2 : si64, 3 : si64, 4 : si64, 5 : si64]} : (tensor<2x3x4x5xi32>) -> tensor<2x3x4x5xi32>\n    // CHECK: func.func @GraphToRun_same_dtype_cast_1(%[[A:[a-zA-Z0-9_]+]]: tensor<2x3x4x5xf32>) -> tensor<2x3x4x5xi32> {\n    // CHECK: %[[OUT0:[a-zA-Z0-9_]+]] = \"oneflow.input\"(%[[A]])\n    // CHECK: %[[OUT1:[a-zA-Z0-9_]+]] = \"oneflow.cast\"(%[[OUT0]])\n    // CHECK: %[[OUT:[a-zA-Z0-9_]+]] = \"oneflow.output\"(%[[OUT1]])\n    // CHECK：return %[[OUT]] : tensor<2x3x4x5xi32>\n    return %output_1 : tensor<2x3x4x5xi32>\n  }\n\n  func.func @GraphToRun_scale_tril_0() -> tensor<5x5xf32> {\n    %output = \"oneflow.variable\"() {data_type = 2 : i32, device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], op_name = \"FreeEagerTensor-1\", output_lbns = [\"FreeEagerTensor-1/out\"], parallel = #sbp.parallel<[] -> [[#sbp.B]]>, scope_symbol_id = 12 : i64, shape = [5 : si64, 5 : si64], trainable = false} : () -> tensor<5x5xf32>\n    %0 = \"oneflow.scalar_mul\"(%output) {device_name = [\"@0:0\"], device_tag = \"cuda\", float_operand = -2.300000e+00 : f64, has_float_operand = true, has_int_operand = false, hierarchy = [1], int_operand = 0 : si64, op_name = \"scalar_mul-0\", scope_symbol_id = 12 : i64} : (tensor<5x5xf32>) -> tensor<5x5xf32>\n    %1 = \"oneflow.tril\"(%0) {device_name = [\"@0:0\"], device_tag = \"cuda\", diagonal = -1 : si64, floating_fill_value = 0.000000e+00 : f64, hierarchy = [1], integer_fill_value = 0 : si64, is_floating_fill_value = false, op_name = \"tril-2\", scope_symbol_id = 12 : i64} : (tensor<5x5xf32>) -> tensor<5x5xf32>\n    %output_0 = \"oneflow.output\"(%1) {data_type = 2 : i32, device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], is_dynamic = false, nd_sbp = [\"B\"], op_name = \"_TestFuseScaleTril_0_output.0.0_2\", output_lbns = [\"_TestFuseScaleTril_0_output.0.0_2/out\"], scope_symbol_id = 12 : i64, shape = [5 : si64, 5 : si64]} : (tensor<5x5xf32>) -> tensor<5x5xf32>\n    // CHECK: func.func @GraphToRun_scale_tril_0() -> tensor<5x5xf32> {\n    // CHECK: %[[OUT0:[a-zA-Z0-9_]+]] = \"oneflow.variable\"()\n    // CHECK: %[[OUT1:[a-zA-Z0-9_]+]] = \"oneflow.fused_scale_tril\"(%[[OUT0]])\n    // CHECK: %[[OUT:[a-zA-Z0-9_]+]] = \"oneflow.output\"(%[[OUT1]])\n    // CHECK：return %[[OUT]]\n    return %output_0 : tensor<5x5xf32>\n  }\n\n  func.func @GraphToRun_scale_tril_1() -> tensor<5x5xf32> {\n    %output = \"oneflow.variable\"() {data_type = 2 : i32, device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], op_name = \"FreeEagerTensor-1\", output_lbns = [\"FreeEagerTensor-1/out\"], parallel = #sbp.parallel<[] -> [[#sbp.B]]>, scope_symbol_id = 66 : i64, shape = [5 : si64, 5 : si64], trainable = false} : () -> tensor<5x5xf32>\n    %0 = \"oneflow.tril\"(%output) {device_name = [\"@0:0\"], device_tag = \"cuda\", diagonal = -1 : si64, floating_fill_value = 0.000000e+00 : f64, hierarchy = [1], integer_fill_value = 0 : si64, is_floating_fill_value = false, op_name = \"tril-0\", scope_symbol_id = 66 : i64} : (tensor<5x5xf32>) -> tensor<5x5xf32>\n    %1 = \"oneflow.scalar_mul\"(%0) {device_name = [\"@0:0\"], device_tag = \"cuda\", float_operand = 2.000000e+00 : f64, has_float_operand = true, has_int_operand = false, hierarchy = [1], int_operand = 0 : si64, op_name = \"scalar_mul-2\", scope_symbol_id = 66 : i64} : (tensor<5x5xf32>) -> tensor<5x5xf32>\n    %output_0 = \"oneflow.output\"(%1) {data_type = 2 : i32, device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], is_dynamic = false, nd_sbp = [\"B\"], op_name = \"_TestFuseTrilScale_1_output.0.0_2\", output_lbns = [\"_TestFuseTrilScale_1_output.0.0_2/out\"], scope_symbol_id = 66 : i64, shape = [5 : si64, 5 : si64]} : (tensor<5x5xf32>) -> tensor<5x5xf32>\n    // CHECK: func.func @GraphToRun_scale_tril_1() -> tensor<5x5xf32> {\n    // CHECK: %[[OUT0:[a-zA-Z0-9_]+]] = \"oneflow.variable\"()\n    // CHECK: %[[OUT1:[a-zA-Z0-9_]+]] = \"oneflow.fused_scale_tril\"(%[[OUT0]])\n    // CHECK: %[[OUT:[a-zA-Z0-9_]+]] = \"oneflow.output\"(%[[OUT1]])\n    // CHECK：return %[[OUT]]\n    return %output_0 : tensor<5x5xf32>\n  }\n\n  func.func @GraphToRun_normalization_1(%x: tensor<2x3x224x224xf32>, %moving_mean: tensor<3xf32>, %moving_variance: tensor<3xf32>, %gamma: tensor<3xf32>, %beta: tensor<3xf32>, %addend: tensor<2x3x224x224xf32>) -> tensor<2x3x224x224xf32> {\n    %y, %mean, %inv_variance = \"oneflow.normalization\"(%x, %moving_mean, %moving_variance, %gamma, %beta) {axis = 1 : si32, device_name = [\"@0:0\"], device_tag = \"cpu\", epsilon = 9.99999974E-6 : f32, hierarchy = [1], momentum = 0.899999976 : f32, op_name = \"normalization-2\", operand_segment_sizes = array<i32: 1, 1, 1, 1, 1, 0>, result_segment_sizes = array<i32: 1, 1, 1>, scope_symbol_id = 12 : i64, training = true} : (tensor<2x3x224x224xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>) -> (tensor<2x3x224x224xf32>, tensor<3xf32>, tensor<3xf32>)\n    %0 = \"oneflow.add_n2\"(%y, %addend) {device_name = [\"@0:0\"], device_tag = \"cpu\", hierarchy = [1], op_name = \"add_n-7\", op_type_name = \"add_n\", scope_symbol_id = 12 : i64} : (tensor<2x3x224x224xf32>, tensor<2x3x224x224xf32>) -> tensor<2x3x224x224xf32>\n    %1 = \"oneflow.relu\"(%0) {device_name = [\"@0:0\"], device_tag = \"cpu\", hierarchy = [1], op_name = \"relu-8\", scope_symbol_id = 12 : i64} : (tensor<2x3x224x224xf32>) -> tensor<2x3x224x224xf32>\n    // CHECK: func.func @GraphToRun_normalization_1(%[[X:[a-zA-Z0-9_]+]]: tensor<2x3x224x224xf32>, %[[MOVING_MEAN:[a-zA-Z0-9_]+]]: tensor<3xf32>, %[[MOVING_VARIANCE:[a-zA-Z0-9_]+]]: tensor<3xf32>, %[[GAMMA:[a-zA-Z0-9_]+]]: tensor<3xf32>, %[[BETA:[a-zA-Z0-9_]+]]: tensor<3xf32>, %[[ADDEND:[a-zA-Z0-9_]+]]: tensor<2x3x224x224xf32>)\n    // CHECK: %[[Y:[a-zA-Z0-9_]+]], %[[reserve_space:[a-zA-Z0-9_]+]], %[[mean:[a-zA-Z0-9_]+]], %[[inv_variance:[a-zA-Z0-9_]+]] = \"oneflow.normalization_add_relu\"(%[[X]], %[[ADDEND]], %[[MOVING_MEAN]], %[[MOVING_VARIANCE]], %[[GAMMA]], %[[BETA]])\n    // CHECK： return %[[Y]]\n    return %1 : tensor<2x3x224x224xf32>\n  }\n\n  func.func @GraphToRun_normalization_2(%x: tensor<2x3x224x224xf32>, %moving_mean: tensor<3xf32>, %moving_variance: tensor<3xf32>, %gamma: tensor<3xf32>, %beta: tensor<3xf32>, %addend: tensor<2x3x224x224xf32>) -> tensor<2x3x224x224xf32> {\n    %y = \"oneflow.normalization_infer\"(%x, %moving_mean, %moving_variance, %gamma, %beta) {axis = 1 : si32, device_name = [\"@0:0\"], device_tag = \"cpu\", epsilon = 9.99999974E-6 : f32, hierarchy = [1], momentum = 0.899999976 : f32, op_name = \"normalization-2\", operand_segment_sizes = array<i32: 1, 1, 1, 1, 1, 0>, result_segment_sizes = array<i32: 1>, scope_symbol_id = 12 : i64, training = true} : (tensor<2x3x224x224xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>) -> (tensor<2x3x224x224xf32>)\n    %0 = \"oneflow.add_n2\"(%y, %addend) {device_name = [\"@0:0\"], device_tag = \"cpu\", hierarchy = [1], op_name = \"add_n-7\", op_type_name = \"add_n\", scope_symbol_id = 12 : i64} : (tensor<2x3x224x224xf32>, tensor<2x3x224x224xf32>) -> tensor<2x3x224x224xf32>\n    %1 = \"oneflow.relu\"(%0) {device_name = [\"@0:0\"], device_tag = \"cpu\", hierarchy = [1], op_name = \"relu-8\", scope_symbol_id = 12 : i64} : (tensor<2x3x224x224xf32>) -> tensor<2x3x224x224xf32>\n    // CHECK: func.func @GraphToRun_normalization_2(%[[X:[a-zA-Z0-9_]+]]: tensor<2x3x224x224xf32>, %[[MOVING_MEAN:[a-zA-Z0-9_]+]]: tensor<3xf32>, %[[MOVING_VARIANCE:[a-zA-Z0-9_]+]]: tensor<3xf32>, %[[GAMMA:[a-zA-Z0-9_]+]]: tensor<3xf32>, %[[BETA:[a-zA-Z0-9_]+]]: tensor<3xf32>, %[[ADDEND:[a-zA-Z0-9_]+]]: tensor<2x3x224x224xf32>)\n    // CHECK: %[[Y:[a-zA-Z0-9_]+]], %[[reserve_space:[a-zA-Z0-9_]+]], %[[mean:[a-zA-Z0-9_]+]], %[[inv_variance:[a-zA-Z0-9_]+]] = \"oneflow.normalization_add_relu\"(%[[X]], %[[ADDEND]], %[[MOVING_MEAN]], %[[MOVING_VARIANCE]], %[[GAMMA]], %[[BETA]])\n    // CHECK： return %[[Y]]\n    return %1 : tensor<2x3x224x224xf32>\n  }\n\n  func.func @GraphToRun_conv_bn_1(%arg0: tensor<1x3x224x224xf32>, %moving_mean: tensor<64xf32>, %moving_variance: tensor<64xf32>, %beta: tensor<64xf32>) -> tensor<1x64x112x112xf32> {\n    %output = \"oneflow.input\"(%arg0) {data_type = 2 : i32, device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], is_dynamic = false, nd_sbp = [\"B\"], op_name = \"_conv_bn_1_input.0.0_2\", output_lbns = [\"_conv_bn_1_input.0.0_2/out\"], scope_symbol_id = 12 : i64, shape = [1 : si64, 3 : si64, 224 : si64, 224 : si64]} : (tensor<1x3x224x224xf32>) -> tensor<1x3x224x224xf32>\n    %0 = \"oneflow.variable_ir\"() {value = dense<1.0> : tensor<64x3x7x7xf32> ,data_type = 2 : i32, device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], op_name = \"model.conv1.weight\", output_lbns = [\"model.conv1.weight/out\"], parallel = #sbp.parallel<[] -> [[#sbp.B]]>, scope_symbol_id = 18 : i64, shape = [64 : si64, 3 : si64, 7 : si64, 7 : si64], nd_sbp = [\"B\"]} : () -> tensor<64x3x7x7xf32>\n    %gamma = \"oneflow.variable_ir\"() {value = dense<1.0> : tensor<64xf32> ,data_type = 2 : i32, device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], op_name = \"model.bn.gamma\", output_lbns = [\"model.bn.gamma/out\"], parallel = #sbp.parallel<[] -> [[#sbp.B]]>, scope_symbol_id = 18 : i64, shape = [64 : si64], nd_sbp = [\"B\"]} : () -> tensor<64xf32>\n    %1 = \"oneflow.conv2d\"(%output, %0) {data_format = \"channels_first\", device_name = [\"@0:0\"], device_tag = \"cuda\", dilation_rate = [1 : si32, 1 : si32], filters = 64 : si32, groups = 1 : si32, hierarchy = [1], kernel_size = [7 : si32, 7 : si32], op_name = \"model.conv1-conv2d-0\", operand_segment_sizes = array<i32: 1, 1, 0, 0>, padding_before = [3 : si32, 3 : si32], scope_symbol_id = 21 : i64, strides = [2 : si32, 2 : si32]} : (tensor<1x3x224x224xf32>, tensor<64x3x7x7xf32>) -> tensor<1x64x112x112xf32>\n    %2 = \"oneflow.normalization_infer\"(%1,  %moving_mean, %moving_variance, %gamma, %beta) {axis = 1 : si32, device_name = [\"@0:0\"], device_tag = \"cuda\", epsilon = 9.99999974E-6 : f32, hierarchy = [1], momentum = 0.899999976 : f32, op_name = \"model.bn1-normalization-1\", operand_segment_sizes = array<i32: 1, 1, 1, 1, 1, 0>, result_segment_sizes = array<i32: 1, 0, 0>, scope_symbol_id = 41 : i64, training = false} : (tensor<1x64x112x112xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>) -> tensor<1x64x112x112xf32>\n    // CHECK: func.func @GraphToRun_conv_bn_1(%[[ARG_0:[a-zA-Z0-9_]+]]: tensor<1x3x224x224xf32>, %[[MOVING_MEAN:[a-zA-Z0-9_]+]]: tensor<64xf32>, %[[MOVING_VARIANCE:[a-zA-Z0-9_]+]]: tensor<64xf32>, %[[BETA:[a-zA-Z0-9_]+]]: tensor<64xf32>)\n    // CHECK:  %[[GAMMA:[a-zA-Z0-9_]+]] = \"oneflow.variable_ir\"()\n    // CHECK:  %[[WEIGHT:[a-zA-Z0-9_]+]] = \"oneflow.variable_ir\"()\n    // CHECK:  %[[OUT:[a-zA-Z0-9_]+]] = \"oneflow.input\"(%[[ARG_0]])\n    // CHECK:  %[[OUT2:[a-zA-Z0-9_]+]] = \"oneflow.scalar_add\"(%[[MOVING_VARIANCE]])\n    // CHECK:  %[[OUT3:[a-zA-Z0-9_]+]] = \"oneflow.sqrt\"(%[[OUT2]])\n    // CHECK:  %[[OUT4:[a-zA-Z0-9_]+]] = \"oneflow.broadcast_div\"(%[[GAMMA]], %[[OUT3]])\n    // CHECK:  %[[OUT5:[a-zA-Z0-9_]+]] = \"oneflow.reshape\"(%[[OUT4]])\n    // CHECK:  %[[OUT6:[a-zA-Z0-9_]+]] = \"oneflow.broadcast_mul\"(%[[WEIGHT]], %[[OUT5]])\n    // CHECK:  %[[OUT7:[a-zA-Z0-9_]+]] = \"oneflow.broadcast_mul\"(%[[MOVING_MEAN]], %[[OUT4]])\n    // CHECK:  %[[OUT8:[a-zA-Z0-9_]+]] = \"oneflow.broadcast_sub\"(%[[BETA]], %[[OUT7]])\n    // CHECK:  %[[OUT9:[a-zA-Z0-9_]+]] = \"oneflow.conv2d\"(%[[OUT]], %[[OUT6]], %[[OUT8]])\n    // CHECK： return %[[OUT9]]\n    return %2 : tensor<1x64x112x112xf32>\n  }\n\n\n  func.func @GraphToRun_broadcastmul_to_scalarmul_1(%arg0: tensor<64x3x7x7xf32>, %arg1: tensor<1xf32>) -> tensor<64x3x7x7xf32> {\n    %output = \"oneflow.broadcast_mul\"(%arg0, %arg1) {device_name = [\"@0:0\"], device_tag = \"cuda\", op_name = \"multiply\"} : (tensor<64x3x7x7xf32>, tensor<1xf32>) -> tensor<64x3x7x7xf32>\n    // CHECK: func.func @GraphToRun_broadcastmul_to_scalarmul_1(%[[ARG_0:[a-zA-Z0-9_]+]]: tensor<64x3x7x7xf32>, %[[ARG_1:[a-zA-Z0-9_]+]]: tensor<1xf32>)\n    // CHECK: %[[OUT:[a-zA-Z0-9_]+]] = \"oneflow.scalar_mul_by_tensor\"(%[[ARG_0]], %[[ARG_1]]\n    return %output : tensor<64x3x7x7xf32>\n    }\n\n  func.func @GraphToRun_fused_gelu_1(%arg0: tensor<2x2304x640xf32>) -> tensor<2x2304x5120xf32> {\n    %output = \"oneflow.variable\"() {data_type = 2 : i32, device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], op_name = \"gelu_mod.proj.weight\", output_lbns = [\"gelu_mod.proj.weight/out\"], parallel = #sbp.parallel<[] -> [[#sbp.B]]>, scope_symbol_id = 18 : i64, shape = [10240 : si64, 640 : si64]} : () -> tensor<10240x640xf32>\n    %output_0 = \"oneflow.variable\"() {data_type = 2 : i32, device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], op_name = \"gelu_mod.proj.bias\", output_lbns = [\"gelu_mod.proj.bias/out\"], parallel = #sbp.parallel<[] -> [[#sbp.B]]>, scope_symbol_id = 25 : i64, shape = [10240 : si64]} : () -> tensor<10240xf32>\n    %output_1 = \"oneflow.input\"(%arg0) {data_type = 2 : i32, device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], is_dynamic = false, nd_sbp = [\"B\"], op_name = \"_GraphToRun_0_input.0.0_2\", output_lbns = [\"_GraphToRun_0_input.0.0_2/out\"], scope_symbol_id = 12 : i64, shape = [2 : si64, 2304 : si64, 640 : si64]} : (tensor<2x2304x640xf32>) -> tensor<2x2304x640xf32>\n    %matmul_wx = \"oneflow.broadcast_matmul\"(%output_1, %output) {alpha = 1.000000e+00 : f64, device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], op_name = \"gelu_mod.proj-broadcast_matmul-0\", scope_symbol_id = 21 : i64, transpose_a = false, transpose_b = true} : (tensor<2x2304x640xf32>, tensor<10240x640xf32>) -> tensor<2x2304x10240xf32>\n    %matmul_wx_add = \"oneflow.broadcast_add\"(%matmul_wx, %output_0) {device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], op_name = \"gelu_mod.proj-broadcast_add-1\", scope_symbol_id = 21 : i64} : (tensor<2x2304x10240xf32>, tensor<10240xf32>) -> tensor<2x2304x10240xf32>\n    %hidden_states = \"oneflow.narrow\"(%matmul_wx_add) {device_name = [\"@0:0\"], device_tag = \"cuda\", dim = 2 : si64, hierarchy = [1], length = 5120 : si64, op_name = \"gelu_mod-narrow-2\", scope_symbol_id = 31 : i64, start = 0 : si64} : (tensor<2x2304x10240xf32>) -> tensor<2x2304x5120xf32>\n    %gate = \"oneflow.narrow\"(%matmul_wx_add) {device_name = [\"@0:0\"], device_tag = \"cuda\", dim = 2 : si64, hierarchy = [1], length = 5120 : si64, op_name = \"gelu_mod-narrow-3\", scope_symbol_id = 31 : i64, start = 5120 : si64} : (tensor<2x2304x10240xf32>) -> tensor<2x2304x5120xf32>\n    %gate_activate = \"oneflow.gelu\"(%gate) {device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], op_name = \"gelu_mod-gelu-4\", scope_symbol_id = 31 : i64} : (tensor<2x2304x5120xf32>) -> tensor<2x2304x5120xf32>\n    %y = \"oneflow.broadcast_mul\"(%hidden_states, %gate_activate) {device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], op_name = \"gelu_mod-broadcast_mul-5\", scope_symbol_id = 31 : i64} : (tensor<2x2304x5120xf32>, tensor<2x2304x5120xf32>) -> tensor<2x2304x5120xf32>\n    %output_2 = \"oneflow.output\"(%y) {data_type = 2 : i32, device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], is_dynamic = false, nd_sbp = [\"B\"], op_name = \"_GraphToRun_0_output.0.0_2\", output_lbns = [\"_GraphToRun_0_output.0.0_2/out\"], scope_symbol_id = 12 : i64, shape = [2 : si64, 2304 : si64, 5120 : si64]} : (tensor<2x2304x5120xf32>) -> tensor<2x2304x5120xf32>\n    // CHECK: func.func @GraphToRun_fused_gelu_1(%[[ARG_0:[a-zA-Z0-9_]+]]: tensor<2x2304x640xf32>) -> tensor<2x2304x5120xf32> {\n    // CHECK:  %[[OUT:[a-zA-Z0-9_]+]] = \"oneflow.variable\"()\n    // CHECK:  %[[OUT0:[a-zA-Z0-9_]+]] = \"oneflow.variable\"()\n    // CHECK:  %[[OUT1:[a-zA-Z0-9_]+]] = \"oneflow.input\"(%[[ARG_0]])\n    // CHECK:  %[[Y:[a-zA-Z0-9_]+]], %[[MATMUL:[a-zA-Z0-9_]+]] = \"oneflow.fused_glu\"(%[[OUT1]], %[[OUT]], %[[OUT0]])\n    // CHECK:  %[[OUT2:[a-zA-Z0-9_]+]] = \"oneflow.output\"(%[[Y]])\n    // CHECK： return %[[OUT2]]\n    return %output_2 : tensor<2x2304x5120xf32>\n  }\n\n  func.func @GraphToRun_fused_gelu_2(%arg0: tensor<2x2304x640xf32>) -> tensor<2x2304x5120xf32> {\n    %output = \"oneflow.variable\"() {data_type = 2 : i32, device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], op_name = \"gelu_mod.proj.weight\", output_lbns = [\"gelu_mod.proj.weight/out\"], parallel = #sbp.parallel<[] -> [[#sbp.B]]>, scope_symbol_id = 18 : i64, shape = [10240 : si64, 640 : si64]} : () -> tensor<10240x640xf32>\n    %output_0 = \"oneflow.variable\"() {data_type = 2 : i32, device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], op_name = \"gelu_mod.proj.bias\", output_lbns = [\"gelu_mod.proj.bias/out\"], parallel = #sbp.parallel<[] -> [[#sbp.B]]>, scope_symbol_id = 25 : i64, shape = [10240 : si64]} : () -> tensor<10240xf32>\n    %output_1 = \"oneflow.input\"(%arg0) {data_type = 2 : i32, device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], is_dynamic = false, nd_sbp = [\"B\"], op_name = \"_GraphToRun_0_input.0.0_2\", output_lbns = [\"_GraphToRun_0_input.0.0_2/out\"], scope_symbol_id = 12 : i64, shape = [2 : si64, 2304 : si64, 640 : si64]} : (tensor<2x2304x640xf32>) -> tensor<2x2304x640xf32>\n    %matmul_wx_add = \"oneflow.fused_matmul_bias\"(%output_1, %output, %output_0) {alpha = 1.000000e+00 : f64, device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], op_name = \"gelu_mod.proj-broadcast_add-1\", scope_symbol_id = 21 : i64} : (tensor<2x2304x640xf32>, tensor<10240x640xf32>, tensor<10240xf32>) -> tensor<2x2304x10240xf32>\n    %hidden_states = \"oneflow.narrow\"(%matmul_wx_add) {device_name = [\"@0:0\"], device_tag = \"cuda\", dim = 2 : si64, hierarchy = [1], length = 5120 : si64, op_name = \"gelu_mod-narrow-2\", scope_symbol_id = 31 : i64, start = 0 : si64} : (tensor<2x2304x10240xf32>) -> tensor<2x2304x5120xf32>\n    %gate = \"oneflow.narrow\"(%matmul_wx_add) {device_name = [\"@0:0\"], device_tag = \"cuda\", dim = 2 : si64, hierarchy = [1], length = 5120 : si64, op_name = \"gelu_mod-narrow-3\", scope_symbol_id = 31 : i64, start = 5120 : si64} : (tensor<2x2304x10240xf32>) -> tensor<2x2304x5120xf32>\n    %gate_activate = \"oneflow.gelu\"(%gate) {device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], op_name = \"gelu_mod-gelu-4\", scope_symbol_id = 31 : i64} : (tensor<2x2304x5120xf32>) -> tensor<2x2304x5120xf32>\n    %y = \"oneflow.broadcast_mul\"(%hidden_states, %gate_activate) {device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], op_name = \"gelu_mod-broadcast_mul-5\", scope_symbol_id = 31 : i64} : (tensor<2x2304x5120xf32>, tensor<2x2304x5120xf32>) -> tensor<2x2304x5120xf32>\n    %output_2 = \"oneflow.output\"(%y) {data_type = 2 : i32, device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], is_dynamic = false, nd_sbp = [\"B\"], op_name = \"_GraphToRun_0_output.0.0_2\", output_lbns = [\"_GraphToRun_0_output.0.0_2/out\"], scope_symbol_id = 12 : i64, shape = [2 : si64, 2304 : si64, 5120 : si64]} : (tensor<2x2304x5120xf32>) -> tensor<2x2304x5120xf32>\n    // CHECK: func.func @GraphToRun_fused_gelu_2(%[[ARG_0:[a-zA-Z0-9_]+]]: tensor<2x2304x640xf32>) -> tensor<2x2304x5120xf32> {\n    // CHECK:  %[[OUT:[a-zA-Z0-9_]+]] = \"oneflow.variable\"()\n    // CHECK:  %[[OUT0:[a-zA-Z0-9_]+]] = \"oneflow.variable\"()\n    // CHECK:  %[[OUT1:[a-zA-Z0-9_]+]] = \"oneflow.input\"(%[[ARG_0]])\n    // CHECK:  %[[Y:[a-zA-Z0-9_]+]], %[[MATMUL:[a-zA-Z0-9_]+]] = \"oneflow.fused_glu\"(%[[OUT1]], %[[OUT]], %[[OUT0]])\n    // CHECK:  %[[OUT2:[a-zA-Z0-9_]+]] = \"oneflow.output\"(%[[Y]])\n    // CHECK： return %[[OUT2]]\n    return %output_2 : tensor<2x2304x5120xf32>\n  }\n}\n"
  },
  {
    "path": "oneflow/ir/test/OneFlow/fuse/test_cast_optimal_pass.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n# RUN: python3 -m oneflow.test_utils.throttle --with-cuda=%with_cuda python3 %s | FileCheck %s\n# CHECK-NOT: oneflow.cast\n\nimport os\nimport unittest\nimport numpy as np\n\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\ndef _cast_optimal_pass(test_case, dtype):\n    a = flow.tensor([2, 3], dtype=dtype)\n    eager_b = flow.cast(a, dtype=dtype)\n\n    class CastOpOptimalPass(flow.nn.Graph):\n        def __init__(self):\n            super().__init__()\n            self.cast = flow.cast\n\n        def build(self, x):\n            return self.cast(x, dtype=dtype)\n\n    lazy_b = CastOpOptimalPass()(a)\n    test_case.assertEqual(eager_b.dtype, lazy_b.dtype)\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestCastOpOptimalPass(flow.unittest.MLIRTestCase):\n    def setUp(self):\n        os.environ[\"ONEFLOW_MLIR_ENABLE_ROUND_TRIP\"] = \"1\"\n        os.environ[\"ONEFLOW_MLIR_FUSE_FORWARD_OPS\"] = \"1\"\n        os.environ[\"ONEFLOW_MLIR_STDOUT\"] = \"1\"\n        os.environ[\"ONEFLOW_MLIR_ENABLE_TIMING\"] = \"1\"\n        os.environ[\"ONEFLOW_MLIR_PRINT_STATS\"] = \"1\"\n        os.environ[\"ONEFLOW_MLIR_ENABLE_IR_PRINTING\"] = \"1\"\n\n    def test_case_optimal_pass(test_case):\n        for dtype in [flow.float32, flow.float64, flow.int32, flow.int64]:\n            _cast_optimal_pass(test_case, dtype)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "oneflow/ir/test/OneFlow/fuse/test_fuse_pad_conv.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n# RUN: python3 -m oneflow.test_utils.throttle --with-cuda=%with_cuda python3 %s | FileCheck %s\n# CHECK-NOT: oneflow.pad\n\nimport unittest\nimport numpy as np\n\nimport os\nimport oneflow as flow\nimport oneflow.unittest\nimport oneflow.sysconfig\n\n\ndef do_pad_conv_graph(test_case, with_cuda, with_bias, with_nchw=True):\n    if with_nchw:\n        x = flow.randn(2, 3, 4, 5)\n    else:\n        x = flow.randn(2, 4, 5, 3)\n    conv = flow.nn.Conv2d(3, 3, 2, 1, bias=with_bias)\n    if with_cuda:\n        x = x.cuda()\n        conv.to(\"cuda\")\n\n    if with_nchw:\n        pad_x = flow.nn.functional.pad(x, (1, 1, 1, 1))\n    else:\n        pad_x = flow.nn.functional.pad(x, (0, 0, 1, 1, 1, 1))\n    eager_conv_x = conv(pad_x)\n\n    class GraphToRun(flow.nn.Graph):\n        def __init__(self):\n            super().__init__()\n            self.conv = conv\n\n        def build(self, x):\n            if with_nchw:\n                pad_x = flow.nn.functional.pad(x, (1, 1, 1, 1))\n            else:\n                pad_x = flow.nn.functional.pad(x, (0, 0, 1, 1, 1, 1))\n            return self.conv(pad_x)\n\n    graph_to_run = GraphToRun()\n    lazy_conv_x = graph_to_run(x)\n    test_case.assertTrue(np.array_equal(eager_conv_x.numpy(), lazy_conv_x.numpy()))\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestFusePadConv(oneflow.unittest.MLIRTestCase):\n    def setUp(self):\n        os.environ[\"ONEFLOW_MLIR_ENABLE_ROUND_TRIP\"] = \"1\"\n        os.environ[\"ONEFLOW_MLIR_FUSE_FORWARD_OPS\"] = \"1\"\n        os.environ[\"ONEFLOW_MLIR_STDOUT\"] = \"1\"\n        os.environ[\"ONEFLOW_MLIR_ENABLE_TIMING\"] = \"1\"\n        os.environ[\"ONEFLOW_MLIR_PRINT_STATS\"] = \"1\"\n        os.environ[\"ONEFLOW_MLIR_ENABLE_IR_PRINTING\"] = \"1\"\n\n    @unittest.skipUnless(oneflow.sysconfig.with_cuda(), \"needs -DBUILD_CUDA=ON\")\n    def test_pad_conv_graph_cuda(test_case):\n        do_pad_conv_graph(test_case, True, True)\n        do_pad_conv_graph(test_case, True, False)\n        do_pad_conv_graph(test_case, True, False, True)\n\n    def test_pad_conv_graph_cpu(test_case):\n        do_pad_conv_graph(test_case, False, True)\n        do_pad_conv_graph(test_case, False, False)\n        do_pad_conv_graph(test_case, False, False, True)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "oneflow/ir/test/OneFlow/group_matmul.mlir",
    "content": "// RUN: oneflow-opt %s \\\n// RUN: -group-matmul | FileCheck %s\nmodule  {\n  // CHECK-LABEL: func.func\n  func.func @no_bias(%x: tensor<2x320xf16>, %weight1: tensor<1280x320xf16>, %weight2: tensor<1280x320xf16>) -> (tensor<2x1280xf16>, tensor<2x1280xf16>) {\n     %1 = \"oneflow.matmul\"(%x, %weight1) {alpha = 1.000000e+00 : f64, device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], op_name = \"unet.time_embedding.linear_1-matmul-20\", scope_symbol_id = 90 : i64, transpose_a = false, transpose_b = true} : (tensor<2x320xf16>, tensor<1280x320xf16>) -> tensor<2x1280xf16>\n     %2 = \"oneflow.matmul\"(%x, %weight2) {alpha = 1.000000e+00 : f64, device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], op_name = \"unet.time_embedding.linear_1-matmul-20\", scope_symbol_id = 90 : i64, transpose_a = false, transpose_b = true} : (tensor<2x320xf16>, tensor<1280x320xf16>) -> tensor<2x1280xf16>\n    return %1, %2 : tensor<2x1280xf16>, tensor<2x1280xf16>\n    // CHECK: @no_bias(%[[X:[a-zA-Z0-9_]+]]: tensor<2x320xf16>, %[[WEIGHT1:[a-zA-Z0-9_]+]]: tensor<1280x320xf16>, %[[WEIGHT2:[a-zA-Z0-9_]+]]: tensor<1280x320xf16>) -> (tensor<2x1280xf16>, tensor<2x1280xf16>)\n    // CHECK: %[[OUT:[a-zA-Z0-9_]+]]:2 = \"oneflow.grouped_matmul_bias\"(%[[X]], %[[X]], %[[WEIGHT2]], %[[WEIGHT1]])\n    // CHECK: return %[[OUT]]#1, %[[OUT]]#0\n  }\n\n  // CHECK-LABEL: func.func\n  func.func @with_bias(%x: tensor<2x320xf16>, %weight1: tensor<1280x320xf16>, %weight2: tensor<1280x320xf16>, %bias1: tensor<1280xf16>, %bias2: tensor<1280xf16>) -> (tensor<2x1280xf16>, tensor<2x1280xf16>) {\n     %1 = \"oneflow.matmul\"(%x, %weight1) {alpha = 1.000000e+00 : f64, device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], op_name = \"unet.time_embedding.linear_1-matmul-20\", scope_symbol_id = 90 : i64, transpose_a = false, transpose_b = true} : (tensor<2x320xf16>, tensor<1280x320xf16>) -> tensor<2x1280xf16>\n     %r1 = \"oneflow.bias_add\"(%1, %bias1) {axis = 1 : si32, device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], op_name = \"unet.time_embedding.linear_1-bias_add-21\", scope_symbol_id = 90 : i64} : (tensor<2x1280xf16>, tensor<1280xf16>) -> tensor<2x1280xf16>\n     %2 = \"oneflow.matmul\"(%x, %weight2) {alpha = 1.000000e+00 : f64, device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], op_name = \"unet.time_embedding.linear_1-matmul-20\", scope_symbol_id = 90 : i64, transpose_a = false, transpose_b = true} : (tensor<2x320xf16>, tensor<1280x320xf16>) -> tensor<2x1280xf16>\n     %r2 = \"oneflow.bias_add\"(%2, %bias2) {axis = 1 : si32, device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], op_name = \"unet.time_embedding.linear_1-bias_add-21\", scope_symbol_id = 90 : i64} : (tensor<2x1280xf16>, tensor<1280xf16>) -> tensor<2x1280xf16>\n    return %r1, %r2 : tensor<2x1280xf16>, tensor<2x1280xf16>\n    // CHECK: @with_bias(%[[X:[a-zA-Z0-9_]+]]: tensor<2x320xf16>, %[[WEIGHT1:[a-zA-Z0-9_]+]]: tensor<1280x320xf16>, %[[WEIGHT2:[a-zA-Z0-9_]+]]: tensor<1280x320xf16>, %[[BIAS1:[a-zA-Z0-9_]+]]: tensor<1280xf16>, %[[BIAS2:[a-zA-Z0-9_]+]]: tensor<1280xf16>)\n    // CHECK: %[[OUT:[a-zA-Z0-9_]+]]:2 = \"oneflow.grouped_matmul_bias\"(%[[X]], %[[X]], %[[WEIGHT2]], %[[WEIGHT1:[a-zA-Z0-9_]+]], %[[BIAS2:[a-zA-Z0-9_]+]], %[[BIAS1:[a-zA-Z0-9_]+]])\n    // CHECK: return %[[OUT]]#1, %[[OUT]]#0\n  }\n\n  // CHECK-LABEL: func.func\n  func.func @with_broadcast_add(%x: tensor<2x320xf16>, %weight1: tensor<1280x320xf16>, %weight2: tensor<1280x320xf16>, %bias1: tensor<1280xf16>, %bias2: tensor<1280xf16>) -> (tensor<2x1280xf16>, tensor<2x1280xf16>) {\n     %1 = \"oneflow.matmul\"(%x, %weight1) {alpha = 1.000000e+00 : f64, device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], op_name = \"unet.time_embedding.linear_1-matmul-20\", scope_symbol_id = 90 : i64, transpose_a = false, transpose_b = true} : (tensor<2x320xf16>, tensor<1280x320xf16>) -> tensor<2x1280xf16>\n     %r1 = \"oneflow.broadcast_add\"(%1, %bias1) {device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], op_name = \"unet.time_embedding.linear_1-bias_add-21\", scope_symbol_id = 90 : i64} : (tensor<2x1280xf16>, tensor<1280xf16>) -> tensor<2x1280xf16>\n     %2 = \"oneflow.matmul\"(%x, %weight2) {alpha = 1.000000e+00 : f64, device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], op_name = \"unet.time_embedding.linear_1-matmul-20\", scope_symbol_id = 90 : i64, transpose_a = false, transpose_b = true} : (tensor<2x320xf16>, tensor<1280x320xf16>) -> tensor<2x1280xf16>\n     %r2 = \"oneflow.broadcast_add\"(%2, %bias2) {device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], op_name = \"unet.time_embedding.linear_1-bias_add-21\", scope_symbol_id = 90 : i64} : (tensor<2x1280xf16>, tensor<1280xf16>) -> tensor<2x1280xf16>\n    return %r1, %r2 : tensor<2x1280xf16>, tensor<2x1280xf16>\n    // CHECK: @with_broadcast_add(%[[X:[a-zA-Z0-9_]+]]: tensor<2x320xf16>, %[[WEIGHT1:[a-zA-Z0-9_]+]]: tensor<1280x320xf16>, %[[WEIGHT2:[a-zA-Z0-9_]+]]: tensor<1280x320xf16>, %[[BIAS1:[a-zA-Z0-9_]+]]: tensor<1280xf16>, %[[BIAS2:[a-zA-Z0-9_]+]]: tensor<1280xf16>)\n    // CHECK: %[[OUT:[a-zA-Z0-9_]+]]:2 = \"oneflow.grouped_matmul_bias\"(%[[X]], %[[X]], %[[WEIGHT2]], %[[WEIGHT1:[a-zA-Z0-9_]+]], %[[BIAS2:[a-zA-Z0-9_]+]], %[[BIAS1:[a-zA-Z0-9_]+]])\n    // CHECK: return %[[OUT]]#1, %[[OUT]]#0\n  }\n\n  // CHECK-LABEL: func.func\n  func.func @mixed(%x: tensor<2x320xf16>, %weight1: tensor<1280x320xf16>, %weight2: tensor<1280x320xf16>, %bias1: tensor<1280xf16>, %bias2: tensor<1280xf16>) -> (tensor<2x1280xf16>, tensor<2x1280xf16>, tensor<2x1280xf16>, tensor<2x1280xf16>) {\n     %1 = \"oneflow.matmul\"(%x, %weight1) {alpha = 1.000000e+00 : f64, device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], op_name = \"unet.time_embedding.linear_1-matmul-20\", scope_symbol_id = 90 : i64, transpose_a = false, transpose_b = true} : (tensor<2x320xf16>, tensor<1280x320xf16>) -> tensor<2x1280xf16>\n     %r1 = \"oneflow.bias_add\"(%1, %bias1) {axis = 1 : si32, device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], op_name = \"unet.time_embedding.linear_1-bias_add-21\", scope_symbol_id = 90 : i64} : (tensor<2x1280xf16>, tensor<1280xf16>) -> tensor<2x1280xf16>\n     %2 = \"oneflow.matmul\"(%x, %weight2) {alpha = 1.000000e+00 : f64, device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], op_name = \"unet.time_embedding.linear_1-matmul-20\", scope_symbol_id = 90 : i64, transpose_a = false, transpose_b = true} : (tensor<2x320xf16>, tensor<1280x320xf16>) -> tensor<2x1280xf16>\n     %r2 = \"oneflow.bias_add\"(%2, %bias2) {axis = 1 : si32, device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], op_name = \"unet.time_embedding.linear_1-bias_add-21\", scope_symbol_id = 90 : i64} : (tensor<2x1280xf16>, tensor<1280xf16>) -> tensor<2x1280xf16>\n     %m1 = \"oneflow.matmul\"(%x, %weight1) {alpha = 1.000000e+00 : f64, device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], op_name = \"unet.time_embedding.linear_1-matmul-20\", scope_symbol_id = 90 : i64, transpose_a = false, transpose_b = true} : (tensor<2x320xf16>, tensor<1280x320xf16>) -> tensor<2x1280xf16>\n     %m2 = \"oneflow.matmul\"(%x, %weight2) {alpha = 1.000000e+00 : f64, device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], op_name = \"unet.time_embedding.linear_1-matmul-20\", scope_symbol_id = 90 : i64, transpose_a = false, transpose_b = true} : (tensor<2x320xf16>, tensor<1280x320xf16>) -> tensor<2x1280xf16>\n    return %r1, %r2, %m1, %m2: tensor<2x1280xf16>, tensor<2x1280xf16>, tensor<2x1280xf16>, tensor<2x1280xf16>\n    // CHECK: @mixed(%[[X:[a-zA-Z0-9_]+]]: tensor<2x320xf16>, %[[WEIGHT1:[a-zA-Z0-9_]+]]: tensor<1280x320xf16>, %[[WEIGHT2:[a-zA-Z0-9_]+]]: tensor<1280x320xf16>, %[[BIAS1:[a-zA-Z0-9_]+]]: tensor<1280xf16>, %[[BIAS2:[a-zA-Z0-9_]+]]: tensor<1280xf16>)\n    // CHECK: %[[OUT0:[a-zA-Z0-9_]+]]:2 = \"oneflow.grouped_matmul_bias\"(%[[X]], %[[X]], %[[WEIGHT2]], %[[WEIGHT1:[a-zA-Z0-9_]+]], %[[BIAS2:[a-zA-Z0-9_]+]], %[[BIAS1:[a-zA-Z0-9_]+]])\n    // CHECK: %[[OUT1:[a-zA-Z0-9_]+]]:2 = \"oneflow.grouped_matmul_bias\"(%[[X]], %[[X]], %[[WEIGHT2]], %[[WEIGHT1]])\n    // CHECK: return %[[OUT0]]#1, %[[OUT0]]#0, %[[OUT1]]#1, %[[OUT1]]#0\n  }\n\n  // CHECK-LABEL: func.func\n  func.func @left_alone(%x: tensor<2x320xf16>, %weight1: tensor<1280x320xf16>, %weight2: tensor<1280x320xf16>, %bias1: tensor<1280xf16>, %bias2: tensor<1280xf16>) -> (tensor<2x1280xf16>, tensor<2x1280xf16>, tensor<2x1280xf16>) {\n     %1 = \"oneflow.matmul\"(%x, %weight1) {alpha = 1.000000e+00 : f64, device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], op_name = \"unet.time_embedding.linear_1-matmul-20\", scope_symbol_id = 90 : i64, transpose_a = false, transpose_b = true} : (tensor<2x320xf16>, tensor<1280x320xf16>) -> tensor<2x1280xf16>\n     %r1 = \"oneflow.bias_add\"(%1, %bias1) {axis = 1 : si32, device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], op_name = \"unet.time_embedding.linear_1-bias_add-21\", scope_symbol_id = 90 : i64} : (tensor<2x1280xf16>, tensor<1280xf16>) -> tensor<2x1280xf16>\n     %2 = \"oneflow.matmul\"(%x, %weight2) {alpha = 1.000000e+00 : f64, device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], op_name = \"unet.time_embedding.linear_1-matmul-20\", scope_symbol_id = 90 : i64, transpose_a = false, transpose_b = true} : (tensor<2x320xf16>, tensor<1280x320xf16>) -> tensor<2x1280xf16>\n     %r2 = \"oneflow.bias_add\"(%2, %bias2) {axis = 1 : si32, device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], op_name = \"unet.time_embedding.linear_1-bias_add-21\", scope_symbol_id = 90 : i64} : (tensor<2x1280xf16>, tensor<1280xf16>) -> tensor<2x1280xf16>\n     %m1 = \"oneflow.matmul\"(%x, %weight1) {alpha = 1.000000e+00 : f64, device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], op_name = \"unet.time_embedding.linear_1-matmul-20\", scope_symbol_id = 90 : i64, transpose_a = false, transpose_b = true} : (tensor<2x320xf16>, tensor<1280x320xf16>) -> tensor<2x1280xf16>\n    return %r1, %r2, %m1: tensor<2x1280xf16>, tensor<2x1280xf16>, tensor<2x1280xf16>\n    // CHECK: @left_alone(%[[X:[a-zA-Z0-9_]+]]: tensor<2x320xf16>, %[[WEIGHT1:[a-zA-Z0-9_]+]]: tensor<1280x320xf16>, %[[WEIGHT2:[a-zA-Z0-9_]+]]: tensor<1280x320xf16>, %[[BIAS1:[a-zA-Z0-9_]+]]: tensor<1280xf16>, %[[BIAS2:[a-zA-Z0-9_]+]]: tensor<1280xf16>)\n    // CHECK: %[[OUT0:[a-zA-Z0-9_]+]]:2 = \"oneflow.grouped_matmul_bias\"(%[[X]], %[[X]], %[[WEIGHT2]], %[[WEIGHT1:[a-zA-Z0-9_]+]], %[[BIAS2:[a-zA-Z0-9_]+]], %[[BIAS1:[a-zA-Z0-9_]+]])\n    // CHECK: %[[OUT1:[a-zA-Z0-9_]+]] = \"oneflow.matmul\"(%arg0, %arg1)\n    // CHECK: return %[[OUT0]]#1, %[[OUT0]]#0, %[[OUT1]]\n  }\n  func.func @f_broadcast_matmul(%x: tensor<2x4096x320xf16>, %w1: tensor<320x320xf16>, %w2: tensor<320x320xf16>, %w3: tensor<320x320xf16>) -> (tensor<2x4096x320xf16>, tensor<2x4096x320xf16>, tensor<2x4096x320xf16>) {\n    %matmul1 = \"oneflow.broadcast_matmul\"(%x, %w1) {alpha = 1.000000e+00 : f64, device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], op_name = \"unet.up_blocks.3.attentions.1.transformer_blocks.0.attn1.to_q-broadcast_matmul-16315\", scope_symbol_id = 5497 : i64, transpose_a = false, transpose_b = true} : (tensor<2x4096x320xf16>, tensor<320x320xf16>) -> tensor<2x4096x320xf16>\n    %matmul2 = \"oneflow.broadcast_matmul\"(%x, %w2) {alpha = 1.000000e+00 : f64, device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], op_name = \"unet.up_blocks.3.attentions.1.transformer_blocks.0.attn1.to_k-broadcast_matmul-16316\", scope_symbol_id = 5505 : i64, transpose_a = false, transpose_b = true} : (tensor<2x4096x320xf16>, tensor<320x320xf16>) -> tensor<2x4096x320xf16>\n    %matmul3 = \"oneflow.broadcast_matmul\"(%x, %w3) {alpha = 1.000000e+00 : f64, device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], op_name = \"unet.up_blocks.3.attentions.1.transformer_blocks.0.attn1.to_v-broadcast_matmul-16317\", scope_symbol_id = 5513 : i64, transpose_a = false, transpose_b = true} : (tensor<2x4096x320xf16>, tensor<320x320xf16>) -> tensor<2x4096x320xf16>\n    return %matmul1, %matmul2, %matmul3 : tensor<2x4096x320xf16>, tensor<2x4096x320xf16>, tensor<2x4096x320xf16>\n    // CHECK: @f_broadcast_matmul(%[[X:[a-zA-Z0-9_]+]]: tensor<2x4096x320xf16>, %[[WEIGHT1:[a-zA-Z0-9_]+]]: tensor<320x320xf16>, %[[WEIGHT2:[a-zA-Z0-9_]+]]: tensor<320x320xf16>, %[[WEIGHT3:[a-zA-Z0-9_]+]]: tensor<320x320xf16>)\n    // CHECK: %[[OUT0:[a-zA-Z0-9_]+]]:3 = \"oneflow.grouped_matmul_bias\"(%[[X]], %[[X]], %[[X]], %[[WEIGHT3]], %[[WEIGHT2]], %[[WEIGHT1]])\n    // CHECK: return %[[OUT0]]#2, %[[OUT0]]#1, %[[OUT0]]#0\n  }\n\n  func.func @test_fused_matmul_bias_graph(%x: tensor<8x9xf64>, %w: tensor<10x9xf64>, %bias: tensor<10xf64>) -> (tensor<8x10xf64>, tensor<8x10xf64>) {\n    %y0 = \"oneflow.fused_matmul_bias\"(%x, %w, %bias) {device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], op_name = \"fused_matmul_bias-0\", scope_symbol_id = 12 : i64} : (tensor<8x9xf64>, tensor<10x9xf64>, tensor<10xf64>) -> tensor<8x10xf64>\n    %y1 = \"oneflow.fused_matmul_bias\"(%x, %w, %bias) {device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], op_name = \"fused_matmul_bias-0\", scope_symbol_id = 12 : i64} : (tensor<8x9xf64>, tensor<10x9xf64>, tensor<10xf64>) -> tensor<8x10xf64>\n    return %y0, %y1 : tensor<8x10xf64>, tensor<8x10xf64>\n    // CHECK: @test_fused_matmul_bias_graph(%[[X:[a-zA-Z0-9_]+]]: tensor<8x9xf64>, %[[W:[a-zA-Z0-9_]+]]: tensor<10x9xf64>, %[[BIAS:[a-zA-Z0-9_]+]]: tensor<10xf64>)\n    // CHECK: %[[OUT0:[a-zA-Z0-9_]+]]:2 = \"oneflow.grouped_matmul_bias\"(%[[X]], %[[X]], %[[W]], %[[W]], %[[BIAS]], %[[BIAS]])\n    // CHECK: return %[[OUT0]]#1, %[[OUT0]]#0\n  }\n\n  func.func @test_fused_matmul_bias_graph_mixed(%x: tensor<8x9xf64>, %w: tensor<10x9xf64>, %bias: tensor<10xf64>, %w1: tensor<10x9xf64>, %bias1: tensor<10xf64>) -> (tensor<8x10xf64>, tensor<8x10xf64>, tensor<8x10xf64>) {\n    %y0 = \"oneflow.fused_matmul_bias\"(%x, %w, %bias) {device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], op_name = \"fused_matmul_bias-0\", scope_symbol_id = 12 : i64} : (tensor<8x9xf64>, tensor<10x9xf64>, tensor<10xf64>) -> tensor<8x10xf64>\n    %y1 = \"oneflow.fused_matmul_bias\"(%x, %w, %bias) {device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], op_name = \"fused_matmul_bias-0\", scope_symbol_id = 12 : i64} : (tensor<8x9xf64>, tensor<10x9xf64>, tensor<10xf64>) -> tensor<8x10xf64>\n    %matmul = \"oneflow.matmul\"(%x, %w1) {alpha = 1.000000e+00 : f64, device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], op_name = \"unet.time_embedding.linear_1-matmul-20\", scope_symbol_id = 90 : i64, transpose_a = false, transpose_b = true} : (tensor<8x9xf64>, tensor<10x9xf64>) ->  tensor<8x10xf64>\n    %bias_add = \"oneflow.bias_add\"(%matmul, %bias1) {axis = 1 : si32, device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], op_name = \"unet.time_embedding.linear_1-bias_add-21\", scope_symbol_id = 90 : i64} : (tensor<8x10xf64>, tensor<10xf64>) -> tensor<8x10xf64>\n    return %y0, %y1, %bias_add : tensor<8x10xf64>, tensor<8x10xf64>, tensor<8x10xf64>\n    // CHECK: @test_fused_matmul_bias_graph_mixed(%[[X:[a-zA-Z0-9_]+]]: tensor<8x9xf64>, %[[W:[a-zA-Z0-9_]+]]: tensor<10x9xf64>, %[[BIAS:[a-zA-Z0-9_]+]]: tensor<10xf64>, %[[W1:[a-zA-Z0-9_]+]]: tensor<10x9xf64>, %[[BIAS1:[a-zA-Z0-9_]+]]: tensor<10xf64>)\n    // CHECK: %[[OUT0:[a-zA-Z0-9_]+]]:3 = \"oneflow.grouped_matmul_bias\"(%[[X]], %[[X]], %[[X]], %[[W1]], %[[W]], %[[W]], %[[BIAS1]], %[[BIAS]], %[[BIAS]])\n    // CHECK: return %[[OUT0]]#2, %[[OUT0]]#1, %[[OUT0]]#0\n  }\n}\n"
  },
  {
    "path": "oneflow/ir/test/OneFlow/jit_outline_func.mlir",
    "content": "// RUN: oneflow-opt %s \\\n// RUN: -ofjob-to-func \\\n// RUN: -convert-to-signless-for-tosa \\\n// RUN: -lower-oneflow-to-tosa=\"full=0 lower-job=0\" \\\n// RUN: --tosa-make-broadcastable \\\n// RUN: -lower-oneflow-to-linalg \\\n// RUN: -tosa-to-tensor \\\n// RUN: | oneflow-opt -pass-pipeline=\"builtin.module(func.func(tosa-to-linalg-named,tosa-to-linalg))\" \\\n// RUN: | oneflow-opt -linalg-fuse-elementwise-ops \\\n// RUN: -func-to-ofjob \\\n// RUN: | oneflow-opt -pass-pipeline=\"builtin.module(oneflow.job(outline-jit-function{compile-to-llvm=0}))\" \\\n// RUN: | oneflow-opt -canonicalize \\\n// RUN: | FileCheck --dump-input=always %s\n\n// CHECK: linalg.generic\n// CHECK: oneflow.mlir_jit\n// CHECK-NOT: oneflow.softmax\n\noneflow.job @GraphToRun_11(%arg0: tensor<2x256x1280xf16>, %arg1: tensor<2x77x1280xf16>, %arg2: tensor<2x77x1280xf16>) -> tensor<2x256x1280xf16> {\n  %output = \"oneflow.input\"(%arg0) {data_type = 9 : i32, device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], is_dynamic = false, nd_sbp = [\"B\"], op_name = \"_GraphToRun_11_input.0.0_2\", output_lbns = [\"_GraphToRun_11_input.0.0_2/out\"], scope_symbol_id = 681 : i64, shape = [2 : si64, 256 : si64, 1280 : si64]} : (tensor<2x256x1280xf16>) -> tensor<2x256x1280xf16>\n  %output_0 = \"oneflow.input\"(%arg1) {data_type = 9 : i32, device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], is_dynamic = false, nd_sbp = [\"B\"], op_name = \"_GraphToRun_11_input.0.1_3\", output_lbns = [\"_GraphToRun_11_input.0.1_3/out\"], scope_symbol_id = 681 : i64, shape = [2 : si64, 77 : si64, 1280 : si64]} : (tensor<2x77x1280xf16>) -> tensor<2x77x1280xf16>\n  %output_1 = \"oneflow.input\"(%arg2) {data_type = 9 : i32, device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], is_dynamic = false, nd_sbp = [\"B\"], op_name = \"_GraphToRun_11_input.0.2_4\", output_lbns = [\"_GraphToRun_11_input.0.2_4/out\"], scope_symbol_id = 681 : i64, shape = [2 : si64, 77 : si64, 1280 : si64]} : (tensor<2x77x1280xf16>) -> tensor<2x77x1280xf16>\n  %0 = \"oneflow.reshape\"(%output) {device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], op_name = \"reshape-0\", scope_symbol_id = 681 : i64, shape = [2 : si64, 256 : si64, 8 : si64, 160 : si64]} : (tensor<2x256x1280xf16>) -> tensor<2x256x8x160xf16>\n  %1 = \"oneflow.reshape\"(%output_0) {device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], op_name = \"reshape-2\", scope_symbol_id = 681 : i64, shape = [2 : si64, 77 : si64, 8 : si64, 160 : si64]} : (tensor<2x77x1280xf16>) -> tensor<2x77x8x160xf16>\n  %2 = \"oneflow.reshape\"(%output_1) {device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], op_name = \"reshape-4\", scope_symbol_id = 681 : i64, shape = [2 : si64, 77 : si64, 8 : si64, 160 : si64]} : (tensor<2x77x1280xf16>) -> tensor<2x77x8x160xf16>\n  %3 = \"oneflow.transpose\"(%0) {device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], op_name = \"transpose-1\", perm = [0 : si32, 2 : si32, 1 : si32, 3 : si32], scope_symbol_id = 681 : i64} : (tensor<2x256x8x160xf16>) -> tensor<2x8x256x160xf16>\n  %4 = \"oneflow.transpose\"(%1) {device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], op_name = \"transpose-3\", perm = [0 : si32, 2 : si32, 1 : si32, 3 : si32], scope_symbol_id = 681 : i64} : (tensor<2x77x8x160xf16>) -> tensor<2x8x77x160xf16>\n  %5 = \"oneflow.transpose\"(%2) {device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], op_name = \"transpose-5\", perm = [0 : si32, 2 : si32, 1 : si32, 3 : si32], scope_symbol_id = 681 : i64} : (tensor<2x77x8x160xf16>) -> tensor<2x8x77x160xf16>\n  %6 = \"oneflow.reshape\"(%3) {device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], op_name = \"reshape-6\", scope_symbol_id = 681 : i64, shape = [16 : si64, 256 : si64, 160 : si64]} : (tensor<2x8x256x160xf16>) -> tensor<16x256x160xf16>\n  %7 = \"oneflow.reshape\"(%4) {device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], op_name = \"reshape-7\", scope_symbol_id = 681 : i64, shape = [16 : si64, 77 : si64, 160 : si64]} : (tensor<2x8x77x160xf16>) -> tensor<16x77x160xf16>\n  %8 = \"oneflow.reshape\"(%5) {device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], op_name = \"reshape-9\", scope_symbol_id = 681 : i64, shape = [16 : si64, 77 : si64, 160 : si64]} : (tensor<2x8x77x160xf16>) -> tensor<16x77x160xf16>\n  %9 = \"oneflow.transpose\"(%7) {device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], op_name = \"transpose-8\", perm = [0 : si32, 2 : si32, 1 : si32], scope_symbol_id = 681 : i64} : (tensor<16x77x160xf16>) -> tensor<16x160x77xf16>\n  %10 = \"oneflow.batch_matmul\"(%6, %9) {alpha = 0.079056941504209485 : f64, device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], op_name = \"batch_matmul-11\", scope_symbol_id = 681 : i64, transpose_a = false, transpose_b = false} : (tensor<16x256x160xf16>, tensor<16x160x77xf16>) -> tensor<16x256x77xf16>\n  %11 = \"oneflow.softmax\"(%10) {device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], op_name = \"softmax-12\", scope_symbol_id = 681 : i64} : (tensor<16x256x77xf16>) -> tensor<16x256x77xf16>\n  %12 = \"oneflow.batch_matmul\"(%11, %8) {alpha = 1.000000e+00 : f64, device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], op_name = \"batch_matmul-13\", scope_symbol_id = 681 : i64, transpose_a = false, transpose_b = false} : (tensor<16x256x77xf16>, tensor<16x77x160xf16>) -> tensor<16x256x160xf16>\n  %13 = \"oneflow.reshape\"(%12) {device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], op_name = \"reshape-14\", scope_symbol_id = 681 : i64, shape = [2 : si64, 8 : si64, 256 : si64, 160 : si64]} : (tensor<16x256x160xf16>) -> tensor<2x8x256x160xf16>\n  %14 = \"oneflow.transpose\"(%13) {device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], op_name = \"transpose-15\", perm = [0 : si32, 2 : si32, 1 : si32, 3 : si32], scope_symbol_id = 681 : i64} : (tensor<2x8x256x160xf16>) -> tensor<2x256x8x160xf16>\n  %15 = \"oneflow.reshape\"(%14) {device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], op_name = \"reshape-16\", scope_symbol_id = 681 : i64, shape = [2 : si64, 256 : si64, 1280 : si64]} : (tensor<2x256x8x160xf16>) -> tensor<2x256x1280xf16>\n  %output_2 = \"oneflow.output\"(%15) {data_type = 9 : i32, device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], is_dynamic = false, nd_sbp = [\"B\"], op_name = \"_GraphToRun_11_output.0.0_2\", output_lbns = [\"_GraphToRun_11_output.0.0_2/out\"], scope_symbol_id = 681 : i64, shape = [2 : si64, 256 : si64, 1280 : si64]} : (tensor<2x256x1280xf16>) -> tensor<2x256x1280xf16>\n  oneflow.return %output_2 : tensor<2x256x1280xf16>\n}\n\n// CHECK: oneflow.mlir_jit\n// CHECK-NOT: oneflow.cast\noneflow.job @GraphToRun_1(%arg0: tensor<2x5xsi64>, %arg1: tensor<1xf32>) -> tensor<2x5xf32> {\n  %output = \"oneflow.input\"(%arg0) {data_type = 6 : i32, device_name = [\"@0:0\"], device_tag = \"cpu\", hierarchy = [1], is_dynamic = false, nd_sbp = [\"B\"], op_name = \"_GraphToRun_1_input.0.0_2\", output_lbns = [\"_GraphToRun_1_input.0.0_2/out\"], scope_symbol_id = 34 : i64, shape = [2 : si64, 5 : si64]} : (tensor<2x5xsi64>) -> tensor<2x5xsi64>\n  %output_0 = \"oneflow.input\"(%arg1) {data_type = 2 : i32, device_name = [\"@0:0\"], device_tag = \"cpu\", hierarchy = [1], is_dynamic = false, nd_sbp = [\"B\"], op_name = \"_GraphToRun_1_input.0.1_3\", output_lbns = [\"_GraphToRun_1_input.0.1_3/out\"], scope_symbol_id = 34 : i64, shape = [1 : si64]} : (tensor<1xf32>) -> tensor<1xf32>\n  %0 = \"oneflow.cast\"(%output) {device_name = [\"@0:0\"], device_tag = \"cpu\", dtype = 2 : i32, hierarchy = [1], op_name = \"fw-cast-0\", pin_memory = false, scope_symbol_id = 41 : i64} : (tensor<2x5xsi64>) -> tensor<2x5xf32>\n  %1 = \"oneflow.scalar_mul_by_tensor\"(%0, %output_0) {device_name = [\"@0:0\"], device_tag = \"cpu\", hierarchy = [1], op_name = \"fw-broadcast_mul-1-mlir-gen-2\", scope_symbol_id = 41 : i64} : (tensor<2x5xf32>, tensor<1xf32>) -> tensor<2x5xf32>\n  %output_1 = \"oneflow.output\"(%1) {data_type = 2 : i32, device_name = [\"@0:0\"], device_tag = \"cpu\", hierarchy = [1], is_dynamic = false, nd_sbp = [\"B\"], op_name = \"_GraphToRun_1_output.0.0_2\", output_lbns = [\"_GraphToRun_1_output.0.0_2/out\"], scope_symbol_id = 34 : i64, shape = [2 : si64, 5 : si64]} : (tensor<2x5xf32>) -> tensor<2x5xf32>\n  oneflow.return %output_1 : tensor<2x5xf32>\n}\n"
  },
  {
    "path": "oneflow/ir/test/OneFlow/kernel_launch/OKLPass/lower_launcher_to_llvm_ptr.mlir",
    "content": "// RUN: oneflow-opt %s \\\n// RUN: -lower-launcher-to-llvm-ptr \\\n// RUN: | FileCheck %s\n\n// CHECK:  func.func @okl_subgraph(%[[ARG:[a-zA-Z0-9_]+]]: !llvm.ptr<i8>) attributes {llvm.emit_c_interface} {\n// CHECK:  %[[ARG0:[a-zA-Z0-9_]+]] = builtin.unrealized_conversion_cast %[[ARG]] : !llvm.ptr<i8> to !okl.launcher_ctx\n// CHECK:  \"okl.get_tensor_from_arg\"(%[[ARG0]]) {index = 0 : i32} : (!okl.launcher_ctx) -> tensor<2xf32>\n// CHECK:  \"okl.get_tensor_as_ret\"(%[[ARG0]], %[[ARG3:[a-zA-Z0-9_]+]]) {index = 1 : i32} : (!okl.launcher_ctx, tensor<2xf32>) -> tensor<2xf32>\n\nmodule {\n  func.func @okl_subgraph(%arg0: !okl.launcher_ctx) attributes {cuda_graph_support = false, pool_size = 1024 : i64} {\n    \"okl.wrapper_kernel\"() ({\n      %0 = \"okl.get_tensor_from_arg\"(%arg0) {index = 0 : i32} : (!okl.launcher_ctx) -> tensor<2xf32>\n      %1 = \"oneflow.relu\"(%0) {device_name = [\"@0:0\"], device_tag = \"cpu\", hierarchy = [1], op_name = \"relu-0\", scope_symbol_id = 12 : i64} : (tensor<2xf32>) -> tensor<2xf32>\n      %2 = \"okl.tensor_to_pool\"(%arg0, %1) {offset = 0 : i64} : (!okl.launcher_ctx, tensor<2xf32>) -> tensor<2xf32>\n      okl.return\n    }) {index = 0 : i32} : () -> ()\n    \"okl.wrapper_kernel\"() ({\n      %0 = \"okl.pool_to_tensor\"(%arg0) {offset = 0 : i64} : (!okl.launcher_ctx) -> tensor<2xf32>\n      %1 = \"oneflow.tanh\"(%0) {device_name = [\"@0:0\"], device_tag = \"cpu\", hierarchy = [1], op_name = \"tanh-1\", scope_symbol_id = 12 : i64} : (tensor<2xf32>) -> tensor<2xf32>\n      %2 = \"okl.tensor_to_pool\"(%arg0, %1) {offset = 512 : i64} : (!okl.launcher_ctx, tensor<2xf32>) -> tensor<2xf32>\n      okl.return\n    }) {index = 1 : i32} : () -> ()\n    \"okl.wrapper_kernel\"() ({\n      %0 = \"okl.pool_to_tensor\"(%arg0) {offset = 512 : i64} : (!okl.launcher_ctx) -> tensor<2xf32>\n      %1 = \"oneflow.arg_sort\"(%0) {device_name = [\"@0:0\"], device_tag = \"cpu\", direction = \"ASCENDING\", hierarchy = [1], op_name = \"arg_sort-2\", scope_symbol_id = 12 : i64} : (tensor<2xf32>) -> tensor<2xsi32>\n      %2 = \"okl.get_tensor_as_ret\"(%arg0, %1) {index = 0 : i32} : (!okl.launcher_ctx, tensor<2xsi32>) -> tensor<2xsi32>\n      okl.return\n    }) {index = 2 : i32} : () -> ()\n    \"okl.wrapper_kernel\"() ({\n      %0 = \"okl.pool_to_tensor\"(%arg0) {offset = 512 : i64} : (!okl.launcher_ctx) -> tensor<2xf32>\n      %1 = \"okl.get_tensor_from_ret\"(%arg0) {index = 0 : i32} : (!okl.launcher_ctx) -> tensor<2xsi32>\n      %2 = \"oneflow.dim_gather\"(%0, %1) {device_name = [\"@0:0\"], device_tag = \"cpu\", dim = 0 : si32, hierarchy = [1], op_name = \"dim_gather-3\", scope_symbol_id = 12 : i64} : (tensor<2xf32>, tensor<2xsi32>) -> tensor<2xf32>\n      %3 = \"okl.get_tensor_as_ret\"(%arg0, %2) {index = 1 : i32} : (!okl.launcher_ctx, tensor<2xf32>) -> tensor<2xf32>\n      okl.return\n    }) {index = 3 : i32} : () -> ()\n    return\n  }\n}\n\n"
  },
  {
    "path": "oneflow/ir/test/OneFlow/kernel_launch/OKLPass/lower_okl_to_llvm_call.mlir",
    "content": "// RUN: oneflow-opt %s \\\n// RUN: -lower-okl-to-llvm-call \\\n// RUN: | FileCheck %s\n\n// CHECK-COUNT-4: llvm.call @okl_llvm_func\nmodule {\n  func.func @okl_subgraph(%arg0: !llvm.ptr<i8>) attributes {llvm.emit_c_interface} {\n    %0 = builtin.unrealized_conversion_cast %arg0 : !llvm.ptr<i8> to !okl.launcher_ctx\n    \"okl.wrapper_kernel\"() ({\n      %1 = \"okl.get_tensor_from_arg\"(%0) {index = 0 : i32} : (!okl.launcher_ctx) -> tensor<2xf32>\n      %2 = \"oneflow.relu\"(%1) {device_name = [\"@0:0\"], device_tag = \"cpu\", hierarchy = [1], op_name = \"relu-0\", scope_symbol_id = 12 : i64} : (tensor<2xf32>) -> tensor<2xf32>\n      %3 = \"okl.tensor_to_pool\"(%0, %2) {offset = 0 : i64} : (!okl.launcher_ctx, tensor<2xf32>) -> tensor<2xf32>\n      okl.return\n    }) {index = 0 : i32} : () -> ()\n    \"okl.wrapper_kernel\"() ({\n      %1 = \"okl.pool_to_tensor\"(%0) {offset = 0 : i64} : (!okl.launcher_ctx) -> tensor<2xf32>\n      %2 = \"oneflow.tanh\"(%1) {device_name = [\"@0:0\"], device_tag = \"cpu\", hierarchy = [1], op_name = \"tanh-1\", scope_symbol_id = 12 : i64} : (tensor<2xf32>) -> tensor<2xf32>\n      %3 = \"okl.tensor_to_pool\"(%0, %2) {offset = 512 : i64} : (!okl.launcher_ctx, tensor<2xf32>) -> tensor<2xf32>\n      okl.return\n    }) {index = 1 : i32} : () -> ()\n    \"okl.wrapper_kernel\"() ({\n      %1 = \"okl.pool_to_tensor\"(%0) {offset = 512 : i64} : (!okl.launcher_ctx) -> tensor<2xf32>\n      %2 = \"oneflow.arg_sort\"(%1) {device_name = [\"@0:0\"], device_tag = \"cpu\", direction = \"ASCENDING\", hierarchy = [1], op_name = \"arg_sort-2\", scope_symbol_id = 12 : i64} : (tensor<2xf32>) -> tensor<2xsi32>\n      %3 = \"okl.get_tensor_as_ret\"(%0, %2) {index = 0 : i32} : (!okl.launcher_ctx, tensor<2xsi32>) -> tensor<2xsi32>\n      okl.return\n    }) {index = 2 : i32} : () -> ()\n    \"okl.wrapper_kernel\"() ({\n      %1 = \"okl.pool_to_tensor\"(%0) {offset = 512 : i64} : (!okl.launcher_ctx) -> tensor<2xf32>\n      %2 = \"okl.get_tensor_from_ret\"(%0) {index = 0 : i32} : (!okl.launcher_ctx) -> tensor<2xsi32>\n      %3 = \"oneflow.dim_gather\"(%1, %2) {device_name = [\"@0:0\"], device_tag = \"cpu\", dim = 0 : si32, hierarchy = [1], op_name = \"dim_gather-3\", scope_symbol_id = 12 : i64} : (tensor<2xf32>, tensor<2xsi32>) -> tensor<2xf32>\n      %4 = \"okl.get_tensor_as_ret\"(%0, %3) {index = 1 : i32} : (!okl.launcher_ctx, tensor<2xf32>) -> tensor<2xf32>\n      okl.return\n    }) {index = 3 : i32} : () -> ()\n    return\n  }\n}\n\n"
  },
  {
    "path": "oneflow/ir/test/OneFlow/kernel_launch/OKLPass/tag_cuda_graph_support.mlir",
    "content": "// RUN: oneflow-opt %s \\\n// RUN: -tag-cuda-graph-support \\\n// RUN: | FileCheck %s\n\n// CHECK:  func.func @okl_subgraph(%[[ARG0:[a-zA-Z0-9_]+]]: !okl.launcher_ctx) attributes {cuda_graph_support = false, pool_size = 1024 : i64}\n\nmodule {\n  func.func @okl_subgraph(%arg0: !okl.launcher_ctx) attributes {pool_size = 1024 : i64} {\n    \"okl.wrapper_kernel\"() ({\n      %0 = \"okl.get_tensor_from_arg\"(%arg0) {index = 0 : i32} : (!okl.launcher_ctx) -> tensor<2xf32>\n      %1 = \"oneflow.relu\"(%0) {device_name = [\"@0:0\"], device_tag = \"cpu\", hierarchy = [1], op_name = \"relu-0\", scope_symbol_id = 12 : i64} : (tensor<2xf32>) -> tensor<2xf32>\n      %2 = \"okl.tensor_to_pool\"(%arg0, %1) {offset = 0 : i64} : (!okl.launcher_ctx, tensor<2xf32>) -> tensor<2xf32>\n      okl.return\n    }) {index = 0 : i32} : () -> ()\n    \"okl.wrapper_kernel\"() ({\n      %0 = \"okl.pool_to_tensor\"(%arg0) {offset = 0 : i64} : (!okl.launcher_ctx) -> tensor<2xf32>\n      %1 = \"oneflow.tanh\"(%0) {device_name = [\"@0:0\"], device_tag = \"cpu\", hierarchy = [1], op_name = \"tanh-1\", scope_symbol_id = 12 : i64} : (tensor<2xf32>) -> tensor<2xf32>\n      %2 = \"okl.tensor_to_pool\"(%arg0, %1) {offset = 512 : i64} : (!okl.launcher_ctx, tensor<2xf32>) -> tensor<2xf32>\n      okl.return\n    }) {index = 1 : i32} : () -> ()\n    \"okl.wrapper_kernel\"() ({\n      %0 = \"okl.pool_to_tensor\"(%arg0) {offset = 512 : i64} : (!okl.launcher_ctx) -> tensor<2xf32>\n      %1 = \"oneflow.arg_sort\"(%0) {device_name = [\"@0:0\"], device_tag = \"cpu\", direction = \"ASCENDING\", hierarchy = [1], op_name = \"arg_sort-2\", scope_symbol_id = 12 : i64} : (tensor<2xf32>) -> tensor<2xsi32>\n      %2 = \"okl.get_tensor_as_ret\"(%arg0, %1) {index = 0 : i32} : (!okl.launcher_ctx, tensor<2xsi32>) -> tensor<2xsi32>\n      okl.return\n    }) {index = 2 : i32} : () -> ()\n    \"okl.wrapper_kernel\"() ({\n      %0 = \"okl.pool_to_tensor\"(%arg0) {offset = 512 : i64} : (!okl.launcher_ctx) -> tensor<2xf32>\n      %1 = \"okl.get_tensor_from_ret\"(%arg0) {index = 0 : i32} : (!okl.launcher_ctx) -> tensor<2xsi32>\n      %2 = \"oneflow.dim_gather\"(%0, %1) {device_name = [\"@0:0\"], device_tag = \"cpu\", dim = 0 : si32, hierarchy = [1], op_name = \"dim_gather-3\", scope_symbol_id = 12 : i64} : (tensor<2xf32>, tensor<2xsi32>) -> tensor<2xf32>\n      %3 = \"okl.get_tensor_as_ret\"(%arg0, %2) {index = 1 : i32} : (!okl.launcher_ctx, tensor<2xf32>) -> tensor<2xf32>\n      okl.return\n    }) {index = 3 : i32} : () -> ()\n    return\n  }\n}\n\n"
  },
  {
    "path": "oneflow/ir/test/OneFlow/kernel_launch/OKMPass/extract_okm_tensor.mlir",
    "content": "// RUN: oneflow-opt %s \\\n// RUN: -extract-okm-tensor \\\n// RUN: | FileCheck %s\n\n// CHECK: \"okm.arg_to_tensor\"() {index = 0 : i32} : () -> tensor<2xf32>\n// CHECK: \"okm.tensor_to_ret\"(%[[ARG0:[a-zA-Z0-9_]+]]) {index = 0 : i32} : (tensor<2xsi32>) -> tensor<2xsi32>\n// CHECK: \"okm.tensor_to_ret\"(%[[ARG1:[a-zA-Z0-9_]+]]) {index = 1 : i32} : (tensor<2xf32>) -> tensor<2xf32>\n\nmodule {\n  func.func @_mlir_oneflow_subgraph0(%arg0: tensor<2xf32>) -> (tensor<2xsi32>, tensor<2xf32>) attributes {llvm.emit_c_interface} {\n    %0 = \"oneflow.relu\"(%arg0) {device_name = [\"@0:0\"], device_tag = \"cpu\", hierarchy = [1], op_name = \"relu-0\", scope_symbol_id = 12 : i64} : (tensor<2xf32>) -> tensor<2xf32>\n    %1 = \"oneflow.tanh\"(%0) {device_name = [\"@0:0\"], device_tag = \"cpu\", hierarchy = [1], op_name = \"tanh-1\", scope_symbol_id = 12 : i64} : (tensor<2xf32>) -> tensor<2xf32>\n    %2 = \"oneflow.arg_sort\"(%1) {device_name = [\"@0:0\"], device_tag = \"cpu\", direction = \"ASCENDING\", hierarchy = [1], op_name = \"arg_sort-2\", scope_symbol_id = 12 : i64} : (tensor<2xf32>) -> tensor<2xsi32>\n    %3 = \"oneflow.dim_gather\"(%1, %2) {device_name = [\"@0:0\"], device_tag = \"cpu\", dim = 0 : si32, hierarchy = [1], op_name = \"dim_gather-3\", scope_symbol_id = 12 : i64} : (tensor<2xf32>, tensor<2xsi32>) -> tensor<2xf32>\n    return %2, %3 : tensor<2xsi32>, tensor<2xf32>\n  }\n}\n"
  },
  {
    "path": "oneflow/ir/test/OneFlow/kernel_launch/OKMPass/okm_to_okl.mlir",
    "content": "// RUN: oneflow-opt %s \\\n// RUN: -convert-okm-to-okl \\\n// RUN: | FileCheck %s\n\n// CHECK:  func.func @okl_subgraph(%arg0: !okl.launcher_ctx) attributes {pool_size = 1024 : i64} {\n// CHECK-COUNT-4:    \"okl.wrapper_kernel\"()\n\nmodule {\n  func.func @okm_alloc_subgraph0() {\n    %c512 = arith.constant 512 : index\n    %c0 = arith.constant 0 : index\n    %0 = \"okm.alloc_memref\"() : () -> memref<1024xi8>\n    %1 = \"okm.arg_to_memref\"() {index = 0 : i32} : () -> memref<2xf32>\n    %2 = memref.view %0[%c0][] : memref<1024xi8> to memref<2xf32>\n    %3 = \"okm.wrapper_kernel\"(%1, %2) ({\n      %12 = bufferization.to_tensor %1 : memref<2xf32>\n      %13 = \"oneflow.relu\"(%12) {device_name = [\"@0:0\"], device_tag = \"cpu\", hierarchy = [1], op_name = \"relu-0\", scope_symbol_id = 12 : i64} : (tensor<2xf32>) -> tensor<2xf32>\n      %14 = bufferization.to_memref %13 : memref<2xf32>\n      okm.return %14 : memref<2xf32>\n    }) : (memref<2xf32>, memref<2xf32>) -> memref<2xf32>\n    %4 = memref.view %0[%c512][] : memref<1024xi8> to memref<2xf32>\n    %5 = \"okm.wrapper_kernel\"(%2, %4) ({\n      %12 = bufferization.to_tensor %2 : memref<2xf32>\n      %13 = \"oneflow.tanh\"(%12) {device_name = [\"@0:0\"], device_tag = \"cpu\", hierarchy = [1], op_name = \"tanh-1\", scope_symbol_id = 12 : i64} : (tensor<2xf32>) -> tensor<2xf32>\n      %14 = bufferization.to_memref %13 : memref<2xf32>\n      okm.return %14 : memref<2xf32>\n    }) : (memref<2xf32>, memref<2xf32>) -> memref<2xf32>\n    %6 = \"okm.ret_to_memref\"() {index = 0 : i32} : () -> memref<2xsi32>\n    %7 = \"okm.wrapper_kernel\"(%4, %6) ({\n      %12 = bufferization.to_tensor %4 : memref<2xf32>\n      %13 = \"oneflow.arg_sort\"(%12) {device_name = [\"@0:0\"], device_tag = \"cpu\", direction = \"ASCENDING\", hierarchy = [1], op_name = \"arg_sort-2\", scope_symbol_id = 12 : i64} : (tensor<2xf32>) -> tensor<2xsi32>\n      %14 = bufferization.to_memref %13 : memref<2xsi32>\n      okm.return %14 : memref<2xsi32>\n    }) : (memref<2xf32>, memref<2xsi32>) -> memref<2xsi32>\n    %8 = \"okm.ret_to_memref\"() {index = 1 : i32} : () -> memref<2xf32>\n    %9 = \"okm.wrapper_kernel\"(%4, %6, %8) ({\n      %12 = bufferization.to_tensor %4 : memref<2xf32>\n      %13 = bufferization.to_tensor %6 : memref<2xsi32>\n      %14 = \"oneflow.dim_gather\"(%12, %13) {device_name = [\"@0:0\"], device_tag = \"cpu\", dim = 0 : si32, hierarchy = [1], op_name = \"dim_gather-3\", scope_symbol_id = 12 : i64} : (tensor<2xf32>, tensor<2xsi32>) -> tensor<2xf32>\n      %15 = bufferization.to_memref %14 : memref<2xf32>\n      okm.return %15 : memref<2xf32>\n    }) : (memref<2xf32>, memref<2xsi32>, memref<2xf32>) -> memref<2xf32>\n    %10 = \"okm.memref_to_ret\"(%7) {index = 0 : i32} : (memref<2xsi32>) -> memref<2xsi32>\n    %11 = \"okm.memref_to_ret\"(%9) {index = 1 : i32} : (memref<2xf32>) -> memref<2xf32>\n    return\n  }\n}\n"
  },
  {
    "path": "oneflow/ir/test/OneFlow/kernel_launch/OKMPass/opt_okm_memref.mlir",
    "content": "// RUN: oneflow-opt %s \\\n// RUN: -opt-okm-memref \\\n// RUN: | FileCheck %s\n\n// CHECK: func.func @okm_alloc_subgraph\n// CHECK: \"okm.alloc_memref\"()\n// CHECK: memref.view\n\nmodule {\n  func.func @okm_wrap_subgraph0() {\n    %0 = \"okm.arg_to_memref\"() {index = 0 : i32} : () -> memref<2xf32>\n    %1 = \"okm.plan_memref\"() : () -> memref<2xf32>\n    %2 = \"okm.wrapper_kernel\"(%0, %1) ({\n      %11 = bufferization.to_tensor %0 : memref<2xf32>\n      %12 = \"oneflow.relu\"(%11) {device_name = [\"@0:0\"], device_tag = \"cpu\", hierarchy = [1], op_name = \"relu-0\", scope_symbol_id = 12 : i64} : (tensor<2xf32>) -> tensor<2xf32>\n      %13 = bufferization.to_memref %12 : memref<2xf32>\n      okm.return %13 : memref<2xf32>\n    }) : (memref<2xf32>, memref<2xf32>) -> memref<2xf32>\n    %3 = \"okm.plan_memref\"() : () -> memref<2xf32>\n    %4 = \"okm.wrapper_kernel\"(%1, %3) ({\n      %11 = bufferization.to_tensor %1 : memref<2xf32>\n      %12 = \"oneflow.tanh\"(%11) {device_name = [\"@0:0\"], device_tag = \"cpu\", hierarchy = [1], op_name = \"tanh-1\", scope_symbol_id = 12 : i64} : (tensor<2xf32>) -> tensor<2xf32>\n      %13 = bufferization.to_memref %12 : memref<2xf32>\n      okm.return %13 : memref<2xf32>\n    }) : (memref<2xf32>, memref<2xf32>) -> memref<2xf32>\n    %5 = \"okm.ret_to_memref\"() {index = 0 : i32} : () -> memref<2xsi32>\n    %6 = \"okm.wrapper_kernel\"(%3, %5) ({\n      %11 = bufferization.to_tensor %3 : memref<2xf32>\n      %12 = \"oneflow.arg_sort\"(%11) {device_name = [\"@0:0\"], device_tag = \"cpu\", direction = \"ASCENDING\", hierarchy = [1], op_name = \"arg_sort-2\", scope_symbol_id = 12 : i64} : (tensor<2xf32>) -> tensor<2xsi32>\n      %13 = bufferization.to_memref %12 : memref<2xsi32>\n      okm.return %13 : memref<2xsi32>\n    }) : (memref<2xf32>, memref<2xsi32>) -> memref<2xsi32>\n    %7 = \"okm.ret_to_memref\"() {index = 1 : i32} : () -> memref<2xf32>\n    %8 = \"okm.wrapper_kernel\"(%3, %5, %7) ({\n      %11 = bufferization.to_tensor %3 : memref<2xf32>\n      %12 = bufferization.to_tensor %5 : memref<2xsi32>\n      %13 = \"oneflow.dim_gather\"(%11, %12) {device_name = [\"@0:0\"], device_tag = \"cpu\", dim = 0 : si32, hierarchy = [1], op_name = \"dim_gather-3\", scope_symbol_id = 12 : i64} : (tensor<2xf32>, tensor<2xsi32>) -> tensor<2xf32>\n      %14 = bufferization.to_memref %13 : memref<2xf32>\n      okm.return %14 : memref<2xf32>\n    }) : (memref<2xf32>, memref<2xsi32>, memref<2xf32>) -> memref<2xf32>\n    %9 = \"okm.memref_to_ret\"(%6) {index = 0 : i32} : (memref<2xsi32>) -> memref<2xsi32>\n    %10 = \"okm.memref_to_ret\"(%8) {index = 1 : i32} : (memref<2xf32>) -> memref<2xf32>\n    return\n  }\n}\n\n"
  },
  {
    "path": "oneflow/ir/test/OneFlow/kernel_launch/OKMPass/wrap_okm_kernel.mlir",
    "content": "// RUN: oneflow-opt %s \\\n// RUN: -wrap-okm-kernel \\\n// RUN: | FileCheck %s\n\n// CHECK: module {\n// CHECK:   func.func @okm_wrap_subgraph0() {\n// CHECK:     %[[ARG0:[a-zA-Z0-9_]+]] = \"okm.arg_to_memref\"() {index = 0 : i32} : () -> memref<2xf32>\n// CHECK:     %[[ARG1:[a-zA-Z0-9_]+]] = \"okm.plan_memref\"() : () -> memref<2xf32>\n// CHECK:     %[[ARG2:[a-zA-Z0-9_]+]] = \"okm.wrapper_kernel\"(%[[ARG0]], %[[ARG1]]) ({\n// CHECK:       %[[ARG11:[a-zA-Z0-9_]+]] = bufferization.to_tensor %[[ARG0]] : memref<2xf32>\n// CHECK:       %[[ARG12:[a-zA-Z0-9_]+]] = \"oneflow.relu\"(%[[ARG11]]) {device_name = [\"@0:0\"], device_tag = \"cpu\", hierarchy = [1], op_name = \"relu-0\", scope_symbol_id = 12 : i64} : (tensor<2xf32>) -> tensor<2xf32>\n// CHECK:       %[[ARG13:[a-zA-Z0-9_]+]] = bufferization.to_memref %[[ARG12]] : memref<2xf32>\n// CHECK:       okm.return %[[ARG13:[a-zA-Z0-9_]+]] : memref<2xf32>\n// CHECK:     }) : (memref<2xf32>, memref<2xf32>) -> memref<2xf32>\n// CHECK:     %[[ARG3:[a-zA-Z0-9_]+]] = \"okm.plan_memref\"() : () -> memref<2xf32>\n// CHECK:     %[[ARG4:[a-zA-Z0-9_]+]] = \"okm.wrapper_kernel\"(%[[ARG1]], %[[ARG3]]) ({\n// CHECK:       %[[ARG11:[a-zA-Z0-9_]+]] = bufferization.to_tensor %[[ARG1]] : memref<2xf32>\n// CHECK:       %[[ARG12:[a-zA-Z0-9_]+]] = \"oneflow.tanh\"(%[[ARG11]]) {device_name = [\"@0:0\"], device_tag = \"cpu\", hierarchy = [1], op_name = \"tanh-1\", scope_symbol_id = 12 : i64} : (tensor<2xf32>) -> tensor<2xf32>\n// CHECK:       %[[ARG13:[a-zA-Z0-9_]+]] = bufferization.to_memref %[[ARG12:[a-zA-Z0-9_]+]] : memref<2xf32>\n// CHECK:       okm.return %[[ARG13:[a-zA-Z0-9_]+]] : memref<2xf32>\n// CHECK:     }) : (memref<2xf32>, memref<2xf32>) -> memref<2xf32>\n// CHECK:     %[[ARG5:[a-zA-Z0-9_]+]] = \"okm.ret_to_memref\"() {index = 0 : i32} : () -> memref<2xsi32>\n// CHECK:     %[[ARG6:[a-zA-Z0-9_]+]] = \"okm.wrapper_kernel\"(%[[ARG3]], %[[ARG5]]) ({\n// CHECK:       %[[ARG11:[a-zA-Z0-9_]+]] = bufferization.to_tensor %[[ARG3]] : memref<2xf32>\n// CHECK:       %[[ARG12:[a-zA-Z0-9_]+]] = \"oneflow.arg_sort\"(%[[ARG11]]) {device_name = [\"@0:0\"], device_tag = \"cpu\", direction = \"ASCENDING\", hierarchy = [1], op_name = \"arg_sort-2\", scope_symbol_id = 12 : i64} : (tensor<2xf32>) -> tensor<2xsi32>\n// CHECK:       %[[ARG13:[a-zA-Z0-9_]+]] = bufferization.to_memref %[[ARG12]] : memref<2xsi32>\n// CHECK:       okm.return %[[ARG13:[a-zA-Z0-9_]+]] : memref<2xsi32>\n// CHECK:     }) : (memref<2xf32>, memref<2xsi32>) -> memref<2xsi32>\n// CHECK:     %[[ARG7:[a-zA-Z0-9_]+]] = \"okm.ret_to_memref\"() {index = 1 : i32} : () -> memref<2xf32>\n// CHECK:     %[[ARG8:[a-zA-Z0-9_]+]] = \"okm.wrapper_kernel\"(%[[ARG3]], %[[ARG5]], %7) ({\n// CHECK:       %[[ARG11:[a-zA-Z0-9_]+]] = bufferization.to_tensor %[[ARG3]] : memref<2xf32>\n// CHECK:       %[[ARG12:[a-zA-Z0-9_]+]] = bufferization.to_tensor %[[ARG5]] : memref<2xsi32>\n// CHECK:       %[[ARG13:[a-zA-Z0-9_]+]] = \"oneflow.dim_gather\"(%[[ARG11]], %[[ARG12]]) {device_name = [\"@0:0\"], device_tag = \"cpu\", dim = 0 : si32, hierarchy = [1], op_name = \"dim_gather-3\", scope_symbol_id = 12 : i64} : (tensor<2xf32>, tensor<2xsi32>) -> tensor<2xf32>\n// CHECK:       %[[ARG14:[a-zA-Z0-9_]+]] = bufferization.to_memref %[[ARG13]] : memref<2xf32>\n// CHECK:       okm.return %[[ARG14]] : memref<2xf32>\n// CHECK:     }) : (memref<2xf32>, memref<2xsi32>, memref<2xf32>) -> memref<2xf32>\n// CHECK:     %[[ARG9:[a-zA-Z0-9_]+]] = \"okm.memref_to_ret\"(%[[ARG6]]) {index = 0 : i32} : (memref<2xsi32>) -> memref<2xsi32>\n// CHECK:     %[[ARG10:[a-zA-Z0-9_]+]] = \"okm.memref_to_ret\"(%[[ARG8]]) {index = 1 : i32} : (memref<2xf32>) -> memref<2xf32>\n// CHECK:     return\n// CHECK:   }\n// CHECK: }\n\nmodule {\n  func.func @okm_subgraph0(%arg0: tensor<2xf32>) -> (tensor<2xsi32>, tensor<2xf32>) attributes {llvm.emit_c_interface} {\n    %0 = \"okm.arg_to_tensor\"() {index = 0 : i32} : () -> tensor<2xf32>\n    %1 = \"oneflow.relu\"(%0) {device_name = [\"@0:0\"], device_tag = \"cpu\", hierarchy = [1], op_name = \"relu-0\", scope_symbol_id = 12 : i64} : (tensor<2xf32>) -> tensor<2xf32>\n    %2 = \"oneflow.tanh\"(%1) {device_name = [\"@0:0\"], device_tag = \"cpu\", hierarchy = [1], op_name = \"tanh-1\", scope_symbol_id = 12 : i64} : (tensor<2xf32>) -> tensor<2xf32>\n    %3 = \"oneflow.arg_sort\"(%2) {device_name = [\"@0:0\"], device_tag = \"cpu\", direction = \"ASCENDING\", hierarchy = [1], op_name = \"arg_sort-2\", scope_symbol_id = 12 : i64} : (tensor<2xf32>) -> tensor<2xsi32>\n    %4 = \"oneflow.dim_gather\"(%2, %3) {device_name = [\"@0:0\"], device_tag = \"cpu\", dim = 0 : si32, hierarchy = [1], op_name = \"dim_gather-3\", scope_symbol_id = 12 : i64} : (tensor<2xf32>, tensor<2xsi32>) -> tensor<2xf32>\n    %5 = \"okm.tensor_to_ret\"(%3) {index = 0 : i32} : (tensor<2xsi32>) -> tensor<2xsi32>\n    %6 = \"okm.tensor_to_ret\"(%4) {index = 1 : i32} : (tensor<2xf32>) -> tensor<2xf32>\n    return %5, %6 : tensor<2xsi32>, tensor<2xf32>\n  }\n}\n\n"
  },
  {
    "path": "oneflow/ir/test/OneFlow/kernel_launch/OneFlowPass/aggregate_compute_ops.mlir",
    "content": "// RUN: oneflow-opt %s \\\n// RUN: -aggregate-compute-ops \\\n// RUN: | FileCheck %s\n\n// CHECK: %[[ARG0:[a-zA-Z0-9_]+]] = \"oneflow.arg_sort\"\n// CHECK: %[[ARG1:[a-zA-Z0-9_]+]] = \"oneflow.dim_gather\"\n// CHECK: \"oneflow.output\"(%[[ARG0]])\n\nmodule {\n  oneflow.job @GraphToRun_1(%arg0: tensor<2xf32>) -> (tensor<2xsi32>, tensor<2xf32>) {\n    %output = \"oneflow.input\"(%arg0) {data_type = 2 : i32, device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], is_dynamic = false, nd_sbp = [\"B\"], op_name = \"_GraphToRun_1_input.0.0_2\", output_lbns = [\"_GraphToRun_1_input.0.0_2/out\"], scope_symbol_id = 30 : i64, shape = [2 : si64]} : (tensor<2xf32>) -> tensor<2xf32>\n    %0 = \"oneflow.relu\"(%output) {device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], op_name = \"relu-0\", scope_symbol_id = 30 : i64} : (tensor<2xf32>) -> tensor<2xf32>\n    %1 = \"oneflow.tanh\"(%0) {device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], op_name = \"tanh-1\", scope_symbol_id = 30 : i64} : (tensor<2xf32>) -> tensor<2xf32>\n    %2 = \"oneflow.arg_sort\"(%1) {device_name = [\"@0:0\"], device_tag = \"cuda\", direction = \"ASCENDING\", hierarchy = [1], op_name = \"arg_sort-2\", scope_symbol_id = 30 : i64} : (tensor<2xf32>) -> tensor<2xsi32>\n    %output_0 = \"oneflow.output\"(%2) {data_type = 5 : i32, device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], is_dynamic = false, nd_sbp = [\"B\"], op_name = \"_GraphToRun_1_output.0.0.1_4\", output_lbns = [\"_GraphToRun_1_output.0.0.1_4/out\"], scope_symbol_id = 30 : i64, shape = [2 : si64]} : (tensor<2xsi32>) -> tensor<2xsi32>\n    %3 = \"oneflow.dim_gather\"(%1, %2) {device_name = [\"@0:0\"], device_tag = \"cuda\", dim = 0 : si32, hierarchy = [1], op_name = \"dim_gather-3\", scope_symbol_id = 30 : i64} : (tensor<2xf32>, tensor<2xsi32>) -> tensor<2xf32>\n    %output_1 = \"oneflow.output\"(%3) {data_type = 2 : i32, device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], is_dynamic = false, nd_sbp = [\"B\"], op_name = \"_GraphToRun_1_output.0.0.0_3\", output_lbns = [\"_GraphToRun_1_output.0.0.0_3/out\"], scope_symbol_id = 30 : i64, shape = [2 : si64]} : (tensor<2xf32>) -> tensor<2xf32>\n    oneflow.return %output_0, %output_1 : tensor<2xsi32>, tensor<2xf32>\n  }\n}\n"
  },
  {
    "path": "oneflow/ir/test/OneFlow/kernel_launch/OneFlowPass/wrap_ops_to_kernel_launch/cuda_graph.mlir",
    "content": "// RUN: oneflow-opt %s \\\n// RUN: -wrap-ops-to-kernel-launch=\"mode=cuda_graph\" \\\n// RUN: | FileCheck %s\n\n// CHECK:  func.func @_mlir_oneflow_subgraph1\n// CHECK:  func.func @_mlir_oneflow_subgraph0\n\nmodule {\n  oneflow.job @GraphToRun_0(%arg0: tensor<2xf32>) -> (tensor<2xsi32>, tensor<2xf32>) {\n    %output = \"oneflow.input\"(%arg0) {data_type = 2 : i32, device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], is_dynamic = false, nd_sbp = [\"B\"], op_name = \"_GraphToRun_0_input.0.0_2\", output_lbns = [\"_GraphToRun_0_input.0.0_2/out\"], scope_symbol_id = 12 : i64, shape = [2 : si64]} : (tensor<2xf32>) -> tensor<2xf32>\n    %0 = \"oneflow.relu\"(%output) {device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], op_name = \"relu-0\", scope_symbol_id = 12 : i64} : (tensor<2xf32>) -> tensor<2xf32>\n    %1 = \"oneflow.tanh\"(%0) {device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], op_name = \"tanh-1\", scope_symbol_id = 12 : i64} : (tensor<2xf32>) -> tensor<2xf32>\n    %2 = \"oneflow.arg_sort\"(%1) {device_name = [\"@0:0\"], device_tag = \"cuda\", direction = \"ASCENDING\", hierarchy = [1], op_name = \"arg_sort-2\", scope_symbol_id = 12 : i64} : (tensor<2xf32>) -> tensor<2xsi32>\n    %3 = \"oneflow.dim_gather\"(%1, %2) {device_name = [\"@0:0\"], device_tag = \"cuda\", dim = 0 : si32, hierarchy = [1], op_name = \"dim_gather-3\", scope_symbol_id = 12 : i64} : (tensor<2xf32>, tensor<2xsi32>) -> tensor<2xf32>\n    %output_0 = \"oneflow.output\"(%3) {data_type = 2 : i32, device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], is_dynamic = false, nd_sbp = [\"B\"], op_name = \"_GraphToRun_0_output.0.0.0_3\", output_lbns = [\"_GraphToRun_0_output.0.0.0_3/out\"], scope_symbol_id = 12 : i64, shape = [2 : si64]} : (tensor<2xf32>) -> tensor<2xf32>\n    %output_1 = \"oneflow.output\"(%2) {data_type = 5 : i32, device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], is_dynamic = false, nd_sbp = [\"B\"], op_name = \"_GraphToRun_0_output.0.0.1_4\", output_lbns = [\"_GraphToRun_0_output.0.0.1_4/out\"], scope_symbol_id = 12 : i64, shape = [2 : si64]} : (tensor<2xsi32>) -> tensor<2xsi32>\n    oneflow.return %output_1, %output_0 : tensor<2xsi32>, tensor<2xf32>\n  }\n}"
  },
  {
    "path": "oneflow/ir/test/OneFlow/kernel_launch/OneFlowPass/wrap_ops_to_kernel_launch/lit.local.cfg",
    "content": "if not config.BUILD_CUDA:\n  config.unsupported = True"
  },
  {
    "path": "oneflow/ir/test/OneFlow/kernel_launch/OneFlowPass/wrap_ops_to_kernel_launch/simple.mlir",
    "content": "// RUN: oneflow-opt %s \\\n// RUN: -wrap-ops-to-kernel-launch=\"mode=simple\" \\\n// RUN: | FileCheck %s\n\n// CHECK-NOT:  func.func @_mlir_oneflow_subgraph1\n// CHECK:  func.func @_mlir_oneflow_subgraph0\n\nmodule {\n  oneflow.job @GraphToRun_0(%arg0: tensor<2xf32>) -> (tensor<2xsi32>, tensor<2xf32>) {\n    %output = \"oneflow.input\"(%arg0) {data_type = 2 : i32, device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], is_dynamic = false, nd_sbp = [\"B\"], op_name = \"_GraphToRun_0_input.0.0_2\", output_lbns = [\"_GraphToRun_0_input.0.0_2/out\"], scope_symbol_id = 12 : i64, shape = [2 : si64]} : (tensor<2xf32>) -> tensor<2xf32>\n    %0 = \"oneflow.relu\"(%output) {device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], op_name = \"relu-0\", scope_symbol_id = 12 : i64} : (tensor<2xf32>) -> tensor<2xf32>\n    %1 = \"oneflow.tanh\"(%0) {device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], op_name = \"tanh-1\", scope_symbol_id = 12 : i64} : (tensor<2xf32>) -> tensor<2xf32>\n    %2 = \"oneflow.arg_sort\"(%1) {device_name = [\"@0:0\"], device_tag = \"cuda\", direction = \"ASCENDING\", hierarchy = [1], op_name = \"arg_sort-2\", scope_symbol_id = 12 : i64} : (tensor<2xf32>) -> tensor<2xsi32>\n    %3 = \"oneflow.dim_gather\"(%1, %2) {device_name = [\"@0:0\"], device_tag = \"cuda\", dim = 0 : si32, hierarchy = [1], op_name = \"dim_gather-3\", scope_symbol_id = 12 : i64} : (tensor<2xf32>, tensor<2xsi32>) -> tensor<2xf32>\n    %output_0 = \"oneflow.output\"(%3) {data_type = 2 : i32, device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], is_dynamic = false, nd_sbp = [\"B\"], op_name = \"_GraphToRun_0_output.0.0.0_3\", output_lbns = [\"_GraphToRun_0_output.0.0.0_3/out\"], scope_symbol_id = 12 : i64, shape = [2 : si64]} : (tensor<2xf32>) -> tensor<2xf32>\n    %output_1 = \"oneflow.output\"(%2) {data_type = 5 : i32, device_name = [\"@0:0\"], device_tag = \"cuda\", hierarchy = [1], is_dynamic = false, nd_sbp = [\"B\"], op_name = \"_GraphToRun_0_output.0.0.1_4\", output_lbns = [\"_GraphToRun_0_output.0.0.1_4/out\"], scope_symbol_id = 12 : i64, shape = [2 : si64]} : (tensor<2xsi32>) -> tensor<2xsi32>\n    oneflow.return %output_1, %output_0 : tensor<2xsi32>, tensor<2xf32>\n  }\n}"
  },
  {
    "path": "oneflow/ir/test/OneFlow/kernel_launch/test_resnet.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n# RUN: python3 -m oneflow.test_utils.throttle --with-cuda=%with_cuda python3 %s\n\nimport os\nimport sys\n\nsys.path.append(os.path.abspath(os.path.dirname(__file__)))\nsys.path.append(os.path.abspath(os.path.dirname(__file__)) + \"/..\")\n\n\nimport unittest\nimport numpy as np\nimport oneflow as flow\nimport oneflow.unittest\nfrom networks.resnet50 import resnet50\n\n\ndef _test_okl_resnet(test_case):\n    x = flow.randn(2, 3, 224, 224)\n    resnet = resnet50()\n    x = x.cuda()\n    resnet.to(\"cuda\")\n\n    eager_res = resnet(x)\n\n    class GraphToRun(flow.nn.Graph):\n        def __init__(self):\n            super().__init__()\n            self.resnet = resnet\n\n        def build(self, x):\n            return self.resnet(x)\n\n    graph_to_run = GraphToRun()\n    lazy_res = graph_to_run(x)\n    test_case.assertTrue(\n        np.allclose(eager_res.numpy(), lazy_res.numpy(), rtol=1e-4, atol=1e-4)\n    )\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestOKLResNet(flow.unittest.MLIRTestCase):\n    def setUp(self):\n        os.environ[\"ONEFLOW_MLIR_STDOUT\"] = \"1\"\n        os.environ[\"ONEFLOW_MLIR_CSE\"] = \"1\"\n        os.environ[\"ONEFLOW_MLIR_FUSE_FORWARD_OPS\"] = \"1\"\n        os.environ[\"ONEFLOW_MLIR_ENABLE_ROUND_TRIP\"] = \"1\"\n        os.environ[\"ONEFLOW_MLIR_FUSE_KERNEL_LAUNCH\"] = \"1\"\n        os.environ[\"ONEFLOW_KERNEL_ENABLE_CUDA_GRAPH\"] = \"1\"\n\n    @unittest.skipUnless(flow.sysconfig.with_cuda(), \"only test cpu cases\")\n    def test_okl_resnet(test_case):\n        _test_okl_resnet(test_case)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "oneflow/ir/test/OneFlow/networks/__init__.py",
    "content": ""
  },
  {
    "path": "oneflow/ir/test/OneFlow/networks/resnet50.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow as flow\nimport oneflow.nn as nn\nfrom oneflow import Tensor\nfrom typing import Type, Any, Callable, Union, List, Optional\n\n\ndef conv3x3(\n    in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1\n) -> nn.Conv2d:\n    \"\"\"3x3 convolution with padding\"\"\"\n    return nn.Conv2d(\n        in_planes,\n        out_planes,\n        kernel_size=3,\n        stride=stride,\n        padding=dilation,\n        groups=groups,\n        bias=False,\n        dilation=dilation,\n    )\n\n\ndef conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:\n    \"\"\"1x1 convolution\"\"\"\n    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)\n\n\nclass BasicBlock(nn.Module):\n    expansion: int = 1\n\n    def __init__(\n        self,\n        inplanes: int,\n        planes: int,\n        stride: int = 1,\n        downsample: Optional[nn.Module] = None,\n        groups: int = 1,\n        base_width: int = 64,\n        dilation: int = 1,\n        norm_layer: Optional[Callable[..., nn.Module]] = None,\n    ) -> None:\n        super(BasicBlock, self).__init__()\n        if norm_layer is None:\n            norm_layer = nn.BatchNorm2d\n        if groups != 1 or base_width != 64:\n            raise ValueError(\"BasicBlock only supports groups=1 and base_width=64\")\n        if dilation > 1:\n            raise NotImplementedError(\"Dilation > 1 not supported in BasicBlock\")\n        # Both self.conv1 and self.downsample layers downsample the input when stride != 1\n        self.conv1 = conv3x3(inplanes, planes, stride)\n        self.bn1 = norm_layer(planes)\n        self.relu = nn.ReLU()\n        self.conv2 = conv3x3(planes, planes)\n        self.bn2 = norm_layer(planes)\n        self.downsample = downsample\n        self.stride = stride\n\n    def forward(self, x: Tensor) -> Tensor:\n        identity = x\n\n        out = self.conv1(x)\n        out = self.bn1(out)\n        out = self.relu(out)\n\n        out = self.conv2(out)\n        out = self.bn2(out)\n\n        if self.downsample is not None:\n            identity = self.downsample(x)\n\n        out += identity\n        out = self.relu(out)\n\n        return out\n\n\nclass Bottleneck(nn.Module):\n    expansion: int = 4\n\n    def __init__(\n        self,\n        inplanes: int,\n        planes: int,\n        stride: int = 1,\n        downsample: Optional[nn.Module] = None,\n        groups: int = 1,\n        base_width: int = 64,\n        dilation: int = 1,\n        norm_layer: Optional[Callable[..., nn.Module]] = None,\n    ) -> None:\n        super(Bottleneck, self).__init__()\n        if norm_layer is None:\n            norm_layer = nn.BatchNorm2d\n        width = int(planes * (base_width / 64.0)) * groups\n        # Both self.conv2 and self.downsample layers downsample the input when stride != 1\n        self.conv1 = conv1x1(inplanes, width)\n        self.bn1 = norm_layer(width)\n        self.conv2 = conv3x3(width, width, stride, groups, dilation)\n        self.bn2 = norm_layer(width)\n        self.conv3 = conv1x1(width, planes * self.expansion)\n        self.bn3 = norm_layer(planes * self.expansion)\n        self.relu = nn.ReLU()\n        self.downsample = downsample\n        self.stride = stride\n\n    def forward(self, x: Tensor) -> Tensor:\n        identity = x\n\n        out = self.conv1(x)\n        out = self.bn1(out)\n        out = self.relu(out)\n\n        out = self.conv2(out)\n        out = self.bn2(out)\n        out = self.relu(out)\n\n        out = self.conv3(out)\n        out = self.bn3(out)\n\n        if self.downsample is not None:\n            identity = self.downsample(x)\n\n        out += identity\n        out = self.relu(out)\n\n        return out\n\n\nclass ResNet(nn.Module):\n    def __init__(\n        self,\n        block: Type[Union[BasicBlock, Bottleneck]],\n        layers: List[int],\n        num_classes: int = 1000,\n        zero_init_residual: bool = False,\n        groups: int = 1,\n        width_per_group: int = 64,\n        replace_stride_with_dilation: Optional[List[bool]] = None,\n        norm_layer: Optional[Callable[..., nn.Module]] = None,\n    ) -> None:\n        super(ResNet, self).__init__()\n        if norm_layer is None:\n            norm_layer = nn.BatchNorm2d\n        self._norm_layer = norm_layer\n\n        self.inplanes = 64\n        self.dilation = 1\n        if replace_stride_with_dilation is None:\n            # each element in the tuple indicates if we should replace\n            # the 2x2 stride with a dilated convolution instead\n            replace_stride_with_dilation = [False, False, False]\n        if len(replace_stride_with_dilation) != 3:\n            raise ValueError(\n                \"replace_stride_with_dilation should be None \"\n                \"or a 3-element tuple, got {}\".format(replace_stride_with_dilation)\n            )\n        self.groups = groups\n        self.base_width = width_per_group\n        self.conv1 = nn.Conv2d(\n            3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False\n        )\n        self.bn1 = norm_layer(self.inplanes)\n        self.relu = nn.ReLU()\n        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)\n        self.layer1 = self._make_layer(block, 64, layers[0])\n        self.layer2 = self._make_layer(\n            block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0]\n        )\n        self.layer3 = self._make_layer(\n            block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1]\n        )\n        self.layer4 = self._make_layer(\n            block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2]\n        )\n        self.avgpool = nn.AvgPool2d((7, 7))\n        self.fc = nn.Linear(512 * block.expansion, num_classes)\n\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                nn.init.kaiming_normal_(m.weight, mode=\"fan_out\", nonlinearity=\"relu\")\n            elif isinstance(m, nn.BatchNorm2d):\n                nn.init.constant_(m.weight, 1)\n                nn.init.constant_(m.bias, 0)\n\n        # Zero-initialize the last BN in each residual branch,\n        # so that the residual branch starts with zeros, and each residual block behaves like an identity.\n        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677\n        if zero_init_residual:\n            for m in self.modules():\n                if isinstance(m, Bottleneck):\n                    nn.init.constant_(m.bn3.weight, 0)  # type: ignore[arg-type]\n                elif isinstance(m, BasicBlock):\n                    nn.init.constant_(m.bn2.weight, 0)  # type: ignore[arg-type]\n\n    def _make_layer(\n        self,\n        block: Type[Union[BasicBlock, Bottleneck]],\n        planes: int,\n        blocks: int,\n        stride: int = 1,\n        dilate: bool = False,\n    ) -> nn.Sequential:\n        norm_layer = self._norm_layer\n        downsample = None\n        previous_dilation = self.dilation\n        if dilate:\n            self.dilation *= stride\n            stride = 1\n        if stride != 1 or self.inplanes != planes * block.expansion:\n            downsample = nn.Sequential(\n                conv1x1(self.inplanes, planes * block.expansion, stride),\n                norm_layer(planes * block.expansion),\n            )\n\n        layers = []\n        layers.append(\n            block(\n                self.inplanes,\n                planes,\n                stride,\n                downsample,\n                self.groups,\n                self.base_width,\n                previous_dilation,\n                norm_layer,\n            )\n        )\n        self.inplanes = planes * block.expansion\n        for _ in range(1, blocks):\n            layers.append(\n                block(\n                    self.inplanes,\n                    planes,\n                    groups=self.groups,\n                    base_width=self.base_width,\n                    dilation=self.dilation,\n                    norm_layer=norm_layer,\n                )\n            )\n\n        return nn.Sequential(*layers)\n\n    def _forward_impl(self, x: Tensor) -> Tensor:\n        x = self.conv1(x)\n        x = self.bn1(x)\n        x = self.relu(x)\n        x = self.maxpool(x)\n\n        x = self.layer1(x)\n        x = self.layer2(x)\n        x = self.layer3(x)\n        x = self.layer4(x)\n\n        x = self.avgpool(x)\n        x = flow.flatten(x, 1)\n        x = self.fc(x)\n\n        return x\n\n    def forward(self, x: Tensor) -> Tensor:\n        return self._forward_impl(x)\n\n\ndef _resnet(\n    arch: str,\n    block: Type[Union[BasicBlock, Bottleneck]],\n    layers: List[int],\n    **kwargs: Any\n) -> ResNet:\n    model = ResNet(block, layers, **kwargs)\n    return model\n\n\ndef resnet50(**kwargs: Any) -> ResNet:\n    r\"\"\"ResNet-5\n    `\"Deep Residual Learning for Image Recognition\" <https://arxiv.org/pdf/1512.03385.pdf>`_.\n    \"\"\"\n    return _resnet(\"resnet50\", Bottleneck, [3, 4, 6, 3], **kwargs)\n"
  },
  {
    "path": "oneflow/ir/test/OneFlow/oneflow-opt.mlir",
    "content": "// RUN: oneflow-opt --show-dialects | FileCheck %s\n// CHECK: Available Dialects:\n// CHECK: oneflow\n"
  },
  {
    "path": "oneflow/ir/test/OneFlow/oneflow-translate.mlir",
    "content": "// RUN: oneflow-translate --help | FileCheck %s\n// CHECK: --import-oneflow-job\n"
  },
  {
    "path": "oneflow/ir/test/OneFlow/psig/error_parse.mlir",
    "content": "// RUN: not oneflow-opt %s \\\n// RUN: -split-input-file \\\n// RUN: -verify-diagnostics -o -  2>&1 | FileCheck  --check-prefix=CHECK_ERROR_1  %s\n\n// CHECK_ERROR_1: unexpected error: failed to parse a sbp attribute here\nmodule {\n  oneflow.job @test_err(){\n    %output_0 = \"oneflow.variable\"() {data_type = 2 : i32, device_name = [\"@0:0\", \"@1:1\"], device_tag = \"cuda\", hierarchy = [2, 1], parallel = #sbp.parallel<[] -> [[[]], \"S(0)\", #sbp.P]>, op_name = \"net-FreeEagerTensor-2\", output_lbns = [\"net-FreeEagerTensor-2/out\"], scope_symbol_id = 14 : i64, shape = [5 : si64, 8 : si64], trainable = false} : () -> tensor<5x8xf32>\n    oneflow.return\n  }\n}\n"
  },
  {
    "path": "oneflow/ir/test/OneFlow/psig/sbp_parse.mlir",
    "content": "// RUN: oneflow-opt %s \\\n// RUN: -split-input-file \\\n// RUN: -verify-diagnostics -o - | FileCheck %s\n\n// CHECK-LABEL: test_single\nmodule {\n  oneflow.job @test_single(){\n// CHECK: parallel = #sbp.parallel<[] -> [#sbp.B, #sbp.S<0>]>\n    %output = \"oneflow.variable\"() {data_type = 2 : i32, device_name = [\"@0:0\", \"@1:1\"], device_tag = \"cuda\", hierarchy = [2, 1], parallel = #sbp.parallel<[] -> [#sbp.B, #sbp.S<0>]>, op_name = \"net-FreeEagerTensor-1\", output_lbns = [\"net-FreeEagerTensor-1/out\"], scope_symbol_id = 14 : i64, shape = [4 : si64, 5 : si64], trainable = false} : () -> tensor<4x5xf32>\n// CHECK: parallel = #sbp.parallel<[] -> [#sbp.B, #sbp.P]>\n    %output_0 = \"oneflow.variable\"() {data_type = 2 : i32, device_name = [\"@0:0\", \"@1:1\"], device_tag = \"cuda\", hierarchy = [2, 1], parallel = #sbp.parallel<[] -> [#sbp.B, #sbp.P]>, op_name = \"net-FreeEagerTensor-2\", output_lbns = [\"net-FreeEagerTensor-2/out\"], scope_symbol_id = 14 : i64, shape = [5 : si64, 8 : si64], trainable = false} : () -> tensor<5x8xf32>\n    oneflow.return\n  }\n}\n\n// CHECK-LABEL: test_nd\nmodule {\n  oneflow.job @test_nd(){\n    // CHECK: #sbp.B, #sbp.S<0>\n    %output = \"oneflow.variable\"() {data_type = 2 : i32, device_name = [\"@0:0\", \"@1:1\"], device_tag = \"cuda\", hierarchy = [2, 1], parallel = #sbp.parallel<[] -> [[#sbp.B, #sbp.S<0>]]>, op_name = \"net-FreeEagerTensor-1\", output_lbns = [\"net-FreeEagerTensor-1/out\"], scope_symbol_id = 14 : i64, shape = [4 : si64, 5 : si64], trainable = false} : () -> tensor<4x5xf32>\n    // CHECK: [#sbp.B, #sbp.P]\n    %output_0 = \"oneflow.variable\"() {data_type = 2 : i32, device_name = [\"@0:0\", \"@1:1\"], device_tag = \"cuda\", hierarchy = [2, 1], parallel = #sbp.parallel<[] -> [[#sbp.B, #sbp.P]]>, op_name = \"net-FreeEagerTensor-2\", output_lbns = [\"net-FreeEagerTensor-2/out\"], scope_symbol_id = 14 : i64, shape = [5 : si64, 8 : si64], trainable = false} : () -> tensor<5x8xf32>\n    oneflow.return\n  }\n}\n"
  },
  {
    "path": "oneflow/ir/test/OneFlow/psig/test_2nd_basic_parse.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n# RUN: python3 -m oneflow.distributed.launch --nproc_per_node 2 %s | FileCheck %s\n# CHECK: [#sbp.B, #sbp.S<0>]\n# CHECK: [#sbp.B, #sbp.S<0>]\n\nimport oneflow as flow\nimport unittest\nimport oneflow.unittest\nimport os\nfrom google.protobuf import text_format\n\n\ndef _test_nd_basic_parse(test_case):\n    class ModuleToRun(flow.nn.Module):\n        def __init__(self):\n            super().__init__()\n            P0 = flow.placement(\"cpu\", ranks=[[0], [1]])\n            a0_sbp = (flow.sbp.broadcast, flow.sbp.split(0))\n            b0_sbp = (flow.sbp.broadcast, flow.sbp.split(0))\n\n            self.A0 = flow.randn(4, 5, placement=P0, sbp=a0_sbp)\n            self.B0 = flow.randn(5, 8, placement=P0, sbp=b0_sbp)\n\n        def forward(self):\n            return flow.matmul(self.A0, self.B0)\n\n    net = ModuleToRun()\n\n    class GraphToRun(flow.nn.Graph):\n        def __init__(self):\n            super().__init__()\n            self.net = net\n\n        def build(self):\n            return self.net()\n\n    graph_to_run = GraphToRun()\n    lazy_output = graph_to_run()\n\n    serialized_job = graph_to_run._forward_job_proto.SerializeToString()\n    mlir = flow._oneflow_internal.nn.graph.ConvertJobToIR(serialized_job)\n    print(mlir)\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestBasicParse(flow.unittest.MLIRTestCase):\n    def setUp(self):\n        os.environ[\"ONEFLOW_MLIR_ENABLE_ROUND_TRIP\"] = \"1\"\n\n    def test_nd_basic_parse(test_case):\n        _test_nd_basic_parse(test_case)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "oneflow/ir/test/OneFlow/psig/test_basic_parse.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n# RUN: python3 %s | FileCheck %s\n# CHECK: [#sbp.B]\n# CHECK: [#sbp.S<0>]\n\nimport oneflow as flow\nimport unittest\nimport oneflow.unittest\nimport os\nfrom google.protobuf import text_format\n\n\ndef _test_1nd_basic_parse(test_case):\n    class ModuleToRun(flow.nn.Module):\n        def __init__(self):\n            super().__init__()\n            P0 = flow.placement(\"cpu\", ranks=[0])\n            a0_sbp = flow.sbp.broadcast\n            b0_sbp = flow.sbp.split(0)\n            self.A0 = flow.randn(4, 5, placement=P0, sbp=a0_sbp)\n            self.B0 = flow.randn(5, 8, placement=P0, sbp=b0_sbp)\n\n        def forward(self):\n            return flow.matmul(self.A0, self.B0)\n\n    net = ModuleToRun()\n\n    class GraphToRun(flow.nn.Graph):\n        def __init__(self):\n            super().__init__()\n            self.net = net\n\n        def build(self):\n            return self.net()\n\n    graph_to_run = GraphToRun()\n    lazy_output = graph_to_run()\n\n    serialized_job = graph_to_run._forward_job_proto.SerializeToString()\n    mlir = flow._oneflow_internal.nn.graph.ConvertJobToIR(serialized_job)\n    print(mlir)\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestBasicParse(flow.unittest.MLIRTestCase):\n    def setUp(self):\n        os.environ[\"ONEFLOW_MLIR_ENABLE_ROUND_TRIP\"] = \"1\"\n\n    def test_1nd_basic_parse(test_case):\n        _test_1nd_basic_parse(test_case)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "oneflow/ir/test/OneFlow/traits.mlir",
    "content": "// RUN: oneflow-opt -test-oneflow-trait-folder %s | FileCheck %s\n\n// CHECK-LABEL: func.func @testSingleIdempotent\n// CHECK-SAME:  ([[ARG0:%.+]]: tensor<f32>)\nfunc.func @testSingleIdempotent(%arg0 : tensor<f32>) -> tensor<f32> {\n  // CHECK: [[IDEMPOTENT:%.+]] = \"oneflow.relu\"([[ARG0]])\n  %0 = \"oneflow.relu\"(%arg0) {device_tag = \"cuda\", op_name = \"Relu_1\", op_type_name = \"relu\", device_name = [\"0:0-0\"], scope_symbol_id = 4611686018427420670 : i64} : (tensor<f32>) -> tensor<f32>\n  // CHECK: return [[IDEMPOTENT]]\n  return %0: tensor<f32>\n}\n\n// CHECK-LABEL: func.func @testDoubleIdempotent\n// CHECK-SAME:  ([[ARG0:%.+]]: tensor<f32>)\nfunc.func @testDoubleIdempotent(%arg0: tensor<f32>) -> tensor<f32> {\n  // CHECK: [[IDEMPOTENT:%.+]] = \"oneflow.relu\"([[ARG0]])\n  %0 = \"oneflow.relu\"(%arg0) {device_tag = \"cuda\", op_name = \"Relu_1\", op_type_name = \"relu\", device_name = [\"0:0-0\"], scope_symbol_id = 4611686018427420670 : i64} : (tensor<f32>) -> tensor<f32>\n  %1 = \"oneflow.relu\"(%0) {device_tag = \"cuda\", op_name = \"Relu_2\", op_type_name = \"relu\", device_name = [\"0:0-0\"], scope_symbol_id = 4611686018427420670 : i64} : (tensor<f32>) -> tensor<f32>\n  // CHECK: return [[IDEMPOTENT]]\n  return %1: tensor<f32>\n}\n\n// CHECK-LABEL: func.func @testTripleIdempotent\n// CHECK-SAME:  ([[ARG0:%.+]]: tensor<f32>)\nfunc.func @testTripleIdempotent(%arg0: tensor<f32>) -> tensor<f32> {\n  // CHECK: [[IDEMPOTENT:%.+]] = \"oneflow.relu\"([[ARG0]])\n  %0 = \"oneflow.relu\"(%arg0) {device_tag = \"cuda\", op_name = \"Relu_1\", op_type_name = \"relu\", device_name = [\"0:0-0\"], scope_symbol_id = 4611686018427420670 : i64} : (tensor<f32>) -> tensor<f32>\n  %1 = \"oneflow.relu\"(%0) {device_tag = \"cuda\", op_name = \"Relu_2\", op_type_name = \"relu\", device_name = [\"0:0-0\"], scope_symbol_id = 4611686018427420670 : i64} : (tensor<f32>) -> tensor<f32>\n  %2 = \"oneflow.relu\"(%1) {device_tag = \"cuda\", op_name = \"Relu_3\", op_type_name = \"relu\", device_name = [\"0:0-0\"], scope_symbol_id = 4611686018427420670 : i64} : (tensor<f32>) -> tensor<f32>\n  // CHECK: return [[IDEMPOTENT]]\n  return %2: tensor<f32>\n}\n\n// CHECK-LABEL: func.func @testDoubleInvolution\n// CHECK-SAME:  ([[ARG0:%.+]]: tensor<f32>)\nfunc.func @testDoubleInvolution(%arg0: tensor<f32>) -> tensor<f32> {\n  %0 = \"oneflow.negative\"(%arg0) {device_tag = \"cuda\", op_name = \"Relu_1\", op_type_name = \"relu\", device_name = [\"0:0-0\"], scope_symbol_id = 4611686018427420670 : i64} : (tensor<f32>) -> tensor<f32>\n  %1 = \"oneflow.negative\"(%0) {device_tag = \"cuda\", op_name = \"Relu_2\", op_type_name = \"relu\", device_name = [\"0:0-0\"], scope_symbol_id = 4611686018427420670 : i64} : (tensor<f32>) -> tensor<f32>\n  // CHECK: return [[ARG0]]\n  return %1: tensor<f32>\n}\n\n// CHECK-LABEL: func.func @testTripleInvolution\n// CHECK-SAME:  ([[ARG0:%.+]]: tensor<f32>)\nfunc.func @testTripleInvolution(%arg0: tensor<f32>) -> tensor<f32> {\n  // CHECK: [[INVOLUTION:%.+]] = \"oneflow.negative\"([[ARG0]])\n  %0 = \"oneflow.negative\"(%arg0) {device_tag = \"cuda\", op_name = \"Relu_1\", op_type_name = \"relu\", device_name = [\"0:0-0\"], scope_symbol_id = 4611686018427420670 : i64} : (tensor<f32>) -> tensor<f32>\n  %1 = \"oneflow.negative\"(%0) {device_tag = \"cuda\", op_name = \"Relu_2\", op_type_name = \"relu\", device_name = [\"0:0-0\"], scope_symbol_id = 4611686018427420670 : i64} : (tensor<f32>) -> tensor<f32>\n  %2 = \"oneflow.negative\"(%1) {device_tag = \"cuda\", op_name = \"Relu_3\", op_type_name = \"relu\", device_name = [\"0:0-0\"], scope_symbol_id = 4611686018427420670 : i64} : (tensor<f32>) -> tensor<f32>\n  // CHECK: return [[INVOLUTION]]\n  return %2: tensor<f32>\n}\n\n// CHECK-LABEL: func.func @testFailedInvolutionFoldDueToDifferentPlacement\n// CHECK-SAME:  ([[ARG0:%.+]]: tensor<f32>)\nfunc.func @testFailedInvolutionFoldDueToDifferentPlacement(%arg0: tensor<f32>) -> tensor<f32> {\n  %0 = \"oneflow.negative\"(%arg0) {device_tag = \"cuda\", op_name = \"Relu_1\", op_type_name = \"relu\", device_name = [\"0:0-0\"], scope_symbol_id = 4611686018427420670 : i64} : (tensor<f32>) -> tensor<f32>\n  %1 = \"oneflow.negative\"(%0) {device_tag = \"cuda\", op_name = \"Relu_2\", op_type_name = \"relu\", device_name = [\"1:0-0\"], scope_symbol_id = 4611686018427420670 : i64} : (tensor<f32>) -> tensor<f32>\n  // CHECK: [[INVOLUTION:%.+]] = \"oneflow.negative\"(%1)\n  %2 = \"oneflow.negative\"(%1) {device_tag = \"cuda\", op_name = \"Relu_3\", op_type_name = \"relu\", device_name = [\"0:0-0\"], scope_symbol_id = 4611686018427420670 : i64} : (tensor<f32>) -> tensor<f32>\n  // CHECK: return [[INVOLUTION]]\n  return %2: tensor<f32>\n}\n\n// CHECK-LABEL: func.func @testFailedInvolutionFoldDueToDifferentDevice\n// CHECK-SAME:  ([[ARG0:%.+]]: tensor<f32>)\nfunc.func @testFailedInvolutionFoldDueToDifferentDevice(%arg0: tensor<f32>) -> tensor<f32> {\n  %0 = \"oneflow.negative\"(%arg0) {device_tag = \"cuda\", op_name = \"Relu_1\", op_type_name = \"relu\", device_name = [\"0:0-0\"], scope_symbol_id = 4611686018427420670 : i64} : (tensor<f32>) -> tensor<f32>\n  %1 = \"oneflow.negative\"(%0) {device_tag = \"cpu\", op_name = \"Relu_2\", op_type_name = \"relu\", device_name = [\"0:0-0\"], scope_symbol_id = 4611686018427420670 : i64} : (tensor<f32>) -> tensor<f32>\n  // CHECK: [[INVOLUTION:%.+]] = \"oneflow.negative\"(%1)\n  %2 = \"oneflow.negative\"(%1) {device_tag = \"cuda\", op_name = \"Relu_3\", op_type_name = \"relu\", device_name = [\"0:0-0\"], scope_symbol_id = 4611686018427420670 : i64} : (tensor<f32>) -> tensor<f32>\n  // CHECK: return [[INVOLUTION]]\n  return %2: tensor<f32>\n}\n"
  },
  {
    "path": "oneflow/ir/test/OneFlow/with_cuda/lit.local.cfg",
    "content": "if not config.BUILD_CUDA:\n  config.unsupported = True\n"
  },
  {
    "path": "oneflow/ir/test/OneFlow/with_cuda/test_conv_bn_auto_nhwc.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n# RUN: python3 -m oneflow.test_utils.throttle --with-cuda=%with_cuda python3 %s | FileCheck %s\n# CHECK: oneflow.transpose\n\nimport os\nimport unittest\nimport numpy as np\nimport oneflow as flow\nimport oneflow.unittest\nimport oneflow.nn as nn\nfrom flowvision.models.resnet import resnet50\n\n\ndef _test_fuse_conv_bn(test_case, with_cuda):\n    data = flow.randn(1, 3, 224, 224)\n    if with_cuda:\n        data = data.to(\"cuda\")\n\n    model = resnet50(pretrained=False, progress=True)\n    if with_cuda:\n        model.to(\"cuda\")\n    model.eval()\n    eager_res = model(data)\n\n    class Resnet50Graph(nn.Graph):\n        def __init__(self):\n            super().__init__()\n            self.model = model\n\n        def build(self, *input):\n            return self.model(*input)\n\n    graph = Resnet50Graph()\n    lazy_res = graph(data)\n\n    test_case.assertTrue(\n        np.allclose(eager_res.numpy(), lazy_res.numpy(), rtol=1e-2, atol=1e-2)\n    )\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestFuseConvBn(oneflow.unittest.MLIRTestCase):\n    def setUp(self):\n        os.environ[\"ONEFLOW_MLIR_ENABLE_ROUND_TRIP\"] = \"1\"\n        os.environ[\"ONEFLOW_MLIR_PREFER_NHWC\"] = \"1\"\n        os.environ[\"ONEFLOW_MLIR_ENABLE_INFERENCE_OPTIMIZATION\"] = \"1\"\n        os.environ[\"ONEFLOW_MLIR_FUSE_FORWARD_OPS\"] = \"1\"\n        os.environ[\"ONEFLOW_MLIR_FUSE_NORMALIZATION_OPS\"] = \"1\"\n        os.environ[\"ONEFLOW_MLIR_PRINT_STATS\"] = \"1\"\n\n    @unittest.skipUnless(oneflow.sysconfig.with_cuda(), \"only test cpu cases\")\n    def test_fuse_conv_bn_cuda(test_case):\n        _test_fuse_conv_bn(test_case, True)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "oneflow/ir/test/OneFlow/with_cuda/test_fuse_bias_add_dropout.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n# RUN: python3 -m oneflow.test_utils.throttle --with-cuda=%with_cuda python3 %s | FileCheck %s\n# CHECK-NOT: oneflow.bias_add\n\nimport unittest\nimport numpy as np\n\nimport os\n\n\nimport oneflow as flow\nimport oneflow.unittest\nimport oneflow.sysconfig\n\n\ndef do_bias_add_dropout_graph(test_case, with_cuda, prob):\n    x = flow.randn(2, 3, 4, 5)\n    bias = flow.randn(5)\n    dropout = flow.nn.Dropout(p=prob)\n    if with_cuda:\n        x = x.cuda()\n        bias = bias.to(\"cuda\")\n        dropout.to(\"cuda\")\n\n    eager_res = dropout(flow._C.bias_add(x, bias, axis=3))\n\n    class GraphToRun(flow.nn.Graph):\n        def __init__(self):\n            super().__init__()\n            self.dropout = dropout\n\n        def build(self, x, bias):\n            return self.dropout(flow._C.bias_add(x, bias, axis=3))\n\n    graph_to_run = GraphToRun()\n    lazy_res = graph_to_run(x, bias)\n    if prob == 1.0:\n        test_case.assertTrue(np.array_equal(eager_res.numpy(), lazy_res.numpy()))\n    else:\n        test_case.assertTrue(lazy_res.sum().item() != 0.0)\n\n\n@flow.unittest.skip_unless_1n1d()\n@unittest.skipUnless(oneflow.sysconfig.with_cuda(), \"needs -DBUILD_CUDA=ON\")\nclass TestBiasAddDropout(oneflow.unittest.MLIRTestCase):\n    def setUp(self):\n        os.environ[\"ONEFLOW_MLIR_ENABLE_ROUND_TRIP\"] = \"1\"\n        os.environ[\"ONEFLOW_MLIR_FUSE_FORWARD_OPS\"] = \"1\"\n        os.environ[\"ONEFLOW_MLIR_STDOUT\"] = \"1\"\n\n    def test_bias_add_dropout_graph(test_case):\n        do_bias_add_dropout_graph(test_case, True, 1.0)\n        do_bias_add_dropout_graph(test_case, True, 0.5)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "oneflow/ir/test/OneFlow/with_cuda/test_fuse_bias_add_gelu.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n# RUN: python3 -m oneflow.test_utils.throttle --with-cuda=%with_cuda python3 %s | FileCheck %s\n# CHECK-NOT: oneflow.bias_add\n\nimport unittest\nimport numpy as np\n\nimport os\n\nimport oneflow as flow\nimport oneflow.unittest\nimport oneflow.sysconfig\n\n\ndef do_bias_add_gelu_graph(test_case, with_cuda):\n    x = flow.randn(2, 3, 4, 5)\n    bias = flow.randn(5)\n    gelu = flow.nn.GELU()\n    if with_cuda:\n        x = x.cuda()\n        bias = bias.to(\"cuda\")\n        gelu.to(\"cuda\")\n\n    eager_res = gelu(flow._C.bias_add(x, bias, axis=3))\n\n    class GraphToRun(flow.nn.Graph):\n        def __init__(self):\n            super().__init__()\n            self.gelu = gelu\n\n        def build(self, x, bias):\n            return self.gelu(flow._C.bias_add(x, bias, axis=3))\n\n    graph_to_run = GraphToRun()\n    lazy_res = graph_to_run(x, bias)\n    test_case.assertTrue(np.array_equal(eager_res.numpy(), lazy_res.numpy()))\n\n\n@flow.unittest.skip_unless_1n1d()\n@unittest.skipUnless(oneflow.sysconfig.with_cuda(), \"needs -DBUILD_CUDA=ON\")\nclass TestBiasAddGelu(oneflow.unittest.MLIRTestCase):\n    def setUp(self):\n        os.environ[\"ONEFLOW_MLIR_ENABLE_ROUND_TRIP\"] = \"1\"\n        os.environ[\"ONEFLOW_MLIR_FUSE_FORWARD_OPS\"] = \"1\"\n        os.environ[\"ONEFLOW_MLIR_STDOUT\"] = \"1\"\n\n    def test_bias_add_gelu_graph(test_case):\n        do_bias_add_gelu_graph(test_case, True)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "oneflow/ir/test/OneFlow/with_cuda/test_fuse_bn_add_relu.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n# RUN: python3 -m oneflow.test_utils.throttle --with-cuda=%with_cuda python3 %s | FileCheck %s\n# CHECK-NOT: \"oneflow.normalization\"\n\nimport unittest\nimport numpy as np\n\nimport os\n\n\nimport oneflow as flow\nimport oneflow.unittest\nimport oneflow.sysconfig\n\n\ndef do_normalization_add_relu_graph(test_case, with_cuda):\n    def get_bn(fused=True):\n        if fused:\n            return flow.nn.FusedBatchNorm2d(num_features=2, eps=1e-5, momentum=0.1).to(\n                \"cuda\"\n            )\n        else:\n            return flow.nn.BatchNorm2d(num_features=2, eps=1e-5, momentum=0.1).to(\n                \"cuda\"\n            )\n\n    class GraphToRun(flow.nn.Graph):\n        def __init__(self):\n            super().__init__()\n            self.m = get_bn()\n\n        def build(self, x, addend):\n            return self.m(x, addend=addend)\n\n    class GraphToRunWithOpt(flow.nn.Graph):\n        def __init__(self):\n            super().__init__()\n            self.m = get_bn(fused=False)\n\n        def build(self, x, addend):\n            return flow.relu(self.m(x) + addend)\n\n    graph_to_run = GraphToRun()\n    graph_to_run_opt = GraphToRunWithOpt()\n    x = flow.Tensor(np.random.randn(4, 2, 8, 3)).to(\"cuda\")\n    addend = flow.Tensor(np.random.randn(4, 2, 8, 3)).to(\"cuda\")\n\n    eager_res = flow.relu(get_bn(fused=False)(x) + addend)\n    eager_res_fuse = get_bn()(x, addend=addend)\n    lazy_res = graph_to_run(x, addend)\n    lazy_res_opt = graph_to_run_opt(x, addend)\n    test_case.assertTrue(np.array_equal(eager_res.numpy(), eager_res_fuse.numpy()))\n    test_case.assertTrue(np.array_equal(eager_res.numpy(), lazy_res.numpy()))\n    test_case.assertTrue(np.array_equal(eager_res.numpy(), lazy_res_opt.numpy()))\n\n\n@flow.unittest.skip_unless_1n1d()\n@unittest.skipUnless(oneflow.sysconfig.with_cuda(), \"needs -DBUILD_CUDA=ON\")\nclass TestNormalizationAddRelu(oneflow.unittest.MLIRTestCase):\n    def setUp(self):\n        os.environ[\"ONEFLOW_MLIR_ENABLE_ROUND_TRIP\"] = \"1\"\n        os.environ[\"ONEFLOW_MLIR_FUSE_NORMALIZATION_OPS\"] = \"1\"\n        os.environ[\"ONEFLOW_MLIR_PRINT_STATS\"] = \"1\"\n\n    def test_normalization_add_relu_graph(test_case):\n        do_normalization_add_relu_graph(test_case, True)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "oneflow/ir/test/OneFlow/with_cuda/test_fuse_gelu.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n# RUN: python3 -m oneflow.test_utils.throttle --with-cuda=%with_cuda python3 %s | FileCheck %s\n# CHECK-NOT: oneflow.broadcast_matmul\n# CHECK-NOT: oneflow.fused_matmul_bias\n# CHECK-NOT: oneflow.narrow\n# CHECK: \"oneflow.fused_glu\"\n\nimport unittest\nimport numpy as np\n\nimport os\n\n\nimport oneflow as flow\nimport oneflow.nn as nn\nimport oneflow.nn.functional as F\nimport oneflow.unittest\nimport oneflow.sysconfig\n\n\nclass GEGLU(nn.Module):\n    r\"\"\"\n    A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.\n\n    Parameters:\n        dim_in (`int`): The number of channels in the input.\n        dim_out (`int`): The number of channels in the output.\n    \"\"\"\n\n    def __init__(\n        self, dim_in: int, dim_out: int,\n    ):\n        super().__init__()\n        self.proj = nn.Linear(dim_in, dim_out * 2)\n\n    def forward(self, hidden_states):\n        hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)\n        return hidden_states * F.gelu(gate)\n\n\nclass GraphToRun(flow.nn.Graph):\n    def __init__(self, gelu_mod):\n        super().__init__()\n        self.gelu_mod = gelu_mod\n\n    def build(self, hidden_states):\n        return self.gelu_mod(hidden_states)\n\n\ndef do_fused_gelu_graph(test_case, dev, fuse_linear=False):\n    if fuse_linear:\n        os.environ[\"ONEFLOW_KERNEL_ENABLE_FUSED_LINEAR\"] = \"1\"\n    else:\n        os.environ[\"ONEFLOW_KERNEL_ENABLE_FUSED_LINEAR\"] = \"0\"\n    gelu_mod = GEGLU(640, 5120).to(dev)\n    hidden_states = flow.randn(2, 2304, 640).to(dev)\n    eager_res = gelu_mod(hidden_states)\n    graph_to_run = GraphToRun(gelu_mod)\n    lazy_res = graph_to_run(hidden_states)\n    test_case.assertTrue(np.allclose(eager_res.numpy(), lazy_res.numpy()))\n\n\n@flow.unittest.skip_unless_1n1d()\n@unittest.skipUnless(oneflow.sysconfig.with_cuda(), \"needs -DBUILD_CUDA=ON\")\nclass TestFusedGelu(oneflow.unittest.MLIRTestCase):\n    def setUp(self):\n        os.environ[\"ONEFLOW_MLIR_ENABLE_ROUND_TRIP\"] = \"1\"\n        os.environ[\"ONEFLOW_MLIR_STDOUT\"] = \"1\"\n        os.environ[\"ONEFLOW_MLIR_FUSE_OPS_WITH_BACKWARD_IMPL\"] = \"1\"\n\n    def test_fused_gelu_graph(test_case):\n        do_fused_gelu_graph(test_case, \"cuda\", fuse_linear=True)\n        do_fused_gelu_graph(test_case, \"cuda\", fuse_linear=False)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "oneflow/ir/test/OneFlow/with_cuda/test_fuse_scale_tril.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n# RUN: python3 -m oneflow.test_utils.throttle --with-cuda=%with_cuda python3 %s | FileCheck %s\n# CHECK-NOT: oneflow.tril\n\nimport os\nimport unittest\nimport numpy as np\n\n\nimport oneflow as flow\nfrom collections import OrderedDict\nfrom oneflow.test_utils.test_util import GenArgDict\n\n\ndef _test_fused_scale_tril(\n    test_case, shape, diagonal=0, scale=1.0,\n):\n    x = np.random.rand(*shape)\n    # Different dtype will result in insert of cast op causing pass to fail.\n    tensor_x = flow.tensor(x, device=\"cuda\", dtype=flow.float32)\n    eager_out = flow.tril(tensor_x, diagonal) * scale\n\n    class TestFuseScaleTril(flow.nn.Graph):\n        def __init__(self):\n            super().__init__()\n\n        def build(self):\n            return flow.tril(tensor_x * scale, diagonal)\n\n    lazy_out_0 = TestFuseScaleTril()()\n    test_case.assertTrue(np.allclose(eager_out.numpy(), lazy_out_0.numpy()))\n\n    class TestFuseTrilScale(flow.nn.Graph):\n        def __init__(self):\n            super().__init__()\n\n        def build(self):\n            return flow.tril(tensor_x, diagonal) * scale\n\n    lazy_out_1 = TestFuseTrilScale()()\n    test_case.assertTrue(np.allclose(eager_out.numpy(), lazy_out_1.numpy()))\n\n\n@flow.unittest.skip_unless_1n1d()\nclass FusedScaleTrilTestCase(flow.unittest.MLIRTestCase):\n    def setUp(self):\n        os.environ[\"ONEFLOW_MLIR_ENABLE_ROUND_TRIP\"] = \"1\"\n        os.environ[\"ONEFLOW_MLIR_FUSE_FORWARD_OPS\"] = \"1\"\n        os.environ[\"ONEFLOW_MLIR_STDOUT\"] = \"1\"\n\n    def test_fused_scale_tril(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"shape\"] = [(5, 5), (4, 6)]\n        arg_dict[\"diagonal\"] = [-1, 0]\n        arg_dict[\"scale\"] = [-2.3, 2.0]\n        for kwargs in GenArgDict(arg_dict):\n            _test_fused_scale_tril(test_case, **kwargs)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "oneflow/ir/test/OneFlow/with_cuda/test_fused_matmul_bias.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n# RUN: python3 -m oneflow.test_utils.throttle --with-cuda=%with_cuda python3 %s | FileCheck %s\n# CHECK-NOT: oneflow.bias_add\n# CHECK: %[[OUT0:[a-zA-Z0-9_]+]]:5 = \"oneflow.grouped_matmul_bias\"\n\nimport unittest\nimport numpy as np\nimport os\nimport oneflow as flow\nimport oneflow.unittest\nimport oneflow.sysconfig\n\n\ndef _matmul_bias0(x, weight, bias):\n    return flow._C.bias_add(\n        flow._C.matmul(x, weight, transpose_b=True), bias, axis=len(x.shape) - 1\n    )\n\n\ndef _matmul_bias1(x, w, bias):\n    return flow._C.fused_matmul_bias(x, w, bias)\n\n\ndef do_fused_matmul_bias_graph(test_case, dev):\n    x = np.random.uniform(low=-1, high=1, size=(8, 9))\n    w = np.random.uniform(low=-1, high=1, size=(10, 9))\n    bias = np.random.uniform(low=-1, high=1, size=(10))\n    x = flow.from_numpy(x).to(dev).to(flow.float32)\n    w = flow.from_numpy(w).to(dev).to(flow.float32)\n    bias = flow.from_numpy(bias).to(dev).to(flow.float32)\n    eager_res = _matmul_bias0(x, w, bias) * 5\n\n    class GraphToRun(flow.nn.Graph):\n        def __init__(self):\n            super().__init__()\n\n        def build(self, x, w, bias):\n            return (\n                _matmul_bias0(x, w, bias)\n                + _matmul_bias1(x, w, bias)\n                + _matmul_bias0(x, w, bias)\n                + _matmul_bias1(x, w, bias)\n                + _matmul_bias0(x, w, bias)\n            )\n\n    graph_to_run = GraphToRun()\n    lazy_res = graph_to_run(x, w, bias)\n    test_case.assertTrue(np.allclose(eager_res.numpy(), lazy_res.numpy()))\n\n\n@flow.unittest.skip_unless_1n1d()\n@unittest.skipUnless(oneflow.sysconfig.with_cuda(), \"needs -DBUILD_CUDA=ON\")\nclass TestGroupMatMulBias(oneflow.unittest.MLIRTestCase):\n    def setUp(self):\n        os.environ[\"ONEFLOW_MLIR_ENABLE_ROUND_TRIP\"] = \"1\"\n        os.environ[\"ONEFLOW_MLIR_GROUP_MATMUL\"] = \"1\"\n        os.environ[\"ONEFLOW_MLIR_STDOUT\"] = \"1\"\n        os.environ[\"ONEFLOW_MLIR_CSE\"] = \"0\"\n\n    def test_fused_matmul_bias_graph(test_case):\n        do_fused_matmul_bias_graph(test_case, \"cuda\")\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "oneflow/ir/test/OneFlow/with_cuda/test_fused_multi_head_attention_inference.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n# RUN: python3 -m oneflow.test_utils.throttle --with-cuda=%with_cuda python3 %s | FileCheck %s\n# CHECK-NOT: oneflow.softmax\n# CHECK-NOT: oneflow.batch_matmul\n\n\nimport unittest\nimport numpy as np\n\nimport math\nimport os\n\n\nimport oneflow as flow\nimport oneflow.unittest\nimport oneflow.sysconfig\n\n\ndef _ref(query, key, value, num_heads, causal=False):\n    query = query.view(query.shape[0], query.shape[1], num_heads, -1).permute(\n        0, 2, 1, 3\n    )\n    key = key.view(key.shape[0], key.shape[1], num_heads, -1).permute(0, 2, 3, 1)\n    value = value.view(value.shape[0], value.shape[1], num_heads, -1).permute(\n        0, 2, 1, 3\n    )\n    scores = flow.matmul(query, key) / math.sqrt(query.shape[-1])\n    if causal:\n        causal_mask = flow.triu(\n            flow.ones(\n                scores.shape[-2], scores.shape[-1], dtype=flow.bool, device=\"cuda\"\n            ),\n            1,\n        )\n        scores = flow.masked_fill(scores, causal_mask, float(\"-inf\"))\n    attn = flow.softmax(scores, dim=-1)\n    out = flow.matmul(attn, value)\n    out = out.permute(0, 2, 1, 3)\n    out = out.reshape(out.shape[0], out.shape[1], -1)\n    return out\n\n\ndef _ref2(query, key, value, num_heads, causal=False):\n    query = query.view(query.shape[0], query.shape[1], num_heads, -1).permute(\n        0, 2, 1, 3\n    )\n    key = key.view(key.shape[0], key.shape[1], num_heads, -1).permute(0, 2, 1, 3)\n    value = value.view(value.shape[0], value.shape[1], num_heads, -1).permute(\n        0, 2, 1, 3\n    )\n    query = query.reshape(-1, query.shape[2], query.shape[3])\n    key = key.reshape(-1, key.shape[2], key.shape[3]).permute(0, 2, 1)\n    value = value.reshape(-1, value.shape[2], value.shape[3])\n\n    scale = 1 / math.sqrt(query.shape[-1])\n\n    scores = flow.baddbmm(\n        flow.empty(\n            query.shape[0],\n            query.shape[1],\n            key.shape[1],\n            dtype=query.dtype,\n            device=query.device,\n        ),\n        query,\n        key,\n        beta=0,\n        alpha=scale,\n    )\n\n    if causal:\n        causal_mask = flow.triu(\n            flow.ones(\n                scores.shape[-2], scores.shape[-1], dtype=flow.bool, device=\"cuda\"\n            ),\n            1,\n        )\n        scores = flow.masked_fill(scores, causal_mask, float(\"-inf\"))\n    attn = flow.softmax(scores, dim=-1)\n    out = flow.matmul(attn, value)\n    out = out.reshape(-1, num_heads, out.shape[1], out.shape[2])\n    out = out.permute(0, 2, 1, 3)\n    out = out.reshape(out.shape[0], out.shape[1], -1)\n\n    return out\n\n\ndef _fused_mha(query, key, value, num_heads, causal=False):\n    return flow._C.fused_multi_head_attention_inference(\n        query, key, value, num_heads, causal=causal\n    )\n\n\nclass GraphToRun(flow.nn.Graph):\n    def __init__(self, ref=None, num_heads=None, causal=False):\n        super().__init__()\n        self.ref = ref\n        self.causal = causal\n        self.num_heads = num_heads\n\n    def build(self, query, key, value):\n        return self.ref(query, key, value, self.num_heads, self.causal)\n\n\ndef _test_fused_multi_head_attention_inference(\n    test_case,\n    batch_size,\n    num_heads,\n    query_seq_len,\n    kv_seq_len,\n    query_head_size,\n    value_head_size,\n    dtype,\n    graph_builder,\n    ref,\n    causal=False,\n):\n\n    query = flow.randn(\n        (batch_size, query_seq_len, num_heads * query_head_size),\n        device=\"cuda\",\n        dtype=flow.float,\n    ).to(dtype)\n    key = flow.randn(\n        (batch_size, kv_seq_len, num_heads * query_head_size),\n        device=\"cuda\",\n        dtype=flow.float,\n    ).to(dtype)\n    value = flow.randn(\n        (batch_size, kv_seq_len, num_heads * value_head_size),\n        device=\"cuda\",\n        dtype=flow.float,\n    ).to(dtype)\n\n    g = graph_builder(ref=ref, num_heads=num_heads, causal=causal)\n    ref_out = ref(query, key, value, num_heads, causal).numpy()\n    fused_out = _fused_mha(query, key, value, num_heads, causal).numpy()\n    g_out = g(query, key, value).numpy()\n    test_case.assertTrue(np.allclose(ref_out, fused_out, atol=1e-2, rtol=1e-2))\n    test_case.assertTrue(np.allclose(ref_out, g_out, atol=1e-2, rtol=1e-2))\n\n\n@flow.unittest.skip_unless_1n1d()\n@unittest.skipUnless(oneflow.sysconfig.with_cuda(), \"needs -DBUILD_CUDA=ON\")\n# TODO: skip for GTX1080 in CI\n@unittest.skipUnless(\n    flow.cuda.get_device_capability()[0] >= 7, \"needs CUDA compatibility >= 7\"\n)\nclass TestFusedMultiHeadAttentionInference(flow.unittest.MLIRTestCase):\n    def setUp(self):\n        os.environ[\"ONEFLOW_MLIR_ENABLE_ROUND_TRIP\"] = \"1\"\n        os.environ[\"ONEFLOW_MLIR_FUSE_FORWARD_OPS\"] = \"1\"\n        os.environ[\"ONEFLOW_MLIR_STDOUT\"] = \"1\"\n        os.environ[\"ONEFLOW_MLIR_CSE\"] = \"0\"\n\n    def test_multi_head_attention_inference(test_case):\n        # test_case,batch_size, num_heads,query_seq_len, kv_seq_len,query_head_size,value_head_size,dtype\n        for ref in [_ref, _ref2]:\n            _test_fused_multi_head_attention_inference(\n                test_case, 2, 8, 4096, 4096, 40, 40, flow.float16, GraphToRun, ref\n            )\n            _test_fused_multi_head_attention_inference(\n                test_case, 2, 8, 4096, 77, 40, 40, flow.float16, GraphToRun, ref\n            )\n            _test_fused_multi_head_attention_inference(\n                test_case, 2, 8, 1024, 1024, 80, 80, flow.float16, GraphToRun, ref\n            )\n            _test_fused_multi_head_attention_inference(\n                test_case, 2, 8, 1024, 77, 80, 80, flow.float16, GraphToRun, ref\n            )\n            _test_fused_multi_head_attention_inference(\n                test_case, 2, 8, 256, 256, 160, 160, flow.float16, GraphToRun, ref\n            )\n            _test_fused_multi_head_attention_inference(\n                test_case, 2, 8, 256, 77, 160, 160, flow.float16, GraphToRun, ref\n            )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "oneflow/ir/test/OneFlow/with_cuda/test_graph_save_and_load.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n# RUN: python3 -m oneflow.test_utils.throttle --with-cuda=%with_cuda python3 %s\n\nimport os\nimport sys\n\nsys.path.append(os.path.abspath(os.path.dirname(__file__)))\nsys.path.append(os.path.abspath(os.path.dirname(__file__)) + \"/..\")\n\n\nimport unittest\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.core.job import job_pb2 as job_pb\n\nfrom networks.resnet50 import resnet50\n\n\nclass InferGraph(flow.nn.Graph):\n    def __init__(self, placement_arg=None):\n        super().__init__()\n        model = resnet50()\n        if placement_arg is not None:\n            if \"placement\" in placement_arg:\n                model.to_global(**placement_arg)\n            else:\n                model.to(**placement_arg)\n        self.model = model\n\n    def build(self, image):\n        logits = self.model(image.to(\"cuda\"))\n        pred = logits.softmax()\n        return pred\n\n\n@unittest.skipIf(not flow.sysconfig.with_mlir(), \"only test with mlir\")\n@flow.unittest.skip_unless_1n1d()\nclass GraphSaveTestCase(flow.unittest.MLIRTestCase):\n    def setUp(self):\n        os.environ[\"ONEFLOW_MLIR_ENABLE_ROUND_TRIP\"] = \"1\"\n\n    def test_save_and_load(self):\n        placement_arg = {\n            \"placement\": flow.placement(\"cuda\", ranks=[0]),\n            \"sbp\": flow.sbp.broadcast,\n        }\n        graph = InferGraph(placement_arg)\n        image_placeholder = flow.empty(\n            (1, 3, 224, 224),\n            dtype=flow.float32,\n            placement=flow.placement(\"cpu\", ranks=[0]),\n            sbp=flow.sbp.broadcast,\n        )\n        graph._compile(image_placeholder)\n        saved_path = os.path.join(\"saved_model\", graph.name)\n        if not os.path.exists(saved_path):\n            os.makedirs(saved_path)\n        flow.save(graph, saved_path)\n\n        saved_ir_path = os.path.join(saved_path, \"model.mlir\")\n        serialized_job = oneflow._oneflow_internal.nn.graph.LoadSerializedJobFromIR(\n            saved_ir_path\n        )\n        job = job_pb.Job()\n        job.ParseFromString(serialized_job)\n\n        # TODO: run loaded job as graph and original graph, compare the result\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "oneflow/ir/test/Transform/lit.local.cfg",
    "content": "if not config.WITH_MLIR_CUDA_CODEGEN:\n  config.unsupported = True\n"
  },
  {
    "path": "oneflow/ir/test/Transform/matmul.mlir",
    "content": "// RUN: oneflow-opt %s  --insert-ofmempool  --convert-linalg-to-loops --convert-scf-to-cf --canonicalize --cse --memref-expand  --gpu-kernel-outlining \\\n// RUN: | oneflow-opt --pass-pipeline='builtin.module(gpu.module(expand-strided-metadata,lower-affine,strip-debuginfo,convert-gpu-to-nvvm,nvvm-to-cubin))'\n\nmodule {\n  func.func @JITOpGenerated0(%arg0: memref<5x10xf32, strided<[?, ?], offset: ?>>, %arg1: memref<2x5xf32, strided<[?, ?], offset: ?>>, %arg2: memref<2x10xf32>) attributes {llvm.emit_c_interface} {\n    %alloc = memref.alloc() : memref<512xi8>\n    %c0 = arith.constant 0 : index\n    %view = memref.view %alloc[%c0][] : memref<512xi8> to memref<1x2x10xf32>\n    %c10 = arith.constant 10 : index\n    %c2 = arith.constant 2 : index\n    %c1 = arith.constant 1 : index\n    %c0_0 = arith.constant 0 : index\n    %c5 = arith.constant 5 : index\n    %cst = arith.constant 0.000000e+00 : f32\n    %expand_shape = memref.expand_shape %arg0 [[0, 1], [2]] : memref<5x10xf32, strided<[?, ?], offset: ?>> into memref<1x5x10xf32, strided<[?, ?, ?], offset: ?>>\n    %expand_shape_1 = memref.expand_shape %arg1 [[0, 1], [2]] : memref<2x5xf32, strided<[?, ?], offset: ?>> into memref<1x2x5xf32, strided<[?, ?, ?], offset: ?>>\n    gpu.launch blocks(%arg3, %arg4, %arg5) in (%arg9 = %c1, %arg10 = %c2, %arg11 = %c10) threads(%arg6, %arg7, %arg8) in (%arg12 = %c1, %arg13 = %c1, %arg14 = %c1) {\n      memref.store %cst, %view[%c0_0, %arg4, %arg5] : memref<1x2x10xf32>\n      gpu.terminator\n    } {SCFToGPU_visited}\n    gpu.launch blocks(%arg3, %arg4, %arg5) in (%arg9 = %c1, %arg10 = %c2, %arg11 = %c10) threads(%arg6, %arg7, %arg8) in (%arg12 = %c1, %arg13 = %c1, %arg14 = %c1) {\n      scf.for %arg15 = %c0_0 to %c5 step %c1 {\n        %0 = memref.load %expand_shape_1[%c0_0, %arg4, %arg15] : memref<1x2x5xf32, strided<[?, ?, ?], offset: ?>>\n        %1 = memref.load %expand_shape[%c0_0, %arg15, %arg5] : memref<1x5x10xf32, strided<[?, ?, ?], offset: ?>>\n        %2 = memref.load %view[%c0_0, %arg4, %arg5] : memref<1x2x10xf32>\n        %3 = arith.mulf %0, %1 : f32\n        %4 = arith.addf %2, %3 : f32\n        memref.store %4, %view[%c0_0, %arg4, %arg5] : memref<1x2x10xf32>\n      }\n      gpu.terminator\n    } {SCFToGPU_visited}\n    %collapse_shape = memref.collapse_shape %view [[0, 1], [2]] : memref<1x2x10xf32> into memref<2x10xf32>\n    memref.copy %collapse_shape, %arg2 : memref<2x10xf32> to memref<2x10xf32>\n    return\n  }\n}"
  },
  {
    "path": "oneflow/ir/test/Transform/softmax.mlir",
    "content": "// RUN: oneflow-opt %s --pass-pipeline=\"builtin.module(oneflow-transform-dialect-interpreter{transform-file-name=%p/softmax_codegen_spec_no_vectorize.mlir})\" \\\n// RUN: | oneflow-opt  --insert-ofmempool  --convert-linalg-to-loops --convert-scf-to-cf --canonicalize --cse --memref-expand  --gpu-kernel-outlining \\\n// RUN: | oneflow-opt --pass-pipeline='builtin.module(gpu.module(expand-strided-metadata,lower-affine,strip-debuginfo,convert-gpu-to-nvvm,nvvm-to-cubin))'\n\n\n!tmp_tensor_t = tensor<16x128xf32>\n!in_tensor_t = tensor<16x128x128xf32>\n!out_tensor_t = tensor<16x128x128xf32>\n\nfunc.func @softmax() -> !out_tensor_t {\n  %cst_0 = arith.constant 0.0 : f32\n  %cst_1 = arith.constant 1.0 : f32\n  %cst_min = arith.constant -3.40282347E+38 : f32\n  %input = arith.constant dense<5.000000e+00> : !out_tensor_t\n\n  %input_max_empty = tensor.empty() : !tmp_tensor_t\n  %input_max_filled = linalg.fill ins(%cst_min : f32)\n    outs(%input_max_empty : !tmp_tensor_t) -> !tmp_tensor_t\n  %input_max = linalg.generic\n    {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>,\n                      affine_map<(d0, d1, d2) -> (d0, d1)>],\n                      iterator_types = [\"parallel\", \"parallel\", \"reduction\"]}\n     ins(%input : !in_tensor_t)\n    outs(%input_max_filled : !tmp_tensor_t) {\n      ^bb0(%arg0: f32, %arg1: f32):\n        %max = arith.maxf %arg0, %arg1 : f32\n        linalg.yield %max : f32\n      } -> !tmp_tensor_t\n\n  // This has been fused manually to avoid the fusion on tensors pass and reduce noise atm.\n  %exps_empty = tensor.empty() : !out_tensor_t\n  %exps_sum_empty = tensor.empty() : !tmp_tensor_t\n  %exps_sum_filled = linalg.fill ins(%cst_0 : f32)\n    outs(%exps_sum_empty : !tmp_tensor_t) -> !tmp_tensor_t\n  %exps, %exps_sum = linalg.generic\n    {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>,\n                      affine_map<(d0, d1, d2) -> (d0, d1)>,\n                      affine_map<(d0, d1, d2) -> (d0, d1, d2)>,\n                      affine_map<(d0, d1, d2) -> (d0, d1)>],\n                      iterator_types = [\"parallel\", \"parallel\", \"reduction\"]}\n     ins(%input, %input_max : !in_tensor_t, !tmp_tensor_t)\n    outs(%exps_empty, %exps_sum_filled : !out_tensor_t, !tmp_tensor_t) {\n      ^bb0(%arg0: f32, %arg1: f32, %arg2: f32, %arg3: f32):\n        %sub = arith.subf %arg0, %arg1 : f32\n        %exp = math.exp %sub : f32\n        %add = arith.addf %exp, %arg3 : f32\n        linalg.yield %exp, %add : f32, f32\n      } -> (!out_tensor_t, !tmp_tensor_t)\n\n  %res_empty = tensor.empty() : !out_tensor_t\n  %res = linalg.generic\n    {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>,\n                      affine_map<(d0, d1, d2) -> (d0, d1)>,\n                      affine_map<(d0, d1, d2) -> (d0, d1, d2)>],\n                      iterator_types = [\"parallel\", \"parallel\", \"parallel\"]}\n     ins(%exps, %exps_sum : !out_tensor_t, !tmp_tensor_t)\n    outs(%res_empty : !out_tensor_t) {\n      ^bb0(%arg0: f32, %arg1: f32, %arg2: f32):\n        // %10 = arith.divf %cst_1, %arg1 : f32\n        // %11 = arith.mulf %arg0, %10 : f32\n        %div = arith.divf %arg0, %arg1 : f32\n        linalg.yield %div : f32\n      } -> !out_tensor_t\n\n  return %res: !out_tensor_t\n}\n"
  },
  {
    "path": "oneflow/ir/test/Transform/softmax_codegen_spec.mlir",
    "content": "// RUN: oneflow-opt %s\n\ntransform.sequence failures(propagate) {\n^bb1(%module_op: !pdl.operation):\n  // Note: step 1, tiling and fusing linalg ops in block level.\n  %ops = transform.structured.match ops{[\"linalg.fill\", \"linalg.generic\"]}\n    in %module_op : (!pdl.operation) -> !pdl.operation\n\n  %match_0, %match_1, %match_2, %match_3, %match_end = transform.split_handle %ops\n    : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation,\n                           !pdl.operation, !pdl.operation)\n\n  %forall, %_ =\n    transform.structured.tile_to_forall_op %match_end tile_sizes [1, 4]\n      ( mapping = [#gpu.block<x>, #gpu.block<y>] )\n\n  transform.structured.fuse_into_containing_op %match_3 into %forall\n  transform.structured.fuse_into_containing_op %match_2 into %forall\n  transform.structured.fuse_into_containing_op %match_1 into %forall\n  transform.structured.fuse_into_containing_op %match_0 into %forall\n\n  transform.oneflow.canonicalization %module_op : (!pdl.operation) -> ()\n  transform.oneflow.cse %module_op : (!pdl.operation) -> ()\n\n\n  // Note: step 2, tiling and fusing linalg ops in thread level.\n  %ops_1 = transform.structured.match ops{[\"linalg.fill\", \"linalg.generic\"]}\n    in %module_op : (!pdl.operation) -> !pdl.operation\n  %match_0_0,\n  %match_0_1,\n  %match_0_2,\n  %match_0_3,\n  %match_0_end = transform.split_handle %ops_1\n    : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation,\n                           !pdl.operation, !pdl.operation)\n\n  %reduction_linalg_ops = transform.merge_handles %match_0_1,\n                                                  %match_0_3\n    : !pdl.operation\n  transform.structured.tile_to_forall_op %reduction_linalg_ops tile_sizes [1, 1]\n    ( mapping = [#gpu.thread<z>, #gpu.thread<y>] )\n\n  %parallel_linalg_ops = transform.merge_handles %match_0_0,\n                                                 %match_0_2,\n                                                 %match_0_end\n    : !pdl.operation\n  transform.structured.tile_to_forall_op %parallel_linalg_ops num_threads [1, 4, 32]\n    ( mapping = [#gpu.thread<z>, #gpu.thread<y>, #gpu.thread<x>] )\n\n  // Note: step 3,vectorize \n  transform.oneflow.canonicalization %module_op : (!pdl.operation) -> ()\n  transform.oneflow.cse %module_op : (!pdl.operation) -> ()\n  %to_vectorize = transform.structured.match ops{[\"func.func\"]} in %module_op : (!pdl.operation) -> !pdl.operation\n  transform.structured.vectorize %to_vectorize\n\n  // Note: step 4, bufferize\n  transform.oneflow.explicit_linalg_outcome %module_op : (!pdl.operation) -> ()\n\n  transform.bufferization.eliminate_empty_tensors %module_op\n\n  %empty = transform.structured.match ops{[\"tensor.empty\"]} in %module_op : (!pdl.operation) -> !pdl.operation\n  %empty_id = transform.cast %empty : !pdl.operation to !transform.op<\"tensor.empty\">\n  transform.bufferization.empty_tensor_to_alloc_tensor %empty_id : (!transform.op<\"tensor.empty\">) -> !transform.op<\"bufferization.alloc_tensor\">\n\n  %bufferized_module_op = transform.bufferization.one_shot_bufferize %module_op\n      {create_deallocs = false, bufferize_function_boundaries = true,  allow_return_allocs = true} : (!pdl.operation) -> !pdl.operation\n      \n  // Note: step 5, post bufferize function-type-related transform\n  transform.oneflow.canonicalization %bufferized_module_op : (!pdl.operation) -> ()\n  transform.oneflow.cse %bufferized_module_op : (!pdl.operation) -> ()\n  transform.oneflow.eliminate_copy %bufferized_module_op : (!pdl.operation) -> ()\n\n  %func = transform.structured.match ops{[\"func.func\"]} in %bufferized_module_op : (!pdl.operation) -> !pdl.operation\n  transform.structured.hoist_redundant_tensor_subsets %func\n    : (!pdl.operation) -> ()\n\n  // Note: step 6, post bufferize memory-buffer-pool transform\n  transform.oneflow.results_to_out_params %bufferized_module_op : (!pdl.operation) -> ()\n  transform.oneflow.eliminate_copy %bufferized_module_op : (!pdl.operation) -> ()\n  transform.oneflow.fold_alloc %func : (!pdl.operation) -> ()\n\n  // Note: step 7, mapping scf to gpu\n  %gpu_launch_op = transform.gpu.map_forall_to_blocks %bufferized_module_op { generate_gpu_launch }\n  transform.gpu.map_nested_forall_to_threads %gpu_launch_op block_dims = [32, 4, 1]\n}\n\n"
  },
  {
    "path": "oneflow/ir/test/Transform/softmax_codegen_spec_no_vectorize.mlir",
    "content": "// RUN: oneflow-opt %s\n\ntransform.sequence failures(propagate) {\n^bb1(%module_op: !pdl.operation):\n  // Note: step 1, tiling and fusing linalg ops in block level.\n  %ops = transform.structured.match ops{[\"linalg.fill\", \"linalg.generic\"]}\n    in %module_op : (!pdl.operation) -> !pdl.operation\n\n  %match_0, %match_1, %match_2, %match_3, %match_end = transform.split_handle %ops\n    : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation,\n                           !pdl.operation, !pdl.operation)\n\n  %forall, %_ =\n    transform.structured.tile_to_forall_op %match_end tile_sizes [1, 4]\n      ( mapping = [#gpu.block<x>, #gpu.block<y>] )\n\n  transform.structured.fuse_into_containing_op %match_3 into %forall\n  transform.structured.fuse_into_containing_op %match_2 into %forall\n  transform.structured.fuse_into_containing_op %match_1 into %forall\n  transform.structured.fuse_into_containing_op %match_0 into %forall\n\n  transform.oneflow.canonicalization %module_op : (!pdl.operation) -> ()\n  transform.oneflow.cse %module_op : (!pdl.operation) -> ()\n\n\n  // Note: step 2, tiling and fusing linalg ops in thread level.\n  %ops_1 = transform.structured.match ops{[\"linalg.fill\", \"linalg.generic\"]}\n    in %module_op : (!pdl.operation) -> !pdl.operation\n  %match_0_0,\n  %match_0_1,\n  %match_0_2,\n  %match_0_3,\n  %match_0_end = transform.split_handle %ops_1\n    : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation,\n                           !pdl.operation, !pdl.operation)\n\n  %reduction_linalg_ops = transform.merge_handles %match_0_1,\n                                                  %match_0_3\n    : !pdl.operation\n  transform.structured.tile_to_forall_op %reduction_linalg_ops tile_sizes [1, 1]\n    ( mapping = [#gpu.thread<z>, #gpu.thread<y>] )\n\n  %parallel_linalg_ops = transform.merge_handles %match_0_0,\n                                                 %match_0_2,\n                                                 %match_0_end\n    : !pdl.operation\n  transform.structured.tile_to_forall_op %parallel_linalg_ops num_threads [1, 4, 32]\n    ( mapping = [#gpu.thread<z>, #gpu.thread<y>, #gpu.thread<x>] )\n  transform.oneflow.canonicalization %module_op : (!pdl.operation) -> ()\n  transform.oneflow.cse %module_op : (!pdl.operation) -> ()\n\n  // Note: step 3, bufferize\n  transform.oneflow.explicit_linalg_outcome %module_op : (!pdl.operation) -> ()\n\n  transform.bufferization.eliminate_empty_tensors %module_op\n\n  %empty = transform.structured.match ops{[\"tensor.empty\"]} in %module_op : (!pdl.operation) -> !pdl.operation\n  %empty_id = transform.cast %empty : !pdl.operation to !transform.op<\"tensor.empty\">\n  transform.bufferization.empty_tensor_to_alloc_tensor %empty_id : (!transform.op<\"tensor.empty\">) -> !transform.op<\"bufferization.alloc_tensor\">\n\n  %bufferized_module_op = transform.bufferization.one_shot_bufferize %module_op\n      {create_deallocs = false, bufferize_function_boundaries = true,  allow_return_allocs = true} : (!pdl.operation) -> !pdl.operation\n      \n  // Note: step 4, post bufferize function-type-related transform\n  transform.oneflow.canonicalization %bufferized_module_op : (!pdl.operation) -> ()\n  transform.oneflow.cse %bufferized_module_op : (!pdl.operation) -> ()\n  transform.oneflow.eliminate_copy %bufferized_module_op : (!pdl.operation) -> ()\n\n  %func = transform.structured.match ops{[\"func.func\"]} in %bufferized_module_op : (!pdl.operation) -> !pdl.operation\n  transform.structured.hoist_redundant_tensor_subsets %func\n    : (!pdl.operation) -> ()\n\n  // Note: step 5, post bufferize memory-buffer-pool transform\n  transform.oneflow.results_to_out_params %bufferized_module_op : (!pdl.operation) -> ()\n  transform.oneflow.eliminate_copy %bufferized_module_op : (!pdl.operation) -> ()\n  transform.oneflow.fold_alloc %func : (!pdl.operation) -> ()\n\n  // Note: step 6, mapping scf to gpu\n  %gpu_launch_op = transform.gpu.map_forall_to_blocks %bufferized_module_op { generate_gpu_launch }\n  transform.gpu.map_nested_forall_to_threads %gpu_launch_op block_dims = [32, 4, 1]\n}\n\n"
  },
  {
    "path": "oneflow/ir/test/Transform/test_dialect.mlir",
    "content": "// RUN: oneflow-opt --oneflow-transform-dialect-interpreter %s -split-input-file -verify-diagnostics | FileCheck %s\n\n// Test One-Shot Bufferize.\n\ntransform.sequence failures(propagate) {\n^bb0(%arg1: !pdl.operation):\n  %0 = transform.structured.match ops{[\"func.func\"]} in %arg1 : (!pdl.operation) -> !pdl.operation\n  %1 = transform.bufferization.one_shot_bufferize %0 : (!pdl.operation) -> !pdl.operation\n}\n\n// CHECK-LABEL: func @test_function(\n//  CHECK-SAME:     %[[A:.*]]: tensor<?xf32>\nfunc.func @test_function(%A : tensor<?xf32>, %v : vector<4xf32>) -> (tensor<?xf32>) {\n  %c0 = arith.constant 0 : index\n\n  // CHECK: %[[A_memref:.*]] = bufferization.to_memref %[[A]]\n  // CHECK: %[[dim:.*]] = memref.dim %[[A_memref]]\n  // CHECK: %[[alloc:.*]] = memref.alloc(%[[dim]])\n  // CHECK: memref.copy %[[A_memref]], %[[alloc]]\n  // CHECK: vector.transfer_write %{{.*}}, %[[alloc]]\n  // CHECK: %[[res_tensor:.*]] = bufferization.to_tensor %[[alloc]]\n  %0 = vector.transfer_write %v, %A[%c0] : vector<4xf32>, tensor<?xf32>\n\n  // CHECK: memref.dealloc %[[alloc]]\n  // CHECK: return %[[res_tensor]]\n  return %0 : tensor<?xf32>\n}"
  },
  {
    "path": "oneflow/ir/test/lit.cfg.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n# -*- Python -*-\n\nimport os\nimport platform\nimport re\nimport subprocess\nimport tempfile\n\nimport lit.formats\nimport lit.util\n\nfrom lit.llvm import llvm_config\nfrom lit.llvm.subst import ToolSubst\nfrom lit.llvm.subst import FindTool\n\n# Configuration file for the 'lit' test runner.\n\n# name: The name of this test suite.\nconfig.name = \"ONEFLOW\"\n\nconfig.test_format = lit.formats.ShTest(not llvm_config.use_lit_shell)\n\n# suffixes: A list of file extensions to treat as test files.\nconfig.suffixes = [\".mlir\", \".py\"]\n\n# test_source_root: The root path where tests are located.\nconfig.test_source_root = os.path.dirname(__file__)\n\n# test_exec_root: The root path where tests should be run.\nconfig.test_exec_root = os.path.join(config.oneflow_obj_root, \"test\")\n\nconfig.substitutions.append((\"%PATH%\", config.environment[\"PATH\"]))\nconfig.substitutions.append((\"%shlibext\", config.llvm_shlib_ext))\n\nllvm_config.with_system_environment([\"HOME\", \"INCLUDE\", \"LIB\", \"TMP\", \"TEMP\"])\n\nllvm_config.use_default_substitutions()\n\n# excludes: A list of directories to exclude from the testsuite. The 'Inputs'\n# subdirectories contain auxiliary inputs for various tests in their parent\n# directories.\nconfig.excludes = [\n    \"Inputs\",\n    \"Examples\",\n    \"CMakeLists.txt\",\n    \"README.txt\",\n    \"LICENSE.txt\",\n    \"networks\",\n    \"test_fuse_cast_scale.mlir.py\",\n    \"test_util.py\",\n    \"test_mlir_opt.mlir.py\",\n    \"lit.cfg.py\",\n    \"saved_model\",\n]\n\n# test_source_root: The root path where tests are located.\nconfig.test_source_root = os.path.dirname(__file__)\n\n# test_exec_root: The root path where tests should be run.\nconfig.test_exec_root = os.path.join(config.oneflow_obj_root, \"test\")\nconfig.oneflow_tools_dir = os.path.join(config.oneflow_ir_obj_root, \"bin\")\n\n# Tweak the PATH to include the tools dir.\nllvm_config.with_environment(\"PATH\", config.llvm_tools_dir, append_path=True)\n\n# TODO: these two should be unnecessary\nllvm_config.with_environment(\n    \"LD_LIBRARY_PATH\",\n    os.path.join(config.oneflow_obj_root, \"third_party_install/protobuf/lib\"),\n    append_path=True,\n)\nllvm_config.with_environment(\n    \"LD_LIBRARY_PATH\",\n    os.path.join(config.oneflow_obj_root, \"_deps/glog-build\"),\n    append_path=True,\n)\n\nllvm_config.with_environment(\"ONEFLOW_MLIR_STDOUT\", \"1\")\nllvm_config.with_environment(\"ONEFLOW_MLIR_ENABLE_ROUND_TRIP\", \"1\")\nllvm_config.with_environment(\"ONEFLOW_MLIR_CSE\", \"1\")\nllvm_config.with_environment(\"ONEFLOW_MLIR_FUSE_FORWARD_OPS\", \"1\")\nllvm_config.with_environment(\n    \"PYTHONPATH\", os.path.join(config.oneflow_src_root, \"python\"), append_path=True,\n)\n# Searches for a runtime library with the given name and returns a tool\n# substitution of the same name and the found path.\n# Correctly handles the platforms shared library directory and naming conventions.\ndef add_runtime(name):\n    path = \"\"\n    for prefix in [\"\", \"lib\"]:\n        path = os.path.join(\n            config.llvm_shlib_dir, f\"{prefix}{name}{config.llvm_shlib_ext}\"\n        )\n        if os.path.isfile(path):\n            break\n    return ToolSubst(f\"%{name}\", path)\n\n\ntool_dirs = [config.oneflow_tools_dir, config.llvm_tools_dir]\ntools = [\n    \"oneflow-opt\",\n    \"oneflow-translate\",\n    \"oneflow-runner\",\n    add_runtime(\"mlir_runner_utils\"),\n]\n\nif config.WITH_MLIR_CUDA_CODEGEN:\n    tools.extend([add_runtime(\"mlir_cuda_runtime\")])\n\ntools.extend(\n    [\n        ToolSubst(\"%with_cuda\", config.BUILD_CUDA, unresolved=\"ignore\"),\n        ToolSubst(\"%linalg_test_lib_dir\", config.llvm_lib_dir, unresolved=\"ignore\"),\n        ToolSubst(\"%test_exec_root\", config.test_exec_root, unresolved=\"ignore\"),\n    ]\n)\nllvm_config.add_tool_substitutions(tools, tool_dirs)\n\ntry:\n    from iree import runtime as ireert\n    from iree.compiler import compile_str\n\n    config.WITH_ONEFLOW_IREE = True\nexcept ImportError:\n    config.WITH_ONEFLOW_IREE = False\n"
  },
  {
    "path": "oneflow/ir/test/lit.site.cfg.py.in",
    "content": "@LIT_SITE_CFG_IN_HEADER@\n\nimport sys\n\nconfig.host_triple = \"@LLVM_HOST_TRIPLE@\"\nconfig.target_triple = \"@TARGET_TRIPLE@\"\nconfig.llvm_src_root = \"@LLVM_SOURCE_DIR@\"\nconfig.llvm_obj_root = \"@LLVM_BINARY_DIR@\"\nconfig.llvm_tools_dir = \"@LLVM_TOOLS_DIR@\"\nconfig.llvm_lib_dir = \"@LLVM_LIBRARY_DIR@\"\nconfig.llvm_shlib_dir = \"@LLVM_LIBRARY_DIR@\"\nconfig.llvm_shlib_ext = \"@SHLIBEXT@\"\nconfig.llvm_exe_ext = \"@EXEEXT@\"\nconfig.lit_tools_dir = \"@LLVM_LIT_TOOLS_DIR@\"\nconfig.python_executable = \"@PYTHON_EXECUTABLE@\"\nconfig.gold_executable = \"@GOLD_EXECUTABLE@\"\nconfig.ld64_executable = \"@LD64_EXECUTABLE@\"\nconfig.enable_shared = @ENABLE_SHARED@\nconfig.enable_assertions = @ENABLE_ASSERTIONS@\nconfig.targets_to_build = \"@TARGETS_TO_BUILD@\"\nconfig.native_target = \"@LLVM_NATIVE_ARCH@\"\nconfig.llvm_bindings = \"@LLVM_BINDINGS@\".split(' ')\nconfig.host_os = \"@HOST_OS@\"\nconfig.host_cc = \"@HOST_CC@\"\nconfig.host_cxx = \"@HOST_CXX@\"\n# Note: ldflags can contain double-quoted paths, so must use single quotes here.\nconfig.host_ldflags = '@HOST_LDFLAGS@'\nconfig.llvm_use_sanitizer = \"@LLVM_USE_SANITIZER@\"\nconfig.llvm_host_triple = '@LLVM_HOST_TRIPLE@'\nconfig.host_arch = \"@HOST_ARCH@\"\nconfig.oneflow_src_root = \"@CMAKE_SOURCE_DIR@\"\nconfig.oneflow_obj_root = \"@CMAKE_BINARY_DIR@\"\nconfig.oneflow_ir_obj_root = \"@PROJECT_BINARY_DIR@\"\nconfig.WITH_MLIR_CUDA_CODEGEN = @WITH_MLIR_CUDA_CODEGEN@\nconfig.BUILD_CUDA = @BUILD_CUDA@\n\n# Support substitution of the tools_dir with user parameters. This is\n# used when we can't determine the tool dir at configuration time.\ntry:\n    config.llvm_tools_dir = config.llvm_tools_dir % lit_config.params\n    config.llvm_shlib_dir = config.llvm_shlib_dir % lit_config.params\nexcept KeyError:\n    e = sys.exc_info()[1]\n    key, = e.args\n    lit_config.fatal(\"unable to find %r parameter, use '--param=%s=VALUE'\" % (key,key))\n\n\nimport lit.llvm\nlit.llvm.initialize(lit_config, config)\n\n# Let the main config do the real work.\nlit_config.load_config(config, \"@CMAKE_SOURCE_DIR@/oneflow/ir/test/lit.cfg.py\")\n"
  },
  {
    "path": "oneflow/maybe/config.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#ifndef ONEFLOW_MAYBE_CONFIG_H_\n#define ONEFLOW_MAYBE_CONFIG_H_\n\n#include <cassert>\n\n// pre-define it if you use a logging library like glog\n#ifndef OF_MAYBE_ASSERT\n#define OF_MAYBE_ASSERT(_cond_) assert(_cond_)\n#endif\n\n// ASSERT_EQ is different from ASSERT in logging / testing framework\n// pre-define it if you use a logging library like glog\n#ifndef OF_MAYBE_ASSERT_EQ\n#define OF_MAYBE_ASSERT_EQ(_lhs_, _rhs_) OF_MAYBE_ASSERT(_lhs_ == _rhs_)\n#endif\n\n#if __GNUC__ >= 7\n#define OF_MAYBE_HAS_IS_AGGREGATE\n// in old versions of clang, __has_builtin(__is_aggregate) returns false\n#elif __clang__\n#if !__is_identifier(__is_aggregate)\n#define OF_MAYBE_HAS_IS_AGGREGATE\n#endif\n#else\n#if __has_builtin(__is_aggregate)\n#define OF_MAYBE_HAS_IS_AGGREGATE\n#endif\n#endif\n\n#ifdef OF_MAYBE_HAS_IS_AGGREGATE\n#define OF_MAYBE_IS_AGGREGATE(...) (__is_aggregate(__VA_ARGS__))\n#else\n// decay to POD checking if no such builtin (because implementing __is_aggregate need reflection)\n#define OF_MAYBE_IS_AGGREGATE(...) \\\n  (std::is_standard_layout<__VA_ARGS__>::value && std::is_trivial<__VA_ARGS__>::value)\n#endif\n\n// `__builtin_expect` exists at least since GCC 4 / Clang 3\n#define OF_MAYBE_EXPECT_FALSE(x) (__builtin_expect((x), 0))\n\n#if __has_cpp_attribute(nodiscard)\n#define OF_MAYBE_NODISCARD_FUNC [[nodiscard]]\n#define OF_MAYBE_NODISCARD_TYPE [[nodiscard]]\n#elif __has_attribute(warn_unused_result)\n#define OF_MAYBE_NODISCARD_FUNC \\\n  __attribute__((warn_unused_result))  // or [[gnu::warn_unused_result]]\n#define OF_MAYBE_NODISCARD_TYPE\n#endif\n\n#endif  // ONEFLOW_MAYBE_CONFIG_H_\n"
  },
  {
    "path": "oneflow/maybe/error.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#ifndef ONEFLOW_MAYBE_ERROR_H_\n#define ONEFLOW_MAYBE_ERROR_H_\n\n#include <cstddef>\n#include <cstdlib>\n#include <memory>\n#include <sstream>\n#include <string>\n#include <type_traits>\n#include <vector>\n#include <iostream>\n#include <string_view>\n\n#include \"utility.h\"\n#include \"type_traits.h\"\n\nnamespace oneflow {\n\nnamespace maybe {\n\nnamespace details {\n\ntemplate<typename D>\nstruct ErrorStackFromContainerBase {\n private:\n  using Derived = D;\n\n  auto& Stack() { return static_cast<Derived*>(this)->GetStack(); }\n\n  const auto& Stack() const { return static_cast<const Derived*>(this)->GetStack(); }\n\n public:\n  std::size_t StackSize() const { return Stack().size(); }\n\n  template<typename... Args>\n  void PushStack(Args&&... args) {\n    auto& s = Stack();\n    s.emplace(s.end(), std::forward<Args>(args)...);\n  }\n\n  template<typename T = Derived>\n  const typename T::StackType::value_type& StackElem(std::size_t index) const {\n    return Stack()[index];\n  }\n\n  auto StackBegin() const { return Stack().begin(); }\n  auto StackEnd() const { return Stack().end(); }\n};\n\n}  // namespace details\n\ntemplate<typename T>\nstruct StackedErrorTraits {\n  StackedErrorTraits() = delete;\n\n  using ErrorType = typename T::ErrorType;\n  using StackEntryType = typename T::StackEntryType;\n\n  template<\n      typename U,\n      std::enable_if_t<\n          std::is_same<T, RemoveCVRef<U>>::value\n              && std::is_same<ErrorType, RemoveCVRef<decltype(std::declval<U>().Error())>>::value,\n          int> = 0>\n  static decltype(auto) Error(U&& se) {\n    return se.Error();\n  }\n\n  static std::size_t StackSize(const T& se) { return se.StackSize(); }\n\n  static ConstRefExceptVoid<StackEntryType> StackElem(const T& se, std::size_t index) {\n    return se.StackElem(index);\n  }\n\n  template<typename U, typename... Args,\n           std::enable_if_t<std::is_same<T, RemoveCVRef<U>>::value, int> = 0>\n  static void PushStack(U&& se, Args&&... args) {\n    se.PushStack(std::forward<Args>(args)...);\n  }\n\n  template<typename U, std::enable_if_t<std::is_same<T, RemoveCVRef<U>>::value, int> = 0>\n  static std::string Dump(U&& se) {\n    return se.Dump();\n  }\n\n  template<typename U, std::enable_if_t<std::is_same<T, RemoveCVRef<U>>::value, int> = 0>\n  [[noreturn]] static void Abort(U&& se) {\n    se.Abort();\n  }\n};\n\ntemplate<typename T>\nstruct StackedErrorTraits<std::unique_ptr<T>> {\n  StackedErrorTraits() = delete;\n\n  using PointedTraits = StackedErrorTraits<T>;\n\n  using ValueType = std::unique_ptr<T>;\n\n  using ErrorType = typename PointedTraits::ErrorType;\n  using StackEntryType = typename PointedTraits::StackEntryType;\n\n  template<typename U, std::enable_if_t<std::is_same<ValueType, RemoveCVRef<U>>::value, int> = 0>\n  static decltype(auto) Error(U&& se) {\n    return PointedTraits::Error(*se);\n  }\n\n  static std::size_t StackSize(const ValueType& se) { return PointedTraits::StackSize(*se); }\n\n  static ConstRefExceptVoid<StackEntryType> StackElem(const T& se, std::size_t index) {\n    return PointedTraits::StackElem(*se, index);\n  }\n\n  template<typename U, typename... Args,\n           std::enable_if_t<std::is_same<ValueType, RemoveCVRef<U>>::value, int> = 0>\n  static void PushStack(U&& se, Args&&... args) {\n    PointedTraits::PushStack(*se, std::forward<Args>(args)...);\n  }\n\n  template<typename U, std::enable_if_t<std::is_same<ValueType, RemoveCVRef<U>>::value, int> = 0>\n  static std::string Dump(U&& se) {\n    return PointedTraits::Dump(*se);\n  }\n\n  template<typename U, std::enable_if_t<std::is_same<ValueType, RemoveCVRef<U>>::value, int> = 0>\n  [[noreturn]] static void Abort(U&& se) {\n    PointedTraits::Abort(*se);\n  }\n};\n\n// simple implementation for some customization points\nnamespace simple {\n\ntemplate<typename T>\nstruct MessageFormatTrait;\n\ntemplate<>\nstruct MessageFormatTrait<std::string> {\n  template<typename Code, typename... Args>\n  static std::string Format(Code&& code, Args&&... args) {\n    if (sizeof...(args) > 0) {\n      std::stringstream res;\n\n      res << code << \": \";\n      ((res << args), ...);\n\n      return res.str();\n    } else {\n      return code;\n    }\n  }\n};\n\ntemplate<>\nstruct MessageFormatTrait<std::string_view> {\n  template<typename Code>\n  static std::string_view Format(Code&& code) {\n    return code;\n  }\n};\n\ntemplate<typename Message, typename MessageFormatTraits = MessageFormatTrait<Message>>\nstruct ErrorStackEntry {\n  std::string_view filename;\n  std::size_t lineno;\n  std::string_view function;\n  Message message;\n\n  template<typename... Args>\n  ErrorStackEntry(std::string_view filename, std::size_t lineno, std::string_view function,\n                  Args&&... args)\n      : filename(filename),\n        lineno(lineno),\n        function(function),\n        message(MessageFormatTraits::Format(std::forward<Args>(args)...)) {}\n};\n\ntemplate<typename E, typename M = std::string>\nstruct StackedError : details::ErrorStackFromContainerBase<StackedError<E, M>> {\n public:\n  using ErrorType = E;\n  using StackMessage = M;\n  using StackEntryType = ErrorStackEntry<StackMessage>;\n  using StackType = std::vector<StackEntryType>;\n  using BaseType = details::ErrorStackFromContainerBase<StackedError<E, M>>;\n\n  static_assert(!std::is_reference<E>::value, \"the underlying value type cannot be reference\");\n\n  StackedError(ErrorType error)  // NOLINT(google-explicit-constructor)\n      : error_(std::move(error)) {}\n\n  ErrorType& Error() { return error_; }\n  const ErrorType& Error() const { return error_; }\n\n  std::string Dump() {\n    std::stringstream res;\n    res << \"error occurred: \" << error_ << std::endl;\n    for (const auto& elem : stack_) {\n      res << \"from \" << elem.function << \" in \" << elem.filename << \":\" << elem.lineno << \": \"\n          << elem.message << std::endl;\n    }\n\n    return res.str();\n  }\n\n  [[noreturn]] void Abort() {\n    std::cerr << \"error occurred: \" << error_ << std::endl;\n    for (const auto& elem : stack_) {\n      std::cerr << \"from \" << elem.function << \" in \" << elem.filename << \":\" << elem.lineno << \": \"\n                << elem.message << std::endl;\n    }\n    std::abort();\n  }\n\n private:\n  ErrorType error_;\n  StackType stack_;\n\n  StackType& GetStack() { return stack_; }\n\n  const StackType& GetStack() const { return stack_; }\n\n  friend BaseType;\n};\n\ntemplate<typename E>\nstruct NoStackError {\n  using ErrorType = E;\n  using StackEntryType = void;\n\n  static_assert(!std::is_reference<E>::value, \"the underlying value type cannot be reference\");\n\n  NoStackError(ErrorType error)  // NOLINT(google-explicit-constructor)\n      : error_(std::move(error)) {}\n\n  ErrorType& Error() { return error_; }\n  const ErrorType& Error() const { return error_; }\n\n  std::size_t StackSize() const { return 0; }\n\n  void StackElem(std::size_t) const {}\n\n  template<typename... Args>\n  void PushStack(Args&&... args) {}\n\n  std::string Dump() {\n    std::stringstream res;\n    res << error_ << std::endl;\n\n    return res.str();\n  }\n\n  [[noreturn]] void Abort() {\n    std::cerr << error_ << std::endl;\n    std::abort();\n  }\n\n private:\n  ErrorType error_;\n};\n\n}  // namespace simple\n\n}  // namespace maybe\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_MAYBE_ERROR_H_\n"
  },
  {
    "path": "oneflow/maybe/error_test.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include <gtest/gtest-death-test.h>\n#include <gtest/gtest.h>\n#include <system_error>\n#include \"oneflow/maybe/error.h\"\n\nusing namespace oneflow::maybe;\nusing namespace oneflow::maybe::simple;\nusing namespace std::string_literals;\n\nnamespace oneflow {\nnamespace maybe {\n\n// test if StackedErrorTraits can be applied to some simple types\ntemplate struct StackedErrorTraits<StackedError<std::error_code>>;\ntemplate struct StackedErrorTraits<NoStackError<std::error_code>>;\n\n}  // namespace maybe\n}  // namespace oneflow\n\nTEST(StackedError, SimpleStackedError) {\n  StackedError<std::error_code, std::string_view> a(std::make_error_code(std::errc::timed_out));\n\n  ASSERT_EQ(a.Error(), std::errc::timed_out);\n  ASSERT_EQ(a.StackSize(), 0);\n\n  const auto& ec = a.Error();\n  ASSERT_DEATH(a.Abort(),  // NOLINT(cppcoreguidelines-avoid-goto)\n               ec.category().name() + \":\"s + std::to_string(ec.value()));\n\n  [&a] { a.PushStack(__FILE__, __LINE__, __PRETTY_FUNCTION__, \"hello\"); }();\n\n  struct SomeType {\n    explicit SomeType(decltype(a)& a) {\n      a.PushStack(__FILE__, __LINE__, __PRETTY_FUNCTION__, \"hi\");\n    }\n  } x(a);\n\n  ASSERT_EQ(a.StackSize(), 2);\n  ASSERT_DEATH(a.Abort(),  // NOLINT(cppcoreguidelines-avoid-goto)\n               \"(lambda|operator\\\\(\\\\)).*hello.*\\n.*SomeType::SomeType.*hi\");\n\n  ASSERT_EQ(a.StackElem(0).message, \"hello\");\n  ASSERT_EQ(a.StackElem(1).message, \"hi\");\n}\n\nTEST(StackedError, SimpleNoStackError) {\n  NoStackError<std::error_code> a(std::make_error_code(std::errc::address_in_use));\n\n  ASSERT_EQ(a.Error(), std::errc::address_in_use);\n  ASSERT_EQ(a.StackSize(), 0);\n\n  const auto& ec = a.Error();\n  ASSERT_DEATH(a.Abort(),  // NOLINT(cppcoreguidelines-avoid-goto)\n               ec.category().name() + \":\"s + std::to_string(ec.value()));\n\n  a.PushStack(__FILE__, __LINE__, __PRETTY_FUNCTION__, \"hello\");\n  ASSERT_EQ(a.StackSize(), 0);\n}\n"
  },
  {
    "path": "oneflow/maybe/just.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#ifndef ONEFLOW_MAYBE_JUST_H_\n#define ONEFLOW_MAYBE_JUST_H_\n\n#include <type_traits>\n#include <utility>\n\n#include \"oneflow/maybe/error.h\"\n#include \"oneflow/maybe/type_traits.h\"\n\nnamespace oneflow {\n\nnamespace maybe {\n\ntemplate<typename T, typename E>\nstruct Maybe;\n\ntemplate<typename T>\nstruct IsMaybe : std::false_type {};\n\ntemplate<typename T, typename E>\nstruct IsMaybe<Maybe<T, E>> : std::true_type {};\n\ntemplate<typename T>\nstruct Optional;\n\ntemplate<typename T>\nstruct IsOptional : std::false_type {};\n\ntemplate<typename T>\nstruct IsOptional<Optional<T>> : std::true_type {};\n\n// user should provide which error will be returned while an optional has no value\n// and is used in JUST or CHECK_JUST;\n// if not provided, then JUST(_MSG) and CHECK_JUST(_MSG) cannot be used for Optional\n// i.e. ```c++\n// template <typename T> struct JustConfig<Optional<T>> {\n//   static SomeError ValueNotFoundError(auto&&) { ... }\n// };\n// ```\n// or some other optional types, i.e. std::shared_ptr ```c++\n// template <typename T> struct JustConfig<std::shared_ptr<T>> {\n//   // define which error will be returned while it is empty\n//   static SomeError ValueNotFoundError(auto&&) { ... }\n//   // define how to get the underlying value\n//   static decltype(auto) Value(auto&&) { ... }\n// };\n// ```\ntemplate<typename T>\nstruct JustTraits;\n\nnamespace details {\n\nstruct JustPrivateScope {\n  template<typename T>\n  static decltype(auto) Value(T&& v) {\n    return std::forward<T>(v).Value();\n  }\n\n  template<typename T, std::enable_if_t<IsMaybe<RemoveCVRef<T>>::value, int> = 0>\n  static decltype(auto) StackedError(T&& v) {\n    return std::forward<T>(v).StackedError();\n  }\n\n  template<typename T, std::enable_if_t<!IsMaybe<RemoveCVRef<T>>::value, int> = 0>\n  static decltype(auto) StackedError(T&& v) {\n    return JustTraits<RemoveCVRef<T>>::ValueNotFoundError(std::forward<T>(v));\n  }\n};\n\ntemplate<typename T>\ntypename std::remove_const<typename std::remove_reference<T>::type>::type&& RemoveRValConst(\n    T&& v) noexcept {\n  static_assert(std::is_rvalue_reference<T&&>::value, \"rvalue is expected here\");\n  return const_cast<typename std::remove_const<typename std::remove_reference<T>::type>::type&&>(v);\n}\n\ntemplate<typename T, typename... Args>\ndecltype(auto) JustPushStackAndReturn(T&& v, Args&&... args) {\n  StackedErrorTraits<RemoveCVRef<T>>::PushStack(std::forward<T>(v), std::forward<Args>(args)...);\n  return std::forward<T>(v);\n}\n\ntemplate<typename T, typename... Args>\n[[noreturn]] void JustPushStackAndAbort(T&& v, Args&&... args) {\n  using Traits = StackedErrorTraits<RemoveCVRef<T>>;\n\n  Traits::PushStack(std::forward<T>(v), std::forward<Args>(args)...);\n  Traits::Abort(std::forward<T>(v));\n}\n\ntemplate<typename T, std::enable_if_t<IsMaybe<T>::value || IsOptional<T>::value, int> = 0>\nauto JustGetValue(T&& v) -> RemoveRValRef<decltype(JustPrivateScope::Value(std::forward<T>(v)))> {\n  return JustPrivateScope::Value(std::forward<T>(v));\n}\n\ntemplate<typename T, std::enable_if_t<!IsMaybe<T>::value && !IsOptional<T>::value, int> = 0>\nauto JustGetValue(T&& v)\n    -> RemoveRValRef<decltype(JustTraits<RemoveCVRef<T>>::Value(std::forward<T>(v)))> {\n  return JustTraits<RemoveCVRef<T>>::Value(std::forward<T>(v));\n}\n\n}  // namespace details\n\n}  // namespace maybe\n\n}  // namespace oneflow\n\n// macros begin\n\n#define JUST_STACK_CHECK_I(...) __VA_ARGS__\n\n#define JUST_TO_STR_I(...) #__VA_ARGS__\n\n#if defined(__GNUC__) || defined(__CUDACC__) || defined(__clang__)\n\n#define JUST(...)                                                                       \\\n  ::oneflow::maybe::details::JustGetValue(::oneflow::maybe::details::RemoveRValConst(({ \\\n    auto&& _just_value_to_check_ = JUST_STACK_CHECK_I(__VA_ARGS__);                     \\\n    if (OF_MAYBE_EXPECT_FALSE(!_just_value_to_check_)) {                                \\\n      return ::oneflow::maybe::details::JustPushStackAndReturn(                         \\\n          ::oneflow::maybe::details::JustPrivateScope::StackedError(                    \\\n              std::forward<decltype(_just_value_to_check_)>(_just_value_to_check_)),    \\\n          __FILE__, __LINE__, __PRETTY_FUNCTION__, JUST_TO_STR_I(__VA_ARGS__));         \\\n    }                                                                                   \\\n    std::forward<decltype(_just_value_to_check_)>(_just_value_to_check_);               \\\n  })))\n\n#define CHECK_JUST(...)                                                              \\\n  ::oneflow::maybe::details::JustGetValue([&](const auto& _just_function_name_) {    \\\n    auto&& _just_value_to_check_ = JUST_STACK_CHECK_I(__VA_ARGS__);                  \\\n    if (OF_MAYBE_EXPECT_FALSE(!_just_value_to_check_)) {                             \\\n      ::oneflow::maybe::details::JustPushStackAndAbort(                              \\\n          ::oneflow::maybe::details::JustPrivateScope::StackedError(                 \\\n              std::forward<decltype(_just_value_to_check_)>(_just_value_to_check_)), \\\n          __FILE__, __LINE__, _just_function_name_, JUST_TO_STR_I(__VA_ARGS__));     \\\n    }                                                                                \\\n    return std::forward<decltype(_just_value_to_check_)>(_just_value_to_check_);     \\\n  }(__PRETTY_FUNCTION__))\n\n#define JUST_MSG(_just_expr_, ...)                                                           \\\n  ::oneflow::maybe::details::JustGetValue(::oneflow::maybe::details::RemoveRValConst(({      \\\n    auto&& _just_value_to_check_ = (_just_expr_);                                            \\\n    if (OF_MAYBE_EXPECT_FALSE(!_just_value_to_check_)) {                                     \\\n      return ::oneflow::maybe::details::JustPushStackAndReturn(                              \\\n          ::oneflow::maybe::details::JustPrivateScope::StackedError(                         \\\n              std::forward<decltype(_just_value_to_check_)>(_just_value_to_check_)),         \\\n          __FILE__, __LINE__, __PRETTY_FUNCTION__, JUST_TO_STR_I(_just_expr_), __VA_ARGS__); \\\n    }                                                                                        \\\n    std::forward<decltype(_just_value_to_check_)>(_just_value_to_check_);                    \\\n  })))\n\n#define CHECK_JUST_MSG(_just_expr_, ...)                                                      \\\n  ::oneflow::maybe::details::JustGetValue([&](const auto& _just_function_name_) {             \\\n    auto&& _just_value_to_check_ = (_just_expr_);                                             \\\n    if (OF_MAYBE_EXPECT_FALSE(!_just_value_to_check_)) {                                      \\\n      ::oneflow::maybe::details::JustPushStackAndAbort(                                       \\\n          ::oneflow::maybe::details::JustPrivateScope::StackedError(                          \\\n              std::forward<decltype(_just_value_to_check_)>(_just_value_to_check_)),          \\\n          __FILE__, __LINE__, _just_function_name_, JUST_TO_STR_I(_just_expr_), __VA_ARGS__); \\\n    }                                                                                         \\\n    return std::forward<decltype(_just_value_to_check_)>(_just_value_to_check_);              \\\n  }(__PRETTY_FUNCTION__))\n\n#define OPT_JUST(...)                                                                   \\\n  ::oneflow::maybe::details::JustGetValue(::oneflow::maybe::details::RemoveRValConst(({ \\\n    auto&& _just_value_to_check_ = JUST_STACK_CHECK_I(__VA_ARGS__);                     \\\n    if (OF_MAYBE_EXPECT_FALSE(!_just_value_to_check_)) { return NullOpt; }              \\\n    std::forward<decltype(_just_value_to_check_)>(_just_value_to_check_);               \\\n  })))\n\n#else\n#error \"statement expression is not supported, please implement try-catch version of JUST\"\n#endif  // defined(__GNUC__) || defined(__CUDACC__) || defined(__clang__)\n\n// macros end\n\n#endif  // ONEFLOW_MAYBE_JUST_H_\n"
  },
  {
    "path": "oneflow/maybe/just_test.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include <gtest/gtest-death-test.h>\n#include <gtest/gtest.h>\n#include <memory>\n\n#include \"oneflow/maybe/maybe.h\"\n#include \"oneflow/maybe/optional.h\"\n\nusing namespace oneflow::maybe;\n\nTEST(Just, MaybeBasic) {\n  using Error = simple::StackedError<std::string>;\n  using MaybeInt = Maybe<int, Error>;\n\n  auto f = [](int x) -> MaybeInt {\n    if (x > 10 || x < 0) { return Error{\"not in range\"}; }\n\n    return x + 10;\n  };\n\n  auto g = [&f](int x) -> MaybeInt {\n    if (x == 15) { return Error{\"invalid value\"}; }\n\n    return JUST(f(x)) * 2;\n  };\n\n  auto h = [&g](int x) -> MaybeInt { return JUST(g(x)) + 2; };\n\n  ASSERT_EQ(CHECK_JUST(h(0)), 22);\n\n  ASSERT_DEATH(  // NOLINT(cppcoreguidelines-avoid-goto)\n      CHECK_JUST(h(11)),\n      R\"(not in range.*(lambda|operator\\(\\)).*f\\(x\\).*(lambda|operator\\(\\)).*g\\(x\\).*TestBody.*h\\(11\\))\");\n\n  ASSERT_DEATH(  // NOLINT(cppcoreguidelines-avoid-goto)\n      CHECK_JUST(h(15)), R\"(invalid value.*(lambda|operator\\(\\)).*g\\(x\\).*TestBody.*h\\(15\\))\");\n\n  ASSERT_EQ(details::JustPrivateScope::StackedError(h(12)).StackSize(), 2);\n  ASSERT_EQ(details::JustPrivateScope::StackedError(h(15)).StackSize(), 1);\n}\n\nTEST(Just, MaybeVoid) {\n  using Error = simple::StackedError<std::string>;\n  using MaybeVoid = Maybe<void, Error>;\n\n  auto f = [](int& x) -> MaybeVoid {\n    if (x > 10 || x < 0) { return Error{\"not in range\"}; }\n\n    x = x + 5;\n    return Ok;\n  };\n\n  auto g = [&f](int& x) -> MaybeVoid {\n    if (x == 15) { return Error{\"invalid value\"}; }\n\n    JUST(f(x));\n    JUST(f(x));\n    return Ok;\n  };\n\n  auto h = [&g](int& x) -> MaybeVoid {\n    JUST(g(x));\n    x = x + 2;\n    return Ok;\n  };\n\n  int x = 0;\n  CHECK_JUST(h(x));\n  ASSERT_EQ(x, 12);\n\n  x = 11;\n  ASSERT_DEATH(  // NOLINT(cppcoreguidelines-avoid-goto)\n      CHECK_JUST(h(x)),\n      R\"(not in range.*(lambda|operator\\(\\)).*f\\(x\\).*(lambda|operator\\(\\)).*g\\(x\\).*TestBody.*h\\(x\\))\");\n  ASSERT_EQ(x, 11);\n\n  x = 8;\n  ASSERT_DEATH(  // NOLINT(cppcoreguidelines-avoid-goto)\n      CHECK_JUST(h(x)),\n      R\"(not in range.*(lambda|operator\\(\\)).*f\\(x\\).*(lambda|operator\\(\\)).*g\\(x\\).*TestBody.*h\\(x\\))\");\n\n  [[maybe_unused]] auto _ = h(x);  // NOLINT\n  ASSERT_EQ(x, 13);\n}\n\nTEST(Just, MaybeRef) {\n  using Error = simple::StackedError<std::string>;\n  using MaybeRef = Maybe<const int&, Error>;\n\n  int k = 100;\n\n  auto f = [&k](const int& x) -> MaybeRef {\n    if (x > 10 || x < 0) { return Error{\"not in range\"}; }\n\n    if (x < 5) return x;\n    return k;\n  };\n\n  auto g = [&f](const int& x) -> MaybeRef {\n    if (x == 2) { return Error{\"invalid value\"}; }\n    return JUST(f(x));\n  };\n\n  int x = 1;\n  ASSERT_EQ(CHECK_JUST(g(x)), 1);\n\n  const int& y = CHECK_JUST(g(5));\n  ASSERT_EQ(y, 100);\n  k = 200;\n  ASSERT_EQ(y, 200);\n\n  ASSERT_DEATH(  // NOLINT(cppcoreguidelines-avoid-goto)\n      CHECK_JUST(g(11)), R\"(not in range.*(lambda|operator\\(\\)).*f\\(x\\).*TestBody.*g\\(11\\))\");\n\n  ASSERT_DEATH(  // NOLINT(cppcoreguidelines-avoid-goto)\n      CHECK_JUST(g(2)), R\"(invalid value.*TestBody.*g\\(2\\))\");\n}\n\nTEST(Just, MaybeErrorPtr) {\n  using E = simple::StackedError<std::string>;\n  using Error = std::unique_ptr<E>;\n  using MaybeInt = Maybe<int, Error>;\n\n  auto f = [](int x) -> MaybeInt {\n    if (x > 10 || x < 0) { return std::make_unique<E>(\"not in range\"); }\n\n    return x + 10;\n  };\n\n  auto g = [&f](int x) -> MaybeInt {\n    if (x == 15) { return std::make_unique<E>(\"invalid value\"); }\n\n    return JUST(f(x)) * 2;\n  };\n\n  auto h = [&g](int x) -> MaybeInt { return JUST(g(x)) + 2; };\n\n  ASSERT_EQ(CHECK_JUST(h(0)), 22);\n\n  ASSERT_DEATH(  // NOLINT(cppcoreguidelines-avoid-goto)\n      CHECK_JUST(h(11)),\n      R\"(not in range.*(lambda|operator\\(\\)).*f\\(x\\).*(lambda|operator\\(\\)).*g\\(x\\).*TestBody.*h\\(11\\))\");\n\n  ASSERT_DEATH(  // NOLINT(cppcoreguidelines-avoid-goto)\n      CHECK_JUST(h(15)), R\"(invalid value.*(lambda|operator\\(\\)).*g\\(x\\).*TestBody.*h\\(15\\))\");\n\n  ASSERT_EQ(details::JustPrivateScope::StackedError(h(12))->StackSize(), 2);\n  ASSERT_EQ(details::JustPrivateScope::StackedError(h(15))->StackSize(), 1);\n}\n\nnamespace oneflow {\nnamespace maybe {\n\ntemplate<typename T>\nstruct JustTraits {\n  template<typename U>\n  static simple::StackedError<std::string> ValueNotFoundError(U&&) {\n    return {\"not found\"};\n  }\n\n  template<typename U>\n  static decltype(auto) Value(U&& v) {\n    return *v;\n  }\n};\n\n}  // namespace maybe\n}  // namespace oneflow\n\nTEST(Just, Optional) {\n  using Error = simple::StackedError<std::string>;\n  using MaybeInt = Maybe<int, Error>;\n\n  Optional<int> a, b(1), c(2);\n\n  auto f = [](const Optional<int>& x) -> MaybeInt {\n    if (x == 1) return Error(\"hello\");\n\n    return JUST(x) + 1;\n  };\n\n  ASSERT_DEATH(  // NOLINT(cppcoreguidelines-avoid-goto)\n      CHECK_JUST(f(a)), R\"(not found.*(lambda|operator\\(\\)).*x.*TestBody.*f\\(a\\))\");\n  ASSERT_DEATH(  // NOLINT(cppcoreguidelines-avoid-goto)\n      CHECK_JUST(f(b)), R\"(hello.*TestBody.*f\\(b\\))\");\n\n  ASSERT_EQ(CHECK_JUST(f(c)), 3);\n}\n\nTEST(Just, Ptr) {\n  using Error = simple::StackedError<std::string>;\n  using MaybeInt = Maybe<int, Error>;\n\n  std::shared_ptr<int> a, b(std::make_shared<int>(1)), c(std::make_shared<int>(2));\n\n  auto f = [](const std::shared_ptr<int>& x) -> MaybeInt {\n    if (JUST(x) == 1) return Error(\"hello\");\n\n    return JUST(x) + 1;\n  };\n\n  ASSERT_DEATH(  // NOLINT(cppcoreguidelines-avoid-goto)\n      CHECK_JUST(f(a)), R\"(not found.*(lambda|operator\\(\\)).*x.*TestBody.*f\\(a\\))\");\n  ASSERT_DEATH(  // NOLINT(cppcoreguidelines-avoid-goto)\n      CHECK_JUST(f(b)), R\"(hello.*TestBody.*f\\(b\\))\");\n\n  ASSERT_EQ(CHECK_JUST(f(c)), 3);\n}\n\nTEST(Just, WithMsg) {\n  struct UniqueInt {\n    int x;\n\n    void drop() { x = -333; }\n\n    explicit UniqueInt(int x) : x{x} {}\n    UniqueInt(const UniqueInt& i) = delete;\n    UniqueInt(UniqueInt&& i) noexcept : x{i.x} { i.drop(); }  // NOLINT\n    UniqueInt& operator=(const UniqueInt& i) = delete;\n    UniqueInt& operator=(UniqueInt&& i) noexcept {\n      x = i.x;\n      i.drop();\n      return *this;\n    }\n    ~UniqueInt() { drop(); }\n  };\n\n  using Error = simple::StackedError<std::string>;\n  using MaybeInt = Maybe<UniqueInt, Error>;\n\n  auto f = [](UniqueInt x) -> MaybeInt {\n    if (x.x > 10) { return Error{\"input value \" + std::to_string(x.x)}; }\n\n    return UniqueInt{233};\n  };\n\n  auto g = [](UniqueInt x) {\n    int y = x.x;\n    return UniqueInt{y * y - 5 * y + 3};\n  };\n\n  auto h = [&](UniqueInt x) -> MaybeInt {\n    int n = x.x;\n    auto y = g(std::move(x));\n    return JUST_MSG(f(std::move(y)), \"input value g(\", n, \")\");\n  };\n\n  auto i = [&](float x) -> MaybeInt {\n    UniqueInt y{int(x)};\n    return JUST_MSG(h(std::move(y)), \"input value int(\", x, \")\");\n  };\n\n  auto data = CHECK_JUST(i(1));\n  ASSERT_EQ(data.x, 233);\n\n  auto err = details::JustPrivateScope::StackedError(i(10.123));\n  ASSERT_EQ(err.Error(), \"input value 53\");\n  ASSERT_EQ(err.StackElem(0).message, \"f(std::move(y)): input value g(10)\");\n  ASSERT_EQ(err.StackElem(1).message, \"h(std::move(y)): input value int(10.123)\");\n\n  // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto)\n  ASSERT_EXIT(CHECK_JUST(i(10.234)), testing::KilledBySignal(SIGABRT), R\"(input value 53)\");\n}\n\nTEST(Just, JustOpt) {\n  auto f = [](int x) -> Optional<int> {\n    if (x > 10) return NullOpt;\n\n    return x + 1;\n  };\n\n  auto g = [&f](int x) -> Optional<int> { return OPT_JUST(f(x)) * 2; };\n\n  ASSERT_EQ(CHECK_JUST(g(2)), 6);\n  ASSERT_FALSE(g(11));\n\n  auto h = [&](int x) -> Optional<int> {\n    if (x == 10) return NullOpt;\n\n    return OPT_JUST(g(x)) + OPT_JUST(f(x + 2));\n  };\n\n  ASSERT_FALSE(h(10));\n  ASSERT_FALSE(h(9));\n  ASSERT_EQ(h(8), 29);\n}\n\nTEST(Just, NoStack) {\n  using Error = simple::NoStackError<std::string>;\n  using MaybeInt = Maybe<int, Error>;\n\n  auto f = [](int x) -> MaybeInt {\n    if (x > 10 || x < 0) { return Error{\"not in range\"}; }\n\n    return x + 10;\n  };\n\n  auto g = [&f](int x) -> MaybeInt {\n    if (x == 15) { return Error{\"invalid value\"}; }\n\n    return JUST(f(x)) * 2;\n  };\n\n  auto h = [&g](int x) -> MaybeInt { return JUST(g(x)) + 2; };\n\n  ASSERT_EQ(CHECK_JUST(h(0)), 22);\n\n  ASSERT_DEATH(  // NOLINT(cppcoreguidelines-avoid-goto)\n      CHECK_JUST(h(11)), R\"(not in range)\");\n\n  ASSERT_DEATH(  // NOLINT(cppcoreguidelines-avoid-goto)\n      CHECK_JUST(h(15)), R\"(invalid value)\");\n\n  ASSERT_EQ(details::JustPrivateScope::StackedError(h(12)).StackSize(), 0);\n  ASSERT_EQ(details::JustPrivateScope::StackedError(h(15)).StackSize(), 0);\n}\n"
  },
  {
    "path": "oneflow/maybe/maybe.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#ifndef ONEFLOW_MAYBE_MAYBE_H_\n#define ONEFLOW_MAYBE_MAYBE_H_\n\n#include <cstddef>\n#include <type_traits>\n\n#include \"oneflow/maybe/just.h\"\n#include \"oneflow/maybe/variant.h\"\n#include \"oneflow/maybe/optional.h\"\n#include \"oneflow/maybe/error.h\"\n#include \"oneflow/maybe/config.h\"\n\nnamespace oneflow {\n\nnamespace maybe {\n\nstruct InPlaceOkType {\n  explicit constexpr InPlaceOkType() = default;\n};\n\nconstexpr InPlaceOkType Ok{};\n\nstruct InPlaceErrorType {\n  explicit constexpr InPlaceErrorType() = default;\n};\n\nconstexpr InPlaceErrorType InPlaceError{};\n\nnamespace details {\n\ntemplate<typename T, typename E, typename = void>\nstruct MaybeStorage : Variant<T, E> {\n  using Base = Variant<T, E>;\n\n  MaybeStorage(const T& v) : Base(v) {}        // NOLINT(google-explicit-constructor)\n  MaybeStorage(T&& v) : Base(std::move(v)) {}  // NOLINT(google-explicit-constructor)\n\n  template<typename... Args>\n  explicit MaybeStorage(InPlaceOkType, Args&&... args)\n      : Base(InPlaceType<T>, std::forward<Args>(args)...) {}\n\n  template<typename... Args>\n  explicit MaybeStorage(InPlaceErrorType, Args&&... args)\n      : Base(InPlaceType<E>, std::forward<Args>(args)...) {}\n\n  MaybeStorage(const E& err) : Base(err) {}        // NOLINT(google-explicit-constructor)\n  MaybeStorage(E&& err) : Base(std::move(err)) {}  // NOLINT(google-explicit-constructor)\n\n  decltype(auto) Value() & { return this->Base::template Value<T>(); }\n  decltype(auto) Value() const& { return this->Base::template Value<T>(); }\n  decltype(auto) Value() && { return std::move(*this).Base::template Value<T>(); }\n\n  decltype(auto) Error() & { return this->Base::template Value<E>(); }\n  decltype(auto) Error() const& { return this->Base::template Value<E>(); }\n  decltype(auto) Error() && { return std::move(*this).Base::template Value<E>(); }\n\n  bool IsOk() const { return this->template Is<T>(); }\n};\n\ntemplate<typename T, typename E>\nstruct MaybeStorage<T, E, std::enable_if_t<std::is_reference<T>::value>>\n    : Variant<std::remove_reference_t<T>*, E> {\n  static_assert(std::is_lvalue_reference<T>::value, \"rvalue reference is not allowed here\");\n\n  using PointedType = std::remove_reference_t<T>;\n  using UnderlyingType = PointedType*;\n  using Base = Variant<UnderlyingType, E>;\n\n  MaybeStorage(T v) : Base(&v) {}  // NOLINT(google-explicit-constructor)\n\n  MaybeStorage(const E& err) : Base(err) {}        // NOLINT(google-explicit-constructor)\n  MaybeStorage(E&& err) : Base(std::move(err)) {}  // NOLINT(google-explicit-constructor)\n\n  template<typename... Args>\n  explicit MaybeStorage(InPlaceErrorType, Args&&... args)\n      : Base(InPlaceType<E>, std::forward<Args>(args)...) {}\n\n  PointedType& Value() { return *this->Base::template Value<UnderlyingType>(); }\n\n  const PointedType& Value() const { return *this->Base::template Value<UnderlyingType>(); }\n\n  decltype(auto) Error() & { return this->Base::template Value<E>(); }\n  decltype(auto) Error() const& { return this->Base::template Value<E>(); }\n  decltype(auto) Error() && { return std::move(*this).Base::template Value<E>(); }\n\n  bool IsOk() const { return this->template Is<UnderlyingType>(); }\n};\n\ntemplate<typename E>\nstruct MaybeStorage<void, E> : Optional<E> {\n  using Base = Optional<E>;\n\n  MaybeStorage(InPlaceOkType) : Base(NullOpt) {}  // NOLINT(google-explicit-constructor)\n\n  MaybeStorage(const E& err) : Base(err) {}        // NOLINT(google-explicit-constructor)\n  MaybeStorage(E&& err) : Base(std::move(err)) {}  // NOLINT(google-explicit-constructor)\n\n  template<typename... Args>\n  explicit MaybeStorage(InPlaceErrorType, Args&&... args)\n      : Base(InPlace, std::forward<Args>(args)...) {}\n\n  void Value() const {}\n\n  decltype(auto) Error() & { return this->Base::Value(); }\n  decltype(auto) Error() const& { return this->Base::Value(); }\n  decltype(auto) Error() && { return std::move(*this).Base::Value(); }\n\n  bool IsOk() const { return !this->HasValue(); }\n};\n\nstruct MaybePrivateScope {\n  template<typename T>\n  static decltype(auto) Value(T&& m) {\n    return std::forward<T>(m).Value();\n  }\n\n  template<typename T>\n  static decltype(auto) StackedError(T&& m) {\n    return std::forward<T>(m).StackedError();\n  }\n\n  template<typename T, typename F>\n  static auto Map(T&& maybe, F&& f)\n      -> Maybe<decltype(std::forward<F>(f)(std::forward<T>(maybe).Value())),\n               typename RemoveCVRef<T>::StackedErrorType> {\n    if (maybe) { return std::forward<F>(f)(std::forward<T>(maybe).Value()); }\n\n    return std::forward<T>(maybe).StackedError();\n  }\n\n  template<typename T, typename F,\n           typename U = std::decay_t<decltype(std::declval<F>()(std::declval<T>().Value()))>>\n  static auto Bind(T&& maybe, F&& f) -> std::enable_if_t<IsMaybe<U>::value, U> {\n    if (maybe) { return std::forward<F>(f)(std::forward<T>(maybe).Value()); }\n\n    return std::forward<T>(maybe).StackedError();\n  }\n};\n\n}  // namespace details\n\n// A type which can be either a value typed T, or a stacked error typed E\ntemplate<typename T, typename E>\nstruct OF_MAYBE_NODISCARD_TYPE Maybe : private details::MaybeStorage<T, E> {\n  static_assert(!std::is_reference<E>::value, \"error type cannot be reference\");\n  static_assert(!(std::is_const<E>::value || std::is_volatile<E>::value),\n                \"error type cannot be cv-qualified\");\n\n  // E must be a stacked error, which implies StackedErrorTraits<E> must exist\n  using ErrorTraits = StackedErrorTraits<E>;\n  using StackedErrorType = E;\n  using ValueType = T;\n  using ErrorType = typename ErrorTraits::ErrorType;\n\n private:\n  using Base = details::MaybeStorage<T, E>;\n\n  friend struct details::MaybePrivateScope;\n  friend struct details::JustPrivateScope;\n\n protected:\n  decltype(auto) Value() & { return Base::Value(); }\n  decltype(auto) Value() const& { return Base::Value(); }\n  decltype(auto) Value() && { return std::move(*this).Base::Value(); }\n\n  decltype(auto) StackedError() & { return Base::Error(); }\n  decltype(auto) StackedError() const& { return Base::Error(); }\n  decltype(auto) StackedError() && { return std::move(*this).Base::Error(); }\n\n  decltype(auto) Error() & { return ErrorTraits::Error(StackedError()); }\n  decltype(auto) Error() const& { return ErrorTraits::Error(StackedError()); }\n  decltype(auto) Error() && { return ErrorTraits::Error(std::move(*this).StackedError()); }\n\n public:\n  using Base::Base;\n\n  OF_MAYBE_NODISCARD_FUNC bool IsOk() const { return Base::IsOk(); }\n  OF_MAYBE_NODISCARD_FUNC bool IsErr() const { return !Base::IsOk(); }\n  explicit operator bool() const { return IsOk(); }\n\n  OF_MAYBE_NODISCARD_FUNC decltype(auto) GetStackedError() & {\n    OF_MAYBE_ASSERT(IsErr());\n    return StackedError();\n  }\n\n  OF_MAYBE_NODISCARD_FUNC decltype(auto) GetStackedError() const& {\n    OF_MAYBE_ASSERT(IsErr());\n    return StackedError();\n  }\n\n  OF_MAYBE_NODISCARD_FUNC decltype(auto) GetStackedError() && {\n    OF_MAYBE_ASSERT(IsErr());\n    return std::move(*this).StackedError();\n  }\n\n  OF_MAYBE_NODISCARD_FUNC decltype(auto) GetError() & {\n    OF_MAYBE_ASSERT(IsErr());\n    return Error();\n  }\n\n  OF_MAYBE_NODISCARD_FUNC decltype(auto) GetError() const& {\n    OF_MAYBE_ASSERT(IsErr());\n    return Error();\n  }\n\n  OF_MAYBE_NODISCARD_FUNC decltype(auto) GetError() && {\n    OF_MAYBE_ASSERT(IsErr());\n    return std::move(*this).Error();\n  }\n\n  template<typename F>\n  OF_MAYBE_NODISCARD_FUNC auto Map(F&& f) const& {\n    return details::MaybePrivateScope::Map(*this, std::forward<F>(f));\n  }\n\n  template<typename F>\n  OF_MAYBE_NODISCARD_FUNC auto Map(F&& f) && {\n    return details::MaybePrivateScope::Map(std::move(*this), std::forward<F>(f));\n  }\n\n  template<typename F>\n  OF_MAYBE_NODISCARD_FUNC auto Bind(F&& f) const& {\n    return details::MaybePrivateScope::Bind(*this, std::forward<F>(f));\n  }\n\n  template<typename F>\n  OF_MAYBE_NODISCARD_FUNC auto Bind(F&& f) && {\n    return details::MaybePrivateScope::Bind(std::move(*this), std::forward<F>(f));\n  }\n};\n\n}  // namespace maybe\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_MAYBE_MAYBE_H_\n"
  },
  {
    "path": "oneflow/maybe/maybe_test.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include <gtest/gtest.h>\n#include \"oneflow/maybe/error.h\"\n#include \"oneflow/maybe/maybe.h\"\n\nusing namespace oneflow::maybe;\n\nTEST(Maybe, Basic) {\n  using Error = simple::StackedError<int>;\n  Maybe<int, Error> a{1}, b{a}, c{Error(2)}, d{c};\n\n  ASSERT_TRUE(a);\n  ASSERT_TRUE(b);\n  ASSERT_FALSE(c);\n  ASSERT_FALSE(d);\n\n  ASSERT_EQ(details::MaybePrivateScope::Value(a), 1);\n  ASSERT_EQ(details::MaybePrivateScope::Value(b), 1);\n\n  a = 2;\n  ASSERT_EQ(details::MaybePrivateScope::Value(a), 2);\n  ASSERT_EQ(details::MaybePrivateScope::Value(b), 1);\n\n  ASSERT_EQ(details::MaybePrivateScope::StackedError(c).Error(), 2);\n  ASSERT_EQ(details::MaybePrivateScope::StackedError(d).Error(), 2);\n\n  a = c;\n  ASSERT_EQ(details::MaybePrivateScope::StackedError(a).Error(), 2);\n}\n\nTEST(Maybe, NonPOD) {\n  using Error = simple::StackedError<std::string>;\n  Maybe<std::shared_ptr<int>, Error> a{Ok, new int{1}}, b{a}, c{Error(\"test\")}, d{c};\n\n  ASSERT_TRUE(a);\n  ASSERT_TRUE(b);\n  ASSERT_FALSE(c);\n  ASSERT_FALSE(d);\n\n  ASSERT_EQ(details::MaybePrivateScope::Value(a).use_count(), 2);\n\n  {\n    Maybe<std::shared_ptr<int>, Error> x(a);\n\n    ASSERT_EQ(details::MaybePrivateScope::Value(x).use_count(), 3);\n\n    x = c;\n    ASSERT_FALSE(x);\n\n    x = a;\n    ASSERT_EQ(details::MaybePrivateScope::Value(x).use_count(), 3);\n  }\n\n  ASSERT_EQ(details::MaybePrivateScope::Value(a).use_count(), 2);\n\n  ASSERT_EQ(*details::MaybePrivateScope::Value(a), 1);\n  *details::MaybePrivateScope::Value(a) = 2;\n  ASSERT_EQ(*details::MaybePrivateScope::Value(a), 2);\n\n  ASSERT_EQ(details::MaybePrivateScope::StackedError(c).Error(), \"test\");\n  ASSERT_EQ(details::MaybePrivateScope::StackedError(c).StackSize(), 0);\n}\n\nTEST(Maybe, Reference) {\n  using Error = simple::StackedError<std::string>;\n\n  const int& n = 1;\n  Maybe<const int&, Error> a{n}, b{a}, c{Error(\"test\")}, d{c};\n\n  ASSERT_TRUE(a);\n  ASSERT_TRUE(b);\n  ASSERT_FALSE(c);\n  ASSERT_FALSE(d);\n\n  ASSERT_EQ(details::MaybePrivateScope::Value(a), 1);\n\n  int k = 2;\n\n  a = k;\n  ASSERT_EQ(details::MaybePrivateScope::Value(a), 2);\n\n  k = 3;\n  ASSERT_EQ(details::MaybePrivateScope::Value(a), 3);\n\n  int x = 1;\n  Maybe<int&, Error> e{x}, f{e}, g{Error(\"test\")}, h{g};\n\n  ASSERT_TRUE(a);\n  ASSERT_TRUE(b);\n  ASSERT_FALSE(c);\n  ASSERT_FALSE(d);\n\n  ASSERT_EQ(details::MaybePrivateScope::Value(e), 1);\n\n  e = k;\n  ASSERT_EQ(details::MaybePrivateScope::Value(e), 3);\n\n  details::MaybePrivateScope::Value(e) = 4;\n  ASSERT_EQ(k, 4);\n}\n\nTEST(Maybe, Void) {\n  using Error = simple::StackedError<std::string>;\n  Maybe<void, Error> a{Ok}, b{a}, c{Error(\"test\")}, d{c};\n\n  ASSERT_TRUE(a);\n  ASSERT_TRUE(b);\n  ASSERT_FALSE(c);\n  ASSERT_FALSE(d);\n\n  ASSERT_EQ(details::MaybePrivateScope::StackedError(c).Error(), \"test\");\n\n  c = Error(\"hello\");\n  ASSERT_EQ(details::MaybePrivateScope::StackedError(c).Error(), \"hello\");\n\n  a = c;\n  ASSERT_EQ(details::MaybePrivateScope::StackedError(a).Error(), \"hello\");\n}\n\nTEST(Maybe, PtrError) {\n  using PointedError = simple::StackedError<std::string>;\n  using Error = std::unique_ptr<PointedError>;\n  Maybe<int, Error> a{1}, c{InPlaceError, new PointedError(\"test\")};\n\n  ASSERT_TRUE(a);\n  ASSERT_FALSE(c);\n\n  ASSERT_EQ(details::MaybePrivateScope::StackedError(c)->Error(), \"test\");\n}\n\nTEST(Maybe, NoStack) {\n  using Error = simple::NoStackError<std::string>;\n  Maybe<int, Error> a{1}, b{a}, c{InPlaceError, \"hello\"}, d{c};\n\n  ASSERT_TRUE(a);\n  ASSERT_TRUE(b);\n  ASSERT_FALSE(c);\n  ASSERT_FALSE(d);\n\n  a = c;\n  ASSERT_FALSE(a);\n}\n\nTEST(Maybe, Monadic) {\n  using Error = simple::NoStackError<std::string>;\n  Maybe<int, Error> a{1}, b{InPlaceError, \"hello\"};\n\n  auto x2 = [](int x) { return x * 2; };\n\n  auto x2e2 = [](int x) -> Maybe<int, Error> {\n    if (x == 4) return Error(\"test\");\n    return x * 2;\n  };\n\n  ASSERT_EQ(CHECK_JUST(a.Map(x2).Map(x2)), 4);\n  ASSERT_FALSE(b.Map(x2).Map(x2));\n\n  a = 1;\n  ASSERT_EQ(CHECK_JUST(a.Bind(x2e2).Bind(x2e2)), 4);\n\n  a = 2;\n  ASSERT_EQ(CHECK_JUST(a.Bind(x2e2)), 4);\n  ASSERT_EQ(a.Bind(x2e2).Bind(x2e2).GetError(), \"test\");\n\n  a = 4;\n  ASSERT_EQ(a.Bind(x2e2).GetError(), \"test\");\n  ASSERT_EQ(a.Bind(x2e2).GetError(), \"test\");\n}\n"
  },
  {
    "path": "oneflow/maybe/optional.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#ifndef ONEFLOW_MAYBE_OPTIONAL_H_\n#define ONEFLOW_MAYBE_OPTIONAL_H_\n\n#include <type_traits>\n#include <utility>\n\n#include \"oneflow/maybe/just.h\"\n#include \"oneflow/maybe/utility.h\"\n#include \"oneflow/maybe/type_traits.h\"\n\nnamespace oneflow {\n\nnamespace maybe {\n\ntemplate<typename T>\nstruct Optional;\n\nnamespace details {\n\n// OptionalStorage is specialized for 2 cases:\n// 1. for scalar types, we optimize all construction, destruction and value check\n// 2. for reference types, we store a pointer to the referenced value\ntemplate<typename T, typename = void>\nstruct OptionalStorage {\n private:\n  bool has_;\n  alignas(T) unsigned char value_[sizeof(T)];\n\n  using Type = std::remove_const_t<T>;\n\n public:\n  OptionalStorage() = default;\n  ~OptionalStorage() = default;\n\n  OptionalStorage(const OptionalStorage&) = delete;\n  OptionalStorage& operator=(const OptionalStorage&) = delete;\n\n  void Init() { has_ = false; }\n\n  T& Value() & { return *reinterpret_cast<T*>(value_); }\n\n  Type&& Value() && { return std::move(*const_cast<Type*>(reinterpret_cast<T*>(value_))); }\n\n  const T& Value() const& { return *reinterpret_cast<const T*>(value_); }\n\n  bool HasValue() const { return has_; }\n\n  void Reset() {\n    if (has_) {\n      has_ = false;\n      Value().~T();\n    }\n  }\n\n  void Destory() {\n    if (has_) { Value().~T(); }\n  }\n\n  template<typename... Args, typename U = Type, std::enable_if_t<IsAggregate<U>, int> = 0>\n  void Construct(Args&&... args) {\n    new (value_) Type{std::forward<Args>(args)...};\n    has_ = true;\n  }\n\n  template<typename... Args, typename U = Type, std::enable_if_t<!IsAggregate<U>, int> = 0>\n  void Construct(Args&&... args) {\n    new (value_) Type(std::forward<Args>(args)...);\n    has_ = true;\n  }\n\n  template<typename... Args, typename U = T, std::enable_if_t<!std::is_const<U>::value, int> = 0>\n  T& Emplace(Args&&... args) {\n    if (!has_) {\n      Construct(std::forward<Args>(args)...);\n      return Value();\n    } else {\n      return Value() = Type(std::forward<Args>(args)...);\n    }\n  }\n\n  template<typename... Args, typename U = T, std::enable_if_t<std::is_const<U>::value, int> = 0>\n  T& Emplace(Args&&... args) {\n    Destory();\n    Construct(std::forward<Args>(args)...);\n    return Value();\n  }\n\n  template<typename OS>\n  void CopyConstruct(OS&& s) {\n    has_ = s.has_;\n\n    if (has_) { new (value_) Type(std::forward<OS>(s).Value()); }\n  }\n\n  template<typename OS>\n  void Copy(OS&& s) {\n    if (s.has_) {\n      Emplace(std::forward<OS>(s).Value());\n    } else {\n      Reset();\n    }\n  }\n};\n\ntemplate<typename T>  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)\nstruct OptionalStorage<T, std::enable_if_t<std::is_scalar<T>::value>> {\n private:\n  using Type = std::remove_const_t<T>;\n\n  bool has_;\n  Type value_;\n\n public:\n  OptionalStorage() = default;\n  ~OptionalStorage() = default;\n\n  OptionalStorage(const OptionalStorage&) = delete;\n  OptionalStorage& operator=(const OptionalStorage&) = delete;\n\n  void Init() {\n    has_ = false;\n    value_ = {};\n  }\n\n  T& Value() & { return value_; }\n\n  Type&& Value() && { return std::move(const_cast<Type&>(value_)); }\n\n  const T& Value() const& { return value_; }\n\n  bool HasValue() const { return has_; }\n\n  void Reset() { has_ = false; }\n\n  void Destory() {}\n\n  template<typename U>\n  void Construct(const U& v) {\n    value_ = v;\n    has_ = true;\n  }\n\n  template<typename U>\n  T& Emplace(const U& v) {\n    Construct(v);\n    return Value();\n  }\n\n  void CopyConstruct(const OptionalStorage& s) {\n    has_ = s.has_;\n    value_ = s.value_;\n  }\n\n  void Copy(const OptionalStorage& s) { CopyConstruct(s); }\n};\n\ntemplate<typename T>\nstruct OptionalStorage<T, std::enable_if_t<std::is_reference<T>::value>> {\n  static_assert(std::is_lvalue_reference<T>::value, \"rvalue reference is not allowed here\");\n\n  using Type = std::remove_reference_t<T>;\n\n private:\n  Type* value_;\n\n public:\n  OptionalStorage() = default;\n  ~OptionalStorage() = default;\n\n  OptionalStorage(const OptionalStorage&) = delete;\n  OptionalStorage& operator=(const OptionalStorage&) = delete;\n\n  void Init() { value_ = nullptr; }\n\n  T Value() { return *value_; }\n\n  const Type& Value() const { return *value_; }\n\n  bool HasValue() const { return value_ != nullptr; }\n\n  void Reset() { value_ = nullptr; }\n\n  void Destory() {}\n\n  void Construct(T v) { value_ = &v; }\n\n  T Emplace(T v) {\n    Construct(v);\n    return Value();\n  }\n\n  void CopyConstruct(const OptionalStorage& s) { value_ = s.value_; }\n\n  void Copy(const OptionalStorage& s) { CopyConstruct(s); }\n};\n\nstruct OptionalPrivateScope {\n  template<typename T>\n  static decltype(auto) Value(T&& opt) {\n    return std::forward<T>(opt).Value();\n  }\n\n  template<typename T, typename F>\n  static auto Map(T&& opt, F&& f)\n      -> Optional<decltype(std::forward<F>(f)(std::forward<T>(opt).Value()))> {\n    if (opt.HasValue()) { return std::forward<F>(f)(std::forward<T>(opt).Value()); }\n\n    return NullOpt;\n  }\n\n  template<typename T, typename F,\n           typename U = std::decay_t<decltype(std::declval<F>()(std::declval<T>().Value()))>>\n  static auto Bind(T&& opt, F&& f) -> std::enable_if_t<IsOptional<U>::value, U> {\n    if (opt.HasValue()) { return std::forward<F>(f)(std::forward<T>(opt).Value()); }\n\n    return NullOpt;\n  }\n\n  template<typename T, typename F,\n           std::enable_if_t<std::is_same<decltype(std::declval<F>()()), void>::value, int> = 0>\n  static auto OrElse(T&& opt, F&& f) -> std::decay_t<T> {\n    if (!opt.HasValue()) {\n      std::forward<F>(f)();\n      return NullOpt;\n    }\n\n    return std::forward<T>(opt);\n  }\n\n  template<typename T, typename F,\n           std::enable_if_t<\n               std::is_convertible<decltype(std::declval<F>()()), std::decay_t<T>>::value, int> = 0>\n  static auto OrElse(T&& opt, F&& f) -> std::decay_t<T> {\n    if (!opt.HasValue()) { return std::forward<F>(f)(); }\n\n    return std::forward<T>(opt);\n  }\n};\n\n}  // namespace details\n\n// unlike Variant, type arguments can be cv qualified or lvalue referenced\n// this Optional DO NOT guarantee exception safety\ntemplate<typename T>\nstruct OF_MAYBE_NODISCARD_TYPE Optional {\n protected:\n  details::OptionalStorage<T> storage_;\n\n  using Type = std::remove_const_t<T>;\n\n  decltype(auto) Value() & { return storage_.Value(); }\n\n  decltype(auto) Value() && { return std::move(storage_).Value(); }\n\n  decltype(auto) Value() const& { return storage_.Value(); }\n\n  // we DO NOT export Value method, then leave these methods accessable for the JUST macro\n  friend struct details::OptionalPrivateScope;\n  friend struct details::JustPrivateScope;\n\n public:\n  static_assert(!std::is_same<std::remove_reference_t<Type>, NullOptType>::value,\n                \"NullOptType is not allowed in Optional\");\n\n  using ValueType = T;\n\n  explicit Optional() { storage_.Init(); };\n\n  Optional(NullOptType) { storage_.Init(); }  // NOLINT(google-explicit-constructor)\n\n  Optional(const T& v) { storage_.Construct(v); }  // NOLINT(google-explicit-constructor)\n\n  template<typename U = T, std::enable_if_t<!std::is_reference<U>::value, int> = 0>\n  Optional(Type&& v) {  // NOLINT(google-explicit-constructor)\n    storage_.Construct(std::move(v));\n  }\n\n  Optional(const Optional& opt) { storage_.CopyConstruct(opt.storage_); }\n  Optional(Optional&& opt) noexcept { storage_.CopyConstruct(std::move(opt.storage_)); }\n\n  template<typename... Args>\n  explicit Optional(InPlaceT, Args&&... args) {\n    storage_.Construct(std::forward<Args>(args)...);\n  }\n\n  ~Optional() { storage_.Destory(); }\n\n  Optional& operator=(NullOptType) {\n    storage_.Reset();\n    return *this;\n  }\n\n  Optional& operator=(const T& v) {\n    storage_.Emplace(v);\n    return *this;\n  }\n\n  template<typename U = T, std::enable_if_t<!std::is_reference<U>::value, int> = 0>\n  Optional& operator=(Type&& v) {\n    storage_.Emplace(std::move(v));\n    return *this;\n  }\n\n  template<typename... Args>\n  decltype(auto) Emplace(Args&&... args) {\n    return storage_.Emplace(std::forward<Args>(args)...);\n  }\n\n  Optional& operator=(const Optional& opt) {\n    storage_.Copy(opt.storage_);\n    return *this;\n  }\n\n  Optional& operator=(Optional&& opt) noexcept {\n    storage_.Copy(std::move(opt.storage_));\n    return *this;\n  }\n\n  OF_MAYBE_NODISCARD_FUNC bool HasValue() const { return storage_.HasValue(); }\n  explicit operator bool() const { return HasValue(); }\n\n  bool operator==(const Optional& opt) const {\n    if (HasValue()) {\n      if (opt.HasValue()) {\n        return Value() == opt.Value();\n      } else {\n        return false;\n      }\n    } else {\n      return !opt.HasValue();\n    }\n  }\n\n  bool operator!=(const Optional& opt) const { return !operator==(opt); }\n\n  bool operator<(const Optional& opt) const {\n    if (HasValue()) {\n      if (opt.HasValue()) {\n        return Value() < opt.Value();\n      } else {\n        return false;\n      }\n    } else {\n      return opt.HasValue();\n    }\n  }\n\n  bool operator>=(const Optional& opt) const { return !operator<(opt); }\n\n  bool operator>(const Optional& opt) const {\n    if (HasValue()) {\n      if (opt.HasValue()) {\n        return Value() > opt.Value();\n      } else {\n        return true;\n      }\n    } else {\n      return false;\n    }\n  }\n\n  bool operator<=(const Optional& opt) const { return !operator>(opt); }\n\n  friend bool operator==(const Optional& opt, NullOptType) { return !opt.HasValue(); }\n  friend bool operator!=(const Optional& opt, NullOptType) { return opt.HasValue(); }\n  friend bool operator==(NullOptType, const Optional& opt) { return !opt.HasValue(); }\n  friend bool operator!=(NullOptType, const Optional& opt) { return opt.HasValue(); }\n\n  friend bool operator<(const Optional& opt, NullOptType) { return false; }\n  friend bool operator>(const Optional& opt, NullOptType) { return opt.HasValue(); }\n  friend bool operator<=(const Optional& opt, NullOptType) { return !opt.HasValue(); }\n  friend bool operator>=(const Optional& opt, NullOptType) { return true; }\n\n  friend bool operator<(NullOptType, const Optional& opt) { return opt > NullOpt; }\n  friend bool operator>(NullOptType, const Optional& opt) { return opt < NullOpt; }\n  friend bool operator<=(NullOptType, const Optional& opt) { return opt >= NullOpt; }\n  friend bool operator>=(NullOptType, const Optional& opt) { return opt <= NullOpt; }\n\n  friend bool operator==(const Optional& opt, const T& v) {\n    if (opt.HasValue()) {\n      return opt.Value() == v;\n    } else {\n      return false;\n    }\n  }\n\n  friend bool operator!=(const Optional& opt, const T& v) { return !(opt == v); }\n\n  friend bool operator==(const T& v, const Optional& opt) { return opt == v; }\n\n  friend bool operator!=(const T& v, const Optional& opt) { return !(opt == v); }\n\n  friend bool operator<(const Optional& opt, const T& v) {\n    if (opt.HasValue()) {\n      return opt.Value() < v;\n    } else {\n      return true;\n    }\n  }\n\n  friend bool operator>=(const Optional& opt, const T& v) { return !(opt < v); }\n\n  friend bool operator>(const T& v, const Optional& opt) { return opt < v; }\n\n  friend bool operator<=(const T& v, const Optional& opt) { return !(opt < v); }\n\n  friend bool operator>(const Optional& opt, const T& v) {\n    if (opt.HasValue()) {\n      return opt.Value() > v;\n    } else {\n      return false;\n    }\n  }\n\n  friend bool operator<=(const Optional& opt, const T& v) { return !(opt > v); }\n\n  friend bool operator<(const T& v, const Optional& opt) { return opt > v; }\n\n  friend bool operator>=(const T& v, const Optional& opt) { return !(opt > v); }\n\n  decltype(auto) ValueOr(const T& v) const& {\n    if (HasValue()) {\n      return Value();\n    } else {\n      return v;\n    }\n  }\n\n  template<typename U = T, std::enable_if_t<!std::is_reference<U>::value, int> = 0>\n  auto ValueOr(T&& v) const& {\n    if (HasValue()) {\n      return Value();\n    } else {\n      return std::move(v);\n    }\n  }\n\n  template<typename U = T, std::enable_if_t<!std::is_reference<U>::value, int> = 0>\n  auto ValueOr(const T& v) && {\n    if (HasValue()) {\n      return std::move(*this).Value();\n    } else {\n      return v;\n    }\n  }\n\n  template<typename U = T, std::enable_if_t<!std::is_reference<U>::value, int> = 0>\n  decltype(auto) ValueOr(T&& v) && {\n    if (HasValue()) {\n      return std::move(*this).Value();\n    } else {\n      return std::move(v);\n    }\n  }\n\n  void Reset() { storage_.Reset(); }\n\n  template<typename F>\n  OF_MAYBE_NODISCARD_FUNC auto Map(F&& f) const& {\n    return details::OptionalPrivateScope::Map(*this, std::forward<F>(f));\n  }\n\n  template<typename F>\n  OF_MAYBE_NODISCARD_FUNC auto Map(F&& f) && {\n    return details::OptionalPrivateScope::Map(std::move(*this), std::forward<F>(f));\n  }\n\n  template<typename F>\n  OF_MAYBE_NODISCARD_FUNC auto Bind(F&& f) const& {\n    return details::OptionalPrivateScope::Bind(*this, std::forward<F>(f));\n  }\n\n  template<typename F>\n  OF_MAYBE_NODISCARD_FUNC auto Bind(F&& f) && {\n    return details::OptionalPrivateScope::Bind(std::move(*this), std::forward<F>(f));\n  }\n\n  template<typename F>\n  OF_MAYBE_NODISCARD_FUNC auto OrElse(F&& f) const& {\n    return details::OptionalPrivateScope::OrElse(*this, std::forward<F>(f));\n  }\n\n  template<typename F>\n  OF_MAYBE_NODISCARD_FUNC auto OrElse(F&& f) && {\n    return details::OptionalPrivateScope::OrElse(std::move(*this), std::forward<F>(f));\n  }\n};\n\n}  // namespace maybe\n\n}  // namespace oneflow\n\nnamespace std {\n\ntemplate<typename T>\nstruct hash<oneflow::maybe::Optional<T>> {\n  size_t operator()(const oneflow::maybe::Optional<T>& v) const noexcept {\n    if (v.HasValue()) {\n      return hashImpl(oneflow::maybe::details::OptionalPrivateScope::Value(v));\n    } else {\n      return oneflow::maybe::NullOptHash;\n    }\n  }\n\n  template<typename U = T, std::enable_if_t<!std::is_reference<U>::value, int> = 0>\n  static std::size_t hashImpl(const T& v) {\n    return std::hash<std::remove_cv_t<T>>()(v);\n  }\n\n  template<typename U = T, std::enable_if_t<std::is_reference<U>::value, int> = 0>\n  static std::size_t hashImpl(const std::remove_reference_t<T>& v) {\n    return std::hash<const std::remove_reference_t<T>*>()(&v);\n  }\n};\n\n}  // namespace std\n\n#endif  // ONEFLOW_MAYBE_OPTIONAL_H_\n"
  },
  {
    "path": "oneflow/maybe/optional_test.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include <gtest/gtest.h>\n#include <memory>\n#include \"oneflow/maybe/optional.h\"\n\nusing namespace oneflow::maybe;\n\nusing Private = details::OptionalPrivateScope;\n\nTEST(Optional, Scalar) {\n  Optional<int> a, b(1), c(a), d(b), e(NullOpt), bb(InPlace, 1);\n\n  static_assert(std::is_same<decltype(Private::Value(a)), int&>::value, \"\");\n\n  ASSERT_TRUE(!a.HasValue());\n  ASSERT_TRUE(b.HasValue());\n  ASSERT_EQ(b.ValueOr(0), 1);\n  ASSERT_TRUE(!c.HasValue());\n  ASSERT_EQ(c.ValueOr(233), 233);\n  ASSERT_TRUE(d.HasValue());\n  ASSERT_EQ(d.ValueOr(0), 1);\n  ASSERT_TRUE(!e.HasValue());\n\n  a = b;\n  ASSERT_TRUE(a.HasValue());\n  ASSERT_EQ(a.ValueOr(0), 1);\n\n  a = NullOpt;\n  ASSERT_TRUE(!a.HasValue());\n\n  a = 222;\n  ASSERT_TRUE(a.HasValue());\n  ASSERT_TRUE(a);\n  ASSERT_EQ(a.ValueOr(1), 222);\n\n  Private::Value(a) = 2333;\n  ASSERT_EQ(a.ValueOr(1), 2333);\n\n  Optional<const int> f, g(1);\n  ASSERT_TRUE(!f.HasValue());\n  ASSERT_TRUE(g.HasValue());\n  ASSERT_EQ(g.ValueOr(2), 1);\n\n  static_assert(std::is_same<decltype(Private::Value(f)), const int&>::value, \"\");\n\n  f = 1;\n  ASSERT_TRUE(f.HasValue());\n  ASSERT_EQ(f.ValueOr(2), 1);\n  ASSERT_EQ(Private::Value(f), 1);\n\n  int x = 2;\n  ASSERT_EQ(f.ValueOr(x), 1);\n\n  ASSERT_EQ(f.Emplace(2), 2);\n  ASSERT_EQ(Private::Value(f), 2);\n\n  f.Reset();\n  ASSERT_TRUE(!f);\n}\n\nTEST(Optional, NonScalar) {\n  auto x = std::make_shared<int>(233);\n  ASSERT_EQ(x.use_count(), 1);\n\n  Optional<std::shared_ptr<int>> a, b(x), aa(a), aaa(InPlace, std::make_shared<int>(244));\n  ASSERT_EQ(x.use_count(), 2);\n  ASSERT_EQ(*Private::Value(b), 233);\n  static_assert(std::is_same<decltype(Private::Value(b)), std::shared_ptr<int>&>::value, \"\");\n\n  ASSERT_TRUE(!a.HasValue());\n  ASSERT_TRUE(!aa.HasValue());\n\n  Optional<std::shared_ptr<int>> c(a), d(b);\n  ASSERT_TRUE(!c.HasValue());\n\n  ASSERT_EQ(x.use_count(), 3);\n  ASSERT_EQ(b, d);\n\n  a = x;\n  ASSERT_EQ(x.use_count(), 4);\n\n  a = NullOpt;\n  ASSERT_EQ(x.use_count(), 3);\n\n  a = b;\n  ASSERT_EQ(x.use_count(), 4);\n  ASSERT_EQ(a, b);\n\n  {\n    Optional<std::shared_ptr<int>> e(a);  // NOLINT\n    ASSERT_EQ(x.use_count(), 5);\n\n    Optional<std::shared_ptr<int>> f;\n    f = e;\n    ASSERT_EQ(x.use_count(), 6);\n  }\n\n  ASSERT_EQ(x.use_count(), 4);\n  *Private::Value(a) = 234;\n  ASSERT_EQ(*x, 234);\n\n  Optional<std::shared_ptr<int>> g(std::move(a));\n  ASSERT_EQ(x.use_count(), 4);\n\n  {\n    Optional<std::shared_ptr<int>> h;\n    ASSERT_TRUE(!h.HasValue());\n\n    h = std::move(b);\n    ASSERT_EQ(x.use_count(), 4);\n  }\n\n  ASSERT_EQ(x.use_count(), 3);\n\n  Optional<const std::shared_ptr<int>> i(x);\n  ASSERT_EQ(x.use_count(), 4);\n  static_assert(std::is_same<decltype(Private::Value(i)), const std::shared_ptr<int>&>::value, \"\");\n\n  i = NullOpt;\n  ASSERT_EQ(x.use_count(), 3);\n\n  i.Emplace(x);\n  ASSERT_EQ(x.use_count(), 4);\n\n  i.Reset();\n  ASSERT_EQ(x.use_count(), 3);\n\n  i.Emplace(std::move(x));\n  ASSERT_EQ(Private::Value(i).use_count(), 3);\n\n  struct A {\n    int id;\n    std::string name;\n  };\n\n  Optional<A> a1, a2{InPlace, 233, \"oneflow\"};\n\n  ASSERT_FALSE(a1);\n  ASSERT_TRUE(a2);\n\n  ASSERT_EQ(a1, NullOpt);\n  ASSERT_EQ(Private::Value(a2).id, 233);\n  ASSERT_EQ(Private::Value(a2).name, \"oneflow\");\n}\n\nTEST(Optional, Reference) {\n  int x = 233;\n\n  Optional<int&> a, b(x), c(a), d(b);\n\n  ASSERT_TRUE(!a);\n  ASSERT_TRUE(b);\n  ASSERT_TRUE(!c);\n  ASSERT_TRUE(d);\n\n  ASSERT_EQ(Private::Value(b), 233);\n  ASSERT_EQ(Private::Value(d), 233);\n\n  static_assert(std::is_same<decltype(Private::Value(b)), int&>::value, \"\");\n\n  a = x;\n  ASSERT_TRUE(a);\n  ASSERT_EQ(Private::Value(a), 233);\n\n  a = NullOpt;\n  ASSERT_TRUE(!a);\n\n  a = b;\n  ASSERT_TRUE(a);\n  ASSERT_EQ(Private::Value(a), 233);\n\n  Private::Value(a) = 234;\n  ASSERT_EQ(x, 234);\n\n  Optional<const int&> e, f(x), g(e), h(f);\n\n  ASSERT_TRUE(!e);\n  ASSERT_TRUE(f);\n  ASSERT_TRUE(!g);\n  ASSERT_TRUE(h);\n  ASSERT_NE(NullOpt, h);\n\n  ASSERT_EQ(Private::Value(f), 234);\n  ASSERT_EQ(Private::Value(h), 234);\n\n  static_assert(std::is_same<decltype(Private::Value(h)), const int&>::value, \"\");\n\n  e = x;\n  ASSERT_TRUE(e);\n  ASSERT_EQ(e, x);\n  ASSERT_EQ(e, 234);\n  ASSERT_EQ(Private::Value(e), 234);\n\n  e = NullOpt;\n  ASSERT_TRUE(!e);\n  ASSERT_EQ(e, NullOpt);\n}\n\nTEST(Optional, Hash) {\n  Optional<int> a, b(123);\n\n  ASSERT_EQ(std::hash<decltype(a)>()(a), NullOptHash);\n  ASSERT_EQ(std::hash<decltype(a)>()(b), std::hash<int>()(123));\n\n  auto si = std::make_shared<int>(123);\n  Optional<std::shared_ptr<int>> c, d(si);\n\n  ASSERT_EQ(std::hash<decltype(c)>()(c), NullOptHash);\n  ASSERT_EQ(std::hash<decltype(c)>()(d), std::hash<decltype(si)>()(si));\n\n  int x = 233;\n  Optional<int&> e, f(x);\n\n  ASSERT_EQ(std::hash<decltype(e)>()(e), NullOptHash);\n  ASSERT_EQ(std::hash<decltype(e)>()(f), std::hash<int*>()(&x));\n\n  Optional<const int&> g;\n  ASSERT_EQ(std::hash<decltype(g)>()(g), NullOptHash);\n}\n\nTEST(Optional, Compare) {\n  Optional<int> a, b, c(-1), d(0), e(1), f(1);\n\n  ASSERT_EQ(a, b);\n  ASSERT_EQ(e, f);\n  ASSERT_NE(a, d);\n  ASSERT_LT(b, c);\n  ASSERT_LE(b, c);\n  ASSERT_LE(c, c);\n  ASSERT_LT(c, d);\n  ASSERT_LT(d, e);\n  ASSERT_GT(e, d);\n  ASSERT_GT(d, c);\n  ASSERT_GT(c, b);\n  ASSERT_GE(c, b);\n  ASSERT_GE(a, b);\n\n  int x = 0, y = 1, z = -1;\n  ASSERT_NE(a, x);\n  ASSERT_EQ(d, x);\n  ASSERT_NE(x, c);\n  ASSERT_EQ(z, c);\n  ASSERT_LT(a, x);\n  ASSERT_LT(c, x);\n  ASSERT_LT(d, y);\n  ASSERT_LT(z, f);\n  ASSERT_LE(a, x);\n  ASSERT_LE(d, x);\n  ASSERT_GT(x, a);\n  ASSERT_GT(x, c);\n  ASSERT_GT(y, d);\n  ASSERT_GT(f, z);\n  ASSERT_GE(x, a);\n  ASSERT_GE(x, d);\n\n  std::set<Optional<int>> s{2, NullOpt, -1, 3, NullOpt, 2};\n\n  ASSERT_EQ(s.size(), 4);\n\n  auto iter = s.begin();\n  ASSERT_EQ(*(iter++), NullOpt);\n  ASSERT_EQ(*(iter++), -1);\n  ASSERT_EQ(*(iter++), 2);\n  ASSERT_EQ(*(iter++), 3);\n}\n\nTEST(Optional, Monadic) {\n  Optional<int> a(1), b, c(2);\n  ASSERT_EQ(a.Map([](int x) { return x + 1; }), c);\n  ASSERT_EQ(b.Map([](int x) { return x + 1; }), b);\n  ASSERT_EQ(a.Map([](int x) { return std::string(x + 1, 'a'); }).Map([](const auto& x) {\n    return (int)x.size();\n  }),\n            c);\n  ASSERT_EQ(a.Bind([](int x) -> Optional<float> {\n               if (x < 10) {\n                 return x * 1.1;\n               } else {\n                 return NullOpt;\n               }\n             })\n                .Map([](float x) { return x - 1; })\n                .Map([](float x) { return std::abs(x - 0.1) < 0.001; }),\n            Optional<bool>(true));\n\n  int x = 0;\n  [[maybe_unused]] auto _ = b.OrElse([&] { x++; }).OrElse([&] { x *= 2; });\n  ASSERT_EQ(x, 2);\n  ASSERT_EQ(b.OrElse([] { return Optional<int>(3); }).Map([](int x) { return x - 1; }), c);\n}\n"
  },
  {
    "path": "oneflow/maybe/type_traits.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#ifndef ONEFLOW_MAYBE_TYPE_TRAITS_H_\n#define ONEFLOW_MAYBE_TYPE_TRAITS_H_\n\n#include <cstddef>\n#include <type_traits>\n#include <tuple>\n#include <utility>\n#include \"config.h\"\n\nnamespace oneflow {\n\nnamespace maybe {\n\n// in this file, xxxS represents struct of xxx\n// for implementant aspect, xxx is an alias of xxxS::type or xxxS::value\n\ntemplate<bool B>\nusing BoolConstant = std::integral_constant<bool, B>;\n\ntemplate<std::size_t I>\nusing IndexConstant = std::integral_constant<std::size_t, I>;\n\nconstexpr std::size_t NPos = -1;\n\ntemplate<typename...>\nstruct ConjS : std::true_type {};\ntemplate<typename B1>\nstruct ConjS<B1> : B1 {};\ntemplate<typename B1, typename... Bn>\nstruct ConjS<B1, Bn...> : std::conditional_t<bool(B1::value), ConjS<Bn...>, B1> {};\n\ntemplate<typename... B>\nconstexpr bool Conj = ConjS<B...>::value;\n\ntemplate<typename...>\nstruct DisjS : std::false_type {};\ntemplate<typename B1>\nstruct DisjS<B1> : B1 {};\ntemplate<typename B1, typename... Bn>\nstruct DisjS<B1, Bn...> : std::conditional_t<bool(B1::value), B1, DisjS<Bn...>> {};\n\ntemplate<typename... B>\nconstexpr bool Disj = DisjS<B...>::value;\n\ntemplate<typename B>\nstruct NegS : BoolConstant<!bool(B::value)> {};\n\ntemplate<typename B>\nconstexpr bool Neg = NegS<B>::value;\n\nstruct TypeNotFound;\n\n// return TypeNotFound while out of range\ntemplate<std::size_t I, typename... Tn>\nstruct TypeGetS;\n\ntemplate<std::size_t I, typename T1, typename... Tn>\nstruct TypeGetS<I, T1, Tn...> : TypeGetS<I - 1, Tn...> {};\n\ntemplate<typename T1, typename... Tn>\nstruct TypeGetS<0, T1, Tn...> {\n  using type = T1;\n};\n\ntemplate<std::size_t N>\nstruct TypeGetS<N> {\n  using type = TypeNotFound;\n};\n\ntemplate<std::size_t I, typename... Ts>\nusing TypeGet = typename TypeGetS<I, Ts...>::type;\n\n// return NPos (-1) while not found\ntemplate<std::size_t I, typename T, typename... Tn>\nstruct IndexGetFromS;\n\ntemplate<std::size_t I, typename T, typename T1, typename... Tn>\nstruct IndexGetFromS<I, T, T1, Tn...> : IndexGetFromS<I + 1, T, Tn...> {};\n\ntemplate<std::size_t I, typename T1, typename... Tn>\nstruct IndexGetFromS<I, T1, T1, Tn...> : IndexConstant<I> {};\n\ntemplate<std::size_t I, typename T>\nstruct IndexGetFromS<I, T> : IndexConstant<NPos> {};\n\ntemplate<typename T, typename... Ts>\nconstexpr auto IndexGet = IndexGetFromS<0, T, Ts...>::value;\n\ntemplate<typename T, typename... Ts>\nconstexpr auto TypeIn = IndexGet<T, Ts...> != NPos;\n\ntemplate<typename T, typename... Ts>\nusing TypeInS = BoolConstant<TypeIn<T, Ts...>>;\n\ntemplate<typename T>\nstruct RemoveCVRefS {\n  using type = std::remove_cv_t<std::remove_reference_t<T>>;\n};\n\ntemplate<typename T>\nusing RemoveCVRef = typename RemoveCVRefS<T>::type;\n\ntemplate<typename T, typename... Ts>\nstruct IsDifferentTypesS : BoolConstant<!TypeIn<T, Ts...> && IsDifferentTypesS<Ts...>::value> {};\n\ntemplate<typename T>\nstruct IsDifferentTypesS<T> : std::true_type {};\n\ntemplate<typename T, typename... Ts>\nconstexpr auto IsDifferentTypes = IsDifferentTypesS<T, Ts...>::value;\n\ntemplate<typename T>\nstruct ConstRefExceptVoidS {\n  using type = const T&;\n};\n\ntemplate<>\nstruct ConstRefExceptVoidS<void> {\n  using type = void;\n};\n\ntemplate<typename T>\nusing ConstRefExceptVoid = typename ConstRefExceptVoidS<T>::type;\n\ntemplate<typename T>\nusing RemoveRValRef =\n    std::conditional_t<std::is_rvalue_reference<T>::value, std::remove_reference_t<T>, T>;\n\ntemplate<typename T>\nconstexpr bool IsAggregate = OF_MAYBE_IS_AGGREGATE(T);\n\n}  // namespace maybe\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_MAYBE_TYPE_TRAITS_H_\n"
  },
  {
    "path": "oneflow/maybe/type_traits_test.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include <gtest/gtest.h>\n#include <type_traits>\n#include \"oneflow/maybe/type_traits.h\"\n\nusing namespace oneflow::maybe;\n\nTEST(TypeTraits, Basics) {\n  static_assert(Conj<std::true_type, std::true_type>, \"\");\n  static_assert(!Conj<std::false_type, std::true_type>, \"\");\n  static_assert(!Conj<std::true_type, std::false_type>, \"\");\n  static_assert(!Conj<std::false_type, std::false_type>, \"\");\n  static_assert(!Conj<std::true_type, std::true_type, std::false_type>, \"\");\n\n  static_assert(Disj<std::true_type, std::true_type>, \"\");\n  static_assert(Disj<std::false_type, std::true_type>, \"\");\n  static_assert(Disj<std::true_type, std::false_type>, \"\");\n  static_assert(!Disj<std::false_type, std::false_type>, \"\");\n  static_assert(Disj<std::true_type, std::true_type, std::false_type>, \"\");\n  static_assert(!Disj<std::false_type, std::false_type, std::false_type>, \"\");\n\n  static_assert(std::is_same<TypeGet<0, int>, int>::value, \"\");\n  static_assert(std::is_same<TypeGet<0, int, float>, int>::value, \"\");\n  static_assert(std::is_same<TypeGet<1, int, float>, float>::value, \"\");\n  static_assert(std::is_same<TypeGet<2, int, float, bool, int>, bool>::value, \"\");\n  static_assert(std::is_same<TypeGet<2, int, int>, TypeNotFound>::value, \"\");\n  static_assert(std::is_same<TypeGet<2, int, int, float>, float>::value, \"\");\n  static_assert(std::is_same<TypeGet<2, int, int, float, int>, float>::value, \"\");\n  static_assert(std::is_same<TypeGet<2>, TypeNotFound>::value, \"\");\n\n  static_assert(IndexGet<int, int> == 0, \"\");\n  static_assert(IndexGet<int, float> == NPos, \"\");\n  static_assert(IndexGet<int, int, int> == 0, \"\");\n  static_assert(IndexGet<int, float, int> == 1, \"\");\n  static_assert(IndexGet<bool, int, float, int, bool, bool, int> == 3, \"\");\n  static_assert(IndexGet<int> == NPos, \"\");\n\n  static_assert(!TypeIn<int>, \"\");\n  static_assert(TypeIn<int, int>, \"\");\n  static_assert(TypeIn<int, float, int>, \"\");\n  static_assert(!TypeIn<int, float, float, bool>, \"\");\n  static_assert(TypeIn<int, float, bool, int, float>, \"\");\n  static_assert(TypeIn<bool, float, float, bool>, \"\");\n\n  static_assert(IsDifferentTypes<int, float, bool>, \"\");\n  static_assert(!IsDifferentTypes<int, float, int, bool>, \"\");\n  static_assert(IsDifferentTypes<int>, \"\");\n  static_assert(!IsDifferentTypes<int, int>, \"\");\n}\n"
  },
  {
    "path": "oneflow/maybe/utility.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#ifndef ONEFLOW_MAYBE_UTILITY_H_\n#define ONEFLOW_MAYBE_UTILITY_H_\n\n#include <cstddef>\n#include <functional>\n\nnamespace oneflow {\n\nnamespace maybe {\n\n// unlike std::nullopt in c++17, the NullOptType is used in both Variant and Optional,\n// so it is more like both std::nullopt and std::monostate (in c++17),\n// the advantage of this unification is a more unifed experience,\n// i.e. `return NullOpt` can be used in both Variant and Optional context\nstruct NullOptType {\n  explicit constexpr NullOptType() = default;\n\n  bool operator==(NullOptType) const { return true; }\n  bool operator!=(NullOptType) const { return false; }\n  bool operator<(NullOptType) const { return false; }\n  bool operator>(NullOptType) const { return false; }\n  bool operator<=(NullOptType) const { return true; }\n  bool operator>=(NullOptType) const { return true; }\n};\n\nconstexpr const std::size_t NullOptHash = -3333;\n\nconstexpr NullOptType NullOpt{};\n\nstruct InPlaceT {\n  explicit constexpr InPlaceT() = default;\n};\n\nconstexpr InPlaceT InPlace;\n\ntemplate<typename T>\nstruct InPlaceTypeT {\n  explicit constexpr InPlaceTypeT() = default;\n};\n\ntemplate<typename T>\nconstexpr InPlaceTypeT<T> InPlaceType;\n\ntemplate<std::size_t I>\nstruct InPlaceIndexT {\n  explicit constexpr InPlaceIndexT() = default;\n};\n\ntemplate<std::size_t I>\nconstexpr InPlaceIndexT<I> InPlaceIndex;\n\ntemplate<class T>\nconstexpr void HashCombine(std::size_t& seed, const T& v) {\n  std::hash<T> hasher;\n  seed ^= hasher(v) + 0x9e3779b9 + (seed << 6) + (seed >> 2);\n}\n\n}  // namespace maybe\n\n}  // namespace oneflow\n\nnamespace std {\n\ntemplate<>\nstruct hash<oneflow::maybe::NullOptType> {\n  size_t operator()(oneflow::maybe::NullOptType) const noexcept {\n    return oneflow::maybe::NullOptHash;\n  }\n};\n\n}  // namespace std\n\n#endif  // ONEFLOW_MAYBE_UTILITY_H_\n"
  },
  {
    "path": "oneflow/maybe/utility_test.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include <gtest/gtest.h>\n#include \"oneflow/maybe/utility.h\"\n\nusing namespace oneflow::maybe;\n\nTEST(Utility, NullOpt) {\n  NullOptType a, b(NullOpt), c(a);  // NOLINT\n\n  a = NullOpt;\n\n  a = b;\n\n  ASSERT_EQ(a, NullOptType{});\n  ASSERT_EQ(std::hash<NullOptType>()(a), std::hash<NullOptType>()(NullOpt));\n  ASSERT_EQ(NullOpt, a);\n  ASSERT_GE(NullOpt, a);\n  ASSERT_LE(NullOpt, a);\n  ASSERT_FALSE(NullOpt < a);\n  ASSERT_FALSE(NullOpt > a);\n}\n"
  },
  {
    "path": "oneflow/maybe/variant.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#ifndef ONEFLOW_MAYBE_VARIANT_H_\n#define ONEFLOW_MAYBE_VARIANT_H_\n\n#include <algorithm>\n#include <cstddef>\n#include <cstdint>\n#include <type_traits>\n#include <utility>\n#include <functional>\n#include <iostream>\n\n#include \"oneflow/maybe/utility.h\"\n#include \"oneflow/maybe/type_traits.h\"\n\nnamespace oneflow {\n\nnamespace maybe {\n\ntemplate<typename... Ts>\nstruct Variant;\n\nnamespace details {\n\n// there are generally two ways to implement visit (like std::visit in c++17)\n// 1. O(N) or O(log N), to iterate for all types or do a binary search on type index recursively\n// 2. O(1), to store an static (storage duration) array of function pointers for every (Variant, F)\n// where N = Variant<T...>::Num, and normally (in most cases) within the range [2, 5]\n// the 2nd method is required in std::visit(f, x...) while sizeof...(x) == 1\n// but weakness of the 2nd method is that compilers usually cannot efficiently optimize these\n// function pointers (compared to trivial recursion, which is easy to do optimization, and also\n// friendly to CPU cache) here we implement visit via the first method:\n// 1. for 2 <= N < 4, we use the O(N) algorithm (TrivialRecursiveVisitImpl) for better optimization\n// 2. for N >= 4, we use the O(log N) algorithm (BinarySearchVisitImpl) for less recursion rounds\n\nstruct VariantPrivateScope {\n  template<typename R, typename F, typename V>\n  static R TrivialRecursiveVisitImpl(F&& f, V&& v, InPlaceIndexT<RemoveCVRef<V>::Num - 1>) {\n    // assume v.Index() == N - 1 now\n    return static_cast<R>(\n        std::forward<F>(f)(std::forward<V>(v).template Value<RemoveCVRef<V>::Num - 1>()));\n  }\n\n  template<typename R, std::size_t I, typename F, typename V,\n           std::enable_if_t<(I < RemoveCVRef<V>::Num - 1), int> = 0>\n  static R TrivialRecursiveVisitImpl(F&& f, V&& v, InPlaceIndexT<I>) {\n    if (v.Index() == I) {\n      return static_cast<R>(std::forward<F>(f)(std::forward<V>(v).template Value<I>()));\n    }\n\n    return TrivialRecursiveVisitImpl<R>(std::forward<F>(f), std::forward<V>(v),\n                                        InPlaceIndex<I + 1>);\n  }\n\n  template<typename R, std::size_t I, typename F, typename V,\n           std::enable_if_t<(I < RemoveCVRef<V>::Num), int> = 0>\n  static R BinarySearchVisitImpl(F&& f, V&& v, InPlaceIndexT<I>, InPlaceIndexT<I>) {\n    return static_cast<R>(std::forward<F>(f)(std::forward<V>(v).template Value<I>()));\n  }\n\n  template<typename R, std::size_t I, typename F, typename V,\n           std::enable_if_t<(I + 1 < RemoveCVRef<V>::Num), int> = 0>\n  static R BinarySearchVisitImpl(F&& f, V&& v, InPlaceIndexT<I>, InPlaceIndexT<I + 1>) {\n    constexpr std::size_t M = (I + I + 1) / 2;\n    constexpr std::size_t N = (M == I) ? I + 1 : I;\n\n    if (v.Index() == M) {\n      return static_cast<R>(std::forward<F>(f)(std::forward<V>(v).template Value<M>()));\n    } else {\n      return static_cast<R>(std::forward<F>(f)(std::forward<V>(v).template Value<N>()));\n    }\n  }\n\n  template<typename R, std::size_t L, std::size_t U, typename F, typename V,\n           std::enable_if_t<(L + 1 < U) && (U < RemoveCVRef<V>::Num), int> = 0>\n  static R BinarySearchVisitImpl(F&& f, V&& v, InPlaceIndexT<L>, InPlaceIndexT<U>) {\n    constexpr std::size_t M = (L + U) / 2;\n\n    if (v.Index() < M) {\n      return BinarySearchVisitImpl<R>(std::forward<F>(f), std::forward<V>(v), InPlaceIndex<L>,\n                                      InPlaceIndex<M - 1>);\n    } else if (v.Index() > M) {\n      return BinarySearchVisitImpl<R>(std::forward<F>(f), std::forward<V>(v), InPlaceIndex<M + 1>,\n                                      InPlaceIndex<U>);\n    } else {\n      return static_cast<R>(std::forward<F>(f)(std::forward<V>(v).template Value<M>()));\n    }\n  }\n\n  template<typename R, typename F, typename V,\n           std::enable_if_t<RemoveCVRef<V>::Num<4, int> = 0> static R VisitImpl(F&& f, V&& v) {\n    return TrivialRecursiveVisitImpl<R>(std::forward<F>(f), std::forward<V>(v), InPlaceIndex<0>);\n  }\n\n  template<typename R, typename F, typename V, std::enable_if_t<RemoveCVRef<V>::Num >= 4, int> = 0>\n  static R VisitImpl(F&& f, V&& v) {\n    return BinarySearchVisitImpl<R>(std::forward<F>(f), std::forward<V>(v), InPlaceIndex<0>,\n                                    InPlaceIndex<RemoveCVRef<V>::Num - 1>);\n  }\n};\n\nstruct AutoDeducedResultType;\n\ntemplate<typename R, typename F, typename... Ts>\nstruct VisitResultS {\n  using type = R;\n};\n\ntemplate<typename F, typename... Ts>\nstruct VisitResultS<AutoDeducedResultType, F, Ts...> {\n  using type = std::common_type_t<decltype(std::declval<F>()(std::declval<Ts>()))...>;\n};\n\ntemplate<typename R, typename F, typename... Ts>\nusing VisitResult = typename VisitResultS<R, F, Ts...>::type;\n\n}  // namespace details\n\n// preconditions: template type arguments must be no less than 2 different type\n// and without reference and cv qualifiers\n// this Variant DO NOT guarantee exception safety\ntemplate<typename... Ts>\nstruct Variant {  // NOLINT(cppcoreguidelines-pro-type-member-init)\n public:\n  static_assert(sizeof...(Ts) > 1, \"expected more than two types\");\n  static_assert(Conj<NegS<std::is_reference<Ts>>...>, \"reference types are not allowed here\");\n  static_assert(Conj<NegS<DisjS<std::is_const<Ts>, std::is_volatile<Ts>>>...>,\n                \"cv qualifiers are not allowed here\");\n  // important precondition to optimize Visit via binary search\n  static_assert(IsDifferentTypes<Ts...>, \"expected all of different types\");\n\n  static constexpr std::size_t Num = sizeof...(Ts);\n\n  template<typename T>\n  static constexpr std::size_t IndexOfType = IndexGet<T, Ts...>;\n\n  template<typename T>\n  static constexpr bool HasType = TypeIn<T, Ts...>;\n\n  template<std::size_t I>\n  using TypeByIndex = TypeGet<I, Ts...>;\n\n  template<typename T = TypeByIndex<0>,\n           std::enable_if_t<std::is_default_constructible<T>::value, int> = 0>\n  Variant() {  // NOLINT(cppcoreguidelines-pro-type-member-init)\n    Construct<0>();\n  }\n\n  // unlike std::variant, we only accept exact types to avoid wrong construction\n  template<typename T, std::enable_if_t<HasType<RemoveCVRef<T>>, int> = 0>\n  Variant(T&& v) {  // NOLINT(cppcoreguidelines-pro-type-member-init, google-explicit-constructor)\n    Construct<RemoveCVRef<T>>(std::forward<T>(v));\n  }\n\n  template<typename T, typename... Args, std::enable_if_t<HasType<RemoveCVRef<T>>, int> = 0>\n  explicit Variant(InPlaceTypeT<T>,  // NOLINT(cppcoreguidelines-pro-type-member-init)\n                   Args&&... args) {\n    Construct<RemoveCVRef<T>>(std::forward<Args>(args)...);\n  }\n\n  template<std::size_t I, typename... Args, std::enable_if_t<(I < Num), int> = 0>\n  explicit Variant(InPlaceIndexT<I>,  // NOLINT(cppcoreguidelines-pro-type-member-init)\n                   Args&&... args) {\n    Construct<I>(std::forward<Args>(args)...);\n  }\n\n  template<typename R = details::AutoDeducedResultType, typename F>\n  decltype(auto) Visit(F&& f) & {\n    using Result = details::VisitResult<R, F, Ts&...>;\n    return details::VariantPrivateScope::VisitImpl<Result>(std::forward<F>(f), *this);\n  }\n\n  template<typename R = details::AutoDeducedResultType, typename F>\n  decltype(auto) Visit(F&& f) && {\n    using Result = details::VisitResult<R, F, Ts&&...>;\n    return details::VariantPrivateScope::VisitImpl<Result>(std::forward<F>(f), std::move(*this));\n  }\n\n  template<typename R = details::AutoDeducedResultType, typename F>\n  decltype(auto) Visit(F&& f) const& {\n    using Result = details::VisitResult<R, F, const Ts&...>;\n    return details::VariantPrivateScope::VisitImpl<Result>(std::forward<F>(f), *this);\n  }\n\n  Variant(const Variant& v) {  // NOLINT(cppcoreguidelines-pro-type-member-init)\n    CopyConstruct(v);\n  }\n\n  Variant(Variant&& v) noexcept {  // NOLINT(cppcoreguidelines-pro-type-member-init)\n    CopyConstruct(std::move(v));\n  }\n\n  template<typename T, std::enable_if_t<HasType<RemoveCVRef<T>>, int> = 0>\n  Variant& operator=(T&& v) {\n    using Type = RemoveCVRef<T>;\n\n    Emplace<Type>(std::forward<T>(v));\n\n    return *this;\n  }\n\n  Variant& operator=(const Variant& v) {\n    Copy(v);\n    return *this;\n  }\n\n  Variant& operator=(Variant&& v) noexcept {\n    Copy(std::move(v));\n    return *this;\n  }\n\n  std::size_t Index() const { return type_index_; }\n\n  template<typename T, std::enable_if_t<HasType<T>, int> = 0>\n  bool Is() const {\n    return type_index_ == IndexOfType<T>;\n  }\n\n  ~Variant() { Destory(); }\n\n  bool operator==(const Variant& v) const {\n    if (type_index_ != v.type_index_) return false;\n\n    return v.Visit(\n        [this](const auto& elem) { return elem == Value<RemoveCVRef<decltype(elem)>>(); });\n  }\n\n  bool operator!=(const Variant& v) const { return !operator==(v); }\n\n  bool operator<(const Variant& v) const {\n    if (type_index_ < v.type_index_) return true;\n    if (type_index_ > v.type_index_) return false;\n\n    return v.Visit(\n        [this](const auto& elem) { return Value<RemoveCVRef<decltype(elem)>>() < elem; });\n  }\n\n  bool operator>=(const Variant& v) const { return !(*this < v); }\n\n  bool operator>(const Variant& v) const {\n    if (type_index_ > v.type_index_) return true;\n    if (type_index_ < v.type_index_) return false;\n\n    return v.Visit(\n        [this](const auto& elem) { return Value<RemoveCVRef<decltype(elem)>>() > elem; });\n  }\n\n  bool operator<=(const Variant& v) const { return !(*this > v); }\n\n  template<typename T, std::enable_if_t<HasType<T>, int> = 0>\n  friend bool operator==(const Variant& v, const T& x) {\n    if (v.type_index_ != IndexOfType<T>) return false;\n\n    return v.Value<T>() == x;\n  }\n\n  template<typename T, std::enable_if_t<HasType<T>, int> = 0>\n  friend bool operator!=(const Variant& v, const T& x) {\n    return !(v == x);\n  }\n\n  template<typename T, std::enable_if_t<HasType<T>, int> = 0>\n  friend bool operator==(const T& x, const Variant& v) {\n    return v == x;\n  }\n\n  template<typename T, std::enable_if_t<HasType<T>, int> = 0>\n  friend bool operator!=(const T& x, const Variant& v) {\n    return !(v == x);\n  }\n\n  template<typename T, typename... Args>\n  T& Emplace(Args&&... args) {\n    if (Is<T>()) {\n      return Value<T>() = T(std::forward<Args>(args)...);\n    } else {\n      Destory();\n      Construct<T>(std::forward<Args>(args)...);\n      return Value<T>();\n    }\n  }\n\n  template<std::size_t I, typename... Args>\n  decltype(auto) Emplace(Args&&... args) {\n    return Emplace<TypeByIndex<I>>(std::forward<Args>(args)...);\n  }\n\n  template<typename T, std::enable_if_t<HasType<T>, int> = 0>\n  T& Get() & {\n    OF_MAYBE_ASSERT_EQ(Index(), IndexOfType<T>);\n    return Value<T>();\n  }\n\n  template<typename T, std::enable_if_t<HasType<T>, int> = 0>\n  T&& Get() && {\n    OF_MAYBE_ASSERT_EQ(Index(), IndexOfType<T>);\n    return std::move(*this).template Value<T>();\n  }\n\n  template<typename T, std::enable_if_t<HasType<T>, int> = 0>\n  const T& Get() const& {\n    OF_MAYBE_ASSERT_EQ(Index(), IndexOfType<T>);\n    return Value<T>();\n  }\n\n  template<std::size_t I, std::enable_if_t<(I < Num), int> = 0>\n  TypeByIndex<I>& Get() & {\n    OF_MAYBE_ASSERT_EQ(Index(), I);\n    return Value<I>();\n  }\n\n  template<std::size_t I, std::enable_if_t<(I < Num), int> = 0>\n  TypeByIndex<I>&& Get() && {\n    OF_MAYBE_ASSERT_EQ(Index(), I);\n    return std::move(*this).template Value<I>();\n  }\n\n  template<std::size_t I, std::enable_if_t<(I < Num), int> = 0>\n  const TypeByIndex<I>& Get() const& {\n    OF_MAYBE_ASSERT_EQ(Index(), I);\n    return Value<I>();\n  }\n\n protected:\n  // use std::launder while updating to c++17\n  template<typename T, std::enable_if_t<HasType<T>, int> = 0>\n  T& Value() & {\n    return *reinterpret_cast<T*>(storage_);\n  }\n\n  template<typename T, std::enable_if_t<HasType<T>, int> = 0>\n  T&& Value() && {\n    return std::move(*reinterpret_cast<T*>(storage_));\n  }\n\n  template<typename T, std::enable_if_t<HasType<T>, int> = 0>\n  const T& Value() const& {\n    return *reinterpret_cast<const T*>(storage_);\n  }\n\n  template<std::size_t I, std::enable_if_t<(I < Num), int> = 0>\n  TypeByIndex<I>& Value() & {\n    return *reinterpret_cast<TypeByIndex<I>*>(storage_);\n  }\n\n  template<std::size_t I, std::enable_if_t<(I < Num), int> = 0>\n  TypeByIndex<I>&& Value() && {\n    return std::move(*reinterpret_cast<TypeByIndex<I>*>(storage_));\n  }\n\n  template<std::size_t I, std::enable_if_t<(I < Num), int> = 0>\n  const TypeByIndex<I>& Value() const& {\n    return *reinterpret_cast<const TypeByIndex<I>*>(storage_);\n  }\n\n private:\n  static constexpr const std::size_t size = std::max({sizeof(Ts)...});\n\n  alignas(Ts...) unsigned char storage_[size];\n  std::uint8_t type_index_;\n\n  friend struct details::VariantPrivateScope;\n\n  template<typename T, typename... Args, std::enable_if_t<HasType<T> && IsAggregate<T>, int> = 0>\n  void Construct(Args&&... args) {\n    new (storage_) T{std::forward<Args>(args)...};\n    type_index_ = IndexOfType<T>;\n  }\n\n  template<typename T, typename... Args, std::enable_if_t<HasType<T> && !IsAggregate<T>, int> = 0>\n  void Construct(Args&&... args) {\n    new (storage_) T(std::forward<Args>(args)...);\n    type_index_ = IndexOfType<T>;\n  }\n\n  template<std::size_t I, typename... Args, std::enable_if_t<(I < Num), int> = 0>\n  void Construct(Args&&... args) {\n    Construct<TypeByIndex<I>>(std::forward<Args>(args)...);\n  }\n\n  template<typename V>\n  void CopyConstruct(V&& v) {\n    std::forward<V>(v).Visit([this](auto&& elem) {\n      using T = RemoveCVRef<decltype(elem)>;\n\n      new (storage_) T(std::forward<decltype(elem)>(elem));\n      type_index_ = IndexOfType<T>;\n    });\n  }\n\n  template<typename V>\n  void Copy(V&& v) {\n    std::forward<V>(v).Visit([this](auto&& elem) {\n      using T = RemoveCVRef<decltype(elem)>;\n\n      if (Is<T>()) {\n        Value<T>() = std::forward<decltype(elem)>(elem);\n      } else {\n        Destory();\n        Construct<T>(std::forward<decltype(elem)>(elem));\n      }\n    });\n  }\n\n  void Destory() {\n    Visit([this](auto& elem) {\n      using T = RemoveCVRef<decltype(elem)>;\n\n      Value<T>().~T();\n    });\n  }\n};\n\ntemplate<typename... Ts>\nusing OptionalVariant = Variant<NullOptType, Ts...>;\n\n}  // namespace maybe\n\n}  // namespace oneflow\n\nnamespace std {\n\ntemplate<typename... Ts>\nstruct hash<oneflow::maybe::Variant<Ts...>> {\n  size_t operator()(const oneflow::maybe::Variant<Ts...>& v) const noexcept {\n    size_t seed = hash<size_t>()(v.Index());\n\n    v.Visit([&seed](const auto& x) {\n      using type = oneflow::maybe::RemoveCVRef<decltype(x)>;\n      oneflow::maybe::HashCombine<type>(seed, x);\n    });\n\n    return seed;\n  }\n};\n\n}  // namespace std\n\n#endif  // ONEFLOW_MAYBE_VARIANT_H_\n"
  },
  {
    "path": "oneflow/maybe/variant_test.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include <gtest/gtest.h>\n#include <memory>\n#include \"oneflow/maybe/variant.h\"\n\nusing namespace oneflow::maybe;\nusing namespace std::string_literals;\n\nTEST(Variant, Basics) {\n  Variant<int, float> a, b(1), c(1.2f), d(InPlaceType<int>, 'a'), e(InPlaceType<float>, 6.66);\n  ASSERT_TRUE(a.Is<int>());\n  ASSERT_EQ(a.Get<int>(), 0);\n  ASSERT_TRUE(b.Is<int>());\n  ASSERT_EQ(b.Get<int>(), 1);\n  ASSERT_TRUE(c.Is<float>());\n  ASSERT_EQ(c.Get<float>(), 1.2f);\n  ASSERT_TRUE(d.Is<int>());\n  ASSERT_EQ(d.Get<int>(), 'a');\n  ASSERT_TRUE(e.Is<float>());\n  ASSERT_FLOAT_EQ(e.Get<float>(), 6.66);\n\n  Variant<int, float> f(b), g(c), h(InPlaceIndex<1>, 2.33), i(InPlaceIndex<0>, 2.33);\n  ASSERT_TRUE(f.Is<int>());\n  ASSERT_EQ(f.Get<int>(), 1);\n  ASSERT_TRUE(g.Is<float>());\n  ASSERT_EQ(g.Get<float>(), 1.2f);\n  ASSERT_TRUE(h.Is<float>());\n  ASSERT_FLOAT_EQ(h.Get<float>(), 2.33);\n  ASSERT_TRUE(i.Is<int>());\n  ASSERT_EQ(i.Get<int>(), 2);\n\n  a = 1;\n  ASSERT_TRUE(a.Is<int>());\n  ASSERT_EQ(a.Get<int>(), 1);\n\n  a = 1.3f;\n  ASSERT_TRUE(a.Is<float>());\n  ASSERT_EQ(a.Get<float>(), 1.3f);\n\n  a = b;\n  ASSERT_TRUE(a.Is<int>());\n  ASSERT_EQ(a.Get<int>(), 1);\n\n  a = c;\n  ASSERT_TRUE(a.Is<float>());\n  ASSERT_EQ(a.Get<float>(), 1.2f);\n\n  ASSERT_EQ((b.Visit<Variant<int, float>>([](auto&& x) { return x + 1; })),\n            (Variant<int, float>(2)));\n  ASSERT_EQ((c.Visit<Variant<int, float>>([](auto&& x) { return x + 1; })),\n            (Variant<int, float>(2.2f)));\n\n  ASSERT_EQ(a.Emplace<1>(1.3f), 1.3f);\n  ASSERT_TRUE(a.Is<float>());\n  ASSERT_EQ(a.Get<1>(), 1.3f);\n\n  ASSERT_EQ(a.Emplace<0>(233), 233);\n  ASSERT_TRUE(a.Is<int>());\n  ASSERT_EQ(a.Get<0>(), 233);\n}\n\nTEST(Variant, NonPOD) {\n  Variant<bool, std::shared_ptr<int>> a;\n  ASSERT_TRUE(a.Is<bool>());\n  ASSERT_EQ(a.Get<bool>(), false);\n\n  a = true;\n  ASSERT_TRUE(a.Is<bool>());\n  ASSERT_EQ(a.Get<bool>(), true);\n\n  a = std::make_shared<int>(233);\n  ASSERT_EQ(a.Index(), 1);\n  ASSERT_EQ(*a.Get<1>(), 233);\n  ASSERT_EQ(a.Get<1>().use_count(), 1);\n\n  {\n    Variant<bool, std::shared_ptr<int>> b = a;\n    ASSERT_EQ(b.Index(), 1);\n    ASSERT_EQ(*b.Get<1>(), 233);\n    ASSERT_EQ(a.Get<1>().use_count(), 2);\n    *b.Get<1>() = 234;\n  }\n\n  ASSERT_EQ(a.Get<1>().use_count(), 1);\n  ASSERT_EQ(*a.Get<1>(), 234);\n\n  Variant<bool, std::shared_ptr<int>> b = std::move(a);\n  ASSERT_EQ(b.Get<1>().use_count(), 1);\n  ASSERT_EQ(*b.Get<1>(), 234);\n\n  Variant<bool, std::shared_ptr<int>> c = b;\n  ASSERT_EQ(c.Get<1>().use_count(), 2);\n  ASSERT_EQ(b, c);\n\n  b = true;\n  ASSERT_EQ(c.Get<1>().use_count(), 1);\n\n  ASSERT_NE(b, c);\n}\n\nTEST(Variant, Optional) {\n  OptionalVariant<int, const char*> a, b(NullOpt), c(a);\n\n  const char* hello = \"hello\";\n\n  std::size_t hash = 0, hash2 = 1, hash3 = 2;\n  HashCombine(hash, NullOpt);\n  HashCombine(hash2, 1);\n  HashCombine(hash3, hello);\n\n  ASSERT_TRUE(a == NullOpt);\n  ASSERT_EQ(std::hash<decltype(a)>()(a), hash);\n\n  a = 1;\n  ASSERT_EQ(a, 1);\n  ASSERT_EQ(std::hash<decltype(a)>()(a), hash2);\n\n  a = NullOpt;\n  ASSERT_EQ(a, NullOpt);\n  ASSERT_EQ(std::hash<decltype(a)>()(a), hash);\n\n  a = hello;\n  ASSERT_EQ(a, hello);\n  ASSERT_EQ(std::hash<decltype(a)>()(a), hash3);\n\n  ASSERT_EQ(b, NullOpt);\n  ASSERT_EQ(c, NullOpt);\n  ASSERT_NE(a, b);\n}\n\nTEST(Variant, BinarySearchVisit) {\n  const char* hello = \"hello\";\n\n  OptionalVariant<int, float, bool> x, y(123), z(1.2f), w(true);\n  OptionalVariant<int, float, bool, const char*> a, b(123), c(1.2f), d(true), e(hello);\n\n  ASSERT_EQ(x, NullOpt);\n  ASSERT_EQ(y, 123);\n  ASSERT_EQ(z, 1.2f);\n  ASSERT_EQ(w, true);\n  ASSERT_EQ(a, NullOpt);\n  ASSERT_EQ(b, 123);\n  ASSERT_EQ(c, 1.2f);\n  ASSERT_EQ(d, true);\n  ASSERT_EQ(e, hello);\n\n  OptionalVariant<int, float, bool, const char*> a1(a), b1(b), c1(c), d1(d), e1(e);\n\n  ASSERT_EQ(a1, NullOpt);\n  ASSERT_EQ(b1, 123);\n  ASSERT_EQ(c1, 1.2f);\n  ASSERT_EQ(d1, true);\n  ASSERT_EQ(e1, hello);\n\n  a = 233;\n  ASSERT_EQ(a, 233);\n\n  a = hello;\n  ASSERT_EQ(a, hello);\n\n  a = c;\n  ASSERT_EQ(a, 1.2f);\n  ASSERT_EQ(1.2f, a);\n  ASSERT_EQ(a, c);\n  ASSERT_NE(a, b);\n}\n\nTEST(Variant, Compare) {\n  OptionalVariant<int, float, bool> a, b, c(0), d(5), dd(5), e(-1.2f), f(2.3f), g(false), h(true);\n\n  ASSERT_EQ(a, b);\n  ASSERT_EQ(d, dd);\n  ASSERT_NE(a, c);\n  ASSERT_NE(c, d);\n  ASSERT_NE(d, e);\n  ASSERT_NE(e, f);\n  ASSERT_NE(f, g);\n  ASSERT_NE(g, h);\n  ASSERT_LT(a, c);\n  ASSERT_LT(c, d);\n  ASSERT_LT(d, e);\n  ASSERT_LT(e, f);\n  ASSERT_LT(f, g);\n  ASSERT_LT(g, h);\n  ASSERT_GT(c, a);\n  ASSERT_GT(d, c);\n  ASSERT_GT(e, d);\n  ASSERT_GT(f, e);\n  ASSERT_GT(g, f);\n  ASSERT_GT(h, g);\n  ASSERT_LE(a, b);\n  ASSERT_LE(b, c);\n  ASSERT_LE(c, d);\n  ASSERT_LE(d, dd);\n\n  std::set<OptionalVariant<int, float, bool>> s{100, 2.3f,  true, 3.3f, NullOpt,\n                                                0,   false, 22,   true, NullOpt};\n  ASSERT_EQ(s.size(), 8);\n\n  auto iter = s.begin();\n  ASSERT_EQ(*(iter++), NullOpt);\n  ASSERT_EQ(*(iter++), 0);\n  ASSERT_EQ(*(iter++), 22);\n  ASSERT_EQ(*(iter++), 100);\n  ASSERT_EQ(*(iter++), 2.3f);\n  ASSERT_EQ(*(iter++), 3.3f);\n  ASSERT_EQ(*(iter++), false);\n  ASSERT_EQ(*(iter++), true);\n}\n\nTEST(Variant, UniquePtr) {\n  Variant<std::string, std::unique_ptr<int>> a(\"hello\"s), b(std::make_unique<int>(1));\n\n  ASSERT_EQ(a, \"hello\"s);\n  ASSERT_EQ(*b.Get<1>(), 1);\n\n  Variant<std::string, std::unique_ptr<int>> c(std::move(a)), d(std::move(b));\n\n  ASSERT_EQ(c, \"hello\"s);\n  ASSERT_EQ(*d.Get<1>(), 1);\n}\n"
  },
  {
    "path": "oneflow/user/data/batch_dataset.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_DATA_BATCH_DATASET_H_\n#define ONEFLOW_USER_DATA_BATCH_DATASET_H_\n\n#include \"oneflow/user/data/dataset.h\"\n\nnamespace oneflow {\nnamespace data {\n\ntemplate<typename LoadTarget>\nclass BatchDataset final : public Dataset<LoadTarget> {\n public:\n  using Base = Dataset<LoadTarget>;\n  using SampleType = typename Base::SampleType;\n  using BatchType = typename Base::BatchType;\n\n  BatchDataset(int32_t batch_size, std::unique_ptr<Dataset<LoadTarget>>&& dataset)\n      : batch_size_(batch_size), nested_ds_(std::move(dataset)) {}\n  ~BatchDataset() = default;\n\n  BatchType Next() override {\n    BatchType batch;\n    batch.reserve(batch_size_);\n    for (size_t i = 0; i < batch_size_; ++i) {\n      BatchType tmp = nested_ds_->Next();\n      CHECK_EQ(tmp.size(), 1);\n      batch.push_back(std::move(tmp[0]));\n    }\n    return batch;\n  }\n\n private:\n  int32_t batch_size_;\n  std::unique_ptr<Dataset<LoadTarget>> nested_ds_;\n};\n\n}  // namespace data\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_DATA_BATCH_DATASET_H_\n"
  },
  {
    "path": "oneflow/user/data/batch_random_shuffle_dataset.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_DATA_BATCH_RANDOM_SHUFFLE_DATASET_H_\n#define ONEFLOW_USER_DATA_BATCH_RANDOM_SHUFFLE_DATASET_H_\n\n#include \"oneflow/user/data/dataset.h\"\n#include \"oneflow/core/common/tensor_buffer.h\"\n#include \"oneflow/core/framework/op_kernel.h\"\n\nnamespace oneflow {\nnamespace data {\n\ntemplate<typename LoadTarget>\nclass BatchRandomShuffleDataset final : public Dataset<LoadTarget> {\n public:\n  using Base = Dataset<LoadTarget>;\n  using SampleType = typename Base::SampleType;\n  using BatchType = typename Base::BatchType;\n\n  BatchRandomShuffleDataset(user_op::KernelInitContext* ctx,\n                            std::unique_ptr<Dataset<LoadTarget>>&& data_set)\n      : loader_(std::move(data_set)) {\n    // random\n    seed_ = ctx->Attr<int64_t>(\"seed\");\n    if (seed_ == -1) { seed_ = NewRandomSeed(); }\n    std::seed_seq seq({seed_});\n    rand_engine_ = std::default_random_engine(seq);\n\n    // fill buffer\n    initial_buffer_fill_ = ctx->Attr<int32_t>(\"shuffle_buffer_size\");\n    for (int32_t i = 0; i < initial_buffer_fill_; ++i) {\n      BatchType batch = loader_->Next();\n      batch_buffer_.push_back(std::move(batch));\n    }\n  }\n  ~BatchRandomShuffleDataset() = default;\n\n  BatchType Next() override {\n    BatchType batch = loader_->Next();\n    std::uniform_int_distribution<> dis(0, batch_buffer_.size() - 1);\n    const int offset = dis(rand_engine_);\n    std::swap(batch_buffer_.at(offset), batch);\n    return batch;\n  }\n\n private:\n  std::unique_ptr<Dataset<LoadTarget>> loader_;\n  std::vector<BatchType> batch_buffer_;\n\n  int32_t initial_buffer_fill_;\n\n  std::default_random_engine rand_engine_;\n  int64_t seed_;\n};\n\n}  // namespace data\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_DATA_BATCH_RANDOM_SHUFFLE_DATASET_H_\n"
  },
  {
    "path": "oneflow/user/data/coco_data_reader.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/user/data/coco_data_reader.h\"\n#include \"oneflow/user/data/coco_dataset.h\"\n#include \"oneflow/user/data/distributed_training_dataset.h\"\n#include \"oneflow/user/data/group_batch_dataset.h\"\n#include \"oneflow/user/data/batch_dataset.h\"\n#include \"oneflow/user/data/distributed_util.h\"\n#include \"oneflow/core/persistence/file_system.h\"\n#include \"oneflow/core/persistence/persistent_in_stream.h\"\n#include \"oneflow/core/rpc/include/global_process_ctx.h\"\n\nnamespace oneflow {\nnamespace data {\n\nCOCODataReader::COCODataReader(user_op::KernelInitContext* ctx) : DataReader<COCOImage>(ctx) {\n  batch_size_ = ctx->TensorDesc4ArgNameAndIndex(\"image\", 0)->shape().elem_cnt();\n  if (auto* pool = TensorBufferPool::TryGet()) { pool->IncreasePoolSizeByBase(batch_size_); }\n\n  std::shared_ptr<const COCOMeta> meta(new COCOMeta(\n      ctx->Attr<int64_t>(\"session_id\"), ctx->Attr<std::string>(\"annotation_file\"),\n      ctx->Attr<std::string>(\"image_dir\"), ctx->Attr<bool>(\"remove_images_without_annotations\")));\n  std::unique_ptr<RandomAccessDataset<COCOImage>> coco_dataset_ptr(new COCODataset(ctx, meta));\n\n  size_t world_size = 1;\n  int64_t rank = 0;\n  CHECK_JUST(InitDataSourceDistributedInfo(ctx, world_size, rank));\n  loader_.reset(new DistributedTrainingDataset<COCOImage>(\n      world_size, rank, ctx->Attr<bool>(\"stride_partition\"), ctx->Attr<bool>(\"shuffle_after_epoch\"),\n      ctx->Attr<int64_t>(\"random_seed\"), std::move(coco_dataset_ptr)));\n\n  if (ctx->Attr<bool>(\"group_by_ratio\")) {\n    auto GetGroupId = [](const COCOImage& sample) {\n      return static_cast<int64_t>(sample.height / sample.width);\n    };\n    loader_.reset(new GroupBatchDataset<COCOImage>(batch_size_, GetGroupId, std::move(loader_)));\n  } else {\n    loader_.reset(new BatchDataset<COCOImage>(batch_size_, std::move(loader_)));\n  }\n\n  parser_.reset(new COCOParser(meta));\n  StartLoadThread();\n}\n\nCOCODataReader::~COCODataReader() {\n  if (auto* pool = TensorBufferPool::TryGet()) { pool->DecreasePoolSizeByBase(batch_size_); }\n}\n\nCOCOMeta::COCOMeta(int64_t session_id, const std::string& annotation_file,\n                   const std::string& image_dir, bool remove_images_without_annotations)\n    : image_dir_(image_dir) {\n  // Read content of annotation file (json format) to json obj\n  PersistentInStream in_stream(session_id, DataFS(), annotation_file);\n  std::string json_str;\n  std::string line;\n  while (in_stream.ReadLine(&line) == 0) { json_str += line; }\n  std::istringstream in_str_stream(json_str);\n  in_str_stream >> annotation_json_;\n  // initialize image_ids_, image_id2image_ and image_id2anno_ids_\n  for (const auto& image : annotation_json_[\"images\"]) {\n    int64_t id = image[\"id\"].get<int64_t>();\n    image_ids_.emplace_back(id);\n    CHECK(image_id2image_.emplace(id, image).second);\n    CHECK(image_id2anno_ids_.emplace(id, std::vector<int64_t>()).second);\n  }\n  // build anno map\n  for (const auto& anno : annotation_json_[\"annotations\"]) {\n    int64_t id = anno[\"id\"].get<int64_t>();\n    int64_t image_id = anno[\"image_id\"].get<int64_t>();\n    // ignore crowd object for now\n    if (anno[\"iscrowd\"].get<int>() == 1) { continue; }\n    // check if invalid segmentation\n    if (anno[\"segmentation\"].is_array()) {\n      for (const auto& poly : anno[\"segmentation\"]) {\n        // at least 3 points can compose a polygon\n        // every point needs 2 element (x, y) to present\n        CHECK_GT(poly.size(), 6);\n      }\n    }\n    CHECK(anno_id2anno_.emplace(id, anno).second);\n    image_id2anno_ids_.at(image_id).emplace_back(id);\n  }\n  // remove images without annotations if necessary\n  if (remove_images_without_annotations) {\n    HashSet<int64_t> to_remove_image_ids;\n    for (int64_t image_id : image_ids_) {\n      if (!ImageHasValidAnnotations(image_id)) { to_remove_image_ids.insert(image_id); }\n    }\n    image_ids_.erase(std::remove_if(image_ids_.begin(), image_ids_.end(),\n                                    [&to_remove_image_ids](int64_t image_id) {\n                                      return to_remove_image_ids.find(image_id)\n                                             != to_remove_image_ids.end();\n                                    }),\n                     image_ids_.end());\n  }\n  // sort image ids for reproducible results\n  std::sort(image_ids_.begin(), image_ids_.end());\n  // build categories map\n  std::vector<int32_t> category_ids;\n  for (const auto& cat : annotation_json_[\"categories\"]) {\n    category_ids.emplace_back(cat[\"id\"].get<int32_t>());\n  }\n  std::sort(category_ids.begin(), category_ids.end());\n  int32_t contiguous_id = 1;\n  for (int32_t category_id : category_ids) {\n    CHECK(category_id2contiguous_id_.emplace(category_id, contiguous_id++).second);\n  }\n}\n\nbool COCOMeta::ImageHasValidAnnotations(int64_t image_id) const {\n  const std::vector<int64_t>& anno_id_vec = image_id2anno_ids_.at(image_id);\n  if (anno_id_vec.empty()) { return false; }\n\n  bool bbox_area_all_close_to_zero = true;\n  size_t visible_keypoints_count = 0;\n  for (int64_t anno_id : anno_id_vec) {\n    const auto& anno = anno_id2anno_.at(anno_id);\n    if (anno[\"bbox\"][2] > 1 && anno[\"bbox\"][3] > 1) { bbox_area_all_close_to_zero = false; }\n    if (anno.contains(\"keypoints\")) {\n      const auto& keypoints = anno[\"keypoints\"];\n      CHECK_EQ(keypoints.size() % 3, 0);\n      FOR_RANGE(size_t, i, 0, keypoints.size() / 3) {\n        int32_t keypoints_label = keypoints[i * 3 + 2].get<int32_t>();\n        if (keypoints_label > 0) { visible_keypoints_count += 1; }\n      }\n    }\n  }\n  // check if all boxes are close to zero area\n  if (bbox_area_all_close_to_zero) { return false; }\n  // keypoints task have a slight different critera for considering\n  // if an annotation is valid\n  if (!anno_id2anno_.at(anno_id_vec.at(0)).contains(\"keypoints\")) { return true; }\n  // for keypoint detection tasks, only consider valid images those\n  // containing at least min_keypoints_per_image\n  if (visible_keypoints_count >= kMinKeypointsPerImage) { return true; }\n  return false;\n}\n\n}  // namespace data\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/data/coco_data_reader.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_DATA_COCO_DATA_READER_H_\n#define ONEFLOW_USER_DATA_COCO_DATA_READER_H_\n\n#include \"oneflow/user/data/data_reader.h\"\n#include \"oneflow/user/data/coco_parser.h\"\n#include \"oneflow/core/common/str_util.h\"\n#include \"nlohmann/json.hpp\"\n\nnamespace oneflow {\nnamespace data {\n\nclass COCODataReader final : public DataReader<COCOImage> {\n public:\n  COCODataReader(user_op::KernelInitContext* ctx);\n  ~COCODataReader() override;\n\n protected:\n  using DataReader<COCOImage>::loader_;\n  using DataReader<COCOImage>::parser_;\n\n private:\n  size_t batch_size_;\n};\n\nclass COCOMeta final {\n public:\n  COCOMeta(int64_t session_id, const std::string& annotation_file, const std::string& image_dir,\n           bool remove_images_without_annotations);\n  ~COCOMeta() = default;\n\n  int64_t Size() const { return image_ids_.size(); }\n  int64_t GetImageId(int64_t index) const { return image_ids_.at(index); }\n  int32_t GetImageHeight(int64_t index) const {\n    int64_t image_id = image_ids_.at(index);\n    return image_id2image_.at(image_id)[\"height\"].get<int32_t>();\n  }\n  int32_t GetImageWidth(int64_t index) const {\n    int64_t image_id = image_ids_.at(index);\n    return image_id2image_.at(image_id)[\"width\"].get<int32_t>();\n  }\n  std::string GetImageFilePath(int64_t index) const {\n    int64_t image_id = image_ids_.at(index);\n    const auto& image_json = image_id2image_.at(image_id);\n    return JoinPath(image_dir_, image_json[\"file_name\"].get<std::string>());\n  }\n  template<typename T>\n  std::vector<T> GetBboxVec(int64_t index) const;\n  template<typename T>\n  std::vector<T> GetLabelVec(int64_t index) const;\n  template<typename T>\n  void ReadSegmentationsToTensorBuffer(int64_t index, TensorBuffer* segm,\n                                       TensorBuffer* segm_offset_mat) const;\n\n private:\n  bool ImageHasValidAnnotations(int64_t image_id) const;\n\n  static constexpr int kMinKeypointsPerImage = 10;\n  nlohmann::json annotation_json_;\n  std::string image_dir_;\n  std::vector<int64_t> image_ids_;\n  HashMap<int64_t, const nlohmann::json&> image_id2image_;\n  HashMap<int64_t, const nlohmann::json&> anno_id2anno_;\n  HashMap<int64_t, std::vector<int64_t>> image_id2anno_ids_;\n  HashMap<int32_t, int32_t> category_id2contiguous_id_;\n};\n\ntemplate<typename T>\nstd::vector<T> COCOMeta::GetBboxVec(int64_t index) const {\n  std::vector<T> bbox_vec;\n  int64_t image_id = image_ids_.at(index);\n  const auto& anno_ids = image_id2anno_ids_.at(image_id);\n  for (int64_t anno_id : anno_ids) {\n    const auto& bbox_json = anno_id2anno_.at(anno_id)[\"bbox\"];\n    CHECK(bbox_json.is_array());\n    CHECK_EQ(bbox_json.size(), 4);\n    // COCO bounding box format is [left, top, width, height]\n    // we need format xyxy\n    const T alginment = static_cast<T>(1);\n    const T min_size = static_cast<T>(0);\n    T left = bbox_json[0].get<T>();\n    T top = bbox_json[1].get<T>();\n    T width = bbox_json[2].get<T>();\n    T height = bbox_json[3].get<T>();\n    T right = left + std::max(width - alginment, min_size);\n    T bottom = top + std::max(height - alginment, min_size);\n    // clip to image\n    int32_t image_height = GetImageHeight(index);\n    int32_t image_width = GetImageWidth(index);\n    left = std::min(std::max(left, min_size), image_width - alginment);\n    top = std::min(std::max(top, min_size), image_height - alginment);\n    right = std::min(std::max(right, min_size), image_width - alginment);\n    bottom = std::min(std::max(bottom, min_size), image_height - alginment);\n    // ensure bbox is not empty\n    if (right > left && bottom > top) {\n      bbox_vec.insert(bbox_vec.end(), {left, top, right, bottom});\n    }\n  }\n  return bbox_vec;\n}\n\ntemplate<typename T>\nstd::vector<T> COCOMeta::GetLabelVec(int64_t index) const {\n  std::vector<T> label_vec;\n  int64_t image_id = image_ids_.at(index);\n  const auto& anno_ids = image_id2anno_ids_.at(image_id);\n  for (int64_t anno_id : anno_ids) {\n    int32_t category_id = anno_id2anno_.at(anno_id)[\"category_id\"].get<int32_t>();\n    label_vec.emplace_back(category_id2contiguous_id_.at(category_id));\n  }\n  return label_vec;\n}\n\ntemplate<typename T>\nvoid COCOMeta::ReadSegmentationsToTensorBuffer(int64_t index, TensorBuffer* segm,\n                                               TensorBuffer* segm_index) const {\n  if (segm == nullptr || segm_index == nullptr) { return; }\n  int64_t image_id = image_ids_.at(index);\n  const auto& anno_ids = image_id2anno_ids_.at(image_id);\n  std::vector<T> segm_vec;\n  for (int64_t anno_id : anno_ids) {\n    const auto& segm_json = anno_id2anno_.at(anno_id)[\"segmentation\"];\n    if (!segm_json.is_array()) { continue; }\n    for (const auto& poly_json : segm_json) {\n      CHECK(poly_json.is_array());\n      for (const auto& elem : poly_json) { segm_vec.emplace_back(elem.get<T>()); }\n    }\n  }\n  CHECK_EQ(segm_vec.size() % 2, 0);\n  int64_t num_pts = segm_vec.size() / 2;\n  segm->Resize(Shape({num_pts, 2}), GetDataType<T>::value);\n  std::copy(segm_vec.begin(), segm_vec.end(), segm->mut_data<T>());\n\n  segm_index->Resize(Shape({num_pts, 3}), DataType::kInt32);\n  int32_t* index_ptr = segm_index->mut_data<int32_t>();\n  int i = 0;\n  int32_t segm_idx = 0;\n  for (int64_t anno_id : anno_ids) {\n    const auto& segm_json = anno_id2anno_.at(anno_id)[\"segmentation\"];\n    CHECK(segm_json.is_array());\n    FOR_RANGE(int32_t, poly_idx, 0, segm_json.size()) {\n      const auto& poly_json = segm_json[poly_idx];\n      CHECK(poly_json.is_array());\n      CHECK_EQ(poly_json.size() % 2, 0);\n      FOR_RANGE(int32_t, pt_idx, 0, poly_json.size() / 2) {\n        index_ptr[i * 3 + 0] = pt_idx;\n        index_ptr[i * 3 + 1] = poly_idx;\n        index_ptr[i * 3 + 2] = segm_idx;\n        i += 1;\n      }\n    }\n    segm_idx += 1;\n  }\n  CHECK_EQ(i, num_pts);\n}\n\n}  // namespace data\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_DATA_COCO_DATA_READER_H_\n"
  },
  {
    "path": "oneflow/user/data/coco_dataset.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/user/data/coco_dataset.h\"\n#include \"oneflow/user/data/coco_data_reader.h\"\n#include \"oneflow/core/persistence/file_system.h\"\n#include \"oneflow/core/persistence/persistent_in_stream.h\"\n\nnamespace oneflow {\nnamespace data {\n\nCOCODataset::BatchType COCODataset::At(int64_t index) const {\n  BatchType batch;\n  batch.push_back(COCOImage());\n  auto& sample = batch.back();\n  sample.index = index;\n  sample.id = meta_->GetImageId(index);\n  sample.height = meta_->GetImageHeight(index);\n  sample.width = meta_->GetImageWidth(index);\n  const std::string& image_file_path = meta_->GetImageFilePath(index);\n  PersistentInStream in_stream(session_id_, DataFS(), image_file_path);\n  int64_t file_size = DataFS()->GetFileSize(image_file_path);\n  sample.data.Resize(Shape({file_size}), DataType::kChar);\n  CHECK_EQ(in_stream.ReadFully(sample.data.mut_data<char>(), sample.data.nbytes()), 0);\n  return batch;\n}\n\nsize_t COCODataset::Size() const { return meta_->Size(); }\n\n}  // namespace data\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/data/coco_dataset.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_DATA_COCO_DATASET_H_\n#define ONEFLOW_USER_DATA_COCO_DATASET_H_\n\n#include \"oneflow/user/data/dataset.h\"\n#include \"oneflow/core/framework/op_kernel.h\"\n\nnamespace oneflow {\nnamespace data {\n\nstruct COCOImage {\n  TensorBuffer data;\n  int64_t index;\n  int64_t id;\n  int32_t height;\n  int32_t width;\n};\n\nclass COCOMeta;\n\nclass COCODataset final : public RandomAccessDataset<COCOImage> {\n public:\n  using Base = RandomAccessDataset<COCOImage>;\n  using SampleType = typename Base::SampleType;\n  using BatchType = typename Base::BatchType;\n\n  COCODataset(user_op::KernelInitContext* ctx, const std::shared_ptr<const COCOMeta>& meta)\n      : meta_(meta), session_id_(ctx->Attr<int64_t>(\"session_id\")) {}\n  ~COCODataset() = default;\n\n  BatchType At(int64_t index) const override;\n  size_t Size() const override;\n\n private:\n  std::shared_ptr<const COCOMeta> meta_;\n  int64_t session_id_;\n};\n\n}  // namespace data\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_DATA_COCO_DATASET_H_\n"
  },
  {
    "path": "oneflow/user/data/coco_parser.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/user/data/coco_parser.h\"\n#include \"oneflow/user/data/coco_data_reader.h\"\n#include \"oneflow/core/thread/thread_manager.h\"\n\nnamespace oneflow {\nnamespace data {\n\nvoid COCOParser::Parse(BatchType& batch_data, user_op::KernelComputeContext* ctx) {\n  user_op::Tensor* image_tensor = ctx->Tensor4ArgNameAndIndex(\"image\", 0);\n  CHECK_NOTNULL(image_tensor);\n  user_op::Tensor* image_id_tensor = ctx->Tensor4ArgNameAndIndex(\"image_id\", 0);\n  user_op::Tensor* image_size_tensor = ctx->Tensor4ArgNameAndIndex(\"image_size\", 0);\n  user_op::Tensor* bbox_tensor = ctx->Tensor4ArgNameAndIndex(\"gt_bbox\", 0);\n  user_op::Tensor* label_tensor = ctx->Tensor4ArgNameAndIndex(\"gt_label\", 0);\n  user_op::Tensor* segm_tensor = ctx->Tensor4ArgNameAndIndex(\"gt_segm\", 0);\n  user_op::Tensor* segm_index_tensor = ctx->Tensor4ArgNameAndIndex(\"gt_segm_index\", 0);\n\n  MultiThreadLoop(batch_data.size(), [&](size_t i) {\n    TensorBuffer* image_buffer = image_tensor->mut_dptr<TensorBuffer>() + i;\n    COCOImage& image = batch_data[i];\n    image_buffer->Swap(image.data);\n    if (image_size_tensor) {\n      auto* image_size_ptr = image_size_tensor->mut_dptr<int32_t>() + i * 2;\n      image_size_ptr[0] = meta_->GetImageHeight(image.index);\n      image_size_ptr[1] = meta_->GetImageWidth(image.index);\n    }\n    if (image_id_tensor) {\n      auto* image_id_ptr = image_id_tensor->mut_dptr<int64_t>();\n      image_id_ptr[i] = image.id;\n    }\n    if (bbox_tensor) {\n      TensorBuffer* bbox_buffer = bbox_tensor->mut_dptr<TensorBuffer>() + i;\n      const auto& bbox_vec = meta_->GetBboxVec<float>(image.index);\n      CHECK_EQ(bbox_vec.size() % 4, 0);\n      int64_t num_bboxes = bbox_vec.size() / 4;\n      bbox_buffer->Resize(Shape({num_bboxes, 4}), DataType::kFloat);\n      std::copy(bbox_vec.begin(), bbox_vec.end(), bbox_buffer->mut_data<float>());\n    }\n    if (label_tensor) {\n      TensorBuffer* label_buffer = label_tensor->mut_dptr<TensorBuffer>() + i;\n      const auto& label_vec = meta_->GetLabelVec<int32_t>(image.index);\n      label_buffer->Resize(Shape({static_cast<int64_t>(label_vec.size())}), DataType::kInt32);\n      std::copy(label_vec.begin(), label_vec.end(), label_buffer->mut_data<int32_t>());\n    }\n    if (segm_tensor && segm_index_tensor) {\n      TensorBuffer* segm_buffer = segm_tensor->mut_dptr<TensorBuffer>() + i;\n      TensorBuffer* segm_index_buffer = segm_index_tensor->mut_dptr<TensorBuffer>() + i;\n      meta_->ReadSegmentationsToTensorBuffer<float>(image.index, segm_buffer, segm_index_buffer);\n    }\n  });\n  // dynamic batch size\n  if (image_tensor->shape_view().elem_cnt() != batch_data.size()) {\n    CHECK_EQ(image_tensor->shape_view().NumAxes(), 1);\n    image_tensor->mut_shape_view().Set(0, batch_data.size());\n  }\n  if (image_id_tensor && image_id_tensor->shape_view().At(0) != batch_data.size()) {\n    image_id_tensor->mut_shape_view().Set(0, batch_data.size());\n  }\n  if (image_size_tensor && image_size_tensor->shape_view().At(0) != batch_data.size()) {\n    image_size_tensor->mut_shape_view().Set(0, batch_data.size());\n  }\n  if (bbox_tensor && bbox_tensor->shape_view().elem_cnt() != batch_data.size()) {\n    CHECK_EQ(bbox_tensor->shape_view().NumAxes(), 1);\n    bbox_tensor->mut_shape_view().Set(0, batch_data.size());\n  }\n  if (label_tensor && label_tensor->shape_view().elem_cnt() != batch_data.size()) {\n    CHECK_EQ(label_tensor->shape_view().NumAxes(), 1);\n    label_tensor->mut_shape_view().Set(0, batch_data.size());\n  }\n  if (segm_tensor && segm_index_tensor\n      && segm_tensor->shape_view().elem_cnt() != batch_data.size()) {\n    CHECK_EQ(segm_tensor->shape_view().NumAxes(), 1);\n    CHECK_EQ(segm_index_tensor->shape_view().NumAxes(), 1);\n    CHECK_EQ(segm_tensor->shape_view().elem_cnt(), segm_index_tensor->shape_view().elem_cnt());\n    segm_tensor->mut_shape_view().Set(0, batch_data.size());\n    segm_index_tensor->mut_shape_view().Set(0, batch_data.size());\n  }\n}\n\n}  // namespace data\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/data/coco_parser.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_DATA_COCO_PARSER_H_\n#define ONEFLOW_USER_DATA_COCO_PARSER_H_\n\n#include \"oneflow/user/data/parser.h\"\n#include \"oneflow/user/data/coco_dataset.h\"\n\nnamespace oneflow {\nnamespace data {\n\nclass COCOMeta;\n\nclass COCOParser final : public Parser<COCOImage> {\n public:\n  using Base = Parser<COCOImage>;\n  using SampleType = typename Base::SampleType;\n  using BatchType = typename Base::BatchType;\n\n  COCOParser(const std::shared_ptr<const COCOMeta>& meta) : meta_(meta){};\n  ~COCOParser() = default;\n\n  void Parse(BatchType& batch_data, user_op::KernelComputeContext* ctx) override;\n\n private:\n  std::shared_ptr<const COCOMeta> meta_;\n};\n\n}  // namespace data\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_DATA_COCO_PARSER_H_\n"
  },
  {
    "path": "oneflow/user/data/data_reader.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_DATA_DATA_READER_H_\n#define ONEFLOW_USER_DATA_DATA_READER_H_\n\n#include \"oneflow/user/data/dataset.h\"\n#include \"oneflow/user/data/parser.h\"\n#include \"oneflow/core/common/buffer.h\"\n\nnamespace oneflow {\n\nnamespace data {\n\nstatic const int32_t kDataReaderBatchBufferSize = 4;\n\ntemplate<typename LoadTarget>\nclass DataReader {\n public:\n  using SampleType = LoadTarget;\n  using BatchType = std::vector<SampleType>;\n\n  DataReader(user_op::KernelInitContext* ctx)\n      : is_closed_(false), batch_buffer_(kDataReaderBatchBufferSize) {}\n\n  virtual ~DataReader() {\n    Close();\n    if (load_thrd_.joinable()) { load_thrd_.join(); }\n  }\n\n  void Read(user_op::KernelComputeContext* ctx) {\n    CHECK(load_thrd_.joinable()) << \"You should call StartLoadThread before read data\";\n    auto batch = FetchBatchData();\n    parser_->Parse(batch, ctx);\n  }\n\n  void Close() {\n    if (!is_closed_.load()) {\n      is_closed_.store(true);\n      batch_buffer_.Close();\n    }\n  }\n\n protected:\n  void StartLoadThread() {\n    if (load_thrd_.joinable()) { return; }\n    load_thrd_ = std::thread([this] {\n      while (!is_closed_.load() && LoadBatch()) {}\n    });\n  }\n\n  std::unique_ptr<Dataset<LoadTarget>> loader_;\n  std::unique_ptr<Parser<LoadTarget>> parser_;\n\n private:\n  BatchType FetchBatchData() {\n    BatchType batch;\n    CHECK_EQ(batch_buffer_.Pull(&batch), BufferStatus::kBufferStatusSuccess);\n    return batch;\n  }\n\n  bool LoadBatch() {\n    BatchType batch = loader_->Next();\n    return batch_buffer_.Push(std::move(batch)) == BufferStatus::kBufferStatusSuccess;\n  }\n\n  std::atomic<bool> is_closed_;\n  Buffer<BatchType> batch_buffer_;\n  std::thread load_thrd_;\n};\n\n}  // namespace data\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_DATA_DATA_READER_H_\n"
  },
  {
    "path": "oneflow/user/data/dataset.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_DATA_DATASET_H_\n#define ONEFLOW_USER_DATA_DATASET_H_\n\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/common/tensor_buffer.h\"\n\nnamespace oneflow {\nnamespace data {\n\nstatic constexpr int kOneflowDatasetSeed = 524287;\n\ntemplate<typename LoadTarget>\nclass Dataset {\n public:\n  using SampleType = LoadTarget;\n  using BatchType = std::vector<SampleType>;\n\n  Dataset() = default;\n  virtual ~Dataset() = default;\n\n  virtual BatchType Next() = 0;\n};\n\ntemplate<typename LoadTarget>\nclass RandomAccessDataset : public Dataset<LoadTarget> {\n public:\n  using Base = Dataset<LoadTarget>;\n  using SampleType = typename Base::SampleType;\n  using BatchType = typename Base::BatchType;\n\n  RandomAccessDataset() : cur_idx_(0) {}\n  virtual ~RandomAccessDataset() = default;\n\n  virtual BatchType At(int64_t index) const = 0;\n  virtual size_t Size() const = 0;\n\n  BatchType Next() final {\n    BatchType ret = this->At(cur_idx_);\n    cur_idx_ += 1;\n    if (cur_idx_ >= this->Size()) { cur_idx_ %= this->Size(); }\n    return ret;\n  }\n\n private:\n  int64_t cur_idx_;\n};\n\n}  // namespace data\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_DATA_DATASET_H_\n"
  },
  {
    "path": "oneflow/user/data/distributed_training_dataset.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_DATA_DISTRIBUTED_TRAINING_DATASET_H_\n#define ONEFLOW_USER_DATA_DISTRIBUTED_TRAINING_DATASET_H_\n\n#include \"oneflow/user/data/dataset.h\"\n\nnamespace oneflow {\nnamespace data {\n\ntemplate<typename LoadTarget>\nclass DistributedTrainingDataset final : public Dataset<LoadTarget> {\n public:\n  using Base = Dataset<LoadTarget>;\n  using SampleType = typename Base::SampleType;\n  using BatchType = typename Base::BatchType;\n  using NestedDS = RandomAccessDataset<LoadTarget>;\n\n  DistributedTrainingDataset(int64_t parallel_num, int64_t parallel_id, bool stride_partition,\n                             bool shuffle, int64_t random_seed, std::unique_ptr<NestedDS>&& dataset)\n      : nested_ds_(std::move(dataset)),\n        shuffle_(shuffle),\n        stride_partition_(stride_partition),\n        rnd_seed_(random_seed),\n        num_shards_(parallel_num),\n        pos_(0),\n        pos_in_shard_(0),\n        epoch_cnt_(0) {\n    shard_size_ = std::ceil(static_cast<float>(nested_ds_->Size()) / num_shards_);\n    if (stride_partition) {\n      pos_ = parallel_id;\n    } else {\n      pos_ = parallel_id * shard_size_;\n    }\n    index_seq_.resize(nested_ds_->Size());\n    std::iota(index_seq_.begin(), index_seq_.end(), 0);\n    GenNewIndexSequence();\n  }\n  virtual ~DistributedTrainingDataset() = default;\n\n  virtual BatchType Next() override {\n    // There are 2 partition strategies\n    // assume epoch size is 10, index seq don't shuffle and there are 4 parts\n    // stride partition strategy (when stride_partition is true):\n    //       |  part1   |  part2   |  part3   |  part4   |\n    // iter0 | 0, 4, 8, | 1, 5, 9, | 2, 6, 0, | 3, 7, 1, |\n    // iter1 | 2, 6, 0, | 3, 7, 1, | 4, 8, 2, | 5, 9, 3, |\n    // contiguous partition strategy (when stride_partition is false):\n    //       |  part1   |  part2   |  part3   |  part4   |\n    // iter0 | 0, 1, 2, | 3, 4, 5, | 6, 7, 8, | 9, 0, 1, |\n    // iter1 | 2, 3, 4, | 5, 6, 7, | 8, 9, 0, | 1, 2, 3, |\n    BatchType batch = nested_ds_->At(index_seq_.at(pos_));\n    if (stride_partition_) {\n      pos_ += num_shards_;\n    } else {\n      pos_ += 1;\n      pos_in_shard_ += 1;\n      if (pos_in_shard_ == shard_size_) {\n        pos_ += (num_shards_ - 1) * shard_size_;\n        pos_in_shard_ = 0;\n      }\n    }\n    CheckRanOutOfSize();\n    return batch;\n  }\n\n private:\n  void CheckRanOutOfSize() {\n    if (pos_ >= index_seq_.size()) {\n      GenNewIndexSequence();\n      pos_ %= index_seq_.size();\n    }\n  }\n\n  void GenNewIndexSequence() {\n    if (shuffle_) {\n      std::mt19937 engine(rnd_seed_ + epoch_cnt_);\n      std::shuffle(index_seq_.begin(), index_seq_.end(), engine);\n    }\n    epoch_cnt_ += 1;\n  }\n\n  std::unique_ptr<NestedDS> nested_ds_;\n  bool shuffle_;\n  bool stride_partition_;\n  int64_t rnd_seed_;\n  int64_t num_shards_;\n  int64_t shard_size_;\n  int64_t pos_;\n  int64_t pos_in_shard_;\n  int64_t epoch_cnt_;\n  std::vector<int64_t> index_seq_;\n};\n\n}  // namespace data\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_DATA_DISTRIBUTED_TRAINING_DATASET_H_\n"
  },
  {
    "path": "oneflow/user/data/distributed_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_DATA_DISTRIBUTED_UTIL_H_\n#define ONEFLOW_USER_DATA_DISTRIBUTED_UTIL_H_\n\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/common/nd_index_offset_helper.h\"\n#include \"oneflow/core/job/sbp_parallel.h\"\n#include \"oneflow/core/rpc/include/global_process_ctx.h\"\n\nnamespace oneflow {\n\nnamespace data {\n\ninline Maybe<void> InitDataSourceDistributedInfo(user_op::KernelInitContext* ctx,\n                                                 size_t& world_size, int64_t& rank) {\n  auto nd_sbp_str_vec = ctx->Attr<std::vector<std::string>>(\"nd_sbp\");\n  if (nd_sbp_str_vec.empty()) {\n    world_size = GlobalProcessCtx::WorldSize();\n    rank = GlobalProcessCtx::Rank();\n  } else {\n    const Shape& hierarchy = *ctx->parallel_desc().hierarchy();\n    CHECK_EQ_OR_RETURN(hierarchy.NumAxes(), nd_sbp_str_vec.size());\n    rank = 0;\n    world_size = 1;\n\n    using index_helper_t = NdIndexOffsetHelper<int64_t, SHAPE_MAX_AXIS_SIZE>;\n    index_helper_t index_helper(hierarchy.dim_vec().data(), hierarchy.NumAxes());\n    int64_t nd_index[SHAPE_MAX_AXIS_SIZE] = {0};\n    index_helper.OffsetToNdIndex(ctx->parallel_ctx().parallel_id(), nd_index);\n\n    for (int i = hierarchy.NumAxes() - 1; i >= 0; --i) {\n      SbpParallel sbp;\n      CHECK_OR_RETURN(ParseSbpParallelFromString(nd_sbp_str_vec[i], &sbp));\n      if (sbp.has_split_parallel()) {\n        rank += nd_index[i] * world_size;\n        world_size *= hierarchy.At(i);\n      }\n    }\n  }\n  return Maybe<void>::Ok();\n}\n\n}  // namespace data\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CUSTOMIZED_DATA_ONEREC_DATA_READER_H_\n"
  },
  {
    "path": "oneflow/user/data/gpt_dataset.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/user/data/gpt_dataset.h\"\n\n#ifdef __linux__\n#include <fcntl.h>\n#include <stdio.h>\n#include <errno.h>\n#include <sys/stat.h>\n#include <sys/mman.h>\n#endif\n\nnamespace oneflow {\n\nnamespace data {\n\nnamespace {\n\nvoid GetSplitDocIndices(std::vector<size_t>* doc_indices, const std::vector<int64_t>& split_sizes,\n                        size_t split_index, size_t num_docs) {\n  CHECK_LT(split_index, split_sizes.size());\n  size_t total_size = 0;\n  FOR_RANGE(size_t, i, 0, split_sizes.size()) { total_size += split_sizes[i]; }\n\n  size_t split_offset = 0;\n  RoundModeGuard round_guard(FE_TONEAREST);\n  FOR_RANGE(size_t, i, 0, split_index) {\n    float ratio = static_cast<float>(split_sizes[i]) / total_size;\n    size_t split_size = static_cast<size_t>(std::nearbyint(ratio * num_docs));\n    split_offset += split_size;\n  }\n\n  float ratio = static_cast<float>(split_sizes[split_index]) / total_size;\n  size_t split_size = static_cast<size_t>(std::nearbyint(ratio * num_docs));\n  doc_indices->resize(split_size);\n  std::iota(doc_indices->begin(), doc_indices->end(), split_offset);\n}\n\nsize_t GetNumEpochs(size_t num_samples, size_t seq_length, size_t tokens_per_epoch) {\n  // num_epochs * tokens_per_epoch >= num_samples * seq_length + 1\n  // +1 is because we need to retrieve seq_length + 1 token each time\n  // but the last token will overlap with the first token of the next\n  // sample except for the last sample.\n  return static_cast<size_t>(\n      std::ceil(static_cast<double>(num_samples * seq_length + 1) / tokens_per_epoch));\n}\n\nsize_t GetNumCompleteEpochs(size_t num_samples, size_t seq_length, size_t tokens_per_epoch) {\n  size_t num_epochs = GetNumEpochs(num_samples, seq_length, tokens_per_epoch);\n  if (num_epochs == 1) { return 1; }\n  size_t num_samples_per_epoch =\n      static_cast<size_t>(std::floor(static_cast<double>(tokens_per_epoch - 1) / seq_length));\n  size_t num_samples_exclude_last_epoch = static_cast<size_t>(\n      std::floor(static_cast<double>((num_epochs - 1) * tokens_per_epoch - 1) / seq_length));\n  CHECK_LE(num_samples_exclude_last_epoch, num_samples);\n  size_t last_epoch_num_samples = num_samples - num_samples_exclude_last_epoch;\n  CHECK_LT(last_epoch_num_samples, num_samples_per_epoch);\n\n  bool separate_last_epoch =\n      last_epoch_num_samples < static_cast<size_t>(0.8f * num_samples_per_epoch);\n  return separate_last_epoch ? (num_epochs - 1) : num_epochs;\n}\n\n}  // namespace\n\nconstexpr char MegatronGPTIndex::kMagicCode[];\n\nMegatronGPTIndex::MegatronGPTIndex(const std::string& index_file_path) {\n  auto start = std::chrono::system_clock::now();\n  std::ifstream stream(index_file_path, std::ios::binary);\n  CHECK(stream.is_open()) << \"can't open dataset index file \" << index_file_path;\n  // verify magic code\n  char magic_code[kMagicCodeLen];\n  stream.read(magic_code, kMagicCodeLen);\n  CHECK_EQ(std::memcmp(magic_code, kMagicCode, kMagicCodeLen), 0);\n  // read version\n  stream.read(reinterpret_cast<char*>(&version_), sizeof(version_));\n  // read dtype\n  stream.read(&dtype_code_, sizeof(dtype_code_));\n  // read size of sizes and doc_offsets\n  uint64_t sizes_size = 0;\n  stream.read(reinterpret_cast<char*>(&sizes_size), sizeof(sizes_size));\n  uint64_t doc_offsets_size = 0;\n  stream.read(reinterpret_cast<char*>(&doc_offsets_size), sizeof(doc_offsets_size));\n  // NOTE: this check is not necessary\n  CHECK_EQ(sizes_size + 1, doc_offsets_size);\n  // read sizes\n  sizes_.resize(sizes_size);\n  stream.read(reinterpret_cast<char*>(sizes_.data()),\n              sizeof(decltype(sizes_)::value_type) * sizes_.size());\n  // read addresses\n  addresses_.resize(sizes_size);\n  stream.read(reinterpret_cast<char*>(addresses_.data()),\n              sizeof(decltype(addresses_)::value_type) * addresses_.size());\n  // read doc_offsets\n  doc_offsets_.resize(doc_offsets_size);\n  stream.read(reinterpret_cast<char*>(doc_offsets_.data()),\n              sizeof(decltype(doc_offsets_)::value_type) * doc_offsets_.size());\n  // check eof\n  int pos = stream.tellg();\n  stream.seekg(0, std::ios_base::end);\n  CHECK_EQ(pos, stream.tellg());\n  // log\n  std::chrono::duration<double, std::milli> elapse = std::chrono::system_clock::now() - start;\n  VLOG(2) << \"Load GPT Dataset index file successed, file_path: \" << index_file_path\n          << \", number of documents: \" << this->num_docs() << \", elapsed time: \" << elapse.count()\n          << \" ms\";\n}\n\nMappedBuffer::MappedBuffer(const std::string& filename) : mapped_(nullptr), size_(0) {\n#ifdef __linux__\n  int fd = open(filename.c_str(), O_RDONLY);\n  CHECK(fd != -1) << \"open \" << filename << \" failed: \" << strerror(errno);\n\n  struct stat s;\n  CHECK(fstat(fd, &s) != -1) << \"stat \" << filename << \" failed: \" << strerror(errno);\n  size_ = s.st_size;\n\n  mapped_ = mmap(nullptr, size_, PROT_READ, MAP_PRIVATE, fd, 0);\n  CHECK(mapped_ != MAP_FAILED) << \"mmap \" << filename << \" failed: \" << strerror(errno);\n\n  close(fd);\n#endif\n}\n\nMappedBuffer::~MappedBuffer() {\n#ifdef __linux__\n  CHECK(munmap(mapped_, size_) == 0) << \"munmap failed\";\n#endif\n}\n\nMegatronGPTMMapDataset::MegatronGPTMMapDataset(const std::string& data_file_prefix, size_t seq_len,\n                                               size_t label_len, size_t num_samples,\n                                               const std::vector<int64_t>& split_sizes,\n                                               size_t split_index, bool shuffle, uint32_t seed)\n    : seq_len_(seq_len),\n      sample_len_(seq_len + label_len),\n      num_samples_(num_samples),\n      shuffle_(shuffle),\n      seed_(seed),\n      gen_(seed) {\n  auto start = std::chrono::system_clock::now();\n  index_ = std::make_unique<const MegatronGPTIndex>(data_file_prefix + \".idx\");\n  data_ = std::make_unique<const MappedBuffer>(data_file_prefix + \".bin\");\n  dtype_size_ = kDTypeCode2Size.at(index_->dtype_code());\n  std::vector<size_t> epoch_doc_indices;\n  GetSplitDocIndices(&epoch_doc_indices, split_sizes, split_index, index_->num_docs());\n  tokens_per_epoch_ = GetEpochNumTokens(epoch_doc_indices);\n  num_epochs_ = GetNumEpochs(num_samples_, seq_len_, tokens_per_epoch_);\n  num_complete_epochs_ = GetNumCompleteEpochs(num_samples_, seq_len_, tokens_per_epoch_);\n  InitDocIndices(epoch_doc_indices, num_epochs_, num_complete_epochs_);\n  size_t total_num_samples = static_cast<size_t>(\n      std::floor(static_cast<double>(num_epochs_ * tokens_per_epoch_ - 1) / seq_len_));\n  InitSampleIndices(total_num_samples);\n  InitShuffleIndices(sample_indices_.size());\n  std::chrono::duration<double, std::milli> elapse = std::chrono::system_clock::now() - start;\n  VLOG(2) << \"Create GPT Dataset successed, sequence length: \" << seq_len_\n          << \", number of samples: \" << num_samples_\n          << \", total number of samples: \" << shuffle_indices_.size()\n          << \", total number of documents: \" << doc_indices_.size()\n          << \", number of epochs: \" << num_epochs_\n          << \", number of complete epochs: \" << num_complete_epochs_\n          << \", shuffle: \" << std::boolalpha << shuffle_ << \", random_seed: \" << seed_\n          << \", elapsed time: \" << elapse.count() << \" ms\";\n}\n\nsize_t MegatronGPTMMapDataset::GetEpochNumTokens(const std::vector<size_t>& doc_indices) const {\n  size_t num_tokens = 0;\n  for (auto doc_index : doc_indices) { num_tokens += index_->doc_length(doc_index); }\n  return num_tokens;\n}\n\nvoid MegatronGPTMMapDataset::InitDocIndices(const std::vector<size_t>& epoch_doc_indices,\n                                            size_t num_epochs, size_t num_complete_epochs) {\n  doc_indices_.reserve(epoch_doc_indices.size() * num_epochs);\n  InitDocIndices(epoch_doc_indices, num_complete_epochs);\n  if (num_epochs != num_complete_epochs) {\n    CHECK_EQ(num_complete_epochs + 1, num_epochs);\n    InitDocIndices(epoch_doc_indices, 1);\n  }\n}\n\nvoid MegatronGPTMMapDataset::InitDocIndices(const std::vector<size_t>& epoch_doc_indices,\n                                            size_t num_epochs) {\n  auto start = std::distance(doc_indices_.cbegin(), doc_indices_.cend());\n  FOR_RANGE(size_t, i, 0, num_epochs) {\n    doc_indices_.insert(doc_indices_.end(), epoch_doc_indices.cbegin(), epoch_doc_indices.cend());\n  }\n  if (shuffle_) { std::shuffle(doc_indices_.begin() + start, doc_indices_.end(), gen_); }\n}\n\nvoid MegatronGPTMMapDataset::InitSampleIndices(size_t total_num_samples) {\n  sample_indices_.reserve(total_num_samples);\n  size_t doc_indices_idx = 0;\n  size_t doc_offset = 0;\n  FOR_RANGE(size_t, i, 0, total_num_samples) {\n    if (doc_indices_idx >= doc_indices_.size()) { break; }\n    sample_indices_.emplace_back(doc_indices_idx, doc_offset);\n    int remaining_tokens = seq_len_;\n    while (remaining_tokens > 0) {\n      CHECK_LT(doc_indices_idx, doc_indices_.size());\n      size_t doc_len = index_->doc_length(doc_indices_[doc_indices_idx]);\n      CHECK_LT(doc_offset, doc_len);\n      doc_len -= doc_offset;\n      if (remaining_tokens < doc_len) {\n        // move offset inside doc\n        doc_offset += remaining_tokens;\n      } else {\n        // move to next doc\n        doc_indices_idx += 1;\n        doc_offset = 0;\n      }\n      remaining_tokens -= doc_len;\n    }\n  }\n  CHECK_EQ(sample_indices_.size(), total_num_samples);\n  CHECK_GE(sample_indices_.size(), num_samples_);\n}\n\nvoid MegatronGPTMMapDataset::InitShuffleIndices(size_t total_num_samples) {\n  shuffle_indices_.resize(total_num_samples);\n  std::iota(shuffle_indices_.begin(), shuffle_indices_.end(), 0);\n  if (shuffle_) {\n    size_t num_samples = static_cast<size_t>(\n        std::floor(static_cast<double>(num_complete_epochs_ * tokens_per_epoch_ - 1) / seq_len_));\n    CHECK_LE(num_samples, shuffle_indices_.size());\n    std::shuffle(shuffle_indices_.begin(), shuffle_indices_.begin() + num_samples, gen_);\n    if (num_complete_epochs_ != num_epochs_) {\n      std::shuffle(shuffle_indices_.begin() + num_samples, shuffle_indices_.end(), gen_);\n    }\n  }\n}\n\nconst HashMap<char, size_t> MegatronGPTMMapDataset::kDTypeCode2Size = {\n    {1, 1},  // DataType::kUInt8\n    {2, 1},  // DataType::kInt8\n    {3, 2},  // DataType::kInt16\n    {4, 4},  // DataType::kInt32\n    {5, 8},  // DataType::kInt64\n    {6, 4},  // DataType::kFloat\n    {7, 8},  // DataType::kDouble\n    {8, 2},  // DataType::kUInt16\n};\n\n}  // namespace data\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/data/gpt_dataset.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_DATA_GPT_DATASET_H_\n#define ONEFLOW_USER_DATA_GPT_DATASET_H_\n\n#include \"oneflow/core/common/util.h\"\n\nnamespace oneflow {\n\nnamespace data {\n\nclass MegatronGPTIndex final {\n public:\n  MegatronGPTIndex(const std::string& index_file);\n  ~MegatronGPTIndex() = default;\n\n  static constexpr char kMagicCode[] = \"MMIDIDX\\x00\\x00\";\n  static constexpr size_t kMagicCodeLen = sizeof(kMagicCode) - 1;\n\n  uint64_t version() const { return version_; }\n  char dtype_code() const { return dtype_code_; }\n  size_t num_docs() const { return sizes_.size(); }\n  size_t doc_length(size_t doc_index) const { return sizes_.at(doc_index); }\n  size_t doc_offset(size_t doc_index) const { return doc_offsets_.at(doc_index); }\n  size_t address(size_t doc_index) const { return addresses_.at(doc_index); }\n\n private:\n  uint64_t version_;\n  char dtype_code_;\n  std::vector<int32_t> sizes_;\n  std::vector<int64_t> addresses_;\n  std::vector<int64_t> doc_offsets_;\n};\n\nclass MappedBuffer final {\n public:\n  MappedBuffer(const std::string& filename);\n  ~MappedBuffer();\n\n  const void* ptr() const { return mapped_; }\n  size_t size() const { return size_; }\n\n private:\n  void* mapped_;\n  size_t size_;\n};\n\nclass MegatronGPTMMapDataset final {\n public:\n  MegatronGPTMMapDataset(const std::string& data_file_prefix, size_t seq_len, size_t label_len,\n                         size_t num_samples, const std::vector<int64_t>& split_sizes,\n                         size_t split_index, bool shuffle, uint32_t seed);\n  OF_DISALLOW_COPY_AND_MOVE(MegatronGPTMMapDataset);\n  ~MegatronGPTMMapDataset() = default;\n\n  template<typename T>\n  void GetSample(size_t index, T* data) const;\n\n private:\n  static const HashMap<char, size_t> kDTypeCode2Size;\n\n  size_t GetEpochNumTokens(const std::vector<size_t>& doc_indices) const;\n  void InitDocIndices(const std::vector<size_t>& epoch_doc_indices, size_t num_epochs,\n                      size_t num_complete_epochs);\n  void InitDocIndices(const std::vector<size_t>& doc_indices, size_t num_epochs);\n  void InitSampleIndices(size_t total_num_samples);\n  void InitShuffleIndices(size_t total_num_samples);\n  template<typename T>\n  void ReadTokens(const void* src, size_t offset, T* dst, size_t size) const;\n\n  // initializer list\n  size_t seq_len_;\n  size_t sample_len_;\n  size_t num_samples_;\n  bool shuffle_;\n  uint32_t seed_;\n  std::mt19937 gen_;\n\n  // initializing in constructor (in order as below)\n  std::unique_ptr<const MegatronGPTIndex> index_;\n  std::unique_ptr<const MappedBuffer> data_;\n  size_t dtype_size_;\n  size_t tokens_per_epoch_;\n  size_t num_epochs_;\n  size_t num_complete_epochs_;\n  std::vector<size_t> doc_indices_;\n  std::vector<std::pair<size_t, size_t>> sample_indices_;\n  std::vector<size_t> shuffle_indices_;\n};\n\ntemplate<typename T>\nvoid MegatronGPTMMapDataset::GetSample(size_t index, T* data) const {\n  CHECK_LT(index, shuffle_indices_.size());\n  const size_t sample_index = shuffle_indices_[index];\n  CHECK_LT(sample_index, sample_indices_.size());\n  size_t doc_indices_idx = sample_indices_[sample_index].first;\n  size_t doc_offset = sample_indices_[sample_index].second;\n  int remaining_tokens = sample_len_;\n  while (remaining_tokens > 0) {\n    CHECK_LT(doc_indices_idx, doc_indices_.size());\n    const size_t doc_index = doc_indices_[doc_indices_idx];\n    size_t offset = index_->address(doc_index) + doc_offset * dtype_size_;\n    size_t num_tokens = index_->doc_length(doc_index);\n    CHECK_LT(doc_offset, num_tokens);\n    num_tokens -= doc_offset;\n    if (num_tokens > remaining_tokens) {\n      num_tokens = remaining_tokens;\n    } else {\n      doc_indices_idx += 1;\n      doc_offset = 0;\n    }\n    ReadTokens(data_->ptr(), offset, data, num_tokens);\n    data += num_tokens;\n    remaining_tokens -= num_tokens;\n  }\n  CHECK_EQ(remaining_tokens, 0);\n}\n\ntemplate<typename T>\nvoid MegatronGPTMMapDataset::ReadTokens(const void* src, size_t bytes_offset, T* dst,\n                                        size_t size) const {\n  CHECK_NOTNULL(src);\n  switch (index_->dtype_code()) {\n#define SWITCH_CASE_ENTRY(type_code, type)                                           \\\n  case type_code: {                                                                  \\\n    const auto* src_ptr =                                                            \\\n        reinterpret_cast<const type*>(static_cast<const char*>(src) + bytes_offset); \\\n    std::copy(src_ptr, src_ptr + size, dst);                                         \\\n    break;                                                                           \\\n  }\n\n    SWITCH_CASE_ENTRY(1, uint8_t)\n    SWITCH_CASE_ENTRY(2, int8_t)\n    SWITCH_CASE_ENTRY(3, int16_t)\n    SWITCH_CASE_ENTRY(4, int32_t)\n    SWITCH_CASE_ENTRY(5, int64_t)\n    SWITCH_CASE_ENTRY(6, float)\n    SWITCH_CASE_ENTRY(7, double)\n    SWITCH_CASE_ENTRY(8, uint16_t)\n#undef SWITCH_CASE_ENTRY\n    default: {\n      UNIMPLEMENTED();\n    }\n  }\n}\n\n}  // namespace data\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_DATA_GPT_DATASET_H_\n"
  },
  {
    "path": "oneflow/user/data/group_batch_dataset.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_DATA_GROUP_BATCH_DATASET_H_\n#define ONEFLOW_USER_DATA_GROUP_BATCH_DATASET_H_\n\n#include \"oneflow/user/data/dataset.h\"\n\nnamespace oneflow {\nnamespace data {\n\ntemplate<typename LoadTarget>\nclass GroupBatchDataset final : public Dataset<LoadTarget> {\n public:\n  using Base = Dataset<LoadTarget>;\n  using SampleType = typename Base::SampleType;\n  using BatchType = typename Base::BatchType;\n  using NestedDS = Dataset<LoadTarget>;\n\n  GroupBatchDataset(size_t batch_size,\n                    const std::function<int64_t(const SampleType&)>& GroupId4Sample,\n                    std::unique_ptr<NestedDS>&& dataset)\n      : nested_ds_(std::move(dataset)),\n        batch_size_(batch_size),\n        group_fn_(GroupId4Sample),\n        order_count_(0) {}\n  ~GroupBatchDataset() = default;\n\n  BatchType Next() override {\n    BatchType batch;\n    int64_t group_id = FindEarliestBatchGroupId();\n    auto group_it = group_id2buffered_samples_.find(group_id);\n    if (group_it != group_id2buffered_samples_.end()) {\n      auto& batch_sample_list = group_it->second;\n      if (!batch_sample_list.empty()) {\n        std::swap(batch, batch_sample_list.front().data);\n        batch_sample_list.pop_front();\n      }\n    }\n    while (batch.size() < batch_size_) {\n      auto next_batch = nested_ds_->Next();\n      CHECK_EQ(next_batch.size(), 1);\n      int64_t next_group_id = group_fn_(next_batch[0]);\n      if (group_id == -1) { group_id = next_group_id; }\n      if (group_id == next_group_id) {\n        batch.emplace_back(std::move(next_batch[0]));\n      } else {\n        auto group_it = group_id2buffered_samples_.find(next_group_id);\n        if (group_it == group_id2buffered_samples_.end()) {\n          group_it =\n              group_id2buffered_samples_.emplace(next_group_id, std::list<BatchSample>()).first;\n        }\n        auto& batch_sample_list = group_it->second;\n        if (batch_sample_list.empty() || batch_sample_list.back().data.size() == batch_size_) {\n          BatchSample batch_sample;\n          std::swap(batch_sample.data, next_batch);\n          batch_sample.data.reserve(batch_size_);\n          batch_sample.order = order_count_++;\n          batch_sample_list.emplace_back(std::move(batch_sample));\n        } else {\n          batch_sample_list.back().data.emplace_back(std::move(next_batch[0]));\n        }\n      }\n    }\n    return batch;\n  }\n\n private:\n  int64_t FindEarliestBatchGroupId() const {\n    int64_t group_id = -1;\n    int64_t min_order = -1;\n    for (const auto& pair : group_id2buffered_samples_) {\n      if (pair.second.size() > 0) {\n        if (min_order == -1 || pair.second.front().order < min_order) {\n          min_order = pair.second.front().order;\n          group_id = pair.first;\n        }\n      }\n    }\n    return group_id;\n  }\n\n  struct BatchSample {\n    BatchType data;\n    int64_t order;\n  };\n\n  std::unique_ptr<NestedDS> nested_ds_;\n  size_t batch_size_;\n  std::function<int64_t(const SampleType&)> group_fn_;\n  std::map<int64_t, std::list<BatchSample>> group_id2buffered_samples_;\n  int64_t order_count_;\n};\n\n}  // namespace data\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_DATA_GROUP_BATCH_DATASET_H_\n"
  },
  {
    "path": "oneflow/user/data/ofrecord_data_reader.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_DATA_OFRECORD_DATA_READER_H_\n#define ONEFLOW_USER_DATA_OFRECORD_DATA_READER_H_\n\n#include \"oneflow/user/data/data_reader.h\"\n#include \"oneflow/user/data/ofrecord_dataset.h\"\n#include \"oneflow/user/data/ofrecord_parser.h\"\n#include \"oneflow/user/data/random_shuffle_dataset.h\"\n#include \"oneflow/user/data/batch_dataset.h\"\n\nnamespace oneflow {\nnamespace data {\n\nclass OFRecordDataReader final : public DataReader<TensorBuffer> {\n public:\n  OFRecordDataReader(user_op::KernelInitContext* ctx) : DataReader<TensorBuffer>(ctx) {\n    batch_size_ = ctx->TensorDesc4ArgNameAndIndex(\"out\", 0)->shape().elem_cnt();\n    if (auto* pool = TensorBufferPool::TryGet()) { pool->IncreasePoolSizeByBase(batch_size_); }\n    loader_.reset(new OFRecordDataset(ctx));\n    if (ctx->Attr<bool>(\"random_shuffle\")) {\n      loader_.reset(new RandomShuffleDataset<TensorBuffer>(ctx, std::move(loader_)));\n    }\n    loader_.reset(new BatchDataset<TensorBuffer>(batch_size_, std::move(loader_)));\n    parser_.reset(new OFRecordParser());\n    StartLoadThread();\n  }\n\n  ~OFRecordDataReader() override {\n    if (auto* pool = TensorBufferPool::TryGet()) { pool->DecreasePoolSizeByBase(batch_size_); }\n  }\n\n protected:\n  using DataReader<TensorBuffer>::loader_;\n  using DataReader<TensorBuffer>::parser_;\n\n private:\n  size_t batch_size_;\n};\n\n}  // namespace data\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_DATA_OFRECORD_DATA_READER_H_\n"
  },
  {
    "path": "oneflow/user/data/ofrecord_dataset.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_DATA_OFRECORD_DATASET_H_\n#define ONEFLOW_USER_DATA_OFRECORD_DATASET_H_\n\n#include \"oneflow/core/common/balanced_splitter.h\"\n#include \"oneflow/core/common/str_util.h\"\n#include \"oneflow/core/framework/op_kernel.h\"\n#include \"oneflow/core/persistence/persistent_in_stream.h\"\n#include \"oneflow/core/job/job_set.pb.h\"\n#include \"oneflow/core/rpc/include/global_process_ctx.h\"\n#include \"oneflow/core/job/env_desc.h\"\n#include \"oneflow/user/data/dataset.h\"\n\nnamespace oneflow {\nnamespace data {\n\nclass OFRecordDataset final : public Dataset<TensorBuffer> {\n public:\n  using Base = Dataset<TensorBuffer>;\n  using SampleType = typename Base::SampleType;\n  using BatchType = typename Base::BatchType;\n\n  OF_DISALLOW_COPY_AND_MOVE(OFRecordDataset);\n\n  OFRecordDataset(user_op::KernelInitContext* ctx) {\n    current_epoch_ = 0;\n    shuffle_after_epoch_ = ctx->Attr<bool>(\"shuffle_after_epoch\");\n\n    // in stream\n    data_part_num_ = ctx->Attr<int32_t>(\"data_part_num\");\n    std::string data_dir = ctx->Attr<std::string>(\"data_dir\");\n    std::string part_name_prefix = ctx->Attr<std::string>(\"part_name_prefix\");\n    int32_t part_name_suffix_length = ctx->Attr<int32_t>(\"part_name_suffix_length\");\n\n    for (int i = 0; i < data_part_num_; ++i) {\n      std::string num = std::to_string(i);\n      int32_t zero_count =\n          std::max(part_name_suffix_length - static_cast<int32_t>(num.length()), 0);\n      data_file_paths_.emplace_back(\n          JoinPath(data_dir, part_name_prefix + std::string(zero_count, '0') + num));\n    }\n\n    bool is_local = false;\n    // NOTE(zwx): OFRecordDataset is used by OFRecordDataReader and\n    // OFRecordImageClassificationDataReader both, the latter has no attr nd_sbp,\n    // so it couldn't work in DDP for now. The If condition here could be removed when\n    // OFRecordImageClassificationDataReader had supported DDP (add attr nd_sbp)\n    // or been deprecated.\n    if (ctx->op_type_name() == \"OFRecordReader\") {\n      auto nd_sbp_str_vec = ctx->Attr<std::vector<std::string>>(\"nd_sbp\");\n      // NOTE(zwx): OFRecordDataset is not global since attr nd_sbp is empty,\n      // we assume that it works in DDP\n      if (nd_sbp_str_vec.empty()) { is_local = true; }\n    }\n    if (is_local) {\n      parallel_id_ = GlobalProcessCtx::Rank();\n      parallel_num_ = GlobalProcessCtx::WorldSize();\n    } else {\n      parallel_id_ = ctx->parallel_ctx().parallel_id();\n      parallel_num_ = ctx->parallel_ctx().parallel_num();\n    }\n    CHECK_LE(parallel_num_, data_part_num_);\n    BalancedSplitter bs(data_part_num_, parallel_num_);\n    range_ = bs.At(parallel_id_);\n    std::vector<std::string> local_file_paths = GetLocalFilePaths();\n    in_stream_.reset(\n        new PersistentInStream(DataFS(), local_file_paths, !shuffle_after_epoch_, false));\n  }\n  ~OFRecordDataset() = default;\n\n  BatchType Next() override {\n    BatchType batch;\n    batch.push_back(TensorBuffer());\n    ReadSample(batch.back());\n    return batch;\n  }\n\n private:\n  void ReadSample(TensorBuffer& tensor) {\n    int64_t OFRecord_size = -1;\n    char* size_ptr = reinterpret_cast<char*>(&OFRecord_size);\n    if (in_stream_->ReadFully(size_ptr, sizeof(int64_t)) != 0) {\n      ShuffleAfterEpoch();\n      CHECK_EQ(in_stream_->ReadFully(size_ptr, sizeof(int64_t)), 0);\n    }\n    CHECK_GT(OFRecord_size, 0);\n    tensor.Resize(Shape({OFRecord_size}), DataType::kChar);\n    CHECK_EQ(in_stream_->ReadFully(tensor.mut_data<char>(), OFRecord_size), 0);\n  }\n\n  void ShuffleAfterEpoch() {\n    CHECK(shuffle_after_epoch_);\n    current_epoch_++;  // move to next epoch\n    std::mt19937 g(kOneflowDatasetSeed + current_epoch_);\n    std::shuffle(data_file_paths_.begin(), data_file_paths_.end(), g);\n    std::vector<std::string> local_file_paths = GetLocalFilePaths();\n    in_stream_.reset(new PersistentInStream(DataFS(), local_file_paths, false, false));\n  }\n\n  std::vector<std::string> GetLocalFilePaths() {\n    std::vector<std::string> ret;\n    for (int i = range_.begin(); i < range_.end(); ++i) {\n      ret.emplace_back(data_file_paths_.at(i));\n    }\n    return ret;\n  }\n\n  int32_t current_epoch_;\n  bool shuffle_after_epoch_;\n\n  int32_t data_part_num_;\n  int32_t parallel_id_;\n  int32_t parallel_num_;\n  Range range_;\n  std::vector<std::string> data_file_paths_;\n  std::unique_ptr<PersistentInStream> in_stream_;\n};\n\n}  // namespace data\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_DATA_OFRECORD_DATASET_H_\n"
  },
  {
    "path": "oneflow/user/data/ofrecord_image_classification_data_reader.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_DATA_OFRECORD_IMAGE_CLASSIFICATION_DATA_READER_H_\n#define ONEFLOW_USER_DATA_OFRECORD_IMAGE_CLASSIFICATION_DATA_READER_H_\n\n#include \"oneflow/user/data/data_reader.h\"\n#include \"oneflow/user/data/ofrecord_dataset.h\"\n#include \"oneflow/user/data/ofrecord_parser.h\"\n#include \"oneflow/user/data/random_shuffle_dataset.h\"\n#include \"oneflow/user/data/batch_dataset.h\"\n#include \"oneflow/user/data/ofrecord_image_classification_dataset.h\"\n#include \"oneflow/user/data/ofrecord_image_classification_parser.h\"\n\nnamespace oneflow {\n\nnamespace data {\n\nclass OFRecordImageClassificationDataReader final\n    : public DataReader<ImageClassificationDataInstance> {\n public:\n  explicit OFRecordImageClassificationDataReader(user_op::KernelInitContext* ctx)\n      : DataReader<ImageClassificationDataInstance>(ctx) {\n    batch_size_ = ctx->TensorDesc4ArgNameAndIndex(\"image\", 0)->shape().elem_cnt();\n    if (auto* pool = TensorBufferPool::TryGet()) { pool->IncreasePoolSizeByBase(batch_size_); }\n    std::unique_ptr<Dataset<TensorBuffer>> base(new OFRecordDataset(ctx));\n    if (ctx->Attr<bool>(\"random_shuffle\")) {\n      base.reset(new RandomShuffleDataset<TensorBuffer>(ctx, std::move(base)));\n    }\n    loader_.reset(new OFRecordImageClassificationDataset(ctx, std::move(base)));\n\n    loader_.reset(\n        new BatchDataset<ImageClassificationDataInstance>(batch_size_, std::move(loader_)));\n    parser_.reset(new OFRecordImageClassificationParser());\n    StartLoadThread();\n  }\n\n  ~OFRecordImageClassificationDataReader() override {\n    if (auto* pool = TensorBufferPool::TryGet()) { pool->DecreasePoolSizeByBase(batch_size_); }\n  }\n\n protected:\n  using DataReader<ImageClassificationDataInstance>::loader_;\n  using DataReader<ImageClassificationDataInstance>::parser_;\n\n private:\n  size_t batch_size_;\n};\n\n}  // namespace data\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_DATA_OFRECORD_IMAGE_CLASSIFICATION_DATA_READER_H_\n"
  },
  {
    "path": "oneflow/user/data/ofrecord_image_classification_dataset.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/user/data/ofrecord_image_classification_dataset.h\"\n#include \"oneflow/core/common/tensor_buffer.h\"\n#include \"oneflow/user/image/image_util.h\"\n#include \"oneflow/core/job/resource_desc.h\"\n#include \"oneflow/core/job/global_for.h\"\n\n#include <opencv2/opencv.hpp>\n\nnamespace oneflow {\n\nnamespace data {\n\nnamespace {\n\nusing DS = OFRecordImageClassificationDataset;\n\nvoid DecodeImageFromOFRecord(const OFRecord& record, const std::string& feature_name,\n                             const std::string& color_space, TensorBuffer* out) {\n  auto image_feature_it = record.feature().find(feature_name);\n  CHECK(image_feature_it != record.feature().end());\n  const Feature& image_feature = image_feature_it->second;\n  CHECK(image_feature.has_bytes_list());\n  CHECK(image_feature.bytes_list().value_size() == 1);\n  const std::string& src_data = image_feature.bytes_list().value(0);\n  cv::Mat image = cv::imdecode(cv::Mat(1, src_data.size(), CV_8UC1, (void*)(src_data.data())),\n                               cv::IMREAD_COLOR);\n  int W = image.cols;\n  int H = image.rows;\n\n  // convert color space\n  if (ImageUtil::IsColor(color_space) && color_space != \"BGR\") {\n    ImageUtil::ConvertColor(\"BGR\", image, color_space, image);\n  }\n\n  CHECK(image.isContinuous());\n  const int c = ImageUtil::IsColor(color_space) ? 3 : 1;\n  CHECK_EQ(c, image.channels());\n  Shape image_shape({H, W, c});\n  out->Resize(image_shape, DataType::kUInt8);\n  CHECK_EQ(image_shape.elem_cnt(), out->nbytes());\n  CHECK_EQ(image_shape.elem_cnt(), image.total() * image.elemSize());\n  memcpy(out->mut_data<uint8_t>(), image.ptr(), image_shape.elem_cnt());\n}\n\nvoid DecodeLabelFromFromOFRecord(const OFRecord& record, const std::string& feature_name,\n                                 TensorBuffer* out) {\n  auto label_feature_it = record.feature().find(feature_name);\n  CHECK(label_feature_it != record.feature().end());\n  const Feature& label_feature = label_feature_it->second;\n  out->Resize(Shape({1}), DataType::kInt32);\n  if (label_feature.has_int32_list()) {\n    CHECK_EQ(label_feature.int32_list().value_size(), 1);\n    *out->mut_data<int32_t>() = label_feature.int32_list().value(0);\n  } else if (label_feature.has_int64_list()) {\n    CHECK_EQ(label_feature.int64_list().value_size(), 1);\n    *out->mut_data<int32_t>() = label_feature.int64_list().value(0);\n  } else {\n    UNIMPLEMENTED();\n  }\n}\n\nvoid LoadWorker(Dataset<TensorBuffer>* record_dataset,\n                std::vector<std::unique_ptr<Buffer<TensorBuffer>>>* decode_in_buffers) {\n  int64_t thread_idx = 0;\n  bool shutdown = false;\n  while (!shutdown) {\n    auto records = record_dataset->Next();\n    for (auto& record : records) {\n      auto& current_in_buffer = decode_in_buffers->at(thread_idx++);\n      if (thread_idx >= decode_in_buffers->size()) { thread_idx = 0; }\n      auto status = current_in_buffer->Push(std::move(record));\n      if (status == kBufferStatusErrorClosed) {\n        shutdown = true;\n        break;\n      }\n      CHECK(status == kBufferStatusSuccess);\n    }\n  }\n}\n\nvoid DecodeWorker(const std::string& image_feature_name, const std::string& label_feature_name,\n                  const std::string& color_space, Buffer<TensorBuffer>* in_buffer,\n                  Buffer<ImageClassificationDataInstance>* out_buffer) {\n  while (true) {\n    TensorBuffer serialized_record;\n    auto receive_status = in_buffer->Pull(&serialized_record);\n    if (receive_status == kBufferStatusErrorClosed) { break; }\n    CHECK(receive_status == kBufferStatusSuccess);\n    OFRecord record;\n    CHECK(record.ParseFromArray(serialized_record.data<char>(),\n                                serialized_record.shape_view().elem_cnt()));\n    ImageClassificationDataInstance instance;\n    DecodeImageFromOFRecord(record, image_feature_name, color_space, &instance.image);\n    DecodeLabelFromFromOFRecord(record, label_feature_name, &instance.label);\n    auto send_status = out_buffer->Push(std::move(instance));\n    if (send_status == kBufferStatusErrorClosed) { break; }\n    CHECK(send_status == kBufferStatusSuccess);\n  }\n}\n\nint32_t GetNumLocalDecodeThreads(int32_t num_decode_threads_per_machine,\n                                 const ParallelDesc& parallel_desc,\n                                 const ParallelContext& parallel_ctx) {\n  if (num_decode_threads_per_machine == 0) {\n    num_decode_threads_per_machine =\n        Singleton<ResourceDesc, ForSession>::Get()->ComputeThreadPoolSize();\n  }\n  int64_t machine_id = CHECK_JUST(parallel_desc.MachineId4ParallelId(parallel_ctx.parallel_id()));\n  int64_t parallel_num_on_this_machine = parallel_desc.sorted_dev_phy_ids(machine_id).size();\n  return std::max<int32_t>(num_decode_threads_per_machine / parallel_num_on_this_machine, 1);\n}\n\n}  // namespace\n\nOFRecordImageClassificationDataset::OFRecordImageClassificationDataset(\n    user_op::KernelInitContext* ctx, std::unique_ptr<NestedDS>&& dataset)\n    : nested_ds_(std::move(dataset)), out_thread_idx_(0) {\n  const std::string& color_space = ctx->Attr<std::string>(\"color_space\");\n  const std::string& image_feature_name = ctx->Attr<std::string>(\"image_feature_name\");\n  const std::string& label_feature_name = ctx->Attr<std::string>(\"label_feature_name\");\n  auto num_decode_threads_per_machine = ctx->Attr<int32_t>(\"num_decode_threads_per_machine\");\n  auto decode_buffer_size_per_thread = ctx->Attr<int32_t>(\"decode_buffer_size_per_thread\");\n  auto num_local_decode_threads = GetNumLocalDecodeThreads(\n      num_decode_threads_per_machine, ctx->parallel_desc(), ctx->parallel_ctx());\n  decode_in_buffers_.reserve(num_local_decode_threads);\n  decode_out_buffers_.reserve(num_local_decode_threads);\n  for (int64_t i = 0; i < num_local_decode_threads; ++i) {\n    decode_in_buffers_.emplace_back(\n        std::make_unique<Buffer<NestedSampleType>>(decode_buffer_size_per_thread));\n    decode_out_buffers_.emplace_back(\n        std::make_unique<Buffer<SampleType>>(decode_buffer_size_per_thread));\n    decode_threads_.emplace_back(DecodeWorker, image_feature_name, label_feature_name, color_space,\n                                 decode_in_buffers_.back().get(), decode_out_buffers_.back().get());\n  }\n  load_thread_ = std::thread(LoadWorker, nested_ds_.get(), &decode_in_buffers_);\n}\n\nOFRecordImageClassificationDataset::~OFRecordImageClassificationDataset() {\n  for (auto& out_buffer : decode_out_buffers_) { out_buffer->Close(); }\n  for (auto& in_buffer : decode_in_buffers_) { in_buffer->Close(); }\n  load_thread_.join();\n  for (auto& decode_thread : decode_threads_) { decode_thread.join(); }\n}\n\n}  // namespace data\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/data/ofrecord_image_classification_dataset.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_DATA_OFRECORD_IMAGE_CLASSIFICATION_DATASET_H_\n#define ONEFLOW_USER_DATA_OFRECORD_IMAGE_CLASSIFICATION_DATASET_H_\n\n#include \"oneflow/user/data/dataset.h\"\n#include \"oneflow/core/common/buffer.h\"\n#include \"oneflow/core/framework/op_kernel.h\"\n\nnamespace oneflow {\n\nnamespace data {\n\nstruct ImageClassificationDataInstance {\n  TensorBuffer label;\n  TensorBuffer image;\n};\n\nclass OFRecordImageClassificationDataset final : public Dataset<ImageClassificationDataInstance> {\n public:\n  using Base = Dataset<ImageClassificationDataInstance>;\n  using SampleType = Base::SampleType;\n  using BatchType = Base::BatchType;\n  using NestedDS = Dataset<TensorBuffer>;\n  using NestedSampleType = NestedDS::SampleType;\n\n  OF_DISALLOW_COPY_AND_MOVE(OFRecordImageClassificationDataset);\n\n  OFRecordImageClassificationDataset(user_op::KernelInitContext* ctx,\n                                     std::unique_ptr<NestedDS>&& dataset);\n  ~OFRecordImageClassificationDataset() override;\n\n  BatchType Next() override {\n    size_t thread_idx =\n        out_thread_idx_.fetch_add(1, std::memory_order_relaxed) % decode_out_buffers_.size();\n    CHECK_LT(thread_idx, decode_out_buffers_.size());\n\n    BatchType batch;\n    SampleType sample;\n    auto status = decode_out_buffers_[thread_idx]->Pull(&sample);\n    CHECK_EQ(status, kBufferStatusSuccess);\n    batch.push_back(std::move(sample));\n    return batch;\n  }\n\n private:\n  std::unique_ptr<NestedDS> nested_ds_;\n  std::thread load_thread_;\n  std::vector<std::thread> decode_threads_;\n  std::vector<std::unique_ptr<Buffer<NestedSampleType>>> decode_in_buffers_;\n  std::vector<std::unique_ptr<Buffer<SampleType>>> decode_out_buffers_;\n  std::atomic<size_t> out_thread_idx_;\n};\n\n}  // namespace data\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_DATA_OFRECORD_IMAGE_CLASSIFICATION_DATASET_H_\n"
  },
  {
    "path": "oneflow/user/data/ofrecord_image_classification_parser.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_DATA_OFRECORD_IMAGE_CLASSIFICATION_PARSER_H_\n#define ONEFLOW_USER_DATA_OFRECORD_IMAGE_CLASSIFICATION_PARSER_H_\n\n#include \"oneflow/user/data/parser.h\"\n#include \"oneflow/core/common/tensor_buffer.h\"\n#include \"oneflow/core/record/record.pb.h\"\n#include \"oneflow/core/thread/thread_manager.h\"\n#include \"oneflow/user/data/ofrecord_image_classification_dataset.h\"\n\nnamespace oneflow {\n\nnamespace data {\n\nclass OFRecordImageClassificationParser final : public Parser<ImageClassificationDataInstance> {\n public:\n  using Base = Parser<ImageClassificationDataInstance>;\n  using SampleType = typename Base::SampleType;\n  using BatchType = typename Base::BatchType;\n\n  OFRecordImageClassificationParser() = default;\n  ~OFRecordImageClassificationParser() override = default;\n\n  void Parse(BatchType& batch_data, user_op::KernelComputeContext* ctx) override {\n    const int64_t batch_size = batch_data.size();\n    user_op::Tensor* image_tensor = ctx->Tensor4ArgNameAndIndex(\"image\", 0);\n    CHECK_EQ(image_tensor->shape_view().NumAxes(), 1);\n    CHECK_EQ(image_tensor->shape_view().At(0), batch_size);\n    auto* image_buffers = image_tensor->mut_dptr<TensorBuffer>();\n    user_op::Tensor* label_tensor = ctx->Tensor4ArgNameAndIndex(\"label\", 0);\n    CHECK_EQ(label_tensor->shape_view().NumAxes(), 1);\n    CHECK_EQ(label_tensor->shape_view().At(0), batch_size);\n    auto* label_buffers = label_tensor->mut_dptr<TensorBuffer>();\n    for (size_t i = 0; i < batch_data.size(); ++i) {\n      auto& instance = batch_data[i];\n      image_buffers[i].Swap(instance.image);\n      label_buffers[i].Swap(instance.label);\n    }\n  }\n};\n\n}  // namespace data\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_DATA_OFRECORD_IMAGE_CLASSIFICATION_PARSER_H_\n"
  },
  {
    "path": "oneflow/user/data/ofrecord_parser.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_DATA_OFRECORD_PARSER_H_\n#define ONEFLOW_USER_DATA_OFRECORD_PARSER_H_\n\n#include \"oneflow/user/data/parser.h\"\n#include \"oneflow/core/common/tensor_buffer.h\"\n#include \"oneflow/core/record/record.pb.h\"\n#include \"oneflow/core/thread/thread_manager.h\"\n\nnamespace oneflow {\nnamespace data {\n\nclass OFRecordParser final : public Parser<TensorBuffer> {\n public:\n  using Base = Parser<TensorBuffer>;\n  using SampleType = typename Base::SampleType;\n  using BatchType = typename Base::BatchType;\n\n  OFRecordParser() = default;\n  ~OFRecordParser() = default;\n\n  void Parse(BatchType& batch_data, user_op::KernelComputeContext* ctx) override {\n    user_op::Tensor* out_tensor = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    OFRecord* dptr = out_tensor->mut_dptr<OFRecord>();\n    MultiThreadLoop(batch_data.size(), [&](size_t i) {\n      auto& sample = batch_data[i];\n      CHECK(dptr[i].ParseFromArray(sample.data(), sample.nbytes()));\n    });\n    if (batch_data.size() != out_tensor->shape_view().elem_cnt()) {\n      CHECK_EQ(out_tensor->mut_shape_view().NumAxes(), 1);\n      out_tensor->mut_shape_view().Set(0, batch_data.size());\n    }\n  }\n};\n\n}  // namespace data\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_DATA_OFRECORD_PARSER_H_\n"
  },
  {
    "path": "oneflow/user/data/parser.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_DATA_PARSER_H_\n#define ONEFLOW_USER_DATA_PARSER_H_\n\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/framework/op_kernel.h\"\n\nnamespace oneflow {\nnamespace data {\n\ntemplate<typename LoadTarget>\nclass Parser {\n public:\n  using SampleType = LoadTarget;\n  using BatchType = std::vector<SampleType>;\n\n  Parser() = default;\n  virtual ~Parser() = default;\n\n  virtual void Parse(BatchType& batch_data, user_op::KernelComputeContext* ctx) = 0;\n};\n\n}  // namespace data\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_DATA_PARSER_H_\n"
  },
  {
    "path": "oneflow/user/data/random_shuffle_dataset.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_DATA_RANDOM_SHUFFLE_DATASET_H_\n#define ONEFLOW_USER_DATA_RANDOM_SHUFFLE_DATASET_H_\n\n#include \"oneflow/user/data/dataset.h\"\n#include \"oneflow/core/common/tensor_buffer.h\"\n#include \"oneflow/core/framework/op_kernel.h\"\n\nnamespace oneflow {\nnamespace data {\n\ntemplate<typename LoadTarget>\nclass RandomShuffleDataset final : public Dataset<LoadTarget> {\n public:\n  using Base = Dataset<LoadTarget>;\n  using SampleType = typename Base::SampleType;\n  using BatchType = typename Base::BatchType;\n\n  RandomShuffleDataset(user_op::KernelInitContext* ctx,\n                       std::unique_ptr<Dataset<LoadTarget>>&& dataset)\n      : nested_ds_(std::move(dataset)) {\n    // random\n    seed_ = ctx->Attr<int64_t>(\"seed\");\n    if (seed_ == -1) { seed_ = NewRandomSeed(); }\n    std::seed_seq seq({seed_});\n    rand_engine_ = std::default_random_engine(seq);\n\n    // fill buffer\n    initial_buffer_fill_ = ctx->Attr<int32_t>(\"shuffle_buffer_size\");\n    int32_t remain_cnt = initial_buffer_fill_;\n    while (remain_cnt > 0) {\n      BatchType batch = nested_ds_->Next();\n      for (auto& sample : batch) {\n        sample_buffer_.push_back(std::move(sample));\n        remain_cnt--;\n      }\n    }\n  }\n  ~RandomShuffleDataset() = default;\n\n  BatchType Next() override {\n    BatchType batch = nested_ds_->Next();\n    for (auto& sample : batch) {\n      std::uniform_int_distribution<> dis(0, sample_buffer_.size() - 1);\n      int offset = dis(rand_engine_);\n      std::swap(sample_buffer_[offset], sample);\n    }\n    return batch;\n  }\n\n private:\n  std::unique_ptr<Dataset<LoadTarget>> nested_ds_;\n  std::vector<SampleType> sample_buffer_;\n\n  int32_t initial_buffer_fill_;\n\n  std::default_random_engine rand_engine_;\n  int64_t seed_;\n};\n\n}  // namespace data\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_DATA_RANDOM_SHUFFLE_DATASET_H_\n"
  },
  {
    "path": "oneflow/user/image/crop_window.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_IMAGE_CROP_WINDOW_H_\n#define ONEFLOW_USER_IMAGE_CROP_WINDOW_H_\n\n#include \"oneflow/core/common/shape.h\"\n\nnamespace oneflow {\n\nstruct CropWindow {\n  Shape anchor;\n  Shape shape;\n\n  CropWindow() : anchor{0, 0}, shape{0, 0} {}\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_IMAGE_CROP_WINDOW_H_\n"
  },
  {
    "path": "oneflow/user/image/image_util.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/user/image/image_util.h\"\n#include <opencv2/opencv.hpp>\n\nnamespace oneflow {\n\nbool ImageUtil::IsColor(const std::string& color_space) {\n  if (color_space == \"RGB\" || color_space == \"BGR\") {\n    return true;\n  } else if (color_space == \"GRAY\") {\n    return false;\n  } else {\n    UNIMPLEMENTED();\n    return false;\n  }\n}\n\nvoid ImageUtil::ConvertColor(const std::string& input_color, const cv::Mat& input_img,\n                             const std::string& output_color, cv::Mat& output_img) {\n  if (input_color == \"BGR\" && output_color == \"RGB\") {\n    cv::cvtColor(input_img, output_img, cv::COLOR_BGR2RGB);\n  } else {\n    UNIMPLEMENTED();\n  }\n}\n\ncv::Mat GenCvMat4ImageBuffer(const TensorBuffer& image_buffer) {\n  CHECK_EQ(image_buffer.shape_view().NumAxes(), 3);\n  int h = image_buffer.shape_view().At(0);\n  int w = image_buffer.shape_view().At(1);\n  int channels = image_buffer.shape_view().At(2);\n  DataType data_type = image_buffer.data_type();\n  if (channels == 1 && data_type == DataType::kUInt8) {\n    return CreateMatWithPtr(h, w, CV_8UC1, image_buffer.data<uint8_t>());\n  } else if (channels == 1 && data_type == DataType::kFloat) {\n    return CreateMatWithPtr(h, w, CV_32FC1, image_buffer.data<float>());\n  } else if (channels == 3 && data_type == DataType::kUInt8) {\n    return CreateMatWithPtr(h, w, CV_8UC3, image_buffer.data<uint8_t>());\n  } else if (channels == 3 && data_type == DataType::kFloat) {\n    return CreateMatWithPtr(h, w, CV_32FC3, image_buffer.data<float>());\n  } else {\n    UNIMPLEMENTED();\n  }\n  return cv::Mat();\n}\n\ncv::Mat GenCvMat4ImageTensor(const user_op::Tensor* image_tensor, int image_offset) {\n  int has_batch_dim = 0;\n  if (image_tensor->shape_view().NumAxes() == 3) {\n    has_batch_dim = 0;\n    image_offset = 0;\n  } else if (image_tensor->shape_view().NumAxes() == 4) {\n    has_batch_dim = 1;\n    CHECK_GE(image_offset, 0);\n    CHECK_LT(image_offset, image_tensor->shape_view().At(0));\n  } else {\n    UNIMPLEMENTED();\n  }\n  int h = image_tensor->shape_view().At(0 + has_batch_dim);\n  int w = image_tensor->shape_view().At(1 + has_batch_dim);\n  int c = image_tensor->shape_view().At(2 + has_batch_dim);\n  int elem_offset = image_offset * h * w * c;\n  DataType data_type = image_tensor->data_type();\n  if (c == 1 && data_type == DataType::kUInt8) {\n    return CreateMatWithPtr(h, w, CV_8UC1, image_tensor->dptr<uint8_t>() + elem_offset);\n  } else if (c == 1 && data_type == DataType::kFloat) {\n    return CreateMatWithPtr(h, w, CV_32FC1, image_tensor->dptr<float>() + elem_offset);\n  } else if (c == 3 && data_type == DataType::kUInt8) {\n    return CreateMatWithPtr(h, w, CV_8UC3, image_tensor->dptr<uint8_t>() + elem_offset);\n  } else if (c == 3 && data_type == DataType::kFloat) {\n    return CreateMatWithPtr(h, w, CV_32FC3, image_tensor->dptr<float>() + elem_offset);\n  } else {\n    UNIMPLEMENTED();\n  }\n  return cv::Mat();\n}\n\nvoid CvMatConvertToDataType(const cv::Mat& src, cv::Mat* dst, DataType dtype) {\n  if (dtype == DataType::kUInt8) {\n    src.convertTo(*dst, CV_8U);\n  } else if (dtype == DataType::kFloat) {\n    src.convertTo(*dst, CV_32F);\n  } else {\n    UNIMPLEMENTED();\n  }\n}\n\nint GetCvInterpolationFlag(const std::string& interp_type, int org_w, int org_h, int res_w,\n                           int res_h) {\n  if (interp_type == \"bilinear\") {\n    return cv::INTER_LINEAR;\n  } else if (interp_type == \"nearest_neighbor\" || interp_type == \"nn\") {\n    return cv::INTER_NEAREST;\n  } else if (interp_type == \"bicubic\") {\n    return cv::INTER_CUBIC;\n  } else if (interp_type == \"area\") {\n    return cv::INTER_AREA;\n  } else if (interp_type == \"auto\") {\n    if (res_w * res_h >= org_w * org_h) {\n      return cv::INTER_LINEAR;\n    } else {\n      return cv::INTER_AREA;\n    }\n  } else {\n    UNIMPLEMENTED();\n  }\n}\n\nbool CheckInterpolationValid(const std::string& interp_type, std::ostringstream& err) {\n  if (interp_type != \"bilinear\" && interp_type != \"nearest_neighbor\" && interp_type != \"nn\"\n      && interp_type != \"bicubic\" && interp_type != \"area\" && interp_type != \"auto\") {\n    err << \", interpolation_type: \" << interp_type\n        << \" (interpolation_type must be one of bilinear, nearest_neighbor(nn), bicubic, area and \"\n           \"auto)\";\n    return false;\n  }\n  return true;\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/image/image_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_IMAGE_IMAGE_UTIL_H_\n#define ONEFLOW_USER_IMAGE_IMAGE_UTIL_H_\n\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/common/tensor_buffer.h\"\n#include \"oneflow/core/framework/user_op_tensor.h\"\n#include <opencv2/opencv.hpp>\n\nnamespace oneflow {\n\nstruct ImageUtil {\n  static bool IsColor(const std::string& color_space);\n\n  static void ConvertColor(const std::string& input_color, const cv::Mat& input_img,\n                           const std::string& output_color, cv::Mat& output_img);\n};\n\ntemplate<typename T>\ninline cv::Mat CreateMatWithPtr(int H, int W, int type, const T* ptr,\n                                size_t step = cv::Mat::AUTO_STEP) {\n  return cv::Mat(H, W, type, const_cast<T*>(ptr), step);\n}\n\ncv::Mat GenCvMat4ImageBuffer(const TensorBuffer& image_buffer);\n\ncv::Mat GenCvMat4ImageTensor(const user_op::Tensor* image_tensor, int image_offset);\n\nvoid CvMatConvertToDataType(const cv::Mat& src, cv::Mat* dst, DataType dtype);\n\nint GetCvInterpolationFlag(const std::string& inter_type, int org_w, int org_h, int res_w,\n                           int res_h);\nbool CheckInterpolationValid(const std::string& interp_type, std::ostringstream& ss);\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_IMAGE_IMAGE_UTIL_H_\n"
  },
  {
    "path": "oneflow/user/image/jpeg_decoder.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <cstddef>\n#include <iostream>\n\n#include \"oneflow/user/image/jpeg_decoder.h\"\n#include \"oneflow/user/image/image_util.h\"\n\nnamespace oneflow {\n\nclass LibjpegCtx {\n public:\n  explicit LibjpegCtx(struct jpeg_decompress_struct* compress_info)\n      : compress_info_(compress_info) {}\n  ~LibjpegCtx() { jpeg_destroy_decompress(compress_info_); }\n  OF_DISALLOW_COPY_AND_MOVE(LibjpegCtx);\n  struct jpeg_decompress_struct* compress_info() {\n    return compress_info_;\n  }\n\n private:\n  struct jpeg_decompress_struct* compress_info_;\n};\n\nbool JpegPartialDecodeRandomCropImage(const unsigned char* data, size_t length,\n                                      RandomCropGenerator* random_crop_gen,\n                                      unsigned char* workspace, size_t workspace_size,\n                                      cv::Mat* out_mat) {\n  struct jpeg_decompress_struct compress_info {};\n  struct jpeg_error_mgr jpeg_err {};\n  compress_info.err = jpeg_std_error(&jpeg_err);\n  jpeg_create_decompress(&compress_info);\n  if (compress_info.err->msg_code != 0) { return false; }\n\n  LibjpegCtx ctx_guard(&compress_info);\n\n  jpeg_mem_src(ctx_guard.compress_info(), data, length);\n  if (ctx_guard.compress_info()->err->msg_code != 0) { return false; }\n\n  int rc = jpeg_read_header(ctx_guard.compress_info(), TRUE);\n  if (rc != JPEG_HEADER_OK) { return false; }\n\n  jpeg_start_decompress(ctx_guard.compress_info());\n  int width = ctx_guard.compress_info()->output_width;\n  int height = ctx_guard.compress_info()->output_height;\n  int pixel_size = ctx_guard.compress_info()->output_components;\n\n  unsigned int u_crop_x = 0, u_crop_y = 0, u_crop_w = width, u_crop_h = height;\n  if (random_crop_gen) {\n    CropWindow crop;\n    random_crop_gen->GenerateCropWindow({height, width}, &crop);\n    u_crop_y = crop.anchor.At(0);\n    u_crop_x = crop.anchor.At(1);\n    u_crop_h = crop.shape.At(0);\n    u_crop_w = crop.shape.At(1);\n  }\n\n  unsigned int tmp_w = u_crop_w;\n  jpeg_crop_scanline(ctx_guard.compress_info(), &u_crop_x, &tmp_w);\n  if (jpeg_skip_scanlines(ctx_guard.compress_info(), u_crop_y) != u_crop_y) { return false; }\n\n  int row_offset = (tmp_w - u_crop_w) * pixel_size;\n  int out_row_stride = u_crop_w * pixel_size;\n  std::vector<unsigned char> decode_output_buf;\n  unsigned char* decode_output_pointer = nullptr;\n  size_t image_space_size = width * pixel_size;\n\n  if (image_space_size > workspace_size) {\n    decode_output_buf.resize(image_space_size);\n    decode_output_pointer = decode_output_buf.data();\n  } else {\n    decode_output_pointer = workspace;\n  }\n  out_mat->create(u_crop_h, u_crop_w, CV_8UC3);\n\n  while (ctx_guard.compress_info()->output_scanline < u_crop_y + u_crop_h) {\n    unsigned char* buffer_array[1];\n    buffer_array[0] = decode_output_pointer;\n    unsigned int read_line_index = ctx_guard.compress_info()->output_scanline;\n    jpeg_read_scanlines(ctx_guard.compress_info(), buffer_array, 1);\n    memcpy(out_mat->data + (read_line_index - u_crop_y) * out_row_stride,\n           decode_output_pointer + row_offset, out_row_stride);\n  }\n\n  jpeg_skip_scanlines(ctx_guard.compress_info(), height - u_crop_y - u_crop_h);\n  jpeg_finish_decompress(ctx_guard.compress_info());\n\n  return true;\n}\n\nvoid OpenCvPartialDecodeRandomCropImage(const unsigned char* data, size_t length,\n                                        RandomCropGenerator* random_crop_gen,\n                                        const std::string& color_space, cv::Mat& out_mat) {\n  cv::Mat image =\n      cv::imdecode(cv::Mat(1, length, CV_8UC1, const_cast<unsigned char*>(data)),\n                   ImageUtil::IsColor(color_space) ? cv::IMREAD_COLOR : cv::IMREAD_GRAYSCALE);\n  int W = image.cols;\n  int H = image.rows;\n\n  // random crop\n  if (random_crop_gen != nullptr) {\n    CHECK(image.data != nullptr);\n    cv::Mat image_roi;\n    CropWindow crop;\n    random_crop_gen->GenerateCropWindow({H, W}, &crop);\n    const int y = crop.anchor.At(0);\n    const int x = crop.anchor.At(1);\n    const int newH = crop.shape.At(0);\n    const int newW = crop.shape.At(1);\n    CHECK(newW > 0 && newW <= W);\n    CHECK(newH > 0 && newH <= H);\n    cv::Rect roi(x, y, newW, newH);\n    image(roi).copyTo(out_mat);\n    W = out_mat.cols;\n    H = out_mat.rows;\n    CHECK(W == newW);\n    CHECK(H == newH);\n  } else {\n    image.copyTo(out_mat);\n  }\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/image/jpeg_decoder.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_IMAGE_JPEG_DECODER_H_\n#define ONEFLOW_USER_IMAGE_JPEG_DECODER_H_\n#include <jpeglib.h>\n#include <opencv2/core/mat.hpp>\n#include <opencv2/opencv.hpp>\n#include \"oneflow/user/image/random_crop_generator.h\"\n\nnamespace oneflow {\n\nbool JpegPartialDecodeRandomCropImage(const unsigned char* data, size_t length,\n                                      RandomCropGenerator* random_crop_gen,\n                                      unsigned char* workspace, size_t workspace_size,\n                                      cv::Mat* out_mat);\n\nvoid OpenCvPartialDecodeRandomCropImage(const unsigned char* data, size_t length,\n                                        RandomCropGenerator* random_crop_gen,\n                                        const std::string& color_space, cv::Mat& out_mat);\n\n}  // namespace oneflow\n#endif  // ONEFLOW_USER_IMAGE_JPEG_DECODER_H_\n"
  },
  {
    "path": "oneflow/user/image/jpeg_decoder_test.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include <gtest/gtest.h>\n#include <cstddef>\n#include <iostream>\n#include <sys/stat.h>\n#include <unistd.h>\n#include <fcntl.h>\n#include <opencv2/opencv.hpp>\n#include \"oneflow/user/image/jpeg_decoder.h\"\n#include \"oneflow/user/image/image_util.h\"\n\nnamespace oneflow {\n\n// generate image\nvoid GenerateImage(std::vector<uint8_t>& jpg, int w, int h) {\n  std::vector<uint8_t> raw_data(w * h * 3);\n\n  for (int i = 0; i < w; i++) {\n    for (int j = 0; j < h; j++) {\n      uint8_t r = 0, g = 0, b = 0;\n      if (i < w / 2 && j < h / 2) {\n        r = 255;\n        g = 0;\n        b = 0;\n      } else if ((i >= w / 2 && j < h / 2) || (i < w / 2 && j >= h / 2)) {\n        r = 0;\n        g = 255;\n        b = 0;\n      } else if ((i >= w / 2) && (j >= h / 2)) {\n        r = 0;\n        g = 0;\n        b = 255;\n      }\n\n      raw_data[3 * (i * w + j)] = b;\n      raw_data[3 * (i * w + j) + 1] = g;\n      raw_data[3 * (i * w + j) + 2] = r;\n    }\n  }\n\n  std::vector<int> compression_params;\n  compression_params.push_back(cv::IMWRITE_JPEG_QUALITY);\n  compression_params.push_back(100);\n\n  cv::Mat raw(h, w, CV_8UC3, (void*)raw_data.data(), cv::Mat::AUTO_STEP);\n  cv::imencode(\".jpg\", raw, jpg);\n}\n\nTEST(JPEG, decoder) {\n  constexpr size_t test_num = 3;\n  std::vector<unsigned char> jpg;\n  GenerateImage(jpg, 192, 192);\n  std::seed_seq seq{1, 2, 3};\n  std::vector<int64_t> seeds(test_num);\n  seq.generate(seeds.begin(), seeds.end());\n\n  for (int i = 0; i < test_num; i++) {\n    cv::Mat libjpeg_image_mat;\n\n    RandomCropGenerator libjpeg_random_crop_gen({0.1, 0.9}, {0.4, 0.6}, seeds[i], 1);\n    RandomCropGenerator opencv_random_crop_gen({0.1, 0.9}, {0.4, 0.6}, seeds[i], 1);\n    auto status = JpegPartialDecodeRandomCropImage(jpg.data(), jpg.size(), &libjpeg_random_crop_gen,\n                                                   nullptr, 0, &libjpeg_image_mat);\n    ASSERT_EQ(status, true);\n\n    cv::Mat opencv_image_mat;\n    std::string color_space(\"RGB\");\n\n    OpenCvPartialDecodeRandomCropImage(jpg.data(), jpg.size(), &opencv_random_crop_gen, color_space,\n                                       opencv_image_mat);\n    ImageUtil::ConvertColor(\"BGR\", opencv_image_mat, color_space, opencv_image_mat);\n\n    cv::Mat checkout = libjpeg_image_mat - opencv_image_mat;\n    auto sum = cv::sum(cv::sum(checkout));\n    ASSERT_EQ(sum[0], 0);\n    // cv::imwrite(\"jpeg.ppm\", libjpeg_image_mat);\n    // cv::imwrite(\"opencv.ppm\", opencv_image_mat);\n  }\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/image/random_crop_generator.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/user/image/random_crop_generator.h\"\n\nnamespace oneflow {\n\nRandomCropGenerator::RandomCropGenerator(AspectRatioRange aspect_ratio_range, AreaRange area_range,\n                                         int64_t seed, int32_t num_attempts)\n    : aspect_ratio_range_(aspect_ratio_range),\n      aspect_ratio_log_dis_(std::log(aspect_ratio_range.first),\n                            std::log(aspect_ratio_range.second)),\n      area_dis_(area_range.first, area_range.second),\n      rand_gen_(seed),\n      seed_(seed),\n      num_attempts_(num_attempts) {}\n\nvoid RandomCropGenerator::GenerateCropWindow(const Shape& shape, CropWindow* crop_window) {\n  CHECK_EQ(shape.NumAxes(), 2);\n  CHECK(crop_window != nullptr);\n\n  int H = shape.At(0);\n  int W = shape.At(1);\n  if (H <= 0 || W <= 0) { return; }\n\n  float min_wh_ratio = aspect_ratio_range_.first;\n  float max_wh_ratio = aspect_ratio_range_.second;\n  float max_hw_ratio = 1 / aspect_ratio_range_.first;\n  float min_area = W * H * area_dis_.a();\n  int maxW = std::max<int>(1, static_cast<int>(H * max_wh_ratio));\n  int maxH = std::max<int>(1, static_cast<int>(W * max_hw_ratio));\n\n  if (H * maxW < min_area) {\n    crop_window->shape = Shape({H, maxW});\n  } else if (W * maxH < min_area) {\n    crop_window->shape = Shape({maxH, W});\n  } else {\n    int attempts_left = num_attempts_;\n    for (; attempts_left > 0; attempts_left--) {\n      float scale = area_dis_(rand_gen_);\n\n      size_t original_area = H * W;\n      float target_area = scale * original_area;\n\n      float ratio = std::exp(aspect_ratio_log_dis_(rand_gen_));\n      int w = static_cast<int>(std::roundf(sqrtf(target_area * ratio)));\n      int h = static_cast<int>(std::roundf(sqrtf(target_area / ratio)));\n\n      w = std::max(w, 1);\n      h = std::max(h, 1);\n\n      crop_window->shape = Shape({h, w});\n\n      ratio = static_cast<float>(w) / h;\n\n      if (w <= W && h <= H && ratio >= min_wh_ratio && ratio <= max_wh_ratio) { break; }\n    }\n\n    if (attempts_left <= 0) {\n      float max_area = area_dis_.b() * W * H;\n      float ratio = static_cast<float>(W) / H;\n      if (ratio > max_wh_ratio) {\n        crop_window->shape = Shape({H, maxW});\n      } else if (ratio < min_wh_ratio) {\n        crop_window->shape = Shape({maxH, W});\n      } else {\n        crop_window->shape = Shape({H, W});\n      }\n      float scale =\n          std::min(1.0f, max_area / (crop_window->shape.At(0) * crop_window->shape.At(1)));\n      crop_window->shape.Set(0, std::max<int64_t>(1, crop_window->shape.At(0) * std::sqrt(scale)));\n      crop_window->shape.Set(1, std::max<int64_t>(1, crop_window->shape.At(1) * std::sqrt(scale)));\n    }\n  }\n\n  crop_window->anchor.Set(\n      0, std::uniform_int_distribution<int>(0, H - crop_window->shape.At(0))(rand_gen_));\n  crop_window->anchor.Set(\n      1, std::uniform_int_distribution<int>(0, W - crop_window->shape.At(1))(rand_gen_));\n}\n\nvoid RandomCropGenerator::GenerateCropWindows(const Shape& shape, size_t n,\n                                              std::vector<CropWindow>* crop_windows) {\n  std::seed_seq seq{seed_};\n  std::vector<int64_t> seeds(n);\n  seq.generate(seeds.begin(), seeds.end());\n  crop_windows->resize(n);\n\n  for (std::size_t i = 0; i < n; i++) {\n    rand_gen_.seed(seeds.at(i));\n    GenerateCropWindow(shape, &(crop_windows->at(i)));\n  }\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/image/random_crop_generator.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_IMAGE_RANDOM_CROP_GENERATOR_H_\n#define ONEFLOW_USER_IMAGE_RANDOM_CROP_GENERATOR_H_\n\n#include \"oneflow/user/image/crop_window.h\"\n\nnamespace oneflow {\n\nusing AspectRatioRange = std::pair<float, float>;\nusing AreaRange = std::pair<float, float>;\n\nclass RandomCropGenerator {\n public:\n  RandomCropGenerator(AspectRatioRange aspect_ratio_range, AreaRange area_range, int64_t seed,\n                      int32_t num_attempts);\n\n  void GenerateCropWindow(const Shape& shape, CropWindow* crop_window);\n  void GenerateCropWindows(const Shape& shape, size_t n, std::vector<CropWindow>* crop_windows);\n\n private:\n  AspectRatioRange aspect_ratio_range_;\n  std::uniform_real_distribution<float> aspect_ratio_log_dis_;\n  std::uniform_real_distribution<float> area_dis_;\n  std::mt19937 rand_gen_;\n  int64_t seed_;\n  int32_t num_attempts_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_IMAGE_RANDOM_CROP_GENERATOR_H_\n"
  },
  {
    "path": "oneflow/user/kernels/acc_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/kernel_util.h\"\n#include \"oneflow/core/ep/include/primitive/add.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nclass AccKernel final : public user_op::OpKernel {\n public:\n  AccKernel() = default;\n  ~AccKernel() override = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    CHECK_EQ(in->shape_view().elem_cnt(), out->shape_view().elem_cnt());\n    CHECK_EQ(in->data_type(), out->data_type());\n    std::unique_ptr<ep::primitive::Add> primitive =\n        ep::primitive::NewPrimitive<ep::primitive::AddFactory>(ctx->device_type(), in->data_type());\n    CHECK(primitive);\n    primitive->Launch(ctx->stream(), out->dptr(), in->dptr(), out->mut_dptr(),\n                      in->shape_view().elem_cnt());\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\nREGISTER_USER_KERNEL(\"acc\").SetCreateFn<AccKernel>();\n\n}  // namespace\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/activation_kernels.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/ep/include/primitive/binary_op.h\"\n#include \"oneflow/core/ep/include/primitive/broadcast_elementwise_binary.h\"\n#include \"oneflow/user/kernels/elementwise_primitive_kernel.h\"\n\nnamespace oneflow {\n\nREGISTER_USER_KERNEL(\"elu\")\n    .SetCreateFn([]() {\n      return user_op::NewOpKernel<UnaryPrimitiveKernel>(\n          \"out\", \"in\", [](user_op::KernelComputeContext* ctx) {\n            const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex(\"in\", 0);\n            const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex(\"out\", 0);\n            return ep::primitive::NewPrimitive<ep::primitive::ElementwiseUnaryFactory>(\n                ctx->device_type(), ep::primitive::UnaryOp::kElu, src->data_type(),\n                dst->data_type(), ctx->Attr<double>(\"alpha\"));\n          });\n    })\n    .SetIsMatchedHob(UnaryPrimitiveExists(ep::primitive::UnaryOp::kElu, \"out\", \"in\"));\n\nREGISTER_USER_KERNEL(\"elu_grad\")\n    .SetCreateFn([]() {\n      return user_op::NewOpKernel<BinaryPrimitiveKernel>(\n          \"dx\", \"dy\", \"x\", [](user_op::KernelComputeContext* ctx) {\n            const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex(\"dy\", 0);\n            const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex(\"dx\", 0);\n            return ep::primitive::NewPrimitive<ep::primitive::BroadcastElementwiseBinaryFactory>(\n                ctx->device_type(), ep::primitive::BinaryOp::kEluBackwardWithDyX, src->data_type(),\n                dst->data_type(), 1 /*max_num_dims*/, ctx->Attr<double>(\"alpha\"));\n          });\n    })\n    .SetIsMatchedHob(BinaryPrimitiveExists(ep::primitive::BinaryOp::kEluBackwardWithDyX, \"dx\",\n                                           \"dy\"));\n\nREGISTER_USER_KERNEL(\"celu\")\n    .SetCreateFn([]() {\n      return user_op::NewOpKernel<UnaryPrimitiveKernel>(\n          \"out\", \"in\", [](user_op::KernelComputeContext* ctx) {\n            const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex(\"in\", 0);\n            const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex(\"out\", 0);\n            return ep::primitive::NewPrimitive<ep::primitive::ElementwiseUnaryFactory>(\n                ctx->device_type(), ep::primitive::UnaryOp::kCelu, src->data_type(),\n                dst->data_type(), ctx->Attr<double>(\"alpha\"));\n          });\n    })\n    .SetIsMatchedHob(UnaryPrimitiveExists(ep::primitive::UnaryOp::kCelu, \"out\", \"in\"));\n\nREGISTER_USER_KERNEL(\"celu_grad\")\n    .SetCreateFn([]() {\n      return user_op::NewOpKernel<BinaryPrimitiveKernel>(\n          \"dx\", \"dy\", \"y\", [](user_op::KernelComputeContext* ctx) {\n            const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex(\"dy\", 0);\n            const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex(\"dx\", 0);\n            return ep::primitive::NewPrimitive<ep::primitive::BroadcastElementwiseBinaryFactory>(\n                ctx->device_type(), ep::primitive::BinaryOp::kCeluBackwardWithDyY, src->data_type(),\n                dst->data_type(), 1 /*max_num_dims*/, ctx->Attr<double>(\"alpha\"));\n          });\n    })\n    .SetIsMatchedHob(BinaryPrimitiveExists(ep::primitive::BinaryOp::kCeluBackwardWithDyY, \"dx\",\n                                           \"dy\"));\n\nREGISTER_USER_KERNEL(\"hardswish\")\n    .SetCreateFn([]() {\n      return user_op::NewOpKernel<UnaryPrimitiveKernel>(\n          \"out\", \"in\", [](user_op::KernelComputeContext* ctx) {\n            const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex(\"in\", 0);\n            const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex(\"out\", 0);\n            return ep::primitive::NewPrimitive<ep::primitive::ElementwiseUnaryFactory>(\n                ctx->device_type(), ep::primitive::UnaryOp::kHardSwish, src->data_type(),\n                dst->data_type());\n          });\n    })\n    .SetIsMatchedHob(UnaryPrimitiveExists(ep::primitive::UnaryOp::kHardSwish, \"out\", \"in\"));\n\nREGISTER_USER_KERNEL(\"hardswish_grad\")\n    .SetCreateFn([]() {\n      return user_op::NewOpKernel<BinaryPrimitiveKernel>(\n          \"dx\", \"dy\", \"x\", [](user_op::KernelComputeContext* ctx) {\n            const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex(\"dy\", 0);\n            const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex(\"dx\", 0);\n            return ep::primitive::NewPrimitive<ep::primitive::BroadcastElementwiseBinaryFactory>(\n                ctx->device_type(), ep::primitive::BinaryOp::kHardswishBackwardWithDyX,\n                src->data_type(), dst->data_type(), 1 /*max_num_dims*/);\n          });\n    })\n    .SetIsMatchedHob(BinaryPrimitiveExists(ep::primitive::BinaryOp::kHardswishBackwardWithDyX, \"dx\",\n                                           \"dy\"));\n\nREGISTER_USER_KERNEL(\"hardsigmoid\")\n    .SetCreateFn([]() {\n      return user_op::NewOpKernel<UnaryPrimitiveKernel>(\n          \"out\", \"in\", [](user_op::KernelComputeContext* ctx) {\n            const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex(\"in\", 0);\n            const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex(\"out\", 0);\n            return ep::primitive::NewPrimitive<ep::primitive::ElementwiseUnaryFactory>(\n                ctx->device_type(), ep::primitive::UnaryOp::kHardSigmoid, src->data_type(),\n                dst->data_type());\n          });\n    })\n    .SetIsMatchedHob(UnaryPrimitiveExists(ep::primitive::UnaryOp::kHardSigmoid, \"out\", \"in\"));\n\nREGISTER_USER_KERNEL(\"hardsigmoid_grad\")\n    .SetCreateFn([]() {\n      return user_op::NewOpKernel<BinaryPrimitiveKernel>(\n          \"dx\", \"dy\", \"x\", [](user_op::KernelComputeContext* ctx) {\n            const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex(\"dy\", 0);\n            const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex(\"dx\", 0);\n            return ep::primitive::NewPrimitive<ep::primitive::BroadcastElementwiseBinaryFactory>(\n                ctx->device_type(), ep::primitive::BinaryOp::kHardsigmoidBackwardWithDyX,\n                src->data_type(), dst->data_type(), 1 /*max_num_dims*/);\n          });\n    })\n    .SetIsMatchedHob(BinaryPrimitiveExists(ep::primitive::BinaryOp::kHardsigmoidBackwardWithDyX,\n                                           \"dx\", \"dy\"));\n\nREGISTER_USER_KERNEL(\"hardshrink\")\n    .SetCreateFn([]() {\n      return user_op::NewOpKernel<UnaryPrimitiveKernel>(\n          \"out\", \"in\", [](user_op::KernelComputeContext* ctx) {\n            const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex(\"in\", 0);\n            const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex(\"out\", 0);\n            return ep::primitive::NewPrimitive<ep::primitive::ElementwiseUnaryFactory>(\n                ctx->device_type(), ep::primitive::UnaryOp::kHardShrink, src->data_type(),\n                dst->data_type(), ctx->Attr<double>(\"lambd\"));\n          });\n    })\n    .SetIsMatchedHob(UnaryPrimitiveExists(ep::primitive::UnaryOp::kHardShrink, \"out\", \"in\"))\n    .SetInplaceProposalFn([](const user_op::InferContext&,\n                             const user_op::AddInplaceArgPair& AddInplaceArgPairFn) -> Maybe<void> {\n      OF_RETURN_IF_ERROR(AddInplaceArgPairFn(\"out\", 0, \"in\", 0, true));\n      return Maybe<void>::Ok();\n    });\n\nREGISTER_USER_KERNEL(\"hardshrink_grad\")\n    .SetCreateFn([]() {\n      return user_op::NewOpKernel<BinaryPrimitiveKernel>(\n          \"dx\", \"dy\", \"y\", [](user_op::KernelComputeContext* ctx) {\n            const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex(\"dy\", 0);\n            const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex(\"dx\", 0);\n            return ep::primitive::NewPrimitive<ep::primitive::BroadcastElementwiseBinaryFactory>(\n                ctx->device_type(), ep::primitive::BinaryOp::kHardshrinkBackwardWithDyY,\n                src->data_type(), dst->data_type(), 1 /*max_num_dims*/, ctx->Attr<double>(\"lambd\"));\n          });\n    })\n    .SetIsMatchedHob(BinaryPrimitiveExists(ep::primitive::BinaryOp::kHardshrinkBackwardWithDyY,\n                                           \"dx\", \"dy\"))\n    .SetInplaceProposalFn([](const user_op::InferContext&,\n                             const user_op::AddInplaceArgPair& AddInplaceArgPairFn) -> Maybe<void> {\n      OF_RETURN_IF_ERROR(AddInplaceArgPairFn(\"dx\", 0, \"dy\", 0, true));\n      return Maybe<void>::Ok();\n    });\n\nREGISTER_USER_KERNEL(\"hardtanh\")\n    .SetCreateFn([]() {\n      return user_op::NewOpKernel<UnaryPrimitiveKernel>(\n          \"out\", \"in\", [](user_op::KernelComputeContext* ctx) {\n            const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex(\"in\", 0);\n            const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex(\"out\", 0);\n            return ep::primitive::NewPrimitive<ep::primitive::ElementwiseUnaryFactory>(\n                ctx->device_type(), ep::primitive::UnaryOp::kHardTanh, src->data_type(),\n                dst->data_type(), ctx->Attr<double>(\"min_val\"), ctx->Attr<double>(\"max_val\"));\n          });\n    })\n    .SetIsMatchedHob(UnaryPrimitiveExists(ep::primitive::UnaryOp::kHardTanh, \"out\", \"in\"))\n    .SetInplaceProposalFn([](const user_op::InferContext&,\n                             const user_op::AddInplaceArgPair& AddInplaceArgPairFn) -> Maybe<void> {\n      OF_RETURN_IF_ERROR(AddInplaceArgPairFn(\"out\", 0, \"in\", 0, true));\n      return Maybe<void>::Ok();\n    });\n\nREGISTER_USER_KERNEL(\"hardtanh_grad\")\n    .SetCreateFn([]() {\n      return user_op::NewOpKernel<BinaryPrimitiveKernel>(\n          \"dx\", \"dy\", \"y\", [](user_op::KernelComputeContext* ctx) {\n            const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex(\"dy\", 0);\n            const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex(\"dx\", 0);\n            return ep::primitive::NewPrimitive<ep::primitive::BroadcastElementwiseBinaryFactory>(\n                ctx->device_type(), ep::primitive::BinaryOp::kHardtanhBackwardWithDyY,\n                src->data_type(), dst->data_type(), 1 /*max_num_dims*/,\n                ctx->Attr<double>(\"min_val\"), ctx->Attr<double>(\"max_val\"));\n          });\n    })\n    .SetIsMatchedHob(BinaryPrimitiveExists(ep::primitive::BinaryOp::kHardtanhBackwardWithDyY, \"dx\",\n                                           \"dy\"))\n    .SetInplaceProposalFn([](const user_op::InferContext&,\n                             const user_op::AddInplaceArgPair& AddInplaceArgPairFn) -> Maybe<void> {\n      OF_RETURN_IF_ERROR(AddInplaceArgPairFn(\"dx\", 0, \"dy\", 0, true));\n      return Maybe<void>::Ok();\n    });\n\nREGISTER_USER_KERNEL(\"gelu\")\n    .SetCreateFn([]() {\n      return user_op::NewOpKernel<UnaryPrimitiveKernel>(\n          \"out\", \"in\", [](user_op::KernelComputeContext* ctx) {\n            const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex(\"in\", 0);\n            const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex(\"out\", 0);\n            return ep::primitive::NewPrimitive<ep::primitive::ElementwiseUnaryFactory>(\n                ctx->device_type(), ep::primitive::UnaryOp::kGelu, src->data_type(),\n                dst->data_type());\n          });\n    })\n    .SetIsMatchedHob(UnaryPrimitiveExists(ep::primitive::UnaryOp::kGelu, \"out\", \"in\"));\n\nREGISTER_USER_KERNEL(\"gelu_grad\")\n    .SetCreateFn([]() {\n      return user_op::NewOpKernel<BinaryPrimitiveKernel>(\n          \"dx\", \"dy\", \"x\", [](user_op::KernelComputeContext* ctx) {\n            const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex(\"dy\", 0);\n            const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex(\"dx\", 0);\n            return ep::primitive::NewPrimitive<ep::primitive::BroadcastElementwiseBinaryFactory>(\n                ctx->device_type(), ep::primitive::BinaryOp::kGeluBackwardWithDyX, src->data_type(),\n                dst->data_type(), 1 /*max_num_dims*/);\n          });\n    })\n    .SetIsMatchedHob(BinaryPrimitiveExists(ep::primitive::BinaryOp::kGeluBackwardWithDyX, \"dx\",\n                                           \"dy\"));\n\nREGISTER_USER_KERNEL(\"fast_gelu\")\n    .SetCreateFn([]() {\n      return user_op::NewOpKernel<UnaryPrimitiveKernel>(\n          \"out\", \"in\", [](user_op::KernelComputeContext* ctx) {\n            const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex(\"in\", 0);\n            const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex(\"out\", 0);\n            return ep::primitive::NewPrimitive<ep::primitive::ElementwiseUnaryFactory>(\n                ctx->device_type(), ep::primitive::UnaryOp::kFastGelu, src->data_type(),\n                dst->data_type());\n          });\n    })\n    .SetIsMatchedHob(UnaryPrimitiveExists(ep::primitive::UnaryOp::kFastGelu, \"out\", \"in\"));\n\nREGISTER_USER_KERNEL(\"fast_gelu_grad\")\n    .SetCreateFn([]() {\n      return user_op::NewOpKernel<BinaryPrimitiveKernel>(\n          \"dx\", \"dy\", \"x\", [](user_op::KernelComputeContext* ctx) {\n            const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex(\"dy\", 0);\n            const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex(\"dx\", 0);\n            return ep::primitive::NewPrimitive<ep::primitive::BroadcastElementwiseBinaryFactory>(\n                ctx->device_type(), ep::primitive::BinaryOp::kFastGeluBackwardWithDyX,\n                src->data_type(), dst->data_type(), 1 /*max_num_dims*/);\n          });\n    })\n    .SetIsMatchedHob(BinaryPrimitiveExists(ep::primitive::BinaryOp::kFastGeluBackwardWithDyX, \"dx\",\n                                           \"dy\"));\n\nREGISTER_USER_KERNEL(\"quick_gelu\")\n    .SetCreateFn([]() {\n      return user_op::NewOpKernel<UnaryPrimitiveKernel>(\n          \"y\", \"x\", [](user_op::KernelComputeContext* ctx) {\n            const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex(\"x\", 0);\n            const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex(\"y\", 0);\n            return ep::primitive::NewPrimitive<ep::primitive::ElementwiseUnaryFactory>(\n                ctx->device_type(), ep::primitive::UnaryOp::kQuickGelu, src->data_type(),\n                dst->data_type());\n          });\n    })\n    .SetIsMatchedHob(UnaryPrimitiveExists(ep::primitive::UnaryOp::kQuickGelu, \"y\", \"x\"));\n\nREGISTER_USER_KERNEL(\"quick_gelu_grad\")\n    .SetCreateFn([]() {\n      return user_op::NewOpKernel<BinaryPrimitiveKernel>(\n          \"dx\", \"dy\", \"x\", [](user_op::KernelComputeContext* ctx) {\n            const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex(\"dy\", 0);\n            const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex(\"dx\", 0);\n            return ep::primitive::NewPrimitive<ep::primitive::BroadcastElementwiseBinaryFactory>(\n                ctx->device_type(), ep::primitive::BinaryOp::kQuickGeluBackwardWithDyX,\n                src->data_type(), dst->data_type(), 1 /*max_num_dims*/);\n          });\n    })\n    .SetIsMatchedHob(BinaryPrimitiveExists(ep::primitive::BinaryOp::kQuickGeluBackwardWithDyX, \"dx\",\n                                           \"dy\"));\nREGISTER_USER_KERNEL(\"square_relu\")\n    .SetCreateFn([]() {\n      return user_op::NewOpKernel<UnaryPrimitiveKernel>(\n          \"y\", \"x\", [](user_op::KernelComputeContext* ctx) {\n            const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex(\"x\", 0);\n            const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex(\"y\", 0);\n            return ep::primitive::NewPrimitive<ep::primitive::ElementwiseUnaryFactory>(\n                ctx->device_type(), ep::primitive::UnaryOp::kSquareReLU, src->data_type(),\n                dst->data_type());\n          });\n    })\n    .SetIsMatchedHob(UnaryPrimitiveExists(ep::primitive::UnaryOp::kSquareReLU, \"y\", \"x\"));\n\nREGISTER_USER_KERNEL(\"square_relu_grad\")\n    .SetCreateFn([]() {\n      return user_op::NewOpKernel<BinaryPrimitiveKernel>(\n          \"dx\", \"dy\", \"x\", [](user_op::KernelComputeContext* ctx) {\n            const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex(\"dy\", 0);\n            const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex(\"dx\", 0);\n            return ep::primitive::NewPrimitive<ep::primitive::BroadcastElementwiseBinaryFactory>(\n                ctx->device_type(), ep::primitive::BinaryOp::kSquareReLUBackwardWithDyX,\n                src->data_type(), dst->data_type(), 1 /*max_num_dims*/);\n          });\n    })\n    .SetIsMatchedHob(BinaryPrimitiveExists(ep::primitive::BinaryOp::kSquareReLUBackwardWithDyX,\n                                           \"dx\", \"dy\"));\n\nREGISTER_USER_KERNEL(\"leaky_relu\")\n    .SetCreateFn([]() {\n      return user_op::NewOpKernel<UnaryPrimitiveKernel>(\n          \"y\", \"x\", [](user_op::KernelComputeContext* ctx) {\n            const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex(\"x\", 0);\n            const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex(\"y\", 0);\n            return ep::primitive::NewPrimitive<ep::primitive::ElementwiseUnaryFactory>(\n                ctx->device_type(), ep::primitive::UnaryOp::kLeakyRelu, src->data_type(),\n                dst->data_type(), ctx->Attr<float>(\"alpha\"));\n          });\n    })\n    .SetIsMatchedHob(UnaryPrimitiveExists(ep::primitive::UnaryOp::kLeakyRelu, \"y\", \"x\"));\n\nREGISTER_USER_KERNEL(\"leaky_relu_grad\")\n    .SetCreateFn([]() {\n      return user_op::NewOpKernel<BinaryPrimitiveKernel>(\n          \"dx\", \"dy\", \"x\", [](user_op::KernelComputeContext* ctx) {\n            const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex(\"dy\", 0);\n            const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex(\"dx\", 0);\n            return ep::primitive::NewPrimitive<ep::primitive::BroadcastElementwiseBinaryFactory>(\n                ctx->device_type(), ep::primitive::BinaryOp::kLeakyReluBackwardWithDyX,\n                src->data_type(), dst->data_type(), 1 /*max_num_dims*/, ctx->Attr<float>(\"alpha\"));\n          });\n    })\n    .SetIsMatchedHob(BinaryPrimitiveExists(ep::primitive::BinaryOp::kLeakyReluBackwardWithDyX, \"dx\",\n                                           \"dy\"));\n\nREGISTER_USER_KERNEL(\"mish\")\n    .SetCreateFn([]() {\n      return user_op::NewOpKernel<UnaryPrimitiveKernel>(\n          \"out\", \"in\", [](user_op::KernelComputeContext* ctx) {\n            const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex(\"in\", 0);\n            const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex(\"out\", 0);\n            return ep::primitive::NewPrimitive<ep::primitive::ElementwiseUnaryFactory>(\n                ctx->device_type(), ep::primitive::UnaryOp::kMish, src->data_type(),\n                dst->data_type());\n          });\n    })\n    .SetIsMatchedHob(UnaryPrimitiveExists(ep::primitive::UnaryOp::kMish, \"out\", \"in\"));\n\nREGISTER_USER_KERNEL(\"mish_grad\")\n    .SetCreateFn([]() {\n      return user_op::NewOpKernel<BinaryPrimitiveKernel>(\n          \"dx\", \"dy\", \"x\", [](user_op::KernelComputeContext* ctx) {\n            const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex(\"dy\", 0);\n            const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex(\"dx\", 0);\n            return ep::primitive::NewPrimitive<ep::primitive::BroadcastElementwiseBinaryFactory>(\n                ctx->device_type(), ep::primitive::BinaryOp::kMishBackwardWithDyX, src->data_type(),\n                dst->data_type(), 1 /*max_num_dims*/);\n          });\n    })\n    .SetIsMatchedHob(BinaryPrimitiveExists(ep::primitive::BinaryOp::kMishBackwardWithDyX, \"dx\",\n                                           \"dy\"));\n\nREGISTER_USER_KERNEL(\"relu\")\n    .SetCreateFn([]() {\n      return user_op::NewOpKernel<UnaryPrimitiveKernel>(\n          \"y\", \"x\", [](user_op::KernelComputeContext* ctx) {\n            const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex(\"x\", 0);\n            const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex(\"y\", 0);\n            return ep::primitive::NewPrimitive<ep::primitive::ElementwiseUnaryFactory>(\n                ctx->device_type(), ep::primitive::UnaryOp::kRelu, src->data_type(),\n                dst->data_type());\n          });\n    })\n    .SetIsMatchedHob(UnaryPrimitiveExists(ep::primitive::UnaryOp::kRelu, \"y\", \"x\"))\n    .SetInplaceProposalFn([](const user_op::InferContext&,\n                             const user_op::AddInplaceArgPair& AddInplaceArgPairFn) -> Maybe<void> {\n      OF_RETURN_IF_ERROR(AddInplaceArgPairFn(\"y\", 0, \"x\", 0, true));\n      return Maybe<void>::Ok();\n    });\n\nREGISTER_USER_KERNEL(\"relu_grad\")\n    .SetCreateFn([]() {\n      return user_op::NewOpKernel<BinaryPrimitiveKernel>(\n          \"dx\", \"dy\", \"y\", [](user_op::KernelComputeContext* ctx) {\n            const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex(\"dy\", 0);\n            const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex(\"dx\", 0);\n            return ep::primitive::NewPrimitive<ep::primitive::BroadcastElementwiseBinaryFactory>(\n                ctx->device_type(), ep::primitive::BinaryOp::kReluBackwardWithDyY, src->data_type(),\n                dst->data_type(), 1 /*max_num_dims*/);\n          });\n    })\n    .SetIsMatchedHob(BinaryPrimitiveExists(ep::primitive::BinaryOp::kReluBackwardWithDyY, \"dx\",\n                                           \"dy\"))\n    .SetInplaceProposalFn([](const user_op::InferContext&,\n                             const user_op::AddInplaceArgPair& AddInplaceArgPairFn) -> Maybe<void> {\n      OF_RETURN_IF_ERROR(AddInplaceArgPairFn(\"dx\", 0, \"dy\", 0, true));\n      return Maybe<void>::Ok();\n    });\n\nREGISTER_USER_KERNEL(\"silu\")\n    .SetCreateFn([]() {\n      return user_op::NewOpKernel<UnaryPrimitiveKernel>(\n          \"out\", \"in\", [](user_op::KernelComputeContext* ctx) {\n            const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex(\"in\", 0);\n            const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex(\"out\", 0);\n            return ep::primitive::NewPrimitive<ep::primitive::ElementwiseUnaryFactory>(\n                ctx->device_type(), ep::primitive::UnaryOp::kSilu, src->data_type(),\n                dst->data_type());\n          });\n    })\n    .SetIsMatchedHob(UnaryPrimitiveExists(ep::primitive::UnaryOp::kSilu, \"out\", \"in\"));\n\nREGISTER_USER_KERNEL(\"silu_grad\")\n    .SetCreateFn([]() {\n      return user_op::NewOpKernel<BinaryPrimitiveKernel>(\n          \"dx\", \"dy\", \"x\", [](user_op::KernelComputeContext* ctx) {\n            const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex(\"dy\", 0);\n            const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex(\"dx\", 0);\n            return ep::primitive::NewPrimitive<ep::primitive::BroadcastElementwiseBinaryFactory>(\n                ctx->device_type(), ep::primitive::BinaryOp::kSiluBackwardWithDyX, src->data_type(),\n                dst->data_type(), 1 /*max_num_dims*/);\n          });\n    })\n    .SetIsMatchedHob(BinaryPrimitiveExists(ep::primitive::BinaryOp::kSiluBackwardWithDyX, \"dx\",\n                                           \"dy\"));\nREGISTER_USER_KERNEL(\"trunc\")\n    .SetCreateFn([]() {\n      return user_op::NewOpKernel<UnaryPrimitiveKernel>(\n          \"out\", \"in\", [](user_op::KernelComputeContext* ctx) {\n            const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex(\"in\", 0);\n            const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex(\"out\", 0);\n            return ep::primitive::NewPrimitive<ep::primitive::ElementwiseUnaryFactory>(\n                ctx->device_type(), ep::primitive::UnaryOp::kTrunc, src->data_type(),\n                dst->data_type());\n          });\n    })\n    .SetIsMatchedHob(UnaryPrimitiveExists(ep::primitive::UnaryOp::kTrunc, \"out\", \"in\"));\n\nREGISTER_USER_KERNEL(\"selu\")\n    .SetCreateFn([]() {\n      return user_op::NewOpKernel<UnaryPrimitiveKernel>(\n          \"out\", \"in\", [](user_op::KernelComputeContext* ctx) {\n            const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex(\"in\", 0);\n            const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex(\"out\", 0);\n            return ep::primitive::NewPrimitive<ep::primitive::ElementwiseUnaryFactory>(\n                ctx->device_type(), ep::primitive::UnaryOp::kSelu, src->data_type(),\n                dst->data_type());\n          });\n    })\n    .SetIsMatchedHob(UnaryPrimitiveExists(ep::primitive::UnaryOp::kSelu, \"out\", \"in\"));\n\nREGISTER_USER_KERNEL(\"selu_grad\")\n    .SetCreateFn([]() {\n      return user_op::NewOpKernel<BinaryPrimitiveKernel>(\n          \"dx\", \"dy\", \"x\", [](user_op::KernelComputeContext* ctx) {\n            const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex(\"dy\", 0);\n            const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex(\"dx\", 0);\n            return ep::primitive::NewPrimitive<ep::primitive::BroadcastElementwiseBinaryFactory>(\n                ctx->device_type(), ep::primitive::BinaryOp::kSeluBackwardWithDyX, src->data_type(),\n                dst->data_type(), 1 /*max_num_dims*/);\n          });\n    })\n    .SetIsMatchedHob(BinaryPrimitiveExists(ep::primitive::BinaryOp::kSeluBackwardWithDyX, \"dx\",\n                                           \"dy\"));\n\nREGISTER_USER_KERNEL(\"softshrink\")\n    .SetCreateFn([]() {\n      return user_op::NewOpKernel<UnaryPrimitiveKernel>(\n          \"out\", \"in\", [](user_op::KernelComputeContext* ctx) {\n            const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex(\"in\", 0);\n            const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex(\"out\", 0);\n            return ep::primitive::NewPrimitive<ep::primitive::ElementwiseUnaryFactory>(\n                ctx->device_type(), ep::primitive::UnaryOp::kSoftShrink, src->data_type(),\n                dst->data_type(), ctx->Attr<double>(\"alpha\"));\n          });\n    })\n    .SetIsMatchedHob(UnaryPrimitiveExists(ep::primitive::UnaryOp::kSoftShrink, \"out\", \"in\"));\n\nREGISTER_USER_KERNEL(\"softshrink_grad\")\n    .SetCreateFn([]() {\n      return user_op::NewOpKernel<BinaryPrimitiveKernel>(\n          \"dx\", \"dy\", \"y\", [](user_op::KernelComputeContext* ctx) {\n            const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex(\"dy\", 0);\n            const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex(\"dx\", 0);\n            return ep::primitive::NewPrimitive<ep::primitive::BroadcastElementwiseBinaryFactory>(\n                ctx->device_type(), ep::primitive::BinaryOp::kSoftshrinkBackwardWithDyY,\n                src->data_type(), dst->data_type(), 1 /*max_num_dims*/, ctx->Attr<double>(\"alpha\"));\n          });\n    })\n    .SetIsMatchedHob(BinaryPrimitiveExists(ep::primitive::BinaryOp::kSoftshrinkBackwardWithDyY,\n                                           \"dx\", \"dy\"));\n\nREGISTER_USER_KERNEL(\"softsign\")\n    .SetCreateFn([]() {\n      return user_op::NewOpKernel<UnaryPrimitiveKernel>(\n          \"out\", \"in\", [](user_op::KernelComputeContext* ctx) {\n            const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex(\"in\", 0);\n            const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex(\"out\", 0);\n            return ep::primitive::NewPrimitive<ep::primitive::ElementwiseUnaryFactory>(\n                ctx->device_type(), ep::primitive::UnaryOp::kSoftSign, src->data_type(),\n                dst->data_type());\n          });\n    })\n    .SetIsMatchedHob(UnaryPrimitiveExists(ep::primitive::UnaryOp::kSoftSign, \"out\", \"in\"));\n\nREGISTER_USER_KERNEL(\"softsign_grad\")\n    .SetCreateFn([]() {\n      return user_op::NewOpKernel<BinaryPrimitiveKernel>(\n          \"dx\", \"dy\", \"x\", [](user_op::KernelComputeContext* ctx) {\n            const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex(\"dy\", 0);\n            const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex(\"dx\", 0);\n            return ep::primitive::NewPrimitive<ep::primitive::BroadcastElementwiseBinaryFactory>(\n                ctx->device_type(), ep::primitive::BinaryOp::kSoftsignBackwardWithDyX,\n                src->data_type(), dst->data_type(), 1 /*max_num_dims*/);\n          });\n    })\n    .SetIsMatchedHob(BinaryPrimitiveExists(ep::primitive::BinaryOp::kSoftsignBackwardWithDyX, \"dx\",\n                                           \"dy\"));\n\nREGISTER_USER_KERNEL(\"softplus\")\n    .SetCreateFn([]() {\n      return user_op::NewOpKernel<UnaryPrimitiveKernel>(\n          \"out\", \"in\", [](user_op::KernelComputeContext* ctx) {\n            const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex(\"in\", 0);\n            const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex(\"out\", 0);\n            return ep::primitive::NewPrimitive<ep::primitive::ElementwiseUnaryFactory>(\n                ctx->device_type(), ep::primitive::UnaryOp::kSoftPlus, src->data_type(),\n                dst->data_type(), ctx->Attr<double>(\"beta\"), ctx->Attr<double>(\"threshold\"));\n          });\n    })\n    .SetIsMatchedHob(UnaryPrimitiveExists(ep::primitive::UnaryOp::kSoftPlus, \"out\", \"in\"));\n\nREGISTER_USER_KERNEL(\"softplus_grad\")\n    .SetCreateFn([]() {\n      return user_op::NewOpKernel<BinaryPrimitiveKernel>(\n          \"dx\", \"dy\", \"x\", [](user_op::KernelComputeContext* ctx) {\n            const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex(\"dy\", 0);\n            const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex(\"dx\", 0);\n            return ep::primitive::NewPrimitive<ep::primitive::BroadcastElementwiseBinaryFactory>(\n                ctx->device_type(), ep::primitive::BinaryOp::kSoftplusBackwardWithDyX,\n                src->data_type(), dst->data_type(), 1 /*max_num_dims*/, ctx->Attr<double>(\"beta\"),\n                ctx->Attr<double>(\"threshold\"));\n          });\n    })\n    .SetIsMatchedHob(BinaryPrimitiveExists(ep::primitive::BinaryOp::kSoftplusBackwardWithDyX, \"dx\",\n                                           \"dy\"));\n\nREGISTER_USER_KERNEL(\"tanh\")\n    .SetCreateFn([]() {\n      return user_op::NewOpKernel<UnaryPrimitiveKernel>(\n          \"y\", \"x\", [](user_op::KernelComputeContext* ctx) {\n            const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex(\"x\", 0);\n            const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex(\"y\", 0);\n            return ep::primitive::NewPrimitive<ep::primitive::ElementwiseUnaryFactory>(\n                ctx->device_type(), ep::primitive::UnaryOp::kTanh, src->data_type(),\n                dst->data_type());\n          });\n    })\n    .SetIsMatchedHob(UnaryPrimitiveExists(ep::primitive::UnaryOp::kTanh, \"y\", \"x\"));\n\nREGISTER_USER_KERNEL(\"tanh_grad\")\n    .SetCreateFn([]() {\n      return user_op::NewOpKernel<BinaryPrimitiveKernel>(\n          \"dx\", \"dy\", \"y\", [](user_op::KernelComputeContext* ctx) {\n            const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex(\"dy\", 0);\n            const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex(\"dx\", 0);\n            return ep::primitive::NewPrimitive<ep::primitive::BroadcastElementwiseBinaryFactory>(\n                ctx->device_type(), ep::primitive::BinaryOp::kTanhBackwardWithDyY, src->data_type(),\n                dst->data_type(), 1 /*max_num_dims*/);\n          });\n    })\n    .SetIsMatchedHob(BinaryPrimitiveExists(ep::primitive::BinaryOp::kTanhBackwardWithDyY, \"dx\",\n                                           \"dy\"));\n\nREGISTER_USER_KERNEL(\"threshold\")\n    .SetCreateFn([]() {\n      return user_op::NewOpKernel<UnaryPrimitiveKernel>(\n          \"out\", \"in\", [](user_op::KernelComputeContext* ctx) {\n            const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex(\"in\", 0);\n            const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex(\"out\", 0);\n            return ep::primitive::NewPrimitive<ep::primitive::ElementwiseUnaryFactory>(\n                ctx->device_type(), ep::primitive::UnaryOp::kThreshold, src->data_type(),\n                dst->data_type(), ctx->Attr<double>(\"threshold_val\"), ctx->Attr<double>(\"value\"));\n          });\n    })\n    .SetIsMatchedHob(UnaryPrimitiveExists(ep::primitive::UnaryOp::kThreshold, \"out\", \"in\"));\n\nREGISTER_USER_KERNEL(\"threshold_grad\")\n    .SetCreateFn([]() {\n      return user_op::NewOpKernel<BinaryPrimitiveKernel>(\n          \"dx\", \"dy\", \"x\", [](user_op::KernelComputeContext* ctx) {\n            const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex(\"dy\", 0);\n            const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex(\"dx\", 0);\n            return ep::primitive::NewPrimitive<ep::primitive::BroadcastElementwiseBinaryFactory>(\n                ctx->device_type(), ep::primitive::BinaryOp::kThresholdBackwardWithDyX,\n                src->data_type(), dst->data_type(), 1 /*max_num_dims*/,\n                ctx->Attr<double>(\"threshold_val\"));\n          });\n    })\n    .SetIsMatchedHob(BinaryPrimitiveExists(ep::primitive::BinaryOp::kThresholdBackwardWithDyX, \"dx\",\n                                           \"dy\"));\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/adaptive_avg_pool_cpu_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/data_type.h\"\n#include \"oneflow/user/kernels/adaptive_pool_kernel_util.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename T, typename accT>\nvoid AvgForwardCompute(user_op::KernelComputeContext* ctx, const int32_t& dim) {\n  user_op::Tensor* in_tensor = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n  user_op::Tensor* out_tensor = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n  const Shape& x_shape = ctx->TensorDesc4ArgNameAndIndex(\"x\", 0)->shape();\n  const Shape& y_shape = ctx->TensorDesc4ArgNameAndIndex(\"y\", 0)->shape();\n\n  // TODO (Tianyu): Support 'channels_last'\n  const std::string& data_format = ctx->Attr<std::string>(\"data_format\");\n  CHECK_OR_THROW(data_format == \"channels_first\")\n      << \"adaptive_avg_pool on cpu only supports NCHW data format\";\n  const Shape& in = GetShape5D(x_shape, data_format, dim);\n  const Shape& out = GetShape5D(y_shape, data_format, dim);\n\n  const T* in_ptr = in_tensor->dptr<T>();\n  T* out_ptr = out_tensor->mut_dptr<T>();\n\n  const int64_t input_width = in.Count(4);\n  const int64_t output_width = out.Count(4);\n  const int64_t input_image_size = in.Count(3);\n  const int64_t output_image_size = out.Count(3);\n  const int64_t input_size = in.Count(2);\n  const int64_t output_size = out.Count(2);\n\n  FOR_RANGE(int64_t, n, 0, in.At(0)) {\n    FOR_RANGE(int64_t, c, 0, in.At(1)) {\n      FOR_RANGE(int64_t, od, 0, out.At(2)) {\n        int64_t id0 = start_index(od, out.At(2), in.At(2));\n        int64_t id1 = end_index(od, out.At(2), in.At(2));\n        int64_t kd = id1 - id0;\n        FOR_RANGE(int64_t, oh, 0, out.At(3)) {\n          int64_t ih0 = start_index(oh, out.At(3), in.At(3));\n          int64_t ih1 = end_index(oh, out.At(3), in.At(3));\n          int64_t kh = ih1 - ih0;\n          FOR_RANGE(int64_t, ow, 0, out.At(4)) {\n            int64_t iw0 = start_index(ow, out.At(4), in.At(4));\n            int64_t iw1 = end_index(ow, out.At(4), in.At(4));\n            int64_t kw = iw1 - iw0;\n\n            // Compute local average\n            accT sum = static_cast<accT>(0);\n            FOR_RANGE(int64_t, id, id0, id1) {\n              FOR_RANGE(int64_t, ih, ih0, ih1) {\n                FOR_RANGE(int64_t, iw, iw0, iw1) {\n                  sum += static_cast<accT>(in_ptr[id * input_image_size + ih * input_width + iw]);\n                }\n              }\n            }\n            out_ptr[od * output_image_size + oh * output_width + ow] =\n                static_cast<T>(sum / kd / kh / kw);\n          }\n        }\n      }\n      in_ptr += input_size;\n      out_ptr += output_size;\n    }\n  }\n}\n\ntemplate<typename T>\nvoid AvgBackwardCompute(user_op::KernelComputeContext* ctx, const int32_t& dim) {\n  user_op::Tensor* grad_input = ctx->Tensor4ArgNameAndIndex(\"dx\", 0);\n  const user_op::Tensor* grad_output = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n  const Shape& dx_shape = ctx->TensorDesc4ArgNameAndIndex(\"dx\", 0)->shape();\n  const Shape& dy_shape = ctx->TensorDesc4ArgNameAndIndex(\"dy\", 0)->shape();\n\n  // TODO (Tianyu): Support 'channels_last'\n  const std::string& data_format = ctx->Attr<std::string>(\"data_format\");\n  CHECK_OR_THROW(data_format == \"channels_first\")\n      << \"adaptive_avg_pool backward on cpu only supports NCHW data format\";\n  const Shape& in = GetShape5D(dx_shape, data_format, dim);\n  const Shape& out = GetShape5D(dy_shape, data_format, dim);\n\n  const T* out_ptr = grad_output->dptr<T>();\n  T* in_ptr = grad_input->mut_dptr<T>();\n\n  std::fill(in_ptr, in_ptr + grad_input->shape_view().elem_cnt(), static_cast<T>(0));\n\n  const int64_t input_width = in.Count(4);\n  const int64_t output_width = out.Count(4);\n  const int64_t input_image_size = in.Count(3);\n  const int64_t output_image_size = out.Count(3);\n  const int64_t input_size = in.Count(2);\n  const int64_t output_size = out.Count(2);\n\n  FOR_RANGE(int64_t, n, 0, in.At(0)) {\n    FOR_RANGE(int64_t, c, 0, in.At(1)) {\n      FOR_RANGE(int64_t, od, 0, out.At(2)) {\n        int64_t id0 = start_index(od, out.At(2), in.At(2));\n        int64_t id1 = end_index(od, out.At(2), in.At(2));\n        int64_t kd = id1 - id0;\n        FOR_RANGE(int64_t, oh, 0, out.At(3)) {\n          int64_t ih0 = start_index(oh, out.At(3), in.At(3));\n          int64_t ih1 = end_index(oh, out.At(3), in.At(3));\n          int64_t kh = ih1 - ih0;\n          FOR_RANGE(int64_t, ow, 0, out.At(4)) {\n            int64_t iw0 = start_index(ow, out.At(4), in.At(4));\n            int64_t iw1 = end_index(ow, out.At(4), in.At(4));\n            int64_t kw = iw1 - iw0;\n            T grad_delta = static_cast<T>(out_ptr[od * output_image_size + oh * output_width + ow]\n                                          / kd / kh / kw);\n            FOR_RANGE(int64_t, id, id0, id1) {\n              FOR_RANGE(int64_t, ih, ih0, ih1) {\n                FOR_RANGE(int64_t, iw, iw0, iw1) {\n                  in_ptr[id * input_image_size + ih * input_width + iw] += grad_delta;\n                }\n              }\n            }\n          }\n        }\n      }\n      in_ptr += input_size;\n      out_ptr += output_size;\n    }\n  }\n}\n}  // namespace\n\ntemplate<DeviceType device_type, typename T>\nclass AdaptivePool1DCpuKernel final : public user_op::OpKernel {\n public:\n  AdaptivePool1DCpuKernel() = default;\n  ~AdaptivePool1DCpuKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    if (GetDataType<T>::value == kFloat16) {\n      AvgForwardCompute<T, float>(ctx, 1);\n    } else {\n      AvgForwardCompute<T, T>(ctx, 1);\n    }\n  }\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\ntemplate<DeviceType device_type, typename T>\nclass AdaptivePool2DCpuKernel final : public user_op::OpKernel {\n public:\n  AdaptivePool2DCpuKernel() = default;\n  ~AdaptivePool2DCpuKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    if (GetDataType<T>::value == kFloat16) {\n      AvgForwardCompute<T, float>(ctx, 2);\n    } else {\n      AvgForwardCompute<T, T>(ctx, 2);\n    }\n  }\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\ntemplate<DeviceType device_type, typename T>\nclass AdaptivePool3DCpuKernel final : public user_op::OpKernel {\n public:\n  AdaptivePool3DCpuKernel() = default;\n  ~AdaptivePool3DCpuKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    if (GetDataType<T>::value == kFloat16) {\n      AvgForwardCompute<T, float>(ctx, 3);\n    } else {\n      AvgForwardCompute<T, T>(ctx, 3);\n    }\n  }\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\ntemplate<DeviceType device_type, typename T>\nclass AdaptivePool1DCpuGradKernel final : public user_op::OpKernel {\n public:\n  AdaptivePool1DCpuGradKernel() = default;\n  ~AdaptivePool1DCpuGradKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override { AvgBackwardCompute<T>(ctx, 1); }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\ntemplate<DeviceType device_type, typename T>\nclass AdaptivePool2DCpuGradKernel final : public user_op::OpKernel {\n public:\n  AdaptivePool2DCpuGradKernel() = default;\n  ~AdaptivePool2DCpuGradKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override { AvgBackwardCompute<T>(ctx, 2); }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\ntemplate<DeviceType device_type, typename T>\nclass AdaptivePool3DCpuGradKernel final : public user_op::OpKernel {\n public:\n  AdaptivePool3DCpuGradKernel() = default;\n  ~AdaptivePool3DCpuGradKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override { AvgBackwardCompute<T>(ctx, 3); }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_ADAPTIVE_POOL_KERNEL(device, dtype)                                    \\\n  REGISTER_USER_KERNEL(\"adaptive_avg_pool1d\")                                           \\\n      .SetCreateFn<AdaptivePool1DCpuKernel<device, dtype>>()                            \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == device)                             \\\n                       && (user_op::HobDataType(\"y\", 0) == GetDataType<dtype>::value)); \\\n  REGISTER_USER_KERNEL(\"adaptive_avg_pool2d\")                                           \\\n      .SetCreateFn<AdaptivePool2DCpuKernel<device, dtype>>()                            \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == device)                             \\\n                       && (user_op::HobDataType(\"y\", 0) == GetDataType<dtype>::value)); \\\n  REGISTER_USER_KERNEL(\"adaptive_avg_pool3d\")                                           \\\n      .SetCreateFn<AdaptivePool3DCpuKernel<device, dtype>>()                            \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == device)                             \\\n                       && (user_op::HobDataType(\"y\", 0) == GetDataType<dtype>::value));\n\n#define REGISTER_ADAPTIVE_POOL_KERNEL_WITH_DEVICE(device) \\\n  REGISTER_ADAPTIVE_POOL_KERNEL(device, float16)          \\\n  REGISTER_ADAPTIVE_POOL_KERNEL(device, float)            \\\n  REGISTER_ADAPTIVE_POOL_KERNEL(device, double)           \\\n  REGISTER_ADAPTIVE_POOL_KERNEL(device, int)\n\nREGISTER_ADAPTIVE_POOL_KERNEL_WITH_DEVICE(DeviceType::kCPU)\n\n#define REGISTER_ADAPTIVE_POOL_BACKWARD_KERNEL(device, dtype)                            \\\n  REGISTER_USER_KERNEL(\"adaptive_avg_pool1d_grad\")                                       \\\n      .SetCreateFn<AdaptivePool1DCpuGradKernel<device, dtype>>()                         \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == device)                              \\\n                       && (user_op::HobDataType(\"dx\", 0) == GetDataType<dtype>::value)); \\\n  REGISTER_USER_KERNEL(\"adaptive_avg_pool2d_grad\")                                       \\\n      .SetCreateFn<AdaptivePool2DCpuGradKernel<device, dtype>>()                         \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == device)                              \\\n                       && (user_op::HobDataType(\"dx\", 0) == GetDataType<dtype>::value)); \\\n  REGISTER_USER_KERNEL(\"adaptive_avg_pool3d_grad\")                                       \\\n      .SetCreateFn<AdaptivePool3DCpuGradKernel<device, dtype>>()                         \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == device)                              \\\n                       && (user_op::HobDataType(\"dx\", 0) == GetDataType<dtype>::value));\n\n#define REGISTER_ADAPTIVE_POOL_BACKWARD_KERNEL_WITH_DEVICE(device) \\\n  REGISTER_ADAPTIVE_POOL_BACKWARD_KERNEL(device, float16)          \\\n  REGISTER_ADAPTIVE_POOL_BACKWARD_KERNEL(device, float)            \\\n  REGISTER_ADAPTIVE_POOL_BACKWARD_KERNEL(device, double)           \\\n  REGISTER_ADAPTIVE_POOL_BACKWARD_KERNEL(device, int)\n\nREGISTER_ADAPTIVE_POOL_BACKWARD_KERNEL_WITH_DEVICE(DeviceType::kCPU)\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/adaptive_avg_pool_gpu_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/device/cuda_util.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/kernel_util.cuh\"\n#include \"oneflow/core/common/data_type.h\"\n#include \"oneflow/core/kernel/util/cuda_half_util.h\"\n#include \"oneflow/core/cuda/atomic.cuh\"\n#include \"oneflow/core/operator/operator_util.h\"\n#include \"oneflow/user/utils/pool_util.h\"\n#include \"oneflow/user/kernels/adaptive_pool_kernel_util.h\"\n\nnamespace oneflow {\n\nnamespace user_op {\n\ntemplate<typename T>\n__global__ void InitPtr(int elements, T* ptr) {\n  int gid = (blockDim.x * blockIdx.x) + threadIdx.x;\n  int step = gridDim.x * blockDim.x;\n  while (gid < elements) {\n    ptr[gid] = static_cast<T>(0);\n    gid += step;\n  }\n}\n\ninline Shape GetShape5D(const Shape& shape, const std::string& data_format, int32_t dim) {\n  FixedDimVector shape_3d = {GetInDim(shape, data_format, 0, dim),\n                             GetInDim(shape, data_format, 1, dim),\n                             GetInDim(shape, data_format, 2, dim)};\n  return Shape({shape.At(0), shape.At(1), shape_3d.at(0), shape_3d.at(1), shape_3d.at(2)});\n}\n\ntemplate<typename T>\n__global__ void AdaptiveAvgPoolCudaKernel(const T* input, T* output, int num_elems, int in_d,\n                                          int in_h, int in_w, int out_d, int out_h, int out_w) {\n  const int out_panel_size = out_d * out_h * out_w;\n  const int in_panel_size = in_d * in_h * in_w;\n\n  CUDA_1D_KERNEL_LOOP(idx, num_elems) {\n    // TODO (Tianyu): Replace following codes with 'NdIndexOffsetHelper'\n    int bc_idx = idx / out_panel_size;\n    int out_d_idx = (idx % out_panel_size) / out_w / out_h;\n    int out_h_idx = (idx % out_panel_size) % (out_h * out_w) / out_w;\n    int out_w_idx = (idx % out_panel_size) % (out_h * out_w) % out_w;\n\n    int in_start_d = START_IND(out_d_idx, out_d, in_d);\n    int in_end_d = END_IND(out_d_idx, out_d, in_d);\n    int k_d = in_end_d - in_start_d;\n\n    int in_start_h = START_IND(out_h_idx, out_h, in_h);\n    int in_end_h = END_IND(out_h_idx, out_h, in_h);\n    int k_h = in_end_h - in_start_h;\n\n    int in_start_w = START_IND(out_w_idx, out_w, in_w);\n    int in_end_w = END_IND(out_w_idx, out_w, in_w);\n    int k_w = in_end_w - in_start_w;\n\n    const T* in_ptr =\n        input + bc_idx * in_panel_size + in_start_d * in_h * in_w + in_start_h * in_w + in_start_w;\n    T sum = static_cast<T>(0);\n    for (int id = 0; id < k_d; ++id) {\n      for (int ih = 0; ih < k_h; ++ih) {\n        for (int iw = 0; iw < k_w; ++iw) {\n          T val = *(in_ptr + ih * in_w + iw);\n          sum += val;\n        }\n      }\n      in_ptr += in_h * in_w;  // next input depth\n    }\n    // Update output\n    output[idx] = sum / static_cast<T>(k_d) / static_cast<T>(k_h) / static_cast<T>(k_w);\n  }\n}\n\ntemplate<typename T>\n__global__ void AdaptiveAvgPoolGradCudaKernel(T* input, const T* output, int num_elems, int in_d,\n                                              int in_h, int in_w, int out_d, int out_h, int out_w) {\n  const int out_panel_size = out_d * out_h * out_w;\n  const int in_panel_size = in_d * in_h * in_w;\n\n  CUDA_1D_KERNEL_LOOP(idx, num_elems) {\n    // TODO (Tianyu): Replace following codes with 'NdIndexOffsetHelper'\n    int bc_idx = idx / out_panel_size;\n    int out_d_idx = (idx % out_panel_size) / out_w / out_h;\n    int out_h_idx = (idx % out_panel_size) % (out_h * out_w) / out_w;\n    int out_w_idx = (idx % out_panel_size) % (out_h * out_w) % out_w;\n\n    int in_start_d = START_IND(out_d_idx, out_d, in_d);\n    int in_end_d = END_IND(out_d_idx, out_d, in_d);\n    int k_d = in_end_d - in_start_d;\n\n    int in_start_h = START_IND(out_h_idx, out_h, in_h);\n    int in_end_h = END_IND(out_h_idx, out_h, in_h);\n    int k_h = in_end_h - in_start_h;\n\n    int in_start_w = START_IND(out_w_idx, out_w, in_w);\n    int in_end_w = END_IND(out_w_idx, out_w, in_w);\n    int k_w = in_end_w - in_start_w;\n\n    const T grad_delta =\n        output[idx] / static_cast<T>(k_d) / static_cast<T>(k_h) / static_cast<T>(k_w);\n    T* input_ptr =\n        input + bc_idx * in_panel_size + in_start_d * in_h * in_w + in_start_h * in_w + in_start_w;\n    for (int id = 0; id < k_d; ++id) {\n      for (int ih = 0; ih < k_h; ++ih) {\n        for (int iw = 0; iw < k_w; ++iw) {\n          // TODO (Tianyu): Use 'atmoic::Add' when necessary\n          cuda::atomic::Add(input_ptr + ih * in_w + iw, grad_delta);\n        }\n      }\n      input_ptr += in_h * in_w;  // next input depth\n    }\n  }\n}\n\ntemplate<typename T>\nvoid AvgForwardCompute(KernelComputeContext* ctx, const int32_t& dim) {\n  const Tensor* in_tensor = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n  Tensor* out_tensor = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n  const T* in_ptr = in_tensor->dptr<T>();\n  T* out_ptr = out_tensor->mut_dptr<T>();\n\n  const Shape& x_shape = ctx->TensorDesc4ArgNameAndIndex(\"x\", 0)->shape();\n  const Shape& y_shape = ctx->TensorDesc4ArgNameAndIndex(\"y\", 0)->shape();\n\n  // TODO (Tianyu): Support 'channels_last'\n  const std::string& data_format = ctx->Attr<std::string>(\"data_format\");\n  CHECK_OR_THROW(data_format == \"channels_first\")\n      << \"adaptive_avg_pool on cuda only supports NCHW data format\";\n  const Shape& in = GetShape5D(x_shape, data_format, dim);\n  const Shape& out = GetShape5D(y_shape, data_format, dim);\n\n  const int out_elems = out_tensor->shape_view().elem_cnt();\n\n  RUN_CUDA_KERNEL((AdaptiveAvgPoolCudaKernel<T>), ctx->stream(), out_elems, in_ptr, out_ptr,\n                  out_elems, in.At(2), in.At(3), in.At(4), out.At(2), out.At(3), out.At(4));\n}\n\ntemplate<typename T>\nvoid AvgBackwardCompute(KernelComputeContext* ctx, const int32_t& dim) {\n  const Tensor* out_tensor = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n  Tensor* in_tensor = ctx->Tensor4ArgNameAndIndex(\"dx\", 0);\n  const T* out_ptr = out_tensor->dptr<T>();\n  T* in_ptr = in_tensor->mut_dptr<T>();\n\n  const Shape& dx_shape = ctx->TensorDesc4ArgNameAndIndex(\"dx\", 0)->shape();\n  const Shape& dy_shape = ctx->TensorDesc4ArgNameAndIndex(\"dy\", 0)->shape();\n\n  // TODO (Tianyu): Support 'channels_last'\n  const std::string& data_format = ctx->Attr<std::string>(\"data_format\");\n  CHECK_OR_THROW(data_format == \"channels_first\")\n      << \"adaptive_avg_pool backward on cuda only supports NCHW data format\";\n  const Shape& in = GetShape5D(dx_shape, data_format, dim);\n  const Shape& out = GetShape5D(dy_shape, data_format, dim);\n\n  const int in_elems = in_tensor->shape_view().elem_cnt();\n  const int out_elems = out_tensor->shape_view().elem_cnt();\n\n  RUN_CUDA_KERNEL((InitPtr<T>), ctx->stream(), in_elems, in_elems, in_ptr);\n  RUN_CUDA_KERNEL((AdaptiveAvgPoolGradCudaKernel<T>), ctx->stream(), out_elems, in_ptr, out_ptr,\n                  out_elems, in.At(2), in.At(3), in.At(4), out.At(2), out.At(3), out.At(4));\n}\n\ntemplate<DeviceType device_type, typename T>\nclass GpuAdaptiveAvgPool1dKernel final : public OpKernel {\n public:\n  GpuAdaptiveAvgPool1dKernel() = default;\n  ~GpuAdaptiveAvgPool1dKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(KernelComputeContext* ctx) const override { AvgForwardCompute<T>(ctx, 1); }\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\ntemplate<DeviceType device_type, typename T>\nclass GpuAdaptiveAvgPool2dKernel final : public OpKernel {\n public:\n  GpuAdaptiveAvgPool2dKernel() = default;\n  ~GpuAdaptiveAvgPool2dKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(KernelComputeContext* ctx) const override { AvgForwardCompute<T>(ctx, 2); }\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\ntemplate<DeviceType device_type, typename T>\nclass GpuAdaptiveAvgPool3dKernel final : public OpKernel {\n public:\n  GpuAdaptiveAvgPool3dKernel() = default;\n  ~GpuAdaptiveAvgPool3dKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(KernelComputeContext* ctx) const override { AvgForwardCompute<T>(ctx, 3); }\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\ntemplate<DeviceType device_type, typename T>\nclass GpuAdaptiveAvgPool1dGradKernel final : public OpKernel {\n public:\n  GpuAdaptiveAvgPool1dGradKernel() = default;\n  ~GpuAdaptiveAvgPool1dGradKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(KernelComputeContext* ctx) const override { AvgBackwardCompute<T>(ctx, 1); }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\ntemplate<DeviceType device_type, typename T>\nclass GpuAdaptiveAvgPool2dGradKernel final : public OpKernel {\n public:\n  GpuAdaptiveAvgPool2dGradKernel() = default;\n  ~GpuAdaptiveAvgPool2dGradKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(KernelComputeContext* ctx) const override { AvgBackwardCompute<T>(ctx, 2); }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\ntemplate<DeviceType device_type, typename T>\nclass GpuAdaptiveAvgPool3dGradKernel final : public OpKernel {\n public:\n  GpuAdaptiveAvgPool3dGradKernel() = default;\n  ~GpuAdaptiveAvgPool3dGradKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(KernelComputeContext* ctx) const override { AvgBackwardCompute<T>(ctx, 3); }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_CUDA_ADAPTIVE_AVGPOOL_KERNEL(device, dtype)                   \\\n  REGISTER_USER_KERNEL(\"adaptive_avg_pool1d\")                                  \\\n      .SetCreateFn<GpuAdaptiveAvgPool1dKernel<device, dtype>>()                \\\n      .SetIsMatchedHob((HobDeviceType() == device)                             \\\n                       && (HobDataType(\"y\", 0) == GetDataType<dtype>::value)); \\\n  REGISTER_USER_KERNEL(\"adaptive_avg_pool2d\")                                  \\\n      .SetCreateFn<GpuAdaptiveAvgPool2dKernel<device, dtype>>()                \\\n      .SetIsMatchedHob((HobDeviceType() == device)                             \\\n                       && (HobDataType(\"y\", 0) == GetDataType<dtype>::value)); \\\n  REGISTER_USER_KERNEL(\"adaptive_avg_pool3d\")                                  \\\n      .SetCreateFn<GpuAdaptiveAvgPool3dKernel<device, dtype>>()                \\\n      .SetIsMatchedHob((HobDeviceType() == device)                             \\\n                       && (HobDataType(\"y\", 0) == GetDataType<dtype>::value));\n\nREGISTER_CUDA_ADAPTIVE_AVGPOOL_KERNEL(DeviceType::kCUDA, half);\nREGISTER_CUDA_ADAPTIVE_AVGPOOL_KERNEL(DeviceType::kCUDA, float);\nREGISTER_CUDA_ADAPTIVE_AVGPOOL_KERNEL(DeviceType::kCUDA, double);\nREGISTER_CUDA_ADAPTIVE_AVGPOOL_KERNEL(DeviceType::kCUDA, int);\n\n#define REGISTER_CUDA_ADAPTIVE_AVGPOOL_BACKWARD_KERNEL(device, dtype)           \\\n  REGISTER_USER_KERNEL(\"adaptive_avg_pool1d_grad\")                              \\\n      .SetCreateFn<GpuAdaptiveAvgPool1dGradKernel<device, dtype>>()             \\\n      .SetIsMatchedHob((HobDeviceType() == device)                              \\\n                       && (HobDataType(\"dx\", 0) == GetDataType<dtype>::value)); \\\n  REGISTER_USER_KERNEL(\"adaptive_avg_pool2d_grad\")                              \\\n      .SetCreateFn<GpuAdaptiveAvgPool2dGradKernel<device, dtype>>()             \\\n      .SetIsMatchedHob((HobDeviceType() == device)                              \\\n                       && (HobDataType(\"dx\", 0) == GetDataType<dtype>::value)); \\\n  REGISTER_USER_KERNEL(\"adaptive_avg_pool3d_grad\")                              \\\n      .SetCreateFn<GpuAdaptiveAvgPool3dGradKernel<device, dtype>>()             \\\n      .SetIsMatchedHob((HobDeviceType() == device)                              \\\n                       && (HobDataType(\"dx\", 0) == GetDataType<dtype>::value));\n\nREGISTER_CUDA_ADAPTIVE_AVGPOOL_BACKWARD_KERNEL(DeviceType::kCUDA, half);\nREGISTER_CUDA_ADAPTIVE_AVGPOOL_BACKWARD_KERNEL(DeviceType::kCUDA, float);\nREGISTER_CUDA_ADAPTIVE_AVGPOOL_BACKWARD_KERNEL(DeviceType::kCUDA, double);\nREGISTER_CUDA_ADAPTIVE_AVGPOOL_BACKWARD_KERNEL(DeviceType::kCUDA, int);\n\n}  // namespace user_op\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/adaptive_max_pool_cpu_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/user/kernels/adaptive_pool_kernel_util.h\"\n\nnamespace oneflow {\n\nnamespace {\ntemplate<typename T, int32_t dim>\nvoid AdapativeMaxPoolForward(user_op::KernelComputeContext* ctx) {\n  user_op::Tensor* in_tensor = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n  user_op::Tensor* out_tensor = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n  user_op::Tensor* index_tensor = ctx->Tensor4ArgNameAndIndex(\"index\", 0);\n  const Shape& x_shape = ctx->TensorDesc4ArgNameAndIndex(\"x\", 0)->shape();\n  const Shape& y_shape = ctx->TensorDesc4ArgNameAndIndex(\"y\", 0)->shape();\n\n  // TODO : Support 'channels_last'\n  const std::string& data_format = ctx->Attr<std::string>(\"data_format\");\n  CHECK_OR_THROW(data_format == \"channels_first\")\n      << \"adaptive_max_pool on cpu only supports NCHW data format\";\n  const Shape& in = GetShape5D(x_shape, data_format, dim);\n  const Shape& out = GetShape5D(y_shape, data_format, dim);\n\n  const T* in_ptr = in_tensor->dptr<T>();\n  T* out_ptr = out_tensor->mut_dptr<T>();\n  int64_t* index_ptr = index_tensor->mut_dptr<int64_t>();\n\n  const int64_t input_width = in.Count(4);\n  const int64_t output_width = out.Count(4);\n  const int64_t input_image_size = in.Count(3);\n  const int64_t output_image_size = out.Count(3);\n  const int64_t input_size = in.Count(2);\n  const int64_t output_size = out.Count(2);\n\n  FOR_RANGE(int64_t, n, 0, in.At(0)) {\n    FOR_RANGE(int64_t, c, 0, in.At(1)) {\n      FOR_RANGE(int64_t, od, 0, out.At(2)) {\n        int64_t id0 = start_index(od, out.At(2), in.At(2));\n        int64_t id1 = end_index(od, out.At(2), in.At(2));\n        FOR_RANGE(int64_t, oh, 0, out.At(3)) {\n          int64_t ih0 = start_index(oh, out.At(3), in.At(3));\n          int64_t ih1 = end_index(oh, out.At(3), in.At(3));\n          FOR_RANGE(int64_t, ow, 0, out.At(4)) {\n            int64_t iw0 = start_index(ow, out.At(4), in.At(4));\n            int64_t iw1 = end_index(ow, out.At(4), in.At(4));\n\n            // Find out local max\n            auto start_offset = id0 * input_image_size + ih0 * input_width + iw0;\n            T local_max = in_ptr[start_offset];\n            int64_t local_max_index = start_offset;\n            FOR_RANGE(int64_t, id, id0, id1) {\n              FOR_RANGE(int64_t, ih, ih0, ih1) {\n                FOR_RANGE(int64_t, iw, iw0, iw1) {\n                  auto cur_index = id * input_image_size + ih * input_width + iw;\n                  if (in_ptr[cur_index] > local_max) {\n                    local_max_index = cur_index;\n                    local_max = in_ptr[cur_index];\n                  }\n                }\n              }\n            }\n            auto i = od * output_image_size + oh * output_width + ow;\n            out_ptr[i] = local_max;\n            index_ptr[i] = local_max_index;\n          }\n        }\n      }\n      in_ptr += input_size;\n      index_ptr += output_size;\n      out_ptr += output_size;\n    }\n  }\n}\n\ntemplate<typename T, int32_t dim>\nvoid AdaptiveMaxPoolBackward(user_op::KernelComputeContext* ctx) {\n  user_op::Tensor* grad_input = ctx->Tensor4ArgNameAndIndex(\"dx\", 0);\n  const user_op::Tensor* grad_output = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n  const user_op::Tensor* return_indices = ctx->Tensor4ArgNameAndIndex(\"index\", 0);\n  const Shape& dx_shape = ctx->TensorDesc4ArgNameAndIndex(\"dx\", 0)->shape();\n  const Shape& dy_shape = ctx->TensorDesc4ArgNameAndIndex(\"dy\", 0)->shape();\n\n  // TODO : Support 'channels_last'\n  const std::string& data_format = ctx->Attr<std::string>(\"data_format\");\n  CHECK_OR_THROW(data_format == \"channels_first\")\n      << \"adaptive_max_pool backward on cpu only supports NCHW data format\";\n  const Shape& in = GetShape5D(dx_shape, data_format, dim);\n  const Shape& out = GetShape5D(dy_shape, data_format, dim);\n\n  const T* dy_ptr = grad_output->dptr<T>();\n  const int64_t* indices_ptr = return_indices->dptr<int64_t>();\n  T* dx_ptr = grad_input->mut_dptr<T>();\n\n  std::fill(dx_ptr, dx_ptr + grad_input->shape_view().elem_cnt(), static_cast<T>(0));\n\n  const int64_t output_width = out.Count(4);\n  const int64_t output_image_size = out.Count(3);\n  const int64_t input_size = in.Count(2);\n  const int64_t output_size = out.Count(2);\n\n  FOR_RANGE(int64_t, n, 0, in.At(0)) {\n    FOR_RANGE(int64_t, c, 0, in.At(1)) {\n      FOR_RANGE(int64_t, od, 0, out.At(2)) {\n        FOR_RANGE(int64_t, oh, 0, out.At(3)) {\n          FOR_RANGE(int64_t, ow, 0, out.At(4)) {\n            auto i = od * output_image_size + oh * output_width + ow;\n            dx_ptr[indices_ptr[i]] += dy_ptr[i];\n          }\n        }\n      }\n      dx_ptr += input_size;\n      dy_ptr += output_size;\n      indices_ptr += output_size;\n    }\n  }\n}\n}  // namespace\n\ntemplate<typename T, int32_t dim>\nclass AdaptiveMaxPoolNDCpuKernel final : public user_op::OpKernel {\n public:\n  AdaptiveMaxPoolNDCpuKernel() = default;\n  ~AdaptiveMaxPoolNDCpuKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    AdapativeMaxPoolForward<T, dim>(ctx);\n  }\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\ntemplate<typename T, int32_t dim>\nclass AdaptiveMaxPoolNDGradCpuKernel final : public user_op::OpKernel {\n public:\n  AdaptiveMaxPoolNDGradCpuKernel() = default;\n  ~AdaptiveMaxPoolNDGradCpuKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    AdaptiveMaxPoolBackward<T, dim>(ctx);\n  }\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_ADAPTIVE_MAX_POOLND_CPU(op_type_name, dtype, dim)                      \\\n  REGISTER_USER_KERNEL(op_type_name)                                                    \\\n      .SetCreateFn<AdaptiveMaxPoolNDCpuKernel<dtype, dim>>()                            \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)                   \\\n                       && (user_op::HobDataType(\"y\", 0) == GetDataType<dtype>::value)); \\\n                                                                                        \\\n  REGISTER_USER_KERNEL(op_type_name \"_grad\")                                            \\\n      .SetCreateFn<AdaptiveMaxPoolNDGradCpuKernel<dtype, dim>>()                        \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)                   \\\n                       && (user_op::HobDataType(\"dx\", 0) == GetDataType<dtype>::value));\n\n#define REGISTER_ADAPTIVE_MAX_POOL_CPU(op_type_name, dim)      \\\n  REGISTER_ADAPTIVE_MAX_POOLND_CPU(op_type_name, double, dim); \\\n  REGISTER_ADAPTIVE_MAX_POOLND_CPU(op_type_name, float, dim);  \\\n  REGISTER_ADAPTIVE_MAX_POOLND_CPU(op_type_name, int, dim);\n\nREGISTER_ADAPTIVE_MAX_POOL_CPU(\"adaptive_max_pool1d\", 1);\nREGISTER_ADAPTIVE_MAX_POOL_CPU(\"adaptive_max_pool2d\", 2);\nREGISTER_ADAPTIVE_MAX_POOL_CPU(\"adaptive_max_pool3d\", 3);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/adaptive_max_pool_gpu_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/device/cuda_util.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/kernel_util.cuh\"\n#include \"oneflow/core/common/data_type.h\"\n#include \"oneflow/core/kernel/util/cuda_half_util.h\"\n#include \"oneflow/core/cuda/atomic.cuh\"\n#include \"oneflow/core/operator/operator_util.h\"\n#include \"oneflow/user/utils/pool_util.h\"\n#include \"oneflow/user/kernels/adaptive_pool_kernel_util.h\"\n\nnamespace oneflow {\n\nnamespace user_op {\n\ntemplate<typename T>\n__global__ void AdaptiveMaxPoolCudaKernel(const T* input, T* output, int64_t* return_index,\n                                          int num_elems, int in_d, int in_h, int in_w, int out_d,\n                                          int out_h, int out_w) {\n  const int out_panel_size = out_d * out_h * out_w;\n  const int in_panel_size = in_d * in_h * in_w;\n  const int out_hw = out_w * out_h;\n\n  CUDA_1D_KERNEL_LOOP(idx, num_elems) {\n    int bc_idx = idx / out_panel_size;\n    int out_d_idx = (idx % out_panel_size) / out_hw;\n    int out_h_idx = (idx % out_panel_size) % (out_h * out_w) / out_w;\n    int out_w_idx = (idx % out_panel_size) % (out_h * out_w) % out_w;\n\n    int in_start_d = START_IND(out_d_idx, out_d, in_d);\n    int in_end_d = END_IND(out_d_idx, out_d, in_d);\n    int k_d = in_end_d - in_start_d;\n\n    int in_start_h = START_IND(out_h_idx, out_h, in_h);\n    int in_end_h = END_IND(out_h_idx, out_h, in_h);\n    int k_h = in_end_h - in_start_h;\n\n    int in_start_w = START_IND(out_w_idx, out_w, in_w);\n    int in_end_w = END_IND(out_w_idx, out_w, in_w);\n    int k_w = in_end_w - in_start_w;\n\n    int64_t batch_idx_base = bc_idx * in_panel_size;\n    const T* in_ptr =\n        input + batch_idx_base + in_start_d * in_h * in_w + in_start_h * in_w + in_start_w;\n    T local_max = in_ptr[0];\n    int64_t local_max_index = static_cast<int64_t>(in_ptr - input) - batch_idx_base;\n    for (int id = 0; id < k_d; ++id) {\n      for (int ih = 0; ih < k_h; ++ih) {\n        for (int iw = 0; iw < k_w; ++iw) {\n          T val = *(in_ptr + ih * in_w + iw);\n          if (val > local_max) {\n            local_max = val;\n            local_max_index = in_ptr - input - batch_idx_base + ih * in_w + iw;\n          }\n        }\n      }\n      in_ptr += in_h * in_w;  // next input depth\n    }\n\n    output[idx] = local_max;\n    return_index[idx] = local_max_index;\n  }\n}\n\ntemplate<typename T>\n__global__ void AdaptiveMaxPoolGradCudaKernel(T* input, const T* output, const int64_t* index,\n                                              int dy_elems, int in_panel_size, int out_panel_size) {\n  CUDA_1D_KERNEL_LOOP(idx, dy_elems) {\n    int bc_idx = idx / out_panel_size;\n    T* input_ptr = input + bc_idx * in_panel_size;\n    cuda::atomic::Add(input_ptr + index[idx], output[idx]);\n  }\n}\n\ntemplate<typename T, int32_t dim>\nvoid MaxForwardCompute(KernelComputeContext* ctx) {\n  const Tensor* in_tensor = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n  Tensor* out_tensor = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n  Tensor* return_indices = ctx->Tensor4ArgNameAndIndex(\"index\", 0);\n\n  const T* in_ptr = in_tensor->dptr<T>();\n  T* out_ptr = out_tensor->mut_dptr<T>();\n  int64_t* index_ptr = return_indices->mut_dptr<int64_t>();\n\n  const Shape& x_shape = ctx->TensorDesc4ArgNameAndIndex(\"x\", 0)->shape();\n  const Shape& y_shape = ctx->TensorDesc4ArgNameAndIndex(\"y\", 0)->shape();\n\n  // TODO: Support 'channels_last'\n  const std::string& data_format = ctx->Attr<std::string>(\"data_format\");\n  CHECK_OR_THROW(data_format == \"channels_first\")\n      << \"adaptive_max_pool on CUDA only supports NCHW data format\";\n  const Shape& in = GetShape5D(x_shape, data_format, dim);\n  const Shape& out = GetShape5D(y_shape, data_format, dim);\n\n  const int out_elems = out_tensor->shape_view().elem_cnt();\n\n  RUN_CUDA_KERNEL((AdaptiveMaxPoolCudaKernel<T>), ctx->stream(), out_elems, in_ptr, out_ptr,\n                  index_ptr, out_elems, in.At(2), in.At(3), in.At(4), out.At(2), out.At(3),\n                  out.At(4));\n}\n\ntemplate<typename T, int32_t dim>\nvoid MaxBackwardCompute(KernelComputeContext* ctx) {\n  const Tensor* out_tensor = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n  Tensor* in_tensor = ctx->Tensor4ArgNameAndIndex(\"dx\", 0);\n  const user_op::Tensor* return_indices = ctx->Tensor4ArgNameAndIndex(\"index\", 0);\n\n  const T* out_ptr = out_tensor->dptr<T>();\n  T* in_ptr = in_tensor->mut_dptr<T>();\n  const int64_t* index_ptr = return_indices->dptr<int64_t>();\n\n  const Shape& dx_shape = ctx->TensorDesc4ArgNameAndIndex(\"dx\", 0)->shape();\n  const Shape& dy_shape = ctx->TensorDesc4ArgNameAndIndex(\"dy\", 0)->shape();\n\n  // TODO (Tianyu): Support 'channels_last'\n  const std::string& data_format = ctx->Attr<std::string>(\"data_format\");\n  CHECK_OR_THROW(data_format == \"channels_first\")\n      << \"adaptive_max_pool backward on CUDA only supports NCHW data format\";\n  const Shape& in = GetShape5D(dx_shape, data_format, dim);\n  const Shape& out = GetShape5D(dy_shape, data_format, dim);\n\n  const int in_elems = in_tensor->shape_view().elem_cnt();\n  const int out_elems = out_tensor->shape_view().elem_cnt();\n\n  std::unique_ptr<ep::primitive::Memset> memset_primitive =\n      ep::primitive::NewPrimitive<ep::primitive::MemsetFactory>(ctx->device_type());\n  CHECK(memset_primitive);\n  memset_primitive->Launch(ctx->stream(), in_ptr, 0, in_elems * sizeof(T));\n  RUN_CUDA_KERNEL((AdaptiveMaxPoolGradCudaKernel<T>), ctx->stream(), out_elems, in_ptr, out_ptr,\n                  index_ptr, out_elems, in.At(2) * in.At(3) * in.At(4),\n                  out.At(2) * out.At(3) * out.At(4));\n}\n\ntemplate<DeviceType device_type, typename T, int32_t dim>\nclass GpuAdaptiveMaxPoolNdKernel final : public OpKernel {\n public:\n  GpuAdaptiveMaxPoolNdKernel() = default;\n  ~GpuAdaptiveMaxPoolNdKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(KernelComputeContext* ctx) const override { MaxForwardCompute<T, dim>(ctx); }\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\ntemplate<DeviceType device_type, typename T, int32_t dim>\nclass GpuAdaptiveMaxPoolNdGradKernel final : public OpKernel {\n public:\n  GpuAdaptiveMaxPoolNdGradKernel() = default;\n  ~GpuAdaptiveMaxPoolNdGradKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(KernelComputeContext* ctx) const override { MaxBackwardCompute<T, dim>(ctx); }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_CUDA_ADAPTIVE_MAXPOOL_KERNEL(device, dtype)                   \\\n  REGISTER_USER_KERNEL(\"adaptive_max_pool1d\")                                  \\\n      .SetCreateFn<GpuAdaptiveMaxPoolNdKernel<device, dtype, 1>>()             \\\n      .SetIsMatchedHob((HobDeviceType() == device)                             \\\n                       && (HobDataType(\"y\", 0) == GetDataType<dtype>::value)); \\\n  REGISTER_USER_KERNEL(\"adaptive_max_pool2d\")                                  \\\n      .SetCreateFn<GpuAdaptiveMaxPoolNdKernel<device, dtype, 2>>()             \\\n      .SetIsMatchedHob((HobDeviceType() == device)                             \\\n                       && (HobDataType(\"y\", 0) == GetDataType<dtype>::value)); \\\n  REGISTER_USER_KERNEL(\"adaptive_max_pool3d\")                                  \\\n      .SetCreateFn<GpuAdaptiveMaxPoolNdKernel<device, dtype, 3>>()             \\\n      .SetIsMatchedHob((HobDeviceType() == device)                             \\\n                       && (HobDataType(\"y\", 0) == GetDataType<dtype>::value));\n\nREGISTER_CUDA_ADAPTIVE_MAXPOOL_KERNEL(DeviceType::kCUDA, float);\nREGISTER_CUDA_ADAPTIVE_MAXPOOL_KERNEL(DeviceType::kCUDA, double);\nREGISTER_CUDA_ADAPTIVE_MAXPOOL_KERNEL(DeviceType::kCUDA, int);\n\n#define REGISTER_CUDA_ADAPTIVE_MAXPOOL_BACKWARD_KERNEL(device, dtype)           \\\n  REGISTER_USER_KERNEL(\"adaptive_max_pool1d_grad\")                              \\\n      .SetCreateFn<GpuAdaptiveMaxPoolNdGradKernel<device, dtype, 1>>()          \\\n      .SetIsMatchedHob((HobDeviceType() == device)                              \\\n                       && (HobDataType(\"dx\", 0) == GetDataType<dtype>::value)); \\\n  REGISTER_USER_KERNEL(\"adaptive_max_pool2d_grad\")                              \\\n      .SetCreateFn<GpuAdaptiveMaxPoolNdGradKernel<device, dtype, 2>>()          \\\n      .SetIsMatchedHob((HobDeviceType() == device)                              \\\n                       && (HobDataType(\"dx\", 0) == GetDataType<dtype>::value)); \\\n  REGISTER_USER_KERNEL(\"adaptive_max_pool3d_grad\")                              \\\n      .SetCreateFn<GpuAdaptiveMaxPoolNdGradKernel<device, dtype, 3>>()          \\\n      .SetIsMatchedHob((HobDeviceType() == device)                              \\\n                       && (HobDataType(\"dx\", 0) == GetDataType<dtype>::value));\n\nREGISTER_CUDA_ADAPTIVE_MAXPOOL_BACKWARD_KERNEL(DeviceType::kCUDA, float);\nREGISTER_CUDA_ADAPTIVE_MAXPOOL_BACKWARD_KERNEL(DeviceType::kCUDA, double);\nREGISTER_CUDA_ADAPTIVE_MAXPOOL_BACKWARD_KERNEL(DeviceType::kCUDA, int);\n\n}  // namespace user_op\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/adaptive_pool_kernel_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef _ONEFLOW_USER_KERNELS_ADAPTIVE_POOL_UTIL_H_\n#define _ONEFLOW_USER_KERNELS_ADAPTIVE_POOL_UTIL_H_\n\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/new_kernel_util.h\"\n#include \"oneflow/core/operator/operator_util.h\"\n#include \"oneflow/user/utils/pool_util.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ninline int64_t start_index(int64_t a, int64_t b, int64_t c) {\n  return (int64_t)std::floor((float)(a * c) / b);\n}\n\ninline int64_t end_index(int64_t a, int64_t b, int64_t c) {\n  return (int64_t)std::ceil((float)((a + 1) * c) / b);\n}\n\n#define START_IND(a, b, c) (int)std::floor((float)(a * c) / b)\n#define END_IND(a, b, c) (int)std::ceil((float)((a + 1) * c) / b)\n\n#define START_IND_INT(a, b, c) ((a * c) / b)\n#define END_IND_INT(a, b, c) (((a + 1) * c + b - 1) / b)\n\ninline Shape GetShape5D(const Shape& shape, const std::string& data_format, int32_t dim) {\n  FixedDimVector shape_3d = {GetInDim(shape, data_format, 0, dim),\n                             GetInDim(shape, data_format, 1, dim),\n                             GetInDim(shape, data_format, 2, dim)};\n  return Shape({shape.At(0), shape.At(1), shape_3d.at(0), shape_3d.at(1), shape_3d.at(2)});\n}\n\n}  // namespace\n}  // namespace oneflow\n\n#endif  // _ONEFLOW_USER_KERNELS_ADAPTIVE_POOL_UTIL_H_\n"
  },
  {
    "path": "oneflow/user/kernels/add_n_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/ep/include/primitive/add.h\"\n#include \"oneflow/core/kernel/cuda_graph_support.h\"\n#include \"oneflow/user/kernels/op_kernel_wrapper.h\"\n\nnamespace oneflow {\n\nnamespace user_op {\n\nnamespace {\n\ntemplate<typename Context>\nstd::unique_ptr<ep::primitive::Add> NewAddPrimitive(Context* ctx) {\n  const DataType data_type = ctx->TensorDesc4ArgNameAndIndex(\"out\", 0)->data_type();\n  return ep::primitive::NewPrimitive<ep::primitive::AddFactory>(ctx->device_type(), data_type);\n}\n\nclass AddNKernel : public OpKernel, public CudaGraphSupport {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(AddNKernel);\n  AddNKernel() = default;\n  ~AddNKernel() override = default;\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n\n private:\n  void Compute(KernelComputeContext* ctx) const override {\n    auto primitive = NewAddPrimitive(ctx);\n    CHECK(primitive);\n    Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    const DataType data_type = out->data_type();\n    const size_t count = out->shape_view().elem_cnt();\n    if (count == 0) { return; }\n    size_t in_num = ctx->inputs().size();\n    std::vector<const void*> srcs(in_num);\n    for (size_t i = 0; i < in_num; ++i) {\n      const Tensor* in_i = ctx->Tensor4ArgNameAndIndex(\"in\", i);\n      CHECK_EQ(in_i->shape_view().elem_cnt(), count);\n      CHECK_EQ(in_i->data_type(), data_type);\n      srcs[i] = in_i->template dptr();\n    }\n    primitive->Launch(ctx->stream(), srcs.data(), in_num, out->mut_dptr(), count);\n  }\n};\n\nauto AddPrimitiveExists() {\n  return hob::make_custom(\"AddPrimitiveExists\", [](const KernelRegContext& ctx) {\n    return NewAddPrimitive(&ctx).operator bool();\n  });\n}\n\nREGISTER_USER_KERNEL(\"add_n\")\n    .SetCreateFn<AddNKernel>()\n    .SetIsMatchedHob(AddPrimitiveExists() == true)\n    .SetInplaceProposalFn([](const InferContext&,\n                             const AddInplaceArgPair& AddInplaceArgPairFn) -> Maybe<void> {\n      OF_RETURN_IF_ERROR(AddInplaceArgPairFn(\"out\", 0, \"in\", 0, true));\n      return Maybe<void>::Ok();\n    });\n\n}  // namespace\n\n}  // namespace user_op\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/affine_grid_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/common/data_type.pb.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/config_def.h\"\n#include \"oneflow/core/ep/include/primitive/matmul.h\"\n#include \"affine_grid_kernel.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nep::primitive::BlasTransposeType GetBlasTransposeType(bool transpose) {\n  return transpose ? ep::primitive::BlasTransposeType::T : ep::primitive::BlasTransposeType::N;\n}\n\nstd::unique_ptr<ep::primitive::Matmul> NewMatmulPrimitive(DeviceType device_type,\n                                                          DataType data_type, bool transpose_a,\n                                                          bool transpose_b) {\n  const auto trans_a = GetBlasTransposeType(transpose_a);\n  const auto trans_b = GetBlasTransposeType(transpose_b);\n  return ep::primitive::NewPrimitive<ep::primitive::MatmulFactory>(device_type, data_type, trans_a,\n                                                                   trans_b);\n}\n\ntemplate<typename Context>\nstd::unique_ptr<ep::primitive::Matmul> NewAffineGridMatmulPrimitive(Context* ctx) {\n  const DataType data_type = ctx->TensorDesc4ArgNameAndIndex(\"theta\", 0)->data_type();\n  return NewMatmulPrimitive(ctx->device_type(), data_type, /*transpose_a=*/false,\n                            /*transpose_b=*/true);\n}\n\nauto AffineGridMatmulPrimitiveExists() {\n  return hob::make_custom(\"AffineGridMatmulPrimitiveExists\",\n                          [](const user_op::KernelRegContext& ctx) {\n                            return NewAffineGridMatmulPrimitive(&ctx).operator bool();\n                          });\n}\n\ntemplate<typename Context>\nstd::unique_ptr<ep::primitive::Matmul> NewAffineGridGradMatmulPrimitive(Context* ctx) {\n  const DataType data_type = ctx->TensorDesc4ArgNameAndIndex(\"dgrid\", 0)->data_type();\n  return NewMatmulPrimitive(ctx->device_type(), data_type, /*transpose_a=*/true,\n                            /*transpose_b=*/false);\n}\n\nauto AffineGridGradMatmulPrimitiveExists() {\n  return hob::make_custom(\"AffineGridGradMatmulPrimitiveExists\",\n                          [](const user_op::KernelRegContext& ctx) {\n                            return NewAffineGridGradMatmulPrimitive(&ctx).operator bool();\n                          });\n}\n\n}  // namespace\n\ntemplate<DeviceType device_type, typename data_type>\nclass AffineGridKernel final : public user_op::OpKernel {\n public:\n  AffineGridKernel() = default;\n  ~AffineGridKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* theta = ctx->Tensor4ArgNameAndIndex(\"theta\", 0);\n    user_op::Tensor* grid = ctx->Tensor4ArgNameAndIndex(\"grid\", 0);\n    user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n    const Shape& size = ctx->Attr<Shape>(\"size\");\n    const bool& align_corners = ctx->Attr<bool>(\"align_corners\");\n    bool is_2d_grid = true;\n    if (size.NumAxes() == 5) { is_2d_grid = false; }\n\n    int64_t N = theta->shape_view().At(0);\n    int64_t theta_h = theta->shape_view().At(1);\n    int64_t theta_w = theta->shape_view().At(2);\n\n    auto matmul = NewAffineGridMatmulPrimitive(ctx);\n    CHECK(matmul);\n\n    if (is_2d_grid) {\n      int64_t H = size.At(2);\n      int64_t W = size.At(3);\n      // generate base grid\n      GenerateBaseGridImp<device_type>::Generate2D(ctx, tmp_buffer->mut_dptr<data_type>(), H, W,\n                                                   align_corners);\n\n      // Compute each batch\n      for (int n = 0; n < N; n++) {\n        matmul->Launch(ctx->stream(), H * W, theta_h, theta_w, /*alpha=*/1.0,\n                       tmp_buffer->dptr<data_type>(),\n                       theta->dptr<data_type>() + n * theta_h * theta_w, /*beta=*/0.0,\n                       grid->mut_dptr<data_type>() + n * theta_h * H * W);\n      }\n    } else {\n      int64_t D = size.At(2);\n      int64_t H = size.At(3);\n      int64_t W = size.At(4);\n      // generate base grid\n      GenerateBaseGridImp<device_type>::Generate3D(ctx, tmp_buffer->mut_dptr<data_type>(), D, H, W,\n                                                   align_corners);\n      // Compute each batch\n      for (int n = 0; n < N; n++) {\n        matmul->Launch(ctx->stream(), D * H * W, theta_h, theta_w, /*alpha=*/1.0,\n                       tmp_buffer->dptr<data_type>(),\n                       theta->dptr<data_type>() + n * theta_h * theta_w, /*beta=*/0.0,\n                       grid->mut_dptr<data_type>() + n * theta_h * D * H * W);\n      }\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_AFFINE_GRID_KERNEL(device, dtype)                                        \\\n  REGISTER_USER_KERNEL(\"affine_grid\")                                                     \\\n      .SetCreateFn<AffineGridKernel<device, dtype>>()                                     \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == device)                               \\\n                       && (user_op::HobDataType(\"theta\", 0) == GetDataType<dtype>::value) \\\n                       && AffineGridMatmulPrimitiveExists())                              \\\n      .SetInferTmpSizeFn([](user_op::InferContext* ctx) -> size_t {                       \\\n        const Shape& size = ctx->Attr<Shape>(\"size\");                                     \\\n        size_t tmp_buffer_size = size.Count(2) * (size.NumAxes() - 1) * sizeof(dtype);    \\\n        return tmp_buffer_size;                                                           \\\n      })\n\nREGISTER_AFFINE_GRID_KERNEL(DeviceType::kCPU, float);\nREGISTER_AFFINE_GRID_KERNEL(DeviceType::kCPU, double);\n#ifdef WITH_CUDA\nREGISTER_AFFINE_GRID_KERNEL(DeviceType::kCUDA, float);\nREGISTER_AFFINE_GRID_KERNEL(DeviceType::kCUDA, double);\n#endif\n\ntemplate<DeviceType device_type, typename data_type>\nclass AffineGridGradKernel final : public user_op::OpKernel {\n public:\n  AffineGridGradKernel() = default;\n  ~AffineGridGradKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* dgrid = ctx->Tensor4ArgNameAndIndex(\"dgrid\", 0);\n    user_op::Tensor* dtheta = ctx->Tensor4ArgNameAndIndex(\"dtheta\", 0);\n    user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n    const Shape& size = ctx->Attr<Shape>(\"size\");\n    const bool& align_corners = ctx->Attr<bool>(\"align_corners\");\n    bool is_2d_grid = true;\n    if (size.NumAxes() == 5) { is_2d_grid = false; }\n\n    int64_t N = dtheta->shape_view().At(0);\n    int64_t dtheta_h = dtheta->shape_view().At(1);\n    int64_t dtheta_w = dtheta->shape_view().At(2);\n\n    auto matmul = NewAffineGridGradMatmulPrimitive(ctx);\n    CHECK(matmul);\n\n    if (is_2d_grid) {\n      int64_t H = size.At(2);\n      int64_t W = size.At(3);\n      // generate base grid\n      GenerateBaseGridImp<device_type>::Generate2D(ctx, tmp_buffer->mut_dptr<data_type>(), H, W,\n                                                   align_corners);\n      // Compute each batch\n      for (int n = 0; n < N; n++) {\n        matmul->Launch(ctx->stream(), dtheta_h, dtheta_w, H * W, /*alpha=*/1.0,\n                       dgrid->dptr<data_type>() + n * dtheta_h * H * W,\n                       tmp_buffer->dptr<data_type>(), /*beta=*/0.0,\n                       dtheta->mut_dptr<data_type>() + n * dtheta_h * dtheta_w);\n      }\n    } else {\n      int64_t D = size.At(2);\n      int64_t H = size.At(3);\n      int64_t W = size.At(4);\n      GenerateBaseGridImp<device_type>::Generate3D(ctx, tmp_buffer->mut_dptr<data_type>(), D, H, W,\n                                                   align_corners);\n      // Compute each batch\n      for (int n = 0; n < N; n++) {\n        matmul->Launch(ctx->stream(), dtheta_h, dtheta_w, D * H * W, /*alpha=*/1.0,\n                       dgrid->dptr<data_type>() + n * dtheta_h * D * H * W,\n                       tmp_buffer->dptr<data_type>(), /*beta=*/0.0,\n                       dtheta->mut_dptr<data_type>() + n * dtheta_h * dtheta_w);\n      }\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_AFFINE_GRID_GRAD_KERNEL(device, dtype)                                   \\\n  REGISTER_USER_KERNEL(\"affine_grid_grad\")                                                \\\n      .SetCreateFn<AffineGridGradKernel<device, dtype>>()                                 \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == device)                               \\\n                       && (user_op::HobDataType(\"dgrid\", 0) == GetDataType<dtype>::value) \\\n                       && AffineGridGradMatmulPrimitiveExists())                          \\\n      .SetInferTmpSizeFn([](user_op::InferContext* ctx) -> size_t {                       \\\n        const Shape& size = ctx->Attr<Shape>(\"size\");                                     \\\n        size_t tmp_buffer_size = size.Count(2) * (size.NumAxes() - 1) * sizeof(dtype);    \\\n        return tmp_buffer_size;                                                           \\\n      })\n\nREGISTER_AFFINE_GRID_GRAD_KERNEL(DeviceType::kCPU, float);\nREGISTER_AFFINE_GRID_GRAD_KERNEL(DeviceType::kCPU, double);\n#ifdef WITH_CUDA\nREGISTER_AFFINE_GRID_GRAD_KERNEL(DeviceType::kCUDA, float);\nREGISTER_AFFINE_GRID_GRAD_KERNEL(DeviceType::kCUDA, double);\n#endif\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/affine_grid_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/kernel/new_kernel_util.h\"\n#include \"oneflow/core/kernel/kernel_util.h\"\n#include \"oneflow/core/device/cuda_util.h\"\n#include \"affine_grid_kernel.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename data_type, bool align_corners>\nOF_DEVICE_FUNC data_type LinspaceGPU(int32_t index, int32_t num_steps) {\n  if (num_steps <= 1) { return static_cast<data_type>(0.0); }\n\n  if (align_corners) {\n    return static_cast<data_type>(-1.0 + 2.0 / (num_steps - 1) * index);\n  } else {\n    return static_cast<data_type>((-1.0 + 2.0 / (num_steps - 1) * index) * (num_steps - 1)\n                                  / num_steps);\n  }\n}\n\ntemplate<typename data_type, bool align_corners>\n__global__ void Generate2DBaseGridGPUKernel(const int32_t nthreads, data_type* grid_ptr, int32_t H,\n                                            int32_t W) {\n  CUDA_1D_KERNEL_LOOP(index, nthreads) {\n    const int32_t h = index / W;\n    const int32_t w = index % W;\n    const int32_t pixel_length = 3;\n    data_type* row_ptr = grid_ptr + h * W * pixel_length;\n    data_type* pixel_ptr = row_ptr + w * pixel_length;\n    data_type h_value = LinspaceGPU<data_type, align_corners>(h, H);\n    data_type w_value = LinspaceGPU<data_type, align_corners>(w, W);\n\n    pixel_ptr[0] = w_value;\n    pixel_ptr[1] = h_value;\n    pixel_ptr[2] = static_cast<data_type>(1.0);\n  }\n}\n\ntemplate<typename data_type, bool align_corners>\n__global__ void Generate3DBaseGridGPUKernel(const int32_t nthreads, data_type* grid_ptr, int32_t D,\n                                            int32_t H, int32_t W) {\n  CUDA_1D_KERNEL_LOOP(index, nthreads) {\n    const int32_t d = index / H;\n    const int32_t h = index % H;\n    const int32_t pixel_length = 4;\n    data_type* image_ptr = grid_ptr + d * H * W * pixel_length;\n    data_type* row_ptr = image_ptr + h * W * pixel_length;\n    data_type d_value = LinspaceGPU<data_type, align_corners>(d, D);\n    data_type h_value = LinspaceGPU<data_type, align_corners>(h, H);\n\n    for (int32_t w = 0; w < W; ++w) {\n      data_type* pixel_ptr = row_ptr + w * pixel_length;\n      data_type w_value = LinspaceGPU<data_type, align_corners>(w, W);\n      pixel_ptr[0] = w_value;\n      pixel_ptr[1] = h_value;\n      pixel_ptr[2] = d_value;\n      pixel_ptr[3] = static_cast<data_type>(1.0);\n    }\n  }\n}\n\n}  // namespace\n\nvoid GenerateBaseGridImp<DeviceType::kCUDA>::Generate2D(user_op::KernelComputeContext* ctx,\n                                                        float* grid_ptr, int64_t H, int64_t W,\n                                                        bool align_corners) {\n  int count = H * W;\n  if (align_corners) {\n    RUN_CUDA_KERNEL((Generate2DBaseGridGPUKernel<float, true>), ctx->stream(), count, count,\n                    grid_ptr, H, W);\n  } else {\n    RUN_CUDA_KERNEL((Generate2DBaseGridGPUKernel<float, false>), ctx->stream(), count, count,\n                    grid_ptr, H, W);\n  }\n}\nvoid GenerateBaseGridImp<DeviceType::kCUDA>::Generate2D(user_op::KernelComputeContext* ctx,\n                                                        double* grid_ptr, int64_t H, int64_t W,\n                                                        bool align_corners) {\n  int count = H * W;\n  if (align_corners) {\n    RUN_CUDA_KERNEL((Generate2DBaseGridGPUKernel<double, true>), ctx->stream(), count, count,\n                    grid_ptr, H, W);\n  } else {\n    RUN_CUDA_KERNEL((Generate2DBaseGridGPUKernel<double, false>), ctx->stream(), count, count,\n                    grid_ptr, H, W);\n  }\n}\n\nvoid GenerateBaseGridImp<DeviceType::kCUDA>::Generate3D(user_op::KernelComputeContext* ctx,\n                                                        float* grid_ptr, int64_t D, int64_t H,\n                                                        int64_t W, bool align_corners) {\n  int count = D * H;\n  if (align_corners) {\n    RUN_CUDA_KERNEL((Generate3DBaseGridGPUKernel<float, true>), ctx->stream(), count, count,\n                    grid_ptr, D, H, W);\n  } else {\n    RUN_CUDA_KERNEL((Generate3DBaseGridGPUKernel<float, false>), ctx->stream(), count, count,\n                    grid_ptr, D, H, W);\n  }\n}\n\nvoid GenerateBaseGridImp<DeviceType::kCUDA>::Generate3D(user_op::KernelComputeContext* ctx,\n                                                        double* grid_ptr, int64_t D, int64_t H,\n                                                        int64_t W, bool align_corners) {\n  int count = D * H;\n  if (align_corners) {\n    RUN_CUDA_KERNEL((Generate3DBaseGridGPUKernel<double, true>), ctx->stream(), count, count,\n                    grid_ptr, D, H, W);\n  } else {\n    RUN_CUDA_KERNEL((Generate3DBaseGridGPUKernel<double, false>), ctx->stream(), count, count,\n                    grid_ptr, D, H, W);\n  }\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/affine_grid_kernel.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef _ONEFLOW_USER_KERNELS_ACTIVATION_KERNELS_H_\n#define _ONEFLOW_USER_KERNELS_ACTIVATION_KERNELS_H_\n\n#include \"oneflow/core/framework/op_kernel.h\"\n#include \"oneflow/core/common/device_type.h\"\n\nnamespace oneflow {\n\ntemplate<DeviceType device_type>\nstruct GenerateBaseGridImp {};\n\ntemplate<>\nstruct GenerateBaseGridImp<DeviceType::kCPU> {\n  template<typename data_type>\n  static void Linspace(std::vector<data_type>& grid, int64_t num_steps, bool align_corners) {\n    if (num_steps <= 1) {\n      for (auto& it : grid) { it = static_cast<data_type>(0.0); }\n      return;\n    }\n\n    if (align_corners) {\n      for (int i = 0; i < num_steps; i++) {\n        grid[i] = static_cast<data_type>(-1.0 + 2.0 / (num_steps - 1) * i);\n      }\n    } else {\n      for (int i = 0; i < num_steps; i++) {\n        grid[i] = static_cast<data_type>((-1.0 + 2.0 / (num_steps - 1) * i) * (num_steps - 1)\n                                         / num_steps);\n      }\n    }\n  }\n\n  template<typename data_type>\n  static void Generate2D(user_op::KernelComputeContext*, data_type* grid_ptr, int64_t H, int64_t W,\n                         bool align_corners) {\n    std::vector<data_type> w_step(W);\n    std::vector<data_type> h_step(H);\n    Linspace(w_step, W, align_corners);\n    Linspace(h_step, H, align_corners);\n\n    for (int h = 0; h < H; h++) {\n      data_type* row_ptr = grid_ptr + h * W * 3;\n      for (int w = 0; w < W; w++) {\n        data_type* pixel_ptr = row_ptr + w * 3;\n        pixel_ptr[0] = w_step[w];\n        pixel_ptr[1] = h_step[h];\n        pixel_ptr[2] = static_cast<data_type>(1.0);\n      }\n    }\n  }\n\n  template<typename data_type>\n  static void Generate3D(user_op::KernelComputeContext*, data_type* grid_ptr, int64_t D, int64_t H,\n                         int64_t W, bool align_corners) {\n    std::vector<data_type> w_step(W);\n    std::vector<data_type> h_step(H);\n    std::vector<data_type> d_step(D);\n    Linspace(w_step, W, align_corners);\n    Linspace(h_step, H, align_corners);\n    Linspace(d_step, D, align_corners);\n\n    for (int d = 0; d < D; d++) {\n      data_type* image_ptr = grid_ptr + d * H * W * 4;\n      for (int h = 0; h < H; h++) {\n        data_type* row_ptr = image_ptr + h * W * 4;\n        for (int w = 0; w < W; w++) {\n          data_type* pixel_ptr = row_ptr + w * 4;\n          pixel_ptr[0] = w_step[w];\n          pixel_ptr[1] = h_step[h];\n          pixel_ptr[2] = d_step[d];\n          pixel_ptr[3] = static_cast<data_type>(1.0);\n        }\n      }\n    }\n  }\n};\n\ntemplate<>\nstruct GenerateBaseGridImp<DeviceType::kCUDA> {\n  static void Generate2D(user_op::KernelComputeContext* ctx, float* grid_ptr, int64_t H, int64_t W,\n                         bool align_corners);\n  static void Generate2D(user_op::KernelComputeContext* ctx, double* grid_ptr, int64_t H, int64_t W,\n                         bool align_corners);\n\n  static void Generate3D(user_op::KernelComputeContext* ctx, float* grid_ptr, int64_t D, int64_t H,\n                         int64_t W, bool align_corners);\n  static void Generate3D(user_op::KernelComputeContext* ctx, double* grid_ptr, int64_t D, int64_t H,\n                         int64_t W, bool align_corners);\n};\n\n}  // namespace oneflow\n\n#endif  // _ONEFLOW_USER_KERNELS_ACTIVATION_KERNELS_H_\n"
  },
  {
    "path": "oneflow/user/kernels/arange_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/user/kernels/arange_kernel_util.h\"\n#include \"oneflow/core/common/data_type.h\"\n#include \"oneflow/core/job/nd_sbp_util.h\"\n#include \"oneflow/core/kernel/cuda_graph_support.h\"\n\nnamespace oneflow {\nnamespace user_op {\nclass ArangeOpKernelCache final : public user_op::OpKernelCache {\n public:\n  ArangeOpKernelCache(int32_t lower, int32_t upper) : lower_(lower), upper_(upper) {}\n  ~ArangeOpKernelCache() override = default;\n\n  int32_t lower() const { return lower_; }\n  int32_t upper() const { return upper_; }\n\n private:\n  const int32_t lower_;\n  const int32_t upper_;\n};\ntemplate<DeviceType device_type, typename T>\nclass ArangeKernel final : public OpKernel, public CudaGraphSupport {\n public:\n  ArangeKernel() = default;\n  ~ArangeKernel() = default;\n  std::shared_ptr<user_op::OpKernelCache> InitOpKernelCache(\n      user_op::KernelCacheContext* ctx) const override {\n    DataType dtype = ctx->Attr<DataType>(\"dtype\");\n    int64_t range_elem_cnt = 0;\n    int64_t parallel_num = ctx->parallel_ctx().parallel_num();\n    if (parallel_num > 1) {\n      if (IsIntegralDataType(dtype)) {\n        int64_t integer_delta = ctx->Attr<int64_t>(\"integer_delta\");\n        int64_t integer_start = ctx->Attr<int64_t>(\"integer_start\");\n        int64_t integer_limit = ctx->Attr<int64_t>(\"integer_limit\");\n        range_elem_cnt =\n            std::ceil(static_cast<double>(integer_limit - integer_start) / integer_delta);\n      } else {\n        double float_delta = ctx->Attr<double>(\"float_delta\");\n        double float_start = ctx->Attr<double>(\"float_start\");\n        double float_limit = ctx->Attr<double>(\"float_limit\");\n        range_elem_cnt = std::ceil(static_cast<double>(float_limit - float_start) / float_delta);\n      }\n      const Shape& logical_shape = Shape({range_elem_cnt});\n      const NdSbp& nd_sbp = ctx->NdSbp4ArgNameAndIndex(\"out\", 0);\n      const Shape& parallel_hierarchy = *ctx->parallel_desc().hierarchy();\n      const int64_t parallel_id = ctx->parallel_ctx().parallel_id();\n      TensorSliceView view =\n          GetTensorSliceView4ParallelId(parallel_hierarchy, nd_sbp, logical_shape, parallel_id);\n      std::shared_ptr<ArangeOpKernelCache> cache(\n          new ArangeOpKernelCache(view.At(0).begin(), view.At(0).end()));\n      return cache;\n    } else {\n      return nullptr;\n    }\n  }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*,\n               const user_op::OpKernelCache* cache) const override {\n    Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    T* output = out->mut_dptr<T>();\n    const DataType dtype = ctx->Attr<DataType>(\"dtype\");\n    int64_t arange_elem_cnt = 0;\n    T start = static_cast<T>(0.0);\n    T delta = static_cast<T>(0.0);\n    T limit = static_cast<T>(0.0);\n    if (IsIntegralDataType(dtype)) {\n      start = static_cast<T>(static_cast<double>(ctx->Attr<int64_t>(\"integer_start\")));\n      delta = static_cast<T>(static_cast<double>(ctx->Attr<int64_t>(\"integer_delta\")));\n      limit = static_cast<T>(static_cast<double>(ctx->Attr<int64_t>(\"integer_limit\")));\n      arange_elem_cnt = std::ceil(static_cast<double>(limit - start) / static_cast<double>(delta));\n    } else {\n      // If we use static_cast<T>(start, delta, limit) and std::ceil to calculate arange_elem_cnt,\n      // it will cause rounding error.\n      double float_start = ctx->Attr<double>(\"float_start\");\n      double float_delta = ctx->Attr<double>(\"float_delta\");\n      double float_limit = ctx->Attr<double>(\"float_limit\");\n      arange_elem_cnt = std::ceil(static_cast<double>(float_limit - float_start) / float_delta);\n      start = static_cast<T>(float_start);\n      delta = static_cast<T>(float_delta);\n      limit = static_cast<T>(float_limit);\n    }\n    if (arange_elem_cnt == 0) { return; }\n    if (cache == nullptr) {\n      ArangeFunctor<device_type, T>()(ctx->stream(), start, delta, arange_elem_cnt, output);\n    } else {\n      const auto* arange_cache = dynamic_cast<const ArangeOpKernelCache*>(cache);\n      auto arange_len = arange_cache->upper() - arange_cache->lower();\n      auto lower = static_cast<T>(static_cast<float>(arange_cache->lower()));\n      ArangeFunctor<device_type, T>()(ctx->stream(), static_cast<T>(start + delta * lower), delta,\n                                      arange_len, output);\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_ARANGE_KERNEL(device, dtype)                                                \\\n  REGISTER_USER_KERNEL(\"arange\").SetCreateFn<ArangeKernel<device, dtype>>().SetIsMatchedHob( \\\n      (user_op::HobDeviceType() == device)                                                   \\\n      && (user_op::HobAttr<DataType>(\"dtype\") == GetDataType<dtype>::value));\n\n#define REGISTER_ARANGE_KERNELS_WITH_DEVICE(device) \\\n  REGISTER_ARANGE_KERNEL(device, uint8_t)           \\\n  REGISTER_ARANGE_KERNEL(device, int8_t)            \\\n  REGISTER_ARANGE_KERNEL(device, int32_t)           \\\n  REGISTER_ARANGE_KERNEL(device, int64_t)           \\\n  REGISTER_ARANGE_KERNEL(device, float)             \\\n  REGISTER_ARANGE_KERNEL(device, double)\n\n#define REGISTER_ARANGE_KERNELS_WITH_CUDA_HALF(device) REGISTER_ARANGE_KERNEL(device, half)\n\n// Register CPU version\nREGISTER_ARANGE_KERNELS_WITH_DEVICE(DeviceType::kCPU);\nREGISTER_ARANGE_KERNEL(DeviceType::kCPU, float16);\n// Register GPU version\n#ifdef WITH_CUDA\nREGISTER_ARANGE_KERNELS_WITH_DEVICE(DeviceType::kCUDA);\nREGISTER_ARANGE_KERNELS_WITH_CUDA_HALF(DeviceType::kCUDA);\n#endif\n}  // namespace user_op\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/arange_kernel_util.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/user/kernels/arange_kernel_util.h\"\n\nnamespace oneflow {\n\nnamespace user_op {\ntemplate<typename T>\nstruct ArangeFunctor<DeviceType::kCPU, T> final {\n  void operator()(ep::Stream* stream, const T start, const T delta, const int64_t arange_elem_cnt,\n                  T* out) {\n    DoArange<T>(start, delta, arange_elem_cnt, out);\n  }\n};\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_ARANGE_FUNCTOR, (DeviceType::kCPU),\n                                 ARANGE_DATA_TYPE_SEQ);\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_ARANGE_FUNCTOR, (DeviceType::kCPU),\n                                 FLOAT16_DATA_TYPE_SEQ);\n\n}  // namespace user_op\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/arange_kernel_util.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifdef WITH_CUDA\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/user/kernels/arange_kernel_util.h\"\n\nnamespace oneflow {\n\nnamespace user_op {\n\ntemplate<typename T>\n__global__ void ArangeForwardGpuKernel(const T start, const T delta, const int64_t arange_elem_cnt,\n                                       T* out) {\n  // Use Loop to set the value\n  DoArange<T>(start, delta, arange_elem_cnt, out);\n}\n\ntemplate<>\n__global__ void ArangeForwardGpuKernel(const half start, const half delta,\n                                       const int64_t arange_elem_cnt, half* out) {\n  // Use Loop to set the value\n  XPU_1D_KERNEL_LOOP(i, arange_elem_cnt) {\n    out[i] = start + static_cast<half>(static_cast<float>(i)) * delta;\n  }\n}\n\ntemplate<typename T>\nstruct ArangeFunctor<DeviceType::kCUDA, T> final {\n  void operator()(ep::Stream* stream, const T start, const T delta, const int64_t arange_elem_cnt,\n                  T* out) {\n    // The thread num is set as arange_elem_cnt\n    RUN_CUDA_KERNEL((ArangeForwardGpuKernel<T>), stream, arange_elem_cnt, start, delta,\n                    arange_elem_cnt, out);\n  }\n};\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_ARANGE_FUNCTOR, (DeviceType::kCUDA),\n                                 ARANGE_DATA_TYPE_SEQ);\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_ARANGE_FUNCTOR, (DeviceType::kCUDA),\n                                 HALF_DATA_TYPE_SEQ);\n}  // namespace user_op\n}  // namespace oneflow\n\n#endif  // End WITH_CUDA\n"
  },
  {
    "path": "oneflow/user/kernels/arange_kernel_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_KERNELS_ARANGE_KERNEL_UTIL_H_\n#define ONEFLOW_USER_KERNELS_ARANGE_KERNEL_UTIL_H_\n#include \"oneflow/core/ep/include/stream.h\"\n#include \"oneflow/core/ndarray/xpu_util.h\"\n\nnamespace oneflow {\n\n#define ARANGE_DATA_TYPE_SEQ \\\n  FLOATING_DATA_TYPE_SEQ     \\\n  INT_DATA_TYPE_SEQ          \\\n  UNSIGNED_INT_DATA_TYPE_SEQ\n\nnamespace user_op {\ntemplate<DeviceType device_type, typename T>\nstruct ArangeFunctor final {\n  void operator()(ep::Stream* stream, const T start, const T delta, const int64_t arange_elem_cnt,\n                  T* out);\n};\n\ntemplate<typename T>\nOF_DEVICE_FUNC void DoArange(const T start, const T delta, const int64_t arange_elem_cnt, T* out) {\n  XPU_1D_KERNEL_LOOP(i, arange_elem_cnt) { out[i] = start + i * delta; }\n}\n\n#define INSTANTIATE_ARANGE_FUNCTOR(device_type_v, dtype_pair) \\\n  template struct ArangeFunctor<device_type_v, OF_PP_PAIR_FIRST(dtype_pair)>;\n\n}  // namespace user_op\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_KERNELS_ARANGE_KERNEL_UTIL_H_\n"
  },
  {
    "path": "oneflow/user/kernels/arg_sort_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/new_kernel_util.h\"\n\nnamespace oneflow {\n\ntemplate<typename T>\nclass CpuArgSortKernel final : public user_op::OpKernel {\n public:\n  CpuArgSortKernel() = default;\n  ~CpuArgSortKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n\n    const int32_t instance_size = in->shape_view().At(in->shape_view().NumAxes() - 1);\n    const int32_t instance_num = in->shape_view().elem_cnt() / instance_size;\n    const std::string& direction = ctx->Attr<std::string>(\"direction\");\n    const bool is_ascending = direction == \"ASCENDING\";\n    const bool is_descending = direction == \"DESCENDING\";\n    FOR_RANGE(int32_t, i, 0, instance_num) {\n      const T* in_ptr_i = in->dptr<T>() + i * instance_size;\n      int32_t* out_ptr_i = out->mut_dptr<int32_t>() + i * instance_size;\n      std::iota(out_ptr_i, out_ptr_i + instance_size, 0);\n      auto comp = [&](const int32_t lhs, const int32_t rhs) {\n        const T l = in_ptr_i[lhs];\n        const T r = in_ptr_i[rhs];\n        if (l == r) {\n          return lhs < rhs;\n        } else {\n          if (is_ascending) {\n            return l < r;\n          } else if (is_descending) {\n            return l > r;\n          } else {\n            LOG(FATAL) << \"expected the input direction parameter value is \\\"ASCENDING\\\" or \"\n                          \"\\\"DESCENDING\\\", \"\n                       << \"but found the value is \"\n                       << \"\\\"\" << direction << \"\\\"\";\n          }\n        }\n      };\n      std::sort(out_ptr_i, out_ptr_i + instance_size, comp);\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_CPU_ARG_SORT_KERNEL(dtype)                           \\\n  REGISTER_USER_KERNEL(\"arg_sort\")                                    \\\n      .SetCreateFn<CpuArgSortKernel<dtype>>()                         \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \\\n                       && (user_op::HobDataType(\"in\", 0) == GetDataType<dtype>::value));\n\nREGISTER_CPU_ARG_SORT_KERNEL(float)\nREGISTER_CPU_ARG_SORT_KERNEL(double)\nREGISTER_CPU_ARG_SORT_KERNEL(bool)\nREGISTER_CPU_ARG_SORT_KERNEL(int8_t)\nREGISTER_CPU_ARG_SORT_KERNEL(uint8_t)\nREGISTER_CPU_ARG_SORT_KERNEL(int32_t)\nREGISTER_CPU_ARG_SORT_KERNEL(int64_t)\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/arg_sort_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/new_kernel_util.h\"\n#include \"oneflow/user/kernels/radix_sort.cuh\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename T>\nclass TmpBufferManager final {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(TmpBufferManager);\n  TmpBufferManager(int32_t capacity, void* ptr, const ShapeView& in_shape)\n      : capacity_{capacity},\n        sorted_in_elem_cnt_{in_shape.elem_cnt()},\n        indices_elem_cnt_{sorted_in_elem_cnt_} {\n    const int32_t sorted_in_aligned_bytes = GetCudaAlignedSize(sorted_in_elem_cnt_ * sizeof(T));\n    const int32_t indices_aligned_bytes = GetCudaAlignedSize(indices_elem_cnt_ * sizeof(int32_t));\n    sorted_in_ptr_ = reinterpret_cast<T*>(ptr);\n    indices_ptr_ = reinterpret_cast<int32_t*>(reinterpret_cast<char*>(sorted_in_ptr_)\n                                              + sorted_in_aligned_bytes);\n    temp_storage_ptr_ =\n        reinterpret_cast<void*>(reinterpret_cast<char*>(indices_ptr_) + indices_aligned_bytes);\n    temp_storage_bytes_ = capacity_ - sorted_in_aligned_bytes - indices_aligned_bytes;\n    CHECK_GE(temp_storage_bytes_, 0);\n  }\n  ~TmpBufferManager() = default;\n\n  T* SortedInPtr() const { return sorted_in_ptr_; }\n  int32_t* IndicesPtr() const { return indices_ptr_; }\n  void* TempStoragePtr() const { return temp_storage_ptr_; }\n\n  int32_t TempStorageBytes() const { return temp_storage_bytes_; }\n\n private:\n  int32_t capacity_;\n\n  T* sorted_in_ptr_;\n  int32_t* indices_ptr_;\n  void* temp_storage_ptr_;\n\n  int64_t sorted_in_elem_cnt_;\n  int64_t indices_elem_cnt_;\n  int32_t temp_storage_bytes_;\n};\n\n__global__ void InitializeIndices(int32_t elem_cnt, int32_t* indices_ptr, int32_t instance_size) {\n  CUDA_1D_KERNEL_LOOP(i, elem_cnt) { indices_ptr[i] = i % instance_size; };\n}\n\n}  // namespace\n\ntemplate<typename T>\nclass GpuArgSortKernel final : public user_op::OpKernel {\n public:\n  GpuArgSortKernel() = default;\n  ~GpuArgSortKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n    TmpBufferManager<T> buf_manager(static_cast<int32_t>(tmp_buffer->shape_view().elem_cnt()),\n                                    tmp_buffer->mut_dptr<void>(), in->shape_view());\n\n    const int32_t elem_cnt = in->shape_view().elem_cnt();\n    const int32_t instance_size = in->shape_view().At(in->shape_view().NumAxes() - 1);\n    const int32_t instance_num = elem_cnt / instance_size;\n    const std::string& direction = ctx->Attr<std::string>(\"direction\");\n    InitializeIndices<<<BlocksNum4ThreadsNum(elem_cnt), kCudaThreadsNumPerBlock, 0,\n                        ctx->stream()->As<ep::CudaStream>()->cuda_stream()>>>(\n        elem_cnt, buf_manager.IndicesPtr(), instance_size);\n    if (direction == \"ASCENDING\") {\n      SortPairsAscending(in->dptr<T>(), buf_manager.IndicesPtr(), instance_num, instance_size,\n                         buf_manager.TempStoragePtr(), buf_manager.TempStorageBytes(),\n                         buf_manager.SortedInPtr(), out->mut_dptr<int32_t>(),\n                         ctx->stream()->As<ep::CudaStream>()->cuda_stream());\n    } else if (direction == \"DESCENDING\") {\n      SortPairsDescending(in->dptr<T>(), buf_manager.IndicesPtr(), instance_num, instance_size,\n                          buf_manager.TempStoragePtr(), buf_manager.TempStorageBytes(),\n                          buf_manager.SortedInPtr(), out->mut_dptr<int32_t>(),\n                          ctx->stream()->As<ep::CudaStream>()->cuda_stream());\n    } else {\n      UNIMPLEMENTED();\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_CUDA_ARG_SORT_KERNEL(dtype)                                                       \\\n  REGISTER_USER_KERNEL(\"arg_sort\")                                                                 \\\n      .SetCreateFn<GpuArgSortKernel<dtype>>()                                                      \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)                             \\\n                       && (user_op::HobDataType(\"in\", 0) == GetDataType<dtype>::value))            \\\n      .SetInferTmpSizeFn([](user_op::InferContext* ctx) {                                          \\\n        const Shape& in_shape = ctx->InputShape(\"in\", 0);                                          \\\n        const int32_t elem_cnt = in_shape.elem_cnt();                                              \\\n        const int32_t instance_size = in_shape.dim_vec().back();                                   \\\n        const int32_t instance_num = elem_cnt / instance_size;                                     \\\n                                                                                                   \\\n        /* Sorted In */                                                                            \\\n        const int32_t sorted_in_aligned_bytes = GetCudaAlignedSize(elem_cnt * sizeof(dtype));      \\\n        /* Indices */                                                                              \\\n        const int32_t indices_aligned_bytes = GetCudaAlignedSize(elem_cnt * sizeof(int32_t));      \\\n        /* CUB Temp Storage */                                                                     \\\n        int32_t temp_storage_bytes = -1;                                                           \\\n        const std::string& direction = ctx->Attr<std::string>(\"direction\");                        \\\n        if (direction == \"ASCENDING\") {                                                            \\\n          temp_storage_bytes =                                                                     \\\n              InferTempStorageForSortPairsAscending<dtype, int32_t>(instance_num, instance_size);  \\\n        } else if (direction == \"DESCENDING\") {                                                    \\\n          temp_storage_bytes =                                                                     \\\n              InferTempStorageForSortPairsDescending<dtype, int32_t>(instance_num, instance_size); \\\n        } else {                                                                                   \\\n          UNIMPLEMENTED();                                                                         \\\n        }                                                                                          \\\n                                                                                                   \\\n        return sorted_in_aligned_bytes + indices_aligned_bytes + temp_storage_bytes;               \\\n      });\n\nREGISTER_CUDA_ARG_SORT_KERNEL(float)\nREGISTER_CUDA_ARG_SORT_KERNEL(double)\nREGISTER_CUDA_ARG_SORT_KERNEL(bool)\nREGISTER_CUDA_ARG_SORT_KERNEL(int8_t)\nREGISTER_CUDA_ARG_SORT_KERNEL(uint8_t)\nREGISTER_CUDA_ARG_SORT_KERNEL(int32_t)\nREGISTER_CUDA_ARG_SORT_KERNEL(int64_t)\nREGISTER_CUDA_ARG_SORT_KERNEL(half)\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/arg_where_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/common/data_type_seq.h\"\n#include \"oneflow/core/common/switch_func.h\"\n#include \"oneflow/user/kernels/arg_where_kernel_util.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<DeviceType device_type, typename IN_T, typename OUT_T>\nclass ArgWhereKernel final : public user_op::OpKernel {\n public:\n  ArgWhereKernel() = default;\n  ~ArgWhereKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    int64_t ndims = ctx->Tensor4ArgNameAndIndex(\"input\", 0)->shape_view().NumAxes();\n    if (ndims == 0) {\n      // 0-dim tensor, elem_cnt of input is 1\n      CHECK_EQ(ctx->Tensor4ArgNameAndIndex(\"input\", 0)->shape_view().elem_cnt(), 1);\n      SetOutputSize<device_type, IN_T, OUT_T>(\n          ctx->stream(), ctx->Tensor4ArgNameAndIndex(\"input\", 0)->dptr<IN_T>(),\n          ctx->Tensor4ArgNameAndIndex(\"output_size\", 0)->mut_dptr<OUT_T>());\n      return;\n    }\n    SwitchNdimCompute(SwitchCase(ndims), ctx);\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n\n#define COMPUTE_SWITCH_ENTRY(func_name, ndim) func_name<ndim>\n  DEFINE_STATIC_SWITCH_FUNC(void, NdimCompute, COMPUTE_SWITCH_ENTRY, MAKE_NDIM_CTRV_SEQ(DIM_SEQ));\n#undef COMPUTE_SWITCH_ENTRY\n\n  template<int NDIM>\n  static void NdimCompute(user_op::KernelComputeContext* ctx) {\n    const user_op::Tensor* input = ctx->Tensor4ArgNameAndIndex(\"input\", 0);\n    user_op::Tensor* output = ctx->Tensor4ArgNameAndIndex(\"output\", 0);\n    user_op::Tensor* output_size = ctx->Tensor4ArgNameAndIndex(\"output_size\", 0);\n    user_op::Tensor* tmp = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n    void* tmp_ptr = tmp ? tmp->mut_dptr() : nullptr;\n    size_t tmp_size = tmp ? tmp->shape_view().elem_cnt() * GetSizeOfDataType(tmp->data_type()) : 0;\n    ArgWhereKernelUtil<device_type, IN_T, OUT_T, NDIM>::ArgWhere(\n        ctx->stream(), input->shape_view(), input->dptr<IN_T>(), tmp_ptr, tmp_size,\n        output->mut_dptr<OUT_T>(), output_size->mut_dptr<OUT_T>());\n  }\n};\n\ntemplate<DeviceType device_type, typename IN_T, typename OUT_T, int NDIM>\nsize_t GetWorkspaceBytesSize(int64_t elem_cnt) {\n  return ArgWhereKernelUtil<device_type, IN_T, OUT_T, NDIM>::GetWorkspaceBytesSize(nullptr,\n                                                                                   elem_cnt);\n}\n\ntemplate<DeviceType device_type>\nstruct SwitchUtil;\n\ntemplate<>\nstruct SwitchUtil<DeviceType::kCPU> {\n#define SWITCH_ENTRY(func_name, device, itype, otype, ndim) func_name<device, itype, otype, ndim>\n\n  DEFINE_STATIC_SWITCH_FUNC(\n      size_t, GetWorkspaceBytesSize, SWITCH_ENTRY,\n      MAKE_DEVICE_TYPE_CTRV_SEQ(OF_PP_MAKE_TUPLE_SEQ(DeviceType::kCPU)),\n      MAKE_DATA_TYPE_CTRV_SEQ(ARITHMETIC_DATA_TYPE_SEQ UNSIGNED_INT_DATA_TYPE_SEQ BOOL_DATA_TYPE_SEQ\n                                  FLOAT16_DATA_TYPE_SEQ),\n      MAKE_DATA_TYPE_CTRV_SEQ(INDEX_DATA_TYPE_SEQ), MAKE_NDIM_CTRV_SEQ(DIM_SEQ));\n#undef SWITCH_ENTRY\n};\n\n#ifdef WITH_CUDA\n\ntemplate<>\nstruct SwitchUtil<DeviceType::kCUDA> {\n#define SWITCH_ENTRY(func_name, device, itype, otype, ndim) func_name<device, itype, otype, ndim>\n\n  DEFINE_STATIC_SWITCH_FUNC(\n      size_t, GetWorkspaceBytesSize, SWITCH_ENTRY,\n      MAKE_DEVICE_TYPE_CTRV_SEQ(OF_PP_MAKE_TUPLE_SEQ(DeviceType::kCUDA)),\n      MAKE_DATA_TYPE_CTRV_SEQ(ARITHMETIC_DATA_TYPE_SEQ UNSIGNED_INT_DATA_TYPE_SEQ BOOL_DATA_TYPE_SEQ\n                                  HALF_DATA_TYPE_SEQ),\n      MAKE_DATA_TYPE_CTRV_SEQ(INDEX_DATA_TYPE_SEQ), MAKE_NDIM_CTRV_SEQ(DIM_SEQ));\n#undef SWITCH_ENTRY\n};\n\n#endif  // WITH_CUDA\n\ntemplate<DeviceType device_type>\nsize_t InferTempStorageBytesSize(user_op::InferContext* ctx) {\n  const Shape& input_shape = ctx->InputShape(\"input\", 0);\n  if (input_shape.NumAxes() == 0) { return 0; }\n  DataType input_dtype = ctx->InputDType(\"input\", 0);\n  DataType output_dtype = ctx->OutputDType(\"output\", 0);\n  return SwitchUtil<device_type>::SwitchGetWorkspaceBytesSize(\n      SwitchCase(device_type, input_dtype, output_dtype, input_shape.NumAxes()),\n      input_shape.elem_cnt());\n}\n\n}  // namespace\n\n#define REGISTER_ARG_WHERE_KERNEL(device, itype, otype)                                          \\\n  REGISTER_USER_KERNEL(\"argwhere\")                                                               \\\n      .SetCreateFn<ArgWhereKernel<device, itype, otype>>()                                       \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == device)                                      \\\n                       && (user_op::HobDataType(\"input\", 0) == GetDataType<itype>::value)        \\\n                       && (user_op::HobDataType(\"output\", 0) == GetDataType<otype>::value)       \\\n                       && (user_op::HobDataType(\"output_size\", 0) == GetDataType<otype>::value)) \\\n      .SetInferTmpSizeFn(InferTempStorageBytesSize<device>);\n\n#define REGISTER_ARG_WHERE_KERNEL_WITH_DTYPE_PAIR(device, itype_pair, otype_pair) \\\n  REGISTER_ARG_WHERE_KERNEL(device, OF_PP_PAIR_FIRST(itype_pair), OF_PP_PAIR_FIRST(otype_pair))\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(\n    REGISTER_ARG_WHERE_KERNEL_WITH_DTYPE_PAIR, DEVICE_TYPE_SEQ,\n    ARITHMETIC_DATA_TYPE_SEQ UNSIGNED_INT_DATA_TYPE_SEQ BOOL_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ)\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_ARG_WHERE_KERNEL_WITH_DTYPE_PAIR, (DeviceType::kCPU),\n                                 FLOAT16_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ)\n#ifdef WITH_CUDA\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_ARG_WHERE_KERNEL_WITH_DTYPE_PAIR, (DeviceType::kCUDA),\n                                 HALF_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ)\n#endif  // WITH_CUDA\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/arg_where_kernel_util.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/user/kernels/arg_where_kernel_util.h\"\n#include \"oneflow/core/common/nd_index_offset_helper.h\"\n#include \"oneflow/core/common/small_vector.h\"\n#include \"oneflow/core/kernel/kernel_util.h\"\n\nnamespace oneflow {\n\ntemplate<typename IN_T, typename OUT_T, int NDIM>\nstruct ArgWhereKernelUtil<DeviceType::kCPU, IN_T, OUT_T, NDIM> {\n  static void ArgWhere(ep::Stream* stream, const ShapeView& input_shape, const IN_T* input_ptr,\n                       void* temp_storage, size_t temp_storage_bytes, OUT_T* output_ptr,\n                       OUT_T* output_size_ptr) {\n    // deal with empty blob\n    if (input_shape.elem_cnt() == 0) {\n      Memset<DeviceType::kCPU>(stream, output_size_ptr, 0, sizeof(OUT_T));\n      return;\n    }\n\n    const int64_t elem_cnt = input_shape.elem_cnt();\n    CHECK_LE(elem_cnt, std::numeric_limits<OUT_T>::max());\n    OUT_T true_cnt = 0;\n    OUT_T dims[NDIM] = {0};\n    std::transform(input_shape.ptr(), input_shape.ptr() + input_shape.NumAxes(), dims,\n                   [](int64_t dim) { return static_cast<OUT_T>(dim); });\n    NdIndexOffsetHelper<OUT_T, NDIM> index_converter(dims);\n    FOR_RANGE(int64_t, i, 0, elem_cnt) {\n      if (static_cast<bool>(input_ptr[i])) {\n        index_converter.OffsetToNdIndex(i, output_ptr + true_cnt * NDIM);\n        true_cnt += 1;\n      }\n    }\n    *output_size_ptr = true_cnt;\n  }\n\n  static size_t GetWorkspaceBytesSize(ep::Stream* stream, int64_t elem_cnt) { return 0; }\n};\n\nINSTANTIATE_ARG_WHERE_KERNEL_UTIL_FOR_DEVICE(DeviceType::kCPU)\n\n#define INSTANTIATE_CPU_FLOAT16_ARG_WHERE_KERNEL_UTIL                                              \\\n  OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_ARG_WHERE_KERNEL_UTIL_WITH_DTYPE_PAIR,              \\\n                                   (DeviceType::kCPU), FLOAT16_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ, \\\n                                   DIM_SEQ)\n\nINSTANTIATE_CPU_FLOAT16_ARG_WHERE_KERNEL_UTIL\n\ntemplate<DeviceType device_type, typename IN_T, typename OUT_T>\nvoid SetOutputSize(ep::Stream* stream, const IN_T* input_ptr, OUT_T* output_size_ptr) {\n  if (*input_ptr == GetZeroVal<IN_T>()) {\n    *output_size_ptr = GetZeroVal<OUT_T>();\n  } else {\n    *output_size_ptr = GetOneVal<OUT_T>();\n  }\n}\n\nINSTANTIATE_SET_OUTPUT_SIZE_FOR_DEVICE(DeviceType::kCPU)\n\n#define INSTANTIATE_CPU_FLOAT16_SET_OUTPUT_SIZE                                 \\\n  OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_SET_OUTPUT_SIZE_WITH_DTYPE_PAIR, \\\n                                   (DeviceType::kCPU), FLOAT16_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ)\n\nINSTANTIATE_CPU_FLOAT16_SET_OUTPUT_SIZE\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/arg_where_kernel_util.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/user/kernels/arg_where_kernel_util.h\"\n#include \"oneflow/core/common/nd_index_offset_helper.h\"\n#include \"oneflow/core/common/small_vector.h\"\n#include \"oneflow/core/cuda/elementwise.cuh\"\n#include \"oneflow/core/kernel/kernel_util.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n#include <cub/cub.cuh>\n\nnamespace oneflow {\n\nnamespace {\n\nconstexpr int kBlockSize = cuda::elementwise::kBlockSize;\n\nint GetNumBlocks(int64_t elem_cnt) {\n  int num_blocks = 0;\n  OF_CUDA_CHECK(cuda::elementwise::GetNumBlocks(elem_cnt, &num_blocks));\n  return num_blocks;\n}\n\ntemplate<typename T, int NDIM>\nstruct StrideIterator {\n  typedef StrideIterator self_type;\n  typedef std::ptrdiff_t difference_type;\n  typedef T value_type;\n  typedef T* pointer;\n  typedef T& reference;\n  typedef std::random_access_iterator_tag iterator_category;\n\n  explicit StrideIterator(T* ptr, size_t max_iters) : ptr_(ptr), max_iters_(max_iters) {}\n\n  OF_DEVICE_FUNC reference operator[](int i) {\n    assert(0 <= i && i < max_iters_);\n    return *(ptr_ + (i * NDIM));\n  }\n\n private:\n  T* ptr_;\n  size_t max_iters_;\n};\n\ntemplate<typename T, int NDIM>\n__global__ void __launch_bounds__(kBlockSize)\n    CudaOffsetToNdIndexInplace(NdIndexOffsetHelper<T, NDIM> index_converter,\n                               const T* output_size_ptr, T* output_ptr) {\n  CUDA_1D_KERNEL_LOOP_T(T, i, *output_size_ptr) {\n    T* index_ptr = output_ptr + i * NDIM;\n    index_converter.OffsetToNdIndex(*index_ptr, index_ptr);\n  }\n}\n\ntemplate<typename T>\nstruct IsTrue {\n  __device__ __forceinline__ bool operator()(const T& val) const { return static_cast<bool>(val); }\n};\n\ntemplate<typename IN_T, typename OUT_T, typename OUT_ITER>\ncudaError_t SelectTrue(cudaStream_t stream, int num_items, void* temp_storage,\n                       size_t& temp_storage_bytes, const IN_T* input, OUT_ITER output_iter,\n                       OUT_T* num_selected) {\n  IsTrue<IN_T> is_true;\n  cub::TransformInputIterator<bool, IsTrue<IN_T>, const IN_T*> flag_iter(input, is_true);\n  cub::CountingInputIterator<OUT_T> offset_counter(0);\n  return cub::DeviceSelect::Flagged(temp_storage, temp_storage_bytes, offset_counter, flag_iter,\n                                    output_iter, num_selected, num_items, stream);\n}\n\ntemplate<typename IN_T, typename OUT_T>\n__global__ void SetOutputSizeKernel(const IN_T* input_ptr, OUT_T* output_size_ptr) {\n  if (*input_ptr == GetZeroVal<IN_T>()) {\n    *output_size_ptr = GetZeroVal<OUT_T>();\n  } else {\n    *output_size_ptr = GetOneVal<OUT_T>();\n  }\n}\n\n}  // namespace\n\ntemplate<typename IN_T, typename OUT_T, int NDIM>\nstruct ArgWhereKernelUtil<DeviceType::kCUDA, IN_T, OUT_T, NDIM> {\n  static void ArgWhere(ep::Stream* stream, const ShapeView& input_shape, const IN_T* input_ptr,\n                       void* temp_storage, size_t temp_storage_bytes, OUT_T* output_ptr,\n                       OUT_T* output_size_ptr) {\n    const int64_t elem_cnt = input_shape.elem_cnt();\n    // deal with empty blob\n    if (elem_cnt == 0) {\n      Memset<DeviceType::kCUDA>(stream, output_size_ptr, 0, sizeof(OUT_T));\n      return;\n    }\n\n    CHECK_NOTNULL(stream);\n    CHECK_LE(elem_cnt, std::numeric_limits<OUT_T>::max());\n    size_t workspace = GetWorkspaceBytesSize(stream, elem_cnt);\n    CHECK_LE(workspace, temp_storage_bytes);\n\n    if (NDIM == 1) {\n      OF_CUDA_CHECK((SelectTrue<IN_T, OUT_T, OUT_T*>(\n          stream->As<ep::CudaStream>()->cuda_stream(), input_shape.elem_cnt(), temp_storage,\n          workspace, input_ptr, output_ptr, output_size_ptr)));\n    } else {\n      using OutputIterator = StrideIterator<OUT_T, NDIM>;\n      OutputIterator output_iter(output_ptr, elem_cnt);\n      OF_CUDA_CHECK((SelectTrue<IN_T, OUT_T, OutputIterator>(\n          stream->As<ep::CudaStream>()->cuda_stream(), elem_cnt, temp_storage, workspace, input_ptr,\n          output_iter, output_size_ptr)));\n\n      OUT_T dims[NDIM] = {0};\n      std::transform(input_shape.ptr(), input_shape.ptr() + input_shape.NumAxes(), dims,\n                     [](int64_t dim) { return static_cast<OUT_T>(dim); });\n      NdIndexOffsetHelper<OUT_T, NDIM> index_converter(dims);\n      CudaOffsetToNdIndexInplace<OUT_T, NDIM>\n          <<<GetNumBlocks(elem_cnt), kBlockSize, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(\n              index_converter, output_size_ptr, output_ptr);\n    }\n  }\n\n  static size_t GetWorkspaceBytesSize(ep::Stream* stream, int64_t elem_cnt) {\n    cudaStream_t cuda_stream = stream ? stream->As<ep::CudaStream>()->cuda_stream() : 0;\n    size_t workspace = 0;\n    if (NDIM == 1) {\n      OF_CUDA_CHECK((SelectTrue<IN_T, OUT_T, OUT_T*>(cuda_stream, elem_cnt, nullptr, workspace,\n                                                     nullptr, nullptr, nullptr)));\n    } else {\n      using OutputIterator = StrideIterator<OUT_T, NDIM>;\n      OutputIterator output_iter(nullptr, elem_cnt);\n      OF_CUDA_CHECK((SelectTrue<IN_T, OUT_T, OutputIterator>(\n          cuda_stream, elem_cnt, nullptr, workspace, nullptr, output_iter, nullptr)));\n    }\n    return workspace;\n  }\n};\n\nINSTANTIATE_ARG_WHERE_KERNEL_UTIL_FOR_DEVICE(DeviceType::kCUDA)\n\n#define INSTANTIATE_CUDA_HALF_ARG_WHERE_KERNEL_UTIL                                              \\\n  OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_ARG_WHERE_KERNEL_UTIL_WITH_DTYPE_PAIR,            \\\n                                   (DeviceType::kCUDA), HALF_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ, \\\n                                   DIM_SEQ)\n\nINSTANTIATE_CUDA_HALF_ARG_WHERE_KERNEL_UTIL\n\ntemplate<DeviceType device_type, typename IN_T, typename OUT_T>\nvoid SetOutputSize(ep::Stream* stream, const IN_T* input_ptr, OUT_T* output_size_ptr) {\n  SetOutputSizeKernel<IN_T, OUT_T>\n      <<<1, 1, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(input_ptr, output_size_ptr);\n}\n\nINSTANTIATE_SET_OUTPUT_SIZE_FOR_DEVICE(DeviceType::kCUDA)\n\n#define INSTANTIATE_CUDA_HALF_SET_OUTPUT_SIZE                                   \\\n  OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_SET_OUTPUT_SIZE_WITH_DTYPE_PAIR, \\\n                                   (DeviceType::kCUDA), HALF_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ)\n\nINSTANTIATE_CUDA_HALF_SET_OUTPUT_SIZE\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/arg_where_kernel_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_KERNELS_ARG_WHERE_KERNEL_UTIL_H_\n#define ONEFLOW_USER_KERNELS_ARG_WHERE_KERNEL_UTIL_H_\n\n#include \"oneflow/core/ep/include/stream.h\"\n#include \"oneflow/core/common/shape_view.h\"\n\nnamespace oneflow {\n\ntemplate<DeviceType device_type, typename IN_T, typename OUT_T, int NDIM>\nstruct ArgWhereKernelUtil {\n  static void ArgWhere(ep::Stream* stream, const ShapeView& input_shape, const IN_T* input_ptr,\n                       void* temp_storage, size_t temp_storage_bytes, OUT_T* output_ptr,\n                       OUT_T* output_size_ptr);\n  static size_t GetWorkspaceBytesSize(ep::Stream* stream, int64_t elem_cnt);\n};\n\n#define INSTANTIATE_ARG_WHERE_KERNEL_UTIL(device, itype, otype, ndim) \\\n  template struct ArgWhereKernelUtil<device, itype, otype, ndim>;\n\n#define INSTANTIATE_ARG_WHERE_KERNEL_UTIL_WITH_DTYPE_PAIR(device, itype_pair, otype_pair, ndim) \\\n  INSTANTIATE_ARG_WHERE_KERNEL_UTIL(device, OF_PP_PAIR_FIRST(itype_pair),                       \\\n                                    OF_PP_PAIR_FIRST(otype_pair), ndim)\n\n#define INSTANTIATE_ARG_WHERE_KERNEL_UTIL_FOR_DEVICE(device)                                       \\\n  OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(                                                                \\\n      INSTANTIATE_ARG_WHERE_KERNEL_UTIL_WITH_DTYPE_PAIR, (device),                                 \\\n      ARITHMETIC_DATA_TYPE_SEQ UNSIGNED_INT_DATA_TYPE_SEQ BOOL_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ, \\\n      DIM_SEQ)\n\ntemplate<DeviceType device_type, typename IN_T, typename OUT_T>\nvoid SetOutputSize(ep::Stream* stream, const IN_T* input_ptr, OUT_T* output_size_ptr);\n\n#define INSTANTIATE_SET_OUTPUT_SIZE(device, itype, otype)                                        \\\n  template void SetOutputSize<device, itype, otype>(ep::Stream * stream, const itype* input_ptr, \\\n                                                    otype* output_size_ptr);\n\n#define INSTANTIATE_SET_OUTPUT_SIZE_WITH_DTYPE_PAIR(device, itype_pair, otype_pair) \\\n  INSTANTIATE_SET_OUTPUT_SIZE(device, OF_PP_PAIR_FIRST(itype_pair), OF_PP_PAIR_FIRST(otype_pair))\n\n#define INSTANTIATE_SET_OUTPUT_SIZE_FOR_DEVICE(device)       \\\n  OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(                          \\\n      INSTANTIATE_SET_OUTPUT_SIZE_WITH_DTYPE_PAIR, (device), \\\n      ARITHMETIC_DATA_TYPE_SEQ UNSIGNED_INT_DATA_TYPE_SEQ BOOL_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ)\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_KERNELS_ARG_WHERE_KERNEL_UTIL_H_\n"
  },
  {
    "path": "oneflow/user/kernels/argmax_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/thread/thread_manager.h\"\n#include \"oneflow/core/common/balanced_splitter.h\"\n\nnamespace oneflow {\n\ntemplate<typename T>\nclass CpuArgMaxKernel final : public user_op::OpKernel {\n public:\n  CpuArgMaxKernel() = default;\n  ~CpuArgMaxKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n\n    const int32_t elem_cnt = in->shape_view().elem_cnt();\n    CHECK_GE(elem_cnt, 0);\n    if (elem_cnt == 0) { return; }\n\n    const T* in_ptr = in->dptr<T>();\n    int64_t* out_ptr = out->mut_dptr<int64_t>();\n\n    const int64_t instance_size = in->shape_view().At(in->shape_view().NumAxes() - 1);\n    const int64_t instance_num = elem_cnt / instance_size;\n    const int64_t num_thread =\n        std::min(instance_num, (int64_t)Singleton<ThreadPool>::Get()->thread_num());\n    const BalancedSplitter bs(instance_num, num_thread);\n    BlockingCounter bc(num_thread);\n    FOR_RANGE(int64_t, thread_id, 0, num_thread) {\n      const Range range = bs.At(thread_id);\n      Singleton<ThreadPool>::Get()->AddWork([=, &bc]() {\n        FOR_RANGE(int64_t, i, range.begin(), range.end()) {\n          const T* in_ptr_i = in_ptr + i * instance_size;\n          out_ptr[i] =\n              std::distance(in_ptr_i, std::max_element(in_ptr_i, in_ptr_i + instance_size));\n        }\n        bc.Decrease();\n      });\n    }\n    bc.WaitForeverUntilCntEqualZero();\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_CPU_ARGMAX_KERNEL(dtype)                                               \\\n  REGISTER_USER_KERNEL(\"argmax\").SetCreateFn<CpuArgMaxKernel<dtype>>().SetIsMatchedHob( \\\n      (user_op::HobDeviceType() == DeviceType::kCPU)                                    \\\n      && (user_op::HobDataType(\"in\", 0) == GetDataType<dtype>::value));\n\nREGISTER_CPU_ARGMAX_KERNEL(bool)\nREGISTER_CPU_ARGMAX_KERNEL(float)\nREGISTER_CPU_ARGMAX_KERNEL(float16)\nREGISTER_CPU_ARGMAX_KERNEL(double)\nREGISTER_CPU_ARGMAX_KERNEL(uint8_t)\nREGISTER_CPU_ARGMAX_KERNEL(int8_t)\nREGISTER_CPU_ARGMAX_KERNEL(int32_t)\nREGISTER_CPU_ARGMAX_KERNEL(int64_t)\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/argmax_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include <cub/cub.cuh>\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename T>\nclass TmpBufferManager final {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(TmpBufferManager);\n  TmpBufferManager(int32_t capacity, void* ptr, int32_t instance_num)\n      : capacity_{capacity}, key_value_out_elem_cnt_{instance_num} {\n    const int32_t key_value_out_aligned_bytes =\n        GetCudaAlignedSize(key_value_out_elem_cnt_ * sizeof(cub::KeyValuePair<int32_t, T>));\n    key_value_out_ptr_ = reinterpret_cast<cub::KeyValuePair<int32_t, T>*>(ptr);\n    temp_storage_ptr_ = reinterpret_cast<void*>(reinterpret_cast<char*>(key_value_out_ptr_)\n                                                + key_value_out_aligned_bytes);\n    temp_storage_bytes_ = capacity_ - key_value_out_aligned_bytes;\n    CHECK_GE(temp_storage_bytes_, 0);\n  }\n  ~TmpBufferManager() = default;\n\n  cub::KeyValuePair<int32_t, T>* KeyValueOutPtr() const { return key_value_out_ptr_; }\n  void* TempStoragePtr() const { return temp_storage_ptr_; }\n\n  int32_t TempStorageBytes() const { return temp_storage_bytes_; }\n\n private:\n  int32_t capacity_;\n\n  cub::KeyValuePair<int32_t, T>* key_value_out_ptr_;\n  void* temp_storage_ptr_;\n\n  int32_t key_value_out_elem_cnt_;\n  int32_t temp_storage_bytes_;\n};\n\nclass MultiplyFunctor final {\n public:\n  MultiplyFunctor(int32_t num_col) : num_col_(num_col) {}\n  __host__ __device__ __forceinline__ int32_t operator()(int32_t idx) const {\n    return idx * num_col_;\n  }\n\n private:\n  int32_t num_col_;\n};\n\ntemplate<typename T>\nsize_t InferTempStorageForArgMax(int32_t num_row, int32_t num_col) {\n  using SegmentOffsetIter =\n      cub::TransformInputIterator<int32_t, MultiplyFunctor, cub::CountingInputIterator<int32_t>>;\n  cub::CountingInputIterator<int32_t> counting_iter(0);\n  MultiplyFunctor multiply_functor(num_col);\n  SegmentOffsetIter segment_offset_iter(counting_iter, multiply_functor);\n\n  size_t temp_storage_bytes = 0;\n  auto err =\n      cub::DeviceSegmentedReduce::ArgMax<T*, cub::KeyValuePair<int32_t, T>*, SegmentOffsetIter>(\n          /* d_temp_storage */ nullptr, /* temp_storage_bytes */ temp_storage_bytes,\n          /* d_in */ nullptr, /* d_out */ nullptr, /* num_segments */ num_row,\n          /* d_begin_offsets */ segment_offset_iter, /* d_end_offsets */ segment_offset_iter + 1,\n          /* stream */ 0);\n  OF_CUDA_CHECK(err);\n\n  return temp_storage_bytes;\n}\n\ntemplate<typename T>\nvoid ArgMax(const T* in_ptr, int32_t num_row, int32_t num_col, void* temp_storage_ptr,\n            int32_t temp_storage_bytes, cub::KeyValuePair<int32_t, T>* out_ptr,\n            cudaStream_t stream) {\n  size_t rt_inferred_temp_storage_bytes = InferTempStorageForArgMax<T>(num_row, num_col);\n  CHECK_LE(rt_inferred_temp_storage_bytes, temp_storage_bytes);\n\n  using SegmentOffsetIter =\n      cub::TransformInputIterator<int32_t, MultiplyFunctor, cub::CountingInputIterator<int32_t>>;\n  cub::CountingInputIterator<int32_t> counting_iter(0);\n  MultiplyFunctor multiply_functor(num_col);\n  SegmentOffsetIter segment_offset_iter(counting_iter, multiply_functor);\n\n  auto err = cub::DeviceSegmentedReduce::ArgMax(\n      /* d_temp_storage */ temp_storage_ptr,\n      /* temp_storage_bytes */ rt_inferred_temp_storage_bytes,\n      /* d_in */ in_ptr,\n      /* d_out */ out_ptr,\n      /* num_segments */ num_row,\n      /* d_begin_offsets */ segment_offset_iter,\n      /* d_end_offsets */ segment_offset_iter + 1,\n      /* stream */ stream);\n  OF_CUDA_CHECK(err);\n}\n\ntemplate<typename T>\n__global__ void WriteKeysToOutput(const int32_t instance_num,\n                                  const cub::KeyValuePair<int32_t, T>* key_value_out_ptr,\n                                  int64_t* out_ptr) {\n  CUDA_1D_KERNEL_LOOP(i, instance_num) { out_ptr[i] = key_value_out_ptr[i].key; }\n}\n\n}  // namespace\n\ntemplate<typename T>\nclass GpuArgMaxKernel final : public user_op::OpKernel {\n public:\n  GpuArgMaxKernel() = default;\n  ~GpuArgMaxKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n\n    const int32_t elem_cnt = in->shape_view().elem_cnt();\n    CHECK_GE(elem_cnt, 0);\n    if (elem_cnt == 0) { return; }\n\n    const int32_t instance_size = in->shape_view().At(in->shape_view().NumAxes() - 1);\n    const int32_t instance_num = elem_cnt / instance_size;\n    TmpBufferManager<T> buffer_manager(tmp_buffer->shape_view().elem_cnt(),\n                                       tmp_buffer->mut_dptr<void>(), instance_num);\n\n    ArgMax(in->dptr<T>(), instance_num, instance_size, buffer_manager.TempStoragePtr(),\n           buffer_manager.TempStorageBytes(), buffer_manager.KeyValueOutPtr(),\n           ctx->stream()->As<ep::CudaStream>()->cuda_stream());\n    WriteKeysToOutput<T><<<BlocksNum4ThreadsNum(instance_num), kCudaThreadsNumPerBlock, 0,\n                           ctx->stream()->As<ep::CudaStream>()->cuda_stream()>>>(\n        instance_num, buffer_manager.KeyValueOutPtr(), out->mut_dptr<int64_t>());\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_CUDA_ARGMAX_KERNEL(dtype)                                                         \\\n  REGISTER_USER_KERNEL(\"argmax\")                                                                   \\\n      .SetCreateFn<GpuArgMaxKernel<dtype>>()                                                       \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)                             \\\n                       && (user_op::HobDataType(\"in\", 0) == GetDataType<dtype>::value))            \\\n      .SetInferTmpSizeFn([](user_op::InferContext* ctx) {                                          \\\n        const Shape& in_shape = ctx->InputShape(\"in\", 0);                                          \\\n        const int32_t instance_size = in_shape.dim_vec().back();                                   \\\n        const int32_t instance_num = in_shape.elem_cnt() / instance_size;                          \\\n                                                                                                   \\\n        /* Key-Value Out */                                                                        \\\n        int32_t key_value_out_bytes =                                                              \\\n            GetCudaAlignedSize(instance_num * sizeof(cub::KeyValuePair<int32_t, dtype>));          \\\n                                                                                                   \\\n        /* CUB Temp Storage */                                                                     \\\n        size_t temp_storage_bytes = InferTempStorageForArgMax<dtype>(instance_num, instance_size); \\\n                                                                                                   \\\n        return key_value_out_bytes + temp_storage_bytes;                                           \\\n      });\n\nREGISTER_CUDA_ARGMAX_KERNEL(bool)\nREGISTER_CUDA_ARGMAX_KERNEL(float)\nREGISTER_CUDA_ARGMAX_KERNEL(double)\nREGISTER_CUDA_ARGMAX_KERNEL(uint8_t)\nREGISTER_CUDA_ARGMAX_KERNEL(int8_t)\nREGISTER_CUDA_ARGMAX_KERNEL(int32_t)\nREGISTER_CUDA_ARGMAX_KERNEL(int64_t)\nREGISTER_CUDA_ARGMAX_KERNEL(half)\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/as_strided_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/new_kernel_util.h\"\n#include \"oneflow/core/kernel/kernel_util.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n#include \"oneflow/core/common/nd_index_offset_helper.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nconstexpr size_t NUM_DIM = 20;\n\ntemplate<typename T>\nstruct AsStridedFunctor final {\n  void operator()(ep::Stream* stream, const T* input_buf, T* output_buf, const int64_t* dest_dims,\n                  const int64_t* stride, const int64_t dest_num_dims, const int64_t storage_offset,\n                  const int64_t input_num, const int64_t output_num) {\n    NdIndexOffsetHelper<int64_t, NUM_DIM> destIndexOffsetHelper(dest_dims, dest_num_dims);\n    FOR_RANGE(int64_t, i, 0, output_num) {\n      int64_t dst_index[NUM_DIM];\n      destIndexOffsetHelper.OffsetToNdIndex(i, dst_index, dest_num_dims);\n      int64_t index_in_input = storage_offset;\n      FOR_RANGE(int64_t, j, 0, dest_num_dims) { index_in_input += dst_index[j] * stride[j]; }\n      output_buf[i] = input_buf[index_in_input];\n    }\n  }\n};\n\ntemplate<typename T>\nstruct AsStridedGradFunctor final {\n  void operator()(ep::Stream* stream, const T* dy_buf, T* dx_buf, const int64_t* dy_dims,\n                  const int64_t* stride, const int64_t dy_num_dims, const int64_t storage_offset,\n                  const int64_t dx_num, const int64_t dy_num) {\n    NdIndexOffsetHelper<int64_t, NUM_DIM> destIndexOffsetHelper(dy_dims, dy_num_dims);\n    FOR_RANGE(int64_t, i, 0, dy_num) {\n      int64_t dy_index[NUM_DIM];\n      destIndexOffsetHelper.OffsetToNdIndex(i, dy_index, dy_num_dims);\n      int64_t index_in_dx = storage_offset;\n      FOR_RANGE(int64_t, j, 0, dy_num_dims) { index_in_dx += dy_index[j] * stride[j]; }\n      dx_buf[index_in_dx] += dy_buf[i];\n    }\n  }\n};\n\n}  // namespace\n\ntemplate<typename T>\nclass CpuAsStridedKernel final : public user_op::OpKernel {\n public:\n  CpuAsStridedKernel() = default;\n  ~CpuAsStridedKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* input = ctx->Tensor4ArgNameAndIndex(\"input\", 0);\n    user_op::Tensor* output = ctx->Tensor4ArgNameAndIndex(\"output\", 0);\n    const auto size = ctx->Attr<std::vector<int64_t>>(\"size\");\n    const auto stride = ctx->Attr<std::vector<int64_t>>(\"stride\");\n    const int64_t storage_offset = ctx->Attr<int64_t>(\"storage_offset\");\n\n    size_t dest_num_dims = output->shape_view().NumAxes();\n    const int64_t* dest_dims = output->shape_view().ptr();\n    const size_t input_num = input->shape_view().Count(0);\n    const size_t output_num = output->shape_view().Count(0);\n\n    AsStridedFunctor<T>()(ctx->stream(), input->dptr<T>(), output->mut_dptr<T>(), dest_dims,\n                          stride.data(), dest_num_dims, storage_offset, input_num, output_num);\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\ntemplate<typename T>\nclass CpuAsStridedGradKernel final : public user_op::OpKernel {\n public:\n  CpuAsStridedGradKernel() = default;\n  ~CpuAsStridedGradKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex(\"dx\", 0);\n    const auto size = ctx->Attr<std::vector<int64_t>>(\"size\");\n    const auto stride = ctx->Attr<std::vector<int64_t>>(\"stride\");\n    const int64_t storage_offset = ctx->Attr<int64_t>(\"storage_offset\");\n\n    size_t dy_num_dims = dy->shape_view().NumAxes();\n    const int64_t* dy_dims = dy->shape_view().ptr();\n    const size_t dx_num = dx->shape_view().Count(0);\n    const size_t dy_num = dy->shape_view().Count(0);\n\n    Memset<DeviceType::kCPU>(ctx->stream(), dx->mut_dptr(), 0,\n                             dx->shape_view().Count(0) * sizeof(T));\n\n    AsStridedGradFunctor<T>()(ctx->stream(), dy->dptr<T>(), dx->mut_dptr<T>(), dy_dims,\n                              stride.data(), dy_num_dims, storage_offset, dx_num, dy_num);\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_CPU_ASSTRIDED_KERNEL(in_type)                        \\\n  REGISTER_USER_KERNEL(\"as_strided\")                                  \\\n      .SetCreateFn<CpuAsStridedKernel<in_type>>()                     \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \\\n                       && (user_op::HobDataType(\"input\", 0) == GetDataType<in_type>::value));\n\nREGISTER_CPU_ASSTRIDED_KERNEL(float);\nREGISTER_CPU_ASSTRIDED_KERNEL(double);\nREGISTER_CPU_ASSTRIDED_KERNEL(int8_t);\nREGISTER_CPU_ASSTRIDED_KERNEL(uint8_t);\nREGISTER_CPU_ASSTRIDED_KERNEL(int32_t);\nREGISTER_CPU_ASSTRIDED_KERNEL(int64_t);\n\n#undef REGISTER_CPU_ASSTRIDED_KERNEL\n\n#define REGISTER_CPU_ASSTRIDED_GRAD_KERNEL(in_type)                   \\\n  REGISTER_USER_KERNEL(\"as_strided_grad\")                             \\\n      .SetCreateFn<CpuAsStridedGradKernel<in_type>>()                 \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \\\n                       && (user_op::HobDataType(\"input\", 0) == GetDataType<in_type>::value));\n\nREGISTER_CPU_ASSTRIDED_GRAD_KERNEL(float);\nREGISTER_CPU_ASSTRIDED_GRAD_KERNEL(double);\n\n#undef REGISTER_CPU_ASSTRIDED_GRAD_KERNEL\n\nREGISTER_USER_KERNEL(\"as_strided\")\n    .SetCreateFn<CpuAsStridedKernel<bool>>()\n    .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)\n                     && (user_op::HobDataType(\"input\", 0) == GetDataType<bool>::value));\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/as_strided_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/cuda/atomic.cuh\"\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/new_kernel_util.h\"\n#include \"oneflow/core/kernel/kernel_util.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n#include \"oneflow/core/common/nd_index_offset_helper.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nconstexpr size_t NUM_DIM = 8;\n\ntemplate<size_t num_dims, typename IndexType>\nstruct AsStridedParams {\n  NdIndexOffsetHelper<IndexType, num_dims> destIndexOffsetHelper;\n  int64_t dest_dims[num_dims];\n  int32_t stride[num_dims];\n  int32_t dest_num_dims;\n  int32_t storage_offset;\n  int32_t input_num;\n  int32_t output_num;\n};\n\ntemplate<typename T>\n__global__ void AsStrided_kernel(const T* input_buf, T* output_buf,\n                                 AsStridedParams<NUM_DIM, int64_t> params) {\n  const int64_t* dest_dims = reinterpret_cast<const int64_t*>(params.dest_dims);\n  const int32_t* stride = reinterpret_cast<const int32_t*>(params.stride);\n\n  CUDA_1D_KERNEL_LOOP_T(int64_t, i, params.output_num) {\n    int64_t dst_index[NUM_DIM];\n    params.destIndexOffsetHelper.OffsetToNdIndex(i, dst_index, params.dest_num_dims);\n    int32_t index_in_input = params.storage_offset;\n    FOR_RANGE(int64_t, j, 0, params.dest_num_dims) { index_in_input += dst_index[j] * stride[j]; }\n    output_buf[i] = input_buf[index_in_input];\n  }\n}\n\ntemplate<typename T>\n__global__ void AsStridedGrad_kernel(const T* dy_buf, T* dx_buf,\n                                     AsStridedParams<NUM_DIM, int64_t> params) {\n  const int64_t* dest_dims = reinterpret_cast<const int64_t*>(params.dest_dims);\n  const int32_t* stride = reinterpret_cast<const int32_t*>(params.stride);\n  CUDA_1D_KERNEL_LOOP_T(int64_t, i, params.output_num) {\n    int64_t dy_index[NUM_DIM];\n    params.destIndexOffsetHelper.OffsetToNdIndex(i, dy_index, params.dest_num_dims);\n    int32_t index_in_dx = params.storage_offset;\n    FOR_RANGE(int64_t, j, 0, params.dest_num_dims) { index_in_dx += dy_index[j] * stride[j]; }\n    cuda::atomic::Add(dx_buf + index_in_dx, dy_buf[i]);\n  }\n}\n\ntemplate<typename T>\nstruct AsStridedFunctor final {\n  void operator()(ep::Stream* stream, const T* input_buf, T* output_buf, const int64_t* dest_dims,\n                  const int64_t* stride, const int64_t dest_num_dims, const int64_t storage_offset,\n                  const int64_t input_num, const int64_t output_num) {\n    NdIndexOffsetHelper<int64_t, NUM_DIM> destIndexOffsetHelper(dest_dims, dest_num_dims);\n    AsStridedParams<NUM_DIM, int64_t> params;\n    params.destIndexOffsetHelper = destIndexOffsetHelper;\n    FOR_RANGE(size_t, i, 0, dest_num_dims) {\n      params.dest_dims[i] = dest_dims[i];\n      params.stride[i] = stride[i];\n    }\n    params.dest_num_dims = dest_num_dims;\n    params.storage_offset = storage_offset;\n    params.input_num = input_num;\n    params.output_num = output_num;\n\n    AsStrided_kernel<T>\n        <<<BlocksNum4ThreadsNum(output_num), kCudaThreadsNumPerBlock, 0,\n           stream->As<ep::CudaStream>()->cuda_stream()>>>(input_buf, output_buf, params);\n  }\n};\n\ntemplate<typename T>\nstruct AsStridedGradFunctor final {\n  void operator()(ep::Stream* stream, const T* dy_buf, T* dx_buf, const int64_t* dy_dims,\n                  const int64_t* stride, const int64_t dy_num_dims, const int64_t storage_offset,\n                  const int64_t dx_num, const int64_t dy_num) {\n    NdIndexOffsetHelper<int64_t, NUM_DIM> dyIndexOffsetHelper(dy_dims, dy_num_dims);\n    AsStridedParams<NUM_DIM, int64_t> params;\n    params.destIndexOffsetHelper = dyIndexOffsetHelper;\n    FOR_RANGE(size_t, i, 0, dy_num_dims) {\n      params.dest_dims[i] = dy_dims[i];\n      params.stride[i] = stride[i];\n    }\n    params.dest_num_dims = dy_num_dims;\n    params.storage_offset = storage_offset;\n    params.input_num = dx_num;\n    params.output_num = dy_num;\n\n    AsStridedGrad_kernel<T>\n        <<<BlocksNum4ThreadsNum(dy_num), kCudaThreadsNumPerBlock, 0,\n           stream->As<ep::CudaStream>()->cuda_stream()>>>(dy_buf, dx_buf, params);\n  }\n};\n\n}  // namespace\n\ntemplate<typename T>\nclass GpuAsStridedKernel final : public user_op::OpKernel {\n public:\n  GpuAsStridedKernel() = default;\n  ~GpuAsStridedKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* input = ctx->Tensor4ArgNameAndIndex(\"input\", 0);\n    user_op::Tensor* output = ctx->Tensor4ArgNameAndIndex(\"output\", 0);\n    const auto size = ctx->Attr<std::vector<int64_t>>(\"size\");\n    const auto stride = ctx->Attr<std::vector<int64_t>>(\"stride\");\n    const int64_t storage_offset = ctx->Attr<int64_t>(\"storage_offset\");\n\n    size_t dest_num_dims = output->shape_view().NumAxes();\n    const int64_t* dest_dims = output->shape_view().ptr();\n    const size_t input_num = input->shape_view().Count(0);\n    const size_t output_num = output->shape_view().Count(0);\n    if (input_num == 0) {\n      // 0-size tensor\n      return;\n    }\n\n    AsStridedFunctor<T>()(ctx->stream(), input->dptr<T>(), output->mut_dptr<T>(), dest_dims,\n                          stride.data(), dest_num_dims, storage_offset, input_num, output_num);\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\ntemplate<typename T>\nclass GpuAsStridedGradKernel final : public user_op::OpKernel {\n public:\n  GpuAsStridedGradKernel() = default;\n  ~GpuAsStridedGradKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex(\"dx\", 0);\n    const auto size = ctx->Attr<std::vector<int64_t>>(\"size\");\n    const auto stride = ctx->Attr<std::vector<int64_t>>(\"stride\");\n    const int64_t storage_offset = ctx->Attr<int64_t>(\"storage_offset\");\n\n    size_t dy_num_dims = dy->shape_view().NumAxes();\n    const int64_t* dy_dims = dy->shape_view().ptr();\n    const size_t dx_num = dx->shape_view().Count(0);\n    const size_t dy_num = dy->shape_view().Count(0);\n\n    Memset<DeviceType::kCUDA>(ctx->stream(), dx->mut_dptr(), 0,\n                              dx->shape_view().Count(0) * sizeof(T));\n\n    AsStridedGradFunctor<T>()(ctx->stream(), dy->dptr<T>(), dx->mut_dptr<T>(), dy_dims,\n                              stride.data(), dy_num_dims, storage_offset, dx_num, dy_num);\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_GPU_ASSTRIDED_KERNEL(in_type)                         \\\n  REGISTER_USER_KERNEL(\"as_strided\")                                   \\\n      .SetCreateFn<GpuAsStridedKernel<in_type>>()                      \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \\\n                       && (user_op::HobDataType(\"input\", 0) == GetDataType<in_type>::value));\n\nREGISTER_GPU_ASSTRIDED_KERNEL(half);\nREGISTER_GPU_ASSTRIDED_KERNEL(float);\nREGISTER_GPU_ASSTRIDED_KERNEL(double);\nREGISTER_GPU_ASSTRIDED_KERNEL(int8_t);\nREGISTER_GPU_ASSTRIDED_KERNEL(uint8_t);\nREGISTER_GPU_ASSTRIDED_KERNEL(int32_t);\nREGISTER_GPU_ASSTRIDED_KERNEL(int64_t);\n\n#undef REGISTER_GPU_ASSTRIDED_KERNEL\n\n#define REGISTER_GPU_ASSTRIDED_GRAD_KERNEL(in_type)                    \\\n  REGISTER_USER_KERNEL(\"as_strided_grad\")                              \\\n      .SetCreateFn<GpuAsStridedGradKernel<in_type>>()                  \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \\\n                       && (user_op::HobDataType(\"input\", 0) == GetDataType<in_type>::value));\n\nREGISTER_GPU_ASSTRIDED_GRAD_KERNEL(half);\nREGISTER_GPU_ASSTRIDED_GRAD_KERNEL(float);\nREGISTER_GPU_ASSTRIDED_GRAD_KERNEL(double);\n\n#undef REGISTER_GPU_ASSTRIDED_GRAD_KERNEL\n\nREGISTER_USER_KERNEL(\"as_strided\")\n    .SetCreateFn<GpuAsStridedKernel<bool>>()\n    .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)\n                     && (user_op::HobDataType(\"input\", 0) == GetDataType<bool>::value));\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/assign_if_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/kernel_util.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<bool assign_if, typename C>\nclass AssignIfCPUKernel final : public user_op::OpKernel {\n public:\n  AssignIfCPUKernel() = default;\n  ~AssignIfCPUKernel() override = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* condition = ctx->Tensor4ArgNameAndIndex(\"condition\", 0);\n    if ((assign_if == (*condition->dptr<C>() == 0))) { return; }\n    const user_op::Tensor* value = ctx->Tensor4ArgNameAndIndex(\"value\", 0);\n    user_op::Tensor* ref = ctx->Tensor4ArgNameAndIndex(\"ref\", 0);\n    if (value->dptr() == ref->dptr()) { return; }\n    CHECK_EQ(value->shape_view(), ref->shape_view());\n    CHECK_EQ(value->data_type(), ref->data_type());\n    const size_t tensor_bytes_size =\n        ref->shape_view().elem_cnt() * GetSizeOfDataType(ref->data_type());\n    AutoMemcpy(ctx->stream(), ref->mut_dptr(), value->dptr(), tensor_bytes_size, ref->mem_case(),\n               value->mem_case());\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return true; }\n};\n\n}  // namespace\n\n#define REGISTER_ASSIGN_WITH_CONDITION_CPU_KERNEL(op_type_name, assign_if, condition_type) \\\n  REGISTER_USER_KERNEL(op_type_name)                                                       \\\n      .SetCreateFn<AssignIfCPUKernel<assign_if, condition_type>>()                         \\\n      .SetIsMatchedHob(                                                                    \\\n          (user_op::HobDeviceType() == DeviceType::kCPU)                                   \\\n          && (user_op::HobDataType(\"condition\", 0) == GetDataType<condition_type>::value));\n\n#define REGISTER_ASSIGN_IF_CPU_KERNEL(condition_cpp_type, condition_data_type)      \\\n  REGISTER_ASSIGN_WITH_CONDITION_CPU_KERNEL(\"assign_if\", true, condition_cpp_type); \\\n  REGISTER_ASSIGN_WITH_CONDITION_CPU_KERNEL(\"assign_if_not\", false, condition_cpp_type)\n\nOF_PP_FOR_EACH_TUPLE(REGISTER_ASSIGN_IF_CPU_KERNEL, INT_DATA_TYPE_SEQ)\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/assign_if_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/kernel_util.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<bool assign_if, typename C, typename T>\n__global__ void AssignGpu(int64_t elem_cnt, const C* condition, const T* value, T* ref) {\n  if (assign_if == (*condition == 0)) { return; }\n  CUDA_1D_KERNEL_LOOP(i, elem_cnt) { ref[i] = value[i]; }\n}\n\ntemplate<bool assign_if, typename C, typename T>\nclass AssignIfGPUKernel final : public user_op::OpKernel {\n public:\n  AssignIfGPUKernel() = default;\n  ~AssignIfGPUKernel() override = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* condition = ctx->Tensor4ArgNameAndIndex(\"condition\", 0);\n    CHECK_EQ(condition->shape_view().NumAxes(), 1);\n    CHECK_EQ(condition->shape_view().At(0), 1);\n    const user_op::Tensor* value = ctx->Tensor4ArgNameAndIndex(\"value\", 0);\n    user_op::Tensor* ref = ctx->Tensor4ArgNameAndIndex(\"ref\", 0);\n    if (value->dptr() == ref->dptr()) { return; }\n    CHECK_EQ(value->shape_view(), ref->shape_view());\n    CHECK_EQ(value->data_type(), ref->data_type());\n    const size_t elem_cnt = ref->shape_view().elem_cnt();\n    AssignGpu<assign_if, C, T><<<BlocksNum4ThreadsNum(elem_cnt), kCudaThreadsNumPerBlock, 0,\n                                 ctx->stream()->As<ep::CudaStream>()->cuda_stream()>>>(\n        elem_cnt, condition->dptr<C>(), value->dptr<T>(), ref->mut_dptr<T>());\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return true; }\n};\n\n}  // namespace\n\n#define REGISTER_ASSIGN_WITH_CONDITION_VALUE_CUDA_KERNEL(op_type_name, assign_if, condition_type, \\\n                                                         value_type)                              \\\n  REGISTER_USER_KERNEL(op_type_name)                                                              \\\n      .SetCreateFn<AssignIfGPUKernel<assign_if, condition_type, value_type>>()                    \\\n      .SetIsMatchedHob(                                                                           \\\n          (user_op::HobDeviceType() == DeviceType::kCUDA)                                         \\\n          && (user_op::HobDataType(\"condition\", 0) == GetDataType<condition_type>::value)         \\\n          && (user_op::HobDataType(\"value\", 0) == GetDataType<value_type>::value));\n\n#define REGISTER_ASSIGN_IF_CUDA_KERNEL(condition_type, value_type)                        \\\n  REGISTER_ASSIGN_WITH_CONDITION_VALUE_CUDA_KERNEL(                                       \\\n      \"assign_if\", true, OF_PP_PAIR_FIRST(condition_type), OF_PP_PAIR_FIRST(value_type)); \\\n  REGISTER_ASSIGN_WITH_CONDITION_VALUE_CUDA_KERNEL(                                       \\\n      \"assign_if_not\", false, OF_PP_PAIR_FIRST(condition_type), OF_PP_PAIR_FIRST(value_type))\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_ASSIGN_IF_CUDA_KERNEL, INT_DATA_TYPE_SEQ,\n                                 POD_DATA_TYPE_SEQ)\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/assign_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/kernel_util.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nclass AssignKernel final : public user_op::OpKernel {\n public:\n  AssignKernel() = default;\n  ~AssignKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* value_tensor = ctx->Tensor4ArgNameAndIndex(\"value\", 0);\n    user_op::Tensor* ref_tensor = ctx->Tensor4ArgNameAndIndex(\"ref\", 0);\n    if (value_tensor->dptr() == ref_tensor->dptr()) { return; }\n    size_t tensor_bytes_size =\n        ref_tensor->shape_view().elem_cnt() * GetSizeOfDataType(ref_tensor->data_type());\n    size_t val_tensor_bytes_size =\n        value_tensor->shape_view().elem_cnt() * GetSizeOfDataType(value_tensor->data_type());\n    CHECK_EQ(tensor_bytes_size, val_tensor_bytes_size);\n    AutoMemcpy(ctx->stream(), ref_tensor->mut_dptr(), value_tensor->dptr(), tensor_bytes_size,\n               ref_tensor->mem_case(), value_tensor->mem_case());\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return true; }\n};\n\n}  // namespace\n\nREGISTER_USER_KERNEL(\"assign\").SetCreateFn<AssignKernel>();\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/avg_pool_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/user/kernels/avg_pool_kernel_util.h\"\n\nnamespace oneflow {\n\nstruct AvgPoolOpKernelCache final : public user_op::OpKernelCache {\n  AvgPoolParams3D params_3d;\n  explicit AvgPoolOpKernelCache(const AvgPoolParams3D& params_3d) : params_3d(params_3d) {}\n  const AvgPoolParams3D& GetParams3D() const { return params_3d; }\n};\n\nstd::shared_ptr<AvgPoolOpKernelCache> CreateAvgOpKernelCache(user_op::KernelCacheContext* ctx,\n                                                             const int32_t& dim) {\n  const Shape& x_shape = ctx->TensorDesc4ArgNameAndIndex(\"x\", 0)->shape();\n  const std::string& data_format = ctx->Attr<std::string>(\"data_format\");\n  const std::vector<int32_t>& padding = ctx->Attr<std::vector<int32_t>>(\"padding\");\n  const std::vector<int32_t>& kernel_size = ctx->Attr<std::vector<int32_t>>(\"kernel_size\");\n  const std::vector<int32_t>& stride = ctx->Attr<std::vector<int32_t>>(\"stride\");\n  const bool ceil_mode = ctx->Attr<bool>(\"ceil_mode\");\n  const bool count_include_pad = ctx->Attr<bool>(\"count_include_pad\");\n  const int32_t divisor_override = ctx->Attr<int32_t>(\"divisor_override\");\n\n  AvgPoolParams3D params_3d =\n      AvgPoolParams3D(dim, x_shape, data_format, padding, kernel_size, stride, ceil_mode,\n                      count_include_pad, divisor_override);\n  std::shared_ptr<AvgPoolOpKernelCache> cache(new AvgPoolOpKernelCache(params_3d));\n  return cache;\n}\n\ntemplate<typename T, typename IDX>\nstruct AvgPoolKernelUtil<DeviceType::kCPU, T, IDX> {\n  static void Avgpool1dForward(ep::Stream* stream, const NdIndexOffsetHelper<IDX, 2>& index_helper,\n                               const IDX elem_num, const T* src, T* dest,\n                               const AvgPoolParams3D& params_3d) {\n    Avgpool1dForwardCompute<T, IDX>(\n        index_helper, elem_num, src, dest, params_3d.padding()[2], params_3d.num_batch(),\n        params_3d.num_channel(), params_3d.GetXShape5D().At(4), params_3d.pool_size_3d()[2],\n        params_3d.stride_3d()[2], params_3d.count_include_pad(), params_3d.divisor_override());\n  }\n\n  static void Avgpool1dBackward(ep::Stream* stream, const NdIndexOffsetHelper<IDX, 2>& index_helper,\n                                const IDX elem_num, const T* src, T* dest,\n                                const AvgPoolParams3D& params_3d) {\n    Avgpool1dBackwardCompute<T, IDX>(\n        index_helper, elem_num, src, dest, params_3d.padding()[2], params_3d.num_batch(),\n        params_3d.num_channel(), params_3d.GetXShape5D().At(4), params_3d.pool_size_3d()[2],\n        params_3d.stride_3d()[2], params_3d.count_include_pad(), params_3d.divisor_override());\n  }\n\n  static void Avgpool2dForward(ep::Stream* stream, const NdIndexOffsetHelper<IDX, 3>& index_helper,\n                               const IDX elem_num, const T* src, T* dest,\n                               const AvgPoolParams3D& params_3d) {\n    Avgpool2dForwardCompute<T, IDX>(\n        index_helper, elem_num, src, dest, params_3d.padding()[1], params_3d.padding()[2],\n        params_3d.num_batch(), params_3d.num_channel(), params_3d.GetXShape5D().At(3),\n        params_3d.GetXShape5D().At(4), params_3d.pool_size_3d()[1], params_3d.pool_size_3d()[2],\n        params_3d.stride_3d()[1], params_3d.stride_3d()[2], params_3d.count_include_pad(),\n        params_3d.divisor_override());\n  }\n\n  static void Avgpool2dBackward(ep::Stream* stream, const NdIndexOffsetHelper<IDX, 3>& index_helper,\n                                const IDX elem_num, const T* src, T* dest,\n                                const AvgPoolParams3D& params_3d) {\n    Avgpool2dBackwardCompute<T, IDX>(\n        index_helper, elem_num, src, dest, params_3d.padding()[1], params_3d.padding()[2],\n        params_3d.num_batch(), params_3d.num_channel(), params_3d.GetXShape5D().At(3),\n        params_3d.GetXShape5D().At(4), params_3d.pool_size_3d()[1], params_3d.pool_size_3d()[2],\n        params_3d.stride_3d()[1], params_3d.stride_3d()[2], params_3d.count_include_pad(),\n        params_3d.divisor_override());\n  }\n\n  static void Avgpool3dForward(ep::Stream* stream, const NdIndexOffsetHelper<IDX, 4>& index_helper,\n                               const IDX elem_num, const T* src, T* dest,\n                               const AvgPoolParams3D& params_3d) {\n    Avgpool3dForwardCompute<T, IDX>(\n        index_helper, elem_num, src, dest, params_3d.padding()[0], params_3d.padding()[1],\n        params_3d.padding()[2], params_3d.num_batch(), params_3d.num_channel(),\n        params_3d.GetXShape5D().At(2), params_3d.GetXShape5D().At(3), params_3d.GetXShape5D().At(4),\n        params_3d.pool_size_3d()[0], params_3d.pool_size_3d()[1], params_3d.pool_size_3d()[2],\n        params_3d.stride_3d()[0], params_3d.stride_3d()[1], params_3d.stride_3d()[2],\n        params_3d.count_include_pad(), params_3d.divisor_override());\n  }\n\n  static void Avgpool3dBackward(ep::Stream* stream, const NdIndexOffsetHelper<IDX, 4>& index_helper,\n                                const int64_t elem_num, const T* src, T* dest,\n                                const AvgPoolParams3D& params_3d) {\n    Avgpool3dBackwardCompute<T, IDX>(\n        index_helper, elem_num, src, dest, params_3d.padding()[0], params_3d.padding()[1],\n        params_3d.padding()[2], params_3d.num_batch(), params_3d.num_channel(),\n        params_3d.GetXShape5D().At(2), params_3d.GetXShape5D().At(3), params_3d.GetXShape5D().At(4),\n        params_3d.pool_size_3d()[0], params_3d.pool_size_3d()[1], params_3d.pool_size_3d()[2],\n        params_3d.stride_3d()[0], params_3d.stride_3d()[1], params_3d.stride_3d()[2],\n        params_3d.count_include_pad(), params_3d.divisor_override());\n  }\n};\n\ntemplate<DeviceType device_type, typename T>\nclass AvgPool1dKernel final : public user_op::OpKernel {\n public:\n  AvgPool1dKernel() = default;\n  ~AvgPool1dKernel() = default;\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n  std::shared_ptr<user_op::OpKernelCache> InitOpKernelCache(\n      user_op::KernelCacheContext* ctx) const override {\n    return CreateAvgOpKernelCache(ctx, 1);\n  }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*,\n               const user_op::OpKernelCache* cache) const override {\n    const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n\n    const auto* pool_cache = dynamic_cast<const AvgPoolOpKernelCache*>(cache);\n    const AvgPoolParams3D& params_3d = pool_cache->GetParams3D();\n\n    const int64_t elem_num = y->shape_view().elem_cnt();\n    const T* src = x->dptr<T>();\n    T* dest = y->mut_dptr<T>();\n\n    DimVector y_vector(2);\n    y_vector.at(0) = y->shape_view().At(0) * y->shape_view().At(1);\n    y_vector.at(1) = y->shape_view().At(2);\n    if (elem_num < GetMaxVal<int32_t>()) {\n      NdIndexOffsetHelper<int32_t, 2> index_helper(y_vector.data());\n      AvgPoolKernelUtil<device_type, T, int32_t>::Avgpool1dForward(ctx->stream(), index_helper,\n                                                                   elem_num, src, dest, params_3d);\n    } else {\n      NdIndexOffsetHelper<int64_t, 2> index_helper(y_vector.data());\n      AvgPoolKernelUtil<device_type, T, int64_t>::Avgpool1dForward(ctx->stream(), index_helper,\n                                                                   elem_num, src, dest, params_3d);\n    }\n  };\n};\n\ntemplate<DeviceType device_type, typename T>\nclass AvgPool1dGradKernel final : public user_op::OpKernel {\n public:\n  AvgPool1dGradKernel() = default;\n  ~AvgPool1dGradKernel() = default;\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n  std::shared_ptr<user_op::OpKernelCache> InitOpKernelCache(\n      user_op::KernelCacheContext* ctx) const override {\n    return CreateAvgOpKernelCache(ctx, 1);\n  }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*,\n               const user_op::OpKernelCache* cache) const override {\n    const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex(\"dx\", 0);\n\n    const auto* pool_cache = dynamic_cast<const AvgPoolOpKernelCache*>(cache);\n    const AvgPoolParams3D& params_3d = pool_cache->GetParams3D();\n\n    const int64_t elem_num = dy->shape_view().elem_cnt();\n    const T* src = dy->dptr<T>();\n    T* dest = dx->mut_dptr<T>();\n    size_t out_bytes_size = dx->shape_view().elem_cnt() * GetSizeOfDataType(dx->data_type());\n    Memset<device_type>(ctx->stream(), dest, 0, out_bytes_size);\n\n    DimVector dy_vector(2);\n    dy_vector.at(0) = dy->shape_view().At(0) * dy->shape_view().At(1);\n    dy_vector.at(1) = dy->shape_view().At(2);\n    if (elem_num < GetMaxVal<int32_t>()) {\n      NdIndexOffsetHelper<int32_t, 2> index_helper(dy_vector.data());\n      AvgPoolKernelUtil<device_type, T, int32_t>::Avgpool1dBackward(ctx->stream(), index_helper,\n                                                                    elem_num, src, dest, params_3d);\n    } else {\n      NdIndexOffsetHelper<int64_t, 2> index_helper(dy_vector.data());\n      AvgPoolKernelUtil<device_type, T, int64_t>::Avgpool1dBackward(ctx->stream(), index_helper,\n                                                                    elem_num, src, dest, params_3d);\n    }\n  };\n};\n\ntemplate<DeviceType device_type, typename T>\nclass AvgPool2dKernel final : public user_op::OpKernel {\n public:\n  AvgPool2dKernel() = default;\n  ~AvgPool2dKernel() = default;\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n  std::shared_ptr<user_op::OpKernelCache> InitOpKernelCache(\n      user_op::KernelCacheContext* ctx) const override {\n    return CreateAvgOpKernelCache(ctx, 2);\n  }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*,\n               const user_op::OpKernelCache* cache) const override {\n    const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n\n    const auto* pool_cache = dynamic_cast<const AvgPoolOpKernelCache*>(cache);\n    const AvgPoolParams3D& params_3d = pool_cache->GetParams3D();\n\n    const int64_t elem_num = y->shape_view().elem_cnt();\n    const T* src = x->dptr<T>();\n    T* dest = y->mut_dptr<T>();\n\n    DimVector y_vector(3);\n    y_vector.at(0) = y->shape_view().At(0) * y->shape_view().At(1);\n    y_vector.at(1) = y->shape_view().At(2);\n    y_vector.at(2) = y->shape_view().At(3);\n    if (elem_num < GetMaxVal<int32_t>()) {\n      NdIndexOffsetHelper<int32_t, 3> index_helper(y_vector.data());\n      AvgPoolKernelUtil<device_type, T, int32_t>::Avgpool2dForward(ctx->stream(), index_helper,\n                                                                   elem_num, src, dest, params_3d);\n    } else {\n      NdIndexOffsetHelper<int64_t, 3> index_helper(y_vector.data());\n      AvgPoolKernelUtil<device_type, T, int64_t>::Avgpool2dForward(ctx->stream(), index_helper,\n                                                                   elem_num, src, dest, params_3d);\n    }\n  };\n};\n\ntemplate<DeviceType device_type, typename T>\nclass AvgPool2dGradKernel final : public user_op::OpKernel {\n public:\n  AvgPool2dGradKernel() = default;\n  ~AvgPool2dGradKernel() = default;\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n  std::shared_ptr<user_op::OpKernelCache> InitOpKernelCache(\n      user_op::KernelCacheContext* ctx) const override {\n    return CreateAvgOpKernelCache(ctx, 2);\n  }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*,\n               const user_op::OpKernelCache* cache) const override {\n    const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex(\"dx\", 0);\n\n    const auto* pool_cache = dynamic_cast<const AvgPoolOpKernelCache*>(cache);\n    const AvgPoolParams3D& params_3d = pool_cache->GetParams3D();\n\n    const int64_t elem_num = dy->shape_view().elem_cnt();\n    const T* src = dy->dptr<T>();\n    T* dest = dx->mut_dptr<T>();\n\n    size_t out_bytes_size = dx->shape_view().elem_cnt() * GetSizeOfDataType(dx->data_type());\n    Memset<device_type>(ctx->stream(), dest, 0, out_bytes_size);\n\n    DimVector dy_vector(3);\n    dy_vector.at(0) = dy->shape_view().At(0) * dy->shape_view().At(1);\n    dy_vector.at(1) = dy->shape_view().At(2);\n    dy_vector.at(2) = dy->shape_view().At(3);\n    if (elem_num < GetMaxVal<int32_t>()) {\n      NdIndexOffsetHelper<int32_t, 3> index_helper(dy_vector.data());\n      AvgPoolKernelUtil<device_type, T, int32_t>::Avgpool2dBackward(ctx->stream(), index_helper,\n                                                                    elem_num, src, dest, params_3d);\n    } else {\n      NdIndexOffsetHelper<int64_t, 3> index_helper(dy_vector.data());\n      AvgPoolKernelUtil<device_type, T, int64_t>::Avgpool2dBackward(ctx->stream(), index_helper,\n                                                                    elem_num, src, dest, params_3d);\n    }\n  };\n};\n\ntemplate<DeviceType device_type, typename T>\nclass AvgPool3dKernel final : public user_op::OpKernel {\n public:\n  AvgPool3dKernel() = default;\n  ~AvgPool3dKernel() = default;\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n  std::shared_ptr<user_op::OpKernelCache> InitOpKernelCache(\n      user_op::KernelCacheContext* ctx) const override {\n    return CreateAvgOpKernelCache(ctx, 3);\n  }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*,\n               const user_op::OpKernelCache* cache) const override {\n    const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n\n    const auto* pool_cache = dynamic_cast<const AvgPoolOpKernelCache*>(cache);\n    const AvgPoolParams3D& params_3d = pool_cache->GetParams3D();\n\n    const int64_t elem_num = y->shape_view().elem_cnt();\n    const T* src = x->dptr<T>();\n    T* dest = y->mut_dptr<T>();\n\n    DimVector y_vector(4);\n    y_vector.at(0) = y->shape_view().At(0) * y->shape_view().At(1);\n    y_vector.at(1) = y->shape_view().At(2);\n    y_vector.at(2) = y->shape_view().At(3);\n    y_vector.at(3) = y->shape_view().At(4);\n    if (elem_num < GetMaxVal<int32_t>()) {\n      NdIndexOffsetHelper<int32_t, 4> index_helper(y_vector.data());\n      AvgPoolKernelUtil<device_type, T, int32_t>::Avgpool3dForward(ctx->stream(), index_helper,\n                                                                   elem_num, src, dest, params_3d);\n    } else {\n      NdIndexOffsetHelper<int64_t, 4> index_helper(y_vector.data());\n      AvgPoolKernelUtil<device_type, T, int64_t>::Avgpool3dForward(ctx->stream(), index_helper,\n                                                                   elem_num, src, dest, params_3d);\n    }\n  };\n};\n\ntemplate<DeviceType device_type, typename T>\nclass AvgPool3dGradKernel final : public user_op::OpKernel {\n public:\n  AvgPool3dGradKernel() = default;\n  ~AvgPool3dGradKernel() = default;\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n  std::shared_ptr<user_op::OpKernelCache> InitOpKernelCache(\n      user_op::KernelCacheContext* ctx) const override {\n    return CreateAvgOpKernelCache(ctx, 3);\n  }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*,\n               const user_op::OpKernelCache* cache) const override {\n    const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex(\"dx\", 0);\n\n    const auto* pool_cache = dynamic_cast<const AvgPoolOpKernelCache*>(cache);\n    const AvgPoolParams3D& params_3d = pool_cache->GetParams3D();\n\n    const int64_t elem_num = dy->shape_view().elem_cnt();\n    const T* src = dy->dptr<T>();\n    T* dest = dx->mut_dptr<T>();\n\n    size_t out_bytes_size = dx->shape_view().elem_cnt() * GetSizeOfDataType(dx->data_type());\n    Memset<device_type>(ctx->stream(), dest, 0, out_bytes_size);\n\n    DimVector dy_vector(4);\n    dy_vector.at(0) = dy->shape_view().At(0) * dy->shape_view().At(1);\n    dy_vector.at(1) = dy->shape_view().At(2);\n    dy_vector.at(2) = dy->shape_view().At(3);\n    dy_vector.at(3) = dy->shape_view().At(4);\n    if (elem_num < GetMaxVal<int32_t>()) {\n      NdIndexOffsetHelper<int32_t, 4> index_helper(dy_vector.data());\n      AvgPoolKernelUtil<device_type, T, int32_t>::Avgpool3dBackward(ctx->stream(), index_helper,\n                                                                    elem_num, src, dest, params_3d);\n    } else {\n      NdIndexOffsetHelper<int64_t, 4> index_helper(dy_vector.data());\n      AvgPoolKernelUtil<device_type, T, int64_t>::Avgpool3dBackward(ctx->stream(), index_helper,\n                                                                    elem_num, src, dest, params_3d);\n    }\n  };\n};\n\n#define REGISTER_AVG_POOL_KERNELS(device, dtype)                                        \\\n  REGISTER_USER_KERNEL(\"avg_pool_1d\")                                                   \\\n      .SetCreateFn<AvgPool1dKernel<device, dtype>>()                                    \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == device)                             \\\n                       && (user_op::HobDataType(\"x\", 0) == GetDataType<dtype>::value)); \\\n  REGISTER_USER_KERNEL(\"avg_pool_1d_grad\")                                              \\\n      .SetCreateFn<AvgPool1dGradKernel<device, dtype>>()                                \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == device)                             \\\n                       && (user_op::HobDataType(\"x\", 0) == GetDataType<dtype>::value)); \\\n  REGISTER_USER_KERNEL(\"avg_pool_2d\")                                                   \\\n      .SetCreateFn<AvgPool2dKernel<device, dtype>>()                                    \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == device)                             \\\n                       && (user_op::HobDataType(\"x\", 0) == GetDataType<dtype>::value)); \\\n  REGISTER_USER_KERNEL(\"avg_pool_2d_grad\")                                              \\\n      .SetCreateFn<AvgPool2dGradKernel<device, dtype>>()                                \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == device)                             \\\n                       && (user_op::HobDataType(\"x\", 0) == GetDataType<dtype>::value)); \\\n  REGISTER_USER_KERNEL(\"avg_pool_3d\")                                                   \\\n      .SetCreateFn<AvgPool3dKernel<device, dtype>>()                                    \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == device)                             \\\n                       && (user_op::HobDataType(\"x\", 0) == GetDataType<dtype>::value)); \\\n  REGISTER_USER_KERNEL(\"avg_pool_3d_grad\")                                              \\\n      .SetCreateFn<AvgPool3dGradKernel<device, dtype>>()                                \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == device)                             \\\n                       && (user_op::HobDataType(\"x\", 0) == GetDataType<dtype>::value));\n\n#define REGISTER_AVG_POOL_WITH_DEVICE(device) \\\n  REGISTER_AVG_POOL_KERNELS(device, float)    \\\n  REGISTER_AVG_POOL_KERNELS(device, double)\n\nREGISTER_AVG_POOL_WITH_DEVICE(DeviceType::kCPU)\n\n#ifdef WITH_CUDA\nREGISTER_AVG_POOL_WITH_DEVICE(DeviceType::kCUDA)\nREGISTER_AVG_POOL_KERNELS(DeviceType::kCUDA, half)\n#endif\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_AVG_POOL_KERNEL_UTIL, (DeviceType::kCPU),\n                                 AVG_POOL_DATA_TYPE_CPU_SEQ, AVG_POOL_IDX_DATA_TYPE_SEQ);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/avg_pool_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <cstdint>\n#include \"oneflow/core/cuda/elementwise.cuh\"\n#include \"oneflow/user/kernels/avg_pool_kernel_util.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nconstexpr int kBlockSize = cuda::elementwise::kBlockSize;\n\nint GetMinThreadNum(const int64_t elem_num) { return std::min<int64_t>(elem_num, kBlockSize); }\n\nint GetNumBlocks(int32_t elem_cnt) {\n  int num_blocks = 0;\n  OF_CUDA_CHECK(cuda::elementwise::GetNumBlocks(elem_cnt, &num_blocks));\n  return num_blocks;\n}\n\n}  // namespace\n\ntemplate<typename T, typename IDX>\n__launch_bounds__(kBlockSize) __global__\n    void DoCUDAAvgPool1dForward(const NdIndexOffsetHelper<IDX, 2> index_helper, IDX elem_num,\n                                const T* src, T* dest, int32_t padding_l, const int32_t n_batch,\n                                const int32_t n_channel, const int32_t x_length,\n                                const int32_t kernel_size_l, const int32_t stride_l,\n                                const bool count_include_pad, const int32_t divisor_override) {\n  Avgpool1dForwardCompute<T>(index_helper, elem_num, src, dest, padding_l, n_batch, n_channel,\n                             x_length, kernel_size_l, stride_l, count_include_pad,\n                             divisor_override);\n};\n\ntemplate<typename T, typename IDX>\n__launch_bounds__(kBlockSize) __global__\n    void DoCUDAAvgPool2dForward(const NdIndexOffsetHelper<IDX, 3> index_helper, IDX elem_num,\n                                const T* src, T* dest, const int32_t padding_h,\n                                const int32_t padding_w, const int32_t n_batch,\n                                const int32_t n_channel, const int32_t x_height,\n                                const int32_t x_width, const int32_t kernel_size_h,\n                                const int32_t kernel_size_w, const int32_t stride_h,\n                                const int32_t stride_w, const bool count_include_pad,\n                                const int32_t divisor_override) {\n  Avgpool2dForwardCompute<T>(index_helper, elem_num, src, dest, padding_h, padding_w, n_batch,\n                             n_channel, x_height, x_width, kernel_size_h, kernel_size_w, stride_h,\n                             stride_w, count_include_pad, divisor_override);\n};\n\ntemplate<typename T, typename IDX>\n__launch_bounds__(kBlockSize) __global__\n    void DoCUDAAvgPool3dForward(const NdIndexOffsetHelper<IDX, 4> index_helper, IDX elem_num,\n                                const T* src, T* dest, int32_t padding_t, const int32_t padding_h,\n                                const int32_t padding_w, const int32_t n_batch,\n                                const int32_t n_channel, const int32_t x_time,\n                                const int32_t x_height, const int32_t x_width,\n                                const int32_t kernel_size_t, int32_t kernel_size_h,\n                                const int32_t kernel_size_w, const int32_t stride_t,\n                                const int32_t stride_h, const int32_t stride_w,\n                                const bool count_include_pad, const int32_t divisor_override) {\n  Avgpool3dForwardCompute<T>(index_helper, elem_num, src, dest, padding_t, padding_h, padding_w,\n                             n_batch, n_channel, x_time, x_height, x_width, kernel_size_t,\n                             kernel_size_h, kernel_size_w, stride_t, stride_h, stride_w,\n                             count_include_pad, divisor_override);\n};\n\ntemplate<typename T, typename IDX>\n__launch_bounds__(kBlockSize) __global__\n    void DoCUDAAvgPool1dBackward(const NdIndexOffsetHelper<IDX, 2> index_helper, IDX elem_num,\n                                 const T* src, T* dest, const int32_t padding_l,\n                                 const int32_t n_batch, const int32_t n_channel,\n                                 const int32_t input_length, const int32_t kernel_size_l,\n                                 const int32_t stride_l, const bool count_include_pad,\n                                 const int32_t divisor_override) {\n  Avgpool1dBackwardCompute<T>(index_helper, elem_num, src, dest, padding_l, n_batch, n_channel,\n                              input_length, kernel_size_l, stride_l, count_include_pad,\n                              divisor_override);\n};\n\ntemplate<typename T, typename IDX>\n__launch_bounds__(kBlockSize) __global__\n    void DoCUDAAvgPool2dBackward(const NdIndexOffsetHelper<IDX, 3> index_helper, IDX elem_num,\n                                 const T* src, T* dest, const int32_t padding_h,\n                                 const int32_t padding_w, const int32_t n_batch,\n                                 const int32_t n_channel, const int32_t input_height,\n                                 const int32_t input_width, const int32_t kernel_size_h,\n                                 const int32_t kernel_size_w, const int32_t stride_h,\n                                 const int32_t stride_w, const bool count_include_pad,\n                                 int32_t divisor_override) {\n  Avgpool2dBackwardCompute<T>(index_helper, elem_num, src, dest, padding_h, padding_w, n_batch,\n                              n_channel, input_height, input_width, kernel_size_h, kernel_size_w,\n                              stride_h, stride_w, count_include_pad, divisor_override);\n};\n\ntemplate<typename T, typename IDX>\n__launch_bounds__(kBlockSize) __global__ void DoCUDAAvgPool3dBackward(\n    const NdIndexOffsetHelper<IDX, 4> index_helper, IDX elem_num, const T* src, T* dest,\n    const int32_t padding_t, const int32_t padding_h, const int32_t padding_w,\n    const int32_t n_batch, const int32_t n_channel, const int32_t x_time, const int32_t x_height,\n    const int32_t x_width, const int32_t kernel_size_t, const int32_t kernel_size_h,\n    const int32_t kernel_size_w, const int32_t stride_t, const int32_t stride_h,\n    const int32_t stride_w, const bool count_include_pad, const int32_t divisor_override) {\n  Avgpool3dBackwardCompute<T>(index_helper, elem_num, src, dest, padding_t, padding_h, padding_w,\n                              n_batch, n_channel, x_time, x_height, x_width, kernel_size_t,\n                              kernel_size_h, kernel_size_w, stride_t, stride_h, stride_w,\n                              count_include_pad, divisor_override);\n};\n\ntemplate<typename IDX>\n__launch_bounds__(kBlockSize) __global__\n    void DoHalfAvgPool1dForward(const NdIndexOffsetHelper<IDX, 2> index_helper, IDX elem_num,\n                                const half* src, half* dest, int32_t padding_l,\n                                const int32_t n_batch, const int32_t n_channel,\n                                const int32_t x_length, const int32_t kernel_size_l,\n                                const int32_t stride_l, const bool count_include_pad,\n                                const int32_t divisor_override) {\n  HalfAvgpool1dForwardCompute<IDX>(index_helper, elem_num, src, dest, padding_l, n_batch, n_channel,\n                                   x_length, kernel_size_l, stride_l, count_include_pad,\n                                   divisor_override);\n};\n\ntemplate<typename IDX>\n__launch_bounds__(kBlockSize) __global__\n    void DoHalfAvgPool2dForward(const NdIndexOffsetHelper<IDX, 3> index_helper, IDX elem_num,\n                                const half* src, half* dest, const int32_t padding_h,\n                                const int32_t padding_w, const int32_t n_batch,\n                                const int32_t n_channel, const int32_t x_height,\n                                const int32_t x_width, const int32_t kernel_size_h,\n                                const int32_t kernel_size_w, const int32_t stride_h,\n                                const int32_t stride_w, const bool count_include_pad,\n                                const int32_t divisor_override) {\n  HalfAvgpool2dForwardCompute<IDX>(index_helper, elem_num, src, dest, padding_h, padding_w, n_batch,\n                                   n_channel, x_height, x_width, kernel_size_h, kernel_size_w,\n                                   stride_h, stride_w, count_include_pad, divisor_override);\n};\n\ntemplate<typename IDX>\n__launch_bounds__(kBlockSize) __global__\n    void DoHalfAvgPool3dForward(const NdIndexOffsetHelper<IDX, 4> index_helper, IDX elem_num,\n                                const half* src, half* dest, int32_t padding_t,\n                                const int32_t padding_h, const int32_t padding_w,\n                                const int32_t n_batch, const int32_t n_channel,\n                                const int32_t x_time, const int32_t x_height, const int32_t x_width,\n                                const int32_t kernel_size_t, int32_t kernel_size_h,\n                                const int32_t kernel_size_w, const int32_t stride_t,\n                                const int32_t stride_h, const int32_t stride_w,\n                                const bool count_include_pad, const int32_t divisor_override) {\n  HalfAvgpool3dForwardCompute<IDX>(index_helper, elem_num, src, dest, padding_t, padding_h,\n                                   padding_w, n_batch, n_channel, x_time, x_height, x_width,\n                                   kernel_size_t, kernel_size_h, kernel_size_w, stride_t, stride_h,\n                                   stride_w, count_include_pad, divisor_override);\n};\n\ntemplate<typename IDX>\n__launch_bounds__(kBlockSize) __global__\n    void DoHalfAvgPool1dBackward(const NdIndexOffsetHelper<IDX, 2> index_helper, IDX elem_num,\n                                 const half* src, half* dest, const int32_t padding_l,\n                                 const int32_t n_batch, const int32_t n_channel,\n                                 const int32_t input_length, const int32_t kernel_size_l,\n                                 const int32_t stride_l, const bool count_include_pad,\n                                 const int32_t divisor_override) {\n  HalfAvgpool1dBackwardCompute<IDX>(index_helper, elem_num, src, dest, padding_l, n_batch,\n                                    n_channel, input_length, kernel_size_l, stride_l,\n                                    count_include_pad, divisor_override);\n};\n\ntemplate<typename IDX>\n__launch_bounds__(kBlockSize) __global__\n    void DoHalfAvgPool2dBackward(const NdIndexOffsetHelper<IDX, 3> index_helper, IDX elem_num,\n                                 const half* src, half* dest, const int32_t padding_h,\n                                 const int32_t padding_w, const int32_t n_batch,\n                                 const int32_t n_channel, const int32_t input_height,\n                                 const int32_t input_width, const int32_t kernel_size_h,\n                                 const int32_t kernel_size_w, const int32_t stride_h,\n                                 const int32_t stride_w, const bool count_include_pad,\n                                 int32_t divisor_override) {\n  HalfAvgpool2dBackwardCompute<IDX>(index_helper, elem_num, src, dest, padding_h, padding_w,\n                                    n_batch, n_channel, input_height, input_width, kernel_size_h,\n                                    kernel_size_w, stride_h, stride_w, count_include_pad,\n                                    divisor_override);\n};\n\ntemplate<typename IDX>\n__launch_bounds__(kBlockSize) __global__ void DoHalfAvgPool3dBackward(\n    const NdIndexOffsetHelper<IDX, 4> index_helper, IDX elem_num, const half* src, half* dest,\n    const int32_t padding_t, const int32_t padding_h, const int32_t padding_w,\n    const int32_t n_batch, const int32_t n_channel, const int32_t x_time, const int32_t x_height,\n    const int32_t x_width, const int32_t kernel_size_t, const int32_t kernel_size_h,\n    const int32_t kernel_size_w, const int32_t stride_t, const int32_t stride_h,\n    const int32_t stride_w, const bool count_include_pad, const int32_t divisor_override) {\n  HalfAvgpool3dBackwardCompute<IDX>(index_helper, elem_num, src, dest, padding_t, padding_h,\n                                    padding_w, n_batch, n_channel, x_time, x_height, x_width,\n                                    kernel_size_t, kernel_size_h, kernel_size_w, stride_t, stride_h,\n                                    stride_w, count_include_pad, divisor_override);\n};\n\ntemplate<typename T, typename IDX>\nstruct AvgPoolKernelUtil<DeviceType::kCUDA, T, IDX> {\n  static void Avgpool1dForward(ep::Stream* stream, const NdIndexOffsetHelper<IDX, 2>& index_helper,\n                               const IDX elem_num, const T* src, T* dest,\n                               const AvgPoolParams3D& params_3d) {\n    DoCUDAAvgPool1dForward<T, IDX><<<GetNumBlocks(elem_num), GetMinThreadNum(elem_num), 0,\n                                     stream->As<ep::CudaStream>()->cuda_stream()>>>(\n        index_helper, elem_num, src, dest, params_3d.padding()[2], params_3d.num_batch(),\n        params_3d.num_channel(), params_3d.GetXShape5D().At(4), params_3d.pool_size_3d()[2],\n        params_3d.stride_3d()[2], params_3d.count_include_pad(), params_3d.divisor_override());\n  }\n\n  static void Avgpool1dBackward(ep::Stream* stream, const NdIndexOffsetHelper<IDX, 2>& index_helper,\n                                const IDX elem_num, const T* src, T* dest,\n                                const AvgPoolParams3D& params_3d) {\n    DoCUDAAvgPool1dBackward<T, IDX><<<GetNumBlocks(elem_num), GetMinThreadNum(elem_num), 0,\n                                      stream->As<ep::CudaStream>()->cuda_stream()>>>(\n        index_helper, elem_num, src, dest, params_3d.padding()[2], params_3d.num_batch(),\n        params_3d.num_channel(), params_3d.GetXShape5D().At(4), params_3d.pool_size_3d()[2],\n        params_3d.stride_3d()[2], params_3d.count_include_pad(), params_3d.divisor_override());\n  }\n\n  static void Avgpool2dForward(ep::Stream* stream, const NdIndexOffsetHelper<IDX, 3>& index_helper,\n                               const IDX elem_num, const T* src, T* dest,\n                               const AvgPoolParams3D& params_3d) {\n    DoCUDAAvgPool2dForward<T, IDX><<<GetNumBlocks(elem_num), GetMinThreadNum(elem_num), 0,\n                                     stream->As<ep::CudaStream>()->cuda_stream()>>>(\n        index_helper, elem_num, src, dest, params_3d.padding()[1], params_3d.padding()[2],\n        params_3d.num_batch(), params_3d.num_channel(), params_3d.GetXShape5D().At(3),\n        params_3d.GetXShape5D().At(4), params_3d.pool_size_3d()[1], params_3d.pool_size_3d()[2],\n        params_3d.stride_3d()[1], params_3d.stride_3d()[2], params_3d.count_include_pad(),\n        params_3d.divisor_override());\n  }\n\n  static void Avgpool2dBackward(ep::Stream* stream, const NdIndexOffsetHelper<IDX, 3>& index_helper,\n                                const IDX elem_num, const T* src, T* dest,\n                                const AvgPoolParams3D& params_3d) {\n    DoCUDAAvgPool2dBackward<T, IDX><<<GetNumBlocks(elem_num), GetMinThreadNum(elem_num), 0,\n                                      stream->As<ep::CudaStream>()->cuda_stream()>>>(\n        index_helper, elem_num, src, dest, params_3d.padding()[1], params_3d.padding()[2],\n        params_3d.num_batch(), params_3d.num_channel(), params_3d.GetXShape5D().At(3),\n        params_3d.GetXShape5D().At(4), params_3d.pool_size_3d()[1], params_3d.pool_size_3d()[2],\n        params_3d.stride_3d()[1], params_3d.stride_3d()[2], params_3d.count_include_pad(),\n        params_3d.divisor_override());\n  }\n\n  static void Avgpool3dForward(ep::Stream* stream, const NdIndexOffsetHelper<IDX, 4>& index_helper,\n                               const IDX elem_num, const T* src, T* dest,\n                               const AvgPoolParams3D& params_3d) {\n    DoCUDAAvgPool3dForward<T, IDX><<<GetNumBlocks(elem_num), GetMinThreadNum(elem_num), 0,\n                                     stream->As<ep::CudaStream>()->cuda_stream()>>>(\n        index_helper, elem_num, src, dest, params_3d.padding()[0], params_3d.padding()[1],\n        params_3d.padding()[2], params_3d.num_batch(), params_3d.num_channel(),\n        params_3d.GetXShape5D().At(2), params_3d.GetXShape5D().At(3), params_3d.GetXShape5D().At(4),\n        params_3d.pool_size_3d()[0], params_3d.pool_size_3d()[1], params_3d.pool_size_3d()[2],\n        params_3d.stride_3d()[0], params_3d.stride_3d()[1], params_3d.stride_3d()[2],\n        params_3d.count_include_pad(), params_3d.divisor_override());\n  }\n\n  static void Avgpool3dBackward(ep::Stream* stream, const NdIndexOffsetHelper<IDX, 4>& index_helper,\n                                const IDX elem_num, const T* src, T* dest,\n                                const AvgPoolParams3D& params_3d) {\n    DoCUDAAvgPool3dBackward<T, IDX><<<GetNumBlocks(elem_num), GetMinThreadNum(elem_num), 0,\n                                      stream->As<ep::CudaStream>()->cuda_stream()>>>(\n        index_helper, elem_num, src, dest, params_3d.padding()[0], params_3d.padding()[1],\n        params_3d.padding()[2], params_3d.num_batch(), params_3d.num_channel(),\n        params_3d.GetXShape5D().At(2), params_3d.GetXShape5D().At(3), params_3d.GetXShape5D().At(4),\n        params_3d.pool_size_3d()[0], params_3d.pool_size_3d()[1], params_3d.pool_size_3d()[2],\n        params_3d.stride_3d()[0], params_3d.stride_3d()[1], params_3d.stride_3d()[2],\n        params_3d.count_include_pad(), params_3d.divisor_override());\n  }\n};\n\ntemplate<typename IDX>\nstruct AvgPoolKernelUtil<DeviceType::kCUDA, half, IDX> {\n  static void Avgpool1dForward(ep::Stream* stream, const NdIndexOffsetHelper<IDX, 2>& index_helper,\n                               const IDX elem_num, const half* src, half* dest,\n                               const AvgPoolParams3D& params_3d) {\n    DoHalfAvgPool1dForward<IDX><<<GetNumBlocks(elem_num), GetMinThreadNum(elem_num), 0,\n                                  stream->As<ep::CudaStream>()->cuda_stream()>>>(\n        index_helper, elem_num, src, dest, params_3d.padding()[2], params_3d.num_batch(),\n        params_3d.num_channel(), params_3d.GetXShape5D().At(4), params_3d.pool_size_3d()[2],\n        params_3d.stride_3d()[2], params_3d.count_include_pad(), params_3d.divisor_override());\n  }\n\n  static void Avgpool1dBackward(ep::Stream* stream, const NdIndexOffsetHelper<IDX, 2>& index_helper,\n                                const IDX elem_num, const half* src, half* dest,\n                                const AvgPoolParams3D& params_3d) {\n    DoHalfAvgPool1dBackward<IDX><<<GetNumBlocks(elem_num), GetMinThreadNum(elem_num), 0,\n                                   stream->As<ep::CudaStream>()->cuda_stream()>>>(\n        index_helper, elem_num, src, dest, params_3d.padding()[2], params_3d.num_batch(),\n        params_3d.num_channel(), params_3d.GetXShape5D().At(4), params_3d.pool_size_3d()[2],\n        params_3d.stride_3d()[2], params_3d.count_include_pad(), params_3d.divisor_override());\n  }\n\n  static void Avgpool2dForward(ep::Stream* stream, const NdIndexOffsetHelper<IDX, 3>& index_helper,\n                               const IDX elem_num, const half* src, half* dest,\n                               const AvgPoolParams3D& params_3d) {\n    DoHalfAvgPool2dForward<IDX><<<GetNumBlocks(elem_num), GetMinThreadNum(elem_num), 0,\n                                  stream->As<ep::CudaStream>()->cuda_stream()>>>(\n        index_helper, elem_num, src, dest, params_3d.padding()[1], params_3d.padding()[2],\n        params_3d.num_batch(), params_3d.num_channel(), params_3d.GetXShape5D().At(3),\n        params_3d.GetXShape5D().At(4), params_3d.pool_size_3d()[1], params_3d.pool_size_3d()[2],\n        params_3d.stride_3d()[1], params_3d.stride_3d()[2], params_3d.count_include_pad(),\n        params_3d.divisor_override());\n  }\n\n  static void Avgpool2dBackward(ep::Stream* stream, const NdIndexOffsetHelper<IDX, 3>& index_helper,\n                                const IDX elem_num, const half* src, half* dest,\n                                const AvgPoolParams3D& params_3d) {\n    DoHalfAvgPool2dBackward<IDX><<<GetNumBlocks(elem_num), GetMinThreadNum(elem_num), 0,\n                                   stream->As<ep::CudaStream>()->cuda_stream()>>>(\n        index_helper, elem_num, src, dest, params_3d.padding()[1], params_3d.padding()[2],\n        params_3d.num_batch(), params_3d.num_channel(), params_3d.GetXShape5D().At(3),\n        params_3d.GetXShape5D().At(4), params_3d.pool_size_3d()[1], params_3d.pool_size_3d()[2],\n        params_3d.stride_3d()[1], params_3d.stride_3d()[2], params_3d.count_include_pad(),\n        params_3d.divisor_override());\n  }\n\n  static void Avgpool3dForward(ep::Stream* stream, const NdIndexOffsetHelper<IDX, 4>& index_helper,\n                               const IDX elem_num, const half* src, half* dest,\n                               const AvgPoolParams3D& params_3d) {\n    DoHalfAvgPool3dForward<IDX><<<GetNumBlocks(elem_num), GetMinThreadNum(elem_num), 0,\n                                  stream->As<ep::CudaStream>()->cuda_stream()>>>(\n        index_helper, elem_num, src, dest, params_3d.padding()[0], params_3d.padding()[1],\n        params_3d.padding()[2], params_3d.num_batch(), params_3d.num_channel(),\n        params_3d.GetXShape5D().At(2), params_3d.GetXShape5D().At(3), params_3d.GetXShape5D().At(4),\n        params_3d.pool_size_3d()[0], params_3d.pool_size_3d()[1], params_3d.pool_size_3d()[2],\n        params_3d.stride_3d()[0], params_3d.stride_3d()[1], params_3d.stride_3d()[2],\n        params_3d.count_include_pad(), params_3d.divisor_override());\n  }\n\n  static void Avgpool3dBackward(ep::Stream* stream, const NdIndexOffsetHelper<IDX, 4>& index_helper,\n                                const IDX elem_num, const half* src, half* dest,\n                                const AvgPoolParams3D& params_3d) {\n    DoHalfAvgPool3dBackward<IDX><<<GetNumBlocks(elem_num), GetMinThreadNum(elem_num), 0,\n                                   stream->As<ep::CudaStream>()->cuda_stream()>>>(\n        index_helper, elem_num, src, dest, params_3d.padding()[0], params_3d.padding()[1],\n        params_3d.padding()[2], params_3d.num_batch(), params_3d.num_channel(),\n        params_3d.GetXShape5D().At(2), params_3d.GetXShape5D().At(3), params_3d.GetXShape5D().At(4),\n        params_3d.pool_size_3d()[0], params_3d.pool_size_3d()[1], params_3d.pool_size_3d()[2],\n        params_3d.stride_3d()[0], params_3d.stride_3d()[1], params_3d.stride_3d()[2],\n        params_3d.count_include_pad(), params_3d.divisor_override());\n  }\n};\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_AVG_POOL_KERNEL_UTIL, (DeviceType::kCUDA),\n                                 AVG_POOL_DATA_TYPE_CUDA_SEQ, AVG_POOL_IDX_DATA_TYPE_SEQ);\ntemplate struct AvgPoolKernelUtil<DeviceType::kCUDA, half, int32_t>;\ntemplate struct AvgPoolKernelUtil<DeviceType::kCUDA, half, int64_t>;\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/avg_pool_kernel_util.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/user/kernels/avg_pool_kernel_util.h\"\n\nnamespace oneflow {\n\nstd::vector<int32_t> GetAvg3DVec(const std::vector<int32_t>& original_vec, int32_t NDims) {\n  std::vector<int32_t> vec;\n  FOR_RANGE(uint8_t, dim, 0, 3) {\n    int64_t index = static_cast<int64_t>(dim) - (3 - NDims);\n    if (index < 0) {\n      vec.emplace_back(1);\n    } else {\n      vec.emplace_back(original_vec.at(index));\n    }\n  }\n  return vec;\n}\n\nstd::vector<int32_t> GetAvg3DPadVec(const std::vector<int32_t>& original_vec, int32_t NDims) {\n  std::vector<int32_t> vec;\n  FOR_RANGE(uint8_t, dim, 0, 3) {\n    int64_t index = static_cast<int64_t>(dim) - (3 - NDims);\n    if (index < 0) {\n      vec.emplace_back(0);\n    } else {\n      vec.emplace_back(original_vec.at(index));\n    }\n  }\n  return vec;\n}\n\nconst int64_t GetNoDilationWindowedOutputShape(int64_t input_size, int32_t filter_size,\n                                               int32_t stride, int32_t padding, bool ceil_mode) {\n  int64_t output_size =\n      (input_size + 2 * padding - (filter_size - 1) - 1 + stride + (ceil_mode ? stride - 1 : 0))\n      / stride;\n\n  if (ceil_mode) {\n    // ensure that the last pooling starts inside the image\n    // needed to avoid problems in ceil mode\n    if ((output_size - 1) * stride >= input_size + padding) { --output_size; }\n  }\n  return output_size;\n}\n\nvoid GetNoDilation3DOutputShape(const DimVector& in, const std::vector<int32_t>& pool_size,\n                                const std::vector<int32_t>& strides,\n                                const std::vector<int32_t>& padding, const bool ceil_mode,\n                                DimVector* out) {\n  out->clear();\n  out->resize(3);\n  FOR_RANGE(size_t, i, 0, 3) {\n    out->at(i) = GetNoDilationWindowedOutputShape(in.at(i), pool_size.at(i), strides.at(i),\n                                                  padding.at(i), ceil_mode);\n  }\n}\n\nAvgPoolParams3D::AvgPoolParams3D(const int32_t dim, const ShapeView& x_shape,\n                                 const std::string& data_format,\n                                 const std::vector<int32_t>& padding,\n                                 const std::vector<int32_t>& kernel_size,\n                                 const std::vector<int32_t>& stride, const bool ceil_mode,\n                                 const bool count_include_pad, const int32_t divisor_override)\n    : dim_(dim),\n      data_format_(data_format),\n      padding_(GetAvg3DPadVec(padding, dim)),\n      pool_size_3d_(GetAvg3DVec(kernel_size, dim)),\n      stride_3d_(GetAvg3DVec(stride, dim)),\n      ceil_mode_(ceil_mode),\n      count_include_pad_(count_include_pad),\n      divisor_override_(divisor_override) {\n  x_3d_ = {GetInDim(x_shape, data_format, 0, dim), GetInDim(x_shape, data_format, 1, dim),\n           GetInDim(x_shape, data_format, 2, dim)};\n  GetNoDilation3DOutputShape(x_3d_, pool_size_3d_, stride_3d_, padding_, ceil_mode_, &y_3d_);\n  if (data_format == \"channels_first\") {\n    channel_num_ = x_shape.At(1);\n  } else {\n    CHECK_EQ(data_format_, \"channels_last\")\n        << \"data_format must be 'channels_first' or 'channels_last'\";\n    channel_num_ = x_shape.At(x_shape.NumAxes() - 1);\n  }\n  batch_num_ = x_shape.At(0);\n}\n\nvoid AvgPoolParams3D::Reset(const ShapeView& x_shape) {\n  x_3d_ = {GetInDim(x_shape, data_format_, 0, dim_), GetInDim(x_shape, data_format_, 1, dim_),\n           GetInDim(x_shape, data_format_, 2, dim_)};\n  GetNoDilation3DOutputShape(x_3d_, pool_size_3d_, stride_3d_, padding_, ceil_mode_, &y_3d_);\n}\n\nShape AvgPoolParams3D::GetYShape() const {\n  DimVector y_dim_vec;\n  if (dim_ == 1) {\n    y_dim_vec = {y_3d_.at(2)};\n  } else if (dim_ == 2) {\n    y_dim_vec = {y_3d_.at(1), y_3d_.at(2)};\n  } else if (dim_ == 3) {\n    y_dim_vec = {y_3d_.at(0), y_3d_.at(1), y_3d_.at(2)};\n  } else {\n    UNIMPLEMENTED();\n  }\n  if (data_format_ == \"channels_first\") {\n    y_dim_vec.insert(y_dim_vec.begin(), channel_num_);\n  } else {\n    CHECK_EQ(data_format_, \"channels_last\")\n        << \"data_format must be 'channels_first' or 'channels_last'\";\n    y_dim_vec.insert(y_dim_vec.end(), channel_num_);\n  }\n  y_dim_vec.insert(y_dim_vec.begin(), batch_num_);\n  return Shape(y_dim_vec);\n}\n\nShape AvgPoolParams3D::GetXShape5D() const {\n  return Shape({batch_num_, channel_num_, x_3d_.at(0), x_3d_.at(1), x_3d_.at(2)});\n}\n\nShape AvgPoolParams3D::GetYShape5D() const {\n  return Shape({batch_num_, channel_num_, y_3d_.at(0), y_3d_.at(1), y_3d_.at(2)});\n}\n\n}  // namespace oneflow"
  },
  {
    "path": "oneflow/user/kernels/avg_pool_kernel_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_KERNELS_AVG_POOL_KERNEL_UTIL_H_\n#define ONEFLOW_USER_KERNELS_AVG_POOL_KERNEL_UTIL_H_\n#include \"oneflow/core/ep/include/stream.h\"\n#include \"oneflow/core/ndarray/xpu_util.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/common/nd_index_offset_helper.h\"\n#include \"oneflow/core/operator/operator_util.h\"\n#include \"oneflow/core/kernel/util/numerics.cuh\"\n#include \"oneflow/core/kernel/util/numeric_limits.cuh\"\n#ifdef WITH_CUDA\n#include \"oneflow/core/cuda/atomic.cuh\"\n#endif  // WITH_CUDA\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename T>\nOF_DEVICE_FUNC T XPU_INT_MIN(T a, T b) {\n  return a <= b ? a : b;\n}\n\ntemplate<typename T>\nOF_DEVICE_FUNC T XPU_INT_MAX(T a, T b) {\n  return a >= b ? a : b;\n}\n\ntemplate<typename T>\nstruct XPUAdd {\n  OF_DEVICE_FUNC static void Invoke(const T* x, T* y) {\n#if defined(__CUDA_ARCH__)\n    cuda::atomic::Add(y, *x);\n#else\n    *y += *x;\n#endif\n  };\n};\n\n}  // namespace\n\n#define AVG_POOL_DATA_TYPE_SEQ                  \\\n  OF_PP_MAKE_TUPLE_SEQ(float, DataType::kFloat) \\\n  OF_PP_MAKE_TUPLE_SEQ(double, DataType::kDouble)\n\n#define AVG_POOL_DATA_TYPE_CPU_SEQ AVG_POOL_DATA_TYPE_SEQ\n\n#define AVG_POOL_DATA_TYPE_CUDA_SEQ AVG_POOL_DATA_TYPE_SEQ\n\n#define AVG_POOL_IDX_DATA_TYPE_SEQ                \\\n  OF_PP_MAKE_TUPLE_SEQ(int32_t, DataType::kInt32) \\\n  OF_PP_MAKE_TUPLE_SEQ(int64_t, DataType::kInt64)\n\ntypedef small_vector<int64_t, SHAPE_MAX_AXIS_SIZE> FixedDimVector;\n\nclass AvgPoolParams3D {\n public:\n  AvgPoolParams3D(const int32_t dim, const ShapeView& x_shape, const std::string& data_format,\n                  const std::vector<int32_t>& padding, const std::vector<int32_t>& kernel_size,\n                  const std::vector<int32_t>& stride, const bool ceil_mode,\n                  const bool count_include_pad, const int32_t divisor_override);\n  ~AvgPoolParams3D() = default;\n\n  const std::string& data_format() const { return data_format_; }\n  const std::vector<int32_t>& padding() const { return padding_; }\n  const std::vector<int32_t>& pool_size_3d() const { return pool_size_3d_; }\n  const std::vector<int32_t>& stride_3d() const { return stride_3d_; }\n  const bool& ceil_mode() const { return ceil_mode_; }\n  const bool& count_include_pad() const { return count_include_pad_; }\n  const int32_t& divisor_override() const { return divisor_override_; }\n  const int32_t& num_batch() const { return batch_num_; }\n  const int32_t& num_channel() const { return channel_num_; }\n\n  void Reset(const ShapeView& x_shape);\n  Shape GetYShape() const;\n  Shape GetXShape5D() const;\n  Shape GetYShape5D() const;\n\n private:\n  int32_t dim_;\n  FixedDimVector x_3d_;\n  FixedDimVector y_3d_;\n  std::string data_format_;\n  std::vector<int32_t> padding_;\n  std::vector<int32_t> pool_size_3d_;\n  std::vector<int32_t> stride_3d_;\n  bool ceil_mode_;\n  bool count_include_pad_;\n  int32_t divisor_override_;\n  int32_t batch_num_;\n  int32_t channel_num_;\n};\n\ntemplate<DeviceType device_type, typename T, typename IDX>\nstruct AvgPoolKernelUtil {\n  static void Avgpool1dForward(ep::Stream* stream, const NdIndexOffsetHelper<IDX, 2>& index_helper,\n                               const IDX elem_num, const T* src, T* dest,\n                               const AvgPoolParams3D& params_3d);\n\n  static void Avgpool1dBackward(ep::Stream* stream, const NdIndexOffsetHelper<IDX, 2>& index_helper,\n                                const IDX elem_num, const T* src, T* dest,\n                                const AvgPoolParams3D& params_3d);\n\n  static void Avgpool2dForward(ep::Stream* stream, const NdIndexOffsetHelper<IDX, 3>& index_helper,\n                               const IDX elem_num, const T* src, T* dest,\n                               const AvgPoolParams3D& params_3d);\n\n  static void Avgpool2dBackward(ep::Stream* stream, const NdIndexOffsetHelper<IDX, 3>& index_helper,\n                                const IDX elem_num, const T* src, T* dest,\n                                const AvgPoolParams3D& params_3d);\n\n  static void Avgpool3dForward(ep::Stream* stream, const NdIndexOffsetHelper<IDX, 4>& index_helper,\n                               const IDX elem_num, const T* src, T* dest,\n                               const AvgPoolParams3D& params_3d);\n\n  static void Avgpool3dBackward(ep::Stream* stream, const NdIndexOffsetHelper<IDX, 4>& index_helper,\n                                const IDX elem_num, const T* src, T* dest,\n                                const AvgPoolParams3D& params_3d);\n};\n\ntemplate<typename T, typename IDX>\nOF_DEVICE_FUNC void Avgpool1dForwardCompute(const NdIndexOffsetHelper<IDX, 2> index_helper,\n                                            IDX elem_num, const T* src, T* dest,\n                                            const int32_t padding_l, const int32_t n_batch,\n                                            const int32_t n_channel, const int32_t x_length,\n                                            const int32_t kernel_size_l, const int32_t stride_l,\n                                            const bool count_include_pad,\n                                            const int32_t divisor_override) {\n  XPU_1D_KERNEL_LOOP(num, elem_num) {\n    IDX n_c, l;\n    index_helper.OffsetToNdIndex(num, n_c, l);\n\n    const IDX start_idx = n_c * x_length;\n    IDX lstart = l * stride_l - padding_l;\n    IDX lend = XPU_INT_MIN<IDX>(lstart + kernel_size_l, x_length + padding_l);\n    const IDX pool_size = (lend - lstart);\n\n    lstart = XPU_INT_MAX<IDX>(0, lstart);\n    lend = XPU_INT_MIN<IDX>(lend, x_length);\n\n    IDX divide_factor;\n    if (divisor_override != static_cast<int32_t>(0)) {\n      divide_factor = divisor_override;\n    } else {\n      if (count_include_pad) {\n        divide_factor = pool_size;\n      } else {\n        divide_factor = (lend - lstart);\n      }\n    }\n    T sum = 0;\n\n    const T* data = src + start_idx;\n    for (IDX idx = lstart; idx < lend; idx += 1) { sum += data[idx]; }\n    dest[num] = static_cast<T>(sum / divide_factor);\n  }\n}\n\ntemplate<typename T, typename IDX>\nOF_DEVICE_FUNC void Avgpool1dBackwardCompute(const NdIndexOffsetHelper<IDX, 2> index_helper,\n                                             IDX elem_num, const T* src, T* dest,\n                                             const int32_t padding_l, const int32_t n_batch,\n                                             const int32_t n_channel, const int32_t input_length,\n                                             const int32_t kernel_size_l, const int32_t stride_l,\n                                             const bool count_include_pad,\n                                             const int32_t divisor_override) {\n  XPU_1D_KERNEL_LOOP(num, elem_num) {\n    IDX n_c, l;\n    index_helper.OffsetToNdIndex(num, n_c, l);\n\n    const IDX start_idx = n_c * input_length;\n    IDX lstart = l * stride_l - padding_l;\n    IDX lend = XPU_INT_MIN<IDX>(lstart + kernel_size_l, input_length + padding_l);\n    const IDX pool_size = (lend - lstart);\n\n    lstart = XPU_INT_MAX<IDX>(IDX(0), lstart);\n    lend = XPU_INT_MIN<IDX>(lend, input_length);\n\n    IDX divide_factor;\n    if (divisor_override != static_cast<int32_t>(0)) {\n      divide_factor = divisor_override;\n    } else {\n      if (count_include_pad) {\n        divide_factor = pool_size;\n      } else {\n        divide_factor = (lend - lstart);\n      }\n    }\n    T grad_delta = src[num] / divide_factor;\n    T* data = dest + start_idx;\n    for (IDX idx = lstart; idx < lend; idx += 1) {\n      XPUAdd<T>::Invoke(&grad_delta, &data[idx]);  // dest[search_idx] += grad_delta\n    }\n  }\n}\n\ntemplate<typename T, typename IDX>\nOF_DEVICE_FUNC void Avgpool2dForwardCompute(\n    const NdIndexOffsetHelper<IDX, 3> index_helper, int64_t elem_num, const T* src, T* dest,\n    const int32_t padding_h, const int32_t padding_w, const int32_t n_batch,\n    const int32_t n_channel, const int32_t x_height, const int32_t x_width,\n    const int32_t kernel_size_h, const int32_t kernel_size_w, const int32_t stride_h,\n    const int32_t stride_w, const bool count_include_pad, int32_t divisor_override) {\n  XPU_1D_KERNEL_LOOP(num, elem_num) {\n    IDX n_c, h, w;\n    index_helper.OffsetToNdIndex(num, n_c, h, w);\n\n    const IDX start_idx = n_c * x_width * x_height;\n    IDX hstart = h * stride_h - padding_h;\n    IDX wstart = w * stride_w - padding_w;\n\n    IDX hend = XPU_INT_MIN<IDX>(hstart + kernel_size_h, x_height + padding_h);\n    IDX wend = XPU_INT_MIN<IDX>(wstart + kernel_size_w, x_width + padding_w);\n    const IDX pool_size = (hend - hstart) * (wend - wstart);\n\n    hstart = XPU_INT_MAX<IDX>(0, hstart);\n    wstart = XPU_INT_MAX<IDX>(0, wstart);\n    hend = XPU_INT_MIN<IDX>(hend, x_height);\n    wend = XPU_INT_MIN<IDX>(wend, x_width);\n\n    IDX divide_factor;\n    if (divisor_override != static_cast<int32_t>(0)) {\n      divide_factor = divisor_override;\n    } else {\n      if (count_include_pad) {\n        divide_factor = pool_size;\n      } else {\n        divide_factor = (hend - hstart) * (wend - wstart);\n      }\n    }\n    T sum = 0;\n\n    const T* data = src + start_idx;\n    for (int64_t i = hstart; i < hend; i += 1) {\n      for (int64_t j = wstart; j < wend; j += 1) {\n        const IDX window_idx = i * x_width + j;\n        sum += data[window_idx];\n      }\n    }\n    dest[num] = sum / divide_factor;\n  }\n}\n\ntemplate<typename T, typename IDX>\nOF_DEVICE_FUNC void Avgpool2dBackwardCompute(\n    const NdIndexOffsetHelper<IDX, 3> index_helper, IDX elem_num, const T* src, T* dest,\n    const int32_t padding_h, const int32_t padding_w, const int32_t n_batch,\n    const int32_t n_channel, const int32_t input_height, const int32_t input_width,\n    const int32_t kernel_size_h, const int32_t kernel_size_w, const int32_t stride_h,\n    const int32_t stride_w, const bool count_include_pad, int32_t divisor_override) {\n  XPU_1D_KERNEL_LOOP(num, elem_num) {\n    IDX n_c, h, w;\n    index_helper.OffsetToNdIndex(num, n_c, h, w);\n\n    const IDX start_idx = n_c * input_width * input_height;\n    IDX hstart = h * stride_h - padding_h;\n    IDX wstart = w * stride_w - padding_w;\n    IDX hend = XPU_INT_MIN<IDX>(hstart + kernel_size_h, input_height + padding_h);\n    IDX wend = XPU_INT_MIN<IDX>(wstart + kernel_size_w, input_width + padding_w);\n    const IDX pool_size = (hend - hstart) * (wend - wstart);\n\n    hstart = XPU_INT_MAX<IDX>(IDX(0), hstart);\n    wstart = XPU_INT_MAX<IDX>(IDX(0), wstart);\n    hend = XPU_INT_MIN<IDX>(hend, input_height);\n    wend = XPU_INT_MIN<IDX>(wend, input_width);\n\n    IDX divide_factor;\n    if (divisor_override != static_cast<int32_t>(0)) {\n      divide_factor = divisor_override;\n    } else {\n      if (count_include_pad) {\n        divide_factor = pool_size;\n      } else {\n        divide_factor = (hend - hstart) * (wend - wstart);\n      }\n    }\n    T grad_delta = src[num] / divide_factor;\n    T* data = dest + start_idx;\n    for (IDX i = hstart; i < hend; i += 1) {\n      for (IDX j = wstart; j < wend; j += 1) {\n        const IDX window_idx = i * input_width + j;\n        XPUAdd<T>::Invoke(&grad_delta, &data[window_idx]);  // dest[search_idx] += grad_delta\n      }\n    }\n  }\n}\n\ntemplate<typename T, typename IDX>\nOF_DEVICE_FUNC void Avgpool3dForwardCompute(\n    const NdIndexOffsetHelper<IDX, 4> index_helper, IDX elem_num, const T* src, T* dest,\n    const int32_t padding_t, const int32_t padding_h, const int32_t padding_w,\n    const int32_t n_batch, const int32_t n_channel, const int32_t x_time, const int32_t x_height,\n    const int32_t x_width, const int32_t kernel_size_t, const int32_t kernel_size_h,\n    const int32_t kernel_size_w, const int32_t stride_t, const int32_t stride_h,\n    const int32_t stride_w, const bool count_include_pad, int32_t divisor_override) {\n  XPU_1D_KERNEL_LOOP(num, elem_num) {\n    IDX n_c, t, h, w;\n    index_helper.OffsetToNdIndex(num, n_c, t, h, w);\n\n    const IDX start_idx = n_c * x_time * x_height * x_width;\n    IDX tstart = t * stride_t - padding_t;\n    IDX hstart = h * stride_h - padding_h;\n    IDX wstart = w * stride_w - padding_w;\n    IDX tend = XPU_INT_MIN<IDX>(tstart + kernel_size_t, x_time + padding_t);\n    IDX hend = XPU_INT_MIN<IDX>(hstart + kernel_size_h, x_height + padding_h);\n    IDX wend = XPU_INT_MIN<IDX>(wstart + kernel_size_w, x_width + padding_w);\n    const IDX pool_size = (tend - tstart) * (hend - hstart) * (wend - wstart);\n\n    tstart = XPU_INT_MAX<IDX>(IDX(0), tstart);\n    hstart = XPU_INT_MAX<IDX>(IDX(0), hstart);\n    wstart = XPU_INT_MAX<IDX>(IDX(0), wstart);\n    tend = XPU_INT_MIN<IDX>(tend, x_time);\n    hend = XPU_INT_MIN<IDX>(hend, x_height);\n    wend = XPU_INT_MIN<IDX>(wend, x_width);\n\n    IDX divide_factor;\n    if (divisor_override != static_cast<int32_t>(0)) {\n      divide_factor = divisor_override;\n    } else {\n      if (count_include_pad) {\n        divide_factor = pool_size;\n      } else {\n        divide_factor = (tend - tstart) * (hend - hstart) * (wend - wstart);\n      }\n    }\n    T sum = 0;\n    const T* data = src + start_idx;\n    for (IDX i = tstart; i < tend; i += 1) {\n      for (IDX j = hstart; j < hend; j += 1) {\n        for (IDX k = wstart; k < wend; k += 1) {\n          const IDX window_idx = i * x_height * x_width + j * x_width + k;\n          sum += data[window_idx];\n        }\n      }\n    }\n    dest[num] = sum / divide_factor;\n  }\n}\n\ntemplate<typename T, typename IDX>\nOF_DEVICE_FUNC void Avgpool3dBackwardCompute(\n    const NdIndexOffsetHelper<IDX, 4> index_helper, IDX elem_num, const T* src, T* dest,\n    const int32_t padding_t, const int32_t padding_h, const int32_t padding_w,\n    const int32_t n_batch, const int32_t n_channel, const int32_t x_time, const int32_t x_height,\n    const int32_t x_width, const int32_t kernel_size_t, const int32_t kernel_size_h,\n    const int32_t kernel_size_w, const int32_t stride_t, const int32_t stride_h,\n    const int32_t stride_w, const bool count_include_pad, const int32_t divisor_override) {\n  XPU_1D_KERNEL_LOOP(num, elem_num) {\n    IDX n_c, t, h, w;\n    index_helper.OffsetToNdIndex(num, n_c, t, h, w);\n\n    const IDX start_idx = n_c * x_time * x_width * x_height;\n    IDX tstart = t * stride_t - padding_t;\n    IDX hstart = h * stride_h - padding_h;\n    IDX wstart = w * stride_w - padding_w;\n    IDX tend = XPU_INT_MIN<IDX>(tstart + kernel_size_t, x_time + padding_t);\n    IDX hend = XPU_INT_MIN<IDX>(hstart + kernel_size_h, x_height + padding_h);\n    IDX wend = XPU_INT_MIN<IDX>(wstart + kernel_size_w, x_width + padding_w);\n    const IDX pool_size = (tend - tstart) * (hend - hstart) * (wend - wstart);\n\n    tstart = XPU_INT_MAX<IDX>(IDX(0), tstart);\n    hstart = XPU_INT_MAX<IDX>(IDX(0), hstart);\n    wstart = XPU_INT_MAX<IDX>(IDX(0), wstart);\n    tend = XPU_INT_MIN<IDX>(tend, x_time);\n    hend = XPU_INT_MIN<IDX>(hend, x_height);\n    wend = XPU_INT_MIN<IDX>(wend, x_width);\n\n    IDX divide_factor;\n    if (divisor_override != static_cast<int32_t>(0)) {\n      divide_factor = divisor_override;\n    } else {\n      if (count_include_pad) {\n        divide_factor = pool_size;\n      } else {\n        divide_factor = (tend - tstart) * (hend - hstart) * (wend - wstart);\n      }\n    }\n    T grad_delta = src[num] / divide_factor;\n    T* data = dest + start_idx;\n    for (IDX i = tstart; i < tend; i += 1) {\n      for (IDX j = hstart; j < hend; j += 1) {\n        for (IDX k = wstart; k < wend; k += 1) {\n          const IDX window_idx = i * x_height * x_width + j * x_width + k;\n          XPUAdd<T>::Invoke(&grad_delta, &data[window_idx]);  // dest[search_idx] += grad_delta\n        }\n      }\n    }\n  }\n}\n\n#ifdef WITH_CUDA\ntemplate<DeviceType device_type, typename IDX>\nstruct AvgPoolKernelUtil<device_type, half, IDX> {\n  static void Avgpool1dForward(ep::Stream* stream, const NdIndexOffsetHelper<IDX, 2>& index_helper,\n                               const IDX elem_num, const half* src, half* dest,\n                               const AvgPoolParams3D& params_3d);\n\n  static void Avgpool1dBackward(ep::Stream* stream, const NdIndexOffsetHelper<IDX, 2>& index_helper,\n                                const IDX elem_num, const half* src, half* dest,\n                                const AvgPoolParams3D& params_3d);\n\n  static void Avgpool2dForward(ep::Stream* stream, const NdIndexOffsetHelper<IDX, 3>& index_helper,\n                               const IDX elem_num, const half* src, half* dest,\n                               const AvgPoolParams3D& params_3d);\n\n  static void Avgpool2dBackward(ep::Stream* stream, const NdIndexOffsetHelper<IDX, 3>& index_helper,\n                                const IDX elem_num, const half* src, half* dest,\n                                const AvgPoolParams3D& params_3d);\n\n  static void Avgpool3dForward(ep::Stream* stream, const NdIndexOffsetHelper<IDX, 4>& index_helper,\n                               const IDX elem_num, const half* src, half* dest,\n                               const AvgPoolParams3D& params_3d);\n\n  static void Avgpool3dBackward(ep::Stream* stream, const NdIndexOffsetHelper<IDX, 4>& index_helper,\n                                const IDX elem_num, const half* src, half* dest,\n                                const AvgPoolParams3D& params_3d);\n};\n\ntemplate<typename IDX>\nOF_DEVICE_FUNC void HalfAvgpool1dForwardCompute(const NdIndexOffsetHelper<IDX, 2> index_helper,\n                                                IDX elem_num, const half* src, half* dest,\n                                                const int32_t padding_l, const int32_t n_batch,\n                                                const int32_t n_channel, const int32_t x_length,\n                                                const int32_t kernel_size_l, const int32_t stride_l,\n                                                const bool count_include_pad,\n                                                const int32_t divisor_override) {\n  XPU_1D_KERNEL_LOOP(num, elem_num) {\n    IDX n_c, l;\n    index_helper.OffsetToNdIndex(num, n_c, l);\n\n    const IDX start_idx = n_c * x_length;\n    IDX lstart = l * stride_l - padding_l;\n    IDX lend = XPU_INT_MIN<IDX>(lstart + kernel_size_l, x_length + padding_l);\n    const IDX pool_size = (lend - lstart);\n\n    lstart = XPU_INT_MAX<IDX>(0, lstart);\n    lend = XPU_INT_MIN<IDX>(lend, x_length);\n\n    IDX divide_factor;\n    if (divisor_override != static_cast<int32_t>(0)) {\n      divide_factor = divisor_override;\n    } else {\n      if (count_include_pad) {\n        divide_factor = pool_size;\n      } else {\n        divide_factor = (lend - lstart);\n      }\n    }\n    float sum = 0;\n\n    const half* data = src + start_idx;\n    for (IDX idx = lstart; idx < lend; idx += 1) { sum += __half2float(data[idx]); }\n    dest[num] = __float2half(sum / divide_factor);\n  }\n}\n\ntemplate<typename IDX>\nOF_DEVICE_FUNC void HalfAvgpool1dBackwardCompute(\n    const NdIndexOffsetHelper<IDX, 2> index_helper, IDX elem_num, const half* src, half* dest,\n    const int32_t padding_l, const int32_t n_batch, const int32_t n_channel,\n    const int32_t input_length, const int32_t kernel_size_l, const int32_t stride_l,\n    const bool count_include_pad, const int32_t divisor_override) {\n  XPU_1D_KERNEL_LOOP(num, elem_num) {\n    IDX n_c, l;\n    index_helper.OffsetToNdIndex(num, n_c, l);\n\n    const IDX start_idx = n_c * input_length;\n    IDX lstart = l * stride_l - padding_l;\n    IDX lend = XPU_INT_MIN<IDX>(lstart + kernel_size_l, input_length + padding_l);\n    const IDX pool_size = (lend - lstart);\n\n    lstart = XPU_INT_MAX<IDX>(IDX(0), lstart);\n    lend = XPU_INT_MIN<IDX>(lend, input_length);\n\n    IDX divide_factor;\n    if (divisor_override != static_cast<int32_t>(0)) {\n      divide_factor = divisor_override;\n    } else {\n      if (count_include_pad) {\n        divide_factor = pool_size;\n      } else {\n        divide_factor = (lend - lstart);\n      }\n    }\n    half grad_delta = static_cast<half>(__half2float(src[num]) / divide_factor);\n    half* data = dest + start_idx;\n    for (IDX idx = lstart; idx < lend; idx += 1) { XPUAdd<half>::Invoke(&grad_delta, &data[idx]); }\n  }\n}\n\ntemplate<typename IDX>\nOF_DEVICE_FUNC void HalfAvgpool2dForwardCompute(\n    const NdIndexOffsetHelper<IDX, 3> index_helper, int64_t elem_num, const half* src, half* dest,\n    const int32_t padding_h, const int32_t padding_w, const int32_t n_batch,\n    const int32_t n_channel, const int32_t x_height, const int32_t x_width,\n    const int32_t kernel_size_h, const int32_t kernel_size_w, const int32_t stride_h,\n    const int32_t stride_w, const bool count_include_pad, int32_t divisor_override) {\n  XPU_1D_KERNEL_LOOP(num, elem_num) {\n    IDX n_c, h, w;\n    index_helper.OffsetToNdIndex(num, n_c, h, w);\n\n    const IDX start_idx = n_c * x_width * x_height;\n    IDX hstart = h * stride_h - padding_h;\n    IDX wstart = w * stride_w - padding_w;\n\n    IDX hend = XPU_INT_MIN<IDX>(hstart + kernel_size_h, x_height + padding_h);\n    IDX wend = XPU_INT_MIN<IDX>(wstart + kernel_size_w, x_width + padding_w);\n    const IDX pool_size = (hend - hstart) * (wend - wstart);\n\n    hstart = XPU_INT_MAX<IDX>(0, hstart);\n    wstart = XPU_INT_MAX<IDX>(0, wstart);\n    hend = XPU_INT_MIN<IDX>(hend, x_height);\n    wend = XPU_INT_MIN<IDX>(wend, x_width);\n\n    IDX divide_factor;\n    if (divisor_override != static_cast<int32_t>(0)) {\n      divide_factor = divisor_override;\n    } else {\n      if (count_include_pad) {\n        divide_factor = pool_size;\n      } else {\n        divide_factor = (hend - hstart) * (wend - wstart);\n      }\n    }\n    float sum = 0;\n    const half* data = src + start_idx;\n    for (int64_t i = hstart; i < hend; i += 1) {\n      for (int64_t j = wstart; j < wend; j += 1) {\n        const IDX window_idx = i * x_width + j;\n        sum += __half2float(data[window_idx]);\n      }\n    }\n    dest[num] = __float2half(sum / divide_factor);\n  }\n}\n\ntemplate<typename IDX>\nOF_DEVICE_FUNC void HalfAvgpool2dBackwardCompute(\n    const NdIndexOffsetHelper<IDX, 3> index_helper, IDX elem_num, const half* src, half* dest,\n    const int32_t padding_h, const int32_t padding_w, const int32_t n_batch,\n    const int32_t n_channel, const int32_t input_height, const int32_t input_width,\n    const int32_t kernel_size_h, const int32_t kernel_size_w, const int32_t stride_h,\n    const int32_t stride_w, const bool count_include_pad, int32_t divisor_override) {\n  XPU_1D_KERNEL_LOOP(num, elem_num) {\n    IDX n_c, h, w;\n    index_helper.OffsetToNdIndex(num, n_c, h, w);\n\n    const IDX start_idx = n_c * input_width * input_height;\n    IDX hstart = h * stride_h - padding_h;\n    IDX wstart = w * stride_w - padding_w;\n    IDX hend = XPU_INT_MIN<IDX>(hstart + kernel_size_h, input_height + padding_h);\n    IDX wend = XPU_INT_MIN<IDX>(wstart + kernel_size_w, input_width + padding_w);\n    const IDX pool_size = (hend - hstart) * (wend - wstart);\n\n    hstart = XPU_INT_MAX<IDX>(IDX(0), hstart);\n    wstart = XPU_INT_MAX<IDX>(IDX(0), wstart);\n    hend = XPU_INT_MIN<IDX>(hend, input_height);\n    wend = XPU_INT_MIN<IDX>(wend, input_width);\n\n    IDX divide_factor;\n    if (divisor_override != static_cast<int32_t>(0)) {\n      divide_factor = divisor_override;\n    } else {\n      if (count_include_pad) {\n        divide_factor = pool_size;\n      } else {\n        divide_factor = (hend - hstart) * (wend - wstart);\n      }\n    }\n    half grad_delta = static_cast<half>(__half2float(src[num]) / divide_factor);\n    half* data = dest + start_idx;\n    for (IDX i = hstart; i < hend; i += 1) {\n      for (IDX j = wstart; j < wend; j += 1) {\n        const IDX window_idx = i * input_width + j;\n        XPUAdd<half>::Invoke(&grad_delta, &data[window_idx]);\n      }\n    }\n  }\n}\n\ntemplate<typename IDX>\nOF_DEVICE_FUNC void HalfAvgpool3dForwardCompute(\n    const NdIndexOffsetHelper<IDX, 4> index_helper, IDX elem_num, const half* src, half* dest,\n    const int32_t padding_t, const int32_t padding_h, const int32_t padding_w,\n    const int32_t n_batch, const int32_t n_channel, const int32_t x_time, const int32_t x_height,\n    const int32_t x_width, const int32_t kernel_size_t, const int32_t kernel_size_h,\n    const int32_t kernel_size_w, const int32_t stride_t, const int32_t stride_h,\n    const int32_t stride_w, const bool count_include_pad, int32_t divisor_override) {\n  XPU_1D_KERNEL_LOOP(num, elem_num) {\n    IDX n_c, t, h, w;\n    index_helper.OffsetToNdIndex(num, n_c, t, h, w);\n\n    const IDX start_idx = n_c * x_time * x_height * x_width;\n    IDX tstart = t * stride_t - padding_t;\n    IDX hstart = h * stride_h - padding_h;\n    IDX wstart = w * stride_w - padding_w;\n    IDX tend = XPU_INT_MIN<IDX>(tstart + kernel_size_t, x_time + padding_t);\n    IDX hend = XPU_INT_MIN<IDX>(hstart + kernel_size_h, x_height + padding_h);\n    IDX wend = XPU_INT_MIN<IDX>(wstart + kernel_size_w, x_width + padding_w);\n    const IDX pool_size = (tend - tstart) * (hend - hstart) * (wend - wstart);\n\n    tstart = XPU_INT_MAX<IDX>(IDX(0), tstart);\n    hstart = XPU_INT_MAX<IDX>(IDX(0), hstart);\n    wstart = XPU_INT_MAX<IDX>(IDX(0), wstart);\n    tend = XPU_INT_MIN<IDX>(tend, x_time);\n    hend = XPU_INT_MIN<IDX>(hend, x_height);\n    wend = XPU_INT_MIN<IDX>(wend, x_width);\n\n    IDX divide_factor;\n    if (divisor_override != static_cast<int32_t>(0)) {\n      divide_factor = divisor_override;\n    } else {\n      if (count_include_pad) {\n        divide_factor = pool_size;\n      } else {\n        divide_factor = (tend - tstart) * (hend - hstart) * (wend - wstart);\n      }\n    }\n    float sum = 0;\n    const half* data = src + start_idx;\n    for (IDX i = tstart; i < tend; i += 1) {\n      for (IDX j = hstart; j < hend; j += 1) {\n        for (IDX k = wstart; k < wend; k += 1) {\n          const IDX window_idx = i * x_height * x_width + j * x_width + k;\n          sum += __half2float(data[window_idx]);\n        }\n      }\n    }\n    dest[num] = __float2half(sum / divide_factor);\n  }\n}\n\ntemplate<typename IDX>\nOF_DEVICE_FUNC void HalfAvgpool3dBackwardCompute(\n    const NdIndexOffsetHelper<IDX, 4> index_helper, IDX elem_num, const half* src, half* dest,\n    const int32_t padding_t, const int32_t padding_h, const int32_t padding_w,\n    const int32_t n_batch, const int32_t n_channel, const int32_t x_time, const int32_t x_height,\n    const int32_t x_width, const int32_t kernel_size_t, const int32_t kernel_size_h,\n    const int32_t kernel_size_w, const int32_t stride_t, const int32_t stride_h,\n    const int32_t stride_w, const bool count_include_pad, const int32_t divisor_override) {\n  XPU_1D_KERNEL_LOOP(num, elem_num) {\n    IDX n_c, t, h, w;\n    index_helper.OffsetToNdIndex(num, n_c, t, h, w);\n\n    const IDX start_idx = n_c * x_time * x_width * x_height;\n    IDX tstart = t * stride_t - padding_t;\n    IDX hstart = h * stride_h - padding_h;\n    IDX wstart = w * stride_w - padding_w;\n    IDX tend = XPU_INT_MIN<IDX>(tstart + kernel_size_t, x_time + padding_t);\n    IDX hend = XPU_INT_MIN<IDX>(hstart + kernel_size_h, x_height + padding_h);\n    IDX wend = XPU_INT_MIN<IDX>(wstart + kernel_size_w, x_width + padding_w);\n    const IDX pool_size = (tend - tstart) * (hend - hstart) * (wend - wstart);\n\n    tstart = XPU_INT_MAX<IDX>(IDX(0), tstart);\n    hstart = XPU_INT_MAX<IDX>(IDX(0), hstart);\n    wstart = XPU_INT_MAX<IDX>(IDX(0), wstart);\n    tend = XPU_INT_MIN<IDX>(tend, x_time);\n    hend = XPU_INT_MIN<IDX>(hend, x_height);\n    wend = XPU_INT_MIN<IDX>(wend, x_width);\n\n    IDX divide_factor;\n    if (divisor_override != static_cast<int32_t>(0)) {\n      divide_factor = divisor_override;\n    } else {\n      if (count_include_pad) {\n        divide_factor = pool_size;\n      } else {\n        divide_factor = (tend - tstart) * (hend - hstart) * (wend - wstart);\n      }\n    }\n    half grad_delta = static_cast<half>(__half2float(src[num]) / divide_factor);\n    half* data = dest + start_idx;\n    for (IDX i = tstart; i < tend; i += 1) {\n      for (IDX j = hstart; j < hend; j += 1) {\n        for (IDX k = wstart; k < wend; k += 1) {\n          const IDX window_idx = i * x_height * x_width + j * x_width + k;\n          XPUAdd<half>::Invoke(&grad_delta, &data[window_idx]);\n        }\n      }\n    }\n  }\n}\n\n#endif  // WITH_CUDA\n\n#define INSTANTIATE_AVG_POOL_KERNEL_UTIL(device_type_v, dtype_pair, index_dtype_pair) \\\n  template struct AvgPoolKernelUtil<device_type_v, OF_PP_PAIR_FIRST(dtype_pair),      \\\n                                    OF_PP_PAIR_FIRST(index_dtype_pair)>;\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_KERNELS_AVG_POOL_KERNEL_UTIL_H_\n"
  },
  {
    "path": "oneflow/user/kernels/batch_gather_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/user/kernels/batch_gather_kernel_util.h\"\n#include \"oneflow/core/kernel/cuda_graph_support.h\"\n\nnamespace oneflow {\n\nnamespace user_op {\n\ntemplate<DeviceType device_type, typename T, typename K>\nclass BatchGatherKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport {\n public:\n  BatchGatherKernel() = default;\n  ~BatchGatherKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    const user_op::Tensor* indices = ctx->Tensor4ArgNameAndIndex(\"indices\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    const int64_t axis = indices->shape_view().NumAxes() - 1;\n    const Shape flat_out_shape =\n        Shape({out->shape_view().Count(0, axis), out->shape_view().At(axis),\n               out->shape_view().Count(axis + 1)});\n    BatchGatherKernelUtilImpl<device_type, T, K>::Forward(\n        ctx->stream(), in->dptr<T>(), indices->dptr<K>(), flat_out_shape, in->shape_view().At(axis),\n        out->mut_dptr<T>());\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_BATCH_GATHER_KERNEL(device, out_dtype, indices_dtype)        \\\n  REGISTER_USER_KERNEL(\"batch_gather\")                                        \\\n      .SetCreateFn<BatchGatherKernel<device, OF_PP_PAIR_FIRST(out_dtype),     \\\n                                     OF_PP_PAIR_FIRST(indices_dtype)>>()      \\\n      .SetIsMatchedHob(                                                       \\\n          (user_op::HobDeviceType() == device)                                \\\n          && (user_op::HobDataType(\"out\", 0) == OF_PP_PAIR_SECOND(out_dtype)) \\\n          && (user_op::HobDataType(\"indices\", 0) == OF_PP_PAIR_SECOND(indices_dtype)));\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_BATCH_GATHER_KERNEL, DEVICE_TYPE_SEQ,\n                                 FLOATING_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ)\n\n}  // namespace user_op\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/batch_gather_kernel_util.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/user/kernels/batch_gather_kernel_util.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nShape GetFlatShape(const ShapeView& shape, const int64_t axis) {\n  CHECK_GT(shape.NumAxes(), 0);\n  CHECK_GE(axis, 0);\n  CHECK_LT(axis, shape.NumAxes());\n  return Shape({shape.Count(0, axis), shape.At(axis), shape.Count(axis + 1)});\n}\n\ntemplate<DeviceType device_type, typename T, typename K>\nvoid BatchGatherForward(ep::Stream* stream, const Blob* in, const Blob* indices, Blob* out) {\n  const int64_t axis = indices->shape_view().NumAxes() - 1;\n  const Shape flat_out_shape = GetFlatShape(out->shape_view(), axis);\n  BatchGatherKernelUtilImpl<device_type, T, K>::Forward(stream, in->dptr<T>(), indices->dptr<K>(),\n                                                        flat_out_shape, in->shape_view().At(axis),\n                                                        out->mut_dptr<T>());\n}\n\ntemplate<DeviceType device_type, typename T, typename K>\nvoid BatchGatherBackward(ep::Stream* stream, const Blob* out_diff, const Blob* indices,\n                         Blob* in_diff) {\n  Memset<device_type>(stream, in_diff->mut_dptr<T>(), 0, in_diff->ByteSizeOfBlobBody());\n  const int64_t axis = indices->shape_view().NumAxes() - 1;\n  const Shape flat_out_diff_shape = GetFlatShape(out_diff->shape_view(), axis);\n  BatchGatherKernelUtilImpl<device_type, T, K>::Backward(\n      stream, out_diff->dptr<T>(), indices->dptr<K>(), flat_out_diff_shape,\n      in_diff->shape_view().At(axis), in_diff->mut_dptr<T>());\n}\n\ntemplate<DeviceType device_type, typename T>\nstruct BatchGatherSwitchUtil final {\n#define MAKE_BATCH_GATHER_SWITCH_ENTRY(func_name, K) func_name<device_type, T, K>\n#define DEFINE_BATCH_GATHER_STATIC_SWITCH_FUNC(func_name)                    \\\n  DEFINE_STATIC_SWITCH_FUNC(void, func_name, MAKE_BATCH_GATHER_SWITCH_ENTRY, \\\n                            MAKE_DATA_TYPE_CTRV_SEQ(INT_DATA_TYPE_SEQ));\n  DEFINE_BATCH_GATHER_STATIC_SWITCH_FUNC(BatchGatherForward);\n  DEFINE_BATCH_GATHER_STATIC_SWITCH_FUNC(BatchGatherBackward);\n#undef DEFINE_BATCH_GATHER_STATIC_SWITCH_FUNC\n#undef MAKE_BATCH_GATHER_SWITCH_ENTRY\n};\n\n}  // namespace\n\ntemplate<DeviceType device_type, typename T>\nvoid BatchGatherKernelUtil<device_type, T>::Forward(ep::Stream* stream, const Blob* in,\n                                                    const Blob* indices, Blob* out) {\n  BatchGatherSwitchUtil<device_type, T>::SwitchBatchGatherForward(SwitchCase(indices->data_type()),\n                                                                  stream, in, indices, out);\n}\n\ntemplate<DeviceType device_type, typename T>\nvoid BatchGatherKernelUtil<device_type, T>::Backward(ep::Stream* stream, const Blob* out_diff,\n                                                     const Blob* indices, Blob* in_diff) {\n  BatchGatherSwitchUtil<device_type, T>::SwitchBatchGatherBackward(\n      SwitchCase(indices->data_type()), stream, out_diff, indices, in_diff);\n}\n\ntemplate<typename T, typename K>\nstruct BatchGatherKernelUtilImpl<DeviceType::kCPU, T, K> final {\n  static void Forward(ep::Stream* stream, const T* in, const K* indices,\n                      const Shape& flat_out_shape, int64_t gather_dim_size, T* out);\n  static void Backward(ep::Stream* stream, const T* out_diff, const K* indices,\n                       const Shape& flat_out_diff_shape, int64_t gather_dim_size, T* in_diff);\n};\n\ntemplate<typename T, typename K>\nvoid BatchGatherKernelUtilImpl<DeviceType::kCPU, T, K>::Forward(ep::Stream* stream, const T* in,\n                                                                const K* indices,\n                                                                const Shape& flat_out_shape,\n                                                                const int64_t gather_dim_size,\n                                                                T* out) {\n  const int64_t batch_num = flat_out_shape.At(0);\n  const int64_t indices_num = flat_out_shape.At(1);\n  const int64_t instance_size = flat_out_shape.At(2);\n  FOR_RANGE(int64_t, batch_idx, 0, batch_num) {\n    FOR_RANGE(int64_t, i, 0, indices_num) {\n      const K idx = indices[batch_idx * indices_num + i];\n      CHECK(idx >= 0 && idx < gather_dim_size);\n      const T* from = in + batch_idx * gather_dim_size * instance_size + idx * instance_size;\n      T* to = out + batch_idx * indices_num * instance_size + i * instance_size;\n      std::copy(from, from + instance_size, to);\n    }\n  }\n}\n\ntemplate<typename T, typename K>\nvoid BatchGatherKernelUtilImpl<DeviceType::kCPU, T, K>::Backward(\n    ep::Stream* stream, const T* out_diff, const K* indices, const Shape& flat_out_diff_shape,\n    const int64_t gather_dim_size, T* in_diff) {\n  const int64_t batch_num = flat_out_diff_shape.At(0);\n  const int64_t indices_num = flat_out_diff_shape.At(1);\n  const int64_t instance_size = flat_out_diff_shape.At(2);\n  FOR_RANGE(int64_t, batch_idx, 0, batch_num) {\n    FOR_RANGE(int64_t, i, 0, indices_num) {\n      const int64_t idx = indices[batch_idx * indices_num + i];\n      CHECK(idx >= 0 && idx < gather_dim_size);\n      const T* from = out_diff + batch_idx * indices_num * instance_size + i * instance_size;\n      T* to = in_diff + batch_idx * gather_dim_size * instance_size + idx * instance_size;\n      std::transform(from, from + instance_size, to, to, std::plus<T>());\n    }\n  }\n}\n\n#define INSTANTIATE_BATCH_GATHER_KERNEL_UTIL_IMPL_CPU(in_type_pair, index_type_pair)          \\\n  template struct BatchGatherKernelUtilImpl<DeviceType::kCPU, OF_PP_PAIR_FIRST(in_type_pair), \\\n                                            OF_PP_PAIR_FIRST(index_type_pair)>;\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_BATCH_GATHER_KERNEL_UTIL_IMPL_CPU,\n                                 FLOATING_DATA_TYPE_SEQ, INT_DATA_TYPE_SEQ);\n#undef INSTANTIATE_BATCH_GATHER_KERNEL_UTIL_IMPL_CPU\n\n#define INSTANTIATE_BATCH_GATHER_KERNEL_UTIL(device_type, in_type_pair) \\\n  template struct BatchGatherKernelUtil<device_type, OF_PP_PAIR_FIRST(in_type_pair)>;\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_BATCH_GATHER_KERNEL_UTIL, DEVICE_TYPE_SEQ,\n                                 FLOATING_DATA_TYPE_SEQ);\n#undef INSTANTIATE_BATCH_GATHER_KERNEL_UTIL\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/batch_gather_kernel_util.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/user/kernels/batch_gather_kernel_util.h\"\n#include \"oneflow/core/cuda/atomic.cuh\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n#include <assert.h>\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename K>\n__device__ int64_t GetInOffset(const int64_t out_offset, const K* indices,\n                               const int64_t indices_num, const int64_t instance_size,\n                               const int64_t gather_dim_size) {\n  const int64_t batch_idx = out_offset / (indices_num * instance_size);\n  const int64_t indices_idx = out_offset % (indices_num * instance_size) / instance_size;\n  const int64_t inner_idx = out_offset % instance_size;\n  const int64_t idx = indices[batch_idx * indices_num + indices_idx];\n  assert(idx >= 0 && idx < gather_dim_size);\n  return batch_idx * gather_dim_size * instance_size + idx * instance_size + inner_idx;\n}\n\ntemplate<typename T, typename K>\n__global__ void BatchGatherForwardGpu(const int64_t elem_cnt, const T* in, const K* indices,\n                                      const int64_t indices_num, const int64_t instance_size,\n                                      const int64_t gather_dim_size, T* out) {\n  CUDA_1D_KERNEL_LOOP(i, elem_cnt) {\n    out[i] = in[GetInOffset<K>(i, indices, indices_num, instance_size, gather_dim_size)];\n  }\n}\n\ntemplate<typename T, typename K>\n__global__ void BatchGatherBackwardGpu(const int64_t elem_cnt, const T* out_diff, const K* indices,\n                                       const int64_t indices_num, const int64_t instance_size,\n                                       const int64_t gather_dim_size, T* in_diff) {\n  CUDA_1D_KERNEL_LOOP(i, elem_cnt) {\n    cuda::atomic::Add(\n        in_diff + GetInOffset<K>(i, indices, indices_num, instance_size, gather_dim_size),\n        out_diff[i]);\n  }\n}\n\n}  // namespace\n\ntemplate<typename T, typename K>\nstruct BatchGatherKernelUtilImpl<DeviceType::kCUDA, T, K> final {\n  static void Forward(ep::Stream* stream, const T* in, const K* indices,\n                      const Shape& flat_out_shape, const int64_t gather_dim_size, T* out);\n  static void Backward(ep::Stream* stream, const T* out_diff, const K* indices,\n                       const Shape& flat_out_diff_shape, const int64_t gather_dim_size, T* in_diff);\n};\n\ntemplate<typename T, typename K>\nvoid BatchGatherKernelUtilImpl<DeviceType::kCUDA, T, K>::Forward(ep::Stream* stream, const T* in,\n                                                                 const K* indices,\n                                                                 const Shape& flat_out_shape,\n                                                                 const int64_t gather_dim_size,\n                                                                 T* out) {\n  const int64_t batch_num = flat_out_shape.At(0);\n  const int64_t indices_num = flat_out_shape.At(1);\n  const int64_t instance_size = flat_out_shape.At(2);\n  const int64_t elem_cnt = batch_num * indices_num * instance_size;\n  BatchGatherForwardGpu<T, K><<<BlocksNum4ThreadsNum(elem_cnt), kCudaThreadsNumPerBlock, 0,\n                                stream->As<ep::CudaStream>()->cuda_stream()>>>(\n      elem_cnt, in, indices, indices_num, instance_size, gather_dim_size, out);\n}\n\ntemplate<typename T, typename K>\nvoid BatchGatherKernelUtilImpl<DeviceType::kCUDA, T, K>::Backward(\n    ep::Stream* stream, const T* out_diff, const K* indices, const Shape& flat_out_diff_shape,\n    const int64_t gather_dim_size, T* in_diff) {\n  const int64_t batch_num = flat_out_diff_shape.At(0);\n  const int64_t indices_num = flat_out_diff_shape.At(1);\n  const int64_t instance_size = flat_out_diff_shape.At(2);\n  const int64_t elem_cnt = batch_num * indices_num * instance_size;\n  BatchGatherBackwardGpu<T, K><<<BlocksNum4ThreadsNum(elem_cnt), kCudaThreadsNumPerBlock, 0,\n                                 stream->As<ep::CudaStream>()->cuda_stream()>>>(\n      elem_cnt, out_diff, indices, indices_num, instance_size, gather_dim_size, in_diff);\n}\n\n#define INSTANTIATE_BATCH_GATHER_KERNEL_UTIL_IMPL_CUDA(in_type_pair, index_type_pair)          \\\n  template struct BatchGatherKernelUtilImpl<DeviceType::kCUDA, OF_PP_PAIR_FIRST(in_type_pair), \\\n                                            OF_PP_PAIR_FIRST(index_type_pair)>;\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_BATCH_GATHER_KERNEL_UTIL_IMPL_CUDA,\n                                 FLOATING_DATA_TYPE_SEQ, INT_DATA_TYPE_SEQ);\n#undef INSTANTIATE_BATCH_GATHER_KERNEL_UTIL_IMPL_CUDA\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/batch_gather_kernel_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_KERNELS_BATCH_GATHER_KERNEL_UTIL_H_\n#define ONEFLOW_USER_KERNELS_BATCH_GATHER_KERNEL_UTIL_H_\n\n#include \"oneflow/core/kernel/kernel.h\"\n\nnamespace oneflow {\n\ntemplate<DeviceType device_type, typename T, typename K>\nstruct BatchGatherKernelUtilImpl final {\n  static void Forward(ep::Stream* stream, const T* in, const K* indices,\n                      const Shape& flat_out_shape, int64_t gather_dim_size, T* out);\n  static void Backward(ep::Stream* stream, const T* out_diff, const K* indices,\n                       const Shape& flat_out_diff_shape, int64_t gather_dim_size, T* in_diff);\n};\n\ntemplate<DeviceType device_type, typename T>\nstruct BatchGatherKernelUtil final {\n  static void Forward(ep::Stream* stream, const Blob* in, const Blob* indices, Blob* out);\n  static void Backward(ep::Stream* stream, const Blob* out_diff, const Blob* indices,\n                       Blob* in_diff);\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_KERNELS_BATCH_GATHER_KERNEL_UTIL_H_\n"
  },
  {
    "path": "oneflow/user/kernels/batch_norm_backward_elemt_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <limits>\n#include <algorithm>\n#include \"oneflow/core/device/cuda_util.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/ndarray/ndarray_util.h\"\n#include \"oneflow/core/ndarray/xpu_var_ndarray.h\"\n#include \"oneflow/core/kernel/kernel_util.h\"\n#include \"oneflow/core/kernel/cuda_graph_support.h\"\n#include \"oneflow/core/ep/include/primitive/cast.h\"\n#include \"oneflow/core/ep/include/primitive/fill.h\"\n#include \"oneflow/core/ep/cuda/cuda_device.h\"\n#include \"oneflow/user/kernels/batch_norm_kernel_utils.h\"\n\n// NOTE(Liang Depeng):\n// The implementation of batch_norm_backward_elemt kernel is modified from\n// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cuda/Normalization.cuh\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename T, typename ACC_T, typename IDX_TYPE>\n__global__ void batch_norm_backward_elemt_kernel(\n    const IDX_TYPE batch_size, const IDX_TYPE channel_size, const IDX_TYPE spatial_size,\n    const T* grad_out_ptr, const T* input_ptr, const T* mean_ptr, const T* invstd_ptr,\n    const T* weight_ptr, const T* sum_dy_ptr, const T* sum_dy_xmu_ptr, T* grad_in_ptr,\n    const int32_t* count_ptr, const int64_t world_size) {\n  int64_t total_numel = 0;\n  for (int i = 0; i < world_size; i++) { total_numel += count_ptr[i]; }\n\n  const ACC_T norm_fct = static_cast<ACC_T>(1) / static_cast<ACC_T>(total_numel);\n\n  IDX_TYPE channel = blockIdx.x;\n\n  if (channel >= channel_size) { return; }\n\n  ACC_T m_c = mean_ptr[channel];\n  ACC_T m_dy_c = sum_dy_ptr[channel] * norm_fct;\n  ACC_T factor_1_c = invstd_ptr[channel];\n  ACC_T factor_2_c = static_cast<ACC_T>(weight_ptr[channel]);\n  factor_2_c *= factor_1_c;\n  factor_1_c = factor_1_c * factor_1_c * sum_dy_xmu_ptr[channel] * norm_fct;\n\n  IDX_TYPE batch_offset = spatial_size * channel_size;\n  IDX_TYPE channel_offset = channel * spatial_size;\n\n  IDX_TYPE bstep = blockDim.y * gridDim.y;\n  for (IDX_TYPE batch = threadIdx.y + blockIdx.y * blockDim.y; batch < batch_size; batch += bstep) {\n    IDX_TYPE offset = batch * batch_offset;\n    for (IDX_TYPE feature = threadIdx.x; feature < spatial_size; feature += blockDim.x) {\n      grad_in_ptr[offset + channel_offset + feature] =\n          static_cast<T>((grad_out_ptr[offset + channel_offset + feature] - m_dy_c\n                          - (input_ptr[offset + channel_offset + feature] - m_c) * factor_1_c)\n                         * factor_2_c);\n    }\n  }\n}\n\ntemplate<typename T, typename ACC_T, typename IDX_TYPE, int PARALLEL_LOADS>\n__global__ void batch_norm_backward_elemt_channels_last_kernel(\n    const T* grad_out_ptr, const T* input_ptr, const ACC_T* mean_ptr, const ACC_T* invstd_ptr,\n    const T* weight_ptr, const ACC_T* sum_dy_ptr, const ACC_T* sum_dy_xmu_ptr,\n    const int32_t* count_ptr, T* grad_in_ptr, const IDX_TYPE world_size, const IDX_TYPE stride,\n    const IDX_TYPE reduction_size) {\n  IDX_TYPE total_numel = 0;\n  for (IDX_TYPE i = 0; i < world_size; i++) { total_numel += count_ptr[i]; }\n\n  auto norm_fct = static_cast<ACC_T>(1) / static_cast<ACC_T>(total_numel);\n\n  // tensor dimension (m,c)\n  // loop along m dimension\n  IDX_TYPE inner_loop_stride = blockDim.y * gridDim.y;\n\n  // offset along m dimension\n  IDX_TYPE m_offset = blockIdx.y * blockDim.y + threadIdx.y;\n  IDX_TYPE c_offset = blockIdx.x * blockDim.x + threadIdx.x;\n\n  if (c_offset >= stride || m_offset >= reduction_size) { return; }\n\n  auto m_c = mean_ptr[c_offset];\n  auto m_dy_c = sum_dy_ptr[c_offset] * norm_fct;\n  auto factor_1_c = invstd_ptr[c_offset];\n  auto factor_2_c =\n      (weight_ptr == nullptr ? ACC_T(1.0) : static_cast<ACC_T>(weight_ptr[c_offset])) * factor_1_c;\n  factor_1_c = factor_1_c * factor_1_c * sum_dy_xmu_ptr[c_offset] * norm_fct;\n\n  int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS);\n  int address_base = m_offset * stride + c_offset;\n  int address_increment = inner_loop_stride * stride;\n\n  for (int i = 0; i < loop_count; i++) {\n#pragma unroll\n    for (int j = 0; j < PARALLEL_LOADS; j++) {\n      if (c_offset < stride && m_offset < reduction_size) {\n        grad_in_ptr[address_base] =\n            static_cast<T>((static_cast<ACC_T>(grad_out_ptr[address_base]) - m_dy_c\n                            - (static_cast<ACC_T>(input_ptr[address_base]) - m_c) * factor_1_c)\n                           * factor_2_c);\n      }\n      m_offset += inner_loop_stride;\n      address_base += address_increment;\n    }\n  }\n}\n\ntemplate<typename T>\nstruct BatchNormBackwardElemtFunctor final {\n  void operator()(ep::Stream* stream, const int64_t batch_size, const int64_t channel_size,\n                  const int64_t spatial_size, const T* grad_out_ptr, const T* input_ptr,\n                  const T* mean_ptr, const T* invstd_ptr, const T* weight_ptr, const T* sum_dy_ptr,\n                  const T* sum_dy_xmu_ptr, T* grad_in_ptr, const int32_t* count_ptr,\n                  const int64_t world_size) {\n    using ACC_T = acc_type<T>;\n\n    // The kernel is pointwise, but we need to balance reading parameters (save_var/mean,\n    // weight/bias) - which we only do once and have a for loop afterwards - with having many\n    // threads and blocks and good occupancy. Quiet likely, we could go with even more blocks than\n    // 1024. The various planes are independent, so we use blocks for them.\n    int tf = std::max<int>(getNumThreads(spatial_size / 4),\n                           std::min<int>(getNumThreads(spatial_size), 64));\n    int tb = std::max<int>(64 / tf, 1);\n    dim3 blocks_trans(channel_size, std::max<int>(1, std::min<int>((256 * 1024) / channel_size,\n                                                                   (batch_size + tb - 1) / tb)));\n    blocks_trans.y = std::min(blocks_trans.y, MAX_GRID_SIZE);\n    dim3 threads_trans(tf, tb);\n\n    if (batch_size * channel_size * spatial_size < std::numeric_limits<int32_t>::max()) {\n      batch_norm_backward_elemt_kernel<T, ACC_T, int32_t>\n          <<<blocks_trans, threads_trans, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(\n              static_cast<int32_t>(batch_size), static_cast<int32_t>(channel_size),\n              static_cast<int32_t>(spatial_size), grad_out_ptr, input_ptr, mean_ptr, invstd_ptr,\n              weight_ptr, sum_dy_ptr, sum_dy_xmu_ptr, grad_in_ptr, count_ptr, world_size);\n    } else {\n      batch_norm_backward_elemt_kernel<T, ACC_T, int64_t>\n          <<<blocks_trans, threads_trans, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(\n              batch_size, channel_size, spatial_size, grad_out_ptr, input_ptr, mean_ptr, invstd_ptr,\n              weight_ptr, sum_dy_ptr, sum_dy_xmu_ptr, grad_in_ptr, count_ptr, world_size);\n    }\n  }\n};\n\ntemplate<typename T>\nstruct BatchNormBackwardElemtChannelLastFunctor final {\n  void operator()(ep::Stream* stream, const int64_t stride, const int64_t reduction_size,\n                  const T* grad_out_ptr, const T* input_ptr, const T* mean_ptr, const T* invstd_ptr,\n                  const T* weight_ptr, const T* sum_dy_ptr, const T* sum_dy_xmu_ptr, T* grad_in_ptr,\n                  const int32_t* count_ptr, const int64_t world_size) {\n    using ACC_T = acc_type<T>;\n    dim3 block;\n    dim3 grid;\n    flexible_launch_configs(reduction_size, stride, block, grid);\n\n    if (stride * reduction_size < std::numeric_limits<int32_t>::max()) {\n      batch_norm_backward_elemt_channels_last_kernel<T, ACC_T, int32_t, ELEMENTS_PER_ITER>\n          <<<grid, block, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(\n              grad_out_ptr, input_ptr, mean_ptr, invstd_ptr, weight_ptr, sum_dy_ptr, sum_dy_xmu_ptr,\n              count_ptr, grad_in_ptr, world_size, static_cast<int32_t>(stride),\n              static_cast<int32_t>(reduction_size));\n    } else {\n      batch_norm_backward_elemt_channels_last_kernel<T, ACC_T, int64_t, ELEMENTS_PER_ITER>\n          <<<grid, block, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(\n              grad_out_ptr, input_ptr, mean_ptr, invstd_ptr, weight_ptr, sum_dy_ptr, sum_dy_xmu_ptr,\n              count_ptr, grad_in_ptr, world_size, stride, reduction_size);\n    }\n  }\n};\n\n}  // namespace\n\ntemplate<typename T>\nclass GpuBatchNormBackwardElemtKernel final : public user_op::OpKernel {\n public:\n  GpuBatchNormBackwardElemtKernel() = default;\n  ~GpuBatchNormBackwardElemtKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* grad_out = ctx->Tensor4ArgNameAndIndex(\"grad_out\", 0);\n    const user_op::Tensor* input = ctx->Tensor4ArgNameAndIndex(\"input\", 0);\n    const user_op::Tensor* mean = ctx->Tensor4ArgNameAndIndex(\"mean\", 0);\n    const user_op::Tensor* invstd = ctx->Tensor4ArgNameAndIndex(\"invstd\", 0);\n    const user_op::Tensor* weight = ctx->Tensor4ArgNameAndIndex(\"weight\", 0);\n    const user_op::Tensor* sum_dy = ctx->Tensor4ArgNameAndIndex(\"sum_dy\", 0);\n    const user_op::Tensor* sum_dy_xmu = ctx->Tensor4ArgNameAndIndex(\"sum_dy_xmu\", 0);\n    const user_op::Tensor* count = ctx->Tensor4ArgNameAndIndex(\"count\", 0);\n\n    user_op::Tensor* grad_in = ctx->Tensor4ArgNameAndIndex(\"grad_in\", 0);\n\n    const T* grad_out_ptr = grad_out->dptr<T>();\n    const T* input_ptr = input->dptr<T>();\n    const T* mean_ptr = mean->dptr<T>();\n    const T* invstd_ptr = invstd->dptr<T>();\n    const T* weight_ptr = weight->dptr<T>();\n    const T* sum_dy_ptr = sum_dy->dptr<T>();\n    const T* sum_dy_xmu_ptr = sum_dy_xmu->dptr<T>();\n    const int32_t* count_ptr = count->dptr<int32_t>();\n\n    T* grad_in_ptr = grad_in->mut_dptr<T>();\n    const int32_t axis = ctx->Attr<int32_t>(\"axis\");\n\n    bool use_channels_last_kernel = axis == 1 ? false : true;\n    const int64_t world_size = count->shape_view().elem_cnt();\n    if (use_channels_last_kernel) {  // NHWC format\n      const int64_t stride = input->shape_view().At(axis);\n      const int64_t reduction_size = input->shape_view().elem_cnt() / stride;\n      BatchNormBackwardElemtChannelLastFunctor<T>()(\n          ctx->stream(), stride, reduction_size, grad_out_ptr, input_ptr, mean_ptr, invstd_ptr,\n          weight_ptr, sum_dy_ptr, sum_dy_xmu_ptr, grad_in_ptr, count_ptr, world_size);\n    } else {  // NCHW format\n      const int64_t batch_size = input->shape_view().At(0);\n      const int64_t channel_size = input->shape_view().At(1);\n      const int64_t spatial_size = input->shape_view().Count(2);\n\n      BatchNormBackwardElemtFunctor<T>()(\n          ctx->stream(), batch_size, channel_size, spatial_size, grad_out_ptr, input_ptr, mean_ptr,\n          invstd_ptr, weight_ptr, sum_dy_ptr, sum_dy_xmu_ptr, grad_in_ptr, count_ptr, world_size);\n    }\n  }\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_BATCH_NORM_BACKWARD_ELEMT_KERNEL(dtype)                                       \\\n  REGISTER_USER_KERNEL(\"batch_norm_backward_elemt\")                                            \\\n      .SetCreateFn<GpuBatchNormBackwardElemtKernel<dtype>>()                                   \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)                         \\\n                       && (user_op::HobDataType(\"grad_out\", 0) == GetDataType<dtype>::value)   \\\n                       && (user_op::HobDataType(\"input\", 0) == GetDataType<dtype>::value)      \\\n                       && (user_op::HobDataType(\"mean\", 0) == GetDataType<dtype>::value)       \\\n                       && (user_op::HobDataType(\"invstd\", 0) == GetDataType<dtype>::value)     \\\n                       && (user_op::HobDataType(\"weight\", 0) == GetDataType<dtype>::value)     \\\n                       && (user_op::HobDataType(\"sum_dy\", 0) == GetDataType<dtype>::value)     \\\n                       && (user_op::HobDataType(\"sum_dy_xmu\", 0) == GetDataType<dtype>::value) \\\n                       && (user_op::HobDataType(\"count\", 0) == GetDataType<int32_t>::value))\n\nREGISTER_BATCH_NORM_BACKWARD_ELEMT_KERNEL(float);\nREGISTER_BATCH_NORM_BACKWARD_ELEMT_KERNEL(double);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/batch_norm_backward_reduce_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <limits>\n#include <algorithm>\n#include \"oneflow/core/device/cuda_util.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/ndarray/ndarray_util.h\"\n#include \"oneflow/core/ndarray/xpu_var_ndarray.h\"\n#include \"oneflow/core/kernel/kernel_util.h\"\n#include \"oneflow/core/kernel/cuda_graph_support.h\"\n#include \"oneflow/core/ep/include/primitive/cast.h\"\n#include \"oneflow/core/ep/include/primitive/fill.h\"\n#include \"oneflow/core/ep/cuda/cuda_device.h\"\n#include \"oneflow/user/kernels/batch_norm_kernel_utils.h\"\n\n// NOTE(Liang Depeng):\n// The implementation of batch_norm_backward_reduce kernel is modified from\n// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cuda/Normalization.cuh\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename T>\nstatic size_t InferTmpSizeForChannelLastKernel(user_op::InferContext* ctx) {\n  const int32_t axis = ctx->Attr<int32_t>(\"axis\");\n  const Shape& in_shape = ctx->InputTensorDesc(\"input\", 0).shape();\n  const int64_t stride = in_shape.At(axis);\n  const int64_t reduction_size = in_shape.elem_cnt() / stride;\n  dim3 block;\n  dim3 grid;\n  flexible_launch_configs(reduction_size, stride, block, grid, true);\n  size_t tmp_size = 0;\n  if (grid.y > 1) {\n    tmp_size += 2 * stride * grid.y * sizeof(T);\n    tmp_size += grid.x * sizeof(int32_t);\n  }\n  return tmp_size;\n}\n\ntemplate<typename T, typename ACC_T>\nstruct Float2 {\n  ACC_T v1, v2;\n  __device__ Float2() {}\n  __device__ Float2(T v1, T v2) : v1(static_cast<ACC_T>(v1)), v2(static_cast<ACC_T>(v2)) {}\n  __device__ Float2(int v) : v1(static_cast<ACC_T>(v)), v2(static_cast<ACC_T>(v)) {}\n  __device__ Float2& operator+=(const Float2& a) {\n    v1 += a.v1;\n    v2 += a.v2;\n    return *this;\n  }\n};\n\n// Sum across all threads within a warp\ntemplate<typename T>\nstatic __device__ __forceinline__ T warpSum_(T val) {\n  for (int i = 0; i < getMSB(WARP_SIZE); ++i) { val += WARP_SHFL_XOR(val, 1 << i, WARP_SIZE); }\n  return val;\n}\n\ntemplate<typename RES_T>\nstatic __device__ __forceinline__ RES_T warpSum(RES_T value) {\n  value.v1 = warpSum_(value.v1);\n  value.v2 = warpSum_(value.v2);\n  return value;\n}\n\ntemplate<typename RES_T, typename T, typename ACC_T, typename IDX_TYPE>\n__device__ RES_T reduce(const T* input_ptr, const T* grad_out_ptr, ACC_T r_mean, IDX_TYPE channel,\n                        IDX_TYPE batch_size, IDX_TYPE channel_size, IDX_TYPE spatial_size) {\n  IDX_TYPE batch_offset = spatial_size * channel_size;\n  IDX_TYPE channel_offset = channel * spatial_size;\n  // first the reductions each thread does separately\n  RES_T sum = static_cast<RES_T>(0);\n  for (int batch = threadIdx.y; batch < batch_size; batch += blockDim.y) {\n    IDX_TYPE offset = batch * batch_offset;\n    for (int x = threadIdx.x; x < spatial_size; x += blockDim.x) {\n      //   sum += op(batch, plane, x);\n      ACC_T g = grad_out_ptr[offset + channel_offset + x];\n      ACC_T c = static_cast<ACC_T>(input_ptr[offset + channel_offset + x]) - r_mean;\n      sum.v1 += g;\n      sum.v2 += g * c;\n    }\n  }\n\n  // first warpSum to get one value per thread to\n  // one value per warp\n  sum = warpSum(sum);\n\n  // this writes each warps  item into shared memory\n  // there are at most WARP_SIZE items left because\n  // there are at most WARP_SIZE**2 threads at the beginning\n  __shared__ RES_T shared[WARP_SIZE];\n  __syncthreads();\n  int tid = threadIdx.x + threadIdx.y * blockDim.x;\n  if (tid % WARP_SIZE == 0) { shared[tid / WARP_SIZE] = sum; }\n  if (tid >= blockDim.x * blockDim.y / WARP_SIZE && tid < WARP_SIZE) {\n    // zero out the other entries in shared\n    shared[tid] = (RES_T)0;\n  }\n  __syncthreads();\n  // now have a second warpSum to reduce the intermediate values\n  // from shared memory to a single number. The very first\n  // thread writes it to shared memory.\n\n  if (tid / WARP_SIZE == 0) {\n    sum = warpSum(shared[tid]);\n    if (tid == 0) { shared[0] = sum; }\n  }\n  __syncthreads();\n\n  // Everyone picks it up, should be broadcast into the whole grad_input\n  return shared[0];\n}\n\ntemplate<typename T, typename ACC_T, typename IDX_TYPE>\n__global__ void batch_norm_backward_reduce_kernel(\n    const IDX_TYPE batch_size, const IDX_TYPE channel_size, const IDX_TYPE spatial_size,\n    const T* grad_out_ptr, const T* input_ptr, const T* mean_ptr, const T* invstd_ptr,\n    T* sum_dy_ptr, T* sum_dy_xmu_ptr, T* grad_weight_ptr, T* grad_bias_ptr) {\n  IDX_TYPE channel = blockIdx.x;\n  ACC_T r_mean = mean_ptr[channel];\n  ACC_T factor = invstd_ptr[channel];\n\n  auto res = reduce<Float2<T, ACC_T>, T, ACC_T, IDX_TYPE>(input_ptr, grad_out_ptr, r_mean, channel,\n                                                          batch_size, channel_size, spatial_size);\n\n  if (threadIdx.x == 0) {\n    if (grad_weight_ptr != nullptr) { grad_weight_ptr[channel] = static_cast<T>(res.v2 * factor); }\n    if (grad_bias_ptr != nullptr) { grad_bias_ptr[channel] = static_cast<T>(res.v1); }\n    if (sum_dy_ptr != nullptr) { sum_dy_ptr[channel] = static_cast<ACC_T>(res.v1); }\n    if (sum_dy_xmu_ptr != nullptr) { sum_dy_xmu_ptr[channel] = static_cast<ACC_T>(res.v2); }\n  }\n}\n\ntemplate<typename T>\n__device__ __forceinline__ void merge_block_vertical_backward(T& sum_dy, T& sum_dy_xmu,\n                                                              T* shmem_sum_dy,\n                                                              T* shmem_sum_dy_xmu) {\n  // write to shared memory\n  auto address_base = threadIdx.x + threadIdx.y * blockDim.x;\n\n#pragma unroll\n  for (int offset = blockDim.y / 2; offset > 0; offset >>= 1) {\n    if (threadIdx.y < offset * 2) {\n      shmem_sum_dy[address_base] = sum_dy;\n      shmem_sum_dy_xmu[address_base] = sum_dy_xmu;\n    }\n    __syncthreads();\n    if (threadIdx.y < offset && threadIdx.y + offset < blockDim.y) {\n      auto address = address_base + offset * blockDim.x;\n\n      sum_dy += shmem_sum_dy[address];\n      sum_dy_xmu += shmem_sum_dy_xmu[address];\n    }\n  }\n}\n\ntemplate<typename T, typename ACC_T, typename IDX_TYPE, int PARALLEL_LOADS>\n__global__ void batch_norm_backward_reduce_channels_last_kernel(\n    const T* __restrict__ grad_output_ptr, const T* __restrict__ input_ptr,\n    const ACC_T* __restrict__ mean_ptr, const ACC_T* __restrict__ inv_std_ptr,\n    ACC_T* __restrict__ sum_dy_o_ptr, ACC_T* __restrict__ sum_dy_xmu_o_ptr,\n    T* __restrict__ grad_weight_ptr, T* __restrict__ grad_bias_ptr,\n    volatile ACC_T* staging_data_ptr, int32_t* semaphores_ptr, const IDX_TYPE reduction_size,\n    const IDX_TYPE stride) {\n  // hide latency with concurrency\n  ACC_T sum_dy[PARALLEL_LOADS];\n  ACC_T sum_dy_xmu[PARALLEL_LOADS];\n\n#pragma unroll\n  for (int i = 0; i < PARALLEL_LOADS; i++) {\n    sum_dy[i] = ACC_T(0);\n    sum_dy_xmu[i] = ACC_T(0);\n  }\n  // tensor dimension (m,c)\n\n  // loop along m dimension\n  int inner_loop_stride = blockDim.y * gridDim.y;\n\n  // offset along m dimension\n  int m_offset = blockIdx.y * blockDim.y + threadIdx.y;\n  int c_offset = blockIdx.x * blockDim.x + threadIdx.x;\n\n  if (c_offset >= stride || m_offset >= reduction_size) { return; }\n\n  int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS);\n  int address_base = m_offset * stride + c_offset;\n  int address_increment = inner_loop_stride * stride;\n\n  auto r_mean = mean_ptr[c_offset];\n  auto factor = inv_std_ptr[c_offset];\n\n  for (int i = 0; i < loop_count; i++) {\n    ACC_T x_input[PARALLEL_LOADS];\n    ACC_T x_grad_output[PARALLEL_LOADS];\n\n    // load multiple data in\n#pragma unroll\n    for (int j = 0; j < PARALLEL_LOADS; j++) {\n      if (c_offset < stride && m_offset < reduction_size) {\n        x_input[j] = input_ptr[address_base];\n        x_grad_output[j] = grad_output_ptr[address_base];\n      } else {\n        x_input[j] = ACC_T(0);\n        x_grad_output[j] = ACC_T(0);\n      }\n      m_offset += inner_loop_stride;\n      address_base += address_increment;\n    }\n\n    // calculate sum_dy / sum_dy_xmu\n#pragma unroll\n    for (int j = 0; j < PARALLEL_LOADS; j++) {\n      sum_dy[j] += x_grad_output[j];\n      sum_dy_xmu[j] += x_grad_output[j] * (x_input[j] - r_mean);\n    }\n  }\n\n  // thread reduction to accumulate sum_dy / sum_dy_xmu between PARALLEL_LOADS\n#pragma unroll\n  for (int j = 1; j < PARALLEL_LOADS; j++) {\n    sum_dy[0] += sum_dy[j];\n    sum_dy_xmu[0] += sum_dy_xmu[j];\n  }\n\n  // release array of registers\n  auto sum_dy_th = sum_dy[0];\n  auto sum_dy_xmu_th = sum_dy_xmu[0];\n\n  // block-wise reduction with shared memory (since reduction cannot be done within a warp)\n  static __shared__ ACC_T shmem_sum_dy[MAX_BLOCK_SIZE];\n  static __shared__ ACC_T shmem_sum_dy_xmu[MAX_BLOCK_SIZE];\n\n  merge_block_vertical_backward(sum_dy_th, sum_dy_xmu_th, shmem_sum_dy, shmem_sum_dy_xmu);\n\n  if (gridDim.y > 1) {\n    volatile ACC_T* staging_sum_dy = staging_data_ptr;\n    volatile ACC_T* staging_sum_dy_xmu = &staging_data_ptr[stride * gridDim.y];\n\n    address_base = c_offset + blockIdx.y * stride;\n    // write data to staging_data;\n    if (threadIdx.y == 0 && c_offset < stride) {\n      staging_sum_dy[address_base] = sum_dy_th;\n      staging_sum_dy_xmu[address_base] = sum_dy_xmu_th;\n    }\n\n    __threadfence();\n    __syncthreads();  // ensuring writes to staging_ is visible to all blocks\n\n    __shared__ bool is_last_block_done;\n    // mark block done\n    if (threadIdx.x == 0 && threadIdx.y == 0) {\n      int old = atomicAdd(&semaphores_ptr[blockIdx.x], 1);\n      is_last_block_done = (old == (gridDim.y - 1));\n    }\n\n    __syncthreads();\n\n    // check that all data is now available in global memory\n    if (is_last_block_done) {\n      sum_dy_th = ACC_T(0.0);\n      sum_dy_xmu_th = ACC_T(0.0);\n\n      for (int y = threadIdx.y; y < gridDim.y; y += blockDim.y) {\n        address_base = c_offset + y * stride;\n        sum_dy_th += (c_offset < stride ? staging_sum_dy[address_base] : ACC_T(0.0));\n        sum_dy_xmu_th += (c_offset < stride ? staging_sum_dy_xmu[address_base] : ACC_T(0.0));\n      }\n\n      merge_block_vertical_backward(sum_dy_th, sum_dy_xmu_th, shmem_sum_dy, shmem_sum_dy_xmu);\n      if (threadIdx.y == 0 && c_offset < stride) {\n        if (grad_bias_ptr != nullptr) { grad_bias_ptr[c_offset] = static_cast<T>(sum_dy_th); }\n        if (grad_weight_ptr != nullptr) {\n          grad_weight_ptr[c_offset] = static_cast<T>(sum_dy_xmu_th * factor);\n        }\n        sum_dy_o_ptr[c_offset] = sum_dy_th;\n        sum_dy_xmu_o_ptr[c_offset] = sum_dy_xmu_th;\n      }\n    }\n  } else {\n    if (blockIdx.y == 0 && threadIdx.y == 0 && c_offset < stride) {\n      if (grad_bias_ptr != nullptr) { grad_bias_ptr[c_offset] = static_cast<T>(sum_dy_th); }\n      if (grad_weight_ptr != nullptr) {\n        grad_weight_ptr[c_offset] = static_cast<T>(sum_dy_xmu_th * factor);\n      }\n      sum_dy_o_ptr[c_offset] = sum_dy_th;\n      sum_dy_xmu_o_ptr[c_offset] = sum_dy_xmu_th;\n    }\n  }\n}\n\ntemplate<typename T>\nstruct BatchNormBackwardReduceFunctor final {\n  void operator()(ep::Stream* stream, const int64_t batch_size, const int64_t channel_size,\n                  const int64_t spatial_size, const T* grad_out_ptr, const T* input_ptr,\n                  const T* mean_ptr, const T* invstd_ptr, T* sum_dy_ptr, T* sum_dy_xmu_ptr,\n                  T* grad_weight_ptr, T* grad_bias_ptr) {\n    using ACC_T = acc_type<T>;\n    int block_y = std::min<int>(lastPow2(batch_size), MAX_BLOCK_SIZE / WARP_SIZE);\n    // We want block_x to be at least a warp width\n    int block_x = std::min<int>(std::max<int>(getNumThreads(spatial_size), WARP_SIZE),\n                                MAX_BLOCK_SIZE / block_y);\n    const dim3 block(block_x, block_y);\n    const dim3 grid(channel_size);\n\n    if (batch_size * channel_size * spatial_size < std::numeric_limits<int32_t>::max()) {\n      batch_norm_backward_reduce_kernel<T, ACC_T, int32_t>\n          <<<grid, block, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(\n              static_cast<int32_t>(batch_size), static_cast<int32_t>(channel_size),\n              static_cast<int32_t>(spatial_size), grad_out_ptr, input_ptr, mean_ptr, invstd_ptr,\n              sum_dy_ptr, sum_dy_xmu_ptr, grad_weight_ptr, grad_bias_ptr);\n    } else {\n      batch_norm_backward_reduce_kernel<T, ACC_T, int64_t>\n          <<<grid, block, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(\n              batch_size, channel_size, spatial_size, grad_out_ptr, input_ptr, mean_ptr, invstd_ptr,\n              sum_dy_ptr, sum_dy_xmu_ptr, grad_weight_ptr, grad_bias_ptr);\n    }\n  }\n};\n\ntemplate<typename T>\nstruct BatchNormBackwardReduceChannelLastFunctor final {\n  void operator()(ep::Stream* stream, const int64_t stride, const int64_t reduction_size,\n                  const T* grad_out_ptr, const T* input_ptr, const T* mean_ptr, const T* invstd_ptr,\n                  T* sum_dy_ptr, T* sum_dy_xmu_ptr, T* grad_weight_ptr, T* grad_bias_ptr,\n                  user_op::Tensor* tmp_buffer) {\n    using ACC_T = acc_type<T>;\n\n    dim3 block;\n    dim3 grid;\n    flexible_launch_configs(reduction_size, stride, block, grid, true);\n\n    T* staging_data_ptr = nullptr;\n    int32_t* semaphores_ptr = nullptr;\n    if (grid.y > 1) {\n      staging_data_ptr = tmp_buffer->mut_dptr<T>();\n      semaphores_ptr = reinterpret_cast<int32_t*>(tmp_buffer->mut_dptr<char>()\n                                                  + 2 * stride * grid.y * sizeof(T));\n    }\n\n    if (stride * reduction_size < std::numeric_limits<int32_t>::max()) {\n      batch_norm_backward_reduce_channels_last_kernel<T, ACC_T, int32_t, ELEMENTS_PER_ITER>\n          <<<grid, block, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(\n              grad_out_ptr, input_ptr, mean_ptr, invstd_ptr, sum_dy_ptr, sum_dy_xmu_ptr,\n              grad_weight_ptr, grad_bias_ptr, staging_data_ptr, semaphores_ptr,\n              static_cast<int32_t>(reduction_size), static_cast<int32_t>(stride));\n    } else {\n      batch_norm_backward_reduce_channels_last_kernel<T, ACC_T, int64_t, ELEMENTS_PER_ITER>\n          <<<grid, block, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(\n              grad_out_ptr, input_ptr, mean_ptr, invstd_ptr, sum_dy_ptr, sum_dy_xmu_ptr,\n              grad_weight_ptr, grad_bias_ptr, staging_data_ptr, semaphores_ptr, reduction_size,\n              stride);\n    }\n  }\n};\n\n}  // namespace\n\ntemplate<typename T>\nclass GpuBatchNormBackwardReduceKernel final : public user_op::OpKernel {\n public:\n  GpuBatchNormBackwardReduceKernel() = default;\n  ~GpuBatchNormBackwardReduceKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* grad_out = ctx->Tensor4ArgNameAndIndex(\"grad_out\", 0);\n    const user_op::Tensor* input = ctx->Tensor4ArgNameAndIndex(\"input\", 0);\n    const user_op::Tensor* mean = ctx->Tensor4ArgNameAndIndex(\"mean\", 0);\n    const user_op::Tensor* invstd = ctx->Tensor4ArgNameAndIndex(\"invstd\", 0);\n\n    user_op::Tensor* sum_dy = ctx->Tensor4ArgNameAndIndex(\"sum_dy\", 0);\n    user_op::Tensor* sum_dy_xmu = ctx->Tensor4ArgNameAndIndex(\"sum_dy_xmu\", 0);\n    user_op::Tensor* grad_weight = ctx->Tensor4ArgNameAndIndex(\"grad_weight\", 0);\n    user_op::Tensor* grad_bias = ctx->Tensor4ArgNameAndIndex(\"grad_bias\", 0);\n\n    const T* grad_out_ptr = grad_out->dptr<T>();\n    const T* input_ptr = input->dptr<T>();\n    const T* mean_ptr = mean->dptr<T>();\n    const T* invstd_ptr = invstd->dptr<T>();\n\n    T* sum_dy_ptr = sum_dy->mut_dptr<T>();\n    T* sum_dy_xmu_ptr = sum_dy_xmu->mut_dptr<T>();\n    T* grad_weight_ptr = grad_weight->mut_dptr<T>();\n    T* grad_bias_ptr = grad_bias->mut_dptr<T>();\n\n    const int32_t axis = ctx->Attr<int32_t>(\"axis\");\n\n    bool use_channels_last_kernel = axis == 1 ? false : true;\n    if (use_channels_last_kernel) {  // NHWC format\n      const int64_t stride = input->shape_view().At(axis);\n      const int64_t reduction_size = input->shape_view().elem_cnt() / stride;\n      user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n      BatchNormBackwardReduceChannelLastFunctor<T>()(\n          ctx->stream(), stride, reduction_size, grad_out_ptr, input_ptr, mean_ptr, invstd_ptr,\n          sum_dy_ptr, sum_dy_xmu_ptr, grad_weight_ptr, grad_bias_ptr, tmp_buffer);\n    } else {  // NCHW format\n      const int64_t batch_size = input->shape_view().At(0);\n      const int64_t channel_size = input->shape_view().At(1);\n      const int64_t spatial_size = input->shape_view().Count(2);\n\n      BatchNormBackwardReduceFunctor<T>()(ctx->stream(), batch_size, channel_size, spatial_size,\n                                          grad_out_ptr, input_ptr, mean_ptr, invstd_ptr, sum_dy_ptr,\n                                          sum_dy_xmu_ptr, grad_weight_ptr, grad_bias_ptr);\n    }\n  }\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_BATCH_NORM_BACKWARD_REDUCE_KERNEL(dtype)                                    \\\n  REGISTER_USER_KERNEL(\"batch_norm_backward_reduce\")                                         \\\n      .SetCreateFn<GpuBatchNormBackwardReduceKernel<dtype>>()                                \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)                       \\\n                       && (user_op::HobDataType(\"grad_out\", 0) == GetDataType<dtype>::value) \\\n                       && (user_op::HobDataType(\"input\", 0) == GetDataType<dtype>::value)    \\\n                       && (user_op::HobDataType(\"mean\", 0) == GetDataType<dtype>::value)     \\\n                       && (user_op::HobDataType(\"invstd\", 0) == GetDataType<dtype>::value))  \\\n      .SetInferTmpSizeFn(InferTmpSizeForChannelLastKernel<dtype>)\n\nREGISTER_BATCH_NORM_BACKWARD_REDUCE_KERNEL(float);\nREGISTER_BATCH_NORM_BACKWARD_REDUCE_KERNEL(double);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/batch_norm_elemt_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <limits>\n#include \"oneflow/core/device/cuda_util.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/ndarray/ndarray_util.h\"\n#include \"oneflow/core/ndarray/xpu_var_ndarray.h\"\n#include \"oneflow/core/kernel/kernel_util.h\"\n#include \"oneflow/core/kernel/cuda_graph_support.h\"\n#include \"oneflow/core/ep/include/primitive/cast.h\"\n#include \"oneflow/core/ep/include/primitive/fill.h\"\n#include \"oneflow/core/ep/cuda/cuda_device.h\"\n#include \"oneflow/user/kernels/batch_norm_kernel_utils.h\"\n\n// NOTE(Liang Depeng):\n// The implementation of batch_norm_elemt kernel is modified from\n// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cuda/Normalization.cuh\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename T, typename IDX_TYPE, int PARALLEL_LOADS>\n__global__ void batch_norm_transform_input_channels_last_kernel(\n    const T* __restrict__ input_ptr, const T* __restrict__ mean_ptr,\n    const T* __restrict__ inv_std_ptr, const T* __restrict__ weight_ptr,\n    const T* __restrict__ bias_ptr, T* __restrict__ out_ptr, const IDX_TYPE reduction_size,\n    const IDX_TYPE stride) {\n  // tensor dimension (m,c)\n  // loop along m dimension\n  IDX_TYPE inner_loop_stride = blockDim.y * gridDim.y;\n\n  // offset along m dimension\n  IDX_TYPE m_offset = blockIdx.y * blockDim.y + threadIdx.y;\n  IDX_TYPE c_offset = blockIdx.x * blockDim.x + threadIdx.x;\n\n  if (c_offset >= stride || m_offset >= reduction_size) { return; }\n\n  auto m_c = mean_ptr[c_offset];\n  auto inv_std_c = static_cast<T>(inv_std_ptr[c_offset]);\n  auto w_c = weight_ptr == nullptr ? T(1.0) : static_cast<T>(weight_ptr[c_offset]);\n  auto b_c = bias_ptr == nullptr ? T(0.0) : static_cast<T>(bias_ptr[c_offset]);\n\n  IDX_TYPE loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS);\n  IDX_TYPE address_base = m_offset * stride + c_offset;\n  IDX_TYPE address_increment = inner_loop_stride * stride;\n\n  for (IDX_TYPE i = 0; i < loop_count; i++) {\n#pragma unroll\n    for (int j = 0; j < PARALLEL_LOADS; j++) {\n      if (c_offset < stride && m_offset < reduction_size) {\n        out_ptr[address_base] =\n            static_cast<T>(w_c * (static_cast<T>(input_ptr[address_base]) - m_c) * inv_std_c + b_c);\n      }\n      m_offset += inner_loop_stride;\n      address_base += address_increment;\n    }\n  }\n}\n\ntemplate<typename T, typename IDX_TYPE>\n__global__ void batch_norm_transform_input_kernel(const IDX_TYPE batch_size,\n                                                  const IDX_TYPE channel_size,\n                                                  const IDX_TYPE spatial_size, const T* input_ptr,\n                                                  const T* mean_ptr, const T* invstd_ptr,\n                                                  const T* weight_ptr, const T* bias_ptr,\n                                                  T* output_ptr) {\n  IDX_TYPE channel = blockIdx.x;\n  IDX_TYPE channel_offset = channel * spatial_size;\n  IDX_TYPE batch_step = channel_size * spatial_size;\n\n  if (channel >= channel_size) { return; }\n\n  T gamma = static_cast<T>(weight_ptr[channel]);\n  T beta = static_cast<T>(bias_ptr[channel]);\n  T mean = static_cast<T>(mean_ptr[channel]);\n  T invstd = invstd_ptr[channel];\n\n  IDX_TYPE bstep = blockDim.y * gridDim.y;\n  for (IDX_TYPE batch = threadIdx.y + blockIdx.y * blockDim.y; batch < batch_size; batch += bstep) {\n    IDX_TYPE offset = batch * batch_step + channel_offset;\n    for (IDX_TYPE feature = threadIdx.x; feature < spatial_size; feature += blockDim.x) {\n      output_ptr[offset + feature] =\n          static_cast<T>(gamma * (input_ptr[offset + feature] - mean) * invstd + beta);\n    }\n  }\n}\n\ntemplate<typename T>\nstruct BatchNormElemtFunctor final {\n  void operator()(ep::Stream* stream, const int64_t batch_size, const int64_t channel_size,\n                  const int64_t spatial_size, const T* input_ptr, const T* mean_ptr,\n                  const T* invstd_ptr, const T* weight_ptr, const T* bias_ptr, T* output_ptr) {\n    // The input_transform kernel is pointwise, but we need to balance reading parameters\n    // (save_var/mean, weight/bias) - which we only do once and have a for loop afterwards - with\n    // having many threads and blocks and good occupancy. Quiet likely, we could go with even more\n    // blocks than 1024. The various planes are independent, so we use blocks for them.\n    int tf = std::max<int>(getNumThreads(spatial_size / 4),\n                           std::min<int>(getNumThreads(spatial_size), 64));\n    int tb = std::max<int>(64 / tf, 1);\n    dim3 blocks_trans(channel_size, std::max<int>(1, std::min<int>((256 * 1024) / channel_size,\n                                                                   (batch_size + tb - 1) / tb)));\n    blocks_trans.y = std::min(blocks_trans.y, MAX_GRID_SIZE);\n    dim3 threads_trans(tf, tb);\n\n    if (batch_size * channel_size * spatial_size < std::numeric_limits<int32_t>::max()) {\n      batch_norm_transform_input_kernel<T, int32_t>\n          <<<blocks_trans, threads_trans, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(\n              static_cast<int32_t>(batch_size), static_cast<int32_t>(channel_size),\n              static_cast<int32_t>(spatial_size), input_ptr, mean_ptr, invstd_ptr, weight_ptr,\n              bias_ptr, output_ptr);\n    } else {\n      batch_norm_transform_input_kernel<T, int64_t>\n          <<<blocks_trans, threads_trans, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(\n              batch_size, channel_size, spatial_size, input_ptr, mean_ptr, invstd_ptr, weight_ptr,\n              bias_ptr, output_ptr);\n    }\n  }\n};\n\ntemplate<typename T>\nstruct BatchNormElemtChannelLastFunctor final {\n  void operator()(ep::Stream* stream, const int64_t stride, const int64_t reduction_size,\n                  const T* input_ptr, const T* mean_ptr, const T* invstd_ptr, const T* weight_ptr,\n                  const T* bias_ptr, T* output_ptr) {\n    dim3 block;\n    dim3 grid;\n    flexible_launch_configs(reduction_size, stride, block, grid);\n\n    if (reduction_size * stride < std::numeric_limits<int32_t>::max()) {\n      batch_norm_transform_input_channels_last_kernel<T, int32_t, ELEMENTS_PER_ITER>\n          <<<grid, block, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(\n              input_ptr, mean_ptr, invstd_ptr, weight_ptr, bias_ptr, output_ptr,\n              static_cast<int32_t>(reduction_size), static_cast<int32_t>(stride));\n    } else {\n      batch_norm_transform_input_channels_last_kernel<T, int64_t, ELEMENTS_PER_ITER>\n          <<<grid, block, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(\n              input_ptr, mean_ptr, invstd_ptr, weight_ptr, bias_ptr, output_ptr, reduction_size,\n              stride);\n    }\n  }\n};\n\n}  // namespace\n\ntemplate<typename T>\nclass GpuBatchNormElemtKernel final : public user_op::OpKernel {\n public:\n  GpuBatchNormElemtKernel() = default;\n  ~GpuBatchNormElemtKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* input = ctx->Tensor4ArgNameAndIndex(\"input\", 0);\n    const user_op::Tensor* mean = ctx->Tensor4ArgNameAndIndex(\"mean\", 0);\n    const user_op::Tensor* invstd = ctx->Tensor4ArgNameAndIndex(\"invstd\", 0);\n    const user_op::Tensor* weight = ctx->Tensor4ArgNameAndIndex(\"weight\", 0);\n    const user_op::Tensor* bias = ctx->Tensor4ArgNameAndIndex(\"bias\", 0);\n    user_op::Tensor* output = ctx->Tensor4ArgNameAndIndex(\"output\", 0);\n\n    const T* input_ptr = input->dptr<T>();\n    const T* mean_ptr = mean->dptr<T>();\n    const T* invstd_ptr = invstd->dptr<T>();\n    const T* weight_ptr = weight->dptr<T>();\n    const T* bias_ptr = bias->dptr<T>();\n    T* output_ptr = output->mut_dptr<T>();\n    const int32_t axis = ctx->Attr<int32_t>(\"axis\");\n\n    bool use_channels_last_kernel = axis == 1 ? false : true;\n    if (use_channels_last_kernel) {  // NHWC format\n      const int64_t stride = input->shape_view().At(axis);\n      const int64_t reduction_size = input->shape_view().elem_cnt() / stride;\n      BatchNormElemtChannelLastFunctor<T>()(ctx->stream(), stride, reduction_size, input_ptr,\n                                            mean_ptr, invstd_ptr, weight_ptr, bias_ptr, output_ptr);\n    } else {  // NCHW format\n      const int64_t batch_size = input->shape_view().At(0);\n      const int64_t channel_size = input->shape_view().At(1);\n      const int64_t spatial_size = input->shape_view().Count(2);\n\n      BatchNormElemtFunctor<T>()(ctx->stream(), batch_size, channel_size, spatial_size, input_ptr,\n                                 mean_ptr, invstd_ptr, weight_ptr, bias_ptr, output_ptr);\n    }\n  }\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_BATCH_NORM_ELEMT_KERNEL(dtype)                                            \\\n  REGISTER_USER_KERNEL(\"batch_norm_elemt\")                                                 \\\n      .SetCreateFn<GpuBatchNormElemtKernel<dtype>>()                                       \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)                     \\\n                       && (user_op::HobDataType(\"input\", 0) == GetDataType<dtype>::value)  \\\n                       && (user_op::HobDataType(\"mean\", 0) == GetDataType<dtype>::value)   \\\n                       && (user_op::HobDataType(\"invstd\", 0) == GetDataType<dtype>::value) \\\n                       && (user_op::HobDataType(\"weight\", 0) == GetDataType<dtype>::value) \\\n                       && (user_op::HobDataType(\"bias\", 0) == GetDataType<dtype>::value))\n\nREGISTER_BATCH_NORM_ELEMT_KERNEL(float);\nREGISTER_BATCH_NORM_ELEMT_KERNEL(double);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/batch_norm_gather_stats_with_counts_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <limits>\n#include \"oneflow/core/device/cuda_util.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/ndarray/ndarray_util.h\"\n#include \"oneflow/core/ndarray/xpu_var_ndarray.h\"\n#include \"oneflow/core/kernel/kernel_util.h\"\n#include \"oneflow/core/kernel/cuda_graph_support.h\"\n#include \"oneflow/core/ep/include/primitive/cast.h\"\n#include \"oneflow/core/ep/include/primitive/fill.h\"\n#include \"oneflow/core/ep/cuda/cuda_device.h\"\n#include \"oneflow/user/kernels/batch_norm_kernel_utils.h\"\n\n// NOTE(Liang Depeng):\n// The implementation of batch_norm_gather_stats_with_counts kernel is modified from\n// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cuda/Normalization.cuh\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename T, typename ACC_T, typename IDX_TYPE>\n__global__ void batch_norm_reduce_statistics_kernel(const int64_t world_size,\n                                                    const int64_t channel_size, const T* mean_ptr,\n                                                    const T* invstd_ptr, const T* counts_ptr,\n                                                    T* global_mean_ptr, T* global_invstd_ptr,\n                                                    T* running_mean_ptr, T* running_var_ptr,\n                                                    const float eps, const float momentum) {\n  IDX_TYPE bid = blockIdx.x;\n  IDX_TYPE tid = threadIdx.x;\n\n  // first the reductions each thread does separately\n  for (IDX_TYPE i = bid * blockDim.x + tid; i < channel_size; i += gridDim.x * blockDim.x) {\n    ACC_T avg = 0;\n    ACC_T var_n = 0;\n    IDX_TYPE n = 0;\n    for (IDX_TYPE j = 0; j < world_size; j++) {\n      T count = counts_ptr[j];\n      ACC_T m = mean_ptr[j * channel_size + i];\n      ACC_T v = ACC_T(1.0) / (invstd_ptr[j * channel_size + i]);\n      v = (v * v - eps) * count;\n      ACC_T factor = 1.0 / (n + count);\n      var_n += v + (avg - m) * (avg - m) * n * count * factor;\n      avg = n * factor * avg + count * factor * m;\n      n += count;\n    }\n    global_mean_ptr[i] = avg;\n    global_invstd_ptr[i] = static_cast<ACC_T>(1) / device_sqrt(var_n / n + eps);\n    if (running_mean_ptr != nullptr) {\n      running_mean_ptr[i] = static_cast<T>((1 - momentum) * running_mean_ptr[i] + momentum * avg);\n    }\n    ACC_T unbiasedVar = var_n / (n - 1);\n    if (running_var_ptr != nullptr) {\n      running_var_ptr[i] =\n          static_cast<T>((1 - momentum) * running_var_ptr[i] + momentum * unbiasedVar);\n    }\n  }\n}\n\ntemplate<typename T>\nstruct BatchNormGatherStatsWithCountsFunctor final {\n  void operator()(ep::Stream* stream, const int64_t world_size, const int64_t channel_size,\n                  const T* mean_ptr, const T* invstd_ptr, const T* counts_ptr, T* global_mean_ptr,\n                  T* global_invstd_ptr, T* running_mean_ptr, T* running_var_ptr, const float eps,\n                  const float momentum) {\n    using ACC_T = acc_type<T>;\n    int32_t block = getNumThreads(channel_size);\n    int32_t grid = std::max<int32_t>(1, channel_size / block);\n\n    if (world_size * channel_size < std::numeric_limits<int32_t>::max()) {\n      batch_norm_reduce_statistics_kernel<T, ACC_T, int32_t>\n          <<<grid, block, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(\n              static_cast<int32_t>(world_size), static_cast<int32_t>(channel_size), mean_ptr,\n              invstd_ptr, counts_ptr, global_mean_ptr, global_invstd_ptr, running_mean_ptr,\n              running_var_ptr, eps, momentum);\n    } else {\n      batch_norm_reduce_statistics_kernel<T, ACC_T, int64_t>\n          <<<grid, block, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(\n              world_size, channel_size, mean_ptr, invstd_ptr, counts_ptr, global_mean_ptr,\n              global_invstd_ptr, running_mean_ptr, running_var_ptr, eps, momentum);\n    }\n  }\n};\n\n}  // namespace\n\ntemplate<typename T>\nclass GpuBatchNormGatherStatsWithCountsKernel final : public user_op::OpKernel {\n public:\n  GpuBatchNormGatherStatsWithCountsKernel() = default;\n  ~GpuBatchNormGatherStatsWithCountsKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* input = ctx->Tensor4ArgNameAndIndex(\"input\", 0);\n    const user_op::Tensor* mean = ctx->Tensor4ArgNameAndIndex(\"mean\", 0);\n    const user_op::Tensor* invstd = ctx->Tensor4ArgNameAndIndex(\"invstd\", 0);\n    const user_op::Tensor* counts = ctx->Tensor4ArgNameAndIndex(\"counts\", 0);\n    user_op::Tensor* global_mean = ctx->Tensor4ArgNameAndIndex(\"global_mean\", 0);\n    user_op::Tensor* global_invstd = ctx->Tensor4ArgNameAndIndex(\"global_invstd\", 0);\n\n    const T* mean_ptr = mean->dptr<T>();\n    const T* invstd_ptr = invstd->dptr<T>();\n    const T* counts_ptr = counts->dptr<T>();\n    T* global_mean_ptr = global_mean->mut_dptr<T>();\n    T* global_invstd_ptr = global_invstd->mut_dptr<T>();\n    T* running_mean_ptr = nullptr;\n    T* running_var_ptr = nullptr;\n    if (ctx->has_input(\"running_mean\", 0)) {\n      CHECK(ctx->has_input(\"running_var\", 0));\n      running_mean_ptr = ctx->Tensor4ArgNameAndIndex(\"running_mean\", 0)->mut_dptr<T>();\n      running_var_ptr = ctx->Tensor4ArgNameAndIndex(\"running_var\", 0)->mut_dptr<T>();\n    }\n\n    const float eps = ctx->Attr<float>(\"eps\");\n    const float momentum = ctx->Attr<float>(\"momentum\");\n\n    const int64_t world_size = mean->shape_view().At(0);\n    const int64_t channel_size = mean->shape_view().At(1);\n\n    BatchNormGatherStatsWithCountsFunctor<T>()(\n        ctx->stream(), world_size, channel_size, mean_ptr, invstd_ptr, counts_ptr, global_mean_ptr,\n        global_invstd_ptr, running_mean_ptr, running_var_ptr, eps, momentum);\n  }\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_BATCH_NORM_GATHER_STATS_WITH_COUNTS_KERNEL(dtype)                         \\\n  REGISTER_USER_KERNEL(\"batch_norm_gather_stats_with_counts\")                              \\\n      .SetCreateFn<GpuBatchNormGatherStatsWithCountsKernel<dtype>>()                       \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)                     \\\n                       && (user_op::HobDataType(\"input\", 0) == GetDataType<dtype>::value)  \\\n                       && (user_op::HobDataType(\"mean\", 0) == GetDataType<dtype>::value)   \\\n                       && (user_op::HobDataType(\"invstd\", 0) == GetDataType<dtype>::value) \\\n                       && (user_op::HobDataType(\"counts\", 0) == GetDataType<dtype>::value))\n\nREGISTER_BATCH_NORM_GATHER_STATS_WITH_COUNTS_KERNEL(float);\nREGISTER_BATCH_NORM_GATHER_STATS_WITH_COUNTS_KERNEL(double);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/batch_norm_kernel_utils.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_KERNELS_BATCH_NORM_UTILS_H_\n#define ONEFLOW_USER_KERNELS_BATCH_NORM_UTILS_H_\n// NOTE(Liang Depeng):\n// Modified from\n// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cuda/Normalization.cuh\n\n#if defined(__CUDACC__)\n\nconstexpr int ELEMENTS_PER_ITER = 4;  // enables concurrency within each thread to hide latency\nconstexpr int ELEMENTS_PER_THREAD = 16;\nconstexpr int OPTIMAL_TILE_W = 32;\nconstexpr int MAX_H_BLOCK = 128;\nconstexpr int32_t MAX_BLOCK_SIZE = 512;\nconstexpr unsigned MAX_GRID_SIZE = 65535u;\n#define WARP_SIZE 32\n\n// returns 2**floor(log2(n))\nstatic int lastPow2(unsigned int n) {\n  n |= (n >> 1);\n  n |= (n >> 2);\n  n |= (n >> 4);\n  n |= (n >> 8);\n  n |= (n >> 16);\n  return std::max<int>(1, n - (n >> 1));\n}\n\n/**\n   Computes ceil(a / b)\n*/\ntemplate<typename T, typename = std::enable_if_t<std::is_integral<T>::value>>\nstatic T ceil_div(T a, T b) {\n  return (a + b - 1) / b;\n}\n\nstatic void flexible_launch_configs(const int reduction, const int stride, dim3& block, dim3& grid,\n                                    const bool coop_flag = false) {\n  int block_x = std::min(lastPow2(stride), OPTIMAL_TILE_W);\n  int block_y =\n      std::min(lastPow2(ceil_div(reduction, ELEMENTS_PER_THREAD)), MAX_BLOCK_SIZE / block_x);\n  if (block_x * block_y != MAX_BLOCK_SIZE) {\n    block_x = std::min(lastPow2(stride), MAX_BLOCK_SIZE / block_y);\n  }\n\n  int grid_x = ceil_div(stride, block_x);\n  int grid_y = std::min(ceil_div(reduction, block_y * ELEMENTS_PER_THREAD), MAX_H_BLOCK);\n  if (coop_flag) {\n    // it's not worth having a grid reduction if the reduction dimension is not big enough\n    grid_y = grid_y < 8 ? 1 : grid_y;\n  }\n\n  block.x = block_x;\n  block.y = block_y;\n  block.z = 1;\n  grid.x = grid_x;\n  grid.y = grid_y;\n  grid.z = 1;\n}\n\ntemplate<typename T>\nstruct AccumulateType {};\ntemplate<>\nstruct AccumulateType<float> {\n  using type = float;\n};\ntemplate<>\nstruct AccumulateType<double> {\n  using type = double;\n};\n\ntemplate<typename T>\nusing acc_type = typename AccumulateType<T>::type;\n\n// Number of threads in a block given an input size up to MAX_BLOCK_SIZE\nstatic int32_t getNumThreads(int64_t nElem) {\n  int32_t threadSizes[5] = {32, 64, 128, 256, MAX_BLOCK_SIZE};\n  for (int32_t i = 0; i != 5; ++i) {\n    if (nElem <= threadSizes[i]) { return threadSizes[i]; }\n  }\n  return MAX_BLOCK_SIZE;\n}\n\ntemplate<typename T>\nstatic __forceinline__ __device__ T device_sqrt(T val);\n\ntemplate<>\n__forceinline__ __device__ float device_sqrt(float val) {\n  return ::sqrtf(val);\n}\n\ntemplate<>\n__forceinline__ __device__ double device_sqrt(double val) {\n  return ::sqrt(val);\n}\n\ntemplate<typename T>\n__device__ __forceinline__ T inv_std(T var, double eps) {\n  T invstd = 0;\n  if (var != static_cast<T>(0) || eps != static_cast<T>(0)) {\n    invstd = static_cast<T>(1) / device_sqrt(var + eps);\n  }\n  return invstd;\n}\n\n// Returns the index of the most significant 1 bit in `val`.\n__device__ __forceinline__ int32_t getMSB(int32_t val) { return 31 - __clz(val); }\n\ntemplate<typename T>\n__device__ __forceinline__ T WARP_SHFL_XOR(T value, int laneMask, int width = warpSize,\n                                           unsigned int mask = 0xffffffff) {\n  return __shfl_xor_sync(mask, value, laneMask, width);\n}\n\n#endif\n\n#endif  // ONEFLOW_USER_KERNELS_BATCH_NORM_UTILS_H_\n"
  },
  {
    "path": "oneflow/user/kernels/batch_norm_stats_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <limits>\n#include \"oneflow/core/device/cuda_util.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/ndarray/ndarray_util.h\"\n#include \"oneflow/core/ndarray/xpu_var_ndarray.h\"\n#include \"oneflow/core/kernel/kernel_util.h\"\n#include \"oneflow/core/kernel/cuda_graph_support.h\"\n#include \"oneflow/core/ep/include/primitive/cast.h\"\n#include \"oneflow/core/ep/include/primitive/fill.h\"\n#include \"oneflow/core/ep/cuda/cuda_device.h\"\n#include \"oneflow/user/kernels/batch_norm_kernel_utils.h\"\n\n// NOTE(Liang Depeng):\n// The implementation of batch_norm_stats kernel is modified from\n// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cuda/Normalization.cuh\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename T>\nstatic size_t InferTmpSizeForChannelLastKernel(user_op::InferContext* ctx) {\n  const int32_t axis = ctx->Attr<int32_t>(\"axis\");\n  const Shape& in_shape = ctx->InputTensorDesc(\"input\", 0).shape();\n  const int64_t stride = in_shape.At(axis);\n  const int64_t reduction_size = in_shape.elem_cnt() / stride;\n  dim3 block;\n  dim3 grid;\n  flexible_launch_configs(reduction_size, stride, block, grid, true);\n  size_t tmp_size = 0;\n  if (grid.y > 1) {\n    tmp_size += 4 * stride * grid.y * sizeof(T);\n    tmp_size += grid.x * sizeof(int32_t);\n  }\n  return tmp_size;\n}\n\ntemplate<typename T, typename C>\n__device__ __forceinline__ void welford_merge_element(C& count, T& mean, T& m2n, const C& count_new,\n                                                      const T& mean_new, const T& m2n_new) {\n  T factor = T(1.0) / ::max(C(1), (count + count_new));\n  T delta0 = mean - mean_new;\n  mean = (mean_new * count_new + mean * count) * factor;\n  m2n += m2n_new + delta0 * delta0 * count_new * count * factor;\n  count += count_new;\n}\n\n// merge mean/m2n among threadIdx.y within block\ntemplate<typename T, typename C>\n__device__ __forceinline__ void welford_merge_block_vertical(C& count, T& mean, T& m2n,\n                                                             C* shmem_count, T* shmem_mean,\n                                                             T* shmem_m2n) {\n  // write to shared memory\n  auto address_base = threadIdx.x + threadIdx.y * blockDim.x;\n\n#pragma unroll\n  for (int offset = blockDim.y / 2; offset > 0; offset >>= 1) {\n    if (threadIdx.y < offset * 2) {\n      shmem_mean[address_base] = mean;\n      shmem_m2n[address_base] = m2n;\n      shmem_count[address_base] = count;\n    }\n    __syncthreads();\n    if (threadIdx.y < offset && threadIdx.y + offset < blockDim.y) {\n      auto address = address_base + offset * blockDim.x;\n      // read shared memory back to register for reduction\n      auto count_new = shmem_count[address];\n      auto mean_new = shmem_mean[address];\n      auto m2n_new = shmem_m2n[address];\n\n      welford_merge_element(count, mean, m2n, count_new, mean_new, m2n_new);\n    }\n  }\n}\n\ntemplate<typename T, typename ACC_T, typename IDX_TYPE, int PARALLEL_LOADS>\n__global__ void batch_norm_collect_statistics_channels_last_kernel(\n    const T* __restrict__ input_ptr, ACC_T* __restrict__ out_mean_ptr,\n    ACC_T* __restrict__ out_invstd_ptr, volatile ACC_T* staging_data_ptr, int32_t* semaphores_ptr,\n    const IDX_TYPE reduction_size, const IDX_TYPE stride, ACC_T epsilon) {\n  // hide latency with concurrency\n  ACC_T x_mean[PARALLEL_LOADS];\n  ACC_T m_2_n[PARALLEL_LOADS];\n  IDX_TYPE count[PARALLEL_LOADS];\n\n#pragma unroll\n  for (IDX_TYPE i = 0; i < PARALLEL_LOADS; i++) {\n    x_mean[i] = ACC_T(0);\n    m_2_n[i] = ACC_T(0);\n    count[i] = ACC_T(0);\n  }\n  // tensor dimension (m,c)\n\n  // loop along m dimension\n  IDX_TYPE inner_loop_stride = blockDim.y * gridDim.y;\n\n  // offset along m dimension\n  IDX_TYPE m_offset = blockIdx.y * blockDim.y + threadIdx.y;\n  IDX_TYPE c_offset = blockIdx.x * blockDim.x + threadIdx.x;\n\n  IDX_TYPE loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS);\n  IDX_TYPE address_base = m_offset * stride + c_offset;\n  IDX_TYPE address_increment = inner_loop_stride * stride;\n\n  for (IDX_TYPE i = 0; i < loop_count; i++) {\n    ACC_T x_math[PARALLEL_LOADS];\n    ACC_T x_count_inv[PARALLEL_LOADS];\n    ACC_T is_valid[PARALLEL_LOADS];\n\n    // load multiple data in\n#pragma unroll\n    for (IDX_TYPE j = 0; j < PARALLEL_LOADS; j++) {\n      if (c_offset < stride && m_offset < reduction_size) {\n        x_math[j] = input_ptr[address_base];\n        count[j]++;\n        x_count_inv[j] = ACC_T(1) / count[j];\n        is_valid[j] = ACC_T(1);\n      } else {\n        x_math[j] = ACC_T(0);\n        x_count_inv[j] = ACC_T(0);\n        is_valid[j] = ACC_T(0);\n      }\n      m_offset += inner_loop_stride;\n      address_base += address_increment;\n    }\n\n    // calculate mean/m2n with welford\n#pragma unroll\n    for (IDX_TYPE j = 0; j < PARALLEL_LOADS; j++) {\n      ACC_T delta0 = x_math[j] - x_mean[j];\n      x_mean[j] += delta0 * x_count_inv[j];\n      ACC_T delta1 = x_math[j] - x_mean[j];\n      m_2_n[j] += delta0 * delta1 * is_valid[j];\n    }\n  }\n\n  // thread reduction to accumulate mean/m_2_n/count between PARALLEL_LOADS\n#pragma unroll\n  for (IDX_TYPE j = 1; j < PARALLEL_LOADS; j++) {\n    welford_merge_element(count[0], x_mean[0], m_2_n[0], count[j], x_mean[j], m_2_n[j]);\n  }\n\n  // release x_mean / m_2_n\n  auto mean_th = x_mean[0];\n  auto m2_th = m_2_n[0];\n  auto count_th = count[0];\n\n  // block-wise reduction with shared memory (since reduction cannot be done within a warp)\n  static __shared__ ACC_T shmem_mean[MAX_BLOCK_SIZE];\n  static __shared__ ACC_T shmem_m2n[MAX_BLOCK_SIZE];\n  static __shared__ IDX_TYPE shmem_count[MAX_BLOCK_SIZE];\n\n  welford_merge_block_vertical(count_th, mean_th, m2_th, shmem_count, shmem_mean, shmem_m2n);\n\n  if (gridDim.y > 1) {\n    volatile ACC_T* staging_mean = staging_data_ptr;\n    volatile ACC_T* staging_m2n = &staging_data_ptr[stride * gridDim.y];\n    volatile IDX_TYPE* staging_count =\n        reinterpret_cast<volatile IDX_TYPE*>(&staging_m2n[stride * gridDim.y]);\n\n    address_base = c_offset + blockIdx.y * stride;\n    // write data to staging_data_ptr;\n    if (threadIdx.y == 0 && c_offset < stride) {\n      staging_mean[address_base] = mean_th;\n      staging_m2n[address_base] = m2_th;\n      staging_count[address_base] = count_th;\n    }\n\n    __threadfence();\n    __syncthreads();  // ensuring writes to staging_ is visible to all blocks\n\n    __shared__ bool is_last_block_done;\n    // mark block done\n    if (threadIdx.x == 0 && threadIdx.y == 0) {\n      IDX_TYPE old = atomicAdd(&semaphores_ptr[blockIdx.x], 1);\n      is_last_block_done = (old == (gridDim.y - 1));\n    }\n\n    __syncthreads();\n\n    // check that all data is now available in global memory\n    if (is_last_block_done) {\n      count_th = 0;\n      mean_th = ACC_T(0.0);\n      m2_th = ACC_T(0.0);\n\n      for (IDX_TYPE y = threadIdx.y; y < gridDim.y; y += blockDim.y) {\n        address_base = c_offset + y * stride;\n        IDX_TYPE count_new = c_offset < stride ? staging_count[address_base] : 0;\n        ACC_T mean_new = c_offset < stride ? staging_mean[address_base] : ACC_T(0.0);\n        ACC_T m2n_new = c_offset < stride ? staging_m2n[address_base] : ACC_T(0.0);\n\n        welford_merge_element(count_th, mean_th, m2_th, count_new, mean_new, m2n_new);\n      }\n\n      welford_merge_block_vertical(count_th, mean_th, m2_th, shmem_count, shmem_mean, shmem_m2n);\n      if (threadIdx.y == 0 && c_offset < stride) {\n        out_mean_ptr[c_offset] = static_cast<ACC_T>(mean_th);\n        out_invstd_ptr[c_offset] = inv_std(m2_th / count_th, epsilon);\n      }\n    }\n  } else {\n    if (blockIdx.y == 0 && threadIdx.y == 0 && c_offset < stride) {\n      out_mean_ptr[c_offset] = static_cast<ACC_T>(mean_th);\n      out_invstd_ptr[c_offset] = inv_std(m2_th / count_th, epsilon);\n    }\n  }\n}\n\ntemplate<typename T, typename ACC_T, typename IDX_TYPE>\n__global__ void batch_norm_collect_statistics_kernel(const T* input_ptr, const IDX_TYPE batch_size,\n                                                     const IDX_TYPE channel_size,\n                                                     const IDX_TYPE spatial_size, const ACC_T eps,\n                                                     T* mean_ptr, T* invstd_ptr) {\n  __shared__ IDX_TYPE shared_n[2 * 2 * WARP_SIZE + WARP_SIZE];\n\n  IDX_TYPE channel_idx = blockIdx.x;\n  IDX_TYPE N = batch_size * spatial_size;\n  IDX_TYPE tid = threadIdx.x + threadIdx.y * blockDim.x;\n\n  // Compute the mean and variance across (batch, x/y/z)\n  // this uses the Welford (in the for loop)/parallel algorithm (to sum across the block)\n  // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_Online_algorithm\n  // and the parallel algorithm on the same page.\n  // We use two shuffles to reduce across the entire block.\n  // https://devblogs.nvidia.com/faster-parallel-reductions-kepler/ has a description.\n  ACC_T* shared_avg_var = (ACC_T*)&shared_n[WARP_SIZE];\n\n  // first the reductions each thread does separately\n  ACC_T avg = 0;\n  ACC_T var_n = 0;\n  IDX_TYPE n = 0;\n  const IDX_TYPE channel_offset = channel_idx * spatial_size;\n  const IDX_TYPE batch_offset = channel_size * spatial_size;\n  for (IDX_TYPE batch = threadIdx.y; batch < batch_size; batch += blockDim.y) {\n    IDX_TYPE offset = batch * batch_offset + channel_offset;\n    for (IDX_TYPE x = threadIdx.x; x < spatial_size; x += blockDim.x) {\n      ACC_T v = input_ptr[offset + x];\n      ACC_T d1 = v - avg;\n      n++;\n      avg += d1 / n;\n      var_n += d1 * (v - avg);\n    }\n  }\n\n  // summing the result of all the threads within a warp\n  // refer to: https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm\n  // first warpSum to get one value per thread to one value per warp\n  for (IDX_TYPE i = 0; i < getMSB(WARP_SIZE); ++i) {\n    ACC_T o_avg = WARP_SHFL_XOR(avg, 1 << i, WARP_SIZE);\n    IDX_TYPE o_n = WARP_SHFL_XOR(n, 1 << i, WARP_SIZE);\n    ACC_T factor = 1.0 / fmaxf(1.0, n + o_n);\n    var_n +=\n        WARP_SHFL_XOR(var_n, 1 << i, WARP_SIZE) + (avg - o_avg) * (avg - o_avg) * n * o_n * factor;\n    avg = (n * avg + o_n * o_avg) * factor;\n    n += o_n;\n  }\n\n  // this writes each warp's final sum result into shared memory\n  // there are at most (thread_number_of_a_block / WARP_SIZE) results\n  __syncthreads();\n  if (tid % WARP_SIZE == 0) {\n    shared_n[tid / WARP_SIZE] = n;\n    shared_avg_var[tid / WARP_SIZE * 2] = avg;\n    shared_avg_var[tid / WARP_SIZE * 2 + 1] = var_n;\n  }\n  __syncthreads();\n\n  // now have a second warpSum to reduce the intermediate values\n  // from shared memory to a single number. The very first\n  // thread writes it to shared memory.\n  if (tid < WARP_SIZE) {\n    // initialize n, avg and var_n of each thread within the first warp\n    n = (tid < blockDim.x * blockDim.y / WARP_SIZE ? shared_n[tid] : 0);\n    avg = (tid < blockDim.x * blockDim.y / WARP_SIZE ? shared_avg_var[2 * tid] : ACC_T(0));\n    var_n = (tid < blockDim.x * blockDim.y / WARP_SIZE ? shared_avg_var[2 * tid + 1] : ACC_T(0));\n\n    for (IDX_TYPE i = 0; i < getMSB(WARP_SIZE); ++i) {\n      ACC_T o_avg = WARP_SHFL_XOR(avg, 1 << i, WARP_SIZE);\n      IDX_TYPE o_n = WARP_SHFL_XOR(n, 1 << i, WARP_SIZE);\n      ACC_T factor = 1.0 / fmaxf(1.0, n + o_n);\n      var_n += WARP_SHFL_XOR(var_n, 1 << i, WARP_SIZE)\n               + (avg - o_avg) * (avg - o_avg) * n * o_n * factor;\n      avg = (n * avg + o_n * o_avg) * factor;\n      n += o_n;\n    }\n  }\n\n  // save the mean and inverse standard deviation\n  if (tid == 0) {\n    mean_ptr[channel_idx] = avg;\n    invstd_ptr[channel_idx] = inv_std(var_n / N, eps);\n  }\n}\n\ntemplate<typename T>\nstruct BatchNormStatsFunctor final {\n  void operator()(ep::Stream* stream, const user_op::Tensor* input, user_op::Tensor* mean,\n                  user_op::Tensor* invstd, const float eps) {\n    using ACC_T = acc_type<T>;\n    const ShapeView& input_shape = input->shape_view();\n    const int64_t input_numel = input_shape.elem_cnt();\n    const int64_t spatial_size = input_shape.Count(2);\n\n    dim3 blocks(input_shape.At(1));\n    int32_t tf = getNumThreads(spatial_size);\n    dim3 threads(tf, std::max<int32_t>(1, MAX_BLOCK_SIZE / tf));\n\n    const T* input_ptr = input->dptr<T>();\n    T* mean_ptr = mean->mut_dptr<T>();\n    T* invstd_ptr = invstd->mut_dptr<T>();\n\n    if (input_numel < std::numeric_limits<int32_t>::max()) {\n      batch_norm_collect_statistics_kernel<T, ACC_T, int32_t>\n          <<<blocks, threads, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(\n              input_ptr, static_cast<int32_t>(input_shape.At(0)),\n              static_cast<int32_t>(input_shape.At(1)), static_cast<int32_t>(spatial_size), eps,\n              mean_ptr, invstd_ptr);\n    } else {\n      batch_norm_collect_statistics_kernel<T, ACC_T, int64_t>\n          <<<blocks, threads, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(\n              input_ptr, input_shape.At(0), input_shape.At(1), spatial_size, eps, mean_ptr,\n              invstd_ptr);\n    }\n  }\n};\n\ntemplate<typename T>\nstruct BatchNormStatsChannelLastFunctor final {\n  void operator()(ep::Stream* stream, const user_op::Tensor* input, user_op::Tensor* mean,\n                  user_op::Tensor* invstd, user_op::Tensor* tmp_buffer, const float eps,\n                  const int32_t axis) {\n    using ACC_T = acc_type<T>;\n    const ShapeView& input_shape = input->shape_view();\n    const int64_t stride = input_shape.At(axis);\n    const int64_t reduction_size = input_shape.elem_cnt() / stride;\n\n    dim3 block;\n    dim3 grid;\n    flexible_launch_configs(reduction_size, stride, block, grid, true);\n\n    T* staging_data_ptr = nullptr;\n    int32_t* semaphores_ptr = nullptr;\n    if (grid.y > 1) {\n      staging_data_ptr = tmp_buffer->mut_dptr<T>();\n      semaphores_ptr = reinterpret_cast<int32_t*>(tmp_buffer->mut_dptr<char>()\n                                                  + 4 * stride * grid.y * sizeof(T));\n    }\n\n    const T* input_ptr = input->dptr<T>();\n    T* mean_ptr = mean->mut_dptr<T>();\n    T* invstd_ptr = invstd->mut_dptr<T>();\n\n    if (input_shape.elem_cnt() < std::numeric_limits<int32_t>::max()) {\n      batch_norm_collect_statistics_channels_last_kernel<T, ACC_T, int32_t, ELEMENTS_PER_ITER>\n          <<<grid, block, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(\n              input_ptr, mean_ptr, invstd_ptr, staging_data_ptr, semaphores_ptr,\n              static_cast<int32_t>(reduction_size), static_cast<int32_t>(stride), eps);\n    } else {\n      batch_norm_collect_statistics_channels_last_kernel<T, ACC_T, int64_t, ELEMENTS_PER_ITER>\n          <<<grid, block, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(\n              input_ptr, mean_ptr, invstd_ptr, staging_data_ptr, semaphores_ptr, reduction_size,\n              stride, eps);\n    }\n  }\n};\n\n}  // namespace\n\ntemplate<typename T>\nclass GpuBatchNormStatsKernel final : public user_op::OpKernel {\n public:\n  GpuBatchNormStatsKernel() = default;\n  ~GpuBatchNormStatsKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* input = ctx->Tensor4ArgNameAndIndex(\"input\", 0);\n    user_op::Tensor* mean = ctx->Tensor4ArgNameAndIndex(\"mean\", 0);\n    user_op::Tensor* invstd = ctx->Tensor4ArgNameAndIndex(\"invstd\", 0);\n\n    const int32_t axis = ctx->Attr<int32_t>(\"axis\");\n    const float eps = ctx->Attr<float>(\"eps\");\n\n    bool use_channels_last_kernel = axis == 1 ? false : true;\n    if (use_channels_last_kernel) {  // NHWC format\n      user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n      BatchNormStatsChannelLastFunctor<T>()(ctx->stream(), input, mean, invstd, tmp_buffer, eps,\n                                            axis);\n    } else {  // NCHW format\n      BatchNormStatsFunctor<T>()(ctx->stream(), input, mean, invstd, eps);\n    }\n  }\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_BATCH_NORM_STATS_KERNEL(dtype)                                            \\\n  REGISTER_USER_KERNEL(\"batch_norm_stats\")                                                 \\\n      .SetCreateFn<GpuBatchNormStatsKernel<dtype>>()                                       \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)                     \\\n                       && (user_op::HobDataType(\"input\", 0) == GetDataType<dtype>::value)) \\\n      .SetInferTmpSizeFn(InferTmpSizeForChannelLastKernel<dtype>)\n\nREGISTER_BATCH_NORM_STATS_KERNEL(float);\nREGISTER_BATCH_NORM_STATS_KERNEL(double);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/bernoulli_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/user/kernels/distributions/common.h\"\n#include \"oneflow/user/kernels/op_kernel_wrapper.h\"\n#include \"oneflow/user/kernels/random_seed_util.h\"\n#include \"oneflow/user/kernels/random_mask_generator.h\"\n\nnamespace oneflow {\n\ntemplate<typename T, typename K>\nclass BernoulliKerenl final : public user_op::OpKernel {\n public:\n  BernoulliKerenl() = default;\n  ~BernoulliKerenl() override = default;\n\n  std::shared_ptr<user_op::OpKernelState> CreateOpKernelState(\n      user_op::KernelInitContext* ctx) const override {\n    const auto& generator = CHECK_JUST(one::MakeGenerator(DeviceType::kCPU));\n    generator->set_current_seed(\n        CHECK_JUST(GetOpKernelRandomSeedInCurrentRank(ctx, ctx->Attr<int64_t>(\"seed\"))));\n    return std::make_shared<DistributionKernelState>(generator);\n  }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state,\n               const user_op::OpKernelCache*) const override {\n    user_op::Tensor* in_blob = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    user_op::Tensor* out_blob = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    const T* in_dptr = in_blob->dptr<T>();\n    K* out_dptr = out_blob->mut_dptr<K>();\n    CHECK_EQ(GetDataType<T>(), in_blob->data_type());\n    CHECK_EQ(GetDataType<K>(), out_blob->data_type());\n    CHECK_EQ(in_blob->shape_view().elem_cnt(), out_blob->shape_view().elem_cnt());\n\n    auto* kernel_state = dynamic_cast<DistributionKernelState*>(state);\n    CHECK_NOTNULL(kernel_state);\n    const auto& generator = kernel_state->generator();\n    CHECK_NOTNULL(generator);\n    const auto& cpu_generator = CHECK_JUST(generator->Get<ep::CPUGenerator>());\n\n    double p = ctx->Attr<double>(\"p\");\n    // prob != -1 means use prob instead of tensor to generate random number\n    if (p != static_cast<double>(-1.0)) {\n      for (int32_t i = 0; i < out_blob->shape_view().elem_cnt(); ++i) {\n        std::bernoulli_distribution dis(p);\n        *(out_dptr + i) = dis(cpu_generator->engine()) ? GetOneVal<K>() : GetZeroVal<K>();\n      }\n    } else {\n      for (int32_t i = 0; i < out_blob->shape_view().elem_cnt(); ++i) {\n        double prob = static_cast<double>(*(in_dptr + i));\n        CHECK(prob >= 0.0 && prob <= 1.0);\n        std::bernoulli_distribution dis(prob);\n        *(out_dptr + i) = dis(cpu_generator->engine()) ? GetOneVal<K>() : GetZeroVal<K>();\n      }\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_BERNOULLI_KERNEL(in_dtype_pair, out_dtype_pair)                                \\\n  REGISTER_USER_KERNEL(\"bernoulli\")                                                             \\\n      .SetCreateFn<                                                                             \\\n          BernoulliKerenl<OF_PP_PAIR_FIRST(in_dtype_pair), OF_PP_PAIR_FIRST(out_dtype_pair)>>() \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)                           \\\n                       && (user_op::HobDataType(\"in\", 0) == OF_PP_PAIR_SECOND(in_dtype_pair))   \\\n                       && (user_op::HobDataType(\"out\", 0) == OF_PP_PAIR_SECOND(out_dtype_pair)));\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_BERNOULLI_KERNEL, FLOATING_DATA_TYPE_SEQ,\n                                 ARITHMETIC_DATA_TYPE_SEQ)\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/bias_add_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/cuda_graph_support.h\"\n#include \"oneflow/core/ep/include/primitive/broadcast_elementwise_binary.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename Context>\nstd::unique_ptr<ep::primitive::BroadcastElementwiseBinary> NewPrimitive(Context* ctx) {\n  const DataType data_type = ctx->TensorDesc4ArgNameAndIndex(\"a\", 0)->data_type();\n  return ep::primitive::NewPrimitive<ep::primitive::BroadcastElementwiseBinaryFactory>(\n      ctx->device_type(), ep::primitive::BinaryOp::kAdd, data_type, data_type, 3);\n}\n\nclass BiasAddUserKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport {\n public:\n  BiasAddUserKernel() = default;\n  ~BiasAddUserKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const auto* a_tensor = ctx->Tensor4ArgNameAndIndex(\"a\", 0);\n    const auto* b_tensor = ctx->Tensor4ArgNameAndIndex(\"b\", 0);\n    if (a_tensor->shape_view().elem_cnt() == 0 || b_tensor->shape_view().elem_cnt() == 0) {\n      return;\n    }\n    auto* out_tensor = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    const int32_t bias_add_axis = ctx->Attr<int32_t>(\"axis\");\n    const int64_t outer_size = a_tensor->shape_view().Count(0, bias_add_axis);\n    const int64_t bias_size = a_tensor->shape_view().At(bias_add_axis);\n    const int64_t inner_size = a_tensor->shape_view().Count(bias_add_axis + 1);\n    auto primitive = NewPrimitive(ctx);\n    const int64_t src0_dims[3] = {outer_size, bias_size, inner_size};\n    const int64_t src1_dims[3] = {1, bias_size, 1};\n    primitive->Launch(ctx->stream(), 3, src0_dims, a_tensor->dptr(), 3, src1_dims, b_tensor->dptr(),\n                      out_tensor->mut_dptr());\n  }\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\nauto PrimitiveExists() {\n  return hob::make_custom(\"PrimitiveExists\", [](const user_op::KernelRegContext& ctx) -> bool {\n    return NewPrimitive(&ctx).operator bool();\n  });\n}\n\nREGISTER_USER_KERNEL(\"bias_add\")\n    .SetCreateFn<BiasAddUserKernel>()\n    .SetIsMatchedHob(PrimitiveExists() == true)\n    .SetInplaceProposalFn([](const user_op::InferContext& ctx,\n                             const user_op::AddInplaceArgPair& AddInplaceArgPairFn) -> Maybe<void> {\n      OF_RETURN_IF_ERROR(AddInplaceArgPairFn(\"out\", 0, \"a\", 0, true));\n      return Maybe<void>::Ok();\n    });\n}  // namespace\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/binary_concat_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/device/cuda_util.h\"\n#include \"oneflow/core/kernel/new_kernel_util.h\"\n#include \"oneflow/core/kernel/cuda_graph_support.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename T, typename IDX>\n__global__ void BinaryConcatKernel(const IDX out_elems, const IDX out_cols, const IDX in0_cols,\n                                   const IDX in1_cols, const T* src0, const T* src1, T* dst) {\n  CUDA_1D_KERNEL_LOOP_T(IDX, i, out_elems) {\n    const IDX row = i / out_cols;\n    const IDX col = i - row * out_cols;\n    const T* src_ptr = nullptr;\n    if (col < in0_cols) {\n      src_ptr = src0 + row * in0_cols + col;\n    } else {\n      src_ptr = src1 + row * in1_cols + (col - in0_cols);\n    }\n    dst[i] = *src_ptr;\n  }\n}\n\ntemplate<typename T, typename IDX>\nvoid LaunchBinaryConcatKernel(ep::Stream* stream, const IDX rows, const IDX in0_cols,\n                              const IDX in1_cols, const void* src0, const void* src1, void* dst) {\n  const IDX out_cols = in0_cols + in1_cols;\n  const IDX out_elems = rows * out_cols;\n  RUN_CUDA_KERNEL((BinaryConcatKernel<T, IDX>), stream, out_elems, out_elems, out_cols, in0_cols,\n                  in1_cols, reinterpret_cast<const T*>(src0), reinterpret_cast<const T*>(src1),\n                  reinterpret_cast<T*>(dst));\n}\n\ntemplate<typename T>\nvoid DispatchIndexType(ep::Stream* stream, const int64_t rows, const int64_t in0_cols,\n                       const int64_t in1_cols, const void* src0, const void* src1, void* dst) {\n  if (rows * (in0_cols + in1_cols) >= (1 >> 30)) {\n    LaunchBinaryConcatKernel<T, int64_t>(stream, rows, in0_cols, in1_cols, src0, src1, dst);\n  } else {\n    LaunchBinaryConcatKernel<T, int32_t>(stream, rows, in0_cols, in1_cols, src0, src1, dst);\n  }\n}\n\nvoid DispatchDataType(ep::Stream* stream, const int64_t rows, const int64_t in0_cols,\n                      const int64_t in1_cols, const void* src0, const void* src1, void* dst) {\n  const uintptr_t src0_ptr = reinterpret_cast<uintptr_t>(src0);\n  const uintptr_t src1_ptr = reinterpret_cast<uintptr_t>(src1);\n  const uintptr_t dst_ptr = reinterpret_cast<uintptr_t>(dst);\n  const auto IsAligned = [&](const size_t alignment) {\n    return src0_ptr % alignment == 0 && src1_ptr % alignment == 0 && dst_ptr % alignment == 0\n           && in0_cols % alignment == 0 && in1_cols % alignment == 0;\n  };\n  if (IsAligned(16)) {\n    DispatchIndexType<uint4>(stream, rows, in0_cols / 16, in1_cols / 16, src0, src1, dst);\n  } else if (IsAligned(8)) {\n    DispatchIndexType<uint2>(stream, rows, in0_cols / 8, in1_cols / 8, src0, src1, dst);\n  } else if (IsAligned(4)) {\n    DispatchIndexType<uint32_t>(stream, rows, in0_cols / 4, in1_cols / 4, src0, src1, dst);\n  } else if (IsAligned(2)) {\n    DispatchIndexType<uint16_t>(stream, rows, in0_cols / 2, in1_cols / 2, src0, src1, dst);\n  } else {\n    DispatchIndexType<uint8_t>(stream, rows, in0_cols, in1_cols, src0, src1, dst);\n  }\n}\n\nvoid DispatchBinaryConcat(ep::Stream* stream, const int64_t elem_size, const int64_t rows,\n                          const int64_t in0_cols, const int64_t in1_cols, const void* src0,\n                          const void* src1, void* dst) {\n  DispatchDataType(stream, rows, in0_cols * elem_size, in1_cols * elem_size, src0, src1, dst);\n}\n\nclass ConcatKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport {\n public:\n  ConcatKernel() = default;\n  ~ConcatKernel() override = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    user_op::Tensor* out_tensor = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    const DataType data_type = out_tensor->data_type();\n    if (out_tensor->shape_view().elem_cnt() == 0) { return; }\n    const int64_t axis = ctx->Attr<int64_t>(\"axis\");\n    CHECK_GE(axis, 0);\n    const int64_t num_axes = out_tensor->shape_view().NumAxes();\n    CHECK_LT(axis, num_axes);\n    const int64_t out_cols = out_tensor->shape_view().Count(axis);\n    const int64_t rows = out_tensor->shape_view().elem_cnt() / out_cols;\n    CHECK_GT(rows, 0);\n\n    CHECK_EQ(ctx->input_size(\"in\"), 2);\n    const user_op::Tensor* in0_tensor = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    const user_op::Tensor* in1_tensor = ctx->Tensor4ArgNameAndIndex(\"in\", 1);\n    CHECK_EQ(in0_tensor->data_type(), data_type);\n    CHECK_EQ(in1_tensor->data_type(), data_type);\n    if (in0_tensor->shape_view().elem_cnt() == 0) {\n      CHECK_EQ(in1_tensor->shape_view(), out_tensor->shape_view());\n      Memcpy<DeviceType::kCUDA>(ctx->stream(), out_tensor->mut_dptr(), in1_tensor->dptr(),\n                                out_tensor->shape_view().elem_cnt() * GetSizeOfDataType(data_type));\n      return;\n    }\n    if (in1_tensor->shape_view().elem_cnt() == 0) {\n      CHECK_EQ(in0_tensor->shape_view(), out_tensor->shape_view());\n      Memcpy<DeviceType::kCUDA>(ctx->stream(), out_tensor->mut_dptr(), in0_tensor->dptr(),\n                                out_tensor->shape_view().elem_cnt() * GetSizeOfDataType(data_type));\n      return;\n    }\n    CHECK_EQ(in0_tensor->shape_view().NumAxes(), num_axes);\n    CHECK_EQ(in1_tensor->shape_view().NumAxes(), num_axes);\n    for (int64_t i = 0; i < num_axes; ++i) {\n      if (i != axis) {\n        CHECK_EQ(in0_tensor->shape_view().At(i), out_tensor->shape_view().At(i));\n        CHECK_EQ(in1_tensor->shape_view().At(i), out_tensor->shape_view().At(i));\n      }\n    }\n    CHECK_EQ(in0_tensor->shape_view().At(axis) + in1_tensor->shape_view().At(axis),\n             out_tensor->shape_view().At(axis));\n    const int64_t in0_cols = in0_tensor->shape_view().Count(axis);\n    const int64_t in1_cols = in1_tensor->shape_view().Count(axis);\n\n    DispatchBinaryConcat(ctx->stream(), GetSizeOfDataType(data_type), rows, in0_cols, in1_cols,\n                         in0_tensor->dptr(), in1_tensor->dptr(), out_tensor->mut_dptr());\n  }\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n}  // namespace\n\nREGISTER_USER_KERNEL(\"cat\")\n    .SetCreateFn<ConcatKernel>()\n    .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)\n                     && (user_op::HobInputSize(\"in\") == 2))\n    .SetPriority(user_op::kKernelPriorityOptimized);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/binary_cross_entropy_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/user/kernels/loss_kernel_util.h\"\n\nnamespace oneflow {\nnamespace user_op {\nnamespace {\n\nusing namespace loss;\n\ntemplate<typename T>\nvoid ComputeBinaryCrossEntropyOut(int64_t elem_cnt, const T* input, const T* target, T* out,\n                                  const T* weight) {\n  T negative_100 = static_cast<T>(-100);\n  FOR_RANGE(int64_t, i, 0, elem_cnt) {\n    T input_val = input[i];\n    T target_val = target[i];\n    CHECK_LE(input_val, 1.0);\n    CHECK_GE(input_val, 0.0);\n    out[i] = (target_val - 1) * std::max(static_cast<T>(std::log(1.0 - input_val)), negative_100)\n             - target_val * std::max(static_cast<T>(std::log(input_val)), negative_100);\n    if (weight != nullptr) { out[i] *= weight[i]; }\n  }\n}\n\ntemplate<typename T>\nvoid ComputeBinaryCrossEntropyGradOut(int64_t elem_cnt, const T* input, const T* target,\n                                      const T* dy, T* dx, const T* weight) {\n  const T eps = static_cast<T>(1e-12);\n  FOR_RANGE(int64_t, i, 0, elem_cnt) {\n    T input_val = input[i];\n    T target_val = target[i];\n    T dy_val = dy[i];\n    dx[i] = dy_val * (input_val - target_val)\n            / (std::max((static_cast<T>(1.0) - input_val) * input_val, eps));\n    if (weight != nullptr) { dx[i] *= weight[i]; }\n  }\n}\ntemplate<typename T>\nclass BinaryCrossEntropyKernel final : public user_op::OpKernel {\n public:\n  BinaryCrossEntropyKernel() = default;\n  ~BinaryCrossEntropyKernel() override = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const auto* input_blob = ctx->Tensor4ArgNameAndIndex(\"input\", 0);\n    const auto* target_blob = ctx->Tensor4ArgNameAndIndex(\"target\", 0);\n    auto* out_blob = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n\n    const int64_t elem_cnt = input_blob->shape_view().elem_cnt();\n\n    const T* input = input_blob->dptr<T>();\n    const T* target = target_blob->dptr<T>();\n    T* out = out_blob->mut_dptr<T>();\n    const T* weight =\n        ctx->has_input(\"weight\", 0) ? ctx->Tensor4ArgNameAndIndex(\"weight\", 0)->dptr<T>() : nullptr;\n\n    ComputeBinaryCrossEntropyOut(elem_cnt, input, target, out, weight);\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\ntemplate<typename T>\nclass BinaryCrossEntropyGradKernel final : public user_op::OpKernel {\n public:\n  BinaryCrossEntropyGradKernel() = default;\n  ~BinaryCrossEntropyGradKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const auto* input_blob = ctx->Tensor4ArgNameAndIndex(\"input\", 0);\n    const auto* target_blob = ctx->Tensor4ArgNameAndIndex(\"target\", 0);\n    const auto* dy_blob = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    auto* dx_blob = ctx->Tensor4ArgNameAndIndex(\"dx\", 0);\n\n    const int64_t elem_cnt = input_blob->shape_view().elem_cnt();\n\n    const T* dy = dy_blob->dptr<T>();\n    const T* input = input_blob->dptr<T>();\n    const T* target = target_blob->dptr<T>();\n    T* dx = dx_blob->mut_dptr<T>();\n    const T* weight =\n        ctx->has_input(\"weight\", 0) ? ctx->Tensor4ArgNameAndIndex(\"weight\", 0)->dptr<T>() : nullptr;\n    ComputeBinaryCrossEntropyGradOut(elem_cnt, input, target, dy, dx, weight);\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n}  // namespace\n\n#define REGISTER_BINARY_CROSS_ENTROPY_KERNEL(dtype)                                        \\\n  REGISTER_USER_KERNEL(\"binary_cross_entropy\")                                             \\\n      .SetCreateFn<BinaryCrossEntropyKernel<dtype>>()                                      \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)                      \\\n                       && (user_op::HobDataType(\"input\", 0) == GetDataType<dtype>::value)  \\\n                       && (user_op::HobDataType(\"target\", 0) == GetDataType<dtype>::value) \\\n                       && (user_op::HobDataType(\"out\", 0) == GetDataType<dtype>::value));\n\n#define REGISTER_BINARY_CROSS_ENTROPY_GRAD_KERNEL(dtype)                                   \\\n  REGISTER_USER_KERNEL(\"binary_cross_entropy_grad\")                                        \\\n      .SetCreateFn<BinaryCrossEntropyGradKernel<dtype>>()                                  \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)                      \\\n                       && (user_op::HobDataType(\"input\", 0) == GetDataType<dtype>::value)  \\\n                       && (user_op::HobDataType(\"target\", 0) == GetDataType<dtype>::value) \\\n                       && (user_op::HobDataType(\"dy\", 0) == GetDataType<dtype>::value)     \\\n                       && (user_op::HobDataType(\"dx\", 0) == GetDataType<dtype>::value));\n\nREGISTER_BINARY_CROSS_ENTROPY_KERNEL(float)\nREGISTER_BINARY_CROSS_ENTROPY_KERNEL(double)\nREGISTER_BINARY_CROSS_ENTROPY_GRAD_KERNEL(float)\nREGISTER_BINARY_CROSS_ENTROPY_GRAD_KERNEL(double)\n\n}  // namespace user_op\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/binary_cross_entropy_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/cuda/elementwise.cuh\"\n#include \"oneflow/user/kernels/loss_kernel_util.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n\nnamespace oneflow {\nnamespace user_op {\nnamespace {\n\nusing namespace loss;\n\ntemplate<typename T>\nstruct BinaryCrossEntropyFunctor {\n  T zero_;\n  T one_;\n  T negative_hundred_;\n  BinaryCrossEntropyFunctor()\n      : zero_(GetZeroVal<T>()), one_(GetOneVal<T>()), negative_hundred_(static_cast<T>(-100)) {}\n  __device__ __forceinline__ T operator()(T input_val, T target_val) const {\n    assert(input_val >= zero_);\n    assert(input_val <= one_);\n    return (target_val - one_) * max(static_cast<T>(log(one_ - input_val)), negative_hundred_)\n           - target_val * max(static_cast<T>(log(input_val)), negative_hundred_);\n  }\n\n  __device__ __forceinline__ T operator()(T input_val, T target_val, T weight_val) const {\n    return (*this)(input_val, target_val) * weight_val;\n  }\n};\n\ntemplate<>\nstruct BinaryCrossEntropyFunctor<float> {\n  float zero_;\n  float one_;\n  float negative_hundred_;\n  BinaryCrossEntropyFunctor() : zero_(0.f), one_(1.f), negative_hundred_(-100.f) {}\n  __device__ __forceinline__ float operator()(float input_val, float target_val) const {\n    assert(input_val >= zero_);\n    assert(input_val <= one_);\n    return (target_val - one_) * max(logf(one_ - input_val), negative_hundred_)\n           - target_val * max(logf(input_val), negative_hundred_);\n  }\n\n  __device__ __forceinline__ float operator()(float input_val, float target_val,\n                                              float weight_val) const {\n    return (*this)(input_val, target_val) * weight_val;\n  }\n};\n\ntemplate<>\nstruct BinaryCrossEntropyFunctor<half> {\n  BinaryCrossEntropyFunctor<float> float_functor;\n  __device__ __forceinline__ half operator()(half input_val, half target_val) const {\n    return __float2half(float_functor(__half2float(input_val), __half2float(target_val)));\n  }\n\n  __device__ __forceinline__ half operator()(half input_val, half target_val,\n                                             half weight_val) const {\n    return (*this)(input_val, target_val) * weight_val;\n  }\n};\n\ntemplate<typename T>\nstruct BinaryCrossEntropyGradFunctor {\n  T eps_;\n  T one_;\n  BinaryCrossEntropyGradFunctor() : eps_(static_cast<T>(1e-12)), one_(GetOneVal<T>()) {}\n  __device__ __forceinline__ T operator()(T input_val, T target_val, T dy_val) const {\n    return dy_val * (input_val - target_val) / max((one_ - input_val) * input_val, eps_);\n  }\n  __device__ __forceinline__ T operator()(T input_val, T target_val, T dy_val, T weight_val) const {\n    return (*this)(input_val, target_val, dy_val) * weight_val;\n  }\n};\n\ntemplate<>\nstruct BinaryCrossEntropyGradFunctor<half> {\n  BinaryCrossEntropyGradFunctor<float> float_functor;\n  BinaryCrossEntropyGradFunctor() {}\n  __device__ __forceinline__ half operator()(half input_val, half target_val, half dy_val) const {\n    return __float2half(\n        float_functor(__half2float(input_val), __half2float(target_val), __half2float(dy_val)));\n  }\n  __device__ __forceinline__ half operator()(half input_val, half target_val, half dy_val,\n                                             half weight_val) const {\n    return __float2half(float_functor(__half2float(input_val), __half2float(target_val),\n                                      __half2float(dy_val), __half2float(weight_val)));\n  }\n};\n\ntemplate<typename T>\nclass BinaryCrossEntropyKernel final : public user_op::OpKernel {\n public:\n  BinaryCrossEntropyKernel() = default;\n  ~BinaryCrossEntropyKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const auto* input_blob = ctx->Tensor4ArgNameAndIndex(\"input\", 0);\n    const auto* target_blob = ctx->Tensor4ArgNameAndIndex(\"target\", 0);\n    auto* out_blob = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n\n    const int64_t elem_cnt = input_blob->shape_view().elem_cnt();\n\n    const T* input = input_blob->dptr<T>();\n    const T* target = target_blob->dptr<T>();\n    T* out = out_blob->mut_dptr<T>();\n\n    if (ctx->has_input(\"weight\", 0)) {\n      const T* weight = ctx->Tensor4ArgNameAndIndex(\"weight\", 0)->dptr<T>();\n      OF_CUDA_CHECK(\n          (cuda::elementwise::Ternary(BinaryCrossEntropyFunctor<T>(), elem_cnt, out, input, target,\n                                      weight, ctx->stream()->As<ep::CudaStream>()->cuda_stream())));\n    } else {\n      OF_CUDA_CHECK(\n          (cuda::elementwise::Binary(BinaryCrossEntropyFunctor<T>(), elem_cnt, out, input, target,\n                                     ctx->stream()->As<ep::CudaStream>()->cuda_stream())));\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\ntemplate<typename T>\nclass BinaryCrossEntropyGradKernel final : public user_op::OpKernel {\n public:\n  BinaryCrossEntropyGradKernel() = default;\n  ~BinaryCrossEntropyGradKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const auto* input_blob = ctx->Tensor4ArgNameAndIndex(\"input\", 0);\n    const auto* target_blob = ctx->Tensor4ArgNameAndIndex(\"target\", 0);\n    const auto* dy_blob = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    auto* dx_blob = ctx->Tensor4ArgNameAndIndex(\"dx\", 0);\n\n    const int64_t elem_cnt = input_blob->shape_view().elem_cnt();\n\n    const T* dy = dy_blob->dptr<T>();\n    const T* input = input_blob->dptr<T>();\n    const T* target = target_blob->dptr<T>();\n    T* dx = dx_blob->mut_dptr<T>();\n\n    if (ctx->has_input(\"weight\", 0)) {\n      const T* weight = ctx->Tensor4ArgNameAndIndex(\"weight\", 0)->dptr<T>();\n      using FunctorT = BinaryCrossEntropyGradFunctor<T>;\n      using FactoryT = cuda::elementwise::SimpleFactory<FunctorT>;\n      OF_CUDA_CHECK((cuda::elementwise::GenericLauncher<FactoryT, T, T, T, T, T>::Launch(\n          FactoryT(FunctorT()), elem_cnt, dx, input, target, dy, weight,\n          ctx->stream()->As<ep::CudaStream>()->cuda_stream())));\n    } else {\n      OF_CUDA_CHECK((cuda::elementwise::Ternary(\n          BinaryCrossEntropyGradFunctor<T>(), elem_cnt, dx, input, target, dy,\n          ctx->stream()->As<ep::CudaStream>()->cuda_stream())));\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n}  // namespace\n\n#define REGISTER_BINARY_CROSS_ENTROPY_KERNEL(dtype)                                        \\\n  REGISTER_USER_KERNEL(\"binary_cross_entropy\")                                             \\\n      .SetCreateFn<BinaryCrossEntropyKernel<dtype>>()                                      \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)                     \\\n                       && (user_op::HobDataType(\"input\", 0) == GetDataType<dtype>::value)  \\\n                       && (user_op::HobDataType(\"target\", 0) == GetDataType<dtype>::value) \\\n                       && (user_op::HobDataType(\"out\", 0) == GetDataType<dtype>::value));\n\n#define REGISTER_BINARY_CROSS_ENTROPY_GRAD_KERNEL(dtype)                                   \\\n  REGISTER_USER_KERNEL(\"binary_cross_entropy_grad\")                                        \\\n      .SetCreateFn<BinaryCrossEntropyGradKernel<dtype>>()                                  \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)                     \\\n                       && (user_op::HobDataType(\"input\", 0) == GetDataType<dtype>::value)  \\\n                       && (user_op::HobDataType(\"target\", 0) == GetDataType<dtype>::value) \\\n                       && (user_op::HobDataType(\"dy\", 0) == GetDataType<dtype>::value)     \\\n                       && (user_op::HobDataType(\"dx\", 0) == GetDataType<dtype>::value));\n\nREGISTER_BINARY_CROSS_ENTROPY_KERNEL(half)\nREGISTER_BINARY_CROSS_ENTROPY_KERNEL(float)\nREGISTER_BINARY_CROSS_ENTROPY_KERNEL(double)\n\nREGISTER_BINARY_CROSS_ENTROPY_GRAD_KERNEL(half)\nREGISTER_BINARY_CROSS_ENTROPY_GRAD_KERNEL(float)\nREGISTER_BINARY_CROSS_ENTROPY_GRAD_KERNEL(double)\n\n}  // namespace user_op\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/binary_cross_entropy_with_logits_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/ndarray/ndarray_util.h\"\n#include \"oneflow/core/ndarray/xpu_var_ndarray.h\"\n#include \"oneflow/user/kernels/loss_kernel_util.h\"\n#include \"oneflow/core/ep/include/primitive/broadcast_elementwise_binary.h\"\n\nnamespace oneflow {\nnamespace user_op {\nnamespace {\n\nusing namespace loss;\n\ntemplate<typename T>\ninline T ComputeMaxVal(const T x) {\n  T y = -x;\n  return y < 0 ? 0 : y;\n}\n\ntemplate<typename T>\ninline T CalSigmoid(const T x) {\n  const T half_of_one = static_cast<T>(0.5);\n  return half_of_one * std::tanh(half_of_one * x) + half_of_one;\n}\n\ntemplate<typename INPUT_T, typename TARGET_T>\nvoid ComputeBinaryCrossEntropyWithLogitsOut(int64_t elem_cnt, const INPUT_T* input,\n                                            const TARGET_T* target, TARGET_T* out,\n                                            const TARGET_T* weight,\n                                            const TARGET_T* pos_weight_processed) {\n  FOR_RANGE(int64_t, i, 0, elem_cnt) {\n    TARGET_T input_val = static_cast<TARGET_T>(input[i]);\n    TARGET_T target_val = target[i];\n    TARGET_T max_val = ComputeMaxVal(input_val);\n    if (out != nullptr) {\n      if (pos_weight_processed == nullptr) {\n        out[i] = (1 - target_val) * input_val + max_val\n                 + (std::log(std::exp(-max_val) + std::exp(-input_val - max_val)));\n      } else {\n        TARGET_T pos_weight_processed_val = pos_weight_processed[i] - target_val + 1;\n        out[i] = (1 - target_val) * input_val\n                 + (pos_weight_processed_val\n                    * (std::log(std::exp(-max_val) + std::exp(-input_val - max_val)) + max_val));\n      }\n    }\n    if (weight != nullptr && out != nullptr) { out[i] *= weight[i]; }\n  }\n}\n\ntemplate<typename INPUT_T, typename TARGET_T>\nvoid ComputeBinaryCrossEntropyWithLogitsGradOut(int64_t elem_cnt, const INPUT_T* input,\n                                                const TARGET_T* target, const TARGET_T* dy,\n                                                INPUT_T* dx, const TARGET_T* weight,\n                                                const TARGET_T* pos_weight_processed) {\n  FOR_RANGE(int64_t, i, 0, elem_cnt) {\n    INPUT_T input_val = input[i];\n    TARGET_T target_val = target[i];\n    TARGET_T dy_val = dy[i];\n    TARGET_T input_sigmoid = static_cast<TARGET_T>(CalSigmoid(input_val));\n    TARGET_T dx_i_buffer = 0.0;\n    if (pos_weight_processed == nullptr) {\n      dx_i_buffer = (input_sigmoid - target_val) * dy_val;\n    } else {\n      dx_i_buffer =\n          dy_val\n          * ((pos_weight_processed[i] + 1 - target_val) * input_sigmoid - pos_weight_processed[i]);\n    }\n\n    if (weight != nullptr) { dx_i_buffer *= weight[i]; }\n    dx[i] = static_cast<INPUT_T>(dx_i_buffer);\n  }\n}\n\ntemplate<typename INPUT_T, typename TARGET_T>\nclass BinaryCrossEntropyWithLogitsKernel final : public user_op::OpKernel {\n public:\n  BinaryCrossEntropyWithLogitsKernel() = default;\n  ~BinaryCrossEntropyWithLogitsKernel() override = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const auto* input_blob = ctx->Tensor4ArgNameAndIndex(\"input\", 0);\n    const auto* target_blob = ctx->Tensor4ArgNameAndIndex(\"target\", 0);\n    auto* out_blob = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    auto* tmp_buffer_blob = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n\n    const int64_t elem_cnt = input_blob->shape_view().elem_cnt();\n\n    const INPUT_T* input = input_blob->dptr<INPUT_T>();\n    const TARGET_T* target = target_blob->dptr<TARGET_T>();\n    TARGET_T* out = out_blob->mut_dptr<TARGET_T>();\n\n    const TARGET_T* weight = ctx->has_input(\"weight\", 0)\n                                 ? ctx->Tensor4ArgNameAndIndex(\"weight\", 0)->dptr<TARGET_T>()\n                                 : nullptr;\n\n    TARGET_T* pos_weight_processed = nullptr;\n\n    if (ctx->Attr<bool>(\"has_pos_weight\")) {\n      pos_weight_processed = tmp_buffer_blob->mut_dptr<TARGET_T>();\n      const TARGET_T* pos_weight = ctx->Tensor4ArgNameAndIndex(\"pos_weight\", 0)->dptr<TARGET_T>();\n\n      Shape pos_weight_shape = Shape::Ones(target_blob->shape_view().NumAxes());\n      pos_weight_shape.Set(pos_weight_shape.NumAxes() - 1,\n                           ctx->Tensor4ArgNameAndIndex(\"pos_weight\", 0)->shape_view().elem_cnt());\n      auto bcast_mul =\n          ep::primitive::NewPrimitive<ep::primitive::BroadcastElementwiseBinaryFactory>(\n              ctx->device_type(), ep::primitive::BinaryOp::kMul, target_blob->data_type(),\n              target_blob->data_type(), target_blob->shape_view().NumAxes());\n      CHECK(bcast_mul);\n      bcast_mul->Launch(ctx->stream(), target_blob->shape_view().NumAxes(),\n                        target_blob->shape_view().ptr(), target, pos_weight_shape.NumAxes(),\n                        pos_weight_shape.dim_vec().data(), pos_weight, pos_weight_processed);\n    }\n    ComputeBinaryCrossEntropyWithLogitsOut(elem_cnt, input, target, out, weight,\n                                           pos_weight_processed);\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\ntemplate<typename INPUT_T, typename TARGET_T>\nclass BinaryCrossEntropyWithLogitsGradKernel final : public user_op::OpKernel {\n public:\n  BinaryCrossEntropyWithLogitsGradKernel() = default;\n  ~BinaryCrossEntropyWithLogitsGradKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const auto* input_blob = ctx->Tensor4ArgNameAndIndex(\"input\", 0);\n    const auto* target_blob = ctx->Tensor4ArgNameAndIndex(\"target\", 0);\n    const auto* dy_blob = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    auto* dx_blob = ctx->Tensor4ArgNameAndIndex(\"dx\", 0);\n    auto* tmp_buffer_blob = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n\n    const int64_t elem_cnt = input_blob->shape_view().elem_cnt();\n\n    const TARGET_T* dy = dy_blob->dptr<TARGET_T>();\n    const INPUT_T* input = input_blob->dptr<INPUT_T>();\n    const TARGET_T* target = target_blob->dptr<TARGET_T>();\n    INPUT_T* dx = dx_blob->mut_dptr<INPUT_T>();\n    const TARGET_T* weight = ctx->has_input(\"weight\", 0)\n                                 ? ctx->Tensor4ArgNameAndIndex(\"weight\", 0)->dptr<TARGET_T>()\n                                 : nullptr;\n\n    TARGET_T* pos_weight_processed = nullptr;\n\n    if (ctx->Attr<bool>(\"has_pos_weight\")) {\n      pos_weight_processed = tmp_buffer_blob->mut_dptr<TARGET_T>();\n      const TARGET_T* pos_weight = ctx->Tensor4ArgNameAndIndex(\"pos_weight\", 0)->dptr<TARGET_T>();\n\n      Shape pos_weight_shape = Shape::Ones(target_blob->shape_view().NumAxes());\n      pos_weight_shape.Set(pos_weight_shape.NumAxes() - 1,\n                           ctx->Tensor4ArgNameAndIndex(\"pos_weight\", 0)->shape_view().elem_cnt());\n      auto bcast_mul =\n          ep::primitive::NewPrimitive<ep::primitive::BroadcastElementwiseBinaryFactory>(\n              ctx->device_type(), ep::primitive::BinaryOp::kMul, target_blob->data_type(),\n              target_blob->data_type(), target_blob->shape_view().NumAxes());\n      CHECK(bcast_mul);\n      bcast_mul->Launch(ctx->stream(), target_blob->shape_view().NumAxes(),\n                        target_blob->shape_view().ptr(), target, pos_weight_shape.NumAxes(),\n                        pos_weight_shape.dim_vec().data(), pos_weight, pos_weight_processed);\n    }\n    ComputeBinaryCrossEntropyWithLogitsGradOut(elem_cnt, input, target, dy, dx, weight,\n                                               pos_weight_processed);\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\ntemplate<typename T>\nuser_op::InferTmpSizeFn GenFwInferTmpSizeFn() {\n  return [](user_op::InferContext* ctx) {\n    const int64_t n = ctx->InputShape(\"target\", 0).elem_cnt();\n    size_t tmp_buffer_size = 0;\n    if (ctx->Attr<bool>(\"has_pos_weight\")) { tmp_buffer_size += GetCudaAlignedSize(n * sizeof(T)); }\n    return tmp_buffer_size;\n  };\n}\n\ntemplate<typename T>\nuser_op::InferTmpSizeFn GenBwInferTmpSizeFn() {\n  return [](user_op::InferContext* ctx) {\n    const int64_t n = ctx->InputShape(\"target\", 0).elem_cnt();\n    size_t tmp_buffer_size = 0;\n    if (ctx->Attr<bool>(\"has_pos_weight\")) { tmp_buffer_size += GetCudaAlignedSize(n * sizeof(T)); }\n    return tmp_buffer_size;\n  };\n}\n\n}  // namespace\n\n#define REGISTER_BINARY_CROSS_ENTROPY_WITH_LOGITS_KERNEL(input_dtype, target_dtype)               \\\n  REGISTER_USER_KERNEL(\"binary_cross_entropy_with_logits\")                                        \\\n      .SetCreateFn<BinaryCrossEntropyWithLogitsKernel<input_dtype, target_dtype>>()               \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)                             \\\n                       && (user_op::HobDataType(\"input\", 0) == GetDataType<input_dtype>::value)   \\\n                       && (user_op::HobDataType(\"target\", 0) == GetDataType<target_dtype>::value) \\\n                       && (user_op::HobDataType(\"out\", 0) == GetDataType<target_dtype>::value))   \\\n      .SetInferTmpSizeFn(GenFwInferTmpSizeFn<target_dtype>());\n\n#define REGISTER_BINARY_CROSS_ENTROPY_WITH_LOGITS_GRAD_KERNEL(input_dtype, target_dtype)          \\\n  REGISTER_USER_KERNEL(\"binary_cross_entropy_with_logits_grad\")                                   \\\n      .SetCreateFn<BinaryCrossEntropyWithLogitsGradKernel<input_dtype, target_dtype>>()           \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)                             \\\n                       && (user_op::HobDataType(\"input\", 0) == GetDataType<input_dtype>::value)   \\\n                       && (user_op::HobDataType(\"target\", 0) == GetDataType<target_dtype>::value) \\\n                       && (user_op::HobDataType(\"dy\", 0) == GetDataType<target_dtype>::value)     \\\n                       && (user_op::HobDataType(\"dx\", 0) == GetDataType<input_dtype>::value))     \\\n      .SetInferTmpSizeFn(GenBwInferTmpSizeFn<target_dtype>());\n\nREGISTER_BINARY_CROSS_ENTROPY_WITH_LOGITS_KERNEL(float, float)\nREGISTER_BINARY_CROSS_ENTROPY_WITH_LOGITS_KERNEL(float, double)\nREGISTER_BINARY_CROSS_ENTROPY_WITH_LOGITS_KERNEL(double, float)\nREGISTER_BINARY_CROSS_ENTROPY_WITH_LOGITS_KERNEL(double, double)\n\nREGISTER_BINARY_CROSS_ENTROPY_WITH_LOGITS_GRAD_KERNEL(float, float)\nREGISTER_BINARY_CROSS_ENTROPY_WITH_LOGITS_GRAD_KERNEL(float, double)\nREGISTER_BINARY_CROSS_ENTROPY_WITH_LOGITS_GRAD_KERNEL(double, float)\nREGISTER_BINARY_CROSS_ENTROPY_WITH_LOGITS_GRAD_KERNEL(double, double)\n\n}  // namespace user_op\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/binary_cross_entropy_with_logits_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/cuda/elementwise.cuh\"\n#include \"oneflow/core/ndarray/ndarray_util.h\"\n#include \"oneflow/core/ndarray/xpu_var_ndarray.h\"\n#include \"oneflow/user/kernels/loss_kernel_util.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n#include \"oneflow/core/ep/include/primitive/broadcast_elementwise_binary.h\"\n\nnamespace oneflow {\nnamespace user_op {\nnamespace {\n\nusing namespace loss;\n\nenum class WeightType {\n  kNone,\n  kWeight,\n  kPosWeight,\n  kBoth,\n};\n\ntemplate<typename INPUT_T, typename TARGET_T, WeightType WEIGHT_TYPE>\nstruct BinaryCrossEntropyWithLogitsFunctor;\n\ntemplate<typename INPUT_T, typename TARGET_T>\nstruct BinaryCrossEntropyWithLogitsFunctor<INPUT_T, TARGET_T, WeightType::kNone> {\n  TARGET_T zero_;\n  TARGET_T one_;\n  BinaryCrossEntropyWithLogitsFunctor()\n      : zero_(GetZeroVal<TARGET_T>()), one_(GetOneVal<TARGET_T>()) {}\n  __device__ __forceinline__ TARGET_T operator()(INPUT_T input_val, TARGET_T target_val) const {\n    const TARGET_T input_val_ = static_cast<TARGET_T>(input_val);\n    const TARGET_T max_val = -input_val_ < zero_ ? zero_ : -input_val_;\n    return (one_ - target_val) * input_val_ + max_val\n           + (log(exp(-max_val) + exp(-input_val_ - max_val)));\n  }\n};\n\ntemplate<typename INPUT_T, typename TARGET_T>\nstruct BinaryCrossEntropyWithLogitsFunctor<INPUT_T, TARGET_T, WeightType::kPosWeight> {\n  TARGET_T zero_;\n  TARGET_T one_;\n  BinaryCrossEntropyWithLogitsFunctor()\n      : zero_(GetZeroVal<TARGET_T>()), one_(GetOneVal<TARGET_T>()) {}\n  __device__ __forceinline__ TARGET_T operator()(INPUT_T input_val, TARGET_T target_val,\n                                                 TARGET_T weight_val) const {\n    const TARGET_T input_val_ = static_cast<TARGET_T>(input_val);\n    const TARGET_T max_val = -input_val_ < zero_ ? zero_ : -input_val_;\n    const TARGET_T pos_weight_processed_val = weight_val - target_val + one_;\n    return (one_ - target_val) * input_val_\n           + (pos_weight_processed_val\n              * (log(exp(-max_val) + exp(-input_val_ - max_val)) + max_val));\n  }\n};\n\ntemplate<typename INPUT_T>\nstruct BinaryCrossEntropyWithLogitsFunctor<INPUT_T, float, WeightType::kNone> {\n  float zero_;\n  float one_;\n  BinaryCrossEntropyWithLogitsFunctor() : zero_(0.f), one_(1.f) {}\n  __device__ __forceinline__ float operator()(INPUT_T input_val, float target_val) const {\n    const float input_val_ = static_cast<float>(input_val);\n    const float max_val = -input_val_ < zero_ ? zero_ : -input_val_;\n    return (one_ - target_val) * input_val_ + max_val\n           + (logf(expf(-max_val) + expf(-input_val_ - max_val)));\n  }\n};\n\ntemplate<typename INPUT_T>\nstruct BinaryCrossEntropyWithLogitsFunctor<INPUT_T, float, WeightType::kPosWeight> {\n  float zero_;\n  float one_;\n  BinaryCrossEntropyWithLogitsFunctor() : zero_(0.f), one_(1.f) {}\n  __device__ __forceinline__ float operator()(INPUT_T input_val, float target_val,\n                                              float weight_val) const {\n    const float input_val_ = static_cast<float>(input_val);\n    const float max_val = -input_val_ < zero_ ? zero_ : -input_val_;\n    const float pos_weight_processed_val = weight_val - target_val + one_;\n    return (one_ - target_val) * input_val_\n           + (pos_weight_processed_val\n              * (logf(expf(-max_val) + expf(-input_val_ - max_val)) + max_val));\n  }\n};\n\ntemplate<typename INPUT_T, typename TARGET_T>\nstruct BinaryCrossEntropyWithLogitsFunctor<INPUT_T, TARGET_T, WeightType::kWeight> {\n  BinaryCrossEntropyWithLogitsFunctor<INPUT_T, TARGET_T, WeightType::kNone> f;\n  __device__ __forceinline__ TARGET_T operator()(INPUT_T input_val, TARGET_T target_val,\n                                                 TARGET_T weight_val) const {\n    return f(input_val, target_val) * weight_val;\n  }\n};\n\ntemplate<typename INPUT_T, typename TARGET_T>\nstruct BinaryCrossEntropyWithLogitsFunctor<INPUT_T, TARGET_T, WeightType::kBoth> {\n  BinaryCrossEntropyWithLogitsFunctor<INPUT_T, TARGET_T, WeightType::kPosWeight> f;\n  __device__ __forceinline__ TARGET_T operator()(INPUT_T input_val, TARGET_T target_val,\n                                                 TARGET_T weight_val,\n                                                 TARGET_T pos_weight_val) const {\n    return f(input_val, target_val, pos_weight_val) * weight_val;\n  }\n};\n\ntemplate<typename INPUT_T>\nstruct BinaryCrossEntropyWithLogitsFunctor<INPUT_T, half, WeightType::kNone> {\n  BinaryCrossEntropyWithLogitsFunctor<INPUT_T, float, WeightType::kNone> f;\n  __device__ __forceinline__ half operator()(INPUT_T input_val, half target_val) const {\n    return __float2half(f(input_val, __half2float(target_val)));\n  }\n};\ntemplate<typename INPUT_T>\nstruct BinaryCrossEntropyWithLogitsFunctor<INPUT_T, half, WeightType::kPosWeight> {\n  BinaryCrossEntropyWithLogitsFunctor<INPUT_T, float, WeightType::kPosWeight> f;\n  __device__ __forceinline__ half operator()(INPUT_T input_val, half target_val,\n                                             half weight_val) const {\n    return __float2half(f(input_val, __half2float(target_val), __half2float(weight_val)));\n  }\n};\ntemplate<typename INPUT_T>\nstruct BinaryCrossEntropyWithLogitsFunctor<INPUT_T, half, WeightType::kWeight> {\n  BinaryCrossEntropyWithLogitsFunctor<INPUT_T, float, WeightType::kWeight> f;\n  __device__ __forceinline__ half operator()(INPUT_T input_val, half target_val,\n                                             half weight_val) const {\n    return __float2half(f(input_val, __half2float(target_val), __half2float(weight_val)));\n  }\n};\ntemplate<typename INPUT_T>\nstruct BinaryCrossEntropyWithLogitsFunctor<INPUT_T, half, WeightType::kBoth> {\n  BinaryCrossEntropyWithLogitsFunctor<INPUT_T, float, WeightType::kBoth> f;\n  __device__ __forceinline__ half operator()(INPUT_T input_val, half target_val, half weight_val,\n                                             half pos_weight_val) const {\n    return __float2half(f(input_val, __half2float(target_val), __half2float(weight_val),\n                          __half2float(pos_weight_val)));\n  }\n};\n\ntemplate<>\nstruct BinaryCrossEntropyWithLogitsFunctor<half, half, WeightType::kNone> {\n  BinaryCrossEntropyWithLogitsFunctor<float, float, WeightType::kNone> f;\n  __device__ __forceinline__ half operator()(half input_val, half target_val) const {\n    return __float2half(f(__half2float(input_val), __half2float(target_val)));\n  }\n};\ntemplate<>\nstruct BinaryCrossEntropyWithLogitsFunctor<half, half, WeightType::kPosWeight> {\n  BinaryCrossEntropyWithLogitsFunctor<float, float, WeightType::kPosWeight> f;\n  __device__ __forceinline__ half operator()(half input_val, half target_val,\n                                             half weight_val) const {\n    return __float2half(\n        f(__half2float(input_val), __half2float(target_val), __half2float(weight_val)));\n  }\n};\ntemplate<>\nstruct BinaryCrossEntropyWithLogitsFunctor<half, half, WeightType::kWeight> {\n  BinaryCrossEntropyWithLogitsFunctor<float, float, WeightType::kWeight> f;\n  __device__ __forceinline__ half operator()(half input_val, half target_val,\n                                             half weight_val) const {\n    return __float2half(\n        f(__half2float(input_val), __half2float(target_val), __half2float(weight_val)));\n  }\n};\ntemplate<>\nstruct BinaryCrossEntropyWithLogitsFunctor<half, half, WeightType::kBoth> {\n  BinaryCrossEntropyWithLogitsFunctor<float, float, WeightType::kBoth> f;\n  __device__ __forceinline__ half operator()(half input_val, half target_val, half weight_val,\n                                             half pos_weight_val) const {\n    return __float2half(f(__half2float(input_val), __half2float(target_val),\n                          __half2float(weight_val), __half2float(pos_weight_val)));\n  }\n};\n\ntemplate<typename T>\n__device__ __forceinline__ T CalSigmoid(const T x) {\n  const T half_of_one = static_cast<T>(0.5);\n  return half_of_one * tanh(half_of_one * x) + half_of_one;\n}\n\ntemplate<>\n__device__ __forceinline__ float CalSigmoid(const float x) {\n  const float half_of_one = static_cast<float>(0.5);\n  return half_of_one * tanhf(half_of_one * x) + half_of_one;\n}\n\ntemplate<>\n__device__ __forceinline__ half CalSigmoid(const half x) {\n  return __float2half(CalSigmoid(__half2float(x)));\n}\n\ntemplate<typename INPUT_T, typename TARGET_T, WeightType WEIGHT_TYPE>\nstruct BinaryCrossEntropyWithLogitsGradFunctor;\n\ntemplate<typename INPUT_T, typename TARGET_T>\nstruct BinaryCrossEntropyWithLogitsGradFunctor<INPUT_T, TARGET_T, WeightType::kNone> {\n  __device__ __forceinline__ INPUT_T operator()(INPUT_T input_val, TARGET_T target_val,\n                                                TARGET_T dy_val) const {\n    return (CalSigmoid(input_val) - static_cast<INPUT_T>(target_val))\n           * static_cast<INPUT_T>(dy_val);\n  }\n};\ntemplate<typename INPUT_T, typename TARGET_T>\nstruct BinaryCrossEntropyWithLogitsGradFunctor<INPUT_T, TARGET_T, WeightType::kPosWeight> {\n  INPUT_T one_;\n  BinaryCrossEntropyWithLogitsGradFunctor() : one_(GetOneVal<INPUT_T>()) {}\n  __device__ __forceinline__ INPUT_T operator()(INPUT_T input_val, TARGET_T target_val,\n                                                TARGET_T dy_val, TARGET_T weight_val) const {\n    TARGET_T dx_tmp =\n        dy_val\n        * ((weight_val + one_ - target_val) * static_cast<TARGET_T>(CalSigmoid(input_val))\n           - weight_val);\n    return static_cast<INPUT_T>(dx_tmp);\n  }\n};\ntemplate<typename INPUT_T, typename TARGET_T>\nstruct BinaryCrossEntropyWithLogitsGradFunctor<INPUT_T, TARGET_T, WeightType::kWeight> {\n  BinaryCrossEntropyWithLogitsGradFunctor<INPUT_T, TARGET_T, WeightType::kNone> f;\n  __device__ __forceinline__ INPUT_T operator()(INPUT_T input_val, TARGET_T target_val,\n                                                TARGET_T dy_val, TARGET_T weight_val) const {\n    return f(input_val, target_val, dy_val) * static_cast<INPUT_T>(weight_val);\n  }\n};\ntemplate<typename INPUT_T, typename TARGET_T>\nstruct BinaryCrossEntropyWithLogitsGradFunctor<INPUT_T, TARGET_T, WeightType::kBoth> {\n  BinaryCrossEntropyWithLogitsGradFunctor<INPUT_T, TARGET_T, WeightType::kPosWeight> f;\n  __device__ __forceinline__ INPUT_T operator()(INPUT_T input_val, TARGET_T target_val,\n                                                TARGET_T dy_val, TARGET_T weight_val,\n                                                TARGET_T pos_weight_val) const {\n    return f(input_val, target_val, dy_val, pos_weight_val) * static_cast<INPUT_T>(weight_val);\n  }\n};\n\ntemplate<>\nstruct BinaryCrossEntropyWithLogitsGradFunctor<half, half, WeightType::kNone> {\n  __device__ __forceinline__ half operator()(half input_val, half target_val, half dy_val) const {\n    return (CalSigmoid(input_val) - target_val) * dy_val;\n  }\n};\ntemplate<>\nstruct BinaryCrossEntropyWithLogitsGradFunctor<half, half, WeightType::kPosWeight> {\n  half one_;\n  BinaryCrossEntropyWithLogitsGradFunctor() : one_(GetOneVal<half>()) {}\n  __device__ __forceinline__ half operator()(half input_val, half target_val, half dy_val,\n                                             half weight_val) const {\n    return dy_val * ((weight_val + one_ - target_val) * CalSigmoid(input_val) - weight_val);\n  }\n};\ntemplate<>\nstruct BinaryCrossEntropyWithLogitsGradFunctor<half, half, WeightType::kWeight> {\n  BinaryCrossEntropyWithLogitsGradFunctor<half, half, WeightType::kNone> f;\n  __device__ __forceinline__ half operator()(half input_val, half target_val, half dy_val,\n                                             half weight_val) const {\n    return f(input_val, target_val, dy_val) * weight_val;\n  }\n};\ntemplate<>\nstruct BinaryCrossEntropyWithLogitsGradFunctor<half, half, WeightType::kBoth> {\n  BinaryCrossEntropyWithLogitsGradFunctor<half, half, WeightType::kPosWeight> f;\n  __device__ __forceinline__ half operator()(half input_val, half target_val, half dy_val,\n                                             half weight_val, half pos_weight_val) const {\n    return f(input_val, target_val, dy_val, pos_weight_val) * weight_val;\n  }\n};\n\ntemplate<typename INPUT_T>\nstruct BinaryCrossEntropyWithLogitsGradFunctor<INPUT_T, half, WeightType::kNone> {\n  __device__ __forceinline__ INPUT_T operator()(INPUT_T input_val, half target_val,\n                                                half dy_val) const {\n    return (CalSigmoid(input_val) - static_cast<INPUT_T>(__half2float(target_val)))\n           * static_cast<INPUT_T>(__half2float(dy_val));\n  }\n};\ntemplate<typename INPUT_T>\nstruct BinaryCrossEntropyWithLogitsGradFunctor<INPUT_T, half, WeightType::kPosWeight> {\n  INPUT_T one_;\n  BinaryCrossEntropyWithLogitsGradFunctor() : one_(GetOneVal<INPUT_T>()) {}\n  __device__ __forceinline__ INPUT_T operator()(INPUT_T input_val, half target_val, half dy_val,\n                                                half weight_val) const {\n    const INPUT_T dy_val_f = static_cast<INPUT_T>(__half2float(dy_val));\n    const INPUT_T target_val_f = static_cast<INPUT_T>(__half2float(target_val));\n    const INPUT_T weight_val_f = static_cast<INPUT_T>(__half2float(weight_val));\n    return dy_val_f * ((weight_val_f + one_ - target_val_f) * CalSigmoid(input_val)) - weight_val_f;\n  }\n};\ntemplate<typename INPUT_T>\nstruct BinaryCrossEntropyWithLogitsGradFunctor<INPUT_T, half, WeightType::kWeight> {\n  BinaryCrossEntropyWithLogitsGradFunctor<INPUT_T, half, WeightType::kNone> f;\n  __device__ __forceinline__ INPUT_T operator()(INPUT_T input_val, half target_val, half dy_val,\n                                                half weight_val) const {\n    return f(input_val, target_val, dy_val) * static_cast<INPUT_T>(__half2float(weight_val));\n  }\n};\ntemplate<typename INPUT_T>\nstruct BinaryCrossEntropyWithLogitsGradFunctor<INPUT_T, half, WeightType::kBoth> {\n  BinaryCrossEntropyWithLogitsGradFunctor<INPUT_T, half, WeightType::kPosWeight> f;\n  __device__ __forceinline__ INPUT_T operator()(INPUT_T input_val, half target_val, half dy_val,\n                                                half weight_val, half pos_weight_val) const {\n    return f(input_val, target_val, dy_val, pos_weight_val)\n           * static_cast<INPUT_T>(__half2float(weight_val));\n  }\n};\n\ntemplate<typename TARGET_T>\nstruct BinaryCrossEntropyWithLogitsGradFunctor<half, TARGET_T, WeightType::kNone> {\n  __device__ __forceinline__ half operator()(half input_val, TARGET_T target_val,\n                                             TARGET_T dy_val) const {\n    const half dy_val_h = __float2half(static_cast<float>(dy_val));\n    const half target_val_h = __float2half(static_cast<float>(target_val));\n    return (CalSigmoid(input_val) - target_val_h) * dy_val_h;\n  }\n};\ntemplate<typename TARGET_T>\nstruct BinaryCrossEntropyWithLogitsGradFunctor<half, TARGET_T, WeightType::kPosWeight> {\n  half one_;\n  BinaryCrossEntropyWithLogitsGradFunctor() : one_(GetOneVal<half>()) {}\n  __device__ __forceinline__ half operator()(half input_val, TARGET_T target_val, TARGET_T dy_val,\n                                             TARGET_T weight_val) const {\n    const half dy_val_h = __float2half(static_cast<float>(dy_val));\n    const half target_val_h = __float2half(static_cast<float>(target_val));\n    const half weight_val_h = __float2half(static_cast<float>(weight_val));\n    return dy_val_h * ((weight_val_h + one_ - target_val_h) * CalSigmoid(input_val) - weight_val_h);\n  }\n};\ntemplate<typename TARGET_T>\nstruct BinaryCrossEntropyWithLogitsGradFunctor<half, TARGET_T, WeightType::kWeight> {\n  BinaryCrossEntropyWithLogitsGradFunctor<half, TARGET_T, WeightType::kNone> f;\n  __device__ __forceinline__ half operator()(half input_val, TARGET_T target_val, TARGET_T dy_val,\n                                             TARGET_T weight_val) const {\n    return f(input_val, target_val, dy_val) * __float2half(static_cast<float>(weight_val));\n  }\n};\ntemplate<typename TARGET_T>\nstruct BinaryCrossEntropyWithLogitsGradFunctor<half, TARGET_T, WeightType::kBoth> {\n  BinaryCrossEntropyWithLogitsGradFunctor<half, TARGET_T, WeightType::kPosWeight> f;\n  __device__ __forceinline__ half operator()(half input_val, TARGET_T target_val, TARGET_T dy_val,\n                                             TARGET_T weight_val, TARGET_T pos_weight_val) const {\n    return f(input_val, target_val, dy_val, pos_weight_val)\n           * __float2half(static_cast<float>(weight_val));\n  }\n};\n\ntemplate<typename INPUT_T, typename TARGET_T>\nclass BinaryCrossEntropyWithLogitsKernel final : public user_op::OpKernel {\n public:\n  BinaryCrossEntropyWithLogitsKernel() = default;\n  ~BinaryCrossEntropyWithLogitsKernel() override = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const auto* input_blob = ctx->Tensor4ArgNameAndIndex(\"input\", 0);\n    const auto* target_blob = ctx->Tensor4ArgNameAndIndex(\"target\", 0);\n    auto* out_blob = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    auto* tmp_buffer_blob = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n\n    const int64_t elem_cnt = input_blob->shape_view().elem_cnt();\n\n    const INPUT_T* input = input_blob->dptr<INPUT_T>();\n    const TARGET_T* target = target_blob->dptr<TARGET_T>();\n    TARGET_T* out = out_blob->mut_dptr<TARGET_T>();\n\n    if (ctx->Attr<bool>(\"has_pos_weight\")) {\n      TARGET_T* pos_weight_processed = tmp_buffer_blob->mut_dptr<TARGET_T>();\n      const TARGET_T* pos_weight = ctx->Tensor4ArgNameAndIndex(\"pos_weight\", 0)->dptr<TARGET_T>();\n\n      Shape pos_weight_shape = Shape::Ones(target_blob->shape_view().NumAxes());\n      pos_weight_shape.Set(pos_weight_shape.NumAxes() - 1,\n                           ctx->Tensor4ArgNameAndIndex(\"pos_weight\", 0)->shape_view().elem_cnt());\n      auto bcast_mul =\n          ep::primitive::NewPrimitive<ep::primitive::BroadcastElementwiseBinaryFactory>(\n              ctx->device_type(), ep::primitive::BinaryOp::kMul, target_blob->data_type(),\n              target_blob->data_type(), target_blob->shape_view().NumAxes());\n      CHECK(bcast_mul);\n      bcast_mul->Launch(ctx->stream(), target_blob->shape_view().NumAxes(),\n                        target_blob->shape_view().ptr(), target, pos_weight_shape.NumAxes(),\n                        pos_weight_shape.dim_vec().data(), pos_weight, pos_weight_processed);\n      if (ctx->has_input(\"weight\", 0)) {\n        const TARGET_T* weight = ctx->Tensor4ArgNameAndIndex(\"weight\", 0)->dptr<TARGET_T>();\n        using FunctorT = BinaryCrossEntropyWithLogitsFunctor<INPUT_T, TARGET_T, WeightType::kBoth>;\n        using FactoryT = cuda::elementwise::SimpleFactory<FunctorT>;\n        OF_CUDA_CHECK(\n            (cuda::elementwise::\n                 GenericLauncher<FactoryT, TARGET_T, INPUT_T, TARGET_T, TARGET_T, TARGET_T>::Launch(\n                     FactoryT(FunctorT()), elem_cnt, out, input, target, weight,\n                     pos_weight_processed, ctx->stream()->As<ep::CudaStream>()->cuda_stream())));\n      } else {\n        OF_CUDA_CHECK((cuda::elementwise::Ternary(\n            BinaryCrossEntropyWithLogitsFunctor<INPUT_T, TARGET_T, WeightType::kPosWeight>(),\n            elem_cnt, out, input, target, pos_weight_processed,\n            ctx->stream()->As<ep::CudaStream>()->cuda_stream())));\n      }\n    } else {\n      if (ctx->has_input(\"weight\", 0)) {\n        const TARGET_T* weight = ctx->Tensor4ArgNameAndIndex(\"weight\", 0)->dptr<TARGET_T>();\n        OF_CUDA_CHECK((cuda::elementwise::Ternary(\n            BinaryCrossEntropyWithLogitsFunctor<INPUT_T, TARGET_T, WeightType::kWeight>(), elem_cnt,\n            out, input, target, weight, ctx->stream()->As<ep::CudaStream>()->cuda_stream())));\n      } else {\n        OF_CUDA_CHECK((cuda::elementwise::Binary(\n            BinaryCrossEntropyWithLogitsFunctor<INPUT_T, TARGET_T, WeightType::kNone>(), elem_cnt,\n            out, input, target, ctx->stream()->As<ep::CudaStream>()->cuda_stream())));\n      }\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\ntemplate<typename INPUT_T, typename TARGET_T>\nclass BinaryCrossEntropyWithLogitsGradKernel final : public user_op::OpKernel {\n public:\n  BinaryCrossEntropyWithLogitsGradKernel() = default;\n  ~BinaryCrossEntropyWithLogitsGradKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const auto* input_blob = ctx->Tensor4ArgNameAndIndex(\"input\", 0);\n    const auto* target_blob = ctx->Tensor4ArgNameAndIndex(\"target\", 0);\n    const auto* dy_blob = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    auto* dx_blob = ctx->Tensor4ArgNameAndIndex(\"dx\", 0);\n    auto* tmp_buffer_blob = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n\n    const int64_t elem_cnt = input_blob->shape_view().elem_cnt();\n\n    const TARGET_T* dy = dy_blob->dptr<TARGET_T>();\n    const INPUT_T* input = input_blob->dptr<INPUT_T>();\n    const TARGET_T* target = target_blob->dptr<TARGET_T>();\n    INPUT_T* dx = dx_blob->mut_dptr<INPUT_T>();\n\n    if (ctx->Attr<bool>(\"has_pos_weight\")) {\n      TARGET_T* pos_weight_processed = tmp_buffer_blob->mut_dptr<TARGET_T>();\n      const TARGET_T* pos_weight = ctx->Tensor4ArgNameAndIndex(\"pos_weight\", 0)->dptr<TARGET_T>();\n\n      Shape pos_weight_shape = Shape::Ones(target_blob->shape_view().NumAxes());\n      pos_weight_shape.Set(pos_weight_shape.NumAxes() - 1,\n                           ctx->Tensor4ArgNameAndIndex(\"pos_weight\", 0)->shape_view().elem_cnt());\n      auto bcast_mul =\n          ep::primitive::NewPrimitive<ep::primitive::BroadcastElementwiseBinaryFactory>(\n              ctx->device_type(), ep::primitive::BinaryOp::kMul, target_blob->data_type(),\n              target_blob->data_type(), target_blob->shape_view().NumAxes());\n      CHECK(bcast_mul);\n      bcast_mul->Launch(ctx->stream(), target_blob->shape_view().NumAxes(),\n                        target_blob->shape_view().ptr(), target, pos_weight_shape.NumAxes(),\n                        pos_weight_shape.dim_vec().data(), pos_weight, pos_weight_processed);\n      if (ctx->has_input(\"weight\", 0)) {\n        const TARGET_T* weight = ctx->Tensor4ArgNameAndIndex(\"weight\", 0)->dptr<TARGET_T>();\n        using FunctorT =\n            BinaryCrossEntropyWithLogitsGradFunctor<INPUT_T, TARGET_T, WeightType::kBoth>;\n        using FactoryT = cuda::elementwise::SimpleFactory<FunctorT>;\n        OF_CUDA_CHECK((cuda::elementwise::GenericLauncher<\n                       FactoryT, INPUT_T, INPUT_T, TARGET_T, TARGET_T, TARGET_T,\n                       TARGET_T>::Launch(FactoryT(FunctorT()), elem_cnt, dx, input, target, dy,\n                                         weight, pos_weight_processed,\n                                         ctx->stream()->As<ep::CudaStream>()->cuda_stream())));\n\n      } else {\n        using FunctorT =\n            BinaryCrossEntropyWithLogitsGradFunctor<INPUT_T, TARGET_T, WeightType::kPosWeight>;\n        using FactoryT = cuda::elementwise::SimpleFactory<FunctorT>;\n        OF_CUDA_CHECK(\n            (cuda::elementwise::\n                 GenericLauncher<FactoryT, INPUT_T, INPUT_T, TARGET_T, TARGET_T, TARGET_T>::Launch(\n                     FactoryT(FunctorT()), elem_cnt, dx, input, target, dy, pos_weight_processed,\n                     ctx->stream()->As<ep::CudaStream>()->cuda_stream())));\n      }\n    } else {\n      if (ctx->has_input(\"weight\", 0)) {\n        const TARGET_T* weight = ctx->Tensor4ArgNameAndIndex(\"weight\", 0)->dptr<TARGET_T>();\n        using FunctorT =\n            BinaryCrossEntropyWithLogitsGradFunctor<INPUT_T, TARGET_T, WeightType::kWeight>;\n        using FactoryT = cuda::elementwise::SimpleFactory<FunctorT>;\n        OF_CUDA_CHECK(\n            (cuda::elementwise::\n                 GenericLauncher<FactoryT, INPUT_T, INPUT_T, TARGET_T, TARGET_T, TARGET_T>::Launch(\n                     FactoryT(FunctorT()), elem_cnt, dx, input, target, dy, weight,\n                     ctx->stream()->As<ep::CudaStream>()->cuda_stream())));\n      } else {\n        OF_CUDA_CHECK((cuda::elementwise::Ternary(\n            BinaryCrossEntropyWithLogitsGradFunctor<INPUT_T, TARGET_T, WeightType::kNone>(),\n            elem_cnt, dx, input, target, dy, ctx->stream()->As<ep::CudaStream>()->cuda_stream())));\n      }\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\ntemplate<typename T>\nuser_op::InferTmpSizeFn GenFwInferTmpSizeFn() {\n  return [](user_op::InferContext* ctx) {\n    const int64_t n = ctx->InputShape(\"input\", 0).elem_cnt();\n    size_t tmp_buffer_size = 0;\n    if (ctx->Attr<bool>(\"has_pos_weight\")) { tmp_buffer_size += GetCudaAlignedSize(n * sizeof(T)); }\n    return tmp_buffer_size;\n  };\n}\n\ntemplate<typename T>\nuser_op::InferTmpSizeFn GenBwInferTmpSizeFn() {\n  return [](user_op::InferContext* ctx) {\n    const int64_t n = ctx->InputShape(\"target\", 0).elem_cnt();\n    size_t tmp_buffer_size = 0;\n    if (ctx->Attr<bool>(\"has_pos_weight\")) { tmp_buffer_size += GetCudaAlignedSize(n * sizeof(T)); }\n    return tmp_buffer_size;\n  };\n}\n\n}  // namespace\n\n#define REGISTER_BINARY_CROSS_ENTROPY_KERNEL(input_dtype, target_dtype)                           \\\n  REGISTER_USER_KERNEL(\"binary_cross_entropy_with_logits\")                                        \\\n      .SetCreateFn<BinaryCrossEntropyWithLogitsKernel<input_dtype, target_dtype>>()               \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)                            \\\n                       && (user_op::HobDataType(\"input\", 0) == GetDataType<input_dtype>::value)   \\\n                       && (user_op::HobDataType(\"target\", 0) == GetDataType<target_dtype>::value) \\\n                       && (user_op::HobDataType(\"out\", 0) == GetDataType<target_dtype>::value))   \\\n      .SetInferTmpSizeFn(GenFwInferTmpSizeFn<target_dtype>());\n\n#define REGISTER_BINARY_CROSS_ENTROPY_GRAD_KERNEL(input_dtype, target_dtype)                      \\\n  REGISTER_USER_KERNEL(\"binary_cross_entropy_with_logits_grad\")                                   \\\n      .SetCreateFn<BinaryCrossEntropyWithLogitsGradKernel<input_dtype, target_dtype>>()           \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)                            \\\n                       && (user_op::HobDataType(\"input\", 0) == GetDataType<input_dtype>::value)   \\\n                       && (user_op::HobDataType(\"target\", 0) == GetDataType<target_dtype>::value) \\\n                       && (user_op::HobDataType(\"dy\", 0) == GetDataType<target_dtype>::value)     \\\n                       && (user_op::HobDataType(\"dx\", 0) == GetDataType<input_dtype>::value))     \\\n      .SetInferTmpSizeFn(GenBwInferTmpSizeFn<target_dtype>());\n\nREGISTER_BINARY_CROSS_ENTROPY_KERNEL(half, half)\nREGISTER_BINARY_CROSS_ENTROPY_KERNEL(half, float)\nREGISTER_BINARY_CROSS_ENTROPY_KERNEL(float, half)\nREGISTER_BINARY_CROSS_ENTROPY_KERNEL(half, double)\nREGISTER_BINARY_CROSS_ENTROPY_KERNEL(double, half)\nREGISTER_BINARY_CROSS_ENTROPY_KERNEL(float, float)\nREGISTER_BINARY_CROSS_ENTROPY_KERNEL(float, double)\nREGISTER_BINARY_CROSS_ENTROPY_KERNEL(double, float)\nREGISTER_BINARY_CROSS_ENTROPY_KERNEL(double, double)\n\nREGISTER_BINARY_CROSS_ENTROPY_GRAD_KERNEL(half, half)\nREGISTER_BINARY_CROSS_ENTROPY_GRAD_KERNEL(half, float)\nREGISTER_BINARY_CROSS_ENTROPY_GRAD_KERNEL(float, half)\nREGISTER_BINARY_CROSS_ENTROPY_GRAD_KERNEL(half, double)\nREGISTER_BINARY_CROSS_ENTROPY_GRAD_KERNEL(double, half)\nREGISTER_BINARY_CROSS_ENTROPY_GRAD_KERNEL(float, float)\nREGISTER_BINARY_CROSS_ENTROPY_GRAD_KERNEL(float, double)\nREGISTER_BINARY_CROSS_ENTROPY_GRAD_KERNEL(double, float)\nREGISTER_BINARY_CROSS_ENTROPY_GRAD_KERNEL(double, double)\n\n}  // namespace user_op\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/binary_cross_entropy_with_logits_mean_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/user/kernels/binary_cross_entropy_with_logits_mean_kernel_util.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n#include \"oneflow/core/cuda/elementwise.cuh\"\n#include <cub/cub.cuh>\n#include \"oneflow/core/kernel/cuda_graph_support.h\"\n\nnamespace oneflow {\n\nnamespace user_op {\n\nnamespace {\n\nconstexpr int32_t kBlockSize = 1024;\nconstexpr int32_t kReduceLocalSumBlockSize = 1024;\nconstexpr int32_t kSingleBlockProcessNumThreshold = 1024;\n\ntemplate<typename T>\nstruct DefaultComputeType {\n  using type = T;\n};\n\ntemplate<>\nstruct DefaultComputeType<half> {\n  using type = float;\n};\n\ntemplate<class Func>\ninline cudaError_t GetNumBlocks(Func func, int64_t block_size, size_t dynamic_smem_size,\n                                int64_t max_blocks, int64_t waves, int* num_blocks) {\n  int dev;\n  {\n    cudaError_t err = cudaGetDevice(&dev);\n    if (err != cudaSuccess) { return err; }\n  }\n  int sm_count;\n  {\n    cudaError_t err = cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev);\n    if (err != cudaSuccess) { return err; }\n  }\n  int max_active_blocks;\n  {\n    cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_active_blocks, func,\n                                                                    block_size, dynamic_smem_size);\n  }\n  *num_blocks =\n      std::max<int>(1, std::min<int64_t>(max_blocks, sm_count * max_active_blocks * waves));\n  return cudaSuccess;\n}\n\ntemplate<typename T>\n__device__ __forceinline__ T Sigmoid(const T x) {\n  const T half_of_one = static_cast<T>(0.5);\n  return half_of_one * tanh(half_of_one * x) + half_of_one;\n}\n\ntemplate<>\n__device__ __forceinline__ half Sigmoid(const half x) {\n  return __float2half(Sigmoid(__half2float(x)));\n}\n\ntemplate<typename INPUT_T, typename TARGET_T, typename OUTPUT_T, typename ComputeType>\n__global__ void FusedBinaryCrossEntropyWithLogitsReduceMeanKernel(const INPUT_T* input,\n                                                                  const TARGET_T* target,\n                                                                  OUTPUT_T* out,\n                                                                  const int64_t local_elem_cnt,\n                                                                  const int64_t reduce_elem_cnt) {\n  ComputeType zero = static_cast<ComputeType>(0.0);\n  ComputeType one = static_cast<ComputeType>(1.0);\n  using BlockReduce = cub::BlockReduce<ComputeType, kBlockSize>;\n  __shared__ typename BlockReduce::TempStorage temp_storage;\n  ComputeType reduce_sum = static_cast<ComputeType>(0.0);\n  CUDA_1D_KERNEL_LOOP(i, local_elem_cnt) {\n    const ComputeType input_val = static_cast<ComputeType>(input[i]);\n    const ComputeType target_val = static_cast<ComputeType>(target[i]);\n    const ComputeType max_val = -input_val < zero ? zero : -input_val;\n    const ComputeType result =\n        (one - target_val) * input_val + max_val + (log(exp(-max_val) + exp(-input_val - max_val)));\n    reduce_sum += result;\n  }\n\n  const ComputeType block_reduce_sum = BlockReduce(temp_storage).Sum(reduce_sum);\n  if (threadIdx.x == 0) {\n    out[blockIdx.x] = static_cast<OUTPUT_T>(block_reduce_sum / reduce_elem_cnt);\n  }\n}\n\ntemplate<typename TARGET_T, typename INPUT_T>\n__global__ void ReduceLocalSumKernel(INPUT_T* block_local_sum_buf, TARGET_T* out,\n                                     int64_t elem_cnt) {\n  using BlockReduce = cub::BlockReduce<INPUT_T, kReduceLocalSumBlockSize>;\n  __shared__ typename BlockReduce::TempStorage temp_storage;\n  INPUT_T reduce_sum = 0.0;\n  CUDA_1D_KERNEL_LOOP(i, elem_cnt) { reduce_sum += block_local_sum_buf[i]; }\n  const INPUT_T block_reduce_sum = BlockReduce(temp_storage).Sum(reduce_sum);\n  if (threadIdx.x == 0) { out[0] = block_reduce_sum; }\n}\n\ntemplate<typename INPUT_T, typename TARGET_T, typename ComputeType>\nstruct BinaryCrossEntropyWithLogitsReduceMeanGradFunctor {\n  OF_DEVICE_FUNC explicit BinaryCrossEntropyWithLogitsReduceMeanGradFunctor(\n      const INPUT_T elem_cnt_reciprocal, const TARGET_T dy)\n      : elem_cnt_reciprocal(elem_cnt_reciprocal), dy(dy) {}\n  __device__ ComputeType operator()(const INPUT_T input_val, const TARGET_T target_val) const {\n    const ComputeType input_val_ = static_cast<ComputeType>(input_val);\n    const ComputeType target_val_ = static_cast<ComputeType>(target_val);\n    const ComputeType dy_ = static_cast<ComputeType>(dy);\n    const ComputeType elem_cnt_reciprocal_ = static_cast<ComputeType>(elem_cnt_reciprocal);\n    return (Sigmoid(input_val_) - target_val_) * dy_ * elem_cnt_reciprocal_;\n  }\n  const TARGET_T dy;\n  const INPUT_T elem_cnt_reciprocal;\n};\n\ntemplate<typename INPUT_T, typename TARGET_T, typename ComputeType>\nstruct BinaryCrossEntropyWithLogitsReduceMeanGradDyptrFunctor {\n  OF_DEVICE_FUNC explicit BinaryCrossEntropyWithLogitsReduceMeanGradDyptrFunctor(\n      const int32_t elem_cnt, const TARGET_T* dy_ptr)\n      : elem_cnt_reciprocal(1.0f / elem_cnt), dy_ptr(dy_ptr) {}\n  __device__ BinaryCrossEntropyWithLogitsReduceMeanGradFunctor<INPUT_T, TARGET_T, ComputeType>\n  operator()() const {\n    return BinaryCrossEntropyWithLogitsReduceMeanGradFunctor<INPUT_T, TARGET_T, ComputeType>(\n        elem_cnt_reciprocal, *dy_ptr);\n  }\n  const TARGET_T* dy_ptr;\n  const INPUT_T elem_cnt_reciprocal;\n};\n\ntemplate<typename INPUT_T, typename TARGET_T, typename ComputeType>\n__global__ void FusedBCEReduceMeanFwBwKernel(const INPUT_T* input, const TARGET_T* target,\n                                             TARGET_T* out, INPUT_T* input_grad,\n                                             const ComputeType constant_output_grad,\n                                             const ComputeType elem_cnt_reciprocal,\n                                             const int32_t local_elem_cnt,\n                                             const int32_t reduce_elem_cnt) {\n  ComputeType zero = static_cast<ComputeType>(0.0);\n  ComputeType one = static_cast<ComputeType>(1.0);\n  BinaryCrossEntropyWithLogitsReduceMeanGradFunctor<INPUT_T, TARGET_T, ComputeType> grad_functor(\n      elem_cnt_reciprocal, constant_output_grad);\n  using BlockReduce = cub::BlockReduce<ComputeType, kBlockSize>;\n  __shared__ typename BlockReduce::TempStorage temp_storage;\n  TARGET_T reduce_sum = 0.0;\n  CUDA_1D_KERNEL_LOOP(i, local_elem_cnt) {\n    const INPUT_T input_val = input[i];\n    const TARGET_T target_val = target[i];\n    input_grad[i] = grad_functor(input_val, target_val);\n    const ComputeType input_val_ = static_cast<ComputeType>(input_val);\n    const ComputeType target_val_ = static_cast<ComputeType>(target_val);\n    const ComputeType max_val = -input_val_ < zero ? zero : -input_val_;\n    const ComputeType result = (one - target_val_) * input_val_ + max_val\n                               + (log(exp(-max_val) + exp(-input_val_ - max_val)));\n    reduce_sum += result;\n  }\n  const ComputeType block_reduce_sum = BlockReduce(temp_storage).Sum(reduce_sum);\n  if (threadIdx.x == 0) {\n    out[blockIdx.x] = static_cast<TARGET_T>(block_reduce_sum / reduce_elem_cnt);\n  }\n}\n\ntemplate<typename INPUT_T, typename TARGET_T>\nclass FusedBCEMeanFwBwKernel final : public user_op::OpKernel, public CudaGraphSupport {\n public:\n  FusedBCEMeanFwBwKernel() = default;\n  ~FusedBCEMeanFwBwKernel() override = default;\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n\n  std::shared_ptr<user_op::OpKernelCache> InitOpKernelCache(\n      user_op::KernelCacheContext* ctx) const override {\n    return CreateBCEWithLogitsReduceMeanKernelCache(ctx);\n  }\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state,\n               const user_op::OpKernelCache* cache) const override {\n    const auto* input_blob = ctx->Tensor4ArgNameAndIndex(\"input\", 0);\n    const auto* target_blob = ctx->Tensor4ArgNameAndIndex(\"target\", 0);\n    auto* out_blob = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    auto* dx_blob = ctx->Tensor4ArgNameAndIndex(\"dx\", 0);\n\n    int64_t local_elem_cnt = input_blob->shape_view().elem_cnt();\n    int64_t reduce_elem_cnt = local_elem_cnt;\n\n    if (cache != nullptr) {\n      // Because `out`'s SBP maybe P or B, we need to use reduce_elem_cnt as reduce_mean factor.\n      const auto* bce_cache = dynamic_cast<const BCEWithLogitsReduceMeanKernelCache*>(cache);\n      CHECK_NOTNULL(bce_cache);\n      reduce_elem_cnt = bce_cache->reduce_elem_cnt();\n    }\n\n    const INPUT_T* input = input_blob->dptr<INPUT_T>();\n    const TARGET_T* target = target_blob->dptr<TARGET_T>();\n    using ComputeType = typename DefaultComputeType<TARGET_T>::type;\n    ComputeType constant_output_grad = ctx->Attr<double>(\"constant_value\");\n    ComputeType elem_cnt_reciprocal = static_cast<ComputeType>(1) / reduce_elem_cnt;\n\n    if (local_elem_cnt <= kSingleBlockProcessNumThreshold) {\n      FusedBCEReduceMeanFwBwKernel<INPUT_T, TARGET_T, ComputeType>\n          <<<1, kBlockSize, 0, ctx->stream()->As<ep::CudaStream>()->cuda_stream()>>>(\n              input_blob->dptr<INPUT_T>(), target_blob->dptr<TARGET_T>(),\n              out_blob->mut_dptr<TARGET_T>(), dx_blob->mut_dptr<INPUT_T>(), constant_output_grad,\n              elem_cnt_reciprocal, local_elem_cnt, reduce_elem_cnt);\n    } else {\n      auto* tmp_buffer = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n      const int64_t tmp_buffer_elem_cnt = tmp_buffer->shape_view().elem_cnt() / sizeof(TARGET_T);\n      const int64_t block_num = (local_elem_cnt + kBlockSize - 1) / kBlockSize;\n      int launch_block = block_num;\n      OF_CUDA_CHECK(GetNumBlocks(FusedBCEReduceMeanFwBwKernel<INPUT_T, ComputeType, ComputeType>,\n                                 kBlockSize, 0, block_num, 32, &launch_block));\n      launch_block = std::min<int32_t>(tmp_buffer_elem_cnt, launch_block);\n      FusedBCEReduceMeanFwBwKernel<INPUT_T, TARGET_T, ComputeType>\n          <<<launch_block, kBlockSize, 0, ctx->stream()->As<ep::CudaStream>()->cuda_stream()>>>(\n              input_blob->dptr<INPUT_T>(), target_blob->dptr<TARGET_T>(),\n              tmp_buffer->mut_dptr<TARGET_T>(), dx_blob->mut_dptr<INPUT_T>(), constant_output_grad,\n              elem_cnt_reciprocal, local_elem_cnt, reduce_elem_cnt);\n      ReduceLocalSumKernel<TARGET_T, ComputeType>\n          <<<1, kReduceLocalSumBlockSize, 0, ctx->stream()->As<ep::CudaStream>()->cuda_stream()>>>(\n              tmp_buffer->mut_dptr<ComputeType>(), out_blob->mut_dptr<TARGET_T>(), block_num);\n    }\n  }\n};\n\ntemplate<typename INPUT_T, typename TARGET_T>\nclass BinaryCrossEntropyWithLogitsMeanKernel final : public user_op::OpKernel,\n                                                     public CudaGraphSupport {\n public:\n  BinaryCrossEntropyWithLogitsMeanKernel() = default;\n  ~BinaryCrossEntropyWithLogitsMeanKernel() override = default;\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n\n  std::shared_ptr<user_op::OpKernelCache> InitOpKernelCache(\n      user_op::KernelCacheContext* ctx) const override {\n    return CreateBCEWithLogitsReduceMeanKernelCache(ctx);\n  }\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state,\n               const user_op::OpKernelCache* cache) const override {\n    const auto* input_blob = ctx->Tensor4ArgNameAndIndex(\"input\", 0);\n    const auto* target_blob = ctx->Tensor4ArgNameAndIndex(\"target\", 0);\n    auto* out_blob = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n\n    int64_t local_elem_cnt = input_blob->shape_view().elem_cnt();\n    int64_t reduce_elem_cnt = local_elem_cnt;\n\n    if (cache != nullptr) {\n      // Because `out`'s SBP maybe P or B, we need to use reduce_elem_cnt as reduce_mean factor.\n      const auto* bce_cache = dynamic_cast<const BCEWithLogitsReduceMeanKernelCache*>(cache);\n      CHECK_NOTNULL(bce_cache);\n      reduce_elem_cnt = bce_cache->reduce_elem_cnt();\n    }\n\n    const INPUT_T* input = input_blob->dptr<INPUT_T>();\n    const TARGET_T* target = target_blob->dptr<TARGET_T>();\n    TARGET_T* out = out_blob->mut_dptr<TARGET_T>();\n    using ComputeType = typename DefaultComputeType<TARGET_T>::type;\n\n    if (local_elem_cnt <= kSingleBlockProcessNumThreshold) {\n      FusedBinaryCrossEntropyWithLogitsReduceMeanKernel<INPUT_T, TARGET_T, TARGET_T, ComputeType>\n          <<<1, kBlockSize, 0, ctx->stream()->As<ep::CudaStream>()->cuda_stream()>>>(\n              input_blob->dptr<INPUT_T>(), target_blob->dptr<TARGET_T>(),\n              out_blob->mut_dptr<TARGET_T>(), local_elem_cnt, reduce_elem_cnt);\n    } else {\n      auto* tmp_buffer = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n      const int64_t tmp_buffer_elem_cnt = tmp_buffer->shape_view().elem_cnt() / sizeof(TARGET_T);\n      const int64_t block_num = (local_elem_cnt + kBlockSize - 1) / kBlockSize;\n      int launch_block = block_num;\n      OF_CUDA_CHECK(\n          GetNumBlocks(FusedBinaryCrossEntropyWithLogitsReduceMeanKernel<INPUT_T, TARGET_T,\n                                                                         ComputeType, ComputeType>,\n                       kBlockSize, 0, block_num, 32, &launch_block));\n      launch_block = std::min<int64_t>(tmp_buffer_elem_cnt, launch_block);\n      FusedBinaryCrossEntropyWithLogitsReduceMeanKernel<INPUT_T, TARGET_T, ComputeType, ComputeType>\n          <<<launch_block, kBlockSize, 0, ctx->stream()->As<ep::CudaStream>()->cuda_stream()>>>(\n              input_blob->dptr<INPUT_T>(), target_blob->dptr<TARGET_T>(),\n              tmp_buffer->mut_dptr<ComputeType>(), local_elem_cnt, reduce_elem_cnt);\n      ReduceLocalSumKernel<TARGET_T, ComputeType>\n          <<<1, kReduceLocalSumBlockSize, 0, ctx->stream()->As<ep::CudaStream>()->cuda_stream()>>>(\n              tmp_buffer->mut_dptr<ComputeType>(), out_blob->mut_dptr<TARGET_T>(), block_num);\n    }\n  }\n};\n\ntemplate<typename INPUT_T, typename TARGET_T>\nclass BinaryCrossEntropyWithLogitsReduceMeanGradKernel final : public user_op::OpKernel {\n public:\n  BinaryCrossEntropyWithLogitsReduceMeanGradKernel() = default;\n  ~BinaryCrossEntropyWithLogitsReduceMeanGradKernel() = default;\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n\n  std::shared_ptr<user_op::OpKernelCache> InitOpKernelCache(\n      user_op::KernelCacheContext* ctx) const override {\n    return CreateBCEWithLogitsReduceMeanKernelCache(ctx);\n  }\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state,\n               const user_op::OpKernelCache* cache) const override {\n    const auto* input_blob = ctx->Tensor4ArgNameAndIndex(\"input\", 0);\n    const auto* target_blob = ctx->Tensor4ArgNameAndIndex(\"target\", 0);\n    const auto* dy_blob = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    auto* dx_blob = ctx->Tensor4ArgNameAndIndex(\"dx\", 0);\n\n    int64_t local_elem_cnt = input_blob->shape_view().elem_cnt();\n    int64_t reduce_elem_cnt = local_elem_cnt;\n    if (cache != nullptr) {\n      // Because `out`'s SBP maybe P or B, we need to use reduce_elem_cnt as reduce_mean factor.\n      const auto* bce_cache = dynamic_cast<const BCEWithLogitsReduceMeanKernelCache*>(cache);\n      CHECK_NOTNULL(bce_cache);\n      reduce_elem_cnt = bce_cache->reduce_elem_cnt();\n    }\n\n    const TARGET_T* dy = dy_blob->dptr<TARGET_T>();\n    const INPUT_T* input = input_blob->dptr<INPUT_T>();\n    const TARGET_T* target = target_blob->dptr<TARGET_T>();\n    INPUT_T* dx = dx_blob->mut_dptr<INPUT_T>();\n    using ComputeType = typename DefaultComputeType<TARGET_T>::type;\n\n    OF_CUDA_CHECK((cuda::elementwise::BinaryWithFactory(\n        BinaryCrossEntropyWithLogitsReduceMeanGradDyptrFunctor<INPUT_T, TARGET_T, ComputeType>(\n            reduce_elem_cnt, dy),\n        local_elem_cnt, dx, input, target, ctx->stream()->As<ep::CudaStream>()->cuda_stream())));\n  }\n};\n\n}  // namespace\n\n#define REGISTER_BINARY_CROSS_ENTROPY_REDUCE_MEAN_KERNEL(input_dtype, target_dtype)               \\\n  REGISTER_USER_KERNEL(\"binary_cross_entropy_with_logits_reduce_mean\")                            \\\n      .SetCreateFn<BinaryCrossEntropyWithLogitsMeanKernel<input_dtype, target_dtype>>()           \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)                            \\\n                       && (user_op::HobDataType(\"input\", 0) == GetDataType<input_dtype>::value)   \\\n                       && (user_op::HobDataType(\"target\", 0) == GetDataType<target_dtype>::value) \\\n                       && (user_op::HobDataType(\"out\", 0) == GetDataType<target_dtype>::value))   \\\n      .SetInferTmpSizeFn([](user_op::InferContext* ctx) {                                         \\\n        const int64_t elem_cnt = ctx->InputShape(\"input\", 0).elem_cnt();                          \\\n        const int64_t block_num = (elem_cnt + kBlockSize - 1) / kBlockSize;                       \\\n        int launch_block = block_num;                                                             \\\n        using compute_dtype = typename DefaultComputeType<target_dtype>::type;                    \\\n        OF_CUDA_CHECK(GetNumBlocks(                                                               \\\n            FusedBinaryCrossEntropyWithLogitsReduceMeanKernel<input_dtype, target_dtype,          \\\n                                                              compute_dtype, compute_dtype>,      \\\n            kBlockSize, 0, block_num, 32, &launch_block));                                        \\\n        const int64_t tmp_buffer_size = GetCudaAlignedSize(launch_block * sizeof(compute_dtype)); \\\n        return tmp_buffer_size;                                                                   \\\n      });\n\n#define REGISTER_BINARY_CROSS_ENTROPY_REDUCE_MEAN_GRAD_KERNEL(input_dtype, target_dtype)          \\\n  REGISTER_USER_KERNEL(\"binary_cross_entropy_with_logits_reduce_mean_grad\")                       \\\n      .SetCreateFn<BinaryCrossEntropyWithLogitsReduceMeanGradKernel<input_dtype, target_dtype>>() \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)                            \\\n                       && (user_op::HobDataType(\"input\", 0) == GetDataType<input_dtype>::value)   \\\n                       && (user_op::HobDataType(\"target\", 0) == GetDataType<target_dtype>::value) \\\n                       && (user_op::HobDataType(\"dy\", 0) == GetDataType<target_dtype>::value)     \\\n                       && (user_op::HobDataType(\"dx\", 0) == GetDataType<input_dtype>::value));\n\nREGISTER_BINARY_CROSS_ENTROPY_REDUCE_MEAN_KERNEL(half, half)\nREGISTER_BINARY_CROSS_ENTROPY_REDUCE_MEAN_KERNEL(half, float)\nREGISTER_BINARY_CROSS_ENTROPY_REDUCE_MEAN_KERNEL(float, half)\nREGISTER_BINARY_CROSS_ENTROPY_REDUCE_MEAN_KERNEL(half, double)\nREGISTER_BINARY_CROSS_ENTROPY_REDUCE_MEAN_KERNEL(double, half)\nREGISTER_BINARY_CROSS_ENTROPY_REDUCE_MEAN_KERNEL(float, float)\nREGISTER_BINARY_CROSS_ENTROPY_REDUCE_MEAN_KERNEL(float, double)\nREGISTER_BINARY_CROSS_ENTROPY_REDUCE_MEAN_KERNEL(double, float)\nREGISTER_BINARY_CROSS_ENTROPY_REDUCE_MEAN_KERNEL(double, double)\n\nREGISTER_BINARY_CROSS_ENTROPY_REDUCE_MEAN_GRAD_KERNEL(half, half)\nREGISTER_BINARY_CROSS_ENTROPY_REDUCE_MEAN_GRAD_KERNEL(half, float)\nREGISTER_BINARY_CROSS_ENTROPY_REDUCE_MEAN_GRAD_KERNEL(float, half)\nREGISTER_BINARY_CROSS_ENTROPY_REDUCE_MEAN_GRAD_KERNEL(half, double)\nREGISTER_BINARY_CROSS_ENTROPY_REDUCE_MEAN_GRAD_KERNEL(double, half)\nREGISTER_BINARY_CROSS_ENTROPY_REDUCE_MEAN_GRAD_KERNEL(float, float)\nREGISTER_BINARY_CROSS_ENTROPY_REDUCE_MEAN_GRAD_KERNEL(float, double)\nREGISTER_BINARY_CROSS_ENTROPY_REDUCE_MEAN_GRAD_KERNEL(double, float)\nREGISTER_BINARY_CROSS_ENTROPY_REDUCE_MEAN_GRAD_KERNEL(double, double)\n\n#define REGISTER_FUSED_BCE_REDUCE_MEAN_FW_BW_KERNEL(input_dtype, target_dtype)                   \\\n  REGISTER_USER_KERNEL(\"fused_bce_reduce_mean_fw_bw\")                                            \\\n      .SetCreateFn<FusedBCEMeanFwBwKernel<input_dtype, target_dtype>>()                          \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)                           \\\n                       && (user_op::HobDataType(\"out\", 0) == GetDataType<target_dtype>::value))  \\\n      .SetInferTmpSizeFn([](user_op::InferContext* ctx) {                                        \\\n        const int64_t elem_cnt = ctx->InputShape(\"input\", 0).elem_cnt();                         \\\n        const int64_t block_num = (elem_cnt + kBlockSize - 1) / kBlockSize;                      \\\n        int launch_block = block_num;                                                            \\\n        using compute_dtype = typename DefaultComputeType<target_dtype>::type;                   \\\n        OF_CUDA_CHECK(GetNumBlocks(                                                              \\\n            FusedBinaryCrossEntropyWithLogitsReduceMeanKernel<input_dtype, target_dtype,         \\\n                                                              compute_dtype, compute_dtype>,     \\\n            kBlockSize, 0, block_num, 32, &launch_block));                                       \\\n        const int64_t tmp_buffer_size = GetCudaAlignedSize(launch_block * sizeof(target_dtype)); \\\n        return tmp_buffer_size;                                                                  \\\n      });\n\nREGISTER_FUSED_BCE_REDUCE_MEAN_FW_BW_KERNEL(half, half)\nREGISTER_FUSED_BCE_REDUCE_MEAN_FW_BW_KERNEL(float, float)\n\n}  // namespace user_op\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/binary_cross_entropy_with_logits_mean_kernel_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\nnamespace oneflow {\n\nnamespace user_op {\n\nnamespace {\n\nclass BCEWithLogitsReduceMeanKernelCache final : public user_op::OpKernelCache {\n public:\n  BCEWithLogitsReduceMeanKernelCache(int64_t reduce_elem_cnt) : reduce_elem_cnt_(reduce_elem_cnt) {}\n  ~BCEWithLogitsReduceMeanKernelCache() override = default;\n\n  int64_t reduce_elem_cnt() const { return reduce_elem_cnt_; }\n\n private:\n  const int64_t reduce_elem_cnt_;\n};\n\nstd::shared_ptr<user_op::OpKernelCache> CreateBCEWithLogitsReduceMeanKernelCache(\n    user_op::KernelCacheContext* ctx) {\n  if (ctx->parallel_ctx().parallel_num() == 1) { return nullptr; }\n  const int64_t reduce_elem_cnt =\n      ctx->LogicalTensorDesc4ArgNameAndIndex(\"input\", 0)->shape().elem_cnt();\n  return std::make_shared<BCEWithLogitsReduceMeanKernelCache>(reduce_elem_cnt);\n}\n\n}  // namespace\n\n}  // namespace user_op\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/binary_cross_entropy_with_logits_reduce_mean.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/data_type.pb.h\"\n#include \"oneflow/user/kernels/binary_cross_entropy_with_logits_mean_kernel_util.h\"\n#include \"oneflow/user/kernels/loss_kernel_util.h\"\n\nnamespace oneflow {\nnamespace user_op {\nnamespace {\n\nusing namespace loss;\n\ntemplate<typename T>\ninline T ComputeMaxVal(const T x) {\n  T y = -x;\n  return y < 0 ? 0 : y;\n}\n\ntemplate<typename T>\ninline T CalSigmoid(const T x) {\n  const T half_of_one = static_cast<T>(0.5);\n  return half_of_one * std::tanh(half_of_one * x) + half_of_one;\n}\n\ntemplate<typename INPUT_T, typename TARGET_T, typename ComputeType>\nstruct ComputeBinaryCrossEntropyWithLogitsReduceMeanOutFunctor {\n  inline ComputeType Compute(int64_t elem_cnt, const INPUT_T* input, const TARGET_T* target,\n                             int64_t reduce_elem_cnt) {\n    ComputeType result = 0.0;\n    FOR_RANGE(int64_t, i, 0, elem_cnt) {\n      ComputeType input_val = static_cast<ComputeType>(input[i]);\n      ComputeType target_val = static_cast<ComputeType>(target[i]);\n      ComputeType max_val = ComputeMaxVal(input_val);\n      result += (1 - target_val) * input_val + max_val\n                + (std::log(std::exp(-max_val) + std::exp(-input_val - max_val)));\n    }\n    return static_cast<TARGET_T>(result) / reduce_elem_cnt;\n  }\n};\n\ntemplate<typename INPUT_T, typename TARGET_T>\nvoid ComputeBinaryCrossEntropyWithLogitsReduceMeanOut(int64_t elem_cnt, const INPUT_T* input,\n                                                      const TARGET_T* target, TARGET_T* out,\n                                                      int64_t reduce_elem_cnt) {\n  if (sizeof(INPUT_T) > sizeof(TARGET_T)) {\n    ComputeBinaryCrossEntropyWithLogitsReduceMeanOutFunctor<INPUT_T, TARGET_T, INPUT_T> f;\n    out[0] = f.Compute(elem_cnt, input, target, reduce_elem_cnt);\n  } else {\n    ComputeBinaryCrossEntropyWithLogitsReduceMeanOutFunctor<INPUT_T, TARGET_T, TARGET_T> f;\n    out[0] = f.Compute(elem_cnt, input, target, reduce_elem_cnt);\n  }\n}\n\ntemplate<typename INPUT_T, typename TARGET_T>\nvoid ComputeBinaryCrossEntropyWithLogitsReduceMeanGradOut(int64_t elem_cnt, const INPUT_T* input,\n                                                          const TARGET_T* target,\n                                                          const TARGET_T* dy, INPUT_T* dx,\n                                                          int64_t reduce_elem_cnt) {\n  INPUT_T dy_val = static_cast<INPUT_T>(dy[0]) / reduce_elem_cnt;\n  FOR_RANGE(int64_t, i, 0, elem_cnt) {\n    INPUT_T input_val = input[i];\n    INPUT_T target_val = static_cast<TARGET_T>(target[i]);\n    INPUT_T input_sigmoid = CalSigmoid(input_val);\n    dx[i] = (input_sigmoid - target_val) * dy_val;\n  }\n}\n\ntemplate<typename INPUT_T, typename TARGET_T>\nclass BinaryCrossEntropyWithLogitsReduceMeanKernel final : public user_op::OpKernel {\n public:\n  BinaryCrossEntropyWithLogitsReduceMeanKernel() = default;\n  ~BinaryCrossEntropyWithLogitsReduceMeanKernel() = default;\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n\n  std::shared_ptr<user_op::OpKernelCache> InitOpKernelCache(\n      user_op::KernelCacheContext* ctx) const override {\n    return CreateBCEWithLogitsReduceMeanKernelCache(ctx);\n  }\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state,\n               const user_op::OpKernelCache* cache) const override {\n    const auto* input_blob = ctx->Tensor4ArgNameAndIndex(\"input\", 0);\n    const auto* target_blob = ctx->Tensor4ArgNameAndIndex(\"target\", 0);\n    auto* out_blob = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n\n    int64_t local_elem_cnt = input_blob->shape_view().elem_cnt();\n    int64_t reduce_elem_cnt = local_elem_cnt;\n    if (cache != nullptr) {\n      // Because `out`'s SBP maybe P or B, we need to use reduce_elem_cnt as reduce_mean factor.\n      const auto* bce_cache = dynamic_cast<const BCEWithLogitsReduceMeanKernelCache*>(cache);\n      CHECK_NOTNULL(bce_cache);\n      reduce_elem_cnt = bce_cache->reduce_elem_cnt();\n    }\n\n    const INPUT_T* input = input_blob->dptr<INPUT_T>();\n    const TARGET_T* target = target_blob->dptr<TARGET_T>();\n    TARGET_T* out = out_blob->mut_dptr<TARGET_T>();\n\n    ComputeBinaryCrossEntropyWithLogitsReduceMeanOut(local_elem_cnt, input, target, out,\n                                                     reduce_elem_cnt);\n  }\n};\n\ntemplate<typename INPUT_T, typename TARGET_T>\nclass BinaryCrossEntropyWithLogitsReduceMeanGradKernel final : public user_op::OpKernel {\n public:\n  BinaryCrossEntropyWithLogitsReduceMeanGradKernel() = default;\n  ~BinaryCrossEntropyWithLogitsReduceMeanGradKernel() = default;\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n\n  std::shared_ptr<user_op::OpKernelCache> InitOpKernelCache(\n      user_op::KernelCacheContext* ctx) const override {\n    return CreateBCEWithLogitsReduceMeanKernelCache(ctx);\n  }\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state,\n               const user_op::OpKernelCache* cache) const override {\n    const auto* input_blob = ctx->Tensor4ArgNameAndIndex(\"input\", 0);\n    const auto* target_blob = ctx->Tensor4ArgNameAndIndex(\"target\", 0);\n    const auto* dy_blob = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    auto* dx_blob = ctx->Tensor4ArgNameAndIndex(\"dx\", 0);\n\n    int64_t local_elem_cnt = input_blob->shape_view().elem_cnt();\n    int64_t reduce_elem_cnt = local_elem_cnt;\n    if (cache != nullptr) {\n      // Because `out`'s SBP maybe P or B, we need to use reduce_elem_cnt as reduce_mean factor.\n      const auto* bce_cache = dynamic_cast<const BCEWithLogitsReduceMeanKernelCache*>(cache);\n      CHECK_NOTNULL(bce_cache);\n      reduce_elem_cnt = bce_cache->reduce_elem_cnt();\n    }\n\n    const TARGET_T* dy = dy_blob->dptr<TARGET_T>();\n    const INPUT_T* input = input_blob->dptr<INPUT_T>();\n    const TARGET_T* target = target_blob->dptr<TARGET_T>();\n    INPUT_T* dx = dx_blob->mut_dptr<INPUT_T>();\n    ComputeBinaryCrossEntropyWithLogitsReduceMeanGradOut(local_elem_cnt, input, target, dy, dx,\n                                                         reduce_elem_cnt);\n  }\n};\n\n}  // namespace\n\n#define REGISTER_BINARY_CROSS_ENTROPY_WITH_LOGITS_REDUCE_MEAN_KERNEL(input_dtype, target_dtype)   \\\n  REGISTER_USER_KERNEL(\"binary_cross_entropy_with_logits_reduce_mean\")                            \\\n      .SetCreateFn<BinaryCrossEntropyWithLogitsReduceMeanKernel<input_dtype, target_dtype>>()     \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)                             \\\n                       && (user_op::HobDataType(\"input\", 0) == GetDataType<input_dtype>::value)   \\\n                       && (user_op::HobDataType(\"target\", 0) == GetDataType<target_dtype>::value) \\\n                       && (user_op::HobDataType(\"out\", 0) == GetDataType<target_dtype>::value));\n\n#define REGISTER_BINARY_CROSS_ENTROPY_WITH_LOGITS_REDUCE_MEAN_GRAD_KERNEL(input_dtype,            \\\n                                                                          target_dtype)           \\\n  REGISTER_USER_KERNEL(\"binary_cross_entropy_with_logits_reduce_mean_grad\")                       \\\n      .SetCreateFn<BinaryCrossEntropyWithLogitsReduceMeanGradKernel<input_dtype, target_dtype>>() \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)                             \\\n                       && (user_op::HobDataType(\"input\", 0) == GetDataType<input_dtype>::value)   \\\n                       && (user_op::HobDataType(\"target\", 0) == GetDataType<target_dtype>::value) \\\n                       && (user_op::HobDataType(\"dy\", 0) == GetDataType<target_dtype>::value)     \\\n                       && (user_op::HobDataType(\"dx\", 0) == GetDataType<input_dtype>::value));\n\nREGISTER_BINARY_CROSS_ENTROPY_WITH_LOGITS_REDUCE_MEAN_KERNEL(float, float)\nREGISTER_BINARY_CROSS_ENTROPY_WITH_LOGITS_REDUCE_MEAN_KERNEL(float, double)\nREGISTER_BINARY_CROSS_ENTROPY_WITH_LOGITS_REDUCE_MEAN_KERNEL(double, float)\nREGISTER_BINARY_CROSS_ENTROPY_WITH_LOGITS_REDUCE_MEAN_KERNEL(double, double)\nREGISTER_BINARY_CROSS_ENTROPY_WITH_LOGITS_REDUCE_MEAN_GRAD_KERNEL(float, float)\nREGISTER_BINARY_CROSS_ENTROPY_WITH_LOGITS_REDUCE_MEAN_GRAD_KERNEL(float, double)\nREGISTER_BINARY_CROSS_ENTROPY_WITH_LOGITS_REDUCE_MEAN_GRAD_KERNEL(double, float)\nREGISTER_BINARY_CROSS_ENTROPY_WITH_LOGITS_REDUCE_MEAN_GRAD_KERNEL(double, double)\n\n}  // namespace user_op\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/bincount_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/ep/include/primitive/memset.h\"\n#include \"oneflow/core/framework/user_op_hob.h\"\n#include \"oneflow/core/thread/thread_manager.h\"\n\nnamespace oneflow {\n\ntemplate<typename IDX, typename T>\nvoid BinCountComputeWeight(const IDX* in_ptr, const T* weight, T* out_ptr, int64_t size) {\n  FOR_RANGE(int64_t, i, 0, size) {\n    IDX idx = *(in_ptr + i);\n    out_ptr[idx] += weight[i];\n  }\n}\n\ntemplate<typename IDX, typename T>\nvoid BinCountCompute(const IDX* in_ptr, T* out_ptr, int64_t size) {\n  FOR_RANGE(int64_t, i, 0, size) {\n    IDX idx = *(in_ptr + i);\n    out_ptr[idx] += 1L;\n  }\n}\n\ntemplate<typename IDX, typename T>\nclass CpuBinCountKernel final : public user_op::OpKernel {\n public:\n  CpuBinCountKernel() = default;\n  ~CpuBinCountKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    size_t out_size = ctx->Attr<int64_t>(\"size\") * sizeof(T);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    const IDX* in_ptr = in->dptr<IDX>();\n    T* out_ptr = out->mut_dptr<T>();\n    std::unique_ptr<ep::primitive::Memset> memset_primitive =\n        ep::primitive::NewPrimitive<ep::primitive::MemsetFactory>(ctx->device_type());\n    CHECK(memset_primitive);\n    memset_primitive->Launch(ctx->stream(), out_ptr, 0, out_size);\n    int64_t in_size = in->shape_view().elem_cnt();\n    if (ctx->has_input(\"weight\", 0)) {\n      const T* weight_ptr = ctx->Tensor4ArgNameAndIndex(\"weight\", 0)->dptr<T>();\n      BinCountComputeWeight<IDX, T>(in_ptr, weight_ptr, out_ptr, in_size);\n    } else {\n      BinCountCompute<IDX, T>(in_ptr, out_ptr, in_size);\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_CPU_BINCOUNT_KERNEL(idx_type, dtype)                                     \\\n  REGISTER_USER_KERNEL(\"bincount\")                                                        \\\n      .SetCreateFn<CpuBinCountKernel<idx_type, dtype>>()                                  \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)                     \\\n                       && (user_op::HobDataType(\"in\", 0) == GetDataType<idx_type>::value) \\\n                       && (user_op::HobDataType(\"out\", 0) == GetDataType<dtype>::value));\n\nREGISTER_CPU_BINCOUNT_KERNEL(int64_t, int64_t)\nREGISTER_CPU_BINCOUNT_KERNEL(int64_t, float16)\nREGISTER_CPU_BINCOUNT_KERNEL(int64_t, float)\nREGISTER_CPU_BINCOUNT_KERNEL(int64_t, double)\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/bincount_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/cuda/atomic.cuh\"\n#include \"oneflow/core/kernel/new_kernel_util.h\"\n#include \"oneflow/core/framework/op_kernel.h\"\n#include \"oneflow/core/framework/user_op_hob.h\"\n#include \"oneflow/core/ep/include/primitive/memset.h\"\n\nnamespace oneflow {\nnamespace user_op {\nnamespace {\n\ntemplate<typename IDX, typename T, bool UseGlobalMem>\n__global__ static void BinCountCompute(const IDX* in_ptr, const T* weight, T* out_ptr,\n                                       int64_t in_size, int64_t out_size) {\n  if constexpr (UseGlobalMem) {\n    CUDA_1D_KERNEL_LOOP(i, in_size) {\n      IDX idx = *(in_ptr + i);\n      cuda::atomic::Add(out_ptr + idx, weight[i]);\n    }\n  } else {\n    __shared__ T shm[kCudaThreadsNumPerBlock];\n    T zero = GetZeroVal<T>();\n    shm[threadIdx.x] = zero;\n    __syncthreads();\n    CUDA_1D_KERNEL_LOOP(i, in_size) {\n      IDX idx = *(in_ptr + i);\n      cuda::atomic::Add(shm + idx, weight[i]);\n    }\n    __syncthreads();\n    if (threadIdx.x < out_size) { cuda::atomic::Add(out_ptr + threadIdx.x, shm[threadIdx.x]); }\n  }\n};\n\ntemplate<typename IDX, typename T, bool UseGlobalMem>\n__global__ static void BinCountCompute(const IDX* in_ptr, T* out_ptr, int64_t in_size,\n                                       int64_t out_size) {\n  T one = GetOneVal<T>();\n  if constexpr (UseGlobalMem) {\n    CUDA_1D_KERNEL_LOOP(i, in_size) {\n      IDX idx = *(in_ptr + i);\n      cuda::atomic::Add(out_ptr + idx, one);\n    }\n  } else {\n    __shared__ T shm[kCudaThreadsNumPerBlock];\n    T zero = GetZeroVal<T>();\n    shm[threadIdx.x] = zero;\n    __syncthreads();\n    CUDA_1D_KERNEL_LOOP(i, in_size) {\n      IDX idx = *(in_ptr + i);\n      cuda::atomic::Add(shm + idx, one);\n    }\n    __syncthreads();\n    if (threadIdx.x < out_size) { cuda::atomic::Add(out_ptr + threadIdx.x, shm[threadIdx.x]); }\n  }\n};\n\ntemplate<typename IDX, typename T, bool UseGlobalMem>\nstatic void BinCountDispatch(user_op::KernelComputeContext* ctx, const IDX* in_ptr,\n                             const T* weight_ptr, T* out_ptr, int64_t in_size, int64_t out_size) {\n  if (weight_ptr) {\n    BinCountCompute<IDX, T, UseGlobalMem>\n        <<<BlocksNum4ThreadsNum(in_size), kCudaThreadsNumPerBlock, 0,\n           ctx->stream()->As<ep::CudaStream>()->cuda_stream()>>>(in_ptr, weight_ptr, out_ptr,\n                                                                 in_size, out_size);\n  } else {\n    BinCountCompute<IDX, T, UseGlobalMem>\n        <<<BlocksNum4ThreadsNum(in_size), kCudaThreadsNumPerBlock, 0,\n           ctx->stream()->As<ep::CudaStream>()->cuda_stream()>>>(in_ptr, out_ptr, in_size,\n                                                                 out_size);\n  }\n}\n\ntemplate<typename IDX, typename T>\nclass CUDABinCountKernel final : public user_op::OpKernel {\n public:\n  CUDABinCountKernel() = default;\n  ~CUDABinCountKernel() override = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    size_t out_size = ctx->Attr<int64_t>(\"size\");\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    const IDX* in_ptr = in->dptr<IDX>();\n    T* out_ptr = out->mut_dptr<T>();\n\n    std::unique_ptr<ep::primitive::Memset> memset_primitive =\n        ep::primitive::NewPrimitive<ep::primitive::MemsetFactory>(ctx->device_type());\n    CHECK(memset_primitive);\n    memset_primitive->Launch(ctx->stream(), out_ptr, 0, out_size * sizeof(T));\n\n    const int64_t in_size = in->shape_view().elem_cnt();\n    if (in_size == 0) { return; }\n\n    const T* weight_ptr = nullptr;\n    if (ctx->has_input(\"weight\", 0)) {\n      weight_ptr = ctx->Tensor4ArgNameAndIndex(\"weight\", 0)->dptr<T>();\n    };\n\n    if (out_size > kCudaThreadsNumPerBlock) {\n      BinCountDispatch<IDX, T, true>(ctx, in_ptr, weight_ptr, out_ptr, in_size, out_size);\n    } else {\n      BinCountDispatch<IDX, T, false>(ctx, in_ptr, weight_ptr, out_ptr, in_size, out_size);\n    }\n  };\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n}  // namespace\n\n#define REGISTER_CUDA_BINCOUNT_KERNEL(idx_type, dtype)                                    \\\n  REGISTER_USER_KERNEL(\"bincount\")                                                        \\\n      .SetCreateFn<CUDABinCountKernel<idx_type, dtype>>()                                 \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)                    \\\n                       && (user_op::HobDataType(\"in\", 0) == GetDataType<idx_type>::value) \\\n                       && (user_op::HobDataType(\"out\", 0) == GetDataType<dtype>::value));\n\nREGISTER_CUDA_BINCOUNT_KERNEL(int64_t, int64_t)\nREGISTER_CUDA_BINCOUNT_KERNEL(int64_t, half)\nREGISTER_CUDA_BINCOUNT_KERNEL(int64_t, float)\nREGISTER_CUDA_BINCOUNT_KERNEL(int64_t, double)\n\n}  // namespace user_op\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/broadcast_div_grad_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/ndarray/ndarray_util.h\"\n#include \"oneflow/core/ndarray/xpu_var_ndarray.h\"\n#include \"oneflow/core/ep/include/primitive/broadcast_elementwise_binary.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<DeviceType device, typename T>\nclass BroadcastDivGradKernel final : public user_op::OpKernel {\n public:\n  BroadcastDivGradKernel() = default;\n  ~BroadcastDivGradKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* y_tensor = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    const user_op::Tensor* z_tensor = ctx->Tensor4ArgNameAndIndex(\"z\", 0);\n    const user_op::Tensor* dz_tensor = ctx->Tensor4ArgNameAndIndex(\"dz\", 0);\n    user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n    user_op::Tensor* dy_tensor = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n\n    const int64_t num_axes = dz_tensor->shape_view().NumAxes();\n    XpuVarNdarray<const T> dz(dz_tensor->shape_view(), dz_tensor->dptr<T>(), num_axes);\n    XpuVarNdarray<const T> const_tmp(dz.shape(), tmp_buffer->dptr<T>());\n    XpuVarNdarray<T> tmp(dz.shape(), tmp_buffer->mut_dptr<T>());\n\n    auto bcast_div = ep::primitive::NewPrimitive<ep::primitive::BroadcastElementwiseBinaryFactory>(\n        ctx->device_type(), ep::primitive::BinaryOp::kDiv, z_tensor->data_type(),\n        z_tensor->data_type(), z_tensor->shape_view().NumAxes());\n    CHECK(bcast_div);\n    bcast_div->Launch(ctx->stream(), z_tensor->shape_view().NumAxes(), z_tensor->shape_view().ptr(),\n                      z_tensor->dptr(), y_tensor->shape_view().NumAxes(),\n                      y_tensor->shape_view().ptr(), y_tensor->dptr<T>(), tmp_buffer->mut_dptr<T>());\n\n    if (IsComplexDataType(z_tensor->data_type())) {\n      auto conj = ep::primitive::NewPrimitive<ep::primitive::ElementwiseUnaryFactory>(\n          ctx->device_type(), ep::primitive::UnaryOp::kConj, z_tensor->data_type(),\n          z_tensor->data_type());\n      CHECK(conj);\n      const int64_t elem_cnt = dz_tensor->shape_view().elem_cnt();\n      conj->Launch(ctx->stream(), tmp_buffer->dptr<T>(), tmp_buffer->mut_dptr<T>(), elem_cnt);\n    }\n\n    auto bcast_mul = ep::primitive::NewPrimitive<ep::primitive::BroadcastElementwiseBinaryFactory>(\n        ctx->device_type(), ep::primitive::BinaryOp::kMul, dz_tensor->data_type(),\n        dz_tensor->data_type(), dz_tensor->shape_view().NumAxes());\n    CHECK(bcast_mul);\n    bcast_mul->Launch(ctx->stream(), dz_tensor->shape_view().NumAxes(),\n                      dz_tensor->shape_view().ptr(), tmp_buffer->dptr(),\n                      dz_tensor->shape_view().NumAxes(), dz_tensor->shape_view().ptr(),\n                      dz_tensor->dptr<T>(), tmp_buffer->mut_dptr<T>());\n\n    NdarrayUtil<device, T>::ReduceSum(\n        ctx->stream(),\n        XpuVarNdarray<T>(dy_tensor->shape_view(), dy_tensor->mut_dptr<T>(), num_axes), const_tmp,\n        tmp);\n\n    auto negative = ep::primitive::NewPrimitive<ep::primitive::ElementwiseUnaryFactory>(\n        ctx->device_type(), ep::primitive::UnaryOp::kNegative, dy_tensor->data_type(),\n        dy_tensor->data_type());\n    CHECK(negative);\n    negative->Launch(ctx->stream(), dy_tensor->dptr(), dy_tensor->mut_dptr(),\n                     dy_tensor->shape_view().elem_cnt());\n  };\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n}  // namespace\n\n#define REGISTER_BROADCAST_DIV_GRAD_KERNEL(device, dtype_pair)                             \\\n  REGISTER_USER_KERNEL(\"broadcast_div_grad\")                                               \\\n      .SetCreateFn<BroadcastDivGradKernel<device, OF_PP_PAIR_FIRST(dtype_pair)>>()         \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == device)                                \\\n                       && (user_op::HobDataType(\"y\", 0) == OF_PP_PAIR_SECOND(dtype_pair))) \\\n      .SetInferTmpSizeFn([](oneflow::user_op::InferContext* ctx) {                         \\\n        const user_op::TensorDesc& z = ctx->InputTensorDesc(\"z\", 0);                       \\\n        DataType data_type = z.data_type();                                                \\\n        const int64_t elem_cnt = z.shape().elem_cnt();                                     \\\n        return GetCudaAlignedSize(elem_cnt * GetSizeOfDataType(data_type));                \\\n      });\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_BROADCAST_DIV_GRAD_KERNEL, DEVICE_TYPE_SEQ,\n                                 ARITHMETIC_DATA_TYPE_SEQ)\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_BROADCAST_DIV_GRAD_KERNEL,\n                                 OF_PP_MAKE_TUPLE_SEQ(DeviceType::kCPU), COMPLEX_DATA_TYPE_SEQ)\n#ifdef WITH_CUDA\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_BROADCAST_DIV_GRAD_KERNEL, (DeviceType::kCUDA),\n                                 FLOAT16_DATA_TYPE_SEQ)\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_BROADCAST_DIV_GRAD_KERNEL, (DeviceType::kCUDA),\n                                 OF_PP_MAKE_TUPLE_SEQ(cuComplex, DataType::kComplex64))\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_BROADCAST_DIV_GRAD_KERNEL, (DeviceType::kCUDA),\n                                 OF_PP_MAKE_TUPLE_SEQ(cuDoubleComplex, DataType::kComplex128))\n#endif\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/broadcast_like_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/ndarray/ndarray_util.h\"\n#include \"oneflow/core/ndarray/xpu_var_ndarray.h\"\n#include \"oneflow/core/kernel/cuda_graph_support.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<DeviceType device_type, typename T>\nclass BroadcastLikeKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport {\n public:\n  BroadcastLikeKernel() = default;\n  ~BroadcastLikeKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* in_tensor = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    const user_op::Tensor* like_tensor = ctx->Tensor4ArgNameAndIndex(\"like\", 0);\n    user_op::Tensor* out_tensor = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    const auto& axis = ctx->Attr<std::vector<int32_t>>(\"broadcast_axes\");\n    const Shape& reduced_shape =\n        CreateReducedShapeOrOnesShape(like_tensor->shape_view(), {axis.begin(), axis.end()});\n    NdarrayUtil<device_type, T>::BroadcastTo(\n        ctx->stream(), XpuVarNdarray<T>(out_tensor->shape_view(), out_tensor->mut_dptr<T>()),\n        XpuVarNdarray<const T>(reduced_shape, in_tensor->dptr<T>()));\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n}  // namespace\n\n#define REGISTER_BROADCAST_LIKE_XPU_KERNEL(device, dtype)   \\\n  REGISTER_USER_KERNEL(\"broadcast_like\")                    \\\n      .SetCreateFn<BroadcastLikeKernel<device, dtype>>()    \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == device) \\\n                       && (user_op::HobDataType(\"y\", 0) == GetDataType<dtype>::value));\n\n#ifdef WITH_CUDA\n#define REGISTER_BROADCAST_LIKE_KERNEL(dtype)                 \\\n  REGISTER_BROADCAST_LIKE_XPU_KERNEL(DeviceType::kCPU, dtype) \\\n  REGISTER_BROADCAST_LIKE_XPU_KERNEL(DeviceType::kCUDA, dtype)\n#else\n#define REGISTER_BROADCAST_LIKE_KERNEL(dtype) \\\n  REGISTER_BROADCAST_LIKE_XPU_KERNEL(DeviceType::kCPU, dtype)\n#endif\n\nREGISTER_BROADCAST_LIKE_KERNEL(float)\nREGISTER_BROADCAST_LIKE_KERNEL(float16)\nREGISTER_BROADCAST_LIKE_KERNEL(double)\nREGISTER_BROADCAST_LIKE_KERNEL(bool)\nREGISTER_BROADCAST_LIKE_KERNEL(int8_t)\nREGISTER_BROADCAST_LIKE_KERNEL(int32_t)\nREGISTER_BROADCAST_LIKE_KERNEL(int64_t)\nREGISTER_BROADCAST_LIKE_XPU_KERNEL(DeviceType::kCPU, std::complex<float>)\nREGISTER_BROADCAST_LIKE_XPU_KERNEL(DeviceType::kCPU, std::complex<double>)\n#ifdef WITH_CUDA\nREGISTER_BROADCAST_LIKE_XPU_KERNEL(DeviceType::kCUDA, cuComplex)\nREGISTER_BROADCAST_LIKE_XPU_KERNEL(DeviceType::kCUDA, cuDoubleComplex)\n#endif\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/cast_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/shape_vec.h\"\n#include \"oneflow/core/common/tensor_meta.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/cuda_graph_support.h\"\n#include \"oneflow/core/ep/include/primitive/cast.h\"\n#include \"oneflow/core/ep/include/primitive/broadcast_elementwise_unary.h\"\n#include \"oneflow/user/kernels/op_kernel_wrapper.h\"\n\nnamespace oneflow {\n\nnamespace user_op {\n\nnamespace {\n\ntemplate<typename Context>\nstd::unique_ptr<ep::primitive::BroadcastElementwiseUnary> NewBroadcastPrimitive(Context* ctx) {\n  const DataType in_data_type = ctx->TensorDesc4ArgNameAndIndex(\"in\", 0)->data_type();\n  const DataType out_data_type = ctx->TensorDesc4ArgNameAndIndex(\"out\", 0)->data_type();\n  const size_t max_ndim = std::max(ctx->TensorDesc4ArgNameAndIndex(\"in\", 0)->shape().NumAxes(),\n                                   ctx->TensorDesc4ArgNameAndIndex(\"out\", 0)->shape().NumAxes());\n  return ep::primitive::NewPrimitive<ep::primitive::BroadcastElementwiseUnaryFactory>(\n      ctx->device_type(), ep::primitive::UnaryOp::kCast, in_data_type, out_data_type, max_ndim);\n}\n\nclass CastKernel final : public OpKernel, public user_op::CudaGraphSupport {\n public:\n  CastKernel() = default;\n  ~CastKernel() = default;\n\n private:\n  void Compute(KernelComputeContext* ctx) const override {\n    const Tensor* input = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    Tensor* output = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    const int64_t elem_cnt = input->shape_view().elem_cnt();\n    // 0-size tensor\n    CHECK_EQ(output->shape_view().elem_cnt(), elem_cnt)\n        << \"The number of cast op's input and output elements should be equal.\";\n    if (elem_cnt == 0) { return; }\n    if (input->data_type() == output->data_type() && input->dptr() == output->dptr()) { return; }\n    const size_t ndim = input->shape_view().NumAxes();\n    auto broadcast_primitive = NewBroadcastPrimitive(ctx);\n    CHECK(broadcast_primitive);\n    if (ndim == 0 && elem_cnt == 1) {\n      // 0-dim tensor\n      // TODO: remove these when BroadcastElementwiseUnary primitive support 0-dim(scalar) tensor\n      Shape input_shape(DimVector{1});\n      Shape output_shape(DimVector{1});\n      Stride input_stride(DimVector{1});\n      Stride output_stride(DimVector{1});\n      const size_t scalar_ndim = 1;\n      broadcast_primitive->Launch(ctx->stream(), scalar_ndim, input_shape.data(),\n                                  input_stride.data(), input->dptr(), scalar_ndim,\n                                  output_shape.data(), output_stride.data(), output->mut_dptr());\n    } else {\n      broadcast_primitive->Launch(\n          ctx->stream(), ndim, input->shape_view().data(), input->stride().data(), input->dptr(),\n          ndim, output->shape_view().data(), output->stride().data(), output->mut_dptr());\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\nauto BroadcastPrimitiveExists() {\n  return hob::make_custom(\"BroadcastElementwiseUnaryPrimitiveExists\",\n                          [](const user_op::KernelRegContext& ctx) -> bool {\n                            return NewBroadcastPrimitive(&ctx).operator bool();\n                          });\n}\n\nREGISTER_USER_KERNEL(\"cast\")\n    .SetCreateFn<CastKernel>()\n    .SetIsMatchedHob(BroadcastPrimitiveExists() == true)\n    .SetInplaceProposalFn([](const user_op::InferContext& ctx,\n                             const user_op::AddInplaceArgPair& AddInplaceArgPairFn) -> Maybe<void> {\n      if (ctx.InputDType(\"in\", 0) == ctx.Attr<DataType>(\"dtype\")) {\n        OF_RETURN_IF_ERROR(AddInplaceArgPairFn(\"out\", 0, \"in\", 0, false));\n      }\n      return Maybe<void>::Ok();\n    });\n\nREGISTER_USER_KERNEL(\"cast_like\")\n    .SetCreateFn<CastKernel>()\n    .SetIsMatchedHob(BroadcastPrimitiveExists() == true)\n    .SetInplaceProposalFn([](const user_op::InferContext& ctx,\n                             const user_op::AddInplaceArgPair& AddInplaceArgPairFn) -> Maybe<void> {\n      if (ctx.InputDType(\"in\", 0) == ctx.InputDType(\"like\", 0)) {\n        OF_RETURN_IF_ERROR(AddInplaceArgPairFn(\"out\", 0, \"in\", 0, false));\n      }\n      return Maybe<void>::Ok();\n    });\n\n}  // namespace\n\n}  // namespace user_op\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/cast_to_static_shape_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/new_kernel_util.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nclass CastToStaticShapeKernel final : public user_op::OpKernel {\n public:\n  CastToStaticShapeKernel() = default;\n  ~CastToStaticShapeKernel() override = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* input_tensor = ctx->Tensor4ArgNameAndIndex(\"input\", 0);\n    const Shape& input_static_shape = ctx->TensorDesc4ArgNameAndIndex(\"input\", 0)->shape();\n    user_op::Tensor* output_tensor = ctx->Tensor4ArgNameAndIndex(\"output\", 0);\n    CHECK(input_tensor->shape_view() == ShapeView(input_static_shape));\n    CHECK_EQ(output_tensor->shape_view(), input_tensor->shape_view());\n    size_t output_tensor_size =\n        output_tensor->shape_view().elem_cnt() * GetSizeOfDataType(output_tensor->data_type());\n    std::unique_ptr<ep::primitive::Memcpy> primitive =\n        ep::primitive::NewPrimitive<ep::primitive::MemcpyFactory>(ctx->stream()->device_type(),\n                                                                  ep::primitive::MemcpyKind::kDtoD);\n    CHECK(primitive) << \"Can not create Memcpy primitive for device type \"\n                     << ctx->stream()->device_type();\n    primitive->Launch(ctx->stream(), output_tensor->mut_dptr(), input_tensor->dptr(),\n                      output_tensor_size);\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n}  // namespace\n\nREGISTER_USER_KERNEL(\"cast_to_static_shape\")\n    .SetCreateFn<CastToStaticShapeKernel>()\n    .SetInplaceProposalFn([](const user_op::InferContext&,\n                             const user_op::AddInplaceArgPair& AddInplaceArgPairFn) -> Maybe<void> {\n      OF_RETURN_IF_ERROR(AddInplaceArgPairFn(\"output\", 0, \"input\", 0, false));\n      return Maybe<void>::Ok();\n    });\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/categorical_ordinal_encode_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/user/kernels/categorical_ordinal_encode_kernel_util.h\"\n\nnamespace oneflow {\n\ntemplate<DeviceType device_type, typename T>\nclass CategoricalOrdinalEncodeKernel final : public user_op::OpKernel {\n public:\n  CategoricalOrdinalEncodeKernel() = default;\n  ~CategoricalOrdinalEncodeKernel() override = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    bool hash_precomputed = ctx->Attr<bool>(\"hash_precomputed\");\n    CHECK(hash_precomputed);\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    user_op::Tensor* table = ctx->Tensor4ArgNameAndIndex(\"table\", 0);\n    user_op::Tensor* size = ctx->Tensor4ArgNameAndIndex(\"size\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    const int64_t table_elem_cnt = table->shape_view().elem_cnt();\n    CHECK_EQ(table_elem_cnt % 2, 0);\n    const int64_t capacity = table_elem_cnt / 2;\n    CategoricalOrdinalEncodeKernelUtil<device_type, T>::Encode(\n        ctx->stream(), capacity, table->mut_dptr<T>(), size->mut_dptr<T>(),\n        in->shape_view().elem_cnt(), in->dptr<T>(), out->mut_dptr<T>());\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return true; }\n};\n\n#define REGISTER_CATEGORICAL_ORDINAL_ENCODE_KERNEL(device, proto_type, cpp_type) \\\n  REGISTER_USER_KERNEL(\"CategoricalOrdinalEncode\")                               \\\n      .SetCreateFn<CategoricalOrdinalEncodeKernel<device, cpp_type>>()           \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == device)                      \\\n                       && (user_op::HobDataType(\"in\", 0) == proto_type));\n\nREGISTER_CATEGORICAL_ORDINAL_ENCODE_KERNEL(DeviceType::kCPU, DataType::kInt32, int32_t);\nREGISTER_CATEGORICAL_ORDINAL_ENCODE_KERNEL(DeviceType::kCPU, DataType::kInt64, int64_t);\n#ifdef WITH_CUDA\nREGISTER_CATEGORICAL_ORDINAL_ENCODE_KERNEL(DeviceType::kCUDA, DataType::kInt32, int32_t);\nREGISTER_CATEGORICAL_ORDINAL_ENCODE_KERNEL(DeviceType::kCUDA, DataType::kInt64, int64_t);\n#endif\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/categorical_ordinal_encode_kernel_util.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/user/kernels/categorical_ordinal_encode_kernel_util.h\"\n\nnamespace oneflow {\n\ntemplate<typename T>\nstruct CategoricalOrdinalEncodeKernelUtil<DeviceType::kCPU, T> {\n  static void Encode(ep::Stream* stream, int64_t capacity, T* table, T* size, int64_t n,\n                     const T* hash, T* out) {\n    for (int64_t i = 0; i < n; ++i) {\n      const T h = hash[i];\n      bool success = false;\n      for (int64_t count = 0; count < capacity; ++count) {\n        size_t idx =\n            (static_cast<size_t>(h) + static_cast<size_t>(count)) % static_cast<size_t>(capacity);\n        T* k_ptr = table + idx * 2;\n        T* v_ptr = k_ptr + 1;\n        if (*k_ptr == h) {\n          out[i] = *v_ptr;\n          success = true;\n          break;\n        } else if (*k_ptr == 0) {\n          T new_size = *size + 1;\n          *k_ptr = h;\n          *v_ptr = new_size;\n          out[i] = new_size;\n          *size = new_size;\n          success = true;\n          break;\n        } else {\n          continue;\n        }\n      }\n      CHECK(success);\n    }\n  }\n};\n\n#define INSTANTIATE_CATEGORICAL_ORDINAL_ENCODE_KERNEL_UTIL_CPU(type_cpp, type_proto) \\\n  template struct CategoricalOrdinalEncodeKernelUtil<DeviceType::kCPU, type_cpp>;\nOF_PP_FOR_EACH_TUPLE(INSTANTIATE_CATEGORICAL_ORDINAL_ENCODE_KERNEL_UTIL_CPU, INDEX_DATA_TYPE_SEQ);\n#undef INSTANTIATE_CATEGORICAL_ORDINAL_ENCODE_KERNEL_UTIL_CPU\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/categorical_ordinal_encode_kernel_util.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifdef NDEBUG\n#undef NDEBUG\n#endif\n#include <assert.h>\n#include \"oneflow/user/kernels/categorical_ordinal_encode_kernel_util.h\"\n#include \"oneflow/core/kernel/kernel_util.cuh\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nusing CuInt64T = unsigned long long int;\n\n__device__ __inline__ int32_t AtomicCAS(int32_t* address, int32_t compare, int32_t val) {\n  return atomicCAS(address, compare, val);\n}\n\n__device__ __inline__ int64_t AtomicCAS(int64_t* address, int64_t compare, int64_t val) {\n  static_assert(sizeof(int64_t) == sizeof(CuInt64T), \"size error\");\n  return static_cast<int64_t>(atomicCAS(reinterpret_cast<CuInt64T*>(address),\n                                        static_cast<CuInt64T>(compare),\n                                        static_cast<CuInt64T>(val)));\n}\n\n__device__ __inline__ int32_t AtomicAdd(int32_t* address, int32_t val) {\n  return atomicAdd(address, val);\n}\n\n__device__ __inline__ int64_t AtomicAdd(int64_t* address, int64_t val) {\n  static_assert(sizeof(int64_t) == sizeof(CuInt64T), \"size error\");\n  return static_cast<int64_t>(\n      atomicAdd(reinterpret_cast<CuInt64T*>(address), static_cast<CuInt64T>(val)));\n}\n\ntemplate<typename K, typename V>\n__device__ bool TryGetOrInsert(K* key, volatile V* value, V* size, const K hash, V* out) {\n  K old_key = AtomicCAS(key, static_cast<K>(0), hash);\n  if (old_key == 0) {\n    V v = AtomicAdd(size, 1) + 1;\n    *value = v;\n    *out = v;\n    return true;\n  } else if (old_key == hash) {\n    while (true) {\n      V v = *value;\n      if (v != 0) {\n        *out = v;\n        break;\n      }\n    }\n    return true;\n  } else {\n    return false;\n  }\n}\n\ntemplate<typename T>\n__device__ bool GetOrInsertOne(const size_t capacity, T* table, T* size, const T hash, T* out) {\n  if (hash == 0) {\n    *out = 0;\n    return true;\n  }\n  const size_t start_idx = static_cast<size_t>(hash) % capacity;\n  // fast path\n  {\n    T* key = table + start_idx * 2;\n    T* value = key + 1;\n    if (*key == hash && *value != 0) {\n      *out = *value;\n      return true;\n    }\n  }\n  for (size_t count = 0; count < capacity; ++count) {\n    const size_t idx = (start_idx + count) % capacity;\n    T* key = table + idx * 2;\n    T* value = key + 1;\n    if (TryGetOrInsert<T, T>(key, value, size, hash, out)) { return true; }\n  }\n  return false;\n}\n\ntemplate<typename T>\n__global__ void EncodeGpu(const size_t capacity, T* table, T* size, const int64_t n, const T* hash,\n                          T* out) {\n  CUDA_1D_KERNEL_LOOP(i, n) {\n    bool success = GetOrInsertOne<T>(capacity, table, size, hash[i], out + i);\n    assert(success);\n  }\n}\n\n}  // namespace\n\ntemplate<typename T>\nstruct CategoricalOrdinalEncodeKernelUtil<DeviceType::kCUDA, T> {\n  static void Encode(ep::Stream* stream, int64_t capacity, T* table, T* size, int64_t n,\n                     const T* hash, T* out) {\n    EncodeGpu<T>\n        <<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0,\n           stream->As<ep::CudaStream>()->cuda_stream()>>>(capacity, table, size, n, hash, out);\n  }\n};\n\n#define INSTANTIATE_CATEGORICAL_ORDINAL_ENCODE_KERNEL_UTIL_CUDA(type_cpp, type_proto) \\\n  template struct CategoricalOrdinalEncodeKernelUtil<DeviceType::kCUDA, type_cpp>;\nOF_PP_FOR_EACH_TUPLE(INSTANTIATE_CATEGORICAL_ORDINAL_ENCODE_KERNEL_UTIL_CUDA, INDEX_DATA_TYPE_SEQ);\n#undef INSTANTIATE_CATEGORICAL_ORDINAL_ENCODE_KERNEL_UTIL_CUDA\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/categorical_ordinal_encode_kernel_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_KERNELS_CATEGORICAL_ORDINAL_ENCODE_KERNEL_UTIL_H_\n#define ONEFLOW_USER_KERNELS_CATEGORICAL_ORDINAL_ENCODE_KERNEL_UTIL_H_\n\n#include \"oneflow/core/kernel/kernel_util.h\"\n\nnamespace oneflow {\n\ntemplate<DeviceType device_type, typename T>\nstruct CategoricalOrdinalEncodeKernelUtil {\n  static void Encode(ep::Stream* stream, int64_t capacity, T* table, T* size, int64_t n,\n                     const T* hash, T* out);\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_KERNELS_CATEGORICAL_ORDINAL_ENCODE_KERNEL_UTIL_H_\n"
  },
  {
    "path": "oneflow/user/kernels/clip_by_value_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/user/kernels/clip_by_value_kernel.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/cuda_graph_support.h\"\n#ifdef WITH_CUDA\n#include <cuda_fp16.h>\n#endif\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename T>\nT GetDtypeMatchedValue(double floating, int64_t integral);\n\ntemplate<>\nfloat GetDtypeMatchedValue(double floating, int64_t integral) {\n  return static_cast<float>(floating);\n}\n\ntemplate<>\ndouble GetDtypeMatchedValue(double floating, int64_t integral) {\n  return floating;\n}\n\ntemplate<>\nint8_t GetDtypeMatchedValue(double floating, int64_t integral) {\n  return static_cast<int8_t>(integral);\n}\n\ntemplate<>\nint32_t GetDtypeMatchedValue(double floating, int64_t integral) {\n  return static_cast<int32_t>(integral);\n}\n\ntemplate<>\nint64_t GetDtypeMatchedValue(double floating, int64_t integral) {\n  return integral;\n}\n\n#ifdef WITH_CUDA\ntemplate<>\nhalf GetDtypeMatchedValue(double floating, int64_t integral) {\n#if CUDA_VERSION >= 11000\n  return __double2half(floating);\n#else\n  return __float2half(static_cast<float>(floating));\n#endif\n}\n#endif\n\ntemplate<>\nfloat16 GetDtypeMatchedValue(double floating, int64_t integral) {\n  return static_cast<float16>(floating);\n}\n\n}  // namespace\n\ntemplate<typename T>\nstruct ClipKernelUtil<DeviceType::kCPU, T> {\n  template<typename F>\n  static void Forward(ep::Stream* stream, F clip_func, const int64_t n, const T* x, T* y) {\n    FOR_RANGE(int64_t, i, 0, n) { y[i] = clip_func(x[i]); }\n  }\n\n  template<typename F>\n  static void Backward(ep::Stream* stream, F clip_func, const int64_t n, const T* x, const T* dy,\n                       T* dx) {\n    FOR_RANGE(int64_t, i, 0, n) { dx[i] = clip_func(x[i], dy[i]); }\n  }\n};\n\ntemplate<DeviceType device_type, typename T>\nclass ClipByScalarKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport {\n public:\n  ClipByScalarKernel() = default;\n  ~ClipByScalarKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    double floating_min = ctx->Attr<double>(\"floating_min\");\n    int64_t integral_min = ctx->Attr<int64_t>(\"integral_min\");\n    double floating_max = ctx->Attr<double>(\"floating_max\");\n    int64_t integral_max = ctx->Attr<int64_t>(\"integral_max\");\n    ClipByMinMaxFunctor<T> clip_func(GetDtypeMatchedValue<T>(floating_min, integral_min),\n                                     GetDtypeMatchedValue<T>(floating_max, integral_max));\n    ClipKernelUtil<device_type, T>::Forward(ctx->stream(), clip_func, y->shape_view().elem_cnt(),\n                                            x->dptr<T>(), y->mut_dptr<T>());\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\ntemplate<DeviceType device_type, typename T>\nclass ClipByScalarMinKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport {\n public:\n  ClipByScalarMinKernel() = default;\n  ~ClipByScalarMinKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    double floating_min = ctx->Attr<double>(\"floating_min\");\n    int64_t integral_min = ctx->Attr<int64_t>(\"integral_min\");\n    ClipByMinFunctor<T> clip_func(GetDtypeMatchedValue<T>(floating_min, integral_min));\n    ClipKernelUtil<device_type, T>::Forward(ctx->stream(), clip_func, y->shape_view().elem_cnt(),\n                                            x->dptr<T>(), y->mut_dptr<T>());\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\ntemplate<DeviceType device_type, typename T>\nclass ClipByScalarMaxKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport {\n public:\n  ClipByScalarMaxKernel() = default;\n  ~ClipByScalarMaxKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    double floating_max = ctx->Attr<double>(\"floating_max\");\n    int64_t integral_max = ctx->Attr<int64_t>(\"integral_max\");\n    ClipByMaxFunctor<T> clip_func(GetDtypeMatchedValue<T>(floating_max, integral_max));\n    ClipKernelUtil<device_type, T>::Forward(ctx->stream(), clip_func, y->shape_view().elem_cnt(),\n                                            x->dptr<T>(), y->mut_dptr<T>());\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\ntemplate<DeviceType device_type, typename T>\nclass ClipByScalarGradKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport {\n public:\n  ClipByScalarGradKernel() = default;\n  ~ClipByScalarGradKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex(\"dx\", 0);\n    double floating_min = ctx->Attr<double>(\"floating_min\");\n    int64_t integral_min = ctx->Attr<int64_t>(\"integral_min\");\n    double floating_max = ctx->Attr<double>(\"floating_max\");\n    int64_t integral_max = ctx->Attr<int64_t>(\"integral_max\");\n    ClipByMinMaxGradFunctor<T> clip_func(GetDtypeMatchedValue<T>(floating_min, integral_min),\n                                         GetDtypeMatchedValue<T>(floating_max, integral_max));\n    ClipKernelUtil<device_type, T>::Backward(ctx->stream(), clip_func, dx->shape_view().elem_cnt(),\n                                             x->dptr<T>(), dy->dptr<T>(), dx->mut_dptr<T>());\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\ntemplate<DeviceType device_type, typename T>\nclass ClipByScalarMinGradKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport {\n public:\n  ClipByScalarMinGradKernel() = default;\n  ~ClipByScalarMinGradKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex(\"dx\", 0);\n    double floating_min = ctx->Attr<double>(\"floating_min\");\n    int64_t integral_min = ctx->Attr<int64_t>(\"integral_min\");\n    ClipByMinGradFunctor<T> clip_func(GetDtypeMatchedValue<T>(floating_min, integral_min));\n    ClipKernelUtil<device_type, T>::Backward(ctx->stream(), clip_func, dx->shape_view().elem_cnt(),\n                                             x->dptr<T>(), dy->dptr<T>(), dx->mut_dptr<T>());\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\ntemplate<DeviceType device_type, typename T>\nclass ClipByScalarMaxGradKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport {\n public:\n  ClipByScalarMaxGradKernel() = default;\n  ~ClipByScalarMaxGradKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex(\"dx\", 0);\n    double floating_max = ctx->Attr<double>(\"floating_max\");\n    int64_t integral_max = ctx->Attr<int64_t>(\"integral_max\");\n    ClipByMaxGradFunctor<T> clip_func(GetDtypeMatchedValue<T>(floating_max, integral_max));\n    ClipKernelUtil<device_type, T>::Backward(ctx->stream(), clip_func, dx->shape_view().elem_cnt(),\n                                             x->dptr<T>(), dy->dptr<T>(), dx->mut_dptr<T>());\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_CLIP_KERNEL(op_type_name, kernel_name, device_type_v, dtype)          \\\n  REGISTER_USER_KERNEL(#op_type_name)                                                  \\\n      .SetCreateFn<kernel_name##Kernel<device_type_v, dtype>>()                        \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == device_type_v)                     \\\n                       && (user_op::HobDataType(\"y\", 0) == GetDataType<dtype>::value)) \\\n      .SetInplaceProposalFn(                                                           \\\n          [](const user_op::InferContext&,                                             \\\n             const user_op::AddInplaceArgPair& AddInplaceArgPairFn) -> Maybe<void> {   \\\n            OF_RETURN_IF_ERROR(AddInplaceArgPairFn(\"y\", 0, \"x\", 0, true));             \\\n            return Maybe<void>::Ok();                                                  \\\n          });\n\n#define REGISTER_CLIP_GRAD_KERNEL(op_type_name, kernel_name, device_type_v, dtype)      \\\n  REGISTER_USER_KERNEL(#op_type_name)                                                   \\\n      .SetCreateFn<kernel_name##GradKernel<device_type_v, dtype>>()                     \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == device_type_v)                      \\\n                       && (user_op::HobDataType(\"dx\", 0) == GetDataType<dtype>::value)) \\\n      .SetInplaceProposalFn(                                                            \\\n          [](const user_op::InferContext&,                                              \\\n             const user_op::AddInplaceArgPair& AddInplaceArgPairFn) -> Maybe<void> {    \\\n            OF_RETURN_IF_ERROR(AddInplaceArgPairFn(\"dx\", 0, \"dy\", 0, true));            \\\n            return Maybe<void>::Ok();                                                   \\\n          });\n\n#define REGISTER_CLIP_KERNELS(device_type_v, dtype_pair)                                          \\\n  REGISTER_CLIP_KERNEL(clip_by_scalar, ClipByScalar, device_type_v, OF_PP_PAIR_FIRST(dtype_pair)) \\\n  REGISTER_CLIP_KERNEL(clip_by_scalar_min, ClipByScalarMin, device_type_v,                        \\\n                       OF_PP_PAIR_FIRST(dtype_pair))                                              \\\n  REGISTER_CLIP_KERNEL(clip_by_scalar_max, ClipByScalarMax, device_type_v,                        \\\n                       OF_PP_PAIR_FIRST(dtype_pair))                                              \\\n  REGISTER_CLIP_GRAD_KERNEL(clip_by_scalar_grad, ClipByScalar, device_type_v,                     \\\n                            OF_PP_PAIR_FIRST(dtype_pair))                                         \\\n  REGISTER_CLIP_GRAD_KERNEL(clip_by_scalar_min_grad, ClipByScalarMin, device_type_v,              \\\n                            OF_PP_PAIR_FIRST(dtype_pair))                                         \\\n  REGISTER_CLIP_GRAD_KERNEL(clip_by_scalar_max_grad, ClipByScalarMax, device_type_v,              \\\n                            OF_PP_PAIR_FIRST(dtype_pair))\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_CLIP_KERNELS, DEVICE_TYPE_SEQ, ARITHMETIC_DATA_TYPE_SEQ)\nREGISTER_CLIP_KERNELS(DeviceType::kCPU, (float16, DataType::kFloat16))\n#ifdef WITH_CUDA\nREGISTER_CLIP_KERNELS(DeviceType::kCUDA, (half, DataType::kFloat16))\n#endif  // WITH_CUDA\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/clip_by_value_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/user/kernels/clip_by_value_kernel.h\"\n#include \"oneflow/core/device/cuda_util.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename T, typename F>\n__global__ void CudaClipForward(F clip_func, int64_t n, const T* x, T* y) {\n  CUDA_1D_KERNEL_LOOP(i, n) { y[i] = clip_func(x[i]); }\n}\n\ntemplate<typename T, typename F>\n__global__ void CudaClipBackward(F clip_func, int64_t n, const T* x, const T* dy, T* dx) {\n  CUDA_1D_KERNEL_LOOP(i, n) { dx[i] = clip_func(x[i], dy[i]); }\n}\n\n}  // namespace\n\ntemplate<typename T>\nstruct ClipKernelUtil<DeviceType::kCUDA, T> {\n  template<typename F>\n  static void Forward(ep::Stream* stream, F clip_func, const int64_t n, const T* x, T* y) {\n    if (n == 0) { return; }\n    RUN_CUDA_KERNEL((CudaClipForward<T, F>), stream, n, clip_func, n, x, y);\n  }\n\n  template<typename F>\n  static void Backward(ep::Stream* stream, F clip_func, const int64_t n, const T* x, const T* dy,\n                       T* dx) {\n    if (n == 0) { return; }\n    RUN_CUDA_KERNEL((CudaClipBackward<T, F>), stream, n, clip_func, n, x, dy, dx);\n  }\n};\n\n#define INITIATE_CLIP_KERNEL_UTIL_CUDA(dtype, dtype_v)                                          \\\n  template struct ClipKernelUtil<DeviceType::kCUDA, dtype>;                                     \\\n  template void ClipKernelUtil<DeviceType::kCUDA, dtype>::Forward(                              \\\n      ep::Stream*, ClipByMinFunctor<dtype>, const int64_t n, const dtype*, dtype*);             \\\n  template void ClipKernelUtil<DeviceType::kCUDA, dtype>::Forward(                              \\\n      ep::Stream*, ClipByMaxFunctor<dtype>, const int64_t n, const dtype*, dtype*);             \\\n  template void ClipKernelUtil<DeviceType::kCUDA, dtype>::Forward(                              \\\n      ep::Stream*, ClipByMinMaxFunctor<dtype>, const int64_t n, const dtype*, dtype*);          \\\n  template void ClipKernelUtil<DeviceType::kCUDA, dtype>::Backward(                             \\\n      ep::Stream*, ClipByMinGradFunctor<dtype>, const int64_t n, const dtype*, const dtype*,    \\\n      dtype*);                                                                                  \\\n  template void ClipKernelUtil<DeviceType::kCUDA, dtype>::Backward(                             \\\n      ep::Stream*, ClipByMaxGradFunctor<dtype>, const int64_t n, const dtype*, const dtype*,    \\\n      dtype*);                                                                                  \\\n  template void ClipKernelUtil<DeviceType::kCUDA, dtype>::Backward(                             \\\n      ep::Stream*, ClipByMinMaxGradFunctor<dtype>, const int64_t n, const dtype*, const dtype*, \\\n      dtype*);\n\nOF_PP_FOR_EACH_TUPLE(INITIATE_CLIP_KERNEL_UTIL_CUDA, ARITHMETIC_DATA_TYPE_SEQ)\nINITIATE_CLIP_KERNEL_UTIL_CUDA(half, DataType::kFloat16)\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/clip_by_value_kernel.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_KERNELS_CLIP_BY_VALUE_KERNEL_H_\n#define ONEFLOW_USER_KERNELS_CLIP_BY_VALUE_KERNEL_H_\n\n#include \"oneflow/core/common/data_type.h\"\n#include \"oneflow/core/ep/include/stream.h\"\n\nnamespace oneflow {\n\ntemplate<typename T>\nOF_DEVICE_FUNC T DeviceMin(T a, T b) {\n#if defined(__CUDA_ARCH__)\n  return a < b ? a : b;\n#else\n  return std::min(a, b);\n#endif\n}\n\ntemplate<typename T>\nOF_DEVICE_FUNC T DeviceMax(T a, T b) {\n#if defined(__CUDA_ARCH__)\n  return a > b ? a : b;\n#else\n  return std::max(a, b);\n#endif\n}\n\ntemplate<typename T>\nstruct ClipByMinFunctor {\n  ClipByMinFunctor(T min) : min_value(min) {}\n  OF_DEVICE_FUNC T operator()(T value) { return DeviceMax(value, min_value); }\n  T min_value;\n};\n\ntemplate<typename T>\nstruct ClipByMaxFunctor {\n  ClipByMaxFunctor(T max) : max_value(max) {}\n  OF_DEVICE_FUNC T operator()(T value) { return DeviceMin(value, max_value); }\n  T max_value;\n};\n\ntemplate<typename T>\nstruct ClipByMinMaxFunctor {\n  ClipByMinMaxFunctor(T min, T max) : min_value(min), max_value(max) {}\n  OF_DEVICE_FUNC T operator()(T value) { return DeviceMin(DeviceMax(value, min_value), max_value); }\n  T min_value;\n  T max_value;\n};\n\ntemplate<typename T>\nstruct ClipByMinGradFunctor {\n  ClipByMinGradFunctor(T min) : min_value(min) {}\n  OF_DEVICE_FUNC T operator()(T value, T grad) {\n    return value < min_value ? static_cast<T>(0) : grad;\n  }\n  T min_value;\n};\n\ntemplate<typename T>\nstruct ClipByMaxGradFunctor {\n  ClipByMaxGradFunctor(T max) : max_value(max) {}\n  OF_DEVICE_FUNC T operator()(T value, T grad) {\n    return value > max_value ? static_cast<T>(0) : grad;\n  }\n  T max_value;\n};\n\ntemplate<typename T>\nstruct ClipByMinMaxGradFunctor {\n  ClipByMinMaxGradFunctor(T min, T max) : min_value(min), max_value(max) {}\n  OF_DEVICE_FUNC T operator()(T value, T grad) {\n    return (value < min_value || value > max_value) ? static_cast<T>(0) : grad;\n  }\n  T min_value;\n  T max_value;\n};\n\ntemplate<DeviceType device_type, typename T>\nstruct ClipKernelUtil {\n  template<typename F>\n  static void Forward(ep::Stream* stream, F clip_func, const int64_t n, const T* x, T* y);\n  template<typename F>\n  static void Backward(ep::Stream* stream, F clip_func, const int64_t n, const T* x, const T* dy,\n                       T* dx);\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_KERNELS_CLIP_BY_VALUE_KERNEL_H_\n"
  },
  {
    "path": "oneflow/user/kernels/coco_reader_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/user/data/coco_data_reader.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nclass COCOReaderWrapper final : public user_op::OpKernelState {\n public:\n  explicit COCOReaderWrapper(user_op::KernelInitContext* ctx) : reader_(ctx) {}\n  ~COCOReaderWrapper() = default;\n\n  void Read(user_op::KernelComputeContext* ctx) { reader_.Read(ctx); }\n\n private:\n  data::COCODataReader reader_;\n};\n\nclass COCOReaderKernel final : public user_op::OpKernel {\n public:\n  COCOReaderKernel() = default;\n  ~COCOReaderKernel() = default;\n\n  std::shared_ptr<user_op::OpKernelState> CreateOpKernelState(\n      user_op::KernelInitContext* ctx) const override {\n    std::shared_ptr<user_op::OpKernelState> reader(new COCOReaderWrapper(ctx));\n    return reader;\n  }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state,\n               const user_op::OpKernelCache*) const override {\n    auto* reader = dynamic_cast<COCOReaderWrapper*>(state);\n    reader->Read(ctx);\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n}  // namespace\n\nREGISTER_USER_KERNEL(\"COCOReader\")\n    .SetCreateFn<COCOReaderKernel>()\n    .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)\n                     && (user_op::HobDataType(\"image\", 0) == DataType::kTensorBuffer)\n                     && (user_op::HobDataType(\"image_id\", 0) == DataType::kInt64)\n                     && (user_op::HobDataType(\"image_size\", 0) == DataType::kInt32)\n                     && (user_op::HobDataType(\"gt_bbox\", 0) == DataType::kTensorBuffer)\n                     && (user_op::HobDataType(\"gt_label\", 0) == DataType::kTensorBuffer)\n                     && (user_op::HobDataType(\"gt_segm\", 0) == DataType::kTensorBuffer)\n                     && (user_op::HobDataType(\"gt_segm_index\", 0) == DataType::kTensorBuffer));\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/collective_communication/cpu/cpu_all_gather.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/data_type.h\"\n#include \"oneflow/core/job/rank_group.h\"\n#include \"oneflow/core/framework/transport_util.h\"\n#include \"oneflow/user/kernels/collective_communication/cpu/cpu_communication_context.h\"\n#include \"oneflow/user/kernels/collective_communication/include/all_gather.h\"\n#include \"oneflow/user/kernels/collective_communication/cpu/cpu_collective_communication_util.h\"\n\nnamespace oneflow {\n\nnamespace ccl {\n\nnamespace {\n\nMaybe<void> AllGatherImpl(const void* in, void* out, size_t elem_cnt, DataType dtype,\n                          Symbol<ParallelDesc> parallel_desc) {\n  int64_t parallel_num = parallel_desc->parallel_num();\n  if (parallel_num == 1) {\n    if (in != out) { std::memcpy(out, in, elem_cnt * GetSizeOfDataType(dtype)); }\n    return Maybe<void>::Ok();\n  }\n  char* char_out = reinterpret_cast<char*>(out);\n  size_t chunk_size = elem_cnt * GetSizeOfDataType(dtype);\n  BalancedSplitter bs(chunk_size * parallel_num, parallel_num);\n  const auto& opt_parallel_id = JUST(GetParallelId4CurrentProcessCtx(parallel_desc));\n  CHECK_OR_RETURN(opt_parallel_id->has_value()) << kOfBugIssueUploadPrompt;\n  const auto& rank_group = JUST(RankGroup::New(parallel_desc));\n  TransportToken transport_token = JUST(TransportToken::NewTransportToken(kTransportTokenTypeData));\n  int64_t parallel_id = JUST(*opt_parallel_id);\n  // In-place operation will happen if in == out + parallel_id * chunk_size\n  if (in != &char_out[parallel_id * chunk_size]) {\n    memcpy(&char_out[parallel_id * chunk_size], in, chunk_size);\n  }\n  for (int64_t i = 0, part_id = parallel_id; i < parallel_num - 1;\n       ++i, part_id = RingDecrease(part_id, parallel_num)) {\n    int64_t send_part_id = part_id;\n    const void* send_ptr = &char_out[bs.At(send_part_id).begin()];\n    size_t send_size = bs.At(send_part_id).size();\n    int64_t recv_part_id = RingDecrease(part_id, parallel_num);\n    void* recv_ptr = &char_out[bs.At(recv_part_id).begin()];\n    size_t recv_size = bs.At(recv_part_id).size();\n    NaiveAsyncTransportCtx ctx(\n        transport_token,\n        [&](void** buffer, std::size_t* size, std::function<void()>* Cb) -> Maybe<void> {\n          *buffer = const_cast<void*>(send_ptr);\n          *size = send_size;\n          *Cb = [] {};\n          return Maybe<void>::Ok();\n        },\n        [&](void** buffer, std::size_t* size, std::function<void()>* Cb) -> Maybe<void> {\n          *buffer = recv_ptr;\n          *size = recv_size;\n          *Cb = [] {};\n          return Maybe<void>::Ok();\n        });\n    if (send_size > 0) {\n      JUST(TransportUtil::SendToNextRankInRing(rank_group, transport_token, &ctx));\n    }\n    if (recv_size > 0) {\n      JUST(TransportUtil::ReceiveFromPrevRankInRing(rank_group, transport_token, &ctx));\n    }\n    JUST(ctx.WaitDone());\n  }\n  return Maybe<void>::Ok();\n}\n}  // namespace\n\nclass CpuAllGather final : public AllGather {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CpuAllGather);\n  CpuAllGather() : datatype_(kInvalidDataType) {}\n  ~CpuAllGather() = default;\n\n  void Init(DataType datatype) override { this->datatype_ = datatype; }\n\n  void Launch(ep::Stream* stream, const void* in, void* out, size_t elem_cnt,\n              const std::shared_ptr<CommunicationContext>& communication_ctx) const override {\n    const auto& cpu_communication_ctx =\n        std::dynamic_pointer_cast<CpuCommunicationContext>(communication_ctx);\n    CHECK(cpu_communication_ctx);\n    CHECK_JUST(AllGatherImpl(in, out, elem_cnt, datatype_, cpu_communication_ctx->parallel_desc()));\n  }\n\n  void Launch(ep::Stream* stream, const void* in, void* out, size_t elem_cnt,\n              const ccl::CclComm& ccl_comm) const override {\n    UNIMPLEMENTED();\n  }\n\n private:\n  DataType datatype_;\n};\n\nREGISTER_COLLECTIVE_COMMUNICATION(DeviceType::kCPU, AllGather, CpuAllGather);\n\n}  // namespace ccl\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/collective_communication/cpu/cpu_all_reduce.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/data_type.h\"\n#include \"oneflow/core/job/rank_group.h\"\n#include \"oneflow/core/framework/transport_util.h\"\n#include \"oneflow/user/kernels/collective_communication/cpu/cpu_communication_context.h\"\n#include \"oneflow/user/kernels/collective_communication/include/all_reduce.h\"\n#include \"oneflow/user/kernels/collective_communication/cpu/cpu_collective_communication_util.h\"\n\nnamespace oneflow {\n\nnamespace ccl {\n\nnamespace {\n\ntemplate<typename T, ReduceType reduce_type>\nstruct AllReduceImpl final {\n  static Maybe<void> Call(const void* void_in, void* void_out, size_t elem_cnt,\n                          Symbol<ParallelDesc> parallel_desc) {\n    int64_t parallel_num = parallel_desc->parallel_num();\n    if (parallel_num == 1) {\n      if (void_in != void_out) { std::memcpy(void_out, void_in, elem_cnt * sizeof(T)); }\n      return Maybe<void>::Ok();\n    }\n    const T* in = reinterpret_cast<const T*>(void_in);\n    T* out = reinterpret_cast<T*>(void_out);\n    BalancedSplitter bs(elem_cnt, parallel_num);\n    auto recv_buffer = std::make_unique<T[]>(bs.At(0).size());\n    Optional<int64_t> parallel_id;\n    JUST(GetTensorDevice4CurrentProcessCtx(parallel_desc, &parallel_id));\n    const auto& rank_group = JUST(RankGroup::New(parallel_desc));\n    TransportToken transport_token =\n        JUST(TransportToken::NewTransportToken(kTransportTokenTypeData));\n    for (int64_t i = 0, part_id = JUST(parallel_id); i < parallel_num - 1;\n         ++i, part_id = RingDecrease(part_id, parallel_num)) {\n      int64_t send_part_id = part_id;\n      const T* send_ptr = nullptr;\n      if (i == 0) {\n        send_ptr = &in[bs.At(send_part_id).begin()];\n      } else {\n        send_ptr = &out[bs.At(send_part_id).begin()];\n      }\n      size_t send_size = bs.At(send_part_id).size();\n      int64_t recv_part_id = RingDecrease(part_id, parallel_num);\n      T* recv_ptr = recv_buffer.get();\n      size_t recv_size = bs.At(recv_part_id).size();\n      NaiveAsyncTransportCtx ctx(\n          transport_token,\n          [&](void** buffer, std::size_t* size, std::function<void()>* Cb) -> Maybe<void> {\n            *buffer = const_cast<T*>(send_ptr);\n            *size = send_size * sizeof(T);\n            *Cb = [] {};\n            return Maybe<void>::Ok();\n          },\n          [&](void** buffer, std::size_t* size, std::function<void()>* Cb) -> Maybe<void> {\n            *buffer = recv_ptr;\n            *size = recv_size * sizeof(T);\n            *Cb = [] {};\n            return Maybe<void>::Ok();\n          });\n      if (send_size > 0) {\n        JUST(TransportUtil::SendToNextRankInRing(rank_group, transport_token, &ctx));\n      }\n      if (recv_size > 0) {\n        JUST(TransportUtil::ReceiveFromPrevRankInRing(rank_group, transport_token, &ctx));\n      }\n      JUST(ctx.WaitDone());\n      const T* cur_in = &in[bs.At(recv_part_id).begin()];\n      T* cur_out = &out[bs.At(recv_part_id).begin()];\n      if (recv_size > 0) {\n        ReduceFunctor<T, reduce_type>::Call(recv_size, cur_out, cur_in, recv_ptr);\n      }\n    }\n    for (int64_t i = 0, part_id = RingIncrease(JUST(parallel_id), parallel_num);\n         i < parallel_num - 1; ++i, part_id = RingDecrease(part_id, parallel_num)) {\n      int64_t send_part_id = part_id;\n      const T* send_ptr = &out[bs.At(send_part_id).begin()];\n      size_t send_size = bs.At(send_part_id).size();\n      int64_t recv_part_id = RingDecrease(part_id, parallel_num);\n      T* recv_ptr = &out[bs.At(recv_part_id).begin()];\n      size_t recv_size = bs.At(recv_part_id).size();\n      NaiveAsyncTransportCtx ctx(\n          transport_token,\n          [&](void** buffer, std::size_t* size, std::function<void()>* Cb) -> Maybe<void> {\n            *buffer = const_cast<T*>(send_ptr);\n            *size = send_size * sizeof(T);\n            *Cb = [] {};\n            return Maybe<void>::Ok();\n          },\n          [&](void** buffer, std::size_t* size, std::function<void()>* Cb) -> Maybe<void> {\n            *buffer = recv_ptr;\n            *size = recv_size * sizeof(T);\n            *Cb = [] {};\n            return Maybe<void>::Ok();\n          });\n      if (send_size > 0) {\n        JUST(TransportUtil::SendToNextRankInRing(rank_group, transport_token, &ctx));\n      }\n      if (recv_size > 0) {\n        JUST(TransportUtil::ReceiveFromPrevRankInRing(rank_group, transport_token, &ctx));\n      }\n      JUST(ctx.WaitDone());\n    }\n    return Maybe<void>::Ok();\n  }\n};\n\n#define MAKE_ALL_REDUCE_ENTRY(func_name, T, reduce_type) func_name<T, reduce_type>::Call\n\nDEFINE_STATIC_SWITCH_FUNC(Maybe<void>, AllReduceImpl, MAKE_ALL_REDUCE_ENTRY,  // NOLINT\n                          MAKE_DATA_TYPE_CTRV_SEQ(POD_DATA_TYPE_SEQ),         // NOLINT\n                          REDUCE_TYPE_CTRV_SEQ);                              // NOLINT\n\n#undef MAKE_ALL_REDUCE_ENTRY\n\n}  // namespace\n\nclass CpuAllReduce final : public AllReduce {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CpuAllReduce);\n  CpuAllReduce() : datatype_(kInvalidDataType), reduce_type_(kInvalidReduceFunctorType) {}\n  ~CpuAllReduce() = default;\n\n  void Init(DataType datatype, ReduceType reduce_type) override {\n    this->datatype_ = datatype;\n    this->reduce_type_ = reduce_type;\n  }\n\n  void Launch(ep::Stream* stream, const void* in, void* out, size_t elem_cnt,\n              const std::shared_ptr<CommunicationContext>& communication_ctx) const override {\n    const auto& cpu_communication_ctx =\n        std::dynamic_pointer_cast<CpuCommunicationContext>(communication_ctx);\n    CHECK(cpu_communication_ctx) << kOfBugIssueUploadPrompt;\n    CHECK_JUST(SwitchAllReduceImpl(SwitchCase(datatype_, reduce_type_), in, out, elem_cnt,\n                                   cpu_communication_ctx->parallel_desc()));\n  }\n\n  void Launch(ep::Stream* stream, const void* in, void* out, size_t elem_cnt,\n              const ccl::CclComm& ccl_comm) const override {\n    UNIMPLEMENTED();\n  }\n\n private:\n  DataType datatype_;\n  ReduceType reduce_type_;\n};\n\nREGISTER_COLLECTIVE_COMMUNICATION(DeviceType::kCPU, AllReduce, CpuAllReduce);\n\n}  // namespace ccl\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/collective_communication/cpu/cpu_broadcast.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/data_type.h\"\n#include \"oneflow/core/ccl/ccl.h\"\n#include \"oneflow/core/job/rank_group.h\"\n#include \"oneflow/core/framework/transport_util.h\"\n#include \"oneflow/user/kernels/collective_communication/cpu/cpu_communication_context.h\"\n#include \"oneflow/user/kernels/collective_communication/include/broadcast.h\"\n\nnamespace oneflow {\n\nnamespace ccl {\n\n// Use CpuBroadcastImpl to avoid name conflict\nclass CpuBroadcastImpl final : public Broadcast {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CpuBroadcastImpl);\n  CpuBroadcastImpl() : size_of_dtype_(0) {}\n  ~CpuBroadcastImpl() = default;\n\n  void Init(DataType datatype) override {\n    CHECK(IsTriviallyCopyableDataType(datatype));\n    this->size_of_dtype_ = GetSizeOfDataType(datatype);\n  }\n\n  void Launch(ep::Stream* stream, const void* in, void* out, size_t elem_cnt, int64_t root,\n              const std::shared_ptr<CommunicationContext>& communication_ctx) const override {\n    const auto& cpu_communication_ctx =\n        std::dynamic_pointer_cast<CpuCommunicationContext>(communication_ctx);\n    CHECK(cpu_communication_ctx);\n    size_t buffer_size = elem_cnt * size_of_dtype_;\n    const auto& transport_token =\n        CHECK_JUST(TransportToken::NewTransportToken(kTransportTokenTypeData));\n    CHECK_JUST(CpuBroadcast(in, out, buffer_size, root, cpu_communication_ctx->parallel_desc(),\n                            transport_token));\n  }\n\n private:\n  size_t size_of_dtype_;\n};\n\nREGISTER_COLLECTIVE_COMMUNICATION(DeviceType::kCPU, Broadcast, CpuBroadcastImpl);\n\n}  // namespace ccl\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/collective_communication/cpu/cpu_collective_communication_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_CPU_CPU_COLLECTIVE_COMMUNICATION_UTIL_H_\n#define ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_CPU_CPU_COLLECTIVE_COMMUNICATION_UTIL_H_\n\n#include \"oneflow/core/thread/thread_manager.h\"\n#include \"oneflow/core/common/balanced_splitter.h\"\n\nnamespace oneflow {\n\nnamespace ccl {\n\ninline int64_t RingDecrease(int64_t n, int64_t size) { return (n - 1 + size) % size; }\n\ninline int64_t RingIncrease(int64_t n, int64_t size) { return (n + 1 + size) % size; }\n\ntemplate<typename T, ReduceType reduce_type>\nstruct ReduceFunctor;\n\ntemplate<typename T>\nstruct ReduceFunctor<T, kSum> {\n  static void Call(size_t size, T* out, const T* in0, const T* in1) {\n    size_t thread_num = Singleton<ThreadPool>::Get()->thread_num();\n    BalancedSplitter bs(size, thread_num);\n    MultiThreadLoop(thread_num, [&](size_t thread_idx) {\n      size_t end = bs.At(thread_idx).end();\n      for (size_t i = bs.At(thread_idx).begin(); i < end; ++i) { out[i] = in0[i] + in1[i]; }\n    });\n  }\n};\n\ntemplate<typename T>\nstruct ReduceFunctor<T, kMax> {\n  static void Call(size_t size, T* out, const T* in0, const T* in1) {\n    size_t thread_num = Singleton<ThreadPool>::Get()->thread_num();\n    BalancedSplitter bs(size, thread_num);\n    MultiThreadLoop(thread_num, [&](size_t thread_idx) {\n      size_t end = bs.At(thread_idx).end();\n      for (size_t i = bs.At(thread_idx).begin(); i < end; ++i) {\n        out[i] = std::max(in0[i], in1[i]);\n      }\n    });\n  }\n};\n\n}  // namespace ccl\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_CPU_CPU_COLLECTIVE_COMMUNICATION_UTIL_H_\n"
  },
  {
    "path": "oneflow/user/kernels/collective_communication/cpu/cpu_communication_context.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/user/kernels/collective_communication/cpu/cpu_communication_context.h\"\n#include \"oneflow/core/job/parallel_desc.h\"\n\nnamespace oneflow {\n\nnamespace ccl {\n\nvoid CpuCommunicationContext::Init(Symbol<ParallelDesc> parallel_desc) {\n  parallel_desc_ = parallel_desc;\n}\n\nREGISTER_COLLECTIVE_COMMUNICATION_COMMUNICATOR(DeviceType::kCPU, CpuCommunicationContext);\n\n}  // namespace ccl\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/collective_communication/cpu/cpu_communication_context.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_CPU_CPU_COMMUNICATION_CONTEXT_H_\n#define ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_CPU_CPU_COMMUNICATION_CONTEXT_H_\n\n#include \"oneflow/user/kernels/collective_communication/include/communication_context.h\"\n#include \"oneflow/core/common/symbol.h\"\n\nnamespace oneflow {\n\nclass ParallelDesc;\n\nnamespace ccl {\n\nclass CpuCommunicationContext : public CommunicationContext {\n public:\n  explicit CpuCommunicationContext() = default;\n  ~CpuCommunicationContext() override = default;\n\n  void Init(Symbol<ParallelDesc>) override;\n\n  Symbol<ParallelDesc> parallel_desc() const { return parallel_desc_; }\n\n private:\n  Symbol<ParallelDesc> parallel_desc_;\n};\n\n}  // namespace ccl\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_CPU_CPU_COMMUNICATION_CONTEXT_H_\n"
  },
  {
    "path": "oneflow/user/kernels/collective_communication/cpu/cpu_recv.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/data_type.h\"\n#include \"oneflow/core/ccl/ccl.h\"\n#include \"oneflow/core/job/rank_group.h\"\n#include \"oneflow/core/framework/transport_util.h\"\n#include \"oneflow/user/kernels/collective_communication/include/recv.h\"\n\nnamespace oneflow {\n\nnamespace ccl {\n\n// Use CpuRecvImpl to avoid name conflict\nclass CpuRecvImpl final : public Recv {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CpuRecvImpl);\n  CpuRecvImpl() : size_of_dtype_(0) {}\n  ~CpuRecvImpl() = default;\n\n  void Init(DataType datatype) override {\n    CHECK(IsTriviallyCopyableDataType(datatype));\n    this->size_of_dtype_ = GetSizeOfDataType(datatype);\n  }\n\n  void Launch(ep::Stream* stream, void* out, size_t elem_cnt, int64_t src) const override {\n    size_t buffer_size = elem_cnt * size_of_dtype_;\n    CHECK_JUST(CpuRecv(out, buffer_size, src));\n  }\n\n  void Launch(ep::Stream* stream, void* out, size_t elem_cnt, int64_t src,\n              const ccl::CclComm& ccl_comm) const override {\n    Launch(stream, out, elem_cnt, src);\n  }\n\n private:\n  size_t size_of_dtype_;\n};\n\nREGISTER_COLLECTIVE_COMMUNICATION(DeviceType::kCPU, Recv, CpuRecvImpl);\n\n}  // namespace ccl\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/collective_communication/cpu/cpu_reduce.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/data_type.h\"\n#include \"oneflow/core/control/global_process_ctx.h\"\n#include \"oneflow/core/job/rank_group.h\"\n#include \"oneflow/core/framework/transport_util.h\"\n#include \"oneflow/user/kernels/collective_communication/cpu/cpu_communication_context.h\"\n#include \"oneflow/user/kernels/collective_communication/include/reduce.h\"\n#include \"oneflow/user/kernels/collective_communication/cpu/cpu_collective_communication_util.h\"\n\nnamespace oneflow {\n\nnamespace ccl {\n\nnamespace {\n\ntemplate<typename T, ReduceType reduce_type>\nstruct ReduceImpl final {\n  static Maybe<void> Call(const void* void_in, void* void_out, size_t elem_cnt, int64_t root,\n                          Symbol<ParallelDesc> parallel_desc) {\n    const T* in = reinterpret_cast<const T*>(void_in);\n    T* out = reinterpret_cast<T*>(void_out);\n\n    int64_t parallel_num = parallel_desc->parallel_num();\n    BalancedSplitter bs(elem_cnt, parallel_num);\n\n    size_t size = root == GlobalProcessCtx::Rank() && void_in != void_out ? 0 : bs.At(0).size();\n    T* tmp_out = nullptr;\n    // void_out is only used on rank root and ignored for other ranks.\n    auto tmp_out_buffer = std::make_unique<T[]>(size);\n    int64_t parallel_id_of_root =\n        JUST(parallel_desc->ParallelId4MachineDeviceId(root, GlobalProcessCtx::LocalRank(root)));\n    if (root == GlobalProcessCtx::Rank() && void_in != void_out) {\n      tmp_out = &reinterpret_cast<T*>(void_out)[bs.At(parallel_id_of_root).begin()];\n    } else {\n      tmp_out = tmp_out_buffer.get();\n    }\n\n    auto recv_buffer = std::make_unique<T[]>(bs.At(0).size());\n    Optional<int64_t> parallel_id;\n    JUST(GetTensorDevice4CurrentProcessCtx(parallel_desc, &parallel_id));\n    const auto& rank_group = JUST(RankGroup::New(parallel_desc));\n    TransportToken transport_token =\n        JUST(TransportToken::NewTransportToken(kTransportTokenTypeData));\n    for (int64_t i = 0, part_id = RingDecrease(JUST(parallel_id), parallel_num);\n         i < parallel_num - 1; ++i, part_id = RingDecrease(part_id, parallel_num)) {\n      int64_t send_part_id = part_id;\n      const T* send_ptr = nullptr;\n      if (i == 0) {\n        send_ptr = &in[bs.At(send_part_id).begin()];\n      } else {\n        send_ptr = tmp_out;\n      }\n      size_t send_size = bs.At(send_part_id).size();\n      int64_t recv_part_id = RingDecrease(part_id, parallel_num);\n      T* recv_ptr = recv_buffer.get();\n      size_t recv_size = bs.At(recv_part_id).size();\n      NaiveAsyncTransportCtx ctx(\n          transport_token,\n          [&](void** buffer, std::size_t* size, std::function<void()>* Cb) -> Maybe<void> {\n            *buffer = const_cast<T*>(send_ptr);\n            *size = send_size * sizeof(T);\n            *Cb = [] {};\n            return Maybe<void>::Ok();\n          },\n          [&](void** buffer, std::size_t* size, std::function<void()>* Cb) -> Maybe<void> {\n            *buffer = recv_ptr;\n            *size = recv_size * sizeof(T);\n            *Cb = [] {};\n            return Maybe<void>::Ok();\n          });\n      if (send_size > 0) {\n        JUST(TransportUtil::SendToNextRankInRing(rank_group, transport_token, &ctx));\n      }\n      if (recv_size > 0) {\n        JUST(TransportUtil::ReceiveFromPrevRankInRing(rank_group, transport_token, &ctx));\n      }\n      JUST(ctx.WaitDone());\n      const T* cur_in = &in[bs.At(recv_part_id).begin()];\n      if (recv_size > 0) {\n        ReduceFunctor<T, reduce_type>::Call(recv_size, tmp_out, cur_in, recv_ptr);\n      }\n    }\n\n    if (root == GlobalProcessCtx::Rank() && void_in == void_out) {\n      memcpy(&out[bs.At(parallel_id_of_root).begin()], tmp_out,\n             bs.At(parallel_id_of_root).size() * sizeof(T));\n    }\n\n    for (int64_t i = 0, part_id = RingIncrease(parallel_id_of_root, parallel_num);\n         i < parallel_num - 1; ++i, part_id = RingIncrease(part_id, parallel_num)) {\n      int64_t send_part_id = part_id;\n      int64_t src_rank = JUST(parallel_desc->MachineId4ParallelId(send_part_id));\n      const T* send_ptr = tmp_out;\n      size_t send_size = bs.At(send_part_id).size();\n      int64_t recv_part_id = part_id;\n      T* recv_ptr = &out[bs.At(recv_part_id).begin()];\n      size_t recv_size = bs.At(recv_part_id).size();\n\n      if (send_size > 0 && src_rank == GlobalProcessCtx::Rank()) {\n        NaiveAsyncTransportCtx ctx(\n            transport_token,\n            [&](void** buffer, std::size_t* size, std::function<void()>* Cb) -> Maybe<void> {\n              *buffer = const_cast<T*>(send_ptr);\n              *size = send_size * sizeof(T);\n              *Cb = [] {};\n              return Maybe<void>::Ok();\n            },\n            [&](void** buffer, std::size_t* size, std::function<void()>* Cb) -> Maybe<void> {\n              UNIMPLEMENTED_THEN_RETURN();\n            });\n        JUST(TransportUtil::SendDataToRank(root, transport_token, &ctx));\n        JUST(ctx.WaitDone());\n      }\n      if (recv_size > 0 && root == GlobalProcessCtx::Rank()) {\n        NaiveAsyncTransportCtx ctx(\n            transport_token,\n            [&](void** buffer, std::size_t* size, std::function<void()>* Cb) -> Maybe<void> {\n              UNIMPLEMENTED_THEN_RETURN();\n            },\n            [&](void** buffer, std::size_t* size, std::function<void()>* Cb) -> Maybe<void> {\n              *buffer = recv_ptr;\n              *size = recv_size * sizeof(T);\n              *Cb = [] {};\n              return Maybe<void>::Ok();\n            });\n        JUST(TransportUtil::ReceiveDataFromRank(src_rank, transport_token, &ctx));\n        JUST(ctx.WaitDone());\n      }\n    }\n    return Maybe<void>::Ok();\n  }\n};\n\n#define MAKE_ALL_REDUCE_ENTRY(func_name, T, reduce_type) func_name<T, reduce_type>::Call\n\nDEFINE_STATIC_SWITCH_FUNC(Maybe<void>, ReduceImpl, MAKE_ALL_REDUCE_ENTRY,  // NOLINT\n                          MAKE_DATA_TYPE_CTRV_SEQ(POD_DATA_TYPE_SEQ),      // NOLINT\n                          REDUCE_TYPE_CTRV_SEQ);                           // NOLINT\n\n#undef MAKE_ALL_REDUCE_ENTRY\n\n}  // namespace\n\nclass CpuReduce final : public Reduce {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CpuReduce);\n  CpuReduce() : datatype_(kInvalidDataType), reduce_type_(kInvalidReduceFunctorType) {}\n  ~CpuReduce() = default;\n\n  void Init(DataType datatype, ReduceType reduce_type) override {\n    this->datatype_ = datatype;\n    this->reduce_type_ = reduce_type;\n  }\n\n  void Launch(ep::Stream* stream, const void* in, void* out, size_t elem_cnt, int64_t root,\n              const std::shared_ptr<CommunicationContext>& communication_ctx) const override {\n    const auto& cpu_communication_ctx =\n        std::dynamic_pointer_cast<CpuCommunicationContext>(communication_ctx);\n    CHECK(cpu_communication_ctx) << kOfBugIssueUploadPrompt;\n    CHECK_JUST(SwitchReduceImpl(SwitchCase(datatype_, reduce_type_), in, out, elem_cnt, root,\n                                cpu_communication_ctx->parallel_desc()));\n  }\n\n private:\n  DataType datatype_;\n  ReduceType reduce_type_;\n};\n\nREGISTER_COLLECTIVE_COMMUNICATION(DeviceType::kCPU, Reduce, CpuReduce);\n\n}  // namespace ccl\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/collective_communication/cpu/cpu_reduce_scatter.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/data_type.h\"\n#include \"oneflow/core/job/rank_group.h\"\n#include \"oneflow/core/framework/transport_util.h\"\n#include \"oneflow/user/kernels/collective_communication/cpu/cpu_communication_context.h\"\n#include \"oneflow/user/kernels/collective_communication/include/reduce_scatter.h\"\n#include \"oneflow/user/kernels/collective_communication/cpu/cpu_collective_communication_util.h\"\n\nnamespace oneflow {\n\nnamespace ccl {\n\nnamespace {\n\ntemplate<typename T, ReduceType reduce_type>\nstruct ReduceScatterImpl final {\n  static Maybe<void> Call(const void* void_in, void* void_out, size_t elem_cnt,\n                          Symbol<ParallelDesc> parallel_desc) {\n    int64_t parallel_num = parallel_desc->parallel_num();\n    if (parallel_num == 1) {\n      if (void_in != void_out) { std::memcpy(void_out, void_in, elem_cnt * sizeof(T)); }\n      return Maybe<void>::Ok();\n    }\n\n    const T* in = reinterpret_cast<const T*>(void_in);\n    T* out = reinterpret_cast<T*>(void_out);\n\n    BalancedSplitter bs(elem_cnt * parallel_num, parallel_num);\n    const auto& opt_parallel_id = JUST(GetParallelId4CurrentProcessCtx(parallel_desc));\n    CHECK_OR_RETURN(opt_parallel_id->has_value()) << kOfBugIssueUploadPrompt;\n    int64_t parallel_id = JUST(*opt_parallel_id);\n\n    auto recv_buffer = std::make_unique<T[]>(bs.At(0).size());\n    const auto& rank_group = JUST(RankGroup::New(parallel_desc));\n\n    TransportToken transport_token =\n        JUST(TransportToken::NewTransportToken(kTransportTokenTypeData));\n    for (int64_t i = 0, part_id = RingDecrease(parallel_id, parallel_num); i < parallel_num - 1;\n         ++i, part_id = RingDecrease(part_id, parallel_num)) {\n      int64_t send_part_id = part_id;\n      const T* send_ptr = nullptr;\n      if (i == 0) {\n        send_ptr = &in[bs.At(send_part_id).begin()];\n      } else {\n        send_ptr = out;\n      }\n      size_t send_size = bs.At(send_part_id).size();\n      int64_t recv_part_id = RingDecrease(part_id, parallel_num);\n      T* recv_ptr = recv_buffer.get();\n      size_t recv_size = bs.At(recv_part_id).size();\n      NaiveAsyncTransportCtx ctx(\n          transport_token,\n          [&](void** buffer, std::size_t* size, std::function<void()>* Cb) -> Maybe<void> {\n            *buffer = const_cast<T*>(send_ptr);\n            *size = send_size * sizeof(T);\n            *Cb = [] {};\n            return Maybe<void>::Ok();\n          },\n          [&](void** buffer, std::size_t* size, std::function<void()>* Cb) -> Maybe<void> {\n            *buffer = recv_ptr;\n            *size = recv_size * sizeof(T);\n            *Cb = [] {};\n            return Maybe<void>::Ok();\n          });\n      if (send_size > 0) {\n        JUST(TransportUtil::SendToNextRankInRing(rank_group, transport_token, &ctx));\n      }\n      if (recv_size > 0) {\n        JUST(TransportUtil::ReceiveFromPrevRankInRing(rank_group, transport_token, &ctx));\n      }\n      JUST(ctx.WaitDone());\n      const T* cur_in = &in[bs.At(recv_part_id).begin()];\n      if (recv_size > 0) { ReduceFunctor<T, reduce_type>::Call(recv_size, out, cur_in, recv_ptr); }\n    }\n    return Maybe<void>::Ok();\n  }\n};\n\n#define MAKE_ALL_REDUCE_ENTRY(func_name, T, reduce_type) func_name<T, reduce_type>::Call\n\nDEFINE_STATIC_SWITCH_FUNC(Maybe<void>, ReduceScatterImpl, MAKE_ALL_REDUCE_ENTRY,  // NOLINT\n                          MAKE_DATA_TYPE_CTRV_SEQ(POD_DATA_TYPE_SEQ),             // NOLINT\n                          REDUCE_TYPE_CTRV_SEQ);                                  // NOLINT\n\n#undef MAKE_ALL_REDUCE_ENTRY\n\n}  // namespace\n\nclass CpuReduceScatter final : public ReduceScatter {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CpuReduceScatter);\n  CpuReduceScatter() : datatype_(kInvalidDataType), reduce_type_(kInvalidReduceFunctorType) {}\n  ~CpuReduceScatter() = default;\n\n  void Init(DataType datatype, ReduceType reduce_type) override {\n    this->datatype_ = datatype;\n    this->reduce_type_ = reduce_type;\n  }\n\n  void Launch(ep::Stream* stream, const void* in, void* out, size_t elem_cnt,\n              const std::shared_ptr<CommunicationContext>& communication_ctx) const override {\n    const auto& cpu_communication_ctx =\n        std::dynamic_pointer_cast<CpuCommunicationContext>(communication_ctx);\n    CHECK(cpu_communication_ctx) << kOfBugIssueUploadPrompt;\n    CHECK_JUST(SwitchReduceScatterImpl(SwitchCase(datatype_, reduce_type_), in, out, elem_cnt,\n                                       cpu_communication_ctx->parallel_desc()));\n  }\n\n  void Launch(ep::Stream* stream, const void* in, void* out, size_t elem_cnt,\n              const ccl::CclComm& ccl_comm) const override {\n    UNIMPLEMENTED();\n  }\n\n private:\n  DataType datatype_;\n  ReduceType reduce_type_;\n};\n\nREGISTER_COLLECTIVE_COMMUNICATION(DeviceType::kCPU, ReduceScatter, CpuReduceScatter);\n\n}  // namespace ccl\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/collective_communication/cpu/cpu_send.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/data_type.h\"\n#include \"oneflow/core/ccl/ccl.h\"\n#include \"oneflow/core/job/rank_group.h\"\n#include \"oneflow/core/framework/transport_util.h\"\n#include \"oneflow/user/kernels/collective_communication/include/send.h\"\n\nnamespace oneflow {\n\nnamespace ccl {\n\n// Use CpuSendImpl to avoid name conflict\nclass CpuSendImpl final : public Send {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CpuSendImpl);\n  CpuSendImpl() : size_of_dtype_(0) {}\n  ~CpuSendImpl() = default;\n\n  void Init(DataType datatype) override {\n    CHECK(IsTriviallyCopyableDataType(datatype));\n    this->size_of_dtype_ = GetSizeOfDataType(datatype);\n  }\n\n  void Launch(ep::Stream* stream, const void* in, size_t elem_cnt, int64_t dst) const override {\n    size_t buffer_size = elem_cnt * size_of_dtype_;\n    CHECK_JUST(CpuSend(in, buffer_size, dst));\n  }\n\n  void Launch(ep::Stream* stream, const void* in, size_t elem_cnt, int64_t dst,\n              const ccl::CclComm& comm) const override {\n    Launch(stream, in, elem_cnt, dst);\n  }\n\n private:\n  size_t size_of_dtype_;\n};\n\nREGISTER_COLLECTIVE_COMMUNICATION(DeviceType::kCPU, Send, CpuSendImpl);\n\n}  // namespace ccl\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/collective_communication/cuda/cuda_all_gather.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifdef WITH_CUDA\n#include \"oneflow/user/kernels/collective_communication/include/all_gather.h\"\n#include \"oneflow/user/kernels/collective_communication/cuda/cuda_communication_context.h\"\n#include \"oneflow/core/device/nccl_util.h\"\n\nnamespace oneflow {\n\nnamespace ccl {\n\nclass CudaAllGather final : public AllGather {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CudaAllGather);\n  CudaAllGather() : nccl_datatype_() {}\n  ~CudaAllGather() = default;\n\n  void Init(DataType datatype) override { this->nccl_datatype_ = GetNcclDataType(datatype); }\n\n  void Launch(ep::Stream* stream, const void* in, void* out, size_t elem_cnt,\n              const std::shared_ptr<CommunicationContext>& communication_ctx) const override {\n    const auto& cuda_communication_ctx =\n        std::dynamic_pointer_cast<CudaCommunicationContext>(communication_ctx);\n    CHECK(cuda_communication_ctx) << kOfBugIssueUploadPrompt;\n    OF_NCCL_CHECK(ncclAllGather(in, out, elem_cnt, nccl_datatype_,\n                                cuda_communication_ctx->nccl_comm(),\n                                stream->As<ep::CudaStream>()->cuda_stream()));\n  }\n\n  virtual void Launch(ep::Stream* stream, const void* in, void* out, size_t elem_cnt,\n                      const ccl::CclComm& ccl_comm) const override {\n    ncclComm_t* nccl_comm = reinterpret_cast<ncclComm_t*>(ccl_comm.getComm());\n    OF_NCCL_CHECK(ncclAllGather(in, out, elem_cnt, nccl_datatype_, *nccl_comm,\n                                stream->As<ep::CudaStream>()->cuda_stream()));\n  }\n\n private:\n  ncclDataType_t nccl_datatype_;\n};\n\nREGISTER_COLLECTIVE_COMMUNICATION(DeviceType::kCUDA, AllGather, CudaAllGather);\n\n}  // namespace ccl\n\n}  // namespace oneflow\n\n#endif  // WITH_CUDA\n"
  },
  {
    "path": "oneflow/user/kernels/collective_communication/cuda/cuda_all_reduce.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifdef WITH_CUDA\n#include \"oneflow/user/kernels/collective_communication/include/all_reduce.h\"\n#include \"oneflow/user/kernels/collective_communication/cuda/cuda_communication_context.h\"\n#include \"oneflow/core/device/nccl_util.h\"\n\nnamespace oneflow {\n\nnamespace ccl {\n\nnamespace {\n\ninline ncclRedOp_t GetNcclReduceType(ReduceType reduce_type) {\n  switch (reduce_type) {\n#define NCCL_REDUCE_TYPE_CASE(dtype) \\\n  case ReduceType::k##dtype: return ncclRedOp_t::nccl##dtype\n    NCCL_REDUCE_TYPE_CASE(Sum);\n    NCCL_REDUCE_TYPE_CASE(Max);\n    default: PRINT_BUG_PROMPT_AND_ABORT();\n  }\n}\n\n}  // namespace\n\nclass CudaAllReduce final : public AllReduce {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CudaAllReduce);\n  CudaAllReduce() : nccl_datatype_(), nccl_reduce_op_() {}\n  ~CudaAllReduce() = default;\n\n  void Init(DataType datatype, ReduceType reduce_type) override {\n    this->nccl_datatype_ = GetNcclDataType(datatype);\n    this->nccl_reduce_op_ = GetNcclReduceType(reduce_type);\n  }\n\n  void Launch(ep::Stream* stream, const void* in, void* out, size_t elem_cnt,\n              const std::shared_ptr<CommunicationContext>& communication_ctx) const override {\n    const auto& cuda_communication_ctx =\n        std::dynamic_pointer_cast<CudaCommunicationContext>(communication_ctx);\n    CHECK(cuda_communication_ctx);\n    OF_NCCL_CHECK(ncclAllReduce(in, out, elem_cnt, nccl_datatype_, nccl_reduce_op_,\n                                cuda_communication_ctx->nccl_comm(),\n                                stream->As<ep::CudaStream>()->cuda_stream()));\n  }\n\n  void Launch(ep::Stream* stream, const void* in, void* out, size_t elem_cnt,\n              const ccl::CclComm& ccl_comm) const override {\n    ncclComm_t* nccl_comm = reinterpret_cast<ncclComm_t*>(ccl_comm.getComm());\n    OF_NCCL_CHECK(ncclAllReduce(in, out, elem_cnt, nccl_datatype_, nccl_reduce_op_, *nccl_comm,\n                                stream->As<ep::CudaStream>()->cuda_stream()));\n  }\n\n private:\n  ncclDataType_t nccl_datatype_;\n  ncclRedOp_t nccl_reduce_op_;\n};\n\nREGISTER_COLLECTIVE_COMMUNICATION(DeviceType::kCUDA, AllReduce, CudaAllReduce);\n\n}  // namespace ccl\n\n}  // namespace oneflow\n\n#endif  // WITH_CUDA\n"
  },
  {
    "path": "oneflow/user/kernels/collective_communication/cuda/cuda_all_to_all.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifdef WITH_CUDA\n#include \"oneflow/user/kernels/collective_communication/include/all_to_all.h\"\n#include \"oneflow/user/kernels/collective_communication/cuda/cuda_communication_context.h\"\n#include \"oneflow/core/device/nccl_util.h\"\n#include \"oneflow/core/common/device_type.h\"\n\nnamespace oneflow {\n\nnamespace ccl {\n\nclass CudaAllToAll final : public AllToAll {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CudaAllToAll);\n  CudaAllToAll()\n      : send_dtype_(), recv_dtype_(), nccl_send_dtype_(), nccl_recv_dtype_(), rank_count_(0) {}\n  ~CudaAllToAll() = default;\n\n  void Init(DataType send_dtype, DataType recv_dtype, size_t parallel_num) override {\n    this->send_dtype_ = send_dtype;\n    this->recv_dtype_ = recv_dtype;\n    this->nccl_send_dtype_ = GetNcclDataType(send_dtype);\n    this->nccl_recv_dtype_ = GetNcclDataType(recv_dtype);\n    this->rank_count_ = parallel_num;\n  }\n\n  void Launch(ep::Stream* stream, void* send, int64_t send_count, void* recv, int64_t recv_count,\n              const ccl::CclComm& ccl_comm) const override {\n    ncclComm_t* nccl_comm = reinterpret_cast<ncclComm_t*>(ccl_comm.getComm());\n    int64_t send_offset = 0;\n    int64_t recv_offset = 0;\n    OF_NCCL_CHECK(ncclGroupStart());\n    for (int64_t i = 0; i < this->rank_count_; ++i) {\n      if (send_count > 0) {\n        char* send_ptr = static_cast<char*>(send) + send_offset;\n        OF_NCCL_CHECK(ncclSend(send_ptr, send_count, this->nccl_send_dtype_, i, *nccl_comm,\n                               stream->As<ep::CudaStream>()->cuda_stream()));\n      }\n      send_offset += send_count * GetSizeOfDataType(this->send_dtype_);\n      if (recv_count) {\n        char* recv_ptr = static_cast<char*>(recv) + recv_offset;\n        OF_NCCL_CHECK(ncclRecv(recv_ptr, recv_count, this->nccl_recv_dtype_, i, *nccl_comm,\n                               stream->As<ep::CudaStream>()->cuda_stream()));\n      }\n      recv_offset += recv_count * GetSizeOfDataType(this->recv_dtype_);\n    }\n    OF_NCCL_CHECK(ncclGroupEnd());\n  }\n\n  void Launch(ep::Stream* stream, void* send, const void* send_counts, const void* send_offsets,\n              void* recv, const void* recv_counts, const void* recv_offsets,\n              const ccl::CclComm& ccl_comm, const bool has_input,\n              const bool has_output) const override {\n    ncclComm_t* nccl_comm = reinterpret_cast<ncclComm_t*>(ccl_comm.getComm());\n    int64_t* send_counts_ptr = static_cast<int64_t*>(const_cast<void*>(send_counts));\n    int64_t* recv_counts_ptr = static_cast<int64_t*>(const_cast<void*>(recv_counts));\n    int64_t* send_offsets_ptr = static_cast<int64_t*>(const_cast<void*>(send_offsets));\n    int64_t* recv_offsets_ptr = static_cast<int64_t*>(const_cast<void*>(recv_offsets));\n    if (has_input || has_output) {\n      OF_NCCL_CHECK(ncclGroupStart());\n      for (int64_t i = 0; i < this->rank_count_; ++i) {\n        if (has_input) {\n          const uint64_t send_count = static_cast<uint64_t>(send_counts_ptr[i]);\n          if (send_count > 0) {\n            uint64_t send_offset = static_cast<uint64_t>(send_offsets_ptr[i]);\n            char* send_ptr = static_cast<char*>(send) + send_offset;\n            OF_NCCL_CHECK(ncclSend(send_ptr, send_count, this->nccl_send_dtype_, i, *nccl_comm,\n                                   stream->As<ep::CudaStream>()->cuda_stream()));\n          }\n        }\n        if (has_output) {\n          const uint64_t recv_count = static_cast<uint64_t>(recv_counts_ptr[i]);\n          if (recv_count > 0) {\n            uint64_t recv_offset = static_cast<uint64_t>(recv_offsets_ptr[i]);\n            char* recv_ptr = static_cast<char*>(recv) + recv_offset;\n            OF_NCCL_CHECK(ncclRecv(recv_ptr, recv_count, this->nccl_recv_dtype_, i, *nccl_comm,\n                                   stream->As<ep::CudaStream>()->cuda_stream()));\n          }\n        }\n      }\n      OF_NCCL_CHECK(ncclGroupEnd());\n    }\n  }\n\n private:\n  DataType send_dtype_;\n  DataType recv_dtype_;\n  ncclDataType_t nccl_send_dtype_;\n  ncclDataType_t nccl_recv_dtype_;\n  size_t rank_count_;\n};\n\nREGISTER_COLLECTIVE_COMMUNICATION(DeviceType::kCUDA, AllToAll, CudaAllToAll);\n\n}  // namespace ccl\n\n}  // namespace oneflow\n\n#endif  // WITH_CUDA\n"
  },
  {
    "path": "oneflow/user/kernels/collective_communication/cuda/cuda_broadcast.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifdef WITH_CUDA\n#include \"oneflow/user/kernels/collective_communication/include/broadcast.h\"\n#include \"oneflow/user/kernels/collective_communication/cuda/cuda_communication_context.h\"\n#include \"oneflow/core/device/nccl_util.h\"\n\nnamespace oneflow {\n\nnamespace ccl {\n\nclass CudaBroadcast final : public Broadcast {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CudaBroadcast);\n  CudaBroadcast() : nccl_datatype_() {}\n  ~CudaBroadcast() = default;\n\n  void Init(DataType datatype) override { this->nccl_datatype_ = GetNcclDataType(datatype); }\n\n  void Launch(ep::Stream* stream, const void* in, void* out, size_t elem_cnt, int64_t root,\n              const std::shared_ptr<CommunicationContext>& communication_ctx) const override {\n    const auto& cuda_communication_ctx =\n        std::dynamic_pointer_cast<CudaCommunicationContext>(communication_ctx);\n    CHECK(cuda_communication_ctx);\n    OF_NCCL_CHECK(ncclBroadcast(\n        in, out, elem_cnt, nccl_datatype_, cuda_communication_ctx->nccl_index4rank(root),\n        cuda_communication_ctx->nccl_comm(), stream->As<ep::CudaStream>()->cuda_stream()));\n  }\n\n private:\n  ncclDataType_t nccl_datatype_;\n};\n\nREGISTER_COLLECTIVE_COMMUNICATION(DeviceType::kCUDA, Broadcast, CudaBroadcast);\n\n}  // namespace ccl\n\n}  // namespace oneflow\n\n#endif  // WITH_CUDA\n"
  },
  {
    "path": "oneflow/user/kernels/collective_communication/cuda/cuda_communication_context.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/user/kernels/collective_communication/cuda/cuda_communication_context.h\"\n#include \"oneflow/core/job/eager_nccl_comm_manager.h\"\n\n#ifdef WITH_CUDA\n\nnamespace oneflow {\n\nnamespace ccl {\n\nvoid CudaCommunicationContext::Init(Symbol<ParallelDesc> parallel_desc) {\n  std::set<std::pair<int64_t, int64_t>> device_set;\n  FOR_RANGE(int64_t, parallel_id, 0, parallel_desc->parallel_num()) {\n    int64_t machine_id = CHECK_JUST(parallel_desc->MachineId4ParallelId(parallel_id));\n    int64_t device_id = CHECK_JUST(parallel_desc->DeviceId4ParallelId(parallel_id));\n    device_set.emplace(std::make_pair(machine_id, device_id));\n    rank2nccl_index_.emplace(machine_id, parallel_id);\n  }\n  nccl_comm_ = CHECK_NOTNULL(Singleton<EagerCclCommMgr>::Get())\n                   ->As<EagerNcclCommMgr>()\n                   ->GetCommForDevice(device_set);\n}\n\nREGISTER_COLLECTIVE_COMMUNICATION_COMMUNICATOR(DeviceType::kCUDA, CudaCommunicationContext);\n\n}  // namespace ccl\n\n}  // namespace oneflow\n\n#endif  // WITH_CUDA\n"
  },
  {
    "path": "oneflow/user/kernels/collective_communication/cuda/cuda_communication_context.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_CUDA_CUDA_COMMUNICATION_CONTEXT_H_\n#define ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_CUDA_CUDA_COMMUNICATION_CONTEXT_H_\n\n#include \"oneflow/user/kernels/collective_communication/include/communication_context.h\"\n#include \"oneflow/core/common/symbol.h\"\n#include \"oneflow/core/job/parallel_desc.h\"\n\n#ifdef WITH_CUDA\n\n#include \"oneflow/core/device/cuda_util.h\"\n\nnamespace oneflow {\n\nnamespace ccl {\n\nclass CudaCommunicationContext : public CommunicationContext {\n public:\n  explicit CudaCommunicationContext() = default;\n  ~CudaCommunicationContext() override = default;\n\n  void Init(Symbol<ParallelDesc>) override;\n\n  ncclComm_t nccl_comm() const { return nccl_comm_; }\n  int64_t nccl_index4rank(int rank) const { return rank2nccl_index_.at(rank); }\n\n private:\n  ncclComm_t nccl_comm_;\n  HashMap<int64_t, int64_t> rank2nccl_index_;\n};\n\n}  // namespace ccl\n\n}  // namespace oneflow\n\n#endif  // WITH_CUDA\n\n#endif  // ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_CUDA_CUDA_COMMUNICATION_CONTEXT_H_\n"
  },
  {
    "path": "oneflow/user/kernels/collective_communication/cuda/cuda_recv.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifdef WITH_CUDA\n#include \"oneflow/user/kernels/collective_communication/include/recv.h\"\n#include \"oneflow/user/kernels/collective_communication/cuda/cuda_send_recv_util.h\"\n#include \"oneflow/user/kernels/collective_communication/cuda/cuda_communication_context.h\"\n#include \"oneflow/core/device/nccl_util.h\"\n\nnamespace oneflow {\n\nnamespace ccl {\n\nclass CudaRecv final : public Recv {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CudaRecv);\n  CudaRecv() : nccl_datatype_() {}\n  ~CudaRecv() = default;\n\n  void Init(DataType datatype) override { this->nccl_datatype_ = GetNcclDataType(datatype); }\n\n  void Launch(ep::Stream* stream, void* out, size_t elem_cnt, int64_t src) const override {\n#if HAS_NCCL_SEND_RECV\n    const auto& comm_and_peer_rank = GetNcclCommAndPeerNcclRank(src);\n    OF_NCCL_CHECK(ncclRecv(out, elem_cnt, nccl_datatype_, comm_and_peer_rank.second,\n                           comm_and_peer_rank.first, stream->As<ep::CudaStream>()->cuda_stream()));\n#else\n    UNIMPLEMENTED() << \"GPU recv is only supported when nccl version >= 2.7\"\n#endif  // HAS_NCCL_SEND_RECV\n  }\n\n  void Launch(ep::Stream* stream, void* out, size_t elem_cnt, int64_t src,\n              const ccl::CclComm& ccl_comm) const override {\n#if HAS_NCCL_SEND_RECV\n    ncclComm_t* comm = reinterpret_cast<ncclComm_t*>(ccl_comm.getComm());\n    OF_NCCL_CHECK(ncclRecv(out, elem_cnt, nccl_datatype_, src, *comm,\n                           stream->As<ep::CudaStream>()->cuda_stream()));\n#else\n    UNIMPLEMENTED() << \"GPU recv is only supported when nccl version >= 2.7\"\n#endif  // HAS_NCCL_SEND_RECV\n  }\n\n private:\n  ncclDataType_t nccl_datatype_;\n};\n\nREGISTER_COLLECTIVE_COMMUNICATION(DeviceType::kCUDA, Recv, CudaRecv);\n\n}  // namespace ccl\n\n}  // namespace oneflow\n\n#endif  // WITH_CUDA\n"
  },
  {
    "path": "oneflow/user/kernels/collective_communication/cuda/cuda_reduce.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifdef WITH_CUDA\n#include \"oneflow/user/kernels/collective_communication/include/reduce.h\"\n#include \"oneflow/user/kernels/collective_communication/cuda/cuda_communication_context.h\"\n#include \"oneflow/core/device/nccl_util.h\"\n\nnamespace oneflow {\n\nnamespace ccl {\n\nnamespace {\n\ninline ncclRedOp_t GetNcclReduceType(ReduceType reduce_type) {\n  switch (reduce_type) {\n#define NCCL_REDUCE_TYPE_CASE(dtype) \\\n  case ReduceType::k##dtype: return ncclRedOp_t::nccl##dtype\n    NCCL_REDUCE_TYPE_CASE(Sum);\n    NCCL_REDUCE_TYPE_CASE(Max);\n    default: PRINT_BUG_PROMPT_AND_ABORT();\n  }\n}\n\n}  // namespace\n\nclass CudaReduce final : public Reduce {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CudaReduce);\n  CudaReduce() : nccl_datatype_(), nccl_reduce_op_() {}\n  ~CudaReduce() = default;\n\n  void Init(DataType datatype, ReduceType reduce_type) override {\n    this->nccl_datatype_ = GetNcclDataType(datatype);\n    this->nccl_reduce_op_ = GetNcclReduceType(reduce_type);\n  }\n\n  void Launch(ep::Stream* stream, const void* in, void* out, size_t elem_cnt, int64_t root,\n              const std::shared_ptr<CommunicationContext>& communication_ctx) const override {\n    const auto& cuda_communication_ctx =\n        std::dynamic_pointer_cast<CudaCommunicationContext>(communication_ctx);\n    CHECK(cuda_communication_ctx) << kOfBugIssueUploadPrompt;\n    OF_NCCL_CHECK(ncclReduce(in, out, elem_cnt, nccl_datatype_, nccl_reduce_op_,\n                             cuda_communication_ctx->nccl_index4rank(root),\n                             cuda_communication_ctx->nccl_comm(),\n                             stream->As<ep::CudaStream>()->cuda_stream()));\n  }\n\n private:\n  ncclDataType_t nccl_datatype_;\n  ncclRedOp_t nccl_reduce_op_;\n};\n\nREGISTER_COLLECTIVE_COMMUNICATION(DeviceType::kCUDA, Reduce, CudaReduce);\n\n}  // namespace ccl\n\n}  // namespace oneflow\n\n#endif  // WITH_CUDA\n"
  },
  {
    "path": "oneflow/user/kernels/collective_communication/cuda/cuda_reduce_scatter.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifdef WITH_CUDA\n#include \"oneflow/user/kernels/collective_communication/include/reduce_scatter.h\"\n#include \"oneflow/user/kernels/collective_communication/cuda/cuda_communication_context.h\"\n#include \"oneflow/core/device/nccl_util.h\"\n\nnamespace oneflow {\n\nnamespace ccl {\n\nnamespace {\n\ninline ncclRedOp_t GetNcclReduceType(ReduceType reduce_type) {\n  switch (reduce_type) {\n#define NCCL_REDUCE_TYPE_CASE(dtype) \\\n  case ReduceType::k##dtype: return ncclRedOp_t::nccl##dtype\n    NCCL_REDUCE_TYPE_CASE(Sum);\n    NCCL_REDUCE_TYPE_CASE(Max);\n    default: PRINT_BUG_PROMPT_AND_ABORT();\n  }\n}\n\n}  // namespace\n\nclass CudaReduceScatter final : public ReduceScatter {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CudaReduceScatter);\n  CudaReduceScatter() : nccl_datatype_(), nccl_reduce_op_() {}\n  ~CudaReduceScatter() = default;\n\n  void Init(DataType datatype, ReduceType reduce_type) override {\n    this->nccl_datatype_ = GetNcclDataType(datatype);\n    this->nccl_reduce_op_ = GetNcclReduceType(reduce_type);\n  }\n\n  void Launch(ep::Stream* stream, const void* in, void* out, size_t elem_cnt,\n              const std::shared_ptr<CommunicationContext>& communication_ctx) const override {\n    const auto& cuda_communication_ctx =\n        std::dynamic_pointer_cast<CudaCommunicationContext>(communication_ctx);\n    CHECK(cuda_communication_ctx) << kOfBugIssueUploadPrompt;\n    OF_NCCL_CHECK(ncclReduceScatter(in, out, elem_cnt, nccl_datatype_, nccl_reduce_op_,\n                                    cuda_communication_ctx->nccl_comm(),\n                                    stream->As<ep::CudaStream>()->cuda_stream()));\n  }\n\n  virtual void Launch(ep::Stream* stream, const void* in, void* out, size_t elem_cnt,\n                      const ccl::CclComm& ccl_comm) const override {\n    ncclComm_t* nccl_comm = reinterpret_cast<ncclComm_t*>(ccl_comm.getComm());\n    OF_NCCL_CHECK(ncclReduceScatter(in, out, elem_cnt, nccl_datatype_, nccl_reduce_op_, *nccl_comm,\n                                    stream->As<ep::CudaStream>()->cuda_stream()));\n  }\n\n private:\n  ncclDataType_t nccl_datatype_;\n  ncclRedOp_t nccl_reduce_op_;\n};\n\nREGISTER_COLLECTIVE_COMMUNICATION(DeviceType::kCUDA, ReduceScatter, CudaReduceScatter);\n\n}  // namespace ccl\n\n}  // namespace oneflow\n\n#endif  // WITH_CUDA\n"
  },
  {
    "path": "oneflow/user/kernels/collective_communication/cuda/cuda_send.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifdef WITH_CUDA\n#include \"oneflow/user/kernels/collective_communication/include/send.h\"\n#include \"oneflow/user/kernels/collective_communication/cuda/cuda_send_recv_util.h\"\n#include \"oneflow/user/kernels/collective_communication/cuda/cuda_communication_context.h\"\n#include \"oneflow/core/device/nccl_util.h\"\n\nnamespace oneflow {\n\nnamespace ccl {\n\nclass CudaSend final : public Send {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CudaSend);\n  CudaSend() : nccl_datatype_() {}\n  ~CudaSend() = default;\n\n  void Init(DataType datatype) override { this->nccl_datatype_ = GetNcclDataType(datatype); }\n\n  void Launch(ep::Stream* stream, const void* in, size_t elem_cnt, int64_t dst) const override {\n#if HAS_NCCL_SEND_RECV\n    const auto& comm_and_peer_rank = GetNcclCommAndPeerNcclRank(dst);\n    OF_NCCL_CHECK(ncclSend(in, elem_cnt, nccl_datatype_, comm_and_peer_rank.second,\n                           comm_and_peer_rank.first, stream->As<ep::CudaStream>()->cuda_stream()));\n#else\n    UNIMPLEMENTED() << \"GPU send is only supported when nccl version >= 2.7\"\n#endif  // HAS_NCCL_SEND_RECV\n  }\n\n  void Launch(ep::Stream* stream, const void* in, size_t elem_cnt, int64_t dst,\n              const ccl::CclComm& ccl_comm) const override {\n#if HAS_NCCL_SEND_RECV\n    ncclComm_t* comm = reinterpret_cast<ncclComm_t*>(ccl_comm.getComm());\n    OF_NCCL_CHECK(ncclSend(in, elem_cnt, nccl_datatype_, dst, *comm,\n                           stream->As<ep::CudaStream>()->cuda_stream()));\n#else\n    UNIMPLEMENTED() << \"GPU send is only supported when nccl version >= 2.7\"\n#endif  // HAS_NCCL_SEND_RECV\n  }\n\n private:\n  ncclDataType_t nccl_datatype_;\n};\n\nREGISTER_COLLECTIVE_COMMUNICATION(DeviceType::kCUDA, Send, CudaSend);\n\n}  // namespace ccl\n\n}  // namespace oneflow\n\n#endif  // WITH_CUDA\n"
  },
  {
    "path": "oneflow/user/kernels/collective_communication/cuda/cuda_send_recv_util.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/user/kernels/collective_communication/cuda/cuda_send_recv_util.h\"\n#include \"oneflow/core/rpc/include/global_process_ctx.h\"\n#include \"oneflow/core/common/decorator.h\"\n#ifdef WITH_CUDA\n#include \"oneflow/core/job/eager_nccl_comm_manager.h\"\n\nnamespace oneflow {\n\nnamespace ccl {\n\nstd::pair<ncclComm_t, int64_t> RawGetNcclCommAndPeerNcclRank(int64_t peer_process_id) {\n  std::set<std::pair<int64_t, int64_t>> device_set;\n  const int64_t& rank = GlobalProcessCtx::Rank();\n  const int64_t peer_nccl_rank = (peer_process_id > rank) ? 1 : 0;\n  device_set.emplace(rank, GlobalProcessCtx::LocalRank());\n  device_set.emplace(peer_process_id, GlobalProcessCtx::LocalRank(peer_process_id));\n  return {CHECK_NOTNULL(Singleton<EagerCclCommMgr>::Get())\n              ->As<EagerNcclCommMgr>()\n              ->GetCommForDevice(device_set),\n          peer_nccl_rank};\n}\n\ndecltype(GetNcclCommAndPeerNcclRank) GetNcclCommAndPeerNcclRank =\n    DECORATE(&RawGetNcclCommAndPeerNcclRank, ThreadLocal);\n\n}  // namespace ccl\n\n}  // namespace oneflow\n\n#endif  // WITH_CUDA\n"
  },
  {
    "path": "oneflow/user/kernels/collective_communication/cuda/cuda_send_recv_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_CUDA_CUDA_SEND_RECV_UTIL_H_\n#define ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_CUDA_CUDA_SEND_RECV_UTIL_H_\n\n#ifdef WITH_CUDA\n#include \"oneflow/core/device/nccl_util.h\"\n\nnamespace oneflow {\n\nnamespace ccl {\n\nextern std::pair<ncclComm_t, int64_t> (*GetNcclCommAndPeerNcclRank)(int64_t peer_process_i);\n\n}  // namespace ccl\n\n}  // namespace oneflow\n\n#endif  // WITH_CUDA\n\n#endif  // ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_CUDA_CUDA_SEND_RECV_UTIL_H_\n"
  },
  {
    "path": "oneflow/user/kernels/collective_communication/include/all_gather.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_INCLUDE_ALL_GATHER_H_\n#define ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_INCLUDE_ALL_GATHER_H_\n\n#include \"oneflow/user/kernels/collective_communication/include/collective_communication.h\"\n\nnamespace oneflow {\n\nnamespace ccl {\n\nclass AllGather : public CollectiveCommunication {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(AllGather);\n  AllGather() = default;\n  ~AllGather() override = default;\n\n  virtual void Init(DataType dtype) = 0;\n\n  virtual void Launch(ep::Stream* stream, const void* in, void* out, size_t elem_cnt,\n                      const std::shared_ptr<CommunicationContext>& communicator) const = 0;\n\n  virtual void Launch(ep::Stream* stream, const void* in, void* out, size_t elem_cnt,\n                      const ccl::CclComm& ccl_comm) const = 0;\n};\n\ninline bool IsAllGatherRegistered(DeviceType device_type) {\n  return IsClassRegistered<DeviceType, AllGather>(device_type);\n}\n\n}  // namespace ccl\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_INCLUDE_ALL_GATHER_H_\n"
  },
  {
    "path": "oneflow/user/kernels/collective_communication/include/all_reduce.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_INCLUDE_ALL_REDUCE_H_\n#define ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_INCLUDE_ALL_REDUCE_H_\n\n#include \"oneflow/user/kernels/collective_communication/include/collective_communication.h\"\n\nnamespace oneflow {\n\nnamespace ccl {\n\nclass AllReduce : public CollectiveCommunication {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(AllReduce);\n  AllReduce() = default;\n  ~AllReduce() override = default;\n\n  virtual void Init(DataType dtype, ReduceType reduce_type) = 0;\n\n  virtual void Launch(ep::Stream* stream, const void* in, void* out, size_t elem_cnt,\n                      const std::shared_ptr<CommunicationContext>& communicator) const = 0;\n\n  virtual void Launch(ep::Stream* stream, const void* in, void* out, size_t elem_cnt,\n                      const ccl::CclComm& ccl_comm) const = 0;\n};\n\ninline bool IsAllReduceRegistered(DeviceType device_type) {\n  return IsClassRegistered<DeviceType, AllReduce>(device_type);\n}\n\n}  // namespace ccl\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_INCLUDE_ALL_REDUCE_H_\n"
  },
  {
    "path": "oneflow/user/kernels/collective_communication/include/all_to_all.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_INCLUDE_ALL_TO_ALL_H_\n#define ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_INCLUDE_ALL_TO_ALL_H_\n\n#include \"oneflow/user/kernels/collective_communication/include/collective_communication.h\"\n\nnamespace oneflow {\n\nnamespace ccl {\n\nclass AllToAll : public CollectiveCommunication {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(AllToAll);\n  AllToAll() = default;\n  ~AllToAll() override = default;\n\n  virtual void Init(DataType send_dtype, DataType recv_dtype, size_t rank_count) = 0;\n\n  // for normal alltoall（balanced send/resv count)\n  virtual void Launch(ep::Stream* stream, void* send, int64_t send_count, void* recv,\n                      int64_t recv_count, const ccl::CclComm& ccl_comm) const = 0;\n\n  // for unbalanced all to all(e.g. nccl all2all using send/recv; hccl HcclAlltoAllV)\n  virtual void Launch(ep::Stream* stream, void* send, const void* send_counts,\n                      const void* send_offsets, void* recv, const void* recv_counts,\n                      const void* recv_offsets, const ccl::CclComm& ccl_comm, const bool has_input,\n                      const bool has_output) const = 0;\n};\n\ninline bool IsAllToAllRegistered(DeviceType device_type) {\n  return IsClassRegistered<DeviceType, AllToAll>(device_type);\n}\n\n}  // namespace ccl\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_INCLUDE_ALL_TO_ALL_H_\n"
  },
  {
    "path": "oneflow/user/kernels/collective_communication/include/broadcast.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_INCLUDE_BROADCAST_H_\n#define ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_INCLUDE_BROADCAST_H_\n\n#include \"oneflow/user/kernels/collective_communication/include/collective_communication.h\"\n\nnamespace oneflow {\n\nnamespace ccl {\n\nclass Broadcast : public CollectiveCommunication {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(Broadcast);\n  Broadcast() = default;\n  ~Broadcast() override = default;\n\n  virtual void Init(DataType dtype) = 0;\n\n  virtual void Launch(ep::Stream* stream, const void* in, void* out, size_t elem_cnt, int64_t root,\n                      const std::shared_ptr<CommunicationContext>& communicator) const = 0;\n};\n\ninline bool IsBroadcastRegistered(DeviceType device_type) {\n  return IsClassRegistered<DeviceType, Broadcast>(device_type);\n}\n\n}  // namespace ccl\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_INCLUDE_BROADCAST_H_\n"
  },
  {
    "path": "oneflow/user/kernels/collective_communication/include/collective_communication.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_INCLUDE_COLLECTIVE_COMMUNICATION_H_\n#define ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_INCLUDE_COLLECTIVE_COMMUNICATION_H_\n\n#include \"oneflow/core/common/auto_registration_factory.h\"\n#include \"oneflow/core/common/switch_func.h\"\n#include \"oneflow/user/kernels/collective_communication/include/communication_context.h\"\n#include \"oneflow/core/ep/include/stream.h\"\n\nnamespace oneflow {\n\nnamespace ccl {\n\n#define REDUCE_TYPE_SEQ      \\\n  OF_PP_MAKE_TUPLE_SEQ(kSum) \\\n  OF_PP_MAKE_TUPLE_SEQ(kMax)\n\nenum ReduceType {\n  kInvalidReduceFunctorType = 0,\n#define DEFINE_REDUCE_TYPE_ENUM_VALUE(enum_value) enum_value,\n  OF_PP_FOR_EACH_TUPLE(DEFINE_REDUCE_TYPE_ENUM_VALUE, REDUCE_TYPE_SEQ)\n#undef DEFINE_REDUCE_TYPE_ENUM_VALUE\n      kReduceTypeSize\n};\n\n#define REDUCE_TYPE_CTRV_SEQ      \\\n  MAKE_TYPED_CTRV_SEQ(ReduceType, \\\n                      OF_PP_FOR_EACH_TUPLE(OF_PP_I_MAKE_REPLICATE_TUPLE_SEQ, REDUCE_TYPE_SEQ))\n\n// abstruct base class for comm\nclass CommBase {\n public:\n  virtual ~CommBase() = default;\n\n  // return impl of comm\n  virtual void* getComm() const = 0;\n};\n\nclass CclComm {\n public:\n  CclComm() {}\n  explicit CclComm(std::shared_ptr<CommBase> comm) : comm_(std::move(comm)) {}\n\n  void* getComm() const { return comm_->getComm(); }\n\n private:\n  std::shared_ptr<CommBase> comm_{};\n};\n\nclass CollectiveCommunication {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CollectiveCommunication);\n  CollectiveCommunication() = default;\n  virtual ~CollectiveCommunication() = default;\n};\n\ntemplate<typename CollectiveCommunicationType, typename... Args>\nstatic std::unique_ptr<CollectiveCommunicationType> NewCollectiveCommunication(\n    DeviceType device_type, Args&&... args) {\n  std::unique_ptr<CollectiveCommunicationType> collective_communication_entry =\n      NewObjUniquePtr<DeviceType, CollectiveCommunicationType>(device_type);\n  if (!collective_communication_entry) { return nullptr; }\n  collective_communication_entry->Init(std::forward<Args>(args)...);\n  return collective_communication_entry;\n}\n\n#define REGISTER_COLLECTIVE_COMMUNICATION(device, Base, Derived) \\\n  REGISTER_CLASS(DeviceType, device, Base, Derived)\n\n}  // namespace ccl\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_INCLUDE_COLLECTIVE_COMMUNICATION_H_\n"
  },
  {
    "path": "oneflow/user/kernels/collective_communication/include/communication_context.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_INCLUDE_COMMUNICATION_CONTEXT_H_\n#define ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_INCLUDE_COMMUNICATION_CONTEXT_H_\n\n#include \"collective_communication.h\"\n#include \"oneflow/core/job/parallel_desc.h\"\n#include \"oneflow/core/common/auto_registration_factory.h\"\n\nnamespace oneflow {\n\nnamespace ccl {\n\nclass CommunicationContext {\n public:\n  CommunicationContext() = default;\n  virtual ~CommunicationContext() = default;\n\n  virtual void Init(Symbol<ParallelDesc>) = 0;\n};\n\ninline std::shared_ptr<CommunicationContext> NewCommunicationContext(\n    DeviceType device_type, Symbol<ParallelDesc> parallel_desc) {\n  CHECK_EQ(device_type, parallel_desc->device_type())\n      << \"device_type not match placement (\" << DeviceType_Name(device_type) << \" vs. \"\n      << DeviceType_Name(parallel_desc->device_type()) << \". \" << kOfBugIssueUploadPrompt;\n  std::shared_ptr<CommunicationContext> communication_ctx =\n      std::shared_ptr<CommunicationContext>(NewObj<DeviceType, CommunicationContext>(device_type));\n  communication_ctx->Init(parallel_desc);\n  return communication_ctx;\n}\n\ninline bool IsCommunicationContextRegistered(DeviceType device_type) {\n  return IsClassRegistered<DeviceType, CommunicationContext>(device_type);\n}\n\n#define REGISTER_COLLECTIVE_COMMUNICATION_COMMUNICATOR(device, Derived) \\\n  REGISTER_CLASS(DeviceType, device, CommunicationContext, Derived)\n\n}  // namespace ccl\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_INCLUDE_COMMUNICATION_CONTEXT_H_\n"
  },
  {
    "path": "oneflow/user/kernels/collective_communication/include/recv.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_INCLUDE_RECVH_\n#define ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_INCLUDE_RECVH_\n\n#include \"oneflow/user/kernels/collective_communication/include/collective_communication.h\"\n\nnamespace oneflow {\n\nnamespace ccl {\n\nclass Recv : public CollectiveCommunication {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(Recv);\n  Recv() = default;\n  ~Recv() override = default;\n\n  virtual void Init(DataType dtype) = 0;\n\n  virtual void Launch(ep::Stream* stream, void* out, size_t elem_cnt, int64_t src) const = 0;\n\n  virtual void Launch(ep::Stream* stream, void* out, size_t elem_cnt, int64_t src,\n                      const ccl::CclComm& ccl_comm) const = 0;\n};\n\ninline bool IsRecvRegistered(DeviceType device_type) {\n  return IsClassRegistered<DeviceType, Recv>(device_type);\n}\n\n}  // namespace ccl\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_INCLUDE_RECVH_\n"
  },
  {
    "path": "oneflow/user/kernels/collective_communication/include/reduce.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_INCLUDE_REDUCE_H_\n#define ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_INCLUDE_REDUCE_H_\n\n#include \"oneflow/user/kernels/collective_communication/include/collective_communication.h\"\n\nnamespace oneflow {\n\nnamespace ccl {\n\nclass Reduce : public CollectiveCommunication {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(Reduce);\n  Reduce() = default;\n  ~Reduce() override = default;\n\n  virtual void Init(DataType dtype, ReduceType reduce_type) = 0;\n\n  virtual void Launch(ep::Stream* stream, const void* in, void* out, size_t elem_cnt, int64_t root,\n                      const std::shared_ptr<CommunicationContext>& communicator) const = 0;\n};\n\ninline bool IsReduceRegistered(DeviceType device_type) {\n  return IsClassRegistered<DeviceType, Reduce>(device_type);\n}\n\n}  // namespace ccl\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_INCLUDE_REDUCE_H_\n"
  },
  {
    "path": "oneflow/user/kernels/collective_communication/include/reduce_scatter.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_INCLUDE_REDUCE_SCATTER_H_\n#define ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_INCLUDE_REDUCE_SCATTER_H_\n\n#include \"oneflow/user/kernels/collective_communication/include/collective_communication.h\"\n\nnamespace oneflow {\n\nnamespace ccl {\n\nclass ReduceScatter : public CollectiveCommunication {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(ReduceScatter);\n  ReduceScatter() = default;\n  ~ReduceScatter() override = default;\n\n  virtual void Init(DataType dtype, ReduceType reduce_type) = 0;\n\n  virtual void Launch(ep::Stream* stream, const void* in, void* out, size_t elem_cnt,\n                      const std::shared_ptr<CommunicationContext>& communicator) const = 0;\n\n  virtual void Launch(ep::Stream* stream, const void* in, void* out, size_t elem_cnt,\n                      const ccl::CclComm& ccl_comm) const = 0;\n};\n\ninline bool IsReduceScatterRegistered(DeviceType device_type) {\n  return IsClassRegistered<DeviceType, ReduceScatter>(device_type);\n}\n\n}  // namespace ccl\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_INCLUDE_REDUCE_SCATTER_H_\n"
  },
  {
    "path": "oneflow/user/kernels/collective_communication/include/send.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_INCLUDE_SEND_H_\n#define ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_INCLUDE_SEND_H_\n\n#include \"oneflow/user/kernels/collective_communication/include/collective_communication.h\"\n\nnamespace oneflow {\n\nnamespace ccl {\n\nclass Send : public CollectiveCommunication {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(Send);\n  Send() = default;\n  ~Send() override = default;\n\n  virtual void Init(DataType dtype) = 0;\n\n  virtual void Launch(ep::Stream* stream, const void* in, size_t elem_cnt, int64_t dst) const = 0;\n\n  virtual void Launch(ep::Stream* stream, const void* in, size_t elem_cnt, int64_t dst,\n                      const ccl::CclComm& ccl_comm) const = 0;\n};\n\ninline bool IsSendRegistered(DeviceType device_type) {\n  return IsClassRegistered<DeviceType, Send>(device_type);\n}\n\n}  // namespace ccl\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_INCLUDE_SEND_H_\n"
  },
  {
    "path": "oneflow/user/kernels/combined_margin_loss_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/common/balanced_splitter.h\"\n#include \"oneflow/core/kernel/kernel_util.h\"\n#include \"oneflow/user/kernels/math_unary_elementwise_func.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nclass CombinedMarginLossOpKernelCache final : public user_op::OpKernelCache {\n public:\n  CombinedMarginLossOpKernelCache(int64_t lower, int64_t upper) : lower_(lower), upper_(upper) {}\n  ~CombinedMarginLossOpKernelCache() override = default;\n\n  int64_t lower() const { return lower_; }\n  int64_t upper() const { return upper_; }\n\n private:\n  const int64_t lower_;\n  const int64_t upper_;\n};\n\nstd::shared_ptr<user_op::OpKernelCache> CreateCombinedMarginLossOpKernelCache(\n    user_op::KernelCacheContext* ctx, const std::string& in_arg_name) {\n  if (ctx->parallel_ctx().parallel_num() == 1) { return nullptr; }\n\n  const SbpParallel& in_sbp = ctx->SbpParallel4ArgNameAndIndex(in_arg_name, 0);\n  if (in_sbp.has_split_parallel() && in_sbp.split_parallel().axis() == 1\n      && ctx->parallel_ctx().parallel_num() > 1) {\n    CHECK(ctx->SbpParallel4ArgNameAndIndex(\"label\", 0).has_broadcast_parallel());\n    const user_op::TensorDesc* in_logical_desc =\n        ctx->LogicalTensorDesc4ArgNameAndIndex(in_arg_name, 0);\n    const auto depth = ctx->Attr<int64_t>(\"depth\");\n    CHECK_EQ(depth, in_logical_desc->shape().At(1));\n    BalancedSplitter bs(depth, ctx->parallel_ctx().parallel_num());\n    return std::make_shared<CombinedMarginLossOpKernelCache>(\n        bs.At(ctx->parallel_ctx().parallel_id()).begin(),\n        bs.At(ctx->parallel_ctx().parallel_id()).end());\n  } else {\n    return nullptr;\n  }\n}\n\n}  // namespace\n\ntemplate<typename T, typename K>\nclass CombinedMarginLossCpuKernel final : public user_op::OpKernel {\n public:\n  CombinedMarginLossCpuKernel() = default;\n  ~CombinedMarginLossCpuKernel() override = default;\n\n  std::shared_ptr<user_op::OpKernelCache> InitOpKernelCache(\n      user_op::KernelCacheContext* ctx) const override {\n    return CreateCombinedMarginLossOpKernelCache(ctx, \"x\");\n  }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*,\n               const user_op::OpKernelCache* cache) const override {\n    const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    const T* x_ptr = x->dptr<T>();\n    const K* label_ptr = ctx->Tensor4ArgNameAndIndex(\"label\", 0)->dptr<K>();\n    T* y_ptr = ctx->Tensor4ArgNameAndIndex(\"y\", 0)->mut_dptr<T>();\n    T* theta_ptr = ctx->Tensor4ArgNameAndIndex(\"theta\", 0)->mut_dptr<T>();\n    const float m1 = ctx->Attr<float>(\"m1\");\n    const float m2 = ctx->Attr<float>(\"m2\");\n    const float m3 = ctx->Attr<float>(\"m3\");\n    int64_t lower_bound = 0;\n    if (cache != nullptr) {\n      auto* kernel_cache = dynamic_cast<const CombinedMarginLossOpKernelCache*>(cache);\n      CHECK_NOTNULL(kernel_cache);\n      CHECK_EQ(x->shape_view().Count(1), kernel_cache->upper() - kernel_cache->lower());\n      lower_bound = kernel_cache->lower();\n    }\n    const int64_t num_classes = x->shape_view().Count(1);\n    FOR_RANGE(int32_t, i, 0, x->shape_view().elem_cnt()) {\n      const int32_t row_id = i / num_classes;\n      const int32_t col_id = i - row_id * num_classes;\n      const T in_data = x_ptr[i];\n      T out_data = in_data;\n      K label = label_ptr[row_id] - lower_bound;\n      if (label == col_id) {\n        const T theta_data = AcosFunctor<T>::Forward(in_data);\n        out_data = CosFunctor<T>::Forward(theta_data * static_cast<T>(m1) + static_cast<T>(m2))\n                   - static_cast<T>(m3);\n        theta_ptr[row_id] = theta_data;\n      } else if ((label < 0 || label >= num_classes) && col_id == 0) {\n        theta_ptr[row_id] = 0;\n      }\n      y_ptr[i] = out_data;\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_COMBINED_MARGIN_LOSS_CPU_KERNEL(in_type, indices_type)                \\\n  REGISTER_USER_KERNEL(\"combined_margin_loss\")                                         \\\n      .SetCreateFn<CombinedMarginLossCpuKernel<OF_PP_PAIR_FIRST(in_type),              \\\n                                               OF_PP_PAIR_FIRST(indices_type)>>()      \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)                  \\\n                       && (user_op::HobDataType(\"x\", 0) == OF_PP_PAIR_SECOND(in_type)) \\\n                       && (user_op::HobDataType(\"label\", 0) == OF_PP_PAIR_SECOND(indices_type)));\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_COMBINED_MARGIN_LOSS_CPU_KERNEL, FLOATING_DATA_TYPE_SEQ,\n                                 INDEX_DATA_TYPE_SEQ)\n\ntemplate<typename T, typename K>\nclass CombinedMarginLossGradCpuKernel final : public user_op::OpKernel {\n public:\n  CombinedMarginLossGradCpuKernel() = default;\n  ~CombinedMarginLossGradCpuKernel() override = default;\n\n  std::shared_ptr<user_op::OpKernelCache> InitOpKernelCache(\n      user_op::KernelCacheContext* ctx) const override {\n    return CreateCombinedMarginLossOpKernelCache(ctx, \"dy\");\n  }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*,\n               const user_op::OpKernelCache* cache) const override {\n    const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    const T* dy_ptr = dy->dptr<T>();\n    const K* label_ptr = ctx->Tensor4ArgNameAndIndex(\"label\", 0)->dptr<K>();\n    const T* theta_ptr = ctx->Tensor4ArgNameAndIndex(\"theta\", 0)->dptr<T>();\n    T* dx_ptr = ctx->Tensor4ArgNameAndIndex(\"dx\", 0)->mut_dptr<T>();\n    const float m1 = ctx->Attr<float>(\"m1\");\n    const float m2 = ctx->Attr<float>(\"m2\");\n    int64_t lower_bound = 0;\n    if (cache != nullptr) {\n      auto* kernel_cache = dynamic_cast<const CombinedMarginLossOpKernelCache*>(cache);\n      CHECK_NOTNULL(kernel_cache);\n      CHECK_EQ(dy->shape_view().Count(1), kernel_cache->upper() - kernel_cache->lower());\n      lower_bound = kernel_cache->lower();\n    }\n\n    const int64_t num_classes = dy->shape_view().Count(1);\n    FOR_RANGE(int32_t, i, 0, dy->shape_view().elem_cnt()) {\n      const int32_t row_id = i / num_classes;\n      const int32_t col_id = i - row_id * num_classes;\n      K label = label_ptr[row_id] - lower_bound;\n      const T dy_data = dy_ptr[i];\n      const T theta_data = theta_ptr[row_id];\n      T dx_data = dy_data;\n      if (label == col_id) {\n        dx_data = dy_data\n                  * SinFunctor<T>::Forward(theta_data * static_cast<T>(m1) + static_cast<T>(m2))\n                  * static_cast<T>(m1) / SinFunctor<T>::Forward(theta_data);\n      }\n      dx_ptr[i] = dx_data;\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_COMBINED_MARGIN_LOSS_GRAD_CPU_KERNEL(dy_type, indices_type)            \\\n  REGISTER_USER_KERNEL(\"combined_margin_loss_grad\")                                     \\\n      .SetCreateFn<CombinedMarginLossGradCpuKernel<OF_PP_PAIR_FIRST(dy_type),           \\\n                                                   OF_PP_PAIR_FIRST(indices_type)>>()   \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)                   \\\n                       && (user_op::HobDataType(\"dy\", 0) == OF_PP_PAIR_SECOND(dy_type)) \\\n                       && (user_op::HobDataType(\"label\", 0) == OF_PP_PAIR_SECOND(indices_type)));\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_COMBINED_MARGIN_LOSS_GRAD_CPU_KERNEL,\n                                 FLOATING_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ)\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/combined_margin_loss_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/common/balanced_splitter.h\"\n#include \"oneflow/core/kernel/kernel_util.h\"\n#include \"oneflow/user/kernels/math_unary_elementwise_func.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename T, typename K, bool is_cosine_loss>\n__global__ void GpuForward(const int64_t n, const int64_t num_classes, const int64_t lower_bound,\n                           const T m1, const T m2, const T m3, const T* in, const K* labels, T* out,\n                           T* theta) {\n  CUDA_1D_KERNEL_LOOP(i, n) {\n    const int32_t row_id = i / num_classes;\n    const int32_t col_id = i - row_id * num_classes;\n    const T in_data = in[i];\n    T out_data = in_data;\n    K label = labels[row_id] - lower_bound;\n    if (is_cosine_loss) {\n      if (label == col_id) { out_data = in_data - m3; }\n    } else {\n      if (label == col_id) {\n        const T theta_data = AcosFunctor<T>::Forward(in_data);\n        out_data = CosFunctor<T>::Forward(theta_data * m1 + m2) - m3;\n        theta[row_id] = theta_data;\n      } else if ((label < 0 || label >= num_classes) && col_id == 0) {\n        theta[row_id] = 0;\n      }\n    }\n    out[i] = out_data;\n  }\n}\n\ntemplate<typename T, typename K, bool is_cosine_loss>\n__global__ void GpuBackward(const int64_t n, const int64_t num_classes, const int64_t lower_bound,\n                            const T m1, const T m2, const T m3, const T* dy, const K* labels,\n                            const T* theta, T* dx) {\n  CUDA_1D_KERNEL_LOOP(i, n) {\n    const int32_t row_id = i / num_classes;\n    const int32_t col_id = i - row_id * num_classes;\n    K label = labels[row_id] - lower_bound;\n    const T dy_data = dy[i];\n    const T theta_data = theta[row_id];\n    T dx_data = dy_data;\n    if (label == col_id && !is_cosine_loss) {\n      dx_data = dy_data * SinFunctor<T>::Forward(theta_data * m1 + m2) * m1\n                / SinFunctor<T>::Forward(theta_data);\n    }\n    dx[i] = dx_data;\n  }\n}\n\nclass CombinedMarginLossOpKernelCache final : public user_op::OpKernelCache {\n public:\n  CombinedMarginLossOpKernelCache(int64_t lower, int64_t upper) : lower_(lower), upper_(upper) {}\n  ~CombinedMarginLossOpKernelCache() override = default;\n\n  int64_t lower() const { return lower_; }\n  int64_t upper() const { return upper_; }\n\n private:\n  const int64_t lower_;\n  const int64_t upper_;\n};\n\nstd::shared_ptr<user_op::OpKernelCache> CreateCombinedMarginLossOpKernelCache(\n    user_op::KernelCacheContext* ctx, const std::string& in_arg_name) {\n  if (ctx->parallel_ctx().parallel_num() == 1) { return nullptr; }\n\n  const SbpParallel& in_sbp = ctx->SbpParallel4ArgNameAndIndex(in_arg_name, 0);\n  if (in_sbp.has_split_parallel() && in_sbp.split_parallel().axis() == 1\n      && ctx->parallel_ctx().parallel_num() > 1) {\n    CHECK(ctx->SbpParallel4ArgNameAndIndex(\"label\", 0).has_broadcast_parallel());\n    const user_op::TensorDesc* in_logical_desc =\n        ctx->LogicalTensorDesc4ArgNameAndIndex(in_arg_name, 0);\n    const auto depth = ctx->Attr<int64_t>(\"depth\");\n    CHECK_EQ(depth, in_logical_desc->shape().At(1));\n    BalancedSplitter bs(depth, ctx->parallel_ctx().parallel_num());\n    return std::make_shared<CombinedMarginLossOpKernelCache>(\n        bs.At(ctx->parallel_ctx().parallel_id()).begin(),\n        bs.At(ctx->parallel_ctx().parallel_id()).end());\n  } else {\n    return nullptr;\n  }\n}\n\n}  // namespace\n\ntemplate<typename T, typename K>\nclass CombinedMarginLossGpuKernel final : public user_op::OpKernel {\n public:\n  CombinedMarginLossGpuKernel() = default;\n  ~CombinedMarginLossGpuKernel() override = default;\n\n  std::shared_ptr<user_op::OpKernelCache> InitOpKernelCache(\n      user_op::KernelCacheContext* ctx) const override {\n    return CreateCombinedMarginLossOpKernelCache(ctx, \"x\");\n  }\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*,\n               const user_op::OpKernelCache* cache) const override {\n    const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    const user_op::Tensor* label = ctx->Tensor4ArgNameAndIndex(\"label\", 0);\n    user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    user_op::Tensor* theta = ctx->Tensor4ArgNameAndIndex(\"theta\", 0);\n    const float m1 = ctx->Attr<float>(\"m1\");\n    const float m2 = ctx->Attr<float>(\"m2\");\n    const float m3 = ctx->Attr<float>(\"m3\");\n    int64_t lower_bound = 0;\n    if (cache != nullptr) {\n      auto* kernel_cache = dynamic_cast<const CombinedMarginLossOpKernelCache*>(cache);\n      CHECK_NOTNULL(kernel_cache);\n      CHECK_EQ(x->shape_view().Count(1), kernel_cache->upper() - kernel_cache->lower());\n      lower_bound = kernel_cache->lower();\n    }\n    if (m1 == 1.0 && m2 == 0.0) {\n      GpuForward<T, K, true>\n          <<<BlocksNum4ThreadsNum(x->shape_view().elem_cnt()), kCudaThreadsNumPerBlock, 0,\n             ctx->stream()->As<ep::CudaStream>()->cuda_stream()>>>(\n              x->shape_view().elem_cnt(), x->shape_view().Count(1), lower_bound, static_cast<T>(m1),\n              static_cast<T>(m2), static_cast<T>(m3), x->dptr<T>(), label->dptr<K>(),\n              y->mut_dptr<T>(), theta->mut_dptr<T>());\n    } else {\n      GpuForward<T, K, false>\n          <<<BlocksNum4ThreadsNum(x->shape_view().elem_cnt()), kCudaThreadsNumPerBlock, 0,\n             ctx->stream()->As<ep::CudaStream>()->cuda_stream()>>>(\n              x->shape_view().elem_cnt(), x->shape_view().Count(1), lower_bound, static_cast<T>(m1),\n              static_cast<T>(m2), static_cast<T>(m3), x->dptr<T>(), label->dptr<K>(),\n              y->mut_dptr<T>(), theta->mut_dptr<T>());\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_COMBINED_MARGIN_LOSS_CUDA_KERNEL(in_type, indices_type)               \\\n  REGISTER_USER_KERNEL(\"combined_margin_loss\")                                         \\\n      .SetCreateFn<CombinedMarginLossGpuKernel<OF_PP_PAIR_FIRST(in_type),              \\\n                                               OF_PP_PAIR_FIRST(indices_type)>>()      \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)                 \\\n                       && (user_op::HobDataType(\"x\", 0) == OF_PP_PAIR_SECOND(in_type)) \\\n                       && (user_op::HobDataType(\"label\", 0) == OF_PP_PAIR_SECOND(indices_type)));\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_COMBINED_MARGIN_LOSS_CUDA_KERNEL, FLOATING_DATA_TYPE_SEQ,\n                                 INDEX_DATA_TYPE_SEQ)\n\ntemplate<typename T, typename K>\nclass CombinedMarginLossGradGpuKernel final : public user_op::OpKernel {\n public:\n  CombinedMarginLossGradGpuKernel() = default;\n  ~CombinedMarginLossGradGpuKernel() override = default;\n\n  std::shared_ptr<user_op::OpKernelCache> InitOpKernelCache(\n      user_op::KernelCacheContext* ctx) const override {\n    return CreateCombinedMarginLossOpKernelCache(ctx, \"dy\");\n  }\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*,\n               const user_op::OpKernelCache* cache) const override {\n    const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    const user_op::Tensor* label = ctx->Tensor4ArgNameAndIndex(\"label\", 0);\n    const user_op::Tensor* theta = ctx->Tensor4ArgNameAndIndex(\"theta\", 0);\n    user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex(\"dx\", 0);\n    const float m1 = ctx->Attr<float>(\"m1\");\n    const float m2 = ctx->Attr<float>(\"m2\");\n    const float m3 = ctx->Attr<float>(\"m3\");\n    int64_t lower_bound = 0;\n    if (cache != nullptr) {\n      auto* kernel_cache = dynamic_cast<const CombinedMarginLossOpKernelCache*>(cache);\n      CHECK_NOTNULL(kernel_cache);\n      CHECK_EQ(dy->shape_view().Count(1), kernel_cache->upper() - kernel_cache->lower());\n      lower_bound = kernel_cache->lower();\n    }\n    if (m1 == 1.0 && m2 == 0.0) {\n      GpuBackward<T, K, true>\n          <<<BlocksNum4ThreadsNum(dy->shape_view().elem_cnt()), kCudaThreadsNumPerBlock, 0,\n             ctx->stream()->As<ep::CudaStream>()->cuda_stream()>>>(\n              dy->shape_view().elem_cnt(), dy->shape_view().Count(1), lower_bound,\n              static_cast<T>(m1), static_cast<T>(m2), static_cast<T>(m3), dy->dptr<T>(),\n              label->dptr<K>(), theta->dptr<T>(), dx->mut_dptr<T>());\n    } else {\n      GpuBackward<T, K, false>\n          <<<BlocksNum4ThreadsNum(dy->shape_view().elem_cnt()), kCudaThreadsNumPerBlock, 0,\n             ctx->stream()->As<ep::CudaStream>()->cuda_stream()>>>(\n              dy->shape_view().elem_cnt(), dy->shape_view().Count(1), lower_bound,\n              static_cast<T>(m1), static_cast<T>(m2), static_cast<T>(m3), dy->dptr<T>(),\n              label->dptr<K>(), theta->dptr<T>(), dx->mut_dptr<T>());\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_COMBINED_MARGIN_LOSS_GRAD_CUDA_KERNEL(dy_type, indices_type)           \\\n  REGISTER_USER_KERNEL(\"combined_margin_loss_grad\")                                     \\\n      .SetCreateFn<CombinedMarginLossGradGpuKernel<OF_PP_PAIR_FIRST(dy_type),           \\\n                                                   OF_PP_PAIR_FIRST(indices_type)>>()   \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)                  \\\n                       && (user_op::HobDataType(\"dy\", 0) == OF_PP_PAIR_SECOND(dy_type)) \\\n                       && (user_op::HobDataType(\"label\", 0) == OF_PP_PAIR_SECOND(indices_type)));\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_COMBINED_MARGIN_LOSS_GRAD_CUDA_KERNEL,\n                                 FLOATING_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ)\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/communicate_util.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/user/kernels/communicate_util.h\"\n#include \"oneflow/core/ep/include/primitive/memcpy.h\"\n#include \"oneflow/core/kernel/new_kernel_util.h\"\n#include \"oneflow/core/control/global_process_ctx.h\"\n#include \"oneflow/user/kernels/collective_communication/include/send.h\"\n#include \"oneflow/user/kernels/collective_communication/include/recv.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nconst void** ThreadLocalSrcDataPtr() {\n  static thread_local const void* data_ptr = nullptr;\n  return &data_ptr;\n}\n\n}  // namespace\n\nbool IsSendAndRecvRegistered(DeviceType device_type) {\n  return ccl::IsSendRegistered(device_type) && ccl::IsRecvRegistered(device_type);\n}\n\nMaybe<void> Send(const void* in, size_t elem_cnt, DataType dtype, int64_t dst,\n                 DeviceType device_type, ep::Stream* stream) {\n  if (GlobalProcessCtx::Rank() == dst) {\n    auto** src_data_ptr = ThreadLocalSrcDataPtr();\n    CHECK_OR_RETURN(*src_data_ptr == nullptr);\n    *src_data_ptr = in;\n  } else {\n    std::unique_ptr<ccl::Send> send =\n        ccl::NewCollectiveCommunication<ccl::Send>(device_type, dtype);\n    send->Launch(stream, in, elem_cnt, dst);\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> Recv(void* out, size_t elem_cnt, DataType dtype, int64_t src, DeviceType device_type,\n                 ep::Stream* stream) {\n  if (GlobalProcessCtx::Rank() == src) {\n    size_t buffer_size = elem_cnt * GetSizeOfDataType(dtype);\n    auto** src_data_ptr = ThreadLocalSrcDataPtr();\n    const void* in = *src_data_ptr;\n    CHECK_OR_RETURN(*src_data_ptr != nullptr);\n    std::unique_ptr<ep::primitive::Memcpy> memcpy_primitive =\n        ep::primitive::NewPrimitive<ep::primitive::MemcpyFactory>(device_type,\n                                                                  ep::primitive::MemcpyKind::kDtoD);\n    CHECK(memcpy_primitive) << \"Can not create Memcpy primitive for device type \" << device_type;\n    memcpy_primitive->Launch(stream, out, in, buffer_size);\n    *src_data_ptr = nullptr;\n  } else {\n    std::unique_ptr<ccl::Recv> recv =\n        ccl::NewCollectiveCommunication<ccl::Recv>(device_type, dtype);\n    recv->Launch(stream, out, elem_cnt, src);\n  }\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/communicate_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_KERNELS_COMMUNICATE_UTIL_H_\n#define ONEFLOW_USER_KERNELS_COMMUNICATE_UTIL_H_\n\n#include \"oneflow/core/common/data_type.h\"\n#include \"oneflow/core/ep/include/stream.h\"\n#include \"oneflow/core/framework/user_op_kernel_registry.h\"\n\nnamespace oneflow {\n\nbool IsSendAndRecvRegistered(DeviceType device_type);\n\nALWAYS_INLINE inline auto HobIsSendAndRecvRegistered() {\n  return hob::make_custom(\"HobIsSendAndRecvRegistered\", [](const user_op::KernelRegContext& ctx) {\n    return IsSendAndRecvRegistered(ctx.device_type());\n  });\n}\n\n// Send data from in to rank dst, if cur rank equal dst, memcopy will happen.\n// Rank dst needs to call Recv with the same datatype and the same count from this rank.\nMaybe<void> Send(const void* in, size_t elem_cnt, DataType dtype, int64_t dst,\n                 DeviceType device_type, ep::Stream* stream);\n\n// Receive data from rank src into out, if cur rank equal src, memcopy will happen.\n// Rank src needs to call Send with the same datatype and the same count to this rank.\nMaybe<void> Recv(void* out, size_t elem_cnt, DataType dtype, int64_t src, DeviceType device_type,\n                 ep::Stream* stream);\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_KERNELS_COMMUNICATE_UTIL_H_\n"
  },
  {
    "path": "oneflow/user/kernels/complex_kernels.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/shape_view.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/ep/include/primitive/elementwise_unary.h\"\n#include \"oneflow/core/ep/include/primitive/primitive.h\"\n#include \"oneflow/core/ep/include/primitive/unary_op.h\"\n#include \"oneflow/user/kernels/elementwise_primitive_kernel.h\"\n#include <complex>\n#ifdef WITH_CUDA\n#include <cuComplex.h>\n#endif  // WITH_CUDA\n\nnamespace oneflow {\nnamespace user_op {\n\n#define COMPLEX_UNARY_ELEMENTWISE_PRIMITIVE_SEQ                        \\\n  OF_PP_MAKE_TUPLE_SEQ(\"conj_physical\", ep::primitive::UnaryOp::kConj) \\\n  OF_PP_MAKE_TUPLE_SEQ(\"real\", ep::primitive::UnaryOp::kReal)          \\\n  OF_PP_MAKE_TUPLE_SEQ(\"imag\", ep::primitive::UnaryOp::kImag)\n\n#define COMPLEX_UNARY_GRAD_ELEMENTWISE_PRIMITIVE_SEQ                   \\\n  OF_PP_MAKE_TUPLE_SEQ(\"real_grad\", ep::primitive::UnaryOp::kRealGrad) \\\n  OF_PP_MAKE_TUPLE_SEQ(\"imag_grad\", ep::primitive::UnaryOp::kImagGrad)\n\n#define REGISTER_COMPLEX_KERNEL(name, UnaryOp)                                            \\\n  REGISTER_USER_KERNEL(name)                                                              \\\n      .SetCreateFn([]() {                                                                 \\\n        return user_op::NewOpKernel<UnaryPrimitiveKernel>(                                \\\n            \"out\", \"x\", [](user_op::KernelComputeContext* ctx) {                          \\\n              const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex(\"out\", 0); \\\n              const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex(\"x\", 0);   \\\n              return ep::primitive::NewPrimitive<ep::primitive::ElementwiseUnaryFactory>( \\\n                  ctx->device_type(), UnaryOp, src->data_type(), dst->data_type());       \\\n            });                                                                           \\\n      })                                                                                  \\\n      .SetIsMatchedHob(UnaryPrimitiveExists(UnaryOp, \"out\", \"x\"));\nOF_PP_FOR_EACH_TUPLE(REGISTER_COMPLEX_KERNEL, COMPLEX_UNARY_ELEMENTWISE_PRIMITIVE_SEQ)\n\n#define REGISTER_COMPLEX_GRAD_KERNEL(name, UnaryOp)                                        \\\n  REGISTER_USER_KERNEL(name)                                                               \\\n      .SetCreateFn([]() {                                                                  \\\n        return user_op::NewOpKernel<UnaryPrimitiveKernel>(                                 \\\n            \"dx\", \"dout\", [](user_op::KernelComputeContext* ctx) {                         \\\n              const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex(\"dx\", 0);   \\\n              const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex(\"dout\", 0); \\\n              return ep::primitive::NewPrimitive<ep::primitive::ElementwiseUnaryFactory>(  \\\n                  ctx->device_type(), UnaryOp, src->data_type(), dst->data_type());        \\\n            });                                                                            \\\n      })                                                                                   \\\n      .SetIsMatchedHob(UnaryPrimitiveExists(UnaryOp, \"dx\", \"dout\"));\n\nOF_PP_FOR_EACH_TUPLE(REGISTER_COMPLEX_GRAD_KERNEL, COMPLEX_UNARY_GRAD_ELEMENTWISE_PRIMITIVE_SEQ)\n\n}  // namespace user_op\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/concat_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/ep/include/primitive/copy_nd.h\"\n#include \"oneflow/core/kernel/cuda_graph_support.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename Context>\nstd::unique_ptr<ep::primitive::CopyNd> NewCopyNdPrimitive(Context* ctx) {\n  return ep::primitive::NewPrimitive<ep::primitive::CopyNdFactory>(ctx->device_type(), 2);\n}\n\nclass ConcatKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport {\n public:\n  ConcatKernel() = default;\n  ~ConcatKernel() = default;\n\n private:\n  void InferShape(user_op::KernelInferContext* ctx) const override {\n    const int64_t axis = ctx->Attr<int64_t>(\"axis\");\n    DimVector dim_vec;\n    for (const auto& in_arg_pair : ctx->inputs()) {\n      const ShapeView& input_shape_view =\n          ctx->ShapeView4ArgNameAndIndex(in_arg_pair.first, in_arg_pair.second);\n      if (dim_vec.size() == 0) {\n        input_shape_view.ToDimVector(&dim_vec);\n      } else {\n        CHECK_EQ(input_shape_view.NumAxes(), dim_vec.size());\n        FOR_RANGE(int64_t, i, 0, input_shape_view.NumAxes()) {\n          if (i == axis) {\n            dim_vec.at(i) += input_shape_view.At(i);\n          } else {\n            CHECK_EQ(input_shape_view.At(i), dim_vec.at(i));\n          }\n        }\n      }\n    }\n    ctx->MutShapeView4ArgNameAndIndex(\"out\", 0).set_shape(Shape(dim_vec));\n  }\n\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    user_op::Tensor* out_tensor = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    if (out_tensor->shape_view().elem_cnt() == 0) { return; }\n    const int64_t axis = ctx->Attr<int64_t>(\"axis\");\n    const int64_t out_cols = out_tensor->shape_view().Count(axis);\n    const int64_t rows = out_tensor->shape_view().elem_cnt() / out_cols;\n    CHECK_GT(rows, 0);\n\n    auto primitive = NewCopyNdPrimitive(ctx);\n    CHECK(primitive);\n    int64_t out_col_offset = 0;\n    for (const auto& in_arg_pair : ctx->inputs()) {\n      const user_op::Tensor* in_tensor =\n          ctx->Tensor4ArgNameAndIndex(in_arg_pair.first, in_arg_pair.second);\n      if (in_tensor->shape_view().elem_cnt() == 0) { continue; }\n      const int64_t in_cols = in_tensor->shape_view().Count(axis);\n      CHECK_EQ(in_tensor->shape_view().elem_cnt(), rows * in_cols);\n      if (in_cols > 0) {\n        DimVector dst_shape = {rows, out_cols};\n        DimVector dst_pos_vec = {0, out_col_offset};\n        DimVector src_shape = {rows, in_cols};\n        DimVector src_pos_vec = {0, 0};\n        DimVector extent_vec = {rows, in_cols};\n        primitive->Launch(ctx->stream(), out_tensor->data_type(), 2, out_tensor->mut_dptr(),\n                          dst_shape.data(), dst_pos_vec.data(), in_tensor->dptr(), src_shape.data(),\n                          src_pos_vec.data(), extent_vec.data());\n      }\n      out_col_offset += in_cols;\n    }\n    CHECK_EQ(out_col_offset, out_cols);\n  }\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\nauto CopyNdPrimitiveExists() {\n  return hob::make_custom(\"CopyNdPrimitiveExists\",\n                          [](const user_op::KernelRegContext& ctx) -> bool {\n                            return NewCopyNdPrimitive(&ctx).operator bool();\n                          });\n}\n\n}  // namespace\n\nREGISTER_USER_KERNEL(\"cat\").SetCreateFn<ConcatKernel>().SetIsMatchedHob(CopyNdPrimitiveExists()\n                                                                        == true);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/constant_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <complex>\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/ep/include/primitive/fill.h\"\n\nnamespace oneflow {\nnamespace user_op {\n\nnamespace {\n\ntemplate<typename Context>\nstd::unique_ptr<ep::primitive::Fill> NewFillPrimitive(Context* ctx) {\n  const DataType data_type = ctx->TensorDesc4ArgNameAndIndex(\"out\", 0)->data_type();\n  return ep::primitive::NewPrimitive<ep::primitive::FillFactory>(ctx->device_type(), data_type);\n}\n\nclass ConstantKernel final : public OpKernel {\n public:\n  ConstantKernel() = default;\n  ~ConstantKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    Tensor* out_tensor = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    bool is_complex_value = ctx->Attr<bool>(\"is_complex_value\");\n    bool is_floating_value = ctx->Attr<bool>(\"is_floating_value\");\n\n    const Scalar value = is_complex_value\n                             ? Scalar(ctx->Attr<std::complex<double>>(\"complex_value\"))\n                             : (is_floating_value ? Scalar(ctx->Attr<double>(\"floating_value\"))\n                                                  : Scalar(ctx->Attr<int64_t>(\"integer_value\")));\n    const int64_t elem_cnt = out_tensor->shape_view().elem_cnt();\n    CHECK_GE(elem_cnt, 0);\n    if (elem_cnt == 0) { return; }\n    std::unique_ptr<ep::primitive::Fill> fill = NewFillPrimitive(ctx);\n    CHECK(fill);\n    fill->Launch(ctx->stream(), out_tensor->mut_dptr(), value, elem_cnt);\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\nauto FillPrimitiveExists() {\n  return hob::make_custom(\"FillPrimitiveExists\", [](const user_op::KernelRegContext& ctx) {\n    return NewFillPrimitive(&ctx).operator bool();\n  });\n}\n\nREGISTER_USER_KERNEL(\"constant\")\n    .SetCreateFn<ConstantKernel>()\n    .SetIsMatchedHob(FillPrimitiveExists() == true);\n\n}  // namespace\n\n}  // namespace user_op\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/conv_cudnn_kernels.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifdef WITH_CUDA\n\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/device/cudnn_conv_util.h\"\n#include \"oneflow/core/kernel/new_kernel_util.h\"\n#include \"oneflow/core/job/resource_desc.h\"\n#include \"oneflow/core/job/global_for.h\"\n#include \"oneflow/core/kernel/cuda_graph_support.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n#include \"oneflow/core/job/lazy_mode.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename PerfT>\nstruct CudnnConvArgsAndAlgo final {\n  using AlgoT = decltype(std::declval<PerfT>().algo);\n\n  CudnnConvArgs args;\n  PerfT algo_perf;\n\n  CudnnConvArgsAndAlgo(const user_op::Tensor* x, const user_op::Tensor* w, const user_op::Tensor* y,\n                       user_op::Tensor* buf, const user_op::KernelComputeContext* ctx,\n                       ep::Stream* stream, bool has_forced_algo, int32_t forced_algo)\n      : args(*ctx, x->data_type(), x->shape_view(), w->data_type(), w->shape_view(), y->data_type(),\n             y->shape_view(), ctx->Attr<std::string>(\"data_format\"), buf->shape_view().elem_cnt(),\n             Singleton<ResourceDesc, ForSession>::Get()\n                     ->resource()\n                     .cudnn_conf()\n                     .cudnn_conv_heuristic_search_algo()\n                 || (!LazyMode::is_enabled()),\n             Singleton<ResourceDesc, ForSession>::Get()\n                 ->resource()\n                 .cudnn_conf()\n                 .cudnn_conv_use_deterministic_algo_only(),\n             Singleton<ResourceDesc, ForSession>::Get()\n                     ->resource()\n                     .cudnn_conf()\n                     .cudnn_conv_enable_pseudo_half()\n                 || (ctx->Attr<std::string>(\"data_format\") == \"channels_last\"\n                     && std::is_same<PerfT, cudnnConvolutionBwdFilterAlgoPerf_t>::value)) {\n    size_t byte_size_of_buf = buf->shape_view().elem_cnt();\n    AllocatedCudnnConvResource res(stream->As<ep::CudaStream>()->cudnn_handle(),\n                                   const_cast<void*>(x->dptr()), const_cast<void*>(w->dptr()),\n                                   const_cast<void*>(y->dptr()), buf->mut_dptr());\n    if (has_forced_algo) {\n      algo_perf = GetCudnnConvAlgorithmPerferenceWithResource<PerfT>(\n          &args, &res, static_cast<AlgoT>(forced_algo));\n    } else {\n      algo_perf = FindCudnnConvAlgorithmWithResource<PerfT>(&args, &res);\n    }\n    CHECK_EQ(algo_perf.status, CUDNN_STATUS_SUCCESS)\n        << \"op (\" << ctx->op_name()\n        << \") find algorithm perference failed. algo: \" << algo_perf.algo;\n    CHECK_LE(algo_perf.memory, byte_size_of_buf)\n        << \"op (\" << ctx->op_name() << \") find algorithm \" << algo_perf.algo << \", need memory \"\n        << algo_perf.memory << \", but cudnn_buf_limit_byte is \" << byte_size_of_buf;\n    OF_CUDNN_CHECK(cudnnSetConvolutionMathType(args.cdesc.Get(), algo_perf.mathType));\n  }\n  CudnnConvArgsAndAlgo() = delete;\n  OF_DISALLOW_COPY_AND_MOVE(CudnnConvArgsAndAlgo);\n};\n\ntemplate<typename PerfT>\nsize_t InferTmpSizeWithCudnn(const user_op::TensorDesc* x, const user_op::TensorDesc* w,\n                             const user_op::TensorDesc* y, const user_op::InferContext& ctx,\n                             bool has_forced_algo, int32_t forced_algo) {\n  using AlgoT = decltype(std::declval<PerfT>().algo);\n\n  const auto& cudnn_conf = Singleton<ResourceDesc, ForSession>::Get()->resource().cudnn_conf();\n  size_t workspace_size = cudnn_conf.cudnn_buf_limit_mbyte() * 1024 * 1024;\n  if (!x->is_dynamic()) {\n    CudnnConvArgs args(ctx, x->data_type(), ShapeView(x->shape()), w->data_type(),\n                       ShapeView(w->shape()), y->data_type(), ShapeView(y->shape()),\n                       ctx.Attr<std::string>(\"data_format\"), workspace_size,\n                       cudnn_conf.cudnn_conv_heuristic_search_algo() || (!LazyMode::is_enabled()),\n                       cudnn_conf.cudnn_conv_use_deterministic_algo_only(),\n                       cudnn_conf.cudnn_conv_enable_pseudo_half()\n                           || (ctx.Attr<std::string>(\"data_format\") == \"channels_last\"\n                               && std::is_same<PerfT, cudnnConvolutionBwdFilterAlgoPerf_t>::value));\n    PerfT algo_perf{};\n    if (has_forced_algo) {\n      algo_perf = GetCudnnConvAlgorithmPerference<PerfT>(&args, static_cast<AlgoT>(forced_algo));\n    } else {\n      algo_perf = FindCudnnConvAlgorithm<PerfT>(&args);\n    }\n    CHECK_EQ(algo_perf.status, CUDNN_STATUS_SUCCESS)\n        << \"op (\" << ctx.op_name()\n        << \") find algorithm perference failed. algo: \" << algo_perf.algo;\n    CHECK_LE(algo_perf.memory, workspace_size)\n        << \"op (\" << ctx.op_name() << \") find algorithm \" << algo_perf.algo << \", need memory \"\n        << algo_perf.memory << \", but cudnn_buf_limit_byte is \" << workspace_size;\n    workspace_size = algo_perf.memory;\n  }\n  workspace_size = std::max(size_t(1), workspace_size);\n  return workspace_size;\n}\n\n// for 1d and 2d\ntemplate<size_t NDims>\nCudnnTensorDesc* GetBiasCudnnTensorDesc(const std::string& data_format, int32_t filters,\n                                        DataType data_type) {\n  if (data_format == \"channels_first\") {\n    return new CudnnTensorDesc(CUDNN_TENSOR_NCHW, data_type, 1, filters, 1, 1);\n  } else {\n    CHECK_EQ(\"channels_last\", data_format);\n    return new CudnnTensorDesc(CUDNN_TENSOR_NHWC, data_type, 1, filters, 1, 1);\n  }\n}\n\n// for 3d and Nd\ntemplate<>\nCudnnTensorDesc* GetBiasCudnnTensorDesc<3>(const std::string& data_format, int32_t filters,\n                                           DataType data_type) {\n  constexpr int NDims = 3 + 2;\n  CHECK_EQ(\"channels_first\", data_format) << \"CUDNN Nd API only support channels first\";\n  std::vector<int32_t> bias_dim(NDims, 1);\n  std::vector<int32_t> stride_of_bias_tensor(NDims, 1);\n  bias_dim[1] = filters;\n  stride_of_bias_tensor[0] = filters;\n  return new CudnnTensorDesc(data_type, NDims, bias_dim.data(), stride_of_bias_tensor.data());\n}\n\nstruct ConvCudnnOpKernelCache final : public user_op::OpKernelCache {\n  std::unique_ptr<CudnnTensorDesc> bias_desc;\n};\n\ntemplate<size_t NDims>\nclass ConvGpuKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport {\n public:\n  ConvGpuKernel() = default;\n  ~ConvGpuKernel() = default;\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n\n  std::shared_ptr<ConvCudnnOpKernelCache> CreateConvCudnnOpKernelCache(\n      user_op::KernelCacheContext* ctx) const {\n    const auto& data_format = ctx->Attr<std::string>(\"data_format\");\n    int32_t filters = ctx->Attr<int32_t>(\"filters\");\n\n    std::shared_ptr<ConvCudnnOpKernelCache> state(new ConvCudnnOpKernelCache());\n\n    const user_op::TensorDesc* bias = ctx->TensorDesc4ArgNameAndIndex(\"bias\", 0);\n    if (bias != nullptr) {\n      state->bias_desc.reset(\n          GetBiasCudnnTensorDesc<NDims>(data_format, filters, bias->data_type()));\n    }\n\n    return state;\n  }\n\n  std::shared_ptr<user_op::OpKernelCache> InitOpKernelCache(\n      user_op::KernelCacheContext* ctx) const override {\n    return CreateConvCudnnOpKernelCache(ctx);\n  }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*,\n               const user_op::OpKernelCache* cache) const override {\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    if (in->shape_view().elem_cnt() == 0) return;\n    const user_op::Tensor* weight = ctx->Tensor4ArgNameAndIndex(\"weight\", 0);\n    user_op::Tensor* buf = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    const auto& cudnn_conf = Singleton<ResourceDesc, ForSession>::Get()->resource().cudnn_conf();\n    CudnnConvArgsAndAlgo<cudnnConvolutionFwdAlgoPerf_t> args_and_algo(\n        in, weight, out, buf, ctx, ctx->stream(), cudnn_conf.has_cudnn_conv_force_fwd_algo(),\n        cudnn_conf.cudnn_conv_force_fwd_algo());\n    const CudnnConvArgs& args = args_and_algo.args;\n    const cudnnConvolutionFwdAlgoPerf_t& algo_perf = args_and_algo.algo_perf;\n    const user_op::Tensor* bias = ctx->Tensor4ArgNameAndIndex(\"bias\", 0);\n\n    const void* beta = nullptr;\n    if (ctx->has_input(\"_add_to_output\", 0)) {\n      const user_op::Tensor* add_to_output = ctx->Tensor4ArgNameAndIndex(\"_add_to_output\", 0);\n      CHECK_EQ(add_to_output->data_type(), out->data_type());\n      CHECK_EQ(add_to_output->shape_view(), out->shape_view());\n      Memcpy<DeviceType::kCUDA>(\n          ctx->stream(), out->mut_dptr(), add_to_output->dptr(),\n          add_to_output->shape_view().elem_cnt() * GetSizeOfDataType(add_to_output->data_type()));\n      beta = CudnnSPOnePtr(in->data_type());\n    } else {\n      beta = CudnnSPZeroPtr(in->data_type());\n    }\n\n    OF_CUDNN_CHECK(cudnnConvolutionForward(\n        ctx->stream()->As<ep::CudaStream>()->cudnn_handle(), CudnnSPOnePtr(in->data_type()),\n        args.xdesc.Get(), in->dptr(), args.wdesc.Get(), weight->dptr(), args.cdesc.Get(),\n        algo_perf.algo, buf->mut_dptr(), args.params.max_ws_size, beta, args.ydesc.Get(),\n        out->mut_dptr()));\n\n    if (bias != nullptr) {\n      const auto* conv_cache = dynamic_cast<const ConvCudnnOpKernelCache*>(cache);\n      CHECK_NOTNULL(conv_cache);\n      OF_CUDNN_CHECK(cudnnAddTensor(ctx->stream()->As<ep::CudaStream>()->cudnn_handle(),\n                                    CudnnSPOnePtr(in->data_type()), conv_cache->bias_desc->Get(),\n                                    bias->dptr(), CudnnSPOnePtr(in->data_type()), args.ydesc.Get(),\n                                    out->mut_dptr()));\n    }\n  }\n\n  bool IsCudaGraphSupported(user_op::KernelInitContext* ctx,\n                            user_op::OpKernelState* state) const override {\n    return Singleton<ResourceDesc, ForSession>::Get()\n        ->resource()\n        .cudnn_conf()\n        .cudnn_conv_heuristic_search_algo();\n  }\n};\n\n#define REGISTER_CONV_KERNEL(op_name, ndims)                                                \\\n  REGISTER_USER_KERNEL(#op_name)                                                            \\\n      .SetCreateFn<ConvGpuKernel<ndims>>()                                                  \\\n      .SetIsMatchedHob(user_op::HobDeviceType() == DeviceType::kCUDA)                       \\\n      .SetInferTmpSizeFn([](user_op::InferContext* ctx) -> size_t {                         \\\n        const auto& in = ctx->InputTensorDesc(\"in\", 0);                                     \\\n        if (in.shape().elem_cnt() == 0) return 0;                                           \\\n        const auto& weight = ctx->InputTensorDesc(\"weight\", 0);                             \\\n        const auto& out = ctx->OutputTensorDesc(\"out\", 0);                                  \\\n        const auto& cudnn_conf =                                                            \\\n            Singleton<ResourceDesc, ForSession>::Get()->resource().cudnn_conf();            \\\n        return InferTmpSizeWithCudnn<cudnnConvolutionFwdAlgoPerf_t>(                        \\\n            &in, &weight, &out, *ctx, cudnn_conf.has_cudnn_conv_force_fwd_algo(),           \\\n            cudnn_conf.cudnn_conv_force_fwd_algo());                                        \\\n      })                                                                                    \\\n      .SetInplaceProposalFn(                                                                \\\n          [](const user_op::InferContext& ctx,                                              \\\n             const user_op::AddInplaceArgPair& AddInplaceArgPairFn) -> Maybe<void> {        \\\n            if (ctx.has_input(\"_add_to_output\", 0)) {                                       \\\n              OF_RETURN_IF_ERROR(AddInplaceArgPairFn(\"out\", 0, \"_add_to_output\", 0, true)); \\\n            }                                                                               \\\n            return Maybe<void>::Ok();                                                       \\\n          });\n\nREGISTER_CONV_KERNEL(conv1d, 1);\nREGISTER_CONV_KERNEL(conv2d, 2);\nREGISTER_CONV_KERNEL(conv3d, 3);\n\nclass ConvDataGradGpuKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(ConvDataGradGpuKernel);\n  ConvDataGradGpuKernel() = default;\n  ~ConvDataGradGpuKernel() = default;\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    const user_op::Tensor* filter = ctx->Tensor4ArgNameAndIndex(\"filter\", 0);\n    user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex(\"dx\", 0);\n    if (dx->shape_view().elem_cnt() == 0) return;\n    user_op::Tensor* buf = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n    const auto& cudnn_conf = Singleton<ResourceDesc, ForSession>::Get()->resource().cudnn_conf();\n\n    CudnnConvArgsAndAlgo<cudnnConvolutionBwdDataAlgoPerf_t> args_and_algo(\n        dx, filter, dy, buf, ctx, ctx->stream(), cudnn_conf.has_cudnn_conv_force_bwd_data_algo(),\n        cudnn_conf.cudnn_conv_force_bwd_data_algo());\n    const CudnnConvArgs& args = args_and_algo.args;\n    const cudnnConvolutionBwdDataAlgoPerf_t& algo_perf = args_and_algo.algo_perf;\n\n    const void* alpha = CudnnSPOnePtr(dy->data_type());\n    const void* beta = nullptr;\n    if (ctx->has_input(\"_add_to_output\", 0)) {\n      const user_op::Tensor* add_to_output = ctx->Tensor4ArgNameAndIndex(\"_add_to_output\", 0);\n      CHECK_EQ(add_to_output->data_type(), dx->data_type());\n      CHECK_EQ(add_to_output->shape_view(), dx->shape_view());\n      Memcpy<DeviceType::kCUDA>(\n          ctx->stream(), dx->mut_dptr<void>(), add_to_output->dptr<void>(),\n          add_to_output->shape_view().elem_cnt() * GetSizeOfDataType(add_to_output->data_type()));\n      beta = CudnnSPOnePtr(dy->data_type());\n    } else {\n      beta = CudnnSPZeroPtr(dy->data_type());\n    }\n\n    OF_CUDNN_CHECK(cudnnConvolutionBackwardData(\n        ctx->stream()->As<ep::CudaStream>()->cudnn_handle(), alpha, args.wdesc.Get(),\n        filter->dptr(), args.ydesc.Get(), dy->dptr(), args.cdesc.Get(), algo_perf.algo,\n        buf->mut_dptr(), args.params.max_ws_size, beta, args.xdesc.Get(), dx->mut_dptr()));\n  }\n\n  bool IsCudaGraphSupported(user_op::KernelInitContext* ctx,\n                            user_op::OpKernelState* state) const override {\n    return Singleton<ResourceDesc, ForSession>::Get()\n        ->resource()\n        .cudnn_conf()\n        .cudnn_conv_heuristic_search_algo();\n  }\n};\n\nREGISTER_USER_KERNEL(\"conv_data_grad\")\n    .SetCreateFn<ConvDataGradGpuKernel>()\n    .SetIsMatchedHob(user_op::HobDeviceType() == DeviceType::kCUDA)\n    .SetInferTmpSizeFn([](user_op::InferContext* ctx) -> size_t {\n      const auto& dy = ctx->InputTensorDesc(\"dy\", 0);\n      const auto& filter = ctx->InputTensorDesc(\"filter\", 0);\n      const auto& dx = ctx->OutputTensorDesc(\"dx\", 0);\n      if (dx.shape().elem_cnt() == 0) return 0;\n      const auto& cudnn_conf = Singleton<ResourceDesc, ForSession>::Get()->resource().cudnn_conf();\n      return InferTmpSizeWithCudnn<cudnnConvolutionBwdDataAlgoPerf_t>(\n          &dx, &filter, &dy, *ctx, cudnn_conf.has_cudnn_conv_force_bwd_data_algo(),\n          cudnn_conf.cudnn_conv_force_bwd_data_algo());\n    })\n    .SetInplaceProposalFn([](const user_op::InferContext& ctx,\n                             const user_op::AddInplaceArgPair& AddInplaceArgPairFn) -> Maybe<void> {\n      if (ctx.has_input(\"_add_to_output\", 0)) {\n        OF_RETURN_IF_ERROR(AddInplaceArgPairFn(\"dx\", 0, \"_add_to_output\", 0, true));\n      }\n      return Maybe<void>::Ok();\n    });\n\nclass ConvFilterGradGpuKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(ConvFilterGradGpuKernel);\n  ConvFilterGradGpuKernel() = default;\n  ~ConvFilterGradGpuKernel() = default;\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    user_op::Tensor* filter_diff = ctx->Tensor4ArgNameAndIndex(\"filter_diff\", 0);\n    if (x->shape_view().elem_cnt() == 0) {\n      Memset<DeviceType::kCUDA>(\n          ctx->stream(), filter_diff->mut_dptr(), 0,\n          filter_diff->shape_view().elem_cnt() * GetSizeOfDataType(filter_diff->data_type()));\n      return;\n    }\n    user_op::Tensor* buf = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n    const auto& cudnn_conf = Singleton<ResourceDesc, ForSession>::Get()->resource().cudnn_conf();\n\n    CudnnConvArgsAndAlgo<cudnnConvolutionBwdFilterAlgoPerf_t> args_and_algo(\n        x, filter_diff, dy, buf, ctx, ctx->stream(),\n        cudnn_conf.has_cudnn_conv_force_bwd_filter_algo(),\n        cudnn_conf.cudnn_conv_force_bwd_filter_algo());\n    const CudnnConvArgs& args = args_and_algo.args;\n    const cudnnConvolutionBwdFilterAlgoPerf_t& algo_perf = args_and_algo.algo_perf;\n\n    OF_CUDNN_CHECK(cudnnConvolutionBackwardFilter(\n        ctx->stream()->As<ep::CudaStream>()->cudnn_handle(), CudnnSPOnePtr(dy->data_type()),\n        args.xdesc.Get(), x->dptr(), args.ydesc.Get(), dy->dptr(), args.cdesc.Get(), algo_perf.algo,\n        buf->mut_dptr(), args.params.max_ws_size, CudnnSPZeroPtr(dy->data_type()), args.wdesc.Get(),\n        filter_diff->mut_dptr()));\n  }\n\n  bool IsCudaGraphSupported(user_op::KernelInitContext* ctx,\n                            user_op::OpKernelState* state) const override {\n    return Singleton<ResourceDesc, ForSession>::Get()\n        ->resource()\n        .cudnn_conf()\n        .cudnn_conv_heuristic_search_algo();\n  }\n};\n\nREGISTER_USER_KERNEL(\"conv_filter_grad\")\n    .SetCreateFn<ConvFilterGradGpuKernel>()\n    .SetIsMatchedHob(user_op::HobDeviceType() == DeviceType::kCUDA)\n    .SetInferTmpSizeFn([](user_op::InferContext* ctx) -> size_t {\n      const auto& dy = ctx->InputTensorDesc(\"dy\", 0);\n      const auto& x = ctx->InputTensorDesc(\"x\", 0);\n      if (x.shape().elem_cnt() == 0) return 0;\n      const auto& filter_diff = ctx->OutputTensorDesc(\"filter_diff\", 0);\n      const auto& cudnn_conf = Singleton<ResourceDesc, ForSession>::Get()->resource().cudnn_conf();\n      return InferTmpSizeWithCudnn<cudnnConvolutionBwdFilterAlgoPerf_t>(\n          &x, &filter_diff, &dy, *ctx, cudnn_conf.has_cudnn_conv_force_bwd_filter_algo(),\n          cudnn_conf.cudnn_conv_force_bwd_filter_algo());\n    });\n\nstruct ConvBiasGradState final : public user_op::OpKernelState {\n  std::unique_ptr<CudnnTensorDesc> bias_diff_desc;\n};\n\nclass ConvBiasGradGpuKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport {\n public:\n  ConvBiasGradGpuKernel() = default;\n  ~ConvBiasGradGpuKernel() = default;\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n\n  std::shared_ptr<ConvBiasGradState> CreateConvBiasGradState(\n      user_op::KernelComputeContext* ctx) const {\n    const auto* bias_diff = ctx->TensorDesc4ArgNameAndIndex(\"bias_diff\", 0);\n    const auto* dy = ctx->TensorDesc4ArgNameAndIndex(\"dy\", 0);\n    const auto& data_format = ctx->Attr<std::string>(\"data_format\");\n\n    std::shared_ptr<ConvBiasGradState> state(new ConvBiasGradState());\n    if (data_format == \"channels_first\") {\n      CHECK_EQ(dy->shape().At(1), bias_diff->shape().At(0));\n      state->bias_diff_desc.reset(\n          new CudnnTensorDesc(CUDNN_TENSOR_NCHW, bias_diff->data_type(), 1,\n                              static_cast<int32_t>(bias_diff->shape().At(0)), 1, 1));\n    } else {\n      CHECK(data_format == \"channels_last\") << \"Illegal data_format: \" << data_format;\n      CHECK_EQ(dy->shape().At(dy->shape().NumAxes() - 1), bias_diff->shape().At(0));\n      state->bias_diff_desc.reset(\n          new CudnnTensorDesc(CUDNN_TENSOR_NHWC, bias_diff->data_type(), 1,\n                              static_cast<int32_t>(bias_diff->shape().At(0)), 1, 1));\n    }\n    return state;\n  }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    user_op::Tensor* bias_diff = ctx->Tensor4ArgNameAndIndex(\"bias_diff\", 0);\n    CHECK_EQ(bias_diff->shape_view().NumAxes(), 1);\n    CHECK_GE(dy->shape_view().NumAxes(), 3);\n    CHECK_LE(dy->shape_view().NumAxes(), 5);\n\n    const std::string& data_format = ctx->Attr<std::string>(\"data_format\");\n\n    std::unique_ptr<CudnnTensorDesc> dy_desc;\n    dy_desc.reset(new CudnnTensorDesc(dy->data_type(), dy->shape_view(), data_format));\n    const auto& bias_grad_state = CreateConvBiasGradState(ctx);\n    CHECK_NOTNULL(bias_grad_state.get());\n    OF_CUDNN_CHECK(cudnnConvolutionBackwardBias(\n        ctx->stream()->As<ep::CudaStream>()->cudnn_handle(), CudnnSPOnePtr(dy->data_type()),\n        dy_desc->Get(), dy->dptr(), CudnnSPZeroPtr(dy->data_type()),\n        bias_grad_state->bias_diff_desc->Get(), bias_diff->mut_dptr()));\n  }\n};\n\nREGISTER_USER_KERNEL(\"conv_bias_grad\")\n    .SetCreateFn<ConvBiasGradGpuKernel>()\n    .SetIsMatchedHob(user_op::HobDeviceType() == DeviceType::kCUDA);\n\n}  // namespace\n\n}  // namespace oneflow\n\n#endif\n"
  },
  {
    "path": "oneflow/user/kernels/conv_cutlass_kernels.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#ifdef WITH_CUTLASS\n\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/new_kernel_util.h\"\n#include \"oneflow/core/kernel/cuda_graph_support.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n#include \"oneflow/core/job/lazy_mode.h\"\n#include \"oneflow/user/kernels/cutlass_conv_tuner.h\"\n#include <cutlass/library/handle.h>\n#include <cutlass/library/library.h>\n#include <cutlass/library/singleton.h>\n#include <nlohmann/json.hpp>\n\nnamespace oneflow {\n\nnamespace {\n\nclass Conv2dCutlassKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport {\n public:\n  Conv2dCutlassKernel() = default;\n  ~Conv2dCutlassKernel() override = default;\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*,\n               const user_op::OpKernelCache* cache) const override {\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    const user_op::Tensor* weight = ctx->Tensor4ArgNameAndIndex(\"weight\", 0);\n    const user_op::Tensor* bias = ctx->Tensor4ArgNameAndIndex(\"bias\", 0);\n    const user_op::Tensor* add_to_output = ctx->Tensor4ArgNameAndIndex(\"_add_to_output\", 0);\n    CHECK(add_to_output == nullptr);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n\n    const auto& padding_before = ctx->Attr<std::vector<int32_t>>(\"padding_before\");\n    const auto& dilation_rate = ctx->Attr<std::vector<int32_t>>(\"dilation_rate\");\n    const auto& strides = ctx->Attr<std::vector<int32_t>>(\"strides\");\n\n    const int n = in->shape_view().At(0);\n    const int h = in->shape_view().At(1);\n    const int w = in->shape_view().At(2);\n    const int c = in->shape_view().At(3);\n\n    const int k = weight->shape_view().At(0);\n    const int r = weight->shape_view().At(1);\n    const int s = weight->shape_view().At(2);\n    CHECK_EQ(weight->shape_view().At(3), c);\n\n    const int p = out->shape_view().At(1);\n    const int q = out->shape_view().At(2);\n\n    auto* stream = ctx->stream()->As<ep::CudaStream>();\n\n    cutlass::library::ConvFunctionalKey key(\n        cutlass::library::Provider::kCUTLASS, cutlass::library::ConvKind::kFprop,\n        cutlass::library::NumericTypeID::kF16, cutlass::library::LayoutTypeID::kTensorNHWC,\n        cutlass::library::NumericTypeID::kF16, cutlass::library::LayoutTypeID::kTensorNHWC,\n        cutlass::library::NumericTypeID::kF16, cutlass::library::LayoutTypeID::kTensorNHWC,\n        cutlass::library::NumericTypeID::kF32, cutlass::library::NumericTypeID::kF32);\n\n    const bool allow_half_accumulation =\n        ParseBooleanFromEnv(\"ONEFLOW_CONV_ALLOW_HALF_PRECISION_ACCUMULATION\", false);\n\n    if (allow_half_accumulation) {\n      key.element_accumulator = cutlass::library::NumericTypeID::kF16;\n      key.element_compute = cutlass::library::NumericTypeID::kF16;\n    }\n\n    cutlass::conv::Conv2dProblemSize problem_size(\n        n, h, w, c, k, r, s, p, q, padding_before.at(0), padding_before.at(1), strides.at(0),\n        strides.at(1), dilation_rate.at(0), dilation_rate.at(1),\n        cutlass::conv::Mode::kCrossCorrelation);\n    cutlass::library::Conv2dConfiguration configuraion;\n    configuraion.split_k_mode = cutlass::conv::SplitKMode::kSerial;\n    configuraion.problem_size = problem_size;\n    configuraion.stride_a = {c, w * c, h * w * c};\n    configuraion.stride_b = {c, s * c, r * s * c};\n    configuraion.stride_c = {0, 0, 0};\n\n    cutlass::library::ConvArguments arguments;\n    arguments.A = in->dptr();\n    arguments.B = weight->dptr();\n    arguments.reordered_B = nullptr;\n    if (bias == nullptr) {\n      arguments.C = nullptr;\n    } else {\n      arguments.C = bias->dptr();\n    }\n    arguments.D = out->mut_dptr();\n\n    union SP {\n      float f;\n      half h;\n    };\n\n    SP alpha;\n    SP beta;\n\n    if (allow_half_accumulation) {\n      alpha.h = static_cast<half>(1.0F);\n      if (bias == nullptr) {\n        beta.h = static_cast<half>(0.0F);\n      } else {\n        beta.h = static_cast<half>(1.0F);\n      }\n    } else {\n      alpha.f = 1.0F;\n      if (bias == nullptr) {\n        beta.f = 0.0F;\n      } else {\n        beta.f = 1.0F;\n      }\n    }\n    arguments.alpha = &alpha;\n    arguments.beta = &beta;\n    arguments.pointer_mode = cutlass::library::ScalarPointerMode::kHost;\n    const cutlass::library::Operation* operation = nullptr;\n    operation = [&]() -> const cutlass::library::Operation* {\n      const std::string& tuning_cache = ctx->Attr<std::string>(\"tuning_cache\");\n      if (tuning_cache.empty()) { return nullptr; }\n      auto tuning_cache_object = nlohmann::json::parse(tuning_cache);\n      if (!tuning_cache_object.is_object()) { return nullptr; }\n      auto it = tuning_cache_object.find(\"cutlass\");\n      if (it == tuning_cache_object.end()) { return nullptr; }\n      if (!it->is_string()) { return nullptr; }\n      const std::string name = *it;\n      return CutlassConvTuner::Get().GetConv2dOperation(name, stream, key, configuraion, arguments,\n                                                        tmp_buffer->mut_dptr(),\n                                                        tmp_buffer->shape_view().elem_cnt());\n    }();\n    if (!operation) {\n      operation = CutlassConvTuner::Get().FindConv2dOperation(stream, key, configuraion, arguments,\n                                                              tmp_buffer->mut_dptr(),\n                                                              tmp_buffer->shape_view().elem_cnt());\n    }\n\n    CHECK(operation != nullptr);\n    const size_t host_workspace_size = operation->get_host_workspace_size(&configuraion);\n    std::vector<uint8_t> host_workspace(host_workspace_size, 0);\n    auto init_status = operation->initialize(&configuraion, host_workspace.data(),\n                                             tmp_buffer->mut_dptr(), stream->cuda_stream());\n    CHECK(init_status == cutlass::Status::kSuccess);\n    auto run_status = operation->run(&arguments, host_workspace.data(), tmp_buffer->mut_dptr(),\n                                     stream->cuda_stream());\n    CHECK(run_status == cutlass::Status::kSuccess);\n  }\n};\n\nREGISTER_USER_KERNEL(\"conv2d\")\n    .SetCreateFn<Conv2dCutlassKernel>()\n    .SetIsMatchedHob(\n        (user_op::HobDeviceType() == DeviceType::kCUDA)\n        && (user_op::HobAttr<std::string>(\"data_format\") == \"channels_last\")\n        && (user_op::HobAttr<int32_t>(\"groups\") == 1)\n        && (user_op::HobDataType(\"in\", 0) == DataType::kFloat16)\n        // Compatible with typo `KERENL`\n        && ((user_op::HobEnvBool(\"ONEFLOW_KERNEL_CONV_ENABLE_CUTLASS_IMPL\", false) == true)\n            || (user_op::HobEnvBool(\"ONEFLOW_KERENL_CONV_ENABLE_CUTLASS_IMPL\", false) == true)))\n    .SetInferTmpSizeFn([](user_op::InferContext* ctx) -> size_t {\n      // use static workspace size\n      return 128 * 1024 * 1024;\n    })\n    .SetPriority(user_op::kKernelPriorityOptimized);\n\n}  // namespace\n\n}  // namespace oneflow\n\n#endif  // WITH_CUTLASS\n"
  },
  {
    "path": "oneflow/user/kernels/conv_kernels.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/user/ops/nn_util.h\"\n#include \"oneflow/core/kernel/kernel_util.h\"\n#include \"oneflow/core/ep/include/primitive/add.h\"\n#include \"oneflow/core/ep/include/primitive/matmul.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nep::primitive::BlasTransposeType GetBlasTransposeType(bool transpose) {\n  return transpose ? ep::primitive::BlasTransposeType::T : ep::primitive::BlasTransposeType::N;\n}\n\nstd::unique_ptr<ep::primitive::Matmul> NewMatmulPrimitive(DeviceType device_type,\n                                                          DataType data_type, bool transpose_a,\n                                                          bool transpose_b) {\n  const auto trans_a = GetBlasTransposeType(transpose_a);\n  const auto trans_b = GetBlasTransposeType(transpose_b);\n  return ep::primitive::NewPrimitive<ep::primitive::MatmulFactory>(device_type, data_type, trans_a,\n                                                                   trans_b);\n}\n\ntemplate<typename Context>\nstd::unique_ptr<ep::primitive::Matmul> NewChannelsFirstMatmulPrimitive(Context* ctx) {\n  const DataType data_type = ctx->TensorDesc4ArgNameAndIndex(\"in\", 0)->data_type();\n  return NewMatmulPrimitive(ctx->device_type(), data_type, /*transpose_a=*/false,\n                            /*transpose_b=*/false);\n}\n\nauto ChannelsFirstMatmulPrimitiveExists() {\n  return hob::make_custom(\"ChannelsFirstMatmulPrimitiveExists\",\n                          [](const user_op::KernelRegContext& ctx) {\n                            return NewChannelsFirstMatmulPrimitive(&ctx).operator bool();\n                          });\n}\n\ntemplate<typename Context>\nstd::unique_ptr<ep::primitive::Matmul> NewChannelsLastMatmulPrimitive(Context* ctx) {\n  const DataType data_type = ctx->TensorDesc4ArgNameAndIndex(\"in\", 0)->data_type();\n  return NewMatmulPrimitive(ctx->device_type(), data_type, /*transpose_a=*/true,\n                            /*transpose_b=*/true);\n}\n\nauto ChannelsLastMatmulPrimitiveExists() {\n  return hob::make_custom(\"ChannelsLastMatmulPrimitiveExists\",\n                          [](const user_op::KernelRegContext& ctx) {\n                            return NewChannelsLastMatmulPrimitive(&ctx).operator bool();\n                          });\n}\n\ntemplate<typename Context>\nstd::unique_ptr<ep::primitive::Matmul> NewConvDataGradTransATransBMatmulPrimitive(Context* ctx) {\n  const DataType data_type = ctx->TensorDesc4ArgNameAndIndex(\"dy\", 0)->data_type();\n  return NewMatmulPrimitive(ctx->device_type(), data_type, /*transpose_a=*/true,\n                            /*transpose_b=*/true);\n}\n\nauto ConvDataGradTransATransBMatmulPrimitiveExists() {\n  return hob::make_custom(\"ConvDataGradTransATransBMatmulPrimitiveExists\",\n                          [](const user_op::KernelRegContext& ctx) {\n                            return NewConvDataGradTransATransBMatmulPrimitive(&ctx).operator bool();\n                          });\n}\n\ntemplate<typename Context>\nstd::unique_ptr<ep::primitive::Matmul> NewConvDataGradTransANoTransBMatmulPrimitive(Context* ctx) {\n  const DataType data_type = ctx->TensorDesc4ArgNameAndIndex(\"dy\", 0)->data_type();\n  return NewMatmulPrimitive(ctx->device_type(), data_type, /*transpose_a=*/true,\n                            /*transpose_b=*/false);\n}\n\nauto ConvDataGradTransANoTransBMatmulPrimitiveExists() {\n  return hob::make_custom(\n      \"ConvDataGradTransANoTransBMatmulPrimitiveExists\", [](const user_op::KernelRegContext& ctx) {\n        return NewConvDataGradTransANoTransBMatmulPrimitive(&ctx).operator bool();\n      });\n}\n\ntemplate<typename Context>\nstd::unique_ptr<ep::primitive::Matmul> NewConvWeightGradTransATransBMatmulPrimitive(Context* ctx) {\n  const DataType data_type = ctx->TensorDesc4ArgNameAndIndex(\"dy\", 0)->data_type();\n  return NewMatmulPrimitive(ctx->device_type(), data_type, /*transpose_a=*/true,\n                            /*transpose_b=*/true);\n}\n\nauto ConvWeightGradTransATransBMatmulPrimitiveExists() {\n  return hob::make_custom(\n      \"ConvWeightGradTransATransBMatmulPrimitiveExists\", [](const user_op::KernelRegContext& ctx) {\n        return NewConvWeightGradTransATransBMatmulPrimitive(&ctx).operator bool();\n      });\n}\n\ntemplate<typename Context>\nstd::unique_ptr<ep::primitive::Matmul> NewConvWeightGradNoTransATransBMatmulPrimitive(\n    Context* ctx) {\n  const DataType data_type = ctx->TensorDesc4ArgNameAndIndex(\"dy\", 0)->data_type();\n  return NewMatmulPrimitive(ctx->device_type(), data_type, /*transpose_a=*/false,\n                            /*transpose_b=*/true);\n}\n\nauto ConvWeightGradNoTransATransBMatmulPrimitiveExists() {\n  return hob::make_custom(\n      \"ConvWeightGradNoTransATransBMatmulPrimitiveExists\",\n      [](const user_op::KernelRegContext& ctx) {\n        return NewConvWeightGradNoTransATransBMatmulPrimitive(&ctx).operator bool();\n      });\n}\n\ntemplate<typename Context>\nstd::unique_ptr<ep::primitive::Matmul> NewConvBiasGradNoTransANoTransBMatmulPrimitive(\n    Context* ctx) {\n  const DataType data_type = ctx->TensorDesc4ArgNameAndIndex(\"dy\", 0)->data_type();\n  return NewMatmulPrimitive(ctx->device_type(), data_type, /*transpose_a=*/false,\n                            /*transpose_b=*/false);\n}\n\nauto ConvBiasGradNoTransANoTransBMatmulPrimitiveExists() {\n  return hob::make_custom(\n      \"ConvBiasGradNoTransANoTransBMatmulPrimitiveExists\",\n      [](const user_op::KernelRegContext& ctx) {\n        return NewConvBiasGradNoTransANoTransBMatmulPrimitive(&ctx).operator bool();\n      });\n}\n\ntemplate<typename Context>\nstd::unique_ptr<ep::primitive::Matmul> NewConvBiasGradTransANoTransBMatmulPrimitive(Context* ctx) {\n  const DataType data_type = ctx->TensorDesc4ArgNameAndIndex(\"dy\", 0)->data_type();\n  return NewMatmulPrimitive(ctx->device_type(), data_type, /*transpose_a=*/true,\n                            /*transpose_b=*/false);\n}\n\nauto ConvBiasGradTransANoTransBMatmulPrimitiveExists() {\n  return hob::make_custom(\n      \"ConvBiasGradTransANoTransBMatmulPrimitiveExists\", [](const user_op::KernelRegContext& ctx) {\n        return NewConvBiasGradTransANoTransBMatmulPrimitive(&ctx).operator bool();\n      });\n}\n\ntemplate<typename T>\nusing Im2ColFunc = void (*)(const T* in_dptr, const ShapeView& in_shape,\n                            const ShapeView& weight_shape, const ShapeView& out_shape,\n                            const int32_t* strides, const int32_t* dilation_rate,\n                            const int32_t* padding_before, T* col_buf);\n\ntemplate<typename T>\nusing Col2ImFunc = void (*)(const T* col_buf, const ShapeView& in_shape,\n                            const ShapeView& weight_shape, const ShapeView& out_shape,\n                            const int32_t* strides, const int32_t* dilation_rate,\n                            const int32_t* padding_before, T* in_diff_ptr);\n\ntemplate<typename T>\nT* GetImgMutDptr(user_op::Tensor* tensor, int64_t idx) {\n  return tensor->mut_dptr<T>() + tensor->shape_view().Count(1) * idx;\n}\n\ntemplate<typename T>\nconst T* GetImgDptr(const user_op::Tensor* tensor, int64_t idx) {\n  return tensor->dptr<T>() + tensor->shape_view().Count(1) * idx;\n}\n\nsize_t CalcElemNumOfColBuf(const ShapeView& out_shape, const ShapeView& weight_shape,\n                           const int32_t idx_offset) {\n  int64_t col_buf_elem_cnt = 1;\n  int64_t ndims = out_shape.NumAxes() - 2;\n\n  for (size_t i = 0; i != ndims + 1; ++i) { col_buf_elem_cnt *= weight_shape.At(i + 1); }\n  for (size_t i = 0; i != ndims; ++i) { col_buf_elem_cnt *= out_shape.At(idx_offset + i); }\n  return col_buf_elem_cnt;\n}\n\ntemplate<typename T>\nclass ColBufWriter {\n public:\n  ColBufWriter(const T* src_ptr, T* dst_ptr, int64_t c_size, int64_t id_size, int64_t ih_size,\n               int64_t iw_size, int64_t od_size, int64_t oh_size, int64_t ow_size)\n      : src_ptr_(src_ptr),\n        dst_ptr_(dst_ptr),\n        c_size_(c_size),\n        id_size_(id_size),\n        ih_size_(ih_size),\n        iw_size_(iw_size),\n        od_size_(od_size),\n        oh_size_(oh_size),\n        ow_size_(ow_size) {}\n  virtual ~ColBufWriter() = default;\n  virtual void DHWCWrite(int64_t c, int64_t id, int64_t ih, int64_t iw) = 0;\n  virtual void CDHWWrite(int64_t c, int64_t id, int64_t ih, int64_t iw) = 0;\n  virtual void InvalidDFunc() = 0;\n  virtual void InvalidHFunc() = 0;\n  virtual void InvalidWFunc() = 0;\n  virtual void NextImCSize() = 0;\n\n protected:\n  const T* src_ptr_;\n  T* dst_ptr_;\n  int64_t c_size_;\n  int64_t id_size_;\n  int64_t ih_size_;\n  int64_t iw_size_;\n  int64_t od_size_;\n  int64_t oh_size_;\n  int64_t ow_size_;\n};\n\ntemplate<typename T>\nclass Im2ColWriter final : public ColBufWriter<T> {\n public:\n  Im2ColWriter(const T* src_ptr, T* dst_ptr, int64_t c_size, int64_t id_size, int64_t ih_size,\n               int64_t iw_size, int64_t od_size, int64_t oh_size, int64_t ow_size)\n      : ColBufWriter<T>::ColBufWriter(src_ptr, dst_ptr, c_size, id_size, ih_size, iw_size, od_size,\n                                      oh_size, ow_size) {}\n  ~Im2ColWriter() = default;\n  void DHWCWrite(int64_t c, int64_t id, int64_t ih, int64_t iw) override {\n    *(this->dst_ptr_++) =\n        this->src_ptr_[id * this->id_size_ + ih * this->ih_size_ + iw * this->iw_size_ + c];\n  }\n  void CDHWWrite(int64_t c, int64_t id, int64_t ih, int64_t iw) override {\n    *(this->dst_ptr_++) = this->src_ptr_[id * this->id_size_ + ih * this->ih_size_ + iw];\n  }\n  void InvalidDFunc() override {\n    FOR_RANGE(int64_t, i, 0, this->od_size_) { *(this->dst_ptr_++) = 0; }\n  }\n  void InvalidHFunc() override {\n    FOR_RANGE(int64_t, i, 0, this->oh_size_) { *(this->dst_ptr_++) = 0; }\n  }\n  void InvalidWFunc() override {\n    FOR_RANGE(int64_t, i, 0, this->ow_size_) { *(this->dst_ptr_++) = 0; }\n  }\n  void NextImCSize() override { this->src_ptr_ += this->c_size_; }\n};\n\ntemplate<typename T>\nclass Col2ImWriter final : public ColBufWriter<T> {\n public:\n  Col2ImWriter(const T* src_ptr, T* dst_ptr, int64_t c_size, int64_t id_size, int64_t ih_size,\n               int64_t iw_size, int64_t od_size, int64_t oh_size, int64_t ow_size)\n      : ColBufWriter<T>::ColBufWriter(src_ptr, dst_ptr, c_size, id_size, ih_size, iw_size, od_size,\n                                      oh_size, ow_size) {}\n  ~Col2ImWriter() = default;\n  void DHWCWrite(int64_t c, int64_t id, int64_t ih, int64_t iw) override {\n    this->dst_ptr_[id * this->id_size_ + ih * this->ih_size_ + iw * this->iw_size_ + c] +=\n        *(this->src_ptr_++);\n  }\n  void CDHWWrite(int64_t c, int64_t id, int64_t ih, int64_t iw) override {\n    this->dst_ptr_[id * this->id_size_ + ih * this->ih_size_ + iw] += *(this->src_ptr_++);\n  }\n  void InvalidDFunc() override { this->src_ptr_ += this->od_size_; }\n  void InvalidHFunc() override { this->src_ptr_ += this->oh_size_; }\n  void InvalidWFunc() override { this->src_ptr_ += this->ow_size_; }\n  void NextImCSize() override { this->dst_ptr_ += this->c_size_; }\n};\n\ntemplate<typename T>\nusing DHWValidFunc = void (ColBufWriter<T>::*)(int64_t c, int64_t kd, int64_t kh, int64_t kw);\n\ntemplate<typename T>\nclass ColBufUtil final {\n public:\n  ColBufUtil(const ShapeView& in_shape, const ShapeView& out_shape, int32_t dhw_offset,\n             const int32_t* strides, const int32_t* dilation_rate, const int32_t* padding_before)\n      : strides_(strides), dilation_rate_(dilation_rate), padding_before_(padding_before) {\n    id_num_ = in_shape.At(dhw_offset);\n    ih_num_ = in_shape.At(dhw_offset + 1);\n    iw_num_ = in_shape.At(dhw_offset + 2);\n    od_num_ = out_shape.At(dhw_offset);\n    oh_num_ = out_shape.At(dhw_offset + 1);\n    ow_num_ = out_shape.At(dhw_offset + 2);\n    if (dhw_offset == 2) {\n      dhw_valid_func_ = &ColBufWriter<T>::CDHWWrite;\n    } else {\n      dhw_valid_func_ = &ColBufWriter<T>::DHWCWrite;\n    }\n  }\n  void operator()(ColBufWriter<T>* col_buf_writer, int64_t c, int64_t kd, int64_t kh, int64_t kw) {\n    int64_t id = kd * dilation_rate_[0] - padding_before_[0];\n    FOR_RANGE(int64_t, od, 0, od_num_) {\n      if (id < 0 || id >= id_num_) {\n        col_buf_writer->InvalidDFunc();\n      } else {\n        int64_t ih = kh * dilation_rate_[1] - padding_before_[1];\n        FOR_RANGE(int64_t, oh, 0, oh_num_) {\n          if (ih < 0 || ih >= ih_num_) {\n            col_buf_writer->InvalidHFunc();\n          } else {\n            int64_t iw = kw * dilation_rate_[2] - padding_before_[2];\n            FOR_RANGE(int64_t, ow, 0, ow_num_) {\n              if (iw < 0 || iw >= iw_num_) {\n                col_buf_writer->InvalidWFunc();\n              } else {\n                (col_buf_writer->*dhw_valid_func_)(c, id, ih, iw);\n              }\n              iw += strides_[2];\n            }\n          }\n          ih += strides_[1];\n        }\n      }\n      id += strides_[0];\n    }\n  }\n\n private:\n  int64_t id_num_;\n  int64_t ih_num_;\n  int64_t iw_num_;\n  int64_t od_num_;\n  int64_t oh_num_;\n  int64_t ow_num_;\n  const int32_t* strides_;\n  const int32_t* dilation_rate_;\n  const int32_t* padding_before_;\n  DHWValidFunc<T> dhw_valid_func_;\n};\n\ntemplate<typename T>\nstruct ConvKernelUtil final {\n public:\n  static void NCDHWIm2Col(const T* in_dptr, const ShapeView& in_shape,\n                          const ShapeView& weight_shape, const ShapeView& out_shape,\n                          const int32_t* strides, const int32_t* dilation_rate,\n                          const int32_t* padding_before, T* col_buf_ptr) {\n    ColBufUtil<T> col_buf_util(in_shape, out_shape, 2, strides, dilation_rate, padding_before);\n    Im2ColWriter<T> col_buf_writer(in_dptr, col_buf_ptr, in_shape.Count(2), in_shape.Count(3),\n                                   in_shape.Count(4), 1, out_shape.Count(3), out_shape.Count(4), 1);\n    DoNCDWHFunc(weight_shape, col_buf_util, &col_buf_writer);\n  }\n\n  static void NDHWCIm2Col(const T* in_dptr, const ShapeView& in_shape,\n                          const ShapeView& weight_shape, const ShapeView& out_shape,\n                          const int32_t* strides, const int32_t* dilation_rate,\n                          const int32_t* padding_before, T* col_buf_ptr) {\n    ColBufUtil<T> col_buf_util(in_shape, out_shape, 1, strides, dilation_rate, padding_before);\n    Im2ColWriter<T> col_buf_writer(in_dptr, col_buf_ptr, in_shape.Count(2), in_shape.Count(2),\n                                   in_shape.Count(3), in_shape.Count(4), out_shape.Count(2, 4),\n                                   out_shape.Count(3, 4), 1);\n    DoNDWHCFunc(weight_shape, col_buf_util, &col_buf_writer);\n  }\n\n  static void NCDHWCol2Im(const T* col_buf_ptr, const ShapeView& in_shape,\n                          const ShapeView& weight_shape, const ShapeView& out_shape,\n                          const int32_t* strides, const int32_t* dilation_rate,\n                          const int32_t* padding_before, T* in_diff_ptr) {\n    ColBufUtil<T> col_buf_util(in_shape, out_shape, 2, strides, dilation_rate, padding_before);\n    Col2ImWriter<T> col_buf_writer(col_buf_ptr, in_diff_ptr, in_shape.Count(2), in_shape.Count(3),\n                                   in_shape.Count(4), 1, out_shape.Count(3), out_shape.Count(4), 1);\n    DoNCDWHFunc(weight_shape, col_buf_util, &col_buf_writer);\n  }\n\n  static void NDHWCCol2Im(const T* col_buf_ptr, const ShapeView& in_shape,\n                          const ShapeView& weight_shape, const ShapeView& out_shape,\n                          const int32_t* strides, const int32_t* dilation_rate,\n                          const int32_t* padding_before, T* in_diff_ptr) {\n    ColBufUtil<T> col_buf_util(in_shape, out_shape, 1, strides, dilation_rate, padding_before);\n    Col2ImWriter<T> col_buf_writer(col_buf_ptr, in_diff_ptr, in_shape.Count(2), in_shape.Count(2),\n                                   in_shape.Count(3), in_shape.Count(4), out_shape.Count(2, 4),\n                                   out_shape.Count(3, 4), 1);\n    DoNDWHCFunc(weight_shape, col_buf_util, &col_buf_writer);\n  }\n\n private:\n  static void DoNCDWHFunc(const ShapeView& weight_shape, ColBufUtil<T>& col_buf_util,\n                          ColBufWriter<T>* col_buf_writer) {\n    for (int64_t c = 0; c != weight_shape.At(1); col_buf_writer->NextImCSize(), ++c) {\n      for (int64_t kd = 0; kd != weight_shape.At(2); ++kd) {\n        for (int64_t kh = 0; kh != weight_shape.At(3); ++kh) {\n          for (int64_t kw = 0; kw != weight_shape.At(4); ++kw) {\n            col_buf_util(col_buf_writer, c, kd, kh, kw);\n          }\n        }\n      }\n    }\n  }\n\n  static void DoNDWHCFunc(const ShapeView& weight_shape, ColBufUtil<T>& col_buf_util,\n                          ColBufWriter<T>* col_buf_writer) {\n    for (int64_t kd = 0; kd != weight_shape.At(1); ++kd) {\n      for (int64_t kh = 0; kh != weight_shape.At(2); ++kh) {\n        for (int64_t kw = 0; kw != weight_shape.At(3); ++kw) {\n          for (int64_t c = 0; c != weight_shape.At(4); ++c) {\n            col_buf_util(col_buf_writer, c, kd, kh, kw);\n          }\n        }\n      }\n    }\n  }\n};\n\ntemplate<typename T>\nstruct ConvOpKernelCache final : public user_op::OpKernelCache {\n  Im2ColFunc<T> im2col_func_ = nullptr;\n  Col2ImFunc<T> col2im_func_ = nullptr;\n\n  Shape in_5d_shape_;\n  Shape out_5d_shape_;\n  Shape weight_5d_shape_;\n\n  std::vector<int32_t> strides_3d_;\n  std::vector<int32_t> dilation_rate_3d_;\n  std::vector<int32_t> padding_before_3d_;\n\n  bool is_out_diff_need_trans_ = false;\n\n  int32_t idx_offset_{};\n  bool is_dynamic_{};\n};\n\ntemplate<typename T>\nstd::shared_ptr<ConvOpKernelCache<T>> CreateConvOpKernelCache(user_op::KernelCacheContext* ctx,\n                                                              const std::string& in_name,\n                                                              const std::string& out_name,\n                                                              const std::string& weight_name) {\n  const auto& data_format = ctx->Attr<std::string>(\"data_format\");\n\n  std::shared_ptr<ConvOpKernelCache<T>> cache(new ConvOpKernelCache<T>());\n  if (data_format == \"channels_first\") {\n    cache->im2col_func_ = ConvKernelUtil<T>::NCDHWIm2Col;\n    cache->col2im_func_ = ConvKernelUtil<T>::NCDHWCol2Im;\n    cache->is_out_diff_need_trans_ = false;\n    cache->idx_offset_ = 2;\n  } else {\n    cache->im2col_func_ = ConvKernelUtil<T>::NDHWCIm2Col;\n    cache->col2im_func_ = ConvKernelUtil<T>::NDHWCCol2Im;\n    cache->is_out_diff_need_trans_ = true;\n    cache->idx_offset_ = 1;\n  }\n\n  auto Gen5DShape = [](const Shape& shape, int32_t idx_offset) -> Shape {\n    DimVector ret_vec(shape.dim_vec());\n    int32_t ndims = ret_vec.size() - 2;\n    ret_vec.insert(ret_vec.begin() + idx_offset, 3 - ndims, 1);\n    return Shape(ret_vec);\n  };\n  const auto* in_tensor = ctx->TensorDesc4ArgNameAndIndex(in_name, 0);\n  const auto& in_shape = in_tensor->shape();\n  cache->in_5d_shape_ = Gen5DShape(in_shape, cache->idx_offset_);\n  cache->out_5d_shape_ =\n      Gen5DShape(ctx->TensorDesc4ArgNameAndIndex(out_name, 0)->shape(), cache->idx_offset_);\n  cache->weight_5d_shape_ =\n      Gen5DShape(ctx->TensorDesc4ArgNameAndIndex(weight_name, 0)->shape(), cache->idx_offset_);\n\n  auto Gen3DVec = [](const std::vector<int32_t>& origin_vec) -> std::vector<int32_t> {\n    std::vector<int32_t> ret_vec = origin_vec;\n    ret_vec.insert(ret_vec.begin(), 3 - ret_vec.size(), 1);\n    return ret_vec;\n  };\n  cache->strides_3d_ = Gen3DVec(ctx->Attr<std::vector<int32_t>>(\"strides\"));\n  cache->dilation_rate_3d_ = Gen3DVec(ctx->Attr<std::vector<int32_t>>(\"dilation_rate\"));\n  cache->is_dynamic_ = ctx->TensorDesc4ArgNameAndIndex(in_name, 0)->is_dynamic();\n  const auto& padding_before = ctx->Attr<std::vector<int32_t>>(\"padding_before\");\n  FOR_RANGE(uint8_t, dim, 0, 3) {\n    int64_t index = static_cast<int64_t>(dim) - (3 - padding_before.size());\n    if (index < 0) {\n      cache->padding_before_3d_.emplace_back(0);\n    } else {\n      cache->padding_before_3d_.emplace_back(padding_before.at(index));\n    }\n  }\n\n  return cache;\n}\n\ntemplate<typename T>\nvoid InitBiasMulBuf(T* dptr, int64_t num) {\n  for (int64_t i = 0; i < num; ++i) { dptr[i] = 1; }\n}\n\ntemplate<typename T, size_t NDims>\nclass ConvCpuKernel final : public user_op::OpKernel {\n public:\n  ConvCpuKernel() = default;\n  ~ConvCpuKernel() = default;\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n\n  std::shared_ptr<user_op::OpKernelCache> InitOpKernelCache(\n      user_op::KernelCacheContext* ctx) const override {\n    return CreateConvOpKernelCache<T>(ctx, \"in\", \"out\", \"weight\");\n  }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*,\n               const user_op::OpKernelCache* cache) const override {\n    const auto* conv_cache = dynamic_cast<const ConvOpKernelCache<T>*>(cache);\n    CHECK_NOTNULL(conv_cache);\n\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    const user_op::Tensor* weight = ctx->Tensor4ArgNameAndIndex(\"weight\", 0);\n    user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n\n    T* col_buf_dptr = tmp_buffer->mut_dptr<T>();\n\n    bool is_bias_mul_inited = false;\n\n    const auto& data_format = ctx->Attr<std::string>(\"data_format\");\n    std::unique_ptr<ep::primitive::Matmul> matmul;\n    if (data_format == \"channels_first\") {\n      matmul = NewChannelsFirstMatmulPrimitive(ctx);\n    } else {\n      matmul = NewChannelsLastMatmulPrimitive(ctx);\n    }\n    CHECK(matmul);\n\n    float beta = 0;\n    if (ctx->has_input(\"_add_to_output\", 0)) {\n      const user_op::Tensor* add_to_output = ctx->Tensor4ArgNameAndIndex(\"_add_to_output\", 0);\n      CHECK_EQ(add_to_output->data_type(), out->data_type());\n      CHECK_EQ(add_to_output->shape_view(), out->shape_view());\n      Memcpy<DeviceType::kCPU>(\n          ctx->stream(), out->mut_dptr(), add_to_output->dptr(),\n          add_to_output->shape_view().elem_cnt() * GetSizeOfDataType(add_to_output->data_type()));\n      beta = 1;\n    }\n\n    for (int64_t i = 0; i < in->shape_view().At(0); ++i) {\n      conv_cache->im2col_func_(GetImgDptr<T>(in, i), ShapeView(conv_cache->in_5d_shape_),\n                               ShapeView(conv_cache->weight_5d_shape_),\n                               ShapeView(conv_cache->out_5d_shape_), conv_cache->strides_3d_.data(),\n                               conv_cache->dilation_rate_3d_.data(),\n                               conv_cache->padding_before_3d_.data(), col_buf_dptr);\n\n      // channels first: out = weight * col_buf\n      // channels last:  out = (weight * col_buf)(T)\n      int32_t idx_offset = conv_cache->idx_offset_;\n      matmul->Launch(ctx->stream(),\n                     conv_cache->weight_5d_shape_.At(0),                           // filter\n                     conv_cache->out_5d_shape_.Count(idx_offset, idx_offset + 3),  // od * oh * ow\n                     conv_cache->weight_5d_shape_.Count(1),  // ci * kd * kh * kw\n                     static_cast<T>(1), weight->dptr<T>(), col_buf_dptr, beta,\n                     GetImgMutDptr<T>(out, i));\n\n      const user_op::Tensor* bias = ctx->Tensor4ArgNameAndIndex(\"bias\", 0);\n      if (bias != nullptr) {\n        int64_t num_of_col_buf =\n            CalcElemNumOfColBuf(out->shape_view(), weight->shape_view(), idx_offset);\n        int64_t num_of_bias_mul =\n            (tmp_buffer->shape_view().elem_cnt() - num_of_col_buf * sizeof(T)) / sizeof(T);\n        CHECK_GT(num_of_bias_mul, 0);\n        T* bias_mul_dptr = col_buf_dptr + num_of_col_buf;\n        if (!is_bias_mul_inited) {\n          InitBiasMulBuf(bias_mul_dptr, num_of_bias_mul);\n          is_bias_mul_inited = true;\n        }\n\n        // channels first:  out += bias * bias_mul\n        // channels last:   out += (bias * bias_mul)(T)\n        matmul->Launch(ctx->stream(),\n                       conv_cache->weight_5d_shape_.At(0),                           // filter\n                       conv_cache->out_5d_shape_.Count(idx_offset, idx_offset + 3),  // od * oh * ow\n                       1,                                                            // 1\n                       static_cast<T>(1), bias->dptr<T>(), bias_mul_dptr, static_cast<T>(1),\n                       GetImgMutDptr<T>(out, i));\n      }\n    }\n  }\n};\n\n#define REGISTER_CONV_KERNEL(op_name, dtype, ndims)                                         \\\n  REGISTER_USER_KERNEL(#op_name)                                                            \\\n      .SetCreateFn<ConvCpuKernel<dtype, ndims>>()                                           \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)                       \\\n                       && (user_op::HobAttr<int32_t>(\"groups\") == 1)                        \\\n                       && (user_op::HobDataType(\"in\", 0) == GetDataType<dtype>::value)      \\\n                       && ChannelsFirstMatmulPrimitiveExists()                              \\\n                       && ChannelsLastMatmulPrimitiveExists())                              \\\n      .SetInferTmpSizeFn([](user_op::InferContext* ctx) -> size_t {                         \\\n        size_t tmp_buffer_size = 0;                                                         \\\n        const auto& out_shape = ctx->OutputTensorDesc(\"out\", 0).shape();                    \\\n        const auto& weight_shape = ctx->InputTensorDesc(\"weight\", 0).shape();               \\\n                                                                                            \\\n        int64_t idx_offset = IdxOffset(ctx->Attr<std::string>(\"data_format\"));              \\\n        tmp_buffer_size +=                                                                  \\\n            CalcElemNumOfColBuf(out_shape, weight_shape, idx_offset) * sizeof(dtype);       \\\n        bool has_bias = ctx->has_input(\"bias\", 0);                                          \\\n        if (has_bias) {                                                                     \\\n          int64_t bias_mul_cnt = 1;                                                         \\\n          for (int i = 0; i < ndims; ++i) { bias_mul_cnt *= out_shape.At(idx_offset + i); } \\\n          tmp_buffer_size += bias_mul_cnt * sizeof(dtype);                                  \\\n        }                                                                                   \\\n        return tmp_buffer_size;                                                             \\\n      })                                                                                    \\\n      .SetInplaceProposalFn(                                                                \\\n          [](const user_op::InferContext& ctx,                                              \\\n             const user_op::AddInplaceArgPair& AddInplaceArgPairFn) -> Maybe<void> {        \\\n            if (ctx.has_input(\"_add_to_output\", 0)) {                                       \\\n              OF_RETURN_IF_ERROR(AddInplaceArgPairFn(\"out\", 0, \"_add_to_output\", 0, true)); \\\n            }                                                                               \\\n            return Maybe<void>::Ok();                                                       \\\n          });\n\nREGISTER_CONV_KERNEL(conv1d, float, 1);\nREGISTER_CONV_KERNEL(conv2d, float, 2);\nREGISTER_CONV_KERNEL(conv3d, float, 3);\nREGISTER_CONV_KERNEL(conv1d, double, 1);\nREGISTER_CONV_KERNEL(conv2d, double, 2);\nREGISTER_CONV_KERNEL(conv3d, double, 3);\n\ntemplate<typename T>\nclass ConvDataGradCpuKernel final : public user_op::OpKernel {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(ConvDataGradCpuKernel);\n  ConvDataGradCpuKernel() = default;\n  ~ConvDataGradCpuKernel() = default;\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n\n  std::shared_ptr<user_op::OpKernelCache> InitOpKernelCache(\n      user_op::KernelCacheContext* ctx) const override {\n    return CreateConvOpKernelCache<T>(ctx, \"dx\", \"dy\", \"filter\");\n  }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*,\n               const user_op::OpKernelCache* cache) const override {\n    const auto* conv_cache = dynamic_cast<const ConvOpKernelCache<T>*>(cache);\n    CHECK_NOTNULL(conv_cache);\n\n    const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    const user_op::Tensor* filter = ctx->Tensor4ArgNameAndIndex(\"filter\", 0);\n    user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex(\"dx\", 0);\n    user_op::Tensor* col_buf = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n\n    Memset<DeviceType::kCPU>(ctx->stream(), dx->mut_dptr<T>(), 0,\n                             dx->shape_view().elem_cnt() * sizeof(T));\n\n    std::unique_ptr<ep::primitive::Matmul> matmul;\n    if (conv_cache->is_out_diff_need_trans_) {\n      matmul = NewConvDataGradTransATransBMatmulPrimitive(ctx);\n    } else {\n      matmul = NewConvDataGradTransANoTransBMatmulPrimitive(ctx);\n    }\n    CHECK(matmul);\n\n    int32_t idx_offset = conv_cache->idx_offset_;\n    FOR_RANGE(int64_t, i, 0, dy->shape_view().At(0)) {\n      // channels first:  col_buf' = weight(T) * out[i]'\n      // channels last :  col_buf' = weight(T) * out[i]'(T)\n      matmul->Launch(ctx->stream(),\n                     conv_cache->weight_5d_shape_.Count(1),  //  ci * kd * kh * kw\n                     conv_cache->out_5d_shape_.Count(idx_offset, idx_offset + 3),  //  od * oh * ow\n                     conv_cache->weight_5d_shape_.At(0),                           //  filter\n                     static_cast<T>(1), filter->dptr<T>(), GetImgDptr<T>(dy, i), static_cast<T>(0),\n                     col_buf->mut_dptr<T>());\n\n      // in' = col2im(col_buf')\n      conv_cache->col2im_func_(col_buf->dptr<T>(), ShapeView(conv_cache->in_5d_shape_),\n                               ShapeView(conv_cache->weight_5d_shape_),\n                               ShapeView(conv_cache->out_5d_shape_), conv_cache->strides_3d_.data(),\n                               conv_cache->dilation_rate_3d_.data(),\n                               conv_cache->padding_before_3d_.data(), GetImgMutDptr<T>(dx, i));\n    }\n    if (ctx->has_input(\"_add_to_output\", 0)) {\n      const user_op::Tensor* add_to_output = ctx->Tensor4ArgNameAndIndex(\"_add_to_output\", 0);\n      CHECK_EQ(add_to_output->data_type(), dx->data_type());\n      CHECK_EQ(add_to_output->shape_view(), dx->shape_view());\n      std::unique_ptr<ep::primitive::Add> primitive =\n          ep::primitive::NewPrimitive<ep::primitive::AddFactory>(DeviceType::kCPU,\n                                                                 add_to_output->data_type());\n      CHECK(primitive);\n      primitive->Launch(ctx->stream(), dx->dptr<T>(), add_to_output->dptr<T>(), dx->mut_dptr<T>(),\n                        add_to_output->shape_view().elem_cnt());\n    }\n  }\n};\n\n#define REGISTER_CONV_DATA_GRAD_KERNEL(op_name, dtype)                                     \\\n  REGISTER_USER_KERNEL(#op_name)                                                           \\\n      .SetCreateFn<ConvDataGradCpuKernel<dtype>>()                                         \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)                      \\\n                       && (user_op::HobAttr<int32_t>(\"groups\") == 1)                       \\\n                       && (user_op::HobDataType(\"dy\", 0) == GetDataType<dtype>::value)     \\\n                       && ConvDataGradTransATransBMatmulPrimitiveExists()                  \\\n                       && ConvDataGradTransANoTransBMatmulPrimitiveExists())               \\\n      .SetInferTmpSizeFn([](user_op::InferContext* ctx) -> size_t {                        \\\n        size_t tmp_buffer_size = 0;                                                        \\\n        const auto& out_diff_shape = ctx->InputTensorDesc(\"dy\", 0).shape();                \\\n        const auto& weight_shape = ctx->InputTensorDesc(\"filter\", 0).shape();              \\\n                                                                                           \\\n        int64_t idx_offset = IdxOffset(ctx->Attr<std::string>(\"data_format\"));             \\\n        tmp_buffer_size +=                                                                 \\\n            CalcElemNumOfColBuf(out_diff_shape, weight_shape, idx_offset) * sizeof(dtype); \\\n        return tmp_buffer_size;                                                            \\\n      })\n\nREGISTER_CONV_DATA_GRAD_KERNEL(conv_data_grad, float);\nREGISTER_CONV_DATA_GRAD_KERNEL(conv_data_grad, double);\n\ntemplate<typename T>\nclass ConvFilterGradCpuKernel final : public user_op::OpKernel {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(ConvFilterGradCpuKernel);\n  ConvFilterGradCpuKernel() = default;\n  ~ConvFilterGradCpuKernel() = default;\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n\n  std::shared_ptr<user_op::OpKernelCache> InitOpKernelCache(\n      user_op::KernelCacheContext* ctx) const override {\n    return CreateConvOpKernelCache<T>(ctx, \"x\", \"dy\", \"filter_diff\");\n  }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*,\n               const user_op::OpKernelCache* cache) const override {\n    const auto* conv_cache = dynamic_cast<const ConvOpKernelCache<T>*>(cache);\n    CHECK_NOTNULL(conv_cache);\n\n    const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    user_op::Tensor* filter_diff = ctx->Tensor4ArgNameAndIndex(\"filter_diff\", 0);\n    user_op::Tensor* col_buf = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n\n    Memset<DeviceType::kCPU>(ctx->stream(), filter_diff->mut_dptr<T>(), 0,\n                             filter_diff->shape_view().elem_cnt() * sizeof(T));\n    std::unique_ptr<ep::primitive::Matmul> matmul;\n    if (conv_cache->is_out_diff_need_trans_) {\n      matmul = NewConvWeightGradTransATransBMatmulPrimitive(ctx);\n    } else {\n      matmul = NewConvWeightGradNoTransATransBMatmulPrimitive(ctx);\n    }\n    CHECK(matmul);\n\n    int32_t idx_offset = conv_cache->idx_offset_;\n    FOR_RANGE(int64_t, i, 0, dy->shape_view().At(0)) {\n      conv_cache->im2col_func_(GetImgDptr<T>(x, i), ShapeView(conv_cache->in_5d_shape_),\n                               ShapeView(conv_cache->weight_5d_shape_),\n                               ShapeView(conv_cache->out_5d_shape_), conv_cache->strides_3d_.data(),\n                               conv_cache->dilation_rate_3d_.data(),\n                               conv_cache->padding_before_3d_.data(), col_buf->mut_dptr<T>());\n\n      // channels first:  weight' += out[i]' * col_buf(T)\n      // channels last :  weight' += out[i]'(T) * col_buf(T)\n      matmul->Launch(ctx->stream(),\n                     conv_cache->weight_5d_shape_.At(0),     //  filter\n                     conv_cache->weight_5d_shape_.Count(1),  //  ci * kd * kh * kw\n                     conv_cache->out_5d_shape_.Count(idx_offset, idx_offset + 3),  //  od * oh * ow\n                     static_cast<T>(1), GetImgDptr<T>(dy, i), col_buf->dptr<T>(), static_cast<T>(1),\n                     filter_diff->mut_dptr<T>());\n    }\n  }\n};\n\n#define REGISTER_CONV_FILTER_GRAD_KERNEL(op_name, dtype)                                        \\\n  REGISTER_USER_KERNEL(#op_name)                                                                \\\n      .SetCreateFn<ConvFilterGradCpuKernel<dtype>>()                                            \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)                           \\\n                       && (user_op::HobAttr<int32_t>(\"groups\") == 1)                            \\\n                       && (user_op::HobDataType(\"dy\", 0) == GetDataType<dtype>::value)          \\\n                       && ConvWeightGradTransATransBMatmulPrimitiveExists()                     \\\n                       && ConvWeightGradNoTransATransBMatmulPrimitiveExists())                  \\\n      .SetInferTmpSizeFn([](user_op::InferContext* ctx) -> size_t {                             \\\n        size_t tmp_buffer_size = 0;                                                             \\\n        const auto& out_diff_shape = ctx->InputTensorDesc(\"dy\", 0).shape();                     \\\n        const auto& weight_diff_shape = ctx->OutputTensorDesc(\"filter_diff\", 0).shape();        \\\n                                                                                                \\\n        int64_t idx_offset = IdxOffset(ctx->Attr<std::string>(\"data_format\"));                  \\\n        tmp_buffer_size +=                                                                      \\\n            CalcElemNumOfColBuf(out_diff_shape, weight_diff_shape, idx_offset) * sizeof(dtype); \\\n        return tmp_buffer_size;                                                                 \\\n      })\n\nREGISTER_CONV_FILTER_GRAD_KERNEL(conv_filter_grad, float);\nREGISTER_CONV_FILTER_GRAD_KERNEL(conv_filter_grad, double);\n\ntemplate<typename T>\nclass ConvBiasGradCpuKernel final : public user_op::OpKernel {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(ConvBiasGradCpuKernel);\n  ConvBiasGradCpuKernel() = default;\n  ~ConvBiasGradCpuKernel() = default;\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    user_op::Tensor* bias_diff = ctx->Tensor4ArgNameAndIndex(\"bias_diff\", 0);\n    user_op::Tensor* bias_mul_buf = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n\n    InitBiasMulBuf(bias_mul_buf->mut_dptr<T>(), bias_mul_buf->shape_view().elem_cnt() / sizeof(T));\n    Memset<DeviceType::kCPU>(ctx->stream(), bias_diff->mut_dptr<T>(), 0,\n                             bias_diff->shape_view().elem_cnt() * sizeof(T));\n\n    const auto& data_format = ctx->Attr<std::string>(\"data_format\");\n    int32_t idx_offset;\n    bool is_out_diff_need_trans = false;\n    int32_t filter;\n    if (data_format == \"channels_first\") {\n      idx_offset = 2;\n      is_out_diff_need_trans = false;\n      filter = dy->shape_view().At(1);\n    } else {\n      idx_offset = 1;\n      is_out_diff_need_trans = true;\n      filter = dy->shape_view().At(dy->shape_view().NumAxes() - 1);\n    }\n    std::unique_ptr<ep::primitive::Matmul> matmul;\n    if (is_out_diff_need_trans) {\n      matmul = NewConvBiasGradTransANoTransBMatmulPrimitive(ctx);\n    } else {\n      matmul = NewConvBiasGradNoTransANoTransBMatmulPrimitive(ctx);\n    }\n    CHECK(matmul);\n\n    int ndims = dy->shape_view().NumAxes() - 2;\n    FOR_RANGE(int64_t, i, 0, dy->shape_view().At(0)) {\n      // channels first:  bias' += out' * bias_mul\n      // channels last:   bias' += out'(T) * bias_mul\n      matmul->Launch(ctx->stream(),\n                     filter,                                                  //  filter\n                     1,                                                       //  1\n                     dy->shape_view().Count(idx_offset, idx_offset + ndims),  //  od * oh * ow\n                     static_cast<T>(1), GetImgDptr<T>(dy, i), bias_mul_buf->dptr<T>(),\n                     static_cast<T>(1), bias_diff->mut_dptr<T>());\n    }\n  }\n};\n\n#define REGISTER_CONV_BIAS_GRAD_KERNEL(op_name, dtype)                                         \\\n  REGISTER_USER_KERNEL(#op_name)                                                               \\\n      .SetCreateFn<ConvBiasGradCpuKernel<dtype>>()                                             \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)                          \\\n                       && (user_op::HobDataType(\"dy\", 0) == GetDataType<dtype>::value)         \\\n                       && ConvBiasGradTransANoTransBMatmulPrimitiveExists()                    \\\n                       && ConvBiasGradNoTransANoTransBMatmulPrimitiveExists())                 \\\n      .SetInferTmpSizeFn([](user_op::InferContext* ctx) -> size_t {                            \\\n        const auto& out_diff_shape = ctx->InputTensorDesc(\"dy\", 0).shape();                    \\\n        const int ndims = out_diff_shape.NumAxes() - 2;                                        \\\n        int64_t idx_offset = IdxOffset(ctx->Attr<std::string>(\"data_format\"));                 \\\n        int64_t bias_mul_cnt = 1;                                                              \\\n        for (int i = 0; i < ndims; ++i) { bias_mul_cnt *= out_diff_shape.At(idx_offset + i); } \\\n        return bias_mul_cnt * sizeof(dtype);                                                   \\\n      })\n\nREGISTER_CONV_BIAS_GRAD_KERNEL(conv_bias_grad, float);\nREGISTER_CONV_BIAS_GRAD_KERNEL(conv_bias_grad, double);\n}  // namespace\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/convert_memory_format_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/user/kernels/convert_memory_format_util.h\"\n\nnamespace oneflow {\n\nclass ConvertMemoryFormatKernel final : public user_op::OpKernel {\n public:\n  ConvertMemoryFormatKernel() = default;\n  ~ConvertMemoryFormatKernel() override = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state,\n               const user_op::OpKernelCache*) const override {\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    ConvertMemoryFormat(ctx->stream(), in->shape_view().NumAxes(), in->shape_view().data(),\n                        in->data_type(), in->dptr(), out->mut_dptr(), in->memory_format(),\n                        out->memory_format());\n  }\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\nREGISTER_USER_KERNEL(\"convert_memory_format\").SetCreateFn<ConvertMemoryFormatKernel>();\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/convert_memory_format_util.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/user/kernels/convert_memory_format_util.h\"\n\n#include \"oneflow/core/common/data_type.h\"\n#include \"oneflow/core/ep/include/primitive/memcpy.h\"\n#include \"oneflow/core/ep/include/primitive/permute.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/new_kernel_util.h\"\n\nnamespace oneflow {\n\nstd::unique_ptr<ep::primitive::Permute> NewPermutePrimitive(DeviceType device_type,\n                                                            const int& num_dims) {\n  return ep::primitive::NewPrimitive<ep::primitive::PermuteFactory>(device_type, num_dims);\n}\n\nstd::unique_ptr<ep::primitive::Memcpy> NewMemcpyPrimitive(DeviceType device_type) {\n  return ep::primitive::NewPrimitive<ep::primitive::MemcpyFactory>(\n      device_type, ep::primitive::MemcpyKind::kDtoD);\n}\n\nvoid ComputeIdentity(ep::Stream* stream, int ndim, const int64_t* shape, DataType data_type,\n                     const void* in, void* out) {\n  size_t count = 1;\n  for (int i = 0; i < ndim; ++i) { count *= shape[i]; }\n  auto memcpy_primitive = NewMemcpyPrimitive(stream->device_type());\n  CHECK(memcpy_primitive) << \"Can not create Memcpy primitive for device type \"\n                          << stream->device_type();\n  memcpy_primitive->Launch(stream, out, in, count * GetSizeOfDataType(data_type));\n}\n\nvoid ComputeContiguousToChannelsLast(ep::Stream* stream, int ndim, const int64_t* shape,\n                                     DataType data_type, const void* in, void* out) {\n  if (ndim <= 2) { return ComputeIdentity(stream, ndim, shape, data_type, in, out); }\n\n  std::vector<int32_t> permute(ndim);\n  permute[0] = 0;\n  permute[ndim - 1] = 1;\n  for (int i = 0; i < ndim - 2; ++i) { permute[i + 1] = i + 2; }\n  auto primitive = NewPermutePrimitive(stream->device_type(), ndim);\n  CHECK_NOTNULL_OR_THROW(primitive);\n  primitive->Launch(stream, data_type, ndim, shape, in, permute.data(), out);\n}\n\nvoid ComputeChannelsLastToContiguous(ep::Stream* stream, int ndim, const int64_t* shape,\n                                     DataType data_type, const void* in, void* out) {\n  if (ndim <= 2) { return ComputeIdentity(stream, ndim, shape, data_type, in, out); }\n\n  std::vector<int32_t> permute(ndim);\n  permute[0] = 0;\n  permute[1] = ndim - 1;\n  for (int i = 0; i < ndim - 2; ++i) { permute[i + 2] = i + 1; }\n  auto primitive = NewPermutePrimitive(stream->device_type(), ndim);\n  CHECK_NOTNULL_OR_THROW(primitive);\n  primitive->Launch(stream, data_type, ndim, shape, in, permute.data(), out);\n}\n\nusing ConvertMemoryFormatFunc =\n    std::function<void(ep::Stream*, int, const int64_t*, DataType, const void*, void*)>;\n\nConvertMemoryFormatFunc convert_funcs[kMemoryFormatCount][kMemoryFormatCount] = {\n    /*kContiguous->other*/ {ComputeIdentity, ComputeContiguousToChannelsLast},\n    /*kChannelsLast->other*/ {ComputeChannelsLastToContiguous, ComputeIdentity},\n};\n\nvoid ConvertMemoryFormat(ep::Stream* stream, const user_op::Tensor* in, user_op::Tensor* out,\n                         MemoryFormat in_memory_format, MemoryFormat out_memory_format) {\n  auto convert_func = convert_funcs[in_memory_format][out_memory_format];\n  convert_func(stream, in->shape_view().size(), in->shape_view().data(), in->data_type(),\n               in->dptr(), out->mut_dptr());\n}\n\nvoid ConvertMemoryFormat(ep::Stream* stream, int ndim, const int64_t* shape, DataType data_type,\n                         const void* in, void* out, MemoryFormat in_memory_format,\n                         MemoryFormat out_memory_format) {\n  auto convert_func = convert_funcs[in_memory_format][out_memory_format];\n  convert_func(stream, ndim, shape, data_type, in, out);\n}\n\nvoid ConvertMemoryFormat(ep::Stream* stream, const ShapeView& shape, DataType data_type,\n                         const void* in, void* out, MemoryFormat in_memory_format,\n                         MemoryFormat out_memory_format) {\n  ConvertMemoryFormat(stream, shape.size(), shape.data(), data_type, in, out, in_memory_format,\n                      out_memory_format);\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/convert_memory_format_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n\nnamespace oneflow {\n\nvoid ConvertMemoryFormat(ep::Stream* stream, const user_op::Tensor* in, user_op::Tensor* out,\n                         MemoryFormat in_memory_format, MemoryFormat out_memory_format);\n\nvoid ConvertMemoryFormat(ep::Stream* stream, int ndim, const int64_t* shape, DataType data_type,\n                         const void* in, void* out, MemoryFormat in_memory_format,\n                         MemoryFormat out_memory_format);\n\nvoid ConvertMemoryFormat(ep::Stream* stream, const ShapeView& shape, DataType data_type,\n                         const void* in, void* out, MemoryFormat in_memory_format,\n                         MemoryFormat out_memory_format);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/copy_data_content_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/cuda_graph_support.h\"\n#include \"oneflow/core/ep/include/primitive/memcpy.h\"\n#include \"oneflow/core/ep/include/primitive/fill.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nclass CopyDataContentKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport {\n public:\n  CopyDataContentKernel() = default;\n  ~CopyDataContentKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    const int64_t elem_cnt = in->shape_view().elem_cnt();\n    // For 0-size tensor, we don't need to copy data, but we must\n    // fill output tensor with Scalar(0) because during the backward propogation, this kernel will\n    // also be used.\n    if (elem_cnt == 0) {\n      const int64_t out_elem_cnt = out->shape_view().elem_cnt();\n      CHECK_GE(out_elem_cnt, 0);\n      if (out_elem_cnt == 0) { return; }\n      std::unique_ptr<ep::primitive::Fill> fill =\n          ep::primitive::NewPrimitive<ep::primitive::FillFactory>(ctx->device_type(),\n                                                                  out->data_type());\n      CHECK(fill);\n      fill->Launch(ctx->stream(), out->mut_dptr(), Scalar(0), out_elem_cnt);\n      return;\n    }\n    CHECK_EQ(out->shape_view().elem_cnt(), elem_cnt);\n    CHECK_EQ(in->data_type(), out->data_type());\n    if (elem_cnt > 0) {\n      std::unique_ptr<ep::primitive::Memcpy> primitive =\n          ep::primitive::NewPrimitive<ep::primitive::MemcpyFactory>(\n              ctx->stream()->device_type(), ep::primitive::MemcpyKind::kDtoD);\n      CHECK(primitive) << \"Can not create Memcpy primitive for device type \"\n                       << ctx->stream()->device_type();\n      primitive->Launch(ctx->stream(), out->mut_dptr(), in->dptr(),\n                        elem_cnt * GetSizeOfDataType(in->data_type()));\n    }\n  };\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_COPY_DATA_CONTENT_KERNEL(op_type_name)                              \\\n  REGISTER_USER_KERNEL(op_type_name)                                                 \\\n      .SetCreateFn<CopyDataContentKernel>()                                          \\\n      .SetInplaceProposalFn(                                                         \\\n          [](const user_op::InferContext&,                                           \\\n             const user_op::AddInplaceArgPair& AddInplaceArgPairFn) -> Maybe<void> { \\\n            OF_RETURN_IF_ERROR(AddInplaceArgPairFn(\"out\", 0, \"in\", 0, false));       \\\n            return Maybe<void>::Ok();                                                \\\n          });\n\nREGISTER_COPY_DATA_CONTENT_KERNEL(\"squeeze\");\nREGISTER_COPY_DATA_CONTENT_KERNEL(\"reshape_like\");\nREGISTER_COPY_DATA_CONTENT_KERNEL(\"expand_dims\");\nREGISTER_COPY_DATA_CONTENT_KERNEL(\"reshape\");\nREGISTER_COPY_DATA_CONTENT_KERNEL(\"amp_white_identity\");\nREGISTER_COPY_DATA_CONTENT_KERNEL(\"amp_black_identity\");\nREGISTER_COPY_DATA_CONTENT_KERNEL(\"identity\");\nREGISTER_COPY_DATA_CONTENT_KERNEL(\"identity_buffer\");\nREGISTER_COPY_DATA_CONTENT_KERNEL(\"parallel_cast\");\nREGISTER_COPY_DATA_CONTENT_KERNEL(\"hierarchical_parallel_cast\");\nREGISTER_COPY_DATA_CONTENT_KERNEL(\"hierarchical_parallel_cast_like\");\nREGISTER_COPY_DATA_CONTENT_KERNEL(\"pinned_identity\");\nREGISTER_COPY_DATA_CONTENT_KERNEL(\"depend\");\n\n}  // namespace\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/copy_hd_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/kernel_util.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nclass CopyHdKernel final : public user_op::OpKernel {\n public:\n  CopyHdKernel() = default;\n  ~CopyHdKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    CHECK(in) << \"input of copy not found\";\n    const ShapeView& in_shape = in->shape_view();\n    if (in_shape.elem_cnt() == 0) {\n      // 0 shape tensor do not need copy\n    } else {\n      const DataType in_data_type = in->data_type();\n      user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n      CHECK(out) << \"output of copy not found, op: \" << ctx->op_name();\n      CHECK_EQ(out->shape_view(), in_shape);\n      CHECK_EQ(out->data_type(), in_data_type);\n\n      ep::primitive::MemcpyKind kind{};\n      if (ctx->op_type_name() == \"copy_h2d\") {\n        kind = ep::primitive::MemcpyKind::kHtoD;\n      } else if (ctx->op_type_name() == \"copy_d2h\") {\n        kind = ep::primitive::MemcpyKind::kDtoH;\n      } else {\n        UNIMPLEMENTED();\n      }\n      std::unique_ptr<ep::primitive::Memcpy> primitive =\n          ep::primitive::NewPrimitive<ep::primitive::MemcpyFactory>(ctx->stream()->device_type(),\n                                                                    kind);\n      primitive->Launch(ctx->stream(), out->mut_raw_dptr(), in->raw_dptr(),\n                        in_shape.elem_cnt() * GetSizeOfDataType(in_data_type));\n    }\n  }\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\nREGISTER_USER_KERNEL(\"copy_h2d\").SetCreateFn<CopyHdKernel>();\nREGISTER_USER_KERNEL(\"copy_d2h\").SetCreateFn<CopyHdKernel>();\n\n}  // namespace\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/copy_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/kernel_util.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nclass CopyKernel final : public user_op::OpKernel {\n public:\n  CopyKernel() = default;\n  ~CopyKernel() override = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    const ShapeView& in_shape = in->shape_view();\n    CHECK_EQ(out->shape_view(), in_shape);\n    const DataType in_data_type = in->data_type();\n    CHECK_EQ(out->data_type(), in_data_type);\n    if (in_shape.elem_cnt() == 0) {\n      // 0 shape tensor do not need copy\n      return;\n    } else {\n      AutoMemcpy(ctx->stream(), out->mut_raw_dptr(), in->raw_dptr(),\n                 in_shape.elem_cnt() * GetSizeOfDataType(in_data_type), out->mem_case(),\n                 in->mem_case());\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\nREGISTER_USER_KERNEL(\"copy\").SetCreateFn<CopyKernel>();\n\n}  // namespace\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/count_not_finite_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n\nnamespace oneflow {\n\ntemplate<typename T>\nclass MultiCountNotFiniteCpuKernel final : public user_op::OpKernel {\n public:\n  MultiCountNotFiniteCpuKernel() = default;\n  ~MultiCountNotFiniteCpuKernel() override = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    int64_t* y_ptr = y->mut_dptr<int64_t>();\n    int64_t count = 0;\n    FOR_RANGE(int32_t, i, 0, ctx->inputs().size()) {\n      user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex(\"x\", i);\n      const T* x_ptr = x->dptr<T>();\n      FOR_RANGE(int32_t, j, 0, x->shape_view().elem_cnt()) {\n        if (!std::isfinite(x_ptr[j])) { count++; }\n      }\n    }\n    y_ptr[0] = count;\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_COUNT_NOT_FINITE_CPU_KERNEL(dtype)                   \\\n  REGISTER_USER_KERNEL(\"count_not_finite\")                            \\\n      .SetCreateFn<MultiCountNotFiniteCpuKernel<dtype>>()             \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \\\n                       && (user_op::HobDataType(\"x\", 0) == GetDataType<dtype>::value));\n\nREGISTER_COUNT_NOT_FINITE_CPU_KERNEL(float)\nREGISTER_COUNT_NOT_FINITE_CPU_KERNEL(double)\n\n#define REGISTER_MULTI_COUNT_NOT_FINITE_CPU_KERNEL(dtype)             \\\n  REGISTER_USER_KERNEL(\"multi_count_not_finite\")                      \\\n      .SetCreateFn<MultiCountNotFiniteCpuKernel<dtype>>()             \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \\\n                       && (user_op::HobDataType(\"x\", 0) == GetDataType<dtype>::value));\n\nREGISTER_MULTI_COUNT_NOT_FINITE_CPU_KERNEL(float)\nREGISTER_MULTI_COUNT_NOT_FINITE_CPU_KERNEL(double)\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/count_not_finite_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include <cub/cub.cuh>\n#include \"oneflow/core/kernel/new_kernel_util.h\"\n#include \"oneflow/core/kernel/cuda_graph_support.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename T, int32_t N>\nstruct Param {\n  const T* x[N];\n  int64_t x_elem_cnt[N];\n  int64_t* y;\n  int64_t num_x;\n};\n\nusing CuInt64T = unsigned long long int;\n\n__device__ __inline__ int64_t AtomicAdd(int64_t* address, int64_t val) {\n  static_assert(sizeof(int64_t) == sizeof(CuInt64T), \"size error\");\n  return static_cast<int64_t>(\n      atomicAdd(reinterpret_cast<CuInt64T*>(address), static_cast<CuInt64T>(val)));\n}\n\ntemplate<typename T>\n__inline__ __device__ bool IsFinite(T x) {\n  return isfinite(x);\n}\n\ntemplate<>\n__inline__ __device__ bool IsFinite<half>(half x) {\n  return IsFinite(static_cast<float>(x));\n}\n\ntemplate<typename T>\n__global__ void CountNotFiniteGpu(const int64_t n, const T* x, int64_t* y) {\n  typedef cub::BlockReduce<int64_t, kCudaThreadsNumPerBlock> BlockReduce;\n  __shared__ typename BlockReduce::TempStorage cub_reduce_tmp_storage;\n  int64_t thread_count = 0;\n  CUDA_1D_KERNEL_LOOP(i, n) {\n    if (!IsFinite(x[i])) { thread_count += 1; }\n  }\n  __syncthreads();\n  int64_t block_count_sum = BlockReduce(cub_reduce_tmp_storage).Reduce(thread_count, cub::Sum());\n  if (threadIdx.x == 0) { AtomicAdd(y, block_count_sum); }\n}\n\ntemplate<typename T, int32_t N>\n__global__ void MultiCountNotFiniteGpu(Param<T, N> param) {\n  typedef cub::BlockReduce<int64_t, kCudaThreadsNumPerBlock> BlockReduce;\n  __shared__ typename BlockReduce::TempStorage cub_reduce_tmp_storage;\n  int64_t thread_count = 0;\n  for (int32_t k = 0; k < param.num_x; ++k) {\n    CUDA_1D_KERNEL_LOOP(i, param.x_elem_cnt[k]) {\n      if (!IsFinite(param.x[k][i])) { thread_count += 1; }\n    }\n  }\n  __syncthreads();\n  int64_t block_count_sum = BlockReduce(cub_reduce_tmp_storage).Reduce(thread_count, cub::Sum());\n  if (threadIdx.x == 0) { AtomicAdd(param.y, block_count_sum); }\n}\n\nconstexpr int64_t kCountNotFiniteNumBlocks = 512;\n\nint GetCountNotFiniteNumBlocks(const int64_t elem_cnt) {\n  return std::min((elem_cnt + kCudaThreadsNumPerBlock - 1) / kCudaThreadsNumPerBlock,\n                  kCountNotFiniteNumBlocks);\n}\n\n}  // namespace\n\ntemplate<typename T>\nclass CountNotFiniteGpuKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport {\n public:\n  CountNotFiniteGpuKernel() = default;\n  ~CountNotFiniteGpuKernel() override = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    const int64_t elem_cnt = x->shape_view().elem_cnt();\n    Memset<DeviceType::kCUDA>(ctx->stream(), y->mut_dptr<int64_t>(), 0,\n                              y->shape_view().elem_cnt() * sizeof(int64_t));\n    CountNotFiniteGpu<T><<<GetCountNotFiniteNumBlocks(elem_cnt), kCudaThreadsNumPerBlock, 0,\n                           ctx->stream()->As<ep::CudaStream>()->cuda_stream()>>>(\n        elem_cnt, x->dptr<T>(), y->mut_dptr<int64_t>());\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_COUNT_NOT_FINITE_CUDA_KERNEL(dtype)                   \\\n  REGISTER_USER_KERNEL(\"count_not_finite\")                             \\\n      .SetCreateFn<CountNotFiniteGpuKernel<dtype>>()                   \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \\\n                       && (user_op::HobDataType(\"x\", 0) == GetDataType<dtype>::value));\n\nREGISTER_COUNT_NOT_FINITE_CUDA_KERNEL(half)\nREGISTER_COUNT_NOT_FINITE_CUDA_KERNEL(float)\nREGISTER_COUNT_NOT_FINITE_CUDA_KERNEL(double)\n\ntemplate<typename T>\nclass MultiCountNotFiniteGpuKernel final : public user_op::OpKernel,\n                                           public user_op::CudaGraphSupport {\n public:\n  MultiCountNotFiniteGpuKernel() = default;\n  ~MultiCountNotFiniteGpuKernel() override = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    Param<T, 128> para;\n    Memset<DeviceType::kCUDA>(ctx->stream(), y->mut_dptr<int64_t>(), 0,\n                              y->shape_view().elem_cnt() * sizeof(int64_t));\n    para.y = y->mut_dptr<int64_t>();\n\n    int64_t remain_size = ctx->inputs().size();\n    int64_t input_id = 0;\n    while (remain_size > 0) {\n      if (remain_size > 128) {\n        remain_size -= 128;\n        para.num_x = 128;\n      } else {\n        para.num_x = remain_size;\n        remain_size = 0;\n      }\n      int64_t max_elem_cnt = 0;\n      for (int32_t i = 0; i < para.num_x; ++i) {\n        const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex(\"x\", input_id);\n        input_id++;\n        para.x[i] = x->dptr<T>();\n        para.x_elem_cnt[i] = x->shape_view().elem_cnt();\n        max_elem_cnt = std::max(max_elem_cnt, x->shape_view().elem_cnt());\n      }\n      MultiCountNotFiniteGpu<T, 128>\n          <<<GetCountNotFiniteNumBlocks(max_elem_cnt), kCudaThreadsNumPerBlock, 0,\n             ctx->stream()->As<ep::CudaStream>()->cuda_stream()>>>(para);\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_MULTI_COUNT_NOT_FINITE_CUDA_KERNEL(dtype)             \\\n  REGISTER_USER_KERNEL(\"multi_count_not_finite\")                       \\\n      .SetCreateFn<MultiCountNotFiniteGpuKernel<dtype>>()              \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \\\n                       && (user_op::HobDataType(\"x\", 0) == GetDataType<dtype>::value));\n\nREGISTER_MULTI_COUNT_NOT_FINITE_CUDA_KERNEL(half)\nREGISTER_MULTI_COUNT_NOT_FINITE_CUDA_KERNEL(float)\nREGISTER_MULTI_COUNT_NOT_FINITE_CUDA_KERNEL(double)\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/ctc_greedy_decoder.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/new_kernel_util.h\"\n#include \"oneflow/core/kernel/kernel_util.h\"\n#include \"oneflow/user/kernels/ctc_greedy_decoder.h\"\n\nnamespace oneflow {\nnamespace {\n\ntemplate<typename T>\nstruct CTCGreedyDecoderFunctor<DeviceType::kCPU, T> final {\n  void operator()(ep::Stream* stream, int64_t* decoded_ptr, T* neg_sum_logits_ptr,\n                  const T* log_probs_ptr, const int64_t* input_lengths_ptr,\n                  const bool merge_repeated, const int64_t max_input_length,\n                  const int64_t batch_size, const int64_t num_labels) {\n    FOR_RANGE(int64_t, b, 0, batch_size) { CHECK_GE(max_input_length, input_lengths_ptr[b]); }\n    NdIndexOffsetHelper<int64_t, 3> input_helper(max_input_length, batch_size, num_labels);\n\n    FOR_RANGE(int64_t, b, 0, batch_size) {\n      int64_t prev_indices = -1, t_dec = 0;\n      neg_sum_logits_ptr[b] = 0;\n      FOR_RANGE(int64_t, t, 0, input_lengths_ptr[b]) {\n        const T* prob_data_t = &log_probs_ptr[input_helper.NdIndexToOffset(t, b, 0)];\n        int64_t max_indice = std::max_element(prob_data_t, prob_data_t + num_labels) - prob_data_t;\n        neg_sum_logits_ptr[b] -= prob_data_t[max_indice];\n        if (max_indice != num_labels - 1 && !(merge_repeated && (prev_indices == max_indice))) {\n          decoded_ptr[b * max_input_length + t_dec] = max_indice;\n          t_dec++;\n        }\n        prev_indices = max_indice;\n      }\n      FOR_RANGE(int64_t, t, t_dec, max_input_length) { decoded_ptr[b * max_input_length + t] = 0; }\n    }\n  }\n};\n\n}  // namespace\n\nREGISTER_CTC_GREEDY_DECODER_KERNELS(DeviceType::kCPU, float);\nREGISTER_CTC_GREEDY_DECODER_KERNELS(DeviceType::kCPU, double);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/ctc_greedy_decoder.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/new_kernel_util.h\"\n#include \"oneflow/core/kernel/kernel_util.h\"\n#include \"oneflow/user/kernels/ctc_greedy_decoder.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n\nnamespace oneflow {\nnamespace {\n\ntemplate<typename T>\n__global__ void CtcGreedyDecodeGpuMultiThread(int64_t* decoded_ptr, T* neg_sum_logits_ptr,\n                                              const T* log_probs_ptr,\n                                              const int64_t* input_lengths_ptr,\n                                              const bool merge_repeated,\n                                              const int64_t max_input_length,\n                                              const int64_t batch_size, const int64_t num_labels) {\n  const int64_t bid = blockIdx.x;\n  const int64_t tid = threadIdx.x;\n\n  for (int64_t b = bid; b < batch_size; b += gridDim.x) {\n    if (tid == 0) {\n      if (input_lengths_ptr[b] > max_input_length) __trap();\n    }\n  }\n\n  for (int64_t b = bid; b < batch_size; b += gridDim.x) {\n    extern __shared__ int64_t shared_max_indices_memory[];\n    int64_t* shared_max_indices = (int64_t*)shared_max_indices_memory;\n    NdIndexOffsetHelper<int64_t, 3> input_helper(max_input_length, batch_size, num_labels);\n    for (int64_t t = tid; t < max_input_length; t += blockDim.x) {\n      const T* prob_data_t = &log_probs_ptr[input_helper.NdIndexToOffset(t, b, 0)];\n      int64_t max_indice = 0;\n      T max_value = -FLT_MAX;\n      FOR_RANGE(int64_t, c, 0, num_labels) {\n        const T prob = prob_data_t[c];\n        if (prob > max_value) {\n          max_indice = c;\n          max_value = prob;\n        }\n      }\n      shared_max_indices[t] = max_indice;\n    }\n\n    __syncthreads();\n\n    if (tid == 0) {\n      int64_t prev_indices = -1, t_dec = 0;\n      FOR_RANGE(int64_t, t, 0, input_lengths_ptr[b]) {\n        const T* prob_data_t = &log_probs_ptr[input_helper.NdIndexToOffset(t, b, 0)];\n        const int64_t indice_t = shared_max_indices[t];\n        neg_sum_logits_ptr[b] -= prob_data_t[indice_t];\n        if (indice_t != num_labels - 1 && !(merge_repeated && (prev_indices == indice_t))) {\n          decoded_ptr[b * max_input_length + t_dec] = indice_t;\n          t_dec++;\n        }\n        prev_indices = indice_t;\n      }\n      FOR_RANGE(int64_t, t, t_dec, max_input_length) { decoded_ptr[b * max_input_length + t] = 0; }\n    }\n  }\n}\n\ntemplate<typename T>\n__global__ void CtcGreedyDecodeGpu(int64_t* decoded_ptr, T* neg_sum_logits_ptr,\n                                   const T* log_probs_ptr, const int64_t* input_lengths_ptr,\n                                   const bool merge_repeated, const int64_t max_input_length,\n                                   const int64_t batch_size, const int64_t num_labels) {\n  for (int64_t b = 0; b < batch_size; b++) {\n    if (input_lengths_ptr[b] > max_input_length) __trap();\n  }\n  NdIndexOffsetHelper<int64_t, 3> input_helper(max_input_length, batch_size, num_labels);\n\n  CUDA_1D_KERNEL_LOOP(b, batch_size) {\n    int prev_indices = -1, t_dec = 0;\n    neg_sum_logits_ptr[b] = 0;\n    FOR_RANGE(int64_t, t, 0, input_lengths_ptr[b]) {\n      const T* prob_data_t = &log_probs_ptr[input_helper.NdIndexToOffset(t, b, 0)];\n      int64_t max_indice = -1;\n      T max_value = -FLT_MAX;\n      FOR_RANGE(int64_t, c, 0, num_labels) {\n        if (prob_data_t[c] > max_value) {\n          max_indice = c;\n          max_value = prob_data_t[c];\n        }\n      }\n      neg_sum_logits_ptr[b] -= max_value;\n      if (max_indice != num_labels - 1 && !(merge_repeated && (prev_indices == max_indice))) {\n        decoded_ptr[b * max_input_length + t_dec] = max_indice;\n        t_dec++;\n      }\n      prev_indices = max_indice;\n    }\n    FOR_RANGE(int64_t, t, t_dec, max_input_length) { decoded_ptr[b * max_input_length + t] = 0; }\n  }\n}\n\ntemplate<typename T>\nstruct CTCGreedyDecoderFunctor<DeviceType::kCUDA, T> final {\n  void operator()(ep::Stream* stream, int64_t* decoded_ptr, T* neg_sum_logits_ptr,\n                  const T* log_probs_ptr, const int64_t* input_lengths_ptr,\n                  const bool merge_repeated, const int64_t max_input_length,\n                  const int64_t batch_size, const int64_t num_labels) {\n    int32_t thread_num = batch_size * kCudaThreadsNumPerBlock;\n    int64_t shared_mem_size = max_input_length * sizeof(int64_t);\n\n    int max_active_blocks;\n    OF_CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(\n        &max_active_blocks, CtcGreedyDecodeGpu<T>, kCudaThreadsNumPerBlock, shared_mem_size));\n    if (max_active_blocks > 0) {\n      CtcGreedyDecodeGpuMultiThread<<<BlocksNum4ThreadsNum(thread_num), kCudaThreadsNumPerBlock,\n                                      shared_mem_size,\n                                      stream->As<ep::CudaStream>()->cuda_stream()>>>(\n          decoded_ptr, neg_sum_logits_ptr, log_probs_ptr, input_lengths_ptr, merge_repeated,\n          max_input_length, batch_size, num_labels);\n\n    } else {\n      CtcGreedyDecodeGpu<<<BlocksNum4ThreadsNum(thread_num), kCudaThreadsNumPerBlock, 0,\n                           stream->As<ep::CudaStream>()->cuda_stream()>>>(\n          decoded_ptr, neg_sum_logits_ptr, log_probs_ptr, input_lengths_ptr, merge_repeated,\n          max_input_length, batch_size, num_labels);\n    }\n  }\n};\n\n}  // namespace\n\nREGISTER_CTC_GREEDY_DECODER_KERNELS(DeviceType::kCUDA, float);\nREGISTER_CTC_GREEDY_DECODER_KERNELS(DeviceType::kCUDA, double);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/ctc_greedy_decoder.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef _ONEFLOW_USER_KERNELS_CTC_GREEDY_DECODER_KERNEL_H_\n#define _ONEFLOW_USER_KERNELS_CTC_GREEDY_DECODER_KERNEL_H_\n#include \"oneflow/core/ndarray/xpu_util.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/common/nd_index_offset_helper.h\"\n\nnamespace oneflow {\n\nnamespace {\ntemplate<DeviceType device_type, typename T>\nstruct CTCGreedyDecoderFunctor final {\n  void operator()(ep::Stream* stream, int64_t* decoded_ptr, T* neg_sum_logits_ptr,\n                  const T* log_probs_ptr, const int64_t* input_lengths_ptr,\n                  const bool merge_repeated, const int64_t max_input_length,\n                  const int64_t batch_size, const int64_t num_labels);\n};\n\n}  // namespace\n\ntemplate<DeviceType device_type, typename T>\nclass CTCGreedyDecoderKernel final : public user_op::OpKernel {\n public:\n  CTCGreedyDecoderKernel() = default;\n  ~CTCGreedyDecoderKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* log_probs = ctx->Tensor4ArgNameAndIndex(\"log_probs\", 0);\n    const user_op::Tensor* input_lengths = ctx->Tensor4ArgNameAndIndex(\"input_lengths\", 0);\n    user_op::Tensor* decoded = ctx->Tensor4ArgNameAndIndex(\"decoded\", 0);\n    user_op::Tensor* neg_sum_logits = ctx->Tensor4ArgNameAndIndex(\"neg_sum_logits\", 0);\n    const T* log_probs_ptr = log_probs->dptr<T>();\n    const int64_t* input_lengths_ptr = input_lengths->dptr<int64_t>();\n    const bool merge_repeated = ctx->Attr<bool>(\"merge_repeated\");\n    const int64_t max_input_length = log_probs->shape_view().At(0);\n    const int64_t batch_size = log_probs->shape_view().At(1);\n    const int64_t num_labels = log_probs->shape_view().At(2);\n    CHECK_EQ(batch_size, input_lengths->shape_view().At(0));\n    int64_t* decoded_ptr = decoded->mut_dptr<int64_t>();\n    T* neg_sum_logits_ptr = neg_sum_logits->mut_dptr<T>();\n\n    CTCGreedyDecoderFunctor<device_type, T>()(ctx->stream(), decoded_ptr, neg_sum_logits_ptr,\n                                              log_probs_ptr, input_lengths_ptr, merge_repeated,\n                                              max_input_length, batch_size, num_labels);\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_CTC_GREEDY_DECODER_KERNELS(device, dtype)  \\\n  REGISTER_USER_KERNEL(\"ctc_greedy_decoder\")                \\\n      .SetCreateFn<CTCGreedyDecoderKernel<device, dtype>>() \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == device) \\\n                       && (user_op::HobDataType(\"log_probs\", 0) == GetDataType<dtype>::value));\n\n}  // namespace oneflow\n\n#endif  // _ONEFLOW_USER_KERNELS_CTC_GREEDY_DECODER_KERNEL_H_\n"
  },
  {
    "path": "oneflow/user/kernels/ctc_loss_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/user/kernels/ctc_loss_kernel_util.h\"\n\nnamespace oneflow {\n\ntemplate<DeviceType device_type, typename T, typename TARGET, typename IDX>\nclass CtcLossKernel final : public user_op::OpKernel {\n public:\n  CtcLossKernel() = default;\n  ~CtcLossKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* log_probs = ctx->Tensor4ArgNameAndIndex(\"log_probs\", 0);\n    const user_op::Tensor* targets = ctx->Tensor4ArgNameAndIndex(\"targets\", 0);\n    const user_op::Tensor* input_lengths = ctx->Tensor4ArgNameAndIndex(\"input_lengths\", 0);\n    const user_op::Tensor* target_lengths = ctx->Tensor4ArgNameAndIndex(\"target_lengths\", 0);\n    user_op::Tensor* loss = ctx->Tensor4ArgNameAndIndex(\"loss\", 0);\n    user_op::Tensor* alpha = ctx->Tensor4ArgNameAndIndex(\"alpha\", 0);\n\n    const T* log_probs_ptr = log_probs->dptr<T>();\n    const TARGET* targets_ptr = targets->dptr<TARGET>();\n    const IDX* input_lengths_ptr = input_lengths->dptr<IDX>();\n    const IDX* target_lengths_ptr = target_lengths->dptr<IDX>();\n    const int64_t blank = ctx->Attr<int64_t>(\"blank\");\n    const int64_t max_input_length = log_probs->shape_view().At(0);\n    const int64_t batch_size = log_probs->shape_view().At(1);\n    const int64_t num_labels = log_probs->shape_view().At(2);\n    const int64_t max_target_length = ctx->Attr<int64_t>(\"max_target_length\");\n    const int32_t targets_ndim = targets->shape_view().NumAxes();\n\n    NdIndexOffsetHelper<int64_t, 3> input_helper(max_input_length, batch_size, num_labels);\n    NdIndexOffsetHelper<int64_t, 3> alpha_helper(batch_size, max_input_length,\n                                                 2 * max_target_length + 1);\n    T* loss_ptr = loss->mut_dptr<T>();\n    T* alpha_ptr = alpha->mut_dptr<T>();\n    CtcLossKernelUtil<device_type, T, TARGET, IDX>::CtcLossForward(\n        ctx->stream(), log_probs_ptr, targets_ptr, input_lengths_ptr, target_lengths_ptr, alpha_ptr,\n        loss_ptr, input_helper, alpha_helper, batch_size, max_input_length, max_target_length,\n        blank, targets_ndim);\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_CTC_LOSS_KERNEL(device, dtype, target_type, idx_dtype)                          \\\n  REGISTER_USER_KERNEL(\"ctc_loss\")                                                               \\\n      .SetCreateFn<CtcLossKernel<device, OF_PP_PAIR_FIRST(dtype), OF_PP_PAIR_FIRST(target_type), \\\n                                 OF_PP_PAIR_FIRST(idx_dtype)>>()                                 \\\n      .SetIsMatchedHob(                                                                          \\\n          (user_op::HobDeviceType() == device)                                                   \\\n          && (user_op::HobDataType(\"log_probs\", 0) == OF_PP_PAIR_SECOND(dtype))                  \\\n          && (user_op::HobDataType(\"targets\", 0) == OF_PP_PAIR_SECOND(target_type))              \\\n          && (user_op::HobDataType(\"input_lengths\", 0) == OF_PP_PAIR_SECOND(idx_dtype)));\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_CTC_LOSS_KERNEL, DEVICE_TYPE_SEQ, FLOATING_DATA_TYPE_SEQ,\n                                 INDEX_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ)\n\ntemplate<DeviceType device_type, typename T, typename TARGET, typename IDX>\nclass CtcLossGradKernel final : public user_op::OpKernel {\n public:\n  CtcLossGradKernel() = default;\n  ~CtcLossGradKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* grad_out = ctx->Tensor4ArgNameAndIndex(\"grad_out\", 0);\n    const user_op::Tensor* loss = ctx->Tensor4ArgNameAndIndex(\"loss\", 0);\n    const user_op::Tensor* alpha = ctx->Tensor4ArgNameAndIndex(\"alpha\", 0);\n    const user_op::Tensor* log_probs = ctx->Tensor4ArgNameAndIndex(\"log_probs\", 0);\n    const user_op::Tensor* targets = ctx->Tensor4ArgNameAndIndex(\"targets\", 0);\n    const user_op::Tensor* input_lengths = ctx->Tensor4ArgNameAndIndex(\"input_lengths\", 0);\n    const user_op::Tensor* target_lengths = ctx->Tensor4ArgNameAndIndex(\"target_lengths\", 0);\n    user_op::Tensor* grad = ctx->Tensor4ArgNameAndIndex(\"grad\", 0);\n    user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n\n    const T* grad_out_ptr = grad_out->dptr<T>();\n    const T* loss_ptr = loss->dptr<T>();\n    const T* alpha_ptr = alpha->dptr<T>();\n    const T* log_probs_ptr = log_probs->dptr<T>();\n    const TARGET* targets_ptr = targets->dptr<TARGET>();\n    const IDX* input_lengths_ptr = input_lengths->dptr<IDX>();\n    const IDX* target_lengths_ptr = target_lengths->dptr<IDX>();\n    const int64_t blank = ctx->Attr<int64_t>(\"blank\");\n    const bool zero_infinity = ctx->Attr<bool>(\"zero_infinity\");\n    const int64_t batch_size = log_probs->shape_view().At(1);\n    const int64_t num_labels = log_probs->shape_view().At(2);\n    const int64_t max_input_length = log_probs->shape_view().At(0);\n    const int64_t max_target_length = ctx->Attr<int64_t>(\"max_target_length\");\n    const int32_t targets_ndim = targets->shape_view().NumAxes();\n\n    NdIndexOffsetHelper<int64_t, 3> input_helper(max_input_length, batch_size, num_labels);\n    NdIndexOffsetHelper<int64_t, 3> beta_helper(batch_size, max_input_length,\n                                                2 * max_target_length + 1);\n    T* grad_ptr = grad->mut_dptr<T>();\n    T* beta_ptr = tmp_buffer->mut_dptr<T>();\n    CtcLossKernelUtil<device_type, T, TARGET, IDX>::CtcLossBackward(\n        ctx->stream(), grad_out_ptr, loss_ptr, alpha_ptr, log_probs_ptr, targets_ptr,\n        input_lengths_ptr, target_lengths_ptr, beta_ptr, grad_ptr, input_helper, beta_helper,\n        batch_size, max_input_length, max_target_length, num_labels, blank, zero_infinity,\n        targets_ndim);\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_CTC_LOSS_BACKWARD_KERNEL(device, dtype, target_type, idx_dtype)            \\\n  REGISTER_USER_KERNEL(\"ctc_loss_grad\")                                                     \\\n      .SetCreateFn<                                                                         \\\n          CtcLossGradKernel<device, OF_PP_PAIR_FIRST(dtype), OF_PP_PAIR_FIRST(target_type), \\\n                            OF_PP_PAIR_FIRST(idx_dtype)>>()                                 \\\n      .SetIsMatchedHob(                                                                     \\\n          (user_op::HobDeviceType() == device)                                              \\\n          && (user_op::HobDataType(\"log_probs\", 0) == OF_PP_PAIR_SECOND(dtype))             \\\n          && (user_op::HobDataType(\"targets\", 0) == OF_PP_PAIR_SECOND(target_type))         \\\n          && (user_op::HobDataType(\"input_lengths\", 0) == OF_PP_PAIR_SECOND(idx_dtype)))    \\\n      .SetInferTmpSizeFn([](user_op::InferContext* ctx) {                                   \\\n        const Shape& log_probs_shape = ctx->InputShape(\"log_probs\", 0);                     \\\n        const int64_t max_target_length = ctx->Attr<int64_t>(\"max_target_length\");          \\\n        int64_t elem_cnt =                                                                  \\\n            log_probs_shape.At(1) * log_probs_shape.At(0) * (2 * max_target_length + 1);    \\\n        return elem_cnt * sizeof(OF_PP_PAIR_FIRST(dtype));                                  \\\n      });\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_CTC_LOSS_BACKWARD_KERNEL, DEVICE_TYPE_SEQ,\n                                 FLOATING_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ)\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/ctc_loss_kernel_util.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/user/kernels/ctc_loss_kernel_util.h\"\n\nnamespace oneflow {\n\ntemplate<typename TARGET, typename IDX>\nint64_t get_target_prime(const TARGET* targets_ptr, const IDX* target_lengths_ptr,\n                         int64_t max_target_length, int64_t b, int64_t s, int64_t blank,\n                         const int32_t targets_ndim) {\n  if (s % 2 == 0) {\n    return blank;\n  } else {\n    int64_t idx = 0;\n    if (targets_ndim == 1) {\n      FOR_RANGE(int64_t, i, 0, b) { idx += target_lengths_ptr[i]; }\n    } else {  // targets_ndim == 2\n      idx = b * max_target_length;\n    }\n    idx += s / 2;\n    return static_cast<int64_t>(targets_ptr[idx]);\n  }\n}\n\ntemplate<typename T, typename TARGET, typename IDX>\nstruct CtcLossKernelUtil<DeviceType::kCPU, T, TARGET, IDX> final {\n  static void CtcLossForward(ep::Stream* stream, const T* log_probs_ptr, const TARGET* targets_ptr,\n                             const IDX* input_lengths_ptr, const IDX* target_lengths_ptr,\n                             T* alpha_ptr, T* loss_ptr,\n                             NdIndexOffsetHelper<int64_t, 3>& input_helper,\n                             NdIndexOffsetHelper<int64_t, 3>& alpha_helper,\n                             const int64_t batch_size, const int64_t max_input_length,\n                             const int64_t max_target_length, const int64_t blank,\n                             const int32_t targets_ndim);\n\n  static void CtcLossBackward(ep::Stream* stream, const T* grad_out_ptr, const T* loss_ptr,\n                              const T* alpha_ptr, const T* log_probs_ptr, const TARGET* targets_ptr,\n                              const IDX* input_lengths_ptr, const IDX* target_lengths_ptr,\n                              T* beta_ptr, T* grad_ptr,\n                              NdIndexOffsetHelper<int64_t, 3>& input_helper,\n                              NdIndexOffsetHelper<int64_t, 3>& beta_helper,\n                              const int64_t batch_size, const int64_t max_input_length,\n                              const int64_t max_target_length, const int64_t num_labels,\n                              const int64_t blank, const bool zero_infinity,\n                              const int32_t targets_ndim);\n};\n\ntemplate<typename T, typename TARGET, typename IDX>\nvoid CtcLossKernelUtil<DeviceType::kCPU, T, TARGET, IDX>::CtcLossForward(\n    ep::Stream* stream, const T* log_probs_ptr, const TARGET* targets_ptr,\n    const IDX* input_lengths_ptr, const IDX* target_lengths_ptr, T* alpha_ptr, T* loss_ptr,\n    NdIndexOffsetHelper<int64_t, 3>& input_helper, NdIndexOffsetHelper<int64_t, 3>& alpha_helper,\n    const int64_t batch_size, const int64_t max_input_length, const int64_t max_target_length,\n    const int64_t blank, const int32_t targets_ndim) {\n  constexpr T neginf = -std::numeric_limits<T>::infinity();\n  FOR_RANGE(int64_t, b, 0, batch_size) {\n    CHECK_GE(max_input_length, input_lengths_ptr[b]);\n    CHECK_GE(max_target_length, target_lengths_ptr[b]);\n  }\n  FOR_RANGE(int32_t, b, 0, batch_size) {\n    IDX input_length = input_lengths_ptr[b];\n    IDX target_length = target_lengths_ptr[b];\n\n    int64_t alpha_idx = alpha_helper.NdIndexToOffset(b, 0, 0);\n    for (IDX s = 0; s < 2 * target_length + 1; s++) { alpha_ptr[alpha_idx + s] = neginf; }\n    alpha_ptr[alpha_idx] = log_probs_ptr[input_helper.NdIndexToOffset(0, b, blank)];\n    if (target_length > 0) {\n      TARGET target = get_target_prime(targets_ptr, target_lengths_ptr, max_target_length, b, 1,\n                                       blank, targets_ndim);\n      alpha_ptr[alpha_idx + 1] = log_probs_ptr[input_helper.NdIndexToOffset(0, b, target)];\n    }\n\n    for (IDX t = 1; t < input_length; t++) {\n      for (IDX s = 0; s < 2 * target_length + 1; s++) {\n        TARGET current_target_prime = get_target_prime(\n            targets_ptr, target_lengths_ptr, max_target_length, b, s, blank, targets_ndim);\n        T la1 = alpha_ptr[alpha_helper.NdIndexToOffset(b, t - 1, s)];\n        T la2, la3, lamax = la1;\n        if (s > 0) {\n          la2 = alpha_ptr[alpha_helper.NdIndexToOffset(b, t - 1, s - 1)];\n          if (la2 > lamax) lamax = la2;\n        } else {\n          la2 = neginf;\n        }\n        if ((s > 1)\n            && (get_target_prime(targets_ptr, target_lengths_ptr, max_target_length, b, s - 2,\n                                 blank, targets_ndim)\n                != current_target_prime)) {\n          la3 = alpha_ptr[alpha_helper.NdIndexToOffset(b, t - 1, s - 2)];\n          if (la3 > lamax) lamax = la3;\n        } else {\n          la3 = neginf;\n        }\n        if (lamax == neginf) lamax = 0;\n\n        int64_t idx_t_s = alpha_helper.NdIndexToOffset(b, t, s);\n        alpha_ptr[idx_t_s] =\n            std::log(std::exp(la1 - lamax) + std::exp(la2 - lamax) + std::exp(la3 - lamax)) + lamax\n            + log_probs_ptr[input_helper.NdIndexToOffset(t, b, current_target_prime)];\n      }\n    }\n\n    if (target_length == 0) {\n      int64_t idx = alpha_helper.NdIndexToOffset(b, input_length - 1, 0);\n      loss_ptr[b] = -alpha_ptr[idx];\n    } else {\n      int64_t idx1 = alpha_helper.NdIndexToOffset(b, input_length - 1, target_length * 2);\n      int64_t idx2 = alpha_helper.NdIndexToOffset(b, input_length - 1, target_length * 2 - 1);\n      T l1 = alpha_ptr[idx1];\n      T l2 = alpha_ptr[idx2];\n      T m = std::max(l1, l2);\n      m = ((m == neginf) ? 0 : m);\n      T log_likelihood = std::log(std::exp(l1 - m) + std::exp(l2 - m)) + m;\n      loss_ptr[b] = -log_likelihood;\n    }\n  }\n}\n\ntemplate<typename T, typename TARGET, typename IDX>\nvoid CtcLossKernelUtil<DeviceType::kCPU, T, TARGET, IDX>::CtcLossBackward(\n    ep::Stream* stream, const T* grad_out_ptr, const T* loss_ptr, const T* alpha_ptr,\n    const T* log_probs_ptr, const TARGET* targets_ptr, const IDX* input_lengths_ptr,\n    const IDX* target_lengths_ptr, T* beta_ptr, T* grad_ptr,\n    NdIndexOffsetHelper<int64_t, 3>& input_helper, NdIndexOffsetHelper<int64_t, 3>& beta_helper,\n    const int64_t batch_size, const int64_t max_input_length, const int64_t max_target_length,\n    const int64_t num_labels, const int64_t blank, const bool zero_infinity,\n    const int32_t targets_ndim) {\n  constexpr T neginf = -std::numeric_limits<T>::infinity();\n  int64_t elem_cnt = max_input_length * batch_size * num_labels;\n  FOR_RANGE(int64_t, i, 0, elem_cnt) { grad_ptr[i] = neginf; }\n\n  FOR_RANGE(int64_t, b, 0, batch_size) {\n    IDX input_length = input_lengths_ptr[b];\n    IDX target_length = target_lengths_ptr[b];\n    T nll = loss_ptr[b];\n    if (zero_infinity && nll == std::numeric_limits<T>::infinity()) {\n      for (IDX t = 0; t < max_input_length; t++) {\n        for (IDX c = 0; c < num_labels; c++) {\n          grad_ptr[input_helper.NdIndexToOffset(t, b, c)] = 0;\n        }\n      }\n      continue;\n    }\n\n    if (input_length > 0) {\n      int64_t beta_idx = beta_helper.NdIndexToOffset(b, input_length - 1, 0);\n      for (IDX s = 0; s < 2 * target_length + 1; s++) { beta_ptr[beta_idx + s] = neginf; }\n      beta_ptr[beta_idx + 2 * target_length] =\n          log_probs_ptr[input_helper.NdIndexToOffset(input_length - 1, b, blank)];\n      grad_ptr[input_helper.NdIndexToOffset(input_length - 1, b, blank)] =\n          alpha_ptr[beta_helper.NdIndexToOffset(b, input_length - 1, 2 * target_length)]\n          + beta_ptr[beta_helper.NdIndexToOffset(b, input_length - 1, 2 * target_length)];\n\n      if (target_length > 0) {\n        TARGET target = get_target_prime(targets_ptr, target_lengths_ptr, max_target_length, b,\n                                         2 * target_length - 1, blank, targets_ndim);\n        beta_ptr[beta_helper.NdIndexToOffset(b, input_length - 1, 2 * target_length - 1)] =\n            log_probs_ptr[input_helper.NdIndexToOffset(input_length - 1, b, target)];\n        grad_ptr[input_helper.NdIndexToOffset(input_length - 1, b, target)] =\n            alpha_ptr[beta_helper.NdIndexToOffset(b, input_length - 1, 2 * target_length - 1)]\n            + beta_ptr[beta_helper.NdIndexToOffset(b, input_length - 1, 2 * target_length - 1)];\n      }\n    }\n\n    for (IDX t = input_length - 2; t >= 0; t--) {\n      for (IDX s = 2 * target_length; s >= 0; s--) {\n        TARGET current_target_prime = get_target_prime(\n            targets_ptr, target_lengths_ptr, max_target_length, b, s, blank, targets_ndim);\n        T lb1 = beta_ptr[beta_helper.NdIndexToOffset(b, t + 1, s)];\n        T lb2, lb3, lbmax = lb1;\n\n        if (s < 2 * target_length) {\n          lb2 = beta_ptr[beta_helper.NdIndexToOffset(b, t + 1, s + 1)];\n          if (lb2 > lbmax) lbmax = lb2;\n        } else {\n          lb2 = neginf;\n        }\n\n        if ((s < 2 * target_length - 1)\n            && (get_target_prime(targets_ptr, target_lengths_ptr, max_target_length, b, s + 2,\n                                 blank, targets_ndim)\n                != current_target_prime)) {\n          lb3 = beta_ptr[beta_helper.NdIndexToOffset(b, t + 1, s + 2)];\n          if (lb3 > lbmax) lbmax = lb3;\n        } else {\n          lb3 = neginf;\n        }\n        if (lbmax == neginf) lbmax = 0;\n\n        int64_t idx_t_s = beta_helper.NdIndexToOffset(b, t, s);\n        beta_ptr[idx_t_s] =\n            std::log(std::exp(lb1 - lbmax) + std::exp(lb2 - lbmax) + std::exp(lb3 - lbmax)) + lbmax\n            + log_probs_ptr[input_helper.NdIndexToOffset(t, b, current_target_prime)];\n\n        T log_alpha_beta = alpha_ptr[idx_t_s] + beta_ptr[idx_t_s];\n        T& lcab = grad_ptr[input_helper.NdIndexToOffset(t, b, current_target_prime)];\n        if (lcab == neginf) {\n          lcab = log_alpha_beta;\n        } else {\n          T m = std::max(lcab, log_alpha_beta);\n          lcab = std::log(std::exp(lcab - m) + std::exp(log_alpha_beta - m)) + m;\n        }\n      }\n    }\n\n    for (int32_t t = 0; t < input_length; t++) {\n      for (int64_t c = 0; c < num_labels; c++) {\n        T& res = grad_ptr[input_helper.NdIndexToOffset(t, b, c)];\n        T lp = log_probs_ptr[input_helper.NdIndexToOffset(t, b, c)];\n        res = (std::exp(lp) - std::exp(res + nll - lp)) * grad_out_ptr[b];\n      }\n    }\n\n    // zero the remainder\n    if (input_length < max_input_length) {\n      for (int64_t t = input_length; t < max_input_length; t++) {\n        for (int64_t c = 0; c < num_labels; c++) {\n          int64_t grad_idx = input_helper.NdIndexToOffset(t, b, c);\n          grad_ptr[grad_idx] = 0;\n        }\n      }\n    }\n  }\n}\n\n#define INSTANTIATE_CTC_LOSS_KERNEL_UTIL_CPU(device_type_v, log_probs_dtype_pair,          \\\n                                             targets_dtype_pair, input_lengths_dtype_pair) \\\n  template struct CtcLossKernelUtil<device_type_v, OF_PP_PAIR_FIRST(log_probs_dtype_pair), \\\n                                    OF_PP_PAIR_FIRST(targets_dtype_pair),                  \\\n                                    OF_PP_PAIR_FIRST(input_lengths_dtype_pair)>;\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_CTC_LOSS_KERNEL_UTIL_CPU, (DeviceType::kCPU),\n                                 FLOATING_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ)\n#undef INSTANTIATE_CTC_LOSS_KERNEL_UTIL_CPU\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/ctc_loss_kernel_util.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/user/kernels/ctc_loss_kernel_util.h\"\n#include \"oneflow/core/device/cuda_util.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename TARGET, typename IDX>\n__device__ __inline__ static int64_t get_target_prime(const TARGET* targets_ptr,\n                                                      const IDX* target_lengths_ptr,\n                                                      int64_t max_target_length, int64_t b,\n                                                      int64_t s, int64_t blank,\n                                                      const int32_t targets_ndim) {\n  if (s % 2 == 0) {\n    return blank;\n  } else {\n    int64_t idx = 0;\n    if (targets_ndim == 1) {\n      FOR_RANGE(int64_t, i, 0, b) { idx += target_lengths_ptr[i]; }\n    } else {  // targets_ndim == 2\n      idx = b * max_target_length;\n    }\n    idx += s / 2;\n    return static_cast<int64_t>(targets_ptr[idx]);\n  }\n}\n\ntemplate<typename T, typename TARGET, typename IDX>\n__global__ void CtcLossGpu(const T* log_probs_ptr, const TARGET* targets_ptr,\n                           const IDX* input_lengths_ptr, const IDX* target_lengths_ptr,\n                           T* alpha_ptr, T* loss_ptr, NdIndexOffsetHelper<int64_t, 3> input_helper,\n                           NdIndexOffsetHelper<int64_t, 3> alpha_helper, const int64_t batch_size,\n                           const int64_t max_input_length, const int64_t max_target_length,\n                           const int64_t blank, const int32_t targets_ndim) {\n  constexpr T neginf = -INFINITY;\n  const int32_t bid = blockIdx.x;\n  const int32_t tid = threadIdx.x;\n  for (int64_t b = bid; b < batch_size; b += gridDim.x) {\n    if (tid == 0) {\n      if (input_lengths_ptr[b] > max_input_length) __trap();\n      if (target_lengths_ptr[b] > max_target_length) __trap();\n    }\n  }\n  for (int64_t b = bid; b < batch_size; b += gridDim.x) {\n    IDX input_length = input_lengths_ptr[b];\n    IDX target_length = target_lengths_ptr[b];\n\n    for (IDX s = tid; s < 2 * target_length + 1; s += blockDim.x) {\n      alpha_ptr[alpha_helper.NdIndexToOffset(b, 0, s)] = neginf;\n    }\n    if (tid == 0) {\n      alpha_ptr[alpha_helper.NdIndexToOffset(b, 0, 0)] =\n          log_probs_ptr[input_helper.NdIndexToOffset(0, b, blank)];\n      if (target_length > 0) {\n        TARGET target = get_target_prime(targets_ptr, target_lengths_ptr, max_target_length, b, 1,\n                                         blank, targets_ndim);\n        alpha_ptr[alpha_helper.NdIndexToOffset(b, 0, 1)] =\n            log_probs_ptr[input_helper.NdIndexToOffset(0, b, target)];\n      }\n    }\n    __syncthreads();\n    for (IDX t = 1; t < input_length; t++) {\n      for (IDX s = tid; s < 2 * target_length + 1; s += blockDim.x) {\n        TARGET current_target_prime = get_target_prime(\n            targets_ptr, target_lengths_ptr, max_target_length, b, s, blank, targets_ndim);\n        T la1 = alpha_ptr[alpha_helper.NdIndexToOffset(b, t - 1, s)];\n        T la2, la3, lamax = la1;\n        if (s > 0) {\n          la2 = alpha_ptr[alpha_helper.NdIndexToOffset(b, t - 1, s - 1)];\n          if (la2 > lamax) lamax = la2;\n        } else {\n          la2 = neginf;\n        }\n        if ((s > 1)\n            && (get_target_prime(targets_ptr, target_lengths_ptr, max_target_length, b, s - 2,\n                                 blank, targets_ndim)\n                != current_target_prime)) {\n          la3 = alpha_ptr[alpha_helper.NdIndexToOffset(b, t - 1, s - 2)];\n          if (la3 > lamax) lamax = la3;\n        } else {\n          la3 = neginf;\n        }\n        if (lamax == neginf) lamax = 0;\n\n        int64_t idx_t_s = alpha_helper.NdIndexToOffset(b, t, s);\n        alpha_ptr[idx_t_s] =\n            log(exp(la1 - lamax) + exp(la2 - lamax) + exp(la3 - lamax)) + lamax\n            + log_probs_ptr[input_helper.NdIndexToOffset(t, b, current_target_prime)];\n      }\n      __syncthreads();\n    }\n    if (tid == 0) {\n      if (target_length == 0) {\n        int64_t idx = alpha_helper.NdIndexToOffset(b, input_length - 1, 0);\n        loss_ptr[b] = -alpha_ptr[idx];\n      } else {\n        int64_t idx1 = alpha_helper.NdIndexToOffset(b, input_length - 1, target_length * 2);\n        int64_t idx2 = alpha_helper.NdIndexToOffset(b, input_length - 1, target_length * 2 - 1);\n        T l1 = alpha_ptr[idx1];\n        T l2 = alpha_ptr[idx2];\n        T m = max(l1, l2);\n        m = ((m == neginf) ? 0 : m);\n        T log_likelihood = log(exp(l1 - m) + exp(l2 - m)) + m;\n        loss_ptr[b] = -log_likelihood;\n      }\n    }\n  }\n}\n\ntemplate<typename T, typename TARGET, typename IDX>\n__global__ void CtcLossGradGpu(\n    const T* grad_out_ptr, const T* loss_ptr, const T* alpha_ptr, const T* log_probs_ptr,\n    const TARGET* targets_ptr, const IDX* input_lengths_ptr, const IDX* target_lengths_ptr,\n    T* beta_ptr, T* grad_ptr, NdIndexOffsetHelper<int64_t, 3> input_helper,\n    NdIndexOffsetHelper<int64_t, 3> beta_helper, const int64_t batch_size,\n    const int64_t max_input_length, const int64_t max_target_length, const int64_t num_labels,\n    const int64_t blank, const bool zero_infinity, const int32_t targets_ndim) {\n  constexpr T neginf = -INFINITY;\n  const int32_t bid = blockIdx.x;\n  const int32_t tid = threadIdx.x;\n\n  for (int64_t b = bid; b < batch_size; b += gridDim.x) {\n    IDX input_length = input_lengths_ptr[b];\n    IDX target_length = target_lengths_ptr[b];\n    T nll = loss_ptr[b];\n    if (zero_infinity && nll == INFINITY) {\n      for (IDX t = tid; t < max_input_length; t += blockDim.x) {\n        for (IDX c = 0; c < num_labels; c++) {\n          grad_ptr[input_helper.NdIndexToOffset(t, b, c)] = 0;\n        }\n      }\n      __syncthreads();\n      continue;\n    }\n\n    if (input_length > 0) {\n      for (IDX s = tid; s < 2 * target_length + 1; s += blockDim.x) {\n        beta_ptr[beta_helper.NdIndexToOffset(b, input_length - 1, s)] = neginf;\n      }\n      if (tid == 0) {\n        beta_ptr[beta_helper.NdIndexToOffset(b, input_length - 1, 2 * target_length)] =\n            log_probs_ptr[input_helper.NdIndexToOffset(input_length - 1, b, blank)];\n        if (target_length > 0) {\n          TARGET target = get_target_prime(targets_ptr, target_lengths_ptr, max_target_length, b,\n                                           2 * target_length - 1, blank, targets_ndim);\n          beta_ptr[beta_helper.NdIndexToOffset(b, input_length - 1, 2 * target_length - 1)] =\n              log_probs_ptr[input_helper.NdIndexToOffset(input_length - 1, b, target)];\n        }\n      }\n      __syncthreads();\n    }\n    for (IDX t = input_length - 2; t >= 0; t--) {\n      for (IDX s = tid; s < 2 * target_length + 1; s += blockDim.x) {\n        TARGET current_target_prime = get_target_prime(\n            targets_ptr, target_lengths_ptr, max_target_length, b, s, blank, targets_ndim);\n        T lb1 = beta_ptr[beta_helper.NdIndexToOffset(b, t + 1, s)];\n        T lb2, lb3, lbmax = lb1;\n        if (s < 2 * target_length) {\n          lb2 = beta_ptr[beta_helper.NdIndexToOffset(b, t + 1, s + 1)];\n          if (lb2 > lbmax) lbmax = lb2;\n        } else {\n          lb2 = neginf;\n        }\n        if ((s < 2 * target_length - 1)\n            && (get_target_prime(targets_ptr, target_lengths_ptr, max_target_length, b, s + 2,\n                                 blank, targets_ndim)\n                != current_target_prime)) {\n          lb3 = beta_ptr[beta_helper.NdIndexToOffset(b, t + 1, s + 2)];\n          if (lb3 > lbmax) lbmax = lb3;\n        } else {\n          lb3 = neginf;\n        }\n        if (lbmax == neginf) lbmax = 0;\n\n        int64_t idx_t_s = beta_helper.NdIndexToOffset(b, t, s);\n        beta_ptr[idx_t_s] =\n            log(exp(lb1 - lbmax) + exp(lb2 - lbmax) + exp(lb3 - lbmax)) + lbmax\n            + log_probs_ptr[input_helper.NdIndexToOffset(t, b, current_target_prime)];\n      }\n      __syncthreads();\n    }\n    for (IDX t = tid; t < max_input_length; t += blockDim.x) {\n      for (IDX c = 0; c < num_labels; c++) {\n        grad_ptr[input_helper.NdIndexToOffset(t, b, c)] = t < input_length ? neginf : 0;\n      }\n    }\n    __syncthreads();\n    if (tid == 0) {\n      grad_ptr[input_helper.NdIndexToOffset(input_length - 1, b, blank)] =\n          alpha_ptr[beta_helper.NdIndexToOffset(b, input_length - 1, 2 * target_length)]\n          + beta_ptr[beta_helper.NdIndexToOffset(b, input_length - 1, 2 * target_length)];\n      if (target_length > 0) {\n        TARGET target = get_target_prime(targets_ptr, target_lengths_ptr, max_target_length, b,\n                                         2 * target_length - 1, blank, targets_ndim);\n        grad_ptr[input_helper.NdIndexToOffset(input_length - 1, b, target)] =\n            alpha_ptr[beta_helper.NdIndexToOffset(b, input_length - 1, 2 * target_length - 1)]\n            + beta_ptr[beta_helper.NdIndexToOffset(b, input_length - 1, 2 * target_length - 1)];\n      }\n    }\n    __syncthreads();\n    for (IDX t = tid; t < input_length; t += blockDim.x) {\n      for (IDX s = 0; (t < input_length - 1) && (s < 2 * target_length + 1); s += 1) {\n        TARGET current_target_prime = get_target_prime(\n            targets_ptr, target_lengths_ptr, max_target_length, b, s, blank, targets_ndim);\n        int64_t idx_t_s = beta_helper.NdIndexToOffset(b, t, s);\n        T log_alpha_beta = alpha_ptr[idx_t_s] + beta_ptr[idx_t_s];\n        T& lcab = grad_ptr[input_helper.NdIndexToOffset(t, b, current_target_prime)];\n        if (lcab == neginf) {\n          lcab = log_alpha_beta;\n        } else {\n          T m = max(lcab, log_alpha_beta);\n          lcab = log(exp(lcab - m) + exp(log_alpha_beta - m)) + m;\n        }\n      }\n      for (int32_t c = 0; c < num_labels; c++) {\n        T& res = grad_ptr[input_helper.NdIndexToOffset(t, b, c)];\n        T lp = log_probs_ptr[input_helper.NdIndexToOffset(t, b, c)];\n        res = (exp(lp) - exp(res + nll - lp)) * grad_out_ptr[b];\n      }\n    }\n  }\n}\n\n}  // namespace\n\ntemplate<typename T, typename TARGET, typename IDX>\nstruct CtcLossKernelUtil<DeviceType::kCUDA, T, TARGET, IDX> {\n  static void CtcLossForward(ep::Stream* stream, const T* log_probs_ptr, const TARGET* targets_ptr,\n                             const IDX* input_lengths_ptr, const IDX* target_lengths_ptr,\n                             T* alpha_ptr, T* loss_ptr,\n                             NdIndexOffsetHelper<int64_t, 3>& input_helper,\n                             NdIndexOffsetHelper<int64_t, 3>& alpha_helper,\n                             const int64_t batch_size, const int64_t max_input_length,\n                             const int64_t max_target_length, const int64_t blank,\n                             const int32_t targets_ndim) {\n    int32_t thread_num = batch_size * kCudaThreadsNumPerBlock;\n    RUN_CUDA_KERNEL((CtcLossGpu<T, TARGET, IDX>), stream, thread_num, log_probs_ptr, targets_ptr,\n                    input_lengths_ptr, target_lengths_ptr, alpha_ptr, loss_ptr, input_helper,\n                    alpha_helper, batch_size, max_input_length, max_target_length, blank,\n                    targets_ndim);\n  }\n\n  static void CtcLossBackward(ep::Stream* stream, const T* grad_out_ptr, const T* loss_ptr,\n                              const T* alpha_ptr, const T* log_probs_ptr, const TARGET* targets_ptr,\n                              const IDX* input_lengths_ptr, const IDX* target_lengths_ptr,\n                              T* beta_ptr, T* grad_ptr,\n                              NdIndexOffsetHelper<int64_t, 3>& input_helper,\n                              NdIndexOffsetHelper<int64_t, 3>& beta_helper,\n                              const int64_t batch_size, const int64_t max_input_length,\n                              const int64_t max_target_length, const int64_t num_labels,\n                              const int64_t blank, const bool zero_infinity,\n                              const int32_t targets_ndim) {\n    int32_t thread_num = batch_size * kCudaThreadsNumPerBlock;\n    RUN_CUDA_KERNEL((CtcLossGradGpu<T, TARGET, IDX>), stream, thread_num, grad_out_ptr, loss_ptr,\n                    alpha_ptr, log_probs_ptr, targets_ptr, input_lengths_ptr, target_lengths_ptr,\n                    beta_ptr, grad_ptr, input_helper, beta_helper, batch_size, max_input_length,\n                    max_target_length, num_labels, blank, zero_infinity, targets_ndim);\n  }\n};\n\n#define INSTANTIATE_CTC_LOSS_KERNEL_UTIL_CUDA(device_type_v, log_probs_dtype_pair,          \\\n                                              targets_dtype_pair, input_lengths_dtype_pair) \\\n  template struct CtcLossKernelUtil<device_type_v, OF_PP_PAIR_FIRST(log_probs_dtype_pair),  \\\n                                    OF_PP_PAIR_FIRST(targets_dtype_pair),                   \\\n                                    OF_PP_PAIR_FIRST(input_lengths_dtype_pair)>;\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_CTC_LOSS_KERNEL_UTIL_CUDA, (DeviceType::kCUDA),\n                                 FLOATING_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ)\n#undef INSTANTIATE_CTC_LOSS_KERNEL_UTIL_CUDA\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/ctc_loss_kernel_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_KERNELS_CTC_LOSS_KERNEL_UTIL_H_\n#define ONEFLOW_USER_KERNELS_CTC_LOSS_KERNEL_UTIL_H_\n\n#include \"oneflow/core/ep/include/stream.h\"\n#include \"oneflow/core/common/nd_index_offset_helper.h\"\n\nnamespace oneflow {\n\ntemplate<DeviceType device_type, typename T, typename TARGET, typename IDX>\nstruct CtcLossKernelUtil final {\n  static void CtcLossForward(ep::Stream* stream, const T* log_probs_ptr, const TARGET* targets_ptr,\n                             const IDX* input_lengths_ptr, const IDX* target_lengths_ptr,\n                             T* alpha_ptr, T* loss_ptr,\n                             NdIndexOffsetHelper<int64_t, 3>& input_helper,\n                             NdIndexOffsetHelper<int64_t, 3>& alpha_helper,\n                             const int64_t batch_size, const int64_t max_input_length,\n                             const int64_t max_target_length, const int64_t blank,\n                             const int32_t targets_ndim);\n\n  static void CtcLossBackward(ep::Stream* stream, const T* grad_out_ptr, const T* loss_ptr,\n                              const T* alpha_ptr, const T* log_probs_ptr, const TARGET* targets_ptr,\n                              const IDX* input_lengths_ptr, const IDX* target_lengths_ptr,\n                              T* beta_ptr, T* grad_ptr,\n                              NdIndexOffsetHelper<int64_t, 3>& input_helper,\n                              NdIndexOffsetHelper<int64_t, 3>& beta_helper,\n                              const int64_t batch_size, const int64_t max_input_length,\n                              const int64_t max_target_length, const int64_t num_labels,\n                              const int64_t blank, const bool zero_infinity,\n                              const int32_t targets_ndim);\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_KERNELS_CTC_LOSS_KERNEL_UTIL_H_\n"
  },
  {
    "path": "oneflow/user/kernels/cublas_bias_add_relu_matmul_grad_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/kernel/cuda_graph_support.h\"\n#include \"oneflow/user/kernels/cublas_fused_mlp_util.cuh\"\n// CUBLAS_AUX_EPILOGUE only support in cuda11.4 or higher version, in cuda11.4 it need static link.\n#if CUDA_VERSION >= 11060\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename T>\nclass CublasBiasAddReluMatmulGradKernel final : public user_op::OpKernel,\n                                                public user_op::CudaGraphSupport {\n public:\n  CublasBiasAddReluMatmulGradKernel() = default;\n  ~CublasBiasAddReluMatmulGradKernel() override = default;\n\n  std::shared_ptr<user_op::OpKernelCache> InitOpKernelCache(\n      user_op::KernelCacheContext* ctx) const override {\n    return CreateCublasFusedMLPKernelCache();\n  }\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*,\n               const user_op::OpKernelCache* cache) const override {\n    const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    const user_op::Tensor* weight = ctx->Tensor4ArgNameAndIndex(\"weight\", 0);\n    const user_op::Tensor* aux = ctx->Tensor4ArgNameAndIndex(\"aux\", 0);\n    user_op::Tensor* d_bias = ctx->Tensor4ArgNameAndIndex(\"d_bias\", 0);\n    user_op::Tensor* d_grad = ctx->Tensor4ArgNameAndIndex(\"d_grad\", 0);\n    const auto* matmul_grad_cache =\n        CHECK_NOTNULL(dynamic_cast<const CublasFusedMLPKernelCache*>(cache));\n    auto* cuda_stream = ctx->stream()->As<ep::CudaStream>();\n\n    const DataType data_type = dy->data_type();\n    const cublasComputeType_t cublas_compute_dtype = GetComputeType(data_type);\n    const cudaDataType_t cuda_data_type = GetCudaDataType(data_type);\n    size_t cublas_m = 0, cublas_n = 0, cublas_k = 0;\n    int64_t cublas_lda = 0, cublas_ldb = 0, cublas_ldc = 0;\n\n    const double alpha = ctx->Attr<double>(\"alpha\");\n    const auto sp_alpha = GetCublasScalarParameter(alpha, cublas_compute_dtype);\n    const double beta = 0.0;\n    const auto sp_beta = GetCublasScalarParameter(beta, cublas_compute_dtype);\n\n    // currently only support 2D matmul.\n    DimVector dy_shape(2);\n    dy->shape_view().ToDimVector(&dy_shape);\n    DimVector weight_shape(2);\n    weight->shape_view().ToDimVector(&weight_shape);\n    cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DRELU_BGRAD;\n\n    InferMatmulCublasMNK(dy_shape, weight_shape,\n                         /*transpose_a=*/ep::primitive::BlasTransposeType::N,\n                         /*transpose_b=*/ep::primitive::BlasTransposeType::N, &cublas_m, &cublas_n,\n                         &cublas_k, &cublas_lda, &cublas_ldb, &cublas_ldc);\n\n    SetCublasAttr(matmul_grad_cache, cublas_compute_dtype, cuda_data_type, /*need_aux=*/true,\n                  /*transpose_a=*/ep::primitive::BlasTransposeType::N,\n                  /*transpose_b=*/ep::primitive::BlasTransposeType::N, epilogue, d_bias->dptr(),\n                  aux->dptr(), cublas_m, cublas_n, cublas_k, cublas_lda, cublas_ldb, cublas_ldc);\n    /*\n    a = dy, b = weight\n    cublas_a=weight, cublas_b=dy\n    */\n    OF_CUBLAS_CHECK(\n        cublasLtMatmul(cuda_stream->cublas_lt_handle(), matmul_grad_cache->operation_desc,\n                       &sp_alpha, weight->dptr(), matmul_grad_cache->cublas_a_desc, dy->dptr(),\n                       matmul_grad_cache->cublas_b_desc, &sp_beta, d_grad->mut_dptr(),\n                       matmul_grad_cache->cublas_c_desc, d_grad->mut_dptr(),\n                       matmul_grad_cache->cublas_c_desc, nullptr, cuda_stream->cublas_workspace(),\n                       cuda_stream->cublas_workspace_size(), cuda_stream->cuda_stream()));\n  };\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_CUBLAS_BIAS_ADD_RELU_MATMUL_GRAD_KERNEL(dtype)        \\\n  REGISTER_USER_KERNEL(\"cublas_bias_add_relu_matmul_grad\")             \\\n      .SetCreateFn<CublasBiasAddReluMatmulGradKernel<dtype>>()         \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \\\n                       && (user_op::HobDataType(\"weight\", 0) == GetDataType<dtype>::value));\n\nREGISTER_CUBLAS_BIAS_ADD_RELU_MATMUL_GRAD_KERNEL(float)\nREGISTER_CUBLAS_BIAS_ADD_RELU_MATMUL_GRAD_KERNEL(double)\nREGISTER_CUBLAS_BIAS_ADD_RELU_MATMUL_GRAD_KERNEL(half)\n\n}  // namespace\n\n}  // namespace oneflow\n\n#endif  // CUDA_VERSION >= 11060\n"
  },
  {
    "path": "oneflow/user/kernels/cublas_fused_matmul_bias_add_grad.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/kernel/cuda_graph_support.h\"\n#include \"oneflow/user/kernels/cublas_fused_mlp_util.cuh\"\n#include \"oneflow/core/ep/include/primitive/memcpy.h\"\n#include \"oneflow/core/ep/cuda/cuda_device.h\"\n// CUBLASLT_EPILOGUE_BGRADB only support in cuda11.4.2 or higher version.\n// TODO(zhengzekang): In cuda11.6 version, CUBLASLT_EPILOGUE_BGRADB may occur illegal memory access\n// error in some shapes.\n#if CUDA_VERSION >= 11060\n\nnamespace oneflow {\n\nnamespace {\n\ncudaDataType_t GetGemmComputeType(cudaDataType_t data_type) {\n  switch (data_type) {\n    case CUDA_R_32F: return CUDA_R_32F;\n    case CUDA_R_64F: return CUDA_R_64F;\n    case CUDA_R_16F: return CUDA_R_32F;\n#if CUDA_VERSION >= 11000\n    case CUDA_R_16BF: return CUDA_R_32F;\n#endif  // CUDA_VERSION >= 11000\n    default: UNIMPLEMENTED(); return CUDA_R_32F;\n  }\n}\n\ntemplate<typename T>\nclass CublasMatmulBiasAddGradKernel final : public user_op::OpKernel,\n                                            public user_op::CudaGraphSupport {\n public:\n  CublasMatmulBiasAddGradKernel() = default;\n  ~CublasMatmulBiasAddGradKernel() override = default;\n\n  std::shared_ptr<user_op::OpKernelCache> InitOpKernelCache(\n      user_op::KernelCacheContext* ctx) const override {\n    return CreateCublasFusedMLPKernelCache();\n  }\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*,\n               const user_op::OpKernelCache* cache) const override {\n    const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    user_op::Tensor* w_grad = ctx->Tensor4ArgNameAndIndex(\"w_grad\", 0);\n    user_op::Tensor* b_grad = ctx->Tensor4ArgNameAndIndex(\"b_grad\", 0);\n    const auto* matmul_grad_cache =\n        CHECK_NOTNULL(dynamic_cast<const CublasFusedMLPKernelCache*>(cache));\n    auto* cuda_stream = ctx->stream()->As<ep::CudaStream>();\n\n    const DataType data_type = dy->data_type();\n    const cublasComputeType_t cublas_compute_dtype = GetComputeType(data_type);\n    const cudaDataType_t cuda_data_type = GetCudaDataType(data_type);\n    size_t cublas_m = 0, cublas_n = 0, cublas_k = 0;\n    int64_t cublas_lda = 0, cublas_ldb = 0, cublas_ldc = 0;\n    const double alpha = 1.0;\n    const auto sp_alpha = GetCublasScalarParameter(alpha, cublas_compute_dtype);\n    const double beta = 0.0;\n    const auto sp_beta = GetCublasScalarParameter(beta, cublas_compute_dtype);\n\n    // currently only support 2D matmul.\n    DimVector dy_shape(2);\n    dy->shape_view().ToDimVector(&dy_shape);\n    DimVector x_shape(2);\n    x->shape_view().ToDimVector(&x_shape);\n    cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_BGRADB;\n\n    InferMatmulCublasMNK(dy_shape, x_shape,\n                         /*transpose_a=*/ep::primitive::BlasTransposeType::T,\n                         /*transpose_b=*/ep::primitive::BlasTransposeType::N, &cublas_m, &cublas_n,\n                         &cublas_k, &cublas_lda, &cublas_ldb, &cublas_ldc);\n    if (cublas_k != 1) {\n      SetCublasAttr(\n          matmul_grad_cache, cublas_compute_dtype, cuda_data_type, /*need_aux=*/false,\n          /*transpose_a=*/ep::primitive::BlasTransposeType::T,\n          /*transpose_b=*/ep::primitive::BlasTransposeType::N, epilogue, b_grad->mut_dptr(),\n          /*aux_ptr=*/nullptr, cublas_m, cublas_n, cublas_k, cublas_lda, cublas_ldb, cublas_ldc);\n\n      /*\n      a = dy, b = x\n      cublas_a=x, cublas_b=dy\n      */\n      OF_CUBLAS_CHECK(cublasLtMatmul(\n          cuda_stream->cublas_lt_handle(), matmul_grad_cache->operation_desc, &sp_alpha, x->dptr(),\n          matmul_grad_cache->cublas_a_desc, dy->dptr(), matmul_grad_cache->cublas_b_desc, &sp_beta,\n          w_grad->mut_dptr(), matmul_grad_cache->cublas_c_desc, w_grad->mut_dptr(),\n          matmul_grad_cache->cublas_c_desc, nullptr, cuda_stream->cublas_workspace(),\n          cuda_stream->cublas_workspace_size(), cuda_stream->cuda_stream()));\n    } else {\n// Cause cublasLtmatmul get wrong bias grad in cublas_k == 1.\n#if CUDA_VERSION >= 11000\n      cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT;\n#else\n      cublasGemmAlgo_t algo =\n          (data_type == DataType::kFloat16) ? CUBLAS_GEMM_DFALT_TENSOR_OP : CUBLAS_GEMM_DEFAULT;\n#endif\n\n      cudaDataType_t gemm_compute_type = GetGemmComputeType(cuda_data_type);\n      std::unique_ptr<ep::primitive::Memcpy> memcpy_primitive =\n          ep::primitive::NewPrimitive<ep::primitive::MemcpyFactory>(\n              ctx->stream()->device_type(), ep::primitive::MemcpyKind::kDtoD);\n      CHECK(memcpy_primitive);\n      memcpy_primitive->Launch(ctx->stream(), b_grad->mut_dptr(), dy->dptr(), cublas_n * sizeof(T));\n      OF_CUBLAS_CHECK(cublasGemmEx(\n          cuda_stream->cublas_handle(), CUBLAS_OP_N, CUBLAS_OP_T, cublas_m, cublas_n, cublas_k,\n          &sp_alpha, x->dptr(), cuda_data_type, cublas_lda, dy->dptr(), cuda_data_type, cublas_ldb,\n          &sp_beta, w_grad->mut_dptr(), cuda_data_type, cublas_ldc, gemm_compute_type, algo));\n    }\n  };\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_CUBLAS_MATMUL_BIAS_ADD_GRAD_KERNEL(dtype)             \\\n  REGISTER_USER_KERNEL(\"cublas_matmul_bias_add_grad\")                  \\\n      .SetCreateFn<CublasMatmulBiasAddGradKernel<dtype>>()             \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \\\n                       && (user_op::HobDataType(\"x\", 0) == GetDataType<dtype>::value));\n\nREGISTER_CUBLAS_MATMUL_BIAS_ADD_GRAD_KERNEL(float)\nREGISTER_CUBLAS_MATMUL_BIAS_ADD_GRAD_KERNEL(double)\nREGISTER_CUBLAS_MATMUL_BIAS_ADD_GRAD_KERNEL(half)\n\n}  // namespace\n\n}  // namespace oneflow\n\n#endif  // CUDA_VERSION >= 11060\n"
  },
  {
    "path": "oneflow/user/kernels/cublas_fused_mlp_grad_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/job/parallel_desc.h\"\n#include \"oneflow/core/kernel/cuda_graph_support.h\"\n#include \"oneflow/user/kernels/cublas_fused_mlp_util.cuh\"\n#include \"oneflow/core/ep/include/primitive/fill.h\"\n#include \"oneflow/core/device/nccl_util.h\"\n#include \"oneflow/core/job/eager_nccl_comm_manager.h\"\n// CUBLAS_AUX_EPILOGUE only support in cuda11.4 or higher version, in cuda11.4 it need static link.\n#if CUDA_VERSION >= 11060\n\nnamespace oneflow {\n\nnamespace {\n\nstruct Comm {\n  Comm(ncclComm_t comm) : comm(comm) {}\n  ncclComm_t comm;\n};\n\nclass MatmulGradKernelState final : public user_op::OpKernelState {\n public:\n  MatmulGradKernelState(user_op::KernelInitContext* ctx)\n      : if_need_comm_(false), stream_name_(EagerNcclCommMgr::kDefaultStreamName) {\n    OF_CUDA_CHECK(cudaStreamCreate(&cuda_stream_));\n    OF_CUDA_CHECK(cudaStreamCreate(&allreduce_stream_));\n    OF_CUBLAS_CHECK(cublasLtCreate(&cublas_lt_handle_));\n    workspace_size_ =\n        ParseIntegerFromEnv(\"ONEFLOW_EP_CUDA_CUBLAS_WORKSPACE_SIZE_MB\", kDefaultWorkspaceSizeMb)\n        * 1024 * 1024;\n    OF_CUDA_CHECK(cudaMalloc(&workspace_, workspace_size_));\n    if (ctx->parallel_ctx().parallel_num() > 1) {\n      parallel_conf_ = ctx->parallel_desc().parallel_conf();\n    }\n  }\n  ~MatmulGradKernelState() {\n    OF_CUDA_CHECK(cudaStreamSynchronize(cuda_stream_));\n    OF_CUBLAS_CHECK(cublasLtDestroy(cublas_lt_handle_));\n    OF_CUDA_CHECK(cudaStreamDestroy(cuda_stream_));\n    OF_CUDA_CHECK(cudaStreamSynchronize(allreduce_stream_));\n    OF_CUDA_CHECK(cudaStreamDestroy(allreduce_stream_));\n    OF_CUDA_CHECK(cudaFree(workspace_));\n  }\n  cudaStream_t grad_cuda_stream() const { return cuda_stream_; }\n  cudaStream_t allreduce_stream() const { return allreduce_stream_; }\n  cublasLtHandle_t cublas_lt_handle() const { return cublas_lt_handle_; }\n  size_t cublas_workspace_size() const { return workspace_size_; }\n  void* cublas_workspace() const { return workspace_; }\n\n  bool IfCommCreate() const {\n    if (!comm_) { return false; }\n    return true;\n  }\n\n  bool IfNeedComm() const { return if_need_comm_; }\n\n  ncclComm_t comm() { return GetOrCreate().comm; }\n\n  const Comm& GetOrCreate() {\n    if (!comm_) { InitCommMgr(); }\n    return *comm_;\n  }\n\n  void InitNeedComm(user_op::KernelInitContext* ctx) {\n    if_need_comm_ = false;\n    if (ctx->parallel_ctx().parallel_num() > 1) {\n      const int64_t d_weights_size = ctx->output_size(\"d_weights\");\n      if (ctx->SbpParallel4ArgNameAndIndex(\"d_weights\", 0).has_broadcast_parallel()) {\n        for (int i = 0; i < d_weights_size; i++) {\n          CHECK(ctx->SbpParallel4ArgNameAndIndex(\"d_weights\", i).has_broadcast_parallel())\n              << \"All d_weight's SBP should be Broadcast. \";\n          CHECK(ctx->SbpParallel4ArgNameAndIndex(\"d_biases\", i).has_broadcast_parallel())\n              << \"All d_bias's SBP should be Broadcast. \";\n        }\n        if (ctx->SbpParallel4ArgNameAndIndex(\"dy\", 0).has_split_parallel()) {\n          if_need_comm_ = true;\n        }\n      }\n    }\n  }\n\n  void InitCommMgr() {\n    std::set<std::pair<int64_t, int64_t>> device_set;\n    const ParallelDesc parallel_desc(parallel_conf_);\n    for (int64_t parallel_id = 0; parallel_id < parallel_desc.parallel_num(); ++parallel_id) {\n      int64_t machine_id = CHECK_JUST(parallel_desc.MachineId4ParallelId(parallel_id));\n      int64_t device_id = CHECK_JUST(parallel_desc.DeviceId4ParallelId(parallel_id));\n      device_set.emplace(std::make_pair(machine_id, device_id));\n    }\n    EagerCclCommMgr* comm_mgr = CHECK_NOTNULL(Singleton<EagerCclCommMgr>::Get());\n    ncclComm_t comm;\n    comm =\n        comm_mgr->As<EagerNcclCommMgr>()->GetCommForDeviceAndStreamName(device_set, stream_name_);\n    comm_.reset(new Comm(comm));\n  }\n\n private:\n  cudaStream_t cuda_stream_{};\n  cudaStream_t allreduce_stream_{};\n  cublasLtHandle_t cublas_lt_handle_{};\n  void* workspace_{};\n  size_t workspace_size_;\n  std::string stream_name_;\n  std::unique_ptr<Comm> comm_;\n  bool if_need_comm_;\n  ParallelConf parallel_conf_;\n};\n\ntemplate<typename T>\nclass CublasFusedMLPGradKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport {\n public:\n  CublasFusedMLPGradKernel() {\n    OF_CUDA_CHECK(cudaEventCreate(&main_stream_event_));\n    OF_CUDA_CHECK(cudaEventCreate(&async_weight_grad_event_));\n    OF_CUDA_CHECK(cudaEventCreate(&dweight_event_));\n    OF_CUDA_CHECK(cudaEventCreate(&allreduce_event_));\n  };\n  ~CublasFusedMLPGradKernel() override {\n    OF_CUDA_CHECK(cudaEventDestroy(main_stream_event_));\n    OF_CUDA_CHECK(cudaEventDestroy(async_weight_grad_event_));\n    OF_CUDA_CHECK(cudaEventDestroy(dweight_event_));\n    OF_CUDA_CHECK(cudaEventDestroy(allreduce_event_));\n  };\n\n  std::shared_ptr<user_op::OpKernelCache> InitOpKernelCache(\n      user_op::KernelCacheContext* ctx) const override {\n    return CreateCublasFusedMLPKernelCache();\n  }\n\n  std::shared_ptr<user_op::OpKernelState> CreateOpKernelState(\n      user_op::KernelInitContext* ctx) const override {\n    std::shared_ptr<MatmulGradKernelState> kernel_state =\n        std::make_shared<MatmulGradKernelState>(ctx);\n    kernel_state->InitNeedComm(ctx);\n    return kernel_state;\n  }\n\n private:\n  cudaEvent_t main_stream_event_;\n  cudaEvent_t async_weight_grad_event_;\n  cudaEvent_t dweight_event_;\n  cudaEvent_t allreduce_event_;\n\n  bool IsReadyForCapture(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state,\n                         const user_op::OpKernelCache* cache) const override {\n    auto* kernel_state = dynamic_cast<MatmulGradKernelState*>(state);\n    if (kernel_state->IfNeedComm()) {\n      return kernel_state->IfCommCreate();\n    } else {\n      return true;\n    }\n  }\n\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state,\n               const user_op::OpKernelCache* cache) const override {\n    const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n    int64_t tmp_buf_elem_cnt = tmp_buffer->shape_view().elem_cnt();\n    const int64_t weight_num = ctx->input_size(\"weights\");\n    user_op::Tensor* d_x = ctx->Tensor4ArgNameAndIndex(\"d_x\", 0);\n    const std::vector<float> alpha_list = ctx->Attr<std::vector<float>>(\"alpha_list\");\n\n    auto* kernel_state = dynamic_cast<MatmulGradKernelState*>(state);\n    const auto* matmul_grad_cache =\n        CHECK_NOTNULL(dynamic_cast<const CublasFusedMLPKernelCache*>(cache));\n\n    ncclComm_t comm{};\n    bool if_need_comm = kernel_state->IfNeedComm();\n\n    if (if_need_comm) { comm = kernel_state->comm(); }\n\n    void* dy_tmp_buf = tmp_buffer->mut_dptr();\n    size_t tmp_buf_offset = 0;\n    auto* cuda_stream = ctx->stream()->As<ep::CudaStream>();\n\n    const DataType data_type = dy->data_type();\n    const cublasComputeType_t cublas_compute_dtype = GetComputeType(data_type);\n    const cudaDataType_t cuda_data_type = GetCudaDataType(data_type);\n    size_t cublas_m = 0, cublas_n = 0, cublas_k = 0;\n    int64_t cublas_lda = 0, cublas_ldb = 0, cublas_ldc = 0;\n\n    const double alpha_one = 1.0;\n    auto sp_alpha_one = GetCublasScalarParameter(alpha_one, cublas_compute_dtype);\n    double alpha = 1.0;\n    auto sp_alpha = GetCublasScalarParameter(alpha, cublas_compute_dtype);\n    double beta = 0.0;\n    auto sp_beta = GetCublasScalarParameter(beta, cublas_compute_dtype);\n\n    cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT;\n\n    // currently only support 2D matmul.\n    DimVector weight_shape(2);\n    DimVector hidden_shape(2);\n    DimVector dy_shape(2);\n    dy->shape_view().ToDimVector(&dy_shape);\n    const void* dgrad_buf = dy->dptr();\n\n    const int64_t batch_size = dy->shape_view().At(0);\n    const void* ones = nullptr;\n    ep::CudaDevice* cuda_device = dynamic_cast<ep::CudaDevice*>(ctx->stream()->device());\n    CHECK_NOTNULL(cuda_device);\n    ones = cuda_device->GetConstOnes(dy->data_type(), batch_size);\n    if (ones == nullptr) {\n      std::unique_ptr<ep::primitive::Fill> fill =\n          ep::primitive::NewPrimitive<ep::primitive::FillFactory>(ctx->stream()->device_type(),\n                                                                  data_type);\n      CHECK(fill);\n      fill->Launch(ctx->stream(), tmp_buffer->mut_dptr(), 1.0, batch_size);\n      ones = tmp_buffer->mut_dptr();\n      tmp_buf_offset += GetCudaAlignedSize(batch_size * sizeof(T));\n      dy_tmp_buf = reinterpret_cast<void*>(tmp_buffer->mut_dptr<char>() + tmp_buf_offset);\n    }\n\n    for (int idx = weight_num - 1; idx >= 0; idx--) {\n      const user_op::Tensor* weight = ctx->Tensor4ArgNameAndIndex(\"weights\", idx);\n      weight->shape_view().ToDimVector(&weight_shape);\n      InferMatmulCublasMNK(dy_shape, weight_shape,\n                           /*transpose_a=*/ep::primitive::BlasTransposeType::N,\n                           /*transpose_b=*/ep::primitive::BlasTransposeType::N, &cublas_m,\n                           &cublas_n, &cublas_k, &cublas_lda, &cublas_ldb, &cublas_ldc);\n      if (idx != 0) {\n        alpha = alpha_list.at(idx - 1);\n        sp_alpha = GetCublasScalarParameter(alpha, cublas_compute_dtype);\n        const user_op::Tensor* aux = ctx->Tensor4ArgNameAndIndex(\"cublas_aux\", idx - 1);\n        user_op::Tensor* d_bias = ctx->Tensor4ArgNameAndIndex(\"d_biases\", idx - 1);\n        epilogue = CUBLASLT_EPILOGUE_DRELU_BGRAD;\n        SetCublasAttr(matmul_grad_cache, cublas_compute_dtype, cuda_data_type, /*need_aux=*/true,\n                      /*transpose_a=*/ep::primitive::BlasTransposeType::N,\n                      /*transpose_b=*/ep::primitive::BlasTransposeType::N, epilogue,\n                      d_bias->mut_dptr(), aux->dptr(), cublas_m, cublas_n, cublas_k, cublas_lda,\n                      cublas_ldb, cublas_ldc);\n        /*\n        a = dy, b = weight\n        cublas_a=weight, cublas_b=dy\n        */\n        OF_CUDA_CHECK(cudaEventRecord(main_stream_event_, cuda_stream->cuda_stream()));\n        OF_CUBLAS_CHECK(cublasLtMatmul(\n            cuda_stream->cublas_lt_handle(), matmul_grad_cache->operation_desc, &sp_alpha,\n            weight->dptr(), matmul_grad_cache->cublas_a_desc, dgrad_buf,\n            matmul_grad_cache->cublas_b_desc, &sp_beta, dy_tmp_buf,\n            matmul_grad_cache->cublas_c_desc, dy_tmp_buf, matmul_grad_cache->cublas_c_desc, nullptr,\n            cuda_stream->cublas_workspace(), cuda_stream->cublas_workspace_size(),\n            cuda_stream->cuda_stream()));\n      } else {\n        epilogue = CUBLASLT_EPILOGUE_DEFAULT;\n        SetCublasAttr(matmul_grad_cache, cublas_compute_dtype, cuda_data_type, /*need_aux=*/false,\n                      /*transpose_a=*/ep::primitive::BlasTransposeType::N,\n                      /*transpose_b=*/ep::primitive::BlasTransposeType::N, epilogue, nullptr,\n                      nullptr, cublas_m, cublas_n, cublas_k, cublas_lda, cublas_ldb, cublas_ldc);\n        /*\n        a = dy, b = weight\n        cublas_a=weight, cublas_b=dy\n        */\n        OF_CUDA_CHECK(cudaEventRecord(main_stream_event_, cuda_stream->cuda_stream()));\n        OF_CUBLAS_CHECK(cublasLtMatmul(\n            cuda_stream->cublas_lt_handle(), matmul_grad_cache->operation_desc, &sp_alpha_one,\n            weight->dptr(), matmul_grad_cache->cublas_a_desc, dgrad_buf,\n            matmul_grad_cache->cublas_b_desc, &sp_beta, d_x->mut_dptr(),\n            matmul_grad_cache->cublas_c_desc, d_x->mut_dptr(), matmul_grad_cache->cublas_c_desc,\n            nullptr, cuda_stream->cublas_workspace(), cuda_stream->cublas_workspace_size(),\n            cuda_stream->cuda_stream()));\n      }\n\n      // step1: Get last layer's dbias.\n      if (idx == weight_num - 1) {\n        user_op::Tensor* d_last_bias = ctx->Tensor4ArgNameAndIndex(\"d_biases\", weight_num - 1);\n        DimVector ones_buf_shape(2);\n        ones_buf_shape.at(0) = 1;\n        ones_buf_shape.at(1) = batch_size;\n        epilogue = CUBLASLT_EPILOGUE_DEFAULT;\n        InferMatmulCublasMNK(ones_buf_shape, dy_shape,\n                             /*transpose_a=*/ep::primitive::BlasTransposeType::N,\n                             /*transpose_b=*/ep::primitive::BlasTransposeType::N, &cublas_m,\n                             &cublas_n, &cublas_k, &cublas_lda, &cublas_ldb, &cublas_ldc);\n        SetCublasAttr(matmul_grad_cache, cublas_compute_dtype, cuda_data_type, /*need_aux=*/false,\n                      /*transpose_a=*/ep::primitive::BlasTransposeType::N,\n                      /*transpose_b=*/ep::primitive::BlasTransposeType::N, epilogue, nullptr,\n                      nullptr, cublas_m, cublas_n, cublas_k, cublas_lda, cublas_ldb, cublas_ldc);\n        OF_CUDA_CHECK(cudaStreamWaitEvent(kernel_state->grad_cuda_stream(), main_stream_event_));\n        OF_CUBLAS_CHECK(cublasLtMatmul(\n            kernel_state->cublas_lt_handle(), matmul_grad_cache->operation_desc, &sp_alpha_one,\n            dgrad_buf, matmul_grad_cache->cublas_a_desc, ones, matmul_grad_cache->cublas_b_desc,\n            &sp_beta, d_last_bias->mut_dptr(), matmul_grad_cache->cublas_c_desc,\n            d_last_bias->mut_dptr(), matmul_grad_cache->cublas_c_desc, nullptr,\n            kernel_state->cublas_workspace(), kernel_state->cublas_workspace_size(),\n            kernel_state->grad_cuda_stream()));\n      }\n\n      user_op::Tensor* d_weight = ctx->Tensor4ArgNameAndIndex(\"d_weights\", idx);\n      epilogue = CUBLASLT_EPILOGUE_DEFAULT;\n      if (idx != 0) {\n        const user_op::Tensor* hidden = ctx->Tensor4ArgNameAndIndex(\"hidden\", idx - 1);  // here\n        hidden->shape_view().ToDimVector(&hidden_shape);\n        InferMatmulCublasMNK(dy_shape, hidden_shape,\n                             /*transpose_a=*/ep::primitive::BlasTransposeType::T,\n                             /*transpose_b=*/ep::primitive::BlasTransposeType::N, &cublas_m,\n                             &cublas_n, &cublas_k, &cublas_lda, &cublas_ldb, &cublas_ldc);\n\n        SetCublasAttr(matmul_grad_cache, cublas_compute_dtype, cuda_data_type, /*need_aux=*/false,\n                      /*transpose_a=*/ep::primitive::BlasTransposeType::T,\n                      /*transpose_b=*/ep::primitive::BlasTransposeType::N, epilogue, nullptr,\n                      nullptr, cublas_m, cublas_n, cublas_k, cublas_lda, cublas_ldb, cublas_ldc);\n        if (idx != weight_num - 1) {\n          // if idx == weight_num - 1, async_stream has wait main_stream_event_ in d_bias.\n          OF_CUDA_CHECK(cudaStreamWaitEvent(kernel_state->grad_cuda_stream(), main_stream_event_));\n        }\n        OF_CUBLAS_CHECK(cublasLtMatmul(\n            kernel_state->cublas_lt_handle(), matmul_grad_cache->operation_desc, &sp_alpha_one,\n            hidden->dptr(), matmul_grad_cache->cublas_a_desc, dgrad_buf,\n            matmul_grad_cache->cublas_b_desc, &sp_beta, d_weight->mut_dptr(),\n            matmul_grad_cache->cublas_c_desc, d_weight->mut_dptr(),\n            matmul_grad_cache->cublas_c_desc, nullptr, kernel_state->cublas_workspace(),\n            kernel_state->cublas_workspace_size(), kernel_state->grad_cuda_stream()));\n        OF_CUDA_CHECK(cudaEventRecord(dweight_event_, kernel_state->grad_cuda_stream()));\n        // compute dy shape\n        dy_shape.at(1) = weight_shape.at(1);\n        // compute dybuf\n        dgrad_buf = dy_tmp_buf;\n        tmp_buf_offset += GetCudaAlignedSize(dy_shape.at(0) * dy_shape.at(1) * sizeof(T));\n        CHECK_LE(tmp_buf_offset, tmp_buf_elem_cnt)\n            << \"Tmp buffer offset should <= Tmp buffer elem_cnt. \";\n        dy_tmp_buf = reinterpret_cast<void*>(tmp_buffer->mut_dptr<char>() + tmp_buf_offset);\n      } else {\n        x->shape_view().ToDimVector(&hidden_shape);\n        InferMatmulCublasMNK(dy_shape, hidden_shape,\n                             /*transpose_a=*/ep::primitive::BlasTransposeType::T,\n                             /*transpose_b=*/ep::primitive::BlasTransposeType::N, &cublas_m,\n                             &cublas_n, &cublas_k, &cublas_lda, &cublas_ldb, &cublas_ldc);\n        SetCublasAttr(matmul_grad_cache, cublas_compute_dtype, cuda_data_type, /*need_aux=*/false,\n                      /*transpose_a=*/ep::primitive::BlasTransposeType::T,\n                      /*transpose_b=*/ep::primitive::BlasTransposeType::N, epilogue, nullptr,\n                      nullptr, cublas_m, cublas_n, cublas_k, cublas_lda, cublas_ldb, cublas_ldc);\n        OF_CUDA_CHECK(cudaStreamWaitEvent(kernel_state->grad_cuda_stream(), main_stream_event_));\n        OF_CUBLAS_CHECK(cublasLtMatmul(\n            kernel_state->cublas_lt_handle(), matmul_grad_cache->operation_desc, &sp_alpha_one,\n            x->dptr(), matmul_grad_cache->cublas_a_desc, dgrad_buf,\n            matmul_grad_cache->cublas_b_desc, &sp_beta, d_weight->mut_dptr(),\n            matmul_grad_cache->cublas_c_desc, d_weight->mut_dptr(),\n            matmul_grad_cache->cublas_c_desc, nullptr, kernel_state->cublas_workspace(),\n            kernel_state->cublas_workspace_size(), kernel_state->grad_cuda_stream()));\n        OF_CUDA_CHECK(cudaEventRecord(dweight_event_, kernel_state->grad_cuda_stream()));\n      }\n\n      if (if_need_comm) {\n        // Do Allreduce for d_bias and d_weight.\n        // Here we wait wgrad event, and set a ncclGroup to Allreduce d_bias and d_weight.\n        OF_CUDA_CHECK(cudaStreamWaitEvent(kernel_state->allreduce_stream(), dweight_event_));\n        OF_NCCL_CHECK(ncclGroupStart());\n        user_op::Tensor* allreduce_d_bias = ctx->Tensor4ArgNameAndIndex(\"d_biases\", idx);\n        OF_NCCL_CHECK(ncclAllReduce(allreduce_d_bias->mut_dptr(), allreduce_d_bias->mut_dptr(),\n                                    allreduce_d_bias->shape_view().elem_cnt(),\n                                    GetNcclDataType(allreduce_d_bias->data_type()),\n                                    ncclRedOp_t::ncclSum, comm, kernel_state->allreduce_stream()));\n        OF_NCCL_CHECK(ncclAllReduce(d_weight->mut_dptr(), d_weight->mut_dptr(),\n                                    d_weight->shape_view().elem_cnt(),\n                                    GetNcclDataType(d_weight->data_type()), ncclRedOp_t::ncclSum,\n                                    comm, kernel_state->allreduce_stream()));\n        OF_NCCL_CHECK(ncclGroupEnd());\n        if (idx == 0) {\n          // We should sync allreduce before the kernel finish.\n          OF_CUDA_CHECK(cudaEventRecord(allreduce_event_, kernel_state->allreduce_stream()));\n        }\n      }\n    }\n\n    if (if_need_comm) {\n      OF_CUDA_CHECK(cudaStreamWaitEvent(cuda_stream->cuda_stream(), allreduce_event_));\n    } else {\n      OF_CUDA_CHECK(cudaStreamWaitEvent(cuda_stream->cuda_stream(), dweight_event_));\n    }\n  };\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_CUBLAS_FUSED_MLP_GRAD_KERNEL(dtype)                                             \\\n  REGISTER_USER_KERNEL(\"cublas_fused_mlp_grad\")                                                  \\\n      .SetCreateFn<CublasFusedMLPGradKernel<dtype>>()                                            \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)                           \\\n                       && (user_op::HobDataType(\"x\", 0) == GetDataType<dtype>::value))           \\\n      .SetInferTmpSizeFn([](user_op::InferContext* ctx) {                                        \\\n        const int64_t weight_num = ctx->input_size(\"weights\");                                   \\\n        const Shape& dy_shape = ctx->InputShape(\"dy\", 0);                                        \\\n        int64_t m = dy_shape.At(0);                                                              \\\n        int64_t k = dy_shape.At(1);                                                              \\\n        int64_t tmp_buffer_size = 0;                                                             \\\n        tmp_buffer_size += GetCudaAlignedSize(m * sizeof(dtype)); /*For last layer's bias grad*/ \\\n        for (int idx = weight_num - 1; idx > 0; idx--) {                                         \\\n          const Shape& weight_shape = ctx->InputShape(\"weights\", idx);                           \\\n          k = weight_shape.At(1);                                                                \\\n          tmp_buffer_size += GetCudaAlignedSize(m * k * sizeof(dtype));                          \\\n        }                                                                                        \\\n        return tmp_buffer_size;                                                                  \\\n      });\n\nREGISTER_CUBLAS_FUSED_MLP_GRAD_KERNEL(float)\nREGISTER_CUBLAS_FUSED_MLP_GRAD_KERNEL(double)\nREGISTER_CUBLAS_FUSED_MLP_GRAD_KERNEL(half)\n\nREGISTER_USER_KERNEL_UNIFIED_NCCL_COMM_INIT(\"cublas_fused_mlp_grad\");\n\n}  // namespace\n\n}  // namespace oneflow\n\n#endif  // CUDA_VERSION >= 11060\n"
  },
  {
    "path": "oneflow/user/kernels/cublas_fused_mlp_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/kernel/cuda_graph_support.h\"\n#include \"oneflow/user/kernels/cublas_fused_mlp_util.cuh\"\n// CUBLAS_AUX_EPILOGUE only support in cuda11.4 or higher version, in cuda11.4 it need static link.\n#if CUDA_VERSION >= 11060\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename T>\nclass CublasFusedMLPKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport {\n public:\n  CublasFusedMLPKernel() = default;\n  ~CublasFusedMLPKernel() override = default;\n\n  std::shared_ptr<user_op::OpKernelCache> InitOpKernelCache(\n      user_op::KernelCacheContext* ctx) const override {\n    return CreateCublasFusedMLPKernelCache();\n  }\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*,\n               const user_op::OpKernelCache* cache) const override {\n    /*\n    Fused DenseActivation Layer. Assume we have two layers:\n    A: (m, k)\n    B: (n, k) need transpose\n    C: (j, n) need transpose\n    tmp: A matmul B(transpose), its shape is (m, n)\n    out: tmp matmul C(transpose), its shape is (m, j)\n    */\n    const int32_t weight_size = ctx->input_size(\"weights\");\n    const int32_t bias_size = ctx->input_size(\"biases\");\n    CHECK_EQ(weight_size, bias_size) << \"The number of weight and bias is not equal!. \";\n    auto* cuda_stream = ctx->stream()->As<ep::CudaStream>();\n    const auto* matmul_cache = CHECK_NOTNULL(dynamic_cast<const CublasFusedMLPKernelCache*>(cache));\n\n    const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    bool skip_final_activation = ctx->Attr<bool>(\"skip_final_activation\");\n\n    const DataType data_type = out->data_type();\n    const cublasComputeType_t cublas_compute_dtype = GetComputeType(data_type);\n    const cudaDataType_t cuda_data_type = GetCudaDataType(data_type);\n    size_t cublas_m = 0, cublas_n = 0, cublas_k = 0;\n    int64_t cublas_lda = 0, cublas_ldb = 0, cublas_ldc = 0;\n\n    const double alpha = 1.0;\n    const auto sp_alpha = GetCublasScalarParameter(alpha, cublas_compute_dtype);\n    const double beta = 0.0;\n    const auto sp_beta = GetCublasScalarParameter(beta, cublas_compute_dtype);\n\n    // Currently only support 2D matmul.\n    DimVector in_shape(2);\n    x->shape_view().ToDimVector(&in_shape);\n\n    DimVector weight_shape(2);\n\n    const void* in_buf_ptr = x->dptr();\n    for (int idx = 0; idx < weight_size; idx++) {\n      const user_op::Tensor* weight = ctx->Tensor4ArgNameAndIndex(\"weights\", idx);\n      const user_op::Tensor* bias = ctx->Tensor4ArgNameAndIndex(\"biases\", idx);\n      user_op::Tensor* cublas_aux = ctx->Tensor4ArgNameAndIndex(\"cublas_aux\", idx);\n\n      int64_t out_feature = weight->shape_view().At(0);\n      weight->shape_view().ToDimVector(&weight_shape);\n\n      InferMatmulCublasMNK(in_shape, weight_shape,\n                           /*transpose_a=*/ep::primitive::BlasTransposeType::N,\n                           /*transpose_b=*/ep::primitive::BlasTransposeType::T, &cublas_m,\n                           &cublas_n, &cublas_k, &cublas_lda, &cublas_ldb, &cublas_ldc);\n\n      cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_RELU_AUX_BIAS;\n      bool need_aux = true;\n      void* y_ptr = nullptr;\n\n      if (idx == weight_size - 1) {\n        y_ptr = ctx->Tensor4ArgNameAndIndex(\"out\", 0)->mut_dptr();\n        if (skip_final_activation) {\n          epilogue = CUBLASLT_EPILOGUE_BIAS;\n          need_aux = false;\n        }\n      } else {\n        y_ptr = ctx->Tensor4ArgNameAndIndex(\"hidden\", idx)->mut_dptr();\n      }\n      SetCublasAttr(matmul_cache, cublas_compute_dtype, cuda_data_type, need_aux,\n                    /*transpose_a=*/ep::primitive::BlasTransposeType::N,\n                    /*transpose_b=*/ep::primitive::BlasTransposeType::T, epilogue, bias->dptr(),\n                    cublas_aux->dptr(), cublas_m, cublas_n, cublas_k, cublas_lda, cublas_ldb,\n                    cublas_ldc);\n\n      OF_CUBLAS_CHECK(cublasLtMatmul(\n          cuda_stream->cublas_lt_handle(), matmul_cache->operation_desc, &sp_alpha, weight->dptr(),\n          matmul_cache->cublas_a_desc, in_buf_ptr, matmul_cache->cublas_b_desc, &sp_beta, y_ptr,\n          matmul_cache->cublas_c_desc, y_ptr, matmul_cache->cublas_c_desc, nullptr,\n          cuda_stream->cublas_workspace(), cuda_stream->cublas_workspace_size(),\n          cuda_stream->cuda_stream()));\n\n      // Set hidden_layer ptr as next layer's input.\n      in_buf_ptr = y_ptr;\n      // Set hidden_layer shape as next layer's input shape.\n      in_shape.at(1) = out_feature;\n    }\n  }\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_CUBLAS_FUSED_MLP_KERNEL_GPU(cpp_type, data_type)      \\\n  REGISTER_USER_KERNEL(\"cublas_fused_mlp\")                             \\\n      .SetCreateFn<CublasFusedMLPKernel<cpp_type>>()                   \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \\\n                       && (user_op::HobDataType(\"out\", 0) == data_type));\n\nREGISTER_CUBLAS_FUSED_MLP_KERNEL_GPU(double, DataType::kDouble);\nREGISTER_CUBLAS_FUSED_MLP_KERNEL_GPU(float, DataType::kFloat);\nREGISTER_CUBLAS_FUSED_MLP_KERNEL_GPU(half, DataType::kFloat16);\nREGISTER_CUBLAS_FUSED_MLP_KERNEL_GPU(nv_bfloat16, DataType::kBFloat16);\n\n}  // namespace\n\n}  // namespace oneflow\n\n#endif  // CUDA_VERSION >= 11060\n"
  },
  {
    "path": "oneflow/user/kernels/cublas_fused_mlp_util.cuh",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#if defined(__CUDACC__)\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/ep/include/primitive/matmul.h\"\n#include \"oneflow/core/common/optional.h\"\n#include \"oneflow/core/device/cuda_util.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n#include <cuda.h>\n// CUBLAS_AUX_EPILOGUE only support in cuda11.4 or higher version, in cuda11.4 it need static link.\n#if CUDA_VERSION >= 11020\n\nnamespace oneflow {\n\nnamespace {\n\nconstexpr int32_t kAuxReluLdAlignRequirement = 128;\nconstexpr size_t kDefaultWorkspaceSizeMb = 4;  // 4M\n\nlong AlignReluAuxLd(long aux_ld) {\n  /*\n  ReLu bit-mask matrix leading dimension in elements.\n  Must be divisible by 128 and be no less than the number of rows in the output matrix.\n  */\n  long old_aux_ld = aux_ld;\n  return ((old_aux_ld + kAuxReluLdAlignRequirement - 1) / kAuxReluLdAlignRequirement)\n         * kAuxReluLdAlignRequirement;\n}\n\nclass CublasFusedMLPKernelCache final : public user_op::OpKernelCache {\n public:\n  CublasFusedMLPKernelCache() {\n    // Just for init.\n    OF_CUBLAS_CHECK(cublasLtMatmulDescCreate(&operation_desc, CUBLAS_COMPUTE_32F, CUDA_R_32F));\n    OF_CUBLAS_CHECK(cublasLtMatrixLayoutCreate(&cublas_a_desc, CUDA_R_32F, 1, 1, 1));\n    OF_CUBLAS_CHECK(cublasLtMatrixLayoutCreate(&cublas_b_desc, CUDA_R_32F, 1, 1, 1));\n    OF_CUBLAS_CHECK(cublasLtMatrixLayoutCreate(&cublas_c_desc, CUDA_R_32F, 1, 1, 1));\n    OF_CUBLAS_CHECK(cublasLtMatmulPreferenceCreate(&cublas_preference));\n  }\n  ~CublasFusedMLPKernelCache() override {\n    OF_CUBLAS_CHECK(cublasLtMatmulDescDestroy(operation_desc));\n    OF_CUBLAS_CHECK(cublasLtMatrixLayoutDestroy(cublas_a_desc));\n    OF_CUBLAS_CHECK(cublasLtMatrixLayoutDestroy(cublas_b_desc));\n    OF_CUBLAS_CHECK(cublasLtMatrixLayoutDestroy(cublas_c_desc));\n    OF_CUBLAS_CHECK(cublasLtMatmulPreferenceDestroy(cublas_preference));\n  }\n  cublasLtMatmulDesc_t operation_desc;\n  cublasLtMatrixLayout_t cublas_a_desc;\n  cublasLtMatrixLayout_t cublas_b_desc;\n  cublasLtMatrixLayout_t cublas_c_desc;\n  cublasLtMatmulPreference_t cublas_preference;\n};\n\nstd::shared_ptr<CublasFusedMLPKernelCache> CreateCublasFusedMLPKernelCache() {\n  std::shared_ptr<CublasFusedMLPKernelCache> cache(new CublasFusedMLPKernelCache());\n  return cache;\n}\n\nOptional<cudaDataType_t> OptCudaDataType(DataType data_type) {\n  switch (data_type) {\n    case kFloat: return CUDA_R_32F;\n    case kDouble: return CUDA_R_64F;\n    case kFloat16: return CUDA_R_16F;\n    case kBFloat16: return CUDA_R_16BF;\n    default: return NullOpt;\n  }\n}\n\ncudaDataType_t GetCudaDataType(DataType data_type) {\n  auto cuda_data_type = OptCudaDataType(data_type);\n  CHECK(cuda_data_type.has_value());\n  return cuda_data_type.value_or(CUDA_R_32F);\n}\n\ncublasComputeType_t GetComputeType(DataType data_type) {\n  switch (data_type) {\n    case kFloat:\n      if (ParseBooleanFromEnv(\"ONEFLOW_EP_CUDA_ENABLE_TF32_EXECUTION\", true)) {\n        return CUBLAS_COMPUTE_32F_FAST_TF32;\n      } else {\n        return CUBLAS_COMPUTE_32F;\n      }\n    case kDouble: return CUBLAS_COMPUTE_64F;\n    case kFloat16: {\n      const bool allow_half_accumulation =\n          ParseBooleanFromEnv(\"ONEFLOW_MATMUL_ALLOW_HALF_PRECISION_ACCUMULATION\", false);\n      if (allow_half_accumulation) {\n        return CUBLAS_COMPUTE_16F;\n      } else {\n        return CUBLAS_COMPUTE_32F;\n      }\n    }\n    case kBFloat16: return CUBLAS_COMPUTE_32F;\n    default: UNIMPLEMENTED(); return CUBLAS_COMPUTE_32F;\n  }\n}\n\nunion CublasScalarParameter {\n  double d;\n  float s;\n  half h;\n};\n\nCublasScalarParameter GetCublasScalarParameter(Scalar scalar, cublasComputeType_t compute_type) {\n  CublasScalarParameter sp{};\n  if (compute_type == CUBLAS_COMPUTE_64F) {\n    sp.d = scalar.Value<double>();\n  } else if (compute_type == CUBLAS_COMPUTE_32F || compute_type == CUBLAS_COMPUTE_32F_FAST_TF32) {\n    sp.s = scalar.Value<float>();\n  } else if (compute_type == CUBLAS_COMPUTE_16F) {\n    sp.h = static_cast<half>(scalar.Value<float>());\n  } else {\n    UNIMPLEMENTED();\n  }\n  return sp;\n}\n\nvoid InferMatmulCublasMNK(const DimVector& a_shape, const DimVector& b_shape,\n                          ep::primitive::BlasTransposeType transpose_a,\n                          ep::primitive::BlasTransposeType transpose_b, size_t* cublas_m,\n                          size_t* cublas_n, size_t* cublas_k, int64_t* cublas_lda,\n                          int64_t* cublas_ldb, int64_t* cublas_ldc) {\n  const int64_t num_a_axes = a_shape.size();\n  CHECK_GE(num_a_axes, 2);\n  const int64_t num_b_axes = b_shape.size();\n  CHECK_GE(num_b_axes, 2);\n  size_t m = 0, n = 0, k = 0;\n  if (transpose_a == ep::primitive::BlasTransposeType::N) {\n    m = a_shape.at(num_a_axes - 2);\n    k = a_shape.at(num_a_axes - 1);\n    *cublas_ldb = k;\n  } else if (transpose_a == ep::primitive::BlasTransposeType::T) {\n    m = a_shape.at(num_a_axes - 1);\n    k = a_shape.at(num_a_axes - 2);\n    *cublas_ldb = m;\n  } else {\n    UNIMPLEMENTED();\n  }\n  if (transpose_b == ep::primitive::BlasTransposeType::N) {\n    CHECK_EQ(b_shape.at(num_b_axes - 2), k);\n    n = b_shape.at(num_b_axes - 1);\n    *cublas_lda = n;\n  } else if (transpose_b == ep::primitive::BlasTransposeType::T) {\n    CHECK_EQ(b_shape.at(num_b_axes - 1), k);\n    n = b_shape.at(num_b_axes - 2);\n    *cublas_lda = k;\n  } else {\n    UNIMPLEMENTED();\n  }\n  *cublas_m = n;\n  *cublas_n = m;\n  *cublas_k = k;\n  *cublas_ldc = n;\n}\n\nvoid SetCublasMatrixLayout(cublasLtMatrixLayout_t layout_desc, cudaDataType_t cuda_data_type,\n                           cublasOperation_t cublas_trans, const size_t cublas_m1,\n                           const size_t cublas_n1, int64_t cublas_ld) {\n  OF_CUBLAS_CHECK(cublasLtMatrixLayoutSetAttribute(layout_desc, CUBLASLT_MATRIX_LAYOUT_TYPE,\n                                                   &cuda_data_type, sizeof(cuda_data_type)));\n  OF_CUBLAS_CHECK(cublasLtMatrixLayoutSetAttribute(\n      layout_desc, CUBLASLT_MATRIX_LAYOUT_ROWS,\n      cublas_trans == CUBLAS_OP_N ? &cublas_m1 : &cublas_n1, sizeof(cublas_m1)));\n  OF_CUBLAS_CHECK(cublasLtMatrixLayoutSetAttribute(\n      layout_desc, CUBLASLT_MATRIX_LAYOUT_COLS,\n      cublas_trans == CUBLAS_OP_N ? &cublas_n1 : &cublas_m1, sizeof(cublas_m1)));\n  OF_CUBLAS_CHECK(cublasLtMatrixLayoutSetAttribute(layout_desc, CUBLASLT_MATRIX_LAYOUT_LD,\n                                                   &cublas_ld, sizeof(cublas_ld)));\n}\n\nvoid SetCublasEpilogue(const CublasFusedMLPKernelCache* matmul_cache, cublasLtEpilogue_t epilogue,\n                       const void* bias_ptr, const void* aux_ptr) {\n  // Set epilogue\n  OF_CUBLAS_CHECK(cublasLtMatmulDescSetAttribute(\n      matmul_cache->operation_desc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue)));\n#if CUDA_VERSION >= 11060\n  const bool has_bias =\n      (epilogue == CUBLASLT_EPILOGUE_RELU_BIAS || epilogue == CUBLASLT_EPILOGUE_BIAS\n       || epilogue == CUBLASLT_EPILOGUE_RELU_AUX_BIAS || epilogue == CUBLASLT_EPILOGUE_DRELU_BGRAD\n       || epilogue == CUBLASLT_EPILOGUE_BGRADB);\n#else\n  const bool has_bias =\n      (epilogue == CUBLASLT_EPILOGUE_RELU_BIAS || epilogue == CUBLASLT_EPILOGUE_BIAS);\n#endif  // CUDA_VERSION >= 11060\n  if (has_bias) {\n    // Set bias ptr\n    OF_CUBLAS_CHECK(cublasLtMatmulDescSetAttribute(matmul_cache->operation_desc,\n                                                   CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias_ptr,\n                                                   sizeof(bias_ptr)));\n  } else {\n    // unset\n    bias_ptr = nullptr;\n    OF_CUBLAS_CHECK(cublasLtMatmulDescSetAttribute(matmul_cache->operation_desc,\n                                                   CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias_ptr,\n                                                   sizeof(bias_ptr)));\n  }\n\n#if CUDA_VERSION >= 11060\n  if (epilogue == CUBLASLT_EPILOGUE_RELU_AUX_BIAS || epilogue == CUBLASLT_EPILOGUE_DRELU_BGRAD) {\n    // Set aux ptr for backward.\n    OF_CUBLAS_CHECK(cublasLtMatmulDescSetAttribute(matmul_cache->operation_desc,\n                                                   CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER,\n                                                   &aux_ptr, sizeof(aux_ptr)));\n  } else {\n    // Clear Aux ptr.\n    aux_ptr = nullptr;\n    OF_CUBLAS_CHECK(cublasLtMatmulDescSetAttribute(matmul_cache->operation_desc,\n                                                   CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER,\n                                                   &aux_ptr, sizeof(aux_ptr)));\n  }\n#endif  // CUDA_VERSION >= 11060\n}\n\nvoid SetCublasAttr(const CublasFusedMLPKernelCache* matmul_grad_cache,\n                   const cublasComputeType_t cublas_compute_dtype,\n                   const cudaDataType_t cuda_data_type, bool need_aux,\n                   ep::primitive::BlasTransposeType transpose_a,\n                   ep::primitive::BlasTransposeType transpose_b, cublasLtEpilogue_t epilogue,\n                   const void* d_bias_ptr, const void* aux_ptr, size_t cublas_m, size_t cublas_n,\n                   size_t cublas_k, int64_t cublas_lda, int64_t cublas_ldb, int64_t cublas_ldc) {\n  OF_CUBLAS_CHECK(cublasLtMatmulDescSetAttribute(\n      matmul_grad_cache->operation_desc, CUBLASLT_MATMUL_DESC_COMPUTE_TYPE, &cublas_compute_dtype,\n      sizeof(cublas_compute_dtype)));\n\n  size_t workspace_size =\n      ParseIntegerFromEnv(\"ONEFLOW_EP_CUDA_CUBLAS_WORKSPACE_SIZE_MB\", kDefaultWorkspaceSizeMb)\n      * 1024 * 1024;\n  OF_CUBLAS_CHECK(cublasLtMatmulPreferenceSetAttribute(matmul_grad_cache->cublas_preference,\n                                                       CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,\n                                                       &workspace_size, sizeof(workspace_size)));\n\n#if CUDA_VERSION < 12000\n  uint32_t pointer_mode = CUBLASLT_POINTER_MODE_MASK_HOST;\n  OF_CUBLAS_CHECK(cublasLtMatmulPreferenceSetAttribute(matmul_grad_cache->cublas_preference,\n                                                       CUBLASLT_MATMUL_PREF_POINTER_MODE_MASK,\n                                                       &pointer_mode, sizeof(pointer_mode)));\n#endif  // CUDA_VERSION < 12000\n\n  // transpose_a = False, transpose_b = True. But in cublas is reversed.\n  const cublasOperation_t cublas_trans_a =\n      transpose_b == ep::primitive::BlasTransposeType::T ? CUBLAS_OP_T : CUBLAS_OP_N;\n  const cublasOperation_t cublas_trans_b =\n      transpose_a == ep::primitive::BlasTransposeType::T ? CUBLAS_OP_T : CUBLAS_OP_N;\n  OF_CUBLAS_CHECK(cublasLtMatmulDescSetAttribute(matmul_grad_cache->operation_desc,\n                                                 CUBLASLT_MATMUL_DESC_TRANSA, &cublas_trans_a,\n                                                 sizeof(cublas_trans_a)));\n  OF_CUBLAS_CHECK(cublasLtMatmulDescSetAttribute(matmul_grad_cache->operation_desc,\n                                                 CUBLASLT_MATMUL_DESC_TRANSB, &cublas_trans_b,\n                                                 sizeof(cublas_trans_b)));\n\n  // Set epilogue\n  SetCublasEpilogue(matmul_grad_cache, epilogue, d_bias_ptr, aux_ptr);\n/*\nSet AUX pointer LD\nIf is used for CUBLASLT_EPILOGUE_DRELU_BGRAD, the AUX_LD need to align 128bit.\nIf is used for CUBLASLT_EPILOGUE_DGELU_BGRAD, the AUX_LD need to align 8.\nFor more details you can refer to CUBLAS docs:\nhttps://docs.nvidia.com/cuda/cublas/index.html#cublasLtMatmulDescAttributes_t\n`CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD`.\n*/\n#if CUDA_VERSION >= 11060\n  if (need_aux) {\n    long aligned_aux_ld = AlignReluAuxLd(cublas_ldc);\n    OF_CUBLAS_CHECK(cublasLtMatmulDescSetAttribute(matmul_grad_cache->operation_desc,\n                                                   CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD,\n                                                   &aligned_aux_ld, sizeof(aligned_aux_ld)));\n  } else {\n    long no_need_aligned_aux_ld = 0;\n    OF_CUBLAS_CHECK(cublasLtMatmulDescSetAttribute(\n        matmul_grad_cache->operation_desc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD,\n        &no_need_aligned_aux_ld, sizeof(no_need_aligned_aux_ld)));\n  }\n#endif  // CUDA_VERSION >= 11060\n  // Set matrix layout\n  SetCublasMatrixLayout(matmul_grad_cache->cublas_a_desc, cuda_data_type, cublas_trans_a, cublas_m,\n                        cublas_k, cublas_lda);\n  SetCublasMatrixLayout(matmul_grad_cache->cublas_b_desc, cuda_data_type, cublas_trans_b, cublas_k,\n                        cublas_n, cublas_ldb);\n  SetCublasMatrixLayout(matmul_grad_cache->cublas_c_desc, cuda_data_type, CUBLAS_OP_N, cublas_m,\n                        cublas_n, cublas_ldc);\n}\n\n}  // namespace\n\n}  // namespace oneflow\n\n#endif  // CUDA_VERSION >= 11020\n\n#endif  // defined(__CUDACC__)\n"
  },
  {
    "path": "oneflow/user/kernels/cufft_plan_cache.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#ifndef ONEFLOW_USER_KERNELS_CUFFT_PLAN_CACHE_H_\n#define ONEFLOW_USER_KERNELS_CUFFT_PLAN_CACHE_H_\n\n#include <cufft.h>\n#include <cufftXt.h>\n#include <cuda_fp16.h>\n#include <cstdint>\n#include <functional>\n#include <numeric>\n#include \"oneflow/core/common/data_type.h\"\n#include \"oneflow/core/common/data_type.pb.h\"\n#include \"oneflow/core/common/shape_vec.h\"\n#include \"oneflow/core/common/throw.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/new_kernel_util.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n#include \"oneflow/core/device/cuda_util.h\"\n#include \"oneflow/core/kernel/kernel.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nconstexpr int max_rank = 3;\n\nenum class CUFFT_EXCUTETYPE { R2C, C2C, C2R };\n\nstruct CuFFTDataTypeDesc {\n  cudaDataType inputtype;\n  cudaDataType outputtype;\n  cudaDataType executiontype;\n};\n\n}  // namespace\n\nclass CuFFTHandle {\n  cufftHandle handle;\n\n public:\n  CuFFTHandle() { OF_CUFFT_CHECK(cufftCreate(&handle)); }\n\n  cufftHandle& get() { return handle; }\n  const cufftHandle& get() const { return handle; }\n\n  ~CuFFTHandle() { cufftDestroy(handle); }\n};\n\n// NOTE: The implementation of `CuFFTDataLayout`, `cufft_simple_embed` and `as_cufft_embed` are\n// mostly taken from pytorch. For more details pls refer to `CuFFTPlanCache.h` in PyTorch.\ntypedef long long cufft_size_type;\ntypedef small_vector<cufft_size_type, max_rank + 1> cufft_dim_vector;\nstruct CuFFTDataLayout {\n  small_vector<cufft_size_type, 5> embed;\n  cufft_size_type stride, dist;\n  bool must_clone, simple;\n};\n\n// Returns a cufft embedding for a contiguous signal of the given size.\n// e.g. if the input is cloned, this will be the resulting data layout\ninline CuFFTDataLayout cufft_simple_embed(const cufft_dim_vector& sizes, bool onesided) {\n  CuFFTDataLayout layout;\n  layout.simple = true;\n  layout.must_clone = false;\n  layout.embed.assign(sizes.cbegin() + 1, sizes.cend());\n  if (onesided) { layout.embed.back() = sizes.back() / 2 + 1; }\n  layout.stride = 1;\n  layout.dist = 1;\n  for (const auto& len : layout.embed) { layout.dist *= len; }\n  return layout;\n}\n\n// Convert strides to a CuFFT embedded representation.\n// If strides cannot be embedded, returns a simple layout and sets must_clone flag\ninline CuFFTDataLayout as_cufft_embed(const cufft_dim_vector& strides,\n                                      const cufft_dim_vector& sizes, bool onesided) {\n  const auto signal_ndim = strides.size() - 1;\n  CuFFTDataLayout layout;\n  auto last_stride = strides[signal_ndim];\n  layout.must_clone = (last_stride <= 0);\n\n  const auto last_dim_size = onesided ? sizes[signal_ndim] / 2 + 1 : sizes[signal_ndim];\n\n  const auto signal_numel = std::accumulate(sizes.begin() + 1, sizes.end() - 1, (cufft_size_type)1,\n                                            std::multiplies<cufft_size_type>())\n                            * last_dim_size;\n\n  // Zero stides are not allowed, even if the batch size is one.\n  // If that happens just set a dummy case\n  if (sizes[0] == 1) {\n    layout.dist = signal_numel;\n  } else if (strides[0] == 0) {\n    layout.must_clone = true;\n  } else {\n    layout.dist = strides[0];\n  }\n\n  // Calculate the embedding shape, or set must_clone if the strides cannot be embedded\n  layout.embed.resize(signal_ndim);\n  for (auto i = signal_ndim - 1; !layout.must_clone && i > 0; i--) {\n    auto stride = strides[i];\n    if (sizes[i] == 1) {\n      layout.embed[i] = 1;\n    } else if (stride > 0 && stride % last_stride == 0) {\n      layout.embed[i] = stride / last_stride;\n      last_stride = stride;\n    } else {\n      layout.must_clone = true;\n    }\n  }\n  // must_clone == false\n  if (layout.must_clone) {\n    // If the input needs to be cloned, assume it will be contiguous\n    layout = cufft_simple_embed(sizes, onesided);\n    layout.must_clone = true;\n  } else {\n    layout.embed[0] = sizes[1];\n    layout.stride = strides[signal_ndim];\n\n    // Determine if layout represents a simple embedding (contiguous data)\n    layout.simple = [&] {\n      FOR_RANGE(int, i, 1, signal_ndim - 1) {\n        if (layout.embed[i] != sizes[i + 1]) { return false; }\n      }\n      return (layout.stride == 1 && layout.dist == signal_numel\n              && layout.embed.back() == last_dim_size);\n    }();\n  }\n  return layout;\n}\n\nstruct CuFFTParams {\n  int64_t ndim;\n  cufft_dim_vector input_shape;\n  cufft_dim_vector input_strides;\n  cufft_dim_vector output_shape;\n  cufft_dim_vector output_strides;\n  cufft_dim_vector data_shape;\n  CUFFT_EXCUTETYPE excute_type;\n  DataType real_data_type;\n\n  CuFFTParams() = default;\n  CuFFTParams(const Shape& in_shape, const Shape& out_shape, const Stride& in_strides,\n              const Stride& out_strides, int64_t dims, CUFFT_EXCUTETYPE type, DataType real)\n      : ndim(dims), excute_type(type), real_data_type(real) {\n    CHECK_OR_THROW(ndim >= 1 && ndim <= max_rank);\n    CHECK_OR_THROW(in_shape.size() == ndim + 1);\n    CHECK_OR_THROW(out_shape.size() == ndim + 1);\n    CHECK_OR_THROW(in_shape.size() == in_strides.size());\n    CHECK_OR_THROW(out_shape.size() == out_strides.size());\n    data_shape.resize(ndim + 1);\n    input_shape.resize(in_shape.size());\n    input_strides.resize(in_strides.size());\n    output_shape.resize(out_shape.size());\n    output_strides.resize(out_strides.size());\n\n    std::copy(in_strides.begin(), in_strides.end(), input_strides.begin());\n    std::copy(out_strides.begin(), out_strides.end(), output_strides.begin());\n    std::copy(in_shape.begin(), in_shape.end(), input_shape.begin());\n    std::copy(out_shape.begin(), out_shape.end(), output_shape.begin());\n\n    data_shape[0] = input_shape[0];  // batch size\n    FOR_RANGE(int64_t, i, 0, ndim) {\n      auto in_size = input_shape[i + 1];\n      auto out_size = output_shape[i + 1];\n      data_shape[i + 1] = std::max(in_size, out_size);\n      CHECK_OR_THROW(in_size == data_shape[i + 1] || in_size == (data_shape[i + 1] / 2) + 1);\n      CHECK_OR_THROW(out_size == data_shape[i + 1] || out_size == (data_shape[i + 1] / 2) + 1);\n    }\n  }\n};\n\nclass CuFFTConfig {\n public:\n  CuFFTConfig(const CuFFTConfig&) = delete;\n  CuFFTConfig& operator=(CuFFTConfig const&) = delete;\n  ~CuFFTConfig() = default;\n\n  explicit CuFFTConfig(CuFFTParams& params) {  // NOLINT\n\n    if (params.real_data_type == kBFloat16 || params.real_data_type == kFloat16) {\n      // CuFFT support half data type, but there are some limits:\n      //  https://docs.nvidia.com/cuda/cufft/#half-precision-cufft-transforms\n      CHECK_OR_THROW(false) << \"Unsupported datatype kBFloat16 and kFloat16.\";\n    }\n\n    CuFFTDataLayout input_layout = as_cufft_embed(params.input_strides, params.data_shape,\n                                                  params.excute_type == CUFFT_EXCUTETYPE::C2R);\n    CuFFTDataLayout output_layout = as_cufft_embed(params.output_strides, params.data_shape,\n                                                   params.excute_type == CUFFT_EXCUTETYPE::R2C);\n\n    bool clone_input = input_layout.must_clone;  // that means: input should be contiguous because\n                                                 // original input can't be embeded\n    const bool is_layout_simple = input_layout.simple && output_layout.simple;\n\n    // disable cuFFT the default behavior of allocating work area at plan generating time\n    OF_CUFFT_CHECK(cufftSetAutoAllocation(plan_handle_.get(), 0));\n    infer_cufft_type_(params.excute_type, params.real_data_type);\n\n    // exclude input_shape[0] whtich is batch dim\n    cufft_dim_vector fft_shape(params.data_shape.begin() + 1, params.data_shape.end());\n    cufft_size_type batch = params.data_shape[0];\n    if (is_layout_simple) {\n      OF_CUFFT_CHECK(cufftXtMakePlanMany(plan_handle_.get(), params.ndim, fft_shape.data(),\n                                         /*inembed=*/nullptr, /*istride=*/1, /*idist=*/1,\n                                         /*inputtype=*/data_type_desc_.inputtype,\n                                         /*onembed=*/nullptr, /*ostride=*/1, /*odist=*/1,\n                                         /*outputtype=*/data_type_desc_.outputtype,\n                                         /*batch=*/batch, /*workSize=*/&work_size_,\n                                         /*executiontype=*/data_type_desc_.executiontype));\n    } else {\n      OF_CUFFT_CHECK(cufftXtMakePlanMany(\n          plan_handle_.get(), params.ndim, fft_shape.data(),\n          /*inembed=*/input_layout.embed.data(), /*istride=*/input_layout.stride,\n          /*idist=*/input_layout.dist, /*inputtype=*/data_type_desc_.inputtype,\n          /*onembed=*/output_layout.embed.data(), /*ostride=*/output_layout.stride,\n          /*odist=*/output_layout.dist, /*outputtype=*/data_type_desc_.outputtype,\n          /*batch=*/batch, /*workSize=*/&work_size_,\n          /*executiontype=*/data_type_desc_.executiontype));\n    }\n  }\n\n  size_t workspace_size() const { return work_size_; }\n  const cufftHandle& plan() const { return plan_handle_.get(); }\n\n  void excute(void* input, void* output, bool forward) {\n    OF_CUFFT_CHECK(\n        cufftXtExec(plan_handle_.get(), input, output, forward ? CUFFT_FORWARD : CUFFT_INVERSE));\n  }\n\n private:\n  void infer_cufft_type_(CUFFT_EXCUTETYPE excute_type, DataType real_data_type) {\n    if (real_data_type == kFloat) {\n      data_type_desc_.executiontype = CUDA_C_32F;\n      data_type_desc_.inputtype = excute_type == CUFFT_EXCUTETYPE::R2C ? CUDA_R_32F : CUDA_C_32F;\n      data_type_desc_.outputtype = excute_type == CUFFT_EXCUTETYPE::C2R ? CUDA_R_32F : CUDA_C_32F;\n    } else if (real_data_type == kDouble) {\n      data_type_desc_.executiontype = CUDA_C_64F;\n      data_type_desc_.inputtype = excute_type == CUFFT_EXCUTETYPE::R2C ? CUDA_R_64F : CUDA_C_64F;\n      data_type_desc_.outputtype = excute_type == CUFFT_EXCUTETYPE::C2R ? CUDA_R_64F : CUDA_C_64F;\n    } else {\n      CHECK_OR_THROW(false) << \"cuFFT doesn't support type \" << real_data_type;\n    }\n  }\n\n  CuFFTHandle plan_handle_;\n  CuFFTDataTypeDesc data_type_desc_;\n  size_t work_size_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_KERNELS_CUFFT_PLAN_CACHE_H_\n"
  },
  {
    "path": "oneflow/user/kernels/cum_backward_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/new_kernel_util.h\"\n\nnamespace oneflow {\nnamespace {\n// CumProd backward, formula: flip(cumsum(flip(dY * Y))) / X.\ntemplate<typename T>\nvoid CumProdBackward(const T* dy_ptr, T* dx_ptr, const T* output_ptr, const T* input_ptr,\n                     const int64_t up_space, const int64_t space, const int64_t down_space,\n                     const int64_t elem_cnt) {\n  const auto step = space * down_space;\n  for (size_t i = 0; i < up_space; i++) {\n    const size_t base_ptr_offset = step * i;\n    const T* input_ptr_base = input_ptr + base_ptr_offset;\n    const T* output_ptr_base = output_ptr + base_ptr_offset;\n    const T* dy_ptr_base = dy_ptr + base_ptr_offset;\n    T* dx_ptr_base = dx_ptr + base_ptr_offset;\n\n    // Use dx as tmp buffer for finding 0 element in the input.\n    for (size_t j = 0; j < space; j++) {\n      const size_t ptr_offset = j * down_space;\n      auto* cur_input_ptr = input_ptr_base + ptr_offset;\n\n      auto* cumsum_zeros_number_ptr = dx_ptr_base + ptr_offset;\n      auto* last_cumsum_zeros_number_ptr = cumsum_zeros_number_ptr - down_space;\n      for (size_t k = 0; k < down_space; k++) {\n        int is_zero = cur_input_ptr[k] == 0 ? 1 : 0;\n        cumsum_zeros_number_ptr[k] = is_zero + (j == 0 ? 0 : last_cumsum_zeros_number_ptr[k]);\n      }\n    }\n\n    for (size_t j = 0; j < down_space; j++) {\n      const auto* cur_output_ptr = output_ptr_base + j;\n      const auto* cur_input_ptr = input_ptr_base + j;\n      const auto* cur_dy_ptr = dy_ptr_base + j;\n      auto* cur_dx_ptr = dx_ptr_base + j;\n      const auto* cumsum_zeros_number_ptr = dx_ptr_base + j;\n\n      size_t first_zero_index = space;\n      // Find index of first zero in input.\n      for (size_t k = 0; k < space; k++) {\n        if (cumsum_zeros_number_ptr[k * down_space] == 1) {\n          first_zero_index = k;\n          break;\n        }\n      }\n      // Suppose z is index of first zero element in input,\n      // for element which index is less than z grad is computed as below:\n      T reverse_cumsum = 0;\n      for (size_t k = 0; k < first_zero_index; k++) {\n        const size_t data_offset = (first_zero_index - k - 1) * down_space;\n        reverse_cumsum += cur_output_ptr[data_offset] * cur_dy_ptr[data_offset];\n        cur_dx_ptr[data_offset] = reverse_cumsum / cur_input_ptr[data_offset];\n      }\n      // For where index is z, its grad is computed as below:\n      if (first_zero_index == space) { continue; }\n      T cumprod = 1;\n      T cumsum = 0;\n      T cumprod_before_first_zero =\n          first_zero_index == 0 ? 1 : cur_output_ptr[(first_zero_index - 1) * down_space];\n      for (size_t k = first_zero_index; k < space; k++) {\n        const size_t data_offset = k * down_space;\n        // Recover dx_ptr default value\n        if (cur_dx_ptr[data_offset] >= 1) { cur_dx_ptr[data_offset] = 0; }\n        if (k != first_zero_index) { cumprod *= cur_input_ptr[data_offset]; }\n        cumsum += cumprod_before_first_zero * cumprod * cur_dy_ptr[data_offset];\n      }\n      cur_dx_ptr[first_zero_index * down_space] = cumsum;\n    }\n  }\n}\n}  // namespace\n\ntemplate<typename T>\nclass CpuCumProdGradKernel final : public user_op::OpKernel {\n public:\n  CpuCumProdGradKernel() = default;\n  ~CpuCumProdGradKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const auto* output = ctx->Tensor4ArgNameAndIndex(\"output\", 0);\n    const auto* input = ctx->Tensor4ArgNameAndIndex(\"input\", 0);\n    const auto* dy = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    auto* dx = ctx->Tensor4ArgNameAndIndex(\"dx\", 0);\n    const int64_t elem_cnt = dy->shape_view().elem_cnt();\n    if (elem_cnt == 0) { return; }\n\n    const auto* output_ptr = output->dptr<T>();\n    const auto* input_ptr = input->dptr<T>();\n    const auto* dy_ptr = dy->dptr<T>();\n    auto* dx_ptr = dx->mut_dptr<T>();\n\n    // data partition: up_space|space|down_space\n    auto dim = ctx->Attr<int64_t>(\"dim\");\n    auto up_space = elem_cnt / dx->shape_view().Count(dim);\n    auto space = dx->shape_view().At(dim);\n    auto down_space = dx->shape_view().Count(dim + 1);\n    if (space == 1) {\n      Memcpy<DeviceType::kCPU>(ctx->stream(), dx_ptr, dy_ptr, elem_cnt * sizeof(T));\n      return;\n    }\n    CumProdBackward(dy_ptr, dx_ptr, output_ptr, input_ptr, up_space, space, down_space, elem_cnt);\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_CPU_CUMPROD_GRAD_KERNEL(dtype)                       \\\n  REGISTER_USER_KERNEL(\"cumprod_grad\")                                \\\n      .SetCreateFn<CpuCumProdGradKernel<dtype>>()                     \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \\\n                       && (user_op::HobDataType(\"dx\", 0) == GetDataType<dtype>::value));\n\nREGISTER_CPU_CUMPROD_GRAD_KERNEL(float)\nREGISTER_CPU_CUMPROD_GRAD_KERNEL(double)\n#undef REGISTER_CPU_CUMPROD_GRAD_KERNEL\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/cum_backward_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/device/cuda_util.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n#include \"oneflow/core/kernel/new_kernel_util.h\"\n\nnamespace oneflow {\n#ifdef WITH_CUDA\nnamespace {\ntemplate<typename T>\n__global__ void CumProdBackward(const T* dy_ptr, T* dx_ptr, const T* output_ptr, const T* input_ptr,\n                                const int64_t up_space, const int64_t space,\n                                const int64_t down_space, const int64_t thread_num) {\n  // A thread is responsible for a row along specific dimension.\n  const size_t up_space_step = space * down_space;\n  CUDA_1D_KERNEL_LOOP_T(size_t, i, thread_num) {\n    const size_t up_space_id = i / down_space;\n    const size_t down_space_id = i % down_space;\n    const size_t ptr_offset = up_space_id * up_space_step + down_space_id;\n    auto* dy_ptr_base = dy_ptr + ptr_offset;\n    auto* dx_ptr_base = dx_ptr + ptr_offset;\n    auto* input_ptr_base = input_ptr + ptr_offset;\n    auto* output_ptr_base = output_ptr + ptr_offset;\n\n    // Buffer storing number of zero element along specific dimension.\n    // Use dx as tmp buffer.\n    for (size_t j = 0; j < space; j++) {\n      const size_t data_offset = j * down_space;\n      int is_zero = input_ptr_base[data_offset] == 0 ? 1 : 0;\n      dx_ptr_base[data_offset] = is_zero + (j == 0 ? 0 : dx_ptr_base[data_offset - down_space]);\n    }\n\n    // Find index of first zero in input.\n    size_t first_zero_index = space;\n    for (size_t j = 0; j < space; j++) {\n      const size_t data_offset = j * down_space;\n      if (dx_ptr_base[data_offset] == 1) {\n        first_zero_index = j;\n        break;\n      }\n    }\n\n    // Suppose z is index of first zero element in input,\n    // for element which index is less than z grad is computed as below:\n    T reverse_cumsum = 0;\n    for (size_t j = 0; j < first_zero_index; j++) {\n      const size_t cur_index = first_zero_index - j - 1;\n      const size_t data_offset = cur_index * down_space;\n      reverse_cumsum += output_ptr_base[data_offset] * dy_ptr_base[data_offset];\n      dx_ptr_base[data_offset] = reverse_cumsum / input_ptr_base[data_offset];\n    }\n\n    // Where index is z, its grad is computed as below:\n    if (first_zero_index == space) { return; }\n    T cumprod = 1;\n    T cumsum = 0;\n    T cumprod_before_first_zero =\n        first_zero_index == 0 ? 1 : output_ptr_base[(first_zero_index - 1) * down_space];\n    for (size_t j = first_zero_index; j < space; j++) {\n      const size_t down_space_offset = j * down_space;\n      // Recover dx_ptr default value\n      if (dx_ptr_base[down_space_offset] >= 1) { dx_ptr_base[down_space_offset] = 0; }\n      if (j != first_zero_index) { cumprod *= input_ptr_base[down_space_offset]; }\n      cumsum += cumprod_before_first_zero * dy_ptr_base[down_space_offset] * cumprod;\n    }\n    dx_ptr_base[first_zero_index * down_space] = cumsum;\n  }\n}\n}  // namespace\n\ntemplate<typename T>\nclass GpuCumProdGradKernel final : public user_op::OpKernel {\n public:\n  GpuCumProdGradKernel() = default;\n  ~GpuCumProdGradKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const auto* output = ctx->Tensor4ArgNameAndIndex(\"output\", 0);\n    const auto* input = ctx->Tensor4ArgNameAndIndex(\"input\", 0);\n    const auto* dy = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    auto* dx = ctx->Tensor4ArgNameAndIndex(\"dx\", 0);\n    const auto elem_cnt = dy->shape_view().elem_cnt();\n    if (!elem_cnt) { return; }\n\n    const auto* output_ptr = output->dptr<T>();\n    const auto* input_ptr = input->dptr<T>();\n    const auto* dy_ptr = dy->dptr<T>();\n    auto* dx_ptr = dx->mut_dptr<T>();\n\n    // Data partition: up_space|space|down_space\n    auto dim = ctx->Attr<int64_t>(\"dim\");\n    const auto up_space = elem_cnt / dx->shape_view().Count(dim);\n    const auto space = dx->shape_view().At(dim);\n    const auto down_space = dx->shape_view().Count(dim + 1);\n    const size_t thread_num = up_space * down_space;\n\n    if (space == 1) {\n      Memcpy<DeviceType::kCUDA>(ctx->stream(), dx_ptr, dy_ptr, elem_cnt * sizeof(T));\n      return;\n    }\n    ep::CudaLaunchConfig config{};\n    ctx->stream()->As<ep::CudaStream>()->InitLaunchConfigWithWaves(\n        &config, thread_num, /*DefaultBlockSize*/ 256, /*max_wave*/ 1);\n    CumProdBackward<<<config.grid_dim, config.block_dim, /*shared memory*/ 0,\n                      ctx->stream()->As<ep::CudaStream>()->cuda_stream()>>>(\n        dy_ptr, dx_ptr, output_ptr, input_ptr, up_space, space, down_space, thread_num);\n  }\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_CUDA_CUMPROD_GRAD_KERNEL(dtype)                       \\\n  REGISTER_USER_KERNEL(\"cumprod_grad\")                                 \\\n      .SetCreateFn<GpuCumProdGradKernel<dtype>>()                      \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \\\n                       && (user_op::HobDataType(\"dx\", 0) == GetDataType<dtype>::value));\n\nREGISTER_CUDA_CUMPROD_GRAD_KERNEL(float)\nREGISTER_CUDA_CUMPROD_GRAD_KERNEL(double)\n#undef REGISTER_CUDA_CUMPROD_GRAD_KERNEL\n#endif\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/cum_forward_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/new_kernel_util.h\"\n#include \"oneflow/core/ndarray/binary_func.h\"\nnamespace oneflow {\n\nnamespace {\ntemplate<typename T, template<typename> class BinaryFunc>\nvoid CumForward(const T* in_ptr, T* out_ptr, int64_t up_space, int64_t space, int64_t down_space,\n                int64_t elem_cnt) {\n  std::copy_n(in_ptr, elem_cnt, out_ptr);\n  auto* tmp_out_ptr_base = out_ptr;\n  auto step = space * down_space;\n  for (auto i = 0; i < up_space; i++) {\n    for (auto j = 1; j < space; j++) {\n      auto* tmp_out_ptr = tmp_out_ptr_base + j * down_space;\n      auto* last_tmp_out_ptr = tmp_out_ptr - down_space;\n      for (auto k = 0; k < down_space; k++) {\n        tmp_out_ptr[k] = BinaryFunc<T>::Invoke(tmp_out_ptr[k], last_tmp_out_ptr[k]);\n      }\n    }\n    tmp_out_ptr_base += step;\n  }\n}\n}  // namespace\n\ntemplate<typename T, template<typename> class BinaryFunc>\nclass CpuCumKernel : public user_op::OpKernel {\n public:\n  CpuCumKernel() = default;\n  ~CpuCumKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const auto* in = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    auto elem_cnt = in->shape_view().elem_cnt();\n    // judge whether tensor has 0 size dimension first\n    if (!elem_cnt) { return; }\n\n    auto* out = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    auto dim = ctx->Attr<int64_t>(\"dim\");\n    const auto* in_ptr = in->dptr<T>();\n    auto* out_ptr = out->mut_dptr<T>();\n\n    // data partition: up_space|space|down_space\n    auto up_space = elem_cnt / in->shape_view().Count(dim);\n    auto space = in->shape_view().At(dim);\n    auto down_space = in->shape_view().Count(dim + 1);\n\n    CumForward<T, BinaryFunc>(in_ptr, out_ptr, up_space, space, down_space, elem_cnt);\n  }\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define CUMOP_SEQ                                \\\n  OF_PP_MAKE_TUPLE_SEQ(\"cumprod\", BinaryFuncMul) \\\n  OF_PP_MAKE_TUPLE_SEQ(\"cumsum\", BinaryFuncAdd)\n\n#define REGISTER_CUMOP_KERNEL(dtype, op_name, op_functor)                                       \\\n  REGISTER_USER_KERNEL(op_name).SetCreateFn<CpuCumKernel<dtype, op_functor>>().SetIsMatchedHob( \\\n      (user_op::HobDeviceType() == DeviceType::kCPU)                                            \\\n      && (user_op::HobDataType(\"y\", 0) == GetDataType<dtype>::value));\n\n#define REGISTER_CUMOP_KERNEL_WITH_DTYPE(op_name, op_functor) \\\n  REGISTER_CUMOP_KERNEL(int32_t, op_name, op_functor)         \\\n  REGISTER_CUMOP_KERNEL(int64_t, op_name, op_functor)         \\\n  REGISTER_CUMOP_KERNEL(float, op_name, op_functor)           \\\n  REGISTER_CUMOP_KERNEL(double, op_name, op_functor)\n\nOF_PP_FOR_EACH_TUPLE(REGISTER_CUMOP_KERNEL_WITH_DTYPE, CUMOP_SEQ);\n\n#undef REGISTER_CUMOP_KERNEL\n#undef REGISTER_CUMOP_KERNEL_WITH_DTYPE\n#undef CUMOP_SEQ\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/cum_forward_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <cub/cub.cuh>\n#include <type_traits>\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/device/cuda_util.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n#include \"oneflow/core/kernel/new_kernel_util.h\"\n#include \"oneflow/core/ndarray/binary_func.h\"\n\nnamespace oneflow {\n#ifdef WITH_CUDA\nnamespace {\n\ntemplate<typename T>\ninline T CeilDiv(T n, T m) {\n  return (n + m - 1) / m;\n}\n\ntemplate<typename T>\nstruct SumFunctor {\n  __device__ __forceinline__ T operator()(const T a, const T b) const { return a + b; }\n};\ntemplate<typename T>\nstruct ProdFunctor {\n  __device__ __forceinline__ T operator()(const T a, const T b) const { return a * b; }\n};\n\ntemplate<typename T, template<typename> class BinaryFunc>\nsize_t InferTmpBufferSize(user_op::InferContext* ctx) {\n  const Shape& in_shape = ctx->InputShape(\"x\", 0);\n  const int64_t dim = ctx->Attr<int64_t>(\"dim\");\n  const size_t dim_size = in_shape.At(dim);\n  if (in_shape.elem_cnt() == dim_size) {\n    size_t temp_storage_bytes = 0;\n    OF_CUDA_CHECK(cub::DeviceScan::InclusiveScan(nullptr, temp_storage_bytes,\n                                                 static_cast<T*>(nullptr), static_cast<T*>(nullptr),\n                                                 BinaryFunc<T>(), dim_size));\n    return GetCudaAlignedSize(temp_storage_bytes);\n  }\n  return 0;\n}\n\n// total thread number: cs_up_space * cs_down_space\n// in cs_down_space part, use cs_down_space threads\n// to calculate as follows(m=cs_down_space-1, n=cs_space-1, '|' stands for dependency):\ntemplate<typename T, template<typename> class BinaryFunc>\n__global__ void CumForwardGpu(const T* in_ptr, T* out_ptr, int64_t cs_up_space, int64_t cs_space,\n                              int64_t cs_down_space) {\n  CUDA_1D_KERNEL_LOOP(i, cs_up_space * cs_down_space) {\n    auto cs_up_space_id = i / cs_down_space;\n    auto cs_down_space_id = i - (i / cs_down_space) * cs_down_space;\n\n    auto* in_ptr_base = in_ptr + cs_up_space_id * cs_space * cs_down_space + cs_down_space_id;\n    auto* out_ptr_base = out_ptr + cs_up_space_id * cs_space * cs_down_space + cs_down_space_id;\n\n    // calculate cs_space data in one thread\n    for (auto j = 0; j < cs_space; j++) {\n      auto idx = j * cs_down_space;\n      out_ptr_base[idx] = in_ptr_base[idx];\n      if (j != 0) {\n        out_ptr_base[idx] = BinaryFunc<T>()(out_ptr_base[idx], out_ptr_base[idx - cs_down_space]);\n      }\n    }\n  }\n}\n\ntemplate<typename T, template<typename> class BinaryFunc>\nvoid ScanOuterDim(ep::Stream* ep_stream, const ShapeView& in_shape, int64_t dim, const T* in_ptr,\n                  T* out_ptr) {\n  // data partition: up_space|space|down_space\n  auto up_space = in_shape.elem_cnt() / in_shape.Count(dim);\n  auto space = in_shape.At(dim);\n  auto down_space = in_shape.Count(dim + 1);\n  auto thread_num = up_space * down_space;\n  RUN_CUDA_KERNEL((CumForwardGpu<T, BinaryFunc>), ep_stream, thread_num, in_ptr, out_ptr, up_space,\n                  space, down_space);\n}\n\n// Refer from\n// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cuda/ScanKernels.cu\ntemplate<typename T, int num_threads_x, int num_threads_y, template<typename> class BinaryFunc>\n__device__ void ScanInnerMostDimKernelImpl(T* row_buf, T* src_, T* tgt_, const uint32_t num_rows,\n                                           const uint32_t row_size, T init) {\n  for (uint32_t block_row = blockIdx.x * blockDim.y; block_row < num_rows;\n       block_row += blockDim.y * gridDim.x) {\n    uint32_t row = block_row + threadIdx.y;\n    T block_total = init;\n\n    T* row_src = src_ + row * row_size;\n    T* row_tgt = tgt_ + row * row_size;\n\n    // Perform scan on one block at a time, keeping track of the total value of\n    // all blocks processed so far.\n    for (uint32_t block_col = 0; block_col < row_size; block_col += 2 * num_threads_x) {\n      // Load data into shared memory (two values per thread).\n      uint32_t col1 = block_col + threadIdx.x;\n      uint32_t col2 = block_col + num_threads_x + threadIdx.x;\n      if (row < num_rows) {\n        if (col1 < row_size) {\n          row_buf[threadIdx.x] = row_src[col1];\n        } else {\n          row_buf[threadIdx.x] = init;\n        }\n\n        if (col2 < row_size) {\n          row_buf[num_threads_x + threadIdx.x] = row_src[col2];\n        } else {\n          row_buf[num_threads_x + threadIdx.x] = init;\n        }\n\n        // Add the total value of all previous blocks to the first value of this block.\n        if (threadIdx.x == 0) { row_buf[0] = BinaryFunc<T>()(row_buf[0], block_total); }\n      }\n      __syncthreads();\n\n      for (uint32_t s = num_threads_x, d = 1; s >= 1; s >>= 1, d <<= 1) {\n        if (row < num_rows && threadIdx.x < s) {\n          uint32_t offset = (2 * threadIdx.x + 1) * d - 1;\n          row_buf[offset + d] = BinaryFunc<T>()(row_buf[offset], row_buf[offset + d]);\n        }\n        __syncthreads();\n      }\n\n      for (uint32_t s = 2, d = num_threads_x / 2; d >= 1; s <<= 1, d >>= 1) {\n        if (row < num_rows && threadIdx.x < s - 1) {\n          uint32_t offset = 2 * (threadIdx.x + 1) * d - 1;\n          row_buf[offset + d] = BinaryFunc<T>()(row_buf[offset], row_buf[offset + d]);\n        }\n        __syncthreads();\n      }\n      // Write back to output.\n      if (row < num_rows) {\n        if (col1 < row_size) row_tgt[col1] = row_buf[threadIdx.x];\n        if (col2 < row_size) row_tgt[col2] = row_buf[num_threads_x + threadIdx.x];\n      }\n      block_total = row_buf[2 * num_threads_x - 1];\n      __syncthreads();\n    }\n  }\n}\n\ntemplate<typename T, int num_threads_x, int num_threads_y, template<typename> class BinaryFunc>\n__global__ void ScanInnerMostDimKernel(const T* in_ptr, T* out_ptr, const int64_t num_rows,\n                                       const int64_t row_size, T init) {\n  __shared__ T sbuf[num_threads_y][2 * num_threads_x];\n  T* row_buf = sbuf[threadIdx.y];\n  ScanInnerMostDimKernelImpl<T, num_threads_x, num_threads_y, BinaryFunc>(\n      row_buf, const_cast<T*>(in_ptr), out_ptr, num_rows, row_size, init);\n}\n\ntemplate<typename T, template<typename> class BinaryFunctor>\nvoid ScanInnerMostDim(const T* in_ptr, T* out_ptr, const int64_t num_rows, const int64_t row_size,\n                      const ep::CudaStream* cuda_stream) {\n  dim3 block(16, 32);\n  const int64_t max_grid_dim = cuda_stream->device()->properties().maxGridSize[0];\n  dim3 grid(std::min(max_grid_dim, CeilDiv(num_rows, (int64_t)block.y)));\n  if (std::is_same<BinaryFunctor<T>, SumFunctor<T>>::value) {\n    ScanInnerMostDimKernel<T, 16, 32, SumFunctor>\n        <<<grid, block, 0, cuda_stream->cuda_stream()>>>(in_ptr, out_ptr, num_rows, row_size,\n                                                         /*init*/ 0);\n  } else if (std::is_same<BinaryFunctor<T>, ProdFunctor<T>>::value) {\n    ScanInnerMostDimKernel<T, 16, 32, ProdFunctor>\n        <<<grid, block, 0, cuda_stream->cuda_stream()>>>(in_ptr, out_ptr, num_rows, row_size,\n                                                         /*init*/ 1);\n  } else {\n    UNIMPLEMENTED() << \"Only Support cumsum and cumprod for now.\";\n  }\n}\n\ntemplate<typename T, template<typename> class BinaryFunc>\nvoid CubInclusiveScan(user_op::Tensor* temp_buffer, const T* in_ptr, T* out_ptr, int64_t elem_cnt,\n                      const ep::CudaStream* cuda_stream) {\n  auto* temp_storage = temp_buffer->mut_dptr<T>();\n  size_t temp_storage_bytes = temp_buffer->shape_view().elem_cnt();\n  OF_CUDA_CHECK(cub::DeviceScan::InclusiveScan(temp_storage, temp_storage_bytes, in_ptr, out_ptr,\n                                               BinaryFunc<T>(), elem_cnt,\n                                               cuda_stream->cuda_stream()));\n}\n}  // namespace\n\ntemplate<typename T, template<typename> class BinaryFunc>\nclass GpuCumKernel : public user_op::OpKernel {\n public:\n  GpuCumKernel() = default;\n  ~GpuCumKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const auto* in = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    auto* out = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    const ShapeView& in_shape = in->shape_view();\n    const int64_t dim = ctx->Attr<int64_t>(\"dim\");\n    const int64_t dim_size = in_shape.At(dim);\n\n    // Judge whether tensor has 0 size dimension first.\n    auto elem_cnt = in_shape.elem_cnt();\n    if (!elem_cnt) { return; }\n\n    const auto* in_ptr = in->dptr<T>();\n    auto* out_ptr = out->mut_dptr<T>();\n\n    const auto* cuda_stream = ctx->stream()->As<ep::CudaStream>();\n\n    if (elem_cnt == dim_size) {\n      auto* temp_buffer = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n      CubInclusiveScan<T, BinaryFunc>(temp_buffer, in_ptr, out_ptr, elem_cnt, cuda_stream);\n    } else if (dim == in_shape.NumAxes() - 1) {\n      // Treat all outer dimension as a single dimension.\n      const int64_t num_rows = elem_cnt / dim_size;\n      ScanInnerMostDim<T, BinaryFunc>(in_ptr, out_ptr, num_rows, dim_size, cuda_stream);\n    } else {\n      ScanOuterDim<T, BinaryFunc>(ctx->stream(), in_shape, dim, in_ptr, out_ptr);\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define CUMOP_SEQ                              \\\n  OF_PP_MAKE_TUPLE_SEQ(\"cumprod\", ProdFunctor) \\\n  OF_PP_MAKE_TUPLE_SEQ(\"cumsum\", SumFunctor)\n\n#define REGISTER_CUMOP_KERNEL(dtype, op_name, op_functor)                              \\\n  REGISTER_USER_KERNEL(op_name)                                                        \\\n      .SetCreateFn<GpuCumKernel<dtype, op_functor>>()                                  \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)                 \\\n                       && (user_op::HobDataType(\"y\", 0) == GetDataType<dtype>::value)) \\\n      .SetInferTmpSizeFn(InferTmpBufferSize<dtype, op_functor>);\n\n#define REGISTER_CUMOP_KERNEL_WITH_DTYPE(op_name, op_functor) \\\n  REGISTER_CUMOP_KERNEL(int32_t, op_name, op_functor)         \\\n  REGISTER_CUMOP_KERNEL(int64_t, op_name, op_functor)         \\\n  REGISTER_CUMOP_KERNEL(float, op_name, op_functor)           \\\n  REGISTER_CUMOP_KERNEL(double, op_name, op_functor)          \\\n  REGISTER_CUMOP_KERNEL(half, op_name, op_functor)\n\nOF_PP_FOR_EACH_TUPLE(REGISTER_CUMOP_KERNEL_WITH_DTYPE, CUMOP_SEQ);\n\n#undef REGISTER_CUMOP_KERNEL\n#undef REGISTER_CUMOP_KERNEL_WITH_DTYPE\n#undef CUMOP_SEQ\n\n#endif\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/cutlass_conv_tuner.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#ifdef WITH_CUTLASS\n\n#include \"oneflow/user/kernels/cutlass_conv_tuner.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/new_kernel_util.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n#include \"oneflow/core/job/lazy_mode.h\"\n#include <cutlass/library/handle.h>\n#include <cutlass/library/library.h>\n#include <cutlass/library/singleton.h>\n\nnamespace oneflow {\n\nnamespace {\n\nbool IsWeakerAlginOperation(const cutlass::library::Operation* lhs,\n                            const cutlass::library::Operation* rhs) {\n  const char* lhs_name = lhs->description().name;\n  const char* rhs_name = rhs->description().name;\n  const size_t len = std::strlen(lhs_name);\n  const size_t suffix_len = std::strlen(\"align8\");\n  if (std::strlen(rhs_name) != len) { return false; }\n  if (len < suffix_len) { return false; }\n  const size_t prefix_len = len - suffix_len;\n  if (std::strncmp(lhs_name, rhs_name, prefix_len) != 0) { return false; }\n  const auto& HasLegalSuffix = [&](const char* str) {\n    if (std::strncmp(str + prefix_len, \"align\", std::strlen(\"align\")) != 0) { return false; }\n    const char align = str[len - 1];\n    return align == '8' || align == '4' || align == '2' || align == '1';\n  };\n  if ((!HasLegalSuffix(lhs_name)) || (!HasLegalSuffix(rhs_name))) { return false; }\n  return lhs_name[len - 1] < rhs_name[len - 1];\n}\n\nstruct Conv2dOperationCacheKey {\n  cutlass::library::ConvFunctionalKey functional_key;\n  cutlass::library::Conv2dConfiguration configuraion;\n  size_t alignment;\n  Conv2dOperationCacheKey(cutlass::library::ConvFunctionalKey functional_key,\n                          cutlass::library::Conv2dConfiguration configuraion,\n                          cutlass::library::ConvArguments arguments)\n      : functional_key(functional_key), configuraion(configuraion) {\n    const auto IsStrideAligned = [&](const std::vector<int64_t>& stride, size_t n) {\n      return std::all_of(stride.cbegin(), stride.cend(),\n                         [&](const int64_t& s) { return s % n == 0; });\n    };\n    CHECK_EQ(reinterpret_cast<uintptr_t>(arguments.A) % kCudaAlignSize, 0);\n    CHECK_EQ(reinterpret_cast<uintptr_t>(arguments.B) % kCudaAlignSize, 0);\n    CHECK_EQ(reinterpret_cast<uintptr_t>(arguments.C) % kCudaAlignSize, 0);\n    CHECK_EQ(reinterpret_cast<uintptr_t>(arguments.D) % kCudaAlignSize, 0);\n    const auto IsAligned = [&](size_t n) {\n      return IsStrideAligned(configuraion.stride_a, n) && IsStrideAligned(configuraion.stride_b, n)\n             && IsStrideAligned(configuraion.stride_c, n);\n    };\n    if (IsAligned(8)) {\n      alignment = 8;\n    } else if (IsAligned(4)) {\n      alignment = 4;\n    } else if (IsAligned(2)) {\n      alignment = 2;\n    } else {\n      alignment = 1;\n    }\n  }\n};\n\nstruct Conv2dProblemSizeHasher {\n  size_t operator()(const cutlass::conv::Conv2dProblemSize& problem_size) const {\n    size_t hash = 0;\n    hash = HashCombine(hash, std::hash<int>()(problem_size.N));\n    hash = HashCombine(hash, std::hash<int>()(problem_size.H));\n    hash = HashCombine(hash, std::hash<int>()(problem_size.W));\n    hash = HashCombine(hash, std::hash<int>()(problem_size.C));\n    hash = HashCombine(hash, std::hash<int>()(problem_size.P));\n    hash = HashCombine(hash, std::hash<int>()(problem_size.Q));\n    hash = HashCombine(hash, std::hash<int>()(problem_size.K));\n    hash = HashCombine(hash, std::hash<int>()(problem_size.R));\n    hash = HashCombine(hash, std::hash<int>()(problem_size.S));\n    hash = HashCombine(hash, std::hash<int>()(problem_size.pad_h));\n    hash = HashCombine(hash, std::hash<int>()(problem_size.pad_w));\n    hash = HashCombine(hash, std::hash<int>()(problem_size.stride_h));\n    hash = HashCombine(hash, std::hash<int>()(problem_size.stride_w));\n    hash = HashCombine(hash, std::hash<int>()(problem_size.dilation_h));\n    hash = HashCombine(hash, std::hash<int>()(problem_size.dilation_w));\n    hash = HashCombine(hash, std::hash<int>()(static_cast<int>(problem_size.mode)));\n    hash = HashCombine(hash, std::hash<int>()(problem_size.split_k_slices));\n    hash = HashCombine(hash, std::hash<int>()(problem_size.groups));\n    return hash;\n  }\n};\n\nstruct Conv2dConfigurationHasher {\n  size_t operator()(const cutlass::library::Conv2dConfiguration& configuraion) const {\n    size_t hash = std::hash<int>()(static_cast<int>(configuraion.split_k_mode));\n    hash = HashCombine(hash, Conv2dProblemSizeHasher()(configuraion.problem_size));\n    for (const int64_t v : configuraion.stride_a) {\n      hash = HashCombine(hash, std::hash<int64_t>()(v));\n    }\n    for (const int64_t v : configuraion.stride_b) {\n      hash = HashCombine(hash, std::hash<int64_t>()(v));\n    }\n    for (const int64_t v : configuraion.stride_c) {\n      hash = HashCombine(hash, std::hash<int64_t>()(v));\n    }\n    return hash;\n  }\n};\n\nstruct Conv2dOperationCacheKeyHasher {\n  size_t operator()(const Conv2dOperationCacheKey& key) const {\n    size_t hash = cutlass::library::ConvFunctionalKeyHasher()(key.functional_key);\n    hash = HashCombine(hash, Conv2dConfigurationHasher()(key.configuraion));\n    hash = HashCombine(hash, std::hash<size_t>()(key.alignment));\n    return hash;\n  }\n};\n\ninline bool operator==(const cutlass::library::Conv2dConfiguration& lhs,\n                       const cutlass::library::Conv2dConfiguration& rhs) {\n  return lhs.split_k_mode == rhs.split_k_mode && lhs.problem_size == rhs.problem_size\n         && lhs.stride_a == rhs.stride_a && lhs.stride_b == rhs.stride_b\n         && lhs.stride_c == rhs.stride_c;\n}\n\ninline bool operator==(const Conv2dOperationCacheKey& lhs, const Conv2dOperationCacheKey& rhs) {\n  return lhs.functional_key == rhs.functional_key && lhs.configuraion == rhs.configuraion\n         && lhs.alignment == rhs.alignment;\n}\n\nsize_t GetTensorSize(cutlass::library::NumericTypeID element, cutlass::library::LayoutTypeID layout,\n                     const cutlass::Tensor4DCoord& extent, const std::vector<int64_t>& stride) {\n  const size_t element_size = cutlass::library::sizeof_bits(element) / 8;\n  size_t capacity = 0;\n  if (layout == cutlass::library::LayoutTypeID::kTensorNHWC) {\n    CHECK_EQ(stride.size(), 3);\n    capacity =\n        cutlass::layout::TensorNHWC(stride.at(0), stride.at(1), stride.at(2)).capacity(extent);\n  } else {\n    UNIMPLEMENTED();\n  }\n  return capacity * element_size;\n}\n\n};  // namespace\n\nusing CacheMap = std::unordered_map<Conv2dOperationCacheKey, const cutlass::library::Operation*,\n                                    Conv2dOperationCacheKeyHasher>;\nstruct CutlassConvTuner::Impl {\n  std::mutex mutex;\n  std::unordered_map<int, CacheMap> cache;\n\n  const cutlass::library::Operation* FindConv2dOperation(\n      ep::CudaStream* stream, cutlass::library::ConvFunctionalKey functional_key,\n      const cutlass::library::Conv2dConfiguration& configuraion,\n      const cutlass::library::ConvArguments& arguments, void* workspace, size_t workspace_size);\n\n  const cutlass::library::Operation* GetConv2dOperation(\n      const std::string& name, ep::CudaStream* stream,\n      cutlass::library::ConvFunctionalKey functional_key,\n      const cutlass::library::Conv2dConfiguration& configuraion,\n      const cutlass::library::ConvArguments& arguments, void* workspace, size_t workspace_size);\n};\n\nconst cutlass::library::Operation* CutlassConvTuner::Impl::FindConv2dOperation(\n    ep::CudaStream* stream, cutlass::library::ConvFunctionalKey functional_key,\n    const cutlass::library::Conv2dConfiguration& configuraion,\n    const cutlass::library::ConvArguments& arguments, void* workspace, size_t workspace_size) {\n  int dev = 0;\n  OF_CUDA_CHECK(cudaGetDevice(&dev));\n  Conv2dOperationCacheKey cache_key(functional_key, configuraion, arguments);\n  {\n    std::lock_guard<std::mutex> lock(mutex);\n    const auto& device_cache = cache[dev];\n    const auto& it = device_cache.find(cache_key);\n    if (it != device_cache.end()) { return it->second; }\n  }\n\n  cutlass::library::ConvArguments benchmark_arguments = arguments;\n  void* benchmark_workspace = workspace;\n  cudaStream_t benchmark_stream = stream->cuda_stream();\n#ifdef WITH_CUDA_GRAPHS\n  cudaStreamCaptureMode mode = cudaStreamCaptureModeRelaxed;\n  if (stream->IsGraphCapturing()) {\n    OF_CUDA_CHECK(cudaThreadExchangeStreamCaptureMode(&mode));\n    OF_CUDA_CHECK(cudaStreamCreate(&benchmark_stream));\n    OF_CUDA_CHECK(cudaMalloc(&benchmark_workspace, workspace_size));\n    const size_t a_size =\n        GetTensorSize(functional_key.element_A, functional_key.layout_A,\n                      configuraion.problem_size.activation_extent(), configuraion.stride_a);\n    OF_CUDA_CHECK(cudaMalloc(&benchmark_arguments.A, a_size));\n    const size_t b_size =\n        GetTensorSize(functional_key.element_B, functional_key.layout_B,\n                      configuraion.problem_size.filter_extent(), configuraion.stride_b);\n    OF_CUDA_CHECK(cudaMalloc(&benchmark_arguments.B, b_size));\n    if (benchmark_arguments.C != nullptr) {\n      const size_t c_size =\n          GetTensorSize(functional_key.element_C, functional_key.layout_C,\n                        configuraion.problem_size.output_extent(), configuraion.stride_c);\n      OF_CUDA_CHECK(cudaMalloc(&benchmark_arguments.C, c_size));\n    }\n\n    const size_t d_size = GetTensorSize(\n        functional_key.element_C, functional_key.layout_C,\n        configuraion.problem_size.output_extent(),\n        {configuraion.problem_size.K, configuraion.problem_size.K * configuraion.problem_size.Q,\n         configuraion.problem_size.K * configuraion.problem_size.Q * configuraion.problem_size.P});\n    OF_CUDA_CHECK(cudaMalloc(&benchmark_arguments.D, d_size));\n  }\n#endif  // WITH_CUDA_GRAPHS\n\n  constexpr int turing_warmup_iters = 2;\n  constexpr int turing_iters = 5;\n  cudaEvent_t start{};\n  cudaEvent_t end{};\n  OF_CUDA_CHECK(cudaEventCreate(&start));\n  OF_CUDA_CHECK(cudaEventCreate(&end));\n  const cutlass::library::Operation* fastest_operation = nullptr;\n  float fastest_time = 0;\n  const auto& operations_map_it =\n      cutlass::library::Singleton::get().operation_table.conv2d_operations.find(functional_key);\n  CHECK(operations_map_it\n        != cutlass::library::Singleton::get().operation_table.conv2d_operations.cend());\n  const cutlass::library::ConvOperationVectorMap& operations_map = operations_map_it->second;\n\n  for (const auto& pair : operations_map) {\n    std::map<std::string, const cutlass::library::Operation*, std::greater<std::string>> operations;\n    for (auto operation : pair.second) {\n      operations.emplace(operation->description().name, operation);\n    }\n    const cutlass::library::Operation* prev_operation = nullptr;\n    for (const auto& name_operation : operations) {\n      const cutlass::library::Operation* operation = name_operation.second;\n      if (prev_operation != nullptr && IsWeakerAlginOperation(operation, prev_operation)) {\n        continue;\n      }\n      if (operation->description().tile_description.minimum_compute_capability * 10\n              > stream->cuda_arch()\n          || operation->description().tile_description.maximum_compute_capability * 10\n                 < stream->cuda_arch()) {\n        continue;\n      }\n      auto status = operation->can_implement(&configuraion, &benchmark_arguments);\n      if (status != cutlass::Status::kSuccess) { continue; }\n      const size_t host_workspace_size = operation->get_host_workspace_size(&configuraion);\n      const size_t device_workspace_size = operation->get_device_workspace_size(&configuraion);\n      if (device_workspace_size > workspace_size) { continue; }\n      std::vector<uint8_t> host_workspace(host_workspace_size, 0);\n      if (operation->initialize(&configuraion, host_workspace.data(), benchmark_workspace,\n                                benchmark_stream)\n          != cutlass::Status::kSuccess) {\n        continue;\n      }\n\n      const auto Run = [&]() {\n        auto init_status = operation->initialize(&configuraion, host_workspace.data(),\n                                                 benchmark_workspace, benchmark_stream);\n        CHECK(init_status == cutlass::Status::kSuccess);\n        auto run_status = operation->run(&benchmark_arguments, host_workspace.data(),\n                                         benchmark_workspace, benchmark_stream);\n        CHECK(run_status == cutlass::Status::kSuccess);\n      };\n      OF_CUDA_CHECK(cudaStreamSynchronize(benchmark_stream));\n      for (int i = 0; i < turing_warmup_iters; ++i) { Run(); }\n      OF_CUDA_CHECK(cudaEventRecord(start, benchmark_stream));\n      for (int i = 0; i < turing_iters; ++i) { Run(); }\n      OF_CUDA_CHECK(cudaEventRecord(end, benchmark_stream));\n      OF_CUDA_CHECK(cudaEventSynchronize(end));\n      float time = 0;\n      OF_CUDA_CHECK(cudaEventElapsedTime(&time, start, end));\n      VLOG(3) << operation->description().name << \" \" << time;\n      prev_operation = operation;\n      if (fastest_operation == nullptr || time < fastest_time) {\n        fastest_operation = operation;\n        fastest_time = time;\n      }\n    }\n  }\n  OF_CUDA_CHECK(cudaEventDestroy(start));\n  OF_CUDA_CHECK(cudaEventDestroy(end));\n#ifdef WITH_CUDA_GRAPHS\n  if (stream->IsGraphCapturing()) {\n    OF_CUDA_CHECK(cudaStreamSynchronize(benchmark_stream));\n    OF_CUDA_CHECK(cudaStreamDestroy(benchmark_stream));\n    OF_CUDA_CHECK(cudaFree(const_cast<void*>(benchmark_arguments.A)));\n    OF_CUDA_CHECK(cudaFree(const_cast<void*>(benchmark_arguments.B)));\n    OF_CUDA_CHECK(cudaFree(const_cast<void*>(benchmark_arguments.C)));\n    OF_CUDA_CHECK(cudaFree(benchmark_arguments.D));\n    OF_CUDA_CHECK(cudaFree(benchmark_workspace));\n    OF_CUDA_CHECK(cudaThreadExchangeStreamCaptureMode(&mode));\n  }\n#endif  // WITH_CUDA_GRAPHS\n  if (fastest_operation != nullptr) {\n    VLOG(3) << \"Fastest: \" << fastest_operation->description().name << \" \" << fastest_time;\n    {\n      std::lock_guard<std::mutex> lock(mutex);\n      cache[dev][cache_key] = fastest_operation;\n    }\n  }\n  return fastest_operation;\n}\n\nconst cutlass::library::Operation* CutlassConvTuner::Impl::GetConv2dOperation(\n    const std::string& name, ep::CudaStream* stream,\n    cutlass::library::ConvFunctionalKey functional_key,\n    const cutlass::library::Conv2dConfiguration& configuraion,\n    const cutlass::library::ConvArguments& arguments, void* workspace, size_t workspace_size) {\n  int dev = 0;\n  OF_CUDA_CHECK(cudaGetDevice(&dev));\n  const auto& operations_map_it =\n      cutlass::library::Singleton::get().operation_table.conv2d_operations.find(functional_key);\n  if (operations_map_it\n      == cutlass::library::Singleton::get().operation_table.conv2d_operations.cend()) {\n    return nullptr;\n  }\n  const cutlass::library::ConvOperationVectorMap& operations_map = operations_map_it->second;\n  for (const auto& pair : operations_map) {\n    for (auto operation : pair.second) {\n      if (name != operation->description().name) { continue; }\n      if (operation->description().tile_description.minimum_compute_capability * 10\n              > stream->cuda_arch()\n          || operation->description().tile_description.maximum_compute_capability * 10\n                 < stream->cuda_arch()) {\n        continue;\n      }\n      auto status = operation->can_implement(&configuraion, &arguments);\n      if (status != cutlass::Status::kSuccess) { continue; }\n      const size_t host_workspace_size = operation->get_host_workspace_size(&configuraion);\n      const size_t device_workspace_size = operation->get_device_workspace_size(&configuraion);\n      if (device_workspace_size > workspace_size) { continue; }\n      std::vector<uint8_t> host_workspace(host_workspace_size, 0);\n      if (operation->initialize(&configuraion, host_workspace.data(), workspace,\n                                stream->cuda_stream())\n          != cutlass::Status::kSuccess) {\n        continue;\n      }\n      return operation;\n    }\n  }\n  return nullptr;\n}\n\nCutlassConvTuner::CutlassConvTuner() { impl_.reset(new Impl()); }\n\nconst CutlassConvTuner& CutlassConvTuner::Get() {\n  static CutlassConvTuner instance;\n  return instance;\n}\n\nconst cutlass::library::Operation* CutlassConvTuner::FindConv2dOperation(\n    ep::CudaStream* stream, cutlass::library::ConvFunctionalKey functional_key,\n    const cutlass::library::Conv2dConfiguration& configuraion,\n    const cutlass::library::ConvArguments& arguments, void* workspace,\n    size_t workspace_size) const {\n  return impl_->FindConv2dOperation(stream, functional_key, configuraion, arguments, workspace,\n                                    workspace_size);\n}\n\nconst cutlass::library::Operation* CutlassConvTuner::GetConv2dOperation(\n    const std::string& name, ep::CudaStream* stream,\n    cutlass::library::ConvFunctionalKey functional_key,\n    const cutlass::library::Conv2dConfiguration& configuraion,\n    const cutlass::library::ConvArguments& arguments, void* workspace,\n    size_t workspace_size) const {\n  return impl_->GetConv2dOperation(name, stream, functional_key, configuraion, arguments, workspace,\n                                   workspace_size);\n}\n\n}  // namespace oneflow\n\n#endif  // WITH_CUTLASS\n"
  },
  {
    "path": "oneflow/user/kernels/cutlass_conv_tuner.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#ifndef ONEFLOW_USER_KERNELS_CUTLASS_CONV_TUNER_H_\n#define ONEFLOW_USER_KERNELS_CUTLASS_CONV_TUNER_H_\n\n#ifdef WITH_CUTLASS\n\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/new_kernel_util.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n#include \"oneflow/core/job/lazy_mode.h\"\n#include <cutlass/library/handle.h>\n#include <cutlass/library/library.h>\n#include <cutlass/library/singleton.h>\n\nnamespace oneflow {\n\nclass CutlassConvTuner {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CutlassConvTuner);\n  ~CutlassConvTuner() = default;\n\n  const cutlass::library::Operation* FindConv2dOperation(\n      ep::CudaStream* stream, cutlass::library::ConvFunctionalKey functional_key,\n      const cutlass::library::Conv2dConfiguration& configuraion,\n      const cutlass::library::ConvArguments& arguments, void* workspace,\n      size_t workspace_size) const;\n\n  const cutlass::library::Operation* GetConv2dOperation(\n      const std::string& name, ep::CudaStream* stream,\n      cutlass::library::ConvFunctionalKey functional_key,\n      const cutlass::library::Conv2dConfiguration& configuraion,\n      const cutlass::library::ConvArguments& arguments, void* workspace,\n      size_t workspace_size) const;\n\n  static const CutlassConvTuner& Get();\n\n private:\n  CutlassConvTuner();\n  struct Impl;\n  std::unique_ptr<Impl> impl_;\n};\n\n}  // namespace oneflow\n\n#endif  // WITH_CUTLASS\n#endif  // ONEFLOW_USER_KERNELS_CUTLASS_CONV_TUNER_H_\n"
  },
  {
    "path": "oneflow/user/kernels/data_shuffle_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/device/nccl_util.h\"\n#include \"oneflow/core/job/eager_nccl_comm_manager.h\"\n#include \"oneflow/core/job/parallel_desc.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n#include \"oneflow/user/kernels/gather_kernel_util.h\"\n#include \"oneflow/user/kernels/unsorted_segment_sum_kernel_util.h\"\n#include \"oneflow/core/cuda/atomic.cuh\"\n#include \"oneflow/core/embedding/hash_functions.cuh\"\n#include \"oneflow/core/cuda/elementwise.cuh\"\n#include \"oneflow/core/ep/include/primitive/copy_nd.h\"\n#include \"oneflow/core/cuda/atomic.cuh\"\n#include \"oneflow/core/embedding/embedding_manager.h\"\n#include \"oneflow/user/kernels/one_embedding_data_shuffle.cuh\"\n\nnamespace oneflow {\n\nnamespace {\n\nenum class IdShuffleBufferType {\n  kNumPartitionedUnique = 0,\n  kPartitionedUniqueIds,\n  kReceivedIds,\n  kTableIds,\n  kPartitionedUniqueTableIds,\n  kReceivedTableIds,\n  kWorkspace,\n  kMaxType\n};\n\ntemplate<typename K, typename U, typename IDX>\nclass IdShuffleTmpBufferManager final {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(IdShuffleTmpBufferManager);\n  IdShuffleTmpBufferManager(void* ptr, const int64_t num_ids, const int64_t parallel_num,\n                            bool need_table_ids, bool need_process_table_ids)\n      : offset_(0),\n        offsets_(static_cast<size_t>(IdShuffleBufferType::kMaxType), -1),\n        sizes_(static_cast<size_t>(IdShuffleBufferType::kMaxType)),\n        ptr_(ptr) {\n    const int64_t num_table_ids = need_process_table_ids ? num_ids : 0;\n    const size_t table_ids_bytes = need_table_ids ? num_ids * sizeof(U) : 0;\n    AllocBuffer(IdShuffleBufferType::kNumPartitionedUnique, parallel_num * sizeof(IDX));\n    size_t partitioned_ids_bytes = parallel_num * num_ids * sizeof(K);\n    AllocBuffer(IdShuffleBufferType::kPartitionedUniqueIds, partitioned_ids_bytes);\n    AllocBuffer(IdShuffleBufferType::kReceivedIds, partitioned_ids_bytes);\n    AllocBuffer(IdShuffleBufferType::kTableIds, table_ids_bytes);\n    size_t partitioned_table_ids_bytes = parallel_num * num_table_ids * sizeof(U);\n    AllocBuffer(IdShuffleBufferType::kPartitionedUniqueTableIds, partitioned_table_ids_bytes);\n    AllocBuffer(IdShuffleBufferType::kReceivedTableIds, partitioned_table_ids_bytes);\n    const size_t hash_table_capacity = parallel_num * num_ids;\n    AllocBuffer(IdShuffleBufferType::kWorkspace,\n                hash_table_capacity * sizeof(data_shuffle::TableEntry<K>));\n  }\n\n  template<typename T = void>\n  T* Ptr(IdShuffleBufferType type) {\n    CHECK(ptr_ != nullptr);\n    int64_t offset = offsets_.at(static_cast<size_t>(type));\n    CHECK_NE(offset, -1);\n    return reinterpret_cast<T*>(reinterpret_cast<char*>(ptr_) + offset);\n  }\n\n  int64_t Size(IdShuffleBufferType type) { return sizes_.at(static_cast<size_t>(type)); }\n\n  size_t TotalBufferSize() const { return offset_; }\n\n private:\n  void AllocBuffer(IdShuffleBufferType type, size_t size) {\n    const size_t type_id = static_cast<size_t>(type);\n    CHECK_EQ(offsets_.at(type_id), -1);\n    offsets_.at(type_id) = offset_;\n    sizes_.at(type_id) = size;\n    offset_ += GetCudaAlignedSize(size);\n  }\n  size_t offset_;\n  std::vector<int64_t> offsets_;\n  std::vector<int64_t> sizes_;\n  void* ptr_;\n};\n\ntemplate<typename IDX>\nclass DataShuffleKernelState final : public user_op::OpKernelState {\n public:\n  explicit DataShuffleKernelState(user_op::KernelInitContext* ctx)\n      : device_index_(-1),\n        stream_name_(EagerNcclCommMgr::kDefaultStreamName),\n        parallel_desc_(ctx->parallel_desc()) {\n    OF_CUDA_CHECK(cudaGetDevice(&device_index_));\n    if (ctx->op_conf().has_stream_name_hint()) { stream_name_ = ctx->op_conf().stream_name_hint(); }\n    OF_CUDA_CHECK(cudaMallocHost(&host_num_keys_, sizeof(IDX)));\n    OF_CUDA_CHECK(cudaMallocHost(\n        &host_num_unique_matrix_,\n        parallel_desc_.parallel_num() * parallel_desc_.parallel_num() * sizeof(IDX)));\n    const std::string& embedding_name = ctx->Attr<std::string>(\"embedding_name\");\n    const int64_t parallel_id = ctx->parallel_ctx().parallel_id();\n    embedding_state_ = Singleton<embedding::EmbeddingManager>::Get()->GetEmbeddingState(\n        embedding_name, parallel_id);\n  }\n  ~DataShuffleKernelState() {\n    CudaCurrentDeviceGuard guard(device_index_);\n    OF_CUDA_CHECK(cudaFreeHost(host_num_unique_matrix_));\n  }\n\n  ncclComm_t comm() { return GetOrCreate().comm; }\n\n  IDX* HostNumUniqueMatrix() { return host_num_unique_matrix_; }\n  IDX* HostNumKeys() { return host_num_keys_; }\n\n  embedding::EmbeddingState* EmbeddingState() { return embedding_state_; }\n\n private:\n  struct Comm {\n    Comm(ncclComm_t comm) : comm(comm) {}\n    ncclComm_t comm;\n  };\n\n  const Comm& GetOrCreate() {\n    if (!comm_) { Init(); }\n    return *comm_;\n  }\n\n  void Init() {\n    std::set<std::pair<int64_t, int64_t>> device_set;\n    for (int64_t parallel_id = 0; parallel_id < parallel_desc_.parallel_num(); ++parallel_id) {\n      int64_t machine_id = CHECK_JUST(parallel_desc_.MachineId4ParallelId(parallel_id));\n      int64_t device_id = CHECK_JUST(parallel_desc_.DeviceId4ParallelId(parallel_id));\n      device_set.emplace(std::make_pair(machine_id, device_id));\n    }\n    EagerCclCommMgr* comm_mgr = CHECK_NOTNULL(Singleton<EagerCclCommMgr>::Get());\n    ncclComm_t comm;\n    comm =\n        comm_mgr->As<EagerNcclCommMgr>()->GetCommForDeviceAndStreamName(device_set, stream_name_);\n    comm_.reset(new Comm(comm));\n  }\n\n  int device_index_;\n  bool has_independent_stream_;\n  std::string stream_name_;\n  ParallelDesc parallel_desc_;\n  std::unique_ptr<Comm> comm_;\n  IDX* host_num_unique_matrix_;\n  IDX* host_num_keys_;\n  embedding::EmbeddingState* embedding_state_;\n};\n\n}  // namespace\n\ntemplate<typename K, typename U, typename IDX>\nclass IdShuffleKernel final : public user_op::OpKernel {\n public:\n  IdShuffleKernel() : current_iter_(0){};\n  ~IdShuffleKernel() override = default;\n\n  std::shared_ptr<user_op::OpKernelState> CreateOpKernelState(\n      user_op::KernelInitContext* ctx) const override {\n    return std::make_shared<DataShuffleKernelState<IDX>>(ctx);\n  }\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state,\n               const user_op::OpKernelCache*) const override {\n    auto* kernel_state = dynamic_cast<DataShuffleKernelState<IDX>*>(state);\n    CHECK(kernel_state != nullptr);\n    const user_op::Tensor* ids = ctx->Tensor4ArgNameAndIndex(\"ids\", 0);\n    user_op::Tensor* num_unique_matrix = ctx->Tensor4ArgNameAndIndex(\"num_unique_matrix\", 0);\n    user_op::Tensor* inverse_unique_partition_indices =\n        ctx->Tensor4ArgNameAndIndex(\"inverse_unique_partition_indices\", 0);\n    user_op::Tensor* cur_rank_num_unique = ctx->Tensor4ArgNameAndIndex(\"cur_rank_num_unique\", 0);\n    user_op::Tensor* cur_rank_unique_ids = ctx->Tensor4ArgNameAndIndex(\"cur_rank_unique_ids\", 0);\n    user_op::Tensor* cur_rank_unique_table_ids =\n        ctx->Tensor4ArgNameAndIndex(\"cur_rank_unique_table_ids\", 0);\n    user_op::Tensor* cur_rank_inverse_indices =\n        ctx->Tensor4ArgNameAndIndex(\"cur_rank_inverse_indices\", 0);\n    user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n    const int32_t num_tables = ctx->Attr<int32_t>(\"num_tables\");\n    const int64_t padding_idx = ctx->Attr<int64_t>(\"padding_idx\");\n    const bool has_padding_idx = ctx->Attr<bool>(\"has_padding_idx\");\n    const bool has_table_ids = ctx->has_input(\"table_ids\", 0);\n    const bool need_gen_table_ids = (!has_table_ids && num_tables > 1);\n    const bool need_process_table_ids = (has_table_ids || num_tables > 1);\n    const int64_t num_ids = ids->shape_view().elem_cnt();\n    const int64_t parallel_num = ctx->parallel_ctx().parallel_num();\n    const int64_t parallel_id = ctx->parallel_ctx().parallel_id();\n    cudaStream_t cuda_stream = ctx->stream()->As<ep::CudaStream>()->cuda_stream();\n    IdShuffleTmpBufferManager<K, U, IDX> buffer_manager(\n        tmp_buffer->mut_dptr(), num_ids, parallel_num, need_gen_table_ids, need_process_table_ids);\n    CHECK_GE(tmp_buffer->shape_view().elem_cnt(), buffer_manager.TotalBufferSize());\n\n    ncclComm_t comm = kernel_state->comm();\n    IDX* host_num_unique_matrix = kernel_state->HostNumUniqueMatrix();\n    IDX* host_num_keys = kernel_state->HostNumKeys();\n    data_shuffle::IdShuffleDataPtrs<K, U, IDX> data_ptrs;\n    data_ptrs.ids_ptr = reinterpret_cast<const K*>(ids->dptr());\n    if (has_table_ids) {\n      const user_op::Tensor* table_ids = ctx->Tensor4ArgNameAndIndex(\"table_ids\", 0);\n      data_ptrs.table_ids_ptr = reinterpret_cast<const U*>(table_ids->dptr());\n    } else if (need_gen_table_ids) {\n      data_shuffle::GenerateTableIds<<<BlocksNum4ThreadsNum(num_ids), kCudaThreadsNumPerBlock, 0,\n                                       cuda_stream>>>(\n          num_ids, num_tables, buffer_manager.template Ptr<U>(IdShuffleBufferType::kTableIds));\n      data_ptrs.table_ids_ptr = buffer_manager.template Ptr<U>(IdShuffleBufferType::kTableIds);\n    } else {\n      data_ptrs.table_ids_ptr = nullptr;\n    }\n    data_ptrs.num_partitioned_unique =\n        buffer_manager.template Ptr<IDX>(IdShuffleBufferType::kNumPartitionedUnique);\n    data_ptrs.partitioned_unique_ids =\n        buffer_manager.template Ptr<K>(IdShuffleBufferType::kPartitionedUniqueIds);\n    data_ptrs.partitioned_unique_table_ids =\n        buffer_manager.template Ptr<U>(IdShuffleBufferType::kPartitionedUniqueTableIds);\n    data_ptrs.workspace_ptr = buffer_manager.Ptr(IdShuffleBufferType::kWorkspace);\n    data_ptrs.workspace_size = buffer_manager.Size(IdShuffleBufferType::kWorkspace);\n    data_ptrs.received_ids = buffer_manager.template Ptr<K>(IdShuffleBufferType::kReceivedIds);\n    data_ptrs.received_table_ids =\n        buffer_manager.template Ptr<U>(IdShuffleBufferType::kReceivedTableIds);\n    data_ptrs.num_unique_matrix_ptr = reinterpret_cast<IDX*>(num_unique_matrix->mut_dptr());\n    data_ptrs.inverse_unique_partition_indices_ptr =\n        reinterpret_cast<IDX*>(inverse_unique_partition_indices->mut_dptr());\n    data_ptrs.cur_rank_num_unique_ptr = reinterpret_cast<IDX*>(cur_rank_num_unique->mut_dptr());\n    data_ptrs.cur_rank_unique_ids_ptr = reinterpret_cast<K*>(cur_rank_unique_ids->mut_dptr());\n    data_ptrs.cur_rank_unique_table_ids_ptr =\n        reinterpret_cast<U*>(cur_rank_unique_table_ids->mut_dptr());\n    data_ptrs.cur_rank_inverse_indices_ptr =\n        reinterpret_cast<IDX*>(cur_rank_inverse_indices->mut_dptr());\n\n    data_shuffle::IdShuffle(ctx->stream(), comm, data_ptrs, num_ids, parallel_id, parallel_num,\n                            num_unique_matrix->data_type(), ids->data_type(),\n                            cur_rank_unique_table_ids->data_type(), need_process_table_ids,\n                            has_padding_idx, padding_idx, host_num_unique_matrix, host_num_keys);\n\n    embedding::EmbeddingState* embedding_state = kernel_state->EmbeddingState();\n    std::vector<uint32_t> num_unique_matrix_vec(parallel_num * parallel_num);\n    std::memcpy(num_unique_matrix_vec.data(), host_num_unique_matrix,\n                parallel_num * parallel_num * sizeof(IDX));\n    CHECK_EQ(sizeof(IDX), sizeof(uint32_t)) << \"assume sizeof(IDX) equals to sizeof(uint32_t)\";\n    embedding_state->SetIdNumUniqueMatrix(num_unique_matrix_vec, current_iter_);\n\n    uint32_t final_num_unique = *host_num_keys;\n    embedding_state->SetIdFinalNumUnique(final_num_unique, current_iter_);\n    current_iter_++;\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n  mutable int64_t current_iter_;\n};\n\n#define ID_DATA_TYPE_SEQ                            \\\n  OF_PP_MAKE_TUPLE_SEQ(uint32_t, DataType::kUInt32) \\\n  OF_PP_MAKE_TUPLE_SEQ(uint64_t, DataType::kUInt64) \\\n  OF_PP_MAKE_TUPLE_SEQ(int32_t, DataType::kInt32)   \\\n  OF_PP_MAKE_TUPLE_SEQ(int64_t, DataType::kInt64)\n\n#define TABLE_ID_DATA_TYPE_SEQ                      \\\n  OF_PP_MAKE_TUPLE_SEQ(uint8_t, DataType::kUInt8)   \\\n  OF_PP_MAKE_TUPLE_SEQ(uint32_t, DataType::kUInt32) \\\n  OF_PP_MAKE_TUPLE_SEQ(uint64_t, DataType::kUInt64) \\\n  OF_PP_MAKE_TUPLE_SEQ(int8_t, DataType::kInt8)     \\\n  OF_PP_MAKE_TUPLE_SEQ(int32_t, DataType::kInt32)   \\\n  OF_PP_MAKE_TUPLE_SEQ(int64_t, DataType::kInt64)\n\n#define IDX_DATA_TYPE_SEQ                           \\\n  OF_PP_MAKE_TUPLE_SEQ(uint32_t, DataType::kUInt32) \\\n  OF_PP_MAKE_TUPLE_SEQ(int32_t, DataType::kInt32)\n\n#define REGISTER_CUDA_ID_SHUFFLE_KERNEL(k_dtype_pair, table_id_dtype_pair, idx_dtype_pair)       \\\n  REGISTER_USER_KERNEL(\"id_shuffle\")                                                             \\\n      .SetCreateFn<                                                                              \\\n          IdShuffleKernel<OF_PP_PAIR_FIRST(k_dtype_pair), OF_PP_PAIR_FIRST(table_id_dtype_pair), \\\n                          OF_PP_PAIR_FIRST(idx_dtype_pair)>>()                                   \\\n      .SetIsMatchedHob(                                                                          \\\n          (user_op::HobDeviceType() == DeviceType::kCUDA)                                        \\\n          && (user_op::HobDataType(\"ids\", 0) == OF_PP_PAIR_SECOND(k_dtype_pair))                 \\\n          && (user_op::HobDataType(\"cur_rank_unique_table_ids\", 0)                               \\\n              == OF_PP_PAIR_SECOND(table_id_dtype_pair))                                         \\\n          && (user_op::HobDataType(\"num_unique_matrix\", 0) == OF_PP_PAIR_SECOND(idx_dtype_pair)) \\\n          && (!ParseBooleanFromEnv(\"ONEFLOW_ONE_EMBEDDING_ID_SHUFFLE_USE_P2P\", false)))          \\\n      .SetInferTmpSizeFn([](user_op::InferContext* ctx) {                                        \\\n        const user_op::TensorDesc& ids = ctx->InputTensorDesc(\"ids\", 0);                         \\\n        const bool has_table_ids = ctx->has_input(\"table_ids\", 0);                               \\\n        const int32_t num_tables = ctx->Attr<int32_t>(\"num_tables\");                             \\\n        const bool need_gen_table_ids = (!has_table_ids && num_tables > 1);                      \\\n        const bool need_process_table_ids = (has_table_ids || num_tables > 1);                   \\\n        IdShuffleTmpBufferManager<OF_PP_PAIR_FIRST(k_dtype_pair),                                \\\n                                  OF_PP_PAIR_FIRST(table_id_dtype_pair),                         \\\n                                  OF_PP_PAIR_FIRST(idx_dtype_pair)>                              \\\n            buffer_manager(nullptr, ids.shape().elem_cnt(), ctx->parallel_desc().parallel_num(), \\\n                           need_gen_table_ids, need_process_table_ids);                          \\\n        return buffer_manager.TotalBufferSize();                                                 \\\n      });\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_CUDA_ID_SHUFFLE_KERNEL, ID_DATA_TYPE_SEQ,\n                                 TABLE_ID_DATA_TYPE_SEQ, IDX_DATA_TYPE_SEQ)\n\n__device__ float RoundHalfAwayFromZero(const float x) {\n  float abs_val = abs(x);\n  float floor_val = floor(abs_val + static_cast<float>(0.5));\n  return copysignf(floor_val, x);\n}\n\n// warp reduce version.\nconstexpr int32_t kWarpSize = 32;\nconstexpr int32_t kMaxColSize = 1024;\n\ntemplate<typename T, int thread_group_width = kWarpSize>\n__inline__ __device__ T WarpMaxAllReduce(T val) {\n  for (int32_t lane_mask = thread_group_width / 2; lane_mask > 0; lane_mask /= 2) {\n    val = max(val, __shfl_xor_sync(0xffffffff, val, lane_mask, thread_group_width));\n  }\n  return val;\n}\n\ninline cudaError_t GetWarpImplNumBlocks(int64_t block_size, int64_t max_blocks, int64_t waves,\n                                        int* num_blocks) {\n  int dev;\n  {\n    cudaError_t err = cudaGetDevice(&dev);\n    if (err != cudaSuccess) { return err; }\n  }\n  int sm_count;\n  {\n    cudaError_t err = cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev);\n    if (err != cudaSuccess) { return err; }\n  }\n  int tpm;\n  {\n    cudaError_t err = cudaDeviceGetAttribute(&tpm, cudaDevAttrMaxThreadsPerMultiProcessor, dev);\n    if (err != cudaSuccess) { return err; }\n  }\n  *num_blocks =\n      std::max<int>(1, std::min<int64_t>(max_blocks, sm_count * tpm / block_size * waves));\n  return cudaSuccess;\n}\n\ntemplate<typename T, typename ComputeType, int pack_size, int cols_per_thread,\n         int thread_group_width, int rows_per_access, bool padding>\n__global__ void QuantizeWarpImplKernel(const T* src, int8_t* dst, T* quantize_factor,\n                                       const int64_t rows, const int64_t cols) {\n  static_assert(cols_per_thread % pack_size == 0, \"\");\n  static_assert(thread_group_width <= kWarpSize, \"\");\n  static_assert(kWarpSize % thread_group_width == 0, \"\");\n  constexpr int num_packs = cols_per_thread / pack_size;\n  assert(cols <= cols_per_thread * thread_group_width);\n  ComputeType buf[rows_per_access][cols_per_thread];\n  const int global_thread_group_id = blockIdx.x * blockDim.y + threadIdx.y;\n  const int num_global_thread_group = gridDim.x * blockDim.y;\n  const int lane_id = threadIdx.x;\n  const int64_t step = num_global_thread_group * rows_per_access;\n  using LoadType = cuda::elementwise::PackType<T, pack_size>;\n  using LoadPack = cuda::elementwise::Pack<T, pack_size>;\n  using StoreType = cuda::elementwise::PackType<int8_t, pack_size>;\n  using StorePack = cuda::elementwise::Pack<int8_t, pack_size>;\n\n  for (int64_t row = global_thread_group_id * rows_per_access; row < rows; row += step) {\n    ComputeType thread_abs_max[rows_per_access];\n#pragma unroll\n    for (int row_id = 0; row_id < rows_per_access; row_id++) {\n      ComputeType* row_buf = buf[row_id];\n      thread_abs_max[row_id] = 0.0;\n#pragma unroll\n      for (int pack_id = 0; pack_id < num_packs; pack_id++) {\n        const int pack_offset = pack_id * pack_size;\n        const int col = (pack_id * thread_group_width + lane_id) * pack_size;\n        LoadPack load_pack;\n        if (!padding || col < cols) {\n          const int64_t load_offset = ((row + row_id) * cols + col) / pack_size;\n          load_pack.storage = *(reinterpret_cast<const LoadType*>(src) + load_offset);\n#pragma unroll\n          for (int i = 0; i < pack_size; i++) {\n            row_buf[pack_offset + i] = static_cast<ComputeType>(load_pack.elem[i]);\n            thread_abs_max[row_id] = max(thread_abs_max[row_id], abs(row_buf[pack_offset + i]));\n          }\n        } else {\n#pragma unroll\n          for (int i = 0; i < pack_size; i++) { row_buf[pack_offset + i] = 0.0; }\n        }\n      }\n    }\n    ComputeType warp_max[rows_per_access];\n#pragma unroll\n    for (int row_id = 0; row_id < rows_per_access; row_id++) {\n      warp_max[row_id] = WarpMaxAllReduce<ComputeType, thread_group_width>(thread_abs_max[row_id]);\n      if (threadIdx.x == 0) { quantize_factor[row + row_id] = static_cast<T>(warp_max[row_id]); }\n      ComputeType* row_buf = buf[row_id];\n      ComputeType quantize_factor_val = static_cast<ComputeType>(127.0) / warp_max[row_id];\n#pragma unroll\n      for (int col = 0; col < cols_per_thread; col++) {\n        row_buf[col] = RoundHalfAwayFromZero(row_buf[col] * quantize_factor_val);\n      }\n#pragma unroll\n      for (int pack_id = 0; pack_id < num_packs; pack_id++) {\n        const int pack_offset = pack_id * pack_size;\n        const int col = (pack_id * thread_group_width + lane_id) * pack_size;\n        StorePack store_pack;\n        if (!padding || col < cols) {\n          const int64_t store_offset = ((row + row_id) * cols + col) / pack_size;\n          for (int i = 0; i < pack_size; i++) {\n            store_pack.elem[i] = static_cast<int8_t>(row_buf[pack_id * pack_size + i]);\n          }\n          *(reinterpret_cast<StoreType*>(dst) + store_offset) = store_pack.storage;\n        }\n      }\n    }\n  }\n}\n\ntemplate<typename T, typename ComputeType, int pack_size, int cols_per_thread,\n         int thread_group_width, int rows_per_access, bool padding>\ninline cudaError_t LaunchQuantizeWarpImpl(cudaStream_t stream, const T* src, int8_t* dst,\n                                          T* quantize_factor, const int64_t rows,\n                                          const int64_t cols) {\n  constexpr int block_size = 128;\n  constexpr int waves = 32;\n  static_assert(block_size % thread_group_width == 0, \"\");\n  constexpr int thread_groups_per_block = block_size / thread_group_width;\n  dim3 block_dim(thread_group_width, thread_groups_per_block);\n  const int64_t num_blocks =\n      (rows / rows_per_access + thread_groups_per_block - 1) / thread_groups_per_block;\n  int grid_dim_x = 0;\n\n  cudaError_t err = GetWarpImplNumBlocks(block_size, num_blocks, waves, &grid_dim_x);\n  if (err != cudaSuccess) { return err; }\n\n  QuantizeWarpImplKernel<T, ComputeType, pack_size, cols_per_thread, thread_group_width,\n                         rows_per_access, padding>\n      <<<grid_dim_x, block_dim, 0, stream>>>(src, dst, quantize_factor, rows, cols);\n  return cudaPeekAtLastError();\n}\n\ntemplate<typename T, typename ComputeType, int pack_size, int cols_per_thread,\n         int thread_group_width, int rows_per_access>\ninline cudaError_t DispatchQuantizeWarpImplPadding(cudaStream_t stream, const T* src, int8_t* dst,\n                                                   T* quantize_factor, const int64_t rows,\n                                                   const int64_t cols) {\n  if (cols == cols_per_thread * thread_group_width) {\n    return LaunchQuantizeWarpImpl<T, ComputeType, pack_size, cols_per_thread, thread_group_width,\n                                  rows_per_access, false>(stream, src, dst, quantize_factor, rows,\n                                                          cols);\n  } else {\n    return LaunchQuantizeWarpImpl<T, ComputeType, pack_size, cols_per_thread, thread_group_width,\n                                  rows_per_access, true>(stream, src, dst, quantize_factor, rows,\n                                                         cols);\n  }\n}\n\ntemplate<typename T, typename ComputeType, int pack_size>\ntypename std::enable_if<pack_size == 1, cudaError_t>::type DispatchQuantizeWarpImplCols(\n    cudaStream_t stream, const T* src, int8_t* dst, T* quantize_factor, const int64_t rows,\n    const int64_t cols) {\n  if (cols <= 0) { return cudaErrorInvalidValue; }\n#define DEFINE_ONE_ELIF(thread_group_width)                                                       \\\n  else if (cols <= (thread_group_width)*pack_size) {                                              \\\n    if (rows % 2 == 0) {                                                                          \\\n      return DispatchQuantizeWarpImplPadding<T, ComputeType, pack_size, pack_size,                \\\n                                             thread_group_width, 2>(stream, src, dst,             \\\n                                                                    quantize_factor, rows, cols); \\\n    } else {                                                                                      \\\n      return DispatchQuantizeWarpImplPadding<T, ComputeType, pack_size, pack_size,                \\\n                                             thread_group_width, 1>(stream, src, dst,             \\\n                                                                    quantize_factor, rows, cols); \\\n    }                                                                                             \\\n  }\n  DEFINE_ONE_ELIF(1)\n  DEFINE_ONE_ELIF(2)\n  DEFINE_ONE_ELIF(4)\n  DEFINE_ONE_ELIF(8)\n  DEFINE_ONE_ELIF(16)\n  DEFINE_ONE_ELIF(32)\n#undef DEFINE_ONE_ELIF\n#define DEFINE_ONE_ELIF(col)                                                              \\\n  else if (cols <= (col)*kWarpSize) {                                                     \\\n    return DispatchQuantizeWarpImplPadding<T, ComputeType, pack_size, col, kWarpSize, 1>( \\\n        stream, src, dst, quantize_factor, rows, cols);                                   \\\n  }\n  DEFINE_ONE_ELIF(2)\n  DEFINE_ONE_ELIF(3)\n  DEFINE_ONE_ELIF(4)\n  DEFINE_ONE_ELIF(5)\n  DEFINE_ONE_ELIF(6)\n  DEFINE_ONE_ELIF(7)\n  DEFINE_ONE_ELIF(8)\n  DEFINE_ONE_ELIF(9)\n  DEFINE_ONE_ELIF(10)\n  DEFINE_ONE_ELIF(11)\n  DEFINE_ONE_ELIF(12)\n  DEFINE_ONE_ELIF(13)\n  DEFINE_ONE_ELIF(14)\n  DEFINE_ONE_ELIF(15)\n  DEFINE_ONE_ELIF(16)\n  DEFINE_ONE_ELIF(17)\n  DEFINE_ONE_ELIF(18)\n  DEFINE_ONE_ELIF(19)\n  DEFINE_ONE_ELIF(20)\n  DEFINE_ONE_ELIF(21)\n  DEFINE_ONE_ELIF(22)\n  DEFINE_ONE_ELIF(23)\n  DEFINE_ONE_ELIF(24)\n  DEFINE_ONE_ELIF(25)\n  DEFINE_ONE_ELIF(26)\n  DEFINE_ONE_ELIF(27)\n  DEFINE_ONE_ELIF(28)\n  DEFINE_ONE_ELIF(29)\n  DEFINE_ONE_ELIF(30)\n  DEFINE_ONE_ELIF(31)\n  DEFINE_ONE_ELIF(32)\n#undef DEFINE_ONE_ELIF\n  else {\n    return cudaErrorInvalidValue;\n  }\n}\n\ntemplate<typename T, typename ComputeType, int pack_size>\ntypename std::enable_if<pack_size == 2, cudaError_t>::type DispatchQuantizeWarpImplCols(\n    cudaStream_t stream, const T* src, int8_t* dst, T* quantize_factor, const int64_t rows,\n    const int64_t cols) {\n  if (cols <= 0) { return cudaErrorInvalidValue; }\n#define DEFINE_ONE_ELIF(thread_group_width)                                                       \\\n  else if (cols <= (thread_group_width)*pack_size) {                                              \\\n    if (rows % 2 == 0) {                                                                          \\\n      return DispatchQuantizeWarpImplPadding<T, ComputeType, pack_size, pack_size,                \\\n                                             thread_group_width, 2>(stream, src, dst,             \\\n                                                                    quantize_factor, rows, cols); \\\n    } else {                                                                                      \\\n      return DispatchQuantizeWarpImplPadding<T, ComputeType, pack_size, pack_size,                \\\n                                             thread_group_width, 1>(stream, src, dst,             \\\n                                                                    quantize_factor, rows, cols); \\\n    }                                                                                             \\\n  }\n  DEFINE_ONE_ELIF(1)\n  DEFINE_ONE_ELIF(2)\n  DEFINE_ONE_ELIF(4)\n  DEFINE_ONE_ELIF(8)\n  DEFINE_ONE_ELIF(16)\n  DEFINE_ONE_ELIF(32)\n#undef DEFINE_ONE_ELIF\n#define DEFINE_ONE_ELIF(col)                                                              \\\n  else if (cols <= (col)*kWarpSize) {                                                     \\\n    return DispatchQuantizeWarpImplPadding<T, ComputeType, pack_size, col, kWarpSize, 1>( \\\n        stream, src, dst, quantize_factor, rows, cols);                                   \\\n  }\n  DEFINE_ONE_ELIF(4)\n  DEFINE_ONE_ELIF(6)\n  DEFINE_ONE_ELIF(8)\n  DEFINE_ONE_ELIF(10)\n  DEFINE_ONE_ELIF(12)\n  DEFINE_ONE_ELIF(14)\n  DEFINE_ONE_ELIF(16)\n  DEFINE_ONE_ELIF(18)\n  DEFINE_ONE_ELIF(20)\n  DEFINE_ONE_ELIF(22)\n  DEFINE_ONE_ELIF(24)\n  DEFINE_ONE_ELIF(26)\n  DEFINE_ONE_ELIF(28)\n  DEFINE_ONE_ELIF(30)\n  DEFINE_ONE_ELIF(32)\n#undef DEFINE_ONE_ELIF\n  else {\n    return cudaErrorInvalidValue;\n  }\n}\n\ntemplate<typename T, typename ComputeType>\nstruct DispatchQuantizeWarpImplPackSize {\n  cudaError_t operator()(cudaStream_t stream, const T* src, int8_t* dst, T* quantize_factor,\n                         const int64_t rows, const int64_t cols) {\n    if (cols % 2 == 0) {\n      return DispatchQuantizeWarpImplCols<T, ComputeType, 2>(stream, src, dst, quantize_factor,\n                                                             rows, cols);\n    } else {\n      return DispatchQuantizeWarpImplCols<T, ComputeType, 1>(stream, src, dst, quantize_factor,\n                                                             rows, cols);\n    }\n  }\n};\n\ntemplate<typename T, typename ComputeType, typename IDX, int pack_size>\n__global__ void DequantizeKernel(const int8_t* x, T* quantize_factor, T* out, IDX col_size,\n                                 IDX elem_cnt);\n\ntemplate<typename T, typename ComputeType, typename IDX, int pack_size>\n__global__ void DequantizeKernel(const int8_t* x, T* quantize_factor, T* out, IDX col_size,\n                                 IDX elem_cnt) {\n  IDX global_thread_id = blockIdx.x * blockDim.x + threadIdx.x;\n\n  for (int index = global_thread_id * pack_size; index < elem_cnt;\n       index += gridDim.x * blockDim.x * pack_size) {\n    IDX quantize_factor_idx = index / col_size;\n    ComputeType quantize_factor_val = static_cast<ComputeType>(quantize_factor[quantize_factor_idx])\n                                      / static_cast<ComputeType>(127.0);\n    using LoadPackType = cuda::elementwise::PackType<int8_t, pack_size>;\n    using LoadPack = cuda::elementwise::Pack<int8_t, pack_size>;\n    using StorePackType = cuda::elementwise::PackType<T, pack_size>;\n    using StorePack = cuda::elementwise::Pack<T, pack_size>;\n    LoadPack load_pack{};\n    StorePack store_pack{};\n    load_pack.storage = *(reinterpret_cast<const LoadPackType*>(x) + index / pack_size);\n#pragma unroll\n    for (int i = 0; i < pack_size; i++) {\n      store_pack.elem[i] =\n          static_cast<T>(static_cast<ComputeType>(load_pack.elem[i]) * quantize_factor_val);\n    }\n    *(reinterpret_cast<StorePackType*>(out) + index / pack_size) = store_pack.storage;\n  }\n}\n\ntemplate<typename T, typename ComputeType, typename IDX, int pack_size>\ncudaError_t DispatchDequantizeKernelPackSize(cudaStream_t stream, const int8_t* src,\n                                             T* quantize_factor, T* dst, const int64_t col_size,\n                                             const int64_t elem_cnt) {\n  const int64_t pack_num = elem_cnt / pack_size;\n  int grid_size = 0;\n  cudaError_t err = cuda::elementwise::GetNumBlocks(pack_num, &grid_size);\n  if (err != cudaSuccess) { return err; }\n  DequantizeKernel<T, ComputeType, IDX, pack_size>\n      <<<grid_size, cuda::elementwise::kBlockSize, 0, stream>>>(src, quantize_factor, dst, col_size,\n                                                                elem_cnt);\n  return cudaSuccess;\n}\n\ntemplate<typename T, typename ComputeType, typename IDX>\ninline cudaError_t LaunchDequantizeKernel(cudaStream_t stream, const int8_t* src,\n                                          T* quantize_factor, T* dst, const int64_t col_size,\n                                          const int64_t elem_cnt) {\n  constexpr int quantized_src_pack_size = cuda::elementwise::PackSize<int8_t>();\n  constexpr int dst_pack_size = cuda::elementwise::PackSize<T>();\n  int launch_pack_size = std::min(quantized_src_pack_size, dst_pack_size);\n  if (launch_pack_size == 8 && col_size % 8 == 0) {\n    cudaError_t err = DispatchDequantizeKernelPackSize<T, ComputeType, IDX, 8>(\n        stream, src, quantize_factor, dst, col_size, elem_cnt);\n    if (err != cudaSuccess) { return err; }\n  } else if (launch_pack_size == 4 && col_size % 4 == 0) {\n    cudaError_t err = DispatchDequantizeKernelPackSize<T, ComputeType, IDX, 4>(\n        stream, src, quantize_factor, dst, col_size, elem_cnt);\n    if (err != cudaSuccess) { return err; }\n  } else if (launch_pack_size == 2 && col_size % 2 == 0) {\n    cudaError_t err = DispatchDequantizeKernelPackSize<T, ComputeType, IDX, 2>(\n        stream, src, quantize_factor, dst, col_size, elem_cnt);\n    if (err != cudaSuccess) { return err; }\n  } else {\n    cudaError_t err = DispatchDequantizeKernelPackSize<T, ComputeType, IDX, 1>(\n        stream, src, quantize_factor, dst, col_size, elem_cnt);\n    if (err != cudaSuccess) { return err; }\n  }\n  return cudaPeekAtLastError();\n}\n\ntemplate<typename T>\nstruct DefaultComputeType {\n  using type = T;\n};\n\ntemplate<>\nstruct DefaultComputeType<half> {\n  using type = float;\n};\n\ntemplate<typename T, typename IDX>\nclass EmbeddingShuffleKernel final : public user_op::OpKernel {\n public:\n  EmbeddingShuffleKernel() : current_iter_(0) {}\n  ~EmbeddingShuffleKernel() override = default;\n\n  std::shared_ptr<user_op::OpKernelState> CreateOpKernelState(\n      user_op::KernelInitContext* ctx) const override {\n    return std::make_shared<DataShuffleKernelState<IDX>>(ctx);\n  }\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state,\n               const user_op::OpKernelCache*) const override {\n    auto* kernel_state = dynamic_cast<DataShuffleKernelState<IDX>*>(state);\n    CHECK(kernel_state != nullptr);\n    embedding::EmbeddingState* embedding_state = kernel_state->EmbeddingState();\n    std::unique_ptr<embedding::TmpBufferAllocator> allocator =\n        embedding_state->NewTmpBufferAllocator(ctx);\n    embedding_state->OnEmbeddingShuffleStart(ctx, current_iter_);\n    const user_op::Tensor* num_unique_matrix = ctx->Tensor4ArgNameAndIndex(\"num_unique_matrix\", 0);\n    const user_op::Tensor* cur_rank_inverse_indices =\n        ctx->Tensor4ArgNameAndIndex(\"cur_rank_inverse_indices\", 0);\n    const user_op::Tensor* inverse_unique_partition_indices =\n        ctx->Tensor4ArgNameAndIndex(\"inverse_unique_partition_indices\", 0);\n    user_op::Tensor* embeddings = ctx->Tensor4ArgNameAndIndex(\"embeddings\", 0);\n    user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n    ncclComm_t comm = kernel_state->comm();\n    using ComputeType = typename DefaultComputeType<T>::type;\n    const int64_t embedding_size = ctx->Attr<int64_t>(\"embedding_size\");\n    IDX* host_num_unique_matrix = kernel_state->HostNumUniqueMatrix();\n    DataType data_type = embeddings->data_type();\n    const int64_t num_ids = inverse_unique_partition_indices->shape_view().elem_cnt();\n    const int64_t parallel_num = ctx->parallel_ctx().parallel_num();\n    const int64_t parallel_id = ctx->parallel_ctx().parallel_id();\n    const bool skip_last_gather = ctx->Attr<bool>(\"skip_last_gather\");\n    bool enable_quantized_comm_env_var =\n        ParseBooleanFromEnv(\"ONEFLOW_ONE_EMBEDDING_ENABLE_QUANTIZED_COMM\", false);\n    bool enable_quantized_comm = enable_quantized_comm_env_var && (embedding_size < kMaxColSize);\n    if (enable_quantized_comm_env_var && !enable_quantized_comm) {\n      LOG(WARNING) << \"Only envrionment variable ONEFLOW_ONE_EMBEDDING_ENABLE_QUANTIZED_COMM=1 and \"\n                      \"embedding_size less equal than 1024 can use quantized communication. \";\n    }\n    cudaStream_t cuda_stream = ctx->stream()->As<ep::CudaStream>()->cuda_stream();\n    const std::vector<uint32_t>& num_unique_matrix_vec =\n        embedding_state->GetIdNumUniqueMatrix(current_iter_);\n    CHECK_EQ(sizeof(IDX), sizeof(uint32_t)) << \"assume sizeof(IDX) equals to sizeof(uint32_t)\";\n    ;\n    std::memcpy(host_num_unique_matrix, num_unique_matrix_vec.data(),\n                parallel_num * parallel_num * sizeof(IDX));\n    uint32_t num_unique = embedding_state->GetIdNumUnique(current_iter_);\n\n    int64_t cur_rank_num_ids = 0;\n    for (int64_t i = 0; i < parallel_num; ++i) {\n      cur_rank_num_ids += host_num_unique_matrix[i * parallel_num + parallel_id];\n    }\n    int64_t unique_partitioned_num_ids = 0;\n    for (int64_t i = 0; i < parallel_num; ++i) {\n      unique_partitioned_num_ids += host_num_unique_matrix[parallel_id * parallel_num + i];\n    }\n    const T* cur_rank_embeddings_ptr = reinterpret_cast<const T*>(\n        embedding_state->EmbeddingShuffleCurRankEmbeddings(current_iter_));\n    if (!enable_quantized_comm) {\n      // 1. reverse cur_rank unique, from (num_unique, embedding_size) to (cur_rank_num_ids,\n      // embedding_size)\n      void* reverse_unique_cur_rank_embeddings;\n      allocator->Allocate(&reverse_unique_cur_rank_embeddings,\n                          cur_rank_num_ids * embedding_size * sizeof(T));\n      GatherKernelUtilImpl<DeviceType::kCUDA, T, IDX>::Forward(\n          ctx->stream(), reinterpret_cast<const IDX*>(cur_rank_inverse_indices->dptr()),\n          cur_rank_num_ids, cur_rank_embeddings_ptr, Shape({1, num_unique, embedding_size}),\n          reinterpret_cast<T*>(reverse_unique_cur_rank_embeddings), 0);\n\n      // 2. send recv embedding, from (cur_rank_num_ids, embedding_size) to\n      // (unique_partitioned_num_ids, embedding_size)\n      if (skip_last_gather) {\n        data_shuffle::ShuffleEmbeddings(cuda_stream, comm, parallel_id, parallel_num, num_ids,\n                                        embedding_size, data_type, host_num_unique_matrix,\n                                        reinterpret_cast<T*>(reverse_unique_cur_rank_embeddings),\n                                        embeddings->mut_dptr<T>());\n        allocator->Free(reverse_unique_cur_rank_embeddings);\n      } else {\n        void* received_embeddings;  // T\n        allocator->Allocate(&received_embeddings, GetCudaAlignedSize(unique_partitioned_num_ids\n                                                                     * embedding_size * sizeof(T)));\n\n        data_shuffle::ShuffleEmbeddings(cuda_stream, comm, parallel_id, parallel_num, num_ids,\n                                        embedding_size, data_type, host_num_unique_matrix,\n                                        reinterpret_cast<T*>(reverse_unique_cur_rank_embeddings),\n                                        reinterpret_cast<T*>(received_embeddings));\n        allocator->Free(reverse_unique_cur_rank_embeddings);\n\n        // 3. reverse unique_partition, from (unique_partitioned_num_ids, embedding_size) to\n        // (num_ids, embedding_size)\n        GatherKernelUtilImpl<DeviceType::kCUDA, T, IDX>::Forward(\n            ctx->stream(), reinterpret_cast<const IDX*>(inverse_unique_partition_indices->dptr()),\n            num_ids, reinterpret_cast<T*>(received_embeddings),\n            Shape({1, unique_partitioned_num_ids, embedding_size}), embeddings->mut_dptr<T>(), 0);\n        allocator->Free(received_embeddings);\n      }\n    } else {\n      CHECK(!skip_last_gather) << \"when enable_quantized_comm, should not use fuse kernel.\";\n      // 1. quantize cur_rank_embeddings, from (num_unique, embedding_size) T to (num_unique,\n      // embedding_size) int8_t, and get (num_unique,) T factor\n      void* quantize_cur_rank_embeddings;  // int8_t\n      allocator->Allocate(&quantize_cur_rank_embeddings,\n                          num_unique * embedding_size * sizeof(int8_t));\n      void* cur_rank_quantize_factor;  // T\n      allocator->Allocate(&cur_rank_quantize_factor, num_unique * sizeof(T));\n      DispatchQuantizeWarpImplPackSize<T, ComputeType>()(\n          cuda_stream, cur_rank_embeddings_ptr,\n          reinterpret_cast<int8_t*>(quantize_cur_rank_embeddings),\n          reinterpret_cast<T*>(cur_rank_quantize_factor), num_unique, embedding_size);\n      // 2. reverse cur_rank unique, from (num_unique, embedding_size) to (cur_rank_num_ids,\n      // embedding_size)\n      void* reverse_unique_cur_rank_embeddings;  // int8_t\n\n      allocator->Allocate(&reverse_unique_cur_rank_embeddings,\n                          cur_rank_num_ids * embedding_size * sizeof(int8_t));\n\n      GatherKernelUtilImpl<DeviceType::kCUDA, int8_t, IDX>::Forward(\n          ctx->stream(), reinterpret_cast<const IDX*>(cur_rank_inverse_indices->dptr()),\n          cur_rank_num_ids, reinterpret_cast<int8_t*>(quantize_cur_rank_embeddings),\n          Shape({1, num_unique, embedding_size}),\n          reinterpret_cast<int8_t*>(reverse_unique_cur_rank_embeddings), 0);\n      allocator->Free(quantize_cur_rank_embeddings);\n\n      // 3. reverse cur_rank quantize factor unique, from (num_unique) to (cur_rank_num_ids)\n      void* reverse_cur_rank_quantize_factor;  // T\n      allocator->Allocate(&reverse_cur_rank_quantize_factor, cur_rank_num_ids * sizeof(T));\n\n      GatherKernelUtilImpl<DeviceType::kCUDA, T, IDX>::Forward(\n          ctx->stream(), reinterpret_cast<const IDX*>(cur_rank_inverse_indices->dptr()),\n          cur_rank_num_ids, reinterpret_cast<T*>(cur_rank_quantize_factor),\n          Shape({1, num_unique, 1}), reinterpret_cast<T*>(reverse_cur_rank_quantize_factor), 0);\n      allocator->Free(cur_rank_quantize_factor);\n      // 4. send recv embedding and factor, from (cur_rank_num_ids, embedding_size) to\n      // (unique_partitioned_num_ids, embedding_size)\n      void* received_embeddings;   // int8_t\n      void* recv_quantize_factor;  // T\n      allocator->Allocate(&received_embeddings,\n                          unique_partitioned_num_ids * embedding_size * sizeof(int8_t));\n      allocator->Allocate(&recv_quantize_factor, unique_partitioned_num_ids * sizeof(T));\n\n      data_shuffle::ShuffleEmbeddings(cuda_stream, comm, parallel_id, parallel_num, num_ids,\n                                      embedding_size, data_type, host_num_unique_matrix,\n                                      reinterpret_cast<int8_t*>(reverse_unique_cur_rank_embeddings),\n                                      reinterpret_cast<int8_t*>(received_embeddings),\n                                      reinterpret_cast<T*>(reverse_cur_rank_quantize_factor),\n                                      reinterpret_cast<T*>(recv_quantize_factor));\n      allocator->Free(reverse_unique_cur_rank_embeddings);\n      allocator->Free(reverse_cur_rank_quantize_factor);\n\n      // 5. reverse unique_partition, from (unique_partitioned_num_ids, embedding_size) to (num_ids,\n      // embedding_size)\n      void* reverse_recv_quantize_cur_rank_embeddings;  // int8_t\n      allocator->Allocate(&reverse_recv_quantize_cur_rank_embeddings,\n                          num_ids * embedding_size * sizeof(int8_t));\n\n      GatherKernelUtilImpl<DeviceType::kCUDA, int8_t, IDX>::Forward(\n          ctx->stream(), reinterpret_cast<const IDX*>(inverse_unique_partition_indices->dptr()),\n          num_ids, reinterpret_cast<int8_t*>(received_embeddings),\n          Shape({1, unique_partitioned_num_ids, embedding_size}),\n          reinterpret_cast<int8_t*>(reverse_recv_quantize_cur_rank_embeddings), 0);\n      allocator->Free(received_embeddings);\n      // 6. reverse unique_partition_factor, from (unique_partitioned_num_ids) to (num_ids)\n      void* reverse_recv_quantize_factor;  // T\n      allocator->Allocate(&reverse_recv_quantize_factor, num_ids * sizeof(T));\n\n      GatherKernelUtilImpl<DeviceType::kCUDA, T, IDX>::Forward(\n          ctx->stream(), reinterpret_cast<const IDX*>(inverse_unique_partition_indices->dptr()),\n          num_ids, reinterpret_cast<T*>(recv_quantize_factor),\n          Shape({1, unique_partitioned_num_ids, 1}),\n          reinterpret_cast<T*>(reverse_recv_quantize_factor), 0);\n      allocator->Free(recv_quantize_factor);\n\n      // 7. dequantize embeddings, from (num_ids, embedding_size) int8_t to (num_ids,\n      // embedding_size) T\n      int32_t dequantize_row_size = num_ids;\n      IDX dequantize_elem_cnt = dequantize_row_size * embedding_size;\n      OF_CUDA_CHECK((LaunchDequantizeKernel<T, ComputeType, IDX>(\n          cuda_stream, reinterpret_cast<int8_t*>(reverse_recv_quantize_cur_rank_embeddings),\n          reinterpret_cast<T*>(reverse_recv_quantize_factor), embeddings->mut_dptr<T>(),\n          embedding_size, dequantize_elem_cnt)));\n      allocator->Free(reverse_recv_quantize_cur_rank_embeddings);\n      allocator->Free(reverse_recv_quantize_factor);\n    }\n    embedding_state->OnEmbeddingShuffleEnd(ctx, current_iter_);\n    current_iter_++;\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n  mutable int64_t current_iter_;\n};\n\n#define REGISTER_CUDA_EMBEDDING_SHUFFLE_KERNEL(t_dtype_pair, idx_dtype_pair)                      \\\n  REGISTER_USER_KERNEL(\"embedding_shuffle\")                                                       \\\n      .SetCreateFn<EmbeddingShuffleKernel<OF_PP_PAIR_FIRST(t_dtype_pair),                         \\\n                                          OF_PP_PAIR_FIRST(idx_dtype_pair)>>()                    \\\n      .SetIsMatchedHob(                                                                           \\\n          (user_op::HobDeviceType() == DeviceType::kCUDA)                                         \\\n          && (user_op::HobDataType(\"cur_rank_embeddings\", 0) == OF_PP_PAIR_SECOND(t_dtype_pair))  \\\n          && ((user_op::HobAttr<bool>(\"skip_last_gather\") == false)                               \\\n              || (!embedding::UseEmbeddingShuffleP2PKernel(OF_PP_PAIR_SECOND(t_dtype_pair),       \\\n                                                           OF_PP_PAIR_SECOND(idx_dtype_pair))))   \\\n          && (user_op::HobDataType(\"num_unique_matrix\", 0) == OF_PP_PAIR_SECOND(idx_dtype_pair))) \\\n      .SetInferTmpSizeFn([](user_op::InferContext* ctx) {                                         \\\n        const user_op::TensorDesc& inverse_unique_partition_indices =                             \\\n            ctx->InputTensorDesc(\"inverse_unique_partition_indices\", 0);                          \\\n        const int64_t num_ids = inverse_unique_partition_indices.shape().elem_cnt();              \\\n        const int64_t parallel_num = ctx->parallel_ctx().parallel_num();                          \\\n        const int64_t cur_rank_max_num_ids = parallel_num * num_ids;                              \\\n        const int64_t embedding_size = ctx->Attr<int64_t>(\"embedding_size\");                      \\\n        bool enable_quantized_comm =                                                              \\\n            ParseBooleanFromEnv(\"ONEFLOW_ONE_EMBEDDING_ENABLE_QUANTIZED_COMM\", false)             \\\n            && (embedding_size < kMaxColSize);                                                    \\\n        size_t tmp_size = 0;                                                                      \\\n        if (embedding::UseDynamicMemoryAllocation()) { return tmp_size; }                         \\\n        if (!enable_quantized_comm) {                                                             \\\n          size_t reverse_cur_rank_embeddings_size = GetCudaAlignedSize(                           \\\n              cur_rank_max_num_ids * embedding_size * sizeof(OF_PP_PAIR_FIRST(t_dtype_pair)));    \\\n          size_t recv_unique_embeddings_size = reverse_cur_rank_embeddings_size;                  \\\n          tmp_size = reverse_cur_rank_embeddings_size + recv_unique_embeddings_size;              \\\n        } else {                                                                                  \\\n          size_t total_elem_cnt = cur_rank_max_num_ids * embedding_size;                          \\\n          size_t reverse_cur_rank_embeddings_size =                                               \\\n              GetCudaAlignedSize(total_elem_cnt * sizeof(int8_t));                                \\\n          size_t recv_unique_embeddings = reverse_cur_rank_embeddings_size;                       \\\n          size_t quantize_cur_rank_embeddings_size = reverse_cur_rank_embeddings_size;            \\\n          size_t reverse_recv_quantize_cur_rank_embeddings_size =                                 \\\n              reverse_cur_rank_embeddings_size;                                                   \\\n          size_t cur_rank_quantize_factor_size =                                                  \\\n              GetCudaAlignedSize(cur_rank_max_num_ids * sizeof(OF_PP_PAIR_FIRST(t_dtype_pair)));  \\\n          size_t reverse_cur_rank_quantize_factor_size = cur_rank_quantize_factor_size;           \\\n          size_t recv_quantize_factor_size = cur_rank_quantize_factor_size;                       \\\n          size_t reverse_recv_quantize_factor_size = cur_rank_quantize_factor_size;               \\\n          tmp_size = reverse_cur_rank_embeddings_size + recv_unique_embeddings                    \\\n                     + quantize_cur_rank_embeddings_size                                          \\\n                     + reverse_recv_quantize_cur_rank_embeddings_size                             \\\n                     + cur_rank_quantize_factor_size + reverse_cur_rank_quantize_factor_size      \\\n                     + recv_quantize_factor_size + reverse_recv_quantize_factor_size;             \\\n        }                                                                                         \\\n        return tmp_size;                                                                          \\\n      });\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_CUDA_EMBEDDING_SHUFFLE_KERNEL,\n                                 FLOATING_DATA_TYPE_SEQ HALF_DATA_TYPE_SEQ, IDX_DATA_TYPE_SEQ)\n\ntemplate<typename T, typename IDX>\nclass EmbeddingGradientShuffleKernel final : public user_op::OpKernel {\n public:\n  EmbeddingGradientShuffleKernel() : current_iter_(0){};\n  ~EmbeddingGradientShuffleKernel() override = default;\n\n  std::shared_ptr<user_op::OpKernelState> CreateOpKernelState(\n      user_op::KernelInitContext* ctx) const override {\n    return std::make_shared<DataShuffleKernelState<IDX>>(ctx);\n  }\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state,\n               const user_op::OpKernelCache*) const override {\n    auto* kernel_state = dynamic_cast<DataShuffleKernelState<IDX>*>(state);\n    CHECK(kernel_state != nullptr);\n    embedding::EmbeddingState* embedding_state = kernel_state->EmbeddingState();\n    std::unique_ptr<embedding::TmpBufferAllocator> allocator =\n        embedding_state->NewTmpBufferAllocator(ctx);\n    const user_op::Tensor* embedding_grad = ctx->Tensor4ArgNameAndIndex(\"embedding_grad\", 0);\n\n    const user_op::Tensor* num_unique_matrix = ctx->Tensor4ArgNameAndIndex(\"num_unique_matrix\", 0);\n    const user_op::Tensor* cur_rank_inverse_indices =\n        ctx->Tensor4ArgNameAndIndex(\"cur_rank_inverse_indices\", 0);\n    const user_op::Tensor* inverse_unique_partition_indices =\n        ctx->Tensor4ArgNameAndIndex(\"inverse_unique_partition_indices\", 0);\n    user_op::Tensor* cur_rank_unique_embedding_grad =\n        ctx->Tensor4ArgNameAndIndex(\"cur_rank_unique_embedding_grad\", 0);\n    const int64_t embedding_size = ctx->Attr<int64_t>(\"embedding_size\");\n    const bool only_zero_valid_grad = ctx->Attr<bool>(\"only_zero_valid_grad\");\n    IDX* host_num_unique_matrix = kernel_state->HostNumUniqueMatrix();\n    DataType data_type = embedding_grad->data_type();\n    const int64_t num_ids = inverse_unique_partition_indices->shape_view().elem_cnt();\n    const int64_t parallel_num = ctx->parallel_ctx().parallel_num();\n    const int64_t parallel_id = ctx->parallel_ctx().parallel_id();\n    const int64_t padded_embedding_size =\n        data_shuffle::GetPaddedEmbeddingSize(data_type, embedding_size);\n    user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n    ncclComm_t comm = kernel_state->comm();\n    using ComputeType = typename DefaultComputeType<T>::type;\n    bool enable_quantized_comm_env_var =\n        ParseBooleanFromEnv(\"ONEFLOW_ONE_EMBEDDING_ENABLE_QUANTIZED_COMM\", false);\n    bool enable_quantized_comm =\n        enable_quantized_comm_env_var && (padded_embedding_size < kMaxColSize);\n    if (enable_quantized_comm_env_var && !enable_quantized_comm) {\n      LOG(WARNING) << \"Only envrionment variable ONEFLOW_ONE_EMBEDDING_ENABLE_QUANTIZED_COMM=1 and \"\n                      \"embedding_size less equal than 1024 can use quantized communication. \";\n    }\n    const bool skip_first_scatter = ctx->Attr<bool>(\"skip_first_scatter\");\n    cudaStream_t cuda_stream = ctx->stream()->As<ep::CudaStream>()->cuda_stream();\n    const std::vector<uint32_t>& num_unique_matrix_vec =\n        embedding_state->GetIdNumUniqueMatrix(current_iter_);\n    CHECK_EQ(sizeof(IDX), sizeof(uint32_t)) << \"assume sizeof(IDX) equals to sizeof(uint32_t)\";\n    std::memcpy(host_num_unique_matrix, num_unique_matrix_vec.data(),\n                parallel_num * parallel_num * sizeof(IDX));\n    uint32_t num_unique = embedding_state->GetIdNumUnique(current_iter_);\n\n    int64_t cur_rank_num_ids = 0;\n    for (int64_t i = 0; i < parallel_num; ++i) {\n      cur_rank_num_ids += host_num_unique_matrix[i * parallel_num + parallel_id];\n    }\n    int64_t unique_partitioned_num_ids = 0;\n    for (int64_t i = 0; i < parallel_num; ++i) {\n      unique_partitioned_num_ids += host_num_unique_matrix[parallel_id * parallel_num + i];\n    }\n    if (!enable_quantized_comm) {\n      // 1. sum to unique grad, from (num_ids, embedding_size) to (unique_partitioned_num_ids,\n      // padded_embedding_size)\n      void* unique_partition_embedding_grad;  // T\n      allocator->Allocate(&unique_partition_embedding_grad,\n                          unique_partitioned_num_ids * padded_embedding_size * sizeof(T));\n\n      const T* unique_embedding_grad_ptr;\n      if (skip_first_scatter) {\n        unique_embedding_grad_ptr = embedding_grad->dptr<T>();\n      } else {\n        data_shuffle::UniquePartitionEmbeddingGrad(\n            ctx->stream(), unique_partitioned_num_ids, num_ids, embedding_size,\n            padded_embedding_size, host_num_unique_matrix, embedding_grad->dptr<T>(),\n            reinterpret_cast<const IDX*>(inverse_unique_partition_indices->dptr()),\n            reinterpret_cast<T*>(unique_partition_embedding_grad));\n        unique_embedding_grad_ptr = reinterpret_cast<T*>(unique_partition_embedding_grad);\n      }\n      // 2. send recv grad, from (unique_partitioned_num_ids, padded_embedding_size) to\n      // (cur_rank_num_ids, padded_embedding_size)\n      void* received_embedding_grad;  // T\n      allocator->Allocate(&received_embedding_grad,\n                          cur_rank_num_ids * padded_embedding_size * sizeof(T));\n\n      data_shuffle::ShuffleEmbeddingsGrad(cuda_stream, comm, parallel_id, parallel_num, num_ids,\n                                          padded_embedding_size, data_type, host_num_unique_matrix,\n                                          unique_embedding_grad_ptr,\n                                          reinterpret_cast<T*>(received_embedding_grad));\n\n      // 3. sum to unique grad, from (cur_rank_num_ids, padded_embedding_size) to (num_unique,\n      // padded_embedding_size) then slice to out from (num_unique, padded_embedding_size) to\n      // (num_unique, embedding_size) should memset cur_rank_unique_embedding_grad all tensor for\n      // amp count_not_finite\n      // use unique_partition_embedding_grad as UniqueCurRankEmbeddingGrad buffer.\n      T* buffer_ptr = reinterpret_cast<T*>(unique_partition_embedding_grad);\n      data_shuffle::UniqueCurRankEmbeddingGrad<T, IDX>(\n          ctx->stream(), data_type, cur_rank_num_ids, num_unique, embedding_size,\n          padded_embedding_size, only_zero_valid_grad,\n          cur_rank_unique_embedding_grad->shape_view().elem_cnt(),\n          reinterpret_cast<T*>(received_embedding_grad),\n          reinterpret_cast<const IDX*>(cur_rank_inverse_indices->dptr()),\n          cur_rank_unique_embedding_grad->mut_dptr<T>(), buffer_ptr);\n      allocator->Free(unique_partition_embedding_grad);\n      allocator->Free(received_embedding_grad);\n    } else {\n      CHECK(!skip_first_scatter) << \"when enable_quantized_comm, should not use fuse kernel.\";\n      // 1. sum to unique grad, from (num_ids, embedding_size) to (unique_partitioned_num_ids,\n      // padded_embedding_size)\n      void* unique_partition_embedding_grad;  // T\n      allocator->Allocate(&unique_partition_embedding_grad,\n                          unique_partitioned_num_ids * padded_embedding_size * sizeof(T));\n\n      data_shuffle::UniquePartitionEmbeddingGrad(\n          ctx->stream(), unique_partitioned_num_ids, num_ids, embedding_size, padded_embedding_size,\n          host_num_unique_matrix, embedding_grad->dptr<T>(),\n          reinterpret_cast<const IDX*>(inverse_unique_partition_indices->dptr()),\n          reinterpret_cast<T*>(unique_partition_embedding_grad));\n\n      // 2. Quantize unique_partition_embedding_grad, get\n      // quantize_cur_rank_embedding_grad(unique_partitioned_num_ids, padded_embedding_size) int8_t\n      // and cur_rank_quantize_factor(unique_partitioned_num_ids) T\n      void* quantize_cur_rank_embedding_grad;  // int8_t\n      allocator->Allocate(&quantize_cur_rank_embedding_grad,\n                          unique_partitioned_num_ids * padded_embedding_size * sizeof(int8_t));\n      void* cur_rank_quantize_factor;  // T\n      allocator->Allocate(&cur_rank_quantize_factor, unique_partitioned_num_ids * sizeof(T));\n\n      DispatchQuantizeWarpImplPackSize<T, ComputeType>()(\n          cuda_stream, reinterpret_cast<T*>(unique_partition_embedding_grad),\n          reinterpret_cast<int8_t*>(quantize_cur_rank_embedding_grad),\n          reinterpret_cast<T*>(cur_rank_quantize_factor), unique_partitioned_num_ids,\n          padded_embedding_size);\n\n      // 3. send recv grad, from (unique_partitioned_num_ids, padded_embedding_size) int8_t to\n      // (cur_rank_num_ids, padded_embedding_size) int8_t send recv quantize_factor, from\n      // (unique_partitioned_num_ids) T to (cur_rank_num_ids) T\n      void* received_embedding_grad;  // int8_t\n      allocator->Allocate(&received_embedding_grad,\n                          cur_rank_num_ids * padded_embedding_size * sizeof(int8_t));\n      void* received_cur_rank_quantize_factor;  // T\n      allocator->Allocate(&received_cur_rank_quantize_factor, cur_rank_num_ids * sizeof(T));\n\n      data_shuffle::ShuffleEmbeddingsGrad(\n          cuda_stream, comm, parallel_id, parallel_num, num_ids, padded_embedding_size, data_type,\n          host_num_unique_matrix, reinterpret_cast<int8_t*>(quantize_cur_rank_embedding_grad),\n          reinterpret_cast<int8_t*>(received_embedding_grad),\n          reinterpret_cast<T*>(cur_rank_quantize_factor),\n          reinterpret_cast<T*>(received_cur_rank_quantize_factor));\n      allocator->Free(quantize_cur_rank_embedding_grad);\n      allocator->Free(cur_rank_quantize_factor);\n\n      /*\n      Host num unique matrix:\n              |  Partition0  |  Partition1  |\n      | Rank0 |      2       |       4      |\n      | Rank1 |      3       |       3      |\n      After ShuffleEmbeddingGrads, each rank will exchange partition.\n      For example:\n      Rank0 will have (matrix[rank0][part0] + matrix[rank1][part0]) grad tensor.\n      Rank1 will have (matrix[rank0][part1] + matrix[rank1][part1]) grad tensor.\n      */\n      // 4. dequantize grad, from (cur_rank_num_ids, padded_embedding_size) int8_t to\n      // (cur_rank_num_ids, padded_embedding_size) T\n      void* dequantize_cur_rank_embedding_grad;  // T\n      allocator->Allocate(&dequantize_cur_rank_embedding_grad,\n                          cur_rank_num_ids * padded_embedding_size * sizeof(T));\n\n      OF_CUDA_CHECK((LaunchDequantizeKernel<T, ComputeType, IDX>(\n          cuda_stream, reinterpret_cast<int8_t*>(received_embedding_grad),\n          reinterpret_cast<T*>(received_cur_rank_quantize_factor),\n          reinterpret_cast<T*>(dequantize_cur_rank_embedding_grad), padded_embedding_size,\n          cur_rank_num_ids * padded_embedding_size)));\n      allocator->Free(received_embedding_grad);\n      allocator->Free(received_cur_rank_quantize_factor);\n\n      // use unique_partition_embedding_grad as UniqueCurRankEmbeddingGrad buffer.\n      T* buffer_ptr = reinterpret_cast<T*>(unique_partition_embedding_grad);\n      // 5. sum to unique grad, from (cur_rank_num_ids, padded_embedding_size) to (num_unique,\n      // padded_embedding_size) then slice to out from (num_unique, padded_embedding_size) to\n      // (num_unique, embedding_size) should memset cur_rank_unique_embedding_grad all tensor for\n      // amp count_not_finite\n      data_shuffle::UniqueCurRankEmbeddingGrad<T, IDX>(\n          ctx->stream(), data_type, cur_rank_num_ids, num_unique, embedding_size,\n          padded_embedding_size, only_zero_valid_grad,\n          cur_rank_unique_embedding_grad->shape_view().elem_cnt(),\n          reinterpret_cast<T*>(dequantize_cur_rank_embedding_grad),\n          reinterpret_cast<const IDX*>(cur_rank_inverse_indices->dptr()),\n          cur_rank_unique_embedding_grad->mut_dptr<T>(), buffer_ptr);\n      allocator->Free(unique_partition_embedding_grad);\n      allocator->Free(dequantize_cur_rank_embedding_grad);\n    }\n    current_iter_++;\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n  mutable int64_t current_iter_;\n};\n\n#define REGISTER_CUDA_EMBEDDING_GRADIENT_SHUFFLE_KERNEL(t_dtype_pair, idx_dtype_pair)             \\\n  REGISTER_USER_KERNEL(\"embedding_gradient_shuffle\")                                              \\\n      .SetCreateFn<EmbeddingGradientShuffleKernel<OF_PP_PAIR_FIRST(t_dtype_pair),                 \\\n                                                  OF_PP_PAIR_FIRST(idx_dtype_pair)>>()            \\\n      .SetIsMatchedHob(                                                                           \\\n          (user_op::HobDeviceType() == DeviceType::kCUDA)                                         \\\n          && (user_op::HobDataType(\"embedding_grad\", 0) == OF_PP_PAIR_SECOND(t_dtype_pair))       \\\n          && ((user_op::HobAttr<bool>(\"skip_first_scatter\") == false)                             \\\n              || (!embedding::UseEmbeddingGradientShuffleP2PKernel(                               \\\n                  OF_PP_PAIR_SECOND(t_dtype_pair), OF_PP_PAIR_SECOND(idx_dtype_pair))))           \\\n          && (user_op::HobDataType(\"num_unique_matrix\", 0) == OF_PP_PAIR_SECOND(idx_dtype_pair))) \\\n      .SetInferTmpSizeFn([](user_op::InferContext* ctx) {                                         \\\n        const user_op::TensorDesc& cur_rank_unique_embedding_grad =                               \\\n            ctx->InputTensorDesc(\"cur_rank_unique_embedding_grad\", 0);                            \\\n        size_t cur_rank_embedding_grad_num = cur_rank_unique_embedding_grad.shape().At(0);        \\\n        size_t embedding_size = cur_rank_unique_embedding_grad.shape().At(1);                     \\\n        size_t padded_embedding_size = data_shuffle::GetPaddedEmbeddingSize(                      \\\n            cur_rank_unique_embedding_grad.data_type(), embedding_size);                          \\\n        size_t cur_rank_embedding_grad_elem_cnt =                                                 \\\n            cur_rank_embedding_grad_num * padded_embedding_size;                                  \\\n        bool enable_quantized_comm =                                                              \\\n            ParseBooleanFromEnv(\"ONEFLOW_ONE_EMBEDDING_ENABLE_QUANTIZED_COMM\", false)             \\\n            && (padded_embedding_size < kMaxColSize);                                             \\\n        size_t tmp_size = 0;                                                                      \\\n        if (embedding::UseDynamicMemoryAllocation()) { return tmp_size; }                         \\\n        if (!enable_quantized_comm) {                                                             \\\n          size_t cur_rank_embedding_grad_size = GetCudaAlignedSize(                               \\\n              cur_rank_embedding_grad_elem_cnt * sizeof(OF_PP_PAIR_FIRST(t_dtype_pair)));         \\\n          tmp_size = 2 * cur_rank_embedding_grad_size;                                            \\\n        } else {                                                                                  \\\n          size_t unique_partition_embedding_grad_size = GetCudaAlignedSize(                       \\\n              cur_rank_embedding_grad_elem_cnt * sizeof(OF_PP_PAIR_FIRST(t_dtype_pair)));         \\\n          size_t received_embedding_grad_size =                                                   \\\n              GetCudaAlignedSize(cur_rank_embedding_grad_elem_cnt * sizeof(int8_t));              \\\n          size_t quantize_cur_rank_embedding_grad_size = received_embedding_grad_size;            \\\n          size_t cur_rank_quantize_factor_size = GetCudaAlignedSize(                              \\\n              cur_rank_embedding_grad_num * sizeof(OF_PP_PAIR_FIRST(t_dtype_pair)));              \\\n          size_t received_cur_rank_quantize_factor_size = cur_rank_quantize_factor_size;          \\\n          size_t dequantize_cur_rank_embedding_grad_size = unique_partition_embedding_grad_size;  \\\n          tmp_size = unique_partition_embedding_grad_size + received_embedding_grad_size          \\\n                     + quantize_cur_rank_embedding_grad_size + cur_rank_quantize_factor_size      \\\n                     + received_cur_rank_quantize_factor_size                                     \\\n                     + dequantize_cur_rank_embedding_grad_size;                                   \\\n        }                                                                                         \\\n        return tmp_size;                                                                          \\\n      });\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_CUDA_EMBEDDING_GRADIENT_SHUFFLE_KERNEL,\n                                 FLOATING_DATA_TYPE_SEQ HALF_DATA_TYPE_SEQ, IDX_DATA_TYPE_SEQ)\n\ntemplate<typename IDX>\nclass EmbeddingUniqueKeyValuePairKernelState final : public user_op::OpKernelState {\n public:\n  explicit EmbeddingUniqueKeyValuePairKernelState(user_op::KernelInitContext* ctx)\n      : device_index_(-1) {\n    OF_CUDA_CHECK(cudaGetDevice(&device_index_));\n    OF_CUDA_CHECK(cudaMallocHost(&host_num_keys_, sizeof(IDX)));\n    const std::string& embedding_name = ctx->Attr<std::string>(\"embedding_name\");\n    const int64_t parallel_id = ctx->parallel_ctx().parallel_id();\n    embedding_state_ = Singleton<embedding::EmbeddingManager>::Get()->GetEmbeddingState(\n        embedding_name, parallel_id);\n  }\n  ~EmbeddingUniqueKeyValuePairKernelState() {\n    CudaCurrentDeviceGuard guard(device_index_);\n    OF_CUDA_CHECK(cudaFreeHost(host_num_keys_));\n  }\n\n  embedding::EmbeddingState* EmbeddingState() { return embedding_state_; }\n\n  IDX* HostNumKeys() { return host_num_keys_; }\n\n private:\n  int device_index_;\n  embedding::EmbeddingState* embedding_state_;\n  IDX* host_num_keys_;\n};\n\ntemplate<typename K, typename V, typename IDX>\nclass UniqueKeyValuePairKernel final : public user_op::OpKernel {\n public:\n  UniqueKeyValuePairKernel() : current_iter_(0){};\n  ~UniqueKeyValuePairKernel() override = default;\n\n  std::shared_ptr<user_op::OpKernelState> CreateOpKernelState(\n      user_op::KernelInitContext* ctx) const override {\n    return std::make_shared<EmbeddingUniqueKeyValuePairKernelState<IDX>>(ctx);\n  }\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state,\n               const user_op::OpKernelCache*) const override {\n    auto* kernel_state = dynamic_cast<EmbeddingUniqueKeyValuePairKernelState<IDX>*>(state);\n    CHECK(kernel_state != nullptr);\n    const user_op::Tensor* keys = ctx->Tensor4ArgNameAndIndex(\"keys\", 0);\n    user_op::Tensor* num_unique = ctx->Tensor4ArgNameAndIndex(\"num_unique\", 0);\n    user_op::Tensor* unique_keys = ctx->Tensor4ArgNameAndIndex(\"unique_keys\", 0);\n    user_op::Tensor* unique_values = ctx->Tensor4ArgNameAndIndex(\"unique_values\", 0);\n    user_op::Tensor* inverse_indices = ctx->Tensor4ArgNameAndIndex(\"inverse_indices\", 0);\n    user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n    const int32_t num_tables = ctx->Attr<int32_t>(\"num_tables\");\n    const int64_t padding_idx = ctx->Attr<int64_t>(\"padding_idx\");\n    const bool has_padding_idx = ctx->Attr<bool>(\"has_padding_idx\");\n    const bool has_values = ctx->has_input(\"values\", 0);\n    const bool need_values_buffer = (!has_values && num_tables > 1);\n    size_t values_buffer_bytes =\n        need_values_buffer ? GetCudaAlignedSize(keys->shape_view().elem_cnt() * sizeof(V)) : 0;\n    const int64_t num_keys = keys->shape_view().elem_cnt();\n    const int64_t hash_capacity = num_keys;\n    const size_t workspace_bytes =\n        GetCudaAlignedSize(hash_capacity * sizeof(data_shuffle::TableEntry<K>));\n    CHECK_LE(values_buffer_bytes + workspace_bytes, tmp_buffer->shape_view().elem_cnt());\n    cudaStream_t cuda_stream = ctx->stream()->As<ep::CudaStream>()->cuda_stream();\n    const V* values_ptr;\n    if (has_values) {\n      const user_op::Tensor* values = ctx->Tensor4ArgNameAndIndex(\"values\", 0);\n      values_ptr = reinterpret_cast<const V*>(values->dptr());\n    } else if (need_values_buffer) {\n      V* values_buffer_ptr = reinterpret_cast<V*>(tmp_buffer->mut_dptr());\n      data_shuffle::GenerateTableIds<<<BlocksNum4ThreadsNum(num_keys), kCudaThreadsNumPerBlock, 0,\n                                       cuda_stream>>>(num_keys, num_tables, values_buffer_ptr);\n      values_ptr = values_buffer_ptr;\n    } else {\n      values_ptr = nullptr;\n    }\n    const bool need_process_table_ids = (has_values || num_tables > 1);\n    data_shuffle::TableEntry<K>* workspace_ptr = reinterpret_cast<data_shuffle::TableEntry<K>*>(\n        tmp_buffer->mut_dptr<char>() + values_buffer_bytes);\n    data_shuffle::UniqueAndPartition<K, V, IDX, embedding::GlobalUniqueHash>(\n        cuda_stream, num_keys, hash_capacity, 1, reinterpret_cast<const K*>(keys->dptr()),\n        values_ptr, reinterpret_cast<IDX*>(num_unique->mut_dptr()),\n        reinterpret_cast<K*>(unique_keys->mut_dptr()),\n        reinterpret_cast<V*>(unique_values->mut_dptr()),\n        reinterpret_cast<IDX*>(inverse_indices->mut_dptr()), workspace_ptr, workspace_bytes,\n        need_process_table_ids, has_padding_idx, padding_idx);\n\n    IDX* host_num_keys = kernel_state->HostNumKeys();\n    OF_CUDA_CHECK(cudaMemcpyAsync(host_num_keys, num_unique->mut_dptr(), sizeof(IDX),\n                                  cudaMemcpyDefault, cuda_stream));\n    CHECK_JUST(ctx->stream()->Sync());\n    uint32_t num_unique_ids = *host_num_keys;\n    embedding::EmbeddingState* embedding_state = kernel_state->EmbeddingState();\n    std::vector<uint32_t> num_unique_matrix_vec({num_unique_ids});\n    embedding_state->SetIdNumUniqueMatrix(num_unique_matrix_vec, current_iter_);\n    embedding_state->SetIdFinalNumUnique(num_unique_ids, current_iter_);\n    current_iter_++;\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n  mutable int64_t current_iter_;\n};\n\n#define REGISTER_CUDA_UNIQUE_KEY_VALUE_PAIR_KERNEL(k_dtype_pair, value_dtype_pair, idx_dtype_pair) \\\n  REGISTER_USER_KERNEL(\"unique_key_value_pair\")                                                    \\\n      .SetCreateFn<UniqueKeyValuePairKernel<OF_PP_PAIR_FIRST(k_dtype_pair),                        \\\n                                            OF_PP_PAIR_FIRST(value_dtype_pair),                    \\\n                                            OF_PP_PAIR_FIRST(idx_dtype_pair)>>()                   \\\n      .SetIsMatchedHob(                                                                            \\\n          (user_op::HobDeviceType() == DeviceType::kCUDA)                                          \\\n          && (user_op::HobDataType(\"keys\", 0) == OF_PP_PAIR_SECOND(k_dtype_pair))                  \\\n          && (user_op::HobDataType(\"inverse_indices\", 0) == OF_PP_PAIR_SECOND(idx_dtype_pair))     \\\n          && (user_op::HobDataType(\"unique_values\", 0) == OF_PP_PAIR_SECOND(value_dtype_pair)))    \\\n      .SetInferTmpSizeFn([](user_op::InferContext* ctx) {                                          \\\n        const user_op::TensorDesc& keys = ctx->InputTensorDesc(\"keys\", 0);                         \\\n        const int64_t num_keys = keys.shape().elem_cnt();                                          \\\n        const int64_t hash_capacity = num_keys;                                                    \\\n        const size_t workspace_bytes = GetCudaAlignedSize(                                         \\\n            hash_capacity * sizeof(data_shuffle::TableEntry<OF_PP_PAIR_FIRST(k_dtype_pair)>));     \\\n        const int32_t num_tables = ctx->Attr<int32_t>(\"num_tables\");                               \\\n        const bool has_values = ctx->has_input(\"values\", 0);                                       \\\n        const bool need_values_buffer = (!has_values && num_tables > 1);                           \\\n        size_t values_buffer_bytes =                                                               \\\n            need_values_buffer                                                                     \\\n                ? GetCudaAlignedSize(num_keys * sizeof(OF_PP_PAIR_FIRST(value_dtype_pair)))        \\\n                : 0;                                                                               \\\n        return workspace_bytes + values_buffer_bytes;                                              \\\n      });\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_CUDA_UNIQUE_KEY_VALUE_PAIR_KERNEL, ID_DATA_TYPE_SEQ,\n                                 TABLE_ID_DATA_TYPE_SEQ, IDX_DATA_TYPE_SEQ)\n\ntemplate<typename T, typename IDX>\nclass OneEmbeddingGatherKernel final : public user_op::OpKernel {\n public:\n  OneEmbeddingGatherKernel() : current_iter_(0) {}\n  ~OneEmbeddingGatherKernel() override = default;\n\n  std::shared_ptr<user_op::OpKernelState> CreateOpKernelState(\n      user_op::KernelInitContext* ctx) const override {\n    return std::make_shared<EmbeddingUniqueKeyValuePairKernelState<IDX>>(ctx);\n  }\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state,\n               const user_op::OpKernelCache*) const override {\n    auto* kernel_state = dynamic_cast<EmbeddingUniqueKeyValuePairKernelState<IDX>*>(state);\n    CHECK(kernel_state != nullptr);\n    embedding::EmbeddingState* embedding_state = kernel_state->EmbeddingState();\n    embedding_state->OnEmbeddingGatherStart(ctx, current_iter_);\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    const user_op::Tensor* indices = ctx->Tensor4ArgNameAndIndex(\"indices\", 0);\n    const int64_t num_indices = indices->shape_view().elem_cnt();\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    uint32_t num_unique = embedding_state->GetIdNumUnique(current_iter_);\n    const int64_t embedding_size = ctx->Attr<int64_t>(\"embedding_size\");\n    const T* in_ptr = reinterpret_cast<const T*>(embedding_state->EmbeddingGatherIn(current_iter_));\n    GatherKernelUtilImpl<DeviceType::kCUDA, T, IDX>::Forward(\n        ctx->stream(), reinterpret_cast<const IDX*>(indices->dptr()), num_indices, in_ptr,\n        Shape({1, num_unique, embedding_size}), out->mut_dptr<T>(), 0);\n    embedding_state->OnEmbeddingGatherEnd(ctx, current_iter_);\n    current_iter_++;\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n  mutable int64_t current_iter_;\n};\n\n#define REGISTER_ONE_EMBEDDING_GATHER_KERNEL(in_type, indices_type)                              \\\n  REGISTER_USER_KERNEL(\"one_embedding_gather\")                                                   \\\n      .SetCreateFn<                                                                              \\\n          OneEmbeddingGatherKernel<OF_PP_PAIR_FIRST(in_type), OF_PP_PAIR_FIRST(indices_type)>>() \\\n      .SetIsMatchedHob(                                                                          \\\n          (user_op::HobDeviceType() == DeviceType::kCUDA)                                        \\\n          && (user_op::HobDataType(\"in\", 0) == OF_PP_PAIR_SECOND(in_type))                       \\\n          && (user_op::HobDataType(\"indices\", 0) == OF_PP_PAIR_SECOND(indices_type)));\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_ONE_EMBEDDING_GATHER_KERNEL,\n                                 FLOATING_DATA_TYPE_SEQ HALF_DATA_TYPE_SEQ, IDX_DATA_TYPE_SEQ)\n\nREGISTER_USER_KERNEL_UNIFIED_NCCL_COMM_INIT(\"id_shuffle\");\nREGISTER_USER_KERNEL_UNIFIED_NCCL_COMM_INIT(\"embedding_shuffle\");\nREGISTER_USER_KERNEL_UNIFIED_NCCL_COMM_INIT(\"embedding_gradient_shuffle\");\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/deconv_cpu_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/job/lazy_mode.h\"\n#include \"oneflow/user/ops/nn_util.h\"\n#include \"oneflow/core/kernel/kernel_util.h\"\n#include \"oneflow/core/ep/include/primitive/matmul.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nep::primitive::BlasTransposeType GetBlasTransposeType(bool transpose) {\n  return transpose ? ep::primitive::BlasTransposeType::T : ep::primitive::BlasTransposeType::N;\n}\n\nstd::unique_ptr<ep::primitive::Matmul> NewMatmulPrimitive(DeviceType device_type,\n                                                          DataType data_type, bool transpose_a,\n                                                          bool transpose_b) {\n  const auto trans_a = GetBlasTransposeType(transpose_a);\n  const auto trans_b = GetBlasTransposeType(transpose_b);\n  return ep::primitive::NewPrimitive<ep::primitive::MatmulFactory>(device_type, data_type, trans_a,\n                                                                   trans_b);\n}\n\ntemplate<typename Context>\nstd::unique_ptr<ep::primitive::Matmul> NewDeconvTransATransBMatmulPrimitive(Context* ctx) {\n  const DataType data_type = ctx->TensorDesc4ArgNameAndIndex(\"in\", 0)->data_type();\n  return NewMatmulPrimitive(ctx->device_type(), data_type, true, true);\n}\n\ntemplate<typename Context>\nstd::unique_ptr<ep::primitive::Matmul> NewDeconvTransANoTransBMatmulPrimitive(Context* ctx) {\n  const DataType data_type = ctx->TensorDesc4ArgNameAndIndex(\"in\", 0)->data_type();\n  return NewMatmulPrimitive(ctx->device_type(), data_type, true, false);\n}\n\nauto DeconvTransATransBMatmulPrimitiveExists() {\n  return hob::make_custom(\"DeconvTransATransBMatmulPrimitiveExists\",\n                          [](const user_op::KernelRegContext& ctx) {\n                            return NewDeconvTransATransBMatmulPrimitive(&ctx).operator bool();\n                          });\n}\n\nauto DeconvTransANoTransBMatmulPrimitiveExists() {\n  return hob::make_custom(\"DeconvTransANoTransBMatmulPrimitiveExists\",\n                          [](const user_op::KernelRegContext& ctx) {\n                            return NewDeconvTransANoTransBMatmulPrimitive(&ctx).operator bool();\n                          });\n}\n\ntemplate<typename T>\nusing Col2ImFunc = void (*)(const T* col_buf, const ShapeView& in_shape,\n                            const ShapeView& weight_shape, const ShapeView& out_shape,\n                            const int32_t* strides, const int32_t* dilation_rate,\n                            const int32_t* padding_before, T* in_diff_ptr);\n\ntemplate<typename T>\nT* GetImgMutDptr(user_op::Tensor* tensor, int64_t idx) {\n  return tensor->mut_dptr<T>() + tensor->shape_view().Count(1) * idx;\n}\n\ntemplate<typename T>\nconst T* GetImgDptr(const user_op::Tensor* tensor, int64_t idx) {\n  return tensor->dptr<T>() + tensor->shape_view().Count(1) * idx;\n}\n\nsize_t CalcElemNumOfColBuf(const ShapeView& out_shape, const ShapeView& weight_shape,\n                           const int32_t idx_offset) {\n  int64_t col_buf_elem_cnt = 1;\n  int64_t ndims = out_shape.NumAxes() - 2;\n  for (size_t i = 0; i != ndims + 1; ++i) { col_buf_elem_cnt *= weight_shape.At(i + 1); }\n  for (size_t i = 0; i != ndims; ++i) { col_buf_elem_cnt *= out_shape.At(idx_offset + i); }\n  return col_buf_elem_cnt;\n}\n\ntemplate<typename T>\nclass ColBufWriter {\n public:\n  ColBufWriter(const T* src_ptr, T* dst_ptr, int64_t c_size, int64_t id_size, int64_t ih_size,\n               int64_t iw_size, int64_t od_size, int64_t oh_size, int64_t ow_size)\n      : src_ptr_(src_ptr),\n        dst_ptr_(dst_ptr),\n        c_size_(c_size),\n        id_size_(id_size),\n        ih_size_(ih_size),\n        iw_size_(iw_size),\n        od_size_(od_size),\n        oh_size_(oh_size),\n        ow_size_(ow_size) {}\n  virtual ~ColBufWriter() = default;\n  virtual void DHWCWrite(int64_t c, int64_t id, int64_t ih, int64_t iw) = 0;\n  virtual void CDHWWrite(int64_t c, int64_t id, int64_t ih, int64_t iw) = 0;\n  virtual void InvalidDFunc() = 0;\n  virtual void InvalidHFunc() = 0;\n  virtual void InvalidWFunc() = 0;\n  virtual void NextImCSize() = 0;\n\n protected:\n  const T* src_ptr_;\n  T* dst_ptr_;\n  int64_t c_size_;\n  int64_t id_size_;\n  int64_t ih_size_;\n  int64_t iw_size_;\n  int64_t od_size_;\n  int64_t oh_size_;\n  int64_t ow_size_;\n};\n\ntemplate<typename T>\nclass Col2ImWriter final : public ColBufWriter<T> {\n public:\n  Col2ImWriter(const T* src_ptr, T* dst_ptr, int64_t c_size, int64_t id_size, int64_t ih_size,\n               int64_t iw_size, int64_t od_size, int64_t oh_size, int64_t ow_size)\n      : ColBufWriter<T>::ColBufWriter(src_ptr, dst_ptr, c_size, id_size, ih_size, iw_size, od_size,\n                                      oh_size, ow_size) {}\n  ~Col2ImWriter() = default;\n  void DHWCWrite(int64_t c, int64_t id, int64_t ih, int64_t iw) override {\n    this->dst_ptr_[id * this->id_size_ + ih * this->ih_size_ + iw * this->iw_size_ + c] +=\n        *(this->src_ptr_++);\n  }\n  void CDHWWrite(int64_t c, int64_t id, int64_t ih, int64_t iw) override {\n    this->dst_ptr_[id * this->id_size_ + ih * this->ih_size_ + iw] += *(this->src_ptr_++);\n  }\n  void InvalidDFunc() override { this->src_ptr_ += this->od_size_; }\n  void InvalidHFunc() override { this->src_ptr_ += this->oh_size_; }\n  void InvalidWFunc() override { this->src_ptr_ += this->ow_size_; }\n  void NextImCSize() override { this->dst_ptr_ += this->c_size_; }\n};\n\ntemplate<typename T>\nusing DHWValidFunc = void (ColBufWriter<T>::*)(int64_t c, int64_t kd, int64_t kh, int64_t kw);\n\ntemplate<typename T>\nclass ColBufUtil final {\n public:\n  ColBufUtil(const ShapeView& in_shape, const ShapeView& out_shape, int32_t dhw_offset,\n             const int32_t* strides, const int32_t* dilation_rate, const int32_t* padding_before)\n      : strides_(strides), dilation_rate_(dilation_rate), padding_before_(padding_before) {\n    id_num_ = in_shape.At(dhw_offset);\n    ih_num_ = in_shape.At(dhw_offset + 1);\n    iw_num_ = in_shape.At(dhw_offset + 2);\n    od_num_ = out_shape.At(dhw_offset);\n    oh_num_ = out_shape.At(dhw_offset + 1);\n    ow_num_ = out_shape.At(dhw_offset + 2);\n    if (dhw_offset == 2) {\n      dhw_valid_func_ = &ColBufWriter<T>::CDHWWrite;\n    } else {\n      dhw_valid_func_ = &ColBufWriter<T>::DHWCWrite;\n    }\n  }\n  void operator()(ColBufWriter<T>* col_buf_writer, int64_t c, int64_t kd, int64_t kh, int64_t kw) {\n    int64_t id = kd * dilation_rate_[0] - padding_before_[0];\n    FOR_RANGE(int64_t, od, 0, od_num_) {\n      if (id < 0 || id >= id_num_) {\n        col_buf_writer->InvalidDFunc();\n      } else {\n        int64_t ih = kh * dilation_rate_[1] - padding_before_[1];\n        FOR_RANGE(int64_t, oh, 0, oh_num_) {\n          if (ih < 0 || ih >= ih_num_) {\n            col_buf_writer->InvalidHFunc();\n          } else {\n            int64_t iw = kw * dilation_rate_[2] - padding_before_[2];\n            FOR_RANGE(int64_t, ow, 0, ow_num_) {\n              if (iw < 0 || iw >= iw_num_) {\n                col_buf_writer->InvalidWFunc();\n              } else {\n                (col_buf_writer->*dhw_valid_func_)(c, id, ih, iw);\n              }\n              iw += strides_[2];\n            }\n          }\n          ih += strides_[1];\n        }\n      }\n      id += strides_[0];\n    }\n  }\n\n private:\n  int64_t id_num_;\n  int64_t ih_num_;\n  int64_t iw_num_;\n  int64_t od_num_;\n  int64_t oh_num_;\n  int64_t ow_num_;\n  const int32_t* strides_;\n  const int32_t* dilation_rate_;\n  const int32_t* padding_before_;\n  DHWValidFunc<T> dhw_valid_func_;\n};\n\ntemplate<typename T>\nstruct DeconvKernelUtil final {\n public:\n  static void NCDHWCol2Im(const T* col_buf_ptr, const ShapeView& in_shape,\n                          const ShapeView& weight_shape, const ShapeView& out_shape,\n                          const int32_t* strides, const int32_t* dilation_rate,\n                          const int32_t* padding_before, T* in_diff_ptr) {\n    ColBufUtil<T> col_buf_util(in_shape, out_shape, 2, strides, dilation_rate, padding_before);\n    Col2ImWriter<T> col_buf_writer(col_buf_ptr, in_diff_ptr, in_shape.Count(2), in_shape.Count(3),\n                                   in_shape.Count(4), 1, out_shape.Count(3), out_shape.Count(4), 1);\n    DoNCDWHFunc(weight_shape, col_buf_util, &col_buf_writer);\n  }\n\n  static void NDHWCCol2Im(const T* col_buf_ptr, const ShapeView& in_shape,\n                          const ShapeView& weight_shape, const ShapeView& out_shape,\n                          const int32_t* strides, const int32_t* dilation_rate,\n                          const int32_t* padding_before, T* in_diff_ptr) {\n    ColBufUtil<T> col_buf_util(in_shape, out_shape, 1, strides, dilation_rate, padding_before);\n    Col2ImWriter<T> col_buf_writer(col_buf_ptr, in_diff_ptr, in_shape.Count(2), in_shape.Count(2),\n                                   in_shape.Count(3), in_shape.Count(4), out_shape.Count(2, 4),\n                                   out_shape.Count(3, 4), 1);\n    DoNDWHCFunc(weight_shape, col_buf_util, &col_buf_writer);\n  }\n\n private:\n  static void DoNCDWHFunc(const ShapeView& weight_shape, ColBufUtil<T>& col_buf_util,\n                          ColBufWriter<T>* col_buf_writer) {\n    for (int64_t c = 0; c != weight_shape.At(1); col_buf_writer->NextImCSize(), ++c) {\n      for (int64_t kd = 0; kd != weight_shape.At(2); ++kd) {\n        for (int64_t kh = 0; kh != weight_shape.At(3); ++kh) {\n          for (int64_t kw = 0; kw != weight_shape.At(4); ++kw) {\n            col_buf_util(col_buf_writer, c, kd, kh, kw);\n          }\n        }\n      }\n    }\n  }\n\n  static void DoNDWHCFunc(const ShapeView& weight_shape, ColBufUtil<T>& col_buf_util,\n                          ColBufWriter<T>* col_buf_writer) {\n    for (int64_t kd = 0; kd != weight_shape.At(1); ++kd) {\n      for (int64_t kh = 0; kh != weight_shape.At(2); ++kh) {\n        for (int64_t kw = 0; kw != weight_shape.At(3); ++kw) {\n          for (int64_t c = 0; c != weight_shape.At(4); ++c) {\n            col_buf_util(col_buf_writer, c, kd, kh, kw);\n          }\n        }\n      }\n    }\n  }\n};\n\ntemplate<typename T>\nstruct DeconvOpKernelCache final : public user_op::OpKernelCache {\n  Col2ImFunc<T> col2im_func_ = nullptr;\n\n  Shape in_5d_shape_;\n  Shape out_5d_shape_;\n  Shape weight_5d_shape_;\n\n  std::vector<int32_t> strides_3d_;\n  std::vector<int32_t> dilation_rate_3d_;\n  std::vector<int32_t> padding_before_3d_;\n\n  bool is_out_diff_need_trans_ = false;\n\n  int32_t idx_offset_ = 0;\n  bool is_dynamic_ = false;\n\n  void Update(const ShapeView& x_shape, const ShapeView& out_shape) {\n    auto Gen5DShape = [](const ShapeView& shape, int32_t idx_offset) -> Shape {\n      DimVector ret_vec;\n      shape.ToDimVector(&ret_vec);\n      int32_t ndims = ret_vec.size() - 2;\n      ret_vec.insert(ret_vec.begin() + idx_offset, 3 - ndims, 1);\n      return Shape(ret_vec);\n    };\n    if (is_dynamic_) {\n      Shape in_shape;\n      in_5d_shape_ = Gen5DShape(x_shape, idx_offset_);\n      out_5d_shape_ = Gen5DShape(out_shape, idx_offset_);\n    }\n  }\n};\n\ntemplate<typename T>\nstd::shared_ptr<DeconvOpKernelCache<T>> CreateDeconvOpKernelCache(user_op::KernelCacheContext* ctx,\n                                                                  const std::string& in_name,\n                                                                  const std::string& out_name,\n                                                                  const std::string& weight_name) {\n  const auto& data_format = ctx->Attr<std::string>(\"data_format\");\n\n  std::shared_ptr<DeconvOpKernelCache<T>> cache(new DeconvOpKernelCache<T>());\n  if (data_format == \"channels_first\") {\n    cache->col2im_func_ = DeconvKernelUtil<T>::NCDHWCol2Im;\n    cache->is_out_diff_need_trans_ = false;\n    cache->idx_offset_ = 2;\n  } else {\n    cache->col2im_func_ = DeconvKernelUtil<T>::NDHWCCol2Im;\n    cache->is_out_diff_need_trans_ = true;\n    cache->idx_offset_ = 1;\n  }\n\n  auto Gen5DShape = [](const Shape& shape, int32_t idx_offset) -> Shape {\n    DimVector ret_vec(shape.dim_vec());\n    int32_t ndims = ret_vec.size() - 2;\n    ret_vec.insert(ret_vec.begin() + idx_offset, 3 - ndims, 1);\n    return Shape(ret_vec);\n  };\n  cache->in_5d_shape_ =\n      Gen5DShape(ctx->TensorDesc4ArgNameAndIndex(in_name, 0)->shape(), cache->idx_offset_);\n  cache->out_5d_shape_ =\n      Gen5DShape(ctx->TensorDesc4ArgNameAndIndex(out_name, 0)->shape(), cache->idx_offset_);\n  cache->weight_5d_shape_ =\n      Gen5DShape(ctx->TensorDesc4ArgNameAndIndex(weight_name, 0)->shape(), cache->idx_offset_);\n\n  auto Gen3DVec = [](const std::vector<int32_t>& origin_vec) -> std::vector<int32_t> {\n    std::vector<int32_t> ret_vec = origin_vec;\n    ret_vec.insert(ret_vec.begin(), 3 - ret_vec.size(), 1);\n    return ret_vec;\n  };\n  cache->strides_3d_ = Gen3DVec(ctx->Attr<std::vector<int32_t>>(\"strides\"));\n  cache->dilation_rate_3d_ = Gen3DVec(ctx->Attr<std::vector<int32_t>>(\"dilation_rate\"));\n  cache->is_dynamic_ = ctx->TensorDesc4ArgNameAndIndex(in_name, 0)->is_dynamic();\n  const auto& padding_before = ctx->Attr<std::vector<int32_t>>(\"padding_before\");\n  FOR_RANGE(uint8_t, dim, 0, 3) {\n    int64_t index = static_cast<int64_t>(dim) - (3 - padding_before.size());\n    if (index < 0) {\n      cache->padding_before_3d_.push_back(0);\n    } else {\n      cache->padding_before_3d_.push_back(padding_before.at(index));\n    }\n  }\n\n  return cache;\n}\n\ntemplate<typename T>\nclass DeconvCpuKernel final : public user_op::OpKernel {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(DeconvCpuKernel);\n  DeconvCpuKernel() = default;\n  ~DeconvCpuKernel() = default;\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n\n  void InitOpKernelCacheWithFlags(\n      user_op::KernelCacheContext* ctx, int8_t flag,\n      std::shared_ptr<user_op::OpKernelCache>* cache_ptr) const override {\n    if (*cache_ptr != nullptr && (flag & user_op::OpKernelCache::kAttrNotChanged)) {\n      auto deconv_cache = std::dynamic_pointer_cast<DeconvOpKernelCache<T>>(*cache_ptr);\n      deconv_cache->Update(ctx->TensorDesc4ArgNameAndIndex(\"in\", 0)->shape(),\n                           ctx->TensorDesc4ArgNameAndIndex(\"out\", 0)->shape());\n      return;\n    }\n    *cache_ptr = CreateDeconvOpKernelCache<T>(ctx, \"out\", \"in\", \"weight\");\n  }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*,\n               const user_op::OpKernelCache* cache) const override {\n    auto deconv_cache = dynamic_cast<const DeconvOpKernelCache<T>*>(cache);\n    CHECK_NOTNULL(deconv_cache);\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    const user_op::Tensor* weight = ctx->Tensor4ArgNameAndIndex(\"weight\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    user_op::Tensor* col_buf = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n\n    Memset<DeviceType::kCPU>(ctx->stream(), out->mut_dptr<T>(), 0,\n                             out->shape_view().elem_cnt() * sizeof(T));\n\n    std::unique_ptr<ep::primitive::Matmul> matmul;\n    if (deconv_cache->is_out_diff_need_trans_) {\n      matmul = NewDeconvTransATransBMatmulPrimitive(ctx);\n    } else {\n      matmul = NewDeconvTransANoTransBMatmulPrimitive(ctx);\n    }\n    CHECK(matmul);\n\n    FOR_RANGE(int64_t, i, 0, in->shape_view().At(0)) {\n      // channels first:  col_buf' = weight(T) * in[i]'\n      // channels last :  col_buf' = weight(T) * in[i]'(T)\n      // m, n, k\n      int32_t idx_offset = deconv_cache->idx_offset_;\n\n      matmul->Launch(ctx->stream(), deconv_cache->weight_5d_shape_.Count(1),\n                     deconv_cache->out_5d_shape_.Count(idx_offset, idx_offset + 3),\n                     deconv_cache->weight_5d_shape_.At(0), static_cast<T>(1), weight->dptr<T>(),\n                     GetImgDptr<T>(in, i), static_cast<T>(0), col_buf->mut_dptr<T>());\n\n      // out = col2im(col_buf')\n      deconv_cache->col2im_func_(\n          col_buf->dptr<T>(), ShapeView(deconv_cache->in_5d_shape_),\n          ShapeView(deconv_cache->weight_5d_shape_), ShapeView(deconv_cache->out_5d_shape_),\n          deconv_cache->strides_3d_.data(), deconv_cache->dilation_rate_3d_.data(),\n          deconv_cache->padding_before_3d_.data(), GetImgMutDptr<T>(out, i));\n    }\n  }\n};\n\n#define REGISTER_DECONV_DATA_KERNEL(op_name, dtype)                                     \\\n  REGISTER_USER_KERNEL(#op_name)                                                        \\\n      .SetCreateFn<DeconvCpuKernel<dtype>>()                                            \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)                   \\\n                       && (user_op::HobAttr<int32_t>(\"groups\") == 1)                    \\\n                       && (user_op::HobDataType(\"out\", 0) == GetDataType<dtype>::value) \\\n                       && DeconvTransATransBMatmulPrimitiveExists()                     \\\n                       && DeconvTransANoTransBMatmulPrimitiveExists())                  \\\n      .SetInferTmpSizeFn([](user_op::InferContext* ctx) -> size_t {                     \\\n        size_t tmp_buffer_size = 0;                                                     \\\n        const auto& in_shape = ctx->InputTensorDesc(\"in\", 0).shape();                   \\\n        const auto& weight_shape = ctx->InputTensorDesc(\"weight\", 0).shape();           \\\n                                                                                        \\\n        int64_t idx_offset = IdxOffset(ctx->Attr<std::string>(\"data_format\"));          \\\n        tmp_buffer_size +=                                                              \\\n            CalcElemNumOfColBuf(in_shape, weight_shape, idx_offset) * sizeof(dtype);    \\\n        return tmp_buffer_size;                                                         \\\n      })\n\nREGISTER_DECONV_DATA_KERNEL(deconv1d, float);\nREGISTER_DECONV_DATA_KERNEL(deconv1d, double);\nREGISTER_DECONV_DATA_KERNEL(deconv2d, float);\nREGISTER_DECONV_DATA_KERNEL(deconv2d, double);\nREGISTER_DECONV_DATA_KERNEL(deconv3d, float);\nREGISTER_DECONV_DATA_KERNEL(deconv3d, double);\n\n}  // namespace\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/deconv_cudnn_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifdef WITH_CUDA\n\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/user/ops/nn_util.h\"\n#include \"oneflow/core/device/cudnn_conv_util.h\"\n#include \"oneflow/core/job/resource_desc.h\"\n#include \"oneflow/core/job/global_for.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n\nnamespace oneflow {\nnamespace {\n\ntemplate<typename PerfT>\nstruct CudnnDeConvArgsAndAlgo final {\n  using AlgoT = decltype(std::declval<PerfT>().algo);\n\n  CudnnConvArgs args;\n  PerfT algo_perf;\n\n  // CudnnDeConvArgsAndAlgo\n  CudnnDeConvArgsAndAlgo(const user_op::Tensor* x, const user_op::Tensor* w,\n                         const user_op::Tensor* y, user_op::Tensor* buf,\n                         const user_op::KernelComputeContext* ctx, ep::Stream* stream,\n                         bool has_forced_algo, int32_t forced_algo)\n      : args(*ctx, x->data_type(), x->shape_view(), w->data_type(), w->shape_view(), y->data_type(),\n             y->shape_view(), ctx->Attr<std::string>(\"data_format\"), buf->shape_view().elem_cnt(),\n             Singleton<ResourceDesc, ForSession>::Get()\n                 ->resource()\n                 .cudnn_conf()\n                 .cudnn_conv_heuristic_search_algo(),\n             Singleton<ResourceDesc, ForSession>::Get()\n                 ->resource()\n                 .cudnn_conf()\n                 .cudnn_conv_use_deterministic_algo_only(),\n             Singleton<ResourceDesc, ForSession>::Get()\n                 ->resource()\n                 .cudnn_conf()\n                 .cudnn_conv_enable_pseudo_half()) {\n    size_t byte_size_of_buf = buf->shape_view().elem_cnt();\n    AllocatedCudnnConvResource res(stream->As<ep::CudaStream>()->cudnn_handle(),\n                                   const_cast<void*>(x->dptr()), const_cast<void*>(w->dptr()),\n                                   const_cast<void*>(y->dptr()), buf->mut_dptr());\n    if (has_forced_algo) {\n      algo_perf = GetCudnnConvAlgorithmPerferenceWithResource<PerfT>(\n          &args, &res, static_cast<AlgoT>(forced_algo));\n    } else {\n      algo_perf = FindCudnnConvAlgorithmWithResource<PerfT>(&args, &res);\n    }\n    CHECK_EQ(algo_perf.status, CUDNN_STATUS_SUCCESS)\n        << \"op (\" << ctx->op_name()\n        << \") find algorithm perference failed. algo: \" << algo_perf.algo;\n    CHECK_LE(algo_perf.memory, byte_size_of_buf)\n        << \"op (\" << ctx->op_name() << \") find algorithm \" << algo_perf.algo << \", need memory \"\n        << algo_perf.memory << \", but cudnn_buf_limit_byte is \" << byte_size_of_buf;\n  }\n  CudnnDeConvArgsAndAlgo() = delete;\n  OF_DISALLOW_COPY_AND_MOVE(CudnnDeConvArgsAndAlgo);\n};\n\ntemplate<typename PerfT>\nsize_t InferTmpSizeWithCudnn(const user_op::TensorDesc* x, const user_op::TensorDesc* w,\n                             const user_op::TensorDesc* y, const user_op::InferContext& ctx,\n                             bool has_forced_algo, int32_t forced_algo) {\n  using AlgoT = decltype(std::declval<PerfT>().algo);\n  const auto& cudnn_conf = Singleton<ResourceDesc, ForSession>::Get()->resource().cudnn_conf();\n  size_t workspace_size = cudnn_conf.cudnn_buf_limit_mbyte() * 1024 * 1024;\n  if (!x->is_dynamic()) {\n    CudnnConvArgs args(ctx, x->data_type(), ShapeView(x->shape()), w->data_type(),\n                       ShapeView(w->shape()), y->data_type(), ShapeView(y->shape()),\n                       ctx.Attr<std::string>(\"data_format\"), workspace_size,\n                       cudnn_conf.cudnn_conv_heuristic_search_algo(),\n                       cudnn_conf.cudnn_conv_use_deterministic_algo_only(),\n                       cudnn_conf.cudnn_conv_enable_pseudo_half());\n    PerfT algo_perf;\n    if (has_forced_algo) {\n      algo_perf = GetCudnnConvAlgorithmPerference<PerfT>(&args, static_cast<AlgoT>(forced_algo));\n    } else {\n      algo_perf = FindCudnnConvAlgorithm<PerfT>(&args);\n    }\n    CHECK_EQ(algo_perf.status, CUDNN_STATUS_SUCCESS)\n        << \"op (\" << ctx.op_name()\n        << \") find algorithm perference failed. algo: \" << algo_perf.algo;\n    CHECK_LE(algo_perf.memory, workspace_size)\n        << \"op (\" << ctx.op_name() << \") find algorithm \" << algo_perf.algo << \", need memory \"\n        << algo_perf.memory << \", but cudnn_buf_limit_byte is \" << workspace_size;\n    workspace_size = algo_perf.memory;\n  }\n  workspace_size = std::max(size_t(1), workspace_size);\n  return workspace_size;\n}\n\n}  // namespace\n\ntemplate<typename T, size_t NDims>\nclass DeConvGpuKernel final : public user_op::OpKernel {\n public:\n  DeConvGpuKernel() = default;\n  ~DeConvGpuKernel() = default;\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    const user_op::Tensor* weight = ctx->Tensor4ArgNameAndIndex(\"weight\", 0);\n    user_op::Tensor* buf = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    if (in->shape_view().elem_cnt() == 0) return;\n    const auto& cudnn_conf = Singleton<ResourceDesc, ForSession>::Get()->resource().cudnn_conf();\n\n    CudnnDeConvArgsAndAlgo<cudnnConvolutionBwdDataAlgoPerf_t> args_and_algo(\n        out, weight, in, buf, ctx, ctx->stream(), cudnn_conf.has_cudnn_conv_force_bwd_data_algo(),\n        cudnn_conf.cudnn_conv_force_bwd_data_algo());\n    const CudnnConvArgs& args = args_and_algo.args;\n    const cudnnConvolutionBwdDataAlgoPerf_t& algo_perf = args_and_algo.algo_perf;\n\n    OF_CUDNN_CHECK(cudnnConvolutionBackwardData(\n        ctx->stream()->As<ep::CudaStream>()->cudnn_handle(), CudnnSPOnePtr<T>(), args.wdesc.Get(),\n        weight->dptr(), args.ydesc.Get(), in->dptr(), args.cdesc.Get(), algo_perf.algo,\n        buf->mut_dptr(), args.params.max_ws_size, CudnnSPZeroPtr<T>(), args.xdesc.Get(),\n        out->mut_dptr()));\n  }\n};\n\n#define REGISTER_DECONV_KERNEL(op_name, dtype, ndims)                                   \\\n  REGISTER_USER_KERNEL(#op_name)                                                        \\\n      .SetCreateFn<DeConvGpuKernel<dtype, ndims>>()                                     \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)                  \\\n                       && (user_op::HobDataType(\"in\", 0) == GetDataType<dtype>::value)) \\\n      .SetInferTmpSizeFn([](user_op::InferContext* ctx) -> size_t {                     \\\n        const auto& in = ctx->InputTensorDesc(\"in\", 0);                                 \\\n        if (in.shape().elem_cnt() == 0) return 0;                                       \\\n        const auto& weight = ctx->InputTensorDesc(\"weight\", 0);                         \\\n        const auto& out = ctx->OutputTensorDesc(\"out\", 0);                              \\\n        const auto& cudnn_conf =                                                        \\\n            Singleton<ResourceDesc, ForSession>::Get()->resource().cudnn_conf();        \\\n        return InferTmpSizeWithCudnn<cudnnConvolutionBwdDataAlgoPerf_t>(                \\\n            &out, &weight, &in, *ctx, cudnn_conf.has_cudnn_conv_force_bwd_data_algo(),  \\\n            cudnn_conf.cudnn_conv_force_bwd_data_algo());                               \\\n      })\n\nREGISTER_DECONV_KERNEL(deconv1d, float, 1);\nREGISTER_DECONV_KERNEL(deconv2d, float, 2);\nREGISTER_DECONV_KERNEL(deconv3d, float, 3);\nREGISTER_DECONV_KERNEL(deconv1d, double, 1);\nREGISTER_DECONV_KERNEL(deconv2d, double, 2);\nREGISTER_DECONV_KERNEL(deconv3d, double, 3);\n\n}  // namespace oneflow\n\n#endif\n"
  },
  {
    "path": "oneflow/user/kernels/deform_conv_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/framework/user_op_hob.h\"\n#include \"oneflow/core/kernel/new_kernel_util.h\"\n#include \"oneflow/core/ep/include/primitive/permute.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n#include \"oneflow/core/ep/include/primitive/matmul.h\"\n#include \"oneflow/core/kernel/cuda_graph_support.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename T>\nT get_coordinate_weight(const T* im_data, int height, int width, T y, T x, bool is_y_direction) {\n  int y_l = floor(y);\n  int x_l = floor(x);\n  int y_h = y_l + 1;\n  int x_h = x_l + 1;\n\n  bool valid_y_l = 0 <= y_l && y_l < height;\n  bool valid_y_h = 0 <= y_h && y_h < height;\n  bool valid_x_l = 0 <= x_l && x_l < width;\n  bool valid_x_h = 0 <= x_h && x_h < width;\n\n  T zero = 0;\n  T v_yx = (valid_y_l && valid_x_l) ? im_data[y_l * width + x_l] : zero;\n  T v_yX = (valid_y_l && valid_x_h) ? im_data[y_l * width + x_h] : zero;\n  T v_Yx = (valid_y_h && valid_x_l) ? im_data[y_h * width + x_l] : zero;\n  T v_YX = (valid_y_h && valid_x_h) ? im_data[y_h * width + x_h] : zero;\n\n  if (is_y_direction) {\n    T dx = x - x_l;\n    return dx * (v_YX - v_yX) + (1 - dx) * (v_Yx - v_yx);\n  } else {\n    T dy = y - y_l;\n    return dy * (v_YX - v_Yx) + (1 - dy) * (v_yX - v_yx);\n  }\n}\n\ntemplate<typename T>\nT DeformableIm2ColBilinear(const T* bottom_data, const int data_width, const int height,\n                           const int width, T h, T w) {\n  int h_low = floor(h);\n  int w_low = floor(w);\n  int h_high = h_low + 1;\n  int w_high = w_low + 1;\n\n  T lh = h - h_low;\n  T lw = w - w_low;\n  T hh = 1 - lh, hw = 1 - lw;\n\n  T v1 = 0;\n  if (h_low >= 0 && w_low >= 0) v1 = bottom_data[h_low * data_width + w_low];\n  T v2 = 0;\n  if (h_low >= 0 && w_high <= width - 1) v2 = bottom_data[h_low * data_width + w_high];\n  T v3 = 0;\n  if (h_high <= height - 1 && w_low >= 0) v3 = bottom_data[h_high * data_width + w_low];\n  T v4 = 0;\n  if (h_high <= height - 1 && w_high <= width - 1) v4 = bottom_data[h_high * data_width + w_high];\n\n  T w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;\n\n  T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);\n  return val;\n}\n\ntemplate<typename T>\nT GetGradientWeight(T argmax_h, T argmax_w, const int h, const int w, const int height,\n                    const int width) {\n  if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) {\n    // empty\n    return 0;\n  }\n\n  int argmax_h_low = floor(argmax_h);\n  int argmax_w_low = floor(argmax_w);\n  int argmax_h_high = argmax_h_low + 1;\n  int argmax_w_high = argmax_w_low + 1;\n\n  T weight = 0;\n  if (h == argmax_h_low && w == argmax_w_low) weight = (h + 1 - argmax_h) * (w + 1 - argmax_w);\n  if (h == argmax_h_low && w == argmax_w_high) weight = (h + 1 - argmax_h) * (argmax_w + 1 - w);\n  if (h == argmax_h_high && w == argmax_w_low) weight = (argmax_h + 1 - h) * (w + 1 - argmax_w);\n  if (h == argmax_h_high && w == argmax_w_high) weight = (argmax_h + 1 - h) * (argmax_w + 1 - w);\n  return weight;\n}\n\ntemplate<typename T>\nT GetCoordinateWeight(T argmax_h, T argmax_w, const int height, const int width, const T* im_data,\n                      const int data_width, const int bp_dir) {\n  if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) {\n    // empty\n    return static_cast<T>(0);\n  }\n\n  int argmax_h_low = floor(argmax_h);\n  int argmax_w_low = floor(argmax_w);\n  int argmax_h_high = argmax_h_low + 1;\n  int argmax_w_high = argmax_w_low + 1;\n\n  T weight = 0;\n\n  if (bp_dir == 0) {\n    if (argmax_h_low >= 0 && argmax_w_low >= 0)\n      weight +=\n          -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low];\n    if (argmax_h_low >= 0 && argmax_w_high <= width - 1)\n      weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high];\n    if (argmax_h_high <= height - 1 && argmax_w_low >= 0)\n      weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low];\n    if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)\n      weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high];\n  } else if (bp_dir == 1) {\n    if (argmax_h_low >= 0 && argmax_w_low >= 0)\n      weight +=\n          -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low];\n    if (argmax_h_low >= 0 && argmax_w_high <= width - 1)\n      weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high];\n    if (argmax_h_high <= height - 1 && argmax_w_low >= 0)\n      weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low];\n    if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)\n      weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high];\n  }\n\n  return weight;\n}\n\ntemplate<typename T>\nT bilinear_interpolate(const T* in, int height, int width, T h, T w) {\n  if (h <= -1 || height <= h || w <= -1 || width <= w) { return 0; }\n\n  int h_low = floor(h);\n  int w_low = floor(w);\n  int h_high = h_low + 1;\n  int w_high = w_low + 1;\n\n  T lh = h - h_low;\n  T lw = w - w_low;\n  T hh = 1 - lh, hw = 1 - lw;\n\n  T v1 = 0;\n  if (h_low >= 0 && w_low >= 0) v1 = in[h_low * width + w_low];\n  T v2 = 0;\n  if (h_low >= 0 && w_high <= width - 1) v2 = in[h_low * width + w_high];\n  T v3 = 0;\n  if (h_high <= height - 1 && w_low >= 0) v3 = in[h_high * width + w_low];\n  T v4 = 0;\n  if (h_high <= height - 1 && w_high <= width - 1) v4 = in[h_high * width + w_high];\n\n  T w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;\n\n  T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);\n  return val;\n}\n\ntemplate<typename T>\nvoid DeformableIm2Col(int n, const T* input, const T* offset, const T* mask, int height, int width,\n                      int weight_h, int weight_w, int pad_h, int pad_w, int stride_h, int stride_w,\n                      int dilation_h, int dilation_w, int batch_sz, int n_in_channels,\n                      int n_offset_grps, int out_h, int out_w, bool use_mask, T* columns) {\n  for (int index = 0; index != n; ++index) {\n    const int out_x = index % out_w;\n    const int out_y = (index / out_w) % out_h;\n    const int out_b = (index / (out_w * out_h)) % batch_sz;\n    const int in_c = index / (out_w * out_h * batch_sz);\n    const int out_c = in_c * weight_h * weight_w;\n\n    int c_per_offset_grp = n_in_channels / n_offset_grps;\n    const int grp_idx = in_c / c_per_offset_grp;\n\n    auto columns_ptr =\n        columns\n        + (out_c * (batch_sz * out_h * out_w) + out_b * (out_h * out_w) + out_y * out_w + out_x);\n    auto input_ptr = input + (out_b * (n_in_channels * height * width) + in_c * (height * width));\n\n    auto offset_ptr =\n        offset + (out_b * n_offset_grps + grp_idx) * 2 * weight_h * weight_w * out_h * out_w;\n\n    auto mask_ptr = mask;\n    if (use_mask) {\n      mask_ptr += (out_b * n_offset_grps + grp_idx) * weight_h * weight_w * out_h * out_w;\n    }\n\n    for (int i = 0; i < weight_h; ++i) {\n      for (int j = 0; j < weight_w; ++j) {\n        const int mask_idx = i * weight_w + j;\n        const int offset_idx = 2 * mask_idx;\n\n        T mask_value = 1;\n        if (use_mask) { mask_value = mask_ptr[mask_idx * (out_h * out_w) + out_y * out_w + out_x]; }\n\n        const T offset_h = offset_ptr[offset_idx * (out_h * out_w) + out_y * out_w + out_x];\n        const T offset_w = offset_ptr[(offset_idx + 1) * (out_h * out_w) + out_y * out_w + out_x];\n        const T y = (out_y * stride_h - pad_h) + i * dilation_h + offset_h;\n        const T x = (out_x * stride_w - pad_w) + j * dilation_w + offset_w;\n        *columns_ptr = mask_value * bilinear_interpolate(input_ptr, height, width, y, x);\n        columns_ptr += batch_sz * out_h * out_w;\n      }\n    }\n  }\n}\n\ntemplate<typename T>\nvoid DeformableCol2Im(int n, const T* col, const T* offset_data, const T* mask_data, int channels,\n                      int height, int width, int kernel_h, int kernel_w, int pad_h, int pad_w,\n                      int stride_h, int stride_w, int dilation_h, int dilation_w, int batch_sz,\n                      int n_offset_grps, int out_h, int out_w, bool use_mask, T* grad_im) {\n  for (int index = 0; index != n; ++index) {\n    const int out_x = index % out_w;\n    const int out_y = (index / out_w) % out_h;\n    const int b = (index / (out_w * out_h)) % batch_sz;\n    const int j = (index / (out_w * out_h * batch_sz)) % kernel_w;\n    const int i = (index / (out_w * out_h * batch_sz * kernel_w)) % kernel_h;\n    const int c = index / (out_w * out_h * batch_sz * kernel_w * kernel_h);\n\n    int c_per_offset_grp = channels / n_offset_grps;\n    const int offset_grp = c / c_per_offset_grp;\n    auto offset_ptr = offset_data;\n    offset_ptr += (b * n_offset_grps + offset_grp) * 2 * kernel_h * kernel_w * out_h * out_w;\n    auto mask_ptr = mask_data;\n    if (use_mask) {\n      mask_ptr += (b * n_offset_grps + offset_grp) * kernel_h * kernel_w * out_h * out_w;\n    }\n\n    const int mask_idx = i * kernel_w + j;\n    const int offset_idx = 2 * mask_idx;\n\n    const int offset_h_ptr = ((offset_idx)*out_h + out_y) * out_w + out_x;\n    const int offset_w_ptr = ((offset_idx + 1) * out_h + out_y) * out_w + out_x;\n\n    const T offset_h = offset_ptr[offset_h_ptr];\n    const T offset_w = offset_ptr[offset_w_ptr];\n\n    T mask_value = 1;\n    if (use_mask) { mask_value = mask_ptr[(mask_idx * out_h + out_y) * out_w + out_x]; }\n\n    const T y = (out_y * stride_h - pad_h) + i * dilation_h + offset_h;\n    const T x = (out_x * stride_w - pad_w) + j * dilation_w + offset_w;\n\n    for (int dy = -1; dy <= 1; dy++) {\n      for (int dx = -1; dx <= 1; dx++) {\n        int yp = (int)y + dy;\n        int xp = (int)x + dx;\n        if (0 <= yp && yp < height && 0 <= xp && xp < width && std::abs(y - yp) < 1\n            && std::abs(x - xp) < 1) {\n          int grad_pos = ((b * channels + c) * height + yp) * width + xp;\n          T weight = (1 - std::abs(y - yp)) * (1 - std::abs(x - xp));\n          grad_im[grad_pos] += mask_value * weight * col[index];\n        }\n      }\n    }\n  }\n}\n\ntemplate<typename T>\nvoid DeformableCol2ImCoord(int n, const T* col_data, const T* im_data, const T* offset_data,\n                           const T* mask_data, int channels, int height, int width, int weight_h,\n                           int weight_w, int pad_h, int pad_w, int stride_h, int stride_w,\n                           int dilation_h, int dilation_w, int batch_sz, int offset_channels,\n                           int n_offset_grps, int out_h, int out_w, const bool use_mask,\n                           T* grad_offset, T* grad_mask) {\n  for (int index = 0; index != n; ++index) {\n    T grad_offset_val = 0;\n    T grad_mask_val = 0;\n    int w = index % out_w;\n    int h = (index / out_w) % out_h;\n    int w_w = (index / (out_w * out_h * 2)) % weight_w;\n    int w_h = (index / (out_w * out_h * 2 * weight_w)) % weight_h;\n    int c = (index / (out_w * out_h)) % offset_channels;\n    int b = index / (out_w * out_h * offset_channels);\n\n    const int offset_grp = c / (2 * weight_h * weight_w);\n    const int col_step = weight_h * weight_w;\n\n    int c_per_offset_grp = channels / n_offset_grps;\n    auto col_ptr = col_data;\n    col_ptr += offset_grp * c_per_offset_grp * weight_h * weight_w * batch_sz * out_w * out_h;\n    auto im_ptr = im_data;\n    im_ptr += (b * n_offset_grps + offset_grp) * c_per_offset_grp * height * width;\n    auto offset_ptr = offset_data;\n    offset_ptr += (b * n_offset_grps + offset_grp) * 2 * weight_h * weight_w * out_h * out_w;\n\n    auto mask_ptr = mask_data;\n    if (use_mask) {\n      mask_ptr += (b * n_offset_grps + offset_grp) * weight_h * weight_w * out_h * out_w;\n    }\n\n    const int offset_c = c - offset_grp * 2 * weight_h * weight_w;\n    const bool is_y_direction = offset_c % 2 == 0;\n\n    const int c_bound = c_per_offset_grp * weight_h * weight_w;\n    for (int col_c = (offset_c / 2); col_c < c_bound; col_c += col_step) {\n      const int col_pos = (((col_c * batch_sz + b) * out_h) + h) * out_w + w;\n\n      int out_x = col_pos % out_w;\n      int out_y = (col_pos / out_w) % out_h;\n      int j = (col_pos / (out_w * out_h * batch_sz)) % weight_w;\n      int i = (col_pos / (out_w * out_h * batch_sz * weight_w)) % weight_h;\n\n      const int mask_idx = i * weight_w + j;\n\n      const int offset_h_ptr = (((2 * mask_idx) * out_h + out_y) * out_w + out_x);\n      const int offset_w_ptr = (((2 * mask_idx + 1) * out_h + out_y) * out_w + out_x);\n      const T offset_h = offset_ptr[offset_h_ptr];\n      const T offset_w = offset_ptr[offset_w_ptr];\n\n      T mask_value = 1;\n      if (use_mask) { mask_value = mask_ptr[(mask_idx * out_h + out_y) * out_w + out_x]; }\n\n      T y = (out_y * stride_h - pad_h) + i * dilation_h + offset_h;\n      T x = (out_x * stride_w - pad_w) + j * dilation_w + offset_w;\n\n      const T weight = get_coordinate_weight(im_ptr, height, width, y, x, is_y_direction);\n      grad_offset_val += mask_value * weight * col_ptr[col_pos];\n\n      if (use_mask && is_y_direction) {\n        grad_mask_val += col_ptr[col_pos] * bilinear_interpolate(im_ptr, height, width, y, x);\n      }\n\n      im_ptr += height * width;\n    }\n\n    grad_offset[index] = grad_offset_val;\n\n    if (use_mask && is_y_direction) {\n      const int idx =\n          ((((b * n_offset_grps + offset_grp) * weight_h + w_h) * weight_w + w_w) * out_h + h)\n              * out_w\n          + w;\n      grad_mask[idx] = grad_mask_val;\n    }\n  }\n}\n\nep::primitive::BlasTransposeType GetBlasTransposeType(bool transpose) {\n  return transpose ? ep::primitive::BlasTransposeType::T : ep::primitive::BlasTransposeType::N;\n}\n\nstd::unique_ptr<ep::primitive::Matmul> NewMatmulPrimitive(DeviceType device_type,\n                                                          DataType data_type, bool transpose_a,\n                                                          bool transpose_b) {\n  const auto trans_a = GetBlasTransposeType(transpose_a);\n  const auto trans_b = GetBlasTransposeType(transpose_b);\n  return ep::primitive::NewPrimitive<ep::primitive::MatmulFactory>(device_type, data_type, trans_a,\n                                                                   trans_b);\n}\ntemplate<typename Context>\nstd::unique_ptr<ep::primitive::Permute> NewPermutePrimitive(Context* ctx, const int& num_dims) {\n  return ep::primitive::NewPrimitive<ep::primitive::PermuteFactory>(ctx->device_type(), num_dims);\n}\n\ntemplate<typename T>\nclass DeformableConv2dCpuKernel final : public user_op::OpKernel {\n public:\n  DeformableConv2dCpuKernel() = default;\n  ~DeformableConv2dCpuKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* input = ctx->Tensor4ArgNameAndIndex(\"input\", 0);\n    const user_op::Tensor* weight = ctx->Tensor4ArgNameAndIndex(\"weight\", 0);\n    const user_op::Tensor* offset = ctx->Tensor4ArgNameAndIndex(\"offset\", 0);\n    const user_op::Tensor* mask = ctx->Tensor4ArgNameAndIndex(\"mask\", 0);\n    user_op::Tensor* output = ctx->Tensor4ArgNameAndIndex(\"output\", 0);\n    user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n    const ShapeView& input_shape = input->shape_view();\n    const ShapeView& output_shape = output->shape_view();\n    const ShapeView& weight_shape = weight->shape_view();\n    const int64_t out_elem_cnt = output_shape.elem_cnt();\n    const int64_t output_bytes = (out_elem_cnt * sizeof(T));\n\n    T* column_tmp_buffer = reinterpret_cast<T*>(tmp_buffer->mut_dptr<char>() + output_bytes);\n    const int32_t kW = weight->shape_view().At(2);\n    const int32_t kH = weight->shape_view().At(3);\n    const int32_t dW = ctx->Attr<int32_t>(\"stride_w\");\n    const int32_t dH = ctx->Attr<int32_t>(\"stride_h\");\n    const int32_t padW = ctx->Attr<int32_t>(\"pad_w\");\n    const int32_t padH = ctx->Attr<int32_t>(\"pad_h\");\n    const int32_t dilationW = ctx->Attr<int32_t>(\"dilation_w\");\n    const int32_t dilationH = ctx->Attr<int32_t>(\"dilation_h\");\n    const int32_t group = ctx->Attr<int32_t>(\"groups\");\n    const int32_t deformable_group = ctx->Attr<int32_t>(\"offset_groups\");\n    const bool use_mask = ctx->Attr<bool>(\"use_mask\");\n\n    const int64_t outputWidth =\n        ((input_shape.At(3) + 2 * padW - (dilationW * (kW - 1) + 1)) / dW) + 1;\n    const int64_t outputHeight =\n        ((input_shape.At(2) + 2 * padH - (dilationH * (kH - 1) + 1)) / dH) + 1;\n    const int64_t column_nums = input_shape.At(1) * input_shape.At(0) * outputHeight * outputWidth;\n\n    if (column_nums > 0) {\n      DeformableIm2Col<T>(column_nums, input->dptr<T>(), offset->dptr<T>(), mask->dptr<T>(),\n                          input_shape.At(2), input_shape.At(3), kH, kW, padH, padW, dH, dW,\n                          dilationH, dilationW, input_shape.At(0), input_shape.At(1),\n                          deformable_group, output_shape.At(2), output_shape.At(3), use_mask,\n                          column_tmp_buffer);\n\n      const int64_t weight_group_offset = weight->shape_view().elem_cnt() / group;\n      const int64_t column_group_offset =\n          input_shape.At(1) * kW * kH * input_shape.At(0) * outputHeight * outputWidth / group;\n      const int64_t output_group_offset = out_elem_cnt / group;\n\n      auto matmul = NewMatmulPrimitive(ctx->device_type(), output->data_type(), false, false);\n      CHECK(matmul);\n      FOR_RANGE(int, g, 0, group) {\n        matmul->Launch(ctx->stream(), weight_shape.At(0) / group,\n                       input_shape.At(0) * outputHeight * outputWidth,\n                       input_shape.At(1) * kW * kH / group, static_cast<T>(1),\n                       weight->dptr<T>() + g * weight_group_offset,\n                       column_tmp_buffer + g * column_group_offset, static_cast<T>(0),\n                       tmp_buffer->mut_dptr<T>() + g * output_group_offset);\n      }\n\n      std::vector<int64_t> out_shapevec(\n          {output_shape.At(1), output_shape.At(0), output_shape.At(2), output_shape.At(3)});\n      auto transpose = NewPermutePrimitive(ctx, output_shape.NumAxes());\n      CHECK(transpose);\n      transpose->Launch(ctx->stream(), output->data_type(), output_shape.NumAxes(),\n                        out_shapevec.data(), tmp_buffer->dptr<T>(),\n                        std::vector<int>({1, 0, 2, 3}).data(), output->mut_dptr<T>());\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\ntemplate<typename T>\nclass DeformableConv2dInputGradCpuKernel final : public user_op::OpKernel {\n public:\n  DeformableConv2dInputGradCpuKernel() = default;\n  ~DeformableConv2dInputGradCpuKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* output_grad = ctx->Tensor4ArgNameAndIndex(\"output_grad\", 0);\n    const user_op::Tensor* input = ctx->Tensor4ArgNameAndIndex(\"input\", 0);\n    const user_op::Tensor* weight = ctx->Tensor4ArgNameAndIndex(\"weight\", 0);\n    const user_op::Tensor* offset = ctx->Tensor4ArgNameAndIndex(\"offset\", 0);\n    user_op::Tensor* input_grad = ctx->Tensor4ArgNameAndIndex(\"input_grad\", 0);\n    user_op::Tensor* offset_grad = ctx->Tensor4ArgNameAndIndex(\"offset_grad\", 0);\n    user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n    const ShapeView& output_grad_shape = output_grad->shape_view();\n    const ShapeView& input_shape = input->shape_view();\n    const ShapeView& weight_shape = weight->shape_view();\n    const int32_t kW = weight->shape_view().At(2);\n    const int32_t kH = weight->shape_view().At(3);\n    const int32_t dW = ctx->Attr<int32_t>(\"stride_w\");\n    const int32_t dH = ctx->Attr<int32_t>(\"stride_h\");\n    const int32_t padW = ctx->Attr<int32_t>(\"pad_w\");\n    const int32_t padH = ctx->Attr<int32_t>(\"pad_h\");\n    const int32_t dilationW = ctx->Attr<int32_t>(\"dilation_w\");\n    const int32_t dilationH = ctx->Attr<int32_t>(\"dilation_h\");\n    const int32_t group = ctx->Attr<int32_t>(\"groups\");\n    const int32_t deformable_group = ctx->Attr<int32_t>(\"offset_groups\");\n    const bool use_mask = ctx->Attr<bool>(\"use_mask\");\n    const T* data_mask = nullptr;\n    T* data_mask_grad = nullptr;\n    if (use_mask) {\n      data_mask = ctx->Tensor4ArgNameAndIndex(\"mask\", 0)->dptr<T>();\n      data_mask_grad = ctx->Tensor4ArgNameAndIndex(\"mask_grad\", 0)->mut_dptr<T>();\n    }\n\n    const int64_t outputWidth =\n        (input_shape.At(3) + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;\n    const int64_t outputHeight =\n        (input_shape.At(2) + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;\n\n    std::unique_ptr<ep::primitive::Memset> primitive =\n        ep::primitive::NewPrimitive<ep::primitive::MemsetFactory>(ctx->stream()->device_type());\n\n    primitive->Launch(ctx->stream(), input_grad->mut_dptr<T>(), 0,\n                      input_grad->shape_view().elem_cnt() * sizeof(T));\n    if (use_mask) {\n      primitive->Launch(\n          ctx->stream(), data_mask_grad, 0,\n          ctx->Tensor4ArgNameAndIndex(\"mask_grad\", 0)->shape_view().elem_cnt() * sizeof(T));\n    }\n    const int64_t nthreads_coord =\n        outputHeight * outputWidth * 2 * kH * kW * deformable_group * input_shape.At(0);\n    const int64_t nthreads_feat =\n        outputHeight * outputWidth * input_shape.At(0) * kH * kW * input_shape.At(1);\n    if (nthreads_coord > 0 && nthreads_feat > 0) {\n      const int64_t weight_group_offset = weight_shape.elem_cnt() / group;\n      const int64_t output_grad_group_offset = output_grad_shape.Count(1) / group;\n      const int64_t column_group_offset =\n          input_shape.At(1) * kW * kH * input_shape.At(0) * outputHeight * outputWidth / group;\n\n      auto matmul = NewMatmulPrimitive(ctx->device_type(), input_grad->data_type(), true, true);\n      CHECK(matmul);\n      FOR_RANGE(int, g, 0, group) {\n        matmul->Launch(ctx->stream(), weight_shape.Count(1),\n                       input_shape.At(0) * outputHeight * outputWidth, weight_shape.At(0) / group,\n                       static_cast<T>(1), weight->dptr<T>() + g * weight_group_offset,\n                       output_grad->dptr<T>() + g * output_grad_group_offset, static_cast<T>(0),\n                       tmp_buffer->mut_dptr<T>() + g * column_group_offset);\n      }\n      DeformableCol2ImCoord<T>(\n          nthreads_coord, tmp_buffer->dptr<T>(), input->dptr<T>(), offset->dptr<T>(), data_mask,\n          input_shape.At(1), input_shape.At(2), input_shape.At(3), kH, kW, padH, padW, dH, dW,\n          dilationH, dilationW, input_shape.At(0), 2 * kH * kW * deformable_group, deformable_group,\n          outputHeight, outputWidth, use_mask, offset_grad->mut_dptr<T>(), data_mask_grad);\n\n      DeformableCol2Im<T>(nthreads_feat, tmp_buffer->dptr<T>(), offset->dptr<T>(), data_mask,\n                          input_shape.At(1), input_shape.At(2), input_shape.At(3), kH, kW, padH,\n                          padW, dH, dW, dilationH, dilationW, input_shape.At(0), deformable_group,\n                          outputHeight, outputWidth, use_mask, input_grad->mut_dptr<T>());\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\ntemplate<typename T>\nclass DeformableConv2dParamGradCpuKernel final : public user_op::OpKernel {\n public:\n  DeformableConv2dParamGradCpuKernel() = default;\n  ~DeformableConv2dParamGradCpuKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* output_grad = ctx->Tensor4ArgNameAndIndex(\"output_grad\", 0);\n    const user_op::Tensor* input = ctx->Tensor4ArgNameAndIndex(\"input\", 0);\n    const user_op::Tensor* offset = ctx->Tensor4ArgNameAndIndex(\"offset\", 0);\n    const user_op::Tensor* weight = ctx->Tensor4ArgNameAndIndex(\"weight\", 0);\n    user_op::Tensor* weight_grad = ctx->Tensor4ArgNameAndIndex(\"weight_grad\", 0);\n    user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n    const ShapeView& output_grad_shape = output_grad->shape_view();\n    const ShapeView& weight_grad_shape = weight_grad->shape_view();\n    const ShapeView& input_shape = input->shape_view();\n    const int64_t out_elem_cnt = output_grad_shape.elem_cnt();\n    const int64_t output_bytes = (out_elem_cnt * sizeof(T));\n\n    T* column_tmp_buffer = reinterpret_cast<T*>(tmp_buffer->mut_dptr<char>() + output_bytes);\n    const int32_t kW = weight->shape_view().At(2);\n    const int32_t kH = weight->shape_view().At(3);\n    const user_op::Tensor* mask = ctx->Tensor4ArgNameAndIndex(\"mask\", 0);\n    const int32_t dW = ctx->Attr<int32_t>(\"stride_w\");\n    const int32_t dH = ctx->Attr<int32_t>(\"stride_h\");\n    const int32_t padW = ctx->Attr<int32_t>(\"pad_w\");\n    const int32_t padH = ctx->Attr<int32_t>(\"pad_h\");\n    const int32_t dilationW = ctx->Attr<int32_t>(\"dilation_w\");\n    const int32_t dilationH = ctx->Attr<int32_t>(\"dilation_h\");\n    const int32_t group = ctx->Attr<int32_t>(\"groups\");\n    const int32_t deformable_group = ctx->Attr<int32_t>(\"offset_groups\");\n    const bool use_mask = ctx->Attr<bool>(\"use_mask\");\n    const int64_t outputWidth =\n        (input_shape.At(3) + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;\n    const int64_t outputHeight =\n        (input_shape.At(2) + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;\n\n    const int64_t column_nums = input_shape.At(1) * input_shape.At(0) * outputHeight * outputWidth;\n    if (column_nums > 0) {\n      DeformableIm2Col<T>(column_nums, input->dptr<T>(), offset->dptr<T>(), mask->dptr<T>(),\n                          input_shape.At(2), input_shape.At(3), kH, kW, padH, padW, dH, dW,\n                          dilationH, dilationW, input_shape.At(0), input_shape.At(1),\n                          deformable_group, output_grad_shape.At(2), output_grad_shape.At(3),\n                          use_mask, column_tmp_buffer);\n\n      std::unique_ptr<ep::primitive::Memset> primitive =\n          ep::primitive::NewPrimitive<ep::primitive::MemsetFactory>(ctx->stream()->device_type());\n      primitive->Launch(ctx->stream(), weight_grad->mut_dptr<T>(), 0,\n                        weight_grad->shape_view().elem_cnt());\n\n      std::vector<int64_t> output_grad_buffer_vec({output_grad_shape.At(1), output_grad_shape.At(0),\n                                                   output_grad_shape.At(2),\n                                                   output_grad_shape.At(3)});\n\n      auto transpose = NewPermutePrimitive(ctx, output_grad_shape.NumAxes());\n      CHECK(transpose);\n      transpose->Launch(ctx->stream(), output_grad->data_type(), output_grad_shape.NumAxes(),\n                        output_grad_buffer_vec.data(), output_grad->dptr<T>(),\n                        std::vector<int>({1, 0, 2, 3}).data(), tmp_buffer->mut_dptr<T>());\n\n      const int64_t output_grad_group_offset = output_grad_shape.elem_cnt() / group;\n      const int64_t column_group_offset =\n          input_shape.At(1) * kW * kW * input_shape.At(0) * outputHeight * outputWidth / group;\n      const int64_t weight_grad_group_offset = weight_grad->shape_view().elem_cnt() / group;\n      FOR_RANGE(int, g, 0, group) {\n        auto matmul = NewMatmulPrimitive(ctx->device_type(), weight_grad->data_type(), false, true);\n        CHECK(matmul);\n\n        matmul->Launch(ctx->stream(), weight_grad_shape.At(0) / group,\n                       input_shape.At(1) * kW * kH / group,\n                       input_shape.At(0) * outputHeight * outputWidth, static_cast<T>(1),\n                       tmp_buffer->dptr<T>() + g * output_grad_group_offset,\n                       column_tmp_buffer + g * column_group_offset, static_cast<T>(0),\n                       weight_grad->mut_dptr<T>() + g * weight_grad_group_offset);\n      }\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_DEFORM_CONV2D_CPU_KERNEL(dtype)                                           \\\n  REGISTER_USER_KERNEL(\"deform_conv2d\")                                                    \\\n      .SetCreateFn<DeformableConv2dCpuKernel<dtype>>()                                     \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)                      \\\n                       && (user_op::HobDataType(\"input\", 0) == GetDataType<dtype>::value)) \\\n      .SetInferTmpSizeFn([](user_op::InferContext* ctx) {                                  \\\n        const Shape& input_shape = ctx->InputShape(\"input\", 0);                            \\\n        const Shape& output_shape = ctx->OutputShape(\"output\", 0);                         \\\n        const Shape& weight_shape = ctx->InputShape(\"weight\", 0);                          \\\n        const int32_t kW = weight_shape.At(2);                                             \\\n        const int32_t kH = weight_shape.At(3);                                             \\\n        const int32_t dW = ctx->Attr<int32_t>(\"stride_w\");                                 \\\n        const int32_t dH = ctx->Attr<int32_t>(\"stride_h\");                                 \\\n        const int32_t padW = ctx->Attr<int32_t>(\"pad_w\");                                  \\\n        const int32_t padH = ctx->Attr<int32_t>(\"pad_h\");                                  \\\n        const int32_t dilationW = ctx->Attr<int32_t>(\"dilation_w\");                        \\\n        const int32_t dilationH = ctx->Attr<int32_t>(\"dilation_h\");                        \\\n        const int64_t outputWidth =                                                        \\\n            ((input_shape.At(3) + 2 * padW - (dilationW * (kW - 1) + 1)) / dW) + 1;        \\\n        const int64_t outputHeight =                                                       \\\n            ((input_shape.At(2) + 2 * padH - (dilationH * (kH - 1) + 1)) / dH) + 1;        \\\n        const int64_t column_bytes = (input_shape.At(1) * kW * kH * input_shape.At(0)      \\\n                                      * outputHeight * outputWidth * sizeof(dtype));       \\\n        const int64_t output_bytes = (output_shape.elem_cnt() * sizeof(dtype));            \\\n        return column_bytes + output_bytes;                                                \\\n      });\nREGISTER_DEFORM_CONV2D_CPU_KERNEL(float)\nREGISTER_DEFORM_CONV2D_CPU_KERNEL(double)\n\n#define REGISTER_DEFORM_CONV2D_INPUT_GRAD_CPU_KERNEL(dtype)                                 \\\n  REGISTER_USER_KERNEL(\"deform_conv2d_input_grad\")                                          \\\n      .SetCreateFn<DeformableConv2dInputGradCpuKernel<dtype>>()                             \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)                       \\\n                       && (user_op::HobDataType(\"input\", 0) == GetDataType<dtype>::value)   \\\n                       && (user_op::HobDataType(\"weight\", 0) == GetDataType<dtype>::value)  \\\n                       && (user_op::HobDataType(\"offset\", 0) == GetDataType<dtype>::value)) \\\n      .SetInferTmpSizeFn([](user_op::InferContext* ctx) {                                   \\\n        const Shape& input_shape = ctx->InputShape(\"input\", 0);                             \\\n        const Shape& weight_shape = ctx->InputShape(\"weight\", 0);                           \\\n        const int32_t kW = weight_shape.At(2);                                              \\\n        const int32_t kH = weight_shape.At(3);                                              \\\n        const int32_t dW = ctx->Attr<int32_t>(\"stride_w\");                                  \\\n        const int32_t dH = ctx->Attr<int32_t>(\"stride_h\");                                  \\\n        const int32_t padW = ctx->Attr<int32_t>(\"pad_w\");                                   \\\n        const int32_t padH = ctx->Attr<int32_t>(\"pad_h\");                                   \\\n        const int32_t dilationW = ctx->Attr<int32_t>(\"dilation_w\");                         \\\n        const int32_t dilationH = ctx->Attr<int32_t>(\"dilation_h\");                         \\\n        const int64_t outputWidth =                                                         \\\n            (input_shape.At(3) + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;           \\\n        const int64_t outputHeight =                                                        \\\n            (input_shape.At(2) + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;           \\\n        const int64_t column_bytes = input_shape.At(1) * kW * kH * input_shape.At(0)        \\\n                                     * outputHeight * outputWidth * sizeof(dtype);          \\\n        return column_bytes;                                                                \\\n      });\nREGISTER_DEFORM_CONV2D_INPUT_GRAD_CPU_KERNEL(float)\nREGISTER_DEFORM_CONV2D_INPUT_GRAD_CPU_KERNEL(double)\n\n#define REGISTER_DEFORM_CONV2D_PARAM_GRAD_CPU_KERNEL(dtype)                                 \\\n  REGISTER_USER_KERNEL(\"deform_conv2d_param_grad\")                                          \\\n      .SetCreateFn<DeformableConv2dParamGradCpuKernel<dtype>>()                             \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)                       \\\n                       && (user_op::HobDataType(\"input\", 0) == GetDataType<dtype>::value)   \\\n                       && (user_op::HobDataType(\"offset\", 0) == GetDataType<dtype>::value)) \\\n      .SetInferTmpSizeFn([](user_op::InferContext* ctx) {                                   \\\n        const Shape& input_shape = ctx->InputShape(\"input\", 0);                             \\\n        const Shape& output_grad_shape = ctx->InputShape(\"output_grad\", 0);                 \\\n        const Shape& weight_shape = ctx->InputShape(\"weight\", 0);                           \\\n        const int32_t kW = weight_shape.At(2);                                              \\\n        const int32_t kH = weight_shape.At(3);                                              \\\n        const int32_t dW = ctx->Attr<int32_t>(\"stride_w\");                                  \\\n        const int32_t dH = ctx->Attr<int32_t>(\"stride_h\");                                  \\\n        const int32_t padW = ctx->Attr<int32_t>(\"pad_w\");                                   \\\n        const int32_t padH = ctx->Attr<int32_t>(\"pad_h\");                                   \\\n        const int32_t dilationW = ctx->Attr<int32_t>(\"dilation_w\");                         \\\n        const int32_t dilationH = ctx->Attr<int32_t>(\"dilation_h\");                         \\\n        const int64_t outputWidth =                                                         \\\n            (input_shape.At(3) + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;           \\\n        const int64_t outputHeight =                                                        \\\n            (input_shape.At(2) + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;           \\\n        const int64_t column_bytes = (input_shape.At(1) * kW * kH * input_shape.At(0)       \\\n                                      * outputHeight * outputWidth * sizeof(dtype));        \\\n        const int64_t output_bytes = (output_grad_shape.elem_cnt() * sizeof(dtype));        \\\n        return column_bytes + output_bytes;                                                 \\\n      });\nREGISTER_DEFORM_CONV2D_PARAM_GRAD_CPU_KERNEL(float)\nREGISTER_DEFORM_CONV2D_PARAM_GRAD_CPU_KERNEL(double)\n\n}  // namespace\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/deform_conv_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/framework/user_op_hob.h\"\n#include \"oneflow/core/ep/include/primitive/permute.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n#include \"oneflow/core/ep/include/primitive/matmul.h\"\n#include \"oneflow/core/ep/include/primitive/memset.h\"\n\nnamespace oneflow {\n\nnamespace {\n\n__device__ __forceinline__ float Add(float* address, float val) { return atomicAdd(address, val); }\n\n__device__ __forceinline__ double Add(double* address, double val) {\n#if __CUDA_ARCH__ >= 600\n  return atomicAdd(address, val);\n#else\n  auto address_as_ull = reinterpret_cast<unsigned long long int*>(address);\n  unsigned long long int old = *address_as_ull;\n  unsigned long long int assumed = 0;\n  do {\n    assumed = old;\n    old = atomicCAS(address_as_ull, assumed,\n                    __double_as_longlong(val + __longlong_as_double(assumed)));\n  } while (assumed != old);\n  return __longlong_as_double(old);\n#endif\n}\n\ntemplate<typename T>\n__device__ T bilinear_interpolate(const T* in, int height, int width, T h, T w) {\n  if (h <= -1 || height <= h || w <= -1 || width <= w) { return 0; }\n\n  int h_low = floor(h);\n  int w_low = floor(w);\n  int h_high = h_low + 1;\n  int w_high = w_low + 1;\n\n  T lh = h - h_low;\n  T lw = w - w_low;\n  T hh = 1 - lh, hw = 1 - lw;\n\n  T v1 = 0;\n  if (h_low >= 0 && w_low >= 0) v1 = in[h_low * width + w_low];\n  T v2 = 0;\n  if (h_low >= 0 && w_high <= width - 1) v2 = in[h_low * width + w_high];\n  T v3 = 0;\n  if (h_high <= height - 1 && w_low >= 0) v3 = in[h_high * width + w_low];\n  T v4 = 0;\n  if (h_high <= height - 1 && w_high <= width - 1) v4 = in[h_high * width + w_high];\n\n  T w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;\n\n  T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);\n  return val;\n}\n\ntemplate<typename T>\n__device__ T DeformableIm2ColBilinear(const T* bottom_data, const int data_width, const int height,\n                                      const int width, T h, T w) {\n  int h_low = floor(h);\n  int w_low = floor(w);\n  int h_high = h_low + 1;\n  int w_high = w_low + 1;\n\n  T lh = h - h_low;\n  T lw = w - w_low;\n  T hh = 1 - lh, hw = 1 - lw;\n\n  T v1 = 0;\n  if (h_low >= 0 && w_low >= 0) v1 = bottom_data[h_low * data_width + w_low];\n  T v2 = 0;\n  if (h_low >= 0 && w_high <= width - 1) v2 = bottom_data[h_low * data_width + w_high];\n  T v3 = 0;\n  if (h_high <= height - 1 && w_low >= 0) v3 = bottom_data[h_high * data_width + w_low];\n  T v4 = 0;\n  if (h_high <= height - 1 && w_high <= width - 1) v4 = bottom_data[h_high * data_width + w_high];\n\n  T w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;\n\n  T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);\n  return val;\n}\n\ntemplate<typename T>\n__device__ T get_coordinate_weight(const T* im_data, int height, int width, T y, T x,\n                                   bool is_y_direction) {\n  int y_l = floor(y);\n  int x_l = floor(x);\n  int y_h = y_l + 1;\n  int x_h = x_l + 1;\n\n  bool valid_y_l = 0 <= y_l && y_l < height;\n  bool valid_y_h = 0 <= y_h && y_h < height;\n  bool valid_x_l = 0 <= x_l && x_l < width;\n  bool valid_x_h = 0 <= x_h && x_h < width;\n\n  T zero = 0;\n  T v_yx = (valid_y_l && valid_x_l) ? im_data[y_l * width + x_l] : zero;\n  T v_yX = (valid_y_l && valid_x_h) ? im_data[y_l * width + x_h] : zero;\n  T v_Yx = (valid_y_h && valid_x_l) ? im_data[y_h * width + x_l] : zero;\n  T v_YX = (valid_y_h && valid_x_h) ? im_data[y_h * width + x_h] : zero;\n\n  if (is_y_direction) {\n    T dx = x - x_l;\n    return dx * (v_YX - v_yX) + (1 - dx) * (v_Yx - v_yx);\n  } else {\n    T dy = y - y_l;\n    return dy * (v_YX - v_Yx) + (1 - dy) * (v_yX - v_yx);\n  }\n}\n\ntemplate<typename T>\n__device__ T GetGradientWeight(T argmax_h, T argmax_w, const int h, const int w, const int height,\n                               const int width) {\n  if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) {\n    // empty\n    return static_cast<T>(0);\n  }\n\n  int argmax_h_low = floor(argmax_h);\n  int argmax_w_low = floor(argmax_w);\n  int argmax_h_high = argmax_h_low + 1;\n  int argmax_w_high = argmax_w_low + 1;\n\n  T weight = 0;\n  if (h == argmax_h_low && w == argmax_w_low) weight = (h + 1 - argmax_h) * (w + 1 - argmax_w);\n  if (h == argmax_h_low && w == argmax_w_high) weight = (h + 1 - argmax_h) * (argmax_w + 1 - w);\n  if (h == argmax_h_high && w == argmax_w_low) weight = (argmax_h + 1 - h) * (w + 1 - argmax_w);\n  if (h == argmax_h_high && w == argmax_w_high) weight = (argmax_h + 1 - h) * (argmax_w + 1 - w);\n  return weight;\n}\n\ntemplate<typename T>\n__device__ T GetCoordinateWeight(T argmax_h, T argmax_w, const int height, const int width,\n                                 const T* im_data, const int data_width, const int bp_dir) {\n  if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) {\n    // empty\n    return 0;\n  }\n\n  int argmax_h_low = floor(argmax_h);\n  int argmax_w_low = floor(argmax_w);\n  int argmax_h_high = argmax_h_low + 1;\n  int argmax_w_high = argmax_w_low + 1;\n\n  T weight = 0;\n\n  if (bp_dir == 0) {\n    if (argmax_h_low >= 0 && argmax_w_low >= 0)\n      weight +=\n          -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low];\n    if (argmax_h_low >= 0 && argmax_w_high <= width - 1)\n      weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high];\n    if (argmax_h_high <= height - 1 && argmax_w_low >= 0)\n      weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low];\n    if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)\n      weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high];\n  } else if (bp_dir == 1) {\n    if (argmax_h_low >= 0 && argmax_w_low >= 0)\n      weight +=\n          -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low];\n    if (argmax_h_low >= 0 && argmax_w_high <= width - 1)\n      weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high];\n    if (argmax_h_high <= height - 1 && argmax_w_low >= 0)\n      weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low];\n    if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)\n      weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high];\n  }\n\n  return weight;\n}\n\ntemplate<typename T>\n__global__ void DeformableCol2Im(int n, const T* col, const T* offset_data, const T* mask_data,\n                                 int channels, int height, int width, int kernel_h, int kernel_w,\n                                 int pad_h, int pad_w, int stride_h, int stride_w, int dilation_h,\n                                 int dilation_w, int batch_sz, int n_offset_grps, int out_h,\n                                 int out_w, bool use_mask, T* grad_im) {\n  CUDA_1D_KERNEL_LOOP(index, n) {\n    const int out_x = index % out_w;\n    const int out_y = (index / out_w) % out_h;\n    const int b = (index / (out_w * out_h)) % batch_sz;\n    const int j = (index / (out_w * out_h * batch_sz)) % kernel_w;\n    const int i = (index / (out_w * out_h * batch_sz * kernel_w)) % kernel_h;\n    const int c = index / (out_w * out_h * batch_sz * kernel_w * kernel_h);\n\n    int c_per_offset_grp = channels / n_offset_grps;\n    const int offset_grp = c / c_per_offset_grp;\n    auto offset_ptr = offset_data;\n\n    offset_ptr += (b * n_offset_grps + offset_grp) * 2 * kernel_h * kernel_w * out_h * out_w;\n    auto mask_ptr = mask_data;\n    if (use_mask) {\n      mask_ptr += (b * n_offset_grps + offset_grp) * kernel_h * kernel_w * out_h * out_w;\n    }\n\n    const int mask_idx = i * kernel_w + j;\n    const int offset_idx = 2 * mask_idx;\n\n    const int offset_h_ptr = ((offset_idx)*out_h + out_y) * out_w + out_x;\n    const int offset_w_ptr = ((offset_idx + 1) * out_h + out_y) * out_w + out_x;\n\n    const T offset_h = offset_ptr[offset_h_ptr];\n    const T offset_w = offset_ptr[offset_w_ptr];\n\n    T mask_value = 1;\n    if (use_mask) { mask_value = mask_ptr[(mask_idx * out_h + out_y) * out_w + out_x]; }\n\n    const T y = (out_y * stride_h - pad_h) + i * dilation_h + offset_h;\n    const T x = (out_x * stride_w - pad_w) + j * dilation_w + offset_w;\n\n    for (int dy = -1; dy <= 1; dy++) {\n      for (int dx = -1; dx <= 1; dx++) {\n        int yp = (int)y + dy;\n        int xp = (int)x + dx;\n        if (0 <= yp && yp < height && 0 <= xp && xp < width && abs(y - yp) < 1 && abs(x - xp) < 1) {\n          int grad_pos = ((b * channels + c) * height + yp) * width + xp;\n          T weight = (1 - abs(y - yp)) * (1 - abs(x - xp));\n          Add(grad_im + grad_pos, mask_value * weight * col[index]);\n        }\n      }\n    }\n  }\n}\n\ntemplate<typename T>\n__global__ void DeformableIm2Col(int n, const T* input, const T* offset, const T* mask, int height,\n                                 int width, int weight_h, int weight_w, int pad_h, int pad_w,\n                                 int stride_h, int stride_w, int dilation_h, int dilation_w,\n                                 int batch_sz, int n_in_channels, int n_offset_grps, int out_h,\n                                 int out_w, bool use_mask, T* columns) {\n  CUDA_1D_KERNEL_LOOP(index, n) {\n    const int out_x = index % out_w;\n    const int out_y = (index / out_w) % out_h;\n    const int out_b = (index / (out_w * out_h)) % batch_sz;\n    const int in_c = index / (out_w * out_h * batch_sz);\n    const int out_c = in_c * weight_h * weight_w;\n\n    int c_per_offset_grp = n_in_channels / n_offset_grps;\n    const int grp_idx = in_c / c_per_offset_grp;\n    auto columns_ptr = columns;\n    columns_ptr +=\n        (out_c * (batch_sz * out_h * out_w) + out_b * (out_h * out_w) + out_y * out_w + out_x);\n    auto input_ptr = input;\n    input_ptr += (out_b * (n_in_channels * height * width) + in_c * (height * width));\n    auto offset_ptr = offset;\n    offset_ptr += (out_b * n_offset_grps + grp_idx) * 2 * weight_h * weight_w * out_h * out_w;\n    auto mask_ptr = mask;\n    if (use_mask) {\n      mask_ptr += (out_b * n_offset_grps + grp_idx) * weight_h * weight_w * out_h * out_w;\n    }\n\n    for (int i = 0; i < weight_h; ++i) {\n      for (int j = 0; j < weight_w; ++j) {\n        const int mask_idx = i * weight_w + j;\n        const int offset_idx = 2 * mask_idx;\n\n        T mask_value = 1;\n        if (use_mask) { mask_value = mask_ptr[mask_idx * (out_h * out_w) + out_y * out_w + out_x]; }\n\n        const T offset_h = offset_ptr[offset_idx * (out_h * out_w) + out_y * out_w + out_x];\n        const T offset_w = offset_ptr[(offset_idx + 1) * (out_h * out_w) + out_y * out_w + out_x];\n        const T y = (out_y * stride_h - pad_h) + i * dilation_h + offset_h;\n        const T x = (out_x * stride_w - pad_w) + j * dilation_w + offset_w;\n        *columns_ptr = mask_value * bilinear_interpolate(input_ptr, height, width, y, x);\n        columns_ptr += batch_sz * out_h * out_w;\n      }\n    }\n  }\n}\n\ntemplate<typename T>\n__global__ void DeformableCol2imCoord(int n, const T* col_data, const T* im_data,\n                                      const T* offset_data, const T* mask_data, int channels,\n                                      int height, int width, int weight_h, int weight_w, int pad_h,\n                                      int pad_w, int stride_h, int stride_w, int dilation_h,\n                                      int dilation_w, int batch_sz, int offset_channels,\n                                      int n_offset_grps, int out_h, int out_w, const bool use_mask,\n                                      T* grad_offset, T* grad_mask) {\n  CUDA_1D_KERNEL_LOOP(index, n) {\n    T grad_offset_val = 0;\n    T grad_mask_val = 0;\n\n    int w = index % out_w;\n    int h = (index / out_w) % out_h;\n    int w_w = (index / (out_w * out_h * 2)) % weight_w;\n    int w_h = (index / (out_w * out_h * 2 * weight_w)) % weight_h;\n    int c = (index / (out_w * out_h)) % offset_channels;\n    int b = index / (out_w * out_h * offset_channels);\n\n    const int offset_grp = c / (2 * weight_h * weight_w);\n    const int col_step = weight_h * weight_w;\n\n    int c_per_offset_grp = channels / n_offset_grps;\n    auto col_ptr = col_data;\n    col_ptr += offset_grp * c_per_offset_grp * weight_h * weight_w * batch_sz * out_w * out_h;\n    auto im_ptr = im_data;\n    im_ptr += (b * n_offset_grps + offset_grp) * c_per_offset_grp * height * width;\n    auto offset_ptr = offset_data;\n    offset_ptr += (b * n_offset_grps + offset_grp) * 2 * weight_h * weight_w * out_h * out_w;\n    auto mask_ptr = mask_data;\n    if (use_mask) {\n      mask_ptr += (b * n_offset_grps + offset_grp) * weight_h * weight_w * out_h * out_w;\n    }\n\n    const int offset_c = c - offset_grp * 2 * weight_h * weight_w;\n    const bool is_y_direction = offset_c % 2 == 0;\n\n    const int c_bound = c_per_offset_grp * weight_h * weight_w;\n    for (int col_c = (offset_c / 2); col_c < c_bound; col_c += col_step) {\n      const int col_pos = (((col_c * batch_sz + b) * out_h) + h) * out_w + w;\n\n      int out_x = col_pos % out_w;\n      int out_y = (col_pos / out_w) % out_h;\n      int j = (col_pos / (out_w * out_h * batch_sz)) % weight_w;\n      int i = (col_pos / (out_w * out_h * batch_sz * weight_w)) % weight_h;\n\n      const int mask_idx = i * weight_w + j;\n\n      const int offset_h_ptr = (((2 * mask_idx) * out_h + out_y) * out_w + out_x);\n      const int offset_w_ptr = (((2 * mask_idx + 1) * out_h + out_y) * out_w + out_x);\n      const T offset_h = offset_ptr[offset_h_ptr];\n      const T offset_w = offset_ptr[offset_w_ptr];\n\n      T mask_value = 1;\n      if (use_mask) { mask_value = mask_ptr[(mask_idx * out_h + out_y) * out_w + out_x]; }\n\n      T y = (out_y * stride_h - pad_h) + i * dilation_h + offset_h;\n      T x = (out_x * stride_w - pad_w) + j * dilation_w + offset_w;\n\n      const T weight = get_coordinate_weight(im_ptr, height, width, y, x, is_y_direction);\n      grad_offset_val += mask_value * weight * col_ptr[col_pos];\n\n      if (use_mask && is_y_direction) {\n        grad_mask_val += col_ptr[col_pos] * bilinear_interpolate(im_ptr, height, width, y, x);\n      }\n\n      im_ptr += height * width;\n    }\n\n    grad_offset[index] = grad_offset_val;\n\n    if (use_mask && is_y_direction) {\n      const int idx =\n          ((((b * n_offset_grps + offset_grp) * weight_h + w_h) * weight_w + w_w) * out_h + h)\n              * out_w\n          + w;\n      grad_mask[idx] = grad_mask_val;\n    }\n  }\n}\n\n}  // namespace\n\nep::primitive::BlasTransposeType GetBlasTransposeType(bool transpose) {\n  return transpose ? ep::primitive::BlasTransposeType::T : ep::primitive::BlasTransposeType::N;\n}\n\nstd::unique_ptr<ep::primitive::Matmul> NewMatmulPrimitive(DeviceType device_type,\n                                                          DataType data_type, bool transpose_a,\n                                                          bool transpose_b) {\n  const auto trans_a = GetBlasTransposeType(transpose_a);\n  const auto trans_b = GetBlasTransposeType(transpose_b);\n  return ep::primitive::NewPrimitive<ep::primitive::MatmulFactory>(device_type, data_type, trans_a,\n                                                                   trans_b);\n}\n\ntemplate<typename Context>\nstd::unique_ptr<ep::primitive::Permute> NewPermutePrimitive(Context* ctx, const int& num_dims) {\n  return ep::primitive::NewPrimitive<ep::primitive::PermuteFactory>(ctx->device_type(), num_dims);\n}\n\ntemplate<typename T>\nclass DeformableConv2dCudaKernel final : public user_op::OpKernel {\n public:\n  DeformableConv2dCudaKernel() = default;\n  ~DeformableConv2dCudaKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* input = ctx->Tensor4ArgNameAndIndex(\"input\", 0);\n    const user_op::Tensor* weight = ctx->Tensor4ArgNameAndIndex(\"weight\", 0);\n    const user_op::Tensor* offset = ctx->Tensor4ArgNameAndIndex(\"offset\", 0);\n    const user_op::Tensor* mask = ctx->Tensor4ArgNameAndIndex(\"mask\", 0);\n    user_op::Tensor* output = ctx->Tensor4ArgNameAndIndex(\"output\", 0);\n    user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n    const ShapeView& input_shape = input->shape_view();\n    const ShapeView& output_shape = output->shape_view();\n    const int64_t out_elem_cnt = output_shape.elem_cnt();\n    const int64_t output_bytes = GetCudaAlignedSize(out_elem_cnt * sizeof(T));\n\n    T* column_tmp_buffer = reinterpret_cast<T*>(tmp_buffer->mut_dptr<char>() + output_bytes);\n    const int32_t kW = weight->shape_view().At(2);\n    const int32_t kH = weight->shape_view().At(3);\n    const int32_t dW = ctx->Attr<int32_t>(\"stride_w\");\n    const int32_t dH = ctx->Attr<int32_t>(\"stride_h\");\n    const int32_t padW = ctx->Attr<int32_t>(\"pad_w\");\n    const int32_t padH = ctx->Attr<int32_t>(\"pad_h\");\n    const int32_t dilationW = ctx->Attr<int32_t>(\"dilation_w\");\n    const int32_t dilationH = ctx->Attr<int32_t>(\"dilation_h\");\n    const int32_t group = ctx->Attr<int32_t>(\"groups\");\n    const int32_t deformable_group = ctx->Attr<int32_t>(\"offset_groups\");\n    const bool use_mask = ctx->Attr<bool>(\"use_mask\");\n\n    const int32_t channel_per_deformable_group = input_shape.At(1) / deformable_group;\n    const int64_t outputWidth =\n        (input_shape.At(3) + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;\n    const int64_t outputHeight =\n        (input_shape.At(2) + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;\n    const int64_t column_nums = input_shape.At(1) * input_shape.At(0) * outputWidth * outputHeight;\n    if (column_nums > 0) {\n      DeformableIm2Col<T><<<BlocksNum4ThreadsNum(column_nums), kCudaThreadsNumPerBlock, 0,\n                            ctx->stream()->As<ep::CudaStream>()->cuda_stream()>>>(\n          column_nums, input->dptr<T>(), offset->dptr<T>(), mask->dptr<T>(), input_shape.At(2),\n          input_shape.At(3), kH, kW, padH, padW, dH, dW, dilationH, dilationW, input_shape.At(0),\n          input_shape.At(1), deformable_group, output_shape.At(2), output_shape.At(3), use_mask,\n          column_tmp_buffer);\n\n      const int64_t weight_group_offset = weight->shape_view().elem_cnt() / group;\n      const int64_t column_group_offset =\n          input_shape.At(1) * kW * kH * input_shape.At(0) * outputHeight * outputWidth / group;\n      const int64_t output_group_offset = out_elem_cnt / group;\n\n      auto matmul = NewMatmulPrimitive(ctx->device_type(), output->data_type(), false, false);\n      CHECK(matmul);\n\n      FOR_RANGE(int, g, 0, group) {\n        matmul->Launch(ctx->stream(), weight->shape_view().At(0) / group,\n                       input_shape.At(0) * outputHeight * outputWidth,\n                       input_shape.At(1) * kW * kH / group, static_cast<T>(1),\n                       weight->dptr<T>() + g * weight_group_offset,\n                       column_tmp_buffer + g * column_group_offset, static_cast<T>(0),\n                       tmp_buffer->mut_dptr<T>() + g * output_group_offset);\n      }\n\n      std::vector<int64_t> out_shapevec(\n          {output_shape.At(1), output_shape.At(0), output_shape.At(2), output_shape.At(3)});\n\n      auto transpose = NewPermutePrimitive(ctx, output_shape.NumAxes());\n      CHECK(transpose);\n      transpose->Launch(ctx->stream(), output->data_type(), output_shape.NumAxes(),\n                        out_shapevec.data(), tmp_buffer->dptr<T>(),\n                        std::vector<int>({1, 0, 2, 3}).data(), output->mut_dptr<T>());\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\ntemplate<typename T>\nclass DeformableConv2dInputGradKernel final : public user_op::OpKernel {\n public:\n  DeformableConv2dInputGradKernel() = default;\n  ~DeformableConv2dInputGradKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* output_grad = ctx->Tensor4ArgNameAndIndex(\"output_grad\", 0);\n    const user_op::Tensor* input = ctx->Tensor4ArgNameAndIndex(\"input\", 0);\n    const user_op::Tensor* weight = ctx->Tensor4ArgNameAndIndex(\"weight\", 0);\n    const user_op::Tensor* offset = ctx->Tensor4ArgNameAndIndex(\"offset\", 0);\n    user_op::Tensor* input_grad = ctx->Tensor4ArgNameAndIndex(\"input_grad\", 0);\n    user_op::Tensor* offset_grad = ctx->Tensor4ArgNameAndIndex(\"offset_grad\", 0);\n    user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n    const ShapeView& output_grad_shape = output_grad->shape_view();\n    const ShapeView& input_shape = input->shape_view();\n    const ShapeView& weight_shape = weight->shape_view();\n    const int32_t kW = weight->shape_view().At(2);\n    const int32_t kH = weight->shape_view().At(3);\n    const int32_t dW = ctx->Attr<int32_t>(\"stride_w\");\n    const int32_t dH = ctx->Attr<int32_t>(\"stride_h\");\n    const int32_t padW = ctx->Attr<int32_t>(\"pad_w\");\n    const int32_t padH = ctx->Attr<int32_t>(\"pad_h\");\n    const int32_t dilationW = ctx->Attr<int32_t>(\"dilation_w\");\n    const int32_t dilationH = ctx->Attr<int32_t>(\"dilation_h\");\n    const int32_t group = ctx->Attr<int32_t>(\"groups\");\n    const int32_t deformable_group = ctx->Attr<int32_t>(\"offset_groups\");\n    const bool use_mask = ctx->Attr<bool>(\"use_mask\");\n    const T* data_mask = nullptr;\n    T* data_mask_grad = nullptr;\n    if (use_mask) {\n      data_mask = ctx->Tensor4ArgNameAndIndex(\"mask\", 0)->dptr<T>();\n      data_mask_grad = ctx->Tensor4ArgNameAndIndex(\"mask_grad\", 0)->mut_dptr<T>();\n    }\n\n    const int64_t outputWidth =\n        (input_shape.At(3) + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;\n    const int64_t outputHeight =\n        (input_shape.At(2) + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;\n\n    std::unique_ptr<ep::primitive::Memset> primitive =\n        ep::primitive::NewPrimitive<ep::primitive::MemsetFactory>(ctx->stream()->device_type());\n\n    primitive->Launch(ctx->stream(), input_grad->mut_dptr<T>(), 0,\n                      input_grad->shape_view().elem_cnt() * sizeof(T));\n\n    const int64_t nthreads_coord =\n        outputHeight * outputWidth * 2 * deformable_group * input_shape.At(0) * kW * kH;\n    const int64_t nthreads_feat =\n        outputHeight * outputWidth * input_shape.At(0) * input_shape.At(1) * kW * kH;\n    if (nthreads_coord > 0 && nthreads_feat > 0) {\n      const int64_t weight_group_offset = weight_shape.elem_cnt() / group;\n      const int64_t output_grad_group_offset = output_grad_shape.Count(1) / group;\n      const int64_t column_group_offset =\n          input_shape.At(1) * kW * kH * input_shape.At(0) * outputHeight * outputWidth / group;\n\n      auto matmul = NewMatmulPrimitive(ctx->device_type(), input_grad->data_type(), true, true);\n      CHECK(matmul);\n      FOR_RANGE(int, g, 0, group) {\n        matmul->Launch(ctx->stream(), weight_shape.Count(1),\n                       input_shape.At(0) * outputHeight * outputWidth, weight_shape.At(0) / group,\n                       static_cast<T>(1), weight->dptr<T>() + g * weight_group_offset,\n                       output_grad->dptr<T>() + g * output_grad_group_offset, static_cast<T>(0),\n                       tmp_buffer->mut_dptr<T>() + g * column_group_offset);\n      }\n      DeformableCol2imCoord<T><<<BlocksNum4ThreadsNum(nthreads_coord), kCudaThreadsNumPerBlock, 0,\n                                 ctx->stream()->As<ep::CudaStream>()->cuda_stream()>>>(\n          nthreads_coord, tmp_buffer->dptr<T>(), input->dptr<T>(), offset->dptr<T>(), data_mask,\n          input_shape.At(1), input_shape.At(2), input_shape.At(3), kH, kW, padH, padW, dH, dW,\n          dilationH, dilationW, input_shape.At(0), 2 * kH * kW * deformable_group, deformable_group,\n          outputHeight, outputWidth, use_mask, offset_grad->mut_dptr<T>(), data_mask_grad);\n      DeformableCol2Im<T><<<BlocksNum4ThreadsNum(nthreads_feat), kCudaThreadsNumPerBlock, 0,\n                            ctx->stream()->As<ep::CudaStream>()->cuda_stream()>>>(\n          nthreads_feat, tmp_buffer->dptr<T>(), offset->dptr<T>(), data_mask, input_shape.At(1),\n          input_shape.At(2), input_shape.At(3), kH, kW, padH, padW, dH, dW, dilationH, dilationW,\n          input_shape.At(0), deformable_group, outputHeight, outputWidth, use_mask,\n          input_grad->mut_dptr<T>());\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\ntemplate<typename T>\nclass DeformableConv2dParamGradKernel final : public user_op::OpKernel {\n public:\n  DeformableConv2dParamGradKernel() = default;\n  ~DeformableConv2dParamGradKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* output_grad = ctx->Tensor4ArgNameAndIndex(\"output_grad\", 0);\n    const user_op::Tensor* input = ctx->Tensor4ArgNameAndIndex(\"input\", 0);\n    const user_op::Tensor* offset = ctx->Tensor4ArgNameAndIndex(\"offset\", 0);\n    const user_op::Tensor* weight = ctx->Tensor4ArgNameAndIndex(\"weight\", 0);\n    user_op::Tensor* weight_grad = ctx->Tensor4ArgNameAndIndex(\"weight_grad\", 0);\n    user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n    const ShapeView& output_grad_shape = output_grad->shape_view();\n    const ShapeView& weight_grad_shape = weight_grad->shape_view();\n    const ShapeView& input_shape = input->shape_view();\n    const int64_t out_elem_cnt = output_grad_shape.elem_cnt();\n    const int64_t output_bytes = GetCudaAlignedSize(out_elem_cnt * sizeof(T));\n    T* column_tmp_buffer = reinterpret_cast<T*>(tmp_buffer->mut_dptr<char>() + output_bytes);\n    const int32_t kW = weight->shape_view().At(2);\n    const int32_t kH = weight->shape_view().At(3);\n    const int32_t dW = ctx->Attr<int32_t>(\"stride_w\");\n    const int32_t dH = ctx->Attr<int32_t>(\"stride_h\");\n    const int32_t padW = ctx->Attr<int32_t>(\"pad_w\");\n    const int32_t padH = ctx->Attr<int32_t>(\"pad_h\");\n    const int32_t dilationW = ctx->Attr<int32_t>(\"dilation_w\");\n    const int32_t dilationH = ctx->Attr<int32_t>(\"dilation_h\");\n    const int32_t group = ctx->Attr<int32_t>(\"groups\");\n    const int32_t deformable_group = ctx->Attr<int32_t>(\"offset_groups\");\n    const bool use_mask = ctx->Attr<bool>(\"use_mask\");\n    const int32_t channel_per_deformable_group = input_shape.At(1) / deformable_group;\n    const int64_t outputWidth =\n        (input_shape.At(3) + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;\n    const int64_t outputHeight =\n        (input_shape.At(2) + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;\n\n    const T* data_mask = nullptr;\n    if (use_mask) { data_mask = ctx->Tensor4ArgNameAndIndex(\"mask\", 0)->dptr<T>(); }\n    const int64_t column_nums = input_shape.At(1) * input_shape.At(0) * outputHeight * outputWidth;\n    if (column_nums > 0) {\n      DeformableIm2Col<T><<<BlocksNum4ThreadsNum(column_nums), kCudaThreadsNumPerBlock, 0,\n                            ctx->stream()->As<ep::CudaStream>()->cuda_stream()>>>(\n          column_nums, input->dptr<T>(), offset->dptr<T>(), data_mask, input_shape.At(2),\n          input_shape.At(3), kH, kW, padH, padW, dH, dW, dilationH, dilationW, input_shape.At(0),\n          input_shape.At(1), deformable_group, output_grad_shape.At(2), output_grad_shape.At(3),\n          use_mask, column_tmp_buffer);\n\n      std::unique_ptr<ep::primitive::Memset> primitive =\n          ep::primitive::NewPrimitive<ep::primitive::MemsetFactory>(ctx->stream()->device_type());\n      primitive->Launch(ctx->stream(), weight_grad->mut_dptr<T>(), 0,\n                        weight_grad->shape_view().elem_cnt() * sizeof(T));\n\n      std::vector<int64_t> output_grad_buffer_vec({output_grad_shape.At(1), output_grad_shape.At(0),\n                                                   output_grad_shape.At(2),\n                                                   output_grad_shape.At(3)});\n\n      auto transpose = NewPermutePrimitive(ctx, output_grad_shape.NumAxes());\n      CHECK(transpose);\n      transpose->Launch(ctx->stream(), output_grad->data_type(), output_grad_shape.NumAxes(),\n                        output_grad_buffer_vec.data(), output_grad->dptr<T>(),\n                        std::vector<int>({1, 0, 2, 3}).data(), tmp_buffer->mut_dptr<T>());\n\n      const int64_t output_grad_group_offset = output_grad_shape.elem_cnt() / group;\n      const int64_t column_group_offset =\n          input_shape.At(1) * kW * kW * input_shape.At(0) * outputHeight * outputWidth / group;\n      const int64_t weight_grad_group_offset = weight_grad->shape_view().elem_cnt() / group;\n      FOR_RANGE(int, g, 0, group) {\n        auto matmul = NewMatmulPrimitive(ctx->device_type(), weight_grad->data_type(), false, true);\n        CHECK(matmul);\n\n        matmul->Launch(ctx->stream(), weight_grad_shape.At(0) / group,\n                       input_shape.At(1) * kW * kH / group,\n                       input_shape.At(0) * outputHeight * outputWidth, static_cast<T>(1),\n                       tmp_buffer->dptr<T>() + g * output_grad_group_offset,\n                       column_tmp_buffer + g * column_group_offset, static_cast<T>(0),\n                       weight_grad->mut_dptr<T>() + g * weight_grad_group_offset);\n      }\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_DEFORM_CONV2D_GPU_KERNEL(dtype)                                                  \\\n  REGISTER_USER_KERNEL(\"deform_conv2d\")                                                           \\\n      .SetCreateFn<DeformableConv2dCudaKernel<dtype>>()                                           \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)                            \\\n                       && (user_op::HobDataType(\"input\", 0) == GetDataType<dtype>::value))        \\\n      .SetInferTmpSizeFn([](user_op::InferContext* ctx) {                                         \\\n        const Shape& input_shape = ctx->InputShape(\"input\", 0);                                   \\\n        const Shape& output_shape = ctx->OutputShape(\"output\", 0);                                \\\n        const Shape& weight_shape = ctx->InputShape(\"weight\", 0);                                 \\\n        const int32_t kW = weight_shape.At(2);                                                    \\\n        const int32_t kH = weight_shape.At(3);                                                    \\\n        const int32_t dW = ctx->Attr<int32_t>(\"stride_w\");                                        \\\n        const int32_t dH = ctx->Attr<int32_t>(\"stride_h\");                                        \\\n        const int32_t padW = ctx->Attr<int32_t>(\"pad_w\");                                         \\\n        const int32_t padH = ctx->Attr<int32_t>(\"pad_h\");                                         \\\n        const int32_t dilationW = ctx->Attr<int32_t>(\"dilation_w\");                               \\\n        const int32_t dilationH = ctx->Attr<int32_t>(\"dilation_h\");                               \\\n        const int64_t outputWidth =                                                               \\\n            (input_shape.At(3) + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;                 \\\n        const int64_t outputHeight =                                                              \\\n            (input_shape.At(2) + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;                 \\\n        const int64_t column_bytes =                                                              \\\n            GetCudaAlignedSize(input_shape.At(1) * kW * kH * input_shape.At(0) * outputHeight     \\\n                               * outputWidth * sizeof(dtype));                                    \\\n        const int64_t output_bytes = GetCudaAlignedSize(output_shape.elem_cnt() * sizeof(dtype)); \\\n        return column_bytes + output_bytes;                                                       \\\n      });\nREGISTER_DEFORM_CONV2D_GPU_KERNEL(float)\nREGISTER_DEFORM_CONV2D_GPU_KERNEL(double)\n\n#define REGISTER_DEFORM_CONV2D_INPUT_GRAD_GPU_KERNEL(dtype)                                   \\\n  REGISTER_USER_KERNEL(\"deform_conv2d_input_grad\")                                            \\\n      .SetCreateFn<DeformableConv2dInputGradKernel<dtype>>()                                  \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)                        \\\n                       && (user_op::HobDataType(\"input\", 0) == GetDataType<dtype>::value)     \\\n                       && (user_op::HobDataType(\"weight\", 0) == GetDataType<dtype>::value)    \\\n                       && (user_op::HobDataType(\"offset\", 0) == GetDataType<dtype>::value))   \\\n      .SetInferTmpSizeFn([](user_op::InferContext* ctx) {                                     \\\n        const Shape& input_shape = ctx->InputShape(\"input\", 0);                               \\\n        const Shape& weight_shape = ctx->InputShape(\"weight\", 0);                             \\\n        const int32_t kW = weight_shape.At(2);                                                \\\n        const int32_t kH = weight_shape.At(3);                                                \\\n        const int32_t dW = ctx->Attr<int32_t>(\"stride_w\");                                    \\\n        const int32_t dH = ctx->Attr<int32_t>(\"stride_h\");                                    \\\n        const int32_t padW = ctx->Attr<int32_t>(\"pad_w\");                                     \\\n        const int32_t padH = ctx->Attr<int32_t>(\"pad_h\");                                     \\\n        const int32_t dilationW = ctx->Attr<int32_t>(\"dilation_w\");                           \\\n        const int32_t dilationH = ctx->Attr<int32_t>(\"dilation_h\");                           \\\n        const int64_t outputWidth =                                                           \\\n            (input_shape.At(3) + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;             \\\n        const int64_t outputHeight =                                                          \\\n            (input_shape.At(2) + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;             \\\n        const int64_t column_bytes =                                                          \\\n            GetCudaAlignedSize(input_shape.At(1) * kW * kH * input_shape.At(0) * outputHeight \\\n                               * outputWidth * sizeof(dtype));                                \\\n        return column_bytes;                                                                  \\\n      });\nREGISTER_DEFORM_CONV2D_INPUT_GRAD_GPU_KERNEL(float)\nREGISTER_DEFORM_CONV2D_INPUT_GRAD_GPU_KERNEL(double)\n\n#define REGISTER_DEFORM_CONV2D_PARAM_GRAD_GPU_KERNEL(dtype)                                   \\\n  REGISTER_USER_KERNEL(\"deform_conv2d_param_grad\")                                            \\\n      .SetCreateFn<DeformableConv2dParamGradKernel<dtype>>()                                  \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)                        \\\n                       && (user_op::HobDataType(\"input\", 0) == GetDataType<dtype>::value)     \\\n                       && (user_op::HobDataType(\"offset\", 0) == GetDataType<dtype>::value))   \\\n      .SetInferTmpSizeFn([](user_op::InferContext* ctx) {                                     \\\n        const Shape& input_shape = ctx->InputShape(\"input\", 0);                               \\\n        const Shape& output_grad_shape = ctx->InputShape(\"output_grad\", 0);                   \\\n        const Shape& weight_shape = ctx->InputShape(\"weight\", 0);                             \\\n        const int32_t kW = weight_shape.At(2);                                                \\\n        const int32_t kH = weight_shape.At(3);                                                \\\n        const int32_t dW = ctx->Attr<int32_t>(\"stride_w\");                                    \\\n        const int32_t dH = ctx->Attr<int32_t>(\"stride_h\");                                    \\\n        const int32_t padW = ctx->Attr<int32_t>(\"pad_w\");                                     \\\n        const int32_t padH = ctx->Attr<int32_t>(\"pad_h\");                                     \\\n        const int32_t dilationW = ctx->Attr<int32_t>(\"dilation_w\");                           \\\n        const int32_t dilationH = ctx->Attr<int32_t>(\"dilation_h\");                           \\\n        const int64_t outputWidth =                                                           \\\n            (input_shape.At(3) + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;             \\\n        const int64_t outputHeight =                                                          \\\n            (input_shape.At(2) + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;             \\\n        const int64_t column_bytes =                                                          \\\n            GetCudaAlignedSize(input_shape.At(1) * kW * kH * input_shape.At(0) * outputHeight \\\n                               * outputWidth * sizeof(dtype));                                \\\n        const int64_t output_bytes =                                                          \\\n            GetCudaAlignedSize(output_grad_shape.elem_cnt() * sizeof(dtype));                 \\\n        return column_bytes + output_bytes;                                                   \\\n      });\nREGISTER_DEFORM_CONV2D_PARAM_GRAD_GPU_KERNEL(float)\nREGISTER_DEFORM_CONV2D_PARAM_GRAD_GPU_KERNEL(double)\n\n}  // namespace oneflow"
  },
  {
    "path": "oneflow/user/kernels/det_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/common/eigen_util.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nstatic inline size_t BatchCount(const user_op::Tensor* batched_matrices) {\n  size_t result = 1;\n  for (size_t i = 0; i < batched_matrices->shape_view().NumAxes() - 2; i++) {\n    result *= batched_matrices->shape_view().At(i);\n  }\n  return result;\n}\n\nstatic inline size_t MatrixStride(const user_op::Tensor* batched_matrices) {\n  const int64_t num_axes = batched_matrices->shape_view().NumAxes();\n  return batched_matrices->shape_view().At(num_axes - 2)\n         * batched_matrices->shape_view().At(num_axes - 1);\n}\n\n}  // namespace\n\ntemplate<typename T>\nclass DetKernel final : public user_op::OpKernel {\n public:\n  DetKernel() = default;\n  ~DetKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    auto batch_count = BatchCount(x);\n    auto matrix_stride = MatrixStride(x);\n    auto matrix_size = x->shape_view().At(x->shape_view().NumAxes() - 2);\n    const T* x_ptr = x->dptr<T>();\n    T* y_ptr = y->mut_dptr<T>();\n\n    FOR_RANGE(int64_t, i, 0, batch_count) {\n      ConstEigenMatrixMap<T> x_mat(x_ptr + i * matrix_stride, matrix_size, matrix_size);\n      if (x_mat.determinant() == 0) {\n        LOG(FATAL)\n            << \"(Batch element \" << i\n            << \"): the inversion could not be completed because the input matrix is singular.\";\n      }\n      T y = x_mat.determinant();\n      *(y_ptr + i) = y;\n    };\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_DET_KERNEL(dtype)                                             \\\n  REGISTER_USER_KERNEL(\"det\").SetCreateFn<DetKernel<dtype>>().SetIsMatchedHob( \\\n      (user_op::HobDeviceType() == DeviceType::kCPU)                           \\\n      && (user_op::HobDataType(\"x\", 0) == GetDataType<dtype>::value));\n\nREGISTER_DET_KERNEL(float)\nREGISTER_DET_KERNEL(double)\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/diag_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/new_kernel_util.h\"\n#include \"oneflow/core/kernel/kernel_util.h\"\n#include \"oneflow/user/kernels/diag_kernel.h\"\n\nnamespace oneflow {\nnamespace {\n\ntemplate<typename T>\nstruct DiagFunctor<DeviceType::kCPU, T> final {\n  void operator()(ep::Stream* stream, T* out_buf, const T* in_buf, int32_t size, int32_t stride,\n                  int32_t in_dim) {\n    if (in_dim == 1) {\n      FOR_RANGE(int32_t, i, 0, size) { out_buf[i * stride] = in_buf[i]; }\n    } else {\n      FOR_RANGE(int32_t, i, 0, size) { out_buf[i] = in_buf[i * stride]; }\n    }\n  }\n};\n\ntemplate<typename T>\nstruct DiagGradFunctor<DeviceType::kCPU, T> final {\n  void operator()(ep::Stream* stream, T* dx_buf, const T* dy_buf, int32_t dx_cnt, int32_t dy_cnt,\n                  int32_t stride, int32_t in_dim) {\n    if (in_dim == 1) {\n      FOR_RANGE(int32_t, i, 0, dx_cnt) { dx_buf[i] = dy_buf[i * stride]; }\n    } else {\n      FOR_RANGE(int32_t, i, 0, dy_cnt) { dx_buf[i * stride] = dy_buf[i]; }\n    }\n  }\n};\n\n}  // namespace\n\nREGISTER_DIAG_KERNELS(DeviceType::kCPU, float);\nREGISTER_DIAG_KERNELS(DeviceType::kCPU, double);\nREGISTER_DIAG_KERNELS(DeviceType::kCPU, bool);\nREGISTER_DIAG_KERNELS(DeviceType::kCPU, uint8_t);\nREGISTER_DIAG_KERNELS(DeviceType::kCPU, int8_t);\nREGISTER_DIAG_KERNELS(DeviceType::kCPU, int32_t);\nREGISTER_DIAG_KERNELS(DeviceType::kCPU, int64_t);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/diag_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/new_kernel_util.h\"\n#include \"oneflow/core/kernel/kernel_util.h\"\n#include \"oneflow/user/kernels/diag_kernel.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n\nnamespace oneflow {\nnamespace {\n\ntemplate<typename T>\n__global__ void vector_diagonal_kernel(T* out_buf, const T* in_buf, int32_t size, int32_t stride) {\n  CUDA_1D_KERNEL_LOOP(i, size) { out_buf[i * stride] = in_buf[i]; }\n}\n\ntemplate<typename T>\n__global__ void matrix_diagonal_kernel(T* out_buf, const T* in_buf, int32_t size, int32_t stride) {\n  CUDA_1D_KERNEL_LOOP(i, size) { out_buf[i] = in_buf[i * stride]; }\n}\n\ntemplate<typename T>\nstruct DiagFunctor<DeviceType::kCUDA, T> final {\n  void operator()(ep::Stream* stream, T* out_buf, const T* in_buf, int32_t size, int32_t stride,\n                  int32_t in_dim) {\n    if (in_dim == 1) {\n      vector_diagonal_kernel<<<BlocksNum4ThreadsNum(size * size), kCudaThreadsNumPerBlock, 0,\n                               stream->As<ep::CudaStream>()->cuda_stream()>>>(out_buf, in_buf, size,\n                                                                              stride);\n    } else {\n      matrix_diagonal_kernel<<<BlocksNum4ThreadsNum(size * size), kCudaThreadsNumPerBlock, 0,\n                               stream->As<ep::CudaStream>()->cuda_stream()>>>(out_buf, in_buf, size,\n                                                                              stride);\n    }\n  }\n};\n\ntemplate<typename T>\nstruct DiagGradFunctor<DeviceType::kCUDA, T> final {\n  void operator()(ep::Stream* stream, T* dx_buf, const T* dy_buf, int32_t dx_cnt, int32_t dy_cnt,\n                  int32_t stride, int32_t in_dim) {\n    if (in_dim == 1) {\n      matrix_diagonal_kernel<<<BlocksNum4ThreadsNum(dx_cnt), kCudaThreadsNumPerBlock, 0,\n                               stream->As<ep::CudaStream>()->cuda_stream()>>>(dx_buf, dy_buf,\n                                                                              dx_cnt, stride);\n    } else {\n      vector_diagonal_kernel<<<BlocksNum4ThreadsNum(dy_cnt), kCudaThreadsNumPerBlock, 0,\n                               stream->As<ep::CudaStream>()->cuda_stream()>>>(dx_buf, dy_buf,\n                                                                              dy_cnt, stride);\n    }\n  }\n};\n\n}  // namespace\n\nREGISTER_DIAG_KERNELS(DeviceType::kCUDA, half);\nREGISTER_DIAG_KERNELS(DeviceType::kCUDA, float);\nREGISTER_DIAG_KERNELS(DeviceType::kCUDA, double);\nREGISTER_DIAG_KERNELS(DeviceType::kCUDA, bool);\nREGISTER_DIAG_KERNELS(DeviceType::kCUDA, uint8_t);\nREGISTER_DIAG_KERNELS(DeviceType::kCUDA, int8_t);\nREGISTER_DIAG_KERNELS(DeviceType::kCUDA, int32_t);\nREGISTER_DIAG_KERNELS(DeviceType::kCUDA, int64_t);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/diag_kernel.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef _ONEFLOW_USER_KERNELS_DIAG_KERNEL_H_\n#define _ONEFLOW_USER_KERNELS_DIAG_KERNEL_H_\n#include \"oneflow/core/ndarray/xpu_util.h\"\n#include \"oneflow/core/framework/framework.h\"\n\nnamespace oneflow {\n\nnamespace {\ntemplate<DeviceType device_type, typename T>\nstruct DiagFunctor final {\n  void operator()(ep::Stream* stream, T* out_buf, const T* in_buf, int32_t size, int32_t stride,\n                  int32_t in_dim);\n};\n\ntemplate<DeviceType device_type, typename T>\nstruct DiagGradFunctor final {\n  void operator()(ep::Stream* stream, T* dx_buf, const T* dy_buf, int32_t dx_cnt, int32_t dy_cnt,\n                  int32_t stride, int32_t in_dim);\n};\n}  // namespace\n\ntemplate<DeviceType device_type, typename T>\nclass DiagKernel final : public user_op::OpKernel {\n public:\n  DiagKernel() = default;\n  ~DiagKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const int32_t diagonal = ctx->Attr<int32_t>(\"diagonal\");\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    const ShapeView& out_shape = out->shape_view();\n    const ShapeView& in_shape = in->shape_view();\n    int32_t in_dim = in_shape.NumAxes();\n    const T* in_buf = in->dptr<T>();\n    T* out_buf = out->mut_dptr<T>();\n\n    Memset<device_type>(ctx->stream(), out->mut_dptr(), 0, out_shape.elem_cnt() * sizeof(T));\n\n    if (in_dim == 1) {\n      int32_t size = in_shape.elem_cnt();\n      out_buf += (diagonal >= 0 ? diagonal : -diagonal * out_shape.At(1));\n      DiagFunctor<device_type, T>()(ctx->stream(), out_buf, in_buf, size, out_shape.At(1) + 1,\n                                    in_dim);\n    } else {\n      int32_t size = 0;\n      in_buf += (diagonal >= 0 ? diagonal : -diagonal * in_shape.At(1));\n      if (diagonal >= 0) {\n        size = std::min(in_shape.At(0), in_shape.At(1) - diagonal);\n      } else {\n        size = std::min(in_shape.At(0) + diagonal, in_shape.At(1));\n      }\n      DiagFunctor<device_type, T>()(ctx->stream(), out_buf, in_buf, size, in_shape.At(1) + 1,\n                                    in_dim);\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\ntemplate<DeviceType device_type, typename T>\nclass DiagBackwardKernel final : public user_op::OpKernel {\n public:\n  DiagBackwardKernel() = default;\n  ~DiagBackwardKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex(\"dx\", 0);\n    int32_t diagonal = ctx->Attr<int32_t>(\"diagonal\");\n    const ShapeView& dx_shape = dx->shape_view();\n    const ShapeView& dy_shape = dy->shape_view();\n    int32_t in_dim = dx_shape.NumAxes();\n    int32_t dy_cnt = dy_shape.Count(0);\n    int32_t dx_cnt = dx_shape.Count(0);\n    T* dx_buf = dx->mut_dptr<T>();\n    const T* dy_buf = dy->dptr<T>();\n\n    Memset<device_type>(ctx->stream(), dx->mut_dptr<T>(), 0, dx_shape.elem_cnt() * sizeof(T));\n\n    if (in_dim == 1) {\n      dy_buf += (diagonal >= 0 ? diagonal : -diagonal * dy_shape.At(1));\n      DiagGradFunctor<device_type, T>()(ctx->stream(), dx_buf, dy_buf, dx_cnt, dy_cnt,\n                                        dy_shape.At(1) + 1, in_dim);\n    } else {\n      dx_buf += (diagonal >= 0 ? diagonal : -diagonal * dx_shape.At(1));\n      DiagGradFunctor<device_type, T>()(ctx->stream(), dx_buf, dy_buf, dx_cnt, dy_cnt,\n                                        dx_shape.At(1) + 1, in_dim);\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_DIAG_KERNELS(device, dtype)                                             \\\n  REGISTER_USER_KERNEL(\"diag\").SetCreateFn<DiagKernel<device, dtype>>().SetIsMatchedHob( \\\n      (user_op::HobDeviceType() == device)                                               \\\n      && (user_op::HobDataType(\"in\", 0) == GetDataType<dtype>::value));                  \\\n  REGISTER_USER_KERNEL(\"diag_grad\")                                                      \\\n      .SetCreateFn<DiagBackwardKernel<device, dtype>>()                                  \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == device)                              \\\n                       && (user_op::HobDataType(\"in\", 0) == GetDataType<dtype>::value));\n\n}  // namespace oneflow\n\n#endif  // _ONEFLOW_USER_KERNELS_DIAG_KERNEL_H_\n"
  },
  {
    "path": "oneflow/user/kernels/diagonal_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/new_kernel_util.h\"\n#include \"oneflow/core/kernel/kernel_util.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n\nnamespace oneflow {\nnamespace {\n\ntemplate<typename T>\nstruct DiagonalFunctor final {\n  void operator()(ep::Stream* stream, T* out_buf, const T* in_buf, int32_t size, int32_t dim1,\n                  int32_t dim2) {\n    int32_t offset_index = (dim1 + 1) * dim2;\n    FOR_RANGE(int32_t, index, 0, size * dim2) {\n      int32_t i = index / dim2;\n      int32_t j = index - i * dim2;\n      out_buf[j * size + i] = in_buf[i * offset_index + j];\n    }\n  }\n};\n\ntemplate<typename T>\nstruct DiagonalGradFunctor final {\n  void operator()(ep::Stream* stream, T* dx_buf, const T* dy_buf, int32_t size, int32_t dim1,\n                  int32_t dim2) {\n    int32_t offset_index = (dim1 + 1) * dim2;\n    FOR_RANGE(int32_t, index, 0, size * dim2) {\n      int32_t i = index / dim2;\n      int32_t j = index - i * dim2;\n      dx_buf[i * offset_index + j] = dy_buf[j * size + i];\n    }\n  }\n};\n\n}  // namespace\n\ntemplate<typename T>\nclass CpuDiagonalKernel final : public user_op::OpKernel {\n public:\n  CpuDiagonalKernel() = default;\n  ~CpuDiagonalKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const int32_t offset = ctx->Attr<int32_t>(\"offset\");\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    const ShapeView& out_shape = out->shape_view();\n    const ShapeView& in_shape = in->shape_view();\n    const T* in_buf = in->dptr<T>();\n    T* out_buf = out->mut_dptr<T>();\n\n    int32_t size = out_shape.At(out_shape.NumAxes() - 1);\n    int32_t dim1 = in_shape.At(1);\n    int32_t dim2 = 0;\n    if (in_shape.NumAxes() <= 2) {\n      dim2 = 1;\n    } else {\n      dim2 = in_shape.Count(2, in_shape.NumAxes());\n    }\n\n    int32_t offset_in_bufer = (offset >= 0 ? offset * dim2 : -offset * dim1 * dim2);\n    in_buf += offset_in_bufer;\n    DiagonalFunctor<T>()(ctx->stream(), out_buf, in_buf, size, dim1, dim2);\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\ntemplate<typename T>\nclass CpuDiagonalBackwardKernel final : public user_op::OpKernel {\n public:\n  CpuDiagonalBackwardKernel() = default;\n  ~CpuDiagonalBackwardKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex(\"dx\", 0);\n    int32_t offset = ctx->Attr<int32_t>(\"offset\");\n    const ShapeView& dx_shape = dx->shape_view();\n    const ShapeView& dy_shape = dy->shape_view();\n    T* dx_buf = dx->mut_dptr<T>();\n    const T* dy_buf = dy->dptr<T>();\n\n    Memset<DeviceType::kCPU>(ctx->stream(), dx->mut_dptr<T>(), 0, dx_shape.elem_cnt() * sizeof(T));\n\n    int32_t dim1 = dx_shape.At(1);\n    int32_t dim2 = 0;\n    if (dx_shape.NumAxes() <= 2) {\n      dim2 = 1;\n    } else {\n      dim2 = dx_shape.Count(2, dx_shape.NumAxes());\n    }\n    int32_t size = dy_shape.At(dy_shape.NumAxes() - 1);\n    int32_t offset_in_bufer = (offset >= 0 ? offset * dim2 : -offset * dim1 * dim2);\n    dx_buf += offset_in_bufer;\n\n    DiagonalGradFunctor<T>()(ctx->stream(), dx_buf, dy_buf, size, dim1, dim2);\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_DIAGONAL_KERNELS(dtype)                                                 \\\n  REGISTER_USER_KERNEL(\"diagonal\")                                                       \\\n      .SetCreateFn<CpuDiagonalKernel<dtype>>()                                           \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)                    \\\n                       && (user_op::HobDataType(\"in\", 0) == GetDataType<dtype>::value)); \\\n  REGISTER_USER_KERNEL(\"diagonal_grad\")                                                  \\\n      .SetCreateFn<CpuDiagonalBackwardKernel<dtype>>()                                   \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)                    \\\n                       && (user_op::HobDataType(\"in\", 0) == GetDataType<dtype>::value));\n\nREGISTER_DIAGONAL_KERNELS(bool);\nREGISTER_DIAGONAL_KERNELS(float);\nREGISTER_DIAGONAL_KERNELS(double);\nREGISTER_DIAGONAL_KERNELS(int8_t);\nREGISTER_DIAGONAL_KERNELS(int32_t);\nREGISTER_DIAGONAL_KERNELS(int64_t);\n\n#undef REGISTER_DIAGONAL_KERNELS\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/diagonal_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/new_kernel_util.h\"\n#include \"oneflow/core/kernel/kernel_util.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n\nnamespace oneflow {\nnamespace {\n\ntemplate<typename T>\n__global__ void forward_diagonal_kernel(T* out_buf, const T* in_buf, int32_t size, int32_t dim1,\n                                        int32_t dim2) {\n  int32_t offset_index = (dim1 + 1) * dim2;\n  CUDA_1D_KERNEL_LOOP(index, size * dim2) {\n    int32_t i = index / dim2;\n    int32_t j = index - i * dim2;\n    out_buf[j * size + i] = in_buf[i * offset_index + j];\n  }\n}\n\ntemplate<typename T>\n__global__ void backward_diagonal_kernel(T* dx_buf, const T* dy_buf, int32_t size, int32_t dim1,\n                                         int32_t dim2) {\n  int32_t offset_index = (dim1 + 1) * dim2;\n  CUDA_1D_KERNEL_LOOP(index, size * dim2) {\n    int32_t i = index / dim2;\n    int32_t j = index - i * dim2;\n    dx_buf[i * offset_index + j] = dy_buf[j * size + i];\n  }\n}\n\ntemplate<typename T>\nstruct DiagonalFunctor final {\n  void operator()(ep::Stream* stream, T* out_buf, const T* in_buf, int32_t size, int32_t dim1,\n                  int32_t dim2) {\n    if (size * dim2 > 0) {\n      forward_diagonal_kernel<T>\n          <<<BlocksNum4ThreadsNum(size * dim2), kCudaThreadsNumPerBlock, 0,\n             stream->As<ep::CudaStream>()->cuda_stream()>>>(out_buf, in_buf, size, dim1, dim2);\n    }\n  }\n};\n\ntemplate<typename T>\nstruct DiagonalGradFunctor final {\n  void operator()(ep::Stream* stream, T* dx_buf, const T* dy_buf, int32_t size, int32_t dim1,\n                  int32_t dim2) {\n    if (size * dim2 > 0) {\n      backward_diagonal_kernel<T>\n          <<<BlocksNum4ThreadsNum(size * dim2), kCudaThreadsNumPerBlock, 0,\n             stream->As<ep::CudaStream>()->cuda_stream()>>>(dx_buf, dy_buf, size, dim1, dim2);\n    }\n  }\n};\n\n}  // namespace\n\ntemplate<typename T>\nclass GpuDiagonalKernel final : public user_op::OpKernel {\n public:\n  GpuDiagonalKernel() = default;\n  ~GpuDiagonalKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const int32_t offset = ctx->Attr<int32_t>(\"offset\");\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    const ShapeView& out_shape = out->shape_view();\n    const ShapeView& in_shape = in->shape_view();\n    const T* in_buf = in->dptr<T>();\n    T* out_buf = out->mut_dptr<T>();\n\n    int32_t size = out_shape.At(out_shape.NumAxes() - 1);\n    int32_t dim1 = in_shape.At(1);\n    int32_t dim2 = 0;\n    if (in_shape.NumAxes() <= 2) {\n      dim2 = 1;\n    } else {\n      dim2 = in_shape.Count(2, in_shape.NumAxes());\n    }\n\n    int32_t offset_in_bufer = (offset >= 0 ? offset * dim2 : -offset * dim1 * dim2);\n    in_buf += offset_in_bufer;\n\n    DiagonalFunctor<T>()(ctx->stream(), out_buf, in_buf, size, dim1, dim2);\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\ntemplate<typename T>\nclass GpuDiagonalBackwardKernel final : public user_op::OpKernel {\n public:\n  GpuDiagonalBackwardKernel() = default;\n  ~GpuDiagonalBackwardKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex(\"dx\", 0);\n    int32_t offset = ctx->Attr<int32_t>(\"offset\");\n    const ShapeView& dx_shape = dx->shape_view();\n    const ShapeView& dy_shape = dy->shape_view();\n    T* dx_buf = dx->mut_dptr<T>();\n    const T* dy_buf = dy->dptr<T>();\n\n    Memset<DeviceType::kCUDA>(ctx->stream(), dx->mut_dptr<T>(), 0, dx_shape.elem_cnt() * sizeof(T));\n\n    int32_t dim1 = dx_shape.At(1);\n    int32_t dim2 = 0;\n    if (dx_shape.NumAxes() <= 2) {\n      dim2 = 1;\n    } else {\n      dim2 = dx_shape.Count(2, dx_shape.NumAxes());\n    }\n    int32_t size = dy_shape.At(dy_shape.NumAxes() - 1);\n    int32_t offset_in_bufer = (offset >= 0 ? offset * dim2 : -offset * dim1 * dim2);\n    dx_buf += offset_in_bufer;\n\n    DiagonalGradFunctor<T>()(ctx->stream(), dx_buf, dy_buf, size, dim1, dim2);\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_DIAGONAL_KERNELS(dtype)                                                 \\\n  REGISTER_USER_KERNEL(\"diagonal\")                                                       \\\n      .SetCreateFn<GpuDiagonalKernel<dtype>>()                                           \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)                   \\\n                       && (user_op::HobDataType(\"in\", 0) == GetDataType<dtype>::value)); \\\n  REGISTER_USER_KERNEL(\"diagonal_grad\")                                                  \\\n      .SetCreateFn<GpuDiagonalBackwardKernel<dtype>>()                                   \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)                   \\\n                       && (user_op::HobDataType(\"in\", 0) == GetDataType<dtype>::value));\n\nREGISTER_DIAGONAL_KERNELS(bool);\nREGISTER_DIAGONAL_KERNELS(half);\nREGISTER_DIAGONAL_KERNELS(float);\nREGISTER_DIAGONAL_KERNELS(double);\nREGISTER_DIAGONAL_KERNELS(int8_t);\nREGISTER_DIAGONAL_KERNELS(int32_t);\nREGISTER_DIAGONAL_KERNELS(int64_t);\n\n#undef REGISTER_DIAGONAL_KERNELS\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/dim_gather_kernel_util.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/user/kernels/dim_gather_kernel_util.h\"\n\nnamespace oneflow {\n\nnamespace user_op {\n\ntemplate<typename IN_T, typename IDX_T>\nstruct DimGatherFunctor<DeviceType::kCPU, IN_T, IDX_T> final {\n  void operator()(ep::Stream* stream, const DimOpIndexNdHelper<IDX_T>& input_nd_helper,\n                  const DimOpIndexNdHelper<IDX_T>& index_nd_helper, int ndim, int64_t elem_cnt,\n                  int32_t dim_length, int32_t dim, const IDX_T* index, const IN_T* input,\n                  IN_T* output) {\n    DoDimGather<IN_T, IDX_T>(input_nd_helper, index_nd_helper, ndim, elem_cnt, dim_length, dim,\n                             index, input, output);\n  }\n};\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_DIM_GATHER_FUNCTOR, (DeviceType::kCPU),\n                                 DIM_GATHER_SCATTER_DATA_TYPE_CPU_SEQ, INDEX_DATA_TYPE_SEQ);\n}  // namespace user_op\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/dim_gather_kernel_util.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifdef WITH_CUDA\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/user/kernels/dim_gather_kernel_util.h\"\n\nnamespace oneflow {\n\nnamespace user_op {\n\ntemplate<typename IN_T, typename IDX_T>\n__global__ void DoCUDADimGather(const DimOpIndexNdHelper<IDX_T> input_nd_helper,\n                                const DimOpIndexNdHelper<IDX_T> index_nd_helper, int ndim,\n                                int64_t elem_cnt, int32_t dim_length, int32_t dim,\n                                const IDX_T* index, const IN_T* input, IN_T* output) {\n  DoDimGather<IN_T, IDX_T>(input_nd_helper, index_nd_helper, ndim, elem_cnt, dim_length, dim, index,\n                           input, output);\n}\n\ntemplate<typename IDX_T, typename IN_T>\nstruct DimGatherFunctor<DeviceType::kCUDA, IN_T, IDX_T> final {\n  void operator()(ep::Stream* stream, const DimOpIndexNdHelper<IDX_T>& input_nd_helper,\n                  const DimOpIndexNdHelper<IDX_T>& index_nd_helper, int ndim, int64_t elem_cnt,\n                  int32_t dim_length, int32_t dim, const IDX_T* index, const IN_T* input,\n                  IN_T* output) {\n    RUN_CUDA_KERNEL((DoCUDADimGather<IN_T, IDX_T>), stream, BlocksNum4ThreadsNum(elem_cnt),\n                    input_nd_helper, index_nd_helper, ndim, elem_cnt, dim_length, dim, index, input,\n                    output);\n  }\n};\n\n// float16 special case of DimGatherFunctor template\ntemplate<typename IDX_T>\nstruct DimGatherFunctor<DeviceType::kCUDA, float16, IDX_T> final {\n  void operator()(ep::Stream* stream, const DimOpIndexNdHelper<IDX_T>& input_nd_helper,\n                  const DimOpIndexNdHelper<IDX_T>& index_nd_helper, int ndim, int64_t elem_cnt,\n                  int32_t dim_length, int32_t dim, const IDX_T* index, const float16* input,\n                  float16* output) {\n    RUN_CUDA_KERNEL((DoCUDADimGather<half, IDX_T>), stream, BlocksNum4ThreadsNum(elem_cnt),\n                    input_nd_helper, index_nd_helper, ndim, elem_cnt, dim_length, dim, index,\n                    reinterpret_cast<const half*>(input), reinterpret_cast<half*>(output));\n  }\n};\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_DIM_GATHER_FUNCTOR, (DeviceType::kCUDA),\n                                 DIM_GATHER_SCATTER_DATA_TYPE_CUDA_SEQ, INDEX_DATA_TYPE_SEQ);\n\n}  // namespace user_op\n}  // namespace oneflow\n\n#endif  // WITH_CUDA\n"
  },
  {
    "path": "oneflow/user/kernels/dim_gather_kernel_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_KERNELS_DIM_GATHER_KERNEL_UTIL_H_\n#define ONEFLOW_USER_KERNELS_DIM_GATHER_KERNEL_UTIL_H_\n#ifdef WITH_CUDA\n#include \"oneflow/core/cuda/atomic.cuh\"\n#endif  // WITH_CUDA\n#include \"oneflow/core/ndarray/xpu_util.h\"\n#include \"oneflow/core/common/nd_index_offset_helper.h\"\n\nnamespace oneflow {\n\n#define DIM_GATHER_SCATTER_DATA_TYPE_CPU_SEQ \\\n  ARITHMETIC_DATA_TYPE_SEQ                   \\\n  UNSIGNED_INT_DATA_TYPE_SEQ                 \\\n  BOOL_DATA_TYPE_SEQ\n\n#define DIM_GATHER_SCATTER_DATA_TYPE_CUDA_SEQ \\\n  DIM_GATHER_SCATTER_DATA_TYPE_CPU_SEQ        \\\n  FLOAT16_DATA_TYPE_SEQ\n\nconstexpr int kDimGatherMaxDimCount = 8;\n\ntemplate<typename T>\nusing DimOpIndexNdHelper = NdIndexOffsetHelper<T, kDimGatherMaxDimCount>;\n\nnamespace user_op {\n\ntemplate<DeviceType device_type, typename IN_T, typename IDX_T>\nstruct DimGatherFunctor final {\n  void operator()(ep::Stream* stream, const DimOpIndexNdHelper<IDX_T>& input_nd_helper,\n                  const DimOpIndexNdHelper<IDX_T>& index_nd_helper, int ndim, int64_t elem_cnt,\n                  int32_t dim_length, int32_t dim, const IDX_T* index, const IN_T* input,\n                  IN_T* output);\n};\n\ntemplate<typename IN_T, typename IDX_T>\nOF_DEVICE_FUNC void DoDimGather(const DimOpIndexNdHelper<IDX_T>& input_nd_helper,\n                                const DimOpIndexNdHelper<IDX_T>& index_nd_helper, int ndim,\n                                int64_t elem_cnt, int32_t dim_length, int32_t dim,\n                                const IDX_T* index, const IN_T* input, IN_T* output) {\n  XPU_1D_KERNEL_LOOP(index_offset, elem_cnt) {\n    IDX_T coordinate[kDimGatherMaxDimCount] = {0};\n    const IDX_T x = index[index_offset];\n#ifdef __CUDA_ARCH__\n    assert(x < dim_length && \"gather index is out of bounds\");\n#else\n    CHECK_LE(x, dim_length) << \"RuntimeError: index \" << x << \" is out of bounds for dimension \"\n                            << dim << \" with size \" << dim_length;\n#endif\n    index_nd_helper.OffsetToNdIndex(index_offset, coordinate, ndim);\n    coordinate[dim] = x;\n\n    IDX_T input_offset = input_nd_helper.NdIndexToOffset(coordinate, ndim);\n    output[index_offset] = input[input_offset];\n  }\n}\n\ntemplate<typename T>\nstruct DeviceAdd {\n  OF_DEVICE_FUNC static void Invoke(const T* x, T* y) {\n#ifdef __CUDA_ARCH__\n    cuda::atomic::Add(y, *x);  // TODO:(YaoChi), refine add using float16 -> half -> float -> half\n#else\n    *y += *x;\n#endif\n  };\n};\n\n// macros for functors instantiate(used by dim_gather_kernel_util.cu and dim_gather_kernel_uti.cpp)\n#define INSTANTIATE_DIM_GATHER_FUNCTOR(device_type_v, dtype_pair, itype_pair)   \\\n  template struct DimGatherFunctor<device_type_v, OF_PP_PAIR_FIRST(dtype_pair), \\\n                                   OF_PP_PAIR_FIRST(itype_pair)>;\n\n}  // namespace user_op\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_KERNELS_DIM_GATHER_KERNEL_UTIL_H_\n"
  },
  {
    "path": "oneflow/user/kernels/dim_gather_kernels.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/common/data_type.h\"\n#include \"oneflow/core/common/shape_view.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/user/kernels/dim_gather_kernel_util.h\"\n\nnamespace oneflow {\nnamespace user_op {\n\nnamespace {\n\ntemplate<typename IDX_T>\nvoid ConvertShape2Array(const ShapeView& shape_view, IDX_T* array, int64_t num_axis) {\n  FOR_RANGE(int64_t, i, 0, num_axis) { array[i] = shape_view.At(i); }\n}\n\n}  // namespace\n\ntemplate<DeviceType device_type, typename IN_T, typename IDX_T>\nclass DimGatherKernel final : public user_op::OpKernel {\n public:\n  DimGatherKernel() = default;\n  ~DimGatherKernel() override = default;\n\n private:\n  void Compute(KernelComputeContext* ctx) const override {\n    const Tensor* input_tensor = ctx->Tensor4ArgNameAndIndex(\"input\", 0);\n    if (input_tensor->shape_view().elem_cnt() == 0) { return; }\n    const Tensor* index_tensor = ctx->Tensor4ArgNameAndIndex(\"index\", 0);\n    Tensor* out_tensor = ctx->Tensor4ArgNameAndIndex(\"output\", 0);\n    const int32_t dim = ctx->Attr<int32_t>(\"dim\");\n\n    const IN_T* input = input_tensor->dptr<IN_T>();\n    const IDX_T* index = index_tensor->dptr<IDX_T>();\n    IN_T* output = out_tensor->mut_dptr<IN_T>();\n\n    const Shape in_shape = ExpandDimIf0D(input_tensor->shape_view());\n    const auto ndim = in_shape.NumAxes();\n    const auto dim_length = in_shape.At(dim);\n\n    DimOpIndexNdHelper<IDX_T> input_nd_helper(in_shape.data(), ndim);\n    DimOpIndexNdHelper<IDX_T> index_nd_helper(index_tensor->shape_view().data(), ndim);\n    DimGatherFunctor<device_type, IN_T, IDX_T>()(ctx->stream(), input_nd_helper, index_nd_helper,\n                                                 ndim, index_tensor->shape_view().elem_cnt(),\n                                                 dim_length, dim, index, input, output);\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_DIM_GATHER_KERNEL(device, dtype_pair, itype_pair)                               \\\n  REGISTER_USER_KERNEL(\"dim_gather\")                                                             \\\n      .SetCreateFn<                                                                              \\\n          DimGatherKernel<device, OF_PP_PAIR_FIRST(dtype_pair), OF_PP_PAIR_FIRST(itype_pair)>>() \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == device)                                      \\\n                       && (user_op::HobDataType(\"input\", 0) == OF_PP_PAIR_SECOND(dtype_pair))    \\\n                       && (user_op::HobDataType(\"index\", 0) == OF_PP_PAIR_SECOND(itype_pair)));\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(\n    REGISTER_DIM_GATHER_KERNEL, (DeviceType::kCPU),\n    ARITHMETIC_DATA_TYPE_SEQ UNSIGNED_INT_DATA_TYPE_SEQ BOOL_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ)\n\n#ifdef WITH_CUDA\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_DIM_GATHER_KERNEL, (DeviceType::kCUDA),\n                                 ARITHMETIC_DATA_TYPE_SEQ UNSIGNED_INT_DATA_TYPE_SEQ\n                                     FLOAT16_DATA_TYPE_SEQ BOOL_DATA_TYPE_SEQ,\n                                 INDEX_DATA_TYPE_SEQ)\n#endif  // WITH_CUDA\n\n}  // namespace user_op\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/dim_scatter_kernel_util.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/user/kernels/dim_scatter_kernel_util.h\"\n\nnamespace oneflow {\nnamespace user_op {\n\ntemplate<typename IN_T, typename IDX_T, template<typename T> class Opt>\nstruct DimScatterFunctor<DeviceType::kCPU, IN_T, IDX_T, Opt> final {\n  void operator()(ep::Stream* stream, const DimOpIndexNdHelper<IDX_T>& src_nd_helper,\n                  const DimOpIndexNdHelper<IDX_T>& idx_nd_helper,\n                  const DimOpIndexNdHelper<IDX_T>& output_nd_helper, const int ndim,\n                  const int64_t elem_cnt, const int32_t dim, const int64_t upper_bound,\n                  const IDX_T* index, const IN_T* src, IN_T* output) {\n    DoDimScatter<IN_T, IDX_T, Opt>(src_nd_helper, idx_nd_helper, output_nd_helper, ndim, elem_cnt,\n                                   dim, upper_bound, index, src, output);\n  }\n};\n\nINSTANTIATE_DIM_SCATTER_CPU_FUNCTORS(DeviceType::kCPU, BinOpAddFunctor);\nINSTANTIATE_DIM_SCATTER_CPU_FUNCTORS(DeviceType::kCPU, BinOpMulFunctor);\nINSTANTIATE_DIM_SCATTER_CPU_FUNCTORS(DeviceType::kCPU, BinOpUpdateFunctor);\n\n}  // namespace user_op\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/dim_scatter_kernel_util.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifdef WITH_CUDA\n#include \"oneflow/user/kernels/dim_scatter_kernel_util.h\"\n\nnamespace oneflow {\nnamespace user_op {\n\ntemplate<typename IN_T, typename IDX_T, template<typename T> class Opt>\n__global__ void DoCUDADimScatter(const DimOpIndexNdHelper<IDX_T> src_nd_helper,\n                                 const DimOpIndexNdHelper<IDX_T> idx_nd_helper,\n                                 const DimOpIndexNdHelper<IDX_T> output_nd_helper, const int ndim,\n                                 const int64_t elem_cnt, const int32_t dim,\n                                 const int64_t upper_bound, const IDX_T* index, const IN_T* src,\n                                 IN_T* output) {\n  DoDimScatter<IN_T, IDX_T, Opt>(src_nd_helper, idx_nd_helper, output_nd_helper, ndim, elem_cnt,\n                                 dim, upper_bound, index, src, output);\n}\n\ntemplate<typename IN_T, typename IDX_T, template<typename T> class Opt>\nstruct DimScatterFunctor<DeviceType::kCUDA, IN_T, IDX_T, Opt> final {\n  void operator()(ep::Stream* stream, const DimOpIndexNdHelper<IDX_T>& src_nd_helper,\n                  const DimOpIndexNdHelper<IDX_T>& idx_nd_helper,\n                  const DimOpIndexNdHelper<IDX_T>& output_nd_helper, const int ndim,\n                  const int64_t elem_cnt, const int32_t dim, const int64_t upper_bound,\n                  const IDX_T* index, const IN_T* src, IN_T* output) {\n    RUN_CUDA_KERNEL((DoCUDADimScatter<IN_T, IDX_T, Opt>), stream, BlocksNum4ThreadsNum(elem_cnt),\n                    src_nd_helper, idx_nd_helper, output_nd_helper, ndim, elem_cnt, dim,\n                    upper_bound, index, src, output);\n  }\n};\n\ntemplate<typename IDX_T, template<typename T> class Opt>\nstruct DimScatterFunctor<DeviceType::kCUDA, float16, IDX_T, Opt> final {\n  void operator()(ep::Stream* stream, const DimOpIndexNdHelper<IDX_T>& src_nd_helper,\n                  const DimOpIndexNdHelper<IDX_T>& idx_nd_helper,\n                  const DimOpIndexNdHelper<IDX_T>& output_nd_helper, const int ndim,\n                  const int64_t elem_cnt, const int32_t dim, const int64_t upper_bound,\n                  const IDX_T* index, const float16* src, float16* output) {\n    RUN_CUDA_KERNEL((DoCUDADimScatter<half, IDX_T, Opt>), stream, BlocksNum4ThreadsNum(elem_cnt),\n                    src_nd_helper, idx_nd_helper, output_nd_helper, ndim, elem_cnt, dim,\n                    upper_bound, index, reinterpret_cast<const half*>(src),\n                    reinterpret_cast<half*>(output));\n  }\n};\n\nINSTANTIATE_DIM_SCATTER_CUDA_FUNCTORS(DeviceType::kCUDA, BinOpAddFunctor);\nINSTANTIATE_DIM_SCATTER_CUDA_FUNCTORS(DeviceType::kCUDA, BinOpMulFunctor);\nINSTANTIATE_DIM_SCATTER_CUDA_FUNCTORS(DeviceType::kCUDA, BinOpUpdateFunctor);\n\n}  // namespace user_op\n}  // namespace oneflow\n\n#endif  // WITH_CUDA\n"
  },
  {
    "path": "oneflow/user/kernels/dim_scatter_kernel_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_KERNELS_DIM_SCATTER_KERNEL_UTIL_H_\n#define ONEFLOW_USER_KERNELS_DIM_SCATTER_KERNEL_UTIL_H_\n#ifdef WITH_CUDA\n#include \"oneflow/core/cuda/atomic.cuh\"\n#include <cuda_fp16.h>\n#endif  // WITH_CUDA\n\n#include \"oneflow/core/ndarray/xpu_util.h\"\n#include \"oneflow/core/common/nd_index_offset_helper.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/common/data_type.h\"\n#include \"oneflow/core/common/shape_view.h\"\n#include \"oneflow/core/common/error.pb.h\"\n\nnamespace oneflow {\n\n#define NO_HALF_UTIL_FOUND         \\\n  printf(\"cuda arch must >= 530\"); \\\n  assert(false)\n\nnamespace user_op {\n\nconstexpr int kDimGatherMaxDimCount = 8;\n\ntemplate<typename T>\nusing DimOpIndexNdHelper = NdIndexOffsetHelper<T, kDimGatherMaxDimCount>;\n\n#define INSTANTIATE_DIM_SCATTER_CPU_FUNCTORS(device_type, opt)           \\\n  template struct DimScatterFunctor<device_type, bool, int32_t, opt>;    \\\n  template struct DimScatterFunctor<device_type, uint8_t, int32_t, opt>; \\\n  template struct DimScatterFunctor<device_type, int8_t, int32_t, opt>;  \\\n  template struct DimScatterFunctor<device_type, int32_t, int32_t, opt>; \\\n  template struct DimScatterFunctor<device_type, int64_t, int32_t, opt>; \\\n  template struct DimScatterFunctor<device_type, float, int32_t, opt>;   \\\n  template struct DimScatterFunctor<device_type, double, int32_t, opt>;  \\\n  template struct DimScatterFunctor<device_type, float16, int32_t, opt>; \\\n  template struct DimScatterFunctor<device_type, bool, int64_t, opt>;    \\\n  template struct DimScatterFunctor<device_type, uint8_t, int64_t, opt>; \\\n  template struct DimScatterFunctor<device_type, int8_t, int64_t, opt>;  \\\n  template struct DimScatterFunctor<device_type, int32_t, int64_t, opt>; \\\n  template struct DimScatterFunctor<device_type, int64_t, int64_t, opt>; \\\n  template struct DimScatterFunctor<device_type, float, int64_t, opt>;   \\\n  template struct DimScatterFunctor<device_type, double, int64_t, opt>;  \\\n  template struct DimScatterFunctor<device_type, float16, int64_t, opt>;\n\n#define INSTANTIATE_DIM_SCATTER_CUDA_FUNCTORS(device_type, opt)          \\\n  template struct DimScatterFunctor<device_type, bool, int32_t, opt>;    \\\n  template struct DimScatterFunctor<device_type, uint8_t, int32_t, opt>; \\\n  template struct DimScatterFunctor<device_type, int8_t, int32_t, opt>;  \\\n  template struct DimScatterFunctor<device_type, int32_t, int32_t, opt>; \\\n  template struct DimScatterFunctor<device_type, int64_t, int32_t, opt>; \\\n  template struct DimScatterFunctor<device_type, float, int32_t, opt>;   \\\n  template struct DimScatterFunctor<device_type, double, int32_t, opt>;  \\\n  template struct DimScatterFunctor<device_type, half, int32_t, opt>;    \\\n  template struct DimScatterFunctor<device_type, bool, int64_t, opt>;    \\\n  template struct DimScatterFunctor<device_type, uint8_t, int64_t, opt>; \\\n  template struct DimScatterFunctor<device_type, int8_t, int64_t, opt>;  \\\n  template struct DimScatterFunctor<device_type, int32_t, int64_t, opt>; \\\n  template struct DimScatterFunctor<device_type, int64_t, int64_t, opt>; \\\n  template struct DimScatterFunctor<device_type, float, int64_t, opt>;   \\\n  template struct DimScatterFunctor<device_type, double, int64_t, opt>;  \\\n  template struct DimScatterFunctor<device_type, half, int64_t, opt>;\n\ntemplate<typename T>\nstruct BinOpAddFunctor {\n  OF_DEVICE_FUNC static void apply(const T* x, T* y) {\n#ifdef __CUDA_ARCH__\n    cuda::atomic::Add(y, *x);\n#else\n    *y += *x;\n#endif\n  }\n};\n\n#ifdef WITH_CUDA\ntemplate<>\nstruct BinOpAddFunctor<half> {\n  OF_DEVICE_FUNC static void apply(const half* x, half* y) {\n#ifdef __CUDA_ARCH__\n    *y = __float2half(__half2float(*x) + __half2float(*y));\n#else\n    NO_HALF_UTIL_FOUND;\n#endif\n  }\n};\n#endif\n\n#define SPECIALIZE_BIN_OP_ADD_FUNCTOR(name, dtype)                           \\\n  template<>                                                                 \\\n  struct name<dtype> {                                                       \\\n    OF_DEVICE_FUNC static void apply(const dtype* x, dtype* y) { *y += *x; } \\\n  };\n\nSPECIALIZE_BIN_OP_ADD_FUNCTOR(BinOpAddFunctor, bool)\nSPECIALIZE_BIN_OP_ADD_FUNCTOR(BinOpAddFunctor, int8_t)\nSPECIALIZE_BIN_OP_ADD_FUNCTOR(BinOpAddFunctor, uint8_t)\nSPECIALIZE_BIN_OP_ADD_FUNCTOR(BinOpAddFunctor, int64_t)\n\ntemplate<typename T>\nstruct BinOpMulFunctor {\n  OF_DEVICE_FUNC static void apply(const T* x, T* y) {\n#ifdef __CUDA_ARCH__\n    cuda::atomic::Mul(y, *x);\n#else\n    *y *= *x;\n#endif\n  }\n};\n\n#ifdef WITH_CUDA\ntemplate<>\nstruct BinOpMulFunctor<half> {\n  OF_DEVICE_FUNC static void apply(const half* x, half* y) {\n#ifdef __CUDA_ARCH__\n    *y = __float2half(__half2float(*x) * __half2float(*y));\n#else\n    NO_HALF_UTIL_FOUND;\n#endif\n  }\n};\n#endif\n\n#define SPECIALIZE_BIN_OP_MUL_FUNCTOR(name, dtype)                           \\\n  template<>                                                                 \\\n  struct name<dtype> {                                                       \\\n    OF_DEVICE_FUNC static void apply(const dtype* x, dtype* y) { *y *= *x; } \\\n  };\n\nSPECIALIZE_BIN_OP_ADD_FUNCTOR(BinOpMulFunctor, int8_t)\nSPECIALIZE_BIN_OP_ADD_FUNCTOR(BinOpMulFunctor, uint8_t)\nSPECIALIZE_BIN_OP_ADD_FUNCTOR(BinOpMulFunctor, int64_t)\n\ntemplate<>\nstruct BinOpMulFunctor<bool> {\n  OF_DEVICE_FUNC static void apply(const bool* x, bool* y) { *y &= *x; }\n};\n\ntemplate<typename T>\nstruct BinOpUpdateFunctor {\n  OF_DEVICE_FUNC static void apply(const T* x, T* y) { *y = *x; }\n};\n\ntemplate<DeviceType device_type, typename IN_T, typename IDX_T, template<typename T> class Opt>\nstruct DimScatterFunctor final {\n  void operator()(ep::Stream* stream, const DimOpIndexNdHelper<IDX_T>& src_nd_helper,\n                  const DimOpIndexNdHelper<IDX_T>& idx_nd_helper,\n                  const DimOpIndexNdHelper<IDX_T>& output_nd_helper, const int ndim,\n                  const int64_t elem_cnt, const int32_t dim, const int64_t upper_bound,\n                  const IDX_T* index, const IN_T* src, IN_T* output);\n};\n\ntemplate<typename IN_T, typename IDX_T, template<typename T> class Opt>\nOF_DEVICE_FUNC void DoDimScatter(const DimOpIndexNdHelper<IDX_T>& src_nd_helper,\n                                 const DimOpIndexNdHelper<IDX_T>& idx_nd_helper,\n                                 const DimOpIndexNdHelper<IDX_T>& output_nd_helper, const int ndim,\n                                 const int64_t elem_cnt, const int32_t dim, int64_t upper_bound,\n                                 const IDX_T* index, const IN_T* src, IN_T* output) {\n  XPU_1D_KERNEL_LOOP(idx_offset, elem_cnt) {\n    IDX_T coordinate[kDimGatherMaxDimCount] = {0};\n    idx_nd_helper.OffsetToNdIndex(idx_offset, coordinate, ndim);  // idx_offset -> ijk\n    IDX_T idx_elem = index[idx_offset];\n    if (upper_bound != 0 && idx_elem >= upper_bound) {\n#if __CUDA_ARCH__\n      __trap();\n#else\n      UNIMPLEMENTED() << \"The index element \" << idx_elem << \" is out of bounds for dimension \"\n                      << dim << \" with size \" << upper_bound << \".\";\n#endif\n    }\n    IDX_T src_offset = src_nd_helper.NdIndexToOffset(coordinate, ndim);\n    coordinate[dim] = idx_elem;\n    IDX_T output_offset = output_nd_helper.NdIndexToOffset(coordinate, ndim);\n    Opt<IN_T>::apply(src + src_offset, output + output_offset);\n  }\n}\n\n}  // namespace user_op\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_KERNELS_DIM_SCATTER_KERNEL_UTIL_H_\n"
  },
  {
    "path": "oneflow/user/kernels/dim_scatter_kernels.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/error.pb.h\"\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/user/kernels/dim_scatter_kernel_util.h\"\n\nnamespace oneflow {\nnamespace user_op {\n\ntemplate<DeviceType device_type, typename IN_T, typename IDX_T, template<typename T> class Opt>\nclass DimScatterKernel final : public user_op::OpKernel {\n public:\n  DimScatterKernel() = default;\n  ~DimScatterKernel() override = default;\n\n private:\n  void Compute(KernelComputeContext* ctx) const override {\n    const Tensor* input_tensor = ctx->Tensor4ArgNameAndIndex(\"input\", 0);\n    const Tensor* index_tensor = ctx->Tensor4ArgNameAndIndex(\"index\", 0);\n    Tensor* out_tensor = ctx->Tensor4ArgNameAndIndex(\"output\", 0);\n    const Tensor* src_tensor = ctx->Tensor4ArgNameAndIndex(\"src\", 0);\n    const int32_t dim = ctx->Attr<int32_t>(\"dim\");\n\n    const IDX_T* index = index_tensor->dptr<IDX_T>();\n    IN_T* output = out_tensor->mut_dptr<IN_T>();\n    size_t out_bytes_size =\n        out_tensor->shape_view().elem_cnt() * GetSizeOfDataType(out_tensor->data_type());\n\n    Tensor* like_tensor = ctx->Tensor4ArgNameAndIndex(\"like\", 0);\n    const IN_T* src = src_tensor->dptr<IN_T>();\n\n    if (input_tensor) {\n      Memcpy<device_type>(ctx->stream(), output, input_tensor->dptr<IN_T>(), out_bytes_size);\n    } else if (like_tensor) {\n      Memset<device_type>(ctx->stream(), output, 0, out_bytes_size);\n    } else {\n      UNIMPLEMENTED() << \"Input tensor and like tensor cannot be empty simultaneously.\";\n    }\n\n    const Shape src_shape = ExpandDimIf0D(src_tensor->shape_view());\n    const Shape index_shape = ExpandDimIf0D(index_tensor->shape_view());\n    const int ndim = src_shape.NumAxes();\n    DimOpIndexNdHelper<IDX_T> src_nd_helper(src_shape.data(), ndim);\n    DimOpIndexNdHelper<IDX_T> idx_nd_helper(index_shape.data(), ndim);\n    DimOpIndexNdHelper<IDX_T> output_nd_helper(out_tensor->shape_view().data(), ndim);\n\n    const int64_t upper_bound = [&]() {\n      if (input_tensor) {\n        const Shape input_shape = ExpandDimIf0D(input_tensor->shape_view());\n        return input_shape.At(dim);\n      } else {\n        const Shape like_shape = ExpandDimIf0D(like_tensor->shape_view());\n        return like_shape.At(dim);\n      }\n    }();\n\n    DimScatterFunctor<device_type, IN_T, IDX_T, Opt>()(\n        ctx->stream(), src_nd_helper, idx_nd_helper, output_nd_helper, ndim, index_shape.elem_cnt(),\n        dim, upper_bound, index, src, output);\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_DIM_SCATTER_LIKE_KERNEL(op_type, device, dtype, itype, opt)             \\\n  REGISTER_USER_KERNEL(op_type)                                                          \\\n      .SetCreateFn<DimScatterKernel<device, dtype, itype, opt>>()                        \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == device)                              \\\n                       && (user_op::HobDataType(\"like\", 0) == GetDataType<dtype>::value) \\\n                       && (user_op::HobDataType(\"index\", 0) == GetDataType<itype>::value));\n\n#define REGISTER_DIM_SCATTER_LIKE_CPU_KERNELS(op_type, opt)                           \\\n  REGISTER_DIM_SCATTER_LIKE_KERNEL(op_type, DeviceType::kCPU, bool, int32_t, opt);    \\\n  REGISTER_DIM_SCATTER_LIKE_KERNEL(op_type, DeviceType::kCPU, float, int32_t, opt);   \\\n  REGISTER_DIM_SCATTER_LIKE_KERNEL(op_type, DeviceType::kCPU, double, int32_t, opt);  \\\n  REGISTER_DIM_SCATTER_LIKE_KERNEL(op_type, DeviceType::kCPU, float16, int32_t, opt); \\\n  REGISTER_DIM_SCATTER_LIKE_KERNEL(op_type, DeviceType::kCPU, int32_t, int32_t, opt); \\\n  REGISTER_DIM_SCATTER_LIKE_KERNEL(op_type, DeviceType::kCPU, bool, int64_t, opt);    \\\n  REGISTER_DIM_SCATTER_LIKE_KERNEL(op_type, DeviceType::kCPU, float, int64_t, opt);   \\\n  REGISTER_DIM_SCATTER_LIKE_KERNEL(op_type, DeviceType::kCPU, double, int64_t, opt);  \\\n  REGISTER_DIM_SCATTER_LIKE_KERNEL(op_type, DeviceType::kCPU, float16, int64_t, opt); \\\n  REGISTER_DIM_SCATTER_LIKE_KERNEL(op_type, DeviceType::kCPU, int32_t, int64_t, opt);\n\n#define REGISTER_DIM_SCATTER_LIKE_CUDA_KERNELS(op_type, opt)                           \\\n  REGISTER_DIM_SCATTER_LIKE_KERNEL(op_type, DeviceType::kCUDA, bool, int32_t, opt);    \\\n  REGISTER_DIM_SCATTER_LIKE_KERNEL(op_type, DeviceType::kCUDA, float, int32_t, opt);   \\\n  REGISTER_DIM_SCATTER_LIKE_KERNEL(op_type, DeviceType::kCUDA, double, int32_t, opt);  \\\n  REGISTER_DIM_SCATTER_LIKE_KERNEL(op_type, DeviceType::kCUDA, half, int32_t, opt);    \\\n  REGISTER_DIM_SCATTER_LIKE_KERNEL(op_type, DeviceType::kCUDA, int32_t, int32_t, opt); \\\n  REGISTER_DIM_SCATTER_LIKE_KERNEL(op_type, DeviceType::kCUDA, bool, int64_t, opt);    \\\n  REGISTER_DIM_SCATTER_LIKE_KERNEL(op_type, DeviceType::kCUDA, float, int64_t, opt);   \\\n  REGISTER_DIM_SCATTER_LIKE_KERNEL(op_type, DeviceType::kCUDA, double, int64_t, opt);  \\\n  REGISTER_DIM_SCATTER_LIKE_KERNEL(op_type, DeviceType::kCUDA, half, int64_t, opt);    \\\n  REGISTER_DIM_SCATTER_LIKE_KERNEL(op_type, DeviceType::kCUDA, int32_t, int64_t, opt);\n\n#define REGISTER_DIM_SCATTER_KERNEL(op_type, device, dtype_pair, itype_pair, opt)             \\\n  REGISTER_USER_KERNEL(#op_type)                                                              \\\n      .SetCreateFn<DimScatterKernel<device, OF_PP_PAIR_FIRST(dtype_pair),                     \\\n                                    OF_PP_PAIR_FIRST(itype_pair), opt>>()                     \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == device)                                   \\\n                       && (user_op::HobDataType(\"input\", 0) == OF_PP_PAIR_SECOND(dtype_pair)) \\\n                       && (user_op::HobDataType(\"index\", 0) == OF_PP_PAIR_SECOND(itype_pair)));\n\n#define REGISTER_DIM_SCATTER_CPU_KERNELS(dtype_pair, itype_pair)                            \\\n  REGISTER_DIM_SCATTER_KERNEL(dim_scatter_add, DeviceType::kCPU, dtype_pair, itype_pair,    \\\n                              BinOpAddFunctor);                                             \\\n  REGISTER_DIM_SCATTER_KERNEL(dim_scatter_mul, DeviceType::kCPU, dtype_pair, itype_pair,    \\\n                              BinOpMulFunctor);                                             \\\n  REGISTER_DIM_SCATTER_KERNEL(dim_scatter_update, DeviceType::kCPU, dtype_pair, itype_pair, \\\n                              BinOpUpdateFunctor);\n\n#define REGISTER_DIM_SCATTER_CUDA_KERNELS(dtype_pair, itype_pair)                            \\\n  REGISTER_DIM_SCATTER_KERNEL(dim_scatter_add, DeviceType::kCUDA, dtype_pair, itype_pair,    \\\n                              BinOpAddFunctor);                                              \\\n  REGISTER_DIM_SCATTER_KERNEL(dim_scatter_mul, DeviceType::kCUDA, dtype_pair, itype_pair,    \\\n                              BinOpMulFunctor);                                              \\\n  REGISTER_DIM_SCATTER_KERNEL(dim_scatter_update, DeviceType::kCUDA, dtype_pair, itype_pair, \\\n                              BinOpUpdateFunctor);\n\nREGISTER_DIM_SCATTER_LIKE_CPU_KERNELS(\"dim_scatter_add_like\", BinOpAddFunctor);\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_DIM_SCATTER_CPU_KERNELS,\n                                 ARITHMETIC_DATA_TYPE_SEQ UNSIGNED_INT_DATA_TYPE_SEQ\n                                     BOOL_DATA_TYPE_SEQ FLOAT16_DATA_TYPE_SEQ,\n                                 INDEX_DATA_TYPE_SEQ)\n\n#ifdef WITH_CUDA\nREGISTER_DIM_SCATTER_LIKE_CUDA_KERNELS(\"dim_scatter_add_like\", BinOpAddFunctor);\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_DIM_SCATTER_CUDA_KERNELS,\n                                 ARITHMETIC_DATA_TYPE_SEQ UNSIGNED_INT_DATA_TYPE_SEQ\n                                     BOOL_DATA_TYPE_SEQ HALF_DATA_TYPE_SEQ,\n                                 INDEX_DATA_TYPE_SEQ)\n#endif  // WITH_CUDA\n\n}  // namespace user_op\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/dim_scatter_scalar_kernel_util.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/user/kernels/dim_scatter_scalar_kernel_util.h\"\n\nnamespace oneflow {\n\nnamespace user_op {\n\ntemplate<typename IN_T, typename IDX_T, template<typename T> class Opt>\nstruct DimScatterScalarFunctor<DeviceType::kCPU, IN_T, IDX_T, Opt> final {\n  void operator()(ep::Stream* stream, const DimOpIndexNdHelper<IDX_T>& idx_nd_helper,\n                  const DimOpIndexNdHelper<IDX_T>& output_nd_helper, const int ndim,\n                  const int64_t elem_cnt, const int32_t dim, int64_t upper_bound,\n                  const IDX_T* index, const IN_T src, IN_T* output) {\n    DoScatterScalarFunctor<IN_T, IDX_T, Opt>(idx_nd_helper, output_nd_helper, ndim, elem_cnt, dim,\n                                             upper_bound, index, src, output);\n  }\n};\n\nINSTANTIATE_DIM_SCATTER_SCARLAR_CPU_FUNCTORS(DeviceType::kCPU, UpdateScalarFunctor);\nINSTANTIATE_DIM_SCATTER_SCARLAR_CPU_FUNCTORS(DeviceType::kCPU, AddScalarFunctor);\n\n}  // namespace user_op\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/dim_scatter_scalar_kernel_util.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifdef WITH_CUDA\n#include \"oneflow/user/kernels/dim_scatter_scalar_kernel_util.h\"\n\nnamespace oneflow {\n\nnamespace user_op {\n\ntemplate<typename IN_T, typename IDX_T, template<typename T> class Opt>\n__global__ void DoCUDADimScatterScalar(const DimOpIndexNdHelper<IDX_T> idx_nd_helper,\n                                       const DimOpIndexNdHelper<IDX_T> output_nd_helper,\n                                       const int ndim, const int64_t elem_cnt, const int32_t dim,\n                                       const int64_t upper_bound, const IDX_T* index,\n                                       const IN_T src_scalar, IN_T* output) {\n  DoScatterScalarFunctor<IN_T, IDX_T, Opt>(idx_nd_helper, output_nd_helper, ndim, elem_cnt, dim,\n                                           upper_bound, index, src_scalar, output);\n}\n\ntemplate<typename IN_T, typename IDX_T, template<typename T> class Opt>\nstruct DimScatterScalarFunctor<DeviceType::kCUDA, IN_T, IDX_T, Opt> final {\n  void operator()(ep::Stream* stream, const DimOpIndexNdHelper<IDX_T>& idx_nd_helper,\n                  const DimOpIndexNdHelper<IDX_T>& output_nd_helper, const int ndim,\n                  const int64_t elem_cnt, const int32_t dim, int64_t upper_bound,\n                  const IDX_T* index, const IN_T src, IN_T* output) {\n    RUN_CUDA_KERNEL((DoCUDADimScatterScalar<IN_T, IDX_T, Opt>), stream,\n                    BlocksNum4ThreadsNum(elem_cnt), idx_nd_helper, output_nd_helper, ndim, elem_cnt,\n                    dim, upper_bound, index, src, output);\n  }\n};\n\ntemplate<typename IDX_T, template<typename T> class Opt>\nstruct DimScatterScalarFunctor<DeviceType::kCUDA, float16, IDX_T, Opt> final {\n  void operator()(ep::Stream* stream, const DimOpIndexNdHelper<IDX_T>& idx_nd_helper,\n                  const DimOpIndexNdHelper<IDX_T>& output_nd_helper, const int ndim,\n                  const int64_t elem_cnt, const int32_t dim, int64_t upper_bound,\n                  const IDX_T* index, const float16 src, float16* output) {\n    RUN_CUDA_KERNEL((DoCUDADimScatterScalar<half, IDX_T, Opt>), stream,\n                    BlocksNum4ThreadsNum(elem_cnt), idx_nd_helper, output_nd_helper, ndim, elem_cnt,\n                    dim, upper_bound, index, src, reinterpret_cast<half*>(output));\n  }\n};\n\nINSTANTIATE_DIM_SCATTER_SCARLAR_CUDA_FUNCTORS(DeviceType::kCUDA, UpdateScalarFunctor);\nINSTANTIATE_DIM_SCATTER_SCARLAR_CUDA_FUNCTORS(DeviceType::kCUDA, AddScalarFunctor);\n\n}  // namespace user_op\n}  // namespace oneflow\n#endif\n"
  },
  {
    "path": "oneflow/user/kernels/dim_scatter_scalar_kernel_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_KERNELS_DIM_SCATTER_SCALAR_KERNEL_UTIL_H_\n#define ONEFLOW_USER_KERNELS_DIM_SCATTER_SCALAR_KERNEL_UTIL_H_\n#ifdef WITH_CUDA\n#include \"oneflow/core/cuda/atomic.cuh\"\n#include <cuda_fp16.h>\n#endif  // WITH_CUDA\n#include \"oneflow/core/ep/include/stream.h\"\n#include \"oneflow/core/ndarray/xpu_util.h\"\n#include \"oneflow/core/common/nd_index_offset_helper.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/common/data_type.h\"\n\nnamespace oneflow {\n\n#define NO_HALF_UTIL_FOUND         \\\n  printf(\"cuda arch must >= 530\"); \\\n  assert(false)\n\nnamespace user_op {\n\nconstexpr int kDimGatherMaxDimCount = 8;\n\ntemplate<typename T>\nstruct AddScalarFunctor {\n  OF_DEVICE_FUNC static void apply(const T x, T* y) {\n#ifdef __CUDA_ARCH__\n    cuda::atomic::Add(y, x);\n#else\n    *y += x;\n#endif\n  }\n};\n\n#ifdef WITH_CUDA\ntemplate<>\nstruct AddScalarFunctor<half> {\n  OF_DEVICE_FUNC static void apply(const half x, half* y) {\n#if __CUDA_ARCH__\n    *y = __float2half(__half2float(*y) + __half2float(x));\n#else\n    NO_HALF_UTIL_FOUND;\n#endif\n  }\n};\n#endif\n\ntemplate<>\nstruct AddScalarFunctor<int8_t> {\n  OF_DEVICE_FUNC static void apply(const int8_t x, int8_t* y) { *y += x; }\n};\n\ntemplate<>\nstruct AddScalarFunctor<uint8_t> {\n  OF_DEVICE_FUNC static void apply(const uint8_t x, uint8_t* y) { *y += x; }\n};\n\ntemplate<>\nstruct AddScalarFunctor<int64_t> {\n  OF_DEVICE_FUNC static void apply(const int64_t x, int64_t* y) { *y += x; }\n};\n\ntemplate<typename T>\nstruct UpdateScalarFunctor {\n  OF_DEVICE_FUNC static void apply(const T x, T* y) { *y = x; }\n};\n\n#define INSTANTIATE_DIM_SCATTER_SCARLAR_CPU_FUNCTORS(device_type, opt)         \\\n  template struct DimScatterScalarFunctor<device_type, uint8_t, int32_t, opt>; \\\n  template struct DimScatterScalarFunctor<device_type, int8_t, int32_t, opt>;  \\\n  template struct DimScatterScalarFunctor<device_type, int32_t, int32_t, opt>; \\\n  template struct DimScatterScalarFunctor<device_type, int64_t, int32_t, opt>; \\\n  template struct DimScatterScalarFunctor<device_type, float, int32_t, opt>;   \\\n  template struct DimScatterScalarFunctor<device_type, double, int32_t, opt>;  \\\n  template struct DimScatterScalarFunctor<device_type, float16, int32_t, opt>; \\\n  template struct DimScatterScalarFunctor<device_type, uint8_t, int64_t, opt>; \\\n  template struct DimScatterScalarFunctor<device_type, int8_t, int64_t, opt>;  \\\n  template struct DimScatterScalarFunctor<device_type, int32_t, int64_t, opt>; \\\n  template struct DimScatterScalarFunctor<device_type, int64_t, int64_t, opt>; \\\n  template struct DimScatterScalarFunctor<device_type, float, int64_t, opt>;   \\\n  template struct DimScatterScalarFunctor<device_type, double, int64_t, opt>;  \\\n  template struct DimScatterScalarFunctor<device_type, float16, int64_t, opt>;\n\n#define INSTANTIATE_DIM_SCATTER_SCARLAR_CUDA_FUNCTORS(device_type, opt)        \\\n  template struct DimScatterScalarFunctor<device_type, uint8_t, int32_t, opt>; \\\n  template struct DimScatterScalarFunctor<device_type, int8_t, int32_t, opt>;  \\\n  template struct DimScatterScalarFunctor<device_type, int32_t, int32_t, opt>; \\\n  template struct DimScatterScalarFunctor<device_type, int64_t, int32_t, opt>; \\\n  template struct DimScatterScalarFunctor<device_type, float, int32_t, opt>;   \\\n  template struct DimScatterScalarFunctor<device_type, double, int32_t, opt>;  \\\n  template struct DimScatterScalarFunctor<device_type, half, int32_t, opt>;    \\\n  template struct DimScatterScalarFunctor<device_type, uint8_t, int64_t, opt>; \\\n  template struct DimScatterScalarFunctor<device_type, int8_t, int64_t, opt>;  \\\n  template struct DimScatterScalarFunctor<device_type, int32_t, int64_t, opt>; \\\n  template struct DimScatterScalarFunctor<device_type, int64_t, int64_t, opt>; \\\n  template struct DimScatterScalarFunctor<device_type, float, int64_t, opt>;   \\\n  template struct DimScatterScalarFunctor<device_type, double, int64_t, opt>;  \\\n  template struct DimScatterScalarFunctor<device_type, half, int64_t, opt>;\n\ntemplate<typename T>\nusing DimOpIndexNdHelper = NdIndexOffsetHelper<T, kDimGatherMaxDimCount>;\n\ntemplate<DeviceType device_type, typename IN_T, typename IDX_T, template<typename T> class Opt>\nstruct DimScatterScalarFunctor final {\n  void operator()(ep::Stream* stream, const DimOpIndexNdHelper<IDX_T>& idx_nd_helper,\n                  const DimOpIndexNdHelper<IDX_T>& output_nd_helper, const int ndim,\n                  const int64_t elem_cnt, const int32_t dim, int64_t upper_bound,\n                  const IDX_T* index, const IN_T src, IN_T* output);\n};\n\ntemplate<typename IN_T, typename IDX_T, template<typename T> class Opt>\nOF_DEVICE_FUNC void DoScatterScalarFunctor(const DimOpIndexNdHelper<IDX_T>& idx_nd_helper,\n                                           const DimOpIndexNdHelper<IDX_T>& output_nd_helper,\n                                           const int ndim, const int64_t elem_cnt,\n                                           const int32_t dim, int64_t upper_bound,\n                                           const IDX_T* index, const IN_T src, IN_T* output) {\n  XPU_1D_KERNEL_LOOP(idx_offset, elem_cnt) {\n    IDX_T coordinate[kDimGatherMaxDimCount] = {0};\n\n    idx_nd_helper.OffsetToNdIndex(idx_offset, coordinate, ndim);  // idx_offset -> ijk\n    IDX_T idx_elem = index[idx_offset];\n    if (idx_elem >= upper_bound) {\n#if __CUDA_ARCH__\n      __trap();\n#else\n      UNIMPLEMENTED() << \"The index element \" << idx_elem << \" is out of bounds for dimension \"\n                      << dim << \" with size \" << upper_bound << \".\";\n#endif\n    }\n    coordinate[dim] = idx_elem;\n    IDX_T output_offset = output_nd_helper.NdIndexToOffset(coordinate, ndim);\n    Opt<IN_T>::apply(src, output + output_offset);\n  }\n}\n\n}  // namespace user_op\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_KERNELS_DIM_SCATTER_SCALAR_KERNEL_UTIL_H_\n"
  },
  {
    "path": "oneflow/user/kernels/dim_scatter_scalar_kernels.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/user/kernels/dim_scatter_scalar_kernel_util.h\"\n\nnamespace oneflow {\n\nnamespace user_op {\n\ntemplate<DeviceType device_type, typename IN_T, typename IDX_T, template<typename T> class Opt>\nclass DimScatterScalarKernel final : public user_op::OpKernel {\n public:\n  DimScatterScalarKernel() = default;\n  ~DimScatterScalarKernel() = default;\n\n private:\n  void Compute(KernelComputeContext* ctx) const override {\n    const Tensor* input_tensor = ctx->Tensor4ArgNameAndIndex(\"input\", 0);\n    const Tensor* index_tensor = ctx->Tensor4ArgNameAndIndex(\"index\", 0);\n    Tensor* out_tensor = ctx->Tensor4ArgNameAndIndex(\"output\", 0);\n    const int32_t dim = ctx->Attr<int32_t>(\"dim\");\n\n    const IDX_T* index = index_tensor->dptr<IDX_T>();\n    IN_T* output = out_tensor->mut_dptr<IN_T>();\n    size_t out_bytes_size =\n        out_tensor->shape_view().elem_cnt() * GetSizeOfDataType(out_tensor->data_type());\n\n    Tensor* like_tensor = ctx->Tensor4ArgNameAndIndex(\"like\", 0);\n    const IN_T src_scalar = static_cast<IN_T>(ctx->Attr<float>(\"src_scalar\"));\n\n    if (input_tensor) {\n      Memcpy<device_type>(ctx->stream(), output, input_tensor->dptr<IN_T>(), out_bytes_size);\n    } else if (like_tensor) {\n      Memset<device_type>(ctx->stream(), output, 0, out_bytes_size);\n    } else {\n      UNIMPLEMENTED() << \"Input tensor and like tensor cannot be empty simultaneously.\";\n    }\n\n    const int ndim = out_tensor->shape_view().NumAxes();\n    small_vector<IDX_T, kDimGatherMaxDimCount> shape_vec(ndim);\n    auto shape2dims = [&shape_vec, &ndim](const ShapeView& tensor_shape) -> void {\n      std::transform(tensor_shape.ptr(), tensor_shape.ptr() + ndim, shape_vec.begin(),\n                     [](int32_t dim) -> IDX_T { return static_cast<IDX_T>(dim); });\n    };\n    shape2dims(index_tensor->shape_view());\n    DimOpIndexNdHelper<IDX_T> idx_nd_helper(shape_vec.data(), ndim);\n    shape2dims(out_tensor->shape_view());\n    DimOpIndexNdHelper<IDX_T> output_nd_helper(shape_vec.data(), ndim);\n\n    int64_t upper_bound = 0;\n    if (input_tensor) {\n      upper_bound =\n          input_tensor->shape_view().At(dim);  // ensure the idx is smaller than upperbound\n    } else {\n      upper_bound = like_tensor->shape_view().At(dim);  // ensure the idx is smaller than upperbound\n    }\n\n    DimScatterScalarFunctor<device_type, IN_T, IDX_T, Opt>()(\n        ctx->stream(), idx_nd_helper, output_nd_helper, ndim, index_tensor->shape_view().elem_cnt(),\n        dim, upper_bound, index, src_scalar, output);\n  }\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_SCATTERSCALAR_KERNEL(op_type_name, device, dtype_pair, itype_pair, opt)      \\\n  REGISTER_USER_KERNEL(#op_type_name)                                                         \\\n      .SetCreateFn<DimScatterScalarKernel<device, OF_PP_PAIR_FIRST(dtype_pair),               \\\n                                          OF_PP_PAIR_FIRST(itype_pair), opt>>()               \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == device)                                   \\\n                       && (user_op::HobDataType(\"input\", 0) == OF_PP_PAIR_SECOND(dtype_pair)) \\\n                       && (user_op::HobDataType(\"index\", 0) == OF_PP_PAIR_SECOND(itype_pair)));\n\n#define REGISTER_SCATTER_SCALAR_CPU_KERNELS(dtype_pair, itype_pair)                               \\\n  REGISTER_SCATTERSCALAR_KERNEL(dim_scatter_update_scalar, DeviceType::kCPU, dtype_pair,          \\\n                                itype_pair, UpdateScalarFunctor);                                 \\\n  REGISTER_SCATTERSCALAR_KERNEL(dim_scatter_add_scalar, DeviceType::kCPU, dtype_pair, itype_pair, \\\n                                AddScalarFunctor);\n\n#define REGISTER_SCATTER_SCALAR_CUDA_KERNELS(dtype_pair, itype_pair)                               \\\n  REGISTER_SCATTERSCALAR_KERNEL(dim_scatter_update_scalar, DeviceType::kCUDA, dtype_pair,          \\\n                                itype_pair, UpdateScalarFunctor);                                  \\\n  REGISTER_SCATTERSCALAR_KERNEL(dim_scatter_add_scalar, DeviceType::kCUDA, dtype_pair, itype_pair, \\\n                                AddScalarFunctor);\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(\n    REGISTER_SCATTER_SCALAR_CPU_KERNELS,\n    ARITHMETIC_DATA_TYPE_SEQ UNSIGNED_INT_DATA_TYPE_SEQ FLOAT16_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ)\n\n#ifdef WITH_CUDA\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(\n    REGISTER_SCATTER_SCALAR_CUDA_KERNELS,\n    ARITHMETIC_DATA_TYPE_SEQ UNSIGNED_INT_DATA_TYPE_SEQ HALF_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ)\n#endif  // WITH_CUDA\n\n}  // namespace user_op\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/distributions/common.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_KERNELS_DISTRIBUTIONS_COMMON_H_\n#define ONEFLOW_USER_KERNELS_DISTRIBUTIONS_COMMON_H_\n\n#include \"oneflow/core/common/data_type.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/random_generator.h\"\n\nnamespace oneflow {\n\nclass DistributionKernelState : public user_op::OpKernelState {\n public:\n  explicit DistributionKernelState(const std::shared_ptr<one::Generator>& generator)\n      : generator_(generator) {}\n\n  const std::shared_ptr<one::Generator>& generator() const { return generator_; }\n\n private:\n  std::shared_ptr<one::Generator> generator_;\n};\n\n// FIXME: refine warning message\n#define CHECK_OUT_OF_BOUNDS(var, name, min, max, dtype) \\\n  CHECK(var >= min && var <= max) << name << \" is out of bounds for \" << dtype;\n\n#define WARN_OUT_OF_BOUNDS(var, name, digits, dtype)                                          \\\n  if (var < -(1LL << digits) || var > (1LL << digits)) {                                      \\\n    LOG(WARNING) << name << \" is out of bounds [-(2^\" << digits << \"), 2^\" << digits << \"]. \" \\\n                 << \"Due to precision limitations \" << dtype                                  \\\n                 << \" can support discrete uniform distribution only within this range. \"     \\\n                 << \"This warning will become an error in later version release.\";            \\\n  }\n\ntemplate<typename scalar_t>\nvoid check_from_to_in_range(int64_t from, int64_t to_inc) {\n  if (IsFloating<scalar_t>::value) {\n    const auto min = static_cast<double>(std::numeric_limits<scalar_t>::lowest());\n    const auto max = static_cast<double>(std::numeric_limits<scalar_t>::max());\n    CHECK_OUT_OF_BOUNDS(from, \"from\", min, max, GetDataType<scalar_t>::value);\n    CHECK_OUT_OF_BOUNDS(to_inc, \"to - 1\", min, max, GetDataType<scalar_t>::value);\n\n    constexpr auto digits = std::numeric_limits<scalar_t>::digits;\n    WARN_OUT_OF_BOUNDS(from, \"from\", digits, GetDataType<scalar_t>::value);\n    WARN_OUT_OF_BOUNDS(to_inc, \"to - 1\", digits, GetDataType<scalar_t>::value);\n  } else if (IsIntegral<scalar_t>::value || IsUnsignedIntegral<scalar_t>::value) {\n    const auto min = static_cast<int64_t>(std::numeric_limits<scalar_t>::lowest());\n    const auto max = static_cast<int64_t>(std::numeric_limits<scalar_t>::max());\n    CHECK_OUT_OF_BOUNDS(from, \"from\", min, max, GetDataType<scalar_t>::value);\n    CHECK_OUT_OF_BOUNDS(to_inc, \"to - 1\", min, max, GetDataType<scalar_t>::value);\n  } else {\n    UNIMPLEMENTED()\n        << \"check_random_bounds handles only integral, floating-point and boolean types\";\n  }\n}\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_KERNELS_DISTRIBUTIONS_UNIFORM_KERNEL_H_\n"
  },
  {
    "path": "oneflow/user/kernels/distributions/distribution_template_util.cuh",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_KERNELS_DISTRIBUTIONS_DISTRIBUTIONS_TEMPLATE_UTIL_H_\n#define ONEFLOW_USER_KERNELS_DISTRIBUTIONS_DISTRIBUTIONS_TEMPLATE_UTIL_H_\n\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n#include \"oneflow/user/kernels/fused_rnn_cell_kernel_util.h\"\n#include \"oneflow/core/common/scalar.h\"\n#ifdef WITH_CUDA\n#include <curand.h>\n#include <curand_kernel.h>\n#endif\n\nnamespace oneflow {\n\nnamespace distribution {\n\ntemplate<typename T>\nstruct DefaultComputeType {\n  using type = T;\n};\n\n#define OF_DEINFE_SPECIAL_DEFAULT_COMPUTE_TYPE(T, typeproto) \\\n  template<>                                                 \\\n  struct DefaultComputeType<T> {                             \\\n    using type = float;                                      \\\n  };\n\nOF_PP_FOR_EACH_TUPLE(OF_DEINFE_SPECIAL_DEFAULT_COMPUTE_TYPE,\n                     INT_DATA_TYPE_SEQ UNSIGNED_INT_DATA_TYPE_SEQ FLOAT16_DATA_TYPE_SEQ\n                         HALF_DATA_TYPE_SEQ)\n\n#undef OF_DEINFE_SPECIAL_DEFAULT_COMPUTE_TYPE\n\n}  // namespace distribution\n\nnamespace {\n\n// launch bounds used for kernels\nconst uint32_t block_size_bound = 256;\nconst uint32_t grid_size_bound = 4;\n\n}  // namespace\n\n#ifdef WITH_CUDA\n\nenum class DistributionOp {\n  kNormal4,\n  kNormal2Double,\n  kUniform4,\n  kUniform2Double,\n};\n\ntemplate<DistributionOp distribution_op>\nstruct DistributionFunctor;\n\ntemplate<>\nstruct DistributionFunctor<DistributionOp::kNormal4> {\n  DistributionFunctor() {}\n\n  __device__ float4 operator()(curandStatePhilox4_32_10_t* state) const {\n    return curand_normal4(state);\n  }\n};\n\ntemplate<>\nstruct DistributionFunctor<DistributionOp::kNormal2Double> {\n  DistributionFunctor() {}\n\n  __device__ double2 operator()(curandStatePhilox4_32_10_t* state) const {\n    return curand_normal2_double(state);\n  }\n};\n\ntemplate<>\nstruct DistributionFunctor<DistributionOp::kUniform4> {\n  DistributionFunctor() {}\n\n  __device__ float4 operator()(curandStatePhilox4_32_10_t* state) const {\n    return curand_uniform4(state);\n  }\n};\n\ntemplate<>\nstruct DistributionFunctor<DistributionOp::kUniform2Double> {\n  DistributionFunctor() {}\n\n  __device__ double2 operator()(curandStatePhilox4_32_10_t* state) const {\n    return curand_uniform2_double(state);\n  }\n};\n\ntemplate<typename T, typename ComputeType, int unroll_factor, typename Distribution,\n         typename Transform>\nOF_LAUNCH_BOUNDS_2(block_size_bound, grid_size_bound)\n__global__\n    void DistributionElementwiseGridStrideKernel(int64_t numel, uint64_t seed, uint64_t offset,\n                                                 T* out_ptr, Distribution dist_func,\n                                                 Transform transform_func) {\n  int idx = blockIdx.x * blockDim.x + threadIdx.x;\n  curandStatePhilox4_32_10_t state;\n  curand_init(seed, idx, offset, &state);\n\n  int rounded_size = ((numel - 1) / (blockDim.x * gridDim.x * unroll_factor) + 1) * blockDim.x\n                     * gridDim.x * unroll_factor;\n  for (int32_t linear_index = idx; linear_index < rounded_size;\n       linear_index += blockDim.x * gridDim.x * unroll_factor) {\n    auto rand = dist_func(&state);\n#pragma unroll\n    for (int ii = 0; ii < unroll_factor; ii++) {\n      int li = linear_index + blockDim.x * gridDim.x * ii;\n      if (li < numel) { out_ptr[li] = transform_func(static_cast<ComputeType>((&rand.x)[ii])); }\n    }\n  }\n}\n\n#endif  // WITH_CUDA\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_KERNELS_DISTRIBUTIONS_DISTRIBUTIONS_TEMPLATE_UTIL_H_\n"
  },
  {
    "path": "oneflow/user/kernels/distributions/exponential_distribution.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <math.h>\n#include <array>\n#include <cmath>\n#include <cstdint>\n\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/user/kernels/distributions/exponential_distribution.h\"\n\nnamespace oneflow {\n\nstatic uint64_t make64BitsFrom32Bits(uint32_t hi, uint32_t lo) {\n  return (static_cast<uint64_t>(hi) << 32) | lo;\n}\n\ntemplate<typename T, typename V>\nstatic T uniform_real(V val, T from, T to) {\n  constexpr auto MASK =\n      static_cast<V>((static_cast<uint64_t>(1) << std::numeric_limits<T>::digits) - 1);\n  constexpr auto DIVISOR =\n      static_cast<T>(1) / (static_cast<uint64_t>(1) << std::numeric_limits<T>::digits);\n  T x = (val & MASK) * DIVISOR;\n  return (x * (to - from) + from);\n}\n\ntemplate<typename T>\nvoid ExponentialDistribution<DeviceType::kCPU, T>::operator()(\n    ep::Stream* stream, const int64_t elem_cnt, T* dptr,\n    const std::shared_ptr<one::Generator>& generator) const {\n  CHECK_GE(elem_cnt, 0);\n  auto gen = CHECK_JUST(generator->Get<ep::CPUGenerator>());\n  ep::pytorch_mt19937_engine& engine = gen->torch_engine();\n  for (int64_t i = 0; i < elem_cnt; ++i) {\n    uint32_t random1 = engine();\n    uint32_t random2 = engine();\n    uint64_t rand_unit = make64BitsFrom32Bits(random1, random2);\n    T random_val = uniform_real(rand_unit, 0.0, 1.0);\n    dptr[i] = static_cast<T>(-1.0) / lambd_ * std::log(static_cast<T>(1.0) - random_val);\n  }\n}\n\n#define INITIATE_CPU_UNIFORM_DISTRIBUTION(T, typeproto)                   \\\n  template void ExponentialDistribution<DeviceType::kCPU, T>::operator()( \\\n      ep::Stream* stream, const int64_t elem_cnt, T* dptr,                \\\n      const std::shared_ptr<one::Generator>& generator) const;\n\nOF_PP_FOR_EACH_TUPLE(INITIATE_CPU_UNIFORM_DISTRIBUTION, FLOATING_DATA_TYPE_SEQ)\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/distributions/exponential_distribution.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/common/data_type.h\"\n#include \"oneflow/core/ep/include/device.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n#include \"oneflow/user/kernels/distributions/distribution_template_util.cuh\"\n#include \"oneflow/user/kernels/distributions/exponential_distribution.h\"\n#include \"oneflow/user/kernels/fused_rnn_cell_kernel_util.h\"\n\nnamespace oneflow {\n\ntemplate<typename T, typename ComputeType>\nstruct ExponentialTransformFunctor;\n\ntemplate<>\nstruct ExponentialTransformFunctor<float, float> {\n  ExponentialTransformFunctor(float epsilon, float lambd) : epsilon(epsilon), lambd(lambd) {}\n  __device__ float operator()(float random_val) const {\n    float log_rand = __logf(static_cast<float>(random_val));\n    // curand_uniform has (0,1] bounds. log(1) is 0 and exponential excludes 0.\n    // we need log to be not 0, and not underflow when converted to half\n    // fast __logf approximation can underflow, so set log to -epsilon/2 for 1 or close to 1\n    // args\n    float log = static_cast<float>(random_val) >= static_cast<float>(1.) - epsilon / 2\n                    ? -epsilon / 2\n                    : log_rand;\n    return static_cast<float>(-1.0) / lambd * log;\n  }\n  float epsilon;\n  float lambd;\n};\n\ntemplate<>\nstruct ExponentialTransformFunctor<double, double> {\n  ExponentialTransformFunctor(double epsilon, double lambd) : epsilon(epsilon), lambd(lambd) {}\n  __device__ double operator()(double random_val) const {\n    double log_rand = ::log(static_cast<double>(random_val));\n    // curand_uniform has (0,1] bounds. log(1) is 0 and exponential excludes 0.\n    // we need log to be not 0, and not underflow when converted to half\n    // fast __logf approximation can underflow, so set log to -epsilon/2 for 1 or close to 1\n    // args\n    double log = static_cast<double>(random_val) >= static_cast<double>(1.) - epsilon / 2\n                     ? -epsilon / 2\n                     : log_rand;\n    return static_cast<double>(-1.0) / lambd * log;\n  }\n  double epsilon;\n  double lambd;\n};\n\ntemplate<>\nstruct ExponentialTransformFunctor<half, float> {\n  ExponentialTransformFunctor(float epsilon, float lambd) : float_functor(epsilon, lambd) {}\n  __device__ half operator()(float random_val) const {\n    return static_cast<half>(float_functor(random_val));\n  }\n  ExponentialTransformFunctor<float, float> float_functor;\n};\n\ntemplate<>\nvoid ExponentialDistribution<DeviceType::kCUDA, double>::operator()(\n    ep::Stream* stream, const int64_t elem_cnt, double* dptr,\n    const std::shared_ptr<one::Generator>& generator) const {\n  CHECK_GT(elem_cnt, 0);\n  const auto device_index = stream->device()->device_index();\n  auto gen = CHECK_JUST(generator->Get<ep::CUDAGenerator>(device_index));\n  ep::CudaStream* cuda_stream = stream->As<ep::CudaStream>();\n  auto execution_policy = gen->CalcExecutionPolicy(elem_cnt, cuda_stream);\n\n  auto counter_offset = std::get<0>(execution_policy);\n  auto grid = std::get<1>(execution_policy);\n  auto block = std::get<2>(execution_policy);\n\n  uint64_t seed = gen->current_seed();\n  uint64_t offset = gen->get_philox_offset(counter_offset);\n\n  ExponentialTransformFunctor<double, double> transform_functor(\n      std::numeric_limits<double>::epsilon(), static_cast<double>(lambd_));\n  DistributionFunctor<DistributionOp::kUniform2Double> dist_functor;\n\n  DistributionElementwiseGridStrideKernel<double, double, 2, decltype(dist_functor),\n                                          decltype(transform_functor)>\n      <<<grid, block, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(\n          elem_cnt, seed, offset, dptr, dist_functor, transform_functor);\n}\n\ntemplate<>\nvoid ExponentialDistribution<DeviceType::kCUDA, float>::operator()(\n    ep::Stream* stream, const int64_t elem_cnt, float* dptr,\n    const std::shared_ptr<one::Generator>& generator) const {\n  CHECK_GT(elem_cnt, 0);\n  const auto device_index = stream->device()->device_index();\n  auto gen = CHECK_JUST(generator->Get<ep::CUDAGenerator>(device_index));\n  ep::CudaStream* cuda_stream = stream->As<ep::CudaStream>();\n  auto execution_policy = gen->CalcExecutionPolicy(elem_cnt, cuda_stream);\n\n  auto counter_offset = std::get<0>(execution_policy);\n  auto grid = std::get<1>(execution_policy);\n  auto block = std::get<2>(execution_policy);\n\n  uint64_t seed = gen->current_seed();\n  uint64_t offset = gen->get_philox_offset(counter_offset);\n\n  ExponentialTransformFunctor<float, float> transform_functor(std::numeric_limits<float>::epsilon(),\n                                                              static_cast<float>(lambd_));\n  DistributionFunctor<DistributionOp::kUniform4> dist_functor;\n\n  DistributionElementwiseGridStrideKernel<float, float, 4, decltype(dist_functor),\n                                          decltype(transform_functor)>\n      <<<grid, block, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(\n          elem_cnt, seed, offset, dptr, dist_functor, transform_functor);\n}\n\ntemplate<>\nvoid ExponentialDistribution<DeviceType::kCUDA, half>::operator()(\n    ep::Stream* stream, const int64_t elem_cnt, half* dptr,\n    const std::shared_ptr<one::Generator>& generator) const {\n  CHECK_GT(elem_cnt, 0);\n  const auto device_index = stream->device()->device_index();\n  auto gen = CHECK_JUST(generator->Get<ep::CUDAGenerator>(device_index));\n  ep::CudaStream* cuda_stream = stream->As<ep::CudaStream>();\n  auto execution_policy = gen->CalcExecutionPolicy(elem_cnt, cuda_stream);\n\n  auto counter_offset = std::get<0>(execution_policy);\n  auto grid = std::get<1>(execution_policy);\n  auto block = std::get<2>(execution_policy);\n\n  uint64_t seed = gen->current_seed();\n  uint64_t offset = gen->get_philox_offset(counter_offset);\n\n  ExponentialTransformFunctor<half, float> transform_functor(std::numeric_limits<float>::epsilon(),\n                                                             static_cast<float>(lambd_));\n  DistributionFunctor<DistributionOp::kUniform4> dist_functor;\n\n  DistributionElementwiseGridStrideKernel<half, float, 4, decltype(dist_functor),\n                                          decltype(transform_functor)>\n      <<<grid, block, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(\n          elem_cnt, seed, offset, dptr, dist_functor, transform_functor);\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/distributions/exponential_distribution.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_KERNELS_DISTRIBUTIONS_EXPONENTIAL_DISTRIBUTION_H_\n#define ONEFLOW_USER_KERNELS_DISTRIBUTIONS_EXPONENTIAL_DISTRIBUTION_H_\n\n#include \"oneflow/core/ep/include/stream.h\"\n#include \"oneflow/core/framework/random_generator.h\"\n#ifdef WITH_CUDA\n#include <curand.h>\n#include <curand_kernel.h>\n#endif\n\nnamespace oneflow {\n\ntemplate<DeviceType device_type, typename T>\nclass ExponentialDistribution;\n\ntemplate<typename T>\nclass ExponentialDistribution<DeviceType::kCPU, T> final {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(ExponentialDistribution);\n  ExponentialDistribution(T lambd) : lambd_(lambd) {}\n  ~ExponentialDistribution() = default;\n\n  void operator()(ep::Stream* stream, const int64_t elem_cnt, T* dptr,\n                  const std::shared_ptr<one::Generator>& generator) const;\n\n private:\n  const T lambd_;\n};\n\n#ifdef WITH_CUDA\ntemplate<typename T>\nclass ExponentialDistribution<DeviceType::kCUDA, T> final {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(ExponentialDistribution);\n  ExponentialDistribution(T lambd) : lambd_(lambd) {}\n  ~ExponentialDistribution() = default;\n\n  void operator()(ep::Stream* stream, const int64_t elem_cnt, T* dptr,\n                  const std::shared_ptr<one::Generator>& generator) const;\n\n private:\n  const T lambd_;\n};\n#endif  // WITH_CUDA\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_KERNELS_DISTRIBUTIONS_EXPONENTIAL_DISTRIBUTION_H_\n"
  },
  {
    "path": "oneflow/user/kernels/distributions/exponential_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/user/kernels/distributions/exponential_kernel.h\"\n\nnamespace oneflow {\n\nnamespace {\n\n#define REGISTER_EXPONENTIAL_KERNEL(device, dtype)          \\\n  REGISTER_USER_KERNEL(\"exponential\")                       \\\n      .SetCreateFn<ExponentialKernel<device, dtype>>()      \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == device) \\\n                       && (user_op::HobAttr<DataType>(\"dtype\") == GetDataType<dtype>::value));\n\nREGISTER_EXPONENTIAL_KERNEL(DeviceType::kCPU, float)\nREGISTER_EXPONENTIAL_KERNEL(DeviceType::kCPU, double)\n#ifdef WITH_CUDA\nREGISTER_EXPONENTIAL_KERNEL(DeviceType::kCUDA, float)\nREGISTER_EXPONENTIAL_KERNEL(DeviceType::kCUDA, double)\nREGISTER_EXPONENTIAL_KERNEL(DeviceType::kCUDA, half)\n#endif  // WITH_CUDA\n\n}  // namespace\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/distributions/exponential_kernel.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_KERNELS_DISTRIBUTIONS_EXPONENTIAL_KERNEL_H_\n#define ONEFLOW_USER_KERNELS_DISTRIBUTIONS_EXPONENTIAL_KERNEL_H_\n\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/user/kernels/distributions/common.h\"\n#include \"oneflow/user/kernels/distributions/exponential_distribution.h\"\n#include \"oneflow/user/kernels/random_seed_util.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<DeviceType device_type, typename T>\nclass ExponentialKernel final : public user_op::OpKernel {\n public:\n  ExponentialKernel() = default;\n  ~ExponentialKernel() = default;\n\n  std::shared_ptr<user_op::OpKernelState> CreateOpKernelState(\n      user_op::KernelInitContext* ctx) const override {\n    const auto& generator = CHECK_JUST(one::MakeGenerator(device_type));\n    // When SBP is Split, each rank uses a different seeds, otherwise, ranks use the same seed\n    generator->set_current_seed(\n        CHECK_JUST(GetOpKernelRandomSeedInCurrentRank(ctx, ctx->Attr<int64_t>(\"seed\"))));\n    return std::make_shared<DistributionKernelState>(generator);\n  }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state,\n               const user_op::OpKernelCache*) const override {\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    const float lambd = ctx->Attr<float>(\"lambd\");\n    int64_t elem_cnt = out->shape_view().elem_cnt();\n    T* out_dptr = out->mut_dptr<T>();\n    auto* distribution_state = dynamic_cast<DistributionKernelState*>(state);\n    CHECK_NOTNULL(distribution_state);\n    const auto& generator = distribution_state->generator();\n    CHECK_NOTNULL(generator);\n    ExponentialDistribution<device_type, T> distribution(static_cast<T>(lambd));\n    distribution(ctx->stream(), elem_cnt, out_dptr, generator);\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n}  // namespace\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_KERNELS_DISTRIBUTIONS_EXPONENTIAL_KERNEL_H_\n"
  },
  {
    "path": "oneflow/user/kernels/distributions/multinomial_with_replacement_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/user/kernels/distributions/common.h\"\n#include \"oneflow/user/kernels/random_seed_util.h\"\n\n// NOTE(Liang Depeng): The implementation of MultinomialWithReplacementCpuKernel is modified from\n//                    https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cpu/MultinomialKernel.cpp#L23\nnamespace oneflow {\n\nnamespace {\n\nstatic size_t InferTmpSizeForCpuKernel(user_op::InferContext* ctx) {\n  const auto& x = ctx->InputTensorDesc(\"x\", 0);\n  int64_t n_categories = x.shape().At(x.shape().NumAxes() - 1);\n  return n_categories * GetSizeOfDataType(x.data_type());\n}\n\ntemplate<typename T, typename V>\nstatic T uniform_real(V val, T from, T to) {\n  constexpr auto MASK =\n      static_cast<V>((static_cast<uint64_t>(1) << std::numeric_limits<T>::digits) - 1);\n  constexpr auto DIVISOR =\n      static_cast<T>(1) / (static_cast<uint64_t>(1) << std::numeric_limits<T>::digits);\n  T x = (val & MASK) * DIVISOR;\n  return (x * (to - from) + from);\n}\n\nstatic uint64_t make64BitsFrom32Bits(uint32_t hi, uint32_t lo) {\n  return (static_cast<uint64_t>(hi) << 32) | lo;\n}\n\n}  // namespace\n\ntemplate<typename T>\nclass MultinomialWithReplacementCpuKernel final : public user_op::OpKernel {\n public:\n  MultinomialWithReplacementCpuKernel() = default;\n  ~MultinomialWithReplacementCpuKernel() = default;\n\n  std::shared_ptr<user_op::OpKernelState> CreateOpKernelState(\n      user_op::KernelInitContext* ctx) const override {\n    const auto& generator = CHECK_JUST(one::MakeGenerator(DeviceType::kCPU));\n    // When SBP is Split, each rank uses a different seeds, otherwise, ranks use the same seed\n    generator->set_current_seed(\n        CHECK_JUST(GetOpKernelRandomSeedInCurrentRank(ctx, ctx->Attr<int64_t>(\"seed\"))));\n    return std::make_shared<DistributionKernelState>(generator);\n  }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state,\n               const user_op::OpKernelCache*) const override {\n    auto* distribution_state = dynamic_cast<DistributionKernelState*>(state);\n    CHECK_NOTNULL(distribution_state);\n    const auto& generator = distribution_state->generator();\n    CHECK_NOTNULL(generator);\n    auto cpu_gen = CHECK_JUST(generator->Get<ep::CPUGenerator>());\n    std::lock_guard<std::mutex> lock(cpu_gen->mutex_);\n\n    const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n\n    const T* self_ptr = x->dptr<T>();\n    int64_t* result_ptr = out->mut_dptr<int64_t>();\n    /* cumulative probability distribution vector */\n    T* cum_dist_ptr = tmp_buffer->mut_dptr<T>();\n\n    int64_t n_categories = x->shape_view().At(x->shape_view().NumAxes() - 1);\n    int64_t n_dist = x->shape_view().NumAxes() > 1 ? x->shape_view().At(0) : 1;\n    const int32_t num_samples = ctx->Attr<int32_t>(\"num_samples\");\n\n    int64_t self_stride_0 = x->shape_view().NumAxes() > 1 ? x->stride().at(0) : 0;\n    int64_t self_stride_1 = x->stride().at(x->shape_view().NumAxes() - 1);\n    int64_t result_dist_stride_0 = out->shape_view().NumAxes() > 1 ? out->stride().at(0) : 0;\n    int64_t result_dist_stride_1 = out->stride().at(out->shape_view().NumAxes() - 1);\n\n    ep::pytorch_mt19937_engine& engine = cpu_gen->torch_engine();\n\n    for (int i = 0; i < n_dist; ++i) {\n      /* Get normalized cumulative distribution from prob distribution */\n      T sum = 0;\n      T val;\n      for (int j = 0; j < n_categories; ++j) {\n        val = self_ptr[i * self_stride_0 + j * self_stride_1];\n        CHECK(val >= 0) << \"invalid multinomial distribution (encountering probability entry < 0)\";\n        CHECK(std::isfinite(val)) << \"invalid multinomial distribution (encountering probability \"\n                                     \"entry = infinity or NaN)\";\n        sum += val;\n        cum_dist_ptr[j] = sum;\n      }\n\n      CHECK(sum > 0) << \"invalid multinomial distribution (sum of probabilities <= 0)\";\n\n      /* normalize cumulative probability distribution so that last val is 1\n      i.e. doesn't assume original self row sums to one */\n      if ((sum > 0) || ((sum < 1.00001) && (sum > 0.99999))) {\n        for (int j = 0; j < n_categories; ++j) { cum_dist_ptr[j] /= sum; }\n      }\n\n      for (int j = 0; j < num_samples; ++j) {\n        /* sample a probability mass from a uniform distribution */\n        // at::uniform_real_distribution<double> uniform(0, 1);\n        // double uniform_sample = uniform(gen);\n        uint32_t random1 = engine();\n        uint32_t random2 = engine();\n        uint64_t rand_unit = make64BitsFrom32Bits(random1, random2);\n        double uniform_sample = uniform_real(rand_unit, 0.0, 1.0);\n\n        // Do a binary search for the slot in which the prob falls\n        // ie cum_dist[row][slot-1] < uniform_prob < cum_distr[row][slot]\n        int left_pointer = 0;\n        int right_pointer = n_categories;\n        int mid_pointer = 0;\n        T cum_prob;\n        int sample_idx = 0;\n        // Make sure the last cumulative distribution bucket sums to 1\n        cum_dist_ptr[(n_categories - 1)] = 1;\n\n        while (right_pointer - left_pointer > 0) {\n          mid_pointer = left_pointer + (right_pointer - left_pointer) / 2;\n          cum_prob = cum_dist_ptr[mid_pointer];\n          if (cum_prob < uniform_sample) {\n            left_pointer = mid_pointer + 1;\n          } else {\n            right_pointer = mid_pointer;\n          }\n        }\n        sample_idx = left_pointer;\n\n        // store in result tensor (will be incremented for lua compat by wrapper)\n        result_ptr[i * result_dist_stride_0 + j * result_dist_stride_1] = sample_idx;\n      }\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_MULTINOMIAL_WITH_REPLACEMENT_CPU_KERNEL(dtype)                        \\\n  REGISTER_USER_KERNEL(\"multinomial_with_replacement\")                                 \\\n      .SetCreateFn<MultinomialWithReplacementCpuKernel<dtype>>()                       \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)                  \\\n                       && (user_op::HobDataType(\"x\", 0) == GetDataType<dtype>::value)) \\\n      .SetInferTmpSizeFn(InferTmpSizeForCpuKernel);\n\nREGISTER_MULTINOMIAL_WITH_REPLACEMENT_CPU_KERNEL(float)\nREGISTER_MULTINOMIAL_WITH_REPLACEMENT_CPU_KERNEL(double)\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/distributions/multinomial_with_replacement_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/data_type.h\"\n#include \"oneflow/core/ep/include/device.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/user/kernels/distributions/common.h\"\n#include \"oneflow/user/kernels/random_seed_util.h\"\n\n// NOTE(Liang Depeng): The implementation of MultinomialWithReplacementGpuKernel is modified from\n//                    https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cuda/MultinomialKernel.cu#L324\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename T>\n__device__ int binarySearchForMultinomial(const T* cumdist, const T* dist, int32_t size, T val) {\n  int start = 0;\n  int end = size;\n\n  while (end - start > 0) {\n    int mid = start + (end - start) / 2;\n    T midVal = cumdist[mid];\n    if (midVal < val) {\n      start = mid + 1;\n    } else {\n      end = mid;\n    }\n  }\n\n  if (start == size) {\n    // No probability mass or precision problems; just return the\n    // first non-zero element by setting start to size-1 here,\n    // the code below will move it to the last non-zero probability\n    // this actually can happen when the random number is 1\n    // (github pytorch issue #4858).\n    start = size - 1;\n  }\n\n  while (start >= 1 && dist[start] == 0) start--;\n\n  return start;\n}\n\ntemplate<typename T>\n__global__ void sampleMultinomialWithReplacement(uint64_t seed, uint64_t offset,\n                                                 int32_t totalSamples, int64_t* dest,\n                                                 int64_t distributions, int64_t categories,\n                                                 const T* normDistPrefixSum, const T* normDist) {\n  // At the moment, each warp computes one sample value in the binary\n  // search due to divergence. It seems possible to compute multiple\n  // values and limit divergence though later on.\n\n  // global index formula for 2D grid of 1D blocks\n  int idx = blockIdx.y * gridDim.x * blockDim.x + blockIdx.x * blockDim.x + threadIdx.x;\n  curandStatePhilox4_32_10_t state;\n  curand_init(seed, idx, offset, &state);\n\n  // The block determines the distribution for which we generate a point\n  for (int64_t curDist = blockIdx.y; curDist < distributions; curDist += gridDim.y) {\n    for (int sample = blockIdx.x * blockDim.x + threadIdx.x; sample < totalSamples;\n         sample += blockDim.x * gridDim.x) {\n      // we are losing 3 out of 4 generated numbers but it's ok\n      // this kernel is not very efficient anyway\n      auto rand = curand_uniform4(&state);\n      T r = static_cast<T>(rand.x);\n\n      // Find the bucket that a uniform sample lies in\n      int choice = binarySearchForMultinomial<T>(normDistPrefixSum + curDist * categories,\n                                                 normDist + curDist * categories, categories, r);\n\n      dest[curDist * totalSamples + sample] = choice;\n    }\n  }\n}\n\n}  // namespace\n\ntemplate<typename T>\nclass MultinomialWithReplacementGpuKernel final : public user_op::OpKernel {\n public:\n  MultinomialWithReplacementGpuKernel() = default;\n  ~MultinomialWithReplacementGpuKernel() = default;\n\n  std::shared_ptr<user_op::OpKernelState> CreateOpKernelState(\n      user_op::KernelInitContext* ctx) const override {\n    const auto& generator = CHECK_JUST(one::MakeGenerator(DeviceType::kCUDA));\n    // When SBP is Split, each rank uses a different seeds, otherwise, ranks use the same seed\n    generator->set_current_seed(\n        CHECK_JUST(GetOpKernelRandomSeedInCurrentRank(ctx, ctx->Attr<int64_t>(\"seed\"))));\n    return std::make_shared<DistributionKernelState>(generator);\n  }\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state,\n               const user_op::OpKernelCache*) const override {\n    auto* distribution_state = dynamic_cast<DistributionKernelState*>(state);\n    CHECK_NOTNULL(distribution_state);\n    const auto& generator = distribution_state->generator();\n    CHECK_NOTNULL(generator);\n    auto gpu_gen = CHECK_JUST(generator->Get<ep::CUDAGenerator>());\n\n    const user_op::Tensor* norm_dist = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    const user_op::Tensor* prefix_sum = ctx->Tensor4ArgNameAndIndex(\"prefix_sum\", 0);\n    CHECK_NOTNULL(prefix_sum);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n\n    const T* norm_dist_ptr = norm_dist->dptr<T>();\n    const T* prefix_sum_ptr = prefix_sum->dptr<T>();\n    int64_t* result_ptr = out->mut_dptr<int64_t>();\n\n    int64_t numCategories = norm_dist->shape_view().At(norm_dist->shape_view().NumAxes() - 1);\n    int64_t numDist = norm_dist->shape_view().NumAxes() > 1 ? norm_dist->shape_view().At(0) : 1;\n    const int32_t n_sample = ctx->Attr<int32_t>(\"num_samples\");\n\n    // Binary search is warp divergent (so effectively we're running\n    // with just a single thread), but for better utilization,\n    // we need each block to have at least 4 warps.\n    dim3 block(128);\n\n    ep::CudaStream* stream = ctx->stream()->As<ep::CudaStream>();\n    // Each block will generate a sample from one\n    // distribution concurrently.\n    int grid_y = std::min<int>(numDist, stream->device_properties().maxGridSize[1]);\n    dim3 grid((n_sample - 1) / block.x + 1, grid_y);\n    uint64_t seed = gpu_gen->current_seed();\n    uint64_t offset = gpu_gen->get_philox_offset(((numDist - 1) / grid.y + 1) * 4);\n\n    // Sample with replacement\n    sampleMultinomialWithReplacement<<<grid, block, 0, stream->cuda_stream()>>>(\n        seed, offset, n_sample, result_ptr, numDist, numCategories, prefix_sum_ptr, norm_dist_ptr);\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_MULTINOMIAL_WITH_REPLACEMENT_GPU_KERNEL(dtype)                       \\\n  REGISTER_USER_KERNEL(\"multinomial_with_replacement\")                                \\\n      .SetCreateFn<MultinomialWithReplacementGpuKernel<dtype>>()                      \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)                \\\n                       && (user_op::HobDataType(\"x\", 0) == GetDataType<dtype>::value) \\\n                       && (user_op::HobDataType(\"prefix_sum\", 0) == GetDataType<dtype>::value));\n\nREGISTER_MULTINOMIAL_WITH_REPLACEMENT_GPU_KERNEL(float)\nREGISTER_MULTINOMIAL_WITH_REPLACEMENT_GPU_KERNEL(double)\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/distributions/normal_distribution.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/user/kernels/distributions/normal_distribution.h\"\n#include \"oneflow/core/framework/framework.h\"\n\nnamespace oneflow {\n\ntemplate<typename T>\nvoid NormalDistribution<DeviceType::kCPU, T>::operator()(\n    ep::Stream* stream, const int64_t elem_cnt, T* dptr,\n    const std::shared_ptr<one::Generator>& generator) const {\n  CHECK_GE(elem_cnt, 0) << \"elem_cnt must be non-negative, but got \" << elem_cnt;\n  auto gen = CHECK_JUST(generator->Get<ep::CPUGenerator>());\n  std::normal_distribution<T> random_distribution(mean_, std_);\n  for (int64_t i = 0; i < elem_cnt; ++i) { dptr[i] = random_distribution(gen->engine()); }\n}\n\n#define INITIATE_CPU_NORMAL_DISTRIBUTION(T, typeproto)               \\\n  template void NormalDistribution<DeviceType::kCPU, T>::operator()( \\\n      ep::Stream* stream, const int64_t elem_cnt, T* dptr,           \\\n      const std::shared_ptr<one::Generator>& generator) const;\n\nOF_PP_FOR_EACH_TUPLE(INITIATE_CPU_NORMAL_DISTRIBUTION, FLOATING_DATA_TYPE_SEQ)\n\n// specialization for half\ntemplate<>\nvoid NormalDistribution<DeviceType::kCPU, float16>::operator()(\n    ep::Stream* stream, const int64_t elem_cnt, float16* dptr,\n    const std::shared_ptr<one::Generator>& generator) const {\n  CHECK_GE(elem_cnt, 0) << \"elem_cnt must be non-negative, but got \" << elem_cnt;\n  auto gen = CHECK_JUST(generator->Get<ep::CPUGenerator>());\n  std::normal_distribution<float> random_distribution(mean_, std_);\n  for (int64_t i = 0; i < elem_cnt; ++i) {\n    dptr[i] = static_cast<float16>(random_distribution(gen->engine()));\n  }\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/distributions/normal_distribution.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/user/kernels/distributions/distribution_template_util.cuh\"\n#include \"oneflow/user/kernels/distributions/normal_distribution.h\"\n#include \"oneflow/core/common/data_type.h\"\n#include \"oneflow/core/ep/include/device.h\"\n\nnamespace oneflow {\n\ntemplate<typename T, typename ComputeType>\nstruct NormalTransformFunctor {\n  NormalTransformFunctor(ComputeType mean, ComputeType std) : mean(mean), std(std) {}\n  __device__ T operator()(ComputeType random_val) const {\n    return static_cast<T>(random_val * std + mean);\n  }\n  ComputeType mean;\n  ComputeType std;\n};\n\ntemplate<typename T>\nvoid NormalDistribution<DeviceType::kCUDA, T>::operator()(\n    ep::Stream* stream, const int64_t elem_cnt, T* dptr,\n    const std::shared_ptr<one::Generator>& generator) const {\n  CHECK_GE(elem_cnt, 0);\n  if (elem_cnt == 0) return;\n  const auto device_index = stream->device()->device_index();\n  auto gen = CHECK_JUST(generator->Get<ep::CUDAGenerator>(device_index));\n\n  ep::CudaStream* cuda_stream = stream->As<ep::CudaStream>();\n  auto execution_policy = gen->CalcExecutionPolicy(elem_cnt, cuda_stream);\n\n  auto counter_offset = std::get<0>(execution_policy);\n  auto grid = std::get<1>(execution_policy);\n  auto block = std::get<2>(execution_policy);\n\n  uint64_t seed = gen->current_seed();\n  uint64_t offset = gen->get_philox_offset(counter_offset);\n\n  using ComputeType = typename distribution::DefaultComputeType<T>::type;\n\n  NormalTransformFunctor<T, ComputeType> transform_functor(static_cast<ComputeType>(mean_),\n                                                           static_cast<ComputeType>(std_));\n\n  if (std::is_same<T, double>::value) {\n    DistributionFunctor<DistributionOp::kNormal2Double> dist_functor;\n    DistributionElementwiseGridStrideKernel<T, ComputeType, 2, decltype(dist_functor),\n                                            decltype(transform_functor)>\n        <<<grid, block, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(\n            elem_cnt, seed, offset, dptr, dist_functor, transform_functor);\n  } else {\n    DistributionFunctor<DistributionOp::kNormal4> dist_functor;\n    DistributionElementwiseGridStrideKernel<T, ComputeType, 4, decltype(dist_functor),\n                                            decltype(transform_functor)>\n        <<<grid, block, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(\n            elem_cnt, seed, offset, dptr, dist_functor, transform_functor);\n  }\n}\n\n#define INITIATE_CUDA_NORMAL_DISTRIBUTION(T, typeproto)               \\\n  template void NormalDistribution<DeviceType::kCUDA, T>::operator()( \\\n      ep::Stream* stream, const int64_t elem_cnt, T* dptr,            \\\n      const std::shared_ptr<one::Generator>& generator) const;\n\nOF_PP_FOR_EACH_TUPLE(INITIATE_CUDA_NORMAL_DISTRIBUTION, FLOATING_DATA_TYPE_SEQ)\nINITIATE_CUDA_NORMAL_DISTRIBUTION(half, DataType::kFloat16)\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/distributions/normal_distribution.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#ifndef ONEFLOW_USER_KERNELS_DISTRIBUTIONS_NORMAL_DISTRIBUTION_H_\n#define ONEFLOW_USER_KERNELS_DISTRIBUTIONS_NORMAL_DISTRIBUTION_H_\n\n#include \"oneflow/core/ep/include/stream.h\"\n#include \"oneflow/core/framework/random_generator.h\"\n#ifdef WITH_CUDA\n#include <curand.h>\n#include <curand_kernel.h>\n#endif\n\nnamespace oneflow {\n\ntemplate<DeviceType device_type, typename T>\nclass NormalDistribution;\n\ntemplate<typename T>\nclass NormalDistribution<DeviceType::kCPU, T> final {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(NormalDistribution);\n  NormalDistribution(T mean, T std) : mean_(mean), std_(std) {}\n  ~NormalDistribution() = default;\n\n  void operator()(ep::Stream* stream, const int64_t elem_cnt, T* dptr,\n                  const std::shared_ptr<one::Generator>& generator) const;\n\n private:\n  const T mean_;\n  const T std_;\n};\n\n#ifdef WITH_CUDA\ntemplate<typename T>\nclass NormalDistribution<DeviceType::kCUDA, T> final {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(NormalDistribution);\n  NormalDistribution(T mean, T std) : mean_(mean), std_(std) {}\n  ~NormalDistribution() = default;\n\n  void operator()(ep::Stream* stream, const int64_t elem_cnt, T* dptr,\n                  const std::shared_ptr<one::Generator>& generator) const;\n\n private:\n  const T mean_;\n  const T std_;\n};\n#endif  // WITH_CUDA\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_KERNELS_DISTRIBUTIONS_NORMAL_DISTRIBUTION_H_"
  },
  {
    "path": "oneflow/user/kernels/distributions/normal_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/user/kernels/distributions/normal_kernel.h\"\n\nnamespace oneflow {\n\nnamespace {\n#define REGISTER_UNIFORM_KERNEL(device, dtype)                                               \\\n  REGISTER_USER_KERNEL(\"normal\").SetCreateFn<NormalKernel<device, dtype>>().SetIsMatchedHob( \\\n      (user_op::HobDeviceType() == device)                                                   \\\n      && (user_op::HobAttr<DataType>(\"dtype\") == GetDataType<dtype>::value));\n\nREGISTER_UNIFORM_KERNEL(DeviceType::kCPU, float16)\nREGISTER_UNIFORM_KERNEL(DeviceType::kCPU, float)\nREGISTER_UNIFORM_KERNEL(DeviceType::kCPU, double)\n#ifdef WITH_CUDA\nREGISTER_UNIFORM_KERNEL(DeviceType::kCUDA, half)\nREGISTER_UNIFORM_KERNEL(DeviceType::kCUDA, float)\nREGISTER_UNIFORM_KERNEL(DeviceType::kCUDA, double)\n#endif  // WITH_CUDA\n}  // namespace\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/distributions/normal_kernel.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#ifndef ONEFLOW_USER_KERNELS_DISTRIBUTIONS_NORMAL_KERNEL_H_\n#define ONEFLOW_USER_KERNELS_DISTRIBUTIONS_NORMAL_KERNEL_H_\n\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/user/kernels/distributions/common.h\"\n#include \"oneflow/user/kernels/distributions/normal_distribution.h\"\n#include \"oneflow/user/kernels/random_seed_util.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<DeviceType device_type, typename T>\nclass NormalKernel final : public user_op::OpKernel {\n public:\n  NormalKernel() = default;\n  ~NormalKernel() = default;\n\n  std::shared_ptr<user_op::OpKernelState> CreateOpKernelState(\n      user_op::KernelInitContext* ctx) const override {\n    const auto& generator = CHECK_JUST(one::MakeGenerator(device_type));\n    // When SBP is Split, each rank uses a different seeds, otherwise, ranks use the same seed\n    generator->set_current_seed(\n        CHECK_JUST(GetOpKernelRandomSeedInCurrentRank(ctx, ctx->Attr<int64_t>(\"seed\"))));\n    return std::make_shared<DistributionKernelState>(generator);\n  }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state,\n               const user_op::OpKernelCache*) const override {\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    const double mean = ctx->Attr<double>(\"mean\");\n    const double std = ctx->Attr<double>(\"std\");\n    int64_t elem_cnt = out->shape_view().elem_cnt();\n    T* out_dptr = out->mut_dptr<T>();\n    auto* distribution_state = dynamic_cast<DistributionKernelState*>(state);\n    CHECK_NOTNULL(distribution_state);\n    const auto& generator = distribution_state->generator();\n    CHECK_NOTNULL(generator);\n    NormalDistribution<device_type, T> distribution(static_cast<T>(mean), static_cast<T>(std));\n    distribution(ctx->stream(), elem_cnt, out_dptr, generator);\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n}  // namespace\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_KERNELS_DISTRIBUTIONS_NORMAL_KERNEL_H_\n"
  },
  {
    "path": "oneflow/user/kernels/distributions/uniform_distribution.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/user/kernels/distributions/uniform_distribution.h\"\n\nnamespace oneflow {\n\ntemplate<typename T, typename E = void>\nclass CPUUniformDistributionImpl;\n\ntemplate<typename T>\nclass CPUUniformDistributionImpl<T,\n                                 typename std::enable_if<std::is_floating_point<T>::value>::type> {\n public:\n  CPUUniformDistributionImpl(T low, T high) : random_distribution_(low, high) {}\n\n  T operator()(std::mt19937& engine) { return random_distribution_(engine); }\n\n private:\n  std::uniform_real_distribution<T> random_distribution_;\n};\n\ntemplate<typename T>\nvoid UniformDistribution<DeviceType::kCPU, T>::operator()(\n    ep::Stream* stream, const int64_t elem_cnt, T* dptr,\n    const std::shared_ptr<one::Generator>& generator) const {\n  CHECK_GE(elem_cnt, 0) << \"elem_cnt must be non-negative, but got \" << elem_cnt;\n  auto gen = CHECK_JUST(generator->Get<ep::CPUGenerator>());\n  CPUUniformDistributionImpl<T> impl(low_, high_);\n  for (int64_t i = 0; i < elem_cnt; ++i) { dptr[i] = impl(gen->engine()); }\n}\n\n#define INITIATE_CPU_UNIFORM_DISTRIBUTION(T, typeproto)               \\\n  template void UniformDistribution<DeviceType::kCPU, T>::operator()( \\\n      ep::Stream* stream, const int64_t elem_cnt, T* dptr,            \\\n      const std::shared_ptr<one::Generator>& generator) const;\n\nOF_PP_FOR_EACH_TUPLE(INITIATE_CPU_UNIFORM_DISTRIBUTION, FLOATING_DATA_TYPE_SEQ)\n\n// specialization for half\ntemplate<>\nvoid UniformDistribution<DeviceType::kCPU, float16>::operator()(\n    ep::Stream* stream, const int64_t elem_cnt, float16* dptr,\n    const std::shared_ptr<one::Generator>& generator) const {\n  CHECK_GE(elem_cnt, 0) << \"elem_cnt must be non-negative, but got \" << elem_cnt;\n  auto gen = CHECK_JUST(generator->Get<ep::CPUGenerator>());\n  CPUUniformDistributionImpl<float> impl(low_, high_);\n  for (int64_t i = 0; i < elem_cnt; ++i) { dptr[i] = static_cast<float16>(impl(gen->engine())); }\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/distributions/uniform_distribution.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/data_type.h\"\n#include \"oneflow/user/kernels/distributions/uniform_distribution.h\"\n#include \"oneflow/user/kernels/distributions/distribution_template_util.cuh\"\n#include \"oneflow/core/ep/include/device.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n\nnamespace oneflow {\n\ntemplate<typename T, typename ComputeType>\nstruct UniformTransformFunctor {\n  UniformTransformFunctor(ComputeType low, ComputeType high) : low(low), high(high) {}\n  __device__ T operator()(ComputeType rand_num) const {\n    if (rand_num == static_cast<ComputeType>(1.0)) { rand_num = static_cast<ComputeType>(0.0); }\n    return static_cast<T>(rand_num * (high - low) + low);\n  }\n  ComputeType low;\n  ComputeType high;\n};\n\ntemplate<typename T>\nvoid UniformDistribution<DeviceType::kCUDA, T>::operator()(\n    ep::Stream* stream, const int64_t elem_cnt, T* dptr,\n    const std::shared_ptr<one::Generator>& generator) const {\n  CHECK_GE(elem_cnt, 0);\n  if (elem_cnt == 0) return;\n  const auto device_index = stream->device()->device_index();\n  auto gen = CHECK_JUST(generator->Get<ep::CUDAGenerator>(device_index));\n\n  ep::CudaStream* cuda_stream = stream->As<ep::CudaStream>();\n  auto execution_policy = gen->CalcExecutionPolicy(elem_cnt, cuda_stream);\n\n  auto counter_offset = std::get<0>(execution_policy);\n  auto grid = std::get<1>(execution_policy);\n  auto block = std::get<2>(execution_policy);\n\n  uint64_t seed = gen->current_seed();\n  uint64_t offset = gen->get_philox_offset(counter_offset);\n\n  using ComputeType = typename distribution::DefaultComputeType<T>::type;\n\n  UniformTransformFunctor<T, ComputeType> transform_functor(static_cast<ComputeType>(low_),\n                                                            static_cast<ComputeType>(high_));\n\n  if (std::is_same<T, double>::value) {\n    DistributionFunctor<DistributionOp::kUniform2Double> dist_functor;\n    DistributionElementwiseGridStrideKernel<T, ComputeType, 2, decltype(dist_functor),\n                                            decltype(transform_functor)>\n        <<<grid, block, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(\n            elem_cnt, seed, offset, dptr, dist_functor, transform_functor);\n  } else {\n    DistributionFunctor<DistributionOp::kUniform4> dist_functor;\n    DistributionElementwiseGridStrideKernel<T, ComputeType, 4, decltype(dist_functor),\n                                            decltype(transform_functor)>\n        <<<grid, block, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(\n            elem_cnt, seed, offset, dptr, dist_functor, transform_functor);\n  }\n}\n\n#define INITIATE_CUDA_UNIFORM_DISTRIBUTION(T, typeproto)               \\\n  template void UniformDistribution<DeviceType::kCUDA, T>::operator()( \\\n      ep::Stream* stream, const int64_t elem_cnt, T* dptr,             \\\n      const std::shared_ptr<one::Generator>& generator) const;\n\nOF_PP_FOR_EACH_TUPLE(INITIATE_CUDA_UNIFORM_DISTRIBUTION, FLOATING_DATA_TYPE_SEQ)\nINITIATE_CUDA_UNIFORM_DISTRIBUTION(half, DataType::kFloat16)\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/distributions/uniform_distribution.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_KERNELS_DISTRIBUTIONS_UNIFORM_DISTRIBUTION_H_\n#define ONEFLOW_USER_KERNELS_DISTRIBUTIONS_UNIFORM_DISTRIBUTION_H_\n\n#include \"oneflow/core/ep/include/stream.h\"\n#include \"oneflow/core/framework/random_generator.h\"\n#ifdef WITH_CUDA\n#include <curand.h>\n#include <curand_kernel.h>\n#endif\n\nnamespace oneflow {\n\ntemplate<DeviceType device_type, typename T>\nclass UniformDistribution;\n\ntemplate<typename T>\nclass UniformDistribution<DeviceType::kCPU, T> final {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(UniformDistribution);\n  UniformDistribution(T low, T high) : low_(low), high_(high) {}\n  ~UniformDistribution() = default;\n\n  void operator()(ep::Stream* stream, const int64_t elem_cnt, T* dptr,\n                  const std::shared_ptr<one::Generator>& generator) const;\n\n private:\n  const T low_;\n  const T high_;\n};\n\n#ifdef WITH_CUDA\ntemplate<typename T>\nclass UniformDistribution<DeviceType::kCUDA, T> final {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(UniformDistribution);\n  UniformDistribution(T low, T high) : low_(low), high_(high) {}\n  ~UniformDistribution() = default;\n\n  void operator()(ep::Stream* stream, const int64_t elem_cnt, T* dptr,\n                  const std::shared_ptr<one::Generator>& generator) const;\n\n private:\n  const T low_;\n  const T high_;\n};\n#endif  // WITH_CUDA\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_KERNELS_DISTRIBUTIONS_UNIFORM_DISTRIBUTION_H_\n"
  },
  {
    "path": "oneflow/user/kernels/distributions/uniform_int_distribution.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/data_type.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/dtype.h\"\n#include \"oneflow/user/kernels/distributions/uniform_int_distribution.h\"\n\nnamespace oneflow {\n\ntemplate<typename T>\nclass CPUUniformIntDistributionImpl {\n public:\n  CPUUniformIntDistributionImpl(int64_t low, int64_t high) : random_distribution_(low, high) {}\n\n  T operator()(std::mt19937& engine) { return static_cast<T>(random_distribution_(engine)); }\n\n private:\n  std::uniform_int_distribution<int64_t> random_distribution_;\n};\n\ntemplate<typename T>\nvoid UniformIntDistribution<DeviceType::kCPU, T>::operator()(\n    ep::Stream* stream, const int64_t elem_cnt, T* dptr,\n    const std::shared_ptr<one::Generator>& generator) const {\n  CHECK_GE(elem_cnt, 0);\n  auto gen = CHECK_JUST(generator->Get<ep::CPUGenerator>());\n  // std::uniform_int_distribution generates [low, high], but we want [low, high) here\n  CPUUniformIntDistributionImpl<T> impl(low_, high_ - 1);\n  for (int64_t i = 0; i < elem_cnt; ++i) { dptr[i] = impl(gen->engine()); }\n}\n\n#define INITIATE_CPU_UNIFORM_INT_DISTRIBUTION(T, typeproto)              \\\n  template void UniformIntDistribution<DeviceType::kCPU, T>::operator()( \\\n      ep::Stream* stream, const int64_t elem_cnt, T* dptr,               \\\n      const std::shared_ptr<one::Generator>& generator) const;\n\nOF_PP_FOR_EACH_TUPLE(INITIATE_CPU_UNIFORM_INT_DISTRIBUTION, FLOATING_DATA_TYPE_SEQ)\nOF_PP_FOR_EACH_TUPLE(INITIATE_CPU_UNIFORM_INT_DISTRIBUTION, INT_DATA_TYPE_SEQ)\nOF_PP_FOR_EACH_TUPLE(INITIATE_CPU_UNIFORM_INT_DISTRIBUTION, UNSIGNED_INT_DATA_TYPE_SEQ)\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/distributions/uniform_int_distribution.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/data_type.h\"\n#include \"oneflow/core/common/preprocessor.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/dtype.h\"\n#include \"oneflow/user/kernels/distributions/uniform_int_distribution.h\"\n#include \"oneflow/user/kernels/distributions/distribution_template_util.cuh\"\n#include \"oneflow/core/ep/include/device.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n\nnamespace oneflow {\n\ntemplate<typename T, typename ComputeType>\nstruct UniformIntTransformFunctor {\n  UniformIntTransformFunctor(ComputeType low, ComputeType high) : low(low), high(high) {}\n  __device__ T operator()(ComputeType rand_num) const {\n    if (rand_num == 1.0) { rand_num = 0.0; }\n    return static_cast<T>(static_cast<int64_t>(rand_num * (high - low) + low));\n  }\n  ComputeType low;\n  ComputeType high;\n};\n\ntemplate<typename T>\nvoid UniformIntDistribution<DeviceType::kCUDA, T>::operator()(\n    ep::Stream* stream, const int64_t elem_cnt, T* dptr,\n    const std::shared_ptr<one::Generator>& generator) const {\n  CHECK_GE(elem_cnt, 0);\n  if (elem_cnt == 0) return;\n  const auto device_index = stream->device()->device_index();\n  auto gen = CHECK_JUST(generator->Get<ep::CUDAGenerator>(device_index));\n\n  ep::CudaStream* cuda_stream = stream->As<ep::CudaStream>();\n  auto execution_policy = gen->CalcExecutionPolicy(elem_cnt, cuda_stream);\n\n  auto counter_offset = std::get<0>(execution_policy);\n  auto grid = std::get<1>(execution_policy);\n  auto block = std::get<2>(execution_policy);\n\n  uint64_t seed = gen->current_seed();\n  uint64_t offset = gen->get_philox_offset(counter_offset);\n\n  using ComputeType = typename distribution::DefaultComputeType<T>::type;\n\n  UniformIntTransformFunctor<T, ComputeType> transform_functor(low_, high_);\n\n  if (std::is_same<T, double>::value) {\n    DistributionFunctor<DistributionOp::kUniform2Double> dist_functor;\n    DistributionElementwiseGridStrideKernel<T, ComputeType, 2, decltype(dist_functor),\n                                            decltype(transform_functor)>\n        <<<grid, block, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(\n            elem_cnt, seed, offset, dptr, dist_functor, transform_functor);\n  } else {\n    DistributionFunctor<DistributionOp::kUniform4> dist_functor;\n    DistributionElementwiseGridStrideKernel<T, ComputeType, 4, decltype(dist_functor),\n                                            decltype(transform_functor)>\n        <<<grid, block, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(\n            elem_cnt, seed, offset, dptr, dist_functor, transform_functor);\n  }\n}\n\n#define INITIATE_CUDA_UNIFORM_INT_DISTRIBUTION(T, typeproto)              \\\n  template void UniformIntDistribution<DeviceType::kCUDA, T>::operator()( \\\n      ep::Stream* stream, const int64_t elem_cnt, T* dptr,                \\\n      const std::shared_ptr<one::Generator>& generator) const;\n\nOF_PP_FOR_EACH_TUPLE(INITIATE_CUDA_UNIFORM_INT_DISTRIBUTION, FLOATING_DATA_TYPE_SEQ)\nOF_PP_FOR_EACH_TUPLE(INITIATE_CUDA_UNIFORM_INT_DISTRIBUTION, INT_DATA_TYPE_SEQ)\nOF_PP_FOR_EACH_TUPLE(INITIATE_CUDA_UNIFORM_INT_DISTRIBUTION, UNSIGNED_INT_DATA_TYPE_SEQ)\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/distributions/uniform_int_distribution.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_KERNELS_DISTRIBUTIONS_UNIFORM_INT_DISTRIBUTION_H_\n#define ONEFLOW_USER_KERNELS_DISTRIBUTIONS_UNIFORM_INT_DISTRIBUTION_H_\n\n#include \"oneflow/core/ep/include/stream.h\"\n#include \"oneflow/core/framework/random_generator.h\"\n#ifdef WITH_CUDA\n#include <curand.h>\n#include <curand_kernel.h>\n#endif\n\nnamespace oneflow {\n\ntemplate<DeviceType device_type, typename T>\nclass UniformIntDistribution;\n\ntemplate<typename T>\nclass UniformIntDistribution<DeviceType::kCPU, T> final {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(UniformIntDistribution);\n  UniformIntDistribution(int64_t low, int64_t high) : low_(low), high_(high) {}\n  ~UniformIntDistribution() = default;\n\n  void operator()(ep::Stream* stream, const int64_t elem_cnt, T* dptr,\n                  const std::shared_ptr<one::Generator>& generator) const;\n\n private:\n  const int64_t low_;\n  const int64_t high_;\n};\n\n#ifdef WITH_CUDA\ntemplate<typename T>\nclass UniformIntDistribution<DeviceType::kCUDA, T> final {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(UniformIntDistribution);\n  UniformIntDistribution(int64_t low, int64_t high) : low_(low), high_(high) {}\n  ~UniformIntDistribution() = default;\n\n  void operator()(ep::Stream* stream, const int64_t elem_cnt, T* dptr,\n                  const std::shared_ptr<one::Generator>& generator) const;\n\n private:\n  const int64_t low_;\n  const int64_t high_;\n};\n#endif  // WITH_CUDA\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_KERNELS_DISTRIBUTIONS_UNIFORM_INT_DISTRIBUTION_H_\n"
  },
  {
    "path": "oneflow/user/kernels/distributions/uniform_int_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/user/kernels/distributions/uniform_int_kernel.h\"\n\nnamespace oneflow {\n\nnamespace {\n#define REGISTER_UNIFORM_KERNEL(device, dtype)              \\\n  REGISTER_USER_KERNEL(\"uniform_int\")                       \\\n      .SetCreateFn<UniformIntKernel<device, dtype>>()       \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == device) \\\n                       && (user_op::HobAttr<DataType>(\"dtype\") == GetDataType<dtype>::value));\n\nREGISTER_UNIFORM_KERNEL(DeviceType::kCPU, float)\nREGISTER_UNIFORM_KERNEL(DeviceType::kCPU, double)\nREGISTER_UNIFORM_KERNEL(DeviceType::kCPU, uint8_t)\nREGISTER_UNIFORM_KERNEL(DeviceType::kCPU, int8_t)\nREGISTER_UNIFORM_KERNEL(DeviceType::kCPU, int32_t)\nREGISTER_UNIFORM_KERNEL(DeviceType::kCPU, int64_t)\n#ifdef WITH_CUDA\nREGISTER_UNIFORM_KERNEL(DeviceType::kCUDA, float)\nREGISTER_UNIFORM_KERNEL(DeviceType::kCUDA, double)\nREGISTER_UNIFORM_KERNEL(DeviceType::kCUDA, uint8_t)\nREGISTER_UNIFORM_KERNEL(DeviceType::kCUDA, int8_t)\nREGISTER_UNIFORM_KERNEL(DeviceType::kCUDA, int32_t)\nREGISTER_UNIFORM_KERNEL(DeviceType::kCUDA, int64_t)\n#endif  // WITH_CUDA\n\n}  // namespace\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/distributions/uniform_int_kernel.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_KERNELS_DISTRIBUTIONS_UNIFORM_INT_KERNEL_H_\n#define ONEFLOW_USER_KERNELS_DISTRIBUTIONS_UNIFORM_INT_KERNEL_H_\n\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/random_generator.h\"\n#include \"oneflow/user/kernels/distributions/common.h\"\n#include \"oneflow/user/kernels/distributions/uniform_int_distribution.h\"\n#include \"oneflow/user/kernels/random_seed_util.h\"\n\nnamespace oneflow {\n\nnamespace {\n\n// The following algorithm is adopted from pytorch:\n// The purpose of `update_from` and `update_to` is to find the closest valid int64_t number that can\n// be used as actual `from`. The current implementation of `random_` uses uint64_t arithmetics and\n// casts the result to the target dtype(scalar_t). This casting can result in generating numbers\n// that happen to be greater or equal to `to` value. For instance:\n//\n//    auto actual = torch::empty({3, 3}, torch::half);\n//    actual.random_(0, 65504);\n//\n// If random's uint64_t arithmetics produces 65503 as a random value after casting to torch::half it\n// becomes 65504 and violates the requirement that random value must be less than `to`. To resolve\n// this issue `update_from` and `update_to` moves `from` to the right and `to` to the left to the\n// next closest value that won't go outside [from, to) after casting to the target dtype. For `to` =\n// 65504 it moves left for (1 << (log2(to) - 11 + 1)) = 32 and becomes 65472, which is previous\n// available number for torch::half dtype.\ntemplate<typename scalar_t>\nint64_t update_from(int64_t from) {\n  const auto from_plus_1 = static_cast<int64_t>(static_cast<scalar_t>(from + 1));\n  if (from_plus_1 < from) {\n    int64_t from_ = std::abs(from + 1);\n    int n = 0;\n    while (from_ >>= 1) ++n;\n    // NOLINTNEXTLINE(clang-analyzer-core.UndefinedBinaryOperatorResult)\n    from = from_plus_1 + (1LL << (n - std::numeric_limits<scalar_t>::digits + 1));\n  }\n  return from;\n}\n\ntemplate<typename scalar_t>\nint64_t update_to(int64_t to) {\n  const auto to_minus_1 = static_cast<int64_t>(static_cast<scalar_t>(to - 1));\n  if (to_minus_1 >= to) {\n    int64_t to_ = std::abs(to - 1);\n    int n = 0;\n    while (to_ >>= 1) ++n;\n    // NOLINTNEXTLINE(clang-analyzer-core.UndefinedBinaryOperatorResult)\n    to = to_minus_1 - (1LL << (n - std::numeric_limits<scalar_t>::digits + 1));\n  }\n  return to;\n}\n\ntemplate<DeviceType device_type, typename T>\nclass UniformIntKernel final : public user_op::OpKernel {\n public:\n  UniformIntKernel() = default;\n  ~UniformIntKernel() = default;\n\n  std::shared_ptr<user_op::OpKernelState> CreateOpKernelState(\n      user_op::KernelInitContext* ctx) const override {\n    const auto& generator = CHECK_JUST(one::MakeAutoGenerator());\n    // When SBP is Spit, each rank uses a different seeds, otherwise, ranks use the same seed\n    generator->set_current_seed(\n        CHECK_JUST(GetOpKernelRandomSeedInCurrentRank(ctx, ctx->Attr<int64_t>(\"seed\"))));\n    return std::make_shared<DistributionKernelState>(generator);\n  }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state,\n               const user_op::OpKernelCache*) const override {\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    int64_t from = ctx->Attr<int64_t>(\"from\");\n    int64_t to = ctx->Attr<int64_t>(\"to\");\n    CHECK_LT(from, to) << \"uniform kernel expects 'from' to be less than 'to'\";\n\n    if (IsFloating<T>::value) {\n      from = update_from<T>(from);\n      to = update_to<T>(to);\n      CHECK_LT(from, to) << \"uniform kernel expects 'from' casted to dtype to be less than 'to'\"\n                            \" casted to dtype\";\n    }\n    check_from_to_in_range<T>(from, to - 1);\n    int64_t elem_cnt = out->shape_view().elem_cnt();\n    T* out_dptr = out->mut_dptr<T>();\n    auto* distribution_state = dynamic_cast<DistributionKernelState*>(state);\n    CHECK_NOTNULL(distribution_state);\n    const auto& generator = distribution_state->generator();\n    CHECK_NOTNULL(generator);\n    UniformIntDistribution<device_type, T> distribution(from, to);\n    distribution(ctx->stream(), elem_cnt, out_dptr, generator);\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n}  // namespace\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_KERNELS_DISTRIBUTIONS_UNIFORM_INT_KERNEL_H_\n"
  },
  {
    "path": "oneflow/user/kernels/distributions/uniform_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/user/kernels/distributions/uniform_kernel.h\"\n\nnamespace oneflow {\n\nnamespace {\n#define REGISTER_UNIFORM_KERNEL(device, dtype)                                                 \\\n  REGISTER_USER_KERNEL(\"uniform\").SetCreateFn<UniformKernel<device, dtype>>().SetIsMatchedHob( \\\n      (user_op::HobDeviceType() == device)                                                     \\\n      && (user_op::HobAttr<DataType>(\"dtype\") == GetDataType<dtype>::value));\n\nREGISTER_UNIFORM_KERNEL(DeviceType::kCPU, float16)\nREGISTER_UNIFORM_KERNEL(DeviceType::kCPU, float)\nREGISTER_UNIFORM_KERNEL(DeviceType::kCPU, double)\n#ifdef WITH_CUDA\nREGISTER_UNIFORM_KERNEL(DeviceType::kCUDA, half)\nREGISTER_UNIFORM_KERNEL(DeviceType::kCUDA, float)\nREGISTER_UNIFORM_KERNEL(DeviceType::kCUDA, double)\n#endif  // WITH_CUDA\n}  // namespace\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/distributions/uniform_kernel.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_KERNELS_DISTRIBUTIONS_UNIFORM_KERNEL_H_\n#define ONEFLOW_USER_KERNELS_DISTRIBUTIONS_UNIFORM_KERNEL_H_\n\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/user/kernels/distributions/common.h\"\n#include \"oneflow/user/kernels/distributions/uniform_distribution.h\"\n#include \"oneflow/user/kernels/random_seed_util.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<DeviceType device_type, typename T>\nclass UniformKernel final : public user_op::OpKernel {\n public:\n  UniformKernel() = default;\n  ~UniformKernel() = default;\n\n  std::shared_ptr<user_op::OpKernelState> CreateOpKernelState(\n      user_op::KernelInitContext* ctx) const override {\n    const auto& generator = CHECK_JUST(one::MakeGenerator(device_type));\n    // When SBP is Split, each rank uses a different seeds, otherwise, ranks use the same seed\n    generator->set_current_seed(\n        CHECK_JUST(GetOpKernelRandomSeedInCurrentRank(ctx, ctx->Attr<int64_t>(\"seed\"))));\n    return std::make_shared<DistributionKernelState>(generator);\n  }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state,\n               const user_op::OpKernelCache*) const override {\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    const double from = ctx->Attr<double>(\"from\");\n    const double to = ctx->Attr<double>(\"to\");\n    check_from_to_in_range<T>(from, to);\n    int64_t elem_cnt = out->shape_view().elem_cnt();\n    T* out_dptr = out->mut_dptr<T>();\n    auto* distribution_state = dynamic_cast<DistributionKernelState*>(state);\n    CHECK_NOTNULL(distribution_state);\n    const auto& generator = distribution_state->generator();\n    CHECK_NOTNULL(generator);\n    UniformDistribution<device_type, T> distribution(static_cast<T>(from), static_cast<T>(to));\n    distribution(ctx->stream(), elem_cnt, out_dptr, generator);\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n}  // namespace\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_KERNELS_DISTRIBUTIONS_UNIFORM_KERNEL_H_\n"
  },
  {
    "path": "oneflow/user/kernels/dot_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/new_kernel_util.h\"\n#include \"oneflow/core/ep/include/primitive/matmul.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nusing namespace ep::primitive;\n\ntemplate<typename Context>\nstd::unique_ptr<Matmul> NewMatmulPrimitive(Context* ctx) {\n  const DataType data_type = ctx->TensorDesc4ArgNameAndIndex(\"out\", 0)->data_type();\n  return ep::primitive::NewPrimitive<MatmulFactory>(ctx->device_type(), data_type,\n                                                    BlasTransposeType::N, BlasTransposeType::N);\n}\n\nauto MatmulPrimitiveExists() {\n  return hob::make_custom(\"MatmulPrimitiveExists\", [](const user_op::KernelRegContext& ctx) {\n    return NewMatmulPrimitive(&ctx).operator bool();\n  });\n}\n\nclass DotKernel final : public user_op::OpKernel {\n public:\n  DotKernel() = default;\n  ~DotKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    const user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    int64_t n = x->shape_view().elem_cnt();\n    auto primitive = NewMatmulPrimitive(ctx);\n\n    primitive->Launch(ctx->stream(), 1, 1, n, 1, x->dptr(), y->dptr(), 0, out->mut_dptr());\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\nREGISTER_USER_KERNEL(\"dot\").SetCreateFn<DotKernel>().SetIsMatchedHob(MatmulPrimitiveExists()\n                                                                     == true);\n\n}  // namespace\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/dropout_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/kernel_util.h\"\n#include \"oneflow/user/kernels/op_kernel_wrapper.h\"\n#include \"oneflow/user/kernels/dropout_kernel.h\"\n#include \"oneflow/user/kernels/random_seed_util.h\"\n#include \"oneflow/core/ep/include/primitive/add.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename T>\nvoid MaskAndScale(ep::Stream* stream, const int64_t n, float scale, const T* x, const bool* mask,\n                  T* y) {\n  for (int64_t i = 0; i < n; ++i) { y[i] = x[i] * static_cast<T>(mask[i]) * scale; }\n}\n\ntemplate<typename T>\nvoid FusedDropoutKernel(ep::Stream* stream, const int64_t elem_cnt,\n                        const std::shared_ptr<ep::CPUGenerator>& cpu_gen, const float rate,\n                        float scale, const T* x, bool* mask, T* y) {\n  /*\n  `uniform_real_distribution` interval is [a, b).\n  And `curand_uniform4` interval is (0, 1.0], so we use > in CUDA and use >= in CPU.\n  */\n  std::uniform_real_distribution<float> random_distribution(GetZeroVal<float>(),\n                                                            GetOneVal<float>());\n  for (int64_t i = 0; i < elem_cnt; ++i) {\n    mask[i] = random_distribution(cpu_gen->engine()) >= rate;\n    y[i] = x[i] * static_cast<T>(mask[i]) * scale;\n  }\n}\n\ntemplate<typename T>\nclass DropoutKernelCPU final : public user_op::OpKernel {\n public:\n  DropoutKernelCPU() = default;\n  ~DropoutKernelCPU() = default;\n\n  std::shared_ptr<user_op::OpKernelState> CreateOpKernelState(\n      user_op::KernelInitContext* ctx) const override {\n    const auto& generator = CHECK_JUST(one::MakeGenerator(kCPU));\n    generator->set_current_seed(\n        CHECK_JUST(GetOpKernelRandomSeedInCurrentRank(ctx, ctx->Attr<int64_t>(\"seed\"))));\n    return std::make_shared<FusedDropoutKernelState>(generator);\n  }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state,\n               const user_op::OpKernelCache*) const override {\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    user_op::Tensor* mask = ctx->Tensor4ArgNameAndIndex(\"mask\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    const float rate = ctx->Attr<float>(\"rate\");\n    float scale = 0.0f;\n    if (rate < 1.0f) { scale = 1.0f / (1.0f - rate); }\n\n    auto* fused_dropout_kernel_state = dynamic_cast<FusedDropoutKernelState*>(state);\n    CHECK_NOTNULL(fused_dropout_kernel_state);\n    const auto& generator = fused_dropout_kernel_state->generator();\n    CHECK_NOTNULL(generator);\n    std::shared_ptr<ep::CPUGenerator> cpu_generator =\n        CHECK_JUST(generator->Get<ep::CPUGenerator>());\n\n    FusedDropoutKernel<T>(ctx->stream(), in->shape_view().elem_cnt(), cpu_generator, rate, scale,\n                          in->dptr<T>(), mask->mut_dptr<bool>(), out->mut_dptr<T>());\n\n    if (ctx->has_input(\"_add_to_output\", 0)) {\n      const user_op::Tensor* add_to_output = ctx->Tensor4ArgNameAndIndex(\"_add_to_output\", 0);\n      CHECK_EQ(add_to_output->data_type(), out->data_type());\n      CHECK_EQ(add_to_output->shape_view(), out->shape_view());\n      std::unique_ptr<ep::primitive::Add> primitive =\n          ep::primitive::NewPrimitive<ep::primitive::AddFactory>(DeviceType::kCPU,\n                                                                 add_to_output->data_type());\n      CHECK(primitive);\n      primitive->Launch(ctx->stream(), out->dptr<T>(), add_to_output->dptr<T>(), out->mut_dptr<T>(),\n                        add_to_output->shape_view().elem_cnt());\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_DROPOUT_KERNEL_CPU(dtype)                                                      \\\n  REGISTER_USER_KERNEL(\"dropout\")                                                               \\\n      .SetCreateFn<DropoutKernelCPU<dtype>>()                                                   \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)                           \\\n                       && (user_op::HobDataType(\"out\", 0) == GetDataType<dtype>::value)         \\\n                       && (user_op::HobDataType(\"mask\", 0) == GetDataType<bool>::value))        \\\n      .SetInplaceProposalFn([](const user_op::InferContext&,                                    \\\n                               user_op::AddInplaceArgPair AddInplaceArgPairFn) -> Maybe<void> { \\\n        OF_RETURN_IF_ERROR(AddInplaceArgPairFn(\"out\", 0, \"in\", 0, true));                       \\\n        return Maybe<void>::Ok();                                                               \\\n      });\n\nREGISTER_DROPOUT_KERNEL_CPU(float)\nREGISTER_DROPOUT_KERNEL_CPU(double)\n\ntemplate<typename T>\nclass DropoutGradKernelCPU final : public user_op::OpKernel {\n public:\n  DropoutGradKernelCPU() = default;\n  ~DropoutGradKernelCPU() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    const user_op::Tensor* mask = ctx->Tensor4ArgNameAndIndex(\"mask\", 0);\n    user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex(\"dx\", 0);\n    const float scale = ctx->Attr<float>(\"scale\");\n    MaskAndScale<T>(ctx->stream(), dy->shape_view().elem_cnt(), scale, dy->dptr<T>(),\n                    mask->dptr<bool>(), dx->mut_dptr<T>());\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_DROPOUT_GRAD_KERNEL_CPU(dtype)                                                 \\\n  REGISTER_USER_KERNEL(\"dropout_grad\")                                                          \\\n      .SetCreateFn<DropoutGradKernelCPU<dtype>>()                                               \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)                           \\\n                       && (user_op::HobDataType(\"dx\", 0) == GetDataType<dtype>::value))         \\\n      .SetInplaceProposalFn([](const user_op::InferContext&,                                    \\\n                               user_op::AddInplaceArgPair AddInplaceArgPairFn) -> Maybe<void> { \\\n        OF_RETURN_IF_ERROR(AddInplaceArgPairFn(\"dx\", 0, \"dy\", 0, true));                        \\\n        return Maybe<void>::Ok();                                                               \\\n      });\n\nREGISTER_DROPOUT_GRAD_KERNEL_CPU(float)\nREGISTER_DROPOUT_GRAD_KERNEL_CPU(double)\n\n}  // namespace\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/dropout_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/data_type.h\"\n#include \"oneflow/core/cuda/elementwise.cuh\"\n#include \"oneflow/core/cuda/atomic.cuh\"\n#include \"oneflow/core/device/cuda_pseudo_bfloat16.h\"\n#include \"oneflow/core/ep/include/device.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n#include \"oneflow/core/kernel/cuda_graph_support.h\"\n#include \"oneflow/user/kernels/op_kernel_wrapper.h\"\n#include \"oneflow/user/kernels/dropout_kernel.h\"\n#include \"oneflow/user/kernels/random_seed_util.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nconstexpr int32_t kVecSize = 4;\nconstexpr int32_t kBlockSize = 256;\n\ntemplate<typename T>\nconstexpr int32_t GetDropoutPackSize() {\n  // For float, bfloat16, half.\n  return 4;\n};\n\ntemplate<>\nconstexpr int32_t GetDropoutPackSize<half2>() {\n  return 2;\n};\n\ntemplate<>\nconstexpr int32_t GetDropoutPackSize<double>() {\n  return 2;\n};\n\nunion RandPack4 {\n  float4 storage;\n  float elem[4];\n};\n\ntemplate<typename T>\nstruct GetPack2Type {\n  using T2 = typename std::aligned_storage<2 * sizeof(T), 2 * sizeof(T)>::type;\n};\n\ntemplate<>\nstruct GetPack2Type<half> {\n  using T2 = half2;\n};\n\n#if CUDA_VERSION >= 11000\ntemplate<>\nstruct GetPack2Type<nv_bfloat16> {\n  using T2 = nv_bfloat162;\n};\n#endif\n\ntemplate<typename T>\nusing Pack2Type = typename GetPack2Type<T>::T2;\n\nusing H2PackType = typename std::aligned_storage<4 * sizeof(half), 4 * sizeof(half)>::type;\n\ntemplate<typename T>\nunion H2Pack {\n  cuda::elementwise::Pack<T, 4> pack_storage;\n  Pack2Type<T> h2[2];\n  __device__ H2Pack() {\n    // do nothing\n  }\n};\n\ntemplate<>\nunion H2Pack<half> {\n  cuda::elementwise::Pack<half, 4> pack_storage;\n  half2 h2[2];\n  __device__ H2Pack() {\n    // do nothing\n  }\n};\n\n#if CUDA_VERSION >= 11000\ntemplate<>\nunion H2Pack<nv_bfloat16> {\n  cuda::elementwise::Pack<nv_bfloat16, 4> pack_storage;\n  nv_bfloat162 h2[2];\n  __device__ H2Pack() {\n    // do nothing\n  }\n};\n#endif\n\ntemplate<typename T>\n__device__ Pack2Type<T> Make2(float v);\n\ntemplate<>\n__device__ Pack2Type<half> Make2<half>(float v) {\n  return __float2half2_rn(v);\n}\n\n#if CUDA_VERSION >= 11000\ntemplate<>\n__device__ Pack2Type<nv_bfloat16> Make2<nv_bfloat16>(float v) {\n  return __float2bfloat162_rn(v);\n}\n#endif\n\n#if CUDA_VERSION >= 11000\n#define RETURN_VOID_IF_HALF                                                                        \\\n  typename std::enable_if_t<(std::is_same<T, half>::value || std::is_same<T, nv_bfloat16>::value), \\\n                            void>\n#else\n#define RETURN_VOID_IF_HALF typename std::enable_if_t<std::is_same<T, half>::value, void>\n#endif\n#define RETURN_VOID_IF_FLOAT typename std::enable_if_t<std::is_same<T, float>::value, void>\n#define RETURN_VOID_IF_DOUBLE typename std::enable_if_t<std::is_same<T, double>::value, void>\n\ntemplate<typename T, int pack_size, bool tail, bool has_addend>\n__global__ RETURN_VOID_IF_FLOAT FusedDropoutAddGpu(uint64_t seed, uint64_t offset,\n                                                   const int64_t elem_cnt, float rate, float scale,\n                                                   int64_t n_tail, const T* x, bool* mask,\n                                                   const T* addend, T* y, const T* tail_x,\n                                                   bool* tail_mask, const T* tail_addend,\n                                                   T* tail_y) {\n  int32_t global_thread_id = blockIdx.x * blockDim.x + threadIdx.x;\n  curandStatePhilox4_32_10_t state;\n  curand_init(seed, global_thread_id, offset, &state);\n  using LoadType = cuda::elementwise::PackType<T, pack_size>;\n  using LoadPack = cuda::elementwise::Pack<T, pack_size>;\n  using MaskType = cuda::elementwise::PackType<bool, pack_size>;\n  using MaskPack = cuda::elementwise::Pack<bool, pack_size>;\n\n  T t_scale = static_cast<T>(scale);\n  RandPack4 rand_uniform_pack4;\n  for (int64_t linear_index = global_thread_id * pack_size; linear_index < elem_cnt;\n       linear_index += gridDim.x * blockDim.x * pack_size) {\n    rand_uniform_pack4.storage = curand_uniform4(&state);\n\n    const LoadType* x_load = reinterpret_cast<const LoadType*>(x + linear_index);\n    LoadPack x_vec;\n    x_vec.storage = *x_load;\n\n    LoadPack addend_vec;\n    if (has_addend) {\n      const LoadType* addend_load = reinterpret_cast<const LoadType*>(addend + linear_index);\n      addend_vec.storage = *addend_load;\n    }\n\n    MaskPack mask_vec;\n    LoadPack y_vec;\n#pragma unroll\n    for (int i = 0; i < pack_size; i++) {\n      mask_vec.elem[i] = rand_uniform_pack4.elem[i] > rate;\n      T tmp_float_mask = static_cast<float>(mask_vec.elem[i]);\n      y_vec.elem[i] = x_vec.elem[i] * tmp_float_mask * t_scale;\n      if (has_addend) { y_vec.elem[i] += addend_vec.elem[i]; }\n    }\n\n    *(reinterpret_cast<LoadType*>(y + linear_index)) = y_vec.storage;\n    *(reinterpret_cast<MaskType*>(mask + linear_index)) = mask_vec.storage;\n  }\n\n  if (tail && global_thread_id < n_tail) {\n    const float rand_uniform = curand_uniform(&state);\n    const bool mask_val = rand_uniform > rate;\n    tail_mask[global_thread_id] = mask_val;\n    T tmp_float_mask = static_cast<float>(mask_val);\n    T tmp_tail_out = tail_x[global_thread_id] * tmp_float_mask * t_scale;\n    if (has_addend) { tmp_tail_out += tail_addend[global_thread_id]; }\n    tail_y[global_thread_id] = tmp_tail_out;\n  }\n}\n\ntemplate<typename T, int pack_size, bool tail, bool has_addend>\n__global__ RETURN_VOID_IF_HALF FusedDropoutAddGpu(uint64_t seed, uint64_t offset,\n                                                  const int64_t elem_cnt, float rate, float scale,\n                                                  int64_t n_tail, const T* x, bool* mask,\n                                                  const T* addend, T* y, const T* tail_x,\n                                                  bool* tail_mask, const T* tail_addend,\n                                                  T* tail_y) {\n  int32_t global_thread_id = blockIdx.x * blockDim.x + threadIdx.x;\n  curandStatePhilox4_32_10_t state;\n  curand_init(seed, global_thread_id, offset, &state);\n  using LoadType = cuda::elementwise::PackType<T, pack_size>;\n  using LoadPack = cuda::elementwise::Pack<T, pack_size>;\n  using StoreType = cuda::elementwise::PackType<Pack2Type<T>, pack_size / 2>;\n  using StorePack = cuda::elementwise::Pack<Pack2Type<T>, pack_size / 2>;\n  using MaskType = cuda::elementwise::PackType<bool, pack_size>;\n  using MaskPack = cuda::elementwise::Pack<bool, pack_size>;\n\n  RandPack4 rand_uniform_pack4;\n  Pack2Type<T> h2_scale = Make2<T>(scale);\n\n  for (int64_t linear_index = global_thread_id * pack_size; linear_index < elem_cnt;\n       linear_index += gridDim.x * blockDim.x * pack_size) {\n    rand_uniform_pack4.storage = curand_uniform4(&state);\n    const LoadType* x_load = reinterpret_cast<const LoadType*>(x + linear_index);\n    H2Pack<T> x_vec{};\n    x_vec.pack_storage.storage = *x_load;\n\n    H2Pack<T> addend_vec{};\n    if (has_addend) {\n      const LoadType* addend_load = reinterpret_cast<const LoadType*>(addend + linear_index);\n      addend_vec.pack_storage.storage = *addend_load;\n    }\n\n    MaskPack mask_vec;\n    StorePack y_vec;\n    StorePack one_or_zero_h2;\n\n    mask_vec.elem[0] = rand_uniform_pack4.elem[0] > rate;\n    float tmp_float_mask = static_cast<float>(mask_vec.elem[0]);\n    one_or_zero_h2.elem[0].x = tmp_float_mask;\n    mask_vec.elem[1] = rand_uniform_pack4.elem[1] > rate;\n    tmp_float_mask = static_cast<float>(mask_vec.elem[1]);\n    one_or_zero_h2.elem[0].y = tmp_float_mask;\n    y_vec.elem[0] = __hmul2(__hmul2(x_vec.h2[0], one_or_zero_h2.elem[0]), h2_scale);\n\n    mask_vec.elem[2] = rand_uniform_pack4.elem[2] > rate;\n    tmp_float_mask = static_cast<float>(mask_vec.elem[2]);\n    one_or_zero_h2.elem[1].x = tmp_float_mask;\n    mask_vec.elem[3] = rand_uniform_pack4.elem[3] > rate;\n    tmp_float_mask = static_cast<float>(mask_vec.elem[3]);\n    one_or_zero_h2.elem[1].y = tmp_float_mask;\n    y_vec.elem[1] = __hmul2(__hmul2(x_vec.h2[1], one_or_zero_h2.elem[1]), h2_scale);\n\n    if (has_addend) {\n      y_vec.elem[0] = __hadd2(y_vec.elem[0], addend_vec.h2[0]);\n      y_vec.elem[1] = __hadd2(y_vec.elem[1], addend_vec.h2[1]);\n    }\n\n    *(reinterpret_cast<StoreType*>(y + linear_index)) = y_vec.storage;\n    *(reinterpret_cast<MaskType*>(mask + linear_index)) = mask_vec.storage;\n  }\n\n  if (tail && global_thread_id < n_tail) {\n    const float rand_uniform = curand_uniform(&state);\n    const bool mask_val = rand_uniform > rate;\n    tail_mask[global_thread_id] = mask_val;\n    float tmp_half_mask = static_cast<float>(mask_val);\n    T tmp_tail_out = tail_x[global_thread_id] * static_cast<T>(tmp_half_mask) * h2_scale.x;\n    if (has_addend) { tmp_tail_out += tail_addend[global_thread_id]; }\n    tail_y[global_thread_id] = tmp_tail_out;\n  }\n}\n\ntemplate<typename T, int pack_size, bool tail, bool has_addend>\n__global__ RETURN_VOID_IF_DOUBLE FusedDropoutAddGpu(uint64_t seed, uint64_t offset,\n                                                    const int64_t elem_cnt, float rate, float scale,\n                                                    int64_t n_tail, const T* x, bool* mask,\n                                                    const T* addend, T* y, const T* tail_x,\n                                                    bool* tail_mask, const T* tail_addend,\n                                                    T* tail_y) {\n  int32_t global_thread_id = blockIdx.x * blockDim.x + threadIdx.x;\n  curandStatePhilox4_32_10_t state;\n  curand_init(seed, global_thread_id, offset, &state);\n  using LoadType = cuda::elementwise::PackType<T, pack_size>;\n  using LoadPack = cuda::elementwise::Pack<T, pack_size>;\n  using MaskType = cuda::elementwise::PackType<bool, pack_size>;\n  using MaskPack = cuda::elementwise::Pack<bool, pack_size>;\n\n  RandPack4 rand_uniform_pack4;\n  bool grid_loop_rand_state = 0;\n\n  for (int64_t linear_index = global_thread_id * pack_size; linear_index < elem_cnt;\n       linear_index += gridDim.x * blockDim.x * pack_size) {\n    if (grid_loop_rand_state == 0) {\n      rand_uniform_pack4.storage = curand_uniform4(&state);\n      grid_loop_rand_state ^= 1;\n    } else {\n      // Use the last two random numbers we generated in previous iteration.\n      rand_uniform_pack4.elem[0] = rand_uniform_pack4.elem[2];\n      rand_uniform_pack4.elem[1] = rand_uniform_pack4.elem[3];\n      grid_loop_rand_state ^= 1;\n    }\n    const LoadType* x_load = reinterpret_cast<const LoadType*>(x + linear_index);\n    LoadPack x_vec;\n    x_vec.storage = *x_load;\n\n    LoadPack addend_vec;\n    if (has_addend) {\n      const LoadType* addend_load = reinterpret_cast<const LoadType*>(addend + linear_index);\n      addend_vec.storage = *addend_load;\n    }\n\n    MaskPack mask_vec;\n    LoadPack y_vec;\n#pragma unroll\n    for (int i = 0; i < pack_size; i++) {\n      mask_vec.elem[i] = rand_uniform_pack4.elem[i] > rate;\n      y_vec.elem[i] = x_vec.elem[i] * mask_vec.elem[i] * scale;\n      if (has_addend) { y_vec.elem[i] += addend_vec.elem[i]; }\n    }\n    *(reinterpret_cast<LoadType*>(y + linear_index)) = y_vec.storage;\n    *(reinterpret_cast<MaskType*>(mask + linear_index)) = mask_vec.storage;\n  }\n\n  if (tail && global_thread_id < n_tail) {\n    const float rand_uniform = curand_uniform(&state);\n    const bool mask_val = rand_uniform > rate;\n    tail_mask[global_thread_id] = mask_val;\n    double tmp_tail_out = tail_x[global_thread_id] * mask_val * scale;\n    if (has_addend) { tmp_tail_out += tail_addend[global_thread_id]; }\n    tail_y[global_thread_id] = tmp_tail_out;\n  }\n}\n\nunsigned int ComputeGridSize(ep::Stream* stream, const int32_t block_size, const int64_t elem_cnt) {\n  auto* cuda_stream = stream->As<ep::CudaStream>();\n  const int32_t max_threads_multi_process =\n      cuda_stream->device_properties().maxThreadsPerMultiProcessor;\n  const int32_t multi_processor_count = cuda_stream->device_properties().multiProcessorCount;\n  unsigned int blocks_per_sm = max_threads_multi_process / block_size;\n  unsigned int grid_size = std::max((int64_t)1, ((elem_cnt + block_size - 1) / block_size));\n  grid_size = std::min((unsigned int)multi_processor_count * blocks_per_sm, grid_size);\n  return grid_size;\n}\n\ntemplate<typename T, bool has_addend>\nvoid DispatchTail(ep::Stream* stream, const std::shared_ptr<ep::CUDAGenerator>& cuda_generator,\n                  const int64_t elem_cnt, float rate, float scale, const T* x, bool* mask,\n                  const T* addend, T* y) {\n  constexpr int pack_size = GetDropoutPackSize<T>();\n  const int64_t pack_num = elem_cnt / pack_size;\n  unsigned int grid_size = ComputeGridSize(stream, kBlockSize, pack_num);\n  const int64_t tail_offset = pack_num * pack_size;\n  const int64_t n_tail = elem_cnt - tail_offset;\n  const bool tail = n_tail > 0 ? true : false;\n  uint64_t offset = 0;\n  uint64_t seed = cuda_generator->current_seed();\n\n  if (tail) {\n    // If tail, we need generate randnum one more time, so here we add another `1`.\n    uint64_t inc_offset = ((elem_cnt - 1) / (kBlockSize * grid_size * kVecSize) + 1) * kVecSize + 1;\n    offset = cuda_generator->get_philox_offset(inc_offset);\n    FusedDropoutAddGpu<T, pack_size, true, has_addend>\n        <<<grid_size, kBlockSize, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(\n            seed, offset, elem_cnt, rate, scale, n_tail, x, mask, addend, y, (x + tail_offset),\n            (mask + tail_offset), (addend + tail_offset), (y + tail_offset));\n  } else {\n    uint64_t inc_offset = ((elem_cnt - 1) / (kBlockSize * grid_size * kVecSize) + 1) * kVecSize;\n    offset = cuda_generator->get_philox_offset(inc_offset);\n    FusedDropoutAddGpu<T, pack_size, false, has_addend>\n        <<<grid_size, kBlockSize, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(\n            seed, offset, elem_cnt, rate, scale, n_tail, x, mask, addend, y, nullptr, nullptr,\n            nullptr, nullptr);\n  }\n}\n\ntemplate<typename T>\nstruct MaskAndScaleFunctor {\n  OF_DEVICE_FUNC explicit MaskAndScaleFunctor(float scale) : scale(scale) {}\n  __device__ T operator()(T x, bool mask) const {\n    return x * static_cast<T>(mask) * static_cast<T>(scale);\n  }\n  float scale;\n};\n\n#if CUDA_VERSION >= 11000\ntemplate<>\nstruct MaskAndScaleFunctor<nv_bfloat16> {\n  OF_DEVICE_FUNC explicit MaskAndScaleFunctor(float scale) : scale(scale) {}\n  __device__ nv_bfloat16 operator()(nv_bfloat16 x, bool mask) const {\n    float float_mask = static_cast<float>(mask);\n    return x * static_cast<nv_bfloat16>(float_mask) * static_cast<nv_bfloat16>(scale);\n  }\n  float scale;\n};\n#endif\n\ntemplate<typename T>\nclass DropoutKernelGPU final : public user_op::OpKernel {\n public:\n  DropoutKernelGPU() = default;\n  ~DropoutKernelGPU() = default;\n\n  std::shared_ptr<user_op::OpKernelState> CreateOpKernelState(\n      user_op::KernelInitContext* ctx) const override {\n    const auto& generator = CHECK_JUST(one::MakeGenerator(DeviceType::kCUDA));\n    generator->set_current_seed(\n        CHECK_JUST(GetOpKernelRandomSeedInCurrentRank(ctx, ctx->Attr<int64_t>(\"seed\"))));\n    return std::make_shared<FusedDropoutKernelState>(generator);\n  }\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state,\n               const user_op::OpKernelCache*) const override {\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    user_op::Tensor* mask = ctx->Tensor4ArgNameAndIndex(\"mask\", 0);\n    auto* fused_dropout_kernel_state = dynamic_cast<FusedDropoutKernelState*>(state);\n    CHECK_NOTNULL(fused_dropout_kernel_state);\n    const auto& generator = fused_dropout_kernel_state->generator();\n    CHECK_NOTNULL(generator);\n    auto* stream = ctx->stream();\n    const auto device_index = stream->device()->device_index();\n    std::shared_ptr<ep::CUDAGenerator> cuda_generator =\n        CHECK_JUST(generator->Get<ep::CUDAGenerator>(device_index));\n\n    const float rate = ctx->Attr<float>(\"rate\");\n    float scale = 0.0;\n    if (rate < 1.0f) { scale = 1.0f / (1.0f - rate); }\n\n    if (ctx->has_input(\"_add_to_output\", 0)) {\n      const user_op::Tensor* addend = ctx->Tensor4ArgNameAndIndex(\"_add_to_output\", 0);\n      DispatchTail<T, true>(\n          stream, cuda_generator, in->shape_view().elem_cnt(), rate, scale,\n          reinterpret_cast<const T*>(in->dptr()), reinterpret_cast<bool*>(mask->mut_dptr()),\n          reinterpret_cast<const T*>(addend->dptr()), reinterpret_cast<T*>(out->mut_dptr()));\n    } else {\n      DispatchTail<T, false>(stream, cuda_generator, in->shape_view().elem_cnt(), rate, scale,\n                             reinterpret_cast<const T*>(in->dptr()),\n                             reinterpret_cast<bool*>(mask->mut_dptr()), nullptr,\n                             reinterpret_cast<T*>(out->mut_dptr()));\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_DROPOUT_KERNEL_GPU(cpp_type, data_type)                                     \\\n  REGISTER_USER_KERNEL(\"dropout\").SetCreateFn<DropoutKernelGPU<cpp_type>>().SetIsMatchedHob( \\\n      (user_op::HobDeviceType() == DeviceType::kCUDA)                                        \\\n      && (user_op::HobDataType(\"out\", 0) == data_type)                                       \\\n      && (user_op::HobDataType(\"mask\", 0) == GetDataType<bool>::value))\n\nREGISTER_DROPOUT_KERNEL_GPU(half, DataType::kFloat16);\nREGISTER_DROPOUT_KERNEL_GPU(float, DataType::kFloat);\nREGISTER_DROPOUT_KERNEL_GPU(double, DataType::kDouble);\n#if CUDA_VERSION >= 11000\nREGISTER_DROPOUT_KERNEL_GPU(nv_bfloat16, DataType::kBFloat16);\n#endif\n\ntemplate<typename T>\nclass DropoutGradKernelGPU final : public user_op::OpKernel, public user_op::CudaGraphSupport {\n public:\n  DropoutGradKernelGPU() = default;\n  ~DropoutGradKernelGPU() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    const user_op::Tensor* mask = ctx->Tensor4ArgNameAndIndex(\"mask\", 0);\n    user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex(\"dx\", 0);\n    const float scale = ctx->Attr<float>(\"scale\");\n    const int64_t elem_cnt = dy->shape_view().elem_cnt();\n    OF_CUDA_CHECK((cuda::elementwise::Binary(\n        MaskAndScaleFunctor<T>(scale), elem_cnt, reinterpret_cast<T*>(dx->mut_dptr()),\n        reinterpret_cast<const T*>(dy->dptr()), reinterpret_cast<const bool*>(mask->dptr()),\n        ctx->stream()->As<ep::CudaStream>()->cuda_stream())));\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_DROPOUT_GRAD_KERNEL_GPU(cpp_type, data_type)                                   \\\n  REGISTER_USER_KERNEL(\"dropout_grad\")                                                          \\\n      .SetCreateFn<DropoutGradKernelGPU<cpp_type>>()                                            \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)                          \\\n                       && (user_op::HobDataType(\"dx\", 0) == data_type))                         \\\n      .SetInplaceProposalFn([](const user_op::InferContext&,                                    \\\n                               user_op::AddInplaceArgPair AddInplaceArgPairFn) -> Maybe<void> { \\\n        OF_RETURN_IF_ERROR(AddInplaceArgPairFn(\"dx\", 0, \"dy\", 0, true));                        \\\n        return Maybe<void>::Ok();                                                               \\\n      })\n\nREGISTER_DROPOUT_GRAD_KERNEL_GPU(half, DataType::kFloat16);\nREGISTER_DROPOUT_GRAD_KERNEL_GPU(float, DataType::kFloat);\nREGISTER_DROPOUT_GRAD_KERNEL_GPU(double, DataType::kDouble);\n#if CUDA_VERSION >= 11000\nREGISTER_DROPOUT_GRAD_KERNEL_GPU(nv_bfloat16, DataType::kBFloat16);\n#endif\n\n}  // namespace\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/dropout_kernel.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_KERNELS_DROPOUT_KERNEL_H_\n#define ONEFLOW_USER_KERNELS_DROPOUT_KERNEL_H_\n\n#include \"oneflow/user/kernels/random_mask_generator.h\"\n#include \"oneflow/core/framework/framework.h\"\n\nnamespace oneflow {\n\nclass FusedDropoutKernelState : public user_op::OpKernelState {\n public:\n  explicit FusedDropoutKernelState(const std::shared_ptr<one::Generator>& generator)\n      : generator_(generator) {}\n\n  const std::shared_ptr<one::Generator>& generator() const { return generator_; }\n\n private:\n  std::shared_ptr<one::Generator> generator_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_KERNELS_DROPOUT_KERNEL_H_\n"
  },
  {
    "path": "oneflow/user/kernels/dynamic_loss_scale_schedule_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n\nnamespace oneflow {\n\nclass DynamicLossScaleScheduleCpuKernel final : public user_op::OpKernel {\n public:\n  DynamicLossScaleScheduleCpuKernel() = default;\n  ~DynamicLossScaleScheduleCpuKernel() override = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const auto* count_not_finite =\n        ctx->Tensor4ArgNameAndIndex(\"count_not_finite\", 0)->dptr<int64_t>();\n    auto* loss_scale = ctx->Tensor4ArgNameAndIndex(\"loss_scale\", 0)->mut_dptr<float>();\n    auto* good_step_counter =\n        ctx->Tensor4ArgNameAndIndex(\"good_step_counter\", 0)->mut_dptr<int64_t>();\n    const auto increment_period = ctx->Attr<int64_t>(\"increment_period\");\n    const auto multiplier = ctx->Attr<float>(\"multiplier\");\n    if (*count_not_finite == 0) {\n      int64_t cur_good_step_counter = *good_step_counter + 1;\n      if (cur_good_step_counter >= increment_period) {\n        const double old_loss_scale = *loss_scale;\n        const double new_loss_scale =\n            std::min(old_loss_scale * multiplier, static_cast<double>(FLT_MAX));\n        *loss_scale = static_cast<float>(new_loss_scale);\n        cur_good_step_counter = 0;\n        LOG(INFO) << \"In past \" << increment_period\n                  << \" steps, there are no nan or inf in gradients, so we increase loss_scale from \"\n                  << old_loss_scale << \" to \" << new_loss_scale;\n      }\n      *good_step_counter = cur_good_step_counter;\n    } else {\n      *good_step_counter = 0;\n      const double old_loss_scale = *loss_scale;\n      const double new_loss_scale = std::max(old_loss_scale / multiplier, 1.0);\n      *loss_scale = static_cast<float>(new_loss_scale);\n      LOG(INFO) << \"There are nan or inf in gradients, so we decrease loss_scale from \"\n                << old_loss_scale << \" to \" << new_loss_scale;\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return true; }\n};\n\nREGISTER_USER_KERNEL(\"dynamic_loss_scale_schedule\")\n    .SetCreateFn<DynamicLossScaleScheduleCpuKernel>()\n    .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU));\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/dynamic_loss_scale_schedule_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n\nnamespace oneflow {\n\nnamespace {\n\n__global__ void DynamicLossScaleScheduleGpu(const int64_t increment_period, const float multiplier,\n                                            const int64_t* count_not_finite, float* loss_scale,\n                                            int64_t* good_step_counter) {\n  if (*count_not_finite == 0) {\n    int64_t cur_good_step_counter = *good_step_counter + 1;\n    if (cur_good_step_counter >= increment_period) {\n      *loss_scale = static_cast<float>(\n          min(static_cast<double>(*loss_scale) * multiplier, static_cast<double>(FLT_MAX)));\n      cur_good_step_counter = 0;\n    }\n    *good_step_counter = cur_good_step_counter;\n  } else {\n    *good_step_counter = 0;\n    *loss_scale = static_cast<float>(max(static_cast<double>(*loss_scale) / multiplier, 1.0));\n  }\n}\n\n}  // namespace\n\nclass DynamicLossScaleScheduleGpuKernel final : public user_op::OpKernel {\n public:\n  DynamicLossScaleScheduleGpuKernel() = default;\n  ~DynamicLossScaleScheduleGpuKernel() override = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* count_not_finite = ctx->Tensor4ArgNameAndIndex(\"count_not_finite\", 0);\n    user_op::Tensor* loss_scale = ctx->Tensor4ArgNameAndIndex(\"loss_scale\", 0);\n    user_op::Tensor* good_step_counter = ctx->Tensor4ArgNameAndIndex(\"good_step_counter\", 0);\n    const auto increment_period = ctx->Attr<int64_t>(\"increment_period\");\n    const auto multiplier = ctx->Attr<float>(\"multiplier\");\n    DynamicLossScaleScheduleGpu<<<1, 1, 0, ctx->stream()->As<ep::CudaStream>()->cuda_stream()>>>(\n        increment_period, multiplier, count_not_finite->dptr<int64_t>(),\n        loss_scale->mut_dptr<float>(), good_step_counter->mut_dptr<int64_t>());\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return true; }\n};\n\nREGISTER_USER_KERNEL(\"dynamic_loss_scale_schedule\")\n    .SetCreateFn<DynamicLossScaleScheduleGpuKernel>()\n    .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA));\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/eager_b_to_s_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/user/kernels/communicate_util.h\"\n#include \"oneflow/core/device/nccl_util.h\"\n#include \"oneflow/core/common/balanced_splitter.h\"\n#include \"oneflow/core/common/container_util.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/new_kernel_util.h\"\n#include \"oneflow/core/job/parallel_desc.h\"\n#include \"oneflow/core/control/global_process_ctx.h\"\n#include \"oneflow/core/job/nd_sbp_util.h\"\n#include \"oneflow/core/register/tensor_slice_copier.h\"\n#include \"oneflow/core/framework/placement_sbp_util.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nMaybe<Symbol<NdSbp>> GetAllSplitNdSbp(int64_t axis, int64_t ndim) {\n  NdSbp split_nd_sbp;\n  for (int64_t i = 0; i < ndim; ++i) {\n    split_nd_sbp.mutable_sbp_parallel()->Add()->mutable_split_parallel()->set_axis(axis);\n  }\n  return SymbolOf(split_nd_sbp);\n}\n\nauto* CachedGetAllSplitNdSbp = DECORATE(&GetAllSplitNdSbp, ThreadLocal);\n\nMaybe<Symbol<NdSbp>> GetAllBroadcastNdSbp(int64_t ndim) {\n  NdSbp split_nd_sbp;\n  for (int64_t i = 0; i < ndim; ++i) {\n    split_nd_sbp.mutable_sbp_parallel()->Add()->mutable_broadcast_parallel();\n  }\n  return SymbolOf(split_nd_sbp);\n}\n\nauto* CachedGetAllBroadcastNdSbp = DECORATE(&GetAllBroadcastNdSbp, ThreadLocal);\n\nclass EagerBToSOpKernelCache final : public user_op::OpKernelCache {\n public:\n  explicit EagerBToSOpKernelCache(user_op::KernelCacheContext* ctx) { Init(ctx); }\n  ~EagerBToSOpKernelCache() override = default;\n\n  const std::vector<std::pair<int64_t, std::shared_ptr<TensorSliceCopier>>>&\n  sorted_elem_cnt2in_tensor_slice_copier_pair() const {\n    return sorted_elem_cnt2in_tensor_slice_copier_pair_;\n  }\n\n  const std::vector<std::pair<int64_t, std::shared_ptr<TensorSliceCopier>>>&\n  sorted_elem_cnt2out_tensor_slice_copier_pair() const {\n    return sorted_elem_cnt2out_tensor_slice_copier_pair_;\n  }\n\n  const std::vector<std::pair<int64_t, int64_t>>& sorted_p2p_pair() const {\n    return sorted_p2p_pair_;\n  }\n\n private:\n  void Init(user_op::KernelCacheContext* ctx) {\n    const std::string& in_parallel_conf_txt = ctx->Attr<std::string>(\"in_parallel_conf\");\n    const std::string& out_parallel_conf_txt = ctx->Attr<std::string>(\"out_parallel_conf\");\n    const int64_t out_split_axis = ctx->Attr<int64_t>(\"out_split_axis\");\n    const Shape& shape = ctx->Attr<Shape>(\"shape\");\n    DeviceType device_type = ctx->device_type();\n    DataType data_type = ctx->TensorDesc4ArgNameAndIndex(\"in\", 0)->data_type();\n    Symbol<ParallelDesc> in_parallel_desc = CHECK_JUST(TxtStringToPlacement(in_parallel_conf_txt));\n    Symbol<ParallelDesc> out_parallel_desc =\n        CHECK_JUST(TxtStringToPlacement(out_parallel_conf_txt));\n    int64_t out_parallel_num = out_parallel_desc->parallel_num();\n\n    for (int64_t out_parallel_id = 0; out_parallel_id < out_parallel_num; ++out_parallel_id) {\n      int64_t dst = CHECK_JUST(out_parallel_desc->MachineId4ParallelId(out_parallel_id));\n      int64_t src = -1;\n      const TensorSliceView& out_slice = GetTensorSliceView4ParallelId(\n          *out_parallel_desc->hierarchy(),\n          *CHECK_JUST(\n              CachedGetAllSplitNdSbp(out_split_axis, out_parallel_desc->hierarchy()->NumAxes())),\n          shape, out_parallel_id);\n      CHECK(!out_slice.IsEmpty());\n      TensorSliceView in_slice;\n      TensorSliceView intersection;\n      {\n        if (in_parallel_desc->ContainingMachineId(dst)) {\n          src = dst;\n          int64_t src_device_id = GlobalProcessCtx::LocalRank(src);\n          int64_t in_parallel_id =\n              CHECK_JUST(in_parallel_desc->ParallelId4MachineDeviceId(src, src_device_id));\n          in_slice = GetTensorSliceView4ParallelId(\n              *in_parallel_desc->hierarchy(),\n              *CHECK_JUST(CachedGetAllBroadcastNdSbp(in_parallel_desc->hierarchy()->NumAxes())),\n              shape, in_parallel_id);\n          // copy to out_slice from in_slice if src == dst\n          intersection = out_slice;\n        } else {\n          int64_t in_parallel_num = in_parallel_desc->parallel_num();\n          int64_t in_parallel_id = out_parallel_id % in_parallel_num;\n          src = CHECK_JUST(in_parallel_desc->MachineId4ParallelId(in_parallel_id));\n          in_slice = GetTensorSliceView4ParallelId(\n              *in_parallel_desc->hierarchy(),\n              *CHECK_JUST(GetAllBroadcastNdSbp(in_parallel_desc->hierarchy()->NumAxes())), shape,\n              in_parallel_id);\n          intersection = out_slice.Intersect(in_slice);\n        }\n      }\n      CHECK_NE(src, -1);\n      CHECK(!in_slice.IsEmpty());\n      CHECK(!intersection.IsEmpty());\n      sorted_p2p_pair_.emplace_back(std::make_pair(src, dst));\n      sorted_elem_cnt2in_tensor_slice_copier_pair_.emplace_back(std::make_pair(\n          intersection.shape().elem_cnt(),\n          std::make_shared<TensorSliceCopier>(intersection, in_slice, data_type, device_type)));\n      sorted_elem_cnt2out_tensor_slice_copier_pair_.emplace_back(std::make_pair(\n          intersection.shape().elem_cnt(),\n          std::make_shared<TensorSliceCopier>(out_slice, intersection, data_type, device_type)));\n    }\n  }\n\n  std::vector<std::pair<int64_t, std::shared_ptr<TensorSliceCopier>>>\n      sorted_elem_cnt2in_tensor_slice_copier_pair_;\n  std::vector<std::pair<int64_t, std::shared_ptr<TensorSliceCopier>>>\n      sorted_elem_cnt2out_tensor_slice_copier_pair_;\n  std::vector<std::pair<int64_t, int64_t>> sorted_p2p_pair_;\n};\n\nsize_t InferEagerBToSKernelTmpBufferSize(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc(\"in\", 0);\n  Shape shape = ctx->Attr<Shape>(\"shape\");\n  const int64_t out_split_axis = ctx->Attr<int64_t>(\"out_split_axis\");\n  const std::string& out_parallel_conf_txt = ctx->Attr<std::string>(\"out_parallel_conf\");\n  Symbol<ParallelDesc> out_parallel_desc = CHECK_JUST(TxtStringToPlacement(out_parallel_conf_txt));\n  int64_t out_parallel_num = out_parallel_desc->parallel_num();\n  if (out_parallel_num > 1) {\n    CHECK_LT(out_split_axis, shape.NumAxes());\n    BalancedSplitter bs(shape.At(out_split_axis), out_parallel_num);\n    shape.Set(out_split_axis, bs.At(0).size());\n  }\n  size_t tensor_byte_size = shape.elem_cnt() * GetSizeOfDataType(in_tensor.data_type());\n  return tensor_byte_size;\n}\n\n}  // namespace\n\nclass EagerBToSKernel final : public user_op::OpKernel {\n public:\n  EagerBToSKernel() = default;\n  ~EagerBToSKernel() override = default;\n\n  void InitOpKernelCacheWithFlags(\n      user_op::KernelCacheContext* ctx, int8_t flag,\n      std::shared_ptr<user_op::OpKernelCache>* cache_ptr) const override {\n    if (*cache_ptr == nullptr) { *cache_ptr = std::make_shared<EagerBToSOpKernelCache>(ctx); }\n  }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*,\n               const user_op::OpKernelCache* cache) const override {\n    auto* kernel_cache = dynamic_cast<const EagerBToSOpKernelCache*>(cache);\n    CHECK(kernel_cache != nullptr);\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n    const void* in_ptr = in->dptr();\n    void* out_ptr = out->mut_dptr();\n    void* tmp_buffer_ptr = tmp_buffer->mut_dptr();\n\n    const auto& sorted_elem_cnt2in_tensor_slice_copier_pair =\n        kernel_cache->sorted_elem_cnt2in_tensor_slice_copier_pair();\n    const auto& sorted_elem_cnt2out_tensor_slice_copier_pair =\n        kernel_cache->sorted_elem_cnt2out_tensor_slice_copier_pair();\n    const auto& sorted_p2p_pair = kernel_cache->sorted_p2p_pair();\n    CHECK_EQ(sorted_elem_cnt2in_tensor_slice_copier_pair.size(), sorted_p2p_pair.size());\n    CHECK_EQ(sorted_elem_cnt2out_tensor_slice_copier_pair.size(), sorted_p2p_pair.size());\n\n    DeviceType device_type = ctx->device_type();\n\n    for (int64_t i = 0; i < sorted_p2p_pair.size(); ++i) {\n      const auto& p2p_pair = sorted_p2p_pair.at(i);\n      int64_t src = p2p_pair.first;\n      int64_t dst = p2p_pair.second;\n      if (src == dst && src == GlobalProcessCtx::Rank()) {\n        const auto& elem_cnt2tensor_slice_copier_pair =\n            sorted_elem_cnt2in_tensor_slice_copier_pair.at(i);\n        const auto& tensor_slice_copier = elem_cnt2tensor_slice_copier_pair.second;\n        tensor_slice_copier->Copy(ctx->stream(), out_ptr, in_ptr);\n        continue;\n      }\n      if (GlobalProcessCtx::Rank() == src) {\n        const auto& elem_cnt2tensor_slice_copier_pair =\n            sorted_elem_cnt2in_tensor_slice_copier_pair.at(i);\n        const auto& elem_cnt = elem_cnt2tensor_slice_copier_pair.first;\n        const auto& tensor_slice_copier = elem_cnt2tensor_slice_copier_pair.second;\n        tensor_slice_copier->Copy(ctx->stream(), tmp_buffer_ptr, in_ptr);\n        CHECK_JUST(Send(reinterpret_cast<const void*>(tmp_buffer_ptr), elem_cnt, in->data_type(),\n                        dst, device_type, ctx->stream()));\n      }\n      if (GlobalProcessCtx::Rank() == dst) {\n        const auto& elem_cnt2tensor_slice_copier_pair =\n            sorted_elem_cnt2out_tensor_slice_copier_pair.at(i);\n        const auto& elem_cnt = elem_cnt2tensor_slice_copier_pair.first;\n        const auto& tensor_slice_copier = elem_cnt2tensor_slice_copier_pair.second;\n        CHECK_JUST(\n            Recv(tmp_buffer_ptr, elem_cnt, out->data_type(), src, device_type, ctx->stream()));\n        tensor_slice_copier->Copy(ctx->stream(), out_ptr,\n                                  reinterpret_cast<const void*>(tmp_buffer_ptr));\n      }\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\nREGISTER_USER_KERNEL(\"eager_b_to_s\")\n    .SetCreateFn<EagerBToSKernel>()\n    .SetIsMatchedHob(HobIsSendAndRecvRegistered())\n    .SetInferTmpSizeFn(InferEagerBToSKernelTmpBufferSize);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/eager_ccl_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/user/kernels/communicate_util.h\"\n#include \"oneflow/core/control/global_process_ctx.h\"\n#include \"oneflow/user/kernels/collective_communication/include/communication_context.h\"\n#include \"oneflow/user/kernels/collective_communication/include/all_reduce.h\"\n#include \"oneflow/user/kernels/collective_communication/include/reduce_scatter.h\"\n#include \"oneflow/user/kernels/collective_communication/include/all_gather.h\"\n#include \"oneflow/user/kernels/collective_communication/include/reduce.h\"\n#include \"oneflow/user/kernels/collective_communication/include/broadcast.h\"\n#include \"oneflow/core/ep/include/primitive/permute.h\"\n#include \"oneflow/core/framework/framework.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nauto AllReduceCollectiveCommunicationExists() {\n  return hob::make_custom(\"AllReduceCollectiveCommunicationExists\",\n                          [=](const user_op::KernelRegContext& ctx) {\n                            DeviceType device_type = ctx.device_type();\n                            return ccl::IsCommunicationContextRegistered(device_type)\n                                   && ccl::IsAllReduceRegistered(device_type);\n                          });\n}\n\nauto ReduceScatterCollectiveCommunicationExists() {\n  return hob::make_custom(\"ReduceScatterCollectiveCommunicationExists\",\n                          [=](const user_op::KernelRegContext& ctx) {\n                            DeviceType device_type = ctx.device_type();\n                            return ccl::IsCommunicationContextRegistered(device_type)\n                                   && ccl::IsReduceScatterRegistered(device_type);\n                          });\n}\n\nauto AllGatherCollectiveCommunicationExists() {\n  return hob::make_custom(\"AllGatherCollectiveCommunicationExists\",\n                          [=](const user_op::KernelRegContext& ctx) {\n                            DeviceType device_type = ctx.device_type();\n                            return ccl::IsCommunicationContextRegistered(device_type)\n                                   && ccl::IsAllGatherRegistered(device_type);\n                          });\n}\n\nauto ReduceCollectiveCommunicationExists() {\n  return hob::make_custom(\"ReduceCollectiveCommunicationExists\",\n                          [=](const user_op::KernelRegContext& ctx) {\n                            DeviceType device_type = ctx.device_type();\n                            return ccl::IsCommunicationContextRegistered(device_type)\n                                   && ccl::IsReduceRegistered(device_type);\n                          });\n}\n\nauto BroadcastCollectiveCommunicationExists() {\n  return hob::make_custom(\"BroadcastCollectiveCommunicationExists\",\n                          [=](const user_op::KernelRegContext& ctx) {\n                            DeviceType device_type = ctx.device_type();\n                            return ccl::IsCommunicationContextRegistered(device_type)\n                                   && ccl::IsBroadcastRegistered(device_type);\n                          });\n}\n\nclass EagerCclOpKernelCache final : public user_op::OpKernelCache {\n public:\n  explicit EagerCclOpKernelCache(user_op::KernelCacheContext* ctx) { Init(ctx); }\n  ~EagerCclOpKernelCache() override = default;\n\n  const std::shared_ptr<ccl::CommunicationContext>& communication_ctx() const {\n    return communication_ctx_;\n  }\n\n private:\n  void Init(user_op::KernelCacheContext* ctx) {\n    const std::string& parallel_conf_txt = ctx->Attr<std::string>(\"parallel_conf\");\n    ParallelConf parallel_conf;\n    CHECK(TxtString2PbMessage(parallel_conf_txt, &parallel_conf));\n    Symbol<ParallelDesc> parallel_desc = SymbolOf(ParallelDesc(parallel_conf));\n    communication_ctx_ = ccl::NewCommunicationContext(parallel_desc->device_type(), parallel_desc);\n  }\n\n  std::shared_ptr<ccl::CommunicationContext> communication_ctx_;\n};\n\nvoid InitEagerCclOpKernelCache(user_op::KernelCacheContext* ctx,\n                               std::shared_ptr<user_op::OpKernelCache>* cache_ptr) {\n  // NOTE(jianhao): the cache only depends on parallel_conf, and the kernel is singleton\n  // once parallel_conf is determined, so only init the cache at the first time.\n  if (*cache_ptr == nullptr) { *cache_ptr = std::make_shared<EagerCclOpKernelCache>(ctx); }\n}\n\n}  // namespace\n\nclass EagerCclAllReduceKernel final : public user_op::OpKernel {\n public:\n  EagerCclAllReduceKernel() = default;\n  ~EagerCclAllReduceKernel() override = default;\n\n  void InitOpKernelCacheWithFlags(\n      user_op::KernelCacheContext* ctx, int8_t flag,\n      std::shared_ptr<user_op::OpKernelCache>* cache_ptr) const override {\n    InitEagerCclOpKernelCache(ctx, cache_ptr);\n  }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*,\n               const user_op::OpKernelCache* cache) const override {\n    auto* kernel_cache = dynamic_cast<const EagerCclOpKernelCache*>(cache);\n    CHECK(kernel_cache != nullptr);\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    CHECK_EQ(in->shape_view(), out->shape_view()) << kOfBugIssueUploadPrompt;\n    CHECK_EQ(in->data_type(), out->data_type()) << kOfBugIssueUploadPrompt;\n\n    ccl::ReduceType reduce_type = ccl::kSum;\n    if (in->data_type() == kBool) { reduce_type = ccl::kMax; }\n\n    std::unique_ptr<ccl::AllReduce> all_reduce = ccl::NewCollectiveCommunication<ccl::AllReduce>(\n        ctx->device_type(), in->data_type(), reduce_type);\n    all_reduce->Launch(ctx->stream(), in->dptr(), out->mut_dptr(), out->shape_view().elem_cnt(),\n                       kernel_cache->communication_ctx());\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\nREGISTER_USER_KERNEL(\"eager_ccl_all_reduce\")\n    .SetCreateFn<EagerCclAllReduceKernel>()\n    .SetIsMatchedHob(AllReduceCollectiveCommunicationExists());\n\nclass EagerCclReduceScatterKernel final : public user_op::OpKernel {\n public:\n  EagerCclReduceScatterKernel() = default;\n  ~EagerCclReduceScatterKernel() override = default;\n\n  void InitOpKernelCacheWithFlags(\n      user_op::KernelCacheContext* ctx, int8_t flag,\n      std::shared_ptr<user_op::OpKernelCache>* cache_ptr) const override {\n    InitEagerCclOpKernelCache(ctx, cache_ptr);\n  }\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*,\n               const user_op::OpKernelCache* cache) const override {\n    auto* kernel_cache = dynamic_cast<const EagerCclOpKernelCache*>(cache);\n    CHECK(kernel_cache != nullptr) << kOfBugIssueUploadPrompt;\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    CHECK_EQ(in->data_type(), out->data_type()) << kOfBugIssueUploadPrompt;\n    const auto& op_type = ctx->Attr<std::string>(\"op_type\");\n    CHECK_EQ(op_type, \"sum\") << kOfBugIssueUploadPrompt;\n    ccl::ReduceType reduce_type = ccl::kSum;\n    if (in->data_type() == kBool) { reduce_type = ccl::kMax; }\n    std::unique_ptr<ccl::ReduceScatter> reduce_scatter =\n        ccl::NewCollectiveCommunication<ccl::ReduceScatter>(ctx->device_type(), in->data_type(),\n                                                            reduce_type);\n    reduce_scatter->Launch(ctx->stream(), in->dptr(), out->mut_dptr(), out->shape_view().elem_cnt(),\n                           kernel_cache->communication_ctx());\n  };\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\nREGISTER_USER_KERNEL(\"eager_ccl_reduce_scatter\")\n    .SetCreateFn<EagerCclReduceScatterKernel>()\n    .SetIsMatchedHob(ReduceScatterCollectiveCommunicationExists());\n\nclass EagerCclAllGatherKernel final : public user_op::OpKernel {\n public:\n  EagerCclAllGatherKernel() = default;\n  ~EagerCclAllGatherKernel() override = default;\n\n  void InitOpKernelCacheWithFlags(\n      user_op::KernelCacheContext* ctx, int8_t flag,\n      std::shared_ptr<user_op::OpKernelCache>* cache_ptr) const override {\n    InitEagerCclOpKernelCache(ctx, cache_ptr);\n  }\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*,\n               const user_op::OpKernelCache* cache) const override {\n    auto* kernel_cache = dynamic_cast<const EagerCclOpKernelCache*>(cache);\n    CHECK(kernel_cache != nullptr) << kOfBugIssueUploadPrompt;\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    CHECK_EQ(in->data_type(), out->data_type()) << kOfBugIssueUploadPrompt;\n    std::unique_ptr<ccl::AllGather> all_gather =\n        ccl::NewCollectiveCommunication<ccl::AllGather>(ctx->device_type(), in->data_type());\n    all_gather->Launch(ctx->stream(), in->dptr(), out->mut_dptr(), in->shape_view().elem_cnt(),\n                       kernel_cache->communication_ctx());\n  };\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\nREGISTER_USER_KERNEL(\"eager_ccl_all_gather\")\n    .SetCreateFn<EagerCclAllGatherKernel>()\n    .SetIsMatchedHob(AllGatherCollectiveCommunicationExists());\n\nclass EagerCclReduceKernel final : public user_op::OpKernel {\n public:\n  EagerCclReduceKernel() = default;\n  ~EagerCclReduceKernel() override = default;\n\n  void InitOpKernelCacheWithFlags(\n      user_op::KernelCacheContext* ctx, int8_t flag,\n      std::shared_ptr<user_op::OpKernelCache>* cache_ptr) const override {\n    InitEagerCclOpKernelCache(ctx, cache_ptr);\n  }\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*,\n               const user_op::OpKernelCache* cache) const override {\n    auto* kernel_cache = dynamic_cast<const EagerCclOpKernelCache*>(cache);\n    CHECK(kernel_cache != nullptr);\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    int64_t root = ctx->Attr<int64_t>(\"root\");\n    void* out_ptr = out->mut_dptr();\n    if (GlobalProcessCtx::Rank() == root) {\n      CHECK_EQ(in->shape_view(), out->shape_view());\n      CHECK_EQ(in->data_type(), out->data_type());\n    }\n    if (out_ptr != nullptr) {\n      CHECK_EQ(in->shape_view(), out->shape_view());\n      CHECK_EQ(in->data_type(), out->data_type());\n    }\n\n    ccl::ReduceType reduce_type = ccl::kSum;\n    if (in->data_type() == kBool) { reduce_type = ccl::kMax; }\n\n    std::unique_ptr<ccl::Reduce> reduce = ccl::NewCollectiveCommunication<ccl::Reduce>(\n        ctx->device_type(), in->data_type(), reduce_type);\n    reduce->Launch(ctx->stream(), in->dptr(), out_ptr, in->shape_view().elem_cnt(), root,\n                   kernel_cache->communication_ctx());\n  };\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\nREGISTER_USER_KERNEL(\"eager_ccl_reduce\")\n    .SetCreateFn<EagerCclReduceKernel>()\n    .SetIsMatchedHob(ReduceCollectiveCommunicationExists());\n\nclass EagerCclBroadcastKernel final : public user_op::OpKernel {\n public:\n  EagerCclBroadcastKernel() = default;\n  ~EagerCclBroadcastKernel() override = default;\n\n  void InitOpKernelCacheWithFlags(\n      user_op::KernelCacheContext* ctx, int8_t flag,\n      std::shared_ptr<user_op::OpKernelCache>* cache_ptr) const override {\n    InitEagerCclOpKernelCache(ctx, cache_ptr);\n  }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state,\n               const user_op::OpKernelCache* cache) const override {\n    size_t size = ctx->input_size(\"in\");\n    CHECK_EQ(size, ctx->output_size(\"out\"));\n    for (int i = 0; i < size; ++i) { ComputeForOneInput(ctx, cache, i); }\n  }\n  void ComputeForOneInput(user_op::KernelComputeContext* ctx, const user_op::OpKernelCache* cache,\n                          int index) const {\n    auto* kernel_cache = dynamic_cast<const EagerCclOpKernelCache*>(cache);\n    CHECK(kernel_cache != nullptr);\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"in\", index);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", index);\n    int64_t root = ctx->Attr<int64_t>(\"root\");\n    const void* in_ptr = in->dptr();\n    if (GlobalProcessCtx::Rank() == root) {\n      CHECK_EQ(in->shape_view(), out->shape_view());\n      CHECK_EQ(in->data_type(), out->data_type());\n    }\n    if (in_ptr != nullptr) {\n      CHECK_EQ(in->shape_view(), out->shape_view());\n      CHECK_EQ(in->data_type(), out->data_type());\n    }\n\n    std::unique_ptr<ccl::Broadcast> broadcast =\n        ccl::NewCollectiveCommunication<ccl::Broadcast>(ctx->device_type(), out->data_type());\n    broadcast->Launch(ctx->stream(), in_ptr, out->mut_dptr(), out->shape_view().elem_cnt(), root,\n                      kernel_cache->communication_ctx());\n  };\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\nREGISTER_USER_KERNEL(\"eager_ccl_broadcast\")\n    .SetCreateFn<EagerCclBroadcastKernel>()\n    .SetIsMatchedHob(BroadcastCollectiveCommunicationExists());\n\nclass EagerCclTouchKernel final : public user_op::OpKernel {\n public:\n  EagerCclTouchKernel() = default;\n  ~EagerCclTouchKernel() override = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*,\n               const user_op::OpKernelCache* cache) const override{\n      // Do nothing.\n  };\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return true; }\n};\n\nREGISTER_USER_KERNEL(\"eager_ccl_touch\")\n    .SetCreateFn<EagerCclTouchKernel>()\n    .SetIsMatchedHob(!(user_op::HobDeviceType() == DeviceType::kInvalidDevice)\n                     && !(user_op::HobDeviceType() == DeviceType::kMockDevice));\n\nnamespace {\n\nclass EagerCclS2SCpuOpKernelCache final : public user_op::OpKernelCache {\n public:\n  explicit EagerCclS2SCpuOpKernelCache(user_op::KernelCacheContext* ctx) { Init(ctx); }\n  ~EagerCclS2SCpuOpKernelCache() override = default;\n\n  Symbol<ParallelDesc> parallel_desc() const { return parallel_desc_; }\n\n private:\n  void Init(user_op::KernelCacheContext* ctx) {\n    const std::string& parallel_conf_txt = ctx->Attr<std::string>(\"parallel_conf\");\n    ParallelConf parallel_conf;\n    CHECK(TxtString2PbMessage(parallel_conf_txt, &parallel_conf));\n    parallel_desc_ = SymbolOf(ParallelDesc(parallel_conf));\n  }\n\n  Symbol<ParallelDesc> parallel_desc_;\n};\n\nsize_t InferEagerCclS2SCpuKernelTmpBufferSize(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc(\"in\", 0);\n  size_t tensor_byte_size = in_tensor.shape().elem_cnt() * GetSizeOfDataType(in_tensor.data_type());\n  // NOTE(hanbinbin): Set tmp_buffer_size to twice tensor_byte_size because the\n  // SbpParallel4ArgNameAndIndex function of LocalUserOpInferContext is unimplemented\n  return tensor_byte_size * 2;\n}\n\nMaybe<std::vector<std::pair<int64_t, int64_t>>> RawGroupP2PPair(\n    Symbol<ParallelDesc> parallel_desc) {\n  std::shared_ptr<std::vector<std::pair<int64_t, int64_t>>> p2p_pairs =\n      std::make_shared<std::vector<std::pair<int64_t, int64_t>>>();\n  for (int64_t src : parallel_desc->sorted_machine_ids()) {\n    for (int64_t dst : parallel_desc->sorted_machine_ids()) {\n      p2p_pairs->emplace_back(std::make_pair(src, dst));\n    }\n  }\n  return p2p_pairs;\n}\n\nstatic constexpr auto* GroupP2PPair = DECORATE(&RawGroupP2PPair, ThreadLocal);\n\n}  // namespace\n\ntemplate<typename T>\nclass EagerCclS2SCPUKernel final : public user_op::OpKernel {\n public:\n  EagerCclS2SCPUKernel() = default;\n  ~EagerCclS2SCPUKernel() override = default;\n\n  void InitOpKernelCacheWithFlags(\n      user_op::KernelCacheContext* ctx, int8_t flag,\n      std::shared_ptr<user_op::OpKernelCache>* cache_ptr) const override {\n    // NOTE(jianhao): the cache only depends on parallel_conf, and the kernel is singleton\n    // once parallel_conf is determined, so only init the cache at the first time.\n    if (*cache_ptr == nullptr) { *cache_ptr = std::make_shared<EagerCclS2SCpuOpKernelCache>(ctx); }\n  }\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*,\n               const user_op::OpKernelCache* cache) const override {\n    auto* kernel_cache = dynamic_cast<const EagerCclS2SCpuOpKernelCache*>(cache);\n    CHECK(kernel_cache != nullptr);\n    // NOTE(hanbinbin): Compute logic copy from _nccl_logical_s2s\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n    const int64_t dtype_size = GetSizeOfDataType(in->data_type());\n    int64_t data_size = in->shape_view().elem_cnt() * dtype_size;\n    // NOTE: in (transpose)-> pack_to_ptr (all2all)-> unpack_from_ptr (transpose)-> out\n    const char* pack_to_ptr = in->dptr<char>();\n    char* unpack_from_ptr = out->mut_dptr<char>();\n    int64_t tmp_size = tmp_buffer->shape_view().elem_cnt();\n    CHECK_EQ(tmp_size, data_size * 2);\n\n    CHECK_EQ(in->data_type(), out->data_type());\n    const int64_t num_ranks = kernel_cache->parallel_desc()->parallel_num();\n    CHECK_EQ(in->shape_view().elem_cnt(), out->shape_view().elem_cnt())\n        << in->shape_view().ToString() << \" vs \" << out->shape_view().ToString();\n    const int64_t elem_cnt = in->shape_view().elem_cnt();\n    const int64_t in_split_axis = ctx->Attr<int64_t>(\"in_split_axis\");\n    const int64_t out_split_axis = ctx->Attr<int64_t>(\"out_split_axis\");\n\n    DimVector logical_shape_dim_vec;\n    in->shape_view().ToDimVector(&logical_shape_dim_vec);\n    logical_shape_dim_vec[in_split_axis] = logical_shape_dim_vec.at(in_split_axis) * num_ranks;\n\n    if (out_split_axis != 0) {\n      // Do pack. Need transpose in -> pack_to\n      // pack use temp buffer offset: [0, data_size]\n      pack_to_ptr = tmp_buffer->dptr<char>();\n      DimVector transpose_in_dim_vec = logical_shape_dim_vec;\n      CHECK_EQ(transpose_in_dim_vec.at(in_split_axis) % num_ranks, 0);\n      transpose_in_dim_vec[in_split_axis] = transpose_in_dim_vec.at(in_split_axis) / num_ranks;\n      CHECK_EQ(transpose_in_dim_vec.at(out_split_axis) % num_ranks, 0);\n      transpose_in_dim_vec[out_split_axis] = transpose_in_dim_vec.at(out_split_axis) / num_ranks;\n      transpose_in_dim_vec.insert(transpose_in_dim_vec.begin() + out_split_axis, num_ranks);\n      std::vector<int32_t> perm;\n      perm.emplace_back(out_split_axis);\n      FOR_RANGE(int64_t, i, 0, transpose_in_dim_vec.size()) {\n        if (i != out_split_axis) { perm.emplace_back(i); }\n      }\n      auto transpose = ep::primitive::NewPrimitive<ep::primitive::PermuteFactory>(\n          ctx->stream()->device_type(), transpose_in_dim_vec.size());\n      CHECK(transpose);\n      transpose->Launch(ctx->stream(), in->data_type(), transpose_in_dim_vec.size(),\n                        transpose_in_dim_vec.data(), in->dptr(), perm.data(),\n                        tmp_buffer->mut_dptr());\n    }\n\n    if (in_split_axis != 0) {\n      // Do unpack. Need transpose unpack_from -> out\n      // unpack use temp buffer offset: [tmp_size - data_size, tmp_size]\n      unpack_from_ptr = tmp_buffer->mut_dptr<char>() + (tmp_size - data_size);\n    }\n\n    {\n      // NOTE: Do S2S\n      const int64_t elem_per_chunk = elem_cnt / num_ranks;\n      const int64_t chunk_size = elem_per_chunk * dtype_size;\n      const auto& p2p_pairs = CHECK_JUST(GroupP2PPair(kernel_cache->parallel_desc()));\n      for (const auto& pair : *p2p_pairs) {\n        int64_t src = pair.first;\n        int64_t dst = pair.second;\n\n        if (GlobalProcessCtx::Rank() == src) {\n          Symbol<ParallelDesc> parallel_desc = kernel_cache->parallel_desc();\n          int64_t device_id = GlobalProcessCtx::LocalRank(dst);\n          int64_t parallel_id =\n              CHECK_JUST(parallel_desc->ParallelId4MachineDeviceId(dst, device_id));\n\n          CHECK_JUST(Send(reinterpret_cast<const void*>(reinterpret_cast<const char*>(pack_to_ptr)\n                                                        + parallel_id * chunk_size),\n                          elem_per_chunk, in->data_type(), dst, DeviceType::kCPU, ctx->stream()));\n        }\n        if (GlobalProcessCtx::Rank() == dst) {\n          Symbol<ParallelDesc> parallel_desc = kernel_cache->parallel_desc();\n          int64_t device_id = GlobalProcessCtx::LocalRank(src);\n          int64_t parallel_id =\n              CHECK_JUST(parallel_desc->ParallelId4MachineDeviceId(src, device_id));\n\n          CHECK_JUST(Recv(reinterpret_cast<void*>(reinterpret_cast<char*>(unpack_from_ptr)\n                                                  + parallel_id * chunk_size),\n                          elem_per_chunk, out->data_type(), src, DeviceType::kCPU, ctx->stream()));\n        }\n      }\n    }\n\n    if (in_split_axis != 0) {\n      // Do unpack.\n      CHECK(unpack_from_ptr != out->mut_dptr<char>());\n      DimVector unpack_from_dim_vec = logical_shape_dim_vec;\n      CHECK_EQ(unpack_from_dim_vec.at(in_split_axis) % num_ranks, 0);\n      unpack_from_dim_vec[in_split_axis] = unpack_from_dim_vec.at(in_split_axis) / num_ranks;\n      CHECK_EQ(unpack_from_dim_vec.at(out_split_axis) % num_ranks, 0);\n      unpack_from_dim_vec[out_split_axis] = unpack_from_dim_vec.at(out_split_axis) / num_ranks;\n      unpack_from_dim_vec.insert(unpack_from_dim_vec.begin(), num_ranks);\n      std::vector<int32_t> perm;\n      FOR_RANGE(int64_t, i, 1, unpack_from_dim_vec.size()) { perm.emplace_back(i); }\n      perm.insert(perm.begin() + in_split_axis, 0);\n      auto transpose = ep::primitive::NewPrimitive<ep::primitive::PermuteFactory>(\n          ctx->stream()->device_type(), unpack_from_dim_vec.size());\n      CHECK(transpose);\n      transpose->Launch(ctx->stream(), in->data_type(), unpack_from_dim_vec.size(),\n                        unpack_from_dim_vec.data(), unpack_from_ptr, perm.data(), out->mut_dptr());\n    }\n  };\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_EAGER_CCL_S2S_CPU_KERNEL(dtype)                                         \\\n  REGISTER_USER_KERNEL(\"eager_ccl_s2s\")                                                  \\\n      .SetCreateFn<EagerCclS2SCPUKernel<dtype>>()                                        \\\n      .SetIsMatchedHob(!(user_op::HobDeviceType() == DeviceType::kCUDA)                  \\\n                       && HobIsSendAndRecvRegistered()                                   \\\n                       && (user_op::HobDataType(\"in\", 0) == GetDataType<dtype>::value)   \\\n                       && (user_op::HobDataType(\"out\", 0) == GetDataType<dtype>::value)) \\\n      .SetInferTmpSizeFn(InferEagerCclS2SCpuKernelTmpBufferSize);\n\nREGISTER_EAGER_CCL_S2S_CPU_KERNEL(int8_t)\nREGISTER_EAGER_CCL_S2S_CPU_KERNEL(int32_t)\nREGISTER_EAGER_CCL_S2S_CPU_KERNEL(int64_t)\nREGISTER_EAGER_CCL_S2S_CPU_KERNEL(bool)\nREGISTER_EAGER_CCL_S2S_CPU_KERNEL(float)\nREGISTER_EAGER_CCL_S2S_CPU_KERNEL(double)\n\n#undef REGISTER_EAGER_CCL_S2S_KERNEL\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/eager_nccl_s2s_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/container_util.h\"\n#include \"oneflow/core/control/global_process_ctx.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/device/nccl_util.h\"\n#include \"oneflow/core/job/eager_nccl_comm_manager.h\"\n#include \"oneflow/core/job/parallel_desc.h\"\n#include \"oneflow/core/ep/include/primitive/permute.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n#include \"oneflow/user/kernels/collective_communication/include/all_to_all.h\"\n\n#if (defined(WITH_CUDA) && (NCCL_VERSION_CODE > 2700)) || defined(WITH_NPU) || defined(WITH_MLU)\n\nnamespace oneflow {\n\nnamespace {\n\nclass EagerCclS2SOpKernelCache final : public user_op::OpKernelCache {\n public:\n  explicit EagerCclS2SOpKernelCache(user_op::KernelCacheContext* ctx) { Init(ctx); }\n  ~EagerCclS2SOpKernelCache() override = default;\n\n  Symbol<ParallelDesc> parallel_desc() const { return parallel_desc_; }\n  const ccl::CclComm& ccl_comm() const { return ccl_comm_; }\n\n private:\n  void Init(user_op::KernelCacheContext* ctx) {\n    const std::string& parallel_conf_txt = ctx->Attr<std::string>(\"parallel_conf\");\n    ParallelConf parallel_conf;\n    CHECK(TxtString2PbMessage(parallel_conf_txt, &parallel_conf));\n    parallel_desc_ = SymbolOf(ParallelDesc(parallel_conf));\n    EagerCclCommMgr* comm_mgr = CHECK_NOTNULL(Singleton<EagerCclCommMgr>::Get());\n    ccl_comm_ = comm_mgr->GetCclCommForParallelDesc(parallel_conf);\n  }\n\n  Symbol<ParallelDesc> parallel_desc_;\n  ccl::CclComm ccl_comm_{};\n};\n\nsize_t InferEagerCclS2SKernelTmpBufferSize(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc(\"in\", 0);\n  size_t tensor_byte_size =\n      GetCudaAlignedSize(in_tensor.shape().elem_cnt() * GetSizeOfDataType(in_tensor.data_type()));\n  // NOTE(hanbinbin): Set tmp_buffer_size to twice tensor_byte_size because the\n  // SbpParallel4ArgNameAndIndex function of LocalUserOpInferContext is unimplemented\n  return tensor_byte_size * 2;\n}\n\nvoid InitEagerCclS2SOpKernelCache(user_op::KernelCacheContext* ctx,\n                                  std::shared_ptr<user_op::OpKernelCache>* cache_ptr) {\n  // NOTE(jianhao): the cache only depends on parallel_conf, and the kernel is singleton\n  // once parallel_conf is determined, so only init the cache at the first time.\n  if (*cache_ptr == nullptr) { *cache_ptr = std::make_shared<EagerCclS2SOpKernelCache>(ctx); }\n}\n}  // namespace\n\ntemplate<typename T>\nclass EagerCclS2SKernel final : public user_op::OpKernel {\n public:\n  EagerCclS2SKernel() = default;\n  ~EagerCclS2SKernel() override = default;\n\n  void InitOpKernelCacheWithFlags(\n      user_op::KernelCacheContext* ctx, int8_t flag,\n      std::shared_ptr<user_op::OpKernelCache>* cache_ptr) const override {\n    InitEagerCclS2SOpKernelCache(ctx, cache_ptr);\n  }\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*,\n               const user_op::OpKernelCache* cache) const override {\n    auto* kernel_cache = dynamic_cast<const EagerCclS2SOpKernelCache*>(cache);\n    CHECK(kernel_cache != nullptr);\n    // NOTE(hanbinbin): Compute logic copy from _nccl_logical_s2s\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n    int64_t tmp_size = 0;\n    const int64_t dtype_size = GetSizeOfDataType(in->data_type());\n    int64_t data_size = GetCudaAlignedSize(in->shape_view().elem_cnt() * dtype_size);\n    // NOTE(chengcheng): in (transpose)-> pack_to_ptr (all2all)-> unpack_from_ptr (transpose)-> out\n    const char* pack_to_ptr = in->dptr<char>();\n    char* unpack_from_ptr = out->mut_dptr<char>();\n    if (tmp_buffer) { tmp_size = tmp_buffer->shape_view().elem_cnt(); }\n    CHECK(tmp_size == 0 || tmp_size == data_size || tmp_size == data_size * 2);\n\n    CHECK_EQ(in->data_type(), out->data_type());\n    const int64_t num_ranks = kernel_cache->parallel_desc()->parallel_num();\n    CHECK_EQ(in->shape_view().elem_cnt(), out->shape_view().elem_cnt())\n        << in->shape_view().ToString() << \" vs \" << out->shape_view().ToString();\n    const int64_t elem_cnt = in->shape_view().elem_cnt();\n    const int64_t in_split_axis = ctx->Attr<int64_t>(\"in_split_axis\");\n    const int64_t out_split_axis = ctx->Attr<int64_t>(\"out_split_axis\");\n\n    DimVector logical_shape_dim_vec;\n    in->shape_view().ToDimVector(&logical_shape_dim_vec);\n    logical_shape_dim_vec[in_split_axis] = logical_shape_dim_vec.at(in_split_axis) * num_ranks;\n\n    if (out_split_axis != 0) {\n      // NOTE(chengcheng): Do pack. Need transpose in -> pack_to\n      // pack use temp buffer offset: [0, data_size]\n      pack_to_ptr = tmp_buffer->dptr<char>();\n      DimVector transpose_in_dim_vec = logical_shape_dim_vec;\n      CHECK_EQ(transpose_in_dim_vec.at(in_split_axis) % num_ranks, 0);\n      transpose_in_dim_vec[in_split_axis] = transpose_in_dim_vec.at(in_split_axis) / num_ranks;\n      CHECK_EQ(transpose_in_dim_vec.at(out_split_axis) % num_ranks, 0);\n      transpose_in_dim_vec[out_split_axis] = transpose_in_dim_vec.at(out_split_axis) / num_ranks;\n      transpose_in_dim_vec.insert(transpose_in_dim_vec.begin() + out_split_axis, num_ranks);\n      std::vector<int32_t> perm;\n      perm.emplace_back(out_split_axis);\n      FOR_RANGE(int64_t, i, 0, transpose_in_dim_vec.size()) {\n        if (i != out_split_axis) { perm.emplace_back(i); }\n      }\n      auto transpose = ep::primitive::NewPrimitive<ep::primitive::PermuteFactory>(\n          ctx->stream()->device_type(), transpose_in_dim_vec.size());\n      CHECK(transpose);\n      transpose->Launch(ctx->stream(), in->data_type(), transpose_in_dim_vec.size(),\n                        transpose_in_dim_vec.data(), in->dptr(), perm.data(),\n                        tmp_buffer->mut_dptr());\n    }\n\n    if (in_split_axis != 0) {\n      // NOTE(chengcheng): Do unpack. Need transpose unpack_from -> out\n      // unpack use temp buffer offset: [tmp_size - data_size, tmp_size]\n      unpack_from_ptr = tmp_buffer->mut_dptr<char>() + (tmp_size - data_size);\n    }\n\n    {\n      // NOTE: Do S2S\n      const int64_t elem_per_chunk = elem_cnt / num_ranks;\n      std::unique_ptr<ccl::AllToAll> all_to_all = ccl::NewCollectiveCommunication<ccl::AllToAll>(\n          ctx->stream()->device_type(), in->data_type(), in->data_type(), num_ranks);\n      const auto& ccl_comm = kernel_cache->ccl_comm();\n      all_to_all->Launch(ctx->stream(), const_cast<char*>(pack_to_ptr), elem_per_chunk,\n                         unpack_from_ptr, elem_per_chunk, ccl_comm);\n    }\n\n    if (in_split_axis != 0) {\n      // Do unpack.\n      CHECK(unpack_from_ptr != out->mut_dptr<char>());\n      DimVector unpack_from_dim_vec = logical_shape_dim_vec;\n      CHECK_EQ(unpack_from_dim_vec.at(in_split_axis) % num_ranks, 0);\n      unpack_from_dim_vec[in_split_axis] = unpack_from_dim_vec.at(in_split_axis) / num_ranks;\n      CHECK_EQ(unpack_from_dim_vec.at(out_split_axis) % num_ranks, 0);\n      unpack_from_dim_vec[out_split_axis] = unpack_from_dim_vec.at(out_split_axis) / num_ranks;\n      unpack_from_dim_vec.insert(unpack_from_dim_vec.begin(), num_ranks);\n      std::vector<int32_t> perm;\n      FOR_RANGE(int64_t, i, 1, unpack_from_dim_vec.size()) { perm.emplace_back(i); }\n      perm.insert(perm.begin() + in_split_axis, 0);\n      auto transpose = ep::primitive::NewPrimitive<ep::primitive::PermuteFactory>(\n          ctx->stream()->device_type(), unpack_from_dim_vec.size());\n      CHECK(transpose);\n      transpose->Launch(ctx->stream(), in->data_type(), unpack_from_dim_vec.size(),\n                        unpack_from_dim_vec.data(), unpack_from_ptr, perm.data(), out->mut_dptr());\n    }\n  };\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_EAGER_CCL_S2S_KERNEL(dtype)                                             \\\n  REGISTER_USER_KERNEL(\"eager_ccl_s2s\")                                                  \\\n      .SetCreateFn<EagerCclS2SKernel<dtype>>()                                           \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)                   \\\n                       && (user_op::HobDataType(\"in\", 0) == GetDataType<dtype>::value)   \\\n                       && (user_op::HobDataType(\"out\", 0) == GetDataType<dtype>::value)) \\\n      .SetInferTmpSizeFn(InferEagerCclS2SKernelTmpBufferSize);\n\nREGISTER_EAGER_CCL_S2S_KERNEL(int8_t)\nREGISTER_EAGER_CCL_S2S_KERNEL(int32_t)\nREGISTER_EAGER_CCL_S2S_KERNEL(int64_t)\nREGISTER_EAGER_CCL_S2S_KERNEL(bool)\nREGISTER_EAGER_CCL_S2S_KERNEL(float)\nREGISTER_EAGER_CCL_S2S_KERNEL(double)\nREGISTER_EAGER_CCL_S2S_KERNEL(float16)\n#undef REGISTER_EAGER_CCL_S2S_KERNEL\n\n}  // namespace oneflow\n\n#endif  // WITH_CUDA || WITH_NPU\n"
  },
  {
    "path": "oneflow/user/kernels/eager_p_to_b_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/user/kernels/communicate_util.h\"\n#include \"oneflow/core/common/container_util.h\"\n#include \"oneflow/core/device/nccl_util.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/new_kernel_util.h\"\n#include \"oneflow/core/job/parallel_desc.h\"\n#include \"oneflow/core/control/global_process_ctx.h\"\n#include \"oneflow/core/framework/placement_sbp_util.h\"\n#include \"oneflow/core/ep/include/primitive/add.h\"\n#include \"oneflow/core/ep/include/primitive/memset.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nclass EagerPToBOpKernelCache final : public user_op::OpKernelCache {\n public:\n  explicit EagerPToBOpKernelCache(user_op::KernelCacheContext* ctx) { Init(ctx); }\n  ~EagerPToBOpKernelCache() override = default;\n\n  const std::vector<std::pair<int64_t, int64_t>>& p2p_pair() const { return p2p_pair_; }\n\n private:\n  void Init(user_op::KernelCacheContext* ctx) {\n    const std::string& in_parallel_conf_txt = ctx->Attr<std::string>(\"in_parallel_conf\");\n    const std::string& out_parallel_conf_txt = ctx->Attr<std::string>(\"out_parallel_conf\");\n    Symbol<ParallelDesc> in_parallel_desc = CHECK_JUST(TxtStringToPlacement(in_parallel_conf_txt));\n    Symbol<ParallelDesc> out_parallel_desc =\n        CHECK_JUST(TxtStringToPlacement(out_parallel_conf_txt));\n    int64_t out_parallel_num = out_parallel_desc->parallel_num();\n    int64_t in_parallel_num = in_parallel_desc->parallel_num();\n\n    for (int64_t out_parallel_id = 0; out_parallel_id < out_parallel_num; ++out_parallel_id) {\n      int64_t dst = CHECK_JUST(out_parallel_desc->MachineId4ParallelId(out_parallel_id));\n      for (int64_t in_parallel_id = 0; in_parallel_id < in_parallel_num; ++in_parallel_id) {\n        int64_t src = CHECK_JUST(in_parallel_desc->MachineId4ParallelId(in_parallel_id));\n        p2p_pair_.emplace_back(std::make_pair(src, dst));\n      }\n    }\n  }\n\n  std::vector<std::pair<int64_t, int64_t>> p2p_pair_;\n};\n\nsize_t InferEagerPToBKernelTmpBufferSize(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc(\"in\", 0);\n  const Shape& shape = ctx->Attr<Shape>(\"shape\");\n  size_t tensor_byte_size = shape.elem_cnt() * GetSizeOfDataType(in_tensor.data_type());\n  return tensor_byte_size;\n}\n\n}  // namespace\n\nclass EagerPToBKernel final : public user_op::OpKernel {\n public:\n  EagerPToBKernel() = default;\n  ~EagerPToBKernel() override = default;\n\n  void InitOpKernelCacheWithFlags(\n      user_op::KernelCacheContext* ctx, int8_t flag,\n      std::shared_ptr<user_op::OpKernelCache>* cache_ptr) const override {\n    if (*cache_ptr == nullptr) { *cache_ptr = std::make_shared<EagerPToBOpKernelCache>(ctx); }\n  }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*,\n               const user_op::OpKernelCache* cache) const override {\n    auto* kernel_cache = dynamic_cast<const EagerPToBOpKernelCache*>(cache);\n    CHECK(kernel_cache != nullptr);\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n    const void* in_ptr = in->dptr();\n    void* tmp_buffer_ptr = tmp_buffer->mut_dptr();\n\n    const int64_t total_elem_cnt = ctx->Attr<Shape>(\"shape\").elem_cnt();\n    const auto& p2p_pair = kernel_cache->p2p_pair();\n\n    DeviceType device_type = ctx->device_type();\n\n    std::unique_ptr<ep::primitive::Memset> memset_primitive =\n        ep::primitive::NewPrimitive<ep::primitive::MemsetFactory>(device_type);\n    CHECK(memset_primitive) << \"Can not create Memset primitive for device type \" << device_type;\n    memset_primitive->Launch(ctx->stream(), out->mut_dptr(), 0,\n                             total_elem_cnt * GetSizeOfDataType(out->data_type()));\n\n    std::unique_ptr<ep::primitive::Add> add_primitive =\n        ep::primitive::NewPrimitive<ep::primitive::AddFactory>(ctx->device_type(), in->data_type());\n    CHECK(add_primitive);\n    for (const auto& pair : p2p_pair) {\n      int64_t src = pair.first;\n      int64_t dst = pair.second;\n\n      if (GlobalProcessCtx::Rank() == src) {\n        CHECK_JUST(Send(in_ptr, total_elem_cnt, in->data_type(), dst, device_type, ctx->stream()));\n      }\n      if (GlobalProcessCtx::Rank() == dst) {\n        CHECK_JUST(Recv(tmp_buffer_ptr, total_elem_cnt, out->data_type(), src, device_type,\n                        ctx->stream()));\n        add_primitive->Launch(ctx->stream(), out->dptr(), tmp_buffer_ptr, out->mut_dptr(),\n                              total_elem_cnt);\n      }\n    }\n  };\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\nREGISTER_USER_KERNEL(\"eager_p_to_b\")\n    .SetCreateFn<EagerPToBKernel>()\n    .SetIsMatchedHob(HobIsSendAndRecvRegistered())\n    .SetInferTmpSizeFn(InferEagerPToBKernelTmpBufferSize);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/eager_p_to_s_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/user/kernels/communicate_util.h\"\n#include \"oneflow/core/device/nccl_util.h\"\n#include \"oneflow/core/common/balanced_splitter.h\"\n#include \"oneflow/core/common/container_util.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/new_kernel_util.h\"\n#include \"oneflow/core/job/parallel_desc.h\"\n#include \"oneflow/core/control/global_process_ctx.h\"\n#include \"oneflow/core/framework/placement_sbp_util.h\"\n#include \"oneflow/core/job/nd_sbp_util.h\"\n#include \"oneflow/core/register/tensor_slice_copier.h\"\n#include \"oneflow/core/ep/include/primitive/add.h\"\n#include \"oneflow/core/ep/include/primitive/memset.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nMaybe<Symbol<NdSbp>> GetAllSplitNdSbp(int64_t axis, int64_t ndim) {\n  NdSbp split_nd_sbp;\n  for (int64_t i = 0; i < ndim; ++i) {\n    split_nd_sbp.mutable_sbp_parallel()->Add()->mutable_split_parallel()->set_axis(axis);\n  }\n  return SymbolOf(split_nd_sbp);\n}\n\nauto* CachedGetAllSplitNdSbp = DECORATE(&GetAllSplitNdSbp, ThreadLocal);\n\nMaybe<Symbol<NdSbp>> GetAllPartialSumNdSbp(int64_t ndim) {\n  NdSbp split_nd_sbp;\n  for (int64_t i = 0; i < ndim; ++i) {\n    split_nd_sbp.mutable_sbp_parallel()->Add()->mutable_partial_sum_parallel();\n  }\n  return SymbolOf(split_nd_sbp);\n}\n\nauto* CachedGetAllPartialSumNdSbp = DECORATE(&GetAllPartialSumNdSbp, ThreadLocal);\n\nclass EagerPToSOpKernelCache final : public user_op::OpKernelCache {\n public:\n  explicit EagerPToSOpKernelCache(user_op::KernelCacheContext* ctx) : elem_cnt_of_this_chunk_(0) {\n    Init(ctx);\n  }\n  ~EagerPToSOpKernelCache() override = default;\n\n  int64_t elem_cnt_of_this_chunk() const { return elem_cnt_of_this_chunk_; }\n\n  const std::vector<std::pair<int64_t, std::shared_ptr<TensorSliceCopier>>>&\n  sorted_elem_cnt2_in_tensor_slice_copier() const {\n    return sorted_elem_cnt2_in_tensor_slice_copier_;\n  }\n\n  const std::vector<std::pair<int64_t, int64_t>>& sorted_p2p_pair() const {\n    return sorted_p2p_pair_;\n  }\n\n private:\n  void Init(user_op::KernelCacheContext* ctx) {\n    const std::string& in_parallel_conf_txt = ctx->Attr<std::string>(\"in_parallel_conf\");\n    const std::string& out_parallel_conf_txt = ctx->Attr<std::string>(\"out_parallel_conf\");\n    const int64_t out_split_axis = ctx->Attr<int64_t>(\"out_split_axis\");\n    const Shape& shape = ctx->Attr<Shape>(\"shape\");\n    DeviceType device_type = ctx->device_type();\n    DataType data_type = ctx->TensorDesc4ArgNameAndIndex(\"in\", 0)->data_type();\n    Symbol<ParallelDesc> in_parallel_desc = CHECK_JUST(TxtStringToPlacement(in_parallel_conf_txt));\n    Symbol<ParallelDesc> out_parallel_desc =\n        CHECK_JUST(TxtStringToPlacement(out_parallel_conf_txt));\n    int64_t out_parallel_num = out_parallel_desc->parallel_num();\n    int64_t in_parallel_num = in_parallel_desc->parallel_num();\n    elem_cnt_of_this_chunk_ = 0;\n    for (int64_t out_parallel_id = 0; out_parallel_id < out_parallel_num; ++out_parallel_id) {\n      int64_t dst = CHECK_JUST(out_parallel_desc->MachineId4ParallelId(out_parallel_id));\n      const TensorSliceView& out_slice = GetTensorSliceView4ParallelId(\n          *out_parallel_desc->hierarchy(),\n          *CHECK_JUST(\n              CachedGetAllSplitNdSbp(out_split_axis, out_parallel_desc->hierarchy()->NumAxes())),\n          shape, out_parallel_id);\n      CHECK(!out_slice.IsEmpty());\n      for (int64_t in_parallel_id = 0; in_parallel_id < in_parallel_num; ++in_parallel_id) {\n        int64_t src = CHECK_JUST(in_parallel_desc->MachineId4ParallelId(in_parallel_id));\n        const TensorSliceView& in_slice = GetTensorSliceView4ParallelId(\n            *in_parallel_desc->hierarchy(),\n            *CHECK_JUST(CachedGetAllPartialSumNdSbp(in_parallel_desc->hierarchy()->NumAxes())),\n            shape, in_parallel_id);\n        CHECK(!in_slice.IsEmpty());\n        const TensorSliceView& intersection = out_slice.Intersect(in_slice);\n        CHECK(!intersection.IsEmpty());\n        sorted_p2p_pair_.emplace_back(std::make_pair(src, dst));\n        sorted_elem_cnt2_in_tensor_slice_copier_.emplace_back(std::make_pair(\n            intersection.shape().elem_cnt(),\n            std::make_shared<TensorSliceCopier>(intersection, in_slice, data_type, device_type)));\n      }\n      if (GlobalProcessCtx::Rank() == dst) {\n        elem_cnt_of_this_chunk_ = sorted_elem_cnt2_in_tensor_slice_copier_.back().first;\n      }\n    }\n  }\n\n  int64_t elem_cnt_of_this_chunk_;\n  std::vector<std::pair<int64_t, std::shared_ptr<TensorSliceCopier>>>\n      sorted_elem_cnt2_in_tensor_slice_copier_;\n  std::vector<std::pair<int64_t, int64_t>> sorted_p2p_pair_;\n};\n\nsize_t InferEagerPToSKernelTmpBufferSize(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc(\"in\", 0);\n  Shape shape = ctx->Attr<Shape>(\"shape\");\n  const int64_t out_split_axis = ctx->Attr<int64_t>(\"out_split_axis\");\n  const std::string& out_parallel_conf_txt = ctx->Attr<std::string>(\"out_parallel_conf\");\n  Symbol<ParallelDesc> out_parallel_desc = CHECK_JUST(TxtStringToPlacement(out_parallel_conf_txt));\n  int64_t out_parallel_num = out_parallel_desc->parallel_num();\n  if (out_parallel_num > 1) {\n    CHECK_LT(out_split_axis, shape.NumAxes());\n    BalancedSplitter bs(shape.At(out_split_axis), out_parallel_num);\n    shape.Set(out_split_axis, bs.At(0).size());\n  }\n  size_t tensor_byte_size = shape.elem_cnt() * GetSizeOfDataType(in_tensor.data_type());\n  return tensor_byte_size;\n}\n\n}  // namespace\n\nclass EagerPToSKernel final : public user_op::OpKernel {\n public:\n  EagerPToSKernel() = default;\n  ~EagerPToSKernel() override = default;\n\n  void InitOpKernelCacheWithFlags(\n      user_op::KernelCacheContext* ctx, int8_t flag,\n      std::shared_ptr<user_op::OpKernelCache>* cache_ptr) const override {\n    if (*cache_ptr == nullptr) { *cache_ptr = std::make_shared<EagerPToSOpKernelCache>(ctx); }\n  }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*,\n               const user_op::OpKernelCache* cache) const override {\n    auto* kernel_cache = dynamic_cast<const EagerPToSOpKernelCache*>(cache);\n    CHECK(kernel_cache != nullptr);\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n    const void* in_ptr = in->dptr();\n    void* tmp_buffer_ptr = tmp_buffer->mut_dptr();\n\n    int64_t elem_cnt_of_this_chunk = kernel_cache->elem_cnt_of_this_chunk();\n    const auto& sorted_elem_cnt2_in_tensor_slice_copier =\n        kernel_cache->sorted_elem_cnt2_in_tensor_slice_copier();\n    const auto& sorted_p2p_pair = kernel_cache->sorted_p2p_pair();\n    CHECK_EQ(sorted_elem_cnt2_in_tensor_slice_copier.size(), sorted_p2p_pair.size());\n\n    DeviceType device_type = ctx->device_type();\n\n    std::unique_ptr<ep::primitive::Memset> memset_primitive =\n        ep::primitive::NewPrimitive<ep::primitive::MemsetFactory>(device_type);\n    CHECK(memset_primitive) << \"Can not create Memset primitive for device type \" << device_type;\n    memset_primitive->Launch(ctx->stream(), out->mut_dptr(), 0,\n                             elem_cnt_of_this_chunk * GetSizeOfDataType(out->data_type()));\n\n    std::unique_ptr<ep::primitive::Add> add_primitive =\n        ep::primitive::NewPrimitive<ep::primitive::AddFactory>(ctx->device_type(), in->data_type());\n    CHECK(add_primitive);\n    for (int64_t i = 0; i < sorted_p2p_pair.size(); ++i) {\n      const auto& p2p_pair = sorted_p2p_pair.at(i);\n      int64_t src = p2p_pair.first;\n      int64_t dst = p2p_pair.second;\n      if (GlobalProcessCtx::Rank() == src) {\n        const auto& tensor_slice_copier = sorted_elem_cnt2_in_tensor_slice_copier.at(i).second;\n        int64_t send_elem_cnt = sorted_elem_cnt2_in_tensor_slice_copier.at(i).first;\n        tensor_slice_copier->Copy(ctx->stream(), tmp_buffer_ptr, in_ptr);\n        CHECK_JUST(Send(reinterpret_cast<const void*>(tmp_buffer_ptr), send_elem_cnt,\n                        in->data_type(), dst, device_type, ctx->stream()));\n      }\n      if (GlobalProcessCtx::Rank() == dst) {\n        CHECK_JUST(Recv(tmp_buffer_ptr, elem_cnt_of_this_chunk, out->data_type(), src, device_type,\n                        ctx->stream()));\n        add_primitive->Launch(ctx->stream(), out->dptr(), tmp_buffer_ptr, out->mut_dptr(),\n                              elem_cnt_of_this_chunk);\n      }\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\nREGISTER_USER_KERNEL(\"eager_p_to_s\")\n    .SetCreateFn<EagerPToSKernel>()\n    .SetIsMatchedHob(HobIsSendAndRecvRegistered())\n    .SetInferTmpSizeFn(InferEagerPToSKernelTmpBufferSize);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/eager_s_to_b_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/user/kernels/communicate_util.h\"\n#include \"oneflow/core/device/nccl_util.h\"\n#include \"oneflow/core/common/balanced_splitter.h\"\n#include \"oneflow/core/common/container_util.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/new_kernel_util.h\"\n#include \"oneflow/core/job/parallel_desc.h\"\n#include \"oneflow/core/control/global_process_ctx.h\"\n#include \"oneflow/core/framework/placement_sbp_util.h\"\n#include \"oneflow/core/job/nd_sbp_util.h\"\n#include \"oneflow/core/register/tensor_slice_copier.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nMaybe<Symbol<NdSbp>> GetAllSplitNdSbp(int64_t axis, int64_t ndim) {\n  NdSbp split_nd_sbp;\n  for (int64_t i = 0; i < ndim; ++i) {\n    split_nd_sbp.mutable_sbp_parallel()->Add()->mutable_split_parallel()->set_axis(axis);\n  }\n  return SymbolOf(split_nd_sbp);\n}\n\nauto* CachedGetAllSplitNdSbp = DECORATE(&GetAllSplitNdSbp, ThreadLocal);\n\nMaybe<Symbol<NdSbp>> GetAllBroadcastNdSbp(int64_t ndim) {\n  NdSbp split_nd_sbp;\n  for (int64_t i = 0; i < ndim; ++i) {\n    split_nd_sbp.mutable_sbp_parallel()->Add()->mutable_broadcast_parallel();\n  }\n  return SymbolOf(split_nd_sbp);\n}\n\nauto* CachedGetAllBroadcastNdSbp = DECORATE(&GetAllBroadcastNdSbp, ThreadLocal);\n\nclass EagerSToBOpKernelCache final : public user_op::OpKernelCache {\n public:\n  explicit EagerSToBOpKernelCache(user_op::KernelCacheContext* ctx) { Init(ctx); }\n  ~EagerSToBOpKernelCache() override = default;\n\n  const std::vector<std::pair<int64_t, std::shared_ptr<TensorSliceCopier>>>&\n  sorted_elem_cnt2in_tensor_slice_copier_pair() const {\n    return sorted_elem_cnt2in_tensor_slice_copier_pair_;\n  }\n\n  const std::vector<std::pair<int64_t, std::shared_ptr<TensorSliceCopier>>>&\n  sorted_elem_cnt2out_tensor_slice_copier_pair() const {\n    return sorted_elem_cnt2out_tensor_slice_copier_pair_;\n  }\n\n  const std::vector<std::pair<int64_t, int64_t>>& sorted_p2p_pair() const {\n    return sorted_p2p_pair_;\n  }\n\n private:\n  void Init(user_op::KernelCacheContext* ctx) {\n    const std::string& in_parallel_conf_txt = ctx->Attr<std::string>(\"in_parallel_conf\");\n    const std::string& out_parallel_conf_txt = ctx->Attr<std::string>(\"out_parallel_conf\");\n    const int64_t in_split_axis = ctx->Attr<int64_t>(\"in_split_axis\");\n    const Shape& shape = ctx->Attr<Shape>(\"shape\");\n    DeviceType device_type = ctx->device_type();\n    DataType data_type = ctx->TensorDesc4ArgNameAndIndex(\"in\", 0)->data_type();\n    Symbol<ParallelDesc> in_parallel_desc = CHECK_JUST(TxtStringToPlacement(in_parallel_conf_txt));\n    Symbol<ParallelDesc> out_parallel_desc =\n        CHECK_JUST(TxtStringToPlacement(out_parallel_conf_txt));\n    int64_t out_parallel_num = out_parallel_desc->parallel_num();\n    int64_t in_parallel_num = in_parallel_desc->parallel_num();\n\n    for (int64_t out_parallel_id = 0; out_parallel_id < out_parallel_num; ++out_parallel_id) {\n      int64_t dst = CHECK_JUST(out_parallel_desc->MachineId4ParallelId(out_parallel_id));\n      const TensorSliceView& out_slice = GetTensorSliceView4ParallelId(\n          *out_parallel_desc->hierarchy(),\n          *CHECK_JUST(CachedGetAllBroadcastNdSbp(out_parallel_desc->hierarchy()->NumAxes())), shape,\n          out_parallel_id);\n      CHECK(!out_slice.IsEmpty());\n      for (int64_t in_parallel_id = 0; in_parallel_id < in_parallel_num; ++in_parallel_id) {\n        int64_t src = CHECK_JUST(in_parallel_desc->MachineId4ParallelId(in_parallel_id));\n        const TensorSliceView& in_slice = GetTensorSliceView4ParallelId(\n            *in_parallel_desc->hierarchy(),\n            *CHECK_JUST(\n                CachedGetAllSplitNdSbp(in_split_axis, in_parallel_desc->hierarchy()->NumAxes())),\n            shape, in_parallel_id);\n        CHECK(!in_slice.IsEmpty());\n        const TensorSliceView& intersection = out_slice.Intersect(in_slice);\n        CHECK(!intersection.IsEmpty());\n        sorted_p2p_pair_.emplace_back(std::make_pair(src, dst));\n        sorted_elem_cnt2in_tensor_slice_copier_pair_.emplace_back(std::make_pair(\n            intersection.shape().elem_cnt(),\n            std::make_shared<TensorSliceCopier>(intersection, in_slice, data_type, device_type)));\n        sorted_elem_cnt2out_tensor_slice_copier_pair_.emplace_back(std::make_pair(\n            intersection.shape().elem_cnt(),\n            std::make_shared<TensorSliceCopier>(out_slice, intersection, data_type, device_type)));\n      }\n    }\n  }\n\n  std::vector<std::pair<int64_t, std::shared_ptr<TensorSliceCopier>>>\n      sorted_elem_cnt2in_tensor_slice_copier_pair_;\n  std::vector<std::pair<int64_t, std::shared_ptr<TensorSliceCopier>>>\n      sorted_elem_cnt2out_tensor_slice_copier_pair_;\n  std::vector<std::pair<int64_t, int64_t>> sorted_p2p_pair_;\n};\n\nsize_t InferEagerSToBKernelTmpBufferSize(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc(\"in\", 0);\n  Shape shape = ctx->Attr<Shape>(\"shape\");\n  const int64_t in_split_axis = ctx->Attr<int64_t>(\"in_split_axis\");\n  const std::string& in_parallel_conf_txt = ctx->Attr<std::string>(\"in_parallel_conf\");\n  Symbol<ParallelDesc> in_parallel_desc = CHECK_JUST(TxtStringToPlacement(in_parallel_conf_txt));\n  int64_t in_parallel_num = in_parallel_desc->parallel_num();\n  if (in_parallel_num > 1) {\n    CHECK_LT(in_split_axis, shape.NumAxes());\n    BalancedSplitter bs(shape.At(in_split_axis), in_parallel_num);\n    shape.Set(in_split_axis, bs.At(0).size());\n  }\n  size_t tensor_byte_size = shape.elem_cnt() * GetSizeOfDataType(in_tensor.data_type());\n  return tensor_byte_size;\n}\n\n}  // namespace\n\nclass EagerSToBKernel final : public user_op::OpKernel {\n public:\n  EagerSToBKernel() = default;\n  ~EagerSToBKernel() override = default;\n\n  void InitOpKernelCacheWithFlags(\n      user_op::KernelCacheContext* ctx, int8_t flag,\n      std::shared_ptr<user_op::OpKernelCache>* cache_ptr) const override {\n    if (*cache_ptr == nullptr) { *cache_ptr = std::make_shared<EagerSToBOpKernelCache>(ctx); }\n  }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*,\n               const user_op::OpKernelCache* cache) const override {\n    auto* kernel_cache = dynamic_cast<const EagerSToBOpKernelCache*>(cache);\n    CHECK(kernel_cache != nullptr);\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n    const void* in_ptr = in->dptr();\n    void* out_ptr = out->mut_dptr();\n    void* tmp_buffer_ptr = tmp_buffer->mut_dptr();\n\n    const auto& sorted_elem_cnt2in_tensor_slice_copier_pair =\n        kernel_cache->sorted_elem_cnt2in_tensor_slice_copier_pair();\n    const auto& sorted_elem_cnt2out_tensor_slice_copier_pair =\n        kernel_cache->sorted_elem_cnt2out_tensor_slice_copier_pair();\n    const auto& sorted_p2p_pair = kernel_cache->sorted_p2p_pair();\n    CHECK_EQ(sorted_elem_cnt2in_tensor_slice_copier_pair.size(), sorted_p2p_pair.size());\n    CHECK_EQ(sorted_elem_cnt2out_tensor_slice_copier_pair.size(), sorted_p2p_pair.size());\n\n    DeviceType device_type = ctx->device_type();\n\n    for (int64_t i = 0; i < sorted_p2p_pair.size(); ++i) {\n      const auto& p2p_pair = sorted_p2p_pair.at(i);\n      int64_t src = p2p_pair.first;\n      int64_t dst = p2p_pair.second;\n      if (GlobalProcessCtx::Rank() == src) {\n        const auto& elem_cnt2tensor_slice_copier_pair =\n            sorted_elem_cnt2in_tensor_slice_copier_pair.at(i);\n        const auto& elem_cnt = elem_cnt2tensor_slice_copier_pair.first;\n        const auto& tensor_slice_copier = elem_cnt2tensor_slice_copier_pair.second;\n        tensor_slice_copier->Copy(ctx->stream(), tmp_buffer_ptr, in_ptr);\n        CHECK_JUST(Send(reinterpret_cast<const void*>(tmp_buffer_ptr), elem_cnt, in->data_type(),\n                        dst, device_type, ctx->stream()));\n      }\n      if (GlobalProcessCtx::Rank() == dst) {\n        const auto& elem_cnt2tensor_slice_copier_pair =\n            sorted_elem_cnt2out_tensor_slice_copier_pair.at(i);\n        const auto& elem_cnt = elem_cnt2tensor_slice_copier_pair.first;\n        const auto& tensor_slice_copier = elem_cnt2tensor_slice_copier_pair.second;\n        CHECK_JUST(\n            Recv(tmp_buffer_ptr, elem_cnt, out->data_type(), src, device_type, ctx->stream()));\n        tensor_slice_copier->Copy(ctx->stream(), out_ptr,\n                                  reinterpret_cast<const void*>(tmp_buffer_ptr));\n      }\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\nREGISTER_USER_KERNEL(\"eager_s_to_b\")\n    .SetCreateFn<EagerSToBKernel>()\n    .SetIsMatchedHob(HobIsSendAndRecvRegistered())\n    .SetInferTmpSizeFn(InferEagerSToBKernelTmpBufferSize);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/eager_s_to_p_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/user/kernels/communicate_util.h\"\n#include \"oneflow/core/device/nccl_util.h\"\n#include \"oneflow/core/common/balanced_splitter.h\"\n#include \"oneflow/core/common/container_util.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/new_kernel_util.h\"\n#include \"oneflow/core/job/parallel_desc.h\"\n#include \"oneflow/core/control/global_process_ctx.h\"\n#include \"oneflow/core/framework/placement_sbp_util.h\"\n#include \"oneflow/core/job/nd_sbp_util.h\"\n#include \"oneflow/core/register/tensor_slice_copier.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nMaybe<Symbol<NdSbp>> GetAllSplitNdSbp(int64_t axis, int64_t ndim) {\n  NdSbp split_nd_sbp;\n  for (int64_t i = 0; i < ndim; ++i) {\n    split_nd_sbp.mutable_sbp_parallel()->Add()->mutable_split_parallel()->set_axis(axis);\n  }\n  return SymbolOf(split_nd_sbp);\n}\n\nauto* CachedGetAllSplitNdSbp = DECORATE(&GetAllSplitNdSbp, ThreadLocal);\n\nMaybe<Symbol<NdSbp>> GetAllPartialSumNdSbp(int64_t ndim) {\n  NdSbp split_nd_sbp;\n  for (int64_t i = 0; i < ndim; ++i) {\n    split_nd_sbp.mutable_sbp_parallel()->Add()->mutable_partial_sum_parallel();\n  }\n  return SymbolOf(split_nd_sbp);\n}\n\nauto* CachedGetAllPartialSumNdSbp = DECORATE(&GetAllPartialSumNdSbp, ThreadLocal);\n\nclass EagerSToPOpKernelCache final : public user_op::OpKernelCache {\n public:\n  explicit EagerSToPOpKernelCache(user_op::KernelCacheContext* ctx) { Init(ctx); }\n  ~EagerSToPOpKernelCache() override = default;\n\n  const std::vector<std::pair<int64_t, std::shared_ptr<TensorSliceCopier>>>&\n  sorted_elem_cnt2in_tensor_slice_copier_pair() const {\n    return sorted_elem_cnt2in_tensor_slice_copier_pair_;\n  }\n\n  const std::vector<std::pair<int64_t, std::shared_ptr<TensorSliceCopier>>>&\n  sorted_elem_cnt2out_tensor_slice_copier_pair() const {\n    return sorted_elem_cnt2out_tensor_slice_copier_pair_;\n  }\n\n  const std::vector<std::pair<int64_t, int64_t>>& sorted_p2p_pair() const {\n    return sorted_p2p_pair_;\n  }\n\n private:\n  void Init(user_op::KernelCacheContext* ctx) {\n    const std::string& in_parallel_conf_txt = ctx->Attr<std::string>(\"in_parallel_conf\");\n    const std::string& out_parallel_conf_txt = ctx->Attr<std::string>(\"out_parallel_conf\");\n    const int64_t in_split_axis = ctx->Attr<int64_t>(\"in_split_axis\");\n    const Shape& shape = ctx->Attr<Shape>(\"shape\");\n    DeviceType device_type = ctx->device_type();\n    DataType data_type = ctx->TensorDesc4ArgNameAndIndex(\"in\", 0)->data_type();\n    Symbol<ParallelDesc> in_parallel_desc = CHECK_JUST(TxtStringToPlacement(in_parallel_conf_txt));\n    Symbol<ParallelDesc> out_parallel_desc =\n        CHECK_JUST(TxtStringToPlacement(out_parallel_conf_txt));\n    int64_t in_parallel_num = in_parallel_desc->parallel_num();\n\n    for (int64_t in_parallel_id = 0; in_parallel_id < in_parallel_num; ++in_parallel_id) {\n      int64_t src = CHECK_JUST(in_parallel_desc->MachineId4ParallelId(in_parallel_id));\n      int64_t dst = -1;\n      const TensorSliceView& in_slice = GetTensorSliceView4ParallelId(\n          *in_parallel_desc->hierarchy(),\n          *CHECK_JUST(\n              CachedGetAllSplitNdSbp(in_split_axis, in_parallel_desc->hierarchy()->NumAxes())),\n          shape, in_parallel_id);\n      CHECK(!in_slice.IsEmpty());\n      TensorSliceView out_slice;\n      TensorSliceView intersection;\n      {\n        if (out_parallel_desc->ContainingMachineId(src)) {\n          dst = src;\n          int64_t dst_device_id = GlobalProcessCtx::LocalRank(dst);\n          int64_t out_parallel_id =\n              CHECK_JUST(in_parallel_desc->ParallelId4MachineDeviceId(dst, dst_device_id));\n          out_slice = GetTensorSliceView4ParallelId(\n              *out_parallel_desc->hierarchy(),\n              *CHECK_JUST(CachedGetAllPartialSumNdSbp(out_parallel_desc->hierarchy()->NumAxes())),\n              shape, out_parallel_id);\n          // copy to out_slice from in_slice if src == dst\n          intersection = out_slice;\n        } else {\n          int64_t out_parallel_num = out_parallel_desc->parallel_num();\n          int64_t out_parallel_id = in_parallel_id % out_parallel_num;\n          dst = CHECK_JUST(out_parallel_desc->MachineId4ParallelId(out_parallel_id));\n          out_slice = GetTensorSliceView4ParallelId(\n              *out_parallel_desc->hierarchy(),\n              *CHECK_JUST(CachedGetAllPartialSumNdSbp(in_parallel_desc->hierarchy()->NumAxes())),\n              shape, out_parallel_id);\n          intersection = out_slice.Intersect(in_slice);\n        }\n      }\n      CHECK_NE(dst, -1);\n      CHECK(!out_slice.IsEmpty());\n      CHECK(!intersection.IsEmpty());\n      sorted_p2p_pair_.emplace_back(std::make_pair(src, dst));\n      sorted_elem_cnt2in_tensor_slice_copier_pair_.emplace_back(std::make_pair(\n          intersection.shape().elem_cnt(),\n          std::make_shared<TensorSliceCopier>(intersection, in_slice, data_type, device_type)));\n      sorted_elem_cnt2out_tensor_slice_copier_pair_.emplace_back(std::make_pair(\n          intersection.shape().elem_cnt(),\n          std::make_shared<TensorSliceCopier>(out_slice, intersection, data_type, device_type)));\n    }\n  }\n\n  std::vector<std::pair<int64_t, std::shared_ptr<TensorSliceCopier>>>\n      sorted_elem_cnt2in_tensor_slice_copier_pair_;\n  std::vector<std::pair<int64_t, std::shared_ptr<TensorSliceCopier>>>\n      sorted_elem_cnt2out_tensor_slice_copier_pair_;\n  std::vector<std::pair<int64_t, int64_t>> sorted_p2p_pair_;\n};\n\nsize_t InferEagerSToPKernelTmpBufferSize(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc(\"in\", 0);\n  Shape shape = ctx->Attr<Shape>(\"shape\");\n  const int64_t in_split_axis = ctx->Attr<int64_t>(\"in_split_axis\");\n  const std::string& in_parallel_conf_txt = ctx->Attr<std::string>(\"in_parallel_conf\");\n  Symbol<ParallelDesc> in_parallel_desc = CHECK_JUST(TxtStringToPlacement(in_parallel_conf_txt));\n  int64_t in_parallel_num = in_parallel_desc->parallel_num();\n  if (in_parallel_num > 1) {\n    CHECK_LT(in_split_axis, shape.NumAxes());\n    BalancedSplitter bs(shape.At(in_split_axis), in_parallel_num);\n    shape.Set(in_split_axis, bs.At(0).size());\n  }\n  return shape.elem_cnt() * GetSizeOfDataType(in_tensor.data_type());\n}\n\n}  // namespace\n\nclass EagerSToPKernel final : public user_op::OpKernel {\n public:\n  EagerSToPKernel() = default;\n  ~EagerSToPKernel() override = default;\n\n  void InitOpKernelCacheWithFlags(\n      user_op::KernelCacheContext* ctx, int8_t flag,\n      std::shared_ptr<user_op::OpKernelCache>* cache_ptr) const override {\n    if (*cache_ptr == nullptr) { *cache_ptr = std::make_shared<EagerSToPOpKernelCache>(ctx); }\n  }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*,\n               const user_op::OpKernelCache* cache) const override {\n    auto* kernel_cache = dynamic_cast<const EagerSToPOpKernelCache*>(cache);\n    CHECK(kernel_cache != nullptr);\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n    const void* in_ptr = in->dptr();\n    void* out_ptr = out->mut_dptr();\n    void* tmp_buffer_ptr = tmp_buffer->mut_dptr();\n\n    const int64_t total_elem_cnt = ctx->Attr<Shape>(\"shape\").elem_cnt();\n\n    DeviceType device_type = ctx->device_type();\n\n    std::unique_ptr<ep::primitive::Memset> memset_primitive =\n        ep::primitive::NewPrimitive<ep::primitive::MemsetFactory>(device_type);\n    CHECK(memset_primitive) << \"Can not create Memset primitive for device type \" << device_type;\n    memset_primitive->Launch(ctx->stream(), out->mut_dptr(), 0,\n                             total_elem_cnt * GetSizeOfDataType(out->data_type()));\n\n    const auto& sorted_elem_cnt2in_tensor_slice_copier_pair =\n        kernel_cache->sorted_elem_cnt2in_tensor_slice_copier_pair();\n    const auto& sorted_elem_cnt2out_tensor_slice_copier_pair =\n        kernel_cache->sorted_elem_cnt2out_tensor_slice_copier_pair();\n    const auto& sorted_p2p_pair = kernel_cache->sorted_p2p_pair();\n    CHECK_EQ(sorted_elem_cnt2in_tensor_slice_copier_pair.size(), sorted_p2p_pair.size());\n    CHECK_EQ(sorted_elem_cnt2out_tensor_slice_copier_pair.size(), sorted_p2p_pair.size());\n\n    for (int64_t i = 0; i < sorted_p2p_pair.size(); ++i) {\n      const auto& p2p_pair = sorted_p2p_pair.at(i);\n      int64_t src = p2p_pair.first;\n      int64_t dst = p2p_pair.second;\n      if (src == dst && src == GlobalProcessCtx::Rank()) {\n        const auto& elem_cnt2tensor_slice_copier_pair =\n            sorted_elem_cnt2in_tensor_slice_copier_pair.at(i);\n        const auto& tensor_slice_copier = elem_cnt2tensor_slice_copier_pair.second;\n        tensor_slice_copier->Copy(ctx->stream(), out_ptr, in_ptr);\n        continue;\n      }\n      if (GlobalProcessCtx::Rank() == src) {\n        const auto& elem_cnt2tensor_slice_copier_pair =\n            sorted_elem_cnt2in_tensor_slice_copier_pair.at(i);\n        const auto& elem_cnt = elem_cnt2tensor_slice_copier_pair.first;\n        const auto& tensor_slice_copier = elem_cnt2tensor_slice_copier_pair.second;\n        tensor_slice_copier->Copy(ctx->stream(), tmp_buffer_ptr, in_ptr);\n        CHECK_JUST(Send(reinterpret_cast<const void*>(tmp_buffer_ptr), elem_cnt, in->data_type(),\n                        dst, device_type, ctx->stream()));\n      }\n      if (GlobalProcessCtx::Rank() == dst) {\n        const auto& elem_cnt2tensor_slice_copier_pair =\n            sorted_elem_cnt2out_tensor_slice_copier_pair.at(i);\n        const auto& elem_cnt = elem_cnt2tensor_slice_copier_pair.first;\n        const auto& tensor_slice_copier = elem_cnt2tensor_slice_copier_pair.second;\n        CHECK_JUST(\n            Recv(tmp_buffer_ptr, elem_cnt, out->data_type(), src, device_type, ctx->stream()));\n        tensor_slice_copier->Copy(ctx->stream(), out_ptr,\n                                  reinterpret_cast<const void*>(tmp_buffer_ptr));\n      }\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\nREGISTER_USER_KERNEL(\"eager_s_to_p\")\n    .SetCreateFn<EagerSToPKernel>()\n    .SetIsMatchedHob(HobIsSendAndRecvRegistered())\n    .SetInferTmpSizeFn(InferEagerSToPKernelTmpBufferSize);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/eager_s_to_s_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/user/kernels/communicate_util.h\"\n#include \"oneflow/core/device/nccl_util.h\"\n#include \"oneflow/core/common/balanced_splitter.h\"\n#include \"oneflow/core/common/container_util.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/new_kernel_util.h\"\n#include \"oneflow/core/job/parallel_desc.h\"\n#include \"oneflow/core/job/nd_sbp_util.h\"\n#include \"oneflow/core/register/tensor_slice_copier.h\"\n#include \"oneflow/core/control/global_process_ctx.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nbool ContainsEmptySlice(const std::vector<TensorSliceView>& slices) {\n  return std::any_of(slices.cbegin(), slices.cend(),\n                     [](const TensorSliceView& slice) { return slice.IsEmpty(); });\n}\n\nMaybe<Symbol<NdSbp>> GetAllSplitNdSbp(int64_t axis, int64_t ndim) {\n  NdSbp split_nd_sbp;\n  for (int64_t i = 0; i < ndim; ++i) {\n    split_nd_sbp.mutable_sbp_parallel()->Add()->mutable_split_parallel()->set_axis(axis);\n  }\n  return SymbolOf(split_nd_sbp);\n}\n\nauto* CachedGetAllSplitNdSbp = DECORATE(&GetAllSplitNdSbp, ThreadLocal);\n\nclass EagerNaiveSToSOpKernelCache final : public user_op::OpKernelCache {\n public:\n  explicit EagerNaiveSToSOpKernelCache(user_op::KernelCacheContext* ctx) { Init(ctx); }\n  ~EagerNaiveSToSOpKernelCache() override = default;\n\n  const std::vector<std::pair<int64_t, std::shared_ptr<TensorSliceCopier>>>&\n  sorted_elem_cnt2in_tensor_slice_copier_pair() const {\n    return sorted_elem_cnt2in_tensor_slice_copier_pair_;\n  }\n\n  const std::vector<std::pair<int64_t, std::shared_ptr<TensorSliceCopier>>>&\n  sorted_elem_cnt2out_tensor_slice_copier_pair() const {\n    return sorted_elem_cnt2out_tensor_slice_copier_pair_;\n  }\n\n  const std::vector<std::pair<int64_t, int64_t>>& sorted_p2p_pair() const {\n    return sorted_p2p_pair_;\n  }\n\n private:\n  void Init(user_op::KernelCacheContext* ctx) {\n    const std::string& in_parallel_conf_txt = ctx->Attr<std::string>(\"in_parallel_conf\");\n    const std::string& out_parallel_conf_txt = ctx->Attr<std::string>(\"out_parallel_conf\");\n    const int64_t in_split_axis = ctx->Attr<int64_t>(\"in_split_axis\");\n    const int64_t out_split_axis = ctx->Attr<int64_t>(\"out_split_axis\");\n    const Shape& shape = ctx->Attr<Shape>(\"shape\");\n    DeviceType device_type = ctx->device_type();\n    DataType data_type = ctx->TensorDesc4ArgNameAndIndex(\"in\", 0)->data_type();\n    Symbol<ParallelDesc> in_parallel_desc = CHECK_JUST(TxtStringToPlacement(in_parallel_conf_txt));\n    Symbol<ParallelDesc> out_parallel_desc =\n        CHECK_JUST(TxtStringToPlacement(out_parallel_conf_txt));\n    int64_t in_parallel_num = in_parallel_desc->parallel_num();\n    int64_t out_parallel_num = out_parallel_desc->parallel_num();\n\n    const std::vector<TensorSliceView> in_slices =\n        GetTensorSliceView(*in_parallel_desc->hierarchy(),\n                           *CHECK_JUST(CachedGetAllSplitNdSbp(\n                               in_split_axis, in_parallel_desc->hierarchy()->NumAxes())),\n                           shape);\n    CHECK(!ContainsEmptySlice(in_slices));\n    const std::vector<TensorSliceView> out_slices =\n        GetTensorSliceView(*out_parallel_desc->hierarchy(),\n                           *CHECK_JUST(CachedGetAllSplitNdSbp(\n                               out_split_axis, out_parallel_desc->hierarchy()->NumAxes())),\n                           shape);\n    CHECK(!ContainsEmptySlice(out_slices));\n\n    for (int64_t i = 0; i < out_parallel_num; ++i) {\n      const TensorSliceView& out_slice = out_slices.at(i);\n      for (int64_t j = 0; j < in_parallel_num; ++j) {\n        const TensorSliceView& in_slice = in_slices.at(j);\n        const TensorSliceView& intersection = out_slice.Intersect(in_slice);\n        if (intersection.IsEmpty()) { continue; }\n        int64_t src = CHECK_JUST(in_parallel_desc->MachineId4ParallelId(j));\n        int64_t dst = CHECK_JUST(out_parallel_desc->MachineId4ParallelId(i));\n        sorted_p2p_pair_.emplace_back(std::make_pair(src, dst));\n        sorted_elem_cnt2in_tensor_slice_copier_pair_.emplace_back(std::make_pair(\n            intersection.shape().elem_cnt(),\n            std::make_shared<TensorSliceCopier>(intersection, in_slice, data_type, device_type)));\n        sorted_elem_cnt2out_tensor_slice_copier_pair_.emplace_back(std::make_pair(\n            intersection.shape().elem_cnt(),\n            std::make_shared<TensorSliceCopier>(out_slice, intersection, data_type, device_type)));\n      }\n    }\n  }\n\n  std::vector<std::pair<int64_t, std::shared_ptr<TensorSliceCopier>>>\n      sorted_elem_cnt2in_tensor_slice_copier_pair_;\n  std::vector<std::pair<int64_t, std::shared_ptr<TensorSliceCopier>>>\n      sorted_elem_cnt2out_tensor_slice_copier_pair_;\n  std::vector<std::pair<int64_t, int64_t>> sorted_p2p_pair_;\n};\n\nsize_t InferNaiveSToSKernelTmpBufferSize(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc(\"in\", 0);\n  Shape shape = ctx->Attr<Shape>(\"shape\");\n  const int64_t out_split_axis = ctx->Attr<int64_t>(\"out_split_axis\");\n  const std::string& out_parallel_conf_txt = ctx->Attr<std::string>(\"out_parallel_conf\");\n  Symbol<ParallelDesc> out_parallel_desc = CHECK_JUST(TxtStringToPlacement(out_parallel_conf_txt));\n\n  int64_t out_parallel_num = out_parallel_desc->parallel_num();\n  if (out_parallel_num > 1) {\n    CHECK_LT(out_split_axis, shape.NumAxes());\n    BalancedSplitter bs(shape.At(out_split_axis), out_parallel_num);\n    shape.Set(out_split_axis, bs.At(0).size());\n  }\n  size_t tensor_byte_size = shape.elem_cnt() * GetSizeOfDataType(in_tensor.data_type());\n  return tensor_byte_size;\n}\n\n}  // namespace\n\nclass EagerNaiveSToSKernel final : public user_op::OpKernel {\n public:\n  EagerNaiveSToSKernel() = default;\n  ~EagerNaiveSToSKernel() override = default;\n\n  void InitOpKernelCacheWithFlags(\n      user_op::KernelCacheContext* ctx, int8_t flag,\n      std::shared_ptr<user_op::OpKernelCache>* cache_ptr) const override {\n    if (*cache_ptr == nullptr) { *cache_ptr = std::make_shared<EagerNaiveSToSOpKernelCache>(ctx); }\n  }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*,\n               const user_op::OpKernelCache* cache) const override {\n    auto* kernel_cache = dynamic_cast<const EagerNaiveSToSOpKernelCache*>(cache);\n    CHECK(kernel_cache != nullptr);\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n    const void* in_ptr = in->dptr();\n    void* out_ptr = out->mut_dptr();\n    void* tmp_buffer_ptr = tmp_buffer->mut_dptr();\n\n    const auto& sorted_elem_cnt2in_tensor_slice_copier_pair =\n        kernel_cache->sorted_elem_cnt2in_tensor_slice_copier_pair();\n    const auto& sorted_elem_cnt2out_tensor_slice_copier_pair =\n        kernel_cache->sorted_elem_cnt2out_tensor_slice_copier_pair();\n    const auto& sorted_p2p_pair = kernel_cache->sorted_p2p_pair();\n    CHECK_EQ(sorted_elem_cnt2in_tensor_slice_copier_pair.size(), sorted_p2p_pair.size());\n    CHECK_EQ(sorted_elem_cnt2out_tensor_slice_copier_pair.size(), sorted_p2p_pair.size());\n\n    DeviceType device_type = ctx->device_type();\n\n    for (int64_t i = 0; i < sorted_p2p_pair.size(); ++i) {\n      const auto& p2p_pair = sorted_p2p_pair.at(i);\n      int64_t src = p2p_pair.first;\n      int64_t dst = p2p_pair.second;\n      if (GlobalProcessCtx::Rank() == src) {\n        const auto& elem_cnt2tensor_slice_copier_pair =\n            sorted_elem_cnt2in_tensor_slice_copier_pair.at(i);\n        const auto& elem_cnt = elem_cnt2tensor_slice_copier_pair.first;\n        const auto& tensor_slice_copier = elem_cnt2tensor_slice_copier_pair.second;\n        tensor_slice_copier->Copy(ctx->stream(), tmp_buffer_ptr, in_ptr);\n        CHECK_JUST(Send(reinterpret_cast<const void*>(tmp_buffer_ptr), elem_cnt, in->data_type(),\n                        dst, device_type, ctx->stream()));\n      }\n      if (GlobalProcessCtx::Rank() == dst) {\n        const auto& elem_cnt2tensor_slice_copier_pair =\n            sorted_elem_cnt2out_tensor_slice_copier_pair.at(i);\n        const auto& elem_cnt = elem_cnt2tensor_slice_copier_pair.first;\n        const auto& tensor_slice_copier = elem_cnt2tensor_slice_copier_pair.second;\n        CHECK_JUST(\n            Recv(tmp_buffer_ptr, elem_cnt, out->data_type(), src, device_type, ctx->stream()));\n        tensor_slice_copier->Copy(ctx->stream(), out_ptr,\n                                  reinterpret_cast<const void*>(tmp_buffer_ptr));\n      }\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\nREGISTER_USER_KERNEL(\"eager_naive_s_to_s\")\n    .SetCreateFn<EagerNaiveSToSKernel>()\n    .SetIsMatchedHob(HobIsSendAndRecvRegistered())\n    .SetInferTmpSizeFn(InferNaiveSToSKernelTmpBufferSize);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/eager_symmetric_s_to_p_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/decorator.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/new_kernel_util.h\"\n#include \"oneflow/core/job/parallel_desc.h\"\n#include \"oneflow/core/job/nd_sbp_util.h\"\n#include \"oneflow/core/register/tensor_slice_copier.h\"\n#include \"oneflow/core/control/global_process_ctx.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename Context>\nstd::unique_ptr<ep::primitive::Memset> NewMemsetPrimitive(Context* ctx) {\n  return ep::primitive::NewPrimitive<ep::primitive::MemsetFactory>(ctx->device_type());\n}\n\nauto MemsetPrimitiveExists() {\n  return hob::make_custom(\"MemsetPrimitiveExists\", [](const user_op::KernelRegContext& ctx) {\n    return NewMemsetPrimitive(&ctx).operator bool();\n  });\n}\n\nMaybe<Symbol<NdSbp>> GetAllSplitNdSbp(int64_t axis, int64_t ndim) {\n  NdSbp split_nd_sbp;\n  for (int64_t i = 0; i < ndim; ++i) {\n    split_nd_sbp.mutable_sbp_parallel()->Add()->mutable_split_parallel()->set_axis(axis);\n  }\n  return SymbolOf(split_nd_sbp);\n}\n\nauto* CachedGetAllSplitNdSbp = DECORATE(&GetAllSplitNdSbp, ThreadLocal);\n\nMaybe<Symbol<NdSbp>> GetAllPartialSumNdSbp(int64_t ndim) {\n  NdSbp split_nd_sbp;\n  for (int64_t i = 0; i < ndim; ++i) {\n    split_nd_sbp.mutable_sbp_parallel()->Add()->mutable_partial_sum_parallel();\n  }\n  return SymbolOf(split_nd_sbp);\n}\n\nauto* CachedGetAllPartialSumNdSbp = DECORATE(&GetAllPartialSumNdSbp, ThreadLocal);\n\nclass EagerSymmetricSToPOpKernelCache final : public user_op::OpKernelCache {\n public:\n  explicit EagerSymmetricSToPOpKernelCache(user_op::KernelCacheContext* ctx) { Init(ctx); }\n  ~EagerSymmetricSToPOpKernelCache() override = default;\n\n  const std::shared_ptr<TensorSliceCopier>& tensor_slice_copier() const {\n    return tensor_slice_copier_;\n  }\n\n private:\n  void Init(user_op::KernelCacheContext* ctx) {\n    const std::string& parallel_conf_txt = ctx->Attr<std::string>(\"parallel_conf\");\n    const int64_t in_split_axis = ctx->Attr<int64_t>(\"in_split_axis\");\n    const user_op::TensorDesc* in_logical_desc = ctx->LogicalTensorDesc4ArgNameAndIndex(\"in\", 0);\n    const Shape& shape = in_logical_desc->shape();\n    DeviceType device_type = ctx->device_type();\n    DataType data_type = ctx->TensorDesc4ArgNameAndIndex(\"in\", 0)->data_type();\n    ParallelConf parallel_conf;\n    CHECK(TxtString2PbMessage(parallel_conf_txt, &parallel_conf));\n    Symbol<ParallelDesc> parallel_desc = SymbolOf(ParallelDesc(parallel_conf));\n\n    const TensorSliceView& in_slice = GetTensorSliceView4ParallelId(\n        *parallel_desc->hierarchy(),\n        *CHECK_JUST(CachedGetAllSplitNdSbp(in_split_axis, parallel_desc->hierarchy()->NumAxes())),\n        shape, ctx->parallel_ctx().parallel_id());\n    CHECK(!in_slice.IsEmpty());\n    const TensorSliceView& out_slice = GetTensorSliceView4ParallelId(\n        *parallel_desc->hierarchy(),\n        *CHECK_JUST(CachedGetAllPartialSumNdSbp(parallel_desc->hierarchy()->NumAxes())), shape,\n        ctx->parallel_ctx().parallel_id());\n    CHECK(!out_slice.IsEmpty());\n    const TensorSliceView& intersection = out_slice.Intersect(in_slice);\n    CHECK(!intersection.IsEmpty());\n    tensor_slice_copier_ =\n        std::make_shared<TensorSliceCopier>(out_slice, in_slice, data_type, device_type);\n  }\n\n  std::shared_ptr<TensorSliceCopier> tensor_slice_copier_;\n};\n\n}  // namespace\n\nclass EagerSymmetricSToPKernel final : public user_op::OpKernel {\n public:\n  EagerSymmetricSToPKernel() = default;\n  ~EagerSymmetricSToPKernel() override = default;\n\n  void InitOpKernelCacheWithFlags(\n      user_op::KernelCacheContext* ctx, int8_t flag,\n      std::shared_ptr<user_op::OpKernelCache>* cache_ptr) const override {\n    if (*cache_ptr == nullptr) {\n      *cache_ptr = std::make_shared<EagerSymmetricSToPOpKernelCache>(ctx);\n    }\n  }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*,\n               const user_op::OpKernelCache* cache) const override {\n    auto* kernel_cache = dynamic_cast<const EagerSymmetricSToPOpKernelCache*>(cache);\n    CHECK(kernel_cache != nullptr);\n    auto primitive = NewMemsetPrimitive(ctx);\n    CHECK(primitive);  // NOLINT\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    const auto& out_shape_view = out->shape_view();\n\n    const void* in_ptr = in->dptr();\n    void* out_ptr = out->mut_dptr();\n\n    primitive->Launch(ctx->stream(), out->mut_dptr(), 0,\n                      out_shape_view.elem_cnt() * GetSizeOfDataType(out->data_type()));\n    const auto& tensor_slice_copier = kernel_cache->tensor_slice_copier();\n    tensor_slice_copier->Copy(ctx->stream(), out_ptr, in_ptr);\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\nREGISTER_USER_KERNEL(\"eager_symmetric_s_to_p\")\n    .SetCreateFn<EagerSymmetricSToPKernel>()\n    .SetIsMatchedHob(MemsetPrimitiveExists() == true);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/elementwise_maximum_minimum_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/user/kernels/elementwise_maximum_minimum_kernel.h\"\n\nnamespace oneflow {\nnamespace {\ntemplate<template<typename> class Opt, typename T>\nstruct ElemwiseXimumGradFunctor<DeviceType::kCPU, Opt, T> final {\n  void operator()(ep::Stream* stream, int64_t elem_cnt, const T* dz, const T* x, const T* y, T* dx,\n                  T* dy) {\n    XPU_1D_KERNEL_LOOP(idx, elem_cnt) {\n      Opt<T>()(dz[idx], x[idx], y[idx], dx ? &dx[idx] : nullptr, dy ? &dy[idx] : nullptr);\n    }\n  }\n};\n\ntemplate<template<typename> class Opt, typename T>\nstruct ElemwiseXimumFunctor<DeviceType::kCPU, Opt, T> final {\n  void operator()(ep::Stream* stream, int64_t elem_cnt, T* z, const T* x, const T* y) {\n    FOR_RANGE(int64_t, idx, 0, elem_cnt) { z[idx] = Opt<T>()(x[idx], y[idx]); }\n  }\n};\n}  // namespace\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_MAXIMUM_KERNELS, (DeviceType::kCPU),\n                                 ARITHMETIC_DATA_TYPE_SEQ UNSIGNED_INT_DATA_TYPE_SEQ)\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_MINIMUM_KERNELS, (DeviceType::kCPU),\n                                 ARITHMETIC_DATA_TYPE_SEQ UNSIGNED_INT_DATA_TYPE_SEQ)\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/elementwise_maximum_minimum_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifdef WITH_CUDA\n#include \"oneflow/core/cuda/elementwise.cuh\"\n#include \"oneflow/user/kernels/elementwise_maximum_minimum_kernel.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n\nnamespace oneflow {\n\nnamespace {\ntemplate<template<typename> class Opt, typename T>\n__global__ void ElementwiseXimumGradGpuKernel(int64_t elem_cnt, const T* dz, const T* x, const T* y,\n                                              T* dx, T* dy) {\n  XPU_1D_KERNEL_LOOP(idx, elem_cnt) {\n    Opt<T>()(dz[idx], x[idx], y[idx], dx ? &dx[idx] : nullptr, dy ? &dy[idx] : nullptr);\n  }\n}\n\ntemplate<template<typename> class Opt, typename T>\nstruct ElemwiseXimumGradFunctor<DeviceType::kCUDA, Opt, T> final {\n  void operator()(ep::Stream* stream, int64_t elem_cnt, const T* dz, const T* x, const T* y, T* dx,\n                  T* dy) {\n    ElementwiseXimumGradGpuKernel<Opt, T>\n        <<<BlocksNum4ThreadsNum(elem_cnt), kCudaThreadsNumPerBlock, 0,\n           stream->As<ep::CudaStream>()->cuda_stream()>>>(elem_cnt, dz, x, y, dx, dy);\n  }\n};\n\ntemplate<template<typename> class Opt, typename T>\nstruct ElemwiseXimumFunctor<DeviceType::kCUDA, Opt, T> final {\n  void operator()(ep::Stream* stream, int64_t elem_cnt, T* z, const T* x, const T* y) {\n    OF_CUDA_CHECK(cuda::elementwise::Binary(Opt<T>(), elem_cnt, z, x, y,\n                                            stream->As<ep::CudaStream>()->cuda_stream()));\n  }\n};\n}  // namespace\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_MAXIMUM_KERNELS, (DeviceType::kCUDA),\n                                 ARITHMETIC_DATA_TYPE_SEQ UNSIGNED_INT_DATA_TYPE_SEQ)\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_MINIMUM_KERNELS, (DeviceType::kCUDA),\n                                 ARITHMETIC_DATA_TYPE_SEQ UNSIGNED_INT_DATA_TYPE_SEQ)\n}  // namespace oneflow\n#endif  // WITH_CUDA\n"
  },
  {
    "path": "oneflow/user/kernels/elementwise_maximum_minimum_kernel.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef _ONEFLOW_USER_KERNELS_ELEMENTWISE_MAXIMUM_MINIMUM_KERNEL_H_\n#define _ONEFLOW_USER_KERNELS_ELEMENTWISE_MAXIMUM_MINIMUM_KERNEL_H_\n#include \"oneflow/core/ndarray/xpu_util.h\"\n#include \"oneflow/core/framework/framework.h\"\n\nnamespace oneflow {\n\ntemplate<typename T>\nstruct MaximumFunctor {\n  OF_DEVICE_FUNC T operator()(const T x, const T y) const { return x > y ? x : y; }\n};\n\ntemplate<typename T>\nstruct MaximumGradFunctor {\n  OF_DEVICE_FUNC void operator()(const T dz, const T x, const T y, T* dx, T* dy) {\n    T dx_val = 0;\n    T dy_val = 0;\n    if (x > y) {\n      dx_val = dz;\n    } else if (x == y) {\n      dx_val = dz / 2;\n      dy_val = dz / 2;\n    } else {\n      dy_val = dz;\n    }\n    if (dx) { *dx = dx_val; }\n    if (dy) { *dy = dy_val; }\n  }\n};\n\ntemplate<typename T>\nstruct MinimumFunctor {\n  OF_DEVICE_FUNC T operator()(const T x, const T y) const { return x < y ? x : y; }\n};\n\ntemplate<typename T>\nstruct MinimumGradFunctor {\n  OF_DEVICE_FUNC void operator()(const T dz, const T x, const T y, T* dx, T* dy) {\n    T dx_val = 0;\n    T dy_val = 0;\n    if (x < y) {\n      dx_val = dz;\n    } else if (x == y) {\n      dx_val = dz / 2;\n      dy_val = dz / 2;\n    } else {\n      dy_val = dz;\n    }\n    if (dx) { *dx = dx_val; }\n    if (dy) { *dy = dy_val; }\n  }\n};\n\nnamespace {\ntemplate<DeviceType device_type, template<typename> class Opt, typename T>\nstruct ElemwiseXimumGradFunctor final {\n  void operator()(ep::Stream* stream, int64_t elem_cnt, const T* dz, const T* x, const T* y, T* dx,\n                  T* dy);\n};\n\ntemplate<DeviceType device_type, template<typename> class Opt, typename T>\nstruct ElemwiseXimumFunctor final {\n  void operator()(ep::Stream* stream, int64_t elem_cnt, T* z, const T* x, const T* y);\n};\n}  // namespace\n\ntemplate<DeviceType device_type, template<typename> class Opt, typename T>\nclass ElemwiseXimumKernel final : public user_op::OpKernel {\n public:\n  ElemwiseXimumKernel() = default;\n  ~ElemwiseXimumKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* tensor_x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    const user_op::Tensor* tensor_y = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    user_op::Tensor* tensor_z = ctx->Tensor4ArgNameAndIndex(\"z\", 0);\n    int64_t n = tensor_x->shape_view().elem_cnt();\n\n    ElemwiseXimumFunctor<device_type, Opt, T>()(ctx->stream(), n, tensor_z->mut_dptr<T>(),\n                                                tensor_x->dptr<T>(), tensor_y->dptr<T>());\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\ntemplate<DeviceType device_type, template<typename> class Opt, typename T>\nclass ElemwiseXimumBackwardKernel final : public user_op::OpKernel {\n public:\n  ElemwiseXimumBackwardKernel() = default;\n  ~ElemwiseXimumBackwardKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    user_op::Tensor* tensor_dz = ctx->Tensor4ArgNameAndIndex(\"dz\", 0);\n    user_op::Tensor* tensor_x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    user_op::Tensor* tensor_y = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    user_op::Tensor* tensor_dx = ctx->Tensor4ArgNameAndIndex(\"dx\", 0);\n    user_op::Tensor* tensor_dy = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n\n    const T* dptr_dz = tensor_dz->dptr<T>();\n    const T* dptr_x = tensor_x->dptr<T>();\n    const T* dptr_y = tensor_y->dptr<T>();\n\n    T* dptr_dx = tensor_dx ? tensor_dx->mut_dptr<T>() : nullptr;\n    T* dptr_dy = tensor_dy ? tensor_dy->mut_dptr<T>() : nullptr;\n\n    ElemwiseXimumGradFunctor<device_type, Opt, T>()(ctx->stream(),\n                                                    tensor_dz->shape_view().elem_cnt(), dptr_dz,\n                                                    dptr_x, dptr_y, dptr_dx, dptr_dy);\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_MAXIMUM_KERNELS(device, dtype_pair)                                               \\\n  REGISTER_USER_KERNEL(\"elementwise_maximum\")                                                      \\\n      .SetCreateFn<ElemwiseXimumKernel<device, MaximumFunctor, OF_PP_PAIR_FIRST(dtype_pair)>>()    \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == device)                                        \\\n                       && (user_op::HobDataType(\"x\", 0) == OF_PP_PAIR_SECOND(dtype_pair))          \\\n                       && (user_op::HobDataType(\"y\", 0) == OF_PP_PAIR_SECOND(dtype_pair)));        \\\n  REGISTER_USER_KERNEL(\"elementwise_maximum_backward\")                                             \\\n      .SetCreateFn<                                                                                \\\n          ElemwiseXimumBackwardKernel<device, MaximumGradFunctor, OF_PP_PAIR_FIRST(dtype_pair)>>() \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == device)                                        \\\n                       && (user_op::HobDataType(\"x\", 0) == OF_PP_PAIR_SECOND(dtype_pair))          \\\n                       && (user_op::HobDataType(\"y\", 0) == OF_PP_PAIR_SECOND(dtype_pair)));\n\n#define REGISTER_MINIMUM_KERNELS(device, dtype_pair)                                               \\\n  REGISTER_USER_KERNEL(\"elementwise_minimum\")                                                      \\\n      .SetCreateFn<ElemwiseXimumKernel<device, MinimumFunctor, OF_PP_PAIR_FIRST(dtype_pair)>>()    \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == device)                                        \\\n                       && (user_op::HobDataType(\"x\", 0) == OF_PP_PAIR_SECOND(dtype_pair))          \\\n                       && (user_op::HobDataType(\"y\", 0) == OF_PP_PAIR_SECOND(dtype_pair)));        \\\n  REGISTER_USER_KERNEL(\"elementwise_minimum_backward\")                                             \\\n      .SetCreateFn<                                                                                \\\n          ElemwiseXimumBackwardKernel<device, MinimumGradFunctor, OF_PP_PAIR_FIRST(dtype_pair)>>() \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == device)                                        \\\n                       && (user_op::HobDataType(\"x\", 0) == OF_PP_PAIR_SECOND(dtype_pair))          \\\n                       && (user_op::HobDataType(\"y\", 0) == OF_PP_PAIR_SECOND(dtype_pair)));\n\n}  // namespace oneflow\n\n#endif  // _ONEFLOW_USER_KERNELS_ELEMENTWISE_MAXIMUM_MINIMUM_KERNEL_H_\n"
  },
  {
    "path": "oneflow/user/kernels/elementwise_primitive_kernel.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef _ONEFLOW_USER_KERNELS_ELEMENTWISE_XPU_KERNEL_H_\n#define _ONEFLOW_USER_KERNELS_ELEMENTWISE_XPU_KERNEL_H_\n#include \"oneflow/core/common/scalar.h\"\n#include \"oneflow/core/common/data_type.h\"\n#include \"oneflow/core/ep/include/primitive/broadcast_elementwise_binary.h\"\n#include \"oneflow/core/ep/include/primitive/elementwise_unary.h\"\n#include \"oneflow/core/ep/include/primitive/unary_op.h\"\n#include \"oneflow/core/ep/include/primitive/binary_op.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/cuda_graph_support.h\"\n\nnamespace oneflow {\n\nclass UnaryPrimitiveKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(UnaryPrimitiveKernel);\n  UnaryPrimitiveKernel() = default;\n  ~UnaryPrimitiveKernel() = default;\n\n  using PrimitiveFactoryFuncType = std::function<std::unique_ptr<ep::primitive::ElementwiseUnary>(\n      user_op::KernelComputeContext*)>;\n\n  UnaryPrimitiveKernel(const std::string& output_name, const std::string& input_name,\n                       PrimitiveFactoryFuncType fn)\n      : output_name_(output_name),\n        input_name_(input_name),\n        primitive_factory_func_(std::move(fn)) {}\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    auto primitive = primitive_factory_func_(ctx);\n    CHECK(primitive);\n\n    const user_op::Tensor* input_tensor = ctx->Tensor4ArgNameAndIndex(input_name_, 0);\n    user_op::Tensor* output_tensor = ctx->Tensor4ArgNameAndIndex(output_name_, 0);\n\n    const ShapeView& input_shape = input_tensor->shape_view();\n    const ShapeView& output_shape = output_tensor->shape_view();\n    CHECK_EQ(input_shape, output_shape) << \"Input shape should be equal to Output shape.\";\n    const int64_t elem_cnt = input_shape.elem_cnt();\n\n    if (elem_cnt != 0) {\n      primitive->Launch(ctx->stream(), input_tensor->dptr(), output_tensor->mut_dptr(), elem_cnt);\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n\n  std::string output_name_;\n  std::string input_name_;\n  PrimitiveFactoryFuncType primitive_factory_func_;\n};\n\nclass BinaryPrimitiveKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(BinaryPrimitiveKernel);\n  BinaryPrimitiveKernel() = default;\n  ~BinaryPrimitiveKernel() = default;\n\n  using PrimitiveFactoryFuncType =\n      std::function<std::unique_ptr<ep::primitive::BroadcastElementwiseBinary>(\n          user_op::KernelComputeContext*)>;\n\n  BinaryPrimitiveKernel(const std::string& output_name, const std::string& input_a_name,\n                        const std::string& input_b_name, PrimitiveFactoryFuncType fn)\n      : output_name_(output_name),\n        input_a_name_(input_a_name),\n        input_b_name_(input_b_name),\n        primitive_factory_func_(std::move(fn)) {}\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    auto primitive = primitive_factory_func_(ctx);\n    CHECK(primitive);\n\n    const user_op::Tensor* input_a_tensor = ctx->Tensor4ArgNameAndIndex(input_a_name_, 0);\n    const user_op::Tensor* input_b_tensor = ctx->Tensor4ArgNameAndIndex(input_b_name_, 0);\n    user_op::Tensor* output_tensor = ctx->Tensor4ArgNameAndIndex(output_name_, 0);\n\n    const ShapeView& input_a_shape = input_a_tensor->shape_view();\n    const ShapeView& input_b_shape = input_b_tensor->shape_view();\n    const ShapeView& output_shape = output_tensor->shape_view();\n    CHECK_EQ(input_a_shape, input_b_shape) << \"InputA shape should be equal to InputB shape.\";\n    CHECK_EQ(input_a_shape, output_shape) << \"Input shape should be equal to Output shape.\";\n    const int64_t elem_cnt = input_a_shape.elem_cnt();\n\n    if (elem_cnt != 0) {\n      primitive->Launch(ctx->stream(), input_a_shape.NumAxes(), input_a_shape.ptr(),\n                        input_a_tensor->dptr(), input_b_shape.NumAxes(), input_b_shape.ptr(),\n                        input_b_tensor->dptr(), output_tensor->mut_dptr());\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n\n  std::string output_name_;\n  std::string input_a_name_;\n  std::string input_b_name_;\n  PrimitiveFactoryFuncType primitive_factory_func_;\n};\n\nnamespace {\nauto UnaryPrimitiveExists(ep::primitive::UnaryOp op, const std::string& output_name,\n                          const std::string& input_name) {\n  return hob::make_custom(\n      \"ElementwiseUnaryPrimitiveExists\", [=](const user_op::KernelRegContext& ctx) {\n        const user_op::TensorDesc* src = ctx.TensorDesc4ArgNameAndIndex(input_name, 0);\n        const user_op::TensorDesc* dst = ctx.TensorDesc4ArgNameAndIndex(output_name, 0);\n        auto primitive = ep::primitive::NewPrimitive<ep::primitive::ElementwiseUnaryFactory>(\n            ctx.device_type(), op, src->data_type(), dst->data_type());\n        return primitive.operator bool();\n      });\n}\n\nauto BinaryPrimitiveExists(ep::primitive::BinaryOp op, const std::string& output_name,\n                           const std::string& input_a_name) {\n  return hob::make_custom(\n      \"BroadcastElementwiseBinaryPrimitiveExists\", [=](const user_op::KernelRegContext& ctx) {\n        const user_op::TensorDesc* src0 = ctx.TensorDesc4ArgNameAndIndex(input_a_name, 0);\n        const user_op::TensorDesc* dst = ctx.TensorDesc4ArgNameAndIndex(output_name, 0);\n        auto primitive =\n            ep::primitive::NewPrimitive<ep::primitive::BroadcastElementwiseBinaryFactory>(\n                ctx.device_type(), op, src0->data_type(), dst->data_type(), 1 /*max_num_dims*/);\n        return primitive.operator bool();\n      });\n}\n}  // namespace\n\n}  // namespace oneflow\n\n#endif  // _ONEFLOW_USER_KERNELS_ELEMENTWISE_XPU_KERNEL_H_\n"
  },
  {
    "path": "oneflow/user/kernels/embedding_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/ep/include/primitive/memset.h\"\n#include \"oneflow/user/kernels/embedding_kernel_util.h\"\n\nnamespace oneflow {\n\ntemplate<typename T, typename IndexType>\nclass CpuEmbeddingRenormKernel final : public user_op::OpKernel {\n public:\n  CpuEmbeddingRenormKernel() = default;\n  ~CpuEmbeddingRenormKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    const user_op::Tensor* indices = ctx->Tensor4ArgNameAndIndex(\"indices\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    const double max_norm = ctx->Attr<double>(\"max_norm\");\n    const double norm_type = ctx->Attr<double>(\"norm_type\");\n\n    const ShapeView& in_shape = in->shape_view();\n    const int64_t emb_size = in_shape.At(0);\n    const int64_t emb_dim = in_shape.At(1);\n    const T* in_buf = in->dptr<T>();\n    const IndexType* indices_buf = indices->dptr<IndexType>();\n    T* out_buf = out->mut_dptr<T>();\n    const int64_t num_indices = indices->shape_view().elem_cnt();\n    EmbeddingReNormFunctor<DeviceType::kCPU, T, IndexType>()(\n        ctx->stream(), in_buf, indices_buf, out_buf, max_norm, norm_type, num_indices, emb_size,\n        emb_dim, nullptr);\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\ntemplate<typename T, typename IndexType>\nclass CpuEmbeddingKernel final : public user_op::OpKernel {\n public:\n  CpuEmbeddingKernel() = default;\n  ~CpuEmbeddingKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* weight = ctx->Tensor4ArgNameAndIndex(\"weight\", 0);\n    const user_op::Tensor* indices = ctx->Tensor4ArgNameAndIndex(\"indices\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    const int64_t padding_idx = ctx->Attr<int64_t>(\"padding_idx\");\n    const bool scale_grad_by_freq = ctx->Attr<bool>(\"scale_grad_by_freq\");\n\n    const ShapeView& out_shape = out->shape_view();\n    const int64_t num_indices = out_shape.Count(0, out_shape.NumAxes() - 1);\n    const int64_t emb_size = weight->shape_view().At(0);\n    const int64_t emb_dim = out_shape.At(out_shape.NumAxes() - 1);\n    const T* weight_buf = weight->dptr<T>();\n    const IndexType* indices_buf = indices->dptr<IndexType>();\n    T* out_buf = out->mut_dptr<T>();\n\n    EmbeddingFunctor<DeviceType::kCPU, T, IndexType>()(ctx->stream(), weight_buf, indices_buf,\n                                                       out_buf, padding_idx, scale_grad_by_freq,\n                                                       num_indices, emb_size, emb_dim);\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\ntemplate<typename T, typename IndexType>\nclass CpuEmbeddingGradKernel final : public user_op::OpKernel {\n public:\n  CpuEmbeddingGradKernel() = default;\n  ~CpuEmbeddingGradKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    const user_op::Tensor* weight = ctx->Tensor4ArgNameAndIndex(\"weight\", 0);\n    const user_op::Tensor* indices = ctx->Tensor4ArgNameAndIndex(\"indices\", 0);\n    user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex(\"dx\", 0);\n    const int64_t padding_idx = ctx->Attr<int64_t>(\"padding_idx\");\n    const bool scale_grad_by_freq = ctx->Attr<bool>(\"scale_grad_by_freq\");\n\n    const ShapeView& dy_shape = dy->shape_view();\n    const int64_t num_indices = dy_shape.Count(0, dy_shape.NumAxes() - 1);\n    const int64_t emb_size = weight->shape_view().At(0);\n    const int64_t emb_dim = dy_shape.At(dy_shape.NumAxes() - 1);\n\n    const T* dy_buf = dy->dptr<T>();\n    const IndexType* indices_buf = indices->dptr<IndexType>();\n    T* dx_buf = dx->mut_dptr<T>();\n\n    std::unique_ptr<ep::primitive::Memset> memset_primitive =\n        ep::primitive::NewPrimitive<ep::primitive::MemsetFactory>(ctx->device_type());\n    CHECK(memset_primitive);\n    memset_primitive->Launch(ctx->stream(), dx_buf, 0, dx->shape_view().Count(0) * sizeof(T));\n    EmbeddingGradFunctor<DeviceType::kCPU, T, IndexType>()(ctx->stream(), dy_buf, indices_buf,\n                                                           dx_buf, padding_idx, scale_grad_by_freq,\n                                                           num_indices, emb_size, emb_dim, nullptr);\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_CPU_EMBEDDING_KERNEL(in_type, indices_type)                                     \\\n  REGISTER_USER_KERNEL(\"embedding_renorm\")                                                       \\\n      .SetCreateFn<                                                                              \\\n          CpuEmbeddingRenormKernel<OF_PP_PAIR_FIRST(in_type), OF_PP_PAIR_FIRST(indices_type)>>() \\\n      .SetIsMatchedHob(                                                                          \\\n          (user_op::HobDeviceType() == DeviceType::kCPU)                                         \\\n          && (user_op::HobDataType(\"in\", 0) == OF_PP_PAIR_SECOND(in_type))                       \\\n          && (user_op::HobDataType(\"indices\", 0) == OF_PP_PAIR_SECOND(indices_type)));           \\\n  REGISTER_USER_KERNEL(\"embedding\")                                                              \\\n      .SetCreateFn<                                                                              \\\n          CpuEmbeddingKernel<OF_PP_PAIR_FIRST(in_type), OF_PP_PAIR_FIRST(indices_type)>>()       \\\n      .SetIsMatchedHob(                                                                          \\\n          (user_op::HobDeviceType() == DeviceType::kCPU)                                         \\\n          && (user_op::HobDataType(\"weight\", 0) == OF_PP_PAIR_SECOND(in_type))                   \\\n          && (user_op::HobDataType(\"indices\", 0) == OF_PP_PAIR_SECOND(indices_type)));           \\\n  REGISTER_USER_KERNEL(\"embedding_grad\")                                                         \\\n      .SetCreateFn<                                                                              \\\n          CpuEmbeddingGradKernel<OF_PP_PAIR_FIRST(in_type), OF_PP_PAIR_FIRST(indices_type)>>()   \\\n      .SetIsMatchedHob(                                                                          \\\n          (user_op::HobDeviceType() == DeviceType::kCPU)                                         \\\n          && (user_op::HobDataType(\"weight\", 0) == OF_PP_PAIR_SECOND(in_type))                   \\\n          && (user_op::HobDataType(\"indices\", 0) == OF_PP_PAIR_SECOND(indices_type)));\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_CPU_EMBEDDING_KERNEL, EMBEDDING_DATA_TYPE_SEQ_CPU,\n                                 INDEX_DATA_TYPE_SEQ)\n#undef REGISTER_CPU_EMBEDDING_KERNEL\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/embedding_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n#include \"oneflow/core/ep/include/primitive/memset.h\"\n#include \"oneflow/user/kernels/embedding_kernel_util.h\"\n\nnamespace oneflow {\n\ntemplate<typename T, typename IndexType>\nclass GpuEmbeddingRenormKernel final : public user_op::OpKernel {\n public:\n  GpuEmbeddingRenormKernel() = default;\n  ~GpuEmbeddingRenormKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    const user_op::Tensor* indices = ctx->Tensor4ArgNameAndIndex(\"indices\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    const double max_norm = ctx->Attr<double>(\"max_norm\");\n    const double norm_type = ctx->Attr<double>(\"norm_type\");\n\n    const ShapeView& in_shape = in->shape_view();\n    const int64_t emb_size = in_shape.At(0);\n    const int64_t emb_dim = in_shape.At(1);\n    const T* in_buf = in->dptr<T>();\n    const IndexType* indices_buf = indices->dptr<IndexType>();\n    T* out_buf = out->mut_dptr<T>();\n    const int64_t num_indices = indices->shape_view().elem_cnt();\n    int32_t* tmp_buf = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0)->mut_dptr<int32_t>();\n    std::unique_ptr<ep::primitive::Memset> memset_primitive =\n        ep::primitive::NewPrimitive<ep::primitive::MemsetFactory>(ctx->device_type());\n    CHECK(memset_primitive);\n    memset_primitive->Launch(ctx->stream(), tmp_buf, 0,\n                             GetCudaAlignedSize(sizeof(int32_t) * emb_size));\n    EmbeddingReNormFunctor<DeviceType::kCUDA, T, IndexType>()(\n        ctx->stream(), in_buf, indices_buf, out_buf, max_norm, norm_type, num_indices, emb_size,\n        emb_dim, tmp_buf);\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\ntemplate<typename T, typename IndexType>\nclass GpuEmbeddingKernel final : public user_op::OpKernel {\n public:\n  GpuEmbeddingKernel() = default;\n  ~GpuEmbeddingKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* weight = ctx->Tensor4ArgNameAndIndex(\"weight\", 0);\n    const user_op::Tensor* indices = ctx->Tensor4ArgNameAndIndex(\"indices\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    const int64_t padding_idx = ctx->Attr<int64_t>(\"padding_idx\");\n    const bool scale_grad_by_freq = ctx->Attr<bool>(\"scale_grad_by_freq\");\n\n    const int64_t num_indices = indices->shape_view().elem_cnt();\n    const int64_t emb_size = weight->shape_view().At(0);\n    const int64_t emb_dim = weight->shape_view().At(1);\n    const T* weight_buf = weight->dptr<T>();\n    const IndexType* indices_buf = indices->dptr<IndexType>();\n    T* out_buf = out->mut_dptr<T>();\n\n    EmbeddingFunctor<DeviceType::kCUDA, T, IndexType>()(ctx->stream(), weight_buf, indices_buf,\n                                                        out_buf, padding_idx, scale_grad_by_freq,\n                                                        num_indices, emb_size, emb_dim);\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\ntemplate<typename T, typename IndexType>\nclass GpuEmbeddingGradKernel final : public user_op::OpKernel {\n public:\n  GpuEmbeddingGradKernel() = default;\n  ~GpuEmbeddingGradKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    const user_op::Tensor* weight = ctx->Tensor4ArgNameAndIndex(\"weight\", 0);\n    const user_op::Tensor* indices = ctx->Tensor4ArgNameAndIndex(\"indices\", 0);\n    user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex(\"dx\", 0);\n    const int64_t padding_idx = ctx->Attr<int64_t>(\"padding_idx\");\n    const bool scale_grad_by_freq = ctx->Attr<bool>(\"scale_grad_by_freq\");\n\n    const int64_t num_indices = indices->shape_view().elem_cnt();\n    const int64_t emb_size = weight->shape_view().At(0);\n    const int64_t emb_dim = weight->shape_view().At(1);\n\n    const T* dy_buf = dy->dptr<T>();\n    const IndexType* indices_buf = indices->dptr<IndexType>();\n    T* dx_buf = dx->mut_dptr<T>();\n    int32_t* tmp_buf = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0)->mut_dptr<int32_t>();\n    std::unique_ptr<ep::primitive::Memset> memset_primitive =\n        ep::primitive::NewPrimitive<ep::primitive::MemsetFactory>(ctx->device_type());\n    CHECK(memset_primitive);\n    memset_primitive->Launch(ctx->stream(), dx_buf, 0, dx->shape_view().elem_cnt() * sizeof(T));\n    memset_primitive->Launch(ctx->stream(), tmp_buf, 0,\n                             GetCudaAlignedSize(sizeof(int32_t) * emb_size));\n    EmbeddingGradFunctor<DeviceType::kCUDA, T, IndexType>()(\n        ctx->stream(), dy_buf, indices_buf, dx_buf, padding_idx, scale_grad_by_freq, num_indices,\n        emb_size, emb_dim, tmp_buf);\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_CUDA_EMBEDDING_KERNEL(in_type, indices_type)                                      \\\n  REGISTER_USER_KERNEL(\"embedding_renorm\")                                                         \\\n      .SetCreateFn<                                                                                \\\n          GpuEmbeddingRenormKernel<OF_PP_PAIR_FIRST(in_type), OF_PP_PAIR_FIRST(indices_type)>>()   \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)                             \\\n                       && (user_op::HobDataType(\"in\", 0) == OF_PP_PAIR_SECOND(in_type))            \\\n                       && (user_op::HobDataType(\"indices\", 0) == OF_PP_PAIR_SECOND(indices_type))) \\\n      .SetInferTmpSizeFn([](user_op::InferContext* ctx) -> size_t {                                \\\n        const Shape& in_shape = ctx->InputShape(\"in\", 0);                                          \\\n        const int64_t emb_size = in_shape.At(0);                                                   \\\n        return GetCudaAlignedSize(sizeof(int32_t) * emb_size);                                     \\\n      });                                                                                          \\\n  REGISTER_USER_KERNEL(\"embedding\")                                                                \\\n      .SetCreateFn<                                                                                \\\n          GpuEmbeddingKernel<OF_PP_PAIR_FIRST(in_type), OF_PP_PAIR_FIRST(indices_type)>>()         \\\n      .SetIsMatchedHob(                                                                            \\\n          (user_op::HobDeviceType() == DeviceType::kCUDA)                                          \\\n          && (user_op::HobDataType(\"weight\", 0) == OF_PP_PAIR_SECOND(in_type))                     \\\n          && (user_op::HobDataType(\"indices\", 0) == OF_PP_PAIR_SECOND(indices_type)));             \\\n  REGISTER_USER_KERNEL(\"embedding_grad\")                                                           \\\n      .SetCreateFn<                                                                                \\\n          GpuEmbeddingGradKernel<OF_PP_PAIR_FIRST(in_type), OF_PP_PAIR_FIRST(indices_type)>>()     \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)                             \\\n                       && (user_op::HobDataType(\"weight\", 0) == OF_PP_PAIR_SECOND(in_type))        \\\n                       && (user_op::HobDataType(\"indices\", 0) == OF_PP_PAIR_SECOND(indices_type))) \\\n      .SetInferTmpSizeFn([](user_op::InferContext* ctx) -> size_t {                                \\\n        const Shape& in_shape = ctx->InputShape(\"weight\", 0);                                      \\\n        const int64_t emb_size = in_shape.At(0);                                                   \\\n        return GetCudaAlignedSize(sizeof(int32_t) * emb_size);                                     \\\n      });\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_CUDA_EMBEDDING_KERNEL, EMBEDDING_DATA_TYPE_SEQ_CUDA,\n                                 INDEX_DATA_TYPE_SEQ)\n#undef REGISTER_CUDA_EMBEDDING_KERNEL\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/embedding_kernel_util.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/user/kernels/embedding_kernel_util.h\"\n\nnamespace oneflow {\n\ntemplate<typename T, typename IndexType>\nstruct EmbeddingReNormFunctor<DeviceType::kCPU, T, IndexType> final {\n  void operator()(ep::Stream* stream, const T* in_buf, const IndexType* indices_buf, T* out_buf,\n                  const double max_norm, const double norm_type, const int64_t num_indices,\n                  const int64_t emb_size, const int64_t emb_dim, int32_t* tmp_buf) {\n    auto sorted_indices = std::vector<IndexType>(indices_buf, indices_buf + num_indices);\n    std::sort(sorted_indices.begin(), sorted_indices.end());\n\n    for (int64_t i = 0; i < num_indices; i++) {\n      if (i > 0 && sorted_indices[i] == sorted_indices[i - 1]) { continue; }\n      CHECK(sorted_indices[i] >= 0 && sorted_indices[i] < emb_size);\n      double norm = 0;\n      for (int64_t j = emb_dim * sorted_indices[i]; j < emb_dim * (sorted_indices[i] + 1); j++) {\n        norm += std::pow(std::abs(in_buf[j]), norm_type);\n      }\n      norm = std::pow(norm, (1.0 / norm_type));\n      if (norm > max_norm) {\n        double scale = max_norm / (norm + 1e-7);\n        for (int64_t j = emb_dim * sorted_indices[i]; j < emb_dim * (sorted_indices[i] + 1); j++) {\n          out_buf[j] = in_buf[j] * scale;\n        }\n      }\n    }\n  }\n};\n\ntemplate<typename T, typename IndexType>\nstruct EmbeddingFunctor<DeviceType::kCPU, T, IndexType> final {\n  void operator()(ep::Stream* stream, const T* weight_buf, const IndexType* indices_buf, T* out_buf,\n                  const int64_t padding_idx, const bool scale_grad_by_freq,\n                  const int64_t num_indices, const int64_t emb_size, const int64_t emb_dim) {\n    for (int64_t i = 0; i < num_indices; i++) {\n      IndexType indice = indices_buf[i];\n      CHECK(indice >= 0 && indice < emb_size);\n      const T* from = weight_buf + indice * emb_dim;\n      T* to = out_buf + i * emb_dim;\n      std::copy(from, from + emb_dim, to);\n    }\n  }\n};\n\ntemplate<typename T, typename IndexType>\nstruct EmbeddingGradFunctor<DeviceType::kCPU, T, IndexType> final {\n  void operator()(ep::Stream* stream, const T* dy_buf, const IndexType* indices_buf, T* dx_buf,\n                  const int64_t padding_idx, const bool scale_grad_by_freq,\n                  const int64_t num_indices, const int64_t emb_size, const int64_t emb_dim,\n                  int32_t* tmp_buf) {\n    for (int64_t i = 0; i < num_indices; i++) {\n      IndexType indice = indices_buf[i];\n      if (indice != padding_idx) {\n        const T* from = dy_buf + i * emb_dim;\n        T* to = dx_buf + indice * emb_dim;\n        std::transform(from, from + emb_dim, to, to, std::plus<T>());\n      }\n    }\n\n    if (scale_grad_by_freq) {\n      std::vector<IndexType> indice_freq(emb_size, 0);\n      for (int64_t i = 0; i < num_indices; i++) { indice_freq[indices_buf[i]]++; }\n\n      for (int64_t i = 0; i < emb_size; i++) {\n        if (indice_freq[i] > 1) {\n          T* from = dx_buf + i * emb_dim;\n          for (int64_t j = 0; j < emb_dim; j++) { from[j] /= indice_freq[i]; }\n        }\n      }\n    }\n  }\n};\n\n#define INITIATE_EMBEDDING_KERNEL_UTIL_CPU_IMPL(in_type_pair, index_type_pair)             \\\n  template struct EmbeddingReNormFunctor<DeviceType::kCPU, OF_PP_PAIR_FIRST(in_type_pair), \\\n                                         OF_PP_PAIR_FIRST(index_type_pair)>;               \\\n  template struct EmbeddingFunctor<DeviceType::kCPU, OF_PP_PAIR_FIRST(in_type_pair),       \\\n                                   OF_PP_PAIR_FIRST(index_type_pair)>;                     \\\n  template struct EmbeddingGradFunctor<DeviceType::kCPU, OF_PP_PAIR_FIRST(in_type_pair),   \\\n                                       OF_PP_PAIR_FIRST(index_type_pair)>;\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INITIATE_EMBEDDING_KERNEL_UTIL_CPU_IMPL,\n                                 EMBEDDING_DATA_TYPE_SEQ_CPU, INDEX_DATA_TYPE_SEQ);\n#undef INITIATE_EMBEDDING_KERNEL_UTIL_CPU_IMPL\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/embedding_kernel_util.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include <cub/cub.cuh>\n#include \"oneflow/core/cuda/atomic.cuh\"\n#include \"oneflow/user/kernels/embedding_kernel_util.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename T>\nstruct AccumulateType {\n  using type = T;\n};\n\ntemplate<>\nstruct AccumulateType<half> {\n  using type = float;\n};\n\ntemplate<typename T, typename IndexType>\n__global__ void embedding_kernel(const T* weight_buf, const IndexType* indices_buf, T* out_buf,\n                                 const int64_t num_indices, const int64_t emb_size,\n                                 const int64_t emb_dim) {\n  CUDA_1D_KERNEL_LOOP_T(int64_t, i, num_indices * emb_dim) {\n    IndexType indices_index = i / emb_dim;\n    IndexType emb_dim_index = i - indices_index * emb_dim;\n    IndexType emb_size_index = indices_buf[indices_index];\n    assert(emb_size_index >= 0 && emb_size_index < emb_size);\n    IndexType from_index = emb_size_index * emb_dim + emb_dim_index;\n    out_buf[i] = weight_buf[from_index];\n  }\n}\n\ntemplate<typename T, typename IndexType>\n__global__ void embedding_grad_kernel(const T* dy_buf, const IndexType* indices_buf, T* dx_buf,\n                                      const int64_t padding_idx, const int64_t num_indices,\n                                      const int64_t emb_dim) {\n  CUDA_1D_KERNEL_LOOP_T(int64_t, i, num_indices * emb_dim) {\n    IndexType indices_index = i / emb_dim;\n    IndexType emb_dim_index = i - indices_index * emb_dim;\n    IndexType emb_size_index = indices_buf[indices_index];\n    if (emb_size_index != padding_idx) {\n      IndexType from_index = emb_size_index * emb_dim + emb_dim_index;\n      cuda::atomic::Add(dx_buf + from_index, dy_buf[i]);\n    }\n  }\n}\n\ntemplate<typename IndexType>\n__global__ void indices_freq_kernel(const IndexType* indices_buf, const int64_t num_indices,\n                                    int32_t* indices_freq, const int64_t emb_size) {\n  CUDA_1D_KERNEL_LOOP_T(IndexType, i, num_indices) {\n    IndexType index = indices_buf[i];\n    assert(index >= 0 && index < emb_size);\n    cuda::atomic::Add(indices_freq + index, 1);\n  }\n}\n\ntemplate<typename T, typename IndexType>\n__global__ void emb_scale_kernel(T* dx_buf, const int64_t emb_size, const int64_t emb_dim,\n                                 int32_t* indices_freq) {\n  CUDA_1D_KERNEL_LOOP_T(int64_t, i, emb_size * emb_dim) {\n    IndexType emb_size_index = i / emb_dim;\n    if (indices_freq[emb_size_index] > 1) {\n      dx_buf[i] /= static_cast<T>(indices_freq[emb_size_index]);\n    }\n  }\n}\n\ntemplate<typename T, typename IndexType, typename AccumType>\n__global__ void embedding_renorm_kernel(const T* in_buf, T* out_buf, int32_t* indices_freq,\n                                        const AccumType max_norm, const AccumType norm_type,\n                                        const int64_t emb_size, const int64_t emb_dim) {\n  int64_t tid = threadIdx.x;\n  for (int64_t emb_idx = blockIdx.x; emb_idx < emb_size; emb_idx += gridDim.x) {\n    if (indices_freq[emb_idx] == 0) { continue; }\n    int64_t base_index = emb_idx * emb_dim;\n\n    AccumType v = 0;\n    for (int64_t i = tid; i < emb_dim; i += blockDim.x) {\n      v += pow(abs(static_cast<AccumType>(in_buf[base_index + i])), norm_type);\n    }\n\n    using BlockReduce = cub::BlockReduce<AccumType, kCudaThreadsNumPerBlock>;\n    __shared__ typename BlockReduce::TempStorage temp_storage;\n    __shared__ AccumType norm;\n    v = BlockReduce(temp_storage).Sum(v);\n\n    if (tid == 0) { norm = pow(v, static_cast<AccumType>(1.0 / norm_type)); }\n    __syncthreads();\n\n    if (norm > max_norm) {\n      auto scale = static_cast<T>(max_norm / (norm + 1e-7));\n      for (int64_t i = tid; i < emb_dim; i += blockDim.x) {\n        out_buf[base_index + i] = in_buf[base_index + i] * scale;\n      }\n    }\n  }\n}\n\n}  // namespace\n\ntemplate<typename T, typename IndexType>\nstruct EmbeddingReNormFunctor<DeviceType::kCUDA, T, IndexType> final {\n  void operator()(ep::Stream* stream, const T* in_buf, const IndexType* indices_buf, T* out_buf,\n                  const double max_norm, const double norm_type, const int64_t num_indices,\n                  const int64_t emb_size, const int64_t emb_dim, int32_t* tmp_buf) {\n    indices_freq_kernel<IndexType><<<BlocksNum4ThreadsNum(num_indices), kCudaThreadsNumPerBlock, 0,\n                                     stream->As<ep::CudaStream>()->cuda_stream()>>>(\n        indices_buf, num_indices, tmp_buf, emb_size);\n\n    using AccumType = typename AccumulateType<T>::type;\n    embedding_renorm_kernel<T, IndexType, AccumType>\n        <<<BlocksNum4ThreadsNum(emb_size), kCudaThreadsNumPerBlock, 0,\n           stream->As<ep::CudaStream>()->cuda_stream()>>>(\n            in_buf, out_buf, tmp_buf, static_cast<AccumType>(max_norm),\n            static_cast<AccumType>(norm_type), emb_size, emb_dim);\n  }\n};\n\ntemplate<typename T, typename IndexType>\nstruct EmbeddingFunctor<DeviceType::kCUDA, T, IndexType> final {\n  void operator()(ep::Stream* stream, const T* weight_buf, const IndexType* indices_buf, T* out_buf,\n                  const int64_t padding_idx, const bool scale_grad_by_freq,\n                  const int64_t num_indices, const int64_t emb_size, const int64_t emb_dim) {\n    embedding_kernel<T, IndexType>\n        <<<BlocksNum4ThreadsNum(num_indices * emb_dim), kCudaThreadsNumPerBlock, 0,\n           stream->As<ep::CudaStream>()->cuda_stream()>>>(weight_buf, indices_buf, out_buf,\n                                                          num_indices, emb_size, emb_dim);\n  }\n};\n\ntemplate<typename T, typename IndexType>\nstruct EmbeddingGradFunctor<DeviceType::kCUDA, T, IndexType> final {\n  void operator()(ep::Stream* stream, const T* dy_buf, const IndexType* indices_buf, T* dx_buf,\n                  const int64_t padding_idx, const bool scale_grad_by_freq,\n                  const int64_t num_indices, const int64_t emb_size, const int64_t emb_dim,\n                  int32_t* tmp_buf) {\n    embedding_grad_kernel<T, IndexType>\n        <<<BlocksNum4ThreadsNum(num_indices * emb_dim), kCudaThreadsNumPerBlock, 0,\n           stream->As<ep::CudaStream>()->cuda_stream()>>>(dy_buf, indices_buf, dx_buf, padding_idx,\n                                                          num_indices, emb_dim);\n    if (scale_grad_by_freq) {\n      indices_freq_kernel<IndexType><<<BlocksNum4ThreadsNum(num_indices), kCudaThreadsNumPerBlock,\n                                       0, stream->As<ep::CudaStream>()->cuda_stream()>>>(\n          indices_buf, num_indices, tmp_buf, emb_size);\n      emb_scale_kernel<T, IndexType>\n          <<<BlocksNum4ThreadsNum(emb_size * emb_dim), kCudaThreadsNumPerBlock, 0,\n             stream->As<ep::CudaStream>()->cuda_stream()>>>(dx_buf, emb_size, emb_dim, tmp_buf);\n    }\n  }\n};\n\n#define INITIATE_EMBEDDING_KERNEL_UTIL_CUDA_IMPL(in_type_pair, index_type_pair)             \\\n  template struct EmbeddingReNormFunctor<DeviceType::kCUDA, OF_PP_PAIR_FIRST(in_type_pair), \\\n                                         OF_PP_PAIR_FIRST(index_type_pair)>;                \\\n  template struct EmbeddingFunctor<DeviceType::kCUDA, OF_PP_PAIR_FIRST(in_type_pair),       \\\n                                   OF_PP_PAIR_FIRST(index_type_pair)>;                      \\\n  template struct EmbeddingGradFunctor<DeviceType::kCUDA, OF_PP_PAIR_FIRST(in_type_pair),   \\\n                                       OF_PP_PAIR_FIRST(index_type_pair)>;\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INITIATE_EMBEDDING_KERNEL_UTIL_CUDA_IMPL,\n                                 EMBEDDING_DATA_TYPE_SEQ_CUDA, INDEX_DATA_TYPE_SEQ);\n\n#undef INITIATE_EMBEDDING_KERNEL_UTIL_CUDA_IMPL\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/embedding_kernel_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#ifndef ONEFLOW_USER_KERNELS_EMBEDDING_KERNEL_UTIL_H_\n#define ONEFLOW_USER_KERNELS_EMBEDDING_KERNEL_UTIL_H_\n\n#include \"oneflow/core/kernel/kernel_util.h\"\n\nnamespace oneflow {\n\ntemplate<DeviceType device_type, typename T, typename IndexType>\nstruct EmbeddingReNormFunctor final {\n  void operator()(ep::Stream* stream, const T* in_buf, const IndexType* indices_buf, T* out_buf,\n                  const double max_norm, const double norm_type, const int64_t num_indices,\n                  const int64_t emb_size, const int64_t emb_dim, int32_t* tmp_buf);\n};\n\ntemplate<DeviceType device_type, typename T, typename IndexType>\nstruct EmbeddingFunctor final {\n  void operator()(ep::Stream* stream, const T* weight_buf, const IndexType* indices_buf, T* out_buf,\n                  const int64_t padding_idx, const bool scale_grad_by_freq,\n                  const int64_t num_indices, const int64_t emb_size, const int64_t emb_dim);\n};\n\ntemplate<DeviceType device_type, typename T, typename IndexType>\nstruct EmbeddingGradFunctor final {\n  void operator()(ep::Stream* stream, const T* dy_buf, const IndexType* indices_buf, T* dx_buf,\n                  const int64_t padding_idx, const bool scale_grad_by_freq,\n                  const int64_t num_indices, const int64_t emb_size, const int64_t emb_dim,\n                  int32_t* tmp_buf);\n};\n\n#define EMBEDDING_DATA_TYPE_SEQ_CPU FLOATING_DATA_TYPE_SEQ FLOAT16_DATA_TYPE_SEQ\n#define EMBEDDING_DATA_TYPE_SEQ_CUDA FLOATING_DATA_TYPE_SEQ HALF_DATA_TYPE_SEQ\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_KERNELS_EMBEDDING_KERNEL_UTIL_H_\n"
  },
  {
    "path": "oneflow/user/kernels/empty_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/common/tensor_buffer.h\"\n#include \"oneflow/core/common/data_type.h\"\n#include \"oneflow/core/kernel/new_kernel_util.h\"\n\nnamespace oneflow {\nnamespace user_op {\n\nclass EmptyKernel final : public OpKernel {\n public:\n  EmptyKernel() = default;\n  ~EmptyKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n\n    // None POD type need check\n    if (!IsTriviallyCopyableDataType(out->data_type())) {\n      CHECK(out->shape_view().NumAxes() > 0 && out->shape_view().elem_cnt() == 0)\n          << \"None POD Tensor created by empty op must be 0-Size tensor.\";\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\nREGISTER_USER_KERNEL(\"empty\").SetCreateFn<EmptyKernel>();\n\n}  // namespace user_op\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/erfinv_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include <cmath>\nnamespace oneflow {\n\ntemplate<typename T>\nclass CpuErfinvKernel final : public user_op::OpKernel {\n public:\n  CpuErfinvKernel() = default;\n  ~CpuErfinvKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    const int32_t elem_cnt = x->shape_view().elem_cnt();\n    const T* x_ptr = x->dptr<T>();\n    T* y_ptr = y->mut_dptr<T>();\n    constexpr float central_range = 0.7;\n    const T temp = static_cast<T>(2.0) / static_cast<T>(std::sqrt(M_PI));\n    T a[4] = {T(0.886226899), T(-1.645349621), T(0.914624893), T(-0.140543331)};\n    T b[4] = {T(-2.118377725), T(1.442710462), T(-0.329097515), T(0.012229801)};\n    T c[4] = {T(-1.970840454), T(-1.624906493), T(3.429567803), T(1.641345311)};\n    T d[2] = {T(3.543889200), T(1.637067800)};\n    FOR_RANGE(int32_t, i, 0, elem_cnt) {\n      T z, num, dem;\n      T x = x_ptr[i];  // Promise the correctness of inplace version.\n      T x_abs = std::abs(x);\n      if (x_abs > 1.0) {\n        y_ptr[i] = std::numeric_limits<T>::quiet_NaN();\n        continue;\n      }\n      if (x_abs == 1.0) {\n        y_ptr[i] = std::copysign(std::numeric_limits<T>::infinity(), x);\n        continue;\n      }\n      if (x_abs <= static_cast<T>(central_range)) {\n        z = x * x;\n        num = (((a[3] * z + a[2]) * z + a[1]) * z + a[0]);\n        dem = ((((b[3] * z + b[2]) * z + b[1]) * z + b[0]) * z + static_cast<T>(1.0));\n        y_ptr[i] = x * num / dem;\n      } else {\n        z = std::sqrt(-std::log((static_cast<T>(1.0) - x_abs) / static_cast<T>(2.0)));\n        num = ((c[3] * z + c[2]) * z + c[1]) * z + c[0];\n        dem = (d[1] * z + d[0]) * z + static_cast<T>(1.0);\n        y_ptr[i] = std::copysign(num, x) / dem;\n      }\n      y_ptr[i] = y_ptr[i] - (std::erf(y_ptr[i]) - x) / (temp * std::exp(-y_ptr[i] * y_ptr[i]));\n      y_ptr[i] = y_ptr[i] - (std::erf(y_ptr[i]) - x) / (temp * std::exp(-y_ptr[i] * y_ptr[i]));\n    }\n  }\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_CPU_ERFINV_KERNEL(dtype)                                              \\\n  REGISTER_USER_KERNEL(\"erfinv\")                                                       \\\n      .SetCreateFn<CpuErfinvKernel<dtype>>()                                           \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)                  \\\n                       && (user_op::HobDataType(\"y\", 0) == GetDataType<dtype>::value)) \\\n      .SetInplaceProposalFn(                                                           \\\n          [](const user_op::InferContext&,                                             \\\n             const user_op::AddInplaceArgPair& AddInplaceArgPairFn) -> Maybe<void> {   \\\n            OF_RETURN_IF_ERROR(AddInplaceArgPairFn(\"y\", 0, \"x\", 0, true));             \\\n            return Maybe<void>::Ok();                                                  \\\n          });\n\nREGISTER_CPU_ERFINV_KERNEL(float)\nREGISTER_CPU_ERFINV_KERNEL(double)\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/erfinv_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/cuda/elementwise.cuh\"\n\nnamespace oneflow {\n\ntemplate<typename T>\nstruct ErfInvFunctor {\n  OF_DEVICE_FUNC ErfInvFunctor() {}\n  OF_DEVICE_FUNC T operator()(T x) const { return erfinv(x); }\n};\n\ntemplate<typename T>\nclass GpuErfinvKernel final : public user_op::OpKernel {\n public:\n  GpuErfinvKernel() = default;\n  ~GpuErfinvKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    const int32_t elem_cnt = x->shape_view().elem_cnt();\n    OF_CUDA_CHECK(cuda::elementwise::Unary(ErfInvFunctor<T>(), elem_cnt, y->mut_dptr<T>(),\n                                           x->dptr<T>(),\n                                           ctx->stream()->As<ep::CudaStream>()->cuda_stream()));\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_CUDA_ERFINV_KERNEL(dtype)                                                      \\\n  REGISTER_USER_KERNEL(\"erfinv\")                                                                \\\n      .SetCreateFn<GpuErfinvKernel<dtype>>()                                                    \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)                          \\\n                       && (user_op::HobDataType(\"y\", 0) == GetDataType<dtype>::value))          \\\n      .SetInplaceProposalFn([](const user_op::InferContext&,                                    \\\n                               user_op::AddInplaceArgPair AddInplaceArgPairFn) -> Maybe<void> { \\\n        OF_RETURN_IF_ERROR(AddInplaceArgPairFn(\"y\", 0, \"x\", 0, true));                          \\\n        return Maybe<void>::Ok();                                                               \\\n      });\n\nREGISTER_CUDA_ERFINV_KERNEL(float)\nREGISTER_CUDA_ERFINV_KERNEL(double)\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/expand_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/ep/include/primitive/broadcast_elementwise_unary.h\"\n#include \"oneflow/core/kernel/cuda_graph_support.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename Context>\nstd::unique_ptr<ep::primitive::BroadcastElementwiseUnary> NewPrimitive(Context* ctx) {\n  const auto* in_desc = ctx->TensorDesc4ArgNameAndIndex(\"in\", 0);\n  const auto* out_desc = ctx->TensorDesc4ArgNameAndIndex(\"out\", 0);\n  size_t max_ndim = std::max(in_desc->shape().size(), out_desc->shape().size());\n  return ep::primitive::NewPrimitive<ep::primitive::BroadcastElementwiseUnaryFactory>(\n      ctx->device_type(), ep::primitive::UnaryOp::kIdentity, in_desc->data_type(),\n      out_desc->data_type(), max_ndim);\n}\n\nauto PrimitiveExists() {\n  return hob::make_custom(\"BroadcastElementwiseUnaryPrimitiveExists\",\n                          [](const user_op::KernelRegContext& ctx) -> bool {\n                            return NewPrimitive(&ctx).operator bool();\n                          });\n}\n\n}  // namespace\n\nclass ExpandKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport {\n public:\n  ExpandKernel() = default;\n  ~ExpandKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    auto in_shape = in->shape_view();\n    auto out_shape = out->shape_view();\n\n    // handle 0-size tensor\n    if (std::any_of(out_shape.begin(), out_shape.end(), [](int64_t dim) { return dim <= 0; })) {\n      return;\n    }\n\n    auto prim = NewPrimitive(ctx);\n    CHECK(prim);\n    if (in_shape.size() == 0 && in_shape.elem_cnt() == 1) {\n      // handle 0-dim tensor\n      // NOTE: this handle will be remove when BroadcastElementwiseUnary primitive support 0-dim\n      // tensor\n      int64_t scalar_ndim = 1;\n      Shape scalar_shape(DimVector{scalar_ndim});\n      Shape scalar_stride(DimVector{scalar_ndim});\n      prim->Launch(ctx->stream(), scalar_ndim, scalar_shape.data(), scalar_stride.data(),\n                   in->dptr(), out_shape.size(), out_shape.data(), out->stride().data(),\n                   out->mut_dptr());\n    } else {\n      prim->Launch(ctx->stream(), in_shape.size(), in_shape.data(), in->stride().data(), in->dptr(),\n                   out_shape.size(), out_shape.data(), out->stride().data(), out->mut_dptr());\n    }\n  }\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\nREGISTER_USER_KERNEL(\"expand\").SetCreateFn<ExpandKernel>().SetIsMatchedHob(PrimitiveExists()\n                                                                           == true);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/eye_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/user/kernels/eye_kernel_util.h\"\n#include \"oneflow/core/common/data_type.h\"\n\nnamespace oneflow {\nnamespace user_op {\ntemplate<DeviceType device_type, typename T>\nclass EyeKernel final : public OpKernel {\n public:\n  EyeKernel() = default;\n  ~EyeKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    int64_t rows = ctx->Attr<int64_t>(\"rows\");\n    int64_t cols = ctx->Attr<int64_t>(\"cols\");\n    if (rows == 0 || cols == 0) { return; }\n    Tensor* out_tensor = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    T* out = out_tensor->mut_dptr<T>();\n    Memset<device_type>(\n        ctx->stream(), out_tensor->mut_dptr<T>(), 0,\n        out_tensor->shape_view().elem_cnt() * GetSizeOfDataType(out_tensor->data_type()));\n    EyeFunctor<device_type, T>()(ctx->stream(), cols, std::min(cols, rows), out);\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_EYE_KERNEL(device, dtype)                                             \\\n  REGISTER_USER_KERNEL(\"eye\").SetCreateFn<EyeKernel<device, dtype>>().SetIsMatchedHob( \\\n      (user_op::HobDeviceType() == device)                                             \\\n      && (user_op::HobAttr<DataType>(\"dtype\") == GetDataType<dtype>::value));\n\n#define REGISTER_EYE_KERNELS_WITH_DEVICE(device) \\\n  REGISTER_EYE_KERNEL(device, bool)              \\\n  REGISTER_EYE_KERNEL(device, uint8_t)           \\\n  REGISTER_EYE_KERNEL(device, int8_t)            \\\n  REGISTER_EYE_KERNEL(device, int32_t)           \\\n  REGISTER_EYE_KERNEL(device, int64_t)           \\\n  REGISTER_EYE_KERNEL(device, float)             \\\n  REGISTER_EYE_KERNEL(device, double)\n\n// Register CPU version\nREGISTER_EYE_KERNELS_WITH_DEVICE(DeviceType::kCPU);\n\n// Register CUDA version\n#ifdef WITH_CUDA\nREGISTER_EYE_KERNELS_WITH_DEVICE(DeviceType::kCUDA);\n#endif\n#undef REGISTER_EYE_KERNELS_WITH_DEVICE\n#undef REGISTER_EYE_KERNEL\n}  // namespace user_op\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/eye_kernel_util.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/user/kernels/eye_kernel_util.h\"\n\nnamespace oneflow {\n\nnamespace user_op {\n\ntemplate<typename T>\nstruct EyeFunctor<DeviceType::kCPU, T> final {\n  void operator()(ep::Stream* stream, const int64_t& cols, const int64_t& rows, T* out) {\n    SetOneInDiag(cols, rows, out);\n  }\n};\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_EYE_FUNCTOR, (DeviceType::kCPU), EYE_DATA_TYPE_SEQ);\n\n}  // namespace user_op\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/eye_kernel_util.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifdef WITH_CUDA\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/user/kernels/eye_kernel_util.h\"\n\nnamespace oneflow {\n\nnamespace user_op {\n\ntemplate<typename T>\n__global__ void EyeForwardGpuKernel(const int64_t cols, const int64_t rows, T* out) {\n  SetOneInDiag(cols, rows, out);\n}\n\ntemplate<typename T>\nstruct EyeFunctor<DeviceType::kCUDA, T> final {\n  void operator()(ep::Stream* stream, const int64_t& cols, const int64_t& rows, T* out) {\n    RUN_CUDA_KERNEL((EyeForwardGpuKernel<T>), stream, rows, cols, rows, out);\n  }\n};\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_EYE_FUNCTOR, (DeviceType::kCUDA), EYE_DATA_TYPE_SEQ);\n}  // namespace user_op\n}  // namespace oneflow\n\n#endif  // End WITH_CUDA\n"
  },
  {
    "path": "oneflow/user/kernels/eye_kernel_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_KERNELS_EYE_KERNEL_UTIL_H_\n#define ONEFLOW_USER_KERNELS_EYE_KERNEL_UTIL_H_\n#include \"oneflow/core/ep/include/stream.h\"\n#include \"oneflow/core/ndarray/xpu_util.h\"\n\nnamespace oneflow {\nnamespace user_op {\n\n#define EYE_DATA_TYPE_SEQ    \\\n  FLOATING_DATA_TYPE_SEQ     \\\n  INT_DATA_TYPE_SEQ          \\\n  UNSIGNED_INT_DATA_TYPE_SEQ \\\n  BOOL_DATA_TYPE_SEQ\n\ntemplate<DeviceType device_type, typename T>\nstruct EyeFunctor final {\n  void operator()(ep::Stream* stream, const int64_t& cols, const int64_t& rows, T* out);\n};\n\ntemplate<typename T>\nOF_DEVICE_FUNC void SetOneInDiag(const int64_t cols, const int64_t rows, T* out) {\n  const T one = static_cast<T>(1);\n  XPU_1D_KERNEL_LOOP(i, rows) {\n    const int64_t index = i * cols + i;\n    out[index] = one;\n  }\n}\n\n#define INSTANTIATE_EYE_FUNCTOR(device_type_v, dtype_pair) \\\n  template struct EyeFunctor<device_type_v, OF_PP_PAIR_FIRST(dtype_pair)>;\n\n}  // namespace user_op\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_KERNELS_EYE_KERNEL_UTIL_H_\n"
  },
  {
    "path": "oneflow/user/kernels/fake_quantization_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n\n#include <algorithm>\n\nnamespace oneflow {\n\ntemplate<typename T>\nvoid FakeQuantizationPerLayerSymmetric(const T* in_ptr, const T scale,\n                                       const int32_t quantization_bit, const int64_t num_elements,\n                                       T* out_ptr) {\n  T upper_bound = static_cast<T>(pow(2.0, quantization_bit - 1)) - 1;\n  T lower_bound = -upper_bound - 1;\n  FOR_RANGE(int64_t, i, 0, num_elements) {\n    T out = std::nearbyint(in_ptr[i] / scale);\n    out = out > upper_bound ? upper_bound : out;\n    out = out < lower_bound ? lower_bound : out;\n    out_ptr[i] = out * scale;\n  }\n}\n\ntemplate<typename T>\nvoid FakeQuantizationPerLayerAffine(const T* in_ptr, const T scale, const T zero_point,\n                                    const int32_t quantization_bit, const int64_t num_elements,\n                                    T* out_ptr) {\n  T upper_bound = static_cast<T>(pow(2.0, quantization_bit)) - 1;\n  T lower_bound = 0;\n  uint8_t zero_point_uint8 = static_cast<uint8_t>(std::round(zero_point));\n  FOR_RANGE(int64_t, i, 0, num_elements) {\n    T out = std::nearbyint(in_ptr[i] / scale + zero_point_uint8);\n    out = out > upper_bound ? upper_bound : out;\n    out = out < lower_bound ? lower_bound : out;\n    out_ptr[i] = (out - zero_point_uint8) * scale;\n  }\n}\n\ntemplate<typename T>\nvoid FakeQuantizationPerLayerCambricon(const T* in_ptr, const T shift,\n                                       const int32_t quantization_bit, const int64_t num_elements,\n                                       T* out_ptr) {\n  T upper_bound = static_cast<T>(pow(2.0, quantization_bit - 1)) - 1;\n  T lower_bound = -upper_bound - 1;\n  T scale = static_cast<T>(pow(2.0, static_cast<int32_t>(shift)));\n  FOR_RANGE(int64_t, i, 0, num_elements) {\n    T out = std::nearbyint(in_ptr[i] / scale);\n    out = out > upper_bound ? upper_bound : out;\n    out = out < lower_bound ? lower_bound : out;\n    out_ptr[i] = out * scale;\n  }\n}\n\ntemplate<typename T>\nclass CpuFakeQuantizationKernel final : public user_op::OpKernel {\n public:\n  CpuFakeQuantizationKernel() = default;\n  ~CpuFakeQuantizationKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    const user_op::Tensor* scale = ctx->Tensor4ArgNameAndIndex(\"scale\", 0);\n    const user_op::Tensor* zero_point = ctx->Tensor4ArgNameAndIndex(\"zero_point\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n\n    const std::string quantization_scheme = ctx->Attr<std::string>(\"quantization_scheme\");\n    const int32_t quantization_bit = ctx->Attr<int32_t>(\"quantization_bit\");\n    const std::string quantization_formula = ctx->Attr<std::string>(\"quantization_formula\");\n\n    const T* in_ptr = in->dptr<T>();\n    const T* scale_ptr = scale->dptr<T>();\n    T* out_ptr = out->mut_dptr<T>();\n\n    // round to even\n    auto origin_round_mode = std::fegetround();\n    std::fesetround(FE_TONEAREST);\n\n    if (quantization_formula == \"google\") {\n      int64_t outer_num = 1;\n      int64_t inner_num = in->shape_view().elem_cnt();\n      if (scale->shape_view().elem_cnt() > 1) {  // per-channel quantization\n        outer_num = in->shape_view().At(0);\n        inner_num = in->shape_view().Count(1);\n      }\n\n      if (quantization_scheme == \"symmetric\") {\n        FOR_RANGE(int64_t, c, 0, outer_num) {\n          FakeQuantizationPerLayerSymmetric(in_ptr, scale_ptr[c], quantization_bit, inner_num,\n                                            out_ptr);\n          in_ptr += inner_num;\n          out_ptr += inner_num;\n        }\n      } else {  // quantization_scheme == \"affine\"\n        const T* zero_point_ptr = zero_point->dptr<T>();\n        FOR_RANGE(int64_t, c, 0, outer_num) {\n          FakeQuantizationPerLayerAffine(in_ptr, scale_ptr[c], zero_point_ptr[c], quantization_bit,\n                                         inner_num, out_ptr);\n          in_ptr += inner_num;\n          out_ptr += inner_num;\n        }\n      }\n    } else if (quantization_formula == \"cambricon\") {\n      FakeQuantizationPerLayerCambricon(in_ptr, scale_ptr[0], quantization_bit,\n                                        in->shape_view().elem_cnt(), out_ptr);\n    } else {\n      UNIMPLEMENTED();\n    }\n\n    std::fesetround(origin_round_mode);\n  }\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_FAKE_QUANTIZATION_KERNEL(dtype)                      \\\n  REGISTER_USER_KERNEL(\"fake_quantization\")                           \\\n      .SetCreateFn<CpuFakeQuantizationKernel<dtype>>()                \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \\\n                       && (user_op::HobDataType(\"in\", 0) == GetDataType<dtype>::value))\n\nREGISTER_FAKE_QUANTIZATION_KERNEL(float);\nREGISTER_FAKE_QUANTIZATION_KERNEL(double);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/fake_quantization_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/device/cuda_util.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/kernel_util.cuh\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename T>\n__global__ void FakeQuantizationSymmetric(const T* in_ptr, const T* scale_ptr,\n                                          const int64_t scale_size, const int64_t elements,\n                                          const int64_t panel_size, const double quantization_bit,\n                                          T* out_ptr) {\n  int64_t gid = (blockDim.x * blockIdx.x) + threadIdx.x;\n  int64_t step = gridDim.x * blockDim.x;\n\n  T upper_bound = static_cast<T>(pow(2.0, quantization_bit - 1)) - 1;\n  T lower_bound = -upper_bound - 1;\n\n  while (gid < elements) {\n    int64_t channel_index = gid / panel_size;\n    int64_t scale_idx = min(scale_size - 1, channel_index);\n\n    T scale = scale_ptr[scale_idx];\n\n    T out = nearbyint(in_ptr[gid] / scale);\n    out = out > upper_bound ? upper_bound : out;\n    out = out < lower_bound ? lower_bound : out;\n    out_ptr[gid] = out * scale;\n\n    gid += step;\n  }\n}\n\ntemplate<typename T>\n__global__ void FakeQuantizationAffine(const T* in_ptr, const T* scale_ptr, const T* zero_point_ptr,\n                                       const int64_t scale_size, const int64_t elements,\n                                       const int64_t panel_size, const double quantization_bit,\n                                       T* out_ptr) {\n  int64_t gid = (blockDim.x * blockIdx.x) + threadIdx.x;\n  int64_t step = gridDim.x * blockDim.x;\n\n  T upper_bound = static_cast<T>(pow(2.0, quantization_bit)) - 1;\n  T lower_bound = 0;\n\n  while (gid < elements) {\n    int64_t channel_index = gid / panel_size;\n    int64_t scale_idx = min(scale_size - 1, channel_index);\n\n    T scale = scale_ptr[scale_idx];\n    T zero_point = zero_point_ptr[scale_idx];\n\n    T out = nearbyint(in_ptr[gid] / scale + zero_point);\n    out = out > upper_bound ? upper_bound : out;\n    out = out < lower_bound ? lower_bound : out;\n    out_ptr[gid] = (out - zero_point) * scale;\n\n    gid += step;\n  }\n}\n\ntemplate<typename T>\n__global__ void FakeQuantizationCambricon(const T* in_ptr, const T* shift, const int64_t scale_size,\n                                          const int64_t elements, const int64_t panel_size,\n                                          const double quantization_bit, T* out_ptr) {\n  int64_t gid = (blockDim.x * blockIdx.x) + threadIdx.x;\n  int64_t step = gridDim.x * blockDim.x;\n\n  T upper_bound = static_cast<T>(pow(2.0, quantization_bit - 1)) - 1;\n  T lower_bound = -upper_bound - 1;\n\n  T scale = static_cast<T>(pow(2.0, static_cast<int32_t>(shift[0])));\n\n  while (gid < elements) {\n    T out = nearbyint(in_ptr[gid] / scale);\n    out = out > upper_bound ? upper_bound : out;\n    out = out < lower_bound ? lower_bound : out;\n    out_ptr[gid] = out * scale;\n    gid += step;\n  }\n}\n\n}  // namespace\n\ntemplate<typename T>\nclass GpuFakeQuantizationKernel final : public user_op::OpKernel {\n public:\n  GpuFakeQuantizationKernel() = default;\n  ~GpuFakeQuantizationKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    const user_op::Tensor* scale = ctx->Tensor4ArgNameAndIndex(\"scale\", 0);\n    const user_op::Tensor* zero_point = ctx->Tensor4ArgNameAndIndex(\"zero_point\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n\n    const std::string quantization_scheme = ctx->Attr<std::string>(\"quantization_scheme\");\n    const int32_t quantization_bit = ctx->Attr<int32_t>(\"quantization_bit\");\n    const std::string quantization_formula = ctx->Attr<std::string>(\"quantization_formula\");\n\n    const int64_t elements = in->shape_view().elem_cnt();\n    const int64_t panel_size = in->shape_view().Count(1);\n    const int64_t scale_size = scale->shape_view().elem_cnt();\n\n    // round to even\n    auto origin_round_mode = std::fegetround();\n    std::fesetround(FE_TONEAREST);\n\n    if (quantization_formula == \"google\") {\n      if (quantization_scheme == \"symmetric\") {\n        RUN_CUDA_KERNEL((FakeQuantizationSymmetric<T>), ctx->stream(), elements, in->dptr<T>(),\n                        scale->dptr<T>(), scale_size, elements, panel_size, quantization_bit,\n                        out->mut_dptr<T>());\n      } else {  // quantization_scheme == \"affine\"\n        RUN_CUDA_KERNEL((FakeQuantizationAffine<T>), ctx->stream(), elements, in->dptr<T>(),\n                        scale->dptr<T>(), zero_point->dptr<T>(), scale_size, elements, panel_size,\n                        quantization_bit, out->mut_dptr<T>());\n      }\n    } else if (quantization_formula == \"cambricon\") {\n      RUN_CUDA_KERNEL((FakeQuantizationCambricon<T>), ctx->stream(), elements, in->dptr<T>(),\n                      scale->dptr<T>(), scale_size, elements, panel_size, quantization_bit,\n                      out->mut_dptr<T>());\n    } else {\n      UNIMPLEMENTED();\n    }\n\n    std::fesetround(origin_round_mode);\n  }\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_FAKE_QUANTIZATION_KERNEL(dtype)                       \\\n  REGISTER_USER_KERNEL(\"fake_quantization\")                            \\\n      .SetCreateFn<GpuFakeQuantizationKernel<dtype>>()                 \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \\\n                       && (user_op::HobDataType(\"in\", 0) == GetDataType<dtype>::value))\n\nREGISTER_FAKE_QUANTIZATION_KERNEL(float);\nREGISTER_FAKE_QUANTIZATION_KERNEL(double);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/fft_kernel_util.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/user/kernels/fft_kernel_util.h\"\n#include <type_traits>\n#include \"pocketfftplan.h\"\n#include \"oneflow/core/common/device_type.pb.h\"\n#include \"oneflow/core/common/preprocessor.h\"\n#include \"oneflow/core/framework/user_op_tensor.h\"\n\nnamespace oneflow {\n\ntemplate<typename T>\nstatic void _conj_symmetry_cpu(T* data_out, const Shape& shape, const std::vector<int64_t>& strides,\n                               const int64_t last_dim, int64_t elem_count) {\n  const oneflow::NdIndexStrideOffsetHelper<int64_t, SHAPE_MAX_AXIS_SIZE> helper(strides.data(),\n                                                                                shape.size());\n  // NOTE: dims must be sorted\n  int64_t last_dim_size = shape[last_dim];\n  int64_t last_dim_half = last_dim_size / 2;\n\n  int64_t ndim = shape.size();\n  std::vector<int64_t> indices(ndim);\n  for (int offset = 0; offset < elem_count; offset++) {\n    helper.OffsetToNdIndex(offset, indices.data(), ndim);\n    if (indices[last_dim] <= last_dim_half) { continue; }\n\n    int64_t cur_last_dim_index = indices[last_dim];\n    // get symmetric\n    indices[last_dim] = last_dim_size - cur_last_dim_index;\n    int64_t symmetric_offset = helper.NdIndexToOffset(indices.data(), ndim);\n\n    // conj\n    data_out[offset] = std::conj(data_out[symmetric_offset]);\n  }\n}\n\ntemplate<typename T>\nstruct FillConjSymmetryUtil<DeviceType::kCPU, T> {\n  static void FillConjSymmetryForward(ep::Stream* stream, T* data_out, const Shape& shape,\n                                      const Stride& strides, const int64_t last_dim,\n                                      int64_t elem_count) {\n    std::vector<int64_t> strides_vec(strides.begin(), strides.end());\n    _conj_symmetry_cpu(/*data_out*/ data_out, /*shape*/ shape, /*strides*/ strides_vec,\n                       /*last_dim*/ last_dim, /*elem_count*/ elem_count);\n  }\n};\n\ntemplate<typename real_type, typename complex_type>\nstruct ComplexConvertUtil<DeviceType::kCPU, real_type, complex_type> {\n  static void ConvertToDoubleSized(ep::Stream* stream, const complex_type* in, complex_type* dst,\n                                   size_t len, size_t n) {\n    size_t fact_len = 2 * len - 2;  // input_shape.back()\n    for (int i = 0; i < n; i++) {\n      int index_x = i / fact_len;\n      int index_y = i % fact_len;\n      if (index_y == 0) {\n        dst[i] = in[index_x * len];\n      } else if (index_y == len - 1) {\n        dst[i] = in[(index_x + 1) * len - 1];\n      } else if (index_y < len - 1 && index_y > 0) {\n        dst[i] = in[index_x * len + index_y];\n      } else {\n        auto index = (index_x + 2) * len - index_y - 2;\n        auto realvalue = in[index].real();\n        dst[i].real(realvalue);\n        auto imagvalue = -in[index].imag();\n        dst[i].imag(imagvalue);\n      }\n    }\n  }\n  static void ConvertComplexToReal(ep::Stream* stream, const complex_type* in, real_type* out,\n                                   size_t n) {\n    for (int i = 0; i < n; i++) {\n      out[2 * i] = in[i].real();\n      out[2 * i + 1] = in[i].imag();\n    }\n  }\n};\n\ntemplate<typename T, typename FCT_TYPE>\nstruct FftC2CKernelUtil<DeviceType::kCPU, T, FCT_TYPE> {\n  static void FftC2CForward(ep::Stream* stream, const T* data_in, T* data_out,\n                            const Shape& input_shape, const Shape& output_shape,\n                            const Stride& input_stride, const Stride& output_stride, bool forward,\n                            const std::vector<int64_t>& dims, FCT_TYPE norm_fct,\n                            DataType real_type) {\n    PocketFFtParams<FCT_TYPE> params(input_shape, output_shape, input_stride, output_stride, dims,\n                                     forward, norm_fct /*1.f*/, FFT_EXCUTETYPE::C2C);\n    PocketFFtConfig<FCT_TYPE> config(params);\n    config.excute(data_in, data_out);\n  }\n};\n\ntemplate<typename IN, typename OUT>\nstruct FftR2CKernelUtil<DeviceType::kCPU, IN, OUT> {\n  static void FftR2CForward(ep::Stream* stream, const IN* data_in, OUT* data_out,\n                            const Shape& input_shape, const Shape& output_shape,\n                            const Stride& input_stride, const Stride& output_stride, bool forward,\n                            const std::vector<int64_t>& dims, IN norm_fct, DataType real_type) {\n    PocketFFtParams<IN> params(input_shape, output_shape, input_stride, output_stride, dims,\n                               forward, norm_fct /*1.f*/, FFT_EXCUTETYPE::R2C);\n    PocketFFtConfig<IN> config(params);\n    config.excute(data_in, data_out);\n  }\n};\n\ntemplate<typename IN, typename OUT>\nstruct FftC2RKernelUtil<DeviceType::kCPU, IN, OUT> {\n  static void FftC2RForward(ep::Stream* stream, const IN* data_in, OUT* data_out,\n                            const Shape& input_shape, const Shape& output_shape,\n                            const Stride& input_stride, const Stride& output_stride, bool forward,\n                            int64_t last_dim_size, const std::vector<int64_t>& dims, OUT norm_fct,\n                            DataType real_type) {\n    PocketFFtParams<OUT> params(input_shape, output_shape, input_stride, output_stride, dims,\n                                /*is_forward=*/false, norm_fct /*1.f*/, FFT_EXCUTETYPE::C2R);\n    PocketFFtConfig<OUT> config(params);\n    config.excute(data_in, data_out);\n  }\n};\n\ntemplate<typename IN, typename OUT>\nstruct FftStftKernelUtil<DeviceType::kCPU, IN, OUT> {\n  static void FftStftForward(ep::Stream* stream, const IN* data_in, OUT* data_out,\n                             const Shape& input_shape, const Shape& output_shape,\n                             const Stride& input_stride, const Stride& output_stride, bool forward,\n                             const std::vector<int64_t>& axes, IN norm_fct, int64_t len,\n                             int64_t dims, int64_t batch) {\n    PocketFFtParams<IN> params(input_shape, output_shape, input_stride, output_stride, axes,\n                               forward, norm_fct /*1.f*/, FFT_EXCUTETYPE::R2C);\n    PocketFFtConfig<IN> config(params);\n    int64_t in_offset = len;\n    int64_t out_offset = len / 2 + 1;\n    for (int j = 0; j < dims; j++) {\n      for (int i = 0; i < batch; i++) {\n        const IN* in = data_in + j * batch * in_offset + i * in_offset;\n        OUT* out = data_out + j * batch * out_offset + i * out_offset;\n        config.excute(in, out);\n      }\n    }\n  }\n};\ntemplate struct FillConjSymmetryUtil<DeviceType::kCPU, std::complex<float>>;\ntemplate struct FillConjSymmetryUtil<DeviceType::kCPU, std::complex<double>>;\n\ntemplate struct ComplexConvertUtil<DeviceType::kCPU, float, std::complex<float>>;\ntemplate struct ComplexConvertUtil<DeviceType::kCPU, double, std::complex<double>>;\n\ntemplate struct FftC2CKernelUtil<DeviceType::kCPU, std::complex<float>, float>;\ntemplate struct FftC2CKernelUtil<DeviceType::kCPU, std::complex<double>, double>;\n\ntemplate struct FftR2CKernelUtil<DeviceType::kCPU, float, std::complex<float>>;\ntemplate struct FftR2CKernelUtil<DeviceType::kCPU, double, std::complex<double>>;\n\ntemplate struct FftC2RKernelUtil<DeviceType::kCPU, std::complex<float>, float>;\ntemplate struct FftC2RKernelUtil<DeviceType::kCPU, std::complex<double>, double>;\n\ntemplate struct FftStftKernelUtil<DeviceType::kCPU, float, std::complex<float>>;\ntemplate struct FftStftKernelUtil<DeviceType::kCPU, double, std::complex<double>>;\n}  // namespace oneflow"
  },
  {
    "path": "oneflow/user/kernels/fft_kernel_util.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <cuda.h>\n#include \"oneflow/core/device/cuda_util.h\"\n#include \"oneflow/core/framework/user_op_tensor.h\"\n#include \"oneflow/user/kernels/to_contiguous_kernel.h\"\n\n#if CUDA_VERSION >= 11000\n#include \"oneflow/user/kernels/fft_kernel_util.h\"\n#include \"cufft_plan_cache.h\"\n\nnamespace oneflow {\n\nnamespace {\ntemplate<typename FFTTYPE>\n__global__ void fft_apply_normalization(FFTTYPE* dst, const double normalization_scale, size_t n,\n                                        bool IsNormalized) {\n  if (!IsNormalized) { return; }\n  CUDA_1D_KERNEL_LOOP(i, n) {\n    dst[i].x *= normalization_scale;\n    dst[i].y *= normalization_scale;\n  };\n}\n\nstruct FillConjSymmetricParams {\n  int64_t last_dim;\n  int64_t elem_count;\n  int64_t ndim;\n  oneflow::NdIndexStrideOffsetHelper<int64_t, SHAPE_MAX_AXIS_SIZE> helper;\n  int64_t last_dim_size;\n  int64_t last_dim_half;\n\n  FillConjSymmetricParams() = default;\n  FillConjSymmetricParams(const Shape& shape, const Stride& strides, int64_t last_dim_,\n                          int64_t elemcnt)\n      : last_dim(last_dim_),\n        elem_count(elemcnt),\n        ndim(strides.size()),\n        helper(strides.data(), ndim) {\n    CHECK_OR_THROW(strides.size() == shape.size());\n    last_dim_size = shape[last_dim];\n    last_dim_half = last_dim_size / 2;\n  }\n};\n\n}  // namespace\n\ntemplate<typename T>\n__global__ void _conj_symmetry_cuda(T* data_out, FillConjSymmetricParams param) {\n  CUDA_1D_KERNEL_LOOP_T(int64_t, offset, param.elem_count) {\n    int64_t ndim = param.ndim;\n    int64_t indices[SHAPE_MAX_AXIS_SIZE];\n    param.helper.OffsetToNdIndex(offset, indices, ndim);\n    if (indices[param.last_dim] <= param.last_dim_half) { continue; }\n    int64_t cur_last_dim_index = indices[param.last_dim];\n    // get symmetric\n    indices[param.last_dim] = param.last_dim_size - cur_last_dim_index;\n    int64_t symmetric_offset = param.helper.NdIndexToOffset(indices, ndim);\n\n    // conj\n    data_out[offset] = T{data_out[symmetric_offset].x, -data_out[symmetric_offset].y};\n  }\n}\n\ntemplate<typename T>\nstruct FillConjSymmetryUtil<DeviceType::kCUDA, T> {\n  static void FillConjSymmetryForward(ep::Stream* stream, T* data_out, const Shape& shape,\n                                      const Stride& strides, const int64_t last_dim,\n                                      int64_t elem_count) {\n    FillConjSymmetricParams param(shape, strides, last_dim, elem_count);\n    _conj_symmetry_cuda<T><<<BlocksNum4ThreadsNum(elem_count), kCudaThreadsNumPerBlock, 0,\n                             stream->As<ep::CudaStream>()->cuda_stream()>>>(data_out, param);\n  }\n};\n\ntemplate<typename IN, typename OUT>\n__global__ void _convert_to_double_sized(const IN* in, OUT* dst, size_t len, size_t n) {\n  size_t fact_len = 2 * len - 2;\n  CUDA_1D_KERNEL_LOOP(i, n) {\n    int index_x = i / fact_len;\n    int index_y = i % fact_len;\n    if (index_y == 0) {\n      dst[i] = in[index_x * len];\n    } else if (index_y == len - 1) {\n      dst[i] = in[(index_x + 1) * len - 1];\n    } else if (index_y < len - 1 && index_y > 0) {\n      dst[i] = in[index_x * len + index_y];\n    } else {\n      auto index = (index_x + 2) * len - index_y - 2;\n      dst[i].x = in[index].x;\n      dst[i].y = -in[index].y;\n    }\n  }\n}\n\ntemplate<typename IN, typename OUT>\n__global__ void _convert_complex_to_real(const IN* in, OUT* out, size_t n) {\n  CUDA_1D_KERNEL_LOOP(i, n) {\n    out[2 * i] = in[i].x;\n    out[2 * i + 1] = in[i].y;\n  };\n}\n\ntemplate<typename real_type, typename complex_type>\nstruct ComplexConvertUtil<DeviceType::kCUDA, real_type, complex_type> {\n  static void ConvertToDoubleSized(ep::Stream* stream, const complex_type* in, complex_type* dst,\n                                   size_t len, size_t n) {\n    _convert_to_double_sized<<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0,\n                               stream->As<ep::CudaStream>()->cuda_stream()>>>(in, dst, len, n);\n  }\n  static void ConvertComplexToReal(ep::Stream* stream, const complex_type* in, real_type* out,\n                                   size_t n) {\n    _convert_complex_to_real<<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0,\n                               stream->As<ep::CudaStream>()->cuda_stream()>>>(in, out, n);\n  }\n};\n\ntemplate<typename dtype_in, typename dtype_out>\nclass StftGpuKernel final : public user_op::OpKernel {\n public:\n  StftGpuKernel() = default;\n  ~StftGpuKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* input = ctx->Tensor4ArgNameAndIndex(\"input\", 0);\n    user_op::Tensor* output = ctx->Tensor4ArgNameAndIndex(\"output\", 0);\n    user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n    const bool normalized = ctx->Attr<bool>(\"normalized\");\n    const bool onesided = ctx->Attr<bool>(\"onesided\");\n    const bool return_complex = ctx->Attr<bool>(\"return_complex\");\n\n    const ShapeView& input_shape = input->shape_view();\n    const ShapeView& output_shape = output->shape_view();\n\n    const Stride& input_stride = input->stride();\n    const int out_elem_cnt =\n        return_complex ? output->shape_view().elem_cnt() : output->shape_view().elem_cnt() / 2;\n\n    const dtype_in* data_in = input->dptr<dtype_in>();\n    dtype_in* data_out = output->mut_dptr<dtype_in>();\n    dtype_out* out_tmp_buffer = reinterpret_cast<dtype_out*>(tmp_buffer->mut_dptr<char>());\n\n    int64_t ndim = 1;\n    int64_t batch = static_cast<int32_t>(input_shape.At(1));\n    int64_t fft_size = static_cast<int32_t>(input_shape.At(2));\n    int64_t rank[1] = {fft_size};\n    const Stride& in_stride = {input_stride.at(1), input_stride.at(2)};\n    const Shape& in_shape = {batch, fft_size};\n    const Shape& out_shape = {batch, fft_size / 2 + 1};\n    Stride out_stride = Stride(out_shape);\n    CuFFTParams params(in_shape, out_shape, in_stride, out_stride, ndim, CUFFT_EXCUTETYPE::R2C,\n                       input->data_type());\n    CuFFTConfig config(params);\n    auto& plan = config.plan();\n    OF_CUFFT_CHECK(cufftSetStream(plan, ctx->stream()->As<ep::CudaStream>()->cuda_stream()));\n    void* workspace{};\n    OF_CUDA_CHECK(cudaMalloc(&workspace, config.workspace_size()));\n    OF_CUFFT_CHECK(cufftSetWorkArea(plan, workspace));\n\n    int64_t in_offset = input_stride.at(0);\n    int64_t out_offset =\n        std::accumulate(out_shape.begin(), out_shape.end(), 0, std::multiplies<int64_t>());\n    int64_t signal_groups_count = static_cast<int64_t>(input_shape.At(0));\n    for (int64_t i = 0; i < signal_groups_count; i++) {\n      config.excute((void*)(data_in + i * in_offset), (void*)(out_tmp_buffer + i * out_offset),\n                    /*forward=*/true);\n    }\n    OF_CUDA_CHECK(cudaFree(workspace));\n\n    if (!onesided) {\n      size_t last_dim_length = fft_size / 2 + 1;\n      dtype_out* doublesided_tmp_buffer =\n          reinterpret_cast<dtype_out*>(tmp_buffer->mut_dptr<char>()) + out_elem_cnt;\n      ComplexConvertUtil<DeviceType::kCUDA, dtype_in, dtype_out>::ConvertToDoubleSized(\n          ctx->stream(), out_tmp_buffer, doublesided_tmp_buffer, last_dim_length, out_elem_cnt);\n      out_tmp_buffer = doublesided_tmp_buffer;\n    }\n\n    const double normalization_scale =\n        _fft_normalization_scale<double>(input_shape.back(), normalized);\n    fft_apply_normalization<<<BlocksNum4ThreadsNum(out_elem_cnt), kCudaThreadsNumPerBlock, 0,\n                              ctx->stream()->As<ep::CudaStream>()->cuda_stream()>>>(\n        out_tmp_buffer, normalization_scale, out_elem_cnt, normalized);\n\n    if (!return_complex) {\n      ComplexConvertUtil<DeviceType::kCUDA, dtype_in, dtype_out>::ConvertComplexToReal(\n          ctx->stream(), out_tmp_buffer, data_out, out_elem_cnt);\n    } else {\n      // TODO(yzm):support return_complex after oneflow supports complex numbers\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_STFT_GPU_KERNEL(intype, outtype)                                           \\\n  REGISTER_USER_KERNEL(\"stft\")                                                              \\\n      .SetCreateFn<StftGpuKernel<intype, outtype>>()                                        \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)                      \\\n                       && (user_op::HobDataType(\"input\", 0) == GetDataType<intype>::value)) \\\n      .SetInferTmpSizeFn([](user_op::InferContext* ctx) {                                   \\\n        const Shape& output_shape = ctx->InputShape(\"output\", 0);                           \\\n        const bool return_complex = ctx->Attr<bool>(\"return_complex\");                      \\\n        const bool onesided = ctx->Attr<bool>(\"onesided\");                                  \\\n        int64_t output_elem_cnt =                                                           \\\n            return_complex ? output_shape.elem_cnt() : output_shape.elem_cnt() / 2;         \\\n        const int64_t output_bytes = GetCudaAlignedSize(output_elem_cnt * sizeof(outtype)); \\\n        return onesided ? output_bytes : 2 * output_bytes;                                  \\\n      });\n\nREGISTER_STFT_GPU_KERNEL(float, cufftComplex)\nREGISTER_STFT_GPU_KERNEL(double, cufftDoubleComplex)\n\ntemplate<typename T, typename FCT_TYPE>\nclass FftC2CKernelUtil<DeviceType::kCUDA, T, FCT_TYPE> {\n  static void FftC2CForward(ep::Stream* stream, const T* data_in, T* data_out,\n                            const Shape& input_shape, const Shape& output_shape,\n                            const Stride& input_stride, const Stride& output_stride, bool forward,\n                            const std::vector<int64_t>& dims, FCT_TYPE normalization,\n                            DataType real_type) {\n    // NOTE: before calling `FftC2CKernelUtil<DeviceType::kCUDA, T, FCT_TYPE>`, input must be\n    // batched out already\n    CuFFTParams params(input_shape, output_shape, input_stride, output_stride, dims.size(),\n                       CUFFT_EXCUTETYPE::C2C, real_type);\n    CuFFTConfig config(params);\n    auto& plan = config.plan();\n    OF_CUFFT_CHECK(cufftSetStream(plan, stream->As<ep::CudaStream>()->cuda_stream()));\n    void* workspace{};\n    OF_CUDA_CHECK(cudaMalloc(&workspace, config.workspace_size()));\n    OF_CUFFT_CHECK(cufftSetWorkArea(plan, workspace));\n\n    config.excute((void*)data_in, (void*)data_out, forward);\n    OF_CUDA_CHECK(cudaFree(workspace));\n  }\n};\n\ntemplate<typename IN, typename OUT>\nstruct FftR2CKernelUtil<DeviceType::kCUDA, IN, OUT> {\n  static void FftR2CForward(ep::Stream* stream, const IN* data_in, OUT* data_out,\n                            const Shape& input_shape, const Shape& output_shape,\n                            const Stride& input_stride, const Stride& output_stride, bool forward,\n                            const std::vector<int64_t>& dims, IN normalization,\n                            DataType real_type) {\n    // NOTE: before calling `FftR2CKernelUtil<DeviceType::kCUDA, IN, OUT>`, input must be batched\n    // out already\n    CuFFTParams params(input_shape, output_shape, input_stride, output_stride, dims.size(),\n                       CUFFT_EXCUTETYPE::R2C, real_type);\n    CuFFTConfig config(params);\n    auto& plan = config.plan();\n    OF_CUFFT_CHECK(cufftSetStream(plan, stream->As<ep::CudaStream>()->cuda_stream()));\n    void* workspace{};\n    OF_CUDA_CHECK(cudaMalloc(&workspace, config.workspace_size()));\n    OF_CUFFT_CHECK(cufftSetWorkArea(plan, workspace));\n\n    config.excute((void*)data_in, (void*)data_out, forward);\n    OF_CUDA_CHECK(cudaFree(workspace));\n  }\n};\n\ntemplate<typename IN, typename OUT>\nstruct FftC2RKernelUtil<DeviceType::kCUDA, IN, OUT> {\n  static void FftC2RForward(ep::Stream* stream, const IN* data_in, OUT* data_out,\n                            const Shape& input_shape, const Shape& output_shape,\n                            const Stride& input_stride, const Stride& output_stride, bool forward,\n                            int64_t last_dim_size, const std::vector<int64_t>& dims,\n                            OUT normalization, DataType real_type) {\n    // NOTE: before calling `FftC2RKernelUtil<DeviceType::kCUDA, IN, OUT>`, input must be batched\n    // out already\n    CuFFTParams params(input_shape, output_shape, input_stride, output_stride, dims.size(),\n                       CUFFT_EXCUTETYPE::C2R, real_type);\n    CuFFTConfig config(params);\n    auto& plan = config.plan();\n    OF_CUFFT_CHECK(cufftSetStream(plan, stream->As<ep::CudaStream>()->cuda_stream()));\n    void* workspace{};\n    OF_CUDA_CHECK(cudaMalloc(&workspace, config.workspace_size()));\n    OF_CUFFT_CHECK(cufftSetWorkArea(plan, workspace));\n\n    config.excute((void*)data_in, (void*)data_out, forward);\n    OF_CUDA_CHECK(cudaFree(workspace));\n  }\n};\n\ntemplate struct FillConjSymmetryUtil<DeviceType::kCUDA, cuComplex>;\ntemplate struct FillConjSymmetryUtil<DeviceType::kCUDA, cuDoubleComplex>;\n\ntemplate struct ComplexConvertUtil<DeviceType::kCUDA, float, cuComplex>;\ntemplate struct ComplexConvertUtil<DeviceType::kCUDA, double, cuDoubleComplex>;\n\ntemplate struct FftC2CKernelUtil<DeviceType::kCUDA, cuComplex, /*FCT_TYPE=*/float>;\ntemplate struct FftC2CKernelUtil<DeviceType::kCUDA, cuDoubleComplex, /*FCT_TYPE=*/double>;\n\ntemplate struct FftR2CKernelUtil<DeviceType::kCUDA, float, cuComplex>;\ntemplate struct FftR2CKernelUtil<DeviceType::kCUDA, double, cuDoubleComplex>;\n\ntemplate struct FftC2RKernelUtil<DeviceType::kCUDA, cuComplex, float>;\ntemplate struct FftC2RKernelUtil<DeviceType::kCUDA, cuDoubleComplex, double>;\n}  // namespace oneflow\n\n#endif  // CUDA_VERSION >= 11000\n"
  },
  {
    "path": "oneflow/user/kernels/fft_kernel_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_KERNELS_FFT_KERNEL_UTIL_H_\n#define ONEFLOW_USER_KERNELS_FFT_KERNEL_UTIL_H_\n\n#include <cstdint>\n#include <type_traits>\n#include \"oneflow/core/common/data_type.pb.h\"\n#include \"oneflow/core/kernel/kernel_util.h\"\n#include \"oneflow/core/common/shape_view.h\"\n#include \"oneflow/core/common/data_type.h\"\n#include \"oneflow/core/framework/op_kernel.h\"\n#include \"oneflow/core/common/nd_index_offset_helper.h\"\n\nnamespace oneflow {\n\ntemplate<typename T>\ninline T _fft_normalization_scale(const int32_t frame_length, bool normalized) {\n  if (!normalized) { return static_cast<T>(1.0); }\n  return static_cast<T>(1.0 / std::sqrt(frame_length));\n}\n\ntemplate<DeviceType device_type, typename T>\nstruct FillConjSymmetryUtil {\n  static void FillConjSymmetryForward(ep::Stream* stream, T* data_out, const Shape& shape,\n                                      const Stride& strides, const int64_t last_dim,\n                                      int64_t elem_count);\n};\n\ntemplate<DeviceType device_type, typename real_type, typename complex_type>\nstruct ComplexConvertUtil {\n  static void ConvertToDoubleSized(ep::Stream* stream, const complex_type* in, complex_type* dst,\n                                   size_t len, size_t n);\n  static void ConvertComplexToReal(ep::Stream* stream, const complex_type* in, real_type* out,\n                                   size_t n);\n};\n\ntemplate<DeviceType device_type, typename T, typename FCT_TYPE>\nstruct FftC2CKernelUtil {\n  static void FftC2CForward(ep::Stream* stream, const T* data_in, T* data_out,\n                            const Shape& input_shape, const Shape& output_shape,\n                            const Stride& input_stride, const Stride& output_stride, bool forward,\n                            const std::vector<int64_t>& dims, FCT_TYPE norm_fct,\n                            DataType real_type);\n};\n\ntemplate<DeviceType device_type, typename IN, typename OUT>\nstruct FftR2CKernelUtil {\n  static void FftR2CForward(ep::Stream* stream, const IN* data_in, OUT* data_out,\n                            const Shape& input_shape, const Shape& output_shape,\n                            const Stride& input_stride, const Stride& output_stride, bool forward,\n                            const std::vector<int64_t>& dims, IN norm_fct, DataType real_type);\n};\n\ntemplate<DeviceType device_type, typename IN, typename OUT>\nstruct FftC2RKernelUtil {\n  static void FftC2RForward(ep::Stream* stream, const IN* data_in, OUT* data_out,\n                            const Shape& input_shape, const Shape& output_shape,\n                            const Stride& input_stride, const Stride& output_stride, bool forward,\n                            int64_t last_dim_size, const std::vector<int64_t>& dims, OUT norm_fct,\n                            DataType real_type);\n};\n\ntemplate<DeviceType device_type, typename IN, typename OUT>\nstruct FftStftKernelUtil {\n  static void FftStftForward(ep::Stream* stream, const IN* data_in, OUT* data_out,\n                             const Shape& input_shape, const Shape& output_shape,\n                             const Stride& input_stride, const Stride& output_stride, bool forward,\n                             const std::vector<int64_t>& axes, IN norm_fct, int64_t len,\n                             int64_t dims, int64_t batch);\n};\n\n}  // namespace oneflow\n#endif  // ONEFLOW_USER_KERNELS_FFT_KERNEL_UTIL_H_"
  },
  {
    "path": "oneflow/user/kernels/fft_kernels.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <complex>\n#include <cstdint>\n#include \"pocketfftplan.h\"\n#include \"oneflow/core/common/stride.h\"\n#include \"oneflow/user/kernels/fft_kernel_util.h\"\n\nusing namespace pocketfft;\nnamespace oneflow {\n\ntemplate<DeviceType device_type, typename T, typename FCT_TYPE>\nclass FftC2CKernel final : public user_op::OpKernel {\n public:\n  FftC2CKernel() = default;\n  ~FftC2CKernel() = default;\n\n private:\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* input = ctx->Tensor4ArgNameAndIndex(\"input\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    bool forward = ctx->Attr<bool>(\"forward\");\n    double norm_fct = ctx->Attr<double>(\"norm_fct\");\n\n    const std::vector<int64_t>& dims = ctx->Attr<std::vector<int64_t>>(\"dims\");\n\n    const T* input_ptr = input->dptr<T>();\n    T* out_ptr = out->mut_dptr<T>();\n\n    Shape input_shape(input->shape_view());\n    Shape out_shape(out->shape_view());\n\n    if (input->data_type() == kComplex64) {\n      FftC2CKernelUtil<device_type, T, FCT_TYPE>::FftC2CForward(\n          ctx->stream(), input_ptr, out_ptr, input_shape, out_shape, input->stride(), out->stride(),\n          forward, dims, static_cast<FCT_TYPE>(norm_fct), DataType::kFloat);\n    } else if (input->data_type() == kComplex128) {\n      FftC2CKernelUtil<device_type, T, FCT_TYPE>::FftC2CForward(\n          ctx->stream(), input_ptr, out_ptr, input_shape, out_shape, input->stride(), out->stride(),\n          forward, dims, static_cast<FCT_TYPE>(norm_fct), DataType::kDouble);\n    } else {\n      CHECK_OR_THROW(false) << \"expects kComplex64 or kComplex128, but got \" << input->data_type();\n    }\n  }\n};\n\ntemplate<DeviceType device_type, typename dtype_in, typename dtype_out>\nclass FftR2CKernel final : public user_op::OpKernel {\n public:\n  FftR2CKernel() = default;\n  ~FftR2CKernel() = default;\n\n private:\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* input = ctx->Tensor4ArgNameAndIndex(\"input\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    bool onesided = ctx->Attr<bool>(\"onesided\");\n    double norm_fct = ctx->Attr<double>(\"norm_fct\");\n    const std::vector<int64_t>& dims = ctx->Attr<std::vector<int64_t>>(\"dims\");\n    const dtype_in* input_ptr = input->dptr<dtype_in>();\n    dtype_out* out_ptr = out->mut_dptr<dtype_out>();\n\n    Shape input_shape(input->shape_view());\n    Shape out_shape(out->shape_view());\n\n    if (input->data_type() == kFloat || input->data_type() == kDouble) {\n      FftR2CKernelUtil<device_type, dtype_in, dtype_out>::FftR2CForward(\n          ctx->stream(), input_ptr, out_ptr, input_shape, out_shape, input->stride(), out->stride(),\n          /*forward=*/true, dims, norm_fct, /*real_type=*/input->data_type());\n    } else {\n      CHECK_OR_THROW(false) << \"expects kFloat or kDouble, but gets \" << input->data_type();\n    }\n\n    if (!onesided) {\n      FillConjSymmetryUtil<device_type, dtype_out>::FillConjSymmetryForward(\n          ctx->stream(), out_ptr, out_shape, out->stride(), dims.back(), out_shape.elem_cnt());\n    }\n  }\n};\n\ntemplate<DeviceType device_type, typename dtype_in, typename dtype_out>\nclass FftC2RKernel final : public user_op::OpKernel {\n public:\n  FftC2RKernel() = default;\n  ~FftC2RKernel() = default;\n\n private:\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* input = ctx->Tensor4ArgNameAndIndex(\"input\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    int64_t last_dim_size = ctx->Attr<int64_t>(\"last_dim_size\");\n    double norm_fct = ctx->Attr<double>(\"norm_fct\");\n    const std::vector<int64_t>& dims = ctx->Attr<std::vector<int64_t>>(\"dims\");\n\n    const dtype_in* input_ptr = input->dptr<dtype_in>();\n    dtype_out* out_ptr = out->mut_dptr<dtype_out>();\n\n    Shape input_shape(input->shape_view());\n    Shape out_shape(out->shape_view());\n\n    out_shape[dims.back()] = last_dim_size;\n\n    if (input->data_type() == kComplex64 || input->data_type() == kComplex128) {\n      FftC2RKernelUtil<device_type, dtype_in, dtype_out>::FftC2RForward(\n          ctx->stream(), input_ptr, out_ptr, input_shape, out_shape, input->stride(), out->stride(),\n          /*forward=*/false,\n          /*last_dim_size=*/last_dim_size, dims, norm_fct, /*real_type=*/out->data_type());\n    } else {\n      CHECK_OR_THROW(false) << \"expects kComplex64 or kComplex128, but gets \" << input->data_type();\n    }\n  }\n};\n\ntemplate<DeviceType device_type, typename dtype_in, typename dtype_out>\nclass StftCpuKernel final : public user_op::OpKernel {\n public:\n  StftCpuKernel() = default;\n  ~StftCpuKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* input = ctx->Tensor4ArgNameAndIndex(\"input\", 0);\n    user_op::Tensor* output = ctx->Tensor4ArgNameAndIndex(\"output\", 0);\n    user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n    const auto normalized = ctx->Attr<bool>(\"normalized\");\n    const auto return_complex = ctx->Attr<bool>(\"return_complex\");\n    const bool onesided = ctx->Attr<bool>(\"onesided\");\n\n    const ShapeView input_shape = input->shape_view();\n    const ShapeView output_shape = output->shape_view();\n    const auto output_elem_cnt = output_shape.elem_cnt() / 2;\n\n    int64_t dims = input_shape.At(0);\n    int64_t batch = input_shape.At(1);\n    int64_t len = input_shape.back();\n    const dtype_in* data_in = input->dptr<dtype_in>();\n    dtype_in* data_out = output->mut_dptr<dtype_in>();\n\n    dtype_out* out_tmp_buffer = reinterpret_cast<dtype_out*>(tmp_buffer->mut_dptr<char>());\n    Shape out_tmp_shape = Shape{len};\n    Stride out_tmp_stride = Stride(out_tmp_shape);\n    std::vector<int64_t> axes(out_tmp_shape.size());\n    std::iota(axes.begin(), axes.end(), 0);\n    auto norm_fct = _fft_normalization_scale<dtype_in>(len, normalized);\n    FftStftKernelUtil<device_type, dtype_in, dtype_out>::FftStftForward(\n        ctx->stream(), data_in, out_tmp_buffer, out_tmp_shape, out_tmp_shape, out_tmp_stride,\n        out_tmp_stride, true, /*axes=*/axes, /*norm_fct=*/norm_fct,\n        /*len=*/len, /*dims=*/dims, /*batch=*/batch);\n\n    if (!onesided) {\n      dtype_out* doublesided_tmp_buffer =\n          reinterpret_cast<dtype_out*>(tmp_buffer->mut_dptr<char>()) + output_elem_cnt;\n      size_t last_dim_length = len / 2 + 1;\n      size_t elem_conut = output_elem_cnt;\n      ComplexConvertUtil<DeviceType::kCPU, dtype_in, dtype_out>::ConvertToDoubleSized(\n          ctx->stream(), out_tmp_buffer, doublesided_tmp_buffer, last_dim_length, elem_conut);\n      out_tmp_buffer = doublesided_tmp_buffer;\n    }\n\n    if (!return_complex) {\n      ComplexConvertUtil<DeviceType::kCPU, dtype_in, dtype_out>::ConvertComplexToReal(\n          ctx->stream(), out_tmp_buffer, data_out, output_elem_cnt);\n    }\n  }\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_STFT_CPU_KERNEL(dtype_in, dtype_out)                                         \\\n  REGISTER_USER_KERNEL(\"stft\")                                                                \\\n      .SetCreateFn<StftCpuKernel<DeviceType::kCPU, dtype_in, dtype_out>>()                    \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == kCPU)                                     \\\n                       && (user_op::HobDataType(\"input\", 0) == GetDataType<dtype_in>::value)) \\\n      .SetInferTmpSizeFn([](user_op::InferContext* ctx) {                                     \\\n        const Shape& output_shape = ctx->InputShape(\"output\", 0);                             \\\n        const bool return_complex = ctx->Attr<bool>(\"return_complex\");                        \\\n        const bool onesided = ctx->Attr<bool>(\"onesided\");                                    \\\n        int64_t output_elem_cnt =                                                             \\\n            return_complex ? output_shape.elem_cnt() : output_shape.elem_cnt() / 2;           \\\n        const int64_t output_bytes = (output_elem_cnt * sizeof(dtype_out));                   \\\n        return onesided ? output_bytes : 2 * output_bytes;                                    \\\n      });\n\nREGISTER_STFT_CPU_KERNEL(double, std::complex<double>)\nREGISTER_STFT_CPU_KERNEL(float, std::complex<float>)\n\n#define REGISTER_FFTC2C_KERNELS(device_type, dtype, fct_type)                             \\\n  REGISTER_USER_KERNEL(\"fft_c2c\")                                                         \\\n      .SetCreateFn<FftC2CKernel<device_type, dtype, fct_type>>()                          \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == device_type)                          \\\n                       && (user_op::HobDataType(\"input\", 0) == GetDataType<dtype>::value) \\\n                       && (user_op::HobDataType(\"out\", 0) == GetDataType<dtype>::value))\n\nREGISTER_FFTC2C_KERNELS(DeviceType::kCPU, std::complex<float>, float);\nREGISTER_FFTC2C_KERNELS(DeviceType::kCPU, std::complex<double>, double);\n#ifdef WITH_CUDA\nREGISTER_FFTC2C_KERNELS(DeviceType::kCUDA, cuComplex, float);\nREGISTER_FFTC2C_KERNELS(DeviceType::kCUDA, cuDoubleComplex, double);\n#endif\n\n#define REGISTER_FFTR2C_KERNELS(device_type, dtype_in, dtype_out)                            \\\n  REGISTER_USER_KERNEL(\"fft_r2c\")                                                            \\\n      .SetCreateFn<FftR2CKernel<device_type, dtype_in, dtype_out>>()                         \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == device_type)                             \\\n                       && (user_op::HobDataType(\"input\", 0) == GetDataType<dtype_in>::value) \\\n                       && (user_op::HobDataType(\"out\", 0) == GetDataType<dtype_out>::value))\n\nREGISTER_FFTR2C_KERNELS(DeviceType::kCPU, float, std::complex<float>);\nREGISTER_FFTR2C_KERNELS(DeviceType::kCPU, double, std::complex<double>);\n#ifdef WITH_CUDA\nREGISTER_FFTR2C_KERNELS(DeviceType::kCUDA, float, cuComplex);\nREGISTER_FFTR2C_KERNELS(DeviceType::kCUDA, double, cuDoubleComplex);\n#endif\n\n#define REGISTER_FFTC2R_KERNELS(device_type, dtype_in, dtype_out)                            \\\n  REGISTER_USER_KERNEL(\"fft_c2r\")                                                            \\\n      .SetCreateFn<FftC2RKernel<device_type, dtype_in, dtype_out>>()                         \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == device_type)                             \\\n                       && (user_op::HobDataType(\"input\", 0) == GetDataType<dtype_in>::value) \\\n                       && (user_op::HobDataType(\"out\", 0) == GetDataType<dtype_out>::value))\n\nREGISTER_FFTC2R_KERNELS(DeviceType::kCPU, std::complex<float>, float);\nREGISTER_FFTC2R_KERNELS(DeviceType::kCPU, std::complex<double>, double);\n#ifdef WITH_CUDA\nREGISTER_FFTC2R_KERNELS(DeviceType::kCUDA, cuComplex, float);\nREGISTER_FFTC2R_KERNELS(DeviceType::kCUDA, cuDoubleComplex, double);\n#endif\n}  // namespace oneflow"
  },
  {
    "path": "oneflow/user/kernels/fill_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/scalar.h\"\n#include \"oneflow/core/ep/include/primitive/fill.h\"\n#include \"oneflow/core/framework/framework.h\"\n\nnamespace oneflow {\nnamespace {\n\ntemplate<typename Context>\nstd::unique_ptr<ep::primitive::Fill> NewFillPrimitive(Context* ctx) {\n  const DataType data_type = ctx->TensorDesc4ArgNameAndIndex(\"out\", 0)->data_type();\n  return ep::primitive::NewPrimitive<ep::primitive::FillFactory>(ctx->device_type(), data_type);\n}\n\n}  // namespace\n\nclass FillKernel final : public user_op::OpKernel {\n public:\n  FillKernel() = default;\n  ~FillKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    const bool is_floating_value = ctx->Attr<bool>(\"is_floating_value\");\n    const Scalar value = is_floating_value ? Scalar(ctx->Attr<double>(\"floating_value\"))\n                                           : Scalar(ctx->Attr<int64_t>(\"integral_value\"));\n    const int32_t elem_cnt = in->shape_view().elem_cnt();\n    CHECK_GE(elem_cnt, 0);\n    if (elem_cnt == 0) { return; }\n    std::unique_ptr<ep::primitive::Fill> fill = NewFillPrimitive(ctx);\n    CHECK(fill);\n    fill->Launch(ctx->stream(), out->mut_dptr(), value, elem_cnt);\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\nauto FillPrimitiveExists() {\n  return hob::make_custom(\"FillPrimitiveExists\", [](const user_op::KernelRegContext& ctx) {\n    return NewFillPrimitive(&ctx).operator bool();\n  });\n}\n\ntemplate<typename T>\nclass FillTensorCpuKernel final : public user_op::OpKernel {\n public:\n  FillTensorCpuKernel() = default;\n  ~FillTensorCpuKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    const user_op::Tensor* value = ctx->Tensor4ArgNameAndIndex(\"value\", 0);\n    const T value_ = value->dptr<T>()[0];\n    const int32_t elem_cnt = in->shape_view().elem_cnt();\n    T* out_ptr = out->mut_dptr<T>();\n    FOR_RANGE(int32_t, i, 0, elem_cnt) { out_ptr[i] = value_; }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_FILL_CPU_KERNEL(dtype)                               \\\n  REGISTER_USER_KERNEL(\"fill_tensor_\")                                \\\n      .SetCreateFn<FillTensorCpuKernel<dtype>>()                      \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \\\n                       && (user_op::HobDataType(\"out\", 0) == GetDataType<dtype>::value));\n\nREGISTER_FILL_CPU_KERNEL(float)\nREGISTER_FILL_CPU_KERNEL(float16)\nREGISTER_FILL_CPU_KERNEL(double)\nREGISTER_FILL_CPU_KERNEL(int8_t)\nREGISTER_FILL_CPU_KERNEL(int32_t)\nREGISTER_FILL_CPU_KERNEL(int64_t)\nREGISTER_USER_KERNEL(\"fill_\").SetCreateFn<FillKernel>().SetIsMatchedHob(FillPrimitiveExists()\n                                                                        == true);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/fill_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/common/nd_index_offset_helper.h\"\n#include \"oneflow/core/kernel/new_kernel_util.h\"\n\nnamespace oneflow {\n\nnamespace {\ntemplate<typename T>\n__global__ void FillTensorGpuForward(const int n, const T* value, T* y) {\n  CUDA_1D_KERNEL_LOOP(i, n) { y[i] = value[0]; }\n}\n};  // namespace\n\ntemplate<typename T>\nclass FillTensorGpuKernel final : public user_op::OpKernel {\n public:\n  FillTensorGpuKernel() = default;\n  ~FillTensorGpuKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    const user_op::Tensor* value = ctx->Tensor4ArgNameAndIndex(\"value\", 0);\n    const int32_t elem_cnt = in->shape_view().elem_cnt();\n    RUN_CUDA_KERNEL((FillTensorGpuForward<T>), ctx->stream(), elem_cnt, elem_cnt, value->dptr<T>(),\n                    out->mut_dptr<T>());\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_FILL_CUDA_KERNEL(dtype)                               \\\n  REGISTER_USER_KERNEL(\"fill_tensor_\")                                 \\\n      .SetCreateFn<FillTensorGpuKernel<dtype>>()                       \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \\\n                       && (user_op::HobDataType(\"out\", 0) == GetDataType<dtype>::value));\n\nREGISTER_FILL_CUDA_KERNEL(float)\nREGISTER_FILL_CUDA_KERNEL(half)\nREGISTER_FILL_CUDA_KERNEL(double)\nREGISTER_FILL_CUDA_KERNEL(int8_t)\nREGISTER_FILL_CUDA_KERNEL(int32_t)\nREGISTER_FILL_CUDA_KERNEL(int64_t)\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/flip_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/new_kernel_util.h\"\n#include \"oneflow/core/common/nd_index_offset_helper.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nconst int32_t NDIMS = 16;\n\nstruct SIZE_V {\n  int32_t val[NDIMS];\n};\n\nstruct VIS {\n  bool val[NDIMS] = {false};\n};\n\ntemplate<typename T>\nvoid FlipCpuForward(const int32_t element, const int64_t total_dims, const SIZE_V sizes_v,\n                    const VIS vis, SIZE_V strides_v, const T* in_dptr, T* out_dptr) {\n  for (int i = 0; i < element; i++) {\n    int32_t cur_indices = i;\n    int32_t rem = 0;\n    int32_t dst_offset = 0;\n    for (int32_t d = 0; d < total_dims; d++) {\n      int32_t temp = cur_indices;\n      cur_indices = cur_indices / strides_v.val[d];\n      rem = temp - cur_indices * strides_v.val[d];\n      dst_offset += vis.val[d] ? (sizes_v.val[d] - 1 - cur_indices) * strides_v.val[d]\n                               : cur_indices * strides_v.val[d];\n      cur_indices = rem;\n    }\n    out_dptr[i] = in_dptr[dst_offset];\n  }\n}\n\n}  // namespace\n\ntemplate<typename T>\nclass FlipCpuKernel final : public user_op::OpKernel {\n public:\n  FlipCpuKernel() = default;\n  ~FlipCpuKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* x_tensor = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    user_op::Tensor* y_tensor = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    const int32_t elem_cnt = y_tensor->shape_view().elem_cnt();\n    if (elem_cnt == 0) { return; }\n    const int32_t total_dims = y_tensor->shape_view().NumAxes();\n\n    std::vector<int32_t> dims = ctx->Attr<std::vector<int32_t>>(\"dims\");\n    VIS vis;\n    for (auto x : dims) { vis.val[x] = true; }\n\n    SIZE_V sizes_v;\n    for (int32_t i = 0; i < total_dims; i++) { sizes_v.val[i] = y_tensor->shape_view().At(i); }\n\n    // TODO(bbuf) delete strides caluculate, after tensor strides supported\n    SIZE_V strides_v;\n    strides_v.val[total_dims - 1] = 1;\n    for (int32_t i = total_dims - 2; i >= 0; i--) {\n      strides_v.val[i] = strides_v.val[i + 1] * y_tensor->shape_view().At(i + 1);\n    }\n\n    FlipCpuForward(elem_cnt, total_dims, sizes_v, vis, strides_v, x_tensor->dptr<T>(),\n                   y_tensor->mut_dptr<T>());\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_FLIP_CPU_KERNEL(dtype)                                             \\\n  REGISTER_USER_KERNEL(\"flip\").SetCreateFn<FlipCpuKernel<dtype>>().SetIsMatchedHob( \\\n      (user_op::HobDeviceType() == DeviceType::kCPU)                                \\\n      && (user_op::HobDataType(\"y\", 0) == GetDataType<dtype>::value));\n\nREGISTER_FLIP_CPU_KERNEL(bool)\nREGISTER_FLIP_CPU_KERNEL(float)\nREGISTER_FLIP_CPU_KERNEL(double)\nREGISTER_FLIP_CPU_KERNEL(uint8_t)\nREGISTER_FLIP_CPU_KERNEL(int8_t)\nREGISTER_FLIP_CPU_KERNEL(int32_t)\nREGISTER_FLIP_CPU_KERNEL(int64_t)\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/flip_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/container_util.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/device/cuda_util.h\"\n#include \"oneflow/core/common/nd_index_offset_helper.h\"\n#include \"oneflow/core/ep/include/stream.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nconst int32_t NDIMS = 16;\nstruct SIZE_V {\n  int32_t val[NDIMS];\n};\n\nstruct VIS {\n  bool val[NDIMS] = {false};\n};\n\ntemplate<typename T>\n__global__ void FlipGpuForward(const int32_t element, const int64_t total_dims,\n                               const SIZE_V sizes_v, const VIS vis, SIZE_V strides_v,\n                               const T* in_dptr, T* out_dptr) {\n  CUDA_1D_KERNEL_LOOP(i, element) {\n    int32_t cur_indices = i;\n    int32_t rem = 0;\n    int32_t dst_offset = 0;\n    for (int32_t d = 0; d < total_dims; d++) {\n      int32_t temp = cur_indices;\n      cur_indices = cur_indices / strides_v.val[d];\n      rem = temp - cur_indices * strides_v.val[d];\n      dst_offset += vis.val[d] ? (sizes_v.val[d] - 1 - cur_indices) * strides_v.val[d]\n                               : cur_indices * strides_v.val[d];\n      cur_indices = rem;\n    }\n    out_dptr[i] = in_dptr[dst_offset];\n  }\n}\n\n/*\nExample tensor:\n[[0, 1, 2, 3, 4, 5, 6, 7],\n [8, 9, 10, 11, 12, 13, 14]]\n\nGiven parameters: BlockSize=4, GridSize=4\nFor each block_i, `block_begin_idx` is calculated as (i - 1) * BlockSize = (i - 1) * 4,\nand `thread_end_idx` is set to 4 for all blocks except the final block.\nIn the final block, `thread_end_idx` is 2, representing the border index of the active thread.\n\n`i_ori` is an index referring to the original position of data stored in shm[threadIdx.x] before\nflipping. For instance, consider block 1 and thread 2 (element 6). The element is located at row 0,\ncolumn 7 in the tensor. Its original index `i_ori` is 7, and after flipping, it is mapped to row 0,\ncolumn 0.\n\n                    ┌───┬───┬───┬───┬───┬───┬───┬───┬───┬───┬───┬───┬───┬───┬───┬───┐\nglobal mem before:  │ 0 │ 1 │ 2 │ 3 │ 4 │ 5 │ 6 │ 7 │ 8 │ 9 │ A │ B │ C │ D │ x │ x │\n                    └───┴───┴───┴───┴───┴───┴───┴───┴───┴───┴───┴───┴───┴───┴───┴───┘\n\n                         block0     │    block1     │    block2     │    block3\n                    ┌───┬───┬───┬───┼───┬───┬───┬───┼───┬───┬───┬───┼───┬───┬───┬───┐\nshm after loading:  │ 3 │ 2 │ 1 │ 0 │ 7 │ 6 │ 5 │ 4 │ B │ A │ 9 │ 8 │ D │ C │ x │ x │\n                    └───┴───┴───┴───┴───┴───┴───┴───┴───┴───┴───┴───┴───┴───┴───┴───┘\n\n                    ┌───┬───┬───┬───┬───┬───┬───┬───┬───┬───┬───┬───┬───┬───┬───┬───┐\nglobal mem after:   │ 6 │ 5 │ 4 │ 3 │ 2 │ 1 │ 0 │ D │ C │ B │ A │ 9 │ 8 │ 7 │ x │ x │\n                    └───┴───┴───┴───┴───┴───┴───┴───┴───┴───┴───┴───┴───┴───┴───┴───┘\n*/\ntemplate<typename T>\n__global__ void FlipLastDimGpuForward(const int32_t element, const int64_t last_dim_size,\n                                      const T* in_dptr, T* out_dptr) {\n  __shared__ T shm[ep::CudaStream::kDefaultBlockSize];\n  CUDA_1D_KERNEL_LOOP(i, element) {\n    int32_t block_begin_idx = blockDim.x * blockIdx.x;\n    int32_t thread_end_idx = min(block_begin_idx + blockDim.x, element) - block_begin_idx;\n    int32_t i_ori = block_begin_idx + (thread_end_idx - threadIdx.x - 1);\n    shm[threadIdx.x] = in_dptr[i_ori];\n    __syncthreads();\n    int32_t row = i_ori / last_dim_size;\n    int32_t col = last_dim_size - (i_ori - row * last_dim_size) - 1;\n    out_dptr[row * last_dim_size + col] = shm[threadIdx.x];\n  }\n}\n\n}  // namespace\n\ntemplate<typename T>\nclass FlipGpuKernel final : public user_op::OpKernel {\n public:\n  FlipGpuKernel() = default;\n  ~FlipGpuKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* x_tensor = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    user_op::Tensor* y_tensor = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    const int32_t elem_cnt = y_tensor->shape_view().elem_cnt();\n    if (elem_cnt == 0) { return; }\n    const int32_t total_dims = y_tensor->shape_view().NumAxes();\n\n    std::vector<int32_t> dims = ctx->Attr<std::vector<int32_t>>(\"dims\");\n    VIS vis;\n    for (auto x : dims) { vis.val[x] = true; }\n\n    if (dims.size() == 1 && dims[0] == x_tensor->shape_view().NumAxes() - 1) {\n      RUN_CUDA_KERNEL((FlipLastDimGpuForward<T>), ctx->stream(), elem_cnt, elem_cnt,\n                      x_tensor->shape_view().At(total_dims - 1), x_tensor->dptr<T>(),\n                      y_tensor->mut_dptr<T>());\n      return;\n    }\n\n    SIZE_V sizes_v;\n    for (int32_t i = 0; i < total_dims; i++) { sizes_v.val[i] = y_tensor->shape_view().At(i); }\n\n    SIZE_V strides_v;\n    for (int32_t i = 0; i < total_dims; i++) {\n      strides_v.val[i] = CHECK_JUST(VectorAt(y_tensor->stride(), i));\n    }\n    RUN_CUDA_KERNEL((FlipGpuForward<T>), ctx->stream(), elem_cnt, elem_cnt, total_dims, sizes_v,\n                    vis, strides_v, x_tensor->dptr<T>(), y_tensor->mut_dptr<T>());\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_FLIP_CUDA_KERNEL(dtype)                                            \\\n  REGISTER_USER_KERNEL(\"flip\").SetCreateFn<FlipGpuKernel<dtype>>().SetIsMatchedHob( \\\n      (user_op::HobDeviceType() == DeviceType::kCUDA)                               \\\n      && (user_op::HobDataType(\"y\", 0) == GetDataType<dtype>::value));\n\nREGISTER_FLIP_CUDA_KERNEL(bool)\nREGISTER_FLIP_CUDA_KERNEL(float)\nREGISTER_FLIP_CUDA_KERNEL(half)\nREGISTER_FLIP_CUDA_KERNEL(double)\nREGISTER_FLIP_CUDA_KERNEL(uint8_t)\nREGISTER_FLIP_CUDA_KERNEL(int8_t)\nREGISTER_FLIP_CUDA_KERNEL(int32_t)\nREGISTER_FLIP_CUDA_KERNEL(int64_t)\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/fold_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/user/ops/nn_util.h\"\n#include \"oneflow/core/operator/operator_util.h\"\n#include \"oneflow/user/kernels/fold_kernel_util.h\"\n\nnamespace oneflow {\n\nnamespace user_op {\n\nnamespace {\n\n// NDIM range: (1, 2, 3)\n// SDIM range: (1, 2), 1 indicates channels_last, 2 indicates channels_first\ntemplate<typename INDEX_T, int NDIM, int SDIM>\nclass FoldOpKernelState : public OpKernelState {\n public:\n  using ParamType = FoldParams<INDEX_T, NDIM, SDIM>;\n  FoldOpKernelState(const ShapeView& input_shape, const std::vector<int32_t>& output_size,\n                    const std::vector<int32_t>& kernel_size, const std::vector<int32_t>& padding,\n                    const std::vector<int32_t>& stride, const std::vector<int32_t>& dilation)\n      : params_(input_shape.At(0), input_shape.At(ParamType::kInputChannelDim), output_size.data(),\n                input_shape.ptr() + SDIM, kernel_size.data(), padding.data(), stride.data(),\n                dilation.data()) {}\n  const ParamType& params() const { return params_; }\n\n private:\n  ParamType params_;\n};\n\ntemplate<typename INDEX_T, int NDIM, int SDIM>\nstd::shared_ptr<FoldOpKernelState<INDEX_T, NDIM, SDIM>> CreateFoldOpKernelState(\n    const ShapeView& input_shape, const std::vector<int32_t>& output_size,\n    const std::vector<int32_t>& kernel_size, const std::vector<int32_t>& padding,\n    const std::vector<int32_t>& stride, const std::vector<int32_t>& dilation) {\n  std::shared_ptr<FoldOpKernelState<INDEX_T, NDIM, SDIM>> state(\n      new FoldOpKernelState<INDEX_T, NDIM, SDIM>(input_shape, output_size, kernel_size, padding,\n                                                 stride, dilation));\n  return state;\n}\n\ntemplate<DeviceType device_type, typename T, typename INDEX_T, int NDIM, int SDIM>\nclass FoldKernel final : public OpKernel {\n public:\n  FoldKernel() = default;\n  ~FoldKernel() = default;\n\n private:\n  void Compute(KernelComputeContext* ctx) const override {\n    const Tensor* input = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    Tensor* output = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n\n    const std::vector<int32_t> output_size = ctx->Attr<std::vector<int32_t>>(\"output_size\");\n    const std::vector<int32_t> kernel_size = ctx->Attr<std::vector<int32_t>>(\"kernel_size\");\n    const std::vector<int32_t> dilation = ctx->Attr<std::vector<int32_t>>(\"dilation_rate\");\n    const std::vector<int32_t> padding = ctx->Attr<std::vector<int32_t>>(\"padding\");\n    const std::vector<int32_t> stride = ctx->Attr<std::vector<int32_t>>(\"strides\");\n\n    const auto& state_ptr = CreateFoldOpKernelState<INDEX_T, NDIM, SDIM>(\n        input->shape_view(), output_size, kernel_size, padding, stride, dilation);\n    const FoldParams<INDEX_T, NDIM, SDIM> params = state_ptr->params();\n    size_t out_bytes_size =\n        output->shape_view().elem_cnt() * GetSizeOfDataType(output->data_type());\n    Memset<device_type>(ctx->stream(), output->mut_dptr<T>(), 0, out_bytes_size);\n    FoldKernelUtil<device_type, T, INDEX_T, NDIM, SDIM>::Forward(\n        ctx->stream(), &params, input->dptr<T>(), output->mut_dptr<T>());\n  }\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n}  // namespace\n\n// Currently support 4-D tensor and NCHW format\n#define REGISTER_FOLD_KERNEL(device, dtype)                    \\\n  REGISTER_USER_KERNEL(\"fold\")                                 \\\n      .SetCreateFn<FoldKernel<device, dtype, int32_t, 2, 2>>() \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == device)    \\\n                       && (user_op::HobDataType(\"x\", 0) == GetDataType<dtype>::value));\n\nREGISTER_FOLD_KERNEL(DeviceType::kCPU, float)\nREGISTER_FOLD_KERNEL(DeviceType::kCPU, double)\n\n#ifdef WITH_CUDA\nREGISTER_FOLD_KERNEL(DeviceType::kCUDA, float)\nREGISTER_FOLD_KERNEL(DeviceType::kCUDA, double)\n#endif  // WITH_CUDA\n\n}  // namespace user_op\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/fold_kernel_util.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/user/kernels/fold_kernel_util.h\"\nnamespace oneflow {\n\nnamespace user_op {\n\n// NDIM range: (1, 2, 3)\n// SDIM range: (1, 2), 1 indicates channels_last, 2 indicates channels_first\ntemplate<typename T, typename INDEX_T, int NDIM, int SDIM>\nstruct FoldKernelUtil<DeviceType::kCPU, T, INDEX_T, NDIM, SDIM> {\n  using ParamType = FoldParams<INDEX_T, NDIM, SDIM>;\n  static void Forward(ep::Stream* stream, const void* raw_params, const T* input_ptr,\n                      T* output_ptr) {\n    const auto* params = static_cast<const ParamType*>(raw_params);\n    for (INDEX_T in_offset = 0; in_offset < params->in_elem_cnt; ++in_offset) {\n      using ParamType = FoldParams<INDEX_T, NDIM, SDIM>;\n      INDEX_T in_index[ParamType::kInputNDim] = {0};\n      INDEX_T out_index[ParamType::kOutputNDim] = {0};\n      params->in_index_helper.OffsetToNdIndex(in_offset, in_index);\n      if (!FoldIndexTransform<INDEX_T, NDIM, SDIM>(*params, in_index, out_index)) {\n        INDEX_T out_offset = params->out_index_helper.NdIndexToOffset(out_index);\n        XPUAdd<T>::Invoke(&input_ptr[in_offset], &output_ptr[out_offset]);\n      } else {\n        continue;\n      }\n    }\n  }\n};\n\nINSTANTIATE_FOLD_KERNEL_UTIL_FOR_DEVICE(DeviceType::kCPU)\n\n}  // namespace user_op\n\n}  // namespace oneflow"
  },
  {
    "path": "oneflow/user/kernels/fold_kernel_util.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifdef WITH_CUDA\n\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/cuda/elementwise.cuh\"\n#include \"oneflow/user/kernels/fold_kernel_util.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n\nnamespace oneflow {\n\nnamespace user_op {\n\nnamespace {\n\nconstexpr int kBlockSize = cuda::elementwise::kBlockSize;\n\nint GetNumBlocks(int64_t elem_cnt) {\n  int num_blocks = 0;\n  OF_CUDA_CHECK(cuda::elementwise::GetNumBlocks(elem_cnt, &num_blocks));\n  return num_blocks;\n}\n\n// NDIM range: (1, 2, 3)\n// SDIM range: (1, 2), 1 indicates channels_last, 2 indicates channels_first\ntemplate<typename T, typename INDEX_T, int NDIM, int SDIM>\n__global__ void CudaFoldForward(FoldParams<INDEX_T, NDIM, SDIM> params, const T* input_ptr,\n                                T* output_ptr) {\n  CUDA_1D_KERNEL_LOOP_T(INDEX_T, in_offset, params.in_elem_cnt) {\n    using ParamType = FoldParams<INDEX_T, NDIM, SDIM>;\n    INDEX_T in_index[ParamType::kInputNDim] = {0};\n    INDEX_T out_index[ParamType::kOutputNDim] = {0};\n    params.in_index_helper.OffsetToNdIndex(in_offset, in_index);\n    if (!FoldIndexTransform<INDEX_T, NDIM, SDIM>(params, in_index, out_index)) {\n      INDEX_T out_offset = params.out_index_helper.NdIndexToOffset(out_index);\n      XPUAdd<T>::Invoke(&input_ptr[in_offset], &output_ptr[out_offset]);\n    } else {\n      continue;\n    }\n  }\n}\n\n}  // namespace\n\ntemplate<typename T, typename INDEX_T, int NDIM, int SDIM>\nstruct FoldKernelUtil<DeviceType::kCUDA, T, INDEX_T, NDIM, SDIM> {\n  using ParamType = FoldParams<INDEX_T, NDIM, SDIM>;\n  static void Forward(ep::Stream* stream, const void* raw_params, const T* input_ptr,\n                      T* output_ptr) {\n    const auto* fold_params = static_cast<const ParamType*>(raw_params);\n    CudaFoldForward<T, INDEX_T, NDIM, SDIM>\n        <<<GetNumBlocks(fold_params->in_elem_cnt), kBlockSize, 0,\n           stream->As<ep::CudaStream>()->cuda_stream()>>>(*fold_params, input_ptr, output_ptr);\n  }\n};\n\nINSTANTIATE_FOLD_KERNEL_UTIL_FOR_DEVICE(DeviceType::kCUDA)\n\n}  // namespace user_op\n}  // namespace oneflow\n#endif  // WITH_CUDA"
  },
  {
    "path": "oneflow/user/kernels/fold_kernel_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_KERNELS_FOLD_KERNEL_UTIL_H_\n#define ONEFLOW_USER_KERNELS_FOLD_KERNEL_UTIL_H_\n\n#include \"oneflow/core/ep/include/stream.h\"\n#include \"oneflow/core/common/shape.h\"\n#include \"oneflow/core/common/nd_index_offset_helper.h\"\n#include \"oneflow/core/common/switch_func.h\"\n#include \"oneflow/core/ndarray/xpu_util.h\"\n#ifdef WITH_CUDA\n#include \"oneflow/core/cuda/atomic.cuh\"\n#endif  // WITH_CUDA\n\nnamespace oneflow {\n\nnamespace user_op {\n\nnamespace {\n\ntemplate<typename T>\nstruct XPUAdd {\n  OF_DEVICE_FUNC static void Invoke(const T* x, T* y) {\n#if defined(__CUDA_ARCH__)\n    cuda::atomic::Add(y, *x);\n#else\n    *y += *x;\n#endif\n  };\n};\n\n}  // namespace\n\n// NDIM range: (1, 2, 3)\n// SDIM range: (1, 2), 1 indicates channels_last, 2 indicates channels_first\ntemplate<typename INDEX_T, int NDIM, int SDIM>\nstruct FoldParams {\n  static constexpr int kInputNDim = NDIM * 2 + 2;\n  static constexpr int kOutputNDim = NDIM + 2;\n  static constexpr int kInputChannelDim = (2 - SDIM) * NDIM * 2 + 1;\n  static constexpr int kOutputChannelDim = (2 - SDIM) * NDIM + 1;\n  static_assert(kInputChannelDim < kInputNDim, \"\");\n  static_assert(kOutputChannelDim < kOutputNDim, \"\");\n  FoldParams(const int64_t batch_size, const int64_t channels, const int32_t* output_size,\n             const int64_t* spatial_dims, const int32_t* kernel_size, const int32_t* padding,\n             const int32_t* stride, const int32_t* dilation);\n  INDEX_T in_elem_cnt;\n  INDEX_T out_elem_cnt;\n  INDEX_T dims[NDIM];\n  int padding[NDIM];\n  int stride[NDIM];\n  int dilation[NDIM];\n  NdIndexOffsetHelper<INDEX_T, kInputNDim> in_index_helper;\n  NdIndexOffsetHelper<INDEX_T, kOutputNDim> out_index_helper;\n};\n\ntemplate<typename INDEX_T, int NDIM, int SDIM>\nFoldParams<INDEX_T, NDIM, SDIM>::FoldParams(const int64_t batch_size,\n                                            const int64_t channels_columns,\n                                            const int32_t* output_size, const int64_t* spatial_dims,\n                                            const int32_t* kernel_size, const int32_t* padding,\n                                            const int32_t* stride, const int32_t* dilation)\n    : in_elem_cnt(0), out_elem_cnt(0), in_index_helper(0), out_index_helper(0) {\n  INDEX_T input_dims[kInputNDim] = {0};\n  INDEX_T output_dims[kOutputNDim] = {0};\n  const int32_t channels =\n      channels_columns / (kernel_size[0] * kernel_size[1]);  // channels_columns = C*K*K\n  this->in_elem_cnt = batch_size * channels;\n  this->out_elem_cnt = batch_size * channels;\n  input_dims[0] = batch_size;\n  output_dims[0] = batch_size;\n  input_dims[kInputChannelDim] = channels;\n  output_dims[kOutputChannelDim] = channels;\n  for (int d = 0; d < NDIM; ++d) {\n    this->dims[d] = output_size[d];\n    this->padding[d] = padding[d];\n    this->stride[d] = stride[d];\n    this->dilation[d] = dilation[d];\n    input_dims[SDIM + NDIM + d] =\n        (output_size[d] + 2 * padding[d] - dilation[d] * (kernel_size[d] - 1) - 1) / stride[d] + 1;\n    input_dims[SDIM + d] = kernel_size[d];\n    this->in_elem_cnt *= input_dims[SDIM + d] * input_dims[SDIM + NDIM + d];  // N,C*Kh*Kw, H*W\n    output_dims[SDIM + d] = output_size[d];\n    this->out_elem_cnt *= output_dims[SDIM + d];\n  }\n\n  in_index_helper = NdIndexOffsetHelper<INDEX_T, kInputNDim>(input_dims);\n  out_index_helper = NdIndexOffsetHelper<INDEX_T, kOutputNDim>(output_dims);\n}\n\n// index_a format: (N, C, D, H, W) or (N, D, H, W, C)\n// index_b format: (N, C, di, hi, wi, db, hb, wb) or (N, di, hi, wi, db, hb, wb, C)\n// return: true indicates out-of-bound, otherwise in-bound\ntemplate<typename INDEX_T, int NDIM, int SDIM>\nOF_DEVICE_FUNC bool FoldIndexTransform(const FoldParams<INDEX_T, NDIM, SDIM>& params,\n                                       const INDEX_T* index_a, INDEX_T* index_b) {\n  // batch dim index transform\n  index_b[0] = index_a[0];\n  // channel dim index transform\n  using ParamType = FoldParams<INDEX_T, NDIM, SDIM>;\n  index_b[ParamType::kOutputChannelDim] = index_a[ParamType::kInputChannelDim];\n// spatial dim index transform\n#ifdef __CUDA_ARCH__\n#pragma unroll\n#endif\n  // D,H,W spatial dim index transform\n  for (int64_t d = 0; d < NDIM; ++d) {\n    INDEX_T idx = index_a[SDIM + NDIM + d] * params.stride[d]\n                  + index_a[SDIM + d] * params.dilation[d] - params.padding[d];\n    if (idx < 0 || idx >= params.dims[d]) return true;\n    index_b[SDIM + d] = idx;\n  }\n  return false;\n}\n\ntemplate<DeviceType device_type, typename T, typename INDEX_T, int NDIM, int SDIM>\nstruct FoldKernelUtil {\n  static void Forward(ep::Stream* stream, const void* params, const T* input_ptr, T* output_ptr);\n};\n\n#define SPATIAL_NDIM_SEQ OF_PP_MAKE_TUPLE_SEQ(1) OF_PP_MAKE_TUPLE_SEQ(2) OF_PP_MAKE_TUPLE_SEQ(3)\n#define SPATIAL_DIM_SEQ OF_PP_MAKE_TUPLE_SEQ(1) OF_PP_MAKE_TUPLE_SEQ(2)\n\n#define INSTANTIATE_FOLD_KERNEL_UTIL(device, dtype, itype, ndim, sdim) \\\n  template struct FoldKernelUtil<device, dtype, itype, ndim, sdim>;\n\n#define INSTANTIATE_FOLD_KERNEL_UTIL_WITH_TYPE_PAIR(device, dtype_pair, itype_pair, ndim, sdim)    \\\n  INSTANTIATE_FOLD_KERNEL_UTIL(device, OF_PP_PAIR_FIRST(dtype_pair), OF_PP_PAIR_FIRST(itype_pair), \\\n                               ndim, sdim)\n\n#define INSTANTIATE_FOLD_KERNEL_UTIL_FOR_DEVICE(device)                                           \\\n  OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_FOLD_KERNEL_UTIL_WITH_TYPE_PAIR, (device),         \\\n                                   FLOATING_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ, SPATIAL_NDIM_SEQ, \\\n                                   SPATIAL_DIM_SEQ)\n\n}  // namespace user_op\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_KERNELS_FOLD_KERNEL_UTIL_H_"
  },
  {
    "path": "oneflow/user/kernels/frac_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n\nnamespace oneflow {\n\ntemplate<typename T>\nclass CpuFracKernel final : public user_op::OpKernel {\n public:\n  CpuFracKernel() = default;\n  ~CpuFracKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    const int32_t elem_cnt = x->shape_view().elem_cnt();\n    const T* x_ptr = x->dptr<T>();\n    T* y_ptr = y->mut_dptr<T>();\n    FOR_RANGE(int32_t, i, 0, elem_cnt) { y_ptr[i] = x_ptr[i] - std::trunc(x_ptr[i]); }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_CPU_FRAC_KERNEL(dtype)                                             \\\n  REGISTER_USER_KERNEL(\"frac\").SetCreateFn<CpuFracKernel<dtype>>().SetIsMatchedHob( \\\n      (user_op::HobDeviceType() == DeviceType::kCPU)                                \\\n      && (user_op::HobDataType(\"y\", 0) == GetDataType<dtype>::value));\n\nREGISTER_CPU_FRAC_KERNEL(float)\nREGISTER_CPU_FRAC_KERNEL(double)\n\n}  // namespace oneflow"
  },
  {
    "path": "oneflow/user/kernels/frac_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/device/cuda_util.h\"\n#include \"oneflow/core/common/data_type.h\"\n#include \"oneflow/core/cuda/elementwise.cuh\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/util/cuda_half_util.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\nnamespace oneflow {\n\nnamespace {\n\n// Write ReLU Functor.\ntemplate<typename T>\nstruct FracForwardGpu {\n  OF_DEVICE_FUNC T operator()(T x) const { return x - std::trunc(x); }\n};\n\n}  // namespace\n\ntemplate<typename T>\nclass GpuFracKernel final : public user_op::OpKernel {\n public:\n  GpuFracKernel() = default;\n  ~GpuFracKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    const int32_t elem_cnt = x->shape_view().elem_cnt();\n    // Use CUDA Elementwise Template.\n    OF_CUDA_CHECK(\n        (cuda::elementwise::Unary(FracForwardGpu<T>(), elem_cnt, y->mut_dptr<T>(), x->dptr<T>(),\n                                  ctx->stream()->As<ep::CudaStream>()->cuda_stream())));\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_GPU_FRAC_KERNEL(dtype)                                             \\\n  REGISTER_USER_KERNEL(\"frac\").SetCreateFn<GpuFracKernel<dtype>>().SetIsMatchedHob( \\\n      (user_op::HobDeviceType() == DeviceType::kCUDA)                               \\\n      && (user_op::HobDataType(\"y\", 0) == GetDataType<dtype>::value));\n\nREGISTER_GPU_FRAC_KERNEL(float)\nREGISTER_GPU_FRAC_KERNEL(double)\n\n}  // namespace oneflow"
  },
  {
    "path": "oneflow/user/kernels/fused_attention_kernels.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#ifdef WITH_CUTLASS\n\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n#include \"oneflow/core/cuda/elementwise.cuh\"\n#include \"oneflow/core/ep/include/primitive/permute.h\"\n#include \"cutlass/arch/mma.h\"\n#include \"cutlass/gemm/warp/mma.h\"\n#include \"kernel_forward.h\"\n#include \"oneflow/core/kernel/cuda_graph_support.h\"\n#include \"trt_flash_attention/fmha.h\"\n#include \"trt_flash_attention/fmha_flash_attention.h\"\n\nnamespace oneflow {\n\nnamespace user_op {\n\nnamespace {\n\nvoid ParseDims(const ShapeView& shape, const std::string& layout,\n               const Optional<int64_t>& batch_size, const Optional<int64_t>& seq_len,\n               const Optional<int64_t>& num_heads, const Optional<int64_t>& head_size,\n               int64_t tensor_index, int64_t* b, int64_t* m, int64_t* h, int64_t* k,\n               int64_t* b_stride, int64_t* m_stride, int64_t* h_stride, int64_t* offset,\n               bool* bm_packed) {\n  if (shape.NumAxes() == 2) {\n    if (layout == \"(BM)(HK)\" || layout == \"(BM)(H2K)\" || layout == \"(BM)(H3K)\") {\n      *bm_packed = true;\n      CHECK(batch_size);\n      CHECK(seq_len);\n      *b = CHECK_JUST(batch_size);\n      *m = CHECK_JUST(seq_len);\n      int64_t packed_n = 0;\n      if (layout == \"(BM)(HK)\") {\n        packed_n = 1;\n      } else if (layout == \"(BM)(H2K)\") {\n        packed_n = 2;\n      } else if (layout == \"(BM)(H3K)\") {\n        packed_n = 3;\n      } else {\n        UNIMPLEMENTED();\n      }\n      const int64_t hidden_size = shape.At(1);\n      if (num_heads) {\n        const int64_t expected_h = CHECK_JUST(num_heads);\n        const int64_t packed_h = packed_n * expected_h;\n        CHECK_EQ(hidden_size % packed_h, 0);\n        *h = expected_h;\n        *k = hidden_size / packed_h;\n      } else if (head_size) {\n        const int64_t expected_k = CHECK_JUST(head_size);\n        const int64_t packed_k = packed_n * expected_k;\n        CHECK_EQ(hidden_size % packed_k, 0);\n        *h = hidden_size / packed_k;\n        *k = expected_k;\n      } else {\n        UNIMPLEMENTED();\n      }\n      *h_stride = *k * packed_n;\n      *m_stride = *h_stride * *h;\n      *b_stride = 0;\n      if (packed_n == 1) {\n        *offset = 0;\n      } else if (packed_n == 2) {\n        CHECK_GE(tensor_index, 1);\n        *offset = (tensor_index - 1) * *k;\n      } else if (packed_n == 3) {\n        *offset = tensor_index * *k;\n      } else {\n        UNIMPLEMENTED();\n      }\n    } else {\n      UNIMPLEMENTED();\n    }\n  } else if (shape.NumAxes() == 3) {\n    if (layout == \"BM(HK)\" || layout == \"BM(H2K)\" || layout == \"BM(H3K)\" || layout == \"MB(HK)\"\n        || layout == \"MB(H2K)\" || layout == \"MB(H3K)\") {\n      *bm_packed = false;\n      bool batch_first = false;\n      int64_t packed_n = 0;\n      const std::string layout_bm = layout.substr(0, 2);\n      const std::string layout_hk = layout.substr(2);\n      if (layout_bm == \"BM\") {\n        *b = shape.At(0);\n        *m = shape.At(1);\n        batch_first = true;\n      } else if (layout_bm == \"MB\") {\n        *b = shape.At(1);\n        *m = shape.At(0);\n        batch_first = false;\n      } else {\n        UNIMPLEMENTED();\n      }\n      if (layout_hk == \"(HK)\") {\n        packed_n = 1;\n      } else if (layout_hk == \"(H2K)\") {\n        packed_n = 2;\n      } else if (layout_hk == \"(H3K)\") {\n        packed_n = 3;\n      } else {\n        UNIMPLEMENTED();\n      }\n      const int64_t hidden_size = shape.At(2);\n      if (num_heads) {\n        const int64_t expected_h = CHECK_JUST(num_heads);\n        const int64_t packed_h = packed_n * expected_h;\n        CHECK_EQ(hidden_size % packed_h, 0);\n        *h = expected_h;\n        *k = hidden_size / packed_h;\n      } else if (head_size) {\n        const int64_t expected_k = CHECK_JUST(head_size);\n        const int64_t packed_k = packed_n * expected_k;\n        CHECK_EQ(hidden_size % packed_k, 0);\n        *h = hidden_size / packed_k;\n        *k = expected_k;\n      } else {\n        UNIMPLEMENTED();\n      }\n      *h_stride = *k * packed_n;\n      if (batch_first) {\n        *m_stride = *h_stride * *h;\n        *b_stride = *m_stride * *m;\n      } else {\n        *b_stride = *h_stride * *h;\n        *m_stride = *b_stride * *b;\n      }\n      if (packed_n == 1) {\n        *offset = 0;\n      } else if (packed_n == 2) {\n        CHECK_GE(tensor_index, 1);\n        *offset = (tensor_index - 1) * *k;\n      } else if (packed_n == 3) {\n        *offset = tensor_index * *k;\n      } else {\n        UNIMPLEMENTED();\n      }\n    } else if (layout == \"(BM)HK\") {\n      *bm_packed = true;\n      CHECK(batch_size);\n      CHECK(seq_len);\n      *b = CHECK_JUST(batch_size);\n      *m = CHECK_JUST(seq_len);\n      *h = shape.At(1);\n      *k = shape.At(2);\n      *h_stride = *k;\n      *m_stride = *h_stride * *h;\n      *b_stride = 0;\n    } else {\n      UNIMPLEMENTED();\n    }\n  } else if (shape.NumAxes() == 4) {\n    *bm_packed = false;\n    if (layout == \"BMHK\") {\n      *b = shape.At(0);\n      *m = shape.At(1);\n      *h = shape.At(2);\n      *k = shape.At(3);\n      *h_stride = *k;\n      *m_stride = *h_stride * *h;\n      *b_stride = *m_stride * *m;\n    } else if (layout == \"BHMK\") {\n      *b = shape.At(0);\n      *m = shape.At(2);\n      *h = shape.At(1);\n      *k = shape.At(3);\n      *m_stride = *k;\n      *h_stride = *m_stride * *m;\n      *b_stride = *h_stride * *h;\n    } else if (layout == \"MBHK\") {\n      *b = shape.At(1);\n      *m = shape.At(0);\n      *h = shape.At(2);\n      *k = shape.At(3);\n      *h_stride = *k;\n      *b_stride = *h_stride * *h;\n      *m_stride = *b_stride * *b;\n    } else {\n      UNIMPLEMENTED();\n    }\n    *offset = 0;\n  } else {\n    UNIMPLEMENTED();\n  };\n  if (batch_size) {\n    const int64_t expected_b = CHECK_JUST(batch_size);\n    CHECK_EQ(*b, expected_b);\n  }\n  if (seq_len) {\n    const int64_t expected_m = CHECK_JUST(seq_len);\n    CHECK_EQ(*m, expected_m);\n  }\n  if (num_heads) {\n    const int64_t expected_h = CHECK_JUST(num_heads);\n    CHECK_EQ(*h, expected_h);\n  }\n  if (head_size) {\n    const int64_t expected_k = CHECK_JUST(head_size);\n    CHECK_EQ(*k, expected_k);\n  }\n}\n\nvoid ParseDims(const ShapeView& shape, const std::string& layout,\n               const Optional<int64_t>& num_heads, const Optional<int64_t>& head_size,\n               int64_t tensor_index, int64_t* b, int64_t* m, int64_t* h, int64_t* k,\n               int64_t* b_stride, int64_t* m_stride, int64_t* h_stride, int64_t* offset) {\n  bool bm_packed{};\n  ParseDims(shape, layout, Optional<int64_t>(), Optional<int64_t>(), num_heads, head_size,\n            tensor_index, b, m, h, k, b_stride, m_stride, h_stride, offset, &bm_packed);\n}\n\ntemplate<typename T, int pack_size>\nstruct alignas(pack_size * sizeof(T)) Pack {\n  T elem[pack_size];\n};\n\ntemplate<typename T>\n__global__ void PackQkv(int b, int s, int nh, int d, const T* q, const T* k, const T* v, T* o,\n                        int32_t* seq_len) {\n  int count = b * s * nh * d * 3;\n  for (int i = threadIdx.x + blockIdx.x * blockDim.x; i < count; i += blockDim.x * gridDim.x) {\n    int row = i / (d * 3);\n    int out_col = i - row * (d * 3);\n    T out;\n    if (out_col < d) {\n      out = q[row * d + out_col];\n    } else if (out_col < 2 * d) {\n      out = k[row * d + out_col - d];\n    } else {\n      out = v[row * d + out_col - d * 2];\n    }\n    o[i] = out;\n  }\n  for (int i = threadIdx.x + blockIdx.x * blockDim.x; i < b + 1; i += blockDim.x * gridDim.x) {\n    seq_len[i] = i * s;\n  }\n}\n\nstruct Params {\n  DataType data_type;\n  int64_t num_batches;\n  int64_t num_heads;\n  int64_t query_seq_len;\n  int64_t kv_seq_len;\n  int64_t head_size;\n  int64_t value_head_size;\n  int64_t q_stride_b;\n  int64_t q_stride_m;\n  int64_t q_stride_h;\n  int64_t k_stride_b;\n  int64_t k_stride_m;\n  int64_t k_stride_h;\n  int64_t v_stride_b;\n  int64_t v_stride_m;\n  int64_t v_stride_h;\n  std::string attn_mask_type;\n  int64_t causal_diagonal_offset;\n  const void* query_ptr;\n  const void* key_ptr;\n  const void* value_ptr;\n  const void* attn_bias_ptr;\n  const void* query_seq_start_ptr;\n  const void* key_seq_start_ptr;\n  const void* key_seq_len_ptr;\n  int64_t attn_bias_stride_b;\n  int64_t attn_bias_stride_h;\n  int64_t attn_bias_stride_m;\n  void* out_ptr;\n  void* workspace;\n  int64_t workspace_size;\n  float scale;\n};\n\ntemplate<typename T, typename ArchTag, bool is_aligned, int queries_per_block, int keys_per_block,\n         bool single_value_iteration, bool with_attn_bias>\nvoid LaunchCutlassFmha(const Params& params, ep::CudaStream* stream) {\n  // The fmha implementation below is based on xformers's fmha\n  // implementation at:\n  // https://github.com/facebookresearch/xformers/tree/main/xformers/csrc/attention/cuda/fmha\n  using Attention = AttentionKernel<T, ArchTag, is_aligned, queries_per_block, keys_per_block,\n                                    single_value_iteration, false, with_attn_bias>;\n  typename Attention::Params p{};\n  p.query_ptr = const_cast<T*>(reinterpret_cast<const T*>(params.query_ptr));\n  p.key_ptr = const_cast<T*>(reinterpret_cast<const T*>(params.key_ptr));\n  p.value_ptr = const_cast<T*>(reinterpret_cast<const T*>(params.value_ptr));\n  p.attn_bias_ptr = const_cast<T*>(reinterpret_cast<const T*>(params.attn_bias_ptr));\n  p.seqstart_q_ptr =\n      const_cast<int32_t*>(reinterpret_cast<const int32_t*>(params.query_seq_start_ptr));\n  p.seqstart_k_ptr =\n      const_cast<int32_t*>(reinterpret_cast<const int32_t*>(params.key_seq_start_ptr));\n  p.seqlen_k_ptr = const_cast<int32_t*>(reinterpret_cast<const int32_t*>(params.key_seq_len_ptr));\n  p.logsumexp_ptr = nullptr;\n  p.output_ptr = reinterpret_cast<T*>(params.out_ptr);\n  if (Attention::kNeedsOutputAccumulatorBuffer) {\n    using Acc = typename Attention::accum_t;\n    CHECK_GE(params.workspace_size, params.num_batches * params.query_seq_len * params.num_heads\n                                        * params.value_head_size * sizeof(Acc));\n    p.output_accum_ptr = reinterpret_cast<Acc*>(params.workspace);\n  } else {\n    p.output_accum_ptr = nullptr;\n  }\n  p.num_heads = params.num_heads;\n  p.num_batches = params.num_batches;\n  p.head_dim = params.head_size;\n  p.head_dim_value = params.value_head_size;\n  p.num_queries = params.query_seq_len;\n  p.num_keys = params.kv_seq_len;\n  p.q_strideM = params.q_stride_m;\n  p.k_strideM = params.k_stride_m;\n  p.v_strideM = params.v_stride_m;\n  p.o_strideM = p.head_dim_value * p.num_heads;\n  p.bias_strideM = params.attn_bias_stride_m;\n\n  p.q_strideH = params.q_stride_h;\n  p.k_strideH = params.k_stride_h;\n  p.v_strideH = params.v_stride_h;\n  p.bias_strideH = params.attn_bias_stride_h;\n\n  p.q_strideB = params.q_stride_b;\n  p.k_strideB = params.k_stride_b;\n  p.v_strideB = params.v_stride_b;\n  p.bias_strideB = params.attn_bias_stride_b;\n\n  p.scale = params.scale;\n\n  if (params.attn_mask_type == \"none\") {\n    p.custom_mask_type = Attention::NoCustomMask;\n  } else if (params.attn_mask_type == \"causal_from_top_left\") {\n    p.custom_mask_type = Attention::CausalFromTopLeft;\n  } else if (params.attn_mask_type == \"causal_from_bottom_right\") {\n    p.custom_mask_type = Attention::CausalFromBottomRight;\n  } else {\n    UNIMPLEMENTED();\n  }\n  p.causal_diagonal_offset = params.causal_diagonal_offset;\n  p.use_dropout = false;\n\n  constexpr auto kernel_fn = attention_kernel_batched_impl<Attention>;\n  int smem_bytes = sizeof(typename Attention::SharedStorage);\n  if (smem_bytes > 0xc000) {\n    static bool once = [&]() {\n      cudaFuncSetAttribute(kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes);\n      return true;\n    }();\n  }\n  CHECK(Attention::check_supported(p));\n  kernel_fn<<<p.getBlocksGrid(), p.getThreadsGrid(), smem_bytes, stream->cuda_stream()>>>(p);\n}\n\ntemplate<typename T, typename ArchTag, bool is_aligned, int queries_per_block, int keys_per_block,\n         bool single_value_iteration>\nvoid DispatchWithAttnBias(const Params& params, ep::CudaStream* stream) {\n  if (params.attn_bias_ptr != nullptr) {\n    LaunchCutlassFmha<T, ArchTag, is_aligned, queries_per_block, keys_per_block,\n                      single_value_iteration, true>(params, stream);\n  } else {\n    LaunchCutlassFmha<T, ArchTag, is_aligned, queries_per_block, keys_per_block,\n                      single_value_iteration, false>(params, stream);\n  }\n}\n\ntemplate<typename T, typename ArchTag, bool is_aligned, int queries_per_block, int keys_per_block>\nvoid DispatchSingleValueIteration(const Params& params, ep::CudaStream* stream) {\n  if (params.value_head_size <= keys_per_block) {\n    DispatchWithAttnBias<T, ArchTag, is_aligned, queries_per_block, keys_per_block, true>(params,\n                                                                                          stream);\n  } else {\n    DispatchWithAttnBias<T, ArchTag, is_aligned, queries_per_block, keys_per_block, false>(params,\n                                                                                           stream);\n  }\n}\n\ntemplate<typename T, typename ArchTag, bool is_aligned>\nvoid DispatchKeysPerBlock(const Params& params, ep::CudaStream* stream) {\n  if (params.value_head_size <= 64) {\n    DispatchSingleValueIteration<T, ArchTag, is_aligned, 64, 64>(params, stream);\n  } else {\n    DispatchSingleValueIteration<T, ArchTag, is_aligned, 32, 128>(params, stream);\n  }\n}\n\ntemplate<typename T, typename ArchTag>\nvoid DispatchIsAligned(const Params& params, ep::CudaStream* stream) {\n  if (reinterpret_cast<uintptr_t>(params.query_ptr) % 16 == 0\n      && reinterpret_cast<uintptr_t>(params.key_ptr) % 16 == 0\n      && reinterpret_cast<uintptr_t>(params.value_ptr) % 16 == 0\n      && params.attn_bias_stride_m % (16 / sizeof(T)) == 0\n      && params.head_size % (16 / sizeof(T)) == 0\n      && params.value_head_size % (16 / sizeof(T)) == 0) {\n    DispatchKeysPerBlock<T, ArchTag, true>(params, stream);\n  } else {\n    DispatchKeysPerBlock<T, ArchTag, false>(params, stream);\n  }\n}\n\ntemplate<typename T>\nvoid DispatchArchTag(const Params& params, ep::CudaStream* stream) {\n  const int major = stream->device_properties().major;\n  const int minor = stream->device_properties().minor;\n\n  if (major == 8) {\n    DispatchIsAligned<T, cutlass::arch::Sm80>(params, stream);\n  } else if (major == 7) {\n    if (minor == 5) {\n      DispatchIsAligned<T, cutlass::arch::Sm75>(params, stream);\n    } else {\n      DispatchIsAligned<T, cutlass::arch::Sm70>(params, stream);\n    }\n  } else {\n    UNIMPLEMENTED();\n  }\n}\n\nvoid DispatchCutlassFmha(const Params& params, ep::CudaStream* stream) {\n  if (params.data_type == DataType::kFloat16) {\n    DispatchArchTag<cutlass::half_t>(params, stream);\n  } else if (params.data_type == DataType::kFloat) {\n    DispatchArchTag<float>(params, stream);\n  } else {\n    UNIMPLEMENTED();\n  }\n}\n\nclass FusedMultiHeadAttentionInferenceKernel final : public user_op::OpKernel,\n                                                     public user_op::CudaGraphSupport {\n public:\n  FusedMultiHeadAttentionInferenceKernel() = default;\n  ~FusedMultiHeadAttentionInferenceKernel() override = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const Tensor* query = ctx->Tensor4ArgNameAndIndex(\"query\", 0);\n    const Tensor* key = ctx->Tensor4ArgNameAndIndex(\"key\", 0);\n    const Tensor* value = ctx->Tensor4ArgNameAndIndex(\"value\", 0);\n    const Tensor* attn_bias = nullptr;\n    if (ctx->has_input(\"attn_bias\", 0)) { attn_bias = ctx->Tensor4ArgNameAndIndex(\"attn_bias\", 0); }\n    const Tensor* query_seq_start = nullptr;\n    const Tensor* key_seq_start = nullptr;\n    const Tensor* key_seq_len = nullptr;\n    const float scale = ctx->Attr<double>(\"scale\");\n    if (ctx->has_input(\"query_seq_start\", 0)) {\n      CHECK(ctx->has_input(\"key_seq_start\", 0));\n      query_seq_start = ctx->Tensor4ArgNameAndIndex(\"query_seq_start\", 0);\n      key_seq_start = ctx->Tensor4ArgNameAndIndex(\"key_seq_start\", 0);\n      CHECK(query_seq_start->data_type() == DataType::kInt32);\n      CHECK(key_seq_start->data_type() == DataType::kInt32);\n      CHECK_EQ(query_seq_start->shape_view().NumAxes(), 1);\n      CHECK_GT(query_seq_start->shape_view().At(0), 1);\n      CHECK(query_seq_start->shape_view() == key_seq_start->shape_view());\n      if (ctx->has_input(\"key_seq_len\", 0)) {\n        key_seq_len = ctx->Tensor4ArgNameAndIndex(\"key_seq_len\", 0);\n        CHECK(key_seq_len->data_type() == DataType::kInt32);\n        CHECK_EQ(key_seq_len->shape_view().NumAxes(), 1);\n        CHECK_EQ(key_seq_len->shape_view().At(0), query_seq_start->shape_view().At(0) - 1);\n      }\n    } else {\n      CHECK(!ctx->has_input(\"key_seq_start\", 0));\n      CHECK(!ctx->has_input(\"key_seq_len\", 0));\n    }\n    Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    Tensor* tmp = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n    const DataType data_type = query->data_type();\n    CHECK_EQ(key->data_type(), data_type);\n    CHECK_EQ(value->data_type(), data_type);\n    CHECK_EQ(out->data_type(), data_type);\n    const int64_t query_head_size = ctx->Attr<int64_t>(\"query_head_size\");\n    const std::string& attn_mask_type = ctx->Attr<std::string>(\"attn_mask_type\");\n    const int64_t causal_diagonal_offset = ctx->Attr<int64_t>(\"causal_diagonal_offset\");\n    CHECK_GE(causal_diagonal_offset, 0);\n    const std::string& query_layout = ctx->Attr<std::string>(\"query_layout\");\n    const std::string& key_layout = ctx->Attr<std::string>(\"key_layout\");\n    const std::string& value_layout = ctx->Attr<std::string>(\"value_layout\");\n    const std::string& output_layout = ctx->Attr<std::string>(\"output_layout\");\n\n    Optional<int64_t> batch_size;\n    if (query_seq_start != nullptr) { batch_size = query_seq_start->shape_view().At(0) - 1; }\n    Optional<int64_t> query_max_seq_len;\n    const int64_t attr_query_max_seq_len = ctx->Attr<int64_t>(\"query_max_seq_len\");\n    if (attr_query_max_seq_len != 0) { query_max_seq_len = attr_query_max_seq_len; }\n    Optional<int64_t> key_max_seq_len;\n    const int64_t attr_key_max_seq_len = ctx->Attr<int64_t>(\"key_max_seq_len\");\n    if (attr_key_max_seq_len != 0) { key_max_seq_len = attr_key_max_seq_len; }\n\n    int64_t q_b = 0;\n    int64_t q_m = 0;\n    int64_t q_h = 0;\n    int64_t q_k = 0;\n    int64_t q_b_stride = 0;\n    int64_t q_m_stride = 0;\n    int64_t q_h_stride = 0;\n    int64_t q_offset = 0;\n    bool q_bm_packed = false;\n    ParseDims(query->shape_view(), query_layout, batch_size, query_max_seq_len, Optional<int64_t>(),\n              query_head_size, 0, &q_b, &q_m, &q_h, &q_k, &q_b_stride, &q_m_stride, &q_h_stride,\n              &q_offset, &q_bm_packed);\n    if (q_bm_packed) { CHECK(query_seq_start != nullptr); }\n\n    int64_t k_b = 0;\n    int64_t k_m = 0;\n    int64_t k_h = 0;\n    int64_t k_k = 0;\n    int64_t k_b_stride = 0;\n    int64_t k_m_stride = 0;\n    int64_t k_h_stride = 0;\n    int64_t k_offset = 0;\n    bool k_bm_packed = false;\n    ParseDims(key->shape_view(), key_layout, q_b, key_max_seq_len, Optional<int64_t>(),\n              query_head_size, 1, &k_b, &k_m, &k_h, &k_k, &k_b_stride, &k_m_stride, &k_h_stride,\n              &k_offset, &k_bm_packed);\n    CHECK_EQ(k_b, q_b);\n    CHECK_EQ(k_h, q_h);\n    CHECK_EQ(k_bm_packed, q_bm_packed);\n\n    int64_t v_b = 0;\n    int64_t v_m = 0;\n    int64_t v_h = 0;\n    int64_t v_k = 0;\n    int64_t v_b_stride = 0;\n    int64_t v_m_stride = 0;\n    int64_t v_h_stride = 0;\n    int64_t v_offset = 0;\n    bool v_bm_packed = false;\n    ParseDims(value->shape_view(), value_layout, q_b, k_m, q_h, Optional<int64_t>(), 2, &v_b, &v_m,\n              &v_h, &v_k, &v_b_stride, &v_m_stride, &v_h_stride, &v_offset, &v_bm_packed);\n    CHECK_EQ(v_b, q_b);\n    CHECK_EQ(v_m, k_m);\n    CHECK_EQ(v_bm_packed, k_bm_packed);\n    if (output_layout == \"BM(HK)\") {\n      CHECK(!q_bm_packed);\n      CHECK_EQ(out->shape_view().NumAxes(), 3);\n      CHECK_EQ(out->shape_view().At(0), q_b);\n      CHECK_EQ(out->shape_view().At(1), q_m);\n      CHECK_EQ(out->shape_view().At(2), q_h * v_k);\n    } else if (output_layout == \"MB(HK)\") {\n      CHECK(!q_bm_packed);\n      CHECK_EQ(out->shape_view().NumAxes(), 3);\n      CHECK_EQ(q_b, 1);\n      CHECK_EQ(out->shape_view().At(0), q_m);\n      CHECK_EQ(out->shape_view().At(1), q_b);\n      CHECK_EQ(out->shape_view().At(2), q_h * v_k);\n    } else if (output_layout == \"(BM)(HK)\") {\n      CHECK(q_bm_packed);\n      CHECK_EQ(out->shape_view().NumAxes(), 2);\n      CHECK_EQ(out->shape_view().At(0), query->shape_view().At(0));\n      CHECK_EQ(out->shape_view().At(1), q_h * v_k);\n    } else {\n      UNIMPLEMENTED();\n    }\n\n    auto* cuda_stream = ctx->stream()->As<ep::CudaStream>();\n\n    // Compatible with typo `KERENL`\n    const bool enable_trt_flash_attn =\n        ParseBooleanFromEnv(\n            \"ONEFLOW_KERNEL_FMHA_ENABLE_TRT_FLASH_ATTN_IMPL\",\n            ParseBooleanFromEnv(\"ONEFLOW_KERENL_FMHA_ENABLE_TRT_FLASH_ATTN_IMPL\", true))\n        && ParseBooleanFromEnv(\"ONEFLOW_MATMUL_ALLOW_HALF_PRECISION_ACCUMULATION\", false);\n    const bool is_default_scale =\n        std::abs(scale - 1.0 / std::sqrt(static_cast<float>(q_k))) <= 1e-5;\n    const int arch = cuda_stream->cuda_arch() / 10;\n    const bool is_trt_supported_arch = (arch == 75 || arch == 80 || arch == 86 || arch == 89);\n    const bool is_trt_supported_head_size = ((q_k == 40) || (q_k == 64));\n    // Avoid PackQKV overhead when seq_len is small.\n    const bool is_long_seq_len = q_m >= 512;\n    const bool is_trt_supported_layout = (query_layout == \"BMHK\" || query_layout == \"BM(HK)\")\n                                         && (key_layout == \"BMHK\" || key_layout == \"BM(HK)\")\n                                         && (value_layout == \"BMHK\" || value_layout == \"BM(HK)\")\n                                         && (output_layout == \"BMHK\" || output_layout == \"BM(HK)\");\n    if (is_default_scale && query_seq_start == nullptr && enable_trt_flash_attn\n        && data_type == DataType::kFloat16 && q_m == k_m && q_k == v_k && is_trt_supported_head_size\n        && is_long_seq_len && is_trt_supported_arch && attn_mask_type == \"none\"\n        && attn_bias == nullptr && is_trt_supported_layout) {\n      // The fmha implementation below is based on TensorRT's multiHeadFlashAttentionPlugin\n      // implementation at:\n      // https://github.com/NVIDIA/TensorRT/tree/main/plugin/multiHeadFlashAttentionPlugin\n      int32_t cu_seqlens_d_size = (q_b + 1) * sizeof(int32_t);\n      int32_t* cu_seqlens_d = reinterpret_cast<int32_t*>(tmp->mut_dptr());\n      half* packed_qkv =\n          reinterpret_cast<half*>(tmp->mut_dptr<char>() + GetCudaAlignedSize(cu_seqlens_d_size));\n      constexpr int pack_size = 4;\n      using PackType = Pack<half, pack_size>;\n      const int64_t count = q_b * q_m * q_h * q_k * 3 / pack_size;\n      PackQkv<PackType><<<(count - 1 + 256) / 256, 256, 0, cuda_stream->cuda_stream()>>>(\n          q_b, q_m, q_h, q_k / pack_size, reinterpret_cast<const PackType*>(query->dptr()),\n          reinterpret_cast<const PackType*>(key->dptr()),\n          reinterpret_cast<const PackType*>(value->dptr()), reinterpret_cast<PackType*>(packed_qkv),\n          cu_seqlens_d);\n\n#ifdef WITH_CUDA_GRAPHS\n      cudaStreamCaptureMode mode = cudaStreamCaptureModeRelaxed;\n      if (cuda_stream->IsGraphCapturing()) {\n        OF_CUDA_CHECK(cudaThreadExchangeStreamCaptureMode(&mode));\n      }\n#endif  // WITH_CUDA_GRAPHS\n      nvinfer1::plugin::FusedMultiHeadFlashAttentionKernel const* kernels =\n          nvinfer1::plugin::getFMHAFlashCubinKernels(nvinfer1::plugin::DATA_TYPE_FP16, arch);\n#ifdef WITH_CUDA_GRAPHS\n      if (cuda_stream->IsGraphCapturing()) {\n        OF_CUDA_CHECK(cudaThreadExchangeStreamCaptureMode(&mode));\n      }\n#endif  // WITH_CUDA_GRAPHS\n      nvinfer1::plugin::runFMHFAKernel(packed_qkv, cu_seqlens_d, out->mut_dptr(), q_b * q_m, arch,\n                                       kernels, q_b, q_h, q_k, q_m, cuda_stream->cuda_stream());\n      return;\n    }\n\n    Params params{};\n    params.data_type = data_type;\n    params.num_batches = q_b;\n    params.num_heads = q_h;\n    params.query_seq_len = q_m;\n    params.kv_seq_len = k_m;\n    params.head_size = q_k;\n    params.value_head_size = v_k;\n    params.scale = scale;\n    params.q_stride_b = q_b_stride;\n    params.q_stride_m = q_m_stride;\n    params.q_stride_h = q_h_stride;\n    params.k_stride_b = k_b_stride;\n    params.k_stride_m = k_m_stride;\n    params.k_stride_h = k_h_stride;\n    params.v_stride_b = v_b_stride;\n    params.v_stride_m = v_m_stride;\n    params.v_stride_h = v_h_stride;\n    params.query_ptr = query->dptr<char>() + q_offset * GetSizeOfDataType(data_type);\n    params.key_ptr = key->dptr<char>() + k_offset * GetSizeOfDataType(data_type);\n    params.value_ptr = value->dptr<char>() + v_offset * GetSizeOfDataType(data_type);\n    params.query_seq_start_ptr =\n        query_seq_start == nullptr ? nullptr : query_seq_start->dptr<int32_t>();\n    params.key_seq_start_ptr = key_seq_start == nullptr ? nullptr : key_seq_start->dptr<int32_t>();\n    params.key_seq_len_ptr = key_seq_len == nullptr ? nullptr : key_seq_len->dptr<int32_t>();\n    params.out_ptr = out->mut_dptr();\n    const int64_t tmp_buffer_size = tmp->shape_view().elem_cnt();\n    params.workspace = tmp->mut_dptr();\n    params.workspace_size = tmp_buffer_size;\n    params.attn_mask_type = attn_mask_type;\n    params.causal_diagonal_offset = causal_diagonal_offset;\n    if (attn_bias != nullptr) {\n      const int64_t num_attn_bias_axes = attn_bias->shape_view().NumAxes();\n      CHECK_GE(num_attn_bias_axes, 1);\n      CHECK_LE(num_attn_bias_axes, 4);\n      DimVector padded_attn_bias_shape;\n      for (int i = 0; i < 4 - num_attn_bias_axes; ++i) { padded_attn_bias_shape.push_back(1); }\n      for (int i = 0; i < num_attn_bias_axes; ++i) {\n        padded_attn_bias_shape.push_back(attn_bias->shape_view().At(i));\n      }\n      CHECK_GE(padded_attn_bias_shape.at(3), k_m);\n      int64_t bias_stride = padded_attn_bias_shape.at(3);\n      if (padded_attn_bias_shape.at(2) == 1) {\n        params.attn_bias_stride_m = 0;\n      } else {\n        CHECK_GE(padded_attn_bias_shape.at(2), q_m);\n        params.attn_bias_stride_m = bias_stride;\n        bias_stride *= padded_attn_bias_shape.at(2);\n      }\n      if (padded_attn_bias_shape.at(1) == 1) {\n        params.attn_bias_stride_h = 0;\n      } else {\n        CHECK_EQ(padded_attn_bias_shape.at(1), q_h);\n        params.attn_bias_stride_h = bias_stride;\n        bias_stride *= q_h;\n      }\n      if (padded_attn_bias_shape.at(0) == 1) {\n        params.attn_bias_stride_b = 0;\n      } else {\n        CHECK_EQ(padded_attn_bias_shape.at(0), q_b);\n        params.attn_bias_stride_b = bias_stride;\n      }\n      params.attn_bias_ptr = attn_bias->dptr();\n    } else {\n      params.attn_bias_ptr = nullptr;\n      params.attn_bias_stride_m = 0;\n      params.attn_bias_stride_h = 0;\n      params.attn_bias_stride_b = 0;\n    }\n    DispatchCutlassFmha(params, cuda_stream);\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\nsize_t InferTmpBufferSize(InferContext* ctx) {\n  const auto& out_desc = ctx->OutputTensorDesc(\"out\", 0);\n  size_t buffer_size = 0;\n  buffer_size +=\n      GetCudaAlignedSize(out_desc.shape().elem_cnt() * GetSizeOfDataType(DataType::kFloat));\n  buffer_size +=\n      GetCudaAlignedSize(out_desc.shape().elem_cnt() * GetSizeOfDataType(out_desc.data_type())) * 3;\n  buffer_size +=\n      GetCudaAlignedSize((out_desc.shape().At(0) + 1) * GetSizeOfDataType(DataType::kInt32));\n  return buffer_size;\n}\n\n#define REGISTER_FUSED_MULTI_HEAD_ATTENTION_INFERENCE_KERNEL(dtype)    \\\n  REGISTER_USER_KERNEL(\"fused_multi_head_attention_inference\")         \\\n      .SetCreateFn<FusedMultiHeadAttentionInferenceKernel>()           \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \\\n                       && (user_op::HobDataType(\"out\", 0) == dtype))   \\\n      .SetInferTmpSizeFn(InferTmpBufferSize);\n\nREGISTER_FUSED_MULTI_HEAD_ATTENTION_INFERENCE_KERNEL(DataType::kFloat16)\nREGISTER_FUSED_MULTI_HEAD_ATTENTION_INFERENCE_KERNEL(DataType::kFloat)\n\ntemplate<typename Index>\nstruct ConcatParam {\n  const void* past_ptr;\n  const void* ptr;\n  void* output_ptr;\n  Index past_offset;\n  Index offset;\n  Index output_offset;\n  Index past_m;\n  Index past_stride_b;\n  Index past_stride_m;\n  Index past_stride_h;\n  Index stride_b;\n  Index stride_m;\n  Index stride_h;\n  Index output_stride_b;\n  Index output_stride_m;\n  Index output_stride_h;\n  Index count;\n  Index output_khm;\n  Index output_kh;\n  Index output_k;\n};\n\ntemplate<typename Index>\nstruct BatchConcatParam {\n  ConcatParam<Index> params[2];\n};\n\ntemplate<typename T, typename Index>\n__device__ void ConcatPastKeyValue(ConcatParam<Index> p) {\n  for (Index i = blockIdx.x * blockDim.x + threadIdx.x; i < p.count; i += blockDim.x * gridDim.x) {\n    Index b_idx = i / p.output_khm;\n    Index b_off = i - b_idx * p.output_khm;\n    Index m_idx = b_off / p.output_kh;\n    Index m_off = b_off - m_idx * p.output_kh;\n    Index h_idx = m_off / p.output_k;\n    Index k_idx = m_off - h_idx * p.output_k;\n    T v;\n    if (m_idx < p.past_m) {\n      v = reinterpret_cast<const T*>(\n          p.past_ptr)[p.past_offset + b_idx * p.past_stride_b + m_idx * p.past_stride_m\n                      + h_idx * p.past_stride_h + k_idx];\n    } else {\n      v = reinterpret_cast<const T*>(\n          p.ptr)[p.offset + b_idx * p.stride_b + (m_idx - p.past_m) * p.stride_m\n                 + h_idx * p.stride_h + k_idx];\n    }\n    reinterpret_cast<T*>(\n        p.output_ptr)[p.output_offset + b_idx * p.output_stride_b + m_idx * p.output_stride_m\n                      + h_idx * p.output_stride_h + k_idx] = v;\n  }\n}\n\ntemplate<size_t elem_size, typename Index>\n__global__ void BatchConcatPastKeyValue(BatchConcatParam<Index> params) {\n  if (blockIdx.y == 0) {\n    ConcatPastKeyValue<std::aligned_storage<elem_size, elem_size>::type, Index>(params.params[0]);\n  } else if (blockIdx.y == 1) {\n    ConcatPastKeyValue<std::aligned_storage<elem_size, elem_size>::type, Index>(params.params[1]);\n  } else {\n    // do nothing\n  }\n}\n\nclass FusedAttentionConcatPastKeyValueKernel final : public user_op::OpKernel,\n                                                     public user_op::CudaGraphSupport {\n public:\n  FusedAttentionConcatPastKeyValueKernel() = default;\n  ~FusedAttentionConcatPastKeyValueKernel() override = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const Tensor* key = ctx->Tensor4ArgNameAndIndex(\"key\", 0);\n    const Tensor* value = ctx->Tensor4ArgNameAndIndex(\"value\", 0);\n    Tensor* output_key = ctx->Tensor4ArgNameAndIndex(\"output_key\", 0);\n    Tensor* output_value = ctx->Tensor4ArgNameAndIndex(\"output_value\", 0);\n    const DataType data_type = key->data_type();\n    const Tensor* past_key = nullptr;\n    const Tensor* past_value = nullptr;\n    if (ctx->has_input(\"past_key\", 0)) {\n      CHECK(ctx->has_input(\"past_value\", 0));\n      past_key = ctx->Tensor4ArgNameAndIndex(\"past_key\", 0);\n      past_value = ctx->Tensor4ArgNameAndIndex(\"past_value\", 0);\n      CHECK_EQ(past_key->data_type(), data_type);\n      CHECK_EQ(past_value->data_type(), data_type);\n    } else {\n      CHECK(!ctx->has_input(\"past_value\", 0));\n    }\n    CHECK_EQ(value->data_type(), data_type);\n    CHECK_EQ(output_key->data_type(), data_type);\n    CHECK_EQ(output_value->data_type(), data_type);\n    const int64_t size_of_data_type = GetSizeOfDataType(data_type);\n    const int64_t key_head_size = ctx->Attr<int64_t>(\"key_head_size\");\n    const std::string& past_key_layout = ctx->Attr<std::string>(\"past_key_layout\");\n    const std::string& past_value_layout = ctx->Attr<std::string>(\"past_value_layout\");\n    const std::string& key_layout = ctx->Attr<std::string>(\"key_layout\");\n    const std::string& value_layout = ctx->Attr<std::string>(\"value_layout\");\n\n    int64_t pack_size = 16 / size_of_data_type;\n    while (key_head_size % pack_size != 0) { pack_size /= 2; }\n\n    auto ParsePackedDims =\n        [](const ShapeView& shape, const std::string& layout, const Optional<int64_t>& num_heads,\n           const Optional<int64_t>& head_size, int64_t tensor_index, int64_t* b, int64_t* m,\n           int64_t* h, int64_t* k, int64_t* b_stride, int64_t* m_stride, int64_t* h_stride,\n           int64_t* offset, int64_t pack_size) {\n          ParseDims(shape, layout, num_heads, head_size, tensor_index, b, m, h, k, b_stride,\n                    m_stride, h_stride, offset);\n          *k /= pack_size;\n          *b_stride /= pack_size;\n          *m_stride /= pack_size;\n          *h_stride /= pack_size;\n          *offset /= pack_size;\n        };\n\n    int64_t key_b = 0;\n    int64_t key_m = 0;\n    int64_t key_h = 0;\n    int64_t key_k = 0;\n    int64_t key_b_stride = 0;\n    int64_t key_m_stride = 0;\n    int64_t key_h_stride = 0;\n    int64_t key_offset = 0;\n    ParsePackedDims(key->shape_view(), key_layout, Optional<int64_t>(), key_head_size, 1, &key_b,\n                    &key_m, &key_h, &key_k, &key_b_stride, &key_m_stride, &key_h_stride,\n                    &key_offset, pack_size);\n\n    int64_t value_b = 0;\n    int64_t value_m = 0;\n    int64_t value_h = 0;\n    int64_t value_k = 0;\n    int64_t value_b_stride = 0;\n    int64_t value_m_stride = 0;\n    int64_t value_h_stride = 0;\n    int64_t value_offset = 0;\n    ParsePackedDims(value->shape_view(), value_layout, key_h, key_head_size, 2, &value_b, &value_m,\n                    &value_h, &value_k, &value_b_stride, &value_m_stride, &value_h_stride,\n                    &value_offset, pack_size);\n    CHECK_EQ(value_b, key_b);\n    CHECK_EQ(value_m, key_m);\n\n    int64_t past_key_b = 0;\n    int64_t past_key_m = 0;\n    int64_t past_key_h = 0;\n    int64_t past_key_k = 0;\n    int64_t past_key_b_stride = 0;\n    int64_t past_key_m_stride = 0;\n    int64_t past_key_h_stride = 0;\n    int64_t past_key_offset = 0;\n    if (past_key != nullptr) {\n      ParsePackedDims(past_key->shape_view(), past_key_layout, key_h, key_head_size, 1, &past_key_b,\n                      &past_key_m, &past_key_h, &past_key_k, &past_key_b_stride, &past_key_m_stride,\n                      &past_key_h_stride, &past_key_offset, pack_size);\n    }\n\n    int64_t past_value_b = 0;\n    int64_t past_value_m = 0;\n    int64_t past_value_h = 0;\n    int64_t past_value_k = 0;\n    int64_t past_value_b_stride = 0;\n    int64_t past_value_m_stride = 0;\n    int64_t past_value_h_stride = 0;\n    int64_t past_value_offset = 0;\n    if (past_value != nullptr) {\n      ParsePackedDims(past_value->shape_view(), past_value_layout, key_h, key_head_size, 2,\n                      &past_value_b, &past_value_m, &past_value_h, &past_value_k,\n                      &past_value_b_stride, &past_value_m_stride, &past_value_h_stride,\n                      &past_value_offset, pack_size);\n    }\n    CHECK_EQ(past_value_b, past_key_b);\n    CHECK_EQ(past_value_m, past_key_m);\n\n    int64_t output_key_b = 0;\n    int64_t output_key_m = 0;\n    int64_t output_key_h = 0;\n    int64_t output_key_k = 0;\n    int64_t output_key_b_stride = 0;\n    int64_t output_key_m_stride = 0;\n    int64_t output_key_h_stride = 0;\n    int64_t output_key_offset = 0;\n    ParsePackedDims(output_key->shape_view(), past_key_layout, key_h, key_head_size, 1,\n                    &output_key_b, &output_key_m, &output_key_h, &output_key_k,\n                    &output_key_b_stride, &output_key_m_stride, &output_key_h_stride,\n                    &output_key_offset, pack_size);\n    CHECK_EQ(output_key_b, key_b);\n    CHECK_EQ(output_key_m, past_key_m + key_m);\n\n    int64_t output_value_b = 0;\n    int64_t output_value_m = 0;\n    int64_t output_value_h = 0;\n    int64_t output_value_k = 0;\n    int64_t output_value_b_stride = 0;\n    int64_t output_value_m_stride = 0;\n    int64_t output_value_h_stride = 0;\n    int64_t output_value_offset = 0;\n    ParsePackedDims(output_value->shape_view(), past_value_layout, key_h, key_head_size, 2,\n                    &output_value_b, &output_value_m, &output_value_h, &output_value_k,\n                    &output_value_b_stride, &output_value_m_stride, &output_value_h_stride,\n                    &output_value_offset, pack_size);\n    CHECK_EQ(output_value_b, key_b);\n    CHECK_EQ(output_value_m, past_value_m + value_m);\n\n    int64_t max_tensor_elem = (1 << 30) * pack_size;\n    CHECK((past_key == nullptr || past_key->shape_view().elem_cnt() <= max_tensor_elem)\n          && (past_value == nullptr || past_value->shape_view().elem_cnt() <= max_tensor_elem)\n          && key->shape_view().elem_cnt() <= max_tensor_elem\n          && value->shape_view().elem_cnt() <= max_tensor_elem\n          && output_key->shape_view().elem_cnt() <= max_tensor_elem\n          && output_value->shape_view().elem_cnt() <= max_tensor_elem);\n\n    int64_t count = output_key_b * output_key_m * output_key_h * output_key_k;\n    BatchConcatParam<int32_t> kv;\n\n    kv.params[0].past_ptr = past_key == nullptr ? nullptr : past_key->dptr();\n    kv.params[0].ptr = key->dptr();\n    kv.params[0].output_ptr = output_key->mut_dptr();\n    kv.params[0].past_offset = past_key_offset;\n    kv.params[0].offset = key_offset;\n    kv.params[0].output_offset = output_key_offset;\n    kv.params[0].past_m = past_key_m;\n    kv.params[0].past_stride_b = past_key_b_stride;\n    kv.params[0].past_stride_m = past_key_m_stride;\n    kv.params[0].past_stride_h = past_key_h_stride;\n    kv.params[0].stride_b = key_b_stride;\n    kv.params[0].stride_m = key_m_stride;\n    kv.params[0].stride_h = key_h_stride;\n    kv.params[0].output_stride_b = output_key_b_stride;\n    kv.params[0].output_stride_m = output_key_m_stride;\n    kv.params[0].output_stride_h = output_key_h_stride;\n    kv.params[0].count = count;\n    kv.params[0].output_khm = output_key_k * output_key_h * output_key_m;\n    kv.params[0].output_kh = output_key_k * output_key_h;\n    kv.params[0].output_k = output_key_k;\n\n    kv.params[1].past_ptr = past_value == nullptr ? nullptr : past_value->dptr();\n    kv.params[1].ptr = value->dptr();\n    kv.params[1].output_ptr = output_value->mut_dptr();\n    kv.params[1].past_offset = past_value_offset;\n    kv.params[1].offset = value_offset;\n    kv.params[1].output_offset = output_value_offset;\n    kv.params[1].past_m = past_value_m;\n    kv.params[1].past_stride_b = past_value_b_stride;\n    kv.params[1].past_stride_m = past_value_m_stride;\n    kv.params[1].past_stride_h = past_value_h_stride;\n    kv.params[1].stride_b = value_b_stride;\n    kv.params[1].stride_m = value_m_stride;\n    kv.params[1].stride_h = value_h_stride;\n    kv.params[1].output_stride_b = output_value_b_stride;\n    kv.params[1].output_stride_m = output_value_m_stride;\n    kv.params[1].output_stride_h = output_value_h_stride;\n    kv.params[1].count = count;\n    kv.params[1].output_khm = output_value_k * output_value_h * output_value_m;\n    kv.params[1].output_kh = output_value_k * output_value_h;\n    kv.params[1].output_k = output_value_k;\n\n    constexpr uint32_t block_size = 256;\n    const dim3 grid_size((count - 1 + block_size) / block_size, 2);\n\n    const int64_t elem_size = size_of_data_type * pack_size;\n    cudaStream_t cuda_stream = ctx->stream()->As<ep::CudaStream>()->cuda_stream();\n    if (elem_size == 16) {\n      BatchConcatPastKeyValue<16, int32_t><<<grid_size, block_size, 0, cuda_stream>>>(kv);\n    } else if (elem_size == 8) {\n      BatchConcatPastKeyValue<8, int32_t><<<grid_size, block_size, 0, cuda_stream>>>(kv);\n    } else if (elem_size == 4) {\n      BatchConcatPastKeyValue<4, int32_t><<<grid_size, block_size, 0, cuda_stream>>>(kv);\n    } else if (elem_size == 2) {\n      BatchConcatPastKeyValue<2, int32_t><<<grid_size, block_size, 0, cuda_stream>>>(kv);\n    } else if (elem_size == 1) {\n      BatchConcatPastKeyValue<1, int32_t><<<grid_size, block_size, 0, cuda_stream>>>(kv);\n    } else {\n      UNIMPLEMENTED();\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\nREGISTER_USER_KERNEL(\"fused_attention_concat_past_key_value\")\n    .SetCreateFn<FusedAttentionConcatPastKeyValueKernel>()\n    .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA));\n\ntemplate<typename T, typename PositionType, typename IndexType, size_t num_dims,\n         size_t rotary_emb_dim>\nstruct FusedApplyRotaryEmbParam {\n  const T* x;\n  const T* cos;\n  const T* sin;\n  const PositionType* position_ids;\n  T* out;\n  const T theta;\n  const float inv_actual_rotary_size;  // 1.0 / (rotary_size per rotary dimension)\n  const IndexType actual_rotary_size;  // rotary_size per rotary dimension\n  const IndexType rotary_size;\n  const IndexType rotate_stride;\n  const IndexType k0;\n  const IndexType k1;\n  IndexType num_elements;\n  const IndexType k;\n  const IndexType x_offset;\n\n  IndexType ref_stride[num_dims];  // b, m, h, k\n  IndexType out_stride[num_dims];  // ordered descendingly by stride\n  IndexType x_stride[num_dims];\n\n  IndexType position_b_stride;\n  IndexType position_rotate_stride;\n\n  IndexType sinuous_m_stride;\n\n  FusedApplyRotaryEmbParam(const T* x, const T* cos, const T* sin, const PositionType* position_ids,\n                           T* out, const T theta, const float inv_actual_rotary_size,\n                           const IndexType actual_rotary_size, const IndexType rotary_size,\n                           const IndexType rotate_stride, const IndexType num_elements,\n                           const IndexType k, const IndexType k0, const IndexType k1,\n                           const IndexType x_offset)\n      : x(x),\n        cos(cos),\n        sin(sin),\n        position_ids(position_ids),\n        out(out),\n        theta(theta),\n        inv_actual_rotary_size(inv_actual_rotary_size),\n        actual_rotary_size(actual_rotary_size),\n        rotary_size(rotary_size),\n        rotate_stride(rotate_stride),\n        num_elements(num_elements),\n        k(k),\n        k0(k0),\n        k1(k1),\n        x_offset(x_offset) {}\n};\n\ntemplate<typename T, typename PositionType, typename IndexType, size_t PackSize, size_t num_dims,\n         size_t rotary_emb_dim>\n__global__ void IntervalKernel(\n    FusedApplyRotaryEmbParam<T, PositionType, IndexType, num_dims, rotary_emb_dim> param) {\n  for (IndexType packed_offset = threadIdx.x + blockIdx.x * blockDim.x;\n       packed_offset < param.num_elements; packed_offset += blockDim.x * gridDim.x) {\n    using LoadPack = cuda::elementwise::Packed<T, PackSize>;\n    IndexType offset = packed_offset * PackSize;\n    IndexType index[num_dims];  // b, m, h, k\n\n    IndexType temp_offset = offset;\n\n    for (int i = 0; i < num_dims - 1; i++) {\n      IndexType ref_stride = param.ref_stride[i];\n      IndexType idx = temp_offset / ref_stride;\n      index[i] = idx;\n      temp_offset = temp_offset - idx * ref_stride;\n    }\n    index[num_dims - 1] = temp_offset;\n\n    IndexType x_offset = param.x_offset;\n    IndexType out_offset = 0;\n#pragma unroll\n    for (int i = 0; i < num_dims; i++) {\n      x_offset = x_offset + param.x_stride[i] * index[i];\n      out_offset = out_offset + param.out_stride[i] * index[i];\n    }\n    const LoadPack x_vec = *reinterpret_cast<const LoadPack*>(param.x + x_offset);\n\n    const IndexType k_index = index[num_dims - 1];\n    if (k_index < param.rotary_size) {\n      const IndexType position_rotate_index = (k_index >= param.k0) ? 1 : 0;\n      const IndexType b_index = index[0], m_index = index[1];\n      const IndexType position_id_offset = b_index * param.position_b_stride\n                                           + position_rotate_index * param.position_rotate_stride\n                                           + m_index;\n\n      const PositionType position =\n          param.position_ids ? param.position_ids[position_id_offset] : m_index;\n      const IndexType actual_k_index = k_index % param.actual_rotary_size;\n      const IndexType sinuous_offset = position * param.sinuous_m_stride + actual_k_index;\n\n      LoadPack cos_vec, sin_vec, out_vec;\n\n      if (param.cos && param.sin) {\n        cos_vec = *reinterpret_cast<const LoadPack*>(param.cos + sinuous_offset);\n        sin_vec = *reinterpret_cast<const LoadPack*>(param.sin + sinuous_offset);\n      } else {\n        const IndexType actual_ndim = param.rotary_size / rotary_emb_dim;\n#pragma unroll\n        for (int i = 0; i < PackSize / 2; i++) {\n          T val = position\n                  * expf(2.0f * static_cast<float>(((actual_k_index >> 1) + i))\n                         * param.inv_actual_rotary_size * logf(param.theta));\n          T cos_val = cosf(val);\n          T sin_val = sinf(val);\n          cos_vec.elem[i * 2] = cos_val;\n          cos_vec.elem[i * 2 + 1] = cos_val;\n          sin_vec.elem[i * 2] = sin_val;\n          sin_vec.elem[i * 2 + 1] = sin_val;\n        }\n      }\n\n#pragma unroll\n      for (int i = 0; i < PackSize / 2; i++) {\n        out_vec.elem[i * 2] =\n            x_vec.elem[i * 2] * cos_vec.elem[i * 2] - x_vec.elem[i * 2 + 1] * sin_vec.elem[i * 2];\n        out_vec.elem[i * 2 + 1] = x_vec.elem[i * 2 + 1] * cos_vec.elem[i * 2 + 1]\n                                  + x_vec.elem[i * 2] * sin_vec.elem[i * 2 + 1];\n      }\n\n      *(reinterpret_cast<LoadPack*>(param.out + out_offset)) = out_vec;\n    } else {\n      *(reinterpret_cast<LoadPack*>(param.out + out_offset)) = x_vec;\n    }\n  }\n}\n\ntemplate<typename T, typename PositionType, typename IndexType, size_t num_dims,\n         size_t rotary_emb_dim>\n__global__ void PlaneKernel(\n    FusedApplyRotaryEmbParam<T, PositionType, IndexType, num_dims, rotary_emb_dim> param) {\n  for (IndexType offset = threadIdx.x + blockIdx.x * blockDim.x; offset < param.num_elements;\n       offset += blockDim.x * gridDim.x) {\n    using LoadPack = cuda::elementwise::Packed<T, 2>;\n    IndexType temp_offset = offset;\n    IndexType index[num_dims];\n#pragma unroll\n    for (int i = 0; i < num_dims - 1; i++) {\n      IndexType ref_stride = param.ref_stride[i];\n      IndexType idx = temp_offset / ref_stride;\n      index[i] = idx;\n      temp_offset = temp_offset - idx * ref_stride;\n    }\n    index[num_dims - 1] = temp_offset;\n\n    const IndexType b_index = index[0], m_index = index[1], k_index = index[num_dims - 1];\n    const IndexType position_rotate_index = (k_index >= param.k0) ? 1 : 0;\n    const IndexType position_id_offset = b_index * param.position_b_stride\n                                         + position_rotate_index * param.position_rotate_stride\n                                         + m_index;\n\n    const PositionType position =\n        param.position_ids ? param.position_ids[position_id_offset] : m_index;\n    const IndexType actual_k_index = k_index % param.actual_rotary_size;\n    const IndexType sinuous_offset = position * param.k + actual_k_index;\n\n    T cos_val, sin_val, out_val;\n\n    if (param.cos && param.sin) {\n      cos_val = *(param.cos + sinuous_offset);\n      sin_val = *(param.sin + sinuous_offset);\n    } else {\n      T val = position\n              * expf(2.0f * static_cast<float>(k_index % (param.actual_rotary_size >> 1))\n                     * param.inv_actual_rotary_size * logf(param.theta));\n      cos_val = cosf(val);\n      sin_val = sinf(val);\n    }\n\n    LoadPack x_vec;\n    IndexType x_offset = param.x_offset;\n    IndexType out_offset = 0;\n#pragma unroll\n    for (int i = 0; i < num_dims; i++) {\n      x_offset = x_offset + param.x_stride[i] * index[i];\n      out_offset = out_offset + param.out_stride[i] * index[i];\n    }\n\n    if (k_index < param.k0) {\n      x_vec.elem[0] = *(param.x + x_offset);\n      x_vec.elem[1] = (param.k0 - k_index > param.rotate_stride)\n                          ? static_cast<T>(-*(param.x + x_offset + param.rotate_stride))\n                          : *(param.x + x_offset - param.rotate_stride);\n      out_val = cos_val * x_vec.elem[0] + sin_val * x_vec.elem[1];\n    } else if (k_index < param.k1) {\n      x_vec.elem[0] = *(param.x + x_offset);\n      x_vec.elem[1] = (param.k1 - k_index > param.rotate_stride)\n                          ? static_cast<T>(-*(param.x + x_offset + param.rotate_stride))\n                          : *(param.x + x_offset - param.rotate_stride);\n      out_val = cos_val * x_vec.elem[0] + sin_val * x_vec.elem[1];\n    } else {\n      out_val = *(param.x + x_offset);\n    }\n\n    *(param.out + out_offset) = out_val;\n  }\n}\n\ntemplate<typename T, typename PositionType, typename IndexType, size_t PackSize, size_t num_dims,\n         size_t rotary_emb_dim>\nvoid LaunchKernel(ep::CudaStream* stream, const T* x, const T* cos, const T* sin,\n                  const PositionType* position_ids, T* out, const int64_t* position_shape,\n                  const std::string& x_layout, const std::string& output_layout,\n                  const std::string& mode, const T theta, const IndexType rotary_size,\n                  const IndexType b, const IndexType m, const IndexType h, const IndexType k,\n                  const IndexType x_b_stride, const IndexType x_m_stride,\n                  const IndexType x_h_stride, const IndexType x_offset,\n                  const IndexType out_b_stride, const IndexType out_m_stride,\n                  const IndexType out_h_stride, IndexType num_elements) {\n  const IndexType k0 = rotary_size / rotary_emb_dim,\n                  k1 = rotary_size;  // TODO: this only support 1d, 2d, rotary postional encoding\n\n  const IndexType rotate_stride = rotary_size / (2 * rotary_emb_dim);\n\n  const IndexType actual_rotary_size = rotary_size / rotary_emb_dim;\n  const float inv_actual_rotary_size = 1.0 / actual_rotary_size;\n\n  struct FusedApplyRotaryEmbParam<T, PositionType, IndexType, num_dims, rotary_emb_dim> param(\n      x, cos, sin, position_ids, out, theta, inv_actual_rotary_size, actual_rotary_size,\n      rotary_size, rotate_stride, num_elements, k, k0, k1, x_offset);\n\n  const IndexType ref_strides[num_dims] = {m * h * k, h * k, k, 1};\n  const IndexType out_strides[num_dims] = {out_b_stride, out_m_stride, out_h_stride, 1};\n  const IndexType x_strides[num_dims] = {x_b_stride, x_m_stride, x_h_stride, 1};\n\n  param.sinuous_m_stride = actual_rotary_size;\n\n  const IndexType position_m = position_shape ? static_cast<IndexType>(position_shape[2]) : m;\n  param.position_rotate_stride = position_m;\n  param.position_b_stride = position_m * rotary_emb_dim;\n\n// K has to be the last dimension, only k&m matters, therefore strides other than k&m does not\n// really needs to be computed\n#pragma unroll\n  for (int i = 0; i < num_dims; i++) {\n    param.ref_stride[i] = ref_strides[i];\n    param.out_stride[i] = out_strides[i];\n    param.x_stride[i] = x_strides[i];\n  }\n\n  constexpr size_t blk_size = 128;\n\n  if (mode == \"plane\") {\n    param.num_elements = param.num_elements * PackSize;\n    PlaneKernel<T, PositionType, IndexType, num_dims, rotary_emb_dim>\n        <<<(param.num_elements + blk_size - 1) / blk_size, blk_size, 0, stream->cuda_stream()>>>(\n            param);\n  } else {\n    IntervalKernel<T, PositionType, IndexType, PackSize, num_dims, rotary_emb_dim>\n        <<<(param.num_elements + blk_size - 1) / blk_size, blk_size, 0, stream->cuda_stream()>>>(\n            param);\n  }\n}\n\ntemplate<typename T, typename PositionType, typename IndexType, size_t num_dims,\n         size_t rotary_emb_dim>\nvoid DispatchPackSize(ep::CudaStream* stream, const T* x, const T* cos, const T* sin,\n                      const PositionType* position_ids, T* out, const int64_t* position_shape,\n                      const std::string& x_layout, const std::string& output_layout,\n                      const std::string& mode, const T theta, const IndexType rotary_size,\n                      const IndexType b, const IndexType m, const IndexType h, const IndexType k,\n                      const IndexType x_b_stride, const IndexType x_m_stride,\n                      const IndexType x_h_stride, const IndexType x_offset,\n                      const IndexType out_b_stride, const IndexType out_m_stride,\n                      const IndexType out_h_stride, IndexType num_elements) {\n  const auto CheckPackSize = [&](const size_t PackSize) {\n    bool r = (((reinterpret_cast<uintptr_t>(x) % (sizeof(T) * PackSize)) == 0)\n              && (((rotary_size / rotary_emb_dim) % PackSize) == 0)\n              && (((k - rotary_size) % PackSize) == 0) && ((16 / sizeof(T)) >= PackSize));\n    return r;\n  };\n\n  if (CheckPackSize(8)) {\n    num_elements /= 8;\n    LaunchKernel<T, PositionType, IndexType, 8, num_dims, rotary_emb_dim>(\n        stream, x, cos, sin, position_ids, out, position_shape, x_layout, output_layout, mode,\n        theta, rotary_size, b, m, h, k, x_b_stride, x_m_stride, x_h_stride, x_offset, out_b_stride,\n        out_m_stride, out_h_stride, num_elements);\n  } else if (CheckPackSize(4)) {\n    num_elements /= 4;\n    LaunchKernel<T, PositionType, IndexType, 4, num_dims, rotary_emb_dim>(\n        stream, x, cos, sin, position_ids, out, position_shape, x_layout, output_layout, mode,\n        theta, rotary_size, b, m, h, k, x_b_stride, x_m_stride, x_h_stride, x_offset, out_b_stride,\n        out_m_stride, out_h_stride, num_elements);\n  } else {\n    num_elements /= 2;\n    LaunchKernel<T, PositionType, IndexType, 2, num_dims, rotary_emb_dim>(\n        stream, x, cos, sin, position_ids, out, position_shape, x_layout, output_layout, mode,\n        theta, rotary_size, b, m, h, k, x_b_stride, x_m_stride, x_h_stride, x_offset, out_b_stride,\n        out_m_stride, out_h_stride, num_elements);\n  }\n}\n\ntemplate<typename T, typename PositionType, size_t num_dims, size_t rotary_emb_dim>\nvoid DispatchIndex(ep::CudaStream* stream, const T* x, const T* cos, const T* sin,\n                   const PositionType* position_ids, T* out, const int64_t* position_shape,\n                   const std::string& x_layout, const std::string& output_layout,\n                   const std::string& mode, const T theta, const int64_t rotary_size,\n                   const int64_t b, const int64_t m, const int64_t h, const int64_t k,\n                   const int64_t x_b_stride, const int64_t x_m_stride, const int64_t x_h_stride,\n                   const int64_t x_offset, const int64_t out_b_stride, const int64_t out_m_stride,\n                   const int64_t out_h_stride) {\n  int64_t num_elements = b * m * h * k;\n  if (num_elements < (1 << 30)) {\n    DispatchPackSize<T, PositionType, int32_t, num_dims, rotary_emb_dim>(\n        stream, x, cos, sin, position_ids, out, position_shape, x_layout, output_layout, mode,\n        theta, static_cast<int32_t>(rotary_size), static_cast<int32_t>(b), static_cast<int32_t>(m),\n        static_cast<int32_t>(h), static_cast<int32_t>(k), static_cast<int32_t>(x_b_stride),\n        static_cast<int32_t>(x_m_stride), static_cast<int32_t>(x_h_stride),\n        static_cast<int32_t>(x_offset), static_cast<int32_t>(out_b_stride),\n        static_cast<int32_t>(out_m_stride), static_cast<int32_t>(out_h_stride),\n        static_cast<int32_t>(num_elements));\n  } else {\n    DispatchPackSize<T, PositionType, int64_t, num_dims, rotary_emb_dim>(\n        stream, x, cos, sin, position_ids, out, position_shape, x_layout, output_layout, mode,\n        theta, rotary_size, b, m, h, k, x_b_stride, x_m_stride, x_h_stride, x_offset, out_b_stride,\n        out_m_stride, out_h_stride, num_elements);\n  }\n}\n\ntemplate<typename T, typename PositionType, size_t num_dims>\nvoid DispatchRotaryEmbeddingDimension(ep::CudaStream* stream, const T* x, const T* cos,\n                                      const T* sin, const PositionType* position_ids, T* out,\n                                      const int64_t* position_shape, const std::string& x_layout,\n                                      const std::string& output_layout, const std::string& mode,\n                                      const T theta, const int64_t rotary_size,\n                                      const int rotary_emb_dim, const int64_t b, const int64_t m,\n                                      const int64_t h, const int64_t k, const int64_t x_b_stride,\n                                      const int64_t x_m_stride, const int64_t x_h_stride,\n                                      const int64_t x_offset, const int64_t out_b_stride,\n                                      const int64_t out_m_stride, const int64_t out_h_stride) {\n  if (rotary_emb_dim == 1) {\n    DispatchIndex<T, PositionType, num_dims, 1>(\n        stream, x, cos, sin, position_ids, out, position_shape, x_layout, output_layout, mode,\n        theta, rotary_size, b, m, h, k, x_b_stride, x_m_stride, x_h_stride, x_offset, out_b_stride,\n        out_m_stride, out_h_stride);\n  } else if (rotary_emb_dim == 2) {\n    DispatchIndex<T, PositionType, num_dims, 2>(\n        stream, x, cos, sin, position_ids, out, position_shape, x_layout, output_layout, mode,\n        theta, rotary_size, b, m, h, k, x_b_stride, x_m_stride, x_h_stride, x_offset, out_b_stride,\n        out_m_stride, out_h_stride);\n  }\n}\n\ntemplate<typename T, typename PositionType>\nclass FusedApplyRotaryEmbKernel final : public user_op::OpKernel {\n public:\n  FusedApplyRotaryEmbKernel() = default;\n  ~FusedApplyRotaryEmbKernel() override = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    user_op::Tensor* cos = nullptr;\n    user_op::Tensor* sin = nullptr;\n    user_op::Tensor* position_ids = nullptr;\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    const std::string& x_layout = ctx->Attr<std::string>(\"x_layout\");\n    const std::string& output_layout = ctx->Attr<std::string>(\"output_layout\");\n    const std::string& mode = ctx->Attr<std::string>(\"mode\");\n    const int64_t tensor_index = ctx->Attr<int64_t>(\"tensor_index\");\n    const int64_t k_size = ctx->Attr<int64_t>(\"k_size\");\n    const int64_t rotary_size = ctx->Attr<int64_t>(\"rotary_size\");\n    const float theta = 1.0f / ctx->Attr<float>(\"base\");\n    int rotary_emb_dim = 1;\n\n    if (ctx->has_input(\"cos\", 0)) { cos = ctx->Tensor4ArgNameAndIndex(\"cos\", 0); }\n\n    if (ctx->has_input(\"sin\", 0)) { sin = ctx->Tensor4ArgNameAndIndex(\"sin\", 0); }\n\n    if (ctx->has_input(\"position_ids\", 0)) {\n      position_ids = ctx->Tensor4ArgNameAndIndex(\"position_ids\", 0);\n      rotary_emb_dim = position_ids->shape_view().At(1);\n    }\n\n    constexpr size_t ndims = 4;\n    int64_t b = 0;\n    int64_t m = 0;\n    int64_t h = 0;\n    int64_t k = 0;\n    int64_t out_b_stride = 0, out_m_stride = 0, out_h_stride = 0, out_offset = 0;\n    int64_t x_b_stride = 0, x_m_stride = 0, x_h_stride = 0, x_offset = 0;\n\n    ParseDims(out->shape_view(), output_layout, Optional<int64_t>(), k_size, 0, &b, &m, &h, &k,\n              &out_b_stride, &out_m_stride, &out_h_stride, &out_offset);\n    ParseDims(x->shape_view(), x_layout, Optional<int64_t>(), k_size, tensor_index, &b, &m, &h, &k,\n              &x_b_stride, &x_m_stride, &x_h_stride, &x_offset);\n\n    // TODO: hard code num_dims & seems redundant template problem...\n    DispatchRotaryEmbeddingDimension<T, PositionType, ndims>(\n        ctx->stream()->As<ep::CudaStream>(), reinterpret_cast<const T*>(x->dptr()),\n        cos ? reinterpret_cast<const T*>(cos->dptr()) : nullptr,\n        sin ? reinterpret_cast<const T*>(sin->dptr()) : nullptr,\n        position_ids ? reinterpret_cast<const PositionType*>(position_ids->dptr()) : nullptr,\n        reinterpret_cast<T*>(out->mut_dptr()),\n        position_ids ? position_ids->shape_view().data() : nullptr, x_layout, output_layout, mode,\n        static_cast<T>(theta), rotary_size, rotary_emb_dim, b, m, h, k, x_b_stride, x_m_stride,\n        x_h_stride, x_offset, out_b_stride, out_m_stride, out_h_stride);\n  }\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_FUSED_APPLY_ROTARY_EMB_GPU(dtype, position_type)          \\\n  REGISTER_USER_KERNEL(\"fused_apply_rotary_emb\")                           \\\n      .SetCreateFn<FusedApplyRotaryEmbKernel<dtype, position_type>>()      \\\n      .SetIsMatchedHob(                                                    \\\n          (user_op::HobDeviceType() == DeviceType::kCUDA)                  \\\n          && (user_op::HobDataType(\"out\", 0) == GetDataType<dtype>::value) \\\n          && (user_op::HobInputSize(\"position_ids\") == 1)                  \\\n          && (user_op::HobDataType(\"position_ids\", 0) == GetDataType<position_type>::value));\n\n#define REGISTER_FUSED_APPLY_ROTARY_EMB_GPU_DTYPE(dtype)                                \\\n  REGISTER_FUSED_APPLY_ROTARY_EMB_GPU(dtype, int64_t);                                  \\\n  REGISTER_FUSED_APPLY_ROTARY_EMB_GPU(dtype, int32_t);                                  \\\n  REGISTER_USER_KERNEL(\"fused_apply_rotary_emb\")                                        \\\n      .SetCreateFn<FusedApplyRotaryEmbKernel<dtype, int64_t>>()                         \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)                  \\\n                       && (user_op::HobDataType(\"out\", 0) == GetDataType<dtype>::value) \\\n                       && (user_op::HobInputSize(\"position_ids\") == 0));\n\nREGISTER_FUSED_APPLY_ROTARY_EMB_GPU_DTYPE(float);\nREGISTER_FUSED_APPLY_ROTARY_EMB_GPU_DTYPE(half);\n#if CUDA_VERSION >= 11000\nREGISTER_FUSED_APPLY_ROTARY_EMB_GPU_DTYPE(nv_bfloat16);\n#endif  // CUDA_VERSION >= 11000\n\n}  // namespace\n\n}  // namespace user_op\n\n}  // namespace oneflow\n\n#endif  // WITH_CUTLASS\n"
  },
  {
    "path": "oneflow/user/kernels/fused_bias_add_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/device/cuda_util.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n#if CUDA_VERSION >= 11000\n#include <cuda_bf16.h>\n#endif  // CUDA_VERSION >= 11000\n#include \"oneflow/core/device/cuda_pseudo_bfloat16.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename T>\nstruct GeluFunctor {\n  __device__ T Compute(T x, int64_t i) const {\n    return static_cast<T>(0.5) * x * (static_cast<T>(1.0) + erf(static_cast<T>(M_SQRT1_2) * x));\n  }\n};\n\ntemplate<>\nstruct GeluFunctor<half> {\n  GeluFunctor<float> float_functor;\n  __device__ half Compute(half x, int64_t i) const {\n    return __float2half(float_functor.Compute(__half2float(x), i));\n  }\n  __device__ half2 ComputeHalf2(half2 x, int64_t i) const {\n    half2 y;\n    y.x = __float2half(float_functor.Compute(__half2float(x.x), 2 * i));\n    y.y = __float2half(float_functor.Compute(__half2float(x.y), 2 * i + 1));\n    return y;\n  }\n};\n\n#if CUDA_VERSION >= 11000\ntemplate<>\nstruct GeluFunctor<nv_bfloat16> {\n  GeluFunctor<float> float_functor;\n  __device__ nv_bfloat16 Compute(nv_bfloat16 x, int64_t i) const {\n    return static_cast<nv_bfloat16>(float_functor.Compute(static_cast<float>(x), i));\n  }\n};\n#endif\n\ntemplate<typename T>\nstruct MaskAndScaleFunctor {\n  MaskAndScaleFunctor(const bool* mask, float scale) : mask(mask), scale(scale) {}\n  __device__ T Compute(T x, int64_t i) const { return x * static_cast<T>(mask[i] * scale); }\n  const bool* mask;\n  float scale;\n};\n\ntemplate<>\nstruct MaskAndScaleFunctor<half> {\n  MaskAndScaleFunctor(const bool* mask, float scale) : mask(mask), scale(scale) {}\n  __device__ half Compute(half x, int64_t i) const {\n    return x * static_cast<half>(mask[i] * scale);\n  }\n  __device__ half2 ComputeHalf2(half2 x, int64_t i) const {\n    const char2* mask_c2 = reinterpret_cast<const char2*>(mask);\n    char2 mask_val = mask_c2[i];\n    half2 one_or_zero_h2;\n    half2 h2_scale = __float2half2_rn(scale);\n    one_or_zero_h2.x = mask_val.x;\n    one_or_zero_h2.y = mask_val.y;\n    return __hmul2(__hmul2(x, one_or_zero_h2), h2_scale);\n  }\n  const bool* mask;\n  float scale;\n};\n\ntemplate<typename T>\nstruct MaskAndScaleAddFunctor {\n  MaskAndScaleAddFunctor(const bool* mask, const T* addend, float scale)\n      : mask(mask), addend(addend), scale(scale) {}\n  __device__ T Compute(T x, int64_t i) const {\n    return x * static_cast<T>(mask[i] * scale) + addend[i];\n  }\n  const bool* mask;\n  const T* addend;\n  float scale;\n};\n\ntemplate<>\nstruct MaskAndScaleAddFunctor<half> {\n  MaskAndScaleAddFunctor(const bool* mask, const half* addend, float scale)\n      : mask(mask), addend(addend), scale(scale) {}\n  __device__ half Compute(half x, int64_t i) const {\n    return x * static_cast<half>(mask[i] * scale) + addend[i];\n  }\n  __device__ half2 ComputeHalf2(half2 x, int64_t i) const {\n    const char2* mask_c2 = reinterpret_cast<const char2*>(mask);\n    const half2* addend_h2 = reinterpret_cast<const half2*>(addend);\n    char2 mask_val = mask_c2[i];\n    half2 one_or_zero_h2;\n    half2 h2_scale = __float2half2_rn(scale);\n    one_or_zero_h2.x = mask_val.x;\n    one_or_zero_h2.y = mask_val.y;\n    return __hadd2(__hmul2(__hmul2(x, one_or_zero_h2), h2_scale), addend_h2[i]);\n  }\n  const bool* mask;\n  const half* addend;\n  float scale;\n};\n\ntemplate<typename T>\nstruct GeluGradFunctor {\n  const T coef = std::sqrt(static_cast<T>(2.0) / std::acos(static_cast<T>(-1.0)));\n  __device__ T Compute(T x, T dy, int64_t i) const {\n    return static_cast<T>(0.5)\n           * (static_cast<T>(1.0) + erf(static_cast<T>(M_SQRT1_2) * x)\n              + x * coef * exp(static_cast<T>(-0.5) * x * x))\n           * dy;\n  }\n};\n\ntemplate<>\nstruct GeluGradFunctor<half> {\n  GeluGradFunctor<float> float_functor;\n  __device__ half Compute(half x, half dy, int64_t i) const {\n    return __float2half(float_functor.Compute(__half2float(x), __half2float(dy), i));\n  }\n};\n\n#if CUDA_VERSION >= 11000\ntemplate<>\nstruct GeluGradFunctor<nv_bfloat16> {\n  GeluGradFunctor<float> float_functor;\n  __device__ nv_bfloat16 Compute(nv_bfloat16 x, nv_bfloat16 dy, int64_t i) const {\n    return static_cast<nv_bfloat16>(\n        float_functor.Compute(static_cast<float>(x), static_cast<float>(dy), i));\n  }\n};\n#endif\n\ntemplate<typename FUNCTOR, typename T, typename Index>\n__global__ void FusedBiasAddGpu(FUNCTOR functor, const Index elem_cnt, const Index bias_size,\n                                const Index inner_size, const T* x, const T* bias, T* y) {\n  const Index block_size = bias_size * inner_size;\n  CUDA_1D_KERNEL_LOOP_T(Index, i, elem_cnt) {\n    T x_i = x[i] + bias[(i % block_size) / inner_size];\n    y[i] = functor.Compute(x_i, i);\n  }\n}\n\ntemplate<typename FUNCTOR, typename T, typename Index>\n__global__ void FusedBiasAddGradGpu(FUNCTOR grad_functor, const Index elem_cnt,\n                                    const Index bias_size, const Index inner_size, const T* x,\n                                    const T* bias, const T* dy, T* dx) {\n  const Index block_size = bias_size * inner_size;\n  CUDA_1D_KERNEL_LOOP_T(Index, i, elem_cnt) {\n    T x_i = x[i] + bias[(i % block_size) / inner_size];\n    dx[i] = grad_functor.Compute(x_i, dy[i], i);\n  }\n}\n\ntemplate<typename FUNCTOR, typename T, typename Index>\n__global__ void FusedBiasAddRowGpu(FUNCTOR functor, const Index elem_cnt, const Index bias_size,\n                                   const T* x, const T* bias, T* y) {\n  CUDA_1D_KERNEL_LOOP_T(Index, i, elem_cnt) {\n    T x_i = x[i] + bias[i % bias_size];\n    y[i] = functor.Compute(x_i, i);\n  }\n}\n\ntemplate<typename FUNCTOR, typename T, typename Index>\n__global__ void FusedBiasAddGradRowGpu(FUNCTOR grad_functor, const Index elem_cnt,\n                                       const Index bias_size, const T* x, const T* bias,\n                                       const T* dy, T* dx) {\n  CUDA_1D_KERNEL_LOOP_T(Index, i, elem_cnt) {\n    T x_i = x[i] + bias[i % bias_size];\n    dx[i] = grad_functor.Compute(x_i, dy[i], i);\n  }\n}\n\ntemplate<typename FUNCTOR, typename Index>\n__global__ void FusedBiasAddRowGpuHalf2(FUNCTOR functor, const Index elem_cnt,\n                                        const Index bias_size, const half* x, const half* bias,\n                                        half* y) {\n  const Index h2_elem_cnt = elem_cnt / 2;\n  const Index h2_bias_size = bias_size / 2;\n  const auto* x_h2 = reinterpret_cast<const half2*>(x);\n  const auto* bias_h2 = reinterpret_cast<const half2*>(bias);\n  auto* y_h2 = reinterpret_cast<half2*>(y);\n  CUDA_1D_KERNEL_LOOP_T(Index, i, h2_elem_cnt) {\n    half2 x_i = __hadd2(x_h2[i], bias_h2[i % h2_bias_size]);\n    y_h2[i] = functor.ComputeHalf2(x_i, i);\n  }\n}\n\ntemplate<typename FUNCTOR, typename Index>\n__global__ void FusedBiasAddGradRowGpuHalf2(FUNCTOR grad_functor, const Index elem_cnt,\n                                            const Index bias_size, const half* x, const half* bias,\n                                            const half* dy, half* dx) {\n  const Index h2_elem_cnt = elem_cnt / 2;\n  const Index h2_bias_size = bias_size / 2;\n  const auto* x_h2 = reinterpret_cast<const half2*>(x);\n  const auto* bias_h2 = reinterpret_cast<const half2*>(bias);\n  const auto* dy_h2 = reinterpret_cast<const half2*>(dy);\n  auto* dx_h2 = reinterpret_cast<half2*>(dx);\n  CUDA_1D_KERNEL_LOOP_T(Index, i, h2_elem_cnt) {\n    half2 x_i = __hadd2(x_h2[i], bias_h2[i % h2_bias_size]);\n    half2 dy_i = dy_h2[i];\n    half2 dx_i;\n    dx_i.x = grad_functor.Compute(x_i.x, dy_i.x, 2 * i);\n    dx_i.y = grad_functor.Compute(x_i.y, dy_i.y, 2 * i + 1);\n    dx_h2[i] = dx_i;\n  }\n}\n\ntemplate<typename FUNCTOR, typename T, typename Index>\n__global__ void FusedBiasAddColGpu(FUNCTOR functor, const Index elem_cnt, const Index inner_size,\n                                   const T* x, const T* bias, T* y) {\n  CUDA_1D_KERNEL_LOOP_T(Index, i, elem_cnt) {\n    T x_i = x[i] + bias[i / inner_size];\n    y[i] = functor.Compute(x_i, i);\n  }\n}\n\ntemplate<typename FUNCTOR, typename T, typename Index>\n__global__ void FusedBiasAddGradColGpu(FUNCTOR grad_functor, const Index elem_cnt,\n                                       const Index inner_size, const T* x, const T* bias,\n                                       const T* dy, T* dx) {\n  CUDA_1D_KERNEL_LOOP_T(Index, i, elem_cnt) {\n    T x_i = x[i] + bias[i / inner_size];\n    dx[i] = grad_functor.Compute(x_i, dy[i], i);\n  }\n}\n\ntemplate<typename FUNCTOR, typename T, typename Index>\nstruct FusedBiasAddRow {\n  static void Invoke(ep::Stream* stream, FUNCTOR functor, Index elem_cnt, Index bias_size,\n                     const T* x, const T* bias, T* y) {\n    FusedBiasAddRowGpu<FUNCTOR, T, Index>\n        <<<BlocksNum4ThreadsNum(elem_cnt), kCudaThreadsNumPerBlock, 0,\n           stream->As<ep::CudaStream>()->cuda_stream()>>>(functor, elem_cnt, bias_size, x, bias, y);\n  }\n};\n\ntemplate<typename FUNCTOR, typename Index>\nstruct FusedBiasAddRow<FUNCTOR, half, Index> {\n  static void Invoke(ep::Stream* stream, FUNCTOR functor, Index elem_cnt, Index bias_size,\n                     const half* x, const half* bias, half* y) {\n    if (bias_size % 2 == 0) {\n      FusedBiasAddRowGpuHalf2<FUNCTOR, Index>\n          <<<BlocksNum4ThreadsNum(elem_cnt / 2), kCudaThreadsNumPerBlock, 0,\n             stream->As<ep::CudaStream>()->cuda_stream()>>>(functor, elem_cnt, bias_size, x, bias,\n                                                            y);\n    } else {\n      FusedBiasAddRowGpu<FUNCTOR, half, Index>\n          <<<BlocksNum4ThreadsNum(elem_cnt), kCudaThreadsNumPerBlock, 0,\n             stream->As<ep::CudaStream>()->cuda_stream()>>>(functor, elem_cnt, bias_size, x, bias,\n                                                            y);\n    }\n  }\n};\n\ntemplate<typename FUNCTOR, typename T, typename Index>\nvoid FusedBiasAddForwardImpl(ep::Stream* stream, FUNCTOR functor, Index outer_size, Index bias_size,\n                             Index inner_size, const T* x, const T* bias, T* y) {\n  const Index elem_cnt = outer_size * bias_size * inner_size;\n  if (inner_size == 1) {\n    FusedBiasAddRow<FUNCTOR, T, Index>::Invoke(stream, functor, elem_cnt, bias_size, x, bias, y);\n  } else if (outer_size == 1) {\n    FusedBiasAddColGpu<FUNCTOR, T, Index><<<BlocksNum4ThreadsNum(elem_cnt), kCudaThreadsNumPerBlock,\n                                            0, stream->As<ep::CudaStream>()->cuda_stream()>>>(\n        functor, elem_cnt, inner_size, x, bias, y);\n  } else {\n    FusedBiasAddGpu<FUNCTOR, T, Index><<<BlocksNum4ThreadsNum(elem_cnt), kCudaThreadsNumPerBlock, 0,\n                                         stream->As<ep::CudaStream>()->cuda_stream()>>>(\n        functor, elem_cnt, bias_size, inner_size, x, bias, y);\n  }\n}\n\ntemplate<typename FUNCTOR, typename T, typename Index>\nstruct FusedBiasAddGradRow {\n  static void Invoke(ep::Stream* stream, FUNCTOR grad_functor, Index elem_cnt, Index bias_size,\n                     const T* x, const T* bias, const T* dy, T* dx) {\n    FusedBiasAddGradRowGpu<FUNCTOR, T, Index>\n        <<<BlocksNum4ThreadsNum(elem_cnt), kCudaThreadsNumPerBlock, 0,\n           stream->As<ep::CudaStream>()->cuda_stream()>>>(grad_functor, elem_cnt, bias_size, x,\n                                                          bias, dy, dx);\n  }\n};\n\ntemplate<typename FUNCTOR, typename Index>\nstruct FusedBiasAddGradRow<FUNCTOR, half, Index> {\n  static void Invoke(ep::Stream* stream, FUNCTOR grad_functor, Index elem_cnt, Index bias_size,\n                     const half* x, const half* bias, const half* dy, half* dx) {\n    if (bias_size % 2 == 0) {\n      FusedBiasAddGradRowGpuHalf2<FUNCTOR, Index>\n          <<<BlocksNum4ThreadsNum(elem_cnt / 2), kCudaThreadsNumPerBlock, 0,\n             stream->As<ep::CudaStream>()->cuda_stream()>>>(grad_functor, elem_cnt, bias_size, x,\n                                                            bias, dy, dx);\n    } else {\n      FusedBiasAddGradRowGpu<FUNCTOR, half, Index>\n          <<<BlocksNum4ThreadsNum(elem_cnt), kCudaThreadsNumPerBlock, 0,\n             stream->As<ep::CudaStream>()->cuda_stream()>>>(grad_functor, elem_cnt, bias_size, x,\n                                                            bias, dy, dx);\n    }\n  }\n};\n\ntemplate<typename FUNCTOR, typename T, typename Index>\nvoid FusedBiasAddGradImpl(ep::Stream* stream, FUNCTOR grad_functor, Index outer_size,\n                          Index bias_size, Index inner_size, const T* x, const T* bias, const T* dy,\n                          T* dx) {\n  const Index elem_cnt = outer_size * bias_size * inner_size;\n  if (inner_size == 1) {\n    FusedBiasAddGradRow<FUNCTOR, T, Index>::Invoke(stream, grad_functor, elem_cnt, bias_size, x,\n                                                   bias, dy, dx);\n  } else if (outer_size == 1) {\n    FusedBiasAddGradColGpu<FUNCTOR, T, Index>\n        <<<BlocksNum4ThreadsNum(elem_cnt), kCudaThreadsNumPerBlock, 0,\n           stream->As<ep::CudaStream>()->cuda_stream()>>>(grad_functor, elem_cnt, inner_size, x,\n                                                          bias, dy, dx);\n  } else {\n    FusedBiasAddGradGpu<FUNCTOR, T, Index>\n        <<<BlocksNum4ThreadsNum(elem_cnt), kCudaThreadsNumPerBlock, 0,\n           stream->As<ep::CudaStream>()->cuda_stream()>>>(grad_functor, elem_cnt, bias_size,\n                                                          inner_size, x, bias, dy, dx);\n  }\n}\n\ntemplate<typename FUNCTOR, typename T>\nvoid DispatchFusedBiasAddForwardImpl(ep::Stream* stream, FUNCTOR functor, int64_t n,\n                                     int64_t outer_size, int64_t bias_size, int64_t inner_size,\n                                     const T* x, const T* bias, T* y) {\n  if (IsKernelSafeInt32(n)) {\n    FusedBiasAddForwardImpl<FUNCTOR, T, int32_t>(stream, functor, outer_size, bias_size, inner_size,\n                                                 x, bias, y);\n  } else {\n    FusedBiasAddForwardImpl<FUNCTOR, T, int64_t>(stream, functor, outer_size, bias_size, inner_size,\n                                                 x, bias, y);\n  }\n}\n\n}  // namespace\n\ntemplate<typename T>\nclass FusedFusedBiasAddKernel final : public user_op::OpKernel {\n public:\n  FusedFusedBiasAddKernel() = default;\n  ~FusedFusedBiasAddKernel() override = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const auto* a_tensor = ctx->Tensor4ArgNameAndIndex(\"a\", 0);\n    const auto* b_tensor = ctx->Tensor4ArgNameAndIndex(\"b\", 0);\n    auto* out_tensor = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    const int32_t bias_add_axis = ctx->Attr<int32_t>(\"axis\");\n    const int64_t outer_size = a_tensor->shape_view().Count(0, bias_add_axis);\n    const int64_t bias_size = a_tensor->shape_view().At(bias_add_axis);\n    const int64_t inner_size = a_tensor->shape_view().Count(bias_add_axis + 1);\n    const auto n = a_tensor->shape_view().elem_cnt();\n    GeluFunctor<T> gelu_functor{};\n    DispatchFusedBiasAddForwardImpl<decltype(gelu_functor), T>(\n        ctx->stream(), gelu_functor, n, outer_size, bias_size, inner_size, a_tensor->dptr<T>(),\n        b_tensor->dptr<T>(), out_tensor->mut_dptr<T>());\n  };\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_FUSED_BIAS_ADD_GELU_KERNEL(dtype)                     \\\n  REGISTER_USER_KERNEL(\"fused_bias_add_gelu\")                          \\\n      .SetCreateFn<FusedFusedBiasAddKernel<dtype>>()                   \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \\\n                       && (user_op::HobDataType(\"out\", 0) == GetDataType<dtype>::value));\n\nREGISTER_FUSED_BIAS_ADD_GELU_KERNEL(float)\nREGISTER_FUSED_BIAS_ADD_GELU_KERNEL(double)\nREGISTER_FUSED_BIAS_ADD_GELU_KERNEL(half)\n#if CUDA_VERSION >= 11000\nREGISTER_FUSED_BIAS_ADD_GELU_KERNEL(nv_bfloat16)\n#endif\n\ntemplate<typename T>\nclass FusedBiasAddMaskScaleKernel final : public user_op::OpKernel {\n public:\n  FusedBiasAddMaskScaleKernel() = default;\n  ~FusedBiasAddMaskScaleKernel() override = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const auto* a_tensor = ctx->Tensor4ArgNameAndIndex(\"a\", 0);\n    const auto* b_tensor = ctx->Tensor4ArgNameAndIndex(\"b\", 0);\n    const auto* mask_tensor = ctx->Tensor4ArgNameAndIndex(\"mask\", 0);\n    auto* out_tensor = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    const int32_t bias_add_axis = ctx->Attr<int32_t>(\"axis\");\n    const float scale = ctx->Attr<float>(\"scale\");\n    const int64_t outer_size = a_tensor->shape_view().Count(0, bias_add_axis);\n    const int64_t bias_size = a_tensor->shape_view().At(bias_add_axis);\n    const int64_t inner_size = a_tensor->shape_view().Count(bias_add_axis + 1);\n    const auto n = a_tensor->shape_view().elem_cnt();\n    if (ctx->has_input(\"_add_to_output\", 0)) {\n      const user_op::Tensor* addend = ctx->Tensor4ArgNameAndIndex(\"_add_to_output\", 0);\n      MaskAndScaleAddFunctor<T> mask_and_scale_add_functor(mask_tensor->dptr<bool>(),\n                                                           addend->dptr<T>(), scale);\n      DispatchFusedBiasAddForwardImpl<decltype(mask_and_scale_add_functor), T>(\n          ctx->stream(), mask_and_scale_add_functor, n, outer_size, bias_size, inner_size,\n          a_tensor->dptr<T>(), b_tensor->dptr<T>(), out_tensor->mut_dptr<T>());\n    } else {\n      MaskAndScaleFunctor<T> mask_and_scale_functor(mask_tensor->dptr<bool>(), scale);\n      DispatchFusedBiasAddForwardImpl<decltype(mask_and_scale_functor), T>(\n          ctx->stream(), mask_and_scale_functor, n, outer_size, bias_size, inner_size,\n          a_tensor->dptr<T>(), b_tensor->dptr<T>(), out_tensor->mut_dptr<T>());\n    }\n  };\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_FUSED_BIAS_ADD_MASK_SCALE_KERNEL(dtype)               \\\n  REGISTER_USER_KERNEL(\"fused_bias_add_mask_scale\")                    \\\n      .SetCreateFn<FusedBiasAddMaskScaleKernel<dtype>>()               \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \\\n                       && (user_op::HobDataType(\"out\", 0) == GetDataType<dtype>::value));\n\nREGISTER_FUSED_BIAS_ADD_MASK_SCALE_KERNEL(float)\nREGISTER_FUSED_BIAS_ADD_MASK_SCALE_KERNEL(double)\nREGISTER_FUSED_BIAS_ADD_MASK_SCALE_KERNEL(half)\n#if CUDA_VERSION >= 11000\nREGISTER_FUSED_BIAS_ADD_MASK_SCALE_KERNEL(nv_bfloat16)\n#endif\n\ntemplate<typename T>\nclass FusedFusedBiasAddGradKernel final : public user_op::OpKernel {\n public:\n  FusedFusedBiasAddGradKernel() = default;\n  ~FusedFusedBiasAddGradKernel() override = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const auto* a_tensor = ctx->Tensor4ArgNameAndIndex(\"a\", 0);\n    const auto* b_tensor = ctx->Tensor4ArgNameAndIndex(\"b\", 0);\n    const auto* dy_tensor = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    auto* dx_tensor = ctx->Tensor4ArgNameAndIndex(\"dx\", 0);\n    const int32_t bias_add_axis = ctx->Attr<int32_t>(\"axis\");\n    const int64_t outer_size = a_tensor->shape_view().Count(0, bias_add_axis);\n    const int64_t bias_size = a_tensor->shape_view().At(bias_add_axis);\n    const int64_t inner_size = a_tensor->shape_view().Count(bias_add_axis + 1);\n    const auto n = a_tensor->shape_view().elem_cnt();\n    GeluGradFunctor<T> gelu_grad_functor;\n    if (IsKernelSafeInt32(n)) {\n      FusedBiasAddGradImpl<decltype(gelu_grad_functor), T, int32_t>(\n          ctx->stream(), gelu_grad_functor, outer_size, bias_size, inner_size, a_tensor->dptr<T>(),\n          b_tensor->dptr<T>(), dy_tensor->dptr<T>(), dx_tensor->mut_dptr<T>());\n    } else {\n      FusedBiasAddGradImpl<decltype(gelu_grad_functor), T, int64_t>(\n          ctx->stream(), gelu_grad_functor, outer_size, bias_size, inner_size, a_tensor->dptr<T>(),\n          b_tensor->dptr<T>(), dy_tensor->dptr<T>(), dx_tensor->mut_dptr<T>());\n    }\n  };\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_FUSED_BIAS_ADD_GELU_GRAD_KERNEL(dtype)                \\\n  REGISTER_USER_KERNEL(\"fused_bias_add_gelu_grad\")                     \\\n      .SetCreateFn<FusedFusedBiasAddGradKernel<dtype>>()               \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \\\n                       && (user_op::HobDataType(\"dx\", 0) == GetDataType<dtype>::value));\n\nREGISTER_FUSED_BIAS_ADD_GELU_GRAD_KERNEL(float)\nREGISTER_FUSED_BIAS_ADD_GELU_GRAD_KERNEL(double)\nREGISTER_FUSED_BIAS_ADD_GELU_GRAD_KERNEL(half)\n#if CUDA_VERSION >= 11000\nREGISTER_FUSED_BIAS_ADD_GELU_GRAD_KERNEL(nv_bfloat16)\n#endif\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/fused_bias_add_scale_mask_softmax_dropout.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/cuda/softmax.cuh\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n#include \"oneflow/user/kernels/fused_softmax.cuh\"\n\nnamespace oneflow {\nnamespace cuda {\n\nnamespace {\n\ntemplate<typename IndexType, size_t NDIM>\nstruct BroadcastMapper {\n  using index_type = IndexType;\n  IndexType src_dims[NDIM] = {0};\n  IndexType dst_dims[NDIM] = {0};\n\n  template<typename DimType>\n  BroadcastMapper(const DimType* arg_src_dims, const DimType* arg_dst_dims) {\n    for (size_t i = 0; i < NDIM; ++i) { src_dims[i] = arg_src_dims[i]; }\n    for (size_t i = 0; i < NDIM; ++i) { dst_dims[i] = arg_dst_dims[i]; }\n  }\n\n  __device__ IndexType map(IndexType src) const {\n    NdIndexOffsetHelper<IndexType, NDIM> src_index_helper(src_dims);\n    NdIndexOffsetHelper<IndexType, NDIM> dst_index_helper(dst_dims);\n    IndexType src_index[NDIM];\n    IndexType dst_index[NDIM];\n    src_index_helper.OffsetToNdIndex(src, src_index);\n#pragma unroll\n    for (int dim = 0; dim < NDIM; ++dim) {\n      if (dst_dims[dim] == 1) {\n        dst_index[dim] = 0;\n      } else {\n        dst_index[dim] = src_index[dim];\n      }\n    }\n    return dst_index_helper.NdIndexToOffset(dst_index);\n  }\n};\n\ntemplate<typename IndexType>\nstruct ElementwiseMapper {\n  using index_type = IndexType;\n  ElementwiseMapper() {}\n  __device__ IndexType map(IndexType index) const { return index; }\n};\n\ntemplate<typename SRC, typename DST, typename MASK, typename BiasMapper, typename MaskMapper>\nstruct BiasAddScaleMaskLoad {\n  static_assert(\n      std::is_same<typename BiasMapper::index_type, typename MaskMapper::index_type>::value, \"\");\n  using IndexType = typename BiasMapper::index_type;\n  const SRC* src;\n  const SRC* bias;\n  const MASK* mask;\n  const DST fill;\n  const DST scale;\n  const IndexType row_size;\n  const BiasMapper bias_mapper;\n  const MaskMapper mask_mapper;\n\n  BiasAddScaleMaskLoad(const SRC* src, const SRC* bias, const MASK* mask, const DST fill,\n                       const DST scale, const IndexType row_size, const BiasMapper bias_mapper,\n                       const MaskMapper mask_mapper)\n      : src(src),\n        bias(bias),\n        mask(mask),\n        fill(fill),\n        scale(scale),\n        row_size(row_size),\n        bias_mapper(bias_mapper),\n        mask_mapper(mask_mapper) {}\n\n  template<int N>\n  __device__ void load(DST* dst, IndexType row, IndexType col) {\n    softmax::Pack<SRC, N> src_pack;\n    softmax::Pack<SRC, N> bias_pack;\n    softmax::Pack<MASK, N> mask_pack;\n    const IndexType offset = row * row_size + col;\n    const IndexType bias_offset = bias_mapper.map(offset);\n    const IndexType mask_offset = mask_mapper.map(offset);\n    src_pack.storage = *(reinterpret_cast<const softmax::PackType<SRC, N>*>(src) + offset / N);\n    bias_pack.storage =\n        *(reinterpret_cast<const softmax::PackType<SRC, N>*>(bias) + bias_offset / N);\n    mask_pack.storage =\n        *(reinterpret_cast<const softmax::PackType<MASK, N>*>(mask) + mask_offset / N);\n#pragma unroll\n    for (int i = 0; i < N; ++i) {\n      if (mask_pack.elem[i] == 0) {\n        dst[i] = fill;\n      } else {\n        dst[i] = static_cast<DST>(src_pack.elem[i] + bias_pack.elem[i]) * scale;\n      }\n    }\n  }\n};\n\ntemplate<typename T, typename MASK>\nvoid DispatchForward(cudaStream_t stream, const user_op::Tensor* x, const user_op::Tensor* bias,\n                     const user_op::Tensor* mask, const user_op::Tensor* dropout_mask,\n                     const float mask_fill, const float scale, const float dropout_scale,\n                     user_op::Tensor* y, user_op::Tensor* softmax_y) {\n  using ComputeType = typename softmax::DefaultComputeType<T>::type;\n  using IndexType = int32_t;\n  constexpr int kMaxNDim = 5;\n\n  const auto& x_shape = x->shape_view();\n  CHECK_GE(x_shape.size(), 2);\n  // the last dim is softmax dim which is considered as col\n  int64_t ncol = x_shape[x_shape.size() - 1];\n  int64_t nrow = x_shape.elem_cnt() / ncol;\n  fused_softmax::DropoutStore<ComputeType, T> store(\n      y->mut_dptr<T>(), softmax_y->mut_dptr<T>(), dropout_mask->dptr<bool>(), ncol, dropout_scale);\n\n  size_t bias_sndim = 0;\n  int64_t bias_x_sdims[kMaxNDim];\n  int64_t bias_sdims[kMaxNDim];\n  const auto& bias_shape = bias->shape_view();\n  fused_softmax::SimplifyBroadcastDims(x_shape.size(), x_shape.ptr(), bias_shape.size(),\n                                       bias_shape.ptr(), &bias_sndim, bias_x_sdims, bias_sdims);\n  size_t mask_sndim = 0;\n  int64_t mask_x_sdims[kMaxNDim];\n  int64_t mask_sdims[kMaxNDim];\n  const auto& mask_shape = mask->shape_view();\n  fused_softmax::SimplifyBroadcastDims(x_shape.size(), x_shape.ptr(), mask_shape.size(),\n                                       mask_shape.ptr(), &mask_sndim, mask_x_sdims, mask_sdims);\n\n#define DISPATCH_BIAS_ADD_SCALE_MASK_SOFTMAX(bias_mapper, mask_mapper)                           \\\n  BiasAddScaleMaskLoad<T, ComputeType, MASK, decltype(bias_mapper), decltype(mask_mapper)> load( \\\n      x->dptr<T>(), bias->dptr<T>(), mask->dptr<MASK>(), mask_fill, scale, ncol, bias_mapper,    \\\n      mask_mapper);                                                                              \\\n  OF_CUDA_CHECK((cuda::softmax::DispatchSoftmax<decltype(load), decltype(store), ComputeType>(   \\\n      stream, load, store, nrow, ncol)))\n\n  if (bias_sndim == 1 && mask_sndim == 1) {\n    // bias elementwise\n    // mask elementwise\n    ElementwiseMapper<IndexType> bias_mapper;\n    ElementwiseMapper<IndexType> mask_mapper;\n    DISPATCH_BIAS_ADD_SCALE_MASK_SOFTMAX(bias_mapper, mask_mapper);\n  } else if (bias_sndim == 1 && mask_sndim == 2) {\n    // bias elementwise\n    // mask broadcast: (M, 1) -> (M, N) or (1, N) -> (M, N)\n    ElementwiseMapper<IndexType> bias_mapper;\n    BroadcastMapper<IndexType, 2> mask_mapper(mask_x_sdims, mask_sdims);\n    DISPATCH_BIAS_ADD_SCALE_MASK_SOFTMAX(bias_mapper, mask_mapper);\n  } else if (bias_sndim == 1 && mask_sndim == 3) {\n    // bias elementwise\n    // mask broadcast: (M, 1, N) -> (M, K, N)\n    ElementwiseMapper<IndexType> bias_mapper;\n    BroadcastMapper<IndexType, 3> mask_mapper(mask_x_sdims, mask_sdims);\n    DISPATCH_BIAS_ADD_SCALE_MASK_SOFTMAX(bias_mapper, mask_mapper);\n  } else if (bias_sndim == 2 && mask_sndim == 1) {\n    // bias broadcast: (M, 1) -> (M, N) or (1, N) -> (M, N)\n    // mask elementwise\n    BroadcastMapper<IndexType, 2> bias_mapper(bias_x_sdims, bias_sdims);\n    ElementwiseMapper<IndexType> mask_mapper;\n    DISPATCH_BIAS_ADD_SCALE_MASK_SOFTMAX(bias_mapper, mask_mapper);\n  } else if (bias_sndim == 2 && mask_sndim == 2) {\n    // bias broadcast: (M, 1) -> (M, N) or (1, N) -> (M, N)\n    // mask broadcast: (M, 1) -> (M, N) or (1, N) -> (M, N)\n    BroadcastMapper<IndexType, 2> bias_mapper(bias_x_sdims, bias_sdims);\n    BroadcastMapper<IndexType, 2> mask_mapper(mask_x_sdims, mask_sdims);\n    DISPATCH_BIAS_ADD_SCALE_MASK_SOFTMAX(bias_mapper, mask_mapper);\n  } else if (bias_sndim == 2 && mask_sndim == 3) {\n    // bias broadcast: (M, 1) -> (M, N) or (1, N) -> (M, N)\n    // mask broadcast: (M, 1, N) -> (M, K, N)\n    BroadcastMapper<IndexType, 2> bias_mapper(bias_x_sdims, bias_sdims);\n    BroadcastMapper<IndexType, 3> mask_mapper(mask_x_sdims, mask_sdims);\n    DISPATCH_BIAS_ADD_SCALE_MASK_SOFTMAX(bias_mapper, mask_mapper);\n    // not support for now\n    // } else if (bias_sndim == 3 && mask_sndim == 1) {\n    // } else if (bias_sndim == 3 && mask_sndim == 2) {\n    // } else if (bias_sndim == 3 && mask_sndim == 3) {\n  } else {\n    UNIMPLEMENTED() << \", bias_sndim=\" << bias_sndim << \", mask_sndim=\" << mask_sndim;\n  }\n\n#undef DISPATCH_BIAS_ADD_SCALE_MASK_SOFTMAX\n}\n\ntemplate<typename T, typename MASK>\nclass FusedBiasAddScaleMaskSoftmaxDropoutKernel final : public user_op::OpKernel {\n public:\n  FusedBiasAddScaleMaskSoftmaxDropoutKernel() = default;\n  ~FusedBiasAddScaleMaskSoftmaxDropoutKernel() override = default;\n\n private:\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    const user_op::Tensor* bias = ctx->Tensor4ArgNameAndIndex(\"bias\", 0);\n    const user_op::Tensor* mask = ctx->Tensor4ArgNameAndIndex(\"mask\", 0);\n    const user_op::Tensor* dropout_mask = ctx->Tensor4ArgNameAndIndex(\"dropout_mask\", 0);\n    user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    user_op::Tensor* softmax_y = ctx->Tensor4ArgNameAndIndex(\"softmax_y\", 0);\n\n    const float mask_fill = ctx->Attr<float>(\"mask_fill_value\");\n    const float scale = ctx->Attr<float>(\"scale_value\");\n    const float dropout_scale = ctx->Attr<float>(\"dropout_scale_value\");\n\n    const ShapeView& x_shape = x->shape_view();\n    // int32 index computing is much faster than int64\n    // TODO: consider using multiple int32 computing to substitute int64 computing\n    CHECK_LT(x_shape.elem_cnt(), INT_MAX) << \"only support int32 max limits size of elements\";\n    DispatchForward<T, MASK>(ctx->stream()->As<ep::CudaStream>()->cuda_stream(), x, bias, mask,\n                             dropout_mask, mask_fill, scale, dropout_scale, y, softmax_y);\n  }\n};\n\n}  // namespace\n\n#define REGISTER_FUSED_BIAS_ADD_SCALE_MASK_SOFTMAX_DROPOUT_CUDA_KERNEL(dtype, mask_dtype) \\\n  REGISTER_USER_KERNEL(\"fused_bias_add_scale_mask_softmax_dropout\")                       \\\n      .SetCreateFn<FusedBiasAddScaleMaskSoftmaxDropoutKernel<dtype, mask_dtype>>()        \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)                    \\\n                       && (user_op::HobDataType(\"x\", 0) == GetDataType<dtype>::value)     \\\n                       && (user_op::HobDataType(\"mask\", 0) == GetDataType<mask_dtype>::value));\n\nREGISTER_FUSED_BIAS_ADD_SCALE_MASK_SOFTMAX_DROPOUT_CUDA_KERNEL(float, bool)\nREGISTER_FUSED_BIAS_ADD_SCALE_MASK_SOFTMAX_DROPOUT_CUDA_KERNEL(half, bool)\n\n#undef REGISTER_FUSED_BIAS_ADD_SCALE_MASK_SOFTMAX_DROPOUT_CUDA_KERNEL\n\n}  // namespace cuda\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/fused_cast_scale_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n\nnamespace oneflow {\n\ntemplate<typename T, typename U>\nclass FusedCastScaleCpuKernel final : public user_op::OpKernel {\n public:\n  FusedCastScaleCpuKernel() = default;\n  ~FusedCastScaleCpuKernel() override = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    const user_op::Tensor* scale_by_tensor = ctx->Tensor4ArgNameAndIndex(\"scale_by_tensor\", 0);\n    user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    const double scale_val = ctx->Attr<double>(\"scale\");\n    const int64_t n = x->shape_view().elem_cnt();\n    const T scale = *(scale_by_tensor->dptr<T>()) * scale_val;\n    const U* x_ptr = x->dptr<U>();\n    T* y_ptr = y->mut_dptr<T>();\n    FOR_RANGE(int64_t, i, 0, n) { y_ptr[i] = static_cast<T>(x_ptr[i]) * scale; }\n  };\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_FUSED_CAST_SCALE_CPU_KERNEL(x_type, y_type)                           \\\n  REGISTER_USER_KERNEL(\"fused_cast_scale\")                                             \\\n      .SetCreateFn<FusedCastScaleCpuKernel<y_type, x_type>>()                          \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)                  \\\n                       && (user_op::HobDataType(\"y\", 0) == GetDataType<y_type>::value) \\\n                       && (user_op::HobDataType(\"x\", 0) == GetDataType<x_type>::value));\n\nREGISTER_FUSED_CAST_SCALE_CPU_KERNEL(float, double);\nREGISTER_FUSED_CAST_SCALE_CPU_KERNEL(double, float);\n#undef REGISTER_FUSED_CAST_SCALE_CPU_KERNEL\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/fused_cast_scale_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/new_kernel_util.h\"\n#include \"oneflow/core/kernel/cuda_graph_support.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n#include <cuda.h>\n#if CUDA_VERSION >= 11000\n#include <cuda_bf16.h>\n#endif  // CUDA_VERSION >= 11000\n#include \"oneflow/core/device/cuda_pseudo_bfloat16.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename T, typename U>\n__global__ void FusedCastScaleGpu(const int64_t n, const T scale_val, const U* in,\n                                  const T* scale_by_ptr, T* out) {\n  const T scale = *scale_by_ptr * scale_val;\n  CUDA_1D_KERNEL_LOOP(i, n) { out[i] = static_cast<T>(in[i]) * scale; }\n}\n\ntemplate<>\n__global__ void FusedCastScaleGpu<float, half>(const int64_t n, const float scale_val,\n                                               const half* in, const float* scale_by_ptr,\n                                               float* out) {\n  const float scale = *scale_by_ptr * scale_val;\n  const int64_t n_2 = n / 2;\n  const auto* in_2 = reinterpret_cast<const half2*>(in);\n  auto* out_2 = reinterpret_cast<float2*>(out);\n  CUDA_1D_KERNEL_LOOP(i, n_2) {\n    float2 f2 = __half22float2(in_2[i]);\n    f2.x *= scale;\n    f2.y *= scale;\n    out_2[i] = f2;\n  }\n  if (n % 2 == 1 && blockIdx.x == 0 && threadIdx.x == 0) {\n    out[n - 1] = __half2float(in[n - 1]) * scale;\n  }\n}\n\ntemplate<>\n__global__ void FusedCastScaleGpu<half, float>(const int64_t n, const half scale_val,\n                                               const float* in, const half* scale_by_ptr,\n                                               half* out) {\n  const half scale = *scale_by_ptr * scale_val;\n  const half2 scale_h2 = __half2half2(scale);\n  const int64_t n_2 = n / 2;\n  const auto* in_2 = reinterpret_cast<const float2*>(in);\n  auto* out_h2 = reinterpret_cast<half2*>(out);\n  CUDA_1D_KERNEL_LOOP(i, n_2) {\n    half2 in_h2 = __float22half2_rn(in_2[i]);\n    out_h2[i] = __hmul2(in_h2, scale_h2);\n  }\n  if (n % 2 == 1 && blockIdx.x == 0 && threadIdx.x == 0) {\n    out[n - 1] = __float2half(in[n - 1]) * scale;\n  }\n}\n\n#if CUDA_VERSION >= 11000 && __CUDA_ARCH__ >= 800\ntemplate<>\n__global__ void FusedCastScaleGpu<float, nv_bfloat16>(const int64_t n, const float scale_val,\n                                                      const nv_bfloat16* in,\n                                                      const float* scale_by_ptr, float* out) {\n  const float scale = *scale_by_ptr * scale_val;\n  const int64_t n_2 = n / 2;\n  const auto* in_2 = reinterpret_cast<const nv_bfloat162*>(in);\n  auto* out_2 = reinterpret_cast<float2*>(out);\n  CUDA_1D_KERNEL_LOOP(i, n_2) {\n    float2 f2 = __bfloat1622float2(in_2[i]);\n    f2.x *= scale;\n    f2.y *= scale;\n    out_2[i] = f2;\n  }\n  if (n % 2 == 1 && blockIdx.x == 0 && threadIdx.x == 0) {\n    out[n - 1] = __bfloat162float(in[n - 1]) * scale;\n  }\n}\n\ntemplate<>\n__global__ void FusedCastScaleGpu<nv_bfloat16, float>(const int64_t n, const nv_bfloat16 scale_val,\n                                                      const float* in,\n                                                      const nv_bfloat16* scale_by_ptr,\n                                                      nv_bfloat16* out) {\n  const nv_bfloat16 scale = *scale_by_ptr * scale_val;\n  const nv_bfloat162 scale_h2 = __bfloat162bfloat162(scale);\n  const int64_t n_2 = n / 2;\n  const auto* in_2 = reinterpret_cast<const float2*>(in);\n  auto* out_h2 = reinterpret_cast<nv_bfloat162*>(out);\n  CUDA_1D_KERNEL_LOOP(i, n_2) {\n    nv_bfloat162 in_h2 = __float22bfloat162_rn(in_2[i]);\n    out_h2[i] = __hmul2(in_h2, scale_h2);\n  }\n  if (n % 2 == 1 && blockIdx.x == 0 && threadIdx.x == 0) {\n    out[n - 1] = __float2bfloat16(in[n - 1]) * scale;\n  }\n}\n#endif\n\ntemplate<typename T, typename U>\nclass FusedCastScaleGpuKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport {\n public:\n  FusedCastScaleGpuKernel() = default;\n  ~FusedCastScaleGpuKernel() override = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    const user_op::Tensor* scale_by_tensor = ctx->Tensor4ArgNameAndIndex(\"scale_by_tensor\", 0);\n    user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    const int64_t n = x->shape_view().elem_cnt();\n    const double scale = ctx->Attr<double>(\"scale\");\n    const bool use_pack =\n        (x->data_type() == DataType::kFloat\n         && (y->data_type() == DataType::kFloat16 || y->data_type() == DataType::kBFloat16))\n        || (y->data_type() == DataType::kFloat\n            && (x->data_type() == DataType::kFloat16 || x->data_type() == DataType::kBFloat16));\n    const int64_t launch_n = use_pack ? RoundUp(n, 2) / 2 : n;\n    FusedCastScaleGpu<T, U><<<BlocksNum4ThreadsNum(launch_n), kCudaThreadsNumPerBlock, 0,\n                              ctx->stream()->As<ep::CudaStream>()->cuda_stream()>>>(\n        n, static_cast<T>(scale), x->dptr<U>(), scale_by_tensor->dptr<T>(), y->mut_dptr<T>());\n  };\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n}  // namespace\n\n#define REGISTER_FUSED_CAST_SCALE_CUDA_KERNEL(x_type, y_type)                          \\\n  REGISTER_USER_KERNEL(\"fused_cast_scale\")                                             \\\n      .SetCreateFn<FusedCastScaleGpuKernel<y_type, x_type>>()                          \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)                 \\\n                       && (user_op::HobDataType(\"y\", 0) == GetDataType<y_type>::value) \\\n                       && (user_op::HobDataType(\"x\", 0) == GetDataType<x_type>::value));\n\nREGISTER_FUSED_CAST_SCALE_CUDA_KERNEL(half, float);\nREGISTER_FUSED_CAST_SCALE_CUDA_KERNEL(half, double);\nREGISTER_FUSED_CAST_SCALE_CUDA_KERNEL(float, half);\nREGISTER_FUSED_CAST_SCALE_CUDA_KERNEL(float, double);\nREGISTER_FUSED_CAST_SCALE_CUDA_KERNEL(double, half);\nREGISTER_FUSED_CAST_SCALE_CUDA_KERNEL(double, float);\n#if CUDA_VERSION >= 11000\nREGISTER_FUSED_CAST_SCALE_CUDA_KERNEL(nv_bfloat16, float);\nREGISTER_FUSED_CAST_SCALE_CUDA_KERNEL(float, nv_bfloat16);\n#endif\n#undef REGISTER_FUSED_CAST_SCALE_CUDA_KERNEL\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/fused_center_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/device/cuda_util.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/new_kernel_util.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename T>\nstruct FusedCenterForwardFunctor {\n  __device__ T Compute(T b_x_delta, T b_y_delta) const {\n    return (b_x_delta * b_x_delta + b_y_delta * b_y_delta) / static_cast<T>(4.0);\n  }\n};\n\ntemplate<>\nstruct FusedCenterForwardFunctor<half> {\n  FusedCenterForwardFunctor<float> float_functor;\n  __device__ half Compute(half b_x_delta, half b_y_delta) const {\n    return __float2half(float_functor.Compute(__half2float(b_x_delta), __half2float(b_y_delta)));\n  }\n};\n\ntemplate<typename FUNCTOR, typename T>\n__global__ void FusedCenterForward(FUNCTOR functor, const int n, const T* b1_x1, const T* b1_x2,\n                                   const T* b2_x1, const T* b2_x2, const T* b1_y1, const T* b1_y2,\n                                   const T* b2_y1, const T* b2_y2, T* rho) {\n  CUDA_1D_KERNEL_LOOP(i, n) {\n    const T b_x_delta = (b2_x1[i] + b2_x2[i] - b1_x1[i] - b1_x2[i]);\n    const T b_y_delta = (b2_y1[i] + b2_y2[i] - b1_y1[i] - b1_y2[i]);\n    rho[i] = functor.Compute(b_x_delta, b_y_delta);\n  }\n}\n\ntemplate<typename T>\n__global__ void FusedCenterBackward(const int n, const T* b1_x1, const T* b1_x2, const T* b2_x1,\n                                    const T* b2_x2, const T* b1_y1, const T* b1_y2, const T* b2_y1,\n                                    const T* b2_y2, const T* rho2_diff, T* b1_x1_diff,\n                                    T* b1_x2_diff, T* b2_x1_diff, T* b2_x2_diff, T* b1_y1_diff,\n                                    T* b1_y2_diff, T* b2_y1_diff, T* b2_y2_diff) {\n  CUDA_1D_KERNEL_LOOP(i, n) {\n    const T rho2_diff_i_2 = rho2_diff[i] / static_cast<T>(2.0);\n\n    const T b_x_diff = rho2_diff_i_2 * (b1_x1[i] + b1_x2[i] - b2_x1[i] - b2_x2[i]);\n    const T b_y_diff = rho2_diff_i_2 * (b1_y1[i] + b1_y2[i] - b2_y1[i] - b2_y2[i]);\n\n    b1_x1_diff[i] = b_x_diff;\n    b1_x2_diff[i] = b_x_diff;\n    b2_x1_diff[i] = b_x_diff * static_cast<T>(-1.0);\n    b2_x2_diff[i] = b_x_diff * static_cast<T>(-1.0);\n\n    b1_y1_diff[i] = b_y_diff;\n    b1_y2_diff[i] = b_y_diff;\n    b2_y1_diff[i] = b_y_diff * static_cast<T>(-1.0);\n    b2_y2_diff[i] = b_y_diff * static_cast<T>(-1.0);\n  }\n}\n\n}  // namespace\n\ntemplate<typename T>\nclass FusedCenterKernel final : public user_op::OpKernel {\n public:\n  FusedCenterKernel() = default;\n  ~FusedCenterKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* b1_x1 = ctx->Tensor4ArgNameAndIndex(\"b1_x1\", 0);\n    const user_op::Tensor* b1_x2 = ctx->Tensor4ArgNameAndIndex(\"b1_x2\", 0);\n    const user_op::Tensor* b2_x1 = ctx->Tensor4ArgNameAndIndex(\"b2_x1\", 0);\n    const user_op::Tensor* b2_x2 = ctx->Tensor4ArgNameAndIndex(\"b2_x2\", 0);\n    const user_op::Tensor* b1_y1 = ctx->Tensor4ArgNameAndIndex(\"b1_y1\", 0);\n    const user_op::Tensor* b1_y2 = ctx->Tensor4ArgNameAndIndex(\"b1_y2\", 0);\n    const user_op::Tensor* b2_y1 = ctx->Tensor4ArgNameAndIndex(\"b2_y1\", 0);\n    const user_op::Tensor* b2_y2 = ctx->Tensor4ArgNameAndIndex(\"b2_y2\", 0);\n\n    user_op::Tensor* rho = ctx->Tensor4ArgNameAndIndex(\"rho2\", 0);\n\n    const int64_t elem_cnt = b1_x1->shape_view().elem_cnt();\n\n    FusedCenterForwardFunctor<T> fused_center_forward_functor{};\n\n    RUN_CUDA_KERNEL((FusedCenterForward<decltype(fused_center_forward_functor), T>), ctx->stream(),\n                    elem_cnt, fused_center_forward_functor, elem_cnt, b1_x1->dptr<T>(),\n                    b1_x2->dptr<T>(), b2_x1->dptr<T>(), b2_x2->dptr<T>(), b1_y1->dptr<T>(),\n                    b1_y2->dptr<T>(), b2_y1->dptr<T>(), b2_y2->dptr<T>(), rho->mut_dptr<T>());\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_FUSED_GET_CENTER_DIST_CUDA_KERNEL(dtype)              \\\n  REGISTER_USER_KERNEL(\"fused_get_center_dist\")                        \\\n      .SetCreateFn<FusedCenterKernel<dtype>>()                         \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \\\n                       && (user_op::HobDataType(\"rho2\", 0) == GetDataType<dtype>::value));\n\nREGISTER_FUSED_GET_CENTER_DIST_CUDA_KERNEL(float)\nREGISTER_FUSED_GET_CENTER_DIST_CUDA_KERNEL(double)\nREGISTER_FUSED_GET_CENTER_DIST_CUDA_KERNEL(half)\n\ntemplate<typename T>\nclass FusedCenterGradKernel final : public user_op::OpKernel {\n public:\n  FusedCenterGradKernel() = default;\n  ~FusedCenterGradKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* b1_x1 = ctx->Tensor4ArgNameAndIndex(\"b1_x1\", 0);\n    const user_op::Tensor* b1_x2 = ctx->Tensor4ArgNameAndIndex(\"b1_x2\", 0);\n    const user_op::Tensor* b2_x1 = ctx->Tensor4ArgNameAndIndex(\"b2_x1\", 0);\n    const user_op::Tensor* b2_x2 = ctx->Tensor4ArgNameAndIndex(\"b2_x2\", 0);\n    const user_op::Tensor* b1_y1 = ctx->Tensor4ArgNameAndIndex(\"b1_y1\", 0);\n    const user_op::Tensor* b1_y2 = ctx->Tensor4ArgNameAndIndex(\"b1_y2\", 0);\n    const user_op::Tensor* b2_y1 = ctx->Tensor4ArgNameAndIndex(\"b2_y1\", 0);\n    const user_op::Tensor* b2_y2 = ctx->Tensor4ArgNameAndIndex(\"b2_y2\", 0);\n    const user_op::Tensor* rho2_diff = ctx->Tensor4ArgNameAndIndex(\"rho2_diff\", 0);\n\n    user_op::Tensor* b1_x1_diff = ctx->Tensor4ArgNameAndIndex(\"b1_x1_diff\", 0);\n    user_op::Tensor* b1_x2_diff = ctx->Tensor4ArgNameAndIndex(\"b1_x2_diff\", 0);\n    user_op::Tensor* b2_x1_diff = ctx->Tensor4ArgNameAndIndex(\"b2_x1_diff\", 0);\n    user_op::Tensor* b2_x2_diff = ctx->Tensor4ArgNameAndIndex(\"b2_x2_diff\", 0);\n    user_op::Tensor* b1_y1_diff = ctx->Tensor4ArgNameAndIndex(\"b1_y1_diff\", 0);\n    user_op::Tensor* b1_y2_diff = ctx->Tensor4ArgNameAndIndex(\"b1_y2_diff\", 0);\n    user_op::Tensor* b2_y1_diff = ctx->Tensor4ArgNameAndIndex(\"b2_y1_diff\", 0);\n    user_op::Tensor* b2_y2_diff = ctx->Tensor4ArgNameAndIndex(\"b2_y2_diff\", 0);\n\n    const int64_t elem_cnt = b1_x1_diff->shape_view().elem_cnt();\n\n    RUN_CUDA_KERNEL((FusedCenterBackward<T>), ctx->stream(), elem_cnt, elem_cnt, b1_x1->dptr<T>(),\n                    b1_x2->dptr<T>(), b2_x1->dptr<T>(), b2_x2->dptr<T>(), b1_y1->dptr<T>(),\n                    b1_y2->dptr<T>(), b2_y1->dptr<T>(), b2_y2->dptr<T>(), rho2_diff->dptr<T>(),\n                    b1_x1_diff->mut_dptr<T>(), b1_x2_diff->mut_dptr<T>(), b2_x1_diff->mut_dptr<T>(),\n                    b2_x2_diff->mut_dptr<T>(), b1_y1_diff->mut_dptr<T>(), b1_y2_diff->mut_dptr<T>(),\n                    b2_y1_diff->mut_dptr<T>(), b2_y2_diff->mut_dptr<T>());\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_FUSED_GET_CENTER_DIST_GRAD_CUDA_KERNEL(dtype)         \\\n  REGISTER_USER_KERNEL(\"fused_get_center_dist_grad\")                   \\\n      .SetCreateFn<FusedCenterGradKernel<dtype>>()                     \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \\\n                       && (user_op::HobDataType(\"b1_x1\", 0) == GetDataType<dtype>::value));\n\nREGISTER_FUSED_GET_CENTER_DIST_GRAD_CUDA_KERNEL(float)\nREGISTER_FUSED_GET_CENTER_DIST_GRAD_CUDA_KERNEL(double)\nREGISTER_FUSED_GET_CENTER_DIST_GRAD_CUDA_KERNEL(half)\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/fused_clip_grad.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n#include \"oneflow/core/device/cuda_util.h\"\n#include \"oneflow/user/kernels/fused_clip_grad.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nconstexpr int64_t kMultiReduceScaleMulPackSize = 64;\n\ntemplate<typename T>\nstruct MultiClipGradParamPack {\n  MultiClipGradParam<T> params[kMultiReduceScaleMulPackSize];\n  size_t size;\n};\n\nsize_t InferFusedClipGradTempStorageSize(user_op::InferContext* ctx) {\n  auto input_size = ctx->input_size(\"model_diff\");\n  if (input_size == 0) { return 0; }\n  int64_t max_elem_cnt = 0;\n  int64_t pack_size = 0;\n  int32_t num_blocks = 0;\n  for (size_t i = 0; i < input_size; ++i) {\n    int64_t elem_cnt = ctx->InputShape(\"model_diff\", i).elem_cnt();\n    max_elem_cnt = std::max(max_elem_cnt, elem_cnt);\n    pack_size++;\n    if (pack_size == kMultiReduceScaleMulPackSize || i == input_size - 1) {\n      CHECK_LT(max_elem_cnt, std::numeric_limits<int32_t>::max());\n      num_blocks += BlocksNum4ThreadsNum(static_cast<int32_t>(max_elem_cnt));\n      max_elem_cnt = 0;\n      pack_size = 0;\n    }\n  }\n  CHECK_LT(num_blocks, kCudaThreadsNumPerBlock * kCudaThreadsNumPerBlock * kCudaThreadsNumPerBlock)\n      << \"Too much blocks needed for computing \" << ctx->op_name() << \", should be less than \"\n      << kCudaThreadsNumPerBlock << \"*\" << kCudaThreadsNumPerBlock << \"*\" << kCudaThreadsNumPerBlock\n      << \", but got \" << num_blocks;\n  size_t elem_size = GetSizeOfDataType(ctx->InputDType(\"model_diff\", 0));\n  return GetCudaAlignedSize(num_blocks * elem_size * 2);\n}\n\ntemplate<typename T>\n__global__ void MultiBlockClipGradGpu(MultiClipGradParamPack<T> pack_params, T* scale,\n                                      const float norm_type, const float max_norm,\n                                      const ClipGradType clip_grad_type,\n                                      const bool scale_writable) {\n  T t = *scale;\n  if (clip_grad_type == ClipGradType::ZeroType) {\n    t = static_cast<T>(t > 0);\n  } else if (clip_grad_type == ClipGradType::PowerType) {\n    t = std::pow(t, 1.f / norm_type);\n  }\n  if (scale_writable && blockDim.x * blockIdx.x + threadIdx.x == 0) { *scale = t; }\n  t = max_norm / (t + 1e-6);\n  if (t >= 1.) { return; }\n  for (int i = 0; i < pack_params.size; ++i) {\n    auto& param = pack_params.params[i];\n    CUDA_1D_KERNEL_LOOP(j, param.size) { param.data[j] *= t; }\n  }\n}\n\n}  // namespace\n\ntemplate<typename T>\nstruct MultiClipGrad<DeviceType::kCUDA, T> {\n  void operator()(ep::Stream* stream, std::vector<MultiClipGradParam<T>>& params, T* scale,\n                  const float norm_type, const float max_norm, const ClipGradType clip_grad_type) {\n    int32_t total_num_blocks = 0;\n    for (size_t i = 0; i < params.size(); i += kMultiReduceScaleMulPackSize) {\n      MultiClipGradParamPack<T> pack_params{};\n      size_t max_elem_cnt = 0;\n      pack_params.size = std::min<size_t>(kMultiReduceScaleMulPackSize, params.size() - i);\n      for (size_t j = 0; j < pack_params.size; ++j) {\n        pack_params.params[j] = params[i + j];\n        max_elem_cnt = std::max<size_t>(max_elem_cnt, pack_params.params[j].size);\n      }\n      int32_t num_blocks = BlocksNum4ThreadsNum(max_elem_cnt);\n      bool scale_writable = static_cast<bool>(i + kMultiReduceScaleMulPackSize >= params.size());\n      MultiBlockClipGradGpu<T>\n          <<<num_blocks, kCudaThreadsNumPerBlock, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(\n              pack_params, scale, norm_type, max_norm, clip_grad_type, scale_writable);\n      total_num_blocks += num_blocks;\n    }\n  }\n};\n\n#define REGISTER_FUSED_CLIP_GRAD_KERNEL(device, dtype)                                         \\\n  REGISTER_USER_KERNEL(\"fused_clip_grad\")                                                      \\\n      .SetCreateFn<FusedClipGradKernel<device, dtype>>()                                       \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == device)                                    \\\n                       && (user_op::HobDataType(\"model_diff\", 0) == GetDataType<dtype>::value) \\\n                       && (user_op::HobDataType(\"out\", 0) == GetDataType<dtype>::value))       \\\n      .SetInferTmpSizeFn(InferFusedClipGradTempStorageSize);\n\nREGISTER_FUSED_CLIP_GRAD_KERNEL(DeviceType::kCUDA, float);\nREGISTER_FUSED_CLIP_GRAD_KERNEL(DeviceType::kCUDA, double);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/fused_clip_grad.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_KERNELS_FUSED_CLIP_GRAD_H_\n#define ONEFLOW_USER_KERNELS_FUSED_CLIP_GRAD_H_\n\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/cuda_graph_support.h\"\n#include \"oneflow/user/kernels/multi_reduce_kernel_util.h\"\n#include \"oneflow/user/kernels/fused_clip_grad_util.h\"\n\nnamespace oneflow {\n\ntemplate<DeviceType device_type, typename T>\nclass FusedClipGradKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport {\n public:\n  FusedClipGradKernel() = default;\n  ~FusedClipGradKernel() override = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    T* out_ptr = out->mut_dptr<T>();\n    T* temp = (ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0))->mut_dptr<T>();\n    const int32_t input_size = ctx->input_size(\"model_diff\");\n    const float max_norm = ctx->Attr<float>(\"max_norm\");\n    const float norm_type = ctx->Attr<float>(\"norm_type\");\n\n    std::vector<MultiReduceParam<T>> params;\n    params.resize(input_size);\n    for (size_t i = 0; i < input_size; ++i) {\n      const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex(\"model_diff\", i);\n      params[i].size = x->shape_view().elem_cnt();\n      params[i].data = x->dptr<T>();\n    }\n    if (norm_type == 0) {\n      PowByZero<T> func{};\n      MultiReduce<device_type, T, decltype(func), BinaryAdd<T>> reduce_add{};\n      reduce_add(ctx->stream(), func, params, GetZeroVal<T>(), out_ptr, temp);\n    } else if (norm_type == INFINITY) {\n      Abs<T> func{};\n      MultiReduce<device_type, T, decltype(func), BinaryMax<T>> reduce_max{};\n      reduce_max(ctx->stream(), func, params, GetZeroVal<T>(), out_ptr, temp);\n    } else if (norm_type == -INFINITY) {\n      Abs<T> func{};\n      MultiReduce<device_type, T, decltype(func), BinaryMin<T>> reduce_min{};\n      reduce_min(ctx->stream(), func, params, std::numeric_limits<T>::max(), out_ptr, temp);\n    } else if (norm_type == 1) {\n      Abs<T> func{};\n      MultiReduce<device_type, T, decltype(func), BinaryAdd<T>> reduce_sum{};\n      reduce_sum(ctx->stream(), func, params, GetZeroVal<T>(), out_ptr, temp);\n    } else if (norm_type == 2) {\n      Square<T> func{};\n      MultiReduce<device_type, T, decltype(func), BinaryAdd<T>> reduce_sum{};\n      reduce_sum(ctx->stream(), func, params, GetZeroVal<T>(), out_ptr, temp);\n    } else {\n      AbsPow<T> func{norm_type};\n      MultiReduce<device_type, T, decltype(func), BinaryAdd<T>> reduce_sum{};\n      reduce_sum(ctx->stream(), func, params, GetZeroVal<T>(), out_ptr, temp);\n    }\n\n    std::vector<MultiClipGradParam<T>> mut_params;\n    mut_params.resize(input_size);\n    for (size_t i = 0; i < input_size; ++i) {\n      user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex(\"model_diff\", i);\n      mut_params[i].size = x->shape_view().elem_cnt();\n      mut_params[i].data = x->mut_dptr<T>();\n    }\n    MultiClipGrad<device_type, T> multi_clip_grad{};\n    if (norm_type == 0) {\n      multi_clip_grad(ctx->stream(), mut_params, out_ptr, norm_type, max_norm,\n                      ClipGradType::ZeroType);\n    } else if (std::abs(norm_type) == INFINITY || norm_type == 1) {\n      multi_clip_grad(ctx->stream(), mut_params, out_ptr, norm_type, max_norm,\n                      ClipGradType::OtherType);\n    } else {\n      multi_clip_grad(ctx->stream(), mut_params, out_ptr, norm_type, max_norm,\n                      ClipGradType::PowerType);\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return true; }\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_KERNELS_FUSED_CLIP_GRAD_H_\n"
  },
  {
    "path": "oneflow/user/kernels/fused_clip_grad_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_KERNELS_FUSED_CLIP_GRAD_UTIL_H_\n#define ONEFLOW_USER_KERNELS_FUSED_CLIP_GRAD_UTIL_H_\n\n#include \"oneflow/core/common/data_type.h\"\n#include \"oneflow/core/common/device_type.h\"\n#include \"oneflow/core/common/device_type.pb.h\"\n#include \"oneflow/core/ep/include/stream.h\"\n\nnamespace oneflow {\n\ntemplate<typename T>\nstruct MultiClipGradParam {\n  T* data;\n  size_t size;\n};\n\nenum ClipGradType : int {\n  ZeroType,\n  PowerType,\n  OtherType,\n};\n\ntemplate<DeviceType device_type, typename T>\nstruct MultiClipGrad {\n  void operator()(ep::Stream* stream, std::vector<MultiClipGradParam<T>>& params, T* scale,\n                  const float norm_type, const float max_norm, const ClipGradType clip_grad_type);\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_KERNELS_FUSED_CLIP_GRAD_UTIL_H_"
  },
  {
    "path": "oneflow/user/kernels/fused_codegeex_qkv_reshape_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <cassert>\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/new_kernel_util.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename T, int pack_size>\nstruct alignas(sizeof(T) * pack_size) Packed {\n  __device__ Packed(T val) {\n#pragma unroll\n    for (int i = 0; i < pack_size; i++) { elem[i] = val; }\n  }\n  __device__ Packed() {\n    // do nothing\n  }\n  union {\n    T elem[pack_size];\n  };\n  __device__ void operator=(Packed<T, pack_size> packA) {\n#pragma unroll\n    for (int i = 0; i < pack_size; i++) { elem[i] = packA.elem[i]; }\n  }\n};\n\n// [seq_length, batch_size, hidden_size] -> [seq_length, batch_size, head_num, size_per_head]\ntemplate<typename T, int pack_size>\n__global__ void batch_reshape_for_qkv(const int n, const T* query, const T* key, const T* value,\n                                      T* new_query, T* new_key, T* new_value) {\n  const auto* query_pack_ptr = reinterpret_cast<const Packed<T, pack_size>*>(query);\n  const auto* key_pack_ptr = reinterpret_cast<const Packed<T, pack_size>*>(key);\n  const auto* value_pack_ptr = reinterpret_cast<const Packed<T, pack_size>*>(value);\n  auto* new_query_pack_ptr = reinterpret_cast<Packed<T, pack_size>*>(new_query);\n  auto* new_key_pack_ptr = reinterpret_cast<Packed<T, pack_size>*>(new_key);\n  auto* new_value_pack_ptr = reinterpret_cast<Packed<T, pack_size>*>(new_value);\n  assert(n % pack_size == 0);\n  CUDA_1D_KERNEL_LOOP(i, n) {\n    Packed<T, pack_size> query_pack = query_pack_ptr[i];\n    Packed<T, pack_size> key_pack = key_pack_ptr[i];\n    Packed<T, pack_size> value_pack = value_pack_ptr[i];\n    new_query_pack_ptr[i] = query_pack;\n    new_key_pack_ptr[i] = key_pack;\n    new_value_pack_ptr[i] = value_pack;\n  }\n}\n\n};  // namespace\n\ntemplate<typename T>\nclass FusedCodegeexQkvReshapeGpuKernel final : public user_op::OpKernel {\n public:\n  FusedCodegeexQkvReshapeGpuKernel() = default;\n  ~FusedCodegeexQkvReshapeGpuKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    // [seq_length, batch_size, hidden_size] -> [seq_length, batch_size, head_num, size_per_head]\n    const user_op::Tensor* query = ctx->Tensor4ArgNameAndIndex(\"query\", 0);\n    const user_op::Tensor* key = ctx->Tensor4ArgNameAndIndex(\"key\", 0);\n    const user_op::Tensor* value = ctx->Tensor4ArgNameAndIndex(\"value\", 0);\n\n    user_op::Tensor* new_query = ctx->Tensor4ArgNameAndIndex(\"new_query\", 0);\n    user_op::Tensor* new_key = ctx->Tensor4ArgNameAndIndex(\"new_key\", 0);\n    user_op::Tensor* new_value = ctx->Tensor4ArgNameAndIndex(\"new_value\", 0);\n\n    const int32_t n = query->shape_view().elem_cnt();\n    if (n % 4 == 0) {\n      RUN_CUDA_KERNEL((batch_reshape_for_qkv<T, 4>), ctx->stream(), n / 4, n / 4, query->dptr<T>(),\n                      key->dptr<T>(), value->dptr<T>(), new_query->mut_dptr<T>(),\n                      new_key->mut_dptr<T>(), new_value->mut_dptr<T>());\n    } else if (n % 2 == 0) {\n      RUN_CUDA_KERNEL((batch_reshape_for_qkv<T, 2>), ctx->stream(), n / 2, n / 2, query->dptr<T>(),\n                      key->dptr<T>(), value->dptr<T>(), new_query->mut_dptr<T>(),\n                      new_key->mut_dptr<T>(), new_value->mut_dptr<T>());\n    } else {\n      RUN_CUDA_KERNEL((batch_reshape_for_qkv<T, 1>), ctx->stream(), n, n, query->dptr<T>(),\n                      key->dptr<T>(), value->dptr<T>(), new_query->mut_dptr<T>(),\n                      new_key->mut_dptr<T>(), new_value->mut_dptr<T>());\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_FUSED_CODEGEEX_QKV_RESHAPE_CUDA_KERNEL(dtype)         \\\n  REGISTER_USER_KERNEL(\"fused_codegeex_qkv_reshape\")                   \\\n      .SetCreateFn<FusedCodegeexQkvReshapeGpuKernel<dtype>>()          \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \\\n                       && (user_op::HobDataType(\"query\", 0) == GetDataType<dtype>::value));\n\nREGISTER_FUSED_CODEGEEX_QKV_RESHAPE_CUDA_KERNEL(float)\nREGISTER_FUSED_CODEGEEX_QKV_RESHAPE_CUDA_KERNEL(half)\nREGISTER_FUSED_CODEGEEX_QKV_RESHAPE_CUDA_KERNEL(double)\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/fused_cross_feature_interaction.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/cuda_graph_support.h\"\n#include \"oneflow/core/ep/include/primitive/matmul.h\"\n#include \"oneflow/core/cuda/elementwise.cuh\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nenum InteractionMode { kVector = 0, kMatrix };\n\nconstexpr int kBlockSize = 256;\n\nvoid InferMatmulMNK(const ShapeView& a_shape, const ShapeView& b_shape, bool transpose_a,\n                    bool transpose_b, size_t* m, size_t* n, size_t* k) {\n  const int64_t num_a_axes = a_shape.NumAxes();\n  CHECK_GE(num_a_axes, 2);\n  const int64_t num_b_axes = b_shape.NumAxes();\n  CHECK_GE(num_b_axes, 2);\n  if (!transpose_a) {\n    *m = a_shape.At(num_a_axes - 2);\n    *k = a_shape.At(num_a_axes - 1);\n  } else {\n    *m = a_shape.At(num_a_axes - 1);\n    *k = a_shape.At(num_a_axes - 2);\n  }\n  if (!transpose_b) {\n    CHECK_EQ(b_shape.At(num_b_axes - 2), *k);\n    *n = b_shape.At(num_b_axes - 1);\n  } else {\n    CHECK_EQ(b_shape.At(num_b_axes - 1), *k);\n    *n = b_shape.At(num_b_axes - 2);\n  }\n}\n\nep::primitive::BlasTransposeType GetBlasTransposeType(bool transpose) {\n  return transpose ? ep::primitive::BlasTransposeType::T : ep::primitive::BlasTransposeType::N;\n}\n\nstd::unique_ptr<ep::primitive::Matmul> NewMatmulPrimitive(DeviceType device_type,\n                                                          DataType data_type, bool transpose_a,\n                                                          bool transpose_b) {\n  const auto trans_a = GetBlasTransposeType(transpose_a);\n  const auto trans_b = GetBlasTransposeType(transpose_b);\n  return ep::primitive::NewPrimitive<ep::primitive::MatmulFactory>(device_type, data_type, trans_a,\n                                                                   trans_b);\n}\n\ntemplate<typename Context>\nstd::unique_ptr<ep::primitive::Matmul> NewMatmulPrimitive(Context* ctx) {\n  const DataType data_type = ctx->TensorDesc4ArgNameAndIndex(\"x\", 0)->data_type();\n  return NewMatmulPrimitive(ctx->device_type(), data_type, /*transpose_a=*/false,\n                            /*transpose_b=*/true);\n}\n\nauto MatmulPrimitiveExists() {\n  return hob::make_custom(\"MatmulPrimitiveExists\", [](const user_op::KernelRegContext& ctx) {\n    return NewMatmulPrimitive(&ctx).operator bool();\n  });\n}\n\ntemplate<typename T, typename IndexType, int pack_size, InteractionMode mode>\n__global__ void FusedBiasAddMulAddResidualKernel(const T* in, const T* x, const T* x0,\n                                                 const T* bias, T* out, const IndexType cols,\n                                                 const IndexType elem_cnt) {\n  const IndexType global_thread_id = blockDim.x * blockIdx.x + threadIdx.x;\n  using LoadPack = cuda::elementwise::Packed<T, pack_size>;\n  for (IndexType linear_index = global_thread_id * pack_size,\n                 step = gridDim.x * blockDim.x * pack_size;\n       linear_index < elem_cnt; linear_index += step) {\n    const IndexType row_idx = linear_index / cols;\n    const IndexType col_idx = linear_index - row_idx * cols;\n\n    const LoadPack* x0_load = reinterpret_cast<const LoadPack*>(x0 + linear_index);\n    const LoadPack* x_load = reinterpret_cast<const LoadPack*>(x + linear_index);\n    const LoadPack* bias_load = reinterpret_cast<const LoadPack*>(bias + col_idx);\n\n    LoadPack x0_vec = *x0_load;\n    LoadPack x_vec = *x_load;\n    LoadPack bias_vec = *bias_load;\n\n    LoadPack out_store;\n    if (mode == InteractionMode::kVector) {\n      T in_val = in[row_idx];\n#pragma unroll\n      for (int i = 0; i < pack_size; i++) {\n        out_store.elem[i] = x0_vec.elem[i] * in_val + bias_vec.elem[i] + x_vec.elem[i];\n      }\n    } else if (mode == InteractionMode::kMatrix) {\n      const LoadPack* in_load = reinterpret_cast<const LoadPack*>(in + linear_index);\n      LoadPack in_vec = *in_load;\n#pragma unroll\n      for (int i = 0; i < pack_size; i++) {\n        out_store.elem[i] = (in_vec.elem[i] + bias_vec.elem[i]) * x0_vec.elem[i] + x_vec.elem[i];\n      }\n    } else {\n      __trap();\n    }\n    *(reinterpret_cast<LoadPack*>(out + linear_index)) = out_store;\n  }\n}\n\ntemplate<typename T>\nint GetLaunchPackSize(const int64_t cols) {\n  constexpr int type_pack_size = cuda::elementwise::PackSize<T>();\n  for (int launch_pack_size = 8; launch_pack_size > 0; launch_pack_size /= 2) {\n    if (type_pack_size >= launch_pack_size && cols % launch_pack_size == 0) {\n      return launch_pack_size;\n    }\n  }\n  return 1;\n}\n\ntemplate<typename T, typename IndexType, InteractionMode mode>\nvoid DispatchFusedBiasAddMulAddResidualPackSize(ep::Stream* stream, const T* in, const T* x,\n                                                const T* x0, const T* bias, T* out,\n                                                const IndexType cols, const IndexType elem_cnt) {\n  int grid_size;\n  const int pack_size = GetLaunchPackSize<T>(cols);\n  const int64_t pack_num = elem_cnt / pack_size;\n  cudaError_t err = cuda::elementwise::GetNumBlocks(pack_num, &grid_size);\n  if (pack_size == 8) {\n    FusedBiasAddMulAddResidualKernel<T, IndexType, 8, mode>\n        <<<grid_size, kBlockSize, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(\n            in, x, x0, bias, out, cols, elem_cnt);\n  } else if (pack_size == 4) {\n    FusedBiasAddMulAddResidualKernel<T, IndexType, 4, mode>\n        <<<grid_size, kBlockSize, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(\n            in, x, x0, bias, out, cols, elem_cnt);\n  } else if (pack_size == 2) {\n    FusedBiasAddMulAddResidualKernel<T, IndexType, 2, mode>\n        <<<grid_size, kBlockSize, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(\n            in, x, x0, bias, out, cols, elem_cnt);\n  } else {\n    FusedBiasAddMulAddResidualKernel<T, IndexType, 1, mode>\n        <<<grid_size, kBlockSize, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(\n            in, x, x0, bias, out, cols, elem_cnt);\n  }\n}\n\ntemplate<typename T, InteractionMode mode>\nvoid DispatchFusedBiasAddMulAddResidualIndexType(ep::Stream* stream, const T* in, const T* x,\n                                                 const T* x0, const T* bias, T* out,\n                                                 const int64_t cols, const int64_t elem_cnt) {\n  if (elem_cnt < GetMaxVal<int32_t>()) {\n    DispatchFusedBiasAddMulAddResidualPackSize<T, int32_t, mode>(stream, in, x, x0, bias, out, cols,\n                                                                 elem_cnt);\n  } else {\n    DispatchFusedBiasAddMulAddResidualPackSize<T, int64_t, mode>(stream, in, x, x0, bias, out, cols,\n                                                                 elem_cnt);\n  }\n}\n\ntemplate<typename T>\nclass FusedCrossFeatureInteractionKernel final : public user_op::OpKernel,\n                                                 public user_op::CudaGraphSupport {\n public:\n  FusedCrossFeatureInteractionKernel() = default;\n  ~FusedCrossFeatureInteractionKernel() override = default;\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    /*\n    Cross Interaction v1:\n    1. x matmul weight. matmul_result0 -> (B, E) matmul (1, E) -> (B, 1)\n       dx = dmatmul_result0 matmul weight\n       dw = x matmul dmatmul_result0\n\n    2. matmul_result0 broadcast_mul x0. matmul_result1 -> (B, 1) broadcast_mul (B, E) -> (B, E)\n       dmatmul_result0 = reduce_sum(dmatmul_result1 * x0, axis=1)\n       dx0 = dmatmul_result1 broadcast_mul matmul_result0\n\n    3. matmul_result1 broadcast_add bias. matmul_result2 -> (B, E) broadcast_add (1, E) -> (B, E)\n       dmatmul_result1 = dout\n       dbias = reduce_sum(dmatmul_result2, axis=0)\n\n    4. matmul_result2 add x. out -> (B, E) elementwise_add (B, E) -> (B, E)\n       dmatmul_result2 = dout, dx = dout.\n\n    Cross Interaction Grad:\n    dw = x matmul dmatmul_result0\n    dx0 = dmatmul_result1 broadcast_mul matmul_result0\n    dbias = reduce_sum(dmatmul_result2, axis=0)\n    dx = (dmatmul_result0 matmul weight) + dout.\n\n    Cross Interaction v2:\n    1. x matmul weight. matmul_result0 -> (B, E) matmul (E, E) -> (B, E)\n\n    2. matmul_result0 add bias. matmul_result1 -> (B, E) bias_add (1, E) -> (B, E)\n\n    3. matmul_result1 multiply x0. matmul_result2 -> (B, E) elementwise_mul (B, E) -> (B, E)\n\n    4. matmul_result2 add x. out -> (B, E) elementwise_add (B, E) -> (B, E)\n\n    */\n    const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    const user_op::Tensor* weight = ctx->Tensor4ArgNameAndIndex(\"weight\", 0);\n    const user_op::Tensor* x0 = ctx->Tensor4ArgNameAndIndex(\"x0\", 0);\n    const user_op::Tensor* bias = ctx->Tensor4ArgNameAndIndex(\"bias\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    user_op::Tensor* matmul_result = ctx->Tensor4ArgNameAndIndex(\"matmul_result\", 0);\n    const std::string interaction_mode = ctx->Attr<std::string>(\"interaction_mode\");\n\n    CHECK_EQ(out->shape_view().NumAxes(), 2);\n    size_t m = 0, n = 0, k = 0;\n    InferMatmulMNK(x->shape_view(), weight->shape_view(), /*trans_a=*/false, /*trans_b=*/true, &m,\n                   &n, &k);\n    const double alpha = 1.0;\n    double beta = 0.0;\n    auto matmul = NewMatmulPrimitive(ctx);\n    CHECK(matmul);\n    matmul->Launch(ctx->stream(), m, n, k, alpha, x->dptr(), weight->dptr(), beta,\n                   matmul_result->mut_dptr());\n    const int64_t elem_cnt = out->shape_view().elem_cnt();\n    const int64_t cols = out->shape_view().At(1);\n    if (interaction_mode == \"vector\") {\n      DispatchFusedBiasAddMulAddResidualIndexType<T, InteractionMode::kVector>(\n          ctx->stream(), matmul_result->mut_dptr<T>(), x->dptr<T>(), x0->dptr<T>(), bias->dptr<T>(),\n          out->mut_dptr<T>(), cols, elem_cnt);\n    } else {\n      DispatchFusedBiasAddMulAddResidualIndexType<T, InteractionMode::kMatrix>(\n          ctx->stream(), matmul_result->mut_dptr<T>(), x->dptr<T>(), x0->dptr<T>(), bias->dptr<T>(),\n          out->mut_dptr<T>(), cols, elem_cnt);\n    }\n  }\n};\n\n#define REGISTER_FUSED_CROSS_FEATURE_INTERACTION_KERNEL(dtype)                        \\\n  REGISTER_USER_KERNEL(\"fused_cross_feature_interaction\")                             \\\n      .SetCreateFn<FusedCrossFeatureInteractionKernel<dtype>>()                       \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)                \\\n                       && (user_op::HobDataType(\"x\", 0) == GetDataType<dtype>::value) \\\n                       && MatmulPrimitiveExists());\n\nREGISTER_FUSED_CROSS_FEATURE_INTERACTION_KERNEL(float)\nREGISTER_FUSED_CROSS_FEATURE_INTERACTION_KERNEL(half)\n\n}  // namespace\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/fused_cross_feature_interaction_grad.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/cuda_graph_support.h\"\n#include \"oneflow/core/ep/include/primitive/matmul.h\"\n#include \"oneflow/core/cuda/elementwise.cuh\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nconstexpr int kBlockSize = 256;\n\nvoid InferMatmulMNK(const DimVector& a_shape, const DimVector& b_shape, bool transpose_a,\n                    bool transpose_b, size_t* m, size_t* n, size_t* k) {\n  const int64_t num_a_axes = a_shape.size();\n  CHECK_GE(num_a_axes, 2);\n  const int64_t num_b_axes = b_shape.size();\n  CHECK_GE(num_b_axes, 2);\n  if (!transpose_a) {\n    *m = a_shape.at(num_a_axes - 2);\n    *k = a_shape.at(num_a_axes - 1);\n  } else {\n    *m = a_shape.at(num_a_axes - 1);\n    *k = a_shape.at(num_a_axes - 2);\n  }\n  if (!transpose_b) {\n    CHECK_EQ(b_shape.at(num_b_axes - 2), *k);\n    *n = b_shape.at(num_b_axes - 1);\n  } else {\n    CHECK_EQ(b_shape.at(num_b_axes - 1), *k);\n    *n = b_shape.at(num_b_axes - 2);\n  }\n}\n\nep::primitive::BlasTransposeType GetBlasTransposeType(bool transpose) {\n  return transpose ? ep::primitive::BlasTransposeType::T : ep::primitive::BlasTransposeType::N;\n}\n\ntemplate<typename T>\nstruct MulOp {\n  __device__ __forceinline__ T operator()(const T& a, const T& b) const { return a * b; }\n};\n\ntemplate<typename T>\nstruct AddOp {\n  __device__ __forceinline__ T operator()(const T& a, const T& b) const { return a + b; }\n};\n\ntemplate<typename T>\nint GetLaunchPackSize(const int64_t cols) {\n  constexpr int type_pack_size = cuda::elementwise::PackSize<T>();\n  for (int launch_pack_size = 8; launch_pack_size > 0; launch_pack_size /= 2) {\n    if (type_pack_size >= launch_pack_size && cols % launch_pack_size == 0) {\n      return launch_pack_size;\n    }\n  }\n  return 1;\n}\n\ntemplate<typename T, typename IndexType, int pack_size>\n__global__ void BroadcastMulKernel(const T* x, const T* y, T* out, const IndexType cols,\n                                   const IndexType elem_cnt) {\n  const IndexType global_thread_id = blockDim.x * blockIdx.x + threadIdx.x;\n  using LoadPack = cuda::elementwise::Packed<T, pack_size>;\n  for (IndexType linear_index = global_thread_id * pack_size,\n                 step = gridDim.x * blockDim.x * pack_size;\n       linear_index < elem_cnt; linear_index += step) {\n    const IndexType row_idx = linear_index / cols;\n    const LoadPack* x_load = reinterpret_cast<const LoadPack*>(x + linear_index);\n    LoadPack x_vec = *x_load;\n    LoadPack out_store;\n    const T y_val = y[row_idx];\n#pragma unroll\n    for (int i = 0; i < pack_size; i++) { out_store.elem[i] = x_vec.elem[i] * y_val; }\n    *(reinterpret_cast<LoadPack*>(out + linear_index)) = out_store;\n  }\n}\n\ntemplate<typename T, typename IndexType>\nvoid DispatchBroadcastMulPackSize(ep::Stream* stream, const T* x, const T* y, T* out,\n                                  const IndexType cols, const IndexType elem_cnt) {\n  int grid_size;\n  const int pack_size = GetLaunchPackSize<T>(cols);\n  const int64_t pack_num = elem_cnt / pack_size;\n  cudaError_t err = cuda::elementwise::GetNumBlocks(pack_num, &grid_size);\n  if (pack_size == 8) {\n    BroadcastMulKernel<T, IndexType, 8>\n        <<<grid_size, kBlockSize, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(x, y, out, cols,\n                                                                                    elem_cnt);\n  } else if (pack_size == 4) {\n    BroadcastMulKernel<T, IndexType, 4>\n        <<<grid_size, kBlockSize, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(x, y, out, cols,\n                                                                                    elem_cnt);\n  } else if (pack_size == 2) {\n    BroadcastMulKernel<T, IndexType, 2>\n        <<<grid_size, kBlockSize, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(x, y, out, cols,\n                                                                                    elem_cnt);\n  } else {\n    BroadcastMulKernel<T, IndexType, 1>\n        <<<grid_size, kBlockSize, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(x, y, out, cols,\n                                                                                    elem_cnt);\n  }\n}\n\ntemplate<typename T>\nvoid DispatchBroadcastMulIndexType(ep::Stream* stream, const T* x, const T* y, T* out,\n                                   const int64_t cols, const int64_t elem_cnt) {\n  if (elem_cnt < GetMaxVal<int32_t>()) {\n    DispatchBroadcastMulPackSize<T, int32_t>(stream, x, y, out, cols, elem_cnt);\n  } else {\n    DispatchBroadcastMulPackSize<T, int64_t>(stream, x, y, out, cols, elem_cnt);\n  }\n}\n\ntemplate<typename T, typename IndexType, int pack_size>\n__global__ void BroadcastAddElementwiseMulKernel(const T* x, const T* y, const T* z, T* out,\n                                                 const IndexType cols, const IndexType elem_cnt) {\n  const IndexType global_thread_id = blockDim.x * blockIdx.x + threadIdx.x;\n  using LoadPack = cuda::elementwise::Packed<T, pack_size>;\n  for (IndexType linear_index = global_thread_id * pack_size,\n                 step = gridDim.x * blockDim.x * pack_size;\n       linear_index < elem_cnt; linear_index += step) {\n    const IndexType row_idx = linear_index / cols;\n    const IndexType col_idx = linear_index - row_idx * cols;\n    const LoadPack* x_load = reinterpret_cast<const LoadPack*>(x + linear_index);\n    const LoadPack* y_load = reinterpret_cast<const LoadPack*>(y + col_idx);\n    const LoadPack* z_load = reinterpret_cast<const LoadPack*>(z + linear_index);\n\n    LoadPack x_vec = *x_load;\n    LoadPack y_vec = *y_load;\n    LoadPack z_vec = *z_load;\n    LoadPack out_store;\n\n#pragma unroll\n    for (int i = 0; i < pack_size; i++) {\n      out_store.elem[i] = (x_vec.elem[i] + y_vec.elem[i]) * z_vec.elem[i];\n    }\n    *(reinterpret_cast<LoadPack*>(out + linear_index)) = out_store;\n  }\n}\n\ntemplate<typename T, typename IndexType>\nvoid DispatchBroadcastAddElementwiseMulPackSize(ep::Stream* stream, const T* x, const T* y,\n                                                const T* z, T* out, const IndexType cols,\n                                                const IndexType elem_cnt) {\n  int grid_size;\n  const int pack_size = GetLaunchPackSize<T>(cols);\n  const int64_t pack_num = elem_cnt / pack_size;\n  cudaError_t err = cuda::elementwise::GetNumBlocks(pack_num, &grid_size);\n  if (pack_size == 8) {\n    BroadcastAddElementwiseMulKernel<T, IndexType, 8>\n        <<<grid_size, kBlockSize, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(x, y, z, out,\n                                                                                    cols, elem_cnt);\n  } else if (pack_size == 4) {\n    BroadcastAddElementwiseMulKernel<T, IndexType, 4>\n        <<<grid_size, kBlockSize, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(x, y, z, out,\n                                                                                    cols, elem_cnt);\n  } else if (pack_size == 2) {\n    BroadcastAddElementwiseMulKernel<T, IndexType, 2>\n        <<<grid_size, kBlockSize, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(x, y, z, out,\n                                                                                    cols, elem_cnt);\n  } else {\n    BroadcastAddElementwiseMulKernel<T, IndexType, 1>\n        <<<grid_size, kBlockSize, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(x, y, z, out,\n                                                                                    cols, elem_cnt);\n  }\n}\n\ntemplate<typename T>\nvoid DispatchBroadcastAddElementwiseMulIndexType(ep::Stream* stream, const T* x, const T* y,\n                                                 const T* z, T* out, const int64_t cols,\n                                                 const int64_t elem_cnt) {\n  if (elem_cnt < GetMaxVal<int32_t>()) {\n    DispatchBroadcastAddElementwiseMulPackSize<T, int32_t>(stream, x, y, z, out, cols, elem_cnt);\n  } else {\n    DispatchBroadcastAddElementwiseMulPackSize<T, int64_t>(stream, x, y, z, out, cols, elem_cnt);\n  }\n}\n\n}  // namespace\n\nnamespace user_op {\n\nstd::unique_ptr<ep::primitive::Matmul> NewMatmulPrimitive(DeviceType device_type,\n                                                          DataType data_type, bool transpose_a,\n                                                          bool transpose_b) {\n  const auto trans_a = GetBlasTransposeType(transpose_a);\n  const auto trans_b = GetBlasTransposeType(transpose_b);\n  return ep::primitive::NewPrimitive<ep::primitive::MatmulFactory>(device_type, data_type, trans_a,\n                                                                   trans_b);\n}\n\ntemplate<typename Context>\nstd::unique_ptr<ep::primitive::Matmul> NewReduceMatmulPrimitive(Context* ctx) {\n  const DataType data_type = ctx->TensorDesc4ArgNameAndIndex(\"dy\", 0)->data_type();\n  return NewMatmulPrimitive(ctx->device_type(), data_type, /*transpose_a=*/false,\n                            /*transpose_b=*/false);\n}\n\nauto ReduceMatmulPrimitiveExists() {\n  return hob::make_custom(\"MatmulPrimitiveExists\", [](const KernelRegContext& ctx) {\n    return NewReduceMatmulPrimitive(&ctx).operator bool();\n  });\n}\n\ntemplate<typename Context>\nstd::unique_ptr<ep::primitive::Matmul> NewWeightGradMatmulPrimitive(Context* ctx) {\n  const DataType data_type = ctx->TensorDesc4ArgNameAndIndex(\"x\", 0)->data_type();\n  return NewMatmulPrimitive(ctx->device_type(), data_type, /*transpose_a=*/true,\n                            /*transpose_b=*/false);\n}\n\nauto WeightGradMatmulPrimitiveExists() {\n  return hob::make_custom(\"MatmulPrimitiveExists\", [](const KernelRegContext& ctx) {\n    return NewWeightGradMatmulPrimitive(&ctx).operator bool();\n  });\n}\n\ntemplate<typename T>\nclass FusedCrossFeatureInteractionGradKernel final : public OpKernel, public CudaGraphSupport {\n public:\n  FusedCrossFeatureInteractionGradKernel() = default;\n  ~FusedCrossFeatureInteractionGradKernel() override = default;\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(KernelComputeContext* ctx) const override {\n    const Tensor* dy = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    const Tensor* weight = ctx->Tensor4ArgNameAndIndex(\"weight\", 0);\n    const Tensor* x0 = ctx->Tensor4ArgNameAndIndex(\"x0\", 0);\n    const Tensor* x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    const Tensor* matmul_result = ctx->Tensor4ArgNameAndIndex(\"matmul_result\", 0);\n\n    const int64_t batch_size = dy->shape_view().At(0);\n    const int64_t hidden_size = dy->shape_view().At(1);\n    const int64_t out_size = weight->shape_view().At(0);\n    const int64_t dy_elem_cnt = dy->shape_view().elem_cnt();\n\n    Tensor* dx0 = ctx->Tensor4ArgNameAndIndex(\"dx0\", 0);\n    Tensor* dw = ctx->Tensor4ArgNameAndIndex(\"dw\", 0);\n    Tensor* dx = ctx->Tensor4ArgNameAndIndex(\"dx\", 0);\n    Tensor* dbias = ctx->Tensor4ArgNameAndIndex(\"dbias\", 0);\n    Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n\n    // step1: Get dbias.\n    const T* ones = nullptr;\n    auto* cuda_device = dynamic_cast<ep::CudaDevice*>(ctx->stream()->device());\n    if (cuda_device != nullptr) {\n      ones = static_cast<const T*>(cuda_device->GetConstOnes(dy->data_type(), batch_size));\n    }\n    size_t m = 0, n = 0, k = 0;\n    DimVector dy_shape(2);\n    dy->shape_view().ToDimVector(&dy_shape);\n    DimVector ones_buf_shape(2);\n    ones_buf_shape.at(0) = 1;\n    ones_buf_shape.at(1) = batch_size;\n    InferMatmulMNK(ones_buf_shape, dy_shape, /*trans_a=*/false, /*trans_b=*/false, &m, &n, &k);\n    auto reduce_matmul = NewReduceMatmulPrimitive(ctx);\n    CHECK(reduce_matmul);\n    reduce_matmul->Launch(ctx->stream(), m, n, k, 1.0, ones, dy->dptr(), 0.0, dbias->mut_dptr());\n\n    // step2: Get dmatmul_result0.\n    T* dy_mul_x0 = reinterpret_cast<T*>(tmp_buffer->mut_dptr());\n    T* dmatmul_result0 = reinterpret_cast<T*>(tmp_buffer->mut_dptr<char>()\n                                              + GetCudaAlignedSize(dy_elem_cnt * sizeof(T)));\n    OF_CUDA_CHECK(cuda::elementwise::Binary(MulOp<T>(), dy_elem_cnt, dy_mul_x0, dy->dptr<T>(),\n                                            x0->dptr<T>(),\n                                            ctx->stream()->As<ep::CudaStream>()->cuda_stream()));\n\n    ones = static_cast<const T*>(cuda_device->GetConstOnes(dy->data_type(), hidden_size));\n    DimVector dy_mul_x0_shape(2);\n    dy->shape_view().ToDimVector(&dy_mul_x0_shape);\n    ones_buf_shape.at(0) = hidden_size;\n    ones_buf_shape.at(1) = 1;\n    InferMatmulMNK(dy_mul_x0_shape, ones_buf_shape, /*trans_a=*/false, /*trans_b=*/false, &m, &n,\n                   &k);\n    reduce_matmul->Launch(ctx->stream(), m, n, k, 1.0, dy_mul_x0, ones, 0.0, dmatmul_result0);\n\n    // step3: Get dx\n    T* dx_buf = reinterpret_cast<T*>(tmp_buffer->mut_dptr<char>()\n                                     + GetCudaAlignedSize(dy_elem_cnt * sizeof(T))\n                                     + GetCudaAlignedSize(batch_size * sizeof(T)));\n    DimVector dmatmul_result_shape(2);\n    dmatmul_result_shape.at(0) = batch_size;\n    dmatmul_result_shape.at(1) = 1;  // todo change to hidden size\n    DimVector weight_shape(2);\n    weight->shape_view().ToDimVector(&weight_shape);\n    InferMatmulMNK(dmatmul_result_shape, weight_shape, /*trans_a=*/false, /*trans_b=*/false, &m, &n,\n                   &k);\n    reduce_matmul->Launch(ctx->stream(), m, n, k, 1.0, dmatmul_result0, weight->dptr(), 0.0,\n                          reinterpret_cast<void*>(dx_buf));\n    OF_CUDA_CHECK(cuda::elementwise::Binary(AddOp<T>(), dy_elem_cnt, dx->mut_dptr<T>(), dx_buf,\n                                            dy->dptr<T>(),\n                                            ctx->stream()->As<ep::CudaStream>()->cuda_stream()));\n\n    // step4: Get dw.\n    DimVector x_shape(2);\n    x->shape_view().ToDimVector(&x_shape);\n\n    InferMatmulMNK(dmatmul_result_shape, x_shape, /*trans_a=*/true, /*trans_b=*/false, &m, &n, &k);\n    auto weight_grad_matmul = NewWeightGradMatmulPrimitive(ctx);\n    CHECK(weight_grad_matmul);\n    weight_grad_matmul->Launch(ctx->stream(), m, n, k, 1.0, dmatmul_result0, x->dptr(), 0.0,\n                               dw->mut_dptr());\n\n    // step5: Get dx0.\n    DispatchBroadcastMulIndexType<T>(ctx->stream(), dy->dptr<T>(), matmul_result->dptr<T>(),\n                                     dx0->mut_dptr<T>(), hidden_size, dy_elem_cnt);\n  }\n};\n\n#define REGISTER_FUSED_CROSS_FEATURE_INTERACTION_V1_GRAD_KERNEL(dtype)                        \\\n  REGISTER_USER_KERNEL(\"fused_cross_feature_interaction_v1_grad\")                             \\\n      .SetCreateFn<FusedCrossFeatureInteractionGradKernel<dtype>>()                           \\\n      .SetIsMatchedHob((HobDeviceType() == DeviceType::kCUDA)                                 \\\n                       && (HobDataType(\"dy\", 0) == GetDataType<dtype>::value)                 \\\n                       && ReduceMatmulPrimitiveExists() && WeightGradMatmulPrimitiveExists()) \\\n      .SetInferTmpSizeFn([](InferContext* ctx) {                                              \\\n        size_t tmp_size = 0;                                                                  \\\n        const TensorDesc& dy = ctx->InputTensorDesc(\"dy\", 0);                                 \\\n        const int64_t dy_elem_cnt = dy.shape().elem_cnt();                                    \\\n        const int64_t batch_size = dy.shape().At(0);                                          \\\n        size_t dy_mul_x0_size = GetCudaAlignedSize(dy_elem_cnt * sizeof(dtype));              \\\n        size_t dmatmul_result_size = GetCudaAlignedSize(batch_size * sizeof(dtype));          \\\n        size_t dx_buf_size = dy_mul_x0_size;                                                  \\\n        tmp_size = dy_mul_x0_size + dmatmul_result_size + dx_buf_size;                        \\\n        return tmp_size;                                                                      \\\n      });\n\nREGISTER_FUSED_CROSS_FEATURE_INTERACTION_V1_GRAD_KERNEL(float)\nREGISTER_FUSED_CROSS_FEATURE_INTERACTION_V1_GRAD_KERNEL(half)\n\ntemplate<typename T>\nclass FusedCrossFeatureInteractionV2GradKernel final : public OpKernel, public CudaGraphSupport {\n public:\n  FusedCrossFeatureInteractionV2GradKernel() = default;\n  ~FusedCrossFeatureInteractionV2GradKernel() = default;\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(KernelComputeContext* ctx) const override {\n    const Tensor* dy = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    const Tensor* weight = ctx->Tensor4ArgNameAndIndex(\"weight\", 0);\n    const Tensor* bias = ctx->Tensor4ArgNameAndIndex(\"bias\", 0);\n    const Tensor* x0 = ctx->Tensor4ArgNameAndIndex(\"x0\", 0);\n    const Tensor* x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    const Tensor* matmul_result = ctx->Tensor4ArgNameAndIndex(\"matmul_result\", 0);\n\n    const int64_t batch_size = dy->shape_view().At(0);\n    const int64_t in_size = weight->shape_view().At(1);\n    const int64_t hidden_size = weight->shape_view().At(0);\n    const int64_t dy_elem_cnt = dy->shape_view().elem_cnt();\n\n    Tensor* dx0 = ctx->Tensor4ArgNameAndIndex(\"dx0\", 0);\n    Tensor* dw = ctx->Tensor4ArgNameAndIndex(\"dw\", 0);\n    Tensor* dx = ctx->Tensor4ArgNameAndIndex(\"dx\", 0);\n    Tensor* dbias = ctx->Tensor4ArgNameAndIndex(\"dbias\", 0);\n    Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n\n    // step1: Get dx0.\n    DispatchBroadcastAddElementwiseMulIndexType<T>(ctx->stream(), matmul_result->dptr<T>(),\n                                                   bias->dptr<T>(), dy->dptr<T>(),\n                                                   dx0->mut_dptr<T>(), hidden_size, dy_elem_cnt);\n\n    // step2: Get dmatmul_result0.\n    T* dmatmul_result0 = reinterpret_cast<T*>(tmp_buffer->mut_dptr());\n    OF_CUDA_CHECK(cuda::elementwise::Binary(MulOp<T>(), dy_elem_cnt, dmatmul_result0, dy->dptr<T>(),\n                                            x0->dptr<T>(),\n                                            ctx->stream()->As<ep::CudaStream>()->cuda_stream()));\n    // step3: Get dx\n    T* dx_buf = reinterpret_cast<T*>(tmp_buffer->mut_dptr<char>()\n                                     + GetCudaAlignedSize(dy_elem_cnt * sizeof(T)));\n    DimVector dmatmul_result_shape(2);\n    dmatmul_result_shape.at(0) = batch_size;\n    dmatmul_result_shape.at(1) = hidden_size;\n    DimVector weight_shape(2);\n    weight->shape_view().ToDimVector(&weight_shape);\n    size_t m = 0, n = 0, k = 0;\n    InferMatmulMNK(dmatmul_result_shape, weight_shape, /*trans_a=*/false, /*trans_b=*/false, &m, &n,\n                   &k);\n    auto reduce_matmul = NewReduceMatmulPrimitive(ctx);\n    CHECK(reduce_matmul);\n    reduce_matmul->Launch(ctx->stream(), m, n, k, 1.0, dmatmul_result0, weight->dptr(), 0.0,\n                          reinterpret_cast<void*>(dx_buf));\n    OF_CUDA_CHECK(cuda::elementwise::Binary(AddOp<T>(), dy_elem_cnt, dx->mut_dptr<T>(), dx_buf,\n                                            dy->dptr<T>(),\n                                            ctx->stream()->As<ep::CudaStream>()->cuda_stream()));\n\n    // step4: Get dw.\n    DimVector x_shape(2);\n    x->shape_view().ToDimVector(&x_shape);\n\n    InferMatmulMNK(dmatmul_result_shape, x_shape, /*trans_a=*/true, /*trans_b=*/false, &m, &n, &k);\n    auto weight_grad_matmul = NewWeightGradMatmulPrimitive(ctx);\n    CHECK(weight_grad_matmul);\n    weight_grad_matmul->Launch(ctx->stream(), m, n, k, 1.0, dmatmul_result0, x->dptr(), 0.0,\n                               dw->mut_dptr());\n\n    // step5: Get dbias.\n    const T* ones = nullptr;\n    auto* cuda_device = dynamic_cast<ep::CudaDevice*>(ctx->stream()->device());\n    if (cuda_device != nullptr) {\n      ones = static_cast<const T*>(cuda_device->GetConstOnes(dy->data_type(), batch_size));\n    }\n    DimVector dy_shape(2);\n    dy->shape_view().ToDimVector(&dy_shape);\n    DimVector ones_buf_shape(2);\n    ones_buf_shape.at(0) = 1;\n    ones_buf_shape.at(1) = batch_size;\n    InferMatmulMNK(ones_buf_shape, dy_shape, /*trans_a=*/false, /*trans_b=*/false, &m, &n, &k);\n    reduce_matmul->Launch(ctx->stream(), m, n, k, 1.0, ones,\n                          reinterpret_cast<void*>(dmatmul_result0), 0.0, dbias->mut_dptr());\n  }\n};\n\n#define REGISTER_FUSED_CROSS_FEATURE_INTERACTION_V2_GRAD_KERNEL(dtype)                        \\\n  REGISTER_USER_KERNEL(\"fused_cross_feature_interaction_v2_grad\")                             \\\n      .SetCreateFn<FusedCrossFeatureInteractionV2GradKernel<dtype>>()                         \\\n      .SetIsMatchedHob((HobDeviceType() == DeviceType::kCUDA)                                 \\\n                       && (HobDataType(\"dy\", 0) == GetDataType<dtype>::value)                 \\\n                       && ReduceMatmulPrimitiveExists() && WeightGradMatmulPrimitiveExists()) \\\n      .SetInferTmpSizeFn([](InferContext* ctx) {                                              \\\n        size_t tmp_size = 0;                                                                  \\\n        const TensorDesc& dy = ctx->InputTensorDesc(\"dy\", 0);                                 \\\n        const int64_t dy_elem_cnt = dy.shape().elem_cnt();                                    \\\n        size_t dmatmul_result_size = GetCudaAlignedSize(dy_elem_cnt * sizeof(dtype));         \\\n        size_t dx_buf_size = dmatmul_result_size;                                             \\\n        tmp_size = dmatmul_result_size + dx_buf_size;                                         \\\n        return tmp_size;                                                                      \\\n      });\n\nREGISTER_FUSED_CROSS_FEATURE_INTERACTION_V2_GRAD_KERNEL(float)\nREGISTER_FUSED_CROSS_FEATURE_INTERACTION_V2_GRAD_KERNEL(half)\n\n}  // namespace user_op\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/fused_dot_feature_interaction_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n#include \"oneflow/core/ep/include/primitive/copy_nd.h\"\n#include \"oneflow/core/ep/include/primitive/batch_matmul.h\"\n#include \"oneflow/core/kernel/cuda_graph_support.h\"\n#include \"oneflow/core/cuda/atomic.cuh\"\n#include <mma.h>\n\nnamespace oneflow {\n\nnamespace {\n\n__global__ void GenerateGatherIndicesGpu(const int32_t elem_cnt, const int32_t stride,\n                                         const int32_t in_cols, const int32_t offset,\n                                         int32_t* gather_indices) {\n  CUDA_1D_KERNEL_LOOP(i, elem_cnt) {\n    const int32_t row = i / stride;\n    const int32_t col = i - row * stride;\n    if (col < row + offset) {\n      int32_t in_index = row * in_cols + col;\n      int32_t idx = row * (offset + row - 1 + offset) / 2 + col;\n      gather_indices[idx] = in_index;\n    }\n  }\n}\n\ntemplate<typename T>\n__global__ void GatherConcatGpu(int32_t elem_cnt, int32_t out_cols, int32_t valid_out_cols,\n                                int32_t in_cols, int32_t output_concat_end_dim,\n                                const int32_t* gather_indices, const T* in,\n                                const T* output_concat_ptr, T* out_ptr) {\n  CUDA_1D_KERNEL_LOOP(i, elem_cnt) {\n    const int32_t row = i / out_cols;\n    const int32_t col = i - row * out_cols;\n    T out_val;\n    if (col < output_concat_end_dim) {\n      const int32_t output_concat_idx = row * output_concat_end_dim + col;\n      out_val = output_concat_ptr[output_concat_idx];\n    } else if (col < valid_out_cols) {\n      const int32_t gather_col_idx = gather_indices[col - output_concat_end_dim];\n      const int32_t in_offset = row * in_cols + gather_col_idx;\n      out_val = in[in_offset];\n    } else {\n      out_val = 0;\n    }\n    out_ptr[i] = out_val;\n  }\n}\n\ntemplate<typename T>\n__global__ void ScatterSplitAddTransposeGpu(int32_t elem_cnt, int32_t stride_dim, int32_t out_dim,\n                                            int32_t in_grad_stride, int32_t in_grad_matrix_dim,\n                                            int32_t in_grad_matrix_valid_dim,\n                                            int32_t output_concat_end_dim, const int32_t offset,\n                                            const T* dy, T* output_concat_grad, T* in_grad) {\n  CUDA_1D_KERNEL_LOOP(i, elem_cnt) {\n    const int32_t row = i / stride_dim;\n    const int32_t col = i - row * stride_dim;\n    if (col < output_concat_end_dim) {\n      output_concat_grad[row * output_concat_end_dim + col] = dy[row * out_dim + col];\n    } else {\n      int32_t in_col_id = col - output_concat_end_dim;\n      const int32_t matrix_row = in_col_id / in_grad_matrix_dim;\n      const int32_t matrix_col = in_col_id - matrix_row * in_grad_matrix_dim;\n      T grad_val = 0;\n      const T* row_dy = dy + row * out_dim + output_concat_end_dim;\n      if (matrix_row < in_grad_matrix_valid_dim && matrix_col < in_grad_matrix_valid_dim) {\n        if (matrix_col < matrix_row) {\n          int32_t dy_col_idx = matrix_row * (offset + matrix_row - 1 + offset) / 2 + matrix_col;\n          grad_val = row_dy[dy_col_idx];\n        } else if (matrix_row < matrix_col) {\n          // transpose add\n          int32_t trans_row_id = matrix_col;\n          int32_t trans_col_id = matrix_row;\n          int32_t dy_col_idx =\n              trans_row_id * (offset + trans_row_id - 1 + offset) / 2 + trans_col_id;\n          grad_val = row_dy[dy_col_idx];\n        } else if ((matrix_row == matrix_col) && (offset == 1)) {\n          int32_t dy_col_idx = matrix_row * (offset + matrix_row - 1 + offset) / 2 + matrix_col;\n          grad_val = row_dy[dy_col_idx] * static_cast<T>(2);\n        }\n      }\n      int32_t in_grad_offset = row * in_grad_stride + in_col_id;\n      in_grad[in_grad_offset] = grad_val;\n    }\n  }\n}\n\ntemplate<typename T>\nvoid ConcatFeatures(user_op::KernelComputeContext* ctx, int64_t dst_rows, int64_t dst_cols,\n                    void* dst_ptr) {\n  const int64_t feature_input_size = ctx->input_size(\"features\");\n  auto primitive = ep::primitive::NewPrimitive<ep::primitive::CopyNdFactory>(DeviceType::kCUDA, 2);\n  DimVector dst_shape = {dst_rows, dst_cols};\n  int64_t out_col_offset = 0;\n  for (int64_t i = 0; i < feature_input_size; ++i) {\n    const user_op::Tensor* feature = ctx->Tensor4ArgNameAndIndex(\"features\", i);\n    const int64_t feature_rows = feature->shape_view().At(0);\n    const int64_t feature_cols = feature->shape_view().Count(1);\n    DimVector dst_pos_vec = {0, out_col_offset};\n    DimVector src_shape = {feature_rows, feature_cols};\n    DimVector src_pos_vec = {0, 0};\n    DimVector extent_vec = {feature_rows, feature_cols};\n    primitive->Launch(ctx->stream(), feature->data_type(), 2, dst_ptr, dst_shape.data(),\n                      dst_pos_vec.data(), feature->dptr<T>(), src_shape.data(), src_pos_vec.data(),\n                      extent_vec.data());\n    out_col_offset += feature_cols;\n  }\n  int64_t pad_dim = dst_cols - out_col_offset;\n  if (pad_dim > 0) {\n    char* out_ptr = reinterpret_cast<char*>(dst_ptr) + out_col_offset * sizeof(T);\n    OF_CUDA_CHECK(cudaMemset2DAsync(out_ptr, dst_cols * sizeof(T), 0, pad_dim * sizeof(T), dst_rows,\n                                    ctx->stream()->As<ep::CudaStream>()->cuda_stream()));\n  }\n}\n\ntemplate<typename T>\nvoid GatherConcatKernel(ep::Stream* stream, int32_t elem_cnt, int32_t out_dim,\n                        int32_t valid_out_dim, int32_t features_concated_dim,\n                        int32_t concated_padded_dim, int32_t output_concat_end_dim,\n                        bool self_interaction, const T* matmul_out, const T* output_concat_ptr,\n                        int32_t* gather_indices_ptr, T* out_ptr) {\n  cudaStream_t cuda_stream = stream->As<ep::CudaStream>()->cuda_stream();\n  const int32_t gen_indices_elem_cnt = features_concated_dim * features_concated_dim;\n  int32_t offset = self_interaction ? 1 : 0;\n  GenerateGatherIndicesGpu<<<BlocksNum4ThreadsNum(gen_indices_elem_cnt), kCudaThreadsNumPerBlock, 0,\n                             cuda_stream>>>(gen_indices_elem_cnt, features_concated_dim,\n                                            concated_padded_dim, offset, gather_indices_ptr);\n\n  int32_t matmul_stride = concated_padded_dim * concated_padded_dim;\n  GatherConcatGpu<<<BlocksNum4ThreadsNum(elem_cnt), kCudaThreadsNumPerBlock, 0, cuda_stream>>>(\n      elem_cnt, out_dim, valid_out_dim, matmul_stride, output_concat_end_dim, gather_indices_ptr,\n      matmul_out, output_concat_ptr, out_ptr);\n}\n\ntemplate<typename T>\nvoid ScatterSplitAddTranspose(ep::Stream* stream, int32_t batch_size, int32_t out_dim,\n                              int32_t concated_padded_dim, int32_t features_concated_dim,\n                              int32_t output_concat_end_dim, const bool self_interaction,\n                              const T* dy, T* output_concat_grad, T* matmul_out_grad_ptr) {\n  int32_t stride_dim = output_concat_end_dim + concated_padded_dim * concated_padded_dim;\n  int32_t matmul_stride = concated_padded_dim * concated_padded_dim;\n  const int32_t elem_cnt = batch_size * stride_dim;\n  int32_t offset = self_interaction ? 1 : 0;\n  ScatterSplitAddTransposeGpu<<<BlocksNum4ThreadsNum(elem_cnt), kCudaThreadsNumPerBlock, 0,\n                                stream->As<ep::CudaStream>()->cuda_stream()>>>(\n      elem_cnt, stride_dim, out_dim, matmul_stride, concated_padded_dim, features_concated_dim,\n      output_concat_end_dim, offset, dy, output_concat_grad, matmul_out_grad_ptr);\n}\n\ntemplate<typename T>\nvoid ConcatFeaturesGrad(user_op::KernelComputeContext* ctx, const int64_t batch_size,\n                        const int64_t concated_padded_dim, const int64_t vector_size,\n                        const T* concated_features_grad) {\n  auto primitive = ep::primitive::NewPrimitive<ep::primitive::CopyNdFactory>(DeviceType::kCUDA, 2);\n  DimVector src_shape = {batch_size, concated_padded_dim * vector_size};\n  int64_t in_col_offset = 0;\n  for (int64_t i = 0; i < ctx->output_size(\"features_grad\"); ++i) {\n    user_op::Tensor* feature_grad = ctx->Tensor4ArgNameAndIndex(\"features_grad\", i);\n    const int64_t feature_grad_rows = feature_grad->shape_view().At(0);\n    const int64_t feature_grad_cols = feature_grad->shape_view().Count(1);\n    DimVector dst_shape = {feature_grad_rows, feature_grad_cols};\n    DimVector dst_pos_vec = {0, 0};\n    DimVector src_pos_vec = {0, in_col_offset};\n    DimVector extent_vec = {feature_grad_rows, feature_grad_cols};\n    in_col_offset += feature_grad_cols;\n    primitive->Launch(ctx->stream(), feature_grad->data_type(), 2, feature_grad->mut_dptr(),\n                      dst_shape.data(), dst_pos_vec.data(), concated_features_grad,\n                      src_shape.data(), src_pos_vec.data(), extent_vec.data());\n  }\n}\n\ntemplate<typename T>\nstruct DefaultComputeType {\n  using type = T;\n};\n\ntemplate<>\nstruct DefaultComputeType<half> {\n  using type = float;\n};\n\ntemplate<typename T, size_t pack_size>\nstruct alignas(sizeof(T) * pack_size) Pack {\n  T elem[pack_size];\n};\n\nint64_t GetPaddedDim(int64_t dim) {\n  const int64_t align_dim = 16;\n  const int64_t padded_dim = (dim + align_dim - 1) / align_dim * align_dim;\n  return padded_dim;\n}\n\ntemplate<typename T, int32_t max_in>\nstruct DotFwdParam {\n  const T* in[max_in];\n  int32_t in_feature_dim[max_in];\n  int32_t dim_start_offset[max_in];\n  const T* sparse_feature;\n  const uint32_t* sparse_indices;\n  int32_t sparse_dim;\n  int32_t sparse_dim_start;\n  int32_t features_dim;\n  const T* output_concat;\n  int32_t output_concat_size;\n  T* out;\n  int32_t num_in;\n};\n\n#if __CUDA_ARCH__ >= 700\ntemplate<typename T, typename AccType, int m, int n, int k, class ALayout, class BLayout>\nclass Wmma {\n public:\n  __device__ void LoadA(const T* ptr, int ldm) { nvcuda::wmma::load_matrix_sync(a_, ptr, ldm); }\n  __device__ void LoadB(const T* ptr, int ldm) { nvcuda::wmma::load_matrix_sync(b_, ptr, ldm); }\n  __device__ void Store(AccType* ptr, int ldm) {\n    nvcuda::wmma::store_matrix_sync(ptr, acc_, ldm, nvcuda::wmma::mem_row_major);\n  }\n  __device__ void Mma() { nvcuda::wmma::mma_sync(acc_, a_, b_, acc_); }\n  __device__ void InitAcc() { nvcuda::wmma::fill_fragment(acc_, 0.0f); }\n  __device__ __forceinline__ T Convert(T src) { return src; }\n\n private:\n  nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, m, n, k, T, ALayout> a_;\n  nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, m, n, k, T, BLayout> b_;\n  nvcuda::wmma::fragment<nvcuda::wmma::accumulator, m, n, k, AccType> acc_;\n};\n\ntemplate<typename AccType, int m, int n, int k, class ALayout, class BLayout>\nclass Wmma<float, AccType, m, n, k, ALayout, BLayout> {\n public:\n#if __CUDA_ARCH__ >= 800\n  __device__ void LoadA(const float* ptr, int ldm) { nvcuda::wmma::load_matrix_sync(a_, ptr, ldm); }\n  __device__ void LoadB(const float* ptr, int ldm) { nvcuda::wmma::load_matrix_sync(b_, ptr, ldm); }\n  __device__ void Mma() { nvcuda::wmma::mma_sync(acc_, a_, b_, acc_); }\n  __device__ __forceinline__ float Convert(float src) { return nvcuda::wmma::__float_to_tf32(src); }\n  __device__ void Store(AccType* ptr, int ldm) {\n    nvcuda::wmma::store_matrix_sync(ptr, acc_, ldm, nvcuda::wmma::mem_row_major);\n  }\n  __device__ void InitAcc() { nvcuda::wmma::fill_fragment(acc_, 0.0f); }\n#else\n  __device__ void LoadA(const float* ptr, int ldm) { __trap(); }\n  __device__ void LoadB(const float* ptr, int ldm) { __trap(); }\n  __device__ void Mma() { __trap(); }\n  __device__ __forceinline__ float Convert(float src) { return src; }\n  __device__ void Store(AccType* ptr, int ldm) { __trap(); }\n  __device__ void InitAcc() { __trap(); }\n#endif\n\n private:\n#if __CUDA_ARCH__ >= 800\n  nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, m, n, k, nvcuda::wmma::precision::tf32, ALayout>\n      a_;\n  nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, m, n, k, nvcuda::wmma::precision::tf32, BLayout>\n      b_;\n  nvcuda::wmma::fragment<nvcuda::wmma::accumulator, m, n, k, AccType> acc_;\n#endif\n};\n#endif  //__CUDA_ARCH__ >= 700\n\nconstexpr int kUnrollDim = 2;\ntemplate<typename T, typename ComputeType, int32_t max_in, int32_t pack_size, int mn_tile_dim,\n         int k_tile_dim>\n__global__ void DotFeatureInteractionWmmaImpl(\n    int m_num_tiles, int k_num_tiles, int64_t batch_size, int padded_num_rows, int vector_num_pack,\n    int padded_vector_num_pack, int out_num_cols, int out_num_cols_num_pack, int in_shared_mem_cols,\n    int in_shared_mem_cols_num_pack, int acc_shared_mem_cols, int acc_shared_mem_cols_num_pack,\n    int offset, int output_padding, DotFwdParam<T, max_in> param) {\n#if __CUDA_ARCH__ >= 700\n  Wmma<T, ComputeType, mn_tile_dim, mn_tile_dim, k_tile_dim, nvcuda::wmma::row_major,\n       nvcuda::wmma::col_major>\n      wmma;\n  extern __shared__ __align__(sizeof(double)) unsigned char shared_buf[];\n  int warp_id = threadIdx.y;\n  T* buf = reinterpret_cast<T*>(shared_buf);\n  Pack<T, pack_size>* buf_pack = reinterpret_cast<Pack<T, pack_size>*>(shared_buf);\n  ComputeType* acc_buf =\n      reinterpret_cast<ComputeType*>(shared_buf + padded_num_rows * in_shared_mem_cols * sizeof(T));\n  int batch_idx = blockIdx.x;\n  T* batch_out = param.out + batch_idx * out_num_cols;\n  Pack<T, pack_size>* batch_out_pack =\n      reinterpret_cast<Pack<T, pack_size>*>(param.out) + batch_idx * out_num_cols_num_pack;\n  const int output_concat_size = param.output_concat_size;\n  const T* batch_output_concat =\n      (param.output_concat) ? (param.output_concat + batch_idx * output_concat_size) : nullptr;\n  const uint32_t* batch_sparse_indices =\n      (param.sparse_indices) ? (param.sparse_indices + batch_idx * param.sparse_dim) : nullptr;\n  const Pack<T, pack_size>* sparse_feature_pack =\n      (param.sparse_feature) ? reinterpret_cast<const Pack<T, pack_size>*>(param.sparse_feature)\n                             : nullptr;\n  for (int col = threadIdx.x; col < vector_num_pack; col += blockDim.x) {\n// load dense feature to shared_mem\n#pragma unroll\n    for (int i = 0; i < max_in; ++i) {\n      if (i >= param.num_in) { break; }\n      const Pack<T, pack_size>* batch_in = reinterpret_cast<const Pack<T, pack_size>*>(param.in[i])\n                                           + batch_idx * param.in_feature_dim[i] * vector_num_pack;\n      for (int j = threadIdx.y * kUnrollDim; j < param.in_feature_dim[i];\n           j += blockDim.y * kUnrollDim) {\n#pragma unroll\n        for (int k = 0; k < kUnrollDim; ++k) {\n          int in_row = j + k;\n          if (in_row >= param.in_feature_dim[i]) { break; }\n          int buf_row = param.dim_start_offset[i] + in_row;\n          Pack<T, pack_size> pack_in_val = batch_in[in_row * vector_num_pack + col];\n#pragma unroll\n          for (int t = 0; t < pack_size; ++t) {\n            pack_in_val.elem[t] = wmma.Convert(pack_in_val.elem[t]);\n          }\n          buf_pack[buf_row * in_shared_mem_cols_num_pack + col] = pack_in_val;\n        }\n      }\n    }\n    // load sparse feature to shared_mem\n    for (int j = threadIdx.y * kUnrollDim; j < param.sparse_dim; j += blockDim.y * kUnrollDim) {\n#pragma unroll\n      for (int k = 0; k < kUnrollDim; ++k) {\n        int in_row = j + k;\n        if (in_row >= param.sparse_dim) { break; }\n        int buf_row = param.sparse_dim_start + in_row;\n        int sparse_in_row = batch_sparse_indices[in_row];\n        Pack<T, pack_size> pack_in_val = sparse_feature_pack[sparse_in_row * vector_num_pack + col];\n#pragma unroll\n        for (int t = 0; t < pack_size; ++t) {\n          pack_in_val.elem[t] = wmma.Convert(pack_in_val.elem[t]);\n        }\n        buf_pack[buf_row * in_shared_mem_cols_num_pack + col] = pack_in_val;\n      }\n    }\n  }\n  Pack<T, pack_size> zero;\n#pragma unroll\n  for (int k = 0; k < pack_size; ++k) { zero.elem[k] = wmma.Convert(0); }\n  for (int row = threadIdx.y; row < param.features_dim; row += blockDim.y) {\n    for (int col = vector_num_pack + threadIdx.x; col < padded_vector_num_pack; col += blockDim.x) {\n      buf_pack[row * in_shared_mem_cols_num_pack + col] = zero;\n    }\n  }\n  __syncthreads();\n  for (int blocks_id = warp_id; blocks_id < m_num_tiles * m_num_tiles; blocks_id += blockDim.y) {\n    int blocks_row = blocks_id / m_num_tiles;\n    int blocks_col = blocks_id - blocks_row * m_num_tiles;\n    if (blocks_row >= blocks_col) {\n      wmma.InitAcc();\n      for (int step = 0; step < k_num_tiles; ++step) {\n        T* tile_a_ptr = buf + blocks_row * mn_tile_dim * in_shared_mem_cols + step * k_tile_dim;\n        T* tile_b_ptr = buf + blocks_col * mn_tile_dim * in_shared_mem_cols + step * k_tile_dim;\n        wmma.LoadA(tile_a_ptr, in_shared_mem_cols);\n        wmma.LoadB(tile_b_ptr, in_shared_mem_cols);\n        wmma.Mma();\n      }\n      ComputeType* tile_ptr =\n          acc_buf + blocks_row * mn_tile_dim * acc_shared_mem_cols + blocks_col * mn_tile_dim;\n      wmma.Store(tile_ptr, acc_shared_mem_cols);\n    }\n  }\n  __syncthreads();\n  T* emb_out = batch_out + output_concat_size;\n  for (int base_row = threadIdx.y * kUnrollDim; base_row < param.features_dim;\n       base_row += kUnrollDim * blockDim.y) {\n#pragma unroll\n    for (int k = 0; k < kUnrollDim; ++k) {\n      int row = base_row + k;\n      if (row >= param.features_dim) { break; }\n      for (int col = threadIdx.x; col < param.features_dim; col += blockDim.x) {\n        if (col < row + offset) {\n          int64_t idx = row * (offset + row - 1 + offset) / 2 + col;\n          emb_out[idx] = static_cast<T>(acc_buf[row * acc_shared_mem_cols + col]);\n        }\n      }\n    }\n  }\n  int thread_id = threadIdx.y * blockDim.x + threadIdx.x;\n  for (int i = thread_id; i < output_concat_size; i += blockDim.x * blockDim.y) {\n    batch_out[i] = batch_output_concat[i];\n  }\n  for (int i = thread_id; i < output_padding; i += blockDim.x * blockDim.y) {\n    batch_out[out_num_cols - 1 - i] = 0;\n  }\n#else\n  __trap();\n#endif  // __CUDA_ARCH__ >= 700\n}\n\ntemplate<typename T>\nstruct KTileDim {\n  static const int val = 16;\n};\n\ntemplate<>\nstruct KTileDim<float> {\n  static const int val = 8;\n};\n\ntemplate<typename T, int max_in, int32_t pack_size>\nstruct DotFeatureInteractionKernel {\n  static bool Launch(ep::Stream* stream, int64_t batch_size, int concated_padded_dim,\n                     int vector_size, int out_num_cols, bool self_interaction, int output_padding,\n                     const DotFwdParam<T, max_in>& param) {\n    const int block_size = 128;\n    const int block_dim_x = 32;\n    const int block_dim_y = block_size / block_dim_x;\n    const int num_blocks = batch_size;\n    const int mn_tile_dim = 16;\n    const int k_tile_dim = KTileDim<T>::val;\n    const int64_t padded_vector_size = GetPaddedDim(vector_size);\n    const int m_num_tiles = concated_padded_dim / mn_tile_dim;\n    const int k_num_tiles = padded_vector_size / k_tile_dim;\n    const int skew_in = 8;\n    const int skew_acc = 8;\n    const int in_shared_mem_num_cols = padded_vector_size + skew_in;\n    const int acc_shared_mem_num_cols = concated_padded_dim + skew_acc;\n    const size_t in_shared_mem_bytes = concated_padded_dim * in_shared_mem_num_cols * sizeof(T);\n    using ComputeType = typename DefaultComputeType<T>::type;\n    const size_t acc_shared_mem_bytes =\n        concated_padded_dim * acc_shared_mem_num_cols * sizeof(ComputeType);\n    const size_t total_shared_mem_bytes = in_shared_mem_bytes + acc_shared_mem_bytes;\n    const int32_t offset = self_interaction ? 1 : 0;\n    const int out_num_cols_num_pack = out_num_cols / pack_size;\n    const int vector_num_pack = vector_size / pack_size;\n    const int padded_vector_num_pack = padded_vector_size / pack_size;\n    const int in_shared_mem_cols_num_pack = in_shared_mem_num_cols / pack_size;\n    const int acc_shared_mem_cols_num_pack = acc_shared_mem_num_cols / pack_size;\n    int max_active_blocks;\n    OF_CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(\n        &max_active_blocks,\n        DotFeatureInteractionWmmaImpl<T, ComputeType, max_in, pack_size, mn_tile_dim, k_tile_dim>,\n        block_size, total_shared_mem_bytes));\n    if (max_active_blocks <= 0) { return false; }\n    cudaStream_t cuda_stream = stream->As<ep::CudaStream>()->cuda_stream();\n    DotFeatureInteractionWmmaImpl<T, ComputeType, max_in, pack_size, mn_tile_dim, k_tile_dim>\n        <<<num_blocks, dim3(block_dim_x, block_dim_y), total_shared_mem_bytes, cuda_stream>>>(\n            m_num_tiles, k_num_tiles, batch_size, concated_padded_dim, vector_num_pack,\n            padded_vector_num_pack, out_num_cols, out_num_cols_num_pack, in_shared_mem_num_cols,\n            in_shared_mem_cols_num_pack, acc_shared_mem_num_cols, acc_shared_mem_cols_num_pack,\n            offset, output_padding, param);\n    return true;\n  }\n};\n\ntemplate<typename T, int32_t max_in>\nstruct DotBwdParam {\n  const T* out_grad;\n  const T* in[max_in];\n  T* in_grad[max_in];\n  T* output_concat_grad;\n  const T* sparse_feature;\n  const uint32_t* sparse_indices;\n  int32_t sparse_dim;\n  int32_t sparse_dim_start;\n  T* sparse_feature_grad;\n  int32_t output_concat_size;\n  int32_t in_feature_dim[max_in];\n  int32_t dim_start_offset[max_in];\n  int32_t features_dim;\n  int32_t num_in;\n};\n\ntemplate<typename T, typename ComputeType, int32_t pack_size>\n__device__ __inline__ void AtomicAdd(Pack<T, pack_size>* address,\n                                     Pack<ComputeType, pack_size> val) {\n#pragma unroll\n  for (int i = 0; i < pack_size; ++i) {\n    cuda::atomic::Add(reinterpret_cast<T*>(address) + i, static_cast<T>(val.elem[i]));\n  }\n}\n\ntemplate<>\n__device__ __inline__ void AtomicAdd<half, float, 2>(Pack<half, 2>* address, Pack<float, 2> val) {\n  half2 h2_val;\n  h2_val.x = static_cast<half>(val.elem[0]);\n  h2_val.y = static_cast<half>(val.elem[1]);\n  cuda::atomic::Add(reinterpret_cast<half2*>(address), h2_val);\n}\n\ntemplate<typename T, typename ComputeType, int32_t max_in, int32_t pack_size,\n         int32_t sparse_grad_pack_size, int mn_tile_dim, int k_tile_dim>\n__global__ void DotFeatureInteractionBackwardWmmaImpl(\n    int m_num_tiles, int n_num_tiles, int k_num_tiles, int64_t batch_size, int padded_num_rows,\n    int vector_num_pack, int vector_num_sparse_grad_pack, int padded_vector_num_pack,\n    int out_num_cols, int in_shared_mem_cols, int in_shared_mem_cols_num_pack,\n    int in_shared_mem_cols_num_sparse_grad_pack, int matrix_out_grad_shared_mem_cols, int offset,\n    DotBwdParam<T, max_in> param) {\n#if __CUDA_ARCH__ >= 700\n  Wmma<T, ComputeType, mn_tile_dim, mn_tile_dim, k_tile_dim, nvcuda::wmma::row_major,\n       nvcuda::wmma::row_major>\n      wmma;\n  extern __shared__ __align__(sizeof(double)) unsigned char shared_buf[];\n  int warp_id = threadIdx.y;\n  T* in_buf = reinterpret_cast<T*>(shared_buf);\n  Pack<T, pack_size>* in_buf_pack = reinterpret_cast<Pack<T, pack_size>*>(shared_buf);\n  T* matrix_out_grad_buf = in_buf + padded_num_rows * in_shared_mem_cols;\n  ComputeType* in_grad_buf = reinterpret_cast<ComputeType*>(\n      matrix_out_grad_buf + padded_num_rows * matrix_out_grad_shared_mem_cols);\n  Pack<ComputeType, pack_size>* in_grad_buf_pack =\n      reinterpret_cast<Pack<ComputeType, pack_size>*>(in_grad_buf);\n\n  int batch_idx = blockIdx.x;\n  const T* batch_out_grad = param.out_grad + batch_idx * out_num_cols;\n  const int output_concat_size = param.output_concat_size;\n  T* batch_output_concat_grad = (param.output_concat_grad)\n                                    ? (param.output_concat_grad + batch_idx * output_concat_size)\n                                    : nullptr;\n  const uint32_t* batch_sparse_indices =\n      (param.sparse_indices) ? (param.sparse_indices + batch_idx * param.sparse_dim) : nullptr;\n  const Pack<T, pack_size>* sparse_feature_pack =\n      (param.sparse_feature) ? reinterpret_cast<const Pack<T, pack_size>*>(param.sparse_feature)\n                             : nullptr;\n\n  int features_dim = param.features_dim;\n  // 1.split out_grad to concat_out_grad and matrix_out_grad buf\n  int thread_id = threadIdx.x + threadIdx.y * blockDim.x;\n  for (int i = thread_id; i < output_concat_size; i += blockDim.x * blockDim.y) {\n    batch_output_concat_grad[i] = batch_out_grad[i];\n  }\n  const T* batch_interaction_out_grad = batch_out_grad + output_concat_size;\n  for (int matrix_row = threadIdx.y; matrix_row < padded_num_rows; matrix_row += blockDim.y) {\n    for (int matrix_col = threadIdx.x; matrix_col < padded_num_rows; matrix_col += blockDim.x) {\n      const int64_t i = matrix_row * matrix_out_grad_shared_mem_cols + matrix_col;\n      T grad_val = 0;\n      if (matrix_row < features_dim && matrix_col < features_dim) {\n        if (matrix_col < matrix_row) {\n          int32_t out_grad_col = matrix_row * (offset + matrix_row - 1 + offset) / 2 + matrix_col;\n          grad_val = batch_interaction_out_grad[out_grad_col];\n        } else if (matrix_row < matrix_col) {\n          // transpose add\n          int32_t trans_row_id = matrix_col;\n          int32_t trans_col_id = matrix_row;\n          int32_t out_grad_col =\n              trans_row_id * (offset + trans_row_id - 1 + offset) / 2 + trans_col_id;\n          grad_val = batch_interaction_out_grad[out_grad_col];\n        } else if ((matrix_row == matrix_col) && (offset == 1)) {\n          int32_t out_grad_col = matrix_row * (offset + matrix_row - 1 + offset) / 2 + matrix_col;\n          grad_val = batch_interaction_out_grad[out_grad_col] * static_cast<T>(2);\n        }\n      }\n      matrix_out_grad_buf[i] = wmma.Convert(grad_val);\n    }\n  }\n\n  // 2.load in to in in_buf\n  for (int col = threadIdx.x; col < vector_num_pack; col += blockDim.x) {\n#pragma unroll\n    for (int i = 0; i < max_in; ++i) {\n      if (i >= param.num_in) { break; }\n      const Pack<T, pack_size>* batch_in = reinterpret_cast<const Pack<T, pack_size>*>(param.in[i])\n                                           + batch_idx * param.in_feature_dim[i] * vector_num_pack;\n      for (int j = threadIdx.y * kUnrollDim; j < param.in_feature_dim[i];\n           j += blockDim.y * kUnrollDim) {\n#pragma unroll\n        for (int k = 0; k < kUnrollDim; ++k) {\n          int in_row = j + k;\n          if (in_row >= param.in_feature_dim[i]) { break; }\n          int buf_row = param.dim_start_offset[i] + in_row;\n          Pack<T, pack_size> pack_in_val = batch_in[in_row * vector_num_pack + col];\n#pragma unroll\n          for (int t = 0; t < pack_size; ++t) {\n            pack_in_val.elem[t] = wmma.Convert(pack_in_val.elem[t]);\n          }\n          in_buf_pack[buf_row * in_shared_mem_cols_num_pack + col] = pack_in_val;\n        }\n      }\n    }\n    // load sparse feature to shared_mem\n    for (int j = threadIdx.y * kUnrollDim; j < param.sparse_dim; j += blockDim.y * kUnrollDim) {\n#pragma unroll\n      for (int k = 0; k < kUnrollDim; ++k) {\n        int in_row = j + k;\n        if (in_row >= param.sparse_dim) { break; }\n        int buf_row = param.sparse_dim_start + in_row;\n        int sparse_in_row = batch_sparse_indices[in_row];\n        Pack<T, pack_size> pack_in_val = sparse_feature_pack[sparse_in_row * vector_num_pack + col];\n#pragma unroll\n        for (int t = 0; t < pack_size; ++t) {\n          pack_in_val.elem[t] = wmma.Convert(pack_in_val.elem[t]);\n        }\n        in_buf_pack[buf_row * in_shared_mem_cols_num_pack + col] = pack_in_val;\n      }\n    }\n  }\n  Pack<T, pack_size> zero;\n#pragma unroll\n  for (int k = 0; k < pack_size; ++k) { zero.elem[k] = wmma.Convert(0); }\n#pragma unroll\n  for (int row = features_dim + threadIdx.y; row < padded_num_rows; row += blockDim.y) {\n    for (int col = threadIdx.x; col < padded_vector_num_pack; col += blockDim.x) {\n      in_buf_pack[row * in_shared_mem_cols_num_pack + col] = zero;\n    }\n  }\n  for (int row = threadIdx.y; row < features_dim; row += blockDim.y) {\n    for (int col = vector_num_pack + threadIdx.x; col < padded_vector_num_pack; col += blockDim.x) {\n      in_buf_pack[row * in_shared_mem_cols_num_pack + col] = zero;\n    }\n  }\n  __syncthreads();\n\n  for (int blocks_id = warp_id; blocks_id < m_num_tiles * n_num_tiles; blocks_id += blockDim.y) {\n    int blocks_row = blocks_id / n_num_tiles;\n    int blocks_col = blocks_id - blocks_row * n_num_tiles;\n    wmma.InitAcc();\n    for (int step = 0; step < k_num_tiles; ++step) {\n      // blocks_row is a row_id, step is a col_id. blocks_col is b col_id,\n      // step is b row_id.\n      T* tile_a_ptr = matrix_out_grad_buf\n                      + blocks_row * mn_tile_dim * matrix_out_grad_shared_mem_cols\n                      + step * k_tile_dim;\n      T* tile_b_ptr = in_buf + step * k_tile_dim * in_shared_mem_cols + blocks_col * mn_tile_dim;\n      wmma.LoadA(tile_a_ptr, matrix_out_grad_shared_mem_cols);\n      wmma.LoadB(tile_b_ptr, in_shared_mem_cols);\n      wmma.Mma();\n    }\n    ComputeType* tile_ptr =\n        in_grad_buf + blocks_row * mn_tile_dim * in_shared_mem_cols + blocks_col * mn_tile_dim;\n    wmma.Store(tile_ptr, in_shared_mem_cols);\n  }\n  __syncthreads();\n\n  // 4.split in_grad buf to dx\n  // shared_mem to dense dx\n  for (int col = threadIdx.x; col < vector_num_pack; col += blockDim.x) {\n#pragma unroll\n    for (int i = 0; i < max_in; ++i) {\n      if (i >= param.num_in) { break; }\n      Pack<T, pack_size>* batch_in_grad = reinterpret_cast<Pack<T, pack_size>*>(param.in_grad[i])\n                                          + batch_idx * param.in_feature_dim[i] * vector_num_pack;\n      for (int j = threadIdx.y * kUnrollDim; j < param.in_feature_dim[i];\n           j += blockDim.y * kUnrollDim) {\n#pragma unroll\n        for (int k = 0; k < kUnrollDim; ++k) {\n          int in_row = j + k;\n          if (in_row >= param.in_feature_dim[i]) { break; }\n          int buf_row = param.dim_start_offset[i] + in_row;\n          Pack<T, pack_size> grad_val;\n          Pack<ComputeType, pack_size> buf_grad_val =\n              in_grad_buf_pack[buf_row * in_shared_mem_cols_num_pack + col];\n#pragma unroll\n          for (int t = 0; t < pack_size; ++t) {\n            grad_val.elem[t] = static_cast<T>(buf_grad_val.elem[t]);\n          }\n          batch_in_grad[in_row * vector_num_pack + col] = grad_val;\n        }\n      }\n    }\n  }\n  // shared_mem to sparse dx, sparse in grad use sparse_grad_pack_size\n  Pack<ComputeType, sparse_grad_pack_size>* in_grad_buf_sparse_grad_pack =\n      reinterpret_cast<Pack<ComputeType, sparse_grad_pack_size>*>(in_grad_buf);\n  Pack<T, sparse_grad_pack_size>* sparse_feature_grad_pack =\n      reinterpret_cast<Pack<T, sparse_grad_pack_size>*>(param.sparse_feature_grad);\n  for (int col = threadIdx.x; col < vector_num_sparse_grad_pack; col += blockDim.x) {\n    for (int j = threadIdx.y * kUnrollDim; j < param.sparse_dim; j += blockDim.y * kUnrollDim) {\n#pragma unroll\n      for (int k = 0; k < kUnrollDim; ++k) {\n        int in_row = j + k;\n        if (in_row >= param.sparse_dim) { break; }\n        int buf_row = param.sparse_dim_start + in_row;\n        int sparse_in_row = batch_sparse_indices[in_row];\n        Pack<ComputeType, sparse_grad_pack_size> buf_grad_val =\n            in_grad_buf_sparse_grad_pack[buf_row * in_shared_mem_cols_num_sparse_grad_pack + col];\n        AtomicAdd<T, ComputeType, sparse_grad_pack_size>(\n            sparse_feature_grad_pack + sparse_in_row * vector_num_sparse_grad_pack + col,\n            buf_grad_val);\n      }\n    }\n  }\n\n#else\n  __trap();\n#endif  // __CUDA_ARCH__ >= 700\n}\n\ntemplate<typename T, int max_in, int32_t pack_size, int32_t sparse_grad_pack_size>\nstruct DotFeatureInteractionBackwardKernel {\n  static bool Launch(ep::Stream* stream, int64_t batch_size, int concated_padded_dim,\n                     int vector_size, int out_num_cols, bool self_interaction,\n                     const DotBwdParam<T, max_in>& param) {\n    const int block_size = 256;\n    const int block_dim_x = 32;\n    const int block_dim_y = block_size / block_dim_x;\n    const int num_blocks = batch_size;\n    const int mn_tile_dim = 16;\n    const int k_tile_dim = KTileDim<T>::val;\n    const int64_t padded_vector_size = GetPaddedDim(vector_size);\n    const int m_num_tiles = concated_padded_dim / mn_tile_dim;\n    const int k_num_tiles = concated_padded_dim / k_tile_dim;\n    const int n_num_tiles = padded_vector_size / mn_tile_dim;\n    const int skew_in = 8;\n    const int in_shared_mem_num_cols = padded_vector_size + skew_in;\n    const int matrix_out_grad_shared_mem_cols = concated_padded_dim + skew_in;\n    const size_t in_shared_mem_bytes = concated_padded_dim * in_shared_mem_num_cols * sizeof(T);\n    const size_t matrix_out_grad_shared_mem_bytes =\n        concated_padded_dim * matrix_out_grad_shared_mem_cols * sizeof(T);\n    using ComputeType = typename DefaultComputeType<T>::type;\n    const size_t in_grad_shared_mem_bytes =\n        concated_padded_dim * in_shared_mem_num_cols * sizeof(ComputeType);\n    const size_t total_shared_mem_bytes =\n        in_shared_mem_bytes + matrix_out_grad_shared_mem_bytes + in_grad_shared_mem_bytes;\n    const int32_t offset = self_interaction ? 1 : 0;\n    const int vector_num_pack = vector_size / pack_size;\n    const int padded_vector_num_pack = padded_vector_size / pack_size;\n    const int in_shared_mem_cols_num_pack = in_shared_mem_num_cols / pack_size;\n    const int vector_num_sparse_grad_pack = vector_size / sparse_grad_pack_size;\n    const int in_shared_mem_cols_num_sparse_grad_pack =\n        in_shared_mem_num_cols / sparse_grad_pack_size;\n\n    int max_active_blocks;\n    OF_CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(\n        &max_active_blocks,\n        DotFeatureInteractionBackwardWmmaImpl<T, ComputeType, max_in, pack_size,\n                                              sparse_grad_pack_size, mn_tile_dim, k_tile_dim>,\n        block_size, total_shared_mem_bytes));\n    if (max_active_blocks <= 0) { return false; }\n    cudaStream_t cuda_stream = stream->As<ep::CudaStream>()->cuda_stream();\n    DotFeatureInteractionBackwardWmmaImpl<T, ComputeType, max_in, pack_size, sparse_grad_pack_size,\n                                          mn_tile_dim, k_tile_dim>\n        <<<num_blocks, dim3(block_dim_x, block_dim_y), total_shared_mem_bytes, cuda_stream>>>(\n            m_num_tiles, n_num_tiles, k_num_tiles, batch_size, concated_padded_dim, vector_num_pack,\n            vector_num_sparse_grad_pack, padded_vector_num_pack, out_num_cols,\n            in_shared_mem_num_cols, in_shared_mem_cols_num_pack,\n            in_shared_mem_cols_num_sparse_grad_pack, matrix_out_grad_shared_mem_cols, offset,\n            param);\n\n    return true;\n  }\n};\n\ntemplate<typename T, size_t pack>\n__global__ void MemsetGpu(int64_t parallel_num, int64_t vector_size, const uint32_t* num_valid,\n                          T* dst) {\n  size_t count = 0;\n  for (int i = 0; i < parallel_num; ++i) { count += num_valid[i] * vector_size; }\n  const size_t pack_count = count / pack;\n  Pack<T, pack> pack_value;\n  for (int i = 0; i < pack; ++i) { pack_value.elem[i] = static_cast<T>(0); }\n  auto* pack_dst = reinterpret_cast<Pack<T, pack>*>(dst);\n  CUDA_1D_KERNEL_LOOP_T(size_t, i, pack_count) { pack_dst[i] = pack_value; }\n  T* tail_dst = dst + pack_count * pack;\n  const size_t tail_count = count - pack_count * pack;\n  CUDA_1D_KERNEL_LOOP_T(size_t, i, tail_count) { tail_dst[i] = static_cast<T>(0); }\n}\n\ntemplate<typename T, size_t pack>\ntypename std::enable_if<(pack != 0), void>::type LaunchPackMemsetGpu(cudaStream_t stream,\n                                                                     const uint32_t* num_valid,\n                                                                     T* ptr, size_t sm_count,\n                                                                     int64_t vector_size,\n                                                                     int64_t parallel_num) {\n  MemsetGpu<T, pack><<<2 * sm_count, 1024, 0, stream>>>(parallel_num, vector_size, num_valid, ptr);\n}\n\ntemplate<typename T, size_t pack>\ntypename std::enable_if<(pack == 0), void>::type LaunchPackMemsetGpu(cudaStream_t stream,\n                                                                     const uint32_t* num_valid,\n                                                                     T* ptr, size_t sm_count,\n                                                                     int64_t vector_size,\n                                                                     int64_t parallel_num) {\n  LOG(FATAL) << \"wrong alignment\";\n}\n\ntemplate<typename T>\nvoid LaunchMemset(cudaStream_t stream, size_t sm_count, int64_t vector_size, int64_t parallel_num,\n                  const uint32_t* num_valid, T* ptr) {\n  auto uintptr = reinterpret_cast<std::uintptr_t>(ptr);\n  if (uintptr % 16 == 0) {\n    LaunchPackMemsetGpu<T, 16 / sizeof(T)>(stream, num_valid, ptr, sm_count, vector_size,\n                                           parallel_num);\n  } else if (uintptr % 8 == 0) {\n    LaunchPackMemsetGpu<T, 8 / sizeof(T)>(stream, num_valid, ptr, sm_count, vector_size,\n                                          parallel_num);\n  } else if (uintptr % 4 == 0) {\n    LaunchPackMemsetGpu<T, 4 / sizeof(T)>(stream, num_valid, ptr, sm_count, vector_size,\n                                          parallel_num);\n  } else if (uintptr % 2 == 0) {\n    LaunchPackMemsetGpu<T, 2 / sizeof(T)>(stream, num_valid, ptr, sm_count, vector_size,\n                                          parallel_num);\n  } else {\n    LaunchPackMemsetGpu<T, 1 / sizeof(T)>(stream, num_valid, ptr, sm_count, vector_size,\n                                          parallel_num);\n  }\n}\n\ntemplate<typename T, int max_in>\nbool DispatchFeatureInteractionDotPackSize(user_op::KernelComputeContext* ctx,\n                                           const int32_t input_size) {\n  CHECK_LE(input_size, max_in) << input_size;\n  user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n  const int64_t batch_size = out->shape_view().At(0);\n  const int64_t out_num_cols = out->shape_view().At(1);\n  const int64_t vector_size = ctx->TensorDesc4ArgNameAndIndex(\"features\", 0)->shape().At(2);\n  DotFwdParam<T, max_in> param;\n  param.num_in = input_size;\n  param.out = out->mut_dptr<T>();\n  int64_t features_concated_dim = 0;\n  for (int i = 0; i < input_size; ++i) {\n    param.in[i] = ctx->Tensor4ArgNameAndIndex(\"features\", i)->dptr<T>();\n    param.in_feature_dim[i] = ctx->TensorDesc4ArgNameAndIndex(\"features\", i)->shape().At(1);\n    param.dim_start_offset[i] = features_concated_dim;\n    features_concated_dim += param.in_feature_dim[i];\n  }\n  if (ctx->has_input(\"sparse_feature\", 0)) {\n    CHECK(ctx->has_input(\"sparse_indices\", 0));\n    const user_op::Tensor* sparse_feature = ctx->Tensor4ArgNameAndIndex(\"sparse_feature\", 0);\n    const user_op::Tensor* sparse_indices = ctx->Tensor4ArgNameAndIndex(\"sparse_indices\", 0);\n    param.sparse_feature = sparse_feature->dptr<T>();\n    CHECK_EQ(sparse_indices->data_type(), DataType::kUInt32);\n    param.sparse_indices = reinterpret_cast<const uint32_t*>(sparse_indices->dptr());\n    param.sparse_dim = ctx->TensorDesc4ArgNameAndIndex(\"sparse_indices\", 0)->shape().At(1);\n    param.sparse_dim_start = features_concated_dim;\n    features_concated_dim += param.sparse_dim;\n  } else {\n    param.sparse_feature = nullptr;\n    param.sparse_indices = nullptr;\n    param.sparse_dim = 0;\n    param.sparse_dim_start = 0;\n  }\n  const int64_t concated_padded_dim = GetPaddedDim(features_concated_dim);\n  param.features_dim = features_concated_dim;\n  if (ctx->has_input(\"output_concat\", 0)) {\n    const user_op::Tensor* output_concat = ctx->Tensor4ArgNameAndIndex(\"output_concat\", 0);\n    param.output_concat = output_concat->dptr<T>();\n    param.output_concat_size = output_concat->shape_view().At(1);\n  } else {\n    param.output_concat = nullptr;\n    param.output_concat_size = 0;\n  }\n  const bool self_interaction = ctx->Attr<bool>(\"self_interaction\");\n  const int32_t output_padding = ctx->Attr<int32_t>(\"output_padding\");\n  if (vector_size % 4 == 0 && out_num_cols % 4 == 0) {\n    return DotFeatureInteractionKernel<T, max_in, 4>::Launch(\n        ctx->stream(), batch_size, concated_padded_dim, vector_size, out_num_cols, self_interaction,\n        output_padding, param);\n  } else if (vector_size % 2 == 0 && out_num_cols % 2 == 0) {\n    return DotFeatureInteractionKernel<T, max_in, 2>::Launch(\n        ctx->stream(), batch_size, concated_padded_dim, vector_size, out_num_cols, self_interaction,\n        output_padding, param);\n  } else {\n    return DotFeatureInteractionKernel<T, max_in, 1>::Launch(\n        ctx->stream(), batch_size, concated_padded_dim, vector_size, out_num_cols, self_interaction,\n        output_padding, param);\n  }\n}\n\ntemplate<typename T, int max_in>\nbool DispatchFeatureInteractionDotBackwardPackSize(user_op::KernelComputeContext* ctx,\n                                                   const int32_t input_size) {\n  CHECK_LE(input_size, max_in) << input_size;\n  user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n  const int64_t batch_size = dy->shape_view().At(0);\n  const int64_t out_num_cols = dy->shape_view().At(1);\n  const int64_t vector_size = ctx->TensorDesc4ArgNameAndIndex(\"features\", 0)->shape().At(2);\n  DotBwdParam<T, max_in> param;\n  param.num_in = input_size;\n  param.out_grad = dy->dptr<T>();\n  int64_t features_concated_dim = 0;\n  for (int i = 0; i < input_size; ++i) {\n    param.in[i] = ctx->Tensor4ArgNameAndIndex(\"features\", i)->dptr<T>();\n    param.in_grad[i] = ctx->Tensor4ArgNameAndIndex(\"features_grad\", i)->mut_dptr<T>();\n    param.in_feature_dim[i] = ctx->TensorDesc4ArgNameAndIndex(\"features\", i)->shape().At(1);\n    param.dim_start_offset[i] = features_concated_dim;\n    features_concated_dim += param.in_feature_dim[i];\n  }\n  if (ctx->has_input(\"sparse_feature\", 0)) {\n    CHECK(ctx->has_input(\"sparse_indices\", 0));\n    CHECK(ctx->has_input(\"num_valid_sparse_feature\", 0));\n    CHECK(ctx->has_output(\"sparse_feature_grad\", 0));\n    const user_op::Tensor* sparse_feature = ctx->Tensor4ArgNameAndIndex(\"sparse_feature\", 0);\n    const user_op::Tensor* sparse_indices = ctx->Tensor4ArgNameAndIndex(\"sparse_indices\", 0);\n    const user_op::Tensor* num_valid_sparse_feature =\n        ctx->Tensor4ArgNameAndIndex(\"num_valid_sparse_feature\", 0);\n    param.sparse_feature = sparse_feature->dptr<T>();\n    CHECK_EQ(sparse_indices->data_type(), DataType::kUInt32);\n    param.sparse_indices = reinterpret_cast<const uint32_t*>(sparse_indices->dptr());\n    param.sparse_dim = ctx->TensorDesc4ArgNameAndIndex(\"sparse_indices\", 0)->shape().At(1);\n    param.sparse_dim_start = features_concated_dim;\n    features_concated_dim += param.sparse_dim;\n    param.sparse_feature_grad =\n        ctx->Tensor4ArgNameAndIndex(\"sparse_feature_grad\", 0)->mut_dptr<T>();\n    const int64_t parallel_num = ctx->parallel_ctx().parallel_num();\n    const int64_t parallel_id = ctx->parallel_ctx().parallel_id();\n    CHECK_EQ(num_valid_sparse_feature->data_type(), DataType::kUInt32);\n    LaunchMemset<T>(ctx->stream()->As<ep::CudaStream>()->cuda_stream(),\n                    ctx->stream()->As<ep::CudaStream>()->device_properties().multiProcessorCount,\n                    vector_size, parallel_num,\n                    reinterpret_cast<const uint32_t*>(num_valid_sparse_feature->dptr())\n                        + parallel_id * parallel_num,\n                    param.sparse_feature_grad);\n  } else {\n    param.sparse_feature = nullptr;\n    param.sparse_indices = nullptr;\n    param.sparse_feature_grad = nullptr;\n    param.sparse_dim = 0;\n    param.sparse_dim_start = 0;\n  }\n  const int64_t concated_padded_dim = GetPaddedDim(features_concated_dim);\n  param.features_dim = features_concated_dim;\n  if (ctx->has_output(\"output_concat_grad\", 0)) {\n    user_op::Tensor* output_concat_grad = ctx->Tensor4ArgNameAndIndex(\"output_concat_grad\", 0);\n    param.output_concat_grad = output_concat_grad->mut_dptr<T>();\n    param.output_concat_size = output_concat_grad->shape_view().At(1);\n  } else {\n    param.output_concat_grad = nullptr;\n    param.output_concat_size = 0;\n  }\n  const bool self_interaction = ctx->Attr<bool>(\"self_interaction\");\n  if (vector_size % 4 == 0) {\n    return DotFeatureInteractionBackwardKernel<T, max_in, 4, 2>::Launch(\n        ctx->stream(), batch_size, concated_padded_dim, vector_size, out_num_cols, self_interaction,\n        param);\n  } else if (vector_size % 2 == 0) {\n    return DotFeatureInteractionBackwardKernel<T, max_in, 2, 2>::Launch(\n        ctx->stream(), batch_size, concated_padded_dim, vector_size, out_num_cols, self_interaction,\n        param);\n  } else {\n    if (ctx->has_input(\"sparse_feature\", 0) && dy->data_type() == DataType::kFloat16) {\n      UNIMPLEMENTED()\n          << \"fused dot interaction backward kernel not support sparse_feature with pack_size 1, \"\n             \"because atomicAdd(half) is too slow\";\n      return false;\n    }\n    return DotFeatureInteractionBackwardKernel<T, max_in, 1, 1>::Launch(\n        ctx->stream(), batch_size, concated_padded_dim, vector_size, out_num_cols, self_interaction,\n        param);\n  }\n}\n\ntemplate<typename T, int32_t max_in>\nstruct Param {\n  const T* in[max_in];\n  int32_t in_feature_dim[max_in];\n  T* out;\n  int32_t num_in;\n};\n\ntemplate<typename T, int32_t max_in, int32_t pack_size>\n__global__ void FeatureInteractionSum(int64_t batch_size, int64_t vector_num_pack,\n                                      Param<T, max_in> param) {\n  using ComputeType = typename DefaultComputeType<T>::type;\n  Pack<T, pack_size>* dst_pack = reinterpret_cast<Pack<T, pack_size>*>(param.out);\n  for (int batch_idx = blockIdx.x * blockDim.y + threadIdx.y; batch_idx < batch_size;\n       batch_idx += gridDim.x * blockDim.y) {\n    Pack<T, pack_size>* batch_out = dst_pack + batch_idx * vector_num_pack;\n    for (int col_id = threadIdx.x; col_id < vector_num_pack; col_id += blockDim.x) {\n      Pack<ComputeType, pack_size> sum;\n      Pack<ComputeType, pack_size> square_sum;\n#pragma unroll\n      for (int k = 0; k < pack_size; ++k) {\n        sum.elem[k] = static_cast<ComputeType>(0);\n        square_sum.elem[k] = static_cast<ComputeType>(0);\n      }\n      for (int i = 0; i < max_in; ++i) {\n        if (i >= param.num_in) { break; }\n        const Pack<T, pack_size>* batch_in =\n            reinterpret_cast<const Pack<T, pack_size>*>(param.in[i])\n            + batch_idx * param.in_feature_dim[i] * vector_num_pack;\n#pragma unroll\n        for (int j = 0; j < param.in_feature_dim[i]; ++j) {\n          Pack<T, pack_size> val = batch_in[j * vector_num_pack + col_id];\n#pragma unroll\n          for (int k = 0; k < pack_size; ++k) {\n            const ComputeType compute_val = static_cast<ComputeType>(val.elem[k]);\n            sum.elem[k] += compute_val;\n            square_sum.elem[k] += compute_val * compute_val;\n          }\n        }\n      }\n      Pack<T, pack_size> out;\n#pragma unroll\n      for (int k = 0; k < pack_size; ++k) {\n        out.elem[k] = static_cast<T>((sum.elem[k] * sum.elem[k] - square_sum.elem[k])\n                                     * static_cast<ComputeType>(0.5));\n      }\n      batch_out[col_id] = out;\n    }\n  }\n}\n\ntemplate<typename T, int32_t max_in>\nstruct GradParam {\n  const T* out_grad;\n  const T* in[max_in];\n  int32_t in_feature_dim[max_in];\n  T* in_grad[max_in];\n  int32_t num_in;\n};\n\ntemplate<typename T, int32_t max_in>\n__global__ void FeatureInteractionSumGrad(int64_t batch_size, int64_t vector_size,\n                                          GradParam<T, max_in> param) {\n  using ComputeType = typename DefaultComputeType<T>::type;\n  for (int batch_idx = blockIdx.x * blockDim.y + threadIdx.y; batch_idx < batch_size;\n       batch_idx += gridDim.x * blockDim.y) {\n    const T* batch_out_grad = param.out_grad + batch_idx * vector_size;\n    for (int col_id = threadIdx.x; col_id < vector_size; col_id += blockDim.x) {\n      ComputeType sum = 0;\n      for (int i = 0; i < max_in; ++i) {\n        if (i >= param.num_in) { break; }\n        const T* batch_in = param.in[i] + batch_idx * param.in_feature_dim[i] * vector_size;\n        for (int j = 0; j < param.in_feature_dim[i]; ++j) {\n          sum += static_cast<ComputeType>(batch_in[j * vector_size + col_id]);\n        }\n      }\n      for (int i = 0; i < max_in; ++i) {\n        if (i >= param.num_in) { break; }\n        const int64_t in_batch_offset = batch_idx * param.in_feature_dim[i] * vector_size;\n        const T* batch_in = param.in[i] + in_batch_offset;\n        T* batch_in_grad = param.in_grad[i] + in_batch_offset;\n        for (int j = 0; j < param.in_feature_dim[i]; ++j) {\n          const int64_t offset = j * vector_size + col_id;\n          batch_in_grad[offset] =\n              static_cast<T>(static_cast<ComputeType>(batch_out_grad[col_id])\n                             * (sum - static_cast<ComputeType>(batch_in[offset])));\n        }\n      }\n    }\n  }\n}\n\nvoid GetBlockDims(const int64_t vector_size, int* block_dim_x, int* block_dim_y) {\n  const int block_size = 256;\n  if (vector_size < block_size) {\n    *block_dim_x = std::ceil(static_cast<float>(vector_size) / 8) * 8;\n    *block_dim_y = (block_size + *block_dim_x - 1) / *block_dim_x;\n  } else {\n    *block_dim_x = block_size;\n    *block_dim_y = 1;\n  }\n}\n\nint GetNumBlocks(const int64_t num_instances, const int64_t instance_per_block) {\n  int max_blocks = (num_instances + instance_per_block - 1) / instance_per_block;\n  return std::min(max_blocks, kCudaMaxBlocksNum);\n}\n\ntemplate<typename T, int32_t max_in>\nvoid DispatchFeatureInteractionSumPackSize(ep::Stream* stream, const int64_t batch_size,\n                                           const int64_t vector_size,\n                                           const Param<T, max_in>& param) {\n  int block_dim_x;\n  int block_dim_y;\n  const int pack_size = (vector_size % 2 == 0) ? 2 : 1;\n  const int64_t vector_num_pack = vector_size / pack_size;\n  GetBlockDims(vector_num_pack, &block_dim_x, &block_dim_y);\n  const int num_blocks = GetNumBlocks(batch_size, block_dim_y);\n  dim3 block_dims = dim3(block_dim_x, block_dim_y);\n  cudaStream_t cuda_stream = stream->As<ep::CudaStream>()->cuda_stream();\n  if (pack_size == 2) {\n    FeatureInteractionSum<T, max_in, 2>\n        <<<num_blocks, block_dims, 0, cuda_stream>>>(batch_size, vector_num_pack, param);\n  } else {\n    FeatureInteractionSum<T, max_in, 1>\n        <<<num_blocks, block_dims, 0, cuda_stream>>>(batch_size, vector_num_pack, param);\n  }\n}\n\ntemplate<typename T, int max_in>\nvoid DispatchFeatureInteractionSumInputSize(user_op::KernelComputeContext* ctx,\n                                            const int32_t input_size) {\n  CHECK_LE(input_size, max_in) << input_size;\n  user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n  const int64_t batch_size = out->shape_view().At(0);\n  const int64_t vector_size = out->shape_view().At(1);\n  Param<T, max_in> param;\n  param.num_in = input_size;\n  param.out = out->mut_dptr<T>();\n  for (int i = 0; i < input_size; ++i) {\n    param.in[i] = ctx->Tensor4ArgNameAndIndex(\"features\", i)->dptr<T>();\n    param.in_feature_dim[i] = ctx->TensorDesc4ArgNameAndIndex(\"features\", i)->shape().At(1);\n  }\n  DispatchFeatureInteractionSumPackSize<T, max_in>(ctx->stream(), batch_size, vector_size, param);\n}\n\ntemplate<typename T, int max_in>\nvoid DispatchFeatureInteractionSumGradInputSize(user_op::KernelComputeContext* ctx,\n                                                const int32_t input_size) {\n  CHECK_LE(input_size, max_in) << input_size;\n  const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n  const int64_t batch_size = dy->shape_view().At(0);\n  const int64_t vector_size = dy->shape_view().At(1);\n  int block_dim_x;\n  int block_dim_y;\n  GetBlockDims(vector_size, &block_dim_x, &block_dim_y);\n  const int num_blocks = GetNumBlocks(batch_size, block_dim_y);\n  dim3 block_dims = dim3(block_dim_x, block_dim_y);\n  GradParam<T, max_in> param;\n  param.num_in = input_size;\n  param.out_grad = dy->dptr<T>();\n  for (int i = 0; i < input_size; ++i) {\n    param.in[i] = ctx->Tensor4ArgNameAndIndex(\"features\", i)->dptr<T>();\n    param.in_grad[i] = ctx->Tensor4ArgNameAndIndex(\"features_grad\", i)->mut_dptr<T>();\n    param.in_feature_dim[i] = ctx->TensorDesc4ArgNameAndIndex(\"features_grad\", i)->shape().At(1);\n  }\n  FeatureInteractionSumGrad<T, max_in>\n      <<<num_blocks, block_dims, 0, ctx->stream()->As<ep::CudaStream>()->cuda_stream()>>>(\n          batch_size, vector_size, param);\n}\n\n}  // namespace\n\ntemplate<typename T>\nclass FusedDotFeatureInteractionPoolingSumKernel final : public user_op::OpKernel,\n                                                         public user_op::CudaGraphSupport {\n public:\n  FusedDotFeatureInteractionPoolingSumKernel() = default;\n  ~FusedDotFeatureInteractionPoolingSumKernel() override = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    CHECK(!ctx->has_input(\"sparse_feature\", 0)) << \"pooling sum, sparse_feature is not supported. \";\n    const int input_size = ctx->input_size(\"features\");\n    if (input_size == 1) {\n      DispatchFeatureInteractionSumInputSize<T, 1>(ctx, input_size);\n    } else if (input_size == 2) {\n      DispatchFeatureInteractionSumInputSize<T, 2>(ctx, input_size);\n    } else if (input_size <= 8) {\n      DispatchFeatureInteractionSumInputSize<T, 8>(ctx, input_size);\n    } else {\n      CHECK_LE(input_size, 128) << \"input_size must not greater than 128. \";\n      DispatchFeatureInteractionSumInputSize<T, 128>(ctx, input_size);\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_FUSED_DOT_FEATURE_INTERACTION_POOLING_SUM_KERNEL(dtype)                \\\n  REGISTER_USER_KERNEL(\"fused_dot_feature_interaction\")                                 \\\n      .SetCreateFn<FusedDotFeatureInteractionPoolingSumKernel<dtype>>()                 \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)                  \\\n                       && (user_op::HobDataType(\"out\", 0) == GetDataType<dtype>::value) \\\n                       && (user_op::HobAttr<std::string>(\"pooling\") == \"sum\"));\n\nREGISTER_FUSED_DOT_FEATURE_INTERACTION_POOLING_SUM_KERNEL(float)\nREGISTER_FUSED_DOT_FEATURE_INTERACTION_POOLING_SUM_KERNEL(half)\n\ntemplate<typename T>\nbool TryLaunchTensorCoreDotKernel(user_op::KernelComputeContext* ctx) {\n  const int input_size = ctx->input_size(\"features\");\n  if (input_size == 1) {\n    return DispatchFeatureInteractionDotPackSize<T, 1>(ctx, input_size);\n  } else if (input_size == 2) {\n    return DispatchFeatureInteractionDotPackSize<T, 2>(ctx, input_size);\n  } else if (input_size <= 8) {\n    return DispatchFeatureInteractionDotPackSize<T, 8>(ctx, input_size);\n  } else {\n    CHECK_LE(input_size, 128) << \"input_size must not greater than 128. \";\n    return DispatchFeatureInteractionDotPackSize<T, 128>(ctx, input_size);\n  }\n}\n\ntemplate<typename T>\nbool TryLaunchTensorCoreDotBackwardKernel(user_op::KernelComputeContext* ctx) {\n  const int input_size = ctx->input_size(\"features\");\n  if (input_size == 1) {\n    return DispatchFeatureInteractionDotBackwardPackSize<T, 1>(ctx, input_size);\n  } else if (input_size == 2) {\n    return DispatchFeatureInteractionDotBackwardPackSize<T, 2>(ctx, input_size);\n  } else if (input_size <= 8) {\n    return DispatchFeatureInteractionDotBackwardPackSize<T, 8>(ctx, input_size);\n  } else {\n    CHECK_LE(input_size, 128) << \"input_size must not greater than 128. \";\n    return DispatchFeatureInteractionDotBackwardPackSize<T, 128>(ctx, input_size);\n  }\n}\ntemplate<typename T>\nclass FusedDotFeatureInteractionKernel final : public user_op::OpKernel,\n                                               public user_op::CudaGraphSupport {\n public:\n  FusedDotFeatureInteractionKernel() = default;\n  ~FusedDotFeatureInteractionKernel() override = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    const DataType data_type = out->data_type();\n    CHECK_LT(out->shape_view().elem_cnt(), GetMaxVal<int32_t>());\n    auto* cuda_stream = ctx->stream()->As<ep::CudaStream>();\n    if ((cuda_stream->device_properties().major >= 7 && data_type == DataType::kFloat16)\n        || (cuda_stream->device_properties().major >= 8 && data_type == DataType::kFloat)) {\n      bool success = TryLaunchTensorCoreDotKernel<T>(ctx);\n      if (success == true) { return; }\n    }\n    CHECK(!ctx->has_input(\"sparse_feature\", 0)) << \"sparse_feature is not supported. \";\n    user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n    const int64_t batch_size = out->shape_view().At(0);\n    int64_t features_concated_dim = 0;\n    for (int64_t i = 0; i < ctx->input_size(\"features\"); ++i) {\n      features_concated_dim += ctx->TensorDesc4ArgNameAndIndex(\"features\", i)->shape().At(1);\n    }\n    const int64_t concated_padded_dim = GetPaddedDim(features_concated_dim);\n    const int64_t vector_size = ctx->TensorDesc4ArgNameAndIndex(\"features\", 0)->shape().At(2);\n    const int64_t out_dim = out->shape_view().At(1);\n    const int32_t output_padding = ctx->Attr<int32_t>(\"output_padding\");\n    const int64_t valid_out_dim = out_dim - output_padding;\n    const bool self_interaction = ctx->Attr<bool>(\"self_interaction\");\n\n    T* matmul_out = reinterpret_cast<T*>(tmp_buffer->mut_dptr<char>());\n    size_t matmul_out_size =\n        GetCudaAlignedSize(batch_size * concated_padded_dim * concated_padded_dim * sizeof(T));\n    const int64_t interaction_dim = self_interaction\n                                        ? features_concated_dim * (features_concated_dim + 1) / 2\n                                        : features_concated_dim * (features_concated_dim - 1) / 2;\n    int32_t* gather_indices_ptr =\n        reinterpret_cast<int32_t*>(tmp_buffer->mut_dptr<char>() + matmul_out_size);\n    size_t gather_indices_size = GetCudaAlignedSize(interaction_dim * sizeof(int32_t));\n    T* padded_concated_features_ptr =\n        reinterpret_cast<T*>(tmp_buffer->mut_dptr<char>() + matmul_out_size + gather_indices_size);\n    size_t padded_concated_features_size =\n        GetCudaAlignedSize(batch_size * concated_padded_dim * vector_size * sizeof(T));\n    CHECK_GE(tmp_buffer->shape_view().elem_cnt(),\n             matmul_out_size + gather_indices_size + padded_concated_features_size);\n    ConcatFeatures<T>(ctx, batch_size, concated_padded_dim * vector_size,\n                      padded_concated_features_ptr);\n    auto batch_matmul = ep::primitive::NewPrimitive<ep::primitive::BatchMatmulFactory>(\n        ctx->device_type(), data_type, ep::primitive::BlasTransposeType::N,\n        ep::primitive::BlasTransposeType::T);\n    batch_matmul->Launch(ctx->stream(), batch_size, concated_padded_dim, concated_padded_dim,\n                         vector_size, 1.0, padded_concated_features_ptr,\n                         padded_concated_features_ptr, 0.0, matmul_out);\n\n    int64_t output_concat_end_dim = 0;\n    const T* output_concat_ptr = nullptr;\n    if (ctx->has_input(\"output_concat\", 0)) {\n      user_op::Tensor* output_concat = ctx->Tensor4ArgNameAndIndex(\"output_concat\", 0);\n      output_concat_end_dim = output_concat->shape_view().At(1);\n      output_concat_ptr = output_concat->dptr<T>();\n    }\n    CHECK_EQ(valid_out_dim, output_concat_end_dim + interaction_dim);\n    GatherConcatKernel<T>(ctx->stream(), out->shape_view().elem_cnt(), out_dim, valid_out_dim,\n                          features_concated_dim, concated_padded_dim, output_concat_end_dim,\n                          self_interaction, matmul_out, output_concat_ptr, gather_indices_ptr,\n                          out->mut_dptr<T>());\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\ntemplate<typename T>\nuser_op::InferTmpSizeFn GenFusedDotFeatureInteractionInferTmpSizeFn() {\n  return [](user_op::InferContext* ctx) {\n    const Shape& first_feature_shape = ctx->InputShape(\"features\", 0);\n    const int64_t batch_size = first_feature_shape.At(0);\n    const int64_t vector_size = first_feature_shape.At(2);\n    int64_t features_concated_dim = 0;\n    for (int32_t i = 0; i < ctx->input_size(\"features\"); ++i) {\n      features_concated_dim += ctx->InputShape(\"features\", i).At(1);\n    }\n    const int64_t concated_padded_dim = GetPaddedDim(features_concated_dim);\n    size_t matmul_out_size =\n        GetCudaAlignedSize(batch_size * concated_padded_dim * concated_padded_dim * sizeof(T));\n    const bool self_interaction = ctx->Attr<bool>(\"self_interaction\");\n    const int64_t interaction_dim = self_interaction\n                                        ? features_concated_dim * (features_concated_dim + 1) / 2\n                                        : features_concated_dim * (features_concated_dim - 1) / 2;\n    size_t gather_indices_size = GetCudaAlignedSize(interaction_dim * sizeof(int32_t));\n    size_t padded_concated_features_size =\n        GetCudaAlignedSize(batch_size * concated_padded_dim * vector_size * sizeof(T));\n    return matmul_out_size + gather_indices_size + padded_concated_features_size;\n  };\n}\n\n#define REGISTER_FUSED_DOT_FEATURE_INTERACTION_KERNEL(dtype)                            \\\n  REGISTER_USER_KERNEL(\"fused_dot_feature_interaction\")                                 \\\n      .SetCreateFn<FusedDotFeatureInteractionKernel<dtype>>()                           \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)                  \\\n                       && (user_op::HobDataType(\"out\", 0) == GetDataType<dtype>::value) \\\n                       && (user_op::HobAttr<std::string>(\"pooling\") == \"none\"))         \\\n      .SetInferTmpSizeFn(GenFusedDotFeatureInteractionInferTmpSizeFn<dtype>());\n\nREGISTER_FUSED_DOT_FEATURE_INTERACTION_KERNEL(float)\nREGISTER_FUSED_DOT_FEATURE_INTERACTION_KERNEL(half)\n\ntemplate<typename T>\nclass FusedDotFeatureInteractionGradKernel final : public user_op::OpKernel,\n                                                   public user_op::CudaGraphSupport {\n public:\n  FusedDotFeatureInteractionGradKernel() = default;\n  ~FusedDotFeatureInteractionGradKernel() override = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n    const DataType data_type = dy->data_type();\n    auto* cuda_stream = ctx->stream()->As<ep::CudaStream>();\n    if ((cuda_stream->device_properties().major >= 7 && data_type == DataType::kFloat16)\n        || (cuda_stream->device_properties().major >= 8 && data_type == DataType::kFloat)) {\n      bool success = TryLaunchTensorCoreDotBackwardKernel<T>(ctx);\n      if (success == true) { return; }\n    }\n    CHECK(!ctx->has_input(\"sparse_feature\", 0)) << \"sparse_feature is not supported. \";\n    const int64_t batch_size = dy->shape_view().At(0);\n    int64_t features_concated_dim = 0;\n    for (int32_t i = 0; i < ctx->output_size(\"features_grad\"); ++i) {\n      features_concated_dim += ctx->TensorDesc4ArgNameAndIndex(\"features_grad\", i)->shape().At(1);\n    }\n    const int64_t concated_padded_dim = GetPaddedDim(features_concated_dim);\n    const int64_t vector_size = ctx->TensorDesc4ArgNameAndIndex(\"features_grad\", 0)->shape().At(2);\n    const int64_t out_dim = dy->shape_view().At(1);\n    const bool self_interaction = ctx->Attr<bool>(\"self_interaction\");\n    T* matmul_out_grad_ptr = reinterpret_cast<T*>(tmp_buffer->mut_dptr<char>());\n    size_t matmul_out_grad_size =\n        GetCudaAlignedSize(batch_size * concated_padded_dim * concated_padded_dim * sizeof(T));\n    T* padded_concated_features_grad_ptr =\n        reinterpret_cast<T*>(tmp_buffer->mut_dptr<char>() + matmul_out_grad_size);\n    size_t padded_concated_features_grad_size =\n        GetCudaAlignedSize(batch_size * concated_padded_dim * vector_size * sizeof(T));\n    T* padded_concated_features_ptr = reinterpret_cast<T*>(\n        tmp_buffer->mut_dptr<char>() + matmul_out_grad_size + padded_concated_features_grad_size);\n    size_t padded_concated_features_size = padded_concated_features_grad_size;\n    CHECK_LE(\n        matmul_out_grad_size + padded_concated_features_grad_size + padded_concated_features_size,\n        tmp_buffer->shape_view().elem_cnt());\n    ConcatFeatures<T>(ctx, batch_size, concated_padded_dim * vector_size,\n                      padded_concated_features_ptr);\n\n    T* output_concat_grad_ptr = nullptr;\n    int64_t output_concat_end_dim = 0;\n    if (ctx->has_output(\"output_concat_grad\", 0)) {\n      user_op::Tensor* output_concat_grad = ctx->Tensor4ArgNameAndIndex(\"output_concat_grad\", 0);\n      output_concat_grad_ptr = output_concat_grad->mut_dptr<T>();\n      output_concat_end_dim = output_concat_grad->shape_view().At(1);\n    }\n    ScatterSplitAddTranspose(ctx->stream(), batch_size, out_dim, concated_padded_dim,\n                             features_concated_dim, output_concat_end_dim, self_interaction,\n                             dy->dptr<T>(), output_concat_grad_ptr, matmul_out_grad_ptr);\n\n    auto batch_matmul = ep::primitive::NewPrimitive<ep::primitive::BatchMatmulFactory>(\n        ctx->device_type(), data_type, ep::primitive::BlasTransposeType::N,\n        ep::primitive::BlasTransposeType::N);\n    batch_matmul->Launch(ctx->stream(), batch_size, concated_padded_dim, vector_size,\n                         concated_padded_dim, 1.0, matmul_out_grad_ptr,\n                         padded_concated_features_ptr, 0.0, padded_concated_features_grad_ptr);\n\n    ConcatFeaturesGrad(ctx, batch_size, concated_padded_dim, vector_size,\n                       padded_concated_features_grad_ptr);\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\ntemplate<typename T>\nuser_op::InferTmpSizeFn GenFusedDotFeatureInteractionGradInferTmpSizeFn() {\n  return [](user_op::InferContext* ctx) {\n    int64_t features_concated_dim = 0;\n    for (int32_t i = 0; i < ctx->output_size(\"features_grad\"); ++i) {\n      features_concated_dim += ctx->InputShape(\"features_grad\", i).At(1);\n    }\n    const int64_t concated_padded_dim = GetPaddedDim(features_concated_dim);\n    const int64_t batch_size = ctx->InputShape(\"features_grad\", 0).At(0);\n    const int64_t vector_size = ctx->InputShape(\"features_grad\", 0).At(2);\n    size_t matmul_out_grad_size =\n        GetCudaAlignedSize(batch_size * concated_padded_dim * concated_padded_dim * sizeof(T));\n    size_t padded_concated_features_grad_size =\n        GetCudaAlignedSize(batch_size * concated_padded_dim * vector_size * sizeof(T));\n    size_t padded_concated_features_size = padded_concated_features_grad_size;\n    return matmul_out_grad_size + padded_concated_features_grad_size\n           + padded_concated_features_size;\n  };\n}\n\n#define REGISTER_FUSED_DOT_FEATURE_INTERACTION_GRAD_KERNEL(dtype)                      \\\n  REGISTER_USER_KERNEL(\"fused_dot_feature_interaction_grad\")                           \\\n      .SetCreateFn<FusedDotFeatureInteractionGradKernel<dtype>>()                      \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)                 \\\n                       && (user_op::HobDataType(\"dy\", 0) == GetDataType<dtype>::value) \\\n                       && (user_op::HobAttr<std::string>(\"pooling\") == \"none\"))        \\\n      .SetInferTmpSizeFn(GenFusedDotFeatureInteractionGradInferTmpSizeFn<dtype>());\n\nREGISTER_FUSED_DOT_FEATURE_INTERACTION_GRAD_KERNEL(float)\nREGISTER_FUSED_DOT_FEATURE_INTERACTION_GRAD_KERNEL(half)\n\ntemplate<typename T>\nclass FusedDotFeatureInteractionPoolingSumGradKernel final : public user_op::OpKernel,\n                                                             public user_op::CudaGraphSupport {\n public:\n  FusedDotFeatureInteractionPoolingSumGradKernel() = default;\n  ~FusedDotFeatureInteractionPoolingSumGradKernel() override = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const int input_size = ctx->input_size(\"features\");\n    if (input_size == 1) {\n      DispatchFeatureInteractionSumGradInputSize<T, 1>(ctx, input_size);\n    } else if (input_size == 2) {\n      DispatchFeatureInteractionSumGradInputSize<T, 2>(ctx, input_size);\n    } else if (input_size <= 8) {\n      DispatchFeatureInteractionSumGradInputSize<T, 8>(ctx, input_size);\n    } else {\n      CHECK_LE(input_size, 128) << \"input_size must not greater than 128. \";\n      DispatchFeatureInteractionSumGradInputSize<T, 128>(ctx, input_size);\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_FUSED_DOT_FEATURE_INTERACTION_POOLING_SUM_GRAD_KERNEL(dtype)          \\\n  REGISTER_USER_KERNEL(\"fused_dot_feature_interaction_grad\")                           \\\n      .SetCreateFn<FusedDotFeatureInteractionPoolingSumGradKernel<dtype>>()            \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)                 \\\n                       && (user_op::HobDataType(\"dy\", 0) == GetDataType<dtype>::value) \\\n                       && (user_op::HobAttr<std::string>(\"pooling\") == \"sum\"));\n\nREGISTER_FUSED_DOT_FEATURE_INTERACTION_POOLING_SUM_GRAD_KERNEL(float)\nREGISTER_FUSED_DOT_FEATURE_INTERACTION_POOLING_SUM_GRAD_KERNEL(half)\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/fused_gelu_mul_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/cuda/elementwise.cuh\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n\nnamespace oneflow {\nnamespace cuda {\n\nnamespace fused_gelu {\n\nOF_DEVICE_FUNC float TanhApprox(float x) {\n#if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000)\n  float r;\n  asm(\"tanh.approx.f32 %0,%1; \\n\\t\" : \"=f\"(r) : \"f\"(x));\n  return r;\n#else\n  return tanhf(x);\n#endif  // (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000)\n}\n\ntemplate<typename T>\nstruct FusedFastGeluMulFunctor {\n  static constexpr T alpha = static_cast<T>(0.7978845608028654);\n  static constexpr T beta = static_cast<T>(0.044714998453855515);\n\n  OF_DEVICE_FUNC FusedFastGeluMulFunctor() {}\n\n  OF_DEVICE_FUNC T operator()(T x, T m) const {\n    // ref to UnaryFunctor of kFastGelu\n    const T half = static_cast<T>(0.5);\n    const T one = static_cast<T>(1);\n    const T tanh_in = alpha * (x + beta * x * x * x);\n    return half * x * (one + tanh(tanh_in)) * m;\n  }\n};\n\ntemplate<>\nstruct FusedFastGeluMulFunctor<half> {\n  static constexpr float alpha = FusedFastGeluMulFunctor<float>::alpha;\n  static constexpr float beta = FusedFastGeluMulFunctor<float>::beta;\n  FusedFastGeluMulFunctor<float> float_functor;\n\n  OF_DEVICE_FUNC FusedFastGeluMulFunctor() {}\n\n  OF_DEVICE_FUNC half operator()(const half x, const half m) const {\n#if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000)\n    const float tanh_in =\n        __half2float(__float2half_rn(alpha) * (x + __float2half_rn(beta) * x * x * x));\n    const float tanh_out = TanhApprox(tanh_in);\n    return __float2half_rn(0.5F) * x * (__float2half_rn(1.0F) + __float2half_rn(tanh_out)) * m;\n#else\n    return static_cast<half>(float_functor(static_cast<float>(x), static_cast<float>(m)));\n#endif  // (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000)\n  }\n\n#if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000)\n  __device__ void Apply2(half* y, const half* x, const half* m) const {\n    const half2 x2 = *(reinterpret_cast<const half2*>(x));\n    const float2 tanh_in = __half22float2(\n        __hmul2(__float2half2_rn(alpha),\n                __hadd2(x2, __hmul2(__hmul2(__hmul2(__float2half2_rn(beta), x2), x2), x2))));\n    float2 tanh_out;\n    tanh_out.x = TanhApprox(tanh_in.x);\n    tanh_out.y = TanhApprox(tanh_in.y);\n    const half2 m2 = *(reinterpret_cast<const half2*>(m));\n    const half2 y2 = __hmul2(__hmul2(__hmul2(__float2half2_rn(0.5F), x2),\n                                     __hadd2(__float2half2_rn(1.0F), __float22half2_rn(tanh_out))),\n                             m2);\n    *reinterpret_cast<half2*>(y) = y2;\n  }\n#endif  // (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000)\n};\n\n#if CUDA_VERSION >= 11000\n\ntemplate<>\nstruct FusedFastGeluMulFunctor<nv_bfloat16> {\n  FusedFastGeluMulFunctor<float> float_functor;\n\n  OF_DEVICE_FUNC FusedFastGeluMulFunctor() {}\n\n  OF_DEVICE_FUNC nv_bfloat16 operator()(const nv_bfloat16 x, const nv_bfloat16 m) const {\n    return __float2bfloat16(float_functor(__bfloat162float(x), __bfloat162float(m)));\n  }\n};\n\n#endif  // CUDA_VERSION >= 11000\n\n}  // namespace fused_gelu\n\ntemplate<typename T>\nclass FusedFastGeluMulKernel final : public user_op::OpKernel {\n public:\n  FusedFastGeluMulKernel() = default;\n  ~FusedFastGeluMulKernel() override = default;\n\n private:\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const auto* in = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    const auto* multiplier = ctx->Tensor4ArgNameAndIndex(\"multiplier\", 0);\n    auto* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n\n    int64_t elem_cnt = in->shape_view().elem_cnt();\n    OF_CUDA_CHECK((elementwise::Binary(fused_gelu::FusedFastGeluMulFunctor<T>(), elem_cnt,\n                                       out->mut_dptr<T>(), in->dptr<T>(), multiplier->dptr<T>(),\n                                       ctx->stream()->As<ep::CudaStream>()->cuda_stream())));\n  };\n};\n\n#define REGISTER_FUSED_FAST_GELU_MUL_CUDA_KERNEL(dtype)                \\\n  REGISTER_USER_KERNEL(\"fused_fast_gelu_mul\")                          \\\n      .SetCreateFn<FusedFastGeluMulKernel<dtype>>()                    \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \\\n                       && (user_op::HobDataType(\"in\", 0) == GetDataType<dtype>::value));\n\nREGISTER_FUSED_FAST_GELU_MUL_CUDA_KERNEL(float)\nREGISTER_FUSED_FAST_GELU_MUL_CUDA_KERNEL(double)\nREGISTER_FUSED_FAST_GELU_MUL_CUDA_KERNEL(half)\n#if CUDA_VERSION >= 11000\nREGISTER_FUSED_FAST_GELU_MUL_CUDA_KERNEL(nv_bfloat16)\n#endif\n\nnamespace fused_gelu {\n\ntemplate<typename T>\nstruct FusedFastGeluMulGradFunctor {\n  static constexpr T alpha = static_cast<T>(0.7978845608028654);\n  static constexpr T beta = static_cast<T>(0.044714998453855515);\n\n  __device__ FusedFastGeluMulGradFunctor() {}\n\n  __device__ void operator()(T& x_diff, T& m_diff, const T& dy, const T& x, const T& m) const {\n    const T one = static_cast<T>(1);\n    const T half = static_cast<T>(0.5);\n    const T pow3 = x * x * x;\n    const T tanh_in = alpha * (x + beta * pow3);\n    const T tanh_out = tanh(alpha * (x + beta * pow3));\n    // calc m_diff ref to UnaryFunctor of kFastGelu\n    m_diff = half * x * (one + tanh(tanh_in)) * dy;\n    // calc x_diff ref to BinaryOp::kFastGeluBackwardWithDyX\n    const T dtanh = alpha * (half * x + beta * static_cast<T>(1.5) * pow3);\n    x_diff = (half + half * tanh_out + dtanh * (one - tanh_out * tanh_out)) * m * dy;\n  }\n};\n\ntemplate<>\nstruct FusedFastGeluMulGradFunctor<half> {\n  static constexpr float alpha = FusedFastGeluMulGradFunctor<float>::alpha;\n  static constexpr float beta = FusedFastGeluMulGradFunctor<float>::beta;\n  FusedFastGeluMulGradFunctor<float> float_functor;\n\n  __device__ FusedFastGeluMulGradFunctor() {}\n\n  __device__ void operator()(half& x_diff, half& m_diff, const half& dy, const half& x,\n                             const half& m) const {\n#if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000)\n    const half halpha = __float2half_rn(alpha);\n    const half hbeta = __float2half_rn(beta);\n    const half hone = __float2half_rn(1.0F);\n    const half hhalf = __float2half_rn(0.5F);\n    const half pow3 = x * x * x;\n    const float tanh_in = __half2float(halpha * (x + hbeta * pow3));\n    const half tanh_out = __float2half_rn(TanhApprox(tanh_in));\n    // m_diff\n    m_diff = hhalf * x * (hone + tanh_out) * dy;\n    // x_diff\n    const half dtanh = halpha * (hhalf * x + hbeta * __float2half_rn(1.5F) * pow3);\n    x_diff = (hhalf + hhalf * tanh_out + dtanh * (hone - tanh_out * tanh_out)) * m * dy;\n#else\n    float x_diff_float;\n    float m_diff_float;\n    float_functor(x_diff_float, m_diff_float, static_cast<float>(dy), static_cast<float>(x),\n                  static_cast<float>(m));\n    x_diff = static_cast<half>(x_diff_float);\n    m_diff = static_cast<half>(m_diff_float);\n#endif  // (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000)\n  }\n\n#if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000)\n  __device__ void Apply2(half* x_diff, half* m_diff, const half* dy, const half* x,\n                         const half* m) const {\n    const half2 dy2 = *(reinterpret_cast<const half2*>(dy));\n    const half2 x2 = *(reinterpret_cast<const half2*>(x));\n    const half2 m2 = *(reinterpret_cast<const half2*>(m));\n    const half2 alpha2 = __float2half2_rn(alpha);\n    const half2 beta2 = __float2half2_rn(beta);\n    const half2 one2 = __float2half2_rn(1.0F);\n    const half2 hhalf2 = __float2half2_rn(0.5F);\n    const half2 pow3 = __hmul2(__hmul2(x2, x2), x2);\n    const float2 tanh_in = __half22float2(__hmul2(alpha2, __hadd2(x2, __hmul2(beta2, pow3))));\n    float2 tanh_out;\n    tanh_out.x = TanhApprox(tanh_in.x);\n    tanh_out.y = TanhApprox(tanh_in.y);\n    const half2 tanh_out2 = __float22half2_rn(tanh_out);\n    // m_diff\n    const half2 m_diff2 = __hmul2(__hmul2(hhalf2, __hmul2(x2, __hadd2(one2, tanh_out2))), dy2);\n    // x_diff\n    const half2 dtanh = __hmul2(\n        alpha2,\n        __hadd2(__hmul2(hhalf2, x2), __hmul2(beta2, __hmul2(pow3, __float2half2_rn(1.5F)))));\n    const half2 x_diff2 =\n        __hmul2(__hmul2(__hadd2(__hadd2(hhalf2, __hmul2(hhalf2, tanh_out2)),\n                                __hmul2(dtanh, __hsub2(one2, __hmul2(tanh_out2, tanh_out2)))),\n                        m2),\n                dy2);\n    *reinterpret_cast<half2*>(x_diff) = x_diff2;\n    *reinterpret_cast<half2*>(m_diff) = m_diff2;\n  }\n#endif  // (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000)\n};\n\n#if CUDA_VERSION >= 11000\n\ntemplate<>\nstruct FusedFastGeluMulGradFunctor<nv_bfloat16> {\n  FusedFastGeluMulGradFunctor<float> float_functor;\n\n  __device__ FusedFastGeluMulGradFunctor() {}\n\n  __device__ void operator()(nv_bfloat16& x_diff, nv_bfloat16& m_diff, const nv_bfloat16& dy,\n                             const nv_bfloat16& x, const nv_bfloat16& m) const {\n    float x_diff_float;\n    float m_diff_float;\n    float_functor(x_diff_float, m_diff_float, __bfloat162float(dy), __bfloat162float(x),\n                  __bfloat162float(m));\n    x_diff = __float2bfloat16(x_diff_float);\n    m_diff = __float2bfloat16(m_diff_float);\n  }\n};\n\n#endif  // CUDA_VERSION >= 11000\n\ntemplate<int pack_size, typename FunctorT, typename T>\n__device__ __forceinline__\n    typename std::enable_if<elementwise::HasApply2<FunctorT>::value == true && pack_size % 2 == 0,\n                            void>::type\n    FusedFastGeluMulGradFunctorApplyPack(const FunctorT& functor,\n                                         elementwise::Packed<T, pack_size>& x_diff_pack,\n                                         elementwise::Packed<T, pack_size>& m_diff_pack,\n                                         const elementwise::Packed<T, pack_size>& dy_pack,\n                                         const elementwise::Packed<T, pack_size>& x_pack,\n                                         const elementwise::Packed<T, pack_size>& m_pack) {\n#pragma unroll\n  for (int j = 0; j < pack_size; j += 2) {\n    functor.Apply2(x_diff_pack.elem + j, m_diff_pack.elem + j, dy_pack.elem + j, x_pack.elem + j,\n                   m_pack.elem + j);\n  }\n}\n\ntemplate<int pack_size, typename FunctorT, typename T>\n__device__ __forceinline__\n    typename std::enable_if<elementwise::HasApply2<FunctorT>::value == false || pack_size % 2 != 0,\n                            void>::type\n    FusedFastGeluMulGradFunctorApplyPack(const FunctorT& functor,\n                                         elementwise::Packed<T, pack_size>& x_diff_pack,\n                                         elementwise::Packed<T, pack_size>& m_diff_pack,\n                                         const elementwise::Packed<T, pack_size>& dy_pack,\n                                         const elementwise::Packed<T, pack_size>& x_pack,\n                                         const elementwise::Packed<T, pack_size>& m_pack) {\n#pragma unroll\n  for (int j = 0; j < pack_size; ++j) {\n    functor(x_diff_pack.elem[j], m_diff_pack.elem[j], dy_pack.elem[j], x_pack.elem[j],\n            m_pack.elem[j]);\n  }\n}\n\ntemplate<int pack_size, typename T>\n__global__ void __launch_bounds__(elementwise::kBlockSize)\n    FusedFastGeluMulGradCudaKernel(int64_t n_pack, elementwise::Packed<T, pack_size>* x_diff_pack,\n                                   elementwise::Packed<T, pack_size>* m_diff_pack,\n                                   const elementwise::Packed<T, pack_size>* dy_pack,\n                                   const elementwise::Packed<T, pack_size>* x_pack,\n                                   const elementwise::Packed<T, pack_size>* m_pack, int64_t n_tail,\n                                   T* x_diff_tail, T* m_diff_tail, const T* dy_tail,\n                                   const T* x_tail, const T* m_tail) {\n  FusedFastGeluMulGradFunctor<T> functor;\n  const int global_tid = blockIdx.x * elementwise::kBlockSize + threadIdx.x;\n  for (int64_t i = global_tid; i < n_pack; i += blockDim.x * gridDim.x) {\n    FusedFastGeluMulGradFunctorApplyPack<pack_size>(functor, x_diff_pack[i], m_diff_pack[i],\n                                                    dy_pack[i], x_pack[i], m_pack[i]);\n  }\n  if (global_tid < n_tail) {\n    functor(x_diff_tail[global_tid], m_diff_tail[global_tid], dy_tail[global_tid],\n            x_tail[global_tid], m_tail[global_tid]);\n  }\n}\n\ntemplate<size_t pack_size, typename T>\ncudaError_t LaunchFusedFastGeluMulGradCudaKernelByPack(cudaStream_t stream, int64_t n, T* x_diff,\n                                                       T* m_diff, const T* dy, const T* x,\n                                                       const T* m) {\n  const int64_t n_pack = n / pack_size;\n  const int64_t tail_offset = n_pack * pack_size;\n  const int64_t n_tail = n - tail_offset;\n  int num_blocks;\n  {\n    cudaError_t err = elementwise::GetNumBlocks(n_pack, &num_blocks);\n    if (err != cudaSuccess) { return err; }\n  }\n  FusedFastGeluMulGradCudaKernel<pack_size><<<num_blocks, elementwise::kBlockSize, 0, stream>>>(\n      n_pack, reinterpret_cast<elementwise::Packed<T, pack_size>*>(x_diff),\n      reinterpret_cast<elementwise::Packed<T, pack_size>*>(m_diff),\n      reinterpret_cast<const elementwise::Packed<T, pack_size>*>(dy),\n      reinterpret_cast<const elementwise::Packed<T, pack_size>*>(x),\n      reinterpret_cast<const elementwise::Packed<T, pack_size>*>(m), n_tail, x_diff + tail_offset,\n      m_diff + tail_offset, dy + tail_offset, x + tail_offset, m + tail_offset);\n  return cudaPeekAtLastError();\n}\n\ntemplate<typename T>\nstatic cudaError_t LaunchFusedFastGeluMulGradCudaKernel(cudaStream_t stream, int64_t n, T* x_diff,\n                                                        T* m_diff, const T* dy, const T* x,\n                                                        const T* m) {\n  constexpr int max_pack_size = elementwise::PackSize<T>();\n  if (elementwise::IsAlignedForPack<max_pack_size>(x_diff, m_diff, dy, x, m)) {\n    return LaunchFusedFastGeluMulGradCudaKernelByPack<max_pack_size>(stream, n, x_diff, m_diff, dy,\n                                                                     x, m);\n  } else {\n    return LaunchFusedFastGeluMulGradCudaKernelByPack<1>(stream, n, x_diff, m_diff, dy, x, m);\n  }\n}\n\n}  // namespace fused_gelu\n\ntemplate<typename T>\nclass FusedFastGeluMulGradKernel final : public user_op::OpKernel {\n public:\n  FusedFastGeluMulGradKernel() = default;\n  ~FusedFastGeluMulGradKernel() override = default;\n\n private:\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const auto* out_diff = ctx->Tensor4ArgNameAndIndex(\"out_diff\", 0);\n    const auto* in = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    const auto* multiplier = ctx->Tensor4ArgNameAndIndex(\"multiplier\", 0);\n    auto* in_diff = ctx->Tensor4ArgNameAndIndex(\"in_diff\", 0);\n    auto* multiplier_diff = ctx->Tensor4ArgNameAndIndex(\"multiplier_diff\", 0);\n\n    int64_t elem_cnt = in->shape_view().elem_cnt();\n    OF_CUDA_CHECK((fused_gelu::LaunchFusedFastGeluMulGradCudaKernel(\n        ctx->stream()->As<ep::CudaStream>()->cuda_stream(), elem_cnt, in_diff->mut_dptr<T>(),\n        multiplier_diff->mut_dptr<T>(), out_diff->dptr<T>(), in->dptr<T>(),\n        multiplier->dptr<T>())));\n  };\n};\n\n#define REGISTER_FUSED_FAST_GELU_MUL_GRAD_CUDA_KERNEL(dtype)           \\\n  REGISTER_USER_KERNEL(\"fused_fast_gelu_mul_grad\")                     \\\n      .SetCreateFn<FusedFastGeluMulGradKernel<dtype>>()                \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \\\n                       && (user_op::HobDataType(\"out_diff\", 0) == GetDataType<dtype>::value));\n\nREGISTER_FUSED_FAST_GELU_MUL_GRAD_CUDA_KERNEL(float)\nREGISTER_FUSED_FAST_GELU_MUL_GRAD_CUDA_KERNEL(double)\nREGISTER_FUSED_FAST_GELU_MUL_GRAD_CUDA_KERNEL(half)\n#if CUDA_VERSION >= 11000\nREGISTER_FUSED_FAST_GELU_MUL_GRAD_CUDA_KERNEL(nv_bfloat16)\n#endif\n\n}  // namespace cuda\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/fused_get_bounding_boxes_coord_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/new_kernel_util.h\"\n\nnamespace oneflow {\n\nnamespace {\ntemplate<typename T>\n__global__ void FusedGetBounddingBoxesCoordForward(const int n, const T* x1, const T* y1,\n                                                   const T* w1, const T* h1, const T* x2,\n                                                   const T* y2, const T* w2, const T* h2, T* b1_x1,\n                                                   T* b1_x2, T* b1_y1, T* b1_y2, T* b2_x1, T* b2_x2,\n                                                   T* b2_y1, T* b2_y2) {\n  CUDA_1D_KERNEL_LOOP(i, n) {\n    const T w1_ = w1[i] / static_cast<T>(2.0);\n    const T h1_ = h1[i] / static_cast<T>(2.0);\n    const T w2_ = w2[i] / static_cast<T>(2.0);\n    const T h2_ = h2[i] / static_cast<T>(2.0);\n    const T x1_i = x1[i], y1_i = y1[i], x2_i = x2[i], y2_i = y2[i];\n    b1_x1[i] = x1_i - w1_;\n    b1_x2[i] = x1_i + w1_;\n    b1_y1[i] = y1_i - h1_;\n    b1_y2[i] = y1_i + h1_;\n    b2_x1[i] = x2_i - w2_;\n    b2_x2[i] = x2_i + w2_;\n    b2_y1[i] = y2_i - h2_;\n    b2_y2[i] = y2_i + h2_;\n  }\n}\n\ntemplate<typename T>\n__global__ void FusedGetBounddingBoxesCoordBackward(\n    const int n, const T* b1_x1_diff, const T* b1_x2_diff, const T* b1_y1_diff, const T* b1_y2_diff,\n    const T* b2_x1_diff, const T* b2_x2_diff, const T* b2_y1_diff, const T* b2_y2_diff, T* x1_diff,\n    T* y1_diff, T* w1_diff, T* h1_diff, T* x2_diff, T* y2_diff, T* w2_diff, T* h2_diff) {\n  CUDA_1D_KERNEL_LOOP(i, n) {\n    const T b1_x1_diff_i = b1_x1_diff[i];\n    const T b1_x2_diff_i = b1_x2_diff[i];\n    const T b1_y1_diff_i = b1_y1_diff[i];\n    const T b1_y2_diff_i = b1_y2_diff[i];\n    const T b2_x1_diff_i = b2_x1_diff[i];\n    const T b2_x2_diff_i = b2_x2_diff[i];\n    const T b2_y2_diff_i = b2_y2_diff[i];\n    const T b2_y1_diff_i = b2_y1_diff[i];\n    x1_diff[i] = b1_x1_diff_i + b1_x2_diff_i;\n    y1_diff[i] = b1_y1_diff_i + b1_y2_diff_i;\n    w1_diff[i] = (b1_x2_diff_i - b1_x1_diff_i) / static_cast<T>(2.0);\n    h1_diff[i] = (b1_y2_diff_i - b1_y1_diff_i) / static_cast<T>(2.0);\n    x2_diff[i] = b2_x1_diff_i + b2_x2_diff_i;\n    y2_diff[i] = b2_y1_diff_i + b2_y2_diff_i;\n    w2_diff[i] = (b2_x2_diff_i - b2_x1_diff_i) / static_cast<T>(2.0);\n    h2_diff[i] = (b2_y2_diff_i - b2_y1_diff_i) / static_cast<T>(2.0);\n  }\n}\n};  // namespace\n\ntemplate<typename T>\nclass FusedGetBounddingBoxesCoordGpuKernel final : public user_op::OpKernel {\n public:\n  FusedGetBounddingBoxesCoordGpuKernel() = default;\n  ~FusedGetBounddingBoxesCoordGpuKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* x1 = ctx->Tensor4ArgNameAndIndex(\"x1\", 0);\n    const user_op::Tensor* y1 = ctx->Tensor4ArgNameAndIndex(\"y1\", 0);\n    const user_op::Tensor* w1 = ctx->Tensor4ArgNameAndIndex(\"w1\", 0);\n    const user_op::Tensor* h1 = ctx->Tensor4ArgNameAndIndex(\"h1\", 0);\n    const user_op::Tensor* x2 = ctx->Tensor4ArgNameAndIndex(\"x2\", 0);\n    const user_op::Tensor* y2 = ctx->Tensor4ArgNameAndIndex(\"y2\", 0);\n    const user_op::Tensor* w2 = ctx->Tensor4ArgNameAndIndex(\"w2\", 0);\n    const user_op::Tensor* h2 = ctx->Tensor4ArgNameAndIndex(\"h2\", 0);\n\n    user_op::Tensor* b1_x1 = ctx->Tensor4ArgNameAndIndex(\"b1_x1\", 0);\n    user_op::Tensor* b1_x2 = ctx->Tensor4ArgNameAndIndex(\"b1_x2\", 0);\n    user_op::Tensor* b1_y1 = ctx->Tensor4ArgNameAndIndex(\"b1_y1\", 0);\n    user_op::Tensor* b1_y2 = ctx->Tensor4ArgNameAndIndex(\"b1_y2\", 0);\n    user_op::Tensor* b2_x1 = ctx->Tensor4ArgNameAndIndex(\"b2_x1\", 0);\n    user_op::Tensor* b2_x2 = ctx->Tensor4ArgNameAndIndex(\"b2_x2\", 0);\n    user_op::Tensor* b2_y1 = ctx->Tensor4ArgNameAndIndex(\"b2_y1\", 0);\n    user_op::Tensor* b2_y2 = ctx->Tensor4ArgNameAndIndex(\"b2_y2\", 0);\n\n    const int32_t elem_cnt = x1->shape_view().elem_cnt();\n    RUN_CUDA_KERNEL((FusedGetBounddingBoxesCoordForward<T>), ctx->stream(), elem_cnt, elem_cnt,\n                    x1->dptr<T>(), y1->dptr<T>(), w1->dptr<T>(), h1->dptr<T>(), x2->dptr<T>(),\n                    y2->dptr<T>(), w2->dptr<T>(), h2->dptr<T>(), b1_x1->mut_dptr<T>(),\n                    b1_x2->mut_dptr<T>(), b1_y1->mut_dptr<T>(), b1_y2->mut_dptr<T>(),\n                    b2_x1->mut_dptr<T>(), b2_x2->mut_dptr<T>(), b2_y1->mut_dptr<T>(),\n                    b2_y2->mut_dptr<T>());\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_FUSED_GET_BOUNDDING_BOXES_COORD_CUDA_KERNEL(dtype)    \\\n  REGISTER_USER_KERNEL(\"fused_get_boundding_boxes_coord\")              \\\n      .SetCreateFn<FusedGetBounddingBoxesCoordGpuKernel<dtype>>()      \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \\\n                       && (user_op::HobDataType(\"b1_x1\", 0) == GetDataType<dtype>::value));\n\nREGISTER_FUSED_GET_BOUNDDING_BOXES_COORD_CUDA_KERNEL(float)\nREGISTER_FUSED_GET_BOUNDDING_BOXES_COORD_CUDA_KERNEL(half)\nREGISTER_FUSED_GET_BOUNDDING_BOXES_COORD_CUDA_KERNEL(double)\n\ntemplate<typename T>\nclass FusedGetBounddingBoxesCoordGradGpuKernel final : public user_op::OpKernel {\n public:\n  FusedGetBounddingBoxesCoordGradGpuKernel() = default;\n  ~FusedGetBounddingBoxesCoordGradGpuKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* b1_x1_diff = ctx->Tensor4ArgNameAndIndex(\"b1_x1_diff\", 0);\n    const user_op::Tensor* b1_x2_diff = ctx->Tensor4ArgNameAndIndex(\"b1_x2_diff\", 0);\n    const user_op::Tensor* b1_y1_diff = ctx->Tensor4ArgNameAndIndex(\"b1_y1_diff\", 0);\n    const user_op::Tensor* b1_y2_diff = ctx->Tensor4ArgNameAndIndex(\"b1_y2_diff\", 0);\n    const user_op::Tensor* b2_x1_diff = ctx->Tensor4ArgNameAndIndex(\"b2_x1_diff\", 0);\n    const user_op::Tensor* b2_x2_diff = ctx->Tensor4ArgNameAndIndex(\"b2_x2_diff\", 0);\n    const user_op::Tensor* b2_y1_diff = ctx->Tensor4ArgNameAndIndex(\"b2_y1_diff\", 0);\n    const user_op::Tensor* b2_y2_diff = ctx->Tensor4ArgNameAndIndex(\"b2_y2_diff\", 0);\n\n    user_op::Tensor* x1_diff = ctx->Tensor4ArgNameAndIndex(\"x1_diff\", 0);\n    user_op::Tensor* y1_diff = ctx->Tensor4ArgNameAndIndex(\"y1_diff\", 0);\n    user_op::Tensor* w1_diff = ctx->Tensor4ArgNameAndIndex(\"w1_diff\", 0);\n    user_op::Tensor* h1_diff = ctx->Tensor4ArgNameAndIndex(\"h1_diff\", 0);\n    user_op::Tensor* x2_diff = ctx->Tensor4ArgNameAndIndex(\"x2_diff\", 0);\n    user_op::Tensor* y2_diff = ctx->Tensor4ArgNameAndIndex(\"y2_diff\", 0);\n    user_op::Tensor* w2_diff = ctx->Tensor4ArgNameAndIndex(\"w2_diff\", 0);\n    user_op::Tensor* h2_diff = ctx->Tensor4ArgNameAndIndex(\"h2_diff\", 0);\n\n    const int32_t elem_cnt = b1_x1_diff->shape_view().elem_cnt();\n    RUN_CUDA_KERNEL((FusedGetBounddingBoxesCoordBackward<T>), ctx->stream(), elem_cnt, elem_cnt,\n                    b1_x1_diff->dptr<T>(), b1_x2_diff->dptr<T>(), b1_y1_diff->dptr<T>(),\n                    b1_y2_diff->dptr<T>(), b2_x1_diff->dptr<T>(), b2_x2_diff->dptr<T>(),\n                    b2_y1_diff->dptr<T>(), b2_y2_diff->dptr<T>(), x1_diff->mut_dptr<T>(),\n                    y1_diff->mut_dptr<T>(), w1_diff->mut_dptr<T>(), h1_diff->mut_dptr<T>(),\n                    x2_diff->mut_dptr<T>(), y2_diff->mut_dptr<T>(), w2_diff->mut_dptr<T>(),\n                    h2_diff->mut_dptr<T>());\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_FUSED_GET_BOUNDDING_BOXES_COORD_GRAD_CUDA_KERNEL(dtype) \\\n  REGISTER_USER_KERNEL(\"fused_get_boundding_boxes_coord_grad\")           \\\n      .SetCreateFn<FusedGetBounddingBoxesCoordGradGpuKernel<dtype>>()    \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)   \\\n                       && (user_op::HobDataType(\"b1_x1_diff\", 0) == GetDataType<dtype>::value));\n\nREGISTER_FUSED_GET_BOUNDDING_BOXES_COORD_GRAD_CUDA_KERNEL(float)\nREGISTER_FUSED_GET_BOUNDDING_BOXES_COORD_GRAD_CUDA_KERNEL(half)\nREGISTER_FUSED_GET_BOUNDDING_BOXES_COORD_GRAD_CUDA_KERNEL(double)\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/fused_get_ciou_diagonal_angle_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <cmath>\n#include \"oneflow/core/common/data_type.h\"\n#include \"oneflow/core/device/cuda_util.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/new_kernel_util.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename T>\nstruct FusedCiouAngleForwardFunctor {\n  __device__ T Compute(T w1, T h1, T w2, T h2, float eps) const {\n    T angle = (atan(w2 / (h2 + eps)) - atan(w1 / (h1 + eps)))\n              * (atan(w2 / (h2 + eps)) - atan(w1 / (h1 + eps)));\n    return static_cast<T>(4.0 / (M_PI * M_PI)) * angle;\n  }\n};\n\ntemplate<>\nstruct FusedCiouAngleForwardFunctor<half> {\n  __device__ half Compute(half w1, half h1, half w2, half h2, float eps) const {\n    float w1f = __half2float(w1);\n    float h1f = __half2float(h1);\n    float w2f = __half2float(w2);\n    float h2f = __half2float(h2);\n    float angle = (atan(w2f / (h2f + eps)) - atan(w1f / (h1f + eps)))\n                  * (atan(w2f / (h2f + eps)) - atan(w1f / (h1f + eps)));\n    return __float2half(static_cast<float>(4.0 / (M_PI * M_PI)) * angle);\n  }\n};\n\ntemplate<typename FUNCTOR, typename T>\n__global__ void FusedCiouAngleForward(FUNCTOR functor, const int n, const T* w1, const T* h1,\n                                      const T* w2, const T* h2, const float eps, T* v) {\n  CUDA_1D_KERNEL_LOOP(i, n) { v[i] = functor.Compute(w1[i], h1[i], w2[i], h2[i], eps); }\n}\n\ntemplate<typename T>\nstruct FusedCiouAngleBackwardFunctor {\n  __device__ T ComputeW1(T h1, T angle_delta, T angle1, float eps) const {\n    return static_cast<T>(-1.0) * angle_delta / ((h1 + eps) * angle1);\n  }\n\n  __device__ T ComputeW2(T h2, T angle_delta, T angle2, float eps) const {\n    return angle_delta / ((h2 + eps) * angle2);\n  }\n\n  __device__ T ComputeH1(T w1, T h1, T angle_delta, T angle1, float eps) const {\n    return w1 * angle_delta / ((h1 + eps) * (h1 + eps) * angle1);\n  }\n\n  __device__ T ComputeH2(T w2, T h2, T angle_delta, T angle2, float eps) const {\n    return static_cast<T>(-1.0) * w2 * angle_delta / ((h2 + eps) * (h2 + eps) * angle2);\n  }\n};\n\ntemplate<>\nstruct FusedCiouAngleBackwardFunctor<half> {\n  __device__ half ComputeW1(half h1, half angle_delta, half angle1, float eps) const {\n    float h1f = __half2float(h1);\n    float angle_delta_f = __half2float(angle_delta);\n    float angle1f = __half2float(angle1);\n    return __float2half(-1.0 * angle_delta_f / ((h1f + eps) * angle1f));\n  }\n\n  __device__ half ComputeW2(half h2, half angle_delta, half angle2, float eps) const {\n    float h2f = __half2float(h2);\n    float angle_delta_f = __half2float(angle_delta);\n    float angle2f = __half2float(angle2);\n    return __float2half(angle_delta_f / ((h2f + eps) * angle2f));\n  }\n\n  __device__ half ComputeH1(half w1, half h1, half angle_delta, half angle1, float eps) const {\n    float w1f = __half2float(w1);\n    float h1f = __half2float(h1);\n    float angle_delta_f = __half2float(angle_delta);\n    float angle1f = __half2float(angle1);\n    return __float2half(w1f * angle_delta_f / ((h1f + eps) * (h1f + eps) * angle1f));\n  }\n\n  __device__ half ComputeH2(half w2, half h2, half angle_delta, half angle2, float eps) const {\n    float w2f = __half2float(w2);\n    float h2f = __half2float(h2);\n    float angle_delta_f = __half2float(angle_delta);\n    float angle2f = __half2float(angle2);\n    return __float2half(-1.0 * w2f * angle_delta_f / ((h2f + eps) * (h2f + eps) * angle2f));\n  }\n};\n\ntemplate<typename T>\nstruct CalcAngleFunctor {\n  __device__ T ComputeDelta(T w1, T h1, T w2, T h2, float eps) const {\n    return static_cast<T>(8.0) * (atan(w2 / (h2 + eps)) - atan(w1 / (h1 + eps)))\n           / static_cast<T>((M_PI * M_PI));\n  }\n\n  __device__ T Compute1(T w1, T h1, float eps) const {\n    return static_cast<T>(1.0) + (w1 * w1 / ((h1 + eps) * (h1 + eps)));\n  }\n\n  __device__ T Compute2(T w2, T h2, float eps) const {\n    return static_cast<T>(1.0) + (w2 * w2 / ((h2 + eps) * (h2 + eps)));\n  }\n};\n\ntemplate<>\nstruct CalcAngleFunctor<half> {\n  __device__ half ComputeDelta(half w1, half h1, half w2, half h2, float eps) const {\n    float w1f = __half2float(w1);\n    float h1f = __half2float(h1);\n    float w2f = __half2float(w2);\n    float h2f = __half2float(h2);\n    return __float2half(8.0 * (atan(w2f / (h2f + eps)) - atan(w1f / (h1f + eps)))\n                        / static_cast<float>((M_PI * M_PI)));\n  }\n\n  __device__ half Compute1(half w1, half h1, float eps) const {\n    float w1f = __half2float(w1);\n    float h1f = __half2float(h1);\n    return __float2half(1.0 + (w1f * w1f / ((h1f + eps) * (h1f + eps))));\n  }\n\n  __device__ half Compute2(half w2, half h2, float eps) const {\n    float w2f = __half2float(w2);\n    float h2f = __half2float(h2);\n    return __float2half(1.0 + (w2f * w2f / ((h2f + eps) * (h2f + eps))));\n  }\n};\n\ntemplate<typename FUNCTOR_BACKWARD, typename FUNCTOR_ANGLE, typename T>\n__global__ void FusedCiouAngleBackward(FUNCTOR_BACKWARD functor_backward,\n                                       FUNCTOR_ANGLE functor_angle, const int n, const T* w1,\n                                       const T* h1, const T* w2, const T* h2, const T* v_diff,\n                                       const float eps, T* w1_diff, T* h1_diff, T* w2_diff,\n                                       T* h2_diff) {\n  CUDA_1D_KERNEL_LOOP(i, n) {\n    const T w1_i = w1[i];\n    const T h1_i = h1[i];\n    const T w2_i = w2[i];\n    const T h2_i = h2[i];\n    const T v_diff_i = v_diff[i];\n    const T angle_delta_i = functor_angle.ComputeDelta(w1_i, h1_i, w2_i, h2_i, eps);\n    const T angle1_i = functor_angle.Compute1(w1_i, h1_i, eps);\n    const T angle2_i = functor_angle.Compute2(w2_i, h2_i, eps);\n    w1_diff[i] = functor_backward.ComputeW1(h1_i, angle_delta_i, angle1_i, eps) * v_diff_i;\n    w2_diff[i] = functor_backward.ComputeW2(h2_i, angle_delta_i, angle2_i, eps) * v_diff_i;\n    h1_diff[i] = functor_backward.ComputeH1(w1_i, h1_i, angle_delta_i, angle1_i, eps) * v_diff_i;\n    h2_diff[i] = functor_backward.ComputeH2(w2_i, h2_i, angle_delta_i, angle2_i, eps) * v_diff_i;\n  }\n}\n\n}  // namespace\n\ntemplate<typename T>\nclass FusedGetCiouDiagonalAngleKernel final : public user_op::OpKernel {\n public:\n  FusedGetCiouDiagonalAngleKernel() = default;\n  ~FusedGetCiouDiagonalAngleKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* w1 = ctx->Tensor4ArgNameAndIndex(\"w1\", 0);\n    const user_op::Tensor* h1 = ctx->Tensor4ArgNameAndIndex(\"h1\", 0);\n    const user_op::Tensor* w2 = ctx->Tensor4ArgNameAndIndex(\"w2\", 0);\n    const user_op::Tensor* h2 = ctx->Tensor4ArgNameAndIndex(\"h2\", 0);\n    const auto eps = ctx->Attr<float>(\"eps\");\n\n    user_op::Tensor* v = ctx->Tensor4ArgNameAndIndex(\"v\", 0);\n\n    const int64_t elem_cnt = w1->shape_view().elem_cnt();\n\n    FusedCiouAngleForwardFunctor<T> fused_get_ciou_diagonal_angle_functor{};\n\n    RUN_CUDA_KERNEL((FusedCiouAngleForward<decltype(fused_get_ciou_diagonal_angle_functor), T>),\n                    ctx->stream(), elem_cnt, fused_get_ciou_diagonal_angle_functor, elem_cnt,\n                    w1->dptr<T>(), h1->dptr<T>(), w2->dptr<T>(), h2->dptr<T>(), eps,\n                    v->mut_dptr<T>());\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_FUSED_GET_CIOU_DIAGONAL_ANGLE_CUDA_KERNEL(dtype)      \\\n  REGISTER_USER_KERNEL(\"fused_get_ciou_diagonal_angle\")                \\\n      .SetCreateFn<FusedGetCiouDiagonalAngleKernel<dtype>>()           \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \\\n                       && (user_op::HobDataType(\"v\", 0) == GetDataType<dtype>::value));\n\nREGISTER_FUSED_GET_CIOU_DIAGONAL_ANGLE_CUDA_KERNEL(float)\nREGISTER_FUSED_GET_CIOU_DIAGONAL_ANGLE_CUDA_KERNEL(double)\nREGISTER_FUSED_GET_CIOU_DIAGONAL_ANGLE_CUDA_KERNEL(half)\n\ntemplate<typename T>\nclass FusedGetCiouDiagonalAngleGradKernel final : public user_op::OpKernel {\n public:\n  FusedGetCiouDiagonalAngleGradKernel() = default;\n  ~FusedGetCiouDiagonalAngleGradKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* w1 = ctx->Tensor4ArgNameAndIndex(\"w1\", 0);\n    const user_op::Tensor* h1 = ctx->Tensor4ArgNameAndIndex(\"h1\", 0);\n    const user_op::Tensor* w2 = ctx->Tensor4ArgNameAndIndex(\"w2\", 0);\n    const user_op::Tensor* h2 = ctx->Tensor4ArgNameAndIndex(\"h2\", 0);\n    const user_op::Tensor* v_diff = ctx->Tensor4ArgNameAndIndex(\"v_diff\", 0);\n    const auto eps = ctx->Attr<float>(\"eps\");\n\n    user_op::Tensor* w1_diff = ctx->Tensor4ArgNameAndIndex(\"w1_diff\", 0);\n    user_op::Tensor* h1_diff = ctx->Tensor4ArgNameAndIndex(\"h1_diff\", 0);\n    user_op::Tensor* w2_diff = ctx->Tensor4ArgNameAndIndex(\"w2_diff\", 0);\n    user_op::Tensor* h2_diff = ctx->Tensor4ArgNameAndIndex(\"h2_diff\", 0);\n\n    const int64_t elem_cnt = w1->shape_view().elem_cnt();\n\n    FusedCiouAngleBackwardFunctor<T> fused_get_ciou_diagonal_angle_grad_functor{};\n    CalcAngleFunctor<T> calc_angle_functor{};\n\n    RUN_CUDA_KERNEL((FusedCiouAngleBackward<decltype(fused_get_ciou_diagonal_angle_grad_functor),\n                                            decltype(calc_angle_functor), T>),\n                    ctx->stream(), elem_cnt, fused_get_ciou_diagonal_angle_grad_functor,\n                    calc_angle_functor, elem_cnt, w1->dptr<T>(), h1->dptr<T>(), w2->dptr<T>(),\n                    h2->dptr<T>(), v_diff->dptr<T>(), eps, w1_diff->mut_dptr<T>(),\n                    h1_diff->mut_dptr<T>(), w2_diff->mut_dptr<T>(), h2_diff->mut_dptr<T>());\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_FUSED_GET_CIOU_DIAGONAL_ANGLE_GRAD_CUDA_KERNEL(dtype) \\\n  REGISTER_USER_KERNEL(\"fused_get_ciou_diagonal_angle_grad\")           \\\n      .SetCreateFn<FusedGetCiouDiagonalAngleGradKernel<dtype>>()       \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \\\n                       && (user_op::HobDataType(\"w1_diff\", 0) == GetDataType<dtype>::value));\n\nREGISTER_FUSED_GET_CIOU_DIAGONAL_ANGLE_GRAD_CUDA_KERNEL(float)\nREGISTER_FUSED_GET_CIOU_DIAGONAL_ANGLE_GRAD_CUDA_KERNEL(double)\nREGISTER_FUSED_GET_CIOU_DIAGONAL_ANGLE_GRAD_CUDA_KERNEL(half)\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/fused_get_ciou_result_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/new_kernel_util.h\"\n\nnamespace oneflow {\n\nnamespace {\ntemplate<typename T>\n__global__ void FusedGetCiouResultForward(const int n, const T* v, const T* iou, const T* rho2,\n                                          const T* c2, T* y, T* alpha, float eps) {\n  CUDA_1D_KERNEL_LOOP(i, n) {\n    const T v_i = v[i];\n    const T iou_i = iou[i];\n    const T alpha_i = v_i / (v_i - iou_i + static_cast<T>(1.0 + eps));\n    y[i] = iou_i - (rho2[i] / c2[i] + v_i * alpha_i);\n    alpha[i] = alpha_i;\n  }\n}\n\ntemplate<typename T>\n__global__ void FusedGetCiouResultBackward(const int n, const T* dy, const T* alpha, const T* rho2,\n                                           const T* c2, T* dv, T* diou, T* drho2, T* dc2) {\n  CUDA_1D_KERNEL_LOOP(i, n) {\n    const T c2_i = c2[i];\n    const T dy_i = dy[i];\n    dv[i] = -alpha[i] * dy_i;\n    diou[i] = dy_i;\n    drho2[i] = -dy_i / c2[i];\n    dc2[i] = rho2[i] / (c2_i * c2_i) * dy_i;\n  }\n}\n};  // namespace\n\ntemplate<typename T>\nclass FusedGetCiouResultGpuKernel final : public user_op::OpKernel {\n public:\n  FusedGetCiouResultGpuKernel() = default;\n  ~FusedGetCiouResultGpuKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* v = ctx->Tensor4ArgNameAndIndex(\"v\", 0);\n    const user_op::Tensor* iou = ctx->Tensor4ArgNameAndIndex(\"iou\", 0);\n    const user_op::Tensor* rho2 = ctx->Tensor4ArgNameAndIndex(\"rho2\", 0);\n    const user_op::Tensor* c2 = ctx->Tensor4ArgNameAndIndex(\"c2\", 0);\n\n    user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    user_op::Tensor* alpha = ctx->Tensor4ArgNameAndIndex(\"alpha\", 0);\n\n    float eps = ctx->Attr<float>(\"eps\");\n\n    const int32_t elem_cnt = v->shape_view().elem_cnt();\n    RUN_CUDA_KERNEL((FusedGetCiouResultForward<T>), ctx->stream(), elem_cnt, elem_cnt, v->dptr<T>(),\n                    iou->dptr<T>(), rho2->dptr<T>(), c2->dptr<T>(), y->mut_dptr<T>(),\n                    alpha->mut_dptr<T>(), eps);\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_FUSED_GET_CIOU_RESULT_CUDA_KERNEL(dtype)              \\\n  REGISTER_USER_KERNEL(\"fused_get_ciou_result\")                        \\\n      .SetCreateFn<FusedGetCiouResultGpuKernel<dtype>>()               \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \\\n                       && (user_op::HobDataType(\"v\", 0) == GetDataType<dtype>::value));\n\nREGISTER_FUSED_GET_CIOU_RESULT_CUDA_KERNEL(float)\nREGISTER_FUSED_GET_CIOU_RESULT_CUDA_KERNEL(half)\nREGISTER_FUSED_GET_CIOU_RESULT_CUDA_KERNEL(double)\n\ntemplate<typename T>\nclass FusedGetCiouResultGradGpuKernel final : public user_op::OpKernel {\n public:\n  FusedGetCiouResultGradGpuKernel() = default;\n  ~FusedGetCiouResultGradGpuKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    const user_op::Tensor* alpha = ctx->Tensor4ArgNameAndIndex(\"alpha\", 0);\n    const user_op::Tensor* rho2 = ctx->Tensor4ArgNameAndIndex(\"rho2\", 0);\n    const user_op::Tensor* c2 = ctx->Tensor4ArgNameAndIndex(\"c2\", 0);\n\n    user_op::Tensor* dv = ctx->Tensor4ArgNameAndIndex(\"dv\", 0);\n    user_op::Tensor* diou = ctx->Tensor4ArgNameAndIndex(\"diou\", 0);\n    user_op::Tensor* drho2 = ctx->Tensor4ArgNameAndIndex(\"drho2\", 0);\n    user_op::Tensor* dc2 = ctx->Tensor4ArgNameAndIndex(\"dc2\", 0);\n\n    const int32_t elem_cnt = dy->shape_view().elem_cnt();\n    RUN_CUDA_KERNEL((FusedGetCiouResultBackward<T>), ctx->stream(), elem_cnt, elem_cnt,\n                    dy->dptr<T>(), alpha->dptr<T>(), rho2->dptr<T>(), c2->dptr<T>(),\n                    dv->mut_dptr<T>(), diou->mut_dptr<T>(), drho2->mut_dptr<T>(),\n                    dc2->mut_dptr<T>());\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_FUSED_GET_CIOU_RESULT_GRAD_CUDA_KERNEL(dtype)         \\\n  REGISTER_USER_KERNEL(\"fused_get_ciou_result_grad\")                   \\\n      .SetCreateFn<FusedGetCiouResultGradGpuKernel<dtype>>()           \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \\\n                       && (user_op::HobDataType(\"dy\", 0) == GetDataType<dtype>::value));\n\nREGISTER_FUSED_GET_CIOU_RESULT_GRAD_CUDA_KERNEL(float)\nREGISTER_FUSED_GET_CIOU_RESULT_GRAD_CUDA_KERNEL(half)\nREGISTER_FUSED_GET_CIOU_RESULT_GRAD_CUDA_KERNEL(double)\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/fused_get_convex_diagonal_squared_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/device/cuda_util.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/new_kernel_util.h\"\n#include \"oneflow/core/common/math_util.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename T>\n__global__ void FusedGetConvexDiagonalSquaredForward(const int n, const T* b1_x1, const T* b1_x2,\n                                                     const T* b2_x1, const T* b2_x2, const T* b1_y1,\n                                                     const T* b1_y2, const T* b2_y1, const T* b2_y2,\n                                                     T* c2, const float eps) {\n  CUDA_1D_KERNEL_LOOP(i, n) {\n    const T cw = DeviceMax(b1_x2[i], b2_x2[i]) - DeviceMin(b1_x1[i], b2_x1[i]);\n    const T ch = DeviceMax(b1_y2[i], b2_y2[i]) - DeviceMin(b1_y1[i], b2_y1[i]);\n    c2[i] = cw * cw + ch * ch + static_cast<T>(eps);\n  }\n}\n\ntemplate<typename T>\n__global__ void FusedGetConvexDiagonalSquaredBackward(\n    const int n, const T* b1_x1, const T* b1_x2, const T* b2_x1, const T* b2_x2, const T* b1_y1,\n    const T* b1_y2, const T* b2_y1, const T* b2_y2, const T* c2_diff, T* b1_x1_diff, T* b1_x2_diff,\n    T* b2_x1_diff, T* b2_x2_diff, T* b1_y1_diff, T* b1_y2_diff, T* b2_y1_diff, T* b2_y2_diff,\n    const float eps) {\n  CUDA_1D_KERNEL_LOOP(i, n) {\n    const T zero = static_cast<T>(0), one = static_cast<T>(1);\n    const T cw = DeviceMax(b1_x2[i], b2_x2[i]) - DeviceMin(b1_x1[i], b2_x1[i]);\n    const T ch = DeviceMax(b1_y2[i], b2_y2[i]) - DeviceMin(b1_y1[i], b2_y1[i]);\n    const T c2_diff_cw = static_cast<T>(2) * cw * c2_diff[i];\n    const T c2_diff_ch = static_cast<T>(2) * ch * c2_diff[i];\n    b1_x2_diff[i] = c2_diff_cw * (b1_x2[i] > b2_x2[i] ? one : zero);\n    b2_x2_diff[i] = c2_diff_cw * (b1_x2[i] > b2_x2[i] ? zero : one);\n    b1_x1_diff[i] = -c2_diff_cw * (b1_x1[i] < b2_x1[i] ? one : zero);\n    b2_x1_diff[i] = -c2_diff_cw * (b1_x1[i] < b2_x1[i] ? zero : one);\n    b1_y2_diff[i] = c2_diff_ch * (b1_y2[i] > b2_y2[i] ? one : zero);\n    b2_y2_diff[i] = c2_diff_ch * (b1_y2[i] > b2_y2[i] ? zero : one);\n    b1_y1_diff[i] = -c2_diff_ch * (b1_y1[i] < b2_y1[i] ? one : zero);\n    b2_y1_diff[i] = -c2_diff_ch * (b1_y1[i] < b2_y1[i] ? zero : one);\n  }\n}\n\n}  // namespace\n\ntemplate<typename T>\nclass FusedGetConvexDiagonalSquaredKernel final : public user_op::OpKernel {\n public:\n  FusedGetConvexDiagonalSquaredKernel() = default;\n  ~FusedGetConvexDiagonalSquaredKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* b1_x1 = ctx->Tensor4ArgNameAndIndex(\"b1_x1\", 0);\n    const user_op::Tensor* b1_x2 = ctx->Tensor4ArgNameAndIndex(\"b1_x2\", 0);\n    const user_op::Tensor* b2_x1 = ctx->Tensor4ArgNameAndIndex(\"b2_x1\", 0);\n    const user_op::Tensor* b2_x2 = ctx->Tensor4ArgNameAndIndex(\"b2_x2\", 0);\n    const user_op::Tensor* b1_y1 = ctx->Tensor4ArgNameAndIndex(\"b1_y1\", 0);\n    const user_op::Tensor* b1_y2 = ctx->Tensor4ArgNameAndIndex(\"b1_y2\", 0);\n    const user_op::Tensor* b2_y1 = ctx->Tensor4ArgNameAndIndex(\"b2_y1\", 0);\n    const user_op::Tensor* b2_y2 = ctx->Tensor4ArgNameAndIndex(\"b2_y2\", 0);\n\n    user_op::Tensor* c2 = ctx->Tensor4ArgNameAndIndex(\"c2\", 0);\n    const float eps = ctx->Attr<float>(\"eps\");\n\n    const int64_t elem_cnt = b1_x1->shape_view().elem_cnt();\n\n    RUN_CUDA_KERNEL((FusedGetConvexDiagonalSquaredForward<T>), ctx->stream(), elem_cnt, elem_cnt,\n                    b1_x1->dptr<T>(), b1_x2->dptr<T>(), b2_x1->dptr<T>(), b2_x2->dptr<T>(),\n                    b1_y1->dptr<T>(), b1_y2->dptr<T>(), b2_y1->dptr<T>(), b2_y2->dptr<T>(),\n                    c2->mut_dptr<T>(), eps);\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_FUSED_GET_CONVEX_DIAGOAL_SQUARED_CUDA_KERNEL(dtype)   \\\n  REGISTER_USER_KERNEL(\"fused_get_convex_diagonal_squared\")            \\\n      .SetCreateFn<FusedGetConvexDiagonalSquaredKernel<dtype>>()       \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \\\n                       && (user_op::HobDataType(\"b1_x1\", 0) == GetDataType<dtype>::value));\n\nREGISTER_FUSED_GET_CONVEX_DIAGOAL_SQUARED_CUDA_KERNEL(float)\nREGISTER_FUSED_GET_CONVEX_DIAGOAL_SQUARED_CUDA_KERNEL(double)\nREGISTER_FUSED_GET_CONVEX_DIAGOAL_SQUARED_CUDA_KERNEL(half)\n\ntemplate<typename T>\nclass FusedGetConvexDiagonalSquaredGradKernel final : public user_op::OpKernel {\n public:\n  FusedGetConvexDiagonalSquaredGradKernel() = default;\n  ~FusedGetConvexDiagonalSquaredGradKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* c2_diff = ctx->Tensor4ArgNameAndIndex(\"c2_diff\", 0);\n    const user_op::Tensor* b1_x1 = ctx->Tensor4ArgNameAndIndex(\"b1_x1\", 0);\n    const user_op::Tensor* b1_x2 = ctx->Tensor4ArgNameAndIndex(\"b1_x2\", 0);\n    const user_op::Tensor* b2_x1 = ctx->Tensor4ArgNameAndIndex(\"b2_x1\", 0);\n    const user_op::Tensor* b2_x2 = ctx->Tensor4ArgNameAndIndex(\"b2_x2\", 0);\n    const user_op::Tensor* b1_y1 = ctx->Tensor4ArgNameAndIndex(\"b1_y1\", 0);\n    const user_op::Tensor* b1_y2 = ctx->Tensor4ArgNameAndIndex(\"b1_y2\", 0);\n    const user_op::Tensor* b2_y1 = ctx->Tensor4ArgNameAndIndex(\"b2_y1\", 0);\n    const user_op::Tensor* b2_y2 = ctx->Tensor4ArgNameAndIndex(\"b2_y2\", 0);\n\n    user_op::Tensor* b1_x1_diff = ctx->Tensor4ArgNameAndIndex(\"b1_x1_diff\", 0);\n    user_op::Tensor* b1_x2_diff = ctx->Tensor4ArgNameAndIndex(\"b1_x2_diff\", 0);\n    user_op::Tensor* b2_x1_diff = ctx->Tensor4ArgNameAndIndex(\"b2_x1_diff\", 0);\n    user_op::Tensor* b2_x2_diff = ctx->Tensor4ArgNameAndIndex(\"b2_x2_diff\", 0);\n    user_op::Tensor* b1_y1_diff = ctx->Tensor4ArgNameAndIndex(\"b1_y1_diff\", 0);\n    user_op::Tensor* b1_y2_diff = ctx->Tensor4ArgNameAndIndex(\"b1_y2_diff\", 0);\n    user_op::Tensor* b2_y1_diff = ctx->Tensor4ArgNameAndIndex(\"b2_y1_diff\", 0);\n    user_op::Tensor* b2_y2_diff = ctx->Tensor4ArgNameAndIndex(\"b2_y2_diff\", 0);\n\n    const float eps = ctx->Attr<float>(\"eps\");\n    const int64_t elem_cnt = b1_x1_diff->shape_view().elem_cnt();\n\n    RUN_CUDA_KERNEL((FusedGetConvexDiagonalSquaredBackward<T>), ctx->stream(), elem_cnt, elem_cnt,\n                    b1_x1->dptr<T>(), b1_x2->dptr<T>(), b2_x1->dptr<T>(), b2_x2->dptr<T>(),\n                    b1_y1->dptr<T>(), b1_y2->dptr<T>(), b2_y1->dptr<T>(), b2_y2->dptr<T>(),\n                    c2_diff->dptr<T>(), b1_x1_diff->mut_dptr<T>(), b1_x2_diff->mut_dptr<T>(),\n                    b2_x1_diff->mut_dptr<T>(), b2_x2_diff->mut_dptr<T>(), b1_y1_diff->mut_dptr<T>(),\n                    b1_y2_diff->mut_dptr<T>(), b2_y1_diff->mut_dptr<T>(), b2_y2_diff->mut_dptr<T>(),\n                    eps);\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_FUSED_GET_CONVEX_DIAGOAL_SQUARED_GRAD_CUDA_KERNEL(dtype) \\\n  REGISTER_USER_KERNEL(\"fused_get_convex_diagonal_squared_grad\")          \\\n      .SetCreateFn<FusedGetConvexDiagonalSquaredGradKernel<dtype>>()      \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)    \\\n                       && (user_op::HobDataType(\"b1_x1\", 0) == GetDataType<dtype>::value));\n\nREGISTER_FUSED_GET_CONVEX_DIAGOAL_SQUARED_GRAD_CUDA_KERNEL(float)\nREGISTER_FUSED_GET_CONVEX_DIAGOAL_SQUARED_GRAD_CUDA_KERNEL(double)\nREGISTER_FUSED_GET_CONVEX_DIAGOAL_SQUARED_GRAD_CUDA_KERNEL(half)\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/fused_get_intersection_area_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/device/cuda_util.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/new_kernel_util.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename T>\nstruct MinMaxDeltaFunctor {\n  __device__ T Compute(T b1_x2_i, T b2_x2_i, T b1_x1_i, T b2_x1_i) const {\n    return min(b1_x2_i, b2_x2_i) - max(b1_x1_i, b2_x1_i);\n  }\n};\n\ntemplate<>\nstruct MinMaxDeltaFunctor<half> {\n  __device__ half Compute(half b1_x2_i, half b2_x2_i, half b1_x1_i, half b2_x1_i) const {\n    const half b_x2_min = b1_x2_i < b2_x2_i ? b1_x2_i : b2_x2_i;\n    const half b_x1_max = b1_x1_i > b2_x1_i ? b1_x1_i : b2_x1_i;\n    return b_x2_min - b_x1_max;\n  }\n};\n\ntemplate<typename FUNCTOR, typename T>\n__global__ void FusedGetIntersectionAreaBackward(FUNCTOR functor, const int n, const T* b1_x1,\n                                                 const T* b1_x2, const T* b2_x1, const T* b2_x2,\n                                                 const T* b1_y1, const T* b1_y2, const T* b2_y1,\n                                                 const T* b2_y2, const T* inter_diff, T* b1_x1_diff,\n                                                 T* b1_x2_diff, T* b2_x1_diff, T* b2_x2_diff,\n                                                 T* b1_y1_diff, T* b1_y2_diff, T* b2_y1_diff,\n                                                 T* b2_y2_diff) {\n  CUDA_1D_KERNEL_LOOP(i, n) {\n    const T inter_diff_i = inter_diff[i];\n    const T b_x_min_max = functor.Compute(b1_x2[i], b2_x2[i], b1_x1[i], b2_x1[i]);\n    const T b_y_min_max = functor.Compute(b1_y2[i], b2_y2[i], b1_y1[i], b2_y1[i]);\n    const T b_x_min_max_inter = b_x_min_max * inter_diff_i;\n    const T b_y_min_max_inter = b_y_min_max * inter_diff_i;\n\n    b1_x1_diff[i] = static_cast<T>(0.0);\n    b1_x2_diff[i] = static_cast<T>(0.0);\n    b2_x1_diff[i] = static_cast<T>(0.0);\n    b2_x2_diff[i] = static_cast<T>(0.0);\n    b1_y1_diff[i] = static_cast<T>(0.0);\n    b1_y2_diff[i] = static_cast<T>(0.0);\n    b2_y1_diff[i] = static_cast<T>(0.0);\n    b2_y2_diff[i] = static_cast<T>(0.0);\n\n    if (b_x_min_max > static_cast<T>(0.0) && b_y_min_max > static_cast<T>(0.0)) {\n      if (b1_x1[i] >= b2_x1[i]) { b1_x1_diff[i] = static_cast<T>(-1.0) * b_y_min_max_inter; }\n      if (b1_x1[i] <= b2_x1[i]) { b2_x1_diff[i] = static_cast<T>(-1.0) * b_y_min_max_inter; }\n      if (b1_x2[i] <= b2_x2[i]) { b1_x2_diff[i] = b_y_min_max_inter; }\n      if (b1_x2[i] >= b2_x2[i]) { b2_x2_diff[i] = b_y_min_max_inter; }\n\n      if (b1_y1[i] >= b2_y1[i]) { b1_y1_diff[i] = static_cast<T>(-1.0) * b_x_min_max_inter; }\n      if (b1_y1[i] <= b2_y1[i]) { b2_y1_diff[i] = static_cast<T>(-1.0) * b_x_min_max_inter; }\n      if (b1_y2[i] <= b2_y2[i]) { b1_y2_diff[i] = b_x_min_max_inter; }\n      if (b1_y2[i] >= b2_y2[i]) { b2_y2_diff[i] = b_x_min_max_inter; }\n    }\n  }\n}\n\ntemplate<typename FUNCTOR, typename T>\n__global__ void FusedGetIntersectionAreaForward(FUNCTOR functor, const int n, const T* b1_x1,\n                                                const T* b1_x2, const T* b2_x1, const T* b2_x2,\n                                                const T* b1_y1, const T* b1_y2, const T* b2_y1,\n                                                const T* b2_y2, T* inter) {\n  CUDA_1D_KERNEL_LOOP(i, n) {\n    const T b_x_min_max = functor.Compute(b1_x2[i], b2_x2[i], b1_x1[i], b2_x1[i]);\n    const T b_y_min_max = functor.Compute(b1_y2[i], b2_y2[i], b1_y1[i], b2_y1[i]);\n    inter[i] = static_cast<T>(0.0);\n    if (b_x_min_max > static_cast<T>(0.0) && b_y_min_max > static_cast<T>(0.0)) {\n      inter[i] = b_x_min_max * b_y_min_max;\n    }\n  }\n}\n\n}  // namespace\n\ntemplate<typename T>\nclass FusedGetIntersectionAreaKernel final : public user_op::OpKernel {\n public:\n  FusedGetIntersectionAreaKernel() = default;\n  ~FusedGetIntersectionAreaKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* b1_x1 = ctx->Tensor4ArgNameAndIndex(\"b1_x1\", 0);\n    const user_op::Tensor* b1_x2 = ctx->Tensor4ArgNameAndIndex(\"b1_x2\", 0);\n    const user_op::Tensor* b2_x1 = ctx->Tensor4ArgNameAndIndex(\"b2_x1\", 0);\n    const user_op::Tensor* b2_x2 = ctx->Tensor4ArgNameAndIndex(\"b2_x2\", 0);\n    const user_op::Tensor* b1_y1 = ctx->Tensor4ArgNameAndIndex(\"b1_y1\", 0);\n    const user_op::Tensor* b1_y2 = ctx->Tensor4ArgNameAndIndex(\"b1_y2\", 0);\n    const user_op::Tensor* b2_y1 = ctx->Tensor4ArgNameAndIndex(\"b2_y1\", 0);\n    const user_op::Tensor* b2_y2 = ctx->Tensor4ArgNameAndIndex(\"b2_y2\", 0);\n\n    user_op::Tensor* inter = ctx->Tensor4ArgNameAndIndex(\"inter\", 0);\n\n    const int64_t elem_cnt = b1_x2->shape_view().elem_cnt();\n\n    MinMaxDeltaFunctor<T> min_max_delta_functor{};\n\n    RUN_CUDA_KERNEL((FusedGetIntersectionAreaForward<decltype(min_max_delta_functor), T>),\n                    ctx->stream(), elem_cnt, min_max_delta_functor, elem_cnt, b1_x1->dptr<T>(),\n                    b1_x2->dptr<T>(), b2_x1->dptr<T>(), b2_x2->dptr<T>(), b1_y1->dptr<T>(),\n                    b1_y2->dptr<T>(), b2_y1->dptr<T>(), b2_y2->dptr<T>(), inter->mut_dptr<T>());\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_FUSED_GET_INTERSECTION_AREA_CUDA_KERNEL(dtype)        \\\n  REGISTER_USER_KERNEL(\"fused_get_intersection_area\")                  \\\n      .SetCreateFn<FusedGetIntersectionAreaKernel<dtype>>()            \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \\\n                       && (user_op::HobDataType(\"inter\", 0) == GetDataType<dtype>::value));\n\nREGISTER_FUSED_GET_INTERSECTION_AREA_CUDA_KERNEL(float)\nREGISTER_FUSED_GET_INTERSECTION_AREA_CUDA_KERNEL(double)\nREGISTER_FUSED_GET_INTERSECTION_AREA_CUDA_KERNEL(half)\n\ntemplate<typename T>\nclass FusedGetIntersectionAreaGradKernel final : public user_op::OpKernel {\n public:\n  FusedGetIntersectionAreaGradKernel() = default;\n  ~FusedGetIntersectionAreaGradKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* b1_x1 = ctx->Tensor4ArgNameAndIndex(\"b1_x1\", 0);\n    const user_op::Tensor* b1_x2 = ctx->Tensor4ArgNameAndIndex(\"b1_x2\", 0);\n    const user_op::Tensor* b2_x1 = ctx->Tensor4ArgNameAndIndex(\"b2_x1\", 0);\n    const user_op::Tensor* b2_x2 = ctx->Tensor4ArgNameAndIndex(\"b2_x2\", 0);\n    const user_op::Tensor* b1_y1 = ctx->Tensor4ArgNameAndIndex(\"b1_y1\", 0);\n    const user_op::Tensor* b1_y2 = ctx->Tensor4ArgNameAndIndex(\"b1_y2\", 0);\n    const user_op::Tensor* b2_y1 = ctx->Tensor4ArgNameAndIndex(\"b2_y1\", 0);\n    const user_op::Tensor* b2_y2 = ctx->Tensor4ArgNameAndIndex(\"b2_y2\", 0);\n\n    user_op::Tensor* inter_diff = ctx->Tensor4ArgNameAndIndex(\"inter_diff\", 0);\n\n    user_op::Tensor* b1_x1_diff = ctx->Tensor4ArgNameAndIndex(\"b1_x1_diff\", 0);\n    user_op::Tensor* b1_x2_diff = ctx->Tensor4ArgNameAndIndex(\"b1_x2_diff\", 0);\n    user_op::Tensor* b2_x1_diff = ctx->Tensor4ArgNameAndIndex(\"b2_x1_diff\", 0);\n    user_op::Tensor* b2_x2_diff = ctx->Tensor4ArgNameAndIndex(\"b2_x2_diff\", 0);\n    user_op::Tensor* b1_y1_diff = ctx->Tensor4ArgNameAndIndex(\"b1_y1_diff\", 0);\n    user_op::Tensor* b1_y2_diff = ctx->Tensor4ArgNameAndIndex(\"b1_y2_diff\", 0);\n    user_op::Tensor* b2_y1_diff = ctx->Tensor4ArgNameAndIndex(\"b2_y1_diff\", 0);\n    user_op::Tensor* b2_y2_diff = ctx->Tensor4ArgNameAndIndex(\"b2_y2_diff\", 0);\n\n    const int64_t elem_cnt = b1_x1->shape_view().elem_cnt();\n\n    MinMaxDeltaFunctor<T> min_max_delta_functor{};\n\n    RUN_CUDA_KERNEL((FusedGetIntersectionAreaBackward<decltype(min_max_delta_functor), T>),\n                    ctx->stream(), elem_cnt, min_max_delta_functor, elem_cnt, b1_x1->dptr<T>(),\n                    b1_x2->dptr<T>(), b2_x1->dptr<T>(), b2_x2->dptr<T>(), b1_y1->dptr<T>(),\n                    b1_y2->dptr<T>(), b2_y1->dptr<T>(), b2_y2->dptr<T>(), inter_diff->dptr<T>(),\n                    b1_x1_diff->mut_dptr<T>(), b1_x2_diff->mut_dptr<T>(), b2_x1_diff->mut_dptr<T>(),\n                    b2_x2_diff->mut_dptr<T>(), b1_y1_diff->mut_dptr<T>(), b1_y2_diff->mut_dptr<T>(),\n                    b2_y1_diff->mut_dptr<T>(), b2_y2_diff->mut_dptr<T>());\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_FUSED_GET_INTERSECTION_AREA_GRAD_CUDA_KERNEL(dtype)   \\\n  REGISTER_USER_KERNEL(\"fused_get_intersection_area_grad\")             \\\n      .SetCreateFn<FusedGetIntersectionAreaGradKernel<dtype>>()        \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \\\n                       && (user_op::HobDataType(\"b1_x1_diff\", 0) == GetDataType<dtype>::value));\n\nREGISTER_FUSED_GET_INTERSECTION_AREA_GRAD_CUDA_KERNEL(float)\nREGISTER_FUSED_GET_INTERSECTION_AREA_GRAD_CUDA_KERNEL(double)\nREGISTER_FUSED_GET_INTERSECTION_AREA_GRAD_CUDA_KERNEL(half)\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/fused_get_iou_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/new_kernel_util.h\"\n\nnamespace oneflow {\n\nnamespace {\ntemplate<typename T>\n__global__ void FusedGetIouForward(const int n, const T* w1, const T* h1, const T* w2, const T* h2,\n                                   const T* inter, T* iou, const float eps) {\n  CUDA_1D_KERNEL_LOOP(i, n) {\n    const T inter_i = inter[i];\n    iou[i] = inter_i / (w1[i] * h1[i] + w2[i] * h2[i] - inter_i + static_cast<T>(eps));\n  }\n}\n\ntemplate<typename T>\n__global__ void FusedGetIouBackward(const int n, const T* diou, const T* w1, const T* h1,\n                                    const T* w2, const T* h2, const T* inter, T* dw1, T* dh1,\n                                    T* dinter, const float eps) {\n  CUDA_1D_KERNEL_LOOP(i, n) {\n    const T w1_i = w1[i], h1_i = h1[i], w2_i = w2[i], h2_i = h2[i], inter_i = inter[i],\n            diou_i = diou[i];\n    const T w_h_eps = w1_i * h1_i + w2_i * h2_i + static_cast<T>(eps);\n    const T w_h_eps_inter_diff = w_h_eps - inter_i;\n    const T w_h_eps_inter_diff_square = w_h_eps_inter_diff * w_h_eps_inter_diff;\n    const T common_for_dwh = -inter_i * diou_i / w_h_eps_inter_diff_square;\n    dinter[i] = w_h_eps * diou_i / w_h_eps_inter_diff_square;\n    dw1[i] = h1_i * common_for_dwh;\n    dh1[i] = w1_i * common_for_dwh;\n  }\n}\n};  // namespace\n\ntemplate<typename T>\nclass FusedGetIouGpuKernel final : public user_op::OpKernel {\n public:\n  FusedGetIouGpuKernel() = default;\n  ~FusedGetIouGpuKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* w1 = ctx->Tensor4ArgNameAndIndex(\"w1\", 0);\n    const user_op::Tensor* h1 = ctx->Tensor4ArgNameAndIndex(\"h1\", 0);\n    const user_op::Tensor* w2 = ctx->Tensor4ArgNameAndIndex(\"w2\", 0);\n    const user_op::Tensor* h2 = ctx->Tensor4ArgNameAndIndex(\"h2\", 0);\n    const user_op::Tensor* inter = ctx->Tensor4ArgNameAndIndex(\"inter\", 0);\n\n    user_op::Tensor* iou = ctx->Tensor4ArgNameAndIndex(\"iou\", 0);\n\n    float eps = ctx->Attr<float>(\"eps\");\n\n    const int32_t elem_cnt = w1->shape_view().elem_cnt();\n    RUN_CUDA_KERNEL((FusedGetIouForward<T>), ctx->stream(), elem_cnt, elem_cnt, w1->dptr<T>(),\n                    h1->dptr<T>(), w2->dptr<T>(), h2->dptr<T>(), inter->dptr<T>(),\n                    iou->mut_dptr<T>(), eps);\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_FUSED_GET_IOU_CUDA_KERNEL(dtype)                      \\\n  REGISTER_USER_KERNEL(\"fused_get_iou\")                                \\\n      .SetCreateFn<FusedGetIouGpuKernel<dtype>>()                      \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \\\n                       && (user_op::HobDataType(\"iou\", 0) == GetDataType<dtype>::value));\n\nREGISTER_FUSED_GET_IOU_CUDA_KERNEL(float)\nREGISTER_FUSED_GET_IOU_CUDA_KERNEL(half)\nREGISTER_FUSED_GET_IOU_CUDA_KERNEL(double)\n\ntemplate<typename T>\nclass FusedGetIouGradGpuKernel final : public user_op::OpKernel {\n public:\n  FusedGetIouGradGpuKernel() = default;\n  ~FusedGetIouGradGpuKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* diou = ctx->Tensor4ArgNameAndIndex(\"diou\", 0);\n    const user_op::Tensor* w1 = ctx->Tensor4ArgNameAndIndex(\"w1\", 0);\n    const user_op::Tensor* h1 = ctx->Tensor4ArgNameAndIndex(\"h1\", 0);\n    const user_op::Tensor* w2 = ctx->Tensor4ArgNameAndIndex(\"w2\", 0);\n    const user_op::Tensor* h2 = ctx->Tensor4ArgNameAndIndex(\"h2\", 0);\n    const user_op::Tensor* inter = ctx->Tensor4ArgNameAndIndex(\"inter\", 0);\n\n    user_op::Tensor* dw1 = ctx->Tensor4ArgNameAndIndex(\"dw1\", 0);\n    user_op::Tensor* dh1 = ctx->Tensor4ArgNameAndIndex(\"dh1\", 0);\n    user_op::Tensor* dinter = ctx->Tensor4ArgNameAndIndex(\"dinter\", 0);\n\n    float eps = ctx->Attr<float>(\"eps\");\n\n    const int32_t elem_cnt = diou->shape_view().elem_cnt();\n\n    RUN_CUDA_KERNEL((FusedGetIouBackward<T>), ctx->stream(), elem_cnt, elem_cnt, diou->dptr<T>(),\n                    w1->dptr<T>(), h1->dptr<T>(), w2->dptr<T>(), h2->dptr<T>(), inter->dptr<T>(),\n                    dw1->mut_dptr<T>(), dh1->mut_dptr<T>(), dinter->mut_dptr<T>(), eps);\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_FUSED_GET_IOU_GRAD_CUDA_KERNEL(dtype)                 \\\n  REGISTER_USER_KERNEL(\"fused_get_iou_grad\")                           \\\n      .SetCreateFn<FusedGetIouGradGpuKernel<dtype>>()                  \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \\\n                       && (user_op::HobDataType(\"diou\", 0) == GetDataType<dtype>::value));\n\nREGISTER_FUSED_GET_IOU_GRAD_CUDA_KERNEL(float)\nREGISTER_FUSED_GET_IOU_GRAD_CUDA_KERNEL(half)\nREGISTER_FUSED_GET_IOU_GRAD_CUDA_KERNEL(double)\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/fused_glu_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/device/cuda_util.h\"\n#include \"oneflow/core/cuda/elementwise.cuh\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n#include \"oneflow/core/ep/include/primitive/matmul.h\"\n#include \"oneflow/core/ep/include/primitive/unary_op.h\"\n#include \"oneflow/core/common/data_type.h\"\n#include \"oneflow/core/common/device_type.h\"\n#include \"oneflow/core/ep/common/primitive/unary_functor.h\"\n#include \"oneflow/core/ep/cuda/primitive/unary_functor.cuh\"\n#include \"oneflow/core/kernel/util/cuda_half_util.h\"\n#include \"oneflow/core/kernel/cuda_graph_support.h\"\n#include \"oneflow/user/kernels/cublas_fused_mlp_util.cuh\"\n\n#if CUDA_VERSION >= 11000\n#include <cuda_bf16.h>\n#endif  // CUDA_VERSION >= 11000\n#include \"oneflow/core/device/cuda_pseudo_bfloat16.h\"\n\n#if CUDA_VERSION >= 11020\n\n#ifdef WITH_CUTLASS\n\n#include \"device/dual_gemm.h\"\n#include \"thread/left_silu_and_mul.h\"\n\nnamespace cutlass {\nnamespace epilogue {\nnamespace thread {\n\ntemplate<typename ElementOutput_, int Count, template<typename> typename Activation,\n         typename ElementAccumulator_ = ElementOutput_, typename ElementCompute_ = ElementOutput_,\n         FloatRoundStyle Round = FloatRoundStyle::round_to_nearest>\nclass RightActivationAndMul {\n public:\n  using ElementOutput = ElementOutput_;\n  using ElementAccumulator = ElementAccumulator_;\n  using ElementCompute = ElementCompute_;\n\n  static int const kCount = Count;\n  using FragmentOutput = Array<ElementOutput, kCount>;\n  using FragmentAccumulator = Array<ElementAccumulator, kCount>;\n  using ComputeFragment = Array<ElementCompute, kCount>;\n\n  static FloatRoundStyle const kRound = Round;\n\n  struct Params {};\n\n private:\n  ElementCompute alpha_;\n  ElementCompute beta_;\n\n public:\n  CUTLASS_HOST_DEVICE\n  RightActivationAndMul(Params const& /*params*/) {}\n\n  CUTLASS_HOST_DEVICE\n  bool is_source_needed() const { return true; }\n\n  CUTLASS_HOST_DEVICE\n  void set_k_partition(int k_partition, int k_partition_count) { assert(false); }\n\n  CUTLASS_HOST_DEVICE\n  FragmentOutput operator()(FragmentAccumulator const& lhs, FragmentAccumulator const& rhs) const {\n    NumericArrayConverter<ElementOutput, ElementAccumulator, kCount, Round> accumulator_to_output;\n\n    FragmentOutput converted_lhs = accumulator_to_output(lhs);\n    FragmentOutput converted_rhs = accumulator_to_output(rhs);\n\n    Activation<FragmentOutput> act;\n    cutlass::multiplies<FragmentOutput> mul;\n    auto act_rhs = act(converted_rhs);\n    return mul(act_rhs, converted_lhs);\n  }\n\n  CUTLASS_HOST_DEVICE\n  ElementOutput operator()(ElementAccumulator const& lhs, ElementAccumulator const& rhs) const {\n    ElementOutput convert_lhs(lhs);\n    ElementOutput convert_rhs(rhs);\n    Activation<ElementOutput> act;\n    cutlass::multiplies<ElementOutput> mul;\n    auto act_rhs = act(convert_rhs);\n    return mul(act_rhs, convert_lhs);\n  }\n};\n}  // namespace thread\n}  // namespace epilogue\n}  // namespace cutlass\n\n#endif  // WITH_CUTLASS\n\nnamespace oneflow {\n\nnamespace {\n\n#ifdef WITH_CUTLASS\n\ntemplate<typename T>\nstruct GetCutlassType {\n  using type = T;\n};\n\ntemplate<>\nstruct GetCutlassType<half> {\n  using type = cutlass::half_t;\n};\n\n#if CUDA_VERSION >= 11000\n\ntemplate<>\nstruct GetCutlassType<nv_bfloat16> {\n  using type = cutlass::bfloat16_t;\n};\n\n#endif\n\ntemplate<typename Acc, typename Arch, template<typename> typename Activation>\nvoid DualGemmGegluHalf(ep::CudaStream* stream, int32_t m, int32_t n, int32_t k, const void* x,\n                       const void* w, const void* v, const void* b, const void* c, void* wx,\n                       int32_t wx_stride, void* vx, int32_t vx_stride, void* y) {\n  constexpr int kStages = 5;\n  constexpr bool kSplitKSerial = false;\n  constexpr bool kUseBias = true;\n  using ElementOperandA = cutlass::half_t;\n  using ElementOperandB = cutlass::half_t;\n  using ElementOutput = cutlass::half_t;\n  using ElementAccumulator = Acc;\n  using ElementCompute = Acc;\n  using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 32>;\n  using WarpShape = cutlass::gemm::GemmShape<64, 32, 32>;\n  using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;\n\n  constexpr auto kScaleType =\n      kUseBias ? cutlass::epilogue::thread::ScaleType::NoBetaScaling\n               : (\n                   // No bias\n                   kSplitKSerial ? cutlass::epilogue::thread::ScaleType::Default\n                                 : cutlass::epilogue::thread::ScaleType::Nothing);\n  using EpilogueOutputOp0 =\n      cutlass::epilogue::thread::LinearCombination<ElementOutput,\n                                                   128 / cutlass::sizeof_bits<ElementOutput>::value,\n                                                   ElementAccumulator, ElementCompute, kScaleType>;\n  using EpilogueOutputOp1 =\n      cutlass::epilogue::thread::LinearCombination<ElementOutput,\n                                                   128 / cutlass::sizeof_bits<ElementOutput>::value,\n                                                   ElementAccumulator, ElementCompute, kScaleType>;\n  using EpilogueOutputOp2 = cutlass::epilogue::thread::RightActivationAndMul<\n      ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value, Activation, ElementOutput,\n      ElementCompute>;\n\n  const ElementCompute alpha0 = ElementCompute(1);\n  const ElementCompute beta0 = ElementCompute(kUseBias ? 1 : 0);\n  const ElementCompute alpha1 = ElementCompute(1);\n  const ElementCompute beta1 = ElementCompute(kUseBias ? 1 : 0);\n\n  // Optionally, we might not need intermediate GEMM outputs\n  constexpr bool kStoreD0 = true;\n  constexpr bool kStoreD1 = true;\n  using DualGemm = cutlass::gemm::device::DualGemm<\n      ElementOperandA, cutlass::layout::RowMajor, ElementOperandB, cutlass::layout::ColumnMajor,\n      ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp,\n      Arch, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1,\n      EpilogueOutputOp2, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, kStages,\n      kStoreD0, kStoreD1, kSplitKSerial>;\n\n  int split_k_slices = DualGemm::kSplitKSerial ? 2 : 1;\n\n  typename cutlass::TensorRef<const ElementOperandA, cutlass::layout::RowMajor> tensor_a0(\n      reinterpret_cast<const cutlass::half_t*>(x), k);\n  typename cutlass::TensorRef<const ElementOperandA, cutlass::layout::ColumnMajor> tensor_b0(\n      reinterpret_cast<const cutlass::half_t*>(w), k);\n  typename cutlass::TensorRef<const ElementOperandA, cutlass::layout::ColumnMajor> tensor_b1(\n      reinterpret_cast<const cutlass::half_t*>(v), k);\n  typename cutlass::TensorRef<const ElementOperandA, cutlass::layout::RowMajor> tensor_bias0(\n      reinterpret_cast<const cutlass::half_t*>(b), {0});\n  typename cutlass::TensorRef<const ElementOperandA, cutlass::layout::RowMajor> tensor_bias1(\n      reinterpret_cast<const cutlass::half_t*>(c), {0});\n  typename cutlass::TensorRef<typename DualGemm::ElementC, typename DualGemm::LayoutC> tensor_d0(\n      reinterpret_cast<cutlass::half_t*>(wx), wx_stride);\n  typename cutlass::TensorRef<typename DualGemm::ElementC, typename DualGemm::LayoutC> tensor_d1(\n      reinterpret_cast<cutlass::half_t*>(vx), vx_stride);\n  typename cutlass::TensorRef<ElementOperandA, cutlass::layout::RowMajor> tensor_out(\n      reinterpret_cast<cutlass::half_t*>(y), n);\n\n  cutlass::gemm::GemmCoord problem_size(m, n, k);\n  typename DualGemm::Arguments arguments{\n      problem_size,    tensor_a0,    tensor_b0,     tensor_bias0, tensor_d0,\n      tensor_b1,       tensor_bias1, tensor_d1,     tensor_out,   {alpha0, beta0},\n      {alpha1, beta1}, {},           split_k_slices};\n\n  DualGemm dual_gemm_op;\n  dual_gemm_op.initialize(arguments, stream->cublas_workspace(), stream->cuda_stream());\n  dual_gemm_op(stream->cuda_stream());\n}\n\ntemplate<typename Acc, typename Arch>\nbool TryDispatchDualGemmImplActivation(ep::CudaStream* stream, const std::string& activation,\n                                       int32_t m, int32_t n, int32_t k, const void* x,\n                                       const void* w, const void* v, const void* b, const void* c,\n                                       void* wx, int32_t wx_stride, void* vx, int32_t vx_stride,\n                                       void* y) {\n  if (activation == \"fast_gelu\") {\n    DualGemmGegluHalf<Acc, Arch, cutlass::epilogue::thread::GELU_taylor>(\n        stream, m, n, k, x, w, v, b, c, wx, wx_stride, vx, vx_stride, y);\n    return true;\n  } else if (activation == \"gelu\") {\n    DualGemmGegluHalf<Acc, Arch, cutlass::epilogue::thread::GELU>(stream, m, n, k, x, w, v, b, c,\n                                                                  wx, wx_stride, vx, vx_stride, y);\n    return true;\n  } else {\n    return false;\n  }\n}\n\ntemplate<typename T, typename Arch>\nbool TryDispatchDualGemmImplAccType(ep::CudaStream* stream, const std::string& activation,\n                                    int32_t m, int32_t n, int32_t k, const T* x, const T* w,\n                                    const T* v, const T* b, const T* c, T* wx, int32_t wx_stride,\n                                    T* vx, int32_t vx_stride, T* y) {\n  const bool allow_half_precision =\n      ParseBooleanFromEnv(\"ONEFLOW_MATMUL_ALLOW_HALF_PRECISION_ACCUMULATION\", false);\n  if (std::is_same<T, half>::value) {\n    if (allow_half_precision) {\n      return TryDispatchDualGemmImplActivation<cutlass::half_t, Arch>(\n          stream, activation, m, n, k, x, w, v, b, c, wx, wx_stride, vx, vx_stride, y);\n    } else {\n      return TryDispatchDualGemmImplActivation<float, Arch>(stream, activation, m, n, k, x, w, v, b,\n                                                            c, wx, wx_stride, vx, vx_stride, y);\n    }\n  } else {\n    return false;\n  }\n}\n\ntemplate<typename T, typename Arch>\nbool TryDispatchDualGemmImplAlignment(ep::CudaStream* stream, const std::string& activation,\n                                      int32_t m, int32_t n, int32_t k, const T* x, const T* w,\n                                      const T* v, const T* b, const T* c, T* wx, int32_t wx_stride,\n                                      T* vx, int32_t vx_stride, T* y) {\n  if (m % 8 == 0 && n % 8 == 0 && k % 8 == 0\n      && reinterpret_cast<uintptr_t>(x) % (8 * sizeof(T)) == 0\n      && reinterpret_cast<uintptr_t>(w) % (8 * sizeof(T)) == 0\n      && reinterpret_cast<uintptr_t>(v) % (8 * sizeof(T)) == 0\n      && reinterpret_cast<uintptr_t>(b) % (8 * sizeof(T)) == 0\n      && reinterpret_cast<uintptr_t>(c) % (8 * sizeof(T)) == 0\n      && reinterpret_cast<uintptr_t>(wx) % (8 * sizeof(T)) == 0 && wx_stride % 8 == 0\n      && reinterpret_cast<uintptr_t>(vx) % (8 * sizeof(T)) == 0\n      && reinterpret_cast<uintptr_t>(y) % (8 * sizeof(T)) == 0 && vx_stride % 8 == 0) {\n    return TryDispatchDualGemmImplAccType<T, Arch>(stream, activation, m, n, k, x, w, v, b, c, wx,\n                                                   wx_stride, vx, vx_stride, y);\n  } else {\n    return false;\n  }\n}\n\ntemplate<typename T>\nbool TryDispatchDualGemmImplArchTag(ep::CudaStream* stream, const std::string& activation,\n                                    int32_t m, int32_t n, int32_t k, const T* x, const T* w,\n                                    const T* v, const T* b, const T* c, T* wx, int32_t wx_stride,\n                                    T* vx, int32_t vx_stride, T* y) {\n  const int arch = stream->cuda_arch();\n  if (arch == 800) {\n    return TryDispatchDualGemmImplAlignment<T, cutlass::arch ::Sm80>(\n        stream, activation, m, n, k, x, w, v, b, c, wx, wx_stride, vx, vx_stride, y);\n  } else {\n    return false;\n  }\n}\n\n#endif  // WITH_CUTLASS\ntemplate<typename T>\nbool TryDispatchDualGemmImpl(ep::CudaStream* stream, const std::string& activation, int32_t m,\n                             int32_t n, int32_t k, const T* x, const T* w, const T* v, const T* b,\n                             const T* c, T* wx, int32_t wx_stride, T* vx, int32_t vx_stride, T* y) {\n#ifdef WITH_CUTLASS\n  const bool enabled = ParseBooleanFromEnv(\"ONEFLOW_KERNEL_GLU_ENABLE_DUAL_GEMM_IMPL\", true);\n  if (enabled) {\n    return TryDispatchDualGemmImplArchTag<T>(stream, activation, m, n, k, x, w, v, b, c, wx,\n                                             wx_stride, vx, vx_stride, y);\n  } else {\n    return false;\n  }\n#else\n  return false;\n#endif  // WITH_CUTLASS\n}\n\ntemplate<typename T, typename IndexType, ep::primitive::UnaryOp act_type, int32_t pack_size>\n__global__ void FusedGluForwardGpu(\n    const IndexType m, const IndexType packed_n, const IndexType packed_num,\n    const IndexType packed_stride,\n    ep::primitive::UnaryFunctor<DeviceType::kCUDA, act_type, T, T> act, T* matmul_wx, T* matmul_vx,\n    T* y) {\n  // obtain global thread index\n  IndexType global_thread_id = blockIdx.x * blockDim.x + threadIdx.x;\n\n  // define type of Pack\n  using LoadPack = cuda::elementwise::Packed<T, pack_size>;\n\n  // workload of current thread\n  for (IndexType packed_index = global_thread_id, step = gridDim.x * blockDim.x;\n       packed_index < packed_num; packed_index += step) {\n    // obtain the row and col index in output tensor \"y\"\n    const IndexType y_packed_row = packed_index / packed_n;\n    const IndexType y_packed_col = packed_index - y_packed_row * packed_n;\n\n    // cast type to load type\n    const LoadPack* matmul_wx_load =\n        reinterpret_cast<LoadPack*>(matmul_wx) + (y_packed_row * packed_stride + y_packed_col);\n    const LoadPack* matmul_vx_load =\n        reinterpret_cast<LoadPack*>(matmul_vx) + (y_packed_row * packed_stride + y_packed_col);\n\n    // init vectors\n    LoadPack matmul_wx_vec = *matmul_wx_load;\n    LoadPack matmul_vx_vec = *matmul_vx_load;\n    LoadPack y_vec;\n\n#pragma unroll\n    for (int i = 0; i < pack_size; i++) {\n      // obtain the hidden_state and gate\n      T hidden_state = matmul_wx_vec.elem[i];\n      T gate = matmul_vx_vec.elem[i];\n\n      // calculate activation\n      T act_gate = act(gate);\n\n      // calculate element-wise product\n      y_vec.elem[i] = hidden_state * act_gate;\n    }\n    *(reinterpret_cast<LoadPack*>(y + packed_index * pack_size)) = y_vec;\n  }\n}\n\ntemplate<typename T, typename IndexType, ep::primitive::UnaryOp act_type, int32_t pack_size>\nvoid LaunchFusedGluForwardGpu(ep::Stream* stream, const IndexType m, const IndexType packed_n,\n                              const IndexType pack_num, const IndexType packed_stride, T* matmul_wx,\n                              T* matmul_vx, T* y) {\n  constexpr int32_t block_size = 128;\n  unsigned int grid_size = (pack_num + block_size - 1) / block_size;\n  ep::primitive::UnaryFunctor<DeviceType::kCUDA, act_type, T, T> act(0, 0);\n  FusedGluForwardGpu<T, IndexType, act_type, pack_size>\n      <<<grid_size, block_size, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(\n          m, packed_n, pack_num, packed_stride, act, matmul_wx, matmul_vx, y);\n}\n\ntemplate<typename T, ep::primitive::UnaryOp act_type, int32_t pack_size>\nvoid DispatchIndexType(ep::Stream* stream, const int64_t m, const int64_t packed_n,\n                       const int64_t pack_num, const int64_t packed_stride, T* matmul_wx,\n                       T* matmul_vx, T* y) {\n  // dispatch index type\n  if (pack_num < (1 << 30)) {\n    LaunchFusedGluForwardGpu<T, int32_t, act_type, pack_size>(\n        stream, m, packed_n, pack_num, packed_stride, matmul_wx, matmul_vx, y);\n  } else {\n    LaunchFusedGluForwardGpu<T, int64_t, act_type, pack_size>(\n        stream, m, packed_n, pack_num, packed_stride, matmul_wx, matmul_vx, y);\n  }\n}\n\ntemplate<typename T, ep::primitive::UnaryOp act_type, int32_t alignment,\n         typename std::enable_if<alignment / sizeof(T) == 0, int>::type = 0>\nvoid DispatchPackSize(ep::Stream* stream, const int64_t m, const int64_t n, const int64_t stride,\n                      T* matmul_wx, T* matmul_vx, T* y) {\n  DispatchIndexType<T, act_type, 1>(stream, m, n, m * n, stride, matmul_wx, matmul_vx, y);\n}\n\ntemplate<typename T, ep::primitive::UnaryOp act_type, int32_t alignment,\n         typename std::enable_if<alignment / sizeof(T) != 0, int>::type = 0>\nvoid DispatchPackSize(ep::Stream* stream, const int64_t m, const int64_t n, const int64_t stride,\n                      T* matmul_wx, T* matmul_vx, T* y) {\n  const int64_t pack_size = alignment / sizeof(T);\n  const int64_t packed_n = n / pack_size;\n  const int64_t pack_num = m * packed_n;\n  const int64_t packed_stride = stride / pack_size;\n  DispatchIndexType<T, act_type, alignment / sizeof(T)>(stream, m, packed_n, pack_num,\n                                                        packed_stride, matmul_wx, matmul_vx, y);\n}\n\ntemplate<typename T, ep::primitive::UnaryOp act_type>\nvoid DispatchAlignment(ep::Stream* stream, const int64_t m, const int64_t n, const int64_t stride,\n                       T* matmul_wx, T* matmul_vx, T* y) {\n  const auto IsAligned = [&](const size_t alignment) {\n    const uintptr_t matmul_wx_ptr = reinterpret_cast<uintptr_t>(matmul_wx);\n    const uintptr_t matmul_vx_ptr = reinterpret_cast<uintptr_t>(matmul_vx);\n    const uintptr_t y_ptr = reinterpret_cast<uintptr_t>(y);\n\n    return (/* memory address alignment */\n            matmul_wx_ptr % alignment == 0 && matmul_vx_ptr % alignment == 0\n            && y_ptr % alignment == 0\n            /* #element per row alignment */\n            && n % (alignment / sizeof(T)) == 0);\n  };\n\n  if (IsAligned(16)) {\n    DispatchPackSize<T, act_type, 16>(stream, m, n, stride, matmul_wx, matmul_vx, y);\n  } else if (IsAligned(8)) {\n    DispatchPackSize<T, act_type, 8>(stream, m, n, stride, matmul_wx, matmul_vx, y);\n  } else if (IsAligned(4)) {\n    DispatchPackSize<T, act_type, 4>(stream, m, n, stride, matmul_wx, matmul_vx, y);\n  } else if (IsAligned(2)) {\n    DispatchPackSize<T, act_type, 2>(stream, m, n, stride, matmul_wx, matmul_vx, y);\n  } else {\n    DispatchPackSize<T, act_type, 1>(stream, m, n, stride, matmul_wx, matmul_vx, y);\n  }\n}\n\ntemplate<typename T>\nvoid DispatchActivationType(ep::Stream* stream, const int64_t m, const int64_t n,\n                            const int64_t stride, T* matmul_wx, T* matmul_vx, T* y,\n                            const std::string& activation) {\n  if (activation == \"none\") {\n    DispatchAlignment<T, ep::primitive::UnaryOp::kIdentity>(stream, m, n, stride, matmul_wx,\n                                                            matmul_vx, y);\n  } else if (activation == \"sigmoid\") {\n    DispatchAlignment<T, ep::primitive::UnaryOp::kSigmoid>(stream, m, n, stride, matmul_wx,\n                                                           matmul_vx, y);\n  } else if (activation == \"relu\") {\n    DispatchAlignment<T, ep::primitive::UnaryOp::kRelu>(stream, m, n, stride, matmul_wx, matmul_vx,\n                                                        y);\n  } else if (activation == \"gelu\") {\n    DispatchAlignment<T, ep::primitive::UnaryOp::kGelu>(stream, m, n, stride, matmul_wx, matmul_vx,\n                                                        y);\n  } else if (activation == \"fast_gelu\") {\n    DispatchAlignment<T, ep::primitive::UnaryOp::kFastGelu>(stream, m, n, stride, matmul_wx,\n                                                            matmul_vx, y);\n  } else if (activation == \"silu\") {\n    DispatchAlignment<T, ep::primitive::UnaryOp::kSilu>(stream, m, n, stride, matmul_wx, matmul_vx,\n                                                        y);\n  } else {\n    UNIMPLEMENTED();\n  }\n}\n\ntemplate<typename T>\nclass GpuFusedGluKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport {\n public:\n  GpuFusedGluKernel() = default;\n  ~GpuFusedGluKernel() override = default;\n\n  std::shared_ptr<user_op::OpKernelCache> InitOpKernelCache(\n      user_op::KernelCacheContext* ctx) const override {\n    return CreateCublasFusedMLPKernelCache();\n  }\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*,\n               const user_op::OpKernelCache* cache) const override {\n    // obtain tensors from context\n    const user_op::Tensor* input_tensor_x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    const user_op::Tensor* input_tensor_w = ctx->Tensor4ArgNameAndIndex(\"w\", 0);\n    user_op::Tensor* out_tensor_y = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    user_op::Tensor* out_tensor_matmul_wx = ctx->Tensor4ArgNameAndIndex(\"matmul_wx\", 0);\n\n    // obtain optional tensors from context\n    bool is_split_mode = false;\n    user_op::Tensor* input_tensor_b = nullptr;\n    user_op::Tensor* input_tensor_v = nullptr;\n    user_op::Tensor* input_tensor_c = nullptr;\n    user_op::Tensor* out_tensor_matmul_vx = nullptr;\n\n    auto* cuda_stream = ctx->stream()->As<ep::CudaStream>();\n    const auto* fused_glu_cache =\n        CHECK_NOTNULL(dynamic_cast<const CublasFusedMLPKernelCache*>(cache));\n\n    // check whether the user provide weight tensor v\n    if (ctx->has_input(\"v\", 0)) {\n      input_tensor_v = ctx->Tensor4ArgNameAndIndex(\"v\", 0);\n      out_tensor_matmul_vx = ctx->Tensor4ArgNameAndIndex(\"matmul_vx\", 0);\n      is_split_mode = true;\n    }\n\n    bool has_b = ctx->has_input(\"b\", 0);\n    bool has_c = ctx->has_input(\"c\", 0);\n\n    // check whether the user provide bais tensors\n    CHECK(!(has_b && (is_split_mode && !has_c)))\n        << \"expected existance of c, when provide tensors w, v and b\";\n    bool has_bias = false;\n    if (has_b && (is_split_mode && has_c)) {\n      input_tensor_b = ctx->Tensor4ArgNameAndIndex(\"b\", 0);\n      input_tensor_c = ctx->Tensor4ArgNameAndIndex(\"c\", 0);\n      has_bias = true;\n    } else if (has_b && (!is_split_mode)) {\n      input_tensor_b = ctx->Tensor4ArgNameAndIndex(\"b\", 0);\n      has_bias = true;\n    } else {\n      has_bias = false;\n    }\n\n    cublasLtEpilogue_t epilogue;\n    if (has_bias) {\n      epilogue = CUBLASLT_EPILOGUE_BIAS;\n    } else {\n      epilogue = CUBLASLT_EPILOGUE_DEFAULT;\n    }\n\n    // obtain tensor shapes\n    const ShapeView& x_shape = input_tensor_x->shape_view();\n    const ShapeView& w_shape = input_tensor_w->shape_view();\n    ShapeView b_shape;\n    if (has_bias) {\n      Shape _b_shape;\n      input_tensor_b->shape_view().ToShape(&_b_shape);\n      b_shape = ShapeView(_b_shape);\n    }\n    const ShapeView& y_shape = out_tensor_y->shape_view();\n\n    // validate dimension and number of axes\n    CHECK_GT(x_shape.NumAxes(), 1)\n        << \"number of axes of \\'x\\' should have be greater than 1, yet get \" << x_shape.NumAxes();\n    CHECK_EQ(w_shape.NumAxes(), 2)\n        << \"number of axes of \\'w\\' should have be equal to 2, yet get \" << w_shape.NumAxes();\n    if (has_bias) {\n      CHECK_EQ(b_shape.NumAxes(), 1)\n          << \"number of axes of \\'b\\' should have be equal to 1, yet get \" << b_shape.NumAxes();\n    }\n\n    // check input tensor shapes\n    size_t x_num_axes = x_shape.NumAxes();\n    CHECK_EQ(w_shape.At(1), x_shape.At(x_num_axes - 1))\n        << \"dimension 1 of \\'w\\'(\" << w_shape.At(1)\n        << \") is not consistant with the last dimension of \\'x\\'(\" << x_shape.At(x_num_axes - 1)\n        << \")\";\n    if (has_bias) {\n      CHECK_EQ(b_shape.At(0), w_shape.At(0))\n          << \"dimension 0 of \\'b\\'(\" << b_shape.At(0)\n          << \") is not consistant with dimension 0 of \\'w\\'(\" << w_shape.At(0) << \")\";\n    }\n    if (!is_split_mode) {\n      CHECK_EQ(w_shape.At(1) % 2, 0) << \"dimension 1 of \\'w\\' is not divisible by 2\";\n    }\n\n    // check optional input tensor shapes\n    if (is_split_mode) {\n      const ShapeView& v_shape = input_tensor_v->shape_view();\n      CHECK_EQ(v_shape.NumAxes(), 2)\n          << \"number of axes of \\'v\\' should have be equal to 2, yet get \" << v_shape.NumAxes();\n      CHECK_EQ(v_shape, w_shape) << \"the shape of \\'v\\' is not consistant with \\'w\\'\";\n      if (has_bias) {\n        const ShapeView& c_shape = input_tensor_c->shape_view();\n        CHECK_EQ(c_shape.NumAxes(), 1)\n            << \"number of axes of \\'c\\' should have be equal to 1, yet get \" << c_shape.NumAxes();\n        CHECK_EQ(c_shape, b_shape) << \"the shape of \\'c\\' is not consistant with \\'b\\'\";\n      }\n    }\n\n    // obtain data type for cublaslt computation\n    const DataType data_type = out_tensor_matmul_wx->data_type();\n    const cublasComputeType_t cublas_compute_dtype = GetComputeType(data_type);\n    const cudaDataType_t cuda_data_type = GetCudaDataType(data_type);\n\n    // infer m, n, k\n    const int64_t m = x_shape.Count(0, x_num_axes - 1);\n    const int64_t n = y_shape.At(x_num_axes - 1);\n    const int64_t k = x_shape.At(x_num_axes - 1);\n\n    if (has_bias) {\n      if (TryDispatchDualGemmImpl(\n              ctx->stream()->As<ep::CudaStream>(), ctx->Attr<std::string>(\"activation\"), m, n, k,\n              input_tensor_x->dptr<T>(), input_tensor_w->dptr<T>(),\n              is_split_mode ? input_tensor_v->dptr<T>() : input_tensor_w->dptr<T>() + n * k,\n              input_tensor_b->dptr<T>(),\n              is_split_mode ? input_tensor_c->dptr<T>() : input_tensor_b->dptr<T>() + n,\n              out_tensor_matmul_wx->mut_dptr<T>(), is_split_mode ? n : 2 * n,\n              is_split_mode ? out_tensor_matmul_vx->mut_dptr<T>()\n                            : out_tensor_matmul_wx->mut_dptr<T>() + n,\n              is_split_mode ? n : 2 * n, out_tensor_y->mut_dptr<T>())) {\n        return;\n      }\n    }\n\n    // init scalar parameters for cublaslt\n    const double alpha = 1.0;\n    const double beta = 0.0;\n    const auto sp_alpha = GetCublasScalarParameter(alpha, cublas_compute_dtype);\n    const auto sp_beta = GetCublasScalarParameter(beta, cublas_compute_dtype);\n\n    // calculate matmul_wx (and matmul_vx) through cublaslt\n    if (is_split_mode) {\n      // define shape parameters to be inferred\n      size_t cublas_wx_m = 0, cublas_wx_n = 0, cublas_wx_k = 0;\n      int64_t cublas_wx_lda = 0, cublas_wx_ldb = 0, cublas_wx_ldc = 0;\n      size_t cublas_vx_m = 0, cublas_vx_n = 0, cublas_vx_k = 0;\n      int64_t cublas_vx_lda = 0, cublas_vx_ldb = 0, cublas_vx_ldc = 0;\n\n      // init dim vector\n      DimVector x_dim_vec({m, k});\n      DimVector w_dim_vec({n, k});\n      DimVector v_dim_vec({n, k});\n\n      // setup cublaslt matmul attributes\n      InferMatmulCublasMNK(x_dim_vec, w_dim_vec,\n                           /*transpose_a=*/ep::primitive::BlasTransposeType::N,\n                           /*transpose_b=*/ep::primitive::BlasTransposeType::T, &cublas_wx_m,\n                           &cublas_wx_n, &cublas_wx_k, &cublas_wx_lda, &cublas_wx_ldb,\n                           &cublas_wx_ldc);\n      SetCublasAttr(fused_glu_cache, cublas_compute_dtype, cuda_data_type, false,\n                    /*transpose_a=*/ep::primitive::BlasTransposeType::N,\n                    /*transpose_b=*/ep::primitive::BlasTransposeType::T, epilogue,\n                    has_bias ? input_tensor_b->dptr() : nullptr, nullptr, cublas_wx_m, cublas_wx_n,\n                    cublas_wx_k, cublas_wx_lda, cublas_wx_ldb, cublas_wx_ldc);\n\n      // setup algorithms\n      cublasLtMatmulPreference_t preference = nullptr;\n      size_t workspace_size = cuda_stream->cublas_workspace_size();\n      OF_CUBLAS_CHECK(cublasLtMatmulPreferenceCreate(&preference));\n      OF_CUBLAS_CHECK(\n          cublasLtMatmulPreferenceSetAttribute(preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,\n                                               &workspace_size, sizeof(workspace_size)));\n      int wx_returned_result = 0;\n      cublasLtMatmulHeuristicResult_t wx_heuristic_result;\n      OF_CUBLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(\n          cuda_stream->cublas_lt_handle(), fused_glu_cache->operation_desc,\n          fused_glu_cache->cublas_a_desc, fused_glu_cache->cublas_b_desc,\n          fused_glu_cache->cublas_c_desc, fused_glu_cache->cublas_c_desc, preference, 1,\n          &wx_heuristic_result, &wx_returned_result));\n      CHECK_EQ(wx_returned_result, 1);\n\n      // launch cublaslt matmul\n      // out_tensor_matmul_wx = 1.0 * (input_tensor_w * input_tensor_x) + 1.0 * input_tensor_b\n      OF_CUBLAS_CHECK(cublasLtMatmul(\n          /*lightHandle*/ cuda_stream->cublas_lt_handle(),\n          /*computeDesc*/ fused_glu_cache->operation_desc,\n          /*alpha*/ &sp_alpha,\n          /*A*/ input_tensor_w->dptr(),\n          /*Adesc*/ fused_glu_cache->cublas_a_desc,\n          /*B*/ input_tensor_x->dptr(),\n          /*Bdesc*/ fused_glu_cache->cublas_b_desc,\n          /*beta*/ &sp_beta,\n          /*C*/ has_bias ? input_tensor_b->dptr() : nullptr,\n          /*Cdesc*/ fused_glu_cache->cublas_c_desc,\n          /*D*/ out_tensor_matmul_wx->mut_dptr(),\n          /*Ddesc*/ fused_glu_cache->cublas_c_desc,\n          /*algo*/ &wx_heuristic_result.algo,\n          /*workspace*/ cuda_stream->cublas_workspace(),\n          /*workspaceSizeInBytes*/ cuda_stream->cublas_workspace_size(),\n          /*stream*/ cuda_stream->cuda_stream()));\n\n      // setup cublaslt attributes\n      InferMatmulCublasMNK(x_dim_vec, v_dim_vec,\n                           /*transpose_a=*/ep::primitive::BlasTransposeType::N,\n                           /*transpose_b=*/ep::primitive::BlasTransposeType::T, &cublas_vx_m,\n                           &cublas_vx_n, &cublas_vx_k, &cublas_vx_lda, &cublas_vx_ldb,\n                           &cublas_vx_ldc);\n      SetCublasAttr(fused_glu_cache, cublas_compute_dtype, cuda_data_type, false,\n                    /*transpose_a=*/ep::primitive::BlasTransposeType::N,\n                    /*transpose_b=*/ep::primitive::BlasTransposeType::T, epilogue,\n                    has_bias ? input_tensor_c->dptr() : nullptr, nullptr, cublas_vx_m, cublas_vx_n,\n                    cublas_vx_k, cublas_vx_lda, cublas_vx_ldb, cublas_vx_ldc);\n\n      // setup algorithm\n      int vx_returned_result = 0;\n      cublasLtMatmulHeuristicResult_t vx_heuristic_result;\n      OF_CUBLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(\n          cuda_stream->cublas_lt_handle(), fused_glu_cache->operation_desc,\n          fused_glu_cache->cublas_a_desc, fused_glu_cache->cublas_b_desc,\n          fused_glu_cache->cublas_c_desc, fused_glu_cache->cublas_c_desc, preference, 1,\n          &vx_heuristic_result, &vx_returned_result));\n      CHECK_EQ(vx_returned_result, 1);\n      cublasLtMatmulPreferenceDestroy(preference);\n\n      // launch cublaslt matmul\n      // out_tensor_matmul_vx = 1.0 * (input_tensor_v * input_tensor_x) + 1.0 * input_tensor_c\n      OF_CUBLAS_CHECK(cublasLtMatmul(\n          /*lightHandle*/ cuda_stream->cublas_lt_handle(),\n          /*computeDesc*/ fused_glu_cache->operation_desc,\n          /*alpha*/ &sp_alpha,\n          /*A*/ input_tensor_v->dptr(),\n          /*Adesc*/ fused_glu_cache->cublas_a_desc,\n          /*B*/ input_tensor_x->dptr(),\n          /*Bdesc*/ fused_glu_cache->cublas_b_desc,\n          /*beta*/ &sp_beta,\n          /*C*/ has_bias ? input_tensor_c->dptr() : nullptr,\n          /*Cdesc*/ fused_glu_cache->cublas_c_desc,\n          /*D*/ out_tensor_matmul_vx->mut_dptr(),\n          /*Ddesc*/ fused_glu_cache->cublas_c_desc,\n          /*algo*/ &wx_heuristic_result.algo,\n          /*workspace*/ cuda_stream->cublas_workspace(),\n          /*workspaceSizeInBytes*/ cuda_stream->cublas_workspace_size(),\n          /*stream*/ cuda_stream->cuda_stream()));\n    } else {\n      // define shape parameters to be inferred\n      size_t cublas_m = 0, cublas_n = 0, cublas_k = 0;\n      int64_t cublas_lda = 0, cublas_ldb = 0, cublas_ldc = 0;\n\n      // init dim vector\n      DimVector x_dim_vec({m, k});\n      DimVector w_dim_vec({2 * n, k});\n\n      // setup cublas attributes\n      InferMatmulCublasMNK(x_dim_vec, w_dim_vec,\n                           /*transpose_a=*/ep::primitive::BlasTransposeType::N,\n                           /*transpose_b=*/ep::primitive::BlasTransposeType::T, &cublas_m,\n                           &cublas_n, &cublas_k, &cublas_lda, &cublas_ldb, &cublas_ldc);\n      SetCublasAttr(fused_glu_cache, cublas_compute_dtype, cuda_data_type, false,\n                    /*transpose_a=*/ep::primitive::BlasTransposeType::N,\n                    /*transpose_b=*/ep::primitive::BlasTransposeType::T, epilogue,\n                    has_bias ? input_tensor_b->dptr() : nullptr, nullptr, cublas_m, cublas_n,\n                    cublas_k, cublas_lda, cublas_ldb, cublas_ldc);\n\n      // setup algorithm\n      cublasLtMatmulPreference_t preference = nullptr;\n      size_t workspace_size = cuda_stream->cublas_workspace_size();\n      OF_CUBLAS_CHECK(cublasLtMatmulPreferenceCreate(&preference));\n      OF_CUBLAS_CHECK(\n          cublasLtMatmulPreferenceSetAttribute(preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,\n                                               &workspace_size, sizeof(workspace_size)));\n      int wx_returned_result = 0;\n      cublasLtMatmulHeuristicResult_t wx_heuristic_result;\n      OF_CUBLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(\n          cuda_stream->cublas_lt_handle(), fused_glu_cache->operation_desc,\n          fused_glu_cache->cublas_a_desc, fused_glu_cache->cublas_b_desc,\n          fused_glu_cache->cublas_c_desc, fused_glu_cache->cublas_c_desc, preference, 1,\n          &wx_heuristic_result, &wx_returned_result));\n      CHECK_EQ(wx_returned_result, 1);\n      cublasLtMatmulPreferenceDestroy(preference);\n\n      // launch cublaslt matmul\n      // out_tensor_matmul_wx = 1.0 * (input_tensor_w * input_tensor_x) + 1.0 * input_tensor_b\n      OF_CUBLAS_CHECK(cublasLtMatmul(\n          /*lightHandle*/ cuda_stream->cublas_lt_handle(),\n          /*computeDesc*/ fused_glu_cache->operation_desc,\n          /*alpha*/ &sp_alpha,\n          /*A*/ input_tensor_w->dptr(),\n          /*Adesc*/ fused_glu_cache->cublas_a_desc,\n          /*B*/ input_tensor_x->dptr(),\n          /*Bdesc*/ fused_glu_cache->cublas_b_desc,\n          /*beta*/ &sp_beta,\n          /*C*/ has_bias ? input_tensor_b->dptr() : nullptr,\n          /*Cdesc*/ fused_glu_cache->cublas_c_desc,\n          /*D*/ out_tensor_matmul_wx->mut_dptr(),\n          /*Ddesc*/ fused_glu_cache->cublas_c_desc,\n          /*algo*/ nullptr,\n          /*workspace*/ cuda_stream->cublas_workspace(),\n          /*workspaceSizeInBytes*/ cuda_stream->cublas_workspace_size(),\n          /*stream*/ cuda_stream->cuda_stream()));\n    }\n\n    // dispatch according to activation type\n    DispatchActivationType<T>(ctx->stream(),\n                              /*m, n=*/m, n,\n                              /*stride=*/is_split_mode ? n : 2 * n,\n                              /*matmul_wx=*/out_tensor_matmul_wx->mut_dptr<T>(),\n                              /*matmul_vx=*/\n                              is_split_mode ? out_tensor_matmul_vx->mut_dptr<T>()\n                                            : out_tensor_matmul_wx->mut_dptr<T>() + n,\n                              /*y=*/out_tensor_y->mut_dptr<T>(),\n                              /*activation=*/ctx->Attr<std::string>(\"activation\"));\n  }\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n}  // namespace\n\n#define REGISTER_GPU_FUSED_GLU_KERNEL(dtype)                           \\\n  REGISTER_USER_KERNEL(\"fused_glu\")                                    \\\n      .SetCreateFn<GpuFusedGluKernel<dtype>>()                         \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \\\n                       && (user_op::HobDataType(\"y\", 0) == GetDataType<dtype>::value));\n\nREGISTER_GPU_FUSED_GLU_KERNEL(double)\nREGISTER_GPU_FUSED_GLU_KERNEL(float)\nREGISTER_GPU_FUSED_GLU_KERNEL(half)\n#if CUDA_VERSION >= 11000\nREGISTER_GPU_FUSED_GLU_KERNEL(nv_bfloat16)\n#endif\n\n}  // namespace oneflow\n\n#endif  // CUDA_VERSION >= 11020\n"
  },
  {
    "path": "oneflow/user/kernels/fused_glu_without_linear_grad_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/device/cuda_util.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n#include \"oneflow/core/common/data_type.h\"\n#include \"oneflow/core/common/device_type.h\"\n#include \"oneflow/core/ep/include/primitive/matmul.h\"\n#include \"oneflow/core/ep/include/primitive/binary_op.h\"\n#include \"oneflow/core/ep/common/primitive/binary_functor.h\"\n#include \"oneflow/core/ep/cuda/primitive/binary_functor.cuh\"\n#include \"oneflow/core/ep/include/primitive/unary_op.h\"\n#include \"oneflow/core/ep/common/primitive/unary_functor.h\"\n#include \"oneflow/core/ep/cuda/primitive/unary_functor.cuh\"\n\n#if CUDA_VERSION >= 11000\n#include <cuda_bf16.h>\n#endif  // CUDA_VERSION >= 11000\n#include \"oneflow/core/device/cuda_pseudo_bfloat16.h\"\n\nnamespace oneflow {\n\nnamespace {\n\n// declear using \"BinaryFunctor\" from namespace \"ep::primitive::broadcast_elementwise_binary\"\ntemplate<DeviceType device, ep::primitive::BinaryOp binary_op, typename Src, typename Dst>\nusing BinaryFunctor =\n    ep::primitive::broadcast_elementwise_binary::BinaryFunctor<device, binary_op, Src, Dst>;\n\ntemplate<typename T, typename IndexType, ep::primitive::BinaryOp d_act_type,\n         ep::primitive::UnaryOp act_type, int32_t pack_size>\n__global__ void FusedGluWithoutLinearGradGpu(\n    const IndexType m, const IndexType packed_n, const IndexType pack_num,\n    const IndexType packed_stride, BinaryFunctor<DeviceType::kCUDA, d_act_type, T, T> dact,\n    ep::primitive::UnaryFunctor<DeviceType::kCUDA, act_type, T, T> act, const T* dy,\n    const T* matmul_wx, const T* matmul_vx, T* d_matmul_wx, T* d_matmul_vx) {\n  // define type of Pack\n  using LoadPack = cuda::elementwise::Packed<T, pack_size>;\n\n  // obtain global thread index\n  IndexType global_thread_id = blockIdx.x * blockDim.x + threadIdx.x;\n\n  // workload of current thread\n  for (IndexType packed_index = global_thread_id, step = gridDim.x * blockDim.x;\n       packed_index < pack_num; packed_index += step) {\n    // obtain the row and col index in output tensor \"d_matmul_wx\" and \"d_matmul_vx\"\n    const IndexType packed_row = packed_index / packed_n;\n    const IndexType packed_col = packed_index - packed_row * packed_n;\n\n    // cast type to load type\n    const LoadPack* dy_load =\n        reinterpret_cast<const LoadPack*>(dy) + (packed_row * packed_n + packed_col);\n    const LoadPack* matmul_wx_load =\n        reinterpret_cast<const LoadPack*>(matmul_wx) + (packed_row * packed_stride + packed_col);\n    const LoadPack* matmul_vx_load =\n        reinterpret_cast<const LoadPack*>(matmul_vx) + (packed_row * packed_stride + packed_col);\n\n    // init vectors\n    LoadPack dy_vec = *dy_load;\n    LoadPack matmul_wx_vec = *matmul_wx_load;\n    LoadPack matmul_vx_vec = *matmul_vx_load;\n    LoadPack d_matmul_wx_vec;\n    LoadPack d_matmul_vx_vec;\n#pragma unroll\n    for (int i = 0; i < pack_size; i++) {\n      // calculate the gradient of activated gate\n      T d_act_gate = matmul_wx_vec.elem[i] * dy_vec.elem[i];\n\n      // calculate the gradient of hidden_state\n      T gate = matmul_vx_vec.elem[i];\n      T act_gate = act(gate);\n      d_matmul_wx_vec.elem[i] = act_gate * dy_vec.elem[i];  // d_hidden_state\n\n      // calculate the gradient of gate\n      d_matmul_vx_vec.elem[i] = dact(d_act_gate, gate);  // d_gate\n    }\n    *(reinterpret_cast<LoadPack*>(d_matmul_wx) + (packed_row * packed_stride + packed_col)) =\n        d_matmul_wx_vec;\n    *(reinterpret_cast<LoadPack*>(d_matmul_vx) + (packed_row * packed_stride + packed_col)) =\n        d_matmul_vx_vec;\n  }\n}\n\ntemplate<typename T, typename IndexType, ep::primitive::UnaryOp act_type,\n         ep::primitive::BinaryOp d_act_type, int32_t pack_size>\nvoid LaunchFusedGluWithoutLinearGradGpu(ep::Stream* stream, const IndexType m,\n                                        const IndexType packed_n, const IndexType pack_num,\n                                        const IndexType packed_stride, const T* dy,\n                                        const T* matmul_wx, const T* matmul_vx, T* d_matmul_wx,\n                                        T* d_matmul_vx) {\n  constexpr int32_t block_size = 128;\n  unsigned int grid_size = (pack_num + block_size - 1) / block_size;\n  ep::primitive::UnaryFunctor<DeviceType::kCUDA, act_type, T, T> act(0, 0);\n  BinaryFunctor<DeviceType::kCUDA, d_act_type, T, T> dact(0, 0);\n  FusedGluWithoutLinearGradGpu<T, IndexType, d_act_type, act_type, pack_size>\n      <<<grid_size, block_size, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(\n          m, packed_n, pack_num, packed_stride, dact, act, dy, matmul_wx, matmul_vx, d_matmul_wx,\n          d_matmul_vx);\n}\n\ntemplate<typename T, ep::primitive::UnaryOp act_type, ep::primitive::BinaryOp d_act_type,\n         int32_t pack_size>\nvoid DispatchIndexType(ep::Stream* stream, const int64_t m, const int64_t packed_n,\n                       const int64_t pack_num, const int64_t packed_stride, const T* dy,\n                       const T* matmul_wx, const T* matmul_vx, T* d_matmul_wx, T* d_matmul_vx) {\n  if (pack_num < (1 << 30)) {\n    LaunchFusedGluWithoutLinearGradGpu<T, int32_t, act_type, d_act_type, pack_size>(\n        stream, m, packed_n, pack_num, packed_stride, dy, matmul_wx, matmul_vx, d_matmul_wx,\n        d_matmul_vx);\n  } else {\n    LaunchFusedGluWithoutLinearGradGpu<T, int64_t, act_type, d_act_type, pack_size>(\n        stream, m, packed_n, pack_num, packed_stride, dy, matmul_wx, matmul_vx, d_matmul_wx,\n        d_matmul_vx);\n  }\n}\n\ntemplate<typename T, ep::primitive::UnaryOp act_type, ep::primitive::BinaryOp d_act_type,\n         int32_t alignment, typename std::enable_if<alignment / sizeof(T) == 0, int>::type = 0>\nvoid DispatchPackSize(ep::Stream* stream, const int64_t m, const int64_t n, const int64_t stride,\n                      const T* dy, const T* matmul_wx, const T* matmul_vx, T* d_matmul_wx,\n                      T* d_matmul_vx) {\n  DispatchIndexType<T, act_type, d_act_type, 1>(stream, m, n, m * n, stride, dy, matmul_wx,\n                                                matmul_vx, d_matmul_wx, d_matmul_vx);\n}\n\ntemplate<typename T, ep::primitive::UnaryOp act_type, ep::primitive::BinaryOp d_act_type,\n         int32_t alignment, typename std::enable_if<alignment / sizeof(T) != 0, int>::type = 0>\nvoid DispatchPackSize(ep::Stream* stream, const int64_t m, const int64_t n, const int64_t stride,\n                      const T* dy, const T* matmul_wx, const T* matmul_vx, T* d_matmul_wx,\n                      T* d_matmul_vx) {\n  const int64_t pack_size = alignment / sizeof(T);\n  const int64_t packed_n = n / pack_size;\n  const int64_t pack_num = m * packed_n;\n  const int64_t packed_stride = stride / pack_size;\n  DispatchIndexType<T, act_type, d_act_type, alignment / sizeof(T)>(\n      stream, m, packed_n, pack_num, packed_stride, dy, matmul_wx, matmul_vx, d_matmul_wx,\n      d_matmul_vx);\n}\n\ntemplate<typename T, ep::primitive::UnaryOp act_type, ep::primitive::BinaryOp d_act_type>\nvoid DispatchAlignment(ep::Stream* stream, const int64_t m, const int64_t n, const int64_t stride,\n                       const T* dy, const T* matmul_wx, const T* matmul_vx, T* d_matmul_wx,\n                       T* d_matmul_vx) {\n  const auto IsAligned = [&](const size_t alignment) {\n    const uintptr_t dy_ptr = reinterpret_cast<uintptr_t>(dy);\n    const uintptr_t matmul_wx_ptr = reinterpret_cast<uintptr_t>(matmul_wx);\n    const uintptr_t matmul_vx_ptr = reinterpret_cast<uintptr_t>(matmul_vx);\n    const uintptr_t d_matmul_wx_ptr = reinterpret_cast<uintptr_t>(d_matmul_wx);\n    const uintptr_t d_matmul_vx_ptr = reinterpret_cast<uintptr_t>(d_matmul_vx);\n    const int64_t pack_size = alignment / sizeof(T);\n    return pack_size != 0 ? (/* memory address alignment */\n                             dy_ptr % alignment == 0 && matmul_vx_ptr % alignment == 0\n                             && matmul_wx_ptr % alignment == 0 && d_matmul_wx_ptr % alignment == 0\n                             && d_matmul_vx_ptr % alignment == 0\n                             /* #element per row alignment */\n                             && n % (pack_size) == 0)\n                          : false;\n  };\n\n  // dispatch alignment\n  if (IsAligned(16)) {\n    DispatchPackSize<T, act_type, d_act_type, 16>(stream, m, n, stride, dy, matmul_wx, matmul_vx,\n                                                  d_matmul_wx, d_matmul_vx);\n  } else if (IsAligned(8)) {\n    DispatchPackSize<T, act_type, d_act_type, 8>(stream, m, n, stride, dy, matmul_wx, matmul_vx,\n                                                 d_matmul_wx, d_matmul_vx);\n  } else if (IsAligned(4)) {\n    DispatchPackSize<T, act_type, d_act_type, 4>(stream, m, n, stride, dy, matmul_wx, matmul_vx,\n                                                 d_matmul_wx, d_matmul_vx);\n  } else if (IsAligned(2)) {\n    DispatchPackSize<T, act_type, d_act_type, 2>(stream, m, n, stride, dy, matmul_wx, matmul_vx,\n                                                 d_matmul_wx, d_matmul_vx);\n  } else {\n    DispatchPackSize<T, act_type, d_act_type, 1>(stream, m, n, stride, dy, matmul_wx, matmul_vx,\n                                                 d_matmul_wx, d_matmul_vx);\n  }\n}\n\ntemplate<typename T>\nvoid DispatchActivationType(ep::Stream* stream, const int64_t m, const int64_t n,\n                            const std::string& activation, const int64_t stride, const T* dy,\n                            const T* matmul_wx, const T* matmul_vx, T* d_matmul_wx,\n                            T* d_matmul_vx) {\n  if (activation == \"none\") {\n    DispatchAlignment<T, ep::primitive::UnaryOp::kIdentity,\n                      ep::primitive::BinaryOp::kIdentityBackwardWithDyX>(\n        stream, m, n, stride, dy, matmul_wx, matmul_vx, d_matmul_wx, d_matmul_vx);\n  } else if (activation == \"sigmoid\") {\n    DispatchAlignment<T, ep::primitive::UnaryOp::kSigmoid,\n                      ep::primitive::BinaryOp::kSigmoidBackwardWithDyX>(\n        stream, m, n, stride, dy, matmul_wx, matmul_vx, d_matmul_wx, d_matmul_vx);\n  } else if (activation == \"relu\") {\n    DispatchAlignment<T, ep::primitive::UnaryOp::kRelu,\n                      ep::primitive::BinaryOp::kReluBackwardWithDyX>(\n        stream, m, n, stride, dy, matmul_wx, matmul_vx, d_matmul_wx, d_matmul_vx);\n  } else if (activation == \"gelu\") {\n    DispatchAlignment<T, ep::primitive::UnaryOp::kGelu,\n                      ep::primitive::BinaryOp::kGeluBackwardWithDyX>(\n        stream, m, n, stride, dy, matmul_wx, matmul_vx, d_matmul_wx, d_matmul_vx);\n  } else if (activation == \"fast_gelu\") {\n    DispatchAlignment<T, ep::primitive::UnaryOp::kFastGelu,\n                      ep::primitive::BinaryOp::kFastGeluBackwardWithDyX>(\n        stream, m, n, stride, dy, matmul_wx, matmul_vx, d_matmul_wx, d_matmul_vx);\n  } else if (activation == \"silu\") {\n    DispatchAlignment<T, ep::primitive::UnaryOp::kSilu,\n                      ep::primitive::BinaryOp::kSiluBackwardWithDyX>(\n        stream, m, n, stride, dy, matmul_wx, matmul_vx, d_matmul_wx, d_matmul_vx);\n  } else {\n    UNIMPLEMENTED();\n  }\n}\n\ntemplate<typename T>\nclass GpuFusedGluWithoutLinearGradKernel final : public user_op::OpKernel {\n public:\n  GpuFusedGluWithoutLinearGradKernel() = default;\n  ~GpuFusedGluWithoutLinearGradKernel() override = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    // obtain tensors from context\n    const user_op::Tensor* input_tensor_dy = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    const user_op::Tensor* input_tensor_matmul_wx = ctx->Tensor4ArgNameAndIndex(\"matmul_wx\", 0);\n    user_op::Tensor* out_tensor_d_matmul_wx = ctx->Tensor4ArgNameAndIndex(\"d_matmul_wx\", 0);\n\n    // obtain optional tensors from context\n    bool is_split_mode = false;\n    user_op::Tensor* input_tensor_matmul_vx = nullptr;\n    user_op::Tensor* out_tensor_d_matmul_vx = nullptr;\n    if (ctx->has_input(\"matmul_vx\", 0)) {\n      input_tensor_matmul_vx = ctx->Tensor4ArgNameAndIndex(\"matmul_vx\", 0);\n      out_tensor_d_matmul_vx = ctx->Tensor4ArgNameAndIndex(\"d_matmul_vx\", 0);\n      is_split_mode = true;\n    }\n\n    // obtain tensor shapes and number of axes\n    const ShapeView& dy_shape = input_tensor_dy->shape_view();\n    const ShapeView& matmul_wx_shape = input_tensor_matmul_wx->shape_view();\n    const ShapeView& d_matmul_wx_shape = out_tensor_d_matmul_wx->shape_view();\n    const size_t dy_num_axes = dy_shape.NumAxes();\n    const size_t matmul_wx_num_axes = matmul_wx_shape.NumAxes();\n\n    // validate dimension and number of axes\n    CHECK_GE(dy_num_axes, 2) << \"number of axes of \\'dy\\' should have be greater than 1, yet get \"\n                             << dy_num_axes;\n    CHECK_GE(matmul_wx_num_axes, 2)\n        << \"number of axes of \\'matmul_wx\\' should have be greater than 1, yet get \"\n        << matmul_wx_num_axes;\n    CHECK_EQ(dy_num_axes, matmul_wx_num_axes)\n        << \"number of axes of \\'dy\\'(\" << dy_num_axes\n        << \") is not consistant with the one of \\'matmul_wx\\'(\" << matmul_wx_num_axes << \")\";\n\n    // check input shape\n    if (is_split_mode) {\n      CHECK_EQ(dy_shape.At(dy_num_axes - 1), matmul_wx_shape.At(matmul_wx_num_axes - 1))\n          << \"the last dimension of \\'dy\\'(\" << dy_shape.At(dy_num_axes - 1)\n          << \") is not consistant with the last dimension of \\'matmul_wx\\'(\"\n          << matmul_wx_shape.At(matmul_wx_num_axes - 1) << \")\";\n    } else {\n      CHECK_EQ(2 * dy_shape.At(dy_num_axes - 1), matmul_wx_shape.At(matmul_wx_num_axes - 1))\n          << \"two times of the last dimension of \\'dy\\'(\" << 2 * dy_shape.At(dy_num_axes - 1)\n          << \") is not consistant with the last dimension of \\'matmul_wx\\'(\"\n          << matmul_wx_shape.At(matmul_wx_num_axes - 1) << \")\";\n    }\n\n    // check optional input tensor shapes\n    if (is_split_mode) {\n      const user_op::Tensor* input_tensor_matmul_vx = ctx->Tensor4ArgNameAndIndex(\"matmul_vx\", 0);\n      const ShapeView& matmul_vx_shape = input_tensor_matmul_vx->shape_view();\n      const size_t matmul_vx_num_axes = matmul_vx_shape.NumAxes();\n      CHECK_GE(matmul_vx_num_axes, 2)\n          << \"number of axes of \\'matmul_vx\\' should have be greater than 1, yet get \"\n          << matmul_vx_num_axes;\n      CHECK_EQ(matmul_vx_num_axes, dy_num_axes)\n          << \"number of axes of \\'dy\\'(\" << dy_num_axes\n          << \") is not consistant with the one of \\'matmul_vx\\'(\" << matmul_vx_num_axes << \")\";\n      CHECK_EQ(matmul_vx_shape.At(matmul_vx_num_axes - 1), dy_shape.At(dy_num_axes - 1))\n          << \"the last dimension of \\'dy\\'(\" << dy_shape.At(dy_num_axes - 1)\n          << \") is not consistant with the last dimension of \\'matmul_vx\\'(\"\n          << matmul_vx_shape.At(matmul_vx_num_axes - 1) << \")\";\n    }\n\n    // infer m, n\n    const int64_t m = dy_shape.Count(0, dy_num_axes - 1);\n    const int64_t n = dy_shape.At(dy_num_axes - 1);\n\n    // start dispatch process\n    DispatchActivationType<T>(\n        ctx->stream(),\n        /*m, n=*/m, n,\n        /*activation=*/ctx->Attr<std::string>(\"activation\"),\n        /*stride=*/is_split_mode ? n : n * 2,\n        /*dy=*/input_tensor_dy->dptr<T>(),\n        /*matmul_wx=*/input_tensor_matmul_wx->dptr<T>(),\n        /*matmul_vx=*/\n        is_split_mode ? input_tensor_matmul_vx->dptr<T>() : input_tensor_matmul_wx->dptr<T>() + n,\n        /*d_matmul_wx=*/out_tensor_d_matmul_wx->mut_dptr<T>(),\n        /*d_matmul_vx=*/\n        is_split_mode ? out_tensor_d_matmul_vx->mut_dptr<T>()\n                      : out_tensor_d_matmul_wx->mut_dptr<T>() + n);\n  }\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n}  // namespace\n\n#define REGISTER_GPU_FUSED_GLU_WITHOUT_LINEAR_GRAD_KERNEL(dtype)       \\\n  REGISTER_USER_KERNEL(\"fused_glu_without_linear_grad\")                \\\n      .SetCreateFn<GpuFusedGluWithoutLinearGradKernel<dtype>>()        \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \\\n                       && (user_op::HobDataType(\"d_matmul_wx\", 0) == GetDataType<dtype>::value));\n\nREGISTER_GPU_FUSED_GLU_WITHOUT_LINEAR_GRAD_KERNEL(double)\nREGISTER_GPU_FUSED_GLU_WITHOUT_LINEAR_GRAD_KERNEL(float)\nREGISTER_GPU_FUSED_GLU_WITHOUT_LINEAR_GRAD_KERNEL(half)\n#if CUDA_VERSION >= 11000\nREGISTER_GPU_FUSED_GLU_WITHOUT_LINEAR_GRAD_KERNEL(nv_bfloat16)\n#endif\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/fused_gru_cell_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <limits>\n#include \"oneflow/core/device/cuda_util.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/ndarray/ndarray_util.h\"\n#include \"oneflow/core/ndarray/xpu_var_ndarray.h\"\n#include \"oneflow/core/kernel/kernel_util.h\"\n#include \"oneflow/core/kernel/cuda_graph_support.h\"\n#include \"oneflow/core/ep/include/primitive/cast.h\"\n#include \"oneflow/core/ep/include/primitive/fill.h\"\n#include \"oneflow/core/ep/cuda/cuda_device.h\"\n#include \"oneflow/core/ep/include/primitive/matmul.h\"\n#include \"oneflow/user/kernels/fused_rnn_cell_kernel_util.h\"\n\n// NOTE(Liang Depeng): The implementation of fused_gru_cell is modified from\n//                     https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cuda/RNN.cu\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename T>\nstruct AccumulateType {};\ntemplate<>\nstruct AccumulateType<float> {\n  using type = float;\n};\ntemplate<>\nstruct AccumulateType<double> {\n  using type = double;\n};\n\ntemplate<typename T>\nusing acc_type = typename AccumulateType<T>::type;\n\n#define H2F(input) static_cast<ACC_T>(input)\n#define F2H(input) static_cast<T>(input)\n\ntemplate<typename T>\n__device__ __forceinline__ T sigmoid(T in) {\n  T one = static_cast<T>(1.0);\n  return one / (one + ::exp(-in));\n}\n\ntemplate<typename T, typename ACC_T, typename IDX_TYPE>\n#if __CUDA_ARCH__ >= 350\nOF_LAUNCH_BOUNDS_2(512, 4)\n#endif\n__global__ void gru_cell_forward(const IDX_TYPE numel, const IDX_TYPE hidden_size,\n                                 const T* input_gates_ptr, const T* hidden_gates_ptr,\n                                 const T* hx_ptr, const T* input_bias_ptr, const T* hidden_bias_ptr,\n                                 T* hy_ptr, T* workspace_ptr) {\n  bool has_bias = input_bias_ptr != nullptr;\n  for (IDX_TYPE linearIndex = blockIdx.x * blockDim.x + threadIdx.x; linearIndex < numel;\n       linearIndex += gridDim.x * blockDim.x) {\n    IDX_TYPE offset = (linearIndex / hidden_size) * 3 * hidden_size + linearIndex % hidden_size;\n\n    T ir = input_gates_ptr[offset + 0 * hidden_size];\n    T ii = input_gates_ptr[offset + 1 * hidden_size];\n    T in = input_gates_ptr[offset + 2 * hidden_size];\n    T hr = hidden_gates_ptr[offset + 0 * hidden_size];\n    T hi = hidden_gates_ptr[offset + 1 * hidden_size];\n    T hn = hidden_gates_ptr[offset + 2 * hidden_size];\n\n    T hx = hx_ptr[linearIndex];\n    T* hy = &(hy_ptr[linearIndex]);\n\n    T b1r, b1i, b1n, b2r, b2i, b2n;\n\n    if (has_bias) {\n      b1r = input_bias_ptr[linearIndex % hidden_size + 0 * hidden_size];\n      b1i = input_bias_ptr[linearIndex % hidden_size + 1 * hidden_size];\n      b1n = input_bias_ptr[linearIndex % hidden_size + 2 * hidden_size];\n\n      b2r = hidden_bias_ptr[linearIndex % hidden_size + 0 * hidden_size];\n      b2i = hidden_bias_ptr[linearIndex % hidden_size + 1 * hidden_size];\n      b2n = hidden_bias_ptr[linearIndex % hidden_size + 2 * hidden_size];\n    } else {\n      b1r = F2H(0.0);\n      b1i = F2H(0.0);\n      b1n = F2H(0.0);\n      b2r = F2H(0.0);\n      b2i = F2H(0.0);\n      b2n = F2H(0.0);\n    }\n\n    offset = (linearIndex / hidden_size) * 5 * hidden_size + linearIndex % hidden_size;\n    ACC_T rg, ig, ng;\n    rg = sigmoid(H2F(ir) + H2F(hr) + H2F(b1r) + H2F(b2r));\n    ig = sigmoid(H2F(ii) + H2F(hi) + H2F(b1i) + H2F(b2i));\n\n    ng = H2F(in) + H2F(b1n) + rg * (H2F(hn) + H2F(b2n));\n    ng = ::tanh(ng);\n    *hy = F2H(ng + ig * (H2F(hx) - ng));\n\n    // SAVE FOR BACKWARDS\n    workspace_ptr[offset + 0 * hidden_size] = F2H(rg);\n    workspace_ptr[offset + 1 * hidden_size] = F2H(ig);\n    workspace_ptr[offset + 2 * hidden_size] = F2H(ng);\n    workspace_ptr[offset + 3 * hidden_size] = hx;\n    workspace_ptr[offset + 4 * hidden_size] = F2H(H2F(hn) + H2F(b2n));\n  }\n}\n\ntemplate<typename T, typename ACC_T, typename IDX_TYPE>\n#if __CUDA_ARCH__ >= 350\nOF_LAUNCH_BOUNDS_2(512, 4)\n#endif\n__global__\n    void gru_cell_backward(const IDX_TYPE numel, const IDX_TYPE hidden_size, const T* grad_hy_ptr,\n                           const T* workspace_ptr, T* grad_input_gates_ptr,\n                           T* grad_hidden_gates_ptr, T* grad_hx_ptr) {\n  for (IDX_TYPE linearIndex = blockIdx.x * blockDim.x + threadIdx.x; linearIndex < numel;\n       linearIndex += gridDim.x * blockDim.x) {\n    IDX_TYPE offset = (linearIndex / hidden_size) * 5 * hidden_size + linearIndex % hidden_size;\n\n    T rg = workspace_ptr[offset + 0 * hidden_size];\n    T ig = workspace_ptr[offset + 1 * hidden_size];\n    T ng = workspace_ptr[offset + 2 * hidden_size];\n    T hx = workspace_ptr[offset + 3 * hidden_size];\n    T hn = workspace_ptr[offset + 4 * hidden_size];\n\n    T go = grad_hy_ptr[linearIndex];\n\n    offset = (linearIndex / hidden_size) * 3 * hidden_size + linearIndex % hidden_size;\n\n    ACC_T gig = H2F(go) * (H2F(hx) - H2F(ng)) * (1 - H2F(ig)) * H2F(ig);\n    ACC_T ghx = H2F(go) * H2F(ig);\n    ACC_T gin = H2F(go) * (1 - H2F(ig)) * (1 - H2F(ng) * H2F(ng));\n    ACC_T ghn = gin * H2F(rg);\n    ACC_T grg = gin * H2F(hn) * (1 - H2F(rg)) * H2F(rg);\n\n    grad_input_gates_ptr[offset + 0 * hidden_size] = F2H(grg);\n    grad_input_gates_ptr[offset + 1 * hidden_size] = F2H(gig);\n    grad_input_gates_ptr[offset + 2 * hidden_size] = F2H(gin);\n\n    grad_hidden_gates_ptr[offset + 0 * hidden_size] = F2H(grg);\n    grad_hidden_gates_ptr[offset + 1 * hidden_size] = F2H(gig);\n    grad_hidden_gates_ptr[offset + 2 * hidden_size] = F2H(ghn);\n    if (grad_hx_ptr != nullptr) { grad_hx_ptr[linearIndex] = F2H(ghx); }\n  }\n}\n\ntemplate<typename T>\nstruct FusedGruCellGradFunctor final {\n  void operator()(ep::Stream* stream, const int64_t hx_numel, const int64_t workspace_numel,\n                  const int64_t hidden_size, const T* grad_hy_ptr, const T* workspace_ptr,\n                  T* grad_input_gates_ptr, T* grad_hidden_gates_ptr, T* grad_hx_ptr) {\n    using ACC_T = acc_type<T>;\n    if (workspace_numel < std::numeric_limits<int32_t>::max()) {\n      RUN_CUDA_KERNEL((gru_cell_backward<T, ACC_T, int32_t>), stream, hx_numel,\n                      static_cast<int32_t>(hx_numel), static_cast<int32_t>(hidden_size),\n                      grad_hy_ptr, workspace_ptr, grad_input_gates_ptr, grad_hidden_gates_ptr,\n                      grad_hx_ptr);\n    } else {\n      RUN_CUDA_KERNEL((gru_cell_backward<T, ACC_T, int64_t>), stream, hx_numel, hx_numel,\n                      hidden_size, grad_hy_ptr, workspace_ptr, grad_input_gates_ptr,\n                      grad_hidden_gates_ptr, grad_hx_ptr);\n    }\n  }\n};\n\ntemplate<>\nvoid FusedGruCellGradFunctor<float16>::operator()(\n    ep::Stream* stream, const int64_t hx_numel, const int64_t workspace_numel,\n    const int64_t hidden_size, const float16* grad_hy_ptr, const float16* workspace_ptr,\n    float16* grad_input_gates_ptr, float16* grad_hidden_gates_ptr, float16* grad_hx_ptr) {\n  if (workspace_numel < std::numeric_limits<int32_t>::max()) {\n    RUN_CUDA_KERNEL(\n        (gru_cell_backward<half, float, int32_t>), stream, hx_numel, static_cast<int32_t>(hx_numel),\n        static_cast<int32_t>(hidden_size), reinterpret_cast<const half*>(grad_hy_ptr),\n        reinterpret_cast<const half*>(workspace_ptr), reinterpret_cast<half*>(grad_input_gates_ptr),\n        reinterpret_cast<half*>(grad_hidden_gates_ptr), reinterpret_cast<half*>(grad_hx_ptr));\n  } else {\n    RUN_CUDA_KERNEL(\n        (gru_cell_backward<half, float, int64_t>), stream, hx_numel, hx_numel, hidden_size,\n        reinterpret_cast<const half*>(grad_hy_ptr), reinterpret_cast<const half*>(workspace_ptr),\n        reinterpret_cast<half*>(grad_input_gates_ptr),\n        reinterpret_cast<half*>(grad_hidden_gates_ptr), reinterpret_cast<half*>(grad_hx_ptr));\n  }\n}\n\ntemplate<typename T>\nstruct FusedGruCellFunctor final {\n  void operator()(ep::Stream* stream, const int64_t hx_numel, const int64_t workspace_numel,\n                  const int64_t hidden_size, const T* input_gates_ptr, const T* hidden_gates_ptr,\n                  const T* hx_ptr, const T* input_bias_ptr, const T* hidden_bias_ptr, T* hy_ptr,\n                  T* workspace_ptr) {\n    using ACC_T = acc_type<T>;\n    if (workspace_numel < std::numeric_limits<int32_t>::max()) {\n      RUN_CUDA_KERNEL((gru_cell_forward<T, ACC_T, int32_t>), stream, hx_numel,\n                      static_cast<int32_t>(hx_numel), static_cast<int32_t>(hidden_size),\n                      input_gates_ptr, hidden_gates_ptr, hx_ptr, input_bias_ptr, hidden_bias_ptr,\n                      hy_ptr, workspace_ptr);\n    } else {\n      RUN_CUDA_KERNEL((gru_cell_forward<T, ACC_T, int64_t>), stream, hx_numel, hx_numel,\n                      hidden_size, input_gates_ptr, hidden_gates_ptr, hx_ptr, input_bias_ptr,\n                      hidden_bias_ptr, hy_ptr, workspace_ptr);\n    }\n  }\n};\n\ntemplate<>\nvoid FusedGruCellFunctor<float16>::operator()(\n    ep::Stream* stream, const int64_t hx_numel, const int64_t workspace_numel,\n    const int64_t hidden_size, const float16* input_gates_ptr, const float16* hidden_gates_ptr,\n    const float16* hx_ptr, const float16* input_bias_ptr, const float16* hidden_bias_ptr,\n    float16* hy_ptr, float16* workspace_ptr) {\n  if (workspace_numel < std::numeric_limits<int32_t>::max()) {\n    RUN_CUDA_KERNEL(\n        (gru_cell_forward<half, float, int32_t>), stream, hx_numel, static_cast<int32_t>(hx_numel),\n        static_cast<int32_t>(hidden_size), reinterpret_cast<const half*>(input_gates_ptr),\n        reinterpret_cast<const half*>(hidden_gates_ptr), reinterpret_cast<const half*>(hx_ptr),\n        reinterpret_cast<const half*>(input_bias_ptr),\n        reinterpret_cast<const half*>(hidden_bias_ptr), reinterpret_cast<half*>(hy_ptr),\n        reinterpret_cast<half*>(workspace_ptr));\n  } else {\n    RUN_CUDA_KERNEL((gru_cell_forward<half, float, int64_t>), stream, hx_numel, hx_numel,\n                    hidden_size, reinterpret_cast<const half*>(input_gates_ptr),\n                    reinterpret_cast<const half*>(hidden_gates_ptr),\n                    reinterpret_cast<const half*>(hx_ptr),\n                    reinterpret_cast<const half*>(input_bias_ptr),\n                    reinterpret_cast<const half*>(hidden_bias_ptr), reinterpret_cast<half*>(hy_ptr),\n                    reinterpret_cast<half*>(workspace_ptr));\n  }\n}\n\n}  // namespace\n\ntemplate<typename T>\nclass GpuFusedGruCellKernel final : public user_op::OpKernel {\n public:\n  GpuFusedGruCellKernel() = default;\n  ~GpuFusedGruCellKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* input_gates = ctx->Tensor4ArgNameAndIndex(\"input_gates\", 0);\n    const user_op::Tensor* hidden_gates = ctx->Tensor4ArgNameAndIndex(\"hidden_gates\", 0);\n    const user_op::Tensor* hx = ctx->Tensor4ArgNameAndIndex(\"hx\", 0);\n    user_op::Tensor* hy = ctx->Tensor4ArgNameAndIndex(\"hy\", 0);\n    user_op::Tensor* workspace = ctx->Tensor4ArgNameAndIndex(\"workspace\", 0);\n\n    const T* input_bias_ptr = nullptr;\n    const T* hidden_bias_ptr = nullptr;\n    if (ctx->has_input(\"input_bias\", 0)) {\n      CHECK(ctx->has_input(\"hidden_bias\", 0));\n      input_bias_ptr = ctx->Tensor4ArgNameAndIndex(\"input_bias\", 0)->dptr<T>();\n      hidden_bias_ptr = ctx->Tensor4ArgNameAndIndex(\"hidden_bias\", 0)->dptr<T>();\n    }\n    const T* input_gates_ptr = input_gates->dptr<T>();\n    const T* hidden_gates_ptr = hidden_gates->dptr<T>();\n    const T* hx_ptr = hx->dptr<T>();\n\n    T* hy_ptr = hy->mut_dptr<T>();\n    T* workspace_ptr = workspace->mut_dptr<T>();\n    const int64_t hx_numel = hx->shape_view().elem_cnt();\n    const int64_t workspace_numel = workspace->shape_view().elem_cnt();\n    const int64_t hidden_size = hx->shape_view().At(hx->shape_view().NumAxes() - 1);\n    FusedGruCellFunctor<T>()(ctx->stream(), hx_numel, workspace_numel, hidden_size, input_gates_ptr,\n                             hidden_gates_ptr, hx_ptr, input_bias_ptr, hidden_bias_ptr, hy_ptr,\n                             workspace_ptr);\n  }\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_FUSED_GRU_CELL_KERNEL(dtype)                                                   \\\n  REGISTER_USER_KERNEL(\"fused_gru_cell\")                                                        \\\n      .SetCreateFn<GpuFusedGruCellKernel<dtype>>()                                              \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)                          \\\n                       && (user_op::HobDataType(\"hx\", 0) == GetDataType<dtype>::value)          \\\n                       && (user_op::HobDataType(\"input_gates\", 0) == GetDataType<dtype>::value) \\\n                       && (user_op::HobDataType(\"hidden_gates\", 0) == GetDataType<dtype>::value))\n\nREGISTER_FUSED_GRU_CELL_KERNEL(float);\nREGISTER_FUSED_GRU_CELL_KERNEL(float16);\n\nclass GpuFusedGruCellGradFloatKernel final : public user_op::OpKernel {\n public:\n  GpuFusedGruCellGradFloatKernel() = default;\n  ~GpuFusedGruCellGradFloatKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* grad_hy = ctx->Tensor4ArgNameAndIndex(\"grad_hy\", 0);\n    const user_op::Tensor* workspace = ctx->Tensor4ArgNameAndIndex(\"workspace\", 0);\n    user_op::Tensor* grad_input_gates = ctx->Tensor4ArgNameAndIndex(\"grad_input_gates\", 0);\n    user_op::Tensor* grad_hidden_gates = ctx->Tensor4ArgNameAndIndex(\"grad_hidden_gates\", 0);\n\n    const float* grad_hy_ptr = grad_hy->dptr<float>();\n    const float* workspace_ptr = workspace->dptr<float>();\n\n    float* grad_input_gates_ptr = grad_input_gates->mut_dptr<float>();\n    float* grad_hidden_gates_ptr = grad_hidden_gates->mut_dptr<float>();\n\n    float* grad_hx_ptr = nullptr;\n    if (ctx->has_output(\"grad_hx\", 0)) {\n      user_op::Tensor* grad_hx = ctx->Tensor4ArgNameAndIndex(\"grad_hx\", 0);\n      grad_hx_ptr = grad_hx->mut_dptr<float>();\n    }\n\n    const int64_t hx_numel = grad_hy->shape_view().elem_cnt();\n    const int64_t workspace_numel = workspace->shape_view().elem_cnt();\n    const int64_t hidden_size = grad_hy->shape_view().At(grad_hy->shape_view().NumAxes() - 1);\n    FusedGruCellGradFunctor<float>()(ctx->stream(), hx_numel, workspace_numel, hidden_size,\n                                     grad_hy_ptr, workspace_ptr, grad_input_gates_ptr,\n                                     grad_hidden_gates_ptr, grad_hx_ptr);\n\n    if (ctx->has_output(\"grad_input_bias\", 0) && ctx->has_output(\"grad_hidden_bias\", 0)) {\n      float* grad_input_bias_ptr =\n          ctx->Tensor4ArgNameAndIndex(\"grad_input_bias\", 0)->mut_dptr<float>();\n      std::vector<int32_t> axis;\n      axis.push_back(0);\n      const Shape& reduced_shape =\n          CreateReducedShape(grad_input_gates->shape_view(), {axis.begin(), axis.end()});\n      user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n      NdarrayReduce<DeviceType::kCUDA, float, BinaryFuncSum>::Reduce(\n          ctx->stream(), XpuVarNdarray<float>(reduced_shape, grad_input_bias_ptr),\n          XpuVarNdarray<const float>(grad_input_gates->shape_view(),\n                                     grad_input_gates->dptr<float>()),\n          XpuVarNdarray<float>(tmp_buffer->shape_view(), tmp_buffer->mut_dptr<float>()));\n\n      float* grad_hidden_bias_ptr =\n          ctx->Tensor4ArgNameAndIndex(\"grad_hidden_bias\", 0)->mut_dptr<float>();\n      NdarrayReduce<DeviceType::kCUDA, float, BinaryFuncSum>::Reduce(\n          ctx->stream(), XpuVarNdarray<float>(reduced_shape, grad_hidden_bias_ptr),\n          XpuVarNdarray<const float>(grad_hidden_gates->shape_view(),\n                                     grad_hidden_gates->dptr<float>()),\n          XpuVarNdarray<float>(tmp_buffer->shape_view(), tmp_buffer->mut_dptr<float>()));\n    }\n  }\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\nREGISTER_USER_KERNEL(\"fused_gru_cell_grad\")\n    .SetCreateFn<GpuFusedGruCellGradFloatKernel>()\n    .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)\n                     && (user_op::HobDataType(\"grad_hy\", 0) == GetDataType<float>::value)\n                     && (user_op::HobDataType(\"workspace\", 0) == GetDataType<float>::value))\n    .SetInferTmpSizeFn([](user_op::InferContext* ctx) {\n      size_t tmp_bytes = 0;\n      if (ctx->has_output(\"grad_input_bias\", 0) && ctx->has_output(\"grad_hidden_bias\", 0)) {\n        const Shape& in_shape = ctx->InputTensorDesc(\"grad_hy\", 0).shape();\n        tmp_bytes = GetCudaAlignedSize(in_shape.elem_cnt() * 3 * sizeof(float));\n      } else {\n        tmp_bytes = 0;\n      }\n      return tmp_bytes;\n    });\n\nclass GpuFusedGruCellGradHalfKernel final : public user_op::OpKernel {\n public:\n  GpuFusedGruCellGradHalfKernel() = default;\n  ~GpuFusedGruCellGradHalfKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* grad_hy = ctx->Tensor4ArgNameAndIndex(\"grad_hy\", 0);\n    const user_op::Tensor* workspace = ctx->Tensor4ArgNameAndIndex(\"workspace\", 0);\n    user_op::Tensor* grad_input_gates = ctx->Tensor4ArgNameAndIndex(\"grad_input_gates\", 0);\n    user_op::Tensor* grad_hidden_gates = ctx->Tensor4ArgNameAndIndex(\"grad_hidden_gates\", 0);\n\n    const float16* grad_hy_ptr = grad_hy->dptr<float16>();\n    const float16* workspace_ptr = workspace->dptr<float16>();\n\n    float16* grad_input_gates_ptr = grad_input_gates->mut_dptr<float16>();\n    float16* grad_hidden_gates_ptr = grad_hidden_gates->mut_dptr<float16>();\n\n    float16* grad_hx_ptr = nullptr;\n    if (ctx->has_output(\"grad_hx\", 0)) {\n      user_op::Tensor* grad_hx = ctx->Tensor4ArgNameAndIndex(\"grad_hx\", 0);\n      grad_hx_ptr = grad_hx->mut_dptr<float16>();\n    }\n\n    const int64_t hx_numel = grad_hy->shape_view().elem_cnt();\n    const int64_t workspace_numel = workspace->shape_view().elem_cnt();\n    const int64_t hidden_size = grad_hy->shape_view().At(grad_hy->shape_view().NumAxes() - 1);\n    FusedGruCellGradFunctor<float16>()(ctx->stream(), hx_numel, workspace_numel, hidden_size,\n                                       grad_hy_ptr, workspace_ptr, grad_input_gates_ptr,\n                                       grad_hidden_gates_ptr, grad_hx_ptr);\n\n    if (ctx->has_output(\"grad_input_bias\", 0) && ctx->has_output(\"grad_hidden_bias\", 0)) {\n      std::vector<int32_t> axis;\n      axis.push_back(0);\n      user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n      const ShapeView& in_shape = grad_input_gates->shape_view();\n      const Shape& reduced_shape = CreateReducedShape(in_shape, {axis.begin(), axis.end()});\n      float* in_tmp_buffer = tmp_buffer->mut_dptr<float>();\n      const size_t in_tmp_buffer_bytes = GetCudaAlignedSize(in_shape.elem_cnt() * sizeof(float));\n      float* out_tmp_buffer =\n          reinterpret_cast<float*>(tmp_buffer->mut_dptr<char>() + in_tmp_buffer_bytes);\n      const size_t out_tmp_buffer_bytes =\n          GetCudaAlignedSize(reduced_shape.elem_cnt() * sizeof(float));\n      float* reduce_tmp_buffer = reinterpret_cast<float*>(\n          tmp_buffer->mut_dptr<char>() + in_tmp_buffer_bytes + out_tmp_buffer_bytes);\n      const size_t reduce_tmp_buffer_bytes =\n          GetCudaAlignedSize(in_shape.elem_cnt() * sizeof(float));\n      CHECK_LE(in_tmp_buffer_bytes + out_tmp_buffer_bytes + reduce_tmp_buffer_bytes,\n               tmp_buffer->shape_view().elem_cnt());\n      auto h2f = ep::primitive::NewPrimitive<ep::primitive::CastFactory>(\n          ctx->device_type(), DataType::kFloat16, DataType::kFloat);\n      CHECK(h2f);\n      auto f2h = ep::primitive::NewPrimitive<ep::primitive::CastFactory>(\n          ctx->device_type(), DataType::kFloat, DataType::kFloat16);\n      CHECK(f2h);\n      h2f->Launch(ctx->stream(), grad_input_gates->dptr<float16>(), in_tmp_buffer,\n                  in_shape.elem_cnt());\n\n      NdarrayReduce<DeviceType::kCUDA, float, BinaryFuncSum>::Reduce(\n          ctx->stream(), XpuVarNdarray<float>(reduced_shape, out_tmp_buffer),\n          XpuVarNdarray<const float>(in_shape, in_tmp_buffer),\n          XpuVarNdarray<float>(in_shape, reduce_tmp_buffer));\n\n      user_op::Tensor* output_tensor = ctx->Tensor4ArgNameAndIndex(\"grad_input_bias\", 0);\n      f2h->Launch(ctx->stream(), out_tmp_buffer, output_tensor->mut_dptr<float16>(),\n                  output_tensor->shape_view().elem_cnt());\n\n      h2f->Launch(ctx->stream(), grad_hidden_gates->dptr<float16>(), in_tmp_buffer,\n                  in_shape.elem_cnt());\n      NdarrayReduce<DeviceType::kCUDA, float, BinaryFuncSum>::Reduce(\n          ctx->stream(), XpuVarNdarray<float>(reduced_shape, out_tmp_buffer),\n          XpuVarNdarray<const float>(in_shape, in_tmp_buffer),\n          XpuVarNdarray<float>(in_shape, reduce_tmp_buffer));\n\n      output_tensor = ctx->Tensor4ArgNameAndIndex(\"grad_hidden_bias\", 0);\n      f2h->Launch(ctx->stream(), out_tmp_buffer, output_tensor->mut_dptr<float16>(),\n                  output_tensor->shape_view().elem_cnt());\n    }\n  }\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\nREGISTER_USER_KERNEL(\"fused_gru_cell_grad\")\n    .SetCreateFn<GpuFusedGruCellGradHalfKernel>()\n    .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)\n                     && (user_op::HobDataType(\"grad_hy\", 0) == GetDataType<float16>::value)\n                     && (user_op::HobDataType(\"workspace\", 0) == GetDataType<float16>::value))\n    .SetInferTmpSizeFn([](user_op::InferContext* ctx) {\n      size_t tmp_bytes = 0;\n      if (ctx->has_output(\"grad_input_bias\", 0) && ctx->has_output(\"grad_hidden_bias\", 0)) {\n        const Shape& in_shape = ctx->InputTensorDesc(\"grad_hy\", 0).shape();\n        const Shape& out_shape = ctx->OutputTensorDesc(\"grad_input_bias\", 0).shape();\n        tmp_bytes = (2 * GetCudaAlignedSize(in_shape.elem_cnt() * 3 * sizeof(float))\n                     + GetCudaAlignedSize(out_shape.elem_cnt() * sizeof(float)));\n      } else {\n        tmp_bytes = 0;\n      }\n      return tmp_bytes;\n    });\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/fused_lstm_cell_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <limits>\n#include \"oneflow/core/device/cuda_util.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/ndarray/ndarray_util.h\"\n#include \"oneflow/core/ndarray/xpu_var_ndarray.h\"\n#include \"oneflow/core/kernel/kernel_util.h\"\n#include \"oneflow/core/kernel/cuda_graph_support.h\"\n#include \"oneflow/core/ep/include/primitive/cast.h\"\n#include \"oneflow/core/ep/include/primitive/fill.h\"\n#include \"oneflow/core/ep/cuda/cuda_device.h\"\n#include \"oneflow/core/ep/include/primitive/matmul.h\"\n#include \"oneflow/user/kernels/fused_rnn_cell_kernel_util.h\"\n\n// NOTE(Liang Depeng): The implementation of fused_lstm_cell is modified from\n//                     https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cuda/RNN.cu\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename T>\nstruct AccumulateType {};\ntemplate<>\nstruct AccumulateType<float> {\n  using type = float;\n};\ntemplate<>\nstruct AccumulateType<double> {\n  using type = double;\n};\n\ntemplate<typename T>\nusing acc_type = typename AccumulateType<T>::type;\n\n#define H2F(input) static_cast<ACC_T>(input)\n#define F2H(input) static_cast<T>(input)\n\ntemplate<typename T>\n__device__ __forceinline__ T sigmoid(T in) {\n  T one = static_cast<T>(1.0);\n  return one / (one + ::exp(-in));\n}\n\ntemplate<typename T, typename ACC_T, typename IDX_TYPE>\n#if __CUDA_ARCH__ >= 350\nOF_LAUNCH_BOUNDS_2(512, 4)\n#endif\n__global__\n    void lstm_cell_forward(const IDX_TYPE numel, const IDX_TYPE hidden_size,\n                           const T* input_gates_ptr, const T* hidden_gates_ptr, const T* cx_ptr,\n                           const T* input_bias_ptr, const T* hidden_bias_ptr, T* hy_ptr, T* cy_ptr,\n                           T* workspace_ptr) {\n  bool has_bias = input_bias_ptr != nullptr;\n  for (IDX_TYPE linearIndex = blockIdx.x * blockDim.x + threadIdx.x; linearIndex < numel;\n       linearIndex += gridDim.x * blockDim.x) {\n    IDX_TYPE offset = (linearIndex / hidden_size) * 4 * hidden_size + linearIndex % hidden_size;\n\n    T iig = input_gates_ptr[offset + 0 * hidden_size];\n    T ifg = input_gates_ptr[offset + 1 * hidden_size];\n    T icg = input_gates_ptr[offset + 2 * hidden_size];\n    T iog = input_gates_ptr[offset + 3 * hidden_size];\n\n    T hig = hidden_gates_ptr[offset + 0 * hidden_size];\n    T hfg = hidden_gates_ptr[offset + 1 * hidden_size];\n    T hcg = hidden_gates_ptr[offset + 2 * hidden_size];\n    T hog = hidden_gates_ptr[offset + 3 * hidden_size];\n\n    T* wig = &(workspace_ptr[offset + 0 * hidden_size]);\n    T* wfg = &(workspace_ptr[offset + 1 * hidden_size]);\n    T* wcg = &(workspace_ptr[offset + 2 * hidden_size]);\n    T* wog = &(workspace_ptr[offset + 3 * hidden_size]);\n\n    T cx = cx_ptr[linearIndex];\n\n    T* hy = &(hy_ptr[linearIndex]);\n    T* cy = &(cy_ptr[linearIndex]);\n\n    T b1i, b1f, b1c, b1o;\n    T b2i, b2f, b2c, b2o;\n\n    if (has_bias) {\n      b1i = input_bias_ptr[linearIndex % hidden_size + 0 * hidden_size];\n      b1f = input_bias_ptr[linearIndex % hidden_size + 1 * hidden_size];\n      b1c = input_bias_ptr[linearIndex % hidden_size + 2 * hidden_size];\n      b1o = input_bias_ptr[linearIndex % hidden_size + 3 * hidden_size];\n\n      b2i = hidden_bias_ptr[linearIndex % hidden_size + 0 * hidden_size];\n      b2f = hidden_bias_ptr[linearIndex % hidden_size + 1 * hidden_size];\n      b2c = hidden_bias_ptr[linearIndex % hidden_size + 2 * hidden_size];\n      b2o = hidden_bias_ptr[linearIndex % hidden_size + 3 * hidden_size];\n    } else {\n      b1i = F2H(0.0);\n      b1f = F2H(0.0);\n      b1c = F2H(0.0);\n      b1o = F2H(0.0);\n      b2i = F2H(0.0);\n      b2f = F2H(0.0);\n      b2c = F2H(0.0);\n      b2o = F2H(0.0);\n    }\n\n    ACC_T ig, fg, cg, og;\n    ACC_T f_hy, f_cy;\n\n    ig = sigmoid(H2F(iig) + H2F(hig) + H2F(b1i) + H2F(b2i));\n    fg = sigmoid(H2F(ifg) + H2F(hfg) + H2F(b1f) + H2F(b2f));\n    cg = ::tanh(H2F(icg) + H2F(hcg) + H2F(b1c) + H2F(b2c));\n    og = sigmoid(H2F(iog) + H2F(hog) + H2F(b1o) + H2F(b2o));\n\n    f_cy = (fg * H2F(cx)) + (ig * cg);\n    f_hy = og * ::tanh(f_cy);\n\n    *hy = F2H(f_hy);\n    *cy = F2H(f_cy);\n\n    // SAVE FOR BACKWARDS\n    // Also need cy and cx but can be saved easily in python\n    *wig = F2H(ig);\n    *wfg = F2H(fg);\n    *wcg = F2H(cg);\n    *wog = F2H(og);\n  }\n}\n\ntemplate<typename T, typename ACC_T, typename IDX_TYPE>\n#if __CUDA_ARCH__ >= 350\nOF_LAUNCH_BOUNDS_2(512, 4)\n#endif\n__global__\n    void lstm_cell_backward(const IDX_TYPE numel, const IDX_TYPE hidden_size, const T* grad_hy_ptr,\n                            const T* grad_cy_ptr, const T* cx_ptr, const T* cy_ptr,\n                            const T* workspace_ptr, T* grad_gates_ptr, T* grad_cx_ptr) {\n  for (IDX_TYPE linearIndex = blockIdx.x * blockDim.x + threadIdx.x; linearIndex < numel;\n       linearIndex += gridDim.x * blockDim.x) {\n    IDX_TYPE offset = (linearIndex / hidden_size) * 4 * hidden_size + linearIndex % hidden_size;\n\n    T ig = workspace_ptr[offset + 0 * hidden_size];\n    T fg = workspace_ptr[offset + 1 * hidden_size];\n    T cg = workspace_ptr[offset + 2 * hidden_size];\n    T og = workspace_ptr[offset + 3 * hidden_size];\n\n    T* ih = &(grad_gates_ptr[offset + 0 * hidden_size]);\n    T* fh = &(grad_gates_ptr[offset + 1 * hidden_size]);\n    T* ch = &(grad_gates_ptr[offset + 2 * hidden_size]);\n    T* oh = &(grad_gates_ptr[offset + 3 * hidden_size]);\n\n    // will return hidden grads here\n    T cx = cx_ptr[linearIndex];\n    T cy = cy_ptr[linearIndex];\n\n    ACC_T go = H2F(grad_hy_ptr[linearIndex]);\n    ACC_T goc = H2F(grad_cy_ptr[linearIndex]);\n\n    ACC_T gcx = ::tanh(H2F(cy));\n\n    ACC_T gog = go * gcx;\n    gcx = go * H2F(og) * (1 - gcx * gcx) + goc;\n\n    ACC_T gig = gcx * H2F(cg);\n    ACC_T gfg = gcx * H2F(cx);\n    ACC_T gcg = gcx * H2F(ig);\n\n    gig = gig * (1 - H2F(ig)) * H2F(ig);\n    gfg = gfg * (1 - H2F(fg)) * H2F(fg);\n    gcg = gcg * (1 - H2F(cg) * H2F(cg));\n    gog = gog * (1 - H2F(og)) * H2F(og);\n\n    *ih = F2H(gig);\n    *fh = F2H(gfg);\n    *ch = F2H(gcg);\n    *oh = F2H(gog);\n\n    if (grad_cx_ptr != nullptr) {\n      gcx = gcx * H2F(fg);\n      T* gi = &(grad_cx_ptr[linearIndex]);\n      *gi = F2H(gcx);\n    }\n  }\n}\n\ntemplate<typename T>\nstruct FusedLstmCellFunctor final {\n  void operator()(ep::Stream* stream, const int64_t cx_numel, const int64_t workspace_numel,\n                  const int64_t hidden_size, const T* input_gates_ptr, const T* hidden_gates_ptr,\n                  const T* cx_ptr, const T* input_bias_ptr, const T* hidden_bias_ptr, T* hy_ptr,\n                  T* cy_ptr, T* workspace_ptr) {\n    using ACC_T = acc_type<T>;\n    if (workspace_numel < std::numeric_limits<int32_t>::max()) {\n      RUN_CUDA_KERNEL((lstm_cell_forward<T, ACC_T, int32_t>), stream, cx_numel,\n                      static_cast<int32_t>(cx_numel), static_cast<int32_t>(hidden_size),\n                      input_gates_ptr, hidden_gates_ptr, cx_ptr, input_bias_ptr, hidden_bias_ptr,\n                      hy_ptr, cy_ptr, workspace_ptr);\n    } else {\n      RUN_CUDA_KERNEL((lstm_cell_forward<T, ACC_T, int64_t>), stream, cx_numel, cx_numel,\n                      hidden_size, input_gates_ptr, hidden_gates_ptr, cx_ptr, input_bias_ptr,\n                      hidden_bias_ptr, hy_ptr, cy_ptr, workspace_ptr);\n    }\n  }\n};\n\ntemplate<>\nvoid FusedLstmCellFunctor<float16>::operator()(\n    ep::Stream* stream, const int64_t cx_numel, const int64_t workspace_numel,\n    const int64_t hidden_size, const float16* input_gates_ptr, const float16* hidden_gates_ptr,\n    const float16* cx_ptr, const float16* input_bias_ptr, const float16* hidden_bias_ptr,\n    float16* hy_ptr, float16* cy_ptr, float16* workspace_ptr) {\n  if (workspace_numel < std::numeric_limits<int32_t>::max()) {\n    RUN_CUDA_KERNEL(\n        (lstm_cell_forward<half, float, int32_t>), stream, cx_numel, static_cast<int32_t>(cx_numel),\n        static_cast<int32_t>(hidden_size), reinterpret_cast<const half*>(input_gates_ptr),\n        reinterpret_cast<const half*>(hidden_gates_ptr), reinterpret_cast<const half*>(cx_ptr),\n        reinterpret_cast<const half*>(input_bias_ptr),\n        reinterpret_cast<const half*>(hidden_bias_ptr), reinterpret_cast<half*>(hy_ptr),\n        reinterpret_cast<half*>(cy_ptr), reinterpret_cast<half*>(workspace_ptr));\n  } else {\n    RUN_CUDA_KERNEL((lstm_cell_forward<half, float, int64_t>), stream, cx_numel, cx_numel,\n                    hidden_size, reinterpret_cast<const half*>(input_gates_ptr),\n                    reinterpret_cast<const half*>(hidden_gates_ptr),\n                    reinterpret_cast<const half*>(cx_ptr),\n                    reinterpret_cast<const half*>(input_bias_ptr),\n                    reinterpret_cast<const half*>(hidden_bias_ptr), reinterpret_cast<half*>(hy_ptr),\n                    reinterpret_cast<half*>(cy_ptr), reinterpret_cast<half*>(workspace_ptr));\n  }\n}\n\ntemplate<typename T>\nstruct FusedLstmCellGradFunctor final {\n  void operator()(ep::Stream* stream, const int64_t cx_numel, const int64_t workspace_numel,\n                  const int64_t hidden_size, const T* grad_hy_ptr, const T* grad_cy_ptr,\n                  const T* cx_ptr, const T* cy_ptr, const T* workspace_ptr, T* grad_gates_ptr,\n                  T* grad_cx_ptr) {\n    using ACC_T = acc_type<T>;\n    if (workspace_numel < std::numeric_limits<int32_t>::max()) {\n      RUN_CUDA_KERNEL((lstm_cell_backward<T, ACC_T, int32_t>), stream, cx_numel,\n                      static_cast<int32_t>(cx_numel), static_cast<int32_t>(hidden_size),\n                      grad_hy_ptr, grad_cy_ptr, cx_ptr, cy_ptr, workspace_ptr, grad_gates_ptr,\n                      grad_cx_ptr);\n    } else {\n      RUN_CUDA_KERNEL((lstm_cell_backward<T, ACC_T, int64_t>), stream, cx_numel, cx_numel,\n                      hidden_size, grad_hy_ptr, grad_cy_ptr, cx_ptr, cy_ptr, workspace_ptr,\n                      grad_gates_ptr, grad_cx_ptr);\n    }\n  }\n};\n\ntemplate<>\nvoid FusedLstmCellGradFunctor<float16>::operator()(\n    ep::Stream* stream, const int64_t cx_numel, const int64_t workspace_numel,\n    const int64_t hidden_size, const float16* grad_hy_ptr, const float16* grad_cy_ptr,\n    const float16* cx_ptr, const float16* cy_ptr, const float16* workspace_ptr,\n    float16* grad_gates_ptr, float16* grad_cx_ptr) {\n  if (workspace_numel < std::numeric_limits<int32_t>::max()) {\n    RUN_CUDA_KERNEL((lstm_cell_backward<half, float, int32_t>), stream, cx_numel,\n                    static_cast<int32_t>(cx_numel), static_cast<int32_t>(hidden_size),\n                    reinterpret_cast<const half*>(grad_hy_ptr),\n                    reinterpret_cast<const half*>(grad_cy_ptr),\n                    reinterpret_cast<const half*>(cx_ptr), reinterpret_cast<const half*>(cy_ptr),\n                    reinterpret_cast<const half*>(workspace_ptr),\n                    reinterpret_cast<half*>(grad_gates_ptr), reinterpret_cast<half*>(grad_cx_ptr));\n  } else {\n    RUN_CUDA_KERNEL((lstm_cell_backward<half, float, int64_t>), stream, cx_numel, cx_numel,\n                    hidden_size, reinterpret_cast<const half*>(grad_hy_ptr),\n                    reinterpret_cast<const half*>(grad_cy_ptr),\n                    reinterpret_cast<const half*>(cx_ptr), reinterpret_cast<const half*>(cy_ptr),\n                    reinterpret_cast<const half*>(workspace_ptr),\n                    reinterpret_cast<half*>(grad_gates_ptr), reinterpret_cast<half*>(grad_cx_ptr));\n  }\n}\n\n}  // namespace\n\ntemplate<typename T>\nclass GpuFusedLstmCellKernel final : public user_op::OpKernel {\n public:\n  GpuFusedLstmCellKernel() = default;\n  ~GpuFusedLstmCellKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* input_gates = ctx->Tensor4ArgNameAndIndex(\"input_gates\", 0);\n    const user_op::Tensor* hidden_gates = ctx->Tensor4ArgNameAndIndex(\"hidden_gates\", 0);\n    const user_op::Tensor* cx = ctx->Tensor4ArgNameAndIndex(\"cx\", 0);\n    user_op::Tensor* hy = ctx->Tensor4ArgNameAndIndex(\"hy\", 0);\n    user_op::Tensor* cy = ctx->Tensor4ArgNameAndIndex(\"cy\", 0);\n    user_op::Tensor* workspace = ctx->Tensor4ArgNameAndIndex(\"workspace\", 0);\n\n    const T* input_bias_ptr = nullptr;\n    const T* hidden_bias_ptr = nullptr;\n    if (ctx->has_input(\"input_bias\", 0)) {\n      CHECK(ctx->has_input(\"hidden_bias\", 0));\n      input_bias_ptr = ctx->Tensor4ArgNameAndIndex(\"input_bias\", 0)->dptr<T>();\n      hidden_bias_ptr = ctx->Tensor4ArgNameAndIndex(\"hidden_bias\", 0)->dptr<T>();\n    }\n    const T* input_gates_ptr = input_gates->dptr<T>();\n    const T* hidden_gates_ptr = hidden_gates->dptr<T>();\n    const T* cx_ptr = cx->dptr<T>();\n\n    T* hy_ptr = hy->mut_dptr<T>();\n    T* cy_ptr = cy->mut_dptr<T>();\n    T* workspace_ptr = workspace->mut_dptr<T>();\n    const int64_t cx_numel = cx->shape_view().elem_cnt();\n    const int64_t workspace_numel = workspace->shape_view().elem_cnt();\n    const int64_t hidden_size = cx->shape_view().At(cx->shape_view().NumAxes() - 1);\n    FusedLstmCellFunctor<T>()(ctx->stream(), cx_numel, workspace_numel, hidden_size,\n                              input_gates_ptr, hidden_gates_ptr, cx_ptr, input_bias_ptr,\n                              hidden_bias_ptr, hy_ptr, cy_ptr, workspace_ptr);\n  }\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_FUSED_LSTM_CELL_KERNEL(dtype)                                                  \\\n  REGISTER_USER_KERNEL(\"fused_lstm_cell\")                                                       \\\n      .SetCreateFn<GpuFusedLstmCellKernel<dtype>>()                                             \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)                          \\\n                       && (user_op::HobDataType(\"cx\", 0) == GetDataType<dtype>::value)          \\\n                       && (user_op::HobDataType(\"input_gates\", 0) == GetDataType<dtype>::value) \\\n                       && (user_op::HobDataType(\"hidden_gates\", 0) == GetDataType<dtype>::value))\n\nREGISTER_FUSED_LSTM_CELL_KERNEL(float);\nREGISTER_FUSED_LSTM_CELL_KERNEL(float16);\n\nclass GpuFusedLstmCellGradFloatKernel final : public user_op::OpKernel {\n public:\n  GpuFusedLstmCellGradFloatKernel() = default;\n  ~GpuFusedLstmCellGradFloatKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* grad_hy = ctx->Tensor4ArgNameAndIndex(\"grad_hy\", 0);\n    const user_op::Tensor* grad_cy = ctx->Tensor4ArgNameAndIndex(\"grad_cy\", 0);\n    const user_op::Tensor* cx = ctx->Tensor4ArgNameAndIndex(\"cx\", 0);\n    const user_op::Tensor* cy = ctx->Tensor4ArgNameAndIndex(\"cy\", 0);\n    const user_op::Tensor* workspace = ctx->Tensor4ArgNameAndIndex(\"workspace\", 0);\n    user_op::Tensor* grad_gates = ctx->Tensor4ArgNameAndIndex(\"grad_gates\", 0);\n    user_op::Tensor* grad_cx = ctx->Tensor4ArgNameAndIndex(\"grad_cx\", 0);\n\n    const float* grad_hy_ptr = grad_hy->dptr<float>();\n    const float* grad_cy_ptr = grad_cy->dptr<float>();\n    const float* cx_ptr = cx->dptr<float>();\n    const float* cy_ptr = cy->dptr<float>();\n    const float* workspace_ptr = workspace->dptr<float>();\n\n    float* grad_gates_ptr = grad_gates->mut_dptr<float>();\n    float* grad_cx_ptr = nullptr;\n\n    if (ctx->has_output(\"grad_cx\", 0)) { grad_cx_ptr = grad_cx->mut_dptr<float>(); }\n\n    const int64_t cx_numel = cx->shape_view().elem_cnt();\n    const int64_t workspace_numel = workspace->shape_view().elem_cnt();\n    const int64_t hidden_size = cx->shape_view().At(cx->shape_view().NumAxes() - 1);\n    FusedLstmCellGradFunctor<float>()(ctx->stream(), cx_numel, workspace_numel, hidden_size,\n                                      grad_hy_ptr, grad_cy_ptr, cx_ptr, cy_ptr, workspace_ptr,\n                                      grad_gates_ptr, grad_cx_ptr);\n\n    if (ctx->has_output(\"grad_bias\", 0)) {\n      float* grad_bias_ptr = ctx->Tensor4ArgNameAndIndex(\"grad_bias\", 0)->mut_dptr<float>();\n      std::vector<int32_t> axis;\n      axis.push_back(0);\n      const Shape& reduced_shape =\n          CreateReducedShape(workspace->shape_view(), {axis.begin(), axis.end()});\n      user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n      NdarrayReduce<DeviceType::kCUDA, float, BinaryFuncSum>::Reduce(\n          ctx->stream(), XpuVarNdarray<float>(reduced_shape, grad_bias_ptr),\n          XpuVarNdarray<const float>(grad_gates->shape_view(), grad_gates->dptr<float>()),\n          XpuVarNdarray<float>(tmp_buffer->shape_view(), tmp_buffer->mut_dptr<float>()));\n    }\n  }\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\nREGISTER_USER_KERNEL(\"fused_lstm_cell_grad\")\n    .SetCreateFn<GpuFusedLstmCellGradFloatKernel>()\n    .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)\n                     && (user_op::HobDataType(\"grad_hy\", 0) == GetDataType<float>::value)\n                     && (user_op::HobDataType(\"grad_cy\", 0) == GetDataType<float>::value)\n                     && (user_op::HobDataType(\"cx\", 0) == GetDataType<float>::value)\n                     && (user_op::HobDataType(\"cy\", 0) == GetDataType<float>::value)\n                     && (user_op::HobDataType(\"workspace\", 0) == GetDataType<float>::value))\n    .SetInferTmpSizeFn([](user_op::InferContext* ctx) {\n      size_t tmp_bytes = 0;\n      if (ctx->has_output(\"grad_bias\", 0)) {\n        const Shape& in_shape = ctx->InputTensorDesc(\"workspace\", 0).shape();\n        tmp_bytes = GetCudaAlignedSize(in_shape.elem_cnt() * sizeof(float));\n      } else {\n        tmp_bytes = 0;\n      }\n      return tmp_bytes;\n    });\n\nclass GpuFusedLstmCellGradHalfKernel final : public user_op::OpKernel {\n public:\n  GpuFusedLstmCellGradHalfKernel() = default;\n  ~GpuFusedLstmCellGradHalfKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* grad_hy = ctx->Tensor4ArgNameAndIndex(\"grad_hy\", 0);\n    const user_op::Tensor* grad_cy = ctx->Tensor4ArgNameAndIndex(\"grad_cy\", 0);\n    const user_op::Tensor* cx = ctx->Tensor4ArgNameAndIndex(\"cx\", 0);\n    const user_op::Tensor* cy = ctx->Tensor4ArgNameAndIndex(\"cy\", 0);\n    const user_op::Tensor* workspace = ctx->Tensor4ArgNameAndIndex(\"workspace\", 0);\n    user_op::Tensor* grad_gates = ctx->Tensor4ArgNameAndIndex(\"grad_gates\", 0);\n    user_op::Tensor* grad_cx = ctx->Tensor4ArgNameAndIndex(\"grad_cx\", 0);\n\n    const float16* grad_hy_ptr = grad_hy->dptr<float16>();\n    const float16* grad_cy_ptr = grad_cy->dptr<float16>();\n    const float16* cx_ptr = cx->dptr<float16>();\n    const float16* cy_ptr = cy->dptr<float16>();\n    const float16* workspace_ptr = workspace->dptr<float16>();\n\n    float16* grad_gates_ptr = grad_gates->mut_dptr<float16>();\n    float16* grad_cx_ptr = nullptr;\n\n    if (ctx->has_output(\"grad_cx\", 0)) { grad_cx_ptr = grad_cx->mut_dptr<float16>(); }\n\n    const int64_t cx_numel = cx->shape_view().elem_cnt();\n    const int64_t workspace_numel = workspace->shape_view().elem_cnt();\n    const int64_t hidden_size = cx->shape_view().At(cx->shape_view().NumAxes() - 1);\n    FusedLstmCellGradFunctor<float16>()(ctx->stream(), cx_numel, workspace_numel, hidden_size,\n                                        grad_hy_ptr, grad_cy_ptr, cx_ptr, cy_ptr, workspace_ptr,\n                                        grad_gates_ptr, grad_cx_ptr);\n\n    if (ctx->has_output(\"grad_bias\", 0)) {\n      std::vector<int32_t> axis;\n      axis.push_back(0);\n      user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n      const ShapeView& in_shape = grad_gates->shape_view();\n      const Shape& reduced_shape = CreateReducedShape(in_shape, {axis.begin(), axis.end()});\n      float* in_tmp_buffer = tmp_buffer->mut_dptr<float>();\n      const size_t in_tmp_buffer_bytes = GetCudaAlignedSize(in_shape.elem_cnt() * sizeof(float));\n      float* out_tmp_buffer =\n          reinterpret_cast<float*>(tmp_buffer->mut_dptr<char>() + in_tmp_buffer_bytes);\n      const size_t out_tmp_buffer_bytes =\n          GetCudaAlignedSize(reduced_shape.elem_cnt() * sizeof(float));\n      float* reduce_tmp_buffer = reinterpret_cast<float*>(\n          tmp_buffer->mut_dptr<char>() + in_tmp_buffer_bytes + out_tmp_buffer_bytes);\n      const size_t reduce_tmp_buffer_bytes =\n          GetCudaAlignedSize(in_shape.elem_cnt() * sizeof(float));\n      CHECK_LE(in_tmp_buffer_bytes + out_tmp_buffer_bytes + reduce_tmp_buffer_bytes,\n               tmp_buffer->shape_view().elem_cnt());\n      auto h2f = ep::primitive::NewPrimitive<ep::primitive::CastFactory>(\n          ctx->device_type(), DataType::kFloat16, DataType::kFloat);\n      CHECK(h2f);\n      auto f2h = ep::primitive::NewPrimitive<ep::primitive::CastFactory>(\n          ctx->device_type(), DataType::kFloat, DataType::kFloat16);\n      CHECK(f2h);\n      h2f->Launch(ctx->stream(), grad_gates->dptr<float16>(), in_tmp_buffer, in_shape.elem_cnt());\n\n      NdarrayReduce<DeviceType::kCUDA, float, BinaryFuncSum>::Reduce(\n          ctx->stream(), XpuVarNdarray<float>(reduced_shape, out_tmp_buffer),\n          XpuVarNdarray<const float>(in_shape, in_tmp_buffer),\n          XpuVarNdarray<float>(in_shape, reduce_tmp_buffer));\n\n      user_op::Tensor* output_tensor = ctx->Tensor4ArgNameAndIndex(\"grad_bias\", 0);\n      f2h->Launch(ctx->stream(), out_tmp_buffer, output_tensor->mut_dptr<float16>(),\n                  output_tensor->shape_view().elem_cnt());\n    }\n  }\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\nREGISTER_USER_KERNEL(\"fused_lstm_cell_grad\")\n    .SetCreateFn<GpuFusedLstmCellGradHalfKernel>()\n    .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)\n                     && (user_op::HobDataType(\"grad_hy\", 0) == GetDataType<float16>::value)\n                     && (user_op::HobDataType(\"grad_cy\", 0) == GetDataType<float16>::value)\n                     && (user_op::HobDataType(\"cx\", 0) == GetDataType<float16>::value)\n                     && (user_op::HobDataType(\"cy\", 0) == GetDataType<float16>::value)\n                     && (user_op::HobDataType(\"workspace\", 0) == GetDataType<float16>::value))\n    .SetInferTmpSizeFn([](user_op::InferContext* ctx) {\n      size_t tmp_bytes = 0;\n      if (ctx->has_output(\"grad_bias\", 0)) {\n        const Shape& in_shape = ctx->InputTensorDesc(\"workspace\", 0).shape();\n        const Shape& out_shape = ctx->OutputTensorDesc(\"grad_bias\", 0).shape();\n        tmp_bytes = (2 * GetCudaAlignedSize(in_shape.elem_cnt() * sizeof(float))\n                     + GetCudaAlignedSize(out_shape.elem_cnt() * sizeof(float)));\n      } else {\n        tmp_bytes = 0;\n      }\n      return tmp_bytes;\n    });\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/fused_matmul_bias_add_relu_dropout.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/kernel/cuda_graph_support.h\"\n#include \"oneflow/core/cuda/elementwise.cuh\"\n#include \"oneflow/core/cuda/atomic.cuh\"\n#include \"oneflow/user/kernels/cublas_fused_mlp_util.cuh\"\n#include \"oneflow/user/kernels/dropout_kernel.h\"\n#include \"oneflow/user/kernels/random_seed_util.h\"\n\n// CUBLAS_AUX_EPILOGUE only support in cuda11.4 or higher version, in cuda11.4 it need static link.\n#if CUDA_VERSION >= 11060\n\nnamespace oneflow {\n\nnamespace {\n\nconstexpr int32_t kVecSize = 4;\nconstexpr int32_t kBlockSize = 256;\nconstexpr int32_t kWarpSize = 32;\n\nunion RandPack4 {\n  uint4 storage;\n  uint32_t elem[4];  // store curand4 return val.\n};\n\ntemplate<int32_t pack_size, typename IndexType>\n__device__ void SetCublasBitMask(const IndexType aux_ld, const IndexType row, const IndexType col,\n                                 int32_t thread_bitmask, int32_t* mask) {\n  IndexType linear_index = row * aux_ld + col;\n  IndexType mask_index = linear_index / kWarpSize;\n  IndexType mask_offset = linear_index - mask_index * kWarpSize;\n\n  int32_t bitmask = thread_bitmask << mask_offset;\n  for (int stride = kWarpSize / (pack_size * 2); stride > 0; stride /= 2) {\n    bitmask |= __shfl_down_sync(__activemask(), bitmask, stride, kWarpSize);\n  }\n  if (mask_offset == 0) { mask[mask_index] = bitmask; }\n}\n\ntemplate<typename T, bool relu, typename IndexType>\n__global__ void FusedVectorizedReluDropoutKernel(uint64_t seed, uint64_t offset,\n                                                 const IndexType elem_cnt, const int32_t aux_ld,\n                                                 const IndexType cols, const uint32_t rate,\n                                                 float scale, T* x, int32_t* mask) {\n  IndexType global_thread_id = blockIdx.x * blockDim.x + threadIdx.x;\n  curandStatePhilox4_32_10_t state;\n  curand_init(seed, global_thread_id, offset, &state);\n  using LoadType = cuda::elementwise::PackType<T, kVecSize>;\n  using LoadPack = cuda::elementwise::Pack<T, kVecSize>;\n\n  T t_scale = static_cast<T>(scale);\n  RandPack4 rand_uniform_pack4;\n  T zero_val = static_cast<T>(0.0);\n  for (IndexType linear_index = global_thread_id * kVecSize,\n                 step = gridDim.x * blockDim.x * kVecSize;\n       linear_index < elem_cnt; linear_index += step) {\n    const IndexType row = linear_index / cols;\n    const IndexType col = linear_index - row * cols;\n    int32_t thread_bitmask = 0;\n\n    rand_uniform_pack4.storage = curand4(&state);\n\n    LoadType* x_load = reinterpret_cast<LoadType*>(x + linear_index);\n    LoadPack x_vec;\n    x_vec.storage = *x_load;\n    LoadPack out_vec;\n#pragma unroll\n    for (int i = 0; i < kVecSize; i++) {\n      bool relu_mask = true;\n      if (relu) {\n        // Relu\n        relu_mask = x_vec.elem[i] >= zero_val;\n      }\n      // dropout\n      bool mask_val = rand_uniform_pack4.elem[i] > rate;\n      // Combined relu_mask, dropout_mask together.\n      bool combined_mask = relu_mask && mask_val;\n      // Cause half/bfloat16 cannot directily convert from bool, here we cast to float type first\n      T t_combined_mask = static_cast<T>(static_cast<float>(combined_mask));\n      thread_bitmask |= (combined_mask << i);\n      out_vec.elem[i] = x_vec.elem[i] * t_combined_mask * t_scale;\n    }\n    *(reinterpret_cast<LoadType*>(x + linear_index)) = out_vec.storage;\n    SetCublasBitMask<kVecSize, IndexType>(aux_ld, row, col, thread_bitmask, mask);\n  }\n}\n\ntemplate<typename T, bool relu, typename IndexType>\n__global__ void FusedPaddedVectorizedReluDropoutKernel(uint64_t seed, uint64_t offset,\n                                                       const IndexType aligned32_elem_cnt,\n                                                       const int32_t aux_ld,\n                                                       const IndexType aligned32_cols,\n                                                       const IndexType cols, const uint32_t rate,\n                                                       float scale, T* x, int32_t* mask) {\n  IndexType global_thread_id = blockIdx.x * blockDim.x + threadIdx.x;\n  curandStatePhilox4_32_10_t state;\n  curand_init(seed, global_thread_id, offset, &state);\n  using LoadType = cuda::elementwise::PackType<T, kVecSize>;\n  using LoadPack = cuda::elementwise::Pack<T, kVecSize>;\n\n  T t_scale = static_cast<T>(scale);\n  RandPack4 rand_uniform_pack4;\n  T zero_val = static_cast<T>(0.0);\n  for (IndexType linear_index = global_thread_id * kVecSize,\n                 step = gridDim.x * blockDim.x * kVecSize;\n       linear_index < aligned32_elem_cnt; linear_index += step) {\n    const IndexType row = linear_index / aligned32_cols;\n    const IndexType col = linear_index - row * aligned32_cols;\n    int32_t thread_bitmask = 0;\n\n    if (col < cols) {\n      const IndexType actual_index = row * cols + col;\n      rand_uniform_pack4.storage = curand4(&state);\n\n      LoadType* x_load = reinterpret_cast<LoadType*>(x + actual_index);\n      LoadPack x_vec;\n      x_vec.storage = *x_load;\n      LoadPack out_vec;\n#pragma unroll\n      for (int i = 0; i < kVecSize; i++) {\n        bool relu_mask = true;\n        if (relu) {\n          // Relu\n          relu_mask = x_vec.elem[i] >= zero_val;\n        }\n        // dropout\n        bool mask_val = rand_uniform_pack4.elem[i] > rate;\n        // Combined relu_mask, dropout_mask together.\n        bool combined_mask = relu_mask && mask_val;\n        // Cause half/bfloat16 cannot directily convert from bool, here we cast to float type first\n        T t_combined_mask = static_cast<T>(static_cast<float>(combined_mask));\n        thread_bitmask |= (combined_mask << i);\n        out_vec.elem[i] = x_vec.elem[i] * t_combined_mask * t_scale;\n      }\n      *(reinterpret_cast<LoadType*>(x + actual_index)) = out_vec.storage;\n    }\n    SetCublasBitMask<kVecSize, IndexType>(aux_ld, row, col, thread_bitmask, mask);\n  }\n}\n\ntemplate<typename T, bool relu, typename IndexType>\n__global__ void FusedWarpReluDropoutKernel(uint64_t seed, uint64_t offset, const IndexType elem_cnt,\n                                           const IndexType aux_ld, const IndexType rows,\n                                           const IndexType cols, const uint32_t rate, float scale,\n                                           T* x, int32_t* mask) {\n  const int32_t lane_id = threadIdx.x;\n  const IndexType global_warp_id = blockIdx.x * blockDim.y + threadIdx.y;\n  const IndexType step = gridDim.x * blockDim.y;\n  const IndexType global_thread_id = global_warp_id * kWarpSize + lane_id;\n\n  curandStatePhilox4_32_10_t state;\n  curand_init(seed, global_thread_id, offset, &state);\n\n  T t_scale = static_cast<T>(scale);\n  T zero_val = static_cast<T>(0.0);\n  RandPack4 rand_uniform_pack4;\n\n  for (IndexType row = global_warp_id; row < rows; row += step) {\n    for (IndexType col = lane_id; col < cols; col += kWarpSize * kVecSize) {\n      const IndexType linear_index = row * cols + col;\n      rand_uniform_pack4.storage = curand4(&state);\n#pragma unroll\n      for (int i = 0; i < kVecSize; i++) {\n        int32_t thread_bitmask = 0;\n        int32_t cur_col = col + i * kWarpSize;\n        int32_t cur_linear_index = linear_index + i * kWarpSize;\n        if (cur_col < cols) {\n          T x_val = x[cur_linear_index];\n          const uint32_t rand_uniform_val = rand_uniform_pack4.elem[i];\n          bool relu_mask = true;\n          if (relu) {\n            // relu\n            relu_mask = x_val >= zero_val;\n          }\n          // dropout\n          bool mask_val = rand_uniform_val > rate;\n          // Combined relu_mask, dropout_mask together.\n          bool combined_mask = relu_mask && mask_val;\n          thread_bitmask = combined_mask;\n          // Cause half/bfloat16 cannot directily convert from bool, here we cast to float type\n          // first\n          T t_combined_mask = static_cast<T>(static_cast<float>(combined_mask));\n          T out_val = x_val * t_combined_mask * t_scale;\n          x[cur_linear_index] = out_val;\n        }\n        int32_t warp_mask = __ballot_sync(__activemask(), thread_bitmask);\n        if (lane_id == 0) { mask[(row * aux_ld + cur_col) / kWarpSize] = warp_mask; }\n      }\n    }\n  }\n}\n\ntemplate<typename Func>\nunsigned int ComputeGridSize(ep::Stream* stream, Func func, const int64_t elem_cnt,\n                             const int32_t block_size) {\n  auto* cuda_stream = stream->As<ep::CudaStream>();\n  const int64_t pack_num = elem_cnt / kVecSize;\n  const int32_t num_blocks = std::max<int64_t>(1, (pack_num + block_size - 1) / block_size);\n  const int32_t multi_processor_count = cuda_stream->device_properties().multiProcessorCount;\n  int max_active_blocks = 0;\n  OF_CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_active_blocks, func, block_size,\n                                                              /*shared_memory*/ 0));\n  return std::min(num_blocks, max_active_blocks * multi_processor_count);\n}\n\nuint64_t RoundUp(uint64_t x, uint64_t y) { return (x + y - 1) / y * y; }\n\ntemplate<typename T, bool relu>\ncudaError_t LaunchFusedReluDropoutKernel(ep::CudaStream* stream,\n                                         const std::shared_ptr<ep::CUDAGenerator>& cuda_generator,\n                                         const int64_t elem_cnt, const int32_t aux_ld,\n                                         const int64_t rows, const int64_t cols, float rate,\n                                         float scale, T* x, int32_t* mask) {\n  uint64_t offset = 0;\n  uint64_t seed = cuda_generator->current_seed();\n  const uint32_t uint_rate = UINT_MAX * rate;\n  unsigned int grid_size = 0;\n  if (cols % 32 == 0) {\n    // Launch Elementwise Vectorized Kernel.\n    if (elem_cnt < GetMaxVal<int32_t>()) {\n      grid_size = ComputeGridSize(stream, FusedVectorizedReluDropoutKernel<T, relu, int32_t>,\n                                  elem_cnt, kBlockSize);\n      uint64_t inc_offset = RoundUp((elem_cnt / (kBlockSize * grid_size)), kVecSize);\n      offset = cuda_generator->get_philox_offset(inc_offset);\n      FusedVectorizedReluDropoutKernel<T, relu, int32_t>\n          <<<grid_size, kBlockSize, 0, stream->cuda_stream()>>>(seed, offset, elem_cnt, aux_ld,\n                                                                cols, uint_rate, scale, x, mask);\n    } else {\n      grid_size = ComputeGridSize(stream, FusedVectorizedReluDropoutKernel<T, relu, int64_t>,\n                                  elem_cnt, kBlockSize);\n      uint64_t inc_offset = RoundUp((elem_cnt / (kBlockSize * grid_size)), kVecSize);\n      offset = cuda_generator->get_philox_offset(inc_offset);\n      FusedVectorizedReluDropoutKernel<T, relu, int64_t>\n          <<<grid_size, kBlockSize, 0, stream->cuda_stream()>>>(seed, offset, elem_cnt, aux_ld,\n                                                                cols, uint_rate, scale, x, mask);\n    }\n  } else {\n    if (cols % 4 == 0) {\n      // Padding cols to align kWarpSize.\n      const int64_t align32_cols = (cols + kWarpSize - 1) / kWarpSize * kWarpSize;\n      const int64_t align32_elem_cnt = rows * align32_cols;\n      if (align32_elem_cnt < GetMaxVal<int32_t>()) {\n        grid_size =\n            ComputeGridSize(stream, FusedPaddedVectorizedReluDropoutKernel<T, relu, int32_t>,\n                            align32_elem_cnt, kBlockSize);\n        uint64_t inc_offset = RoundUp((elem_cnt / (kBlockSize * grid_size)), kVecSize);\n        offset = cuda_generator->get_philox_offset(inc_offset);\n        FusedPaddedVectorizedReluDropoutKernel<T, relu, int32_t>\n            <<<grid_size, kBlockSize, 0, stream->cuda_stream()>>>(seed, offset, align32_elem_cnt,\n                                                                  aux_ld, align32_cols, cols,\n                                                                  uint_rate, scale, x, mask);\n      } else {\n        grid_size =\n            ComputeGridSize(stream, FusedPaddedVectorizedReluDropoutKernel<T, relu, int64_t>,\n                            align32_elem_cnt, kBlockSize);\n        uint64_t inc_offset = RoundUp((elem_cnt / (kBlockSize * grid_size)), kVecSize);\n        offset = cuda_generator->get_philox_offset(inc_offset);\n        FusedPaddedVectorizedReluDropoutKernel<T, relu, int64_t>\n            <<<grid_size, kBlockSize, 0, stream->cuda_stream()>>>(seed, offset, align32_elem_cnt,\n                                                                  aux_ld, align32_cols, cols,\n                                                                  uint_rate, scale, x, mask);\n      }\n    } else {\n      // Process a row by using a warp.\n      dim3 block_dim(kWarpSize, kBlockSize / kWarpSize);\n      if (elem_cnt < GetMaxVal<int32_t>()) {\n        grid_size = ComputeGridSize(stream, FusedWarpReluDropoutKernel<T, relu, int32_t>, elem_cnt,\n                                    kBlockSize);\n        uint64_t inc_offset = RoundUp((elem_cnt / (kBlockSize * grid_size)), kVecSize);\n        offset = cuda_generator->get_philox_offset(inc_offset);\n        FusedWarpReluDropoutKernel<T, relu, int32_t>\n            <<<grid_size, block_dim, 0, stream->cuda_stream()>>>(\n                seed, offset, elem_cnt, aux_ld, rows, cols, uint_rate, scale, x, mask);\n      } else {\n        grid_size = ComputeGridSize(stream, FusedWarpReluDropoutKernel<T, relu, int32_t>, elem_cnt,\n                                    kBlockSize);\n        uint64_t inc_offset = RoundUp((elem_cnt / (kBlockSize * grid_size)), kVecSize);\n        offset = cuda_generator->get_philox_offset(inc_offset);\n        FusedWarpReluDropoutKernel<T, relu, int64_t>\n            <<<grid_size, block_dim, 0, stream->cuda_stream()>>>(\n                seed, offset, elem_cnt, aux_ld, rows, cols, uint_rate, scale, x, mask);\n      }\n    }\n  }\n  return cudaPeekAtLastError();\n}\n\ntemplate<typename T>\nclass FusedMatmulBiasAddReluDropoutKernel final : public user_op::OpKernel {\n public:\n  FusedMatmulBiasAddReluDropoutKernel() = default;\n  ~FusedMatmulBiasAddReluDropoutKernel() override = default;\n\n  std::shared_ptr<user_op::OpKernelCache> InitOpKernelCache(\n      user_op::KernelCacheContext* ctx) const override {\n    return CreateCublasFusedMLPKernelCache();\n  }\n\n  std::shared_ptr<user_op::OpKernelState> CreateOpKernelState(\n      user_op::KernelInitContext* ctx) const override {\n    const auto& generator = CHECK_JUST(one::MakeGenerator(DeviceType::kCUDA));\n    generator->set_current_seed(\n        CHECK_JUST(GetOpKernelRandomSeedInCurrentRank(ctx, ctx->Attr<int64_t>(\"seed\"))));\n    return std::make_shared<FusedDropoutKernelState>(generator);\n  }\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state,\n               const user_op::OpKernelCache* cache) const override {\n    /*\n    Fused DenseActivation Layer. Assume we have two layers:\n    A: (m, k)\n    B: (n, k) need transpose\n    C: (j, n) need transpose\n    tmp: A matmul B(transpose), its shape is (m, n)\n    out: tmp matmul C(transpose), its shape is (m, j)\n    */\n    const int32_t weight_size = ctx->input_size(\"weights\");\n    const int32_t bias_size = ctx->input_size(\"biases\");\n    CHECK_EQ(weight_size, bias_size) << \"The number of weight and bias is not equal!. \";\n    auto* cuda_stream = ctx->stream()->As<ep::CudaStream>();\n    const auto* matmul_cache = CHECK_NOTNULL(dynamic_cast<const CublasFusedMLPKernelCache*>(cache));\n\n    const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    bool skip_final_activation = ctx->Attr<bool>(\"skip_final_activation\");\n\n    auto* fused_dropout_kernel_state = dynamic_cast<FusedDropoutKernelState*>(state);\n    CHECK_NOTNULL(fused_dropout_kernel_state);\n    const auto& generator = fused_dropout_kernel_state->generator();\n    CHECK_NOTNULL(generator);\n    const auto device_index = ctx->stream()->device()->device_index();\n    std::shared_ptr<ep::CUDAGenerator> cuda_generator =\n        CHECK_JUST(generator->Get<ep::CUDAGenerator>(device_index));\n    const std::vector<float> dropout_rate_list = ctx->Attr<std::vector<float>>(\"dropout_rate_list\");\n\n    const DataType data_type = out->data_type();\n    const cublasComputeType_t cublas_compute_dtype = GetComputeType(data_type);\n    const cudaDataType_t cuda_data_type = GetCudaDataType(data_type);\n    size_t cublas_m = 0, cublas_n = 0, cublas_k = 0;\n    int64_t cublas_lda = 0, cublas_ldb = 0, cublas_ldc = 0;\n\n    const double alpha = 1.0;\n    const auto sp_alpha = GetCublasScalarParameter(alpha, cublas_compute_dtype);\n    const double beta = 0.0;\n    const auto sp_beta = GetCublasScalarParameter(beta, cublas_compute_dtype);\n\n    // Currently only support 2D matmul.\n    DimVector in_shape(2);\n    x->shape_view().ToDimVector(&in_shape);\n    DimVector weight_shape(2);\n\n    const void* in_buf_ptr = x->dptr();\n    size_t offset = 0;\n    for (int idx = 0; idx < weight_size; idx++) {\n      const user_op::Tensor* weight = ctx->Tensor4ArgNameAndIndex(\"weights\", idx);\n      const user_op::Tensor* bias = ctx->Tensor4ArgNameAndIndex(\"biases\", idx);\n      user_op::Tensor* cublas_aux = ctx->Tensor4ArgNameAndIndex(\"cublas_aux\", idx);\n\n      const int64_t batchsize = in_shape.at(0);\n      const int64_t out_feature = weight->shape_view().At(0);\n      weight->shape_view().ToDimVector(&weight_shape);\n      size_t matmul_out_elem_cnt = batchsize * out_feature;\n\n      InferMatmulCublasMNK(in_shape, weight_shape,\n                           /*transpose_a=*/ep::primitive::BlasTransposeType::N,\n                           /*transpose_b=*/ep::primitive::BlasTransposeType::T, &cublas_m,\n                           &cublas_n, &cublas_k, &cublas_lda, &cublas_ldb, &cublas_ldc);\n\n      cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_BIAS;\n      void* matmul_out_ptr;\n\n      float rate = dropout_rate_list.at(idx);\n      float scale = 0.0;\n      const int32_t aux_ld = AlignReluAuxLd(out_feature);\n      if (rate < 1.0f) { scale = 1.0f / (1.0f - rate); }\n\n      if (idx == weight_size - 1) {\n        matmul_out_ptr = ctx->Tensor4ArgNameAndIndex(\"out\", 0)->mut_dptr();\n      } else {\n        matmul_out_ptr = ctx->Tensor4ArgNameAndIndex(\"hidden\", idx)->mut_dptr();\n      }\n      SetCublasAttr(matmul_cache, cublas_compute_dtype, cuda_data_type, /*need_aux=*/false,\n                    /*transpose_a=*/ep::primitive::BlasTransposeType::N,\n                    /*transpose_b=*/ep::primitive::BlasTransposeType::T, epilogue, bias->dptr(),\n                    /*aux_ptr=*/nullptr, cublas_m, cublas_n, cublas_k, cublas_lda, cublas_ldb,\n                    cublas_ldc);\n\n      OF_CUBLAS_CHECK(cublasLtMatmul(\n          cuda_stream->cublas_lt_handle(), matmul_cache->operation_desc, &sp_alpha, weight->dptr(),\n          matmul_cache->cublas_a_desc, in_buf_ptr, matmul_cache->cublas_b_desc, &sp_beta,\n          matmul_out_ptr, matmul_cache->cublas_c_desc, matmul_out_ptr, matmul_cache->cublas_c_desc,\n          nullptr, cuda_stream->cublas_workspace(), cuda_stream->cublas_workspace_size(),\n          cuda_stream->cuda_stream()));\n\n      if (idx != weight_size - 1 || !skip_final_activation || rate != 0.0f) {\n        OF_CUDA_CHECK(cudaMemsetAsync(cublas_aux->mut_dptr<int32_t>(), 0,\n                                      cublas_aux->shape_view().elem_cnt() * sizeof(int32_t),\n                                      cuda_stream->cuda_stream()));\n      }\n\n      if (idx != weight_size - 1 || !skip_final_activation) {\n        // If it's not last layer or it's last layer but need relu.\n        OF_CUDA_CHECK((LaunchFusedReluDropoutKernel<T, true>(\n            cuda_stream, cuda_generator, matmul_out_elem_cnt, aux_ld, batchsize, out_feature, rate,\n            scale, reinterpret_cast<T*>(matmul_out_ptr),\n            reinterpret_cast<int32_t*>(cublas_aux->mut_dptr()))));\n        // Set relu_droput_out ptr as next layer's input.\n        in_buf_ptr = matmul_out_ptr;\n        // Set hidden_layer shape as next layer's input shape.\n        in_shape.at(1) = out_feature;\n      } else {\n        if (rate == 0.0f) {\n          // It's last layer and dropout_rate is 0.0f, we do not launch FusedReluDropoutKernel.\n          break;\n        } else {\n          // skip_final_activation but need dropout.\n          OF_CUDA_CHECK((LaunchFusedReluDropoutKernel<T, false>(\n              cuda_stream, cuda_generator, matmul_out_elem_cnt, aux_ld, batchsize, out_feature,\n              rate, scale, reinterpret_cast<T*>(matmul_out_ptr),\n              reinterpret_cast<int32_t*>(cublas_aux->mut_dptr()))));\n        }\n      }\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_FUSED_MATMUL_BIAS_ADD_RELU_DROPOUT_KERNEL_GPU(cpp_type, data_type) \\\n  REGISTER_USER_KERNEL(\"fused_matmul_bias_add_relu_dropout\")                        \\\n      .SetCreateFn<FusedMatmulBiasAddReluDropoutKernel<cpp_type>>()                 \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)              \\\n                       && (user_op::HobDataType(\"out\", 0) == data_type));\n\nREGISTER_FUSED_MATMUL_BIAS_ADD_RELU_DROPOUT_KERNEL_GPU(float, DataType::kFloat)\nREGISTER_FUSED_MATMUL_BIAS_ADD_RELU_DROPOUT_KERNEL_GPU(half, DataType::kFloat16)\n#if CUDA_VERSION >= 11000\nREGISTER_FUSED_MATMUL_BIAS_ADD_RELU_DROPOUT_KERNEL_GPU(nv_bfloat16, DataType::kBFloat16)\n#endif\n\n}  // namespace\n\n}  // namespace oneflow\n\n#endif  // CUDA_VERSION >= 11060\n"
  },
  {
    "path": "oneflow/user/kernels/fused_matmul_bias_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/kernel/cuda_graph_support.h\"\n#include \"oneflow/user/kernels/cublas_fused_mlp_util.cuh\"\n\n// same with cublas_fused_mlp_util.cuh\n#if CUDA_VERSION >= 11020\n\nnamespace oneflow {\n\nnamespace {\n\nclass FusedMatmulBiasKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport {\n public:\n  FusedMatmulBiasKernel() = default;\n  ~FusedMatmulBiasKernel() override = default;\n\n  std::shared_ptr<user_op::OpKernelCache> InitOpKernelCache(\n      user_op::KernelCacheContext* ctx) const override {\n    return CreateCublasFusedMLPKernelCache();\n  }\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*,\n               const user_op::OpKernelCache* cache) const override {\n    auto* cuda_stream = ctx->stream()->As<ep::CudaStream>();\n    const auto* matmul_cache = CHECK_NOTNULL(dynamic_cast<const CublasFusedMLPKernelCache*>(cache));\n\n    const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    const user_op::Tensor* _add_to_output = (ctx->has_input(\"_add_to_output\", 0))\n                                                ? ctx->Tensor4ArgNameAndIndex(\"_add_to_output\", 0)\n                                                : nullptr;\n\n    const DataType data_type = out->data_type();\n    const cublasComputeType_t cublas_compute_dtype = GetComputeType(data_type);\n    const cudaDataType_t cuda_data_type = GetCudaDataType(data_type);\n    size_t cublas_m = 0, cublas_n = 0, cublas_k = 0;\n    int64_t cublas_lda = 0, cublas_ldb = 0, cublas_ldc = 0;\n\n    const double alpha = ctx->Attr<double>(\"alpha\");\n    const double beta = (ctx->has_input(\"_add_to_output\", 0)) ? ctx->Attr<double>(\"beta\") : 0.0;\n\n    const auto sp_alpha = GetCublasScalarParameter(alpha, cublas_compute_dtype);\n    const auto sp_beta = GetCublasScalarParameter(beta, cublas_compute_dtype);\n\n    DimVector in_shape({x->shape_view().Count(0, x->shape_view().NumAxes() - 1),\n                        x->shape_view().At(x->shape_view().NumAxes() - 1)});\n\n    DimVector weight_shape(2);\n\n    const user_op::Tensor* weight = ctx->Tensor4ArgNameAndIndex(\"weight\", 0);\n    const user_op::Tensor* bias = ctx->Tensor4ArgNameAndIndex(\"bias\", 0);\n\n    weight->shape_view().ToDimVector(&weight_shape);\n\n    InferMatmulCublasMNK(in_shape, weight_shape,\n                         /*transpose_a=*/ep::primitive::BlasTransposeType::N,\n                         /*transpose_b=*/ep::primitive::BlasTransposeType::T, &cublas_m, &cublas_n,\n                         &cublas_k, &cublas_lda, &cublas_ldb, &cublas_ldc);\n\n    cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_BIAS;\n    void* y_ptr = ctx->Tensor4ArgNameAndIndex(\"out\", 0)->mut_dptr();\n\n    SetCublasAttr(matmul_cache, cublas_compute_dtype, cuda_data_type, false,\n                  /*transpose_a=*/ep::primitive::BlasTransposeType::N,\n                  /*transpose_b=*/ep::primitive::BlasTransposeType::T, epilogue, bias->dptr(),\n                  nullptr, cublas_m, cublas_n, cublas_k, cublas_lda, cublas_ldb, cublas_ldc);\n\n    cublasLtMatmulPreference_t preference = nullptr;\n    size_t workspace_size = cuda_stream->cublas_workspace_size();\n    OF_CUBLAS_CHECK(cublasLtMatmulPreferenceCreate(&preference));\n    OF_CUBLAS_CHECK(cublasLtMatmulPreferenceSetAttribute(preference,\n                                                         CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,\n                                                         &workspace_size, sizeof(workspace_size)));\n    int returned_results = 0;\n    cublasLtMatmulHeuristicResult_t heuristic_result;\n    OF_CUBLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(\n        cuda_stream->cublas_lt_handle(), matmul_cache->operation_desc, matmul_cache->cublas_a_desc,\n        matmul_cache->cublas_b_desc, matmul_cache->cublas_c_desc, matmul_cache->cublas_c_desc,\n        preference, 1, &heuristic_result, &returned_results));\n    CHECK_EQ(returned_results, 1);\n    cublasLtMatmulPreferenceDestroy(preference);\n    OF_CUBLAS_CHECK(cublasLtMatmul(\n        cuda_stream->cublas_lt_handle(), matmul_cache->operation_desc, &sp_alpha, weight->dptr(),\n        matmul_cache->cublas_a_desc, x->dptr(), matmul_cache->cublas_b_desc, &sp_beta,\n        (_add_to_output == nullptr) ? y_ptr : _add_to_output->dptr(), matmul_cache->cublas_c_desc,\n        y_ptr, matmul_cache->cublas_c_desc, &heuristic_result.algo, cuda_stream->cublas_workspace(),\n        cuda_stream->cublas_workspace_size(), cuda_stream->cuda_stream()));\n  }\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_FUSED_MATMUL_BIAS_KERNEL_GPU(data_type)               \\\n  REGISTER_USER_KERNEL(\"fused_matmul_bias\")                            \\\n      .SetCreateFn<FusedMatmulBiasKernel>()                            \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \\\n                       && (user_op::HobDataType(\"out\", 0) == data_type));\n\nREGISTER_FUSED_MATMUL_BIAS_KERNEL_GPU(DataType::kDouble);\nREGISTER_FUSED_MATMUL_BIAS_KERNEL_GPU(DataType::kFloat);\nREGISTER_FUSED_MATMUL_BIAS_KERNEL_GPU(DataType::kFloat16);\n#if CUDA_VERSION >= 11000\nREGISTER_FUSED_MATMUL_BIAS_KERNEL_GPU(DataType::kBFloat16);\n#endif  // CUDA_VERSION >= 11000\n\n}  // namespace\n\n}  // namespace oneflow\n\n#endif  // CUDA_VERSION >= 11020\n"
  },
  {
    "path": "oneflow/user/kernels/fused_relu_dropout_grad_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/device/cuda_util.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n#include <cuda.h>\n#include \"oneflow/core/kernel/cuda_graph_support.h\"\n#include \"oneflow/core/cuda/elementwise.cuh\"\n\nnamespace oneflow {\n\nnamespace {\n\nconstexpr int32_t kWarpSize = 32;\n\ntemplate<typename T, typename IndexType, int pack_size, bool tail>\n__global__ void VectorizedReluDropoutBitmaskBackwardKernel(\n    const IndexType elem_cnt, const IndexType cols, const IndexType aux_ld, const float scale,\n    const IndexType n_tail, const IndexType tail_offset, const T* dy, const int32_t* mask, T* dx) {\n  int32_t global_thread_id = blockIdx.x * blockDim.x + threadIdx.x;\n  using LoadStoreType = cuda::elementwise::PackType<T, pack_size>;\n  using LoadStorePack = cuda::elementwise::Pack<T, pack_size>;\n\n  T t_scale = static_cast<T>(scale);\n  for (IndexType linear_pack_index = global_thread_id * pack_size; linear_pack_index < elem_cnt;\n       linear_pack_index += gridDim.x * blockDim.x * pack_size) {\n    const LoadStoreType* dy_load = reinterpret_cast<const LoadStoreType*>(dy + linear_pack_index);\n    LoadStorePack dy_vec;\n    dy_vec.storage = *dy_load;\n\n    LoadStorePack dx_vec;\n#pragma unroll\n    for (int i = 0; i < pack_size; i++) {\n      const IndexType linear_index = (linear_pack_index + i);\n      const IndexType row = linear_index / cols;\n      const IndexType col = linear_index - row * cols;\n      const int32_t col_mod_warpsize = col % kWarpSize;\n      const IndexType aux_idx = ((row * aux_ld) + col) / kWarpSize;\n      bool is_positive = mask[aux_idx] & (1 << col_mod_warpsize);\n      dx_vec.elem[i] =\n          dy_vec.elem[i] * static_cast<T>(static_cast<float>(is_positive)) * static_cast<T>(scale);\n    }\n    *(reinterpret_cast<LoadStoreType*>(dx + linear_pack_index)) = dx_vec.storage;\n  }\n\n  if (tail && global_thread_id < n_tail) {\n    const IndexType tail_index = tail_offset + global_thread_id;\n    const IndexType tail_row = tail_index / cols;\n    const IndexType tail_col = tail_index - tail_row * cols;\n    const IndexType tail_col_mod_warpsize = tail_col % kWarpSize;\n    const IndexType tail_aux_idx = ((tail_row * aux_ld) + tail_col) / kWarpSize;\n    bool is_positive = mask[tail_aux_idx] & (1 << tail_col_mod_warpsize);\n    dx[tail_index] =\n        dy[tail_index] * static_cast<T>(static_cast<float>(is_positive)) * static_cast<T>(scale);\n  }\n}\n\ntemplate<typename T>\nvoid LaunchVectorizedReluDropoutBackwardKernel(ep::Stream* stream, const int64_t elem_cnt,\n                                               const int64_t cols, const int64_t aux_ld,\n                                               float scale, const T* dy, const int32_t* mask,\n                                               T* dx) {\n  constexpr int pack_size = cuda::elementwise::PackSize<T>();\n  const int64_t pack_num = elem_cnt / pack_size;\n  const int64_t tail_offset = pack_num * pack_size;\n  const int64_t n_tail = elem_cnt - tail_offset;\n  const bool tail = n_tail > 0 ? true : false;\n  if (tail) {\n    if (elem_cnt < GetMaxVal<int32_t>()) {\n      stream->As<ep::CudaStream>()->LaunchKernelDefaultWaves(\n          (VectorizedReluDropoutBitmaskBackwardKernel<T, int32_t, pack_size, true>),\n          std::max<int64_t>(1, pack_num), elem_cnt, cols, aux_ld, scale, n_tail, tail_offset, dy,\n          mask, dx);\n    } else {\n      stream->As<ep::CudaStream>()->LaunchKernelDefaultWaves(\n          (VectorizedReluDropoutBitmaskBackwardKernel<T, int64_t, pack_size, true>),\n          std::max<int64_t>(1, pack_num), elem_cnt, cols, aux_ld, scale, n_tail, tail_offset, dy,\n          mask, dx);\n    }\n  } else {\n    if (elem_cnt < GetMaxVal<int32_t>()) {\n      stream->As<ep::CudaStream>()->LaunchKernelDefaultWaves(\n          (VectorizedReluDropoutBitmaskBackwardKernel<T, int32_t, pack_size, false>),\n          std::max<int64_t>(1, pack_num), elem_cnt, cols, aux_ld, scale, /*n_tail=*/0, tail_offset,\n          dy, mask, dx);\n    } else {\n      stream->As<ep::CudaStream>()->LaunchKernelDefaultWaves(\n          (VectorizedReluDropoutBitmaskBackwardKernel<T, int64_t, pack_size, false>),\n          std::max<int64_t>(1, pack_num), elem_cnt, cols, aux_ld, scale, /*n_tail=*/0, tail_offset,\n          dy, mask, dx);\n    }\n  }\n}\n\ntemplate<typename T>\nclass FusedReluDropoutGradKernel final : public user_op::OpKernel,\n                                         public user_op::CudaGraphSupport {\n public:\n  FusedReluDropoutGradKernel() = default;\n  ~FusedReluDropoutGradKernel() override = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    const user_op::Tensor* mask = ctx->Tensor4ArgNameAndIndex(\"mask\", 0);\n    user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex(\"dx\", 0);\n    const float scale = ctx->Attr<float>(\"scale\");\n\n    const int64_t cols = dy->shape_view().At(1);\n    const int64_t aux_ld = mask->shape_view().At(1) * 32;\n    const int64_t elem_cnt = dy->shape_view().elem_cnt();\n    LaunchVectorizedReluDropoutBackwardKernel<T>(\n        ctx->stream(), elem_cnt, cols, aux_ld, scale, reinterpret_cast<const T*>(dy->dptr()),\n        mask->dptr<int32_t>(), reinterpret_cast<T*>(dx->mut_dptr()));\n  }\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_FUSED_RELU_DROPOUT_GRAD_KERNEL_GPU(cpp_type, data_type) \\\n  REGISTER_USER_KERNEL(\"fused_relu_dropout_grad\")                        \\\n      .SetCreateFn<FusedReluDropoutGradKernel<cpp_type>>()               \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)   \\\n                       && (user_op::HobDataType(\"dx\", 0) == data_type));\n\nREGISTER_FUSED_RELU_DROPOUT_GRAD_KERNEL_GPU(float, DataType::kFloat)\nREGISTER_FUSED_RELU_DROPOUT_GRAD_KERNEL_GPU(half, DataType::kFloat16)\n#if CUDA_VERSION >= 11000\nREGISTER_FUSED_RELU_DROPOUT_GRAD_KERNEL_GPU(nv_bfloat16, DataType::kBFloat16)\n#endif\n\n}  // namespace\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/fused_rnn_cell_kernel_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_KERNELS_FUSED_RNN_CELL_KERNEL_UTIL_H_\n#define ONEFLOW_USER_KERNELS_FUSED_RNN_CELL_KERNEL_UTIL_H_\n\n// NOTE(Liang Depeng): Modified from\n// https://github.com/pytorch/pytorch/blob/master/c10/macros/Macros.h#L256\n#if defined(__CUDACC__)\n// constants from\n// (https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#features-and-technical-specifications)\n// The maximum number of threads per multiprocessor is 1024 for Turing\n// architecture (7.5), 1536 for Geforce Ampere (8.6), and 2048 for all other\n// architectures. You'll get warnings if you exceed these constants. Hence, the\n// following macros adjust the input values from the user to resolve potential\n// warnings.\n#if __CUDA_ARCH__ == 750\nconstexpr uint32_t CUDA_MAX_THREADS_PER_SM = 1024;\n#elif __CUDA_ARCH__ == 860\nconstexpr uint32_t CUDA_MAX_THREADS_PER_SM = 1536;\n#else\nconstexpr uint32_t CUDA_MAX_THREADS_PER_SM = 2048;\n#endif\n// CUDA_MAX_THREADS_PER_BLOCK is same for all architectures currently\nconstexpr uint32_t CUDA_MAX_THREADS_PER_BLOCK = 1024;\n// CUDA_THREADS_PER_BLOCK_FALLBACK is the \"canonical fallback\" choice of block\n// size. 256 is a good number for this fallback and should give good occupancy\n// and versatility across all architectures.\nconstexpr uint32_t CUDA_THREADS_PER_BLOCK_FALLBACK = 256;\n// NOTE: if you are thinking of constexpr-ify the inputs to launch bounds, it\n//       turns out that although __launch_bounds__ can take constexpr, it\n//       can't take a constexpr that has anything to do with templates.\n//       Currently we use launch_bounds that depend on template arguments in\n//       Loops.cuh, Reduce.cuh and LossCTC.cuh. Hence, OF_MAX_THREADS_PER_BLOCK\n//       and OF_MIN_BLOCKS_PER_SM are kept as macros.\n// Suppose you were planning to write __launch_bounds__(a, b), based on your\n// performance tuning on a modern GPU. Instead, you should write\n// __launch_bounds__(OF_MAX_THREADS_PER_BLOCK(a), OF_MIN_BLOCKS_PER_SM(a, b)),\n// which will also properly respect limits on old architectures.\n#define OF_MAX_THREADS_PER_BLOCK(val) \\\n  (((val) <= CUDA_MAX_THREADS_PER_BLOCK) ? (val) : CUDA_THREADS_PER_BLOCK_FALLBACK)\n#define OF_MIN_BLOCKS_PER_SM(threads_per_block, blocks_per_sm)         \\\n  ((((threads_per_block) * (blocks_per_sm) <= CUDA_MAX_THREADS_PER_SM) \\\n        ? (blocks_per_sm)                                              \\\n        : ((CUDA_MAX_THREADS_PER_SM + (threads_per_block)-1) / (threads_per_block))))\n// OF_LAUNCH_BOUNDS is analogous to __launch_bounds__\n#define OF_LAUNCH_BOUNDS_0 \\\n  __launch_bounds__(256, 4)  // default launch bounds that should give good occupancy and\n                             // versatility across all architectures.\n#define OF_LAUNCH_BOUNDS_1(max_threads_per_block) \\\n  __launch_bounds__((OF_MAX_THREADS_PER_BLOCK((max_threads_per_block))))\n#define OF_LAUNCH_BOUNDS_2(max_threads_per_block, min_blocks_per_sm)     \\\n  __launch_bounds__((OF_MAX_THREADS_PER_BLOCK((max_threads_per_block))), \\\n                    (OF_MIN_BLOCKS_PER_SM((max_threads_per_block), (min_blocks_per_sm))))\n#endif\n\n#endif  // ONEFLOW_USER_KERNELS_FUSED_RNN_CELL_KERNEL_UTIL_H_\n"
  },
  {
    "path": "oneflow/user/kernels/fused_scale_mask_bias_softmax.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/cuda/softmax.cuh\"\n#include \"oneflow/core/cuda/elementwise.cuh\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n#include \"oneflow/core/framework/user_op_tensor.h\"\n\nnamespace oneflow {\nnamespace {\ntemplate<typename SRC, typename DST>\nstruct LoadWithBias {\n  LoadWithBias(const SRC* x_ptr, const SRC* mask_ptr, const SRC* bias_ptr, const SRC scale,\n               int64_t row_stride, int64_t bias_stride, int64_t row_size)\n      : x_ptr_(x_ptr),\n        mask_ptr_(mask_ptr),\n        bias_ptr_(bias_ptr),\n        scale_(scale),\n        row_stride_(row_stride),\n        bias_stride_(bias_stride),\n        row_size_(row_size) {}\n  template<int N>\n  __device__ void load(DST* dst, int64_t row, int64_t col) const {\n    cuda::softmax::Pack<SRC, N> x;\n    const int64_t offset = (row * row_size_ + col) / N;\n    x.storage = *(reinterpret_cast<const cuda::softmax::PackType<SRC, N>*>(x_ptr_) + offset);\n    cuda::softmax::Pack<SRC, N> mask;\n    const int64_t m_offset = (row / row_stride_ * row_size_ + col) / N;\n    mask.storage =\n        *(reinterpret_cast<const cuda::softmax::PackType<SRC, N>*>(mask_ptr_) + m_offset);\n    cuda::softmax::Pack<SRC, N> bias;\n    /*\n    1). bias_stride_ = 0 for bias: [1, num_heads, seqlen_q, seqlen_kv]\n                             x:    [batch_size, num_heads, seqlen_q, seqlen_kv]\n    2). bias_stride_ > 0 for bias: [ensemble_batch, 1, num_heads, seqlen_q, seqlen_kv]\n                             x:    [ensemble_batch, batch_size, num_heads, seqlen_q, seqlen_kv]\n        here, bias_stride_ = batch_size, row_stride_ = num_heads * seqlen_q\n        x could be viewed as [B1, B2, B3] and bias could be viewed as [B1, 1, B3] where\n        B1 = ensemble_batch, B2 = batch_size = bias_stride_, B3 = num_heads * seqlen_q = row_stride_\n        For row in range [0, B1 * B2 * B3) {[0, ensemble_batch * batch_size * num_heads * seqlen_q]}\n        b1 = row/(B2*B3), b2=(row%(B2*B3)/B3), b3 = row%B3, after broadcast b2 will be 0 for bias.\n        And finally the correspoding (broadcast) row of bias will be:\n        `b1 * B3 + b3 = row/(B2*B3) * B3 + row%B3\n        = row / (bias_stride_ * row_stride_) * row_stride_ + row % row_stride_`\n    */\n    int64_t bias_offset =\n        (bias_stride_ > 0)\n            ? ((row / (bias_stride_ * row_stride_) * row_stride_ + row % row_stride_) * row_size_\n               + col)\n                  / N\n            : (row % row_stride_ * row_size_ + col) / N;\n    bias.storage =\n        *(reinterpret_cast<const cuda::softmax::PackType<SRC, N>*>(bias_ptr_) + bias_offset);\n#pragma unroll\n    for (int i = 0; i < N; ++i) {\n      dst[i] = static_cast<DST>(x.elem[i]) * static_cast<DST>(scale_)\n               + static_cast<DST>(mask.elem[i]) + static_cast<DST>(bias.elem[i]);\n    }\n  }\n  const SRC* x_ptr_;\n  const SRC* mask_ptr_;\n  const SRC* bias_ptr_;\n  const SRC scale_;\n  int64_t row_stride_;\n  int64_t bias_stride_;\n  int64_t row_size_;\n};\n\ntemplate<typename SRC, typename DST>\nstruct LoadWithoutBias {\n  LoadWithoutBias(const SRC* x_ptr, const SRC* mask_ptr, const SRC scale, int64_t row_stride,\n                  int64_t row_size)\n      : x_ptr_(x_ptr),\n        mask_ptr_(mask_ptr),\n        scale_(scale),\n        row_stride_(row_stride),\n        row_size_(row_size) {}\n  template<int N>\n  __device__ void load(DST* dst, int64_t row, int64_t col) const {\n    cuda::softmax::Pack<SRC, N> x;\n    const int64_t offset = (row * row_size_ + col) / N;\n    x.storage = *(reinterpret_cast<const cuda::softmax::PackType<SRC, N>*>(x_ptr_) + offset);\n    cuda::softmax::Pack<SRC, N> mask;\n    const int64_t m_offset = (row / row_stride_ * row_size_ + col) / N;\n    mask.storage =\n        *(reinterpret_cast<const cuda::softmax::PackType<SRC, N>*>(mask_ptr_) + m_offset);\n#pragma unroll\n    for (int i = 0; i < N; ++i) {\n      dst[i] =\n          static_cast<DST>(x.elem[i]) * static_cast<DST>(scale_) + static_cast<DST>(mask.elem[i]);\n    }\n  }\n  const SRC* x_ptr_;\n  const SRC* mask_ptr_;\n  const SRC scale_;\n  int64_t row_stride_;\n  int64_t row_size_;\n};\n\ntemplate<typename T, typename ComputeType = typename cuda::softmax::DefaultComputeType<T>::type>\nvoid LaunchFusedSoftmaxForwardKernel(cudaStream_t stream, T* out, const T* x, const T* mask,\n                                     const T* bias, T scale, const int64_t row_stride,\n                                     const int64_t bias_stride, const int64_t rows,\n                                     const int64_t row_size) {\n  cuda::softmax::DirectStore<ComputeType, T> store(out, row_size);\n  if (bias != nullptr) {\n    LoadWithBias<T, ComputeType> load(x, mask, bias, scale, row_stride, bias_stride, row_size);\n    OF_CUDA_CHECK((cuda::softmax::DispatchSoftmax<decltype(load), decltype(store), ComputeType>(\n        stream, load, store, rows, row_size)));\n  } else {\n    LoadWithoutBias<T, ComputeType> load(x, mask, scale, row_stride, row_size);\n    OF_CUDA_CHECK((cuda::softmax::DispatchSoftmax<decltype(load), decltype(store), ComputeType>(\n        stream, load, store, rows, row_size)));\n  }\n};\n\ntemplate<typename SRC, typename DST>\nstruct GradStore {\n  GradStore(DST* dx, const SRC scale, int64_t row_size)\n      : dx(dx), scale(scale), row_size(row_size) {}\n  template<int N>\n  __device__ void store(const SRC* dout, int64_t row, int64_t col) const {\n    cuda::softmax::Pack<DST, N> x;\n    const int64_t offset = (row * row_size + col) / N;\n#pragma unroll\n    for (int i = 0; i < N; ++i) { x.elem[i] = static_cast<DST>(dout[i]) * static_cast<DST>(scale); }\n    *(reinterpret_cast<cuda::softmax::PackType<DST, N>*>(dx) + offset) = x.storage;\n  }\n  DST* dx;\n  const SRC scale;\n  int64_t row_size;\n};\n\ntemplate<typename T, typename ComputeType = typename cuda::softmax::DefaultComputeType<T>::type>\nvoid LaunchSoftmaxBackwardKernel(cudaStream_t stream, T* dx, const T* y, const T* dy, T scale,\n                                 const int64_t rows, const int64_t row_size) {\n  GradStore<ComputeType, T> store(dx, scale, row_size);\n  cuda::softmax::DirectLoad<T, ComputeType> load_y(y, row_size);\n  cuda::softmax::DirectLoad<T, ComputeType> load_dy(dy, row_size);\n  OF_CUDA_CHECK((cuda::softmax::DispatchSoftmaxGrad<decltype(load_y), decltype(load_dy),\n                                                    decltype(store), ComputeType>(\n      stream, load_y, load_dy, store, rows, row_size)));\n};\n\n}  // namespace\n\ntemplate<typename T>\nclass FusedScaleMaskBiasSoftmaxKernel final : public user_op::OpKernel {\n public:\n  FusedScaleMaskBiasSoftmaxKernel() = default;\n  ~FusedScaleMaskBiasSoftmaxKernel() override = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    const user_op::Tensor* mask = ctx->Tensor4ArgNameAndIndex(\"mask\", 0);\n    const T scale = ctx->Attr<float>(\"scale\");\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n\n    auto x_shape = x->shape_view();\n    auto axes = x_shape.NumAxes();\n    /*\n     * axes=3 for x: [batch_size, num_heads, seq], mask: [batch_size, 1, seq], no bias here\n     * axes=4 for x: [batch_size, num_heads, seq_len_q, seq_len_kv]\n     *            mask: [batch_size, 1, 1, seq_len_kv]\n     *            bias: [1, num_heads, seq_len_q, seq_len_kv]\n     * axes=5 for x: [ensemble_batch, batch_size, num_heads, seq_len_q, seq_len_kv]\n     *            mask: [ensemble_batch, batch_size, 1, 1, seq_len_kv]\n     *            bias: [ensemble_batch, 1, num_heads, seq_len_q, seq_len_kv]\n     * `axes=5` is equivalent to `axes=4` when ensemble_batch = 1 .\n     *\n     * row_stride is used for computing `mask` stride and\n     * bias_stride for computing `bias` stride\n     * row_stride is num_heads (for `axes=3`) or num_heads * seq_len_q (for `axes=4` & `axes=5`)\n     * bias_stride is 0 (for `axes=4`) or batch_size (for `axes=5`)\n     * row_size = seq_len_k (the last dimension of `x`)\n     */\n    CHECK(axes == 3 || axes == 4 || axes == 5);\n    auto mask_shape = mask->shape_view();\n    CHECK(mask_shape.NumAxes() == axes);\n    const int row_size = x_shape.At(axes - 1);\n    const int rows = x_shape.elem_cnt() / row_size;\n    int row_stride = 1;\n    for (int i = axes - 2; i >= 0; i--) {\n      if (mask_shape.At(i) == 1)\n        row_stride *= x_shape.At(i);\n      else\n        break;\n    }\n\n    user_op::Tensor* bias = nullptr;\n    int64_t bias_stride = 0;\n    if (ctx->has_input(\"bias\", 0)) {\n      bias = ctx->Tensor4ArgNameAndIndex(\"bias\", 0);\n      if (axes == 5 && x_shape.At(0) != 1) bias_stride = x_shape.At(1);\n    }\n    LaunchFusedSoftmaxForwardKernel<T>(ctx->stream()->As<ep::CudaStream>()->cuda_stream(),\n                                       out->mut_dptr<T>(), x->dptr<T>(), mask->dptr<T>(),\n                                       ctx->has_input(\"bias\", 0) ? bias->dptr<T>() : nullptr, scale,\n                                       row_stride, bias_stride, rows, row_size);\n  }\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_FUSED_SCALE_MASK_BIAS_SOFTMAX_KERNEL_GPU(dtype)       \\\n  REGISTER_USER_KERNEL(\"fused_scale_mask_bias_softmax\")                \\\n      .SetCreateFn<FusedScaleMaskBiasSoftmaxKernel<dtype>>()           \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \\\n                       && (user_op::HobDataType(\"out\", 0) == GetDataType<dtype>::value));\n\nREGISTER_FUSED_SCALE_MASK_BIAS_SOFTMAX_KERNEL_GPU(half)\n#if CUDA_VERSION >= 11000\nREGISTER_FUSED_SCALE_MASK_BIAS_SOFTMAX_KERNEL_GPU(nv_bfloat16)\n#endif\nREGISTER_FUSED_SCALE_MASK_BIAS_SOFTMAX_KERNEL_GPU(float)\n\ntemplate<typename T>\nclass FusedScaleMaskBiasSoftmaxGradKernel final : public user_op::OpKernel {\n public:\n  FusedScaleMaskBiasSoftmaxGradKernel() = default;\n  ~FusedScaleMaskBiasSoftmaxGradKernel() override = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex(\"dx\", 0);\n    const T scale = ctx->Attr<float>(\"scale\");\n    auto y_shape = y->shape_view();\n\n    const int64_t axes = y_shape.NumAxes();\n    int64_t row_size = y_shape.At(axes - 1);\n    int64_t rows = y_shape.elem_cnt() / row_size;\n\n    LaunchSoftmaxBackwardKernel<T>(ctx->stream()->As<ep::CudaStream>()->cuda_stream(),\n                                   dx->mut_dptr<T>(), y->dptr<T>(), dy->dptr<T>(), scale, rows,\n                                   row_size);\n  }\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_FUSED_SCALE_MASK_BIAS_SOFTMAX_GRAD_KERNEL_GPU(dtype)  \\\n  REGISTER_USER_KERNEL(\"fused_scale_mask_bias_softmax_grad\")           \\\n      .SetCreateFn<FusedScaleMaskBiasSoftmaxGradKernel<dtype>>()       \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \\\n                       && (user_op::HobDataType(\"dx\", 0) == GetDataType<dtype>::value));\n\nREGISTER_FUSED_SCALE_MASK_BIAS_SOFTMAX_GRAD_KERNEL_GPU(half)\n#if CUDA_VERSION >= 11000\nREGISTER_FUSED_SCALE_MASK_BIAS_SOFTMAX_GRAD_KERNEL_GPU(nv_bfloat16)\n#endif\nREGISTER_FUSED_SCALE_MASK_BIAS_SOFTMAX_GRAD_KERNEL_GPU(float)\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/fused_scale_mask_softmax.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/cuda/softmax.cuh\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n#include \"oneflow/user/kernels/fused_softmax.cuh\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename T, typename ComputeType, typename MASK, size_t num_dims>\nvoid LaunchBroadcastForwardKernel(cudaStream_t stream, const T* x, T* y, const MASK* mask,\n                                  const int64_t elem_cnt, const int64_t rows, const int64_t cols,\n                                  const float fill, const float scale, const int64_t* input_dims,\n                                  const int64_t* mask_dims) {\n  NdIndexOffsetHelper<int32_t, num_dims> input_index_helper(input_dims);\n  NdIndexOffsetHelper<int32_t, num_dims> mask_index_helper(mask_dims);\n  cuda::fused_softmax::BroadcastMaskSoftmaxParams<num_dims, int32_t> params;\n  params.src_index_helper = input_index_helper;\n  params.mask_index_helper = mask_index_helper;\n  params.mask_dims = mask_dims;\n  params.row_size = cols;\n  params.fill = fill;\n  params.scale = scale;\n  cuda::fused_softmax::BroadcastScaleMaskLoad<T, ComputeType, MASK, num_dims, int32_t> load(x, mask,\n                                                                                            params);\n  cuda::softmax::DirectStore<ComputeType, T> store(y, cols);\n  OF_CUDA_CHECK((cuda::softmax::DispatchSoftmax<decltype(load), decltype(store), ComputeType>(\n      stream, load, store, rows, cols)));\n}\n\ntemplate<typename T, typename ComputeType, typename MASK>\nvoid LaunchElementwiseForwardKernel(cudaStream_t stream, const T* x, T* y, const MASK* mask,\n                                    const int64_t rows, const int64_t cols, const float fill,\n                                    const float scale) {\n  cuda::fused_softmax::ElementwiseMaskSoftmaxParams params;\n  params.row_size = cols;\n  params.fill = fill;\n  params.scale = scale;\n  cuda::fused_softmax::ElementwiseScaleMaskLoad<T, ComputeType, MASK> load(x, mask, params);\n  cuda::softmax::DirectStore<ComputeType, T> store(y, cols);\n  OF_CUDA_CHECK((cuda::softmax::DispatchSoftmax<decltype(load), decltype(store), ComputeType>(\n      stream, load, store, rows, cols)));\n}\n\ntemplate<typename T, typename ComputeType, typename MASK, size_t num_dims>\nvoid LaunchBroadcastBackwardKernel(cudaStream_t stream, const T* y, const T* dy, T* dx,\n                                   const MASK* mask, const int64_t elem_cnt, const int64_t rows,\n                                   const int64_t cols, const float fill, const float scale,\n                                   const int64_t* input_dims, const int64_t* mask_dims) {\n  NdIndexOffsetHelper<int32_t, num_dims> input_index_helper(input_dims);\n  NdIndexOffsetHelper<int32_t, num_dims> mask_index_helper(mask_dims);\n  cuda::fused_softmax::BroadcastMaskSoftmaxParams<num_dims, int32_t> params;\n  params.src_index_helper = input_index_helper;\n  params.mask_index_helper = mask_index_helper;\n  params.mask_dims = mask_dims;\n  params.row_size = cols;\n  params.fill = fill;\n  params.scale = scale;\n  cuda::softmax::DirectLoad<T, ComputeType> load_y(y, cols);\n  cuda::softmax::DirectLoad<T, ComputeType> load_dy(dy, cols);\n  cuda::fused_softmax::BroadcastScaleMaskStore<ComputeType, T, MASK, num_dims, int32_t> store(\n      dx, mask, params);\n  OF_CUDA_CHECK((\n      cuda::softmax::DispatchSoftmaxGrad<decltype(load_y), decltype(load_dy), decltype(store),\n                                         ComputeType>(stream, load_y, load_dy, store, rows, cols)));\n}\n\ntemplate<typename T, typename ComputeType, typename MASK>\nvoid LaunchElementwiseBackwardKernel(cudaStream_t stream, const T* y, const T* dy, T* dx,\n                                     const MASK* mask, const int64_t rows, const int64_t cols,\n                                     const float fill, const float scale) {\n  cuda::fused_softmax::ElementwiseMaskSoftmaxParams params;\n  params.row_size = cols;\n  params.fill = fill;\n  params.scale = scale;\n  cuda::softmax::DirectLoad<T, ComputeType> load_y(y, cols);\n  cuda::softmax::DirectLoad<T, ComputeType> load_dy(dy, cols);\n  cuda::fused_softmax::ElementwiseScaleMaskStore<ComputeType, T, MASK> store(dx, mask, params);\n  OF_CUDA_CHECK((\n      cuda::softmax::DispatchSoftmaxGrad<decltype(load_y), decltype(load_dy), decltype(store),\n                                         ComputeType>(stream, load_y, load_dy, store, rows, cols)));\n}\n\nconstexpr int32_t kMaxNumDims = 5;\n\ntemplate<typename T, typename MASK>\nclass FusedScaleMaskSoftmaxKernel final : public user_op::OpKernel {\n public:\n  FusedScaleMaskSoftmaxKernel() = default;\n  ~FusedScaleMaskSoftmaxKernel() override = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    const user_op::Tensor* mask = ctx->Tensor4ArgNameAndIndex(\"mask\", 0);\n    user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    const float mask_fill_value = ctx->Attr<float>(\"mask_fill_value\");\n    const float scale_value = ctx->Attr<float>(\"scale_value\");\n    const ShapeView& x_shape = x->shape_view();\n    const ShapeView& mask_shape = mask->shape_view();\n    CHECK_GE(x_shape.NumAxes(), 2);\n    const int64_t elem_cnt = x_shape.elem_cnt();\n    const int64_t cols = x_shape.At(x_shape.NumAxes() - 1);\n    const int64_t rows = x_shape.Count(0, x_shape.NumAxes() - 1);\n    const size_t num_input_dims = x_shape.NumAxes();\n    const int64_t* input_dims = x_shape.ptr();\n    const size_t num_mask_dims = mask_shape.NumAxes();\n    const int64_t* mask_dims = mask_shape.ptr();\n    using ComputeType = typename cuda::softmax::DefaultComputeType<T>::type;\n\n    size_t simplified_num_dims = 0;\n    int64_t simplified_input_dims[kMaxNumDims];\n    int64_t simplified_mask_dims[kMaxNumDims];\n    cuda::fused_softmax::SimplifyBroadcastDims(num_input_dims, input_dims, num_mask_dims, mask_dims,\n                                               &simplified_num_dims, simplified_input_dims,\n                                               simplified_mask_dims);\n    if (simplified_num_dims == 1) {\n      LaunchElementwiseForwardKernel<T, ComputeType, MASK>(\n          ctx->stream()->As<ep::CudaStream>()->cuda_stream(), x->dptr<T>(), y->mut_dptr<T>(),\n          mask->dptr<MASK>(), rows, cols, mask_fill_value, scale_value);\n    }\n#define DEFINE_ONE_ELIF(dims)                                                               \\\n  else if (simplified_num_dims == dims) {                                                   \\\n    LaunchBroadcastForwardKernel<T, ComputeType, MASK, dims>(                               \\\n        ctx->stream()->As<ep::CudaStream>()->cuda_stream(), x->dptr<T>(), y->mut_dptr<T>(), \\\n        mask->dptr<MASK>(), elem_cnt, rows, cols, mask_fill_value, scale_value,             \\\n        simplified_input_dims, simplified_mask_dims);                                       \\\n  }\n    DEFINE_ONE_ELIF(2)\n    DEFINE_ONE_ELIF(3)\n    DEFINE_ONE_ELIF(4)\n#undef DEFINE_ONE_ELIF\n    else {\n      UNIMPLEMENTED();\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\ntemplate<typename T, typename MASK>\nclass FusedScaleMaskSoftmaxGradKernel final : public user_op::OpKernel {\n public:\n  FusedScaleMaskSoftmaxGradKernel() = default;\n  ~FusedScaleMaskSoftmaxGradKernel() override = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    const user_op::Tensor* mask = ctx->Tensor4ArgNameAndIndex(\"mask\", 0);\n    user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex(\"dx\", 0);\n    const float scale_value = ctx->Attr<float>(\"scale_value\");\n    const float mask_fill_value = static_cast<float>(0.0);\n    const ShapeView& dy_shape = dy->shape_view();\n    const ShapeView& mask_shape = mask->shape_view();\n    CHECK_GE(dy_shape.NumAxes(), 2);\n    const int64_t elem_cnt = dy_shape.elem_cnt();\n    const int64_t cols = dy_shape.At(dy_shape.NumAxes() - 1);\n    const int64_t rows = dy_shape.Count(0, dy_shape.NumAxes() - 1);\n    const int64_t* input_dims = dy_shape.ptr();\n    const size_t num_input_dims = dy_shape.NumAxes();\n    const int64_t* mask_dims = mask_shape.ptr();\n    const size_t num_mask_dims = mask_shape.NumAxes();\n\n    using ComputeType = typename cuda::softmax::DefaultComputeType<T>::type;\n\n    size_t simplified_num_dims = 0;\n    int64_t simplified_input_dims[kMaxNumDims];\n    int64_t simplified_mask_dims[kMaxNumDims];\n    cuda::fused_softmax::SimplifyBroadcastDims(num_input_dims, input_dims, num_mask_dims, mask_dims,\n                                               &simplified_num_dims, simplified_input_dims,\n                                               simplified_mask_dims);\n    if (simplified_num_dims == 1) {\n      LaunchElementwiseBackwardKernel<T, ComputeType, MASK>(\n          ctx->stream()->As<ep::CudaStream>()->cuda_stream(), y->dptr<T>(), dy->dptr<T>(),\n          dx->mut_dptr<T>(), mask->dptr<MASK>(), rows, cols, mask_fill_value, scale_value);\n    }\n#define DEFINE_ONE_ELIF(dims)                                                                      \\\n  else if (simplified_num_dims == dims) {                                                          \\\n    LaunchBroadcastBackwardKernel<T, ComputeType, MASK, dims>(                                     \\\n        ctx->stream()->As<ep::CudaStream>()->cuda_stream(), y->dptr<T>(), dy->dptr<T>(),           \\\n        dx->mut_dptr<T>(), mask->dptr<MASK>(), elem_cnt, rows, cols, mask_fill_value, scale_value, \\\n        simplified_input_dims, simplified_mask_dims);                                              \\\n  }\n    DEFINE_ONE_ELIF(2)\n    DEFINE_ONE_ELIF(3)\n    DEFINE_ONE_ELIF(4)\n#undef DEFINE_ONE_ELIF\n    else {\n      UNIMPLEMENTED();\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n}  // namespace\n\n#define REGISTER_FUSED_SCALE_MASK_SOFTMAX_CUDA_KERNEL(dtype, mask_dtype)              \\\n  REGISTER_USER_KERNEL(\"fused_scale_mask_softmax\")                                    \\\n      .SetCreateFn<FusedScaleMaskSoftmaxKernel<dtype, mask_dtype>>()                  \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)                \\\n                       && (user_op::HobDataType(\"x\", 0) == GetDataType<dtype>::value) \\\n                       && (user_op::HobDataType(\"mask\", 0) == GetDataType<mask_dtype>::value));\n\nREGISTER_FUSED_SCALE_MASK_SOFTMAX_CUDA_KERNEL(half, bool)\nREGISTER_FUSED_SCALE_MASK_SOFTMAX_CUDA_KERNEL(float, bool)\n#undef REGISTER_FUSED_SCALE_MASK_SOFTMAX_CUDA_KERNEL\n\n#define REGISTER_FUSED_SCALE_MASK_SOFTMAX_GRAD_KERNEL(dtype, mask_dtype)               \\\n  REGISTER_USER_KERNEL(\"fused_scale_mask_softmax_grad\")                                \\\n      .SetCreateFn<FusedScaleMaskSoftmaxGradKernel<dtype, mask_dtype>>()               \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)                 \\\n                       && (user_op::HobDataType(\"dy\", 0) == GetDataType<dtype>::value) \\\n                       && (user_op::HobDataType(\"mask\", 0) == GetDataType<mask_dtype>::value));\n\nREGISTER_FUSED_SCALE_MASK_SOFTMAX_GRAD_KERNEL(half, bool)\nREGISTER_FUSED_SCALE_MASK_SOFTMAX_GRAD_KERNEL(float, bool)\n#undef REGISTER_FUSED_SCALE_MASK_SOFTMAX_GRAD_KERNEL\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/fused_scale_mask_softmax_dropout.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/cuda/softmax.cuh\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n#include \"oneflow/user/kernels/fused_softmax.cuh\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename T, typename ComputeType, typename MASK, int num_dims>\nvoid LaunchBroadcastForwardKernel(cudaStream_t stream, const T* x, T* y, T* softmax_y,\n                                  const MASK* mask, const bool* dropout_mask,\n                                  const int64_t elem_cnt, const int64_t rows, const int64_t cols,\n                                  const float fill, const float scale, const float dropout_scale,\n                                  const int64_t* input_dims, const int64_t* mask_dims) {\n  cuda::fused_softmax::DropoutStore<ComputeType, T> store(y, softmax_y, dropout_mask, cols,\n                                                          dropout_scale);\n  NdIndexOffsetHelper<int32_t, num_dims> input_index_helper(input_dims);\n  NdIndexOffsetHelper<int32_t, num_dims> mask_index_helper(mask_dims);\n  cuda::fused_softmax::BroadcastMaskSoftmaxParams<num_dims, int32_t> params;\n  params.src_index_helper = input_index_helper;\n  params.mask_index_helper = mask_index_helper;\n  params.mask_dims = mask_dims;\n  params.row_size = cols;\n  params.fill = fill;\n  params.scale = scale;\n  cuda::fused_softmax::BroadcastScaleMaskLoad<T, ComputeType, MASK, num_dims, int32_t> load(x, mask,\n                                                                                            params);\n  OF_CUDA_CHECK((cuda::softmax::DispatchSoftmax<decltype(load), decltype(store), ComputeType>(\n      stream, load, store, rows, cols)));\n}\n\ntemplate<typename T, typename ComputeType, typename MASK>\nvoid LaunchElementwiseForwardKernel(cudaStream_t stream, const T* x, T* y, T* softmax_y,\n                                    const MASK* mask, const bool* dropout_mask, const int64_t rows,\n                                    const int64_t cols, const float fill, const float scale,\n                                    const float dropout_scale) {\n  cuda::fused_softmax::ElementwiseMaskSoftmaxParams params;\n  params.row_size = cols;\n  params.fill = fill;\n  params.scale = scale;\n  cuda::fused_softmax::ElementwiseScaleMaskLoad<T, ComputeType, MASK> load(x, mask, params);\n  cuda::fused_softmax::DropoutStore<ComputeType, T> store(y, softmax_y, dropout_mask, cols,\n                                                          dropout_scale);\n  OF_CUDA_CHECK((cuda::softmax::DispatchSoftmax<decltype(load), decltype(store), ComputeType>(\n      stream, load, store, rows, cols)));\n}\n\ntemplate<typename T, typename ComputeType, typename MASK, int num_dims>\nvoid LaunchBroadcastBackwardKernel(cudaStream_t stream, const T* softmax_y, const T* dy, T* dx,\n                                   const MASK* mask, const bool* dropout_mask,\n                                   const int64_t elem_cnt, const int64_t rows, const int64_t cols,\n                                   const float fill, const float scale, const float dropout_scale,\n                                   const int64_t* input_dims, const int64_t* mask_dims) {\n  cuda::fused_softmax::MaskScaleLoad<T, ComputeType> load_dy(dy, dropout_mask, cols, dropout_scale);\n  NdIndexOffsetHelper<int32_t, num_dims> input_index_helper(input_dims, num_dims);\n  NdIndexOffsetHelper<int32_t, num_dims> mask_index_helper(mask_dims, num_dims);\n  cuda::fused_softmax::BroadcastMaskSoftmaxParams<num_dims, int32_t> params;\n  params.src_index_helper = input_index_helper;\n  params.mask_index_helper = mask_index_helper;\n  params.mask_dims = mask_dims;\n  params.row_size = cols;\n  params.fill = fill;\n  params.scale = scale;\n  cuda::softmax::DirectLoad<T, ComputeType> load_softmax_y(softmax_y, cols);\n  cuda::fused_softmax::BroadcastScaleMaskStore<ComputeType, T, MASK, num_dims, int32_t> store(\n      dx, mask, params);\n  OF_CUDA_CHECK((cuda::softmax::DispatchSoftmaxGrad<decltype(load_softmax_y), decltype(load_dy),\n                                                    decltype(store), ComputeType>(\n      stream, load_softmax_y, load_dy, store, rows, cols)));\n}\n\ntemplate<typename T, typename ComputeType, typename MASK>\nvoid LaunchElementwiseBackwardKernel(cudaStream_t stream, const T* softmax_y, const T* dy, T* dx,\n                                     const MASK* mask, const bool* dropout_mask, const int64_t rows,\n                                     const int64_t cols, const float fill, const float scale,\n                                     const float dropout_scale) {\n  cuda::fused_softmax::ElementwiseMaskSoftmaxParams params;\n  params.row_size = cols;\n  params.fill = fill;\n  params.scale = scale;\n  cuda::softmax::DirectLoad<T, ComputeType> load_softmax_y(softmax_y, cols);\n  cuda::fused_softmax::MaskScaleLoad<T, ComputeType> load_dy(dy, dropout_mask, cols, dropout_scale);\n  cuda::fused_softmax::ElementwiseScaleMaskStore<ComputeType, T, MASK> store(dx, mask, params);\n  OF_CUDA_CHECK((cuda::softmax::DispatchSoftmaxGrad<decltype(load_softmax_y), decltype(load_dy),\n                                                    decltype(store), ComputeType>(\n      stream, load_softmax_y, load_dy, store, rows, cols)));\n}\n\nconstexpr int32_t kMaxNumDims = 5;\n\ntemplate<typename T, typename MASK>\nclass FusedScaleMaskSoftmaxDropoutKernel final : public user_op::OpKernel {\n public:\n  FusedScaleMaskSoftmaxDropoutKernel() = default;\n  ~FusedScaleMaskSoftmaxDropoutKernel() override = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    const user_op::Tensor* mask = ctx->Tensor4ArgNameAndIndex(\"mask\", 0);\n    const user_op::Tensor* dropout_mask = ctx->Tensor4ArgNameAndIndex(\"dropout_mask\", 0);\n    user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    const float mask_fill_value = ctx->Attr<float>(\"mask_fill_value\");\n    const float scale_value = ctx->Attr<float>(\"scale_value\");\n    const float dropout_scale_value = ctx->Attr<float>(\"dropout_scale_value\");\n    user_op::Tensor* softmax_y = ctx->Tensor4ArgNameAndIndex(\"softmax_y\", 0);\n    const ShapeView& x_shape = x->shape_view();\n    const ShapeView& mask_shape = mask->shape_view();\n    CHECK_GE(x_shape.NumAxes(), 2);\n    const int64_t elem_cnt = x_shape.elem_cnt();\n    const int64_t cols = x_shape.At(x_shape.NumAxes() - 1);\n    const int64_t rows = x_shape.Count(0, x_shape.NumAxes() - 1);\n    const size_t num_input_dims = x_shape.NumAxes();\n    const int64_t* input_dims = x_shape.ptr();\n    const size_t num_mask_dims = mask_shape.NumAxes();\n    const int64_t* mask_dims = mask_shape.ptr();\n    using ComputeType = typename cuda::softmax::DefaultComputeType<T>::type;\n\n    size_t simplified_num_dims = 0;\n    int64_t simplified_input_dims[kMaxNumDims];\n    int64_t simplified_mask_dims[kMaxNumDims];\n    cuda::fused_softmax::SimplifyBroadcastDims(num_input_dims, input_dims, num_mask_dims, mask_dims,\n                                               &simplified_num_dims, simplified_input_dims,\n                                               simplified_mask_dims);\n    if (simplified_num_dims == 1) {\n      LaunchElementwiseForwardKernel<T, ComputeType, MASK>(\n          ctx->stream()->As<ep::CudaStream>()->cuda_stream(), x->dptr<T>(), y->mut_dptr<T>(),\n          softmax_y->mut_dptr<T>(), mask->dptr<MASK>(), dropout_mask->dptr<bool>(), rows, cols,\n          mask_fill_value, scale_value, dropout_scale_value);\n    }\n\n#define DEFINE_ONE_ELIF(dims)                                                                     \\\n  else if (simplified_num_dims == dims) {                                                         \\\n    LaunchBroadcastForwardKernel<T, ComputeType, MASK, dims>(                                     \\\n        ctx->stream()->As<ep::CudaStream>()->cuda_stream(), x->dptr<T>(), y->mut_dptr<T>(),       \\\n        softmax_y->mut_dptr<T>(), mask->dptr<MASK>(), dropout_mask->dptr<bool>(), elem_cnt, rows, \\\n        cols, mask_fill_value, scale_value, dropout_scale_value, simplified_input_dims,           \\\n        simplified_mask_dims);                                                                    \\\n  }\n    DEFINE_ONE_ELIF(2)\n    DEFINE_ONE_ELIF(3)\n    DEFINE_ONE_ELIF(4)\n#undef DEFINE_ONE_ELIF\n    else {\n      UNIMPLEMENTED();\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\ntemplate<typename T, typename MASK>\nclass FusedScaleMaskSoftmaxDropoutGradKernel final : public user_op::OpKernel {\n public:\n  FusedScaleMaskSoftmaxDropoutGradKernel() = default;\n  ~FusedScaleMaskSoftmaxDropoutGradKernel() override = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* softmax_y = ctx->Tensor4ArgNameAndIndex(\"softmax_y\", 0);\n    const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    const user_op::Tensor* mask = ctx->Tensor4ArgNameAndIndex(\"mask\", 0);\n    const user_op::Tensor* dropout_mask = ctx->Tensor4ArgNameAndIndex(\"dropout_mask\", 0);\n    user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex(\"dx\", 0);\n    const float mask_fill_value = static_cast<float>(0.0);\n    const float scale_value = ctx->Attr<float>(\"scale_value\");\n    const float dropout_scale_value = ctx->Attr<float>(\"dropout_scale_value\");\n    const ShapeView& dy_shape = dy->shape_view();\n    const int64_t elem_cnt = dy_shape.elem_cnt();\n    const ShapeView& mask_shape = mask->shape_view();\n    CHECK_GE(dy_shape.NumAxes(), 2);\n    const int64_t cols = dy_shape.At(dy_shape.NumAxes() - 1);\n    const int64_t rows = dy_shape.Count(0, dy_shape.NumAxes() - 1);\n    const int64_t* input_dims = dy_shape.ptr();\n    const size_t num_input_dims = dy_shape.NumAxes();\n    const int64_t* mask_dims = mask_shape.ptr();\n    const size_t num_mask_dims = mask_shape.NumAxes();\n\n    using ComputeType = typename cuda::softmax::DefaultComputeType<T>::type;\n    cuda::softmax::DirectLoad<T, ComputeType> load_softmax_y(softmax_y->dptr<T>(), cols);\n\n    size_t simplified_num_dims = 0;\n    int64_t simplified_input_dims[kMaxNumDims];\n    int64_t simplified_mask_dims[kMaxNumDims];\n    cuda::fused_softmax::SimplifyBroadcastDims(num_input_dims, input_dims, num_mask_dims, mask_dims,\n                                               &simplified_num_dims, simplified_input_dims,\n                                               simplified_mask_dims);\n    if (simplified_num_dims == 1) {\n      LaunchElementwiseBackwardKernel<T, ComputeType, MASK>(\n          ctx->stream()->As<ep::CudaStream>()->cuda_stream(), softmax_y->dptr<T>(), dy->dptr<T>(),\n          dx->mut_dptr<T>(), mask->dptr<MASK>(), dropout_mask->dptr<bool>(), rows, cols,\n          mask_fill_value, scale_value, dropout_scale_value);\n    }\n#define DEFINE_ONE_ELIF(dims)                                                                    \\\n  else if (simplified_num_dims == dims) {                                                        \\\n    LaunchBroadcastBackwardKernel<T, ComputeType, MASK, dims>(                                   \\\n        ctx->stream()->As<ep::CudaStream>()->cuda_stream(), softmax_y->dptr<T>(), dy->dptr<T>(), \\\n        dx->mut_dptr<T>(), mask->dptr<MASK>(), dropout_mask->dptr<bool>(), elem_cnt, rows, cols, \\\n        static_cast<float>(0.0), ctx->Attr<float>(\"scale_value\"),                                \\\n        ctx->Attr<float>(\"dropout_scale_value\"), simplified_input_dims, simplified_mask_dims);   \\\n  }\n    DEFINE_ONE_ELIF(2)\n    DEFINE_ONE_ELIF(3)\n    DEFINE_ONE_ELIF(4)\n#undef DEFINE_ONE_ELIF\n    else {\n      UNIMPLEMENTED();\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n}  // namespace\n\n#define REGISTER_FUSED_SCALE_MASK_SOFTMAX_DROPOUT_CUDA_KERNEL(dtype, mask_dtype)      \\\n  REGISTER_USER_KERNEL(\"fused_scale_mask_softmax_dropout\")                            \\\n      .SetCreateFn<FusedScaleMaskSoftmaxDropoutKernel<dtype, mask_dtype>>()           \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)                \\\n                       && (user_op::HobDataType(\"x\", 0) == GetDataType<dtype>::value) \\\n                       && (user_op::HobDataType(\"mask\", 0) == GetDataType<mask_dtype>::value));\n\nREGISTER_FUSED_SCALE_MASK_SOFTMAX_DROPOUT_CUDA_KERNEL(half, bool)\nREGISTER_FUSED_SCALE_MASK_SOFTMAX_DROPOUT_CUDA_KERNEL(float, bool)\n#undef REGISTER_FUSED_SCALE_MASK_SOFTMAX_DROPOUT_CUDA_KERNEL\n\n#define REGISTER_FUSED_SCALE_MASK_SOFTMAX_DROPOUT_GRAD_KERNEL(dtype, mask_dtype)       \\\n  REGISTER_USER_KERNEL(\"fused_scale_mask_softmax_dropout_grad\")                        \\\n      .SetCreateFn<FusedScaleMaskSoftmaxDropoutGradKernel<dtype, mask_dtype>>()        \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)                 \\\n                       && (user_op::HobDataType(\"dx\", 0) == GetDataType<dtype>::value) \\\n                       && (user_op::HobDataType(\"mask\", 0) == GetDataType<mask_dtype>::value));\n\nREGISTER_FUSED_SCALE_MASK_SOFTMAX_DROPOUT_GRAD_KERNEL(half, bool)\nREGISTER_FUSED_SCALE_MASK_SOFTMAX_DROPOUT_GRAD_KERNEL(float, bool)\n#undef REGISTER_FUSED_SCALE_MASK_SOFTMAX_DROPOUT_GRAD_KERNEL\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/fused_self_attention_query_mul_key_and_value_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/user/kernels/slice_util.h\"\n#include \"oneflow/core/kernel/new_kernel_util.h\"\n#include \"oneflow/core/ep/include/primitive/permute.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ninline cublasOperation_t GetCublasOp(char op) {\n  switch (op) {\n    case 'n':\n    case 'N': {\n      return CUBLAS_OP_N;\n    }\n    case 't':\n    case 'T': {\n      return CUBLAS_OP_T;\n    }\n    case 'c':\n    case 'C': {\n      return CUBLAS_OP_C;\n    }\n    default: {\n      UNIMPLEMENTED();\n    }\n  }\n  return CUBLAS_OP_N;\n}\n\ntemplate<typename T>\nstruct CudaDataTypeTrait;\n\ntemplate<>\nstruct CudaDataTypeTrait<float> {\n  const static cudaDataType_t value = CUDA_R_32F;\n};\n\ntemplate<>\nstruct CudaDataTypeTrait<half> {\n  const static cudaDataType_t value = CUDA_R_16F;\n};\n\ntemplate<typename T>\nvoid CublasBatchGemm(ep::CudaStream* stream, char transa, char transb, int64_t m, int64_t n,\n                     int64_t k, T alpha, const T* a, int64_t lda, int64_t stridea, const T* b,\n                     int64_t ldb, int64_t strideb, T beta, T* c, int64_t ldc, int64_t stridec,\n                     int64_t batch_size) {\n  cublasOperation_t opa = GetCublasOp(transa);\n  cublasOperation_t opb = GetCublasOp(transb);\n  if (CUDA_VERSION >= 9010 && stream->cuda_arch() >= 500) {\n#if CUDA_VERSION >= 9010\n    cudaDataType_t data_type = CudaDataTypeTrait<T>::value;\n    OF_CUBLAS_CHECK(cublasGemmStridedBatchedEx(\n        stream->cublas_handle(), opa, opb, m, n, k, reinterpret_cast<const void*>(&alpha),\n        reinterpret_cast<const void*>(a), data_type, lda, stridea, reinterpret_cast<const void*>(b),\n        data_type, ldb, strideb, reinterpret_cast<const void*>(&beta), reinterpret_cast<void*>(c),\n        data_type, ldc, stridec, batch_size, data_type, CUBLAS_GEMM_DEFAULT));\n#else\n    UNIMPLEMENTED();\n#endif\n  }\n}\n\n#if CUDA_VERSION >= 9010\n\ntemplate<>\nvoid CublasBatchGemm<half>(ep::CudaStream* stream, char transa, char transb, int64_t m, int64_t n,\n                           int64_t k, half alpha, const half* a, int64_t lda, int64_t stridea,\n                           const half* b, int64_t ldb, int64_t strideb, half beta, half* c,\n                           int64_t ldc, int64_t stridec, int64_t batch_size) {\n  using comp_t = float;\n  cublasOperation_t opa = GetCublasOp(transa);\n  cublasOperation_t opb = GetCublasOp(transb);\n\n  if (stream->cuda_arch() >= 500) {\n    float alpha_f = static_cast<comp_t>(alpha);\n    float beta_f = static_cast<comp_t>(beta);\n#if CUDA_VERSION >= 11000\n    cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT;\n#else\n    cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP;\n#endif\n    cudaDataType_t data_type = CudaDataTypeTrait<half>::value;\n    cudaDataType_t comp_type = CudaDataTypeTrait<comp_t>::value;\n    OF_CUBLAS_CHECK(cublasGemmStridedBatchedEx(\n        stream->cublas_handle(), opa, opb, m, n, k, &alpha_f, reinterpret_cast<const void*>(a),\n        data_type, lda, stridea, reinterpret_cast<const void*>(b), data_type, ldb, strideb, &beta_f,\n        reinterpret_cast<void*>(c), data_type, ldc, stridec, batch_size, comp_type, algo));\n  }\n}\n\ntemplate<>\nvoid CublasBatchGemm<float16>(ep::CudaStream* stream, char transa, char transb, int64_t m,\n                              int64_t n, int64_t k, float16 alpha, const float16* a, int64_t lda,\n                              int64_t stridea, const float16* b, int64_t ldb, int64_t strideb,\n                              float16 beta, float16* c, int64_t ldc, int64_t stridec,\n                              int64_t batch_size) {\n  CublasBatchGemm<half>(stream, transa, transb, m, n, k, static_cast<half>(alpha),\n                        reinterpret_cast<const half*>(a), lda, stridea,\n                        reinterpret_cast<const half*>(b), ldb, strideb, static_cast<half>(beta),\n                        reinterpret_cast<half*>(c), ldc, stridec, batch_size);\n}\n\n#endif  // CUDA_VERSION >= 9010\n\ntemplate<typename T>\nvoid BatchedGemm(ep::Stream* stream, char opa, char opb, int64_t m, int64_t n, int64_t k,\n                 float alpha, const T* a, int64_t lda, int64_t stridea, const T* b, int64_t ldb,\n                 int64_t strideb, float beta, T* c, int64_t ldc, int64_t stridec,\n                 int64_t batch_size) {\n  // swap m and n, a and b to convert from row-major to col-major\n  CublasBatchGemm<T>(stream->As<ep::CudaStream>(), opb, opa, n, m, k, static_cast<T>(alpha), b, ldb,\n                     strideb, a, lda, stridea, static_cast<T>(beta), c, ldc, stridec, batch_size);\n}\n\nSliceParams ConstructSliceParams4Value(int64_t seq_len, int64_t batch_size, int64_t num_heads,\n                                       int64_t head_size) {\n  // slice (s, b, n, 3, h) to (s, b, n, 1, h)\n  SliceParams params;\n  params.ndim = 4;\n  params.dims[0] = seq_len;\n  params.dims[1] = batch_size;\n  params.dims[2] = num_heads;\n  params.dims[3] = 3 * head_size;\n  params.start[0] = 0;\n  params.start[1] = 0;\n  params.start[2] = 0;\n  params.start[3] = 2 * head_size;\n  params.step[0] = 1;\n  params.step[1] = 1;\n  params.step[2] = 1;\n  params.step[3] = 1;\n  params.size[0] = seq_len;\n  params.size[1] = batch_size;\n  params.size[2] = num_heads;\n  params.size[3] = head_size;\n  return params;\n}\n\ntemplate<typename T>\nvoid TransposeGpu(ep::Stream* stream, DataType data_type, const ShapeView& in_shape,\n                  const ShapeView& out_shape, const std::vector<int32_t>& perm, const T* in,\n                  T* out) {\n  CHECK_EQ(in_shape.NumAxes(), out_shape.NumAxes());\n  int32_t num_axes = in_shape.NumAxes();\n  CHECK_EQ(num_axes, perm.size());\n  for (int i = 0; i < perm.size(); ++i) { CHECK_EQ(in_shape.At(perm[i]), out_shape.At(i)); }\n  auto transpose = ep::primitive::NewPrimitive<ep::primitive::PermuteFactory>(stream->device_type(),\n                                                                              in_shape.NumAxes());\n  CHECK(transpose);\n  transpose->Launch(stream, data_type, in_shape.NumAxes(), in_shape.ptr(), in, perm.data(), out);\n}\n\ntemplate<typename T>\nclass FusedSelfAttentionQueryMulKeyAndValueGpuKernel final : public user_op::OpKernel {\n public:\n  FusedSelfAttentionQueryMulKeyAndValueGpuKernel() = default;\n  ~FusedSelfAttentionQueryMulKeyAndValueGpuKernel() override = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* h_tensor = ctx->Tensor4ArgNameAndIndex(\"hidden_states\", 0);\n    int64_t seq_len = h_tensor->shape_view().At(0);\n    int64_t batch_size = h_tensor->shape_view().At(1);\n    int64_t hidden_size = h_tensor->shape_view().At(2);\n    int64_t head_size = ctx->Attr<int64_t>(\"head_size\");\n    int64_t num_heads = hidden_size / (3 * head_size);\n    int64_t ld = batch_size * hidden_size;\n    int64_t stride = 3 * head_size;\n    int64_t k_offset = head_size;\n\n    // q * k: (sq, b, n, h) x (sk, b, n, h) => (b, n, sq, h) x (b, n, sk, h)\n    // => (b, n, sq, h) x (b, n, h, sk) -> (b, n, sq, sk)\n    float alpha = ctx->Attr<float>(\"alpha\");\n    user_op::Tensor* qmk_tensor = ctx->Tensor4ArgNameAndIndex(\"query_mul_key\", 0);\n    const T* q_dptr = h_tensor->dptr<T>();\n    const T* k_dptr = h_tensor->dptr<T>() + k_offset;\n    BatchedGemm<T>(ctx->stream(), 'N', 'T', seq_len, seq_len, head_size, alpha, q_dptr, ld, stride,\n                   k_dptr, ld, stride, 0.0f, qmk_tensor->mut_dptr<T>(), seq_len, seq_len * seq_len,\n                   batch_size * num_heads);\n\n    // slice v\n    user_op::Tensor* tmp_v_tensor = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n    user_op::Tensor* v_tensor = ctx->Tensor4ArgNameAndIndex(\"value\", 0);\n    SliceParams params = ConstructSliceParams4Value(seq_len, batch_size, num_heads, head_size);\n    SliceKernelUtil<DeviceType::kCUDA, T>::Forward(ctx->stream(), params, h_tensor->dptr<T>(),\n                                                   tmp_v_tensor->mut_dptr<T>());\n    // v from (s, b, n, h) transpose to (b, n, s, h)\n    Shape value_shape({seq_len, batch_size, num_heads, head_size});\n    TransposeGpu<T>(ctx->stream(), h_tensor->data_type(), value_shape, v_tensor->shape_view(),\n                    {1, 2, 0, 3}, tmp_v_tensor->dptr<T>(), v_tensor->mut_dptr<T>());\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\ntemplate<typename T>\nclass FusedSelfAttentionQueryMulKeyAndValueGradGpuKernel final : public user_op::OpKernel {\n public:\n  FusedSelfAttentionQueryMulKeyAndValueGradGpuKernel() = default;\n  ~FusedSelfAttentionQueryMulKeyAndValueGradGpuKernel() override = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* v_grad_tensor = ctx->Tensor4ArgNameAndIndex(\"value_grad\", 0);\n    const user_op::Tensor* qmk_grad_tensor = ctx->Tensor4ArgNameAndIndex(\"query_mul_key_grad\", 0);\n    const user_op::Tensor* h_tensor = ctx->Tensor4ArgNameAndIndex(\"hidden_states\", 0);\n    user_op::Tensor* tmp_v_tensor = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n    user_op::Tensor* h_grad_tensor = ctx->Tensor4ArgNameAndIndex(\"hidden_states_grad\", 0);\n\n    float alpha = ctx->Attr<float>(\"alpha\");\n    int64_t seq_len = h_grad_tensor->shape_view().At(0);\n    int64_t batch_size = h_grad_tensor->shape_view().At(1);\n    int64_t hidden_size = h_grad_tensor->shape_view().At(2);\n    int64_t num_heads = v_grad_tensor->shape_view().At(1);\n    int64_t head_size = v_grad_tensor->shape_view().At(3);\n    int64_t ld = batch_size * hidden_size;\n    int64_t stride = 3 * head_size;\n    CHECK_EQ(hidden_size, num_heads * stride);\n\n    // transpose from (b, n, s, h) to (s, b, n, h)\n    Shape value_shape({seq_len, batch_size, num_heads, head_size});\n    TransposeGpu<T>(ctx->stream(), v_grad_tensor->data_type(), v_grad_tensor->shape_view(),\n                    value_shape, {2, 0, 1, 3}, v_grad_tensor->dptr<T>(),\n                    tmp_v_tensor->mut_dptr<T>());\n    // slice v grad\n    SliceParams params = ConstructSliceParams4Value(seq_len, batch_size, num_heads, head_size);\n    SliceKernelUtil<DeviceType::kCUDA, T>::Backward(ctx->stream(), params, tmp_v_tensor->dptr<T>(),\n                                                    h_grad_tensor->mut_dptr<T>());\n\n    // grad_q = grad_qmk * k\n    // (b, n, sq, sk) x (b, n, sk, h) -> (b, n, s, h) <= (s, b, n, h) <= (s, b, n, 3, h)\n    const T* qmk_grad_dptr = qmk_grad_tensor->dptr<T>();\n    const T* k_dptr = h_tensor->dptr<T>() + head_size;\n    T* grad_q_dptr = h_grad_tensor->mut_dptr<T>();\n    BatchedGemm<T>(ctx->stream(), 'N', 'N', seq_len, head_size, seq_len, alpha, qmk_grad_dptr,\n                   seq_len, seq_len * seq_len, k_dptr, ld, stride, 0.0f, grad_q_dptr, ld, stride,\n                   batch_size * num_heads);\n    // grad_k = grad_qmk * q\n    // (b, n, sk, sq) x (b, n, sq, h) -> (b, n, sk, h) <= (s, b, n, h) <= (s, b, n, 3, h)\n    const T* q_dptr = h_tensor->dptr<T>();\n    T* grad_k_dptr = h_grad_tensor->mut_dptr<T>() + head_size;\n    BatchedGemm<T>(ctx->stream(), 'T', 'N', seq_len, head_size, seq_len, alpha, qmk_grad_dptr,\n                   seq_len, seq_len * seq_len, q_dptr, ld, stride, 0.0f, grad_k_dptr, ld, stride,\n                   batch_size * num_heads);\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\nsize_t InferTmpBufferSize(user_op::InferContext* ctx) {\n  const Shape& value_shape = ctx->OutputShape(\"value\", 0);\n  DataType value_dtype = ctx->OutputDType(\"value\", 0);\n  return value_shape.elem_cnt() * GetSizeOfDataType(value_dtype);\n}\n\nsize_t InferGradTmpBufferSize(user_op::InferContext* ctx) {\n  const Shape& value_shape = ctx->InputShape(\"value_grad\", 0);\n  DataType value_dtype = ctx->InputDType(\"value_grad\", 0);\n  return value_shape.elem_cnt() * GetSizeOfDataType(value_dtype);\n}\n\n}  // namespace\n\n#define REGISTER_FUSED_SELF_ATTENTION_QUERY_MUL_KEY_AND_VALUE_CUDA_KERNEL(dtype)                   \\\n  REGISTER_USER_KERNEL(\"fused_self_attention_query_mul_key_and_value\")                             \\\n      .SetCreateFn<FusedSelfAttentionQueryMulKeyAndValueGpuKernel<dtype>>()                        \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)                             \\\n                       && (user_op::HobDataType(\"hidden_states\", 0) == GetDataType<dtype>::value)) \\\n      .SetInferTmpSizeFn(InferTmpBufferSize);\n\n#define REGISTER_FUSED_SELF_ATTENTION_QUERY_MUL_KEY_AND_VALUE_GRAD_CUDA_KERNEL(dtype)              \\\n  REGISTER_USER_KERNEL(\"fused_self_attention_query_mul_key_and_value_grad\")                        \\\n      .SetCreateFn<FusedSelfAttentionQueryMulKeyAndValueGradGpuKernel<dtype>>()                    \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)                             \\\n                       && (user_op::HobDataType(\"hidden_states\", 0) == GetDataType<dtype>::value)) \\\n      .SetInferTmpSizeFn(InferGradTmpBufferSize);\n\nREGISTER_FUSED_SELF_ATTENTION_QUERY_MUL_KEY_AND_VALUE_CUDA_KERNEL(float)\nREGISTER_FUSED_SELF_ATTENTION_QUERY_MUL_KEY_AND_VALUE_CUDA_KERNEL(float16)\nREGISTER_FUSED_SELF_ATTENTION_QUERY_MUL_KEY_AND_VALUE_GRAD_CUDA_KERNEL(float)\nREGISTER_FUSED_SELF_ATTENTION_QUERY_MUL_KEY_AND_VALUE_GRAD_CUDA_KERNEL(float16)\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/fused_softmax.cuh",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#ifndef ONEFLOW_USER_KERNELS_FUSED_SOFTMAX_H_\n#define ONEFLOW_USER_KERNELS_FUSED_SOFTMAX_H_\n\n#include \"oneflow/core/common/nd_index_offset_helper.h\"\n\nnamespace oneflow {\nnamespace cuda {\nnamespace fused_softmax {\n\ninline void SimplifyBroadcastDims(size_t num_a_dims, const int64_t* a_dims, size_t num_b_dims,\n                                  const int64_t* b_dims, size_t* simplified_num_dims,\n                                  int64_t* simplified_a_dims, int64_t* simplified_b_dims) {\n  const size_t num_max_dims = std::max(num_a_dims, num_b_dims);\n  auto MakeGetDim = [num_max_dims](size_t num_dims, const int64_t* dims) {\n    const int64_t num_padding_dims = num_max_dims - num_dims;\n    return [num_padding_dims, dims](size_t index) {\n      return index < num_padding_dims ? 1 : dims[index - num_padding_dims];\n    };\n  };\n  auto GetADim = MakeGetDim(num_a_dims, a_dims);\n  auto GetBDim = MakeGetDim(num_b_dims, b_dims);\n  *simplified_num_dims = 0;\n  bool prev_broadcast_a = false;\n  bool prev_broadcast_b = false;\n  for (int64_t i = 0; i < num_max_dims; ++i) {\n    const int64_t a_dim = GetADim(i);\n    const int64_t b_dim = GetBDim(i);\n    const int64_t broadcast_dim = std::max(a_dim, b_dim);\n    CHECK_GT(broadcast_dim, 0);\n    const bool broadcast_a = (a_dim == 1);\n    const bool broadcast_b = (b_dim == 1);\n    CHECK((a_dim == broadcast_dim) || broadcast_a);\n    CHECK((b_dim == broadcast_dim) || broadcast_b);\n    if (broadcast_dim == 1) {\n      continue;\n    } else if (*simplified_num_dims != 0\n               && (prev_broadcast_a == broadcast_a && prev_broadcast_b == broadcast_b)) {\n      simplified_a_dims[*simplified_num_dims - 1] *= a_dim;\n      simplified_b_dims[*simplified_num_dims - 1] *= b_dim;\n    } else {\n      simplified_a_dims[*simplified_num_dims] = a_dim;\n      simplified_b_dims[*simplified_num_dims] = b_dim;\n      *simplified_num_dims += 1;\n      prev_broadcast_a = broadcast_a;\n      prev_broadcast_b = broadcast_b;\n    }\n  }\n}\n\ntemplate<size_t num_dims, typename IndexType>\nstruct BroadcastMaskSoftmaxParams {\n  NdIndexOffsetHelper<IndexType, num_dims> src_index_helper;\n  NdIndexOffsetHelper<IndexType, num_dims> mask_index_helper;\n  const int64_t* mask_dims{};\n  int64_t row_size;\n  float fill;\n  float scale;\n};\n\nstruct ElementwiseMaskSoftmaxParams {\n  int64_t row_size;\n  float fill;\n  float scale;\n};\n\ntemplate<typename SRC, typename DST, typename MASK, size_t num_dims, typename IndexType>\nstruct BroadcastScaleMaskLoad {\n  BroadcastScaleMaskLoad(const SRC* src, const MASK* mask,\n                         BroadcastMaskSoftmaxParams<num_dims, IndexType> params)\n      : src(src), mask(mask), params(params) {\n    for (int i = 0; i < num_dims; i++) { mask_dims[i] = params.mask_dims[i]; }\n  }\n  template<int N>\n  __device__ void load(DST* dst, int64_t row, int64_t col) {\n    cuda::softmax::Pack<SRC, N> pack;\n    cuda::softmax::Pack<MASK, N> mask_pack;\n    const IndexType offset = row * params.row_size + col;\n    IndexType input_index[num_dims];\n    IndexType mask_index[num_dims];\n    params.src_index_helper.OffsetToNdIndex(offset, input_index);\n    for (int dim = 0; dim < num_dims; ++dim) {\n      if (mask_dims[dim] == 1) {\n        mask_index[dim] = 0;\n      } else {\n        mask_index[dim] = input_index[dim];\n      }\n    }\n    const IndexType mask_offset = params.mask_index_helper.NdIndexToOffset(mask_index);\n    pack.storage = *(reinterpret_cast<const cuda::softmax::PackType<SRC, N>*>(src) + offset / N);\n    mask_pack.storage =\n        *(reinterpret_cast<const cuda::softmax::PackType<MASK, N>*>(mask) + mask_offset / N);\n#pragma unroll\n    for (int i = 0; i < N; ++i) {\n      if (mask_pack.elem[i] == 0) {\n        dst[i] = static_cast<DST>(params.fill);\n      } else {\n        dst[i] = static_cast<DST>(pack.elem[i]) * static_cast<DST>(params.scale);\n      }\n    }\n  }\n  const SRC* src;\n  const MASK* mask;\n  int64_t mask_dims[num_dims];\n  BroadcastMaskSoftmaxParams<num_dims, IndexType> params;\n};\n\ntemplate<typename SRC, typename DST, typename MASK>\nstruct ElementwiseScaleMaskLoad {\n  ElementwiseScaleMaskLoad(const SRC* src, const MASK* mask, ElementwiseMaskSoftmaxParams param)\n      : src(src), mask(mask), param(param) {}\n  template<int N>\n  __device__ void load(DST* dst, int64_t row, int64_t col) {\n    cuda::softmax::Pack<SRC, N> pack;\n    const int64_t offset = (row * param.row_size + col) / N;\n    pack.storage = *(reinterpret_cast<const cuda::softmax::PackType<SRC, N>*>(src) + offset);\n    cuda::softmax::Pack<int8_t, N> mask_pack;\n    mask_pack.storage = *(reinterpret_cast<const cuda::softmax::PackType<MASK, N>*>(mask) + offset);\n#pragma unroll\n    for (int i = 0; i < N; ++i) {\n      if (mask_pack.elem[i] == 0) {\n        dst[i] = static_cast<DST>(param.fill);\n      } else {\n        dst[i] = static_cast<DST>(pack.elem[i]) * static_cast<DST>(param.scale);\n      }\n    }\n  }\n  const SRC* src;\n  const MASK* mask;\n  ElementwiseMaskSoftmaxParams param;\n};\n\ntemplate<typename SRC, typename DST, typename MASK, size_t num_dims, typename IndexType>\nstruct BroadcastScaleMaskStore {\n  BroadcastScaleMaskStore(DST* dst, const MASK* mask,\n                          BroadcastMaskSoftmaxParams<num_dims, IndexType> params)\n      : dst(dst), mask(mask), params(params) {\n    for (int i = 0; i < num_dims; ++i) { mask_dims[i] = params.mask_dims[i]; }\n  }\n  template<int N>\n  __device__ void store(const SRC* src, int64_t row, int64_t col) {\n    cuda::softmax::Pack<DST, N> pack;\n    cuda::softmax::Pack<MASK, N> mask_pack;\n    const IndexType offset = row * params.row_size + col;\n    IndexType input_index[num_dims];\n    IndexType mask_index[num_dims];\n    params.src_index_helper.OffsetToNdIndex(offset, input_index);\n    for (int dim = 0; dim < num_dims; ++dim) {\n      if (mask_dims[dim] == 1) {\n        mask_index[dim] = 0;\n      } else {\n        mask_index[dim] = input_index[dim];\n      }\n    }\n    const IndexType mask_offset = params.mask_index_helper.NdIndexToOffset(mask_index);\n    mask_pack.storage =\n        *(reinterpret_cast<const cuda::softmax::PackType<MASK, N>*>(mask) + mask_offset / N);\n#pragma unroll\n    for (int i = 0; i < N; ++i) {\n      if (mask_pack.elem[i] == 0) {\n        pack.elem[i] = static_cast<DST>(params.fill);\n      } else {\n        pack.elem[i] = static_cast<DST>(src[i]) * static_cast<DST>(params.scale);\n      }\n    }\n    *(reinterpret_cast<cuda::softmax::PackType<DST, N>*>(dst) + offset / N) = pack.storage;\n  }\n  DST* dst;\n  const MASK* mask;\n  int64_t mask_dims[num_dims];\n  BroadcastMaskSoftmaxParams<num_dims, IndexType> params;\n};\n\ntemplate<typename SRC, typename DST, typename MASK>\nstruct ElementwiseScaleMaskStore {\n  ElementwiseScaleMaskStore(DST* dst, const MASK* mask, ElementwiseMaskSoftmaxParams params)\n      : dst(dst), mask(mask), params(params) {}\n  template<int N>\n  __device__ void store(const SRC* src, int64_t row, int64_t col) {\n    cuda::softmax::Pack<DST, N> pack;\n    const int64_t offset = (row * params.row_size + col) / N;\n    cuda::softmax::Pack<MASK, N> mask_pack;\n    mask_pack.storage = *(reinterpret_cast<const cuda::softmax::PackType<MASK, N>*>(mask) + offset);\n#pragma unroll\n    for (int i = 0; i < N; ++i) {\n      if (mask_pack.elem[i] == 0) {\n        pack.elem[i] = params.fill;\n      } else {\n        pack.elem[i] = static_cast<DST>(src[i]) * static_cast<DST>(params.scale);\n      }\n    }\n    *(reinterpret_cast<cuda::softmax::PackType<DST, N>*>(dst) + offset) = pack.storage;\n  }\n  DST* dst;\n  const MASK* mask;\n  ElementwiseMaskSoftmaxParams params;\n};\n\ntemplate<typename SRC, typename DST>\nstruct MaskScaleLoad {\n  MaskScaleLoad(const SRC* src, const bool* mask, int64_t row_size, SRC scale)\n      : src(src), mask(mask), row_size(row_size), scale(scale) {}\n  template<int N>\n  __device__ void load(DST* dst, int64_t row, int64_t col) const {\n    cuda::softmax::Pack<SRC, N> pack;\n    const int64_t offset = (row * row_size + col) / N;\n    pack.storage = *(reinterpret_cast<const cuda::softmax::PackType<SRC, N>*>(src) + offset);\n    cuda::softmax::Pack<bool, N> mask_pack;\n    mask_pack.storage = *(reinterpret_cast<const cuda::softmax::PackType<bool, N>*>(mask) + offset);\n#pragma unroll\n    for (int i = 0; i < N; ++i) {\n      dst[i] = static_cast<DST>(pack.elem[i]) * static_cast<DST>(mask_pack.elem[i])\n               * static_cast<DST>(scale);\n    }\n  }\n  const SRC* src;\n  const bool* mask;\n  int64_t row_size;\n  SRC scale;\n};\n\ntemplate<typename SRC, typename DST>\nstruct DropoutStore {\n  DropoutStore(DST* dst, DST* softmax_y, const bool* mask, int64_t row_size, DST scale)\n      : dst(dst), softmax_y(softmax_y), mask(mask), row_size(row_size), scale(scale) {}\n  template<int N>\n  __device__ void store(const SRC* src, int64_t row, int64_t col) {\n    cuda::softmax::Pack<DST, N> softmax_y_pack;\n    cuda::softmax::Pack<DST, N> dst_pack;\n    const int64_t offset = (row * row_size + col) / N;\n    cuda::softmax::Pack<bool, N> mask_pack;\n    mask_pack.storage = *(reinterpret_cast<const cuda::softmax::PackType<bool, N>*>(mask) + offset);\n#pragma unroll\n    for (int i = 0; i < N; ++i) {\n      softmax_y_pack.elem[i] = static_cast<DST>(src[i]);\n      dst_pack.elem[i] =\n          static_cast<DST>(src[i]) * static_cast<DST>(mask_pack.elem[i]) * static_cast<DST>(scale);\n    }\n    *(reinterpret_cast<cuda::softmax::PackType<DST, N>*>(softmax_y) + offset) =\n        softmax_y_pack.storage;\n    *(reinterpret_cast<cuda::softmax::PackType<DST, N>*>(dst) + offset) = dst_pack.storage;\n  }\n  DST* dst;\n  DST* softmax_y;\n  const bool* mask;\n  int64_t row_size;\n  DST scale;\n};\n\n}  // namespace fused_softmax\n}  // namespace cuda\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_KERNELS_FUSED_SOFTMAX_H_\n"
  },
  {
    "path": "oneflow/user/kernels/fused_tril_scale_softmax_mask_scale_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/cuda/softmax.cuh\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n\nnamespace oneflow {\n\ntemplate<typename SRC, typename DST>\nstruct TrilScaleLoad {\n  TrilScaleLoad(const SRC* src, int64_t tril_num_rows, int64_t row_size, int64_t diagonal, SRC fill,\n                SRC scale)\n      : src(src),\n        tril_num_rows(tril_num_rows),\n        row_size(row_size),\n        diagonal(diagonal),\n        fill(fill),\n        scale(scale) {}\n  template<int N>\n  __device__ void load(DST* dst, int64_t row, int64_t col) {\n    int64_t tril_row = row % tril_num_rows;\n    int64_t diagonal_col_id = tril_row + diagonal;\n    bool need_load = (col <= diagonal_col_id);\n    cuda::softmax::Pack<SRC, N> pack;\n    if (need_load) {\n      const int64_t offset = (row * row_size + col) / N;\n      pack.storage = *(reinterpret_cast<const cuda::softmax::PackType<SRC, N>*>(src) + offset);\n    }\n#pragma unroll\n    for (int i = 0; i < N; ++i) {\n      if (col + i > diagonal_col_id) {\n        dst[i] = static_cast<DST>(fill);\n      } else {\n        dst[i] = static_cast<DST>(pack.elem[i]) * static_cast<DST>(scale);\n      }\n    }\n  }\n  const SRC* src;\n  int64_t tril_num_rows;\n  int64_t row_size;\n  int64_t diagonal;\n  SRC fill;\n  SRC scale;\n};\n\ntemplate<typename SRC, typename DST>\nstruct MaskAndScaleStore {\n  MaskAndScaleStore(DST* dst, DST* softmax_y, const bool* mask, int64_t row_size, DST scale)\n      : dst(dst), softmax_y(softmax_y), mask(mask), row_size(row_size), scale(scale) {}\n  template<int N>\n  __device__ void store(const SRC* src, int64_t row, int64_t col) {\n    cuda::softmax::Pack<DST, N> softmax_y_pack;\n    cuda::softmax::Pack<DST, N> dst_pack;\n    const int64_t offset = (row * row_size + col) / N;\n    cuda::softmax::Pack<bool, N> mask_pack;\n    mask_pack.storage = *(reinterpret_cast<const cuda::softmax::PackType<bool, N>*>(mask) + offset);\n#pragma unroll\n    for (int i = 0; i < N; ++i) {\n      softmax_y_pack.elem[i] = static_cast<DST>(src[i]);\n      dst_pack.elem[i] =\n          static_cast<DST>(src[i]) * static_cast<DST>(mask_pack.elem[i]) * static_cast<DST>(scale);\n    }\n    *(reinterpret_cast<cuda::softmax::PackType<DST, N>*>(softmax_y) + offset) =\n        softmax_y_pack.storage;\n    *(reinterpret_cast<cuda::softmax::PackType<DST, N>*>(dst) + offset) = dst_pack.storage;\n  }\n  DST* dst;\n  DST* softmax_y;\n  const bool* mask;\n  int64_t row_size;\n  DST scale;\n};\n\ntemplate<typename SRC, typename DST>\nstruct MaskAndScaleLoad {\n  MaskAndScaleLoad(const SRC* src, const bool* mask, int64_t row_size, SRC scale)\n      : src(src), mask(mask), row_size(row_size), scale(scale) {}\n  template<int N>\n  __device__ void load(DST* dst, int64_t row, int64_t col) const {\n    cuda::softmax::Pack<SRC, N> pack;\n    const int64_t offset = (row * row_size + col) / N;\n    pack.storage = *(reinterpret_cast<const cuda::softmax::PackType<SRC, N>*>(src) + offset);\n    cuda::softmax::Pack<bool, N> mask_pack;\n    mask_pack.storage = *(reinterpret_cast<const cuda::softmax::PackType<bool, N>*>(mask) + offset);\n#pragma unroll\n    for (int i = 0; i < N; ++i) {\n      dst[i] = static_cast<DST>(pack.elem[i]) * static_cast<DST>(mask_pack.elem[i])\n               * static_cast<DST>(scale);\n    }\n  }\n  const SRC* src;\n  const bool* mask;\n  int64_t row_size;\n  SRC scale;\n};\n\ntemplate<typename SRC, typename DST>\nstruct TrilScaleStore {\n  TrilScaleStore(DST* dst, int64_t tril_num_rows, int64_t row_size, int64_t diagonal, DST fill,\n                 DST scale)\n      : dst(dst),\n        tril_num_rows(tril_num_rows),\n        row_size(row_size),\n        diagonal(diagonal),\n        fill(fill),\n        scale(scale) {}\n  template<int N>\n  __device__ void store(const SRC* src, int64_t row, int64_t col) {\n    cuda::softmax::Pack<DST, N> pack;\n    const int64_t offset = (row * row_size + col) / N;\n    int64_t tril_row = row % tril_num_rows;\n#pragma unroll\n    for (int i = 0; i < N; ++i) {\n      if (col + i > tril_row + diagonal) {\n        pack.elem[i] = fill;\n      } else {\n        pack.elem[i] = static_cast<DST>(src[i]) * static_cast<DST>(scale);\n      }\n    }\n    *(reinterpret_cast<cuda::softmax::PackType<DST, N>*>(dst) + offset) = pack.storage;\n  }\n  DST* dst;\n  int64_t tril_num_rows;\n  int64_t row_size;\n  int64_t diagonal;\n  DST fill;\n  DST scale;\n};\n\ntemplate<typename T>\nclass FusedTrilScaleSoftmaxMaskScaleKernel final : public user_op::OpKernel {\n public:\n  FusedTrilScaleSoftmaxMaskScaleKernel() = default;\n  ~FusedTrilScaleSoftmaxMaskScaleKernel() override = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    const user_op::Tensor* mask = ctx->Tensor4ArgNameAndIndex(\"mask\", 0);\n    user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    user_op::Tensor* softmax_y = ctx->Tensor4ArgNameAndIndex(\"softmax_y\", 0);\n    const ShapeView& x_shape = x->shape_view();\n    CHECK_GE(x_shape.NumAxes(), 2);\n    const int64_t cols = x_shape.At(x_shape.NumAxes() - 1);\n    const int64_t rows = x_shape.Count(0, x_shape.NumAxes() - 1);\n    const int64_t tril_num_rows = x_shape.At(x_shape.NumAxes() - 2);\n    using ComputeType = typename cuda::softmax::DefaultComputeType<T>::type;\n    TrilScaleLoad<T, ComputeType> load(\n        x->dptr<T>(), tril_num_rows, cols, ctx->Attr<int64_t>(\"diagonal\"),\n        ctx->Attr<float>(\"tril_fill_value\"), ctx->Attr<float>(\"tril_scale_value\"));\n    MaskAndScaleStore<ComputeType, T> store(y->mut_dptr<T>(), softmax_y->mut_dptr<T>(),\n                                            mask->dptr<bool>(), cols,\n                                            ctx->Attr<float>(\"mask_scale_value\"));\n    OF_CUDA_CHECK((cuda::softmax::DispatchSoftmax<decltype(load), decltype(store), ComputeType>(\n        ctx->stream()->As<ep::CudaStream>()->cuda_stream(), load, store, rows, cols)));\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_FUSED_TRIL_SCALE_SOFTMAX_MASK_SCALE_CUDA_KERNEL(dtype) \\\n  REGISTER_USER_KERNEL(\"fused_tril_scale_softmax_mask_scale\")           \\\n      .SetCreateFn<FusedTrilScaleSoftmaxMaskScaleKernel<dtype>>()       \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)  \\\n                       && (user_op::HobDataType(\"y\", 0) == GetDataType<dtype>::value));\n\nREGISTER_FUSED_TRIL_SCALE_SOFTMAX_MASK_SCALE_CUDA_KERNEL(half)\nREGISTER_FUSED_TRIL_SCALE_SOFTMAX_MASK_SCALE_CUDA_KERNEL(float)\nREGISTER_FUSED_TRIL_SCALE_SOFTMAX_MASK_SCALE_CUDA_KERNEL(double)\n#undef REGISTER_FUSED_TRIL_SCALE_SOFTMAX_MASK_SCALE_CUDA_KERNEL\n\ntemplate<typename T>\nclass FusedTrilScaleSoftmaxMaskScaleGradKernel final : public user_op::OpKernel {\n public:\n  FusedTrilScaleSoftmaxMaskScaleGradKernel() = default;\n  ~FusedTrilScaleSoftmaxMaskScaleGradKernel() override = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* softmax_y = ctx->Tensor4ArgNameAndIndex(\"softmax_y\", 0);\n    const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    const user_op::Tensor* mask = ctx->Tensor4ArgNameAndIndex(\"mask\", 0);\n    user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex(\"dx\", 0);\n    const ShapeView& dy_shape = dy->shape_view();\n    CHECK_GE(dy_shape.NumAxes(), 2);\n    const int64_t cols = dy_shape.At(dy_shape.NumAxes() - 1);\n    const int64_t rows = dy_shape.Count(0, dy_shape.NumAxes() - 1);\n    const int64_t tril_num_rows = dy_shape.At(dy_shape.NumAxes() - 2);\n    using ComputeType = typename cuda::softmax::DefaultComputeType<T>::type;\n    cuda::softmax::DirectLoad<T, ComputeType> load_softmax_y(softmax_y->dptr<T>(), cols);\n    MaskAndScaleLoad<T, ComputeType> load_dy(dy->dptr<T>(), mask->dptr<bool>(), cols,\n                                             ctx->Attr<float>(\"mask_scale_value\"));\n    TrilScaleStore<ComputeType, T> store(dx->mut_dptr<T>(), tril_num_rows, cols,\n                                         ctx->Attr<int64_t>(\"diagonal\"), static_cast<T>(0.0),\n                                         ctx->Attr<float>(\"tril_scale_value\"));\n    OF_CUDA_CHECK((cuda::softmax::DispatchSoftmaxGrad<decltype(load_softmax_y), decltype(load_dy),\n                                                      decltype(store), ComputeType>(\n        ctx->stream()->As<ep::CudaStream>()->cuda_stream(), load_softmax_y, load_dy, store, rows,\n        cols)));\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_FUSED_TRIL_SCALE_SOFTMAX_MASK_SCALE_GRAD_KERNEL(dtype) \\\n  REGISTER_USER_KERNEL(\"fused_tril_scale_softmax_mask_scale_grad\")      \\\n      .SetCreateFn<FusedTrilScaleSoftmaxMaskScaleGradKernel<dtype>>()   \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)  \\\n                       && (user_op::HobDataType(\"dx\", 0) == GetDataType<dtype>::value));\n\nREGISTER_FUSED_TRIL_SCALE_SOFTMAX_MASK_SCALE_GRAD_KERNEL(half)\nREGISTER_FUSED_TRIL_SCALE_SOFTMAX_MASK_SCALE_GRAD_KERNEL(float)\nREGISTER_FUSED_TRIL_SCALE_SOFTMAX_MASK_SCALE_GRAD_KERNEL(double)\n#undef REGISTER_FUSED_TRIL_SCALE_SOFTMAX_MASK_SCALE_GRAD_KERNEL\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/fused_weighted_sum_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/new_kernel_util.h\"\n#include \"oneflow/core/ep/cpu/cpu_stream.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename T>\nclass FusedWeightedSumKernel final : public user_op::OpKernel {\n public:\n  FusedWeightedSumKernel() = default;\n  ~FusedWeightedSumKernel() override = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    const int64_t arity = ctx->input_size(\"in\");\n    CHECK_GE(arity, 1);\n    const std::vector<float>& weights = ctx->Attr<std::vector<float>>(\"weights\");\n    CHECK_EQ(weights.size(), arity);\n    const float alpha = ctx->Attr<float>(\"alpha\");\n    const DataType data_type = out->data_type();\n    const ShapeView& shape = out->shape_view();\n    std::vector<const T*> inputs(arity);\n    for (int i = 0; i < arity; ++i) {\n      const user_op::Tensor* in_i = ctx->Tensor4ArgNameAndIndex(\"in\", i);\n      CHECK(in_i->shape_view() == shape);\n      CHECK_EQ(in_i->data_type(), data_type);\n      inputs[i] = in_i->dptr<T>();\n    }\n    T* out_ptr = out->mut_dptr<T>();\n    auto* cpu_stream = ctx->stream()->As<ep::CpuStream>();\n    cpu_stream->ParallelFor(0, shape.elem_cnt(), [&](int64_t s, int64_t e) {\n      for (int64_t i = s; i < e; ++i) {\n        T out = static_cast<T>(0.0);\n        for (int j = 0; j < arity; ++j) { out += inputs[j][i] * static_cast<T>(weights[j]); }\n        out_ptr[i] = out * static_cast<T>(alpha);\n      }\n    });\n  }\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n}  // namespace\n\n#define REGISTER_FUSED_WEIGHT_SUM_KERNEL(data_type, cpp_type)         \\\n  REGISTER_USER_KERNEL(\"fused_weighted_sum\")                          \\\n      .SetCreateFn<FusedWeightedSumKernel<cpp_type>>()                \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \\\n                       && (user_op::HobDataType(\"out\", 0) == data_type))\n\nREGISTER_FUSED_WEIGHT_SUM_KERNEL(DataType::kDouble, double);\nREGISTER_FUSED_WEIGHT_SUM_KERNEL(DataType::kFloat, float);\nREGISTER_FUSED_WEIGHT_SUM_KERNEL(DataType::kFloat16, float16);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/fused_weighted_sum_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/device/cuda_util.h\"\n#include \"oneflow/core/kernel/new_kernel_util.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename T, int arity>\nstruct Params {\n  const T* inputs[arity];\n  float weights[arity];\n  float alpha{};\n  T* output;\n  int64_t n;\n};\n\ntemplate<typename T, int arity, bool acc>\n__global__ void WeightedSumKernel(Params<T, arity> params) {\n  CUDA_1D_KERNEL_LOOP_T(int64_t, i, params.n) {\n    T out = 0;\n    if (acc) { out = params.output[i]; }\n#pragma unroll\n    for (int j = 0; j < arity; ++j) {\n      out += params.inputs[j][i] * static_cast<T>(params.weights[j]);\n    }\n    params.output[i] = out * static_cast<T>(params.alpha);\n  }\n}\n\ntemplate<typename T, int arity, bool acc>\nvoid LaunchWeightedSum(ep::Stream* stream, int n, const T** inputs, const float* weights,\n                       float alpha, T* output) {\n  Params<T, arity> params{};\n  for (int i = 0; i < arity; ++i) {\n    params.inputs[i] = *(inputs + i);\n    params.weights[i] = *(weights + i);\n  }\n  params.alpha = alpha;\n  params.output = output;\n  params.n = n;\n  RUN_CUDA_KERNEL((WeightedSumKernel<T, arity, acc>), stream, n, params);\n}\n\ntemplate<typename T, bool acc>\nvoid DispatchWeightedSum(ep::Stream* stream, int arity, int64_t n, const T** inputs,\n                         const float* weights, float alpha, T* output) {\n  if (arity == 1) {\n    LaunchWeightedSum<T, 1, acc>(stream, n, inputs, weights, alpha, output);\n  } else if (arity == 2) {\n    LaunchWeightedSum<T, 2, acc>(stream, n, inputs, weights, alpha, output);\n  } else if (arity == 3) {\n    LaunchWeightedSum<T, 3, acc>(stream, n, inputs, weights, alpha, output);\n  } else if (arity == 4) {\n    LaunchWeightedSum<T, 4, acc>(stream, n, inputs, weights, alpha, output);\n  } else if (arity == 5) {\n    LaunchWeightedSum<T, 5, acc>(stream, n, inputs, weights, alpha, output);\n  } else if (arity == 6) {\n    LaunchWeightedSum<T, 6, acc>(stream, n, inputs, weights, alpha, output);\n  } else if (arity == 7) {\n    LaunchWeightedSum<T, 7, acc>(stream, n, inputs, weights, alpha, output);\n  } else if (arity == 8) {\n    LaunchWeightedSum<T, 8, acc>(stream, n, inputs, weights, alpha, output);\n  } else if (arity > 8) {\n    LaunchWeightedSum<T, 8, acc>(stream, n, inputs, weights, 1.0F, output);\n    DispatchWeightedSum<T, true>(stream, arity - 8, n, inputs + 8, weights + 8, alpha, output);\n  } else {\n    UNIMPLEMENTED();\n  }\n}\n\ntemplate<typename T>\nclass FusedWeightedSumKernel final : public user_op::OpKernel {\n public:\n  FusedWeightedSumKernel() = default;\n  ~FusedWeightedSumKernel() override = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    const int64_t arity = ctx->input_size(\"in\");\n    CHECK_GE(arity, 1) << \"input_size should be greater than 0.\";\n    const std::vector<float>& weights = ctx->Attr<std::vector<float>>(\"weights\");\n    CHECK_EQ(weights.size(), arity);\n    const float alpha = ctx->Attr<float>(\"alpha\");\n    const DataType data_type = out->data_type();\n    const ShapeView& shape = out->shape_view();\n    std::vector<const T*> inputs(arity);\n    for (int i = 0; i < arity; ++i) {\n      const user_op::Tensor* in_i = ctx->Tensor4ArgNameAndIndex(\"in\", i);\n      CHECK(in_i->shape_view() == shape);\n      CHECK_EQ(in_i->data_type(), data_type);\n      inputs[i] = in_i->dptr<T>();\n    }\n    DispatchWeightedSum<T, false>(ctx->stream(), arity, shape.elem_cnt(), inputs.data(),\n                                  weights.data(), alpha, out->mut_dptr<T>());\n  }\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n}  // namespace\n\n#define REGISTER_FUSED_WEIGHT_SUM_KERNEL(data_type, cpp_type)          \\\n  REGISTER_USER_KERNEL(\"fused_weighted_sum\")                           \\\n      .SetCreateFn<FusedWeightedSumKernel<cpp_type>>()                 \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \\\n                       && (user_op::HobDataType(\"out\", 0) == data_type))\n\nREGISTER_FUSED_WEIGHT_SUM_KERNEL(DataType::kDouble, double);\nREGISTER_FUSED_WEIGHT_SUM_KERNEL(DataType::kFloat, float);\nREGISTER_FUSED_WEIGHT_SUM_KERNEL(DataType::kFloat16, half);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/gather_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/user/kernels/gather_kernel_util.h\"\n#include \"oneflow/core/kernel/cuda_graph_support.h\"\n#include \"oneflow/core/job/nd_sbp_util.h\"\n\nnamespace oneflow {\n\nnamespace user_op {\n\nnamespace {\n\nShape GetFlatShape(ShapeView shape, int64_t axis) {\n  return Shape({shape.Count(0, axis), shape.At(axis), shape.Count(axis + 1)});\n}\n\nclass GatherOpKernelCache final : public user_op::OpKernelCache {\n public:\n  GatherOpKernelCache(int64_t lower, int64_t upper) : lower_(lower), upper_(upper) {}\n  ~GatherOpKernelCache() override = default;\n\n  int64_t lower() const { return lower_; }\n  int64_t upper() const { return upper_; }\n\n private:\n  const int64_t lower_;\n  const int64_t upper_;\n};\n\nvoid CheckNdSbp(const Shape& hierarchy, int64_t gather_axis, const NdSbp& in_nd_sbp,\n                const NdSbp& indices_nd_sbp, const NdSbp& out_nd_sbp) {\n  CHECK_EQ(hierarchy.NumAxes(), in_nd_sbp.sbp_parallel_size());\n  CHECK_EQ(hierarchy.NumAxes(), indices_nd_sbp.sbp_parallel_size());\n  CHECK_EQ(hierarchy.NumAxes(), out_nd_sbp.sbp_parallel_size());\n  if (hierarchy.elem_cnt() == 1) { return; }\n  FOR_RANGE(int64_t, i, 0, hierarchy.NumAxes()) {\n    const auto& in_sbp = in_nd_sbp.sbp_parallel(i);\n    if (in_sbp.has_split_parallel() && in_sbp.split_parallel().axis() == gather_axis) {\n      CHECK(indices_nd_sbp.sbp_parallel(i).has_broadcast_parallel());\n      CHECK(out_nd_sbp.sbp_parallel(i).has_partial_sum_parallel());\n    }\n  }\n}\n\n}  // namespace\n\ntemplate<DeviceType device_type, typename T, typename K>\nclass GatherKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport {\n public:\n  GatherKernel() = default;\n  ~GatherKernel() override = default;\n\n  std::shared_ptr<user_op::OpKernelCache> InitOpKernelCache(\n      user_op::KernelCacheContext* ctx) const override {\n    if (ctx->parallel_ctx().parallel_num() > 1) {\n      const auto axis = ctx->Attr<int64_t>(\"axis\");\n      const NdSbp& in_nd_sbp = ctx->NdSbp4ArgNameAndIndex(\"in\", 0);\n      const Shape& hierarchy = *ctx->parallel_desc().hierarchy();\n      CheckNdSbp(hierarchy, axis, in_nd_sbp, ctx->NdSbp4ArgNameAndIndex(\"indices\", 0),\n                 ctx->NdSbp4ArgNameAndIndex(\"out\", 0));\n      const Shape in_logical_shape =\n          ExpandDimIf0D(ctx->LogicalTensorDesc4ArgNameAndIndex(\"in\", 0)->shape());\n      TensorSliceView view = GetTensorSliceView4ParallelId(hierarchy, in_nd_sbp, in_logical_shape,\n                                                           ctx->parallel_ctx().parallel_id());\n      return std::make_shared<GatherOpKernelCache>(view.At(axis).begin(), view.At(axis).end());\n    } else {\n      return nullptr;\n    }\n  }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*,\n               const user_op::OpKernelCache* cache) const override {\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    const user_op::Tensor* indices = ctx->Tensor4ArgNameAndIndex(\"indices\", 0);\n    const int64_t axis = ctx->Attr<int64_t>(\"axis\");\n    const int64_t num_indices = indices->shape_view().elem_cnt();\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    if (out->shape_view().elem_cnt() == 0) { return; }\n\n    const Shape in_shape = ExpandDimIf0D(in->shape_view());\n\n    int64_t offset = 0;\n    if (cache != nullptr) {\n      auto* gather_cache = dynamic_cast<const GatherOpKernelCache*>(cache);\n      CHECK_NOTNULL(gather_cache);\n      CHECK_EQ(in_shape.At(axis), gather_cache->upper() - gather_cache->lower());\n      offset = gather_cache->lower();\n    }\n\n    GatherKernelUtilImpl<device_type, T, K>::Forward(ctx->stream(), indices->dptr<K>(), num_indices,\n                                                     in->dptr<T>(), GetFlatShape(in_shape, axis),\n                                                     out->mut_dptr<T>(), offset);\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_GATHER_KERNEL(device, in_type, indices_type)                                \\\n  REGISTER_USER_KERNEL(\"gather\")                                                             \\\n      .SetCreateFn<                                                                          \\\n          GatherKernel<device, OF_PP_PAIR_FIRST(in_type), OF_PP_PAIR_FIRST(indices_type)>>() \\\n      .SetIsMatchedHob(                                                                      \\\n          (user_op::HobDeviceType() == device)                                               \\\n          && (user_op::HobDataType(\"in\", 0) == OF_PP_PAIR_SECOND(in_type))                   \\\n          && (user_op::HobDataType(\"indices\", 0) == OF_PP_PAIR_SECOND(indices_type)));\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_GATHER_KERNEL, DEVICE_TYPE_SEQ, GATHER_DATA_TYPE_SEQ,\n                                 INDEX_DATA_TYPE_SEQ)\n// For cpu float16/bfloat16\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_GATHER_KERNEL, OF_PP_MAKE_TUPLE_SEQ(DeviceType::kCPU),\n                                 FLOAT16_DATA_TYPE_SEQ BFLOAT16_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ)\n\n#ifdef WITH_CUDA\n// For cuda half\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_GATHER_KERNEL, OF_PP_MAKE_TUPLE_SEQ(DeviceType::kCUDA),\n                                 HALF_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ)\n#if CUDA_VERSION >= 11000\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_GATHER_KERNEL, OF_PP_MAKE_TUPLE_SEQ(DeviceType::kCUDA),\n                                 OF_PP_MAKE_TUPLE_SEQ(nv_bfloat16, DataType::kBFloat16),\n                                 INDEX_DATA_TYPE_SEQ)\n#endif\n\n#endif\n\n}  // namespace user_op\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/gather_kernel_util.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/user/kernels/gather_kernel_util.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nShape GetFlatShape(const ShapeView& shape, int64_t axis) {\n  CHECK_GT(shape.NumAxes(), 0);\n  CHECK_GE(axis, 0);\n  CHECK_LT(axis, shape.NumAxes());\n  return Shape({shape.Count(0, axis), shape.At(axis), shape.Count(axis + 1)});\n}\n\ntemplate<DeviceType device_type, typename T, typename K>\nvoid GatherForward(ep::Stream* stream, const Blob* indices, const Blob* in, int64_t axis, Blob* out,\n                   const int64_t offset) {\n  const Shape& flat_in_shape = GetFlatShape(in->shape_view(), axis);\n  GatherKernelUtilImpl<device_type, T, K>::Forward(stream, indices->dptr<K>(),\n                                                   indices->shape_view().elem_cnt(), in->dptr<T>(),\n                                                   flat_in_shape, out->mut_dptr<T>(), offset);\n}\n\ntemplate<DeviceType device_type, typename T>\nstruct GatherSwitchUtil final {\n#define MAKE_GATHER_SWITCH_ENTRY(func_name, K) func_name<device_type, T, K>\n#define DEFINE_GATHER_STATIC_SWITCH_FUNC(func_name)                    \\\n  DEFINE_STATIC_SWITCH_FUNC(void, func_name, MAKE_GATHER_SWITCH_ENTRY, \\\n                            MAKE_DATA_TYPE_CTRV_SEQ(INDEX_DATA_TYPE_SEQ));\n  DEFINE_GATHER_STATIC_SWITCH_FUNC(GatherForward);\n#undef DEFINE_GATHER_STATIC_SWITCH_FUNC\n#undef MAKE_GATHER_SWITCH_ENTRY\n};\n\n}  // namespace\n\ntemplate<DeviceType device_type, typename T>\nvoid GatherKernelUtil<device_type, T>::Forward(ep::Stream* stream, const Blob* indices,\n                                               const Blob* in, const int64_t axis, Blob* out) {\n  GatherKernelUtil<device_type, T>::Forward(stream, indices, in, axis, out, 0);\n}\n\ntemplate<DeviceType device_type, typename T>\nvoid GatherKernelUtil<device_type, T>::Forward(ep::Stream* stream, const Blob* indices,\n                                               const Blob* in, const int64_t axis, Blob* out,\n                                               const int64_t offset) {\n  GatherSwitchUtil<device_type, T>::SwitchGatherForward(SwitchCase(indices->data_type()), stream,\n                                                        indices, in, axis, out, offset);\n}\n\ntemplate<typename T, typename K>\nstruct GatherKernelUtilImpl<DeviceType::kCPU, T, K> final {\n  static void Forward(ep::Stream* stream, const K* indices, int64_t num_indices, const T* in,\n                      const Shape& flat_in_shape, T* out, const int64_t offset);\n};\n\ntemplate<typename T, typename K>\nvoid GatherKernelUtilImpl<DeviceType::kCPU, T, K>::Forward(ep::Stream* stream, const K* indices,\n                                                           int64_t num_indices, const T* in,\n                                                           const Shape& flat_in_shape, T* out,\n                                                           const int64_t offset) {\n  const int64_t outer_dim_size = flat_in_shape.At(0);\n  const int64_t gather_dim_size = flat_in_shape.At(1);\n  const int64_t inner_dim_size = flat_in_shape.At(2);\n  FOR_RANGE(int64_t, outer_idx, 0, outer_dim_size) {\n    FOR_RANGE(int64_t, i, 0, num_indices) {\n      CHECK_GE(indices[i], 0);\n      const int64_t idx = indices[i] - offset;\n      T* to = out + outer_idx * num_indices * inner_dim_size + i * inner_dim_size;\n      if (idx >= 0 && idx < gather_dim_size) {\n        const T* from = in + outer_idx * gather_dim_size * inner_dim_size + idx * inner_dim_size;\n        std::copy(from, from + inner_dim_size, to);\n      } else {\n        std::memset(reinterpret_cast<void*>(to), 0, inner_dim_size * sizeof(T));\n      }\n    }\n  }\n}\n\n#define INITIATE_GATHER_KERNEL_UTIL_CPU_IMPL(in_type_pair, index_type_pair)              \\\n  template struct GatherKernelUtilImpl<DeviceType::kCPU, OF_PP_PAIR_FIRST(in_type_pair), \\\n                                       OF_PP_PAIR_FIRST(index_type_pair)>;\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INITIATE_GATHER_KERNEL_UTIL_CPU_IMPL,\n                                 GATHER_DATA_TYPE_SEQ FLOAT16_DATA_TYPE_SEQ BFLOAT16_DATA_TYPE_SEQ,\n                                 GATHER_INDEX_TYPE_SEQ);\n#undef INITIATE_GATHER_KERNEL_UTIL_CPU_IMPL\n\n#define INITIATE_GATHER_KERNEL_UTIL(device_type, in_type_pair) \\\n  template struct GatherKernelUtil<device_type, OF_PP_PAIR_FIRST(in_type_pair)>;\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INITIATE_GATHER_KERNEL_UTIL, DEVICE_TYPE_SEQ,\n                                 GATHER_DATA_TYPE_SEQ);\n// For cpu float16/bfloat16\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INITIATE_GATHER_KERNEL_UTIL,\n                                 OF_PP_MAKE_TUPLE_SEQ(DeviceType::kCPU),\n                                 FLOAT16_DATA_TYPE_SEQ BFLOAT16_DATA_TYPE_SEQ);\n#undef INITIATE_GATHER_KERNEL_UTIL\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/gather_kernel_util.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/user/kernels/gather_kernel_util.h\"\n#include \"oneflow/core/kernel/kernel.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n#include \"oneflow/core/common/nd_index_offset_helper.h\"\n#include <assert.h>\n#if CUDA_VERSION >= 11000\n#include <cuda_bf16.h>\n#endif  // CUDA_VERSION >= 11000\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename T, typename K, typename IDX>\n__global__ void GatherForwardGpu(const IDX elem_cnt, NdIndexOffsetHelper<IDX, 3> in_helper,\n                                 NdIndexOffsetHelper<IDX, 3> out_helper, const K* indices,\n                                 const T* in, const IDX gather_dim_size, T* out, const IDX offset) {\n  IDX index[3];\n  CUDA_1D_KERNEL_LOOP_T(IDX, i, elem_cnt) {\n    out_helper.OffsetToNdIndex(i, index);\n    index[1] = indices[index[1]] - offset;\n    T v{};\n    if (index[1] >= 0 && index[1] < gather_dim_size) { v = in[in_helper.NdIndexToOffset(index)]; }\n    out[i] = v;\n  }\n}\n\nbool IsSafeUseIndex32(int64_t outer_dim_size, int64_t gather_dim_size, int64_t inner_dim_size,\n                      int64_t num_indices) {\n  const int64_t in_elem_cnt = outer_dim_size * gather_dim_size * inner_dim_size;\n  const int64_t out_elem_cnt = outer_dim_size * num_indices * inner_dim_size;\n  return std::max(out_elem_cnt, in_elem_cnt) < GetMaxVal<int32_t>() / 2;\n}\n\ntemplate<typename T, typename K>\nvoid DispatchIndexSize(ep::Stream* stream, int64_t outer_dim_size, int64_t gather_dim_size,\n                       int64_t inner_dim_size, int64_t num_indices, int64_t offset,\n                       const K* indices, const T* in, T* out) {\n  const int64_t out_elem_cnt = outer_dim_size * num_indices * inner_dim_size;\n  if (IsSafeUseIndex32(outer_dim_size, gather_dim_size, inner_dim_size, num_indices)) {\n    NdIndexOffsetHelper<int32_t, 3> in_helper(outer_dim_size, gather_dim_size, inner_dim_size);\n    NdIndexOffsetHelper<int32_t, 3> out_helper(outer_dim_size, num_indices, inner_dim_size);\n    GatherForwardGpu<T, K, int32_t><<<BlocksNum4ThreadsNum(out_elem_cnt), kCudaThreadsNumPerBlock,\n                                      0, stream->As<ep::CudaStream>()->cuda_stream()>>>(\n        out_elem_cnt, in_helper, out_helper, indices, in, gather_dim_size, out, offset);\n  } else {\n    NdIndexOffsetHelper<int64_t, 3> in_helper(outer_dim_size, gather_dim_size, inner_dim_size);\n    NdIndexOffsetHelper<int64_t, 3> out_helper(outer_dim_size, num_indices, inner_dim_size);\n    GatherForwardGpu<T, K, int64_t><<<BlocksNum4ThreadsNum(out_elem_cnt), kCudaThreadsNumPerBlock,\n                                      0, stream->As<ep::CudaStream>()->cuda_stream()>>>(\n        out_elem_cnt, in_helper, out_helper, indices, in, gather_dim_size, out, offset);\n  }\n}\n\ntemplate<typename K, typename T>\nbool TryDispatchMovementType(ep::Stream* stream, int64_t outer_dim_size, int64_t gather_dim_size,\n                             int64_t inner_dim_size, int64_t num_indices, int64_t offset,\n                             const K* indices, const void* in, void* out) {\n  if (reinterpret_cast<uintptr_t>(in) % sizeof(T) == 0\n      && reinterpret_cast<uintptr_t>(out) % sizeof(T) == 0 && inner_dim_size % sizeof(T) == 0) {\n    DispatchIndexSize<T, K>(stream, outer_dim_size, gather_dim_size, inner_dim_size / sizeof(T),\n                            num_indices, offset, indices, static_cast<const T*>(in),\n                            static_cast<T*>(out));\n    return true;\n  } else {\n    return false;\n  }\n}\n\ntemplate<typename K>\nvoid DispatchMovementSize(ep::Stream* stream, int64_t outer_dim_size, int64_t gather_dim_size,\n                          int64_t inner_dim_size, int64_t num_indices, int64_t offset,\n                          const K* indices, const void* in, void* out) {\n  using Func = bool (*)(ep::Stream * stream, int64_t outer_dim_size, int64_t gather_dim_size,\n                        int64_t inner_dim_size, int64_t num_indices, int64_t offset,\n                        const K* indices, const void* in, void* out);\n  Func funcs[] = {\n      TryDispatchMovementType<K, ulonglong2>,  // 16B\n      TryDispatchMovementType<K, uint64_t>,    // 8B\n      TryDispatchMovementType<K, uint32_t>,    // 4B\n      TryDispatchMovementType<K, uint16_t>,    // 2B\n      TryDispatchMovementType<K, uint8_t>,     // 1B\n  };\n  for (size_t i = 0; i < sizeof(funcs) / sizeof(funcs[0]); ++i) {\n    if (funcs[i](stream, outer_dim_size, gather_dim_size, inner_dim_size, num_indices, offset,\n                 indices, in, out)) {\n      break;\n    }\n  }\n}\n\n}  // namespace\n\ntemplate<typename T, typename K>\nstruct GatherKernelUtilImpl<DeviceType::kCUDA, T, K> final {\n  static void Forward(ep::Stream* stream, const K* indices, int64_t num_indices, const T* in,\n                      const Shape& flat_in_shape, T* out, const int64_t offset) {\n    DispatchMovementSize(stream, flat_in_shape.At(0), flat_in_shape.At(1),\n                         flat_in_shape.At(2) * sizeof(T), num_indices, offset, indices, in, out);\n  }\n};\n\n#define INITIATE_GATHER_KERNEL_UTIL_CUDA_IMPL(in_type_pair, index_type_pair)              \\\n  template struct GatherKernelUtilImpl<DeviceType::kCUDA, OF_PP_PAIR_FIRST(in_type_pair), \\\n                                       OF_PP_PAIR_FIRST(index_type_pair)>;\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INITIATE_GATHER_KERNEL_UTIL_CUDA_IMPL,\n                                 GATHER_DATA_TYPE_SEQ HALF_DATA_TYPE_SEQ, GATHER_INDEX_TYPE_SEQ);\n#if CUDA_VERSION >= 11000\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INITIATE_GATHER_KERNEL_UTIL_CUDA_IMPL,\n                                 OF_PP_MAKE_TUPLE_SEQ(nv_bfloat16, DataType::kBFloat16),\n                                 GATHER_INDEX_TYPE_SEQ);\n#endif\n#undef INITIATE_GATHER_KERNEL_UTIL_CUDA_IMPL\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/gather_kernel_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_KERNELS_GATHER_KERNEL_UTIL_H_\n#define ONEFLOW_USER_KERNELS_GATHER_KERNEL_UTIL_H_\n\n#include \"oneflow/core/kernel/kernel_util.h\"\n\nnamespace oneflow {\n\ntemplate<DeviceType device_type, typename T>\nstruct GatherKernelUtil final {\n  static void Forward(ep::Stream* stream, const Blob* indices, const Blob* in, int64_t axis,\n                      Blob* out);\n  static void Forward(ep::Stream* stream, const Blob* indices, const Blob* in, int64_t axis,\n                      Blob* out, int64_t offset);\n};\n\ntemplate<DeviceType device_type, typename T, typename K>\nstruct GatherKernelUtilImpl final {\n  static void Forward(ep::Stream* stream, const K* indices, int64_t num_indices, const T* in,\n                      const Shape& flat_in_shape, T* out, int64_t offset);\n};\n\n#define GATHER_DATA_TYPE_SEQ ARITHMETIC_DATA_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(bool, DataType::kBool)\n#define GATHER_INDEX_TYPE_SEQ INDEX_DATA_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(uint32_t, DataType::kUInt32)\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_KERNELS_GATHER_KERNEL_UTIL_H_\n"
  },
  {
    "path": "oneflow/user/kernels/generate_random_batch_permutation_indices_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <cstdint>\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/new_kernel_util.h\"\n#include \"oneflow/user/kernels/op_kernel_wrapper.h\"\n\nnamespace oneflow {\n\nclass GenerateRandomBatchPermutationIndicesCPUKernel final : public user_op::OpKernel {\n public:\n  GenerateRandomBatchPermutationIndicesCPUKernel() = default;\n  ~GenerateRandomBatchPermutationIndicesCPUKernel() = default;\n\n  std::shared_ptr<user_op::OpKernelState> CreateOpKernelState(\n      user_op::KernelInitContext* ctx) const override {\n    int64_t seed = ctx->Attr<int64_t>(\"seed\");\n    return std::make_shared<OpKernelStateWrapper<std::mt19937>>(seed);\n  }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state,\n               const user_op::OpKernelCache*) const override {\n    auto* random_generator = dynamic_cast<OpKernelStateWrapper<std::mt19937>*>(state);\n    user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    std::iota(y->mut_dptr<int32_t>(), y->mut_dptr<int32_t>() + y->shape_view().elem_cnt(), 0);\n    std::shuffle(y->mut_dptr<int32_t>(), y->mut_dptr<int32_t>() + y->shape_view().elem_cnt(),\n                 *random_generator->Mutable());\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\nREGISTER_USER_KERNEL(\"generate_random_batch_permutation_indices\")\n    .SetCreateFn<GenerateRandomBatchPermutationIndicesCPUKernel>()\n    .SetIsMatchedHob(user_op::HobDeviceType() == DeviceType::kCPU);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/generate_random_batch_permutation_indices_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/new_kernel_util.h\"\n#include \"oneflow/core/kernel/random_generator.h\"\n#include \"oneflow/user/kernels/radix_sort.cuh\"\n#include \"oneflow/user/kernels/op_kernel_wrapper.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nclass TmpBufferManager final {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(TmpBufferManager);\n  TmpBufferManager(const int32_t& batch_size, const int32_t& capacity, void* ptr)\n      : capacity_{capacity},\n        random_value_elem_cnt_{batch_size},\n        sorted_value_elem_cnt_{batch_size},\n        indices_elem_cnt_{batch_size} {\n    const int32_t random_value_aligned_bytes =\n        GetCudaAlignedSize(random_value_elem_cnt_ * sizeof(float));\n    const int32_t sorted_value_aligned_bytes =\n        GetCudaAlignedSize(sorted_value_elem_cnt_ * sizeof(float));\n    const int32_t indices_aligned_bytes = GetCudaAlignedSize(indices_elem_cnt_ * sizeof(int32_t));\n    random_value_ptr_ = reinterpret_cast<float*>(ptr);\n    sorted_value_ptr_ = reinterpret_cast<float*>(reinterpret_cast<char*>(random_value_ptr_)\n                                                 + random_value_aligned_bytes);\n    indices_ptr_ = reinterpret_cast<int32_t*>(reinterpret_cast<char*>(sorted_value_ptr_)\n                                              + sorted_value_aligned_bytes);\n    temp_storage_ptr_ =\n        reinterpret_cast<void*>(reinterpret_cast<char*>(indices_ptr_) + indices_aligned_bytes);\n    temp_storage_bytes_ =\n        capacity_ - random_value_aligned_bytes - sorted_value_aligned_bytes - indices_aligned_bytes;\n    CHECK_GE(temp_storage_bytes_, 0);\n  }\n  ~TmpBufferManager() = default;\n\n  float* RandomValuePtr() const { return random_value_ptr_; }\n  float* SortedValuePtr() const { return sorted_value_ptr_; }\n  int32_t* IndicesPtr() const { return indices_ptr_; }\n  void* TempStoragePtr() const { return temp_storage_ptr_; }\n\n  int32_t TempStorageBytes() const { return temp_storage_bytes_; }\n\n private:\n  int32_t capacity_;\n\n  float* random_value_ptr_;\n  float* sorted_value_ptr_;\n  int32_t* indices_ptr_;\n  void* temp_storage_ptr_;\n\n  int32_t random_value_elem_cnt_;\n  int32_t sorted_value_elem_cnt_;\n  int32_t indices_elem_cnt_;\n  int32_t temp_storage_bytes_;\n};\n\n__global__ void InitializeIndices(int32_t elem_cnt, int32_t* indices_ptr) {\n  CUDA_1D_KERNEL_LOOP(i, elem_cnt) { indices_ptr[i] = i; };\n}\n\n}  // namespace\n\nclass GenerateRandomBatchPermutationIndicesGPUKernel final : public user_op::OpKernel {\n public:\n  GenerateRandomBatchPermutationIndicesGPUKernel() = default;\n  ~GenerateRandomBatchPermutationIndicesGPUKernel() = default;\n\n  std::shared_ptr<user_op::OpKernelState> CreateOpKernelState(\n      user_op::KernelInitContext* ctx) const override {\n    int64_t seed = ctx->Attr<int64_t>(\"seed\");\n    return std::make_shared<OpKernelStateWrapper<RandomGenerator<DeviceType::kCUDA>>>(\n        seed, ctx->stream());\n  }\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state,\n               const user_op::OpKernelCache*) const override {\n    auto* random_generator =\n        dynamic_cast<OpKernelStateWrapper<RandomGenerator<DeviceType::kCUDA>>*>(state);\n    user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    const int32_t batch_size = y->shape_view().At(0);\n    user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n    TmpBufferManager buf_manager(batch_size,\n                                 static_cast<int32_t>(tmp_buffer->shape_view().elem_cnt()),\n                                 tmp_buffer->mut_dptr<void>());\n    random_generator->Mutable()->Uniform(batch_size, buf_manager.RandomValuePtr());\n    InitializeIndices<<<BlocksNum4ThreadsNum(batch_size), kCudaThreadsNumPerBlock, 0,\n                        ctx->stream()->As<ep::CudaStream>()->cuda_stream()>>>(\n        batch_size, buf_manager.IndicesPtr());\n    const int32_t argsort_instance_num = 1;\n    const int32_t argsort_instance_size = batch_size;\n    SortPairsAscending(buf_manager.RandomValuePtr(), buf_manager.IndicesPtr(), argsort_instance_num,\n                       argsort_instance_size, buf_manager.TempStoragePtr(),\n                       buf_manager.TempStorageBytes(), buf_manager.SortedValuePtr(),\n                       y->mut_dptr<int32_t>(), ctx->stream()->As<ep::CudaStream>()->cuda_stream());\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\nREGISTER_USER_KERNEL(\"generate_random_batch_permutation_indices\")\n    .SetCreateFn<GenerateRandomBatchPermutationIndicesGPUKernel>()\n    .SetIsMatchedHob(user_op::HobDeviceType() == DeviceType::kCUDA)\n    .SetInferTmpSizeFn([](oneflow::user_op::InferContext* ctx) {\n      const Shape& y_shape = ctx->OutputShape(\"y\", 0);\n      const int32_t batch_size = y_shape.At(0);\n\n      const int32_t random_value_aligned_bytes = GetCudaAlignedSize(batch_size * sizeof(float));\n      const int32_t sorted_value_aligned_bytes = GetCudaAlignedSize(batch_size * sizeof(float));\n      const int32_t indices_aligned_bytes = GetCudaAlignedSize(batch_size * sizeof(int32_t));\n      const int32_t argsort_instance_num = 1;\n      const int32_t argsort_instance_size = batch_size;\n      const int32_t temp_storage_bytes = InferTempStorageForSortPairsAscending<float, int32_t>(\n          argsort_instance_num, argsort_instance_size);\n\n      return random_value_aligned_bytes + sorted_value_aligned_bytes + indices_aligned_bytes\n             + temp_storage_bytes;\n    });\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/gpt_data_loader_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/user/data/gpt_dataset.h\"\n#include \"oneflow/user/data/distributed_util.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/rpc/include/global_process_ctx.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nusing namespace user_op;\nusing namespace data;\n\nsize_t GetNumShards(const Shape& hierarchy, const NdSbp& nd_sbp) {\n  size_t num_shards = 1;\n  FOR_RANGE(size_t, i, 0, nd_sbp.sbp_parallel_size()) {\n    const auto& sbp_parallel = nd_sbp.sbp_parallel(i);\n    if (sbp_parallel.has_split_parallel()) {\n      num_shards *= hierarchy.At(sbp_parallel.split_parallel().axis());\n    }\n  }\n  return num_shards;\n}\n\nsize_t GetShardIndex(const Shape& hierarchy, const NdSbp& nd_sbp, size_t rank) {\n  using index_helper_t = NdIndexOffsetHelper<int64_t, SHAPE_MAX_AXIS_SIZE>;\n  size_t ndim = hierarchy.NumAxes();\n  CHECK_GT(ndim, 0);\n  CHECK_LE(ndim, SHAPE_MAX_AXIS_SIZE);\n  index_helper_t index_helper(hierarchy.dim_vec().data(), ndim);\n  int64_t nd_index[SHAPE_MAX_AXIS_SIZE] = {0};\n  index_helper.OffsetToNdIndex(rank, nd_index);\n  size_t stride = 1;\n  size_t index = 0;\n  for (int i = ndim - 1; i >= 0; --i) {\n    const auto& sbp_parallel = nd_sbp.sbp_parallel(i);\n    if (sbp_parallel.has_split_parallel()) {\n      index += nd_index[i] * stride;\n      stride *= hierarchy.At(i);\n    }\n  }\n  return index;\n}\n\nclass GPTDataLoader final : public OpKernelState {\n public:\n  GPTDataLoader(KernelInitContext* ctx) : batch_cnt_(0) {\n    seq_len_ = ctx->Attr<int64_t>(\"seq_length\");\n    label_len_ = 1;\n    int64_t num_samples = ctx->Attr<int64_t>(\"num_samples\");\n\n    dataset_ = std::make_unique<const MegatronGPTMMapDataset>(\n        ctx->Attr<std::string>(\"data_file_prefix\"), seq_len_, label_len_, num_samples,\n        ctx->Attr<std::vector<int64_t>>(\"split_sizes\"), ctx->Attr<int64_t>(\"split_index\"),\n        ctx->Attr<bool>(\"shuffle\"), ctx->Attr<int64_t>(\"random_seed\"));\n\n    batch_size_ = ctx->TensorDesc4ArgNameAndIndex(\"out\", 0)->shape().At(0);\n    CHECK_JUST(InitDataSourceDistributedInfo(ctx, num_shards_, shard_index_));\n  }\n  ~GPTDataLoader() = default;\n\n  template<typename T>\n  void GetBatch(size_t iter, user_op::Tensor* tokens) const {\n    const size_t sample_len = seq_len_ + label_len_;\n    CHECK_EQ(tokens->shape_view().NumAxes(), 2);\n    CHECK_EQ(tokens->shape_view().At(0), batch_size_);\n    CHECK_EQ(tokens->shape_view().At(1), sample_len);\n    T* dptr = tokens->mut_dptr<T>();\n    for (size_t i = 0; i < batch_size_; ++i) {\n      size_t sample_iter = iter * batch_size_ * num_shards_ + shard_index_ * batch_size_ + i;\n      dataset_->GetSample(sample_iter, dptr + i * sample_len);\n    }\n  }\n\n  template<typename T>\n  void NextBatch(user_op::Tensor* tokens) {\n    GetBatch<T>(batch_cnt_, tokens);\n    batch_cnt_ += 1;\n  }\n\n private:\n  std::unique_ptr<const MegatronGPTMMapDataset> dataset_;\n  size_t seq_len_;\n  size_t label_len_;\n  size_t batch_size_;\n  size_t num_shards_;\n  int64_t shard_index_;\n  size_t batch_cnt_;\n};\n\ntemplate<typename T>\nclass GPTDataLoaderKernel final : public OpKernel {\n public:\n  GPTDataLoaderKernel() = default;\n  ~GPTDataLoaderKernel() = default;\n\n  std::shared_ptr<OpKernelState> CreateOpKernelState(KernelInitContext* ctx) const override {\n    std::shared_ptr<OpKernelState> reader(new GPTDataLoader(ctx));\n    return reader;\n  }\n\n private:\n  void Compute(KernelComputeContext* ctx, OpKernelState* state,\n               const OpKernelCache*) const override {\n    auto* loader = dynamic_cast<GPTDataLoader*>(state);\n    user_op::Tensor* iteration_tensor = ctx->Tensor4ArgNameAndIndex(\"iteration\", 0);\n    user_op::Tensor* out_tensor = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    if (iteration_tensor) {\n      CHECK_EQ(iteration_tensor->shape_view().elem_cnt(), 1);\n      CHECK_EQ(iteration_tensor->data_type(), DataType::kInt64);\n      int64_t* iter_ptr = iteration_tensor->mut_dptr<int64_t>();\n      loader->GetBatch<T>(*iter_ptr, out_tensor);\n      *iter_ptr += 1;\n    } else {\n      loader->NextBatch<T>(out_tensor);\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n}  // namespace\n\n#define REGISTER_GPT_DATA_LOADER_KERNEL(dtype)                        \\\n  REGISTER_USER_KERNEL(\"megatron_gpt_mmap_data_loader\")               \\\n      .SetCreateFn<GPTDataLoaderKernel<dtype>>()                      \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \\\n                       && (user_op::HobDataType(\"out\", 0) == GetDataType<dtype>::value))\n\nREGISTER_GPT_DATA_LOADER_KERNEL(int32_t);\nREGISTER_GPT_DATA_LOADER_KERNEL(int64_t);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/greater_inplace_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/user/kernels/greater_inplace_kernel_util.h\"\n\nnamespace oneflow {\n\ntemplate<DeviceType device_type, typename T>\nclass GreaterInplaceKernel final : public user_op::OpKernel {\n public:\n  GreaterInplaceKernel() = default;\n  ~GreaterInplaceKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    const int64_t elem_cnt = x->shape_view().elem_cnt();\n    if (elem_cnt == 0) { return; }\n    const user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n\n    const T* x_ptr = x->dptr<T>();\n    const T* y_ptr = y->dptr<T>();\n    T* out_ptr = out->mut_dptr<T>();\n    T* broadcast_y_ptr = tmp_buffer->mut_dptr<T>();\n\n    if (x->shape_view() == y->shape_view()) {\n      GreaterInplaceKernelUtil<device_type, T>::Forward(ctx->stream(), elem_cnt, x_ptr, y_ptr,\n                                                        out_ptr);\n      return;\n    }\n    GreaterInplaceKernelUtil<device_type, T>::YBroadcastToX(\n        ctx->stream(), elem_cnt, x_ptr, y_ptr, broadcast_y_ptr, x->shape_view(), y->shape_view());\n    GreaterInplaceKernelUtil<device_type, T>::Forward(ctx->stream(), elem_cnt, x_ptr,\n                                                      broadcast_y_ptr, out_ptr);\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_GREATER_INPLACE_KERNEL(device_type, dtype)                              \\\n  REGISTER_USER_KERNEL(\"broadcast_inplace_greater\")                                      \\\n      .SetCreateFn<GreaterInplaceKernel<device_type, dtype>>()                           \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == device_type)                         \\\n                       && (user_op::HobDataType(\"out\", 0) == GetDataType<dtype>::value)) \\\n      .SetInferTmpSizeFn([](user_op::InferContext* ctx) {                                \\\n        const Shape& x_shape = ctx->InputShape(\"x\", 0);                                  \\\n        return GetCudaAlignedSize(x_shape.elem_cnt() * sizeof(dtype));                   \\\n      });\n\nREGISTER_GREATER_INPLACE_KERNEL(DeviceType::kCPU, float)\nREGISTER_GREATER_INPLACE_KERNEL(DeviceType::kCPU, double)\nREGISTER_GREATER_INPLACE_KERNEL(DeviceType::kCPU, int8_t)\nREGISTER_GREATER_INPLACE_KERNEL(DeviceType::kCPU, int32_t)\nREGISTER_GREATER_INPLACE_KERNEL(DeviceType::kCPU, int64_t)\n\n#ifdef WITH_CUDA\nREGISTER_GREATER_INPLACE_KERNEL(DeviceType::kCUDA, half)\nREGISTER_GREATER_INPLACE_KERNEL(DeviceType::kCUDA, float)\nREGISTER_GREATER_INPLACE_KERNEL(DeviceType::kCUDA, double)\nREGISTER_GREATER_INPLACE_KERNEL(DeviceType::kCUDA, int8_t)\nREGISTER_GREATER_INPLACE_KERNEL(DeviceType::kCUDA, int32_t)\nREGISTER_GREATER_INPLACE_KERNEL(DeviceType::kCUDA, int64_t)\n#endif  // WITH_CUDA\n\ntemplate<DeviceType device_type, typename T, typename ValueT>\nclass ScalarGreaterInplaceKernel final : public user_op::OpKernel {\n public:\n  ScalarGreaterInplaceKernel() = default;\n  ~ScalarGreaterInplaceKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    const int64_t elem_cnt = in->shape_view().elem_cnt();\n    if (elem_cnt == 0) { return; }\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    Scalar scalar_operand;\n    if (ctx->Attr<bool>(\"has_int_operand\")) {\n      scalar_operand = ctx->Attr<int64_t>(\"int_operand\");\n    } else if (ctx->Attr<bool>(\"has_float_operand\")) {\n      scalar_operand = ctx->Attr<double>(\"float_operand\");\n    } else {\n      UNIMPLEMENTED();\n    }\n\n    const T* in_ptr = in->dptr<T>();\n    T* out_ptr = out->mut_dptr<T>();\n\n    ScalarGreaterInplaceKernelUtil<device_type, T, ValueT>::Forward(ctx->stream(), elem_cnt, in_ptr,\n                                                                    scalar_operand, out_ptr);\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_SCALAR_GREATER_INPLACE_KERNEL(device_type, dtype, value_type)   \\\n  REGISTER_USER_KERNEL(\"scalar_logical_inplace_greater\")                         \\\n      .SetCreateFn<ScalarGreaterInplaceKernel<device_type, dtype, value_type>>() \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == device_type)                 \\\n                       && (user_op::HobDataType(\"out\", 0) == GetDataType<dtype>::value));\n\nREGISTER_SCALAR_GREATER_INPLACE_KERNEL(DeviceType::kCPU, float, double)\nREGISTER_SCALAR_GREATER_INPLACE_KERNEL(DeviceType::kCPU, double, double)\nREGISTER_SCALAR_GREATER_INPLACE_KERNEL(DeviceType::kCPU, int8_t, int64_t)\nREGISTER_SCALAR_GREATER_INPLACE_KERNEL(DeviceType::kCPU, int32_t, int64_t)\nREGISTER_SCALAR_GREATER_INPLACE_KERNEL(DeviceType::kCPU, int64_t, int64_t)\n\n#ifdef WITH_CUDA\nREGISTER_SCALAR_GREATER_INPLACE_KERNEL(DeviceType::kCUDA, half, double)\nREGISTER_SCALAR_GREATER_INPLACE_KERNEL(DeviceType::kCUDA, float, double)\nREGISTER_SCALAR_GREATER_INPLACE_KERNEL(DeviceType::kCUDA, double, double)\nREGISTER_SCALAR_GREATER_INPLACE_KERNEL(DeviceType::kCUDA, int8_t, int64_t)\nREGISTER_SCALAR_GREATER_INPLACE_KERNEL(DeviceType::kCUDA, int32_t, int64_t)\nREGISTER_SCALAR_GREATER_INPLACE_KERNEL(DeviceType::kCUDA, int64_t, int64_t)\n#endif  // WITH_CUDA\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/greater_inplace_kernel_util.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/user/kernels/greater_inplace_kernel_util.h\"\n\nnamespace oneflow {\n\ntemplate<typename T>\nstruct GreaterInplaceKernelUtil<DeviceType::kCPU, T> {\n  static void Forward(ep::Stream* stream, const int64_t n, const T* x, const T* y, T* out) {\n    FOR_RANGE(int64_t, i, 0, n) { out[i] = x[i] > y[i] ? static_cast<T>(1) : static_cast<T>(0); }\n  }\n};\n\ntemplate<typename T, typename ValueT>\nstruct ScalarGreaterInplaceKernelUtil<DeviceType::kCPU, T, ValueT> {\n  static void Forward(ep::Stream* stream, const int64_t n, const T* x, const Scalar operand,\n                      T* out) {\n    FOR_RANGE(int64_t, i, 0, n) {\n      out[i] =\n          x[i] > static_cast<T>(operand.Value<ValueT>()) ? static_cast<T>(1) : static_cast<T>(0);\n    }\n  }\n};\n\n#define INSTANTIATE_GREATER_INPLACE_KERNEL_UTIL_CPU(data_type, other) \\\n  template struct GreaterInplaceKernelUtil<DeviceType::kCPU, data_type>;\nOF_PP_FOR_EACH_TUPLE(INSTANTIATE_GREATER_INPLACE_KERNEL_UTIL_CPU, GREATER_INPLACE_DATA_TYPE_SEQ_CPU)\n#undef INSTANTIATE_GREATER_INPLACE_KERNEL_UTIL_CPU\n\n#define INSTANTIATE_SCALAR_GREATER_INPLACE_KERNEL_UTIL_CPU(data_type, value_data_type)          \\\n  template struct ScalarGreaterInplaceKernelUtil<DeviceType::kCPU, OF_PP_PAIR_FIRST(data_type), \\\n                                                 OF_PP_PAIR_FIRST(value_data_type)>;\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_SCALAR_GREATER_INPLACE_KERNEL_UTIL_CPU,\n                                 GREATER_INPLACE_DATA_TYPE_SEQ_CPU, SCALAR_VALUE_DATA_TYPE_SEQ)\n#undef INSTANTIATE_SCALAR_GREATER_INPLACE_KERNEL_UTIL_CPU\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/greater_inplace_kernel_util.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/device/cuda_util.h\"\n#include \"oneflow/user/kernels/greater_inplace_kernel_util.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename T>\n__global__ void GreaterInplacForwardGpu(const int64_t n, const T* x, const T* y, T* out) {\n  CUDA_1D_KERNEL_LOOP_T(int64_t, i, n) {\n    out[i] = x[i] > y[i] ? static_cast<T>(1) : static_cast<T>(0);\n  }\n}\n\ntemplate<typename T, typename ValueT>\n__global__ void ScalarGreaterInplacForwardGpu(const int64_t n, const T* x, const Scalar operand,\n                                              T* out) {\n  CUDA_1D_KERNEL_LOOP_T(int64_t, i, n) {\n    out[i] = x[i] > static_cast<T>(operand.Value<ValueT>()) ? static_cast<T>(1) : static_cast<T>(0);\n  }\n}\n\ntemplate<>\n__global__ void ScalarGreaterInplacForwardGpu<half, int64_t>(const int64_t n, const half* x,\n                                                             const Scalar operand, half* out) {\n  CUDA_1D_KERNEL_LOOP_T(int64_t, i, n) {\n    float operator_value = static_cast<float>(operand.Value<int64_t>());\n    out[i] = x[i] > __float2half(operator_value) ? static_cast<half>(1) : static_cast<half>(0);\n  }\n}\n\ntemplate<>\n__global__ void ScalarGreaterInplacForwardGpu<half, double>(const int64_t n, const half* x,\n                                                            const Scalar operand, half* out) {\n  CUDA_1D_KERNEL_LOOP_T(int64_t, i, n) {\n    float operator_value = static_cast<float>(operand.Value<double>());\n    out[i] = x[i] > __float2half(operator_value) ? static_cast<half>(1) : static_cast<half>(0);\n  }\n}\n\n}  // namespace\n\ntemplate<typename T>\nstruct GreaterInplaceKernelUtil<DeviceType::kCUDA, T> {\n  static void Forward(ep::Stream* stream, const int64_t n, const T* x, const T* y, T* out) {\n    RUN_CUDA_KERNEL((GreaterInplacForwardGpu<T>), stream, n, n, x, y, out);\n  }\n};\n\ntemplate<typename T, typename ValueT>\nstruct ScalarGreaterInplaceKernelUtil<DeviceType::kCUDA, T, ValueT> {\n  static void Forward(ep::Stream* stream, const int64_t n, const T* x, const Scalar operand,\n                      T* out) {\n    RUN_CUDA_KERNEL((ScalarGreaterInplacForwardGpu<T, ValueT>), stream, n, n, x, operand, out);\n  }\n};\n\n#define INSTANTIATE_GREATER_INPLACE_KERNEL_UTIL_CUDA(data_type, other) \\\n  template struct GreaterInplaceKernelUtil<DeviceType::kCUDA, data_type>;\nOF_PP_FOR_EACH_TUPLE(INSTANTIATE_GREATER_INPLACE_KERNEL_UTIL_CUDA,\n                     GREATER_INPLACE_DATA_TYPE_SEQ_CUDA)\n#undef INSTANTIATE_GREATER_INPLACE_KERNEL_UTIL_CUDA\n\n#define INSTANTIATE_SCALAR_GREATER_INPLACE_KERNEL_UTIL_CUDA(data_type, value_data_type)          \\\n  template struct ScalarGreaterInplaceKernelUtil<DeviceType::kCUDA, OF_PP_PAIR_FIRST(data_type), \\\n                                                 OF_PP_PAIR_FIRST(value_data_type)>;\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_SCALAR_GREATER_INPLACE_KERNEL_UTIL_CUDA,\n                                 GREATER_INPLACE_DATA_TYPE_SEQ_CUDA, SCALAR_VALUE_DATA_TYPE_SEQ)\n#undef INSTANTIATE_SCALAR_GREATER_INPLACE_KERNEL_UTIL_CUDA\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/greater_inplace_kernel_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_KERNELS_GREATER_INPLACE_KERNEL_UTIL_H_\n#define ONEFLOW_USER_KERNELS_GREATER_INPLACE_KERNEL_UTIL_H_\n\n#include \"oneflow/core/common/scalar.h\"\n#include \"oneflow/core/common/shape_view.h\"\n#include \"oneflow/core/ep/include/stream.h\"\n#include \"oneflow/core/ndarray/ndarray_util.h\"\n\nnamespace oneflow {\n\ntemplate<DeviceType device_type, typename T, typename ValueT>\nstruct ScalarGreaterInplaceKernelUtil {\n  static void Forward(ep::Stream* stream, const int64_t n, const T* x, const Scalar operand,\n                      T* out);\n};\n\ntemplate<DeviceType device_type, typename T>\nstruct GreaterInplaceKernelUtil {\n  static void Forward(ep::Stream* stream, const int64_t n, const T* x, const T* y, T* out);\n  static void YBroadcastToX(ep::Stream* stream, const int64_t n, const T* x, const T* y,\n                            T* broadcast_y, const ShapeView& x_shape, const ShapeView& y_shape) {\n    const int64_t x_ndim = x_shape.NumAxes();\n    const int64_t y_ndim = y_shape.NumAxes();\n    const int64_t num_prepend = x_ndim - y_ndim;\n    std::vector<int64_t> prepend_shape(num_prepend, 1);\n    std::vector<int32_t> broadcast_axes;\n    for (int i = 0; i < y_ndim; ++i) { prepend_shape.emplace_back(y_shape.At(i)); }\n    for (int i = 0; i < num_prepend; ++i) { broadcast_axes.emplace_back(i); }\n    for (int i = num_prepend; i < prepend_shape.size(); ++i) {\n      if (prepend_shape[i] != x_shape.At(i)) {\n        if (prepend_shape[i] == 1) { broadcast_axes.emplace_back(i); }\n      }\n    }\n    const Shape& reduced_shape =\n        CreateReducedShapeOrOnesShape(x_shape, {broadcast_axes.begin(), broadcast_axes.end()});\n    NdarrayUtil<device_type, T>::BroadcastTo(stream, XpuVarNdarray<T>(x_shape, broadcast_y),\n                                             XpuVarNdarray<const T>(reduced_shape, y));\n  }\n};\n\n#define SCALAR_VALUE_DATA_TYPE_SEQ                \\\n  OF_PP_MAKE_TUPLE_SEQ(int64_t, DataType::kInt64) \\\n  OF_PP_MAKE_TUPLE_SEQ(double, DataType::kDouble)\n\n#define GREATER_INPLACE_DATA_TYPE_SEQ_CPU \\\n  FLOATING_DATA_TYPE_SEQ                  \\\n  SIGNED_INT_DATA_TYPE_SEQ\n\n#ifdef WITH_CUDA\n#define GREATER_INPLACE_DATA_TYPE_SEQ_CUDA \\\n  FLOATING_DATA_TYPE_SEQ                   \\\n  SIGNED_INT_DATA_TYPE_SEQ                 \\\n  HALF_DATA_TYPE_SEQ\n#endif\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_KERNELS_GREATER_INPLACE_KERNEL_UTIL_H_\n"
  },
  {
    "path": "oneflow/user/kernels/grid_sample_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/common/data_type.pb.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/new_kernel_util.h\"\n#include \"oneflow/core/framework/config_def.h\"\n#include \"grid_sample_kernel_util.h\"\n\nnamespace oneflow {\n\ntemplate<DeviceType device_type, typename data_type>\nclass GridSampleKernel final : public user_op::OpKernel {\n public:\n  GridSampleKernel() = default;\n  ~GridSampleKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* input = ctx->Tensor4ArgNameAndIndex(\"input\", 0);\n    const user_op::Tensor* grid = ctx->Tensor4ArgNameAndIndex(\"grid\", 0);\n    user_op::Tensor* output = ctx->Tensor4ArgNameAndIndex(\"output\", 0);\n    const std::string interpolation_mode = ctx->Attr<std::string>(\"interpolation_mode\");\n    const std::string padding_mode = ctx->Attr<std::string>(\"padding_mode\");\n    GridSamplerInterpolation interpolation = StringToGridSamplerInterpolation(interpolation_mode);\n    GridSamplerPadding padding = StringToGridGridSamplerPadding(padding_mode);\n    const bool align_corners = ctx->Attr<bool>(\"align_corners\");\n\n    const ShapeView& input_shape = input->shape_view();\n    const ShapeView& grid_shape = grid->shape_view();\n    const ShapeView& output_shape = output->shape_view();\n    int64_t count = output_shape.elem_cnt() / input_shape.At(1);\n\n    if (input_shape.NumAxes() == 4) {\n      if (!CanUse32BitIndex({input_shape, grid_shape, output_shape})) {\n        GridSampleKernelUtil<device_type, data_type, int64_t>::Forward4D(\n            ctx, input, grid, output, interpolation, padding, align_corners, input_shape,\n            grid_shape, output_shape, count);\n      } else {\n        GridSampleKernelUtil<device_type, data_type, int32_t>::Forward4D(\n            ctx, input, grid, output, interpolation, padding, align_corners, input_shape,\n            grid_shape, output_shape, count);\n      }\n    } else {\n      if (!CanUse32BitIndex({input_shape, grid_shape, output_shape})) {\n        GridSampleKernelUtil<device_type, data_type, int64_t>::Forward5D(\n            ctx, input, grid, output, interpolation, padding, align_corners, input_shape,\n            grid_shape, output_shape, count);\n      } else {\n        GridSampleKernelUtil<device_type, data_type, int32_t>::Forward5D(\n            ctx, input, grid, output, interpolation, padding, align_corners, input_shape,\n            grid_shape, output_shape, count);\n      }\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_GRID_SAMPLE_KERNEL(device, dtype)          \\\n  REGISTER_USER_KERNEL(\"grid_sample\")                       \\\n      .SetCreateFn<GridSampleKernel<device, dtype>>()       \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == device) \\\n                       && (user_op::HobDataType(\"input\", 0) == GetDataType<dtype>::value))\n\nREGISTER_GRID_SAMPLE_KERNEL(DeviceType::kCPU, float);\nREGISTER_GRID_SAMPLE_KERNEL(DeviceType::kCPU, double);\n#ifdef WITH_CUDA\nREGISTER_GRID_SAMPLE_KERNEL(DeviceType::kCUDA, float);\nREGISTER_GRID_SAMPLE_KERNEL(DeviceType::kCUDA, double);\n#endif\n\ntemplate<DeviceType device_type, typename data_type>\nclass GridSampleGradKernel final : public user_op::OpKernel {\n public:\n  GridSampleGradKernel() = default;\n  ~GridSampleGradKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* doutput = ctx->Tensor4ArgNameAndIndex(\"doutput\", 0);\n    const user_op::Tensor* input = ctx->Tensor4ArgNameAndIndex(\"input\", 0);\n    const user_op::Tensor* grid = ctx->Tensor4ArgNameAndIndex(\"grid\", 0);\n    user_op::Tensor* dinput = ctx->Tensor4ArgNameAndIndex(\"dinput\", 0);\n    user_op::Tensor* dgrid = ctx->Tensor4ArgNameAndIndex(\"dgrid\", 0);\n    const std::string interpolation_mode = ctx->Attr<std::string>(\"interpolation_mode\");\n    const std::string padding_mode = ctx->Attr<std::string>(\"padding_mode\");\n    GridSamplerInterpolation interpolation = StringToGridSamplerInterpolation(interpolation_mode);\n    GridSamplerPadding padding = StringToGridGridSamplerPadding(padding_mode);\n    const bool align_corners = ctx->Attr<bool>(\"align_corners\");\n\n    const ShapeView& input_shape = input->shape_view();\n    const ShapeView& grid_shape = grid->shape_view();\n    const ShapeView& output_shape = doutput->shape_view();\n    int64_t count = output_shape.elem_cnt() / input_shape.At(1);\n\n    Memset<device_type>(ctx->stream(), dinput->mut_dptr<data_type>(), 0,\n                        input_shape.elem_cnt() * sizeof(data_type));\n\n    if (input_shape.NumAxes() == 4) {\n      if (!CanUse32BitIndex({input_shape, grid_shape, output_shape})) {\n        GridSampleKernelUtil<device_type, data_type, int64_t>::Backward4D(\n            ctx, doutput, input, grid, dinput, dgrid, interpolation, padding, align_corners,\n            input_shape, grid_shape, output_shape, count);\n      } else {\n        GridSampleKernelUtil<device_type, data_type, int32_t>::Backward4D(\n            ctx, doutput, input, grid, dinput, dgrid, interpolation, padding, align_corners,\n            input_shape, grid_shape, output_shape, count);\n      }\n    } else {\n      if (!CanUse32BitIndex({input_shape, grid_shape, output_shape})) {\n        GridSampleKernelUtil<device_type, data_type, int64_t>::Backward5D(\n            ctx, doutput, input, grid, dinput, dgrid, interpolation, padding, align_corners,\n            input_shape, grid_shape, output_shape, count);\n      } else {\n        GridSampleKernelUtil<device_type, data_type, int32_t>::Backward5D(\n            ctx, doutput, input, grid, dinput, dgrid, interpolation, padding, align_corners,\n            input_shape, grid_shape, output_shape, count);\n      }\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_GRID_SAMPLE_GRAD_KERNEL(device, dtype)     \\\n  REGISTER_USER_KERNEL(\"grid_sample_grad\")                  \\\n      .SetCreateFn<GridSampleGradKernel<device, dtype>>()   \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == device) \\\n                       && (user_op::HobDataType(\"input\", 0) == GetDataType<dtype>::value))\n\nREGISTER_GRID_SAMPLE_GRAD_KERNEL(DeviceType::kCPU, float);\nREGISTER_GRID_SAMPLE_GRAD_KERNEL(DeviceType::kCPU, double);\n#ifdef WITH_CUDA\nREGISTER_GRID_SAMPLE_GRAD_KERNEL(DeviceType::kCUDA, float);\nREGISTER_GRID_SAMPLE_GRAD_KERNEL(DeviceType::kCUDA, double);\n#endif\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/grid_sample_kernel_util.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"grid_sample_kernel_util.h\"\n\nnamespace oneflow {\n\ntemplate<typename data_type, typename index_type>\nstruct GridSampleKernelUtil<DeviceType::kCPU, data_type, index_type> final {\n  static void Forward4D(user_op::KernelComputeContext* ctx, const user_op::Tensor* input,\n                        const user_op::Tensor* grid, user_op::Tensor* output,\n                        GridSamplerInterpolation interpolation, GridSamplerPadding padding,\n                        const bool align_corners, const ShapeView& input_shape,\n                        const ShapeView& grid_shape, const ShapeView& output_shape, int64_t count) {\n    GridSampler4DKernel<data_type, index_type>(\n        count, input->dptr<data_type>(), grid->dptr<data_type>(), output->mut_dptr<data_type>(),\n        input_shape.At(0), input_shape.At(1), input_shape.At(2), input_shape.At(3),\n        output_shape.At(2), output_shape.At(3), interpolation, padding, align_corners);\n  }\n\n  static void Forward5D(user_op::KernelComputeContext* ctx, const user_op::Tensor* input,\n                        const user_op::Tensor* grid, user_op::Tensor* output,\n                        GridSamplerInterpolation interpolation, GridSamplerPadding padding,\n                        const bool align_corners, const ShapeView& input_shape,\n                        const ShapeView& grid_shape, const ShapeView& output_shape, int64_t count) {\n    GridSampler5DKernel<data_type, index_type>(\n        count, input->dptr<data_type>(), grid->dptr<data_type>(), output->mut_dptr<data_type>(),\n        input_shape.At(0), input_shape.At(1), input_shape.At(2), input_shape.At(3),\n        input_shape.At(4), output_shape.At(2), output_shape.At(3), output_shape.At(4),\n        interpolation, padding, align_corners);\n  }\n\n  static void Backward4D(user_op::KernelComputeContext* ctx, const user_op::Tensor* doutput,\n                         const user_op::Tensor* input, const user_op::Tensor* grid,\n                         user_op::Tensor* dinput, user_op::Tensor* dgrid,\n                         GridSamplerInterpolation interpolation, GridSamplerPadding padding,\n                         const bool align_corners, const ShapeView& input_shape,\n                         const ShapeView& grid_shape, const ShapeView& output_shape,\n                         int64_t count) {\n    GridSampler4DBackwardKernel<data_type, index_type>(\n        count, doutput->dptr<data_type>(), input->dptr<data_type>(), grid->dptr<data_type>(),\n        dinput->mut_dptr<data_type>(), dgrid->mut_dptr<data_type>(), input_shape.At(0),\n        input_shape.At(1), input_shape.At(2), input_shape.At(3), output_shape.At(2),\n        output_shape.At(3), interpolation, padding, align_corners, input_shape.elem_cnt());\n  }\n\n  static void Backward5D(user_op::KernelComputeContext* ctx, const user_op::Tensor* doutput,\n                         const user_op::Tensor* input, const user_op::Tensor* grid,\n                         user_op::Tensor* dinput, user_op::Tensor* dgrid,\n                         GridSamplerInterpolation interpolation, GridSamplerPadding padding,\n                         const bool align_corners, const ShapeView& input_shape,\n                         const ShapeView& grid_shape, const ShapeView& output_shape,\n                         int64_t count) {\n    GridSampler5DBackwardKernel<data_type, index_type>(\n        count, doutput->dptr<data_type>(), input->dptr<data_type>(), grid->dptr<data_type>(),\n        dinput->mut_dptr<data_type>(), dgrid->mut_dptr<data_type>(), input_shape.At(0),\n        input_shape.At(1), input_shape.At(2), input_shape.At(3), input_shape.At(4),\n        output_shape.At(2), output_shape.At(3), output_shape.At(4), interpolation, padding,\n        align_corners, input_shape.elem_cnt());\n  }\n};\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_GRID_SAMPLE_KERNEL_UTIL, (DeviceType::kCPU),\n                                 FLOATING_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/grid_sample_kernel_util.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/kernel_util.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n\n#include \"grid_sample_kernel_util.h\"\n\nnamespace oneflow {\n\nclass CudnnGridSampleDesc final {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CudnnGridSampleDesc);\n  CudnnGridSampleDesc(DataType data_type, const ShapeView& shape) {\n    std::vector<int> tensor_dim({shape.ptr(), shape.ptr() + shape.NumAxes()});\n    OF_CUDNN_CHECK(cudnnCreateSpatialTransformerDescriptor(&val_));\n    OF_CUDNN_CHECK(cudnnSetSpatialTransformerNdDescriptor(val_, CUDNN_SAMPLER_BILINEAR,\n                                                          GetCudnnDataType(data_type),\n                                                          shape.NumAxes(), tensor_dim.data()));\n  }\n\n  ~CudnnGridSampleDesc() { OF_CUDNN_CHECK(cudnnDestroySpatialTransformerDescriptor(val_)); }\n\n  const cudnnSpatialTransformerDescriptor_t& Get() const { return val_; }\n\n private:\n  cudnnSpatialTransformerDescriptor_t val_;\n};\n\ntemplate<typename T>\nstruct CudnnGridSampleKernelUtil {\n  static bool CanRunWithCudnn(user_op::KernelComputeContext* ctx) {\n    if (ctx->Attr<std::string>(\"interpolation_mode\") != \"bilinear\"\n        || ctx->Attr<std::string>(\"padding_mode\") != \"zeros\" || !ctx->Attr<bool>(\"align_corners\")) {\n      return false;\n    }\n    const ShapeView& input_shape = ctx->Tensor4ArgNameAndIndex(\"input\", 0)->shape_view();\n    if (input_shape.NumAxes() != 4 || input_shape.At(1) > 1024) { return false; }\n\n    return true;\n  }\n\n  static void ForwardCompute(user_op::KernelComputeContext* ctx) {\n    const user_op::Tensor* input = ctx->Tensor4ArgNameAndIndex(\"input\", 0);\n    const user_op::Tensor* grid = ctx->Tensor4ArgNameAndIndex(\"grid\", 0);\n    user_op::Tensor* output = ctx->Tensor4ArgNameAndIndex(\"output\", 0);\n    const ShapeView& input_shape = input->shape_view();\n    const ShapeView& output_shape = output->shape_view();\n    const DataType dtype = input->data_type();\n\n    CudnnTensorDesc input_desc(dtype, input_shape, \"channels_first\");\n    CudnnTensorDesc output_desc(dtype, output_shape, \"channels_first\");\n    CudnnGridSampleDesc transfomer_desc(dtype, output_shape);\n\n    OF_CUDNN_CHECK(cudnnSpatialTfSamplerForward(\n        ctx->stream()->As<ep::CudaStream>()->cudnn_handle(), transfomer_desc.Get(),\n        CudnnSPOnePtr<T>(), input_desc.Get(), input->dptr(), grid->dptr(), CudnnSPZeroPtr<T>(),\n        output_desc.Get(), output->mut_dptr()));\n  }\n\n  static void BackwardCompute(user_op::KernelComputeContext* ctx) {\n    const user_op::Tensor* doutput = ctx->Tensor4ArgNameAndIndex(\"doutput\", 0);\n    const user_op::Tensor* input = ctx->Tensor4ArgNameAndIndex(\"input\", 0);\n    const user_op::Tensor* grid = ctx->Tensor4ArgNameAndIndex(\"grid\", 0);\n    user_op::Tensor* dinput = ctx->Tensor4ArgNameAndIndex(\"dinput\", 0);\n    user_op::Tensor* dgrid = ctx->Tensor4ArgNameAndIndex(\"dgrid\", 0);\n    const ShapeView& input_shape = input->shape_view();\n    const ShapeView& output_shape = doutput->shape_view();\n    const ShapeView& dinput_shape = dinput->shape_view();\n    const DataType dtype = input->data_type();\n\n    CudnnTensorDesc input_desc(dtype, input_shape, \"channels_first\");\n    CudnnTensorDesc output_desc(dtype, output_shape, \"channels_first\");\n    CudnnTensorDesc dinput_desc(dtype, dinput_shape, \"channels_first\");\n    CudnnGridSampleDesc transfomer_desc(dtype, output_shape);\n\n    OF_CUDNN_CHECK(cudnnSpatialTfSamplerBackward(\n        ctx->stream()->As<ep::CudaStream>()->cudnn_handle(), transfomer_desc.Get(),\n        CudnnSPOnePtr<T>(), input_desc.Get(), input->dptr(), CudnnSPZeroPtr<T>(), dinput_desc.Get(),\n        dinput->mut_dptr(), CudnnSPOnePtr<T>(), output_desc.Get(), doutput->dptr(), grid->dptr(),\n        CudnnSPZeroPtr<T>(), dgrid->mut_dptr()));\n  }\n};\n\ntemplate<typename data_type, typename index_type>\n__launch_bounds__(256) __global__\n    void CUDAGridSampler4DKernel(const index_type nthreads, const data_type* input_ptr,\n                                 const data_type* grid_ptr, data_type* output_ptr, index_type N,\n                                 index_type C, index_type inp_H, index_type inp_W, index_type out_H,\n                                 index_type out_W,\n                                 const GridSamplerInterpolation interpolation_mode,\n                                 const GridSamplerPadding padding_mode, const bool align_corners) {\n  GridSampler4DKernel(nthreads, input_ptr, grid_ptr, output_ptr, N, C, inp_H, inp_W, out_H, out_W,\n                      interpolation_mode, padding_mode, align_corners);\n}\n\ntemplate<typename data_type, typename index_type>\n__launch_bounds__(512) __global__\n    void CUDAGridSampler5DKernel(const index_type nthreads, const data_type* input_ptr,\n                                 const data_type* grid_ptr, data_type* output_ptr, index_type N,\n                                 index_type C, index_type inp_D, index_type inp_H, index_type inp_W,\n                                 index_type out_D, index_type out_H, index_type out_W,\n                                 const GridSamplerInterpolation interpolation_mode,\n                                 const GridSamplerPadding padding_mode, const bool align_corners) {\n  GridSampler5DKernel(nthreads, input_ptr, grid_ptr, output_ptr, N, C, inp_D, inp_H, inp_W, out_D,\n                      out_H, out_W, interpolation_mode, padding_mode, align_corners);\n}\n\ntemplate<typename data_type, typename index_type>\n__launch_bounds__(256) __global__ void CUDAGridSampler4DBackwardKernel(\n    const index_type nthreads, const data_type* grad_output_ptr, const data_type* input_ptr,\n    const data_type* grid_ptr, data_type* grad_input_ptr, data_type* grad_grid_ptr, index_type N,\n    index_type C, index_type inp_H, index_type inp_W, index_type out_H, index_type out_W,\n    const GridSamplerInterpolation interpolation_mode, const GridSamplerPadding padding_mode,\n    const bool align_corners, const index_type grad_input_memory_span) {\n  GridSampler4DBackwardKernel(nthreads, grad_output_ptr, input_ptr, grid_ptr, grad_input_ptr,\n                              grad_grid_ptr, N, C, inp_H, inp_W, out_H, out_W, interpolation_mode,\n                              padding_mode, align_corners, grad_input_memory_span);\n}\n\ntemplate<typename data_type, typename index_type>\n__launch_bounds__(256) __global__ void CUDAGridSampler5DBackwardKernel(\n    const index_type nthreads, const data_type* grad_output_ptr, const data_type* input_ptr,\n    const data_type* grid_ptr, data_type* grad_input_ptr, data_type* grad_grid_ptr, index_type N,\n    index_type C, index_type inp_D, index_type inp_H, index_type inp_W, index_type out_D,\n    index_type out_H, index_type out_W, const GridSamplerInterpolation interpolation_mode,\n    const GridSamplerPadding padding_mode, const bool align_corners,\n    const index_type grad_input_memory_span) {\n  GridSampler5DBackwardKernel(nthreads, grad_output_ptr, input_ptr, grid_ptr, grad_input_ptr,\n                              grad_grid_ptr, N, C, inp_D, inp_H, inp_W, out_D, out_H, out_W,\n                              interpolation_mode, padding_mode, align_corners,\n                              grad_input_memory_span);\n}\n\ntemplate<typename data_type, typename index_type>\nstruct GridSampleKernelUtil<DeviceType::kCUDA, data_type, index_type> final {\n  static void Forward4D(user_op::KernelComputeContext* ctx, const user_op::Tensor* input,\n                        const user_op::Tensor* grid, user_op::Tensor* output,\n                        GridSamplerInterpolation interpolation, GridSamplerPadding padding,\n                        const bool align_corners, const ShapeView& input_shape,\n                        const ShapeView& grid_shape, const ShapeView& output_shape, int64_t count) {\n    if (CudnnGridSampleKernelUtil<data_type>::CanRunWithCudnn(ctx)\n        && CanUse32BitIndex({input_shape, grid_shape, output_shape})) {\n      return CudnnGridSampleKernelUtil<data_type>::ForwardCompute(ctx);\n    }\n\n    CUDAGridSampler4DKernel<data_type, index_type>\n        <<<GridSampleGetBlocks(count, 256), 256, 0,\n           ctx->stream()->As<ep::CudaStream>()->cuda_stream()>>>(\n            count, input->dptr<data_type>(), grid->dptr<data_type>(), output->mut_dptr<data_type>(),\n            input_shape.At(0), input_shape.At(1), input_shape.At(2), input_shape.At(3),\n            output_shape.At(2), output_shape.At(3), interpolation, padding, align_corners);\n  }\n  static void Forward5D(user_op::KernelComputeContext* ctx, const user_op::Tensor* input,\n                        const user_op::Tensor* grid, user_op::Tensor* output,\n                        GridSamplerInterpolation interpolation, GridSamplerPadding padding,\n                        const bool align_corners, const ShapeView& input_shape,\n                        const ShapeView& grid_shape, const ShapeView& output_shape, int64_t count) {\n    CUDAGridSampler5DKernel<data_type, index_type>\n        <<<GridSampleGetBlocks(count, 512), 512, 0,\n           ctx->stream()->As<ep::CudaStream>()->cuda_stream()>>>(\n            count, input->dptr<data_type>(), grid->dptr<data_type>(), output->mut_dptr<data_type>(),\n            input_shape.At(0), input_shape.At(1), input_shape.At(2), input_shape.At(3),\n            input_shape.At(4), output_shape.At(2), output_shape.At(3), output_shape.At(4),\n            interpolation, padding, align_corners);\n  }\n\n  static void Backward4D(user_op::KernelComputeContext* ctx, const user_op::Tensor* doutput,\n                         const user_op::Tensor* input, const user_op::Tensor* grid,\n                         user_op::Tensor* dinput, user_op::Tensor* dgrid,\n                         GridSamplerInterpolation interpolation, GridSamplerPadding padding,\n                         const bool align_corners, const ShapeView& input_shape,\n                         const ShapeView& grid_shape, const ShapeView& output_shape,\n                         int64_t count) {\n    if (CudnnGridSampleKernelUtil<data_type>::CanRunWithCudnn(ctx)\n        && CanUse32BitIndex({input_shape, grid_shape, output_shape})) {\n      return CudnnGridSampleKernelUtil<data_type>::BackwardCompute(ctx);\n    }\n\n    CUDAGridSampler4DBackwardKernel<data_type, index_type>\n        <<<GridSampleGetBlocks(count, 256), 256, 0,\n           ctx->stream()->As<ep::CudaStream>()->cuda_stream()>>>(\n            count, doutput->dptr<data_type>(), input->dptr<data_type>(), grid->dptr<data_type>(),\n            dinput->mut_dptr<data_type>(), dgrid->mut_dptr<data_type>(), input_shape.At(0),\n            input_shape.At(1), input_shape.At(2), input_shape.At(3), output_shape.At(2),\n            output_shape.At(3), interpolation, padding, align_corners, input_shape.elem_cnt());\n  }\n  static void Backward5D(user_op::KernelComputeContext* ctx, const user_op::Tensor* doutput,\n                         const user_op::Tensor* input, const user_op::Tensor* grid,\n                         user_op::Tensor* dinput, user_op::Tensor* dgrid,\n                         GridSamplerInterpolation interpolation, GridSamplerPadding padding,\n                         const bool align_corners, const ShapeView& input_shape,\n                         const ShapeView& grid_shape, const ShapeView& output_shape,\n                         int64_t count) {\n    CUDAGridSampler5DBackwardKernel<data_type, index_type>\n        <<<GridSampleGetBlocks(count, 256), 256, 0,\n           ctx->stream()->As<ep::CudaStream>()->cuda_stream()>>>(\n            count, doutput->dptr<data_type>(), input->dptr<data_type>(), grid->dptr<data_type>(),\n            dinput->mut_dptr<data_type>(), dgrid->mut_dptr<data_type>(), input_shape.At(0),\n            input_shape.At(1), input_shape.At(2), input_shape.At(3), input_shape.At(4),\n            output_shape.At(2), output_shape.At(3), output_shape.At(4), interpolation, padding,\n            align_corners, input_shape.elem_cnt());\n  }\n};\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_GRID_SAMPLE_KERNEL_UTIL, (DeviceType::kCUDA),\n                                 FLOATING_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/grid_sample_kernel_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_KERNELS_GRID_SAMPLE_KERNEL_H_\n#define ONEFLOW_USER_KERNELS_GRID_SAMPLE_KERNEL_H_\n\n#include \"oneflow/core/common/shape_view.h\"\n#include \"oneflow/core/common/data_type.h\"\n#include \"oneflow/core/framework/op_kernel.h\"\n#include \"oneflow/core/ndarray/xpu_util.h\"\n#include \"oneflow/user/kernels/clip_by_value_kernel.h\"\n#ifdef WITH_CUDA\n#include \"oneflow/core/cuda/atomic.cuh\"\n#endif  // WITH_CUDA\n\nnamespace oneflow {\n\nenum class GridSamplerInterpolation { kBilinear = 0, kNearest, kBicubic };\n\nenum class GridSamplerPadding { kZeros = 0, kBorder, kReflection };\n\nstatic GridSamplerInterpolation StringToGridSamplerInterpolation(const std::string& mode) {\n  if (mode == \"bilinear\") {\n    return GridSamplerInterpolation::kBilinear;\n  } else if (mode == \"nearest\") {\n    return GridSamplerInterpolation::kNearest;\n  }\n  return GridSamplerInterpolation::kBicubic;\n}\nstatic GridSamplerPadding StringToGridGridSamplerPadding(const std::string& mode) {\n  if (mode == \"zeros\") {\n    return GridSamplerPadding::kZeros;\n  } else if (mode == \"border\") {\n    return GridSamplerPadding::kBorder;\n  }\n  return GridSamplerPadding::kReflection;\n}\nstatic bool CanUse32BitIndex(const std::initializer_list<ShapeView>& shapes) {\n  for (const auto& shape : shapes) {\n    if (shape.elem_cnt() >= std::numeric_limits<int32_t>::max()) { return false; }\n  }\n  return true;\n}\n\ninline int GridSampleGetBlocks(const int64_t number, const int64_t threads_per_block) {\n  // Round up division for positive number that cannot cause integer overflow\n  auto block_num = (number - 1) / threads_per_block + 1;\n  return static_cast<int>(block_num);\n}\n\n// This kernel implement is referenced from:\n// https://github.com/pytorch/pytorch with git commit id: e7724bb\n// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cuda/GridSampler.cu\n// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cuda/GridSampler.cuh\n\n// Unnormalizes a coordinate from the -1 to +1 scale to its pixel index value,\n// where we view each pixel as an area between (idx - 0.5) and (idx + 0.5).\n// if align_corners: -1 and +1 get sent to the centers of the corner pixels\n//     -1 --> 0\n//     +1 --> (size - 1)\n//     scale_factor = (size - 1) / 2\n// if not align_corners: -1 and +1 get sent to the image edges\n//     -1 --> -0.5\n//     +1 --> (size - 1) + 0.5 == size - 0.5\n//     scale_factor = size / 2\ntemplate<typename scalar_t>\nstatic OF_DEVICE_FUNC scalar_t GridSamplerUnnormalize(scalar_t coord, int size,\n                                                      bool align_corners) {\n  if (align_corners) {\n    // unnormalize coord from [-1, 1] to [0, size - 1]\n    return ((coord + 1.f) / 2) * (size - 1);\n  } else {\n    // unnormalize coord from [-1, 1] to [-0.5, size - 0.5]\n    return ((coord + 1.f) * size - 1) / 2;\n  }\n}\n\n// GridSamplerUnnormalizeSetGrad works the same as GridSamplerUnnormalize\n// except that it also returns the `d output / d input` via pointer argument\n// `grad_in`.\n// This is useful in the backward pass of grid_sampler.\ntemplate<typename scalar_t>\nstatic OF_DEVICE_FUNC scalar_t GridSamplerUnnormalizeSetGrad(scalar_t coord, int size,\n                                                             bool align_corners,\n                                                             scalar_t* grad_in) {\n  if (align_corners) {\n    // unnormalize coord from [-1, 1] to [0, size - 1]\n    *grad_in = static_cast<scalar_t>(size - 1) / 2;\n    return ((coord + 1.f) / 2) * (size - 1);\n  } else {\n    // unnormalize coord from [-1, 1] to [-0.5, size - 0.5]\n    *grad_in = static_cast<scalar_t>(size) / 2;\n    return ((coord + 1.f) * size - 1) / 2;\n  }\n}\n\n// Clips coordinates to between 0 and clip_limit - 1\ntemplate<typename scalar_t>\nstatic OF_DEVICE_FUNC scalar_t ClipCoordinates(scalar_t in, int clip_limit) {\n  return DeviceMin(static_cast<scalar_t>(clip_limit - 1), DeviceMax(in, static_cast<scalar_t>(0)));\n}\n\n// ClipCoordinatesSetGrad works similarly to ClipCoordinates except that\n// it also returns the `d output / d input` via pointer argument `grad_in`.\n// This is useful in the backward pass of grid_sampler.\ntemplate<typename scalar_t>\nstatic OF_DEVICE_FUNC scalar_t ClipCoordinatesSetGrad(scalar_t in, int clip_limit,\n                                                      scalar_t* grad_in) {\n  // Note that it is important for the gradient calculation that borders\n  // are considered out of bounds.\n  if (in <= static_cast<scalar_t>(0)) {\n    *grad_in = static_cast<scalar_t>(0);\n    return static_cast<scalar_t>(0);\n  } else {\n    scalar_t max = static_cast<scalar_t>(clip_limit - 1);\n    if (in >= max) {\n      *grad_in = static_cast<scalar_t>(0);\n      return max;\n    } else {\n      *grad_in = static_cast<scalar_t>(1);\n      return in;\n    }\n  }\n}\n\n// Reflects coordinates until they fall between low and high (inclusive).\n// The bounds are passed as twice their value so that half-integer values\n// can be represented as ints.\ntemplate<typename scalar_t>\nstatic OF_DEVICE_FUNC scalar_t ReflectCoordinates(scalar_t in, int twice_low, int twice_high) {\n  if (twice_low == twice_high) { return static_cast<scalar_t>(0); }\n  scalar_t min = static_cast<scalar_t>(twice_low) / 2;\n  scalar_t span = static_cast<scalar_t>(twice_high - twice_low) / 2;\n  in = fabs(in - min);\n  // `fmod` returns same sign as `in`, which is positive after the `fabs` above.\n  scalar_t extra = fmod(in, span);\n  int flips = static_cast<int>(floor(in / span));\n  if (flips % 2 == 0) {\n    return extra + min;\n  } else {\n    return span - extra + min;\n  }\n}\n\n// ReflectCoordinatesSetGrad works similarly to ReflectCoordinates except\n// that it also returns the `d output / d input` via pointer argument\n// `grad_in`.\n// This is useful in the backward pass of grid_sampler.\ntemplate<typename scalar_t>\nstatic OF_DEVICE_FUNC scalar_t ReflectCoordinatesSetGrad(scalar_t in, int twice_low, int twice_high,\n                                                         scalar_t* grad_in) {\n  if (twice_low == twice_high) {\n    *grad_in = static_cast<scalar_t>(0);\n    return static_cast<scalar_t>(0);\n  }\n  int grad_in_mult_ = 1;\n  scalar_t min = static_cast<scalar_t>(twice_low) / 2;\n  scalar_t span = static_cast<scalar_t>(twice_high - twice_low) / 2;\n  in = in - min;\n  if (in < static_cast<scalar_t>(0)) {\n    grad_in_mult_ = -1;\n    in = -in;\n  } else {\n    grad_in_mult_ = 1;\n  }\n  // `fmod` returns same sign as `in`, which is positive after the `if` above.\n  scalar_t extra = fmod(in, span);\n  int flips = static_cast<int>(floor(in / span));\n  if (flips % 2 == 0) {\n    *grad_in = static_cast<scalar_t>(grad_in_mult_);\n    return extra + min;\n  } else {\n    *grad_in = static_cast<scalar_t>(-grad_in_mult_);\n    return span - extra + min;\n  }\n}\n\n#if defined(__CUDACC__)\ntemplate<typename scalar_t>\nstatic __device__ __forceinline__ scalar_t safe_downgrade_to_int_range(scalar_t x) {\n  // -100.0 does not have special meaning. This is just to make sure\n  // it's not WithinBounds2D or WithinBounds3D, and does not cause\n  // undefined behavior. See #35506.\n  // TODO(pei tingkuan): (explicit or implicit) type conversion from\n  // INT_MAX - 1 to float(INT_MAX - 1) indeed changes value from\n  // 2147483647 to 2147483648 and losses precision\n  // Reference: https://stackoverflow.com/q/526070\n  if (x > static_cast<scalar_t>(INT_MAX - 1) || x < INT_MIN || !isfinite(static_cast<double>(x)))\n    return static_cast<scalar_t>(-100.0);\n  return x;\n}\n#endif\n\ntemplate<typename scalar_t>\nstatic OF_DEVICE_FUNC scalar_t ComputeCoordinates(scalar_t coord, int size,\n                                                  GridSamplerPadding padding_mode,\n                                                  bool align_corners) {\n  if (padding_mode == GridSamplerPadding::kBorder) {\n    // clip coordinates to image borders\n    coord = ClipCoordinates(coord, size);\n  } else if (padding_mode == GridSamplerPadding::kReflection) {\n    // reflect coordinates by image borders\n    if (align_corners) {\n      coord = ReflectCoordinates(coord, 0, 2 * (size - 1));\n    } else {\n      coord = ReflectCoordinates(coord, -1, 2 * size - 1);\n    }\n    // clip coordinates to image borders\n    coord = ClipCoordinates(coord, size);\n  }\n#if defined(__CUDACC__)\n  coord = safe_downgrade_to_int_range(coord);\n#endif\n  return coord;\n}\n\n// Computes the pixel source index value for a grid coordinate\ntemplate<typename scalar_t>\nstatic OF_DEVICE_FUNC scalar_t GridSamplerComputeSourceIndex(scalar_t coord, int size,\n                                                             GridSamplerPadding padding_mode,\n                                                             bool align_corners) {\n  coord = GridSamplerUnnormalize(coord, size, align_corners);\n  coord = ComputeCoordinates(coord, size, padding_mode, align_corners);\n  return coord;\n}\n\n// GridSamplerComputeSourceIndexSetGrad works similarly to\n// GridSamplerComputeSourceIndex except that it also returns the\n// `d output / d input` via pointer argument `grad_in`.\n// This is useful in the backward pass of grid_sampler.\ntemplate<typename scalar_t>\nstatic OF_DEVICE_FUNC scalar_t GridSamplerComputeSourceIndexSetGrad(scalar_t coord, int size,\n                                                                    GridSamplerPadding padding_mode,\n                                                                    bool align_corners,\n                                                                    scalar_t* grad_in) {\n  scalar_t grad_clip, grad_refl;\n  coord = GridSamplerUnnormalizeSetGrad(coord, size, align_corners, grad_in);\n  if (padding_mode == GridSamplerPadding::kBorder) {\n    // clip coordinates to image borders\n    coord = ClipCoordinatesSetGrad(coord, size, &grad_clip);\n    *grad_in = (*grad_in) * grad_clip;\n  } else if (padding_mode == GridSamplerPadding::kReflection) {\n    // reflect coordinates by image borders\n    if (align_corners) {\n      coord = ReflectCoordinatesSetGrad(coord, 0, 2 * (size - 1), &grad_refl);\n    } else {\n      coord = ReflectCoordinatesSetGrad(coord, -1, 2 * size - 1, &grad_refl);\n    }\n    // clip coordinates to image borders\n    coord = ClipCoordinatesSetGrad(coord, size, &grad_clip);\n    *grad_in = (*grad_in) * grad_refl * grad_clip;\n  }\n\n#if defined(__CUDACC__)\n  coord = safe_downgrade_to_int_range(coord);\n#endif\n  return coord;\n}\n\nstatic OF_DEVICE_FUNC bool WithinBounds2D(int h, int w, int H, int W) {\n  return h >= 0 && h < H && w >= 0 && w < W;\n}\n\nstatic OF_DEVICE_FUNC bool WithinBounds3D(int d, int h, int w, int D, int H, int W) {\n  return d >= 0 && d < D && h >= 0 && h < H && w >= 0 && w < W;\n}\n\ntemplate<typename scalar_t>\nstatic OF_DEVICE_FUNC scalar_t GetValueBounded(const scalar_t* data, scalar_t x, scalar_t y, int W,\n                                               int H, int sW, int sH,\n                                               GridSamplerPadding padding_mode,\n                                               bool align_corners) {\n  x = ComputeCoordinates(x, W, padding_mode, align_corners);\n  y = ComputeCoordinates(y, H, padding_mode, align_corners);\n\n  int ix = static_cast<int>(x);\n  int iy = static_cast<int>(y);\n\n  if (WithinBounds2D(iy, ix, H, W)) { return data[iy * sH + ix * sW]; }\n  return static_cast<scalar_t>(0);\n}\n\ntemplate<typename scalar_t, typename index_t>\nstatic OF_DEVICE_FUNC void SafeAdd2D(scalar_t* data, int h, int w, int sH, int sW, int H, int W,\n                                     scalar_t delta, const index_t NC_offset,\n                                     const index_t memory_span) {\n  if (WithinBounds2D(h, w, H, W)) {\n#if defined(__CUDACC__)\n    cuda::atomic::Add(data + NC_offset + h * sH + w * sW, delta);\n#else\n    data[NC_offset + h * sH + w * sW] += delta;\n#endif\n  }\n}\n\ntemplate<typename scalar_t, typename index_t>\nstatic OF_DEVICE_FUNC void SafeAdd3D(scalar_t* data, int d, int h, int w, int sD, int sH, int sW,\n                                     int D, int H, int W, scalar_t delta, const index_t NC_offset,\n                                     const index_t memory_span) {\n  if (WithinBounds3D(d, h, w, D, H, W)) {\n#if defined(__CUDACC__)\n    cuda::atomic::Add(data + NC_offset + d * sD + h * sH + w * sW, delta);\n#else\n    data[NC_offset + d * sD + h * sH + w * sW] += delta;\n#endif\n  }\n}\n\ntemplate<typename scalar_t, typename index_t>\nstatic OF_DEVICE_FUNC void AddValueBounded(scalar_t* data, scalar_t x, scalar_t y, int W, int H,\n                                           int sW, int sH, scalar_t delta,\n                                           GridSamplerPadding padding_mode, bool align_corners,\n                                           const index_t NC_offset, const index_t memory_span) {\n  x = ComputeCoordinates(x, W, padding_mode, align_corners);\n  y = ComputeCoordinates(y, H, padding_mode, align_corners);\n\n  int ix = static_cast<int>(x);\n  int iy = static_cast<int>(y);\n\n  SafeAdd2D(data, iy, ix, sH, sW, H, W, delta, NC_offset, memory_span);\n}\n\n// Calculate the differential of the cubic convolution, i.e. `d coeff / d x`\ntemplate<typename scalar_t>\nstatic OF_DEVICE_FUNC void GetCubicCoefficientsGrad(scalar_t coeffs[4], scalar_t t) {\n  // Must be the same as forward calculation in\n  // aten/src/ATen/native/cuda/UpSample.cuh:get_cubic_upsample_coefficients\n  scalar_t A = -0.75;\n\n  scalar_t x;\n  x = -1 - t;  // 1 < x = |-1 - tx| < 2\n  coeffs[0] = (-3 * A * x - 10 * A) * x - 8 * A;\n  x = -t;  // x = |0 - tx| <= 1\n  coeffs[1] = (-3 * (A + 2) * x - 2 * (A + 3)) * x;\n  x = 1 - t;  // x = |1 - tx| <= 1\n  coeffs[2] = (3 * (A + 2) * x - 2 * (A + 3)) * x;\n  x = 2 - t;  // 1 < x = |2 - tx| < 2\n  coeffs[3] = (3 * A * x - 10 * A) * x + 8 * A;\n}\n\n// Based on\n// https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm\ntemplate<typename accscalar_t>\nOF_DEVICE_FUNC static accscalar_t CubicConvolution1(accscalar_t x, accscalar_t A) {\n  return ((A + 2) * x - (A + 3)) * x * x + 1;\n}\n\ntemplate<typename accscalar_t>\nOF_DEVICE_FUNC static accscalar_t CubicConvolution2(accscalar_t x, accscalar_t A) {\n  return ((A * x - 5 * A) * x + 8 * A) * x - 4 * A;\n}\n\ntemplate<typename accscalar_t>\nOF_DEVICE_FUNC static void GetCubicUpsamplingCoefficients(accscalar_t coeffs[4], accscalar_t t) {\n  accscalar_t A = -0.75;\n\n  accscalar_t x1 = t;\n  coeffs[0] = CubicConvolution2<accscalar_t>(x1 + 1.0, A);\n  coeffs[1] = CubicConvolution1<accscalar_t>(x1, A);\n\n  // opposite coefficients\n  accscalar_t x2 = 1.0 - t;\n  coeffs[2] = CubicConvolution1<accscalar_t>(x2, A);\n  coeffs[3] = CubicConvolution2<accscalar_t>(x2 + 1.0, A);\n}\n\ntemplate<typename scalar_t, typename accscalar_t>\nOF_DEVICE_FUNC static accscalar_t cubic_interp1d(scalar_t x0, scalar_t x1, scalar_t x2, scalar_t x3,\n                                                 accscalar_t t) {\n  accscalar_t coeffs[4];\n  GetCubicUpsamplingCoefficients<accscalar_t>(coeffs, t);\n\n  return x0 * coeffs[0] + x1 * coeffs[1] + x2 * coeffs[2] + x3 * coeffs[3];\n}\n\ntemplate<typename data_type, typename index_type>\nOF_DEVICE_FUNC void GridSampler4DKernel(const index_type nthreads, const data_type* input_ptr,\n                                        const data_type* grid_ptr, data_type* output_ptr,\n                                        index_type N, index_type C, index_type inp_H,\n                                        index_type inp_W, index_type out_H, index_type out_W,\n                                        const GridSamplerInterpolation interpolation_mode,\n                                        const GridSamplerPadding padding_mode,\n                                        const bool align_corners) {\n  index_type inp_sN = C * inp_H * inp_W;\n  index_type inp_sC = inp_H * inp_W;\n  index_type inp_sH = inp_W;\n  index_type inp_sW = 1;\n  index_type grid_sN = out_H * out_W * 2;\n  index_type grid_sH = out_W * 2;\n  index_type grid_sW = 2;\n  index_type grid_sCoor = 1;\n  index_type out_sN = C * out_H * out_W;\n  index_type out_sC = out_H * out_W;\n  index_type out_sH = out_W;\n  index_type out_sW = 1;\n\n  XPU_1D_KERNEL_LOOP(index, nthreads) {\n    const index_type w = index % out_W;\n    const index_type h = (index / out_W) % out_H;\n    const index_type n = index / (out_H * out_W);\n    const index_type grid_offset = n * grid_sN + h * grid_sH + w * grid_sW;\n\n    // get the corresponding input x, y co-ordinates from grid\n    data_type x = grid_ptr[grid_offset];\n    data_type y = grid_ptr[grid_offset + grid_sCoor];\n\n    data_type ix = GridSamplerComputeSourceIndex(x, inp_W, padding_mode, align_corners);\n    data_type iy = GridSamplerComputeSourceIndex(y, inp_H, padding_mode, align_corners);\n\n    if (interpolation_mode == GridSamplerInterpolation::kBilinear) {\n      // get NE, NW, SE, SW pixel values from (x, y)\n      index_type ix_nw = static_cast<index_type>(::floor(ix));\n      index_type iy_nw = static_cast<index_type>(::floor(iy));\n      index_type ix_ne = ix_nw + 1;\n      index_type iy_ne = iy_nw;\n      index_type ix_sw = ix_nw;\n      index_type iy_sw = iy_nw + 1;\n      index_type ix_se = ix_nw + 1;\n      index_type iy_se = iy_nw + 1;\n\n      // get surfaces to each neighbor:\n      data_type nw = (ix_se - ix) * (iy_se - iy);\n      data_type ne = (ix - ix_sw) * (iy_sw - iy);\n      data_type sw = (ix_ne - ix) * (iy - iy_ne);\n      data_type se = (ix - ix_nw) * (iy - iy_nw);\n\n      // calculate bilinear weighted pixel value and set output pixel\n      auto inp_ptr_NC = input_ptr + n * inp_sN;\n      auto out_ptr_NCHW = output_ptr + n * out_sN + h * out_sH + w * out_sW;\n      for (index_type c = 0; c < C; ++c, inp_ptr_NC += inp_sC, out_ptr_NCHW += out_sC) {\n        *out_ptr_NCHW = static_cast<data_type>(0);\n        if (WithinBounds2D(iy_nw, ix_nw, inp_H, inp_W)) {\n          *out_ptr_NCHW += inp_ptr_NC[iy_nw * inp_sH + ix_nw * inp_sW] * nw;\n        }\n        if (WithinBounds2D(iy_ne, ix_ne, inp_H, inp_W)) {\n          *out_ptr_NCHW += inp_ptr_NC[iy_ne * inp_sH + ix_ne * inp_sW] * ne;\n        }\n        if (WithinBounds2D(iy_sw, ix_sw, inp_H, inp_W)) {\n          *out_ptr_NCHW += inp_ptr_NC[iy_sw * inp_sH + ix_sw * inp_sW] * sw;\n        }\n        if (WithinBounds2D(iy_se, ix_se, inp_H, inp_W)) {\n          *out_ptr_NCHW += inp_ptr_NC[iy_se * inp_sH + ix_se * inp_sW] * se;\n        }\n      }\n    } else if (interpolation_mode == GridSamplerInterpolation::kNearest) {\n      index_type ix_nearest = static_cast<index_type>(::round(ix));\n      index_type iy_nearest = static_cast<index_type>(::round(iy));\n\n      // assign nearest neighor pixel value to output pixel\n      auto inp_ptr_NC = input_ptr + n * inp_sN;\n      auto out_ptr_NCHW = output_ptr + n * out_sN + h * out_sH + w * out_sW;\n      for (index_type c = 0; c < C; ++c, inp_ptr_NC += inp_sC, out_ptr_NCHW += out_sC) {\n        if (WithinBounds2D(iy_nearest, ix_nearest, inp_H, inp_W)) {\n          *out_ptr_NCHW = inp_ptr_NC[iy_nearest * inp_sH + ix_nearest * inp_sW];\n        } else {\n          *out_ptr_NCHW = static_cast<data_type>(0);\n        }\n      }\n    } else if (interpolation_mode == GridSamplerInterpolation::kBicubic) {\n      ix = GridSamplerUnnormalize(x, inp_W, align_corners);\n      iy = GridSamplerUnnormalize(y, inp_H, align_corners);\n\n      data_type ix_nw = ::floor(ix);\n      data_type iy_nw = ::floor(iy);\n\n      const data_type tx = ix - ix_nw;\n      const data_type ty = iy - iy_nw;\n\n      auto inp_ptr_NC = input_ptr + n * inp_sN;\n      auto out_ptr_NCHW = output_ptr + n * out_sN + h * out_sH + w * out_sW;\n      for (index_type c = 0; c < C; ++c, inp_ptr_NC += inp_sC, out_ptr_NCHW += out_sC) {\n        data_type coefficients[4];\n#ifdef __CUDA_ARCH__\n#pragma unroll 4\n#endif\n        for (index_type i = 0; i < 4; ++i) {\n          coefficients[i] = cubic_interp1d(\n              GetValueBounded<data_type>(inp_ptr_NC, ix_nw - 1, iy_nw - 1 + i, inp_W, inp_H, inp_sW,\n                                         inp_sH, padding_mode, align_corners),\n              GetValueBounded<data_type>(inp_ptr_NC, ix_nw + 0, iy_nw - 1 + i, inp_W, inp_H, inp_sW,\n                                         inp_sH, padding_mode, align_corners),\n              GetValueBounded<data_type>(inp_ptr_NC, ix_nw + 1, iy_nw - 1 + i, inp_W, inp_H, inp_sW,\n                                         inp_sH, padding_mode, align_corners),\n              GetValueBounded<data_type>(inp_ptr_NC, ix_nw + 2, iy_nw - 1 + i, inp_W, inp_H, inp_sW,\n                                         inp_sH, padding_mode, align_corners),\n              tx);\n        }\n\n        *out_ptr_NCHW =\n            cubic_interp1d(coefficients[0], coefficients[1], coefficients[2], coefficients[3], ty);\n      }\n    }\n  }\n}\n\ntemplate<typename data_type, typename index_type>\nOF_DEVICE_FUNC void GridSampler5DKernel(const index_type nthreads, const data_type* input_ptr,\n                                        const data_type* grid_ptr, data_type* output_ptr,\n                                        index_type N, index_type C, index_type inp_D,\n                                        index_type inp_H, index_type inp_W, index_type out_D,\n                                        index_type out_H, index_type out_W,\n                                        const GridSamplerInterpolation interpolation_mode,\n                                        const GridSamplerPadding padding_mode,\n                                        const bool align_corners) {\n  index_type inp_sN = C * inp_D * inp_H * inp_W;\n  index_type inp_sC = inp_D * inp_H * inp_W;\n  index_type inp_sD = inp_H * inp_W;\n  index_type inp_sH = inp_W;\n  index_type inp_sW = 1;\n  index_type grid_sN = out_D * out_H * out_W * 3;\n  index_type grid_sD = out_H * out_W * 3;\n  index_type grid_sH = out_W * 3;\n  index_type grid_sW = 3;\n  index_type grid_sCoor = 1;\n  index_type out_sN = C * out_D * out_H * out_W;\n  index_type out_sC = out_D * out_H * out_W;\n  index_type out_sD = out_H * out_W;\n  index_type out_sH = out_W;\n  index_type out_sW = 1;\n\n  XPU_1D_KERNEL_LOOP(index, nthreads) {\n    const index_type w = index % out_W;\n    const index_type h = (index / out_W) % out_H;\n    const index_type d = (index / (out_H * out_W)) % out_D;\n    const index_type n = index / (out_D * out_H * out_W);\n    const index_type grid_offset = n * grid_sN + d * grid_sD + h * grid_sH + w * grid_sW;\n\n    // get the corresponding input x, y, z co-ordinates from grid\n    data_type ix = grid_ptr[grid_offset];\n    data_type iy = grid_ptr[grid_offset + grid_sCoor];\n    data_type iz = grid_ptr[grid_offset + 2 * grid_sCoor];\n\n    ix = GridSamplerComputeSourceIndex(ix, inp_W, padding_mode, align_corners);\n    iy = GridSamplerComputeSourceIndex(iy, inp_H, padding_mode, align_corners);\n    iz = GridSamplerComputeSourceIndex(iz, inp_D, padding_mode, align_corners);\n\n    if (interpolation_mode == GridSamplerInterpolation::kBilinear) {\n      // get corner pixel values from (x, y, z)\n      // for 4d, we used north-east-south-west\n      // for 5d, we add top-bottom\n      index_type ix_tnw = static_cast<index_type>(::floor(ix));\n      index_type iy_tnw = static_cast<index_type>(::floor(iy));\n      index_type iz_tnw = static_cast<index_type>(::floor(iz));\n\n      index_type ix_tne = ix_tnw + 1;\n      index_type iy_tne = iy_tnw;\n      index_type iz_tne = iz_tnw;\n\n      index_type ix_tsw = ix_tnw;\n      index_type iy_tsw = iy_tnw + 1;\n      index_type iz_tsw = iz_tnw;\n\n      index_type ix_tse = ix_tnw + 1;\n      index_type iy_tse = iy_tnw + 1;\n      index_type iz_tse = iz_tnw;\n\n      index_type ix_bnw = ix_tnw;\n      index_type iy_bnw = iy_tnw;\n      index_type iz_bnw = iz_tnw + 1;\n\n      index_type ix_bne = ix_tnw + 1;\n      index_type iy_bne = iy_tnw;\n      index_type iz_bne = iz_tnw + 1;\n\n      index_type ix_bsw = ix_tnw;\n      index_type iy_bsw = iy_tnw + 1;\n      index_type iz_bsw = iz_tnw + 1;\n\n      index_type ix_bse = ix_tnw + 1;\n      index_type iy_bse = iy_tnw + 1;\n      index_type iz_bse = iz_tnw + 1;\n\n      // get surfaces to each neighbor:\n      data_type tnw = (ix_bse - ix) * (iy_bse - iy) * (iz_bse - iz);\n      data_type tne = (ix - ix_bsw) * (iy_bsw - iy) * (iz_bsw - iz);\n      data_type tsw = (ix_bne - ix) * (iy - iy_bne) * (iz_bne - iz);\n      data_type tse = (ix - ix_bnw) * (iy - iy_bnw) * (iz_bnw - iz);\n      data_type bnw = (ix_tse - ix) * (iy_tse - iy) * (iz - iz_tse);\n      data_type bne = (ix - ix_tsw) * (iy_tsw - iy) * (iz - iz_tsw);\n      data_type bsw = (ix_tne - ix) * (iy - iy_tne) * (iz - iz_tne);\n      data_type bse = (ix - ix_tnw) * (iy - iy_tnw) * (iz - iz_tnw);\n\n      auto inp_ptr_NC = input_ptr + n * inp_sN;\n      auto out_ptr_NCDHW = output_ptr + n * out_sN + d * out_sD + h * out_sH + w * out_sW;\n      for (index_type c = 0; c < C; ++c, inp_ptr_NC += inp_sC, out_ptr_NCDHW += out_sC) {\n        //   (c, iz_tnw, iy_tnw, ix_tnw) * tnw + (c, iz_tne, iy_tne, ix_tne) * tne\n        // + (c, iz_tsw, iy_tsw, ix_tsw) * tsw + (c, iz_tse, iy_tse, ix_tse) * tse\n        // + (c, iz_bnw, iy_bnw, ix_bnw) * bnw + (c, iz_bne, iy_bne, ix_bne) * bne\n        // + (c, iz_bsw, iy_bsw, ix_bsw) * bsw + (c, iz_bse, iy_bse, ix_bse) * bse\n        *out_ptr_NCDHW = static_cast<data_type>(0);\n        if (WithinBounds3D(iz_tnw, iy_tnw, ix_tnw, inp_D, inp_H, inp_W)) {\n          *out_ptr_NCDHW += inp_ptr_NC[iz_tnw * inp_sD + iy_tnw * inp_sH + ix_tnw * inp_sW] * tnw;\n        }\n        if (WithinBounds3D(iz_tne, iy_tne, ix_tne, inp_D, inp_H, inp_W)) {\n          *out_ptr_NCDHW += inp_ptr_NC[iz_tne * inp_sD + iy_tne * inp_sH + ix_tne * inp_sW] * tne;\n        }\n        if (WithinBounds3D(iz_tsw, iy_tsw, ix_tsw, inp_D, inp_H, inp_W)) {\n          *out_ptr_NCDHW += inp_ptr_NC[iz_tsw * inp_sD + iy_tsw * inp_sH + ix_tsw * inp_sW] * tsw;\n        }\n        if (WithinBounds3D(iz_tse, iy_tse, ix_tse, inp_D, inp_H, inp_W)) {\n          *out_ptr_NCDHW += inp_ptr_NC[iz_tse * inp_sD + iy_tse * inp_sH + ix_tse * inp_sW] * tse;\n        }\n        if (WithinBounds3D(iz_bnw, iy_bnw, ix_bnw, inp_D, inp_H, inp_W)) {\n          *out_ptr_NCDHW += inp_ptr_NC[iz_bnw * inp_sD + iy_bnw * inp_sH + ix_bnw * inp_sW] * bnw;\n        }\n        if (WithinBounds3D(iz_bne, iy_bne, ix_bne, inp_D, inp_H, inp_W)) {\n          *out_ptr_NCDHW += inp_ptr_NC[iz_bne * inp_sD + iy_bne * inp_sH + ix_bne * inp_sW] * bne;\n        }\n        if (WithinBounds3D(iz_bsw, iy_bsw, ix_bsw, inp_D, inp_H, inp_W)) {\n          *out_ptr_NCDHW += inp_ptr_NC[iz_bsw * inp_sD + iy_bsw * inp_sH + ix_bsw * inp_sW] * bsw;\n        }\n        if (WithinBounds3D(iz_bse, iy_bse, ix_bse, inp_D, inp_H, inp_W)) {\n          *out_ptr_NCDHW += inp_ptr_NC[iz_bse * inp_sD + iy_bse * inp_sH + ix_bse * inp_sW] * bse;\n        }\n      }\n    } else if (interpolation_mode == GridSamplerInterpolation::kNearest) {\n      index_type ix_nearest = static_cast<index_type>(::round(ix));\n      index_type iy_nearest = static_cast<index_type>(::round(iy));\n      index_type iz_nearest = static_cast<index_type>(::round(iz));\n\n      // assign nearest neighor pixel value to output pixel\n      auto inp_ptr_NC = input_ptr + n * inp_sN;\n      auto out_ptr_NCDHW = output_ptr + n * out_sN + d * out_sD + h * out_sH + w * out_sW;\n      for (index_type c = 0; c < C; ++c, inp_ptr_NC += inp_sC, out_ptr_NCDHW += out_sC) {\n        if (WithinBounds3D(iz_nearest, iy_nearest, ix_nearest, inp_D, inp_H, inp_W)) {\n          *out_ptr_NCDHW =\n              inp_ptr_NC[iz_nearest * inp_sD + iy_nearest * inp_sH + ix_nearest * inp_sW];\n        } else {\n          *out_ptr_NCDHW = static_cast<data_type>(0);\n        }\n      }\n    }\n  }\n}\n\n// Note [Passing pointer and offset to fastAtomicAdd]\n// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n// For its internal bounds checking, fastAtomicAdd needs to know where the destination address\n// lies relative to the entire tensor, so we pass the base grad_input_ptr and full offset\n// information, including batch * channel offset (NC_offset).\n\ntemplate<typename data_type, typename index_type>\nOF_DEVICE_FUNC void GridSampler4DBackwardKernel(\n    const index_type nthreads, const data_type* grad_output_ptr, const data_type* input_ptr,\n    const data_type* grid_ptr, data_type* grad_input_ptr, data_type* grad_grid_ptr, index_type N,\n    index_type C, index_type inp_H, index_type inp_W, index_type out_H, index_type out_W,\n    const GridSamplerInterpolation interpolation_mode, const GridSamplerPadding padding_mode,\n    const bool align_corners, const index_type grad_input_memory_span) {\n  index_type inp_sN = C * inp_H * inp_W;\n  index_type inp_sC = inp_H * inp_W;\n  index_type inp_sH = inp_W;\n  index_type inp_sW = 1;\n  index_type grid_sN = out_H * out_W * 2;\n  index_type grid_sH = out_W * 2;\n  index_type grid_sW = 2;\n  index_type grid_sCoor = 1;\n  index_type gOut_sN = C * out_H * out_W;\n  index_type gOut_sC = out_H * out_W;\n  index_type gOut_sH = out_W;\n  index_type gOut_sW = 1;\n  index_type gInp_sN = inp_sN;\n  index_type gInp_sC = inp_sC;\n  index_type gInp_sH = inp_sH;\n  index_type gInp_sW = inp_sW;\n  index_type gGrid_sW = grid_sW;\n\n  XPU_1D_KERNEL_LOOP(index, nthreads) {\n    const index_type w = index % out_W;\n    const index_type h = (index / out_W) % out_H;\n    const index_type n = index / (out_H * out_W);\n    const auto grid_offset = n * grid_sN + h * grid_sH + w * grid_sW;\n\n    // get the corresponding input x, y co-ordinates from grid\n    data_type x = grid_ptr[grid_offset];\n    data_type y = grid_ptr[grid_offset + grid_sCoor];\n\n    // multipliers for gradients on ix and iy\n    data_type gix_mult, giy_mult;\n    data_type ix =\n        GridSamplerComputeSourceIndexSetGrad(x, inp_W, padding_mode, align_corners, &gix_mult);\n    data_type iy =\n        GridSamplerComputeSourceIndexSetGrad(y, inp_H, padding_mode, align_corners, &giy_mult);\n\n    if (interpolation_mode == GridSamplerInterpolation::kBilinear) {\n      // get NE, NW, SE, SW pixel values from (x, y)\n      index_type ix_nw = static_cast<index_type>(::floor(ix));\n      index_type iy_nw = static_cast<index_type>(::floor(iy));\n      index_type ix_ne = ix_nw + 1;\n      index_type iy_ne = iy_nw;\n      index_type ix_sw = ix_nw;\n      index_type iy_sw = iy_nw + 1;\n      index_type ix_se = ix_nw + 1;\n      index_type iy_se = iy_nw + 1;\n\n      // get surfaces to each neighbor:\n      data_type nw = (ix_se - ix) * (iy_se - iy);\n      data_type ne = (ix - ix_sw) * (iy_sw - iy);\n      data_type sw = (ix_ne - ix) * (iy - iy_ne);\n      data_type se = (ix - ix_nw) * (iy - iy_nw);\n\n      data_type gix = static_cast<data_type>(0), giy = static_cast<data_type>(0);\n      const data_type* gOut_ptr_NCHW = grad_output_ptr + n * gOut_sN + h * gOut_sH + w * gOut_sW;\n      index_type NC_offset = n * gInp_sN;\n      const data_type* inp_ptr_NC = input_ptr + n * inp_sN;\n      for (index_type c = 0; c < C;\n           ++c, inp_ptr_NC += inp_sC, NC_offset += gInp_sC, gOut_ptr_NCHW += gOut_sC) {\n        data_type gOut = *gOut_ptr_NCHW;\n\n        // calculate and set grad_input. See Note [Passing pointer and offset to fastAtomicAdd].\n        SafeAdd2D(grad_input_ptr, iy_nw, ix_nw, gInp_sH, gInp_sW, inp_H, inp_W, nw * gOut,\n                  NC_offset, grad_input_memory_span);\n        SafeAdd2D(grad_input_ptr, iy_ne, ix_ne, gInp_sH, gInp_sW, inp_H, inp_W, ne * gOut,\n                  NC_offset, grad_input_memory_span);\n        SafeAdd2D(grad_input_ptr, iy_sw, ix_sw, gInp_sH, gInp_sW, inp_H, inp_W, sw * gOut,\n                  NC_offset, grad_input_memory_span);\n        SafeAdd2D(grad_input_ptr, iy_se, ix_se, gInp_sH, gInp_sW, inp_H, inp_W, se * gOut,\n                  NC_offset, grad_input_memory_span);\n\n        // calculate grad_grid\n        if (WithinBounds2D(iy_nw, ix_nw, inp_H, inp_W)) {\n          data_type nw_val = inp_ptr_NC[iy_nw * inp_sH + ix_nw * inp_sW];\n          gix -= nw_val * (iy_se - iy) * gOut;\n          giy -= nw_val * (ix_se - ix) * gOut;\n        }\n        if (WithinBounds2D(iy_ne, ix_ne, inp_H, inp_W)) {\n          data_type ne_val = inp_ptr_NC[iy_ne * inp_sH + ix_ne * inp_sW];\n          gix += ne_val * (iy_sw - iy) * gOut;\n          giy -= ne_val * (ix - ix_sw) * gOut;\n        }\n        if (WithinBounds2D(iy_sw, ix_sw, inp_H, inp_W)) {\n          data_type sw_val = inp_ptr_NC[iy_sw * inp_sH + ix_sw * inp_sW];\n          gix -= sw_val * (iy - iy_ne) * gOut;\n          giy += sw_val * (ix_ne - ix) * gOut;\n        }\n        if (WithinBounds2D(iy_se, ix_se, inp_H, inp_W)) {\n          data_type se_val = inp_ptr_NC[iy_se * inp_sH + ix_se * inp_sW];\n          gix += se_val * (iy - iy_nw) * gOut;\n          giy += se_val * (ix - ix_nw) * gOut;\n        }\n      }\n\n      // assuming grad_grid is contiguous\n      // thus we can\n      //   1. use index with gGrid_sW to directly compute gGrid_ptr_NHW\n      //   2. directly assign to gGrid_ptr_NHW[0], gGrid_ptr_NHW[1]\n      data_type* gGrid_ptr_NHW = grad_grid_ptr + index * gGrid_sW;\n      gGrid_ptr_NHW[0] = gix_mult * gix;\n      gGrid_ptr_NHW[1] = giy_mult * giy;\n    } else if (interpolation_mode == GridSamplerInterpolation::kNearest) {\n      index_type ix_nearest = static_cast<index_type>(::round(ix));\n      index_type iy_nearest = static_cast<index_type>(::round(iy));\n\n      // assign nearest neighor pixel value to output pixel\n      const data_type* gOut_ptr_NCHW = grad_output_ptr + n * gOut_sN + h * gOut_sH + w * gOut_sW;\n      index_type NC_offset = n * gInp_sN;\n      for (index_type c = 0; c < C; ++c, NC_offset += gInp_sC, gOut_ptr_NCHW += gOut_sC) {\n        // calculate and set grad_input. See Note [Passing pointer and offset to fastAtomicAdd].\n        SafeAdd2D(grad_input_ptr, iy_nearest, ix_nearest, gInp_sH, gInp_sW, inp_H, inp_W,\n                  *gOut_ptr_NCHW, NC_offset, grad_input_memory_span);\n      }\n\n      // assuming grad_grid is contiguous\n      // thus we can\n      //   1. use index with gGrid_sW to directly compute gGrid_ptr_NHW\n      //   2. directly assign to gGrid_ptr_NHW[0], gGrid_ptr_NHW[1]\n      data_type* gGrid_ptr_NHW = grad_grid_ptr + index * gGrid_sW;\n      gGrid_ptr_NHW[0] = static_cast<data_type>(0);\n      gGrid_ptr_NHW[1] = static_cast<data_type>(0);\n    } else if (interpolation_mode == GridSamplerInterpolation::kBicubic) {\n      ix = GridSamplerUnnormalizeSetGrad(x, inp_W, align_corners, &gix_mult);\n      iy = GridSamplerUnnormalizeSetGrad(y, inp_H, align_corners, &giy_mult);\n\n      data_type ix_nw = ::floor(ix);\n      data_type iy_nw = ::floor(iy);\n\n      const data_type tx = ix - ix_nw;\n      const data_type ty = iy - iy_nw;\n\n      data_type x_coeffs[4];\n      data_type y_coeffs[4];\n      data_type x_coeffs_grad[4];\n      data_type y_coeffs_grad[4];\n\n      GetCubicUpsamplingCoefficients<data_type>(x_coeffs, tx);\n      GetCubicUpsamplingCoefficients<data_type>(y_coeffs, ty);\n      GetCubicCoefficientsGrad<data_type>(x_coeffs_grad, tx);\n      GetCubicCoefficientsGrad<data_type>(y_coeffs_grad, ty);\n\n      data_type gix = static_cast<data_type>(0);\n      data_type giy = static_cast<data_type>(0);\n\n      const data_type* gOut_ptr_NCHW = grad_output_ptr + n * gOut_sN + h * gOut_sH + w * gOut_sW;\n      index_type NC_offset = n * gInp_sN;\n      const data_type* inp_ptr_NC = input_ptr + n * inp_sN;\n\n      for (index_type c = 0; c < C;\n           ++c, gOut_ptr_NCHW += gOut_sC, NC_offset += gInp_sC, inp_ptr_NC += inp_sC) {\n        data_type gOut = *gOut_ptr_NCHW;\n\n#ifdef __CUDA_ARCH__\n#pragma unroll 4\n#endif\n        for (index_type i = 0; i < 4; ++i) {\n#ifdef __CUDA_ARCH__\n#pragma unroll 4\n#endif\n          for (index_type j = 0; j < 4; ++j) {\n            // set input gradient. See Note [Passing pointer and offset to fastAtomicAdd].\n            AddValueBounded<data_type>(grad_input_ptr, ix_nw - 1 + i, iy_nw - 1 + j, inp_W, inp_H,\n                                       gInp_sW, gInp_sH, gOut * x_coeffs[i] * y_coeffs[j],\n                                       padding_mode, align_corners, NC_offset,\n                                       grad_input_memory_span);\n\n            // set grid gradient\n            data_type val =\n                GetValueBounded<data_type>(inp_ptr_NC, ix_nw - 1 + i, iy_nw - 1 + j, inp_W, inp_H,\n                                           inp_sW, inp_sH, padding_mode, align_corners);\n\n            gix -= val * x_coeffs_grad[i] * y_coeffs[j] * gOut;\n            giy -= val * y_coeffs_grad[j] * x_coeffs[i] * gOut;\n          }\n        }\n      }\n\n      data_type* gGrid_ptr_NHW = grad_grid_ptr + index * gGrid_sW;\n      gGrid_ptr_NHW[0] = gix_mult * gix;\n      gGrid_ptr_NHW[1] = giy_mult * giy;\n    }\n  }\n}\n\ntemplate<typename data_type, typename index_type>\nOF_DEVICE_FUNC void GridSampler5DBackwardKernel(\n    const index_type nthreads, const data_type* grad_output_ptr, const data_type* input_ptr,\n    const data_type* grid_ptr, data_type* grad_input_ptr, data_type* grad_grid_ptr, index_type N,\n    index_type C, index_type inp_D, index_type inp_H, index_type inp_W, index_type out_D,\n    index_type out_H, index_type out_W, const GridSamplerInterpolation interpolation_mode,\n    const GridSamplerPadding padding_mode, const bool align_corners,\n    const index_type grad_input_memory_span) {\n  index_type inp_sN = C * inp_D * inp_H * inp_W;\n  index_type inp_sC = inp_D * inp_H * inp_W;\n  index_type inp_sD = inp_H * inp_W;\n  index_type inp_sH = inp_W;\n  index_type inp_sW = 1;\n  index_type grid_sN = out_D * out_H * out_W * 3;\n  index_type grid_sD = out_H * out_W * 3;\n  index_type grid_sH = out_W * 3;\n  index_type grid_sW = 3;\n  index_type grid_sCoor = 1;\n  index_type gOut_sN = C * out_D * out_H * out_W;\n  index_type gOut_sC = out_D * out_H * out_W;\n  index_type gOut_sD = out_H * out_W;\n  index_type gOut_sH = out_W;\n  index_type gOut_sW = 1;\n  index_type gInp_sN = inp_sN;\n  index_type gInp_sC = inp_sC;\n  index_type gInp_sD = inp_sD;\n  index_type gInp_sH = inp_sH;\n  index_type gInp_sW = inp_sW;\n  index_type gGrid_sW = grid_sW;\n\n  XPU_1D_KERNEL_LOOP(index, nthreads) {\n    const index_type w = index % out_W;\n    const index_type h = (index / out_W) % out_H;\n    const index_type d = (index / (out_H * out_W)) % out_D;\n    const index_type n = index / (out_D * out_H * out_W);\n    const auto grid_offset = n * grid_sN + d * grid_sD + h * grid_sH + w * grid_sW;\n\n    // get the corresponding input x, y, z co-ordinates from grid\n    data_type ix = grid_ptr[grid_offset];\n    data_type iy = grid_ptr[grid_offset + grid_sCoor];\n    data_type iz = grid_ptr[grid_offset + 2 * grid_sCoor];\n\n    // multipliers for gradients on ix, iy, and iz\n    data_type gix_mult, giy_mult, giz_mult;\n    ix = GridSamplerComputeSourceIndexSetGrad(ix, inp_W, padding_mode, align_corners, &gix_mult);\n    iy = GridSamplerComputeSourceIndexSetGrad(iy, inp_H, padding_mode, align_corners, &giy_mult);\n    iz = GridSamplerComputeSourceIndexSetGrad(iz, inp_D, padding_mode, align_corners, &giz_mult);\n\n    if (interpolation_mode == GridSamplerInterpolation::kBilinear) {\n      // get corner pixel values from (x, y, z)\n      // for 4d, we used north-east-south-west\n      // for 5d, we add top-bottom\n      index_type ix_tnw = static_cast<index_type>(::floor(ix));\n      index_type iy_tnw = static_cast<index_type>(::floor(iy));\n      index_type iz_tnw = static_cast<index_type>(::floor(iz));\n\n      index_type ix_tne = ix_tnw + 1;\n      index_type iy_tne = iy_tnw;\n      index_type iz_tne = iz_tnw;\n\n      index_type ix_tsw = ix_tnw;\n      index_type iy_tsw = iy_tnw + 1;\n      index_type iz_tsw = iz_tnw;\n\n      index_type ix_tse = ix_tnw + 1;\n      index_type iy_tse = iy_tnw + 1;\n      index_type iz_tse = iz_tnw;\n\n      index_type ix_bnw = ix_tnw;\n      index_type iy_bnw = iy_tnw;\n      index_type iz_bnw = iz_tnw + 1;\n\n      index_type ix_bne = ix_tnw + 1;\n      index_type iy_bne = iy_tnw;\n      index_type iz_bne = iz_tnw + 1;\n\n      index_type ix_bsw = ix_tnw;\n      index_type iy_bsw = iy_tnw + 1;\n      index_type iz_bsw = iz_tnw + 1;\n\n      index_type ix_bse = ix_tnw + 1;\n      index_type iy_bse = iy_tnw + 1;\n      index_type iz_bse = iz_tnw + 1;\n\n      // get surfaces to each neighbor:\n      data_type tnw = (ix_bse - ix) * (iy_bse - iy) * (iz_bse - iz);\n      data_type tne = (ix - ix_bsw) * (iy_bsw - iy) * (iz_bsw - iz);\n      data_type tsw = (ix_bne - ix) * (iy - iy_bne) * (iz_bne - iz);\n      data_type tse = (ix - ix_bnw) * (iy - iy_bnw) * (iz_bnw - iz);\n      data_type bnw = (ix_tse - ix) * (iy_tse - iy) * (iz - iz_tse);\n      data_type bne = (ix - ix_tsw) * (iy_tsw - iy) * (iz - iz_tsw);\n      data_type bsw = (ix_tne - ix) * (iy - iy_tne) * (iz - iz_tne);\n      data_type bse = (ix - ix_tnw) * (iy - iy_tnw) * (iz - iz_tnw);\n\n      data_type gix = static_cast<data_type>(0), giy = static_cast<data_type>(0),\n                giz = static_cast<data_type>(0);\n      const data_type* gOut_ptr_NCDHW =\n          grad_output_ptr + n * gOut_sN + d * gOut_sD + h * gOut_sH + w * gOut_sW;\n      index_type NC_offset = n * gInp_sN;\n      const data_type* inp_ptr_NC = input_ptr + n * inp_sN;\n      // calculate bilinear weighted pixel value and set output pixel\n      for (index_type c = 0; c < C;\n           ++c, gOut_ptr_NCDHW += gOut_sC, NC_offset += gInp_sC, inp_ptr_NC += inp_sC) {\n        data_type gOut = *gOut_ptr_NCDHW;\n\n        // calculate and set grad_input. See Note [Passing pointer and offset to fastAtomicAdd].\n        SafeAdd3D(grad_input_ptr, iz_tnw, iy_tnw, ix_tnw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H,\n                  inp_W, tnw * gOut, NC_offset, grad_input_memory_span);\n        SafeAdd3D(grad_input_ptr, iz_tne, iy_tne, ix_tne, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H,\n                  inp_W, tne * gOut, NC_offset, grad_input_memory_span);\n        SafeAdd3D(grad_input_ptr, iz_tsw, iy_tsw, ix_tsw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H,\n                  inp_W, tsw * gOut, NC_offset, grad_input_memory_span);\n        SafeAdd3D(grad_input_ptr, iz_tse, iy_tse, ix_tse, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H,\n                  inp_W, tse * gOut, NC_offset, grad_input_memory_span);\n        SafeAdd3D(grad_input_ptr, iz_bnw, iy_bnw, ix_bnw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H,\n                  inp_W, bnw * gOut, NC_offset, grad_input_memory_span);\n        SafeAdd3D(grad_input_ptr, iz_bne, iy_bne, ix_bne, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H,\n                  inp_W, bne * gOut, NC_offset, grad_input_memory_span);\n        SafeAdd3D(grad_input_ptr, iz_bsw, iy_bsw, ix_bsw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H,\n                  inp_W, bsw * gOut, NC_offset, grad_input_memory_span);\n        SafeAdd3D(grad_input_ptr, iz_bse, iy_bse, ix_bse, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H,\n                  inp_W, bse * gOut, NC_offset, grad_input_memory_span);\n\n        // calculate grad_grid\n        if (WithinBounds3D(iz_tnw, iy_tnw, ix_tnw, inp_D, inp_H, inp_W)) {\n          data_type tnw_val = inp_ptr_NC[iz_tnw * inp_sD + iy_tnw * inp_sH + ix_tnw * inp_sW];\n          gix -= tnw_val * (iy_bse - iy) * (iz_bse - iz) * gOut;\n          giy -= tnw_val * (ix_bse - ix) * (iz_bse - iz) * gOut;\n          giz -= tnw_val * (ix_bse - ix) * (iy_bse - iy) * gOut;\n        }\n        if (WithinBounds3D(iz_tne, iy_tne, ix_tne, inp_D, inp_H, inp_W)) {\n          data_type tne_val = inp_ptr_NC[iz_tne * inp_sD + iy_tne * inp_sH + ix_tne * inp_sW];\n          gix += tne_val * (iy_bsw - iy) * (iz_bsw - iz) * gOut;\n          giy -= tne_val * (ix - ix_bsw) * (iz_bsw - iz) * gOut;\n          giz -= tne_val * (ix - ix_bsw) * (iy_bsw - iy) * gOut;\n        }\n        if (WithinBounds3D(iz_tsw, iy_tsw, ix_tsw, inp_D, inp_H, inp_W)) {\n          data_type tsw_val = inp_ptr_NC[iz_tsw * inp_sD + iy_tsw * inp_sH + ix_tsw * inp_sW];\n          gix -= tsw_val * (iy - iy_bne) * (iz_bne - iz) * gOut;\n          giy += tsw_val * (ix_bne - ix) * (iz_bne - iz) * gOut;\n          giz -= tsw_val * (ix_bne - ix) * (iy - iy_bne) * gOut;\n        }\n        if (WithinBounds3D(iz_tse, iy_tse, ix_tse, inp_D, inp_H, inp_W)) {\n          data_type tse_val = inp_ptr_NC[iz_tse * inp_sD + iy_tse * inp_sH + ix_tse * inp_sW];\n          gix += tse_val * (iy - iy_bnw) * (iz_bnw - iz) * gOut;\n          giy += tse_val * (ix - ix_bnw) * (iz_bnw - iz) * gOut;\n          giz -= tse_val * (ix - ix_bnw) * (iy - iy_bnw) * gOut;\n        }\n        if (WithinBounds3D(iz_bnw, iy_bnw, ix_bnw, inp_D, inp_H, inp_W)) {\n          data_type bnw_val = inp_ptr_NC[iz_bnw * inp_sD + iy_bnw * inp_sH + ix_bnw * inp_sW];\n          gix -= bnw_val * (iy_tse - iy) * (iz - iz_tse) * gOut;\n          giy -= bnw_val * (ix_tse - ix) * (iz - iz_tse) * gOut;\n          giz += bnw_val * (ix_tse - ix) * (iy_tse - iy) * gOut;\n        }\n        if (WithinBounds3D(iz_bne, iy_bne, ix_bne, inp_D, inp_H, inp_W)) {\n          data_type bne_val = inp_ptr_NC[iz_bne * inp_sD + iy_bne * inp_sH + ix_bne * inp_sW];\n          gix += bne_val * (iy_tsw - iy) * (iz - iz_tsw) * gOut;\n          giy -= bne_val * (ix - ix_tsw) * (iz - iz_tsw) * gOut;\n          giz += bne_val * (ix - ix_tsw) * (iy_tsw - iy) * gOut;\n        }\n        if (WithinBounds3D(iz_bsw, iy_bsw, ix_bsw, inp_D, inp_H, inp_W)) {\n          data_type bsw_val = inp_ptr_NC[iz_bsw * inp_sD + iy_bsw * inp_sH + ix_bsw * inp_sW];\n          gix -= bsw_val * (iy - iy_tne) * (iz - iz_tne) * gOut;\n          giy += bsw_val * (ix_tne - ix) * (iz - iz_tne) * gOut;\n          giz += bsw_val * (ix_tne - ix) * (iy - iy_tne) * gOut;\n        }\n        if (WithinBounds3D(iz_bse, iy_bse, ix_bse, inp_D, inp_H, inp_W)) {\n          data_type bse_val = inp_ptr_NC[iz_bse * inp_sD + iy_bse * inp_sH + ix_bse * inp_sW];\n          gix += bse_val * (iy - iy_tnw) * (iz - iz_tnw) * gOut;\n          giy += bse_val * (ix - ix_tnw) * (iz - iz_tnw) * gOut;\n          giz += bse_val * (ix - ix_tnw) * (iy - iy_tnw) * gOut;\n        }\n      }\n\n      // assuming grad_grid is contiguous\n      // thus we can\n      //   1. use index with gGrid_sW to directly compute gGrid_ptr_NDHW\n      //   2. directly assign to gGrid_ptr_NDHW[0], gGrid_ptr_NDHW[1], gGrid_ptr_NDHW[2]\n      data_type* gGrid_ptr_NDHW = grad_grid_ptr + index * gGrid_sW;\n      gGrid_ptr_NDHW[0] = gix_mult * gix;\n      gGrid_ptr_NDHW[1] = giy_mult * giy;\n      gGrid_ptr_NDHW[2] = giz_mult * giz;\n    } else if (interpolation_mode == GridSamplerInterpolation::kNearest) {\n      auto ix_nearest = static_cast<index_type>(::round(ix));\n      auto iy_nearest = static_cast<index_type>(::round(iy));\n      auto iz_nearest = static_cast<index_type>(::round(iz));\n\n      // assign nearest neighor pixel value to output pixel\n      const data_type* gOut_ptr_NCDHW =\n          grad_output_ptr + n * gOut_sN + d * gOut_sD + h * gOut_sH + w * gOut_sW;\n      index_type NC_offset = n * gInp_sN;\n      for (index_type c = 0; c < C; ++c, gOut_ptr_NCDHW += gOut_sC, NC_offset += gInp_sC) {\n        // calculate and set grad_input. See Note [Passing pointer and offset to fastAtomicAdd].\n        SafeAdd3D(grad_input_ptr, iz_nearest, iy_nearest, ix_nearest, gInp_sD, gInp_sH, gInp_sW,\n                  inp_D, inp_H, inp_W, *gOut_ptr_NCDHW, NC_offset, grad_input_memory_span);\n      }\n\n      // assuming grad_grid is contiguous\n      // thus we can\n      //   1. use index with gGrid_sW to directly compute gGrid_ptr_NDHW\n      //   2. directly assign to gGrid_ptr_NDHW[0], gGrid_ptr_NDHW[1], gGrid_ptr_NDHW[2]\n      data_type* gGrid_ptr_NDHW = grad_grid_ptr + index * gGrid_sW;\n      gGrid_ptr_NDHW[0] = static_cast<data_type>(0);\n      gGrid_ptr_NDHW[1] = static_cast<data_type>(0);\n      gGrid_ptr_NDHW[2] = static_cast<data_type>(0);\n    }\n  }\n}\n\ntemplate<DeviceType device_type, typename data_type, typename index_type>\nstruct GridSampleKernelUtil final {\n  static void Forward4D(user_op::KernelComputeContext* ctx, const user_op::Tensor* input,\n                        const user_op::Tensor* grid, user_op::Tensor* output,\n                        GridSamplerInterpolation interpolation, GridSamplerPadding padding,\n                        const bool align_corners, const ShapeView& input_shape,\n                        const ShapeView& grid_shape, const ShapeView& output_shape, int64_t count);\n  static void Forward5D(user_op::KernelComputeContext* ctx, const user_op::Tensor* input,\n                        const user_op::Tensor* grid, user_op::Tensor* output,\n                        GridSamplerInterpolation interpolation, GridSamplerPadding padding,\n                        const bool align_corners, const ShapeView& input_shape,\n                        const ShapeView& grid_shape, const ShapeView& output_shape, int64_t count);\n\n  static void Backward4D(user_op::KernelComputeContext* ctx, const user_op::Tensor* doutput,\n                         const user_op::Tensor* input, const user_op::Tensor* grid,\n                         user_op::Tensor* dinput, user_op::Tensor* dgrid,\n                         GridSamplerInterpolation interpolation, GridSamplerPadding padding,\n                         const bool align_corners, const ShapeView& input_shape,\n                         const ShapeView& grid_shape, const ShapeView& output_shape, int64_t count);\n  static void Backward5D(user_op::KernelComputeContext* ctx, const user_op::Tensor* doutput,\n                         const user_op::Tensor* input, const user_op::Tensor* grid,\n                         user_op::Tensor* dinput, user_op::Tensor* dgrid,\n                         GridSamplerInterpolation interpolation, GridSamplerPadding padding,\n                         const bool align_corners, const ShapeView& input_shape,\n                         const ShapeView& grid_shape, const ShapeView& output_shape, int64_t count);\n};\n\n// macros for functors instantiate(used by grid_sample_kernel_util.cu, grid_sample_kernel_util.cpp)\n#define INSTANTIATE_GRID_SAMPLE_KERNEL_UTIL(device_type, dtype_pair, itype_pair)  \\\n  template struct GridSampleKernelUtil<device_type, OF_PP_PAIR_FIRST(dtype_pair), \\\n                                       OF_PP_PAIR_FIRST(itype_pair)>;\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_KERNELS_GRID_SAMPLE_KERNEL_H_\n"
  },
  {
    "path": "oneflow/user/kernels/group_conv_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/user/ops/nn_util.h\"\n#include \"oneflow/core/kernel/kernel_util.h\"\n#include \"oneflow/core/ep/include/primitive/add.h\"\n#include \"oneflow/core/ep/include/primitive/matmul.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nep::primitive::BlasTransposeType GetBlasTransposeType(bool transpose) {\n  return transpose ? ep::primitive::BlasTransposeType::T : ep::primitive::BlasTransposeType::N;\n}\n\nstd::unique_ptr<ep::primitive::Matmul> NewMatmulPrimitive(DeviceType device_type,\n                                                          DataType data_type, bool transpose_a,\n                                                          bool transpose_b) {\n  const auto trans_a = GetBlasTransposeType(transpose_a);\n  const auto trans_b = GetBlasTransposeType(transpose_b);\n  return ep::primitive::NewPrimitive<ep::primitive::MatmulFactory>(device_type, data_type, trans_a,\n                                                                   trans_b);\n}\n\ntemplate<typename Context>\nstd::unique_ptr<ep::primitive::Matmul> NewChannelsFirstMatmulPrimitive(Context* ctx) {\n  const DataType data_type = ctx->TensorDesc4ArgNameAndIndex(\"in\", 0)->data_type();\n  return NewMatmulPrimitive(ctx->device_type(), data_type, /*transpose_a=*/false,\n                            /*transpose_b=*/false);\n}\n\nauto ChannelsFirstMatmulPrimitiveExists() {\n  return hob::make_custom(\"ChannelsFirstMatmulPrimitiveExists\",\n                          [](const user_op::KernelRegContext& ctx) {\n                            return NewChannelsFirstMatmulPrimitive(&ctx).operator bool();\n                          });\n}\n\ntemplate<typename Context>\nstd::unique_ptr<ep::primitive::Matmul> NewChannelsLastMatmulPrimitive(Context* ctx) {\n  const DataType data_type = ctx->TensorDesc4ArgNameAndIndex(\"in\", 0)->data_type();\n  return NewMatmulPrimitive(ctx->device_type(), data_type, /*transpose_a=*/true,\n                            /*transpose_b=*/true);\n}\n\nauto ChannelsLastMatmulPrimitiveExists() {\n  return hob::make_custom(\"ChannelsLastMatmulPrimitiveExists\",\n                          [](const user_op::KernelRegContext& ctx) {\n                            return NewChannelsLastMatmulPrimitive(&ctx).operator bool();\n                          });\n}\n\ntemplate<typename Context>\nstd::unique_ptr<ep::primitive::Matmul> NewConvDataGradTransATransBMatmulPrimitive(Context* ctx) {\n  const DataType data_type = ctx->TensorDesc4ArgNameAndIndex(\"dy\", 0)->data_type();\n  return NewMatmulPrimitive(ctx->device_type(), data_type, /*transpose_a=*/true,\n                            /*transpose_b=*/true);\n}\n\nauto ConvDataGradTransATransBMatmulPrimitiveExists() {\n  return hob::make_custom(\"ConvDataGradTransATransBMatmulPrimitiveExists\",\n                          [](const user_op::KernelRegContext& ctx) {\n                            return NewConvDataGradTransATransBMatmulPrimitive(&ctx).operator bool();\n                          });\n}\n\ntemplate<typename Context>\nstd::unique_ptr<ep::primitive::Matmul> NewConvDataGradTransANoTransBMatmulPrimitive(Context* ctx) {\n  const DataType data_type = ctx->TensorDesc4ArgNameAndIndex(\"dy\", 0)->data_type();\n  return NewMatmulPrimitive(ctx->device_type(), data_type, /*transpose_a=*/true,\n                            /*transpose_b=*/false);\n}\n\nauto ConvDataGradTransANoTransBMatmulPrimitiveExists() {\n  return hob::make_custom(\n      \"ConvDataGradTransANoTransBMatmulPrimitiveExists\", [](const user_op::KernelRegContext& ctx) {\n        return NewConvDataGradTransANoTransBMatmulPrimitive(&ctx).operator bool();\n      });\n}\n\ntemplate<typename Context>\nstd::unique_ptr<ep::primitive::Matmul> NewConvWeightGradTransATransBMatmulPrimitive(Context* ctx) {\n  const DataType data_type = ctx->TensorDesc4ArgNameAndIndex(\"dy\", 0)->data_type();\n  return NewMatmulPrimitive(ctx->device_type(), data_type, /*transpose_a=*/true,\n                            /*transpose_b=*/true);\n}\n\nauto ConvWeightGradTransATransBMatmulPrimitiveExists() {\n  return hob::make_custom(\n      \"ConvWeightGradTransATransBMatmulPrimitiveExists\", [](const user_op::KernelRegContext& ctx) {\n        return NewConvWeightGradTransATransBMatmulPrimitive(&ctx).operator bool();\n      });\n}\n\ntemplate<typename Context>\nstd::unique_ptr<ep::primitive::Matmul> NewConvWeightGradNoTransATransBMatmulPrimitive(\n    Context* ctx) {\n  const DataType data_type = ctx->TensorDesc4ArgNameAndIndex(\"dy\", 0)->data_type();\n  return NewMatmulPrimitive(ctx->device_type(), data_type, /*transpose_a=*/false,\n                            /*transpose_b=*/true);\n}\n\nauto ConvWeightGradNoTransATransBMatmulPrimitiveExists() {\n  return hob::make_custom(\n      \"ConvWeightGradNoTransATransBMatmulPrimitiveExists\",\n      [](const user_op::KernelRegContext& ctx) {\n        return NewConvWeightGradNoTransATransBMatmulPrimitive(&ctx).operator bool();\n      });\n}\n\ntemplate<typename T>\nusing Im2ColFunc = void (*)(const T* in_dptr, const ShapeView& in_shape,\n                            const ShapeView& weight_shape, const ShapeView& out_shape,\n                            const int32_t* strides, const int32_t* dilation_rate,\n                            const int32_t* padding_before, T* col_buf);\n\ntemplate<typename T>\nusing Col2ImFunc = void (*)(const T* col_buf, const ShapeView& in_shape,\n                            const ShapeView& weight_shape, const ShapeView& out_shape,\n                            const int32_t* strides, const int32_t* dilation_rate,\n                            const int32_t* padding_before, T* in_diff_ptr);\n\ntemplate<typename T>\nT* GetImgMutDptr(user_op::Tensor* tensor, int64_t idx) {\n  return tensor->mut_dptr<T>() + tensor->shape_view().Count(1) * idx;\n}\n\ntemplate<typename T>\nconst T* GetImgDptr(const user_op::Tensor* tensor, int64_t idx) {\n  return tensor->dptr<T>() + tensor->shape_view().Count(1) * idx;\n}\n\nsize_t CalcElemNumOfColBuf(const ShapeView& out_shape, const ShapeView& weight_shape,\n                           const int32_t idx_offset) {\n  int64_t col_buf_elem_cnt = 1;\n  int64_t ndims = out_shape.NumAxes() - 2;\n\n  for (size_t i = 0; i != ndims + 1; ++i) { col_buf_elem_cnt *= weight_shape.At(i + 1); }\n  for (size_t i = 0; i != ndims; ++i) { col_buf_elem_cnt *= out_shape.At(idx_offset + i); }\n  return col_buf_elem_cnt;\n}\n\ntemplate<typename T>\nclass ColBufWriter {\n public:\n  ColBufWriter(const T* src_ptr, T* dst_ptr, int64_t c_size, int64_t id_size, int64_t ih_size,\n               int64_t iw_size, int64_t od_size, int64_t oh_size, int64_t ow_size)\n      : src_ptr_(src_ptr),\n        dst_ptr_(dst_ptr),\n        c_size_(c_size),\n        id_size_(id_size),\n        ih_size_(ih_size),\n        iw_size_(iw_size),\n        od_size_(od_size),\n        oh_size_(oh_size),\n        ow_size_(ow_size) {}\n  virtual ~ColBufWriter() = default;\n  virtual void DHWCWrite(int64_t c, int64_t id, int64_t ih, int64_t iw) = 0;\n  virtual void CDHWWrite(int64_t c, int64_t id, int64_t ih, int64_t iw) = 0;\n  virtual void InvalidDFunc() = 0;\n  virtual void InvalidHFunc() = 0;\n  virtual void InvalidWFunc() = 0;\n  virtual void NextImCSize() = 0;\n\n protected:\n  const T* src_ptr_;\n  T* dst_ptr_;\n  int64_t c_size_ = 0;\n  int64_t id_size_ = 0;\n  int64_t ih_size_ = 0;\n  int64_t iw_size_ = 0;\n  int64_t od_size_ = 0;\n  int64_t oh_size_ = 0;\n  int64_t ow_size_ = 0;\n};\n\ntemplate<typename T>\nclass Im2ColWriter final : public ColBufWriter<T> {\n public:\n  Im2ColWriter(const T* src_ptr, T* dst_ptr, int64_t c_size, int64_t id_size, int64_t ih_size,\n               int64_t iw_size, int64_t od_size, int64_t oh_size, int64_t ow_size)\n      : ColBufWriter<T>::ColBufWriter(src_ptr, dst_ptr, c_size, id_size, ih_size, iw_size, od_size,\n                                      oh_size, ow_size) {}\n  ~Im2ColWriter() = default;\n  void DHWCWrite(int64_t c, int64_t id, int64_t ih, int64_t iw) override {\n    *(this->dst_ptr_++) =\n        this->src_ptr_[id * this->id_size_ + ih * this->ih_size_ + iw * this->iw_size_ + c];\n  }\n  void CDHWWrite(int64_t c, int64_t id, int64_t ih, int64_t iw) override {\n    *(this->dst_ptr_++) = this->src_ptr_[id * this->id_size_ + ih * this->ih_size_ + iw];\n  }\n  void InvalidDFunc() override {\n    FOR_RANGE(int64_t, i, 0, this->od_size_) { *(this->dst_ptr_++) = 0; }\n  }\n  void InvalidHFunc() override {\n    FOR_RANGE(int64_t, i, 0, this->oh_size_) { *(this->dst_ptr_++) = 0; }\n  }\n  void InvalidWFunc() override {\n    FOR_RANGE(int64_t, i, 0, this->ow_size_) { *(this->dst_ptr_++) = 0; }\n  }\n  void NextImCSize() override { this->src_ptr_ += this->c_size_; }\n};\n\ntemplate<typename T>\nclass Col2ImWriter final : public ColBufWriter<T> {\n public:\n  Col2ImWriter(const T* src_ptr, T* dst_ptr, int64_t c_size, int64_t id_size, int64_t ih_size,\n               int64_t iw_size, int64_t od_size, int64_t oh_size, int64_t ow_size)\n      : ColBufWriter<T>::ColBufWriter(src_ptr, dst_ptr, c_size, id_size, ih_size, iw_size, od_size,\n                                      oh_size, ow_size) {}\n  ~Col2ImWriter() = default;\n  void DHWCWrite(int64_t c, int64_t id, int64_t ih, int64_t iw) override {\n    this->dst_ptr_[id * this->id_size_ + ih * this->ih_size_ + iw * this->iw_size_ + c] +=\n        *(this->src_ptr_++);\n  }\n  void CDHWWrite(int64_t c, int64_t id, int64_t ih, int64_t iw) override {\n    this->dst_ptr_[id * this->id_size_ + ih * this->ih_size_ + iw] += *(this->src_ptr_++);\n  }\n  void InvalidDFunc() override { this->src_ptr_ += this->od_size_; }\n  void InvalidHFunc() override { this->src_ptr_ += this->oh_size_; }\n  void InvalidWFunc() override { this->src_ptr_ += this->ow_size_; }\n  void NextImCSize() override { this->dst_ptr_ += this->c_size_; }\n};\n\ntemplate<typename T>\nusing DHWValidFunc = void (ColBufWriter<T>::*)(int64_t c, int64_t kd, int64_t kh, int64_t kw);\n\ntemplate<typename T>\nclass ColBufUtil final {\n public:\n  ColBufUtil(const ShapeView& in_shape, const ShapeView& out_shape, int32_t dhw_offset,\n             const int32_t* strides, const int32_t* dilation_rate, const int32_t* padding_before,\n             const int32_t id_num, const int32_t ih_num, const int32_t iw_num, const int32_t od_num,\n             const int32_t oh_num, const int32_t ow_num)\n      : strides_(strides),\n        dilation_rate_(dilation_rate),\n        padding_before_(padding_before),\n        id_num_(id_num),\n        ih_num_(ih_num),\n        iw_num_(iw_num),\n        od_num_(od_num),\n        oh_num_(oh_num),\n        ow_num_(ow_num) {\n    if (dhw_offset == 2) {\n      dhw_valid_func_ = &ColBufWriter<T>::CDHWWrite;\n    } else {\n      dhw_valid_func_ = &ColBufWriter<T>::DHWCWrite;\n    }\n  }\n  void operator()(ColBufWriter<T>* col_buf_writer, int64_t c, int64_t kd, int64_t kh, int64_t kw) {\n    int64_t id = kd * dilation_rate_[0] - padding_before_[0];\n    FOR_RANGE(int64_t, od, 0, od_num_) {\n      if (id < 0 || id >= id_num_) {\n        col_buf_writer->InvalidDFunc();\n      } else {\n        int64_t ih = kh * dilation_rate_[1] - padding_before_[1];\n        FOR_RANGE(int64_t, oh, 0, oh_num_) {\n          if (ih < 0 || ih >= ih_num_) {\n            col_buf_writer->InvalidHFunc();\n          } else {\n            int64_t iw = kw * dilation_rate_[2] - padding_before_[2];\n            FOR_RANGE(int64_t, ow, 0, ow_num_) {\n              if (iw < 0 || iw >= iw_num_) {\n                col_buf_writer->InvalidWFunc();\n              } else {\n                (col_buf_writer->*dhw_valid_func_)(c, id, ih, iw);\n              }\n              iw += strides_[2];\n            }\n          }\n          ih += strides_[1];\n        }\n      }\n      id += strides_[0];\n    }\n  }\n\n private:\n  const int32_t* strides_;\n  const int32_t* dilation_rate_;\n  const int32_t* padding_before_;\n  DHWValidFunc<T> dhw_valid_func_;\n  int64_t id_num_;\n  int64_t ih_num_;\n  int64_t iw_num_;\n  int64_t od_num_;\n  int64_t oh_num_;\n  int64_t ow_num_;\n};\n\ntemplate<typename T>\nstruct ConvKernelUtil final {\n public:\n  static void NCDHWIm2Col(const T* in_dptr, const ShapeView& in_shape,\n                          const ShapeView& weight_shape, const ShapeView& out_shape,\n                          const int32_t* strides, const int32_t* dilation_rate,\n                          const int32_t* padding_before, T* col_buf_ptr) {\n    ColBufUtil<T> col_buf_util(in_shape, out_shape, 2, strides, dilation_rate, padding_before,\n                               in_shape.At(2), in_shape.At(3), in_shape.At(4), out_shape.At(2),\n                               out_shape.At(3), out_shape.At(4));\n    Im2ColWriter<T> col_buf_writer(in_dptr, col_buf_ptr, in_shape.Count(2), in_shape.Count(3),\n                                   in_shape.Count(4), 1, out_shape.Count(3), out_shape.Count(4), 1);\n    DoNCDWHFunc(weight_shape, col_buf_util, &col_buf_writer);\n  }\n\n  static void NDHWCIm2Col(const T* in_dptr, const ShapeView& in_shape,\n                          const ShapeView& weight_shape, const ShapeView& out_shape,\n                          const int32_t* strides, const int32_t* dilation_rate,\n                          const int32_t* padding_before, T* col_buf_ptr) {\n    ColBufUtil<T> col_buf_util(in_shape, out_shape, 1, strides, dilation_rate, padding_before,\n                               in_shape.At(1), in_shape.At(2), in_shape.At(3), out_shape.At(1),\n                               out_shape.At(2), out_shape.At(3));\n    Im2ColWriter<T> col_buf_writer(in_dptr, col_buf_ptr, in_shape.Count(2), in_shape.Count(2),\n                                   in_shape.Count(3), in_shape.Count(4), out_shape.Count(2, 4),\n                                   out_shape.Count(3, 4), 1);\n    DoNDWHCFunc(weight_shape, col_buf_util, &col_buf_writer);\n  }\n\n  static void NCDHWCol2Im(const T* col_buf_ptr, const ShapeView& in_shape,\n                          const ShapeView& weight_shape, const ShapeView& out_shape,\n                          const int32_t* strides, const int32_t* dilation_rate,\n                          const int32_t* padding_before, T* in_diff_ptr) {\n    ColBufUtil<T> col_buf_util(in_shape, out_shape, 2, strides, dilation_rate, padding_before,\n                               in_shape.At(2), in_shape.At(3), in_shape.At(4), out_shape.At(2),\n                               out_shape.At(3), out_shape.At(4));\n    Col2ImWriter<T> col_buf_writer(col_buf_ptr, in_diff_ptr, in_shape.Count(2), in_shape.Count(3),\n                                   in_shape.Count(4), 1, out_shape.Count(3), out_shape.Count(4), 1);\n    DoNCDWHFunc(weight_shape, col_buf_util, &col_buf_writer);\n  }\n\n  static void NDHWCCol2Im(const T* col_buf_ptr, const ShapeView& in_shape,\n                          const ShapeView& weight_shape, const ShapeView& out_shape,\n                          const int32_t* strides, const int32_t* dilation_rate,\n                          const int32_t* padding_before, T* in_diff_ptr) {\n    ColBufUtil<T> col_buf_util(in_shape, out_shape, 1, strides, dilation_rate, padding_before,\n                               in_shape.At(1), in_shape.At(2), in_shape.At(3), out_shape.At(1),\n                               out_shape.At(2), out_shape.At(3));\n    Col2ImWriter<T> col_buf_writer(col_buf_ptr, in_diff_ptr, in_shape.Count(2), in_shape.Count(2),\n                                   in_shape.Count(3), in_shape.Count(4), out_shape.Count(2, 4),\n                                   out_shape.Count(3, 4), 1);\n    DoNDWHCFunc(weight_shape, col_buf_util, &col_buf_writer);\n  }\n\n private:\n  static void DoNCDWHFunc(const ShapeView& weight_shape, ColBufUtil<T>& col_buf_util,\n                          ColBufWriter<T>* col_buf_writer) {\n    for (int64_t c = 0; c != weight_shape.At(1); col_buf_writer->NextImCSize(), ++c) {\n      for (int64_t kd = 0; kd != weight_shape.At(2); ++kd) {\n        for (int64_t kh = 0; kh != weight_shape.At(3); ++kh) {\n          for (int64_t kw = 0; kw != weight_shape.At(4); ++kw) {\n            col_buf_util(col_buf_writer, c, kd, kh, kw);\n          }\n        }\n      }\n    }\n  }\n\n  static void DoNDWHCFunc(const ShapeView& weight_shape, ColBufUtil<T>& col_buf_util,\n                          ColBufWriter<T>* col_buf_writer) {\n    for (int64_t kd = 0; kd != weight_shape.At(1); ++kd) {\n      for (int64_t kh = 0; kh != weight_shape.At(2); ++kh) {\n        for (int64_t kw = 0; kw != weight_shape.At(3); ++kw) {\n          for (int64_t c = 0; c != weight_shape.At(4); ++c) {\n            col_buf_util(col_buf_writer, c, kd, kh, kw);\n          }\n        }\n      }\n    }\n  }\n};\n\ntemplate<typename T>\nstruct ConvOpKernelCache final : public user_op::OpKernelCache {\n  Im2ColFunc<T> im2col_func_ = ConvKernelUtil<T>::NCDHWIm2Col;\n  Col2ImFunc<T> col2im_func_ = ConvKernelUtil<T>::NCDHWCol2Im;\n\n  Shape in_5d_shape_;\n  Shape out_5d_shape_;\n  Shape weight_5d_shape_;\n\n  std::vector<int32_t> strides_3d_;\n  std::vector<int32_t> dilation_rate_3d_;\n  std::vector<int32_t> padding_before_3d_;\n\n  bool is_out_diff_need_trans_ = false;\n  int32_t idx_offset_ = 0;\n  bool is_dynamic_ = false;\n  int32_t groups = 1;\n};\n\ntemplate<typename T>\nstd::shared_ptr<ConvOpKernelCache<T>> CreateConvOpKernelCache(user_op::KernelCacheContext* ctx,\n                                                              const std::string& in_name,\n                                                              const std::string& out_name,\n                                                              const std::string& weight_name) {\n  const auto& data_format = ctx->Attr<std::string>(\"data_format\");\n\n  std::shared_ptr<ConvOpKernelCache<T>> state(new ConvOpKernelCache<T>());\n  if (data_format == \"channels_first\") {\n    state->im2col_func_ = ConvKernelUtil<T>::NCDHWIm2Col;\n    state->col2im_func_ = ConvKernelUtil<T>::NCDHWCol2Im;\n    state->is_out_diff_need_trans_ = false;\n    state->idx_offset_ = 2;\n  } else {\n    state->im2col_func_ = ConvKernelUtil<T>::NDHWCIm2Col;\n    state->col2im_func_ = ConvKernelUtil<T>::NDHWCCol2Im;\n    state->is_out_diff_need_trans_ = true;\n    state->idx_offset_ = 1;\n  }\n  state->groups = ctx->Attr<int32_t>(\"groups\");\n\n  auto Gen5DShape = [](const Shape& shape, int32_t idx_offset) -> Shape {\n    DimVector ret_vec(shape.dim_vec());\n    int32_t ndims = ret_vec.size() - 2;\n    ret_vec.insert(ret_vec.begin() + idx_offset, 3 - ndims, 1);\n    return Shape(ret_vec);\n  };\n  state->in_5d_shape_ =\n      Gen5DShape(ctx->TensorDesc4ArgNameAndIndex(in_name, 0)->shape(), state->idx_offset_);\n  state->out_5d_shape_ =\n      Gen5DShape(ctx->TensorDesc4ArgNameAndIndex(out_name, 0)->shape(), state->idx_offset_);\n  state->weight_5d_shape_ =\n      Gen5DShape(ctx->TensorDesc4ArgNameAndIndex(weight_name, 0)->shape(), state->idx_offset_);\n\n  auto Gen3DVec = [](const std::vector<int32_t>& origin_vec) -> std::vector<int32_t> {\n    std::vector<int32_t> ret_vec = origin_vec;\n    ret_vec.insert(ret_vec.begin(), 3 - ret_vec.size(), 1);\n    return ret_vec;\n  };\n  state->strides_3d_ = Gen3DVec(ctx->Attr<std::vector<int32_t>>(\"strides\"));\n  state->dilation_rate_3d_ = Gen3DVec(ctx->Attr<std::vector<int32_t>>(\"dilation_rate\"));\n  state->is_dynamic_ = ctx->TensorDesc4ArgNameAndIndex(in_name, 0)->is_dynamic();\n  const auto& padding_before = ctx->Attr<std::vector<int32_t>>(\"padding_before\");\n  FOR_RANGE(uint8_t, dim, 0, 3) {\n    int64_t index = static_cast<int64_t>(dim) - (3 - padding_before.size());\n    if (index < 0) {\n      state->padding_before_3d_.push_back(0);\n    } else {\n      state->padding_before_3d_.push_back(padding_before.at(index));\n    }\n  }\n\n  return state;\n}\n\ntemplate<typename T>\nvoid InitBiasMulBuf(T* dptr, int64_t num) {\n  for (int64_t i = 0; i < num; ++i) { dptr[i] = 1; }\n}\ntemplate<typename T, size_t NDims>\nclass ConvCpuKernel final : public user_op::OpKernel {\n public:\n  ConvCpuKernel() = default;\n  ~ConvCpuKernel() = default;\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n\n  std::shared_ptr<user_op::OpKernelCache> InitOpKernelCache(\n      user_op::KernelCacheContext* ctx) const override {\n    return CreateConvOpKernelCache<T>(ctx, \"in\", \"out\", \"weight\");\n  }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*,\n               const user_op::OpKernelCache* cache) const override {\n    const auto* conv_cache = dynamic_cast<const ConvOpKernelCache<T>*>(cache);\n    CHECK_NOTNULL(conv_cache);\n\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    const user_op::Tensor* weight = ctx->Tensor4ArgNameAndIndex(\"weight\", 0);\n    user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n\n    T* col_buf_dptr = tmp_buffer->mut_dptr<T>();\n    int32_t idx_offset = conv_cache->idx_offset_;\n    const int32_t input_group_interval = in->shape_view().At(1) / conv_cache->groups;\n    const int32_t weight_group_interval = weight->shape_view().At(0) / conv_cache->groups;\n    const int32_t output_group_interval = out->shape_view().At(1) / conv_cache->groups;\n    const int32_t input_step = input_group_interval * in->shape_view().Count(2);\n    const int32_t weight_step = weight_group_interval * weight->shape_view().Count(1);\n    const int32_t output_step = output_group_interval * out->shape_view().Count(2);\n    const int32_t m = conv_cache->weight_5d_shape_.At(0) / conv_cache->groups;\n    const int32_t n = conv_cache->out_5d_shape_.Count(idx_offset, idx_offset + 3);\n    const int32_t k = conv_cache->weight_5d_shape_.Count(1);\n    bool is_bias_mul_inited = false;\n\n    const auto& data_format = ctx->Attr<std::string>(\"data_format\");\n    std::unique_ptr<ep::primitive::Matmul> matmul;\n    if (data_format == \"channels_first\") {\n      matmul = NewChannelsFirstMatmulPrimitive(ctx);\n    } else {\n      matmul = NewChannelsLastMatmulPrimitive(ctx);\n    }\n    CHECK(matmul);\n\n    for (int64_t i = 0; i < in->shape_view().At(0); ++i) {\n      const T* input_ptr = GetImgDptr<T>(in, i);\n      const T* weight_ptr = weight->dptr<T>();\n      T* output_ptr = GetImgMutDptr<T>(out, i);\n      for (int64_t g = 0; g < conv_cache->groups; g++) {\n        conv_cache->im2col_func_(\n            input_ptr, ShapeView(conv_cache->in_5d_shape_), ShapeView(conv_cache->weight_5d_shape_),\n            ShapeView(conv_cache->out_5d_shape_), conv_cache->strides_3d_.data(),\n            conv_cache->dilation_rate_3d_.data(), conv_cache->padding_before_3d_.data(),\n            col_buf_dptr);\n\n        // channels first: out = weight * col_buf\n        // channels last:  out = (weight * col_buf)(T)\n        matmul->Launch(ctx->stream(),\n                       m,  // filter / groups\n                       n,  // od * oh * ow\n                       k,  // ci * kd * kh * kw / groups\n                       static_cast<T>(1), weight_ptr, col_buf_dptr, static_cast<T>(0), output_ptr);\n        input_ptr += input_step;\n        weight_ptr += weight_step;\n        output_ptr += output_step;\n      }\n\n      const user_op::Tensor* bias = ctx->Tensor4ArgNameAndIndex(\"bias\", 0);\n      if (bias != nullptr) {\n        int64_t num_of_col_buf =\n            CalcElemNumOfColBuf(out->shape_view(), weight->shape_view(), idx_offset);\n        int64_t num_of_bias_mul =\n            (tmp_buffer->shape_view().elem_cnt() - num_of_col_buf * sizeof(T)) / sizeof(T);\n        CHECK_GT(num_of_bias_mul, 0);\n        T* bias_mul_dptr = col_buf_dptr + num_of_col_buf;\n        if (!is_bias_mul_inited) {\n          InitBiasMulBuf(bias_mul_dptr, num_of_bias_mul);\n          is_bias_mul_inited = true;\n        }\n\n        // channels first:  out += bias * bias_mul\n        // channels last:   out += (bias * bias_mul)(T)\n        matmul->Launch(ctx->stream(),\n                       conv_cache->weight_5d_shape_.At(0),                           // filter\n                       conv_cache->out_5d_shape_.Count(idx_offset, idx_offset + 3),  // od * oh * ow\n                       1,                                                            // 1\n                       static_cast<T>(1), bias->dptr<T>(), bias_mul_dptr, static_cast<T>(1),\n                       GetImgMutDptr<T>(out, i));\n      }\n    }\n  }\n};\n\n#define REGISTER_CONV_KERNEL(op_name, dtype, ndims)                                         \\\n  REGISTER_USER_KERNEL(#op_name)                                                            \\\n      .SetCreateFn<ConvCpuKernel<dtype, ndims>>()                                           \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)                       \\\n                       && (user_op::HobAttr<int32_t>(\"groups\") > 1)                         \\\n                       && (user_op::HobDataType(\"in\", 0) == GetDataType<dtype>::value)      \\\n                       && ChannelsFirstMatmulPrimitiveExists()                              \\\n                       && ChannelsLastMatmulPrimitiveExists())                              \\\n      .SetInferTmpSizeFn([](user_op::InferContext* ctx) -> size_t {                         \\\n        size_t tmp_buffer_size = 0;                                                         \\\n        const auto& out_shape = ctx->OutputTensorDesc(\"out\", 0).shape();                    \\\n        const auto& weight_shape = ctx->InputTensorDesc(\"weight\", 0).shape();               \\\n                                                                                            \\\n        int64_t idx_offset = IdxOffset(ctx->Attr<std::string>(\"data_format\"));              \\\n        tmp_buffer_size +=                                                                  \\\n            CalcElemNumOfColBuf(out_shape, weight_shape, idx_offset) * sizeof(dtype);       \\\n        bool has_bias = ctx->has_input(\"bias\", 0);                                          \\\n        if (has_bias) {                                                                     \\\n          int64_t bias_mul_cnt = 1;                                                         \\\n          for (int i = 0; i < ndims; ++i) { bias_mul_cnt *= out_shape.At(idx_offset + i); } \\\n          tmp_buffer_size += bias_mul_cnt * sizeof(dtype);                                  \\\n        }                                                                                   \\\n        return tmp_buffer_size;                                                             \\\n      })\n\nREGISTER_CONV_KERNEL(conv1d, float, 1);\nREGISTER_CONV_KERNEL(conv2d, float, 2);\nREGISTER_CONV_KERNEL(conv3d, float, 3);\nREGISTER_CONV_KERNEL(conv1d, double, 1);\nREGISTER_CONV_KERNEL(conv2d, double, 2);\nREGISTER_CONV_KERNEL(conv3d, double, 3);\n\ntemplate<typename T>\nclass ConvDataGradCpuKernel final : public user_op::OpKernel {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(ConvDataGradCpuKernel);\n  ConvDataGradCpuKernel() = default;\n  ~ConvDataGradCpuKernel() = default;\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n\n  std::shared_ptr<user_op::OpKernelCache> InitOpKernelCache(\n      user_op::KernelCacheContext* ctx) const override {\n    return CreateConvOpKernelCache<T>(ctx, \"dx\", \"dy\", \"filter\");\n  }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*,\n               const user_op::OpKernelCache* cache) const override {\n    const auto* conv_cache = dynamic_cast<const ConvOpKernelCache<T>*>(cache);\n    CHECK_NOTNULL(conv_cache);\n\n    const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    const user_op::Tensor* filter = ctx->Tensor4ArgNameAndIndex(\"filter\", 0);\n    user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex(\"dx\", 0);\n    user_op::Tensor* col_buf = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n\n    int32_t idx_offset = conv_cache->idx_offset_;\n    const int32_t dy_group_interval = dy->shape_view().At(1) / conv_cache->groups;\n    const int32_t filter_group_interval = filter->shape_view().At(0) / conv_cache->groups;\n    const int32_t dx_group_interval = dx->shape_view().At(1) / conv_cache->groups;\n    const int32_t dx_step = dx_group_interval * dx->shape_view().Count(2);\n    const int32_t filter_step = filter_group_interval * filter->shape_view().Count(1);\n    const int32_t dy_step = dy_group_interval * dy->shape_view().Count(2);\n    const int32_t m = conv_cache->weight_5d_shape_.Count(1);\n    const int32_t n = conv_cache->out_5d_shape_.Count(idx_offset, idx_offset + 3);\n    const int32_t k = conv_cache->weight_5d_shape_.At(0) / conv_cache->groups;\n\n    Memset<DeviceType::kCPU>(ctx->stream(), dx->mut_dptr<T>(), 0,\n                             dx->shape_view().elem_cnt() * sizeof(T));\n\n    std::unique_ptr<ep::primitive::Matmul> matmul;\n    if (conv_cache->is_out_diff_need_trans_) {\n      matmul = NewConvDataGradTransATransBMatmulPrimitive(ctx);\n    } else {\n      matmul = NewConvDataGradTransANoTransBMatmulPrimitive(ctx);\n    }\n    CHECK(matmul);\n\n    FOR_RANGE(int64_t, i, 0, dy->shape_view().At(0)) {\n      const T* filter_ptr = filter->dptr<T>();\n      const T* dy_ptr = GetImgDptr<T>(dy, i);\n      T* dx_ptr = GetImgMutDptr<T>(dx, i);\n      FOR_RANGE(int64_t, g, 0, conv_cache->groups) {\n        // channels first:  col_buf' = weight(T) * out[i]'\n        // channels last :  col_buf' = weight(T) * out[i]'(T)\n        matmul->Launch(ctx->stream(),\n                       m,  //  ci * kd * kh * kw / groups\n                       n,  //  od * oh * ow\n                       k,  //  filter / groups\n                       static_cast<T>(1), filter_ptr, dy_ptr, static_cast<T>(0),\n                       col_buf->mut_dptr<T>());\n\n        // in' = col2im(col_buf')\n        conv_cache->col2im_func_(\n            col_buf->dptr<T>(), ShapeView(conv_cache->in_5d_shape_),\n            ShapeView(conv_cache->weight_5d_shape_), ShapeView(conv_cache->out_5d_shape_),\n            conv_cache->strides_3d_.data(), conv_cache->dilation_rate_3d_.data(),\n            conv_cache->padding_before_3d_.data(), dx_ptr);\n        filter_ptr += filter_step;\n        dy_ptr += dy_step;\n        dx_ptr += dx_step;\n      }\n    }\n    if (ctx->has_input(\"_add_to_output\", 0)) {\n      const user_op::Tensor* add_to_output = ctx->Tensor4ArgNameAndIndex(\"_add_to_output\", 0);\n      CHECK_EQ(add_to_output->data_type(), dx->data_type());\n      CHECK_EQ(add_to_output->shape_view(), dx->shape_view());\n      std::unique_ptr<ep::primitive::Add> primitive =\n          ep::primitive::NewPrimitive<ep::primitive::AddFactory>(DeviceType::kCPU,\n                                                                 add_to_output->data_type());\n      CHECK(primitive);\n      primitive->Launch(ctx->stream(), dx->dptr<T>(), add_to_output->dptr<T>(), dx->mut_dptr<T>(),\n                        add_to_output->shape_view().elem_cnt());\n    }\n  }\n};\n\n#define REGISTER_CONV_DATA_GRAD_KERNEL(op_name, dtype)                                     \\\n  REGISTER_USER_KERNEL(#op_name)                                                           \\\n      .SetCreateFn<ConvDataGradCpuKernel<dtype>>()                                         \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)                      \\\n                       && (user_op::HobAttr<int32_t>(\"groups\") > 1)                        \\\n                       && (user_op::HobDataType(\"dy\", 0) == GetDataType<dtype>::value)     \\\n                       && ConvDataGradTransATransBMatmulPrimitiveExists()                  \\\n                       && ConvDataGradTransANoTransBMatmulPrimitiveExists())               \\\n      .SetInferTmpSizeFn([](user_op::InferContext* ctx) -> size_t {                        \\\n        size_t tmp_buffer_size = 0;                                                        \\\n        const auto& out_diff_shape = ctx->InputTensorDesc(\"dy\", 0).shape();                \\\n        const auto& weight_shape = ctx->InputTensorDesc(\"filter\", 0).shape();              \\\n                                                                                           \\\n        int64_t idx_offset = IdxOffset(ctx->Attr<std::string>(\"data_format\"));             \\\n        tmp_buffer_size +=                                                                 \\\n            CalcElemNumOfColBuf(out_diff_shape, weight_shape, idx_offset) * sizeof(dtype); \\\n        return tmp_buffer_size;                                                            \\\n      })\n\nREGISTER_CONV_DATA_GRAD_KERNEL(conv_data_grad, float);\nREGISTER_CONV_DATA_GRAD_KERNEL(conv_data_grad, double);\n\ntemplate<typename T>\nclass ConvFilterGradCpuKernel final : public user_op::OpKernel {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(ConvFilterGradCpuKernel);\n  ConvFilterGradCpuKernel() = default;\n  ~ConvFilterGradCpuKernel() = default;\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n\n  std::shared_ptr<user_op::OpKernelCache> InitOpKernelCache(\n      user_op::KernelCacheContext* ctx) const override {\n    return CreateConvOpKernelCache<T>(ctx, \"x\", \"dy\", \"filter_diff\");\n  }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*,\n               const user_op::OpKernelCache* cache) const override {\n    const auto* conv_cache = dynamic_cast<const ConvOpKernelCache<T>*>(cache);\n    CHECK_NOTNULL(conv_cache);\n\n    const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    user_op::Tensor* filter_diff = ctx->Tensor4ArgNameAndIndex(\"filter_diff\", 0);\n    user_op::Tensor* col_buf = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n    int32_t idx_offset = conv_cache->idx_offset_;\n    const int32_t dy_group_interval = dy->shape_view().At(1) / conv_cache->groups;\n    const int32_t filter_diff_group_interval = filter_diff->shape_view().At(0) / conv_cache->groups;\n    const int32_t x_group_interval = x->shape_view().At(1) / conv_cache->groups;\n    const int32_t x_step = x_group_interval * x->shape_view().Count(2);\n    const int32_t dy_step = dy_group_interval * dy->shape_view().Count(2);\n    const int32_t filter_diff_step =\n        filter_diff_group_interval * filter_diff->shape_view().Count(1);\n    const int32_t m = conv_cache->weight_5d_shape_.At(0) / conv_cache->groups;\n    const int32_t n = conv_cache->weight_5d_shape_.Count(1);\n    const int32_t k = conv_cache->out_5d_shape_.Count(idx_offset, idx_offset + 3);\n\n    Memset<DeviceType::kCPU>(ctx->stream(), filter_diff->mut_dptr<T>(), 0,\n                             filter_diff->shape_view().elem_cnt() * sizeof(T));\n\n    std::unique_ptr<ep::primitive::Matmul> matmul;\n    if (conv_cache->is_out_diff_need_trans_) {\n      matmul = NewConvWeightGradTransATransBMatmulPrimitive(ctx);\n    } else {\n      matmul = NewConvWeightGradNoTransATransBMatmulPrimitive(ctx);\n    }\n    CHECK(matmul);\n\n    FOR_RANGE(int64_t, i, 0, dy->shape_view().At(0)) {\n      const T* x_ptr = GetImgDptr<T>(x, i);\n      const T* dy_ptr = GetImgDptr<T>(dy, i);\n      T* filter_diff_ptr = filter_diff->mut_dptr<T>();\n      FOR_RANGE(int64_t, g, 0, conv_cache->groups) {\n        conv_cache->im2col_func_(\n            x_ptr, ShapeView(conv_cache->in_5d_shape_), ShapeView(conv_cache->weight_5d_shape_),\n            ShapeView(conv_cache->out_5d_shape_), conv_cache->strides_3d_.data(),\n            conv_cache->dilation_rate_3d_.data(), conv_cache->padding_before_3d_.data(),\n            col_buf->mut_dptr<T>());\n\n        // channels first:  weight' += out[i]' * col_buf(T)\n        // channels last :  weight' += out[i]'(T) * col_buf(T)\n        matmul->Launch(ctx->stream(),\n                       m,  //  filter / groups\n                       n,  //  ci * kd * kh * kw\n                       k,  //  od * oh * ow / groups\n                       static_cast<T>(1), dy_ptr, col_buf->dptr<T>(), static_cast<T>(1),\n                       filter_diff_ptr);\n        x_ptr += x_step;\n        dy_ptr += dy_step;\n        filter_diff_ptr += filter_diff_step;\n      }\n    }\n  }\n};\n\n#define REGISTER_CONV_FILTER_GRAD_KERNEL(op_name, dtype)                                        \\\n  REGISTER_USER_KERNEL(#op_name)                                                                \\\n      .SetCreateFn<ConvFilterGradCpuKernel<dtype>>()                                            \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)                           \\\n                       && (user_op::HobAttr<int32_t>(\"groups\") > 1)                             \\\n                       && (user_op::HobDataType(\"dy\", 0) == GetDataType<dtype>::value)          \\\n                       && ConvWeightGradTransATransBMatmulPrimitiveExists()                     \\\n                       && ConvWeightGradNoTransATransBMatmulPrimitiveExists())                  \\\n      .SetInferTmpSizeFn([](user_op::InferContext* ctx) -> size_t {                             \\\n        size_t tmp_buffer_size = 0;                                                             \\\n        const auto& out_diff_shape = ctx->InputTensorDesc(\"dy\", 0).shape();                     \\\n        const auto& weight_diff_shape = ctx->OutputTensorDesc(\"filter_diff\", 0).shape();        \\\n                                                                                                \\\n        int64_t idx_offset = IdxOffset(ctx->Attr<std::string>(\"data_format\"));                  \\\n        tmp_buffer_size +=                                                                      \\\n            CalcElemNumOfColBuf(out_diff_shape, weight_diff_shape, idx_offset) * sizeof(dtype); \\\n        return tmp_buffer_size;                                                                 \\\n      })\n\nREGISTER_CONV_FILTER_GRAD_KERNEL(conv_filter_grad, float);\nREGISTER_CONV_FILTER_GRAD_KERNEL(conv_filter_grad, double);\n\n}  // namespace\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/group_deconv_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/job/lazy_mode.h\"\n#include \"oneflow/user/ops/nn_util.h\"\n#include \"oneflow/core/kernel/kernel_util.h\"\n#include \"oneflow/core/ep/include/primitive/matmul.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nep::primitive::BlasTransposeType GetBlasTransposeType(bool transpose) {\n  return transpose ? ep::primitive::BlasTransposeType::T : ep::primitive::BlasTransposeType::N;\n}\n\nstd::unique_ptr<ep::primitive::Matmul> NewMatmulPrimitive(DeviceType device_type,\n                                                          DataType data_type, bool transpose_a,\n                                                          bool transpose_b) {\n  const auto trans_a = GetBlasTransposeType(transpose_a);\n  const auto trans_b = GetBlasTransposeType(transpose_b);\n  return ep::primitive::NewPrimitive<ep::primitive::MatmulFactory>(device_type, data_type, trans_a,\n                                                                   trans_b);\n}\n\ntemplate<typename Context>\nstd::unique_ptr<ep::primitive::Matmul> NewDeconvTransATransBMatmulPrimitive(Context* ctx) {\n  const DataType data_type = ctx->TensorDesc4ArgNameAndIndex(\"in\", 0)->data_type();\n  return NewMatmulPrimitive(ctx->device_type(), data_type, true, true);\n}\n\ntemplate<typename Context>\nstd::unique_ptr<ep::primitive::Matmul> NewDeconvTransANoTransBMatmulPrimitive(Context* ctx) {\n  const DataType data_type = ctx->TensorDesc4ArgNameAndIndex(\"in\", 0)->data_type();\n  return NewMatmulPrimitive(ctx->device_type(), data_type, true, false);\n}\n\nauto DeconvTransATransBMatmulPrimitiveExists() {\n  return hob::make_custom(\"DeconvTransATransBMatmulPrimitiveExists\",\n                          [](const user_op::KernelRegContext& ctx) {\n                            return NewDeconvTransATransBMatmulPrimitive(&ctx).operator bool();\n                          });\n}\n\nauto DeconvTransANoTransBMatmulPrimitiveExists() {\n  return hob::make_custom(\"DeconvTransANoTransBMatmulPrimitiveExists\",\n                          [](const user_op::KernelRegContext& ctx) {\n                            return NewDeconvTransANoTransBMatmulPrimitive(&ctx).operator bool();\n                          });\n}\n\ntemplate<typename T>\nusing Col2ImFunc = void (*)(const T* col_buf, const ShapeView& in_shape,\n                            const ShapeView& weight_shape, const ShapeView& out_shape,\n                            const int32_t* strides, const int32_t* dilation_rate,\n                            const int32_t* padding_before, T* in_diff_ptr);\n\ntemplate<typename T>\nT* GetImgMutDptr(user_op::Tensor* tensor, int64_t idx) {\n  return tensor->mut_dptr<T>() + tensor->shape_view().Count(1) * idx;\n}\n\ntemplate<typename T>\nconst T* GetImgDptr(const user_op::Tensor* tensor, int64_t idx) {\n  return tensor->dptr<T>() + tensor->shape_view().Count(1) * idx;\n}\n\nsize_t CalcElemNumOfColBuf(const ShapeView& out_shape, const ShapeView& weight_shape,\n                           const int32_t idx_offset) {\n  int64_t col_buf_elem_cnt = 1;\n  int64_t ndims = out_shape.NumAxes() - 2;\n  for (size_t i = 0; i != ndims + 1; ++i) { col_buf_elem_cnt *= weight_shape.At(i + 1); }\n  for (size_t i = 0; i != ndims; ++i) { col_buf_elem_cnt *= out_shape.At(idx_offset + i); }\n  return col_buf_elem_cnt;\n}\n\ntemplate<typename T>\nclass ColBufWriter {\n public:\n  ColBufWriter(const T* src_ptr, T* dst_ptr, int64_t c_size, int64_t id_size, int64_t ih_size,\n               int64_t iw_size, int64_t od_size, int64_t oh_size, int64_t ow_size)\n      : src_ptr_(src_ptr),\n        dst_ptr_(dst_ptr),\n        c_size_(c_size),\n        id_size_(id_size),\n        ih_size_(ih_size),\n        iw_size_(iw_size),\n        od_size_(od_size),\n        oh_size_(oh_size),\n        ow_size_(ow_size) {}\n  virtual ~ColBufWriter() = default;\n  virtual void DHWCWrite(int64_t c, int64_t id, int64_t ih, int64_t iw) = 0;\n  virtual void CDHWWrite(int64_t c, int64_t id, int64_t ih, int64_t iw) = 0;\n  virtual void InvalidDFunc() = 0;\n  virtual void InvalidHFunc() = 0;\n  virtual void InvalidWFunc() = 0;\n  virtual void NextImCSize() = 0;\n\n protected:\n  const T* src_ptr_;\n  T* dst_ptr_;\n  int64_t c_size_ = 0;\n  int64_t id_size_ = 0;\n  int64_t ih_size_ = 0;\n  int64_t iw_size_ = 0;\n  int64_t od_size_ = 0;\n  int64_t oh_size_ = 0;\n  int64_t ow_size_ = 0;\n};\n\ntemplate<typename T>\nclass Col2ImWriter final : public ColBufWriter<T> {\n public:\n  Col2ImWriter(const T* src_ptr, T* dst_ptr, int64_t c_size, int64_t id_size, int64_t ih_size,\n               int64_t iw_size, int64_t od_size, int64_t oh_size, int64_t ow_size)\n      : ColBufWriter<T>::ColBufWriter(src_ptr, dst_ptr, c_size, id_size, ih_size, iw_size, od_size,\n                                      oh_size, ow_size) {}\n  ~Col2ImWriter() = default;\n  void DHWCWrite(int64_t c, int64_t id, int64_t ih, int64_t iw) override {\n    this->dst_ptr_[id * this->id_size_ + ih * this->ih_size_ + iw * this->iw_size_ + c] +=\n        *(this->src_ptr_++);\n  }\n  void CDHWWrite(int64_t c, int64_t id, int64_t ih, int64_t iw) override {\n    this->dst_ptr_[id * this->id_size_ + ih * this->ih_size_ + iw] += *(this->src_ptr_++);\n  }\n  void InvalidDFunc() override { this->src_ptr_ += this->od_size_; }\n  void InvalidHFunc() override { this->src_ptr_ += this->oh_size_; }\n  void InvalidWFunc() override { this->src_ptr_ += this->ow_size_; }\n  void NextImCSize() override { this->dst_ptr_ += this->c_size_; }\n};\n\ntemplate<typename T>\nusing DHWValidFunc = void (ColBufWriter<T>::*)(int64_t c, int64_t kd, int64_t kh, int64_t kw);\n\ntemplate<typename T>\nclass ColBufUtil final {\n public:\n  ColBufUtil(const ShapeView& in_shape, const ShapeView& out_shape, int32_t dhw_offset,\n             const int32_t* strides, const int32_t* dilation_rate, const int32_t* padding_before,\n             const int32_t id_num, const int32_t ih_num, const int32_t iw_num, const int32_t od_num,\n             const int32_t oh_num, const int32_t ow_num)\n      : strides_(strides),\n        dilation_rate_(dilation_rate),\n        padding_before_(padding_before),\n        id_num_(id_num),\n        ih_num_(ih_num),\n        iw_num_(iw_num),\n        od_num_(od_num),\n        oh_num_(oh_num),\n        ow_num_(ow_num) {\n    if (dhw_offset == 2) {\n      dhw_valid_func_ = &ColBufWriter<T>::CDHWWrite;\n    } else {\n      dhw_valid_func_ = &ColBufWriter<T>::DHWCWrite;\n    }\n  }\n  void operator()(ColBufWriter<T>* col_buf_writer, int64_t c, int64_t kd, int64_t kh, int64_t kw) {\n    int64_t id = kd * dilation_rate_[0] - padding_before_[0];\n    FOR_RANGE(int64_t, od, 0, od_num_) {\n      if (id < 0 || id >= id_num_) {\n        col_buf_writer->InvalidDFunc();\n      } else {\n        int64_t ih = kh * dilation_rate_[1] - padding_before_[1];\n        FOR_RANGE(int64_t, oh, 0, oh_num_) {\n          if (ih < 0 || ih >= ih_num_) {\n            col_buf_writer->InvalidHFunc();\n          } else {\n            int64_t iw = kw * dilation_rate_[2] - padding_before_[2];\n            FOR_RANGE(int64_t, ow, 0, ow_num_) {\n              if (iw < 0 || iw >= iw_num_) {\n                col_buf_writer->InvalidWFunc();\n              } else {\n                (col_buf_writer->*dhw_valid_func_)(c, id, ih, iw);\n              }\n              iw += strides_[2];\n            }\n          }\n          ih += strides_[1];\n        }\n      }\n      id += strides_[0];\n    }\n  }\n\n private:\n  const int32_t* strides_;\n  const int32_t* dilation_rate_;\n  const int32_t* padding_before_;\n  DHWValidFunc<T> dhw_valid_func_;\n  int64_t id_num_ = 0;\n  int64_t ih_num_ = 0;\n  int64_t iw_num_ = 0;\n  int64_t od_num_ = 0;\n  int64_t oh_num_ = 0;\n  int64_t ow_num_ = 0;\n};\n\ntemplate<typename T>\nstruct DeconvKernelUtil final {\n public:\n  static void NCDHWCol2Im(const T* col_buf_ptr, const ShapeView& in_shape,\n                          const ShapeView& weight_shape, const ShapeView& out_shape,\n                          const int32_t* strides, const int32_t* dilation_rate,\n                          const int32_t* padding_before, T* in_diff_ptr) {\n    ColBufUtil<T> col_buf_util(in_shape, out_shape, 2, strides, dilation_rate, padding_before,\n                               in_shape.At(2), in_shape.At(3), in_shape.At(4), out_shape.At(2),\n                               out_shape.At(3), out_shape.At(4));\n    Col2ImWriter<T> col_buf_writer(col_buf_ptr, in_diff_ptr, in_shape.Count(2), in_shape.Count(3),\n                                   in_shape.Count(4), 1, out_shape.Count(3), out_shape.Count(4), 1);\n    DoNCDWHFunc(weight_shape, col_buf_util, &col_buf_writer);\n  }\n\n  static void NDHWCCol2Im(const T* col_buf_ptr, const ShapeView& in_shape,\n                          const ShapeView& weight_shape, const ShapeView& out_shape,\n                          const int32_t* strides, const int32_t* dilation_rate,\n                          const int32_t* padding_before, T* in_diff_ptr) {\n    ColBufUtil<T> col_buf_util(in_shape, out_shape, 2, strides, dilation_rate, padding_before,\n                               in_shape.At(2), in_shape.At(3), in_shape.At(4), out_shape.At(2),\n                               out_shape.At(3), out_shape.At(4));\n    Col2ImWriter<T> col_buf_writer(col_buf_ptr, in_diff_ptr, in_shape.Count(2), in_shape.Count(2),\n                                   in_shape.Count(3), in_shape.Count(4), out_shape.Count(2, 4),\n                                   out_shape.Count(3, 4), 1);\n    DoNDWHCFunc(weight_shape, col_buf_util, &col_buf_writer);\n  }\n\n private:\n  static void DoNCDWHFunc(const ShapeView& weight_shape, ColBufUtil<T>& col_buf_util,\n                          ColBufWriter<T>* col_buf_writer) {\n    for (int64_t c = 0; c != weight_shape.At(1); col_buf_writer->NextImCSize(), ++c) {\n      for (int64_t kd = 0; kd != weight_shape.At(2); ++kd) {\n        for (int64_t kh = 0; kh != weight_shape.At(3); ++kh) {\n          for (int64_t kw = 0; kw != weight_shape.At(4); ++kw) {\n            col_buf_util(col_buf_writer, c, kd, kh, kw);\n          }\n        }\n      }\n    }\n  }\n\n  static void DoNDWHCFunc(const ShapeView& weight_shape, ColBufUtil<T>& col_buf_util,\n                          ColBufWriter<T>* col_buf_writer) {\n    for (int64_t kd = 0; kd != weight_shape.At(1); ++kd) {\n      for (int64_t kh = 0; kh != weight_shape.At(2); ++kh) {\n        for (int64_t kw = 0; kw != weight_shape.At(3); ++kw) {\n          for (int64_t c = 0; c != weight_shape.At(4); ++c) {\n            col_buf_util(col_buf_writer, c, kd, kh, kw);\n          }\n        }\n      }\n    }\n  }\n};\n\ntemplate<typename T>\nstruct DeconvOpKernelCache final : public user_op::OpKernelCache {\n  Col2ImFunc<T> col2im_func_ = DeconvKernelUtil<T>::NCDHWCol2Im;\n  ;\n\n  Shape in_5d_shape_;\n  Shape out_5d_shape_;\n  Shape weight_5d_shape_;\n\n  std::vector<int32_t> strides_3d_;\n  std::vector<int32_t> dilation_rate_3d_;\n  std::vector<int32_t> padding_before_3d_;\n\n  bool is_out_diff_need_trans_ = false;\n\n  int32_t idx_offset_ = 0;\n  bool is_dynamic_ = false;\n  int32_t groups = 1;\n\n  void Update(const ShapeView& x_shape, const ShapeView& out_shape) {\n    auto Gen5DShape = [](const ShapeView& shape, int32_t idx_offset) -> Shape {\n      DimVector ret_vec;\n      shape.ToDimVector(&ret_vec);\n      int32_t ndims = ret_vec.size() - 2;\n      ret_vec.insert(ret_vec.begin() + idx_offset, 3 - ndims, 1);\n      return Shape(ret_vec);\n    };\n    if (is_dynamic_) {\n      Shape in_shape;\n      in_5d_shape_ = Gen5DShape(x_shape, idx_offset_);\n      out_5d_shape_ = Gen5DShape(out_shape, idx_offset_);\n    }\n  }\n};\n\ntemplate<typename T>\nstd::shared_ptr<DeconvOpKernelCache<T>> CreateDeconvOpKernelCache(user_op::KernelCacheContext* ctx,\n                                                                  const std::string& in_name,\n                                                                  const std::string& out_name,\n                                                                  const std::string& weight_name) {\n  const auto& data_format = ctx->Attr<std::string>(\"data_format\");\n\n  std::shared_ptr<DeconvOpKernelCache<T>> state(new DeconvOpKernelCache<T>());\n  if (data_format == \"channels_first\") {\n    state->col2im_func_ = DeconvKernelUtil<T>::NCDHWCol2Im;\n    state->is_out_diff_need_trans_ = false;\n    state->idx_offset_ = 2;\n  } else {\n    state->col2im_func_ = DeconvKernelUtil<T>::NDHWCCol2Im;\n    state->is_out_diff_need_trans_ = true;\n    state->idx_offset_ = 1;\n  }\n\n  auto Gen5DShape = [](const Shape& shape, int32_t idx_offset) -> Shape {\n    DimVector ret_vec(shape.dim_vec());\n    int32_t ndims = ret_vec.size() - 2;\n    ret_vec.insert(ret_vec.begin() + idx_offset, 3 - ndims, 1);\n    return Shape(ret_vec);\n  };\n  state->groups = ctx->Attr<int32_t>(\"groups\");\n\n  state->in_5d_shape_ =\n      Gen5DShape(ctx->TensorDesc4ArgNameAndIndex(in_name, 0)->shape(), state->idx_offset_);\n  state->out_5d_shape_ =\n      Gen5DShape(ctx->TensorDesc4ArgNameAndIndex(out_name, 0)->shape(), state->idx_offset_);\n  state->weight_5d_shape_ =\n      Gen5DShape(ctx->TensorDesc4ArgNameAndIndex(weight_name, 0)->shape(), state->idx_offset_);\n\n  auto Gen3DVec = [](const std::vector<int32_t>& origin_vec) -> std::vector<int32_t> {\n    std::vector<int32_t> ret_vec = origin_vec;\n    ret_vec.insert(ret_vec.begin(), 3 - ret_vec.size(), 1);\n    return ret_vec;\n  };\n  state->strides_3d_ = Gen3DVec(ctx->Attr<std::vector<int32_t>>(\"strides\"));\n  state->dilation_rate_3d_ = Gen3DVec(ctx->Attr<std::vector<int32_t>>(\"dilation_rate\"));\n  state->is_dynamic_ = ctx->TensorDesc4ArgNameAndIndex(in_name, 0)->is_dynamic();\n  const auto& padding_before = ctx->Attr<std::vector<int32_t>>(\"padding_before\");\n  FOR_RANGE(uint8_t, dim, 0, 3) {\n    int64_t index = static_cast<int64_t>(dim) - (3 - padding_before.size());\n    if (index < 0) {\n      state->padding_before_3d_.emplace_back(0);\n    } else {\n      state->padding_before_3d_.emplace_back(padding_before.at(index));\n    }\n  }\n\n  return state;\n}\n\ntemplate<typename T>\nclass DeconvCpuKernel final : public user_op::OpKernel {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(DeconvCpuKernel);\n  DeconvCpuKernel() = default;\n  ~DeconvCpuKernel() = default;\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n\n  void InitOpKernelCacheWithFlags(\n      user_op::KernelCacheContext* ctx, int8_t flag,\n      std::shared_ptr<user_op::OpKernelCache>* cache_ptr) const override {\n    if (*cache_ptr != nullptr && (flag & user_op::OpKernelCache::kAttrNotChanged)) {\n      auto deconv_cache = std::dynamic_pointer_cast<DeconvOpKernelCache<T>>(*cache_ptr);\n      deconv_cache->Update(ctx->TensorDesc4ArgNameAndIndex(\"in\", 0)->shape(),\n                           ctx->TensorDesc4ArgNameAndIndex(\"out\", 0)->shape());\n      return;\n    }\n    *cache_ptr = CreateDeconvOpKernelCache<T>(ctx, \"out\", \"in\", \"weight\");\n  }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*,\n               const user_op::OpKernelCache* cache) const override {\n    auto deconv_cache = dynamic_cast<const DeconvOpKernelCache<T>*>(cache);\n    CHECK_NOTNULL(deconv_cache);\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    const user_op::Tensor* weight = ctx->Tensor4ArgNameAndIndex(\"weight\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    user_op::Tensor* col_buf = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n\n    int32_t idx_offset = deconv_cache->idx_offset_;\n    const int32_t input_group_interval = in->shape_view().At(1) / deconv_cache->groups;\n    const int32_t weight_group_interval = weight->shape_view().At(0) / deconv_cache->groups;\n    const int32_t output_group_interval = out->shape_view().At(1) / deconv_cache->groups;\n    const int32_t input_step = input_group_interval * in->shape_view().Count(2);\n    const int32_t weight_step = weight_group_interval * weight->shape_view().Count(1);\n    const int32_t output_step = output_group_interval * out->shape_view().Count(2);\n    const int32_t m = deconv_cache->weight_5d_shape_.Count(1);\n    const int32_t n = deconv_cache->out_5d_shape_.Count(idx_offset, idx_offset + 3);\n    const int32_t k = deconv_cache->weight_5d_shape_.At(0) / deconv_cache->groups;\n\n    Memset<DeviceType::kCPU>(ctx->stream(), out->mut_dptr<T>(), 0,\n                             out->shape_view().elem_cnt() * sizeof(T));\n\n    std::unique_ptr<ep::primitive::Matmul> matmul;\n    if (deconv_cache->is_out_diff_need_trans_) {\n      matmul = NewDeconvTransATransBMatmulPrimitive(ctx);\n    } else {\n      matmul = NewDeconvTransANoTransBMatmulPrimitive(ctx);\n    }\n    CHECK(matmul);\n\n    FOR_RANGE(int64_t, i, 0, in->shape_view().At(0)) {\n      const T* input_ptr = GetImgDptr<T>(in, i);\n      const T* weight_ptr = weight->dptr<T>();\n      T* output_ptr = GetImgMutDptr<T>(out, i);\n\n      FOR_RANGE(int64_t, g, 0, deconv_cache->groups) {\n        matmul->Launch(ctx->stream(),\n                       m,  //  (co / groups) * kd * kh * kw\n                       n,  //  od * oh * ow\n                       k,  //  filter / groups\n                       static_cast<T>(1), weight_ptr, input_ptr, static_cast<T>(0),\n                       col_buf->mut_dptr<T>());\n\n        // out = col2im(col_buf')\n        deconv_cache->col2im_func_(\n            col_buf->mut_dptr<T>(), ShapeView(deconv_cache->in_5d_shape_),\n            ShapeView(deconv_cache->weight_5d_shape_), ShapeView(deconv_cache->out_5d_shape_),\n            deconv_cache->strides_3d_.data(), deconv_cache->dilation_rate_3d_.data(),\n            deconv_cache->padding_before_3d_.data(), output_ptr);\n        input_ptr += input_step;\n        weight_ptr += weight_step;\n        output_ptr += output_step;\n      }\n    }\n  }\n};\n\n#define REGISTER_DECONV_DATA_KERNEL(op_name, dtype)                                     \\\n  REGISTER_USER_KERNEL(#op_name)                                                        \\\n      .SetCreateFn<DeconvCpuKernel<dtype>>()                                            \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)                   \\\n                       && (user_op::HobAttr<int32_t>(\"groups\") > 1)                     \\\n                       && (user_op::HobDataType(\"out\", 0) == GetDataType<dtype>::value) \\\n                       && DeconvTransATransBMatmulPrimitiveExists()                     \\\n                       && DeconvTransANoTransBMatmulPrimitiveExists())                  \\\n      .SetInferTmpSizeFn([](user_op::InferContext* ctx) -> size_t {                     \\\n        size_t tmp_buffer_size = 0;                                                     \\\n        const auto& in_shape = ctx->InputTensorDesc(\"in\", 0).shape();                   \\\n        const auto& weight_shape = ctx->InputTensorDesc(\"weight\", 0).shape();           \\\n                                                                                        \\\n        int64_t idx_offset = IdxOffset(ctx->Attr<std::string>(\"data_format\"));          \\\n        tmp_buffer_size +=                                                              \\\n            CalcElemNumOfColBuf(in_shape, weight_shape, idx_offset) * sizeof(dtype);    \\\n        return tmp_buffer_size;                                                         \\\n      })\n\nREGISTER_DECONV_DATA_KERNEL(deconv1d, float);\nREGISTER_DECONV_DATA_KERNEL(deconv1d, double);\nREGISTER_DECONV_DATA_KERNEL(deconv2d, float);\nREGISTER_DECONV_DATA_KERNEL(deconv2d, double);\nREGISTER_DECONV_DATA_KERNEL(deconv3d, float);\nREGISTER_DECONV_DATA_KERNEL(deconv3d, double);\n\n}  // namespace\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/group_norm_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/device/cudnn_util.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/ndarray/ndarray_util.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n#include \"oneflow/core/ep/cuda/primitive/unary_functor.cuh\"\n#include \"oneflow/core/cuda/layer_norm.cuh\"\n#include <cub/cub.cuh>\n#include \"oneflow/core/kernel/cuda_graph_support.h\"\n\n#ifdef WITH_CUTLASS\n#include <cutlass/fast_math.h>\n#endif  // WITH_CUTLASS\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename SRC, typename DST, ep::primitive::UnaryOp activation, bool affine>\nstruct AffineStore {\n  AffineStore(DST* y, int64_t row_size, int64_t channel_size, int64_t spatial_size,\n              const DST* gamma, const DST* beta)\n      : y(y),\n        row_size(row_size),\n        channel_size(channel_size),\n        spatial_size(spatial_size),\n        gamma(gamma),\n        beta(beta),\n        act(0, 0) {}\n\n  template<int PackSize>\n  __device__ void store(const SRC* src, int64_t row, int64_t col) {\n    cuda::layer_norm::Pack<DST, PackSize> y_pack;\n    const int64_t offset = row * row_size + col;\n    const int64_t packed_offset = offset / PackSize;\n    const int64_t gamma_beta_offset = (offset / spatial_size) % channel_size;\n    DST gamma_val = 1.0;\n    DST beta_val = 0.0;\n    if (affine) {\n      gamma_val = gamma[gamma_beta_offset];\n      beta_val = beta[gamma_beta_offset];\n    }\n\n#pragma unroll\n    for (int i = 0; i < PackSize; ++i) {\n      DST normalized_i = static_cast<DST>(src[i]);\n      if (affine) {\n        y_pack.elem[i] = act(normalized_i * gamma_val + beta_val);\n      } else {\n        // Direct Store.\n        y_pack.elem[i] = act(normalized_i);\n      }\n    }\n    *(reinterpret_cast<cuda::layer_norm::PackType<DST, PackSize>*>(y) + packed_offset) =\n        y_pack.storage;\n  }\n  bool CanPackAs(size_t pack_size) { return (spatial_size % pack_size) == 0; }\n  DST* y;\n  int64_t row_size;\n  int64_t channel_size;\n  int64_t spatial_size;\n  const DST* gamma;\n  const DST* beta;\n  ep::primitive::UnaryFunctor<DeviceType::kCUDA, activation, DST, DST> act;\n};\n\ntemplate<typename SRC, typename DST, bool affine>\nstruct ScaleLoad {\n  using LoadType = DST;\n  ScaleLoad(const SRC* src, const SRC* gamma, int64_t row_size, int64_t channel_size,\n            int64_t spatial_size)\n      : src(src),\n        gamma(gamma),\n        row_size(row_size),\n        channel_size(channel_size),\n        spatial_size(spatial_size) {}\n  template<int PackSize>\n  __device__ void load(DST* dst, int64_t row, int64_t col) const {\n    cuda::layer_norm::Pack<SRC, PackSize> src_pack;\n    cuda::layer_norm::Pack<SRC, PackSize> gamma_pack;\n\n    const int64_t offset = row * row_size + col;\n    const int64_t packed_offset = offset / PackSize;\n    const int64_t gamma_offset = (offset / spatial_size) % channel_size;\n\n    src_pack.storage =\n        *(reinterpret_cast<const cuda::layer_norm::PackType<SRC, PackSize>*>(src) + packed_offset);\n    SRC gamma_val = static_cast<SRC>(1.0);\n    if (affine) { gamma_val = gamma[gamma_offset]; }\n#pragma unroll\n    for (int i = 0; i < PackSize; ++i) { dst[i] = static_cast<DST>(src_pack.elem[i] * gamma_val); }\n  }\n  bool CanPackAs(size_t pack_size) { return (spatial_size % pack_size) == 0; }\n  const SRC* src;\n  const SRC* gamma;\n  int64_t row_size;\n  int64_t channel_size;\n  int64_t spatial_size;\n};\n\n#ifdef WITH_CUTLASS\n\ntemplate<typename SRC, typename DST, ep::primitive::UnaryOp activation, bool affine>\nstruct ChannelsLastStore {\n  ChannelsLastStore(DST* y, const DST* gamma, const DST* beta, int64_t spatial_size,\n                    int64_t channel_size, int64_t num_groups)\n      : y(y),\n        gamma(gamma),\n        beta(beta),\n        spatial_size(spatial_size),\n        c0(num_groups),\n        c1(channel_size / num_groups),\n        act(0, 0) {}\n\n  template<int PackSize>\n  __device__ void store(const SRC* src, int32_t row, int32_t col) {\n    cuda::layer_norm::Pack<DST, PackSize> y_pack;\n    cuda::layer_norm::Pack<DST, PackSize> gamma_pack;\n    cuda::layer_norm::Pack<DST, PackSize> beta_pack;\n    int32_t spatial_idx;\n    int32_t c1_idx;\n    c1(spatial_idx, c1_idx, col);\n    int32_t batch_idx;\n    int32_t c0_idx;\n    c0(batch_idx, c0_idx, row);\n    const int32_t y_offset =\n        (batch_idx * c0.divisor * c1.divisor * spatial_size + spatial_idx * c0.divisor * c1.divisor\n         + c0_idx * c1.divisor + c1_idx)\n        / PackSize;\n    const int32_t gamma_beta_offset = (c0_idx * c1.divisor + c1_idx) / PackSize;\n    if (affine) {\n      gamma_pack.storage =\n          *(reinterpret_cast<const cuda::layer_norm::PackType<DST, PackSize>*>(gamma)\n            + gamma_beta_offset);\n      beta_pack.storage = *(reinterpret_cast<const cuda::layer_norm::PackType<DST, PackSize>*>(beta)\n                            + gamma_beta_offset);\n    }\n\n#pragma unroll\n    for (int i = 0; i < PackSize; ++i) {\n      DST normalized_i = static_cast<DST>(src[i]);\n      if (affine) {\n        y_pack.elem[i] = act(normalized_i * gamma_pack.elem[i] + beta_pack.elem[i]);\n      } else {\n        // Direct Store.\n        y_pack.elem[i] = act(normalized_i);\n      }\n    }\n    *(reinterpret_cast<cuda::layer_norm::PackType<DST, PackSize>*>(y) + y_offset) = y_pack.storage;\n  }\n  bool CanPackAs(size_t pack_size) { return (c1.divisor % pack_size) == 0; }\n  DST* y;\n  const DST* gamma;\n  const DST* beta;\n  int32_t spatial_size;\n  cutlass::FastDivmod c0;\n  cutlass::FastDivmod c1;\n  ep::primitive::UnaryFunctor<DeviceType::kCUDA, activation, DST, DST> act;\n};\n\ntemplate<typename SRC, typename DST>\nstruct ChannelsLastLoad {\n  using LoadType = DST;\n  ChannelsLastLoad(const SRC* src, int64_t spatial_size, int64_t channel_size, int64_t num_groups)\n      : src(src), spatial_size(spatial_size), c0(num_groups), c1(channel_size / num_groups) {}\n  template<int N>\n  __device__ void load(DST* dst, int32_t row, int32_t col) const {\n    int32_t spatial_idx;\n    int32_t c1_idx;\n    c1(spatial_idx, c1_idx, col);\n    int32_t batch_idx;\n    int32_t c0_idx;\n    c0(batch_idx, c0_idx, row);\n    cuda::layer_norm::Pack<SRC, N> pack;\n    const int32_t offset = (batch_idx * c0.divisor * c1.divisor * spatial_size\n                            + spatial_idx * c0.divisor * c1.divisor + c0_idx * c1.divisor + c1_idx)\n                           / N;\n\n    pack.storage = *(reinterpret_cast<const cuda::layer_norm::PackType<SRC, N>*>(src) + offset);\n#pragma unroll\n    for (int i = 0; i < N; ++i) { dst[i] = static_cast<DST>(pack.elem[i]); }\n  }\n  bool CanPackAs(size_t pack_size) { return (c1.divisor % pack_size) == 0; }\n  const SRC* src;\n  int32_t spatial_size;\n  cutlass::FastDivmod c0;\n  cutlass::FastDivmod c1;\n};\n\n#else\n\ntemplate<typename SRC, typename DST, ep::primitive::UnaryOp activation, bool affine>\nstruct ChannelsLastStore {\n  ChannelsLastStore(DST* y, const DST* gamma, const DST* beta, int64_t spatial_size,\n                    int64_t channel_size, int64_t num_groups)\n      : y(y),\n        gamma(gamma),\n        beta(beta),\n        spatial_size(spatial_size),\n        c0(num_groups),\n        c1(channel_size / num_groups),\n        act(0, 0) {}\n\n  template<int PackSize>\n  __device__ void store(const SRC* src, int32_t row, int32_t col) {\n    cuda::layer_norm::Pack<DST, PackSize> y_pack;\n    cuda::layer_norm::Pack<DST, PackSize> gamma_pack;\n    cuda::layer_norm::Pack<DST, PackSize> beta_pack;\n    int32_t spatial_idx = col / c1;\n    int32_t c1_idx = col - spatial_idx * c1;\n    int32_t batch_idx = row / c0;\n    int32_t c0_idx = row - batch_idx * c0;\n    const int32_t y_offset =\n        (batch_idx * c0 * c1 * spatial_size + spatial_idx * c0 * c1 + c0_idx * c1 + c1_idx)\n        / PackSize;\n    const int32_t gamma_beta_offset = (c0_idx * c1 + c1_idx) / PackSize;\n    if (affine) {\n      gamma_pack.storage =\n          *(reinterpret_cast<const cuda::layer_norm::PackType<DST, PackSize>*>(gamma)\n            + gamma_beta_offset);\n      beta_pack.storage = *(reinterpret_cast<const cuda::layer_norm::PackType<DST, PackSize>*>(beta)\n                            + gamma_beta_offset);\n    }\n\n#pragma unroll\n    for (int i = 0; i < PackSize; ++i) {\n      DST normalized_i = static_cast<DST>(src[i]);\n      if (affine) {\n        y_pack.elem[i] = act(normalized_i * gamma_pack.elem[i] + beta_pack.elem[i]);\n      } else {\n        // Direct Store.\n        y_pack.elem[i] = act(normalized_i);\n      }\n    }\n    *(reinterpret_cast<cuda::layer_norm::PackType<DST, PackSize>*>(y) + y_offset) = y_pack.storage;\n  }\n  bool CanPackAs(size_t pack_size) { return (c1 % pack_size) == 0; }\n  DST* y;\n  const DST* gamma;\n  const DST* beta;\n  int32_t spatial_size;\n  int32_t c0;\n  int32_t c1;\n  ep::primitive::UnaryFunctor<DeviceType::kCUDA, activation, DST, DST> act;\n};\n\ntemplate<typename SRC, typename DST>\nstruct ChannelsLastLoad {\n  using LoadType = DST;\n  ChannelsLastLoad(const SRC* src, int64_t spatial_size, int64_t channel_size, int64_t num_groups)\n      : src(src), spatial_size(spatial_size), c0(num_groups), c1(channel_size / num_groups) {}\n  template<int N>\n  __device__ void load(DST* dst, int32_t row, int32_t col) const {\n    int32_t spatial_idx = col / c1;\n    int32_t c1_idx = col - spatial_idx * c1;\n    int32_t batch_idx = row / c0;\n    int32_t c0_idx = row - batch_idx * c0;\n    cuda::layer_norm::Pack<SRC, N> pack;\n    const int32_t offset =\n        (batch_idx * c0 * c1 * spatial_size + spatial_idx * c0 * c1 + c0_idx * c1 + c1_idx) / N;\n\n    pack.storage = *(reinterpret_cast<const cuda::layer_norm::PackType<SRC, N>*>(src) + offset);\n#pragma unroll\n    for (int i = 0; i < N; ++i) { dst[i] = static_cast<DST>(pack.elem[i]); }\n  }\n  bool CanPackAs(size_t pack_size) { return (c1 % pack_size) == 0; }\n  const SRC* src;\n  int32_t spatial_size;\n  int32_t c0;\n  int32_t c1;\n};\n\n#endif  // WITH_CUTLASS\n\ntemplate<typename T, ep::primitive::UnaryOp activation, bool affine>\nvoid GroupNormForwardGpu(ep::Stream* stream, const int64_t num_instances, const int64_t norm_size,\n                         const int64_t channel_size, const int64_t spatial_size,\n                         const double epsilon, const T* x_ptr, const T* gamma_ptr,\n                         const T* beta_ptr, T* y_ptr, user_op::Tensor* mean,\n                         user_op::Tensor* inv_variance, bool channels_first) {\n  using ComputeType = typename cuda::layer_norm::DefaultComputeType<T>::type;\n  if (channels_first) {\n    cuda::layer_norm::DirectLoad<T, T> load(x_ptr, norm_size);\n    AffineStore<ComputeType, T, activation, affine> store(y_ptr, norm_size, channel_size,\n                                                          spatial_size, gamma_ptr, beta_ptr);\n\n    cuda::layer_norm::DispatchLayerNorm<decltype(load), decltype(store), ComputeType>(\n        stream->As<ep::CudaStream>()->cuda_stream(), load, store, num_instances, norm_size, epsilon,\n        mean->mut_dptr<ComputeType>(), inv_variance->mut_dptr<ComputeType>());\n  } else {\n    ChannelsLastLoad<T, T> load(x_ptr, spatial_size, channel_size,\n                                channel_size / (norm_size / spatial_size));\n    ChannelsLastStore<ComputeType, T, activation, affine> store(\n        y_ptr, gamma_ptr, beta_ptr, spatial_size, channel_size,\n        channel_size / (norm_size / spatial_size));\n\n    cuda::layer_norm::DispatchLayerNorm<decltype(load), decltype(store), ComputeType>(\n        stream->As<ep::CudaStream>()->cuda_stream(), load, store, num_instances, norm_size, epsilon,\n        mean->mut_dptr<ComputeType>(), inv_variance->mut_dptr<ComputeType>());\n  }\n}\n\ntemplate<typename T, ep::primitive::UnaryOp activation>\nvoid DispatchGroupNormAffine(ep::Stream* stream, const int64_t num_instances,\n                             const int64_t norm_size, const int64_t channel_size,\n                             const int64_t spatial_size, const double epsilon, const T* x_ptr,\n                             const T* gamma_ptr, const T* beta_ptr, T* y_ptr, user_op::Tensor* mean,\n                             user_op::Tensor* inv_variance, bool channels_first) {\n  if (gamma_ptr != nullptr && beta_ptr != nullptr) {\n    GroupNormForwardGpu<T, activation, true>(stream, num_instances, norm_size, channel_size,\n                                             spatial_size, epsilon, x_ptr, gamma_ptr, beta_ptr,\n                                             y_ptr, mean, inv_variance, channels_first);\n  } else {\n    GroupNormForwardGpu<T, activation, false>(stream, num_instances, norm_size, channel_size,\n                                              spatial_size, epsilon, x_ptr, gamma_ptr, beta_ptr,\n                                              y_ptr, mean, inv_variance, channels_first);\n  }\n}\n\ntemplate<typename T>\nvoid DispatchGroupNormForwardGpu(ep::Stream* stream, const int64_t num_instances,\n                                 const int64_t norm_size, const int64_t channel_size,\n                                 const int64_t spatial_size, const double epsilon, const T* x_ptr,\n                                 const T* gamma_ptr, const T* beta_ptr, T* y_ptr,\n                                 user_op::Tensor* mean, user_op::Tensor* inv_variance,\n                                 bool channels_first, const std::string& activation) {\n  if (activation == \"none\") {\n    DispatchGroupNormAffine<T, ep::primitive::UnaryOp::kIdentity>(\n        stream, num_instances, norm_size, channel_size, spatial_size, epsilon, x_ptr, gamma_ptr,\n        beta_ptr, y_ptr, mean, inv_variance, channels_first);\n  } else if (activation == \"silu\") {\n    DispatchGroupNormAffine<T, ep::primitive::UnaryOp::kSilu>(\n        stream, num_instances, norm_size, channel_size, spatial_size, epsilon, x_ptr, gamma_ptr,\n        beta_ptr, y_ptr, mean, inv_variance, channels_first);\n  } else {\n    UNIMPLEMENTED();\n  }\n}\n\ntemplate<typename T, bool affine>\nvoid GroupNormBackwardGpu(ep::Stream* stream, const int64_t num_instances, const int64_t norm_size,\n                          const int64_t channel_size, const int64_t spatial_size, const T* dy_ptr,\n                          const T* x_ptr, const user_op::Tensor* mean,\n                          const user_op::Tensor* inv_variance, const T* gamma_ptr, T* dx_ptr) {\n  using ComputeType = typename cuda::layer_norm::DefaultComputeType<T>::type;\n  cuda::layer_norm::DirectLoad<T, T> load_x(x_ptr, norm_size);\n  ScaleLoad<T, T, affine> load_scaled_dy(dy_ptr, gamma_ptr, norm_size, channel_size, spatial_size);\n  cuda::layer_norm::DirectStore<ComputeType, T> store(dx_ptr, norm_size);\n  OF_CUDA_CHECK((cuda::layer_norm::DispatchLayerNormGrad<decltype(load_x), decltype(load_scaled_dy),\n                                                         decltype(store), ComputeType>(\n      stream->As<ep::CudaStream>()->cuda_stream(), load_x, load_scaled_dy, store,\n      mean->dptr<ComputeType>(), inv_variance->dptr<ComputeType>(), num_instances, norm_size)));\n}\n\ntemplate<typename T>\nvoid LaunchGroupNormBackward(ep::Stream* stream, const int64_t num_instances,\n                             const int64_t norm_size, const int64_t channel_size,\n                             const int64_t spatial_size, const T* dy_ptr, const T* x_ptr,\n                             const user_op::Tensor* mean, const user_op::Tensor* inv_variance,\n                             const T* gamma_ptr, T* dx_ptr) {\n  if (gamma_ptr != nullptr) {\n    GroupNormBackwardGpu<T, true>(stream, num_instances, norm_size, channel_size, spatial_size,\n                                  dy_ptr, x_ptr, mean, inv_variance, gamma_ptr, dx_ptr);\n  } else {\n    GroupNormBackwardGpu<T, false>(stream, num_instances, norm_size, channel_size, spatial_size,\n                                   dy_ptr, x_ptr, mean, inv_variance, gamma_ptr, dx_ptr);\n  }\n}\n\n}  // namespace\n\ntemplate<typename T>\nclass GroupNormGpuKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport {\n public:\n  GroupNormGpuKernel() = default;\n  ~GroupNormGpuKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    user_op::Tensor* mean = ctx->Tensor4ArgNameAndIndex(\"mean\", 0);\n    user_op::Tensor* inv_variance = ctx->Tensor4ArgNameAndIndex(\"inv_variance\", 0);\n    const double epsilon = ctx->Attr<double>(\"epsilon\");\n    const int32_t num_groups = ctx->Attr<int32_t>(\"num_groups\");\n    const std::string& data_format = ctx->Attr<std::string>(\"data_format\");\n    CHECK_GE(epsilon, CUDNN_BN_MIN_EPSILON);\n    const int64_t num_instances = mean->shape_view().elem_cnt();  // N*num_groups\n    const int64_t norm_size = x->shape_view().elem_cnt() / num_instances;\n    const int64_t batch_size = x->shape_view().At(0);\n    int64_t channel_size = 0;\n    bool channels_first = false;\n    if (data_format == \"channels_first\") {\n      channel_size = x->shape_view().At(1);\n      channels_first = true;\n    } else if (data_format == \"channels_last\") {\n      channel_size = x->shape_view().At(x->shape_view().NumAxes() - 1);\n      channels_first = false;\n    } else {\n      UNIMPLEMENTED();\n    }\n    const int64_t spatial_size = x->shape_view().elem_cnt() / batch_size / channel_size;\n    const T* gamma_ptr = nullptr;\n    const T* beta_ptr = nullptr;\n    if (ctx->has_input(\"gamma\", 0) && ctx->has_input(\"beta\", 0)) {\n      const user_op::Tensor* gamma = ctx->Tensor4ArgNameAndIndex(\"gamma\", 0);\n      gamma_ptr = gamma->dptr<T>();\n      CHECK_EQ(gamma->shape_view().elem_cnt(), channel_size);\n      const user_op::Tensor* beta = ctx->Tensor4ArgNameAndIndex(\"beta\", 0);\n      beta_ptr = ctx->Tensor4ArgNameAndIndex(\"beta\", 0)->dptr<T>();\n      CHECK_EQ(beta->shape_view().elem_cnt(), channel_size);\n    }\n    DispatchGroupNormForwardGpu<T>(ctx->stream(), num_instances, norm_size, channel_size,\n                                   spatial_size, epsilon, x->dptr<T>(), gamma_ptr, beta_ptr,\n                                   y->mut_dptr<T>(), mean, inv_variance, channels_first,\n                                   ctx->Attr<std::string>(\"activation\"));\n  }\n};\n\n#define REGISTER_GROUP_NORM_CUDA_KERNEL(dtype)                         \\\n  REGISTER_USER_KERNEL(\"group_norm\")                                   \\\n      .SetCreateFn<GroupNormGpuKernel<dtype>>()                        \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \\\n                       && (user_op::HobDataType(\"x\", 0) == GetDataType<dtype>::value));\n\nREGISTER_GROUP_NORM_CUDA_KERNEL(half)\nREGISTER_GROUP_NORM_CUDA_KERNEL(float)\nREGISTER_GROUP_NORM_CUDA_KERNEL(double)\n#if CUDA_VRSION >= 11000\nREGISTER_GROUP_NORM_CUDA_KERNEL(nv_bfloat16)\n#endif\n\ntemplate<typename T>\nclass GroupNormGradGpuKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport {\n public:\n  GroupNormGradGpuKernel() = default;\n  ~GroupNormGradGpuKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    const user_op::Tensor* mean = ctx->Tensor4ArgNameAndIndex(\"mean\", 0);\n    const user_op::Tensor* inv_variance = ctx->Tensor4ArgNameAndIndex(\"inv_variance\", 0);\n    user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex(\"dx\", 0);\n    const int64_t num_instances = mean->shape_view().elem_cnt();\n    const int64_t norm_size = x->shape_view().elem_cnt() / num_instances;\n    const int64_t batch_size = x->shape_view().At(0);\n    const int64_t channel_size = x->shape_view().At(1);\n    const int64_t spatial_size = x->shape_view().elem_cnt() / batch_size / channel_size;\n    const T* gamma_ptr = nullptr;\n    if (ctx->has_input(\"gamma\", 0)) {\n      gamma_ptr = ctx->Tensor4ArgNameAndIndex(\"gamma\", 0)->dptr<T>();\n    }\n    LaunchGroupNormBackward<T>(ctx->stream(), num_instances, norm_size, channel_size, spatial_size,\n                               dy->dptr<T>(), x->dptr<T>(), mean, inv_variance, gamma_ptr,\n                               dx->mut_dptr<T>());\n  };\n};\n\n#define REGISTER_GROUP_NORM_GRAD_CUDA_KERNEL(dtype)                    \\\n  REGISTER_USER_KERNEL(\"group_norm_grad\")                              \\\n      .SetCreateFn<GroupNormGradGpuKernel<dtype>>()                    \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \\\n                       && (user_op::HobDataType(\"dy\", 0) == GetDataType<dtype>::value));\n\nREGISTER_GROUP_NORM_GRAD_CUDA_KERNEL(half)\nREGISTER_GROUP_NORM_GRAD_CUDA_KERNEL(float)\nREGISTER_GROUP_NORM_GRAD_CUDA_KERNEL(double)\n#if CUDA_VRSION >= 11000\nREGISTER_GROUP_NORM_GRAD_CUDA_KERNEL(nv_bfloat16)\n#endif\n\nconstexpr int kReduceBlockSize = 512;\nconstexpr int kBlockSize = 128;\nconstexpr int kNumWaves = 32;\n\ninline cudaError_t GetReduceNumBlocks(int64_t n, int* num_blocks) {\n  int dev;\n  {\n    cudaError_t err = cudaGetDevice(&dev);\n    if (err != cudaSuccess) { return err; }\n  }\n  int sm_count;\n  {\n    cudaError_t err = cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev);\n    if (err != cudaSuccess) { return err; }\n  }\n  int tpm;\n  {\n    cudaError_t err = cudaDeviceGetAttribute(&tpm, cudaDevAttrMaxThreadsPerMultiProcessor, dev);\n    if (err != cudaSuccess) { return err; }\n  }\n  *num_blocks =\n      std::max<int>(1, std::min<int64_t>(n, sm_count * tpm / kReduceBlockSize * kNumWaves));\n  return cudaSuccess;\n}\n\ninline cudaError_t GetNumBlocks(int64_t n, int* num_blocks) {\n  int dev;\n  {\n    cudaError_t err = cudaGetDevice(&dev);\n    if (err != cudaSuccess) { return err; }\n  }\n  int sm_count;\n  {\n    cudaError_t err = cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev);\n    if (err != cudaSuccess) { return err; }\n  }\n  int tpm;\n  {\n    cudaError_t err = cudaDeviceGetAttribute(&tpm, cudaDevAttrMaxThreadsPerMultiProcessor, dev);\n    if (err != cudaSuccess) { return err; }\n  }\n  *num_blocks = std::max<int>(1, std::min<int64_t>((n + kBlockSize - 1) / kBlockSize,\n                                                   sm_count * tpm / kBlockSize * kNumWaves));\n  return cudaSuccess;\n}\n\ntemplate<typename T>\nstruct SumOp {\n  __device__ __forceinline__ T operator()(const T& a, const T& b) const { return a + b; }\n};\n\ntemplate<typename T, int PackSize>\nstruct GetPackType {\n  using type = typename std::aligned_storage<sizeof(T) * PackSize, sizeof(T) * PackSize>::type;\n};\n\ntemplate<typename T, int PackSize>\nusing PackType = typename GetPackType<T, PackSize>::type;\n\ntemplate<typename T, int PackSize>\nunion Pack {\n  static_assert(sizeof(PackType<T, PackSize>) == sizeof(T) * PackSize, \"\");\n  __device__ Pack(T val) {\n    for (int i = 0; i < PackSize; i++) { elem[i] = val; }\n  }\n\n  T elem[PackSize];\n  PackType<T, PackSize> storage;\n};\n\nconstexpr int kMaxPackBytes = 128 / 8;\nconstexpr int kMaxPackSize = 8;\n\nconstexpr int Min(int a, int b) { return a < b ? a : b; }\n\ntemplate<typename T>\nconstexpr int GetPackSize() {\n  return Min(kMaxPackBytes / sizeof(T), kMaxPackSize);\n}\n\ntemplate<typename T, typename ComputeType, int PackSize>\n__global__ void GroupNormParamGradKernel(const T* dy, const T* x, const ComputeType* mean,\n                                         const ComputeType* inv_var,\n                                         ComputeType* dgamma_partial_sum,\n                                         ComputeType* dbeta_partial_sum, const int32_t batch_size,\n                                         const int32_t group_size, const int32_t channel_size,\n                                         const int32_t spatial_size) {\n  using LoadType = PackType<T, PackSize>;\n  const int32_t batch_channel_size = batch_size * channel_size;\n  for (int32_t batch_channel_id = blockIdx.x; batch_channel_id < batch_channel_size;\n       batch_channel_id += gridDim.x) {\n    const int32_t batch_id = batch_channel_id / channel_size;\n    const int32_t channel_id = batch_channel_id % channel_size;\n    const int32_t group_num = channel_size / group_size;\n    const int32_t batch_group_id = batch_id * group_size + channel_id / group_num;\n\n    ComputeType mean_val = mean[batch_group_id];\n    ComputeType inv_var_val = inv_var[batch_group_id];\n\n    Pack<ComputeType, PackSize> ds_sum_pack(0);\n    Pack<ComputeType, PackSize> db_sum_pack(0);\n\n    for (int32_t spatial = threadIdx.x * PackSize; spatial < spatial_size;\n         spatial += blockDim.x * PackSize) {\n      Pack<T, PackSize> dy_pack(0);\n      Pack<T, PackSize> x_pack(0);\n      const int32_t load_idx = batch_channel_id * spatial_size + spatial;\n      const LoadType* dy_load = reinterpret_cast<const LoadType*>(dy + load_idx);\n      dy_pack.storage = *dy_load;\n      const LoadType* x_load = reinterpret_cast<const LoadType*>(x + load_idx);\n      x_pack.storage = *x_load;\n#pragma unroll\n      for (int i = 0; i < PackSize; i++) {\n        ds_sum_pack.elem[i] += static_cast<ComputeType>(dy_pack.elem[i])\n                               * (static_cast<ComputeType>(x_pack.elem[i]) - mean_val)\n                               * inv_var_val;\n        db_sum_pack.elem[i] += static_cast<ComputeType>(dy_pack.elem[i]);\n      }\n    }\n\n    ComputeType ds_sum = 0.0;\n    ComputeType db_sum = 0.0;\n\n#pragma unroll\n    for (int i = 0; i < PackSize; i++) {\n      ds_sum += ds_sum_pack.elem[i];\n      db_sum += db_sum_pack.elem[i];\n    }\n\n    __syncthreads();\n    typedef cub::BlockReduce<ComputeType, kReduceBlockSize> BlockReduce;\n    __shared__ typename BlockReduce::TempStorage temp_storage1;\n    __shared__ typename BlockReduce::TempStorage temp_storage2;\n    ComputeType ds_sum_result = BlockReduce(temp_storage1).Reduce(ds_sum, SumOp<ComputeType>());\n    ComputeType db_sum_result = BlockReduce(temp_storage2).Reduce(db_sum, SumOp<ComputeType>());\n    if (threadIdx.x == 0) {\n      dgamma_partial_sum[batch_channel_id] = ds_sum_result;\n      dbeta_partial_sum[batch_channel_id] = db_sum_result;\n    }\n  }\n}\n\ntemplate<typename T, typename ComputeType>\n__global__ void BatchReduceGammaBetaGradKernel(ComputeType* ds_sum, ComputeType* db_sum, T* dgamma,\n                                               T* dbeta, const int32_t batch_size,\n                                               const int32_t group_size, const int32_t channel_size,\n                                               const int32_t spatial_size) {\n  const int32_t group_num = channel_size / group_size;\n  CUDA_1D_KERNEL_LOOP(channel_idx, channel_size) {\n    ComputeType dgamma_sum = 0.0;\n    ComputeType dbeta_sum = 0.0;\n    for (int batch_id = 0; batch_id < batch_size; batch_id++) {\n      const int32_t batch_group_id = batch_id * group_size + channel_idx / group_num;\n      const int32_t batch_channel_id = batch_id * channel_size + channel_idx;\n      dgamma_sum += ds_sum[batch_channel_id];\n      dbeta_sum += db_sum[batch_channel_id];\n    }\n    dgamma[channel_idx] = dgamma_sum;\n    dbeta[channel_idx] = dbeta_sum;\n  }\n}\n\ntemplate<typename T>\nint32_t GetLaunchPackSize(const int32_t spatial_size) {\n  for (int pack_size = GetPackSize<T>(); pack_size > 0; pack_size /= 2) {\n    if (spatial_size % pack_size == 0) { return pack_size; }\n  }\n  return 1;\n}\n\ntemplate<typename T, typename ComputeType>\nvoid DispatchGroupNormParamGradKernel(ep::Stream* stream, const T* dy, const T* x,\n                                      const ComputeType* mean, const ComputeType* inv_var,\n                                      ComputeType* reduce_ds_buf, ComputeType* reduce_db_buf,\n                                      const int32_t batch_size, const int32_t group_size,\n                                      const int32_t channel_size, const int32_t spatial_size) {\n  const int launch_pack_size = GetLaunchPackSize<T>(spatial_size);\n  int num_blocks;\n  OF_CUDA_CHECK(GetReduceNumBlocks(batch_size * channel_size, &num_blocks));\n  if (launch_pack_size == 8) {\n    GroupNormParamGradKernel<T, ComputeType, 8>\n        <<<num_blocks, kReduceBlockSize, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(\n            dy, x, mean, inv_var, reduce_ds_buf, reduce_db_buf, batch_size, group_size,\n            channel_size, spatial_size);\n  } else if (launch_pack_size == 4) {\n    GroupNormParamGradKernel<T, ComputeType, 4>\n        <<<num_blocks, kReduceBlockSize, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(\n            dy, x, mean, inv_var, reduce_ds_buf, reduce_db_buf, batch_size, group_size,\n            channel_size, spatial_size);\n  } else if (launch_pack_size == 2) {\n    GroupNormParamGradKernel<T, ComputeType, 2>\n        <<<num_blocks, kReduceBlockSize, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(\n            dy, x, mean, inv_var, reduce_ds_buf, reduce_db_buf, batch_size, group_size,\n            channel_size, spatial_size);\n  } else {\n    GroupNormParamGradKernel<T, ComputeType, 1>\n        <<<num_blocks, kReduceBlockSize, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(\n            dy, x, mean, inv_var, reduce_ds_buf, reduce_db_buf, batch_size, group_size,\n            channel_size, spatial_size);\n  }\n}\n\ntemplate<typename T>\nclass GroupNormParamGradGpuKernel final : public user_op::OpKernel,\n                                          public user_op::CudaGraphSupport {\n public:\n  GroupNormParamGradGpuKernel() = default;\n  ~GroupNormParamGradGpuKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    const user_op::Tensor* mean = ctx->Tensor4ArgNameAndIndex(\"mean\", 0);\n    const user_op::Tensor* inv_variance = ctx->Tensor4ArgNameAndIndex(\"inv_variance\", 0);\n    user_op::Tensor* dgamma = ctx->Tensor4ArgNameAndIndex(\"dgamma\", 0);\n    user_op::Tensor* dbeta = ctx->Tensor4ArgNameAndIndex(\"dbeta\", 0);\n    const int64_t num_instances = mean->shape_view().elem_cnt();\n    const int64_t norm_size = x->shape_view().elem_cnt() / num_instances;\n    const int64_t batch_size = x->shape_view().At(0);\n    const int64_t channel_size = x->shape_view().At(1);\n    const int64_t spatial_size = x->shape_view().elem_cnt() / batch_size / channel_size;\n    const int64_t group_size = num_instances / batch_size;\n    user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n    using ComputeType = typename cuda::layer_norm::DefaultComputeType<T>::type;\n    ComputeType* reduce_ds_buf_ptr = reinterpret_cast<ComputeType*>(tmp_buffer->mut_dptr<char>());\n    ComputeType* reduce_db_buf_ptr = reinterpret_cast<ComputeType*>(\n        tmp_buffer->mut_dptr<char>() + batch_size * channel_size * sizeof(T));\n    DispatchGroupNormParamGradKernel<T, ComputeType>(\n        ctx->stream(), dy->dptr<T>(), x->dptr<T>(), mean->dptr<ComputeType>(),\n        inv_variance->dptr<ComputeType>(), reduce_ds_buf_ptr, reduce_db_buf_ptr, batch_size,\n        group_size, channel_size, spatial_size);\n    int num_blocks;\n    OF_CUDA_CHECK(GetNumBlocks(channel_size, &num_blocks));\n    // Note(zhengzekang): In large batchsize, it is recommend to use gemm to reduce. (1, N) matmul\n    // (N, C)\n    BatchReduceGammaBetaGradKernel<T, ComputeType>\n        <<<num_blocks, kBlockSize, 0, ctx->stream()->As<ep::CudaStream>()->cuda_stream()>>>(\n            reduce_ds_buf_ptr, reduce_db_buf_ptr, dgamma->mut_dptr<T>(), dbeta->mut_dptr<T>(),\n            batch_size, group_size, channel_size, spatial_size);\n  };\n};\n\n#define REGISTER_GROUP_NORM_PARAM_GRAD_CUDA_KERNEL(dtype, compute_dtype)                  \\\n  REGISTER_USER_KERNEL(\"group_norm_param_grad\")                                           \\\n      .SetCreateFn<GroupNormParamGradGpuKernel<dtype>>()                                  \\\n      .SetInferTmpSizeFn([](user_op::InferContext* ctx) {                                 \\\n        const auto& x = ctx->InputTensorDesc(\"x\", 0);                                     \\\n        const int64_t batch_size = x.shape().At(0);                                       \\\n        const int64_t channel_size = x.shape().At(1);                                     \\\n        size_t tmp_buffer_size = (2 * batch_size * channel_size) * sizeof(compute_dtype); \\\n        return tmp_buffer_size;                                                           \\\n      })                                                                                  \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)                    \\\n                       && (user_op::HobDataType(\"dy\", 0) == GetDataType<dtype>::value));\n\nREGISTER_GROUP_NORM_PARAM_GRAD_CUDA_KERNEL(half, float)\nREGISTER_GROUP_NORM_PARAM_GRAD_CUDA_KERNEL(float, float)\nREGISTER_GROUP_NORM_PARAM_GRAD_CUDA_KERNEL(double, double)\n#if CUDA_VRSION >= 11000\nREGISTER_GROUP_NORM_PARAM_GRAD_CUDA_KERNEL(nv_bfloat16, float)\n#endif\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/grouped_matmul_bias.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/cuda_graph_support.h\"\n#include \"oneflow/core/cuda/elementwise.cuh\"\n#include \"oneflow/core/cuda/atomic.cuh\"\n#include \"oneflow/core/device/cuda_util.h\"\n#include \"oneflow/core/common/scalar.h\"\n\nnamespace oneflow {\n\nstruct Problem {\n  Problem(int64_t m, int64_t n, int64_t k) : m(m), n(n), k(k) {}\n  int64_t m;\n  int64_t n;\n  int64_t k;\n};\n\ninline bool operator==(const Problem& lhs, const Problem& rhs) {\n  return lhs.m == rhs.m && lhs.n == rhs.n && lhs.k == rhs.k;\n}\n\n}  // namespace oneflow\n\nnamespace std {\n\ntemplate<>\nstruct hash<oneflow::Problem> {\n  std::size_t operator()(const oneflow::Problem& p) const {\n    return oneflow::Hash<int64_t, int64_t, int64_t>(p.m, p.n, p.k);\n  }\n};\n\n}  // namespace std\n\nnamespace oneflow {\n\nnamespace {\n\nconstexpr int64_t kMaxProblemBatch = 64;\n\ntemplate<typename T>\nstruct Buffer {\n  const T* x;\n  const T* w;\n  const T* b;\n  T* y;\n};\n\ntemplate<typename T>\nstruct Param {\n  Param(const Problem& problem, std::vector<Buffer<T>> buffers)\n      : problem(problem), n(buffers.size()) {\n    std::copy(buffers.cbegin(), buffers.cend(), buffer);\n    elem_cnt = n * problem.m * problem.n;\n  }\n  Problem problem;\n  Buffer<T> buffer[kMaxProblemBatch];\n  int n;\n  int elem_cnt;\n};\n\ntemplate<typename T, bool has_biases>\n__global__ void InitPtrAndApplyBias(Param<T> p, void** ptr_arr) {\n  if (has_biases) {\n    CUDA_1D_KERNEL_LOOP(i, p.elem_cnt) {\n      const int32_t p_idx = i / (p.problem.m * p.problem.n);\n      const int32_t y_idx = i % (p.problem.m * p.problem.n);\n      const int32_t m_idx = y_idx / p.problem.n;\n      const int32_t n_idx = y_idx % p.problem.n;\n      p.buffer[p_idx].y[y_idx] = p.buffer[p_idx].b[n_idx];\n    }\n  }\n  CUDA_1D_KERNEL_LOOP(i, p.n) {\n    ptr_arr[i] = const_cast<T*>(p.buffer[i].x);\n    ptr_arr[i + kMaxProblemBatch] = const_cast<T*>(p.buffer[i].w);\n    ptr_arr[i + 2 * kMaxProblemBatch] = p.buffer[i].y;\n  }\n}\n\nunion CublasScalarParameter {\n  double d;\n  float s;\n  half h;\n};\n\nCublasScalarParameter GetCublasScalarParameter(Scalar scalar, cudaDataType_t compute_type) {\n  CublasScalarParameter sp{};\n  if (compute_type == CUDA_R_64F) {\n    sp.d = scalar.Value<double>();\n  } else if (compute_type == CUDA_R_32F) {\n    sp.s = scalar.Value<float>();\n  } else if (compute_type == CUDA_R_16F) {\n    sp.h = static_cast<half>(scalar.Value<float>());\n  } else {\n    UNIMPLEMENTED();\n  }\n  return sp;\n}\n\ntemplate<typename T>\nvoid ApplyGroup(const Problem& problem, std::vector<Buffer<T>> ptrs, bool has_biases,\n                void* workspace, ep::Stream* stream) {\n  Param<T> params(problem, ptrs);\n  void** ptr_arr = reinterpret_cast<void**>(workspace);\n  if (has_biases) {\n    RUN_CUDA_KERNEL((InitPtrAndApplyBias<T, true>), stream, params.elem_cnt, params, ptr_arr);\n  } else {\n    RUN_CUDA_KERNEL((InitPtrAndApplyBias<T, false>), stream, params.n, params, ptr_arr);\n  }\n  float alpha = 1.0;\n  float beta = has_biases ? 1.0 : 0.0;\n  cudaDataType_t data_type{};\n  cudaDataType_t compute_type{};\n  if (std::is_same<T, half>::value) {\n    data_type = CUDA_R_16F;\n    const bool allow_half_accumulation =\n        ParseBooleanFromEnv(\"ONEFLOW_MATMUL_ALLOW_HALF_PRECISION_ACCUMULATION\", false);\n    if (allow_half_accumulation) {\n      compute_type = CUDA_R_16F;\n    } else {\n      compute_type = CUDA_R_32F;\n    }\n  } else if (std::is_same<T, float>::value) {\n    data_type = CUDA_R_32F;\n    compute_type = CUDA_R_32F;\n  } else {\n    UNIMPLEMENTED();\n  }\n  auto sp_alpha = GetCublasScalarParameter(alpha, compute_type);\n  auto sp_beta = GetCublasScalarParameter(beta, compute_type);\n  OF_CUBLAS_CHECK(cublasGemmBatchedEx(\n      stream->As<ep::CudaStream>()->cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N, problem.n, problem.m,\n      problem.k, &sp_alpha, ptr_arr + kMaxProblemBatch, data_type, problem.k, ptr_arr, data_type,\n      problem.k, &sp_beta, ptr_arr + 2 * kMaxProblemBatch, data_type, problem.n, params.n,\n      compute_type, CUBLAS_GEMM_DEFAULT));\n}\n\ntemplate<typename T>\nclass GroupedMatmulBiasKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport {\n public:\n  GroupedMatmulBiasKernel() = default;\n  ~GroupedMatmulBiasKernel() override = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state,\n               const user_op::OpKernelCache* cache) const override {\n    HashMap<Problem, std::vector<Buffer<T>>> groups;\n    const int32_t input_size = ctx->input_size(\"xs\");\n    CHECK_EQ(ctx->input_size(\"weights\"), input_size);\n    const bool has_biases = ctx->has_input(\"biases\", 0);\n    if (has_biases) { CHECK_EQ(ctx->input_size(\"biases\"), input_size); }\n    CHECK_EQ(ctx->output_size(\"ys\"), input_size);\n    for (int32_t i = 0; i < input_size; ++i) {\n      const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex(\"xs\", i);\n      const user_op::Tensor* w = ctx->Tensor4ArgNameAndIndex(\"weights\", i);\n      const user_op::Tensor* b = has_biases ? ctx->Tensor4ArgNameAndIndex(\"biases\", i) : nullptr;\n      user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex(\"ys\", i);\n      CHECK_GE(x->shape_view().NumAxes(), 2);\n      const int64_t k = x->shape_view().At(x->shape_view().NumAxes() - 1);\n      const int64_t m = x->shape_view().elem_cnt() / k;\n      CHECK_EQ(w->shape_view().NumAxes(), 2);\n      CHECK_EQ(w->shape_view().At(1), k);\n      const int64_t n = w->shape_view().At(0);\n      if (has_biases) {\n        CHECK_EQ(b->shape_view().NumAxes(), 1);\n        CHECK_EQ(b->shape_view().At(0), n);\n      }\n      CHECK_EQ(y->shape_view().NumAxes(), x->shape_view().NumAxes());\n      CHECK_EQ(y->shape_view().At(y->shape_view().NumAxes() - 1), n);\n      for (int32_t j = 0; j < y->shape_view().NumAxes() - 1; ++j) {\n        CHECK_EQ(y->shape_view().At(j), x->shape_view().At(j));\n      }\n      groups[Problem(m, n, k)].push_back(Buffer<T>{\n          x->dptr<T>(), w->dptr<T>(), has_biases ? b->dptr<T>() : nullptr, y->mut_dptr<T>()});\n    }\n    void* workspace = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0)->mut_dptr();\n    for (const auto& group : groups) {\n      for (size_t i = 0; i < group.second.size(); i += kMaxProblemBatch) {\n        std::vector<Buffer<T>> ptrs(\n            {group.second.begin() + i,\n             group.second.begin() + i\n                 + std::min<size_t>(group.second.size() - i, kMaxProblemBatch)});\n        ApplyGroup<T>(group.first, ptrs, has_biases, workspace, ctx->stream());\n      }\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_GROUPED_MATMUL_BIAS_KERNEL_GPU(cpp_type, data_type)    \\\n  REGISTER_USER_KERNEL(\"grouped_matmul_bias\")                           \\\n      .SetCreateFn<GroupedMatmulBiasKernel<cpp_type>>()                 \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)  \\\n                       && (user_op::HobDataType(\"ys\", 0) == data_type)) \\\n      .SetInferTmpSizeFn([](user_op::InferContext* ctx) -> size_t {     \\\n        return kMaxProblemBatch * 3 * sizeof(void*);                    \\\n      });                                                               \\\n  ;\n\nREGISTER_GROUPED_MATMUL_BIAS_KERNEL_GPU(float, DataType::kFloat)\nREGISTER_GROUPED_MATMUL_BIAS_KERNEL_GPU(half, DataType::kFloat16)\n\n}  // namespace\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/groupwise_quantization_kernels.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/device/cuda_util.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/kernel_util.cuh\"\n#include \"oneflow/core/kernel/cuda_graph_support.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n#include <cub/cub.cuh>\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename T, int pack_size>\nstruct alignas(sizeof(T) * pack_size) AlignedArray {\n  __device__ AlignedArray() {\n    // do nothing\n  }\n  union {\n    T elem[pack_size];\n  };\n};\n\ntemplate<typename Src, typename Dst, size_t pack_size>\nstruct Cast {\n  __device__ void operator()(const AlignedArray<Src, pack_size>& src,\n                             AlignedArray<Dst, pack_size>* dst) {\n#pragma unroll\n    for (int i = 0; i < pack_size; ++i) { dst->elem[i] = static_cast<Dst>(src.elem[i]); }\n  }\n};\n\ntemplate<typename Dst, size_t pack_size>\nstruct Cast<uint8_t, Dst, pack_size> {\n  __device__ void operator()(const AlignedArray<uint8_t, pack_size>& src,\n                             AlignedArray<Dst, pack_size>* dst) {\n#pragma unroll\n    for (int i = 0; i < pack_size; ++i) { dst->elem[i] = static_cast<Dst>(src.elem[i]); }\n  }\n\n  __device__ void operator()(const AlignedArray<uint8_t, pack_size>& src,\n                             AlignedArray<Dst, pack_size * 2>* dst) {\n#pragma unroll\n    for (int i = 0; i < pack_size; ++i) {\n      const uint8_t q = src.elem[i];\n      const uint8_t hi = (q >> 4);\n      const uint8_t lo = (q & 0xF);\n      dst->elem[i * 2 + 0] = static_cast<Dst>(hi);\n      dst->elem[i * 2 + 1] = static_cast<Dst>(lo);\n    }\n  }\n};\n\ntemplate<typename Dst, size_t pack_size>\nstruct Cast<int8_t, Dst, pack_size> {\n  __device__ void operator()(const AlignedArray<int8_t, pack_size>& src,\n                             AlignedArray<Dst, pack_size>* dst) {\n#pragma unroll\n    for (int i = 0; i < pack_size; ++i) { dst->elem[i] = static_cast<Dst>(src.elem[i]); }\n  }\n\n  __device__ void operator()(const AlignedArray<int8_t, pack_size>& src,\n                             AlignedArray<Dst, pack_size * 2>* dst) {\n#pragma unroll\n    for (int i = 0; i < pack_size; ++i) {\n      const int8_t q = src.elem[i];\n      const int8_t hi = (q >> 4);\n      int8_t lo = (q << 4);\n      lo = (lo >> 4);\n      dst->elem[i * 2 + 0] = static_cast<Dst>(hi);\n      dst->elem[i * 2 + 1] = static_cast<Dst>(lo);\n    }\n  }\n};\n\ntemplate<typename C, size_t pack_size>\nstruct InplaceAddScalar {\n  __device__ void operator()(AlignedArray<C, pack_size>* array, C scalar) {\n#pragma unroll\n    for (int i = 0; i < pack_size; ++i) { array->elem[i] += scalar; }\n  }\n};\n\ntemplate<typename T, size_t pack_size>\nstruct InplaceFmaScalar {\n  __device__ void operator()(AlignedArray<T, pack_size>* array, T m, T a) {\n#pragma unroll\n    for (int i = 0; i < pack_size; ++i) { array->elem[i] = array->elem[i] * m + a; }\n  }\n};\n\n#if __CUDA_ARCH_ >= 530\ntemplate<size_t pack_size>\nstruct InplaceFmaScalar<half, pack_size> {\n  __device__ void operator()(AlignedArray<half, pack_size>* array, half m, half a) {\n    if (pack_size == 1) {\n#pragma unroll\n      for (int i = 0; i < pack_size; ++i) { array->elem[i] = array->elem[i] * m + a; }\n    } else {\n      const half2 m2 = __half2half2(m);\n      const half2 a2 = __half2half2(a);\n      half2* h2 = reinterpret_cast<half2*>(array->elem);\n#pragma unroll\n      for (int i = 0; i < pack_size / 2; ++i) { h2[i] = __hfma2(h2[i], m2, a2); }\n    }\n  }\n};\n#endif  // __CUDA_ARCH_ >= 530\n\ntemplate<typename T, size_t pack_size>\nstruct InplaceFma {\n  __device__ void operator()(AlignedArray<T, pack_size>* a, const AlignedArray<T, pack_size>& b,\n                             const AlignedArray<T, pack_size>& c) {\n#pragma unroll\n    for (int i = 0; i < pack_size; ++i) { a->elem[i] = a->elem[i] * b.elem[i] + c.elem[i]; }\n  }\n};\n\ntemplate<typename T, size_t pack_size>\nstruct InplaceMulScalar {\n  __device__ void operator()(AlignedArray<T, pack_size>* a, T b) {\n#pragma unroll\n    for (int i = 0; i < pack_size; ++i) { a->elem[i] = a->elem[i] * b; }\n  }\n};\n\ntemplate<typename T, typename C, size_t pack_size>\nstruct MultiplyAccumulate {\n  __device__ void operator()(const AlignedArray<T, pack_size>& a,\n                             const AlignedArray<T, pack_size>& b, C* sum) {\n#pragma unroll\n    for (int i = 0; i < pack_size; ++i) { *sum += static_cast<C>(a.elem[i] * b.elem[i]); }\n  }\n};\n\ntemplate<size_t pack_size>\nstruct MultiplyAccumulate<half, float, pack_size> {\n  __device__ void operator()(const AlignedArray<half, pack_size>& a,\n                             const AlignedArray<half, pack_size>& b, float* sum) {\n    if (pack_size == 1) {\n#pragma unroll\n      for (int i = 0; i < pack_size; ++i) { *sum += static_cast<float>(a.elem[i] * b.elem[i]); }\n    } else {\n      const half2* a2 = reinterpret_cast<const half2*>(a.elem);\n      const half2* b2 = reinterpret_cast<const half2*>(b.elem);\n\n      for (int i = 0; i < pack_size / 2; ++i) {\n        const half2 c2 = __hmul2(a2[i], b2[i]);\n        const float2 f2 = __half22float2(c2);\n        *sum += f2.x;\n        *sum += f2.y;\n      }\n    }\n  }\n};\n\ntemplate<typename T, typename U, typename Index, size_t d_pack_size, size_t q_pack_size, int bits,\n         bool symmetric, bool outer_size_1>\n__global__ void Dequantize3D(Index packed_elem_cnt, Index group_size, Index packed_inner_size,\n                             const AlignedArray<U, q_pack_size>* quantized,\n                             const AlignedArray<T, d_pack_size>* scale,\n                             const AlignedArray<T, d_pack_size>* zero,\n                             AlignedArray<T, d_pack_size>* out) {\n  const Index packed_group_inner_size = group_size * packed_inner_size;\n  CUDA_1D_KERNEL_LOOP_T(Index, i, packed_elem_cnt) {\n    const Index outer_id = outer_size_1 ? 0 : i / packed_group_inner_size;\n    const Index group_inner_offset = i - outer_id * packed_group_inner_size;\n    const Index group_id = group_inner_offset / packed_inner_size;\n    const Index inner_id = group_inner_offset - group_id * packed_inner_size;\n    const Index scale_offset = outer_id * packed_inner_size + inner_id;\n    const AlignedArray<T, d_pack_size> group_scale = scale[scale_offset];\n    AlignedArray<T, d_pack_size> group_zero;\n    if (symmetric) {\n      if (std::is_same<U, uint8_t>::value) {\n        group_zero = group_scale;\n        InplaceMulScalar<T, d_pack_size>()(&group_zero, -static_cast<T>(((1 << (bits - 1)) - 1)));\n      } else {\n#pragma unroll\n        for (int i = 0; i < d_pack_size; ++i) { group_zero.elem[i] = 0; }\n      }\n    } else {\n      group_zero = zero[scale_offset];\n    }\n    AlignedArray<T, d_pack_size> values;\n    const AlignedArray<U, q_pack_size> q = quantized[i];\n    Cast<U, T, q_pack_size>()(q, &values);\n    InplaceFma<T, d_pack_size>()(&values, group_scale, group_zero);\n    out[i] = values;\n  }\n}\n\ntemplate<typename T, typename U, int num_bits, bool symmetric, size_t d_pack_size,\n         size_t q_pack_size, bool outer_size_1>\nvoid LaunchDequantize3D(ep::CudaStream* stream, int64_t outer_size, int64_t group_size,\n                        int64_t inner_size, const U* in, const T* scale, const T* zero, T* out) {\n  if constexpr (sizeof(T) * d_pack_size <= 16 && q_pack_size > 0) {\n    const int64_t packed_elem_cnt = outer_size * group_size * inner_size / d_pack_size;\n    const int64_t packed_inner_size = inner_size / d_pack_size;\n    if (packed_elem_cnt <= (1 << 30)) {\n      RUN_CUDA_KERNEL((Dequantize3D<T, U, int32_t, d_pack_size, q_pack_size, num_bits, symmetric,\n                                    outer_size_1>),\n                      stream, packed_elem_cnt, packed_elem_cnt, group_size, packed_inner_size,\n                      reinterpret_cast<const AlignedArray<U, q_pack_size>*>(in),\n                      reinterpret_cast<const AlignedArray<T, d_pack_size>*>(scale),\n                      reinterpret_cast<const AlignedArray<T, d_pack_size>*>(zero),\n                      reinterpret_cast<AlignedArray<T, d_pack_size>*>(out));\n    } else {\n      RUN_CUDA_KERNEL((Dequantize3D<T, U, int64_t, d_pack_size, q_pack_size, num_bits, symmetric,\n                                    outer_size_1>),\n                      stream, packed_elem_cnt, packed_elem_cnt, group_size, packed_inner_size,\n                      reinterpret_cast<const AlignedArray<U, q_pack_size>*>(in),\n                      reinterpret_cast<const AlignedArray<T, d_pack_size>*>(scale),\n                      reinterpret_cast<const AlignedArray<T, d_pack_size>*>(zero),\n                      reinterpret_cast<AlignedArray<T, d_pack_size>*>(out));\n    }\n  } else {\n    UNIMPLEMENTED();\n  }\n}\n\ntemplate<typename T, typename U, int num_bits, bool symmetric, size_t d_pack_size,\n         size_t q_pack_size>\nvoid DispatchDequantize3DOuterSize1(ep::CudaStream* stream, int64_t outer_size, int64_t group_size,\n                                    int64_t inner_size, const U* in, const T* scale, const T* zero,\n                                    T* out) {\n  if (outer_size == 1) {\n    LaunchDequantize3D<T, U, num_bits, symmetric, d_pack_size, q_pack_size, true>(\n        stream, outer_size, group_size, inner_size, in, scale, zero, out);\n  } else {\n    LaunchDequantize3D<T, U, num_bits, symmetric, d_pack_size, q_pack_size, false>(\n        stream, outer_size, group_size, inner_size, in, scale, zero, out);\n  }\n}\n\ntemplate<typename T, typename U, int num_bits, bool symmetric>\nvoid DispatchDequantize3D(ep::CudaStream* stream, int64_t outer_size, int64_t group_size,\n                          int64_t inner_size, const U* in, const T* scale, const T* zero, T* out) {\n  constexpr int32_t max_pack_size = 16 / sizeof(T);\n  constexpr int32_t data_per_quant = 8 / num_bits;\n  int32_t pack_size = max_pack_size;\n  while (inner_size % pack_size != 0) { pack_size /= 2; }\n  if (pack_size == 16) {\n    DispatchDequantize3DOuterSize1<T, U, num_bits, symmetric, 16, 16 / data_per_quant>(\n        stream, outer_size, group_size, inner_size, in, scale, zero, out);\n  } else if (pack_size == 8) {\n    DispatchDequantize3DOuterSize1<T, U, num_bits, symmetric, 8, 8 / data_per_quant>(\n        stream, outer_size, group_size, inner_size, in, scale, zero, out);\n  } else if (pack_size == 4) {\n    DispatchDequantize3DOuterSize1<T, U, num_bits, symmetric, 4, 4 / data_per_quant>(\n        stream, outer_size, group_size, inner_size, in, scale, zero, out);\n  } else if (pack_size == 2) {\n    DispatchDequantize3DOuterSize1<T, U, num_bits, symmetric, 2, 2 / data_per_quant>(\n        stream, outer_size, group_size, inner_size, in, scale, zero, out);\n  } else if (pack_size == 1) {\n    DispatchDequantize3DOuterSize1<T, U, num_bits, symmetric, 1, 1 / data_per_quant>(\n        stream, outer_size, group_size, inner_size, in, scale, zero, out);\n  } else {\n    UNIMPLEMENTED();\n  }\n}\n\ntemplate<typename T, typename U, typename Index, size_t d_pack_size, size_t q_pack_size, int bits,\n         bool symmetric>\n__global__ void DequantizeInnerSize1(Index packed_elem_cnt, Index packed_group_size,\n                                     const AlignedArray<U, q_pack_size>* quantized, const T* scale,\n                                     const T* zero, AlignedArray<T, d_pack_size>* out) {\n  CUDA_1D_KERNEL_LOOP_T(Index, i, packed_elem_cnt) {\n    const Index group_id = i / packed_group_size;\n    const T group_scale = scale[group_id];\n    T group_zero;\n    if (symmetric) {\n      if (std::is_same<U, uint8_t>::value) {\n        group_zero = -static_cast<T>(((1 << (bits - 1)) - 1)) * group_scale;\n      } else {\n        group_zero = 0;\n      }\n    } else {\n      group_zero = zero[group_id];\n    }\n    AlignedArray<T, d_pack_size> values;\n    AlignedArray<U, q_pack_size> q = quantized[i];\n    Cast<U, T, q_pack_size>()(q, &values);\n    InplaceFmaScalar<T, d_pack_size>()(&values, group_scale, group_zero);\n    out[i] = values;\n  }\n}\n\ntemplate<typename T, typename U, int num_bits, bool symmetric, size_t d_pack_size,\n         size_t q_pack_size>\nvoid LaunchDequantizeInnerSize1(ep::CudaStream* stream, int64_t outer_size, int64_t group_size,\n                                const U* in, const T* scale, const T* zero, T* out) {\n  if constexpr (sizeof(T) * d_pack_size <= 16 && q_pack_size > 0) {\n    const int64_t packed_elem_cnt = outer_size * group_size / d_pack_size;\n    const int64_t packed_group_size = group_size / d_pack_size;\n    if (packed_elem_cnt <= (1 << 30)) {\n      RUN_CUDA_KERNEL(\n          (DequantizeInnerSize1<T, U, int32_t, d_pack_size, q_pack_size, num_bits, symmetric>),\n          stream, packed_elem_cnt, packed_elem_cnt, packed_group_size,\n          reinterpret_cast<const AlignedArray<U, q_pack_size>*>(in), scale, zero,\n          reinterpret_cast<AlignedArray<T, d_pack_size>*>(out));\n    } else {\n      RUN_CUDA_KERNEL(\n          (DequantizeInnerSize1<T, U, int64_t, d_pack_size, q_pack_size, num_bits, symmetric>),\n          stream, packed_elem_cnt, packed_elem_cnt, packed_group_size,\n          reinterpret_cast<const AlignedArray<U, q_pack_size>*>(in), scale, zero,\n          reinterpret_cast<AlignedArray<T, d_pack_size>*>(out));\n    }\n  } else {\n    UNIMPLEMENTED();\n  }\n}\n\ntemplate<typename T, typename U, int num_bits, bool symmetric>\nvoid DispatchDequantizeInnerSize1PackSize(ep::CudaStream* stream, int64_t outer_size,\n                                          int64_t group_size, const U* in, const T* scale,\n                                          const T* zero, T* out) {\n  constexpr int32_t max_pack_size = 16 / sizeof(T);\n  int32_t pack_size = max_pack_size;\n  while (group_size % pack_size != 0) { pack_size /= 2; }\n  constexpr int32_t data_per_quant = 8 / num_bits;\n  CHECK(group_size % data_per_quant == 0);\n  if (pack_size == 16) {\n    LaunchDequantizeInnerSize1<T, U, num_bits, symmetric, 16, 16 / data_per_quant>(\n        stream, outer_size, group_size, in, scale, zero, out);\n  } else if (pack_size == 8) {\n    LaunchDequantizeInnerSize1<T, U, num_bits, symmetric, 8, 8 / data_per_quant>(\n        stream, outer_size, group_size, in, scale, zero, out);\n  } else if (pack_size == 4) {\n    LaunchDequantizeInnerSize1<T, U, num_bits, symmetric, 4, 4 / data_per_quant>(\n        stream, outer_size, group_size, in, scale, zero, out);\n  } else if (pack_size == 2) {\n    LaunchDequantizeInnerSize1<T, U, num_bits, symmetric, 2, 2 / data_per_quant>(\n        stream, outer_size, group_size, in, scale, zero, out);\n  } else if (pack_size == 1) {\n    LaunchDequantizeInnerSize1<T, U, num_bits, symmetric, 1, 1 / data_per_quant>(\n        stream, outer_size, group_size, in, scale, zero, out);\n  } else {\n    UNIMPLEMENTED();\n  }\n}\n\ntemplate<typename T, typename U, int num_bits, bool symmetric>\nvoid DispatchDequantizeSize(ep::CudaStream* stream, int64_t outer_size, int64_t group_size,\n                            int64_t inner_size, const U* in, const T* scale, const T* zero,\n                            T* out) {\n  if (inner_size == 1) {\n    DispatchDequantizeInnerSize1PackSize<T, U, num_bits, symmetric>(stream, outer_size, group_size,\n                                                                    in, scale, zero, out);\n  } else {\n    DispatchDequantize3D<T, U, num_bits, symmetric>(stream, outer_size, group_size, inner_size, in,\n                                                    scale, zero, out);\n  }\n}\n\ntemplate<typename T, typename U>\nvoid DispatchDequantize(ep::CudaStream* stream, int32_t num_bits, bool symmetric,\n                        int64_t outer_size, int64_t group_size, int64_t inner_size, const U* in,\n                        const T* scale, const T* zero, T* out) {\n  if (num_bits == 4) {\n    if (symmetric) {\n      DispatchDequantizeSize<T, U, 4, true>(stream, outer_size, group_size, inner_size, in, scale,\n                                            zero, out);\n    } else {\n      DispatchDequantizeSize<T, U, 4, false>(stream, outer_size, group_size, inner_size, in, scale,\n                                             zero, out);\n    }\n  } else if (num_bits == 8) {\n    if (symmetric) {\n      DispatchDequantizeSize<T, U, 8, true>(stream, outer_size, group_size, inner_size, in, scale,\n                                            zero, out);\n    } else {\n      DispatchDequantizeSize<T, U, 8, false>(stream, outer_size, group_size, inner_size, in, scale,\n                                             zero, out);\n    }\n\n  } else {\n    UNIMPLEMENTED();\n  }\n}\n\ntemplate<typename T>\nclass GroupwiseDequantizeKernel final : public user_op::OpKernel {\n public:\n  GroupwiseDequantizeKernel() = default;\n  ~GroupwiseDequantizeKernel() override = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    const user_op::Tensor* scale = ctx->Tensor4ArgNameAndIndex(\"scale\", 0);\n    const user_op::Tensor* zero = nullptr;\n    if (ctx->has_input(\"zero\", 0)) { zero = ctx->Tensor4ArgNameAndIndex(\"zero\", 0); }\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    const int64_t group_size = ctx->Attr<int64_t>(\"group_size\");\n    const int64_t group_dim = ctx->Attr<int64_t>(\"group_dim\");\n    const int32_t num_bits = ctx->Attr<int32_t>(\"num_bits\");\n    const bool symmetric = ctx->Attr<bool>(\"symmetric\");\n    const int64_t num_in_axes = in->shape_view().NumAxes();\n    CHECK_GE(num_in_axes, 1);\n    CHECK_EQ(scale->shape_view().NumAxes(), num_in_axes);\n    if (zero != nullptr) { CHECK_EQ(zero->shape_view().NumAxes(), num_in_axes); }\n    CHECK_EQ(out->shape_view().NumAxes(), num_in_axes);\n    CHECK_GE(group_dim, 0);\n    CHECK_LT(group_dim, num_in_axes);\n    for (int i = 0; i < num_in_axes; ++i) {\n      if (i == num_in_axes - 1) {\n        CHECK_EQ(out->shape_view().At(i), in->shape_view().At(i) * (8 / num_bits));\n      } else {\n        CHECK_EQ(out->shape_view().At(i), in->shape_view().At(i));\n      }\n    }\n    const int64_t group_dim_size = out->shape_view().At(group_dim);\n    CHECK_GT(group_size, 0);\n    CHECK_LE(group_size, group_dim_size);\n    CHECK_EQ(group_dim_size % group_size, 0);\n    const int64_t num_groups = group_dim_size / group_size;\n    for (int i = 0; i < num_in_axes; ++i) {\n      const int64_t expected_dim_size = i == group_dim ? num_groups : out->shape_view().At(i);\n      CHECK_EQ(scale->shape_view().At(i), expected_dim_size);\n      if (zero != nullptr) { CHECK_EQ(zero->shape_view().At(i), expected_dim_size); }\n    }\n    const int64_t outer_size = out->shape_view().Count(0, group_dim) * num_groups;\n    const int64_t inner_size = out->shape_view().Count(group_dim + 1);\n    if (in->data_type() == DataType::kUInt8) {\n      DispatchDequantize<T, uint8_t>(ctx->stream()->As<ep::CudaStream>(), num_bits, symmetric,\n                                     outer_size, group_size, inner_size, in->dptr<uint8_t>(),\n                                     scale->dptr<T>(), zero == nullptr ? nullptr : zero->dptr<T>(),\n                                     out->mut_dptr<T>());\n    } else if (in->data_type() == DataType::kInt8) {\n      DispatchDequantize<T, int8_t>(ctx->stream()->As<ep::CudaStream>(), num_bits, symmetric,\n                                    outer_size, group_size, inner_size, in->dptr<int8_t>(),\n                                    scale->dptr<T>(), zero == nullptr ? nullptr : zero->dptr<T>(),\n                                    out->mut_dptr<T>());\n    } else {\n      UNIMPLEMENTED();\n    }\n  }\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_VECTOR_WISE_SYMMETRIC_DEQUANTIZE_KERNEL(dtype)        \\\n  REGISTER_USER_KERNEL(\"groupwise_dequantize\")                         \\\n      .SetCreateFn<GroupwiseDequantizeKernel<dtype>>()                 \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \\\n                       && (user_op::HobDataType(\"scale\", 0) == GetDataType<dtype>::value))\n\nREGISTER_VECTOR_WISE_SYMMETRIC_DEQUANTIZE_KERNEL(half);\nREGISTER_VECTOR_WISE_SYMMETRIC_DEQUANTIZE_KERNEL(float);\n\ntemplate<typename T, typename C, typename U, int block_size, size_t d_pack_size, size_t q_pack_size,\n         int bits, bool symmetric, bool single_group>\n__global__ void QuantizedMatmulBiasGroupN(int32_t M, int32_t N, int32_t K, int32_t group_size,\n                                          const AlignedArray<T, d_pack_size>* __restrict__ x,\n                                          const AlignedArray<U, q_pack_size>* __restrict__ w,\n                                          const AlignedArray<T, d_pack_size>* __restrict__ scale,\n                                          const AlignedArray<T, d_pack_size>* __restrict__ zero,\n                                          const T* __restrict__ bias, T* __restrict__ out) {\n  for (int32_t m = blockIdx.x; m < M; m += gridDim.x) {\n    const auto* x_m = x + m * K;\n    for (int32_t n = blockIdx.y; n < N; n += gridDim.y) {\n      C t_sum = 0;\n      const auto* w_n = w + n * K;\n      const int64_t group_id = single_group ? 0 : n / group_size;\n      const auto* scale_n = scale + group_id * K;\n      const auto* zero_n = symmetric ? nullptr : zero + group_id * K;\n      for (int32_t k = threadIdx.x; k < K; k += block_size) {\n        auto xs = x_m[k];\n        auto ws = w_n[k];\n        auto scale_k = scale_n[k];\n        AlignedArray<T, d_pack_size> zero_k;\n        if (symmetric) {\n          if (std::is_same<U, uint8_t>::value) {\n            zero_k = scale_k;\n            InplaceMulScalar<T, d_pack_size>()(&zero_k, -static_cast<T>(((1 << (bits - 1)) - 1)));\n          } else {\n            for (int i = 0; i < d_pack_size; ++i) { zero_k.elem[i] = 0; }\n          }\n        } else {\n          zero_k = zero_n[k];\n        }\n        AlignedArray<T, d_pack_size> weights;\n        Cast<U, T, q_pack_size>()(ws, &weights);\n        InplaceFma<T, d_pack_size>()(&weights, scale_k, zero_k);\n        MultiplyAccumulate<T, C, d_pack_size>()(xs, weights, &t_sum);\n      }\n      using BlockReduce = cub::BlockReduce<C, block_size>;\n      __shared__ typename BlockReduce::TempStorage temp_storage;\n      C sum = BlockReduce(temp_storage).Sum(t_sum);\n      if (threadIdx.x == 0) {\n        if (bias != nullptr) { sum += static_cast<C>(bias[n]); }\n        out[m * N + n] = static_cast<T>(sum);\n      }\n      __syncthreads();\n    }\n  }\n}\n\ntemplate<typename T, typename C, typename U, int num_bits, bool symmetric, size_t d_pack_size,\n         size_t q_pack_size, bool single_group>\nvoid LaunchMatmulBiasGroupN(ep::CudaStream* stream, int64_t m, int64_t n, int64_t k,\n                            int64_t group_size, const T* x, const U* w, const T* scale,\n                            const T* zero, const T* bias, T* out) {\n  constexpr uint32_t max_grid_size = 8192;\n  constexpr uint32_t block_size = 128;\n  const int64_t int32_max = std::numeric_limits<int32_t>::max();\n  if (m * k > int32_max || n * k > int32_max || m * n > int32_max || m > int32_max - max_grid_size\n      || n > int32_max - max_grid_size || k > int32_max - block_size) {\n    UNIMPLEMENTED();\n  }\n  if constexpr (sizeof(T) * d_pack_size <= 16 && q_pack_size > 0) {\n    QuantizedMatmulBiasGroupN<T, C, U, block_size, d_pack_size, q_pack_size, num_bits, symmetric,\n                              single_group>\n        <<<dim3(std::min<int64_t>(m, max_grid_size), std::min<int64_t>(n, max_grid_size)),\n           block_size, 0, stream->cuda_stream()>>>(\n            m, n, k / d_pack_size, group_size,\n            reinterpret_cast<const AlignedArray<T, d_pack_size>*>(x),\n            reinterpret_cast<const AlignedArray<U, q_pack_size>*>(w),\n            reinterpret_cast<const AlignedArray<T, d_pack_size>*>(scale),\n            reinterpret_cast<const AlignedArray<T, d_pack_size>*>(zero), bias, out);\n  } else {\n    UNIMPLEMENTED();\n  }\n}\n\ntemplate<typename T, typename C, typename U, int num_bits, bool symmetric, size_t d_pack_size,\n         size_t q_pack_size>\nvoid DispatchMatmulBiasGroupNSingleGroup(ep::CudaStream* stream, int64_t m, int64_t n, int64_t k,\n                                         int64_t group_size, const T* x, const U* w, const T* scale,\n                                         const T* zero, const T* bias, T* out) {\n  if (n == group_size) {\n    LaunchMatmulBiasGroupN<T, C, U, num_bits, symmetric, d_pack_size, q_pack_size, true>(\n        stream, m, n, k, group_size, x, w, scale, zero, bias, out);\n  } else {\n    LaunchMatmulBiasGroupN<T, C, U, num_bits, symmetric, d_pack_size, q_pack_size, false>(\n        stream, m, n, k, group_size, x, w, scale, zero, bias, out);\n  }\n}\n\ntemplate<typename T, typename C, typename U, int num_bits, bool symmetric>\nvoid DispatchMatmulBiasGroupNPackSize(ep::CudaStream* stream, int64_t m, int64_t n, int64_t k,\n                                      int64_t group_size, const T* x, const U* w, const T* scale,\n                                      const T* zero, const T* bias, T* out) {\n  const int max_pack_size = 16 / sizeof(T);\n  int pack_size = max_pack_size;\n  while (k % pack_size != 0) { pack_size /= 2; }\n  constexpr int32_t data_per_quant = 8 / num_bits;\n  if (pack_size == 16) {\n    DispatchMatmulBiasGroupNSingleGroup<T, C, U, num_bits, symmetric, 16, 16 / data_per_quant>(\n        stream, m, n, k, group_size, x, w, scale, zero, bias, out);\n  } else if (pack_size == 8) {\n    DispatchMatmulBiasGroupNSingleGroup<T, C, U, num_bits, symmetric, 8, 8 / data_per_quant>(\n        stream, m, n, k, group_size, x, w, scale, zero, bias, out);\n  } else if (pack_size == 4) {\n    DispatchMatmulBiasGroupNSingleGroup<T, C, U, num_bits, symmetric, 4, 4 / data_per_quant>(\n        stream, m, n, k, group_size, x, w, scale, zero, bias, out);\n  } else if (pack_size == 2) {\n    DispatchMatmulBiasGroupNSingleGroup<T, C, U, num_bits, symmetric, 2, 2 / data_per_quant>(\n        stream, m, n, k, group_size, x, w, scale, zero, bias, out);\n  } else if (pack_size == 1) {\n    DispatchMatmulBiasGroupNSingleGroup<T, C, U, num_bits, symmetric, 1, 1 / data_per_quant>(\n        stream, m, n, k, group_size, x, w, scale, zero, bias, out);\n  } else {\n    UNIMPLEMENTED();\n  }\n}\n\ntemplate<typename T, typename C, typename U, int block_size, size_t d_pack_size, size_t q_pack_size,\n         int bits, bool symmetric, bool single_group>\n__global__ void QuantizedMatmulBiasGroupK(int32_t M, int32_t N, int32_t K, int32_t group_size,\n                                          int32_t num_groups_per_n,\n                                          const AlignedArray<T, d_pack_size>* __restrict__ x,\n                                          const AlignedArray<U, q_pack_size>* __restrict__ w,\n                                          const T* __restrict__ scale, const T* __restrict__ zero,\n                                          const T* __restrict__ bias, T* __restrict__ out) {\n  for (int32_t m = blockIdx.x; m < M; m += gridDim.x) {\n    const auto* x_m = x + m * K;\n    for (int32_t n = blockIdx.y; n < N; n += gridDim.y) {\n      C t_sum = 0;\n      const auto* w_n = w + n * K;\n      const auto* scale_n = scale + n * num_groups_per_n;\n      const T* zero_n = symmetric ? nullptr : zero + n * num_groups_per_n;\n      T group_scale;\n      T group_zero;\n      if (single_group) {\n        group_scale = static_cast<T>(scale_n[0]);\n        if (symmetric) {\n          if (std::is_same<U, uint8_t>::value) {\n            group_zero = -static_cast<T>(((1 << (bits - 1)) - 1)) * group_scale;\n          } else {\n            group_zero = 0;\n          }\n        } else {\n          group_zero = zero_n[0];\n        }\n      }\n      for (int32_t k = threadIdx.x; k < K; k += block_size) {\n        if (!single_group) {\n          auto group_id = k / group_size;\n          group_scale = static_cast<T>(scale_n[group_id]);\n          if (symmetric) {\n            if (std::is_same<U, uint8_t>::value) {\n              group_zero = -static_cast<T>(((1 << (bits - 1)) - 1)) * group_scale;\n            } else {\n              group_zero = 0;\n            }\n          } else {\n            group_zero = zero_n[group_id];\n          }\n        }\n        auto xs = x_m[k];\n        auto ws = w_n[k];\n        AlignedArray<T, d_pack_size> weights;\n        Cast<U, T, q_pack_size>()(ws, &weights);\n        InplaceFmaScalar<T, d_pack_size>()(&weights, group_scale, group_zero);\n        MultiplyAccumulate<T, C, d_pack_size>()(xs, weights, &t_sum);\n      }\n      using BlockReduce = cub::BlockReduce<C, block_size>;\n      __shared__ typename BlockReduce::TempStorage temp_storage;\n      C sum = BlockReduce(temp_storage).Sum(t_sum);\n      if (threadIdx.x == 0) {\n        if (bias != nullptr) { sum += static_cast<C>(bias[n]); }\n        out[m * N + n] = static_cast<T>(sum);\n      }\n      __syncthreads();\n    }\n  }\n}\n\ntemplate<typename T, typename C, typename U, int num_bits, bool symmetric, size_t d_pack_size,\n         size_t q_pack_size, bool single_group>\nvoid LaunchMatmulBiasGroupK(ep::CudaStream* stream, int64_t m, int64_t n, int64_t k,\n                            int64_t group_size, const T* x, const U* w, const T* scale,\n                            const T* zero, const T* bias, T* out) {\n  constexpr uint32_t max_grid_size = 8192;\n  constexpr uint32_t block_size = 128;\n  const int64_t int32_max = std::numeric_limits<int32_t>::max();\n  if (m * k > int32_max || n * k > int32_max || m * n > int32_max || m > int32_max - max_grid_size\n      || n > int32_max - max_grid_size || k > int32_max - block_size) {\n    UNIMPLEMENTED();\n  }\n  if constexpr (sizeof(T) * d_pack_size <= 16 && q_pack_size > 0) {\n    QuantizedMatmulBiasGroupK<T, C, U, block_size, d_pack_size, q_pack_size, num_bits, symmetric,\n                              single_group>\n        <<<dim3(std::min<int64_t>(m, max_grid_size), std::min<int64_t>(n, max_grid_size)),\n           block_size, 0, stream->cuda_stream()>>>(\n            m, n, k / d_pack_size, group_size / d_pack_size, k / group_size,\n            reinterpret_cast<const AlignedArray<T, d_pack_size>*>(x),\n            reinterpret_cast<const AlignedArray<U, q_pack_size>*>(w), scale, zero, bias, out);\n  } else {\n    UNIMPLEMENTED();\n  }\n}\n\ntemplate<typename T, typename C, typename U, int num_bits, bool symmetric, size_t d_pack_size,\n         size_t q_pack_size>\nvoid DispatchMatmulBiasGroupKSingleGroup(ep::CudaStream* stream, int64_t m, int64_t n, int64_t k,\n                                         int64_t group_size, const T* x, const U* w, const T* scale,\n                                         const T* zero, const T* bias, T* out) {\n  if (k == group_size) {\n    LaunchMatmulBiasGroupK<T, C, U, num_bits, symmetric, d_pack_size, q_pack_size, true>(\n        stream, m, n, k, group_size, x, w, scale, zero, bias, out);\n  } else {\n    LaunchMatmulBiasGroupK<T, C, U, num_bits, symmetric, d_pack_size, q_pack_size, false>(\n        stream, m, n, k, group_size, x, w, scale, zero, bias, out);\n  }\n}\n\ntemplate<typename T, typename C, typename U, int num_bits, bool symmetric>\nvoid DispatchMatmulBiasGroupKPackSize(ep::CudaStream* stream, int64_t m, int64_t n, int64_t k,\n                                      int64_t group_size, const T* x, const U* w, const T* scale,\n                                      const T* zero, const T* bias, T* out) {\n  const int max_pack_size = 16 / sizeof(T);\n  int pack_size = max_pack_size;\n  while (group_size % pack_size != 0) { pack_size /= 2; }\n  constexpr int32_t data_per_quant = 8 / num_bits;\n  if (pack_size == 16) {\n    DispatchMatmulBiasGroupKSingleGroup<T, C, U, num_bits, symmetric, 16, 16 / data_per_quant>(\n        stream, m, n, k, group_size, x, w, scale, zero, bias, out);\n  } else if (pack_size == 8) {\n    DispatchMatmulBiasGroupKSingleGroup<T, C, U, num_bits, symmetric, 8, 8 / data_per_quant>(\n        stream, m, n, k, group_size, x, w, scale, zero, bias, out);\n  } else if (pack_size == 4) {\n    DispatchMatmulBiasGroupKSingleGroup<T, C, U, num_bits, symmetric, 4, 4 / data_per_quant>(\n        stream, m, n, k, group_size, x, w, scale, zero, bias, out);\n  } else if (pack_size == 2) {\n    DispatchMatmulBiasGroupKSingleGroup<T, C, U, num_bits, symmetric, 2, 2 / data_per_quant>(\n        stream, m, n, k, group_size, x, w, scale, zero, bias, out);\n  } else if (pack_size == 1) {\n    DispatchMatmulBiasGroupKSingleGroup<T, C, U, num_bits, symmetric, 1, 1 / data_per_quant>(\n        stream, m, n, k, group_size, x, w, scale, zero, bias, out);\n  } else {\n    UNIMPLEMENTED();\n  }\n}\n\ntemplate<typename T, typename C, typename U, int num_bits, bool symmetric>\nvoid DispatchMatmulBiasGroupDim(ep::CudaStream* stream, int64_t m, int64_t n, int64_t k,\n                                int64_t group_dim, int64_t group_size, const T* x, const U* w,\n                                const T* scale, const T* zero, const T* bias, T* out) {\n  if (group_dim == 0) {\n    DispatchMatmulBiasGroupNPackSize<T, C, U, num_bits, symmetric>(stream, m, n, k, group_size, x,\n                                                                   w, scale, zero, bias, out);\n  } else if (group_dim == 1) {\n    DispatchMatmulBiasGroupKPackSize<T, C, U, num_bits, symmetric>(stream, m, n, k, group_size, x,\n                                                                   w, scale, zero, bias, out);\n  } else {\n    UNIMPLEMENTED();\n  }\n}\n\ntemplate<typename T, typename C, typename U>\nvoid DispatchMatmulBias(ep::CudaStream* stream, int num_bits, bool symmetric, int64_t m, int64_t n,\n                        int64_t k, int64_t group_dim, int64_t group_size, const T* x, const U* w,\n                        const T* scale, const T* zero, const T* bias, T* out) {\n  if (num_bits == 4) {\n    if (symmetric) {\n      DispatchMatmulBiasGroupDim<T, C, U, 4, true>(stream, m, n, k, group_dim, group_size, x, w,\n                                                   scale, zero, bias, out);\n    } else {\n      DispatchMatmulBiasGroupDim<T, C, U, 4, false>(stream, m, n, k, group_dim, group_size, x, w,\n                                                    scale, zero, bias, out);\n    }\n  } else if (num_bits == 8) {\n    if (symmetric) {\n      DispatchMatmulBiasGroupDim<T, C, U, 8, true>(stream, m, n, k, group_dim, group_size, x, w,\n                                                   scale, zero, bias, out);\n    } else {\n      DispatchMatmulBiasGroupDim<T, C, U, 8, false>(stream, m, n, k, group_dim, group_size, x, w,\n                                                    scale, zero, bias, out);\n    }\n  } else {\n    UNIMPLEMENTED();\n  }\n}\n\ntemplate<typename T>\nclass FusedLinearWithGroupwiseQuantizedWeightKernel final : public user_op::OpKernel,\n                                                            public user_op::CudaGraphSupport {\n public:\n  FusedLinearWithGroupwiseQuantizedWeightKernel() = default;\n  ~FusedLinearWithGroupwiseQuantizedWeightKernel() override = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*,\n               const user_op::OpKernelCache* cache) const override {\n    const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    const user_op::Tensor* w = ctx->Tensor4ArgNameAndIndex(\"w\", 0);\n    const user_op::Tensor* w_scale = ctx->Tensor4ArgNameAndIndex(\"w_scale\", 0);\n    const user_op::Tensor* b =\n        (ctx->has_input(\"b\", 0)) ? ctx->Tensor4ArgNameAndIndex(\"b\", 0) : nullptr;\n    const user_op::Tensor* w_zero =\n        (ctx->has_input(\"w_zero\", 0)) ? ctx->Tensor4ArgNameAndIndex(\"w_zero\", 0) : nullptr;\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    const DataType data_type = x->data_type();\n    CHECK_EQ(w_scale->data_type(), data_type);\n    CHECK_EQ(out->data_type(), data_type);\n    const int64_t group_size = ctx->Attr<int64_t>(\"group_size\");\n    const int64_t group_dim = ctx->Attr<int64_t>(\"group_dim\");\n    CHECK(group_dim == 0 || group_dim == 1);\n    const int32_t num_bits = ctx->Attr<int32_t>(\"num_bits\");\n    const bool symmetric = ctx->Attr<bool>(\"symmetric\");\n    CHECK_GE(x->shape_view().NumAxes(), 2);\n    const int64_t k = x->shape_view().At(x->shape_view().NumAxes() - 1);\n    const int64_t m = x->shape_view().elem_cnt() / k;\n    CHECK_EQ(w->shape_view().NumAxes(), 2);\n    if (num_bits == 4) {\n      CHECK_EQ(w->shape_view().At(1) * 2, k);\n    } else if (num_bits == 8) {\n      CHECK_EQ(w->shape_view().At(1), k);\n    } else {\n      UNIMPLEMENTED();\n    }\n    const int64_t n = w->shape_view().At(0);\n    const int64_t group_dim_size = group_dim == 0 ? n : k;\n    CHECK_GT(group_size, 0);\n    CHECK_LE(group_size, group_dim_size);\n    CHECK_EQ(group_dim_size % group_size, 0);\n    const int64_t num_groups = group_dim_size / group_size;\n    if (group_dim == 0) {\n      CHECK_EQ(w_scale->shape_view().At(0), num_groups);\n      CHECK_EQ(w_scale->shape_view().At(1), k);\n    } else if (group_dim == 1) {\n      CHECK_EQ(w_scale->shape_view().At(0), n);\n      CHECK_EQ(w_scale->shape_view().At(1), num_groups);\n    } else {\n      UNIMPLEMENTED();\n    }\n    if (w_zero != nullptr) {\n      CHECK_EQ(w_zero->data_type(), data_type);\n      CHECK(w_zero->shape_view() == w_scale->shape_view());\n    }\n    if (b != nullptr) {\n      CHECK_EQ(b->data_type(), data_type);\n      CHECK_EQ(b->shape_view().NumAxes(), 1);\n      CHECK_EQ(b->shape_view().At(0), n);\n    }\n    CHECK_EQ(x->shape_view().NumAxes(), out->shape_view().NumAxes());\n    for (int i = 0; i < x->shape_view().NumAxes() - 1; ++i) {\n      CHECK_EQ(out->shape_view().At(i), x->shape_view().At(i));\n    }\n    CHECK_EQ(out->shape_view().At(out->shape_view().NumAxes() - 1), n);\n    if (symmetric) {\n      CHECK(w_zero == nullptr);\n    } else {\n      CHECK(w_zero != nullptr);\n    }\n    const DataType quant_type = w->data_type();\n    if (quant_type == DataType::kUInt8) {\n      DispatchMatmulBias<T, float, uint8_t>(\n          ctx->stream()->As<ep::CudaStream>(), num_bits, symmetric, m, n, k, group_dim, group_size,\n          x->dptr<T>(), w->dptr<uint8_t>(), w_scale->dptr<T>(),\n          w_zero == nullptr ? nullptr : w_zero->dptr<T>(), b == nullptr ? nullptr : b->dptr<T>(),\n          out->mut_dptr<T>());\n    } else if (quant_type == DataType::kInt8) {\n      DispatchMatmulBias<T, float, int8_t>(\n          ctx->stream()->As<ep::CudaStream>(), num_bits, symmetric, m, n, k, group_dim, group_size,\n          x->dptr<T>(), w->dptr<int8_t>(), w_scale->dptr<T>(),\n          w_zero == nullptr ? nullptr : w_zero->dptr<T>(), b == nullptr ? nullptr : b->dptr<T>(),\n          out->mut_dptr<T>());\n    } else {\n      UNIMPLEMENTED();\n    }\n  }\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_FUSED_MATMUL_BIAS_KERNEL_GPU(data_type, cpp_type)            \\\n  REGISTER_USER_KERNEL(\"fused_linear_with_groupwise_quantized_weight\")        \\\n      .SetCreateFn<FusedLinearWithGroupwiseQuantizedWeightKernel<cpp_type>>() \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)        \\\n                       && (user_op::HobDataType(\"out\", 0) == data_type));\n\nREGISTER_FUSED_MATMUL_BIAS_KERNEL_GPU(DataType::kFloat, float);\nREGISTER_FUSED_MATMUL_BIAS_KERNEL_GPU(DataType::kFloat16, half);\n\n}  // namespace\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/host_scalar_add_by_tensor_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/new_kernel_util.h\"\n#include \"oneflow/user/kernels/radix_sort.cuh\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename T>\n__global__ void ScalarAdd(int32_t elem_cnt, const T* in, const T scalar, T* out) {\n  CUDA_1D_KERNEL_LOOP(i, elem_cnt) { out[i] = in[i] + scalar; };\n}\n\n}  // namespace\n\ntemplate<typename T>\nclass HostScalarAddByTensorKernel final : public user_op::OpKernel {\n public:\n  HostScalarAddByTensorKernel() = default;\n  ~HostScalarAddByTensorKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    const user_op::Tensor* scalar = ctx->Tensor4ArgNameAndIndex(\"scalar\", 0);\n    user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n\n    const int32_t elem_cnt = x->shape_view().elem_cnt();\n\n    CHECK_EQ(scalar->shape_view().elem_cnt(), 1);\n\n    // val of scalar can be visited because it is host input.\n    const T scalar_val = *scalar->dptr<T>();\n\n    ScalarAdd<<<BlocksNum4ThreadsNum(elem_cnt), kCudaThreadsNumPerBlock, 0,\n                ctx->stream()->As<ep::CudaStream>()->cuda_stream()>>>(elem_cnt, x->dptr<T>(),\n                                                                      scalar_val, y->mut_dptr<T>());\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_CUDA_ARG_SORT_KERNEL(dtype)                                               \\\n  REGISTER_USER_KERNEL(\"host_scalar_add_by_tensor\")                                        \\\n      .SetCreateFn<HostScalarAddByTensorKernel<dtype>>()                                   \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)                     \\\n                       && (user_op::HobDataType(\"x\", 0) == GetDataType<dtype>::value)      \\\n                       && (user_op::HobDataType(\"scalar\", 0) == GetDataType<dtype>::value) \\\n                       && (user_op::HobDataType(\"y\", 0) == GetDataType<dtype>::value));\n\nREGISTER_CUDA_ARG_SORT_KERNEL(float)\nREGISTER_CUDA_ARG_SORT_KERNEL(double)\nREGISTER_CUDA_ARG_SORT_KERNEL(bool)\nREGISTER_CUDA_ARG_SORT_KERNEL(int8_t)\nREGISTER_CUDA_ARG_SORT_KERNEL(uint8_t)\nREGISTER_CUDA_ARG_SORT_KERNEL(int32_t)\nREGISTER_CUDA_ARG_SORT_KERNEL(int64_t)\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/image_batch_align_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/thread/thread_manager.h\"\n#include \"oneflow/user/image/image_util.h\"\n#include <opencv2/opencv.hpp>\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename T, typename F>\nvoid CopyFromTensorBuffer(T* image_ptr, const TensorBuffer& image_buffer, const int batch_height,\n                          const int batch_width, const int channels) {\n  CHECK_EQ(image_buffer.shape_view().NumAxes(), 3);\n  const int h = image_buffer.shape_view().At(0);\n  const int w = image_buffer.shape_view().At(1);\n  const int c = image_buffer.shape_view().At(2);\n  CHECK_LE(h, batch_height);\n  CHECK_LE(w, batch_width);\n  CHECK_EQ(c, channels);\n  FOR_RANGE(int, i, 0, h) {\n    const F* from = image_buffer.data<F>() + i * w * c;\n    T* to = image_ptr + i * batch_width * channels;\n    std::transform(from, from + w * c, to, [](F v) { return static_cast<T>(v); });\n  }\n}\n\ntemplate<typename T>\nstruct ImageCopier final {\n#define MAKE_COPY_FROM_TENSOR_BUFFER_SWITCH_ENTRY(func_name, F) func_name<T, F>\n  DEFINE_STATIC_SWITCH_FUNC(void, CopyFromTensorBuffer, MAKE_COPY_FROM_TENSOR_BUFFER_SWITCH_ENTRY,\n                            MAKE_DATA_TYPE_CTRV_SEQ(IMAGE_DATA_TYPE_SEQ))\n#undef MAKE_COPY_FROM_TENSOR_BUFFER_SWITCH_ENTRY\n};\n\n}  // namespace\n\ntemplate<typename T>\nclass ImageBatchAlignKernel final : public user_op::OpKernel {\n public:\n  ImageBatchAlignKernel() = default;\n  ~ImageBatchAlignKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* in_tensor = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    user_op::Tensor* out_tensor = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    CHECK_EQ(in_tensor->shape_view().NumAxes(), 1);\n    CHECK_EQ(out_tensor->shape_view().NumAxes(), 4);\n    const int64_t num_images = in_tensor->shape_view().elem_cnt();\n    const bool dynamic_out = ctx->Attr<bool>(\"dynamic_out\");\n    CHECK_GT(num_images, 0);\n    int64_t max_height = 0;\n    int64_t max_width = 0;\n    const int64_t channels = out_tensor->shape_view().At(3);\n    FOR_RANGE(int, i, 0, num_images) {\n      const TensorBuffer& image_buffer = in_tensor->dptr<TensorBuffer>()[i];\n      max_height = std::max(max_height, image_buffer.shape_view().At(0));\n      max_width = std::max(max_width, image_buffer.shape_view().At(1));\n      CHECK_EQ(image_buffer.shape_view().At(2), channels);\n    }\n    int32_t alignment = ctx->Attr<int32_t>(\"alignment\");\n    max_height = RoundUp(max_height, alignment);\n    max_width = RoundUp(max_width, alignment);\n\n    if (dynamic_out) {\n      auto mut_shape_view = out_tensor->mut_shape_view();\n      mut_shape_view.Set(0, num_images);\n      mut_shape_view.Set(1, max_height);\n      mut_shape_view.Set(2, max_width);\n    }\n\n    memset(out_tensor->mut_dptr(), 0,\n           out_tensor->shape_view().elem_cnt() * GetSizeOfDataType(out_tensor->data_type()));\n    MultiThreadLoop(num_images, [&](size_t i) {\n      const TensorBuffer& image_buffer = in_tensor->dptr<TensorBuffer>()[i];\n      T* out_ptr = out_tensor->mut_dptr<T>() + i * max_height * max_width * channels;\n      ImageCopier<T>::SwitchCopyFromTensorBuffer(SwitchCase(image_buffer.data_type()), out_ptr,\n                                                 image_buffer, max_height, max_width, channels);\n    });\n  }\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_IMAGE_BATCH_ALIGN_KERNEL(dtype)                                     \\\n  REGISTER_USER_KERNEL(\"image_batch_align\")                                          \\\n      .SetCreateFn<ImageBatchAlignKernel<dtype>>()                                   \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)                \\\n                       && (user_op::HobDataType(\"in\", 0) == DataType::kTensorBuffer) \\\n                       && (user_op::HobDataType(\"out\", 0) == GetDataType<dtype>::value));\n\nREGISTER_IMAGE_BATCH_ALIGN_KERNEL(uint8_t)\nREGISTER_IMAGE_BATCH_ALIGN_KERNEL(float)\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/image_decode_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/thread/thread_manager.h\"\n#include \"oneflow/user/image/image_util.h\"\n#include <opencv2/opencv.hpp>\n\nnamespace oneflow {\n\nnamespace {\n\nvoid DecodeImage(const TensorBuffer& raw_bytes, TensorBuffer* image_buffer,\n                 const std::string& color_space, DataType data_type) {\n  // should only support kChar, but numpy ndarray maybe cannot convert to char*\n  CHECK(raw_bytes.data_type() == DataType::kChar || raw_bytes.data_type() == DataType::kInt8\n        || raw_bytes.data_type() == DataType::kUInt8);\n  cv::_InputArray raw_bytes_arr(raw_bytes.data<char>(), raw_bytes.elem_cnt());\n  cv::Mat image_mat = cv::imdecode(\n      raw_bytes_arr, (ImageUtil::IsColor(color_space) ? cv::IMREAD_COLOR : cv::IMREAD_GRAYSCALE)\n                         | cv::IMREAD_ANYDEPTH);\n  if (ImageUtil::IsColor(color_space) && color_space != \"BGR\") {\n    ImageUtil::ConvertColor(\"BGR\", image_mat, color_space, image_mat);\n  }\n  if (data_type == DataType::kUInt8) {\n    image_mat.convertTo(image_mat, CV_8U);\n  } else if (data_type == DataType::kFloat) {\n    image_mat.convertTo(image_mat, CV_32F);\n  } else {\n    UNIMPLEMENTED();\n  }\n\n  int64_t h = image_mat.rows;\n  int64_t w = image_mat.cols;\n  int64_t c = image_mat.channels();\n  image_buffer->Resize(Shape({h, w, c}), data_type);\n\n  w *= c;\n  if (image_mat.isContinuous()) {\n    w *= h;\n    h = 1;\n  }\n  char* image_ptr = image_buffer->mut_data<char>();\n  FOR_RANGE(int64_t, i, 0, h) {\n    memcpy(image_ptr + i * w, image_mat.ptr(i), w * GetSizeOfDataType(data_type));\n  }\n}\n\n}  // namespace\n\nclass ImageDecodeKernel final : public user_op::OpKernel {\n public:\n  ImageDecodeKernel() = default;\n  ~ImageDecodeKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* in_tensor = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    user_op::Tensor* out_tensor = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    CHECK_EQ(in_tensor->shape_view().elem_cnt(), out_tensor->shape_view().elem_cnt());\n    CHECK_GT(in_tensor->shape_view().elem_cnt(), 0);\n\n    const TensorBuffer* in_img_buf = in_tensor->dptr<TensorBuffer>();\n    TensorBuffer* out_img_buf = out_tensor->mut_dptr<TensorBuffer>();\n    const std::string& color_space = ctx->Attr<std::string>(\"color_space\");\n    const DataType data_type = ctx->Attr<DataType>(\"data_type\");\n\n    MultiThreadLoop(in_tensor->shape_view().elem_cnt(), [&](size_t i) {\n      DecodeImage(in_img_buf[i], out_img_buf + i, color_space, data_type);\n    });\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\nREGISTER_USER_KERNEL(\"image_decode\")\n    .SetCreateFn<ImageDecodeKernel>()\n    .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)\n                     && (user_op::HobDataType(\"in\", 0) == DataType::kTensorBuffer)\n                     && (user_op::HobDataType(\"out\", 0) == DataType::kTensorBuffer));\n;\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/image_object_preprocess_kernels.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/thread/thread_manager.h\"\n#include \"oneflow/user/image/image_util.h\"\n#include <opencv2/opencv.hpp>\n#include <cfenv>\n\nnamespace oneflow {\n\nnamespace {\n\nenum class FlipCode : int8_t {\n  kNonFlip = 0x00,\n  kHorizontalFlip = 0x01,\n  kVerticalFlip = 0x02,\n  kBothDirectionFlip = 0x03,\n};\n\nbool operator&(FlipCode lhs, FlipCode rhs) {\n  return static_cast<bool>(static_cast<std::underlying_type<FlipCode>::type>(lhs)\n                           & static_cast<std::underlying_type<FlipCode>::type>(rhs));\n}\n\nint CvFlipCode(FlipCode flip_code) {\n  if (flip_code == FlipCode::kHorizontalFlip) {\n    return 1;\n  } else if (flip_code == FlipCode::kVerticalFlip) {\n    return 0;\n  } else if (flip_code == FlipCode::kBothDirectionFlip) {\n    return -1;\n  } else {\n    UNIMPLEMENTED();\n  }\n}\n\nvoid FlipImage(TensorBuffer* image_buffer, FlipCode flip_code) {\n  cv::Mat image_mat = GenCvMat4ImageBuffer(*image_buffer);\n  cv::flip(image_mat, image_mat, CvFlipCode(flip_code));\n}\n\ntemplate<typename T>\nvoid FlipBoxes(TensorBuffer* boxes_buffer, int32_t image_width, int32_t image_height,\n               FlipCode flip_code) {\n  int num_boxes = boxes_buffer->shape_view().At(0);\n  FOR_RANGE(int, i, 0, num_boxes) {\n    T* cur_box_ptr = boxes_buffer->mut_data<T>() + i * 4;\n    if (flip_code & FlipCode::kHorizontalFlip) {\n      T xmin = cur_box_ptr[0];\n      T xmax = cur_box_ptr[2];\n      cur_box_ptr[0] = image_width - xmax - static_cast<T>(1);\n      cur_box_ptr[2] = image_width - xmin - static_cast<T>(1);\n    }\n    if (flip_code & FlipCode::kVerticalFlip) {\n      T ymin = cur_box_ptr[1];\n      T ymax = cur_box_ptr[3];\n      cur_box_ptr[1] = image_height - ymax - static_cast<T>(1);\n      cur_box_ptr[3] = image_height - ymin - static_cast<T>(1);\n    }\n  }\n}\n\n#define MAKE_FLIP_BOXES_SWITCH_ENTRY(func_name, T) func_name<T>\nDEFINE_STATIC_SWITCH_FUNC(void, FlipBoxes, MAKE_FLIP_BOXES_SWITCH_ENTRY,\n                          MAKE_DATA_TYPE_CTRV_SEQ(FLOATING_DATA_TYPE_SEQ));\n\n#undef MAKE_FLIP_BOXES_SWITCH_ENTRY\n\ntemplate<typename T>\nvoid ScaleBoxes(TensorBuffer* boxes_buffer, T scale_w, T scale_h) {\n  int num_boxes = boxes_buffer->shape_view().At(0);\n  FOR_RANGE(int, i, 0, num_boxes) {\n    T* cur_box_ptr = boxes_buffer->mut_data<T>() + i * 4;\n    cur_box_ptr[0] *= scale_w;\n    cur_box_ptr[1] *= scale_h;\n    cur_box_ptr[2] *= scale_w;\n    cur_box_ptr[3] *= scale_h;\n  }\n}\n\n#define MAKE_SCALE_BOXES_SWITCH_ENTRY(func_name, T) func_name<T>\nDEFINE_STATIC_SWITCH_FUNC(void, ScaleBoxes, MAKE_SCALE_BOXES_SWITCH_ENTRY,\n                          MAKE_DATA_TYPE_CTRV_SEQ(FLOATING_DATA_TYPE_SEQ));\n\n#undef MAKE_SCALE_BOXES_SWITCH_ENTRY\n\ntemplate<typename T>\nvoid FlipPolygons(TensorBuffer* polygons_buffer, int32_t image_width, int32_t image_height,\n                  FlipCode flip_code) {\n  int num_points = polygons_buffer->shape_view().At(0);\n  FOR_RANGE(int, i, 0, num_points) {\n    T* cur_poly_ptr = polygons_buffer->mut_data<T>() + i * 2;\n    if (flip_code & FlipCode::kHorizontalFlip) { cur_poly_ptr[0] = image_width - cur_poly_ptr[0]; }\n    if (flip_code & FlipCode::kVerticalFlip) { cur_poly_ptr[1] = image_height - cur_poly_ptr[1]; }\n  }\n}\n\n#define MAKE_FLIP_POLYGONS_SWITCH_ENTRY(func_name, T) func_name<T>\nDEFINE_STATIC_SWITCH_FUNC(void, FlipPolygons, MAKE_FLIP_POLYGONS_SWITCH_ENTRY,\n                          MAKE_DATA_TYPE_CTRV_SEQ(FLOATING_DATA_TYPE_SEQ));\n\n#undef MAKE_FLIP_POLYGONS_SWITCH_ENTRY\n\ntemplate<typename T>\nvoid ScalePolygons(TensorBuffer* poly_buffer, T scale_w, T scale_h) {\n  int num_pts = poly_buffer->shape_view().At(0);\n  FOR_RANGE(int, i, 0, num_pts) {\n    T* cur_pt = poly_buffer->mut_data<T>() + i * 2;\n    cur_pt[0] *= scale_w;\n    cur_pt[1] *= scale_h;\n  }\n}\n\n#define MAKE_SCALE_POLYGONS_SWITCH_ENTRY(func_name, T) func_name<T>\nDEFINE_STATIC_SWITCH_FUNC(void, ScalePolygons, MAKE_SCALE_POLYGONS_SWITCH_ENTRY,\n                          MAKE_DATA_TYPE_CTRV_SEQ(FLOATING_DATA_TYPE_SEQ));\n\n#undef MAKE_SCALE_POLYGONS_SWITCH_ENTRY\n\ntemplate<typename T>\nvoid ImageNormalizeByChannel(TensorBuffer* image_buffer, const std::vector<float>& std_vec,\n                             const std::vector<float>& mean_vec) {\n  CHECK_EQ(image_buffer->shape_view().NumAxes(), 3);\n  int h = image_buffer->shape_view().At(0);\n  int w = image_buffer->shape_view().At(1);\n  int c = image_buffer->shape_view().At(2);\n  CHECK_EQ(std_vec.size(), c);\n  CHECK_EQ(mean_vec.size(), c);\n  FOR_RANGE(int, i, 0, (h * w)) {\n    T* image_data = image_buffer->mut_data<T>() + i * c;\n    FOR_RANGE(int, j, 0, c) { image_data[j] = (image_data[j] - mean_vec.at(j)) / std_vec.at(j); }\n  }\n}\n\n#define MAKE_IMAGE_NORMALIZE_SWITCH_ENTRY(func_name, T) func_name<T>\nDEFINE_STATIC_SWITCH_FUNC(void, ImageNormalizeByChannel, MAKE_IMAGE_NORMALIZE_SWITCH_ENTRY,\n                          MAKE_DATA_TYPE_CTRV_SEQ(FLOATING_DATA_TYPE_SEQ));\n\n#undef MAKE_IMAGE_NORMALIZE_SWITCH_ENTRY\n\ntemplate<typename T, typename I>\nvoid PolygonsToMask(const TensorBuffer& polys, const TensorBuffer& polys_nd_index,\n                    TensorBuffer* masks, int32_t im_w, int32_t im_h) {\n  CHECK_EQ(polys.shape_view().NumAxes(), 2);\n  CHECK_EQ(polys.shape_view().At(1), 2);\n  CHECK_EQ(polys_nd_index.shape_view().NumAxes(), 2);\n  CHECK_EQ(polys_nd_index.shape_view().At(1), 3);\n  int num_points = polys.shape_view().At(0);\n  CHECK_EQ(polys_nd_index.shape_view().At(0), num_points);\n\n  std::vector<std::vector<cv::Point>> poly_point_vec;\n  std::vector<cv::Mat> mask_mat_vec;\n  auto PolyToMask = [&]() {\n    CHECK_GT(poly_point_vec.size(), 0);\n    CHECK_GT(poly_point_vec.front().size(), 0);\n    cv::Mat mask_mat = cv::Mat(im_h, im_w, CV_8SC1, cv::Scalar(0));\n    cv::fillPoly(mask_mat, poly_point_vec, cv::Scalar(1), cv::LINE_8);\n    mask_mat_vec.emplace_back(std::move(mask_mat));\n    poly_point_vec.clear();\n  };\n\n  int origin_round_way = std::fegetround();\n  CHECK_EQ(std::fesetround(FE_TONEAREST), 0);\n  FOR_RANGE(int, i, 0, num_points) {\n    const I pt_idx = polys_nd_index.data<I>()[i * 3 + 0];\n    const I poly_idx = polys_nd_index.data<I>()[i * 3 + 1];\n    const I segm_idx = polys_nd_index.data<I>()[i * 3 + 2];\n    if (segm_idx != mask_mat_vec.size()) { PolyToMask(); }\n    if (poly_idx == poly_point_vec.size()) {\n      poly_point_vec.emplace_back(std::vector<cv::Point>());\n    }\n    CHECK_EQ(segm_idx, mask_mat_vec.size());\n    CHECK_EQ(poly_idx, poly_point_vec.size() - 1);\n    CHECK_EQ(pt_idx, poly_point_vec.back().size());\n    const T* pts_ptr = polys.data<T>() + i * 2;\n    cv::Point pt{static_cast<int>(std::nearbyint(pts_ptr[0])),\n                 static_cast<int>(std::nearbyint(pts_ptr[1]))};\n    poly_point_vec.back().emplace_back(std::move(pt));\n  }\n  PolyToMask();\n  CHECK_EQ(std::fesetround(origin_round_way), 0);\n\n  masks->Resize(Shape({static_cast<int64_t>(mask_mat_vec.size()), static_cast<int64_t>(im_h),\n                       static_cast<int64_t>(im_w)}),\n                DataType::kInt8);\n  int mask_idx = 0;\n  for (const auto& mask_mat : mask_mat_vec) {\n    CHECK(mask_mat.isContinuous());\n    CHECK_EQ(mask_mat.total(), im_h * im_w);\n    memcpy(masks->mut_data<int8_t>() + mask_idx * im_h * im_w, mask_mat.ptr<int8_t>(),\n           mask_mat.total() * sizeof(int8_t));\n    mask_idx += 1;\n  }\n}\n\n#define MAKE_POLYGONS_TO_MASK_SWITCH_ENTRY(func_name, T, I) func_name<T, I>\nDEFINE_STATIC_SWITCH_FUNC(void, PolygonsToMask, MAKE_POLYGONS_TO_MASK_SWITCH_ENTRY,\n                          MAKE_DATA_TYPE_CTRV_SEQ(FLOATING_DATA_TYPE_SEQ),\n                          MAKE_DATA_TYPE_CTRV_SEQ(INDEX_DATA_TYPE_SEQ));\n\n#undef MAKE_POLYGONS_TO_MASK_SWITCH_ENTRY\n\n}  // namespace\n\nclass ImageFlipKernel final : public user_op::OpKernel {\n public:\n  ImageFlipKernel() = default;\n  ~ImageFlipKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* in_tensor = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    const user_op::Tensor* flip_code_tensor = ctx->Tensor4ArgNameAndIndex(\"flip_code\", 0);\n    user_op::Tensor* out_tensor = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    int num_images = in_tensor->shape_view().elem_cnt();\n    CHECK_EQ(out_tensor->shape_view().elem_cnt(), num_images);\n\n    MultiThreadLoop(num_images, [&](size_t i) {\n      const TensorBuffer& in_buffer = in_tensor->dptr<TensorBuffer>()[i];\n      CHECK_EQ(in_buffer.shape_view().NumAxes(), 3);\n      TensorBuffer* out_buffer = out_tensor->mut_dptr<TensorBuffer>() + i;\n      out_buffer->CopyFrom(in_buffer);\n      FlipCode flip_code = static_cast<FlipCode>(flip_code_tensor->dptr<int8_t>()[i]);\n      if (flip_code != FlipCode::kNonFlip) { FlipImage(out_buffer, flip_code); }\n    });\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\nclass ObjectBboxFlipKernel final : public user_op::OpKernel {\n public:\n  ObjectBboxFlipKernel() = default;\n  ~ObjectBboxFlipKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* bbox_tensor = ctx->Tensor4ArgNameAndIndex(\"bbox\", 0);\n    const user_op::Tensor* image_size_tensor = ctx->Tensor4ArgNameAndIndex(\"image_size\", 0);\n    const user_op::Tensor* flip_code_tensor = ctx->Tensor4ArgNameAndIndex(\"flip_code\", 0);\n    user_op::Tensor* out_tensor = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n\n    int num_images = bbox_tensor->shape_view().elem_cnt();\n    CHECK_GT(num_images, 0);\n    CHECK_EQ(out_tensor->shape_view().elem_cnt(), num_images);\n    CHECK_EQ(image_size_tensor->shape_view().At(0), num_images);\n    CHECK_EQ(flip_code_tensor->shape_view().elem_cnt(), num_images);\n\n    MultiThreadLoop(num_images, [&](size_t i) {\n      const TensorBuffer& bbox_buffer = bbox_tensor->dptr<TensorBuffer>()[i];\n      CHECK_EQ(bbox_buffer.shape_view().NumAxes(), 2);\n      CHECK_EQ(bbox_buffer.shape_view().At(1), 4);\n      TensorBuffer* out_bbox_buffer = out_tensor->mut_dptr<TensorBuffer>() + i;\n      out_bbox_buffer->CopyFrom(bbox_buffer);\n      int32_t image_width = image_size_tensor->dptr<int32_t>()[i * 2 + 0];\n      int32_t image_height = image_size_tensor->dptr<int32_t>()[i * 2 + 1];\n      FlipCode flip_code = static_cast<FlipCode>(flip_code_tensor->dptr<int8_t>()[i]);\n      SwitchFlipBoxes(SwitchCase(out_bbox_buffer->data_type()), out_bbox_buffer, image_width,\n                      image_height, flip_code);\n    });\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\nclass ObjectBboxScaleKernel final : public user_op::OpKernel {\n public:\n  ObjectBboxScaleKernel() = default;\n  ~ObjectBboxScaleKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* bbox_tensor = ctx->Tensor4ArgNameAndIndex(\"bbox\", 0);\n    const user_op::Tensor* scale_tensor = ctx->Tensor4ArgNameAndIndex(\"scale\", 0);\n    user_op::Tensor* out_tensor = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n\n    int num_images = bbox_tensor->shape_view().elem_cnt();\n    CHECK_GT(num_images, 0);\n    CHECK_EQ(scale_tensor->shape_view().At(0), num_images);\n    CHECK_EQ(out_tensor->shape_view().elem_cnt(), num_images);\n\n    MultiThreadLoop(num_images, [&](size_t i) {\n      const TensorBuffer& bbox_buffer = bbox_tensor->dptr<TensorBuffer>()[i];\n      CHECK_EQ(bbox_buffer.shape_view().NumAxes(), 2);\n      CHECK_EQ(bbox_buffer.shape_view().At(1), 4);\n      TensorBuffer* out_bbox_buffer = out_tensor->mut_dptr<TensorBuffer>() + i;\n      out_bbox_buffer->CopyFrom(bbox_buffer);\n      float scale_w = scale_tensor->dptr<float>()[i * 2 + 0];\n      float scale_h = scale_tensor->dptr<float>()[i * 2 + 1];\n      SwitchScaleBoxes(SwitchCase(out_bbox_buffer->data_type()), out_bbox_buffer, scale_w, scale_h);\n    });\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\nclass ObjectSegmentationPolygonFlipKernel final : public user_op::OpKernel {\n public:\n  ObjectSegmentationPolygonFlipKernel() = default;\n  ~ObjectSegmentationPolygonFlipKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* polygon_tensor = ctx->Tensor4ArgNameAndIndex(\"poly\", 0);\n    const user_op::Tensor* image_size_tensor = ctx->Tensor4ArgNameAndIndex(\"image_size\", 0);\n    const user_op::Tensor* flip_code_tensor = ctx->Tensor4ArgNameAndIndex(\"flip_code\", 0);\n    user_op::Tensor* out_tensor = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n\n    int num_images = polygon_tensor->shape_view().elem_cnt();\n    CHECK_GT(num_images, 0);\n    CHECK_EQ(out_tensor->shape_view().elem_cnt(), num_images);\n    CHECK_EQ(image_size_tensor->shape_view().At(0), num_images);\n    CHECK_EQ(flip_code_tensor->shape_view().elem_cnt(), num_images);\n\n    MultiThreadLoop(num_images, [&](size_t i) {\n      const TensorBuffer& polygons_buffer = polygon_tensor->dptr<TensorBuffer>()[i];\n      CHECK_EQ(polygons_buffer.shape_view().NumAxes(), 2);\n      CHECK_EQ(polygons_buffer.shape_view().At(1), 2);\n      TensorBuffer* out_polygons_buffer = out_tensor->mut_dptr<TensorBuffer>() + i;\n      out_polygons_buffer->CopyFrom(polygons_buffer);\n      int32_t image_width = image_size_tensor->dptr<int32_t>()[i * 2 + 0];\n      int32_t image_height = image_size_tensor->dptr<int32_t>()[i * 2 + 1];\n      FlipCode flip_code = static_cast<FlipCode>(flip_code_tensor->dptr<int8_t>()[i]);\n      SwitchFlipPolygons(SwitchCase(out_polygons_buffer->data_type()), out_polygons_buffer,\n                         image_width, image_height, flip_code);\n    });\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\nclass ObjectSegmentationPolygonScaleKernel final : public user_op::OpKernel {\n public:\n  ObjectSegmentationPolygonScaleKernel() = default;\n  ~ObjectSegmentationPolygonScaleKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* poly_tensor = ctx->Tensor4ArgNameAndIndex(\"poly\", 0);\n    const user_op::Tensor* scale_tensor = ctx->Tensor4ArgNameAndIndex(\"scale\", 0);\n    user_op::Tensor* out_tensor = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n\n    int num_images = poly_tensor->shape_view().elem_cnt();\n    CHECK_GT(num_images, 0);\n    CHECK_EQ(scale_tensor->shape_view().At(0), num_images);\n    CHECK_EQ(out_tensor->shape_view().elem_cnt(), num_images);\n\n    MultiThreadLoop(num_images, [&](size_t i) {\n      const TensorBuffer& poly_buffer = poly_tensor->dptr<TensorBuffer>()[i];\n      CHECK_EQ(poly_buffer.shape_view().NumAxes(), 2);\n      CHECK_EQ(poly_buffer.shape_view().At(1), 2);\n      TensorBuffer* out_poly_buffer = out_tensor->mut_dptr<TensorBuffer>() + i;\n      out_poly_buffer->CopyFrom(poly_buffer);\n      float scale_w = scale_tensor->dptr<float>()[i * 2 + 0];\n      float scale_h = scale_tensor->dptr<float>()[i * 2 + 1];\n      SwitchScalePolygons(SwitchCase(out_poly_buffer->data_type()), out_poly_buffer, scale_w,\n                          scale_h);\n    });\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\nclass ImageNormalize final : public user_op::OpKernel {\n public:\n  ImageNormalize() = default;\n  ~ImageNormalize() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* in_tensor = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    user_op::Tensor* out_tensor = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    int num_images = in_tensor->shape_view().elem_cnt();\n    CHECK_EQ(out_tensor->shape_view().elem_cnt(), num_images);\n    const auto& std_vec = ctx->Attr<std::vector<float>>(\"std\");\n    const auto& mean_vec = ctx->Attr<std::vector<float>>(\"mean\");\n\n    MultiThreadLoop(num_images, [&](size_t i) {\n      const TensorBuffer& in_buffer = in_tensor->dptr<TensorBuffer>()[i];\n      CHECK_EQ(in_buffer.shape_view().NumAxes(), 3);\n      TensorBuffer* out_buffer = out_tensor->mut_dptr<TensorBuffer>() + i;\n      out_buffer->CopyFrom(in_buffer);\n      SwitchImageNormalizeByChannel(SwitchCase(out_buffer->data_type()), out_buffer, std_vec,\n                                    mean_vec);\n    });\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\nclass ObjectSegmentationPolygonToMask final : public user_op::OpKernel {\n public:\n  ObjectSegmentationPolygonToMask() = default;\n  ~ObjectSegmentationPolygonToMask() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* poly_tensor = ctx->Tensor4ArgNameAndIndex(\"poly\", 0);\n    const user_op::Tensor* poly_index_tensor = ctx->Tensor4ArgNameAndIndex(\"poly_index\", 0);\n    const user_op::Tensor* image_size_tensor = ctx->Tensor4ArgNameAndIndex(\"image_size\", 0);\n    user_op::Tensor* mask_tensor = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n\n    int num_images = poly_tensor->shape_view().elem_cnt();\n    CHECK_GT(num_images, 0);\n    CHECK_EQ(poly_index_tensor->shape_view().elem_cnt(), num_images);\n    CHECK_EQ(image_size_tensor->shape_view().At(0), num_images);\n    CHECK_EQ(mask_tensor->shape_view().elem_cnt(), num_images);\n\n    MultiThreadLoop(num_images, [&](size_t i) {\n      const TensorBuffer& poly_buffer = poly_tensor->dptr<TensorBuffer>()[i];\n      const TensorBuffer& poly_index_buffer = poly_index_tensor->dptr<TensorBuffer>()[i];\n      int32_t image_width = image_size_tensor->dptr<int32_t>()[i * 2 + 0];\n      int32_t image_height = image_size_tensor->dptr<int32_t>()[i * 2 + 1];\n      TensorBuffer* mask_buffer = mask_tensor->mut_dptr<TensorBuffer>() + i;\n      SwitchPolygonsToMask(SwitchCase(poly_buffer.data_type(), poly_index_buffer.data_type()),\n                           poly_buffer, poly_index_buffer, mask_buffer, image_width, image_height);\n    });\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\nnamespace {\n\nstd::function<Maybe<void>(const user_op::InferContext&, user_op::AddInplaceArgPair)>\nMakeInplaceProposalFn(const std::string& input_arg_name) {\n  return [input_arg_name](const user_op::InferContext& ctx,\n                          user_op::AddInplaceArgPair AddInplaceArgPairFn) -> Maybe<void> {\n    OF_RETURN_IF_ERROR(AddInplaceArgPairFn(\"out\", 0, input_arg_name, 0, true));\n    return Maybe<void>::Ok();\n  };\n}\n\n}  // namespace\n\nREGISTER_USER_KERNEL(\"image_flip\")\n    .SetCreateFn<ImageFlipKernel>()\n    .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)\n                     && (user_op::HobDataType(\"in\", 0) == DataType::kTensorBuffer)\n                     && (user_op::HobDataType(\"flip_code\", 0) == DataType::kInt8)\n                     && (user_op::HobDataType(\"out\", 0) == DataType::kTensorBuffer))\n    .SetInplaceProposalFn(MakeInplaceProposalFn(\"in\"));\n\nREGISTER_USER_KERNEL(\"object_bbox_flip\")\n    .SetCreateFn<ObjectBboxFlipKernel>()\n    .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)\n                     && (user_op::HobDataType(\"bbox\", 0) == DataType::kTensorBuffer)\n                     && (user_op::HobDataType(\"image_size\", 0) == DataType::kInt32)\n                     && (user_op::HobDataType(\"flip_code\", 0) == DataType::kInt8)\n                     && (user_op::HobDataType(\"out\", 0) == DataType::kTensorBuffer))\n    .SetInplaceProposalFn(MakeInplaceProposalFn(\"bbox\"));\n\nREGISTER_USER_KERNEL(\"object_bbox_scale\")\n    .SetCreateFn<ObjectBboxScaleKernel>()\n    .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)\n                     && (user_op::HobDataType(\"bbox\", 0) == DataType::kTensorBuffer)\n                     && (user_op::HobDataType(\"scale\", 0) == DataType::kFloat)\n                     && (user_op::HobDataType(\"out\", 0) == DataType::kTensorBuffer))\n    .SetInplaceProposalFn(MakeInplaceProposalFn(\"bbox\"));\n\nREGISTER_USER_KERNEL(\"object_segmentation_polygon_flip\")\n    .SetCreateFn<ObjectSegmentationPolygonFlipKernel>()\n    .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)\n                     && (user_op::HobDataType(\"poly\", 0) == DataType::kTensorBuffer)\n                     && (user_op::HobDataType(\"image_size\", 0) == DataType::kInt32)\n                     && (user_op::HobDataType(\"flip_code\", 0) == DataType::kInt8)\n                     && (user_op::HobDataType(\"out\", 0) == DataType::kTensorBuffer))\n    .SetInplaceProposalFn(MakeInplaceProposalFn(\"poly\"));\n\nREGISTER_USER_KERNEL(\"object_segmentation_polygon_scale\")\n    .SetCreateFn<ObjectSegmentationPolygonScaleKernel>()\n    .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)\n                     && (user_op::HobDataType(\"poly\", 0) == DataType::kTensorBuffer)\n                     && (user_op::HobDataType(\"scale\", 0) == DataType::kFloat)\n                     && (user_op::HobDataType(\"out\", 0) == DataType::kTensorBuffer))\n    .SetInplaceProposalFn(MakeInplaceProposalFn(\"poly\"));\n\nREGISTER_USER_KERNEL(\"image_normalize\")\n    .SetCreateFn<ImageNormalize>()\n    .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)\n                     && (user_op::HobDataType(\"in\", 0) == DataType::kTensorBuffer)\n                     && (user_op::HobDataType(\"out\", 0) == DataType::kTensorBuffer))\n    .SetInplaceProposalFn(MakeInplaceProposalFn(\"in\"));\n\nREGISTER_USER_KERNEL(\"object_segmentation_polygon_to_mask\")\n    .SetCreateFn<ObjectSegmentationPolygonToMask>()\n    .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)\n                     && (user_op::HobDataType(\"poly\", 0) == DataType::kTensorBuffer)\n                     && (user_op::HobDataType(\"poly_index\", 0) == DataType::kTensorBuffer)\n                     && (user_op::HobDataType(\"image_size\", 0) == DataType::kInt32)\n                     && (user_op::HobDataType(\"out\", 0) == DataType::kTensorBuffer));\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/image_preprocess_kernels.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/memory_format.pb.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/common/balanced_splitter.h\"\n#include \"oneflow/core/common/blocking_counter.h\"\n#include \"oneflow/core/common/tensor_buffer.h\"\n#include \"oneflow/core/kernel/new_kernel_util.h\"\n#include \"oneflow/core/thread/thread_manager.h\"\n#include \"oneflow/user/image/image_util.h\"\n#include \"oneflow/user/kernels/random_crop_kernel_state.h\"\n#include \"oneflow/user/kernels/random_seed_util.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<MemoryFormat layout>\ninline int64_t GetOffset(int64_t h, int64_t w, int64_t c, int64_t H, int64_t W, int64_t C);\n\ntemplate<>\ninline int64_t GetOffset<MemoryFormat::kContiguous>(int64_t h, int64_t w, int64_t c, int64_t H,\n                                                    int64_t W, int64_t C) {\n  return c * H * W + h * W + w;  // C, H, W\n}\n\ntemplate<>\ninline int64_t GetOffset<MemoryFormat::kChannelsLast>(int64_t h, int64_t w, int64_t c, int64_t H,\n                                                      int64_t W, int64_t C) {\n  return h * W * C + w * C + c;  // H, W, C\n}\n\ntemplate<bool mirror>\ninline int64_t GetInputW(int64_t out_w, int64_t out_W, int64_t in_W, float crop_pos_x);\n\ntemplate<>\ninline int64_t GetInputW<true>(int64_t out_w, int64_t out_W, int64_t in_W, float crop_pos_x) {\n  return (in_W - out_W) * crop_pos_x + (out_W - 1 - out_w);\n}\n\ntemplate<>\ninline int64_t GetInputW<false>(int64_t out_w, int64_t out_W, int64_t in_W, float crop_pos_x) {\n  return (in_W - out_W) * crop_pos_x + out_w;\n}\n\ntemplate<MemoryFormat output_layout, bool mirror>\nvoid CMN1Sample(int64_t C, int64_t in_H, int64_t in_W, int64_t out_H, int64_t out_W,\n                float crop_pos_y, float crop_pos_x, const uint8_t* in_dptr, float* out_dptr,\n                const std::vector<float>& mean_vec, const std::vector<float>& inv_std_vec) {\n  CHECK_LE(out_H, in_H);\n  CHECK_LE(out_W, in_W);\n  for (int64_t c = 0; c < C; ++c) {\n    float mean = mean_vec.at(c);\n    float inv_std = inv_std_vec.at(c);\n    for (int64_t out_h = 0; out_h < out_H; ++out_h) {\n      int64_t in_h = (in_H - out_H) * crop_pos_y + out_h;\n      for (int64_t out_w = 0; out_w < out_W; ++out_w) {\n        int64_t in_w = GetInputW<mirror>(out_w, out_W, in_W, crop_pos_x);\n        int64_t in_offset = GetOffset<MemoryFormat::kChannelsLast>(in_h, in_w, c, in_H, in_W, C);\n        int64_t out_offset = GetOffset<output_layout>(out_h, out_w, c, out_H, out_W, C);\n        out_dptr[out_offset] = (static_cast<float>(in_dptr[in_offset]) - mean) * inv_std;\n      }\n    }\n  }\n}\n\nstd::vector<int8_t> GetMirrorVec(user_op::KernelComputeContext* ctx) {\n  std::vector<int8_t> mirror;\n  user_op::Tensor* in_blob = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n  user_op::Tensor* mirror_blob = ctx->Tensor4ArgNameAndIndex(\"mirror\", 0);\n  int64_t record_num = in_blob->shape_view().At(0);\n  if (mirror_blob) {\n    CHECK_EQ(record_num, mirror_blob->shape_view().elem_cnt());\n    mirror.insert(mirror.end(), mirror_blob->dptr<int8_t>(),\n                  mirror_blob->dptr<int8_t>() + record_num);\n  } else {\n    mirror.resize(record_num, 0);\n  }\n  return mirror;\n}\n\nclass CMNAttr final : public user_op::OpKernelState {\n public:\n  CMNAttr(user_op::KernelInitContext* ctx) {\n    mean_vec_ = ctx->Attr<std::vector<float>>(\"mean\");\n    const std::vector<float>& std_vec = ctx->Attr<std::vector<float>>(\"std\");\n    const std::string& color_space = ctx->Attr<std::string>(\"color_space\");\n    int64_t C = ImageUtil::IsColor(color_space) ? 3 : 1;\n    CHECK(mean_vec_.size() == 1 || mean_vec_.size() == C);\n    CHECK(std_vec.size() == 1 || std_vec.size() == C);\n    for (float elem : std_vec) { inv_std_vec_.emplace_back(1.0f / elem); }\n    if (mean_vec_.size() == 1) { mean_vec_.resize(C, mean_vec_.at(0)); }\n    if (inv_std_vec_.size() == 1) { inv_std_vec_.resize(C, inv_std_vec_.at(0)); }\n  }\n  ~CMNAttr() = default;\n\n  const std::vector<float>& mean_vec() const { return mean_vec_; }\n  const std::vector<float>& inv_std_vec() const { return inv_std_vec_; }\n\n private:\n  std::vector<float> mean_vec_;\n  std::vector<float> inv_std_vec_;\n};\n\n}  // namespace\n\nclass CropMirrorNormalizeFromStaticShapeToFloatKernel final : public user_op::OpKernel {\n public:\n  CropMirrorNormalizeFromStaticShapeToFloatKernel() = default;\n  ~CropMirrorNormalizeFromStaticShapeToFloatKernel() override = default;\n\n  std::shared_ptr<user_op::OpKernelState> CreateOpKernelState(\n      user_op::KernelInitContext* ctx) const override {\n    return std::make_shared<CMNAttr>(ctx);\n  }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state,\n               const user_op::OpKernelCache*) const override {\n    auto* cmn_attr = dynamic_cast<CMNAttr*>(state);\n    const std::vector<float>& mean_vec = cmn_attr->mean_vec();\n    const std::vector<float>& inv_std_vec = cmn_attr->inv_std_vec();\n    user_op::Tensor* in_blob = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    user_op::Tensor* out_blob = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    std::vector<int8_t> mirror = GetMirrorVec(ctx);\n    int64_t record_num = in_blob->shape_view().At(0);\n    const std::string& color_space = ctx->Attr<std::string>(\"color_space\");\n    int64_t C = ImageUtil::IsColor(color_space) ? 3 : 1;\n    float crop_pos_y = ctx->Attr<float>(\"crop_pos_y\");\n    float crop_pos_x = ctx->Attr<float>(\"crop_pos_x\");\n    const std::string& output_layout = ctx->Attr<std::string>(\"output_layout\");\n    float* out_dptr = out_blob->mut_dptr<float>();\n\n    const uint8_t* in_dptr = in_blob->dptr<uint8_t>();\n    const ShapeView& in_shape = in_blob->shape_view();\n    int64_t N = in_shape.At(0);\n    int64_t in_H = in_shape.At(1);\n    int64_t in_W = in_shape.At(2);\n    CHECK_EQ(C, in_shape.At(3));\n    int64_t in_image_elem_cnt = in_H * in_W * C;\n    const ShapeView& out_shape = out_blob->shape_view();\n    CHECK_EQ(out_shape.NumAxes(), 4);\n    CHECK_EQ(out_shape.At(0), N);\n    if (output_layout == \"NCHW\") {\n      CHECK_EQ(out_shape.At(1), C);\n      int64_t out_H = out_shape.At(2);\n      int64_t out_W = out_shape.At(3);\n      int64_t out_image_elem_cnt = C * out_H * out_W;\n      MultiThreadLoop(record_num, [&](size_t i) {\n        if (mirror.at(i)) {\n          CMN1Sample<MemoryFormat::kContiguous, true>(\n              C, in_H, in_W, out_H, out_W, crop_pos_y, crop_pos_x, in_dptr + in_image_elem_cnt * i,\n              out_dptr + out_image_elem_cnt * i, mean_vec, inv_std_vec);\n        } else {\n          CMN1Sample<MemoryFormat::kContiguous, false>(\n              C, in_H, in_W, out_H, out_W, crop_pos_y, crop_pos_x, in_dptr + in_image_elem_cnt * i,\n              out_dptr + out_image_elem_cnt * i, mean_vec, inv_std_vec);\n        }\n      });\n    } else if (output_layout == \"NHWC\") {\n      CHECK_EQ(out_shape.At(3), C);\n      int64_t out_H = out_shape.At(1);\n      int64_t out_W = out_shape.At(2);\n      int64_t out_image_elem_cnt = C * out_H * out_W;\n      MultiThreadLoop(record_num, [&](size_t i) {\n        if (mirror.at(i)) {\n          CMN1Sample<MemoryFormat::kChannelsLast, true>(\n              C, in_H, in_W, out_H, out_W, crop_pos_y, crop_pos_x, in_dptr + in_image_elem_cnt * i,\n              out_dptr + out_image_elem_cnt * i, mean_vec, inv_std_vec);\n        } else {\n          CMN1Sample<MemoryFormat::kChannelsLast, false>(\n              C, in_H, in_W, out_H, out_W, crop_pos_y, crop_pos_x, in_dptr + in_image_elem_cnt * i,\n              out_dptr + out_image_elem_cnt * i, mean_vec, inv_std_vec);\n        }\n      });\n    } else {\n      UNIMPLEMENTED();\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\nREGISTER_USER_KERNEL(\"crop_mirror_normalize_from_uint8\")\n    .SetCreateFn<CropMirrorNormalizeFromStaticShapeToFloatKernel>()\n    .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)\n                     && (user_op::HobDataType(\"in\", 0) == DataType::kUInt8)\n                     && (user_op::HobDataType(\"out\", 0) == DataType::kFloat));\n\nclass CropMirrorNormalizeFromTensorBufferToFloatKernel final : public user_op::OpKernel {\n public:\n  CropMirrorNormalizeFromTensorBufferToFloatKernel() = default;\n  ~CropMirrorNormalizeFromTensorBufferToFloatKernel() override = default;\n\n  std::shared_ptr<user_op::OpKernelState> CreateOpKernelState(\n      user_op::KernelInitContext* ctx) const override {\n    return std::make_shared<CMNAttr>(ctx);\n  }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state,\n               const user_op::OpKernelCache*) const override {\n    auto* cmn_attr = dynamic_cast<CMNAttr*>(state);\n    const std::vector<float>& mean_vec = cmn_attr->mean_vec();\n    const std::vector<float>& inv_std_vec = cmn_attr->inv_std_vec();\n    user_op::Tensor* in_blob = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    user_op::Tensor* out_blob = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    std::vector<int8_t> mirror = GetMirrorVec(ctx);\n    int64_t record_num = in_blob->shape_view().At(0);\n    const std::string& color_space = ctx->Attr<std::string>(\"color_space\");\n    int64_t C = ImageUtil::IsColor(color_space) ? 3 : 1;\n    float crop_pos_y = ctx->Attr<float>(\"crop_pos_y\");\n    float crop_pos_x = ctx->Attr<float>(\"crop_pos_x\");\n    const std::string& output_layout = ctx->Attr<std::string>(\"output_layout\");\n    float* out_dptr = out_blob->mut_dptr<float>();\n\n    const TensorBuffer* in_buffers = in_blob->dptr<TensorBuffer>();\n    const ShapeView& in_shape = in_blob->shape_view();\n    int64_t N = in_shape.At(0);\n    CHECK_EQ(in_shape.NumAxes(), 1);\n    const ShapeView& out_shape = out_blob->shape_view();\n    CHECK_EQ(out_shape.NumAxes(), 4);\n    CHECK_EQ(out_shape.At(0), N);\n    if (output_layout == \"NCHW\") {\n      CHECK_EQ(out_shape.At(1), C);\n      int64_t out_H = out_shape.At(2);\n      int64_t out_W = out_shape.At(3);\n      int64_t out_image_elem_cnt = C * out_H * out_W;\n      MultiThreadLoop(record_num, [&](size_t i) {\n        const TensorBuffer* in_buffer = in_buffers + i;\n        const Shape& in_shape = in_buffer->shape();\n        CHECK_EQ(in_shape.NumAxes(), 3);  // H, W, C\n        int64_t in_H = in_shape.At(0);\n        int64_t in_W = in_shape.At(1);\n        CHECK_EQ(C, in_shape.At(2));\n        if (mirror.at(i)) {\n          CMN1Sample<MemoryFormat::kContiguous, true>(\n              C, in_H, in_W, out_H, out_W, crop_pos_y, crop_pos_x, in_buffer->data<uint8_t>(),\n              out_dptr + out_image_elem_cnt * i, mean_vec, inv_std_vec);\n        } else {\n          CMN1Sample<MemoryFormat::kContiguous, false>(\n              C, in_H, in_W, out_H, out_W, crop_pos_y, crop_pos_x, in_buffer->data<uint8_t>(),\n              out_dptr + out_image_elem_cnt * i, mean_vec, inv_std_vec);\n        }\n      });\n    } else if (output_layout == \"NHWC\") {\n      CHECK_EQ(out_shape.At(3), C);\n      int64_t out_H = out_shape.At(1);\n      int64_t out_W = out_shape.At(2);\n      int64_t out_image_elem_cnt = C * out_H * out_W;\n      MultiThreadLoop(record_num, [&](size_t i) {\n        const TensorBuffer* in_buffer = in_buffers + i;\n        const Shape& in_shape = in_buffer->shape();\n        CHECK_EQ(in_shape.NumAxes(), 3);  // H, W, C\n        int64_t in_H = in_shape.At(0);\n        int64_t in_W = in_shape.At(1);\n        CHECK_EQ(C, in_shape.At(2));\n        if (mirror.at(i)) {\n          CMN1Sample<MemoryFormat::kChannelsLast, true>(\n              C, in_H, in_W, out_H, out_W, crop_pos_y, crop_pos_x, in_buffer->data<uint8_t>(),\n              out_dptr + out_image_elem_cnt * i, mean_vec, inv_std_vec);\n        } else {\n          CMN1Sample<MemoryFormat::kChannelsLast, false>(\n              C, in_H, in_W, out_H, out_W, crop_pos_y, crop_pos_x, in_buffer->data<uint8_t>(),\n              out_dptr + out_image_elem_cnt * i, mean_vec, inv_std_vec);\n        }\n      });\n    } else {\n      UNIMPLEMENTED();\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\nREGISTER_USER_KERNEL(\"crop_mirror_normalize_from_tensorbuffer\")\n    .SetCreateFn<CropMirrorNormalizeFromTensorBufferToFloatKernel>()\n    .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)\n                     && (user_op::HobDataType(\"in\", 0) == DataType::kTensorBuffer)\n                     && (user_op::HobDataType(\"out\", 0) == DataType::kFloat));\n\nnamespace {\n\nclass RandBoolGen final : public user_op::OpKernelState {\n public:\n  explicit RandBoolGen(float prob, int64_t seed) : dis_(prob), rng_(seed) {}\n  ~RandBoolGen() = default;\n\n  bool GetNextBool() { return dis_(rng_); }\n\n private:\n  std::bernoulli_distribution dis_;\n  std::mt19937 rng_;\n};\n\n}  // namespace\n\nclass CoinFlipKernel final : public user_op::OpKernel {\n public:\n  CoinFlipKernel() = default;\n  ~CoinFlipKernel() override = default;\n\n  std::shared_ptr<user_op::OpKernelState> CreateOpKernelState(\n      user_op::KernelInitContext* ctx) const override {\n    float prob = ctx->Attr<float>(\"probability\");\n    int64_t seed = CHECK_JUST(GetOpKernelRandomSeed(ctx));\n    std::shared_ptr<RandBoolGen> rand_bool_gen(new RandBoolGen(prob, seed));\n    return rand_bool_gen;\n  }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state,\n               const user_op::OpKernelCache*) const override {\n    auto* rand_bool_gen = dynamic_cast<RandBoolGen*>(state);\n    user_op::Tensor* out_blob = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    int8_t* dptr = out_blob->mut_dptr<int8_t>();\n    for (int32_t i = 0; i < out_blob->shape_view().elem_cnt(); ++i) {\n      *(dptr + i) = rand_bool_gen->GetNextBool() ? 1 : 0;\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\nREGISTER_USER_KERNEL(\"coin_flip\")\n    .SetCreateFn<CoinFlipKernel>()\n    .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)\n                     && (user_op::HobDataType(\"out\", 0) == DataType::kInt8));\n\nnamespace {\n\nvoid ImageRandomCropImpl(const TensorBuffer* in_buffer, TensorBuffer* out_buffer,\n                         RandomCropGenerator* random_crop_gen) {\n  cv::Mat image = GenCvMat4ImageBuffer(*in_buffer);\n  int W = image.cols;\n  int H = image.rows;\n  cv::Mat image_roi;\n  CropWindow crop;\n  random_crop_gen->GenerateCropWindow({H, W}, &crop);\n  const int y = crop.anchor.At(0);\n  const int x = crop.anchor.At(1);\n  const int new_h = crop.shape.At(0);\n  const int new_w = crop.shape.At(1);\n  CHECK(new_w > 0 && new_w <= W);\n  CHECK(new_h > 0 && new_h <= H);\n  cv::Rect roi(x, y, new_w, new_h);\n  image(roi).copyTo(image_roi);\n  image = image_roi;\n  W = image.cols;\n  H = image.rows;\n\n  CHECK(image.isContinuous());\n  const int c = in_buffer->shape_view().At(2);\n  CHECK_EQ(c, image.channels());\n  Shape image_shape({H, W, c});\n  out_buffer->Resize(image_shape, in_buffer->data_type());\n  memcpy(out_buffer->mut_data<>(), image.ptr(), out_buffer->nbytes());\n}\n\n}  // namespace\n\nclass ImageRandomCropKernel final : public user_op::OpKernel {\n public:\n  ImageRandomCropKernel() = default;\n  ~ImageRandomCropKernel() override = default;\n\n  std::shared_ptr<user_op::OpKernelState> CreateOpKernelState(\n      user_op::KernelInitContext* ctx) const override {\n    return CreateRandomCropKernelState(ctx);\n  }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state,\n               const user_op::OpKernelCache*) const override {\n    auto* crop_window_generators = dynamic_cast<RandomCropKernelState*>(state);\n    CHECK_NOTNULL(crop_window_generators);\n    user_op::Tensor* out_blob = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    int64_t record_num = out_blob->shape_view().elem_cnt();\n    CHECK(record_num > 0);\n    user_op::Tensor* in_blob = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    CHECK_EQ(out_blob->shape_view(), in_blob->shape_view());\n    const TensorBuffer* in_buffers = in_blob->dptr<TensorBuffer>();\n    TensorBuffer* out_buffers = out_blob->mut_dptr<TensorBuffer>();\n    MultiThreadLoop(record_num, [&](size_t i) {\n      ImageRandomCropImpl(in_buffers + i, out_buffers + i, crop_window_generators->GetGenerator(i));\n    });\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\nREGISTER_USER_KERNEL(\"image_random_crop\")\n    .SetCreateFn<ImageRandomCropKernel>()\n    .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)\n                     && (user_op::HobDataType(\"in\", 0) == DataType::kTensorBuffer)\n                     && (user_op::HobDataType(\"out\", 0) == DataType::kTensorBuffer));\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/image_preprocess_kernels.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/memory_format.pb.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/common/small_vector.h\"\n#include \"oneflow/core/common/nd_index_offset_helper.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nstruct NormalizeVal {\n  float val[3];\n};\n\nclass NormalizeAttr final : public user_op::OpKernelState {\n public:\n  NormalizeAttr(user_op::KernelInitContext* ctx) {\n    const std::vector<float>& mean_vec = ctx->Attr<std::vector<float>>(\"mean\");\n    if (mean_vec.size() == 1) {\n      for (int i = 0; i < 3; ++i) { mean_.val[i] = mean_vec.at(0); }\n    } else if (mean_vec.size() == 3) {\n      for (int i = 0; i < 3; ++i) { mean_.val[i] = mean_vec.at(i); }\n    } else {\n      UNIMPLEMENTED();\n    }\n\n    const std::vector<float>& std_vec = ctx->Attr<std::vector<float>>(\"std\");\n    if (std_vec.size() == 1) {\n      for (int i = 0; i < 3; ++i) { inv_std_.val[i] = 1.0f / std_vec.at(0); }\n    } else if (std_vec.size() == 3) {\n      for (int i = 0; i < 3; ++i) { inv_std_.val[i] = 1.0f / std_vec.at(i); }\n    } else {\n      UNIMPLEMENTED();\n    }\n  }\n  ~NormalizeAttr() = default;\n\n  const NormalizeVal& mean() const { return mean_; }\n  const NormalizeVal& inv_std() const { return inv_std_; }\n\n private:\n  NormalizeVal mean_;\n  NormalizeVal inv_std_;\n};\n\ntemplate<MemoryFormat layout>\n__device__ __forceinline__ void OutIdx2InIdx(int32_t* out_idx, int32_t* in_idx,\n                                             const int8_t* mirror_dptr, int32_t out_W,\n                                             int32_t H_offset, int32_t W_offset);\ntemplate<>\n__device__ __forceinline__ void OutIdx2InIdx<MemoryFormat::kContiguous>(\n    int32_t* out_idx, int32_t* in_idx, const int8_t* mirror_dptr, int32_t out_W, int32_t H_offset,\n    int32_t W_offset) {\n  if (mirror_dptr && mirror_dptr[out_idx[0]]) { out_idx[3] = out_W - 1 - out_idx[3]; }\n  in_idx[0] = out_idx[0];             // N\n  in_idx[1] = out_idx[2] + H_offset;  // H\n  in_idx[2] = out_idx[3] + W_offset;  // W\n  in_idx[3] = out_idx[1];             // C\n}\n\ntemplate<>\n__device__ __forceinline__ void OutIdx2InIdx<MemoryFormat::kChannelsLast>(\n    int32_t* out_idx, int32_t* in_idx, const int8_t* mirror_dptr, int32_t out_W, int32_t H_offset,\n    int32_t W_offset) {\n  if (mirror_dptr && mirror_dptr[out_idx[0]]) { out_idx[2] = out_W - 1 - out_idx[2]; }\n  in_idx[0] = out_idx[0];             // N\n  in_idx[1] = out_idx[1] + H_offset;  // H\n  in_idx[2] = out_idx[2] + W_offset;  // W\n  in_idx[3] = out_idx[3];             // C\n}\n\ntemplate<MemoryFormat layout>\n__global__ void CropMirrorNormalizeGpuImpl(int32_t elem_cnt, const uint8_t* in_dptr,\n                                           float* out_dptr, const int8_t* mirror_dptr,\n                                           int32_t out_W,\n                                           const NdIndexOffsetHelper<int32_t, 4> in_helper,\n                                           const NdIndexOffsetHelper<int32_t, 4> out_helper,\n                                           int32_t H_offset, int32_t W_offset,\n                                           const NormalizeVal mean, const NormalizeVal inv_std) {\n  CUDA_1D_KERNEL_LOOP(out_offset, elem_cnt) {\n    int32_t in_idx[4];\n    int32_t out_idx[4];\n    out_helper.OffsetToNdIndex(out_offset, out_idx);\n    OutIdx2InIdx<layout>(out_idx, in_idx, mirror_dptr, out_W, H_offset, W_offset);\n    float mean_val;\n    float inv_std_val;\n    const int32_t c = in_idx[3];\n    // When the compiler can't resolve array indices to constants it will put private arrays into\n    // GPU local memory. Using local memory is slower than keeping array elements directly in\n    // registers.\n    if (c == 0) {\n      mean_val = mean.val[0];\n      inv_std_val = inv_std.val[0];\n    } else if (c == 1) {\n      mean_val = mean.val[1];\n      inv_std_val = inv_std.val[1];\n    } else if (c == 2) {\n      mean_val = mean.val[2];\n      inv_std_val = inv_std.val[2];\n    } else {\n      // undefined behavior\n      assert(false);\n    }\n    int32_t in_offset = in_helper.NdIndexToOffset(in_idx);\n    out_dptr[out_offset] = (static_cast<float>(in_dptr[in_offset]) - mean_val) * inv_std_val;\n  }\n}\n\n}  // namespace\n\nclass CropMirrorNormalizeGpuKernel final : public user_op::OpKernel {\n public:\n  CropMirrorNormalizeGpuKernel() = default;\n  ~CropMirrorNormalizeGpuKernel() override = default;\n\n  std::shared_ptr<user_op::OpKernelState> CreateOpKernelState(\n      user_op::KernelInitContext* ctx) const override {\n    return std::make_shared<NormalizeAttr>(ctx);\n  }\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state,\n               const user_op::OpKernelCache*) const override {\n    auto* normalize_attr = dynamic_cast<NormalizeAttr*>(state);\n    const NormalizeVal& mean = normalize_attr->mean();\n    const NormalizeVal& inv_std = normalize_attr->inv_std();\n    user_op::Tensor* in_blob = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    user_op::Tensor* out_blob = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    const std::string& output_layout = ctx->Attr<std::string>(\"output_layout\");\n    float* out_dptr = out_blob->mut_dptr<float>();\n    const uint8_t* in_dptr = in_blob->dptr<uint8_t>();\n    const ShapeView& in_shape = in_blob->shape_view();\n    const ShapeView& out_shape = out_blob->shape_view();\n    CHECK_EQ(in_shape.NumAxes(), 4);\n    CHECK_EQ(out_shape.NumAxes(), 4);\n    int32_t elem_cnt = out_shape.elem_cnt();\n    CHECK_LE(elem_cnt, GetMaxVal<int32_t>());\n    float crop_pos_y = ctx->Attr<float>(\"crop_pos_y\");\n    float crop_pos_x = ctx->Attr<float>(\"crop_pos_x\");\n\n    int32_t N = in_shape.At(0);\n    int32_t in_H = in_shape.At(1);\n    int32_t in_W = in_shape.At(2);\n    int32_t C = in_shape.At(3);\n    const NdIndexOffsetHelper<int32_t, 4> in_helper(N, in_H, in_W, C);\n    const int8_t* mirror_dptr = nullptr;\n    user_op::Tensor* mirror_blob = ctx->Tensor4ArgNameAndIndex(\"mirror\", 0);\n    if (mirror_blob) { mirror_dptr = mirror_blob->dptr<int8_t>(); }\n\n    if (output_layout == \"NCHW\") {\n      CHECK_EQ(N, out_shape.At(0));\n      CHECK_EQ(C, out_shape.At(1));\n      int32_t out_H = out_shape.At(2);\n      int32_t out_W = out_shape.At(3);\n      CHECK_LE(out_H, in_H);\n      CHECK_LE(out_W, in_W);\n      int32_t H_offset = (in_H - out_H) * crop_pos_y;\n      int32_t W_offset = (in_W - out_W) * crop_pos_x;\n      const NdIndexOffsetHelper<int32_t, 4> out_helper(N, C, out_H, out_W);\n      CropMirrorNormalizeGpuImpl<MemoryFormat::kContiguous>\n          <<<BlocksNum4ThreadsNum(elem_cnt), kCudaThreadsNumPerBlock, 0,\n             ctx->stream()->As<ep::CudaStream>()->cuda_stream()>>>(\n              elem_cnt, in_dptr, out_dptr, mirror_dptr, out_W, in_helper, out_helper, H_offset,\n              W_offset, mean, inv_std);\n    } else if (output_layout == \"NHWC\") {\n      CHECK_EQ(N, out_shape.At(0));\n      int32_t out_H = out_shape.At(1);\n      int32_t out_W = out_shape.At(2);\n      CHECK_EQ(C, out_shape.At(3));\n      CHECK_LE(out_H, in_H);\n      CHECK_LE(out_W, in_W);\n      int32_t H_offset = (in_H - out_H) * crop_pos_y;\n      int32_t W_offset = (in_W - out_W) * crop_pos_x;\n      const NdIndexOffsetHelper<int32_t, 4> out_helper(N, out_H, out_W, C);\n      CropMirrorNormalizeGpuImpl<MemoryFormat::kChannelsLast>\n          <<<BlocksNum4ThreadsNum(elem_cnt), kCudaThreadsNumPerBlock, 0,\n             ctx->stream()->As<ep::CudaStream>()->cuda_stream()>>>(\n              elem_cnt, in_dptr, out_dptr, mirror_dptr, out_W, in_helper, out_helper, H_offset,\n              W_offset, mean, inv_std);\n    } else {\n      UNIMPLEMENTED();\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\nREGISTER_USER_KERNEL(\"crop_mirror_normalize_from_uint8\")\n    .SetCreateFn<CropMirrorNormalizeGpuKernel>()\n    .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)\n                     && (user_op::HobDataType(\"in\", 0) == DataType::kUInt8)\n                     && (user_op::HobDataType(\"out\", 0) == DataType::kFloat));\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/image_resize_kernels.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/thread/thread_manager.h\"\n#include \"oneflow/user/image/image_util.h\"\n#include <opencv2/opencv.hpp>\n#include <cfenv>\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename T>\nstd::pair<T, T> GetTargetResizedSize4ImageBuffer(const TensorBuffer& image_buffer,\n                                                 const bool resize_longer, const T target_size,\n                                                 const T min_size, const T max_size) {\n  CHECK_GT(target_size, 0);\n  if (min_size > 0) { CHECK_GE(target_size, min_size); }\n  if (max_size > 0) { CHECK_LE(target_size, max_size); }\n  CHECK_EQ(image_buffer.shape_view().NumAxes(), 3);\n  const T origin_height = image_buffer.shape_view().At(0);\n  const T origin_width = image_buffer.shape_view().At(1);\n\n  // set round to banker's rounding\n  int origin_round_way = std::fegetround();\n  CHECK_EQ(std::fesetround(FE_TONEAREST), 0);\n\n  double org_min_size = std::min<double>(origin_height, origin_width);\n  double org_max_size = std::max<double>(origin_height, origin_width);\n  double aspect_ratio = org_min_size / org_max_size;\n  double res_min_size = 0.0;\n  double res_max_size = 0.0;\n  if (resize_longer) {\n    res_max_size = static_cast<double>(target_size);\n    res_min_size = std::nearbyint(res_max_size * aspect_ratio);\n    if (min_size > 0 && res_min_size < min_size) {\n      res_min_size = static_cast<double>(min_size);\n      res_max_size = std::nearbyint(res_min_size / aspect_ratio);\n    }\n  } else {\n    res_min_size = static_cast<double>(target_size);\n    res_max_size = std::nearbyint(res_min_size / aspect_ratio);\n    if (max_size > 0 && res_max_size > max_size) {\n      res_max_size = static_cast<double>(max_size);\n      res_min_size = std::nearbyint(res_max_size * aspect_ratio);\n    }\n  }\n  std::fesetround(origin_round_way);\n\n  std::pair<T, T> width_and_height;\n  if (origin_width < origin_height) {\n    width_and_height.first = static_cast<T>(res_min_size);\n    width_and_height.second = static_cast<T>(res_max_size);\n  } else {\n    width_and_height.first = static_cast<T>(res_max_size);\n    width_and_height.second = static_cast<T>(res_min_size);\n  }\n  return width_and_height;\n}\n\nbool CheckMatSizeMatch(const cv::Mat& mat, const bool resize_longer, const int32_t target_size,\n                       const int32_t min_size, const int32_t max_size) {\n  bool is_size_match = true;\n  int mat_min_size = std::min(mat.rows, mat.cols);\n  int mat_max_size = std::max(mat.rows, mat.cols);\n  if (resize_longer) {\n    if (min_size > 0) {\n      is_size_match = (mat_max_size >= target_size) && (mat_min_size >= min_size)\n                      && (mat_min_size == min_size || mat_max_size == target_size);\n    } else {\n      is_size_match = (mat_max_size == target_size);\n    }\n  } else {\n    if (max_size > 0) {\n      is_size_match = (mat_min_size <= target_size) && (mat_max_size <= max_size)\n                      && (mat_min_size == target_size || mat_max_size == max_size);\n    } else {\n      is_size_match = (mat_min_size == target_size);\n    }\n  }\n  return is_size_match;\n}\n\nvoid ImageTargetResize(const TensorBuffer& image_buffer, TensorBuffer* resized_image_buffer,\n                       const bool resize_longer, const int32_t target_size, const int32_t min_size,\n                       const int32_t max_size, const std::string& interp_type) {\n  const cv::Mat image_mat = GenCvMat4ImageBuffer(image_buffer);\n  int64_t res_w = 0;\n  int64_t res_h = 0;\n  int64_t channels = image_mat.channels();\n  std::tie(res_w, res_h) = GetTargetResizedSize4ImageBuffer<int64_t>(\n      image_buffer, resize_longer, target_size, min_size, max_size);\n  resized_image_buffer->Resize(Shape({res_h, res_w, channels}), image_buffer.data_type());\n  cv::Mat res_image_mat = GenCvMat4ImageBuffer(*resized_image_buffer);\n  int interp_flag =\n      GetCvInterpolationFlag(interp_type, image_mat.cols, image_mat.rows, res_w, res_h);\n  cv::resize(image_mat, res_image_mat, cv::Size(res_w, res_h), 0, 0, interp_flag);\n\n  CHECK_EQ(res_image_mat.ptr<void>(), resized_image_buffer->data());\n  CHECK(CheckMatSizeMatch(res_image_mat, resize_longer, target_size, min_size, max_size));\n}\n\nclass ImageResizeToFixedSizeKernel final : public user_op::OpKernel {\n public:\n  ImageResizeToFixedSizeKernel() = default;\n  ~ImageResizeToFixedSizeKernel() override = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* in_tensor = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    CHECK_NOTNULL(in_tensor);\n    const int64_t batch_size = in_tensor->shape_view().elem_cnt();\n    CHECK_GT(batch_size, 0);\n\n    user_op::Tensor* out_tensor = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    CHECK_EQ(out_tensor->shape_view().NumAxes(), 4);\n    CHECK_EQ(out_tensor->shape_view().At(0), batch_size);\n    int64_t res_h = out_tensor->shape_view().At(1);\n    int64_t res_w = out_tensor->shape_view().At(2);\n    int64_t channels = out_tensor->shape_view().At(3);\n    int64_t elem_cnt_per_img = res_h * res_w * channels;\n\n    user_op::Tensor* scale_tensor = ctx->Tensor4ArgNameAndIndex(\"scale\", 0);\n    CHECK_EQ(scale_tensor->shape_view().NumAxes(), 2);\n    CHECK_EQ(scale_tensor->shape_view().At(0), batch_size);\n    CHECK_EQ(scale_tensor->shape_view().At(1), 2);\n\n    MultiThreadLoop(batch_size, [&](size_t i) {\n      const TensorBuffer& in_buffer = in_tensor->dptr<TensorBuffer>()[i];\n      CHECK_EQ(in_buffer.shape_view().NumAxes(), 3);\n      const int64_t origin_height = in_buffer.shape_view().At(0);\n      const int64_t origin_width = in_buffer.shape_view().At(1);\n      CHECK_EQ(in_buffer.shape_view().At(2), channels);\n      DataType dtype = ctx->Attr<DataType>(\"data_type\");\n      int interp_flag = GetCvInterpolationFlag(ctx->Attr<std::string>(\"interpolation_type\"),\n                                               origin_width, origin_height, res_w, res_h);\n\n      const cv::Mat in_img_mat = GenCvMat4ImageBuffer(in_buffer);\n      cv::Mat out_img_mat = GenCvMat4ImageTensor(out_tensor, i);\n      if (in_buffer.data_type() == dtype) {\n        cv::resize(in_img_mat, out_img_mat, cv::Size(res_w, res_h), 0, 0, interp_flag);\n      } else {\n        cv::Mat res_img_mat;\n        cv::resize(in_img_mat, res_img_mat, cv::Size(res_w, res_h), 0, 0, interp_flag);\n        CvMatConvertToDataType(res_img_mat, &out_img_mat, dtype);\n      }\n\n      char* cur_out_dptr =\n          out_tensor->mut_dptr<char>() + i * elem_cnt_per_img * GetSizeOfDataType(dtype);\n      CHECK(out_img_mat.isContinuous());\n      CHECK_EQ(out_img_mat.ptr<void>(), static_cast<void*>(cur_out_dptr));\n      CHECK_EQ(out_img_mat.cols, res_w);\n      CHECK_EQ(out_img_mat.rows, res_h);\n      CHECK_EQ(out_img_mat.channels(), channels);\n\n      if (scale_tensor) {\n        float* scale_dptr = scale_tensor->mut_dptr<float>() + i * 2;\n        scale_dptr[0] = static_cast<float>(res_w) / static_cast<float>(origin_width);\n        scale_dptr[1] = static_cast<float>(res_h) / static_cast<float>(origin_height);\n      }\n    });\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\nclass ImageResizeKeepAspectRatioKernel final : public user_op::OpKernel {\n public:\n  ImageResizeKeepAspectRatioKernel() = default;\n  ~ImageResizeKeepAspectRatioKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* in_tensor = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    user_op::Tensor* out_tensor = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    user_op::Tensor* size_tensor = ctx->Tensor4ArgNameAndIndex(\"size\", 0);\n    user_op::Tensor* scale_tensor = ctx->Tensor4ArgNameAndIndex(\"scale\", 0);\n    CHECK_NOTNULL(out_tensor);\n    CHECK_NOTNULL(size_tensor);\n    CHECK_NOTNULL(scale_tensor);\n    const TensorBuffer* in_img_buf = in_tensor->dptr<TensorBuffer>();\n    TensorBuffer* out_img_buf = out_tensor->mut_dptr<TensorBuffer>();\n    TensorBuffer* scale_buf = scale_tensor->mut_dptr<TensorBuffer>();\n    TensorBuffer* size_buf = size_tensor->mut_dptr<TensorBuffer>();\n\n    const int64_t num_images = in_tensor->shape_view().elem_cnt();\n    const bool resize_longer = ctx->Attr<bool>(\"resize_longer\");\n    const int32_t target_size = ctx->Attr<int32_t>(\"target_size\");\n    const int32_t min_size = ctx->Attr<int32_t>(\"min_size\");\n    const int32_t max_size = ctx->Attr<int32_t>(\"max_size\");\n    const std::string& interp_type = ctx->Attr<std::string>(\"interpolation_type\");\n\n    MultiThreadLoop(num_images, [&](size_t i) {\n      ImageTargetResize(in_img_buf[i], out_img_buf + i, resize_longer, target_size, min_size,\n                        max_size, interp_type);\n      const int64_t org_h = in_img_buf[i].shape_view().At(0);\n      const int64_t org_w = in_img_buf[i].shape_view().At(1);\n      const int64_t res_h = out_img_buf[i].shape_view().At(0);\n      const int64_t res_w = out_img_buf[i].shape_view().At(1);\n\n      scale_buf[i].Resize(Shape({2}), DataType::kFloat);\n      scale_buf[i].mut_data<float>()[0] = static_cast<float>(res_w) / static_cast<float>(org_w);\n      scale_buf[i].mut_data<float>()[1] = static_cast<float>(res_h) / static_cast<float>(org_h);\n\n      size_buf[i].Resize(Shape({2}), DataType::kInt32);\n      size_buf[i].mut_data<int32_t>()[0] = static_cast<int32_t>(res_w);\n      size_buf[i].mut_data<int32_t>()[1] = static_cast<int32_t>(res_h);\n    });\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n}  // namespace\n\n#define REGISTER_IMAGE_RESIZE_KERNEL(dtype)                                          \\\n  REGISTER_USER_KERNEL(\"image_resize_to_fixed\")                                      \\\n      .SetCreateFn<ImageResizeToFixedSizeKernel>()                                   \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)                \\\n                       && (user_op::HobDataType(\"in\", 0) == DataType::kTensorBuffer) \\\n                       && (user_op::HobAttr<DataType>(\"data_type\") == GetDataType<dtype>::value));\n\nREGISTER_IMAGE_RESIZE_KERNEL(float)\nREGISTER_IMAGE_RESIZE_KERNEL(uint8_t)\n\nREGISTER_USER_KERNEL(\"image_resize_keep_aspect_ratio\")\n    .SetCreateFn<ImageResizeKeepAspectRatioKernel>()\n    .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)\n                     && (user_op::HobDataType(\"in\", 0) == DataType::kTensorBuffer)\n                     && (user_op::HobDataType(\"out\", 0) == DataType::kTensorBuffer)\n                     && (user_op::HobDataType(\"size\", 0) == DataType::kTensorBuffer)\n                     && (user_op::HobDataType(\"scale\", 0) == DataType::kTensorBuffer));\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/image_target_resize_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/thread/thread_manager.h\"\n#include \"oneflow/user/image/image_util.h\"\n#include <opencv2/opencv.hpp>\n#include <cfenv>\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename T>\nstd::pair<T, T> GetTargetResizedSize4ImageBuffer(const TensorBuffer& image_buffer,\n                                                 const T target_size, const T max_size) {\n  CHECK_EQ(image_buffer.shape_view().NumAxes(), 3);\n  const T origin_height = image_buffer.shape_view().At(0);\n  const T origin_width = image_buffer.shape_view().At(1);\n\n  // set round to banker's rounding\n  int origin_round_way = std::fegetround();\n  CHECK_EQ(std::fesetround(FE_TONEAREST), 0);\n\n  double origin_min_size = std::min<double>(origin_height, origin_width);\n  double origin_max_size = std::max<double>(origin_height, origin_width);\n  double resized_min_size = static_cast<double>(target_size);\n  double resized_max_size = std::nearbyint((origin_max_size / origin_min_size) * resized_min_size);\n  if (resized_max_size > max_size) {\n    resized_max_size = static_cast<double>(max_size);\n    resized_min_size = std::nearbyint(resized_max_size * origin_min_size / origin_max_size);\n  }\n\n  std::pair<T, T> height_and_width;\n  if (origin_width < origin_height) {\n    height_and_width.second = resized_min_size;\n    height_and_width.first = resized_max_size;\n  } else {\n    height_and_width.first = resized_min_size;\n    height_and_width.second = resized_max_size;\n  }\n  std::fesetround(origin_round_way);\n  return height_and_width;\n}\n\nvoid ImageTargetResize(const TensorBuffer& image_buffer, TensorBuffer* resized_image_buffer,\n                       const int32_t target_size, const int32_t max_size) {\n  CHECK_EQ(image_buffer.shape_view().NumAxes(), 3);\n  CHECK_GT(target_size, 0);\n  CHECK_GE(max_size, target_size);\n\n  cv::Mat image_mat = GenCvMat4ImageBuffer(image_buffer);\n  int64_t res_h = 0;\n  int64_t res_w = 0;\n  int64_t channels = image_mat.channels();\n  std::tie(res_h, res_w) =\n      GetTargetResizedSize4ImageBuffer<int64_t>(image_buffer, target_size, max_size);\n  resized_image_buffer->Resize(Shape({res_h, res_w, channels}), image_buffer.data_type());\n  cv::Mat res_image_mat = GenCvMat4ImageBuffer(*resized_image_buffer);\n  cv::resize(image_mat, res_image_mat, cv::Size(res_w, res_h), 0, 0, cv::INTER_LINEAR);\n\n  CHECK_EQ(res_image_mat.ptr(), resized_image_buffer->data());\n  CHECK_LE(std::max(res_image_mat.rows, res_image_mat.cols), max_size);\n  CHECK(std::max(res_image_mat.rows, res_image_mat.cols) == max_size\n        || std::min(res_image_mat.rows, res_image_mat.cols) == target_size);\n}\n\n}  // namespace\n\nclass ImageTargetResizeKernel final : public user_op::OpKernel {\n public:\n  ImageTargetResizeKernel() = default;\n  ~ImageTargetResizeKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* in_tensor = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    user_op::Tensor* out_tensor = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    user_op::Tensor* size_tensor = ctx->Tensor4ArgNameAndIndex(\"size\", 0);\n    user_op::Tensor* scale_tensor = ctx->Tensor4ArgNameAndIndex(\"scale\", 0);\n    CHECK_GT(in_tensor->shape_view().elem_cnt(), 0);\n    CHECK_EQ(in_tensor->shape_view().elem_cnt(), out_tensor->shape_view().elem_cnt());\n    CHECK_EQ(in_tensor->shape_view().elem_cnt(), size_tensor->shape_view().At(0));\n    CHECK_EQ(in_tensor->shape_view().elem_cnt(), scale_tensor->shape_view().At(0));\n\n    const TensorBuffer* in_img_buf = in_tensor->dptr<TensorBuffer>();\n    TensorBuffer* out_img_buf = out_tensor->mut_dptr<TensorBuffer>();\n    int32_t* size_ptr = size_tensor ? size_tensor->mut_dptr<int32_t>() : nullptr;\n    float* scale_ptr = scale_tensor ? scale_tensor->mut_dptr<float>() : nullptr;\n    const int32_t target_size = ctx->Attr<int32_t>(\"target_size\");\n    const int32_t max_size = ctx->Attr<int32_t>(\"max_size\");\n\n    MultiThreadLoop(in_tensor->shape_view().elem_cnt(), [&](size_t i) {\n      ImageTargetResize(in_img_buf[i], out_img_buf + i, target_size, max_size);\n      if (size_ptr != nullptr) {\n        size_ptr[i * 2 + 0] = out_img_buf[i].shape_view().At(0);\n        size_ptr[i * 2 + 1] = out_img_buf[i].shape_view().At(1);\n      }\n      if (scale_ptr != nullptr) {\n        scale_ptr[i * 2 + 0] = static_cast<float>(out_img_buf[i].shape_view().At(0))\n                               / static_cast<float>(in_img_buf[i].shape_view().At(0));\n        scale_ptr[i * 2 + 1] = static_cast<float>(out_img_buf[i].shape_view().At(1))\n                               / static_cast<float>(in_img_buf[i].shape_view().At(1));\n      }\n    });\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\nREGISTER_USER_KERNEL(\"image_target_resize\")\n    .SetCreateFn<ImageTargetResizeKernel>()\n    .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)\n                     && (user_op::HobDataType(\"in\", 0) == DataType::kTensorBuffer)\n                     && (user_op::HobDataType(\"out\", 0) == DataType::kTensorBuffer)\n                     && (user_op::HobDataType(\"size\", 0) == DataType::kInt32)\n                     && (user_op::HobDataType(\"scale\", 0) == DataType::kFloat));\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/in_top_k_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/user/kernels/in_top_k_kernel_util.h\"\n\nnamespace oneflow {\n\ntemplate<DeviceType device_type, typename T>\nclass InTopkKernel final : public user_op::OpKernel {\n public:\n  InTopkKernel() = default;\n  ~InTopkKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* targets = ctx->Tensor4ArgNameAndIndex(\"targets\", 0);\n    const user_op::Tensor* predictions = ctx->Tensor4ArgNameAndIndex(\"predictions\", 0);\n    const int32_t k = ctx->Attr<int32_t>(\"k\");\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    CHECK_EQ(targets->shape_view().At(0), predictions->shape_view().At(0));\n    CHECK_EQ(targets->shape_view().NumAxes(), 1);\n    CHECK_EQ(predictions->shape_view().NumAxes(), 2);\n    const int32_t instance_num = predictions->shape_view().At(0);\n    const int32_t classes_num = predictions->shape_view().At(1);\n    InTopkKernelUtil<device_type, T>::InTopk(ctx->stream(), instance_num, classes_num,\n                                             targets->dptr<T>(), predictions->dptr<float>(), k,\n                                             out->mut_dptr<bool>());\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_IN_TOP_K_KERNEL(device, target_dtype_pair)                     \\\n  REGISTER_USER_KERNEL(\"in_top_k\")                                              \\\n      .SetCreateFn<InTopkKernel<device, OF_PP_PAIR_FIRST(target_dtype_pair)>>() \\\n      .SetIsMatchedHob(                                                         \\\n          (user_op::HobDeviceType() == device)                                  \\\n          && (user_op::HobDataType(\"targets\", 0) == OF_PP_PAIR_SECOND(target_dtype_pair)));\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_IN_TOP_K_KERNEL, DEVICE_TYPE_SEQ, INDEX_DATA_TYPE_SEQ)\n\n#undef REGISTER_IN_TOP_K_KERNEL\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/in_top_k_kernel_util.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/user/kernels/in_top_k_kernel_util.h\"\n#include \"oneflow/core/common/data_type_seq.h\"\n\nnamespace oneflow {\n\ntemplate<typename T>\nstruct InTopkKernelUtil<DeviceType::kCPU, T> {\n  static void InTopk(ep::Stream* stream, const int instance_num, const int classes_num,\n                     const T* targets, const float* predictions, const int k, bool* out) {\n    FOR_RANGE(int32_t, idx, 0, instance_num) {\n      T target = targets[idx];\n      bool cannot_say =\n          (target >= classes_num) || !std::isfinite(predictions[idx * classes_num + target]);\n      int32_t more_probable_classes = 0;\n      if (!cannot_say) {\n        const float target_prediction = predictions[idx * classes_num + target];\n        FOR_RANGE(int32_t, class_idx, 0, classes_num) {\n          float pred = predictions[idx * classes_num + class_idx];\n\n          if (!std::isfinite(pred)) {\n            cannot_say = true;\n            break;\n          } else if (pred > target_prediction) {\n            ++more_probable_classes;\n            if (more_probable_classes > k) break;\n          }\n        }\n      }\n      out[idx] = cannot_say ? false : (more_probable_classes < k);\n    }\n  }\n};\n\n#define INSTANTIATE_IN_TOP_K_KERNEL_UTIL_CPU(cpp_data_type, data_type) \\\n  template struct InTopkKernelUtil<DeviceType::kCPU, cpp_data_type>;\n\nOF_PP_FOR_EACH_TUPLE(INSTANTIATE_IN_TOP_K_KERNEL_UTIL_CPU, INDEX_DATA_TYPE_SEQ)\n\n#undef INSTANTIATE_IN_TOP_K_KERNEL_UTIL_CPU\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/in_top_k_kernel_util.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/user/kernels/in_top_k_kernel_util.h\"\n#include \"oneflow/core/device/cuda_util.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename T>\n__global__ void InTopkGpu(const int instance_num, const int classes_num, const T* targets,\n                          const float* predictions, const int k, bool* out) {\n  CUDA_1D_KERNEL_LOOP(idx, instance_num) {\n    T target = targets[idx];\n    bool cannot_say = (target >= classes_num) || !isfinite(predictions[idx * classes_num + target]);\n\n    int32_t more_probable_classes = 0;\n    if (!cannot_say) {\n      const float target_prediction = predictions[idx * classes_num + target];\n      FOR_RANGE(int32_t, class_idx, 0, classes_num) {\n        float pred = predictions[idx * classes_num + class_idx];\n\n        if (!isfinite(pred)) {\n          cannot_say = true;\n          break;\n        } else if (pred > target_prediction) {\n          ++more_probable_classes;\n          if (more_probable_classes > k) break;\n        }\n      }\n    }\n    out[idx] = cannot_say ? false : (more_probable_classes < k);\n  }\n}\n\n}  // namespace\n\ntemplate<typename T>\nstruct InTopkKernelUtil<DeviceType::kCUDA, T> {\n  static void InTopk(ep::Stream* stream, const int instance_num, const int classes_num,\n                     const T* targets, const float* predictions, const int k, bool* out) {\n    RUN_CUDA_KERNEL((InTopkGpu<T>), stream, instance_num, instance_num, classes_num, targets,\n                    predictions, k, out);\n  }\n};\n\n#define INSTANTIATE_IN_TOP_K_KERNEL_UTIL_CUDA(cpp_data_type, data_type) \\\n  template struct InTopkKernelUtil<DeviceType::kCUDA, cpp_data_type>;\n\nOF_PP_FOR_EACH_TUPLE(INSTANTIATE_IN_TOP_K_KERNEL_UTIL_CUDA, INDEX_DATA_TYPE_SEQ)\n\n#undef INSTANTIATE_IN_TOP_K_KERNEL_UTIL_CUDA\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/in_top_k_kernel_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_KERNELS_IN_TOP_K_KERNEL_UTIL_H_\n#define ONEFLOW_USER_KERNELS_IN_TOP_K_KERNEL_UTIL_H_\n\n#include \"oneflow/core/ep/include/stream.h\"\n\nnamespace oneflow {\n\ntemplate<DeviceType device_type, typename T>\nstruct InTopkKernelUtil {\n  static void InTopk(ep::Stream* stream, const int instance_num, const int classes_num,\n                     const T* targets, const float* predictions, const int k, bool* out);\n};\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_KERNELS_IN_TOP_K_KERNEL_UTIL_H_\n"
  },
  {
    "path": "oneflow/user/kernels/index_add_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/new_kernel_util.h\"\n\nnamespace oneflow {\n\nnamespace {\ntemplate<typename T, typename IndexT>\nvoid index_add_cpu_kernel(const int64_t n, const T* input, const IndexT* index, const T* source,\n                          T* output, const int64_t stride, const int64_t source_dim,\n                          const int64_t delta, const float alpha) {\n  const int64_t stride_source_dim = stride * source_dim;\n  for (int i = 0; i < n; i++) {\n    int64_t pre_index = i / stride_source_dim;\n    int64_t dim_index = (i - pre_index * stride_source_dim) / stride;\n    IndexT source_dim_idx = index[dim_index];\n    int64_t output_index = i + (delta * pre_index + source_dim_idx - dim_index) * stride;\n    output[output_index] += static_cast<T>(alpha) * source[i];\n  }\n}\n};  // namespace\n\ntemplate<typename T>\nclass IndexAddCpuKernel final : public user_op::OpKernel {\n public:\n  IndexAddCpuKernel() = default;\n  ~IndexAddCpuKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* input = ctx->Tensor4ArgNameAndIndex(\"input\", 0);\n    const user_op::Tensor* index = ctx->Tensor4ArgNameAndIndex(\"index\", 0);\n    const user_op::Tensor* source = ctx->Tensor4ArgNameAndIndex(\"source\", 0);\n    user_op::Tensor* output = ctx->Tensor4ArgNameAndIndex(\"output\", 0);\n    const int64_t dim = ctx->Attr<int64_t>(\"dim\");\n    const float alpha = ctx->Attr<float>(\"alpha\");\n    const ShapeView& input_shape = input->shape_view();\n    const ShapeView& source_shape = source->shape_view();\n    std::vector<int64_t> input_stride(input->stride().begin(), input->stride().end());\n    const int64_t stride = input_stride[dim];\n    const int64_t source_dim = source_shape.At(dim);\n    const int64_t delta = input_shape.At(dim) - source_dim;\n    DataType index_dtype = index->data_type();\n    const int32_t n = source->shape_view().elem_cnt();\n    Memcpy<DeviceType::kCPU>(\n        ctx->stream(), output->mut_dptr<void>(), input->dptr<void>(),\n        input->shape_view().elem_cnt() * GetSizeOfDataType(input->data_type()));\n    if (GetSizeOfDataType(index_dtype) == 4) {\n      index_add_cpu_kernel(n, input->dptr<T>(), index->dptr<int32_t>(), source->dptr<T>(),\n                           output->mut_dptr<T>(), stride, source_dim, delta, alpha);\n    } else {\n      index_add_cpu_kernel(n, input->dptr<T>(), index->dptr<int64_t>(), source->dptr<T>(),\n                           output->mut_dptr<T>(), stride, source_dim, delta, alpha);\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_INDEX_ADD_CPU_KERNEL(dtype)                          \\\n  REGISTER_USER_KERNEL(\"index_add\")                                   \\\n      .SetCreateFn<IndexAddCpuKernel<dtype>>()                        \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \\\n                       && (user_op::HobDataType(\"output\", 0) == GetDataType<dtype>::value));\n\nREGISTER_INDEX_ADD_CPU_KERNEL(int8_t)\nREGISTER_INDEX_ADD_CPU_KERNEL(int32_t)\nREGISTER_INDEX_ADD_CPU_KERNEL(float)\nREGISTER_INDEX_ADD_CPU_KERNEL(double)\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/index_add_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/cuda/atomic.cuh\"\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/new_kernel_util.h\"\n#include \"oneflow/core/kernel/kernel_util.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n\nnamespace oneflow {\n\nnamespace {\ntemplate<typename T, typename IndexT>\n__global__ void index_add_cuda_kernel(const int64_t n, const T* input, const IndexT* index,\n                                      const T* source, T* output, const int64_t stride,\n                                      const int64_t source_dim, const int64_t delta,\n                                      const float alpha) {\n  // For x = flow.ones(5, 3)\n  // source = flow.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=flow.float)\n  // index = flow.tensor([0, 4, 2])\n  // dim = 0\n  // We have:\n  // stride = 3\n  // source_dim = 3\n  // stride * source_dim = 9\n  // alpha = 1.0\n  // delta = 5 - 3 = 2\n\n  // For i = 8\n  // pre_index = i / stride_source_dim = 8 / 9 = 0\n  // dim_index = i % stride_source_dim / stride = 8 % 9 / 3 = 0\n  // source_dim_idx = index[dim_index] = index[0] = 0\n  // output_index = i + (delta * pre_index + source_dim_idx - dim_index) * stride = 9 + (2 * 0 + 0 -\n  // 0) * 3 = 9 cuda::atomic::Add(output + output_index, static_cast<T>(alpha) * source[i])=>\n  // output[9] += 1.0 * 9 = 10.0\n  const int64_t stride_source_dim = stride * source_dim;\n  CUDA_1D_KERNEL_LOOP(i, n) {\n    int64_t pre_index = i / stride_source_dim;\n    int64_t dim_index = (i - pre_index * stride_source_dim) / stride;\n    IndexT source_dim_idx = index[dim_index];\n    int64_t output_index = i + (delta * pre_index + source_dim_idx - dim_index) * stride;\n    cuda::atomic::Add(output + output_index, static_cast<T>(alpha) * source[i]);\n  }\n}\n};  // namespace\n\ntemplate<typename T>\nclass IndexAddGpuKernel final : public user_op::OpKernel {\n public:\n  IndexAddGpuKernel() = default;\n  ~IndexAddGpuKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* input = ctx->Tensor4ArgNameAndIndex(\"input\", 0);\n    const user_op::Tensor* index = ctx->Tensor4ArgNameAndIndex(\"index\", 0);\n    const user_op::Tensor* source = ctx->Tensor4ArgNameAndIndex(\"source\", 0);\n    user_op::Tensor* output = ctx->Tensor4ArgNameAndIndex(\"output\", 0);\n    const int64_t dim = ctx->Attr<int64_t>(\"dim\");\n    const float alpha = ctx->Attr<float>(\"alpha\");\n    const ShapeView& input_shape = input->shape_view();\n    const ShapeView& source_shape = source->shape_view();\n    std::vector<int64_t> input_stride(input->stride().begin(), input->stride().end());\n    const int64_t stride = input_stride[dim];\n    const int64_t source_dim = source_shape.At(dim);\n    const int64_t delta = input_shape.At(dim) - source_dim;\n    DataType index_dtype = index->data_type();\n    const int32_t n = source->shape_view().elem_cnt();\n    Memcpy<DeviceType::kCUDA>(\n        ctx->stream(), output->mut_dptr<void>(), input->dptr<void>(),\n        input->shape_view().elem_cnt() * GetSizeOfDataType(input->data_type()));\n    if (GetSizeOfDataType(index_dtype) == 4) {\n      RUN_CUDA_KERNEL((index_add_cuda_kernel<T, int32_t>), ctx->stream(), n, n, input->dptr<T>(),\n                      index->dptr<int32_t>(), source->dptr<T>(), output->mut_dptr<T>(), stride,\n                      source_dim, delta, alpha);\n    } else {\n      RUN_CUDA_KERNEL((index_add_cuda_kernel<T, int64_t>), ctx->stream(), n, n, input->dptr<T>(),\n                      index->dptr<int64_t>(), source->dptr<T>(), output->mut_dptr<T>(), stride,\n                      source_dim, delta, alpha);\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_INDEX_ADD_CUDA_KERNEL(dtype)                          \\\n  REGISTER_USER_KERNEL(\"index_add\")                                    \\\n      .SetCreateFn<IndexAddGpuKernel<dtype>>()                         \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \\\n                       && (user_op::HobDataType(\"output\", 0) == GetDataType<dtype>::value));\n\nREGISTER_INDEX_ADD_CUDA_KERNEL(float)\nREGISTER_INDEX_ADD_CUDA_KERNEL(half)\nREGISTER_INDEX_ADD_CUDA_KERNEL(double)\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/indexed_slices_reduce_sum_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/user/kernels/indexed_slices_reduce_sum_kernel_util.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<DeviceType device_type, typename T, typename K>\nclass IndexedSlicesReduceSumKernel final : public user_op::OpKernel {\n public:\n  IndexedSlicesReduceSumKernel() = default;\n  ~IndexedSlicesReduceSumKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* x_indices = ctx->Tensor4ArgNameAndIndex(\"x_indices\", 0);\n    const user_op::Tensor* x_values = ctx->Tensor4ArgNameAndIndex(\"x_values\", 0);\n    user_op::Tensor* y_indices = ctx->Tensor4ArgNameAndIndex(\"y_indices\", 0);\n    user_op::Tensor* y_values = ctx->Tensor4ArgNameAndIndex(\"y_values\", 0);\n    user_op::Tensor* num_unique = ctx->Tensor4ArgNameAndIndex(\"num_unique\", 0);\n    user_op::Tensor* tmp = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n    void* tmp_ptr = tmp ? tmp->mut_dptr() : nullptr;\n    int64_t tmp_size = tmp ? tmp->shape_view().elem_cnt() * GetSizeOfDataType(tmp->data_type()) : 0;\n    const int64_t n = x_indices->shape_view().elem_cnt();\n    const int64_t m = x_values->shape_view().elem_cnt() / n;\n    IndexedSlicesReduceSumKernelUtil<device_type, K, T, int64_t>::ReduceSum(\n        ctx->stream(), n, m, x_indices->dptr<K>(), x_values->dptr<T>(),\n        num_unique->mut_dptr<int64_t>(), y_indices->mut_dptr<K>(), y_values->mut_dptr<T>(), tmp_ptr,\n        tmp_size);\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\ntemplate<DeviceType device_type, typename T, typename K>\nuser_op::InferTmpSizeFn GenInferTmpSizeFn() {\n  return [](user_op::InferContext* ctx) {\n    const auto& x_indices = ctx->InputTensorDesc(\"x_indices\", 0);\n    const auto& x_values = ctx->InputTensorDesc(\"x_values\", 0);\n    const int64_t n = x_indices.shape().elem_cnt();\n    const int64_t m = x_values.shape().elem_cnt() / n;\n    int64_t workspace_size_in_bytes;\n    IndexedSlicesReduceSumKernelUtil<device_type, K, T, int64_t>::GetReduceSumWorkspaceSizeInBytes(\n        nullptr, n, m, &workspace_size_in_bytes);\n    return workspace_size_in_bytes;\n  };\n}\n\n#define REGISTER_INDEXED_SLICES_REDUCE_SUM_KERNEL(device_type_v, data_type_pair,                 \\\n                                                  indices_type_pair)                             \\\n  REGISTER_USER_KERNEL(\"indexed_slices_reduce_sum\")                                              \\\n      .SetCreateFn<IndexedSlicesReduceSumKernel<device_type_v, OF_PP_PAIR_FIRST(data_type_pair), \\\n                                                OF_PP_PAIR_FIRST(indices_type_pair)>>()          \\\n      .SetIsMatchedHob(                                                                          \\\n          (user_op::HobDeviceType() == device_type_v)                                            \\\n          && (user_op::HobDataType(\"x_values\", 0) == OF_PP_PAIR_SECOND(data_type_pair))          \\\n          && (user_op::HobDataType(\"x_indices\", 0) == OF_PP_PAIR_SECOND(indices_type_pair)))     \\\n      .SetInferTmpSizeFn(GenInferTmpSizeFn<device_type_v, OF_PP_PAIR_FIRST(data_type_pair),      \\\n                                           OF_PP_PAIR_FIRST(indices_type_pair)>());\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_INDEXED_SLICES_REDUCE_SUM_KERNEL, DEVICE_TYPE_SEQ,\n                                 FLOATING_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ)\n\n}  // namespace\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/indexed_slices_reduce_sum_kernel_util.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/user/kernels/indexed_slices_reduce_sum_kernel_util.h\"\n#include \"oneflow/user/kernels/unique_kernel_util.h\"\n#include \"oneflow/user/kernels/unsorted_segment_sum_kernel_util.h\"\n\nnamespace oneflow {\n\ntemplate<typename IDX>\nint64_t GetUniqueIdxSize(int64_t n) {\n  return GetCudaAlignedSize(n * sizeof(IDX));\n}\n\ntemplate<DeviceType device_type, typename K, typename T, typename IDX>\nvoid IndexedSlicesReduceSumKernelUtil<device_type, K, T, IDX>::ReduceSum(\n    ep::Stream* stream, int64_t n, int64_t m, const K* indices, const T* values,\n    IDX* num_unique_indices, K* indices_out, T* values_out, void* workspace,\n    int64_t workspace_size_in_bytes) {\n  const int64_t unique_idx_size = GetUniqueIdxSize<IDX>(n);\n  CHECK_LE(unique_idx_size, workspace_size_in_bytes);\n  IDX* unique_idx_ptr = reinterpret_cast<IDX*>(workspace);\n  void* unique_workspace_ptr = reinterpret_cast<unsigned char*>(workspace) + unique_idx_size;\n  const int64_t unique_workspace_size = workspace_size_in_bytes - unique_idx_size;\n  UniqueKernelUtil<device_type, K, IDX>::Unique(stream, n, indices, num_unique_indices, indices_out,\n                                                unique_idx_ptr, unique_workspace_ptr,\n                                                unique_workspace_size, /*sorted*/ false);\n  const Shape flat_in_shape({1, n, m});\n  Memset<device_type>(stream, values_out, 0, n * m * sizeof(T));\n\n  UnsortedSegmentSumKernelUtil<device_type, T, IDX, T>::UnsortedSegmentSum(\n      stream, unique_idx_ptr, values, n, n, 1, m, 0, values_out);\n}\n\ntemplate<DeviceType device_type, typename K, typename T, typename IDX>\nvoid IndexedSlicesReduceSumKernelUtil<device_type, K, T, IDX>::GetReduceSumWorkspaceSizeInBytes(\n    ep::Stream* stream, int64_t n, int64_t m, int64_t* workspace_size_in_bytes) {\n  int64_t unique_workspace_size;\n  UniqueKernelUtil<device_type, K, int64_t>::GetUniqueWorkspaceSizeInBytes(stream, n,\n                                                                           &unique_workspace_size);\n  *workspace_size_in_bytes = GetUniqueIdxSize<IDX>(n) + unique_workspace_size;\n}\n\n#define INSTANTIATE_INDEXED_SLICES_REDUCE_SUM_KERNEL_UTIL(device_type, key_type_pair,            \\\n                                                          val_type_pair, idx_type_pair)          \\\n  template struct IndexedSlicesReduceSumKernelUtil<device_type, OF_PP_PAIR_FIRST(key_type_pair), \\\n                                                   OF_PP_PAIR_FIRST(val_type_pair),              \\\n                                                   OF_PP_PAIR_FIRST(idx_type_pair)>;\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_INDEXED_SLICES_REDUCE_SUM_KERNEL_UTIL, DEVICE_TYPE_SEQ,\n                                 INDEX_DATA_TYPE_SEQ, FLOATING_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ);\n#undef INSTANTIATE_INDEXED_SLICES_REDUCE_SUM_KERNEL_UTIL\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/indexed_slices_reduce_sum_kernel_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_KERNELS_INDEXED_SLICES_REDUCE_SUM_KERNEL_UTIL_H_\n#define ONEFLOW_USER_KERNELS_INDEXED_SLICES_REDUCE_SUM_KERNEL_UTIL_H_\n\n#include \"oneflow/core/kernel/kernel_util.h\"\n#include \"oneflow/core/ep/include/stream.h\"\n\nnamespace oneflow {\n\ntemplate<DeviceType device_type, typename K, typename T, typename IDX>\nstruct IndexedSlicesReduceSumKernelUtil {\n  static void ReduceSum(ep::Stream* stream, int64_t n, int64_t m, const K* indices, const T* values,\n                        IDX* num_unique_indices, K* indices_out, T* values_out, void* workspace,\n                        int64_t workspace_size_in_bytes);\n  static void GetReduceSumWorkspaceSizeInBytes(ep::Stream* stream, int64_t n, int64_t m,\n                                               int64_t* workspace_size_in_bytes);\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_KERNELS_INDEXED_SLICES_REDUCE_SUM_KERNEL_UTIL_H_\n"
  },
  {
    "path": "oneflow/user/kernels/inv_kernels.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/common/eigen_util.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nstatic inline size_t BatchCount(const user_op::Tensor* batched_matrices) {\n  size_t result = 1;\n  for (size_t i = 0; i < batched_matrices->shape_view().NumAxes() - 2; i++) {\n    result *= batched_matrices->shape_view().At(i);\n  }\n  return result;\n}\n\nstatic inline size_t MatrixStride(const user_op::Tensor* batched_matrices) {\n  const int64_t num_axes = batched_matrices->shape_view().NumAxes();\n  return batched_matrices->shape_view().At(num_axes - 2)\n         * batched_matrices->shape_view().At(num_axes - 1);\n}\n\n}  // namespace\n\ntemplate<typename T>\nclass CpuInvKernel final : public user_op::OpKernel {\n public:\n  CpuInvKernel() = default;\n  ~CpuInvKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    auto batch_count = BatchCount(x);\n    auto matrix_stride = MatrixStride(x);\n    auto matrix_size = x->shape_view().At(x->shape_view().NumAxes() - 2);\n    const T* x_ptr = x->dptr<T>();\n    T* y_ptr = y->mut_dptr<T>();\n    FOR_RANGE(int64_t, i, 0, batch_count) {\n      ConstEigenMatrixMap<T> x_mat(x_ptr + i * matrix_stride, matrix_size, matrix_size);\n      EigenMatrixMap<T> y_mat(y_ptr + i * matrix_stride, matrix_size, matrix_size);\n      if (x_mat.determinant() == 0) {\n        LOG(FATAL)\n            << \"(Batch element \" << i\n            << \"): the inversion could not be completed because the input matrix is singular.\";\n      }\n      y_mat = x_mat.inverse();\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_CPU_INV_KERNEL(dtype)                                            \\\n  REGISTER_USER_KERNEL(\"inv\").SetCreateFn<CpuInvKernel<dtype>>().SetIsMatchedHob( \\\n      (user_op::HobDeviceType() == DeviceType::kCPU)                              \\\n      && (user_op::HobDataType(\"x\", 0) == GetDataType<dtype>::value));\n\nREGISTER_CPU_INV_KERNEL(float)\nREGISTER_CPU_INV_KERNEL(double)\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/inv_kernels.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/new_kernel_util.h\"\n#include \"oneflow/user/kernels/arange_kernel_util.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nstatic inline size_t BatchCount(const user_op::Tensor* batched_matrices) {\n  size_t result = 1;\n  for (size_t i = 0; i < batched_matrices->shape_view().NumAxes() - 2; i++) {\n    result *= batched_matrices->shape_view().At(i);\n  }\n  return result;\n}\n\nstatic inline size_t MatrixStride(const user_op::Tensor* batched_matrices) {\n  const int64_t num_axes = batched_matrices->shape_view().NumAxes();\n  return batched_matrices->shape_view().At(num_axes - 2)\n         * batched_matrices->shape_view().At(num_axes - 1);\n}\n\nvoid OFgetrfBatched(ep::Stream* stream, int n, float** dA_array, int ldda, int* ipiv_array,\n                    int* info_array, int batchsize) {\n  OF_CUBLAS_CHECK(cublasSgetrfBatched(stream->As<ep::CudaStream>()->cublas_handle(), n, dA_array,\n                                      ldda, ipiv_array, info_array, batchsize));\n}\nvoid OFgetrfBatched(ep::Stream* stream, int n, double** dA_array, int ldda, int* ipiv_array,\n                    int* info_array, int batchsize) {\n  OF_CUBLAS_CHECK(cublasDgetrfBatched(stream->As<ep::CudaStream>()->cublas_handle(), n, dA_array,\n                                      ldda, ipiv_array, info_array, batchsize));\n}\nvoid OFgetriBatched(ep::Stream* stream, int n, float** dA_array, int ldda, int* ipiv_array,\n                    float** dC_array, int lddc, int* info_array, int batchsize) {\n  OF_CUBLAS_CHECK(cublasSgetriBatched(stream->As<ep::CudaStream>()->cublas_handle(), n, dA_array,\n                                      ldda, ipiv_array, dC_array, lddc, info_array, batchsize));\n}\nvoid OFgetriBatched(ep::Stream* stream, int n, double** dA_array, int ldda, int* ipiv_array,\n                    double** dC_array, int lddc, int* info_array, int batchsize) {\n  OF_CUBLAS_CHECK(cublasDgetriBatched(stream->As<ep::CudaStream>()->cublas_handle(), n, dA_array,\n                                      ldda, ipiv_array, dC_array, lddc, info_array, batchsize));\n}\n\n}  // namespace\n\nnamespace user_op {\n\ntemplate<typename T>\nclass CudaInvKernel final : public user_op::OpKernel {\n public:\n  CudaInvKernel() = default;\n  ~CudaInvKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n    auto batch_count = BatchCount(x);\n    auto matrix_stride = MatrixStride(x);\n    auto matrix_size = x->shape_view().At(x->shape_view().NumAxes() - 2);\n\n    const ShapeView& x_shape = x->shape_view();\n    const int64_t instance_num = x_shape.Count(0, x_shape.NumAxes() - 2);\n    const int64_t infos_bytes = GetCudaAlignedSize(instance_num * sizeof(int));\n    const int64_t ipiv_bytes =\n        GetCudaAlignedSize(batch_count * x_shape.At(x_shape.NumAxes() - 2) * sizeof(int));\n    const int64_t pptr_bytes = GetCudaAlignedSize(batch_count * sizeof(T*));\n    int* infos_getrf_ptr = tmp_buffer->mut_dptr<int>();\n    int* infos_getrs_ptr =\n        reinterpret_cast<int*>(reinterpret_cast<char*>(infos_getrf_ptr) + infos_bytes);\n    int* ipiv_ptr = reinterpret_cast<int*>(reinterpret_cast<char*>(infos_getrs_ptr) + infos_bytes);\n    T** x_pptr = reinterpret_cast<T**>(reinterpret_cast<char*>(ipiv_ptr) + ipiv_bytes);\n    T** y_pptr = reinterpret_cast<T**>(reinterpret_cast<char*>(x_pptr) + pptr_bytes);\n    T* x_copy_ptr = reinterpret_cast<T*>(reinterpret_cast<char*>(y_pptr) + pptr_bytes);\n    Memcpy<DeviceType::kCUDA>(ctx->stream(), x_copy_ptr, x->dptr<T>(),\n                              x_shape.elem_cnt() * sizeof(T));\n    ArangeFunctor<DeviceType::kCUDA, int64_t>()(ctx->stream(),\n                                                reinterpret_cast<int64_t>(x_copy_ptr),\n                                                static_cast<int64_t>(matrix_stride * sizeof(T)),\n                                                batch_count, reinterpret_cast<int64_t*>(x_pptr));\n    ArangeFunctor<DeviceType::kCUDA, int64_t>()(ctx->stream(),\n                                                reinterpret_cast<int64_t>(y->mut_dptr<T>()),\n                                                static_cast<int64_t>(matrix_stride * sizeof(T)),\n                                                batch_count, reinterpret_cast<int64_t*>(y_pptr));\n    Memset<DeviceType::kCUDA>(ctx->stream(), infos_getrf_ptr, 0, infos_bytes);\n    Memset<DeviceType::kCUDA>(ctx->stream(), infos_getrs_ptr, 0, infos_bytes);\n    Memset<DeviceType::kCUDA>(ctx->stream(), ipiv_ptr, 0, ipiv_bytes);\n    OFgetrfBatched(ctx->stream(), matrix_size, x_pptr, matrix_size, ipiv_ptr, infos_getrf_ptr,\n                   batch_count);\n    OFgetriBatched(ctx->stream(), matrix_size, x_pptr, matrix_size, ipiv_ptr, y_pptr, matrix_size,\n                   infos_getrs_ptr, batch_count);\n    std::vector<int> infos_getrf_vec_host(batch_count, 0);\n    std::vector<int> infos_getrs_vec_host(batch_count, 0);\n    OF_CUDA_CHECK(cudaMemcpyAsync(infos_getrf_vec_host.data(), infos_getrf_ptr,\n                                  batch_count * sizeof(int), cudaMemcpyDefault,\n                                  ctx->stream()->As<ep::CudaStream>()->cuda_stream()));\n    OF_CUDA_CHECK(cudaMemcpyAsync(infos_getrs_vec_host.data(), infos_getrs_ptr,\n                                  batch_count * sizeof(int), cudaMemcpyDefault,\n                                  ctx->stream()->As<ep::CudaStream>()->cuda_stream()));\n    CHECK_JUST(ctx->stream()->Sync());\n    FOR_RANGE(int64_t, i, 0, batch_count) {\n      if (infos_getrf_vec_host[i] > 0) {\n        LOG(FATAL) << \"(Batch element \" << i << \"): The diagonal element \"\n                   << infos_getrf_vec_host[i]\n                   << \" is zero, the inversion could not be completed because the input matrix is \"\n                      \"singular.\";\n      }\n    }\n    FOR_RANGE(int64_t, i, 0, batch_count) {\n      if (infos_getrs_vec_host[i] > 0) {\n        LOG(FATAL) << \"(Batch element \" << i << \"): The diagonal element \"\n                   << infos_getrs_vec_host[i]\n                   << \" is zero, the inversion could not be completed because the input matrix is \"\n                      \"singular.\";\n      }\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_CUDA_INV_KERNEL(dtype)                                                       \\\n  REGISTER_USER_KERNEL(\"inv\")                                                                 \\\n      .SetCreateFn<CudaInvKernel<dtype>>()                                                    \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)                        \\\n                       && (user_op::HobDataType(\"x\", 0) == GetDataType<dtype>::value))        \\\n      .SetInferTmpSizeFn([](user_op::InferContext* ctx) {                                     \\\n        const Shape& x_shape = ctx->InputShape(\"x\", 0);                                       \\\n        auto batch_size = x_shape.Count(0, x_shape.NumAxes() - 2);                            \\\n        const int64_t instance_num = x_shape.Count(0, x_shape.NumAxes() - 2);                 \\\n        const int64_t infos_bytes = GetCudaAlignedSize(instance_num * sizeof(int));           \\\n        const int64_t ipiv_bytes =                                                            \\\n            GetCudaAlignedSize(batch_size * x_shape.At(x_shape.NumAxes() - 2) * sizeof(int)); \\\n        const int64_t pptr_bytes = GetCudaAlignedSize(batch_size * sizeof(dtype*));           \\\n        const int64_t x_copy_bytes = GetCudaAlignedSize(x_shape.elem_cnt() * sizeof(dtype));  \\\n        return infos_bytes * 2 + ipiv_bytes + pptr_bytes * 2 + x_copy_bytes;                  \\\n      });\n\nREGISTER_CUDA_INV_KERNEL(float)\nREGISTER_CUDA_INV_KERNEL(double)\n\n}  // namespace user_op\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/kl_div_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/kernel_util.cuh\"\n#include \"oneflow/user/kernels/loss_kernel_util.h\"\n\nnamespace oneflow {\nnamespace user_op {\nnamespace {\n\nusing namespace loss;\n\ntemplate<typename T>\nvoid ComputeKLDivOut(int64_t elem_cnt, const T* input, const T* target, T* out,\n                     const bool log_target) {\n  if (log_target) {\n    FOR_RANGE(int64_t, i, 0, elem_cnt) { out[i] = std::exp(target[i]) * (target[i] - input[i]); }\n  } else {\n    FOR_RANGE(int64_t, i, 0, elem_cnt) {\n      const auto out_val = target[i] * (SafeLog(target[i]) - input[i]);\n      out[i] = target[i] > 0 ? out_val : static_cast<T>(0);\n    }\n  }\n}\n\ntemplate<typename T>\nvoid ComputeKLDivGradOut(int64_t elem_cnt, const T* input, const T* target, const T* dy, T* dx,\n                         const bool log_target) {\n  FOR_RANGE(int64_t, i, 0, elem_cnt) {\n    const T dy_val = dy[i];\n    dx[i] =\n        log_target ? (-std::exp(target[i]) * dy_val) : (target[i] > 0 ? -target[i] * dy_val : 0);\n  }\n}\n\ntemplate<typename T>\nclass KLDivKernel : public SimpleLossKernel<DeviceType::kCPU, T, KLDivKernel<T>> {\n public:\n  void ComputeOut(user_op::KernelComputeContext* ctx, int64_t elem_cnt, const T* input,\n                  const T* target, T* out) const {\n    const bool log_target = ctx->Attr<bool>(\"log_target\");\n    ComputeKLDivOut(elem_cnt, input, target, out, log_target);\n  }\n};\n\ntemplate<typename T>\nclass KLDivGradKernel : public SimpleLossGradKernel<DeviceType::kCPU, T, KLDivGradKernel<T>> {\n public:\n  void ComputeOut(user_op::KernelComputeContext* ctx, int64_t elem_cnt, const T* input,\n                  const T* target, const T* dy, T* dx) const {\n    const bool log_target = ctx->Attr<bool>(\"log_target\");\n    ComputeKLDivGradOut(elem_cnt, input, target, dy, dx, log_target);\n  }\n};\n\n}  // namespace\n\nREGISTER_SIMPLE_LOSS_KERNEL_CPU(\"kl_div_loss\", KLDivKernel)\nREGISTER_SIMPLE_LOSS_GRAD_KERNEL_CPU(\"kl_div_loss_grad\", KLDivGradKernel)\n\n}  // namespace user_op\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/kl_div_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/data_type.h\"\n#include \"oneflow/core/cuda/elementwise.cuh\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/kernel_util.cuh\"\n#include \"oneflow/user/kernels/loss_kernel_util.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n\nnamespace oneflow {\nnamespace user_op {\nnamespace {\n\nusing namespace loss;\n\ntemplate<typename T, bool LOG_TARGET>\nstruct KLDivFunctor {\n  __device__ __forceinline__ T operator()(T input_val, T target_val) const {\n    if (LOG_TARGET) {\n      return exp(target_val) * (target_val - input_val);\n    } else {\n      const T zero_val = static_cast<T>(0);\n      const T out_val = target_val * (SafeLog(target_val) - input_val);\n      return target_val > zero_val ? out_val : zero_val;\n    }\n  }\n};\n\ntemplate<bool LOG_TARGET>\nstruct KLDivFunctor<half, LOG_TARGET> {\n  __device__ __forceinline__ half operator()(half input_val, half target_val) const {\n    if (LOG_TARGET) {\n      return hexp(target_val) * (target_val - input_val);\n    } else {\n      const half zero_val = __float2half(0.f);\n      const half out_val = target_val * (SafeLog(target_val) - input_val);\n      return target_val > zero_val ? out_val : zero_val;\n    }\n  }\n};\n\ntemplate<typename T, bool LOG_TARGET>\nstruct KLDivGradFunctor {\n  __device__ __forceinline__ T operator()(T target_val, T dy_val) const {\n    if (LOG_TARGET) {\n      return -exp(target_val) * dy_val;\n    } else {\n      const T zero_val = static_cast<T>(0);\n      return target_val > zero_val ? -target_val * dy_val : zero_val;\n    }\n  }\n};\n\ntemplate<bool LOG_TARGET>\nstruct KLDivGradFunctor<half, LOG_TARGET> {\n  __device__ __forceinline__ half operator()(half target_val, half dy_val) const {\n    if (LOG_TARGET) {\n      return __hneg(hexp(target_val) * dy_val);\n    } else {\n      const half zero_val = __float2half(0.f);\n      return target_val > zero_val ? __hneg(target_val * dy_val) : zero_val;\n    }\n  }\n};\n\ntemplate<typename T>\nclass KLDivKernel : public SimpleLossKernel<DeviceType::kCUDA, T, KLDivKernel<T>> {\n public:\n  void ComputeOut(user_op::KernelComputeContext* ctx, int64_t elem_cnt, const T* input,\n                  const T* target, T* out) const {\n    const bool log_target = ctx->Attr<bool>(\"log_target\");\n    if (log_target) {\n      OF_CUDA_CHECK(\n          (cuda::elementwise::Binary(KLDivFunctor<T, true>(), elem_cnt, out, input, target,\n                                     ctx->stream()->As<ep::CudaStream>()->cuda_stream())));\n    } else {\n      OF_CUDA_CHECK(\n          (cuda::elementwise::Binary(KLDivFunctor<T, false>(), elem_cnt, out, input, target,\n                                     ctx->stream()->As<ep::CudaStream>()->cuda_stream())));\n    }\n  }\n};\n\ntemplate<typename T>\nclass KLDivGradKernel : public SimpleLossGradKernel<DeviceType::kCUDA, T, KLDivGradKernel<T>> {\n public:\n  void ComputeOut(user_op::KernelComputeContext* ctx, int64_t elem_cnt, const T* input,\n                  const T* target, const T* dy, T* dx) const {\n    const bool log_target = ctx->Attr<bool>(\"log_target\");\n    if (log_target) {\n      OF_CUDA_CHECK((cuda::elementwise::Binary(\n          KLDivGradFunctor<T, /*LOG_TARGET*/ true>(), elem_cnt, dx, target, dy,\n          ctx->stream()->As<ep::CudaStream>()->cuda_stream())));\n    } else {\n      OF_CUDA_CHECK((cuda::elementwise::Binary(\n          KLDivGradFunctor<T, /*LOG_TARGET*/ false>(), elem_cnt, dx, target, dy,\n          ctx->stream()->As<ep::CudaStream>()->cuda_stream())));\n    }\n  }\n};\n\n}  // namespace\n\nREGISTER_SIMPLE_LOSS_KERNEL_CUDA(\"kl_div_loss\", KLDivKernel)\nREGISTER_SIMPLE_LOSS_GRAD_KERNEL_CUDA(\"kl_div_loss_grad\", KLDivGradKernel)\n\n}  // namespace user_op\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/l1_l2_regularize_gradient_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/user/kernels/l1_l2_regularize_gradient_kernel_util.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<DeviceType device_type, typename T>\nclass L1L2RegularizeGradientKernel final : public user_op::OpKernel {\n public:\n  L1L2RegularizeGradientKernel() = default;\n  ~L1L2RegularizeGradientKernel() override = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* model = ctx->Tensor4ArgNameAndIndex(\"model\", 0);\n    const user_op::Tensor* model_diff = ctx->Tensor4ArgNameAndIndex(\"model_diff\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    const auto l1 = ctx->Attr<float>(\"l1\");\n    const auto l2 = ctx->Attr<float>(\"l2\");\n    L1L2RegularizeGradientKernelUtil<device_type, T>::RegularizeGradient(\n        ctx->stream(), out->shape_view().elem_cnt(), model->dptr<T>(), model_diff->dptr<T>(),\n        out->mut_dptr<T>(), l1, l2);\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_L1_L2_REGULARIZE_GRADIENT_KERNEL(device, dtype)                                \\\n  REGISTER_USER_KERNEL(\"l1_l2_regularize_gradient\")                                             \\\n      .SetCreateFn<L1L2RegularizeGradientKernel<device, dtype>>()                               \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == device)                                     \\\n                       && (user_op::HobDataType(\"out\", 0) == GetDataType<dtype>::value))        \\\n      .SetInplaceProposalFn([](const user_op::InferContext&,                                    \\\n                               user_op::AddInplaceArgPair AddInplaceArgPairFn) -> Maybe<void> { \\\n        OF_RETURN_IF_ERROR(AddInplaceArgPairFn(\"out\", 0, \"model_diff\", 0, true));               \\\n        return Maybe<void>::Ok();                                                               \\\n      });\n\nREGISTER_L1_L2_REGULARIZE_GRADIENT_KERNEL(DeviceType::kCPU, float)\nREGISTER_L1_L2_REGULARIZE_GRADIENT_KERNEL(DeviceType::kCPU, double)\n#ifdef WITH_CUDA\nREGISTER_L1_L2_REGULARIZE_GRADIENT_KERNEL(DeviceType::kCUDA, float)\nREGISTER_L1_L2_REGULARIZE_GRADIENT_KERNEL(DeviceType::kCUDA, double)\n#endif\n\n}  // namespace\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/l1_l2_regularize_gradient_kernel_util.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/user/kernels/l1_l2_regularize_gradient_kernel_util.h\"\n\nnamespace oneflow {\n\ntemplate<typename T>\nstruct L1L2RegularizeGradientKernelUtil<DeviceType::kCPU, T> {\n  static void RegularizeGradient(ep::Stream* stream, int64_t n, const T* model, const T* model_diff,\n                                 T* out, const T l1, const T l2) {\n    FOR_RANGE(int64_t, i, 0, n) {\n      const T model_val = model[i];\n      out[i] = model_diff[i] + l1 * (model_val >= 0 ? 1 : -1) + l2 * model_val;\n    }\n  }\n};\n\n#define INSTANTIATE_L1_L2_REGULARIZE_GRADIENT_KERNEL_UTIL_CPU(type_cpp, type_proto) \\\n  template struct L1L2RegularizeGradientKernelUtil<DeviceType::kCPU, type_cpp>;\nOF_PP_FOR_EACH_TUPLE(INSTANTIATE_L1_L2_REGULARIZE_GRADIENT_KERNEL_UTIL_CPU, FLOATING_DATA_TYPE_SEQ);\n#undef INSTANTIATE_L1_L2_REGULARIZE_GRADIENT_KERNEL_UTIL_CPU\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/l1_l2_regularize_gradient_kernel_util.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/user/kernels/l1_l2_regularize_gradient_kernel_util.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename T>\n__global__ void L1L2RegularizeGradientGpu(int64_t n, const T* model, const T* model_diff, T* out,\n                                          const T l1, const T l2) {\n  CUDA_1D_KERNEL_LOOP(i, n) {\n    const T model_val = model[i];\n    out[i] = model_diff[i] + l1 * ((model_val >= 0) - (model_val <= 0)) + l2 * model_val;\n  }\n}\n\n}  // namespace\n\ntemplate<typename T>\nstruct L1L2RegularizeGradientKernelUtil<DeviceType::kCUDA, T> {\n  static void RegularizeGradient(ep::Stream* stream, int64_t n, const T* model, const T* model_diff,\n                                 T* out, const T l1, const T l2) {\n    L1L2RegularizeGradientGpu<<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0,\n                                stream->As<ep::CudaStream>()->cuda_stream()>>>(n, model, model_diff,\n                                                                               out, l1, l2);\n  }\n};\n\n#define INSTANTIATE_L1_L2_REGULARIZE_GRADIENT_KERNEL_UTIL_CUDA(type_cpp, type_proto) \\\n  template struct L1L2RegularizeGradientKernelUtil<DeviceType::kCUDA, type_cpp>;\nOF_PP_FOR_EACH_TUPLE(INSTANTIATE_L1_L2_REGULARIZE_GRADIENT_KERNEL_UTIL_CUDA,\n                     FLOATING_DATA_TYPE_SEQ);\n#undef INSTANTIATE_L1_L2_REGULARIZE_GRADIENT_KERNEL_UTIL_CUDA\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/l1_l2_regularize_gradient_kernel_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_KERNELS_L1_L2_REGULARIZE_GRADIENT_KERNEL_UTIL_H_\n#define ONEFLOW_USER_KERNELS_L1_L2_REGULARIZE_GRADIENT_KERNEL_UTIL_H_\n\n#include \"oneflow/core/kernel/kernel_util.h\"\n\nnamespace oneflow {\n\ntemplate<DeviceType device_type, typename T>\nstruct L1L2RegularizeGradientKernelUtil {\n  static void RegularizeGradient(ep::Stream* stream, int64_t n, const T* model, const T* model_diff,\n                                 T* out, T l1, T l2);\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_KERNELS_L1_L2_REGULARIZE_GRADIENT_KERNEL_UTIL_H_\n"
  },
  {
    "path": "oneflow/user/kernels/l2_normalize_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/new_kernel_util.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename T>\nstatic void L2NormalizeForward(const int32_t n, const int32_t c, const int32_t d, const T epsilon,\n                               const T* in, T* square_x_sum, T* out) {\n  for (int32_t i = 0; i < n; i++) {\n    const int32_t offset = (i / d) * d * c + (i % d);\n    for (int32_t j = 0; j < c; j++) {\n      const T x = in[offset + j * d];\n      square_x_sum[i] += x * x;\n    }\n    const T norm = std::sqrt(std::max(square_x_sum[i], epsilon));\n    for (int32_t j = 0; j < c; j++) {\n      const int32_t index = offset + j * d;\n      out[index] = in[index] / norm;\n    }\n  }\n}\n\ntemplate<typename T>\nstatic void L2NormalizeBackward(const int32_t n, const int32_t c, const int32_t d, const T epsilon,\n                                const T* out, const T* out_diff, const T* square_x_sum,\n                                T* in_diff) {\n  for (int32_t i = 0; i < n; i++) {\n    const T norm = std::sqrt(std::max(square_x_sum[i], epsilon));\n    const int32_t offset = (i / d) * d * c + (i % d);\n    if (square_x_sum[i] >= epsilon) {\n      T y_dy_inner_prod = GetZeroVal<T>();\n      for (int32_t j = 0; j < c; j++) {\n        const int32_t index = offset + j * d;\n        y_dy_inner_prod += out_diff[index] * out[index];\n      }\n      for (int32_t j = 0; j < c; j++) {\n        const int32_t index = offset + j * d;\n        in_diff[index] = (1 / norm) * (out_diff[index] - y_dy_inner_prod * out[index]);\n      }\n    } else {\n      for (int32_t j = 0; j < c; j++) {\n        const int32_t index = offset + j * d;\n        in_diff[index] = (1 / norm) * out_diff[index];\n      }\n    }\n  }\n}\n\n}  // namespace\n\ntemplate<typename T>\nclass CpuL2NormalizeKernel final : public user_op::OpKernel {\n public:\n  CpuL2NormalizeKernel() = default;\n  ~CpuL2NormalizeKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    user_op::Tensor* square_x_sum = ctx->Tensor4ArgNameAndIndex(\"square_x_sum\", 0);\n    const float epsilon = ctx->Attr<float>(\"epsilon\");\n    int32_t axis = ctx->Attr<int32_t>(\"axis\");\n    int32_t c = x->shape_view().At(axis);\n    int32_t n = x->shape_view().elem_cnt() / c;\n    int32_t d = x->shape_view().Count(axis + 1);\n\n    size_t square_x_sum_byte_size = square_x_sum->shape_view().elem_cnt() * sizeof(T);\n    Memset<DeviceType::kCPU>(ctx->stream(), square_x_sum->mut_dptr(), 0, square_x_sum_byte_size);\n    L2NormalizeForward<T>(n, c, d, static_cast<T>(epsilon), x->dptr<T>(),\n                          square_x_sum->mut_dptr<T>(), y->mut_dptr<T>());\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_CPU_L2_NORMALIZE_KERNEL(dtype)                       \\\n  REGISTER_USER_KERNEL(\"l2_normalize\")                                \\\n      .SetCreateFn<CpuL2NormalizeKernel<dtype>>()                     \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \\\n                       && (user_op::HobDataType(\"y\", 0) == GetDataType<dtype>::value));\n\nREGISTER_CPU_L2_NORMALIZE_KERNEL(float)\nREGISTER_CPU_L2_NORMALIZE_KERNEL(double)\n\ntemplate<typename T>\nclass CpuL2NormalizeGradKernel final : public user_op::OpKernel {\n public:\n  CpuL2NormalizeGradKernel() = default;\n  ~CpuL2NormalizeGradKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    const user_op::Tensor* square_x_sum = ctx->Tensor4ArgNameAndIndex(\"square_x_sum\", 0);\n    user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex(\"dx\", 0);\n    const float epsilon = ctx->Attr<float>(\"epsilon\");\n    int32_t axis = ctx->Attr<int32_t>(\"axis\");\n    int32_t c = dy->shape_view().At(axis);\n    int32_t n = dy->shape_view().elem_cnt() / c;\n    int32_t d = dy->shape_view().Count(axis + 1);\n    L2NormalizeBackward<T>(n, c, d, static_cast<T>(epsilon), y->dptr<T>(), dy->dptr<T>(),\n                           square_x_sum->dptr<T>(), dx->mut_dptr<T>());\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_CPU_L2_NORMALIZE_GRAD_KERNEL(dtype)                  \\\n  REGISTER_USER_KERNEL(\"l2_normalize_grad\")                           \\\n      .SetCreateFn<CpuL2NormalizeGradKernel<dtype>>()                 \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \\\n                       && (user_op::HobDataType(\"dx\", 0) == GetDataType<dtype>::value));\n\nREGISTER_CPU_L2_NORMALIZE_GRAD_KERNEL(float)\nREGISTER_CPU_L2_NORMALIZE_GRAD_KERNEL(double)\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/l2_normalize_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include <cub/cub.cuh>\n#include \"oneflow/core/device/cuda_util.h\"\n#include \"oneflow/core/cuda/layer_norm.cuh\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename T, typename ComputeType>\n__global__ void L2NormalizeForward(const int32_t n, const int32_t c, const int32_t d,\n                                   const ComputeType epsilon, const T* in,\n                                   ComputeType* square_x_sum, T* out) {\n  using BlockReduce = cub::BlockReduce<ComputeType, ep::CudaStream::kDefaultBlockSize>;\n  __shared__ typename BlockReduce::TempStorage temp_storage;\n\n  for (int32_t i = blockIdx.x; i < n; i += gridDim.x) {\n    ComputeType sum = GetZeroVal<ComputeType>();\n    const int32_t offset = (i / d) * d * c + (i % d);\n    for (int32_t j = threadIdx.x; j < c; j += blockDim.x) {\n      const ComputeType x = static_cast<ComputeType>(in[offset + j * d]);\n      sum += x * x;\n    }\n    const ComputeType reduce_sum = BlockReduce(temp_storage).Sum(sum);\n    if (threadIdx.x == 0) { square_x_sum[i] = reduce_sum; }\n    __syncthreads();\n\n    const ComputeType inv_norm = rsqrtf(fmaxf(square_x_sum[i], epsilon));\n    for (int32_t j = threadIdx.x; j < c; j += blockDim.x) {\n      const int32_t index = offset + j * d;\n      out[index] = static_cast<T>(inv_norm * static_cast<ComputeType>(in[index]));\n    }\n  }\n}\n\ntemplate<typename T>\n__global__ void L2NormalizeBackward(const int32_t n, const int32_t c, const int32_t d,\n                                    const float epsilon, const T* out, const T* out_diff,\n                                    const T* square_x_sum, T* in_diff) {\n  for (int32_t i = blockIdx.x; i < n; i += gridDim.x) {\n    const T inv_norm = rsqrt(fmaxf(square_x_sum[i], epsilon));\n    const int32_t offset = (i / d) * d * c + (i % d);\n    if (square_x_sum[i] >= epsilon) {\n      using BlockReduce = cub::BlockReduce<T, ep::CudaStream::kDefaultBlockSize>;\n      __shared__ typename BlockReduce::TempStorage temp_storage_prod_sum;\n\n      T y_dy_prod_sum = GetZeroVal<T>();\n      for (int32_t j = threadIdx.x; j < c; j += blockDim.x) {\n        const int32_t index = offset + j * d;\n        y_dy_prod_sum += out[index] * out_diff[index];\n      }\n\n      const T reduce_y_dy_prod_sum = BlockReduce(temp_storage_prod_sum).Sum(y_dy_prod_sum);\n      __shared__ T y_dy_inner_prod;\n      if (threadIdx.x == 0) { y_dy_inner_prod = reduce_y_dy_prod_sum; }\n      __syncthreads();\n\n      for (int32_t j = threadIdx.x; j < c; j += blockDim.x) {\n        const int32_t index = offset + j * d;\n        in_diff[index] = inv_norm * (out_diff[index] - y_dy_inner_prod * out[index]);\n      }\n    } else {\n      for (int32_t j = threadIdx.x; j < c; j += blockDim.x) {\n        const int32_t index = offset + j * d;\n        in_diff[index] = inv_norm * out_diff[index];\n      }\n    }\n  }\n}\n\n}  // namespace\n\ntemplate<typename T>\nclass GpuL2NormalizeKernel final : public user_op::OpKernel {\n public:\n  GpuL2NormalizeKernel() = default;\n  ~GpuL2NormalizeKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    user_op::Tensor* square_x_sum = ctx->Tensor4ArgNameAndIndex(\"square_x_sum\", 0);\n    const float epsilon = ctx->Attr<float>(\"epsilon\");\n    int32_t axis = ctx->Attr<int32_t>(\"axis\");\n    int32_t c = x->shape_view().At(axis);\n    int32_t n = x->shape_view().elem_cnt() / c;\n    int32_t d = x->shape_view().Count(axis + 1);\n    using ComputeType = typename cuda::layer_norm::DefaultComputeType<T>::type;\n    RUN_CUDA_KERNEL((L2NormalizeForward<T, ComputeType>), ctx->stream(), n, n, c, d,\n                    static_cast<ComputeType>(epsilon), x->dptr<T>(),\n                    square_x_sum->mut_dptr<ComputeType>(), y->mut_dptr<T>());\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_CUDA_L2_NORMALIZE_KERNEL(dtype)                       \\\n  REGISTER_USER_KERNEL(\"l2_normalize\")                                 \\\n      .SetCreateFn<GpuL2NormalizeKernel<dtype>>()                      \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \\\n                       && (user_op::HobDataType(\"y\", 0) == GetDataType<dtype>::value));\n\nREGISTER_CUDA_L2_NORMALIZE_KERNEL(half)\nREGISTER_CUDA_L2_NORMALIZE_KERNEL(float)\nREGISTER_CUDA_L2_NORMALIZE_KERNEL(double)\n\ntemplate<typename T>\nclass GpuL2NormalizeGradKernel final : public user_op::OpKernel {\n public:\n  GpuL2NormalizeGradKernel() = default;\n  ~GpuL2NormalizeGradKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    const user_op::Tensor* square_x_sum = ctx->Tensor4ArgNameAndIndex(\"square_x_sum\", 0);\n    user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex(\"dx\", 0);\n    const float epsilon = ctx->Attr<float>(\"epsilon\");\n    int32_t axis = ctx->Attr<int32_t>(\"axis\");\n    int32_t c = dy->shape_view().At(axis);\n    int32_t n = dy->shape_view().elem_cnt() / c;\n    int32_t d = dy->shape_view().Count(axis + 1);\n    RUN_CUDA_KERNEL((L2NormalizeBackward<T>), ctx->stream(), n, n, c, d, static_cast<T>(epsilon),\n                    y->dptr<T>(), dy->dptr<T>(), square_x_sum->dptr<T>(), dx->mut_dptr<T>());\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_CUDA_L2_NORMALIZE_GRAD_KERNEL(dtype)                  \\\n  REGISTER_USER_KERNEL(\"l2_normalize_grad\")                            \\\n      .SetCreateFn<GpuL2NormalizeGradKernel<dtype>>()                  \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \\\n                       && (user_op::HobDataType(\"dx\", 0) == GetDataType<dtype>::value));\n\nREGISTER_CUDA_L2_NORMALIZE_GRAD_KERNEL(float)\nREGISTER_CUDA_L2_NORMALIZE_GRAD_KERNEL(double)\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/layer_norm_cpu_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n\nnamespace oneflow {\n\ntemplate<typename T>\nclass LayerNormCpuKernel final : public user_op::OpKernel {\n public:\n  LayerNormCpuKernel() = default;\n  ~LayerNormCpuKernel() = default;\n\n private:\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n  void Compute(user_op::KernelComputeContext* ctx) const override { TODO(); };\n};\n\n#define REGISTER_LAYER_NORM_CPU_KERNEL(dtype)                         \\\n  REGISTER_USER_KERNEL(\"layer_norm\")                                  \\\n      .SetCreateFn<LayerNormCpuKernel<dtype>>()                       \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \\\n                       && (user_op::HobDataType(\"x\", 0) == GetDataType<dtype>::value));\n\nREGISTER_LAYER_NORM_CPU_KERNEL(float)\nREGISTER_LAYER_NORM_CPU_KERNEL(double)\n\ntemplate<typename T>\nclass LayerNormGradCpuKernel final : public user_op::OpKernel {\n public:\n  LayerNormGradCpuKernel() = default;\n  ~LayerNormGradCpuKernel() = default;\n\n private:\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n  void Compute(user_op::KernelComputeContext* ctx) const override { TODO(); };\n};\n\n#define REGISTER_LAYER_NORM_GRAD_CPU_KERNEL(dtype)                    \\\n  REGISTER_USER_KERNEL(\"layer_norm_grad\")                             \\\n      .SetCreateFn<LayerNormGradCpuKernel<dtype>>()                   \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \\\n                       && (user_op::HobDataType(\"dy\", 0) == GetDataType<dtype>::value));\n\nREGISTER_LAYER_NORM_GRAD_CPU_KERNEL(float)\nREGISTER_LAYER_NORM_GRAD_CPU_KERNEL(double)\n\ntemplate<typename T>\nclass FuseLayerNormGradCpuKernel final : public user_op::OpKernel {\n public:\n  FuseLayerNormGradCpuKernel() = default;\n  ~FuseLayerNormGradCpuKernel() = default;\n\n private:\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n  void Compute(user_op::KernelComputeContext* ctx) const override { TODO(); };\n};\n\n#define REGISTER_FUSE_LAYER_NORM_GRAD_CPU_KERNEL(dtype)               \\\n  REGISTER_USER_KERNEL(\"fuse_layer_norm_grad\")                        \\\n      .SetCreateFn<LayerNormGradCpuKernel<dtype>>()                   \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \\\n                       && (user_op::HobDataType(\"dy\", 0) == GetDataType<dtype>::value));\n\nREGISTER_FUSE_LAYER_NORM_GRAD_CPU_KERNEL(float)\nREGISTER_FUSE_LAYER_NORM_GRAD_CPU_KERNEL(double)\n\ntemplate<typename T>\nclass LayerNormParamGradCpuKernel final : public user_op::OpKernel {\n public:\n  LayerNormParamGradCpuKernel() = default;\n  ~LayerNormParamGradCpuKernel() = default;\n\n private:\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n  void Compute(user_op::KernelComputeContext* ctx) const override { TODO(); };\n};\n\n#define REGISTER_LAYER_NORM_PARAM_GRAD_CPU_KERNEL(dtype)              \\\n  REGISTER_USER_KERNEL(\"layer_norm_param_grad\")                       \\\n      .SetCreateFn<LayerNormParamGradCpuKernel<dtype>>()              \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \\\n                       && (user_op::HobDataType(\"dy\", 0) == GetDataType<dtype>::value));\n\nREGISTER_LAYER_NORM_PARAM_GRAD_CPU_KERNEL(float)\nREGISTER_LAYER_NORM_PARAM_GRAD_CPU_KERNEL(double)\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/layer_norm_gpu_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/device/cudnn_util.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/ndarray/ndarray_util.h\"\n#include \"oneflow/core/cuda/atomic.cuh\"\n#include <cub/cub.cuh>\n#include \"oneflow/core/kernel/cuda_graph_support.h\"\n#include \"oneflow/core/ep/include/primitive/fill.h\"\n#include \"oneflow/core/ep/include/primitive/matmul.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n#include \"oneflow/core/cuda/layer_norm.cuh\"\n#if CUDA_VERSION >= 11000\n#include <cuda_bf16.h>\n#endif  // CUDA_VERSION >= 11000\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename SRC, typename DST, bool do_scale, bool do_center>\nstruct AffineStore {\n  AffineStore(DST* y, int64_t row_size, const DST* gamma, const DST* beta)\n      : y(y), row_size(row_size), gamma(gamma), beta(beta) {}\n  template<int N>\n  __device__ void store(const SRC* src, int64_t row, int64_t col) {\n    cuda::layer_norm::Pack<DST, N> y_pack;\n    cuda::layer_norm::Pack<DST, N> gamma_pack;\n    cuda::layer_norm::Pack<DST, N> beta_pack;\n    const int64_t offset = (row * row_size + col) / N;\n    const int64_t gamma_offset = col / N;\n    if (do_scale) {\n      gamma_pack.storage =\n          *(reinterpret_cast<const cuda::layer_norm::PackType<DST, N>*>(gamma) + gamma_offset);\n    } else {\n#pragma unroll\n      for (int i = 0; i < N; ++i) { gamma_pack.elem[i] = static_cast<DST>(1.f); }\n    }\n    if (do_center) {\n      beta_pack.storage =\n          *(reinterpret_cast<const cuda::layer_norm::PackType<DST, N>*>(beta) + gamma_offset);\n    } else {\n#pragma unroll\n      for (int i = 0; i < N; ++i) { beta_pack.elem[i] = static_cast<DST>(0.f); }\n    }\n#pragma unroll\n    for (int i = 0; i < N; ++i) {\n      DST normalized_i = static_cast<DST>(src[i]);\n      if (do_scale || do_center) {\n        y_pack.elem[i] = normalized_i * gamma_pack.elem[i] + beta_pack.elem[i];\n      } else {\n        y_pack.elem[i] = normalized_i;\n      }\n    }\n    *(reinterpret_cast<cuda::layer_norm::PackType<DST, N>*>(y) + offset) = y_pack.storage;\n  }\n  DST* y;\n  int64_t row_size;\n  const DST* gamma;\n  const DST* beta;\n};\n\ntemplate<typename SRC, typename DST, bool do_scale>\nstruct ScaleLoad {\n  using LoadType = DST;\n  ScaleLoad(const SRC* src, const SRC* gamma, int64_t row_size)\n      : src(src), gamma(gamma), row_size(row_size) {}\n  template<int N>\n  __device__ void load(DST* dst, int64_t row, int64_t col) const {\n    cuda::layer_norm::Pack<SRC, N> src_pack;\n    cuda::layer_norm::Pack<SRC, N> gamma_pack;\n    const int64_t offset = (row * row_size + col) / N;\n    const int64_t gamma_offset = col / N;\n    src_pack.storage = *(reinterpret_cast<const cuda::layer_norm::PackType<SRC, N>*>(src) + offset);\n    if (do_scale) {\n      gamma_pack.storage =\n          *(reinterpret_cast<const cuda::layer_norm::PackType<SRC, N>*>(gamma) + gamma_offset);\n    } else {\n#pragma unroll\n      for (int i = 0; i < N; ++i) { gamma_pack.elem[i] = static_cast<SRC>(1.f); }\n    }\n#pragma unroll\n    for (int i = 0; i < N; ++i) {\n      dst[i] = static_cast<DST>(src_pack.elem[i] * gamma_pack.elem[i]);\n    }\n  }\n  const SRC* src;\n  const SRC* gamma;\n  int64_t row_size;\n};\n\ntemplate<typename SRC, typename DST, bool do_add>\nstruct AddStore {\n  AddStore(const DST* add_to_output, DST* dst, int64_t row_size)\n      : add_to_output(add_to_output), dst(dst), row_size(row_size) {}\n  template<int N>\n  __device__ void store(const SRC* src, int64_t row, int64_t col) {\n    cuda::layer_norm::Pack<DST, N> add_to_output_pack;\n    cuda::layer_norm::Pack<DST, N> dst_pack;\n    const int64_t offset = (row * row_size + col) / N;\n    if (do_add) {\n      add_to_output_pack.storage =\n          *(reinterpret_cast<const cuda::layer_norm::PackType<DST, N>*>(add_to_output) + offset);\n    }\n#pragma unroll\n    for (int i = 0; i < N; ++i) {\n      if (do_add) {\n        dst_pack.elem[i] = static_cast<DST>(src[i]) + add_to_output_pack.elem[i];\n      } else {\n        dst_pack.elem[i] = static_cast<DST>(src[i]);\n      }\n    }\n    *(reinterpret_cast<cuda::layer_norm::PackType<DST, N>*>(dst) + offset) = dst_pack.storage;\n  }\n  const DST* add_to_output;\n  DST* dst;\n  int64_t row_size;\n};\n\ntemplate<typename T>\n__inline__ __device__ T WarpReduce(T val) {\n  for (int mask = 16; mask > 0; mask /= 2) { val += __shfl_down_sync(0xffffffff, val, mask); }\n  return val;\n}\n\nconstexpr int tile_size = 32;\nconstexpr int num_per_block = 4;\nconstexpr int block_dim_x = 32;\nconstexpr int block_dim_y = 32 / num_per_block;\n\ntemplate<typename T, typename ComputeType>\n__global__ void LayerNormParamGrad(int rows, int cols, const T* __restrict__ dy,\n                                   const T* __restrict__ x, const ComputeType* __restrict__ mean,\n                                   const ComputeType* __restrict__ inv_var,\n                                   T* __restrict__ tmp_gamma_diff, T* __restrict__ tmp_beta_diff) {\n  __shared__ ComputeType dgamma[32][33];\n  __shared__ ComputeType dbeta[32][33];\n  ComputeType dgamma_sum[num_per_block];\n  ComputeType dbeta_sum[num_per_block];\n#pragma unroll\n  for (int index = 0; index < num_per_block; ++index) {\n    dgamma_sum[index] = 0;\n    dbeta_sum[index] = 0;\n  }\n  const int col_id = blockIdx.x * blockDim.x + threadIdx.x;\n  if (col_id < cols) {\n    for (int i = blockIdx.y * tile_size + threadIdx.y; i < rows; i += tile_size * gridDim.y) {\n#pragma unroll\n      for (int index = 0; index < num_per_block; ++index) {\n        int row_id = i + index * blockDim.y;\n        if (row_id < rows) {\n          int offset = row_id * cols + col_id;\n          const ComputeType dy_val = static_cast<ComputeType>(dy[offset]);\n          const ComputeType x_val = static_cast<ComputeType>(x[offset]);\n          const ComputeType mean_val = mean[row_id];\n          const ComputeType inv_var_val = inv_var[row_id];\n          dgamma_sum[index] += dy_val * (x_val - mean_val) * inv_var_val;\n          dbeta_sum[index] += dy_val;\n        }\n      }\n    }\n  }\n#pragma unroll\n  for (int index = 0; index < num_per_block; ++index) {\n    dgamma[index * blockDim.y + threadIdx.y][threadIdx.x] = dgamma_sum[index];\n    dbeta[index * blockDim.y + threadIdx.y][threadIdx.x] = dbeta_sum[index];\n  }\n  __syncthreads();\n#pragma unroll\n  for (int index = 0; index < num_per_block; ++index) {\n    const int col_id = blockIdx.x * blockDim.x + threadIdx.y + index * blockDim.y;\n    if (col_id < cols) {\n      ComputeType gamma_sum = dgamma[threadIdx.x][threadIdx.y + index * blockDim.y];\n      ComputeType beta_sum = dbeta[threadIdx.x][threadIdx.y + index * blockDim.y];\n      ComputeType global_dgamma = WarpReduce<ComputeType>(gamma_sum);\n      ComputeType global_dbeta = WarpReduce<ComputeType>(beta_sum);\n      if (threadIdx.x == 0) {\n        const int offset = blockIdx.y * cols + col_id;\n        tmp_gamma_diff[offset] = global_dgamma;\n        tmp_beta_diff[offset] = global_dbeta;\n      }\n    }\n  }\n}\n\ntemplate<typename T>\nint GetGirdDimY(const int64_t num_instances, const int64_t norm_size) {\n  using ComputeType = typename cuda::layer_norm::DefaultComputeType<T>::type;\n  const int grid_dim_x = (norm_size + tile_size - 1) / tile_size;\n  const int max_grid_dim_y = (num_instances + tile_size - 1) / tile_size;\n  const int block_size = block_dim_x * block_dim_y;\n  int max_active_blocks = 0;\n  OF_CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(\n      &max_active_blocks, LayerNormParamGrad<T, ComputeType>, block_size, 0));\n  int waves = 1;\n  int dev;\n  OF_CUDA_CHECK(cudaGetDevice(&dev));\n  int sm_count;\n  OF_CUDA_CHECK(cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev));\n  int num_blocks = max_active_blocks * sm_count * waves;\n  int grid_dim_y = std::min(max_grid_dim_y, static_cast<int>(num_blocks / grid_dim_x));\n  return std::max(grid_dim_y, 1);\n}\n\ntemplate<typename T, bool do_scale, bool do_center>\nvoid LayerNormForwardGpu(ep::Stream* stream, const int64_t num_instances, const int64_t norm_size,\n                         const double epsilon, const T* x_ptr, const T* gamma_ptr,\n                         const T* beta_ptr, T* y_ptr, user_op::Tensor* mean,\n                         user_op::Tensor* inv_variance) {\n  using ComputeType = typename cuda::layer_norm::DefaultComputeType<T>::type;\n  cuda::layer_norm::DirectLoad<T, T> load(x_ptr, norm_size);\n  AffineStore<ComputeType, T, do_scale, do_center> store(y_ptr, norm_size, gamma_ptr, beta_ptr);\n  cuda::layer_norm::DispatchLayerNorm<decltype(load), decltype(store), ComputeType>(\n      stream->As<ep::CudaStream>()->cuda_stream(), load, store, num_instances, norm_size, epsilon,\n      mean->mut_dptr<ComputeType>(), inv_variance->mut_dptr<ComputeType>());\n}\n\ntemplate<typename T>\nvoid DispatchLayerNormForwardGpu(ep::Stream* stream, const int64_t num_instances,\n                                 const int64_t norm_size, const double epsilon, const T* x_ptr,\n                                 const T* gamma_ptr, const T* beta_ptr, T* y_ptr,\n                                 user_op::Tensor* mean, user_op::Tensor* inv_variance) {\n  if (gamma_ptr != nullptr && beta_ptr != nullptr) {\n    LayerNormForwardGpu<T, true, true>(stream, num_instances, norm_size, epsilon, x_ptr, gamma_ptr,\n                                       beta_ptr, y_ptr, mean, inv_variance);\n  } else if (gamma_ptr != nullptr && beta_ptr == nullptr) {\n    LayerNormForwardGpu<T, true, false>(stream, num_instances, norm_size, epsilon, x_ptr, gamma_ptr,\n                                        beta_ptr, y_ptr, mean, inv_variance);\n  } else if (gamma_ptr == nullptr && beta_ptr != nullptr) {\n    LayerNormForwardGpu<T, false, true>(stream, num_instances, norm_size, epsilon, x_ptr, gamma_ptr,\n                                        beta_ptr, y_ptr, mean, inv_variance);\n  } else {\n    LayerNormForwardGpu<T, false, false>(stream, num_instances, norm_size, epsilon, x_ptr,\n                                         gamma_ptr, beta_ptr, y_ptr, mean, inv_variance);\n  }\n}\n\ntemplate<typename T, bool do_scale, bool do_add>\nvoid LayerNormBackwardGpu(ep::Stream* stream, const int64_t num_instances, const int64_t norm_size,\n                          const T* dy_ptr, const T* x_ptr, const user_op::Tensor* mean,\n                          const user_op::Tensor* inv_variance, const T* gamma_ptr,\n                          const T* add_to_output_ptr, T* dx_ptr) {\n  using ComputeType = typename cuda::layer_norm::DefaultComputeType<T>::type;\n  cuda::layer_norm::DirectLoad<T, T> load_x(x_ptr, norm_size);\n  ScaleLoad<T, T, do_scale> load_scaled_dy(dy_ptr, gamma_ptr, norm_size);\n  AddStore<ComputeType, T, do_add> store(add_to_output_ptr, dx_ptr, norm_size);\n  OF_CUDA_CHECK((cuda::layer_norm::DispatchLayerNormGrad<decltype(load_x), decltype(load_scaled_dy),\n                                                         decltype(store), ComputeType>(\n      stream->As<ep::CudaStream>()->cuda_stream(), load_x, load_scaled_dy, store,\n      mean->dptr<ComputeType>(), inv_variance->dptr<ComputeType>(), num_instances, norm_size)));\n}\n\ntemplate<typename T, bool do_scale>\nvoid DispatchLayerNormBackwardDoAdd(ep::Stream* stream, const int64_t num_instances,\n                                    const int64_t norm_size, const T* dy_ptr, const T* x_ptr,\n                                    const user_op::Tensor* mean,\n                                    const user_op::Tensor* inv_variance, const T* gamma_ptr,\n                                    const T* add_to_output_ptr, T* dx_ptr) {\n  if (add_to_output_ptr != nullptr) {\n    LayerNormBackwardGpu<T, do_scale, true>(stream, num_instances, norm_size, dy_ptr, x_ptr, mean,\n                                            inv_variance, gamma_ptr, add_to_output_ptr, dx_ptr);\n  } else {\n    LayerNormBackwardGpu<T, do_scale, false>(stream, num_instances, norm_size, dy_ptr, x_ptr, mean,\n                                             inv_variance, gamma_ptr, add_to_output_ptr, dx_ptr);\n  }\n}\n\ntemplate<typename T>\nvoid LaunchLayerNormBackward(ep::Stream* stream, const int64_t num_instances,\n                             const int64_t norm_size, const T* dy_ptr, const T* x_ptr,\n                             const user_op::Tensor* mean, const user_op::Tensor* inv_variance,\n                             const T* gamma_ptr, const T* add_to_output_ptr, T* dx_ptr) {\n  if (gamma_ptr != nullptr) {\n    DispatchLayerNormBackwardDoAdd<T, true>(stream, num_instances, norm_size, dy_ptr, x_ptr, mean,\n                                            inv_variance, gamma_ptr, add_to_output_ptr, dx_ptr);\n  } else {\n    DispatchLayerNormBackwardDoAdd<T, false>(stream, num_instances, norm_size, dy_ptr, x_ptr, mean,\n                                             inv_variance, gamma_ptr, add_to_output_ptr, dx_ptr);\n  }\n}\n\n}  // namespace\n\ntemplate<typename T>\nclass LayerNormGpuKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport {\n public:\n  LayerNormGpuKernel() = default;\n  ~LayerNormGpuKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    user_op::Tensor* mean = ctx->Tensor4ArgNameAndIndex(\"mean\", 0);\n    user_op::Tensor* inv_variance = ctx->Tensor4ArgNameAndIndex(\"inv_variance\", 0);\n    const double epsilon = ctx->Attr<double>(\"epsilon\");\n    CHECK_GE(epsilon, CUDNN_BN_MIN_EPSILON);\n    const int64_t num_instances = mean->shape_view().elem_cnt();\n    const int64_t norm_size = x->shape_view().elem_cnt() / num_instances;\n    const T* gamma_ptr = nullptr;\n    const T* beta_ptr = nullptr;\n    if (ctx->has_input(\"gamma\", 0)) {\n      const user_op::Tensor* gamma = ctx->Tensor4ArgNameAndIndex(\"gamma\", 0);\n      gamma_ptr = gamma->dptr<T>();\n      CHECK_EQ(gamma->shape_view().elem_cnt(), norm_size);\n    }\n    if (ctx->has_input(\"beta\", 0)) { beta_ptr = ctx->Tensor4ArgNameAndIndex(\"beta\", 0)->dptr<T>(); }\n    DispatchLayerNormForwardGpu<T>(ctx->stream(), num_instances, norm_size, epsilon, x->dptr<T>(),\n                                   gamma_ptr, beta_ptr, y->mut_dptr<T>(), mean, inv_variance);\n  };\n};\n\n#define REGISTER_LAYER_NORM_CUDA_KERNEL(dtype)                         \\\n  REGISTER_USER_KERNEL(\"layer_norm\")                                   \\\n      .SetCreateFn<LayerNormGpuKernel<dtype>>()                        \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \\\n                       && (user_op::HobDataType(\"x\", 0) == GetDataType<dtype>::value));\n\nREGISTER_LAYER_NORM_CUDA_KERNEL(float)\nREGISTER_LAYER_NORM_CUDA_KERNEL(double)\nREGISTER_LAYER_NORM_CUDA_KERNEL(half)\n#if CUDA_VERSION >= 11000\nREGISTER_LAYER_NORM_CUDA_KERNEL(nv_bfloat16)\n#endif\n\ntemplate<typename T>\nclass LayerNormGradGpuKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport {\n public:\n  LayerNormGradGpuKernel() = default;\n  ~LayerNormGradGpuKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    const user_op::Tensor* mean = ctx->Tensor4ArgNameAndIndex(\"mean\", 0);\n    const user_op::Tensor* inv_variance = ctx->Tensor4ArgNameAndIndex(\"inv_variance\", 0);\n    user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex(\"dx\", 0);\n    const int64_t num_instances = mean->shape_view().elem_cnt();\n    const int64_t norm_size = x->shape_view().elem_cnt() / num_instances;\n    const T* gamma_ptr = nullptr;\n    if (ctx->has_input(\"gamma\", 0)) {\n      gamma_ptr = ctx->Tensor4ArgNameAndIndex(\"gamma\", 0)->dptr<T>();\n    }\n    const T* add_to_output_ptr = nullptr;\n    if (ctx->has_input(\"_add_to_output\", 0)) {\n      const user_op::Tensor* add_to_output = ctx->Tensor4ArgNameAndIndex(\"_add_to_output\", 0);\n      CHECK_EQ(add_to_output->data_type(), dx->data_type());\n      CHECK_EQ(add_to_output->shape_view(), dx->shape_view());\n      add_to_output_ptr = add_to_output->dptr<T>();\n    }\n    LaunchLayerNormBackward<T>(ctx->stream(), num_instances, norm_size, dy->dptr<T>(), x->dptr<T>(),\n                               mean, inv_variance, gamma_ptr, add_to_output_ptr, dx->mut_dptr<T>());\n  };\n};\n\n#define REGISTER_LAYER_NORM_GRAD_CUDA_KERNEL(dtype)                                        \\\n  REGISTER_USER_KERNEL(\"layer_norm_grad\")                                                  \\\n      .SetCreateFn<LayerNormGradGpuKernel<dtype>>()                                        \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)                     \\\n                       && (user_op::HobDataType(\"dy\", 0) == GetDataType<dtype>::value))    \\\n      .SetInplaceProposalFn(                                                               \\\n          [](const user_op::InferContext& ctx,                                             \\\n             const user_op::AddInplaceArgPair& AddInplaceArgPairFn) -> Maybe<void> {       \\\n            if (ctx.has_input(\"_add_to_output\", 0)) {                                      \\\n              OF_RETURN_IF_ERROR(AddInplaceArgPairFn(\"dx\", 0, \"_add_to_output\", 0, true)); \\\n            }                                                                              \\\n            return Maybe<void>::Ok();                                                      \\\n          });\n\nREGISTER_LAYER_NORM_GRAD_CUDA_KERNEL(float)\nREGISTER_LAYER_NORM_GRAD_CUDA_KERNEL(double)\nREGISTER_LAYER_NORM_GRAD_CUDA_KERNEL(half)\n#if CUDA_VERSION >= 11000\nREGISTER_LAYER_NORM_GRAD_CUDA_KERNEL(nv_bfloat16)\n#endif\n\ntemplate<typename T>\nclass LayerNormParamGradGpuKernel final : public user_op::OpKernel,\n                                          public user_op::CudaGraphSupport {\n public:\n  LayerNormParamGradGpuKernel() = default;\n  ~LayerNormParamGradGpuKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    const user_op::Tensor* mean = ctx->Tensor4ArgNameAndIndex(\"mean\", 0);\n    const user_op::Tensor* inv_variance = ctx->Tensor4ArgNameAndIndex(\"inv_variance\", 0);\n    const int64_t num_instances = mean->shape_view().elem_cnt();\n    const int64_t norm_size = x->shape_view().elem_cnt() / num_instances;\n    user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n    const DataType data_type = dy->data_type();\n    const int grid_dim_x = (norm_size + tile_size - 1) / tile_size;\n    const int grid_dim_y = GetGirdDimY<T>(num_instances, norm_size);\n    const size_t tmp_gamma_diff_size = grid_dim_y * norm_size * sizeof(T);\n    T* tmp_gamma_diff_ptr = reinterpret_cast<T*>(tmp_buffer->mut_dptr());\n    T* tmp_beta_diff_ptr = reinterpret_cast<T*>(tmp_buffer->mut_dptr<char>() + tmp_gamma_diff_size);\n    T* reduce_buf_ptr =\n        reinterpret_cast<T*>(tmp_buffer->mut_dptr<char>() + 2 * tmp_gamma_diff_size);\n    using ComputeType = typename cuda::layer_norm::DefaultComputeType<T>::type;\n    LayerNormParamGrad<T, ComputeType><<<dim3(grid_dim_x, grid_dim_y), dim3(32, 32 / num_per_block),\n                                         0, ctx->stream()->As<ep::CudaStream>()->cuda_stream()>>>(\n        num_instances, norm_size, dy->dptr<T>(), x->dptr<T>(), mean->dptr<ComputeType>(),\n        inv_variance->dptr<ComputeType>(), tmp_gamma_diff_ptr, tmp_beta_diff_ptr);\n    const int32_t m = norm_size;\n    const int32_t n = 1;\n    const int32_t k = grid_dim_y;\n    std::unique_ptr<ep::primitive::Fill> fill =\n        ep::primitive::NewPrimitive<ep::primitive::FillFactory>(ctx->stream()->device_type(),\n                                                                data_type);\n    CHECK(fill);\n    fill->Launch(ctx->stream(), reduce_buf_ptr, 1.0, grid_dim_y);\n    std::unique_ptr<ep::primitive::Matmul> matmul =\n        ep::primitive::NewPrimitive<ep::primitive::MatmulFactory>(\n            ctx->stream()->device_type(), data_type, ep::primitive::BlasTransposeType::T,\n            ep::primitive::BlasTransposeType::N);\n    CHECK(matmul);\n    if (ctx->has_output(\"gamma_diff\", 0)) {\n      user_op::Tensor* gamma_diff = ctx->Tensor4ArgNameAndIndex(\"gamma_diff\", 0);\n      matmul->Launch(ctx->stream(), m, n, k, 1.0, tmp_gamma_diff_ptr, reduce_buf_ptr, 0.0,\n                     gamma_diff->mut_dptr());\n    }\n    if (ctx->has_output(\"beta_diff\", 0)) {\n      user_op::Tensor* beta_diff = ctx->Tensor4ArgNameAndIndex(\"beta_diff\", 0);\n      matmul->Launch(ctx->stream(), m, n, k, 1.0, tmp_beta_diff_ptr, reduce_buf_ptr, 0.0,\n                     beta_diff->mut_dptr());\n    }\n  };\n};\n\n#define REGISTER_LAYER_NORM_PARAM_GRAD_GPU_KERNEL(dtype)                                    \\\n  REGISTER_USER_KERNEL(\"layer_norm_param_grad\")                                             \\\n      .SetCreateFn<LayerNormParamGradGpuKernel<dtype>>()                                    \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)                      \\\n                       && (user_op::HobDataType(\"dy\", 0) == GetDataType<dtype>::value))     \\\n      .SetInferTmpSizeFn([](user_op::InferContext* ctx) {                                   \\\n        const int64_t begin_params_axis = ctx->Attr<int64_t>(\"begin_params_axis\");          \\\n        const bool has_gamma_diff = ctx->has_output(\"gamma_diff\", 0);                       \\\n        const bool has_beta_diff = ctx->has_output(\"beta_diff\", 0);                         \\\n        const auto& dy = ctx->InputTensorDesc(\"dy\", 0);                                     \\\n        const int64_t num_instances = dy.shape().Count(0, begin_params_axis);               \\\n        const int64_t norm_size = dy.shape().Count(begin_params_axis);                      \\\n        const int grid_dim_y = GetGirdDimY<dtype>(num_instances, norm_size);                \\\n        size_t tmp_buffer_size = (2 * grid_dim_y * norm_size + grid_dim_y) * sizeof(dtype); \\\n        return tmp_buffer_size;                                                             \\\n      });\n\nREGISTER_LAYER_NORM_PARAM_GRAD_GPU_KERNEL(float)\nREGISTER_LAYER_NORM_PARAM_GRAD_GPU_KERNEL(double)\nREGISTER_LAYER_NORM_PARAM_GRAD_GPU_KERNEL(half)\n#if CUDA_VERSION >= 11000\nREGISTER_LAYER_NORM_PARAM_GRAD_GPU_KERNEL(nv_bfloat16)\n#endif\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/lerp_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/user/kernels/lerp_kernel_util.h\"\n\nnamespace oneflow {\n\ntemplate<DeviceType device_type, typename T>\nclass LerpKernel final : public user_op::OpKernel {\n public:\n  LerpKernel() = default;\n  ~LerpKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* start = ctx->Tensor4ArgNameAndIndex(\"start\", 0);\n    const user_op::Tensor* end = ctx->Tensor4ArgNameAndIndex(\"end\", 0);\n    const user_op::Tensor* weight = ctx->Tensor4ArgNameAndIndex(\"weight\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n\n    const ShapeView& start_shape = start->shape_view();\n    const ShapeView& end_shape = end->shape_view();\n    const ShapeView& weight_shape = weight->shape_view();\n    CHECK_EQ(start_shape, end_shape);\n    CHECK_EQ(start_shape, weight_shape);\n\n    const T* start_ptr = start->dptr<T>();\n    const T* end_ptr = end->dptr<T>();\n    const T* weight_ptr = weight->dptr<T>();\n    T* out_ptr = out->mut_dptr<T>();\n\n    LerpKernelUtil<device_type, T>::Forward(ctx->stream(), start_shape.elem_cnt(), start_ptr,\n                                            weight_ptr, end_ptr, out_ptr);\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_LERP_KERNEL(device_type, dtype)                                              \\\n  REGISTER_USER_KERNEL(\"lerp\").SetCreateFn<LerpKernel<device_type, dtype>>().SetIsMatchedHob( \\\n      (user_op::HobDeviceType() == device_type)                                               \\\n      && (user_op::HobDataType(\"out\", 0) == GetDataType<dtype>::value));\n\nREGISTER_LERP_KERNEL(DeviceType::kCPU, float)\nREGISTER_LERP_KERNEL(DeviceType::kCPU, double)\nREGISTER_LERP_KERNEL(DeviceType::kCPU, uint8_t)\nREGISTER_LERP_KERNEL(DeviceType::kCPU, int8_t)\nREGISTER_LERP_KERNEL(DeviceType::kCPU, int32_t)\nREGISTER_LERP_KERNEL(DeviceType::kCPU, int64_t)\n#ifdef WITH_CUDA\nREGISTER_LERP_KERNEL(DeviceType::kCUDA, half)\nREGISTER_LERP_KERNEL(DeviceType::kCUDA, float)\nREGISTER_LERP_KERNEL(DeviceType::kCUDA, double)\nREGISTER_LERP_KERNEL(DeviceType::kCUDA, uint8_t)\nREGISTER_LERP_KERNEL(DeviceType::kCUDA, int8_t)\nREGISTER_LERP_KERNEL(DeviceType::kCUDA, int32_t)\nREGISTER_LERP_KERNEL(DeviceType::kCUDA, int64_t)\n#endif  // WITH_CUDA\n\ntemplate<DeviceType device_type, typename T>\nclass LerpGradKernel final : public user_op::OpKernel {\n public:\n  LerpGradKernel() = default;\n  ~LerpGradKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* start = ctx->Tensor4ArgNameAndIndex(\"start\", 0);\n    const user_op::Tensor* end = ctx->Tensor4ArgNameAndIndex(\"end\", 0);\n    const user_op::Tensor* weight = ctx->Tensor4ArgNameAndIndex(\"weight\", 0);\n    const user_op::Tensor* out_diff = ctx->Tensor4ArgNameAndIndex(\"out_diff\", 0);\n    user_op::Tensor* start_diff = ctx->Tensor4ArgNameAndIndex(\"start_diff\", 0);\n    user_op::Tensor* end_diff = ctx->Tensor4ArgNameAndIndex(\"end_diff\", 0);\n    user_op::Tensor* weight_diff = ctx->Tensor4ArgNameAndIndex(\"weight_diff\", 0);\n\n    const ShapeView& start_shape = start->shape_view();\n    const ShapeView& end_shape = end->shape_view();\n    const ShapeView& weight_shape = weight->shape_view();\n    CHECK_EQ(start_shape, end_shape);\n    CHECK_EQ(start_shape, weight_shape);\n\n    const T* start_ptr = start->dptr<T>();\n    const T* end_ptr = end->dptr<T>();\n    const T* weight_ptr = weight->dptr<T>();\n    const T* out_diff_ptr = out_diff->dptr<T>();\n    T* start_diff_ptr = start_diff->mut_dptr<T>();\n    T* end_diff_ptr = end_diff->mut_dptr<T>();\n    T* weight_diff_ptr = weight_diff->mut_dptr<T>();\n\n    LerpKernelUtil<device_type, T>::Backward(ctx->stream(), start_shape.elem_cnt(), start_ptr,\n                                             weight_ptr, end_ptr, out_diff_ptr, start_diff_ptr,\n                                             weight_diff_ptr, end_diff_ptr);\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_LERP_GRAD_KERNEL(device_type, dtype)                                           \\\n  REGISTER_USER_KERNEL(\"lerp_grad\")                                                             \\\n      .SetCreateFn<LerpGradKernel<device_type, dtype>>()                                        \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == device_type)                                \\\n                       && (user_op::HobDataType(\"start_diff\", 0) == GetDataType<dtype>::value)  \\\n                       && (user_op::HobDataType(\"weight_diff\", 0) == GetDataType<dtype>::value) \\\n                       && (user_op::HobDataType(\"end_diff\", 0) == GetDataType<dtype>::value));\n\nREGISTER_LERP_GRAD_KERNEL(DeviceType::kCPU, float)\nREGISTER_LERP_GRAD_KERNEL(DeviceType::kCPU, double)\n#ifdef WITH_CUDA\nREGISTER_LERP_GRAD_KERNEL(DeviceType::kCUDA, half)\nREGISTER_LERP_GRAD_KERNEL(DeviceType::kCUDA, float)\nREGISTER_LERP_GRAD_KERNEL(DeviceType::kCUDA, double)\n#endif  // WITH_CUDA\n\ntemplate<DeviceType device_type, typename T, typename ValueT>\nclass ScalarLerpKernel final : public user_op::OpKernel {\n public:\n  ScalarLerpKernel() = default;\n  ~ScalarLerpKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* start = ctx->Tensor4ArgNameAndIndex(\"start\", 0);\n    const user_op::Tensor* end = ctx->Tensor4ArgNameAndIndex(\"end\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n\n    const ShapeView& start_shape = start->shape_view();\n    const ShapeView& end_shape = end->shape_view();\n    CHECK_EQ(start_shape, end_shape);\n\n    const T* start_ptr = start->dptr<T>();\n    const T* end_ptr = end->dptr<T>();\n    T* out_ptr = out->mut_dptr<T>();\n\n    Scalar scalar_operand;\n    if (ctx->Attr<bool>(\"has_int_operand\")) {\n      scalar_operand = ctx->Attr<int64_t>(\"int_operand\");\n    } else if (ctx->Attr<bool>(\"has_float_operand\")) {\n      scalar_operand = ctx->Attr<double>(\"float_operand\");\n    } else {\n      UNIMPLEMENTED();\n    }\n\n    ScalarLerpKernelUtil<device_type, T, ValueT>::Forward(\n        ctx->stream(), start_shape.elem_cnt(), start_ptr, end_ptr, scalar_operand, out_ptr);\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_SCALAR_LERP_KERNEL(device_type, dtype, value_type)    \\\n  REGISTER_USER_KERNEL(\"scalar_lerp\")                                  \\\n      .SetCreateFn<ScalarLerpKernel<device_type, dtype, value_type>>() \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == device_type)       \\\n                       && (user_op::HobDataType(\"out\", 0) == GetDataType<dtype>::value));\n\nREGISTER_SCALAR_LERP_KERNEL(DeviceType::kCPU, float, double)\nREGISTER_SCALAR_LERP_KERNEL(DeviceType::kCPU, double, double)\nREGISTER_SCALAR_LERP_KERNEL(DeviceType::kCPU, int8_t, int64_t)\nREGISTER_SCALAR_LERP_KERNEL(DeviceType::kCPU, int32_t, int64_t)\nREGISTER_SCALAR_LERP_KERNEL(DeviceType::kCPU, int64_t, int64_t)\n#ifdef WITH_CUDA\nREGISTER_SCALAR_LERP_KERNEL(DeviceType::kCUDA, half, double)\nREGISTER_SCALAR_LERP_KERNEL(DeviceType::kCUDA, float, double)\nREGISTER_SCALAR_LERP_KERNEL(DeviceType::kCUDA, double, double)\nREGISTER_SCALAR_LERP_KERNEL(DeviceType::kCUDA, int8_t, int64_t)\nREGISTER_SCALAR_LERP_KERNEL(DeviceType::kCUDA, int32_t, int64_t)\nREGISTER_SCALAR_LERP_KERNEL(DeviceType::kCUDA, int64_t, int64_t)\n#endif  // WITH_CUDA\n\ntemplate<DeviceType device_type, typename T, typename ValueT>\nclass ScalarLerpGradKernel final : public user_op::OpKernel {\n public:\n  ScalarLerpGradKernel() = default;\n  ~ScalarLerpGradKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* start = ctx->Tensor4ArgNameAndIndex(\"start\", 0);\n    const user_op::Tensor* end = ctx->Tensor4ArgNameAndIndex(\"end\", 0);\n    const user_op::Tensor* out_diff = ctx->Tensor4ArgNameAndIndex(\"out_diff\", 0);\n    user_op::Tensor* start_diff = ctx->Tensor4ArgNameAndIndex(\"start_diff\", 0);\n    user_op::Tensor* end_diff = ctx->Tensor4ArgNameAndIndex(\"end_diff\", 0);\n\n    const ShapeView& start_shape = start->shape_view();\n    const ShapeView& end_shape = end->shape_view();\n    CHECK_EQ(start_shape, end_shape);\n\n    const T* start_ptr = start->dptr<T>();\n    const T* end_ptr = end->dptr<T>();\n    const T* out_diff_ptr = out_diff->dptr<T>();\n    T* start_diff_ptr = start_diff->mut_dptr<T>();\n    T* end_diff_ptr = end_diff->mut_dptr<T>();\n\n    Scalar scalar_operand;\n    if (ctx->Attr<bool>(\"has_int_operand\")) {\n      scalar_operand = ctx->Attr<int64_t>(\"int_operand\");\n    } else if (ctx->Attr<bool>(\"has_float_operand\")) {\n      scalar_operand = ctx->Attr<double>(\"float_operand\");\n    } else {\n      UNIMPLEMENTED();\n    }\n\n    ScalarLerpKernelUtil<device_type, T, ValueT>::Backward(\n        ctx->stream(), start_shape.elem_cnt(), start_ptr, end_ptr, out_diff_ptr, scalar_operand,\n        start_diff_ptr, end_diff_ptr);\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_SCALAR_LERP_GRAD_KERNEL(device_type, dtype, value_type)                       \\\n  REGISTER_USER_KERNEL(\"scalar_lerp_grad\")                                                     \\\n      .SetCreateFn<ScalarLerpGradKernel<device_type, dtype, value_type>>()                     \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == device_type)                               \\\n                       && (user_op::HobDataType(\"start_diff\", 0) == GetDataType<dtype>::value) \\\n                       && (user_op::HobDataType(\"end_diff\", 0) == GetDataType<dtype>::value));\n\nREGISTER_SCALAR_LERP_GRAD_KERNEL(DeviceType::kCPU, float, double)\nREGISTER_SCALAR_LERP_GRAD_KERNEL(DeviceType::kCPU, double, double)\n#ifdef WITH_CUDA\nREGISTER_SCALAR_LERP_GRAD_KERNEL(DeviceType::kCUDA, half, double)\nREGISTER_SCALAR_LERP_GRAD_KERNEL(DeviceType::kCUDA, float, double)\nREGISTER_SCALAR_LERP_GRAD_KERNEL(DeviceType::kCUDA, double, double)\n#endif  // WITH_CUDA\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/lerp_kernel_util.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/user/kernels/lerp_kernel_util.h\"\n\nnamespace oneflow {\n\ntemplate<typename T>\nstruct LerpKernelUtil<DeviceType::kCPU, T> {\n  static void Forward(ep::Stream* stream, const int64_t n, const T* start, const T* weight,\n                      const T* end, T* out) {\n    FOR_RANGE(int64_t, i, 0, n) { out[i] = start[i] + weight[i] * (end[i] - start[i]); }\n  }\n\n  static void Backward(ep::Stream* stream, const int64_t n, const T* start, const T* weight,\n                       const T* end, const T* out_diff, T* start_diff, T* weight_diff,\n                       T* end_diff) {\n    FOR_RANGE(int64_t, i, 0, n) {\n      T out_diff_i = out_diff[i];\n      start_diff[i] = (static_cast<T>(1.0) - weight[i]) * out_diff_i;\n      weight_diff[i] = (end[i] - start[i]) * out_diff_i;\n      end_diff[i] = weight[i] * out_diff_i;\n    }\n  }\n};\n\ntemplate<typename T, typename ValueT>\nstruct ScalarLerpKernelUtil<DeviceType::kCPU, T, ValueT> {\n  static void Forward(ep::Stream* stream, const int64_t n, const T* start, const T* end,\n                      const Scalar operand, T* out) {\n    T weight = static_cast<T>(operand.Value<ValueT>());\n    FOR_RANGE(int64_t, i, 0, n) { out[i] = start[i] + weight * (end[i] - start[i]); }\n  }\n\n  static void Backward(ep::Stream* stream, const int64_t n, const T* start, const T* end,\n                       const T* out_diff, const Scalar operand, T* start_diff, T* end_diff) {\n    T weight = static_cast<T>(operand.Value<ValueT>());\n    FOR_RANGE(int64_t, i, 0, n) {\n      T out_diff_i = out_diff[i];\n      start_diff[i] = (static_cast<T>(1.0) - weight) * out_diff_i;\n      end_diff[i] = out_diff_i - start_diff[i];\n    }\n  }\n};\n\n#define INSTANTIATE_LERP_KERNEL_UTIL_CPU(data_type, other) \\\n  template struct LerpKernelUtil<DeviceType::kCPU, data_type>;\nOF_PP_FOR_EACH_TUPLE(INSTANTIATE_LERP_KERNEL_UTIL_CPU, LERP_DATA_TYPE_SEQ_CPU)\n#undef INSTANTIATE_LERP_KERNEL_UTIL_CPU\n\n#define INSTANTIATE_SCALAR_LERP_KERNEL_UTIL_CPU(data_type, value_data_type)           \\\n  template struct ScalarLerpKernelUtil<DeviceType::kCPU, OF_PP_PAIR_FIRST(data_type), \\\n                                       OF_PP_PAIR_FIRST(value_data_type)>;\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_SCALAR_LERP_KERNEL_UTIL_CPU, LERP_DATA_TYPE_SEQ_CPU,\n                                 SCALAR_VALUE_DATA_TYPE_SEQ)\n#undef INSTANTIATE_SCALAR_LERP_KERNEL_UTIL_CPU\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/lerp_kernel_util.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/device/cuda_util.h\"\n#include \"oneflow/user/kernels/lerp_kernel_util.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename T>\n__global__ void LerpForwardGpu(const int n, const T* start, const T* weight, const T* end, T* out) {\n  CUDA_1D_KERNEL_LOOP(i, n) {\n    const T start_i = start[i];\n    out[i] = start_i + weight[i] * (end[i] - start_i);\n  }\n}\n\ntemplate<typename T, typename ValueT>\n__global__ void ScalarLerpForwardGpu(const int n, const T* start, const ValueT weight, const T* end,\n                                     T* out) {\n  T weight_calculate = 0.0;\n  if constexpr (std::is_same<T, half>::value) {\n    weight_calculate = __float2half(static_cast<float>(weight));\n  } else {\n    weight_calculate = static_cast<T>(weight);\n  }\n  CUDA_1D_KERNEL_LOOP(i, n) {\n    const T start_i = start[i];\n    out[i] = start_i + weight_calculate * (end[i] - start_i);\n  }\n}\n\ntemplate<typename T>\n__global__ void LerpBackwardGpu(const int n, const T* start, const T* weight, const T* end,\n                                const T* out_diff, T* start_diff, T* weight_diff, T* end_diff) {\n  CUDA_1D_KERNEL_LOOP(i, n) {\n    const T out_diff_i = out_diff[i];\n    const T start_diff_i = (static_cast<T>(1.0) - weight[i]) * out_diff_i;\n    start_diff[i] = start_diff_i;\n    weight_diff[i] = (end[i] - start[i]) * out_diff_i;\n    end_diff[i] = out_diff_i - start_diff_i;\n  }\n}\n\ntemplate<typename T, typename ValueT>\n__global__ void ScalarLerpBackwardGpu(const int n, const T* start, const ValueT weight,\n                                      const T* end, const T* out_diff, T* start_diff, T* end_diff) {\n  T weight_calculate = 0.0;\n  if constexpr (std::is_same<T, half>::value) {\n    weight_calculate = __float2half(static_cast<float>(weight));\n  } else {\n    weight_calculate = static_cast<T>(weight);\n  }\n  CUDA_1D_KERNEL_LOOP(i, n) {\n    T out_diff_i = out_diff[i];\n    const T start_diff_i = (static_cast<T>(1.0) - weight_calculate) * out_diff_i;\n    start_diff[i] = start_diff_i;\n    end_diff[i] = out_diff_i - start_diff_i;\n  }\n}\n\n}  // namespace\n\ntemplate<typename T>\nstruct LerpKernelUtil<DeviceType::kCUDA, T> {\n  static void Forward(ep::Stream* stream, const int64_t n, const T* start, const T* weight,\n                      const T* end, T* out) {\n    RUN_CUDA_KERNEL((LerpForwardGpu<T>), stream, n, n, start, weight, end, out);\n  }\n\n  static void Backward(ep::Stream* stream, const int64_t n, const T* start, const T* weight,\n                       const T* end, const T* out_diff, T* start_diff, T* weight_diff,\n                       T* end_diff) {\n    RUN_CUDA_KERNEL((LerpBackwardGpu<T>), stream, n, n, start, weight, end, out_diff, start_diff,\n                    weight_diff, end_diff);\n  }\n};\n\ntemplate<typename T, typename ValueT>\nstruct ScalarLerpKernelUtil<DeviceType::kCUDA, T, ValueT> {\n  static void Forward(ep::Stream* stream, const int64_t n, const T* start, const T* end,\n                      const Scalar operand, T* out) {\n    ValueT weight = operand.Value<ValueT>();\n    RUN_CUDA_KERNEL((ScalarLerpForwardGpu<T, ValueT>), stream, n, n, start, weight, end, out);\n  }\n\n  static void Backward(ep::Stream* stream, const int64_t n, const T* start, const T* end,\n                       const T* out_diff, const Scalar operand, T* start_diff, T* end_diff) {\n    ValueT weight = operand.Value<ValueT>();\n    RUN_CUDA_KERNEL((ScalarLerpBackwardGpu<T, ValueT>), stream, n, n, start, weight, end, out_diff,\n                    start_diff, end_diff);\n  }\n};\n\n#define INSTANTIATE_LERP_KERNEL_UTIL_CUDA(data_type, other) \\\n  template struct LerpKernelUtil<DeviceType::kCUDA, data_type>;\nOF_PP_FOR_EACH_TUPLE(INSTANTIATE_LERP_KERNEL_UTIL_CUDA, LERP_DATA_TYPE_SEQ_CUDA)\n#undef INSTANTIATE_LERP_KERNEL_UTIL_CUDA\n\n#define INSTANTIATE_SCALAR_LERP_KERNEL_UTIL_CUDA(data_type, value_data_type)           \\\n  template struct ScalarLerpKernelUtil<DeviceType::kCUDA, OF_PP_PAIR_FIRST(data_type), \\\n                                       OF_PP_PAIR_FIRST(value_data_type)>;\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_SCALAR_LERP_KERNEL_UTIL_CUDA, LERP_DATA_TYPE_SEQ_CUDA,\n                                 SCALAR_VALUE_DATA_TYPE_SEQ)\n#undef INSTANTIATE_SCALAR_LERP_KERNEL_UTIL_CUDA\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/lerp_kernel_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_KERNELS_LERP_KERNEL_UTIL_H_\n#define ONEFLOW_USER_KERNELS_LERP_KERNEL_UTIL_H_\n\n#include \"oneflow/core/common/scalar.h\"\n#include \"oneflow/core/common/shape_view.h\"\n#include \"oneflow/core/ep/include/stream.h\"\n#include \"oneflow/core/ndarray/ndarray_util.h\"\n\nnamespace oneflow {\n\ntemplate<DeviceType device_type, typename T>\nstruct LerpKernelUtil {\n  static void Forward(ep::Stream* stream, const int64_t n, const T* start, const T* weight,\n                      const T* end, T* out);\n  static void Backward(ep::Stream* stream, const int64_t n, const T* start, const T* weight,\n                       const T* end, const T* out_diff, T* start_diff, T* weight_diff, T* end_diff);\n};\n\ntemplate<DeviceType device_type, typename T, typename ValueT>\nstruct ScalarLerpKernelUtil {\n  static void Forward(ep::Stream* stream, const int64_t n, const T* start, const T* end,\n                      const Scalar operand, T* out);\n  static void Backward(ep::Stream* stream, const int64_t n, const T* start, const T* end,\n                       const T* out_diff, const Scalar operand, T* start_diff, T* end_diff);\n};\n\n#define SCALAR_VALUE_DATA_TYPE_SEQ                \\\n  OF_PP_MAKE_TUPLE_SEQ(int64_t, DataType::kInt64) \\\n  OF_PP_MAKE_TUPLE_SEQ(double, DataType::kDouble)\n\n#define LERP_DATA_TYPE_SEQ_CPU \\\n  FLOATING_DATA_TYPE_SEQ       \\\n  SIGNED_INT_DATA_TYPE_SEQ     \\\n  UNSIGNED_INT_DATA_TYPE_SEQ\n\n#ifdef WITH_CUDA\n#define LERP_DATA_TYPE_SEQ_CUDA \\\n  FLOATING_DATA_TYPE_SEQ        \\\n  SIGNED_INT_DATA_TYPE_SEQ      \\\n  UNSIGNED_INT_DATA_TYPE_SEQ    \\\n  HALF_DATA_TYPE_SEQ\n#endif\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_KERNELS_LERP_KERNEL_UTIL_H_\n"
  },
  {
    "path": "oneflow/user/kernels/linalg_cross_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_kernel.h\"\n\nnamespace oneflow {\n\ntemplate<typename T>\nclass CpuLinalgCrossKernel final : public user_op::OpKernel {\n public:\n  CpuLinalgCrossKernel() = default;\n  ~CpuLinalgCrossKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const auto* input_tensor = ctx->Tensor4ArgNameAndIndex(\"input\", 0);\n    const auto* other_tensor = ctx->Tensor4ArgNameAndIndex(\"other\", 0);\n    auto* out_tensor = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n\n    const auto shape = input_tensor->shape_view();\n    const auto num_axes = shape.NumAxes();\n\n    int64_t dim = ctx->Attr<int64_t>(\"dim\");\n\n    const auto strides = [&shape]() -> std::vector<int64_t> {\n      std::vector<int64_t> result(shape.NumAxes(), 1);\n      for (size_t i(0); i < result.size() - 1; ++i) { result[i] = shape.Count(i + 1); }\n      return result;\n    }();\n\n    const int64_t total = shape.elem_cnt() / 3;\n    int64_t stride = strides[dim];\n\n    const T* input_ptr = input_tensor->dptr<T>();\n    const T* other_ptr = other_tensor->dptr<T>();\n    T* out_dtr = out_tensor->mut_dptr<T>();\n\n    std::vector<int64_t> positions_in_dims(num_axes);\n\n    int64_t start = 0;\n\n    int64_t s = 0;\n    while (s < total) {\n      out_dtr[start + 0 * stride] = input_ptr[start + 1 * stride] * other_ptr[start + 2 * stride]\n                                    - input_ptr[start + 2 * stride] * other_ptr[start + 1 * stride];\n      out_dtr[start + 1 * stride] = input_ptr[start + 2 * stride] * other_ptr[start + 0 * stride]\n                                    - input_ptr[start + 0 * stride] * other_ptr[start + 2 * stride];\n      out_dtr[start + 2 * stride] = input_ptr[start + 0 * stride] * other_ptr[start + 1 * stride]\n                                    - input_ptr[start + 1 * stride] * other_ptr[start + 0 * stride];\n\n      ++s;\n\n      FOR_RANGE(int64_t, i, 0, num_axes) {\n        if (i == dim) continue;\n\n        ++positions_in_dims[i];\n        start += strides[i];\n\n        if (positions_in_dims[i] == shape.At(i) && i != num_axes - 1) {\n          start -= positions_in_dims[i] * strides[i];\n          positions_in_dims[i] = 0;\n        } else {\n          break;\n        }\n      }\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_CPU_LINALG_CROSS_KERNEL(dtype)                       \\\n  REGISTER_USER_KERNEL(\"linalg_cross\")                                \\\n      .SetCreateFn<CpuLinalgCrossKernel<dtype>>()                     \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \\\n                       && (user_op::HobDataType(\"input\", 0) == GetDataType<dtype>::value));\n\nREGISTER_CPU_LINALG_CROSS_KERNEL(float)\nREGISTER_CPU_LINALG_CROSS_KERNEL(double)\n\n}  // namespace oneflow"
  },
  {
    "path": "oneflow/user/kernels/linalg_cross_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_kernel.h\"\n#include \"oneflow/core/device/cuda_util.h\"\n\nnamespace {\n\ntemplate<typename T>\n__global__ void LinalgCrossForward(const int64_t n, const T* input, const T* other, T* out) {\n  CUDA_1D_KERNEL_LOOP_T(int64_t, i, n) {\n    const int64_t index = i * 3;\n    out[index] = input[index + 1] * other[index + 2] - input[index + 2] * other[index + 1];\n    out[index + 1] = input[index + 2] * other[index] - input[index] * other[index + 2];\n    out[index + 2] = input[index] * other[index + 1] - input[index + 1] * other[index];\n  }\n}\n\n}  // namespace\n\nnamespace oneflow {\n\ntemplate<typename T>\nclass CudaLinalgCrossKernel final : public user_op::OpKernel {\n public:\n  CudaLinalgCrossKernel() = default;\n  ~CudaLinalgCrossKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const auto* input_tensor = ctx->Tensor4ArgNameAndIndex(\"input\", 0);\n    const auto* other_tensor = ctx->Tensor4ArgNameAndIndex(\"other\", 0);\n    auto* out_tensor = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n\n    const int64_t n = input_tensor->shape_view().elem_cnt() / 3;\n\n    if (n == 0) { return; }\n    RUN_CUDA_KERNEL((LinalgCrossForward<T>), ctx->stream(), n, n, input_tensor->dptr<T>(),\n                    other_tensor->dptr<T>(), out_tensor->mut_dptr<T>());\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_CUDA_LINALG_CROSS_KERNEL(dtype)                       \\\n  REGISTER_USER_KERNEL(\"linalg_cross\")                                 \\\n      .SetCreateFn<CudaLinalgCrossKernel<dtype>>()                     \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \\\n                       && (user_op::HobDataType(\"input\", 0) == GetDataType<dtype>::value));\n\nREGISTER_CUDA_LINALG_CROSS_KERNEL(float)\nREGISTER_CUDA_LINALG_CROSS_KERNEL(double)\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/log_softmax_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/ep/include/primitive/log_softmax.h\"\n#include \"oneflow/core/ep/include/primitive/log_softmax_backward.h\"\n#include \"oneflow/core/kernel/cuda_graph_support.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename Context>\nstd::unique_ptr<ep::primitive::LogSoftmax> NewLogSoftmaxPrimitive(Context* ctx) {\n  const DataType data_type = ctx->TensorDesc4ArgNameAndIndex(\"in\", 0)->data_type();\n  return ep::primitive::NewPrimitive<ep::primitive::LogSoftmaxFactory>(ctx->device_type(),\n                                                                       data_type);\n}\n\nauto LogSoftmaxPrimitiveExists() {\n  return hob::make_custom(\"LogSoftmaxPrimitiveExists\", [](const user_op::KernelRegContext& ctx) {\n    return NewLogSoftmaxPrimitive(&ctx).operator bool();\n  });\n}\n\ntemplate<typename Context>\nstd::unique_ptr<ep::primitive::LogSoftmaxBackward> NewLogSoftmaxBackwardPrimitive(Context* ctx) {\n  const DataType data_type = ctx->TensorDesc4ArgNameAndIndex(\"dy\", 0)->data_type();\n  return ep::primitive::NewPrimitive<ep::primitive::LogSoftmaxBackwardFactory>(ctx->device_type(),\n                                                                               data_type);\n}\n\nauto LogSoftmaxBackwardPrimitiveExists() {\n  return hob::make_custom(\"LogSoftmaxBackwardPrimitiveExists\",\n                          [](const user_op::KernelRegContext& ctx) {\n                            return NewLogSoftmaxBackwardPrimitive(&ctx).operator bool();\n                          });\n}\n\nclass LogSoftmaxKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport {\n public:\n  LogSoftmaxKernel() = default;\n  ~LogSoftmaxKernel() override = default;\n\n private:\n  using user_op::OpKernel::Compute;\n\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    user_op::Tensor* prob = ctx->Tensor4ArgNameAndIndex(\"prob\", 0);\n    const ShapeView& in_shape = in->shape_view();\n    const int64_t num_classes = in_shape.At(in_shape.NumAxes() - 1);\n    const int64_t num_instances = in_shape.Count(0, in_shape.NumAxes() - 1);\n    std::unique_ptr<ep::primitive::LogSoftmax> primitive = NewLogSoftmaxPrimitive(ctx);\n    CHECK(primitive);\n    primitive->Launch(ctx->stream(), num_instances, num_classes, in->dptr(), prob->mut_dptr());\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\nclass LogSoftmaxGradKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport {\n public:\n  LogSoftmaxGradKernel() = default;\n  ~LogSoftmaxGradKernel() override = default;\n\n private:\n  using user_op::OpKernel::Compute;\n\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* prob = ctx->Tensor4ArgNameAndIndex(\"prob\", 0);\n    const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex(\"dx\", 0);\n\n    const int64_t num_classes = prob->shape_view().At(prob->shape_view().NumAxes() - 1);\n    const int64_t num_instances = prob->shape_view().elem_cnt() / num_classes;\n\n    std::unique_ptr<ep::primitive::LogSoftmaxBackward> primitive =\n        NewLogSoftmaxBackwardPrimitive(ctx);\n    CHECK(primitive);\n    primitive->Launch(ctx->stream(), num_instances, num_classes, prob->dptr(), dy->dptr(),\n                      dx->mut_dptr());\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n}  // namespace\n\nREGISTER_USER_KERNEL(\"log_softmax\")\n    .SetCreateFn<LogSoftmaxKernel>()\n    .SetIsMatchedHob(LogSoftmaxPrimitiveExists() == true);\n\nREGISTER_USER_KERNEL(\"log_softmax_grad\")\n    .SetCreateFn<LogSoftmaxGradKernel>()\n    .SetIsMatchedHob(LogSoftmaxBackwardPrimitiveExists() == true);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/logical_not_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/cuda_graph_support.h\"\n#include \"oneflow/core/ep/include/primitive/elementwise_unary.h\"\n#include \"oneflow/user/kernels/op_kernel_wrapper.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename Context>\nstd::unique_ptr<ep::primitive::ElementwiseUnary> NewLogicalNotPrimitive(Context* ctx) {\n  const DataType in_data_type = ctx->TensorDesc4ArgNameAndIndex(\"x\", 0)->data_type();\n  const DataType out_data_type = ctx->TensorDesc4ArgNameAndIndex(\"y\", 0)->data_type();\n  return ep::primitive::NewPrimitive<ep::primitive::ElementwiseUnaryFactory>(\n      ctx->device_type(), ep::primitive::UnaryOp::kLogicalNot, in_data_type, out_data_type);\n}\n\nclass LogicalNotKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport {\n public:\n  LogicalNotKernel() = default;\n  ~LogicalNotKernel() override = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* tensor_x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    user_op::Tensor* tensor_y = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    int64_t n = tensor_x->shape_view().elem_cnt();\n\n    if (n != 0) {\n      auto primitive = NewLogicalNotPrimitive(ctx);\n      CHECK(primitive);\n      primitive->Launch(ctx->stream(), tensor_x->dptr(), tensor_y->mut_dptr(), n);\n    }\n  }\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\nauto LogicalNotPrimitiveExists() {\n  return hob::make_custom(\"LogicalNotPrimitiveExists\",\n                          [](const user_op::KernelRegContext& ctx) -> bool {\n                            return NewLogicalNotPrimitive(&ctx).operator bool();\n                          });\n}\n\nREGISTER_USER_KERNEL(\"logical_not\")\n    .SetCreateFn<LogicalNotKernel>()\n    .SetIsMatchedHob(LogicalNotPrimitiveExists());\n}  // namespace\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/loss_kernel_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_KERNELS_LOSS_KERNEL_UTIL_H_\n#define ONEFLOW_USER_KERNELS_LOSS_KERNEL_UTIL_H_\n\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/ep/include/stream.h\"\n#include \"oneflow/core/framework/framework.h\"\n\nnamespace oneflow {\nnamespace user_op {\nnamespace loss {\n\ntemplate<DeviceType device_type, typename T, typename R>\nclass SimpleLossKernel : public user_op::OpKernel {\n public:\n  SimpleLossKernel() = default;\n  ~SimpleLossKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const auto* input_blob = ctx->Tensor4ArgNameAndIndex(\"input\", 0);\n    const auto* target_blob = ctx->Tensor4ArgNameAndIndex(\"target\", 0);\n    auto* out_blob = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n\n    const int64_t elem_cnt = input_blob->shape_view().elem_cnt();\n\n    const T* input = input_blob->dptr<T>();\n    const T* target = target_blob->dptr<T>();\n    T* out = out_blob->mut_dptr<T>();\n\n    static_cast<const R*>(this)->ComputeOut(ctx, elem_cnt, input, target, out);\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\ntemplate<DeviceType device_type, typename T, typename R>\nclass SimpleLossGradKernel : public user_op::OpKernel {\n public:\n  SimpleLossGradKernel() = default;\n  ~SimpleLossGradKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const auto* input_blob = ctx->Tensor4ArgNameAndIndex(\"input\", 0);\n    const auto* target_blob = ctx->Tensor4ArgNameAndIndex(\"target\", 0);\n    const auto* dy_blob = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    auto* dx_blob = ctx->Tensor4ArgNameAndIndex(\"dx\", 0);\n\n    const int64_t elem_cnt = input_blob->shape_view().elem_cnt();\n\n    const T* dy = dy_blob->dptr<T>();\n    const T* input = input_blob->dptr<T>();\n    const T* target = target_blob->dptr<T>();\n    T* dx = dx_blob->mut_dptr<T>();\n\n    static_cast<const R*>(this)->ComputeOut(ctx, elem_cnt, input, target, dy, dx);\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\nnamespace {\n\n#define REGISTER_SIMPLE_LOSS_KERNEL(name, kernel, device, dtype)           \\\n  REGISTER_USER_KERNEL(name).SetCreateFn<kernel<dtype>>().SetIsMatchedHob( \\\n      (user_op::HobDeviceType() == device)                                 \\\n      && (user_op::HobDataType(\"input\", 0) == GetDataType<dtype>::value)   \\\n      && (user_op::HobDataType(\"target\", 0) == GetDataType<dtype>::value)  \\\n      && (user_op::HobDataType(\"out\", 0) == GetDataType<dtype>::value));\n\n#define REGISTER_SIMPLE_LOSS_GRAD_KERNEL(name, kernel, device, dtype)      \\\n  REGISTER_USER_KERNEL(name).SetCreateFn<kernel<dtype>>().SetIsMatchedHob( \\\n      (user_op::HobDeviceType() == device)                                 \\\n      && (user_op::HobDataType(\"input\", 0) == GetDataType<dtype>::value)   \\\n      && (user_op::HobDataType(\"target\", 0) == GetDataType<dtype>::value)  \\\n      && (user_op::HobDataType(\"dy\", 0) == GetDataType<dtype>::value)      \\\n      && (user_op::HobDataType(\"dx\", 0) == GetDataType<dtype>::value));\n\n}  // namespace\n\n#define REGISTER_SIMPLE_LOSS_KERNEL_CPU(name, kernel)                \\\n  REGISTER_SIMPLE_LOSS_KERNEL(name, kernel, DeviceType::kCPU, float) \\\n  REGISTER_SIMPLE_LOSS_KERNEL(name, kernel, DeviceType::kCPU, double)\n\n#define REGISTER_SIMPLE_LOSS_KERNEL_CUDA(name, kernel)                \\\n  REGISTER_SIMPLE_LOSS_KERNEL(name, kernel, DeviceType::kCUDA, half)  \\\n  REGISTER_SIMPLE_LOSS_KERNEL(name, kernel, DeviceType::kCUDA, float) \\\n  REGISTER_SIMPLE_LOSS_KERNEL(name, kernel, DeviceType::kCUDA, double)\n\n#define REGISTER_SIMPLE_LOSS_GRAD_KERNEL_CPU(name, kernel)                \\\n  REGISTER_SIMPLE_LOSS_GRAD_KERNEL(name, kernel, DeviceType::kCPU, float) \\\n  REGISTER_SIMPLE_LOSS_GRAD_KERNEL(name, kernel, DeviceType::kCPU, double)\n\n#define REGISTER_SIMPLE_LOSS_GRAD_KERNEL_CUDA(name, kernel)                \\\n  REGISTER_SIMPLE_LOSS_GRAD_KERNEL(name, kernel, DeviceType::kCUDA, half)  \\\n  REGISTER_SIMPLE_LOSS_GRAD_KERNEL(name, kernel, DeviceType::kCUDA, float) \\\n  REGISTER_SIMPLE_LOSS_GRAD_KERNEL(name, kernel, DeviceType::kCUDA, double)\n\n}  // namespace loss\n}  // namespace user_op\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_KERNELS_LOSS_KERNEL_UTIL_H_\n"
  },
  {
    "path": "oneflow/user/kernels/lu_decomposition_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/new_kernel_util.h\"\n\nnamespace oneflow {\n\nnamespace {\n\n#if CUDA_VERSION >= 11000\nstatic inline size_t BatchCount(const user_op::Tensor* batched_matrices) {\n  size_t result = 1;\n  for (size_t i = 0; i < batched_matrices->shape_view().NumAxes() - 2; i++) {\n    result *= batched_matrices->shape_view().At(i);\n  }\n  return result;\n}\n\nstatic inline size_t MatrixStride(const user_op::Tensor* batched_matrices) {\n  const int64_t num_axes = batched_matrices->shape_view().NumAxes();\n  return batched_matrices->shape_view().At(num_axes - 2)\n         * batched_matrices->shape_view().At(num_axes - 1);\n}\n\nstatic inline size_t PivotStride(const user_op::Tensor* batched_pivot) {\n  const int64_t num_axes = batched_pivot->shape_view().NumAxes();\n  return batched_pivot->shape_view().At(num_axes - 1);\n}\n\nvoid OFgetrf_bufferSize(ep::Stream* stream, int32_t m, int32_t n, float* dA_array, int32_t lda,\n                        int32_t& lwork) {\n  OF_CUSOLVER_CHECK(cusolverDnSgetrf_bufferSize(stream->As<ep::CudaStream>()->cusolver_dn_handle(),\n                                                m, n, dA_array, m, &lwork));\n}\n\nvoid OFgetrf_bufferSize(ep::Stream* stream, int32_t m, int32_t n, double* dA_array, int32_t lda,\n                        int32_t& lwork) {\n  OF_CUSOLVER_CHECK(cusolverDnDgetrf_bufferSize(stream->As<ep::CudaStream>()->cusolver_dn_handle(),\n                                                m, n, dA_array, m, &lwork));\n}\n\nvoid OFgetrf(ep::Stream* stream, int32_t m, int32_t n, float* dA_array, int32_t lda, float* d_work,\n             int32_t* pivot_ptr, int32_t* d_info) {\n  OF_CUSOLVER_CHECK(cusolverDnSgetrf(stream->As<ep::CudaStream>()->cusolver_dn_handle(), m, m,\n                                     dA_array, lda, d_work, pivot_ptr, d_info));\n}\n\nvoid OFgetrf(ep::Stream* stream, int32_t m, int32_t n, double* dA_array, int32_t lda,\n             double* d_work, int32_t* pivot_ptr, int32_t* d_info) {\n  OF_CUSOLVER_CHECK(cusolverDnDgetrf(stream->As<ep::CudaStream>()->cusolver_dn_handle(), m, m,\n                                     dA_array, lda, d_work, pivot_ptr, d_info));\n}\n}  // namespace\n\nnamespace user_op {\n\ntemplate<typename T>\nclass LUDecompositionKernel final : public user_op::OpKernel {\n public:\n  LUDecompositionKernel() = default;\n  ~LUDecompositionKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    user_op::Tensor* LU = ctx->Tensor4ArgNameAndIndex(\"LU\", 0);\n    user_op::Tensor* pivot = ctx->Tensor4ArgNameAndIndex(\"pivot\", 0);\n    auto stream = ctx->stream()->As<ep::CudaStream>();\n\n    // infer tmp buffer\n    const int32_t m = x->shape_view().At(x->shape_view().NumAxes() - 2);\n    const int32_t lda = m;\n    const T* x_ptr = x->dptr<T>();\n    T* LU_ptr = LU->mut_dptr<T>();\n    int32_t* pivot_ptr = pivot->mut_dptr<int32_t>();\n\n    size_t batch_count = BatchCount(x);\n    size_t matrix_stride = MatrixStride(x);\n    size_t pivot_stride = PivotStride(x);\n\n    std::unique_ptr<ep::primitive::Memcpy> memcpy_primitive =\n        ep::primitive::NewPrimitive<ep::primitive::MemcpyFactory>(ctx->stream()->device_type(),\n                                                                  ep::primitive::MemcpyKind::kDtoD);\n    CHECK(memcpy_primitive) << \"Can not create Memcpy primitive for device type \"\n                            << ctx->stream()->device_type();\n    memcpy_primitive->Launch(stream, LU_ptr, x_ptr, sizeof(T) * x->shape_view().elem_cnt());\n\n    std::vector<int32_t> batched_info(batch_count, -1);\n    int32_t* batched_d_info = nullptr;\n    int32_t lwork = -1;\n    T* d_work = nullptr;\n\n    OF_CUDA_CHECK(\n        cudaMalloc(reinterpret_cast<void**>(&batched_d_info), batch_count * sizeof(int32_t)));\n\n    for (size_t batch = 0; batch < batch_count; batch++) {\n      OFgetrf_bufferSize(stream, m, m, LU_ptr, m, lwork);\n      OF_CUDA_CHECK(cudaMalloc(reinterpret_cast<void**>(&d_work), sizeof(T) * lwork));\n      OFgetrf(stream, m, m, LU_ptr + batch * matrix_stride, lda, d_work,\n              pivot_ptr + batch * pivot_stride, batched_d_info + batch);\n      OF_CUDA_CHECK(cudaFree(d_work));\n    }\n\n    OF_CUDA_CHECK(cudaMemcpyAsync(batched_info.data(), batched_d_info,\n                                  batch_count * sizeof(int32_t), cudaMemcpyDeviceToHost,\n                                  stream->cuda_stream()));\n    for (size_t i = 0; i < batched_info.size(); i++) {\n      int32_t info = batched_info[i];\n      CHECK(info >= 0) << \"LU decomposition: \" << -info << \"-th parameter of batch \" << i\n                       << \" is wrong\";\n    }\n    OF_CUDA_CHECK(cudaFree(batched_d_info));\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_CUDA_LU_DECOMPOSITION_KERNEL(dtype)                   \\\n  REGISTER_USER_KERNEL(\"lu_decomposition\")                             \\\n      .SetCreateFn<LUDecompositionKernel<dtype>>()                     \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \\\n                       && (user_op::HobDataType(\"x\", 0) == GetDataType<dtype>::value));\n\nREGISTER_CUDA_LU_DECOMPOSITION_KERNEL(float)\nREGISTER_CUDA_LU_DECOMPOSITION_KERNEL(double)\n#endif\n\n}  // namespace user_op\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/masked_fill_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/user/kernels/where_kernel_util.h\"\n#include \"oneflow/core/ndarray/ndarray_util.h\"\n\nnamespace oneflow {\n\ntemplate<DeviceType device_type, typename T, typename CondT>\nclass MaskedFillKernel final : public user_op::OpKernel {\n public:\n  MaskedFillKernel() = default;\n  ~MaskedFillKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* mask = ctx->Tensor4ArgNameAndIndex(\"mask\", 0);\n    const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    T scalar_operand = static_cast<T>(0);\n    if (ctx->Attr<bool>(\"has_int_operand\")) {\n      scalar_operand = static_cast<T>(ctx->Attr<int64_t>(\"int_operand\"));\n    } else if (ctx->Attr<bool>(\"has_float_operand\")) {\n      scalar_operand = static_cast<T>(ctx->Attr<double>(\"float_operand\"));\n    } else if (ctx->Attr<bool>(\"has_bool_operand\")) {\n      scalar_operand = static_cast<T>(ctx->Attr<bool>(\"bool_operand\"));\n    } else {\n      UNIMPLEMENTED() << \"The scalar in MaskedFill should be float or int.\";\n    }\n    WhereKernelUtil<device_type, T, CondT>::WhereXScalar(\n        ctx->stream(), out->shape_view().elem_cnt(), mask->dptr<CondT>(), scalar_operand,\n        x->dptr<T>(), out->mut_dptr<T>());\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_MASKED_FILL_KERNEL(device_type_v, dtype_pair, ctype_pair)                   \\\n  REGISTER_USER_KERNEL(\"masked_fill\")                                                        \\\n      .SetCreateFn<MaskedFillKernel<device_type_v, OF_PP_PAIR_FIRST(dtype_pair),             \\\n                                    OF_PP_PAIR_FIRST(ctype_pair)>>()                         \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == device_type_v)                           \\\n                       && (user_op::HobDataType(\"mask\", 0) == OF_PP_PAIR_SECOND(ctype_pair)) \\\n                       && (user_op::HobDataType(\"out\", 0) == OF_PP_PAIR_SECOND(dtype_pair)));\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_MASKED_FILL_KERNEL, DEVICE_TYPE_SEQ,\n                                 ARITHMETIC_DATA_TYPE_SEQ BOOL_DATA_TYPE_SEQ,\n                                 INT_DATA_TYPE_SEQ BOOL_DATA_TYPE_SEQ)\n#ifdef WITH_CUDA\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_MASKED_FILL_KERNEL, (DeviceType::kCUDA),\n                                 FLOAT16_DATA_TYPE_SEQ, INT_DATA_TYPE_SEQ BOOL_DATA_TYPE_SEQ)\n#endif\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/math_binary_broadcast_kernels.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/kernel_util.h\"\n#include \"oneflow/core/ndarray/ndarray_util.h\"\n#include \"oneflow/core/ndarray/binary_func.h\"\n#include \"oneflow/core/ndarray/xpu_var_ndarray.h\"\n#include \"oneflow/user/ops/math_binary_broadcast_seq.h\"\n#include \"oneflow/core/kernel/cuda_graph_support.h\"\n#include \"oneflow/core/ep/include/primitive/broadcast_elementwise_binary.h\"\nnamespace oneflow {\n\ntemplate<typename Context, ep::primitive::BinaryOp binary_op>\nstd::enable_if_t<binary_op == ep::primitive::BinaryOp::kIsCloseEqualNan\n                     or binary_op == ep::primitive::BinaryOp::kIsClose,\n                 std::unique_ptr<ep::primitive::BroadcastElementwiseBinary>>\nNewBroadcastElementwiseBinaryPrimitive(Context* ctx) {\n  const user_op::TensorDesc* x = ctx->TensorDesc4ArgNameAndIndex(\"x\", 0);\n  const user_op::TensorDesc* z = ctx->TensorDesc4ArgNameAndIndex(\"z\", 0);\n  size_t num_axes = z->shape().NumAxes();\n  return ep::primitive::NewPrimitive<ep::primitive::BroadcastElementwiseBinaryFactory>(\n      ctx->device_type(), binary_op, x->data_type(), z->data_type(), num_axes,\n      ctx->template Attr<float>(\"atol\"), ctx->template Attr<float>(\"rtol\"));\n}\n\ntemplate<typename Context, ep::primitive::BinaryOp binary_op>\nstd::enable_if_t<binary_op != ep::primitive::BinaryOp::kIsCloseEqualNan\n                     and binary_op != ep::primitive::BinaryOp::kIsClose,\n                 std::unique_ptr<ep::primitive::BroadcastElementwiseBinary>>\nNewBroadcastElementwiseBinaryPrimitive(Context* ctx) {\n  const user_op::TensorDesc* x = ctx->TensorDesc4ArgNameAndIndex(\"x\", 0);\n  const user_op::TensorDesc* z = ctx->TensorDesc4ArgNameAndIndex(\"z\", 0);\n  size_t num_axes = z->shape().NumAxes();\n  return ep::primitive::NewPrimitive<ep::primitive::BroadcastElementwiseBinaryFactory>(\n      ctx->device_type(), binary_op, x->data_type(), z->data_type(), num_axes);\n}\n\ntemplate<ep::primitive::BinaryOp binary_op>\nclass MathBinaryBroadcastEpKernel final : public user_op::OpKernel,\n                                          public user_op::CudaGraphSupport {\n public:\n  MathBinaryBroadcastEpKernel() = default;\n  ~MathBinaryBroadcastEpKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    user_op::Tensor* z = ctx->Tensor4ArgNameAndIndex(\"z\", 0);\n\n    auto primitive =\n        NewBroadcastElementwiseBinaryPrimitive<user_op::KernelComputeContext, binary_op>(ctx);\n    CHECK(primitive.get() != nullptr) << \"Exceeds maximum supported dimensions\";\n\n    const int64_t x_elem_cnt = x->shape_view().elem_cnt();\n    const int64_t y_elem_cnt = y->shape_view().elem_cnt();\n    size_t num_src0_dims = x->shape_view().NumAxes();\n    size_t num_src1_dims = y->shape_view().NumAxes();\n\n    int64_t zero_dim = 1;\n    int64_t* src0_dims = const_cast<int64_t*>(x->shape_view().ptr());\n    int64_t* src1_dims = const_cast<int64_t*>(y->shape_view().ptr());\n\n    if (x_elem_cnt != 0 && y_elem_cnt != 0) {\n      if (num_src0_dims == 0) {\n        num_src0_dims = 1;\n        src0_dims = &zero_dim;\n      }\n      if (num_src1_dims == 0) {\n        num_src1_dims = 1;\n        src1_dims = &zero_dim;\n      }\n\n      primitive->Launch(ctx->stream(), num_src0_dims, src0_dims, x->dptr(), num_src1_dims,\n                        src1_dims, y->dptr(), z->mut_dptr());\n    } else {\n      // For 0-size Tensor\n      return;\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\ntemplate<ep::primitive::BinaryOp binary_op>\nauto MathBinaryBroadcastPrimitiveExists() {\n  return hob::make_custom(\"MathBinaryBroadcastPrimitiveExists\", [](const user_op::KernelRegContext&\n                                                                       ctx) {\n    return NewBroadcastElementwiseBinaryPrimitive<const user_op::KernelRegContext, binary_op>(&ctx).\n    operator bool();\n  });\n}\n\n#define REGISTER_BINARY_BROADCAST_EP_KERNEL(math_type_pair, binary_op) \\\n  REGISTER_USER_KERNEL(math_type_pair)                                 \\\n      .SetCreateFn<MathBinaryBroadcastEpKernel<binary_op>>()           \\\n      .SetIsMatchedHob(MathBinaryBroadcastPrimitiveExists<binary_op>() == true);\n\nREGISTER_BINARY_BROADCAST_EP_KERNEL(\"broadcast_add\", ep::primitive::BinaryOp::kAdd)\nREGISTER_BINARY_BROADCAST_EP_KERNEL(\"broadcast_sub\", ep::primitive::BinaryOp::kSub)\nREGISTER_BINARY_BROADCAST_EP_KERNEL(\"broadcast_mul\", ep::primitive::BinaryOp::kMul)\nREGISTER_BINARY_BROADCAST_EP_KERNEL(\"broadcast_div\", ep::primitive::BinaryOp::kDiv)\nREGISTER_BINARY_BROADCAST_EP_KERNEL(\"broadcast_minimum\", ep::primitive::BinaryOp::kMin)\nREGISTER_BINARY_BROADCAST_EP_KERNEL(\"broadcast_maximum\", ep::primitive::BinaryOp::kMax)\nREGISTER_BINARY_BROADCAST_EP_KERNEL(\"broadcast_pow\", ep::primitive::BinaryOp::kPow)\nREGISTER_BINARY_BROADCAST_EP_KERNEL(\"broadcast_equal\", ep::primitive::BinaryOp::kEqual)\nREGISTER_BINARY_BROADCAST_EP_KERNEL(\"broadcast_not_equal\", ep::primitive::BinaryOp::kNotEqual)\nREGISTER_BINARY_BROADCAST_EP_KERNEL(\"broadcast_greater\", ep::primitive::BinaryOp::kGreaterThan)\nREGISTER_BINARY_BROADCAST_EP_KERNEL(\"broadcast_greater_equal\",\n                                    ep::primitive::BinaryOp::kGreaterEqual)\nREGISTER_BINARY_BROADCAST_EP_KERNEL(\"broadcast_less\", ep::primitive::BinaryOp::kLessThan)\nREGISTER_BINARY_BROADCAST_EP_KERNEL(\"broadcast_less_equal\", ep::primitive::BinaryOp::kLessEqual)\nREGISTER_BINARY_BROADCAST_EP_KERNEL(\"broadcast_isclose_eq_nan\",\n                                    ep::primitive::BinaryOp::kIsCloseEqualNan)\nREGISTER_BINARY_BROADCAST_EP_KERNEL(\"broadcast_isclose_neq_nan\", ep::primitive::BinaryOp::kIsClose)\nREGISTER_BINARY_BROADCAST_EP_KERNEL(\"broadcast_logical_and\", ep::primitive::BinaryOp::kLogicalAnd)\nREGISTER_BINARY_BROADCAST_EP_KERNEL(\"broadcast_logical_or\", ep::primitive::BinaryOp::kLogicalOr)\nREGISTER_BINARY_BROADCAST_EP_KERNEL(\"broadcast_logical_xor\", ep::primitive::BinaryOp::kLogicalXor)\nREGISTER_BINARY_BROADCAST_EP_KERNEL(\"broadcast_bitwise_and\", ep::primitive::BinaryOp::kBitwiseAnd)\nREGISTER_BINARY_BROADCAST_EP_KERNEL(\"broadcast_bitwise_or\", ep::primitive::BinaryOp::kBitwiseOr)\nREGISTER_BINARY_BROADCAST_EP_KERNEL(\"broadcast_bitwise_xor\", ep::primitive::BinaryOp::kBitwiseXor)\nREGISTER_BINARY_BROADCAST_EP_KERNEL(\"broadcast_floor_mod\", ep::primitive::BinaryOp::kFloorMod)\nREGISTER_BINARY_BROADCAST_EP_KERNEL(\"broadcast_fmod\", ep::primitive::BinaryOp::kFmod)\nREGISTER_BINARY_BROADCAST_EP_KERNEL(\"broadcast_zeta\", ep::primitive::BinaryOp::kZeta)\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/math_binary_elementwise_func.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_KERNELS_MATH_BINARY_ELEMENTWISE_FUNC_H_\n#define ONEFLOW_USER_KERNELS_MATH_BINARY_ELEMENTWISE_FUNC_H_\n\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/common/data_type.h\"\n#include \"oneflow/user/ops/math_binary_elementwise_seq.h\"\n#include \"oneflow/core/device/cuda_pseudo_half.h\"\n\n#if defined(__CUDACC__)\n\n#include <cuda_fp16.h>\n#define MATH_FUNC(name) name\n\n#else\n\n#include <cmath>\n#define MATH_FUNC(name) std::name\n\n#endif\n\nnamespace oneflow {\n\n#define DECLARE_BINARY_FUNCTOR(math_binary_elementwise_type, func_prefix) \\\n  template<typename T>                                                    \\\n  struct func_prefix##Functor;\n\nOF_PP_FOR_EACH_TUPLE(DECLARE_BINARY_FUNCTOR, MATH_BINARY_ELEMENTWISE_FUNC_SEQ)\n\ntemplate<typename T>\nstruct PowFunctor {\n  static OF_DEVICE_FUNC const T Forward(const T x, const T y) { return MATH_FUNC(pow)(x, y); }\n\n  static OF_DEVICE_FUNC const T BackwardXGrad(const T x, const T y, const T dz) {\n    return dz * y * (MATH_FUNC(pow)(x, y - T(1)));\n  }\n\n  static OF_DEVICE_FUNC const T BackwardYGrad(const T x, const T y, const T dz) {\n    if (x > T(0)) {\n      return dz * MATH_FUNC(log)(x) * (MATH_FUNC(pow)(x, y));\n    } else {\n      return T(0);\n    }\n  }\n};\n\ntemplate<typename T>\nstruct Atan2Functor {\n  static OF_DEVICE_FUNC const T Forward(const T x, const T y) { return MATH_FUNC(atan2)(x, y); }\n\n  static OF_DEVICE_FUNC const T BackwardXGrad(const T x, const T y, const T dz) {\n    return dz * (y / (x * x + y * y));\n  }\n\n  static OF_DEVICE_FUNC const T BackwardYGrad(const T x, const T y, const T dz) {\n    return dz * -x / (y * y + x * x);\n  }\n};\n\ntemplate<typename T>\nstruct FloorDivFunctor {\n  static OF_DEVICE_FUNC const T Forward(const T x, const T y) {\n#if defined(__CUDACC__)\n    return floor(fdividef(x, y));\n#else\n    return std::floor(x / y);\n#endif\n  }\n\n  static OF_DEVICE_FUNC const T BackwardXGrad(const T x, const T y, const T dz) { return T(0); }\n\n  static OF_DEVICE_FUNC const T BackwardYGrad(const T x, const T y, const T dz) { return T(0); }\n};\n\ntemplate<typename T>\nstruct TruncDivFunctor {\n  static OF_DEVICE_FUNC const T Forward(const T x, const T y) {\n#if defined(__CUDACC__)\n    return trunc(fdividef(x, y));\n#else\n    return std::trunc(x / y);\n#endif\n  }\n\n  static OF_DEVICE_FUNC const T BackwardXGrad(const T x, const T y, const T dz) { return T(0); }\n\n  static OF_DEVICE_FUNC const T BackwardYGrad(const T x, const T y, const T dz) { return T(0); }\n};\n\ntemplate<typename T>\nstruct XdivyFunctor {\n  static OF_DEVICE_FUNC const T Forward(const T x, const T y) {\n    if (T(0) == x) {\n      return T(0);\n    } else {\n      return x / y;\n    }\n  }\n\n  static OF_DEVICE_FUNC const T BackwardXGrad(const T x, const T y, const T dz) {\n    if (T(0) == x || T(0) == dz) {\n      return T(0);\n    } else {\n      return dz / y;\n    }\n  }\n\n  static OF_DEVICE_FUNC const T BackwardYGrad(const T x, const T y, const T dz) {\n    return dz * XdivyFunctor<T>::Forward((-x), (y * y));\n  }\n};\n\ntemplate<typename T>\nstruct XlogyFunctor {\n  static OF_DEVICE_FUNC const T Forward(const T x, const T y) {\n    if (T(0) == x) {\n      return T(0);\n    } else {\n      return x * MATH_FUNC(log)(y);\n    }\n  }\n\n  static OF_DEVICE_FUNC const T BackwardXGrad(const T x, const T y, const T dz) {\n    if (T(0) == x || T(0) == dz) {\n      return T(0);\n    } else {\n      return dz * MATH_FUNC(log)(y);\n    }\n  }\n\n  static OF_DEVICE_FUNC const T BackwardYGrad(const T x, const T y, const T dz) {\n    return dz * XdivyFunctor<T>::Forward(x, y);\n  }\n};\n\n#if defined(__CUDACC__)\n// half version\n\n#define OF_HALF_FUNC __device__ __forceinline__\n\n#define MATH_FUNC_H_FW(name) __float2half(name(__half2float(x), __half2float(y)))\n#define MATH_FUNC_H_BW(name) __float2half(name(__half2float(x), __half2float(y), __half2float(dz)))\n\ntemplate<>\nstruct PowFunctor<half> {\n  static OF_HALF_FUNC const half Forward(const half x, const half y) {\n    return MATH_FUNC_H_FW(PowFunctor<float>::Forward);\n  }\n\n  static OF_HALF_FUNC const half BackwardXGrad(const half x, const half y, const half dz) {\n    return MATH_FUNC_H_BW(PowFunctor<float>::BackwardXGrad);\n  }\n\n  static OF_HALF_FUNC const half BackwardYGrad(const half x, const half y, const half dz) {\n    return MATH_FUNC_H_BW(PowFunctor<float>::BackwardYGrad);\n  }\n};\n\ntemplate<>\nstruct Atan2Functor<half> {\n  static OF_HALF_FUNC const half Forward(const half x, const half y) {\n    return MATH_FUNC_H_FW(Atan2Functor<float>::Forward);\n  }\n\n  static OF_HALF_FUNC const half BackwardXGrad(const half x, const half y, const half dz) {\n    return __hmul(dz, __hdiv(y, __hadd(__hmul(y, y), __hmul(x, x))));\n  }\n\n  static OF_HALF_FUNC const half BackwardYGrad(const half x, const half y, const half dz) {\n    return __hmul(dz, __hdiv(__hneg(x), __hadd(__hmul(y, y), __hmul(x, x))));\n  }\n};\n\ntemplate<>\nstruct FloorDivFunctor<half> {\n  static OF_HALF_FUNC const half Forward(const half x, const half y) {\n    return hfloor(__hdiv(x, y));\n  }\n\n  static OF_HALF_FUNC const half BackwardXGrad(const half x, const half y, const half dz) {\n    return GetZeroVal<half>();\n  }\n\n  static OF_HALF_FUNC const half BackwardYGrad(const half x, const half y, const half dz) {\n    return GetZeroVal<half>();\n  }\n};\n\ntemplate<>\nstruct TruncDivFunctor<half> {\n  static OF_HALF_FUNC const half Forward(const half x, const half y) {\n    return htrunc(__hdiv(x, y));\n  }\n\n  static OF_HALF_FUNC const half BackwardXGrad(const half x, const half y, const half dz) {\n    return GetZeroVal<half>();\n  }\n\n  static OF_HALF_FUNC const half BackwardYGrad(const half x, const half y, const half dz) {\n    return GetZeroVal<half>();\n  }\n};\n\ntemplate<>\nstruct XdivyFunctor<half> {\n  static OF_HALF_FUNC const half Forward(const half x, const half y) {\n    if (__heq(GetZeroVal<half>(), x)) {\n      return GetZeroVal<half>();\n    } else {\n      return __hdiv(x, y);\n    }\n  }\n\n  static OF_HALF_FUNC const half BackwardXGrad(const half x, const half y, const half dz) {\n    if (__heq(GetZeroVal<half>(), x)) {\n      return GetZeroVal<half>();\n    } else {\n      return XdivyFunctor<half>::Forward(dz, y);\n    }\n  }\n\n  static OF_HALF_FUNC const half BackwardYGrad(const half x, const half y, const half dz) {\n    return __hmul(dz, XdivyFunctor<half>::Forward(__hneg(x), __hmul(y, y)));\n  }\n};\n\ntemplate<>\nstruct XlogyFunctor<half> {\n  static OF_HALF_FUNC const half Forward(const half x, const half y) {\n    if (__heq(GetZeroVal<half>(), x)) {\n      return GetZeroVal<half>();\n    } else {\n      return __hmul(x, hlog(y));\n    }\n  }\n\n  static OF_HALF_FUNC const half BackwardXGrad(const half x, const half y, const half dz) {\n    if (__heq(GetZeroVal<half>(), x)) {\n      return GetZeroVal<half>();\n    } else {\n      return XlogyFunctor<half>::Forward(dz, y);\n    }\n  }\n\n  static OF_HALF_FUNC const half BackwardYGrad(const half x, const half y, const half dz) {\n    return __hmul(dz, XdivyFunctor<half>::Forward(x, y));\n  }\n};\n\n#endif\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_KERNELS_MATH_BINARY_ELEMENTWISE_FUNC_H_\n"
  },
  {
    "path": "oneflow/user/kernels/math_binary_elementwise_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/user/kernels/math_binary_elementwise_func.h\"\n#include \"oneflow/core/ep/cpu/cpu_stream.h\"\n#include \"oneflow/core/ep/cpu/cpu_device.h\"\n\nnamespace oneflow {\n\ntemplate<template<typename> class BinaryFunctor, typename T>\nclass MathBinaryElementwiseCpuKernel final : public user_op::OpKernel {\n public:\n  MathBinaryElementwiseCpuKernel() = default;\n  ~MathBinaryElementwiseCpuKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* tensor_x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    const user_op::Tensor* tensor_y = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    user_op::Tensor* tensor_z = ctx->Tensor4ArgNameAndIndex(\"z\", 0);\n    const T* x = tensor_x->dptr<T>();\n    const T* y = tensor_y->dptr<T>();\n    T* z = tensor_z->mut_dptr<T>();\n    int64_t n = tensor_x->shape_view().elem_cnt();\n    CHECK_LE(n, GetMaxVal<int32_t>() / 2);\n    ep::CpuStream* cpu_stream = ctx->stream()->As<ep::CpuStream>();\n\n    cpu_stream->ParallelFor(0, n, [x, y, z](int64_t begin, int64_t end) {\n      for (int64_t i = begin; i < end; i++) { z[i] = BinaryFunctor<T>::Forward(x[i], y[i]); }\n    });\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\ntemplate<template<typename> class BinaryFunctor, typename T>\nclass MathBinaryElementwiseXGradCpuKernel final : public user_op::OpKernel {\n public:\n  MathBinaryElementwiseXGradCpuKernel() = default;\n  ~MathBinaryElementwiseXGradCpuKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* tensor_x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    const user_op::Tensor* tensor_y = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    const user_op::Tensor* tensor_dz = ctx->Tensor4ArgNameAndIndex(\"dz\", 0);\n    user_op::Tensor* tensor_dx = ctx->Tensor4ArgNameAndIndex(\"dx\", 0);\n\n    const T* x = tensor_x->dptr<T>();\n    const T* y = tensor_y->dptr<T>();\n    const T* dz = tensor_dz->dptr<T>();\n    T* dx = tensor_dx->mut_dptr<T>();\n    int64_t n = tensor_x->shape_view().elem_cnt();\n    CHECK_LE(n, GetMaxVal<int32_t>() / 2);\n    for (int32_t i = 0; i < n; ++i) { dx[i] = BinaryFunctor<T>::BackwardXGrad(x[i], y[i], dz[i]); }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\ntemplate<template<typename> class BinaryFunctor, typename T>\nclass MathBinaryElementwiseYGradCpuKernel final : public user_op::OpKernel {\n public:\n  MathBinaryElementwiseYGradCpuKernel() = default;\n  ~MathBinaryElementwiseYGradCpuKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* tensor_x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    const user_op::Tensor* tensor_y = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    const user_op::Tensor* tensor_dz = ctx->Tensor4ArgNameAndIndex(\"dz\", 0);\n    user_op::Tensor* tensor_dy = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n\n    const T* x = tensor_x->dptr<T>();\n    const T* y = tensor_y->dptr<T>();\n    const T* dz = tensor_dz->dptr<T>();\n    T* dy = tensor_dy->mut_dptr<T>();\n    int64_t n = tensor_x->shape_view().elem_cnt();\n    CHECK_LE(n, GetMaxVal<int32_t>() / 2);\n    for (int32_t i = 0; i < n; ++i) { dy[i] = BinaryFunctor<T>::BackwardYGrad(x[i], y[i], dz[i]); }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_MATH_BINARY_ELEMENTWISE_CPU_KERNEL_AND_GRAD(math_type_pair, data_type_pair)    \\\n  REGISTER_USER_KERNEL(OF_PP_PAIR_FIRST(math_type_pair))                                        \\\n      .SetCreateFn<                                                                             \\\n          MathBinaryElementwiseCpuKernel<OF_PP_CAT(OF_PP_PAIR_SECOND(math_type_pair), Functor), \\\n                                         OF_PP_PAIR_FIRST(data_type_pair)>>()                   \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)                           \\\n                       && (user_op::HobDataType(\"x\", 0) == OF_PP_PAIR_SECOND(data_type_pair))); \\\n                                                                                                \\\n  REGISTER_USER_KERNEL((std::string(\"\") + OF_PP_PAIR_FIRST(math_type_pair) + \"_x_grad\"))        \\\n      .SetCreateFn<MathBinaryElementwiseXGradCpuKernel<                                         \\\n          OF_PP_CAT(OF_PP_PAIR_SECOND(math_type_pair), Functor),                                \\\n          OF_PP_PAIR_FIRST(data_type_pair)>>()                                                  \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)                           \\\n                       && (user_op::HobDataType(\"x\", 0) == OF_PP_PAIR_SECOND(data_type_pair))); \\\n  REGISTER_USER_KERNEL((std::string(\"\") + OF_PP_PAIR_FIRST(math_type_pair) + \"_y_grad\"))        \\\n      .SetCreateFn<MathBinaryElementwiseYGradCpuKernel<                                         \\\n          OF_PP_CAT(OF_PP_PAIR_SECOND(math_type_pair), Functor),                                \\\n          OF_PP_PAIR_FIRST(data_type_pair)>>()                                                  \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)                           \\\n                       && (user_op::HobDataType(\"x\", 0) == OF_PP_PAIR_SECOND(data_type_pair)));\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_MATH_BINARY_ELEMENTWISE_CPU_KERNEL_AND_GRAD,\n                                 MATH_BINARY_ELEMENTWISE_FUNC_SEQ, FLOATING_DATA_TYPE_SEQ)\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_MATH_BINARY_ELEMENTWISE_CPU_KERNEL_AND_GRAD,\n                                 OF_PP_MAKE_TUPLE_SEQ(\"floordiv\", FloorDiv)\n                                     OF_PP_MAKE_TUPLE_SEQ(\"truncdiv\", TruncDiv),\n                                 INT_DATA_TYPE_SEQ UNSIGNED_INT_DATA_TYPE_SEQ)\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/math_binary_elementwise_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/user/kernels/math_binary_elementwise_func.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<template<typename> class BinaryFunctor, typename T>\n__global__ void MathBinaryElementwiseForwardGpu(const int64_t n, const T* x, const T* y, T* z) {\n  CUDA_1D_KERNEL_LOOP_T(int64_t, i, n) { z[i] = BinaryFunctor<T>::Forward(x[i], y[i]); }\n}\n\ntemplate<template<typename> class BinaryFunctor, typename T>\n__global__ void MathBinaryElementwiseBackwardXGradGpu(const int64_t n, const T* x, const T* y,\n                                                      const T* dz, T* dx) {\n  CUDA_1D_KERNEL_LOOP_T(int64_t, i, n) {\n    dx[i] = BinaryFunctor<T>::BackwardXGrad(x[i], y[i], dz[i]);\n  }\n}\n\ntemplate<template<typename> class BinaryFunctor, typename T>\n__global__ void MathBinaryElementwiseBackwardYGradGpu(const int64_t n, const T* x, const T* y,\n                                                      const T* dz, T* dy) {\n  CUDA_1D_KERNEL_LOOP_T(int64_t, i, n) {\n    dy[i] = BinaryFunctor<T>::BackwardYGrad(x[i], y[i], dz[i]);\n  }\n}\n\n}  // namespace\n\ntemplate<template<typename> class BinaryFunctor, typename T>\nclass MathBinaryElementwiseGpuKernel final : public user_op::OpKernel {\n public:\n  MathBinaryElementwiseGpuKernel() = default;\n  ~MathBinaryElementwiseGpuKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* tensor_x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    const user_op::Tensor* tensor_y = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    user_op::Tensor* tensor_z = ctx->Tensor4ArgNameAndIndex(\"z\", 0);\n    int64_t n = tensor_x->shape_view().elem_cnt();\n    if (n == 0) { return; }\n    MathBinaryElementwiseForwardGpu<BinaryFunctor, T>\n        <<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0,\n           ctx->stream()->As<ep::CudaStream>()->cuda_stream()>>>(\n            n, tensor_x->dptr<T>(), tensor_y->dptr<T>(), tensor_z->mut_dptr<T>());\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\ntemplate<template<typename> class BinaryFunctor, typename T>\nclass MathBinaryElementwiseXGradGpuKernel final : public user_op::OpKernel {\n public:\n  MathBinaryElementwiseXGradGpuKernel() = default;\n  ~MathBinaryElementwiseXGradGpuKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* tensor_x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    const user_op::Tensor* tensor_y = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    const user_op::Tensor* tensor_dz = ctx->Tensor4ArgNameAndIndex(\"dz\", 0);\n    user_op::Tensor* tensor_dx = ctx->Tensor4ArgNameAndIndex(\"dx\", 0);\n    int64_t n = tensor_x->shape_view().elem_cnt();\n    if (n == 0) { return; }\n    MathBinaryElementwiseBackwardXGradGpu<BinaryFunctor, T>\n        <<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0,\n           ctx->stream()->As<ep::CudaStream>()->cuda_stream()>>>(\n            n, tensor_x->dptr<T>(), tensor_y->dptr<T>(), tensor_dz->dptr<T>(),\n            tensor_dx->mut_dptr<T>());\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\ntemplate<template<typename> class BinaryFunctor, typename T>\nclass MathBinaryElementwiseYGradGpuKernel final : public user_op::OpKernel {\n public:\n  MathBinaryElementwiseYGradGpuKernel() = default;\n  ~MathBinaryElementwiseYGradGpuKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* tensor_x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    const user_op::Tensor* tensor_y = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    const user_op::Tensor* tensor_dz = ctx->Tensor4ArgNameAndIndex(\"dz\", 0);\n    user_op::Tensor* tensor_dy = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    int64_t n = tensor_x->shape_view().elem_cnt();\n    if (n == 0) { return; }\n    MathBinaryElementwiseBackwardYGradGpu<BinaryFunctor, T>\n        <<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0,\n           ctx->stream()->As<ep::CudaStream>()->cuda_stream()>>>(\n            n, tensor_x->dptr<T>(), tensor_y->dptr<T>(), tensor_dz->dptr<T>(),\n            tensor_dy->mut_dptr<T>());\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_MATH_BINARY_ELEMENTWISE_CUDA_KERNEL_AND_GRAD(math_type_pair, data_type_pair)   \\\n  REGISTER_USER_KERNEL(OF_PP_PAIR_FIRST(math_type_pair))                                        \\\n      .SetCreateFn<                                                                             \\\n          MathBinaryElementwiseGpuKernel<OF_PP_CAT(OF_PP_PAIR_SECOND(math_type_pair), Functor), \\\n                                         OF_PP_PAIR_FIRST(data_type_pair)>>()                   \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)                          \\\n                       && (user_op::HobDataType(\"x\", 0) == OF_PP_PAIR_SECOND(data_type_pair))); \\\n                                                                                                \\\n  REGISTER_USER_KERNEL((std::string(\"\") + OF_PP_PAIR_FIRST(math_type_pair) + \"_x_grad\"))        \\\n      .SetCreateFn<MathBinaryElementwiseXGradGpuKernel<                                         \\\n          OF_PP_CAT(OF_PP_PAIR_SECOND(math_type_pair), Functor),                                \\\n          OF_PP_PAIR_FIRST(data_type_pair)>>()                                                  \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)                          \\\n                       && (user_op::HobDataType(\"x\", 0) == OF_PP_PAIR_SECOND(data_type_pair))); \\\n  REGISTER_USER_KERNEL((std::string(\"\") + OF_PP_PAIR_FIRST(math_type_pair) + \"_y_grad\"))        \\\n      .SetCreateFn<MathBinaryElementwiseYGradGpuKernel<                                         \\\n          OF_PP_CAT(OF_PP_PAIR_SECOND(math_type_pair), Functor),                                \\\n          OF_PP_PAIR_FIRST(data_type_pair)>>()                                                  \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)                          \\\n                       && (user_op::HobDataType(\"x\", 0) == OF_PP_PAIR_SECOND(data_type_pair)));\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_MATH_BINARY_ELEMENTWISE_CUDA_KERNEL_AND_GRAD,\n                                 MATH_BINARY_ELEMENTWISE_FUNC_SEQ, FLOATING_DATA_TYPE_SEQ)\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_MATH_BINARY_ELEMENTWISE_CUDA_KERNEL_AND_GRAD,\n                                 OF_PP_MAKE_TUPLE_SEQ(\"floordiv\", FloorDiv)\n                                     OF_PP_MAKE_TUPLE_SEQ(\"truncdiv\", TruncDiv),\n                                 INT_DATA_TYPE_SEQ UNSIGNED_INT_DATA_TYPE_SEQ)\n\ntemplate<template<typename> class BinaryFunctor>\nclass MathBinaryElementwiseGpuHalfKernel final : public user_op::OpKernel {\n public:\n  MathBinaryElementwiseGpuHalfKernel() = default;\n  ~MathBinaryElementwiseGpuHalfKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* tensor_x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    const user_op::Tensor* tensor_y = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    user_op::Tensor* tensor_z = ctx->Tensor4ArgNameAndIndex(\"z\", 0);\n    const half* x = reinterpret_cast<const half*>(tensor_x->dptr<float16>());\n    const half* y = reinterpret_cast<const half*>(tensor_y->dptr<float16>());\n    half* z = reinterpret_cast<half*>(tensor_z->mut_dptr<float16>());\n    int64_t n = tensor_x->shape_view().elem_cnt();\n    if (n == 0) { return; }\n    MathBinaryElementwiseForwardGpu<BinaryFunctor, half>\n        <<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0,\n           ctx->stream()->As<ep::CudaStream>()->cuda_stream()>>>(n, x, y, z);\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\ntemplate<template<typename> class BinaryFunctor>\nclass MathBinaryElementwiseXGradGpuHalfKernel final : public user_op::OpKernel {\n public:\n  MathBinaryElementwiseXGradGpuHalfKernel() = default;\n  ~MathBinaryElementwiseXGradGpuHalfKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* tensor_x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    const user_op::Tensor* tensor_y = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    const user_op::Tensor* tensor_dz = ctx->Tensor4ArgNameAndIndex(\"dz\", 0);\n    user_op::Tensor* tensor_dx = ctx->Tensor4ArgNameAndIndex(\"dx\", 0);\n\n    const half* x = reinterpret_cast<const half*>(tensor_x->dptr<float16>());\n    const half* y = reinterpret_cast<const half*>(tensor_y->dptr<float16>());\n    const half* dz = reinterpret_cast<const half*>(tensor_dz->dptr<float16>());\n    half* dx = reinterpret_cast<half*>(tensor_dx->mut_dptr<float16>());\n    int64_t n = tensor_x->shape_view().elem_cnt();\n    if (n == 0) { return; }\n    MathBinaryElementwiseBackwardXGradGpu<BinaryFunctor, half>\n        <<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0,\n           ctx->stream()->As<ep::CudaStream>()->cuda_stream()>>>(n, x, y, dz, dx);\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\ntemplate<template<typename> class BinaryFunctor>\nclass MathBinaryElementwiseYGradGpuHalfKernel final : public user_op::OpKernel {\n public:\n  MathBinaryElementwiseYGradGpuHalfKernel() = default;\n  ~MathBinaryElementwiseYGradGpuHalfKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* tensor_x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    const user_op::Tensor* tensor_y = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    const user_op::Tensor* tensor_dz = ctx->Tensor4ArgNameAndIndex(\"dz\", 0);\n    user_op::Tensor* tensor_dy = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n\n    const half* x = reinterpret_cast<const half*>(tensor_x->dptr<float16>());\n    const half* y = reinterpret_cast<const half*>(tensor_y->dptr<float16>());\n    const half* dz = reinterpret_cast<const half*>(tensor_dz->dptr<float16>());\n    half* dy = reinterpret_cast<half*>(tensor_dy->mut_dptr<float16>());\n    int64_t n = tensor_x->shape_view().elem_cnt();\n    if (n == 0) { return; }\n    MathBinaryElementwiseBackwardYGradGpu<BinaryFunctor, half>\n        <<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0,\n           ctx->stream()->As<ep::CudaStream>()->cuda_stream()>>>(n, x, y, dz, dy);\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_MATH_BINARY_ELEMENTWISE_CUDA_HALF_KERNEL_AND_GRAD(math_type_str,              \\\n                                                                   math_func_prefix)           \\\n  REGISTER_USER_KERNEL(math_type_str)                                                          \\\n      .SetCreateFn<MathBinaryElementwiseGpuHalfKernel<OF_PP_CAT(math_func_prefix, Functor)>>() \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)                         \\\n                       && (user_op::HobDataType(\"x\", 0) == DataType::kFloat16));               \\\n                                                                                               \\\n  REGISTER_USER_KERNEL((std::string(\"\") + math_type_str + \"_x_grad\"))                          \\\n      .SetCreateFn<                                                                            \\\n          MathBinaryElementwiseXGradGpuHalfKernel<OF_PP_CAT(math_func_prefix, Functor)>>()     \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)                         \\\n                       && (user_op::HobDataType(\"x\", 0) == DataType::kFloat16));               \\\n  REGISTER_USER_KERNEL((std::string(\"\") + math_type_str + \"_y_grad\"))                          \\\n      .SetCreateFn<                                                                            \\\n          MathBinaryElementwiseYGradGpuHalfKernel<OF_PP_CAT(math_func_prefix, Functor)>>()     \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)                         \\\n                       && (user_op::HobDataType(\"x\", 0) == DataType::kFloat16));\n\nOF_PP_FOR_EACH_TUPLE(REGISTER_MATH_BINARY_ELEMENTWISE_CUDA_HALF_KERNEL_AND_GRAD,\n                     MATH_BINARY_ELEMENTWISE_FUNC_SEQ)\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/math_unary_elementwise_func.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_KERNELS_MATH_UNARY_ELEMENTWISE_FUNC_H_\n#define ONEFLOW_USER_KERNELS_MATH_UNARY_ELEMENTWISE_FUNC_H_\n\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/common/data_type.h\"\n#include \"oneflow/user/ops/math_unary_elementwise_seq.h\"\n#include \"oneflow/core/device/cuda_pseudo_half.h\"\n\n#if defined(__CUDACC__)\n\n#include <cuda_fp16.h>\n#define MATH_FUNC_F(name, x) name##f(x)\n#define MATH_FUNC_D(name, x) name(x)\n\n#else\n\n#include <cmath>\n#define MATH_FUNC_F(name, x) std::name(x)\n#define MATH_FUNC_D(name, x) std::name(x)\n\n#endif\n\nnamespace oneflow {\n\n#define DECLARE_UNARY_FUNCTOR(math_unary_elementwise_type, func_prefix) \\\n  template<typename T>                                                  \\\n  struct func_prefix##Functor;\n\nOF_PP_FOR_EACH_TUPLE(DECLARE_UNARY_FUNCTOR, MATH_UNARY_ELEMENTWISE_FUNC_SEQ)\n\ntemplate<typename T>\nstruct AbsFunctor {\n  static OF_DEVICE_FUNC T Forward(const T x) {\n    if (x == T(0))\n      return T(0);\n    else\n      return x < T(0) ? -x : x;\n  }\n\n  static OF_DEVICE_FUNC T Backward(const T x, const T dy) {\n    if (x == T(0))\n      return T(0);\n    else\n      return x < T(0) ? -dy : dy;\n  }\n};\n\ntemplate<typename T>\nstruct SignFunctor {\n  static OF_DEVICE_FUNC T Forward(const T x) { return (T(0) < x) - (x < T(0)); }\n\n  static OF_DEVICE_FUNC T Backward(const T x, const T dy) { return T(0); }\n};\n\ntemplate<>\nstruct RsqrtFunctor<float> {\n  static OF_DEVICE_FUNC float Forward(const float x) {\n#if defined(__CUDACC__)\n    return rsqrtf(x);\n#else\n    return 1.0f / std::sqrt(x);\n#endif\n  }\n\n  static OF_DEVICE_FUNC float Backward(const float x, const float dy) {\n    return dy * (-1.0f / (2.0f * MATH_FUNC_F(sqrt, x * x * x)));\n  }\n};\n\ntemplate<>\nstruct RsqrtFunctor<double> {\n  static OF_DEVICE_FUNC double Forward(const double x) {\n#if defined(__CUDACC__)\n    return rsqrt(x);\n#else\n    return 1.0 / std::sqrt(x);\n#endif\n  }\n\n  static OF_DEVICE_FUNC double Backward(const double x, const double dy) {\n    return dy * (-1.0 / (2.0 * MATH_FUNC_D(sqrt, x * x * x)));\n  }\n};\n\n// float version\n\ntemplate<>\nstruct AcosFunctor<float> {\n  static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(acos, x); }\n\n  static OF_DEVICE_FUNC float Backward(const float x, const float dy) {\n    return dy * -RsqrtFunctor<float>::Forward(1.0f - x * x);\n  }\n};\n\ntemplate<>\nstruct AcoshFunctor<float> {\n  static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(acosh, x); }\n\n  static OF_DEVICE_FUNC float Backward(const float x, const float dy) {\n    return dy * RsqrtFunctor<float>::Forward(x * x - 1.0f);\n  }\n};\n\ntemplate<>\nstruct AsinFunctor<float> {\n  static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(asin, x); }\n\n  static OF_DEVICE_FUNC float Backward(const float x, const float dy) {\n    return dy * RsqrtFunctor<float>::Forward(1.0f - x * x);\n  }\n};\n\ntemplate<>\nstruct AsinhFunctor<float> {\n  static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(asinh, x); }\n\n  static OF_DEVICE_FUNC float Backward(const float x, const float dy) {\n    return dy * RsqrtFunctor<float>::Forward(1.0f + x * x);\n  }\n};\n\ntemplate<>\nstruct AtanFunctor<float> {\n  static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(atan, x); }\n\n  static OF_DEVICE_FUNC float Backward(const float x, const float dy) {\n    return dy * (1.0f / (1.0f + x * x));\n  }\n};\n\ntemplate<>\nstruct AtanhFunctor<float> {\n  static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(atanh, x); }\n\n  static OF_DEVICE_FUNC float Backward(const float x, const float dy) {\n    return dy * (1.0f / (1.0f - x * x));\n  }\n};\n\ntemplate<>\nstruct NotEqualZeroFunctor<float> {\n  static OF_DEVICE_FUNC float Forward(const float x) { return x != 0; }\n\n  static OF_DEVICE_FUNC float Backward(const float x, const float dy) { return 0.0f; }\n};\n\ntemplate<>\nstruct CeilFunctor<float> {\n  static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(ceil, x); }\n\n  static OF_DEVICE_FUNC float Backward(const float x, const float dy) { return 0.0f; }\n};\n\ntemplate<>\nstruct CosFunctor<float> {\n  static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(cos, x); }\n\n  static OF_DEVICE_FUNC float Backward(const float x, const float dy) {\n    return dy * (-MATH_FUNC_F(sin, x));\n  }\n};\n\ntemplate<>\nstruct CoshFunctor<float> {\n  static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(cosh, x); }\n\n  static OF_DEVICE_FUNC float Backward(const float x, const float dy) {\n    return dy * MATH_FUNC_F(sinh, x);\n  }\n};\n\ntemplate<>\nstruct ErfFunctor<float> {\n  static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(erf, x); }\n\n  static OF_DEVICE_FUNC float Backward(const float x, const float dy) {\n    return dy * 2.0f * RsqrtFunctor<float>::Forward(M_PI) * expf(-x * x);\n  }\n};\n\ntemplate<>\nstruct ErfcFunctor<float> {\n  static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(erfc, x); }\n\n  static OF_DEVICE_FUNC float Backward(const float x, const float dy) {\n    return dy * -2.0f * RsqrtFunctor<float>::Forward(M_PI) * expf(-x * x);\n  }\n};\n\ntemplate<>\nstruct ExpFunctor<float> {\n  static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(exp, x); }\n\n  static OF_DEVICE_FUNC float Backward(const float x, const float dy) {\n    return dy * MATH_FUNC_F(exp, x);\n  }\n};\n\ntemplate<>\nstruct Expm1Functor<float> {\n  static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(expm1, x); }\n\n  static OF_DEVICE_FUNC float Backward(const float x, const float dy) {\n    return dy * MATH_FUNC_F(exp, x);\n  }\n};\n\ntemplate<>\nstruct FloorFunctor<float> {\n  static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(floor, x); }\n\n  static OF_DEVICE_FUNC float Backward(const float x, const float dy) { return 0.0f; }\n};\n\ntemplate<>\nstruct LgammaFunctor<float> {\n  static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(lgamma, x); }\n\n  static OF_DEVICE_FUNC float Backward(const float x, const float dy) {\n    // TODO(chengcheng): return: dy * digamma(x)\n    assert(false);\n    return 0.0f;\n  }\n};\n\ntemplate<>\nstruct LogFunctor<float> {\n  static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(log, x); }\n\n  static OF_DEVICE_FUNC float Backward(const float x, const float dy) { return dy * (1.0f / x); }\n};\n\ntemplate<>\nstruct Log2Functor<float> {\n  static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(log2, x); }\n\n  static OF_DEVICE_FUNC float Backward(const float x, const float dy) {\n    return dy * (1.0f / (x * MATH_FUNC_F(log, 2.0f)));\n  }\n};\n\ntemplate<>\nstruct Log1pFunctor<float> {\n  static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(log1p, x); }\n\n  static OF_DEVICE_FUNC float Backward(const float x, const float dy) {\n    return dy * (1.0f / (x + 1.0f));\n  }\n};\n\ntemplate<>\nstruct LogSigmoidFunctor<float> {\n  static OF_DEVICE_FUNC float Forward(const float x) {\n    return -MATH_FUNC_F(log, (1.0f + MATH_FUNC_F(exp, -x)));\n  }\n\n  static OF_DEVICE_FUNC float Backward(const float x, const float dy) {\n    return dy * (1.0f / (MATH_FUNC_F(exp, x) + 1.0f));\n  }\n};\n\ntemplate<>\nstruct NegativeFunctor<float> {\n  static OF_DEVICE_FUNC float Forward(const float x) { return -x; }\n\n  static OF_DEVICE_FUNC float Backward(const float x, const float dy) { return -dy; }\n};\n\ntemplate<>\nstruct ReciprocalFunctor<float> {\n  static OF_DEVICE_FUNC float Forward(const float x) { return 1.0f / x; }\n\n  static OF_DEVICE_FUNC float Backward(const float x, const float dy) {\n    return dy * (-1.0f / (x * x));\n  }\n};\n\ntemplate<>\nstruct ReciprocalNoNanFunctor<float> {\n  static OF_DEVICE_FUNC float Forward(const float x) {\n    if (fabsf(x) <= 0.0f) { return 0.0f; }\n    return 1.0f / x;\n  }\n\n  static OF_DEVICE_FUNC float Backward(const float x, const float dy) {\n    if (fabsf(x) <= 0.0f) { return 0.0f; }\n    return dy * (-1.0f / (x * x));\n  }\n};\n\ntemplate<>\nstruct RintFunctor<float> {\n  static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(rint, x); }\n\n  static OF_DEVICE_FUNC float Backward(const float x, const float dy) { return 0.0f; }\n};\n\ntemplate<>\nstruct RoundFunctor<float> {\n  static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(nearbyint, x); }\n\n  static OF_DEVICE_FUNC float Backward(const float x, const float dy) { return 0.0f; }\n};\n\ntemplate<>\nstruct SigmoidFunctor<float> {\n  static OF_DEVICE_FUNC float Forward(const float x) {\n    return 1.0f / (1.0f + MATH_FUNC_F(exp, -x));\n  }\n\n  static OF_DEVICE_FUNC float Backward(const float x, const float dy) {\n    float y = 1.0f / (1.0f + MATH_FUNC_F(exp, -x));\n    return dy * (y * (1.0f - y));\n  }\n};\n\ntemplate<>\nstruct SinFunctor<float> {\n  static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(sin, x); }\n\n  static OF_DEVICE_FUNC float Backward(const float x, const float dy) {\n    return dy * MATH_FUNC_F(cos, x);\n  }\n};\n\ntemplate<>\nstruct SinhFunctor<float> {\n  static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(sinh, x); }\n\n  static OF_DEVICE_FUNC float Backward(const float x, const float dy) {\n    return dy * MATH_FUNC_F(cosh, x);\n  }\n};\n\ntemplate<>\nstruct SqrtFunctor<float> {\n  static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(sqrt, x); }\n\n  static OF_DEVICE_FUNC float Backward(const float x, const float dy) {\n    return dy * 0.5f / MATH_FUNC_F(sqrt, x);\n  }\n};\n\ntemplate<>\nstruct SquareFunctor<float> {\n  static OF_DEVICE_FUNC float Forward(const float x) { return x * x; }\n\n  static OF_DEVICE_FUNC float Backward(const float x, const float dy) { return dy * 2.0f * x; }\n};\n\ntemplate<>\nstruct TanFunctor<float> {\n  static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(tan, x); }\n\n  static OF_DEVICE_FUNC float Backward(const float x, const float dy) {\n    return dy * (1.0f / (MATH_FUNC_F(cos, x) * MATH_FUNC_F(cos, x)));\n  }\n};\n\n// double version\n\ntemplate<>\nstruct AcosFunctor<double> {\n  static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(acos, x); }\n\n  static OF_DEVICE_FUNC double Backward(const double x, const double dy) {\n    return dy * -RsqrtFunctor<double>::Forward(1.0 - x * x);\n  }\n};\n\ntemplate<>\nstruct AcoshFunctor<double> {\n  static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(acosh, x); }\n\n  static OF_DEVICE_FUNC double Backward(const double x, const double dy) {\n    return dy * -RsqrtFunctor<double>::Forward(x * x - 1.0);\n  }\n};\n\ntemplate<>\nstruct AsinFunctor<double> {\n  static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(asin, x); }\n\n  static OF_DEVICE_FUNC double Backward(const double x, const double dy) {\n    return dy * RsqrtFunctor<double>::Forward(1.0 - x * x);\n  }\n};\n\ntemplate<>\nstruct AsinhFunctor<double> {\n  static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(asinh, x); }\n\n  static OF_DEVICE_FUNC double Backward(const double x, const double dy) {\n    return dy * RsqrtFunctor<double>::Forward(1.0 + x * x);\n  }\n};\n\ntemplate<>\nstruct AtanFunctor<double> {\n  static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(atan, x); }\n\n  static OF_DEVICE_FUNC double Backward(const double x, const double dy) {\n    return dy * (1.0 / (1.0 + x * x));\n  }\n};\n\ntemplate<>\nstruct AtanhFunctor<double> {\n  static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(atanh, x); }\n\n  static OF_DEVICE_FUNC double Backward(const double x, const double dy) {\n    return dy * (1.0 / (1.0 - x * x));\n  }\n};\n\ntemplate<>\nstruct NotEqualZeroFunctor<double> {\n  static OF_DEVICE_FUNC double Forward(const double x) { return x != 0; }\n\n  static OF_DEVICE_FUNC double Backward(const double x, const double dy) { return 0.0f; }\n};\n\ntemplate<>\nstruct CeilFunctor<double> {\n  static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(ceil, x); }\n\n  static OF_DEVICE_FUNC double Backward(const double x, const double dy) { return 0.0; }\n};\n\ntemplate<>\nstruct CosFunctor<double> {\n  static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(cos, x); }\n\n  static OF_DEVICE_FUNC double Backward(const double x, const double dy) {\n    return dy * (-MATH_FUNC_D(sin, x));\n  }\n};\n\ntemplate<>\nstruct CoshFunctor<double> {\n  static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(cosh, x); }\n\n  static OF_DEVICE_FUNC double Backward(const double x, const double dy) {\n    return dy * MATH_FUNC_D(sinh, x);\n  }\n};\n\ntemplate<>\nstruct ErfFunctor<double> {\n  static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(erf, x); }\n\n  static OF_DEVICE_FUNC double Backward(const double x, const double dy) {\n    return dy * 2.0 * RsqrtFunctor<double>::Forward(M_PI) * expf(-x * x);\n  }\n};\n\ntemplate<>\nstruct ErfcFunctor<double> {\n  static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(erfc, x); }\n\n  static OF_DEVICE_FUNC double Backward(const double x, const double dy) {\n    return dy * -2.0 * RsqrtFunctor<double>::Forward(M_PI) * expf(-x * x);\n  }\n};\n\ntemplate<>\nstruct ExpFunctor<double> {\n  static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(exp, x); }\n\n  static OF_DEVICE_FUNC double Backward(const double x, const double dy) {\n    return dy * MATH_FUNC_D(exp, x);\n  }\n};\n\ntemplate<>\nstruct Expm1Functor<double> {\n  static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(expm1, x); }\n\n  static OF_DEVICE_FUNC double Backward(const double x, const double dy) {\n    return dy * MATH_FUNC_D(exp, x);\n  }\n};\n\ntemplate<>\nstruct FloorFunctor<double> {\n  static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(floor, x); }\n\n  static OF_DEVICE_FUNC double Backward(const double x, const double dy) { return 0.0; }\n};\n\ntemplate<>\nstruct LgammaFunctor<double> {\n  static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(lgamma, x); }\n\n  static OF_DEVICE_FUNC double Backward(const double x, const double dy) {\n    // TODO(chengcheng): return: dy * digamma(x)\n    assert(false);\n    return 0.0;\n  }\n};\n\ntemplate<>\nstruct LogFunctor<double> {\n  static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(log, x); }\n\n  static OF_DEVICE_FUNC double Backward(const double x, const double dy) { return dy * (1.0 / x); }\n};\n\ntemplate<>\nstruct Log2Functor<double> {\n  static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(log2, x); }\n\n  static OF_DEVICE_FUNC double Backward(const double x, const double dy) {\n    return dy * (1.0 / (x * MATH_FUNC_D(log, 2.0)));\n  }\n};\n\ntemplate<>\nstruct Log1pFunctor<double> {\n  static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(log1p, x); }\n\n  static OF_DEVICE_FUNC double Backward(const double x, const double dy) {\n    return dy * (1.0 / (x + 1.0));\n  }\n};\n\ntemplate<>\nstruct LogSigmoidFunctor<double> {\n  static OF_DEVICE_FUNC double Forward(const double x) {\n    return -MATH_FUNC_D(log, (1.0 + MATH_FUNC_D(exp, -x)));\n  }\n\n  static OF_DEVICE_FUNC double Backward(const double x, const double dy) {\n    return dy * (1.0 / (MATH_FUNC_D(exp, x) + 1.0));\n  }\n};\n\ntemplate<>\nstruct NegativeFunctor<double> {\n  static OF_DEVICE_FUNC double Forward(const double x) { return -x; }\n\n  static OF_DEVICE_FUNC double Backward(const double x, const double dy) { return -dy; }\n};\n\ntemplate<>\nstruct ReciprocalFunctor<double> {\n  static OF_DEVICE_FUNC double Forward(const double x) { return 1.0 / x; }\n\n  static OF_DEVICE_FUNC double Backward(const double x, const double dy) {\n    return dy * (-1.0 / (x * x));\n  }\n};\n\ntemplate<>\nstruct ReciprocalNoNanFunctor<double> {\n  static OF_DEVICE_FUNC double Forward(const double x) {\n    if (fabs(x) <= 0.0) { return 0.0; }\n    return 1.0 / x;\n  }\n\n  static OF_DEVICE_FUNC double Backward(const double x, const double dy) {\n    if (fabs(x) <= 0.0) { return 0.0; }\n    return dy * (-1.0 / (x * x));\n  }\n};\n\ntemplate<>\nstruct RintFunctor<double> {\n  static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(rint, x); }\n\n  static OF_DEVICE_FUNC double Backward(const double x, const double dy) { return 0.0; }\n};\n\ntemplate<>\nstruct RoundFunctor<double> {\n  static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(nearbyint, x); }\n\n  static OF_DEVICE_FUNC double Backward(const double x, const double dy) { return 0.0; }\n};\n\ntemplate<>\nstruct SigmoidFunctor<double> {\n  static OF_DEVICE_FUNC double Forward(const double x) {\n    return 1.0 / (1.0 + MATH_FUNC_D(exp, -x));\n  }\n\n  static OF_DEVICE_FUNC double Backward(const double x, const double dy) {\n    double y = 1.0 / (1.0 + MATH_FUNC_D(exp, -x));\n    return dy * (y * (1.0 - y));\n  }\n};\n\ntemplate<>\nstruct SinFunctor<double> {\n  static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(sin, x); }\n\n  static OF_DEVICE_FUNC double Backward(const double x, const double dy) {\n    return dy * MATH_FUNC_D(cos, x);\n  }\n};\n\ntemplate<>\nstruct SinhFunctor<double> {\n  static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(sinh, x); }\n\n  static OF_DEVICE_FUNC double Backward(const double x, const double dy) {\n    return dy * MATH_FUNC_D(cosh, x);\n  }\n};\n\ntemplate<>\nstruct SqrtFunctor<double> {\n  static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(sqrt, x); }\n\n  static OF_DEVICE_FUNC double Backward(const double x, const double dy) {\n    return dy * (double)0.5 / MATH_FUNC_D(sqrt, x);\n  }\n};\n\ntemplate<>\nstruct SquareFunctor<double> {\n  static OF_DEVICE_FUNC double Forward(const double x) { return x * x; }\n\n  static OF_DEVICE_FUNC double Backward(const double x, const double dy) { return dy * 2.0 * x; }\n};\n\ntemplate<>\nstruct TanFunctor<double> {\n  static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(tan, x); }\n\n  static OF_DEVICE_FUNC double Backward(const double x, const double dy) {\n    return dy * (1.0 / (MATH_FUNC_D(cos, x) * MATH_FUNC_D(cos, x)));\n  }\n};\n\n#if defined(__CUDACC__)\n// half version\n\n#define OF_HALF_FUNC __device__ __forceinline__\n\n#define MATH_FUNC_H(name, x) __float2half(name##f(__half2float(x)))\n#define HALF_VAL_HALF __float2half(0.5f)\n#define HALF_VAL_TWO __float2half(2.0f)\n#define HALF_VAL_2RSQRT_PI __float2half(1.1283791671f)\n\ntemplate<>\nstruct AbsFunctor<half> {\n  static OF_HALF_FUNC half Forward(const half x) {\n    return __hlt(x, GetZeroVal<half>()) ? __hneg(x) : x;\n  }\n\n  static OF_HALF_FUNC half Backward(const half x, const half dy) {\n    return __hlt(x, GetZeroVal<half>()) ? __hneg(dy) : dy;\n  }\n};\n\ntemplate<>\nstruct AcosFunctor<half> {\n  static OF_HALF_FUNC half Forward(const half x) { return MATH_FUNC_H(acos, x); }\n\n  static OF_HALF_FUNC half Backward(const half x, const half dy) {\n    return __hmul(dy, __hneg(hrsqrt(__hsub(GetOneVal<half>(), __hmul(x, x)))));\n  }\n};\n\ntemplate<>\nstruct AcoshFunctor<half> {\n  static OF_HALF_FUNC half Forward(const half x) { return MATH_FUNC_H(acosh, x); }\n\n  static OF_HALF_FUNC half Backward(const half x, const half dy) {\n    return __hmul(dy, hrsqrt(__hsub(__hmul(x, x), GetOneVal<half>())));\n  }\n};\n\ntemplate<>\nstruct AsinFunctor<half> {\n  static OF_HALF_FUNC half Forward(const half x) { return MATH_FUNC_H(asin, x); }\n\n  static OF_HALF_FUNC half Backward(const half x, const half dy) {\n    return __hmul(dy, hrsqrt(__hsub(GetOneVal<half>(), __hmul(x, x))));\n  }\n};\n\ntemplate<>\nstruct AsinhFunctor<half> {\n  static OF_HALF_FUNC half Forward(const half x) { return MATH_FUNC_H(asinh, x); }\n\n  static OF_HALF_FUNC half Backward(const half x, const half dy) {\n    return __hmul(dy, hrsqrt(__hadd(GetOneVal<half>(), __hmul(x, x))));\n  }\n};\n\ntemplate<>\nstruct AtanFunctor<half> {\n  static OF_HALF_FUNC half Forward(const half x) { return MATH_FUNC_H(atan, x); }\n\n  static OF_HALF_FUNC half Backward(const half x, const half dy) {\n    return __hmul(dy, __hdiv(GetOneVal<half>(), __hadd(GetOneVal<half>(), __hmul(x, x))));\n  }\n};\n\ntemplate<>\nstruct AtanhFunctor<half> {\n  static OF_HALF_FUNC half Forward(const half x) { return MATH_FUNC_H(atanh, x); }\n\n  static OF_HALF_FUNC half Backward(const half x, const half dy) {\n    return __hmul(dy, __hdiv(GetOneVal<half>(), __hsub(GetOneVal<half>(), __hmul(x, x))));\n  }\n};\n\ntemplate<>\nstruct CeilFunctor<half> {\n  static OF_HALF_FUNC half Forward(const half x) { return hceil(x); }\n\n  static OF_HALF_FUNC half Backward(const half x, const half dy) { return GetZeroVal<half>(); }\n};\n\ntemplate<>\nstruct NotEqualZeroFunctor<half> {\n  static OF_HALF_FUNC half Forward(const half x) { return x != static_cast<half>(0.0); }\n\n  static OF_HALF_FUNC half Backward(const half x, const half dy) { return GetZeroVal<half>(); }\n};\n\ntemplate<>\nstruct CosFunctor<half> {\n  static OF_HALF_FUNC half Forward(const half x) { return hcos(x); }\n\n  static OF_HALF_FUNC half Backward(const half x, const half dy) {\n    return __hmul(dy, __hneg(hsin(x)));\n  }\n};\n\ntemplate<>\nstruct CoshFunctor<half> {\n  static OF_HALF_FUNC half Forward(const half x) { return MATH_FUNC_H(cosh, x); }\n\n  static OF_HALF_FUNC half Backward(const half x, const half dy) {\n    return __hmul(dy, MATH_FUNC_H(sinh, x));\n  }\n};\n\ntemplate<>\nstruct ErfFunctor<half> {\n  static OF_HALF_FUNC half Forward(const half x) { return MATH_FUNC_H(erf, x); }\n\n  static OF_HALF_FUNC half Backward(const half x, const half dy) {\n    return __hmul(dy, __hmul(HALF_VAL_2RSQRT_PI, hexp(__hmul(__hneg(x), x))));\n  }\n};\n\ntemplate<>\nstruct ErfcFunctor<half> {\n  static OF_HALF_FUNC half Forward(const half x) { return MATH_FUNC_H(erfc, x); }\n\n  static OF_HALF_FUNC half Backward(const half x, const half dy) {\n    return __hmul(dy, __hneg(__hmul(HALF_VAL_2RSQRT_PI, hexp(__hmul(__hneg(x), x)))));\n  }\n};\n\ntemplate<>\nstruct ExpFunctor<half> {\n  static OF_HALF_FUNC half Forward(const half x) { return hexp(x); }\n\n  static OF_HALF_FUNC half Backward(const half x, const half dy) { return __hmul(dy, hexp(x)); }\n};\n\ntemplate<>\nstruct Expm1Functor<half> {\n  static OF_HALF_FUNC half Forward(const half x) { return MATH_FUNC_H(expm1, x); }\n\n  static OF_HALF_FUNC half Backward(const half x, const half dy) { return __hmul(dy, hexp(x)); }\n};\n\ntemplate<>\nstruct FloorFunctor<half> {\n  static OF_HALF_FUNC half Forward(const half x) { return hfloor(x); }\n\n  static OF_HALF_FUNC half Backward(const half x, const half dy) { return GetZeroVal<half>(); }\n};\n\ntemplate<>\nstruct LgammaFunctor<half> {\n  static OF_HALF_FUNC half Forward(const half x) { return MATH_FUNC_H(lgamma, x); }\n\n  static OF_HALF_FUNC half Backward(const half x, const half dy) {\n    // TODO(chengcheng): return: dy * digamma(x)\n    assert(false);\n    return GetZeroVal<half>();\n  }\n};\n\ntemplate<>\nstruct LogFunctor<half> {\n  static OF_HALF_FUNC half Forward(const half x) { return hlog(x); }\n\n  static OF_HALF_FUNC half Backward(const half x, const half dy) { return __hmul(dy, hrcp(x)); }\n};\n\ntemplate<>\nstruct Log2Functor<half> {\n  static OF_HALF_FUNC half Forward(const half x) { return hlog2(x); }\n\n  static OF_HALF_FUNC half Backward(const half x, const half dy) {\n    return __hmul(dy, hrcp(__hmul(x, hlog(HALF_VAL_TWO))));\n  }\n};\n\ntemplate<>\nstruct Log1pFunctor<half> {\n  static OF_HALF_FUNC half Forward(const half x) { return MATH_FUNC_H(log1p, x); }\n\n  static OF_HALF_FUNC half Backward(const half x, const half dy) {\n    return __hmul(dy, hrcp(__hadd(x, GetOneVal<half>())));\n  }\n};\n\ntemplate<>\nstruct LogSigmoidFunctor<half> {\n  static OF_HALF_FUNC half Forward(const half x) {\n    return __hneg(hlog(__hadd(GetOneVal<half>(), hexp(__hneg(x)))));\n  }\n\n  static OF_HALF_FUNC half Backward(const half x, const half dy) {\n    return __hmul(dy, hrcp(__hadd(hexp(x), GetOneVal<half>())));\n  }\n};\n\ntemplate<>\nstruct NegativeFunctor<half> {\n  static OF_HALF_FUNC half Forward(const half x) { return __hneg(x); }\n\n  static OF_HALF_FUNC half Backward(const half x, const half dy) { return __hneg(dy); }\n};\n\ntemplate<>\nstruct ReciprocalFunctor<half> {\n  static OF_HALF_FUNC half Forward(const half x) { return hrcp(x); }\n\n  static OF_HALF_FUNC half Backward(const half x, const half dy) {\n    return __hmul(dy, __hneg(hrcp(__hmul(x, x))));\n  }\n};\n\ntemplate<>\nstruct ReciprocalNoNanFunctor<half> {\n  static OF_HALF_FUNC half Forward(const half x) {\n    if (__heq(GetZeroVal<half>(), x)) { return GetZeroVal<half>(); }\n    return hrcp(x);\n  }\n\n  static OF_HALF_FUNC half Backward(const half x, const half dy) {\n    if (__heq(GetZeroVal<half>(), x)) { return GetZeroVal<half>(); }\n    return __hmul(dy, __hneg(hrcp(__hmul(x, x))));\n  }\n};\n\ntemplate<>\nstruct RintFunctor<half> {\n  static OF_HALF_FUNC half Forward(const half x) { return hrint(x); }\n\n  static OF_HALF_FUNC half Backward(const half x, const half dy) { return GetZeroVal<half>(); }\n};\n\ntemplate<>\nstruct RoundFunctor<half> {\n  static OF_HALF_FUNC half Forward(const half x) { return MATH_FUNC_H(nearbyint, x); }\n\n  static OF_HALF_FUNC half Backward(const half x, const half dy) { return GetZeroVal<half>(); }\n};\n\ntemplate<>\nstruct RsqrtFunctor<half> {\n  static OF_HALF_FUNC half Forward(const half x) { return hrsqrt(x); }\n\n  static OF_HALF_FUNC half Backward(const half x, const half dy) {\n    return __hmul(dy, __hneg(hrcp(__hmul(HALF_VAL_TWO, hsqrt(__hmul(x, __hmul(x, x)))))));\n  }\n};\n\ntemplate<>\nstruct SigmoidFunctor<half> {\n  static OF_HALF_FUNC half Forward(const half x) {\n    return hrcp(__hadd(GetOneVal<half>(), hexp(__hneg(x))));\n  }\n\n  static OF_HALF_FUNC half Backward(const half x, const half dy) {\n    half y = hrcp(__hadd(GetOneVal<half>(), hexp(__hneg(x))));\n    return __hmul(dy, __hmul(y, __hsub(GetOneVal<half>(), y)));\n  }\n};\n\ntemplate<>\nstruct SignFunctor<half> {\n  static OF_HALF_FUNC half Forward(const half x) {\n    if (__hgt(x, GetZeroVal<half>())) { return GetOneVal<half>(); }\n    if (__hlt(x, GetZeroVal<half>())) { return __hneg(GetOneVal<half>()); }\n    return GetZeroVal<half>();\n  }\n\n  static OF_HALF_FUNC half Backward(const half x, const half dy) { return GetZeroVal<half>(); }\n};\n\ntemplate<>\nstruct SinFunctor<half> {\n  static OF_HALF_FUNC half Forward(const half x) { return hsin(x); }\n\n  static OF_HALF_FUNC half Backward(const half x, const half dy) { return __hmul(dy, hcos(x)); }\n};\n\ntemplate<>\nstruct SinhFunctor<half> {\n  static OF_HALF_FUNC half Forward(const half x) { return MATH_FUNC_H(sinh, x); }\n\n  static OF_HALF_FUNC half Backward(const half x, const half dy) {\n    return __hmul(dy, MATH_FUNC_H(cosh, x));\n  }\n};\n\ntemplate<>\nstruct SqrtFunctor<half> {\n  static OF_HALF_FUNC half Forward(const half x) { return hsqrt(x); }\n\n  static OF_HALF_FUNC half Backward(const half x, const half dy) {\n    return __hmul(dy, __hdiv(HALF_VAL_HALF, hsqrt(x)));\n  }\n};\n\ntemplate<>\nstruct SquareFunctor<half> {\n  static OF_HALF_FUNC half Forward(const half x) { return __hmul(x, x); }\n\n  static OF_HALF_FUNC half Backward(const half x, const half dy) {\n    return __hmul(dy, __hmul(HALF_VAL_TWO, x));\n  }\n};\n\ntemplate<>\nstruct TanFunctor<half> {\n  static OF_HALF_FUNC half Forward(const half x) { return __hdiv(hsin(x), hcos(x)); }\n\n  static OF_HALF_FUNC half Backward(const half x, const half dy) {\n    return __hmul(dy, hrcp(__hmul(hcos(x), hcos(x))));\n  }\n};\n\n#endif\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_KERNELS_MATH_UNARY_ELEMENTWISE_FUNC_H_\n"
  },
  {
    "path": "oneflow/user/kernels/math_unary_elementwise_primitive_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/ep/include/primitive/binary_op.h\"\n#include \"oneflow/core/ep/include/primitive/broadcast_elementwise_binary.h\"\n#include \"oneflow/user/kernels/elementwise_primitive_kernel.h\"\n\nnamespace oneflow {\n\n#define MATH_UNARY_ELEMENTWISE_PRIMITIVE_SEQ                                          \\\n  OF_PP_MAKE_TUPLE_SEQ(\"abs\", ep::primitive::UnaryOp::kAbs)                           \\\n  OF_PP_MAKE_TUPLE_SEQ(\"acos\", ep::primitive::UnaryOp::kAcos)                         \\\n  OF_PP_MAKE_TUPLE_SEQ(\"acosh\", ep::primitive::UnaryOp::kAcosh)                       \\\n  OF_PP_MAKE_TUPLE_SEQ(\"asin\", ep::primitive::UnaryOp::kAsin)                         \\\n  OF_PP_MAKE_TUPLE_SEQ(\"asinh\", ep::primitive::UnaryOp::kAsinh)                       \\\n  OF_PP_MAKE_TUPLE_SEQ(\"atan\", ep::primitive::UnaryOp::kAtan)                         \\\n  OF_PP_MAKE_TUPLE_SEQ(\"atanh\", ep::primitive::UnaryOp::kAtanh)                       \\\n  OF_PP_MAKE_TUPLE_SEQ(\"ceil\", ep::primitive::UnaryOp::kCeil)                         \\\n  OF_PP_MAKE_TUPLE_SEQ(\"cos\", ep::primitive::UnaryOp::kCos)                           \\\n  OF_PP_MAKE_TUPLE_SEQ(\"cosh\", ep::primitive::UnaryOp::kCosh)                         \\\n  OF_PP_MAKE_TUPLE_SEQ(\"digamma\", ep::primitive::UnaryOp::kDigamma)                   \\\n  OF_PP_MAKE_TUPLE_SEQ(\"trigamma\", ep::primitive::UnaryOp::kTrigamma)                 \\\n  OF_PP_MAKE_TUPLE_SEQ(\"erf\", ep::primitive::UnaryOp::kErf)                           \\\n  OF_PP_MAKE_TUPLE_SEQ(\"erfc\", ep::primitive::UnaryOp::kErfc)                         \\\n  OF_PP_MAKE_TUPLE_SEQ(\"exp\", ep::primitive::UnaryOp::kExp)                           \\\n  OF_PP_MAKE_TUPLE_SEQ(\"exp2\", ep::primitive::UnaryOp::kExp2)                         \\\n  OF_PP_MAKE_TUPLE_SEQ(\"expm1\", ep::primitive::UnaryOp::kExpm1)                       \\\n  OF_PP_MAKE_TUPLE_SEQ(\"floor\", ep::primitive::UnaryOp::kFloor)                       \\\n  OF_PP_MAKE_TUPLE_SEQ(\"lgamma\", ep::primitive::UnaryOp::kLgamma)                     \\\n  OF_PP_MAKE_TUPLE_SEQ(\"log\", ep::primitive::UnaryOp::kLog)                           \\\n  OF_PP_MAKE_TUPLE_SEQ(\"log2\", ep::primitive::UnaryOp::kLog2)                         \\\n  OF_PP_MAKE_TUPLE_SEQ(\"log10\", ep::primitive::UnaryOp::kLog10)                       \\\n  OF_PP_MAKE_TUPLE_SEQ(\"log1p\", ep::primitive::UnaryOp::kLog1p)                       \\\n  OF_PP_MAKE_TUPLE_SEQ(\"log_sigmoid\", ep::primitive::UnaryOp::kLogSigmoid)            \\\n  OF_PP_MAKE_TUPLE_SEQ(\"negative\", ep::primitive::UnaryOp::kNegative)                 \\\n  OF_PP_MAKE_TUPLE_SEQ(\"reciprocal\", ep::primitive::UnaryOp::kReciprocal)             \\\n  OF_PP_MAKE_TUPLE_SEQ(\"reciprocal_no_nan\", ep::primitive::UnaryOp::kReciprocalNoNan) \\\n  OF_PP_MAKE_TUPLE_SEQ(\"rint\", ep::primitive::UnaryOp::kRint)                         \\\n  OF_PP_MAKE_TUPLE_SEQ(\"round\", ep::primitive::UnaryOp::kRound)                       \\\n  OF_PP_MAKE_TUPLE_SEQ(\"rsqrt\", ep::primitive::UnaryOp::kRsqrt)                       \\\n  OF_PP_MAKE_TUPLE_SEQ(\"sigmoid\", ep::primitive::UnaryOp::kSigmoid)                   \\\n  OF_PP_MAKE_TUPLE_SEQ(\"sign\", ep::primitive::UnaryOp::kSign)                         \\\n  OF_PP_MAKE_TUPLE_SEQ(\"sin\", ep::primitive::UnaryOp::kSin)                           \\\n  OF_PP_MAKE_TUPLE_SEQ(\"sinh\", ep::primitive::UnaryOp::kSinh)                         \\\n  OF_PP_MAKE_TUPLE_SEQ(\"sqrt\", ep::primitive::UnaryOp::kSqrt)                         \\\n  OF_PP_MAKE_TUPLE_SEQ(\"square\", ep::primitive::UnaryOp::kSquare)                     \\\n  OF_PP_MAKE_TUPLE_SEQ(\"tan\", ep::primitive::UnaryOp::kTan)                           \\\n  OF_PP_MAKE_TUPLE_SEQ(\"not_equal_zero\", ep::primitive::UnaryOp::kNotEqualZero)       \\\n  OF_PP_MAKE_TUPLE_SEQ(\"bitwise_not\", ep::primitive::UnaryOp::kBitwiseNot)\n\n#define MATH_UNARY_ELEMENTWISE_GRAD_WITH_DY_X_PRIMITIVE_SEQ                                     \\\n  OF_PP_MAKE_TUPLE_SEQ(\"abs_grad\", ep::primitive::BinaryOp::kAbsBackwardWithDyX)                \\\n  OF_PP_MAKE_TUPLE_SEQ(\"acos_grad\", ep::primitive::BinaryOp::kAcosBackwardWithDyX)              \\\n  OF_PP_MAKE_TUPLE_SEQ(\"acosh_grad\", ep::primitive::BinaryOp::kAcoshBackwardWithDyX)            \\\n  OF_PP_MAKE_TUPLE_SEQ(\"asin_grad\", ep::primitive::BinaryOp::kAsinBackwardWithDyX)              \\\n  OF_PP_MAKE_TUPLE_SEQ(\"asinh_grad\", ep::primitive::BinaryOp::kAsinhBackwardWithDyX)            \\\n  OF_PP_MAKE_TUPLE_SEQ(\"atan_grad\", ep::primitive::BinaryOp::kAtanBackwardWithDyX)              \\\n  OF_PP_MAKE_TUPLE_SEQ(\"atanh_grad\", ep::primitive::BinaryOp::kAtanhBackwardWithDyX)            \\\n  OF_PP_MAKE_TUPLE_SEQ(\"cos_grad\", ep::primitive::BinaryOp::kCosBackwardWithDyX)                \\\n  OF_PP_MAKE_TUPLE_SEQ(\"cosh_grad\", ep::primitive::BinaryOp::kCoshBackwardWithDyX)              \\\n  OF_PP_MAKE_TUPLE_SEQ(\"digamma_grad\", ep::primitive::BinaryOp::kDigammaBackwardWithDyX)        \\\n  OF_PP_MAKE_TUPLE_SEQ(\"erf_grad\", ep::primitive::BinaryOp::kErfBackwardWithDyX)                \\\n  OF_PP_MAKE_TUPLE_SEQ(\"erfc_grad\", ep::primitive::BinaryOp::kErfcBackwardWithDyX)              \\\n  OF_PP_MAKE_TUPLE_SEQ(\"exp_grad\", ep::primitive::BinaryOp::kExpBackwardWithDyX)                \\\n  OF_PP_MAKE_TUPLE_SEQ(\"exp2_grad\", ep::primitive::BinaryOp::kExp2BackwardWithDyX)              \\\n  OF_PP_MAKE_TUPLE_SEQ(\"expm1_grad\", ep::primitive::BinaryOp::kExpm1BackwardWithDyX)            \\\n  OF_PP_MAKE_TUPLE_SEQ(\"log_grad\", ep::primitive::BinaryOp::kLogBackwardWithDyX)                \\\n  OF_PP_MAKE_TUPLE_SEQ(\"lgamma_grad\", ep::primitive::BinaryOp::kLgammaBackwardWithDyX)          \\\n  OF_PP_MAKE_TUPLE_SEQ(\"log2_grad\", ep::primitive::BinaryOp::kLog2BackwardWithDyX)              \\\n  OF_PP_MAKE_TUPLE_SEQ(\"log10_grad\", ep::primitive::BinaryOp::kLog10BackwardWithDyX)            \\\n  OF_PP_MAKE_TUPLE_SEQ(\"log1p_grad\", ep::primitive::BinaryOp::kLog1pBackwardWithDyX)            \\\n  OF_PP_MAKE_TUPLE_SEQ(\"log_sigmoid_grad\", ep::primitive::BinaryOp::kLogSigmoidBackwardWithDyX) \\\n  OF_PP_MAKE_TUPLE_SEQ(\"reciprocal_grad\", ep::primitive::BinaryOp::kReciprocalBackwardWithDyX)  \\\n  OF_PP_MAKE_TUPLE_SEQ(\"reciprocal_no_nan_grad\",                                                \\\n                       ep::primitive::BinaryOp::kReciprocalNoNanBackwardWithDyX)                \\\n  OF_PP_MAKE_TUPLE_SEQ(\"rsqrt_grad\", ep::primitive::BinaryOp::kRsqrtBackwardWithDyX)            \\\n  OF_PP_MAKE_TUPLE_SEQ(\"sin_grad\", ep::primitive::BinaryOp::kSinBackwardWithDyX)                \\\n  OF_PP_MAKE_TUPLE_SEQ(\"sinh_grad\", ep::primitive::BinaryOp::kSinhBackwardWithDyX)              \\\n  OF_PP_MAKE_TUPLE_SEQ(\"sqrt_grad\", ep::primitive::BinaryOp::kSqrtBackwardWithDyX)              \\\n  OF_PP_MAKE_TUPLE_SEQ(\"square_grad\", ep::primitive::BinaryOp::kSquareBackwardWithDyX)          \\\n  OF_PP_MAKE_TUPLE_SEQ(\"tan_grad\", ep::primitive::BinaryOp::kTanBackwardWithDyX)\n\n#define MATH_UNARY_ELEMENTWISE_GRAD_WITH_DY_Y_PRIMITIVE_SEQ \\\n  OF_PP_MAKE_TUPLE_SEQ(\"sigmoid_grad\", ep::primitive::BinaryOp::kSigmoidBackwardWithDyY)\n\n#define REGISTER_MATH_UNARY_PRIMITIVE_KERNEL(name, UnaryOp)                               \\\n  REGISTER_USER_KERNEL(name)                                                              \\\n      .SetCreateFn([]() {                                                                 \\\n        return user_op::NewOpKernel<UnaryPrimitiveKernel>(                                \\\n            \"y\", \"x\", [](user_op::KernelComputeContext* ctx) {                            \\\n              const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex(\"y\", 0);   \\\n              const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex(\"x\", 0);   \\\n              return ep::primitive::NewPrimitive<ep::primitive::ElementwiseUnaryFactory>( \\\n                  ctx->device_type(), UnaryOp, src->data_type(), dst->data_type());       \\\n            });                                                                           \\\n      })                                                                                  \\\n      .SetIsMatchedHob(UnaryPrimitiveExists(UnaryOp, \"y\", \"x\"));\n\nOF_PP_FOR_EACH_TUPLE(REGISTER_MATH_UNARY_PRIMITIVE_KERNEL, MATH_UNARY_ELEMENTWISE_PRIMITIVE_SEQ)\n\n#define REGISTER_MATH_UNARY_GRAD_PRIMITIVE_WITH_DY_X_KERNEL(name, BinaryOp)                     \\\n  REGISTER_USER_KERNEL(name)                                                                    \\\n      .SetCreateFn([]() {                                                                       \\\n        return user_op::NewOpKernel<                                                            \\\n            BinaryPrimitiveKernel>(\"dx\", \"dy\", \"x\", [](user_op::KernelComputeContext* ctx) {    \\\n          const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex(\"dy\", 0);            \\\n          const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex(\"dx\", 0);            \\\n          return ep::primitive::NewPrimitive<ep::primitive::BroadcastElementwiseBinaryFactory>( \\\n              ctx->device_type(), BinaryOp, src->data_type(), dst->data_type(),                 \\\n              1 /*max_num_dims*/);                                                              \\\n        });                                                                                     \\\n      })                                                                                        \\\n      .SetIsMatchedHob(BinaryPrimitiveExists(BinaryOp, \"dx\", \"dy\"));\n\nOF_PP_FOR_EACH_TUPLE(REGISTER_MATH_UNARY_GRAD_PRIMITIVE_WITH_DY_X_KERNEL,\n                     MATH_UNARY_ELEMENTWISE_GRAD_WITH_DY_X_PRIMITIVE_SEQ)\n\n#define REGISTER_MATH_UNARY_GRAD_PRIMITIVE_WITH_DY_Y_KERNEL(name, BinaryOp)                     \\\n  REGISTER_USER_KERNEL(name)                                                                    \\\n      .SetCreateFn([]() {                                                                       \\\n        return user_op::NewOpKernel<                                                            \\\n            BinaryPrimitiveKernel>(\"dx\", \"dy\", \"y\", [](user_op::KernelComputeContext* ctx) {    \\\n          const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex(\"dy\", 0);            \\\n          const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex(\"dx\", 0);            \\\n          return ep::primitive::NewPrimitive<ep::primitive::BroadcastElementwiseBinaryFactory>( \\\n              ctx->device_type(), BinaryOp, src->data_type(), dst->data_type(),                 \\\n              1 /*max_num_dims*/);                                                              \\\n        });                                                                                     \\\n      })                                                                                        \\\n      .SetIsMatchedHob(BinaryPrimitiveExists(BinaryOp, \"dx\", \"dy\"));\n\nOF_PP_FOR_EACH_TUPLE(REGISTER_MATH_UNARY_GRAD_PRIMITIVE_WITH_DY_Y_KERNEL,\n                     MATH_UNARY_ELEMENTWISE_GRAD_WITH_DY_Y_PRIMITIVE_SEQ)\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/matmul_kernels.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/new_kernel_util.h\"\n#include \"oneflow/core/framework/config_def.h\"\n#include \"oneflow/core/kernel/cuda_graph_support.h\"\n#include \"oneflow/core/ep/include/primitive/memcpy.h\"\n#include \"oneflow/core/ep/include/primitive/matmul.h\"\n#include \"oneflow/core/ep/include/primitive/batch_matmul.h\"\n#include \"oneflow/core/ep/include/primitive/broadcast_matmul.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nep::primitive::BlasTransposeType GetBlasTransposeType(bool transpose) {\n  return transpose ? ep::primitive::BlasTransposeType::T : ep::primitive::BlasTransposeType::N;\n}\n\ntemplate<typename Context>\nep::primitive::BlasTransposeType GetBlasTransposeType(Context* ctx, const std::string& attr) {\n  return GetBlasTransposeType(ctx->template Attr<bool>(attr));\n}\n\nvoid InferMatmulMNK(const ShapeView& a_shape, const ShapeView& b_shape, const ShapeView& c_shape,\n                    ep::primitive::BlasTransposeType transpose_a,\n                    ep::primitive::BlasTransposeType transpose_b, size_t* m, size_t* n, size_t* k) {\n  const int64_t num_a_axes = a_shape.NumAxes();\n  CHECK_GE(num_a_axes, 2);\n  const int64_t num_b_axes = b_shape.NumAxes();\n  CHECK_GE(num_b_axes, 2);\n  const int64_t num_c_axes = c_shape.NumAxes();\n  CHECK_GE(num_c_axes, 2);\n  if (transpose_a == ep::primitive::BlasTransposeType::N) {\n    *m = a_shape.At(num_a_axes - 2);\n    *k = a_shape.At(num_a_axes - 1);\n  } else if (transpose_a == ep::primitive::BlasTransposeType::T) {\n    *m = a_shape.At(num_a_axes - 1);\n    *k = a_shape.At(num_a_axes - 2);\n  } else {\n    UNIMPLEMENTED();\n  }\n  if (transpose_b == ep::primitive::BlasTransposeType::N) {\n    CHECK_EQ(b_shape.At(num_b_axes - 2), *k);\n    *n = b_shape.At(num_b_axes - 1);\n  } else if (transpose_b == ep::primitive::BlasTransposeType::T) {\n    CHECK_EQ(b_shape.At(num_b_axes - 1), *k);\n    *n = b_shape.At(num_b_axes - 2);\n  } else {\n    UNIMPLEMENTED();\n  }\n  CHECK_EQ(c_shape.At(num_c_axes - 2), *m);\n  CHECK_EQ(c_shape.At(num_c_axes - 1), *n);\n}\n\ntemplate<typename Context>\nstd::unique_ptr<ep::primitive::Memcpy> NewMemcpyPrimitive(Context* ctx) {\n  return ep::primitive::NewPrimitive<ep::primitive::MemcpyFactory>(\n      ctx->device_type(), ep::primitive::MemcpyKind::kDtoD);\n}\n\nstd::unique_ptr<ep::primitive::Matmul> NewMatmulPrimitive(DeviceType device_type,\n                                                          DataType data_type, bool transpose_a,\n                                                          bool transpose_b) {\n  const auto trans_a = GetBlasTransposeType(transpose_a);\n  const auto trans_b = GetBlasTransposeType(transpose_b);\n  return ep::primitive::NewPrimitive<ep::primitive::MatmulFactory>(device_type, data_type, trans_a,\n                                                                   trans_b);\n}\n\ntemplate<typename Context>\nstd::unique_ptr<ep::primitive::Matmul> NewMatmulPrimitive(Context* ctx) {\n  const DataType data_type = ctx->TensorDesc4ArgNameAndIndex(\"out\", 0)->data_type();\n  return NewMatmulPrimitive(ctx->device_type(), data_type, ctx->template Attr<bool>(\"transpose_a\"),\n                            ctx->template Attr<bool>(\"transpose_b\"));\n}\n\ntemplate<typename Context>\nstd::unique_ptr<ep::primitive::BatchMatmul> NewBatchMatmulPrimitive(Context* ctx) {\n  const DataType data_type = ctx->TensorDesc4ArgNameAndIndex(\"out\", 0)->data_type();\n  const auto trans_a = GetBlasTransposeType(ctx, \"transpose_a\");\n  const auto trans_b = GetBlasTransposeType(ctx, \"transpose_b\");\n  return ep::primitive::NewPrimitive<ep::primitive::BatchMatmulFactory>(\n      ctx->device_type(), data_type, trans_a, trans_b);\n}\n\ntemplate<typename Context>\nstd::unique_ptr<ep::primitive::BroadcastMatmul> NewBroadcastMatmulPrimitive(Context* ctx) {\n  const DataType data_type = ctx->TensorDesc4ArgNameAndIndex(\"out\", 0)->data_type();\n  const auto trans_a = GetBlasTransposeType(ctx, \"transpose_a\");\n  const auto trans_b = GetBlasTransposeType(ctx, \"transpose_b\");\n  const int64_t a_num_axes = ctx->TensorDesc4ArgNameAndIndex(\"a\", 0)->shape().NumAxes();\n  const int64_t b_num_axes = ctx->TensorDesc4ArgNameAndIndex(\"b\", 0)->shape().NumAxes();\n  const int64_t max_num_axes = std::max(a_num_axes, b_num_axes);\n  return ep::primitive::NewPrimitive<ep::primitive::BroadcastMatmulFactory>(\n      ctx->device_type(), data_type, trans_a, trans_b, max_num_axes);\n}\n\nauto MemcpyPrimitiveExists() {\n  return hob::make_custom(\"MemcpyPrimitiveExists\", [](const user_op::KernelRegContext& ctx) {\n    return NewMemcpyPrimitive(&ctx).operator bool();\n  });\n}\n\nauto MatmulPrimitiveExists() {\n  return hob::make_custom(\"MatmulPrimitiveExists\", [](const user_op::KernelRegContext& ctx) {\n    return NewMatmulPrimitive(&ctx).operator bool();\n  });\n}\n\nauto BatchMatmulPrimitiveExists() {\n  return hob::make_custom(\"BatchMatmulPrimitiveExists\", [](const user_op::KernelRegContext& ctx) {\n    return NewBatchMatmulPrimitive(&ctx).operator bool();\n  });\n}\n\nauto BroadcastMatmulPrimitiveExists() {\n  return hob::make_custom(\"BroadcastMatmulPrimitiveExists\",\n                          [](const user_op::KernelRegContext& ctx) {\n                            return NewBroadcastMatmulPrimitive(&ctx).operator bool();\n                          });\n}\n\nclass MatmulKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport {\n public:\n  MatmulKernel() = default;\n  ~MatmulKernel() = default;\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const auto trans_a = GetBlasTransposeType(ctx, \"transpose_a\");\n    const auto trans_b = GetBlasTransposeType(ctx, \"transpose_b\");\n    const user_op::Tensor* a = ctx->Tensor4ArgNameAndIndex(\"a\", 0);\n    CHECK_EQ(a->shape_view().NumAxes(), 2);\n    const DataType data_type = a->data_type();\n    const user_op::Tensor* b = ctx->Tensor4ArgNameAndIndex(\"b\", 0);\n    CHECK_EQ(b->shape_view().NumAxes(), 2);\n    CHECK_EQ(b->data_type(), data_type);\n\n    const int32_t elem_cnt_a = a->shape_view().elem_cnt();\n    const int32_t elem_cnt_b = b->shape_view().elem_cnt();\n    CHECK_GE(elem_cnt_a, 0);\n    CHECK_GE(elem_cnt_b, 0);\n    if (elem_cnt_a == 0 || elem_cnt_b == 0) { return; }\n\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    CHECK_EQ(out->shape_view().NumAxes(), 2);\n    CHECK_EQ(out->data_type(), data_type);\n    size_t m = 0, n = 0, k = 0;\n    InferMatmulMNK(a->shape_view(), b->shape_view(), out->shape_view(), trans_a, trans_b, &m, &n,\n                   &k);\n    const double alpha = ctx->Attr<double>(\"alpha\");\n    double beta = 0.0;\n    if (ctx->has_input(\"_add_to_output\", 0)) {\n      const user_op::Tensor* add_to_output = ctx->Tensor4ArgNameAndIndex(\"_add_to_output\", 0);\n      CHECK_EQ(add_to_output->data_type(), data_type);\n      CHECK_EQ(add_to_output->shape_view(), out->shape_view());\n      auto memcpy = NewMemcpyPrimitive(ctx);\n      CHECK(memcpy);\n      memcpy->Launch(ctx->stream(), out->mut_dptr(), add_to_output->dptr(),\n                     add_to_output->shape_view().elem_cnt() * GetSizeOfDataType(data_type));\n      beta = 1.0;\n    }\n    auto matmul = NewMatmulPrimitive(ctx);\n    CHECK(matmul);\n    matmul->Launch(ctx->stream(), m, n, k, alpha, a->dptr(), b->dptr(), beta, out->mut_dptr());\n  }\n};\n\nREGISTER_USER_KERNEL(\"matmul\")\n    .SetCreateFn<MatmulKernel>()\n    .SetIsMatchedHob(MemcpyPrimitiveExists() && MatmulPrimitiveExists())\n    .SetInplaceProposalFn([](const user_op::InferContext& ctx,\n                             const user_op::AddInplaceArgPair& AddInplaceArgPairFn) -> Maybe<void> {\n      if (ctx.has_input(\"_add_to_output\", 0)) {\n        OF_RETURN_IF_ERROR(AddInplaceArgPairFn(\"out\", 0, \"_add_to_output\", 0, true));\n      }\n      return Maybe<void>::Ok();\n    });\n\nclass BatchMatmulKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport {\n public:\n  BatchMatmulKernel() = default;\n  ~BatchMatmulKernel() = default;\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const auto trans_a = GetBlasTransposeType(ctx, \"transpose_a\");\n    const auto trans_b = GetBlasTransposeType(ctx, \"transpose_b\");\n    const user_op::Tensor* a = ctx->Tensor4ArgNameAndIndex(\"a\", 0);\n    const DataType data_type = a->data_type();\n    const int64_t num_axes = a->shape_view().NumAxes();\n    CHECK_GT(num_axes, 2);\n    const user_op::Tensor* b = ctx->Tensor4ArgNameAndIndex(\"b\", 0);\n    CHECK_EQ(b->data_type(), data_type);\n    CHECK_EQ(b->shape_view().NumAxes(), num_axes);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    CHECK_EQ(out->data_type(), data_type);\n    CHECK_EQ(out->shape_view().NumAxes(), num_axes);\n    size_t m = 0;\n    size_t n = 0;\n    size_t k = 0;\n    InferMatmulMNK(a->shape_view(), b->shape_view(), out->shape_view(), trans_a, trans_b, &m, &n,\n                   &k);\n    size_t batch_size = 1;\n    for (size_t i = 0; i < num_axes - 2; ++i) {\n      const int64_t dim_size = a->shape_view().At(i);\n      CHECK_GT(dim_size, 0);\n      CHECK_EQ(b->shape_view().At(i), dim_size);\n      CHECK_EQ(out->shape_view().At(i), dim_size);\n      batch_size *= dim_size;\n    }\n    const double alpha = ctx->Attr<double>(\"alpha\");\n    double beta = 0.0;\n    if (ctx->has_input(\"_add_to_output\", 0)) {\n      const user_op::Tensor* add_to_output = ctx->Tensor4ArgNameAndIndex(\"_add_to_output\", 0);\n      CHECK_EQ(add_to_output->data_type(), data_type);\n      CHECK_EQ(add_to_output->shape_view(), out->shape_view());\n      auto memcpy = NewMemcpyPrimitive(ctx);\n      CHECK(memcpy);\n      memcpy->Launch(ctx->stream(), out->mut_dptr(), add_to_output->dptr(),\n                     add_to_output->shape_view().elem_cnt() * GetSizeOfDataType(data_type));\n      beta = 1.0;\n    }\n    auto batch_matmul = NewBatchMatmulPrimitive(ctx);\n    CHECK(batch_matmul);\n    batch_matmul->Launch(ctx->stream(), batch_size, m, n, k, alpha, a->dptr(), b->dptr(), beta,\n                         out->mut_dptr());\n  }\n};\n\nREGISTER_USER_KERNEL(\"batch_matmul\")\n    .SetCreateFn<BatchMatmulKernel>()\n    .SetIsMatchedHob(MemcpyPrimitiveExists() && BatchMatmulPrimitiveExists())\n    .SetInplaceProposalFn([](const user_op::InferContext& ctx,\n                             const user_op::AddInplaceArgPair& AddInplaceArgPairFn) -> Maybe<void> {\n      if (ctx.has_input(\"_add_to_output\", 0)) {\n        OF_RETURN_IF_ERROR(AddInplaceArgPairFn(\"out\", 0, \"_add_to_output\", 0, true));\n      }\n      return Maybe<void>::Ok();\n    });\n\nclass BroadcastMatmulKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport {\n public:\n  BroadcastMatmulKernel() = default;\n  ~BroadcastMatmulKernel() = default;\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    double alpha = ctx->Attr<double>(\"alpha\");\n\n    const user_op::Tensor* a = ctx->Tensor4ArgNameAndIndex(\"a\", 0);\n    const user_op::Tensor* b = ctx->Tensor4ArgNameAndIndex(\"b\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n\n    double beta = 0.0;\n    if (ctx->has_input(\"_add_to_output\", 0)) {\n      const user_op::Tensor* add_to_output = ctx->Tensor4ArgNameAndIndex(\"_add_to_output\", 0);\n      CHECK_EQ(add_to_output->shape_view(), out->shape_view());\n      auto memcpy = NewMemcpyPrimitive(ctx);\n      CHECK(memcpy);\n      memcpy->Launch(\n          ctx->stream(), out->mut_dptr(), add_to_output->dptr(),\n          add_to_output->shape_view().elem_cnt() * GetSizeOfDataType(add_to_output->data_type()));\n      beta = 1.0;\n    }\n\n    const int64_t a_num_axes = a->shape_view().NumAxes();\n    const int64_t b_num_axes = b->shape_view().NumAxes();\n    const int64_t out_num_axes = out->shape_view().NumAxes();\n    auto broadcast_matmul = NewBroadcastMatmulPrimitive(ctx);\n    CHECK(broadcast_matmul);\n    broadcast_matmul->Launch(ctx->stream(), alpha, a_num_axes, a->shape_view().ptr(), a->dptr(),\n                             b_num_axes, b->shape_view().ptr(), b->dptr(), beta, out_num_axes,\n                             out->shape_view().ptr(), out->mut_dptr());\n  }\n};\n\nREGISTER_USER_KERNEL(\"broadcast_matmul\")\n    .SetCreateFn<BroadcastMatmulKernel>()\n    .SetIsMatchedHob(MemcpyPrimitiveExists() && BroadcastMatmulPrimitiveExists())\n    .SetInplaceProposalFn([](const user_op::InferContext& ctx,\n                             const user_op::AddInplaceArgPair& AddInplaceArgPairFn) -> Maybe<void> {\n      if (ctx.has_input(\"_add_to_output\", 0)) {\n        OF_RETURN_IF_ERROR(AddInplaceArgPairFn(\"out\", 0, \"_add_to_output\", 0, true));\n      }\n      return Maybe<void>::Ok();\n    });\n\ntemplate<typename Context>\nstd::unique_ptr<ep::primitive::Matmul> NewMatmulPrimitiveForBroadcastMatmulGradB(Context* ctx) {\n  const DataType data_type = ctx->TensorDesc4ArgNameAndIndex(\"out\", 0)->data_type();\n  return NewMatmulPrimitive(ctx->device_type(), data_type, true, false);\n}\n\nclass BroadcastMatmulGradBKernel final : public user_op::OpKernel,\n                                         public user_op::CudaGraphSupport {\n public:\n  BroadcastMatmulGradBKernel() = default;\n  ~BroadcastMatmulGradBKernel() = default;\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    double alpha = ctx->Attr<double>(\"alpha\");\n    const user_op::Tensor* a = ctx->Tensor4ArgNameAndIndex(\"a\", 0);\n    const user_op::Tensor* b = ctx->Tensor4ArgNameAndIndex(\"b\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    double beta = 0.0;\n    if (ctx->has_input(\"_add_to_output\", 0)) {\n      const user_op::Tensor* add_to_output = ctx->Tensor4ArgNameAndIndex(\"_add_to_output\", 0);\n      CHECK_EQ(add_to_output->shape_view(), out->shape_view());\n      auto memcpy = NewMemcpyPrimitive(ctx);\n      CHECK(memcpy);\n      memcpy->Launch(\n          ctx->stream(), out->mut_dptr(), add_to_output->dptr(),\n          add_to_output->shape_view().elem_cnt() * GetSizeOfDataType(add_to_output->data_type()));\n      beta = 1.0;\n    }\n\n    CHECK_EQ(a->shape_view().NumAxes(), b->shape_view().NumAxes());\n    int64_t k = a->shape_view().Count(0, a->shape_view().NumAxes() - 1);\n    CHECK_EQ(b->shape_view().Count(0, b->shape_view().NumAxes() - 1), k);\n    int64_t m = a->shape_view().At(a->shape_view().NumAxes() - 1);\n    int64_t n = b->shape_view().At(b->shape_view().NumAxes() - 1);\n    auto matmul = NewMatmulPrimitiveForBroadcastMatmulGradB(ctx);\n    CHECK(matmul);\n    matmul->Launch(ctx->stream(), m, n, k, alpha, a->dptr(), b->dptr(), beta, out->mut_dptr());\n  }\n};\n\nauto PrimitiveExistsForBroadcastMatmulGradB() {\n  return hob::make_custom(\"MatmulPrimitiveExists\", [](const user_op::KernelRegContext& ctx) {\n    return NewMatmulPrimitiveForBroadcastMatmulGradB(&ctx).operator bool();\n  });\n}\n\nREGISTER_USER_KERNEL(\"broadcast_matmul_grad_b\")\n    .SetCreateFn<BroadcastMatmulGradBKernel>()\n    .SetIsMatchedHob(MemcpyPrimitiveExists() && PrimitiveExistsForBroadcastMatmulGradB())\n    .SetInplaceProposalFn([](const user_op::InferContext& ctx,\n                             const user_op::AddInplaceArgPair& AddInplaceArgPairFn) -> Maybe<void> {\n      if (ctx.has_input(\"_add_to_output\", 0)) {\n        OF_RETURN_IF_ERROR(AddInplaceArgPairFn(\"out\", 0, \"_add_to_output\", 0, true));\n      }\n      return Maybe<void>::Ok();\n    });\n}  // namespace\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/matrix_vector_product_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/new_kernel_util.h\"\n#include \"oneflow/core/framework/config_def.h\"\n#include \"oneflow/core/kernel/cuda_graph_support.h\"\n#include \"oneflow/core/ep/include/primitive/memcpy.h\"\n#include \"oneflow/core/ep/include/primitive/matmul.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nep::primitive::BlasTransposeType GetBlasTransposeType(bool transpose) {\n  return transpose ? ep::primitive::BlasTransposeType::T : ep::primitive::BlasTransposeType::N;\n}\n\ntemplate<typename Context>\nstd::unique_ptr<ep::primitive::Memcpy> NewMemcpyPrimitive(Context* ctx) {\n  return ep::primitive::NewPrimitive<ep::primitive::MemcpyFactory>(\n      ctx->device_type(), ep::primitive::MemcpyKind::kDtoD);\n}\n\nstd::unique_ptr<ep::primitive::Matmul> NewMatmulPrimitive(DeviceType device_type,\n                                                          DataType data_type, bool transpose_a,\n                                                          bool transpose_b) {\n  const auto trans_a = GetBlasTransposeType(transpose_a);\n  const auto trans_b = GetBlasTransposeType(transpose_b);\n  return ep::primitive::NewPrimitive<ep::primitive::MatmulFactory>(device_type, data_type, trans_a,\n                                                                   trans_b);\n}\n\ntemplate<typename Context>\nstd::unique_ptr<ep::primitive::Matmul> NewMatrixVectorProductPrimitive(Context* ctx) {\n  const DataType data_type = ctx->TensorDesc4ArgNameAndIndex(\"out\", 0)->data_type();\n  return NewMatmulPrimitive(ctx->device_type(), data_type, false, false);\n}\n\ntemplate<typename Context>\nstd::unique_ptr<ep::primitive::Matmul> NewMatrixVectorProductGradAPrimitive(Context* ctx) {\n  const DataType data_type = ctx->TensorDesc4ArgNameAndIndex(\"dx\", 0)->data_type();\n  return NewMatmulPrimitive(ctx->device_type(), data_type, false, true);\n}\n\ntemplate<typename Context>\nstd::unique_ptr<ep::primitive::Matmul> NewMatrixVectorProductGradBPrimitive(Context* ctx) {\n  const DataType data_type = ctx->TensorDesc4ArgNameAndIndex(\"dx\", 0)->data_type();\n  return NewMatmulPrimitive(ctx->device_type(), data_type, true, false);\n}\n\nauto MatrixVectorProductPrimitiveExists() {\n  return hob::make_custom(\"NewMatrixVectorProductPrimitiveExists\",\n                          [](const user_op::KernelRegContext& ctx) {\n                            return NewMatrixVectorProductPrimitive(&ctx).operator bool();\n                          });\n}\n\nauto MatrixVectorProductGradAPrimitiveExists() {\n  return hob::make_custom(\"NewMatrixVectorProductGradAPrimitiveExists\",\n                          [](const user_op::KernelRegContext& ctx) {\n                            return NewMatrixVectorProductGradAPrimitive(&ctx).operator bool();\n                          });\n}\n\nauto MatrixVectorProductGradBPrimitiveExists() {\n  return hob::make_custom(\"NewMatrixVectorProductGradBPrimitiveExists\",\n                          [](const user_op::KernelRegContext& ctx) {\n                            return NewMatrixVectorProductGradBPrimitive(&ctx).operator bool();\n                          });\n}\n\nclass MatrixVectorProductKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport {\n public:\n  MatrixVectorProductKernel() = default;\n  ~MatrixVectorProductKernel() = default;\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    /*\n    A(m, k) matmul B(k) -> (m, k) matmul (k, 1) -> (m, 1) -> (m)\n    */\n    const user_op::Tensor* a = ctx->Tensor4ArgNameAndIndex(\"a\", 0);\n    CHECK_EQ(a->shape_view().NumAxes(), 2) << \"A Numdims should be equal to 2. \";\n    const DataType data_type = a->data_type();\n    const user_op::Tensor* b = ctx->Tensor4ArgNameAndIndex(\"b\", 0);\n    CHECK_EQ(b->shape_view().NumAxes(), 1) << \"B Numdims should be equal to 1. \";\n    CHECK_EQ(b->data_type(), data_type) << \"Matrix A Datatype should be equal to Vector B\";\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    CHECK_EQ(out->shape_view().NumAxes(), 1) << \"Out Numdims should be equal to 1. \";\n    CHECK_EQ(out->data_type(), data_type) << \"Out Datatype should be equal to input's. \";\n    size_t m = a->shape_view().At(0);\n    size_t k = a->shape_view().At(1);\n    size_t n = 1;\n    const double alpha = 1.0;\n    double beta = 0.0;\n    auto matmul = NewMatrixVectorProductPrimitive(ctx);\n    CHECK(matmul);\n    matmul->Launch(ctx->stream(), m, n, k, alpha, a->dptr(), b->dptr(), beta, out->mut_dptr());\n  }\n};\n\nREGISTER_USER_KERNEL(\"matrix_vector_product\")\n    .SetCreateFn<MatrixVectorProductKernel>()\n    .SetIsMatchedHob(MatrixVectorProductPrimitiveExists());\n\nclass MatrixVectorProductGradAKernel final : public user_op::OpKernel,\n                                             public user_op::CudaGraphSupport {\n public:\n  MatrixVectorProductGradAKernel() = default;\n  ~MatrixVectorProductGradAKernel() = default;\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    /*\n    A(m, k) matmul B(k) -> (m, k) matmul (k, 1) -> (m, 1) -> (m)\n    GradA = dy (m) matmul B(k) -> (m, 1) (k, 1)_transpose\n    */\n    const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    const user_op::Tensor* b = ctx->Tensor4ArgNameAndIndex(\"b\", 0);\n    user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex(\"dx\", 0);\n    size_t m = dy->shape_view().At(0);\n    size_t k = 1;\n    size_t n = b->shape_view().At(0);\n    const double alpha = 1.0;\n    double beta = 0.0;\n    auto matmul = NewMatrixVectorProductGradAPrimitive(ctx);\n    CHECK(matmul);\n    matmul->Launch(ctx->stream(), m, n, k, alpha, dy->dptr(), b->dptr(), beta, dx->mut_dptr());\n  }\n};\n\nREGISTER_USER_KERNEL(\"matrix_vector_product_grad_a\")\n    .SetCreateFn<MatrixVectorProductGradAKernel>()\n    .SetIsMatchedHob(MatrixVectorProductGradAPrimitiveExists());\n\nclass MatrixVectorProductGradBKernel final : public user_op::OpKernel,\n                                             public user_op::CudaGraphSupport {\n public:\n  MatrixVectorProductGradBKernel() = default;\n  ~MatrixVectorProductGradBKernel() = default;\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    /*\n    A(m, k) matmul B(k) -> (m, k) matmul (k, 1) -> (m, 1) -> (m)\n    GradB = dy_transpose (1, m) matmul A(m, k)\n    */\n    const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    const user_op::Tensor* a = ctx->Tensor4ArgNameAndIndex(\"a\", 0);\n    user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex(\"dx\", 0);\n    size_t m = 1;\n    size_t k = dy->shape_view().At(0);\n    size_t n = a->shape_view().At(1);\n    const double alpha = 1.0;\n    double beta = 0.0;\n    auto matmul = NewMatrixVectorProductGradBPrimitive(ctx);\n    CHECK(matmul);\n    matmul->Launch(ctx->stream(), m, n, k, alpha, dy->dptr(), a->dptr(), beta, dx->mut_dptr());\n  }\n};\n\nREGISTER_USER_KERNEL(\"matrix_vector_product_grad_b\")\n    .SetCreateFn<MatrixVectorProductGradBKernel>()\n    .SetIsMatchedHob(MatrixVectorProductGradBPrimitiveExists());\n\n}  // namespace\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/max_pool_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/user/kernels/max_pool_kernel_util.h\"\n\nnamespace oneflow {\n\nstruct PoolOpKernelCache final : public user_op::OpKernelCache {\n  MaxPoolParams3D params_3d;\n  explicit PoolOpKernelCache(const MaxPoolParams3D& params_3d) : params_3d(params_3d) {}\n  const MaxPoolParams3D& GetParams3D() const { return params_3d; }\n};\n\nstd::shared_ptr<PoolOpKernelCache> CreatePoolOpKernelCache(user_op::KernelCacheContext* ctx,\n                                                           const int32_t& dim) {\n  const Shape& x_shape = ctx->TensorDesc4ArgNameAndIndex(\"x\", 0)->shape();\n  const std::string& data_format = ctx->Attr<std::string>(\"data_format\");\n  const std::vector<int32_t>& padding = ctx->Attr<std::vector<int32_t>>(\"padding\");\n  const std::vector<int32_t>& kernel_size = ctx->Attr<std::vector<int32_t>>(\"kernel_size\");\n  const std::vector<int32_t>& stride = ctx->Attr<std::vector<int32_t>>(\"stride\");\n  const std::vector<int32_t>& dilation = ctx->Attr<std::vector<int32_t>>(\"dilation\");\n  const bool return_indices = ctx->Attr<bool>(\"return_indices\");\n  const bool ceil_mode = ctx->Attr<bool>(\"ceil_mode\");\n\n  MaxPoolParams3D params_3d = MaxPoolParams3D(dim, x_shape, data_format, padding, kernel_size,\n                                              stride, dilation, return_indices, ceil_mode);\n  std::shared_ptr<PoolOpKernelCache> cache(new PoolOpKernelCache(params_3d));\n  return cache;\n}\n\nnamespace {\n\ntemplate<typename T, typename IDX>\nvoid Maxpool2dForwardComputeCLast(const NdIndexOffsetHelper<IDX, 4>& index_helper, IDX elem_num,\n                                  const T* src, T* dest, int64_t* indice_ptr,\n                                  const int32_t padding_h, const int32_t padding_w,\n                                  const int32_t n_batch, const int32_t n_channel,\n                                  const int32_t x_height, const int32_t x_width,\n                                  const int32_t y_height, const int32_t y_width,\n                                  const int32_t kernel_size_h, const int32_t kernel_size_w,\n                                  const int32_t stride_h, const int32_t stride_w,\n                                  const int32_t dilation_h, const int32_t dilation_w) {\n  IDX n = 0, h = 0, w = 0, c = 0;\n  for (IDX num = 0; num < elem_num; ++num) {\n    index_helper.OffsetToNdIndex(num, n, h, w, c);\n\n    const IDX x_start_idx = n * x_height * x_width * n_channel;\n    const IDX y_start_idx = n * y_height * y_width * n_channel;\n    IDX hstart = h * stride_h - padding_h;\n    IDX wstart = w * stride_w - padding_w;\n    const IDX hend = (hstart + (kernel_size_h - 1) * dilation_h + 1) <= x_height\n                         ? (hstart + (kernel_size_h - 1) * dilation_h + 1)\n                         : x_height;\n    const IDX wend = (wstart + (kernel_size_w - 1) * dilation_w + 1) <= x_width\n                         ? (wstart + (kernel_size_w - 1) * dilation_w + 1)\n                         : x_width;\n\n    while (hstart < 0) { hstart += dilation_h; }\n    while (wstart < 0) { wstart += dilation_w; }\n    /* compute max value(src[src_idx]) in kernel box region, and save the value to dest[num] */\n    IDX max_index = hstart * x_width + wstart;\n    IDX src_idx = 0;\n    /* equal to -std::numeric_limits<T>::infinity(); */\n    T max_value = detail::numeric_limits<T>::lower_bound();\n\n    for (IDX i = hstart; i < hend; i += dilation_h) {\n      for (IDX j = wstart; j < wend; j += dilation_w) {\n        const IDX window_idx = i * x_width * n_channel + j * n_channel + c;\n        const IDX search_idx = x_start_idx + window_idx;\n        T val = src[search_idx];\n        if (val > max_value || detail::numerics<T>::isnan(val)) {\n          max_value = val;\n          max_index = window_idx;\n          src_idx = search_idx;\n        }\n      }\n    }\n    const IDX out_idx = y_start_idx + h * y_width * n_channel + w * n_channel + c;\n    dest[out_idx] = src[src_idx];\n    indice_ptr[out_idx] = max_index;\n  }\n}\n\n}  // namespace\n\ntemplate<typename T, typename IDX>\nstruct PoolKernelUtil<DeviceType::kCPU, T, IDX> {\n  static void Maxpool1dForward(ep::Stream* stream, const NdIndexOffsetHelper<IDX, 2>& index_helper,\n                               const IDX elem_num, const T* src, T* dest, int64_t* indice_ptr,\n                               const MaxPoolParams3D& params_3d) {\n    Maxpool1dForwardCompute<T, IDX>(\n        index_helper, elem_num, src, dest, indice_ptr, params_3d.padding()[2],\n        params_3d.num_batch(), params_3d.num_channel(), params_3d.GetXShape5D().At(4),\n        params_3d.pool_size_3d()[2], params_3d.stride_3d()[2], params_3d.dilation_3d()[2]);\n  }\n\n  static void Maxpool1dBackward(ep::Stream* stream, const NdIndexOffsetHelper<IDX, 2>& index_helper,\n                                const IDX elem_num, const T* src, T* dest,\n                                const int64_t* indice_ptr, const MaxPoolParams3D& params_3d) {\n    Maxpool1dBackwardCompute<T, IDX>(index_helper, elem_num, src, dest, indice_ptr,\n                                     params_3d.num_batch(), params_3d.num_channel(),\n                                     params_3d.GetYShape5D().At(4), params_3d.GetXShape5D().At(4));\n  }\n\n  static void Maxpool2dForwardCFirst(ep::Stream* stream,\n                                     const NdIndexOffsetHelper<IDX, 3>& index_helper,\n                                     const IDX elem_num, const T* src, T* dest, int64_t* indice_ptr,\n                                     const MaxPoolParams3D& params_3d) {\n    Maxpool2dForwardComputeCFirst<T, IDX>(\n        index_helper, elem_num, src, dest, indice_ptr, params_3d.padding()[1],\n        params_3d.padding()[2], params_3d.num_batch(), params_3d.num_channel(),\n        params_3d.GetXShape5D().At(3), params_3d.GetXShape5D().At(4), params_3d.pool_size_3d()[1],\n        params_3d.pool_size_3d()[2], params_3d.stride_3d()[1], params_3d.stride_3d()[2],\n        params_3d.dilation_3d()[1], params_3d.dilation_3d()[2]);\n  }\n\n  static void Maxpool2dBackwardCFirst(ep::Stream* stream,\n                                      const NdIndexOffsetHelper<IDX, 3>& index_helper,\n                                      const IDX elem_num, const T* src, T* dest,\n                                      const int64_t* indice_ptr, const MaxPoolParams3D& params_3d) {\n    Maxpool2dBackwardComputeCFirst<T, IDX>(\n        index_helper, elem_num, src, dest, indice_ptr, params_3d.num_batch(),\n        params_3d.num_channel(), params_3d.GetYShape5D().At(3), params_3d.GetYShape5D().At(4),\n        params_3d.GetXShape5D().At(3), params_3d.GetXShape5D().At(4));\n  }\n\n  static void Maxpool2dForwardCLast(ep::Stream* stream,\n                                    const NdIndexOffsetHelper<IDX, 4>& index_helper,\n                                    const IDX elem_num, const T* src, T* dest, int64_t* indice_ptr,\n                                    const MaxPoolParams3D& params_3d) {\n    Maxpool2dForwardComputeCLast<T, IDX>(\n        index_helper, elem_num, src, dest, indice_ptr, params_3d.padding()[1],\n        params_3d.padding()[2], params_3d.num_batch(), params_3d.num_channel(),\n        params_3d.GetXShape5D().At(3), params_3d.GetXShape5D().At(4), params_3d.GetYShape5D().At(3),\n        params_3d.GetYShape5D().At(4), params_3d.pool_size_3d()[1], params_3d.pool_size_3d()[2],\n        params_3d.stride_3d()[1], params_3d.stride_3d()[2], params_3d.dilation_3d()[1],\n        params_3d.dilation_3d()[2]);\n  }\n\n  static void Maxpool2dBackwardCLast(ep::Stream* stream,\n                                     const NdIndexOffsetHelper<IDX, 4>& index_helper,\n                                     const IDX elem_num, const T* src, T* dest,\n                                     const int64_t* indice_ptr, const MaxPoolParams3D& params_3d) {\n    Maxpool2dBackwardComputeCLast<T, IDX>(\n        index_helper, elem_num, src, dest, indice_ptr, params_3d.num_batch(),\n        params_3d.num_channel(), params_3d.GetYShape5D().At(3), params_3d.GetYShape5D().At(4),\n        params_3d.GetXShape5D().At(3), params_3d.GetXShape5D().At(4));\n  }\n\n  static void Maxpool3dForward(ep::Stream* stream, const NdIndexOffsetHelper<IDX, 4>& index_helper,\n                               const IDX elem_num, const T* src, T* dest, int64_t* indice_ptr,\n                               const MaxPoolParams3D& params_3d) {\n    Maxpool3dForwardCompute<T, IDX>(\n        index_helper, elem_num, src, dest, indice_ptr, params_3d.padding()[0],\n        params_3d.padding()[1], params_3d.padding()[2], params_3d.num_batch(),\n        params_3d.num_channel(), params_3d.GetXShape5D().At(2), params_3d.GetXShape5D().At(3),\n        params_3d.GetXShape5D().At(4), params_3d.pool_size_3d()[0], params_3d.pool_size_3d()[1],\n        params_3d.pool_size_3d()[2], params_3d.stride_3d()[0], params_3d.stride_3d()[1],\n        params_3d.stride_3d()[2], params_3d.dilation_3d()[0], params_3d.dilation_3d()[1],\n        params_3d.dilation_3d()[2]);\n  }\n\n  static void Maxpool3dBackward(ep::Stream* stream, const NdIndexOffsetHelper<IDX, 4> index_helper,\n                                const IDX elem_num, const T* src, T* dest,\n                                const int64_t* indice_ptr, const MaxPoolParams3D& params_3d) {\n    Maxpool3dBackwardCompute<T, IDX>(index_helper, elem_num, src, dest, indice_ptr,\n                                     params_3d.num_batch(), params_3d.num_channel(),\n                                     params_3d.GetYShape5D().At(2), params_3d.GetYShape5D().At(3),\n                                     params_3d.GetYShape5D().At(4), params_3d.GetXShape5D().At(2),\n                                     params_3d.GetXShape5D().At(3), params_3d.GetXShape5D().At(4));\n  }\n};\n\ntemplate<DeviceType device_type, typename T>\nclass MaxPool1dKernel final : public user_op::OpKernel {\n public:\n  MaxPool1dKernel() = default;\n  ~MaxPool1dKernel() = default;\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n  std::shared_ptr<user_op::OpKernelCache> InitOpKernelCache(\n      user_op::KernelCacheContext* ctx) const override {\n    return CreatePoolOpKernelCache(ctx, 1);\n  }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*,\n               const user_op::OpKernelCache* cache) const override {\n    const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    user_op::Tensor* indice = ctx->Tensor4ArgNameAndIndex(\"indice\", 0);\n\n    const auto* pool_cache = dynamic_cast<const PoolOpKernelCache*>(cache);\n    const MaxPoolParams3D& params_3d = pool_cache->GetParams3D();\n\n    const int64_t elem_num = y->shape_view().elem_cnt();\n    const T* src = x->dptr<T>();\n    T* dest = y->mut_dptr<T>();\n    int64_t* indice_ptr = indice->mut_dptr<int64_t>();\n\n    DimVector y_vector(2);\n    y_vector.at(0) = y->shape_view().At(0) * y->shape_view().At(1);\n    y_vector.at(1) = y->shape_view().At(2);\n    if (elem_num < GetMaxVal<int32_t>()) {\n      NdIndexOffsetHelper<int32_t, 2> index_helper(y_vector.data());\n      PoolKernelUtil<device_type, T, int32_t>::Maxpool1dForward(\n          ctx->stream(), index_helper, elem_num, src, dest, indice_ptr, params_3d);\n    } else {\n      NdIndexOffsetHelper<int64_t, 2> index_helper(y_vector.data());\n      PoolKernelUtil<device_type, T, int64_t>::Maxpool1dForward(\n          ctx->stream(), index_helper, elem_num, src, dest, indice_ptr, params_3d);\n    }\n  }\n};\n\ntemplate<DeviceType device_type, typename T>\nclass MaxPool1dGradKernel final : public user_op::OpKernel {\n public:\n  MaxPool1dGradKernel() = default;\n  ~MaxPool1dGradKernel() = default;\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n  std::shared_ptr<user_op::OpKernelCache> InitOpKernelCache(\n      user_op::KernelCacheContext* ctx) const override {\n    return CreatePoolOpKernelCache(ctx, 1);\n  }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*,\n               const user_op::OpKernelCache* cache) const override {\n    const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    const user_op::Tensor* indice = ctx->Tensor4ArgNameAndIndex(\"indice\", 0);\n    user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex(\"dx\", 0);\n\n    const auto* pool_cache = dynamic_cast<const PoolOpKernelCache*>(cache);\n    const MaxPoolParams3D& params_3d = pool_cache->GetParams3D();\n\n    const int64_t elem_num = dy->shape_view().elem_cnt();\n    const T* src = dy->dptr<T>();\n    const int64_t* indice_ptr = indice->dptr<int64_t>();\n    T* dest = dx->mut_dptr<T>();\n    DimVector dy_vector(2);\n    dy_vector.at(0) = dy->shape_view().At(0) * dy->shape_view().At(1);\n    dy_vector.at(1) = dy->shape_view().At(2);\n    size_t out_bytes_size = dx->shape_view().elem_cnt() * GetSizeOfDataType(dx->data_type());\n    Memset<device_type>(ctx->stream(), dest, 0, out_bytes_size);\n\n    if (elem_num < GetMaxVal<int32_t>()) {\n      NdIndexOffsetHelper<int32_t, 2> index_helper(dy_vector.data());\n      PoolKernelUtil<device_type, T, int32_t>::Maxpool1dBackward(\n          ctx->stream(), index_helper, elem_num, src, dest, indice_ptr, params_3d);\n    } else {\n      NdIndexOffsetHelper<int64_t, 2> index_helper(dy_vector.data());\n      PoolKernelUtil<device_type, T, int64_t>::Maxpool1dBackward(\n          ctx->stream(), index_helper, elem_num, src, dest, indice_ptr, params_3d);\n    }\n  };\n};\n\ntemplate<DeviceType device_type, typename T>\nclass MaxPool2dKernel final : public user_op::OpKernel {\n public:\n  MaxPool2dKernel() = default;\n  ~MaxPool2dKernel() = default;\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n  std::shared_ptr<user_op::OpKernelCache> InitOpKernelCache(\n      user_op::KernelCacheContext* ctx) const override {\n    return CreatePoolOpKernelCache(ctx, 2);\n  }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*,\n               const user_op::OpKernelCache* cache) const override {\n    const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    user_op::Tensor* indice = ctx->Tensor4ArgNameAndIndex(\"indice\", 0);\n\n    const auto* pool_cache = dynamic_cast<const PoolOpKernelCache*>(cache);\n    const MaxPoolParams3D& params_3d = pool_cache->GetParams3D();\n\n    const int64_t elem_num = y->shape_view().elem_cnt();\n\n    const T* src = x->dptr<T>();\n    T* dest = y->mut_dptr<T>();\n    int64_t* indice_ptr = indice->mut_dptr<int64_t>();\n\n    const std::string& data_format = ctx->Attr<std::string>(\"data_format\");\n    if (data_format == \"channels_first\") {\n      DimVector y_vector(3);\n      y_vector.at(0) = y->shape_view().At(0) * y->shape_view().At(1);\n      y_vector.at(1) = y->shape_view().At(2);\n      y_vector.at(2) = y->shape_view().At(3);\n      if (elem_num < GetMaxVal<int32_t>()) {\n        NdIndexOffsetHelper<int32_t, 3> index_helper(y_vector.data());\n        PoolKernelUtil<device_type, T, int32_t>::Maxpool2dForwardCFirst(\n            ctx->stream(), index_helper, elem_num, src, dest, indice_ptr, params_3d);\n      } else {\n        NdIndexOffsetHelper<int64_t, 3> index_helper(y_vector.data());\n        PoolKernelUtil<device_type, T, int64_t>::Maxpool2dForwardCFirst(\n            ctx->stream(), index_helper, elem_num, src, dest, indice_ptr, params_3d);\n      }\n    } else if (data_format == \"channels_last\") {\n      DimVector y_vector;\n      y->shape_view().ToDimVector(&y_vector);\n      if (elem_num < GetMaxVal<int32_t>()) {\n        NdIndexOffsetHelper<int32_t, 4> index_helper(y_vector.data());\n        PoolKernelUtil<device_type, T, int32_t>::Maxpool2dForwardCLast(\n            ctx->stream(), index_helper, elem_num, src, dest, indice_ptr, params_3d);\n      } else {\n        NdIndexOffsetHelper<int64_t, 4> index_helper(y_vector.data());\n        PoolKernelUtil<device_type, T, int64_t>::Maxpool2dForwardCLast(\n            ctx->stream(), index_helper, elem_num, src, dest, indice_ptr, params_3d);\n      }\n    } else {\n      UNIMPLEMENTED() << \"Unsupported data_format\";\n    }\n  };\n};\n\ntemplate<DeviceType device_type, typename T>\nclass MaxPool2dGradKernel final : public user_op::OpKernel {\n public:\n  MaxPool2dGradKernel() = default;\n  ~MaxPool2dGradKernel() = default;\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n  std::shared_ptr<user_op::OpKernelCache> InitOpKernelCache(\n      user_op::KernelCacheContext* ctx) const override {\n    return CreatePoolOpKernelCache(ctx, 2);\n  }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*,\n               const user_op::OpKernelCache* cache) const override {\n    const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    const user_op::Tensor* indice = ctx->Tensor4ArgNameAndIndex(\"indice\", 0);\n    user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex(\"dx\", 0);\n\n    const auto* pool_cache = dynamic_cast<const PoolOpKernelCache*>(cache);\n    const MaxPoolParams3D& params_3d = pool_cache->GetParams3D();\n\n    const int64_t elem_num = dy->shape_view().elem_cnt();\n    const T* src = dy->dptr<T>();\n    const int64_t* indice_ptr = indice->dptr<int64_t>();\n    T* dest = dx->mut_dptr<T>();\n\n    size_t out_bytes_size = dx->shape_view().elem_cnt() * GetSizeOfDataType(dx->data_type());\n    Memset<device_type>(ctx->stream(), dest, 0, out_bytes_size);\n\n    const std::string& data_format = ctx->Attr<std::string>(\"data_format\");\n\n    if (data_format == \"channels_first\") {\n      DimVector dy_vector(3);\n      dy_vector.at(0) = dy->shape_view().At(0) * dy->shape_view().At(1);\n      dy_vector.at(1) = dy->shape_view().At(2);\n      dy_vector.at(2) = dy->shape_view().At(3);\n      if (elem_num < GetMaxVal<int32_t>()) {\n        NdIndexOffsetHelper<int32_t, 3> index_helper(dy_vector.data());\n        PoolKernelUtil<device_type, T, int32_t>::Maxpool2dBackwardCFirst(\n            ctx->stream(), index_helper, elem_num, src, dest, indice_ptr, params_3d);\n      } else {\n        NdIndexOffsetHelper<int64_t, 3> index_helper(dy_vector.data());\n        PoolKernelUtil<device_type, T, int64_t>::Maxpool2dBackwardCFirst(\n            ctx->stream(), index_helper, elem_num, src, dest, indice_ptr, params_3d);\n      }\n    } else if (data_format == \"channels_last\") {\n      DimVector dy_vector;\n      dy->shape_view().ToDimVector(&dy_vector);\n      if (elem_num < GetMaxVal<int32_t>()) {\n        NdIndexOffsetHelper<int32_t, 4> index_helper(dy_vector.data());\n        PoolKernelUtil<device_type, T, int32_t>::Maxpool2dBackwardCLast(\n            ctx->stream(), index_helper, elem_num, src, dest, indice_ptr, params_3d);\n      } else {\n        NdIndexOffsetHelper<int64_t, 4> index_helper(dy_vector.data());\n        PoolKernelUtil<device_type, T, int64_t>::Maxpool2dBackwardCLast(\n            ctx->stream(), index_helper, elem_num, src, dest, indice_ptr, params_3d);\n      }\n    } else {\n      UNIMPLEMENTED() << \"Unsupported data_format\";\n    }\n  };\n};\n\ntemplate<DeviceType device_type, typename T>\nclass MaxPool3dKernel final : public user_op::OpKernel {\n public:\n  MaxPool3dKernel() = default;\n  ~MaxPool3dKernel() = default;\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n  std::shared_ptr<user_op::OpKernelCache> InitOpKernelCache(\n      user_op::KernelCacheContext* ctx) const override {\n    return CreatePoolOpKernelCache(ctx, 3);\n  }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*,\n               const user_op::OpKernelCache* cache) const override {\n    const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    user_op::Tensor* indice = ctx->Tensor4ArgNameAndIndex(\"indice\", 0);\n\n    const auto* pool_cache = dynamic_cast<const PoolOpKernelCache*>(cache);\n    const MaxPoolParams3D& params_3d = pool_cache->GetParams3D();\n\n    const int64_t elem_num = y->shape_view().elem_cnt();\n    const T* src = x->dptr<T>();\n    T* dest = y->mut_dptr<T>();\n    int64_t* indice_ptr = indice->mut_dptr<int64_t>();\n\n    DimVector y_vector(4);\n    y_vector.at(0) = y->shape_view().At(0) * y->shape_view().At(1);\n    y_vector.at(1) = y->shape_view().At(2);\n    y_vector.at(2) = y->shape_view().At(3);\n    y_vector.at(3) = y->shape_view().At(4);\n\n    if (elem_num < GetMaxVal<int32_t>()) {\n      NdIndexOffsetHelper<int32_t, 4> index_helper(y_vector.data());\n      PoolKernelUtil<device_type, T, int32_t>::Maxpool3dForward(\n          ctx->stream(), index_helper, elem_num, src, dest, indice_ptr, params_3d);\n    } else {\n      NdIndexOffsetHelper<int64_t, 4> index_helper(y_vector.data());\n      PoolKernelUtil<device_type, T, int64_t>::Maxpool3dForward(\n          ctx->stream(), index_helper, elem_num, src, dest, indice_ptr, params_3d);\n    }\n  };\n};\n\ntemplate<DeviceType device_type, typename T>\nclass MaxPool3dGradKernel final : public user_op::OpKernel {\n public:\n  MaxPool3dGradKernel() = default;\n  ~MaxPool3dGradKernel() = default;\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n  std::shared_ptr<user_op::OpKernelCache> InitOpKernelCache(\n      user_op::KernelCacheContext* ctx) const override {\n    return CreatePoolOpKernelCache(ctx, 3);\n  }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*,\n               const user_op::OpKernelCache* cache) const override {\n    const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    const user_op::Tensor* indice = ctx->Tensor4ArgNameAndIndex(\"indice\", 0);\n    user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex(\"dx\", 0);\n\n    const auto* pool_cache = dynamic_cast<const PoolOpKernelCache*>(cache);\n    const MaxPoolParams3D& params_3d = pool_cache->GetParams3D();\n\n    const int64_t elem_num = dy->shape_view().elem_cnt();\n    const T* src = dy->dptr<T>();\n    const int64_t* indice_ptr = indice->dptr<int64_t>();\n    T* dest = dx->mut_dptr<T>();\n\n    DimVector dy_vector(4);\n    dy_vector.at(0) = dy->shape_view().At(0) * dy->shape_view().At(1);\n    dy_vector.at(1) = dy->shape_view().At(2);\n    dy_vector.at(2) = dy->shape_view().At(3);\n    dy_vector.at(3) = dy->shape_view().At(4);\n\n    size_t out_bytes_size = dx->shape_view().elem_cnt() * GetSizeOfDataType(dx->data_type());\n    Memset<device_type>(ctx->stream(), dest, 0, out_bytes_size);\n\n    if (elem_num < GetMaxVal<int32_t>()) {\n      NdIndexOffsetHelper<int32_t, 4> index_helper(dy_vector.data());\n      PoolKernelUtil<device_type, T, int32_t>::Maxpool3dBackward(\n          ctx->stream(), index_helper, elem_num, src, dest, indice_ptr, params_3d);\n    } else {\n      NdIndexOffsetHelper<int64_t, 4> index_helper(dy_vector.data());\n      PoolKernelUtil<device_type, T, int64_t>::Maxpool3dBackward(\n          ctx->stream(), index_helper, elem_num, src, dest, indice_ptr, params_3d);\n    }\n  };\n};\n\n#define REGISTER_POOL_KERNELS(device, dtype)                                            \\\n  REGISTER_USER_KERNEL(\"max_pool_1d\")                                                   \\\n      .SetCreateFn<MaxPool1dKernel<device, dtype>>()                                    \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == device)                             \\\n                       && (user_op::HobDataType(\"x\", 0) == GetDataType<dtype>::value)); \\\n  REGISTER_USER_KERNEL(\"max_pool_1d_grad\")                                              \\\n      .SetCreateFn<MaxPool1dGradKernel<device, dtype>>()                                \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == device)                             \\\n                       && (user_op::HobDataType(\"x\", 0) == GetDataType<dtype>::value)); \\\n  REGISTER_USER_KERNEL(\"max_pool_2d\")                                                   \\\n      .SetCreateFn<MaxPool2dKernel<device, dtype>>()                                    \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == device)                             \\\n                       && (user_op::HobDataType(\"x\", 0) == GetDataType<dtype>::value)); \\\n  REGISTER_USER_KERNEL(\"max_pool_2d_grad\")                                              \\\n      .SetCreateFn<MaxPool2dGradKernel<device, dtype>>()                                \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == device)                             \\\n                       && (user_op::HobDataType(\"x\", 0) == GetDataType<dtype>::value)); \\\n  REGISTER_USER_KERNEL(\"max_pool_3d\")                                                   \\\n      .SetCreateFn<MaxPool3dKernel<device, dtype>>()                                    \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == device)                             \\\n                       && (user_op::HobDataType(\"x\", 0) == GetDataType<dtype>::value)); \\\n  REGISTER_USER_KERNEL(\"max_pool_3d_grad\")                                              \\\n      .SetCreateFn<MaxPool3dGradKernel<device, dtype>>()                                \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == device)                             \\\n                       && (user_op::HobDataType(\"x\", 0) == GetDataType<dtype>::value));\n\n#define REGISTER_POOL_WITH_DEVICE(device) \\\n  REGISTER_POOL_KERNELS(device, int32_t)  \\\n  REGISTER_POOL_KERNELS(device, float)    \\\n  REGISTER_POOL_KERNELS(device, double)\n\nREGISTER_POOL_WITH_DEVICE(DeviceType::kCPU)\n\n#ifdef WITH_CUDA\nREGISTER_POOL_WITH_DEVICE(DeviceType::kCUDA)\nREGISTER_POOL_KERNELS(DeviceType::kCUDA, half)\n#endif\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_POOL_KERNEL_UTIL, (DeviceType::kCPU),\n                                 POOL_DATA_TYPE_CPU_SEQ, POOL_IDX_DATA_TYPE_SEQ);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/max_pool_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <cstdint>\n#ifdef WITH_CUDA\n#include \"oneflow/core/cuda/elementwise.cuh\"\n#include \"oneflow/user/kernels/max_pool_kernel_util.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n#include <cuda_fp16.h>\n\nnamespace oneflow {\nnamespace {\n\nconstexpr int kBlockSize = cuda::elementwise::kBlockSize;\n\nint GetMinThreadNum(int64_t elem_num) { return std::min<int64_t>(elem_num, kBlockSize); }\n\nint GetNumBlocks(int64_t elem_cnt) {\n  int num_blocks = 0;\n  OF_CUDA_CHECK(cuda::elementwise::GetNumBlocks(elem_cnt, &num_blocks));\n  return num_blocks;\n}\n\ntemplate<typename T, typename IDX>\n__device__ __inline__ void Maxpool2dForwardComputeCLast(\n    const NdIndexOffsetHelper<IDX, 4>& index_helper, IDX elem_num, const T* src, T* dest,\n    int64_t* indice_ptr, const int32_t padding_h, const int32_t padding_w, const int64_t n_batch,\n    const int64_t n_channel, const int64_t x_height, const int64_t x_width, const int64_t y_height,\n    const int64_t y_width, const int32_t kernel_size_h, const int32_t kernel_size_w,\n    const int32_t stride_h, const int32_t stride_w, const int32_t dilation_h,\n    const int32_t dilation_w) {\n  IDX n, h, w, c;\n  CUDA_1D_KERNEL_LOOP(num, elem_num) {\n    index_helper.OffsetToNdIndex(num, n, h, w, c);\n\n    const IDX x_start_idx = n * n_channel * x_width * x_height;\n    const IDX y_start_idx = n * n_channel * y_height * y_width;\n    IDX hstart = h * stride_h - padding_h;\n    IDX wstart = w * stride_w - padding_w;\n    const IDX hend = (hstart + (kernel_size_h - 1) * dilation_h + 1) <= x_height\n                         ? (hstart + (kernel_size_h - 1) * dilation_h + 1)\n                         : x_height;\n    const IDX wend = (wstart + (kernel_size_w - 1) * dilation_w + 1) <= x_width\n                         ? (wstart + (kernel_size_w - 1) * dilation_w + 1)\n                         : x_width;\n\n    while (hstart < 0) { hstart += dilation_h; }\n    while (wstart < 0) { wstart += dilation_w; }\n    /* compute max value(src[src_idx]) in kernel box region, and save the value to dest[num] */\n    IDX max_index = hstart * x_width + wstart;\n    IDX src_idx = 0;\n    /* equal to -std::numeric_limits<T>::infinity(); */\n    T max_value = detail::numeric_limits<T>::lower_bound();\n\n    for (IDX i = hstart; i < hend; i += dilation_h) {\n      for (IDX j = wstart; j < wend; j += dilation_w) {\n        const IDX window_idx = i * x_width * n_channel + j * n_channel + c;\n        const IDX search_idx = x_start_idx + window_idx;\n        T val = src[search_idx];\n        if (val > max_value || detail::numerics<T>::isnan(val)) {\n          max_value = val;\n          max_index = window_idx;\n          src_idx = search_idx;\n        }\n      }\n    }\n    const IDX out_idx = y_start_idx + h * y_width * n_channel + w * n_channel + c;\n    dest[out_idx] = src[src_idx];\n    indice_ptr[out_idx] = max_index;\n  }\n}\n\n}  // namespace\n\ntemplate<typename T, typename IDX>\n__launch_bounds__(kBlockSize) __global__\n    void DoCUDAMaxPool1dForward(const NdIndexOffsetHelper<IDX, 2> index_helper, IDX elem_num,\n                                const T* src, T* dest, int64_t* indice_ptr, int32_t padding_l,\n                                int32_t n_batch, int32_t n_channel, int32_t x_length,\n                                int32_t kernel_size_l, int32_t stride_l, int32_t dilation_l) {\n  Maxpool1dForwardCompute<T, IDX>(index_helper, elem_num, src, dest, indice_ptr, padding_l, n_batch,\n                                  n_channel, x_length, kernel_size_l, stride_l, dilation_l);\n};\n\ntemplate<typename T, typename IDX>\n__launch_bounds__(kBlockSize) __global__\n    void DoCUDAMaxPool2dForwardCFirst(const NdIndexOffsetHelper<IDX, 3> index_helper, IDX elem_num,\n                                      const T* src, T* dest, int64_t* indice_ptr, int32_t padding_h,\n                                      int32_t padding_w, int32_t n_batch, int32_t n_channel,\n                                      int32_t x_height, int32_t x_width, int32_t kernel_size_h,\n                                      int32_t kernel_size_w, int32_t stride_h, int32_t stride_w,\n                                      int32_t dilation_h, int32_t dilation_w) {\n  Maxpool2dForwardComputeCFirst<T, IDX>(\n      index_helper, elem_num, src, dest, indice_ptr, padding_h, padding_w, n_batch, n_channel,\n      x_height, x_width, kernel_size_h, kernel_size_w, stride_h, stride_w, dilation_h, dilation_w);\n};\n\ntemplate<typename T, typename IDX>\n__launch_bounds__(kBlockSize) __global__\n    void DoCUDAMaxPool2dForwardCLast(const NdIndexOffsetHelper<IDX, 4> index_helper, IDX elem_num,\n                                     const T* src, T* dest, int64_t* indice_ptr, int32_t padding_h,\n                                     int32_t padding_w, int32_t n_batch, int32_t n_channel,\n                                     int32_t x_height, int32_t x_width, int32_t y_height,\n                                     int32_t y_width, int32_t kernel_size_h, int32_t kernel_size_w,\n                                     int32_t stride_h, int32_t stride_w, int32_t dilation_h,\n                                     int32_t dilation_w) {\n  Maxpool2dForwardComputeCLast<T, IDX>(index_helper, elem_num, src, dest, indice_ptr, padding_h,\n                                       padding_w, n_batch, n_channel, x_height, x_width, y_height,\n                                       y_width, kernel_size_h, kernel_size_w, stride_h, stride_w,\n                                       dilation_h, dilation_w);\n};\n\ntemplate<typename T, typename IDX>\n__launch_bounds__(kBlockSize) __global__\n    void DoCUDAMaxPool3dForward(const NdIndexOffsetHelper<IDX, 4> index_helper, IDX elem_num,\n                                const T* src, T* dest, int64_t* indice_ptr, int32_t padding_t,\n                                int32_t padding_h, int32_t padding_w, int32_t n_batch,\n                                int32_t n_channel, int32_t x_time, int32_t x_height,\n                                int32_t x_width, int32_t kernel_size_t, int32_t kernel_size_h,\n                                int32_t kernel_size_w, int32_t stride_t, int32_t stride_h,\n                                int32_t stride_w, int32_t dilation_t, int32_t dilation_h,\n                                int32_t dilation_w) {\n  Maxpool3dForwardCompute<T, IDX>(index_helper, elem_num, src, dest, indice_ptr, padding_t,\n                                  padding_h, padding_w, n_batch, n_channel, x_time, x_height,\n                                  x_width, kernel_size_t, kernel_size_h, kernel_size_w, stride_t,\n                                  stride_h, stride_w, dilation_t, dilation_h, dilation_w);\n};\n\ntemplate<typename T, typename IDX>\n__launch_bounds__(kBlockSize) __global__\n    void DoCUDAMaxPool1dBackward(const NdIndexOffsetHelper<IDX, 2> index_helper, const IDX elem_num,\n                                 const T* src, T* dest, const int64_t* indice_ptr,\n                                 const int32_t n_batch, const int32_t n_channel,\n                                 const int32_t src_length, const int32_t dst_length) {\n  Maxpool1dBackwardCompute<T, IDX>(index_helper, elem_num, src, dest, indice_ptr, n_batch,\n                                   n_channel, src_length, dst_length);\n};\n\ntemplate<typename T, typename IDX>\n__launch_bounds__(kBlockSize) __global__\n    void DoCUDAMaxPool2dBackwardCFirst(const NdIndexOffsetHelper<IDX, 3> index_helper,\n                                       const IDX elem_num, const T* src, T* dest,\n                                       const int64_t* indice_ptr, const int32_t n_batch,\n                                       const int32_t n_channel, const int32_t src_height,\n                                       const int32_t src_width, const int32_t dst_height,\n                                       const int32_t dst_width) {\n  Maxpool2dBackwardComputeCFirst<T, IDX>(index_helper, elem_num, src, dest, indice_ptr, n_batch,\n                                         n_channel, src_height, src_width, dst_height, dst_width);\n};\n\ntemplate<typename T, typename IDX>\n__launch_bounds__(kBlockSize) __global__\n    void DoCUDAMaxPool2dBackwardCLast(const NdIndexOffsetHelper<IDX, 4> index_helper,\n                                      const IDX elem_num, const T* src, T* dest,\n                                      const int64_t* indice_ptr, const int32_t n_batch,\n                                      const int32_t n_channel, const int32_t src_height,\n                                      const int32_t src_width, const int32_t dst_height,\n                                      const int32_t dst_width) {\n  Maxpool2dBackwardComputeCLast<T, IDX>(index_helper, elem_num, src, dest, indice_ptr, n_batch,\n                                        n_channel, src_height, src_width, dst_height, dst_width);\n};\n\ntemplate<typename T, typename IDX>\n__launch_bounds__(kBlockSize) __global__\n    void DoCUDAMaxPool3dBackward(const NdIndexOffsetHelper<IDX, 4> index_helper, const IDX elem_num,\n                                 const T* src, T* dest, const int64_t* indice_ptr,\n                                 const int32_t n_batch, const int32_t n_channel,\n                                 const int32_t src_time, const int32_t src_height,\n                                 const int32_t src_width, const int32_t dst_time,\n                                 const int32_t dst_height, const int32_t dst_width) {\n  Maxpool3dBackwardCompute<T, IDX>(index_helper, elem_num, src, dest, indice_ptr, n_batch,\n                                   n_channel, src_time, src_height, src_width, dst_time, dst_height,\n                                   dst_width);\n};\n\ntemplate<typename T, typename IDX>\nstruct PoolKernelUtil<DeviceType::kCUDA, T, IDX> {\n  static void Maxpool1dForward(ep::Stream* stream, const NdIndexOffsetHelper<IDX, 2>& index_helper,\n                               const IDX elem_num, const T* src, T* dest, int64_t* indice_ptr,\n                               const MaxPoolParams3D& params_3d) {\n    DoCUDAMaxPool1dForward<T, IDX><<<GetNumBlocks(elem_num), GetMinThreadNum(elem_num), 0,\n                                     stream->As<ep::CudaStream>()->cuda_stream()>>>(\n        index_helper, elem_num, src, dest, indice_ptr, params_3d.padding()[2],\n        params_3d.num_batch(), params_3d.num_channel(), params_3d.GetXShape5D().At(4),\n        params_3d.pool_size_3d()[2], params_3d.stride_3d()[2], params_3d.dilation_3d()[2]);\n  }\n\n  static void Maxpool1dBackward(ep::Stream* stream, const NdIndexOffsetHelper<IDX, 2>& index_helper,\n                                const IDX elem_num, const T* src, T* dest,\n                                const int64_t* indice_ptr, const MaxPoolParams3D& params_3d) {\n    DoCUDAMaxPool1dBackward<T, IDX><<<GetNumBlocks(elem_num), GetMinThreadNum(elem_num), 0,\n                                      stream->As<ep::CudaStream>()->cuda_stream()>>>(\n        index_helper, elem_num, src, dest, indice_ptr, params_3d.num_batch(),\n        params_3d.num_channel(), params_3d.GetYShape5D().At(4), params_3d.GetXShape5D().At(4));\n  }\n\n  static void Maxpool2dForwardCFirst(ep::Stream* stream,\n                                     const NdIndexOffsetHelper<IDX, 3>& index_helper,\n                                     const IDX elem_num, const T* src, T* dest, int64_t* indice_ptr,\n                                     const MaxPoolParams3D& params_3d) {\n    DoCUDAMaxPool2dForwardCFirst<T, IDX><<<GetNumBlocks(elem_num), GetMinThreadNum(elem_num), 0,\n                                           stream->As<ep::CudaStream>()->cuda_stream()>>>(\n        index_helper, elem_num, src, dest, indice_ptr, params_3d.padding()[1],\n        params_3d.padding()[2], params_3d.num_batch(), params_3d.num_channel(),\n        params_3d.GetXShape5D().At(3), params_3d.GetXShape5D().At(4), params_3d.pool_size_3d()[1],\n        params_3d.pool_size_3d()[2], params_3d.stride_3d()[1], params_3d.stride_3d()[2],\n        params_3d.dilation_3d()[1], params_3d.dilation_3d()[2]);\n  }\n\n  static void Maxpool2dBackwardCFirst(ep::Stream* stream,\n                                      const NdIndexOffsetHelper<IDX, 3>& index_helper,\n                                      const IDX elem_num, const T* src, T* dest,\n                                      const int64_t* indice_ptr, const MaxPoolParams3D& params_3d) {\n    DoCUDAMaxPool2dBackwardCFirst<T, IDX><<<GetNumBlocks(elem_num), GetMinThreadNum(elem_num), 0,\n                                            stream->As<ep::CudaStream>()->cuda_stream()>>>(\n        index_helper, elem_num, src, dest, indice_ptr, params_3d.num_batch(),\n        params_3d.num_channel(), params_3d.GetYShape5D().At(3), params_3d.GetYShape5D().At(4),\n        params_3d.GetXShape5D().At(3), params_3d.GetXShape5D().At(4));\n  }\n\n  static void Maxpool2dForwardCLast(ep::Stream* stream,\n                                    const NdIndexOffsetHelper<IDX, 4>& index_helper,\n                                    const IDX elem_num, const T* src, T* dest, int64_t* indice_ptr,\n                                    const MaxPoolParams3D& params_3d) {\n    DoCUDAMaxPool2dForwardCLast<T, IDX><<<GetNumBlocks(elem_num), GetMinThreadNum(elem_num), 0,\n                                          stream->As<ep::CudaStream>()->cuda_stream()>>>(\n        index_helper, elem_num, src, dest, indice_ptr, params_3d.padding()[1],\n        params_3d.padding()[2], params_3d.num_batch(), params_3d.num_channel(),\n        params_3d.GetXShape5D().At(3), params_3d.GetXShape5D().At(4), params_3d.GetYShape5D().At(3),\n        params_3d.GetYShape5D().At(4), params_3d.pool_size_3d()[1], params_3d.pool_size_3d()[2],\n        params_3d.stride_3d()[1], params_3d.stride_3d()[2], params_3d.dilation_3d()[1],\n        params_3d.dilation_3d()[2]);\n  }\n\n  static void Maxpool2dBackwardCLast(ep::Stream* stream,\n                                     const NdIndexOffsetHelper<IDX, 4>& index_helper,\n                                     const IDX elem_num, const T* src, T* dest,\n                                     const int64_t* indice_ptr, const MaxPoolParams3D& params_3d) {\n    DoCUDAMaxPool2dBackwardCLast<T, IDX><<<GetNumBlocks(elem_num), GetMinThreadNum(elem_num), 0,\n                                           stream->As<ep::CudaStream>()->cuda_stream()>>>(\n        index_helper, elem_num, src, dest, indice_ptr, params_3d.num_batch(),\n        params_3d.num_channel(), params_3d.GetYShape5D().At(3), params_3d.GetYShape5D().At(4),\n        params_3d.GetXShape5D().At(3), params_3d.GetXShape5D().At(4));\n  }\n\n  static void Maxpool3dForward(ep::Stream* stream, const NdIndexOffsetHelper<IDX, 4>& index_helper,\n                               const IDX elem_num, const T* src, T* dest, int64_t* indice_ptr,\n                               const MaxPoolParams3D& params_3d) {\n    DoCUDAMaxPool3dForward<T, IDX><<<GetNumBlocks(elem_num), GetMinThreadNum(elem_num), 0,\n                                     stream->As<ep::CudaStream>()->cuda_stream()>>>(\n        index_helper, elem_num, src, dest, indice_ptr, params_3d.padding()[0],\n        params_3d.padding()[1], params_3d.padding()[2], params_3d.num_batch(),\n        params_3d.num_channel(), params_3d.GetXShape5D().At(2), params_3d.GetXShape5D().At(3),\n        params_3d.GetXShape5D().At(4), params_3d.pool_size_3d()[0], params_3d.pool_size_3d()[1],\n        params_3d.pool_size_3d()[2], params_3d.stride_3d()[0], params_3d.stride_3d()[1],\n        params_3d.stride_3d()[2], params_3d.dilation_3d()[0], params_3d.dilation_3d()[1],\n        params_3d.dilation_3d()[2]);\n  }\n\n  static void Maxpool3dBackward(ep::Stream* stream, const NdIndexOffsetHelper<IDX, 4>& index_helper,\n                                const IDX elem_num, const T* src, T* dest,\n                                const int64_t* indice_ptr, const MaxPoolParams3D& params_3d) {\n    DoCUDAMaxPool3dBackward<T, IDX><<<GetNumBlocks(elem_num), GetMinThreadNum(elem_num), 0,\n                                      stream->As<ep::CudaStream>()->cuda_stream()>>>(\n        index_helper, elem_num, src, dest, indice_ptr, params_3d.num_batch(),\n        params_3d.num_channel(), params_3d.GetYShape5D().At(2), params_3d.GetYShape5D().At(3),\n        params_3d.GetYShape5D().At(4), params_3d.GetXShape5D().At(2), params_3d.GetXShape5D().At(3),\n        params_3d.GetXShape5D().At(4));\n  }\n};\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_POOL_KERNEL_UTIL, (DeviceType::kCUDA),\n                                 POOL_DATA_TYPE_CUDA_SEQ, POOL_IDX_DATA_TYPE_SEQ);\n\n}  // namespace oneflow\n#endif  // WITH_CUDA\n"
  },
  {
    "path": "oneflow/user/kernels/max_pool_kernel_util.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/user/utils/pool_util.h\"\n#include \"oneflow/user/kernels/max_pool_kernel_util.h\"\n\nnamespace oneflow {\n\nvoid GetWindowedOutputShape(int64_t input_size, int32_t filter_size, int32_t stride,\n                            int32_t padding, bool ceil_mode, int32_t dilation_rate,\n                            int64_t* output_ptr) {\n  int64_t output_size = (input_size + 2 * padding - dilation_rate * (filter_size - 1) - 1 + stride\n                         + (ceil_mode ? stride - 1 : 0))\n                        / stride;\n\n  if (ceil_mode) {\n    // ensure that the last pool starts inside the image\n    // needed to avoid problems in ceil mode\n    if ((output_size - 1) * stride >= input_size + padding) { --output_size; }\n  }\n  *output_ptr = output_size;\n}\n\nvoid Get3DOutputShape(const DimVector& in, const std::vector<int32_t>& pool_size,\n                      const std::vector<int32_t>& strides, const std::vector<int32_t>& padding,\n                      const bool ceil_mode, std::vector<int32_t> dilation_rate, DimVector* out) {\n  out->clear();\n  out->resize(3);\n  FOR_RANGE(size_t, i, 0, 3) {\n    int64_t* out_ptr = &(*out).at(i);\n    GetWindowedOutputShape(in.at(i), pool_size.at(i), strides.at(i), padding.at(i), ceil_mode,\n                           dilation_rate.at(i), out_ptr);\n  }\n}\n\nMaxPoolParams3D::MaxPoolParams3D(const int32_t dim, const ShapeView& x_shape,\n                                 const std::string& data_format,\n                                 const std::vector<int32_t>& padding,\n                                 const std::vector<int32_t>& kernel_size,\n                                 const std::vector<int32_t>& stride,\n                                 const std::vector<int32_t>& dilation, const bool return_indices,\n                                 const bool ceil_mode)\n    : dim_(dim),\n      data_format_(data_format),\n      padding_(Get3DVec<Get3DVecType::kPad>(padding, dim)),\n      pool_size_3d_(Get3DVec(kernel_size, dim)),\n      stride_3d_(Get3DVec(stride, dim)),\n      dilation_3d_(Get3DVec(dilation, dim)),\n      return_indices_(return_indices),\n      ceil_mode_(ceil_mode) {\n  x_3d_ = {GetInDim(x_shape, data_format, 0, dim), GetInDim(x_shape, data_format, 1, dim),\n           GetInDim(x_shape, data_format, 2, dim)};\n  Get3DOutputShape(x_3d_, pool_size_3d_, stride_3d_, padding_, ceil_mode_, dilation_3d_, &y_3d_);\n  if (data_format == \"channels_first\") {\n    channel_num_ = x_shape.At(1);\n  } else {\n    CHECK_EQ(data_format_, \"channels_last\")\n        << \"data_format must be 'channels_first' or 'channels_last'\";\n    channel_num_ = x_shape.At(x_shape.NumAxes() - 1);\n  }\n  batch_num_ = x_shape.At(0);\n}\n\nvoid MaxPoolParams3D::Reset(const ShapeView& x_shape) {\n  x_3d_ = {GetInDim(x_shape, data_format_, 0, dim_), GetInDim(x_shape, data_format_, 1, dim_),\n           GetInDim(x_shape, data_format_, 2, dim_)};\n  Get3DOutputShape(x_3d_, pool_size_3d_, stride_3d_, padding_, ceil_mode_, dilation_3d_, &y_3d_);\n}\n\nShape MaxPoolParams3D::GetYShape() const {\n  DimVector y_dim_vec;\n  if (dim_ == 1) {\n    y_dim_vec = {y_3d_.at(2)};\n  } else if (dim_ == 2) {\n    y_dim_vec = {y_3d_.at(1), y_3d_.at(2)};\n  } else if (dim_ == 3) {\n    y_dim_vec = {y_3d_.at(0), y_3d_.at(1), y_3d_.at(2)};\n  } else {\n    UNIMPLEMENTED();\n  }\n  if (data_format_ == \"channels_first\") {\n    y_dim_vec.insert(y_dim_vec.begin(), channel_num_);\n  } else {\n    CHECK_EQ(data_format_, \"channels_last\")\n        << \"data_format must be 'channels_first' or 'channels_last'\";\n    y_dim_vec.insert(y_dim_vec.end(), channel_num_);\n  }\n  y_dim_vec.insert(y_dim_vec.begin(), batch_num_);\n  return Shape(y_dim_vec);\n}\n\nShape MaxPoolParams3D::GetXShape5D() const {\n  return Shape({batch_num_, channel_num_, x_3d_.at(0), x_3d_.at(1), x_3d_.at(2)});\n}\n\nShape MaxPoolParams3D::GetYShape5D() const {\n  return Shape({batch_num_, channel_num_, y_3d_.at(0), y_3d_.at(1), y_3d_.at(2)});\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/max_pool_kernel_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_KERNELS_POOL_KERNEL_UTIL_H_\n#define ONEFLOW_USER_KERNELS_POOL_KERNEL_UTIL_H_\n#include \"oneflow/core/ep/include/stream.h\"\n#include \"oneflow/core/ndarray/xpu_util.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/common/nd_index_offset_helper.h\"\n#include \"oneflow/core/operator/operator_util.h\"\n#include \"oneflow/core/kernel/util/numerics.cuh\"\n#include \"oneflow/core/kernel/util/numeric_limits.cuh\"\n#ifdef WITH_CUDA\n#include \"oneflow/core/cuda/atomic.cuh\"\n#endif  // WITH_CUDA\n\nnamespace oneflow {\n\n#define POOL_DATA_TYPE_SEQ                        \\\n  OF_PP_MAKE_TUPLE_SEQ(int32_t, DataType::kInt32) \\\n  OF_PP_MAKE_TUPLE_SEQ(float, DataType::kFloat)   \\\n  OF_PP_MAKE_TUPLE_SEQ(double, DataType::kDouble)\n\n#define POOL_IDX_DATA_TYPE_SEQ                    \\\n  OF_PP_MAKE_TUPLE_SEQ(int32_t, DataType::kInt32) \\\n  OF_PP_MAKE_TUPLE_SEQ(int64_t, DataType::kInt64)\n\n#define POOL_DATA_TYPE_CPU_SEQ POOL_DATA_TYPE_SEQ\n#define POOL_DATA_TYPE_CUDA_SEQ POOL_DATA_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(half, DataType::kFloat16)\n\ntypedef small_vector<int64_t, SHAPE_MAX_AXIS_SIZE> FixedDimVector;\n\ntemplate<typename T>\nstruct DeviceAdd {\n  OF_DEVICE_FUNC static void Invoke(const T* x, T* y) {\n#if defined(__CUDA_ARCH__)\n    cuda::atomic::Add(y, *x);\n#else\n    *y += *x;\n#endif\n  };\n};\n\nclass MaxPoolParams3D {\n public:\n  MaxPoolParams3D(const int32_t dim, const ShapeView& x_shape, const std::string& data_format,\n                  const std::vector<int32_t>& padding, const std::vector<int32_t>& kernel_size,\n                  const std::vector<int32_t>& stride, const std::vector<int32_t>& dilation,\n                  const bool return_indices, const bool ceil_mode);\n  ~MaxPoolParams3D() = default;\n\n  const std::string& data_format() const { return data_format_; }\n  const std::vector<int32_t>& padding() const { return padding_; }\n  const std::vector<int32_t>& pool_size_3d() const { return pool_size_3d_; }\n  const std::vector<int32_t>& stride_3d() const { return stride_3d_; }\n  const std::vector<int32_t>& dilation_3d() const { return dilation_3d_; }\n  const bool& return_indices() const { return return_indices_; }\n  const bool& ceil_mode() const { return ceil_mode_; }\n  const int32_t& num_batch() const { return batch_num_; }\n  const int32_t& num_channel() const { return channel_num_; }\n\n  void Reset(const ShapeView& x_shape);\n  Shape GetYShape() const;\n  Shape GetXShape5D() const;\n  Shape GetYShape5D() const;\n\n private:\n  int32_t dim_;\n  FixedDimVector x_3d_;\n  FixedDimVector y_3d_;\n  std::string data_format_;\n  std::vector<int32_t> padding_;\n  std::vector<int32_t> pool_size_3d_;\n  std::vector<int32_t> stride_3d_;\n  std::vector<int32_t> dilation_3d_;\n  bool return_indices_;\n  bool ceil_mode_;\n  int32_t batch_num_;\n  int32_t channel_num_;\n};\n\ntemplate<DeviceType device_type, typename T, typename IDX>\nstruct PoolKernelUtil {\n  static void Maxpool1dForward(ep::Stream* stream, const NdIndexOffsetHelper<IDX, 2>& index_helper,\n                               const IDX elem_num, const T* src, T* dest, int64_t* indice_ptr,\n                               const MaxPoolParams3D& params_3d);\n\n  static void Maxpool1dBackward(ep::Stream* stream, const NdIndexOffsetHelper<IDX, 2>& index_helper,\n                                const IDX elem_num, const T* src, T* dest,\n                                const int64_t* indice_ptr, const MaxPoolParams3D& params_3d);\n\n  static void Maxpool2dForwardCFirst(ep::Stream* stream,\n                                     const NdIndexOffsetHelper<IDX, 3>& index_helper,\n                                     const IDX elem_num, const T* src, T* dest, int64_t* indice_ptr,\n                                     const MaxPoolParams3D& params_3d);\n\n  static void Maxpool2dBackwardCFirst(ep::Stream* stream,\n                                      const NdIndexOffsetHelper<IDX, 3>& index_helper,\n                                      const IDX elem_num, const T* src, T* dest,\n                                      const int64_t* indice_ptr, const MaxPoolParams3D& params_3d);\n\n  static void Maxpool2dForwardCLast(ep::Stream* stream,\n                                    const NdIndexOffsetHelper<IDX, 4>& index_helper,\n                                    const IDX elem_num, const T* src, T* dest, int64_t* indice_ptr,\n                                    const MaxPoolParams3D& params_3d);\n\n  static void Maxpool2dBackwardCLast(ep::Stream* stream,\n                                     const NdIndexOffsetHelper<IDX, 4>& index_helper,\n                                     const IDX elem_num, const T* src, T* dest,\n                                     const int64_t* indice_ptr, const MaxPoolParams3D& params_3d);\n\n  static void Maxpool3dForward(ep::Stream* stream, const NdIndexOffsetHelper<IDX, 4>& index_helper,\n                               const IDX elem_num, const T* src, T* dest, int64_t* indice_ptr,\n                               const MaxPoolParams3D& params_3d);\n\n  static void Maxpool3dBackward(ep::Stream* stream, const NdIndexOffsetHelper<IDX, 4>& index_helper,\n                                const IDX elem_num, const T* src, T* dest,\n                                const int64_t* indice_ptr, const MaxPoolParams3D& params_3d);\n};\n\ntemplate<typename T, typename IDX>\nOF_DEVICE_FUNC void Maxpool1dForwardCompute(const NdIndexOffsetHelper<IDX, 2> index_helper,\n                                            IDX elem_num, const T* src, T* dest,\n                                            int64_t* indice_ptr, const int32_t padding_l,\n                                            const int32_t n_batch, const int32_t n_channel,\n                                            const int32_t x_length, const int32_t kernel_size_l,\n                                            const int32_t stride_l, const int32_t dilation_l) {\n  XPU_1D_KERNEL_LOOP(num, elem_num) {\n    IDX n_c, l;\n    index_helper.OffsetToNdIndex(num, n_c, l);\n\n    IDX lstart = l * stride_l - padding_l;\n    const IDX lend = (lstart + (kernel_size_l - 1) * dilation_l + 1) <= x_length\n                         ? (lstart + (kernel_size_l - 1) * dilation_l + 1)\n                         : x_length;\n\n    while (lstart < 0) { lstart += dilation_l; }\n\n    /* compute max value(src[src_idx]) in kernel box region, and save the value to dest[num] */\n    IDX max_index = lstart;\n\n    /* equal to -std::numeric_limits<T>::infinity(); */\n    T max_value = detail::numeric_limits<T>::lower_bound();\n    const T* data = src + n_c * x_length;\n    for (IDX idx = lstart; idx < lend; idx += dilation_l) {\n      const IDX window_idx = idx;\n      T val = data[window_idx];\n      if (val > max_value || detail::numerics<T>::isnan(val)) {\n        max_value = val;\n        max_index = idx;\n      }\n    }\n    dest[num] = max_value;\n    indice_ptr[num] = max_index;\n  }\n}\n\ntemplate<typename T, typename IDX>\nOF_DEVICE_FUNC void Maxpool1dBackwardCompute(const NdIndexOffsetHelper<IDX, 2> index_helper,\n                                             const IDX elem_num, const T* src, T* dest,\n                                             const int64_t* indice_ptr, const int32_t n_batch,\n                                             const int32_t n_channel, const int32_t src_length,\n                                             const int32_t dst_length) {\n  XPU_1D_KERNEL_LOOP(num, elem_num) {\n    IDX n_c, l;\n    index_helper.OffsetToNdIndex(num, n_c, l);\n\n    const IDX src_start = n_c * src_length;\n    const IDX dst_start = n_c * dst_length;\n    const IDX index = src_start + l;\n    const IDX max_index = dst_start + indice_ptr[index];\n    if (max_index != -1) {\n      /* update gradient, equals to dest[max_index] += src[index]; */\n      DeviceAdd<T>::Invoke(src + index, dest + max_index);\n    }\n  }\n}\n\ntemplate<typename T, typename IDX>\nOF_DEVICE_FUNC void Maxpool2dForwardComputeCFirst(\n    const NdIndexOffsetHelper<IDX, 3> index_helper, IDX elem_num, const T* src, T* dest,\n    int64_t* indice_ptr, const int32_t padding_h, const int32_t padding_w, const int32_t n_batch,\n    const int32_t n_channel, const int32_t x_height, const int32_t x_width,\n    const int32_t kernel_size_h, const int32_t kernel_size_w, const int32_t stride_h,\n    const int32_t stride_w, const int32_t dilation_h, const int32_t dilation_w) {\n  XPU_1D_KERNEL_LOOP(num, elem_num) {\n    IDX n_c, h, w;\n    index_helper.OffsetToNdIndex(num, n_c, h, w);\n    IDX hstart = h * stride_h - padding_h;\n    IDX wstart = w * stride_w - padding_w;\n    const IDX hend = (hstart + (kernel_size_h - 1) * dilation_h + 1) <= x_height\n                         ? (hstart + (kernel_size_h - 1) * dilation_h + 1)\n                         : x_height;\n    const IDX wend = (wstart + (kernel_size_w - 1) * dilation_w + 1) <= x_width\n                         ? (wstart + (kernel_size_w - 1) * dilation_w + 1)\n                         : x_width;\n    while (hstart < 0) { hstart += dilation_h; }\n    while (wstart < 0) { wstart += dilation_w; }\n    /* equal to -std::numeric_limits<T>::infinity(); */\n    T max_value = detail::numeric_limits<T>::lower_bound();\n    /* compute max value(src[src_idx]) in kernel box region, and save the value to dest[num] */\n    IDX max_index = hstart * x_width + wstart;\n    const T* data = src + n_c * x_width * x_height;\n    for (IDX i = hstart; i < hend; i += dilation_h) {\n      for (IDX j = wstart; j < wend; j += dilation_w) {\n        const IDX window_idx = i * x_width + j;\n        T val = data[window_idx];\n        /* NOTE:\n        std::isnan(val) only supports a few data types, see:\n        https://en.cppreference.com/w/cpp/numeric/math/isnan and when use gcc/g++ 4.x to compile,\n        the following exception will be throw:\n\n        new_kernel_util.cu:24] Check failed: cudaMemcpyAsync(dst, src, sz, cudaMemcpyDefault,\n        ctx->cuda_stream() ) : unspecified launch failure (719)\n\n        but if use gcc/g++ 7.x to compile, everything is ok! the exact reason is still unknown!\n        */\n        if (val > max_value || detail::numerics<T>::isnan(val)) {\n          max_index = window_idx;\n          max_value = val;\n        }\n      }\n    }\n    dest[num] = max_value;\n    indice_ptr[num] = max_index;\n  }\n}\n\ntemplate<typename T, typename IDX>\nOF_DEVICE_FUNC void Maxpool2dBackwardComputeCFirst(\n    const NdIndexOffsetHelper<IDX, 3> index_helper, const IDX elem_num, const T* src, T* dest,\n    const int64_t* indice_ptr, const int32_t n_batch, const int32_t n_channel,\n    const int32_t src_height, const int32_t src_width, const int32_t dst_height,\n    const int32_t dst_width) {\n  XPU_1D_KERNEL_LOOP(num, elem_num) {\n    IDX n_c, h, w;\n    index_helper.OffsetToNdIndex(num, n_c, h, w);\n\n    const IDX src_start = n_c * src_height * src_width;\n    const IDX dst_start = n_c * dst_height * dst_width;\n    const IDX index = src_start + h * src_width + w;\n\n    const IDX max_index = dst_start + indice_ptr[index];\n    if (max_index != -1) {\n      /* update gradient, equals to dest[max_index] += src[index]; */\n      DeviceAdd<T>::Invoke(src + index, dest + max_index);\n    }\n  }\n}\n\ntemplate<typename T, typename IDX>\nOF_DEVICE_FUNC void Maxpool2dBackwardComputeCLast(const NdIndexOffsetHelper<IDX, 4> index_helper,\n                                                  const IDX elem_num, const T* src, T* dest,\n                                                  const int64_t* indice_ptr, const int32_t n_batch,\n                                                  const int32_t n_channel, const int32_t src_height,\n                                                  const int32_t src_width, const int32_t dst_height,\n                                                  const int32_t dst_width) {\n  XPU_1D_KERNEL_LOOP(num, elem_num) {\n    IDX n, c, h, w;\n    index_helper.OffsetToNdIndex(num, n, c, h, w);\n    const IDX src_start = n * src_height * src_width * n_channel;\n    const IDX dst_start = n * dst_height * dst_width * n_channel;\n    const IDX index = src_start + h * src_width + w;\n    const IDX max_index = dst_start + indice_ptr[index];\n    if (max_index != -1) {\n      /* update gradient, equals to dest[max_index] += src[index]; */\n      DeviceAdd<T>::Invoke(src + index, dest + max_index);\n    }\n  }\n}\n\ntemplate<typename T, typename IDX>\nOF_DEVICE_FUNC void Maxpool3dForwardCompute(\n    const NdIndexOffsetHelper<IDX, 4> index_helper, IDX elem_num, const T* src, T* dest,\n    int64_t* indice_ptr, const int32_t padding_t, const int32_t padding_h, const int32_t padding_w,\n    const int32_t n_batch, const int32_t n_channel, const int32_t x_time, const int32_t x_height,\n    const int32_t x_width, const int32_t kernel_size_t, const int32_t kernel_size_h,\n    const int32_t kernel_size_w, const int32_t stride_t, const int32_t stride_h,\n    const int32_t stride_w, const int32_t dilation_t, const int32_t dilation_h,\n    const int32_t dilation_w) {\n  XPU_1D_KERNEL_LOOP(num, elem_num) {\n    IDX n_c, t, h, w;\n    index_helper.OffsetToNdIndex(num, n_c, t, h, w);\n\n    IDX tstart = t * stride_t - padding_t;\n    IDX hstart = h * stride_h - padding_h;\n    IDX wstart = w * stride_w - padding_w;\n\n    const IDX t1 = tstart + (kernel_size_t - 1) * dilation_t + 1;\n    const IDX t2 = hstart + (kernel_size_h - 1) * dilation_h + 1;\n    const IDX t3 = wstart + (kernel_size_w - 1) * dilation_w + 1;\n    const IDX tend = t1 <= x_time ? t1 : x_time;\n    const IDX hend = t2 <= x_height ? t2 : x_height;\n    const IDX wend = t3 <= x_width ? t3 : x_width;\n\n    while (tstart < 0) { tstart += dilation_t; }\n    while (hstart < 0) { hstart += dilation_h; }\n    while (wstart < 0) { wstart += dilation_w; }\n\n    IDX max_index = tstart * x_height * x_width + hstart * x_width + wstart;\n    const T* data = src + n_c * x_time * x_width * x_height;\n    T max_value = detail::numeric_limits<T>::lower_bound();\n    for (IDX zi = tstart; zi < tend; zi += dilation_t) {\n      for (IDX i = hstart; i < hend; i += dilation_h) {\n        for (IDX j = wstart; j < wend; j += dilation_w) {\n          const IDX window_idx = zi * x_height * x_width + i * x_width + j;\n          T val = data[window_idx];\n          if (val > max_value || detail::numerics<T>::isnan(val)) {\n            max_value = val;\n            max_index = window_idx;\n          }\n        }\n      }\n      /* set output to local max */\n      dest[num] = max_value;\n      /* store location of max */\n      indice_ptr[num] = max_index;\n    }\n  }\n}\n\ntemplate<typename T, typename IDX>\nOF_DEVICE_FUNC void Maxpool3dBackwardCompute(const NdIndexOffsetHelper<IDX, 4> index_helper,\n                                             const IDX elem_num, const T* src, T* dest,\n                                             const int64_t* indice_ptr, const int32_t n_batch,\n                                             const int32_t n_channel, const int32_t src_time,\n                                             const int32_t src_height, const int32_t src_width,\n                                             const int32_t dst_time, const int32_t dst_height,\n                                             const int32_t dst_width) {\n  XPU_1D_KERNEL_LOOP(num, elem_num) {\n    IDX n_c, t, h, w;\n    index_helper.OffsetToNdIndex(num, n_c, t, h, w);\n\n    const IDX src_start = n_c * src_time * src_height * src_width;\n    const IDX dst_start = n_c * dst_time * dst_height * dst_width;\n    const IDX index = src_start + t * src_height * src_width + h * src_width + w;\n    const IDX max_index = dst_start + indice_ptr[index];\n\n    if (max_index != -1) { DeviceAdd<T>::Invoke(src + index, dest + max_index); }\n  }\n}\n\n#define INSTANTIATE_POOL_KERNEL_UTIL(device_type_v, dtype_pair, index_dtype_pair) \\\n  template struct PoolKernelUtil<device_type_v, OF_PP_PAIR_FIRST(dtype_pair),     \\\n                                 OF_PP_PAIR_FIRST(index_dtype_pair)>;\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_KERNELS_POOL_KERNEL_UTIL_H_\n"
  },
  {
    "path": "oneflow/user/kernels/max_unpool_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"fmt/core.h\"\n#include \"oneflow/core/common/bfloat16.h\"\n#include \"oneflow/core/common/throw.h\"\n#include \"oneflow/user/kernels/max_unpool_kernel_util.h\"\n\nnamespace oneflow {\nnamespace {\n\ntemplate<typename T, typename IDX, typename F>\nvoid MaxUnpoolNdForwardOrBackward(const NdIndexOffsetHelper<IDX, 2>& index_helper,\n                                  const IDX elem_num, const int64_t* indice_ptr,\n                                  const int64_t hwd_size, const int64_t out_elem_num, const F& f) {\n  XPU_1D_KERNEL_LOOP(num, elem_num) {\n    IDX bc_idx, hwd_idx;\n    index_helper.OffsetToNdIndex(num, bc_idx, hwd_idx);\n    IDX idx = bc_idx * hwd_size + indice_ptr[num];\n    CHECK_OR_THROW(idx >= 0 && idx < out_elem_num) << fmt::format(\n        \"Found an invalid max index: {}, output volumes are of size {}\", idx, out_elem_num);\n    f(num, idx);\n  }\n}\n\n}  // namespace\n\ntemplate<typename T, typename IDX>\nstruct UnpoolKernelUtil<DeviceType::kCPU, T, IDX> {\n  static void MaxUnpoolNdForward(ep::Stream* stream,\n                                 const NdIndexOffsetHelper<IDX, 2>& index_helper,\n                                 const IDX elem_num, const T* src, T* dest,\n                                 const int64_t* indice_ptr, const int64_t y_hwd_size,\n                                 const int64_t y_elem_num) {\n    MaxUnpoolNdForwardOrBackward<T>(index_helper, elem_num, indice_ptr, y_hwd_size, y_elem_num,\n                                    [&](int64_t num, IDX idx) { dest[idx] = src[num]; });\n  }\n\n  static void MaxUnpoolNdBackward(ep::Stream* stream,\n                                  const NdIndexOffsetHelper<IDX, 2>& index_helper,\n                                  const IDX elem_num, const T* src, T* dest,\n                                  const int64_t* indice_ptr, const int64_t dy_hwd_size,\n                                  const int64_t dy_elem_num) {\n    MaxUnpoolNdForwardOrBackward<T>(index_helper, elem_num, indice_ptr, dy_hwd_size, dy_elem_num,\n                                    [&](int64_t num, IDX idx) { dest[num] = src[idx]; });\n  }\n};\n\ntemplate<DeviceType device_type, typename T>\nclass MaxUnpoolNdKernel final : public user_op::OpKernel {\n public:\n  MaxUnpoolNdKernel() = default;\n  ~MaxUnpoolNdKernel() = default;\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    const user_op::Tensor* indice = ctx->Tensor4ArgNameAndIndex(\"indices\", 0);\n    user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n\n    const int64_t elem_num = x->shape_view().elem_cnt();\n    const T* src = x->dptr<T>();\n    const int64_t* indice_ptr = indice->dptr<int64_t>();\n    T* dest = y->mut_dptr<T>();\n\n    DimVector x_vector(2);\n    x_vector.at(0) = x->shape_view().At(0) * x->shape_view().At(1);\n\n    int64_t y_hwd_size = 1;\n\n    x_vector.at(1) = std::accumulate(x->shape_view().begin() + 2, x->shape_view().end(), 1,\n                                     std::multiplies<int64_t>());\n    y_hwd_size = std::accumulate(y->shape_view().begin() + 2, y->shape_view().end(), 1,\n                                 std::multiplies<int64_t>());\n\n    std::unique_ptr<ep::primitive::Memset> memset_primitive =\n        ep::primitive::NewPrimitive<ep::primitive::MemsetFactory>(ctx->device_type());\n    CHECK(memset_primitive);\n    memset_primitive->Launch(ctx->stream(), dest, 0,\n                             y->shape_view().elem_cnt() * GetSizeOfDataType(y->data_type()));\n\n    const int64_t y_elem_num = y->shape_view().elem_cnt();\n\n    if (elem_num < GetMaxVal<int32_t>()) {\n      NdIndexOffsetHelper<int32_t, 2> index_helper(x_vector.data());\n      UnpoolKernelUtil<device_type, T, int32_t>::MaxUnpoolNdForward(\n          ctx->stream(), index_helper, elem_num, src, dest, indice_ptr, y_hwd_size, y_elem_num);\n    } else {\n      NdIndexOffsetHelper<int64_t, 2> index_helper(x_vector.data());\n      UnpoolKernelUtil<device_type, T, int64_t>::MaxUnpoolNdForward(\n          ctx->stream(), index_helper, elem_num, src, dest, indice_ptr, y_hwd_size, y_elem_num);\n    }\n  }\n};\n\ntemplate<DeviceType device_type, typename T>\nclass MaxUnpoolNdGradKernel final : public user_op::OpKernel {\n public:\n  MaxUnpoolNdGradKernel() = default;\n  ~MaxUnpoolNdGradKernel() = default;\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    const user_op::Tensor* indice = ctx->Tensor4ArgNameAndIndex(\"indices\", 0);\n    user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex(\"dx\", 0);\n\n    const int64_t elem_num = dx->shape_view().elem_cnt();\n    const T* src = dy->dptr<T>();\n    const int64_t* indice_ptr = indice->dptr<int64_t>();\n    T* dest = dx->mut_dptr<T>();\n\n    DimVector dx_vector(2);\n    dx_vector.at(0) = dx->shape_view().At(0) * dx->shape_view().At(1);\n    int64_t dy_hwd_size = 1;\n\n    dx_vector.at(1) = std::accumulate(dx->shape_view().begin() + 2, dx->shape_view().end(), 1,\n                                      std::multiplies<int64_t>());\n    dy_hwd_size = std::accumulate(dy->shape_view().begin() + 2, dy->shape_view().end(), 1,\n                                  std::multiplies<int64_t>());\n\n    const int64_t dy_elem_num = dy->shape_view().elem_cnt();\n\n    if (elem_num < GetMaxVal<int32_t>()) {\n      NdIndexOffsetHelper<int32_t, 2> index_helper(dx_vector.data());\n      UnpoolKernelUtil<device_type, T, int32_t>::MaxUnpoolNdBackward(\n          ctx->stream(), index_helper, elem_num, src, dest, indice_ptr, dy_hwd_size, dy_elem_num);\n    } else {\n      NdIndexOffsetHelper<int64_t, 2> index_helper(dx_vector.data());\n      UnpoolKernelUtil<device_type, T, int64_t>::MaxUnpoolNdBackward(\n          ctx->stream(), index_helper, elem_num, src, dest, indice_ptr, dy_hwd_size, dy_elem_num);\n    }\n  };\n};\n\n#define REGISTER_UNPOOL_KERNELS(device, dtype)                                          \\\n  REGISTER_USER_KERNEL(\"max_unpool_1d\")                                                 \\\n      .SetCreateFn<MaxUnpoolNdKernel<device, dtype>>()                                  \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == device)                             \\\n                       && (user_op::HobDataType(\"x\", 0) == GetDataType<dtype>::value)); \\\n  REGISTER_USER_KERNEL(\"max_unpool_2d\")                                                 \\\n      .SetCreateFn<MaxUnpoolNdKernel<device, dtype>>()                                  \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == device)                             \\\n                       && (user_op::HobDataType(\"x\", 0) == GetDataType<dtype>::value)); \\\n  REGISTER_USER_KERNEL(\"max_unpool_3d\")                                                 \\\n      .SetCreateFn<MaxUnpoolNdKernel<device, dtype>>()                                  \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == device)                             \\\n                       && (user_op::HobDataType(\"x\", 0) == GetDataType<dtype>::value)); \\\n  REGISTER_USER_KERNEL(\"max_unpool_1d_grad\")                                            \\\n      .SetCreateFn<MaxUnpoolNdGradKernel<device, dtype>>()                              \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == device)                             \\\n                       && (user_op::HobDataType(\"x\", 0) == GetDataType<dtype>::value)); \\\n  REGISTER_USER_KERNEL(\"max_unpool_2d_grad\")                                            \\\n      .SetCreateFn<MaxUnpoolNdGradKernel<device, dtype>>()                              \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == device)                             \\\n                       && (user_op::HobDataType(\"x\", 0) == GetDataType<dtype>::value)); \\\n  REGISTER_USER_KERNEL(\"max_unpool_3d_grad\")                                            \\\n      .SetCreateFn<MaxUnpoolNdGradKernel<device, dtype>>()                              \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == device)                             \\\n                       && (user_op::HobDataType(\"x\", 0) == GetDataType<dtype>::value));\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_UNPOOL_KERNEL_UTIL, (DeviceType::kCPU),\n                                 UNPOOL_DATA_TYPE_CPU_SEQ, UNPOOL_IDX_DATA_TYPE_SEQ);\n\n#define REGISTER_UNPOOL_WITH_DEVICE(device) \\\n  REGISTER_UNPOOL_KERNELS(device, int32_t)  \\\n  REGISTER_UNPOOL_KERNELS(device, int64_t)  \\\n  REGISTER_UNPOOL_KERNELS(device, float)    \\\n  REGISTER_UNPOOL_KERNELS(device, double)\n\nREGISTER_UNPOOL_WITH_DEVICE(DeviceType::kCPU)\nREGISTER_UNPOOL_KERNELS(DeviceType::kCPU, float16)\nREGISTER_UNPOOL_KERNELS(DeviceType::kCPU, bfloat16)\n\n#ifdef WITH_CUDA\nREGISTER_UNPOOL_WITH_DEVICE(DeviceType::kCUDA)\nREGISTER_UNPOOL_KERNELS(DeviceType::kCUDA, half)\n#if CUDA_VERSION >= 11000\nREGISTER_UNPOOL_KERNELS(DeviceType::kCUDA, nv_bfloat16)\n#endif  // CUDA_VERSION >= 11000\n#endif  // WITH_CUDA\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/max_unpool_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifdef WITH_CUDA\n#include \"oneflow/core/cuda/elementwise.cuh\"\n#include \"oneflow/user/kernels/max_unpool_kernel_util.h\"\n#include <cuda_fp16.h>\n\nnamespace oneflow {\nnamespace {\n\nconstexpr int kBlockSize = cuda::elementwise::kBlockSize;\n\nint GetMinThreadNum(int64_t elem_num) { return std::min<int64_t>(elem_num, kBlockSize); }\n\nint GetNumBlocks(int64_t elem_cnt) {\n  int num_blocks = 0;\n  OF_CUDA_CHECK(cuda::elementwise::GetNumBlocks(elem_cnt, &num_blocks));\n  return num_blocks;\n}\n\n}  // namespace\n\ntemplate<typename T, typename IDX>\n__launch_bounds__(kBlockSize) __global__\n    void DoCUDAMaxUnpoolNdForward(const NdIndexOffsetHelper<IDX, 2> index_helper, IDX elem_num,\n                                  const T* src, T* dest, const int64_t* indice_ptr,\n                                  const int64_t y_hwd_size, const int64_t y_elem_num) {\n  CUDA_1D_KERNEL_LOOP_T(IDX, num, elem_num) {\n    IDX bc_idx, hwd_idx;\n    index_helper.OffsetToNdIndex(num, bc_idx, hwd_idx);\n    IDX dest_idx = bc_idx * y_hwd_size + indice_ptr[num];\n    if (dest_idx >= 0 && dest_idx < y_elem_num) { dest[dest_idx] = src[num]; }\n  }\n}\n\ntemplate<typename T, typename IDX>\n__launch_bounds__(kBlockSize) __global__\n    void DoCUDAMaxUnpoolNdBackward(const NdIndexOffsetHelper<IDX, 2> index_helper, IDX elem_num,\n                                   const T* src, T* dest, const int64_t* indice_ptr,\n                                   const int64_t dy_hwd_size, const int64_t dy_elem_num) {\n  CUDA_1D_KERNEL_LOOP_T(IDX, num, elem_num) {\n    IDX bc_idx, hwd_idx;\n    index_helper.OffsetToNdIndex(num, bc_idx, hwd_idx);\n    IDX src_idx = bc_idx * dy_hwd_size + indice_ptr[num];\n    if (src_idx >= 0 && src_idx < dy_elem_num) {\n      dest[num] = src[src_idx];\n    } else {\n      dest[num] = 0.0f;\n    }\n  }\n}\n\ntemplate<typename T, typename IDX>\nstruct UnpoolKernelUtil<DeviceType::kCUDA, T, IDX> {\n  static void MaxUnpoolNdForward(ep::Stream* stream,\n                                 const NdIndexOffsetHelper<IDX, 2>& index_helper, IDX elem_num,\n                                 const T* src, T* dest, const int64_t* indice_ptr,\n                                 const int64_t y_hwd_size, const int64_t y_elem_num) {\n    DoCUDAMaxUnpoolNdForward<T, IDX><<<GetNumBlocks(elem_num), GetMinThreadNum(elem_num), 0,\n                                       stream->As<ep::CudaStream>()->cuda_stream()>>>(\n        index_helper, elem_num, src, dest, indice_ptr, y_hwd_size, y_elem_num);\n  }\n\n  static void MaxUnpoolNdBackward(ep::Stream* stream,\n                                  const NdIndexOffsetHelper<IDX, 2>& index_helper, IDX elem_num,\n                                  const T* src, T* dest, const int64_t* indice_ptr,\n                                  const int64_t dy_hwd_size, const int64_t dy_elem_num) {\n    DoCUDAMaxUnpoolNdBackward<T, IDX><<<GetNumBlocks(elem_num), GetMinThreadNum(elem_num), 0,\n                                        stream->As<ep::CudaStream>()->cuda_stream()>>>(\n        index_helper, elem_num, src, dest, indice_ptr, dy_hwd_size, dy_elem_num);\n  }\n};\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_UNPOOL_KERNEL_UTIL, (DeviceType::kCUDA),\n                                 UNPOOL_DATA_TYPE_CUDA_SEQ, UNPOOL_IDX_DATA_TYPE_SEQ);\n#if CUDA_VERSION >= 11000\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_UNPOOL_KERNEL_UTIL, (DeviceType::kCUDA),\n                                 OF_PP_MAKE_TUPLE_SEQ(nv_bfloat16, DataType::kBFloat16),\n                                 UNPOOL_IDX_DATA_TYPE_SEQ);\n#endif  // CUDA_VERSION >= 11000\n\n}  // namespace oneflow\n#endif  // WITH_CUDA\n"
  },
  {
    "path": "oneflow/user/kernels/max_unpool_kernel_util.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/operator/operator_util.h\"\n#include \"oneflow/user/utils/pool_util.h\"\n#include \"oneflow/user/kernels/max_unpool_kernel_util.h\"\n\nnamespace oneflow {\nnamespace {\n\nvoid GetWindowedOutputShape(int64_t input_size, int32_t filter_size, int32_t stride,\n                            int32_t padding, int64_t* output_ptr) {\n  int64_t output_size = (input_size - 1) * stride - 2 * padding + filter_size;\n  *output_ptr = output_size;\n}\n\nvoid Get3DOutputShape(const DimVector& in, const std::vector<int32_t>& pool_size,\n                      const std::vector<int32_t>& strides, const std::vector<int32_t>& padding,\n                      DimVector* out) {\n  out->clear();\n  out->resize(3);\n  FOR_RANGE(size_t, i, 0, 3) {\n    int64_t* out_ptr = &(*out).at(i);\n    GetWindowedOutputShape(in.at(i), pool_size.at(i), strides.at(i), padding.at(i), out_ptr);\n  }\n}\n\n}  // namespace\n\nMaxUnpoolParams3D::MaxUnpoolParams3D(const int32_t dim, const ShapeView& x_shape,\n                                     const std::vector<int32_t>& padding,\n                                     const std::vector<int32_t>& kernel_size,\n                                     const std::vector<int32_t>& stride)\n    : dim_(dim),\n      padding_(Get3DVec<Get3DVecType::kPad>(padding, dim)),\n      pool_size_3d_(Get3DVec(kernel_size, dim)),\n      stride_3d_(Get3DVec(stride, dim)),\n      batch_num_(x_shape.At(0)),\n      channel_num_(x_shape.At(1)) {\n  std::string data_format = \"channels_first\";\n  x_3d_ = {GetInDim(x_shape, data_format, 0, dim), GetInDim(x_shape, data_format, 1, dim),\n           GetInDim(x_shape, data_format, 2, dim)};\n  Get3DOutputShape(x_3d_, pool_size_3d_, stride_3d_, padding_, &y_3d_);\n}\n\nvoid MaxUnpoolParams3D::Reset(const ShapeView& x_shape) {\n  std::string data_format = \"channels_first\";\n  x_3d_ = {GetInDim(x_shape, data_format, 0, dim_), GetInDim(x_shape, data_format, 1, dim_),\n           GetInDim(x_shape, data_format, 2, dim_)};\n  Get3DOutputShape(x_3d_, pool_size_3d_, stride_3d_, padding_, &y_3d_);\n}\n\nint64_t MaxUnpoolParams3D::GetYStride() const { return y_3d_.at(0) * y_3d_.at(1) * y_3d_.at(2); }\n\nShape MaxUnpoolParams3D::GetYShape() const {\n  DimVector y_dim_vec;\n  if (dim_ == 1) {\n    y_dim_vec = {y_3d_.at(2)};\n  } else if (dim_ == 2) {\n    y_dim_vec = {y_3d_.at(1), y_3d_.at(2)};\n  } else if (dim_ == 3) {\n    y_dim_vec = {y_3d_.at(0), y_3d_.at(1), y_3d_.at(2)};\n  } else {\n    UNIMPLEMENTED();\n  }\n  y_dim_vec.insert(y_dim_vec.begin(), channel_num_);\n  y_dim_vec.insert(y_dim_vec.begin(), batch_num_);\n  return Shape(y_dim_vec);\n}\n\nShape MaxUnpoolParams3D::GetXShape5D() const {\n  return Shape({batch_num_, channel_num_, x_3d_.at(0), x_3d_.at(1), x_3d_.at(2)});\n}\n\nShape MaxUnpoolParams3D::GetYShape5D() const {\n  return Shape({batch_num_, channel_num_, y_3d_.at(0), y_3d_.at(1), y_3d_.at(2)});\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/max_unpool_kernel_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_KERNELS_UNPOOL_KERNEL_UTIL_H_\n#define ONEFLOW_USER_KERNELS_UNPOOL_KERNEL_UTIL_H_\n#include \"oneflow/core/ndarray/xpu_util.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/common/nd_index_offset_helper.h\"\n\nnamespace oneflow {\n\n#define UNPOOL_DATA_TYPE_SEQ                      \\\n  OF_PP_MAKE_TUPLE_SEQ(int32_t, DataType::kInt32) \\\n  OF_PP_MAKE_TUPLE_SEQ(int64_t, DataType::kInt64) \\\n  OF_PP_MAKE_TUPLE_SEQ(float, DataType::kFloat)   \\\n  OF_PP_MAKE_TUPLE_SEQ(double, DataType::kDouble)\n\n#define UNPOOL_IDX_DATA_TYPE_SEQ                  \\\n  OF_PP_MAKE_TUPLE_SEQ(int32_t, DataType::kInt32) \\\n  OF_PP_MAKE_TUPLE_SEQ(int64_t, DataType::kInt64)\n\n#define UNPOOL_DATA_TYPE_CPU_SEQ                                         \\\n  UNPOOL_DATA_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(float16, DataType::kFloat16) \\\n      OF_PP_MAKE_TUPLE_SEQ(bfloat16, DataType::kBFloat16)\n#define UNPOOL_DATA_TYPE_CUDA_SEQ \\\n  UNPOOL_DATA_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(half, DataType::kFloat16)\n// OF_PP_MAKE_TUPLE_SEQ(nv_bfloat16, DataType::kBFloat16)\n\ntypedef small_vector<int64_t, SHAPE_MAX_AXIS_SIZE> FixedDimVector;\n\nclass MaxUnpoolParams3D {\n public:\n  MaxUnpoolParams3D(const int32_t dim, const ShapeView& x_shape,\n                    const std::vector<int32_t>& padding, const std::vector<int32_t>& kernel_size,\n                    const std::vector<int32_t>& stride);\n  ~MaxUnpoolParams3D() = default;\n\n  const std::vector<int32_t>& padding() const { return padding_; }\n  const std::vector<int32_t>& pool_size_3d() const { return pool_size_3d_; }\n  const std::vector<int32_t>& stride_3d() const { return stride_3d_; }\n  const int32_t& num_batch() const { return batch_num_; }\n  const int32_t& num_channel() const { return channel_num_; }\n\n  void Reset(const ShapeView& x_shape);\n\n  Shape GetYShape() const;\n  Shape GetXShape5D() const;\n  Shape GetYShape5D() const;\n  int64_t GetYStride() const;\n\n private:\n  int32_t dim_;\n  FixedDimVector x_3d_;\n  FixedDimVector y_3d_;\n  std::vector<int32_t> padding_;\n  std::vector<int32_t> pool_size_3d_;\n  std::vector<int32_t> stride_3d_;\n  int32_t batch_num_;\n  int32_t channel_num_;\n};\n\ntemplate<DeviceType device_type, typename T, typename IDX>\nstruct UnpoolKernelUtil {\n  static void MaxUnpoolNdForward(ep::Stream* stream,\n                                 const NdIndexOffsetHelper<IDX, 2>& index_helper,\n                                 const IDX elem_num, const T* src, T* dest,\n                                 const int64_t* indice_ptr, const int64_t y_hwd_size,\n                                 const int64_t y_elem_num);\n  static void MaxUnpoolNdBackward(ep::Stream* stream,\n                                  const NdIndexOffsetHelper<IDX, 2>& index_helper,\n                                  const IDX elem_num, const T* src, T* dest,\n                                  const int64_t* indice_ptr, const int64_t dy_hwd_size,\n                                  const int64_t dy_elem_num);\n};\n\n#define INSTANTIATE_UNPOOL_KERNEL_UTIL(device_type_v, dtype_pair, index_dtype_pair) \\\n  template struct UnpoolKernelUtil<device_type_v, OF_PP_PAIR_FIRST(dtype_pair),     \\\n                                   OF_PP_PAIR_FIRST(index_dtype_pair)>;\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_KERNELS_UNPOOL_KERNEL_UTIL_H_\n"
  },
  {
    "path": "oneflow/user/kernels/median_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/kernel_util.h\"\n\nnamespace oneflow {\n\ntemplate<typename T>\nclass CpuMedianKernel final : public user_op::OpKernel {\n public:\n  CpuMedianKernel() = default;\n  ~CpuMedianKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"input\", 0);\n    const int64_t size = in->shape_view().elem_cnt();\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"output\", 0);\n    user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n    T* out_ptr = out->mut_dptr<T>();\n    Memcpy<DeviceType::kCPU>(ctx->stream(), tmp_buffer->mut_dptr<void>(), in->dptr<void>(),\n                             size * sizeof(T));\n    T* first = tmp_buffer->mut_dptr<T>();\n    T* last = first + size;\n    T* median = first + (size - 1) / 2;\n    std::nth_element(first, median, last);\n    *out_ptr = *median;\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_CPU_MEDIAN_KERNEL(dtype)                                                  \\\n  REGISTER_USER_KERNEL(\"median\")                                                           \\\n      .SetCreateFn<CpuMedianKernel<dtype>>()                                               \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)                      \\\n                       && (user_op::HobDataType(\"input\", 0) == GetDataType<dtype>::value)) \\\n      .SetInferTmpSizeFn([](user_op::InferContext* ctx) -> size_t {                        \\\n        return ctx->InputShape(\"input\", 0).elem_cnt() * sizeof(dtype);                     \\\n      });\n\nREGISTER_CPU_MEDIAN_KERNEL(float)\nREGISTER_CPU_MEDIAN_KERNEL(double)\nREGISTER_CPU_MEDIAN_KERNEL(int8_t)\nREGISTER_CPU_MEDIAN_KERNEL(uint8_t)\nREGISTER_CPU_MEDIAN_KERNEL(int32_t)\nREGISTER_CPU_MEDIAN_KERNEL(int64_t)\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/median_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/new_kernel_util.h\"\n#include \"oneflow/user/kernels/radix_sort.cuh\"\n\nnamespace oneflow {\n\ntemplate<typename T>\nclass CudaMedianKernel final : public user_op::OpKernel {\n public:\n  CudaMedianKernel() = default;\n  ~CudaMedianKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"input\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"output\", 0);\n    user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n\n    const int32_t instance_size = in->shape_view().elem_cnt();\n    const size_t sort_tensor_buffer_bytes = GetCudaAlignedSize(instance_size * sizeof(T));\n    SortKeysAscending(\n        in->dptr<T>(), 1, instance_size,\n        reinterpret_cast<void*>(tmp_buffer->mut_dptr<char>() + sort_tensor_buffer_bytes),\n        tmp_buffer->shape_view().elem_cnt() - sort_tensor_buffer_bytes, tmp_buffer->mut_dptr<T>(),\n        ctx->stream()->As<ep::CudaStream>()->cuda_stream());\n    Memcpy<DeviceType::kCUDA>(ctx->stream(), out->mut_dptr<T>(),\n                              tmp_buffer->mut_dptr<T>() + (instance_size - 1) / 2, sizeof(T));\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_CUDA_MEDIAN_KERNEL(dtype)                                                   \\\n  REGISTER_USER_KERNEL(\"median\")                                                             \\\n      .SetCreateFn<CudaMedianKernel<dtype>>()                                                \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)                       \\\n                       && (user_op::HobDataType(\"input\", 0) == GetDataType<dtype>::value))   \\\n      .SetInferTmpSizeFn([](user_op::InferContext* ctx) -> size_t {                          \\\n        const Shape& in_shape = ctx->InputShape(\"input\", 0);                                 \\\n        const int32_t instance_size = in_shape.elem_cnt();                                   \\\n        size_t sort_tmp_buffer_bytes =                                                       \\\n            InferTempStorageForSortKeysAscending<dtype>(1, instance_size);                   \\\n        size_t sort_tensor_buffer_bytes = GetCudaAlignedSize(instance_size * sizeof(dtype)); \\\n        return sort_tmp_buffer_bytes + sort_tensor_buffer_bytes;                             \\\n      });\n\nREGISTER_CUDA_MEDIAN_KERNEL(float)\nREGISTER_CUDA_MEDIAN_KERNEL(double)\nREGISTER_CUDA_MEDIAN_KERNEL(int8_t)\nREGISTER_CUDA_MEDIAN_KERNEL(uint8_t)\nREGISTER_CUDA_MEDIAN_KERNEL(int32_t)\nREGISTER_CUDA_MEDIAN_KERNEL(int64_t)\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/median_with_indices_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/kernel_util.h\"\n#include \"oneflow/core/thread/thread_manager.h\"\n\nnamespace oneflow {\n\ntemplate<typename T>\nclass CpuMedianWithIndicesKernel final : public user_op::OpKernel {\n public:\n  CpuMedianWithIndicesKernel() = default;\n  ~CpuMedianWithIndicesKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"input\", 0);\n    const int64_t num_axes = in->shape_view().NumAxes();\n    const int64_t size = in->shape_view().elem_cnt();\n    if (size == 0) return;\n    const int64_t stride = in->shape_view().At(num_axes - 1);\n    const int64_t instance_num = size / stride;\n    user_op::Tensor* values = ctx->Tensor4ArgNameAndIndex(\"values\", 0);\n    user_op::Tensor* indices = ctx->Tensor4ArgNameAndIndex(\"indices\", 0);\n    user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n    Memcpy<DeviceType::kCPU>(ctx->stream(), tmp_buffer->mut_dptr<void>(), in->dptr<void>(),\n                             size * sizeof(T));\n    const int64_t thread_num =\n        std::min(instance_num, (int64_t)Singleton<ThreadPool>::Get()->thread_num());\n    const BalancedSplitter bs(instance_num, thread_num);\n    BlockingCounter bc(thread_num);\n    FOR_RANGE(int64_t, thread_id, 0, thread_num) {\n      const Range range = bs.At(thread_id);\n      Singleton<ThreadPool>::Get()->AddWork([=, &bc]() {\n        FOR_RANGE(int64_t, i, range.begin(), range.end()) {\n          T* in_ptr = tmp_buffer->mut_dptr<T>() + i * stride;\n          T* val_ptr = values->mut_dptr<T>() + i;\n          int64_t* ind_ptr = indices->mut_dptr<int64_t>() + i;\n          std::vector<int64_t> idx(stride);\n          auto first = idx.begin();\n          auto last = idx.end();\n          std::iota(first, last, 0);\n          auto nth = first;\n          nth += (stride - 1) / 2;\n          std::nth_element(first, nth, last, [&in_ptr](int64_t i, int64_t j) {\n            return in_ptr[i] < in_ptr[j] || (in_ptr[i] == in_ptr[j] && i < j);\n          });\n          *val_ptr = in_ptr[*nth];\n          *ind_ptr = *nth;\n        }\n        bc.Decrease();\n      });\n    }\n    bc.WaitForeverUntilCntEqualZero();\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_CPU_MEDIAN_WITH_INDICES_KERNEL(dtype)                                     \\\n  REGISTER_USER_KERNEL(\"median_with_indices\")                                              \\\n      .SetCreateFn<CpuMedianWithIndicesKernel<dtype>>()                                    \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)                      \\\n                       && (user_op::HobDataType(\"input\", 0) == GetDataType<dtype>::value)) \\\n      .SetInferTmpSizeFn([](user_op::InferContext* ctx) -> size_t {                        \\\n        return ctx->InputShape(\"input\", 0).elem_cnt() * sizeof(dtype);                     \\\n      });\n\nREGISTER_CPU_MEDIAN_WITH_INDICES_KERNEL(float)\nREGISTER_CPU_MEDIAN_WITH_INDICES_KERNEL(double)\nREGISTER_CPU_MEDIAN_WITH_INDICES_KERNEL(int8_t)\nREGISTER_CPU_MEDIAN_WITH_INDICES_KERNEL(uint8_t)\nREGISTER_CPU_MEDIAN_WITH_INDICES_KERNEL(int32_t)\nREGISTER_CPU_MEDIAN_WITH_INDICES_KERNEL(int64_t)\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/median_with_indices_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/new_kernel_util.h\"\n#include \"oneflow/user/kernels/radix_sort.cuh\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename T, typename IDX>\n__global__ void MedianSelectCuda(const IDX reduce_elem_cnt, const IDX stride, const T* in,\n                                 const int64_t* sort_indices, T* values, int64_t* indices) {\n  IDX nth = (stride - 1) / 2;\n  CUDA_1D_KERNEL_LOOP_T(IDX, i, reduce_elem_cnt) {\n    values[i] = in[i * stride + nth];\n    indices[i] = sort_indices[i * stride + nth];\n  }\n}\n\nbool IsSafeUseIndex32(int64_t elem_cnt) { return elem_cnt < GetMaxVal<int32_t>() / 2; }\n\ntemplate<typename T>\nvoid DispatchIndexSize(ep::Stream* stream, const int64_t elem_cnt, const int64_t stride,\n                       const T* in, const int64_t* sort_indices, T* out, int64_t* out_indices) {\n  const int64_t reduce_elem_cnt = elem_cnt / stride;\n  if (IsSafeUseIndex32(elem_cnt)) {\n    RUN_CUDA_KERNEL((MedianSelectCuda<T, int32_t>), stream, reduce_elem_cnt, reduce_elem_cnt,\n                    stride, in, sort_indices, out, out_indices);\n  } else {\n    RUN_CUDA_KERNEL((MedianSelectCuda<T, int64_t>), stream, reduce_elem_cnt, reduce_elem_cnt,\n                    stride, in, sort_indices, out, out_indices);\n  }\n}\n\ntemplate<typename T>\nclass TmpBufferManager final {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(TmpBufferManager);\n  TmpBufferManager(size_t capacity, void* ptr, const ShapeView& in_shape)\n      : capacity_{capacity},\n        sorted_in_elem_cnt_{in_shape.elem_cnt()},\n        indices_elem_cnt_{sorted_in_elem_cnt_} {\n    const size_t sort_tensor_buffer_bytes = GetCudaAlignedSize(sorted_in_elem_cnt_ * sizeof(T));\n    const size_t sort_indices_buffer_bytes =\n        GetCudaAlignedSize(indices_elem_cnt_ * sizeof(int64_t));\n    sorted_in_ptr_ = reinterpret_cast<T*>(ptr);\n    in_indices_ptr_ = reinterpret_cast<int64_t*>(reinterpret_cast<char*>(sorted_in_ptr_)\n                                                 + sort_tensor_buffer_bytes);\n    out_indices_ptr_ = reinterpret_cast<int64_t*>(reinterpret_cast<char*>(in_indices_ptr_)\n                                                  + sort_indices_buffer_bytes);\n    temp_storage_ptr_ = reinterpret_cast<void*>(reinterpret_cast<char*>(out_indices_ptr_)\n                                                + sort_indices_buffer_bytes);\n    temp_storage_bytes_ = capacity_ - sort_tensor_buffer_bytes - sort_indices_buffer_bytes * 2;\n    CHECK_GE(temp_storage_bytes_, 0);\n  }\n  ~TmpBufferManager() = default;\n\n  T* SortedInPtr() const { return sorted_in_ptr_; }\n  int64_t* InIndicesPtr() const { return in_indices_ptr_; }\n  int64_t* OutIndicesPtr() const { return out_indices_ptr_; }\n  void* TempStoragePtr() const { return temp_storage_ptr_; }\n\n  size_t TempStorageBytes() const { return temp_storage_bytes_; }\n\n private:\n  size_t capacity_;\n\n  T* sorted_in_ptr_;\n  int64_t* in_indices_ptr_;\n  int64_t* out_indices_ptr_;\n  void* temp_storage_ptr_;\n\n  int64_t sorted_in_elem_cnt_;\n  int64_t indices_elem_cnt_;\n  size_t temp_storage_bytes_;\n};\n\n__global__ void InitializeIndices(int64_t elem_cnt, int64_t* indices_ptr, int64_t instance_size) {\n  CUDA_1D_KERNEL_LOOP(i, elem_cnt) { indices_ptr[i] = i % instance_size; };\n}\n\n}  // namespace\n\ntemplate<typename T>\nclass CudaMedianWithIndicesKernel final : public user_op::OpKernel {\n public:\n  CudaMedianWithIndicesKernel() = default;\n  ~CudaMedianWithIndicesKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"input\", 0);\n    if (in->shape_view().elem_cnt() == 0) return;\n    user_op::Tensor* values = ctx->Tensor4ArgNameAndIndex(\"values\", 0);\n    user_op::Tensor* indices = ctx->Tensor4ArgNameAndIndex(\"indices\", 0);\n    user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n    TmpBufferManager<T> buf_manager(tmp_buffer->shape_view().elem_cnt(),\n                                    tmp_buffer->mut_dptr<void>(), in->shape_view());\n\n    const int64_t elem_cnt = in->shape_view().elem_cnt();\n    const int64_t instance_size = in->shape_view().At(in->shape_view().NumAxes() - 1);\n    const int64_t instance_num = elem_cnt / instance_size;\n    RUN_CUDA_KERNEL(InitializeIndices, ctx->stream(), elem_cnt, elem_cnt,\n                    buf_manager.InIndicesPtr(), instance_size);\n    SortPairsAscending(in->dptr<T>(), buf_manager.InIndicesPtr(), instance_num, instance_size,\n                       buf_manager.TempStoragePtr(), buf_manager.TempStorageBytes(),\n                       buf_manager.SortedInPtr(), buf_manager.OutIndicesPtr(),\n                       ctx->stream()->As<ep::CudaStream>()->cuda_stream());\n    DispatchIndexSize(ctx->stream(), elem_cnt, instance_size, buf_manager.SortedInPtr(),\n                      buf_manager.OutIndicesPtr(), values->mut_dptr<T>(),\n                      indices->mut_dptr<int64_t>());\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_CUDA_MEDIAN_WITH_INDICES_KERNEL(dtype)                                            \\\n  REGISTER_USER_KERNEL(\"median_with_indices\")                                                      \\\n      .SetCreateFn<CudaMedianWithIndicesKernel<dtype>>()                                           \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)                             \\\n                       && (user_op::HobDataType(\"input\", 0) == GetDataType<dtype>::value))         \\\n      .SetInferTmpSizeFn([](user_op::InferContext* ctx) -> size_t {                                \\\n        const Shape& in_shape = ctx->InputShape(\"input\", 0);                                       \\\n        const int64_t instance_size = in_shape.dim_vec().back();                                   \\\n        const int64_t instance_num = in_shape.elem_cnt() / instance_size;                          \\\n        size_t sort_tmp_buffer_bytes =                                                             \\\n            InferTempStorageForSortPairsAscending<dtype, int64_t>(instance_num, instance_size);    \\\n        size_t sort_tensor_buffer_bytes = GetCudaAlignedSize(in_shape.elem_cnt() * sizeof(dtype)); \\\n        size_t sort_indices_buffer_bytes =                                                         \\\n            GetCudaAlignedSize(in_shape.elem_cnt() * sizeof(int64_t));                             \\\n        return sort_tmp_buffer_bytes + sort_tensor_buffer_bytes + sort_indices_buffer_bytes * 2;   \\\n      });\n\nREGISTER_CUDA_MEDIAN_WITH_INDICES_KERNEL(float)\nREGISTER_CUDA_MEDIAN_WITH_INDICES_KERNEL(double)\nREGISTER_CUDA_MEDIAN_WITH_INDICES_KERNEL(int8_t)\nREGISTER_CUDA_MEDIAN_WITH_INDICES_KERNEL(uint8_t)\nREGISTER_CUDA_MEDIAN_WITH_INDICES_KERNEL(int32_t)\nREGISTER_CUDA_MEDIAN_WITH_INDICES_KERNEL(int64_t)\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/min_max_observer_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n\n#include <algorithm>\n\nnamespace oneflow {\n\ntemplate<typename T>\nvoid GenQuantScaleSymmetric(const T* in_ptr, const int32_t quantization_bit,\n                            const int64_t num_elements, T* scale, T* zero_point) {\n  T in_max = *std::max_element(in_ptr, in_ptr + num_elements);\n  T in_min = *std::min_element(in_ptr, in_ptr + num_elements);\n\n  in_max = std::max(std::abs(in_max), std::abs(in_min));\n\n  T denominator = static_cast<T>(pow(2.0, quantization_bit - 1)) - 1;\n\n  *scale = in_max / denominator;\n  *zero_point = 0;\n}\n\ntemplate<typename T>\nvoid GenQuantScaleAffine(const T* in_ptr, const int32_t quantization_bit,\n                         const int64_t num_elements, T* scale, T* zero_point) {\n  T in_max = *std::max_element(in_ptr, in_ptr + num_elements);\n  T in_min = *std::min_element(in_ptr, in_ptr + num_elements);\n\n  T denominator = static_cast<T>(pow(2.0, quantization_bit)) - 1;\n\n  *scale = (in_max - in_min) / denominator;\n  *zero_point = -std::nearbyint(in_min / (*scale));\n}\n\ntemplate<typename T>\nvoid GenQuantScaleCambricon(const T* in_ptr, const int32_t quantization_bit,\n                            const int64_t num_elements, T* scale, T* zero_point) {\n  T in_max = *std::max_element(in_ptr, in_ptr + num_elements);\n  T in_min = *std::min_element(in_ptr, in_ptr + num_elements);\n\n  in_max = std::max(std::abs(in_max), std::abs(in_min));\n\n  *scale = std::floor(std::log2(in_max)) - (quantization_bit - 2);\n  *zero_point = 0;\n}\n\ntemplate<typename T>\nclass CpuMinMaxObserverKernel final : public user_op::OpKernel {\n public:\n  CpuMinMaxObserverKernel() = default;\n  ~CpuMinMaxObserverKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    user_op::Tensor* scale = ctx->Tensor4ArgNameAndIndex(\"scale\", 0);\n    user_op::Tensor* zero_point = ctx->Tensor4ArgNameAndIndex(\"zero_point\", 0);\n\n    const std::string quantization_scheme = ctx->Attr<std::string>(\"quantization_scheme\");\n    const int32_t quantization_bit = ctx->Attr<int32_t>(\"quantization_bit\");\n    const bool per_layer_quantization = ctx->Attr<bool>(\"per_layer_quantization\");\n    const std::string quantization_formula = ctx->Attr<std::string>(\"quantization_formula\");\n\n    const T* in_ptr = in->dptr<T>();\n    T* scale_ptr = scale->mut_dptr<T>();\n    T* zero_point_ptr = zero_point->mut_dptr<T>();\n\n    if (quantization_formula == \"google\") {\n      // NOTE(Liang Depeng): per-layer quantization by default\n      int64_t outer_num = 1;\n      int64_t inner_num = in->shape_view().elem_cnt();\n      if (!per_layer_quantization) {  // per-channel quantization\n        outer_num = in->shape_view().At(0);\n        inner_num = in->shape_view().Count(1);\n      }\n\n      if (quantization_scheme == \"symmetric\") {\n        FOR_RANGE(int64_t, c, 0, outer_num) {\n          GenQuantScaleSymmetric(in_ptr, quantization_bit, inner_num, scale_ptr, zero_point_ptr);\n          in_ptr += inner_num;\n          scale_ptr += 1;\n          zero_point_ptr += 1;\n        }\n      } else {  // quantization_scheme == \"affine\"\n        FOR_RANGE(int64_t, c, 0, outer_num) {\n          GenQuantScaleAffine(in_ptr, quantization_bit, inner_num, scale_ptr, zero_point_ptr);\n          in_ptr += inner_num;\n          scale_ptr += 1;\n          zero_point_ptr += 1;\n        }\n      }\n    } else if (quantization_formula == \"cambricon\") {\n      if (!per_layer_quantization) {\n        UNIMPLEMENTED() << \" per-channel mode is not supported in cambricon scheme\";\n      }\n      GenQuantScaleCambricon(in_ptr, quantization_bit, in->shape_view().elem_cnt(), scale_ptr,\n                             zero_point_ptr);\n    } else {\n      UNIMPLEMENTED();\n    }\n  }\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_MIN_MAX_OBSERVER_KERNEL(dtype)                       \\\n  REGISTER_USER_KERNEL(\"min_max_observer\")                            \\\n      .SetCreateFn<CpuMinMaxObserverKernel<dtype>>()                  \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \\\n                       && (user_op::HobDataType(\"in\", 0) == GetDataType<dtype>::value))\n\nREGISTER_MIN_MAX_OBSERVER_KERNEL(float);\nREGISTER_MIN_MAX_OBSERVER_KERNEL(double);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/min_max_observer_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/device/cuda_util.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/cuda/atomic.cuh\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n\n#include <float.h>\n\nnamespace oneflow {\n\nnamespace {\n\n// NOTE(Liang Depeng): refer to\n// https://stackoverflow.com/questions/17371275/implementing-max-reduce-in-cuda\ntemplate<typename T>\n__global__ void ReduceMaxMinPerLayer(const T* input_ptr, const int64_t elements, T* max_ptr,\n                                     T* min_ptr) {\n  extern __shared__ unsigned char shared_max_min_memory[];\n  T* shared_max = reinterpret_cast<T*>(shared_max_min_memory);\n  T* shared_min = shared_max + blockDim.x;\n\n  int64_t tid = threadIdx.x;\n  int64_t gid = (blockDim.x * blockIdx.x) + tid;\n  shared_max[tid] = -FLT_MAX;\n  shared_min[tid] = -FLT_MAX;\n\n  while (gid < elements) {\n    shared_max[tid] = max(shared_max[tid], input_ptr[gid]);\n    shared_min[tid] = max(shared_min[tid], -input_ptr[gid]);\n    gid += gridDim.x * blockDim.x;\n  }\n  __syncthreads();\n  gid = (blockDim.x * blockIdx.x) + tid;\n  for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) {\n    if (tid < s && gid < elements) {\n      shared_max[tid] = max(shared_max[tid], shared_max[tid + s]);\n      shared_min[tid] = max(shared_min[tid], shared_min[tid + s]);\n    }\n    __syncthreads();\n  }\n\n  if (tid == 0) {\n    cuda::atomic::Max(max_ptr, shared_max[0]);\n    cuda::atomic::Max(min_ptr, shared_min[0]);\n  }\n}\n\ntemplate<typename T>\n__global__ void ReduceMaxMinPerChannel(const T* input_ptr, const int64_t elements,\n                                       const int64_t num_channels, const int64_t panel_size,\n                                       T* max_ptr, T* min_ptr) {\n  extern __shared__ unsigned char shared_max_min_memory[];\n  T* shared_max = reinterpret_cast<T*>(shared_max_min_memory);\n  T* shared_min = shared_max + blockDim.x;\n\n  int64_t cur_channel = blockIdx.x;\n  int64_t tid = threadIdx.x;\n\n  while (cur_channel < num_channels) {\n    shared_max[tid] = -FLT_MAX;\n    shared_min[tid] = -FLT_MAX;\n\n    int64_t index = (panel_size * cur_channel) + tid;\n    int64_t end = panel_size * (cur_channel + 1);\n\n    while (index < end && index < elements) {\n      shared_max[tid] = max(shared_max[tid], input_ptr[index]);\n      shared_min[tid] = max(shared_min[tid], -input_ptr[index]);\n      index += blockDim.x;\n    }\n    __syncthreads();\n\n    for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) {\n      if (tid < s) {\n        shared_max[tid] = max(shared_max[tid], shared_max[tid + s]);\n        shared_min[tid] = max(shared_min[tid], shared_min[tid + s]);\n      }\n      __syncthreads();\n    }\n\n    if (tid == 0) {\n      cuda::atomic::Max(&max_ptr[cur_channel], shared_max[0]);\n      cuda::atomic::Max(&min_ptr[cur_channel], shared_min[0]);\n    }\n\n    // __syncthreads();\n    cur_channel += gridDim.x;\n  }\n}\n\ntemplate<typename T>\n__global__ void InitMaxMin(const int64_t elements, T* max_ptr, T* min_ptr) {\n  int64_t tid = threadIdx.x;\n  int64_t gid = (blockDim.x * blockIdx.x) + tid;\n\n  while (gid < elements) {\n    max_ptr[gid] = -FLT_MAX;\n    min_ptr[gid] = -FLT_MAX;\n    gid += gridDim.x * blockDim.x;\n  }\n}\n\ntemplate<typename T>\n__global__ void CalScaleZeroPointSymmetric(const T* max_ptr, const T* min_ptr,\n                                           const int64_t elements, const double quantization_bit,\n                                           T* scale, T* zero_point) {\n  int64_t tid = threadIdx.x;\n  int64_t gid = (blockDim.x * blockIdx.x) + tid;\n\n  while (gid < elements) {\n    T weight_max = max(fabs(max_ptr[gid]), fabs(min_ptr[gid]));\n    T denominator = static_cast<T>(pow(2.0, quantization_bit - 1)) - 1;\n    scale[gid] = weight_max / denominator;\n    zero_point[gid] = 0;\n    gid += gridDim.x * blockDim.x;\n  }\n}\n\ntemplate<typename T>\n__global__ void CalScaleZeroPointAffine(const T* max_ptr, const T* min_ptr, const int64_t elements,\n                                        const double quantization_bit, T* scale, T* zero_point) {\n  int64_t tid = threadIdx.x;\n  int64_t gid = (blockDim.x * blockIdx.x) + tid;\n\n  while (gid < elements) {\n    T denominator = static_cast<T>(pow(2.0, quantization_bit)) - 1;\n    T min = -min_ptr[gid];\n    T s = (max_ptr[gid] - min) / denominator;\n    scale[gid] = s;\n    zero_point[gid] = -nearbyint(min / s);\n    gid += gridDim.x * blockDim.x;\n  }\n}\n\ntemplate<typename T>\n__global__ void CalScaleZeroPointCambricon(const T* max_ptr, const T* min_ptr,\n                                           const int64_t elements, const double quantization_bit,\n                                           T* scale, T* zero_point) {\n  int64_t tid = threadIdx.x;\n  int64_t gid = (blockDim.x * blockIdx.x) + tid;\n\n  while (gid < elements) {\n    T weight_max = max(fabs(max_ptr[gid]), fabs(min_ptr[gid]));\n    // T denominator = static_cast<T>(pow(2.0, quantization_bit - 1)) - 1;\n    scale[gid] = floor(log2(weight_max)) - (quantization_bit - 2);\n    zero_point[gid] = 0;\n    gid += gridDim.x * blockDim.x;\n  }\n}\n\nep::CudaLaunchConfig GetLaunchConfig(ep::CudaStream* stream, size_t thread_num,\n                                     size_t shared_mem_size) {\n  ep::CudaLaunchConfig config;\n  stream->InitLaunchConfigWithWaves(&config, thread_num, kCudaThreadsNumPerBlock, 1);\n  config.shared_mem_size = shared_mem_size;\n  return config;\n}\n\n}  // namespace\n\n#define LAUNCH_CUDA_KERNEL(func, stream, thread_num, shared_mem_size, ...) \\\n  (stream)->LaunchKernel(func, GetLaunchConfig((stream), thread_num, shared_mem_size), __VA_ARGS__);\n\ntemplate<typename T>\nclass GpuMinMaxObserverKernel final : public user_op::OpKernel {\n public:\n  GpuMinMaxObserverKernel() = default;\n  ~GpuMinMaxObserverKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    user_op::Tensor* scale = ctx->Tensor4ArgNameAndIndex(\"scale\", 0);\n    user_op::Tensor* zero_point = ctx->Tensor4ArgNameAndIndex(\"zero_point\", 0);\n    user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n\n    const std::string quantization_scheme = ctx->Attr<std::string>(\"quantization_scheme\");\n    const int32_t quantization_bit = ctx->Attr<int32_t>(\"quantization_bit\");\n    const bool per_layer_quantization = ctx->Attr<bool>(\"per_layer_quantization\");\n    const std::string quantization_formula = ctx->Attr<std::string>(\"quantization_formula\");\n\n    const int64_t elements = in->shape_view().elem_cnt();\n    const int64_t channel = scale->shape_view().At(0);\n    const int64_t panel_size = elements / channel;\n    T* max_ptr = tmp_buffer->mut_dptr<T>();\n    T* min_ptr = max_ptr + channel;\n    auto* cuda_stream = ctx->stream()->As<ep::CudaStream>();\n    LAUNCH_CUDA_KERNEL((InitMaxMin<T>), cuda_stream, channel, 0, channel, max_ptr, min_ptr);\n\n    if (per_layer_quantization) {\n      LAUNCH_CUDA_KERNEL((ReduceMaxMinPerLayer<T>), cuda_stream, elements,\n                         kCudaThreadsNumPerBlock * 2 * sizeof(T), in->dptr<T>(), elements, max_ptr,\n                         min_ptr);\n    } else {  // per-channel quantization\n      // NOTE(Liang Depeng): each block of threads will be responsible for\n      //                     computing the max and min values of the whole channel.\n      LAUNCH_CUDA_KERNEL((ReduceMaxMinPerChannel<T>), cuda_stream,\n                         channel * kCudaThreadsNumPerBlock, kCudaThreadsNumPerBlock * 2 * sizeof(T),\n                         in->dptr<T>(), elements, channel, panel_size, max_ptr, min_ptr);\n    }\n\n    if (quantization_formula == \"google\") {\n      if (quantization_scheme == \"symmetric\") {\n        LAUNCH_CUDA_KERNEL((CalScaleZeroPointSymmetric<T>), cuda_stream, channel, 0, max_ptr,\n                           min_ptr, channel, static_cast<double>(quantization_bit),\n                           scale->mut_dptr<T>(), zero_point->mut_dptr<T>());\n      } else {  // quantization_scheme == \"affine\"\n        LAUNCH_CUDA_KERNEL((CalScaleZeroPointAffine<T>), cuda_stream, channel, 0, max_ptr, min_ptr,\n                           channel, static_cast<double>(quantization_bit), scale->mut_dptr<T>(),\n                           zero_point->mut_dptr<T>());\n      }\n    } else if (quantization_formula == \"cambricon\") {\n      if (!per_layer_quantization) {\n        UNIMPLEMENTED() << \" per-channel mode is not supported in cambricon scheme\";\n      }\n      LAUNCH_CUDA_KERNEL((CalScaleZeroPointCambricon<T>), cuda_stream, channel, 0, max_ptr, min_ptr,\n                         channel, static_cast<double>(quantization_bit), scale->mut_dptr<T>(),\n                         zero_point->mut_dptr<T>());\n    } else {\n      UNIMPLEMENTED();\n    }\n  }\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_MIN_MAX_OBSERVER_KERNEL(dtype)                                         \\\n  REGISTER_USER_KERNEL(\"min_max_observer\")                                              \\\n      .SetCreateFn<GpuMinMaxObserverKernel<dtype>>()                                    \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)                  \\\n                       && (user_op::HobDataType(\"in\", 0) == GetDataType<dtype>::value)) \\\n      .SetInferTmpSizeFn([](user_op::InferContext* ctx) -> size_t {                     \\\n        size_t tmp_buffer_size = 1;                                                     \\\n        if (ctx->Attr<bool>(\"per_layer_quantization\") == false) {                       \\\n          const Shape& in_shape = ctx->InputShape(\"in\", 0);                             \\\n          tmp_buffer_size = in_shape.At(0);                                             \\\n        }                                                                               \\\n        return 2 * tmp_buffer_size * sizeof(dtype);                                     \\\n      })\n\nREGISTER_MIN_MAX_OBSERVER_KERNEL(float);\nREGISTER_MIN_MAX_OBSERVER_KERNEL(double);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/mode_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/kernel_util.h\"\n#include \"oneflow/core/thread/thread_manager.h\"\n\nnamespace oneflow {\n\ntemplate<typename Context>\nstd::unique_ptr<ep::primitive::Memcpy> NewMemcpyPrimitive(Context* ctx) {\n  return ep::primitive::NewPrimitive<ep::primitive::MemcpyFactory>(\n      ctx->device_type(), ep::primitive::MemcpyKind::kDtoD);\n}\ntemplate<typename T>\nclass CpuModeKernel final : public user_op::OpKernel {\n public:\n  CpuModeKernel() = default;\n  ~CpuModeKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"input\", 0);\n    const int64_t num_axes = in->shape_view().NumAxes();\n    const int64_t size = in->shape_view().elem_cnt();\n    if (size == 0) return;\n    const int64_t stride = in->shape_view().At(num_axes - 1);\n    const int64_t instance_num = size / stride;\n    user_op::Tensor* values = ctx->Tensor4ArgNameAndIndex(\"values\", 0);\n    user_op::Tensor* indices = ctx->Tensor4ArgNameAndIndex(\"indices\", 0);\n    user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n\n    auto memcpy = NewMemcpyPrimitive(ctx);\n    CHECK(memcpy);\n    memcpy->Launch(ctx->stream(), tmp_buffer->mut_dptr<void>(), in->dptr<void>(), size * sizeof(T));\n    const int64_t thread_num =\n        std::min(instance_num, (int64_t)Singleton<ThreadPool>::Get()->thread_num());\n    const BalancedSplitter bs(instance_num, thread_num);\n    BlockingCounter bc(thread_num);\n    FOR_RANGE(int64_t, thread_id, 0, thread_num) {\n      const Range range = bs.At(thread_id);\n      Singleton<ThreadPool>::Get()->AddWork([=, &bc]() {\n        FOR_RANGE(int64_t, i, range.begin(), range.end()) {\n          T* in_ptr = tmp_buffer->mut_dptr<T>() + i * stride;\n          T* val_ptr = values->mut_dptr<T>() + i;\n          int64_t* ind_ptr = indices->mut_dptr<int64_t>() + i;\n          std::vector<std::pair<T, int64_t>> elements(stride);\n          T mode = 0;\n          int64_t mode_idx = 0;\n          int64_t temp_freq = 0;\n          int64_t max_freq = 0;\n          FOR_RANGE(int64_t, idx, 0, stride) {\n            elements[idx] = std::make_pair(*(in_ptr + idx), idx);\n          }\n          std::sort(elements.begin(), elements.end(),\n                    [=](const auto& i, const auto& j) { return i.first < j.first; });\n          FOR_RANGE(int64_t, idx, 0, stride) {\n            temp_freq++;\n            if ((idx == stride - 1) || (elements[idx].first != elements[idx + 1].first)) {\n              if (temp_freq > max_freq) {\n                mode = elements[idx].first;\n                mode_idx = elements[idx].second;\n                max_freq = temp_freq;\n              }\n              temp_freq = 0;\n            }\n          }\n          *val_ptr = mode;\n          *ind_ptr = mode_idx;\n        }\n        bc.Decrease();\n      });\n    }\n    bc.WaitForeverUntilCntEqualZero();\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_CPU_MODE_KERNEL(dtype)                                                    \\\n  REGISTER_USER_KERNEL(\"mode\")                                                             \\\n      .SetCreateFn<CpuModeKernel<dtype>>()                                                 \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)                      \\\n                       && (user_op::HobDataType(\"input\", 0) == GetDataType<dtype>::value)) \\\n      .SetInferTmpSizeFn([](user_op::InferContext* ctx) -> size_t {                        \\\n        return ctx->InputShape(\"input\", 0).elem_cnt() * sizeof(dtype);                     \\\n      });\n\nREGISTER_CPU_MODE_KERNEL(float)\nREGISTER_CPU_MODE_KERNEL(double)\nREGISTER_CPU_MODE_KERNEL(int8_t)\nREGISTER_CPU_MODE_KERNEL(uint8_t)\nREGISTER_CPU_MODE_KERNEL(int32_t)\nREGISTER_CPU_MODE_KERNEL(int64_t)\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/model_update_kernel_util.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/user/kernels/model_update_kernel_util.h\"\n\nnamespace oneflow {\n\nnamespace {\n\n// For bias correction compute in CPU.\ntemplate<typename T>\nT Fastpow(T a, int64_t b) {\n  T ans = static_cast<T>(1);\n  while (b) {\n    if (b & 1) { ans *= a; }\n    a *= a;\n    b >>= 1;\n  }\n  return ans;\n}\n\ntemplate<typename T>\nvoid SumSquares2(int64_t n, const T* src0, T* dst0, const T* src1, T* dst1) {\n  *dst0 += cblas_dot<T>(n, src0, 1, src0, 1);\n  *dst1 += cblas_dot<T>(n, src1, 1, src1, 1);\n}\n\n}  // namespace\n\ntemplate<typename T, typename G, typename C>\nstruct SGDUpdateKernelUtil<DeviceType::kCPU, T, G, C> {\n  static void Update(ep::Stream* stream, int64_t n, T scale, float l1, float l2, float weight_decay,\n                     float learning_rate_val, float lr_scale, const float* learning_rate,\n                     const T* scale_by_ptr, const int64_t* skip_if, const G* model_diff, T* model,\n                     C* model_copy);\n};\n\ntemplate<typename T, typename G, typename C>\nvoid SGDUpdateKernelUtil<DeviceType::kCPU, T, G, C>::Update(\n    ep::Stream* stream, int64_t n, T scale, float l1, float l2, float weight_decay,\n    float learning_rate_val, float lr_scale, const float* learning_rate, const T* scale_by_ptr,\n    const int64_t* skip_if, const G* model_diff, T* model, C* model_copy) {\n  if (skip_if != nullptr && *skip_if != 0) { return; }\n  if (learning_rate != nullptr) { learning_rate_val = *learning_rate; }\n  if (scale_by_ptr != nullptr) { scale *= *scale_by_ptr; }\n  learning_rate_val *= lr_scale;\n  for (int64_t i = 0; i != n; ++i) {\n    if (model_copy != nullptr) {\n      FusedSGDUpdateFunctor<T, G, C>()(model_diff + i, model + i, model_copy + i, scale, l1, l2,\n                                       weight_decay, learning_rate_val);\n    } else {\n      SGDUpdateFunctor<T, G>()(model_diff + i, model + i, scale, l1, l2, weight_decay,\n                               learning_rate_val);\n    }\n  }\n}\n\ntemplate struct SGDUpdateKernelUtil<DeviceType::kCPU, float, float, float16>;\ntemplate struct SGDUpdateKernelUtil<DeviceType::kCPU, double, double, float16>;\n\ntemplate<typename T, typename K, typename IDX>\nstruct IndexedSlicesSGDUpdateKernelUtil<DeviceType::kCPU, T, K, IDX> {\n  static void Update(ep::Stream* stream, float weight_decay, float lr_scale, int64_t num_indices,\n                     int64_t feature_size, int64_t lower_bound, int64_t upper_bound,\n                     const IDX* num_unique_instance, const float* learning_rate, const K* indices,\n                     const T* values, T* model);\n};\n\ntemplate<typename T, typename K, typename IDX>\nvoid IndexedSlicesSGDUpdateKernelUtil<DeviceType::kCPU, T, K, IDX>::Update(\n    ep::Stream* stream, float weight_decay, float lr_scale, int64_t num_indices,\n    int64_t feature_size, int64_t lower_bound, int64_t upper_bound, const IDX* num_unique_instance,\n    const float* learning_rate, const K* indices, const T* values, T* model) {\n  const int64_t n = *num_unique_instance * feature_size;\n  T lr = *learning_rate;\n  lr *= lr_scale;\n  FOR_RANGE(int64_t, i, 0, n) {\n    const IDX indices_idx = i / feature_size;\n    const IDX inner_idx = i - indices_idx * feature_size;\n    const IDX instance_id = indices[indices_idx];\n    if (instance_id >= lower_bound && instance_id < upper_bound) {\n      const IDX model_idx = (instance_id - lower_bound) * feature_size + inner_idx;\n      SGDUpdateFunctor<T, T>()(values + i, model + model_idx, static_cast<T>(1), 0.0, 0.0,\n                               weight_decay, lr);\n    }\n  }\n}\n\n#define INITIATE_INDEXED_SLICES_SGD_UPDATE_KERNEL_UTIL_CPU(val_type_pair, key_type_pair,  \\\n                                                           idx_type_pair)                 \\\n  template struct IndexedSlicesSGDUpdateKernelUtil<                                       \\\n      DeviceType::kCPU, OF_PP_PAIR_FIRST(val_type_pair), OF_PP_PAIR_FIRST(key_type_pair), \\\n      OF_PP_PAIR_FIRST(idx_type_pair)>;\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INITIATE_INDEXED_SLICES_SGD_UPDATE_KERNEL_UTIL_CPU,\n                                 FLOATING_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ);\n#undef INITIATE_INDEXED_SLICES_SGD_UPDATE_KERNEL_UTIL_CPU\n\ntemplate<typename T, typename G>\nstruct MomentumUpdateKernelUtil<DeviceType::kCPU, T, G> {\n  static void Update(ep::Stream* stream, int64_t n, T scale, float l1, float l2, float beta,\n                     float dampening, bool nesterov, bool maximize, float weight_decay,\n                     float learning_rate_val, float lr_scale, const float* learning_rate,\n                     const T* scale_by_ptr, const int64_t* skip_if, const G* model_diff, T* model,\n                     T* momentum);\n};\n\ntemplate<typename T, typename G>\nvoid MomentumUpdateKernelUtil<DeviceType::kCPU, T, G>::Update(\n    ep::Stream* stream, int64_t n, T scale, float l1, float l2, float beta, float dampening,\n    bool nesterov, bool maximize, float weight_decay, float learning_rate_val, float lr_scale,\n    const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, const G* model_diff,\n    T* model, T* momentum) {\n  if (skip_if != nullptr && *skip_if != 0) { return; }\n  if (learning_rate != nullptr) { learning_rate_val = *learning_rate; }\n  if (scale_by_ptr != nullptr) { scale *= *scale_by_ptr; }\n  learning_rate_val *= lr_scale;\n  for (int64_t i = 0; i != n; ++i) {\n    MomentumUpdateFunctor<T, G>()(model_diff + i, model + i, momentum + i, scale, l1, l2, beta,\n                                  dampening, nesterov, maximize, weight_decay, learning_rate_val);\n  }\n}\n\ntemplate struct MomentumUpdateKernelUtil<DeviceType::kCPU, float, float>;\ntemplate struct MomentumUpdateKernelUtil<DeviceType::kCPU, double, double>;\n\ntemplate<typename T, typename K, typename IDX>\nstruct IndexedSlicesMomentumMdUpdateKernelUtil<DeviceType::kCPU, T, K, IDX> {\n  static void Update(ep::Stream* stream, T beta, float dampening, bool nesterov, bool maximize,\n                     float weight_decay, float lr_scale, int64_t num_instance, int64_t feature_size,\n                     int64_t lower_bound, int64_t upper_bound, const IDX* num_unique_instance,\n                     const float* learning_rate, const K* indices, const T* values, T* model,\n                     T* momentum);\n};\n\ntemplate<typename T, typename K, typename IDX>\nvoid IndexedSlicesMomentumMdUpdateKernelUtil<DeviceType::kCPU, T, K, IDX>::Update(\n    ep::Stream* stream, T beta, float dampening, bool nesterov, bool maximize, float weight_decay,\n    float lr_scale, int64_t num_instance, int64_t feature_size, int64_t lower_bound,\n    int64_t upper_bound, const IDX* num_unique_instance, const float* learning_rate,\n    const K* indices, const T* values, T* model, T* momentum) {\n  const int64_t n = *num_unique_instance * feature_size;\n  T lr = *learning_rate;\n  lr *= lr_scale;\n  for (int64_t i = 0; i != n; ++i) {\n    const IDX indices_idx = i / feature_size;\n    const IDX inner_idx = i - indices_idx * feature_size;\n    const IDX instance_id = indices[indices_idx];\n    if (instance_id >= lower_bound && instance_id < upper_bound) {\n      const IDX model_idx = (instance_id - lower_bound) * feature_size + inner_idx;\n      MomentumUpdateFunctor<T, T>()(values + i, model + model_idx, momentum + model_idx, 1.0, 0.0,\n                                    0.0, beta, dampening, nesterov, maximize, weight_decay, lr);\n    }\n  }\n}\n\n#define INSTANTIATE_INDEXED_SLICES_MOMENTUM_MODEL_UPDATE_KERNEL_UTIL_CPU(                 \\\n    val_type_pair, key_type_pair, idx_type_pair)                                          \\\n  template struct IndexedSlicesMomentumMdUpdateKernelUtil<                                \\\n      DeviceType::kCPU, OF_PP_PAIR_FIRST(val_type_pair), OF_PP_PAIR_FIRST(key_type_pair), \\\n      OF_PP_PAIR_FIRST(idx_type_pair)>;\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_INDEXED_SLICES_MOMENTUM_MODEL_UPDATE_KERNEL_UTIL_CPU,\n                                 FLOATING_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ);\n#undef INSTANTIATE_INDEXED_SLICES_MOMENTUM_MODEL_UPDATE_KERNEL_UTIL_CPU\n\ntemplate<typename T, typename G, typename C>\nstruct AdamUpdateKernelUtil<DeviceType::kCPU, T, G, C> {\n  static void Update(ep::Stream* stream, int64_t n, T scale, float l1, float l2, float beta1,\n                     float beta2, float epsilon, float weight_decay, bool amsgrad,\n                     bool do_bias_correction, float learning_rate_val, float lr_scale,\n                     float bias_correction1_val, float bias_correction2_val,\n                     const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if,\n                     const float* bias_correction1, const float* bias_correction2,\n                     const G* model_diff, T* model, C* model_copy, T* m, T* v, T* max_v);\n};\n\ntemplate<typename T, typename G, typename C>\nvoid AdamUpdateKernelUtil<DeviceType::kCPU, T, G, C>::Update(\n    ep::Stream* stream, int64_t n, T scale, float l1, float l2, float beta1, float beta2,\n    float epsilon, float weight_decay, bool amsgrad, bool do_bias_correction,\n    float learning_rate_val, float lr_scale, float bias_correction1_val, float bias_correction2_val,\n    const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if,\n    const float* bias_correction1_ptr, const float* bias_correction2_ptr, const G* model_diff,\n    T* model, C* model_copy, T* m, T* v, T* max_v) {\n  if (skip_if != nullptr && *skip_if != 0) { return; }\n  if (learning_rate != nullptr) { learning_rate_val = *learning_rate; }\n  if (scale_by_ptr != nullptr) { scale *= *scale_by_ptr; }\n  if (bias_correction1_ptr != nullptr) { bias_correction1_val = *bias_correction1_ptr; }\n  if (bias_correction2_ptr != nullptr) { bias_correction2_val = *bias_correction2_ptr; }\n\n  learning_rate_val *= lr_scale;\n  FOR_RANGE(int64_t, i, 0, n) {\n    if (model_copy != nullptr) {\n      FusedAdamUpdateFunctor<T, G, C>()(model_diff + i, model + i, model_copy + i, m + i, v + i,\n                                        max_v + i, scale, l1, l2, beta1, beta2, epsilon,\n                                        weight_decay, amsgrad, bias_correction1_val,\n                                        bias_correction2_val, learning_rate_val);\n    } else {\n      AdamUpdateFunctor<T, G>()(model_diff + i, model + i, m + i, v + i, max_v + i, scale, l1, l2,\n                                beta1, beta2, epsilon, weight_decay, amsgrad, bias_correction1_val,\n                                bias_correction2_val, learning_rate_val);\n    }\n  }\n}\n\ntemplate struct AdamUpdateKernelUtil<DeviceType::kCPU, float, float, float16>;\ntemplate struct AdamUpdateKernelUtil<DeviceType::kCPU, double, double, float16>;\n\ntemplate<typename T, typename K, typename IDX>\nstruct IndexedSlicesAdamMdUpdateKernelUtil<DeviceType::kCPU, T, K, IDX> {\n  static void Update(ep::Stream* stream, float beta1, float beta2, float epsilon,\n                     float weight_decay, bool amsgrad, bool do_bias_correction, float lr,\n                     float lr_scale, int64_t num_instance, int64_t feature_size,\n                     int64_t lower_bound, int64_t upper_bound, const IDX* num_unique_instance,\n                     const float* learning_rate, const float* bias_correction1_ptr,\n                     const float* bias_correction2_ptr, const K* indices, const T* values, T* model,\n                     T* m, T* v, T* max_v) {\n    if (learning_rate != nullptr) { lr = *learning_rate; }\n    lr *= lr_scale;\n    float bias_correction1 = 1.0;\n    float bias_correction2 = 1.0;\n    if (bias_correction1_ptr != nullptr) { bias_correction1 = *bias_correction1_ptr; }\n    if (bias_correction2_ptr != nullptr) { bias_correction2 = *bias_correction2_ptr; }\n\n    const int64_t n = *num_unique_instance * feature_size;\n    FOR_RANGE(int64_t, i, 0, n) {\n      const IDX indices_idx = i / feature_size;\n      const IDX inner_idx = i - indices_idx * feature_size;\n      const IDX instance_id = indices[indices_idx];\n\n      if (instance_id >= lower_bound && instance_id < upper_bound) {\n        const IDX model_idx = (instance_id - lower_bound) * feature_size + inner_idx;\n        AdamUpdateFunctor<T, T>()(values + i, model + model_idx, m + model_idx, v + model_idx,\n                                  max_v + i, /*scale=*/1.0, /*l1=*/0.0, /*l2=*/0.0, beta1, beta2,\n                                  epsilon, weight_decay, amsgrad, bias_correction1,\n                                  bias_correction2, lr);\n      }\n    }\n  }\n};\n\n#define INSTANTIATE_INDEXED_SLICES_ADAM_MODEL_UPDATE_KERNEL_UTIL_CPU(val_type_pair, key_type_pair, \\\n                                                                     idx_type_pair)                \\\n  template struct IndexedSlicesAdamMdUpdateKernelUtil<                                             \\\n      DeviceType::kCPU, OF_PP_PAIR_FIRST(val_type_pair), OF_PP_PAIR_FIRST(key_type_pair),          \\\n      OF_PP_PAIR_FIRST(idx_type_pair)>;\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_INDEXED_SLICES_ADAM_MODEL_UPDATE_KERNEL_UTIL_CPU,\n                                 FLOATING_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ);\n#undef INSTANTIATE_INDEXED_SLICES_ADAM_MODEL_UPDATE_KERNEL_UTIL_CPU\n\ntemplate<typename T, typename G>\nstruct AdagradUpdateKernelUtil<DeviceType::kCPU, T, G> {\n  static void Update(ep::Stream* stream, int64_t n, T scale, float l1, float l2, float lr_decay,\n                     float epsilon, float weight_decay, float learning_rate_val, float lr_scale,\n                     int64_t train_step, const float* learning_rate, const int64_t* train_step_ptr,\n                     const T* scale_by_ptr, const int64_t* skip_if, const G* model_diff, T* model,\n                     T* sum);\n};\n\ntemplate<typename T, typename G>\nvoid AdagradUpdateKernelUtil<DeviceType::kCPU, T, G>::Update(\n    ep::Stream* stream, int64_t n, T scale, float l1, float l2, float lr_decay, float epsilon,\n    float weight_decay, float learning_rate_val, float lr_scale, int64_t train_step,\n    const float* learning_rate, const int64_t* train_step_ptr, const T* scale_by_ptr,\n    const int64_t* skip_if, const G* model_diff, T* model, T* sum) {\n  if (skip_if != nullptr && *skip_if != 0) { return; }\n  if (learning_rate != nullptr) { learning_rate_val = *learning_rate; }\n  if (train_step_ptr != nullptr) {\n    train_step = *train_step_ptr + 1;\n  }  // train_step_ptr start from zero.\n  if (scale_by_ptr != nullptr) { scale *= *scale_by_ptr; }\n  learning_rate_val = learning_rate_val * lr_scale / (1 + (train_step - 1) * lr_decay);\n\n  FOR_RANGE(int64_t, i, 0, n) {\n    AdagradUpdateFunctor<T, G>()(model_diff + i, model + i, sum + i, scale, l1, l2, epsilon,\n                                 weight_decay, learning_rate_val);\n  }\n}\n\ntemplate struct AdagradUpdateKernelUtil<DeviceType::kCPU, float, float>;\ntemplate struct AdagradUpdateKernelUtil<DeviceType::kCPU, double, double>;\n\ntemplate<typename T, typename G>\nstruct LambUpdateKernelUtil<DeviceType::kCPU, T, G> {\n  static void Update(ep::Stream* stream, int64_t n, float scale, float l1, float l2, float beta1,\n                     float beta2, float epsilon, float weight_decay, float learning_rate_val,\n                     float lr_scale, bool do_bias_correction, float bias_correction1_val,\n                     float bias_correction2_val, const float* learning_rate_ptr,\n                     const float* bias_correction1_ptr, const float* bias_correction2_ptr,\n                     const T* scale_by_ptr, const int64_t* skip_if, const G* model_diff,\n                     T* adam_diff, T* model, T* m, T* v, T* norm_buffer);\n};\n\ntemplate<typename T, typename G>\nvoid LambUpdateKernelUtil<DeviceType::kCPU, T, G>::Update(\n    ep::Stream* stream, int64_t n, float scale, float l1, float l2, float beta1, float beta2,\n    float epsilon, float weight_decay, float learning_rate_val, float lr_scale,\n    bool do_bias_correction, float bias_correction1_val, float bias_correction2_val,\n    const float* learning_rate_ptr, const float* bias_correction1_ptr,\n    const float* bias_correction2_ptr, const T* scale_by_ptr, const int64_t* skip_if,\n    const G* model_diff, T* adam_diff, T* model, T* m, T* v, T* norm_buffer) {\n  if (skip_if != nullptr && *skip_if != 0) { return; }\n  if (learning_rate_ptr != nullptr) { learning_rate_val = *learning_rate_ptr; }\n  if (scale_by_ptr != nullptr) { scale *= *scale_by_ptr; }\n  if (bias_correction1_ptr != nullptr) { bias_correction1_val = *bias_correction1_ptr; }\n  if (bias_correction2_ptr != nullptr) { bias_correction2_val = *bias_correction2_ptr; }\n\n  FOR_RANGE(int64_t, i, 0, n) {\n    LambGradFunctor<T, G>()(model_diff + i, adam_diff + i, model + i, m + i, v + i, scale, l1, l2,\n                            beta1, beta2, epsilon, do_bias_correction, bias_correction1_val,\n                            bias_correction2_val);\n  }\n  T* w_norm_2 = norm_buffer;\n  T* g_norm_2 = norm_buffer + 1;\n  Memset<DeviceType::kCPU>(stream, norm_buffer, 0, 2 * sizeof(T));\n  SumSquares2(n, model, w_norm_2, adam_diff, g_norm_2);\n  learning_rate_val *= lr_scale;\n  const float lr = LambLRFunctor<T>()(learning_rate_val, w_norm_2, g_norm_2);\n  FOR_RANGE(int64_t, i, 0, n) {\n    LambUpdateFunctor<T>()(lr, weight_decay, adam_diff + i, model + i);\n  }\n}\n\ntemplate struct LambUpdateKernelUtil<DeviceType::kCPU, float, float>;\ntemplate struct LambUpdateKernelUtil<DeviceType::kCPU, double, double>;\n\ntemplate<>\nstruct BiasCorrectionFactorKernelUtil<DeviceType::kCPU> {\n  static void BiasCorrectionFactorCompute(ep::Stream* stream, float beta, const int64_t* train_step,\n                                          float* out);\n};\n\nvoid BiasCorrectionFactorKernelUtil<DeviceType::kCPU>::BiasCorrectionFactorCompute(\n    ep::Stream* stream, float beta, const int64_t* train_step, float* out) {\n  const float bias_correction_factor = 1.0 - Fastpow<float>(beta, *train_step + 1);\n  *out = bias_correction_factor;\n}\n\ntemplate<typename T, typename G>\nstruct RmsPropUpdateKernelUtil<DeviceType::kCPU, T, G> {\n  static void Update(ep::Stream* stream, int64_t n, T scale, float l1, float l2, bool centered,\n                     float epsilon, float weight_decay, float decay_rate, float learning_rate_val,\n                     float lr_scale, const float* learning_rate, const T* scale_by_ptr,\n                     const int64_t* skip_if, const G* model_diff, T* model, T* mean_square,\n                     T* mean_gradient);\n};\n\ntemplate<typename T, typename G>\nvoid RmsPropUpdateKernelUtil<DeviceType::kCPU, T, G>::Update(\n    ep::Stream* stream, int64_t n, T scale, float l1, float l2, bool centered, float epsilon,\n    float weight_decay, float decay_rate, float learning_rate_val, float lr_scale,\n    const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, const G* model_diff,\n    T* model, T* mean_square, T* mean_gradient) {\n  if (skip_if != nullptr && *skip_if != 0) { return; }\n  if (learning_rate != nullptr) { learning_rate_val = *learning_rate; }\n  if (scale_by_ptr != nullptr) { scale *= *scale_by_ptr; }\n  learning_rate_val *= lr_scale;\n  if (centered) {\n    FOR_RANGE(int64_t, i, 0, n) {\n      RmsPropUpdateFunctor<T, G, true>()(model_diff + i, model + i, n, scale, l1, l2,\n                                         mean_square + i, mean_gradient + i, epsilon, weight_decay,\n                                         decay_rate, learning_rate_val);\n    }\n  } else {\n    FOR_RANGE(int64_t, i, 0, n) {\n      RmsPropUpdateFunctor<T, G, false>()(model_diff + i, model + i, n, scale, l1, l2,\n                                          mean_square + i, nullptr, epsilon, weight_decay,\n                                          decay_rate, learning_rate_val);\n    }\n  }\n}\n\ntemplate struct RmsPropUpdateKernelUtil<DeviceType::kCPU, float, float>;\ntemplate struct RmsPropUpdateKernelUtil<DeviceType::kCPU, double, double>;\n\ntemplate<typename T, typename G>\nstruct LarsUpdateKernelUtil<DeviceType::kCPU, T, G> {\n  static void Update(ep::Stream* stream, int64_t n, T scale, float l1, float l2,\n                     float momentum_beta, float epsilon, float lars_coefficient, float weight_decay,\n                     float lr_scale, const float* learning_rate, const T* scale_by_ptr,\n                     const int64_t* skip_if, const G* model_diff, T* model, T* momentum,\n                     T* data_tmp, T* model_diff_tmp);\n};\n\ntemplate<typename T, typename G>\nvoid LarsUpdateKernelUtil<DeviceType::kCPU, T, G>::Update(\n    ep::Stream* stream, int64_t n, T scale, float l1, float l2, float momentum_beta, float epsilon,\n    float lars_coefficient, float weight_decay, float lr_scale, const float* learning_rate,\n    const T* scale_by_ptr, const int64_t* skip_if, const G* model_diff, T* model, T* momentum,\n    T* data_tmp, T* model_diff_tmp) {\n  if (skip_if != nullptr && *skip_if != 0) { return; }\n  if (scale_by_ptr != nullptr) { scale *= *scale_by_ptr; }\n  T model_norm = data_tmp[0];\n  T model_diff_norm = data_tmp[1];\n  FOR_RANGE(int64_t, i, 0, n) {\n    model_diff_tmp[i] =\n        CastScaleRegularizeGradientFunctor<T, G>()(model_diff[i], model[i], scale, l1, l2);\n  }\n  Memset<DeviceType::kCPU>(stream, data_tmp, 0, 2 * sizeof(T));\n  SumSquares2(n, model, &model_norm, model_diff_tmp, &model_diff_norm);\n  model_norm = std::sqrt(model_norm);\n  model_diff_norm = std::sqrt(model_diff_norm);\n  T lars = static_cast<T>(1);\n  if (model_norm > 0 && model_diff_norm > 0) {\n    lars = lars_coefficient * model_norm / (epsilon + model_diff_norm + weight_decay * model_norm);\n  }\n  T lr = *learning_rate;\n  lr *= lr_scale;\n  T local_learning_rate = lr * lars;\n  FOR_RANGE(int64_t, i, 0, n) {\n    LarsUpdateFunctor<T>()(model_diff_tmp + i, model + i, momentum_beta, momentum + i, weight_decay,\n                           local_learning_rate);\n  }\n}\n\ntemplate struct LarsUpdateKernelUtil<DeviceType::kCPU, float, float>;\ntemplate struct LarsUpdateKernelUtil<DeviceType::kCPU, double, double>;\n\ntemplate<typename T, typename G>\nstruct FtrlUpdateKernelUtil<DeviceType::kCPU, T, G> {\n  static void Update(ep::Stream* stream, int64_t n, T scale, float l1, float l2, float lr_power,\n                     float lambda1, float lambda2, float beta, float weight_decay,\n                     float learning_rate_val, float lr_scale, const float* learning_rate,\n                     const T* scale_by_ptr, const int64_t* skip_if, const G* model_diff, T* model,\n                     T* accumulate, T* z);\n};\n\ntemplate<typename T, typename G>\nvoid FtrlUpdateKernelUtil<DeviceType::kCPU, T, G>::Update(\n    ep::Stream* stream, int64_t n, T scale, float l1, float l2, float lr_power, float lambda1,\n    float lambda2, float beta, float weight_decay, float learning_rate_val, float lr_scale,\n    const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, const G* model_diff,\n    T* model, T* accumulate, T* z) {\n  if (skip_if != nullptr && *skip_if != 0) { return; }\n  if (learning_rate != nullptr) { learning_rate_val = *learning_rate; }\n  if (scale_by_ptr != nullptr) { scale *= *scale_by_ptr; }\n  learning_rate_val *= lr_scale;\n  for (int64_t i = 0; i != n; ++i) {\n    FtrlUpdateFunctor<T, G>()(model_diff + i, model + i, accumulate + i, z + i, scale, l1, l2,\n                              lr_power, lambda1, lambda2, beta, weight_decay, learning_rate_val);\n  }\n}\n\ntemplate struct FtrlUpdateKernelUtil<DeviceType::kCPU, float, float>;\ntemplate struct FtrlUpdateKernelUtil<DeviceType::kCPU, double, double>;\n\ntemplate<typename T, typename G>\nstruct AdadeltaUpdateKernelUtil<DeviceType::kCPU, T, G> {\n  static void Update(ep::Stream* stream, int64_t n, T scale, float l1, float l2, float rho,\n                     float epsilon, bool maximize, float weight_decay, float learning_rate_val,\n                     float lr_scale, const float* learning_rate, const T* scale_by_ptr,\n                     const int64_t* skip_if, const G* model_diff, T* model, T* square_avgs,\n                     T* acc_deltas);\n};\n\ntemplate<typename T, typename G>\nvoid AdadeltaUpdateKernelUtil<DeviceType::kCPU, T, G>::Update(\n    ep::Stream* stream, int64_t n, T scale, float l1, float l2, float rho, float epsilon,\n    bool maximize, float weight_decay, float learning_rate_val, float lr_scale,\n    const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, const G* model_diff,\n    T* model, T* square_avgs, T* acc_deltas) {\n  if (skip_if != nullptr && *skip_if != 0) { return; }\n  if (learning_rate != nullptr) { learning_rate_val = *learning_rate; }\n  if (scale_by_ptr != nullptr) { scale *= *scale_by_ptr; }\n  learning_rate_val *= lr_scale;\n  for (int64_t i = 0; i != n; ++i) {\n    AdadeltaUpdateFunctor<T, G>()(model_diff + i, model + i, square_avgs + i, acc_deltas + i, scale,\n                                  l1, l2, rho, epsilon, maximize, weight_decay, learning_rate_val);\n  }\n}\n\ntemplate struct AdadeltaUpdateKernelUtil<DeviceType::kCPU, float, float>;\ntemplate struct AdadeltaUpdateKernelUtil<DeviceType::kCPU, double, double>;\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/model_update_kernel_util.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/cuda/atomic.cuh\"\n#include \"oneflow/user/kernels/model_update_kernel_util.h\"\n#include <cub/cub.cuh>\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename T, typename G, typename C>\n__global__ void SGDUpdateGpu(int64_t n, T scale, float l1, float l2, float weight_decay,\n                             float learning_rate_val, float lr_scale, const float* learning_rate,\n                             const T* scale_by_ptr, const int64_t* skip_if, const G* model_diff,\n                             T* model, C* model_copy) {\n  if (skip_if != nullptr && *skip_if != 0) { return; }\n  if (learning_rate != nullptr) { learning_rate_val = *learning_rate; }\n  if (scale_by_ptr != nullptr) { scale *= *scale_by_ptr; }\n  learning_rate_val *= lr_scale;\n  CUDA_1D_KERNEL_LOOP(i, n) {\n    if (model_copy != nullptr) {\n      FusedSGDUpdateFunctor<T, G, C>()(model_diff + i, model + i, model_copy + i, scale, l1, l2,\n                                       weight_decay, learning_rate_val);\n    } else {\n      SGDUpdateFunctor<T, G>()(model_diff + i, model + i, scale, l1, l2, weight_decay,\n                               learning_rate_val);\n    }\n  }\n}\n\ntemplate<typename T, typename K, typename IDX>\n__global__ void IndexedSlicesSGDUpdateGpu(float weight_decay, float lr_scale,\n                                          const IDX feature_size, const int64_t lower_bound,\n                                          const int64_t upper_bound, const IDX* num_unique_instance,\n                                          const float* learning_rate, const K* indices,\n                                          const T* values, T* model) {\n  const int64_t n = *num_unique_instance * feature_size;\n  T lr = *learning_rate;\n  lr *= lr_scale;\n  CUDA_1D_KERNEL_LOOP_T(IDX, i, n) {\n    const IDX indices_idx = i / feature_size;\n    const IDX inner_idx = i - indices_idx * feature_size;\n    const IDX instance_id = indices[indices_idx];\n    if (instance_id >= lower_bound && instance_id < upper_bound) {\n      const IDX model_idx = (instance_id - lower_bound) * feature_size + inner_idx;\n      SGDUpdateFunctor<T, T>()(values + i, model + model_idx, static_cast<T>(1), 0.0, 0.0,\n                               weight_decay, lr);\n    }\n  }\n}\n\ntemplate<typename T>\n__global__ void SumSquares2(int64_t n, const T* src0, T* dst0, const T* src1, T* dst1) {\n  T t_sum0 = 0;\n  T t_sum1 = 0;\n  CUDA_1D_KERNEL_LOOP(i, n) {\n    t_sum0 += src0[i] * src0[i];\n    t_sum1 += src1[i] * src1[i];\n  }\n  typedef cub::BlockReduce<T, kCudaThreadsNumPerBlock> BlockReduce;\n  __shared__ typename BlockReduce::TempStorage temp_storage0;\n  __shared__ typename BlockReduce::TempStorage temp_storage1;\n  T b_sum0 = BlockReduce(temp_storage0).Sum(t_sum0);\n  T b_sum1 = BlockReduce(temp_storage1).Sum(t_sum1);\n  if (threadIdx.x == 0) {\n    cuda::atomic::Add(dst0, b_sum0);\n    cuda::atomic::Add(dst1, b_sum1);\n  }\n}\n\n}  // namespace\n\ntemplate<typename T, typename G, typename C>\nstruct SGDUpdateKernelUtil<DeviceType::kCUDA, T, G, C> {\n  static void Update(ep::Stream* stream, int64_t n, T scale, float l1, float l2, float weight_decay,\n                     float learning_rate_val, float lr_scale, const float* learning_rate,\n                     const T* scale_by_ptr, const int64_t* skip_if, const G* model_diff, T* model,\n                     C* model_copy);\n};\n\ntemplate<typename T, typename G, typename C>\nvoid SGDUpdateKernelUtil<DeviceType::kCUDA, T, G, C>::Update(\n    ep::Stream* stream, int64_t n, T scale, float l1, float l2, float weight_decay,\n    float learning_rate_val, float lr_scale, const float* learning_rate, const T* scale_by_ptr,\n    const int64_t* skip_if, const G* model_diff, T* model, C* model_copy) {\n  SGDUpdateGpu<T, G, C><<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0,\n                          stream->As<ep::CudaStream>()->cuda_stream()>>>(\n      n, scale, l1, l2, weight_decay, learning_rate_val, lr_scale, learning_rate, scale_by_ptr,\n      skip_if, model_diff, model, model_copy);\n}\n\ntemplate<typename T, typename G>\nstruct SGDUpdateKernelUtil<DeviceType::kCUDA, T, G, float16> {\n  static void Update(ep::Stream* stream, int64_t n, T scale, float l1, float l2, float weight_decay,\n                     float learning_rate_val, float lr_scale, const float* learning_rate,\n                     const T* scale_by_ptr, const int64_t* skip_if, const G* model_diff, T* model,\n                     float16* model_copy);\n};\n\ntemplate<typename T, typename G>\nvoid SGDUpdateKernelUtil<DeviceType::kCUDA, T, G, float16>::Update(\n    ep::Stream* stream, int64_t n, T scale, float l1, float l2, float weight_decay,\n    float learning_rate_val, float lr_scale, const float* learning_rate, const T* scale_by_ptr,\n    const int64_t* skip_if, const G* model_diff, T* model, float16* model_copy) {\n  SGDUpdateKernelUtil<DeviceType::kCUDA, T, G, half>::Update(\n      stream, n, scale, l1, l2, weight_decay, learning_rate_val, lr_scale, learning_rate,\n      scale_by_ptr, skip_if, model_diff, model, reinterpret_cast<half*>(model_copy));\n}\n\ntemplate<typename T>\nstruct SGDUpdateKernelUtil<DeviceType::kCUDA, T, float16, float16> {\n  static void Update(ep::Stream* stream, int64_t n, T scale, float l1, float l2, float weight_decay,\n                     float learning_rate_val, float lr_scale, const float* learning_rate,\n                     const T* scale_by_ptr, const int64_t* skip_if, const float16* model_diff,\n                     T* model, float16* model_copy);\n};\n\ntemplate<typename T>\nvoid SGDUpdateKernelUtil<DeviceType::kCUDA, T, float16, float16>::Update(\n    ep::Stream* stream, int64_t n, T scale, float l1, float l2, float weight_decay,\n    float learning_rate_val, float lr_scale, const float* learning_rate, const T* scale_by_ptr,\n    const int64_t* skip_if, const float16* model_diff, T* model, float16* model_copy) {\n  SGDUpdateKernelUtil<DeviceType::kCUDA, T, half, half>::Update(\n      stream, n, scale, l1, l2, weight_decay, learning_rate_val, lr_scale, learning_rate,\n      scale_by_ptr, skip_if, reinterpret_cast<const half*>(model_diff), model,\n      reinterpret_cast<half*>(model_copy));\n}\n\ntemplate struct SGDUpdateKernelUtil<DeviceType::kCUDA, double, double, float16>;\ntemplate struct SGDUpdateKernelUtil<DeviceType::kCUDA, float, float, float16>;\ntemplate struct SGDUpdateKernelUtil<DeviceType::kCUDA, float, float16, float16>;\n\ntemplate<typename T, typename K, typename IDX>\nstruct IndexedSlicesSGDUpdateKernelUtil<DeviceType::kCUDA, T, K, IDX> {\n  static void Update(ep::Stream* stream, float weight_decay, float lr_scale, int64_t num_indices,\n                     int64_t feature_size, int64_t lower_bound, int64_t upper_bound,\n                     const IDX* num_unique_instance, const float* learning_rate, const K* indices,\n                     const T* values, T* model);\n};\n\ntemplate<typename T, typename K, typename IDX>\nvoid IndexedSlicesSGDUpdateKernelUtil<DeviceType::kCUDA, T, K, IDX>::Update(\n    ep::Stream* stream, float weight_decay, float lr_scale, int64_t num_indices,\n    int64_t feature_size, int64_t lower_bound, int64_t upper_bound, const IDX* num_unique_instance,\n    const float* learning_rate, const K* indices, const T* values, T* model) {\n  IndexedSlicesSGDUpdateGpu<T, K, IDX>\n      <<<BlocksNum4ThreadsNum(num_indices * feature_size), kCudaThreadsNumPerBlock, 0,\n         stream->As<ep::CudaStream>()->cuda_stream()>>>(\n          weight_decay, lr_scale, feature_size, lower_bound, upper_bound, num_unique_instance,\n          learning_rate, indices, values, model);\n}\n\n#define INITIATE_INDEXED_SLICES_SGD_UPDATE_KERNEL_UTIL_CUDA(val_type_pair, key_type_pair,  \\\n                                                            idx_type_pair)                 \\\n  template struct IndexedSlicesSGDUpdateKernelUtil<                                        \\\n      DeviceType::kCUDA, OF_PP_PAIR_FIRST(val_type_pair), OF_PP_PAIR_FIRST(key_type_pair), \\\n      OF_PP_PAIR_FIRST(idx_type_pair)>;\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INITIATE_INDEXED_SLICES_SGD_UPDATE_KERNEL_UTIL_CUDA,\n                                 FLOATING_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ, INT_DATA_TYPE_SEQ);\n#undef INITIATE_INDEXED_SLICES_SGD_UPDATE_KERNEL_UTIL_CUDA\n\nnamespace {\n\ntemplate<typename T, typename G>\n__global__ void MomentumUpdateGpu(int64_t n, T scale, float l1, float l2, float beta,\n                                  float dampening, bool nesterov, bool maximize, float weight_decay,\n                                  float learning_rate_val, float lr_scale,\n                                  const float* learning_rate, const T* scale_by_ptr,\n                                  const int64_t* skip_if, const G* model_diff, T* model,\n                                  T* momentum) {\n  if (skip_if != nullptr && *skip_if != 0) { return; }\n  if (learning_rate != nullptr) { learning_rate_val = *learning_rate; }\n  if (scale_by_ptr != nullptr) { scale *= *scale_by_ptr; }\n  learning_rate_val *= lr_scale;\n  CUDA_1D_KERNEL_LOOP(i, n) {\n    MomentumUpdateFunctor<T, G>()(model_diff + i, model + i, momentum + i, scale, l1, l2, beta,\n                                  dampening, nesterov, maximize, weight_decay, learning_rate_val);\n  }\n}\n\ntemplate<typename T, typename K, typename IDX>\n__global__ void IndexedSlicesMomentumUpdateGpu(T beta, float dampening, bool nesterov,\n                                               bool maximize, float weight_decay, float lr_scale,\n                                               int64_t feature_size, int64_t lower_bound,\n                                               int64_t upper_bound, const IDX* num_unique_instance,\n                                               const float* learning_rate, const K* indices,\n                                               const T* values, T* model, T* momentum) {\n  const int64_t n = *num_unique_instance * feature_size;\n  T lr = *learning_rate;\n  lr *= lr_scale;\n  CUDA_1D_KERNEL_LOOP(i, n) {\n    const IDX indices_idx = i / feature_size;\n    const IDX inner_idx = i - indices_idx * feature_size;\n    const IDX instance_id = indices[indices_idx];\n    if (instance_id >= lower_bound && instance_id < upper_bound) {\n      const IDX model_idx = (instance_id - lower_bound) * feature_size + inner_idx;\n      MomentumUpdateFunctor<T, T>()(values + i, model + model_idx, momentum + model_idx,\n                                    static_cast<T>(1), 0.0, 0.0, beta, dampening, nesterov,\n                                    maximize, weight_decay, lr);\n    }\n  }\n}\n}  // namespace\n\ntemplate<typename T, typename G>\nstruct MomentumUpdateKernelUtil<DeviceType::kCUDA, T, G> {\n  static void Update(ep::Stream* stream, int64_t n, T scale, float l1, float l2, float beta,\n                     float dampening, bool nesterov, bool maximize, float weight_decay,\n                     float learning_rate_val, float lr_scale, const float* learning_rate,\n                     const T* scale_by_ptr, const int64_t* skip_if, const G* model_diff, T* model,\n                     T* momentum);\n};\n\ntemplate<typename T, typename G>\nvoid MomentumUpdateKernelUtil<DeviceType::kCUDA, T, G>::Update(\n    ep::Stream* stream, int64_t n, T scale, float l1, float l2, float beta, float dampening,\n    bool nesterov, bool maximize, float weight_decay, float learning_rate_val, float lr_scale,\n    const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, const G* model_diff,\n    T* model, T* momentum) {\n  MomentumUpdateGpu<T, G><<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0,\n                            stream->As<ep::CudaStream>()->cuda_stream()>>>(\n      n, scale, l1, l2, beta, dampening, nesterov, maximize, weight_decay, learning_rate_val,\n      lr_scale, learning_rate, scale_by_ptr, skip_if, model_diff, model, momentum);\n}\n\ntemplate<typename T>\nstruct MomentumUpdateKernelUtil<DeviceType::kCUDA, T, float16> {\n  static void Update(ep::Stream* stream, int64_t n, T scale, float l1, float l2, float beta,\n                     float dampening, bool nesterov, bool maximize, float weight_decay,\n                     float learning_rate_val, float lr_scale, const float* learning_rate,\n                     const T* scale_by_ptr, const int64_t* skip_if, const float16* model_diff,\n                     T* model, T* momentum);\n};\n\ntemplate<typename T>\nvoid MomentumUpdateKernelUtil<DeviceType::kCUDA, T, float16>::Update(\n    ep::Stream* stream, int64_t n, T scale, float l1, float l2, float beta, float dampening,\n    bool nesterov, bool maximize, float weight_decay, float learning_rate_val, float lr_scale,\n    const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if,\n    const float16* model_diff, T* model, T* momentum) {\n  MomentumUpdateKernelUtil<DeviceType::kCUDA, T, half>::Update(\n      stream, n, scale, l1, l2, beta, dampening, nesterov, maximize, weight_decay,\n      learning_rate_val, lr_scale, learning_rate, scale_by_ptr, skip_if,\n      reinterpret_cast<const half*>(model_diff), model, momentum);\n}\n\ntemplate struct MomentumUpdateKernelUtil<DeviceType::kCUDA, double, double>;\ntemplate struct MomentumUpdateKernelUtil<DeviceType::kCUDA, float, float>;\ntemplate struct MomentumUpdateKernelUtil<DeviceType::kCUDA, float, float16>;\n\ntemplate<typename T, typename K, typename IDX>\nstruct IndexedSlicesMomentumMdUpdateKernelUtil<DeviceType::kCUDA, T, K, IDX> {\n  static void Update(ep::Stream* stream, T beta, float dampening, bool nesterov, bool maximize,\n                     float weight_decay, float lr_scale, int64_t num_instance, int64_t feature_size,\n                     int64_t lower_bound, int64_t upper_bound, const IDX* num_unique_instance,\n                     const float* learning_rate, const K* indices, const T* values, T* model,\n                     T* momentum);\n};\n\ntemplate<typename T, typename K, typename IDX>\nvoid IndexedSlicesMomentumMdUpdateKernelUtil<DeviceType::kCUDA, T, K, IDX>::Update(\n    ep::Stream* stream, T beta, float dampening, bool nesterov, bool maximize, float weight_decay,\n    float lr_scale, int64_t num_instance, int64_t feature_size, int64_t lower_bound,\n    int64_t upper_bound, const IDX* num_unique_instance, const float* learning_rate,\n    const K* indices, const T* values, T* model, T* momentum) {\n  IndexedSlicesMomentumUpdateGpu<T, K, IDX>\n      <<<BlocksNum4ThreadsNum(num_instance * feature_size), kCudaThreadsNumPerBlock, 0,\n         stream->As<ep::CudaStream>()->cuda_stream()>>>(\n          beta, dampening, nesterov, maximize, weight_decay, lr_scale, feature_size, lower_bound,\n          upper_bound, num_unique_instance, learning_rate, indices, values, model, momentum);\n}\n\n#define INSTANTIATE_INDEXED_SLICES_MOMENTUM_MODEL_UPDATE_KERNEL_UTIL_CUDA(                 \\\n    val_type_pair, key_type_pair, idx_type_pair)                                           \\\n  template struct IndexedSlicesMomentumMdUpdateKernelUtil<                                 \\\n      DeviceType::kCUDA, OF_PP_PAIR_FIRST(val_type_pair), OF_PP_PAIR_FIRST(key_type_pair), \\\n      OF_PP_PAIR_FIRST(idx_type_pair)>;\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_INDEXED_SLICES_MOMENTUM_MODEL_UPDATE_KERNEL_UTIL_CUDA,\n                                 FLOATING_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ, INT_DATA_TYPE_SEQ);\n#undef INSTANTIATE_INDEXED_SLICES_MOMENTUM_MODEL_UPDATE_KERNEL_UTIL_CUDA\n\nnamespace {\n\n__global__ void BiasCorrectionFactorKernelGpu(float beta, const int64_t* train_step, float* out) {\n  const auto exponent = static_cast<double>(*train_step + 1);\n  const float bias_correction_factor = 1.0 - static_cast<float>(pow(beta, exponent));\n  *out = bias_correction_factor;\n}\n\ntemplate<typename T, typename G, typename C>\n__global__ void AdamUpdateGpu(int64_t n, T scale, float l1, float l2, float beta1, float beta2,\n                              float epsilon, float weight_decay, bool amsgrad,\n                              bool do_bias_correction, float learning_rate_val, float lr_scale,\n                              float bias_correction1_val, float bias_correction2_val,\n                              const float* learning_rate, const T* scale_by_ptr,\n                              const int64_t* skip_if, const float* bias_correction1_ptr,\n                              const float* bias_correction2_ptr, const G* model_diff, T* model,\n                              C* model_copy, T* m, T* v, T* max_v) {\n  if (skip_if != nullptr && *skip_if != 0) { return; }\n  if (learning_rate != nullptr) { learning_rate_val = *learning_rate; }\n  if (scale_by_ptr != nullptr) { scale *= *scale_by_ptr; }\n  if (bias_correction1_ptr != nullptr) { bias_correction1_val = *bias_correction1_ptr; }\n  if (bias_correction2_ptr != nullptr) { bias_correction2_val = *bias_correction2_ptr; }\n\n  learning_rate_val *= lr_scale;\n  CUDA_1D_KERNEL_LOOP(i, n) {\n    if (model_copy != nullptr) {\n      FusedAdamUpdateFunctor<T, G, C>()(model_diff + i, model + i, model_copy + i, m + i, v + i,\n                                        max_v + i, scale, l1, l2, beta1, beta2, epsilon,\n                                        weight_decay, amsgrad, bias_correction1_val,\n                                        bias_correction2_val, learning_rate_val);\n    } else {\n      AdamUpdateFunctor<T, G>()(model_diff + i, model + i, m + i, v + i, max_v + i, scale, l1, l2,\n                                beta1, beta2, epsilon, weight_decay, amsgrad, bias_correction1_val,\n                                bias_correction2_val, learning_rate_val);\n    }\n  }\n}\n\ntemplate<typename T>\n__global__ void AdamUpdateBetaTGpu(const T beta1, const T beta2, const int64_t* skip_if, T* beta1_t,\n                                   T* beta2_t) {\n  if (skip_if != nullptr && *skip_if != 0) { return; }\n  *beta1_t *= beta1;\n  *beta2_t *= beta2;\n}\n\ntemplate<typename T, typename K, typename IDX>\n__global__ void IndexedSlicesAdamUpdateGpu(\n    float beta1, float beta2, float epsilon, float weight_decay, bool amsgrad,\n    bool do_bias_correction, float lr, float lr_scale, int64_t feature_size, int64_t lower_bound,\n    int64_t upper_bound, const IDX* num_unique_instance, const float* learning_rate,\n    const float* bias_correction1_ptr, const float* bias_correction2_ptr, const K* indices,\n    const T* values, T* model, T* m, T* v, T* max_v) {\n  if (learning_rate != nullptr) { lr = *learning_rate; }\n  lr *= lr_scale;\n  float bias_correction1 = 1.0;\n  float bias_correction2 = 1.0;\n  if (bias_correction1_ptr != nullptr) { bias_correction1 = *bias_correction1_ptr; }\n  if (bias_correction2_ptr != nullptr) { bias_correction2 = *bias_correction2_ptr; }\n  const int64_t n = *num_unique_instance * feature_size;\n  CUDA_1D_KERNEL_LOOP(i, n) {\n    const IDX indices_idx = i / feature_size;\n    const IDX inner_idx = i - indices_idx * feature_size;\n    const IDX instance_id = indices[indices_idx];\n    if (instance_id >= lower_bound && instance_id < upper_bound) {\n      const IDX model_idx = (instance_id - lower_bound) * feature_size + inner_idx;\n      AdamUpdateFunctor<T, T>()(values + i, model + model_idx, m + model_idx, v + model_idx,\n                                max_v + i, static_cast<T>(1), 0, 0, beta1, beta2, epsilon,\n                                weight_decay, amsgrad, bias_correction1, bias_correction2, lr);\n    }\n  }\n}\n\ntemplate<typename T, typename G>\n__global__ void LambGradGpu(int64_t n, T scale, float l1, float l2, float beta1, float beta2,\n                            float epsilon, const T* scale_by_ptr, const int64_t* skip_if,\n                            const G* model_diff, T* adam_diff, T* model, T* m, T* v,\n                            bool do_bias_correction, float bias_correction1_val,\n                            float bias_correction2_val, const float* bias_correction1_ptr,\n                            const float* bias_correction2_ptr) {\n  if (skip_if != nullptr && *skip_if != 0) { return; }\n  if (scale_by_ptr != nullptr) { scale *= *scale_by_ptr; }\n  if (bias_correction1_ptr != nullptr) { bias_correction1_val = *bias_correction1_ptr; }\n  if (bias_correction2_ptr != nullptr) { bias_correction2_val = *bias_correction2_ptr; }\n  CUDA_1D_KERNEL_LOOP(i, n) {\n    LambGradFunctor<T, G>()(model_diff + i, adam_diff + i, model + i, m + i, v + i, scale, l1, l2,\n                            beta1, beta2, epsilon, do_bias_correction, bias_correction1_val,\n                            bias_correction2_val);\n  }\n}\n\ntemplate<typename T>\n__global__ void LambUpdateGpu(int64_t n, float weight_decay, float learning_rate_val,\n                              float lr_scale, const float* learning_rate_ptr,\n                              const int64_t* skip_if, const T* w_norm_2, const T* g_norm_2,\n                              const T* adam_diff, T* model) {\n  if (skip_if != nullptr && *skip_if != 0) { return; }\n  if (learning_rate_ptr != nullptr) { learning_rate_val = *learning_rate_ptr; }\n  learning_rate_val *= lr_scale;\n  const float lr = LambLRFunctor<T>()(learning_rate_val, w_norm_2, g_norm_2);\n  CUDA_1D_KERNEL_LOOP(i, n) { LambUpdateFunctor<T>()(lr, weight_decay, adam_diff + i, model + i); }\n}\n\n}  // namespace\n\ntemplate<typename T, typename G, typename C>\nstruct AdamUpdateKernelUtil<DeviceType::kCUDA, T, G, C> {\n  static void Update(ep::Stream* stream, int64_t n, T scale, float l1, float l2, float beta1,\n                     float beta2, float epsilon, float weight_decay, bool amsgrad,\n                     bool do_bias_correction, float learning_rate_val, float lr_scale,\n                     float bias_correction1_val, float bias_correction2_val,\n                     const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if,\n                     const float* bias_correction1_ptr, const float* bias_correction2_ptr,\n                     const G* model_diff, T* model, C* model_copy, T* m, T* v, T* max_v);\n};\n\ntemplate<typename T, typename G, typename C>\nvoid AdamUpdateKernelUtil<DeviceType::kCUDA, T, G, C>::Update(\n    ep::Stream* stream, int64_t n, T scale, float l1, float l2, float beta1, float beta2,\n    float epsilon, float weight_decay, bool amsgrad, bool do_bias_correction,\n    float learning_rate_val, float lr_scale, float bias_correction1_val, float bias_correction2_val,\n    const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if,\n    const float* bias_correction1_ptr, const float* bias_correction2_ptr, const G* model_diff,\n    T* model, C* model_copy, T* m, T* v, T* max_v) {\n  AdamUpdateGpu<T, G, C><<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0,\n                           stream->As<ep::CudaStream>()->cuda_stream()>>>(\n      n, scale, l1, l2, beta1, beta2, epsilon, weight_decay, amsgrad, do_bias_correction,\n      learning_rate_val, lr_scale, bias_correction1_val, bias_correction2_val, learning_rate,\n      scale_by_ptr, skip_if, bias_correction1_ptr, bias_correction2_ptr, model_diff, model,\n      model_copy, m, v, max_v);\n}\n\ntemplate<typename T, typename G>\nstruct AdamUpdateKernelUtil<DeviceType::kCUDA, T, G, float16> {\n  static void Update(ep::Stream* stream, int64_t n, T scale, float l1, float l2, float beta1,\n                     float beta2, float epsilon, float weight_decay, bool amsgrad,\n                     bool do_bias_correction, float learning_rate_val, float lr_scale,\n                     float bias_correction1_val, float bias_correction2_val,\n                     const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if,\n                     const float* bias_correction1_ptr, const float* bias_correction2_ptr,\n                     const G* model_diff, T* model, float16* model_copy, T* m, T* v, T* max_v);\n};\n\ntemplate<typename T, typename G>\nvoid AdamUpdateKernelUtil<DeviceType::kCUDA, T, G, float16>::Update(\n    ep::Stream* stream, int64_t n, T scale, float l1, float l2, float beta1, float beta2,\n    float epsilon, float weight_decay, bool amsgrad, bool do_bias_correction,\n    float learning_rate_val, float lr_scale, float bias_correction1_val, float bias_correction2_val,\n    const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if,\n    const float* bias_correction1_ptr, const float* bias_correction2_ptr, const G* model_diff,\n    T* model, float16* model_copy, T* m, T* v, T* max_v) {\n  AdamUpdateKernelUtil<DeviceType::kCUDA, T, G, half>::Update(\n      stream, n, scale, l1, l2, beta1, beta2, epsilon, weight_decay, amsgrad, do_bias_correction,\n      learning_rate_val, lr_scale, bias_correction1_val, bias_correction2_val, learning_rate,\n      scale_by_ptr, skip_if, bias_correction1_ptr, bias_correction2_ptr, model_diff, model,\n      reinterpret_cast<half*>(model_copy), m, v, max_v);\n}\n\ntemplate<typename T>\nstruct AdamUpdateKernelUtil<DeviceType::kCUDA, T, float16, float16> {\n  static void Update(ep::Stream* stream, int64_t n, T scale, float l1, float l2, float beta1,\n                     float beta2, float epsilon, float weight_decay, bool amsgrad,\n                     bool do_bias_correction, float learning_rate_val, float lr_scale,\n                     float bias_correction1_val, float bias_correction2_val,\n                     const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if,\n                     const float* bias_correction1_ptr, const float* bias_correction2_ptr,\n                     const float16* model_diff, T* model, float16* model_copy, T* m, T* v,\n                     T* max_v);\n};\n\ntemplate<typename T>\nvoid AdamUpdateKernelUtil<DeviceType::kCUDA, T, float16, float16>::Update(\n    ep::Stream* stream, int64_t n, T scale, float l1, float l2, float beta1, float beta2,\n    float epsilon, float weight_decay, bool amsgrad, bool do_bias_correction,\n    float learning_rate_val, float lr_scale, float bias_correction1_val, float bias_correction2_val,\n    const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if,\n    const float* bias_correction1_ptr, const float* bias_correction2_ptr, const float16* model_diff,\n    T* model, float16* model_copy, T* m, T* v, T* max_v) {\n  AdamUpdateKernelUtil<DeviceType::kCUDA, T, half, half>::Update(\n      stream, n, scale, l1, l2, beta1, beta2, epsilon, weight_decay, amsgrad, do_bias_correction,\n      learning_rate_val, lr_scale, bias_correction1_val, bias_correction2_val, learning_rate,\n      scale_by_ptr, skip_if, bias_correction1_ptr, bias_correction2_ptr,\n      reinterpret_cast<const half*>(model_diff), model, reinterpret_cast<half*>(model_copy), m, v,\n      max_v);\n}\n\ntemplate struct AdamUpdateKernelUtil<DeviceType::kCUDA, float, float, float16>;\ntemplate struct AdamUpdateKernelUtil<DeviceType::kCUDA, double, double, float16>;\ntemplate struct AdamUpdateKernelUtil<DeviceType::kCUDA, float, float16, float16>;\n\ntemplate<typename T, typename G>\n__global__ void AdagradUpdateGpu(int64_t n, T scale, float l1, float l2, float lr_decay,\n                                 float epsilon, float weight_decay, float learning_rate_val,\n                                 float lr_scale, int64_t train_step, const float* learning_rate,\n                                 const int64_t* train_step_ptr, const T* scale_by_ptr,\n                                 const int64_t* skip_if, const G* model_diff, T* model, T* sum) {\n  if (skip_if != nullptr && *skip_if != 0) { return; }\n  if (learning_rate != nullptr) { learning_rate_val = *learning_rate; }\n  if (train_step_ptr != nullptr) {\n    train_step = *train_step_ptr + 1;\n  }  // train_step_ptr start from zero.\n  if (scale_by_ptr != nullptr) { scale *= *scale_by_ptr; }\n  learning_rate_val = learning_rate_val * lr_scale / (1 + (train_step - 1) * lr_decay);\n  CUDA_1D_KERNEL_LOOP(i, n) {\n    AdagradUpdateFunctor<T, G>()(model_diff + i, model + i, sum + i, scale, l1, l2, epsilon,\n                                 weight_decay, learning_rate_val);\n  }\n}\n\ntemplate<typename T, typename G>\nstruct AdagradUpdateKernelUtil<DeviceType::kCUDA, T, G> {\n  static void Update(ep::Stream* stream, int64_t n, T scale, float l1, float l2, float lr_decay,\n                     float epsilon, float weight_decay, float learning_rate_val, float lr_scale,\n                     int64_t train_step, const float* learning_rate, const int64_t* train_step_ptr,\n                     const T* scale_by_ptr, const int64_t* skip_if, const G* model_diff, T* model,\n                     T* sum);\n};\n\ntemplate<typename T, typename G>\nvoid AdagradUpdateKernelUtil<DeviceType::kCUDA, T, G>::Update(\n    ep::Stream* stream, int64_t n, T scale, float l1, float l2, float lr_decay, float epsilon,\n    float weight_decay, float learning_rate_val, float lr_scale, int64_t train_step,\n    const float* learning_rate, const int64_t* train_step_ptr, const T* scale_by_ptr,\n    const int64_t* skip_if, const G* model_diff, T* model, T* sum) {\n  AdagradUpdateGpu<T, G><<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0,\n                           stream->As<ep::CudaStream>()->cuda_stream()>>>(\n      n, scale, l1, l2, lr_decay, epsilon, weight_decay, learning_rate_val, lr_scale, train_step,\n      learning_rate, train_step_ptr, scale_by_ptr, skip_if, model_diff, model, sum);\n}\n\ntemplate struct AdagradUpdateKernelUtil<DeviceType::kCUDA, float, float>;\ntemplate struct AdagradUpdateKernelUtil<DeviceType::kCUDA, double, double>;\n\ntemplate<typename T, typename G>\nstruct LambUpdateKernelUtil<DeviceType::kCUDA, T, G> {\n  static void Update(ep::Stream* stream, int64_t n, float scale, float l1, float l2, float beta1,\n                     float beta2, float epsilon, float weight_decay, float learning_rate_val,\n                     float lr_scale, bool do_bias_correction, float bias_correction1_val,\n                     float bias_correction2_val, const float* learning_rate_ptr,\n                     const float* bias_correction1_ptr, const float* bias_correction2_ptr,\n                     const T* scale_by_ptr, const int64_t* skip_if, const G* model_diff,\n                     T* adam_diff, T* model, T* m, T* v, T* norm_buffer);\n};\n\ntemplate<typename T, typename G>\nvoid LambUpdateKernelUtil<DeviceType::kCUDA, T, G>::Update(\n    ep::Stream* stream, int64_t n, float scale, float l1, float l2, float beta1, float beta2,\n    float epsilon, float weight_decay, float learning_rate_val, float lr_scale,\n    bool do_bias_correction, float bias_correction1_val, float bias_correction2_val,\n    const float* learning_rate_ptr, const float* bias_correction1_ptr,\n    const float* bias_correction2_ptr, const T* scale_by_ptr, const int64_t* skip_if,\n    const G* model_diff, T* adam_diff, T* model, T* m, T* v, T* norm_buffer) {\n  LambGradGpu<T, G><<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0,\n                      stream->As<ep::CudaStream>()->cuda_stream()>>>(\n      n, scale, l1, l2, beta1, beta2, epsilon, scale_by_ptr, skip_if, model_diff, adam_diff, model,\n      m, v, do_bias_correction, bias_correction1_val, bias_correction2_val, bias_correction1_ptr,\n      bias_correction2_ptr);\n  T* w_norm_2 = norm_buffer;\n  T* g_norm_2 = norm_buffer + 1;\n  Memset<DeviceType::kCUDA>(stream, norm_buffer, 0, 2 * sizeof(T));\n  SumSquares2<T>\n      <<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0,\n         stream->As<ep::CudaStream>()->cuda_stream()>>>(n, model, w_norm_2, adam_diff, g_norm_2);\n  LambUpdateGpu<T><<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0,\n                     stream->As<ep::CudaStream>()->cuda_stream()>>>(\n      n, weight_decay, learning_rate_val, lr_scale, learning_rate_ptr, skip_if, w_norm_2, g_norm_2,\n      adam_diff, model);\n}\n\ntemplate<typename T>\nstruct LambUpdateKernelUtil<DeviceType::kCUDA, T, float16> {\n  static void Update(ep::Stream* stream, int64_t n, float scale, float l1, float l2, float beta1,\n                     float beta2, float epsilon, float weight_decay, float learning_rate_val,\n                     float lr_scale, bool do_bias_correction, float bias_correction1_val,\n                     float bias_correction2_val, const float* learning_rate_ptr,\n                     const float* bias_correction1_ptr, const float* bias_correction2_ptr,\n                     const T* scale_by_ptr, const int64_t* skip_if, const float16* model_diff,\n                     T* adam_diff, T* model, T* m, T* v, T* norm_buffer);\n};\n\ntemplate<typename T>\nvoid LambUpdateKernelUtil<DeviceType::kCUDA, T, float16>::Update(\n    ep::Stream* stream, int64_t n, float scale, float l1, float l2, float beta1, float beta2,\n    float epsilon, float weight_decay, float learning_rate_val, float lr_scale,\n    bool do_bias_correction, float bias_correction1_val, float bias_correction2_val,\n    const float* learning_rate_ptr, const float* bias_correction1_ptr,\n    const float* bias_correction2_ptr, const T* scale_by_ptr, const int64_t* skip_if,\n    const float16* model_diff, T* adam_diff, T* model, T* m, T* v, T* norm_buffer) {\n  LambUpdateKernelUtil<DeviceType::kCUDA, T, half>::Update(\n      stream, n, scale, l1, l2, beta1, beta2, epsilon, weight_decay, learning_rate_val, lr_scale,\n      do_bias_correction, bias_correction1_val, bias_correction2_val, learning_rate_ptr,\n      bias_correction1_ptr, bias_correction2_ptr, scale_by_ptr, skip_if,\n      reinterpret_cast<const half*>(model_diff), adam_diff, model, m, v, norm_buffer);\n}\n\ntemplate struct LambUpdateKernelUtil<DeviceType::kCUDA, float, float>;\ntemplate struct LambUpdateKernelUtil<DeviceType::kCUDA, double, double>;\ntemplate struct LambUpdateKernelUtil<DeviceType::kCUDA, float, float16>;\n\ntemplate<typename T, typename K, typename IDX>\nstruct IndexedSlicesAdamMdUpdateKernelUtil<DeviceType::kCUDA, T, K, IDX> {\n  static void Update(ep::Stream* stream, float beta1, float beta2, float epsilon,\n                     float weight_decay, bool amsgrad, bool do_bias_correction, float lr,\n                     float lr_scale, int64_t num_instance, int64_t feature_size,\n                     int64_t lower_bound, int64_t upper_bound, const IDX* num_unique_instance,\n                     const float* learning_rate, const float* bias_correction1_ptr,\n                     const float* bias_correction2_ptr, const K* indices, const T* values, T* model,\n                     T* m, T* v, T* max_v);\n};\n\ntemplate<typename T, typename K, typename IDX>\nvoid IndexedSlicesAdamMdUpdateKernelUtil<DeviceType::kCUDA, T, K, IDX>::Update(\n    ep::Stream* stream, float beta1, float beta2, float epsilon, float weight_decay, bool amsgrad,\n    bool do_bias_correction, float lr, float lr_scale, int64_t num_instance, int64_t feature_size,\n    int64_t lower_bound, int64_t upper_bound, const IDX* num_unique_instance,\n    const float* learning_rate, const float* bias_correction1_ptr,\n    const float* bias_correction2_ptr, const K* indices, const T* values, T* model, T* m, T* v,\n    T* max_v) {\n  IndexedSlicesAdamUpdateGpu<T, K, IDX>\n      <<<BlocksNum4ThreadsNum(num_instance * feature_size), kCudaThreadsNumPerBlock, 0,\n         stream->As<ep::CudaStream>()->cuda_stream()>>>(\n          beta1, beta2, epsilon, weight_decay, amsgrad, do_bias_correction, lr, lr_scale,\n          feature_size, lower_bound, upper_bound, num_unique_instance, learning_rate,\n          bias_correction1_ptr, bias_correction2_ptr, indices, values, model, m, v, max_v);\n}\n#define INSTANTIATE_INDEXED_SLICES_ADAM_MODEL_UPDATE_KERNEL_UTIL_CUDA(                     \\\n    val_type_pair, key_type_pair, idx_type_pair)                                           \\\n  template struct IndexedSlicesAdamMdUpdateKernelUtil<                                     \\\n      DeviceType::kCUDA, OF_PP_PAIR_FIRST(val_type_pair), OF_PP_PAIR_FIRST(key_type_pair), \\\n      OF_PP_PAIR_FIRST(idx_type_pair)>;\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_INDEXED_SLICES_ADAM_MODEL_UPDATE_KERNEL_UTIL_CUDA,\n                                 FLOATING_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ, INT_DATA_TYPE_SEQ);\n#undef INSTANTIATE_INDEXED_SLICES_ADAM_MODEL_UPDATE_KERNEL_UTIL_CUDA\n\ntemplate<>\nstruct BiasCorrectionFactorKernelUtil<DeviceType::kCUDA> {\n  static void BiasCorrectionFactorCompute(ep::Stream* stream, float beta, const int64_t* train_step,\n                                          float* out);\n};\n\nvoid BiasCorrectionFactorKernelUtil<DeviceType::kCUDA>::BiasCorrectionFactorCompute(\n    ep::Stream* stream, float beta, const int64_t* train_step, float* out) {\n  BiasCorrectionFactorKernelGpu<<<1, 1, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(\n      beta, train_step, out);\n}\n\nnamespace {\n\ntemplate<typename T, typename G, bool centered>\n__global__ void RmsPropUpdateGpu(int64_t n, T scale, float l1, float l2, T* mean_square,\n                                 T* mean_gradient, float epsilon, float weight_decay,\n                                 float decay_rate, float learning_rate_val, float lr_scale,\n                                 const float* learning_rate, const T* scale_by_ptr,\n                                 const int64_t* skip_if, const G* model_diff, T* model) {\n  if (skip_if != nullptr && *skip_if != 0) { return; }\n  if (learning_rate != nullptr) { learning_rate_val = *learning_rate; }\n  if (scale_by_ptr != nullptr) { scale *= *scale_by_ptr; }\n  learning_rate_val *= lr_scale;\n  CUDA_1D_KERNEL_LOOP(i, n) {\n    RmsPropUpdateFunctor<T, G, centered>()(model_diff + i, model + i, n, scale, l1, l2,\n                                           mean_square + i,\n                                           (centered ? mean_gradient + i : nullptr), epsilon,\n                                           weight_decay, decay_rate, learning_rate_val);\n  }\n}\n\n}  // namespace\n\ntemplate<typename T, typename G>\nstruct RmsPropUpdateKernelUtil<DeviceType::kCUDA, T, G> {\n  static void Update(ep::Stream* stream, int64_t n, T scale, float l1, float l2, bool centered,\n                     float epsilon, float weight_decay, float decay_rate, float learning_rate_val,\n                     float lr_scale, const float* learning_rate, const T* scale_by_ptr,\n                     const int64_t* skip_if, const G* model_diff, T* model, T* mean_square,\n                     T* mean_gradient);\n};\n\ntemplate<typename T, typename G>\nvoid RmsPropUpdateKernelUtil<DeviceType::kCUDA, T, G>::Update(\n    ep::Stream* stream, int64_t n, T scale, float l1, float l2, bool centered, float epsilon,\n    float weight_decay, float decay_rate, float learning_rate_val, float lr_scale,\n    const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, const G* model_diff,\n    T* model, T* mean_square, T* mean_gradient) {\n  if (centered) {\n    RmsPropUpdateGpu<T, G, true><<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0,\n                                   stream->As<ep::CudaStream>()->cuda_stream()>>>(\n        n, scale, l1, l2, mean_square, mean_gradient, epsilon, weight_decay, decay_rate,\n        learning_rate_val, lr_scale, learning_rate, scale_by_ptr, skip_if, model_diff, model);\n  } else {\n    RmsPropUpdateGpu<T, G, false><<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0,\n                                    stream->As<ep::CudaStream>()->cuda_stream()>>>(\n        n, scale, l1, l2, mean_square, mean_gradient, epsilon, weight_decay, decay_rate,\n        learning_rate_val, lr_scale, learning_rate, scale_by_ptr, skip_if, model_diff, model);\n  }\n}\n\ntemplate<typename T>\nstruct RmsPropUpdateKernelUtil<DeviceType::kCUDA, T, float16> {\n  static void Update(ep::Stream* stream, int64_t n, T scale, float l1, float l2, bool centered,\n                     float epsilon, float weight_decay, float decay_rate, float learning_rate_val,\n                     float lr_scale, const float* learning_rate, const T* scale_by_ptr,\n                     const int64_t* skip_if, const float16* model_diff, T* model, T* mean_square,\n                     T* mean_gradient);\n};\n\ntemplate<typename T>\nvoid RmsPropUpdateKernelUtil<DeviceType::kCUDA, T, float16>::Update(\n    ep::Stream* stream, int64_t n, T scale, float l1, float l2, bool centered, float epsilon,\n    float weight_decay, float decay_rate, float learning_rate_val, float lr_scale,\n    const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if,\n    const float16* model_diff, T* model, T* mean_square, T* mean_gradient) {\n  RmsPropUpdateKernelUtil<DeviceType::kCUDA, T, half>::Update(\n      stream, n, scale, l1, l2, centered, epsilon, weight_decay, decay_rate, learning_rate_val,\n      lr_scale, learning_rate, scale_by_ptr, skip_if, reinterpret_cast<const half*>(model_diff),\n      model, mean_square, mean_gradient);\n}\n\ntemplate struct RmsPropUpdateKernelUtil<DeviceType::kCUDA, float, float>;\ntemplate struct RmsPropUpdateKernelUtil<DeviceType::kCUDA, double, double>;\ntemplate struct RmsPropUpdateKernelUtil<DeviceType::kCUDA, float, float16>;\n\nnamespace {\n\ntemplate<typename T, typename G>\n__global__ void LarsScaleModelDiffGpu(int64_t n, T scale, float l1, float l2, const T* scale_by_ptr,\n                                      const int64_t* skip_if, const G* model_diff, T* model,\n                                      T* model_diff_tmp) {\n  if (skip_if != nullptr && *skip_if != 0) { return; }\n  if (scale_by_ptr != nullptr) { scale *= *scale_by_ptr; }\n  CUDA_1D_KERNEL_LOOP(i, n) {\n    model_diff_tmp[i] =\n        CastScaleRegularizeGradientFunctor<T, G>()(model_diff[i], model[i], scale, l1, l2);\n  }\n}\n\ntemplate<typename T>\n__global__ void LarsGetLocalLearningRateGpu(const float* learning_rate, float lr_scale,\n                                            T weight_decay, T epsilon, T lars_coefficient,\n                                            const int64_t* skip_if, T* data_tmp) {\n  if (skip_if != nullptr && *skip_if != 0) { return; }\n  T* model_norm = &data_tmp[0];\n  T* model_diff_norm = &data_tmp[1];\n  T* local_learning_rate = &data_tmp[2];\n  *model_norm = std::sqrt(*model_norm);\n  *model_diff_norm = std::sqrt(*model_diff_norm);\n  T lars = static_cast<T>(1);\n  if (*model_norm > 0 && *model_diff_norm > 0) {\n    lars = lars_coefficient * (*model_norm)\n           / (epsilon + (*model_diff_norm) + weight_decay * (*model_norm));\n  }\n  T lr = *learning_rate;\n  lr *= lr_scale;\n  *local_learning_rate = lr * lars;\n}\n\ntemplate<typename T>\n__global__ void LarsUpdateGpu(int64_t n, float momentum_beta, T* momentum, float weight_decay,\n                              const int64_t* skip_if, T* local_learning_rate, T* model_diff_tmp,\n                              T* model) {\n  if (skip_if != nullptr && *skip_if != 0) { return; }\n  CUDA_1D_KERNEL_LOOP(i, n) {\n    LarsUpdateFunctor<T>()(model_diff_tmp + i, model + i, momentum_beta, momentum + i, weight_decay,\n                           *local_learning_rate);\n  }\n}\n\n}  // namespace\n\ntemplate<typename T, typename G>\nstruct LarsUpdateKernelUtil<DeviceType::kCUDA, T, G> {\n  static void Update(ep::Stream* stream, int64_t n, T scale, float l1, float l2,\n                     float momentum_beta, float epsilon, float lars_coefficient, float weight_decay,\n                     float lr_scale, const float* learning_rate, const T* scale_by_ptr,\n                     const int64_t* skip_if, const G* model_diff, T* model, T* momentum,\n                     T* data_tmp, T* model_diff_tmp);\n};\n\ntemplate<typename T, typename G>\nvoid LarsUpdateKernelUtil<DeviceType::kCUDA, T, G>::Update(\n    ep::Stream* stream, int64_t n, T scale, float l1, float l2, float momentum_beta, float epsilon,\n    float lars_coefficient, float weight_decay, float lr_scale, const float* learning_rate,\n    const T* scale_by_ptr, const int64_t* skip_if, const G* model_diff, T* model, T* momentum,\n    T* data_tmp, T* model_diff_tmp) {\n  LarsScaleModelDiffGpu<T, G><<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0,\n                                stream->As<ep::CudaStream>()->cuda_stream()>>>(\n      n, scale, l1, l2, scale_by_ptr, skip_if, model_diff, model, model_diff_tmp);\n  T* model_norm = data_tmp;\n  T* model_diff_norm = data_tmp + 1;\n  T* local_learning_rate = data_tmp + 2;\n  Memset<DeviceType::kCUDA>(stream, data_tmp, 0, 2 * sizeof(T));\n  SumSquares2<T><<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0,\n                   stream->As<ep::CudaStream>()->cuda_stream()>>>(n, model, model_norm,\n                                                                  model_diff_tmp, model_diff_norm);\n  LarsGetLocalLearningRateGpu<T><<<1, 1, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(\n      learning_rate, lr_scale, weight_decay, epsilon, lars_coefficient, skip_if, data_tmp);\n  LarsUpdateGpu<T><<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0,\n                     stream->As<ep::CudaStream>()->cuda_stream()>>>(\n      n, momentum_beta, momentum, weight_decay, skip_if, local_learning_rate, model_diff_tmp,\n      model);\n}\n\ntemplate<typename T>\nstruct LarsUpdateKernelUtil<DeviceType::kCUDA, T, float16> {\n  static void Update(ep::Stream* stream, int64_t n, T scale, float l1, float l2,\n                     float momentum_beta, float epsilon, float lars_coefficient, float weight_decay,\n                     float lr_scale, const float* learning_rate, const T* scale_by_ptr,\n                     const int64_t* skip_if, const float16* model_diff, T* model, T* momentum,\n                     T* data_tmp, T* model_diff_tmp);\n};\n\ntemplate<typename T>\nvoid LarsUpdateKernelUtil<DeviceType::kCUDA, T, float16>::Update(\n    ep::Stream* stream, int64_t n, T scale, float l1, float l2, float momentum_beta, float epsilon,\n    float lars_coefficient, float weight_decay, float lr_scale, const float* learning_rate,\n    const T* scale_by_ptr, const int64_t* skip_if, const float16* model_diff, T* model, T* momentum,\n    T* data_tmp, T* model_diff_tmp) {\n  LarsUpdateKernelUtil<DeviceType::kCUDA, T, half>::Update(\n      stream, n, scale, l1, l2, momentum_beta, epsilon, lars_coefficient, weight_decay, lr_scale,\n      learning_rate, scale_by_ptr, skip_if, reinterpret_cast<const half*>(model_diff), model,\n      momentum, data_tmp, model_diff_tmp);\n}\n\ntemplate struct LarsUpdateKernelUtil<DeviceType::kCUDA, float, float>;\ntemplate struct LarsUpdateKernelUtil<DeviceType::kCUDA, double, double>;\ntemplate struct LarsUpdateKernelUtil<DeviceType::kCUDA, float, float16>;\n\ntemplate<typename T, typename G>\n__global__ void FtrlUpdateGpu(int64_t n, T scale, float l1, float l2, float lr_power, float lambda1,\n                              float lambda2, float beta, float weight_decay,\n                              float learning_rate_val, float lr_scale, const float* learning_rate,\n                              const T* scale_by_ptr, const int64_t* skip_if, const G* model_diff,\n                              T* model, T* accumulate, T* z) {\n  if (skip_if != nullptr && *skip_if != 0) { return; }\n  if (learning_rate != nullptr) { learning_rate_val = *learning_rate; }\n  if (scale_by_ptr != nullptr) { scale *= *scale_by_ptr; }\n  learning_rate_val *= lr_scale;\n  CUDA_1D_KERNEL_LOOP(i, n) {\n    FtrlUpdateFunctor<T, G>()(model_diff + i, model + i, accumulate + i, z + i, scale, l1, l2,\n                              lr_power, lambda1, lambda2, beta, weight_decay, learning_rate_val);\n  }\n}\n\ntemplate<typename T, typename G>\nstruct FtrlUpdateKernelUtil<DeviceType::kCUDA, T, G> {\n  static void Update(ep::Stream* stream, int64_t n, T scale, float l1, float l2, float lr_power,\n                     float lambda1, float lambda2, float beta, float weight_decay,\n                     float learning_rate_val, float lr_scale, const float* learning_rate,\n                     const T* scale_by_ptr, const int64_t* skip_if, const G* model_diff, T* model,\n                     T* accumulate, T* z);\n};\n\ntemplate<typename T, typename G>\nvoid FtrlUpdateKernelUtil<DeviceType::kCUDA, T, G>::Update(\n    ep::Stream* stream, int64_t n, T scale, float l1, float l2, float lr_power, float lambda1,\n    float lambda2, float beta, float weight_decay, float learning_rate_val, float lr_scale,\n    const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, const G* model_diff,\n    T* model, T* accumulate, T* z) {\n  FtrlUpdateGpu<T, G><<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0,\n                        stream->As<ep::CudaStream>()->cuda_stream()>>>(\n      n, scale, l1, l2, lr_power, lambda1, lambda2, beta, weight_decay, learning_rate_val, lr_scale,\n      learning_rate, scale_by_ptr, skip_if, model_diff, model, accumulate, z);\n}\n\ntemplate<typename T>\nstruct FtrlUpdateKernelUtil<DeviceType::kCUDA, T, float16> {\n  static void Update(ep::Stream* stream, int64_t n, T scale, float l1, float l2, float lr_power,\n                     float lambda1, float lambda2, float beta, float weight_decay,\n                     float learning_rate_val, float lr_scale, const float* learning_rate,\n                     const T* scale_by_ptr, const int64_t* skip_if, const float16* model_diff,\n                     T* model, T* accumulate, T* z);\n};\n\ntemplate<typename T>\nvoid FtrlUpdateKernelUtil<DeviceType::kCUDA, T, float16>::Update(\n    ep::Stream* stream, int64_t n, T scale, float l1, float l2, float lr_power, float lambda1,\n    float lambda2, float beta, float weight_decay, float learning_rate_val, float lr_scale,\n    const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if,\n    const float16* model_diff, T* model, T* accumulate, T* z) {\n  FtrlUpdateKernelUtil<DeviceType::kCUDA, T, half>::Update(\n      stream, n, scale, l1, l2, lr_power, lambda1, lambda2, beta, weight_decay, learning_rate_val,\n      lr_scale, learning_rate, scale_by_ptr, skip_if, reinterpret_cast<const half*>(model_diff),\n      model, accumulate, z);\n}\n\ntemplate struct FtrlUpdateKernelUtil<DeviceType::kCUDA, float, float>;\ntemplate struct FtrlUpdateKernelUtil<DeviceType::kCUDA, double, double>;\ntemplate struct FtrlUpdateKernelUtil<DeviceType::kCUDA, float, float16>;\n\ntemplate<typename T, typename G>\n__global__ void AdadeltaUpdateGpu(int64_t n, T scale, float l1, float l2, float rho, float epsilon,\n                                  bool maximize, float weight_decay, float learning_rate_val,\n                                  float lr_scale, const float* learning_rate, const T* scale_by_ptr,\n                                  const int64_t* skip_if, const G* model_diff, T* model,\n                                  T* square_avgs, T* acc_deltas) {\n  if (skip_if != nullptr && *skip_if != 0) { return; }\n  if (learning_rate != nullptr) { learning_rate_val = *learning_rate; }\n  if (scale_by_ptr != nullptr) { scale *= *scale_by_ptr; }\n  learning_rate_val *= lr_scale;\n  CUDA_1D_KERNEL_LOOP(i, n) {\n    AdadeltaUpdateFunctor<T, G>()(model_diff + i, model + i, square_avgs + i, acc_deltas + i, scale,\n                                  l1, l2, rho, epsilon, maximize, weight_decay, learning_rate_val);\n  }\n}\n\ntemplate<typename T, typename G>\nstruct AdadeltaUpdateKernelUtil<DeviceType::kCUDA, T, G> {\n  static void Update(ep::Stream* stream, int64_t n, T scale, float l1, float l2, float rho,\n                     float epsilon, bool maximize, float weight_decay, float learning_rate_val,\n                     float lr_scale, const float* learning_rate, const T* scale_by_ptr,\n                     const int64_t* skip_if, const G* model_diff, T* model, T* square_avgs,\n                     T* acc_deltas);\n};\n\ntemplate<typename T, typename G>\nvoid AdadeltaUpdateKernelUtil<DeviceType::kCUDA, T, G>::Update(\n    ep::Stream* stream, int64_t n, T scale, float l1, float l2, float rho, float epsilon,\n    bool maximize, float weight_decay, float learning_rate_val, float lr_scale,\n    const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, const G* model_diff,\n    T* model, T* square_avgs, T* acc_deltas) {\n  AdadeltaUpdateGpu<T, G><<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0,\n                            stream->As<ep::CudaStream>()->cuda_stream()>>>(\n      n, scale, l1, l2, rho, epsilon, maximize, weight_decay, learning_rate_val, lr_scale,\n      learning_rate, scale_by_ptr, skip_if, model_diff, model, square_avgs, acc_deltas);\n}\n\ntemplate<typename T>\nstruct AdadeltaUpdateKernelUtil<DeviceType::kCUDA, T, float16> {\n  static void Update(ep::Stream* stream, int64_t n, T scale, float l1, float l2, float rho,\n                     float epsilon, bool maximize, float weight_decay, float learning_rate_val,\n                     float lr_scale, const float* learning_rate, const T* scale_by_ptr,\n                     const int64_t* skip_if, const float16* model_diff, T* model, T* square_avgs,\n                     T* acc_deltas);\n};\n\ntemplate<typename T>\nvoid AdadeltaUpdateKernelUtil<DeviceType::kCUDA, T, float16>::Update(\n    ep::Stream* stream, int64_t n, T scale, float l1, float l2, float rho, float epsilon,\n    bool maximize, float weight_decay, float learning_rate_val, float lr_scale,\n    const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if,\n    const float16* model_diff, T* model, T* square_avgs, T* acc_deltas) {\n  AdadeltaUpdateKernelUtil<DeviceType::kCUDA, T, half>::Update(\n      stream, n, scale, l1, l2, rho, epsilon, maximize, weight_decay, learning_rate_val, lr_scale,\n      learning_rate, scale_by_ptr, skip_if, reinterpret_cast<const half*>(model_diff), model,\n      square_avgs, acc_deltas);\n}\n\ntemplate struct AdadeltaUpdateKernelUtil<DeviceType::kCUDA, float, float>;\ntemplate struct AdadeltaUpdateKernelUtil<DeviceType::kCUDA, double, double>;\ntemplate struct AdadeltaUpdateKernelUtil<DeviceType::kCUDA, float, float16>;\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/model_update_kernel_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_KERNELS_MODEL_UPDATE_KERNEL_UTIL_H_\n#define ONEFLOW_USER_KERNELS_MODEL_UPDATE_KERNEL_UTIL_H_\n\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/kernel_util.h\"\n#include \"oneflow/user/kernels/math_unary_elementwise_func.h\"\n\nnamespace oneflow {\n\ntemplate<typename T, typename G>\nstruct CastScaleRegularizeGradientFunctor {\n  OF_DEVICE_FUNC\n  T operator()(G model_diff, T model, T scale, float l1, float l2) const {\n    return static_cast<T>(model_diff) * scale + l1 * ((model >= 0) - (model <= 0)) + l2 * model;\n  }\n};\n\ntemplate<typename T, typename G>\nstruct SGDUpdateFunctor {\n  OF_DEVICE_FUNC\n  void operator()(const G* model_diff, T* model, T scale, float l1, float l2, float weight_decay,\n                  float learning_rate) const {\n    const T model_val = *model;\n    const T model_diff_t =\n        CastScaleRegularizeGradientFunctor<T, G>()(*model_diff, model_val, scale, l1, l2);\n    const T next_model = model_val - learning_rate * (model_diff_t + weight_decay * model_val);\n    *model = next_model;\n  }\n};\n\ntemplate<typename T, typename G, typename C>\nstruct FusedSGDUpdateFunctor {\n  OF_DEVICE_FUNC\n  void operator()(const G* model_diff, T* model, C* model_copy, T scale, float l1, float l2,\n                  float weight_decay, float learning_rate) const {\n    const T model_val = *model;\n    const T model_diff_t =\n        CastScaleRegularizeGradientFunctor<T, G>()(*model_diff, model_val, scale, l1, l2);\n    const T next_model = model_val - learning_rate * (model_diff_t + weight_decay * model_val);\n    *model = next_model;\n    *model_copy = static_cast<C>(next_model);\n  }\n};\n\ntemplate<DeviceType device_type, typename T, typename G, typename C>\nstruct SGDUpdateKernelUtil {\n  static void Update(ep::Stream* stream, int64_t n, T scale, float l1, float l2, float weight_decay,\n                     float learning_rate_val, float lr_scale, const float* learning_rate,\n                     const T* scale_by_ptr, const int64_t* skip_if, const G* model_diff, T* model,\n                     C* model_copy);\n};\n\ntemplate<DeviceType device_type, typename T, typename K, typename IDX>\nstruct IndexedSlicesSGDUpdateKernelUtil final {\n  static void Update(ep::Stream* stream, float weight_decay, float lr_scale, int64_t num_indices,\n                     int64_t feature_size, int64_t lower_bound, int64_t upper_bound,\n                     const IDX* num_unique_instance, const float* learning_rate, const K* indices,\n                     const T* values, T* model);\n};\n\ntemplate<typename T, typename G>\nstruct MomentumUpdateFunctor {\n  OF_DEVICE_FUNC\n  void operator()(const G* model_diff, T* model, T* momentum, T scale, float l1, float l2,\n                  float beta, float dampening, bool nesterov, bool maximize, float weight_decay,\n                  float learning_rate) const {\n    const T model_val = *model;\n    T model_diff_t =\n        CastScaleRegularizeGradientFunctor<T, G>()(*model_diff, model_val, scale, l1, l2);\n\n    T next_momentum = beta * *momentum + (1.0f - dampening) * model_diff_t;\n    *momentum = next_momentum;\n\n    if (!nesterov) {\n      model_diff_t = next_momentum;\n    } else {\n      model_diff_t += beta * next_momentum;\n    }\n\n    T alpha = -learning_rate;\n    if (maximize) { alpha = learning_rate; }\n    const T next_model =\n        model_val + alpha * model_diff_t - learning_rate * weight_decay * model_val;\n    *model = next_model;\n  }\n};\n\ntemplate<typename T, typename G>\nstruct AdamUpdateFunctor {\n  OF_DEVICE_FUNC\n  void operator()(const G* model_diff, T* model, T* m, T* v, T* max_v, T scale, float l1, float l2,\n                  float beta1, float beta2, float epsilon, float weight_decay, bool amsgrad,\n                  float bias_correction1, float bias_correction2, float learning_rate) const {\n    const T model_val = *model;\n    T model_diff_t =\n        CastScaleRegularizeGradientFunctor<T, G>()(*model_diff, model_val, scale, l1, l2);\n\n    const T next_m = beta1 * *m + (1 - beta1) * model_diff_t;\n    *m = next_m;\n\n    const T next_v = beta2 * *v + (1 - beta2) * model_diff_t * model_diff_t;\n    *v = next_v;\n\n    T denom = 0;\n    if (amsgrad) {\n      const T next_max_v =\n          *max_v > next_v ? *max_v : next_v;  // use std::max has bug in GPU kernel.\n      *max_v = next_max_v;\n      denom = (sqrt(next_max_v) / sqrt(bias_correction2)) + epsilon;\n    } else {\n      denom = (sqrt(next_v) / sqrt(bias_correction2)) + epsilon;\n    }\n    const T step_size = learning_rate / bias_correction1;\n    *model = model_val - step_size * (next_m / denom) - learning_rate * weight_decay * model_val;\n  }\n};\n\ntemplate<typename T, typename G, typename C>\nstruct FusedAdamUpdateFunctor {\n  OF_DEVICE_FUNC\n  void operator()(const G* model_diff, T* model, C* model_copy, T* m, T* v, T* max_v, T scale,\n                  float l1, float l2, float beta1, float beta2, float epsilon, float weight_decay,\n                  bool amsgrad, float bias_correction1, float bias_correction2,\n                  float learning_rate) const {\n    const T model_val = *model;\n    T model_diff_t =\n        CastScaleRegularizeGradientFunctor<T, G>()(*model_diff, model_val, scale, l1, l2);\n\n    const T next_m = beta1 * *m + (1 - beta1) * model_diff_t;\n    *m = next_m;\n\n    const T next_v = beta2 * *v + (1 - beta2) * model_diff_t * model_diff_t;\n    *v = next_v;\n\n    T denom = 0;\n    if (amsgrad) {\n      const T next_max_v =\n          *max_v > next_v ? *max_v : next_v;  // use std::max has bug in GPU kernel.\n      *max_v = next_max_v;\n      denom = (sqrt(next_max_v) / sqrt(bias_correction2)) + epsilon;\n    } else {\n      denom = (sqrt(next_v) / sqrt(bias_correction2)) + epsilon;\n    }\n    const T step_size = learning_rate / bias_correction1;\n    const T next_model =\n        model_val - step_size * (next_m / denom) - learning_rate * weight_decay * model_val;\n    *model = next_model;\n    *model_copy = static_cast<C>(next_model);\n  }\n};\n\ntemplate<typename T, typename G>\nstruct AdagradUpdateFunctor {\n  OF_DEVICE_FUNC\n  void operator()(const G* model_diff, T* model, T* sum, T scale, float l1, float l2, float epsilon,\n                  float weight_decay, float learning_rate) {\n    const T model_val = *model;\n    T model_diff_t =\n        CastScaleRegularizeGradientFunctor<T, G>()(*model_diff, model_val, scale, l1, l2);\n    const T next_sum = *sum + model_diff_t * model_diff_t;\n    *sum = next_sum;\n    *model = model_val - learning_rate / (sqrt(next_sum) + epsilon) * model_diff_t\n             - learning_rate * weight_decay * model_val;\n  }\n};\n\ntemplate<typename T, typename G>\nstruct LambGradFunctor {\n  OF_DEVICE_FUNC\n  void operator()(const G* model_diff, T* adam_diff, T* model, T* m, T* v, float scale, float l1,\n                  float l2, float beta1, float beta2, float epsilon, bool do_bias_correction,\n                  float bias_correction1, float bias_correction2) const {\n    const T model_val = *model;\n    T model_diff_t =\n        CastScaleRegularizeGradientFunctor<T, G>()(*model_diff, model_val, scale, l1, l2);\n    const T next_m = beta1 * *m + (1 - beta1) * model_diff_t;\n    const T next_v = beta2 * *v + (1 - beta2) * model_diff_t * model_diff_t;\n    *m = next_m;\n    *v = next_v;\n    T numerator = 0;\n    T denominator = 0;\n    if (do_bias_correction) {\n      numerator = next_m / bias_correction1;\n      denominator = (sqrt(next_v) / sqrt(bias_correction2)) + epsilon;\n    } else {\n      numerator = next_m;\n      denominator = sqrt(next_v) + epsilon;\n    }\n    *adam_diff = numerator / denominator;\n  }\n};\n\ntemplate<typename T>\nstruct LambLRFunctor {\n  OF_DEVICE_FUNC\n  float operator()(const float learning_rate_val, const T* w_norm_2, const T* g_norm_2) const {\n    float lr = learning_rate_val;\n    const T w_norm_val = sqrt(*w_norm_2);\n    const T g_norm_val = sqrt(*g_norm_2);\n    T trust_ratio = 1;\n    if (w_norm_val > 0 && g_norm_val > 0) { trust_ratio = w_norm_val / g_norm_val; }\n    lr *= trust_ratio;\n    return lr;\n  }\n};\n\ntemplate<typename T>\nstruct LambUpdateFunctor {\n  OF_DEVICE_FUNC\n  void operator()(const float learning_rate, const float weight_decay, const T* adam_diff,\n                  T* model) const {\n    const T model_val = *model;\n    *model = model_val - learning_rate * (*adam_diff + weight_decay * model_val);\n  }\n};\n\ntemplate<typename T, typename G>\nstruct FtrlUpdateFunctor {\n  OF_DEVICE_FUNC void operator()(const G* model_diff, T* model, T* accumulate, T* z, T scale,\n                                 float l1, float l2, float lr_power, float lambda1, float lambda2,\n                                 float beta, float weight_decay, float learning_rate) {\n    const T model_val = *model;\n    const T z_val = *z;\n    const float lr_reciprocal = static_cast<float>(1.0) / learning_rate;\n    T model_diff_t =\n        CastScaleRegularizeGradientFunctor<T, G>()(*model_diff, model_val, scale, l1, l2);\n    const T accumulate_val = *accumulate;\n    const T next_accumulate_val = accumulate_val + model_diff_t * model_diff_t;\n    const T acc_powered = pow(accumulate_val, lr_power);\n    const T next_acc_powered = pow(next_accumulate_val, lr_power);\n    const T sigma = (next_acc_powered - acc_powered) * lr_reciprocal;\n    const T new_z_val = z_val + model_diff_t - sigma * model_val;\n    T new_model = static_cast<T>(0.0);\n    if (abs(new_z_val) >= lambda1) {\n      new_model = (copysign(lambda1, new_z_val) - new_z_val)\n                      / ((beta + next_acc_powered) * lr_reciprocal + lambda2)\n                  - learning_rate * weight_decay * model_val;\n    }\n    *model = new_model;\n    *accumulate = next_accumulate_val;\n    *z = new_z_val;\n  }\n};\n\ntemplate<typename T, typename G>\nstruct AdadeltaUpdateFunctor {\n  OF_DEVICE_FUNC void operator()(const G* model_diff, T* model, T* square_avgs, T* acc_deltas,\n                                 T scale, float l1, float l2, float rho, float epsilon,\n                                 bool maximize, float weight_decay, float learning_rate) {\n    const T model_val = *model;\n    T model_diff_val = *model_diff;\n    if (maximize) { model_diff_val = -model_diff_val; }\n    T model_diff_t =\n        CastScaleRegularizeGradientFunctor<T, G>()(model_diff_val, model_val, scale, l1, l2);\n    T square_avgs_val = *square_avgs;\n    T new_square_avgs_val = square_avgs_val * rho + model_diff_t * model_diff_t * (1.0f - rho);\n    T square_avgs_std = sqrt(new_square_avgs_val + epsilon);\n    T acc_delta_val = *acc_deltas;\n    T delta = sqrt(acc_delta_val + epsilon) / square_avgs_std * model_diff_t;\n    T new_acc_deltas = acc_delta_val * rho + delta * delta * (1.0f - rho);\n    T new_model = model_val - learning_rate * delta;\n    *model = new_model;\n    *square_avgs = new_square_avgs_val;\n    *acc_deltas = new_acc_deltas;\n  }\n};\n\ntemplate<DeviceType device_type>\nstruct BiasCorrectionFactorKernelUtil {\n public:\n  static void BiasCorrectionFactorCompute(ep::Stream* stream, float beta, const int64_t* train_step,\n                                          float* out);\n};\n\ntemplate<DeviceType device_type, typename T, typename G>\nstruct MomentumUpdateKernelUtil {\n  static void Update(ep::Stream* stream, int64_t n, T scale, float l1, float l2, float beta,\n                     float dampening, bool nesterov, bool maximize, float weight_decay,\n                     float learning_rate_val, float lr_scale, const float* learning_rate,\n                     const T* scale_by_ptr, const int64_t* skip_if, const G* model_diff, T* model,\n                     T* momentum);\n};\n\ntemplate<DeviceType device_type, typename T, typename K, typename IDX>\nstruct IndexedSlicesMomentumMdUpdateKernelUtil {\n  static void Update(ep::Stream* stream, T beta, float dampening, bool nesterov, bool maximize,\n                     float weight_decay, float lr_scale, int64_t num_instance, int64_t feature_size,\n                     int64_t lower_bound, int64_t upper_bound, const IDX* num_unique_instance,\n                     const float* learning_rate, const K* indices, const T* values, T* model,\n                     T* momentum);\n};\n\ntemplate<DeviceType device_type, typename T, typename G, typename C>\nstruct AdamUpdateKernelUtil {\n  static void Update(ep::Stream* stream, int64_t n, T scale, float l1, float l2, float beta1,\n                     float beta2, float epsilon, float weight_decay, bool amsgrad,\n                     bool do_bias_correction, float learning_rate_val, float lr_scale,\n                     float bias_correction1_val, float bias_correction2_val,\n                     const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if,\n                     const float* bias_correction1, const float* bias_correction2,\n                     const G* model_diff, T* model, C* model_copy, T* m, T* v, T* max_v);\n};\n\ntemplate<DeviceType device_type, typename T, typename G>\nstruct AdagradUpdateKernelUtil {\n  static void Update(ep::Stream* stream, int64_t n, T scale, float l1, float l2, float lr_decay,\n                     float epsilon, float weight_decay, float learning_rate_val, float lr_scale,\n                     int64_t train_step, const float* learning_rate, const int64_t* train_step_ptr,\n                     const T* scale_by_ptr, const int64_t* skip_if, const G* model_diff, T* model,\n                     T* sum);\n};\n\ntemplate<DeviceType device_type, typename T, typename K, typename IDX>\nstruct IndexedSlicesAdamMdUpdateKernelUtil {\n  static void Update(ep::Stream* stream, float beta1, float beta2, float epsilon,\n                     float weight_decay, bool amsgrad, bool do_bias_correction, float lr,\n                     float lr_scale, int64_t num_instance, int64_t feature_size,\n                     int64_t lower_bound, int64_t upper_bound, const IDX* num_unique_instance,\n                     const float* learning_rate, const float* bias_correction1_ptr,\n                     const float* bias_correction2_ptr, const K* indices, const T* values, T* model,\n                     T* m, T* v, T* max_v);\n};\n\ntemplate<DeviceType device_type, typename T, typename G>\nstruct LambUpdateKernelUtil {\n public:\n  static void Update(ep::Stream* stream, int64_t n, float scale, float l1, float l2, float beta1,\n                     float beta2, float epsilon, float weight_decay, float learning_rate_val,\n                     float lr_scale, bool do_bias_correction, float bias_correction1_val,\n                     float bias_correction2_val, const float* learning_rate_ptr,\n                     const float* bias_correction1_ptr, const float* bias_correction2_ptr,\n                     const T* scale_by_ptr, const int64_t* skip_if, const G* model_diff,\n                     T* adam_diff, T* model, T* m, T* v, T* norm_buffer);\n};\n\ntemplate<DeviceType device_type, typename T, typename G>\nstruct FtrlUpdateKernelUtil {\n  static void Update(ep::Stream* stream, int64_t n, T scale, float l1, float l2, float lr_power,\n                     float lambda1, float lambda2, float beta, float weight_decay,\n                     float learning_rate_val, float lr_scale, const float* learning_rate,\n                     const T* scale_by_ptr, const int64_t* skip_if, const G* model_diff, T* model,\n                     T* accumulate, T* z);\n};\n\ntemplate<DeviceType device_type, typename T, typename G>\nstruct AdadeltaUpdateKernelUtil {\n  static void Update(ep::Stream* stream, int64_t n, T scale, float l1, float l2, float rho,\n                     float epsilon, bool maximize, float weight_decay, float learning_rate_val,\n                     float lr_scale, const float* learning_rate, const T* scale_by_ptr,\n                     const int64_t* skip_if, const G* model_diff, T* model, T* square_avgs,\n                     T* acc_deltas);\n};\n\ntemplate<typename T, typename G, bool centered>\nstruct RmsPropUpdateFunctor {\n  OF_DEVICE_FUNC\n  void operator()(const G* model_diff, T* model, int64_t n, T scale, float l1, float l2,\n                  T* mean_square, T* mean_gradient, float epsilon, float weight_decay,\n                  float decay_rate, const float learning_rate) const {\n    const T model_val = *model;\n    T model_diff_t = CastScaleRegularizeGradientFunctor<T, G>()(*model_diff, *model, scale, l1, l2);\n    T mean_square_val = *mean_square;\n    mean_square_val = (1 - decay_rate) * model_diff_t * model_diff_t + decay_rate * mean_square_val;\n    *mean_square = mean_square_val;\n    T denom_t;\n    if (centered) {\n      T mean_gradient_val = *mean_gradient;\n      mean_gradient_val = (1 - decay_rate) * model_diff_t + decay_rate * mean_gradient_val;\n      *mean_gradient = mean_gradient_val;\n      denom_t = mean_square_val - mean_gradient_val * mean_gradient_val;\n    } else {\n      denom_t = *mean_square;\n    }\n    *model = model_val - learning_rate * model_diff_t * RsqrtFunctor<T>::Forward(denom_t + epsilon);\n  }\n};\n\ntemplate<DeviceType device_type, typename T, typename G>\nstruct RmsPropUpdateKernelUtil {\n  static void Update(ep::Stream* stream, int64_t n, T scale, float l1, float l2, bool centered,\n                     float epsilon, float weight_decay, float decay_rate, float learning_rate_val,\n                     float lr_scale, const float* learning_rate, const T* scale_by_ptr,\n                     const int64_t* skip_if, const G* model_diff, T* model, T* mean_square,\n                     T* mean_gradient);\n};\n\ntemplate<typename T>\nstruct LarsUpdateFunctor {\n  OF_DEVICE_FUNC\n  void operator()(T* model_diff_tmp, T* model, float momentum_beta, T* momentum, float weight_decay,\n                  const T local_learning_rate) const {\n    const T model_val = *model;\n    T next_momentum = *momentum * momentum_beta - local_learning_rate * *model_diff_tmp;\n    *momentum = next_momentum;\n    const T next_model = model_val + next_momentum - local_learning_rate * weight_decay * model_val;\n    *model = next_model;\n  }\n};\n\ntemplate<DeviceType device_type, typename T, typename G>\nstruct LarsUpdateKernelUtil {\n  static void Update(ep::Stream* stream, int64_t n, T scale, float l1, float l2,\n                     float momentum_beta, float epsilon, float lars_coefficient, float weight_decay,\n                     float lr_scale, const float* learning_rate, const T* scale_by_ptr,\n                     const int64_t* skip_if, const G* model_diff, T* model, T* momentum,\n                     T* data_tmp, T* model_diff_tmp);\n};\n\n#endif\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/model_update_kernels.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/user/kernels/model_update_kernel_util.h\"\n#include \"oneflow/user/kernels/indexed_slices_reduce_sum_kernel_util.h\"\n#include \"oneflow/core/common/balanced_splitter.h\"\n#include \"oneflow/core/kernel/cuda_graph_support.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<DeviceType device_type, typename T, typename K>\nclass TmpBufferManager final {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(TmpBufferManager);\n  TmpBufferManager(void* ptr, const int64_t num_indices, const int64_t num_values) : ptr_(ptr) {\n    CHECK_NE(num_indices, 0);\n    CHECK_NE(num_values, 0);\n    const size_t unique_diff_indices_bytes = GetCudaAlignedSize(num_indices * sizeof(K));\n    const size_t unique_diff_values_bytes = GetCudaAlignedSize(num_values * sizeof(T));\n    const size_t num_unique_diff_indices_bytes = GetCudaAlignedSize(1 * sizeof(int32_t));\n    CHECK_EQ(num_values % num_indices, 0);\n    IndexedSlicesReduceSumKernelUtil<device_type, K, T, int64_t>::GetReduceSumWorkspaceSizeInBytes(\n        nullptr, num_indices, num_values / num_indices, &unique_workspace_bytes_);\n    unique_diff_indices_offset_ = 0;\n    unique_diff_values_offset_ = unique_diff_indices_offset_ + unique_diff_indices_bytes;\n    num_unique_diff_indices_offset_ = unique_diff_values_offset_ + unique_diff_values_bytes;\n    unique_workspace_offset_ = num_unique_diff_indices_offset_ + num_unique_diff_indices_bytes;\n    CHECK_GE(unique_workspace_bytes_, 0);\n    total_buffer_size_ = unique_diff_indices_bytes + unique_diff_values_bytes\n                         + num_unique_diff_indices_bytes\n                         + static_cast<size_t>(unique_workspace_bytes_);\n  }\n  ~TmpBufferManager() = default;\n\n  int64_t UniqueWorkspaceBytes() const { return unique_workspace_bytes_; }\n  size_t GetTotalBufferSize() const { return total_buffer_size_; }\n  K* UniqueDiffIndicesPtr() const {\n    CHECK(ptr_ != nullptr);\n    return reinterpret_cast<K*>(reinterpret_cast<char*>(ptr_) + unique_diff_indices_offset_);\n  }\n  T* UniqueDiffValuesPtr() const {\n    CHECK(ptr_ != nullptr);\n    return reinterpret_cast<T*>(reinterpret_cast<char*>(ptr_) + unique_diff_values_offset_);\n  }\n  int32_t* NumUniqueDiffIndicesPtr() const {\n    CHECK(ptr_ != nullptr);\n    return reinterpret_cast<int32_t*>(reinterpret_cast<char*>(ptr_)\n                                      + num_unique_diff_indices_offset_);\n  }\n  char* UniqueWorkspacePtr() const {\n    CHECK(ptr_ != nullptr);\n    return reinterpret_cast<char*>(ptr_) + unique_workspace_offset_;\n  }\n\n private:\n  size_t unique_diff_indices_offset_;\n  size_t unique_diff_values_offset_;\n  size_t num_unique_diff_indices_offset_;\n  size_t unique_workspace_offset_;\n\n  int64_t unique_workspace_bytes_;\n  size_t total_buffer_size_;\n  void* ptr_;\n};\n\nclass IndexedSlicesUpdateOpKernelCache final : public user_op::OpKernelCache {\n public:\n  IndexedSlicesUpdateOpKernelCache(int64_t lower, int64_t upper) : lower_(lower), upper_(upper) {}\n  ~IndexedSlicesUpdateOpKernelCache() override = default;\n\n  int64_t lower() const { return lower_; }\n  int64_t upper() const { return upper_; }\n\n private:\n  const int64_t lower_;\n  const int64_t upper_;\n};\n\nstd::shared_ptr<user_op::OpKernelCache> CreateIndexedSlicesUpdateOpKernelCache(\n    user_op::KernelCacheContext* ctx) {\n  const SbpParallel& model_sbp = ctx->SbpParallel4ArgNameAndIndex(\"model\", 0);\n  const user_op::TensorDesc* model_logical_desc =\n      ctx->LogicalTensorDesc4ArgNameAndIndex(\"model\", 0);\n  const int64_t num_model_instances = model_logical_desc->shape().At(0);\n  if (model_sbp.has_split_parallel() && model_sbp.split_parallel().axis() == 0\n      && ctx->parallel_ctx().parallel_num() > 1) {\n    CHECK(ctx->SbpParallel4ArgNameAndIndex(\"model_diff_indices\", 0).has_broadcast_parallel());\n    CHECK(ctx->SbpParallel4ArgNameAndIndex(\"model_diff_values\", 0).has_broadcast_parallel());\n    BalancedSplitter bs(num_model_instances, ctx->parallel_ctx().parallel_num());\n    return std::make_shared<IndexedSlicesUpdateOpKernelCache>(\n        bs.At(ctx->parallel_ctx().parallel_id()).begin(),\n        bs.At(ctx->parallel_ctx().parallel_id()).end());\n  } else {\n    return std::make_shared<IndexedSlicesUpdateOpKernelCache>(0, num_model_instances);\n  }\n}\n\ntemplate<DeviceType device_type, typename T, typename G, typename C>\nclass SGDUpdateKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport {\n public:\n  SGDUpdateKernel() = default;\n  ~SGDUpdateKernel() override = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* model_diff = ctx->Tensor4ArgNameAndIndex(\"model_diff\", 0);\n    user_op::Tensor* model = ctx->Tensor4ArgNameAndIndex(\"model\", 0);\n    const auto scale = ctx->Attr<double>(\"scale\");\n    const auto l1 = ctx->Attr<float>(\"l1\");\n    const auto l2 = ctx->Attr<float>(\"l2\");\n    const auto weight_decay = ctx->Attr<float>(\"weight_decay\");\n    const float learning_rate_val = ctx->Attr<float>(\"learning_rate_val\");\n    const float lr_scale = ctx->Attr<float>(\"learning_rate_scale\");\n    const float* learning_rate_ptr = nullptr;\n    C* model_copy_ptr = nullptr;\n    if (ctx->has_input(\"model_copy\", 0)) {\n      user_op::Tensor* model_copy = ctx->Tensor4ArgNameAndIndex(\"model_copy\", 0);\n      model_copy_ptr = model_copy->mut_dptr<C>();\n    }\n    if (ctx->has_input(\"learning_rate\", 0)) {\n      const user_op::Tensor* learning_rate = ctx->Tensor4ArgNameAndIndex(\"learning_rate\", 0);\n      learning_rate_ptr = learning_rate->dptr<float>();\n    }\n    const T* scale_by_ptr = nullptr;\n    if (ctx->has_input(\"scale_by_tensor\", 0)) {\n      const user_op::Tensor* scale_by_tensor = ctx->Tensor4ArgNameAndIndex(\"scale_by_tensor\", 0);\n      CHECK_EQ(scale_by_tensor->data_type(), model->data_type());\n      CHECK_EQ(scale_by_tensor->shape_view().elem_cnt(), 1);\n      scale_by_ptr = scale_by_tensor->dptr<T>();\n    }\n    const int64_t* skip_if_ptr = nullptr;\n    if (ctx->has_input(\"skip_if\", 0)) {\n      const user_op::Tensor* skip_if = ctx->Tensor4ArgNameAndIndex(\"skip_if\", 0);\n      CHECK_EQ(skip_if->shape_view().elem_cnt(), 1);\n      skip_if_ptr = skip_if->dptr<int64_t>();\n    }\n    SGDUpdateKernelUtil<device_type, T, G, C>::Update(\n        ctx->stream(), model->shape_view().elem_cnt(), static_cast<T>(scale), l1, l2, weight_decay,\n        learning_rate_val, lr_scale, learning_rate_ptr, scale_by_ptr, skip_if_ptr,\n        model_diff->dptr<G>(), model->mut_dptr<T>(), model_copy_ptr);\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return true; }\n};\n\n#define REGISTER_SGD_UPDATE_KERNEL(device, dtype, gtype, ctype)                           \\\n  REGISTER_USER_KERNEL(\"sgd_update\")                                                      \\\n      .SetCreateFn<SGDUpdateKernel<device, dtype, gtype, ctype>>()                        \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == device)                               \\\n                       && (user_op::HobDataType(\"model\", 0) == GetDataType<dtype>::value) \\\n                       && (user_op::HobDataType(\"model_diff\", 0) == GetDataType<gtype>::value));\n\nREGISTER_SGD_UPDATE_KERNEL(DeviceType::kCPU, float, float, float16);\nREGISTER_SGD_UPDATE_KERNEL(DeviceType::kCPU, double, double, float16);\n#ifdef WITH_CUDA\nREGISTER_SGD_UPDATE_KERNEL(DeviceType::kCUDA, float, float16, float16);\nREGISTER_SGD_UPDATE_KERNEL(DeviceType::kCUDA, float, float, float16);\nREGISTER_SGD_UPDATE_KERNEL(DeviceType::kCUDA, double, double, float16);\n#endif  // WITH_CUDA\n\ntemplate<DeviceType device_type, typename T, typename K>\nuser_op::InferTmpSizeFn GenInferTmpSizeFn() {\n  return [](user_op::InferContext* ctx) {\n    const user_op::TensorDesc& indices = ctx->InputTensorDesc(\"model_diff_indices\", 0);\n    const user_op::TensorDesc& values = ctx->InputTensorDesc(\"model_diff_values\", 0);\n    const int64_t num_indices = indices.shape().elem_cnt();\n    const int64_t num_values = values.shape().elem_cnt();\n    TmpBufferManager<device_type, T, K> buffer_manager(nullptr, num_indices, num_values);\n    return buffer_manager.GetTotalBufferSize();\n  };\n}\n\ntemplate<DeviceType device_type, typename T, typename K>\nclass IndexedSlicesSGDUpdateKernel final : public user_op::OpKernel {\n public:\n  IndexedSlicesSGDUpdateKernel() = default;\n  ~IndexedSlicesSGDUpdateKernel() override = default;\n\n  std::shared_ptr<user_op::OpKernelCache> InitOpKernelCache(\n      user_op::KernelCacheContext* ctx) const override {\n    return CreateIndexedSlicesUpdateOpKernelCache(ctx);\n  }\n\n private:\n  using ReduceSumUtilT = IndexedSlicesReduceSumKernelUtil<device_type, K, T, int32_t>;\n  using MdUpdateUtilT = IndexedSlicesSGDUpdateKernelUtil<device_type, T, K, int32_t>;\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*,\n               const user_op::OpKernelCache* cache) const override {\n    const user_op::Tensor* learning_rate = ctx->Tensor4ArgNameAndIndex(\"learning_rate\", 0);\n    const user_op::Tensor* model_diff_indices =\n        ctx->Tensor4ArgNameAndIndex(\"model_diff_indices\", 0);\n    const user_op::Tensor* model_diff_values = ctx->Tensor4ArgNameAndIndex(\"model_diff_values\", 0);\n    user_op::Tensor* model = ctx->Tensor4ArgNameAndIndex(\"model\", 0);\n    const auto weight_decay = ctx->Attr<float>(\"weight_decay\");\n    const auto lr_scale = ctx->Attr<float>(\"learning_rate_scale\");\n    const int64_t num_indices = model_diff_indices->shape_view().elem_cnt();\n    const int64_t num_values = model_diff_values->shape_view().elem_cnt();\n    if (num_indices == 0) {\n      CHECK_EQ(num_values, 0);\n      return;\n    }\n    CHECK_NE(num_values, 0);\n    CHECK_EQ(num_values % num_indices, 0);\n    const int64_t feature_size = num_values / num_indices;\n    auto* kernel_cache = dynamic_cast<const IndexedSlicesUpdateOpKernelCache*>(cache);\n    CHECK_NOTNULL(kernel_cache);\n    CHECK_EQ(model->shape_view().At(0), kernel_cache->upper() - kernel_cache->lower());\n    user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n    TmpBufferManager<device_type, T, K> buffer_manager(tmp_buffer->mut_dptr(), num_indices,\n                                                       num_values);\n    CHECK_GE(tmp_buffer->shape_view().elem_cnt(), buffer_manager.GetTotalBufferSize());\n    ReduceSumUtilT::ReduceSum(\n        ctx->stream(), num_indices, feature_size, model_diff_indices->dptr<K>(),\n        model_diff_values->dptr<T>(), buffer_manager.NumUniqueDiffIndicesPtr(),\n        buffer_manager.UniqueDiffIndicesPtr(), buffer_manager.UniqueDiffValuesPtr(),\n        buffer_manager.UniqueWorkspacePtr(), buffer_manager.UniqueWorkspaceBytes());\n    MdUpdateUtilT::Update(ctx->stream(), weight_decay, lr_scale, num_indices, feature_size,\n                          kernel_cache->lower(), kernel_cache->upper(),\n                          buffer_manager.NumUniqueDiffIndicesPtr(), learning_rate->dptr<float>(),\n                          buffer_manager.UniqueDiffIndicesPtr(),\n                          buffer_manager.UniqueDiffValuesPtr(), model->mut_dptr<T>());\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return true; }\n};\n\n#define REGISTER_INDEXED_SLICES_SGD_UPDATE_KERNEL(device_type_v, data_type_pair,                 \\\n                                                  indices_type_pair)                             \\\n  REGISTER_USER_KERNEL(\"indexed_slices_sgd_update\")                                              \\\n      .SetCreateFn<IndexedSlicesSGDUpdateKernel<device_type_v, OF_PP_PAIR_FIRST(data_type_pair), \\\n                                                OF_PP_PAIR_FIRST(indices_type_pair)>>()          \\\n      .SetIsMatchedHob(                                                                          \\\n          (user_op::HobDeviceType() == device_type_v)                                            \\\n          && (user_op::HobDataType(\"model\", 0) == OF_PP_PAIR_SECOND(data_type_pair))             \\\n          && (user_op::HobDataType(\"model_diff_values\", 0) == OF_PP_PAIR_SECOND(data_type_pair)) \\\n          && (user_op::HobDataType(\"model_diff_indices\", 0)                                      \\\n              == OF_PP_PAIR_SECOND(indices_type_pair)))                                          \\\n      .SetInferTmpSizeFn(GenInferTmpSizeFn<device_type_v, OF_PP_PAIR_FIRST(data_type_pair),      \\\n                                           OF_PP_PAIR_FIRST(indices_type_pair)>());\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_INDEXED_SLICES_SGD_UPDATE_KERNEL, DEVICE_TYPE_SEQ,\n                                 FLOATING_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ)\n\ntemplate<DeviceType device_type, typename T, typename G>\nclass MomentumUpdateKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport {\n public:\n  MomentumUpdateKernel() = default;\n  ~MomentumUpdateKernel() override = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    float learning_rate_val = ctx->Attr<float>(\"learning_rate_val\");\n    double scale = ctx->Attr<double>(\"scale\");\n    float l1 = ctx->Attr<float>(\"l1\");\n    float l2 = ctx->Attr<float>(\"l2\");\n    float beta = ctx->Attr<float>(\"beta\");\n    const float dampening = ctx->Attr<float>(\"dampening\");\n    const bool nesterov = ctx->Attr<bool>(\"nesterov\");\n    const bool maximize = ctx->Attr<bool>(\"maximize\");\n    float weight_decay = ctx->Attr<float>(\"weight_decay\");\n    const auto lr_scale = ctx->Attr<float>(\"learning_rate_scale\");\n\n    const user_op::Tensor* model_diff = ctx->Tensor4ArgNameAndIndex(\"model_diff\", 0);\n    user_op::Tensor* model = ctx->Tensor4ArgNameAndIndex(\"model\", 0);\n    user_op::Tensor* momentum = ctx->Tensor4ArgNameAndIndex(\"momentum\", 0);\n    const float* learning_rate_ptr = nullptr;\n    if (ctx->has_input(\"learning_rate\", 0)) {\n      const user_op::Tensor* learning_rate = ctx->Tensor4ArgNameAndIndex(\"learning_rate\", 0);\n      learning_rate_ptr = learning_rate->dptr<float>();\n    }\n    const T* scale_by_ptr = nullptr;\n    if (ctx->has_input(\"scale_by_tensor\", 0)) {\n      const user_op::Tensor* scale_by_tensor = ctx->Tensor4ArgNameAndIndex(\"scale_by_tensor\", 0);\n      CHECK_EQ(scale_by_tensor->data_type(), model->data_type());\n      CHECK_EQ(scale_by_tensor->shape_view().elem_cnt(), 1);\n      scale_by_ptr = scale_by_tensor->dptr<T>();\n    }\n    const int64_t* skip_if_ptr = nullptr;\n    if (ctx->has_input(\"skip_if\", 0)) {\n      const user_op::Tensor* skip_if = ctx->Tensor4ArgNameAndIndex(\"skip_if\", 0);\n      CHECK_EQ(skip_if->shape_view().elem_cnt(), 1);\n      skip_if_ptr = skip_if->dptr<int64_t>();\n    }\n    MomentumUpdateKernelUtil<device_type, T, G>::Update(\n        ctx->stream(), model->shape_view().elem_cnt(), static_cast<T>(scale), l1, l2, beta,\n        dampening, nesterov, maximize, weight_decay, learning_rate_val, lr_scale, learning_rate_ptr,\n        scale_by_ptr, skip_if_ptr, model_diff->dptr<G>(), model->mut_dptr<T>(),\n        momentum->mut_dptr<T>());\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return true; }\n};\n\n#define REGISTER_MOMENTUM_UPDATE_KERNEL(device, dtype, gtype)                             \\\n  REGISTER_USER_KERNEL(\"momentum_update\")                                                 \\\n      .SetCreateFn<MomentumUpdateKernel<device, dtype, gtype>>()                          \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == device)                               \\\n                       && (user_op::HobDataType(\"model\", 0) == GetDataType<dtype>::value) \\\n                       && (user_op::HobDataType(\"model_diff\", 0) == GetDataType<gtype>::value));\n\nREGISTER_MOMENTUM_UPDATE_KERNEL(DeviceType::kCPU, float, float);\nREGISTER_MOMENTUM_UPDATE_KERNEL(DeviceType::kCPU, double, double);\n#ifdef WITH_CUDA\nREGISTER_MOMENTUM_UPDATE_KERNEL(DeviceType::kCUDA, float, float16);\nREGISTER_MOMENTUM_UPDATE_KERNEL(DeviceType::kCUDA, float, float);\nREGISTER_MOMENTUM_UPDATE_KERNEL(DeviceType::kCUDA, double, double);\n#endif  // WITH_CUDA\n\ntemplate<DeviceType device_type, typename T, typename K>\nclass IndexedSlicesMomentumUpdateKernel final : public user_op::OpKernel {\n public:\n  IndexedSlicesMomentumUpdateKernel() = default;\n  ~IndexedSlicesMomentumUpdateKernel() override = default;\n\n  std::shared_ptr<user_op::OpKernelCache> InitOpKernelCache(\n      user_op::KernelCacheContext* ctx) const override {\n    return CreateIndexedSlicesUpdateOpKernelCache(ctx);\n  }\n\n private:\n  using ReduceSumUtilT = IndexedSlicesReduceSumKernelUtil<device_type, K, T, int32_t>;\n  using MdUpdateUtilT = IndexedSlicesMomentumMdUpdateKernelUtil<device_type, T, K, int32_t>;\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*,\n               const user_op::OpKernelCache* cache) const override {\n    const user_op::Tensor* learning_rate = ctx->Tensor4ArgNameAndIndex(\"learning_rate\", 0);\n    const user_op::Tensor* model_diff_indices =\n        ctx->Tensor4ArgNameAndIndex(\"model_diff_indices\", 0);\n    const user_op::Tensor* model_diff_values = ctx->Tensor4ArgNameAndIndex(\"model_diff_values\", 0);\n    user_op::Tensor* model = ctx->Tensor4ArgNameAndIndex(\"model\", 0);\n    user_op::Tensor* momentum = ctx->Tensor4ArgNameAndIndex(\"momentum\", 0);\n    const auto beta = ctx->Attr<float>(\"beta\");\n    const float dampening = ctx->Attr<float>(\"dampening\");\n    const bool nesterov = ctx->Attr<bool>(\"nesterov\");\n    const bool maximize = ctx->Attr<bool>(\"maximize\");\n    const auto weight_decay = ctx->Attr<float>(\"weight_decay\");\n    const float lr_scale = ctx->Attr<float>(\"learning_rate_scale\");\n    const int64_t num_indices = model_diff_indices->shape_view().elem_cnt();\n    const int64_t num_values = model_diff_values->shape_view().elem_cnt();\n    if (num_indices == 0) {\n      CHECK_EQ(num_values, 0);\n      return;\n    }\n    CHECK_NE(num_values, 0);\n    CHECK_EQ(num_values % num_indices, 0);\n    const int64_t feature_size = num_values / num_indices;\n    CHECK_EQ(feature_size,\n             model_diff_values->shape_view().Count(model_diff_indices->shape_view().NumAxes()));\n    auto* kernel_cache = dynamic_cast<const IndexedSlicesUpdateOpKernelCache*>(cache);\n    CHECK_NOTNULL(kernel_cache);\n    CHECK_EQ(model->shape_view().At(0), kernel_cache->upper() - kernel_cache->lower());\n    user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n    TmpBufferManager<device_type, T, K> buffer_manager(tmp_buffer->mut_dptr(), num_indices,\n                                                       num_values);\n    CHECK_GE(tmp_buffer->shape_view().elem_cnt(), buffer_manager.GetTotalBufferSize());\n    ReduceSumUtilT::ReduceSum(\n        ctx->stream(), num_indices, feature_size, model_diff_indices->dptr<K>(),\n        model_diff_values->dptr<T>(), buffer_manager.NumUniqueDiffIndicesPtr(),\n        buffer_manager.UniqueDiffIndicesPtr(), buffer_manager.UniqueDiffValuesPtr(),\n        buffer_manager.UniqueWorkspacePtr(), buffer_manager.UniqueWorkspaceBytes());\n    MdUpdateUtilT::Update(ctx->stream(), beta, dampening, nesterov, maximize, weight_decay,\n                          lr_scale, num_indices, feature_size, kernel_cache->lower(),\n                          kernel_cache->upper(), buffer_manager.NumUniqueDiffIndicesPtr(),\n                          learning_rate->dptr<float>(), buffer_manager.UniqueDiffIndicesPtr(),\n                          buffer_manager.UniqueDiffValuesPtr(), model->mut_dptr<T>(),\n                          momentum->mut_dptr<T>());\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return true; }\n};\n\n#define REGISTER_INDEXED_SLICES_MOMENTUM_UPDATE_KERNEL(device_type_v, data_type_pair,              \\\n                                                       indices_type_pair)                          \\\n  REGISTER_USER_KERNEL(\"indexed_slices_momentum_update\")                                           \\\n      .SetCreateFn<IndexedSlicesMomentumUpdateKernel<                                              \\\n          device_type_v, OF_PP_PAIR_FIRST(data_type_pair), OF_PP_PAIR_FIRST(indices_type_pair)>>() \\\n      .SetIsMatchedHob(                                                                            \\\n          (user_op::HobDeviceType() == device_type_v)                                              \\\n          && (user_op::HobDataType(\"model\", 0) == OF_PP_PAIR_SECOND(data_type_pair))               \\\n          && (user_op::HobDataType(\"model_diff_values\", 0) == OF_PP_PAIR_SECOND(data_type_pair))   \\\n          && (user_op::HobDataType(\"model_diff_indices\", 0)                                        \\\n              == OF_PP_PAIR_SECOND(indices_type_pair)))                                            \\\n      .SetInferTmpSizeFn(GenInferTmpSizeFn<device_type_v, OF_PP_PAIR_FIRST(data_type_pair),        \\\n                                           OF_PP_PAIR_FIRST(indices_type_pair)>());\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_INDEXED_SLICES_MOMENTUM_UPDATE_KERNEL, DEVICE_TYPE_SEQ,\n                                 FLOATING_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ)\n\ntemplate<DeviceType device_type, typename T, typename G, typename C>\nclass AdamUpdateKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport {\n public:\n  AdamUpdateKernel() = default;\n  ~AdamUpdateKernel() override = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* model_diff = ctx->Tensor4ArgNameAndIndex(\"model_diff\", 0);\n    user_op::Tensor* model = ctx->Tensor4ArgNameAndIndex(\"model\", 0);\n    user_op::Tensor* m = ctx->Tensor4ArgNameAndIndex(\"m\", 0);\n    user_op::Tensor* v = ctx->Tensor4ArgNameAndIndex(\"v\", 0);\n\n    const auto scale = ctx->Attr<double>(\"scale\");\n    const auto l1 = ctx->Attr<float>(\"l1\");\n    const auto l2 = ctx->Attr<float>(\"l2\");\n    const auto beta1 = ctx->Attr<float>(\"beta1\");\n    const auto beta2 = ctx->Attr<float>(\"beta2\");\n    const auto epsilon = ctx->Attr<float>(\"epsilon\");\n    const auto weight_decay = ctx->Attr<float>(\"weight_decay\");\n    const bool amsgrad = ctx->Attr<bool>(\"amsgrad\");\n    const bool do_bias_correction = ctx->Attr<bool>(\"do_bias_correction\");\n    const float lr_scale = ctx->Attr<float>(\"learning_rate_scale\");\n\n    T* max_v_ptr = nullptr;\n    if (amsgrad) {\n      user_op::Tensor* max_v = ctx->Tensor4ArgNameAndIndex(\"max_v\", 0);\n      max_v_ptr = max_v->mut_dptr<T>();\n      CHECK(max_v_ptr != nullptr);\n    }\n\n    const float learning_rate_val = ctx->Attr<float>(\"learning_rate_val\");\n    const float* learning_rate_ptr = nullptr;\n    if (ctx->has_input(\"learning_rate\", 0)) {\n      const user_op::Tensor* learning_rate = ctx->Tensor4ArgNameAndIndex(\"learning_rate\", 0);\n      learning_rate_ptr = learning_rate->dptr<float>();\n    }\n\n    const float bias_correction1_val = ctx->Attr<float>(\"bias_correction1_val\");\n    const float* bias_correction1_ptr = nullptr;\n    if (ctx->has_input(\"bias_correction1\", 0)) {\n      const user_op::Tensor* bias_correction1 = ctx->Tensor4ArgNameAndIndex(\"bias_correction1\", 0);\n      CHECK_EQ(bias_correction1->shape_view().elem_cnt(),\n               1);  // Just for Lazy Optional Input Check.\n      bias_correction1_ptr = bias_correction1->dptr<float>();\n    }\n\n    const float bias_correction2_val = ctx->Attr<float>(\"bias_correction2_val\");\n    const float* bias_correction2_ptr = nullptr;\n    if (ctx->has_input(\"bias_correction2\", 0)) {\n      const user_op::Tensor* bias_correction2 = ctx->Tensor4ArgNameAndIndex(\"bias_correction2\", 0);\n      CHECK_EQ(bias_correction2->shape_view().elem_cnt(),\n               1);  // Just for Lazy Optional Input Check.\n      bias_correction2_ptr = bias_correction2->dptr<float>();\n    }\n\n    const T* scale_by_ptr = nullptr;\n    if (ctx->has_input(\"scale_by_tensor\", 0)) {\n      const user_op::Tensor* scale_by_tensor = ctx->Tensor4ArgNameAndIndex(\"scale_by_tensor\", 0);\n      CHECK_EQ(scale_by_tensor->data_type(), model->data_type());\n      CHECK_EQ(scale_by_tensor->shape_view().elem_cnt(), 1);\n      scale_by_ptr = scale_by_tensor->dptr<T>();\n    }\n\n    const int64_t* skip_if_ptr = nullptr;\n    if (ctx->has_input(\"skip_if\", 0)) {\n      const user_op::Tensor* skip_if = ctx->Tensor4ArgNameAndIndex(\"skip_if\", 0);\n      CHECK_EQ(skip_if->shape_view().elem_cnt(), 1);\n      skip_if_ptr = skip_if->dptr<int64_t>();\n    }\n\n    C* model_copy_ptr = nullptr;\n    if (ctx->has_input(\"model_copy\", 0)) {\n      user_op::Tensor* model_copy = ctx->Tensor4ArgNameAndIndex(\"model_copy\", 0);\n      model_copy_ptr = model_copy->mut_dptr<C>();\n    }\n\n    AdamUpdateKernelUtil<device_type, T, G, C>::Update(\n        ctx->stream(), model->shape_view().elem_cnt(), static_cast<T>(scale), l1, l2, beta1, beta2,\n        epsilon, weight_decay, amsgrad, do_bias_correction, learning_rate_val, lr_scale,\n        bias_correction1_val, bias_correction2_val, learning_rate_ptr, scale_by_ptr, skip_if_ptr,\n        bias_correction1_ptr, bias_correction2_ptr, model_diff->dptr<G>(), model->mut_dptr<T>(),\n        model_copy_ptr, m->mut_dptr<T>(), v->mut_dptr<T>(), max_v_ptr);\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return true; }\n};\n\n#define REGISTER_ADAM_UPDATE_KERNEL(device, dtype, gtype, ctype)                          \\\n  REGISTER_USER_KERNEL(\"adam_update\")                                                     \\\n      .SetCreateFn<AdamUpdateKernel<device, dtype, gtype, ctype>>()                       \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == device)                               \\\n                       && (user_op::HobDataType(\"model\", 0) == GetDataType<dtype>::value) \\\n                       && (user_op::HobDataType(\"model_diff\", 0) == GetDataType<gtype>::value));\n\nREGISTER_ADAM_UPDATE_KERNEL(DeviceType::kCPU, float, float, float16);\nREGISTER_ADAM_UPDATE_KERNEL(DeviceType::kCPU, double, double, float16);\n#ifdef WITH_CUDA\nREGISTER_ADAM_UPDATE_KERNEL(DeviceType::kCUDA, float, float16, float16);\nREGISTER_ADAM_UPDATE_KERNEL(DeviceType::kCUDA, float, float, float16);\nREGISTER_ADAM_UPDATE_KERNEL(DeviceType::kCUDA, double, double, float16);\n#endif  // WITH_CUDA\n\ntemplate<DeviceType device_type, typename T, typename G>\nclass AdagradUpdateKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport {\n public:\n  AdagradUpdateKernel() = default;\n  ~AdagradUpdateKernel() override = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* model_diff = ctx->Tensor4ArgNameAndIndex(\"model_diff\", 0);\n    user_op::Tensor* model = ctx->Tensor4ArgNameAndIndex(\"model\", 0);\n    user_op::Tensor* sum = ctx->Tensor4ArgNameAndIndex(\"sum\", 0);\n    const auto scale = ctx->Attr<double>(\"scale\");\n    const auto l1 = ctx->Attr<float>(\"l1\");\n    const auto l2 = ctx->Attr<float>(\"l2\");\n    const auto lr_decay = ctx->Attr<float>(\"lr_decay\");\n    const auto epsilon = ctx->Attr<float>(\"epsilon\");\n    const auto weight_decay = ctx->Attr<float>(\"weight_decay\");\n    const float learning_rate_val = ctx->Attr<float>(\"learning_rate_val\");\n    const float* learning_rate_ptr = nullptr;\n    const int64_t train_step_val = ctx->Attr<int32_t>(\"train_step_val\");\n    const int64_t* train_step_ptr = nullptr;\n    const float lr_scale = ctx->Attr<float>(\"learning_rate_scale\");\n\n    if (ctx->has_input(\"learning_rate\", 0)) {\n      const user_op::Tensor* learning_rate = ctx->Tensor4ArgNameAndIndex(\"learning_rate\", 0);\n      learning_rate_ptr = learning_rate->dptr<float>();\n    }\n    if (ctx->has_input(\"train_step\", 0)) {\n      const user_op::Tensor* train_step = ctx->Tensor4ArgNameAndIndex(\"train_step\", 0);\n      train_step_ptr = train_step->dptr<int64_t>();\n    }\n\n    const T* scale_by_ptr = nullptr;\n    if (ctx->has_input(\"scale_by_tensor\", 0)) {\n      const user_op::Tensor* scale_by_tensor = ctx->Tensor4ArgNameAndIndex(\"scale_by_tensor\", 0);\n      CHECK_EQ(scale_by_tensor->data_type(), model->data_type());\n      CHECK_EQ(scale_by_tensor->shape_view().elem_cnt(), 1);\n      scale_by_ptr = scale_by_tensor->dptr<T>();\n    }\n    const int64_t* skip_if_ptr = nullptr;\n    if (ctx->has_input(\"skip_if\", 0)) {\n      const user_op::Tensor* skip_if = ctx->Tensor4ArgNameAndIndex(\"skip_if\", 0);\n      CHECK_EQ(skip_if->shape_view().elem_cnt(), 1);\n      skip_if_ptr = skip_if->dptr<int64_t>();\n    }\n    AdagradUpdateKernelUtil<device_type, T, G>::Update(\n        ctx->stream(), model->shape_view().elem_cnt(), static_cast<T>(scale), l1, l2, lr_decay,\n        epsilon, weight_decay, learning_rate_val, lr_scale, train_step_val, learning_rate_ptr,\n        train_step_ptr, scale_by_ptr, skip_if_ptr, model_diff->dptr<G>(), model->mut_dptr<T>(),\n        sum->mut_dptr<T>());\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return true; }\n};\n\n#define REGISTER_ADAGRAD_UPDATE_KERNEL(device, dtype, gtype)                              \\\n  REGISTER_USER_KERNEL(\"adagrad_update\")                                                  \\\n      .SetCreateFn<AdagradUpdateKernel<device, dtype, gtype>>()                           \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == device)                               \\\n                       && (user_op::HobDataType(\"model\", 0) == GetDataType<dtype>::value) \\\n                       && (user_op::HobDataType(\"model_diff\", 0) == GetDataType<gtype>::value));\n\nREGISTER_ADAGRAD_UPDATE_KERNEL(DeviceType::kCPU, float, float);\nREGISTER_ADAGRAD_UPDATE_KERNEL(DeviceType::kCPU, double, double);\n#ifdef WITH_CUDA\nREGISTER_ADAGRAD_UPDATE_KERNEL(DeviceType::kCUDA, float, float);\nREGISTER_ADAGRAD_UPDATE_KERNEL(DeviceType::kCUDA, double, double);\n#endif  // WITH_CUDA\n\ntemplate<DeviceType device_type, typename T, typename K>\nclass IndexedSlicesAdamUpdateKernel final : public user_op::OpKernel {\n public:\n  IndexedSlicesAdamUpdateKernel() = default;\n  ~IndexedSlicesAdamUpdateKernel() override = default;\n  std::shared_ptr<user_op::OpKernelCache> InitOpKernelCache(\n      user_op::KernelCacheContext* ctx) const override {\n    return CreateIndexedSlicesUpdateOpKernelCache(ctx);\n  }\n\n private:\n  using ReduceSumUtilT = IndexedSlicesReduceSumKernelUtil<device_type, K, T, int32_t>;\n  using MdUpdateUtilT = IndexedSlicesAdamMdUpdateKernelUtil<device_type, T, K, int32_t>;\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*,\n               const user_op::OpKernelCache* cache) const override {\n    const float learning_rate_val = ctx->Attr<float>(\"learning_rate_val\");\n    const float* learning_rate_ptr = nullptr;\n    if (ctx->has_input(\"learning_rate\", 0)) {\n      const user_op::Tensor* learning_rate = ctx->Tensor4ArgNameAndIndex(\"learning_rate\", 0);\n      learning_rate_ptr = learning_rate->dptr<float>();\n    }\n\n    const float* bias_correction1_ptr = nullptr;\n    if (ctx->has_input(\"bias_correction1\", 0)) {\n      const user_op::Tensor* bias_correction1 = ctx->Tensor4ArgNameAndIndex(\"bias_correction1\", 0);\n      CHECK_EQ(bias_correction1->shape_view().elem_cnt(), 1);\n      bias_correction1_ptr = bias_correction1->dptr<float>();\n    }\n\n    const float* bias_correction2_ptr = nullptr;\n    if (ctx->has_input(\"bias_correction2\", 0)) {\n      const user_op::Tensor* bias_correction2 = ctx->Tensor4ArgNameAndIndex(\"bias_correction2\", 0);\n      CHECK_EQ(bias_correction2->shape_view().elem_cnt(), 1);\n      bias_correction2_ptr = bias_correction2->dptr<float>();\n    }\n\n    const user_op::Tensor* model_diff_indices =\n        ctx->Tensor4ArgNameAndIndex(\"model_diff_indices\", 0);\n    const user_op::Tensor* model_diff_values = ctx->Tensor4ArgNameAndIndex(\"model_diff_values\", 0);\n    user_op::Tensor* model = ctx->Tensor4ArgNameAndIndex(\"model\", 0);\n    user_op::Tensor* m = ctx->Tensor4ArgNameAndIndex(\"m\", 0);\n    user_op::Tensor* v = ctx->Tensor4ArgNameAndIndex(\"v\", 0);\n\n    const auto beta1 = ctx->Attr<float>(\"beta1\");\n    const auto beta2 = ctx->Attr<float>(\"beta2\");\n    const auto epsilon = ctx->Attr<float>(\"epsilon\");\n    const auto weight_decay = ctx->Attr<float>(\"weight_decay\");\n    const bool amsgrad = ctx->Attr<bool>(\"amsgrad\");\n    const bool do_bias_correction = ctx->Attr<bool>(\"do_bias_correction\");\n    const float lr_scale = ctx->Attr<float>(\"learning_rate_scale\");\n\n    T* max_v_ptr = nullptr;\n    if (amsgrad) {\n      user_op::Tensor* max_v = ctx->Tensor4ArgNameAndIndex(\"max_v\", 0);\n      max_v_ptr = max_v->mut_dptr<T>();\n    }\n\n    auto* kernel_cache = dynamic_cast<const IndexedSlicesUpdateOpKernelCache*>(cache);\n    CHECK_NOTNULL(kernel_cache);\n    CHECK_EQ(model->shape_view().At(0), kernel_cache->upper() - kernel_cache->lower());\n    const int64_t num_indices = model_diff_indices->shape_view().elem_cnt();\n    const int64_t num_values = model_diff_values->shape_view().elem_cnt();\n    if (num_indices == 0) {\n      CHECK_EQ(num_values, 0);\n      return;\n    }\n    CHECK_NE(num_values, 0);\n    CHECK_EQ(num_values % num_indices, 0);\n    const int64_t feature_size = num_values / num_indices;\n    CHECK_EQ(feature_size,\n             model_diff_values->shape_view().Count(model_diff_indices->shape_view().NumAxes()));\n    user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n    TmpBufferManager<device_type, T, K> buffer_manager(tmp_buffer->mut_dptr(), num_indices,\n                                                       num_values);\n    CHECK_GE(tmp_buffer->shape_view().elem_cnt(), buffer_manager.GetTotalBufferSize());\n\n    ReduceSumUtilT::ReduceSum(\n        ctx->stream(), num_indices, feature_size, model_diff_indices->dptr<K>(),\n        model_diff_values->dptr<T>(), buffer_manager.NumUniqueDiffIndicesPtr(),\n        buffer_manager.UniqueDiffIndicesPtr(), buffer_manager.UniqueDiffValuesPtr(),\n        buffer_manager.UniqueWorkspacePtr(), buffer_manager.UniqueWorkspaceBytes());\n\n    MdUpdateUtilT::Update(\n        ctx->stream(), beta1, beta2, epsilon, weight_decay, amsgrad, do_bias_correction,\n        learning_rate_val, lr_scale, num_indices, feature_size, kernel_cache->lower(),\n        kernel_cache->upper(), buffer_manager.NumUniqueDiffIndicesPtr(), learning_rate_ptr,\n        bias_correction1_ptr, bias_correction2_ptr, buffer_manager.UniqueDiffIndicesPtr(),\n        buffer_manager.UniqueDiffValuesPtr(), model->mut_dptr<T>(), m->mut_dptr<T>(),\n        v->mut_dptr<T>(), max_v_ptr);\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return true; }\n};\n\n#define REGISTER_INDEXED_SLICES_ADAM_UPDATE_KERNEL(device_type_v, data_type_pair,                 \\\n                                                   indices_type_pair)                             \\\n  REGISTER_USER_KERNEL(\"indexed_slices_adam_update\")                                              \\\n      .SetCreateFn<IndexedSlicesAdamUpdateKernel<device_type_v, OF_PP_PAIR_FIRST(data_type_pair), \\\n                                                 OF_PP_PAIR_FIRST(indices_type_pair)>>()          \\\n      .SetIsMatchedHob(                                                                           \\\n          (user_op::HobDeviceType() == device_type_v)                                             \\\n          && (user_op::HobDataType(\"model\", 0) == OF_PP_PAIR_SECOND(data_type_pair))              \\\n          && (user_op::HobDataType(\"model_diff_values\", 0) == OF_PP_PAIR_SECOND(data_type_pair))  \\\n          && (user_op::HobDataType(\"model_diff_indices\", 0)                                       \\\n              == OF_PP_PAIR_SECOND(indices_type_pair)))                                           \\\n      .SetInferTmpSizeFn(GenInferTmpSizeFn<device_type_v, OF_PP_PAIR_FIRST(data_type_pair),       \\\n                                           OF_PP_PAIR_FIRST(indices_type_pair)>());\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_INDEXED_SLICES_ADAM_UPDATE_KERNEL, DEVICE_TYPE_SEQ,\n                                 FLOATING_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ)\n\ntemplate<DeviceType device_type, typename T>\nclass LambTmpBufferManager final {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(LambTmpBufferManager);\n  LambTmpBufferManager(void* ptr, const int64_t n) : ptr_(ptr) {\n    const size_t adam_diff_bytes = GetCudaAlignedSize(n * sizeof(T));\n    norm_buffer_bytes_ = GetCudaAlignedSize(2 * sizeof(T));\n    adam_diff_offset_ = 0;\n    norm_buffer_offset_ = adam_diff_offset_ + adam_diff_bytes;\n    total_buffer_size_ = adam_diff_bytes + norm_buffer_bytes_;\n  }\n  ~LambTmpBufferManager() = default;\n\n  size_t GetNormBufferSize() const { return norm_buffer_bytes_; }\n  size_t GetTotalBufferSize() const { return total_buffer_size_; }\n\n  T* AdamDiffPtr() const {\n    CHECK(ptr_ != nullptr);\n    return reinterpret_cast<T*>(reinterpret_cast<char*>(ptr_) + adam_diff_offset_);\n  }\n  T* NormBufferPtr() const {\n    CHECK(ptr_ != nullptr);\n    return reinterpret_cast<T*>(reinterpret_cast<char*>(ptr_) + norm_buffer_offset_);\n  }\n\n private:\n  size_t adam_diff_offset_;\n  size_t norm_buffer_offset_;\n\n  size_t total_buffer_size_;\n  size_t norm_buffer_bytes_;\n  void* ptr_;\n};\n\ntemplate<DeviceType device_type, typename T, typename G>\nclass LambUpdateKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport {\n public:\n  LambUpdateKernel() = default;\n  ~LambUpdateKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* model_diff = ctx->Tensor4ArgNameAndIndex(\"model_diff\", 0);\n    user_op::Tensor* model = ctx->Tensor4ArgNameAndIndex(\"model\", 0);\n    user_op::Tensor* m = ctx->Tensor4ArgNameAndIndex(\"m\", 0);\n    user_op::Tensor* v = ctx->Tensor4ArgNameAndIndex(\"v\", 0);\n    user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n    LambTmpBufferManager<device_type, T> tbm(tmp_buffer->mut_dptr(),\n                                             model->shape_view().elem_cnt());\n\n    const auto scale = ctx->Attr<double>(\"scale\");\n    const auto l1 = ctx->Attr<float>(\"l1\");\n    const auto l2 = ctx->Attr<float>(\"l2\");\n    const auto beta1 = ctx->Attr<float>(\"beta1\");\n    const auto beta2 = ctx->Attr<float>(\"beta2\");\n    const auto epsilon = ctx->Attr<float>(\"epsilon\");\n    const auto weight_decay = ctx->Attr<float>(\"weight_decay\");\n    const auto lr_scale = ctx->Attr<float>(\"learning_rate_scale\");\n\n    const bool do_bias_correction = ctx->Attr<bool>(\"do_bias_correction\");\n    const float bias_correction1_val = ctx->Attr<float>(\"bias_correction1_val\");\n    const float* bias_correction1_ptr = nullptr;\n    if (ctx->has_input(\"bias_correction1\", 0)) {\n      const user_op::Tensor* bias_correction1 = ctx->Tensor4ArgNameAndIndex(\"bias_correction1\", 0);\n      // Just for Lazy optional input check.\n      CHECK_EQ(bias_correction1->shape_view().elem_cnt(), 1);\n      bias_correction1_ptr = bias_correction1->dptr<float>();\n    }\n    const float bias_correction2_val = ctx->Attr<float>(\"bias_correction2_val\");\n    const float* bias_correction2_ptr = nullptr;\n    if (ctx->has_input(\"bias_correction2\", 0)) {\n      const user_op::Tensor* bias_correction2 = ctx->Tensor4ArgNameAndIndex(\"bias_correction2\", 0);\n      CHECK_EQ(bias_correction2->shape_view().elem_cnt(), 1);\n      bias_correction2_ptr = bias_correction2->dptr<float>();\n    }\n\n    const float learning_rate_val = ctx->Attr<float>(\"learning_rate_val\");\n    const float* learning_rate_ptr = nullptr;\n    if (ctx->has_input(\"learning_rate\", 0)) {\n      const user_op::Tensor* learning_rate = ctx->Tensor4ArgNameAndIndex(\"learning_rate\", 0);\n      learning_rate_ptr = learning_rate->dptr<float>();\n    }\n\n    const T* scale_by_ptr = nullptr;\n    if (ctx->has_input(\"scale_by_tensor\", 0)) {\n      const user_op::Tensor* scale_by_tensor = ctx->Tensor4ArgNameAndIndex(\"scale_by_tensor\", 0);\n      CHECK_EQ(scale_by_tensor->data_type(), model->data_type());\n      CHECK_EQ(scale_by_tensor->shape_view().elem_cnt(), 1);\n      scale_by_ptr = scale_by_tensor->dptr<T>();\n    }\n\n    const int64_t* skip_if_ptr = nullptr;\n    if (ctx->has_input(\"skip_if\", 0)) {\n      const user_op::Tensor* skip_if = ctx->Tensor4ArgNameAndIndex(\"skip_if\", 0);\n      CHECK_EQ(skip_if->shape_view().elem_cnt(), 1);\n      skip_if_ptr = skip_if->dptr<int64_t>();\n    }\n\n    LambUpdateKernelUtil<device_type, T, G>::Update(\n        ctx->stream(), m->shape_view().elem_cnt(), scale, l1, l2, beta1, beta2, epsilon,\n        weight_decay, learning_rate_val, lr_scale, do_bias_correction, bias_correction1_val,\n        bias_correction2_val, learning_rate_ptr, bias_correction1_ptr, bias_correction2_ptr,\n        scale_by_ptr, skip_if_ptr, model_diff->dptr<G>(), tbm.AdamDiffPtr(), model->mut_dptr<T>(),\n        m->mut_dptr<T>(), v->mut_dptr<T>(), tbm.NormBufferPtr());\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return true; }\n};\n\ntemplate<DeviceType device_type, typename T>\nuser_op::InferTmpSizeFn LambGenInferTmpSizeFn() {\n  return [](user_op::InferContext* ctx) {\n    const user_op::TensorDesc& model = ctx->InputTensorDesc(\"model\", 0);\n    LambTmpBufferManager<device_type, T> tbm(nullptr, model.shape().elem_cnt());\n    return tbm.GetTotalBufferSize();\n  };\n}\n\n#define REGISTER_LAMB_UPDATE_KERNEL(device, dtype, gtype)                                       \\\n  REGISTER_USER_KERNEL(\"lamb_update\")                                                           \\\n      .SetCreateFn<LambUpdateKernel<device, dtype, gtype>>()                                    \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == device)                                     \\\n                       && (user_op::HobDataType(\"model\", 0) == GetDataType<dtype>::value)       \\\n                       && (user_op::HobDataType(\"model_diff\", 0) == GetDataType<gtype>::value)) \\\n      .SetInferTmpSizeFn(LambGenInferTmpSizeFn<device, dtype>());\n\nREGISTER_LAMB_UPDATE_KERNEL(DeviceType::kCPU, float, float);\nREGISTER_LAMB_UPDATE_KERNEL(DeviceType::kCPU, double, double);\n#ifdef WITH_CUDA\nREGISTER_LAMB_UPDATE_KERNEL(DeviceType::kCUDA, float, float16);\nREGISTER_LAMB_UPDATE_KERNEL(DeviceType::kCUDA, float, float);\nREGISTER_LAMB_UPDATE_KERNEL(DeviceType::kCUDA, double, double);\n#endif  // WITH_CUDA\n\ntemplate<DeviceType device_type>\nclass BiasCorrectionFactorKernel final : public user_op::OpKernel,\n                                         public user_op::CudaGraphSupport {\n public:\n  BiasCorrectionFactorKernel() = default;\n  ~BiasCorrectionFactorKernel() override = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* train_step = ctx->Tensor4ArgNameAndIndex(\"train_step\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    const auto beta = ctx->Attr<float>(\"beta\");\n    BiasCorrectionFactorKernelUtil<device_type>::BiasCorrectionFactorCompute(\n        ctx->stream(), beta, train_step->dptr<int64_t>(), out->mut_dptr<float>());\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return true; }\n};\n\n#define REGISTER_ADAM_BIAS_CORRECTION_FACTOR_KERNEL(device) \\\n  REGISTER_USER_KERNEL(\"adam_bias_correction_factor\")       \\\n      .SetCreateFn<BiasCorrectionFactorKernel<device>>()    \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == device));\nREGISTER_ADAM_BIAS_CORRECTION_FACTOR_KERNEL(DeviceType::kCPU)\n#ifdef WITH_CUDA\nREGISTER_ADAM_BIAS_CORRECTION_FACTOR_KERNEL(DeviceType::kCUDA)\n#endif  // WITH_CUDA\n\ntemplate<DeviceType device_type, typename T, typename G>\nclass RmsPropUpdateKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport {\n public:\n  RmsPropUpdateKernel() = default;\n  ~RmsPropUpdateKernel() override = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* model_diff = ctx->Tensor4ArgNameAndIndex(\"model_diff\", 0);\n    user_op::Tensor* model = ctx->Tensor4ArgNameAndIndex(\"model\", 0);\n    user_op::Tensor* mean_square = ctx->Tensor4ArgNameAndIndex(\"mean_square\", 0);\n    const auto scale = ctx->Attr<double>(\"scale\");\n    const auto l1 = ctx->Attr<float>(\"l1\");\n    const auto l2 = ctx->Attr<float>(\"l2\");\n    const auto decay_rate = ctx->Attr<float>(\"decay_rate\");\n    const auto epsilon = ctx->Attr<float>(\"epsilon\");\n    const auto centered = ctx->Attr<bool>(\"centered\");\n    const auto weight_decay = ctx->Attr<float>(\"weight_decay\");\n    const float learning_rate_val = ctx->Attr<float>(\"learning_rate_val\");\n    const float lr_scale = ctx->Attr<float>(\"learning_rate_scale\");\n    const float* learning_rate_ptr = nullptr;\n    if (ctx->has_input(\"learning_rate\", 0)) {\n      const user_op::Tensor* learning_rate = ctx->Tensor4ArgNameAndIndex(\"learning_rate\", 0);\n      learning_rate_ptr = learning_rate->dptr<float>();\n    }\n    const T* scale_by_ptr = nullptr;\n    if (ctx->has_input(\"scale_by_tensor\", 0)) {\n      const user_op::Tensor* scale_by_tensor = ctx->Tensor4ArgNameAndIndex(\"scale_by_tensor\", 0);\n      CHECK_EQ(scale_by_tensor->data_type(), model->data_type());\n      CHECK_EQ(scale_by_tensor->shape_view().elem_cnt(), 1);\n      scale_by_ptr = scale_by_tensor->dptr<T>();\n    }\n    const int64_t* skip_if_ptr = nullptr;\n    if (ctx->has_input(\"skip_if\", 0)) {\n      const user_op::Tensor* skip_if = ctx->Tensor4ArgNameAndIndex(\"skip_if\", 0);\n      CHECK_EQ(skip_if->shape_view().elem_cnt(), 1);\n      skip_if_ptr = skip_if->dptr<int64_t>();\n    }\n    T* mean_gradient_ptr = nullptr;\n    if (centered) {\n      user_op::Tensor* mean_gradient = ctx->Tensor4ArgNameAndIndex(\"mean_gradient\", 0);\n      mean_gradient_ptr = mean_gradient->mut_dptr<T>();\n    }\n    RmsPropUpdateKernelUtil<device_type, T, G>::Update(\n        ctx->stream(), model->shape_view().elem_cnt(), static_cast<T>(scale), l1, l2, centered,\n        epsilon, weight_decay, decay_rate, learning_rate_val, lr_scale, learning_rate_ptr,\n        scale_by_ptr, skip_if_ptr, model_diff->dptr<G>(), model->mut_dptr<T>(),\n        mean_square->mut_dptr<T>(), mean_gradient_ptr);\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return true; }\n};\n\n#define REGISTER_RMSPROP_UPDATE_KERNEL(device, dtype, gtype)                              \\\n  REGISTER_USER_KERNEL(\"rmsprop_update\")                                                  \\\n      .SetCreateFn<RmsPropUpdateKernel<device, dtype, gtype>>()                           \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == device)                               \\\n                       && (user_op::HobDataType(\"model\", 0) == GetDataType<dtype>::value) \\\n                       && (user_op::HobDataType(\"model_diff\", 0) == GetDataType<gtype>::value));\n\nREGISTER_RMSPROP_UPDATE_KERNEL(DeviceType::kCPU, float, float);\nREGISTER_RMSPROP_UPDATE_KERNEL(DeviceType::kCPU, double, double);\n#ifdef WITH_CUDA\nREGISTER_RMSPROP_UPDATE_KERNEL(DeviceType::kCUDA, float, float16);\nREGISTER_RMSPROP_UPDATE_KERNEL(DeviceType::kCUDA, float, float);\nREGISTER_RMSPROP_UPDATE_KERNEL(DeviceType::kCUDA, double, double);\n#endif  // WITH_CUDA\n\ntemplate<DeviceType device_type, typename T>\nclass LarsTmpBufferManager final {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(LarsTmpBufferManager);\n  LarsTmpBufferManager(void* ptr, const int64_t n) : ptr_(ptr) {\n    model_diff_size_ = GetCudaAlignedSize(n * sizeof(T));\n    model_diff_offset_ = 0;\n    data_tmp_size_ = GetCudaAlignedSize(3 * sizeof(T));\n    data_tmp_offset_ = model_diff_offset_ + model_diff_size_;\n    total_buffer_size_ = model_diff_size_ + data_tmp_size_;\n  }\n  ~LarsTmpBufferManager() = default;\n\n  size_t GetTotalBufferSize() const { return total_buffer_size_; }\n  size_t GetDataTmpBufferSize() const { return data_tmp_size_; }\n\n  T* ModelDiffPtr() const {\n    CHECK(ptr_ != nullptr);\n    return reinterpret_cast<T*>(reinterpret_cast<char*>(ptr_) + model_diff_offset_);\n  }\n\n  T* DataTmpPtr() const {\n    CHECK(ptr_ != nullptr);\n    return reinterpret_cast<T*>(reinterpret_cast<char*>(ptr_) + data_tmp_offset_);\n  }\n\n private:\n  size_t model_diff_offset_;\n  size_t model_diff_size_;\n  size_t data_tmp_offset_;\n  size_t data_tmp_size_;\n  size_t total_buffer_size_;\n  void* ptr_;\n};\n\ntemplate<DeviceType device_type, typename T, typename G>\nclass LarsUpdateKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport {\n public:\n  LarsUpdateKernel() = default;\n  ~LarsUpdateKernel() override = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* learning_rate = ctx->Tensor4ArgNameAndIndex(\"learning_rate\", 0);\n    const user_op::Tensor* model_diff = ctx->Tensor4ArgNameAndIndex(\"model_diff\", 0);\n    user_op::Tensor* model = ctx->Tensor4ArgNameAndIndex(\"model\", 0);\n    user_op::Tensor* momentum = ctx->Tensor4ArgNameAndIndex(\"momentum\", 0);\n    user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n    LarsTmpBufferManager<device_type, T> tlm(tmp_buffer->mut_dptr(),\n                                             model->shape_view().elem_cnt());\n    const auto scale = ctx->Attr<double>(\"scale\");\n    const auto l1 = ctx->Attr<float>(\"l1\");\n    const auto l2 = ctx->Attr<float>(\"l2\");\n    const auto momentum_beta = ctx->Attr<float>(\"momentum_beta\");\n    const auto epsilon = ctx->Attr<float>(\"epsilon\");\n    const auto lars_coefficient = ctx->Attr<float>(\"lars_coefficient\");\n    const auto weight_decay = ctx->Attr<float>(\"weight_decay\");\n    const auto lr_scale = ctx->Attr<float>(\"learning_rate_scale\");\n    const T* scale_by_ptr = nullptr;\n    if (ctx->has_input(\"scale_by_tensor\", 0)) {\n      const user_op::Tensor* scale_by_tensor = ctx->Tensor4ArgNameAndIndex(\"scale_by_tensor\", 0);\n      CHECK_EQ(scale_by_tensor->data_type(), model->data_type());\n      CHECK_EQ(scale_by_tensor->shape_view().elem_cnt(), 1);\n      scale_by_ptr = scale_by_tensor->dptr<T>();\n    }\n    const int64_t* skip_if_ptr = nullptr;\n    if (ctx->has_input(\"skip_if\", 0)) {\n      const user_op::Tensor* skip_if = ctx->Tensor4ArgNameAndIndex(\"skip_if\", 0);\n      CHECK_EQ(skip_if->shape_view().elem_cnt(), 1);\n      skip_if_ptr = skip_if->dptr<int64_t>();\n    }\n    LarsUpdateKernelUtil<device_type, T, G>::Update(\n        ctx->stream(), model->shape_view().elem_cnt(), static_cast<T>(scale), l1, l2, momentum_beta,\n        epsilon, lars_coefficient, weight_decay, lr_scale, learning_rate->dptr<float>(),\n        scale_by_ptr, skip_if_ptr, model_diff->dptr<G>(), model->mut_dptr<T>(),\n        momentum->mut_dptr<T>(), tlm.DataTmpPtr(), tlm.ModelDiffPtr());\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return true; }\n};\n\ntemplate<DeviceType device_type, typename T>\nuser_op::InferTmpSizeFn LarsGenInferTmpSizeFn() {\n  return [](user_op::InferContext* ctx) {\n    const user_op::TensorDesc& model = ctx->InputTensorDesc(\"model\", 0);\n    LarsTmpBufferManager<device_type, T> tlm(nullptr, model.shape().elem_cnt());\n    return tlm.GetTotalBufferSize();\n  };\n}\n\n#define REGISTER_LARS_UPDATE_KERNEL(device, dtype, gtype)                                       \\\n  REGISTER_USER_KERNEL(\"lars_update\")                                                           \\\n      .SetCreateFn<LarsUpdateKernel<device, dtype, gtype>>()                                    \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == device)                                     \\\n                       && (user_op::HobDataType(\"model\", 0) == GetDataType<dtype>::value)       \\\n                       && (user_op::HobDataType(\"model_diff\", 0) == GetDataType<gtype>::value)) \\\n      .SetInferTmpSizeFn(LarsGenInferTmpSizeFn<device, dtype>());\n\nREGISTER_LARS_UPDATE_KERNEL(DeviceType::kCPU, float, float);\nREGISTER_LARS_UPDATE_KERNEL(DeviceType::kCPU, double, double);\n#ifdef WITH_CUDA\nREGISTER_LARS_UPDATE_KERNEL(DeviceType::kCUDA, float, float16);\nREGISTER_LARS_UPDATE_KERNEL(DeviceType::kCUDA, float, float);\nREGISTER_LARS_UPDATE_KERNEL(DeviceType::kCUDA, double, double);\n#endif  // WITH_CUDA\n\ntemplate<DeviceType device_type, typename T, typename G>\nclass FtrlUpdateKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport {\n public:\n  FtrlUpdateKernel() = default;\n  ~FtrlUpdateKernel() override = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* model_diff = ctx->Tensor4ArgNameAndIndex(\"model_diff\", 0);\n    user_op::Tensor* model = ctx->Tensor4ArgNameAndIndex(\"model\", 0);\n    user_op::Tensor* accumulate = ctx->Tensor4ArgNameAndIndex(\"accumulate\", 0);\n    user_op::Tensor* z = ctx->Tensor4ArgNameAndIndex(\"z\", 0);\n    const auto scale = ctx->Attr<double>(\"scale\");\n    const auto l1 = ctx->Attr<float>(\"l1\");\n    const auto l2 = ctx->Attr<float>(\"l2\");\n    const float lr_power = ctx->Attr<float>(\"lr_power\");\n    const float lambda1 = ctx->Attr<float>(\"lambda1\");\n    const float lambda2 = ctx->Attr<float>(\"lambda2\");\n    const float beta = ctx->Attr<float>(\"beta\");\n    const float weight_decay = ctx->Attr<float>(\"weight_decay\");\n    // TODO(zhengzekang): Undefined behavior for ftrl optimizer with weight_decay in `abs(new_z_val)\n    // < lambda1` condition.\n    CHECK_EQ(weight_decay, static_cast<float>(0.0))\n        << \"Currently not support for setting weight decay. \";\n    const float learning_rate_val = ctx->Attr<float>(\"learning_rate_val\");\n    const float lr_scale = ctx->Attr<float>(\"learning_rate_scale\");\n    const float* learning_rate_ptr = nullptr;\n\n    if (ctx->has_input(\"learning_rate\", 0)) {\n      const user_op::Tensor* learning_rate = ctx->Tensor4ArgNameAndIndex(\"learning_rate\", 0);\n      learning_rate_ptr = learning_rate->dptr<float>();\n    }\n\n    const T* scale_by_ptr = nullptr;\n    if (ctx->has_input(\"scale_by_tensor\", 0)) {\n      const user_op::Tensor* scale_by_tensor = ctx->Tensor4ArgNameAndIndex(\"scale_by_tensor\", 0);\n      CHECK_EQ(scale_by_tensor->data_type(), model->data_type());\n      CHECK_EQ(scale_by_tensor->shape_view().elem_cnt(), 1);\n      scale_by_ptr = scale_by_tensor->dptr<T>();\n    }\n    const int64_t* skip_if_ptr = nullptr;\n    if (ctx->has_input(\"skip_if\", 0)) {\n      const user_op::Tensor* skip_if = ctx->Tensor4ArgNameAndIndex(\"skip_if\", 0);\n      CHECK_EQ(skip_if->shape_view().elem_cnt(), 1);\n      skip_if_ptr = skip_if->dptr<int64_t>();\n    }\n    FtrlUpdateKernelUtil<device_type, T, G>::Update(\n        ctx->stream(), model->shape_view().elem_cnt(), static_cast<T>(scale), l1, l2, lr_power,\n        lambda1, lambda2, beta, weight_decay, learning_rate_val, lr_scale, learning_rate_ptr,\n        scale_by_ptr, skip_if_ptr, model_diff->dptr<G>(), model->mut_dptr<T>(),\n        accumulate->mut_dptr<T>(), z->mut_dptr<T>());\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return true; }\n};\n\n#define REGISTER_FTRL_UPDATE_KERNEL(device, dtype, gtype)                                 \\\n  REGISTER_USER_KERNEL(\"ftrl_update\")                                                     \\\n      .SetCreateFn<FtrlUpdateKernel<device, dtype, gtype>>()                              \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == device)                               \\\n                       && (user_op::HobDataType(\"model\", 0) == GetDataType<dtype>::value) \\\n                       && (user_op::HobDataType(\"model_diff\", 0) == GetDataType<gtype>::value));\n\nREGISTER_FTRL_UPDATE_KERNEL(DeviceType::kCPU, float, float);\nREGISTER_FTRL_UPDATE_KERNEL(DeviceType::kCPU, double, double);\n#ifdef WITH_CUDA\nREGISTER_FTRL_UPDATE_KERNEL(DeviceType::kCUDA, float, float16);\nREGISTER_FTRL_UPDATE_KERNEL(DeviceType::kCUDA, float, float);\nREGISTER_FTRL_UPDATE_KERNEL(DeviceType::kCUDA, double, double);\n#endif  // WITH_CUDA\n\ntemplate<DeviceType device_type, typename T, typename G>\nclass AdadeltaUpdateKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport {\n public:\n  AdadeltaUpdateKernel() = default;\n  ~AdadeltaUpdateKernel() override = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* model_diff = ctx->Tensor4ArgNameAndIndex(\"model_diff\", 0);\n    user_op::Tensor* model = ctx->Tensor4ArgNameAndIndex(\"model\", 0);\n    user_op::Tensor* square_avgs = ctx->Tensor4ArgNameAndIndex(\"square_avgs\", 0);\n    user_op::Tensor* acc_deltas = ctx->Tensor4ArgNameAndIndex(\"acc_deltas\", 0);\n    const auto scale = ctx->Attr<double>(\"scale\");\n    const auto l1 = ctx->Attr<float>(\"l1\");\n    const auto l2 = ctx->Attr<float>(\"l2\");\n    const float rho = ctx->Attr<float>(\"rho\");\n    const float epsilon = ctx->Attr<float>(\"epsilon\");\n    const bool maximize = ctx->Attr<bool>(\"maximize\");\n    const float weight_decay = ctx->Attr<float>(\"weight_decay\");\n    const float learning_rate_val = ctx->Attr<float>(\"learning_rate_val\");\n    const float lr_scale = ctx->Attr<float>(\"learning_rate_scale\");\n    const float* learning_rate_ptr = nullptr;\n    if (ctx->has_input(\"learning_rate\", 0)) {\n      const user_op::Tensor* learning_rate = ctx->Tensor4ArgNameAndIndex(\"learning_rate\", 0);\n      learning_rate_ptr = learning_rate->dptr<float>();\n    }\n\n    const T* scale_by_ptr = nullptr;\n    if (ctx->has_input(\"scale_by_tensor\", 0)) {\n      const user_op::Tensor* scale_by_tensor = ctx->Tensor4ArgNameAndIndex(\"scale_by_tensor\", 0);\n      CHECK_EQ(scale_by_tensor->data_type(), model->data_type());\n      CHECK_EQ(scale_by_tensor->shape_view().elem_cnt(), 1);\n      scale_by_ptr = scale_by_tensor->dptr<T>();\n    }\n    const int64_t* skip_if_ptr = nullptr;\n    if (ctx->has_input(\"skip_if\", 0)) {\n      const user_op::Tensor* skip_if = ctx->Tensor4ArgNameAndIndex(\"skip_if\", 0);\n      CHECK_EQ(skip_if->shape_view().elem_cnt(), 1);\n      skip_if_ptr = skip_if->dptr<int64_t>();\n    }\n    AdadeltaUpdateKernelUtil<device_type, T, G>::Update(\n        ctx->stream(), model->shape_view().elem_cnt(), static_cast<T>(scale), l1, l2, rho, epsilon,\n        maximize, weight_decay, learning_rate_val, lr_scale, learning_rate_ptr, scale_by_ptr,\n        skip_if_ptr, model_diff->dptr<G>(), model->mut_dptr<T>(), square_avgs->mut_dptr<T>(),\n        acc_deltas->mut_dptr<T>());\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return true; }\n};\n\n#define REGISTER_ADADELTA_UPDATE_KERNEL(device, dtype, gtype)                             \\\n  REGISTER_USER_KERNEL(\"adadelta_update\")                                                 \\\n      .SetCreateFn<AdadeltaUpdateKernel<device, dtype, gtype>>()                          \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == device)                               \\\n                       && (user_op::HobDataType(\"model\", 0) == GetDataType<dtype>::value) \\\n                       && (user_op::HobDataType(\"model_diff\", 0) == GetDataType<gtype>::value));\n\nREGISTER_ADADELTA_UPDATE_KERNEL(DeviceType::kCPU, float, float);\nREGISTER_ADADELTA_UPDATE_KERNEL(DeviceType::kCPU, double, double);\n#ifdef WITH_CUDA\nREGISTER_ADADELTA_UPDATE_KERNEL(DeviceType::kCUDA, float, float16);\nREGISTER_ADADELTA_UPDATE_KERNEL(DeviceType::kCUDA, float, float);\nREGISTER_ADADELTA_UPDATE_KERNEL(DeviceType::kCUDA, double, double);\n#endif  // WITH_CUDA\n\n}  // namespace\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/moving_average_min_max_observer_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n\n#include <algorithm>\n\nnamespace oneflow {\n\ntemplate<typename T>\nvoid GenQuantScalePerLayerSymmetric(const T* in, const int64_t current_train_step,\n                                    const int64_t stop_update_after_iters, const bool is_training,\n                                    const int32_t quantization_bit, const int64_t num_elements,\n                                    const float momentum, T* moving_max, T* moving_min, T* scale,\n                                    T* zero_point) {\n  if (current_train_step <= stop_update_after_iters && is_training) {\n    T in_max = *std::max_element(in, in + num_elements);\n    T in_min = *std::min_element(in, in + num_elements);\n\n    in_max = std::max(std::abs(in_max), std::abs(in_min));\n\n    T moving_max_val = *moving_max;\n\n    if (moving_max_val == 0) {\n      *moving_max = in_max;\n    } else {\n      *moving_max = moving_max_val * momentum + in_max * (1 - momentum);\n    }\n\n    // NOTE(Liang Depeng): symmetric quantization only use moving_max to calculate the scale\n    *moving_min = *moving_max;\n  }\n\n  T denominator = static_cast<T>(pow(2.0, quantization_bit - 1)) - 1;\n  *scale = (*moving_max) / denominator;\n  *zero_point = 0;\n}\n\ntemplate<typename T>\nvoid GenQuantScalePerLayerAffine(const T* in, const int64_t current_train_step,\n                                 const int64_t stop_update_after_iters, const bool is_training,\n                                 const int32_t quantization_bit, const int64_t num_elements,\n                                 const float momentum, T* moving_max, T* moving_min, T* scale,\n                                 T* zero_point) {\n  if (current_train_step <= stop_update_after_iters && is_training) {\n    T in_max = *std::max_element(in, in + num_elements);\n    T in_min = *std::min_element(in, in + num_elements);\n\n    T moving_max_val = *moving_max;\n    if (moving_max_val == 0) {\n      *moving_max = in_max;\n    } else {\n      *moving_max = moving_max_val * momentum + in_max * (1 - momentum);\n    }\n\n    T moving_min_val = *moving_min;\n    if (moving_min_val == 0) {\n      *moving_min = in_min;\n    } else {\n      *moving_min = moving_min_val * momentum + in_min * (1 - momentum);\n    }\n  }\n\n  T denominator = static_cast<T>(pow(2.0, quantization_bit)) - 1;\n  *scale = ((*moving_max) - (*moving_min)) / denominator;\n  *zero_point = -std::round((*moving_min) / (*scale));\n}\n\ntemplate<typename T>\nvoid GenQuantScalePerLayerCambricon(const T* in, const int64_t current_train_step,\n                                    const int64_t stop_update_after_iters, const bool is_training,\n                                    const int32_t quantization_bit, const int64_t num_elements,\n                                    const float momentum, T* moving_max, T* moving_min, T* scale,\n                                    T* zero_point) {\n  if (current_train_step <= stop_update_after_iters && is_training) {\n    T in_max = *std::max_element(in, in + num_elements);\n    T in_min = *std::min_element(in, in + num_elements);\n\n    in_max = std::max(std::abs(in_max), std::abs(in_min));\n\n    T moving_max_val = *moving_max;\n\n    if (moving_max_val == 0) {\n      *moving_max = in_max;\n    } else {\n      *moving_max = moving_max_val * momentum + in_max * (1 - momentum);\n    }\n\n    // NOTE(Liang Depeng): symmetric quantization only use moving_max to calculate the scale\n    *moving_min = *moving_max;\n  }\n\n  *scale = std::floor(std::log2(*moving_max)) - (quantization_bit - 2);\n  *zero_point = 0;\n}\n\ntemplate<typename T>\nclass CpuMovingAverageMinMaxObserverKernel final : public user_op::OpKernel {\n public:\n  CpuMovingAverageMinMaxObserverKernel() = default;\n  ~CpuMovingAverageMinMaxObserverKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    const user_op::Tensor* current_train_step =\n        ctx->Tensor4ArgNameAndIndex(\"current_train_step\", 0);\n    user_op::Tensor* moving_max = ctx->Tensor4ArgNameAndIndex(\"moving_max\", 0);\n    user_op::Tensor* moving_min = ctx->Tensor4ArgNameAndIndex(\"moving_min\", 0);\n    user_op::Tensor* scale = ctx->Tensor4ArgNameAndIndex(\"scale\", 0);\n    user_op::Tensor* zero_point = ctx->Tensor4ArgNameAndIndex(\"zero_point\", 0);\n\n    const std::string quantization_scheme = ctx->Attr<std::string>(\"quantization_scheme\");\n    const int32_t quantization_bit = ctx->Attr<int32_t>(\"quantization_bit\");\n    const float momentum = ctx->Attr<float>(\"momentum\");\n    const int64_t stop_update_after_iters = ctx->Attr<int64_t>(\"stop_update_after_iters\");\n    const bool is_training = ctx->Attr<bool>(\"training\");\n    const std::string quantization_formula = ctx->Attr<std::string>(\"quantization_formula\");\n\n    const T* in_ptr = in->dptr<T>();\n    const int64_t* current_train_step_ptr = current_train_step->dptr<int64_t>();\n    T* moving_max_ptr = moving_max->mut_dptr<T>();\n    T* moving_min_ptr = moving_min->mut_dptr<T>();\n    T* scale_ptr = scale->mut_dptr<T>();\n    T* zero_point_ptr = zero_point->mut_dptr<T>();\n\n    int64_t num_elements = in->shape_view().elem_cnt();\n\n    if (quantization_formula == \"google\") {\n      if (quantization_scheme == \"symmetric\") {\n        GenQuantScalePerLayerSymmetric(in_ptr, *current_train_step_ptr, stop_update_after_iters,\n                                       is_training, quantization_bit, num_elements, momentum,\n                                       moving_max_ptr, moving_min_ptr, scale_ptr, zero_point_ptr);\n      } else {  // quantization_scheme == \"affine\"\n        GenQuantScalePerLayerAffine(in_ptr, *current_train_step_ptr, stop_update_after_iters,\n                                    is_training, quantization_bit, num_elements, momentum,\n                                    moving_max_ptr, moving_min_ptr, scale_ptr, zero_point_ptr);\n      }\n    } else if (quantization_formula == \"cambricon\") {\n      GenQuantScalePerLayerCambricon(in_ptr, *current_train_step_ptr, stop_update_after_iters,\n                                     is_training, quantization_bit, num_elements, momentum,\n                                     moving_max_ptr, moving_min_ptr, scale_ptr, zero_point_ptr);\n    } else {\n      UNIMPLEMENTED();\n    }\n  }\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_MOVING_AVERAGE_MIN_MAX_OBSERVER_KERNEL(dtype)        \\\n  REGISTER_USER_KERNEL(\"moving_average_min_max_observer\")             \\\n      .SetCreateFn<CpuMovingAverageMinMaxObserverKernel<dtype>>()     \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \\\n                       && (user_op::HobDataType(\"in\", 0) == GetDataType<dtype>::value))\n\nREGISTER_MOVING_AVERAGE_MIN_MAX_OBSERVER_KERNEL(float);\nREGISTER_MOVING_AVERAGE_MIN_MAX_OBSERVER_KERNEL(double);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/moving_average_min_max_observer_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/device/cuda_util.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/cuda/atomic.cuh\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n\n#include <float.h>\n\nnamespace oneflow {\n\nnamespace {\n\n// NOTE(Liang Depeng): refer to\n// https://stackoverflow.com/questions/17371275/implementing-max-reduce-in-cuda\ntemplate<typename T>\n__global__ void ReduceMaxMinPerLayer(const T* input_ptr, const int64_t elements, T* max_ptr,\n                                     T* min_ptr) {\n  extern __shared__ unsigned char shared_max_min_memory[];\n  T* shared_max = reinterpret_cast<T*>(shared_max_min_memory);\n  T* shared_min = shared_max + blockDim.x;\n\n  int64_t tid = threadIdx.x;\n  int64_t gid = (blockDim.x * blockIdx.x) + tid;\n  shared_max[tid] = -FLT_MAX;\n  shared_min[tid] = -FLT_MAX;\n\n  while (gid < elements) {\n    shared_max[tid] = max(shared_max[tid], input_ptr[gid]);\n    shared_min[tid] = max(shared_min[tid], -input_ptr[gid]);\n    gid += gridDim.x * blockDim.x;\n  }\n  __syncthreads();\n  gid = (blockDim.x * blockIdx.x) + tid;\n  for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) {\n    if (tid < s && gid < elements) {\n      shared_max[tid] = max(shared_max[tid], shared_max[tid + s]);\n      shared_min[tid] = max(shared_min[tid], shared_min[tid + s]);\n    }\n    __syncthreads();\n  }\n\n  if (tid == 0) {\n    cuda::atomic::Max(max_ptr, shared_max[0]);\n    cuda::atomic::Max(min_ptr, shared_min[0]);\n  }\n}\n\ntemplate<typename T>\n__global__ void InitMaxMin(const int64_t elements, T* max_ptr, T* min_ptr) {\n  int64_t tid = threadIdx.x;\n  int64_t gid = (blockDim.x * blockIdx.x) + tid;\n\n  while (gid < elements) {\n    max_ptr[gid] = -FLT_MAX;\n    min_ptr[gid] = -FLT_MAX;\n    gid += gridDim.x * blockDim.x;\n  }\n}\n\ntemplate<typename T>\n__global__ void CalScaleZeroPointSymmetric(const int64_t elements, const double quantization_bit,\n                                           const float momentum, const T* max_ptr, const T* min_ptr,\n                                           T* moving_max_ptr, T* moving_min_ptr, T* scale,\n                                           T* zero_point) {\n  int64_t tid = threadIdx.x;\n  int64_t gid = (blockDim.x * blockIdx.x) + tid;\n\n  while (gid < elements) {\n    T activation_max = max(fabs(max_ptr[gid]), fabs(min_ptr[gid]));\n    T denominator = static_cast<T>(pow(2.0, quantization_bit - 1)) - 1;\n\n    if (moving_max_ptr[gid] == 0)\n      moving_max_ptr[gid] = activation_max;\n    else\n      moving_max_ptr[gid] = moving_max_ptr[gid] * momentum + activation_max * (1 - momentum);\n\n    // NOTE(Liang Depeng): symmetric quantization only use moving_max to calculate the scale\n    moving_min_ptr[gid] = moving_max_ptr[gid];\n\n    scale[gid] = moving_max_ptr[gid] / denominator;\n    zero_point[gid] = 0;\n    gid += gridDim.x * blockDim.x;\n  }\n}\n\ntemplate<typename T>\n__global__ void CalFreezeScaleZeroPointSymmetric(const int64_t elements,\n                                                 const double quantization_bit,\n                                                 const float momentum, const T* moving_max_ptr,\n                                                 T* scale, T* zero_point) {\n  int64_t tid = threadIdx.x;\n  int64_t gid = (blockDim.x * blockIdx.x) + tid;\n\n  while (gid < elements) {\n    T denominator = static_cast<T>(pow(2.0, quantization_bit - 1)) - 1;\n    scale[gid] = moving_max_ptr[gid] / denominator;\n    zero_point[gid] = 0;\n    gid += gridDim.x * blockDim.x;\n  }\n}\n\ntemplate<typename T>\n__global__ void CalScaleZeroPointAffine(const int64_t elements, const double quantization_bit,\n                                        const float momentum, const T* max_ptr, const T* min_ptr,\n                                        T* moving_max_ptr, T* moving_min_ptr, T* scale,\n                                        T* zero_point) {\n  int64_t tid = threadIdx.x;\n  int64_t gid = (blockDim.x * blockIdx.x) + tid;\n\n  while (gid < elements) {\n    T denominator = static_cast<T>(pow(2.0, quantization_bit)) - 1;\n\n    if (moving_max_ptr[gid] == 0)\n      moving_max_ptr[gid] = max_ptr[gid];\n    else\n      moving_max_ptr[gid] = moving_max_ptr[gid] * momentum + max_ptr[gid] * (1 - momentum);\n\n    if (moving_min_ptr[gid] == 0)\n      moving_min_ptr[gid] = -min_ptr[gid];\n    else\n      moving_min_ptr[gid] = moving_min_ptr[gid] * momentum + -min_ptr[gid] * (1 - momentum);\n\n    T min = moving_min_ptr[gid];\n    T s = (moving_max_ptr[gid] - min) / denominator;\n\n    scale[gid] = s;\n    zero_point[gid] = -round(min / s);\n    gid += gridDim.x * blockDim.x;\n  }\n}\n\ntemplate<typename T>\n__global__ void CalFreezeScaleZeroPointAffine(const int64_t elements, const double quantization_bit,\n                                              const float momentum, const T* moving_max_ptr,\n                                              const T* moving_min_ptr, T* scale, T* zero_point) {\n  int64_t tid = threadIdx.x;\n  int64_t gid = (blockDim.x * blockIdx.x) + tid;\n\n  while (gid < elements) {\n    T denominator = static_cast<T>(pow(2.0, quantization_bit)) - 1;\n\n    T min = moving_min_ptr[gid];\n    T s = (moving_max_ptr[gid] - min) / denominator;\n\n    scale[gid] = s;\n    zero_point[gid] = -round(min / s);\n    gid += gridDim.x * blockDim.x;\n  }\n}\n\ntemplate<typename T>\n__global__ void CalScaleZeroPointCambricon(const int64_t elements, const double quantization_bit,\n                                           const float momentum, const T* max_ptr, const T* min_ptr,\n                                           T* moving_max_ptr, T* moving_min_ptr, T* scale,\n                                           T* zero_point) {\n  int64_t tid = threadIdx.x;\n  int64_t gid = (blockDim.x * blockIdx.x) + tid;\n\n  while (gid < elements) {\n    T activation_max = max(fabs(max_ptr[gid]), fabs(min_ptr[gid]));\n\n    if (moving_max_ptr[gid] == 0)\n      moving_max_ptr[gid] = activation_max;\n    else\n      moving_max_ptr[gid] = moving_max_ptr[gid] * momentum + activation_max * (1 - momentum);\n\n    // NOTE(Liang Depeng): cambricon quantization only use moving_max to calculate the scale\n    moving_min_ptr[gid] = moving_max_ptr[gid];\n\n    scale[gid] = floor(log2(moving_max_ptr[gid])) - (quantization_bit - 2);\n    zero_point[gid] = 0;\n    gid += gridDim.x * blockDim.x;\n  }\n}\n\ntemplate<typename T>\n__global__ void CalFreezeScaleZeroPointCambricon(const int64_t elements,\n                                                 const double quantization_bit,\n                                                 const float momentum, const T* moving_max_ptr,\n                                                 T* scale, T* zero_point) {\n  int64_t tid = threadIdx.x;\n  int64_t gid = (blockDim.x * blockIdx.x) + tid;\n\n  while (gid < elements) {\n    T denominator = static_cast<T>(pow(2.0, quantization_bit - 1)) - 1;\n    scale[gid] = floor(log2(moving_max_ptr[gid])) - (quantization_bit - 2);\n    zero_point[gid] = 0;\n    gid += gridDim.x * blockDim.x;\n  }\n}\n\nep::CudaLaunchConfig GetLaunchConfig(ep::CudaStream* stream, size_t thread_num,\n                                     size_t shared_mem_size) {\n  ep::CudaLaunchConfig config;\n  stream->InitLaunchConfigWithWaves(&config, thread_num, kCudaThreadsNumPerBlock, 1);\n  config.shared_mem_size = shared_mem_size;\n  return config;\n}\n\n}  // namespace\n\n#define LAUNCH_CUDA_KERNEL(func, stream, thread_num, shared_mem_size, ...) \\\n  (stream)->LaunchKernel(func, GetLaunchConfig((stream), thread_num, shared_mem_size), __VA_ARGS__);\n\ntemplate<typename T>\nclass GpuMovingAverageMinMaxObserverKernel final : public user_op::OpKernel {\n public:\n  GpuMovingAverageMinMaxObserverKernel() = default;\n  ~GpuMovingAverageMinMaxObserverKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    const user_op::Tensor* current_train_step =\n        ctx->Tensor4ArgNameAndIndex(\"current_train_step\", 0);\n    user_op::Tensor* moving_max = ctx->Tensor4ArgNameAndIndex(\"moving_max\", 0);\n    user_op::Tensor* moving_min = ctx->Tensor4ArgNameAndIndex(\"moving_min\", 0);\n    user_op::Tensor* scale = ctx->Tensor4ArgNameAndIndex(\"scale\", 0);\n    user_op::Tensor* zero_point = ctx->Tensor4ArgNameAndIndex(\"zero_point\", 0);\n    user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n\n    const bool is_training = ctx->Attr<bool>(\"training\");\n    const int64_t stop_update_after_iters = ctx->Attr<int64_t>(\"stop_update_after_iters\");\n    const std::string quantization_scheme = ctx->Attr<std::string>(\"quantization_scheme\");\n    const int32_t quantization_bit = ctx->Attr<int32_t>(\"quantization_bit\");\n    const float momentum = ctx->Attr<float>(\"momentum\");\n    const std::string quantization_formula = ctx->Attr<std::string>(\"quantization_formula\");\n\n    int64_t elements = in->shape_view().elem_cnt();\n    T* max_ptr = tmp_buffer->mut_dptr<T>();\n    T* min_ptr = max_ptr + 1;\n\n    int64_t* host_current_train_step_ptr = new int64_t[current_train_step->shape_view().elem_cnt()];\n    OF_CUDA_CHECK(cudaMemcpy(host_current_train_step_ptr, current_train_step->dptr<int64_t>(),\n                             current_train_step->shape_view().elem_cnt() * sizeof(int64_t),\n                             cudaMemcpyDefault));\n    auto* cuda_stream = ctx->stream()->As<ep::CudaStream>();\n    if (*host_current_train_step_ptr <= stop_update_after_iters && is_training) {\n      LAUNCH_CUDA_KERNEL((InitMaxMin<T>), cuda_stream, 1, 0, 1, max_ptr, min_ptr);\n      LAUNCH_CUDA_KERNEL((ReduceMaxMinPerLayer<T>), cuda_stream, elements,\n                         kCudaThreadsNumPerBlock * 2 * sizeof(T), in->dptr<T>(), elements, max_ptr,\n                         min_ptr);\n    }\n    bool moving = (*host_current_train_step_ptr <= stop_update_after_iters) && is_training;\n    if (quantization_formula == \"google\") {\n      if (quantization_scheme == \"symmetric\") {\n        if (moving) {\n          LAUNCH_CUDA_KERNEL((CalScaleZeroPointSymmetric<T>), cuda_stream, 1, 0, 1,\n                             static_cast<double>(quantization_bit), momentum, max_ptr, min_ptr,\n                             moving_max->mut_dptr<T>(), moving_min->mut_dptr<T>(),\n                             scale->mut_dptr<T>(), zero_point->mut_dptr<T>());\n        } else {\n          LAUNCH_CUDA_KERNEL((CalFreezeScaleZeroPointSymmetric<T>), cuda_stream, 1, 0, 1,\n                             static_cast<double>(quantization_bit), momentum, moving_max->dptr<T>(),\n                             scale->mut_dptr<T>(), zero_point->mut_dptr<T>());\n        }\n      } else {  // quantization_scheme == \"affine\"\n        if (moving) {\n          LAUNCH_CUDA_KERNEL((CalScaleZeroPointAffine<T>), cuda_stream, 1, 0, 1,\n                             static_cast<double>(quantization_bit), momentum, max_ptr, min_ptr,\n                             moving_max->mut_dptr<T>(), moving_min->mut_dptr<T>(),\n                             scale->mut_dptr<T>(), zero_point->mut_dptr<T>());\n        } else {\n          LAUNCH_CUDA_KERNEL((CalFreezeScaleZeroPointAffine<T>), cuda_stream, 1, 0, 1,\n                             static_cast<double>(quantization_bit), momentum, moving_max->dptr<T>(),\n                             moving_min->dptr<T>(), scale->mut_dptr<T>(),\n                             zero_point->mut_dptr<T>());\n        }\n      }\n    } else if (quantization_formula == \"cambricon\") {\n      if (moving) {\n        LAUNCH_CUDA_KERNEL((CalScaleZeroPointCambricon<T>), cuda_stream, 1, 0, 1,\n                           static_cast<double>(quantization_bit), momentum, max_ptr, min_ptr,\n                           moving_max->mut_dptr<T>(), moving_min->mut_dptr<T>(),\n                           scale->mut_dptr<T>(), zero_point->mut_dptr<T>());\n      } else {\n        LAUNCH_CUDA_KERNEL((CalFreezeScaleZeroPointCambricon<T>), cuda_stream, 1, 0, 1,\n                           static_cast<double>(quantization_bit), momentum, moving_max->dptr<T>(),\n                           scale->mut_dptr<T>(), zero_point->mut_dptr<T>());\n      }\n    } else {\n      UNIMPLEMENTED();\n    }\n\n    delete[] host_current_train_step_ptr;\n  }\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_MOVING_AVERAGE_MIN_MAX_OBSERVER_KERNEL(dtype)                          \\\n  REGISTER_USER_KERNEL(\"moving_average_min_max_observer\")                               \\\n      .SetCreateFn<GpuMovingAverageMinMaxObserverKernel<dtype>>()                       \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)                  \\\n                       && (user_op::HobDataType(\"in\", 0) == GetDataType<dtype>::value)) \\\n      .SetInferTmpSizeFn([](user_op::InferContext* ctx) -> size_t { return 2 * sizeof(dtype); })\n\nREGISTER_MOVING_AVERAGE_MIN_MAX_OBSERVER_KERNEL(float);\nREGISTER_MOVING_AVERAGE_MIN_MAX_OBSERVER_KERNEL(double);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/multi_reduce_kernel_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_KERNELS_MULTI_REDUCE_KERNEL_UTIL_H_\n#define ONEFLOW_USER_KERNELS_MULTI_REDUCE_KERNEL_UTIL_H_\n\n#include \"oneflow/core/common/data_type.h\"\n#include \"oneflow/core/common/device_type.h\"\n#include \"oneflow/core/common/device_type.pb.h\"\n#include \"oneflow/core/ep/include/stream.h\"\n\nnamespace oneflow {\n\ntemplate<typename T>\nstruct MultiReduceParam {\n  const T* data;\n  size_t size;\n};\n\ntemplate<DeviceType device_type, typename T, typename TransformFn, typename ReduceFn>\nstruct MultiReduce {\n  void operator()(ep::Stream* stream, TransformFn transform,\n                  const std::vector<MultiReduceParam<T>>& params, T init, T* ret, T* temp);\n};\n\ntemplate<typename T, typename TransformFn, typename ReduceFn>\nstruct MultiReduce<DeviceType::kCPU, T, TransformFn, ReduceFn> {\n  void operator()(ep::Stream* stream, TransformFn transform,\n                  const std::vector<MultiReduceParam<T>>& params, T init, T* ret, T* temp) {\n    *ret = init;\n    ReduceFn reduce{};\n    FOR_RANGE(size_t, i, 0, params.size()) {\n      const auto& p = params[i];\n      FOR_RANGE(size_t, j, 0, p.size) { *ret = reduce(*ret, transform(p.data[j])); }\n    }\n  }\n};\n\ntemplate<typename T>\nstruct BinaryAdd {\n  OF_DEVICE_FUNC T operator()(const T& x, const T& y) const { return x + y; }\n};\n\ntemplate<typename T>\nstruct BinaryMax {\n  OF_DEVICE_FUNC T operator()(const T& x, const T& y) const { return x > y ? x : y; }\n};\n\ntemplate<typename T>\nstruct BinaryMin {\n  OF_DEVICE_FUNC T operator()(const T& x, const T& y) const { return x < y ? x : y; }\n};\n\ntemplate<typename T>\nstruct Abs {\n  OF_DEVICE_FUNC T operator()(const T& x) const { return x < GetZeroVal<T>() ? -x : x; }\n};\n\ntemplate<typename T>\nstruct PowByZero {\n  OF_DEVICE_FUNC T operator()(const T& x) const {\n    return x != GetZeroVal<T>() ? GetOneVal<T>() : x;\n  }\n};\n\ntemplate<typename T>\nstruct Square {\n  OF_DEVICE_FUNC T operator()(const T& x) const { return x * x; }\n};\n\ntemplate<typename T>\nstruct AbsPow {\n  explicit AbsPow(const T& base) : base_(base) {}\n\n  OF_DEVICE_FUNC T operator()(const T& x) {\n    T abs_x = x < GetZeroVal<T>() ? -x : x;\n#if defined(__CUDA_ARCH__)\n    return pow(abs_x, base_);\n#else\n    return std::pow(abs_x, base_);\n#endif\n  }\n\n private:\n  T base_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_KERNELS_MULTI_REDUCE_KERNEL_UTIL_H_\n"
  },
  {
    "path": "oneflow/user/kernels/multi_reduce_kernels.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/user/kernels/multi_reduce_kernels.h\"\n\nnamespace oneflow {\n\n#define REGISTER_MULTI_REDUCE_SUM_POW_ABS_CPU_KERNEL(dtype)               \\\n  REGISTER_USER_KERNEL(\"multi_reduce_sum_pow_abs\")                        \\\n      .SetCreateFn<MultiReduceSumPowAbsKernel<DeviceType::kCPU, dtype>>() \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)     \\\n                       && (user_op::HobDataType(\"y\", 0) == GetDataType<dtype>::value));\n\n#define REGISTER_MULTI_REDUCE_XIMUM_ABS_CPU_KERNEL(op_type_name, ximum_enum, dtype)  \\\n  REGISTER_USER_KERNEL(op_type_name)                                                 \\\n      .SetCreateFn<MultiReduceXimumAbsKernel<DeviceType::kCPU, dtype, ximum_enum>>() \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)                \\\n                       && (user_op::HobDataType(\"y\", 0) == GetDataType<dtype>::value));\n\n#define REGISTER_MULTI_REDUCE_XIMUM_ABS_CPU_KERNELS(dtype)                                     \\\n  REGISTER_MULTI_REDUCE_XIMUM_ABS_CPU_KERNEL(\"multi_reduce_max_abs\", Ximum::kMax, dtype)       \\\n  REGISTER_MULTI_REDUCE_XIMUM_ABS_CPU_KERNEL(\"multi_reduce_min_abs\", Ximum::kMin, dtype)       \\\n  REGISTER_MULTI_REDUCE_XIMUM_ABS_CPU_KERNEL(\"local_multi_reduce_max_abs\", Ximum::kMax, dtype) \\\n  REGISTER_MULTI_REDUCE_XIMUM_ABS_CPU_KERNEL(\"local_multi_reduce_min_abs\", Ximum::kMin, dtype)\n\nREGISTER_MULTI_REDUCE_SUM_POW_ABS_CPU_KERNEL(float)\nREGISTER_MULTI_REDUCE_SUM_POW_ABS_CPU_KERNEL(double)\n\nREGISTER_MULTI_REDUCE_XIMUM_ABS_CPU_KERNELS(float)\nREGISTER_MULTI_REDUCE_XIMUM_ABS_CPU_KERNELS(double)\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/multi_reduce_kernels.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/user/kernels/multi_reduce_kernels.h\"\n#include \"oneflow/core/ep/include/primitive/fill.h\"\n#include \"oneflow/core/cuda/atomic.cuh\"\n#include \"oneflow/core/device/cuda_util.h\"\n#include <cub/cub.cuh>\n#include <limits>\n\nnamespace oneflow {\n\nnamespace {\n\nconstexpr int64_t kMultiReduceMaxPackSize = 64;\n\ntemplate<typename T>\nstruct MultiReduceParamsPack {\n  MultiReduceParam<T> params[kMultiReduceMaxPackSize];\n  size_t size;\n};\n\ntemplate<typename T, typename TransformFn, typename ReduceFn>\n__global__ void MultiBlockReduceGpu(TransformFn transform,\n                                    const MultiReduceParamsPack<T> pack_params, const T init,\n                                    T* out) {\n  ReduceFn reduce_fn{};\n  T t_out = init;\n  for (int i = 0; i < pack_params.size; ++i) {\n    const auto& param = pack_params.params[i];\n    CUDA_1D_KERNEL_LOOP(j, param.size) { t_out = reduce_fn(t_out, transform(param.data[j])); }\n  }\n  typedef cub::BlockReduce<T, kCudaThreadsNumPerBlock> BlockReduce;\n  __shared__ typename BlockReduce::TempStorage temp_storage;\n  T b_out = BlockReduce(temp_storage).Reduce(t_out, reduce_fn);\n  if (threadIdx.x == 0) { out[blockIdx.x] = b_out; }\n}\n\nsize_t InferTempStorageSize(user_op::InferContext* ctx) {\n  auto input_size = ctx->input_size(\"x\");\n  if (input_size == 0) { return 0; }\n  int64_t max_elem_cnt = 0;\n  int64_t pack_size = 0;\n  int32_t num_blocks = 0;\n  for (size_t i = 0; i < input_size; ++i) {\n    int64_t elem_cnt = ctx->InputShape(\"x\", i).elem_cnt();\n    max_elem_cnt = std::max(max_elem_cnt, elem_cnt);\n    pack_size++;\n    if (pack_size == kMultiReduceMaxPackSize || i == input_size - 1) {\n      CHECK_LT(max_elem_cnt, std::numeric_limits<int32_t>::max());\n      num_blocks += BlocksNum4ThreadsNum(static_cast<int32_t>(max_elem_cnt));\n      max_elem_cnt = 0;\n      pack_size = 0;\n    }\n  }\n  CHECK_LT(num_blocks, kCudaThreadsNumPerBlock * kCudaThreadsNumPerBlock * kCudaThreadsNumPerBlock)\n      << \"Too much blocks needed for computing \" << ctx->op_name() << \", should be less than \"\n      << kCudaThreadsNumPerBlock << \"*\" << kCudaThreadsNumPerBlock << \"*\" << kCudaThreadsNumPerBlock\n      << \", but got \" << num_blocks;\n  size_t elem_size = GetSizeOfDataType(ctx->InputDType(\"x\", 0));\n  return GetCudaAlignedSize(num_blocks * elem_size * 2);\n}\n\n}  // namespace\n\ntemplate<typename T, typename TransformFn, typename ReduceFn>\nstruct MultiReduce<DeviceType::kCUDA, T, TransformFn, ReduceFn> {\n  void operator()(ep::Stream* stream, TransformFn transform,\n                  const std::vector<MultiReduceParam<T>>& params, T init, T* ret, T* temp) {\n    CHECK_NOTNULL(temp);\n    int32_t total_num_blocks = 0;\n    for (size_t i = 0; i < params.size(); i += kMultiReduceMaxPackSize) {\n      MultiReduceParamsPack<T> pack_params{};\n      size_t max_elem_cnt = 0;\n      pack_params.size = std::min<size_t>(kMultiReduceMaxPackSize, params.size() - i);\n      for (size_t j = 0; j < pack_params.size; ++j) {\n        pack_params.params[j] = params[i + j];\n        max_elem_cnt = std::max<size_t>(max_elem_cnt, pack_params.params[j].size);\n      }\n      int32_t num_blocks = BlocksNum4ThreadsNum(max_elem_cnt);\n      MultiBlockReduceGpu<T, TransformFn, ReduceFn>\n          <<<num_blocks, kCudaThreadsNumPerBlock, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(\n              transform, pack_params, init, temp + total_num_blocks);\n      total_num_blocks += num_blocks;\n    }\n    size_t wksp_size = 0;\n    auto DeviceReduce = [&](void* temp_storage) -> void {\n      OF_CUDA_CHECK(cub::DeviceReduce::Reduce(temp_storage, wksp_size, temp, ret, total_num_blocks,\n                                              ReduceFn{}, init,\n                                              stream->As<ep::CudaStream>()->cuda_stream()));\n    };\n    DeviceReduce(nullptr);\n    // NOTE(zwx): We have allocated the temp storage with the space\n    //  that can hold all the elements to reduce,\n    //  normally the `temp_storage_bytes` for cub::DeviceReduce shouldn't exceed it.\n    CHECK_LE(wksp_size, total_num_blocks * sizeof(T))\n        << wksp_size << \" size in bytes of temp storage is needed for doing cub::DeviceReduce, \"\n        << \"but only allocated \" << total_num_blocks * sizeof(T);\n    DeviceReduce(temp + total_num_blocks);\n  }\n};\n\n#define REGISTER_MULTI_REDUCE_SUM_POW_ABS_CUDA_KERNEL(dtype)                           \\\n  REGISTER_USER_KERNEL(\"multi_reduce_sum_pow_abs\")                                     \\\n      .SetCreateFn<MultiReduceSumPowAbsKernel<DeviceType::kCUDA, dtype>>()             \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)                 \\\n                       && (user_op::HobDataType(\"y\", 0) == GetDataType<dtype>::value)) \\\n      .SetInferTmpSizeFn(InferTempStorageSize);\n\n#define REGISTER_MULTI_REDUCE_XIMUM_ABS_CUDA_KERNEL(op_type_name, ximum_enum, dtype)   \\\n  REGISTER_USER_KERNEL(op_type_name)                                                   \\\n      .SetCreateFn<MultiReduceXimumAbsKernel<DeviceType::kCUDA, dtype, ximum_enum>>()  \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)                 \\\n                       && (user_op::HobDataType(\"y\", 0) == GetDataType<dtype>::value)) \\\n      .SetInferTmpSizeFn(InferTempStorageSize);\n\n#define REGISTER_MULTI_REDUCE_XIMUM_ABS_CUDA_KERNELS(dtype)                                     \\\n  REGISTER_MULTI_REDUCE_XIMUM_ABS_CUDA_KERNEL(\"multi_reduce_max_abs\", Ximum::kMax, dtype)       \\\n  REGISTER_MULTI_REDUCE_XIMUM_ABS_CUDA_KERNEL(\"multi_reduce_min_abs\", Ximum::kMin, dtype)       \\\n  REGISTER_MULTI_REDUCE_XIMUM_ABS_CUDA_KERNEL(\"local_multi_reduce_max_abs\", Ximum::kMax, dtype) \\\n  REGISTER_MULTI_REDUCE_XIMUM_ABS_CUDA_KERNEL(\"local_multi_reduce_min_abs\", Ximum::kMin, dtype)\n\nREGISTER_MULTI_REDUCE_SUM_POW_ABS_CUDA_KERNEL(float)\nREGISTER_MULTI_REDUCE_SUM_POW_ABS_CUDA_KERNEL(double)\n\nREGISTER_MULTI_REDUCE_XIMUM_ABS_CUDA_KERNELS(float)\nREGISTER_MULTI_REDUCE_XIMUM_ABS_CUDA_KERNELS(double)\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/multi_reduce_kernels.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_KERNELS_MULTI_REDUCE_KERNELS_H_\n#define ONEFLOW_USER_KERNELS_MULTI_REDUCE_KERNELS_H_\n\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/cuda_graph_support.h\"\n#include \"oneflow/user/kernels/multi_reduce_kernel_util.h\"\n\nnamespace oneflow {\n\ntemplate<DeviceType device_type, typename T>\nclass MultiReduceSumPowAbsKernel final : public user_op::OpKernel,\n                                         public user_op::CudaGraphSupport {\n public:\n  MultiReduceSumPowAbsKernel() = default;\n  ~MultiReduceSumPowAbsKernel() override = default;\n\n private:\n  using user_op::OpKernel::Compute;\n\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*,\n               const user_op::OpKernelCache*) const override {\n    std::vector<MultiReduceParam<T>> params;\n    params.resize(ctx->input_size(\"x\"));\n    for (size_t i = 0; i < params.size(); ++i) {\n      const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex(\"x\", i);\n      params[i].size = x->shape_view().elem_cnt();\n      params[i].data = x->dptr<T>();\n    }\n    user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    T* y_dptr = y->mut_dptr<T>();\n    user_op::Tensor* temp = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n    T* tmp_dptr = temp ? temp->mut_dptr<T>() : nullptr;\n    float p = ctx->Attr<float>(\"p\");\n    if (p == 0) {\n      PowByZero<T> func{};\n      MultiReduce<device_type, T, decltype(func), BinaryAdd<T>> reduce_sum{};\n      reduce_sum(ctx->stream(), func, params, GetZeroVal<T>(), y_dptr, tmp_dptr);\n    } else if (p == 1) {\n      Abs<T> func{};\n      MultiReduce<device_type, T, decltype(func), BinaryAdd<T>> reduce_sum{};\n      reduce_sum(ctx->stream(), func, params, GetZeroVal<T>(), y_dptr, tmp_dptr);\n    } else if (p == 2) {\n      Square<T> func{};\n      MultiReduce<device_type, T, decltype(func), BinaryAdd<T>> reduce_sum{};\n      reduce_sum(ctx->stream(), func, params, GetZeroVal<T>(), y_dptr, tmp_dptr);\n    } else {\n      AbsPow<T> func{p};\n      MultiReduce<device_type, T, decltype(func), BinaryAdd<T>> reduce_sum{};\n      reduce_sum(ctx->stream(), func, params, GetZeroVal<T>(), y_dptr, tmp_dptr);\n    }\n  }\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\nenum class Ximum {\n  kMax = 0,\n  kMin = 1,\n};\n\ntemplate<DeviceType device_type, typename T, Ximum X>\nclass MultiReduceXimumAbsKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport {\n public:\n  MultiReduceXimumAbsKernel() = default;\n  ~MultiReduceXimumAbsKernel() override = default;\n\n private:\n  using user_op::OpKernel::Compute;\n\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*,\n               const user_op::OpKernelCache*) const override {\n    std::vector<MultiReduceParam<T>> params;\n    params.resize(ctx->input_size(\"x\"));\n    for (size_t i = 0; i < params.size(); ++i) {\n      const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex(\"x\", i);\n      params[i].size = x->shape_view().elem_cnt();\n      params[i].data = x->dptr<T>();\n    }\n    user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    user_op::Tensor* temp = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n    T* tmp_dptr = temp ? temp->mut_dptr<T>() : nullptr;\n    Abs<T> abs{};\n    if (X == Ximum::kMax) {\n      MultiReduce<device_type, T, decltype(abs), BinaryMax<T>> reduce_max{};\n      reduce_max(ctx->stream(), abs, params, GetZeroVal<T>(), y->mut_dptr<T>(), tmp_dptr);\n    } else if (X == Ximum::kMin) {\n      MultiReduce<device_type, T, decltype(abs), BinaryMin<T>> reduce_min{};\n      reduce_min(ctx->stream(), abs, params, std::numeric_limits<T>::max(), y->mut_dptr<T>(),\n                 tmp_dptr);\n    } else {\n      UNIMPLEMENTED();\n    }\n  }\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_KERNELS_MULTI_REDUCE_KERNELS_H_\n"
  },
  {
    "path": "oneflow/user/kernels/multi_tensor_model_update_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/user/kernels/multi_tensor_model_update_kernel_util.h\"\n#include \"oneflow/core/kernel/cuda_graph_support.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<DeviceType device_type, typename T, typename G>\nclass MultiTensorSGDUpdateKernel final : public user_op::OpKernel,\n                                         public user_op::CudaGraphSupport {\n public:\n  MultiTensorSGDUpdateKernel() = default;\n  ~MultiTensorSGDUpdateKernel() override = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const int64_t n_tensor = ctx->input_size(\"model\");\n    const double scale = ctx->Attr<double>(\"scale\");\n    const float l1 = ctx->Attr<float>(\"l1\");\n    const float l2 = ctx->Attr<float>(\"l2\");\n    const float weight_decay = ctx->Attr<float>(\"weight_decay\");\n    const float* learning_rate_ptr = nullptr;\n    const float learning_rate_val = ctx->Attr<float>(\"learning_rate_val\");\n    const float lr_scale = ctx->Attr<float>(\"learning_rate_scale\");\n\n    if (ctx->has_input(\"learning_rate\", 0)) {\n      const user_op::Tensor* learning_rate = ctx->Tensor4ArgNameAndIndex(\"learning_rate\", 0);\n      learning_rate_ptr = learning_rate->dptr<float>();\n    }\n    const T* scale_by_ptr = nullptr;\n    if (ctx->has_input(\"scale_by_tensor\", 0)) {\n      const user_op::Tensor* scale_by_tensor = ctx->Tensor4ArgNameAndIndex(\"scale_by_tensor\", 0);\n      CHECK_EQ(scale_by_tensor->data_type(), ctx->Tensor4ArgNameAndIndex(\"model\", 0)->data_type());\n      CHECK_EQ(scale_by_tensor->shape_view().elem_cnt(), 1);\n      scale_by_ptr = scale_by_tensor->dptr<T>();\n    }\n    const int64_t* skip_if_ptr = nullptr;\n    if (ctx->has_input(\"skip_if\", 0)) {\n      const user_op::Tensor* skip_if = ctx->Tensor4ArgNameAndIndex(\"skip_if\", 0);\n      CHECK_EQ(skip_if->shape_view().elem_cnt(), 1);\n      skip_if_ptr = skip_if->dptr<int64_t>();\n    }\n\n    TensorTupleParams<2> tensor_tuple_params{};\n    int32_t count = 0;\n    int32_t total_elem_cnt = 0;\n    for (int tensor_idx = 0; tensor_idx < n_tensor; tensor_idx++) {\n      tensor_tuple_params.ptr[0][count] =\n          (ctx->Tensor4ArgNameAndIndex(\"model\", tensor_idx))->mut_dptr();\n      tensor_tuple_params.ptr[1][count] =\n          (ctx->Tensor4ArgNameAndIndex(\"model_diff\", tensor_idx))->mut_dptr();\n\n      const int64_t tensor_elem_cnt =\n          ctx->Tensor4ArgNameAndIndex(\"model\", tensor_idx)->shape_view().elem_cnt();\n      tensor_tuple_params.sizes[count] = tensor_elem_cnt;\n\n      count += 1;\n      total_elem_cnt += tensor_elem_cnt;\n      if (count == kMaxTuples || tensor_idx == n_tensor - 1) {\n        MultiTensorSGDUpdateKernelUtil<device_type, T, G>::Update(\n            ctx->stream(), total_elem_cnt, count, static_cast<T>(scale), l1, l2, weight_decay,\n            learning_rate_val, lr_scale, learning_rate_ptr, scale_by_ptr, skip_if_ptr,\n            tensor_tuple_params);\n        count = 0;\n        total_elem_cnt = 0;\n      }\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return true; }\n};\n\n#define REGISTER_MULTI_TENSOR_UPDATE_SGD_UPDATE_KERNEL(device, dtype, gtype)              \\\n  REGISTER_USER_KERNEL(\"multi_tensor_sgd_update\")                                         \\\n      .SetCreateFn<MultiTensorSGDUpdateKernel<device, dtype, gtype>>()                    \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == device)                               \\\n                       && (user_op::HobDataType(\"model\", 0) == GetDataType<dtype>::value) \\\n                       && (user_op::HobDataType(\"model_diff\", 0) == GetDataType<gtype>::value));\n\n#ifdef WITH_CUDA\nREGISTER_MULTI_TENSOR_UPDATE_SGD_UPDATE_KERNEL(DeviceType::kCUDA, float, float16);\nREGISTER_MULTI_TENSOR_UPDATE_SGD_UPDATE_KERNEL(DeviceType::kCUDA, float, float);\nREGISTER_MULTI_TENSOR_UPDATE_SGD_UPDATE_KERNEL(DeviceType::kCUDA, double, double);\n#endif\n\ntemplate<DeviceType device_type, typename T, typename G>\nclass MultiTensorMomentumUpdateKernel final : public user_op::OpKernel,\n                                              public user_op::CudaGraphSupport {\n public:\n  MultiTensorMomentumUpdateKernel() = default;\n  ~MultiTensorMomentumUpdateKernel() override = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const int64_t n_tensor = ctx->input_size(\"model\");\n    const double scale = ctx->Attr<double>(\"scale\");\n    const float l1 = ctx->Attr<float>(\"l1\");\n    const float l2 = ctx->Attr<float>(\"l2\");\n    const float weight_decay = ctx->Attr<float>(\"weight_decay\");\n    const float* learning_rate_ptr = nullptr;\n    const float learning_rate_val = ctx->Attr<float>(\"learning_rate_val\");\n    const float lr_scale = ctx->Attr<float>(\"learning_rate_scale\");\n    const float momentum = ctx->Attr<float>(\"momentum\");\n    const float dampening = ctx->Attr<float>(\"dampening\");\n    const bool nesterov = ctx->Attr<bool>(\"nesterov\");\n    const bool maximize = ctx->Attr<bool>(\"maximize\");\n\n    if (ctx->has_input(\"learning_rate\", 0)) {\n      const user_op::Tensor* learning_rate = ctx->Tensor4ArgNameAndIndex(\"learning_rate\", 0);\n      learning_rate_ptr = learning_rate->dptr<float>();\n    }\n    const T* scale_by_ptr = nullptr;\n    if (ctx->has_input(\"scale_by_tensor\", 0)) {\n      const user_op::Tensor* scale_by_tensor = ctx->Tensor4ArgNameAndIndex(\"scale_by_tensor\", 0);\n      CHECK_EQ(scale_by_tensor->data_type(), ctx->Tensor4ArgNameAndIndex(\"model\", 0)->data_type());\n      CHECK_EQ(scale_by_tensor->shape_view().elem_cnt(), 1);\n      scale_by_ptr = scale_by_tensor->dptr<T>();\n    }\n    const int64_t* skip_if_ptr = nullptr;\n    if (ctx->has_input(\"skip_if\", 0)) {\n      const user_op::Tensor* skip_if = ctx->Tensor4ArgNameAndIndex(\"skip_if\", 0);\n      CHECK_EQ(skip_if->shape_view().elem_cnt(), 1);\n      skip_if_ptr = skip_if->dptr<int64_t>();\n    }\n\n    TensorTupleParams<3> tensor_tuple_params{};\n    int32_t count = 0;\n    int32_t total_elem_cnt = 0;\n    for (int tensor_idx = 0; tensor_idx < n_tensor; tensor_idx++) {\n      tensor_tuple_params.ptr[0][count] =\n          (ctx->Tensor4ArgNameAndIndex(\"model\", tensor_idx))->mut_dptr();\n      tensor_tuple_params.ptr[1][count] =\n          (ctx->Tensor4ArgNameAndIndex(\"model_diff\", tensor_idx))->mut_dptr();\n      tensor_tuple_params.ptr[2][count] =\n          (ctx->Tensor4ArgNameAndIndex(\"momentum_buf\", tensor_idx))->mut_dptr();\n\n      const int64_t tensor_elem_cnt =\n          ctx->Tensor4ArgNameAndIndex(\"model\", tensor_idx)->shape_view().elem_cnt();\n      tensor_tuple_params.sizes[count] = tensor_elem_cnt;\n\n      count += 1;\n      total_elem_cnt += tensor_elem_cnt;\n      if (count == kMaxTuples || tensor_idx == n_tensor - 1) {\n        MultiTensorMomentumUpdateKernelUtil<device_type, T, G>::Update(\n            ctx->stream(), total_elem_cnt, count, static_cast<T>(scale), l1, l2, weight_decay,\n            learning_rate_val, lr_scale, learning_rate_ptr, scale_by_ptr, skip_if_ptr, momentum,\n            dampening, nesterov, maximize, tensor_tuple_params);\n        count = 0;\n        total_elem_cnt = 0;\n      }\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return true; }\n};\n\n#define REGISTER_MULTI_TENSOR_UPDATE_MOMENTUM_UPDATE_KERNEL(device, dtype, gtype)              \\\n  REGISTER_USER_KERNEL(\"multi_tensor_momentum_update\")                                         \\\n      .SetCreateFn<MultiTensorMomentumUpdateKernel<device, dtype, gtype>>()                    \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == device)                                    \\\n                       && (user_op::HobDataType(\"model\", 0) == GetDataType<dtype>::value)      \\\n                       && (user_op::HobDataType(\"model_diff\", 0) == GetDataType<gtype>::value) \\\n                       && (user_op::HobDataType(\"momentum_buf\", 0) == GetDataType<gtype>::value));\n\n#ifdef WITH_CUDA\nREGISTER_MULTI_TENSOR_UPDATE_MOMENTUM_UPDATE_KERNEL(DeviceType::kCUDA, float, float16);\nREGISTER_MULTI_TENSOR_UPDATE_MOMENTUM_UPDATE_KERNEL(DeviceType::kCUDA, float, float);\nREGISTER_MULTI_TENSOR_UPDATE_MOMENTUM_UPDATE_KERNEL(DeviceType::kCUDA, double, double);\n#endif\n\ntemplate<DeviceType device_type, typename T, typename G>\nclass MultiTensorAdamUpdateKernel final : public user_op::OpKernel,\n                                          public user_op::CudaGraphSupport {\n public:\n  MultiTensorAdamUpdateKernel() = default;\n  ~MultiTensorAdamUpdateKernel() override = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const int64_t n_tensor = ctx->input_size(\"model\");\n    const auto scale = ctx->Attr<double>(\"scale\");\n    const float l1 = ctx->Attr<float>(\"l1\");\n    const float l2 = ctx->Attr<float>(\"l2\");\n\n    const float beta1 = ctx->Attr<float>(\"beta1\");\n    const float beta2 = ctx->Attr<float>(\"beta2\");\n    const float epsilon = ctx->Attr<float>(\"epsilon\");\n    const float weight_decay = ctx->Attr<float>(\"weight_decay\");\n\n    const bool amsgrad = ctx->Attr<bool>(\"amsgrad\");\n    const bool do_bias_correction = ctx->Attr<bool>(\"do_bias_correction\");\n    if (amsgrad) { UNIMPLEMENTED() << \"Multi Tensor Adam Update do not support amsgrad = True. \"; }\n\n    const float* learning_rate_ptr = nullptr;\n    const float learning_rate_val = ctx->Attr<float>(\"learning_rate_val\");\n    const float lr_scale = ctx->Attr<float>(\"learning_rate_scale\");\n\n    if (ctx->has_input(\"learning_rate\", 0)) {\n      const user_op::Tensor* learning_rate = ctx->Tensor4ArgNameAndIndex(\"learning_rate\", 0);\n      learning_rate_ptr = learning_rate->dptr<float>();\n    }\n\n    const float bias_correction1_val = ctx->Attr<float>(\"bias_correction1_val\");\n    const float* bias_correction1_ptr = nullptr;\n    if (ctx->has_input(\"bias_correction1\", 0)) {\n      const user_op::Tensor* bias_correction1 = ctx->Tensor4ArgNameAndIndex(\"bias_correction1\", 0);\n      CHECK_EQ(bias_correction1->shape_view().elem_cnt(),\n               1);  // Just for Lazy Optional Input Check.\n      bias_correction1_ptr = bias_correction1->dptr<float>();\n    }\n\n    const float bias_correction2_val = ctx->Attr<float>(\"bias_correction2_val\");\n    const float* bias_correction2_ptr = nullptr;\n    if (ctx->has_input(\"bias_correction2\", 0)) {\n      const user_op::Tensor* bias_correction2 = ctx->Tensor4ArgNameAndIndex(\"bias_correction2\", 0);\n      CHECK_EQ(bias_correction2->shape_view().elem_cnt(),\n               1);  // Just for Lazy Optional Input Check.\n      bias_correction2_ptr = bias_correction2->dptr<float>();\n    }\n\n    const T* scale_by_ptr = nullptr;\n    if (ctx->has_input(\"scale_by_tensor\", 0)) {\n      const user_op::Tensor* scale_by_tensor = ctx->Tensor4ArgNameAndIndex(\"scale_by_tensor\", 0);\n      CHECK_EQ(scale_by_tensor->data_type(), ctx->Tensor4ArgNameAndIndex(\"model\", 0)->data_type());\n      CHECK_EQ(scale_by_tensor->shape_view().elem_cnt(), 1);\n      scale_by_ptr = scale_by_tensor->dptr<T>();\n    }\n    const int64_t* skip_if_ptr = nullptr;\n    if (ctx->has_input(\"skip_if\", 0)) {\n      const user_op::Tensor* skip_if = ctx->Tensor4ArgNameAndIndex(\"skip_if\", 0);\n      CHECK_EQ(skip_if->shape_view().elem_cnt(), 1);\n      skip_if_ptr = skip_if->dptr<int64_t>();\n    }\n\n    TensorTupleParams<4> tensor_tuple_params{};\n    int32_t count = 0;\n    int32_t total_elem_cnt = 0;\n    for (int tensor_idx = 0; tensor_idx < n_tensor; tensor_idx++) {\n      tensor_tuple_params.ptr[0][count] =\n          (ctx->Tensor4ArgNameAndIndex(\"model\", tensor_idx))->mut_dptr();\n      tensor_tuple_params.ptr[1][count] =\n          (ctx->Tensor4ArgNameAndIndex(\"model_diff\", tensor_idx))->mut_dptr();\n      tensor_tuple_params.ptr[2][count] =\n          (ctx->Tensor4ArgNameAndIndex(\"m\", tensor_idx))->mut_dptr();\n      tensor_tuple_params.ptr[3][count] =\n          (ctx->Tensor4ArgNameAndIndex(\"v\", tensor_idx))->mut_dptr();\n      const int64_t tensor_elem_cnt =\n          ctx->Tensor4ArgNameAndIndex(\"model\", tensor_idx)->shape_view().elem_cnt();\n      tensor_tuple_params.sizes[count] = tensor_elem_cnt;\n\n      count += 1;\n      total_elem_cnt += tensor_elem_cnt;\n      if (count == kMaxTuples || tensor_idx == n_tensor - 1) {\n        MultiTensorAdamUpdateKernelUtil<device_type, T, G>::Update(\n            ctx->stream(), total_elem_cnt, count, static_cast<T>(scale), l1, l2, beta1, beta2,\n            epsilon, weight_decay, amsgrad, do_bias_correction, learning_rate_val,\n            bias_correction1_val, bias_correction2_val, lr_scale, learning_rate_ptr, scale_by_ptr,\n            skip_if_ptr, bias_correction1_ptr, bias_correction2_ptr, tensor_tuple_params);\n        count = 0;\n        total_elem_cnt = 0;\n      }\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return true; }\n};\n\n#define REGISTER_MULTI_TENSOR_UPDATE_ADAM_UPDATE_KERNEL(device, dtype, gtype)             \\\n  REGISTER_USER_KERNEL(\"multi_tensor_adam_update\")                                        \\\n      .SetCreateFn<MultiTensorAdamUpdateKernel<device, dtype, gtype>>()                   \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == device)                               \\\n                       && (user_op::HobDataType(\"model\", 0) == GetDataType<dtype>::value) \\\n                       && (user_op::HobDataType(\"model_diff\", 0) == GetDataType<gtype>::value));\n\n#ifdef WITH_CUDA\nREGISTER_MULTI_TENSOR_UPDATE_ADAM_UPDATE_KERNEL(DeviceType::kCUDA, float, float16);\nREGISTER_MULTI_TENSOR_UPDATE_ADAM_UPDATE_KERNEL(DeviceType::kCUDA, float, float);\nREGISTER_MULTI_TENSOR_UPDATE_ADAM_UPDATE_KERNEL(DeviceType::kCUDA, double, double);\n#endif\n\ntemplate<DeviceType device_type, typename T, typename G>\nclass MultiTensorSGDUpdateWithCastKernel final : public user_op::OpKernel,\n                                                 public user_op::CudaGraphSupport {\n public:\n  MultiTensorSGDUpdateWithCastKernel() = default;\n  ~MultiTensorSGDUpdateWithCastKernel() override = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const int64_t n_tensor = ctx->input_size(\"model\");\n    const double scale = ctx->Attr<double>(\"scale\");\n    const float l1 = ctx->Attr<float>(\"l1\");\n    const float l2 = ctx->Attr<float>(\"l2\");\n    const float weight_decay = ctx->Attr<float>(\"weight_decay\");\n    const float* learning_rate_ptr = nullptr;\n    const float learning_rate_val = ctx->Attr<float>(\"learning_rate_val\");\n    const float lr_scale = ctx->Attr<float>(\"learning_rate_scale\");\n\n    if (ctx->has_input(\"learning_rate\", 0)) {\n      const user_op::Tensor* learning_rate = ctx->Tensor4ArgNameAndIndex(\"learning_rate\", 0);\n      learning_rate_ptr = learning_rate->dptr<float>();\n    }\n    const T* scale_by_ptr = nullptr;\n    if (ctx->has_input(\"scale_by_tensor\", 0)) {\n      const user_op::Tensor* scale_by_tensor = ctx->Tensor4ArgNameAndIndex(\"scale_by_tensor\", 0);\n      CHECK_EQ(scale_by_tensor->data_type(), ctx->Tensor4ArgNameAndIndex(\"model\", 0)->data_type());\n      CHECK_EQ(scale_by_tensor->shape_view().elem_cnt(), 1);\n      scale_by_ptr = scale_by_tensor->dptr<T>();\n    }\n    const int64_t* skip_if_ptr = nullptr;\n    if (ctx->has_input(\"skip_if\", 0)) {\n      const user_op::Tensor* skip_if = ctx->Tensor4ArgNameAndIndex(\"skip_if\", 0);\n      CHECK_EQ(skip_if->shape_view().elem_cnt(), 1);\n      skip_if_ptr = skip_if->dptr<int64_t>();\n    }\n\n    TensorTupleParams<3> tensor_tuple_params{};\n    int32_t count = 0;\n    int32_t total_elem_cnt = 0;\n    for (int tensor_idx = 0; tensor_idx < n_tensor; tensor_idx++) {\n      tensor_tuple_params.ptr[0][count] =\n          (ctx->Tensor4ArgNameAndIndex(\"model\", tensor_idx))->mut_dptr();\n      tensor_tuple_params.ptr[1][count] =\n          (ctx->Tensor4ArgNameAndIndex(\"model_diff\", tensor_idx))->mut_dptr();\n      tensor_tuple_params.ptr[2][count] =\n          (ctx->Tensor4ArgNameAndIndex(\"model_copy\", tensor_idx))->mut_dptr();\n\n      const int64_t tensor_elem_cnt =\n          ctx->Tensor4ArgNameAndIndex(\"model\", tensor_idx)->shape_view().elem_cnt();\n      tensor_tuple_params.sizes[count] = tensor_elem_cnt;\n\n      count += 1;\n      total_elem_cnt += tensor_elem_cnt;\n      if (count == kMaxTuples || tensor_idx == n_tensor - 1) {\n        MultiTensorSGDUpdateWithCastKernelUtil<device_type, T, G>::Update(\n            ctx->stream(), total_elem_cnt, count, static_cast<T>(scale), l1, l2, weight_decay,\n            learning_rate_val, lr_scale, learning_rate_ptr, scale_by_ptr, skip_if_ptr,\n            tensor_tuple_params);\n        count = 0;\n        total_elem_cnt = 0;\n      }\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return true; }\n};\n\n#define REGISTER_MULTI_TENSOR_UPDATE_SGD_UPDATE_WITH_CAST_KERNEL(device, dtype, gtype)         \\\n  REGISTER_USER_KERNEL(\"multi_tensor_sgd_update_with_cast\")                                    \\\n      .SetCreateFn<MultiTensorSGDUpdateWithCastKernel<device, dtype, gtype>>()                 \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == device)                                    \\\n                       && (user_op::HobDataType(\"model\", 0) == GetDataType<dtype>::value)      \\\n                       && (user_op::HobDataType(\"model_diff\", 0) == GetDataType<gtype>::value) \\\n                       && (user_op::HobDataType(\"model_copy\", 0) == GetDataType<float16>::value));\n\n#ifdef WITH_CUDA\nREGISTER_MULTI_TENSOR_UPDATE_SGD_UPDATE_WITH_CAST_KERNEL(DeviceType::kCUDA, float, float);\nREGISTER_MULTI_TENSOR_UPDATE_SGD_UPDATE_WITH_CAST_KERNEL(DeviceType::kCUDA, float, float16);\n#endif\n\ntemplate<DeviceType device_type, typename T, typename G>\nclass MultiTensorMomentumUpdateWithCastKernel final : public user_op::OpKernel,\n                                                      public user_op::CudaGraphSupport {\n public:\n  MultiTensorMomentumUpdateWithCastKernel() = default;\n  ~MultiTensorMomentumUpdateWithCastKernel() override = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const int64_t n_tensor = ctx->input_size(\"model\");\n    const double scale = ctx->Attr<double>(\"scale\");\n    const float l1 = ctx->Attr<float>(\"l1\");\n    const float l2 = ctx->Attr<float>(\"l2\");\n    const float weight_decay = ctx->Attr<float>(\"weight_decay\");\n    const float* learning_rate_ptr = nullptr;\n    const float learning_rate_val = ctx->Attr<float>(\"learning_rate_val\");\n    const float lr_scale = ctx->Attr<float>(\"learning_rate_scale\");\n    const float momentum = ctx->Attr<float>(\"momentum\");\n    const float dampening = ctx->Attr<float>(\"dampening\");\n    const bool nesterov = ctx->Attr<float>(\"nesterov\");\n    const bool maximize = ctx->Attr<float>(\"maximize\");\n\n    if (ctx->has_input(\"learning_rate\", 0)) {\n      const user_op::Tensor* learning_rate = ctx->Tensor4ArgNameAndIndex(\"learning_rate\", 0);\n      learning_rate_ptr = learning_rate->dptr<float>();\n    }\n    const T* scale_by_ptr = nullptr;\n    if (ctx->has_input(\"scale_by_tensor\", 0)) {\n      const user_op::Tensor* scale_by_tensor = ctx->Tensor4ArgNameAndIndex(\"scale_by_tensor\", 0);\n      CHECK_EQ(scale_by_tensor->data_type(), ctx->Tensor4ArgNameAndIndex(\"model\", 0)->data_type());\n      CHECK_EQ(scale_by_tensor->shape_view().elem_cnt(), 1);\n      scale_by_ptr = scale_by_tensor->dptr<T>();\n    }\n    const int64_t* skip_if_ptr = nullptr;\n    if (ctx->has_input(\"skip_if\", 0)) {\n      const user_op::Tensor* skip_if = ctx->Tensor4ArgNameAndIndex(\"skip_if\", 0);\n      CHECK_EQ(skip_if->shape_view().elem_cnt(), 1);\n      skip_if_ptr = skip_if->dptr<int64_t>();\n    }\n\n    TensorTupleParams<4> tensor_tuple_params{};\n    int32_t count = 0;\n    int32_t total_elem_cnt = 0;\n    for (int tensor_idx = 0; tensor_idx < n_tensor; tensor_idx++) {\n      tensor_tuple_params.ptr[0][count] =\n          (ctx->Tensor4ArgNameAndIndex(\"model\", tensor_idx))->mut_dptr();\n      tensor_tuple_params.ptr[1][count] =\n          (ctx->Tensor4ArgNameAndIndex(\"model_diff\", tensor_idx))->mut_dptr();\n      tensor_tuple_params.ptr[2][count] =\n          (ctx->Tensor4ArgNameAndIndex(\"momentum_buf\", tensor_idx))->mut_dptr();\n      tensor_tuple_params.ptr[3][count] =\n          (ctx->Tensor4ArgNameAndIndex(\"model_copy\", tensor_idx))->mut_dptr();\n\n      const int64_t tensor_elem_cnt =\n          ctx->Tensor4ArgNameAndIndex(\"model\", tensor_idx)->shape_view().elem_cnt();\n      tensor_tuple_params.sizes[count] = tensor_elem_cnt;\n\n      count += 1;\n      total_elem_cnt += tensor_elem_cnt;\n      if (count == kMaxTuples || tensor_idx == n_tensor - 1) {\n        MultiTensorMomentumUpdateWithCastKernelUtil<device_type, T, G>::Update(\n            ctx->stream(), total_elem_cnt, count, static_cast<T>(scale), l1, l2, weight_decay,\n            learning_rate_val, lr_scale, learning_rate_ptr, scale_by_ptr, skip_if_ptr, momentum,\n            dampening, nesterov, maximize, tensor_tuple_params);\n        count = 0;\n        total_elem_cnt = 0;\n      }\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return true; }\n};\n\n#define REGISTER_MULTI_TENSOR_UPDATE_MOMENTUM_UPDATE_WITH_CAST_KERNEL(device, dtype, gtype)      \\\n  REGISTER_USER_KERNEL(\"multi_tensor_momentum_update_with_cast\")                                 \\\n      .SetCreateFn<MultiTensorMomentumUpdateWithCastKernel<device, dtype, gtype>>()              \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == device)                                      \\\n                       && (user_op::HobDataType(\"model\", 0) == GetDataType<dtype>::value)        \\\n                       && (user_op::HobDataType(\"model_diff\", 0) == GetDataType<gtype>::value)   \\\n                       && (user_op::HobDataType(\"momentum_buf\", 0) == GetDataType<gtype>::value) \\\n                       && (user_op::HobDataType(\"model_copy\", 0) == GetDataType<float16>::value));\n\n#ifdef WITH_CUDA\nREGISTER_MULTI_TENSOR_UPDATE_MOMENTUM_UPDATE_WITH_CAST_KERNEL(DeviceType::kCUDA, float, float);\nREGISTER_MULTI_TENSOR_UPDATE_MOMENTUM_UPDATE_WITH_CAST_KERNEL(DeviceType::kCUDA, float, float16);\n#endif\n\ntemplate<DeviceType device_type, typename T, typename G>\nclass MultiTensorAdamUpdateWithCastKernel final : public user_op::OpKernel,\n                                                  public user_op::CudaGraphSupport {\n public:\n  MultiTensorAdamUpdateWithCastKernel() = default;\n  ~MultiTensorAdamUpdateWithCastKernel() override = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const int64_t n_tensor = ctx->input_size(\"model\");\n    const auto scale = ctx->Attr<double>(\"scale\");\n    const float l1 = ctx->Attr<float>(\"l1\");\n    const float l2 = ctx->Attr<float>(\"l2\");\n\n    const float beta1 = ctx->Attr<float>(\"beta1\");\n    const float beta2 = ctx->Attr<float>(\"beta2\");\n    const float epsilon = ctx->Attr<float>(\"epsilon\");\n    const float weight_decay = ctx->Attr<float>(\"weight_decay\");\n\n    const bool amsgrad = ctx->Attr<bool>(\"amsgrad\");\n    const bool do_bias_correction = ctx->Attr<bool>(\"do_bias_correction\");\n    if (amsgrad) { UNIMPLEMENTED() << \"Multi Tensor Adam Update do not support amsgrad = True. \"; }\n\n    const float* learning_rate_ptr = nullptr;\n    const float learning_rate_val = ctx->Attr<float>(\"learning_rate_val\");\n    const float lr_scale = ctx->Attr<float>(\"learning_rate_scale\");\n\n    if (ctx->has_input(\"learning_rate\", 0)) {\n      const user_op::Tensor* learning_rate = ctx->Tensor4ArgNameAndIndex(\"learning_rate\", 0);\n      learning_rate_ptr = learning_rate->dptr<float>();\n    }\n\n    const float bias_correction1_val = ctx->Attr<float>(\"bias_correction1_val\");\n    const float* bias_correction1_ptr = nullptr;\n    if (ctx->has_input(\"bias_correction1\", 0)) {\n      const user_op::Tensor* bias_correction1 = ctx->Tensor4ArgNameAndIndex(\"bias_correction1\", 0);\n      CHECK_EQ(bias_correction1->shape_view().elem_cnt(),\n               1);  // Just for Lazy Optional Input Check.\n      bias_correction1_ptr = bias_correction1->dptr<float>();\n    }\n\n    const float bias_correction2_val = ctx->Attr<float>(\"bias_correction2_val\");\n    const float* bias_correction2_ptr = nullptr;\n    if (ctx->has_input(\"bias_correction2\", 0)) {\n      const user_op::Tensor* bias_correction2 = ctx->Tensor4ArgNameAndIndex(\"bias_correction2\", 0);\n      CHECK_EQ(bias_correction2->shape_view().elem_cnt(),\n               1);  // Just for Lazy Optional Input Check.\n      bias_correction2_ptr = bias_correction2->dptr<float>();\n    }\n\n    const T* scale_by_ptr = nullptr;\n    if (ctx->has_input(\"scale_by_tensor\", 0)) {\n      const user_op::Tensor* scale_by_tensor = ctx->Tensor4ArgNameAndIndex(\"scale_by_tensor\", 0);\n      CHECK_EQ(scale_by_tensor->data_type(), ctx->Tensor4ArgNameAndIndex(\"model\", 0)->data_type());\n      CHECK_EQ(scale_by_tensor->shape_view().elem_cnt(), 1);\n      scale_by_ptr = scale_by_tensor->dptr<T>();\n    }\n    const int64_t* skip_if_ptr = nullptr;\n    if (ctx->has_input(\"skip_if\", 0)) {\n      const user_op::Tensor* skip_if = ctx->Tensor4ArgNameAndIndex(\"skip_if\", 0);\n      CHECK_EQ(skip_if->shape_view().elem_cnt(), 1);\n      skip_if_ptr = skip_if->dptr<int64_t>();\n    }\n\n    TensorTupleParams<5> tensor_tuple_params{};\n    int32_t count = 0;\n    int32_t total_elem_cnt = 0;\n    for (int tensor_idx = 0; tensor_idx < n_tensor; tensor_idx++) {\n      tensor_tuple_params.ptr[0][count] =\n          (ctx->Tensor4ArgNameAndIndex(\"model\", tensor_idx))->mut_dptr();\n      tensor_tuple_params.ptr[1][count] =\n          (ctx->Tensor4ArgNameAndIndex(\"model_diff\", tensor_idx))->mut_dptr();\n      tensor_tuple_params.ptr[2][count] =\n          (ctx->Tensor4ArgNameAndIndex(\"m\", tensor_idx))->mut_dptr();\n      tensor_tuple_params.ptr[3][count] =\n          (ctx->Tensor4ArgNameAndIndex(\"v\", tensor_idx))->mut_dptr();\n      tensor_tuple_params.ptr[4][count] =\n          (ctx->Tensor4ArgNameAndIndex(\"model_copy\", tensor_idx))->mut_dptr();\n      const int64_t tensor_elem_cnt =\n          ctx->Tensor4ArgNameAndIndex(\"model\", tensor_idx)->shape_view().elem_cnt();\n      tensor_tuple_params.sizes[count] = tensor_elem_cnt;\n\n      count += 1;\n      total_elem_cnt += tensor_elem_cnt;\n      if (count == kMaxTuples || tensor_idx == n_tensor - 1) {\n        MultiTensorAdamUpdateWithCastKernelUtil<device_type, T, G>::Update(\n            ctx->stream(), total_elem_cnt, count, static_cast<T>(scale), l1, l2, beta1, beta2,\n            epsilon, weight_decay, amsgrad, do_bias_correction, learning_rate_val,\n            bias_correction1_val, bias_correction2_val, lr_scale, learning_rate_ptr, scale_by_ptr,\n            skip_if_ptr, bias_correction1_ptr, bias_correction2_ptr, tensor_tuple_params);\n        count = 0;\n        total_elem_cnt = 0;\n      }\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return true; }\n};\n\n#define REGISTER_MULTI_TENSOR_UPDATE_ADAM_UPDATE_WITH_CAST_KERNEL(device, dtype, gtype)        \\\n  REGISTER_USER_KERNEL(\"multi_tensor_adam_update_with_cast\")                                   \\\n      .SetCreateFn<MultiTensorAdamUpdateWithCastKernel<device, dtype, gtype>>()                \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == device)                                    \\\n                       && (user_op::HobDataType(\"model\", 0) == GetDataType<dtype>::value)      \\\n                       && (user_op::HobDataType(\"model_diff\", 0) == GetDataType<gtype>::value) \\\n                       && (user_op::HobDataType(\"model_copy\", 0) == GetDataType<float16>::value));\n\n#ifdef WITH_CUDA\nREGISTER_MULTI_TENSOR_UPDATE_ADAM_UPDATE_WITH_CAST_KERNEL(DeviceType::kCUDA, float, float);\nREGISTER_MULTI_TENSOR_UPDATE_ADAM_UPDATE_WITH_CAST_KERNEL(DeviceType::kCUDA, float, float16);\n#endif\n\ntemplate<DeviceType device_type, typename T>\nclass MultiTensorYoloV5WeightUpdateKernel final : public user_op::OpKernel,\n                                                  public user_op::CudaGraphSupport {\n public:\n  MultiTensorYoloV5WeightUpdateKernel() = default;\n  ~MultiTensorYoloV5WeightUpdateKernel() override = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const int64_t n_tensor = ctx->input_size(\"model\");\n    const float d = ctx->Attr<float>(\"d\");\n\n    TensorTupleParams<2> tensor_tuple_params{};\n    int32_t count = 0;\n    int32_t total_elem_cnt = 0;\n    for (int tensor_idx = 0; tensor_idx < n_tensor; tensor_idx++) {\n      tensor_tuple_params.ptr[0][count] =\n          (ctx->Tensor4ArgNameAndIndex(\"model\", tensor_idx))->mut_dptr();\n      tensor_tuple_params.ptr[1][count] =\n          (ctx->Tensor4ArgNameAndIndex(\"model_update\", tensor_idx))->mut_dptr();\n      const int64_t tensor_elem_cnt =\n          ctx->Tensor4ArgNameAndIndex(\"model\", tensor_idx)->shape_view().elem_cnt();\n      tensor_tuple_params.sizes[count] = tensor_elem_cnt;\n\n      count += 1;\n      total_elem_cnt += tensor_elem_cnt;\n      if (count == kMaxTuples || tensor_idx == n_tensor - 1) {\n        MultiTensorYoloV5WeightUpdateKernelUtil<device_type, T>::Update(\n            ctx->stream(), total_elem_cnt, count, d, tensor_tuple_params);\n        count = 0;\n        total_elem_cnt = 0;\n      }\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return true; }\n};\n\n#define REGISTER_MULTI_TENSOR_YOLOV5_WEIGHT_UPDATE_KERNEL(device, dtype) \\\n  REGISTER_USER_KERNEL(\"multi_tensor_yolov5_weight_update\")              \\\n      .SetCreateFn<MultiTensorYoloV5WeightUpdateKernel<device, dtype>>() \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == device)              \\\n                       && (user_op::HobDataType(\"model\", 0) == GetDataType<dtype>::value));\n\n#ifdef WITH_CUDA\nREGISTER_MULTI_TENSOR_YOLOV5_WEIGHT_UPDATE_KERNEL(DeviceType::kCUDA, float);\n#endif\n\n}  // namespace\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/multi_tensor_model_update_kernel_util.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/user/kernels/model_update_kernel_util.h\"\n#include \"oneflow/user/kernels/multi_tensor_model_update_kernel_util.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n\nnamespace oneflow {\n\nconstexpr int kBlockSize = 256;\nconstexpr int kUnrollSize = 4;\n\nunsigned int ComputeGridSize(ep::Stream* stream, const int32_t block_size, const int64_t elem_cnt) {\n  auto* cuda_stream = stream->As<ep::CudaStream>();\n  const int32_t max_threads_multi_process =\n      cuda_stream->device_properties().maxThreadsPerMultiProcessor;\n  const int32_t multi_processor_count = cuda_stream->device_properties().multiProcessorCount;\n  unsigned int blocks_per_sm = max_threads_multi_process / block_size;\n  unsigned int grid_size = ((elem_cnt + block_size - 1) / block_size);\n  grid_size = std::min((unsigned int)multi_processor_count * blocks_per_sm, grid_size);\n  return grid_size;\n}\n\ntemplate<typename T, typename G, int N>\n__global__ void MultiTensorSGDUpdateGpu(int64_t num_tensor, T scale, const float l1, const float l2,\n                                        const float weight_decay, float learning_rate_val,\n                                        float lr_scale, const float* learning_rate,\n                                        const T* scale_by_ptr, const int64_t* skip_if,\n                                        TensorTupleParams<N> tensor_tuple_params) {\n  if (skip_if != nullptr && *skip_if != 0) { return; }\n  if (learning_rate != nullptr) { learning_rate_val = *learning_rate; }\n  if (scale_by_ptr != nullptr) { scale *= *scale_by_ptr; }\n  learning_rate_val *= lr_scale;\n  int64_t v_block_id = blockIdx.x;\n  for (int64_t tensor_idx = 0; tensor_idx < num_tensor; tensor_idx++) {\n    const int64_t tensor_elem_cnt = tensor_tuple_params.sizes[tensor_idx];\n    T* model_ptr = (T*)tensor_tuple_params.ptr[0][tensor_idx];\n    G* model_diff_ptr = (G*)tensor_tuple_params.ptr[1][tensor_idx];\n    half* model_copy_ptr = nullptr;\n    if (N == 3) { model_copy_ptr = (half*)tensor_tuple_params.ptr[2][tensor_idx]; }\n\n    for (int64_t i = v_block_id * blockDim.x * kUnrollSize + threadIdx.x; i < tensor_elem_cnt;\n         i += blockDim.x * gridDim.x * kUnrollSize) {\n      T model_val[kUnrollSize] = {0};\n      G model_diff[kUnrollSize] = {0};\n\n#pragma unroll\n      for (int32_t ilp = 0; ilp < kUnrollSize; ilp++) {\n        int64_t actual_idx = i + ilp * blockDim.x;\n        if (actual_idx < tensor_elem_cnt) {\n          model_val[ilp] = *(model_ptr + actual_idx);\n          model_diff[ilp] = *(model_diff_ptr + actual_idx);\n        }\n      }\n\n#pragma unroll\n      for (int32_t ilp = 0; ilp < kUnrollSize; ilp++) {\n        int64_t actual_idx = i + ilp * blockDim.x;\n        if (actual_idx < tensor_elem_cnt) {\n          T model_diff_t = CastScaleRegularizeGradientFunctor<T, G>()(\n              model_diff[ilp], model_val[ilp], scale, l1, l2);\n          model_val[ilp] =\n              model_val[ilp] - learning_rate_val * (model_diff_t + weight_decay * model_val[ilp]);\n        }\n      }\n\n#pragma unroll\n      for (int32_t ilp = 0; ilp < kUnrollSize; ilp++) {\n        int64_t actual_idx = i + ilp * blockDim.x;\n        if (actual_idx < tensor_elem_cnt) {\n          *(model_ptr + actual_idx) = model_val[ilp];\n          if (N == 3) { *(model_copy_ptr + actual_idx) = static_cast<half>(model_val[ilp]); }\n        }\n      }\n    }\n    v_block_id -= tensor_tuple_params.block_offset[tensor_idx];\n    if (v_block_id < 0) { v_block_id += gridDim.x; }\n  }\n}\n\ntemplate<typename T, typename G>\nstruct MultiTensorSGDUpdateKernelUtil<DeviceType::kCUDA, T, G> {\n  static void Update(ep::Stream* stream, const int64_t elem_cnt, const int64_t n_tensor, T scale,\n                     float l1, float l2, float weight_decay, float learning_rate_val,\n                     float lr_scale, const float* learning_rate, const T* scale_by_ptr,\n                     const int64_t* skip_if, TensorTupleParams<2> tensor_tuple_params);\n};\n\ntemplate<typename T, typename G>\nvoid MultiTensorSGDUpdateKernelUtil<DeviceType::kCUDA, T, G>::Update(\n    ep::Stream* stream, const int64_t elem_cnt, const int64_t n_tensor, T scale, float l1, float l2,\n    float weight_decay, float learning_rate_val, float lr_scale, const float* learning_rate,\n    const T* scale_by_ptr, const int64_t* skip_if, TensorTupleParams<2> tensor_tuple_params) {\n  const unsigned int grid_size =\n      ComputeGridSize(stream->As<ep::CudaStream>(), kBlockSize, elem_cnt);\n  for (int i = 0; i < n_tensor; i++) {\n    tensor_tuple_params.block_offset[i] =\n        ((tensor_tuple_params.sizes[i] + kBlockSize * kUnrollSize - 1) / (kBlockSize * kUnrollSize))\n        % grid_size;\n  }\n  MultiTensorSGDUpdateGpu<T, G, 2>\n      <<<grid_size, kBlockSize, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(\n          n_tensor, static_cast<T>(scale), l1, l2, weight_decay, learning_rate_val, lr_scale,\n          learning_rate, scale_by_ptr, skip_if, tensor_tuple_params);\n}\n\ntemplate<typename T>\nstruct MultiTensorSGDUpdateKernelUtil<DeviceType::kCUDA, T, float16> {\n  static void Update(ep::Stream* stream, const int64_t elem_cnt, const int64_t n_tensor, T scale,\n                     float l1, float l2, float weight_decay, float learning_rate_val,\n                     float lr_scale, const float* learning_rate, const T* scale_by_ptr,\n                     const int64_t* skip_if, TensorTupleParams<2> tensor_tuple_params);\n};\n\ntemplate<typename T>\nvoid MultiTensorSGDUpdateKernelUtil<DeviceType::kCUDA, T, float16>::Update(\n    ep::Stream* stream, const int64_t elem_cnt, const int64_t n_tensor, T scale, float l1, float l2,\n    float weight_decay, float learning_rate_val, float lr_scale, const float* learning_rate,\n    const T* scale_by_ptr, const int64_t* skip_if, TensorTupleParams<2> tensor_tuple_params) {\n  MultiTensorSGDUpdateKernelUtil<DeviceType::kCUDA, T, half>::Update(\n      stream, elem_cnt, n_tensor, scale, l1, l2, weight_decay, learning_rate_val, lr_scale,\n      learning_rate, scale_by_ptr, skip_if, tensor_tuple_params);\n}\n\ntemplate struct MultiTensorSGDUpdateKernelUtil<DeviceType::kCUDA, double, double>;\ntemplate struct MultiTensorSGDUpdateKernelUtil<DeviceType::kCUDA, float, float>;\ntemplate struct MultiTensorSGDUpdateKernelUtil<DeviceType::kCUDA, float, float16>;\n\ntemplate<typename T, typename G, int N>\n__global__ void MultiTensorMomentumUpdateGpu(\n    int64_t num_tensor, T scale, const float l1, const float l2, const float weight_decay,\n    float learning_rate_val, float lr_scale, const float* learning_rate, const T* scale_by_ptr,\n    const int64_t* skip_if, const float momentum, const float dampening, const bool nesterov,\n    const bool maximize, TensorTupleParams<N> tensor_tuple_params) {\n  if (skip_if != nullptr && *skip_if != 0) { return; }\n  if (learning_rate != nullptr) { learning_rate_val = *learning_rate; }\n  if (scale_by_ptr != nullptr) { scale *= *scale_by_ptr; }\n  learning_rate_val *= lr_scale;\n  int64_t v_block_id = blockIdx.x;\n  for (int64_t tensor_idx = 0; tensor_idx < num_tensor; tensor_idx++) {\n    const int64_t tensor_elem_cnt = tensor_tuple_params.sizes[tensor_idx];\n    T* model_ptr = (T*)tensor_tuple_params.ptr[0][tensor_idx];\n    G* model_diff_ptr = (G*)tensor_tuple_params.ptr[1][tensor_idx];\n    T* momentum_buf_ptr = (T*)tensor_tuple_params.ptr[2][tensor_idx];\n    half* model_copy_ptr = nullptr;\n    if (N == 4) { model_copy_ptr = (half*)tensor_tuple_params.ptr[3][tensor_idx]; }\n\n    for (int64_t i = v_block_id * blockDim.x * kUnrollSize + threadIdx.x; i < tensor_elem_cnt;\n         i += blockDim.x * gridDim.x * kUnrollSize) {\n      T model_val[kUnrollSize] = {0};\n      G model_diff[kUnrollSize] = {0};\n      T momentum_buf[kUnrollSize] = {0};\n\n#pragma unroll\n      for (int32_t ilp = 0; ilp < kUnrollSize; ilp++) {\n        int64_t actual_idx = i + ilp * blockDim.x;\n        if (actual_idx < tensor_elem_cnt) {\n          model_val[ilp] = *(model_ptr + actual_idx);\n          model_diff[ilp] = *(model_diff_ptr + actual_idx);\n          momentum_buf[ilp] = *(momentum_buf_ptr + actual_idx);\n        }\n      }\n\n#pragma unroll\n      for (int32_t ilp = 0; ilp < kUnrollSize; ilp++) {\n        int64_t actual_idx = i + ilp * blockDim.x;\n        if (actual_idx < tensor_elem_cnt) {\n          T model_diff_t = CastScaleRegularizeGradientFunctor<T, G>()(\n              model_diff[ilp], model_val[ilp], scale, l1, l2);\n\n          if (weight_decay != 0.f) { model_diff_t += weight_decay * model_val[ilp]; }\n\n          momentum_buf[ilp] = momentum * momentum_buf[ilp] + (1.f - dampening) * model_diff_t;\n\n          if (nesterov)\n            model_diff_t += momentum * momentum_buf[ilp];\n          else\n            model_diff_t = momentum_buf[ilp];\n\n          T alpha = -learning_rate_val;\n          if (maximize) alpha = learning_rate_val;\n          model_val[ilp] += alpha * model_diff_t;\n        }\n      }\n\n#pragma unroll\n      for (int32_t ilp = 0; ilp < kUnrollSize; ilp++) {\n        int64_t actual_idx = i + ilp * blockDim.x;\n        if (actual_idx < tensor_elem_cnt) {\n          *(model_ptr + actual_idx) = model_val[ilp];\n          *(momentum_buf_ptr + actual_idx) = momentum_buf[ilp];\n          if (N == 4) { *(model_copy_ptr + actual_idx) = static_cast<half>(model_val[ilp]); }\n        }\n      }\n    }\n    v_block_id -= tensor_tuple_params.block_offset[tensor_idx];\n    if (v_block_id < 0) { v_block_id += gridDim.x; }\n  }\n}\n\ntemplate<typename T, typename G>\nstruct MultiTensorMomentumUpdateKernelUtil<DeviceType::kCUDA, T, G> {\n  static void Update(ep::Stream* stream, const int64_t elem_cnt, const int64_t n_tensor, T scale,\n                     float l1, float l2, float weight_decay, float learning_rate_val,\n                     float lr_scale, const float* learning_rate, const T* scale_by_ptr,\n                     const int64_t* skip_if, const float momentum, const float dampening,\n                     const bool nesterov, const bool maximize,\n                     TensorTupleParams<3> tensor_tuple_params);\n};\n\ntemplate<typename T, typename G>\nvoid MultiTensorMomentumUpdateKernelUtil<DeviceType::kCUDA, T, G>::Update(\n    ep::Stream* stream, const int64_t elem_cnt, const int64_t n_tensor, T scale, float l1, float l2,\n    float weight_decay, float learning_rate_val, float lr_scale, const float* learning_rate,\n    const T* scale_by_ptr, const int64_t* skip_if, const float momentum, const float dampening,\n    const bool nesterov, const bool maximize, TensorTupleParams<3> tensor_tuple_params) {\n  const unsigned int grid_size =\n      ComputeGridSize(stream->As<ep::CudaStream>(), kBlockSize, elem_cnt);\n  for (int i = 0; i < n_tensor; i++) {\n    tensor_tuple_params.block_offset[i] =\n        ((tensor_tuple_params.sizes[i] + kBlockSize * kUnrollSize - 1) / (kBlockSize * kUnrollSize))\n        % grid_size;\n  }\n  MultiTensorMomentumUpdateGpu<T, G, 3>\n      <<<grid_size, kBlockSize, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(\n          n_tensor, static_cast<T>(scale), l1, l2, weight_decay, learning_rate_val, lr_scale,\n          learning_rate, scale_by_ptr, skip_if, momentum, dampening, nesterov, maximize,\n          tensor_tuple_params);\n}\n\ntemplate<typename T>\nstruct MultiTensorMomentumUpdateKernelUtil<DeviceType::kCUDA, T, float16> {\n  static void Update(ep::Stream* stream, const int64_t elem_cnt, const int64_t n_tensor, T scale,\n                     float l1, float l2, float weight_decay, float learning_rate_val,\n                     float lr_scale, const float* learning_rate, const T* scale_by_ptr,\n                     const int64_t* skip_if, const float momentum, const float dampening,\n                     const bool nesterov, const bool maximize,\n                     TensorTupleParams<3> tensor_tuple_params);\n};\n\ntemplate<typename T>\nvoid MultiTensorMomentumUpdateKernelUtil<DeviceType::kCUDA, T, float16>::Update(\n    ep::Stream* stream, const int64_t elem_cnt, const int64_t n_tensor, T scale, float l1, float l2,\n    float weight_decay, float learning_rate_val, float lr_scale, const float* learning_rate,\n    const T* scale_by_ptr, const int64_t* skip_if, const float momentum, const float dampening,\n    const bool nesterov, const bool maximize, TensorTupleParams<3> tensor_tuple_params) {\n  MultiTensorMomentumUpdateKernelUtil<DeviceType::kCUDA, T, half>::Update(\n      stream, elem_cnt, n_tensor, scale, l1, l2, weight_decay, learning_rate_val, lr_scale,\n      learning_rate, scale_by_ptr, skip_if, momentum, dampening, nesterov, maximize,\n      tensor_tuple_params);\n}\n\ntemplate struct MultiTensorMomentumUpdateKernelUtil<DeviceType::kCUDA, double, double>;\ntemplate struct MultiTensorMomentumUpdateKernelUtil<DeviceType::kCUDA, float, float>;\ntemplate struct MultiTensorMomentumUpdateKernelUtil<DeviceType::kCUDA, float, float16>;\n\ntemplate<typename T, typename G, int N>\n__global__ void MultiTensorAdamUpdateGpu(int64_t num_tensor, T scale, float l1, float l2,\n                                         float beta1, float beta2, float epsilon,\n                                         float weight_decay, bool amsgrad, bool do_bias_correction,\n                                         float learning_rate_val, float bias_correction1_val,\n                                         float bias_correction2_val, float lr_scale,\n                                         const float* learning_rate, const T* scale_by_ptr,\n                                         const int64_t* skip_if, const float* bias_correction1_ptr,\n                                         const float* bias_correction2_ptr,\n                                         TensorTupleParams<N> tensor_tuple_params) {\n  if (skip_if != nullptr && *skip_if != 0) { return; }\n  if (learning_rate != nullptr) { learning_rate_val = *learning_rate; }\n  if (scale_by_ptr != nullptr) { scale *= *scale_by_ptr; }\n  if (bias_correction1_ptr != nullptr) { bias_correction1_val = *bias_correction1_ptr; }\n  if (bias_correction2_ptr != nullptr) { bias_correction2_val = *bias_correction2_ptr; }\n\n  learning_rate_val *= lr_scale;\n  int64_t v_block_id = blockIdx.x;\n  for (int64_t tensor_idx = 0; tensor_idx < num_tensor; tensor_idx++) {\n    const int64_t tensor_elem_cnt = tensor_tuple_params.sizes[tensor_idx];\n    T* model_ptr = (T*)tensor_tuple_params.ptr[0][tensor_idx];\n    G* model_diff_ptr = (G*)tensor_tuple_params.ptr[1][tensor_idx];\n    T* m_ptr = (T*)tensor_tuple_params.ptr[2][tensor_idx];\n    T* v_ptr = (T*)tensor_tuple_params.ptr[3][tensor_idx];\n    half* model_copy_ptr = nullptr;\n    if (N == 5) { model_copy_ptr = (half*)tensor_tuple_params.ptr[4][tensor_idx]; }\n\n    for (int64_t i = v_block_id * blockDim.x * kUnrollSize + threadIdx.x; i < tensor_elem_cnt;\n         i += blockDim.x * gridDim.x * kUnrollSize) {\n      T model_val[kUnrollSize] = {0};\n      T m_val[kUnrollSize] = {0};\n      T v_val[kUnrollSize] = {0};\n      G model_diff[kUnrollSize] = {0};\n\n#pragma unroll\n      for (int32_t ilp = 0; ilp < kUnrollSize; ilp++) {\n        int64_t actual_idx = i + ilp * blockDim.x;\n        if (actual_idx < tensor_elem_cnt) {\n          model_val[ilp] = *(model_ptr + actual_idx);\n          m_val[ilp] = *(m_ptr + actual_idx);\n          v_val[ilp] = *(v_ptr + actual_idx);\n          model_diff[ilp] = *(model_diff_ptr + actual_idx);\n        }\n      }\n\n#pragma unroll\n      for (int32_t ilp = 0; ilp < kUnrollSize; ilp++) {\n        int64_t actual_idx = i + ilp * blockDim.x;\n        if (actual_idx < tensor_elem_cnt) {\n          T model_diff_t = CastScaleRegularizeGradientFunctor<T, G>()(\n              model_diff[ilp], model_val[ilp], scale, l1, l2);\n\n          m_val[ilp] = beta1 * m_val[ilp] + (1 - beta1) * model_diff_t;\n          v_val[ilp] = beta2 * v_val[ilp] + (1 - beta2) * model_diff_t * model_diff_t;\n\n          T denom = (sqrt(v_val[ilp]) / sqrt(bias_correction2_val)) + epsilon;\n          const T step_size = learning_rate_val / bias_correction1_val;\n          model_val[ilp] = model_val[ilp] - step_size * (m_val[ilp] / denom)\n                           - learning_rate_val * weight_decay * model_val[ilp];\n        }\n      }\n\n#pragma unroll\n      for (int32_t ilp = 0; ilp < kUnrollSize; ilp++) {\n        int64_t actual_idx = i + ilp * blockDim.x;\n        if (actual_idx < tensor_elem_cnt) {\n          *(model_ptr + actual_idx) = model_val[ilp];\n          *(m_ptr + actual_idx) = m_val[ilp];\n          *(v_ptr + actual_idx) = v_val[ilp];\n          if (N == 5) { *(model_copy_ptr + actual_idx) = static_cast<half>(model_val[ilp]); }\n        }\n      }\n    }\n    v_block_id -= tensor_tuple_params.block_offset[tensor_idx];\n    if (v_block_id < 0) { v_block_id += gridDim.x; }\n  }\n}\n\ntemplate<typename T, typename G>\nstruct MultiTensorAdamUpdateKernelUtil<DeviceType::kCUDA, T, G> {\n  static void Update(ep::Stream* stream, const int64_t elem_cnt, const int64_t n_tensor, T scale,\n                     float l1, float l2, float beta1, float beta2, float epsilon,\n                     float weight_decay, bool amsgrad, bool do_bias_correction,\n                     float learning_rate_val, float bias_correction1_val,\n                     float bias_correction2_val, float lr_scale, const float* learning_rate,\n                     const T* scale_by_ptr, const int64_t* skip_if, const float* bias_correction1,\n                     const float* bias_correction2, TensorTupleParams<4> tensor_tuple_params);\n};\n\ntemplate<typename T, typename G>\nvoid MultiTensorAdamUpdateKernelUtil<DeviceType::kCUDA, T, G>::Update(\n    ep::Stream* stream, const int64_t elem_cnt, const int64_t n_tensor, T scale, float l1, float l2,\n    float beta1, float beta2, float epsilon, float weight_decay, bool amsgrad,\n    bool do_bias_correction, float learning_rate_val, float bias_correction1_val,\n    float bias_correction2_val, float lr_scale, const float* learning_rate, const T* scale_by_ptr,\n    const int64_t* skip_if, const float* bias_correction1, const float* bias_correction2,\n    TensorTupleParams<4> tensor_tuple_params) {\n  const unsigned int grid_size =\n      ComputeGridSize(stream->As<ep::CudaStream>(), kBlockSize, elem_cnt);\n  for (int i = 0; i < n_tensor; i++) {\n    tensor_tuple_params.block_offset[i] =\n        ((tensor_tuple_params.sizes[i] + kBlockSize * kUnrollSize - 1) / (kBlockSize * kUnrollSize))\n        % grid_size;\n  }\n  MultiTensorAdamUpdateGpu<T, G>\n      <<<grid_size, kBlockSize, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(\n          n_tensor, scale, l1, l2, beta1, beta2, epsilon, weight_decay, amsgrad, do_bias_correction,\n          learning_rate_val, bias_correction1_val, bias_correction2_val, lr_scale, learning_rate,\n          scale_by_ptr, skip_if, bias_correction1, bias_correction2, tensor_tuple_params);\n}\n\ntemplate<typename T>\nstruct MultiTensorAdamUpdateKernelUtil<DeviceType::kCUDA, T, float16> {\n  static void Update(ep::Stream* stream, const int64_t elem_cnt, const int64_t n_tensor, T scale,\n                     float l1, float l2, float beta1, float beta2, float epsilon,\n                     float weight_decay, bool amsgrad, bool do_bias_correction,\n                     float learning_rate_val, float bias_correction1_val,\n                     float bias_correction2_val, float lr_scale, const float* learning_rate,\n                     const T* scale_by_ptr, const int64_t* skip_if, const float* bias_correction1,\n                     const float* bias_correction2, TensorTupleParams<4> tensor_tuple_params);\n};\n\ntemplate<typename T>\nvoid MultiTensorAdamUpdateKernelUtil<DeviceType::kCUDA, T, float16>::Update(\n    ep::Stream* stream, const int64_t elem_cnt, const int64_t n_tensor, T scale, float l1, float l2,\n    float beta1, float beta2, float epsilon, float weight_decay, bool amsgrad,\n    bool do_bias_correction, float learning_rate_val, float bias_correction1_val,\n    float bias_correction2_val, float lr_scale, const float* learning_rate, const T* scale_by_ptr,\n    const int64_t* skip_if, const float* bias_correction1, const float* bias_correction2,\n    TensorTupleParams<4> tensor_tuple_params) {\n  MultiTensorAdamUpdateKernelUtil<DeviceType::kCUDA, T, half>::Update(\n      stream, elem_cnt, n_tensor, scale, l1, l2, beta1, beta2, epsilon, weight_decay, amsgrad,\n      do_bias_correction, learning_rate_val, bias_correction1_val, bias_correction2_val, lr_scale,\n      learning_rate, scale_by_ptr, skip_if, bias_correction1, bias_correction2,\n      tensor_tuple_params);\n}\n\ntemplate struct MultiTensorAdamUpdateKernelUtil<DeviceType::kCUDA, double, double>;\ntemplate struct MultiTensorAdamUpdateKernelUtil<DeviceType::kCUDA, float, float>;\ntemplate struct MultiTensorAdamUpdateKernelUtil<DeviceType::kCUDA, float, float16>;\n\ntemplate<typename T, typename G>\nstruct MultiTensorSGDUpdateWithCastKernelUtil<DeviceType::kCUDA, T, G> {\n  static void Update(ep::Stream* stream, const int64_t elem_cnt, const int64_t n_tensor, T scale,\n                     float l1, float l2, float weight_decay, float learning_rate_val,\n                     float lr_scale, const float* learning_rate, const T* scale_by_ptr,\n                     const int64_t* skip_if, TensorTupleParams<3> tensor_tuple_params);\n};\n\ntemplate<typename T, typename G>\nvoid MultiTensorSGDUpdateWithCastKernelUtil<DeviceType::kCUDA, T, G>::Update(\n    ep::Stream* stream, const int64_t elem_cnt, const int64_t n_tensor, T scale, float l1, float l2,\n    float weight_decay, float learning_rate_val, float lr_scale, const float* learning_rate,\n    const T* scale_by_ptr, const int64_t* skip_if, TensorTupleParams<3> tensor_tuple_params) {\n  const unsigned int grid_size =\n      ComputeGridSize(stream->As<ep::CudaStream>(), kBlockSize, elem_cnt);\n  for (int i = 0; i < n_tensor; i++) {\n    tensor_tuple_params.block_offset[i] =\n        ((tensor_tuple_params.sizes[i] + kBlockSize * kUnrollSize - 1) / (kBlockSize * kUnrollSize))\n        % grid_size;\n  }\n  MultiTensorSGDUpdateGpu<T, G, 3>\n      <<<grid_size, kBlockSize, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(\n          n_tensor, static_cast<T>(scale), l1, l2, weight_decay, learning_rate_val, lr_scale,\n          learning_rate, scale_by_ptr, skip_if, tensor_tuple_params);\n}\n\ntemplate<typename T>\nstruct MultiTensorSGDUpdateWithCastKernelUtil<DeviceType::kCUDA, T, float16> {\n  static void Update(ep::Stream* stream, const int64_t elem_cnt, const int64_t n_tensor, T scale,\n                     float l1, float l2, float weight_decay, float learning_rate_val,\n                     float lr_scale, const float* learning_rate, const T* scale_by_ptr,\n                     const int64_t* skip_if, TensorTupleParams<3> tensor_tuple_params);\n};\n\ntemplate<typename T>\nvoid MultiTensorSGDUpdateWithCastKernelUtil<DeviceType::kCUDA, T, float16>::Update(\n    ep::Stream* stream, const int64_t elem_cnt, const int64_t n_tensor, T scale, float l1, float l2,\n    float weight_decay, float learning_rate_val, float lr_scale, const float* learning_rate,\n    const T* scale_by_ptr, const int64_t* skip_if, TensorTupleParams<3> tensor_tuple_params) {\n  MultiTensorSGDUpdateWithCastKernelUtil<DeviceType::kCUDA, T, half>::Update(\n      stream, elem_cnt, n_tensor, scale, l1, l2, weight_decay, learning_rate_val, lr_scale,\n      learning_rate, scale_by_ptr, skip_if, tensor_tuple_params);\n}\n\ntemplate struct MultiTensorSGDUpdateWithCastKernelUtil<DeviceType::kCUDA, float, float>;\ntemplate struct MultiTensorSGDUpdateWithCastKernelUtil<DeviceType::kCUDA, float, float16>;\n\ntemplate<typename T, typename G>\nstruct MultiTensorMomentumUpdateWithCastKernelUtil<DeviceType::kCUDA, T, G> {\n  static void Update(ep::Stream* stream, const int64_t elem_cnt, const int64_t n_tensor, T scale,\n                     float l1, float l2, float weight_decay, float learning_rate_val,\n                     float lr_scale, const float* learning_rate, const T* scale_by_ptr,\n                     const int64_t* skip_if, const float momentum, const float dampening,\n                     const bool nesterov, const bool maximize,\n                     TensorTupleParams<4> tensor_tuple_params);\n};\n\ntemplate<typename T, typename G>\nvoid MultiTensorMomentumUpdateWithCastKernelUtil<DeviceType::kCUDA, T, G>::Update(\n    ep::Stream* stream, const int64_t elem_cnt, const int64_t n_tensor, T scale, float l1, float l2,\n    float weight_decay, float learning_rate_val, float lr_scale, const float* learning_rate,\n    const T* scale_by_ptr, const int64_t* skip_if, const float momentum, const float dampening,\n    const bool nesterov, const bool maximize, TensorTupleParams<4> tensor_tuple_params) {\n  const unsigned int grid_size =\n      ComputeGridSize(stream->As<ep::CudaStream>(), kBlockSize, elem_cnt);\n  for (int i = 0; i < n_tensor; i++) {\n    tensor_tuple_params.block_offset[i] =\n        ((tensor_tuple_params.sizes[i] + kBlockSize * kUnrollSize - 1) / (kBlockSize * kUnrollSize))\n        % grid_size;\n  }\n  MultiTensorMomentumUpdateGpu<T, G, 4>\n      <<<grid_size, kBlockSize, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(\n          n_tensor, static_cast<T>(scale), l1, l2, weight_decay, learning_rate_val, lr_scale,\n          learning_rate, scale_by_ptr, skip_if, momentum, dampening, nesterov, maximize,\n          tensor_tuple_params);\n}\n\ntemplate<typename T>\nstruct MultiTensorMomentumUpdateWithCastKernelUtil<DeviceType::kCUDA, T, float16> {\n  static void Update(ep::Stream* stream, const int64_t elem_cnt, const int64_t n_tensor, T scale,\n                     float l1, float l2, float weight_decay, float learning_rate_val,\n                     float lr_scale, const float* learning_rate, const T* scale_by_ptr,\n                     const int64_t* skip_if, const float momentum, const float dampening,\n                     const bool nesterov, const bool maximize,\n                     TensorTupleParams<4> tensor_tuple_params);\n};\n\ntemplate<typename T>\nvoid MultiTensorMomentumUpdateWithCastKernelUtil<DeviceType::kCUDA, T, float16>::Update(\n    ep::Stream* stream, const int64_t elem_cnt, const int64_t n_tensor, T scale, float l1, float l2,\n    float weight_decay, float learning_rate_val, float lr_scale, const float* learning_rate,\n    const T* scale_by_ptr, const int64_t* skip_if, const float momentum, const float dampening,\n    const bool nesterov, const bool maximize, TensorTupleParams<4> tensor_tuple_params) {\n  MultiTensorMomentumUpdateWithCastKernelUtil<DeviceType::kCUDA, T, half>::Update(\n      stream, elem_cnt, n_tensor, scale, l1, l2, weight_decay, learning_rate_val, lr_scale,\n      learning_rate, scale_by_ptr, skip_if, momentum, dampening, nesterov, maximize,\n      tensor_tuple_params);\n}\n\ntemplate struct MultiTensorMomentumUpdateWithCastKernelUtil<DeviceType::kCUDA, float, float>;\ntemplate struct MultiTensorMomentumUpdateWithCastKernelUtil<DeviceType::kCUDA, float, float16>;\n\ntemplate<typename T, typename G>\nstruct MultiTensorAdamUpdateWithCastKernelUtil<DeviceType::kCUDA, T, G> {\n  static void Update(ep::Stream* stream, const int64_t elem_cnt, const int64_t n_tensor, T scale,\n                     float l1, float l2, float beta1, float beta2, float epsilon,\n                     float weight_decay, bool amsgrad, bool do_bias_correction,\n                     float learning_rate_val, float bias_correction1_val,\n                     float bias_correction2_val, float lr_scale, const float* learning_rate,\n                     const T* scale_by_ptr, const int64_t* skip_if, const float* bias_correction1,\n                     const float* bias_correction2, TensorTupleParams<5> tensor_tuple_params);\n};\n\ntemplate<typename T, typename G>\nvoid MultiTensorAdamUpdateWithCastKernelUtil<DeviceType::kCUDA, T, G>::Update(\n    ep::Stream* stream, const int64_t elem_cnt, const int64_t n_tensor, T scale, float l1, float l2,\n    float beta1, float beta2, float epsilon, float weight_decay, bool amsgrad,\n    bool do_bias_correction, float learning_rate_val, float bias_correction1_val,\n    float bias_correction2_val, float lr_scale, const float* learning_rate, const T* scale_by_ptr,\n    const int64_t* skip_if, const float* bias_correction1, const float* bias_correction2,\n    TensorTupleParams<5> tensor_tuple_params) {\n  const unsigned int grid_size =\n      ComputeGridSize(stream->As<ep::CudaStream>(), kBlockSize, elem_cnt);\n  for (int i = 0; i < n_tensor; i++) {\n    tensor_tuple_params.block_offset[i] =\n        ((tensor_tuple_params.sizes[i] + kBlockSize * kUnrollSize - 1) / (kBlockSize * kUnrollSize))\n        % grid_size;\n  }\n  MultiTensorAdamUpdateGpu<T, G, 5>\n      <<<grid_size, kBlockSize, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(\n          n_tensor, scale, l1, l2, beta1, beta2, epsilon, weight_decay, amsgrad, do_bias_correction,\n          learning_rate_val, bias_correction1_val, bias_correction2_val, lr_scale, learning_rate,\n          scale_by_ptr, skip_if, bias_correction1, bias_correction2, tensor_tuple_params);\n}\n\ntemplate<typename T>\nstruct MultiTensorAdamUpdateWithCastKernelUtil<DeviceType::kCUDA, T, float16> {\n  static void Update(ep::Stream* stream, const int64_t elem_cnt, const int64_t n_tensor, T scale,\n                     float l1, float l2, float beta1, float beta2, float epsilon,\n                     float weight_decay, bool amsgrad, bool do_bias_correction,\n                     float learning_rate_val, float bias_correction1_val,\n                     float bias_correction2_val, float lr_scale, const float* learning_rate,\n                     const T* scale_by_ptr, const int64_t* skip_if, const float* bias_correction1,\n                     const float* bias_correction2, TensorTupleParams<5> tensor_tuple_params);\n};\n\ntemplate<typename T>\nvoid MultiTensorAdamUpdateWithCastKernelUtil<DeviceType::kCUDA, T, float16>::Update(\n    ep::Stream* stream, const int64_t elem_cnt, const int64_t n_tensor, T scale, float l1, float l2,\n    float beta1, float beta2, float epsilon, float weight_decay, bool amsgrad,\n    bool do_bias_correction, float learning_rate_val, float bias_correction1_val,\n    float bias_correction2_val, float lr_scale, const float* learning_rate, const T* scale_by_ptr,\n    const int64_t* skip_if, const float* bias_correction1, const float* bias_correction2,\n    TensorTupleParams<5> tensor_tuple_params) {\n  MultiTensorAdamUpdateWithCastKernelUtil<DeviceType::kCUDA, T, half>::Update(\n      stream, elem_cnt, n_tensor, scale, l1, l2, beta1, beta2, epsilon, weight_decay, amsgrad,\n      do_bias_correction, learning_rate_val, bias_correction1_val, bias_correction2_val, lr_scale,\n      learning_rate, scale_by_ptr, skip_if, bias_correction1, bias_correction2,\n      tensor_tuple_params);\n}\n\ntemplate struct MultiTensorAdamUpdateWithCastKernelUtil<DeviceType::kCUDA, float, float>;\ntemplate struct MultiTensorAdamUpdateWithCastKernelUtil<DeviceType::kCUDA, float, float16>;\n\ntemplate<typename T, int N>\n__global__ void MultiTensorYoloModelEmaUpdateGpu(int64_t num_tensor, const float d,\n                                                 TensorTupleParams<N> tensor_tuple_params) {\n  int64_t v_block_id = blockIdx.x;\n  for (int64_t tensor_idx = 0; tensor_idx < num_tensor; tensor_idx++) {\n    const int64_t tensor_elem_cnt = tensor_tuple_params.sizes[tensor_idx];\n    T* model_ptr = (T*)tensor_tuple_params.ptr[0][tensor_idx];\n    T* model_update_ptr = (T*)tensor_tuple_params.ptr[1][tensor_idx];\n\n    for (int64_t i = v_block_id * blockDim.x * kUnrollSize + threadIdx.x; i < tensor_elem_cnt;\n         i += blockDim.x * gridDim.x * kUnrollSize) {\n      T model_val[kUnrollSize] = {0};\n      T model_update_val[kUnrollSize] = {0};\n\n#pragma unroll\n      for (int32_t ilp = 0; ilp < kUnrollSize; ilp++) {\n        int64_t actual_idx = i + ilp * blockDim.x;\n        if (actual_idx < tensor_elem_cnt) {\n          model_val[ilp] = *(model_ptr + actual_idx);\n          model_update_val[ilp] = *(model_update_ptr + actual_idx);\n        }\n      }\n\n#pragma unroll\n      for (int32_t ilp = 0; ilp < kUnrollSize; ilp++) {\n        int64_t actual_idx = i + ilp * blockDim.x;\n        if (actual_idx < tensor_elem_cnt) {\n          model_val[ilp] *= d;\n          model_val[ilp] += (1 - d) * model_update_val[ilp];\n        }\n      }\n\n#pragma unroll\n      for (int32_t ilp = 0; ilp < kUnrollSize; ilp++) {\n        int64_t actual_idx = i + ilp * blockDim.x;\n        if (actual_idx < tensor_elem_cnt) {\n          *(model_ptr + actual_idx) = model_val[ilp];\n          *(model_update_ptr + actual_idx) = model_update_val[ilp];\n        }\n      }\n    }\n    v_block_id -= tensor_tuple_params.block_offset[tensor_idx];\n    if (v_block_id < 0) { v_block_id += gridDim.x; }\n  }\n}\n\ntemplate<typename T>\nstruct MultiTensorYoloV5WeightUpdateKernelUtil<DeviceType::kCUDA, T> {\n  static void Update(ep::Stream* stream, const int64_t elem_cnt, const int64_t n_tensor, float d,\n                     TensorTupleParams<2> tensor_tuple_params);\n};\n\ntemplate<>\nstruct MultiTensorYoloV5WeightUpdateKernelUtil<DeviceType::kCUDA, half> {\n  static void Update(ep::Stream* stream, const int64_t elem_cnt, const int64_t n_tensor, float d,\n                     TensorTupleParams<2> tensor_tuple_params);\n};\n\ntemplate<typename T>\nvoid MultiTensorYoloV5WeightUpdateKernelUtil<DeviceType::kCUDA, T>::Update(\n    ep::Stream* stream, const int64_t elem_cnt, const int64_t n_tensor, float d,\n    TensorTupleParams<2> tensor_tuple_params) {\n  const unsigned int grid_size =\n      ComputeGridSize(stream->As<ep::CudaStream>(), kBlockSize, elem_cnt);\n  for (int i = 0; i < n_tensor; i++) {\n    tensor_tuple_params.block_offset[i] =\n        ((tensor_tuple_params.sizes[i] + kBlockSize * kUnrollSize - 1) / (kBlockSize * kUnrollSize))\n        % grid_size;\n  }\n  MultiTensorYoloModelEmaUpdateGpu<T>\n      <<<grid_size, kBlockSize, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(\n          n_tensor, d, tensor_tuple_params);\n}\n\ntemplate struct MultiTensorYoloV5WeightUpdateKernelUtil<DeviceType::kCUDA, float>;\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/multi_tensor_model_update_kernel_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_KERNELS_MULTI_TENSOR_MODEL_UPDATE_KERNEL_UTIL_H_\n#define ONEFLOW_USER_KERNELS_MULTI_TENSOR_MODEL_UPDATE_KERNEL_UTIL_H_\n\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/kernel_util.h\"\n\nnamespace oneflow {\n\n// Kernel arg size has 4K limit, but currently we set process 32 tensors in each kernel.\nconstexpr int kMaxTuples = 32;\n\ntemplate<int N>\nstruct TensorTupleParams {\n  void* ptr[N][kMaxTuples];\n  int64_t sizes[kMaxTuples];\n  int32_t block_offset[kMaxTuples];\n};\n\ntemplate<DeviceType device_type, typename T, typename G>\nstruct MultiTensorSGDUpdateKernelUtil {\n  static void Update(ep::Stream* stream, const int64_t elem_cnt, const int64_t n_tensor, T scale,\n                     float l1, float l2, float weight_decay, float learning_rate_val,\n                     float lr_scale, const float* learning_rate, const T* scale_by_ptr,\n                     const int64_t* skip_if, TensorTupleParams<2> tensor_tuple_params);\n};\n\ntemplate<DeviceType device_type, typename T, typename G>\nstruct MultiTensorMomentumUpdateKernelUtil {\n  static void Update(ep::Stream* stream, const int64_t elem_cnt, const int64_t n_tensor, T scale,\n                     float l1, float l2, float weight_decay, float learning_rate_val,\n                     float lr_scale, const float* learning_rate, const T* scale_by_ptr,\n                     const int64_t* skip_if, const float momentum, const float dampening,\n                     const bool nesterov, const bool maximize,\n                     TensorTupleParams<3> tensor_tuple_params);\n};\n\ntemplate<DeviceType device_type, typename T, typename G>\nstruct MultiTensorAdamUpdateKernelUtil {\n  static void Update(ep::Stream* stream, const int64_t elem_cnt, const int64_t n_tensor, T scale,\n                     float l1, float l2, float beta1, float beta2, float epsilon,\n                     float weight_decay, bool amsgrad, bool do_bias_correction,\n                     float learning_rate_val, float bias_correction1_val,\n                     float bias_correction2_val, float lr_scale, const float* learning_rate,\n                     const T* scale_by_ptr, const int64_t* skip_if, const float* bias_correction1,\n                     const float* bias_correction2, TensorTupleParams<4> tensor_tuple_params);\n};\n\ntemplate<DeviceType device_type, typename T, typename G>\nstruct MultiTensorSGDUpdateWithCastKernelUtil {\n  static void Update(ep::Stream* stream, const int64_t elem_cnt, const int64_t n_tensor, T scale,\n                     float l1, float l2, float weight_decay, float learning_rate_val,\n                     float lr_scale, const float* learning_rate, const T* scale_by_ptr,\n                     const int64_t* skip_if, TensorTupleParams<3> tensor_tuple_params);\n};\n\ntemplate<DeviceType device_type, typename T, typename G>\nstruct MultiTensorMomentumUpdateWithCastKernelUtil {\n  static void Update(ep::Stream* stream, const int64_t elem_cnt, const int64_t n_tensor, T scale,\n                     float l1, float l2, float weight_decay, float learning_rate_val,\n                     float lr_scale, const float* learning_rate, const T* scale_by_ptr,\n                     const int64_t* skip_if, const float momentum, const float dampening,\n                     const bool nesterov, const bool maximize,\n                     TensorTupleParams<4> tensor_tuple_params);\n};\n\ntemplate<DeviceType device_type, typename T, typename G>\nstruct MultiTensorAdamUpdateWithCastKernelUtil {\n  static void Update(ep::Stream* stream, const int64_t elem_cnt, const int64_t n_tensor, T scale,\n                     float l1, float l2, float beta1, float beta2, float epsilon,\n                     float weight_decay, bool amsgrad, bool do_bias_correction,\n                     float learning_rate_val, float bias_correction1_val,\n                     float bias_correction2_val, float lr_scale, const float* learning_rate,\n                     const T* scale_by_ptr, const int64_t* skip_if, const float* bias_correction1,\n                     const float* bias_correction2, TensorTupleParams<5> tensor_tuple_params);\n};\n\ntemplate<DeviceType device_type, typename T>\nstruct MultiTensorYoloV5WeightUpdateKernelUtil {\n  static void Update(ep::Stream* stream, const int64_t elem_cnt, const int64_t n_tensor, float d,\n                     TensorTupleParams<2> tensor_tuple_params);\n};\n\n}  // namespace oneflow\n\n#endif\n"
  },
  {
    "path": "oneflow/user/kernels/mutable_cast_once_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/cuda_graph_support.h\"\n#include \"oneflow/core/ep/include/primitive/cast.h\"\n#include \"oneflow/user/kernels/op_kernel_wrapper.h\"\n\nnamespace oneflow {\n\nnamespace user_op {\n\nnamespace {\n\ntemplate<typename Context>\nstd::unique_ptr<ep::primitive::Cast> NewCastPrimitive(Context* ctx) {\n  const DataType in_data_type = ctx->TensorDesc4ArgNameAndIndex(\"in\", 0)->data_type();\n  const DataType out_data_type = ctx->TensorDesc4ArgNameAndIndex(\"out\", 0)->data_type();\n  return ep::primitive::NewPrimitive<ep::primitive::CastFactory>(ctx->device_type(), in_data_type,\n                                                                 out_data_type);\n}\n\nclass MutableCastOnceOpKernelState final : public OpKernelState {\n public:\n  MutableCastOnceOpKernelState() : cast_once_flag_(false) {}\n\n  void SetDone() {\n    if (!cast_once_flag_) { cast_once_flag_ = true; }\n  }\n\n  bool IsDone() { return cast_once_flag_; }\n\n private:\n  bool cast_once_flag_ = false;\n};\n\nclass MutableCastOnce final : public OpKernel {\n public:\n  MutableCastOnce() = default;\n  ~MutableCastOnce() = default;\n\n  std::shared_ptr<OpKernelState> CreateOpKernelState(KernelInitContext* ctx) const override {\n    return std::make_shared<MutableCastOnceOpKernelState>();\n  }\n\n private:\n  void Compute(KernelComputeContext* ctx, user_op::OpKernelState* state,\n               const user_op::OpKernelCache*) const override {\n    auto* cast_state = CHECK_NOTNULL(dynamic_cast<MutableCastOnceOpKernelState*>(state));\n    if (cast_state->IsDone()) { return; }\n    const Tensor* input_tensor = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    Tensor* output_tensor = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    const int64_t elem_cnt = input_tensor->shape_view().elem_cnt();\n    CHECK_EQ(output_tensor->shape_view().elem_cnt(), elem_cnt);\n    auto cast_primitive = NewCastPrimitive(ctx);\n    CHECK(cast_primitive);\n    cast_primitive->Launch(ctx->stream(), input_tensor->dptr(), output_tensor->mut_dptr(),\n                           elem_cnt);\n    cast_state->SetDone();\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\nauto CastPrimitiveExists() {\n  return hob::make_custom(\"CastPrimitiveExists\", [](const user_op::KernelRegContext& ctx) -> bool {\n    return NewCastPrimitive(&ctx).operator bool();\n  });\n}\n\nREGISTER_USER_KERNEL(\"mutable_cast_once\")\n    .SetCreateFn<MutableCastOnce>()\n    .SetIsMatchedHob(CastPrimitiveExists() == true);\n\n}  // namespace\n\n}  // namespace user_op\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/narrow_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/ep/include/primitive/copy_nd.h\"\n#include \"oneflow/core/ep/include/primitive/memset.h\"\n\nnamespace oneflow {\n\nnamespace user_op {\n\nnamespace {\n\ntemplate<typename Context>\nstd::unique_ptr<ep::primitive::CopyNd> NewCopyNdPrimitive(Context* ctx) {\n  return ep::primitive::NewPrimitive<ep::primitive::CopyNdFactory>(ctx->device_type(), 3);\n}\n\ntemplate<typename Context>\nstd::unique_ptr<ep::primitive::Memset> NewMemsetPrimitive(Context* ctx) {\n  return ep::primitive::NewPrimitive<ep::primitive::MemsetFactory>(ctx->device_type());\n}\n\nauto CopyNdPrimitiveExists() {\n  return hob::make_custom(\"CopyNdPrimitiveExists\", [](const user_op::KernelRegContext& ctx) {\n    return NewCopyNdPrimitive(&ctx).operator bool();\n  });\n}\n\nauto MemsetPrimitiveExists() {\n  return hob::make_custom(\"MemsetPrimitiveExists\", [](const KernelRegContext& ctx) {\n    return NewMemsetPrimitive(&ctx).operator bool();\n  });\n}\n\n}  // namespace\n\nclass NarrowKernel final : public user_op::OpKernel {\n public:\n  NarrowKernel() = default;\n  ~NarrowKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    if (in->shape_view().elem_cnt() == 0) { return; }\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    const int64_t& dim = ctx->Attr<int64_t>(\"dim\");\n    const int64_t& start = ctx->Attr<int64_t>(\"start\");\n    int64_t length = out->shape_view().At(dim);\n    const ShapeView in_shape = in->shape_view();\n    auto copy_nd_primitive = NewCopyNdPrimitive(ctx);\n    CHECK(copy_nd_primitive);\n\n    const int64_t outer_dim = in_shape.Count(0, dim);\n    const int64_t inner_dim = in_shape.Count(dim + 1);\n    const int64_t narrow_dim = in_shape.At(dim);\n\n    DimVector dst_shape = {outer_dim, length, inner_dim};\n    DimVector dst_pos_vec = {0, 0, 0};\n\n    DimVector src_shape = {outer_dim, narrow_dim, inner_dim};\n    DimVector src_pos_vec = {0, start, 0};\n    DimVector extent_vec = {outer_dim, length, inner_dim};\n    copy_nd_primitive->Launch(ctx->stream(), out->data_type(), 3, out->mut_dptr(), dst_shape.data(),\n                              dst_pos_vec.data(), in->dptr(), src_shape.data(), src_pos_vec.data(),\n                              extent_vec.data());\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\nclass NarrowGradKernel final : public user_op::OpKernel {\n public:\n  NarrowGradKernel() = default;\n  ~NarrowGradKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex(\"dx\", 0);\n    const int64_t& dim = ctx->Attr<int64_t>(\"dim\");\n    const int64_t& start = ctx->Attr<int64_t>(\"start\");\n    int64_t length = dy->shape_view().At(dim);\n\n    size_t dx_byte_size = dx->shape_view().elem_cnt() * GetSizeOfDataType(dx->data_type());\n    void* dst = dx->mut_dptr();\n    std::unique_ptr<ep::primitive::Memset> memset_primitive =\n        ep::primitive::NewPrimitive<ep::primitive::MemsetFactory>(ctx->device_type());\n    CHECK(memset_primitive);\n    memset_primitive->Launch(ctx->stream(), dst, 0, dx_byte_size);\n\n    auto copy_nd_primitive = NewCopyNdPrimitive(ctx);\n    CHECK(copy_nd_primitive);\n    const ShapeView dx_shape = dx->shape_view();\n\n    const int64_t outer_dim = dx_shape.Count(0, dim);\n    const int64_t inner_dim = dx_shape.Count(dim + 1);\n    const int64_t narrow_dim = dx_shape.At(dim);\n\n    DimVector dst_shape = {outer_dim, narrow_dim, inner_dim};\n    DimVector dst_pos_vec = {0, start, 0};\n\n    DimVector src_shape = {outer_dim, length, inner_dim};\n    DimVector src_pos_vec = {0, 0, 0};\n    DimVector extent_vec = {outer_dim, length, inner_dim};\n\n    copy_nd_primitive->Launch(ctx->stream(), dx->data_type(), 3, dst, dst_shape.data(),\n                              dst_pos_vec.data(), dy->dptr(), src_shape.data(), src_pos_vec.data(),\n                              extent_vec.data());\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\nREGISTER_USER_KERNEL(\"narrow\").SetCreateFn<NarrowKernel>().SetIsMatchedHob(CopyNdPrimitiveExists()\n                                                                           == true);\nREGISTER_USER_KERNEL(\"narrow_grad\")\n    .SetCreateFn<NarrowGradKernel>()\n    .SetIsMatchedHob(MemsetPrimitiveExists() && CopyNdPrimitiveExists());\n\n}  // namespace user_op\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/nccl_logical_2d_sbp_kernels.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/device/nccl_util.h\"\n#include \"oneflow/core/job/eager_nccl_comm_manager.h\"\n#include \"oneflow/core/job/parallel_desc.h\"\n#include \"oneflow/core/ep/include/primitive/permute.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n#include \"oneflow/user/ops/nccl_logical_util.h\"\n#include \"oneflow/user/kernels/collective_communication/include/all_reduce.h\"\n#include \"oneflow/user/kernels/collective_communication/include/all_gather.h\"\n#include \"oneflow/user/kernels/collective_communication/include/all_to_all.h\"\n\n#if (defined(WITH_CUDA) && (NCCL_VERSION_CODE > 2700)) || defined(WITH_NPU) || defined(WITH_MLU)\n\nnamespace oneflow {\n\nnamespace {\n\nauto AllReduceCollectiveCommunicationExists() {\n  return hob::make_custom(\"AllReduceCollectiveCommunicationExists\",\n                          [=](const user_op::KernelRegContext& ctx) {\n                            DeviceType device_type = ctx.device_type();\n                            return ccl::IsCommunicationContextRegistered(device_type)\n                                   && ccl::IsAllReduceRegistered(device_type);\n                          });\n}\n\nauto AllGatherCollectiveCommunicationExists() {\n  return hob::make_custom(\"AllGatherCollectiveCommunicationExists\",\n                          [=](const user_op::KernelRegContext& ctx) {\n                            DeviceType device_type = ctx.device_type();\n                            return ccl::IsCommunicationContextRegistered(device_type)\n                                   && ccl::IsAllGatherRegistered(device_type);\n                          });\n}\n\nauto AllToAllCollectiveCommunicationExists() {\n  return hob::make_custom(\"AllToAllCollectiveCommunicationExists\",\n                          [=](const user_op::KernelRegContext& ctx) {\n                            DeviceType device_type = ctx.device_type();\n                            return ccl::IsCommunicationContextRegistered(device_type)\n                                   && ccl::IsAllToAllRegistered(device_type);\n                          });\n}\n\nclass CclLogical2DSameDim0KernelCommState : public user_op::OpKernelState {\n public:\n  explicit CclLogical2DSameDim0KernelCommState(user_op::KernelInitContext* ctx)\n      : is_init_(false),\n        stream_name_(EagerCclCommMgr::kDefaultCclStreamName),\n        parallel_desc_(ctx->parallel_desc()),\n        this_parallel_id_(ctx->parallel_ctx().parallel_id()) {\n    if (ctx->op_conf().has_stream_name_hint()) { stream_name_ = ctx->op_conf().stream_name_hint(); }\n  }\n  ~CclLogical2DSameDim0KernelCommState() override = default;\n\n  const ccl::CclComm& ccl_comm() {\n    if (!is_init_) { Init(); }\n    return ccl_comm_;\n  }\n\n  int64_t num_ranks() {\n    if (!is_init_) { Init(); }\n    return num_ranks_;\n  }\n\n  const std::string& stream_name() const { return stream_name_; }\n\n private:\n  void Init() {\n    CHECK(!is_init_);\n    const Shape& hierarchy = *parallel_desc_.hierarchy();\n    CHECK_EQ(hierarchy.NumAxes(), 2);\n    const int64_t group_size = hierarchy.At(1);\n    EagerCclCommMgr* comm_mgr = CHECK_NOTNULL(Singleton<EagerCclCommMgr>::Get());\n    ccl_comm_ = comm_mgr->GetCclCommForParallelDescNdHierarchy(parallel_desc_, stream_name_,\n                                                               this_parallel_id_, \"SameDim0\");\n    num_ranks_ = group_size;\n    is_init_ = true;\n  }\n\n  bool is_init_;\n  std::string stream_name_;\n  ParallelDesc parallel_desc_;\n  int64_t this_parallel_id_;\n  int64_t num_ranks_{};\n  ccl::CclComm ccl_comm_{};\n};\n\nclass CclLogical2DSameDim0AllGatherNoncontinuousKernelState\n    : public CclLogical2DSameDim0KernelCommState {\n public:\n  explicit CclLogical2DSameDim0AllGatherNoncontinuousKernelState(user_op::KernelInitContext* ctx)\n      : CclLogical2DSameDim0KernelCommState(ctx), src_split_axis_(-1) {}\n  ~CclLogical2DSameDim0AllGatherNoncontinuousKernelState() override = default;\n\n  int64_t src_split_axis() const { return src_split_axis_; }\n  void set_src_split_axis(int64_t split_axis) { src_split_axis_ = split_axis; }\n\n private:\n  int64_t src_split_axis_;\n};\n\nclass CclLogical2DSameDim0All2AllKernelState : public CclLogical2DSameDim0KernelCommState {\n public:\n  explicit CclLogical2DSameDim0All2AllKernelState(user_op::KernelInitContext* ctx)\n      : CclLogical2DSameDim0KernelCommState(ctx), src_split_axis_(-1), dst_split_axis_(-1) {}\n  ~CclLogical2DSameDim0All2AllKernelState() override = default;\n\n  int64_t src_split_axis() const { return src_split_axis_; }\n  void set_src_split_axis(int64_t split_axis) { src_split_axis_ = split_axis; }\n  int64_t dst_split_axis() const { return dst_split_axis_; }\n  void set_dst_split_axis(int64_t split_axis) { dst_split_axis_ = split_axis; }\n\n private:\n  int64_t src_split_axis_;\n  int64_t dst_split_axis_;\n};\n\nclass CclLogical2DSameDim0AllReduce final : public user_op::OpKernel {\n public:\n  CclLogical2DSameDim0AllReduce() = default;\n  ~CclLogical2DSameDim0AllReduce() override = default;\n\n  std::shared_ptr<user_op::OpKernelState> CreateOpKernelState(\n      user_op::KernelInitContext* ctx) const override {\n    return std::make_shared<CclLogical2DSameDim0KernelCommState>(ctx);\n  }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state,\n               const user_op::OpKernelCache*) const override {\n    auto* comm_state = dynamic_cast<CclLogical2DSameDim0KernelCommState*>(state);\n    CHECK(comm_state != nullptr);\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    CHECK_EQ(in->shape_view(), out->shape_view());\n    CHECK_EQ(in->data_type(), out->data_type());\n    VLOG(3) << \"[NcclLogical2D][SameDim0AllReduce] \" << comm_state->stream_name() << \" \"\n            << ctx->op_name() << std::endl;\n    ccl::ReduceType ccl_reduce_type = ccl::ReduceType::kSum;\n    if (in->data_type() == DataType::kBool) { ccl_reduce_type = ccl::ReduceType::kMax; }\n    ccl::CclComm ccl_comm = comm_state->ccl_comm();\n    std::unique_ptr<ccl::AllReduce> ccl_all_reduce =\n        ccl::NewCollectiveCommunication<ccl::AllReduce>(ctx->stream()->device_type(),\n                                                        in->data_type(), ccl_reduce_type);\n    ccl_all_reduce->Launch(ctx->stream(), in->dptr(), out->mut_dptr(), in->shape_view().elem_cnt(),\n                           ccl_comm);\n  };\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n  bool IsKernelLaunchSynchronized() const override {\n    const EagerCclCommMgr* comm_mgr = CHECK_NOTNULL(Singleton<EagerCclCommMgr>::Get());\n    return comm_mgr->IsAsyncLaunchCclLogicalKernel();\n  }\n};\n\nclass CclLogical2DSameDim0AllGather final : public user_op::OpKernel {\n public:\n  CclLogical2DSameDim0AllGather() = default;\n  ~CclLogical2DSameDim0AllGather() override = default;\n\n  std::shared_ptr<user_op::OpKernelState> CreateOpKernelState(\n      user_op::KernelInitContext* ctx) const override {\n    return std::make_shared<CclLogical2DSameDim0KernelCommState>(ctx);\n  }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state,\n               const user_op::OpKernelCache*) const override {\n    auto* comm_state = dynamic_cast<CclLogical2DSameDim0KernelCommState*>(state);\n    CHECK(comm_state != nullptr);\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    CHECK_EQ(in->data_type(), out->data_type());\n    const int64_t num_ranks = comm_state->num_ranks();\n    CHECK_EQ(in->shape_view().elem_cnt() * num_ranks, out->shape_view().elem_cnt());\n    VLOG(3) << \"[NcclLogical2D][SameDim0AllGather] \" << comm_state->stream_name() << \" \"\n            << ctx->op_name() << std::endl;\n\n    std::unique_ptr<ccl::AllGather> ccl_all_gather =\n        ccl::NewCollectiveCommunication<ccl::AllGather>(ctx->stream()->device_type(),\n                                                        in->data_type());\n    ccl::CclComm ccl_comm = comm_state->ccl_comm();\n    ccl_all_gather->Launch(ctx->stream(), in->dptr(), out->mut_dptr(), in->shape_view().elem_cnt(),\n                           ccl_comm);\n  };\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n  bool IsKernelLaunchSynchronized() const override {\n    const EagerCclCommMgr* comm_mgr = CHECK_NOTNULL(Singleton<EagerCclCommMgr>::Get());\n    return comm_mgr->IsAsyncLaunchCclLogicalKernel();\n  }\n};\n\ntemplate<typename T>\nclass CclLogical2DSameDim0AllGatherNoncontinuous final : public user_op::OpKernel {\n public:\n  CclLogical2DSameDim0AllGatherNoncontinuous() = default;\n  ~CclLogical2DSameDim0AllGatherNoncontinuous() override = default;\n\n  std::shared_ptr<user_op::OpKernelState> CreateOpKernelState(\n      user_op::KernelInitContext* ctx) const override {\n    auto state = std::make_shared<CclLogical2DSameDim0AllGatherNoncontinuousKernelState>(ctx);\n    NdSbp src_nd_sbp;\n    CHECK_JUST(GetNcclLogicalNdSbpFromAttr(ctx, \"src_reduced_nd_sbp\", &src_nd_sbp));\n    CHECK_EQ(src_nd_sbp.sbp_parallel_size(), 2);\n    CHECK(src_nd_sbp.sbp_parallel(1).has_split_parallel());\n    state->set_src_split_axis(src_nd_sbp.sbp_parallel(1).split_parallel().axis());\n    return state;\n  }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state,\n               const user_op::OpKernelCache*) const override {\n    auto* kernel_state =\n        dynamic_cast<CclLogical2DSameDim0AllGatherNoncontinuousKernelState*>(state);\n    CHECK_NOTNULL(kernel_state);\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n    const int64_t dtype_size = GetSizeOfDataType(in->data_type());\n    int64_t data_size = GetCudaAlignedSize(out->shape_view().elem_cnt() * dtype_size);\n    void* unpack_from_ptr = tmp_buffer->mut_dptr();\n    CHECK_EQ(tmp_buffer->shape_view().elem_cnt(), data_size);\n\n    CHECK_EQ(in->data_type(), out->data_type());\n    const int64_t num_ranks = kernel_state->num_ranks();\n    const int64_t in_split_axis = kernel_state->src_split_axis();\n\n    DimVector logical_shape_dim_vec;\n    in->shape_view().ToDimVector(&logical_shape_dim_vec);\n    logical_shape_dim_vec[in_split_axis] = logical_shape_dim_vec.at(in_split_axis) * num_ranks;\n\n    VLOG(3) << \"[NcclLogical2D][SameDim0AllGatherNoncontinuous] \" << kernel_state->stream_name()\n            << \" \" << ctx->op_name() << std::endl;\n\n    // NOTE(chengcheng): Do AllGather\n    CHECK_EQ(in->shape_view().elem_cnt() * num_ranks, out->shape_view().elem_cnt());\n\n    std::unique_ptr<ccl::AllGather> ccl_all_gather =\n        ccl::NewCollectiveCommunication<ccl::AllGather>(ctx->stream()->device_type(),\n                                                        in->data_type());\n    ccl::CclComm ccl_comm = kernel_state->ccl_comm();\n    ccl_all_gather->Launch(ctx->stream(), in->dptr(), unpack_from_ptr, in->shape_view().elem_cnt(),\n                           ccl_comm);\n\n    CHECK_GT(in_split_axis, 0);\n    // NOTE(chengcheng): Do unpack.\n    DimVector unpack_from_dim_vec = logical_shape_dim_vec;\n    CHECK_EQ(unpack_from_dim_vec.at(in_split_axis) % num_ranks, 0);\n    unpack_from_dim_vec[in_split_axis] = unpack_from_dim_vec.at(in_split_axis) / num_ranks;\n    unpack_from_dim_vec.insert(unpack_from_dim_vec.begin(), num_ranks);\n    std::vector<int32_t> perm;\n    FOR_RANGE(int64_t, i, 1, unpack_from_dim_vec.size()) { perm.emplace_back(i); }\n    perm.insert(perm.begin() + in_split_axis, 0);\n\n    auto transpose = ep::primitive::NewPrimitive<ep::primitive::PermuteFactory>(\n        ctx->stream()->device_type(), unpack_from_dim_vec.size());\n    CHECK(transpose);\n    transpose->Launch(ctx->stream(), in->data_type(), unpack_from_dim_vec.size(),\n                      unpack_from_dim_vec.data(), unpack_from_ptr, perm.data(), out->mut_dptr());\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n  bool IsKernelLaunchSynchronized() const override {\n    const EagerCclCommMgr* comm_mgr = CHECK_NOTNULL(Singleton<EagerCclCommMgr>::Get());\n    return comm_mgr->IsAsyncLaunchCclLogicalKernel();\n  }\n};\n\nsize_t Infer2DSameDim0AllGatherNoncontinuousKernelTmpBufferSize(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& out_tensor = ctx->OutputTensorDesc(\"out\", 0);\n  return GetCudaAlignedSize(out_tensor.shape().elem_cnt()\n                            * GetSizeOfDataType(out_tensor.data_type()));\n}\n\ntemplate<typename T>\nclass CclLogical2DSameDim0All2All final : public user_op::OpKernel {\n public:\n  CclLogical2DSameDim0All2All() = default;\n  ~CclLogical2DSameDim0All2All() override = default;\n\n  std::shared_ptr<user_op::OpKernelState> CreateOpKernelState(\n      user_op::KernelInitContext* ctx) const override {\n    auto state = std::make_shared<CclLogical2DSameDim0All2AllKernelState>(ctx);\n    NdSbp src_nd_sbp;\n    NdSbp dst_nd_sbp;\n    CHECK_JUST(GetNcclLogicalNdSbpFromAttr(ctx, \"src_reduced_nd_sbp\", &src_nd_sbp));\n    CHECK_JUST(GetNcclLogicalNdSbpFromAttr(ctx, \"dst_reduced_nd_sbp\", &dst_nd_sbp));\n    CHECK_EQ(src_nd_sbp.sbp_parallel_size(), 2);\n    CHECK_EQ(dst_nd_sbp.sbp_parallel_size(), 2);\n    CHECK(src_nd_sbp.sbp_parallel(1).has_split_parallel());\n    CHECK(dst_nd_sbp.sbp_parallel(1).has_split_parallel());\n    state->set_src_split_axis(src_nd_sbp.sbp_parallel(1).split_parallel().axis());\n    state->set_dst_split_axis(dst_nd_sbp.sbp_parallel(1).split_parallel().axis());\n    return state;\n  }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state,\n               const user_op::OpKernelCache*) const override {\n    auto* kernel_state = dynamic_cast<CclLogical2DSameDim0All2AllKernelState*>(state);\n    CHECK_NOTNULL(kernel_state);\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n    int64_t tmp_size = 0;\n    const int64_t dtype_size = GetSizeOfDataType(in->data_type());\n    int64_t data_size = GetCudaAlignedSize(in->shape_view().elem_cnt() * dtype_size);\n    // NOTE(chengcheng): in (transpose)-> pack_to_ptr (all2all)-> unpack_from_ptr (transpose)-> out\n    const char* pack_to_ptr = in->dptr<char>();\n    char* unpack_from_ptr = out->mut_dptr<char>();\n    if (tmp_buffer) { tmp_size = tmp_buffer->shape_view().elem_cnt(); }\n    CHECK(tmp_size == 0 || tmp_size == data_size || tmp_size == data_size * 2);\n\n    CHECK_EQ(in->data_type(), out->data_type());\n    const int64_t num_ranks = kernel_state->num_ranks();\n    CHECK_EQ(in->shape_view().elem_cnt(), out->shape_view().elem_cnt());\n    const int64_t elem_cnt = in->shape_view().elem_cnt();\n    const int64_t in_split_axis = kernel_state->src_split_axis();\n    const int64_t out_split_axis = kernel_state->dst_split_axis();\n\n    DimVector logical_shape_dim_vec;\n    in->shape_view().ToDimVector(&logical_shape_dim_vec);\n    logical_shape_dim_vec[in_split_axis] = logical_shape_dim_vec.at(in_split_axis) * num_ranks;\n\n    VLOG(3) << \"[NcclLogical2D][SameDim0All2All] \" << kernel_state->stream_name() << \" \"\n            << ctx->op_name() << std::endl;\n\n    if (out_split_axis != 0) {\n      // NOTE(chengcheng): Do pack. Need transpose in -> pack_to\n      // pack use temp buffer offset: [0, data_size]\n      pack_to_ptr = CHECK_NOTNULL(tmp_buffer)->dptr<char>();\n      DimVector transpose_in_dim_vec = logical_shape_dim_vec;\n      CHECK_EQ(transpose_in_dim_vec.at(in_split_axis) % num_ranks, 0);\n      transpose_in_dim_vec[in_split_axis] = transpose_in_dim_vec.at(in_split_axis) / num_ranks;\n      CHECK_EQ(transpose_in_dim_vec.at(out_split_axis) % num_ranks, 0);\n      transpose_in_dim_vec[out_split_axis] = transpose_in_dim_vec.at(out_split_axis) / num_ranks;\n      transpose_in_dim_vec.insert(transpose_in_dim_vec.begin() + out_split_axis, num_ranks);\n      std::vector<int32_t> perm;\n      perm.emplace_back(out_split_axis);\n      FOR_RANGE(int64_t, i, 0, transpose_in_dim_vec.size()) {\n        if (i != out_split_axis) { perm.emplace_back(i); }\n      }\n      auto transpose = ep::primitive::NewPrimitive<ep::primitive::PermuteFactory>(\n          ctx->stream()->device_type(), transpose_in_dim_vec.size());\n      CHECK(transpose);\n      transpose->Launch(ctx->stream(), in->data_type(), transpose_in_dim_vec.size(),\n                        transpose_in_dim_vec.data(), in->dptr(), perm.data(),\n                        tmp_buffer->mut_dptr());\n    }\n\n    if (in_split_axis != 0) {\n      // NOTE(chengcheng): Do unpack. Need transpose unpack_from -> out\n      // unpack use temp buffer offset: [tmp_size - data_size, tmp_size]\n      unpack_from_ptr = CHECK_NOTNULL(tmp_buffer)->mut_dptr<char>() + (tmp_size - data_size);\n    }\n\n    {\n      // NOTE(chengcheng): Do S2S\n      const int64_t elem_per_chunk = elem_cnt / num_ranks;\n      std::unique_ptr<ccl::AllToAll> all_to_all = ccl::NewCollectiveCommunication<ccl::AllToAll>(\n          ctx->stream()->device_type(), in->data_type(), in->data_type(), num_ranks);\n      ccl::CclComm ccl_comm = kernel_state->ccl_comm();\n      all_to_all->Launch(ctx->stream(), const_cast<char*>(pack_to_ptr), elem_per_chunk,\n                         unpack_from_ptr, elem_per_chunk, ccl_comm);\n    }\n\n    if (in_split_axis != 0) {\n      // Do unpack.\n      CHECK(unpack_from_ptr != out->mut_dptr<char>());\n      DimVector unpack_from_dim_vec = logical_shape_dim_vec;\n      CHECK_EQ(unpack_from_dim_vec.at(in_split_axis) % num_ranks, 0);\n      unpack_from_dim_vec[in_split_axis] = unpack_from_dim_vec.at(in_split_axis) / num_ranks;\n      CHECK_EQ(unpack_from_dim_vec.at(out_split_axis) % num_ranks, 0);\n      unpack_from_dim_vec[out_split_axis] = unpack_from_dim_vec.at(out_split_axis) / num_ranks;\n      unpack_from_dim_vec.insert(unpack_from_dim_vec.begin(), num_ranks);\n      std::vector<int32_t> perm;\n      FOR_RANGE(int64_t, i, 1, unpack_from_dim_vec.size()) { perm.emplace_back(i); }\n      perm.insert(perm.begin() + in_split_axis, 0);\n      auto transpose = ep::primitive::NewPrimitive<ep::primitive::PermuteFactory>(\n          ctx->stream()->device_type(), unpack_from_dim_vec.size());\n      CHECK(transpose);\n      transpose->Launch(ctx->stream(), in->data_type(), unpack_from_dim_vec.size(),\n                        unpack_from_dim_vec.data(), unpack_from_ptr, perm.data(), out->mut_dptr());\n    }\n  };\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n  bool IsKernelLaunchSynchronized() const override {\n    const EagerCclCommMgr* comm_mgr = CHECK_NOTNULL(Singleton<EagerCclCommMgr>::Get());\n    return comm_mgr->IsAsyncLaunchCclLogicalKernel();\n  }\n};\n\nsize_t Infer2DSameDim0All2AllKernelTmpBufferSize(user_op::InferContext* ctx) {\n  size_t ret = 0;\n  const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc(\"in\", 0);\n  size_t tensor_byte_size =\n      GetCudaAlignedSize(in_tensor.shape().elem_cnt() * GetSizeOfDataType(in_tensor.data_type()));\n  NdSbp src_nd_sbp;\n  NdSbp dst_nd_sbp;\n  CHECK_JUST(GetNcclLogicalNdSbpFromAttr(ctx, \"src_reduced_nd_sbp\", &src_nd_sbp));\n  CHECK_JUST(GetNcclLogicalNdSbpFromAttr(ctx, \"dst_reduced_nd_sbp\", &dst_nd_sbp));\n  CHECK_EQ(src_nd_sbp.sbp_parallel_size(), 2);\n  CHECK_EQ(dst_nd_sbp.sbp_parallel_size(), 2);\n  CHECK(src_nd_sbp.sbp_parallel(1).has_split_parallel());\n  CHECK(dst_nd_sbp.sbp_parallel(1).has_split_parallel());\n  if (src_nd_sbp.sbp_parallel(1).split_parallel().axis() != 0) { ret += tensor_byte_size; }\n  if (dst_nd_sbp.sbp_parallel(1).split_parallel().axis() != 0) { ret += tensor_byte_size; }\n  return ret;\n}\n\nclass CclLogical2DSameDim1KernelCommState final : public user_op::OpKernelState {\n public:\n  explicit CclLogical2DSameDim1KernelCommState(user_op::KernelInitContext* ctx)\n      : is_init_(false),\n        stream_name_(EagerCclCommMgr::kDefaultCclStreamName),\n        parallel_desc_(ctx->parallel_desc()),\n        this_parallel_id_(ctx->parallel_ctx().parallel_id()) {\n    if (ctx->op_conf().has_stream_name_hint()) { stream_name_ = ctx->op_conf().stream_name_hint(); }\n  }\n  ~CclLogical2DSameDim1KernelCommState() = default;\n\n  const ccl::CclComm& ccl_comm() {\n    if (!is_init_) {\n      const Shape& hierarchy = *parallel_desc_.hierarchy();\n      CHECK_EQ(hierarchy.NumAxes(), 2);\n      EagerCclCommMgr* comm_mgr = CHECK_NOTNULL(Singleton<EagerCclCommMgr>::Get());\n      ccl_comm_ = comm_mgr->GetCclCommForParallelDescNdHierarchy(parallel_desc_, stream_name_,\n                                                                 this_parallel_id_, \"SameDim1\");\n      is_init_ = true;\n    }\n    return ccl_comm_;\n  }\n\n  const std::string& stream_name() const { return stream_name_; }\n\n private:\n  bool is_init_;\n  std::string stream_name_;\n  ParallelDesc parallel_desc_;\n  int64_t this_parallel_id_;\n  ccl::CclComm ccl_comm_{};\n};\n\nclass CclLogical2DSameDim1AllReduce final : public user_op::OpKernel {\n public:\n  CclLogical2DSameDim1AllReduce() = default;\n  ~CclLogical2DSameDim1AllReduce() override = default;\n\n  std::shared_ptr<user_op::OpKernelState> CreateOpKernelState(\n      user_op::KernelInitContext* ctx) const override {\n    return std::make_shared<CclLogical2DSameDim1KernelCommState>(ctx);\n  }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state,\n               const user_op::OpKernelCache*) const override {\n    auto* comm_state = dynamic_cast<CclLogical2DSameDim1KernelCommState*>(state);\n    CHECK(comm_state != nullptr);\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    CHECK_EQ(in->shape_view(), out->shape_view());\n    CHECK_EQ(in->data_type(), out->data_type());\n    VLOG(3) << \"[NcclLogical2D][SameDim1AllReduce] \" << comm_state->stream_name() << \" \"\n            << ctx->op_name() << std::endl;\n    ccl::ReduceType ccl_reduce_type = ccl::ReduceType::kSum;\n    if (in->data_type() == DataType::kBool) { ccl_reduce_type = ccl::ReduceType::kMax; }\n\n    ccl::CclComm ccl_comm = comm_state->ccl_comm();\n    std::unique_ptr<ccl::AllReduce> ccl_all_reduce =\n        ccl::NewCollectiveCommunication<ccl::AllReduce>(ctx->stream()->device_type(),\n                                                        in->data_type(), ccl_reduce_type);\n    ccl_all_reduce->Launch(ctx->stream(), in->dptr(), out->mut_dptr(), in->shape_view().elem_cnt(),\n                           ccl_comm);\n  };\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n  bool IsKernelLaunchSynchronized() const override {\n    const EagerCclCommMgr* comm_mgr = CHECK_NOTNULL(Singleton<EagerCclCommMgr>::Get());\n    return comm_mgr->IsAsyncLaunchCclLogicalKernel();\n  }\n};\n\n}  // namespace\n\nREGISTER_USER_KERNEL(\"_nccl_logical_2D_same_dim0_all_reduce\")\n    .SetCreateFn<CclLogical2DSameDim0AllReduce>()\n    .SetIsMatchedHob(AllReduceCollectiveCommunicationExists());\n\nREGISTER_USER_KERNEL(\"_nccl_logical_2D_same_dim0_all_gather\")\n    .SetCreateFn<CclLogical2DSameDim0AllGather>()\n    .SetIsMatchedHob(AllGatherCollectiveCommunicationExists());\n\n#define REGISTER_2D_SAME_DIM0_ALLGATHER_NONCONTINUOUS_KERNEL(dtype)                      \\\n  REGISTER_USER_KERNEL(\"_nccl_logical_2D_same_dim0_all_gather_noncontinuous\")            \\\n      .SetCreateFn<CclLogical2DSameDim0AllGatherNoncontinuous<dtype>>()                  \\\n      .SetIsMatchedHob(AllGatherCollectiveCommunicationExists()                          \\\n                       && (user_op::HobDataType(\"in\", 0) == GetDataType<dtype>::value)   \\\n                       && (user_op::HobDataType(\"out\", 0) == GetDataType<dtype>::value)) \\\n      .SetInferTmpSizeFn(Infer2DSameDim0AllGatherNoncontinuousKernelTmpBufferSize);\n\nREGISTER_2D_SAME_DIM0_ALLGATHER_NONCONTINUOUS_KERNEL(bool)\nREGISTER_2D_SAME_DIM0_ALLGATHER_NONCONTINUOUS_KERNEL(int8_t)\nREGISTER_2D_SAME_DIM0_ALLGATHER_NONCONTINUOUS_KERNEL(int32_t)\nREGISTER_2D_SAME_DIM0_ALLGATHER_NONCONTINUOUS_KERNEL(int64_t)\nREGISTER_2D_SAME_DIM0_ALLGATHER_NONCONTINUOUS_KERNEL(float)\nREGISTER_2D_SAME_DIM0_ALLGATHER_NONCONTINUOUS_KERNEL(double)\nREGISTER_2D_SAME_DIM0_ALLGATHER_NONCONTINUOUS_KERNEL(float16)\n#if defined(__CUDA_BF16_TYPES_EXIST__)\nREGISTER_2D_SAME_DIM0_ALLGATHER_NONCONTINUOUS_KERNEL(nv_bfloat16)\n#endif\n\n#define REGISTER_2D_SAME_DIM0_ALL2ALL_KERNEL(dtype)                                      \\\n  REGISTER_USER_KERNEL(\"_nccl_logical_2D_same_dim0_all2all\")                             \\\n      .SetCreateFn<CclLogical2DSameDim0All2All<dtype>>()                                 \\\n      .SetIsMatchedHob(AllToAllCollectiveCommunicationExists()                           \\\n                       && (user_op::HobDataType(\"in\", 0) == GetDataType<dtype>::value)   \\\n                       && (user_op::HobDataType(\"out\", 0) == GetDataType<dtype>::value)) \\\n      .SetInferTmpSizeFn(Infer2DSameDim0All2AllKernelTmpBufferSize);\n\nREGISTER_2D_SAME_DIM0_ALL2ALL_KERNEL(bool)\nREGISTER_2D_SAME_DIM0_ALL2ALL_KERNEL(int8_t)\nREGISTER_2D_SAME_DIM0_ALL2ALL_KERNEL(int32_t)\nREGISTER_2D_SAME_DIM0_ALL2ALL_KERNEL(int64_t)\nREGISTER_2D_SAME_DIM0_ALL2ALL_KERNEL(float)\nREGISTER_2D_SAME_DIM0_ALL2ALL_KERNEL(double)\nREGISTER_2D_SAME_DIM0_ALL2ALL_KERNEL(float16)\n#if defined(__CUDA_BF16_TYPES_EXIST__)\nREGISTER_2D_SAME_DIM0_ALL2ALL_KERNEL(nv_bfloat16)\n#endif\n\nREGISTER_USER_KERNEL(\"_nccl_logical_2D_same_dim1_all_reduce\")\n    .SetCreateFn<CclLogical2DSameDim1AllReduce>()\n    .SetIsMatchedHob(AllReduceCollectiveCommunicationExists());\n\nREGISTER_USER_KERNEL_UNIFIED_CCL_COMM_INIT(\"_nccl_logical_2D_same_dim0_all_reduce\");\nREGISTER_USER_KERNEL_UNIFIED_CCL_COMM_INIT(\"_nccl_logical_2D_same_dim0_all_gather\");\nREGISTER_USER_KERNEL_UNIFIED_CCL_COMM_INIT(\"_nccl_logical_2D_same_dim0_all_gather_noncontinuous\");\nREGISTER_USER_KERNEL_UNIFIED_CCL_COMM_INIT(\"_nccl_logical_2D_same_dim0_all2all\");\nREGISTER_USER_KERNEL_UNIFIED_CCL_COMM_INIT(\"_nccl_logical_2D_same_dim1_all_reduce\");\n\n}  // namespace oneflow\n\n#endif  // WITH_CUDA || WITH_NPU || WITH_MLU\n"
  },
  {
    "path": "oneflow/user/kernels/nccl_logical_fusion_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/device/nccl_util.h\"\n#include \"oneflow/core/job/eager_nccl_comm_manager.h\"\n#include \"oneflow/core/job/parallel_desc.h\"\n#include \"oneflow/core/ep/include/primitive/permute.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n#include \"oneflow/user/ops/nccl_logical_util.h\"\n#include \"collective_communication/include/collective_communication.h\"\n#include \"collective_communication/include/send.h\"\n#include \"collective_communication/include/recv.h\"\n#include \"collective_communication/include/all_gather.h\"\n#include \"collective_communication/include/all_reduce.h\"\n#include \"collective_communication/include/all_to_all.h\"\n#include \"collective_communication/include/reduce_scatter.h\"\n\n#if (defined(WITH_CUDA) && (NCCL_VERSION_CODE > 2700)) || defined(WITH_NPU) || defined(WITH_MLU)\n\nnamespace oneflow {\n\nnamespace {\n\nsize_t GetTmpBufferSizeByNcclType(const std::string& nccl_type, size_t in_tensor_byte_size,\n                                  size_t out_tensor_byte_size, const NdSbp& src_nd_sbp,\n                                  const NdSbp& dst_nd_sbp) {\n  if (nccl_type == \"_nccl_logical_all_gather_noncontinuous\") {\n    return out_tensor_byte_size;\n  } else if (nccl_type == \"_nccl_logical_reduce_scatter_noncontinuous\") {\n    return in_tensor_byte_size;\n  } else if (nccl_type == \"_nccl_logical_s2s\") {\n    size_t ret = 0;\n    CHECK_EQ(src_nd_sbp.sbp_parallel_size(), 1);\n    CHECK_EQ(dst_nd_sbp.sbp_parallel_size(), 1);\n    CHECK(src_nd_sbp.sbp_parallel(0).has_split_parallel());\n    CHECK(dst_nd_sbp.sbp_parallel(0).has_split_parallel());\n    if (src_nd_sbp.sbp_parallel(0).split_parallel().axis() != 0) { ret += in_tensor_byte_size; }\n    if (dst_nd_sbp.sbp_parallel(0).split_parallel().axis() != 0) { ret += in_tensor_byte_size; }\n    return ret;\n  } else if (nccl_type == \"_nccl_logical_2D_same_dim0_all_gather_noncontinuous\") {\n    return out_tensor_byte_size;\n  } else if (nccl_type == \"_nccl_logical_2D_same_dim0_all2all\") {\n    size_t ret = 0;\n    CHECK_EQ(src_nd_sbp.sbp_parallel_size(), 2);\n    CHECK_EQ(dst_nd_sbp.sbp_parallel_size(), 2);\n    CHECK(src_nd_sbp.sbp_parallel(1).has_split_parallel());\n    CHECK(dst_nd_sbp.sbp_parallel(1).has_split_parallel());\n    if (src_nd_sbp.sbp_parallel(1).split_parallel().axis() != 0) { ret += in_tensor_byte_size; }\n    if (dst_nd_sbp.sbp_parallel(1).split_parallel().axis() != 0) { ret += in_tensor_byte_size; }\n    return ret;\n  }\n  return 0;\n}\n\nsize_t GetTensorByteSize(const user_op::TensorDesc& tensor_desc) {\n  return GetCudaAlignedSize(tensor_desc.shape().elem_cnt()\n                            * GetSizeOfDataType(tensor_desc.data_type()));\n}\n\nsize_t GetTensorByteSize(const user_op::Tensor& tensor) {\n  return GetCudaAlignedSize(tensor.shape_view().elem_cnt() * GetSizeOfDataType(tensor.data_type()));\n}\n\nclass CclLogicalFusionKernelState : public user_op::OpKernelState {\n public:\n  explicit CclLogicalFusionKernelState(user_op::KernelInitContext* ctx)\n      : is_init_(false),\n        stream_name_(EagerCclCommMgr::kDefaultCclStreamName),\n        parallel_desc_(ctx->parallel_desc()),\n        this_parallel_id_(ctx->parallel_ctx().parallel_id()),\n        num_ranks_(-1),\n        comm_key_(\"InvalidKey\"),\n        nccl_num_(-1) {\n    if (ctx->op_conf().has_stream_name_hint()) { stream_name_ = ctx->op_conf().stream_name_hint(); }\n    InitSplitAxisAndTmpBufferOffset(ctx);\n  }\n  ~CclLogicalFusionKernelState() override = default;\n\n  ccl::CclComm ccl_comm() {\n    if (!is_init_) { InitComm(); }\n    return ccl_comm_;\n  }\n\n  int64_t num_ranks() {\n    if (!is_init_) { InitComm(); }\n    return num_ranks_;\n  }\n\n  const std::string& stream_name() const { return stream_name_; }\n  int64_t src_split_axis(int32_t i) const {\n    CHECK_GE(i, 0);\n    CHECK_LT(i, src_split_axis_list_.size());\n    return src_split_axis_list_.at(i);\n  }\n  int64_t dst_split_axis(int32_t i) const {\n    CHECK_GE(i, 0);\n    CHECK_LT(i, dst_split_axis_list_.size());\n    return dst_split_axis_list_.at(i);\n  }\n\n  int32_t nccl_num() const { return nccl_num_; }\n  size_t tmp_buffer_offset(int32_t i) {\n    CHECK_GE(i, 0);\n    CHECK_LT(i, tmp_buffer_offset_.size());\n    return tmp_buffer_offset_.at(i);\n  }\n\n  size_t tmp_buffer_size(int32_t i) {\n    CHECK_GE(i, 0);\n    CHECK_LT(i, tmp_buffer_size_.size());\n    return tmp_buffer_size_.at(i);\n  }\n\n private:\n  void InitComm() {\n    CHECK(!is_init_);\n    const Shape& hierarchy = *parallel_desc_.hierarchy();\n\n    if (hierarchy.NumAxes() == 1) {\n      num_ranks_ = parallel_desc_.parallel_num();\n    } else if (hierarchy.NumAxes() == 2) {\n      CHECK(comm_key_ == \"SameDim0\" || comm_key_ == \"SameDim1\");\n      if (comm_key_ == \"SameDim0\") {\n        const int64_t group_size = hierarchy.At(1);\n        num_ranks_ = group_size;\n      } else if (comm_key_ == \"SameDim1\") {\n        const int64_t group_size = hierarchy.At(0);\n        num_ranks_ = group_size;\n      } else {\n        UNIMPLEMENTED();\n      }\n    } else {\n      UNIMPLEMENTED();\n    }\n\n    EagerCclCommMgr* comm_mgr = CHECK_NOTNULL(Singleton<EagerCclCommMgr>::Get());\n    ccl_comm_ = comm_mgr->GetCclCommForParallelDescNdHierarchy(parallel_desc_, stream_name_,\n                                                               this_parallel_id_, comm_key_);\n    is_init_ = true;\n  }\n\n  void UpdateOrCheckEqCommKey(const std::string& val) {\n    if (comm_key_ == \"InvalidKey\") {\n      comm_key_ = val;\n    } else {\n      CHECK_EQ(comm_key_, val);\n    }\n  }\n\n  void UpdateSplitAxisByNcclType(const std::string& nccl_type, const int32_t i,\n                                 const NdSbp& src_nd_sbp, const NdSbp& dst_nd_sbp) {\n    if (nccl_type == \"_nccl_logical_all_gather_noncontinuous\") {\n      CHECK_EQ(src_nd_sbp.sbp_parallel_size(), 1);\n      CHECK(src_nd_sbp.sbp_parallel(0).has_split_parallel());\n      src_split_axis_list_.at(i) = src_nd_sbp.sbp_parallel(0).split_parallel().axis();\n    } else if (nccl_type == \"_nccl_logical_reduce_scatter_noncontinuous\") {\n      CHECK_EQ(dst_nd_sbp.sbp_parallel_size(), 1);\n      CHECK(dst_nd_sbp.sbp_parallel(0).has_split_parallel());\n      dst_split_axis_list_.at(i) = dst_nd_sbp.sbp_parallel(0).split_parallel().axis();\n    } else if (nccl_type == \"_nccl_logical_s2s\") {\n      CHECK_EQ(src_nd_sbp.sbp_parallel_size(), 1);\n      CHECK_EQ(dst_nd_sbp.sbp_parallel_size(), 1);\n      CHECK(src_nd_sbp.sbp_parallel(0).has_split_parallel());\n      CHECK(dst_nd_sbp.sbp_parallel(0).has_split_parallel());\n      src_split_axis_list_.at(i) = src_nd_sbp.sbp_parallel(0).split_parallel().axis();\n      dst_split_axis_list_.at(i) = dst_nd_sbp.sbp_parallel(0).split_parallel().axis();\n      CHECK_NE(src_split_axis_list_.at(i), dst_split_axis_list_.at(i));\n    } else if (nccl_type == \"_nccl_logical_2D_same_dim0_all_gather_noncontinuous\") {\n      CHECK_EQ(src_nd_sbp.sbp_parallel_size(), 2);\n      CHECK(src_nd_sbp.sbp_parallel(1).has_split_parallel());\n      src_split_axis_list_.at(i) = src_nd_sbp.sbp_parallel(1).split_parallel().axis();\n    } else if (nccl_type == \"_nccl_logical_2D_same_dim0_all2all\") {\n      CHECK_EQ(src_nd_sbp.sbp_parallel_size(), 2);\n      CHECK_EQ(dst_nd_sbp.sbp_parallel_size(), 2);\n      CHECK(src_nd_sbp.sbp_parallel(1).has_split_parallel());\n      CHECK(dst_nd_sbp.sbp_parallel(1).has_split_parallel());\n      src_split_axis_list_.at(i) = src_nd_sbp.sbp_parallel(1).split_parallel().axis();\n      dst_split_axis_list_.at(i) = dst_nd_sbp.sbp_parallel(1).split_parallel().axis();\n      CHECK_NE(src_split_axis_list_.at(i), dst_split_axis_list_.at(i));\n    }\n  }\n\n  void InitSplitAxisAndTmpBufferOffset(user_op::KernelInitContext* ctx) {\n    nccl_num_ = ctx->input_size(\"in\");\n    const std::vector<std::string>& src_nd_sbp_str_list =\n        ctx->Attr<std::vector<std::string>>(\"src_nd_sbp_str_list\");\n    const std::vector<std::string>& dst_nd_sbp_str_list =\n        ctx->Attr<std::vector<std::string>>(\"dst_nd_sbp_str_list\");\n    const std::vector<std::string>& nccl_type_list =\n        ctx->Attr<std::vector<std::string>>(\"nccl_type_list\");\n\n    CHECK_EQ(nccl_num_, ctx->output_size(\"out\"));\n    src_split_axis_list_.resize(nccl_num_, -1);\n    dst_split_axis_list_.resize(nccl_num_, -1);\n\n    CHECK_EQ(src_nd_sbp_str_list.size(), nccl_num_);\n    CHECK_EQ(dst_nd_sbp_str_list.size(), nccl_num_);\n    CHECK_EQ(nccl_type_list.size(), nccl_num_);\n    CHECK_EQ(src_split_axis_list_.size(), nccl_num_);\n    CHECK_EQ(dst_split_axis_list_.size(), nccl_num_);\n\n    size_t total_buffer_size = 0;\n\n    for (int32_t i = 0; i < nccl_num_; ++i) {\n      NdSbp src_nd_sbp;\n      NdSbp dst_nd_sbp;\n      CHECK(ParseNdSbpFromLongString(src_nd_sbp_str_list.at(i), &src_nd_sbp));\n      CHECK(ParseNdSbpFromLongString(dst_nd_sbp_str_list.at(i), &dst_nd_sbp));\n      const std::string& nccl_type = nccl_type_list.at(i);\n      UpdateOrCheckEqCommKey(GetCommKeyFromNcclType(nccl_type));\n      UpdateSplitAxisByNcclType(nccl_type, i, src_nd_sbp, dst_nd_sbp);\n      size_t in_tensor_byte_size = GetTensorByteSize(*ctx->TensorDesc4ArgNameAndIndex(\"in\", i));\n      size_t out_tensor_byte_size = GetTensorByteSize(*ctx->TensorDesc4ArgNameAndIndex(\"out\", i));\n\n      tmp_buffer_offset_.push_back(total_buffer_size);\n      size_t tmp_buffer_size = GetTmpBufferSizeByNcclType(\n          nccl_type, in_tensor_byte_size, out_tensor_byte_size, src_nd_sbp, dst_nd_sbp);\n      tmp_buffer_size_.push_back(tmp_buffer_size);\n      total_buffer_size += tmp_buffer_size;\n    }\n    // NOTE(chengcheng): last element of vector is total_buffer_size\n    tmp_buffer_offset_.push_back(total_buffer_size);\n    CHECK_EQ(tmp_buffer_offset_.size(), nccl_num_ + 1);\n    CHECK_EQ(tmp_buffer_size_.size(), nccl_num_);\n    const user_op::TensorDesc* tmp_buffer_tensor_desc =\n        ctx->TensorDesc4ArgNameAndIndex(\"tmp_buffer\", 0);\n    if (tmp_buffer_tensor_desc == nullptr) {\n      CHECK_EQ(total_buffer_size, 0);\n    } else {\n      CHECK_EQ(total_buffer_size, GetTensorByteSize(*tmp_buffer_tensor_desc));\n    }\n  }\n\n  bool is_init_;\n  std::string stream_name_;\n  ParallelDesc parallel_desc_;\n  int64_t this_parallel_id_;\n  int64_t num_ranks_;\n  std::string comm_key_;\n  int32_t nccl_num_;\n  std::vector<int64_t> src_split_axis_list_;\n  std::vector<int64_t> dst_split_axis_list_;\n  std::vector<size_t> tmp_buffer_offset_;\n  std::vector<size_t> tmp_buffer_size_;\n  ccl::CclComm ccl_comm_{};\n};\n\nclass CclLogicalFusionKernel final : public user_op::OpKernel {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CclLogicalFusionKernel);\n  CclLogicalFusionKernel() = default;\n  ~CclLogicalFusionKernel() override = default;\n\n  std::shared_ptr<user_op::OpKernelState> CreateOpKernelState(\n      user_op::KernelInitContext* ctx) const override {\n    return std::make_shared<CclLogicalFusionKernelState>(ctx);\n  }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state,\n               const user_op::OpKernelCache*) const override;\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n  bool IsKernelLaunchSynchronized() const override {\n    EagerCclCommMgr* comm_mgr = CHECK_NOTNULL(Singleton<EagerCclCommMgr>::Get());\n    return comm_mgr->IsAsyncLaunchCclLogicalKernel();\n    return true;\n  }\n};\n\nconst void* UpdatePackToPtrByNcclType(const void* pack_to_ptr, const std::string& nccl_type,\n                                      user_op::Tensor* tmp_buffer,\n                                      CclLogicalFusionKernelState* kernel_state, const int32_t i) {\n  CHECK_NOTNULL(tmp_buffer);\n  const void* tmp_dptr =\n      static_cast<const void*>(tmp_buffer->dptr<char>() + kernel_state->tmp_buffer_offset(i));\n  if (nccl_type == \"_nccl_logical_reduce_scatter_noncontinuous\") {\n    return tmp_dptr;\n  } else if (nccl_type == \"_nccl_logical_s2s\") {\n    if (kernel_state->dst_split_axis(i) != 0) {\n      return tmp_dptr;  // need do pack;\n    }\n  } else if (nccl_type == \"_nccl_logical_2D_same_dim0_all2all\") {\n    if (kernel_state->dst_split_axis(i) != 0) {\n      return tmp_dptr;  // need do pack;\n    }\n  }\n  return pack_to_ptr;\n}\n\nvoid* UpdateUnpackFromPtrByNcclType(void* unpack_from_ptr, const std::string& nccl_type,\n                                    user_op::Tensor* tmp_buffer, const user_op::Tensor* in,\n                                    CclLogicalFusionKernelState* kernel_state, const int32_t i) {\n  CHECK_NOTNULL(tmp_buffer);\n  void* tmp_dptr =\n      static_cast<void*>(tmp_buffer->mut_dptr<char>() + kernel_state->tmp_buffer_offset(i));\n  int64_t data_size = GetTensorByteSize(*in);\n  int64_t tmp_buffer_size = kernel_state->tmp_buffer_size(i);\n  if (nccl_type == \"_nccl_logical_all_gather_noncontinuous\") {\n    return tmp_dptr;\n  } else if (nccl_type == \"_nccl_logical_s2s\") {\n    if (kernel_state->src_split_axis(i) != 0) {\n      CHECK(tmp_buffer_size == data_size || tmp_buffer_size == 2 * data_size);\n      return static_cast<void*>(static_cast<char*>(tmp_dptr) + (tmp_buffer_size - data_size));\n    }\n  } else if (nccl_type == \"_nccl_logical_2D_same_dim0_all_gather_noncontinuous\") {\n    return tmp_dptr;\n  } else if (nccl_type == \"_nccl_logical_2D_same_dim0_all2all\") {\n    if (kernel_state->src_split_axis(i) != 0) {\n      CHECK(tmp_buffer_size == data_size || tmp_buffer_size == 2 * data_size);\n      return static_cast<void*>(static_cast<char*>(tmp_dptr) + (tmp_buffer_size - data_size));\n    }\n  }\n  return unpack_from_ptr;\n}\n\nvoid DoPackBeforeNcclGroup(void* pack_to_ptr, const std::string& nccl_type,\n                           const user_op::Tensor* in, user_op::KernelComputeContext* ctx,\n                           CclLogicalFusionKernelState* kernel_state, const int32_t i) {\n  if (nccl_type == \"_nccl_logical_reduce_scatter_noncontinuous\") {\n    // Do pack before reduce scatter\n    const int64_t num_ranks = kernel_state->num_ranks();\n    const int64_t out_split_axis = kernel_state->dst_split_axis(i);\n    DimVector transpose_in_dim_vec;\n    in->shape_view().ToDimVector(&transpose_in_dim_vec);\n\n    transpose_in_dim_vec[out_split_axis] = transpose_in_dim_vec.at(out_split_axis) / num_ranks;\n    transpose_in_dim_vec.insert(transpose_in_dim_vec.begin() + out_split_axis, num_ranks);\n    const Shape transpose_in_shape(transpose_in_dim_vec);\n    std::vector<int32_t> perm;\n    perm.emplace_back(out_split_axis);\n    FOR_RANGE(int64_t, i, 0, transpose_in_dim_vec.size()) {\n      if (i != out_split_axis) { perm.emplace_back(i); }\n    }\n    auto transpose = ep::primitive::NewPrimitive<ep::primitive::PermuteFactory>(\n        ctx->stream()->device_type(), transpose_in_dim_vec.size());\n    CHECK(transpose);\n    transpose->Launch(ctx->stream(), in->data_type(), transpose_in_dim_vec.size(),\n                      transpose_in_dim_vec.data(), in->dptr(), perm.data(), pack_to_ptr);\n    VLOG(3) << \"[NcclLogicalFusion] op: \" << ctx->op_name() << \" , i= \" << i\n            << \", stream: \" << kernel_state->stream_name() << \" Do pack before [ReduceScatter]\";\n  } else if (nccl_type == \"_nccl_logical_s2s\") {\n    const int64_t out_split_axis = kernel_state->dst_split_axis(i);\n    if (out_split_axis != 0) {\n      // Do pack before all2all\n      const int64_t num_ranks = kernel_state->num_ranks();\n      DimVector transpose_in_dim_vec;\n      in->shape_view().ToDimVector(&transpose_in_dim_vec);\n      CHECK_EQ(transpose_in_dim_vec.at(out_split_axis) % num_ranks, 0);\n      transpose_in_dim_vec[out_split_axis] = transpose_in_dim_vec.at(out_split_axis) / num_ranks;\n      transpose_in_dim_vec.insert(transpose_in_dim_vec.begin() + out_split_axis, num_ranks);\n      std::vector<int32_t> perm;\n      perm.emplace_back(out_split_axis);\n      FOR_RANGE(int64_t, i, 0, transpose_in_dim_vec.size()) {\n        if (i != out_split_axis) { perm.emplace_back(i); }\n      }\n      auto transpose = ep::primitive::NewPrimitive<ep::primitive::PermuteFactory>(\n          ctx->stream()->device_type(), transpose_in_dim_vec.size());\n      CHECK(transpose);\n      transpose->Launch(ctx->stream(), in->data_type(), transpose_in_dim_vec.size(),\n                        transpose_in_dim_vec.data(), in->dptr(), perm.data(), pack_to_ptr);\n      VLOG(3) << \"[NcclLogicalFusion] op: \" << ctx->op_name() << \" , i= \" << i\n              << \", stream: \" << kernel_state->stream_name() << \" Do pack before [All2All]\";\n    }\n  } else if (nccl_type == \"_nccl_logical_2D_same_dim0_all2all\") {\n    const int64_t out_split_axis = kernel_state->dst_split_axis(i);\n    if (out_split_axis != 0) {\n      const int64_t num_ranks = kernel_state->num_ranks();\n      DimVector transpose_in_dim_vec;\n      in->shape_view().ToDimVector(&transpose_in_dim_vec);\n      CHECK_EQ(transpose_in_dim_vec.at(out_split_axis) % num_ranks, 0);\n      transpose_in_dim_vec[out_split_axis] = transpose_in_dim_vec.at(out_split_axis) / num_ranks;\n      transpose_in_dim_vec.insert(transpose_in_dim_vec.begin() + out_split_axis, num_ranks);\n      std::vector<int32_t> perm;\n      perm.emplace_back(out_split_axis);\n      FOR_RANGE(int64_t, i, 0, transpose_in_dim_vec.size()) {\n        if (i != out_split_axis) { perm.emplace_back(i); }\n      }\n      auto transpose = ep::primitive::NewPrimitive<ep::primitive::PermuteFactory>(\n          ctx->stream()->device_type(), transpose_in_dim_vec.size());\n      CHECK(transpose);\n      transpose->Launch(ctx->stream(), in->data_type(), transpose_in_dim_vec.size(),\n                        transpose_in_dim_vec.data(), in->dptr(), perm.data(), pack_to_ptr);\n      VLOG(3) << \"[NcclLogicalFusion] op: \" << ctx->op_name() << \" , i= \" << i\n              << \", stream: \" << kernel_state->stream_name()\n              << \" Do pack before [2DSameDim0All2All]\";\n    }\n  }\n}\n\nvoid DoNcclComputeByNcclTypeInGroup(const void* pack_to_ptr, void* unpack_from_ptr,\n                                    const std::string& nccl_type, const user_op::Tensor* in,\n                                    user_op::Tensor* out, user_op::KernelComputeContext* ctx,\n                                    CclLogicalFusionKernelState* kernel_state, const int32_t i,\n                                    ccl::CclComm ccl_comm) {\n  std::unique_ptr<ccl::Send> ccl_send =\n      ccl::NewCollectiveCommunication<ccl::Send>(ctx->stream()->device_type(), in->data_type());\n  std::unique_ptr<ccl::Recv> ccl_recv =\n      ccl::NewCollectiveCommunication<ccl::Recv>(ctx->stream()->device_type(), in->data_type());\n\n  const int64_t num_ranks = kernel_state->num_ranks();\n  VLOG(3) << \"[NcclLogicalFusion] op: \" << ctx->op_name() << \" , i= \" << i\n          << \", stream: \" << kernel_state->stream_name() << \" Try launch nccl_type: \" << nccl_type;\n  if (nccl_type == \"_nccl_logical_all_reduce\") {\n    CHECK(in->dptr() == pack_to_ptr);\n    CHECK(out->mut_dptr() == unpack_from_ptr);\n    ccl::ReduceType ccl_reduce_type = ccl::ReduceType::kSum;\n    if (in->data_type() == DataType::kBool) { ccl_reduce_type = ccl::ReduceType::kMax; }\n    std::unique_ptr<ccl::AllReduce> ccl_all_reduce =\n        ccl::NewCollectiveCommunication<ccl::AllReduce>(ctx->stream()->device_type(),\n                                                        in->data_type(), ccl_reduce_type);\n    ccl_all_reduce->Launch(ctx->stream(), in->dptr(), out->mut_dptr(), in->shape_view().elem_cnt(),\n                           ccl_comm);\n\n  } else if (nccl_type == \"_nccl_logical_reduce_scatter\") {\n    CHECK(in->dptr() == pack_to_ptr);\n    CHECK(out->mut_dptr() == unpack_from_ptr);\n    CHECK_EQ(in->shape_view().elem_cnt(), out->shape_view().elem_cnt() * num_ranks);\n    ccl::ReduceType ccl_reduce_type = ccl::ReduceType::kSum;\n    if (in->data_type() == DataType::kBool) { ccl_reduce_type = ccl::ReduceType::kMax; }\n    std::unique_ptr<ccl::ReduceScatter> ccl_reduce_scatter =\n        ccl::NewCollectiveCommunication<ccl::ReduceScatter>(ctx->stream()->device_type(),\n                                                            in->data_type(), ccl_reduce_type);\n    ccl_reduce_scatter->Launch(ctx->stream(), in->dptr(), out->mut_dptr(),\n                               out->shape_view().elem_cnt(), ccl_comm);\n  } else if (nccl_type == \"_nccl_logical_all_gather\") {\n    CHECK(in->dptr() == pack_to_ptr);\n    CHECK(out->mut_dptr() == unpack_from_ptr);\n    CHECK_EQ(in->shape_view().elem_cnt() * num_ranks, out->shape_view().elem_cnt());\n\n    std::unique_ptr<ccl::AllGather> ccl_all_gather =\n        ccl::NewCollectiveCommunication<ccl::AllGather>(ctx->stream()->device_type(),\n                                                        in->data_type());\n    ccl_all_gather->Launch(ctx->stream(), in->dptr(), out->mut_dptr(), in->shape_view().elem_cnt(),\n                           ccl_comm);\n  } else if (nccl_type == \"_nccl_logical_all_gather_noncontinuous\") {\n    CHECK(in->dptr() == pack_to_ptr);\n    CHECK(out->mut_dptr() != unpack_from_ptr);  // do unpack from ptr -> out\n    CHECK_EQ(in->shape_view().elem_cnt() * num_ranks, out->shape_view().elem_cnt());\n    std::unique_ptr<ccl::AllGather> ccl_all_gather =\n        ccl::NewCollectiveCommunication<ccl::AllGather>(ctx->stream()->device_type(),\n                                                        in->data_type());\n    ccl_all_gather->Launch(ctx->stream(), in->dptr(), unpack_from_ptr, in->shape_view().elem_cnt(),\n                           ccl_comm);\n  } else if (nccl_type == \"_nccl_logical_reduce_scatter_noncontinuous\") {\n    CHECK(in->dptr() != pack_to_ptr);  // do in -> pack to ptr\n    CHECK(out->mut_dptr() == unpack_from_ptr);\n    ccl::ReduceType ccl_reduce_type = ccl::ReduceType::kSum;\n    if (in->data_type() == DataType::kBool) { ccl_reduce_type = ccl::ReduceType::kMax; }\n    std::unique_ptr<ccl::ReduceScatter> ccl_reduce_scatter =\n        ccl::NewCollectiveCommunication<ccl::ReduceScatter>(ctx->stream()->device_type(),\n                                                            in->data_type(), ccl_reduce_type);\n    ccl_reduce_scatter->Launch(ctx->stream(), pack_to_ptr, out->mut_dptr(),\n                               out->shape_view().elem_cnt(), ccl_comm);\n  } else if (nccl_type == \"_nccl_logical_s2s\") {\n    const int64_t elem_cnt = in->shape_view().elem_cnt();\n    const int64_t elem_per_chunk = elem_cnt / num_ranks;\n    std::unique_ptr<ccl::AllToAll> all_to_all = ccl::NewCollectiveCommunication<ccl::AllToAll>(\n        ctx->stream()->device_type(), in->data_type(), in->data_type(), num_ranks);\n    all_to_all->Launch(ctx->stream(), const_cast<void*>(pack_to_ptr), elem_per_chunk,\n                       unpack_from_ptr, elem_per_chunk, ccl_comm);\n\n  } else if (nccl_type == \"_nccl_logical_2D_same_dim0_all_reduce\") {\n    CHECK(in->dptr() == pack_to_ptr);\n    CHECK(out->mut_dptr() == unpack_from_ptr);\n    ccl::ReduceType ccl_reduce_type = ccl::ReduceType::kSum;\n    if (in->data_type() == DataType::kBool) { ccl_reduce_type = ccl::ReduceType::kMax; }\n    std::unique_ptr<ccl::AllReduce> ccl_all_reduce =\n        ccl::NewCollectiveCommunication<ccl::AllReduce>(ctx->stream()->device_type(),\n                                                        in->data_type(), ccl_reduce_type);\n    ccl_all_reduce->Launch(ctx->stream(), in->dptr(), out->mut_dptr(), in->shape_view().elem_cnt(),\n                           ccl_comm);\n  } else if (nccl_type == \"_nccl_logical_2D_same_dim0_all_gather\") {\n    CHECK(in->dptr() == pack_to_ptr);\n    CHECK(out->mut_dptr() == unpack_from_ptr);\n    CHECK_EQ(in->shape_view().elem_cnt() * num_ranks, out->shape_view().elem_cnt());\n    std::unique_ptr<ccl::AllGather> ccl_all_gather =\n        ccl::NewCollectiveCommunication<ccl::AllGather>(ctx->stream()->device_type(),\n                                                        in->data_type());\n    ccl_all_gather->Launch(ctx->stream(), in->dptr(), out->mut_dptr(), in->shape_view().elem_cnt(),\n                           ccl_comm);\n  } else if (nccl_type == \"_nccl_logical_2D_same_dim0_all_gather_noncontinuous\") {\n    CHECK(in->dptr() == pack_to_ptr);\n    CHECK(out->mut_dptr() != unpack_from_ptr);  // do unpack from ptr -> out\n    CHECK_EQ(in->shape_view().elem_cnt() * num_ranks, out->shape_view().elem_cnt());\n    std::unique_ptr<ccl::AllGather> ccl_all_gather =\n        ccl::NewCollectiveCommunication<ccl::AllGather>(ctx->stream()->device_type(),\n                                                        in->data_type());\n    ccl_all_gather->Launch(ctx->stream(), in->dptr(), unpack_from_ptr, in->shape_view().elem_cnt(),\n                           ccl_comm);\n  } else if (nccl_type == \"_nccl_logical_2D_same_dim0_all2all\") {\n    const int64_t elem_cnt = in->shape_view().elem_cnt();\n    const int64_t elem_per_chunk = elem_cnt / num_ranks;\n    std::unique_ptr<ccl::AllToAll> all_to_all = ccl::NewCollectiveCommunication<ccl::AllToAll>(\n        ctx->stream()->device_type(), in->data_type(), in->data_type(), num_ranks);\n    all_to_all->Launch(ctx->stream(), const_cast<void*>(pack_to_ptr), elem_per_chunk,\n                       unpack_from_ptr, elem_per_chunk, ccl_comm);\n  } else if (nccl_type == \"_nccl_logical_2D_same_dim1_all_reduce\") {\n    CHECK(in->dptr() == pack_to_ptr);\n    CHECK(out->mut_dptr() == unpack_from_ptr);\n    ccl::ReduceType ccl_reduce_type = ccl::ReduceType::kSum;\n    if (in->data_type() == DataType::kBool) { ccl_reduce_type = ccl::ReduceType::kMax; }\n    std::unique_ptr<ccl::AllReduce> ccl_all_reduce =\n        ccl::NewCollectiveCommunication<ccl::AllReduce>(ctx->stream()->device_type(),\n                                                        in->data_type(), ccl_reduce_type);\n    ccl_all_reduce->Launch(ctx->stream(), in->dptr(), out->mut_dptr(), in->shape_view().elem_cnt(),\n                           ccl_comm);\n\n  } else {\n    UNIMPLEMENTED();\n  }\n  VLOG(3) << \"[NcclLogicalFusion] op: \" << ctx->op_name() << \" , i= \" << i\n          << \", stream: \" << kernel_state->stream_name() << \" launched nccl_type: \" << nccl_type;\n}\n\nvoid DoUnpackAfterNcclGroup(void* unpack_from_ptr, const std::string& nccl_type,\n                            const user_op::Tensor* in, user_op::Tensor* out,\n                            user_op::KernelComputeContext* ctx,\n                            CclLogicalFusionKernelState* kernel_state, const int32_t i) {\n  const int64_t num_ranks = kernel_state->num_ranks();\n  const int64_t in_split_axis = kernel_state->src_split_axis(i);\n  const int64_t out_split_axis = kernel_state->dst_split_axis(i);\n  if (nccl_type == \"_nccl_logical_all_gather_noncontinuous\") {\n    CHECK_GT(in_split_axis, 0);\n    DimVector unpack_from_dim_vec;\n    in->shape_view().ToDimVector(&unpack_from_dim_vec);\n    unpack_from_dim_vec.insert(unpack_from_dim_vec.begin(), num_ranks);\n    std::vector<int32_t> perm;\n    FOR_RANGE(int64_t, i, 1, unpack_from_dim_vec.size()) { perm.emplace_back(i); }\n    perm.insert(perm.begin() + in_split_axis, 0);\n    auto transpose = ep::primitive::NewPrimitive<ep::primitive::PermuteFactory>(\n        ctx->stream()->device_type(), unpack_from_dim_vec.size());\n    CHECK(transpose);\n    transpose->Launch(ctx->stream(), in->data_type(), unpack_from_dim_vec.size(),\n                      unpack_from_dim_vec.data(), unpack_from_ptr, perm.data(), out->mut_dptr());\n    VLOG(3) << \"[NcclLogicalFusion] op: \" << ctx->op_name() << \" , i= \" << i\n            << \", stream: \" << kernel_state->stream_name()\n            << \" Do unpack after [AllGatherNoncontinuous]\";\n  } else if (nccl_type == \"_nccl_logical_s2s\") {\n    CHECK_GE(in_split_axis, 0);\n    CHECK_GE(out_split_axis, 0);\n    if (in_split_axis != 0) {\n      // Do unpack.\n      CHECK(unpack_from_ptr != out->mut_dptr());\n\n      DimVector unpack_from_dim_vec;\n      in->shape_view().ToDimVector(&unpack_from_dim_vec);\n      CHECK_EQ(unpack_from_dim_vec.at(out_split_axis) % num_ranks, 0);\n      unpack_from_dim_vec[out_split_axis] = unpack_from_dim_vec.at(out_split_axis) / num_ranks;\n      unpack_from_dim_vec.insert(unpack_from_dim_vec.begin(), num_ranks);\n      std::vector<int32_t> perm;\n      FOR_RANGE(int64_t, i, 1, unpack_from_dim_vec.size()) { perm.emplace_back(i); }\n      perm.insert(perm.begin() + in_split_axis, 0);\n      auto transpose = ep::primitive::NewPrimitive<ep::primitive::PermuteFactory>(\n          ctx->stream()->device_type(), unpack_from_dim_vec.size());\n      CHECK(transpose);\n      transpose->Launch(ctx->stream(), in->data_type(), unpack_from_dim_vec.size(),\n                        unpack_from_dim_vec.data(), unpack_from_ptr, perm.data(), out->mut_dptr());\n      VLOG(3) << \"[NcclLogicalFusion] op: \" << ctx->op_name() << \" , i= \" << i\n              << \", stream: \" << kernel_state->stream_name() << \" Do unpack after [All2All]\";\n    }\n  } else if (nccl_type == \"_nccl_logical_2D_same_dim0_all_gather_noncontinuous\") {\n    DimVector unpack_from_dim_vec;\n    in->shape_view().ToDimVector(&unpack_from_dim_vec);\n    CHECK_GT(in_split_axis, 0);\n    // NOTE(chengcheng): Do unpack.\n    unpack_from_dim_vec.insert(unpack_from_dim_vec.begin(), num_ranks);\n    std::vector<int32_t> perm;\n    FOR_RANGE(int64_t, i, 1, unpack_from_dim_vec.size()) { perm.emplace_back(i); }\n    perm.insert(perm.begin() + in_split_axis, 0);\n\n    auto transpose = ep::primitive::NewPrimitive<ep::primitive::PermuteFactory>(\n        ctx->stream()->device_type(), unpack_from_dim_vec.size());\n    CHECK(transpose);\n    transpose->Launch(ctx->stream(), in->data_type(), unpack_from_dim_vec.size(),\n                      unpack_from_dim_vec.data(), unpack_from_ptr, perm.data(), out->mut_dptr());\n    VLOG(3) << \"[NcclLogicalFusion] op: \" << ctx->op_name() << \" , i= \" << i\n            << \", stream: \" << kernel_state->stream_name()\n            << \" Do unpack after [SameDim0AllGatherNoncontinuous]\";\n  } else if (nccl_type == \"_nccl_logical_2D_same_dim0_all2all\") {\n    CHECK_GE(in_split_axis, 0);\n    CHECK_GE(out_split_axis, 0);\n    if (in_split_axis != 0) {\n      DimVector unpack_from_dim_vec;\n      in->shape_view().ToDimVector(&unpack_from_dim_vec);\n      // Do unpack.\n      CHECK(unpack_from_ptr != out->mut_dptr());\n      CHECK_EQ(unpack_from_dim_vec.at(out_split_axis) % num_ranks, 0);\n      unpack_from_dim_vec[out_split_axis] = unpack_from_dim_vec.at(out_split_axis) / num_ranks;\n      unpack_from_dim_vec.insert(unpack_from_dim_vec.begin(), num_ranks);\n      std::vector<int32_t> perm;\n      FOR_RANGE(int64_t, i, 1, unpack_from_dim_vec.size()) { perm.emplace_back(i); }\n      perm.insert(perm.begin() + in_split_axis, 0);\n      auto transpose = ep::primitive::NewPrimitive<ep::primitive::PermuteFactory>(\n          ctx->stream()->device_type(), unpack_from_dim_vec.size());\n      CHECK(transpose);\n      transpose->Launch(ctx->stream(), in->data_type(), unpack_from_dim_vec.size(),\n                        unpack_from_dim_vec.data(), unpack_from_ptr, perm.data(), out->mut_dptr());\n    }\n  }\n}\n\nvoid CclLogicalFusionKernel::Compute(user_op::KernelComputeContext* ctx,\n                                     user_op::OpKernelState* state,\n                                     const user_op::OpKernelCache*) const {\n  auto* kernel_state = dynamic_cast<CclLogicalFusionKernelState*>(state);\n  CHECK_NOTNULL(kernel_state);\n  const int32_t nccl_num = kernel_state->nccl_num();\n  const std::vector<std::string>& nccl_type_list =\n      ctx->Attr<std::vector<std::string>>(\"nccl_type_list\");\n\n  user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n\n  // NOTE(chengcheng):\n  //    pack:   in.dptr -> pack_to_ptr              // if not pack : pack_to_ptr = in.dptr\n  //    nccl:   pack_to_ptr -> unpack_from_ptr\n  //    unpack: unpack_from_ptr ->out.dptr          // if not unpack: unpack_from_ptr = out.dptr\n  std::vector<const void*> pack_to_ptr_list(nccl_num, nullptr);\n  std::vector<void*> unpack_from_ptr_list(nccl_num, nullptr);\n  std::vector<DataType> dtype_list(nccl_num, DataType::kInvalidDataType);\n  for (int32_t i = 0; i < nccl_num; ++i) {\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"in\", i);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", i);\n    pack_to_ptr_list.at(i) = in->dptr();\n    unpack_from_ptr_list.at(i) = out->mut_dptr();\n    dtype_list.at(i) = in->data_type();\n    CHECK_EQ(dtype_list.at(i), out->data_type());\n  }\n\n  // try to do pack before all nccl\n  for (int32_t i = 0; i < nccl_num; ++i) {\n    if (kernel_state->tmp_buffer_size(i) == 0) { continue; }\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"in\", i);\n    // TODO(chengcheng): refactor code by template.\n    pack_to_ptr_list.at(i) = UpdatePackToPtrByNcclType(pack_to_ptr_list.at(i), nccl_type_list.at(i),\n                                                       tmp_buffer, kernel_state, i);\n    unpack_from_ptr_list.at(i) = UpdateUnpackFromPtrByNcclType(\n        unpack_from_ptr_list.at(i), nccl_type_list.at(i), tmp_buffer, in, kernel_state, i);\n    DoPackBeforeNcclGroup(const_cast<void*>(pack_to_ptr_list.at(i)) /* mut dptr */,\n                          nccl_type_list.at(i), in, ctx, kernel_state, i);\n  }\n\n  // NOTE(chengcheng): init nccl comm need before ncclGroupStart.\n  ccl::CclComm ccl_comm = kernel_state->ccl_comm();\n\n  // do nccl compute in group\n  // TODO:(zhaoluyang) replacre ncclGroupStart/ncclGroupEnd with ccl CclGroupStart/CclGroupEnd\n  // OF_NCCL_CHECK(ncclGroupStart());\n  for (int32_t i = 0; i < nccl_num; ++i) {\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"in\", i);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", i);\n    DoNcclComputeByNcclTypeInGroup(pack_to_ptr_list.at(i), unpack_from_ptr_list.at(i),\n                                   nccl_type_list.at(i), in, out, ctx, kernel_state, i, ccl_comm);\n  }\n  // OF_NCCL_CHECK(ncclGroupEnd());\n\n  // try to do unpack after all nccl\n  for (int32_t i = 0; i < nccl_num; ++i) {\n    if (kernel_state->tmp_buffer_size(i) == 0) { continue; }\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"in\", i);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", i);\n    DoUnpackAfterNcclGroup(unpack_from_ptr_list.at(i), nccl_type_list.at(i), in, out, ctx,\n                           kernel_state, i);\n  }\n}\n\nsize_t InferNcclLogicalFusionKernelTmpBufferSize(user_op::InferContext* ctx) {\n  size_t total_buffer_size = 0;\n  const auto& src_nd_sbp_str_list = ctx->Attr<std::vector<std::string>>(\"src_nd_sbp_str_list\");\n  const auto& dst_nd_sbp_str_list = ctx->Attr<std::vector<std::string>>(\"dst_nd_sbp_str_list\");\n  const auto& nccl_type_list = ctx->Attr<std::vector<std::string>>(\"nccl_type_list\");\n  int32_t nccl_num = nccl_type_list.size();\n  CHECK_EQ(nccl_num, ctx->input_size(\"in\"));\n  CHECK_EQ(nccl_num, ctx->output_size(\"out\"));\n  CHECK_EQ(nccl_num, src_nd_sbp_str_list.size());\n  CHECK_EQ(nccl_num, dst_nd_sbp_str_list.size());\n  for (int32_t i = 0; i < nccl_num; ++i) {\n    const std::string& nccl_type = nccl_type_list.at(i);\n    size_t in_tensor_byte_size = GetTensorByteSize(ctx->InputTensorDesc(\"in\", i));\n    size_t out_tensor_byte_size = GetTensorByteSize(ctx->OutputTensorDesc(\"out\", i));\n    NdSbp src_nd_sbp;\n    NdSbp dst_nd_sbp;\n    CHECK(ParseNdSbpFromLongString(src_nd_sbp_str_list.at(i), &src_nd_sbp));\n    CHECK(ParseNdSbpFromLongString(dst_nd_sbp_str_list.at(i), &dst_nd_sbp));\n    total_buffer_size += GetTmpBufferSizeByNcclType(nccl_type, in_tensor_byte_size,\n                                                    out_tensor_byte_size, src_nd_sbp, dst_nd_sbp);\n  }\n  return total_buffer_size;\n}\n\nREGISTER_USER_KERNEL(\"_nccl_logical_fusion\")\n    .SetCreateFn<CclLogicalFusionKernel>()\n    .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)\n                     || (user_op::HobDeviceType() == DeviceType::kNPU)\n                     || (user_op::HobDeviceType() == DeviceType::kMLU))\n    .SetInferTmpSizeFn(InferNcclLogicalFusionKernelTmpBufferSize);\n\n// TODO: SetIsMatchedHob support multi devices(not including cpu)\n}  // namespace\n\n}  // namespace oneflow\n\n#endif  // WITH_CUDA || WITH_NPU || WITH_MLU\n"
  },
  {
    "path": "oneflow/user/kernels/nccl_logical_kernels.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/device/nccl_util.h\"\n#include \"oneflow/core/job/eager_nccl_comm_manager.h\"\n#include \"oneflow/core/job/parallel_desc.h\"\n#include \"oneflow/core/ep/include/primitive/permute.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n#include \"oneflow/user/ops/nccl_logical_util.h\"\n#include \"oneflow/user/kernels/collective_communication/include/all_to_all.h\"\n#include \"oneflow/user/kernels/collective_communication/include/all_reduce.h\"\n#include \"oneflow/user/kernels/collective_communication/include/all_gather.h\"\n#include \"oneflow/user/kernels/collective_communication/include/reduce_scatter.h\"\n#include \"oneflow/user/kernels/collective_communication/include/broadcast.h\"\n#include \"oneflow/user/kernels/collective_communication/include/reduce.h\"\n\n#if (defined(WITH_CUDA) && (NCCL_VERSION_CODE > 2700)) || defined(WITH_NPU) || defined(WITH_MLU)\n\nnamespace oneflow {\n\nnamespace {\n\nauto AllReduceCollectiveCommunicationExists() {\n  return hob::make_custom(\"AllReduceCollectiveCommunicationExists\",\n                          [=](const user_op::KernelRegContext& ctx) {\n                            DeviceType device_type = ctx.device_type();\n                            return ccl::IsCommunicationContextRegistered(device_type)\n                                   && ccl::IsAllReduceRegistered(device_type);\n                          });\n}\n\nauto ReduceScatterCollectiveCommunicationExists() {\n  return hob::make_custom(\"ReduceScatterCollectiveCommunicationExists\",\n                          [=](const user_op::KernelRegContext& ctx) {\n                            DeviceType device_type = ctx.device_type();\n                            return ccl::IsCommunicationContextRegistered(device_type)\n                                   && ccl::IsReduceScatterRegistered(device_type);\n                          });\n}\n\nauto AllGatherCollectiveCommunicationExists() {\n  return hob::make_custom(\"AllGatherCollectiveCommunicationExists\",\n                          [=](const user_op::KernelRegContext& ctx) {\n                            DeviceType device_type = ctx.device_type();\n                            return ccl::IsCommunicationContextRegistered(device_type)\n                                   && ccl::IsAllGatherRegistered(device_type);\n                          });\n}\n\nauto ReduceCollectiveCommunicationExists() {\n  return hob::make_custom(\"ReduceCollectiveCommunicationExists\",\n                          [=](const user_op::KernelRegContext& ctx) {\n                            DeviceType device_type = ctx.device_type();\n                            return ccl::IsCommunicationContextRegistered(device_type)\n                                   && ccl::IsReduceRegistered(device_type);\n                          });\n}\n\nauto BroadcastCollectiveCommunicationExists() {\n  return hob::make_custom(\"BroadcastCollectiveCommunicationExists\",\n                          [=](const user_op::KernelRegContext& ctx) {\n                            DeviceType device_type = ctx.device_type();\n                            return ccl::IsCommunicationContextRegistered(device_type)\n                                   && ccl::IsBroadcastRegistered(device_type);\n                          });\n}\n\nauto AllToAllCollectiveCommunicationExists() {\n  return hob::make_custom(\"AllToAllCollectiveCommunicationExists\",\n                          [=](const user_op::KernelRegContext& ctx) {\n                            DeviceType device_type = ctx.device_type();\n                            return ccl::IsCommunicationContextRegistered(device_type)\n                                   && ccl::IsAllToAllRegistered(device_type);\n                          });\n}\n\nclass NcclLogicalKernelCommState : public user_op::OpKernelState {\n public:\n  explicit NcclLogicalKernelCommState(user_op::KernelInitContext* ctx)\n      : is_init_(false),\n        stream_name_(EagerCclCommMgr::kDefaultCclStreamName),\n        parallel_desc_(ctx->parallel_desc()) {\n    if (ctx->op_conf().has_stream_name_hint()) { stream_name_ = ctx->op_conf().stream_name_hint(); }\n  }\n  ~NcclLogicalKernelCommState() override = default;\n\n  const ccl::CclComm& ccl_comm() {\n    if (!is_init_) {\n      EagerCclCommMgr* comm_mgr = CHECK_NOTNULL(Singleton<EagerCclCommMgr>::Get());\n      ccl_comm_ = comm_mgr->GetCclCommForParallelDescAndStreamName(parallel_desc_, stream_name_);\n      is_init_ = true;\n    }\n    return ccl_comm_;\n  }\n\n  const std::string& stream_name() const { return stream_name_; }\n\n private:\n  bool is_init_;\n  std::string stream_name_;\n  ParallelDesc parallel_desc_;\n  ccl::CclComm ccl_comm_{};\n};\n\nclass NcclLogicalAllGatherNoncontinuousKernelState : public NcclLogicalKernelCommState {\n public:\n  explicit NcclLogicalAllGatherNoncontinuousKernelState(user_op::KernelInitContext* ctx)\n      : NcclLogicalKernelCommState(ctx), src_split_axis_(-1) {}\n  ~NcclLogicalAllGatherNoncontinuousKernelState() override = default;\n\n  int64_t src_split_axis() const { return src_split_axis_; }\n  void set_src_split_axis(int64_t split_axis) { src_split_axis_ = split_axis; }\n\n private:\n  int64_t src_split_axis_;\n};\n\nclass NcclLogicalReduceScatterNoncontinuousKernelState : public NcclLogicalKernelCommState {\n public:\n  explicit NcclLogicalReduceScatterNoncontinuousKernelState(user_op::KernelInitContext* ctx)\n      : NcclLogicalKernelCommState(ctx), dst_split_axis_(-1) {}\n  ~NcclLogicalReduceScatterNoncontinuousKernelState() override = default;\n\n  int64_t dst_split_axis() const { return dst_split_axis_; }\n  void set_dst_split_axis(int64_t split_axis) { dst_split_axis_ = split_axis; }\n\n private:\n  int64_t dst_split_axis_;\n};\n\nclass NcclLogicalS2SKernelState : public NcclLogicalKernelCommState {\n public:\n  explicit NcclLogicalS2SKernelState(user_op::KernelInitContext* ctx)\n      : NcclLogicalKernelCommState(ctx), src_split_axis_(-1), dst_split_axis_(-1) {}\n  ~NcclLogicalS2SKernelState() override = default;\n\n  int64_t src_split_axis() const { return src_split_axis_; }\n  void set_src_split_axis(int64_t split_axis) { src_split_axis_ = split_axis; }\n  int64_t dst_split_axis() const { return dst_split_axis_; }\n  void set_dst_split_axis(int64_t split_axis) { dst_split_axis_ = split_axis; }\n\n private:\n  int64_t src_split_axis_;\n  int64_t dst_split_axis_;\n};\n\nclass CclLogicalAllReduceKernel final : public user_op::OpKernel {\n public:\n  CclLogicalAllReduceKernel() = default;\n  ~CclLogicalAllReduceKernel() override = default;\n\n  std::shared_ptr<user_op::OpKernelState> CreateOpKernelState(\n      user_op::KernelInitContext* ctx) const override {\n    return std::make_shared<NcclLogicalKernelCommState>(ctx);\n  }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state,\n               const user_op::OpKernelCache*) const override {\n    auto* comm_state = dynamic_cast<NcclLogicalKernelCommState*>(state);\n    CHECK(comm_state != nullptr);\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    CHECK_EQ(in->shape_view(), out->shape_view());\n    CHECK_EQ(in->data_type(), out->data_type());\n    VLOG(3) << \"[NcclLogical][AllReduce] \" << comm_state->stream_name() << \" \" << ctx->op_name()\n            << std::endl;\n\n    ccl::CclComm ccl_comm = comm_state->ccl_comm();\n    ccl::ReduceType ccl_reduce_type = ccl::ReduceType::kSum;\n    if (in->data_type() == DataType::kBool) { ccl_reduce_type = ccl::ReduceType::kMax; }\n    std::unique_ptr<ccl::AllReduce> ccl_all_reduce =\n        ccl::NewCollectiveCommunication<ccl::AllReduce>(ctx->stream()->device_type(),\n                                                        in->data_type(), ccl_reduce_type);\n    ccl_all_reduce->Launch(ctx->stream(), in->dptr(), out->mut_dptr(), in->shape_view().elem_cnt(),\n                           ccl_comm);\n  };\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n  bool IsKernelLaunchSynchronized() const override {\n    const EagerCclCommMgr* comm_mgr = CHECK_NOTNULL(Singleton<EagerCclCommMgr>::Get());\n    return comm_mgr->IsAsyncLaunchCclLogicalKernel();\n  }\n};\n\nclass CclLogicalReduceScatterKernel final : public user_op::OpKernel {\n public:\n  CclLogicalReduceScatterKernel() = default;\n  ~CclLogicalReduceScatterKernel() override = default;\n\n  std::shared_ptr<user_op::OpKernelState> CreateOpKernelState(\n      user_op::KernelInitContext* ctx) const override {\n    return std::make_shared<NcclLogicalKernelCommState>(ctx);\n  }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state,\n               const user_op::OpKernelCache*) const override {\n    auto* comm_state = dynamic_cast<NcclLogicalKernelCommState*>(state);\n    CHECK(comm_state != nullptr);\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    CHECK_EQ(in->data_type(), out->data_type());\n    const int64_t num_ranks = ctx->parallel_ctx().parallel_num();\n    CHECK_EQ(in->shape_view().elem_cnt(), out->shape_view().elem_cnt() * num_ranks);\n    VLOG(3) << \"[NcclLogical][ReduceScatter] \" << comm_state->stream_name() << \" \" << ctx->op_name()\n            << std::endl;\n\n    ccl::CclComm ccl_comm = comm_state->ccl_comm();\n    ccl::ReduceType ccl_reduce_type = ccl::ReduceType::kSum;\n    if (in->data_type() == DataType::kBool) { ccl_reduce_type = ccl::ReduceType::kMax; }\n    std::unique_ptr<ccl::ReduceScatter> ccl_reduce_scatter =\n        ccl::NewCollectiveCommunication<ccl::ReduceScatter>(ctx->stream()->device_type(),\n                                                            in->data_type(), ccl_reduce_type);\n    ccl_reduce_scatter->Launch(ctx->stream(), in->dptr(), out->mut_dptr(),\n                               out->shape_view().elem_cnt(), ccl_comm);\n  };\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n  bool IsKernelLaunchSynchronized() const override {\n    const EagerCclCommMgr* comm_mgr = CHECK_NOTNULL(Singleton<EagerCclCommMgr>::Get());\n    return comm_mgr->IsAsyncLaunchCclLogicalKernel();\n  }\n};\n\nclass CclLogicalAllGatherKernel final : public user_op::OpKernel {\n public:\n  CclLogicalAllGatherKernel() = default;\n  ~CclLogicalAllGatherKernel() override = default;\n\n  std::shared_ptr<user_op::OpKernelState> CreateOpKernelState(\n      user_op::KernelInitContext* ctx) const override {\n    return std::make_shared<NcclLogicalKernelCommState>(ctx);\n  }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state,\n               const user_op::OpKernelCache*) const override {\n    auto* comm_state = dynamic_cast<NcclLogicalKernelCommState*>(state);\n    CHECK(comm_state != nullptr);\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    CHECK_EQ(in->data_type(), out->data_type());\n    const int64_t num_ranks = ctx->parallel_ctx().parallel_num();\n    CHECK_EQ(in->shape_view().elem_cnt() * num_ranks, out->shape_view().elem_cnt());\n    VLOG(3) << \"[NcclLogical][AllGather] \" << comm_state->stream_name() << \" \" << ctx->op_name()\n            << std::endl;\n\n    ccl::CclComm ccl_comm = comm_state->ccl_comm();\n    std::unique_ptr<ccl::AllGather> ccl_all_gather =\n        ccl::NewCollectiveCommunication<ccl::AllGather>(ctx->stream()->device_type(),\n                                                        in->data_type());\n    ccl_all_gather->Launch(ctx->stream(), in->dptr(), out->mut_dptr(), in->shape_view().elem_cnt(),\n                           ccl_comm);\n  };\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n  bool IsKernelLaunchSynchronized() const override {\n    const EagerCclCommMgr* comm_mgr = CHECK_NOTNULL(Singleton<EagerCclCommMgr>::Get());\n    return comm_mgr->IsAsyncLaunchCclLogicalKernel();\n  }\n};\n\ntemplate<typename T>\nclass CclLogicalAllGatherNoncontinuous final : public user_op::OpKernel {\n public:\n  CclLogicalAllGatherNoncontinuous() = default;\n  ~CclLogicalAllGatherNoncontinuous() override = default;\n\n  std::shared_ptr<user_op::OpKernelState> CreateOpKernelState(\n      user_op::KernelInitContext* ctx) const override {\n    auto state = std::make_shared<NcclLogicalAllGatherNoncontinuousKernelState>(ctx);\n    NdSbp src_nd_sbp;\n    CHECK_JUST(GetNcclLogicalNdSbpFromAttr(ctx, \"src_reduced_nd_sbp\", &src_nd_sbp));\n    CHECK_EQ(src_nd_sbp.sbp_parallel_size(), 1);\n    CHECK(src_nd_sbp.sbp_parallel(0).has_split_parallel());\n    state->set_src_split_axis(src_nd_sbp.sbp_parallel(0).split_parallel().axis());\n    return state;\n  }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state,\n               const user_op::OpKernelCache*) const override {\n    auto* kernel_state = dynamic_cast<NcclLogicalAllGatherNoncontinuousKernelState*>(state);\n    CHECK_NOTNULL(kernel_state);\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n    const int64_t dtype_size = GetSizeOfDataType(in->data_type());\n    int64_t data_size = GetCudaAlignedSize(out->shape_view().elem_cnt() * dtype_size);\n    void* unpack_from_ptr = tmp_buffer->mut_dptr();\n    CHECK_EQ(tmp_buffer->shape_view().elem_cnt(), data_size);\n\n    CHECK_EQ(in->data_type(), out->data_type());\n    const int64_t num_ranks = ctx->parallel_ctx().parallel_num();\n    const int64_t in_split_axis = kernel_state->src_split_axis();\n\n    DimVector logical_shape_dim_vec;\n    in->shape_view().ToDimVector(&logical_shape_dim_vec);\n    logical_shape_dim_vec[in_split_axis] = logical_shape_dim_vec.at(in_split_axis) * num_ranks;\n\n    VLOG(3) << \"[NcclLogical][AllGatherNoncontinuous] \" << kernel_state->stream_name() << \" \"\n            << ctx->op_name() << std::endl;\n\n    // NOTE(chengcheng): Do AllGather\n    CHECK_EQ(in->shape_view().elem_cnt() * num_ranks, out->shape_view().elem_cnt());\n    ccl::CclComm ccl_comm = kernel_state->ccl_comm();\n    std::unique_ptr<ccl::AllGather> ccl_all_gather =\n        ccl::NewCollectiveCommunication<ccl::AllGather>(ctx->stream()->device_type(),\n                                                        in->data_type());\n    ccl_all_gather->Launch(ctx->stream(), in->dptr(), unpack_from_ptr, in->shape_view().elem_cnt(),\n                           ccl_comm);\n\n    CHECK_GT(in_split_axis, 0);\n    // NOTE(chengcheng): Do unpack.\n    DimVector unpack_from_dim_vec = logical_shape_dim_vec;\n    CHECK_EQ(unpack_from_dim_vec.at(in_split_axis) % num_ranks, 0);\n    unpack_from_dim_vec[in_split_axis] = unpack_from_dim_vec.at(in_split_axis) / num_ranks;\n    unpack_from_dim_vec.insert(unpack_from_dim_vec.begin(), num_ranks);\n    std::vector<int32_t> perm;\n    FOR_RANGE(int64_t, i, 1, unpack_from_dim_vec.size()) { perm.emplace_back(i); }\n    perm.insert(perm.begin() + in_split_axis, 0);\n    auto transpose = ep::primitive::NewPrimitive<ep::primitive::PermuteFactory>(\n        ctx->stream()->device_type(), unpack_from_dim_vec.size());\n    CHECK(transpose);\n    transpose->Launch(ctx->stream(), in->data_type(), unpack_from_dim_vec.size(),\n                      unpack_from_dim_vec.data(), unpack_from_ptr, perm.data(), out->mut_dptr());\n  };\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n  bool IsKernelLaunchSynchronized() const override {\n    const EagerCclCommMgr* comm_mgr = CHECK_NOTNULL(Singleton<EagerCclCommMgr>::Get());\n    return comm_mgr->IsAsyncLaunchCclLogicalKernel();\n  }\n};\n\nsize_t InferAllGatherNoncontinuousKernelTmpBufferSize(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& out_tensor = ctx->OutputTensorDesc(\"out\", 0);\n  return GetCudaAlignedSize(out_tensor.shape().elem_cnt()\n                            * GetSizeOfDataType(out_tensor.data_type()));\n}\n\ntemplate<typename T>\nclass CclLogicalReduceScatterNoncontinuous final : public user_op::OpKernel {\n public:\n  CclLogicalReduceScatterNoncontinuous() = default;\n  ~CclLogicalReduceScatterNoncontinuous() override = default;\n\n  std::shared_ptr<user_op::OpKernelState> CreateOpKernelState(\n      user_op::KernelInitContext* ctx) const override {\n    auto state = std::make_shared<NcclLogicalReduceScatterNoncontinuousKernelState>(ctx);\n    NdSbp dst_nd_sbp;\n    CHECK_JUST(GetNcclLogicalNdSbpFromAttr(ctx, \"dst_reduced_nd_sbp\", &dst_nd_sbp));\n    CHECK_EQ(dst_nd_sbp.sbp_parallel_size(), 1);\n    CHECK(dst_nd_sbp.sbp_parallel(0).has_split_parallel());\n    state->set_dst_split_axis(dst_nd_sbp.sbp_parallel(0).split_parallel().axis());\n    return state;\n  }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state,\n               const user_op::OpKernelCache*) const override {\n    auto* kernel_state = dynamic_cast<NcclLogicalReduceScatterNoncontinuousKernelState*>(state);\n    CHECK(kernel_state != nullptr);\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n    const int64_t dtype_size = GetSizeOfDataType(in->data_type());\n    int64_t data_size = GetCudaAlignedSize(in->shape_view().elem_cnt() * dtype_size);\n    CHECK_EQ(tmp_buffer->shape_view().elem_cnt(), data_size);\n\n    CHECK_EQ(in->data_type(), out->data_type());\n    const int64_t num_ranks = ctx->parallel_ctx().parallel_num();\n    const int64_t out_split_axis = kernel_state->dst_split_axis();\n\n    DimVector logical_shape_dim_vec;\n    in->shape_view().ToDimVector(&logical_shape_dim_vec);\n\n    DimVector transpose_in_dim_vec = logical_shape_dim_vec;\n    transpose_in_dim_vec[out_split_axis] = transpose_in_dim_vec.at(out_split_axis) / num_ranks;\n    transpose_in_dim_vec.insert(transpose_in_dim_vec.begin() + out_split_axis, num_ranks);\n    const Shape transpose_in_shape(transpose_in_dim_vec);\n    std::vector<int32_t> perm;\n    perm.emplace_back(out_split_axis);\n    FOR_RANGE(int64_t, i, 0, transpose_in_dim_vec.size()) {\n      if (i != out_split_axis) { perm.emplace_back(i); }\n    }\n    auto transpose = ep::primitive::NewPrimitive<ep::primitive::PermuteFactory>(\n        ctx->stream()->device_type(), transpose_in_dim_vec.size());\n    CHECK(transpose);\n    transpose->Launch(ctx->stream(), in->data_type(), transpose_in_dim_vec.size(),\n                      transpose_in_dim_vec.data(), in->dptr(), perm.data(), tmp_buffer->mut_dptr());\n    VLOG(3) << \"[NcclLogical][ReduceScatterNoncontinuous] \" << kernel_state->stream_name() << \" \"\n            << ctx->op_name() << std::endl;\n\n    ccl::CclComm ccl_comm = kernel_state->ccl_comm();\n    ccl::ReduceType ccl_reduce_type = ccl::ReduceType::kSum;\n    if (in->data_type() == DataType::kBool) { ccl_reduce_type = ccl::ReduceType::kMax; }\n    std::unique_ptr<ccl::ReduceScatter> ccl_reduce_scatter =\n        ccl::NewCollectiveCommunication<ccl::ReduceScatter>(ctx->stream()->device_type(),\n                                                            in->data_type(), ccl_reduce_type);\n    ccl_reduce_scatter->Launch(ctx->stream(), tmp_buffer->dptr(), out->mut_dptr(),\n                               out->shape_view().elem_cnt(), ccl_comm);\n  };\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n  bool IsKernelLaunchSynchronized() const override {\n    const EagerCclCommMgr* comm_mgr = CHECK_NOTNULL(Singleton<EagerCclCommMgr>::Get());\n    return comm_mgr->IsAsyncLaunchCclLogicalKernel();\n  }\n};\n\nsize_t InferReduceScatterNoncontinuousKernelTmpBufferSize(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& in_tensor = ctx->OutputTensorDesc(\"in\", 0);\n  return GetCudaAlignedSize(in_tensor.shape().elem_cnt()\n                            * GetSizeOfDataType(in_tensor.data_type()));\n}\n\ntemplate<typename T>\nclass CclLogicalS2SKernel final : public user_op::OpKernel {\n public:\n  CclLogicalS2SKernel() = default;\n  ~CclLogicalS2SKernel() override = default;\n\n  std::shared_ptr<user_op::OpKernelState> CreateOpKernelState(\n      user_op::KernelInitContext* ctx) const override {\n    auto state = std::make_shared<NcclLogicalS2SKernelState>(ctx);\n    NdSbp src_nd_sbp;\n    NdSbp dst_nd_sbp;\n    CHECK_JUST(GetNcclLogicalNdSbpFromAttr(ctx, \"src_reduced_nd_sbp\", &src_nd_sbp));\n    CHECK_JUST(GetNcclLogicalNdSbpFromAttr(ctx, \"dst_reduced_nd_sbp\", &dst_nd_sbp));\n    CHECK_EQ(src_nd_sbp.sbp_parallel_size(), 1);\n    CHECK_EQ(dst_nd_sbp.sbp_parallel_size(), 1);\n    CHECK(src_nd_sbp.sbp_parallel(0).has_split_parallel());\n    CHECK(dst_nd_sbp.sbp_parallel(0).has_split_parallel());\n    state->set_src_split_axis(src_nd_sbp.sbp_parallel(0).split_parallel().axis());\n    state->set_dst_split_axis(dst_nd_sbp.sbp_parallel(0).split_parallel().axis());\n    return state;\n  }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state,\n               const user_op::OpKernelCache*) const override {\n    auto* kernel_state = dynamic_cast<NcclLogicalS2SKernelState*>(state);\n    CHECK_NOTNULL(kernel_state);\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n    int64_t tmp_size = 0;\n    const int64_t dtype_size = GetSizeOfDataType(in->data_type());\n    int64_t data_size = GetCudaAlignedSize(in->shape_view().elem_cnt() * dtype_size);\n    // NOTE(chengcheng): in (transpose)-> pack_to_ptr (all2all)-> unpack_from_ptr (transpose)-> out\n    const char* pack_to_ptr = in->dptr<char>();\n    char* unpack_from_ptr = out->mut_dptr<char>();\n    if (tmp_buffer) { tmp_size = tmp_buffer->shape_view().elem_cnt(); }\n    CHECK(tmp_size == 0 || tmp_size == data_size || tmp_size == data_size * 2);\n\n    CHECK_EQ(in->data_type(), out->data_type());\n    const int64_t num_ranks = ctx->parallel_ctx().parallel_num();\n    CHECK_EQ(in->shape_view().elem_cnt(), out->shape_view().elem_cnt());\n    const int64_t elem_cnt = in->shape_view().elem_cnt();\n    const int64_t in_split_axis = kernel_state->src_split_axis();\n    const int64_t out_split_axis = kernel_state->dst_split_axis();\n\n    DimVector logical_shape_dim_vec;\n    in->shape_view().ToDimVector(&logical_shape_dim_vec);\n    logical_shape_dim_vec[in_split_axis] = logical_shape_dim_vec.at(in_split_axis) * num_ranks;\n\n    VLOG(3) << \"[NcclLogical][S2S] \" << kernel_state->stream_name() << \" \" << ctx->op_name()\n            << std::endl;\n\n    if (out_split_axis != 0) {\n      // NOTE(chengcheng): Do pack. Need transpose in -> pack_to\n      // pack use temp buffer offset: [0, data_size]\n      pack_to_ptr = CHECK_NOTNULL(tmp_buffer)->dptr<char>();\n      DimVector transpose_in_dim_vec = logical_shape_dim_vec;\n      CHECK_EQ(transpose_in_dim_vec.at(in_split_axis) % num_ranks, 0);\n      transpose_in_dim_vec[in_split_axis] = transpose_in_dim_vec.at(in_split_axis) / num_ranks;\n      CHECK_EQ(transpose_in_dim_vec.at(out_split_axis) % num_ranks, 0);\n      transpose_in_dim_vec[out_split_axis] = transpose_in_dim_vec.at(out_split_axis) / num_ranks;\n      transpose_in_dim_vec.insert(transpose_in_dim_vec.begin() + out_split_axis, num_ranks);\n      std::vector<int32_t> perm;\n      perm.emplace_back(out_split_axis);\n      FOR_RANGE(int64_t, i, 0, transpose_in_dim_vec.size()) {\n        if (i != out_split_axis) { perm.emplace_back(i); }\n      }\n      auto transpose = ep::primitive::NewPrimitive<ep::primitive::PermuteFactory>(\n          ctx->stream()->device_type(), transpose_in_dim_vec.size());\n      CHECK(transpose);\n      transpose->Launch(ctx->stream(), in->data_type(), transpose_in_dim_vec.size(),\n                        transpose_in_dim_vec.data(), in->dptr(), perm.data(),\n                        tmp_buffer->mut_dptr());\n    }\n\n    if (in_split_axis != 0) {\n      // NOTE(chengcheng): Do unpack. Need transpose unpack_from -> out\n      // unpack use temp buffer offset: [tmp_size - data_size, tmp_size]\n      unpack_from_ptr = CHECK_NOTNULL(tmp_buffer)->mut_dptr<char>() + (tmp_size - data_size);\n    }\n\n    {\n      // NOTE(chengcheng): Do S2S\n      const int64_t elem_per_chunk = elem_cnt / num_ranks;\n      std::unique_ptr<ccl::AllToAll> all_to_all = ccl::NewCollectiveCommunication<ccl::AllToAll>(\n          ctx->stream()->device_type(), in->data_type(), in->data_type(), num_ranks);\n      ccl::CclComm ccl_comm = kernel_state->ccl_comm();\n      all_to_all->Launch(ctx->stream(), const_cast<char*>(pack_to_ptr), elem_per_chunk,\n                         unpack_from_ptr, elem_per_chunk, ccl_comm);\n    }\n\n    if (in_split_axis != 0) {\n      // Do unpack.\n      CHECK(unpack_from_ptr != out->mut_dptr<char>());\n      DimVector unpack_from_dim_vec = logical_shape_dim_vec;\n      CHECK_EQ(unpack_from_dim_vec.at(in_split_axis) % num_ranks, 0);\n      unpack_from_dim_vec[in_split_axis] = unpack_from_dim_vec.at(in_split_axis) / num_ranks;\n      CHECK_EQ(unpack_from_dim_vec.at(out_split_axis) % num_ranks, 0);\n      unpack_from_dim_vec[out_split_axis] = unpack_from_dim_vec.at(out_split_axis) / num_ranks;\n      unpack_from_dim_vec.insert(unpack_from_dim_vec.begin(), num_ranks);\n      std::vector<int32_t> perm;\n      FOR_RANGE(int64_t, i, 1, unpack_from_dim_vec.size()) { perm.emplace_back(i); }\n      perm.insert(perm.begin() + in_split_axis, 0);\n      auto transpose = ep::primitive::NewPrimitive<ep::primitive::PermuteFactory>(\n          ctx->stream()->device_type(), unpack_from_dim_vec.size());\n      CHECK(transpose);\n      transpose->Launch(ctx->stream(), in->data_type(), unpack_from_dim_vec.size(),\n                        unpack_from_dim_vec.data(), unpack_from_ptr, perm.data(), out->mut_dptr());\n    }\n  };\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n  bool IsKernelLaunchSynchronized() const override {\n    const EagerCclCommMgr* comm_mgr = CHECK_NOTNULL(Singleton<EagerCclCommMgr>::Get());\n    return comm_mgr->IsAsyncLaunchCclLogicalKernel();\n  }\n};\n\nsize_t InferS2SKernelTmpBufferSize(user_op::InferContext* ctx) {\n  size_t ret = 0;\n  const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc(\"in\", 0);\n  size_t tensor_byte_size =\n      GetCudaAlignedSize(in_tensor.shape().elem_cnt() * GetSizeOfDataType(in_tensor.data_type()));\n  NdSbp src_nd_sbp;\n  NdSbp dst_nd_sbp;\n  CHECK_JUST(GetNcclLogicalNdSbpFromAttr(ctx, \"src_reduced_nd_sbp\", &src_nd_sbp));\n  CHECK_JUST(GetNcclLogicalNdSbpFromAttr(ctx, \"dst_reduced_nd_sbp\", &dst_nd_sbp));\n  CHECK_EQ(src_nd_sbp.sbp_parallel_size(), 1);\n  CHECK_EQ(dst_nd_sbp.sbp_parallel_size(), 1);\n  CHECK(src_nd_sbp.sbp_parallel(0).has_split_parallel());\n  CHECK(dst_nd_sbp.sbp_parallel(0).has_split_parallel());\n  if (src_nd_sbp.sbp_parallel(0).split_parallel().axis() != 0) { ret += tensor_byte_size; }\n  if (dst_nd_sbp.sbp_parallel(0).split_parallel().axis() != 0) { ret += tensor_byte_size; }\n  return ret;\n}\n\n}  // namespace\n\nREGISTER_USER_KERNEL(\"_nccl_logical_all_reduce\")\n    .SetCreateFn<CclLogicalAllReduceKernel>()\n    .SetIsMatchedHob(AllReduceCollectiveCommunicationExists());\n\nREGISTER_USER_KERNEL(\"_nccl_logical_reduce_scatter\")\n    .SetCreateFn<CclLogicalReduceScatterKernel>()\n    .SetIsMatchedHob(ReduceScatterCollectiveCommunicationExists());\n\nREGISTER_USER_KERNEL(\"_nccl_logical_all_gather\")\n    .SetCreateFn<CclLogicalAllGatherKernel>()\n    .SetIsMatchedHob(AllGatherCollectiveCommunicationExists());\n\n#define REGISTER_ALLGATHER_NONCONTINUOUS_KERNEL(dtype)                                   \\\n  REGISTER_USER_KERNEL(\"_nccl_logical_all_gather_noncontinuous\")                         \\\n      .SetCreateFn<CclLogicalAllGatherNoncontinuous<dtype>>()                            \\\n      .SetIsMatchedHob(AllGatherCollectiveCommunicationExists()                          \\\n                       && (user_op::HobDataType(\"in\", 0) == GetDataType<dtype>::value)   \\\n                       && (user_op::HobDataType(\"out\", 0) == GetDataType<dtype>::value)) \\\n      .SetInferTmpSizeFn(InferAllGatherNoncontinuousKernelTmpBufferSize);\n\nREGISTER_ALLGATHER_NONCONTINUOUS_KERNEL(bool)\nREGISTER_ALLGATHER_NONCONTINUOUS_KERNEL(int8_t)\nREGISTER_ALLGATHER_NONCONTINUOUS_KERNEL(int32_t)\nREGISTER_ALLGATHER_NONCONTINUOUS_KERNEL(int64_t)\nREGISTER_ALLGATHER_NONCONTINUOUS_KERNEL(float)\nREGISTER_ALLGATHER_NONCONTINUOUS_KERNEL(double)\nREGISTER_ALLGATHER_NONCONTINUOUS_KERNEL(float16)\n#if defined(__CUDA_BF16_TYPES_EXIST__)\nREGISTER_ALLGATHER_NONCONTINUOUS_KERNEL(nv_bfloat16)\n#endif\n\n#define REGISTER_REDUCE_SCATTER_NONCONTINUOUS_KERNEL(dtype)                              \\\n  REGISTER_USER_KERNEL(\"_nccl_logical_reduce_scatter_noncontinuous\")                     \\\n      .SetCreateFn<CclLogicalReduceScatterNoncontinuous<dtype>>()                        \\\n      .SetIsMatchedHob(ReduceScatterCollectiveCommunicationExists()                      \\\n                       && (user_op::HobDataType(\"in\", 0) == GetDataType<dtype>::value)   \\\n                       && (user_op::HobDataType(\"out\", 0) == GetDataType<dtype>::value)) \\\n      .SetInferTmpSizeFn(InferReduceScatterNoncontinuousKernelTmpBufferSize);\n\nREGISTER_REDUCE_SCATTER_NONCONTINUOUS_KERNEL(bool)\nREGISTER_REDUCE_SCATTER_NONCONTINUOUS_KERNEL(int8_t)\nREGISTER_REDUCE_SCATTER_NONCONTINUOUS_KERNEL(int32_t)\nREGISTER_REDUCE_SCATTER_NONCONTINUOUS_KERNEL(int64_t)\nREGISTER_REDUCE_SCATTER_NONCONTINUOUS_KERNEL(float)\nREGISTER_REDUCE_SCATTER_NONCONTINUOUS_KERNEL(double)\nREGISTER_REDUCE_SCATTER_NONCONTINUOUS_KERNEL(float16)\n#if defined(__CUDA_BF16_TYPES_EXIST__)\nREGISTER_REDUCE_SCATTER_NONCONTINUOUS_KERNEL(nv_bfloat16)\n#endif\n\n#define REGISTER_S2S_KERNEL(dtype)                                                       \\\n  REGISTER_USER_KERNEL(\"_nccl_logical_s2s\")                                              \\\n      .SetCreateFn<CclLogicalS2SKernel<dtype>>()                                         \\\n      .SetIsMatchedHob(AllToAllCollectiveCommunicationExists()                           \\\n                       && (user_op::HobDataType(\"in\", 0) == GetDataType<dtype>::value)   \\\n                       && (user_op::HobDataType(\"out\", 0) == GetDataType<dtype>::value)) \\\n      .SetInferTmpSizeFn(InferS2SKernelTmpBufferSize);\n\nREGISTER_S2S_KERNEL(bool)\nREGISTER_S2S_KERNEL(int8_t)\nREGISTER_S2S_KERNEL(int32_t)\nREGISTER_S2S_KERNEL(int64_t)\nREGISTER_S2S_KERNEL(float)\nREGISTER_S2S_KERNEL(double)\nREGISTER_S2S_KERNEL(float16)\n#if defined(__CUDA_BF16_TYPES_EXIST__)\nREGISTER_S2S_KERNEL(nv_bfloat16)\n#endif\n\nREGISTER_USER_KERNEL_UNIFIED_CCL_COMM_INIT(\"_nccl_logical_all_reduce\");\nREGISTER_USER_KERNEL_UNIFIED_CCL_COMM_INIT(\"_nccl_logical_reduce_scatter\");\nREGISTER_USER_KERNEL_UNIFIED_CCL_COMM_INIT(\"_nccl_logical_all_gather\");\nREGISTER_USER_KERNEL_UNIFIED_CCL_COMM_INIT(\"_nccl_logical_all_gather_noncontinuous\");\nREGISTER_USER_KERNEL_UNIFIED_CCL_COMM_INIT(\"_nccl_logical_s2s\");\n\n}  // namespace oneflow\n\n#endif  // WITH_CUDA || WITH_NPU || WITH_MLU\n"
  },
  {
    "path": "oneflow/user/kernels/nccl_logical_send_recv_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"collective_communication/include/collective_communication.h\"\n#include \"oneflow/core/common/data_type.h\"\n#include \"oneflow/core/common/data_type.pb.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/device/nccl_util.h\"\n#include \"oneflow/core/job/parallel_desc.h\"\n#include \"oneflow/user/ops/nccl_logical_util.h\"\n#include \"oneflow/core/framework/infer_util.h\"\n#include \"oneflow/core/framework/op_kernel.h\"\n#include \"oneflow/core/job/eager_nccl_comm_manager.h\"\n#include \"oneflow/core/job/nd_sbp_util.h\"\n#include \"oneflow/core/register/tensor_slice_copier.h\"\n#include \"oneflow/core/ep/include/primitive/memset.h\"\n#include \"oneflow/core/ep/include/primitive/add.h\"\n#include \"oneflow/core/operator/nccl_send_recv_boxing_op_util.h\"\n#include \"oneflow/user/kernels/collective_communication/include/all_to_all.h\"\n\n#if (defined(WITH_CUDA) && (NCCL_VERSION_CODE > 2700)) || defined(WITH_NPU) || defined(WITH_MLU)\n\nnamespace oneflow {\n\nclass CclLogicalSendRecvState final : public user_op::OpKernelState {\n public:\n  explicit CclLogicalSendRecvState(user_op::KernelInitContext* ctx);\n  const std::vector<std::shared_ptr<TensorSliceCopier>>& in_tensor_slice_copier_vec() const {\n    return in_tensor_slice_copier_vec_;\n  }\n  const std::vector<std::shared_ptr<TensorSliceCopier>>& out_tensor_slice_copier_vec() const {\n    return out_tensor_slice_copier_vec_;\n  }\n  bool src_nd_sbp_has_no_partial_parallel() const { return src_nd_sbp_no_partial_parallel_; }\n  const std::vector<int64_t>& send_elem_cnts() const { return send_elem_cnts_; }\n  const std::vector<int64_t>& recv_elem_cnts() const { return recv_elem_cnts_; }\n  ccl::CclComm ccl_comm() const { return GetOrCreateComm().ccl_comm; }\n\n private:\n  struct Comm {\n    Comm(ccl::CclComm comm) : ccl_comm(comm) {}\n    ccl::CclComm ccl_comm;\n  };\n\n  void InitComm() const;\n  const Comm& GetOrCreateComm() const {\n    if (!ccl_comm_) { InitComm(); }\n    return *ccl_comm_;\n  }\n\n  std::string stream_name_;\n  std::unique_ptr<ParallelDesc> parallel_desc_;\n  mutable std::unique_ptr<Comm> ccl_comm_;\n  bool src_nd_sbp_no_partial_parallel_;\n  std::vector<std::shared_ptr<TensorSliceCopier>> in_tensor_slice_copier_vec_;\n  std::vector<std::shared_ptr<TensorSliceCopier>> out_tensor_slice_copier_vec_;\n  std::vector<int64_t> send_elem_cnts_;\n  std::vector<int64_t> recv_elem_cnts_;\n};\n\nCclLogicalSendRecvState::CclLogicalSendRecvState(user_op::KernelInitContext* ctx)\n    : stream_name_(EagerCclCommMgr::kDefaultCclStreamName) {\n  if (ctx->op_conf().has_stream_name_hint()) { stream_name_ = ctx->op_conf().stream_name_hint(); }\n  const int64_t parallel_id = ctx->parallel_ctx().parallel_id();\n  parallel_desc_ = std::make_unique<ParallelDesc>(ctx->parallel_desc());\n  NdSbp src_nd_sbp;\n  CHECK_JUST(GetNcclLogicalNdSbpFromAttr(ctx, \"src_reduced_nd_sbp\", &src_nd_sbp));\n  NdSbp dst_nd_sbp;\n  CHECK_JUST(GetNcclLogicalNdSbpFromAttr(ctx, \"dst_reduced_nd_sbp\", &dst_nd_sbp));\n  const auto& parallel_hierarchy = parallel_desc_->hierarchy();\n  src_nd_sbp_no_partial_parallel_ = !NdSbpHasPartialParallel(src_nd_sbp);\n  CHECK_EQ(src_nd_sbp.sbp_parallel_size(), parallel_hierarchy->NumAxes());\n  CHECK_EQ(dst_nd_sbp.sbp_parallel_size(), parallel_hierarchy->NumAxes());\n  const user_op::TensorDesc* in_logical_desc = ctx->LogicalTensorDesc4ArgNameAndIndex(\"in\", 0);\n  const DataType data_type = in_logical_desc->data_type();\n  const Shape& logical_shape = Shape(in_logical_desc->shape());\n  const DeviceType device_type = parallel_desc_->device_type();\n  const int64_t parallel_num = parallel_desc_->parallel_num();\n\n  std::vector<TensorSliceView> src_send_intersections;\n  std::vector<TensorSliceView> dst_recv_intersections;\n  GetRankSendRecvIntersection(parallel_id, /*merge_parallel_desc=*/*parallel_desc_,\n                              /*in_parallel_desc=*/*parallel_desc_,\n                              /*out_parallel_desc=*/*parallel_desc_, src_nd_sbp, dst_nd_sbp,\n                              logical_shape, &src_send_intersections, &dst_recv_intersections);\n\n  CHECK_EQ(src_send_intersections.size(), parallel_num);\n  send_elem_cnts_.resize(parallel_num);\n  in_tensor_slice_copier_vec_.resize(parallel_num);\n  const TensorSliceView& cur_rank_in_slice =\n      GetTensorSliceView4ParallelId(*parallel_hierarchy, src_nd_sbp, logical_shape, parallel_id);\n  for (int64_t i = 0; i < parallel_num; ++i) {\n    const TensorSliceView& intersection = src_send_intersections.at(i);\n    if (!intersection.IsEmpty()) {\n      send_elem_cnts_.at(i) = intersection.shape().elem_cnt();\n      in_tensor_slice_copier_vec_.at(i).reset(\n          new TensorSliceCopier(intersection, cur_rank_in_slice, data_type, device_type));\n    }\n  }\n\n  CHECK_EQ(dst_recv_intersections.size(), parallel_num);\n  recv_elem_cnts_.resize(parallel_num);\n  out_tensor_slice_copier_vec_.resize(parallel_num);\n  const TensorSliceView& cur_rank_out_slice =\n      GetTensorSliceView4ParallelId(*parallel_hierarchy, dst_nd_sbp, logical_shape, parallel_id);\n  for (int64_t i = 0; i < parallel_num; ++i) {\n    const TensorSliceView& intersection = dst_recv_intersections.at(i);\n    if (!intersection.IsEmpty()) {\n      recv_elem_cnts_.at(i) = intersection.shape().elem_cnt();\n      out_tensor_slice_copier_vec_.at(i).reset(\n          new TensorSliceCopier(cur_rank_out_slice, intersection, data_type, device_type));\n    }\n  }\n}\n\nvoid CclLogicalSendRecvState::InitComm() const {\n  EagerCclCommMgr* comm_mgr = CHECK_NOTNULL(Singleton<EagerCclCommMgr>::Get());\n  ccl::CclComm ccl_comm =\n      comm_mgr->GetCclCommForParallelDescAndStreamName(*parallel_desc_.get(), stream_name_);\n  ccl_comm_.reset(new Comm(ccl_comm));\n}\n\nclass CclLogicalSendRecv final : public user_op::OpKernel {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CclLogicalSendRecv);\n  CclLogicalSendRecv() = default;\n  ~CclLogicalSendRecv() override = default;\n\n  std::shared_ptr<user_op::OpKernelState> CreateOpKernelState(\n      user_op::KernelInitContext* ctx) const override {\n    return std::make_shared<CclLogicalSendRecvState>(ctx);\n  }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state,\n               const user_op::OpKernelCache*) const override;\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n  bool IsKernelLaunchSynchronized() const override {\n    EagerCclCommMgr* comm_mgr = CHECK_NOTNULL(Singleton<EagerCclCommMgr>::Get());\n    return comm_mgr->IsAsyncLaunchCclLogicalKernel();\n  }\n};\n\nvoid CclLogicalSendRecv::Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state,\n                                 const user_op::OpKernelCache*) const {\n  auto* kernel_state = dynamic_cast<CclLogicalSendRecvState*>(state);\n  CHECK_NOTNULL(kernel_state);\n  const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n  user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n  user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n  ccl::CclComm ccl_comm = kernel_state->ccl_comm();\n  const std::vector<int64_t>& send_elem_cnts = kernel_state->send_elem_cnts();\n  const std::vector<int64_t>& recv_elem_cnts = kernel_state->recv_elem_cnts();\n  const int64_t parallel_num = send_elem_cnts.size();\n  const DataType data_type = in->data_type();\n\n  std::vector<void*> send_in_ptr;\n  std::vector<void*> recv_out_ptr;\n  std::vector<int64_t> send_offsets;\n  std::vector<int64_t> recv_offsets;\n  char* buf_ptr = tmp_buffer->mut_dptr<char>();\n  uint64_t offset = 0;\n  for (int64_t i = 0; i < parallel_num; ++i) {\n    void* send_ptr = reinterpret_cast<void*>(buf_ptr + offset);\n    send_in_ptr.push_back(send_ptr);\n    send_offsets.push_back(offset);\n    offset += send_elem_cnts.at(i) * GetSizeOfDataType(data_type);\n  }\n  const uint64_t recv_offset = offset;\n  for (int64_t i = 0; i < parallel_num; ++i) {\n    void* recv_ptr = reinterpret_cast<void*>(buf_ptr + offset);\n    recv_out_ptr.push_back(recv_ptr);\n    recv_offsets.push_back(offset - recv_offset);\n    offset += recv_elem_cnts.at(i) * GetSizeOfDataType(data_type);\n  }\n\n  const std::vector<std::shared_ptr<TensorSliceCopier>>& in_tensor_slice_copier_vec =\n      kernel_state->in_tensor_slice_copier_vec();\n  for (int64_t i = 0; i < parallel_num; ++i) {\n    if (in_tensor_slice_copier_vec.at(i)) {\n      in_tensor_slice_copier_vec.at(i)->Copy(ctx->stream(), send_in_ptr.at(i), in->dptr());\n    }\n  }\n\n  std::unique_ptr<ccl::AllToAll> all_to_all = ccl::NewCollectiveCommunication<ccl::AllToAll>(\n      ctx->stream()->device_type(), data_type, data_type, parallel_num);\n  void* send_buf = reinterpret_cast<void*>(buf_ptr);\n  void* recv_buf = reinterpret_cast<void*>(buf_ptr + recv_offset);\n  all_to_all->Launch(ctx->stream(), send_buf, send_elem_cnts.data(), send_offsets.data(), recv_buf,\n                     recv_elem_cnts.data(), recv_offsets.data(), ccl_comm, /*has_input=*/true,\n                     /*has_output=*/true);\n\n  const std::vector<std::shared_ptr<TensorSliceCopier>>& out_tensor_slice_copier_vec =\n      kernel_state->out_tensor_slice_copier_vec();\n\n  if (kernel_state->src_nd_sbp_has_no_partial_parallel()) {\n    for (int64_t i = 0; i < parallel_num; ++i) {\n      if (out_tensor_slice_copier_vec.at(i)) {\n        out_tensor_slice_copier_vec.at(i)->Copy(ctx->stream(), out->mut_dptr(), recv_out_ptr.at(i));\n      }\n    }\n  } else {\n    std::unique_ptr<ep::primitive::Add> add_primitive =\n        ep::primitive::NewPrimitive<ep::primitive::AddFactory>(ctx->stream()->device_type(),\n                                                               out->data_type());\n    CHECK(add_primitive);\n    std::unique_ptr<ep::primitive::Memset> memset_primitive =\n        ep::primitive::NewPrimitive<ep::primitive::MemsetFactory>(ctx->stream()->device_type());\n    CHECK(memset_primitive);\n    bool is_first_slice = true;\n    for (int64_t i = 0; i < parallel_num; ++i) {\n      if (out_tensor_slice_copier_vec.at(i)) {\n        if (is_first_slice) {\n          is_first_slice = false;\n          if (recv_elem_cnts.at(i) != out->shape_view().elem_cnt()) {\n            // if not same shape, memset out\n            memset_primitive->Launch(ctx->stream(), out->mut_dptr(), 0,\n                                     out->shape_view().elem_cnt() * GetSizeOfDataType(data_type));\n          }\n          out_tensor_slice_copier_vec.at(i)->Copy(ctx->stream(), out->mut_dptr(),\n                                                  recv_out_ptr.at(i));\n        } else {\n          if (recv_elem_cnts.at(i) == out->shape_view().elem_cnt()) {\n            add_primitive->Launch(ctx->stream(), out->dptr(), recv_out_ptr.at(i), out->mut_dptr(),\n                                  out->shape_view().elem_cnt());\n          } else {\n            void* out_buf = reinterpret_cast<void*>(buf_ptr + offset);\n            memset_primitive->Launch(ctx->stream(), out_buf, 0,\n                                     out->shape_view().elem_cnt() * GetSizeOfDataType(data_type));\n            out_tensor_slice_copier_vec.at(i)->Copy(ctx->stream(), out_buf, recv_out_ptr.at(i));\n            add_primitive->Launch(ctx->stream(), out->dptr(), out_buf, out->mut_dptr(),\n                                  out->shape_view().elem_cnt());\n          }\n        }\n      }\n    }\n  }\n}\n\nsize_t InferTmpBufferSize(user_op::InferContext* ctx) {\n  const Shape& out_shape = ctx->OutputShape(\"out\", 0);\n  const user_op::TensorDesc* logical_in_tensor = ctx->LogicalTensorDesc4ArgNameAndIndex(\"in\", 0);\n  const Shape& logical_shape = logical_in_tensor->shape();\n  const DataType data_type = logical_in_tensor->data_type();\n\n  const NdSbp& src_nd_sbp = ctx->NdSbp4ArgNameAndIndex(\"in\", 0);\n  const NdSbp& dst_nd_sbp = ctx->NdSbp4ArgNameAndIndex(\"out\", 0);\n  const int64_t parallel_num = ctx->parallel_num();\n  const int64_t parallel_id = ctx->parallel_ctx().parallel_id();\n\n  std::vector<TensorSliceView> src_send_intersections;\n  std::vector<TensorSliceView> dst_recv_intersections;\n  const auto& parallel_desc = ctx->parallel_desc();\n  GetRankSendRecvIntersection(parallel_id, /*merge_parallel_desc=*/parallel_desc,\n                              /*in_parallel_desc=*/parallel_desc,\n                              /*out_parallel_desc=*/parallel_desc, src_nd_sbp, dst_nd_sbp,\n                              logical_shape, &src_send_intersections, &dst_recv_intersections);\n  int64_t buf_count = 0;\n  CHECK_EQ(src_send_intersections.size(), parallel_num);\n  for (int64_t i = 0; i < parallel_num; ++i) {\n    const TensorSliceView& intersection = src_send_intersections.at(i);\n    if (!intersection.IsEmpty()) { buf_count += intersection.shape().elem_cnt(); }\n  }\n  for (int64_t i = 0; i < parallel_num; ++i) {\n    const TensorSliceView& intersection = dst_recv_intersections.at(i);\n    if (!intersection.IsEmpty()) { buf_count += intersection.shape().elem_cnt(); }\n  }\n  if (NdSbpHasPartialParallel(src_nd_sbp)) {\n    // Note: when src_nd_sbp has partial_sum, need a out_size buffer to copy and add to out.\n    buf_count += out_shape.elem_cnt();\n  }\n  return buf_count * GetSizeOfDataType(data_type);\n}\n\n// TODO:(zhaoluyang) SetIsMatchedHob support multi devices(not including cpu)\nREGISTER_USER_KERNEL(\"_nccl_logical_send_recv\")\n    .SetCreateFn<CclLogicalSendRecv>()\n    .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)\n                     || (user_op::HobDeviceType() == DeviceType::kNPU)\n                     || (user_op::HobDeviceType() == DeviceType::kMLU))\n    .SetInferTmpSizeFn(InferTmpBufferSize);\n\n}  // namespace oneflow\n\n#endif  // WITH_CUDA || WITH_NPU || WITH_MLU\n"
  },
  {
    "path": "oneflow/user/kernels/nd_index_slice_kernels.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/user/kernels/nd_index_slice_kernels.h\"\n\nnamespace oneflow {\n\ntemplate<typename T, typename I>\nstruct GatherNdFunctor<DeviceType::kCPU, T, I> final {\n  void operator()(ep::Stream* stream, const NdIndexSliceArgs& args, const I* indices,\n                  const T* dense, T* slices) const {\n    DoGatherNd(args.num_slices * args.slice_size, args.slice_size, args.index_ndims,\n               args.dense_shape, indices, dense, slices);\n  }\n};\n\ntemplate<typename T, typename I>\nstruct ScatterNdAddFunctor<DeviceType::kCPU, T, I> final {\n  void operator()(ep::Stream* stream, const NdIndexSliceArgs& args, const I* indices,\n                  const T* slices, T* dense) const {\n    DoScatterNdAdd<DeviceType::kCPU>(args.num_slices * args.slice_size, args.slice_size,\n                                     args.index_ndims, args.dense_shape, indices, slices, dense);\n  }\n};\n\ntemplate<typename T, typename I>\nstruct ScatterNdUpdateFunctor<DeviceType::kCPU, T, I> final {\n  void operator()(ep::Stream* stream, const NdIndexSliceArgs& args, const I* indices,\n                  const T* slices, T* dense) const {\n    DoScatterNdUpdate<DeviceType::kCPU>(args.num_slices * args.slice_size, args.slice_size,\n                                        args.index_ndims, args.dense_shape, indices, slices, dense);\n  }\n};\n\ntemplate<typename T, typename I>\nstruct ScatterNdUpdateWithStrideFunctor<DeviceType::kCPU, T, I> final {\n  void operator()(ep::Stream* stream, const NdIndexSliceArgs& args, const I* indices,\n                  const T* slices, T* dense) const {\n    DoScatterNdUpdateWithStride<DeviceType::kCPU>(args.num_slices * args.slice_size, args, indices,\n                                                  slices, dense);\n  }\n};\n\ntemplate<typename T, typename I>\nstruct FillByNdIndexFunctor<DeviceType::kCPU, T, I> final {\n  void operator()(ep::Stream* stream, const NdIndexSliceArgs& args, const I* indices, T* dense,\n                  T value) const {\n    DoFillByNdIndex(args.num_slices * args.slice_size, args.slice_size, args.index_ndims,\n                    args.dense_shape, indices, dense, value);\n  }\n};\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_ND_INDEX_SLICE_FUNCTORS, (DeviceType::kCPU),\n                                 ARITHMETIC_DATA_TYPE_SEQ FLOAT16_DATA_TYPE_SEQ\n                                     BFLOAT16_DATA_TYPE_SEQ UNSIGNED_INT_DATA_TYPE_SEQ\n                                         BOOL_DATA_TYPE_SEQ,\n                                 INDEX_DATA_TYPE_SEQ)\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_ND_INDEX_SLICE_KERNELS, (DeviceType::kCPU),\n                                 ARITHMETIC_DATA_TYPE_SEQ FLOAT16_DATA_TYPE_SEQ\n                                     BFLOAT16_DATA_TYPE_SEQ UNSIGNED_INT_DATA_TYPE_SEQ\n                                         BOOL_DATA_TYPE_SEQ,\n                                 INDEX_DATA_TYPE_SEQ)\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/nd_index_slice_kernels.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/user/kernels/nd_index_slice_kernels.h\"\n#include \"oneflow/core/cuda/atomic.cuh\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename T, typename I>\n__global__ void CudaGatherNd(NdIndexSliceArgs args, const I* indices, const T* dense, T* slices) {\n  DoGatherNd(args.num_slices * args.slice_size, args.slice_size, args.index_ndims, args.dense_shape,\n             indices, dense, slices);\n}\n\ntemplate<typename T, typename I>\n__global__ void CudaScatterNdAdd(NdIndexSliceArgs args, const I* indices, const T* slices,\n                                 T* dense) {\n  DoScatterNdAdd<DeviceType::kCUDA>(args.num_slices * args.slice_size, args.slice_size,\n                                    args.index_ndims, args.dense_shape, indices, slices, dense);\n}\n\ntemplate<typename T, typename I>\n__global__ void CudaScatterNdUpdate(NdIndexSliceArgs args, const I* indices, const T* slices,\n                                    T* dense) {\n  DoScatterNdUpdate<DeviceType::kCUDA>(args.num_slices * args.slice_size, args.slice_size,\n                                       args.index_ndims, args.dense_shape, indices, slices, dense);\n}\n\ntemplate<typename T, typename I>\n__global__ void CudaScatterNdUpdateWithStride(NdIndexSliceArgs args, const I* indices,\n                                              const T* slices, T* dense) {\n  DoScatterNdUpdateWithStride<DeviceType::kCUDA>(args.num_slices * args.slice_size, args, indices,\n                                                 slices, dense);\n}\n\ntemplate<typename T, typename I>\n__global__ void CudaFillByNdIndex(NdIndexSliceArgs args, const I* indices, T* dense, T value) {\n  DoFillByNdIndex(args.num_slices * args.slice_size, args.slice_size, args.index_ndims,\n                  args.dense_shape, indices, dense, value);\n}\n\n}  // namespace\n\ntemplate<typename T, typename I>\nstruct GatherNdFunctor<DeviceType::kCUDA, T, I> final {\n  void operator()(ep::Stream* stream, const NdIndexSliceArgs& args, const I* indices,\n                  const T* dense, T* slices) const {\n    RUN_CUDA_KERNEL((CudaGatherNd<T, I>), stream, args.num_slices * args.slice_size, args, indices,\n                    dense, slices);\n  }\n};\n\ntemplate<typename T, typename I>\nstruct ScatterNdAddFunctor<DeviceType::kCUDA, T, I> final {\n  void operator()(ep::Stream* stream, const NdIndexSliceArgs& args, const I* indices,\n                  const T* slices, T* dense) const {\n    RUN_CUDA_KERNEL((CudaScatterNdAdd<T, I>), stream, args.num_slices * args.slice_size, args,\n                    indices, slices, dense);\n  }\n};\n\ntemplate<typename T, typename I>\nstruct ScatterNdUpdateFunctor<DeviceType::kCUDA, T, I> final {\n  void operator()(ep::Stream* stream, const NdIndexSliceArgs& args, const I* indices,\n                  const T* slices, T* dense) const {\n    RUN_CUDA_KERNEL((CudaScatterNdUpdate<T, I>), stream, args.num_slices * args.slice_size, args,\n                    indices, slices, dense);\n  }\n};\n\ntemplate<typename T, typename I>\nstruct ScatterNdUpdateWithStrideFunctor<DeviceType::kCUDA, T, I> final {\n  void operator()(ep::Stream* stream, const NdIndexSliceArgs& args, const I* indices,\n                  const T* slices, T* dense) const {\n    RUN_CUDA_KERNEL((CudaScatterNdUpdateWithStride<T, I>), stream,\n                    args.num_slices * args.slice_size, args, indices, slices, dense);\n  }\n};\n\ntemplate<typename T, typename I>\nstruct FillByNdIndexFunctor<DeviceType::kCUDA, T, I> final {\n  void operator()(ep::Stream* stream, const NdIndexSliceArgs& args, const I* indices, T* dense,\n                  T value) const {\n    RUN_CUDA_KERNEL((CudaFillByNdIndex<T, I>), stream, args.num_slices * args.slice_size, args,\n                    indices, dense, value);\n  }\n};\n\ntemplate<typename T>\nstruct DeviceAdd<DeviceType::kCUDA, T> {\n  __device__ __forceinline__ static void Invoke(const T* x, T* y) { cuda::atomic::Add(y, *x); }\n};\n\ntemplate<>\nstruct DeviceAdd<DeviceType::kCUDA, bool> {\n  __device__ __forceinline__ static void Invoke(const bool* x, bool* y) { *y += *x; }\n};\n\ntemplate<>\nstruct DeviceAdd<DeviceType::kCUDA, uint8_t> {\n  __device__ __forceinline__ static void Invoke(const uint8_t* x, uint8_t* y) { *y += *x; }\n};\n\ntemplate<>\nstruct DeviceAdd<DeviceType::kCUDA, int8_t> {\n  __device__ __forceinline__ static void Invoke(const int8_t* x, int8_t* y) { *y += *x; }\n};\n\ntemplate<>\nstruct DeviceAdd<DeviceType::kCUDA, int64_t> {\n  __device__ __forceinline__ static void Invoke(const int64_t* x, int64_t* y) { *y += *x; }\n};\n\n#define CUDA_ATOMIC_ADD_SUPPORTED_DATA_TYPE_SEQ \\\n  FLOATING_DATA_TYPE_SEQ                        \\\n  OF_PP_MAKE_TUPLE_SEQ(int32_t, DataType::kInt32)\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(\n    INSTANTIATE_GATHER_ND_FUNCTOR, (DeviceType::kCUDA),\n    ARITHMETIC_DATA_TYPE_SEQ UNSIGNED_INT_DATA_TYPE_SEQ BOOL_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ)\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_SCATTER_ND_ADD_FUNCTOR, (DeviceType::kCUDA),\n                                 CUDA_ATOMIC_ADD_SUPPORTED_DATA_TYPE_SEQ BOOL_DATA_TYPE_SEQ,\n                                 INDEX_DATA_TYPE_SEQ)\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_FILL_BY_ND_INDEX_FUNCTOR, (DeviceType::kCUDA),\n                                 ARITHMETIC_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ)\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(\n    REGISTER_GATHER_ND_KERNELS, (DeviceType::kCUDA),\n    ARITHMETIC_DATA_TYPE_SEQ UNSIGNED_INT_DATA_TYPE_SEQ BOOL_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ)\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(\n    REGISTER_SCATTER_ND_KERNELS, (DeviceType::kCUDA),\n    ARITHMETIC_DATA_TYPE_SEQ UNSIGNED_INT_DATA_TYPE_SEQ BOOL_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ)\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_SCATTER_ND_LIKE_KERNELS, (DeviceType::kCUDA),\n                                 CUDA_ATOMIC_ADD_SUPPORTED_DATA_TYPE_SEQ BOOL_DATA_TYPE_SEQ,\n                                 INDEX_DATA_TYPE_SEQ)\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(\n    REGISTER_TENSOR_GATHER_ND_UPDATE_KERNELS, (DeviceType::kCUDA),\n    ARITHMETIC_DATA_TYPE_SEQ UNSIGNED_INT_DATA_TYPE_SEQ BOOL_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ)\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_TENSOR_GATHER_ND_ADD_KERNELS, (DeviceType::kCUDA),\n                                 CUDA_ATOMIC_ADD_SUPPORTED_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ)\n\ntemplate<>\nstruct DeviceAdd<DeviceType::kCUDA, float16> {\n  __device__ __forceinline__ static void Invoke(const float16* x, float16* y) {\n    cuda::atomic::Add(reinterpret_cast<half*>(y), *(reinterpret_cast<const half*>(x)));\n  }\n};\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_ND_INDEX_SLICE_FUNCTORS, (DeviceType::kCUDA),\n                                 FLOAT16_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ)\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_ND_INDEX_SLICE_KERNELS, (DeviceType::kCUDA),\n                                 FLOAT16_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ)\n\n#if defined(__CUDA_BF16_TYPES_EXIST__)\ntemplate<>\nstruct DeviceAdd<DeviceType::kCUDA, bfloat16> {\n  __device__ __forceinline__ static void Invoke(const bfloat16* x, bfloat16* y) {\n    cuda::atomic::Add(reinterpret_cast<nv_bfloat16*>(y),\n                      *(reinterpret_cast<const nv_bfloat16*>(x)));\n  }\n};\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_ND_INDEX_SLICE_FUNCTORS, (DeviceType::kCUDA),\n                                 BFLOAT16_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ)\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_ND_INDEX_SLICE_KERNELS, (DeviceType::kCUDA),\n                                 BFLOAT16_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ)\n#endif\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/nd_index_slice_kernels.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_KERNELS_ND_INDEX_SLICE_KERNELS_H_\n#define ONEFLOW_USER_KERNELS_ND_INDEX_SLICE_KERNELS_H_\n\n#include \"oneflow/user/kernels/nd_index_slice_util.h\"\n#include \"oneflow/core/common/tensor_meta.h\"\n\nnamespace oneflow {\n\ntemplate<DeviceType device_type, typename T, typename I>\nclass GatherNdKernel final : public user_op::OpKernel {\n public:\n  GatherNdKernel() = default;\n  ~GatherNdKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override;\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\ntemplate<DeviceType device_type, typename T, typename I>\nclass ScatterNdKernel final : public user_op::OpKernel {\n public:\n  ScatterNdKernel() = default;\n  ~ScatterNdKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override;\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\ntemplate<DeviceType device_type, typename T, typename I>\nclass TensorScatterNdUpdateKernel final : public user_op::OpKernel {\n public:\n  TensorScatterNdUpdateKernel() = default;\n  ~TensorScatterNdUpdateKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override;\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\ntemplate<DeviceType device_type, typename T, typename I>\nclass TensorScatterNdAddKernel final : public user_op::OpKernel {\n public:\n  TensorScatterNdAddKernel() = default;\n  ~TensorScatterNdAddKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override;\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\ntemplate<DeviceType device_type, typename T, typename I>\nvoid GatherNdKernel<device_type, T, I>::Compute(user_op::KernelComputeContext* ctx) const {\n  const user_op::Tensor* indices = ctx->Tensor4ArgNameAndIndex(\"indices\", 0);\n  const user_op::Tensor* params = ctx->Tensor4ArgNameAndIndex(\"params\", 0);\n  user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n  if (params->shape_view().elem_cnt() == 0 || indices->shape_view().elem_cnt() == 0) { return; }\n  auto args = ConstructNdIndexSliceArgs(*params, *out, *indices);\n  GatherNdFunctor<device_type, T, I>()(ctx->stream(), args, indices->dptr<I>(), params->dptr<T>(),\n                                       out->mut_dptr<T>());\n}\n\ntemplate<DeviceType device_type, typename T, typename I>\nvoid ScatterNdKernel<device_type, T, I>::Compute(user_op::KernelComputeContext* ctx) const {\n  const user_op::Tensor* indices = ctx->Tensor4ArgNameAndIndex(\"indices\", 0);\n  const user_op::Tensor* updates = ctx->Tensor4ArgNameAndIndex(\"updates\", 0);\n  user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n  size_t out_bytes_size = out->shape_view().elem_cnt() * GetSizeOfDataType(out->data_type());\n  Memset<device_type>(ctx->stream(), out->mut_dptr<T>(), 0, out_bytes_size);\n  if (indices->shape_view().elem_cnt() == 0) { return; }\n  auto args = ConstructNdIndexSliceArgs(*out, *updates, *indices);\n  ScatterNdAddFunctor<device_type, T, I>()(ctx->stream(), args, indices->dptr<I>(),\n                                           updates->dptr<T>(), out->mut_dptr<T>());\n}\n\ntemplate<DeviceType device_type, typename T, typename I>\nvoid TensorScatterNdUpdateKernel<device_type, T, I>::Compute(\n    user_op::KernelComputeContext* ctx) const {\n  const user_op::Tensor* params = ctx->Tensor4ArgNameAndIndex(\"params\", 0);\n  const user_op::Tensor* indices = ctx->Tensor4ArgNameAndIndex(\"indices\", 0);\n  const user_op::Tensor* updates = ctx->Tensor4ArgNameAndIndex(\"updates\", 0);\n  user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n  size_t out_bytes_size = out->shape_view().elem_cnt() * GetSizeOfDataType(out->data_type());\n  Memcpy<device_type>(ctx->stream(), out->mut_dptr<T>(), params->dptr<T>(), out_bytes_size);\n  if (indices->shape_view().elem_cnt() == 0) { return; }\n  auto args = ConstructNdIndexSliceArgs(*params, *updates, *indices);\n  if (one::IsContiguous(params->shape_view(), params->stride())\n      && one::IsContiguous(updates->shape_view(), updates->stride())) {\n    ScatterNdUpdateFunctor<device_type, T, I>()(ctx->stream(), args, indices->dptr<I>(),\n                                                updates->dptr<T>(), out->mut_dptr<T>());\n  } else {\n    ScatterNdUpdateWithStrideFunctor<device_type, T, I>()(ctx->stream(), args, indices->dptr<I>(),\n                                                          updates->dptr<T>(), out->mut_dptr<T>());\n  }\n}\n\ntemplate<DeviceType device_type, typename T, typename I>\nvoid TensorScatterNdAddKernel<device_type, T, I>::Compute(\n    user_op::KernelComputeContext* ctx) const {\n  const user_op::Tensor* params = ctx->Tensor4ArgNameAndIndex(\"params\", 0);\n  const user_op::Tensor* indices = ctx->Tensor4ArgNameAndIndex(\"indices\", 0);\n  const user_op::Tensor* updates = ctx->Tensor4ArgNameAndIndex(\"updates\", 0);\n  user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n  size_t out_bytes_size = out->shape_view().elem_cnt() * GetSizeOfDataType(out->data_type());\n  Memcpy<device_type>(ctx->stream(), out->mut_dptr<T>(), params->dptr<T>(), out_bytes_size);\n  if (indices->shape_view().elem_cnt() == 0) { return; }\n  auto args = ConstructNdIndexSliceArgs(*params, *updates, *indices);\n  ScatterNdAddFunctor<device_type, T, I>()(ctx->stream(), args, indices->dptr<I>(),\n                                           updates->dptr<T>(), out->mut_dptr<T>());\n}\n\n#define REGISTER_GATHER_SCATTER_ND_KERNELS(op_type_name, op, device_type_v, dtype_pair,            \\\n                                           itype_pair)                                             \\\n  REGISTER_USER_KERNEL(#op_type_name)                                                              \\\n      .SetCreateFn<                                                                                \\\n          op##Kernel<device_type_v, OF_PP_PAIR_FIRST(dtype_pair), OF_PP_PAIR_FIRST(itype_pair)>>() \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == device_type_v)                                 \\\n                       && (user_op::HobDataType(\"indices\", 0) == OF_PP_PAIR_SECOND(itype_pair))    \\\n                       && (user_op::HobDataType(\"out\", 0) == OF_PP_PAIR_SECOND(dtype_pair)));\n\n#define REGISTER_TENSOR_SCATTER_ND_OPT_KERNELS(op_type_name, opt, device_type_v, dtype_pair,    \\\n                                               itype_pair)                                      \\\n  REGISTER_USER_KERNEL(#op_type_name)                                                           \\\n      .SetCreateFn<TensorScatterNd##opt##Kernel<device_type_v, OF_PP_PAIR_FIRST(dtype_pair),    \\\n                                                OF_PP_PAIR_FIRST(itype_pair)>>()                \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == device_type_v)                              \\\n                       && (user_op::HobDataType(\"indices\", 0) == OF_PP_PAIR_SECOND(itype_pair)) \\\n                       && (user_op::HobDataType(\"out\", 0) == OF_PP_PAIR_SECOND(dtype_pair)))    \\\n      .SetInplaceProposalFn(                                                                    \\\n          [](const user_op::InferContext&,                                                      \\\n             const user_op::AddInplaceArgPair& AddInplaceArgPairFn) -> Maybe<void> {            \\\n            OF_RETURN_IF_ERROR(AddInplaceArgPairFn(\"out\", 0, \"params\", 0, true));               \\\n            return Maybe<void>::Ok();                                                           \\\n          });\n\n#define REGISTER_GATHER_ND_KERNELS(device_type_v, dtype_pair, itype_pair) \\\n  REGISTER_GATHER_SCATTER_ND_KERNELS(gather_nd, GatherNd, device_type_v, dtype_pair, itype_pair)\n\n#define REGISTER_SCATTER_ND_KERNELS(device_type_v, dtype_pair, itype_pair) \\\n  REGISTER_GATHER_SCATTER_ND_KERNELS(scatter_nd, ScatterNd, device_type_v, dtype_pair, itype_pair)\n\n#define REGISTER_SCATTER_ND_LIKE_KERNELS(device_type_v, dtype_pair, itype_pair)             \\\n  REGISTER_GATHER_SCATTER_ND_KERNELS(scatter_nd_like, ScatterNd, device_type_v, dtype_pair, \\\n                                     itype_pair)\n\n#define REGISTER_TENSOR_GATHER_ND_UPDATE_KERNELS(device_type_v, dtype_pair, itype_pair)   \\\n  REGISTER_TENSOR_SCATTER_ND_OPT_KERNELS(tensor_scatter_nd_update, Update, device_type_v, \\\n                                         dtype_pair, itype_pair)\n\n#define REGISTER_TENSOR_GATHER_ND_ADD_KERNELS(device_type_v, dtype_pair, itype_pair)            \\\n  REGISTER_TENSOR_SCATTER_ND_OPT_KERNELS(tensor_scatter_nd_add, Add, device_type_v, dtype_pair, \\\n                                         itype_pair)\n\n#define REGISTER_ND_INDEX_SLICE_KERNELS(device_type_v, dtype_pair, itype_pair)    \\\n  REGISTER_GATHER_ND_KERNELS(device_type_v, dtype_pair, itype_pair)               \\\n  REGISTER_SCATTER_ND_KERNELS(device_type_v, dtype_pair, itype_pair)              \\\n  REGISTER_SCATTER_ND_LIKE_KERNELS(device_type_v, dtype_pair, itype_pair)         \\\n  REGISTER_TENSOR_GATHER_ND_UPDATE_KERNELS(device_type_v, dtype_pair, itype_pair) \\\n  REGISTER_TENSOR_GATHER_ND_ADD_KERNELS(device_type_v, dtype_pair, itype_pair)\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_KERNELS_ND_INDEX_SLICE_KERNELS_H_\n"
  },
  {
    "path": "oneflow/user/kernels/nd_index_slice_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_KERNELS_ND_INDEX_SLICE_UTIL_H_\n#define ONEFLOW_USER_KERNELS_ND_INDEX_SLICE_UTIL_H_\n\n#include \"oneflow/core/common/shape.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/ndarray/xpu_util.h\"\n\nnamespace oneflow {\n\nstruct NdIndexSliceArgs {\n  static const size_t kMaxDims = 8;\n  int64_t num_slices;   // The number of slices (indices_shape.Count(0, -1))\n  int64_t slice_size;   // The element_cnt of each slice (sliced_shape.Count(indices_num_axes-1))\n  int64_t index_ndims;  // The number of dims which are sliced (indices_shape.At(-1))\n  int64_t dense_ndims;\n  int64_t dense_shape[kMaxDims];\n  int64_t dense_stride[kMaxDims];\n  int64_t slices_ndims;\n  int64_t slices_shape[kMaxDims];\n  int64_t slices_stride[kMaxDims];\n};\n\ninline NdIndexSliceArgs ConstructNdIndexSliceArgs(const user_op::Tensor& dense,\n                                                  const user_op::Tensor& slices,\n                                                  const user_op::Tensor& indices) {\n  NdIndexSliceArgs args;\n  std::memset(&args, 0, sizeof(NdIndexSliceArgs));\n  args.num_slices = indices.shape_view().Count(0, indices.shape_view().NumAxes() - 1);\n  args.index_ndims = indices.shape_view().At(indices.shape_view().NumAxes() - 1);\n  args.slice_size = slices.shape_view().Count(indices.shape_view().NumAxes() - 1);\n\n  args.dense_ndims = dense.shape_view().NumAxes();\n  FOR_RANGE(int64_t, i, 0, dense.shape_view().NumAxes()) {\n    args.dense_shape[i] = dense.shape_view().At(i);\n    args.dense_stride[i] = dense.stride().at(i);\n  }\n  args.slices_ndims = slices.shape_view().NumAxes();\n  FOR_RANGE(int64_t, i, 0, slices.stride().size()) {\n    args.slices_shape[i] = slices.shape_view().At(i);\n    args.slices_stride[i] = slices.stride().at(i);\n  }\n  return args;\n}\n\ntemplate<DeviceType device_type, typename T, typename I>\nstruct GatherNdFunctor final {\n  void operator()(ep::Stream* stream, const NdIndexSliceArgs& args, const I* indices,\n                  const T* dense, T* slices) const;\n};\n\ntemplate<DeviceType device_type, typename T, typename I>\nstruct ScatterNdAddFunctor final {\n  void operator()(ep::Stream* stream, const NdIndexSliceArgs& args, const I* indices,\n                  const T* slices, T* dense) const;\n};\n\ntemplate<DeviceType device_type, typename T, typename I>\nstruct ScatterNdUpdateFunctor final {\n  void operator()(ep::Stream* stream, const NdIndexSliceArgs& args, const I* indices,\n                  const T* slices, T* dense) const;\n};\n\ntemplate<DeviceType device_type, typename T, typename I>\nstruct ScatterNdUpdateWithStrideFunctor final {\n  void operator()(ep::Stream* stream, const NdIndexSliceArgs& args, const I* indices,\n                  const T* slices, T* dense) const;\n};\n\ntemplate<DeviceType device_type, typename T, typename I>\nstruct FillByNdIndexFunctor final {\n  void operator()(ep::Stream* stream, const NdIndexSliceArgs& args, const I* indices, T* dense,\n                  T value) const;\n};\n\ntemplate<typename I>\nOF_DEVICE_FUNC int64_t OffsetInSliceToOffsetInDense(int64_t slice_size, int64_t index_ndims,\n                                                    const int64_t* dense_shape, const I* indices,\n                                                    int64_t n) {\n  int64_t slice_idx = n / slice_size;\n  const I* nd_index = indices + slice_idx * index_ndims;\n  int64_t offset = 0;\n  int64_t product = 1;\n  int64_t shifted_index = 0;\n  for (int64_t i = index_ndims - 1; i >= 0; --i) {\n#if defined(__CUDACC__)\n    assert(nd_index[i] < dense_shape[i] && nd_index[i] >= -dense_shape[i] && \"index out of bounds\");\n#else\n    CHECK(nd_index[i] < dense_shape[i] && nd_index[i] >= -dense_shape[i])\n        << \"IndexError: index \" << nd_index[i] << \" is out of bounds for dimension \" << i\n        << \" with size \" << dense_shape[i];\n#endif\n    shifted_index = nd_index[i] < 0 && nd_index[i] >= -dense_shape[i] ? nd_index[i] + dense_shape[i]\n                                                                      : nd_index[i];\n    offset += shifted_index * product;\n    product *= dense_shape[i];\n  }\n  return offset * slice_size + n % slice_size;\n}\n\nOF_DEVICE_FUNC int64_t GetMemoryOffset4ElementIdx(int64_t n, int64_t ndims, const int64_t* shape,\n                                                  const int64_t* stride) {\n  int64_t offset = 0;\n  for (int64_t i = ndims - 1; i >= 0; --i) {\n    offset += n % shape[i] * stride[i];\n    n /= shape[i];\n  }\n  return offset;\n}\n\ntemplate<typename T, typename I>\nOF_DEVICE_FUNC void DoGatherNd(int64_t elem_cnt, int64_t slice_size, int64_t index_ndims,\n                               const int64_t* dense_shape, const I* indices, const T* dense,\n                               T* slices) {\n  XPU_1D_KERNEL_LOOP(i, elem_cnt) {\n    int64_t offset = OffsetInSliceToOffsetInDense(slice_size, index_ndims, dense_shape, indices, i);\n    slices[i] = dense[offset];\n  }\n}\n\ntemplate<DeviceType device_type, typename T>\nstruct DeviceAdd {\n  OF_DEVICE_FUNC static void Invoke(const T* x, T* y) { *y += *x; }\n};\n\ntemplate<DeviceType device_type, typename T, typename I>\nOF_DEVICE_FUNC void DoScatterNdAdd(int64_t elem_cnt, int64_t slice_size, int64_t index_ndims,\n                                   const int64_t* dense_shape, const I* indices, const T* slices,\n                                   T* dense) {\n  XPU_1D_KERNEL_LOOP(i, elem_cnt) {\n    int64_t offset = OffsetInSliceToOffsetInDense(slice_size, index_ndims, dense_shape, indices, i);\n    DeviceAdd<device_type, T>::Invoke(slices + i, dense + offset);\n  }\n}\n\ntemplate<DeviceType device_type, typename T, typename I>\nOF_DEVICE_FUNC void DoScatterNdUpdate(int64_t elem_cnt, int64_t slice_size, int64_t index_ndims,\n                                      const int64_t* dense_shape, const I* indices, const T* slices,\n                                      T* dense) {\n  XPU_1D_KERNEL_LOOP(i, elem_cnt) {\n    int64_t offset = OffsetInSliceToOffsetInDense(slice_size, index_ndims, dense_shape, indices, i);\n    dense[offset] = slices[i];\n  }\n}\n\ntemplate<DeviceType device_type, typename T, typename I>\nOF_DEVICE_FUNC void DoScatterNdUpdateWithStride(int64_t elem_cnt, const NdIndexSliceArgs& args,\n                                                const I* indices, const T* slices, T* dense) {\n  XPU_1D_KERNEL_LOOP(i, elem_cnt) {\n    // dense tensor memory offset\n    int64_t dense_index = OffsetInSliceToOffsetInDense(args.slice_size, args.index_ndims,\n                                                       args.dense_shape, indices, i);\n    int64_t dense_mem_offset = GetMemoryOffset4ElementIdx(dense_index, args.dense_ndims,\n                                                          args.dense_shape, args.dense_stride);\n    // update tensor memory offset\n    int64_t slice_mem_offset =\n        GetMemoryOffset4ElementIdx(i, args.slices_ndims, args.slices_shape, args.slices_stride);\n    dense[dense_mem_offset] = slices[slice_mem_offset];\n  }\n}\n\ntemplate<typename T, typename I>\nOF_DEVICE_FUNC void DoFillByNdIndex(int64_t elem_cnt, int64_t slice_size, int64_t index_ndims,\n                                    const int64_t* dense_shape, const I* indices, T* dense,\n                                    T value) {\n  XPU_1D_KERNEL_LOOP(i, elem_cnt) {\n    int64_t offset = OffsetInSliceToOffsetInDense(slice_size, index_ndims, dense_shape, indices, i);\n    dense[offset] = value;\n  }\n}\n\n#define INSTANTIATE_GATHER_ND_FUNCTOR(device_type_v, dtype_pair, itype_pair)   \\\n  template struct GatherNdFunctor<device_type_v, OF_PP_PAIR_FIRST(dtype_pair), \\\n                                  OF_PP_PAIR_FIRST(itype_pair)>;\n\n#define INSTANTIATE_SCATTER_ND_ADD_FUNCTOR(device_type_v, dtype_pair, itype_pair)  \\\n  template struct ScatterNdAddFunctor<device_type_v, OF_PP_PAIR_FIRST(dtype_pair), \\\n                                      OF_PP_PAIR_FIRST(itype_pair)>;\n\n#define INSTANTIATE_FILL_BY_ND_INDEX_FUNCTOR(device_type_v, dtype_pair, itype_pair) \\\n  template struct FillByNdIndexFunctor<device_type_v, OF_PP_PAIR_FIRST(dtype_pair), \\\n                                       OF_PP_PAIR_FIRST(itype_pair)>;\n\n#define INSTANTIATE_ND_INDEX_SLICE_FUNCTORS(device_type_v, dtype_pair, itype_pair) \\\n  INSTANTIATE_GATHER_ND_FUNCTOR(device_type_v, dtype_pair, itype_pair)             \\\n  INSTANTIATE_SCATTER_ND_ADD_FUNCTOR(device_type_v, dtype_pair, itype_pair)        \\\n  INSTANTIATE_FILL_BY_ND_INDEX_FUNCTOR(device_type_v, dtype_pair, itype_pair)\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_KERNELS_ND_INDEX_SLICE_UTIL_H_\n"
  },
  {
    "path": "oneflow/user/kernels/nll_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/nd_sbp.h\"\n#include \"oneflow/core/job/nd_sbp_util.h\"\n#include \"oneflow/user/kernels/nll_kernel_util.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nclass NLLKernelCache final : public user_op::OpKernelCache {\n public:\n  NLLKernelCache(int64_t class_start, int64_t num_classes)\n      : class_start_(class_start), num_classes_(num_classes) {}\n  ~NLLKernelCache() override = default;\n\n  int64_t class_start() const { return class_start_; }\n  int64_t num_classes() const { return num_classes_; }\n\n private:\n  const int64_t class_start_;\n  const int64_t num_classes_;\n};\n\nstd::shared_ptr<user_op::OpKernelCache> CreateNLLKernelCache(user_op::KernelCacheContext* ctx) {\n  CHECK_GT(ctx->parallel_ctx().parallel_num(), 0) << ctx->op_name() << \": invalid parallel_ctx\";\n  if (ctx->parallel_ctx().parallel_num() == 1) { return nullptr; }\n\n  const NdSbp& nd_sbp = ctx->NdSbp4ArgNameAndIndex(\"input\", 0);\n  const Shape& hierarchy = *ctx->parallel_desc().hierarchy();\n  CHECK_EQ(nd_sbp.sbp_parallel_size(), hierarchy.NumAxes())\n      << ctx->op_name() << \": Expected input sbp \" << NdSbpToString(nd_sbp) << \" match hierarchy \"\n      << hierarchy.ToString();\n\n  const Shape& shape = ctx->LogicalTensorDesc4ArgNameAndIndex(\"input\", 0)->shape();\n  const int64_t class_axis = shape.NumAxes() - 1;\n\n  bool split_class_dim = false;\n  for (const auto& sbp : nd_sbp.sbp_parallel()) {\n    if (sbp.has_split_parallel() && sbp.split_parallel().axis() == class_axis) {\n      split_class_dim = true;\n      break;\n    }\n  }\n\n  if (!split_class_dim) { return nullptr; }\n\n  TensorSliceView view =\n      GetTensorSliceView4ParallelId(hierarchy, nd_sbp, shape, ctx->parallel_ctx().parallel_id());\n  return std::make_shared<NLLKernelCache>(view.At(class_axis).begin(), view.At(class_axis).size());\n}\n\n}  // namespace\n\ntemplate<DeviceType device_type, typename T, typename K>\nclass NLLKernel final : public user_op::OpKernel {\n public:\n  NLLKernel() = default;\n  ~NLLKernel() override = default;\n\n  std::shared_ptr<user_op::OpKernelCache> InitOpKernelCache(\n      user_op::KernelCacheContext* ctx) const override {\n    return CreateNLLKernelCache(ctx);\n  }\n\n private:\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state,\n               const user_op::OpKernelCache* cache) const override {\n    const auto* input = ctx->Tensor4ArgNameAndIndex(\"input\", 0);\n    const auto* target = ctx->Tensor4ArgNameAndIndex(\"target\", 0);\n    auto* output = ctx->Tensor4ArgNameAndIndex(\"output\", 0);\n    auto* out_weight = ctx->Tensor4ArgNameAndIndex(\"out_weight\", 0);\n\n    const int64_t N = target->shape_view().elem_cnt();\n    const int64_t C = input->shape_view().At(input->shape_view().NumAxes() - 1);\n    CHECK_LE(N, std::numeric_limits<int32_t>::max())\n        << \"Expected batch size not exceed int32 numeric limits\";\n\n    K class_start = 0;\n    if (cache) {\n      const auto* spec_cache = dynamic_cast<const NLLKernelCache*>(cache);\n      CHECK_NOTNULL(spec_cache);\n      CHECK_EQ(spec_cache->num_classes(), C) << ctx->op_name() << \": expected num_classes \" << C\n                                             << \", got \" << spec_cache->num_classes();\n      class_start = spec_cache->class_start();\n    }\n\n    const K ignore_index = static_cast<K>(ctx->Attr<int64_t>(\"ignore_index\"));\n\n    const T* weight_dptr = nullptr;\n    if (ctx->has_input(\"weight\", 0)) {\n      weight_dptr = CHECK_NOTNULL(ctx->Tensor4ArgNameAndIndex(\"weight\", 0))->dptr<T>();\n    }\n\n    NLLKernelUtil<device_type, T, K>::Forward(ctx->stream(), static_cast<int32_t>(N),\n                                              static_cast<K>(C), class_start, ignore_index,\n                                              input->dptr<T>(), target->dptr<K>(), weight_dptr,\n                                              output->mut_dptr<T>(), out_weight->mut_dptr<T>());\n  }\n};\n\ntemplate<DeviceType device_type, typename T, typename K>\nclass NLLGradKernel final : public user_op::OpKernel {\n public:\n  NLLGradKernel() = default;\n  ~NLLGradKernel() override = default;\n\n  std::shared_ptr<user_op::OpKernelCache> InitOpKernelCache(\n      user_op::KernelCacheContext* ctx) const override {\n    return CreateNLLKernelCache(ctx);\n  }\n\n private:\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state,\n               const user_op::OpKernelCache* cache) const override {\n    const auto* target = ctx->Tensor4ArgNameAndIndex(\"target\", 0);\n    const auto* out_grad = ctx->Tensor4ArgNameAndIndex(\"out_grad\", 0);\n    auto* in_grad = ctx->Tensor4ArgNameAndIndex(\"in_grad\", 0);\n\n    const int64_t N = target->shape_view().elem_cnt();\n    const int64_t C = in_grad->shape_view().At(in_grad->shape_view().NumAxes() - 1);\n    CHECK_LE(N, std::numeric_limits<int32_t>::max())\n        << \"Expected batch size not exceed int32 numeric limits\";\n\n    K class_start = 0;\n    if (cache) {\n      const auto* spec_cache = dynamic_cast<const NLLKernelCache*>(cache);\n      CHECK_NOTNULL(spec_cache);\n      CHECK_EQ(spec_cache->num_classes(), C) << ctx->op_name() << \": expected num_classes \" << C\n                                             << \", got \" << spec_cache->num_classes();\n      class_start = spec_cache->class_start();\n    }\n\n    const K ignore_index = static_cast<K>(ctx->Attr<int64_t>(\"ignore_index\"));\n\n    const T* weight_dptr = nullptr;\n    if (ctx->has_input(\"weight\", 0)) {\n      weight_dptr = CHECK_NOTNULL(ctx->Tensor4ArgNameAndIndex(\"weight\", 0))->dptr<T>();\n    }\n\n    NLLKernelUtil<device_type, T, K>::Backward(\n        ctx->stream(), static_cast<int32_t>(N), static_cast<K>(C), class_start, ignore_index,\n        out_grad->dptr<T>(), target->dptr<K>(), weight_dptr, in_grad->mut_dptr<T>());\n  }\n};\n\n#define REGISTER_NLL_KERNELS(device, dtype, ltype)                                            \\\n  REGISTER_USER_KERNEL(\"nll\").SetCreateFn<NLLKernel<device, dtype, ltype>>().SetIsMatchedHob( \\\n      (user_op::HobDeviceType() == device)                                                    \\\n      && (user_op::HobDataType(\"input\", 0) == GetDataType<dtype>::value)                      \\\n      && (user_op::HobDataType(\"target\", 0) == GetDataType<ltype>::value));                   \\\n  REGISTER_USER_KERNEL(\"nll_grad\")                                                            \\\n      .SetCreateFn<NLLGradKernel<device, dtype, ltype>>()                                     \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == device)                                   \\\n                       && (user_op::HobDataType(\"input\", 0) == GetDataType<dtype>::value)     \\\n                       && (user_op::HobDataType(\"target\", 0) == GetDataType<ltype>::value)    \\\n                       && (user_op::HobDataType(\"out_grad\", 0) == GetDataType<dtype>::value))\n\nREGISTER_NLL_KERNELS(DeviceType::kCPU, float, int32_t);\nREGISTER_NLL_KERNELS(DeviceType::kCPU, float, int64_t);\nREGISTER_NLL_KERNELS(DeviceType::kCPU, double, int32_t);\nREGISTER_NLL_KERNELS(DeviceType::kCPU, double, int64_t);\n\n#ifdef WITH_CUDA\n\nREGISTER_NLL_KERNELS(DeviceType::kCUDA, float, int32_t);\nREGISTER_NLL_KERNELS(DeviceType::kCUDA, float, int64_t);\nREGISTER_NLL_KERNELS(DeviceType::kCUDA, double, int32_t);\nREGISTER_NLL_KERNELS(DeviceType::kCUDA, double, int64_t);\nREGISTER_NLL_KERNELS(DeviceType::kCUDA, half, int32_t);\nREGISTER_NLL_KERNELS(DeviceType::kCUDA, half, int64_t);\n\n#endif  // WITH_CUDA\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/nll_kernel_util.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/user/kernels/nll_kernel_util.h\"\n\nnamespace oneflow {\n\ntemplate<typename T, typename K>\nstruct NLLKernelUtil<DeviceType::kCPU, T, K> {\n  static void Forward(ep::Stream* stream, const int32_t num_samples, const K num_classes,\n                      const K class_start, const K ignore_index, const T* input, const K* target,\n                      const T* weight, T* out, T* out_weight) {\n    FOR_RANGE(int32_t, i, 0, num_samples) {\n      K label = target[i];\n      T w = T{0};\n      T y = T{0};\n      if (label != ignore_index) {\n        label -= class_start;\n        if (label >= 0 && label < num_classes) {\n          w = weight ? weight[label] : T{1};\n          y = -(input[i * num_classes + label] * w);\n        }\n      }\n      out[i] = y;\n      out_weight[i] = w;\n    }\n  }\n\n  static void Backward(ep::Stream* stream, const int32_t num_samples, const K num_classes,\n                       const K class_start, const K ignore_index, const T* out_grad,\n                       const K* target, const T* weight, T* in_grad) {\n    Memset<DeviceType::kCPU>(stream, in_grad, 0,\n                             RoundUp(num_samples * num_classes * sizeof(T), kBlobBodyAlignSize));\n    FOR_RANGE(int32_t, i, 0, num_samples) {\n      K label = target[i];\n      if (label == ignore_index) { continue; }\n      label -= class_start;\n      if (label >= 0 && label < num_classes) {\n        const T w = weight ? -weight[label] : T(-1);\n        in_grad[i * num_classes + label] = out_grad[i] * w;\n      }\n    }\n  }\n};\n\ntemplate struct NLLKernelUtil<DeviceType::kCPU, float, int32_t>;\ntemplate struct NLLKernelUtil<DeviceType::kCPU, float, int64_t>;\ntemplate struct NLLKernelUtil<DeviceType::kCPU, double, int32_t>;\ntemplate struct NLLKernelUtil<DeviceType::kCPU, double, int64_t>;\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/nll_kernel_util.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/user/kernels/nll_kernel_util.h\"\n#include \"oneflow/core/cuda/atomic.cuh\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename T, typename K>\n__global__ void NLLForward(const int32_t num_samples, const K num_classes, const K class_start,\n                           const K ignore_index, const T* input, const K* target, const T* weight,\n                           T* out, T* out_weight) {\n  const T zero = GetZeroVal<T>();\n  const T one = GetOneVal<T>();\n  CUDA_1D_KERNEL_LOOP(i, num_samples) {\n    K label = target[i];\n    T w = zero;\n    T y = zero;\n    if (label != ignore_index) {\n      label -= class_start;\n      if (label >= 0 && label < num_classes) {\n        w = weight ? weight[label] : one;\n        y = -(input[i * num_classes + label] * w);\n      }\n    }\n    out[i] = y;\n    out_weight[i] = w;\n  }\n}\n\ntemplate<typename T, typename K>\n__global__ void NLLBackward(const int32_t num_samples, const K num_classes, const K class_start,\n                            const K ignore_index, const T* out_grad, const K* target,\n                            const T* weight, T* in_grad) {\n  const T one = GetOneVal<T>();\n  const T zero = GetZeroVal<T>();\n  CUDA_1D_KERNEL_LOOP_T(K, i, num_samples * num_classes) {\n    const K n = i / num_classes;\n    const K idx = i - n * num_classes;\n    const K label = target[n];\n    if (label != ignore_index && idx == label - class_start) {\n      in_grad[i] = out_grad[n] * (weight ? -weight[idx] : -one);\n    } else {\n      in_grad[i] = zero;\n    }\n  }\n}\n\n}  // namespace\n\ntemplate<typename T, typename K>\nstruct NLLKernelUtil<DeviceType::kCUDA, T, K> {\n  static void Forward(ep::Stream* stream, const int32_t num_samples, const K num_classes,\n                      const K class_start, const K ignore_index, const T* input, const K* target,\n                      const T* weight, T* out, T* out_weight) {\n    NLLForward<<<BlocksNum4ThreadsNum(num_samples), kCudaThreadsNumPerBlock, 0,\n                 stream->As<ep::CudaStream>()->cuda_stream()>>>(num_samples, num_classes,\n                                                                class_start, ignore_index, input,\n                                                                target, weight, out, out_weight);\n  }\n\n  static void Backward(ep::Stream* stream, const int32_t num_samples, const K num_classes,\n                       const K class_start, const K ignore_index, const T* out_grad,\n                       const K* target, const T* weight, T* in_grad) {\n    NLLBackward<<<BlocksNum4ThreadsNum(num_samples), kCudaThreadsNumPerBlock, 0,\n                  stream->As<ep::CudaStream>()->cuda_stream()>>>(\n        num_samples, num_classes, class_start, ignore_index, out_grad, target, weight, in_grad);\n  }\n};\n\ntemplate struct NLLKernelUtil<DeviceType::kCUDA, float, int32_t>;\ntemplate struct NLLKernelUtil<DeviceType::kCUDA, float, int64_t>;\ntemplate struct NLLKernelUtil<DeviceType::kCUDA, double, int32_t>;\ntemplate struct NLLKernelUtil<DeviceType::kCUDA, double, int64_t>;\ntemplate struct NLLKernelUtil<DeviceType::kCUDA, half, int32_t>;\ntemplate struct NLLKernelUtil<DeviceType::kCUDA, half, int64_t>;\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/nll_kernel_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_KERNELS_NLL_KERNEL_UTIL_H_\n#define ONEFLOW_USER_KERNELS_NLL_KERNEL_UTIL_H_\n\n#include \"oneflow/core/kernel/kernel_util.h\"\n\nnamespace oneflow {\n\ntemplate<DeviceType device_type, typename T, typename K>\nstruct NLLKernelUtil {\n  static void Forward(ep::Stream* stream, const int32_t num_samples, const K num_classes,\n                      const K class_start, const K ignore_index, const T* input, const K* target,\n                      const T* weight, T* out, T* out_weight);\n\n  static void Backward(ep::Stream* stream, const int32_t num_samples, const K num_classes,\n                       const K class_start, const K ignore_index, const T* out_grad,\n                       const K* target, const T* weight, T* in_grad);\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_KERNELS_NLL_KERNEL_UTIL_H_\n"
  },
  {
    "path": "oneflow/user/kernels/nms_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/new_kernel_util.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename T>\n__inline__ T IoU(T const* const a, T const* const b) {\n  T interS = std::max(std::min(a[2], b[2]) - std::max(a[0], b[0]), static_cast<T>(0.f))\n             * std::max(std::min(a[3], b[3]) - std::max(a[1], b[1]), static_cast<T>(0.f));\n  T Sa = (a[2] - a[0]) * (a[3] - a[1]);\n  T Sb = (b[2] - b[0]) * (b[3] - b[1]);\n  return interS / (Sa + Sb - interS);\n}\n\n}  // namespace\n\ntemplate<typename T>\nclass NmsCpuKernel final : public user_op::OpKernel {\n public:\n  NmsCpuKernel() = default;\n  ~NmsCpuKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* boxes_blob = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    user_op::Tensor* keep_blob = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    const T* boxes = boxes_blob->dptr<T>();\n    int8_t* keep = keep_blob->mut_dptr<int8_t>();\n\n    const int num_boxes = boxes_blob->shape_view().At(0);\n    int num_keep = ctx->Attr<int>(\"keep_n\");\n    if (num_keep <= 0 || num_keep > num_boxes) { num_keep = num_boxes; }\n    const float iou_threshold = ctx->Attr<float>(\"iou_threshold\");\n    for (int i = 0; i < num_boxes; i++) { keep[i] = -1; }\n    for (int i = 0; i < num_boxes; i++) {\n      if (keep[i] == 0) continue;\n      keep[i] = 1;\n      for (int j = i + 1; j < num_boxes; j++) {\n        if (IoU(boxes + i * 4, boxes + j * 4) > iou_threshold) { keep[j] = 0; }\n      }\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_NMS_CPU_KERNEL(dtype)                                            \\\n  REGISTER_USER_KERNEL(\"nms\").SetCreateFn<NmsCpuKernel<dtype>>().SetIsMatchedHob( \\\n      (user_op::HobDeviceType() == DeviceType::kCPU)                              \\\n      && (user_op::HobDataType(\"out\", 0) == DataType::kInt8)                      \\\n      && (user_op::HobDataType(\"in\", 0) == GetDataType<dtype>::value));\n\nREGISTER_NMS_CPU_KERNEL(float)\nREGISTER_NMS_CPU_KERNEL(double)\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/nms_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/new_kernel_util.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nconstexpr int kBlockSize = sizeof(int64_t) * 8;\n\ntemplate<typename T>\n__host__ __device__ __forceinline__ T CeilDiv(T a, T b) {\n  return (a + b - 1) / b;\n}\n\ntemplate<typename T>\n__host__ __device__ __forceinline__ T IoU(T const* const a, T const* const b) {\n  T interS =\n      max(min(a[2], b[2]) - max(a[0], b[0]), 0.f) * max(min(a[3], b[3]) - max(a[1], b[1]), 0.f);\n  T Sa = (a[2] - a[0]) * (a[3] - a[1]);\n  T Sb = (b[2] - b[0]) * (b[3] - b[1]);\n  return interS / (Sa + Sb - interS);\n}\n\ntemplate<typename T>\n__global__ void CalcSuppressionBitmaskMatrix(int num_boxes, float iou_threshold, const T* boxes,\n                                             int64_t* suppression_bmask_matrix) {\n  const int row = blockIdx.y;\n  const int col = blockIdx.x;\n\n  if (row > col) return;\n\n  const int row_size = min(num_boxes - row * kBlockSize, kBlockSize);\n  const int col_size = min(num_boxes - col * kBlockSize, kBlockSize);\n\n  __shared__ T block_boxes[kBlockSize * 4];\n  if (threadIdx.x < col_size) {\n    block_boxes[threadIdx.x * 4 + 0] = boxes[(kBlockSize * col + threadIdx.x) * 4 + 0];\n    block_boxes[threadIdx.x * 4 + 1] = boxes[(kBlockSize * col + threadIdx.x) * 4 + 1];\n    block_boxes[threadIdx.x * 4 + 2] = boxes[(kBlockSize * col + threadIdx.x) * 4 + 2];\n    block_boxes[threadIdx.x * 4 + 3] = boxes[(kBlockSize * col + threadIdx.x) * 4 + 3];\n  }\n  __syncthreads();\n\n  if (threadIdx.x < row_size) {\n    const int cur_box_idx = kBlockSize * row + threadIdx.x;\n    const T* cur_box_ptr = boxes + cur_box_idx * 4;\n    unsigned long long bits = 0;\n    int start = 0;\n    if (row == col) { start = threadIdx.x + 1; }\n    for (int i = start; i < col_size; i++) {\n      if (IoU(cur_box_ptr, block_boxes + i * 4) > iou_threshold) { bits |= 1Ull << i; }\n    }\n    suppression_bmask_matrix[cur_box_idx * gridDim.y + col] = bits;\n  }\n}\n\n__global__ void ScanSuppression(int num_boxes, int num_blocks, int num_keep,\n                                int64_t* suppression_bmask, int8_t* keep_mask) {\n  extern __shared__ int64_t remv[];\n  remv[threadIdx.x] = 0;\n  for (int i = 0; i < num_boxes; ++i) {\n    int block_n = i / kBlockSize;\n    int block_i = i % kBlockSize;\n    if (!(remv[block_n] & (1Ull << block_i))) {\n      remv[threadIdx.x] |= suppression_bmask[i * num_blocks + threadIdx.x];\n      if (threadIdx.x == block_n && num_keep > 0) {\n        keep_mask[i] = 1;\n        num_keep -= 1;\n      }\n    }\n  }\n}\n\n}  // namespace\n\ntemplate<typename T>\nclass NmsGpuKernel final : public user_op::OpKernel {\n public:\n  NmsGpuKernel() = default;\n  ~NmsGpuKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* boxes_blob = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    user_op::Tensor* keep_blob = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    user_op::Tensor* tmp_blob = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n    const T* boxes = boxes_blob->dptr<T>();\n    int8_t* keep = keep_blob->mut_dptr<int8_t>();\n    int64_t* suppression_mask = tmp_blob->mut_dptr<int64_t>();\n\n    const int num_boxes = boxes_blob->shape_view().At(0);\n    int num_keep = ctx->Attr<int>(\"keep_n\");\n    if (num_keep <= 0 || num_keep > num_boxes) { num_keep = num_boxes; }\n    const int num_blocks = CeilDiv<int>(num_boxes, kBlockSize);\n    Memset<DeviceType::kCUDA>(ctx->stream(), suppression_mask, 0,\n                              num_boxes * num_blocks * sizeof(int64_t));\n    Memset<DeviceType::kCUDA>(ctx->stream(), keep, 0, num_boxes * sizeof(int8_t));\n\n    dim3 blocks(num_blocks, num_blocks);\n    dim3 threads(kBlockSize);\n    CalcSuppressionBitmaskMatrix<<<blocks, threads, 0,\n                                   ctx->stream()->As<ep::CudaStream>()->cuda_stream()>>>(\n        num_boxes, ctx->Attr<float>(\"iou_threshold\"), boxes, suppression_mask);\n    ScanSuppression<<<1, num_blocks, num_blocks * sizeof(int64_t),\n                      ctx->stream()->As<ep::CudaStream>()->cuda_stream()>>>(\n        num_boxes, num_blocks, num_keep, suppression_mask, keep);\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_NMS_CUDA_KERNEL(dtype)                                                 \\\n  REGISTER_USER_KERNEL(\"nms\")                                                           \\\n      .SetCreateFn<NmsGpuKernel<dtype>>()                                               \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)                  \\\n                       && (user_op::HobDataType(\"out\", 0) == DataType::kInt8)           \\\n                       && (user_op::HobDataType(\"in\", 0) == GetDataType<dtype>::value)) \\\n      .SetInferTmpSizeFn([](user_op::InferContext* ctx) {                               \\\n        const Shape& in_shape = ctx->Shape4ArgNameAndIndex(\"in\", 0);                    \\\n        int64_t num_boxes = in_shape.At(0);                                             \\\n        int64_t blocks = CeilDiv<int64_t>(num_boxes, kBlockSize);                       \\\n        return num_boxes * blocks * sizeof(int64_t);                                    \\\n      });\n\nREGISTER_NMS_CUDA_KERNEL(float)\nREGISTER_NMS_CUDA_KERNEL(double)\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/noncontiguous_binary_op.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/data_type.h\"\n#include \"oneflow/core/device/cuda_util.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/user_op_tensor.h\"\n#include \"oneflow/core/kernel/kernel_util.h\"\n#include \"oneflow/core/ep/include/primitive/fast_integer_math.h\"\n#include \"oneflow/core/cuda/elementwise.cuh\"\n\nnamespace oneflow {\n\nnamespace {\n\n#define MaxDims 6\n#define MAX2(a, b) ((a) > (b)) ? (a) : (b)\n#define MAX3(a, b, c) MAX2(MAX2(a, b), c)\n\nusing cuda::elementwise::Packed;\n\n#define DEFINE_BINARY_FUNCTOR(OP, expr)                                                        \\\n  template<typename T>                                                                         \\\n  struct OP {                                                                                  \\\n    __device__ __forceinline__ T operator()(const T& a, const T& b) const { return a expr b; } \\\n  };                                                                                           \\\n  template<>                                                                                   \\\n  struct OP<half> {                                                                            \\\n    __device__ __forceinline__ half operator()(const half& a, const half& b) const {           \\\n      return __float2half(__half2float(a) expr __half2float(b));                               \\\n    }                                                                                          \\\n  };\n\nDEFINE_BINARY_FUNCTOR(Add, +)\nDEFINE_BINARY_FUNCTOR(Sub, -)\nDEFINE_BINARY_FUNCTOR(Mul, *)\nDEFINE_BINARY_FUNCTOR(Div, /)\n#undef DEFINE_BINARY_FUNCTOR\n\n#define DEFINE_BINARY_OP_GRAD_FUNCTOR(OP, dl_expr, dr_expr)                                       \\\n  template<typename T>                                                                            \\\n  struct OP##Grad {                                                                               \\\n    __device__ __forceinline__ void operator()(const T& dout, const T& a, const T& b, T* da,      \\\n                                               T* db) const {                                     \\\n      *da = dl_expr dout;                                                                         \\\n      *db = dr_expr dout;                                                                         \\\n    }                                                                                             \\\n  };                                                                                              \\\n  template<>                                                                                      \\\n  struct OP##Grad<half> {                                                                         \\\n    __device__ __forceinline__ void operator()(const half& hdout, const half& ha, const half& hb, \\\n                                               half* hda, half* hdb) const {                      \\\n      float dout, a, b;                                                                           \\\n      dout = __half2float(hdout), a = __half2float(ha), b = __half2float(hb);                     \\\n      *hda = __float2half(dl_expr dout);                                                          \\\n      *hdb = __float2half(dr_expr dout);                                                          \\\n    }                                                                                             \\\n  };\n\nDEFINE_BINARY_OP_GRAD_FUNCTOR(Add, 1 *, 1 *)\nDEFINE_BINARY_OP_GRAD_FUNCTOR(Sub, 1 *, -1 *)\nDEFINE_BINARY_OP_GRAD_FUNCTOR(Mul, b*, a*)\nDEFINE_BINARY_OP_GRAD_FUNCTOR(Div, 1 / b*, -a / b / b*)\n#undef DEFINE_BINARY_OP_GRAD_FUNCTOR\n\ntemplate<int pack_size, typename IndexType, typename BinaryOp, typename R, typename T1, typename T2,\n         typename Store, typename Loader1, typename Loader2>\n__global__ void noncontiguous_binary_op_kernel(IndexType n_pack, Store y, Loader1 x1, Loader2 x2) {\n  Packed<R, pack_size> pack_y;\n  Packed<T1, pack_size> pack_x1;\n  Packed<T2, pack_size> pack_x2;\n  CUDA_1D_KERNEL_LOOP_T(IndexType, i, n_pack) {\n    x1.load(i, &pack_x1);\n    x2.load(i, &pack_x2);\n#pragma unroll\n    for (int j = 0; j < pack_size; ++j)\n      pack_y.elem[j] = BinaryOp()(static_cast<R>(pack_x1.elem[j]),\n                                  static_cast<R>(pack_x2.elem[j]));  // todo: Apply2\n    y.store(i, &pack_y);\n  }\n};\n\ntemplate<int pack_size, typename IndexType, typename FastIntegerMath, typename Src,\n         typename Dst = void>\nstruct LoadStore {\n  LoadStore(FastIntegerMath fast_integer_math[MaxDims], const int ndims, const int strides[MaxDims],\n            const Src* src, Dst* dst = nullptr, bool is_contiguous = false)\n      : ndims_(ndims), src_(src), dst_(dst), is_contiguous_(is_contiguous) {\n    for (int i = 0; i < ndims; i++) {\n      strides_[i] = static_cast<IndexType>(strides[i]);\n      fast_integer_math_[i] = fast_integer_math[i];\n    }\n  }\n\n  OF_DEVICE_FUNCTION IndexType index2offset(IndexType index) {\n    IndexType offset = 0;\n    IndexType div = 0, mod = 0;\n#pragma unroll\n    for (int dim = ndims_ - 1; dim >= 0; --dim) {\n      if (index == 0) break;\n      fast_integer_math_[dim].divmod(index, &div, &mod);\n      index = div;\n      offset += mod * strides_[dim];\n    }\n    return offset;\n  }\n\n  OF_DEVICE_FUNCTION void load(IndexType idx, Packed<Src, pack_size>* pack) {\n    IndexType offset;\n    if (is_contiguous_)\n      offset = idx * pack_size;\n    else\n      offset = index2offset(idx);\n    *pack = *(reinterpret_cast<const Packed<Src, pack_size>*>(src_ + offset));\n  }\n\n  OF_DEVICE_FUNCTION void store(IndexType idx, Packed<Dst, pack_size>* pack) {\n    IndexType offset;\n    if (is_contiguous_)\n      offset = idx * pack_size;\n    else\n      offset = index2offset(idx);\n    *(reinterpret_cast<Packed<Dst, pack_size>*>(dst_ + offset)) = *pack;\n  }\n\n  int ndims_;\n  int pack_dim_;\n  bool is_contiguous_;\n  const Src* src_;\n  Dst* dst_;\n  IndexType strides_[MaxDims];\n  FastIntegerMath fast_integer_math_[MaxDims];\n};\n\ntemplate<int pack_size, typename IndexType, typename BinaryOp, typename R, typename lhs,\n         typename rhs, typename Store, typename Load1, typename Load2>\nvoid launch_noncontiguous_binary_op_kernel(cudaStream_t stream, const IndexType n_pack,\n                                           Store& store, Load1& load1, Load2& load2) {\n  int num_blocks = 1, block_size = cuda::elementwise::kBlockSize;\n  cudaError_t err = cuda::elementwise::GetNumBlocks(n_pack, &num_blocks);\n  CHECK(err == cudaSuccess);\n  noncontiguous_binary_op_kernel<pack_size, IndexType, BinaryOp, R, lhs, rhs>\n      <<<num_blocks, block_size, 0, stream>>>(n_pack, store, load1, load2);\n}\n\ntemplate<int pack_size, typename IndexType, typename R, typename lhs, typename rhs, typename Store,\n         typename Load1, typename Load2>\nvoid dispatchOp(cudaStream_t stream, const std::string& op, const IndexType n_pack, Store& store,\n                Load1& load1, Load2& load2) {\n  if (op == \"add\")\n    launch_noncontiguous_binary_op_kernel<pack_size, IndexType, Add<R>, R, lhs, rhs>(\n        stream, n_pack, store, load1, load2);\n  else if (op == \"sub\")\n    launch_noncontiguous_binary_op_kernel<pack_size, IndexType, Sub<R>, R, lhs, rhs>(\n        stream, n_pack, store, load1, load2);\n  else if (op == \"mul\")\n    launch_noncontiguous_binary_op_kernel<pack_size, IndexType, Mul<R>, R, lhs, rhs>(\n        stream, n_pack, store, load1, load2);\n  else if (op == \"div\")\n    launch_noncontiguous_binary_op_kernel<pack_size, IndexType, Div<R>, R, lhs, rhs>(\n        stream, n_pack, store, load1, load2);\n  else\n    UNIMPLEMENTED_THEN_THROW();\n}\n\ntemplate<int pack_size, typename IndexType, typename R, typename lhs, typename rhs>\nvoid dispatchInplace(cudaStream_t stream, const bool inplace, const std::string& op,\n                     const int& ndims, const IndexType n_pack, const int sizes[MaxDims],\n                     const int strides[][MaxDims], R* y, const lhs* x1, const rhs* x2) {\n  typedef FastIntegerMath<IndexType> FastIntegerMathT;\n  FastIntegerMathT fast_integer_math[MaxDims];\n  for (int i = 0; i < ndims; ++i) fast_integer_math[i] = FastIntegerMathT(sizes[i]);\n  if (inplace) {\n    LoadStore<pack_size, IndexType, FastIntegerMathT, lhs, R> load_store(fast_integer_math, ndims,\n                                                                         strides[0], x1, y);\n    LoadStore<pack_size, IndexType, FastIntegerMathT, rhs> loader2(fast_integer_math, ndims,\n                                                                   strides[2], x2);\n    dispatchOp<pack_size, IndexType, R, lhs, rhs>(stream, op, n_pack, load_store, load_store,\n                                                  loader2);\n  } else {\n    LoadStore<pack_size, IndexType, FastIntegerMathT, lhs, R> store(fast_integer_math, ndims,\n                                                                    strides[0], nullptr, y);\n    LoadStore<pack_size, IndexType, FastIntegerMathT, lhs> loader1(fast_integer_math, ndims,\n                                                                   strides[1], x1);\n\n    LoadStore<pack_size, IndexType, FastIntegerMathT, rhs> loader2(fast_integer_math, ndims,\n                                                                   strides[2], x2);\n    dispatchOp<pack_size, IndexType, R, lhs, rhs>(stream, op, n_pack, store, loader1, loader2);\n  }\n}\n\ntemplate<int pack_size, typename R, typename lhs, typename rhs>\nvoid dispatchIndexType(cudaStream_t stream, const bool inplace, const std::string& op,\n                       const int& ndims, const int64_t& n_pack, const int sizes[MaxDims],\n                       const int strides[][MaxDims], R* y, const lhs* x1, const rhs* x2) {\n  if ((n_pack * pack_size) >> 30 == 0) {\n    int32_t n = (int32_t)n_pack;\n    dispatchInplace<pack_size, int32_t, R, lhs, rhs>(stream, inplace, op, ndims, n, sizes, strides,\n                                                     y, x1, x2);\n  } else\n    dispatchInplace<pack_size, int64_t, R, lhs, rhs>(stream, inplace, op, ndims, n_pack, sizes,\n                                                     strides, y, x1, x2);\n}\n\ntemplate<typename R, typename lhs, typename rhs>\nvoid dispatchPacksize(cudaStream_t stream, const bool inplace, const std::string& op,\n                      const int& ndims, const int64_t n_pack, int pack_size,\n                      const int sizes[MaxDims], const int strides[][MaxDims], R* y, const lhs* x1,\n                      const rhs* x2) {\n  if (pack_size == 8)\n    dispatchIndexType<8, R, lhs, rhs>(stream, inplace, op, ndims, n_pack, sizes, strides, y, x1,\n                                      x2);\n  else if (pack_size == 4)\n    dispatchIndexType<4, R, lhs, rhs>(stream, inplace, op, ndims, n_pack, sizes, strides, y, x1,\n                                      x2);\n  else if (pack_size == 2)\n    dispatchIndexType<2, R, lhs, rhs>(stream, inplace, op, ndims, n_pack, sizes, strides, y, x1,\n                                      x2);\n  else if (pack_size == 1)\n    dispatchIndexType<1, R, lhs, rhs>(stream, inplace, op, ndims, n_pack, sizes, strides, y, x1,\n                                      x2);\n  else\n    UNIMPLEMENTED();\n}\n}  // namespace\n\ntemplate<typename R, typename lhs, typename rhs>\nclass NonContiguousBinaryOpKernel final : public user_op::OpKernel {\n public:\n  NonContiguousBinaryOpKernel() = default;\n  ~NonContiguousBinaryOpKernel() override = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    const user_op::Tensor* x1 = ctx->Tensor4ArgNameAndIndex(\"lhs\", 0);\n    const user_op::Tensor* x2 = ctx->Tensor4ArgNameAndIndex(\"rhs\", 0);\n    const std::string op = ctx->Attr<std::string>(\"op\");\n    const bool inplace = ctx->Attr<bool>(\"inplace\");\n    int ndims = y->shape_view().NumAxes();\n    const ShapeView& shape = y->shape_view();\n    int sizes[MaxDims];\n    int strides[3][MaxDims];\n\n    int pack_size = 1;\n    int64_t elem_cnt = 1;\n    int max_elem_size = MAX3(GetSizeOfDataType(y->data_type()), GetSizeOfDataType(x1->data_type()),\n                             GetSizeOfDataType(x2->data_type()));\n    for (int i = 0; i < ndims; ++i) {\n      sizes[i] = shape.At(i);\n      elem_cnt *= shape.At(i);\n      strides[0][i] = y->stride()[i];\n      strides[1][i] = x1->stride()[i];\n      strides[2][i] = x2->stride()[i];\n      if (x1->stride()[i] == 1 && x2->stride()[i] == 1 && y->stride()[i] == 1) {\n        pack_size = 16 / max_elem_size;\n        while (pack_size > 1 && sizes[i] % pack_size) pack_size >>= 1;\n        sizes[i] = sizes[i] / pack_size;\n        strides[0][i] *= pack_size;\n        strides[1][i] *= pack_size;\n        strides[2][i] *= pack_size;\n      }\n    }\n\n    dispatchPacksize(ctx->stream()->As<ep::CudaStream>()->cuda_stream(), inplace, op, ndims,\n                     elem_cnt / pack_size, pack_size, sizes, strides, y->mut_dptr<R>(),\n                     x1->dptr<lhs>(), x2->dptr<rhs>());\n  }\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_USER_KERNEL_NONCONTIGUOUS_BINARY_OP_KERNEL(dtype, lhs, rhs)          \\\n  REGISTER_USER_KERNEL(\"noncontiguous_binary_op\")                                     \\\n      .SetCreateFn<NonContiguousBinaryOpKernel<dtype, lhs, rhs>>()                    \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)                \\\n                       && (user_op::HobDataType(\"y\", 0) == GetDataType<lhs>::value)   \\\n                       && (user_op::HobDataType(\"lhs\", 0) == GetDataType<lhs>::value) \\\n                       && (user_op::HobDataType(\"rhs\", 0) == GetDataType<rhs>::value));\n\n// output_type, lhs_type, rhs_type\nREGISTER_USER_KERNEL_NONCONTIGUOUS_BINARY_OP_KERNEL(float, float, float)\nREGISTER_USER_KERNEL_NONCONTIGUOUS_BINARY_OP_KERNEL(half, half, half)\n// #if CUDA_VERSION >= 11000\n// REGISTER_USER_KERNEL_NONCONTIGUOUS_BINARY_OP_KERNEL(nv_bfloat16, nv_bfloat16, nv_bfloat16)\n// #endif\n\n// ------------------------------------- grad kernel -------------------------------------\ntemplate<int pack_size, typename IndexType, typename BinaryOp, typename R, typename T1, typename T2,\n         typename Loadery, typename Loader1, typename Loader2>\n__global__ void noncontiguous_binary_op_grad_kernel(IndexType n_pack, Loadery dy, Loader1 load1,\n                                                    Loader2 load2) {\n  Packed<R, pack_size> pack_dy;\n  Packed<T1, pack_size> pack_x1;\n  Packed<T2, pack_size> pack_x2;\n  Packed<T1, pack_size> pack_dx1;\n  Packed<T2, pack_size> pack_dx2;\n  CUDA_1D_KERNEL_LOOP_T(IndexType, i, n_pack) {\n    load1.load(i, &pack_x1);\n    load2.load(i, &pack_x2);\n    dy.load(i, &pack_dy);\n#pragma unroll\n    for (int j = 0; j < pack_size; ++j)\n      BinaryOp()(pack_dy.elem[j], pack_x1.elem[j], pack_x2.elem[j], &pack_dx1.elem[j],\n                 &pack_dx2.elem[j]);  // todo: Apply2\n    load1.store(i, &pack_dx1);\n    load2.store(i, &pack_dx2);\n  }\n};\n\ntemplate<int pack_size, typename IndexType, typename BinaryOp, typename R, typename lhs,\n         typename rhs, typename Loady, typename Load1, typename Load2>\nvoid launch_noncontiguous_binary_op_grad_kernel(cudaStream_t stream, const IndexType n_pack,\n                                                Loady& load_y, Load1& load1, Load2& load2) {\n  int num_blocks = 1, block_size = cuda::elementwise::kBlockSize;\n  cudaError_t err = cuda::elementwise::GetNumBlocks(n_pack, &num_blocks);\n  CHECK(err == cudaSuccess);\n  noncontiguous_binary_op_grad_kernel<pack_size, IndexType, BinaryOp, R, lhs, rhs>\n      <<<num_blocks, block_size, 0, stream>>>(n_pack, load_y, load1, load2);\n}\n\ntemplate<int pack_size, typename IndexType, typename R, typename lhs, typename rhs, typename Loady,\n         typename Load1, typename Load2>\nvoid dispatchOpGrad(cudaStream_t stream, const std::string& op, const IndexType& n_pack,\n                    Loady& load_y, Load1& load1, Load2& load2) {\n  if (op == \"add\")\n    launch_noncontiguous_binary_op_grad_kernel<pack_size, IndexType, AddGrad<R>, R, lhs, rhs>(\n        stream, n_pack, load_y, load1, load2);\n  else if (op == \"sub\")\n    launch_noncontiguous_binary_op_grad_kernel<pack_size, IndexType, SubGrad<R>, R, lhs, rhs>(\n        stream, n_pack, load_y, load1, load2);\n  else if (op == \"mul\")\n    launch_noncontiguous_binary_op_grad_kernel<pack_size, IndexType, MulGrad<R>, R, lhs, rhs>(\n        stream, n_pack, load_y, load1, load2);\n  else if (op == \"div\")\n    launch_noncontiguous_binary_op_grad_kernel<pack_size, IndexType, DivGrad<R>, R, lhs, rhs>(\n        stream, n_pack, load_y, load1, load2);\n  else\n    UNIMPLEMENTED_THEN_THROW();\n}\n\ntemplate<int pack_size, typename IndexType, typename R, typename lhs, typename rhs>\nvoid dispatchLoader(cudaStream_t stream, const std::string& op, const int& ndims,\n                    const IndexType n_pack, const int sizes[MaxDims], const int strides[][MaxDims],\n                    lhs* dx1, rhs* dx2, const R* dy, const lhs* x1, const rhs* x2) {\n  typedef FastIntegerMath<IndexType> FastIntegerMathT;\n  FastIntegerMathT fast_integer_math[MaxDims];\n  for (int i = 0; i < ndims; ++i) fast_integer_math[i] = FastIntegerMathT(sizes[i]);\n  LoadStore<pack_size, IndexType, FastIntegerMathT, lhs, R> load_y(fast_integer_math, ndims,\n                                                                   strides[0], dy);\n  LoadStore<pack_size, IndexType, FastIntegerMathT, lhs, lhs> loader_store1(\n      fast_integer_math, ndims, strides[1], x1, dx1);\n\n  LoadStore<pack_size, IndexType, FastIntegerMathT, rhs, rhs> loader_store2(\n      fast_integer_math, ndims, strides[2], x2, dx2);\n  dispatchOpGrad<pack_size, IndexType, R, lhs, rhs>(stream, op, n_pack, load_y, loader_store1,\n                                                    loader_store2);\n}\n\ntemplate<int pack_size, typename R, typename lhs, typename rhs>\nvoid dispatchIndexTypeGrad(cudaStream_t stream, const std::string& op, const int& ndims,\n                           const int64_t& n_pack, const int sizes[MaxDims],\n                           const int strides[][MaxDims], lhs* dx1, rhs* dx2, const R* dy,\n                           const lhs* x1, const rhs* x2) {\n  if ((n_pack * pack_size) >> 30 == 0) {\n    int32_t n = (int32_t)n_pack;\n    dispatchLoader<pack_size, int32_t, R, lhs, rhs>(stream, op, ndims, n, sizes, strides, dx1, dx2,\n                                                    dy, x1, x2);\n  } else\n    dispatchLoader<pack_size, int64_t, R, lhs, rhs>(stream, op, ndims, n_pack, sizes, strides, dx1,\n                                                    dx2, dy, x1, x2);\n}\n\ntemplate<typename R, typename lhs, typename rhs>\nvoid dispatchPacksizeGrad(cudaStream_t stream, const std::string& op, const int& ndims,\n                          const int64_t& n_pack, int& pack_size, const int sizes[MaxDims],\n                          const int strides[][MaxDims], lhs* dx1, rhs* dx2, const R* dy,\n                          const lhs* x1, const rhs* x2) {\n  if (pack_size == 8)\n    dispatchIndexTypeGrad<8, R, lhs, rhs>(stream, op, ndims, n_pack, sizes, strides, dx1, dx2, dy,\n                                          x1, x2);\n  else if (pack_size == 4)\n    dispatchIndexTypeGrad<4, R, lhs, rhs>(stream, op, ndims, n_pack, sizes, strides, dx1, dx2, dy,\n                                          x1, x2);\n  else if (pack_size == 2)\n    dispatchIndexTypeGrad<2, R, lhs, rhs>(stream, op, ndims, n_pack, sizes, strides, dx1, dx2, dy,\n                                          x1, x2);\n  else if (pack_size == 1)\n    dispatchIndexTypeGrad<1, R, lhs, rhs>(stream, op, ndims, n_pack, sizes, strides, dx1, dx2, dy,\n                                          x1, x2);\n  else\n    UNIMPLEMENTED();\n}\n\ntemplate<typename R, typename lhs, typename rhs>\nclass NonContiguousBinaryOpGradKernel final : public user_op::OpKernel {\n public:\n  NonContiguousBinaryOpGradKernel() = default;\n  ~NonContiguousBinaryOpGradKernel() override = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    const user_op::Tensor* x1 = ctx->Tensor4ArgNameAndIndex(\"lhs\", 0);\n    const user_op::Tensor* x2 = ctx->Tensor4ArgNameAndIndex(\"rhs\", 0);\n    user_op::Tensor* dx1 = ctx->Tensor4ArgNameAndIndex(\"dlhs\", 0);\n    user_op::Tensor* dx2 = ctx->Tensor4ArgNameAndIndex(\"drhs\", 0);\n    const std::string op = ctx->Attr<std::string>(\"op\");\n    const bool inplace = ctx->Attr<bool>(\"inplace\");\n    CHECK(inplace == false) << \"inplace should be set to `false` to compute gradients.\";\n    int ndims = dy->shape_view().NumAxes();\n    const ShapeView& shape = dy->shape_view();\n    int sizes[MaxDims];\n    int strides[3][MaxDims];\n\n    int pack_size = 1;\n    int64_t elem_cnt = 1;\n    int max_elem_size = MAX3(GetSizeOfDataType(dy->data_type()), GetSizeOfDataType(x1->data_type()),\n                             GetSizeOfDataType(x2->data_type()));\n    for (int i = 0; i < ndims; ++i) {\n      sizes[i] = shape.At(i);\n      elem_cnt *= shape.At(i);\n      strides[0][i] = dy->stride()[i];\n      strides[1][i] = x1->stride()[i];\n      strides[2][i] = x2->stride()[i];\n      if (x1->stride()[i] == 1 && x2->stride()[i] == 1 && dy->stride()[i] == 1) {\n        pack_size = 16 / max_elem_size;\n        while (pack_size > 1 && sizes[i] % pack_size) pack_size >>= 1;\n        sizes[i] = sizes[i] / pack_size;\n        strides[0][i] *= pack_size;\n        strides[1][i] *= pack_size;\n        strides[2][i] *= pack_size;\n      }\n    }\n\n    dispatchPacksizeGrad(ctx->stream()->As<ep::CudaStream>()->cuda_stream(), op, ndims,\n                         elem_cnt / pack_size, pack_size, sizes, strides, dx1->mut_dptr<lhs>(),\n                         dx2->mut_dptr<rhs>(), dy->dptr<R>(), x1->dptr<lhs>(), x2->dptr<rhs>());\n  }\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_USER_KERNEL_NONCONTIGUOUS_BINARY_OP_GRAD_KERNEL(dtype, lhs, rhs)      \\\n  REGISTER_USER_KERNEL(\"noncontiguous_binary_op_grad\")                                 \\\n      .SetCreateFn<NonContiguousBinaryOpGradKernel<dtype, lhs, rhs>>()                 \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)                 \\\n                       && (user_op::HobDataType(\"dy\", 0) == GetDataType<dtype>::value) \\\n                       && (user_op::HobDataType(\"lhs\", 0) == GetDataType<lhs>::value)  \\\n                       && (user_op::HobDataType(\"rhs\", 0) == GetDataType<rhs>::value));\n\n// output_type, lhs_type, rhs_type\nREGISTER_USER_KERNEL_NONCONTIGUOUS_BINARY_OP_GRAD_KERNEL(float, float, float)\nREGISTER_USER_KERNEL_NONCONTIGUOUS_BINARY_OP_GRAD_KERNEL(half, half, half)\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/nop_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nclass NopKernel final : public user_op::OpKernel {\n public:\n  NopKernel() = default;\n  ~NopKernel() override = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    // do nothing\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_NOP_KERNEL(op_type_name) \\\n  REGISTER_USER_KERNEL(op_type_name).SetCreateFn<NopKernel>();\n\nREGISTER_NOP_KERNEL(\"cast_to_tick\")\nREGISTER_NOP_KERNEL(\"acc_ctrl_tick\")\nREGISTER_NOP_KERNEL(\"repeat\")\n\n}  // namespace\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/normalization_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n\nnamespace oneflow {\n\ntemplate<typename T>\nstatic void ComputeMeanAndVar(const T* input_ptr, T* mean_ptr, T* inv_variance_ptr,\n                              T* moving_mean_ptr, T* moving_variance_ptr, const int64_t batch_size,\n                              const int64_t channel_size, const int64_t spatial_size,\n                              const float epsilon, const float momentum) {\n  // NOTE(Liang Depeng): the following parameters were used to compute mean and var\n  const int64_t jump_step = spatial_size * channel_size;\n  const int64_t reduce_count = batch_size * spatial_size;\n  const int64_t unbias_reduce_count = reduce_count - 1;\n  const T reduce_scale_factor = static_cast<T>(1) / reduce_count;\n  const T unbias_reduce_scale_factor = static_cast<T>(1) / unbias_reduce_count;\n  const T unbias_reduce_scale_factor_m2 = unbias_reduce_scale_factor * -static_cast<T>(2);\n  const T unbias_reduce_scale_factor_mn = reduce_count * unbias_reduce_scale_factor;\n\n  const T exponential_average_factor = 1.0f - momentum;\n\n  for (int64_t channel = 0; channel < channel_size; ++channel) {\n    const T* temp_input_ptr = input_ptr + channel * spatial_size;\n    T sum = 0;\n    T sum_square = 0;\n    for (int64_t batch = 0; batch < batch_size; ++batch) {\n      for (int64_t s = 0; s < spatial_size; ++s) {\n        const T x = temp_input_ptr[s];\n        sum += x;\n        sum_square += x * x;\n      }\n      temp_input_ptr += jump_step;\n    }\n\n    const T temp_mean = sum * reduce_scale_factor;\n    mean_ptr[channel] = temp_mean;\n\n    const T temp_mean_square = temp_mean * temp_mean;\n    const T temp_variance = sum_square * reduce_scale_factor - temp_mean_square;\n\n    const T temp_unbias_variance = sum_square * unbias_reduce_scale_factor\n                                   + unbias_reduce_scale_factor_m2 * temp_mean * sum\n                                   + unbias_reduce_scale_factor_mn * temp_mean_square;\n\n    inv_variance_ptr[channel] = static_cast<T>(1) / std::sqrt(temp_variance + epsilon);\n\n    if (moving_mean_ptr != nullptr && moving_variance_ptr != nullptr) {\n      moving_mean_ptr[channel] =\n          moving_mean_ptr[channel] * momentum + temp_mean * exponential_average_factor;\n      moving_variance_ptr[channel] = moving_variance_ptr[channel] * momentum\n                                     + temp_unbias_variance * exponential_average_factor;\n    }\n  }\n}\n\ntemplate<typename T>\nstatic void Normalize(const T* input_ptr, const T* mean_ptr, const T* variance_ptr,\n                      const T* gamma_ptr, const T* beta_ptr, T* output_ptr,\n                      const int64_t batch_size, const int64_t channel_size,\n                      const int64_t spatial_size, const float epsilon, const bool training) {\n  const T* temp_input_ptr = input_ptr;\n  T* temp_output_ptr = output_ptr;\n  const int64_t all_channels = batch_size * channel_size;\n  int64_t channel = -1;\n  for (int64_t ac = 0; ac < all_channels; ++ac) {\n    channel += 1;\n    if (channel >= channel_size) { channel = 0; }\n    T inv_variance = variance_ptr[channel];\n    if (!training) { inv_variance = 1.0f / std::sqrt(inv_variance + epsilon); }\n    const T gamma = gamma_ptr[channel] * inv_variance;\n    const T beta = beta_ptr[channel];\n    const T mean = mean_ptr[channel];\n    for (int64_t s = 0; s < spatial_size; ++s) {\n      temp_output_ptr[s] = (temp_input_ptr[s] - mean) * gamma + beta;\n    }\n    temp_input_ptr += spatial_size;\n    temp_output_ptr += spatial_size;\n  }\n}\n\ntemplate<typename T>\nstatic void AddToOutput(const T* add_to_output_ptr, T* output_ptr, const int64_t elem_count) {\n  for (int64_t i = 0; i < elem_count; ++i) { output_ptr[i] += add_to_output_ptr[i]; }\n}\n\ntemplate<typename T>\nstatic void AddRelu(const T* addend_ptr, int32_t* mask_ptr, T* output_ptr, const int64_t elem_cnt) {\n  const int32_t step = 32;\n  const int64_t outer_loop = elem_cnt / step;\n  const int64_t remain_loop_start_idx = outer_loop * step;\n\n  T* temp_output_ptr = output_ptr;\n  for (int64_t outer = 0; outer < outer_loop; ++outer) {\n    int32_t mask = 0;\n    for (int32_t s = 0; s < step; ++s) {\n      const T sum = temp_output_ptr[s] + addend_ptr[s];\n      const bool is_positive = (sum > 0);\n      mask = mask | (static_cast<int32_t>(is_positive) << s);\n      temp_output_ptr[s] = is_positive ? sum : 0;\n    }\n    mask_ptr[outer] = mask;\n    addend_ptr += step;\n    temp_output_ptr += step;\n  }\n  if (remain_loop_start_idx < elem_cnt) {\n    int32_t mask_val = 0;\n    const int32_t remain = elem_cnt - remain_loop_start_idx;\n    for (int32_t i = 0; i < remain; ++i) {\n      const T sum = temp_output_ptr[i] + addend_ptr[i];\n      const bool is_positive = (sum > 0);\n      mask_val = mask_val | (static_cast<int32_t>(is_positive) << i);\n      temp_output_ptr[i] = is_positive ? sum : 0;\n    }\n    mask_ptr[outer_loop] = mask_val;\n  }\n}\n\ntemplate<typename T>\nstatic void Relu(int32_t* mask_ptr, T* output_ptr, const int64_t elem_cnt) {\n  const int32_t step = 32;\n  const int64_t outer_loop = elem_cnt / step;\n  const int64_t remain_loop_start_idx = outer_loop * step;\n\n  T* temp_output_ptr = output_ptr;\n  for (int64_t outer = 0; outer < outer_loop; ++outer) {\n    int32_t mask_val = 0;\n    for (int32_t s = 0; s < step; ++s) {\n      const T output = temp_output_ptr[s];\n      const bool is_positive = (output > 0);\n      mask_val = mask_val | (static_cast<int32_t>(is_positive) << s);\n      temp_output_ptr[s] = is_positive ? output : 0;\n    }\n    mask_ptr[outer] = mask_val;\n    temp_output_ptr += step;\n  }\n  if (remain_loop_start_idx < elem_cnt) {\n    int32_t mask_val = 0;\n    const int32_t remain = elem_cnt - remain_loop_start_idx;\n    for (int32_t i = 0; i < remain; ++i) {\n      const T output = temp_output_ptr[i];\n      const bool is_positive = (output > 0);\n      mask_val = mask_val | (static_cast<int32_t>(is_positive) << i);\n      temp_output_ptr[i] = is_positive ? output : 0;\n    }\n    mask_ptr[outer_loop] = mask_val;\n  }\n}\n\ntemplate<typename T>\nstatic void AddReluGrad(const T* dy_ptr, const int32_t* mask_ptr, T* addend_diff_ptr,\n                        const int64_t elem_cnt) {\n  const int32_t step = 32;\n  const int64_t outer_loop = elem_cnt / step;\n  const int64_t remain_loop_start_idx = outer_loop * step;\n\n  for (int64_t outer = 0; outer < outer_loop; ++outer) {\n    const int32_t mask_val = mask_ptr[outer];\n    for (int32_t s = 0; s < step; ++s) {\n      bool is_positive = mask_val & (1 << s);\n      addend_diff_ptr[s] = static_cast<T>(is_positive) * dy_ptr[s];\n    }\n    addend_diff_ptr += step;\n    dy_ptr += step;\n  }\n\n  if (remain_loop_start_idx < elem_cnt) {\n    const int32_t mask_val = mask_ptr[outer_loop];\n    const int32_t remain = elem_cnt - remain_loop_start_idx;\n    for (int32_t i = 0; i < remain; ++i) {\n      bool is_positive = mask_val & (1 << i);\n      addend_diff_ptr[i] = static_cast<T>(is_positive) * dy_ptr[i];\n    }\n  }\n}\n\ntemplate<typename T>\nstatic void ReluGrad(const T* dy_ptr, const int32_t* mask_ptr, T* relu_dx_ptr,\n                     const int64_t elem_cnt) {\n  const int32_t step = 32;\n  const int64_t outer_loop = elem_cnt / step;\n  const int64_t remain_loop_start_idx = outer_loop * step;\n\n  for (int64_t outer = 0; outer < outer_loop; ++outer) {\n    const int32_t mask_val = mask_ptr[outer];\n    for (int32_t s = 0; s < step; ++s) {\n      bool is_positive = mask_val & (1 << s);\n      relu_dx_ptr[s] = static_cast<T>(is_positive) * dy_ptr[s];\n    }\n    relu_dx_ptr += step;\n    dy_ptr += step;\n  }\n\n  if (remain_loop_start_idx < elem_cnt) {\n    const int32_t mask_val = mask_ptr[outer_loop];\n    const int32_t remain = elem_cnt - remain_loop_start_idx;\n    for (int32_t i = 0; i < remain; ++i) {\n      bool is_positive = mask_val & (1 << i);\n      relu_dx_ptr[i] = static_cast<T>(is_positive) * dy_ptr[i];\n    }\n  }\n}\n\nstatic size_t InferGradTmpSizeForCpuKernel(user_op::InferContext* ctx) {\n  const auto& dy = ctx->InputTensorDesc(\"dy\", 0);\n  size_t tmp_size = 0;\n  if (ctx->op_type_name() == \"normalization_add_relu_grad\" && !ctx->has_output(\"addend_diff\", 0)) {\n    tmp_size += dy.shape().elem_cnt() * GetSizeOfDataType(dy.data_type());\n  }\n  return tmp_size;\n}\n\n// NOTE(Liang Depeng): helper functions to process datas for specific channel over all samples.\ntemplate<typename T, typename DataProcessor>\nstatic inline void ForEachFast(const T* data, const int64_t batch_size, const int64_t spatial_size,\n                               const int64_t jump_step, const int64_t channel_idx,\n                               DataProcessor data_processor) {\n  const int64_t start_offset = channel_idx * spatial_size;\n  const T* tmp_data = data + start_offset;\n  for (int64_t outer = 0; outer < batch_size; ++outer) {\n    for (int64_t i = 0; i < spatial_size; ++i) { data_processor(&tmp_data[i]); }\n    tmp_data += jump_step;\n  }\n}\n\ntemplate<typename T, typename DataProcessor>\nstatic inline void ForEachFast(const T* in_data1, const T* in_data2, const int64_t batch_size,\n                               const int64_t spatial_size, const int64_t jump_step,\n                               const int64_t channel_idx, DataProcessor data_processor) {\n  const int64_t start_offset = channel_idx * spatial_size;\n  const T* tmp_in_data1 = in_data1 + start_offset;\n  const T* tmp_in_data2 = in_data2 + start_offset;\n  for (int64_t outer = 0; outer < batch_size; ++outer) {\n    for (int64_t i = 0; i < spatial_size; ++i) {\n      data_processor(&tmp_in_data1[i], &tmp_in_data2[i]);\n    }\n    tmp_in_data1 += jump_step;\n    tmp_in_data2 += jump_step;\n  }\n}\n\ntemplate<typename T, typename DataProcessor>\nstatic inline void ForEachFast(const T* in_data, T* out_data, const int64_t batch_size,\n                               const int64_t spatial_size, const int64_t jump_step,\n                               const int64_t channel_idx, DataProcessor data_processor) {\n  const int64_t start_offset = channel_idx * spatial_size;\n  const T* tmp_in_data = in_data + start_offset;\n  T* tmp_out_data = out_data + start_offset;\n  for (int64_t outer = 0; outer < batch_size; ++outer) {\n    for (int64_t i = 0; i < spatial_size; ++i) {\n      data_processor(&tmp_in_data[i], &tmp_out_data[i]);\n    }\n    tmp_in_data += jump_step;\n    tmp_out_data += jump_step;\n  }\n}\n\ntemplate<typename T>\nclass NormalizationInferenceCpuKernel final : public user_op::OpKernel {\n public:\n  NormalizationInferenceCpuKernel() = default;\n  ~NormalizationInferenceCpuKernel() override = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const bool training = ctx->Attr<bool>(\"training\");\n    CHECK(!training);\n    const auto* x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    auto* y = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    const auto* gamma = ctx->Tensor4ArgNameAndIndex(\"gamma\", 0);\n    const auto* beta = ctx->Tensor4ArgNameAndIndex(\"beta\", 0);\n    auto* moving_mean = ctx->Tensor4ArgNameAndIndex(\"moving_mean\", 0);\n    auto* moving_variance = ctx->Tensor4ArgNameAndIndex(\"moving_variance\", 0);\n    const auto axis = ctx->Attr<int32_t>(\"axis\");\n    const auto epsilon = ctx->Attr<float>(\"epsilon\");\n\n    const DataType data_type = x->data_type();\n    CHECK_EQ(x->shape_view(), y->shape_view());\n    CHECK_EQ(y->data_type(), data_type);\n    CHECK_GE(axis, 0);\n    CHECK_LT(axis, x->shape_view().NumAxes());\n\n    if (axis == 1) {  // NOTE(Liang Depeng): NCHW format\n      const T* input_ptr = x->dptr<T>();\n      const T* gamma_ptr = gamma->dptr<T>();\n      const T* beta_ptr = beta->dptr<T>();\n\n      T* output_ptr = y->mut_dptr<T>();\n      T* moving_mean_ptr = moving_mean->mut_dptr<T>();\n      T* moving_variance_ptr = moving_variance->mut_dptr<T>();\n\n      const int64_t batch_size = x->shape_view().At(0);\n      const int64_t channel_size = x->shape_view().At(axis);\n      const int64_t spatial_size = x->shape_view().Count(axis + 1);\n\n      // NOTE(Liang Depeng):\n      // compute the normalization result\n      Normalize(input_ptr, moving_mean_ptr, moving_variance_ptr, gamma_ptr, beta_ptr, output_ptr,\n                batch_size, channel_size, spatial_size, epsilon, false);\n\n      if (ctx->has_input(\"_add_to_output\", 0)) {\n        const user_op::Tensor* add_to_output = ctx->Tensor4ArgNameAndIndex(\"_add_to_output\", 0);\n        CHECK_EQ(add_to_output->data_type(), y->data_type());\n        CHECK_EQ(add_to_output->shape_view(), y->shape_view());\n        AddToOutput(add_to_output->dptr<T>(), output_ptr, x->shape_view().elem_cnt());\n      }\n\n    } else {  // TODO(Liang Depeng): NHWC format\n      UNIMPLEMENTED() << \"cpu normalization op only support nchw data_format now!\";\n    }\n  }\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_BN_INFERENCE_CPU_KERNEL(dtype)                                           \\\n  REGISTER_USER_KERNEL(\"normalization\")                                                   \\\n      .SetCreateFn<NormalizationInferenceCpuKernel<dtype>>()                              \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)                     \\\n                       && (user_op::HobDataType(\"y\", 0) == GetDataType<dtype>::value)     \\\n                       && (user_op::HobAttr<bool>(\"training\") == false))                  \\\n      .SetInplaceProposalFn(                                                              \\\n          [](const user_op::InferContext& ctx,                                            \\\n             const user_op::AddInplaceArgPair& AddInplaceArgPairFn) -> Maybe<void> {      \\\n            if (ctx.has_input(\"_add_to_output\", 0)) {                                     \\\n              OF_RETURN_IF_ERROR(AddInplaceArgPairFn(\"y\", 0, \"_add_to_output\", 0, true)); \\\n            }                                                                             \\\n            return Maybe<void>::Ok();                                                     \\\n          });\n\nREGISTER_BN_INFERENCE_CPU_KERNEL(float)\nREGISTER_BN_INFERENCE_CPU_KERNEL(double)\n\n#undef REGISTER_BN_INFERENCE_CPU_KERNEL\n\ntemplate<typename T>\nclass NormalizationTrainCpuKernel final : public user_op::OpKernel {\n public:\n  NormalizationTrainCpuKernel() = default;\n  ~NormalizationTrainCpuKernel() override = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    if (ctx->op_type_name() == \"normalization\") { CHECK(ctx->Attr<bool>(\"training\")); }\n    const auto* x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    auto* y = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n\n    const auto axis = ctx->Attr<int32_t>(\"axis\");\n    const auto epsilon = ctx->Attr<float>(\"epsilon\");\n    const auto momentum = ctx->Attr<float>(\"momentum\");\n\n    const DataType data_type = x->data_type();\n    CHECK_EQ(x->shape_view(), y->shape_view());\n    CHECK_EQ(y->data_type(), data_type);\n    CHECK_GE(axis, 0);\n    CHECK_LT(axis, x->shape_view().NumAxes());\n\n    const auto* gamma = ctx->Tensor4ArgNameAndIndex(\"gamma\", 0);\n    const auto* beta = ctx->Tensor4ArgNameAndIndex(\"beta\", 0);\n    auto* mean = ctx->Tensor4ArgNameAndIndex(\"mean\", 0);\n    auto* inv_variance = ctx->Tensor4ArgNameAndIndex(\"inv_variance\", 0);\n\n    user_op::Tensor* moving_mean = nullptr;\n    user_op::Tensor* moving_variance = nullptr;\n    if (ctx->has_input(\"moving_mean\", 0)) {\n      CHECK(ctx->has_input(\"moving_variance\", 0));\n      moving_mean = ctx->Tensor4ArgNameAndIndex(\"moving_mean\", 0);\n      moving_variance = ctx->Tensor4ArgNameAndIndex(\"moving_variance\", 0);\n    }\n\n    if (axis == 1) {  // NOTE(Liang Depeng): NCHW format\n      const T* input_ptr = x->dptr<T>();\n      const T* gamma_ptr = gamma->dptr<T>();\n      const T* beta_ptr = beta->dptr<T>();\n\n      T* output_ptr = y->mut_dptr<T>();\n      T* mean_ptr = mean->mut_dptr<T>();\n      T* inv_variance_ptr = inv_variance->mut_dptr<T>();\n\n      T* moving_mean_ptr = nullptr;\n      T* moving_variance_ptr = nullptr;\n      if (moving_mean != nullptr && moving_variance != nullptr) {\n        moving_mean_ptr = moving_mean->mut_dptr<T>();\n        moving_variance_ptr = moving_variance->mut_dptr<T>();\n      }\n\n      const int64_t batch_size = x->shape_view().At(0);\n      const int64_t channel_size = x->shape_view().At(axis);\n      const int64_t spatial_size = x->shape_view().Count(axis + 1);\n\n      // NOTE(Liang Depeng):\n      // Compute mean & inv_variance and update moving_mean & moving_variance for each channel.\n      ComputeMeanAndVar(input_ptr, mean_ptr, inv_variance_ptr, moving_mean_ptr, moving_variance_ptr,\n                        batch_size, channel_size, spatial_size, epsilon, momentum);\n\n      // NOTE(Liang Depeng):\n      // compute the normalization result\n      Normalize(input_ptr, mean_ptr, inv_variance_ptr, gamma_ptr, beta_ptr, output_ptr, batch_size,\n                channel_size, spatial_size, epsilon, true);\n\n      if (ctx->has_input(\"_add_to_output\", 0)) {\n        const user_op::Tensor* add_to_output = ctx->Tensor4ArgNameAndIndex(\"_add_to_output\", 0);\n        CHECK_EQ(add_to_output->data_type(), y->data_type());\n        CHECK_EQ(add_to_output->shape_view(), y->shape_view());\n        AddToOutput(add_to_output->dptr<T>(), output_ptr, x->shape_view().elem_cnt());\n      }\n\n      if (ctx->op_type_name() == \"normalization_add_relu\") {\n        CHECK(!ctx->has_input(\"_add_to_output\", 0));\n        auto* mask = ctx->Tensor4ArgNameAndIndex(\"reserve_space\", 0);\n\n        if (ctx->has_input(\"addend\", 0)) {\n          const auto* addend = ctx->Tensor4ArgNameAndIndex(\"addend\", 0);\n          AddRelu(addend->dptr<T>(), mask->mut_dptr<int32_t>(), output_ptr,\n                  x->shape_view().elem_cnt());\n        } else {\n          Relu(mask->mut_dptr<int32_t>(), output_ptr, x->shape_view().elem_cnt());\n        }\n      }\n    } else {  // TODO(Liang Depeng): NHWC format\n    }\n  }\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_BN_TRAIN_CPU_KERNEL(dtype)                                               \\\n  REGISTER_USER_KERNEL(\"normalization\")                                                   \\\n      .SetCreateFn<NormalizationTrainCpuKernel<dtype>>()                                  \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)                     \\\n                       && (user_op::HobDataType(\"y\", 0) == GetDataType<dtype>::value)     \\\n                       && (user_op::HobAttr<bool>(\"training\") == true))                   \\\n      .SetInplaceProposalFn(                                                              \\\n          [](const user_op::InferContext& ctx,                                            \\\n             const user_op::AddInplaceArgPair& AddInplaceArgPairFn) -> Maybe<void> {      \\\n            if (ctx.has_input(\"_add_to_output\", 0)) {                                     \\\n              OF_RETURN_IF_ERROR(AddInplaceArgPairFn(\"y\", 0, \"_add_to_output\", 0, true)); \\\n            }                                                                             \\\n            return Maybe<void>::Ok();                                                     \\\n          });\n\nREGISTER_BN_TRAIN_CPU_KERNEL(float)\nREGISTER_BN_TRAIN_CPU_KERNEL(double)\n\n#undef REGISTER_BN_TRAIN_CPU_KERNEL\n\n#define REGISTER_BN_ADD_RELU_CPU_KERNEL(dtype)                        \\\n  REGISTER_USER_KERNEL(\"normalization_add_relu\")                      \\\n      .SetCreateFn<NormalizationTrainCpuKernel<dtype>>()              \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \\\n                       && (user_op::HobDataType(\"y\", 0) == GetDataType<dtype>::value));\n\nREGISTER_BN_ADD_RELU_CPU_KERNEL(float)\nREGISTER_BN_ADD_RELU_CPU_KERNEL(double)\n\n#undef REGISTER_BN_ADD_RELU_CPU_KERNEL\n\ntemplate<typename T>\nclass NormalizationGradCpuKernel final : public user_op::OpKernel {\n public:\n  NormalizationGradCpuKernel() = default;\n  ~NormalizationGradCpuKernel() override = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const auto* x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    auto* dx = ctx->Tensor4ArgNameAndIndex(\"dx\", 0);\n    const auto* dy = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    const auto* gamma = ctx->Tensor4ArgNameAndIndex(\"gamma\", 0);\n    auto* gamma_diff = ctx->Tensor4ArgNameAndIndex(\"gamma_diff\", 0);\n    auto* beta_diff = ctx->Tensor4ArgNameAndIndex(\"beta_diff\", 0);\n    const auto* mean = ctx->Tensor4ArgNameAndIndex(\"mean\", 0);\n    const auto* inv_variance = ctx->Tensor4ArgNameAndIndex(\"inv_variance\", 0);\n    auto* tmp_buffer = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n    const auto axis = ctx->Attr<int32_t>(\"axis\");\n\n    const DataType data_type = x->data_type();\n    CHECK_EQ(dy->shape_view(), x->shape_view());\n    CHECK_EQ(dy->data_type(), data_type);\n    CHECK_EQ(dx->shape_view(), x->shape_view());\n    CHECK_EQ(dx->data_type(), data_type);\n    CHECK_GE(axis, 0);\n    CHECK_LT(axis, x->shape_view().NumAxes());\n\n    const T* dy_ptr = nullptr;\n    if (ctx->op_type_name() == \"normalization_grad\") {\n      dy_ptr = dy->dptr<T>();\n    } else if (ctx->op_type_name() == \"normalization_add_relu_grad\") {\n      const auto* mask = ctx->Tensor4ArgNameAndIndex(\"reserve_space\", 0);\n      if (ctx->has_output(\"addend_diff\", 0)) {\n        user_op::Tensor* addend_diff = ctx->Tensor4ArgNameAndIndex(\"addend_diff\", 0);\n        AddReluGrad(dy->dptr<T>(), mask->dptr<int32_t>(), addend_diff->mut_dptr<T>(),\n                    dy->shape_view().elem_cnt());\n        dy_ptr = addend_diff->dptr<T>();\n      } else {\n        ReluGrad(dy->dptr<T>(), mask->dptr<int32_t>(), tmp_buffer->mut_dptr<T>(),\n                 dy->shape_view().elem_cnt());\n        dy_ptr = tmp_buffer->dptr<T>();\n      }\n\n    } else {\n      UNIMPLEMENTED();\n    }\n\n    if (axis == 1) {  // NOTE(Liang Depeng): NCHW format\n      const T* x_ptr = x->dptr<T>();\n      const T* gamma_ptr = gamma->dptr<T>();\n      const T* mean_ptr = mean->dptr<T>();\n      const T* inv_variance_ptr = inv_variance->dptr<T>();\n\n      T* dx_ptr = dx->mut_dptr<T>();\n      T* gamma_diff_ptr = gamma_diff->mut_dptr<T>();\n      T* beta_diff_ptr = beta_diff->mut_dptr<T>();\n\n      const int64_t batch_size = x->shape_view().At(0);\n      const int64_t channel_size = x->shape_view().At(axis);\n      const int64_t spatial_size = x->shape_view().Count(axis + 1);\n      const int64_t jump_step = spatial_size * channel_size;\n      const int64_t reduce_count = batch_size * spatial_size;\n\n      // NOTE(Liang Depeng):\n      // Borrow the MXNet implementation to compute dx, gamma_diff and beta_diff.\n      // For more details pls refers to:\n      // https://github.com/apache/incubator-mxnet/blob/master/src/operator/nn/batch_norm.cc\n      for (int64_t channel = 0; channel < channel_size; ++channel) {\n        const T gamma_c = gamma_ptr[channel];\n        const T mean_c = mean_ptr[channel];\n        const T inv_variance_c = inv_variance_ptr[channel];\n\n        // NOTE(Liang Depeng): sum dy for specific channel over all samples\n        T sum_dy_out = 0;\n        ForEachFast(dy_ptr, batch_size, spatial_size, jump_step, channel,\n                    [&sum_dy_out](const T* dy_data) { sum_dy_out += *dy_data; });\n\n        // NOTE(Liang Depeng): dot product of the x and dy\n        T dotp = 0;\n        ForEachFast(x_ptr, dy_ptr, batch_size, spatial_size, jump_step, channel,\n                    [&dotp, mean_c](const T* x_data, const T* dy_data) {\n                      dotp += (*x_data - mean_c) * (*dy_data);\n                    });\n\n        // NOTE(Liang Depeng): projection of dy on to output scaled by std\n        const T k = dotp * inv_variance_c * inv_variance_c / reduce_count;\n        const T iw = inv_variance_c * gamma_c;\n        const T grad_mean_c = sum_dy_out / reduce_count;\n        ForEachFast(\n            x_ptr, dx_ptr, batch_size, spatial_size, jump_step, channel,\n            [&mean_c, &k](const T* x_data, T* dx_data) { *dx_data = (*x_data - mean_c) * k; });\n\n        ForEachFast(dy_ptr, dx_ptr, batch_size, spatial_size, jump_step, channel,\n                    [iw, grad_mean_c](const T* dy_data, T* dx_data) {\n                      *dx_data = (*dy_data - grad_mean_c - *dx_data) * iw;\n                    });\n\n        gamma_diff_ptr[channel] = dotp * inv_variance_c;\n        beta_diff_ptr[channel] = sum_dy_out;\n      }\n\n    } else {  // TODO(Liang Depeng): NHWC format\n    }\n  }\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_BN_GRAD_CPU_KERNEL(dtype)                            \\\n  REGISTER_USER_KERNEL(\"normalization_grad\")                          \\\n      .SetCreateFn<NormalizationGradCpuKernel<dtype>>()               \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \\\n                       && (user_op::HobDataType(\"dx\", 0) == GetDataType<dtype>::value));\n\nREGISTER_BN_GRAD_CPU_KERNEL(float)\nREGISTER_BN_GRAD_CPU_KERNEL(double)\n\n#undef REGISTER_BN_GRAD_CPU_KERNEL\n\n#define REGISTER_BN_ADD_RELU_GRAD_CPU_KERNEL(dtype)                                     \\\n  REGISTER_USER_KERNEL(\"normalization_add_relu_grad\")                                   \\\n      .SetCreateFn<NormalizationGradCpuKernel<dtype>>()                                 \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)                   \\\n                       && (user_op::HobDataType(\"dx\", 0) == GetDataType<dtype>::value)) \\\n      .SetInferTmpSizeFn(InferGradTmpSizeForCpuKernel);\n\nREGISTER_BN_ADD_RELU_GRAD_CPU_KERNEL(float)\nREGISTER_BN_ADD_RELU_GRAD_CPU_KERNEL(double)\n\n#undef REGISTER_BN_ADD_RELU_GRAD_CPU_KERNEL\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/normalization_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifdef WITH_CUDA\n\n#include <unordered_map>\n\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/device/cudnn_util.h\"\n#include \"oneflow/core/kernel/new_kernel_util.h\"\n#include \"oneflow/core/kernel/cuda_graph_support.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n#include \"oneflow/core/device/cuda_pseudo_bfloat16.h\"\n#if CUDA_VERSION >= 11000\n#include <cuda_bf16.h>\n#endif  // CUDA_VERSION >= 11000\n#include <cudnn.h>\n\n#if (CUDNN_VERSION >= 7401)\n#define BN_ENABLE_EX_API\n#endif\n\nnamespace oneflow {\n\nnamespace {\n\ncudnnBatchNormMode_t getCudnnBatchNormMode(const int64_t dim) {\n  if (dim == 2) {\n    return CUDNN_BATCHNORM_PER_ACTIVATION;\n  } else if (ParseBooleanFromEnv(\"ONEFLOW_ENABLE_NHWC\", false)) {\n    return CUDNN_BATCHNORM_SPATIAL_PERSISTENT;\n  } else {\n    // NOTE(Liang Depeng): The new CUDNN_BATCHNORM_SPATIAL_PERSISTENT mode was\n    // introduced in CuDNN 7 for performance optimization, but it results in\n    // accuracy losses in convolution models such as ResNeXt-101 and\n    // video R(2+1)D. We will fall back to the normal CUDNN_BATCHNORM_SPATIAL\n    return CUDNN_BATCHNORM_SPATIAL;\n  }\n}\n\nvoid InferDimSizeAndDataFormat(const ShapeView& x_shape, const int32_t axis, int32_t* n, int32_t* c,\n                               int32_t* h, int32_t* w, cudnnTensorFormat_t* format) {\n  if (x_shape.Count(axis + 1) == 1) {\n    if (axis == 0) {\n      *n = 1;\n      *h = 1;\n    } else {\n      *n = x_shape.At(0);\n      *h = x_shape.Count(1, axis);\n    }\n    *w = 1;\n    *c = x_shape.At(axis);\n    *format = CUDNN_TENSOR_NHWC;\n  } else {\n    *n = x_shape.Count(0, axis);\n    *c = x_shape.At(axis);\n    *h = x_shape.Count(axis + 1);\n    *w = 1;\n    *format = CUDNN_TENSOR_NCHW;\n  }\n}\n\nvoid InferXYCudnnTensorDesc(const ShapeView& xy_shape, const DataType& data_type,\n                            const int32_t axis, cudnnTensorDescriptor_t xy_desc) {\n  int32_t n, c, h, w;\n  cudnnTensorFormat_t format;\n  InferDimSizeAndDataFormat(xy_shape, axis, &n, &c, &h, &w, &format);\n  OF_CUDNN_CHECK(\n      cudnnSetTensor4dDescriptor(xy_desc, format, GetCudnnDataType(data_type), n, c, h, w));\n}\n\nvoid InferParamCudnnTensorDesc(const cudnnTensorDescriptor_t xy_desc, cudnnBatchNormMode_t mode,\n                               cudnnTensorDescriptor_t param_desc) {\n  OF_CUDNN_CHECK(cudnnDeriveBNTensorDescriptor(param_desc, xy_desc, mode));\n}\n\nclass CudnnTensorDescHelper final {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CudnnTensorDescHelper);\n  CudnnTensorDescHelper(const ShapeView& xy_shape, const DataType& data_type, const int32_t axis,\n                        cudnnBatchNormMode_t mode) {\n    OF_CUDNN_CHECK(cudnnCreateTensorDescriptor(&xy_desc_));\n    InferXYCudnnTensorDesc(xy_shape, data_type, axis, xy_desc_);\n    OF_CUDNN_CHECK(cudnnCreateTensorDescriptor(&param_desc_));\n    InferParamCudnnTensorDesc(xy_desc_, mode, param_desc_);\n    int n, c, h, w, n_stride, c_stride, h_stride, w_stride;\n    OF_CUDNN_CHECK(cudnnGetTensor4dDescriptor(param_desc_, &param_data_type_, &n, &c, &h, &w,\n                                              &n_stride, &c_stride, &h_stride, &w_stride));\n    param_size_ = c;\n  }\n  ~CudnnTensorDescHelper() {\n    OF_CUDNN_CHECK(cudnnDestroyTensorDescriptor(param_desc_));\n    OF_CUDNN_CHECK(cudnnDestroyTensorDescriptor(xy_desc_));\n  }\n\n  cudnnTensorDescriptor_t xy_desc() const { return xy_desc_; }\n\n  cudnnTensorDescriptor_t param_desc() const { return param_desc_; }\n\n  void CheckParamTensor(const user_op::Tensor* tensor) const {\n    CHECK_NOTNULL(tensor);\n    CHECK_EQ(tensor->shape_view().NumAxes(), 1);\n    CHECK_EQ(tensor->shape_view().At(0), param_size_);\n    CHECK_EQ(GetCudnnDataType(tensor->data_type()), param_data_type_);\n  }\n\n private:\n  cudnnTensorDescriptor_t xy_desc_ = nullptr;\n  cudnnTensorDescriptor_t param_desc_ = nullptr;\n  cudnnDataType_t param_data_type_;\n  int32_t param_size_ = 0;\n};\n\nsize_t InferTrainWorkspaceSize(const ShapeView& x_shape, const DataType data_type,\n                               const int32_t axis) {\n#if defined(BN_ENABLE_EX_API)\n  cudnnBatchNormMode_t mode = getCudnnBatchNormMode(x_shape.NumAxes());\n  const CudnnTensorDescHelper desc_helper(x_shape, data_type, axis, mode);\n  size_t size_in_bytes;\n  cudnnHandle_t handle = Singleton<CudnnHandlePool>::Get()->Get();\n  OF_CUDNN_CHECK(cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize(\n      handle, mode, CUDNN_BATCHNORM_OPS_BN, desc_helper.xy_desc(), nullptr, desc_helper.xy_desc(),\n      desc_helper.param_desc(), nullptr, &size_in_bytes));\n  Singleton<CudnnHandlePool>::Get()->Put(handle);\n  return std::max(size_in_bytes, static_cast<size_t>(1));\n#else\n  return 1;\n#endif\n}\n\nsize_t InferTrainTmpSize(user_op::InferContext* ctx) {\n  const auto& x = ctx->InputTensorDesc(\"x\", 0);\n  const auto axis = ctx->Attr<int32_t>(\"axis\");\n  return InferTrainWorkspaceSize(x.shape(), x.data_type(), axis);\n}\n\nsize_t InferGradWorkspaceSize(const ShapeView& x_shape, const DataType data_type,\n                              const int32_t axis) {\n#if defined(BN_ENABLE_EX_API)\n  cudnnBatchNormMode_t mode = getCudnnBatchNormMode(x_shape.NumAxes());\n  const CudnnTensorDescHelper desc_helper(x_shape, data_type, axis, mode);\n  size_t size_in_bytes;\n  cudnnHandle_t handle = Singleton<CudnnHandlePool>::Get()->Get();\n  OF_CUDNN_CHECK(cudnnGetBatchNormalizationBackwardExWorkspaceSize(\n      handle, mode, CUDNN_BATCHNORM_OPS_BN, desc_helper.xy_desc(), nullptr, desc_helper.xy_desc(),\n      nullptr, desc_helper.xy_desc(), desc_helper.param_desc(), nullptr, &size_in_bytes));\n  Singleton<CudnnHandlePool>::Get()->Put(handle);\n  return std::max(size_in_bytes, static_cast<size_t>(1));\n#else\n  return 1;\n#endif\n}\n\nsize_t InferGradTmpSize(user_op::InferContext* ctx) {\n  const auto& dy = ctx->InputTensorDesc(\"dy\", 0);\n  const auto axis = ctx->Attr<int32_t>(\"axis\");\n  size_t tmp_size = 0;\n  if (ctx->op_type_name() == \"normalization_add_relu_grad\" && !ctx->has_output(\"addend_diff\", 0)) {\n    tmp_size += GetCudaAlignedSize(dy.shape().elem_cnt() * GetSizeOfDataType(dy.data_type()));\n  }\n  tmp_size += GetCudaAlignedSize(InferGradWorkspaceSize(dy.shape(), dy.data_type(), axis));\n  return tmp_size;\n}\n\nconstexpr int64_t kCudaWarpSize = 32;\n\ntemplate<typename T>\n__global__ void ReluGpu(int64_t n, const T* x, T* y, int32_t* mask) {\n  const int32_t lane_id = threadIdx.x % kCudaWarpSize;\n  const T zero = static_cast<T>(0.f);\n  CUDA_1D_KERNEL_LOOP(i, n) {\n    const T x_val = x[i];\n    const bool is_positive = (x_val > zero);\n    int32_t warp_mask = __ballot_sync(__activemask(), static_cast<int>(is_positive));\n    if (lane_id == 0) { mask[i / kCudaWarpSize] = warp_mask; }\n    y[i] = is_positive ? x_val : zero;\n  }\n}\n\ntemplate<typename T>\n__global__ void AddReluGpu(int64_t n, const T* x, const T* addend, T* y, int32_t* mask) {\n  const int32_t lane_id = threadIdx.x % kCudaWarpSize;\n  const T zero = static_cast<T>(0.f);\n  CUDA_1D_KERNEL_LOOP(i, n) {\n    const T sum = x[i] + addend[i];\n    const bool is_positive = (sum > zero);\n    int32_t warp_mask = __ballot_sync(__activemask(), static_cast<int>(is_positive));\n    if (lane_id == 0) { mask[i / kCudaWarpSize] = warp_mask; }\n    y[i] = is_positive ? sum : zero;\n  }\n}\n\ntemplate<typename T>\nvoid Relu(ep::Stream* stream, int64_t n, const T* x, T* y, int32_t* mask) {\n  ReluGpu<T><<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0,\n               stream->As<ep::CudaStream>()->cuda_stream()>>>(n, x, y, mask);\n}\n\ntemplate<typename T>\nvoid AddRelu(ep::Stream* stream, int64_t n, const T* x, const T* addend, T* y, int32_t* mask) {\n  AddReluGpu<T><<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0,\n                  stream->As<ep::CudaStream>()->cuda_stream()>>>(n, x, addend, y, mask);\n}\n\ntemplate<typename T>\n__global__ void ReluBackwardGpu(int64_t n, const int32_t* mask, const T* dy, T* addend_diff) {\n  int32_t lane_id = threadIdx.x % kCudaWarpSize;\n  CUDA_1D_KERNEL_LOOP(i, n) {\n    int32_t mask_val = mask[i / kCudaWarpSize];\n    bool is_positive = mask_val & (1 << lane_id);\n    addend_diff[i] = static_cast<T>(is_positive) * dy[i];\n  }\n}\n\n#if CUDA_VERSION >= 11000\n\ntemplate<>\n__global__ void ReluBackwardGpu<nv_bfloat16>(int64_t n, const int32_t* mask, const nv_bfloat16* dy,\n                                             nv_bfloat16* addend_diff) {\n  int32_t lane_id = threadIdx.x % kCudaWarpSize;\n  CUDA_1D_KERNEL_LOOP(i, n) {\n    int32_t mask_val = mask[i / kCudaWarpSize];\n    bool is_positive = mask_val & (1 << lane_id);\n    addend_diff[i] = static_cast<nv_bfloat16>(static_cast<float>(is_positive)) * dy[i];\n  }\n}\n\n#endif\n\ntemplate<typename T>\nvoid ReluBackward(ep::Stream* stream, int64_t n, const int32_t* mask, const T* dy, T* addend_diff) {\n  ReluBackwardGpu<T><<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0,\n                       stream->As<ep::CudaStream>()->cuda_stream()>>>(n, mask, dy, addend_diff);\n}\n\nvoid Relu(ep::Stream* stream, int64_t n, const DataType data_type, const void* x, void* y,\n          int32_t* mask) {\n  if (data_type == kFloat) {\n    Relu<float>(stream, n, reinterpret_cast<const float*>(x), reinterpret_cast<float*>(y), mask);\n  } else if (data_type == kDouble) {\n    Relu<double>(stream, n, reinterpret_cast<const double*>(x), reinterpret_cast<double*>(y), mask);\n  } else if (data_type == kFloat16) {\n    Relu<half>(stream, n, reinterpret_cast<const half*>(x), reinterpret_cast<half*>(y), mask);\n  } else if (data_type == kBFloat16) {\n#if CUDA_VERSION >= 11000\n    Relu<nv_bfloat16>(stream, n, reinterpret_cast<const nv_bfloat16*>(x),\n                      reinterpret_cast<nv_bfloat16*>(y), mask);\n#else\n    UNIMPLEMENTED();\n#endif\n  } else {\n    UNIMPLEMENTED();\n  }\n}\nvoid AddRelu(ep::Stream* stream, int64_t n, const DataType data_type, const void* x,\n             const void* addend, void* y, int32_t* mask) {\n  if (data_type == kFloat) {\n    AddRelu<float>(stream, n, reinterpret_cast<const float*>(x),\n                   reinterpret_cast<const float*>(addend), reinterpret_cast<float*>(y), mask);\n  } else if (data_type == kDouble) {\n    AddRelu<double>(stream, n, reinterpret_cast<const double*>(x),\n                    reinterpret_cast<const double*>(addend), reinterpret_cast<double*>(y), mask);\n  } else if (data_type == kFloat16) {\n    AddRelu<half>(stream, n, reinterpret_cast<const half*>(x),\n                  reinterpret_cast<const half*>(addend), reinterpret_cast<half*>(y), mask);\n  } else if (data_type == kBFloat16) {\n#if CUDA_VERSION >= 11000\n    AddRelu<nv_bfloat16>(stream, n, reinterpret_cast<const nv_bfloat16*>(x),\n                         reinterpret_cast<const nv_bfloat16*>(addend),\n                         reinterpret_cast<nv_bfloat16*>(y), mask);\n#else\n    UNIMPLEMENTED();\n#endif\n  } else {\n    UNIMPLEMENTED();\n  }\n}\nvoid ReluBackward(ep::Stream* stream, int64_t n, const DataType data_type, const int32_t* mask,\n                  const void* dy, void* addend_diff) {\n  if (data_type == kFloat) {\n    ReluBackward<float>(stream, n, mask, reinterpret_cast<const float*>(dy),\n                        reinterpret_cast<float*>(addend_diff));\n  } else if (data_type == kDouble) {\n    ReluBackward<double>(stream, n, mask, reinterpret_cast<const double*>(dy),\n                         reinterpret_cast<double*>(addend_diff));\n  } else if (data_type == kFloat16) {\n    ReluBackward<half>(stream, n, mask, reinterpret_cast<const half*>(dy),\n                       reinterpret_cast<half*>(addend_diff));\n  } else if (data_type == kBFloat16) {\n#if CUDA_VERSION >= 11000\n    ReluBackward<nv_bfloat16>(stream, n, mask, reinterpret_cast<const nv_bfloat16*>(dy),\n                              reinterpret_cast<nv_bfloat16*>(addend_diff));\n#else\n    UNIMPLEMENTED();\n#endif\n  } else {\n    UNIMPLEMENTED();\n  }\n}\n\nclass NormalizationInferenceKernel final : public user_op::OpKernel,\n                                           public user_op::CudaGraphSupport {\n public:\n  NormalizationInferenceKernel() = default;\n  ~NormalizationInferenceKernel() override = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const bool training = ctx->Attr<bool>(\"training\");\n    CHECK(!training);\n    const auto* x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    auto* y = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    const auto* gamma = ctx->Tensor4ArgNameAndIndex(\"gamma\", 0);\n    const auto* beta = ctx->Tensor4ArgNameAndIndex(\"beta\", 0);\n    auto* moving_mean = ctx->Tensor4ArgNameAndIndex(\"moving_mean\", 0);\n    auto* moving_variance = ctx->Tensor4ArgNameAndIndex(\"moving_variance\", 0);\n    const auto axis = ctx->Attr<int32_t>(\"axis\");\n    const auto epsilon = ctx->Attr<float>(\"epsilon\");\n\n    const DataType data_type = x->data_type();\n    CHECK_EQ(x->shape_view(), y->shape_view());\n    CHECK_EQ(y->data_type(), data_type);\n    CHECK_GE(axis, 0);\n    CHECK_LT(axis, x->shape_view().NumAxes());\n\n    cudnnBatchNormMode_t mode = getCudnnBatchNormMode(x->shape_view().NumAxes());\n    const CudnnTensorDescHelper desc_helper(x->shape_view(), data_type, axis, mode);\n    desc_helper.CheckParamTensor(gamma);\n    desc_helper.CheckParamTensor(beta);\n    desc_helper.CheckParamTensor(moving_mean);\n    desc_helper.CheckParamTensor(moving_variance);\n\n    const void* sp_alpha = CudnnSPOnePtr(data_type);\n    const void* sp_beta;\n    if (ctx->has_input(\"_add_to_output\", 0)) {\n      const user_op::Tensor* add_to_output = ctx->Tensor4ArgNameAndIndex(\"_add_to_output\", 0);\n      CHECK_EQ(add_to_output->data_type(), y->data_type());\n      CHECK_EQ(add_to_output->shape_view(), y->shape_view());\n      Memcpy<DeviceType::kCUDA>(\n          ctx->stream(), y->mut_dptr<void>(), add_to_output->dptr<void>(),\n          add_to_output->shape_view().elem_cnt() * GetSizeOfDataType(add_to_output->data_type()));\n      sp_beta = CudnnSPOnePtr(data_type);\n    } else {\n      sp_beta = CudnnSPZeroPtr(data_type);\n    }\n\n    OF_CUDNN_CHECK(cudnnBatchNormalizationForwardInference(\n        ctx->stream()->As<ep::CudaStream>()->cudnn_handle(), mode, sp_alpha, sp_beta,\n        desc_helper.xy_desc(), x->dptr(), desc_helper.xy_desc(), y->mut_dptr(),\n        desc_helper.param_desc(), gamma->dptr(), beta->dptr(), moving_mean->dptr(),\n        moving_variance->dptr(), epsilon));\n\n    if (ctx->op_type_name() == \"normalization_add_relu\") {\n      CHECK(!ctx->has_input(\"_add_to_output\", 0));\n      const int64_t elem_cnt = x->shape_view().elem_cnt();\n      auto* mask = ctx->Tensor4ArgNameAndIndex(\"reserve_space\", 0);\n      if (ctx->has_input(\"addend\", 0)) {\n        const auto* addend = ctx->Tensor4ArgNameAndIndex(\"addend\", 0);\n        AddRelu(ctx->stream(), elem_cnt, data_type, y->dptr(), addend->dptr(), y->mut_dptr(),\n                mask->mut_dptr<int32_t>());\n      } else {\n        Relu(ctx->stream(), elem_cnt, data_type, y->dptr(), y->mut_dptr(),\n             mask->mut_dptr<int32_t>());\n      }\n    }\n  }\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\nREGISTER_USER_KERNEL(\"normalization\")\n    .SetCreateFn<NormalizationInferenceKernel>()\n    .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)\n                     && (user_op::HobAttr<bool>(\"training\") == false))\n    .SetInplaceProposalFn([](const user_op::InferContext& ctx,\n                             user_op::AddInplaceArgPair AddInplaceArgPairFn) -> Maybe<void> {\n      if (ctx.has_input(\"_add_to_output\", 0)) {\n        OF_RETURN_IF_ERROR(AddInplaceArgPairFn(\"y\", 0, \"_add_to_output\", 0, true));\n      }\n      return Maybe<void>::Ok();\n    });\n\nREGISTER_USER_KERNEL(\"normalization_add_relu\")\n    .SetCreateFn<NormalizationInferenceKernel>()\n    .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)\n                     && (user_op::HobAttr<bool>(\"training\") == false));\n\nclass NormalizationTrainKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport {\n public:\n  NormalizationTrainKernel() = default;\n  ~NormalizationTrainKernel() override = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    if (ctx->op_type_name() == \"normalization\") { CHECK(ctx->Attr<bool>(\"training\")); }\n    const auto* x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    auto* y = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n\n    const auto axis = ctx->Attr<int32_t>(\"axis\");\n    const auto epsilon = ctx->Attr<float>(\"epsilon\");\n    const auto momentum = ctx->Attr<float>(\"momentum\");\n\n    const DataType data_type = x->data_type();\n    CHECK_EQ(x->shape_view(), y->shape_view());\n    CHECK_EQ(y->data_type(), data_type);\n    CHECK_GE(axis, 0);\n    CHECK_LT(axis, x->shape_view().NumAxes());\n    cudnnBatchNormMode_t mode = getCudnnBatchNormMode(x->shape_view().NumAxes());\n    const CudnnTensorDescHelper desc_helper(x->shape_view(), data_type, axis, mode);\n\n    const auto* gamma = ctx->Tensor4ArgNameAndIndex(\"gamma\", 0);\n    const auto* beta = ctx->Tensor4ArgNameAndIndex(\"beta\", 0);\n    auto* mean = ctx->Tensor4ArgNameAndIndex(\"mean\", 0);\n    auto* inv_variance = ctx->Tensor4ArgNameAndIndex(\"inv_variance\", 0);\n    desc_helper.CheckParamTensor(gamma);\n    desc_helper.CheckParamTensor(beta);\n    desc_helper.CheckParamTensor(mean);\n    desc_helper.CheckParamTensor(inv_variance);\n\n    user_op::Tensor* moving_mean = nullptr;\n    user_op::Tensor* moving_variance = nullptr;\n    if (ctx->has_input(\"moving_mean\", 0)) {\n      CHECK(ctx->has_input(\"moving_variance\", 0));\n      moving_mean = ctx->Tensor4ArgNameAndIndex(\"moving_mean\", 0);\n      moving_variance = ctx->Tensor4ArgNameAndIndex(\"moving_variance\", 0);\n      desc_helper.CheckParamTensor(moving_mean);\n      desc_helper.CheckParamTensor(moving_variance);\n    }\n\n    const void* sp_alpha = CudnnSPOnePtr(data_type);\n    const void* sp_beta;\n    if (ctx->has_input(\"_add_to_output\", 0)) {\n      const user_op::Tensor* add_to_output = ctx->Tensor4ArgNameAndIndex(\"_add_to_output\", 0);\n      CHECK_EQ(add_to_output->data_type(), y->data_type());\n      CHECK_EQ(add_to_output->shape_view(), y->shape_view());\n      Memcpy<DeviceType::kCUDA>(\n          ctx->stream(), y->mut_dptr<void>(), add_to_output->dptr<void>(),\n          add_to_output->shape_view().elem_cnt() * GetSizeOfDataType(add_to_output->data_type()));\n      sp_beta = CudnnSPOnePtr(data_type);\n    } else {\n      sp_beta = CudnnSPZeroPtr(data_type);\n    }\n\n#if defined(BN_ENABLE_EX_API)\n    size_t workspace_size;\n    OF_CUDNN_CHECK(cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize(\n        ctx->stream()->As<ep::CudaStream>()->cudnn_handle(), mode, CUDNN_BATCHNORM_OPS_BN,\n        desc_helper.xy_desc(), nullptr, desc_helper.xy_desc(), desc_helper.param_desc(), nullptr,\n        &workspace_size));\n    size_t reserve_space_size;\n    OF_CUDNN_CHECK(cudnnGetBatchNormalizationTrainingExReserveSpaceSize(\n        ctx->stream()->As<ep::CudaStream>()->cudnn_handle(), mode, CUDNN_BATCHNORM_OPS_BN, nullptr,\n        desc_helper.xy_desc(), &reserve_space_size));\n    auto* workspace = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n    if (reserve_space_size == 0 && workspace_size <= workspace->shape_view().elem_cnt()) {\n      OF_CUDNN_CHECK(cudnnBatchNormalizationForwardTrainingEx(\n          ctx->stream()->As<ep::CudaStream>()->cudnn_handle(), mode, CUDNN_BATCHNORM_OPS_BN,\n          sp_alpha, sp_beta, desc_helper.xy_desc(), x->dptr(), nullptr, nullptr,\n          desc_helper.xy_desc(), y->mut_dptr(), desc_helper.param_desc(), gamma->dptr(),\n          beta->dptr(), 1.0 - momentum, moving_mean ? moving_mean->mut_dptr() : NULL,\n          moving_variance ? moving_variance->mut_dptr() : NULL, epsilon, mean->mut_dptr(),\n          inv_variance->mut_dptr(), nullptr, workspace->mut_dptr(),\n          workspace->shape_view().elem_cnt(), nullptr, 0));\n    } else {\n      OF_CUDNN_CHECK(cudnnBatchNormalizationForwardTraining(\n          ctx->stream()->As<ep::CudaStream>()->cudnn_handle(), mode, sp_alpha, sp_beta,\n          desc_helper.xy_desc(), x->dptr(), desc_helper.xy_desc(), y->mut_dptr(),\n          desc_helper.param_desc(), gamma->dptr(), beta->dptr(), 1.0 - momentum,\n          moving_mean ? moving_mean->mut_dptr() : NULL,\n          moving_variance ? moving_variance->mut_dptr() : NULL, epsilon, mean->mut_dptr(),\n          inv_variance->mut_dptr()));\n    }\n#else\n    OF_CUDNN_CHECK(cudnnBatchNormalizationForwardTraining(\n        ctx->stream()->As<ep::CudaStream>()->cudnn_handle(), mode, sp_alpha, sp_beta,\n        desc_helper.xy_desc(), x->dptr(), desc_helper.xy_desc(), y->mut_dptr(),\n        desc_helper.param_desc(), gamma->dptr(), beta->dptr(), 1.0 - momentum,\n        moving_mean ? moving_mean->mut_dptr() : NULL,\n        moving_variance ? moving_variance->mut_dptr() : NULL, epsilon, mean->mut_dptr(),\n        inv_variance->mut_dptr()));\n#endif\n\n    if (ctx->op_type_name() == \"normalization_add_relu\") {\n      CHECK(!ctx->has_input(\"_add_to_output\", 0));\n      const int64_t elem_cnt = x->shape_view().elem_cnt();\n      auto* mask = ctx->Tensor4ArgNameAndIndex(\"reserve_space\", 0);\n      if (ctx->has_input(\"addend\", 0)) {\n        const auto* addend = ctx->Tensor4ArgNameAndIndex(\"addend\", 0);\n        AddRelu(ctx->stream(), elem_cnt, data_type, y->dptr(), addend->dptr(), y->mut_dptr(),\n                mask->mut_dptr<int32_t>());\n      } else {\n        Relu(ctx->stream(), elem_cnt, data_type, y->dptr(), y->mut_dptr(),\n             mask->mut_dptr<int32_t>());\n      }\n    }\n  }\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\nREGISTER_USER_KERNEL(\"normalization\")\n    .SetCreateFn<NormalizationTrainKernel>()\n    .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)\n                     && (user_op::HobAttr<bool>(\"training\") == true))\n    .SetInferTmpSizeFn(InferTrainTmpSize)\n    .SetInplaceProposalFn([](const user_op::InferContext& ctx,\n                             user_op::AddInplaceArgPair AddInplaceArgPairFn) -> Maybe<void> {\n      if (ctx.has_input(\"_add_to_output\", 0)) {\n        OF_RETURN_IF_ERROR(AddInplaceArgPairFn(\"y\", 0, \"_add_to_output\", 0, true));\n      }\n      return Maybe<void>::Ok();\n    });\n\nREGISTER_USER_KERNEL(\"normalization_add_relu\")\n    .SetCreateFn<NormalizationTrainKernel>()\n    .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)\n                     && (user_op::HobAttr<bool>(\"training\") == true))\n    .SetInferTmpSizeFn(InferTrainTmpSize);\n\nclass NormalizationGradUserKernel final : public user_op::OpKernel,\n                                          public user_op::CudaGraphSupport {\n public:\n  NormalizationGradUserKernel() = default;\n  ~NormalizationGradUserKernel() override = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const auto* x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    auto* dx = ctx->Tensor4ArgNameAndIndex(\"dx\", 0);\n    const auto* dy = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    const auto* gamma = ctx->Tensor4ArgNameAndIndex(\"gamma\", 0);\n    auto* gamma_diff = ctx->Tensor4ArgNameAndIndex(\"gamma_diff\", 0);\n    auto* beta_diff = ctx->Tensor4ArgNameAndIndex(\"beta_diff\", 0);\n    const auto* mean = ctx->Tensor4ArgNameAndIndex(\"mean\", 0);\n    const auto* inv_variance = ctx->Tensor4ArgNameAndIndex(\"inv_variance\", 0);\n    auto* tmp_buffer = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n    const auto axis = ctx->Attr<int32_t>(\"axis\");\n    const auto epsilon = ctx->Attr<float>(\"epsilon\");\n\n    const DataType data_type = x->data_type();\n    CHECK_EQ(dy->shape_view(), x->shape_view());\n    CHECK_EQ(dy->data_type(), data_type);\n    CHECK_EQ(dx->shape_view(), x->shape_view());\n    CHECK_EQ(dx->data_type(), data_type);\n    CHECK_GE(axis, 0);\n    CHECK_LT(axis, x->shape_view().NumAxes());\n    cudnnBatchNormMode_t mode = getCudnnBatchNormMode(x->shape_view().NumAxes());\n    const CudnnTensorDescHelper desc_helper(x->shape_view(), data_type, axis, mode);\n    desc_helper.CheckParamTensor(gamma);\n    desc_helper.CheckParamTensor(gamma_diff);\n    desc_helper.CheckParamTensor(beta_diff);\n    desc_helper.CheckParamTensor(mean);\n    desc_helper.CheckParamTensor(inv_variance);\n\n    void* bn_workspace_ptr;\n    size_t bn_workspace_size;\n    const void* bn_dy_ptr;\n\n    if (ctx->op_type_name() == \"normalization_grad\") {\n      bn_workspace_ptr = tmp_buffer->mut_dptr();\n      bn_workspace_size = tmp_buffer->shape_view().elem_cnt();\n      bn_dy_ptr = dy->dptr();\n    } else if (ctx->op_type_name() == \"normalization_add_relu_grad\") {\n      const int64_t elem_cnt = dy->shape_view().elem_cnt();\n      const auto* mask = ctx->Tensor4ArgNameAndIndex(\"reserve_space\", 0);\n      user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n      if (ctx->has_output(\"addend_diff\", 0)) {\n        user_op::Tensor* addend_diff = ctx->Tensor4ArgNameAndIndex(\"addend_diff\", 0);\n        ReluBackward(ctx->stream(), elem_cnt, data_type, mask->dptr<int32_t>(), dy->dptr(),\n                     addend_diff->mut_dptr());\n        bn_workspace_ptr = tmp_buffer->mut_dptr();\n        bn_workspace_size = tmp_buffer->shape_view().elem_cnt();\n        bn_dy_ptr = addend_diff->dptr();\n      } else {\n        const size_t tmp_buffer_size = tmp_buffer->shape_view().elem_cnt();\n        const size_t relu_dx_size =\n            GetCudaAlignedSize(dy->shape_view().elem_cnt() * GetSizeOfDataType(data_type));\n        CHECK_GE(tmp_buffer_size, relu_dx_size);\n        ReluBackward(ctx->stream(), elem_cnt, data_type, mask->dptr<int32_t>(), dy->dptr(),\n                     tmp_buffer->mut_dptr());\n        bn_workspace_ptr = tmp_buffer->mut_dptr<char>() + relu_dx_size;\n        bn_workspace_size = tmp_buffer_size - relu_dx_size;\n        bn_dy_ptr = tmp_buffer->dptr();\n      }\n    } else {\n      UNIMPLEMENTED();\n    }\n\n#if defined(BN_ENABLE_EX_API)\n    size_t workspace_size;\n    OF_CUDNN_CHECK(cudnnGetBatchNormalizationBackwardExWorkspaceSize(\n        ctx->stream()->As<ep::CudaStream>()->cudnn_handle(), mode, CUDNN_BATCHNORM_OPS_BN,\n        desc_helper.xy_desc(), nullptr, desc_helper.xy_desc(), nullptr, desc_helper.xy_desc(),\n        desc_helper.param_desc(), nullptr, &workspace_size));\n    size_t reserve_space_size;\n    OF_CUDNN_CHECK(cudnnGetBatchNormalizationTrainingExReserveSpaceSize(\n        ctx->stream()->As<ep::CudaStream>()->cudnn_handle(), mode, CUDNN_BATCHNORM_OPS_BN, nullptr,\n        desc_helper.xy_desc(), &reserve_space_size));\n    if (reserve_space_size == 0 && workspace_size <= bn_workspace_size) {\n      OF_CUDNN_CHECK(cudnnBatchNormalizationBackwardEx(\n          ctx->stream()->As<ep::CudaStream>()->cudnn_handle(), mode, CUDNN_BATCHNORM_OPS_BN,\n          CudnnSPOnePtr(data_type), CudnnSPZeroPtr(data_type), CudnnSPOnePtr(data_type),\n          CudnnSPZeroPtr(data_type), desc_helper.xy_desc(), x->dptr(), nullptr, nullptr,\n          desc_helper.xy_desc(), bn_dy_ptr, nullptr, nullptr, desc_helper.xy_desc(), dx->mut_dptr(),\n          desc_helper.param_desc(), gamma->dptr(), nullptr, gamma_diff->mut_dptr(),\n          beta_diff->mut_dptr(), epsilon, mean->dptr(), inv_variance->dptr(), nullptr,\n          bn_workspace_ptr, bn_workspace_size, nullptr, 0));\n    } else {\n      OF_CUDNN_CHECK(cudnnBatchNormalizationBackward(\n          ctx->stream()->As<ep::CudaStream>()->cudnn_handle(), mode, CudnnSPOnePtr(data_type),\n          CudnnSPZeroPtr(data_type), CudnnSPOnePtr(data_type), CudnnSPZeroPtr(data_type),\n          desc_helper.xy_desc(), x->dptr(), desc_helper.xy_desc(), bn_dy_ptr, desc_helper.xy_desc(),\n          dx->mut_dptr(), desc_helper.param_desc(), gamma->dptr(), gamma_diff->mut_dptr(),\n          beta_diff->mut_dptr(), epsilon, mean->dptr(), inv_variance->dptr()));\n    }\n#else\n    OF_CUDNN_CHECK(cudnnBatchNormalizationBackward(\n        ctx->stream()->As<ep::CudaStream>()->cudnn_handle(), mode, CudnnSPOnePtr(data_type),\n        CudnnSPZeroPtr(data_type), CudnnSPOnePtr(data_type), CudnnSPZeroPtr(data_type),\n        desc_helper.xy_desc(), x->dptr(), desc_helper.xy_desc(), bn_dy_ptr, desc_helper.xy_desc(),\n        dx->mut_dptr(), desc_helper.param_desc(), gamma->dptr(), gamma_diff->mut_dptr(),\n        beta_diff->mut_dptr(), epsilon, mean->dptr(), inv_variance->dptr()));\n#endif\n  }\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\nREGISTER_USER_KERNEL(\"normalization_grad\")\n    .SetCreateFn<NormalizationGradUserKernel>()\n    .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA))\n    .SetInferTmpSizeFn(InferGradTmpSize);\n\n#define REGISTER_BN_ADD_RELU_GRAD_KERNEL(dtype)\nREGISTER_USER_KERNEL(\"normalization_add_relu_grad\")\n    .SetCreateFn<NormalizationGradUserKernel>()\n    .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA))\n    .SetInferTmpSizeFn(InferGradTmpSize);\n\n#if (CUDNN_VERSION >= 7401)\n\nsize_t InferFusedNormalizationAddReluTmpSize(user_op::InferContext* ctx) {\n  const auto& x = ctx->InputTensorDesc(\"x\", 0);\n  const auto axis = ctx->Attr<int32_t>(\"axis\");\n  const CudnnTensorDescHelper desc_helper(x.shape(), x.data_type(), axis,\n                                          CUDNN_BATCHNORM_SPATIAL_PERSISTENT);\n  size_t size_in_bytes;\n  cudnnHandle_t handle = Singleton<CudnnHandlePool>::Get()->Get();\n  CudnnActivationDesc activation_desc(CUDNN_ACTIVATION_RELU, CUDNN_PROPAGATE_NAN, 0);\n  cudnnBatchNormOps_t ops;\n  cudnnTensorDescriptor_t z_desc;\n  if (ctx->has_input(\"addend\", 0)) {\n    ops = CUDNN_BATCHNORM_OPS_BN_ADD_ACTIVATION;\n    z_desc = desc_helper.xy_desc();\n  } else {\n    ops = CUDNN_BATCHNORM_OPS_BN_ACTIVATION;\n    z_desc = nullptr;\n  }\n  OF_CUDNN_CHECK(cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize(\n      handle, CUDNN_BATCHNORM_SPATIAL_PERSISTENT, ops, desc_helper.xy_desc(), z_desc,\n      desc_helper.xy_desc(), desc_helper.param_desc(), activation_desc.Get(), &size_in_bytes));\n  Singleton<CudnnHandlePool>::Get()->Put(handle);\n  return std::max(size_in_bytes, static_cast<size_t>(1));\n}\n\nsize_t InferFusedNormalizationAddReluGradTmpSize(user_op::InferContext* ctx) {\n  const auto& x = ctx->InputTensorDesc(\"x\", 0);\n  const auto axis = ctx->Attr<int32_t>(\"axis\");\n  const CudnnTensorDescHelper desc_helper(x.shape(), x.data_type(), axis,\n                                          CUDNN_BATCHNORM_SPATIAL_PERSISTENT);\n  size_t size_in_bytes;\n  cudnnHandle_t handle = Singleton<CudnnHandlePool>::Get()->Get();\n  CudnnActivationDesc activation_desc(CUDNN_ACTIVATION_RELU, CUDNN_PROPAGATE_NAN, 0);\n  cudnnBatchNormOps_t ops;\n  cudnnTensorDescriptor_t z_desc;\n  if (ctx->has_output(\"addend_diff\", 0)) {\n    ops = CUDNN_BATCHNORM_OPS_BN_ADD_ACTIVATION;\n    z_desc = desc_helper.xy_desc();\n  } else {\n    ops = CUDNN_BATCHNORM_OPS_BN_ACTIVATION;\n    z_desc = nullptr;\n  }\n  OF_CUDNN_CHECK(cudnnGetBatchNormalizationBackwardExWorkspaceSize(\n      handle, CUDNN_BATCHNORM_SPATIAL_PERSISTENT, ops, desc_helper.xy_desc(), desc_helper.xy_desc(),\n      desc_helper.xy_desc(), z_desc, desc_helper.xy_desc(), desc_helper.param_desc(),\n      activation_desc.Get(), &size_in_bytes));\n  Singleton<CudnnHandlePool>::Get()->Put(handle);\n  return std::max(size_in_bytes, static_cast<size_t>(1));\n}\n\nclass FusedNormalizationAddReluKernel final : public user_op::OpKernel,\n                                              public user_op::CudaGraphSupport {\n public:\n  FusedNormalizationAddReluKernel() = default;\n  ~FusedNormalizationAddReluKernel() override = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const auto* x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    auto* y = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    const auto* gamma = ctx->Tensor4ArgNameAndIndex(\"gamma\", 0);\n    const auto* beta = ctx->Tensor4ArgNameAndIndex(\"beta\", 0);\n    auto* moving_mean = ctx->Tensor4ArgNameAndIndex(\"moving_mean\", 0);\n    auto* moving_variance = ctx->Tensor4ArgNameAndIndex(\"moving_variance\", 0);\n    const auto axis = ctx->Attr<int32_t>(\"axis\");\n    const auto epsilon = ctx->Attr<float>(\"epsilon\");\n    const auto momentum = ctx->Attr<float>(\"momentum\");\n    auto* mean = ctx->Tensor4ArgNameAndIndex(\"mean\", 0);\n    auto* inv_variance = ctx->Tensor4ArgNameAndIndex(\"inv_variance\", 0);\n    auto* reserve_space = ctx->Tensor4ArgNameAndIndex(\"reserve_space\", 0);\n    auto* tmp_buffer = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n\n    const DataType data_type = x->data_type();\n    CHECK_EQ(x->shape_view(), y->shape_view());\n    CHECK_EQ(y->data_type(), data_type);\n    CHECK_GE(axis, 0);\n    CHECK_LT(axis, x->shape_view().NumAxes());\n\n    const CudnnTensorDescHelper desc_helper(x->shape_view(), data_type, axis,\n                                            CUDNN_BATCHNORM_SPATIAL_PERSISTENT);\n    desc_helper.CheckParamTensor(gamma);\n    desc_helper.CheckParamTensor(beta);\n    desc_helper.CheckParamTensor(moving_mean);\n    desc_helper.CheckParamTensor(moving_variance);\n    desc_helper.CheckParamTensor(mean);\n    desc_helper.CheckParamTensor(inv_variance);\n\n    CudnnActivationDesc activation_desc(CUDNN_ACTIVATION_RELU, CUDNN_PROPAGATE_NAN, 0);\n    cudnnTensorDescriptor_t z_desc;\n    const void* z_ptr;\n    cudnnBatchNormOps_t ops;\n    if (ctx->has_input(\"addend\", 0)) {\n      z_desc = desc_helper.xy_desc();\n      z_ptr = ctx->Tensor4ArgNameAndIndex(\"addend\", 0)->dptr();\n      ops = CUDNN_BATCHNORM_OPS_BN_ADD_ACTIVATION;\n    } else {\n      z_desc = nullptr;\n      z_ptr = nullptr;\n      ops = CUDNN_BATCHNORM_OPS_BN_ACTIVATION;\n    }\n\n    size_t min_workspace_size;\n    OF_CUDNN_CHECK(cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize(\n        ctx->stream()->As<ep::CudaStream>()->cudnn_handle(), CUDNN_BATCHNORM_SPATIAL_PERSISTENT,\n        ops, desc_helper.xy_desc(), z_desc, desc_helper.xy_desc(), desc_helper.param_desc(),\n        activation_desc.Get(), &min_workspace_size));\n    const size_t workspace_size = tmp_buffer->shape_view().elem_cnt();\n    CHECK_GE(workspace_size, min_workspace_size);\n    size_t min_reserve_space_size;\n    OF_CUDNN_CHECK(cudnnGetBatchNormalizationTrainingExReserveSpaceSize(\n        ctx->stream()->As<ep::CudaStream>()->cudnn_handle(), CUDNN_BATCHNORM_SPATIAL_PERSISTENT,\n        ops, activation_desc.Get(), desc_helper.xy_desc(), &min_reserve_space_size));\n    const size_t reserve_space_size = reserve_space->shape_view().elem_cnt();\n    CHECK_GE(reserve_space_size, min_reserve_space_size);\n\n    OF_CUDNN_CHECK(cudnnBatchNormalizationForwardTrainingEx(\n        ctx->stream()->As<ep::CudaStream>()->cudnn_handle(), CUDNN_BATCHNORM_SPATIAL_PERSISTENT,\n        ops, CudnnSPOnePtr(data_type), CudnnSPZeroPtr(data_type), desc_helper.xy_desc(), x->dptr(),\n        z_desc, z_ptr, desc_helper.xy_desc(), y->mut_dptr(), desc_helper.param_desc(),\n        gamma->dptr(), beta->dptr(), 1.0 - momentum, moving_mean->mut_dptr(),\n        moving_variance->mut_dptr(), epsilon, mean->mut_dptr(), inv_variance->mut_dptr(),\n        activation_desc.Get(), tmp_buffer->mut_dptr(), workspace_size, reserve_space->mut_dptr(),\n        reserve_space_size));\n  }\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\nREGISTER_USER_KERNEL(\"cudnn_fused_normalization_add_relu\")\n    .SetCreateFn<FusedNormalizationAddReluKernel>()\n    .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA))\n    .SetInferTmpSizeFn(InferFusedNormalizationAddReluTmpSize);\n\nclass FusedNormalizationAddReluGradUserKernel final : public user_op::OpKernel,\n                                                      public user_op::CudaGraphSupport {\n public:\n  FusedNormalizationAddReluGradUserKernel() = default;\n  ~FusedNormalizationAddReluGradUserKernel() override = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const auto* x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    const auto* y = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    auto* dx = ctx->Tensor4ArgNameAndIndex(\"dx\", 0);\n    const auto* dy = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    const auto* gamma = ctx->Tensor4ArgNameAndIndex(\"gamma\", 0);\n    const auto* beta = ctx->Tensor4ArgNameAndIndex(\"beta\", 0);\n    auto* gamma_diff = ctx->Tensor4ArgNameAndIndex(\"gamma_diff\", 0);\n    auto* beta_diff = ctx->Tensor4ArgNameAndIndex(\"beta_diff\", 0);\n    const auto* mean = ctx->Tensor4ArgNameAndIndex(\"mean\", 0);\n    const auto* inv_variance = ctx->Tensor4ArgNameAndIndex(\"inv_variance\", 0);\n    const auto* reserve_space = ctx->Tensor4ArgNameAndIndex(\"reserve_space\", 0);\n    auto* tmp_buffer = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n    const auto axis = ctx->Attr<int32_t>(\"axis\");\n    const auto epsilon = ctx->Attr<float>(\"epsilon\");\n\n    const DataType data_type = x->data_type();\n    CHECK_EQ(dy->shape_view(), x->shape_view());\n    CHECK_EQ(dy->data_type(), data_type);\n    CHECK_EQ(dx->shape_view(), x->shape_view());\n    CHECK_EQ(dx->data_type(), data_type);\n    CHECK_GE(axis, 0);\n    CHECK_LT(axis, x->shape_view().NumAxes());\n\n    const CudnnTensorDescHelper desc_helper(x->shape_view(), data_type, axis,\n                                            CUDNN_BATCHNORM_SPATIAL_PERSISTENT);\n    desc_helper.CheckParamTensor(gamma);\n    desc_helper.CheckParamTensor(beta);\n    desc_helper.CheckParamTensor(gamma_diff);\n    desc_helper.CheckParamTensor(beta_diff);\n    desc_helper.CheckParamTensor(mean);\n    desc_helper.CheckParamTensor(inv_variance);\n\n    CudnnActivationDesc activation_desc(CUDNN_ACTIVATION_RELU, CUDNN_PROPAGATE_NAN, 0);\n    cudnnTensorDescriptor_t dz_desc;\n    void* dz_ptr;\n    cudnnBatchNormOps_t ops;\n    if (ctx->has_output(\"addend_diff\", 0)) {\n      dz_desc = desc_helper.xy_desc();\n      dz_ptr = ctx->Tensor4ArgNameAndIndex(\"addend_diff\", 0)->mut_dptr();\n      ops = CUDNN_BATCHNORM_OPS_BN_ADD_ACTIVATION;\n    } else {\n      dz_desc = nullptr;\n      dz_ptr = nullptr;\n      ops = CUDNN_BATCHNORM_OPS_BN_ACTIVATION;\n    }\n\n    size_t min_workspace_size;\n    OF_CUDNN_CHECK(cudnnGetBatchNormalizationBackwardExWorkspaceSize(\n        ctx->stream()->As<ep::CudaStream>()->cudnn_handle(), CUDNN_BATCHNORM_SPATIAL_PERSISTENT,\n        ops, desc_helper.xy_desc(), desc_helper.xy_desc(), desc_helper.xy_desc(), dz_desc,\n        desc_helper.xy_desc(), desc_helper.param_desc(), activation_desc.Get(),\n        &min_workspace_size));\n    const size_t workspace_size = tmp_buffer->shape_view().elem_cnt();\n    CHECK_GE(workspace_size, min_workspace_size);\n    size_t min_reserve_space_size;\n    OF_CUDNN_CHECK(cudnnGetBatchNormalizationTrainingExReserveSpaceSize(\n        ctx->stream()->As<ep::CudaStream>()->cudnn_handle(), CUDNN_BATCHNORM_SPATIAL_PERSISTENT,\n        ops, activation_desc.Get(), desc_helper.xy_desc(), &min_reserve_space_size));\n    const size_t reserve_space_size = reserve_space->shape_view().elem_cnt();\n    CHECK_GE(reserve_space_size, min_reserve_space_size);\n    OF_CUDNN_CHECK(cudnnBatchNormalizationBackwardEx(\n        ctx->stream()->As<ep::CudaStream>()->cudnn_handle(), CUDNN_BATCHNORM_SPATIAL_PERSISTENT,\n        ops, CudnnSPOnePtr(data_type), CudnnSPZeroPtr(data_type), CudnnSPOnePtr(data_type),\n        CudnnSPZeroPtr(data_type), desc_helper.xy_desc(), x->dptr(), desc_helper.xy_desc(),\n        y->dptr(), desc_helper.xy_desc(), dy->dptr(), dz_desc, dz_ptr, desc_helper.xy_desc(),\n        dx->mut_dptr(), desc_helper.param_desc(), gamma->dptr(), beta->dptr(),\n        gamma_diff->mut_dptr(), beta_diff->mut_dptr(), epsilon, mean->dptr(), inv_variance->dptr(),\n        activation_desc.Get(), tmp_buffer->mut_dptr(), workspace_size,\n        const_cast<void*>(reserve_space->dptr()), reserve_space_size));\n  }\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\nREGISTER_USER_KERNEL(\"cudnn_fused_normalization_add_relu_grad\")\n    .SetCreateFn<FusedNormalizationAddReluGradUserKernel>()\n    .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA))\n    .SetInferTmpSizeFn(InferFusedNormalizationAddReluGradTmpSize);\n\n#endif\n\n}  // namespace\n}  // namespace oneflow\n\n#endif\n"
  },
  {
    "path": "oneflow/user/kernels/nvtx_range_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/new_kernel_util.h\"\n\n#ifdef OF_ENABLE_PROFILER\n#include <nvtx3/nvToolsExt.h>\n#endif  // OF_ENABLE_PROFILER\n\nnamespace oneflow {\n\nnamespace {\n\n#ifdef OF_ENABLE_PROFILER\nstatic thread_local HashMap<std::string, nvtxRangeId_t> mark2range_id;\n#endif\n\n}  // namespace\n\nclass NvtxOpKernelState final : public user_op::OpKernelState {\n public:\n  NvtxOpKernelState() : counter_(0) {\n#ifndef OF_ENABLE_PROFILER\n    LOG(WARNING) << \"To use NVTX, run cmake with -DBUILD_PROFILER=ON\";\n#endif\n  }\n  ~NvtxOpKernelState() override = default;\n\n  int64_t counter() const { return counter_; }\n  void IncreaseCount() { counter_ += 1; }\n\n private:\n  int64_t counter_;\n};\n\nclass NvtxStartKernel final : public user_op::OpKernel {\n public:\n  NvtxStartKernel() = default;\n  ~NvtxStartKernel() override = default;\n\n  std::shared_ptr<user_op::OpKernelState> CreateOpKernelState(\n      user_op::KernelInitContext* ctx) const override {\n    return std::make_shared<NvtxOpKernelState>();\n  }\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state,\n               const user_op::OpKernelCache*) const override {\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    const ShapeView& in_shape = in->shape_view();\n    CHECK_EQ(out->shape_view(), in_shape);\n    const DataType in_data_type = in->data_type();\n    CHECK_EQ(out->data_type(), in_data_type);\n    Memcpy<DeviceType::kCUDA>(ctx->stream(), out->mut_dptr<void>(), in->dptr<void>(),\n                              in_shape.elem_cnt() * GetSizeOfDataType(in_data_type));\n#ifdef OF_ENABLE_PROFILER\n    auto* kernel_state = dynamic_cast<NvtxOpKernelState*>(state);\n    const std::string mark_prefix = ctx->Attr<std::string>(\"mark_prefix\");\n    const std::string mark = mark_prefix + \"-\" + std::to_string(kernel_state->counter());\n    nvtxRangeId_t range_id = nvtxRangeStartA(mark.c_str());\n    CHECK(mark2range_id.emplace(mark, range_id).second);\n    kernel_state->IncreaseCount();\n#endif  // OF_ENABLE_PROFILER\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\nREGISTER_USER_KERNEL(\"nvtx_start\")\n    .SetCreateFn<NvtxStartKernel>()\n    .SetIsMatchedHob(user_op::HobDeviceType() == DeviceType::kCUDA)\n    .SetInplaceProposalFn([](const user_op::InferContext&,\n                             user_op::AddInplaceArgPair AddInplaceArgPairFn) -> Maybe<void> {\n      OF_RETURN_IF_ERROR(AddInplaceArgPairFn(\"out\", 0, \"in\", 0, false));\n      return Maybe<void>::Ok();\n    });\n\nclass NvtxEndKernel final : public user_op::OpKernel {\n public:\n  NvtxEndKernel() = default;\n  ~NvtxEndKernel() override = default;\n\n  std::shared_ptr<user_op::OpKernelState> CreateOpKernelState(\n      user_op::KernelInitContext* ctx) const override {\n    return std::make_shared<NvtxOpKernelState>();\n  }\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state,\n               const user_op::OpKernelCache*) const override {\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    const ShapeView& in_shape = in->shape_view();\n    CHECK_EQ(out->shape_view(), in_shape);\n    const DataType in_data_type = in->data_type();\n    CHECK_EQ(out->data_type(), in_data_type);\n#ifdef OF_ENABLE_PROFILER\n    auto* kernel_state = dynamic_cast<NvtxOpKernelState*>(state);\n    const std::string mark_prefix = ctx->Attr<std::string>(\"mark_prefix\");\n    const std::string mark = mark_prefix + \"-\" + std::to_string(kernel_state->counter());\n    auto it = mark2range_id.find(mark.c_str());\n    CHECK(it != mark2range_id.end());\n    nvtxRangeId_t range_id = it->second;\n    mark2range_id.erase(it);\n    nvtxRangeEnd(range_id);\n    Memcpy<DeviceType::kCUDA>(ctx->stream(), out->mut_dptr<void>(), in->dptr<void>(),\n                              in_shape.elem_cnt() * GetSizeOfDataType(in_data_type));\n    kernel_state->IncreaseCount();\n#endif\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\nREGISTER_USER_KERNEL(\"nvtx_end\")\n    .SetCreateFn<NvtxEndKernel>()\n    .SetIsMatchedHob(user_op::HobDeviceType() == DeviceType::kCUDA)\n    .SetInplaceProposalFn([](const user_op::InferContext&,\n                             user_op::AddInplaceArgPair AddInplaceArgPairFn) -> Maybe<void> {\n      OF_RETURN_IF_ERROR(AddInplaceArgPairFn(\"out\", 0, \"in\", 0, false));\n      return Maybe<void>::Ok();\n    });\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/ofrecord_decoder_kernels.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/common/balanced_splitter.h\"\n#include \"oneflow/core/common/blocking_counter.h\"\n#include \"oneflow/core/common/tensor_buffer.h\"\n#include \"oneflow/core/kernel/new_kernel_util.h\"\n#include \"oneflow/core/kernel/kernel_util.h\"\n#include \"oneflow/core/thread/thread_manager.h\"\n#include \"oneflow/user/image/random_crop_generator.h\"\n#include \"oneflow/user/image/image_util.h\"\n#include \"oneflow/user/kernels/random_crop_kernel_state.h\"\n#include \"oneflow/user/kernels/op_kernel_wrapper.h\"\n#include \"oneflow/user/kernels/random_seed_util.h\"\n#include \"oneflow/user/image/jpeg_decoder.h\"\n\n#include <opencv2/opencv.hpp>\n#include <jpeglib.h>\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename T>\nvoid DecodeOneRawOFRecord(const Feature& feature, T* dptr, int64_t sample_elem_cnt, bool truncate,\n                          bool dim1_varying_length) {\n  if (feature.has_bytes_list()) {\n    CHECK_EQ(feature.bytes_list().value_size(), 1);\n    const auto& value0 = feature.bytes_list().value(0);\n    auto in_dptr = reinterpret_cast<const int8_t*>(value0.c_str());\n    sample_elem_cnt = std::min<int64_t>(sample_elem_cnt, value0.size());\n    std::transform(in_dptr, in_dptr + sample_elem_cnt, dptr,\n                   [](int8_t v) { return static_cast<T>(v); });\n  }\n#define DEFINE_ONE_ELIF(PbT, CppT)                                                       \\\n  else if (feature.has_##PbT##_list()) {                                                 \\\n    const auto& list = feature.PbT##_list();                                             \\\n    const CppT* in_dptr = list.value().data();                                           \\\n    const int64_t padding_elem_num = truncate ? sample_elem_cnt - list.value_size() : 0; \\\n    if (truncate) {                                                                      \\\n      sample_elem_cnt = std::min<int64_t>(sample_elem_cnt, list.value_size());           \\\n    } else {                                                                             \\\n      if (dim1_varying_length) {                                                         \\\n        sample_elem_cnt = list.value_size();                                             \\\n      } else {                                                                           \\\n        CHECK_EQ(sample_elem_cnt, list.value_size());                                    \\\n      }                                                                                  \\\n    }                                                                                    \\\n    std::transform(in_dptr, in_dptr + sample_elem_cnt, dptr,                             \\\n                   [](CppT v) { return static_cast<T>(v); });                            \\\n    if (padding_elem_num > 0) {                                                          \\\n      std::memset(dptr + sample_elem_cnt, 0, padding_elem_num * sizeof(T));              \\\n    }                                                                                    \\\n  }\n  DEFINE_ONE_ELIF(float, float)\n  DEFINE_ONE_ELIF(double, double)\n  DEFINE_ONE_ELIF(int32, int32_t)\n  DEFINE_ONE_ELIF(int64, int64_t)\n#undef DEFINE_ONE_ELIF\n  else {\n    UNIMPLEMENTED();\n  }\n}\n\n}  // namespace\n\ntemplate<typename T>\nclass OFRecordRawDecoderKernel final : public user_op::OpKernel {\n public:\n  OFRecordRawDecoderKernel() = default;\n  ~OFRecordRawDecoderKernel() override = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    user_op::Tensor* in_blob = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    user_op::Tensor* out_blob = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    // TODO(chengcheng): remove record num in record blob, fix by shape elem cnt\n    int64_t record_num = in_blob->shape_view().At(0);\n    int64_t sample_elem_cnt = out_blob->shape_view().Count(1);\n    CHECK(record_num > 0);\n    const OFRecord* records = in_blob->dptr<OFRecord>();\n    T* out_dptr = out_blob->mut_dptr<T>();\n    const std::string& name = ctx->Attr<std::string>(\"name\");\n\n    bool truncate = ctx->Attr<bool>(\"truncate\");\n    bool dim1_varying_length = ctx->Attr<bool>(\"dim1_varying_length\");\n\n    MultiThreadLoop(record_num, [&](size_t i) {\n      const OFRecord& record = *(records + i);\n      T* dptr = out_dptr + i * sample_elem_cnt;\n      CHECK(record.feature().find(name) != record.feature().end())\n          << \"Field \" << name << \" not found\";\n      const Feature& feature = record.feature().at(name);\n      DecodeOneRawOFRecord(feature, dptr, sample_elem_cnt, truncate, dim1_varying_length);\n    });\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_RAW_DECODER_KERNEL(dtype)                                       \\\n  REGISTER_USER_KERNEL(\"ofrecord_raw_decoder\")                                   \\\n      .SetCreateFn<OFRecordRawDecoderKernel<dtype>>()                            \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)            \\\n                       && (user_op::HobDataType(\"in\", 0) == DataType::kOFRecord) \\\n                       && (user_op::HobDataType(\"out\", 0) == GetDataType<dtype>::value));\n\nREGISTER_RAW_DECODER_KERNEL(char)\nREGISTER_RAW_DECODER_KERNEL(float)\nREGISTER_RAW_DECODER_KERNEL(double)\nREGISTER_RAW_DECODER_KERNEL(int8_t)\nREGISTER_RAW_DECODER_KERNEL(int32_t)\nREGISTER_RAW_DECODER_KERNEL(int64_t)\nREGISTER_RAW_DECODER_KERNEL(uint8_t)\n\nclass OFRecordBytesDecoderKernel final : public user_op::OpKernel {\n public:\n  OFRecordBytesDecoderKernel() = default;\n  ~OFRecordBytesDecoderKernel() override = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    CHECK_EQ(out->shape_view(), in->shape_view());\n    CHECK_EQ(in->data_type(), DataType::kOFRecord);\n    CHECK_EQ(out->data_type(), DataType::kTensorBuffer);\n    const int64_t num_instances = in->shape_view().elem_cnt();\n    const auto* records = in->dptr<OFRecord>();\n    auto* buffers = out->mut_dptr<TensorBuffer>();\n    const std::string& name = ctx->Attr<std::string>(\"name\");\n    MultiThreadLoop(num_instances, [&](size_t i) {\n      const OFRecord& record = *(records + i);\n      TensorBuffer* buffer = buffers + i;\n      auto it = record.feature().find(name);\n      CHECK(it != record.feature().end()) << \"Field \" << name << \" not found\";\n      const Feature& feature = it->second;\n      CHECK(feature.has_bytes_list());\n      CHECK_EQ(feature.bytes_list().value_size(), 1);\n      const int64_t size = feature.bytes_list().value(0).size();\n      buffer->Resize(Shape({size}), DataType::kUInt8);\n      memcpy(buffer->mut_data(), feature.bytes_list().value(0).data(), size);\n    });\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\nREGISTER_USER_KERNEL(\"ofrecord_bytes_decoder\")\n    .SetCreateFn<OFRecordBytesDecoderKernel>()\n    .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)\n                     && (user_op::HobDataType(\"in\", 0) == DataType::kOFRecord)\n                     && (user_op::HobDataType(\"out\", 0) == DataType::kTensorBuffer));\n\nnamespace {\n\nvoid DecodeRandomCropImageFromOneRecord(const OFRecord& record, TensorBuffer* buffer,\n                                        const std::string& name, const std::string& color_space,\n                                        RandomCropGenerator* random_crop_gen) {\n  CHECK(record.feature().find(name) != record.feature().end()) << \"Field \" << name << \" not found\";\n  const Feature& feature = record.feature().at(name);\n  CHECK(feature.has_bytes_list());\n  CHECK(feature.bytes_list().value_size() == 1);\n  const std::string& src_data = feature.bytes_list().value(0);\n  cv::Mat image;\n\n  if (JpegPartialDecodeRandomCropImage(reinterpret_cast<const unsigned char*>(src_data.data()),\n                                       src_data.size(), random_crop_gen, nullptr, 0, &image)) {\n    // convert color space\n    // jpeg decode output RGB\n    if (ImageUtil::IsColor(color_space) && color_space != \"RGB\") {\n      ImageUtil::ConvertColor(\"RGB\", image, color_space, image);\n    }\n  } else {\n    OpenCvPartialDecodeRandomCropImage(reinterpret_cast<const unsigned char*>(src_data.data()),\n                                       src_data.size(), random_crop_gen, color_space, image);\n    // convert color space\n    // opencv decode output BGR\n    if (ImageUtil::IsColor(color_space) && color_space != \"BGR\") {\n      ImageUtil::ConvertColor(\"BGR\", image, color_space, image);\n    }\n  }\n\n  int W = image.cols;\n  int H = image.rows;\n\n  CHECK(image.isContinuous());\n  const int c = ImageUtil::IsColor(color_space) ? 3 : 1;\n  CHECK_EQ(c, image.channels());\n  Shape image_shape({H, W, c});\n  buffer->Resize(image_shape, DataType::kUInt8);\n  CHECK_EQ(image_shape.elem_cnt(), buffer->nbytes());\n  CHECK_EQ(image_shape.elem_cnt(), image.total() * image.elemSize());\n  memcpy(buffer->mut_data<uint8_t>(), image.ptr(), image_shape.elem_cnt());\n}\n\n}  // namespace\n\nclass OFRecordImageDecoderRandomCropKernel final : public user_op::OpKernel {\n public:\n  OFRecordImageDecoderRandomCropKernel() = default;\n  ~OFRecordImageDecoderRandomCropKernel() override = default;\n\n  std::shared_ptr<user_op::OpKernelState> CreateOpKernelState(\n      user_op::KernelInitContext* ctx) const override {\n    return CreateRandomCropKernelState(ctx);\n  }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state,\n               const user_op::OpKernelCache*) const override {\n    auto* crop_window_generators = dynamic_cast<RandomCropKernelState*>(state);\n    CHECK_NOTNULL(crop_window_generators);\n    user_op::Tensor* out_blob = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    int64_t record_num = out_blob->shape_view().At(0);\n    CHECK(record_num > 0);\n    user_op::Tensor* in_blob = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    CHECK_EQ(out_blob->shape_view(), in_blob->shape_view());\n    const OFRecord* records = in_blob->dptr<OFRecord>();\n    TensorBuffer* buffers = out_blob->mut_dptr<TensorBuffer>();\n    const std::string& name = ctx->Attr<std::string>(\"name\");\n    const std::string& color_space = ctx->Attr<std::string>(\"color_space\");\n\n    MultiThreadLoop(record_num, [&](size_t i) {\n      const OFRecord& record = *(records + i);\n      TensorBuffer* buffer = buffers + i;\n      RandomCropGenerator* gen = crop_window_generators->GetGenerator(i);\n      DecodeRandomCropImageFromOneRecord(record, buffer, name, color_space, gen);\n    });\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\nREGISTER_USER_KERNEL(\"ofrecord_image_decoder_random_crop\")\n    .SetCreateFn<OFRecordImageDecoderRandomCropKernel>()\n    .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)\n                     && (user_op::HobDataType(\"in\", 0) == DataType::kOFRecord)\n                     && (user_op::HobDataType(\"out\", 0) == DataType::kTensorBuffer));\n\nclass OFRecordImageDecoderKernel final : public user_op::OpKernel {\n public:\n  OFRecordImageDecoderKernel() = default;\n  ~OFRecordImageDecoderKernel() override = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    user_op::Tensor* out_blob = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    int64_t record_num = out_blob->shape_view().At(0);\n    CHECK(record_num > 0);\n    user_op::Tensor* in_blob = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    CHECK_EQ(out_blob->shape_view(), in_blob->shape_view());\n    const OFRecord* records = in_blob->dptr<OFRecord>();\n    TensorBuffer* buffers = out_blob->mut_dptr<TensorBuffer>();\n    const std::string& name = ctx->Attr<std::string>(\"name\");\n    const std::string& color_space = ctx->Attr<std::string>(\"color_space\");\n\n    MultiThreadLoop(record_num, [&](size_t i) {\n      const OFRecord& record = *(records + i);\n      TensorBuffer* buffer = buffers + i;\n      DecodeRandomCropImageFromOneRecord(record, buffer, name, color_space, nullptr);\n    });\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\nREGISTER_USER_KERNEL(\"ofrecord_image_decoder\")\n    .SetCreateFn<OFRecordImageDecoderKernel>()\n    .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)\n                     && (user_op::HobDataType(\"in\", 0) == DataType::kOFRecord)\n                     && (user_op::HobDataType(\"out\", 0) == DataType::kTensorBuffer));\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/ofrecord_image_classification_reader_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/user/data/ofrecord_image_classification_data_reader.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nclass OFRecordImageClassificationReaderKernelState final : public user_op::OpKernelState {\n public:\n  explicit OFRecordImageClassificationReaderKernelState(user_op::KernelInitContext* ctx)\n      : reader_(ctx) {}\n  ~OFRecordImageClassificationReaderKernelState() override = default;\n\n  void Read(user_op::KernelComputeContext* ctx) { reader_.Read(ctx); }\n\n private:\n  data::OFRecordImageClassificationDataReader reader_;\n};\n\n}  // namespace\n\nclass OFRecordImageClassificationReaderKernel final : public user_op::OpKernel {\n public:\n  OFRecordImageClassificationReaderKernel() = default;\n  ~OFRecordImageClassificationReaderKernel() override = default;\n\n  std::shared_ptr<user_op::OpKernelState> CreateOpKernelState(\n      user_op::KernelInitContext* ctx) const override {\n    return std::make_shared<OFRecordImageClassificationReaderKernelState>(ctx);\n  }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state,\n               const user_op::OpKernelCache*) const override {\n    auto* reader = dynamic_cast<OFRecordImageClassificationReaderKernelState*>(state);\n    CHECK_NOTNULL(reader);\n    reader->Read(ctx);\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\nREGISTER_USER_KERNEL(\"ofrecord_image_classification_reader\")\n    .SetCreateFn<OFRecordImageClassificationReaderKernel>()\n    .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)\n                     && (user_op::HobDataType(\"image\", 0) == DataType::kTensorBuffer)\n                     && (user_op::HobDataType(\"label\", 0) == DataType::kTensorBuffer));\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/ofrecord_reader_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/user/data/ofrecord_data_reader.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nclass OFRecordReaderWrapper final : public user_op::OpKernelState {\n public:\n  explicit OFRecordReaderWrapper(user_op::KernelInitContext* ctx) : reader_(ctx) {}\n  ~OFRecordReaderWrapper() = default;\n\n  void Read(user_op::KernelComputeContext* ctx) { reader_.Read(ctx); }\n\n private:\n  data::OFRecordDataReader reader_;\n};\n\n}  // namespace\n\nclass OFRecordReaderKernel final : public user_op::OpKernel {\n public:\n  OFRecordReaderKernel() = default;\n  ~OFRecordReaderKernel() override = default;\n\n  std::shared_ptr<user_op::OpKernelState> CreateOpKernelState(\n      user_op::KernelInitContext* ctx) const override {\n    std::shared_ptr<OFRecordReaderWrapper> reader(new OFRecordReaderWrapper(ctx));\n    return reader;\n  }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state,\n               const user_op::OpKernelCache*) const override {\n    auto* reader = dynamic_cast<OFRecordReaderWrapper*>(state);\n    reader->Read(ctx);\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\nREGISTER_USER_KERNEL(\"OFRecordReader\")\n    .SetCreateFn<OFRecordReaderKernel>()\n    .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)\n                     && (user_op::HobDataType(\"out\", 0) == DataType::kOFRecord));\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/one_embedding_data_shuffle.cuh",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/device/nccl_util.h\"\n#include \"oneflow/core/job/eager_nccl_comm_manager.h\"\n#include \"oneflow/core/job/parallel_desc.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n#include \"oneflow/user/kernels/gather_kernel_util.h\"\n#include \"oneflow/user/kernels/unsorted_segment_sum_kernel_util.h\"\n#include \"oneflow/core/cuda/atomic.cuh\"\n#include \"oneflow/core/embedding/hash_functions.cuh\"\n\nnamespace oneflow {\n\nnamespace data_shuffle {\n\ntemplate<typename K>\nstruct TableEntry {\n  K key;\n  uint32_t value;\n};\n\ntemplate<typename U>\n__global__ void GenerateTableIds(int32_t elem_cnt, int32_t num_tables, U* table_ids) {\n  CUDA_1D_KERNEL_LOOP(i, elem_cnt) { table_ids[i] = i % num_tables; }\n}\n\nnamespace {\n\nconstexpr uint32_t PADDING_REV_INDEX = 0xffffffff;\n\ntemplate<typename K, typename V, typename IDX, typename HASH>\n__global__ void HashTableUniqueAndPartitionPairs(\n    const uint32_t table_capacity, const uint32_t num_keys, int32_t num_partition,\n    IDX* unique_counts, TableEntry<K>* table, const K* keys, const V* values,\n    K* partitioned_unique_keys, V* partitioned_unique_values, IDX* reverse_index,\n    bool need_process_values, const bool has_padding_idx, const int64_t padding_idx) {\n  CUDA_1D_KERNEL_LOOP_T(uint32_t, i, num_keys) {\n    IDX r_index_plus_one = 0;\n    const K key = keys[i];\n    if (has_padding_idx && key == padding_idx) {\n      reverse_index[i] = PADDING_REV_INDEX;\n    } else {\n      size_t key_hash = HASH()(key);\n      uint32_t partition_id = key_hash % num_partition;\n      IDX* unique_count = unique_counts + partition_id;\n      K* unique_keys = partitioned_unique_keys + partition_id * num_keys;\n      uint32_t pos = key_hash % table_capacity;\n      const K key_hi = (key | 0x1);\n      const K key_lo = (key & 0x1);\n      uint32_t counter = 0;\n      while (r_index_plus_one == 0) {\n        bool prob_next = false;\n        K* key_ptr = &table[pos].key;\n        volatile uint32_t* table_value_ptr = &table[pos].value;\n        const K old_key = cuda::atomic::CAS(key_ptr, 0, key_hi);\n        if (old_key == 0) {\n          IDX unique_pos = cuda::atomic::Add(unique_count, 1);\n          r_index_plus_one = unique_pos + 1;\n          unique_keys[unique_pos] = key;\n          if (need_process_values) {\n            partitioned_unique_values[partition_id * num_keys + unique_pos] = values[i];\n          }\n          *table_value_ptr = ((r_index_plus_one << 1U) | key_lo);\n        } else if (old_key == key_hi) {\n          const uint32_t value = *table_value_ptr;\n          if (value == 0) {\n            // do nothing\n          } else if ((value & 0x1) == key_lo) {\n            r_index_plus_one = (value >> 1U);\n          } else {\n            prob_next = true;\n          }\n        } else {\n          prob_next = true;\n        }\n        if (prob_next) {\n          pos += 1;\n          counter += 1;\n          if (pos >= table_capacity) { pos -= table_capacity; }\n          if (counter >= table_capacity) { __trap(); }\n        }\n      }\n      reverse_index[i] = partition_id * num_keys + r_index_plus_one - 1;\n    }\n  }\n}\n\ntemplate<typename IDX>\n__global__ void ComputeOffset(int32_t n, IDX* value) {\n  IDX sum = 0;\n  for (int i = 0; i < n; ++i) {\n    IDX count = value[i];\n    value[i] = sum;\n    sum += count;\n  }\n}\n\ntemplate<typename IDX>\n__global__ void ContiguousInverseUniquePartitionIndices(const int32_t num_ids, IDX* indices_offset,\n                                                        IDX* inverse_ptr) {\n  CUDA_1D_KERNEL_LOOP(i, num_ids) {\n    int inverse_indice = inverse_ptr[i];\n    int partition_id = inverse_indice / num_ids;\n    int partition_indice = inverse_indice - partition_id * num_ids;\n    int new_offset = indices_offset[partition_id];\n    inverse_ptr[i] = new_offset + partition_indice;\n  }\n}\n\ntemplate<typename T>\nvoid ShuffleData(cudaStream_t cuda_stream, ncclComm_t comm, DataType data_type,\n                 const std::vector<int64_t>& send_offsets,\n                 const std::vector<int64_t>& send_elem_cnt, const T* send_data,\n                 const std::vector<int64_t>& recv_offsets,\n                 const std::vector<int64_t>& recv_elem_cnt, T* recv_data) {\n  ncclDataType_t nccl_data_type = GetNcclDataType(data_type);\n  const int64_t parallel_num = send_offsets.size();\n  OF_NCCL_CHECK(ncclGroupStart());\n  for (int64_t i = 0; i < parallel_num; ++i) {\n    OF_NCCL_CHECK(ncclSend(send_data + send_offsets.at(i), send_elem_cnt.at(i), nccl_data_type, i,\n                           comm, cuda_stream));\n    OF_NCCL_CHECK(ncclRecv(recv_data + recv_offsets.at(i), recv_elem_cnt.at(i), nccl_data_type, i,\n                           comm, cuda_stream));\n  }\n  OF_NCCL_CHECK(ncclGroupEnd());\n}\n\ntemplate<typename IDX>\nvoid MakeShuffleIdParams(const IDX* host_num_unique_matrix, const int64_t num_ids,\n                         const int64_t row_size, int64_t parallel_id, int64_t parallel_num,\n                         std::vector<int64_t>* scatter_offset_vec,\n                         std::vector<int64_t>* scatter_elem_cnt_vec,\n                         std::vector<int64_t>* gather_offset_vec,\n                         std::vector<int64_t>* gather_elem_cnt_vec) {\n  scatter_offset_vec->resize(parallel_num);\n  scatter_elem_cnt_vec->resize(parallel_num);\n  gather_offset_vec->resize(parallel_num);\n  gather_elem_cnt_vec->resize(parallel_num);\n  int64_t gather_offset = 0;\n  for (int64_t i = 0; i < parallel_num; ++i) {\n    const int64_t scatter_elem_cnt =\n        host_num_unique_matrix[parallel_id * parallel_num + i] * row_size;\n    const int64_t gather_elem_cnt =\n        host_num_unique_matrix[i * parallel_num + parallel_id] * row_size;\n    scatter_offset_vec->at(i) = i * num_ids * row_size;\n    scatter_elem_cnt_vec->at(i) = scatter_elem_cnt;\n    gather_offset_vec->at(i) = gather_offset;\n    gather_elem_cnt_vec->at(i) = gather_elem_cnt;\n    gather_offset += gather_elem_cnt;\n  }\n}\n\ntemplate<typename IDX>\nvoid MakeShuffleParams(const IDX* host_num_unique_matrix, const int64_t num_ids,\n                       const int64_t row_size, int64_t parallel_id, int64_t parallel_num,\n                       std::vector<int64_t>* scatter_offset_vec,\n                       std::vector<int64_t>* scatter_elem_cnt_vec,\n                       std::vector<int64_t>* gather_offset_vec,\n                       std::vector<int64_t>* gather_elem_cnt_vec) {\n  scatter_offset_vec->resize(parallel_num);\n  scatter_elem_cnt_vec->resize(parallel_num);\n  gather_offset_vec->resize(parallel_num);\n  gather_elem_cnt_vec->resize(parallel_num);\n  int64_t gather_offset = 0;\n  int64_t scatter_offset = 0;\n  for (int64_t i = 0; i < parallel_num; ++i) {\n    const int64_t scatter_elem_cnt =\n        host_num_unique_matrix[parallel_id * parallel_num + i] * row_size;\n    const int64_t gather_elem_cnt =\n        host_num_unique_matrix[i * parallel_num + parallel_id] * row_size;\n    scatter_offset_vec->at(i) = scatter_offset;\n    scatter_elem_cnt_vec->at(i) = scatter_elem_cnt;\n    gather_offset_vec->at(i) = gather_offset;\n    gather_elem_cnt_vec->at(i) = gather_elem_cnt;\n    scatter_offset += scatter_elem_cnt;\n    gather_offset += gather_elem_cnt;\n  }\n}\n\ntemplate<typename K, typename U, typename IDX>\nvoid ShuffleIdsAndTableIds(cudaStream_t cuda_stream, ncclComm_t comm, int64_t parallel_id,\n                           int64_t parallel_num, int64_t num_ids, DataType ids_data_type,\n                           DataType table_ids_data_type, IDX* host_num_unique_matrix,\n                           K* partitioned_unique_ids, U* partitioned_unique_table_ids,\n                           K* received_ids, U* received_table_ids, int64_t* received_elem_cnt,\n                           bool need_process_table_ids) {\n  std::vector<int64_t> send_offsets;\n  std::vector<int64_t> send_elem_cnt;\n  std::vector<int64_t> recv_offsets;\n  std::vector<int64_t> recv_elem_cnt;\n  MakeShuffleIdParams(host_num_unique_matrix, num_ids, 1, parallel_id, parallel_num, &send_offsets,\n                      &send_elem_cnt, &recv_offsets, &recv_elem_cnt);\n  ShuffleData(cuda_stream, comm, ids_data_type, send_offsets, send_elem_cnt, partitioned_unique_ids,\n              recv_offsets, recv_elem_cnt, received_ids);\n  *received_elem_cnt = recv_offsets.at(parallel_num - 1) + recv_elem_cnt.at(parallel_num - 1);\n  if (need_process_table_ids) {\n    ShuffleData(cuda_stream, comm, table_ids_data_type, send_offsets, send_elem_cnt,\n                partitioned_unique_table_ids, recv_offsets, recv_elem_cnt, received_table_ids);\n  }\n}\n\ntemplate<typename K, typename IDX>\n__global__ void UnsortedSegmentHalfGpu(const IDX in_h2_elem_cnt, const IDX h2_inner_dim_size,\n                                       const IDX inner_dim_size, const half* data,\n                                       const K* segment_ids, const IDX num_segments,\n                                       half2* out_h2) {\n  CUDA_1D_KERNEL_LOOP_T(IDX, i, in_h2_elem_cnt) {\n    const IDX segment_id_idx = i / h2_inner_dim_size;\n    const IDX h2_inner_idx = i - segment_id_idx * h2_inner_dim_size;\n    const IDX inner_idx_0 = 2 * h2_inner_idx;\n    const IDX inner_idx_1 = inner_idx_0 + 1;\n    const half* data_row = data + segment_id_idx * inner_dim_size;\n    half2 val;\n    val.x = data_row[inner_idx_0];\n    val.y = (inner_idx_1 >= inner_dim_size) ? static_cast<half>(0) : data_row[inner_idx_1];\n    const IDX idx = segment_ids[segment_id_idx];\n    const IDX out_h2_offset = idx * h2_inner_dim_size + h2_inner_idx;\n    cuda::atomic::Add(out_h2 + out_h2_offset, val);\n  }\n}\n\ntemplate<typename T, typename K>\nstruct UnsortedSegmentSumPad {\n  void operator()(ep::Stream* stream, const K* segment_ids, const T* data, int64_t num_segment_ids,\n                  int64_t num_segments, int64_t inner_dim_size, int64_t padded_inner_dim_size,\n                  T* out) const {\n    UNIMPLEMENTED();\n  }\n};\n\ntemplate<typename K>\nstruct UnsortedSegmentSumPad<half, K> {\n  void operator()(ep::Stream* stream, const K* segment_ids, const half* data,\n                  int64_t num_segment_ids, int64_t num_segments, int64_t inner_dim_size,\n                  int64_t padded_inner_dim_size, half* out) const {\n    const int64_t data_elem_cnt = num_segment_ids * inner_dim_size;\n    const int64_t out_elem_cnt = num_segments * padded_inner_dim_size;\n    CHECK_EQ(padded_inner_dim_size % 2, 0);\n    CHECK_EQ(inner_dim_size + 1, padded_inner_dim_size);\n    const int64_t h2_inner_dim_size = padded_inner_dim_size / 2;\n    const int64_t in_h2_elem_cnt = num_segment_ids * h2_inner_dim_size;\n    if (std::max(data_elem_cnt, out_elem_cnt) < GetMaxVal<int32_t>() / 2) {\n      UnsortedSegmentHalfGpu<K, int32_t>\n          <<<BlocksNum4ThreadsNum(in_h2_elem_cnt), kCudaThreadsNumPerBlock, 0,\n             stream->As<ep::CudaStream>()->cuda_stream()>>>(\n              in_h2_elem_cnt, h2_inner_dim_size, inner_dim_size, data, segment_ids, num_segments,\n              reinterpret_cast<half2*>(out));\n    } else {\n      UnsortedSegmentHalfGpu<K, int64_t>\n          <<<BlocksNum4ThreadsNum(in_h2_elem_cnt), kCudaThreadsNumPerBlock, 0,\n             stream->As<ep::CudaStream>()->cuda_stream()>>>(\n              in_h2_elem_cnt, h2_inner_dim_size, inner_dim_size, data, segment_ids, num_segments,\n              reinterpret_cast<half2*>(out));\n    }\n  }\n};\n\ntemplate<typename T, typename K>\nvoid UnsortedSegmentSum(ep::Stream* stream, const K* segment_ids, const T* data,\n                        int64_t num_segment_ids, int64_t num_segments, int64_t inner_dim_size,\n                        int64_t padded_inner_dim_size, T* out) {\n  if (inner_dim_size == padded_inner_dim_size) {\n    UnsortedSegmentSumKernelUtil<DeviceType::kCUDA, T, K, T>::UnsortedSegmentSum(\n        stream, segment_ids, data, num_segment_ids, num_segments, 1, inner_dim_size, 0, out);\n  } else {\n    CHECK_EQ(inner_dim_size + 1, padded_inner_dim_size);\n    UnsortedSegmentSumPad<T, K>()(stream, segment_ids, data, num_segment_ids, num_segments,\n                                  inner_dim_size, padded_inner_dim_size, out);\n  }\n}\n\n}  // namespace\n\ntemplate<typename K, typename V, typename IDX, typename HASH>\nvoid UniqueAndPartition(cudaStream_t cuda_stream, int64_t num_ids, size_t capacity,\n                        int64_t num_partition, const K* ids, const V* table_ids,\n                        IDX* num_partitioned_unique_ids_ptr, K* partitioned_unique_ids,\n                        V* partitioned_unique_table_ids, IDX* inverse_unique_partition_indices,\n                        void* workspace_ptr, size_t workspace_bytes, bool need_process_table_ids,\n                        const bool has_padding_idx, const int64_t padding_idx) {\n  size_t table_capacity_bytes = capacity * sizeof(TableEntry<K>);\n  CHECK_GE(workspace_bytes, table_capacity_bytes);\n  OF_CUDA_CHECK(cudaMemsetAsync(workspace_ptr, 0, table_capacity_bytes, cuda_stream));\n  OF_CUDA_CHECK(\n      cudaMemsetAsync(num_partitioned_unique_ids_ptr, 0, num_partition * sizeof(IDX), cuda_stream));\n  HashTableUniqueAndPartitionPairs<K, V, IDX, HASH>\n      <<<BlocksNum4ThreadsNum(num_ids), kCudaThreadsNumPerBlock, 0, cuda_stream>>>(\n          capacity, num_ids, num_partition, num_partitioned_unique_ids_ptr,\n          reinterpret_cast<TableEntry<K>*>(workspace_ptr), ids, table_ids, partitioned_unique_ids,\n          partitioned_unique_table_ids, inverse_unique_partition_indices, need_process_table_ids,\n          has_padding_idx, padding_idx);\n}\n\ntemplate<typename T, typename IDX>\nvoid ShuffleEmbeddings(cudaStream_t cuda_stream, ncclComm_t comm, int64_t parallel_id,\n                       int64_t parallel_num, int64_t num_ids, int64_t embedding_size,\n                       DataType data_type, IDX* host_num_unique_matrix,\n                       const T* reverse_unique_cur_rank_embeddings, T* received_embeddings) {\n  std::vector<int64_t> send_offsets;\n  std::vector<int64_t> send_elem_cnt;\n  std::vector<int64_t> recv_offsets;\n  std::vector<int64_t> recv_elem_cnt;\n  MakeShuffleParams(host_num_unique_matrix, num_ids, embedding_size, parallel_id, parallel_num,\n                    &recv_offsets, &recv_elem_cnt, &send_offsets, &send_elem_cnt);\n  ShuffleData(cuda_stream, comm, data_type, send_offsets, send_elem_cnt,\n              reverse_unique_cur_rank_embeddings, recv_offsets, recv_elem_cnt, received_embeddings);\n}\n\n// Quantized Version.\ntemplate<typename T, typename IDX>\nvoid ShuffleEmbeddings(cudaStream_t cuda_stream, ncclComm_t comm, int64_t parallel_id,\n                       int64_t parallel_num, int64_t num_ids, int64_t embedding_size,\n                       DataType data_type, IDX* host_num_unique_matrix,\n                       int8_t* reverse_unique_cur_rank_embeddings, int8_t* received_embeddings,\n                       T* reverse_cur_rank_quantize_factor, T* recv_quantize_factor) {\n  std::vector<int64_t> send_offsets;\n  std::vector<int64_t> send_elem_cnt;\n  std::vector<int64_t> recv_offsets;\n  std::vector<int64_t> recv_elem_cnt;\n  // shuffle quantized_embedding\n  MakeShuffleParams(host_num_unique_matrix, num_ids, embedding_size, parallel_id, parallel_num,\n                    &recv_offsets, &recv_elem_cnt, &send_offsets, &send_elem_cnt);\n  ShuffleData(cuda_stream, comm, DataType::kInt8, send_offsets, send_elem_cnt,\n              reverse_unique_cur_rank_embeddings, recv_offsets, recv_elem_cnt, received_embeddings);\n  // shuffle quantize_factor\n  MakeShuffleParams(host_num_unique_matrix, num_ids, /*embedding_size=*/1, parallel_id,\n                    parallel_num, &recv_offsets, &recv_elem_cnt, &send_offsets, &send_elem_cnt);\n  ShuffleData(cuda_stream, comm, data_type, send_offsets, send_elem_cnt,\n              reverse_cur_rank_quantize_factor, recv_offsets, recv_elem_cnt, recv_quantize_factor);\n}\n\ntemplate<typename T, typename IDX>\nvoid ShuffleEmbeddingsGrad(cudaStream_t cuda_stream, ncclComm_t comm, int64_t parallel_id,\n                           int64_t parallel_num, int64_t num_ids, int64_t embedding_size,\n                           DataType data_type, IDX* host_num_unique_matrix,\n                           const T* unique_partition_embedding_grad, T* received_embeddings_grad) {\n  std::vector<int64_t> send_offsets;\n  std::vector<int64_t> send_elem_cnt;\n  std::vector<int64_t> recv_offsets;\n  std::vector<int64_t> recv_elem_cnt;\n  MakeShuffleParams(host_num_unique_matrix, num_ids, embedding_size, parallel_id, parallel_num,\n                    &send_offsets, &send_elem_cnt, &recv_offsets, &recv_elem_cnt);\n  ShuffleData(cuda_stream, comm, data_type, send_offsets, send_elem_cnt,\n              unique_partition_embedding_grad, recv_offsets, recv_elem_cnt,\n              received_embeddings_grad);\n}\n\n// Quantize Version.\ntemplate<typename T, typename IDX>\nvoid ShuffleEmbeddingsGrad(cudaStream_t cuda_stream, ncclComm_t comm, int64_t parallel_id,\n                           int64_t parallel_num, int64_t num_ids, int64_t embedding_size,\n                           DataType data_type, IDX* host_num_unique_matrix,\n                           int8_t* unique_partition_embedding_grad,\n                           int8_t* received_embeddings_grad, T* cur_rank_quantize_factor,\n                           T* received_cur_rank_quantize_factor) {\n  std::vector<int64_t> send_offsets;\n  std::vector<int64_t> send_elem_cnt;\n  std::vector<int64_t> recv_offsets;\n  std::vector<int64_t> recv_elem_cnt;\n  // Shuffle Embedding Grad.\n  MakeShuffleParams(host_num_unique_matrix, num_ids, embedding_size, parallel_id, parallel_num,\n                    &send_offsets, &send_elem_cnt, &recv_offsets, &recv_elem_cnt);\n  ShuffleData(cuda_stream, comm, DataType::kInt8, send_offsets, send_elem_cnt,\n              unique_partition_embedding_grad, recv_offsets, recv_elem_cnt,\n              received_embeddings_grad);\n  // Shuffle Quantize factor.\n  MakeShuffleParams(host_num_unique_matrix, num_ids, /*embedding_size=*/1, parallel_id,\n                    parallel_num, &send_offsets, &send_elem_cnt, &recv_offsets, &recv_elem_cnt);\n  ShuffleData(cuda_stream, comm, data_type, send_offsets, send_elem_cnt, cur_rank_quantize_factor,\n              recv_offsets, recv_elem_cnt, received_cur_rank_quantize_factor);\n}\n\ninline int64_t GetPaddedEmbeddingSize(DataType data_type, int64_t embedding_size) {\n  if (data_type == DataType::kFloat16 && embedding_size % 2 != 0) {\n    return embedding_size + 1;\n  } else {\n    return embedding_size;\n  }\n}\n\ntemplate<typename T, typename IDX>\nvoid UniquePartitionEmbeddingGrad(ep::Stream* stream, int64_t unique_partitioned_num_ids,\n                                  int64_t num_ids, int64_t embedding_size,\n                                  int64_t padded_embedding_size, const IDX* host_num_unique_matrix,\n                                  const T* embedding_grad,\n                                  const IDX* inverse_unique_partition_indices,\n                                  T* unique_partition_embedding_grad) {\n  const int64_t valid_value_size = unique_partitioned_num_ids * padded_embedding_size * sizeof(T);\n  OF_CUDA_CHECK(cudaMemsetAsync(unique_partition_embedding_grad, 0, valid_value_size,\n                                stream->As<ep::CudaStream>()->cuda_stream()));\n  UnsortedSegmentSum<T, IDX>(stream, inverse_unique_partition_indices, embedding_grad, num_ids,\n                             unique_partitioned_num_ids, embedding_size, padded_embedding_size,\n                             unique_partition_embedding_grad);\n}\n\ntemplate<typename T, typename IDX>\nvoid UniqueCurRankEmbeddingGrad(ep::Stream* stream, DataType data_type, int64_t cur_rank_num_ids,\n                                int64_t num_unique, int64_t embedding_size,\n                                int64_t padded_embedding_size, bool only_zero_valid_grad,\n                                int64_t cur_rank_unique_embedding_grad_elem_cnt,\n                                const T* cur_rank_embedding_grad,\n                                const IDX* cur_rank_inverse_indices,\n                                T* cur_rank_unique_embedding_grad, T* tmp_buffer) {\n  cudaStream_t cuda_stream = stream->As<ep::CudaStream>()->cuda_stream();\n  // memset cur_rank_unique_embedding_grad, if only_zero_valid_grad, only memset valid data.\n  if (only_zero_valid_grad) {\n    OF_CUDA_CHECK(cudaMemsetAsync(cur_rank_unique_embedding_grad, 0,\n                                  num_unique * embedding_size * sizeof(T), cuda_stream));\n  } else {\n    OF_CUDA_CHECK(cudaMemsetAsync(cur_rank_unique_embedding_grad, 0,\n                                  cur_rank_unique_embedding_grad_elem_cnt * sizeof(T),\n                                  cuda_stream));\n  }\n  T* unsorted_segment_sum_out;\n  if (embedding_size != padded_embedding_size) {\n    unsorted_segment_sum_out = tmp_buffer;\n    size_t buffer_size = GetCudaAlignedSize(num_unique * padded_embedding_size * sizeof(T));\n    OF_CUDA_CHECK(cudaMemsetAsync(unsorted_segment_sum_out, 0, buffer_size, cuda_stream));\n  } else {\n    // cur_rank_unique_embedding_grad's has been memset, not need to memset again.\n    unsorted_segment_sum_out = cur_rank_unique_embedding_grad;\n  }\n  UnsortedSegmentSum<T, IDX>(stream, cur_rank_inverse_indices, cur_rank_embedding_grad,\n                             cur_rank_num_ids, num_unique, padded_embedding_size,\n                             padded_embedding_size, unsorted_segment_sum_out);\n  if (embedding_size != padded_embedding_size) {\n    std::unique_ptr<ep::primitive::CopyNd> primitive =\n        ep::primitive::NewPrimitive<ep::primitive::CopyNdFactory>(DeviceType::kCUDA, 2);\n    DimVector dst_shape = {num_unique, embedding_size};\n    DimVector dst_pos_vec = {0, 0};\n    DimVector src_shape = {num_unique, padded_embedding_size};\n    DimVector src_pos_vec = {0, 0};\n    DimVector extent_vec = {num_unique, embedding_size};\n    primitive->Launch(stream, data_type, 2, cur_rank_unique_embedding_grad, dst_shape.data(),\n                      dst_pos_vec.data(), unsorted_segment_sum_out, src_shape.data(),\n                      src_pos_vec.data(), extent_vec.data());\n  }\n}\n\ntemplate<typename K, typename U, typename IDX>\nstruct IdShuffleDataPtrs {\n  const K* ids_ptr;\n  const U* table_ids_ptr;\n  IDX* num_partitioned_unique;\n  K* partitioned_unique_ids;\n  U* partitioned_unique_table_ids;\n  IDX* num_unique_matrix_ptr;\n  IDX* inverse_unique_partition_indices_ptr;\n  void* workspace_ptr;\n  size_t workspace_size;\n  K* received_ids;\n  U* received_table_ids;\n  IDX* cur_rank_num_unique_ptr;\n  K* cur_rank_unique_ids_ptr;\n  U* cur_rank_unique_table_ids_ptr;\n  IDX* cur_rank_inverse_indices_ptr;\n};\n\ntemplate<typename K, typename U, typename IDX>\nvoid IdShuffle(ep::Stream* stream, ncclComm_t comm, const IdShuffleDataPtrs<K, U, IDX>& data_ptrs,\n               int64_t num_ids, int64_t parallel_id, int64_t parallel_num,\n               DataType num_unique_matrix_dtype, DataType ids_dtype, DataType table_ids_dtype,\n               bool need_process_table_ids, const bool has_padding_idx, const int64_t padding_idx,\n               IDX* host_num_unique_matrix, IDX* host_num_keys) {\n  cudaStream_t cuda_stream = stream->As<ep::CudaStream>()->cuda_stream();\n  size_t hash_table_capacity = parallel_num * num_ids;\n  UniqueAndPartition<K, U, IDX, embedding::ShardingHash>(\n      cuda_stream, num_ids, hash_table_capacity, parallel_num, data_ptrs.ids_ptr,\n      data_ptrs.table_ids_ptr, data_ptrs.num_partitioned_unique, data_ptrs.partitioned_unique_ids,\n      data_ptrs.partitioned_unique_table_ids, data_ptrs.inverse_unique_partition_indices_ptr,\n      data_ptrs.workspace_ptr, data_ptrs.workspace_size, need_process_table_ids, has_padding_idx,\n      padding_idx);\n\n  OF_NCCL_CHECK(ncclAllGather(data_ptrs.num_partitioned_unique, data_ptrs.num_unique_matrix_ptr,\n                              parallel_num, GetNcclDataType(num_unique_matrix_dtype), comm,\n                              cuda_stream));\n\n  OF_CUDA_CHECK(cudaMemcpyAsync(host_num_unique_matrix, data_ptrs.num_unique_matrix_ptr,\n                                parallel_num * parallel_num * sizeof(IDX), cudaMemcpyDefault,\n                                cuda_stream));\n  CHECK_JUST(stream->Sync());\n  if (parallel_num > 1) {\n    // use num_partitioned_unique as indices_offset buffer, so should after ncclAllGather.\n    ComputeOffset<<<1, 1, 0, cuda_stream>>>(parallel_num, data_ptrs.num_partitioned_unique);\n    ContiguousInverseUniquePartitionIndices<<<BlocksNum4ThreadsNum(num_ids),\n                                              kCudaThreadsNumPerBlock, 0, cuda_stream>>>(\n        num_ids, data_ptrs.num_partitioned_unique, data_ptrs.inverse_unique_partition_indices_ptr);\n  }\n  int64_t received_elem_cnt = 0;\n  ShuffleIdsAndTableIds(cuda_stream, comm, parallel_id, parallel_num, num_ids, ids_dtype,\n                        table_ids_dtype, host_num_unique_matrix, data_ptrs.partitioned_unique_ids,\n                        data_ptrs.partitioned_unique_table_ids, data_ptrs.received_ids,\n                        data_ptrs.received_table_ids, &received_elem_cnt, need_process_table_ids);\n  UniqueAndPartition<K, U, IDX, embedding::LocalUniqueHash>(\n      cuda_stream, received_elem_cnt, hash_table_capacity, 1, data_ptrs.received_ids,\n      data_ptrs.received_table_ids, data_ptrs.cur_rank_num_unique_ptr,\n      data_ptrs.cur_rank_unique_ids_ptr, data_ptrs.cur_rank_unique_table_ids_ptr,\n      data_ptrs.cur_rank_inverse_indices_ptr, data_ptrs.workspace_ptr, data_ptrs.workspace_size,\n      need_process_table_ids, has_padding_idx, padding_idx);\n  if (!need_process_table_ids) {\n    OF_CUDA_CHECK(cudaMemsetAsync(data_ptrs.cur_rank_unique_table_ids_ptr, 0,\n                                  received_elem_cnt * sizeof(U), cuda_stream));\n  }\n  OF_CUDA_CHECK(cudaMemcpyAsync(host_num_keys, data_ptrs.cur_rank_num_unique_ptr, sizeof(IDX),\n                                cudaMemcpyDefault, cuda_stream));\n  CHECK_JUST(stream->Sync());\n}\n\n}  // namespace data_shuffle\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/one_embedding_embedding_gradient_shuffle_p2p_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/job/parallel_desc.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n#include \"oneflow/core/cuda/atomic.cuh\"\n#include \"oneflow/core/embedding/embedding_manager.h\"\n#include \"oneflow/core/control/ctrl_client.h\"\n#include \"oneflow/core/kernel/cuda_graph_support.h\"\n#include <cuda.h>\n\n#if CUDA_VERSION >= 11030\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename T, int pack_size>\nstruct alignas(sizeof(T) * pack_size) Pack {\n  T elem[pack_size];\n};\n\ntemplate<typename T, int32_t pack_size>\n__device__ __inline__ void AtomicAdd(Pack<T, pack_size>* address, Pack<T, pack_size> val) {\n#pragma unroll\n  for (int i = 0; i < pack_size; ++i) {\n    cuda::atomic::Add(reinterpret_cast<T*>(address) + i, static_cast<T>(val.elem[i]));\n  }\n}\n\ntemplate<>\n__device__ __inline__ void AtomicAdd<half, 2>(Pack<half, 2>* address, Pack<half, 2> val) {\n  half2 h2_val;\n  h2_val.x = static_cast<half>(val.elem[0]);\n  h2_val.y = static_cast<half>(val.elem[1]);\n  cuda::atomic::Add(reinterpret_cast<half2*>(address), h2_val);\n}\n\ntemplate<typename T, typename IDX, int pack_size, int N>\nstruct Param {\n  const IDX* cur_rank_inverse_indices;\n  const Pack<T, pack_size>* unique_partitioned_embedding_grads[N];\n  int32_t* is_kernel_start[N];\n  const IDX* num_unique_matrix;\n  Pack<T, pack_size>* cur_rank_unique_embedding_grad_ptr;\n};\n\ntemplate<typename T, typename IDX, int pack_size, int N>\n__global__ void EmbeddingGradientShuffleCudaKernel(int64_t parallel_id, int64_t parallel_num,\n                                                   int64_t embedding_num_pack,\n                                                   Param<T, IDX, pack_size, N> param) {\n#pragma unroll 1\n  for (int i = 0; i < parallel_num; ++i) {\n    int rank_id = (parallel_id + i) % parallel_num;\n    IDX cur_rank_index_offset = 0;\n    for (int k = 0; k < rank_id; ++k) {\n      cur_rank_index_offset += param.num_unique_matrix[k * parallel_num + parallel_id];\n    }\n    IDX in_index_offset = 0;\n    for (int k = 0; k < parallel_id; ++k) {\n      in_index_offset += param.num_unique_matrix[rank_id * parallel_num + k];\n    }\n    const IDX* cur_rank_inverse_indices_ptr =\n        param.cur_rank_inverse_indices + cur_rank_index_offset;\n    const Pack<T, pack_size>* unique_partitioned_embedding_grad_ptr =\n        param.unique_partitioned_embedding_grads[rank_id] + in_index_offset * embedding_num_pack;\n    Pack<T, pack_size>* cur_rank_unique_embedding_grad_ptr =\n        param.cur_rank_unique_embedding_grad_ptr;\n    const int copy_cnt =\n        param.num_unique_matrix[rank_id * parallel_num + parallel_id] * embedding_num_pack;\n    CUDA_1D_KERNEL_LOOP_T(int, j, copy_cnt) {\n      int in_row_id = j / embedding_num_pack;\n      int col_id = j - in_row_id * embedding_num_pack;\n      int out_row_id = cur_rank_inverse_indices_ptr[in_row_id];\n      Pack<T, pack_size> grad_val = unique_partitioned_embedding_grad_ptr[j];\n      AtomicAdd(cur_rank_unique_embedding_grad_ptr + out_row_id * embedding_num_pack + col_id,\n                grad_val);\n    }\n  }\n}\n\ntemplate<typename T, typename IDX, int pack_size, int N>\n__global__ void BarrierKernel(int32_t parallel_id, int32_t parallel_num,\n                              Param<T, IDX, pack_size, N> param) {\n  int count = param.is_kernel_start[parallel_id][parallel_id];\n  if (threadIdx.x < parallel_num) {\n    volatile int32_t* start_f = param.is_kernel_start[parallel_id];\n    volatile int32_t* remote_start_f = param.is_kernel_start[threadIdx.x];\n    start_f[threadIdx.x] = count + 1;\n    while (remote_start_f[parallel_id] < count + 1) {}\n  }\n}\n\nstruct IpcMemHandleOffset {\n  cudaIpcMemHandle_t handle;\n  int64_t offset;\n};\n\nvoid GetPtrs(user_op::KernelComputeContext* ctx,\n             std::vector<void*>* unique_partitioned_embedding_grad_ptr,\n             std::vector<void*>* is_kernel_start_ptr) {\n  const int64_t parallel_id = ctx->parallel_ctx().parallel_id();\n  const int64_t parallel_num = ctx->parallel_ctx().parallel_num();\n  unique_partitioned_embedding_grad_ptr->at(parallel_id) =\n      const_cast<void*>(ctx->Tensor4ArgNameAndIndex(\"embedding_grad\", 0)->dptr());\n  std::string name = ctx->op_name();\n  {\n    std::vector<IpcMemHandleOffset> push_handle_offset;\n    push_handle_offset.resize(2);\n    OF_CUDA_CHECK(cudaIpcGetMemHandle(&push_handle_offset.at(0).handle,\n                                      unique_partitioned_embedding_grad_ptr->at(parallel_id)));\n    OF_CUDA_CHECK(cudaIpcGetMemHandle(&push_handle_offset.at(1).handle,\n                                      is_kernel_start_ptr->at(parallel_id)));\n    cudaError_t (*func)(void*, CUpointer_attribute, CUdeviceptr);\n    OF_CUDA_CHECK(\n        cudaGetDriverEntryPoint(\"cuPointerGetAttribute\", (void**)(&func), cudaEnableDefault));\n    void* embedding_grad_base;\n    OF_CUDA_CHECK(func(&embedding_grad_base, CU_POINTER_ATTRIBUTE_RANGE_START_ADDR,\n                       (CUdeviceptr)(unique_partitioned_embedding_grad_ptr->at(parallel_id))));\n    push_handle_offset.at(0).offset =\n        reinterpret_cast<char*>(unique_partitioned_embedding_grad_ptr->at(parallel_id))\n        - reinterpret_cast<char*>(embedding_grad_base);\n    push_handle_offset.at(1).offset = 0;\n    Singleton<CtrlClient>::Get()->PushKV(\n        name + std::to_string(parallel_id),\n        std::string(reinterpret_cast<const char*>(push_handle_offset.data()),\n                    2 * sizeof(IpcMemHandleOffset)));\n  }\n  for (int64_t i = 0; i < parallel_num; ++i) {\n    std::string key = name + std::to_string(i);\n    if (parallel_id != i) {\n      std::vector<IpcMemHandleOffset> handle_offset;\n      handle_offset.resize(2);\n      Singleton<CtrlClient>::Get()->PullKV(key, [i, &handle_offset](const std::string& val) {\n        memcpy(handle_offset.data(), val.data(), 2 * sizeof(IpcMemHandleOffset));\n      });\n      OF_CUDA_CHECK(cudaIpcOpenMemHandle(&unique_partitioned_embedding_grad_ptr->at(i),\n                                         handle_offset.at(0).handle,\n                                         cudaIpcMemLazyEnablePeerAccess));\n      unique_partitioned_embedding_grad_ptr->at(i) =\n          reinterpret_cast<char*>(unique_partitioned_embedding_grad_ptr->at(i))\n          + handle_offset.at(0).offset;\n      OF_CUDA_CHECK(cudaIpcOpenMemHandle(&is_kernel_start_ptr->at(i), handle_offset.at(1).handle,\n                                         cudaIpcMemLazyEnablePeerAccess));\n      is_kernel_start_ptr->at(i) =\n          reinterpret_cast<char*>(is_kernel_start_ptr->at(i)) + handle_offset.at(1).offset;\n    }\n  }\n}\n\ntemplate<typename IDX>\nclass DataShuffleKernelState final : public user_op::OpKernelState {\n public:\n  explicit DataShuffleKernelState(user_op::KernelInitContext* ctx)\n      : device_index_(-1),\n        parallel_desc_(ctx->parallel_desc()),\n        parallel_id_(ctx->parallel_ctx().parallel_id()) {\n    OF_CUDA_CHECK(cudaGetDevice(&device_index_));\n    int64_t parallel_num = parallel_desc_.parallel_num();\n    unique_partitioned_embedding_grad_ptr_.resize(parallel_num);\n    is_kernel_start_ptr_.resize(parallel_num);\n    size_t is_kernel_start_size = GetCudaAlignedSize(parallel_num * sizeof(int32_t));\n    OF_CUDA_CHECK(cudaMalloc(&is_kernel_start_ptr_.at(parallel_id_), is_kernel_start_size));\n    OF_CUDA_CHECK(cudaMemset(is_kernel_start_ptr_.at(parallel_id_), 0, is_kernel_start_size));\n  }\n\n  ~DataShuffleKernelState() {\n    CudaCurrentDeviceGuard guard(device_index_);\n    OF_CUDA_CHECK(cudaFree(is_kernel_start_ptr_.at(parallel_id_)));\n  }\n\n  std::vector<void*>* UniquePartitionedEmbeddingGrads() {\n    return &unique_partitioned_embedding_grad_ptr_;\n  }\n\n  std::vector<void*>* IsKernelStart() { return &is_kernel_start_ptr_; }\n\n private:\n  int device_index_;\n  ParallelDesc parallel_desc_;\n  int64_t parallel_id_;\n  std::vector<void*> unique_partitioned_embedding_grad_ptr_;\n  std::vector<void*> is_kernel_start_ptr_;\n};\n\nconstexpr int pack_size = 2;\n\ntemplate<typename T, size_t pack>\n__global__ void MemsetCurRankEmbeddingGrad(int64_t parallel_id, int64_t parallel_num,\n                                           int64_t vector_size, const uint32_t* num_unique_matrix,\n                                           T* dst) {\n  size_t count = 0;\n  for (int i = 0; i < parallel_num; ++i) {\n    count += num_unique_matrix[i * parallel_num + parallel_id] * vector_size;\n  }\n  const size_t pack_count = count / pack;\n  Pack<T, pack> pack_value;\n  for (int i = 0; i < pack; ++i) { pack_value.elem[i] = static_cast<T>(0); }\n  auto* pack_dst = reinterpret_cast<Pack<T, pack>*>(dst);\n  CUDA_1D_KERNEL_LOOP_T(size_t, i, pack_count) { pack_dst[i] = pack_value; }\n  T* tail_dst = dst + pack_count * pack;\n  const size_t tail_count = count - pack_count * pack;\n  CUDA_1D_KERNEL_LOOP_T(size_t, i, tail_count) { tail_dst[i] = static_cast<T>(0); }\n}\n\ntemplate<typename T, size_t pack>\ntypename std::enable_if<(pack != 0), void>::type LaunchPackMemsetCurRankEmbeddingGrad(\n    cudaStream_t stream, const uint32_t* num_unique_matrix, T* ptr, int sm_count,\n    int64_t vector_size, int64_t parallel_id, int64_t parallel_num) {\n  MemsetCurRankEmbeddingGrad<T, pack><<<2 * sm_count, 1024, 0, stream>>>(\n      parallel_id, parallel_num, vector_size, num_unique_matrix, ptr);\n}\n\ntemplate<typename T, size_t pack>\ntypename std::enable_if<(pack == 0), void>::type LaunchPackMemsetCurRankEmbeddingGrad(\n    cudaStream_t stream, const uint32_t* num_unique_matrix, T* ptr, int sm_count,\n    int64_t vector_size, int64_t parallel_id, int64_t parallel_num) {\n  LOG(FATAL) << \"wrong alignment\";\n}\n\ntemplate<typename T>\nvoid LaunchMemsetCurRankEmbeddingGrad(cudaStream_t stream, int sm_count, int64_t vector_size,\n                                      int64_t parallel_id, int64_t parallel_num,\n                                      const uint32_t* num_unique_matrix, T* ptr) {\n  auto uintptr = reinterpret_cast<std::uintptr_t>(ptr);\n  if (uintptr % 16 == 0) {\n    LaunchPackMemsetCurRankEmbeddingGrad<T, 16 / sizeof(T)>(\n        stream, num_unique_matrix, ptr, sm_count, vector_size, parallel_id, parallel_num);\n  } else if (uintptr % 8 == 0) {\n    LaunchPackMemsetCurRankEmbeddingGrad<T, 8 / sizeof(T)>(stream, num_unique_matrix, ptr, sm_count,\n                                                           vector_size, parallel_id, parallel_num);\n  } else if (uintptr % 4 == 0) {\n    LaunchPackMemsetCurRankEmbeddingGrad<T, 4 / sizeof(T)>(stream, num_unique_matrix, ptr, sm_count,\n                                                           vector_size, parallel_id, parallel_num);\n  } else if (uintptr % 2 == 0) {\n    LaunchPackMemsetCurRankEmbeddingGrad<T, 2 / sizeof(T)>(stream, num_unique_matrix, ptr, sm_count,\n                                                           vector_size, parallel_id, parallel_num);\n  } else {\n    LaunchPackMemsetCurRankEmbeddingGrad<T, 1 / sizeof(T)>(stream, num_unique_matrix, ptr, sm_count,\n                                                           vector_size, parallel_id, parallel_num);\n  }\n}\n\n}  // namespace\n\ntemplate<typename T, typename IDX>\nclass EmbeddingGraidientShuffleP2PKernel final : public user_op::OpKernel,\n                                                 public user_op::CudaGraphSupport {\n public:\n  EmbeddingGraidientShuffleP2PKernel() : current_iter_(0) {}\n  ~EmbeddingGraidientShuffleP2PKernel() override = default;\n\n  std::shared_ptr<user_op::OpKernelState> CreateOpKernelState(\n      user_op::KernelInitContext* ctx) const override {\n    return std::make_shared<DataShuffleKernelState<IDX>>(ctx);\n  }\n\n  bool IsReadyForCapture(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state,\n                         const user_op::OpKernelCache* cache) const override {\n    if (current_iter_ == 0) {\n      return false;\n    } else {\n      return true;\n    }\n  }\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state,\n               const user_op::OpKernelCache*) const override {\n    CHECK(!embedding::UseDynamicMemoryAllocation());\n    CHECK(ParseBooleanFromEnv(\"ONEFLOW_ONE_EMBEDDING_FUSE_EMBEDDING_INTERACTION\",\n                              false));  // only support skip last gather.\n    CHECK(ParseBooleanFromEnv(\"ONEFLOW_ONE_EMBEDDING_ADD_ID_SHUFFLE_COPY_OUT\",\n                              true));  // when no identity, every time the cur_rank_inverse_indices\n                                       // will change becauseof regster num=2.\n    auto* kernel_state = dynamic_cast<DataShuffleKernelState<IDX>*>(state);\n    CHECK(kernel_state != nullptr);\n    const user_op::Tensor* embedding_grad = ctx->Tensor4ArgNameAndIndex(\"embedding_grad\", 0);\n    const user_op::Tensor* num_unique_matrix = ctx->Tensor4ArgNameAndIndex(\"num_unique_matrix\", 0);\n    const user_op::Tensor* cur_rank_inverse_indices =\n        ctx->Tensor4ArgNameAndIndex(\"cur_rank_inverse_indices\", 0);\n    user_op::Tensor* cur_rank_unique_embedding_grad =\n        ctx->Tensor4ArgNameAndIndex(\"cur_rank_unique_embedding_grad\", 0);\n\n    const int64_t embedding_size = ctx->Attr<int64_t>(\"embedding_size\");\n    const bool only_zero_valid_grad = ctx->Attr<bool>(\"only_zero_valid_grad\");\n    const int64_t parallel_num = ctx->parallel_ctx().parallel_num();\n    const int64_t parallel_id = ctx->parallel_ctx().parallel_id();\n    const int sm_count =\n        ctx->stream()->As<ep::CudaStream>()->device_properties().multiProcessorCount;\n    const bool skip_first_scatter = ctx->Attr<bool>(\"skip_first_scatter\");\n    CHECK(skip_first_scatter);\n    cudaStream_t cuda_stream = ctx->stream()->As<ep::CudaStream>()->cuda_stream();\n    if (current_iter_ == 0) {\n      GetPtrs(ctx, kernel_state->UniquePartitionedEmbeddingGrads(), kernel_state->IsKernelStart());\n    }\n    CHECK_EQ(kernel_state->UniquePartitionedEmbeddingGrads()->at(parallel_id),\n             embedding_grad->dptr());\n    Param<T, IDX, pack_size, 8> param;\n    CHECK_EQ(embedding_size % pack_size, 0);\n    CHECK_LE(parallel_num, 8);\n    param.cur_rank_unique_embedding_grad_ptr =\n        reinterpret_cast<Pack<T, pack_size>*>(cur_rank_unique_embedding_grad->mut_dptr<T>());\n    for (int i = 0; i < parallel_num; ++i) {\n      param.unique_partitioned_embedding_grads[i] = reinterpret_cast<Pack<T, pack_size>*>(\n          kernel_state->UniquePartitionedEmbeddingGrads()->at(i));\n      param.is_kernel_start[i] = reinterpret_cast<int32_t*>(kernel_state->IsKernelStart()->at(i));\n    }\n    param.cur_rank_inverse_indices = reinterpret_cast<const IDX*>(cur_rank_inverse_indices->dptr());\n    param.num_unique_matrix = reinterpret_cast<const uint32_t*>(num_unique_matrix->dptr());\n    int64_t embedding_num_pack = embedding_size / pack_size;\n    if (only_zero_valid_grad) {\n      LaunchMemsetCurRankEmbeddingGrad(cuda_stream, sm_count, embedding_size, parallel_id,\n                                       parallel_num,\n                                       reinterpret_cast<const uint32_t*>(num_unique_matrix->dptr()),\n                                       cur_rank_unique_embedding_grad->mut_dptr<T>());\n    } else {\n      OF_CUDA_CHECK(cudaMemsetAsync(\n          cur_rank_unique_embedding_grad->mut_dptr(), 0,\n          cur_rank_unique_embedding_grad->shape_view().elem_cnt() * sizeof(T), cuda_stream));\n    }\n    BarrierKernel<<<1, parallel_num, 0, cuda_stream>>>(parallel_id, parallel_num, param);\n    const int num_blocks =\n        2 * ctx->stream()->As<ep::CudaStream>()->device_properties().multiProcessorCount;\n    EmbeddingGradientShuffleCudaKernel<<<num_blocks, 1024, 0, cuda_stream>>>(\n        parallel_id, parallel_num, embedding_num_pack, param);\n    current_iter_++;\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n  mutable int64_t current_iter_;\n};\n\nREGISTER_USER_KERNEL(\"embedding_gradient_shuffle\")\n    .SetCreateFn<EmbeddingGraidientShuffleP2PKernel<half, uint32_t>>()\n    .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)\n                     && (user_op::HobDataType(\"embedding_grad\", 0) == DataType::kFloat16)\n                     && (user_op::HobDataType(\"num_unique_matrix\", 0) == DataType::kUInt32)\n                     && (user_op::HobAttr<bool>(\"skip_first_scatter\") == true)\n                     && (embedding::UseEmbeddingGradientShuffleP2PKernel(DataType::kFloat16,\n                                                                         DataType::kUInt32)));\n\n}  // namespace oneflow\n\n#endif  // CUDA_VERSION >= 11030\n"
  },
  {
    "path": "oneflow/user/kernels/one_embedding_embedding_shuffle_p2p_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/job/parallel_desc.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n#include \"oneflow/core/cuda/atomic.cuh\"\n#include \"oneflow/core/embedding/embedding_manager.h\"\n#include \"oneflow/core/control/ctrl_client.h\"\n#include \"oneflow/core/kernel/cuda_graph_support.h\"\n#include <cuda.h>\n\n#if CUDA_VERSION >= 11030\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename T, int pack_size>\nstruct alignas(sizeof(T) * pack_size) Pack {\n  T elem[pack_size];\n};\n\ntemplate<typename T, typename IDX, int pack_size, int N>\nstruct Param {\n  IDX* inverse_indices[N];\n  Pack<T, pack_size>* unique_embeddings[N];\n  int32_t* is_kernel_start[N];\n  const IDX* num_unique_matrix;\n  Pack<T, pack_size>* embedding_ptr;\n};\n\ntemplate<typename T, typename IDX, int pack_size, int N>\n__global__ void EmbeddingShuffleCudaKernel(int parallel_id, int parallel_num,\n                                           int embedding_num_pack,\n                                           Param<T, IDX, pack_size, N> param) {\n#pragma unroll 1\n  for (int i = 0; i < parallel_num; ++i) {\n    int rank_id = (parallel_id + i) % parallel_num;\n    IDX out_index_offset = 0;\n    for (int k = 0; k < rank_id; ++k) {\n      out_index_offset += param.num_unique_matrix[parallel_id * parallel_num + k];\n    }\n    IDX in_index_offset = 0;\n    for (int k = 0; k < parallel_id; ++k) {\n      in_index_offset += param.num_unique_matrix[k * parallel_num + rank_id];\n    }\n    const IDX* inverse_indices_ptr = param.inverse_indices[rank_id] + in_index_offset;\n    const Pack<T, pack_size>* unique_embeddings_ptr = param.unique_embeddings[rank_id];\n    Pack<T, pack_size>* embedding_ptr = param.embedding_ptr + out_index_offset * embedding_num_pack;\n    const int copy_cnt =\n        param.num_unique_matrix[parallel_id * parallel_num + rank_id] * embedding_num_pack;\n    CUDA_1D_KERNEL_LOOP_T(int, j, copy_cnt) {\n      int out_row_id = j / embedding_num_pack;\n      int in_row_id = inverse_indices_ptr[out_row_id];\n      int col_id = j - out_row_id * embedding_num_pack;\n      embedding_ptr[j] = unique_embeddings_ptr[in_row_id * embedding_num_pack + col_id];\n    }\n  }\n}\n\ntemplate<typename T, typename IDX, int pack_size, int N>\n__global__ void EmbeddingShuffleCopyKernel(int parallel_id, int parallel_num,\n                                           int embedding_num_pack,\n                                           Param<T, IDX, pack_size, N> param) {\n#pragma unroll 1\n  for (int i = 0; i < parallel_num; ++i) {\n    int rank_id = (parallel_id + i) % parallel_num;\n    IDX out_index_offset = 0;\n    for (int k = 0; k < rank_id; ++k) {\n      out_index_offset += param.num_unique_matrix[parallel_id * parallel_num + k];\n    }\n    IDX in_index_offset = 0;\n    for (int k = 0; k < parallel_id; ++k) {\n      in_index_offset += param.num_unique_matrix[k * parallel_num + rank_id];\n    }\n    const Pack<T, pack_size>* unique_embeddings_ptr =\n        param.unique_embeddings[rank_id] + in_index_offset * embedding_num_pack;\n    Pack<T, pack_size>* embedding_ptr = param.embedding_ptr + out_index_offset * embedding_num_pack;\n    const int copy_cnt =\n        param.num_unique_matrix[parallel_id * parallel_num + rank_id] * embedding_num_pack;\n    CUDA_1D_KERNEL_LOOP_T(int, j, copy_cnt) { embedding_ptr[j] = unique_embeddings_ptr[j]; }\n  }\n}\n\ntemplate<typename T, typename IDX, int pack_size>\n__global__ void GatherKernel(int parallel_id, int parallel_num, int embedding_num_pack,\n                             const IDX* num_unique_matrix, const IDX* inverse_indices,\n                             const Pack<T, pack_size>* unique_embeddings,\n                             Pack<T, pack_size>* gather_out_unique_embeddings) {\n  int cur_rank_num_ids = 0;\n  for (int i = 0; i < parallel_num; ++i) {\n    cur_rank_num_ids += num_unique_matrix[i * parallel_num + parallel_id];\n  }\n  int out_cnt = cur_rank_num_ids * embedding_num_pack;\n  CUDA_1D_KERNEL_LOOP_T(int, i, out_cnt) {\n    int out_row_id = i / embedding_num_pack;\n    int in_row_id = inverse_indices[out_row_id];\n    int col_id = i - out_row_id * embedding_num_pack;\n    gather_out_unique_embeddings[i] = unique_embeddings[in_row_id * embedding_num_pack + col_id];\n  }\n}\n\ntemplate<typename T, typename IDX, int pack_size, int N>\n__global__ void BarrierKernel(int32_t parallel_id, int32_t parallel_num,\n                              Param<T, IDX, pack_size, N> param) {\n  int count = param.is_kernel_start[parallel_id][parallel_id];\n  if (threadIdx.x < parallel_num) {\n    volatile int32_t* start_f = param.is_kernel_start[parallel_id];\n    volatile int32_t* remote_start_f = param.is_kernel_start[threadIdx.x];\n    start_f[threadIdx.x] = count + 1;\n    while (remote_start_f[parallel_id] < count + 1) {}\n  }\n}\n\nstruct IpcMemHandleOffset {\n  cudaIpcMemHandle_t handle;\n  int64_t offset;\n};\n\nbool DisableFuseGatherCopy() {\n  return ParseBooleanFromEnv(\"ONEFLOW_ONE_EMBEDDING_P2P_DISABLE_FUSE_GATHER_COPY\", false);\n}\n\nvoid GetPtrs(user_op::KernelComputeContext* ctx, std::vector<void*>* unique_embeddings_ptr,\n             std::vector<void*>* inverse_indices_ptr, std::vector<void*>* is_kernel_start_ptr) {\n  const int64_t num_ids =\n      ctx->TensorDesc4ArgNameAndIndex(\"inverse_unique_partition_indices\", 0)->shape().elem_cnt();\n  const int64_t parallel_id = ctx->parallel_ctx().parallel_id();\n  const int64_t parallel_num = ctx->parallel_ctx().parallel_num();\n  inverse_indices_ptr->at(parallel_id) =\n      const_cast<void*>(ctx->Tensor4ArgNameAndIndex(\"cur_rank_inverse_indices\", 0)->dptr());\n  if (DisableFuseGatherCopy()) {\n    unique_embeddings_ptr->at(parallel_id) =\n        ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0)->mut_dptr();\n  } else {\n    unique_embeddings_ptr->at(parallel_id) =\n        const_cast<void*>(ctx->Tensor4ArgNameAndIndex(\"cur_rank_embeddings\", 0)->dptr());\n  }\n\n  std::string name = ctx->op_name();\n  {\n    std::vector<IpcMemHandleOffset> push_handle_offset;\n    push_handle_offset.resize(3);\n    OF_CUDA_CHECK(cudaIpcGetMemHandle(&push_handle_offset.at(0).handle,\n                                      unique_embeddings_ptr->at(parallel_id)));\n    OF_CUDA_CHECK(cudaIpcGetMemHandle(&push_handle_offset.at(1).handle,\n                                      inverse_indices_ptr->at(parallel_id)));\n    OF_CUDA_CHECK(cudaIpcGetMemHandle(&push_handle_offset.at(2).handle,\n                                      is_kernel_start_ptr->at(parallel_id)));\n\n    cudaError_t (*func)(void*, CUpointer_attribute, CUdeviceptr);\n    OF_CUDA_CHECK(\n        cudaGetDriverEntryPoint(\"cuPointerGetAttribute\", (void**)(&func), cudaEnableDefault));\n    void* unique_embeddings_base;\n    OF_CUDA_CHECK(func(&unique_embeddings_base, CU_POINTER_ATTRIBUTE_RANGE_START_ADDR,\n                       (CUdeviceptr)(unique_embeddings_ptr->at(parallel_id))));\n    push_handle_offset.at(0).offset =\n        reinterpret_cast<char*>(unique_embeddings_ptr->at(parallel_id))\n        - reinterpret_cast<char*>(unique_embeddings_base);\n    void* inverse_indices_base;\n    OF_CUDA_CHECK(func(&inverse_indices_base, CU_POINTER_ATTRIBUTE_RANGE_START_ADDR,\n                       (CUdeviceptr)(inverse_indices_ptr->at(parallel_id))));\n    push_handle_offset.at(1).offset = reinterpret_cast<char*>(inverse_indices_ptr->at(parallel_id))\n                                      - reinterpret_cast<char*>(inverse_indices_base);\n    push_handle_offset.at(2).offset = 0;\n    Singleton<CtrlClient>::Get()->PushKV(\n        name + std::to_string(parallel_id),\n        std::string(reinterpret_cast<const char*>(push_handle_offset.data()),\n                    3 * sizeof(IpcMemHandleOffset)));\n  }\n  for (int64_t i = 0; i < parallel_num; ++i) {\n    std::string key = name + std::to_string(i);\n    if (parallel_id != i) {\n      std::vector<IpcMemHandleOffset> handle_offset;\n      handle_offset.resize(3);\n      Singleton<CtrlClient>::Get()->PullKV(key, [i, &handle_offset](const std::string& val) {\n        memcpy(handle_offset.data(), val.data(), 3 * sizeof(IpcMemHandleOffset));\n      });\n      OF_CUDA_CHECK(cudaIpcOpenMemHandle(&unique_embeddings_ptr->at(i), handle_offset.at(0).handle,\n                                         cudaIpcMemLazyEnablePeerAccess));\n      unique_embeddings_ptr->at(i) =\n          reinterpret_cast<char*>(unique_embeddings_ptr->at(i)) + handle_offset.at(0).offset;\n\n      OF_CUDA_CHECK(cudaIpcOpenMemHandle(&inverse_indices_ptr->at(i), handle_offset.at(1).handle,\n                                         cudaIpcMemLazyEnablePeerAccess));\n      inverse_indices_ptr->at(i) =\n          reinterpret_cast<char*>(inverse_indices_ptr->at(i)) + handle_offset.at(1).offset;\n\n      OF_CUDA_CHECK(cudaIpcOpenMemHandle(&is_kernel_start_ptr->at(i), handle_offset.at(2).handle,\n                                         cudaIpcMemLazyEnablePeerAccess));\n      is_kernel_start_ptr->at(i) =\n          reinterpret_cast<char*>(is_kernel_start_ptr->at(i)) + handle_offset.at(2).offset;\n    }\n  }\n}\n\ntemplate<typename IDX>\nclass DataShuffleKernelState final : public user_op::OpKernelState {\n public:\n  explicit DataShuffleKernelState(user_op::KernelInitContext* ctx)\n      : device_index_(-1),\n        parallel_desc_(ctx->parallel_desc()),\n        parallel_id_(ctx->parallel_ctx().parallel_id()) {\n    OF_CUDA_CHECK(cudaGetDevice(&device_index_));\n    int64_t parallel_num = parallel_desc_.parallel_num();\n    unique_embeddings_ptr_.resize(parallel_num);\n    inverse_indices_ptr_.resize(parallel_num);\n    is_kernel_start_ptr_.resize(parallel_num);\n    size_t is_kernel_start_size = GetCudaAlignedSize(parallel_num * sizeof(int32_t));\n    OF_CUDA_CHECK(cudaMalloc(&is_kernel_start_ptr_.at(parallel_id_), is_kernel_start_size));\n    OF_CUDA_CHECK(cudaMemset(is_kernel_start_ptr_.at(parallel_id_), 0, is_kernel_start_size));\n  }\n\n  ~DataShuffleKernelState() {\n    CudaCurrentDeviceGuard guard(device_index_);\n    OF_CUDA_CHECK(cudaFree(is_kernel_start_ptr_.at(parallel_id_)));\n  }\n\n  std::vector<void*>* UniqueEmbeddings() { return &unique_embeddings_ptr_; }\n\n  std::vector<void*>* InverseIndices() { return &inverse_indices_ptr_; }\n\n  std::vector<void*>* IsKernelStart() { return &is_kernel_start_ptr_; }\n\n private:\n  int device_index_;\n  ParallelDesc parallel_desc_;\n  int64_t parallel_id_;\n  std::vector<void*> unique_embeddings_ptr_;\n  std::vector<void*> inverse_indices_ptr_;\n  std::vector<void*> is_kernel_start_ptr_;\n};\n\ntemplate<typename T, typename IDX, int pack_size>\nvoid LaunchKernel(user_op::KernelComputeContext* ctx, DataShuffleKernelState<IDX>* kernel_state) {\n  const int64_t parallel_num = ctx->parallel_ctx().parallel_num();\n  const int64_t parallel_id = ctx->parallel_ctx().parallel_id();\n  const user_op::Tensor* num_unique_matrix = ctx->Tensor4ArgNameAndIndex(\"num_unique_matrix\", 0);\n  user_op::Tensor* embeddings = ctx->Tensor4ArgNameAndIndex(\"embeddings\", 0);\n  const int64_t embedding_size = ctx->Attr<int64_t>(\"embedding_size\");\n  DataType data_type = embeddings->data_type();\n  Param<T, IDX, pack_size, 8> param;\n  CHECK_LE(parallel_num, 8);\n  param.embedding_ptr = reinterpret_cast<Pack<T, pack_size>*>(embeddings->mut_dptr<T>());\n  for (int i = 0; i < parallel_num; ++i) {\n    param.inverse_indices[i] = reinterpret_cast<IDX*>(kernel_state->InverseIndices()->at(i));\n    param.unique_embeddings[i] =\n        reinterpret_cast<Pack<T, pack_size>*>(kernel_state->UniqueEmbeddings()->at(i));\n    param.is_kernel_start[i] = reinterpret_cast<int32_t*>(kernel_state->IsKernelStart()->at(i));\n  }\n  param.num_unique_matrix = reinterpret_cast<const uint32_t*>(num_unique_matrix->dptr());\n  int64_t embedding_num_pack = embedding_size / pack_size;\n  cudaStream_t cuda_stream = ctx->stream()->As<ep::CudaStream>()->cuda_stream();\n  BarrierKernel<<<1, parallel_num, 0, cuda_stream>>>(parallel_id, parallel_num, param);\n  const int num_blocks =\n      2 * ctx->stream()->As<ep::CudaStream>()->device_properties().multiProcessorCount;\n\n  if (DisableFuseGatherCopy()) {\n    CHECK_EQ(kernel_state->UniqueEmbeddings()->at(parallel_id),\n             ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0)->dptr())\n        << parallel_id;\n    GatherKernel<<<num_blocks, 1024, 0, cuda_stream>>>(\n        parallel_id, parallel_num, embedding_num_pack, param.num_unique_matrix,\n        param.inverse_indices[parallel_id],\n        reinterpret_cast<const Pack<T, pack_size>*>(\n            ctx->Tensor4ArgNameAndIndex(\"cur_rank_embeddings\", 0)->dptr()),\n        param.unique_embeddings[parallel_id]);\n    EmbeddingShuffleCopyKernel<<<num_blocks, 1024, 0, cuda_stream>>>(parallel_id, parallel_num,\n                                                                     embedding_num_pack, param);\n  } else {\n    CHECK_EQ(kernel_state->UniqueEmbeddings()->at(parallel_id),\n             ctx->Tensor4ArgNameAndIndex(\"cur_rank_embeddings\", 0)->dptr())\n        << parallel_id;\n    EmbeddingShuffleCudaKernel<<<num_blocks, 1024, 0, cuda_stream>>>(parallel_id, parallel_num,\n                                                                     embedding_num_pack, param);\n  }\n  if (!ctx->Attr<bool>(\"is_train\")) {\n    BarrierKernel<<<1, parallel_num, 0, cuda_stream>>>(\n        parallel_id, parallel_num,\n        param);  // if in eval, should add last barrier.\n  }\n}\n\n}  // namespace\n\ntemplate<typename T, typename IDX>\nclass EmbeddingShuffleP2PKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport {\n public:\n  EmbeddingShuffleP2PKernel() : current_iter_(0) {}\n  ~EmbeddingShuffleP2PKernel() override = default;\n\n  std::shared_ptr<user_op::OpKernelState> CreateOpKernelState(\n      user_op::KernelInitContext* ctx) const override {\n    return std::make_shared<DataShuffleKernelState<IDX>>(ctx);\n  }\n\n  bool IsReadyForCapture(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state,\n                         const user_op::OpKernelCache* cache) const override {\n    if (current_iter_ == 0) {\n      return false;\n    } else {\n      return true;\n    }\n  }\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state,\n               const user_op::OpKernelCache*) const override {\n    CHECK(!embedding::UseDynamicMemoryAllocation());\n    CHECK(ParseBooleanFromEnv(\"ONEFLOW_ONE_EMBEDDING_FUSE_EMBEDDING_INTERACTION\",\n                              false));  // only support skip last gather.\n    CHECK(ParseBooleanFromEnv(\"ONEFLOW_ONE_EMBEDDING_ADD_ID_SHUFFLE_COPY_OUT\",\n                              true));  // when no identity, every time the cur_rank_inverse_indices\n                                       // will change becauseof regster num=2.\n    auto* kernel_state = dynamic_cast<DataShuffleKernelState<IDX>*>(state);\n    CHECK(kernel_state != nullptr);\n    const user_op::Tensor* cur_rank_inverse_indices =\n        ctx->Tensor4ArgNameAndIndex(\"cur_rank_inverse_indices\", 0);\n    const user_op::Tensor* inverse_unique_partition_indices =\n        ctx->Tensor4ArgNameAndIndex(\"inverse_unique_partition_indices\", 0);\n    const bool skip_last_gather = ctx->Attr<bool>(\"skip_last_gather\");\n    CHECK(skip_last_gather);\n    const int64_t embedding_size = ctx->Attr<int64_t>(\"embedding_size\");\n    if (current_iter_ == 0) {\n      GetPtrs(ctx, kernel_state->UniqueEmbeddings(), kernel_state->InverseIndices(),\n              kernel_state->IsKernelStart());\n    }\n    const int64_t parallel_id = ctx->parallel_ctx().parallel_id();\n    CHECK_EQ(kernel_state->InverseIndices()->at(parallel_id), cur_rank_inverse_indices->dptr())\n        << parallel_id;\n    if (embedding_size % 4 == 0) {\n      LaunchKernel<T, IDX, 4>(ctx, kernel_state);\n    } else if (embedding_size % 2 == 0) {\n      LaunchKernel<T, IDX, 2>(ctx, kernel_state);\n    } else {\n      LaunchKernel<T, IDX, 1>(ctx, kernel_state);\n    }\n    current_iter_++;\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n  mutable int64_t current_iter_;\n};\n\nREGISTER_USER_KERNEL(\"embedding_shuffle\")\n    .SetCreateFn<EmbeddingShuffleP2PKernel<half, uint32_t>>()\n    .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)\n                     && (user_op::HobDataType(\"cur_rank_embeddings\", 0) == DataType::kFloat16)\n                     && (user_op::HobDataType(\"num_unique_matrix\", 0) == DataType::kUInt32)\n                     && (user_op::HobAttr<bool>(\"skip_last_gather\") == true)\n                     && (embedding::UseEmbeddingShuffleP2PKernel(DataType::kFloat16,\n                                                                 DataType::kUInt32)))\n    .SetInferTmpSizeFn([](user_op::InferContext* ctx) {\n      return GetCudaAlignedSize(ctx->InputTensorDesc(\"cur_rank_embeddings\", 0).shape().elem_cnt()\n                                * sizeof(half));\n    });\n}  // namespace oneflow\n\n#endif  // CUDA_VERSION >= 11030\n"
  },
  {
    "path": "oneflow/user/kernels/one_embedding_id_shuffle_p2p_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/job/parallel_desc.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n#include \"oneflow/core/cuda/atomic.cuh\"\n#include \"oneflow/core/embedding/hash_functions.cuh\"\n#include \"oneflow/core/embedding/embedding_manager.h\"\n#include \"oneflow/core/control/ctrl_client.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename K>\nstruct TableEntry {\n  K key;\n  uint32_t value;\n};\n\ntemplate<typename K, typename V, typename IDX, typename HASH>\n__global__ void HashTableUniqueAndPartitionPairs(\n    const uint32_t table_capacity, const uint32_t num_keys, int32_t num_partition,\n    IDX* unique_counts, TableEntry<K>* table, const K* keys, const V* values,\n    K* partitioned_unique_keys, V* partitioned_unique_values, IDX* reverse_index,\n    bool need_process_values, int32_t* is_kernel_start) {\n  CUDA_1D_KERNEL_LOOP_T(uint32_t, i, num_keys) {\n    IDX r_index_plus_one = 0;\n    const K key = keys[i];\n    size_t key_hash = HASH()(key);\n    uint32_t partition_id = key_hash % num_partition;\n    IDX* unique_count = unique_counts + partition_id;\n    K* unique_keys = partitioned_unique_keys + partition_id * num_keys;\n    uint32_t pos = key_hash % table_capacity;\n    const K key_hi = (key | 0x1);\n    const K key_lo = (key & 0x1);\n    uint32_t counter = 0;\n    while (r_index_plus_one == 0) {\n      bool prob_next = false;\n      K* key_ptr = &table[pos].key;\n      volatile uint32_t* table_value_ptr = &table[pos].value;\n      const K old_key = cuda::atomic::CAS(key_ptr, 0, key_hi);\n      if (old_key == 0) {\n        IDX unique_pos = cuda::atomic::Add(unique_count, 1);\n        r_index_plus_one = unique_pos + 1;\n        unique_keys[unique_pos] = key;\n        if (need_process_values) {\n          partitioned_unique_values[partition_id * num_keys + unique_pos] = values[i];\n        }\n        *table_value_ptr = ((r_index_plus_one << 1U) | key_lo);\n      } else if (old_key == key_hi) {\n        const uint32_t value = *table_value_ptr;\n        if (value == 0) {\n          // do nothing\n        } else if ((value & 0x1) == key_lo) {\n          r_index_plus_one = (value >> 1U);\n        } else {\n          prob_next = true;\n        }\n      } else {\n        prob_next = true;\n      }\n      if (prob_next) {\n        pos += 1;\n        counter += 1;\n        if (pos >= table_capacity) { pos -= table_capacity; }\n        if (counter >= table_capacity) { __trap(); }\n      }\n    }\n    reverse_index[i] = partition_id * num_keys + r_index_plus_one - 1;\n  }\n}\n\ntemplate<typename K, typename U, typename IDX, int N>\nstruct Param {\n  IDX* num_unique[N];\n  K* unique_ids[N];\n  U* unique_table_ids[N];\n  int32_t* is_kernel_start[N];\n  IDX* num_unique_matrix;\n  int32_t* counter;\n};\n\ntemplate<typename T, int pack_size>\nstruct alignas(sizeof(T) * pack_size) Pack {\n  T elem[pack_size];\n};\n\ntemplate<typename K, typename V, typename IDX, int N, int pack_size>\n__global__ void BarrierAndMemset(int32_t parallel_id, int32_t parallel_num,\n                                 Param<K, V, IDX, N> param, Pack<char, pack_size>* workspace_ptr,\n                                 size_t workspace_num_pack, IDX* counter, int num_counter) {\n  int count;\n  if (blockIdx.x == 0) {\n    count = param.is_kernel_start[parallel_id][parallel_id];\n    if (threadIdx.x < parallel_num) {\n      volatile int32_t* start_f = param.is_kernel_start[parallel_id];\n      start_f[threadIdx.x] = count + 1;\n    }\n  }\n  Pack<char, pack_size> pack_value;\n  for (int i = 0; i < pack_size; ++i) { pack_value.elem[i] = static_cast<char>(0); }\n  CUDA_1D_KERNEL_LOOP(i, workspace_num_pack) { workspace_ptr[i] = pack_value; }\n  int global_thread_id = blockIdx.x * blockDim.x + threadIdx.x;\n  if (global_thread_id < num_counter) { counter[global_thread_id] = 0; }\n  if (blockIdx.x == 0) {\n    if (threadIdx.x < parallel_num) {\n      volatile int32_t* remote_start_f = param.is_kernel_start[threadIdx.x];\n      while (remote_start_f[parallel_id] < count + 1) {}\n    }\n  }\n}\n\ntemplate<typename K, typename V, typename IDX, typename HASH, int N>\n__global__ void HashTableUniquePairs(const uint32_t table_capacity, const uint32_t num_ids,\n                                     int32_t parallel_num, int32_t parallel_id, IDX* unique_count,\n                                     TableEntry<K>* table, Param<K, V, IDX, N> param,\n                                     K* unique_keys, V* unique_values, IDX* reverse_index,\n                                     bool need_process_values) {\n#pragma unroll 1\n  for (int i = 0; i < parallel_num; ++i) {\n    int rank_id = (parallel_id + i) % parallel_num;\n    const IDX* num_uniques = param.num_unique[rank_id];\n    CUDA_1D_KERNEL_LOOP_T(int, rank_index, num_uniques[parallel_id]) {\n      const IDX* num_uniques = param.num_unique[rank_id];\n      // if (rank_index >= num_uniques[parallel_id]) { continue; }\n      const K* keys = param.unique_ids[rank_id];\n      const V* values = param.unique_table_ids[rank_id];\n      IDX index_offset = 0;\n      for (int k = 0; k < rank_id; ++k) { index_offset += param.num_unique[k][parallel_id]; }\n      IDX r_index_plus_one = 0;\n      const K key = keys[rank_index];\n      size_t key_hash = HASH()(key);\n      uint32_t pos = key_hash % table_capacity;\n      const K key_hi = (key | 0x1);\n      const K key_lo = (key & 0x1);\n      uint32_t counter = 0;\n      while (r_index_plus_one == 0) {\n        bool prob_next = false;\n        K* key_ptr = &table[pos].key;\n        volatile uint32_t* table_value_ptr = &table[pos].value;\n        const K old_key = cuda::atomic::CAS(key_ptr, 0, key_hi);\n        if (old_key == 0) {\n          IDX unique_pos = cuda::atomic::Add(unique_count, 1);\n          r_index_plus_one = unique_pos + 1;\n          unique_keys[unique_pos] = key;\n          if (need_process_values) { unique_values[unique_pos] = values[rank_index]; }\n          *table_value_ptr = ((r_index_plus_one << 1U) | key_lo);\n        } else if (old_key == key_hi) {\n          const uint32_t value = *table_value_ptr;\n          if (value == 0) {\n            // do nothing\n          } else if ((value & 0x1) == key_lo) {\n            r_index_plus_one = (value >> 1U);\n          } else {\n            prob_next = true;\n          }\n        } else {\n          prob_next = true;\n        }\n        if (prob_next) {\n          pos += 1;\n          counter += 1;\n          if (pos >= table_capacity) { pos -= table_capacity; }\n          if (counter >= table_capacity) { __trap(); }\n        }\n      }\n      reverse_index[rank_index + index_offset] = r_index_plus_one - 1;\n      if (rank_index < parallel_num) {\n        param.num_unique_matrix[i * parallel_num + rank_index] = param.num_unique[i][rank_index];\n      }\n    }\n  }\n}\n\ntemplate<typename U, typename IDX, int pack_size>\n__global__ void GenerateTableIdsAndMemsetUniqueWorkspace(int32_t elem_cnt, int32_t num_tables,\n                                                         U* table_ids,\n                                                         Pack<char, pack_size>* workspace_ptr,\n                                                         size_t workspace_num_pack, IDX* counter,\n                                                         int num_counter) {\n  CUDA_1D_KERNEL_LOOP(i, elem_cnt) { table_ids[i] = i % num_tables; }\n  Pack<char, pack_size> pack_value;\n  for (int i = 0; i < pack_size; ++i) { pack_value.elem[i] = static_cast<char>(0); }\n  CUDA_1D_KERNEL_LOOP(i, workspace_num_pack) { workspace_ptr[i] = pack_value; }\n  int global_thread_id = blockIdx.x * blockDim.x + threadIdx.x;\n  if (global_thread_id < num_counter) { counter[global_thread_id] = 0; }\n}\n\ntemplate<typename K, typename V, typename IDX, typename HASH>\nvoid UniqueAndPartition(cudaStream_t cuda_stream, int64_t num_blocks, int64_t num_ids,\n                        size_t capacity, int64_t num_partition, const K* ids, const V* table_ids,\n                        IDX* num_partitioned_unique_ids_ptr, K* partitioned_unique_ids,\n                        V* partitioned_unique_table_ids, IDX* inverse_unique_partition_indices,\n                        void* workspace_ptr, size_t workspace_bytes, bool need_process_table_ids,\n                        int32_t* is_kernel_start_ptr) {\n  size_t table_capacity_bytes = capacity * sizeof(TableEntry<K>);\n  CHECK_GE(workspace_bytes, table_capacity_bytes);\n  HashTableUniqueAndPartitionPairs<K, V, IDX, HASH><<<num_blocks, 1024, 0, cuda_stream>>>(\n      capacity, num_ids, num_partition, num_partitioned_unique_ids_ptr,\n      reinterpret_cast<TableEntry<K>*>(workspace_ptr), ids, table_ids, partitioned_unique_ids,\n      partitioned_unique_table_ids, inverse_unique_partition_indices, need_process_table_ids,\n      is_kernel_start_ptr);\n}\n\nenum class IdShuffleBufferType { kTableIds = 0, kWorkspace, kMaxType };\n\ntemplate<typename K, typename U, typename IDX>\nclass IdShuffleTmpBufferManager final {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(IdShuffleTmpBufferManager);\n  IdShuffleTmpBufferManager(void* ptr, const int64_t num_ids, const int64_t parallel_num,\n                            bool need_table_ids, bool need_process_table_ids)\n      : offset_(0),\n        offsets_(static_cast<size_t>(IdShuffleBufferType::kMaxType), -1),\n        sizes_(static_cast<size_t>(IdShuffleBufferType::kMaxType)),\n        ptr_(ptr) {\n    const int64_t num_table_ids = need_process_table_ids ? num_ids : 0;\n    const size_t table_ids_bytes = need_table_ids ? num_ids * sizeof(U) : 0;\n    AllocBuffer(IdShuffleBufferType::kTableIds, table_ids_bytes);\n    const size_t hash_table_capacity = parallel_num * num_ids;\n    AllocBuffer(IdShuffleBufferType::kWorkspace, hash_table_capacity * sizeof(TableEntry<K>));\n  }\n\n  template<typename T = void>\n  T* Ptr(IdShuffleBufferType type) {\n    CHECK(ptr_ != nullptr);\n    int64_t offset = offsets_.at(static_cast<size_t>(type));\n    CHECK_NE(offset, -1);\n    return reinterpret_cast<T*>(reinterpret_cast<char*>(ptr_) + offset);\n  }\n\n  int64_t Size(IdShuffleBufferType type) { return sizes_.at(static_cast<size_t>(type)); }\n\n  size_t TotalBufferSize() const { return offset_; }\n\n private:\n  void AllocBuffer(IdShuffleBufferType type, size_t size) {\n    const size_t type_id = static_cast<size_t>(type);\n    CHECK_EQ(offsets_.at(type_id), -1);\n    offsets_.at(type_id) = offset_;\n    sizes_.at(type_id) = size;\n    offset_ += GetCudaAlignedSize(size);\n  }\n  size_t offset_;\n  std::vector<int64_t> offsets_;\n  std::vector<int64_t> sizes_;\n  void* ptr_;\n};\n\ntemplate<typename K, typename U, typename IDX>\nclass DataShuffleKernelState final : public user_op::OpKernelState {\n public:\n  explicit DataShuffleKernelState(user_op::KernelInitContext* ctx)\n      : device_index_(-1),\n        parallel_desc_(ctx->parallel_desc()),\n        parallel_id_(ctx->parallel_ctx().parallel_id()) {\n    OF_CUDA_CHECK(cudaGetDevice(&device_index_));\n    int64_t parallel_num = parallel_desc_.parallel_num();\n    OF_CUDA_CHECK(\n        cudaMallocHost(&host_num_unique_matrix_, parallel_num * parallel_num * sizeof(IDX)));\n    OF_CUDA_CHECK(cudaMallocHost(&host_cur_rank_num_unique_, sizeof(IDX)));\n    const std::string& embedding_name = ctx->Attr<std::string>(\"embedding_name\");\n    const int64_t parallel_id = parallel_id_;\n    embedding_state_ = Singleton<embedding::EmbeddingManager>::Get()->GetEmbeddingState(\n        embedding_name, parallel_id);\n    const int64_t num_ids = ctx->TensorDesc4ArgNameAndIndex(\"ids\", 0)->shape().elem_cnt();\n    num_partitioned_unique_size_ = GetCudaAlignedSize(parallel_num * sizeof(IDX));\n    partitioned_unique_ids_size_ = GetCudaAlignedSize(parallel_num * num_ids * sizeof(K));\n    partitioned_unique_table_ids_size_ = GetCudaAlignedSize(parallel_num * num_ids * sizeof(U));\n    is_kernel_start_size_ = GetCudaAlignedSize(parallel_num * sizeof(int32_t));\n    size_t buffer_size = num_partitioned_unique_size_ + partitioned_unique_ids_size_\n                         + partitioned_unique_table_ids_size_ + is_kernel_start_size_;\n    buffer_ptrs_.resize(parallel_num);\n    cudaMalloc(&buffer_ptrs_.at(parallel_id), buffer_size);\n    cudaMemset(buffer_ptrs_.at(parallel_id), 0, buffer_size);\n  }\n  ~DataShuffleKernelState() {\n    CudaCurrentDeviceGuard guard(device_index_);\n    OF_CUDA_CHECK(cudaFreeHost(host_cur_rank_num_unique_));\n    OF_CUDA_CHECK(cudaFreeHost(host_num_unique_matrix_));\n    OF_CUDA_CHECK(cudaFree(buffer_ptrs_.at(parallel_id_)));\n  }\n\n  std::vector<void*>* BufferPtrs() { return &buffer_ptrs_; }\n\n  IDX* HostNumUniqueMatrix() { return host_num_unique_matrix_; }\n\n  IDX* HostCurRankNumUnique() { return host_cur_rank_num_unique_; }\n\n  embedding::EmbeddingState* EmbeddingState() { return embedding_state_; }\n\n  IDX* NumPartitionedUnique(int64_t parallel_id) {\n    return reinterpret_cast<IDX*>(buffer_ptrs_.at(parallel_id));\n  }\n\n  K* PartitionedUniqueIds(int64_t parallel_id) {\n    return reinterpret_cast<K*>(reinterpret_cast<char*>(buffer_ptrs_.at(parallel_id))\n                                + num_partitioned_unique_size_);\n  }\n\n  U* PartitionedUniqueTableIds(int64_t parallel_id) {\n    return reinterpret_cast<U*>(reinterpret_cast<char*>(buffer_ptrs_.at(parallel_id))\n                                + num_partitioned_unique_size_ + partitioned_unique_ids_size_);\n  }\n\n  int32_t* IsKernelStart(int64_t parallel_id) {\n    return reinterpret_cast<int32_t*>(reinterpret_cast<char*>(buffer_ptrs_.at(parallel_id))\n                                      + num_partitioned_unique_size_ + partitioned_unique_ids_size_\n                                      + partitioned_unique_table_ids_size_);\n  }\n\n private:\n  int device_index_;\n  ParallelDesc parallel_desc_;\n  int64_t parallel_id_;\n  IDX* host_num_unique_matrix_;\n  IDX* host_cur_rank_num_unique_;\n  std::vector<void*> buffer_ptrs_;\n  size_t num_partitioned_unique_size_;\n  size_t partitioned_unique_ids_size_;\n  size_t partitioned_unique_table_ids_size_;\n  size_t is_kernel_start_size_;\n  embedding::EmbeddingState* embedding_state_;\n};\n\nvoid GetPtrs(user_op::KernelComputeContext* ctx, std::vector<void*>* buffer_ptrs) {\n  const int64_t parallel_id = ctx->parallel_ctx().parallel_id();\n  const int64_t parallel_num = ctx->parallel_ctx().parallel_num();\n  std::string name = ctx->op_name();\n  cudaIpcMemHandle_t handle;\n  OF_CUDA_CHECK(cudaIpcGetMemHandle(&handle, buffer_ptrs->at(parallel_id)));\n  Singleton<CtrlClient>::Get()->PushKV(\n      name + std::to_string(parallel_id),\n      std::string(reinterpret_cast<const char*>(&handle), sizeof(cudaIpcMemHandle_t)));\n  for (int64_t i = 0; i < parallel_num; ++i) {\n    std::string key = name + std::to_string(i);\n    if (parallel_id != i) {\n      cudaIpcMemHandle_t handle;\n      Singleton<CtrlClient>::Get()->PullKV(key, [&handle](const std::string& val) {\n        memcpy(&handle, val.data(), sizeof(cudaIpcMemHandle_t));\n      });\n      OF_CUDA_CHECK(\n          cudaIpcOpenMemHandle(&buffer_ptrs->at(i), handle, cudaIpcMemLazyEnablePeerAccess));\n    }\n  }\n}\n\ntemplate<typename K, typename V, typename IDX, int N>\n__global__ void BarrierAndComputeOut(int32_t parallel_id, int32_t parallel_num, int32_t num_ids,\n                                     Param<K, V, IDX, N> param, IDX* num_partitioned_unique,\n                                     IDX* inverse_ptr, IDX* num_unique_matrix,\n                                     IDX* host_num_unique_matrix, IDX* cur_rank_num_unique,\n                                     IDX* host_cur_rank_num_unique) {\n  int count;\n  if (blockIdx.x == 0) {\n    count = param.is_kernel_start[parallel_id][parallel_id];\n    if (threadIdx.x < parallel_num) {\n      volatile int32_t* start_f = param.is_kernel_start[parallel_id];\n      start_f[threadIdx.x] = count + 1;\n    }\n  }\n  if (parallel_num > 1) {\n    CUDA_1D_KERNEL_LOOP(i, num_ids) {\n      int inverse_indice = inverse_ptr[i];\n      int partition_id = inverse_indice / num_ids;\n      int partition_indice = inverse_indice - partition_id * num_ids;\n      int new_offset = 0;\n      for (int k = 0; k < partition_id; ++k) { new_offset += num_partitioned_unique[k]; }\n      inverse_ptr[i] = new_offset + partition_indice;\n    }\n  }\n  int global_thread_id = blockIdx.x * blockDim.x + threadIdx.x;\n  if (global_thread_id < parallel_num * parallel_num) {\n    host_num_unique_matrix[global_thread_id] = num_unique_matrix[global_thread_id];\n  }\n  if (global_thread_id == 0) {\n    host_cur_rank_num_unique[global_thread_id] = cur_rank_num_unique[global_thread_id];\n  }\n  if (blockIdx.x == 0) {\n    if (threadIdx.x < parallel_num) {\n      volatile int32_t* remote_start_f = param.is_kernel_start[threadIdx.x];\n      while (remote_start_f[parallel_id] < count + 1) {}\n    }\n  }\n}\n\n}  // namespace\n\ntemplate<typename K, typename U, typename IDX>\nclass IdShuffleP2PKernel final : public user_op::OpKernel {\n public:\n  IdShuffleP2PKernel() : current_iter_(0){};\n  ~IdShuffleP2PKernel() override = default;\n\n  std::shared_ptr<user_op::OpKernelState> CreateOpKernelState(\n      user_op::KernelInitContext* ctx) const override {\n    return std::make_shared<DataShuffleKernelState<K, U, IDX>>(ctx);\n  }\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state,\n               const user_op::OpKernelCache*) const override {\n    auto* kernel_state = dynamic_cast<DataShuffleKernelState<K, U, IDX>*>(state);\n    CHECK(kernel_state != nullptr);\n    const user_op::Tensor* ids = ctx->Tensor4ArgNameAndIndex(\"ids\", 0);\n    user_op::Tensor* num_unique_matrix = ctx->Tensor4ArgNameAndIndex(\"num_unique_matrix\", 0);\n    user_op::Tensor* inverse_unique_partition_indices =\n        ctx->Tensor4ArgNameAndIndex(\"inverse_unique_partition_indices\", 0);\n    user_op::Tensor* cur_rank_num_unique = ctx->Tensor4ArgNameAndIndex(\"cur_rank_num_unique\", 0);\n    user_op::Tensor* cur_rank_unique_ids = ctx->Tensor4ArgNameAndIndex(\"cur_rank_unique_ids\", 0);\n    user_op::Tensor* cur_rank_unique_table_ids =\n        ctx->Tensor4ArgNameAndIndex(\"cur_rank_unique_table_ids\", 0);\n    user_op::Tensor* cur_rank_inverse_indices =\n        ctx->Tensor4ArgNameAndIndex(\"cur_rank_inverse_indices\", 0);\n    user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n    const int32_t num_tables = ctx->Attr<int32_t>(\"num_tables\");\n    const bool has_table_ids = ctx->has_input(\"table_ids\", 0);\n    const bool need_gen_table_ids = (!has_table_ids && num_tables > 1);\n    const bool need_process_table_ids = (has_table_ids || num_tables > 1);\n    const int64_t num_ids = ids->shape_view().elem_cnt();\n    const int64_t parallel_num = ctx->parallel_ctx().parallel_num();\n    const int64_t parallel_id = ctx->parallel_ctx().parallel_id();\n    cudaStream_t cuda_stream = ctx->stream()->As<ep::CudaStream>()->cuda_stream();\n    IdShuffleTmpBufferManager<K, U, IDX> buffer_manager(\n        tmp_buffer->mut_dptr(), num_ids, parallel_num, need_gen_table_ids, need_process_table_ids);\n    CHECK_GE(tmp_buffer->shape_view().elem_cnt(), buffer_manager.TotalBufferSize());\n    if (current_iter_ == 0) { GetPtrs(ctx, kernel_state->BufferPtrs()); }\n    const int num_blocks =\n        2 * ctx->stream()->As<ep::CudaStream>()->device_properties().multiProcessorCount;\n    IDX* num_partitioned_unique = kernel_state->NumPartitionedUnique(parallel_id);\n    K* partitioned_unique_ids = kernel_state->PartitionedUniqueIds(parallel_id);\n    U* partitioned_unique_table_ids = kernel_state->PartitionedUniqueTableIds(parallel_id);\n    IDX* num_unique_matrix_ptr = reinterpret_cast<IDX*>(num_unique_matrix->mut_dptr());\n    size_t hash_table_capacity = parallel_num * num_ids;\n    void* workspace_ptr = buffer_manager.Ptr(IdShuffleBufferType::kWorkspace);\n    size_t workspace_size = buffer_manager.Size(IdShuffleBufferType::kWorkspace);\n    const U* table_ids_ptr;\n    bool skip_memset = false;\n    if (has_table_ids) {\n      const user_op::Tensor* table_ids = ctx->Tensor4ArgNameAndIndex(\"table_ids\", 0);\n      table_ids_ptr = reinterpret_cast<const U*>(table_ids->dptr());\n    } else if (need_gen_table_ids) {\n      CHECK_EQ(workspace_size % 16, 0);\n      CHECK_EQ(reinterpret_cast<std::uintptr_t>(workspace_ptr) % 16, 0);\n      GenerateTableIdsAndMemsetUniqueWorkspace<U, IDX, 16><<<num_blocks, 1024, 0, cuda_stream>>>(\n          num_ids, num_tables, buffer_manager.template Ptr<U>(IdShuffleBufferType::kTableIds),\n          reinterpret_cast<Pack<char, 16>*>(workspace_ptr), workspace_size / 16,\n          num_partitioned_unique, parallel_num);\n      table_ids_ptr = buffer_manager.template Ptr<U>(IdShuffleBufferType::kTableIds);\n      skip_memset = true;\n    } else {\n      table_ids_ptr = nullptr;\n    }\n    if (!skip_memset) {\n      OF_CUDA_CHECK(cudaMemsetAsync(workspace_ptr, 0, workspace_size, cuda_stream));\n      OF_CUDA_CHECK(\n          cudaMemsetAsync(num_partitioned_unique, 0, parallel_num * sizeof(IDX), cuda_stream));\n    }\n    UniqueAndPartition<K, U, IDX, embedding::ShardingHash>(\n        cuda_stream, num_blocks, num_ids, hash_table_capacity, parallel_num,\n        reinterpret_cast<const K*>(ids->dptr()), table_ids_ptr, num_partitioned_unique,\n        partitioned_unique_ids, partitioned_unique_table_ids,\n        reinterpret_cast<IDX*>(inverse_unique_partition_indices->mut_dptr()), workspace_ptr,\n        workspace_size, need_process_table_ids, kernel_state->IsKernelStart(parallel_id));\n\n    IDX* cur_rank_num_unique_ids_ptr = reinterpret_cast<IDX*>(cur_rank_num_unique->mut_dptr());\n    Param<K, U, IDX, 8> param;\n    CHECK_LE(parallel_num, 8);\n    for (int i = 0; i < parallel_num; ++i) {\n      param.num_unique[i] = kernel_state->NumPartitionedUnique(i);\n      param.unique_ids[i] = kernel_state->PartitionedUniqueIds(i) + parallel_id * num_ids;\n      param.unique_table_ids[i] =\n          kernel_state->PartitionedUniqueTableIds(i) + parallel_id * num_ids;\n      param.is_kernel_start[i] = kernel_state->IsKernelStart(i);\n    }\n    param.num_unique_matrix = num_unique_matrix_ptr;\n    CHECK_EQ(workspace_size % 16, 0);\n    CHECK_EQ(reinterpret_cast<std::uintptr_t>(workspace_ptr) % 16, 0);\n    int workspace_num_pack = workspace_size / 16;\n    BarrierAndMemset<<<num_blocks, 1024, 0, cuda_stream>>>(\n        parallel_id, parallel_num, param, reinterpret_cast<Pack<char, 16>*>(workspace_ptr),\n        workspace_num_pack, cur_rank_num_unique_ids_ptr, 1);\n    HashTableUniquePairs<K, U, IDX, embedding::LocalUniqueHash>\n        <<<num_blocks, 1024, 0, cuda_stream>>>(\n            hash_table_capacity, num_ids, parallel_num, parallel_id, cur_rank_num_unique_ids_ptr,\n            reinterpret_cast<TableEntry<K>*>(workspace_ptr), param,\n            reinterpret_cast<K*>(cur_rank_unique_ids->mut_dptr()),\n            reinterpret_cast<U*>(cur_rank_unique_table_ids->mut_dptr()),\n            reinterpret_cast<IDX*>(cur_rank_inverse_indices->mut_dptr()), need_process_table_ids);\n\n    IDX* host_num_unique_matrix = kernel_state->HostNumUniqueMatrix();\n    IDX* host_cur_rank_num_unique = kernel_state->HostCurRankNumUnique();\n    BarrierAndComputeOut<<<num_blocks, 1024, 0, cuda_stream>>>(\n        parallel_id, parallel_num, num_ids, param, num_partitioned_unique,\n        reinterpret_cast<IDX*>(inverse_unique_partition_indices->mut_dptr()), num_unique_matrix_ptr,\n        host_num_unique_matrix, cur_rank_num_unique_ids_ptr, host_cur_rank_num_unique);\n\n    if (!need_process_table_ids) {\n      OF_CUDA_CHECK(cudaMemsetAsync(cur_rank_unique_table_ids->mut_dptr(), 0,\n                                    cur_rank_unique_table_ids->shape_view().elem_cnt() * sizeof(U),\n                                    cuda_stream));\n    }\n    embedding::EmbeddingState* embedding_state = kernel_state->EmbeddingState();\n    std::vector<uint32_t> num_unique_matrix_vec(parallel_num * parallel_num);\n    CHECK_JUST(ctx->stream()->Sync());\n    std::memcpy(num_unique_matrix_vec.data(), host_num_unique_matrix,\n                parallel_num * parallel_num * sizeof(IDX));\n    CHECK_EQ(sizeof(IDX), sizeof(uint32_t)) << \"assume sizeof(IDX) equals to sizeof(uint32_t)\";\n    embedding_state->SetIdNumUniqueMatrix(num_unique_matrix_vec, current_iter_);\n    uint32_t final_num_unique = *host_cur_rank_num_unique;\n    embedding_state->SetIdFinalNumUnique(final_num_unique, current_iter_);\n    current_iter_++;\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n  mutable int64_t current_iter_;\n};\n\n#define ID_DATA_TYPE_SEQ                            \\\n  OF_PP_MAKE_TUPLE_SEQ(uint32_t, DataType::kUInt32) \\\n  OF_PP_MAKE_TUPLE_SEQ(uint64_t, DataType::kUInt64) \\\n  OF_PP_MAKE_TUPLE_SEQ(int32_t, DataType::kInt32)   \\\n  OF_PP_MAKE_TUPLE_SEQ(int64_t, DataType::kInt64)\n\n#define TABLE_ID_DATA_TYPE_SEQ                      \\\n  OF_PP_MAKE_TUPLE_SEQ(uint8_t, DataType::kUInt8)   \\\n  OF_PP_MAKE_TUPLE_SEQ(uint32_t, DataType::kUInt32) \\\n  OF_PP_MAKE_TUPLE_SEQ(uint64_t, DataType::kUInt64) \\\n  OF_PP_MAKE_TUPLE_SEQ(int8_t, DataType::kInt8)     \\\n  OF_PP_MAKE_TUPLE_SEQ(int32_t, DataType::kInt32)   \\\n  OF_PP_MAKE_TUPLE_SEQ(int64_t, DataType::kInt64)\n\n#define IDX_DATA_TYPE_SEQ                           \\\n  OF_PP_MAKE_TUPLE_SEQ(uint32_t, DataType::kUInt32) \\\n  OF_PP_MAKE_TUPLE_SEQ(int32_t, DataType::kInt32)\n\n#define REGISTER_CUDA_ID_SHUFFLE_P2P_KERNEL(k_dtype_pair, table_id_dtype_pair, idx_dtype_pair)   \\\n  REGISTER_USER_KERNEL(\"id_shuffle\")                                                             \\\n      .SetCreateFn<IdShuffleP2PKernel<OF_PP_PAIR_FIRST(k_dtype_pair),                            \\\n                                      OF_PP_PAIR_FIRST(table_id_dtype_pair),                     \\\n                                      OF_PP_PAIR_FIRST(idx_dtype_pair)>>()                       \\\n      .SetIsMatchedHob(                                                                          \\\n          (user_op::HobDeviceType() == DeviceType::kCUDA)                                        \\\n          && (user_op::HobDataType(\"ids\", 0) == OF_PP_PAIR_SECOND(k_dtype_pair))                 \\\n          && (user_op::HobDataType(\"cur_rank_unique_table_ids\", 0)                               \\\n              == OF_PP_PAIR_SECOND(table_id_dtype_pair))                                         \\\n          && (user_op::HobDataType(\"num_unique_matrix\", 0) == OF_PP_PAIR_SECOND(idx_dtype_pair)) \\\n          && ParseBooleanFromEnv(\"ONEFLOW_ONE_EMBEDDING_ID_SHUFFLE_USE_P2P\", false))             \\\n      .SetInferTmpSizeFn([](user_op::InferContext* ctx) {                                        \\\n        const user_op::TensorDesc& ids = ctx->InputTensorDesc(\"ids\", 0);                         \\\n        const bool has_table_ids = ctx->has_input(\"table_ids\", 0);                               \\\n        const int32_t num_tables = ctx->Attr<int32_t>(\"num_tables\");                             \\\n        const bool need_gen_table_ids = (!has_table_ids && num_tables > 1);                      \\\n        const bool need_process_table_ids = (has_table_ids || num_tables > 1);                   \\\n        IdShuffleTmpBufferManager<OF_PP_PAIR_FIRST(k_dtype_pair),                                \\\n                                  OF_PP_PAIR_FIRST(table_id_dtype_pair),                         \\\n                                  OF_PP_PAIR_FIRST(idx_dtype_pair)>                              \\\n            buffer_manager(nullptr, ids.shape().elem_cnt(), ctx->parallel_desc().parallel_num(), \\\n                           need_gen_table_ids, need_process_table_ids);                          \\\n        return buffer_manager.TotalBufferSize();                                                 \\\n      });\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_CUDA_ID_SHUFFLE_P2P_KERNEL, ID_DATA_TYPE_SEQ,\n                                 TABLE_ID_DATA_TYPE_SEQ, IDX_DATA_TYPE_SEQ)\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/one_embedding_kernels.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/embedding/key_value_store.h\"\n#include \"oneflow/core/embedding/embedding_manager.h\"\n#include \"oneflow/core/device/cuda_util.h\"\n#include \"oneflow/core/ep/include/primitive/copy_nd.h\"\n#include \"oneflow/core/ep/include/primitive/cast.h\"\n#include \"oneflow/core/ep/include/device.h\"\n#include \"oneflow/user/kernels/one_embedding_data_shuffle.cuh\"\n#include <curand.h>\n#include <curand_kernel.h>\n\nnamespace oneflow {\n\nnamespace {\n\nenum class InitializerType { kUniform, kNormal, kConstant, kTruncNormal };\n\nstruct EmbeddingInitializer {\n  InitializerType type;\n  union {\n    struct {\n      float low;\n      float high;\n    } uniform_param;\n    struct {\n      float mean;\n      float std;\n    } normal_param;\n    struct {\n      float value;\n    } constant_param;\n    struct {\n      float mean;\n      float std;\n      float a;\n      float b;\n    } trunc_normal_param;\n  };\n\n  bool operator==(const EmbeddingInitializer& rhs) const {\n    if (this->type != rhs.type) { return false; }\n    if (rhs.type == InitializerType::kUniform) {\n      return (this->uniform_param.low == rhs.uniform_param.low)\n             && (this->uniform_param.high == rhs.uniform_param.high);\n    } else if (rhs.type == InitializerType::kNormal) {\n      return (this->normal_param.mean == rhs.normal_param.mean)\n             && (this->normal_param.std == rhs.normal_param.std);\n    } else if (rhs.type == InitializerType::kConstant) {\n      return this->constant_param.value == rhs.constant_param.value;\n    } else if (rhs.type == InitializerType::kTruncNormal) {\n      return (this->trunc_normal_param.mean == rhs.trunc_normal_param.mean)\n             && (this->trunc_normal_param.std == rhs.trunc_normal_param.std)\n             && (this->trunc_normal_param.a == rhs.trunc_normal_param.a)\n             && (this->trunc_normal_param.b == rhs.trunc_normal_param.b);\n    } else {\n      UNIMPLEMENTED();\n      return false;\n    }\n  }\n};\n\nvoid ParseInitializerFromJson(const nlohmann::json& initializer,\n                              EmbeddingInitializer* embedding_initializer) {\n  CHECK(initializer.contains(\"type\"));\n  CHECK(initializer[\"type\"].is_string());\n  std::string type = initializer[\"type\"].get<std::string>();\n  if (type == \"uniform\") {\n    embedding_initializer->type = InitializerType::kUniform;\n    CHECK(initializer.contains(\"low\"));\n    CHECK(initializer.contains(\"high\"));\n    CHECK(initializer[\"low\"].is_number());\n    CHECK(initializer[\"high\"].is_number());\n    embedding_initializer->uniform_param.low = initializer[\"low\"];\n    embedding_initializer->uniform_param.high = initializer[\"high\"];\n  } else if (type == \"normal\") {\n    CHECK(initializer.contains(\"mean\"));\n    CHECK(initializer.contains(\"std\"));\n    CHECK(initializer[\"mean\"].is_number());\n    CHECK(initializer[\"std\"].is_number());\n    embedding_initializer->type = InitializerType::kNormal;\n    embedding_initializer->normal_param.mean = initializer[\"mean\"];\n    embedding_initializer->normal_param.std = initializer[\"std\"];\n  } else if (type == \"constant\") {\n    CHECK(initializer.contains(\"value\"));\n    CHECK(initializer[\"value\"].is_number());\n    embedding_initializer->type = InitializerType::kConstant;\n    embedding_initializer->constant_param.value = initializer[\"value\"];\n  } else if (type == \"trunc_normal\") {\n    CHECK(initializer.contains(\"mean\"));\n    CHECK(initializer.contains(\"std\"));\n    CHECK(initializer.contains(\"a\"));\n    CHECK(initializer.contains(\"b\"));\n    CHECK(initializer[\"mean\"].is_number());\n    CHECK(initializer[\"std\"].is_number());\n    CHECK(initializer[\"a\"].is_number());\n    CHECK(initializer[\"b\"].is_number());\n    embedding_initializer->type = InitializerType::kTruncNormal;\n    embedding_initializer->trunc_normal_param.mean = initializer[\"mean\"];\n    embedding_initializer->trunc_normal_param.std = initializer[\"std\"];\n    embedding_initializer->trunc_normal_param.a = initializer[\"a\"];\n    embedding_initializer->trunc_normal_param.b = initializer[\"b\"];\n  } else {\n    UNIMPLEMENTED() << \"Unsupported initializer type\";\n  }\n}\n\nint32_t ParseJsonToUniqueInitializerVecAndReturnOffset(\n    const nlohmann::json& initializer, std::vector<EmbeddingInitializer>* initializers) {\n  EmbeddingInitializer embedding_initializer;\n  ParseInitializerFromJson(initializer, &embedding_initializer);\n  for (int32_t i = 0; i < initializers->size(); ++i) {\n    if (initializers->at(i) == embedding_initializer) { return i; }\n  }\n  initializers->push_back(embedding_initializer);\n  return initializers->size() - 1;\n}\n\nvoid SetInitializerIndex(int32_t row_id, int32_t col_start, int32_t col_end, int64_t line_size,\n                         int8_t index, std::vector<int8_t>* initializer_index) {\n  int64_t row_offset = row_id * line_size;\n  for (int32_t col = col_start; col < col_end; ++col) {\n    initializer_index->at(row_offset + col) = index;\n  }\n}\n\nvoid ParseAndSetStateInitializerIndex(const std::string& state_initializer,\n                                      const int32_t num_tables, const int64_t line_size,\n                                      const int64_t embedding_size,\n                                      std::vector<EmbeddingInitializer>* initializer_params,\n                                      std::vector<int8_t>* initializer_index) {\n  if (line_size == embedding_size) { return; }\n  CHECK(!state_initializer.empty());\n  auto initializers = nlohmann::json::parse(state_initializer);\n  CHECK(initializers.is_array());\n  const int num_states = line_size / embedding_size - 1;\n  CHECK_EQ(num_states, initializers.size());\n  for (int32_t i = 0; i < num_states; ++i) {\n    int32_t offset =\n        ParseJsonToUniqueInitializerVecAndReturnOffset(initializers.at(i), initializer_params);\n    int32_t col_start = embedding_size + i * embedding_size;\n    int32_t col_end = col_start + embedding_size;\n    CHECK_LE(col_end, line_size);\n    for (int32_t j = 0; j < num_tables; ++j) {\n      SetInitializerIndex(j, col_start, col_end, line_size, offset, initializer_index);\n    }\n  }\n}\n\nvoid ParseAndSetStepInitializerIndex(const int32_t num_tables, const int64_t line_size,\n                                     const int64_t embedding_size,\n                                     std::vector<EmbeddingInitializer>* initializer_params,\n                                     std::vector<int8_t>* initializer_index) {\n  if (line_size % embedding_size == 0) { return; }\n  nlohmann::json initializer;\n  initializer[\"type\"] = \"constant\";\n  initializer[\"value\"] = 0.0;\n  int32_t offset = ParseJsonToUniqueInitializerVecAndReturnOffset(initializer, initializer_params);\n  int32_t col_start = line_size / embedding_size * embedding_size;\n  int32_t col_end = line_size;\n  CHECK_LE(col_end, line_size);\n  for (int32_t j = 0; j < num_tables; ++j) {\n    SetInitializerIndex(j, col_start, col_end, line_size, offset, initializer_index);\n  }\n}\n\nvoid ParseAndSetModelInitializerIndex(const nlohmann::json& tables,\n                                      const std::vector<int64_t>& column_dims,\n                                      const int32_t num_tables, const int32_t num_columns,\n                                      const int64_t line_size, const int64_t embedding_size,\n                                      std::vector<EmbeddingInitializer>* initializer_params,\n                                      std::vector<int8_t>* initializer_index) {\n  for (int32_t i = 0; i < num_tables; ++i) {\n    auto table = tables.at(i);\n    CHECK(table.contains(\"columns\"));\n    auto columns = table[\"columns\"];\n    CHECK(columns.is_array());\n    CHECK_EQ(num_columns, columns.size()) << \"columns size must equal to num embedding dims\";\n    int32_t col_start = 0;\n    for (int k = 0; k < columns.size(); ++k) {\n      auto column = columns.at(k);\n      CHECK(column.contains(\"initializer\"));\n      int32_t offset =\n          ParseJsonToUniqueInitializerVecAndReturnOffset(column[\"initializer\"], initializer_params);\n      int32_t col_end = col_start + column_dims.at(k);\n      SetInitializerIndex(i, col_start, col_end, line_size, offset, initializer_index);\n      col_start = col_end;\n    }\n    CHECK_EQ(col_start, embedding_size);\n  }\n}\n\nvoid ParseInitializers(const int64_t line_size, const int64_t embedding_size,\n                       const std::string& state_initializer, const std::string& json_serialized,\n                       std::vector<EmbeddingInitializer>* initializer_params,\n                       std::vector<int8_t>* initializer_index) {\n  auto json_object = nlohmann::json::parse(json_serialized);\n  CHECK(json_object.contains(\"column_dims\"));\n  std::vector<int64_t> column_dims = json_object[\"column_dims\"];\n  const int32_t num_columns = column_dims.size();\n  CHECK(json_object.contains(\"tables\"));\n  auto tables = json_object[\"tables\"];\n  CHECK(tables.is_array());\n  const int32_t num_tables = tables.size();\n  initializer_index->resize(num_tables * line_size);\n  ParseAndSetStepInitializerIndex(num_tables, line_size, embedding_size, initializer_params,\n                                  initializer_index);\n  ParseAndSetStateInitializerIndex(state_initializer, num_tables, line_size, embedding_size,\n                                   initializer_params, initializer_index);\n  ParseAndSetModelInitializerIndex(tables, column_dims, num_tables, num_columns, line_size,\n                                   embedding_size, initializer_params, initializer_index);\n}\n\ntemplate<typename IDX>\nclass EmbeddingKernelState final : public user_op::OpKernelState {\n public:\n  explicit EmbeddingKernelState(user_op::KernelInitContext* ctx) : device_index_(-1) {\n    OF_CUDA_CHECK(cudaGetDevice(&device_index_));\n    OF_CUDA_CHECK(cudaMallocHost(&host_num_keys_, sizeof(IDX)));\n    const std::string& embedding_name = ctx->Attr<std::string>(\"embedding_name\");\n    const int64_t parallel_id = ctx->parallel_ctx().parallel_id();\n    key_value_store_ = Singleton<embedding::EmbeddingManager>::Get()->GetKeyValueStore(\n        embedding_name, parallel_id);\n    uint32_t max_query_length =\n        ctx->TensorDesc4ArgNameAndIndex(\"unique_ids\", 0)->shape().elem_cnt();\n    key_value_store_->ReserveQueryLength(max_query_length);\n    embedding_state_ = Singleton<embedding::EmbeddingManager>::Get()->GetEmbeddingState(\n        embedding_name, parallel_id);\n\n    const int64_t embedding_size = ctx->Attr<int64_t>(\"embedding_size\");\n    const int64_t line_size = ctx->Attr<int64_t>(\"line_size\");\n    const std::string& state_initializer = ctx->Attr<std::string>(\"state_initializer\");\n\n    std::vector<EmbeddingInitializer> initializer_param;\n    std::vector<int8_t> initializer_index;\n    ParseInitializers(line_size, embedding_size, state_initializer,\n                      ctx->Attr<std::string>(\"embedding_tables\"), &initializer_param,\n                      &initializer_index);\n\n    const size_t param_size_bytes = initializer_param.size() * sizeof(EmbeddingInitializer);\n    OF_CUDA_CHECK(cudaMallocHost(&host_initializer_param_, param_size_bytes));\n    std::memcpy(host_initializer_param_, initializer_param.data(), param_size_bytes);\n    OF_CUDA_CHECK(cudaMalloc(&device_initializer_param_, param_size_bytes));\n    OF_CUDA_CHECK(cudaMemcpyAsync(device_initializer_param_, host_initializer_param_,\n                                  param_size_bytes, cudaMemcpyDefault,\n                                  ctx->stream()->As<ep::CudaStream>()->cuda_stream()));\n\n    const size_t index_size_bytes = initializer_index.size() * sizeof(int8_t);\n    OF_CUDA_CHECK(cudaMallocHost(&host_initializer_index_, index_size_bytes));\n    std::memcpy(host_initializer_index_, initializer_index.data(), index_size_bytes);\n    OF_CUDA_CHECK(cudaMalloc(&device_initializer_index_, index_size_bytes));\n    OF_CUDA_CHECK(cudaMemcpyAsync(device_initializer_index_, host_initializer_index_,\n                                  index_size_bytes, cudaMemcpyDefault,\n                                  ctx->stream()->As<ep::CudaStream>()->cuda_stream()));\n  }\n  ~EmbeddingKernelState() override {\n    CudaCurrentDeviceGuard guard(device_index_);\n    OF_CUDA_CHECK(cudaFreeHost(host_num_keys_));\n    OF_CUDA_CHECK(cudaFreeHost(host_initializer_param_));\n    OF_CUDA_CHECK(cudaFree(device_initializer_param_));\n    OF_CUDA_CHECK(cudaFreeHost(host_initializer_index_));\n    OF_CUDA_CHECK(cudaFree(device_initializer_index_));\n  }\n\n  void* HostNumKeys() { return host_num_keys_; }\n\n  embedding::KeyValueStore* KeyValueStore() { return key_value_store_; }\n\n  embedding::EmbeddingState* EmbeddingState() { return embedding_state_; }\n\n  const int8_t* InitializerIndex() { return device_initializer_index_; }\n  const EmbeddingInitializer* Initializers() { return device_initializer_param_; }\n\n private:\n  int device_index_;\n  void* host_num_keys_;\n  embedding::KeyValueStore* key_value_store_;\n  embedding::EmbeddingState* embedding_state_;\n  EmbeddingInitializer* host_initializer_param_;\n  EmbeddingInitializer* device_initializer_param_;\n  int8_t* host_initializer_index_;\n  int8_t* device_initializer_index_;\n};\n\nclass EmbeddingPutKernelState final : public user_op::OpKernelState {\n public:\n  explicit EmbeddingPutKernelState(user_op::KernelInitContext* ctx) {\n    const std::string& embedding_name = ctx->Attr<std::string>(\"embedding_name\");\n    const int64_t parallel_id = ctx->parallel_ctx().parallel_id();\n    key_value_store_ = Singleton<embedding::EmbeddingManager>::Get()->GetKeyValueStore(\n        embedding_name, parallel_id);\n    uint32_t max_query_length =\n        ctx->TensorDesc4ArgNameAndIndex(\"unique_ids\", 0)->shape().elem_cnt();\n    key_value_store_->ReserveQueryLength(max_query_length);\n    embedding_state_ = Singleton<embedding::EmbeddingManager>::Get()->GetEmbeddingState(\n        embedding_name, parallel_id);\n  }\n  ~EmbeddingPutKernelState() override = default;\n\n  embedding::KeyValueStore* KeyValueStore() { return key_value_store_; }\n  embedding::EmbeddingState* EmbeddingState() { return embedding_state_; }\n\n private:\n  embedding::KeyValueStore* key_value_store_;\n  embedding::EmbeddingState* embedding_state_;\n};\n\ntemplate<typename T, typename K, typename U>\n__global__ void InitValueKernel(uint64_t seed, const int32_t line_size,\n                                const int32_t embedding_size,\n                                const EmbeddingInitializer* initializer_param,\n                                const int8_t* initializer_index, const K* unique_ids,\n                                const U* table_ids, const uint32_t* num_missing_keys,\n                                const uint32_t* missing_indices, T* values) {\n  int64_t n = *num_missing_keys * line_size;\n  CUDA_1D_KERNEL_LOOP(i, n) {\n    int row = i / line_size;\n    int col = i - row * line_size;\n    const uint32_t index = missing_indices[row];\n    const int64_t offset = index * line_size + col;\n    const int32_t table_idx = table_ids[index];\n    const K id = unique_ids[index];\n    curandStatePhilox4_32_10_t state;\n    curand_init(seed, id, col, &state);\n    const int32_t initializer_idx = initializer_index[table_idx * line_size + col];\n    EmbeddingInitializer initializer = initializer_param[initializer_idx];\n    T value;\n    if (initializer.type == InitializerType::kUniform) {\n      const float low = initializer.uniform_param.low;\n      const float high = initializer.uniform_param.high;\n      value = curand_uniform(&state) * (high - low) + low;\n    } else if (initializer.type == InitializerType::kNormal) {\n      const float mean = initializer.normal_param.mean;\n      const float std = initializer.normal_param.std;\n      value = curand_normal(&state) * std + mean;\n    } else if (initializer.type == InitializerType::kConstant) {\n      value = initializer.constant_param.value;\n    } else if (initializer.type == InitializerType::kTruncNormal) {\n      const float mean = initializer.trunc_normal_param.mean;\n      const float std = initializer.trunc_normal_param.std;\n      const float a = initializer.trunc_normal_param.a;\n      const float b = initializer.trunc_normal_param.b;\n      while (true) {\n        value = curand_normal(&state) * std + mean;\n        if (value >= a && value <= b) { break; }\n        skipahead(line_size, &state);\n      }\n    } else {\n      __trap();\n    }\n    values[offset] = value;\n  }\n}\n\ntemplate<typename T, typename K, typename U, typename IDX>\nvoid LookupAndInitMissing(ep::Stream* stream, uint64_t seed, embedding::KeyValueStore* store,\n                          const EmbeddingInitializer* initializer_param,\n                          const int8_t* initializer_index, void* host_num_keys, uint32_t num_unique,\n                          const int64_t embedding_size, const int64_t line_size,\n                          const bool put_to_store, const void* unique_ids, const void* table_ids,\n                          void* num_missing_ptr, void* missing_indices, void* store_values) {\n  store->Get(stream, num_unique, unique_ids, store_values,\n             reinterpret_cast<uint32_t*>(num_missing_ptr),\n             reinterpret_cast<uint32_t*>(missing_indices));\n  CHECK_GE(sizeof(IDX), sizeof(uint32_t));  // host_num_keys's buffer size is sizeof(IDX)\n  OF_CUDA_CHECK(cudaMemcpyAsync(host_num_keys, num_missing_ptr, sizeof(uint32_t), cudaMemcpyDefault,\n                                stream->As<ep::CudaStream>()->cuda_stream()));\n  CHECK_JUST(stream->Sync());\n  uint32_t num_missing = *reinterpret_cast<uint32_t*>(host_num_keys);\n  // init missing values\n  if (num_missing > 0) {\n    const int64_t elem_cnt = num_missing * line_size;\n    const int64_t num_blocks = BlocksNum4ThreadsNum(elem_cnt);\n    InitValueKernel<T, K, U>\n        <<<num_blocks, kCudaThreadsNumPerBlock, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(\n            seed, line_size, embedding_size, initializer_param, initializer_index,\n            reinterpret_cast<const K*>(unique_ids), reinterpret_cast<const U*>(table_ids),\n            reinterpret_cast<uint32_t*>(num_missing_ptr),\n            reinterpret_cast<uint32_t*>(missing_indices), reinterpret_cast<T*>(store_values));\n  }\n  if (put_to_store) { store->Put(stream, num_unique, unique_ids, store_values); }\n}\n\ntemplate<typename T, typename K, typename U, typename IDX>\nvoid LookupAndInitMissing(ep::Stream* stream, EmbeddingKernelState<IDX>* kernel_state,\n                          uint64_t seed, uint32_t num_unique, const int64_t embedding_size,\n                          const int64_t line_size, const bool put_to_store, const void* unique_ids,\n                          const void* table_ids, void* num_missing_ptr, void* missing_indices,\n                          void* store_values) {\n  embedding::KeyValueStore* store = kernel_state->KeyValueStore();\n  const EmbeddingInitializer* initializer_param = kernel_state->Initializers();\n  const int8_t* initializer_index = kernel_state->InitializerIndex();\n  void* host_num_keys = kernel_state->HostNumKeys();\n  LookupAndInitMissing<T, K, U, IDX>(stream, seed, store, initializer_param, initializer_index,\n                                     host_num_keys, num_unique, embedding_size, line_size,\n                                     put_to_store, unique_ids, table_ids, num_missing_ptr,\n                                     missing_indices, store_values);\n}\n\ntemplate<typename T, size_t pack_size>\nstruct alignas(sizeof(T) * pack_size) Pack {\n  T elem[pack_size];\n};\n\ntemplate<typename T, typename K, typename U, typename V, int pack_size>\n__global__ void FusedInitSliceCast(const int32_t elem_cnt, uint64_t seed, const int32_t line_size,\n                                   const int32_t embedding_size, const int32_t line_num_pack,\n                                   const int32_t embedding_num_pack,\n                                   const EmbeddingInitializer* initializer_param,\n                                   const int8_t* initializer_index, const K* unique_ids,\n                                   const U* table_ids, const uint8_t* lookup_mask,\n                                   Pack<T, pack_size>* values, Pack<V, pack_size>* embeddings) {\n  CUDA_1D_KERNEL_LOOP(i, elem_cnt) {\n    int row = i / line_num_pack;\n    int col = i - row * line_num_pack;\n    Pack<T, pack_size> value_i;\n    if (!lookup_mask[row]) {\n      const int32_t table_idx = table_ids[row];\n      const K id = unique_ids[row];\n      curandStatePhilox4_32_10_t state;\n      curand_init(seed, id, col, &state);\n#pragma unroll\n      for (int k = 0; k < pack_size; ++k) {\n        const int32_t initializer_idx =\n            initializer_index[table_idx * line_size + col * pack_size + k];\n        EmbeddingInitializer initializer = initializer_param[initializer_idx];\n        T value;\n        if (initializer.type == InitializerType::kUniform) {\n          const float low = initializer.uniform_param.low;\n          const float high = initializer.uniform_param.high;\n          value = curand_uniform(&state) * (high - low) + low;\n        } else if (initializer.type == InitializerType::kNormal) {\n          const float mean = initializer.normal_param.mean;\n          const float std = initializer.normal_param.std;\n          value = curand_normal(&state) * std + mean;\n        } else if (initializer.type == InitializerType::kConstant) {\n          value = initializer.constant_param.value;\n        } else if (initializer.type == InitializerType::kTruncNormal) {\n          const float mean = initializer.trunc_normal_param.mean;\n          const float std = initializer.trunc_normal_param.std;\n          const float a = initializer.trunc_normal_param.a;\n          const float b = initializer.trunc_normal_param.b;\n          while (true) {\n            value = curand_normal(&state) * std + mean;\n            if (value >= a && value <= b) { break; }\n            skipahead(line_size, &state);\n          }\n        } else {\n          __trap();\n        }\n        value_i.elem[k] = value;\n      }\n      values[i] = value_i;\n    } else {\n      value_i = values[i];\n    }\n    if (embeddings != nullptr && col < embedding_num_pack) {\n      int64_t embedding_offset = row * embedding_num_pack + col;\n      Pack<V, pack_size> embedding_i;\n#pragma unroll\n      for (int k = 0; k < pack_size; ++k) { embedding_i.elem[k] = static_cast<V>(value_i.elem[k]); }\n      embeddings[embedding_offset] = embedding_i;\n    }\n  }\n}\n\ntemplate<typename T, typename K, typename U, typename V>\nvoid InitMissingAndSliceCast(cudaStream_t cuda_stream, uint32_t num_unique,\n                             const int64_t embedding_size, const int64_t line_size, uint64_t seed,\n                             const EmbeddingInitializer* initializer_param,\n                             const int8_t* initializer_index, const void* unique_ids,\n                             const void* table_ids, const uint8_t* mask, T* values_ptr,\n                             V* embeddings_ptr) {\n  int32_t pack_size;\n  if (embedding_size % 4 == 0 && line_size % 4 == 0) {\n    pack_size = 4;\n  } else if (embedding_size % 2 == 0 && line_size % 2 == 0) {\n    pack_size = 2;\n  } else {\n    pack_size = 1;\n  }\n  int32_t embedding_num_pack = embedding_size / pack_size;\n  int32_t line_num_pack = line_size / pack_size;\n  int64_t value_elem_cnt = num_unique * line_size;\n  int64_t value_elem_num_pack = value_elem_cnt / pack_size;\n  const int64_t num_blocks = BlocksNum4ThreadsNum(value_elem_num_pack);\n  if (pack_size == 4) {\n    FusedInitSliceCast<T, K, U, V, 4><<<num_blocks, kCudaThreadsNumPerBlock, 0, cuda_stream>>>(\n        value_elem_num_pack, seed, line_size, embedding_size, line_num_pack, embedding_num_pack,\n        initializer_param, initializer_index, reinterpret_cast<const K*>(unique_ids),\n        reinterpret_cast<const U*>(table_ids), mask, reinterpret_cast<Pack<T, 4>*>(values_ptr),\n        reinterpret_cast<Pack<V, 4>*>(embeddings_ptr));\n  } else if (pack_size == 2) {\n    FusedInitSliceCast<T, K, U, V, 2><<<num_blocks, kCudaThreadsNumPerBlock, 0, cuda_stream>>>(\n        value_elem_num_pack, seed, line_size, embedding_size, line_num_pack, embedding_num_pack,\n        initializer_param, initializer_index, reinterpret_cast<const K*>(unique_ids),\n        reinterpret_cast<const U*>(table_ids), mask, reinterpret_cast<Pack<T, 2>*>(values_ptr),\n        reinterpret_cast<Pack<V, 2>*>(embeddings_ptr));\n  } else {\n    FusedInitSliceCast<T, K, U, V, 1><<<num_blocks, kCudaThreadsNumPerBlock, 0, cuda_stream>>>(\n        value_elem_num_pack, seed, line_size, embedding_size, line_num_pack, embedding_num_pack,\n        initializer_param, initializer_index, reinterpret_cast<const K*>(unique_ids),\n        reinterpret_cast<const U*>(table_ids), mask, reinterpret_cast<Pack<T, 1>*>(values_ptr),\n        reinterpret_cast<Pack<V, 1>*>(embeddings_ptr));\n  }\n}\n\ntemplate<typename T, typename K, typename U, typename IDX>\nvoid LookupAndFusedInitMissingSliceCast(ep::Stream* stream, EmbeddingKernelState<IDX>* kernel_state,\n                                        uint64_t seed, uint32_t num_unique,\n                                        const int64_t embedding_size, const int64_t line_size,\n                                        DataType value_dtype, DataType embedding_dtype,\n                                        const void* unique_ids, const void* table_ids,\n                                        uint8_t* lookup_mask_ptr, void* values_ptr,\n                                        void* embeddings_ptr) {\n  embedding::KeyValueStore* store = kernel_state->KeyValueStore();\n  const EmbeddingInitializer* initializer_param = kernel_state->Initializers();\n  const int8_t* initializer_index = kernel_state->InitializerIndex();\n  cudaStream_t cuda_stream = stream->As<ep::CudaStream>()->cuda_stream();\n  store->Get(stream, num_unique, unique_ids, values_ptr, lookup_mask_ptr);\n  if (embedding_dtype == value_dtype) {\n    InitMissingAndSliceCast<T, K, U, T>(\n        cuda_stream, num_unique, embedding_size, line_size, seed, initializer_param,\n        initializer_index, reinterpret_cast<const K*>(unique_ids),\n        reinterpret_cast<const U*>(table_ids), lookup_mask_ptr, reinterpret_cast<T*>(values_ptr),\n        reinterpret_cast<T*>(embeddings_ptr));\n  } else if (embedding_dtype == DataType::kFloat16) {\n    InitMissingAndSliceCast<T, K, U, half>(\n        cuda_stream, num_unique, embedding_size, line_size, seed, initializer_param,\n        initializer_index, reinterpret_cast<const K*>(unique_ids),\n        reinterpret_cast<const U*>(table_ids), lookup_mask_ptr, reinterpret_cast<T*>(values_ptr),\n        reinterpret_cast<half*>(embeddings_ptr));\n  } else {\n    UNIMPLEMENTED() << \"Unimplemented data_type \" << embedding_dtype;\n  }\n}\n\ntemplate<typename T, typename U>\n__global__ void Copy2D(int64_t out_elem_cnt, const int32_t in_cols, const int32_t out_cols,\n                       const T* in, U* out) {\n  CUDA_1D_KERNEL_LOOP(i, out_elem_cnt) {\n    const int32_t row = i / out_cols;\n    const int32_t col = i - row * out_cols;\n    const int64_t in_offset = row * in_cols + col;\n    out[i] = static_cast<U>(in[in_offset]);\n  }\n}\n\ntemplate<typename T>\nvoid CopyValuesToEmbeddings(ep::Stream* stream, int64_t num_unique, const int32_t embedding_size,\n                            const int32_t value_size, const DataType value_dtype,\n                            const DataType embedding_dtype, const T* values, void* embeddings) {\n  bool need_cast = (value_dtype != embedding_dtype);\n  bool need_copy_nd = (embedding_size != value_size);\n  CHECK(need_cast || need_copy_nd);\n  if (need_cast && !need_copy_nd) {\n    const int64_t cast_elem_count = num_unique * embedding_size;\n    std::unique_ptr<ep::primitive::Cast> cast_primitive =\n        ep::primitive::NewPrimitive<ep::primitive::CastFactory>(DeviceType::kCUDA, value_dtype,\n                                                                embedding_dtype);\n    cast_primitive->Launch(stream, values, embeddings, cast_elem_count);\n  } else if (!need_cast && need_copy_nd) {\n    const int32_t ndims = 2;\n    DimVector src_pos_vec(ndims, 0);\n    DimVector dst_pos_vec(ndims, 0);\n    DimVector src_shape = {num_unique, value_size};\n    DimVector dst_shape = {num_unique, embedding_size};\n    DimVector extent_shape = {num_unique, embedding_size};\n    std::unique_ptr<ep::primitive::CopyNd> copy_nd_primitive =\n        ep::primitive::NewPrimitive<ep::primitive::CopyNdFactory>(DeviceType::kCUDA, ndims);\n    CHECK(copy_nd_primitive);\n    copy_nd_primitive->Launch(stream, value_dtype, ndims, embeddings, dst_shape.data(),\n                              dst_pos_vec.data(), values, src_shape.data(), src_pos_vec.data(),\n                              extent_shape.data());\n  } else {\n    const int64_t embedding_elem_cnt = num_unique * embedding_size;\n    if (embedding_dtype == DataType::kFloat16) {\n      Copy2D<T, half><<<BlocksNum4ThreadsNum(embedding_elem_cnt), kCudaThreadsNumPerBlock, 0,\n                        stream->As<ep::CudaStream>()->cuda_stream()>>>(\n          embedding_elem_cnt, value_size, embedding_size, values,\n          reinterpret_cast<half*>(embeddings));\n    } else {\n      UNIMPLEMENTED();\n    }\n  }\n}\n\ntemplate<typename T, bool is_prefetch>\nuser_op::InferTmpSizeFn GenEmbeddingInferTmpSizeFn() {\n  return [](user_op::InferContext* ctx) {\n    size_t total_buffer_size = 0;\n    if (embedding::UseDynamicMemoryAllocation()) { return total_buffer_size; }\n    const user_op::TensorDesc& unique_ids = ctx->InputTensorDesc(\"unique_ids\", 0);\n    int64_t num_ids = unique_ids.shape().elem_cnt();\n    size_t num_missing_size = GetCudaAlignedSize(sizeof(uint32_t));\n    size_t missing_indices_size = GetCudaAlignedSize(num_ids * sizeof(uint32_t));\n    size_t value_buffer_size;\n    if (is_prefetch) {\n      size_t value_byte_size = ctx->Attr<int64_t>(\"line_size\") * sizeof(T);\n      value_buffer_size = GetCudaAlignedSize(num_ids * value_byte_size);\n    } else {\n      value_buffer_size = 0;\n    }\n    total_buffer_size = num_missing_size + missing_indices_size + value_buffer_size;\n    return total_buffer_size;\n  };\n}\n\nclass IdShuffleCopyOutKernelState final : public user_op::OpKernelState {\n public:\n  explicit IdShuffleCopyOutKernelState(user_op::KernelInitContext* ctx) {\n    const std::string& embedding_name = ctx->Attr<std::string>(\"embedding_name\");\n    const int64_t parallel_id = ctx->parallel_ctx().parallel_id();\n    embedding_state_ = Singleton<embedding::EmbeddingManager>::Get()->GetEmbeddingState(\n        embedding_name, parallel_id);\n  }\n  ~IdShuffleCopyOutKernelState() override = default;\n\n  embedding::EmbeddingState* EmbeddingState() { return embedding_state_; }\n\n private:\n  embedding::EmbeddingState* embedding_state_;\n};\n\ntemplate<typename K, typename U, typename IDX>\nstruct IdShuffleCopyOutParam {\n  uint32_t final_num_unique_ids;\n  const K* cur_rank_unique_ids;\n  K* out_cur_rank_unique_ids;\n  const U* cur_rank_unique_table_ids;\n  U* out_cur_rank_unique_table_ids;\n  uint32_t cur_rank_num_ids;\n  const IDX* cur_rank_inverse_indices;\n  IDX* out_cur_rank_inverse_indices;\n  uint32_t num_ids;\n  const IDX* inverse_unique_partition_indices;\n  IDX* out_inverse_unique_partition_indices;\n  uint32_t num_unique_matrix_cnt;\n  const IDX* num_unique_matrix;\n  IDX* out_num_unique_matrix;\n  const IDX* cur_rank_num_unique;\n  IDX* out_cur_rank_num_unique;\n};\n\ntemplate<typename K, typename U, typename IDX>\n__global__ void CopyGpu(IdShuffleCopyOutParam<K, U, IDX> param) {\n  CUDA_1D_KERNEL_LOOP_T(uint32_t, i, param.final_num_unique_ids) {\n    param.out_cur_rank_unique_ids[i] = param.cur_rank_unique_ids[i];\n    param.out_cur_rank_unique_table_ids[i] = param.cur_rank_unique_table_ids[i];\n  }\n  CUDA_1D_KERNEL_LOOP_T(uint32_t, i, param.cur_rank_num_ids) {\n    param.out_cur_rank_inverse_indices[i] = param.cur_rank_inverse_indices[i];\n  }\n  CUDA_1D_KERNEL_LOOP_T(uint32_t, i, param.num_ids) {\n    param.out_inverse_unique_partition_indices[i] = param.inverse_unique_partition_indices[i];\n  }\n  CUDA_1D_KERNEL_LOOP_T(uint32_t, i, param.num_unique_matrix_cnt) {\n    param.out_num_unique_matrix[i] = param.num_unique_matrix[i];\n  }\n  if (blockIdx.x * blockDim.x + threadIdx.x == 0) {\n    *param.out_cur_rank_num_unique = *param.cur_rank_num_unique;\n  }\n}\n\n}  // namespace\n\ntemplate<typename T, typename K, typename U, typename IDX>\nclass EmbeddingPrefetchKernel final : public user_op::OpKernel {\n public:\n  EmbeddingPrefetchKernel() : current_iter_(0){};\n  ~EmbeddingPrefetchKernel() override = default;\n\n  std::shared_ptr<user_op::OpKernelState> CreateOpKernelState(\n      user_op::KernelInitContext* ctx) const override {\n    return std::make_shared<EmbeddingKernelState<IDX>>(ctx);\n  }\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state,\n               const user_op::OpKernelCache*) const override {\n    auto* kernel_state = dynamic_cast<EmbeddingKernelState<IDX>*>(state);\n    CHECK(kernel_state != nullptr);\n    embedding::EmbeddingState* embedding_state = kernel_state->EmbeddingState();\n    std::unique_ptr<embedding::TmpBufferAllocator> allocator =\n        embedding_state->NewTmpBufferAllocator(ctx);\n    uint32_t num_unique = embedding_state->GetIdNumUnique(current_iter_);\n    const user_op::Tensor* num_unique_ids = ctx->Tensor4ArgNameAndIndex(\"num_unique_ids\", 0);\n    const user_op::Tensor* unique_ids = ctx->Tensor4ArgNameAndIndex(\"unique_ids\", 0);\n    const user_op::Tensor* table_ids = ctx->Tensor4ArgNameAndIndex(\"table_ids\", 0);\n    const int64_t embedding_size = ctx->Attr<int64_t>(\"embedding_size\");\n    const int64_t line_size = ctx->Attr<int64_t>(\"line_size\");\n    const int64_t seed = ctx->Attr<int64_t>(\"seed\");\n    void* num_missing_ptr;\n    allocator->Allocate(&num_missing_ptr, sizeof(uint32_t));\n    void* missing_indices_ptr;\n    allocator->Allocate(&missing_indices_ptr, num_unique * sizeof(uint32_t));\n    void* values_ptr;\n    allocator->Allocate(&values_ptr, num_unique * line_size * sizeof(T));\n    LookupAndInitMissing<T, K, U, IDX>(\n        ctx->stream(), kernel_state, seed, num_unique, embedding_size, line_size, true,\n        unique_ids->dptr(), table_ids->dptr(), num_missing_ptr, missing_indices_ptr, values_ptr);\n    allocator->Free(num_missing_ptr);\n    allocator->Free(missing_indices_ptr);\n    allocator->Free(values_ptr);\n    current_iter_++;\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n  mutable int64_t current_iter_;\n};\n\n#define EMBEDDING_DATA_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(float, DataType::kFloat)\n\n#define ID_DATA_TYPE_SEQ                            \\\n  OF_PP_MAKE_TUPLE_SEQ(uint32_t, DataType::kUInt32) \\\n  OF_PP_MAKE_TUPLE_SEQ(uint64_t, DataType::kUInt64) \\\n  OF_PP_MAKE_TUPLE_SEQ(int32_t, DataType::kInt32)   \\\n  OF_PP_MAKE_TUPLE_SEQ(int64_t, DataType::kInt64)\n\n#define TABLE_ID_DATA_TYPE_SEQ                      \\\n  OF_PP_MAKE_TUPLE_SEQ(uint8_t, DataType::kUInt8)   \\\n  OF_PP_MAKE_TUPLE_SEQ(uint32_t, DataType::kUInt32) \\\n  OF_PP_MAKE_TUPLE_SEQ(uint64_t, DataType::kUInt64) \\\n  OF_PP_MAKE_TUPLE_SEQ(int8_t, DataType::kInt8)     \\\n  OF_PP_MAKE_TUPLE_SEQ(int32_t, DataType::kInt32)   \\\n  OF_PP_MAKE_TUPLE_SEQ(int64_t, DataType::kInt64)\n\n#define IDX_DATA_TYPE_SEQ                           \\\n  OF_PP_MAKE_TUPLE_SEQ(uint32_t, DataType::kUInt32) \\\n  OF_PP_MAKE_TUPLE_SEQ(int32_t, DataType::kInt32)\n\n#define REGISTER_CUDA_EMBEDDING_PREFETCH_KERNEL(t_dtype_pair, k_dtype_pair, table_dtype_pair,  \\\n                                                idx_dtype_pair)                                \\\n  REGISTER_USER_KERNEL(\"embedding_prefetch\")                                                   \\\n      .SetCreateFn<EmbeddingPrefetchKernel<                                                    \\\n          OF_PP_PAIR_FIRST(t_dtype_pair), OF_PP_PAIR_FIRST(k_dtype_pair),                      \\\n          OF_PP_PAIR_FIRST(table_dtype_pair), OF_PP_PAIR_FIRST(idx_dtype_pair)>>()             \\\n      .SetIsMatchedHob(                                                                        \\\n          (user_op::HobDeviceType() == DeviceType::kCUDA)                                      \\\n          && (user_op::HobDataType(\"unique_ids\", 0) == OF_PP_PAIR_SECOND(k_dtype_pair))        \\\n          && (user_op::HobDataType(\"table_ids\", 0) == OF_PP_PAIR_SECOND(table_dtype_pair))     \\\n          && (user_op::HobDataType(\"num_unique_ids\", 0) == OF_PP_PAIR_SECOND(idx_dtype_pair))) \\\n      .SetInferTmpSizeFn(GenEmbeddingInferTmpSizeFn<OF_PP_PAIR_FIRST(t_dtype_pair), true>());\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_CUDA_EMBEDDING_PREFETCH_KERNEL, EMBEDDING_DATA_TYPE_SEQ,\n                                 ID_DATA_TYPE_SEQ, TABLE_ID_DATA_TYPE_SEQ, IDX_DATA_TYPE_SEQ)\n\ntemplate<typename T, typename K, typename U, typename IDX>\nclass EmbeddingLookupKernel final : public user_op::OpKernel {\n public:\n  EmbeddingLookupKernel() : current_iter_(0){};\n  ~EmbeddingLookupKernel() override = default;\n\n  std::shared_ptr<user_op::OpKernelState> CreateOpKernelState(\n      user_op::KernelInitContext* ctx) const override {\n    return std::make_shared<EmbeddingKernelState<IDX>>(ctx);\n  }\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state,\n               const user_op::OpKernelCache*) const override {\n    auto* kernel_state = dynamic_cast<EmbeddingKernelState<IDX>*>(state);\n    CHECK(kernel_state != nullptr);\n    embedding::EmbeddingState* embedding_state = kernel_state->EmbeddingState();\n    std::unique_ptr<embedding::TmpBufferAllocator> allocator =\n        embedding_state->NewTmpBufferAllocator(ctx);\n    embedding_state->OnEmbeddingLookupStart(ctx, current_iter_);\n    const user_op::Tensor* num_unique_ids = ctx->Tensor4ArgNameAndIndex(\"num_unique_ids\", 0);\n    const user_op::Tensor* unique_ids = ctx->Tensor4ArgNameAndIndex(\"unique_ids\", 0);\n    const user_op::Tensor* table_ids = ctx->Tensor4ArgNameAndIndex(\"table_ids\", 0);\n    user_op::Tensor* unique_values = ctx->Tensor4ArgNameAndIndex(\"unique_values\", 0);\n    const int64_t embedding_size = ctx->Attr<int64_t>(\"embedding_size\");\n    const int64_t line_size = ctx->Attr<int64_t>(\"line_size\");\n    const bool has_output_embeddings = ctx->has_output(\"embeddings\", 0);\n    const int64_t seed = ctx->Attr<int64_t>(\"seed\");\n    uint32_t num_unique = embedding_state->GetIdNumUnique(current_iter_);\n    void* values_ptr = embedding_state->LookupUniqueValues(current_iter_);\n    if (has_output_embeddings && kernel_state->KeyValueStore()->IsFusionSupported()) {\n      void* embeddings_ptr = embedding_state->LookupEmbeddings(current_iter_);\n      user_op::Tensor* embeddings = ctx->Tensor4ArgNameAndIndex(\"embeddings\", 0);\n      void* lookup_mask_ptr;\n      allocator->Allocate(&lookup_mask_ptr, num_unique * sizeof(uint8_t));\n      LookupAndFusedInitMissingSliceCast<T, K, U, IDX>(\n          ctx->stream(), kernel_state, seed, num_unique, embedding_size, line_size,\n          unique_values->data_type(), embeddings->data_type(), unique_ids->dptr(),\n          table_ids->dptr(), reinterpret_cast<uint8_t*>(lookup_mask_ptr), values_ptr,\n          embeddings_ptr);\n      allocator->Free(lookup_mask_ptr);\n    } else {\n      void* num_missing_ptr;\n      allocator->Allocate(&num_missing_ptr, sizeof(uint32_t));\n      void* missing_indices_ptr;\n      allocator->Allocate(&missing_indices_ptr, num_unique * sizeof(uint32_t));\n      LookupAndInitMissing<T, K, U, IDX>(\n          ctx->stream(), kernel_state, seed, num_unique, embedding_size, line_size, false,\n          unique_ids->dptr(), table_ids->dptr(), num_missing_ptr, missing_indices_ptr, values_ptr);\n      allocator->Free(num_missing_ptr);\n      allocator->Free(missing_indices_ptr);\n      if (has_output_embeddings) {\n        void* embeddings_ptr = embedding_state->LookupEmbeddings(current_iter_);\n        user_op::Tensor* embeddings = ctx->Tensor4ArgNameAndIndex(\"embeddings\", 0);\n        CopyValuesToEmbeddings<T>(ctx->stream(), num_unique, embedding_size, line_size,\n                                  unique_values->data_type(), embeddings->data_type(),\n                                  reinterpret_cast<T*>(values_ptr), embeddings_ptr);\n      }\n    }\n    embedding_state->OnEmbeddingLookupEnd(ctx, current_iter_);\n    current_iter_++;\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n  mutable int64_t current_iter_;\n};\n\n#define REGISTER_CUDA_EMBEDDING_LOOKUP_KERNEL(t_dtype_pair, k_dtype_pair, table_dtype_pair,    \\\n                                              idx_dtype_pair)                                  \\\n  REGISTER_USER_KERNEL(\"embedding_lookup\")                                                     \\\n      .SetCreateFn<EmbeddingLookupKernel<                                                      \\\n          OF_PP_PAIR_FIRST(t_dtype_pair), OF_PP_PAIR_FIRST(k_dtype_pair),                      \\\n          OF_PP_PAIR_FIRST(table_dtype_pair), OF_PP_PAIR_FIRST(idx_dtype_pair)>>()             \\\n      .SetIsMatchedHob(                                                                        \\\n          (user_op::HobDeviceType() == DeviceType::kCUDA)                                      \\\n          && (user_op::HobDataType(\"unique_values\", 0) == OF_PP_PAIR_SECOND(t_dtype_pair))     \\\n          && (user_op::HobDataType(\"unique_ids\", 0) == OF_PP_PAIR_SECOND(k_dtype_pair))        \\\n          && (user_op::HobDataType(\"table_ids\", 0) == OF_PP_PAIR_SECOND(table_dtype_pair))     \\\n          && (user_op::HobDataType(\"num_unique_ids\", 0) == OF_PP_PAIR_SECOND(idx_dtype_pair))) \\\n      .SetInferTmpSizeFn(GenEmbeddingInferTmpSizeFn<OF_PP_PAIR_FIRST(t_dtype_pair), false>());\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_CUDA_EMBEDDING_LOOKUP_KERNEL, EMBEDDING_DATA_TYPE_SEQ,\n                                 ID_DATA_TYPE_SEQ, TABLE_ID_DATA_TYPE_SEQ, IDX_DATA_TYPE_SEQ)\n\ntemplate<typename IDX>\nclass EmbeddingPutKernel final : public user_op::OpKernel {\n public:\n  EmbeddingPutKernel() : current_iter_(0){};\n  ~EmbeddingPutKernel() override = default;\n\n  std::shared_ptr<user_op::OpKernelState> CreateOpKernelState(\n      user_op::KernelInitContext* ctx) const override {\n    return std::make_shared<EmbeddingPutKernelState>(ctx);\n  }\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state,\n               const user_op::OpKernelCache*) const override {\n    auto* kernel_state = dynamic_cast<EmbeddingPutKernelState*>(state);\n    CHECK(kernel_state != nullptr);\n    embedding::KeyValueStore* store = kernel_state->KeyValueStore();\n    embedding::EmbeddingState* embedding_state = kernel_state->EmbeddingState();\n    embedding_state->OnEmbeddingPutStart(ctx, current_iter_);\n    const user_op::Tensor* num_unique_ids = ctx->Tensor4ArgNameAndIndex(\"num_unique_ids\", 0);\n    const user_op::Tensor* unique_ids = ctx->Tensor4ArgNameAndIndex(\"unique_ids\", 0);\n    const user_op::Tensor* unique_embeddings = ctx->Tensor4ArgNameAndIndex(\"unique_embeddings\", 0);\n    uint32_t num_unique = embedding_state->GetIdNumUnique(current_iter_);\n    store->Put(ctx->stream(), num_unique, unique_ids->dptr(),\n               embedding_state->EmbeddingPutUniqueEmbeddings(current_iter_));\n    embedding_state->OnEmbeddingPutEnd(ctx, current_iter_);\n    current_iter_++;\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n  mutable int64_t current_iter_;\n};\n\n#define REGISTER_CUDA_EMBEDDING_PUT_KERNEL(dtype, typeproto)           \\\n  REGISTER_USER_KERNEL(\"embedding_put\")                                \\\n      .SetCreateFn<EmbeddingPutKernel<dtype>>()                        \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \\\n                       && (user_op::HobDataType(\"num_unique_ids\", 0) == typeproto));\n\nOF_PP_FOR_EACH_TUPLE(REGISTER_CUDA_EMBEDDING_PUT_KERNEL, IDX_DATA_TYPE_SEQ)\n\ntemplate<typename IDX>\nclass OneEmbeddingFusedSgdUpdatePutKernel final : public user_op::OpKernel {\n public:\n  OneEmbeddingFusedSgdUpdatePutKernel() : current_iter_(0){};\n  ~OneEmbeddingFusedSgdUpdatePutKernel() override = default;\n\n  std::shared_ptr<user_op::OpKernelState> CreateOpKernelState(\n      user_op::KernelInitContext* ctx) const override {\n    return std::make_shared<EmbeddingPutKernelState>(ctx);\n  }\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state,\n               const user_op::OpKernelCache*) const override {\n    auto* kernel_state = dynamic_cast<EmbeddingPutKernelState*>(state);\n    CHECK(kernel_state != nullptr);\n    embedding::KeyValueStore* store = kernel_state->KeyValueStore();\n    embedding::EmbeddingState* embedding_state = kernel_state->EmbeddingState();\n    embedding_state->OnEmbeddingFusedUpdatePutStart(ctx, current_iter_);\n    const user_op::Tensor* unique_ids = ctx->Tensor4ArgNameAndIndex(\"unique_ids\", 0);\n    const user_op::Tensor* embedding_grad = ctx->Tensor4ArgNameAndIndex(\"embedding_grad\", 0);\n    const user_op::Tensor* learning_rate = ctx->Tensor4ArgNameAndIndex(\"learning_rate\", 0);\n    const float* learning_rate_ptr = learning_rate->dptr<float>();\n    const auto scale = ctx->Attr<double>(\"scale\");\n    uint32_t num_unique = embedding_state->GetIdNumUnique(current_iter_);\n    store->FusedHalfUpdatePut(\n        ctx->stream(), num_unique, unique_ids->dptr(),\n        embedding_state->EmbeddingFusedUpdatePutUniqueEmbeddings(current_iter_),\n        embedding_grad->dptr(), learning_rate_ptr, scale);\n    embedding_state->OnEmbeddingFusedUpdatePutEnd(ctx, current_iter_);\n    current_iter_++;\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n  mutable int64_t current_iter_;\n};\n\n#define REGISTER_CUDA_ONE_EMBEDDING_FUSED_SGD_UPDATE_PUT_KERNEL(dtype, typeproto)            \\\n  REGISTER_USER_KERNEL(\"one_embedding_fused_sgd_update_put\")                                 \\\n      .SetCreateFn<OneEmbeddingFusedSgdUpdatePutKernel<dtype>>()                             \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)                       \\\n                       && (user_op::HobDataType(\"num_unique_ids\", 0) == typeproto)           \\\n                       && (user_op::HobDataType(\"unique_embeddings\", 0) == DataType::kFloat) \\\n                       && (user_op::HobDataType(\"embedding_grad\", 0) == DataType::kFloat16));\n\nOF_PP_FOR_EACH_TUPLE(REGISTER_CUDA_ONE_EMBEDDING_FUSED_SGD_UPDATE_PUT_KERNEL, IDX_DATA_TYPE_SEQ)\n\ntemplate<typename K, typename U, typename IDX>\nclass IdShuffleCopyOutKernel final : public user_op::OpKernel {\n public:\n  IdShuffleCopyOutKernel() : current_iter_(0){};\n  ~IdShuffleCopyOutKernel() override = default;\n\n  std::shared_ptr<user_op::OpKernelState> CreateOpKernelState(\n      user_op::KernelInitContext* ctx) const override {\n    return std::make_shared<IdShuffleCopyOutKernelState>(ctx);\n  }\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state,\n               const user_op::OpKernelCache*) const override {\n    auto* kernel_state = dynamic_cast<IdShuffleCopyOutKernelState*>(state);\n    CHECK(kernel_state != nullptr);\n    const int64_t parallel_num = ctx->parallel_ctx().parallel_num();\n    const int64_t parallel_id = ctx->parallel_ctx().parallel_id();\n    embedding::EmbeddingState* embedding_state = kernel_state->EmbeddingState();\n    const uint32_t num_unique = embedding_state->GetIdNumUnique(current_iter_);\n    const std::vector<uint32_t>& num_unique_matrix_vec =\n        embedding_state->GetIdNumUniqueMatrix(current_iter_);\n    uint32_t cur_rank_num_ids = 0;\n    for (int64_t i = 0; i < parallel_num; ++i) {\n      cur_rank_num_ids += num_unique_matrix_vec.at(i * parallel_num + parallel_id);\n    }\n    IdShuffleCopyOutParam<K, U, IDX> param;\n    param.final_num_unique_ids = num_unique;\n    param.cur_rank_unique_ids =\n        reinterpret_cast<const K*>(ctx->Tensor4ArgNameAndIndex(\"cur_rank_unique_ids\", 0)->dptr());\n    param.out_cur_rank_unique_ids =\n        reinterpret_cast<K*>(ctx->Tensor4ArgNameAndIndex(\"out_cur_rank_unique_ids\", 0)->mut_dptr());\n    param.cur_rank_unique_table_ids = reinterpret_cast<const U*>(\n        ctx->Tensor4ArgNameAndIndex(\"cur_rank_unique_table_ids\", 0)->dptr());\n    param.out_cur_rank_unique_table_ids = reinterpret_cast<U*>(\n        ctx->Tensor4ArgNameAndIndex(\"out_cur_rank_unique_table_ids\", 0)->mut_dptr());\n    param.cur_rank_num_ids = cur_rank_num_ids;\n    param.cur_rank_inverse_indices = reinterpret_cast<const IDX*>(\n        ctx->Tensor4ArgNameAndIndex(\"cur_rank_inverse_indices\", 0)->dptr());\n    param.out_cur_rank_inverse_indices = reinterpret_cast<IDX*>(\n        ctx->Tensor4ArgNameAndIndex(\"out_cur_rank_inverse_indices\", 0)->mut_dptr());\n    param.num_ids =\n        ctx->Tensor4ArgNameAndIndex(\"inverse_unique_partition_indices\", 0)->shape_view().elem_cnt();\n    param.inverse_unique_partition_indices = reinterpret_cast<const IDX*>(\n        ctx->Tensor4ArgNameAndIndex(\"inverse_unique_partition_indices\", 0)->dptr());\n    param.out_inverse_unique_partition_indices = reinterpret_cast<IDX*>(\n        ctx->Tensor4ArgNameAndIndex(\"out_inverse_unique_partition_indices\", 0)->mut_dptr());\n    param.num_unique_matrix_cnt = parallel_num * parallel_num;\n    param.num_unique_matrix =\n        reinterpret_cast<const IDX*>(ctx->Tensor4ArgNameAndIndex(\"num_unique_matrix\", 0)->dptr());\n    param.out_num_unique_matrix =\n        reinterpret_cast<IDX*>(ctx->Tensor4ArgNameAndIndex(\"out_num_unique_matrix\", 0)->mut_dptr());\n    param.cur_rank_num_unique =\n        reinterpret_cast<const IDX*>(ctx->Tensor4ArgNameAndIndex(\"cur_rank_num_unique\", 0)->dptr());\n    param.out_cur_rank_num_unique = reinterpret_cast<IDX*>(\n        ctx->Tensor4ArgNameAndIndex(\"out_cur_rank_num_unique\", 0)->mut_dptr());\n\n    CopyGpu<K, U, IDX><<<BlocksNum4ThreadsNum(param.num_ids), kCudaThreadsNumPerBlock, 0,\n                         ctx->stream()->As<ep::CudaStream>()->cuda_stream()>>>(param);\n    current_iter_++;\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n  mutable int64_t current_iter_;\n};\n\n#define REGISTER_CUDA_ID_SHUFFLE_COPY_OUT_KERNEL(k_dtype_pair, table_id_dtype_pair,              \\\n                                                 idx_dtype_pair)                                 \\\n  REGISTER_USER_KERNEL(\"id_shuffle_copy_out\")                                                    \\\n      .SetCreateFn<IdShuffleCopyOutKernel<OF_PP_PAIR_FIRST(k_dtype_pair),                        \\\n                                          OF_PP_PAIR_FIRST(table_id_dtype_pair),                 \\\n                                          OF_PP_PAIR_FIRST(idx_dtype_pair)>>()                   \\\n      .SetIsMatchedHob(                                                                          \\\n          (user_op::HobDeviceType() == DeviceType::kCUDA)                                        \\\n          && (user_op::HobDataType(\"cur_rank_unique_ids\", 0) == OF_PP_PAIR_SECOND(k_dtype_pair)) \\\n          && (user_op::HobDataType(\"cur_rank_unique_table_ids\", 0)                               \\\n              == OF_PP_PAIR_SECOND(table_id_dtype_pair))                                         \\\n          && (user_op::HobDataType(\"num_unique_matrix\", 0) == OF_PP_PAIR_SECOND(idx_dtype_pair)));\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_CUDA_ID_SHUFFLE_COPY_OUT_KERNEL, ID_DATA_TYPE_SEQ,\n                                 TABLE_ID_DATA_TYPE_SEQ, IDX_DATA_TYPE_SEQ)\n\nenum class FusedEmbeddingBufferType {\n  // id shuffle\n  kNumPartitionedUnique = 0,\n  kPartitionedUniqueIds,\n  kReceivedIds,\n  kTableIds,\n  kPartitionedUniqueTableIds,\n  kReceivedTableIds,\n  kWorkspace,\n  kNumUniqueMatrix,\n  kInverseUniquePartitionIndices,\n  kCurRankNumUnique,\n  kCurRankUniqueIds,\n  kCurRankUniqueTableIds,\n  kCurRankInverseIndices,\n  // embedding lookup\n  kNumMissing,\n  kMissingIndices,\n  kCurRankUniqueValues,\n  kCurRankUniqueEmbeddings,\n  // embedding shuffle\n  kReverseUniqueCurRankEmbeddings,\n  kReceivedEmbeddings,\n  kMaxType\n};\n\ntemplate<typename K, typename U, typename IDX>\nclass FusedEmbeddingTmpBufferManager final {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(FusedEmbeddingTmpBufferManager);\n  FusedEmbeddingTmpBufferManager(void* ptr, const int64_t num_ids, const int64_t parallel_num,\n                                 bool need_process_table_ids, int64_t line_size,\n                                 int64_t embedding_size, bool need_unique_values,\n                                 bool need_embeddings, DataType value_dtype,\n                                 DataType embedding_dtype)\n      : offset_(0),\n        offsets_(static_cast<size_t>(FusedEmbeddingBufferType::kMaxType), -1),\n        sizes_(static_cast<size_t>(FusedEmbeddingBufferType::kMaxType)),\n        ptr_(ptr) {\n    // id shuffle\n    const int64_t num_table_ids = need_process_table_ids ? num_ids : 0;\n    const size_t table_ids_bytes = need_process_table_ids ? num_ids * sizeof(U) : 0;\n    AllocBuffer(FusedEmbeddingBufferType::kNumPartitionedUnique, parallel_num * sizeof(IDX));\n    size_t partitioned_ids_bytes = parallel_num * num_ids * sizeof(K);\n    AllocBuffer(FusedEmbeddingBufferType::kPartitionedUniqueIds, partitioned_ids_bytes);\n    AllocBuffer(FusedEmbeddingBufferType::kReceivedIds, partitioned_ids_bytes);\n    AllocBuffer(FusedEmbeddingBufferType::kTableIds, table_ids_bytes);\n    size_t partitioned_table_ids_bytes = parallel_num * num_table_ids * sizeof(U);\n    AllocBuffer(FusedEmbeddingBufferType::kPartitionedUniqueTableIds, partitioned_table_ids_bytes);\n    AllocBuffer(FusedEmbeddingBufferType::kReceivedTableIds, partitioned_table_ids_bytes);\n    const size_t hash_table_capacity = parallel_num * num_ids;\n    AllocBuffer(FusedEmbeddingBufferType::kWorkspace,\n                hash_table_capacity * sizeof(data_shuffle::TableEntry<K>));\n    size_t num_unique_matrix_bytes = parallel_num * parallel_num * sizeof(IDX);\n    AllocBuffer(FusedEmbeddingBufferType::kNumUniqueMatrix, num_unique_matrix_bytes);\n    size_t inverse_unique_partition_indices_bytes = num_ids * sizeof(IDX);\n    AllocBuffer(FusedEmbeddingBufferType::kInverseUniquePartitionIndices,\n                inverse_unique_partition_indices_bytes);\n    size_t cur_rank_num_ids = parallel_num * num_ids;\n    size_t cur_rank_num_table_ids = cur_rank_num_ids;\n    size_t cur_rank_num_unique_bytes = sizeof(uint32_t);\n    AllocBuffer(FusedEmbeddingBufferType::kCurRankNumUnique, cur_rank_num_unique_bytes);\n    size_t cur_rank_unique_ids_bytes = cur_rank_num_ids * sizeof(K);\n    AllocBuffer(FusedEmbeddingBufferType::kCurRankUniqueIds, cur_rank_unique_ids_bytes);\n    size_t cur_rank_unique_table_ids_bytes = cur_rank_num_table_ids * sizeof(U);\n    AllocBuffer(FusedEmbeddingBufferType::kCurRankUniqueTableIds, cur_rank_unique_table_ids_bytes);\n    size_t cur_rank_inverse_indices_bytes = cur_rank_num_ids * sizeof(IDX);\n    AllocBuffer(FusedEmbeddingBufferType::kCurRankInverseIndices, cur_rank_inverse_indices_bytes);\n    // embedding lookup\n    size_t num_missing_bytes = sizeof(uint32_t);\n    AllocBuffer(FusedEmbeddingBufferType::kNumMissing, num_missing_bytes);\n    size_t missing_indices_bytes = cur_rank_num_ids * sizeof(uint32_t);\n    AllocBuffer(FusedEmbeddingBufferType::kMissingIndices, missing_indices_bytes);\n    if (need_unique_values) {\n      size_t cur_rank_unique_values_bytes =\n          cur_rank_num_ids * line_size * GetSizeOfDataType(value_dtype);\n      AllocBuffer(FusedEmbeddingBufferType::kCurRankUniqueValues, cur_rank_unique_values_bytes);\n    }\n    if (need_embeddings) {\n      size_t cur_rank_unique_embeddings_bytes =\n          cur_rank_num_ids * embedding_size * GetSizeOfDataType(embedding_dtype);\n      AllocBuffer(FusedEmbeddingBufferType::kCurRankUniqueEmbeddings,\n                  cur_rank_unique_embeddings_bytes);\n    }\n    // embedding shuffle\n    size_t reverse_unique_cur_rank_embeddings_bytes =\n        cur_rank_num_ids * embedding_size * GetSizeOfDataType(embedding_dtype);\n    AllocBuffer(FusedEmbeddingBufferType::kReverseUniqueCurRankEmbeddings,\n                reverse_unique_cur_rank_embeddings_bytes);\n    size_t received_embeddings_bytes =\n        cur_rank_num_ids * embedding_size * GetSizeOfDataType(embedding_dtype);\n    AllocBuffer(FusedEmbeddingBufferType::kReceivedEmbeddings, received_embeddings_bytes);\n  }\n\n  template<typename T = void>\n  T* Ptr(FusedEmbeddingBufferType type) const {\n    CHECK(ptr_ != nullptr);\n    int64_t offset = offsets_.at(static_cast<size_t>(type));\n    CHECK_NE(offset, -1);\n    return reinterpret_cast<T*>(reinterpret_cast<char*>(ptr_) + offset);\n  }\n\n  int64_t Size(FusedEmbeddingBufferType type) const { return sizes_.at(static_cast<size_t>(type)); }\n\n  size_t TotalBufferSize() const { return offset_; }\n\n private:\n  void AllocBuffer(FusedEmbeddingBufferType type, size_t size) {\n    const size_t type_id = static_cast<size_t>(type);\n    CHECK_EQ(offsets_.at(type_id), -1);\n    offsets_.at(type_id) = offset_;\n    sizes_.at(type_id) = size;\n    offset_ += GetCudaAlignedSize(size);\n  }\n  size_t offset_;\n  std::vector<int64_t> offsets_;\n  std::vector<int64_t> sizes_;\n  void* ptr_;\n};\n\nvoid MakeConstantInitializerAttr(const int64_t embedding_size, const int64_t line_size,\n                                 const std::vector<float>& values, std::string* initializer_attr) {\n  if (embedding_size == line_size) { return; }\n  const int32_t num_states = line_size / embedding_size - 1;\n  CHECK_GT(num_states, 0) << \"num_states \" << num_states;\n  CHECK(values.size() == 0 || num_states == values.size())\n      << \"must set \" << num_states << \" optimizer states init value, but get \" << values.size();\n  nlohmann::json initializers;\n  for (int32_t i = 0; i < num_states; ++i) {\n    nlohmann::json initializer;\n    initializer[\"type\"] = \"constant\";\n    const float initial_value = values.size() > 0 ? values.at(i) : 0.0;\n    initializer[\"value\"] = initial_value;\n    initializers.push_back(initializer);\n  }\n  *initializer_attr = initializers.dump();\n}\n\ntemplate<typename IDX>\nclass OneEmbeddingFusedLookupKernelState final : public user_op::OpKernelState {\n public:\n  explicit OneEmbeddingFusedLookupKernelState(user_op::KernelInitContext* ctx)\n      : device_index_(-1),\n        stream_name_(EagerNcclCommMgr::kDefaultStreamName),\n        parallel_desc_(ctx->parallel_desc()) {\n    OF_CUDA_CHECK(cudaGetDevice(&device_index_));\n    const int64_t parallel_id = ctx->parallel_ctx().parallel_id();\n    const int64_t parallel_num = ctx->parallel_ctx().parallel_num();\n    OF_CUDA_CHECK(cudaMallocHost(&host_num_keys_, sizeof(IDX)));\n    OF_CUDA_CHECK(\n        cudaMallocHost(&host_num_unique_matrix_, parallel_num * parallel_num * sizeof(IDX)));\n    const std::string& embedding_name = ctx->Attr<std::string>(\"embedding_name\");\n    key_value_store_ = Singleton<embedding::EmbeddingManager>::Get()->GetKeyValueStore(\n        embedding_name, parallel_id);\n    uint32_t max_query_length =\n        ctx->TensorDesc4ArgNameAndIndex(\"ids\", 0)->shape().elem_cnt() * parallel_num;\n    key_value_store_->ReserveQueryLength(max_query_length);\n\n    const int64_t embedding_size = ctx->Attr<int64_t>(\"embedding_size\");\n    const int64_t line_size = ctx->Attr<int64_t>(\"line_size\");\n    // Note(guoran): This op have no optimizer info, so set embedding states initializer constant\n    // 0, which may make error in optimizer with initial_accumulator_value like adagrad and ftrl.\n    std::string state_initializer;\n    MakeConstantInitializerAttr(embedding_size, line_size, {}, &state_initializer);\n\n    std::vector<EmbeddingInitializer> initializer_param;\n    std::vector<int8_t> initializer_index;\n    ParseInitializers(line_size, embedding_size, state_initializer,\n                      ctx->Attr<std::string>(\"embedding_tables\"), &initializer_param,\n                      &initializer_index);\n\n    const size_t param_size_bytes = initializer_param.size() * sizeof(EmbeddingInitializer);\n    OF_CUDA_CHECK(cudaMallocHost(&host_initializer_param_, param_size_bytes));\n    std::memcpy(host_initializer_param_, initializer_param.data(), param_size_bytes);\n    OF_CUDA_CHECK(cudaMalloc(&device_initializer_param_, param_size_bytes));\n    OF_CUDA_CHECK(cudaMemcpyAsync(device_initializer_param_, host_initializer_param_,\n                                  param_size_bytes, cudaMemcpyDefault,\n                                  ctx->stream()->As<ep::CudaStream>()->cuda_stream()));\n\n    const size_t index_size_bytes = initializer_index.size() * sizeof(int8_t);\n    OF_CUDA_CHECK(cudaMallocHost(&host_initializer_index_, index_size_bytes));\n    std::memcpy(host_initializer_index_, initializer_index.data(), index_size_bytes);\n    OF_CUDA_CHECK(cudaMalloc(&device_initializer_index_, index_size_bytes));\n    OF_CUDA_CHECK(cudaMemcpyAsync(device_initializer_index_, host_initializer_index_,\n                                  index_size_bytes, cudaMemcpyDefault,\n                                  ctx->stream()->As<ep::CudaStream>()->cuda_stream()));\n  }\n  ~OneEmbeddingFusedLookupKernelState() override {\n    CudaCurrentDeviceGuard guard(device_index_);\n    OF_CUDA_CHECK(cudaFreeHost(host_num_keys_));\n    OF_CUDA_CHECK(cudaFreeHost(host_num_unique_matrix_));\n    OF_CUDA_CHECK(cudaFreeHost(host_initializer_param_));\n    OF_CUDA_CHECK(cudaFree(device_initializer_param_));\n    OF_CUDA_CHECK(cudaFreeHost(host_initializer_index_));\n    OF_CUDA_CHECK(cudaFree(device_initializer_index_));\n  }\n\n  ncclComm_t comm() { return GetOrCreate().comm; }\n\n  IDX* HostNumUniqueMatrix() { return host_num_unique_matrix_; }\n\n  IDX* HostNumKeys() { return host_num_keys_; }\n\n  embedding::KeyValueStore* KeyValueStore() { return key_value_store_; }\n\n  const int8_t* InitializerIndex() { return device_initializer_index_; }\n  const EmbeddingInitializer* Initializers() { return device_initializer_param_; }\n\n private:\n  struct Comm {\n    Comm(ncclComm_t comm) : comm(comm) {}\n    ncclComm_t comm;\n  };\n\n  const Comm& GetOrCreate() {\n    if (!comm_) { Init(); }\n    return *comm_;\n  }\n\n  void Init() {\n    std::set<std::pair<int64_t, int64_t>> device_set;\n    for (int64_t parallel_id = 0; parallel_id < parallel_desc_.parallel_num(); ++parallel_id) {\n      int64_t machine_id = CHECK_JUST(parallel_desc_.MachineId4ParallelId(parallel_id));\n      int64_t device_id = CHECK_JUST(parallel_desc_.DeviceId4ParallelId(parallel_id));\n      device_set.emplace(std::make_pair(machine_id, device_id));\n    }\n    EagerCclCommMgr* comm_mgr = CHECK_NOTNULL(Singleton<EagerCclCommMgr>::Get());\n    ncclComm_t comm;\n    comm =\n        comm_mgr->As<EagerNcclCommMgr>()->GetCommForDeviceAndStreamName(device_set, stream_name_);\n    comm_.reset(new Comm(comm));\n  }\n\n  int device_index_;\n  std::string stream_name_;\n  ParallelDesc parallel_desc_;\n  std::unique_ptr<Comm> comm_;\n  IDX* host_num_keys_;\n  IDX* host_num_unique_matrix_;\n  embedding::KeyValueStore* key_value_store_;\n\n  EmbeddingInitializer* host_initializer_param_;\n  EmbeddingInitializer* device_initializer_param_;\n  int8_t* host_initializer_index_;\n  int8_t* device_initializer_index_;\n};\n\ntemplate<typename T, typename K, typename U, typename IDX>\nvoid LookupAndInitMissing(ep::Stream* stream, OneEmbeddingFusedLookupKernelState<IDX>* kernel_state,\n                          uint64_t seed, uint32_t num_unique, const int64_t embedding_size,\n                          const int64_t line_size, const bool put_to_store, const void* unique_ids,\n                          const void* table_ids, void* num_missing_ptr, void* missing_indices,\n                          void* store_values) {\n  embedding::KeyValueStore* store = kernel_state->KeyValueStore();\n  const EmbeddingInitializer* initializer_param = kernel_state->Initializers();\n  const int8_t* initializer_index = kernel_state->InitializerIndex();\n  void* host_num_keys = kernel_state->HostNumKeys();\n  LookupAndInitMissing<T, K, U, IDX>(stream, seed, store, initializer_param, initializer_index,\n                                     host_num_keys, num_unique, embedding_size, line_size,\n                                     put_to_store, unique_ids, table_ids, num_missing_ptr,\n                                     missing_indices, store_values);\n}\n\ntemplate<typename K, typename U, typename IDX>\nvoid SetIdShuffleDataPtrsParam(const void* ids_ptr,\n                               const FusedEmbeddingTmpBufferManager<K, U, IDX>& buffer_manager,\n                               data_shuffle::IdShuffleDataPtrs<K, U, IDX>* data_ptrs) {\n  data_ptrs->ids_ptr = reinterpret_cast<const K*>(ids_ptr);\n  data_ptrs->table_ids_ptr = buffer_manager.template Ptr<U>(FusedEmbeddingBufferType::kTableIds);\n  data_ptrs->num_partitioned_unique =\n      buffer_manager.template Ptr<IDX>(FusedEmbeddingBufferType::kNumPartitionedUnique);\n  data_ptrs->partitioned_unique_ids =\n      buffer_manager.template Ptr<K>(FusedEmbeddingBufferType::kPartitionedUniqueIds);\n  data_ptrs->partitioned_unique_table_ids =\n      buffer_manager.template Ptr<U>(FusedEmbeddingBufferType::kPartitionedUniqueTableIds);\n  data_ptrs->workspace_ptr = buffer_manager.Ptr(FusedEmbeddingBufferType::kWorkspace);\n  data_ptrs->workspace_size = buffer_manager.Size(FusedEmbeddingBufferType::kWorkspace);\n  data_ptrs->received_ids = buffer_manager.template Ptr<K>(FusedEmbeddingBufferType::kReceivedIds);\n  data_ptrs->received_table_ids =\n      buffer_manager.template Ptr<U>(FusedEmbeddingBufferType::kReceivedTableIds);\n  data_ptrs->inverse_unique_partition_indices_ptr =\n      buffer_manager.template Ptr<IDX>(FusedEmbeddingBufferType::kInverseUniquePartitionIndices);\n  data_ptrs->num_unique_matrix_ptr =\n      buffer_manager.template Ptr<IDX>(FusedEmbeddingBufferType::kNumUniqueMatrix);\n  data_ptrs->cur_rank_num_unique_ptr =\n      buffer_manager.template Ptr<IDX>(FusedEmbeddingBufferType::kCurRankNumUnique);\n  data_ptrs->cur_rank_unique_ids_ptr =\n      buffer_manager.template Ptr<K>(FusedEmbeddingBufferType::kCurRankUniqueIds);\n  data_ptrs->cur_rank_unique_table_ids_ptr =\n      buffer_manager.template Ptr<U>(FusedEmbeddingBufferType::kCurRankUniqueTableIds);\n  data_ptrs->cur_rank_inverse_indices_ptr =\n      buffer_manager.template Ptr<IDX>(FusedEmbeddingBufferType::kCurRankInverseIndices);\n}\n\ntemplate<typename K, typename T, typename V, typename U, typename IDX>\nclass OneEmbeddingFusedLookupKernel final : public user_op::OpKernel {\n public:\n  OneEmbeddingFusedLookupKernel() = default;\n  ~OneEmbeddingFusedLookupKernel() override = default;\n\n  std::shared_ptr<user_op::OpKernelState> CreateOpKernelState(\n      user_op::KernelInitContext* ctx) const override {\n    return std::make_shared<OneEmbeddingFusedLookupKernelState<IDX>>(ctx);\n  }\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state,\n               const user_op::OpKernelCache*) const override {\n    // IDX type is uint32_t, table_ids type is uint8_t.\n    DataType num_unique_matrix_dtype = DataType::kUInt32;\n    DataType table_ids_dtype = DataType::kUInt8;\n    CHECK_EQ(sizeof(IDX), GetSizeOfDataType(num_unique_matrix_dtype));\n    CHECK_EQ(sizeof(U), GetSizeOfDataType(table_ids_dtype));\n    auto* kernel_state = dynamic_cast<OneEmbeddingFusedLookupKernelState<IDX>*>(state);\n    CHECK(kernel_state != nullptr);\n    const user_op::Tensor* ids = ctx->Tensor4ArgNameAndIndex(\"ids\", 0);\n    user_op::Tensor* embeddings = ctx->Tensor4ArgNameAndIndex(\"embeddings\", 0);\n    const int32_t num_tables = ctx->Attr<int32_t>(\"num_tables\");\n    // default uint8_t as table_ids type, so num_tables can not greater than 256.\n    CHECK_LE(num_tables, 256) << num_tables;\n    const bool has_table_ids = ctx->has_input(\"table_ids\", 0);\n    const bool need_process_table_ids = (has_table_ids || num_tables > 1);\n    const int64_t num_ids = ids->shape_view().elem_cnt();\n    const int64_t parallel_num = ctx->parallel_ctx().parallel_num();\n    const int64_t parallel_id = ctx->parallel_ctx().parallel_id();\n    cudaStream_t cuda_stream = ctx->stream()->As<ep::CudaStream>()->cuda_stream();\n    DataType value_dtype = ctx->Attr<DataType>(\"dtype\");\n    const int64_t embedding_size = ctx->Attr<int64_t>(\"embedding_size\");\n    const int64_t line_size = ctx->Attr<int64_t>(\"line_size\");\n    const int64_t padding_idx = ctx->Attr<int64_t>(\"padding_idx\");\n    const bool has_padding_idx = ctx->Attr<bool>(\"has_padding_idx\");\n    user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n    bool need_unique_values = true;\n    bool need_embeddings =\n        (line_size != embedding_size) || (value_dtype != embeddings->data_type());\n    FusedEmbeddingTmpBufferManager<K, U, IDX> buffer_manager(\n        tmp_buffer->mut_dptr(), num_ids, parallel_num, need_process_table_ids, line_size,\n        embedding_size, need_unique_values, need_embeddings, value_dtype, embeddings->data_type());\n    CHECK_GE(tmp_buffer->shape_view().elem_cnt(), buffer_manager.TotalBufferSize());\n    ncclComm_t comm = kernel_state->comm();\n    IDX* host_num_unique_matrix = kernel_state->HostNumUniqueMatrix();\n    IDX* host_num_keys = kernel_state->HostNumKeys();\n    data_shuffle::IdShuffleDataPtrs<K, U, IDX> data_ptrs;\n    SetIdShuffleDataPtrsParam(ids->dptr(), buffer_manager, &data_ptrs);\n    // overwrite data_ptrs.table_ids_ptr\n    if (need_process_table_ids) {\n      U* tmp_table_ids_ptr = buffer_manager.template Ptr<U>(FusedEmbeddingBufferType::kTableIds);\n      data_ptrs.table_ids_ptr = tmp_table_ids_ptr;\n      if (has_table_ids) {\n        // use table_id default data_type uint8, if has input table_ids with different data_type,\n        // cast it to uint8.\n        const user_op::Tensor* table_ids = ctx->Tensor4ArgNameAndIndex(\"table_ids\", 0);\n        if (table_ids->data_type() != table_ids_dtype) {\n          std::unique_ptr<ep::primitive::Cast> cast_primitive =\n              ep::primitive::NewPrimitive<ep::primitive::CastFactory>(\n                  DeviceType::kCUDA, table_ids->data_type(), table_ids_dtype);\n          cast_primitive->Launch(ctx->stream(), table_ids->dptr(), tmp_table_ids_ptr,\n                                 table_ids->shape_view().elem_cnt());\n        } else {\n          data_ptrs.table_ids_ptr = reinterpret_cast<const U*>(table_ids->dptr());\n        }\n      } else {\n        const int32_t num_tables = ctx->Attr<int32_t>(\"num_tables\");\n        data_shuffle::GenerateTableIds<<<BlocksNum4ThreadsNum(num_ids), kCudaThreadsNumPerBlock, 0,\n                                         cuda_stream>>>(num_ids, num_tables, tmp_table_ids_ptr);\n      }\n    } else {\n      data_ptrs.table_ids_ptr = nullptr;\n    }\n\n    data_shuffle::IdShuffle(ctx->stream(), comm, data_ptrs, num_ids, parallel_id, parallel_num,\n                            num_unique_matrix_dtype, ids->data_type(), table_ids_dtype,\n                            need_process_table_ids, has_padding_idx, padding_idx,\n                            host_num_unique_matrix, host_num_keys);\n    uint32_t num_unique = *host_num_keys;\n\n    // lookup and put, if is_full_cache, not put to store.\n    uint32_t* num_missing_ptr =\n        buffer_manager.template Ptr<uint32_t>(FusedEmbeddingBufferType::kNumMissing);\n    uint32_t* missing_indices_ptr =\n        buffer_manager.template Ptr<uint32_t>(FusedEmbeddingBufferType::kMissingIndices);\n    void* values_ptr =\n        buffer_manager.template Ptr<V>(FusedEmbeddingBufferType::kCurRankUniqueValues);\n    T* cur_rank_embeddings_ptr =\n        need_embeddings\n            ? buffer_manager.template Ptr<T>(FusedEmbeddingBufferType::kCurRankUniqueEmbeddings)\n            : reinterpret_cast<T*>(values_ptr);\n    const bool is_full_cache = ctx->Attr<bool>(\"is_full_cache\");\n    const bool put_to_store = (!is_full_cache);\n    const int64_t seed = ctx->Attr<int64_t>(\"seed\");\n    LookupAndInitMissing<V, K, U, IDX>(\n        ctx->stream(), kernel_state, seed, num_unique, embedding_size, line_size, put_to_store,\n        data_ptrs.cur_rank_unique_ids_ptr, data_ptrs.cur_rank_unique_table_ids_ptr, num_missing_ptr,\n        missing_indices_ptr, values_ptr);\n    if (need_embeddings) {\n      CopyValuesToEmbeddings<V>(ctx->stream(), num_unique, embedding_size, line_size, value_dtype,\n                                embeddings->data_type(), reinterpret_cast<V*>(values_ptr),\n                                cur_rank_embeddings_ptr);\n    }\n\n    // embedding shuffle\n    int64_t cur_rank_num_ids = 0;\n    for (int64_t i = 0; i < parallel_num; ++i) {\n      cur_rank_num_ids += host_num_unique_matrix[i * parallel_num + parallel_id];\n    }\n    int64_t unique_partitioned_num_ids = 0;\n    for (int64_t i = 0; i < parallel_num; ++i) {\n      unique_partitioned_num_ids += host_num_unique_matrix[parallel_id * parallel_num + i];\n    }\n    T* reverse_unique_cur_rank_embeddings_ptr =\n        buffer_manager.template Ptr<T>(FusedEmbeddingBufferType::kReverseUniqueCurRankEmbeddings);\n    T* received_embeddings_ptr =\n        buffer_manager.template Ptr<T>(FusedEmbeddingBufferType::kReceivedEmbeddings);\n    GatherKernelUtilImpl<DeviceType::kCUDA, T, IDX>::Forward(\n        ctx->stream(), data_ptrs.cur_rank_inverse_indices_ptr, cur_rank_num_ids,\n        cur_rank_embeddings_ptr, Shape({1, num_unique, embedding_size}),\n        reverse_unique_cur_rank_embeddings_ptr, 0);\n\n    data_shuffle::ShuffleEmbeddings(cuda_stream, comm, parallel_id, parallel_num, num_ids,\n                                    embedding_size, embeddings->data_type(), host_num_unique_matrix,\n                                    reverse_unique_cur_rank_embeddings_ptr,\n                                    received_embeddings_ptr);\n    GatherKernelUtilImpl<DeviceType::kCUDA, T, IDX>::Forward(\n        ctx->stream(), data_ptrs.inverse_unique_partition_indices_ptr, num_ids,\n        received_embeddings_ptr, Shape({1, unique_partitioned_num_ids, embedding_size}),\n        embeddings->mut_dptr<T>(), 0);\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\nauto SingleDeviceKernel() {\n  return hob::make_custom(\"SingleDeviceKernel\", [](const user_op::KernelRegContext& ctx) {\n    return (ctx.parallel_ctx().parallel_num() == 1);\n  });\n}\n\n// Note(guoran): Default use U type as uint8_t, IDX as uint32_t. Because table_ids is optional, so\n// can not use it in hob, if has table_ids input and dtype is not uint8_t cast to uint8_t in kernel.\n#define REGISTER_CUDA_ONE_EMBEDDING_FUSED_LOOKUP_KERNEL(k_dtype_pair, t_dtype_pair, v_dtype_pair) \\\n  REGISTER_USER_KERNEL(\"one_embedding_fused_lookup\")                                              \\\n      .SetCreateFn<OneEmbeddingFusedLookupKernel<                                                 \\\n          OF_PP_PAIR_FIRST(k_dtype_pair), OF_PP_PAIR_FIRST(t_dtype_pair),                         \\\n          OF_PP_PAIR_FIRST(v_dtype_pair), uint8_t, uint32_t>>()                                   \\\n      .SetIsMatchedHob(                                                                           \\\n          (user_op::HobDeviceType() == DeviceType::kCUDA)                                         \\\n          && (user_op::HobDataType(\"ids\", 0) == OF_PP_PAIR_SECOND(k_dtype_pair))                  \\\n          && (user_op::HobDataType(\"embeddings\", 0) == OF_PP_PAIR_SECOND(t_dtype_pair))           \\\n          && (user_op::HobAttr<DataType>(\"dtype\") == OF_PP_PAIR_SECOND(v_dtype_pair))             \\\n          && !SingleDeviceKernel())                                                               \\\n      .SetInferTmpSizeFn([](user_op::InferContext* ctx) {                                         \\\n        const user_op::TensorDesc& ids = ctx->InputTensorDesc(\"ids\", 0);                          \\\n        const user_op::TensorDesc& embeddings = ctx->OutputTensorDesc(\"embeddings\", 0);           \\\n        const bool has_table_ids = ctx->has_input(\"table_ids\", 0);                                \\\n        const int32_t num_tables = ctx->Attr<int32_t>(\"num_tables\");                              \\\n        const bool need_process_table_ids = (has_table_ids || num_tables > 1);                    \\\n        DataType value_dtype = ctx->Attr<DataType>(\"dtype\");                                      \\\n        const int64_t embedding_size = ctx->Attr<int64_t>(\"embedding_size\");                      \\\n        const int64_t line_size = ctx->Attr<int64_t>(\"line_size\");                                \\\n        bool need_embeddings =                                                                    \\\n            (line_size != embedding_size) || (value_dtype != embeddings.data_type());             \\\n        FusedEmbeddingTmpBufferManager<OF_PP_PAIR_FIRST(k_dtype_pair), uint8_t, uint32_t>         \\\n            buffer_manager(nullptr, ids.shape().elem_cnt(), ctx->parallel_ctx().parallel_num(),   \\\n                           need_process_table_ids, line_size, embedding_size, true,               \\\n                           need_embeddings, value_dtype, embeddings.data_type());                 \\\n        return buffer_manager.TotalBufferSize();                                                  \\\n      });\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_CUDA_ONE_EMBEDDING_FUSED_LOOKUP_KERNEL, ID_DATA_TYPE_SEQ,\n                                 FLOATING_DATA_TYPE_SEQ HALF_DATA_TYPE_SEQ, EMBEDDING_DATA_TYPE_SEQ)\n\ntemplate<typename IDX>\nclass OneEmbeddingFusedLookupLocalKernelState final : public user_op::OpKernelState {\n public:\n  explicit OneEmbeddingFusedLookupLocalKernelState(user_op::KernelInitContext* ctx)\n      : device_index_(-1) {\n    OF_CUDA_CHECK(cudaGetDevice(&device_index_));\n    const int64_t parallel_id = ctx->parallel_ctx().parallel_id();\n    OF_CUDA_CHECK(cudaMallocHost(&host_num_keys_, sizeof(IDX)));\n    const std::string& embedding_name = ctx->Attr<std::string>(\"embedding_name\");\n    key_value_store_ = Singleton<embedding::EmbeddingManager>::Get()->GetKeyValueStore(\n        embedding_name, parallel_id);\n    uint32_t max_query_length = ctx->TensorDesc4ArgNameAndIndex(\"ids\", 0)->shape().elem_cnt();\n    key_value_store_->ReserveQueryLength(max_query_length);\n\n    const int64_t embedding_size = ctx->Attr<int64_t>(\"embedding_size\");\n    const int64_t line_size = ctx->Attr<int64_t>(\"line_size\");\n    // Note(guoran): This op have no optimizer info, so set embedding states initializer constant\n    // 0, which may make error in optimizer with initial_accumulator_value like adagrad and ftrl.\n    std::string state_initializer;\n    MakeConstantInitializerAttr(embedding_size, line_size, {}, &state_initializer);\n\n    std::vector<EmbeddingInitializer> initializer_param;\n    std::vector<int8_t> initializer_index;\n    ParseInitializers(line_size, embedding_size, state_initializer,\n                      ctx->Attr<std::string>(\"embedding_tables\"), &initializer_param,\n                      &initializer_index);\n\n    const size_t param_size_bytes = initializer_param.size() * sizeof(EmbeddingInitializer);\n    OF_CUDA_CHECK(cudaMallocHost(&host_initializer_param_, param_size_bytes));\n    std::memcpy(host_initializer_param_, initializer_param.data(), param_size_bytes);\n    OF_CUDA_CHECK(cudaMalloc(&device_initializer_param_, param_size_bytes));\n    OF_CUDA_CHECK(cudaMemcpyAsync(device_initializer_param_, host_initializer_param_,\n                                  param_size_bytes, cudaMemcpyDefault,\n                                  ctx->stream()->As<ep::CudaStream>()->cuda_stream()));\n\n    const size_t index_size_bytes = initializer_index.size() * sizeof(int8_t);\n    OF_CUDA_CHECK(cudaMallocHost(&host_initializer_index_, index_size_bytes));\n    std::memcpy(host_initializer_index_, initializer_index.data(), index_size_bytes);\n    OF_CUDA_CHECK(cudaMalloc(&device_initializer_index_, index_size_bytes));\n    OF_CUDA_CHECK(cudaMemcpyAsync(device_initializer_index_, host_initializer_index_,\n                                  index_size_bytes, cudaMemcpyDefault,\n                                  ctx->stream()->As<ep::CudaStream>()->cuda_stream()));\n  }\n  ~OneEmbeddingFusedLookupLocalKernelState() override {\n    CudaCurrentDeviceGuard guard(device_index_);\n    OF_CUDA_CHECK(cudaFreeHost(host_num_keys_));\n    OF_CUDA_CHECK(cudaFreeHost(host_initializer_param_));\n    OF_CUDA_CHECK(cudaFree(device_initializer_param_));\n    OF_CUDA_CHECK(cudaFreeHost(host_initializer_index_));\n    OF_CUDA_CHECK(cudaFree(device_initializer_index_));\n  }\n\n  IDX* HostNumKeys() { return host_num_keys_; }\n\n  embedding::KeyValueStore* KeyValueStore() { return key_value_store_; }\n\n  const int8_t* InitializerIndex() { return device_initializer_index_; }\n  const EmbeddingInitializer* Initializers() { return device_initializer_param_; }\n\n private:\n  int device_index_;\n  IDX* host_num_keys_;\n  embedding::KeyValueStore* key_value_store_;\n\n  EmbeddingInitializer* host_initializer_param_;\n  EmbeddingInitializer* device_initializer_param_;\n  int8_t* host_initializer_index_;\n  int8_t* device_initializer_index_;\n};\n\ntemplate<typename T, typename K, typename U, typename IDX>\nvoid LookupAndInitMissing(ep::Stream* stream,\n                          OneEmbeddingFusedLookupLocalKernelState<IDX>* kernel_state, uint64_t seed,\n                          uint32_t num_unique, const int64_t embedding_size,\n                          const int64_t line_size, const bool put_to_store, const void* unique_ids,\n                          const void* table_ids, void* num_missing_ptr, void* missing_indices,\n                          void* store_values) {\n  embedding::KeyValueStore* store = kernel_state->KeyValueStore();\n  const EmbeddingInitializer* initializer_param = kernel_state->Initializers();\n  const int8_t* initializer_index = kernel_state->InitializerIndex();\n  void* host_num_keys = kernel_state->HostNumKeys();\n  LookupAndInitMissing<T, K, U, IDX>(stream, seed, store, initializer_param, initializer_index,\n                                     host_num_keys, num_unique, embedding_size, line_size,\n                                     put_to_store, unique_ids, table_ids, num_missing_ptr,\n                                     missing_indices, store_values);\n}\n\ntemplate<typename K, typename U, typename IDX>\nclass FusedLocalEmbeddingTmpBufferManager final {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(FusedLocalEmbeddingTmpBufferManager);\n  FusedLocalEmbeddingTmpBufferManager(void* ptr, const int64_t num_ids, bool need_process_table_ids,\n                                      int64_t line_size, int64_t embedding_size,\n                                      bool need_embeddings, DataType value_dtype,\n                                      DataType embedding_dtype)\n      : offset_(0),\n        offsets_(static_cast<size_t>(FusedEmbeddingBufferType::kMaxType), -1),\n        sizes_(static_cast<size_t>(FusedEmbeddingBufferType::kMaxType)),\n        ptr_(ptr) {\n    // id shuffle\n    const size_t table_ids_bytes = need_process_table_ids ? num_ids * sizeof(U) : 0;\n    AllocBuffer(FusedEmbeddingBufferType::kTableIds, table_ids_bytes);\n    const size_t hash_table_capacity = num_ids;\n    AllocBuffer(FusedEmbeddingBufferType::kWorkspace,\n                hash_table_capacity * sizeof(data_shuffle::TableEntry<K>));\n    size_t cur_rank_num_ids = num_ids;\n    size_t cur_rank_num_table_ids = cur_rank_num_ids;\n    size_t cur_rank_num_unique_bytes = sizeof(uint32_t);\n    AllocBuffer(FusedEmbeddingBufferType::kCurRankNumUnique, cur_rank_num_unique_bytes);\n    size_t cur_rank_unique_ids_bytes = cur_rank_num_ids * sizeof(K);\n    AllocBuffer(FusedEmbeddingBufferType::kCurRankUniqueIds, cur_rank_unique_ids_bytes);\n    size_t cur_rank_unique_table_ids_bytes = cur_rank_num_table_ids * sizeof(U);\n    AllocBuffer(FusedEmbeddingBufferType::kCurRankUniqueTableIds, cur_rank_unique_table_ids_bytes);\n    size_t cur_rank_inverse_indices_bytes = cur_rank_num_ids * sizeof(IDX);\n    AllocBuffer(FusedEmbeddingBufferType::kCurRankInverseIndices, cur_rank_inverse_indices_bytes);\n    // embedding lookup\n    size_t num_missing_bytes = sizeof(uint32_t);\n    AllocBuffer(FusedEmbeddingBufferType::kNumMissing, num_missing_bytes);\n    size_t missing_indices_bytes = cur_rank_num_ids * sizeof(uint32_t);\n    AllocBuffer(FusedEmbeddingBufferType::kMissingIndices, missing_indices_bytes);\n    size_t cur_rank_unique_values_bytes =\n        cur_rank_num_ids * line_size * GetSizeOfDataType(value_dtype);\n    AllocBuffer(FusedEmbeddingBufferType::kCurRankUniqueValues, cur_rank_unique_values_bytes);\n    if (need_embeddings) {\n      size_t cur_rank_unique_embeddings_bytes =\n          cur_rank_num_ids * embedding_size * GetSizeOfDataType(embedding_dtype);\n      AllocBuffer(FusedEmbeddingBufferType::kCurRankUniqueEmbeddings,\n                  cur_rank_unique_embeddings_bytes);\n    }\n  }\n\n  template<typename T = void>\n  T* Ptr(FusedEmbeddingBufferType type) const {\n    CHECK(ptr_ != nullptr);\n    int64_t offset = offsets_.at(static_cast<size_t>(type));\n    CHECK_NE(offset, -1);\n    return reinterpret_cast<T*>(reinterpret_cast<char*>(ptr_) + offset);\n  }\n\n  int64_t Size(FusedEmbeddingBufferType type) const { return sizes_.at(static_cast<size_t>(type)); }\n\n  size_t TotalBufferSize() const { return offset_; }\n\n private:\n  void AllocBuffer(FusedEmbeddingBufferType type, size_t size) {\n    const size_t type_id = static_cast<size_t>(type);\n    CHECK_EQ(offsets_.at(type_id), -1);\n    offsets_.at(type_id) = offset_;\n    sizes_.at(type_id) = size;\n    offset_ += GetCudaAlignedSize(size);\n  }\n  size_t offset_;\n  std::vector<int64_t> offsets_;\n  std::vector<int64_t> sizes_;\n  void* ptr_;\n};\n\ntemplate<typename K, typename T, typename V, typename U, typename IDX>\nclass OneEmbeddingFusedLookupLocalKernel final : public user_op::OpKernel {\n public:\n  OneEmbeddingFusedLookupLocalKernel() = default;\n  ~OneEmbeddingFusedLookupLocalKernel() override = default;\n\n  std::shared_ptr<user_op::OpKernelState> CreateOpKernelState(\n      user_op::KernelInitContext* ctx) const override {\n    return std::make_shared<OneEmbeddingFusedLookupLocalKernelState<IDX>>(ctx);\n  }\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state,\n               const user_op::OpKernelCache*) const override {\n    // IDX type is uint32_t, table_ids type is uint8_t.\n    DataType num_unique_matrix_dtype = DataType::kUInt32;\n    DataType table_ids_dtype = DataType::kUInt8;\n    CHECK_EQ(sizeof(IDX), GetSizeOfDataType(num_unique_matrix_dtype));\n    CHECK_EQ(sizeof(U), GetSizeOfDataType(table_ids_dtype));\n    auto* kernel_state = dynamic_cast<OneEmbeddingFusedLookupLocalKernelState<IDX>*>(state);\n    CHECK(kernel_state != nullptr);\n    const user_op::Tensor* ids = ctx->Tensor4ArgNameAndIndex(\"ids\", 0);\n    user_op::Tensor* embeddings = ctx->Tensor4ArgNameAndIndex(\"embeddings\", 0);\n    const int32_t num_tables = ctx->Attr<int32_t>(\"num_tables\");\n    // default uint8_t as table_ids type, so num_tables can not greater than 256.\n    CHECK_LE(num_tables, 256) << num_tables;\n    const bool has_table_ids = ctx->has_input(\"table_ids\", 0);\n    const bool need_process_table_ids = (has_table_ids || num_tables > 1);\n    const int64_t num_ids = ids->shape_view().elem_cnt();\n    cudaStream_t cuda_stream = ctx->stream()->As<ep::CudaStream>()->cuda_stream();\n    DataType value_dtype = ctx->Attr<DataType>(\"dtype\");\n    const int64_t embedding_size = ctx->Attr<int64_t>(\"embedding_size\");\n    const int64_t line_size = ctx->Attr<int64_t>(\"line_size\");\n    const int64_t padding_idx = ctx->Attr<int64_t>(\"padding_idx\");\n    const bool has_padding_idx = ctx->Attr<bool>(\"has_padding_idx\");\n    user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n    bool need_embeddings =\n        (line_size != embedding_size) || (value_dtype != embeddings->data_type());\n    FusedLocalEmbeddingTmpBufferManager<K, U, IDX> buffer_manager(\n        tmp_buffer->mut_dptr(), num_ids, need_process_table_ids, line_size, embedding_size,\n        need_embeddings, value_dtype, embeddings->data_type());\n    CHECK_GE(tmp_buffer->shape_view().elem_cnt(), buffer_manager.TotalBufferSize());\n    IDX* host_num_keys = kernel_state->HostNumKeys();\n\n    const U* table_ids_ptr = nullptr;\n    if (need_process_table_ids) {\n      U* tmp_table_ids_ptr = buffer_manager.template Ptr<U>(FusedEmbeddingBufferType::kTableIds);\n      table_ids_ptr = tmp_table_ids_ptr;\n      if (has_table_ids) {\n        // use table_id default data_type uint8, if has input table_ids with different data_type,\n        // cast it to uint8.\n        const user_op::Tensor* table_ids = ctx->Tensor4ArgNameAndIndex(\"table_ids\", 0);\n        if (table_ids->data_type() != table_ids_dtype) {\n          std::unique_ptr<ep::primitive::Cast> cast_primitive =\n              ep::primitive::NewPrimitive<ep::primitive::CastFactory>(\n                  DeviceType::kCUDA, table_ids->data_type(), table_ids_dtype);\n          cast_primitive->Launch(ctx->stream(), table_ids->dptr(), tmp_table_ids_ptr,\n                                 table_ids->shape_view().elem_cnt());\n        } else {\n          table_ids_ptr = reinterpret_cast<const U*>(table_ids->dptr());\n        }\n      } else {\n        const int32_t num_tables = ctx->Attr<int32_t>(\"num_tables\");\n        data_shuffle::GenerateTableIds<<<BlocksNum4ThreadsNum(num_ids), kCudaThreadsNumPerBlock, 0,\n                                         cuda_stream>>>(num_ids, num_tables, tmp_table_ids_ptr);\n      }\n    }\n    IDX* num_unique_ptr =\n        buffer_manager.template Ptr<IDX>(FusedEmbeddingBufferType::kCurRankNumUnique);\n    K* unique_ids_ptr = buffer_manager.template Ptr<K>(FusedEmbeddingBufferType::kCurRankUniqueIds);\n    U* unique_table_ids_ptr =\n        buffer_manager.template Ptr<U>(FusedEmbeddingBufferType::kCurRankUniqueTableIds);\n    IDX* inverse_indices_ptr =\n        buffer_manager.template Ptr<IDX>(FusedEmbeddingBufferType::kCurRankInverseIndices);\n    void* workspace_ptr = buffer_manager.Ptr(FusedEmbeddingBufferType::kWorkspace);\n    const size_t workspace_bytes = buffer_manager.Size(FusedEmbeddingBufferType::kWorkspace);\n    int64_t hash_capacity = num_ids;\n    data_shuffle::UniqueAndPartition<K, U, IDX, embedding::GlobalUniqueHash>(\n        cuda_stream, num_ids, hash_capacity, 1, reinterpret_cast<const K*>(ids->dptr()),\n        table_ids_ptr, num_unique_ptr, unique_ids_ptr, unique_table_ids_ptr, inverse_indices_ptr,\n        reinterpret_cast<data_shuffle::TableEntry<K>*>(workspace_ptr), workspace_bytes,\n        need_process_table_ids, has_padding_idx, padding_idx);\n\n    OF_CUDA_CHECK(cudaMemcpyAsync(host_num_keys, num_unique_ptr, sizeof(IDX), cudaMemcpyDefault,\n                                  cuda_stream));\n    CHECK_JUST(ctx->stream()->Sync());\n\n    uint32_t num_unique = *host_num_keys;\n\n    // lookup and put, if is_full_cache, not put to store.\n    uint32_t* num_missing_ptr =\n        buffer_manager.template Ptr<uint32_t>(FusedEmbeddingBufferType::kNumMissing);\n    uint32_t* missing_indices_ptr =\n        buffer_manager.template Ptr<uint32_t>(FusedEmbeddingBufferType::kMissingIndices);\n    void* values_ptr =\n        buffer_manager.template Ptr<V>(FusedEmbeddingBufferType::kCurRankUniqueValues);\n    T* cur_rank_embeddings_ptr =\n        need_embeddings\n            ? buffer_manager.template Ptr<T>(FusedEmbeddingBufferType::kCurRankUniqueEmbeddings)\n            : reinterpret_cast<T*>(values_ptr);\n    const bool is_full_cache = ctx->Attr<bool>(\"is_full_cache\");\n    const bool put_to_store = (!is_full_cache);\n    const int64_t seed = ctx->Attr<int64_t>(\"seed\");\n    LookupAndInitMissing<V, K, U, IDX>(\n        ctx->stream(), kernel_state, seed, num_unique, embedding_size, line_size, put_to_store,\n        unique_ids_ptr, unique_table_ids_ptr, num_missing_ptr, missing_indices_ptr, values_ptr);\n    if (need_embeddings) {\n      CopyValuesToEmbeddings<V>(ctx->stream(), num_unique, embedding_size, line_size, value_dtype,\n                                embeddings->data_type(), reinterpret_cast<V*>(values_ptr),\n                                cur_rank_embeddings_ptr);\n    }\n    // gather\n    GatherKernelUtilImpl<DeviceType::kCUDA, T, IDX>::Forward(\n        ctx->stream(), inverse_indices_ptr, num_ids, cur_rank_embeddings_ptr,\n        Shape({1, num_unique, embedding_size}), embeddings->mut_dptr<T>(), 0);\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n// Note(guoran): Default use U type as uint8_t, IDX as uint32_t. Because table_ids is optional, so\n// can not use it in hob, if has table_ids input and dtype is not uint8_t cast to uint8_t in kernel.\n#define REGISTER_CUDA_ONE_EMBEDDING_FUSED_LOOKUP_LOCAL_KERNEL(k_dtype_pair, t_dtype_pair,         \\\n                                                              v_dtype_pair)                       \\\n  REGISTER_USER_KERNEL(\"one_embedding_fused_lookup\")                                              \\\n      .SetCreateFn<OneEmbeddingFusedLookupLocalKernel<                                            \\\n          OF_PP_PAIR_FIRST(k_dtype_pair), OF_PP_PAIR_FIRST(t_dtype_pair),                         \\\n          OF_PP_PAIR_FIRST(v_dtype_pair), uint8_t, uint32_t>>()                                   \\\n      .SetIsMatchedHob(                                                                           \\\n          (user_op::HobDeviceType() == DeviceType::kCUDA)                                         \\\n          && (user_op::HobDataType(\"ids\", 0) == OF_PP_PAIR_SECOND(k_dtype_pair))                  \\\n          && (user_op::HobDataType(\"embeddings\", 0) == OF_PP_PAIR_SECOND(t_dtype_pair))           \\\n          && (user_op::HobAttr<DataType>(\"dtype\") == OF_PP_PAIR_SECOND(v_dtype_pair))             \\\n          && SingleDeviceKernel())                                                                \\\n      .SetInferTmpSizeFn([](user_op::InferContext* ctx) {                                         \\\n        const user_op::TensorDesc& ids = ctx->InputTensorDesc(\"ids\", 0);                          \\\n        const user_op::TensorDesc& embeddings = ctx->OutputTensorDesc(\"embeddings\", 0);           \\\n        const bool has_table_ids = ctx->has_input(\"table_ids\", 0);                                \\\n        const int32_t num_tables = ctx->Attr<int32_t>(\"num_tables\");                              \\\n        const bool need_process_table_ids = (has_table_ids || num_tables > 1);                    \\\n        DataType value_dtype = ctx->Attr<DataType>(\"dtype\");                                      \\\n        const int64_t embedding_size = ctx->Attr<int64_t>(\"embedding_size\");                      \\\n        const int64_t line_size = ctx->Attr<int64_t>(\"line_size\");                                \\\n        bool need_embeddings =                                                                    \\\n            (line_size != embedding_size) || (value_dtype != embeddings.data_type());             \\\n        FusedLocalEmbeddingTmpBufferManager<OF_PP_PAIR_FIRST(k_dtype_pair), uint8_t, uint32_t>    \\\n            buffer_manager(nullptr, ids.shape().elem_cnt(), need_process_table_ids, line_size,    \\\n                           embedding_size, need_embeddings, value_dtype, embeddings.data_type()); \\\n        return buffer_manager.TotalBufferSize();                                                  \\\n      });\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_CUDA_ONE_EMBEDDING_FUSED_LOOKUP_LOCAL_KERNEL,\n                                 ID_DATA_TYPE_SEQ, FLOATING_DATA_TYPE_SEQ HALF_DATA_TYPE_SEQ,\n                                 EMBEDDING_DATA_TYPE_SEQ)\n\nclass OneEmbeddingFusedLookupGradKernel final : public user_op::OpKernel {\n public:\n  OneEmbeddingFusedLookupGradKernel() = default;\n  ~OneEmbeddingFusedLookupGradKernel() override = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    // do nothing\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\nREGISTER_USER_KERNEL(\"one_embedding_fused_lookup_grad\")\n    .SetCreateFn<OneEmbeddingFusedLookupGradKernel>()\n    .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA));\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/one_embedding_update_kernels.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/device/cuda_util.h\"\n#include \"oneflow/user/kernels/model_update_kernel_util.h\"\n#include \"oneflow/core/embedding/embedding_manager.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename T, typename G, typename IDX>\n__global__ void SGDUpdateKernel(const int64_t embedding_size, T scale, float l1, float l2,\n                                float weight_decay, float learning_rate_val,\n                                const IDX* num_unique_ids, const float* learning_rate,\n                                const T* scale_by_ptr, const T* down_scale_by_ptr,\n                                const int64_t* skip_if, const G* model_diff, const T* model,\n                                T* updated_model) {\n  if (skip_if != nullptr && *skip_if != 0) {\n    const int64_t n = *num_unique_ids * embedding_size;\n    CUDA_1D_KERNEL_LOOP(i, n) { updated_model[i] = model[i]; }\n  } else {\n    if (scale_by_ptr != nullptr) { scale *= *scale_by_ptr; }\n    if (down_scale_by_ptr != nullptr) { scale /= *down_scale_by_ptr; }\n    if (learning_rate != nullptr) { learning_rate_val = *learning_rate; }\n    const int64_t n = *num_unique_ids * embedding_size;\n    CUDA_1D_KERNEL_LOOP(i, n) {\n      updated_model[i] = model[i];\n      SGDUpdateFunctor<T, G>()(model_diff + i, updated_model + i, scale, l1, l2, weight_decay,\n                               learning_rate_val);\n    }\n  }\n}\n\n__device__ void GetMomentumOffset(const int32_t line_size, const int32_t embedding_size,\n                                  int64_t model_diff_offset, int64_t* model_offset,\n                                  int64_t* momentum_offset) {\n  const int32_t row = model_diff_offset / embedding_size;\n  const int32_t col = model_diff_offset - row * embedding_size;\n  *model_offset = row * line_size + col;\n  *momentum_offset = *model_offset + embedding_size;\n}\n\ntemplate<typename T, typename G, typename IDX>\n__global__ void MomentumUpdateKernel(const int64_t line_size, const int64_t embedding_size, T scale,\n                                     float l1, float l2, float weight_decay, float beta,\n                                     float dampening, bool nesterov, bool maximize,\n                                     float learning_rate_val, const IDX* num_unique_ids,\n                                     const float* learning_rate, const T* scale_by_ptr,\n                                     const T* down_scale_by_ptr, const int64_t* skip_if,\n                                     const G* model_diff, const T* unique_values,\n                                     T* updated_unique_values) {\n  if (skip_if != nullptr && *skip_if != 0) {\n    const int64_t n = *num_unique_ids * line_size;\n    CUDA_1D_KERNEL_LOOP(i, n) { updated_unique_values[i] = unique_values[i]; }\n  } else {\n    if (scale_by_ptr != nullptr) { scale *= *scale_by_ptr; }\n    if (down_scale_by_ptr != nullptr) { scale /= *down_scale_by_ptr; }\n    if (learning_rate != nullptr) { learning_rate_val = *learning_rate; }\n    const int64_t n = *num_unique_ids * embedding_size;\n    CUDA_1D_KERNEL_LOOP(i, n) {\n      int64_t model_offset;\n      int64_t momentum_offset;\n      GetMomentumOffset(line_size, embedding_size, i, &model_offset, &momentum_offset);\n      updated_unique_values[model_offset] = unique_values[model_offset];\n      updated_unique_values[momentum_offset] = unique_values[momentum_offset];\n      MomentumUpdateFunctor<T, G>()(model_diff + i, updated_unique_values + model_offset,\n                                    updated_unique_values + momentum_offset, scale, l1, l2, beta,\n                                    dampening, nesterov, maximize, weight_decay, learning_rate_val);\n    }\n  }\n}\n\n__device__ void GetAdamOffset(const int32_t line_size, const int32_t embedding_size,\n                              int64_t model_diff_offset, int64_t* model_offset, int64_t* m_offset,\n                              int64_t* v_offset) {\n  const int32_t row = model_diff_offset / embedding_size;\n  const int32_t col = model_diff_offset - row * embedding_size;\n  *model_offset = row * line_size + col;\n  *m_offset = *model_offset + embedding_size;\n  *v_offset = *model_offset + 2 * embedding_size;\n}\n\ntemplate<typename T, typename G, typename IDX>\n__global__ void AdamUpdateKernel(const int32_t line_size, const int32_t embedding_size, T scale,\n                                 float l1, float l2, float weight_decay, float beta1, float beta2,\n                                 float epsilon, float learning_rate_val, float bias_correction1_val,\n                                 float bias_correction2_val, const float* bias_correction1_ptr,\n                                 const float* bias_correction2_ptr, const IDX* num_unique_ids,\n                                 const float* learning_rate, const T* scale_by_ptr,\n                                 const T* down_scale_by_ptr, const int64_t* skip_if,\n                                 const G* model_diff, const T* unique_values,\n                                 T* updated_unique_values) {\n  if (skip_if != nullptr && *skip_if != 0) {\n    const int64_t n = *num_unique_ids * line_size;\n    CUDA_1D_KERNEL_LOOP(i, n) {\n      // The n is the unique_values elem_cnt, so not need to use GetAdamOffset.\n      updated_unique_values[i] = unique_values[i];\n    }\n  } else {\n    if (scale_by_ptr != nullptr) { scale *= *scale_by_ptr; }\n    if (down_scale_by_ptr != nullptr) { scale /= *down_scale_by_ptr; }\n    if (bias_correction1_ptr != nullptr) { bias_correction1_val = *bias_correction1_ptr; }\n    if (bias_correction2_ptr != nullptr) { bias_correction2_val = *bias_correction2_ptr; }\n    if (learning_rate != nullptr) { learning_rate_val = *learning_rate; }\n    const int64_t n = *num_unique_ids * embedding_size;\n    // The n is model_diff elem_cnt.\n    CUDA_1D_KERNEL_LOOP(i, n) {\n      int64_t model_offset;\n      int64_t m_offset;\n      int64_t v_offset;\n      GetAdamOffset(line_size, embedding_size, i, &model_offset, &m_offset, &v_offset);\n      updated_unique_values[model_offset] = unique_values[model_offset];\n      updated_unique_values[m_offset] = unique_values[m_offset];\n      updated_unique_values[v_offset] = unique_values[v_offset];\n      AdamUpdateFunctor<T, G>()(model_diff + i, updated_unique_values + model_offset,\n                                updated_unique_values + m_offset, updated_unique_values + v_offset,\n                                nullptr, scale, l1, l2, beta1, beta2, epsilon, weight_decay, false,\n                                bias_correction1_val, bias_correction2_val, learning_rate_val);\n    }\n  }\n}\n\n// Note(guoran): The SmartDecaySparseAdam is from\n// https://github.com/pytorch/pytorch/blob/master/caffe2/sgd/adam_op.h#L57\ntemplate<typename T, typename G, typename IDX>\n__global__ void SmartDecaySparseAdamUpdateKernel(\n    const int32_t line_size, const int32_t embedding_size, T scale, float l1, float l2,\n    float weight_decay, float beta1, float beta2, float epsilon, float learning_rate_val,\n    int64_t step_col_offset, const IDX* num_unique_ids, const float* learning_rate,\n    const int64_t* train_step_ptr, const T* scale_by_ptr, const T* down_scale_by_ptr,\n    const int64_t* skip_if, const G* model_diff, const T* unique_values, T* updated_unique_values) {\n  if (skip_if != nullptr && *skip_if != 0) {\n    const int64_t n = *num_unique_ids * line_size;\n    CUDA_1D_KERNEL_LOOP(i, n) {\n      // The n is the unique_values elem_cnt, so not need to use GetAdamOffset.\n      updated_unique_values[i] = unique_values[i];\n    }\n  } else {\n    if (scale_by_ptr != nullptr) { scale *= *scale_by_ptr; }\n    if (down_scale_by_ptr != nullptr) { scale /= *down_scale_by_ptr; }\n    if (learning_rate != nullptr) { learning_rate_val = *learning_rate; }\n    const int64_t n = *num_unique_ids * embedding_size;\n    // The n is model_diff elem_cnt.\n    CUDA_1D_KERNEL_LOOP(i, n) {\n      const int32_t row = i / embedding_size;\n      const int32_t col = i - row * embedding_size;\n      int64_t model_offset = row * line_size + col;\n      int64_t m_offset = model_offset + embedding_size;\n      int64_t v_offset = model_offset + 2 * embedding_size;\n      int64_t step_offset = row * line_size + step_col_offset;\n      const T model_val = *(unique_values + model_offset);\n      const T m_val = *(unique_values + m_offset);\n      const T v_val = *(unique_values + v_offset);\n      T model_diff_t =\n          CastScaleRegularizeGradientFunctor<T, G>()(*(model_diff + i), model_val, scale, l1, l2);\n      int64_t prev_step = *reinterpret_cast<const int64_t*>(unique_values + step_offset);\n      int64_t cur_step = *train_step_ptr + 1;\n      int64_t skip_step = cur_step - prev_step;\n      float catchup = 0.0;\n      if (skip_step > 1) {\n        catchup = m_val * beta1 * (1 - pow(beta1, skip_step - 1)) / (1 - beta1);\n      }\n      const T next_m = pow(beta1, skip_step) * m_val + (1 - beta1) * model_diff_t;\n      const T next_v = pow(beta2, skip_step) * v_val + (1 - beta2) * model_diff_t * model_diff_t;\n      updated_unique_values[m_offset] = next_m;\n      updated_unique_values[v_offset] = next_v;\n      updated_unique_values[model_offset] =\n          model_val - (learning_rate_val * (next_m + catchup)) / (sqrt(next_v) + epsilon);\n      if (col == 0) { *reinterpret_cast<int64_t*>(updated_unique_values + step_offset) = cur_step; }\n    }\n  }\n}\n\ntemplate<typename T, typename G, typename IDX>\n__global__ void AdagradUpdateKernel(const int64_t line_size, const int64_t embedding_size, T scale,\n                                    float l1, float l2, float weight_decay, float lr_decay,\n                                    float epsilon, float learning_rate_val, int64_t train_step,\n                                    const IDX* num_unique_ids, const float* learning_rate,\n                                    const int64_t* train_step_ptr, const T* scale_by_ptr,\n                                    const T* down_scale_by_ptr, const int64_t* skip_if,\n                                    const G* model_diff, const T* unique_values,\n                                    T* updated_unique_values) {\n  if (skip_if != nullptr && *skip_if != 0) {\n    const int64_t n = *num_unique_ids * line_size;\n    CUDA_1D_KERNEL_LOOP(i, n) { updated_unique_values[i] = unique_values[i]; }\n  } else {\n    if (train_step_ptr != nullptr) { train_step = *train_step_ptr + 1; }\n    if (scale_by_ptr != nullptr) { scale *= *scale_by_ptr; }\n    if (down_scale_by_ptr != nullptr) { scale /= *down_scale_by_ptr; }\n    if (learning_rate != nullptr) { learning_rate_val = *learning_rate; }\n    learning_rate_val = learning_rate_val / (1 + (train_step - 1) * lr_decay);\n    const int64_t n = *num_unique_ids * embedding_size;\n    CUDA_1D_KERNEL_LOOP(i, n) {\n      int64_t model_offset;\n      int64_t sum_offset;\n      GetMomentumOffset(line_size, embedding_size, i, &model_offset, &sum_offset);\n      updated_unique_values[model_offset] = unique_values[model_offset];\n      updated_unique_values[sum_offset] = unique_values[sum_offset];\n      AdagradUpdateFunctor<T, G>()(model_diff + i, updated_unique_values + model_offset,\n                                   updated_unique_values + sum_offset, scale, l1, l2, epsilon,\n                                   weight_decay, learning_rate_val);\n    }\n  }\n}\n\n__device__ void GetFtrlOffset(const int32_t line_size, const int32_t embedding_size,\n                              int64_t model_diff_offset, int64_t* model_offset,\n                              int64_t* accumulate_offset, int64_t* z_offset) {\n  const int32_t row = model_diff_offset / embedding_size;\n  const int32_t col = model_diff_offset - row * embedding_size;\n  *model_offset = row * line_size + col;\n  *accumulate_offset = *model_offset + embedding_size;\n  *z_offset = *model_offset + 2 * embedding_size;\n}\n\ntemplate<typename T, typename G, typename IDX>\n__global__ void FtrlUpdateKernel(const int32_t line_size, const int32_t embedding_size, T scale,\n                                 float l1, float l2, float weight_decay, float lr_power,\n                                 float lambda1, float lambda2, float beta, float learning_rate_val,\n                                 const IDX* num_unique_ids, const float* learning_rate,\n                                 const T* scale_by_ptr, const T* down_scale_by_ptr,\n                                 const int64_t* skip_if, const G* model_diff,\n                                 const T* unique_values, T* updated_unique_values) {\n  if (skip_if != nullptr && *skip_if != 0) {\n    const int64_t n = *num_unique_ids * line_size;\n    CUDA_1D_KERNEL_LOOP(i, n) { updated_unique_values[i] = unique_values[i]; }\n  } else {\n    if (scale_by_ptr != nullptr) { scale *= *scale_by_ptr; }\n    if (down_scale_by_ptr != nullptr) { scale /= *down_scale_by_ptr; }\n    if (learning_rate != nullptr) { learning_rate_val = *learning_rate; }\n    const int64_t n = *num_unique_ids * embedding_size;\n    CUDA_1D_KERNEL_LOOP(i, n) {\n      int64_t model_offset;\n      int64_t accumulate_offset;\n      int64_t z_offset;\n      GetFtrlOffset(line_size, embedding_size, i, &model_offset, &accumulate_offset, &z_offset);\n      updated_unique_values[model_offset] = unique_values[model_offset];\n      updated_unique_values[accumulate_offset] = unique_values[accumulate_offset];\n      updated_unique_values[z_offset] = unique_values[z_offset];\n      FtrlUpdateFunctor<T, G>()(model_diff + i, updated_unique_values + model_offset,\n                                updated_unique_values + accumulate_offset,\n                                updated_unique_values + z_offset, scale, l1, l2, lr_power, lambda1,\n                                lambda2, beta, weight_decay, learning_rate_val);\n    }\n  }\n}\n\nclass EmbeddingUpdateKernelState final : public user_op::OpKernelState {\n public:\n  explicit EmbeddingUpdateKernelState(user_op::KernelInitContext* ctx) {\n    const std::string& embedding_name = ctx->Attr<std::string>(\"embedding_name\");\n    const int64_t parallel_id = ctx->parallel_ctx().parallel_id();\n    embedding_state_ = Singleton<embedding::EmbeddingManager>::Get()->GetEmbeddingState(\n        embedding_name, parallel_id);\n  }\n  ~EmbeddingUpdateKernelState() override = default;\n\n  embedding::EmbeddingState* EmbeddingState() { return embedding_state_; }\n\n private:\n  embedding::EmbeddingState* embedding_state_;\n};\n\n}  // namespace\n\ntemplate<typename T, typename G, typename IDX>\nclass SgdEmbeddingUpdateKernel final : public user_op::OpKernel {\n public:\n  SgdEmbeddingUpdateKernel() = default;\n  ~SgdEmbeddingUpdateKernel() override = default;\n\n  std::shared_ptr<user_op::OpKernelState> CreateOpKernelState(\n      user_op::KernelInitContext* ctx) const override {\n    return std::make_shared<EmbeddingUpdateKernelState>(ctx);\n  }\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state,\n               const user_op::OpKernelCache*) const override {\n    auto* kernel_state = dynamic_cast<EmbeddingUpdateKernelState*>(state);\n    CHECK(kernel_state != nullptr);\n    embedding::EmbeddingState* embedding_state = kernel_state->EmbeddingState();\n    embedding_state->OnEmbeddingUpdateStart(ctx, current_iter_);\n    const user_op::Tensor* num_unique_ids = ctx->Tensor4ArgNameAndIndex(\"num_unique_ids\", 0);\n    const user_op::Tensor* embedding_grad = ctx->Tensor4ArgNameAndIndex(\"embedding_grad\", 0);\n    CHECK_EQ(embedding_grad->shape_view().NumAxes(), 2);\n    const int64_t line_size = ctx->Attr<int64_t>(\"line_size\");\n    const int64_t embedding_size = ctx->Attr<int64_t>(\"embedding_size\");\n    CHECK_EQ(line_size, embedding_size);\n    const auto scale = ctx->Attr<double>(\"scale\");\n    const float l1 = ctx->Attr<float>(\"l1\");\n    const float l2 = ctx->Attr<float>(\"l2\");\n    const auto weight_decay = ctx->Attr<float>(\"weight_decay\");\n    const float learning_rate_val = ctx->Attr<float>(\"learning_rate_val\");\n    const float* learning_rate_ptr = nullptr;\n    if (ctx->has_input(\"learning_rate\", 0)) {\n      const user_op::Tensor* learning_rate = ctx->Tensor4ArgNameAndIndex(\"learning_rate\", 0);\n      learning_rate_ptr = learning_rate->dptr<float>();\n    }\n    const T* scale_by_ptr = nullptr;\n    if (ctx->has_input(\"scale_by_tensor\", 0)) {\n      const user_op::Tensor* scale_by_tensor = ctx->Tensor4ArgNameAndIndex(\"scale_by_tensor\", 0);\n      CHECK_EQ(scale_by_tensor->shape_view().elem_cnt(), 1);\n      scale_by_ptr = scale_by_tensor->dptr<T>();\n    }\n    const T* down_scale_by_ptr = nullptr;\n    if (ctx->has_input(\"down_scale_by_tensor\", 0)) {\n      const user_op::Tensor* down_scale_by_tensor =\n          ctx->Tensor4ArgNameAndIndex(\"down_scale_by_tensor\", 0);\n      CHECK_EQ(down_scale_by_tensor->shape_view().elem_cnt(), 1);\n      down_scale_by_ptr = down_scale_by_tensor->dptr<T>();\n    }\n    const int64_t* skip_if_ptr = nullptr;\n    if (ctx->has_input(\"skip_if\", 0)) {\n      const user_op::Tensor* skip_if = ctx->Tensor4ArgNameAndIndex(\"skip_if\", 0);\n      CHECK_EQ(skip_if->shape_view().elem_cnt(), 1);\n      skip_if_ptr = skip_if->dptr<int64_t>();\n    }\n    // update kernel\n    const T* unique_embeddings_ptr =\n        reinterpret_cast<const T*>(embedding_state->EmbeddingUpdateUniqueEmbeddings(current_iter_));\n    T* updated_unique_embeddings_ptr = reinterpret_cast<T*>(\n        embedding_state->EmbeddingUpdateUpdatedUniqueEmbeddings(current_iter_));\n    const uint32_t num_unique = embedding_state->GetIdNumUnique(current_iter_);\n    const int64_t embedding_grad_elem_cnt = num_unique * embedding_size;\n    SGDUpdateKernel<T, G, IDX>\n        <<<BlocksNum4ThreadsNum(embedding_grad_elem_cnt), kCudaThreadsNumPerBlock, 0,\n           ctx->stream()->As<ep::CudaStream>()->cuda_stream()>>>(\n            embedding_size, scale, l1, l2, weight_decay, learning_rate_val,\n            reinterpret_cast<const IDX*>(num_unique_ids->dptr()), learning_rate_ptr, scale_by_ptr,\n            down_scale_by_ptr, skip_if_ptr, embedding_grad->dptr<G>(), unique_embeddings_ptr,\n            updated_unique_embeddings_ptr);\n    embedding_state->OnEmbeddingUpdateEnd(ctx, current_iter_);\n    current_iter_++;\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n  mutable int64_t current_iter_;\n};\n\n#define IDX_DATA_TYPE_SEQ                           \\\n  OF_PP_MAKE_TUPLE_SEQ(uint32_t, DataType::kUInt32) \\\n  OF_PP_MAKE_TUPLE_SEQ(int32_t, DataType::kInt32)\n\n#define REGISTER_CUDA_ONE_EMBEDDING_SGD_UPDATE_KERNEL(t_dtype_pair, g_type_pair, idx_dtype_pair)  \\\n  REGISTER_USER_KERNEL(\"one_embedding_sgd_update\")                                                \\\n      .SetCreateFn<                                                                               \\\n          SgdEmbeddingUpdateKernel<OF_PP_PAIR_FIRST(t_dtype_pair), OF_PP_PAIR_FIRST(g_type_pair), \\\n                                   OF_PP_PAIR_FIRST(idx_dtype_pair)>>()                           \\\n      .SetIsMatchedHob(                                                                           \\\n          (user_op::HobDeviceType() == DeviceType::kCUDA)                                         \\\n          && (user_op::HobDataType(\"num_unique_ids\", 0) == OF_PP_PAIR_SECOND(idx_dtype_pair))     \\\n          && (user_op::HobDataType(\"embedding_grad\", 0) == OF_PP_PAIR_SECOND(g_type_pair))        \\\n          && (user_op::HobDataType(\"unique_embeddings\", 0) == OF_PP_PAIR_SECOND(t_dtype_pair)));\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_CUDA_ONE_EMBEDDING_SGD_UPDATE_KERNEL,\n                                 FLOATING_DATA_TYPE_SEQ, FLOATING_DATA_TYPE_SEQ HALF_DATA_TYPE_SEQ,\n                                 IDX_DATA_TYPE_SEQ)\n\ntemplate<typename T, typename G, typename IDX>\nclass MomentumEmbeddingUpdateKernel final : public user_op::OpKernel {\n public:\n  MomentumEmbeddingUpdateKernel() : current_iter_(0){};\n  ~MomentumEmbeddingUpdateKernel() override = default;\n\n  std::shared_ptr<user_op::OpKernelState> CreateOpKernelState(\n      user_op::KernelInitContext* ctx) const override {\n    return std::make_shared<EmbeddingUpdateKernelState>(ctx);\n  }\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state,\n               const user_op::OpKernelCache*) const override {\n    auto* kernel_state = dynamic_cast<EmbeddingUpdateKernelState*>(state);\n    CHECK(kernel_state != nullptr);\n    embedding::EmbeddingState* embedding_state = kernel_state->EmbeddingState();\n    embedding_state->OnEmbeddingUpdateStart(ctx, current_iter_);\n    const user_op::Tensor* num_unique_ids = ctx->Tensor4ArgNameAndIndex(\"num_unique_ids\", 0);\n    const user_op::Tensor* embedding_grad = ctx->Tensor4ArgNameAndIndex(\"embedding_grad\", 0);\n    CHECK_EQ(embedding_grad->shape_view().NumAxes(), 2);\n    const int64_t line_size = ctx->Attr<int64_t>(\"line_size\");\n    const int64_t embedding_size = ctx->Attr<int64_t>(\"embedding_size\");\n    CHECK_EQ(line_size, embedding_size * 2);\n    const float l1 = ctx->Attr<float>(\"l1\");\n    const float l2 = ctx->Attr<float>(\"l2\");\n    const auto weight_decay = ctx->Attr<float>(\"weight_decay\");\n    const auto beta = ctx->Attr<float>(\"beta\");\n    // TODO: Suppoprt dampening, nesterov, maximize in OneEmbeddingMomentumUpdate(zhengzekang).\n    const float dampening = 0.0;\n    const bool nesterov = false;\n    const bool maximize = false;\n    const auto scale = ctx->Attr<double>(\"scale\");\n    const T* scale_by_ptr = nullptr;\n    if (ctx->has_input(\"scale_by_tensor\", 0)) {\n      const user_op::Tensor* scale_by_tensor = ctx->Tensor4ArgNameAndIndex(\"scale_by_tensor\", 0);\n      CHECK_EQ(scale_by_tensor->shape_view().elem_cnt(), 1);\n      scale_by_ptr = scale_by_tensor->dptr<T>();\n    }\n    const T* down_scale_by_ptr = nullptr;\n    if (ctx->has_input(\"down_scale_by_tensor\", 0)) {\n      const user_op::Tensor* down_scale_by_tensor =\n          ctx->Tensor4ArgNameAndIndex(\"down_scale_by_tensor\", 0);\n      CHECK_EQ(down_scale_by_tensor->shape_view().elem_cnt(), 1);\n      down_scale_by_ptr = down_scale_by_tensor->dptr<T>();\n    }\n    const float learning_rate_val = ctx->Attr<float>(\"learning_rate_val\");\n    const float* learning_rate_ptr = nullptr;\n    if (ctx->has_input(\"learning_rate\", 0)) {\n      const user_op::Tensor* learning_rate = ctx->Tensor4ArgNameAndIndex(\"learning_rate\", 0);\n      learning_rate_ptr = learning_rate->dptr<float>();\n    }\n    const int64_t* skip_if_ptr = nullptr;\n    if (ctx->has_input(\"skip_if\", 0)) {\n      const user_op::Tensor* skip_if = ctx->Tensor4ArgNameAndIndex(\"skip_if\", 0);\n      CHECK_EQ(skip_if->shape_view().elem_cnt(), 1);\n      skip_if_ptr = skip_if->dptr<int64_t>();\n    }\n    // update kernel\n    const T* unique_embeddings_ptr =\n        reinterpret_cast<const T*>(embedding_state->EmbeddingUpdateUniqueEmbeddings(current_iter_));\n    T* updated_unique_embeddings_ptr = reinterpret_cast<T*>(\n        embedding_state->EmbeddingUpdateUpdatedUniqueEmbeddings(current_iter_));\n    const uint32_t num_unique = embedding_state->GetIdNumUnique(current_iter_);\n    const int64_t embedding_grad_elem_cnt = num_unique * embedding_size;\n    MomentumUpdateKernel<T, G, IDX>\n        <<<BlocksNum4ThreadsNum(embedding_grad_elem_cnt), kCudaThreadsNumPerBlock, 0,\n           ctx->stream()->As<ep::CudaStream>()->cuda_stream()>>>(\n            line_size, embedding_size, scale, l1, l2, weight_decay, beta, dampening, nesterov,\n            maximize, learning_rate_val, reinterpret_cast<const IDX*>(num_unique_ids->dptr()),\n            learning_rate_ptr, scale_by_ptr, down_scale_by_ptr, skip_if_ptr,\n            embedding_grad->dptr<G>(), unique_embeddings_ptr, updated_unique_embeddings_ptr);\n    embedding_state->OnEmbeddingUpdateEnd(ctx, current_iter_);\n    current_iter_++;\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n  mutable int64_t current_iter_;\n};\n\n#define REGISTER_CUDA_ONE_EMBEDDING_MOMENTUM_UPDATE_KERNEL(t_dtype_pair, g_type_pair,         \\\n                                                           idx_dtype_pair)                    \\\n  REGISTER_USER_KERNEL(\"one_embedding_momentum_update\")                                       \\\n      .SetCreateFn<MomentumEmbeddingUpdateKernel<OF_PP_PAIR_FIRST(t_dtype_pair),              \\\n                                                 OF_PP_PAIR_FIRST(g_type_pair),               \\\n                                                 OF_PP_PAIR_FIRST(idx_dtype_pair)>>()         \\\n      .SetIsMatchedHob(                                                                       \\\n          (user_op::HobDeviceType() == DeviceType::kCUDA)                                     \\\n          && (user_op::HobDataType(\"num_unique_ids\", 0) == OF_PP_PAIR_SECOND(idx_dtype_pair)) \\\n          && (user_op::HobDataType(\"embedding_grad\", 0) == OF_PP_PAIR_SECOND(g_type_pair))    \\\n          && (user_op::HobDataType(\"unique_embeddings\", 0) == OF_PP_PAIR_SECOND(t_dtype_pair)));\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_CUDA_ONE_EMBEDDING_MOMENTUM_UPDATE_KERNEL,\n                                 FLOATING_DATA_TYPE_SEQ, FLOATING_DATA_TYPE_SEQ HALF_DATA_TYPE_SEQ,\n                                 IDX_DATA_TYPE_SEQ)\n\ntemplate<typename T, typename G, typename IDX>\nclass AdamEmbeddingUpdateKernel final : public user_op::OpKernel {\n public:\n  AdamEmbeddingUpdateKernel() : current_iter_(0){};\n  ~AdamEmbeddingUpdateKernel() override = default;\n\n  std::shared_ptr<user_op::OpKernelState> CreateOpKernelState(\n      user_op::KernelInitContext* ctx) const override {\n    return std::make_shared<EmbeddingUpdateKernelState>(ctx);\n  }\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state,\n               const user_op::OpKernelCache*) const override {\n    auto* kernel_state = dynamic_cast<EmbeddingUpdateKernelState*>(state);\n    CHECK(kernel_state != nullptr);\n    embedding::EmbeddingState* embedding_state = kernel_state->EmbeddingState();\n    embedding_state->OnEmbeddingUpdateStart(ctx, current_iter_);\n    const user_op::Tensor* num_unique_ids = ctx->Tensor4ArgNameAndIndex(\"num_unique_ids\", 0);\n    const user_op::Tensor* unique_embeddings = ctx->Tensor4ArgNameAndIndex(\"unique_embeddings\", 0);\n    const user_op::Tensor* embedding_grad = ctx->Tensor4ArgNameAndIndex(\"embedding_grad\", 0);\n    user_op::Tensor* updated_unique_embeddings =\n        ctx->Tensor4ArgNameAndIndex(\"updated_unique_embeddings\", 0);\n    CHECK_EQ(embedding_grad->shape_view().NumAxes(), 2);\n    const int64_t line_size = ctx->Attr<int64_t>(\"line_size\");\n    const int64_t embedding_size = ctx->Attr<int64_t>(\"embedding_size\");\n    CHECK_EQ(line_size, embedding_size * 3);\n\n    const float l1 = ctx->Attr<float>(\"l1\");\n    const float l2 = ctx->Attr<float>(\"l2\");\n    const auto weight_decay = ctx->Attr<float>(\"weight_decay\");\n    const auto beta1 = ctx->Attr<float>(\"beta1\");\n    const auto beta2 = ctx->Attr<float>(\"beta2\");\n    const auto epsilon = ctx->Attr<float>(\"epsilon\");\n    const bool do_bias_correction = ctx->Attr<bool>(\"do_bias_correction\");\n    const auto scale = ctx->Attr<double>(\"scale\");\n    const T* scale_by_ptr = nullptr;\n    if (ctx->has_input(\"scale_by_tensor\", 0)) {\n      const user_op::Tensor* scale_by_tensor = ctx->Tensor4ArgNameAndIndex(\"scale_by_tensor\", 0);\n      CHECK_EQ(scale_by_tensor->shape_view().elem_cnt(), 1);\n      scale_by_ptr = scale_by_tensor->dptr<T>();\n    }\n    const T* down_scale_by_ptr = nullptr;\n    if (ctx->has_input(\"down_scale_by_tensor\", 0)) {\n      const user_op::Tensor* down_scale_by_tensor =\n          ctx->Tensor4ArgNameAndIndex(\"down_scale_by_tensor\", 0);\n      CHECK_EQ(down_scale_by_tensor->shape_view().elem_cnt(), 1);\n      down_scale_by_ptr = down_scale_by_tensor->dptr<T>();\n    }\n    const float learning_rate_val = ctx->Attr<float>(\"learning_rate_val\");\n    const float* learning_rate_ptr = nullptr;\n    if (ctx->has_input(\"learning_rate\", 0)) {\n      const user_op::Tensor* learning_rate = ctx->Tensor4ArgNameAndIndex(\"learning_rate\", 0);\n      learning_rate_ptr = learning_rate->dptr<float>();\n    }\n    const int64_t* skip_if_ptr = nullptr;\n    if (ctx->has_input(\"skip_if\", 0)) {\n      const user_op::Tensor* skip_if = ctx->Tensor4ArgNameAndIndex(\"skip_if\", 0);\n      CHECK_EQ(skip_if->shape_view().elem_cnt(), 1);\n      skip_if_ptr = skip_if->dptr<int64_t>();\n    }\n    const float bias_correction1_val = ctx->Attr<float>(\"bias_correction1_val\");\n    const float* bias_correction1_ptr = nullptr;\n    if (ctx->has_input(\"bias_correction1\", 0)) {\n      bias_correction1_ptr = ctx->Tensor4ArgNameAndIndex(\"bias_correction1\", 0)->dptr<float>();\n    }\n    const float bias_correction2_val = ctx->Attr<float>(\"bias_correction2_val\");\n    const float* bias_correction2_ptr = nullptr;\n    if (ctx->has_input(\"bias_correction2\", 0)) {\n      bias_correction2_ptr = ctx->Tensor4ArgNameAndIndex(\"bias_correction2\", 0)->dptr<float>();\n    }\n    // update kernel\n    const T* unique_embeddings_ptr =\n        reinterpret_cast<const T*>(embedding_state->EmbeddingUpdateUniqueEmbeddings(current_iter_));\n    T* updated_unique_embeddings_ptr = reinterpret_cast<T*>(\n        embedding_state->EmbeddingUpdateUpdatedUniqueEmbeddings(current_iter_));\n    const uint32_t num_unique = embedding_state->GetIdNumUnique(current_iter_);\n    const int64_t embedding_grad_elem_cnt = num_unique * embedding_size;\n    AdamUpdateKernel<T, G, IDX>\n        <<<BlocksNum4ThreadsNum(embedding_grad_elem_cnt), kCudaThreadsNumPerBlock, 0,\n           ctx->stream()->As<ep::CudaStream>()->cuda_stream()>>>(\n            line_size, embedding_size, static_cast<T>(scale), l1, l2, weight_decay, beta1, beta2,\n            epsilon, learning_rate_val, bias_correction1_val, bias_correction2_val,\n            bias_correction1_ptr, bias_correction2_ptr,\n            reinterpret_cast<const IDX*>(num_unique_ids->dptr()), learning_rate_ptr, scale_by_ptr,\n            down_scale_by_ptr, skip_if_ptr, embedding_grad->dptr<G>(), unique_embeddings_ptr,\n            updated_unique_embeddings_ptr);\n    embedding_state->OnEmbeddingUpdateEnd(ctx, current_iter_);\n    current_iter_++;\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n  mutable int64_t current_iter_;\n};\n\n#define REGISTER_CUDA_ONE_EMBEDDING_ADAM_UPDATE_KERNEL(t_dtype_pair, g_type_pair, idx_dtype_pair)  \\\n  REGISTER_USER_KERNEL(\"one_embedding_adam_update\")                                                \\\n      .SetCreateFn<                                                                                \\\n          AdamEmbeddingUpdateKernel<OF_PP_PAIR_FIRST(t_dtype_pair), OF_PP_PAIR_FIRST(g_type_pair), \\\n                                    OF_PP_PAIR_FIRST(idx_dtype_pair)>>()                           \\\n      .SetIsMatchedHob(                                                                            \\\n          (user_op::HobDeviceType() == DeviceType::kCUDA)                                          \\\n          && (user_op::HobDataType(\"num_unique_ids\", 0) == OF_PP_PAIR_SECOND(idx_dtype_pair))      \\\n          && (user_op::HobDataType(\"embedding_grad\", 0) == OF_PP_PAIR_SECOND(g_type_pair))         \\\n          && (user_op::HobDataType(\"unique_embeddings\", 0) == OF_PP_PAIR_SECOND(t_dtype_pair)));\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_CUDA_ONE_EMBEDDING_ADAM_UPDATE_KERNEL,\n                                 FLOATING_DATA_TYPE_SEQ, FLOATING_DATA_TYPE_SEQ HALF_DATA_TYPE_SEQ,\n                                 IDX_DATA_TYPE_SEQ)\n\ntemplate<typename T, typename G, typename IDX>\nclass SmartDecaySparseAdamEmbeddingUpdateKernel final : public user_op::OpKernel {\n public:\n  SmartDecaySparseAdamEmbeddingUpdateKernel() : current_iter_(0){};\n  ~SmartDecaySparseAdamEmbeddingUpdateKernel() override = default;\n\n  std::shared_ptr<user_op::OpKernelState> CreateOpKernelState(\n      user_op::KernelInitContext* ctx) const override {\n    return std::make_shared<EmbeddingUpdateKernelState>(ctx);\n  }\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state,\n               const user_op::OpKernelCache*) const override {\n    auto* kernel_state = dynamic_cast<EmbeddingUpdateKernelState*>(state);\n    CHECK(kernel_state != nullptr);\n    embedding::EmbeddingState* embedding_state = kernel_state->EmbeddingState();\n    embedding_state->OnEmbeddingUpdateStart(ctx, current_iter_);\n    const user_op::Tensor* num_unique_ids = ctx->Tensor4ArgNameAndIndex(\"num_unique_ids\", 0);\n    const user_op::Tensor* embedding_grad = ctx->Tensor4ArgNameAndIndex(\"embedding_grad\", 0);\n    user_op::Tensor* updated_unique_embeddings =\n        ctx->Tensor4ArgNameAndIndex(\"updated_unique_embeddings\", 0);\n    CHECK_EQ(embedding_grad->shape_view().NumAxes(), 2);\n    const int64_t line_size = ctx->Attr<int64_t>(\"line_size\");\n    const int64_t embedding_size = ctx->Attr<int64_t>(\"embedding_size\");\n    const float l1 = ctx->Attr<float>(\"l1\");\n    const float l2 = ctx->Attr<float>(\"l2\");\n    const auto weight_decay = ctx->Attr<float>(\"weight_decay\");\n    const auto beta1 = ctx->Attr<float>(\"beta1\");\n    const auto beta2 = ctx->Attr<float>(\"beta2\");\n    const auto epsilon = ctx->Attr<float>(\"epsilon\");\n    const bool do_bias_correction = ctx->Attr<bool>(\"do_bias_correction\");\n    const auto scale = ctx->Attr<double>(\"scale\");\n    const T* scale_by_ptr = nullptr;\n    if (ctx->has_input(\"scale_by_tensor\", 0)) {\n      const user_op::Tensor* scale_by_tensor = ctx->Tensor4ArgNameAndIndex(\"scale_by_tensor\", 0);\n      CHECK_EQ(scale_by_tensor->shape_view().elem_cnt(), 1);\n      scale_by_ptr = scale_by_tensor->dptr<T>();\n    }\n    const T* down_scale_by_ptr = nullptr;\n    if (ctx->has_input(\"down_scale_by_tensor\", 0)) {\n      const user_op::Tensor* down_scale_by_tensor =\n          ctx->Tensor4ArgNameAndIndex(\"down_scale_by_tensor\", 0);\n      CHECK_EQ(down_scale_by_tensor->shape_view().elem_cnt(), 1);\n      down_scale_by_ptr = down_scale_by_tensor->dptr<T>();\n    }\n    const float learning_rate_val = ctx->Attr<float>(\"learning_rate_val\");\n    const float* learning_rate_ptr = nullptr;\n    if (ctx->has_input(\"learning_rate\", 0)) {\n      const user_op::Tensor* learning_rate = ctx->Tensor4ArgNameAndIndex(\"learning_rate\", 0);\n      learning_rate_ptr = learning_rate->dptr<float>();\n    }\n    const int64_t train_step_val = ctx->Attr<int64_t>(\"train_step_val\");\n    const int64_t* train_step_ptr = nullptr;\n    if (ctx->has_input(\"train_step\", 0)) {\n      const user_op::Tensor* train_step = ctx->Tensor4ArgNameAndIndex(\"train_step\", 0);\n      train_step_ptr = train_step->dptr<int64_t>();\n    }\n    const int64_t* skip_if_ptr = nullptr;\n    if (ctx->has_input(\"skip_if\", 0)) {\n      const user_op::Tensor* skip_if = ctx->Tensor4ArgNameAndIndex(\"skip_if\", 0);\n      CHECK_EQ(skip_if->shape_view().elem_cnt(), 1);\n      skip_if_ptr = skip_if->dptr<int64_t>();\n    }\n    // update kernel\n    const T* unique_embeddings_ptr =\n        reinterpret_cast<const T*>(embedding_state->EmbeddingUpdateUniqueEmbeddings(current_iter_));\n    T* updated_unique_embeddings_ptr = reinterpret_cast<T*>(\n        embedding_state->EmbeddingUpdateUpdatedUniqueEmbeddings(current_iter_));\n    const uint32_t num_unique = embedding_state->GetIdNumUnique(current_iter_);\n    const int64_t embedding_grad_elem_cnt = num_unique * embedding_size;\n\n    const int64_t value_dtype_size = GetSizeOfDataType(updated_unique_embeddings->data_type());\n    const int64_t step_dtype_size = sizeof(int64_t);\n    const int64_t model_and_states_bytes = embedding_size * 3 * value_dtype_size;\n    const int64_t align_to_step_size_bytes =\n        (model_and_states_bytes + step_dtype_size - 1) / step_dtype_size * step_dtype_size;\n    const int64_t step_col_offset = align_to_step_size_bytes / value_dtype_size;\n    const int64_t smart_decay_sparse_adam_line_size =\n        (align_to_step_size_bytes + step_dtype_size) / value_dtype_size;\n    CHECK_EQ(line_size, smart_decay_sparse_adam_line_size);\n\n    SmartDecaySparseAdamUpdateKernel<T, G, IDX>\n        <<<BlocksNum4ThreadsNum(embedding_grad_elem_cnt), kCudaThreadsNumPerBlock, 0,\n           ctx->stream()->As<ep::CudaStream>()->cuda_stream()>>>(\n            line_size, embedding_size, static_cast<T>(scale), l1, l2, weight_decay, beta1, beta2,\n            epsilon, learning_rate_val, step_col_offset,\n            reinterpret_cast<const IDX*>(num_unique_ids->dptr()), learning_rate_ptr, train_step_ptr,\n            scale_by_ptr, down_scale_by_ptr, skip_if_ptr, embedding_grad->dptr<G>(),\n            unique_embeddings_ptr, updated_unique_embeddings_ptr);\n    embedding_state->OnEmbeddingUpdateEnd(ctx, current_iter_);\n    current_iter_++;\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n  mutable int64_t current_iter_;\n};\n\n#define REGISTER_CUDA_ONE_EMBEDDING_SMART_DECAY_SPARSE_ADAM_UPDATE_KERNEL(                        \\\n    t_dtype_pair, g_type_pair, idx_dtype_pair)                                                    \\\n  REGISTER_USER_KERNEL(\"one_embedding_smart_decay_sparse_adam_update\")                            \\\n      .SetCreateFn<SmartDecaySparseAdamEmbeddingUpdateKernel<OF_PP_PAIR_FIRST(t_dtype_pair),      \\\n                                                             OF_PP_PAIR_FIRST(g_type_pair),       \\\n                                                             OF_PP_PAIR_FIRST(idx_dtype_pair)>>() \\\n      .SetIsMatchedHob(                                                                           \\\n          (user_op::HobDeviceType() == DeviceType::kCUDA)                                         \\\n          && (user_op::HobDataType(\"num_unique_ids\", 0) == OF_PP_PAIR_SECOND(idx_dtype_pair))     \\\n          && (user_op::HobDataType(\"embedding_grad\", 0) == OF_PP_PAIR_SECOND(g_type_pair))        \\\n          && (user_op::HobDataType(\"unique_embeddings\", 0) == OF_PP_PAIR_SECOND(t_dtype_pair)));\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_CUDA_ONE_EMBEDDING_SMART_DECAY_SPARSE_ADAM_UPDATE_KERNEL,\n                                 FLOATING_DATA_TYPE_SEQ, FLOATING_DATA_TYPE_SEQ HALF_DATA_TYPE_SEQ,\n                                 IDX_DATA_TYPE_SEQ)\n\ntemplate<typename T, typename G, typename IDX>\nclass AdagradEmbeddingUpdateKernel final : public user_op::OpKernel {\n public:\n  AdagradEmbeddingUpdateKernel() : current_iter_(0){};\n  ~AdagradEmbeddingUpdateKernel() override = default;\n\n  std::shared_ptr<user_op::OpKernelState> CreateOpKernelState(\n      user_op::KernelInitContext* ctx) const override {\n    return std::make_shared<EmbeddingUpdateKernelState>(ctx);\n  }\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state,\n               const user_op::OpKernelCache*) const override {\n    auto* kernel_state = dynamic_cast<EmbeddingUpdateKernelState*>(state);\n    CHECK(kernel_state != nullptr);\n    embedding::EmbeddingState* embedding_state = kernel_state->EmbeddingState();\n    embedding_state->OnEmbeddingUpdateStart(ctx, current_iter_);\n    const user_op::Tensor* num_unique_ids = ctx->Tensor4ArgNameAndIndex(\"num_unique_ids\", 0);\n    const user_op::Tensor* unique_embeddings = ctx->Tensor4ArgNameAndIndex(\"unique_embeddings\", 0);\n    const user_op::Tensor* embedding_grad = ctx->Tensor4ArgNameAndIndex(\"embedding_grad\", 0);\n    user_op::Tensor* updated_unique_embeddings =\n        ctx->Tensor4ArgNameAndIndex(\"updated_unique_embeddings\", 0);\n    CHECK_EQ(embedding_grad->shape_view().NumAxes(), 2);\n    const int64_t line_size = ctx->Attr<int64_t>(\"line_size\");\n    const int64_t embedding_size = ctx->Attr<int64_t>(\"embedding_size\");\n    CHECK_EQ(line_size, embedding_size * 2);\n    const float l1 = ctx->Attr<float>(\"l1\");\n    const float l2 = ctx->Attr<float>(\"l2\");\n    const auto weight_decay = ctx->Attr<float>(\"weight_decay\");\n    const auto lr_decay = ctx->Attr<float>(\"lr_decay\");\n    const auto epsilon = ctx->Attr<float>(\"epsilon\");\n    const auto scale = ctx->Attr<double>(\"scale\");\n    const T* scale_by_ptr = nullptr;\n    if (ctx->has_input(\"scale_by_tensor\", 0)) {\n      const user_op::Tensor* scale_by_tensor = ctx->Tensor4ArgNameAndIndex(\"scale_by_tensor\", 0);\n      CHECK_EQ(scale_by_tensor->shape_view().elem_cnt(), 1);\n      scale_by_ptr = scale_by_tensor->dptr<T>();\n    }\n    const T* down_scale_by_ptr = nullptr;\n    if (ctx->has_input(\"down_scale_by_tensor\", 0)) {\n      const user_op::Tensor* down_scale_by_tensor =\n          ctx->Tensor4ArgNameAndIndex(\"down_scale_by_tensor\", 0);\n      CHECK_EQ(down_scale_by_tensor->shape_view().elem_cnt(), 1);\n      down_scale_by_ptr = down_scale_by_tensor->dptr<T>();\n    }\n    const float learning_rate_val = ctx->Attr<float>(\"learning_rate_val\");\n    const float* learning_rate_ptr = nullptr;\n    if (ctx->has_input(\"learning_rate\", 0)) {\n      const user_op::Tensor* learning_rate = ctx->Tensor4ArgNameAndIndex(\"learning_rate\", 0);\n      learning_rate_ptr = learning_rate->dptr<float>();\n    }\n    const int64_t train_step_val = ctx->Attr<int64_t>(\"train_step_val\");\n    const int64_t* train_step_ptr = nullptr;\n    if (ctx->has_input(\"train_step\", 0)) {\n      const user_op::Tensor* train_step = ctx->Tensor4ArgNameAndIndex(\"train_step\", 0);\n      train_step_ptr = train_step->dptr<int64_t>();\n    }\n    const int64_t* skip_if_ptr = nullptr;\n    if (ctx->has_input(\"skip_if\", 0)) {\n      const user_op::Tensor* skip_if = ctx->Tensor4ArgNameAndIndex(\"skip_if\", 0);\n      CHECK_EQ(skip_if->shape_view().elem_cnt(), 1);\n      skip_if_ptr = skip_if->dptr<int64_t>();\n    }\n    // update kernel\n    const T* unique_embeddings_ptr =\n        reinterpret_cast<const T*>(embedding_state->EmbeddingUpdateUniqueEmbeddings(current_iter_));\n    T* updated_unique_embeddings_ptr = reinterpret_cast<T*>(\n        embedding_state->EmbeddingUpdateUpdatedUniqueEmbeddings(current_iter_));\n    const uint32_t num_unique = embedding_state->GetIdNumUnique(current_iter_);\n    const int64_t embedding_grad_elem_cnt = num_unique * embedding_size;\n    AdagradUpdateKernel<T, G, IDX>\n        <<<BlocksNum4ThreadsNum(embedding_grad_elem_cnt), kCudaThreadsNumPerBlock, 0,\n           ctx->stream()->As<ep::CudaStream>()->cuda_stream()>>>(\n            line_size, embedding_size, static_cast<T>(scale), l1, l2, weight_decay, lr_decay,\n            epsilon, learning_rate_val, train_step_val,\n            reinterpret_cast<const IDX*>(num_unique_ids->dptr()), learning_rate_ptr, train_step_ptr,\n            scale_by_ptr, down_scale_by_ptr, skip_if_ptr, embedding_grad->dptr<G>(),\n            unique_embeddings_ptr, updated_unique_embeddings_ptr);\n    embedding_state->OnEmbeddingUpdateEnd(ctx, current_iter_);\n    current_iter_++;\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n  mutable int64_t current_iter_;\n};\n\n#define REGISTER_CUDA_ONE_EMBEDDING_ADAGRAD_UPDATE_KERNEL(t_dtype_pair, g_type_pair,          \\\n                                                          idx_dtype_pair)                     \\\n  REGISTER_USER_KERNEL(\"one_embedding_adagrad_update\")                                        \\\n      .SetCreateFn<AdagradEmbeddingUpdateKernel<OF_PP_PAIR_FIRST(t_dtype_pair),               \\\n                                                OF_PP_PAIR_FIRST(g_type_pair),                \\\n                                                OF_PP_PAIR_FIRST(idx_dtype_pair)>>()          \\\n      .SetIsMatchedHob(                                                                       \\\n          (user_op::HobDeviceType() == DeviceType::kCUDA)                                     \\\n          && (user_op::HobDataType(\"num_unique_ids\", 0) == OF_PP_PAIR_SECOND(idx_dtype_pair)) \\\n          && (user_op::HobDataType(\"embedding_grad\", 0) == OF_PP_PAIR_SECOND(g_type_pair))    \\\n          && (user_op::HobDataType(\"unique_embeddings\", 0) == OF_PP_PAIR_SECOND(t_dtype_pair)));\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_CUDA_ONE_EMBEDDING_ADAGRAD_UPDATE_KERNEL,\n                                 FLOATING_DATA_TYPE_SEQ, FLOATING_DATA_TYPE_SEQ HALF_DATA_TYPE_SEQ,\n                                 IDX_DATA_TYPE_SEQ)\n\ntemplate<typename T, typename G, typename IDX>\nclass FtrlEmbeddingUpdateKernel final : public user_op::OpKernel {\n public:\n  FtrlEmbeddingUpdateKernel() : current_iter_(0){};\n  ~FtrlEmbeddingUpdateKernel() override = default;\n\n  std::shared_ptr<user_op::OpKernelState> CreateOpKernelState(\n      user_op::KernelInitContext* ctx) const override {\n    return std::make_shared<EmbeddingUpdateKernelState>(ctx);\n  }\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state,\n               const user_op::OpKernelCache*) const override {\n    auto* kernel_state = dynamic_cast<EmbeddingUpdateKernelState*>(state);\n    CHECK(kernel_state != nullptr);\n    embedding::EmbeddingState* embedding_state = kernel_state->EmbeddingState();\n    embedding_state->OnEmbeddingUpdateStart(ctx, current_iter_);\n    const user_op::Tensor* num_unique_ids = ctx->Tensor4ArgNameAndIndex(\"num_unique_ids\", 0);\n    const user_op::Tensor* embedding_grad = ctx->Tensor4ArgNameAndIndex(\"embedding_grad\", 0);\n    CHECK_EQ(embedding_grad->shape_view().NumAxes(), 2)\n        << \"The NumAxes of embedding_grad should be equal to 2. \";\n    const int64_t line_size = ctx->Attr<int64_t>(\"line_size\");\n    const int64_t embedding_size = ctx->Attr<int64_t>(\"embedding_size\");\n    CHECK_EQ(line_size, embedding_size * 3)\n        << \"The line_size should be equal to 3 x embedding_size. \";\n    const float l1 = 0.0;\n    const float l2 = 0.0;\n    const float weight_decay = ctx->Attr<float>(\"weight_decay\");\n    // TODO(zhengzekang): Undefined behavior for ftrl optimizer with weight_decay in `abs(new_z_val)\n    // < lambda1` condition.\n    CHECK_EQ(weight_decay, static_cast<float>(0.0))\n        << \"Currently not support for setting weight decay. \";\n    const float lr_power = ctx->Attr<float>(\"lr_power\");\n    const float lambda1 = ctx->Attr<float>(\"lambda1\");\n    const float lambda2 = ctx->Attr<float>(\"lambda2\");\n    const float beta = ctx->Attr<float>(\"beta\");\n    const double scale = ctx->Attr<double>(\"scale\");\n    const T* scale_by_ptr = nullptr;\n    if (ctx->has_input(\"scale_by_tensor\", 0)) {\n      const user_op::Tensor* scale_by_tensor = ctx->Tensor4ArgNameAndIndex(\"scale_by_tensor\", 0);\n      CHECK_EQ(scale_by_tensor->shape_view().elem_cnt(), 1);\n      scale_by_ptr = scale_by_tensor->dptr<T>();\n    }\n    const T* down_scale_by_ptr = nullptr;\n    if (ctx->has_input(\"down_scale_by_tensor\", 0)) {\n      const user_op::Tensor* down_scale_by_tensor =\n          ctx->Tensor4ArgNameAndIndex(\"down_scale_by_tensor\", 0);\n      CHECK_EQ(down_scale_by_tensor->shape_view().elem_cnt(), 1);\n      down_scale_by_ptr = down_scale_by_tensor->dptr<T>();\n    }\n    const float learning_rate_val = ctx->Attr<float>(\"learning_rate_val\");\n    const float* learning_rate_ptr = nullptr;\n    if (ctx->has_input(\"learning_rate\", 0)) {\n      const user_op::Tensor* learning_rate = ctx->Tensor4ArgNameAndIndex(\"learning_rate\", 0);\n      learning_rate_ptr = learning_rate->dptr<float>();\n    }\n    const int64_t* skip_if_ptr = nullptr;\n    if (ctx->has_input(\"skip_if\", 0)) {\n      const user_op::Tensor* skip_if = ctx->Tensor4ArgNameAndIndex(\"skip_if\", 0);\n      CHECK_EQ(skip_if->shape_view().elem_cnt(), 1);\n      skip_if_ptr = skip_if->dptr<int64_t>();\n    }\n    // update kernel\n    const T* unique_embeddings_ptr =\n        reinterpret_cast<const T*>(embedding_state->EmbeddingUpdateUniqueEmbeddings(current_iter_));\n    T* updated_unique_embeddings_ptr = reinterpret_cast<T*>(\n        embedding_state->EmbeddingUpdateUpdatedUniqueEmbeddings(current_iter_));\n    const uint32_t num_unique = embedding_state->GetIdNumUnique(current_iter_);\n    const int64_t embedding_grad_elem_cnt = num_unique * embedding_size;\n    FtrlUpdateKernel<T, G, IDX>\n        <<<BlocksNum4ThreadsNum(embedding_grad_elem_cnt), kCudaThreadsNumPerBlock, 0,\n           ctx->stream()->As<ep::CudaStream>()->cuda_stream()>>>(\n            line_size, embedding_size, static_cast<T>(scale), l1, l2, weight_decay, lr_power,\n            lambda1, lambda2, beta, learning_rate_val,\n            reinterpret_cast<const IDX*>(num_unique_ids->dptr()), learning_rate_ptr, scale_by_ptr,\n            down_scale_by_ptr, skip_if_ptr, embedding_grad->dptr<G>(), unique_embeddings_ptr,\n            updated_unique_embeddings_ptr);\n    embedding_state->OnEmbeddingUpdateEnd(ctx, current_iter_);\n    current_iter_++;\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n  mutable int64_t current_iter_;\n};\n#define REGISTER_CUDA_ONE_EMBEDDING_FTRL_UPDATE_KERNEL(t_dtype_pair, g_type_pair, idx_dtype_pair)  \\\n  REGISTER_USER_KERNEL(\"one_embedding_ftrl_update\")                                                \\\n      .SetCreateFn<                                                                                \\\n          FtrlEmbeddingUpdateKernel<OF_PP_PAIR_FIRST(t_dtype_pair), OF_PP_PAIR_FIRST(g_type_pair), \\\n                                    OF_PP_PAIR_FIRST(idx_dtype_pair)>>()                           \\\n      .SetIsMatchedHob(                                                                            \\\n          (user_op::HobDeviceType() == DeviceType::kCUDA)                                          \\\n          && (user_op::HobDataType(\"num_unique_ids\", 0) == OF_PP_PAIR_SECOND(idx_dtype_pair))      \\\n          && (user_op::HobDataType(\"embedding_grad\", 0) == OF_PP_PAIR_SECOND(g_type_pair))         \\\n          && (user_op::HobDataType(\"unique_embeddings\", 0) == OF_PP_PAIR_SECOND(t_dtype_pair)));\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_CUDA_ONE_EMBEDDING_FTRL_UPDATE_KERNEL,\n                                 FLOATING_DATA_TYPE_SEQ, FLOATING_DATA_TYPE_SEQ HALF_DATA_TYPE_SEQ,\n                                 IDX_DATA_TYPE_SEQ)\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/one_hot_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/new_kernel_util.h\"\n#include \"oneflow/core/ep/include/primitive/fill.h\"\n\nnamespace oneflow {\n\ntemplate<typename T, typename K>\nclass CpuOneHotKernel final : public user_op::OpKernel {\n public:\n  CpuOneHotKernel() = default;\n  ~CpuOneHotKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* indices = ctx->Tensor4ArgNameAndIndex(\"indices\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    const int64_t num_indices = indices->shape_view().elem_cnt();\n    const int64_t depth = ctx->Attr<int64_t>(\"depth\");\n    const DataType dtype = ctx->Attr<DataType>(\"dtype\");\n    const T on_value = IsFloatingDataType(dtype)\n                           ? static_cast<T>(ctx->Attr<double>(\"floating_on_value\"))\n                           : static_cast<T>(ctx->Attr<int64_t>(\"integer_on_value\"));\n    const T off_value = IsFloatingDataType(dtype)\n                            ? static_cast<T>(ctx->Attr<double>(\"floating_off_value\"))\n                            : static_cast<T>(ctx->Attr<int64_t>(\"integer_off_value\"));\n    const K* indices_dptr = indices->dptr<K>();\n    T* out_dptr = out->mut_dptr<T>();\n    std::unique_ptr<ep::primitive::Fill> fill =\n        ep::primitive::NewPrimitive<ep::primitive::FillFactory>(ctx->stream()->device_type(),\n                                                                out->data_type());\n    CHECK(fill);\n    fill->Launch(ctx->stream(), out->mut_dptr(), off_value, out->shape_view().elem_cnt());\n    FOR_RANGE(int64_t, i, 0, num_indices) {\n      const int64_t idx = indices_dptr[i];\n      CHECK_GE(idx, 0);\n      CHECK_LT(idx, depth);\n      out_dptr[i * depth + idx] = on_value;\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_CPU_ONE_HOT_KERNEL(dtype, itype)                                               \\\n  REGISTER_USER_KERNEL(\"one_hot\").SetCreateFn<CpuOneHotKernel<dtype, itype>>().SetIsMatchedHob( \\\n      (user_op::HobDeviceType() == DeviceType::kCPU)                                            \\\n      && (user_op::HobDataType(\"indices\", 0) == GetDataType<itype>::value)                      \\\n      && (user_op::HobDataType(\"out\", 0) == GetDataType<dtype>::value));\n\nREGISTER_CPU_ONE_HOT_KERNEL(int32_t, int32_t)\nREGISTER_CPU_ONE_HOT_KERNEL(int32_t, int64_t)\nREGISTER_CPU_ONE_HOT_KERNEL(int64_t, int32_t)\nREGISTER_CPU_ONE_HOT_KERNEL(int64_t, int64_t)\nREGISTER_CPU_ONE_HOT_KERNEL(float, int32_t)\nREGISTER_CPU_ONE_HOT_KERNEL(float, int64_t)\nREGISTER_CPU_ONE_HOT_KERNEL(double, int32_t)\nREGISTER_CPU_ONE_HOT_KERNEL(double, int64_t)\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/one_hot_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/cuda_graph_support.h\"\n#include \"oneflow/core/device/cuda_util.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename T, typename K>\n__global__ void OneHotEncodeGpu(int64_t elem_cnt, const int64_t depth, const T on_value,\n                                const T off_value, const K* indices, T* out) {\n  CUDA_1D_KERNEL_LOOP(i, elem_cnt) {\n    const int64_t row = i / depth;\n    const int64_t col = i - row * depth;\n    const int64_t idx = indices[row];\n    assert(idx >= 0 && idx < depth);\n    out[i] = (idx == col) ? on_value : off_value;\n  }\n}\n\n}  // namespace\n\ntemplate<typename T, typename K>\nclass GpuOneHotKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport {\n public:\n  GpuOneHotKernel() = default;\n  ~GpuOneHotKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* indices = ctx->Tensor4ArgNameAndIndex(\"indices\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    const int64_t num_indices = indices->shape_view().elem_cnt();\n    const int64_t depth = ctx->Attr<int64_t>(\"depth\");\n    const DataType dtype = ctx->Attr<DataType>(\"dtype\");\n    const T on_value = IsFloatingDataType(dtype)\n                           ? static_cast<T>(ctx->Attr<double>(\"floating_on_value\"))\n                           : static_cast<T>(ctx->Attr<int64_t>(\"integer_on_value\"));\n    const T off_value = IsFloatingDataType(dtype)\n                            ? static_cast<T>(ctx->Attr<double>(\"floating_off_value\"))\n                            : static_cast<T>(ctx->Attr<int64_t>(\"integer_off_value\"));\n    RUN_CUDA_KERNEL((OneHotEncodeGpu<T, K>), ctx->stream(), num_indices * depth,\n                    num_indices * depth, depth, on_value, off_value, indices->dptr<K>(),\n                    out->mut_dptr<T>());\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_CUDA_ONE_HOT_KERNEL(dtype, itype)                                              \\\n  REGISTER_USER_KERNEL(\"one_hot\").SetCreateFn<GpuOneHotKernel<dtype, itype>>().SetIsMatchedHob( \\\n      (user_op::HobDeviceType() == DeviceType::kCUDA)                                           \\\n      && (user_op::HobDataType(\"indices\", 0) == GetDataType<itype>::value)                      \\\n      && (user_op::HobDataType(\"out\", 0) == GetDataType<dtype>::value));\n\nREGISTER_CUDA_ONE_HOT_KERNEL(int32_t, int32_t)\nREGISTER_CUDA_ONE_HOT_KERNEL(int32_t, int64_t)\nREGISTER_CUDA_ONE_HOT_KERNEL(int64_t, int32_t)\nREGISTER_CUDA_ONE_HOT_KERNEL(int64_t, int64_t)\nREGISTER_CUDA_ONE_HOT_KERNEL(float, int32_t)\nREGISTER_CUDA_ONE_HOT_KERNEL(float, int64_t)\nREGISTER_CUDA_ONE_HOT_KERNEL(double, int32_t)\nREGISTER_CUDA_ONE_HOT_KERNEL(double, int64_t)\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/ones_like_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/switch_func.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/ep/include/primitive/fill.h\"\n\nnamespace oneflow {\n\nnamespace user_op {\n\nnamespace {\n\ntemplate<typename Context>\nstd::unique_ptr<ep::primitive::Fill> NewFillPrimitive(Context* ctx) {\n  const DataType data_type = ctx->TensorDesc4ArgNameAndIndex(\"out\", 0)->data_type();\n  return ep::primitive::NewPrimitive<ep::primitive::FillFactory>(ctx->device_type(), data_type);\n}\n\nclass OnesLikeKernel final : public user_op::OpKernel {\n public:\n  OnesLikeKernel() = default;\n  ~OnesLikeKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    std::unique_ptr<ep::primitive::Fill> fill =\n        ep::primitive::NewPrimitive<ep::primitive::FillFactory>(ctx->stream()->device_type(),\n                                                                out->data_type());\n    CHECK(fill);\n    fill->Launch(ctx->stream(), out->mut_dptr(), 1, out->shape_view().elem_cnt());\n  }\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\nauto FillPrimitiveExists() {\n  return hob::make_custom(\"FillPrimitiveExists\", [](const user_op::KernelRegContext& ctx) {\n    return NewFillPrimitive(&ctx).operator bool();\n  });\n}\n\nREGISTER_USER_KERNEL(\"ones_like\")\n    .SetCreateFn<OnesLikeKernel>()\n    .SetIsMatchedHob(FillPrimitiveExists());\n\n}  // namespace\n\n}  // namespace user_op\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/op_kernel_wrapper.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_KERNELS_OP_KERNEL_STATE_WRAPPER_H_\n#define ONEFLOW_USER_KERNELS_OP_KERNEL_STATE_WRAPPER_H_\n\n#include \"oneflow/core/framework/op_kernel.h\"\n\nnamespace oneflow {\n\ntemplate<typename T>\nclass OpKernelStateWrapper final : public user_op::OpKernelState {\n public:\n  template<typename... Args>\n  explicit OpKernelStateWrapper(Args&&... args) : data_(std::forward<Args>(args)...) {}\n\n  ~OpKernelStateWrapper() = default;\n\n  const T& Get() const { return data_; }\n  T* Mutable() { return &data_; }\n\n private:\n  T data_;\n};\n\ntemplate<typename T>\nclass OpKernelCacheWrapper final : public user_op::OpKernelCache {\n public:\n  template<typename... Args>\n  explicit OpKernelCacheWrapper(Args&&... args) : data_(std::forward<Args>(args)...) {}\n\n  ~OpKernelCacheWrapper() = default;\n\n  const T& Get() const { return data_; }\n  T* Mutable() { return &data_; }\n\n private:\n  T data_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_KERNELS_OP_KERNEL_STATE_WRAPPER_H_\n"
  },
  {
    "path": "oneflow/user/kernels/p2p_comm_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/new_kernel_util.h\"\n#include \"oneflow/core/control/global_process_ctx.h\"\n#include \"oneflow/core/job/rank_group.h\"\n#include \"oneflow/core/framework/instructions_builder.h\"\n#include \"oneflow/user/kernels/collective_communication/include/send.h\"\n#include \"oneflow/user/kernels/collective_communication/include/recv.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nnamespace {\n\nauto SendCollectiveCommunicationExists() {\n  return hob::make_custom(\"SendCollectiveCommunicationExists\",\n                          [=](const user_op::KernelRegContext& ctx) {\n                            DeviceType device_type = ctx.device_type();\n                            return ccl::IsSendRegistered(device_type);\n                          });\n}\n\nauto RecvCollectiveCommunicationExists() {\n  return hob::make_custom(\"RecvCollectiveCommunicationExists\",\n                          [=](const user_op::KernelRegContext& ctx) {\n                            DeviceType device_type = ctx.device_type();\n                            return ccl::IsRecvRegistered(device_type);\n                          });\n}\n\n}  // namespace\n\nclass SendKernel final : public user_op::OpKernel {\n public:\n  SendKernel() = default;\n  ~SendKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    const auto& dst_process_id = ctx->Attr<int64_t>(\"dst_process_id\");\n    std::unique_ptr<ccl::Send> send =\n        ccl::NewCollectiveCommunication<ccl::Send>(ctx->device_type(), in->data_type());\n    send->Launch(ctx->stream(), in->dptr(), in->shape_view().elem_cnt(), dst_process_id);\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\nclass RecvKernel final : public user_op::OpKernel {\n public:\n  RecvKernel() = default;\n  ~RecvKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    const auto& src_process_id = ctx->Attr<int64_t>(\"src_process_id\");\n    std::unique_ptr<ccl::Recv> recv =\n        ccl::NewCollectiveCommunication<ccl::Recv>(ctx->device_type(), out->data_type());\n    recv->Launch(ctx->stream(), out->mut_dptr(), out->shape_view().elem_cnt(), src_process_id);\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\nREGISTER_USER_KERNEL(\"send\").SetCreateFn<SendKernel>().SetIsMatchedHob(\n    SendCollectiveCommunicationExists());\n\nREGISTER_USER_KERNEL(\"recv\").SetCreateFn<RecvKernel>().SetIsMatchedHob(\n    RecvCollectiveCommunicationExists());\n}  // namespace\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/pack_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/new_kernel_util.h\"\n#include \"oneflow/user/kernels/op_kernel_wrapper.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<DeviceType device_type>\nclass PackKernel final : public user_op::OpKernel {\n public:\n  PackKernel() = default;\n  ~PackKernel() override = default;\n\n  std::shared_ptr<user_op::OpKernelState> CreateOpKernelState(\n      user_op::KernelInitContext* ctx) const override {\n    return std::make_shared<OpKernelStateWrapper<std::pair<size_t, size_t>>>(\n        std::make_pair<size_t, size_t>(0, 0));\n  }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state,\n               const user_op::OpKernelCache*) const override {\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    CHECK_EQ(in->data_type(), out->data_type());\n    const auto pack_num = ctx->Attr<int32_t>(\"pack_num\");\n    if (in->shape_view().NumAxes() > 0) {\n      CHECK_EQ(in->shape_view().NumAxes(), out->shape_view().NumAxes());\n      CHECK_EQ(out->shape_view().At(0), in->shape_view().At(0) * pack_num);\n      for (int64_t i = 1; i < in->shape_view().NumAxes(); ++i) {\n        CHECK_EQ(out->shape_view().At(i), in->shape_view().At(i));\n      }\n    } else {\n      // NOTE(chengcheng): for Scalar input pack\n      CHECK_EQ(in->shape_view().NumAxes(), 0);\n      CHECK_EQ(out->shape_view().NumAxes(), 1);\n      CHECK_EQ(in->shape_view().elem_cnt(), 1);\n      CHECK_EQ(out->shape_view().elem_cnt(), pack_num);\n    }\n    const int64_t copy_size = in->shape_view().elem_cnt() * GetSizeOfDataType(out->data_type());\n    auto* state_wrapper = dynamic_cast<OpKernelStateWrapper<std::pair<size_t, size_t>>*>(state);\n    CHECK_NOTNULL(state_wrapper);\n    const size_t index = state_wrapper->Get().first;\n    CHECK_EQ(state_wrapper->Get().second, pack_num);\n    Memcpy<device_type>(ctx->stream(), out->mut_dptr<char>() + index * copy_size, in->dptr<char>(),\n                        copy_size);\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_PACK_KERNEL(device)                                              \\\n  REGISTER_USER_KERNEL(\"pack\").SetCreateFn<PackKernel<device>>().SetIsMatchedHob( \\\n      (user_op::HobDeviceType() == device));\n\nOF_PP_FOR_EACH_TUPLE(REGISTER_PACK_KERNEL, DEVICE_TYPE_SEQ)\n#if defined(WITH_MLU)\nREGISTER_PACK_KERNEL(DeviceType::kMLU)\n#endif\n#undef REGISTER_PACK_KERNEL\n\n}  // namespace\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/pad_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/new_kernel_util.h\"\n#include \"oneflow/core/kernel/cuda_graph_support.h\"\n#include \"oneflow/core/ep/include/primitive/constant_pad.h\"\n\nnamespace oneflow {\n\nnamespace user_op {\n\nnamespace {\n\ntemplate<typename Context>\nstd::unique_ptr<ep::primitive::ConstantPad> NewConstantPadPrimitive(Context* ctx) {\n  const DataType data_type = ctx->TensorDesc4ArgNameAndIndex(\"y\", 0)->data_type();\n  return ep::primitive::NewPrimitive<ep::primitive::ConstantPadFactory>(ctx->device_type(),\n                                                                        data_type);\n}\n\nauto ConstantPadPrimitiveExists() {\n  return hob::make_custom(\"ConstantPadPrimitiveExists\", [](const KernelRegContext& ctx) {\n    return NewConstantPadPrimitive(&ctx).operator bool();\n  });\n}\n\n}  // namespace\n\nclass PadKernel final : public OpKernel, public CudaGraphSupport {\n public:\n  PadKernel() = default;\n  ~PadKernel() = default;\n\n private:\n  void Compute(KernelComputeContext* ctx) const override {\n    const Tensor* x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    Tensor* y = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    if (y->shape_view().NumAxes() > 0 && y->shape_view().elem_cnt() == 0) {\n      // if output is 0-shape tensor, than do nothing and return\n      return;\n    }\n\n    Scalar value;\n    if (IsIntegralDataType(x->data_type()) || x->data_type() == kBool) {\n      value = Scalar(ctx->Attr<int64_t>(\"integral_constant_value\"));\n    } else {\n      value = Scalar(ctx->Attr<double>(\"floating_constant_value\"));\n    }\n\n    const auto& padding_before = ctx->Attr<std::vector<int64_t>>(\"padding_before\");\n    const auto& padding_after = ctx->Attr<std::vector<int64_t>>(\"padding_after\");\n    const int64_t ndims = x->shape_view().NumAxes();\n    CHECK_EQ(padding_before.size(), ndims);\n\n    std::unique_ptr<ep::primitive::ConstantPad> pad_primitive = NewConstantPadPrimitive(ctx);\n    CHECK(pad_primitive);\n\n    pad_primitive->Launch(ctx->stream(), ndims, x->shape_view().ptr(), x->dptr(),\n                          padding_before.data(), padding_after.data(), value, y->mut_dptr());\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\nREGISTER_USER_KERNEL(\"pad\").SetCreateFn<PadKernel>().SetIsMatchedHob(ConstantPadPrimitiveExists());\n\n}  // namespace user_op\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/partial_fc_sample_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifdef WITH_CUDA\n\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/new_kernel_util.h\"\n#include \"oneflow/core/common/balanced_splitter.h\"\n#include \"oneflow/user/kernels/gather_kernel_util.h\"\n#include \"oneflow/core/common/not_equal_to_previous_adjacent_iterator.h\"\n#include <cub/cub.cuh>\n#include <curand.h>\n#include <curand_kernel.h>\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n\nnamespace oneflow {\nnamespace user_op {\n\nnamespace {\n\ntemplate<typename K>\nint64_t GetCubSortPairsTempStorageSize(int64_t n) {\n  size_t cub_sort_temp_store_size = 0;\n  OF_CUDA_CHECK((cub::DeviceRadixSort::SortPairs<K, K>(nullptr, cub_sort_temp_store_size, nullptr,\n                                                       nullptr, nullptr, nullptr, n)));\n  size_t temp_store_size = GetCudaAlignedSize(cub_sort_temp_store_size);\n  CHECK_GE(temp_store_size, 0);\n  CHECK_LT(temp_store_size, static_cast<size_t>(GetMaxVal<int64_t>()));\n  return static_cast<int64_t>(temp_store_size);\n}\n\ntemplate<typename K>\nint64_t GetCubScanTempStorageSize(int64_t n) {\n  size_t cub_scan_temp_store_size = 0;\n  NotEqualToPreviousAdjacentIterator<K, K> unique_counting_iter(nullptr, 0);\n  OF_CUDA_CHECK((cub::DeviceScan::InclusiveSum<NotEqualToPreviousAdjacentIterator<K, K>, K*>(\n      nullptr, cub_scan_temp_store_size, unique_counting_iter, nullptr, n)));\n  size_t temp_store_size = GetCudaAlignedSize(cub_scan_temp_store_size);\n  CHECK_GE(temp_store_size, 0);\n  CHECK_LT(temp_store_size, static_cast<size_t>(GetMaxVal<int64_t>()));\n  return static_cast<int64_t>(temp_store_size);\n}\n\ntemplate<typename K>\nclass TmpBufferManager final {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(TmpBufferManager);\n  TmpBufferManager(void* ptr, const int64_t device_num_class, const int64_t batch_size,\n                   const int64_t parallel_num)\n      : ptr_(ptr) {\n    const int64_t buffer_elem_cnt = std::max(device_num_class, batch_size);\n    const size_t cub_sort_keys_bytes = GetCudaAlignedSize(buffer_elem_cnt * sizeof(K));\n    const size_t cub_sort_values_bytes = GetCudaAlignedSize(buffer_elem_cnt * sizeof(K));\n    const size_t cub_sort_keys_out_bytes = GetCudaAlignedSize(buffer_elem_cnt * sizeof(K));\n    const size_t cub_sort_values_out_bytes = GetCudaAlignedSize(buffer_elem_cnt * sizeof(K));\n    const size_t bound_index_bytes = GetCudaAlignedSize((parallel_num + 1) * sizeof(K));\n    const size_t bound_value_bytes = GetCudaAlignedSize((parallel_num + 1) * sizeof(K));\n    cub_tmp_storage_bytes_ = std::max(GetCubSortPairsTempStorageSize<K>(buffer_elem_cnt),\n                                      GetCubScanTempStorageSize<K>(batch_size));\n    cub_sort_keys_offset_ = 0;\n    cub_sort_values_offset_ = cub_sort_keys_offset_ + cub_sort_keys_bytes;\n    cub_sort_keys_out_offset_ = cub_sort_values_offset_ + cub_sort_values_bytes;\n    cub_sort_values_out_offset_ = cub_sort_keys_out_offset_ + cub_sort_keys_out_bytes;\n    cub_tmp_storage_offset_ = cub_sort_values_out_offset_ + cub_sort_values_out_bytes;\n    bound_index_offset_ = cub_tmp_storage_offset_ + cub_tmp_storage_bytes_;\n    bound_value_offset_ = bound_index_offset_ + bound_index_bytes;\n    total_buffer_size_ = cub_sort_keys_bytes + cub_sort_values_bytes + cub_sort_keys_out_bytes\n                         + cub_sort_values_out_bytes + cub_tmp_storage_bytes_ + bound_index_bytes\n                         + bound_value_bytes;\n  }\n  ~TmpBufferManager() = default;\n\n  size_t GetTotalBufferSize() const { return total_buffer_size_; }\n  size_t GetCubTmpStorageSize() const { return cub_tmp_storage_bytes_; }\n  K* CubSortKeysPtr() const {\n    CHECK(ptr_ != nullptr);\n    return reinterpret_cast<K*>(reinterpret_cast<char*>(ptr_) + cub_sort_keys_offset_);\n  }\n  K* CubSortValuesPtr() const {\n    CHECK(ptr_ != nullptr);\n    return reinterpret_cast<K*>(reinterpret_cast<char*>(ptr_) + cub_sort_values_offset_);\n  }\n  K* CubSortKeysOutPtr() const {\n    CHECK(ptr_ != nullptr);\n    return reinterpret_cast<K*>(reinterpret_cast<char*>(ptr_) + cub_sort_keys_out_offset_);\n  }\n  K* CubSortValuesOutPtr() const {\n    CHECK(ptr_ != nullptr);\n    return reinterpret_cast<K*>(reinterpret_cast<char*>(ptr_) + cub_sort_values_out_offset_);\n  }\n  void* CubTmpStoragePtr() const {\n    CHECK(ptr_ != nullptr);\n    return reinterpret_cast<void*>(reinterpret_cast<char*>(ptr_) + cub_tmp_storage_offset_);\n  }\n  K* BoundIndexPtr() const {\n    CHECK(ptr_ != nullptr);\n    return reinterpret_cast<K*>(reinterpret_cast<char*>(ptr_) + bound_index_offset_);\n  }\n  K* BoundValuePtr() const {\n    CHECK(ptr_ != nullptr);\n    return reinterpret_cast<K*>(reinterpret_cast<char*>(ptr_) + bound_value_offset_);\n  }\n\n private:\n  size_t cub_sort_keys_offset_;\n  size_t cub_sort_values_offset_;\n  size_t cub_sort_keys_out_offset_;\n  size_t cub_sort_values_out_offset_;\n  size_t cub_tmp_storage_offset_;\n  size_t bound_index_offset_;\n  size_t bound_value_offset_;\n  size_t cub_tmp_storage_bytes_;\n  size_t total_buffer_size_;\n  void* ptr_;\n};\n\n__global__ void SetupKernel(int64_t seed, curandState* state) {\n  const int id = blockIdx.x * blockDim.x + threadIdx.x;\n  size_t local_seed = (static_cast<size_t>(seed) + 0x9e3779b9U + (static_cast<size_t>(id) << 6U)\n                       + (static_cast<size_t>(id) >> 2U));\n  curand_init(local_seed, 0, 0, &state[id]);\n}\n\ntemplate<typename K>\n__global__ void GenerateGpu(curandState* state, const int64_t n, const int64_t max_val, K* buffer) {\n  const int id = blockIdx.x * blockDim.x + threadIdx.x;\n  curandState localState = state[id];\n  CUDA_1D_KERNEL_LOOP(i, n) { buffer[i] = static_cast<K>(curand(&localState) % max_val); }\n  state[id] = localState;\n}\n\nclass DistributedPartialFcSampleOpKernelState final : public user_op::OpKernelState {\n public:\n  DistributedPartialFcSampleOpKernelState(ep::Stream* stream, int64_t lower, int64_t upper,\n                                          int64_t num_sample_per_rank, int64_t seed)\n      : lower_(lower), upper_(upper), num_sample_per_rank_(num_sample_per_rank) {\n    CHECK_NOTNULL(stream);\n    const int64_t num_classes = upper_ - lower_;\n    OF_CUDA_CHECK(cudaMalloc(&curand_states_, BlocksNum4ThreadsNum(num_classes)\n                                                  * kCudaThreadsNumPerBlock * sizeof(curandState)));\n    SetupKernel<<<BlocksNum4ThreadsNum(num_classes), kCudaThreadsNumPerBlock, 0,\n                  stream->As<ep::CudaStream>()->cuda_stream()>>>(seed, curand_states_);\n  }\n  ~DistributedPartialFcSampleOpKernelState() {\n    cudaError_t ret = cudaFree(curand_states_);\n    if (ret != cudaErrorCudartUnloading) { OF_CUDA_CHECK(ret); }\n  };\n\n  int64_t lower() const { return lower_; }\n  int64_t upper() const { return upper_; }\n  int64_t num_sample_per_rank() const { return num_sample_per_rank_; }\n\n  template<typename K>\n  void GenRandom(ep::Stream* stream, const int64_t n, const int64_t max_val, K* buffer) {\n    GenerateGpu<K>\n        <<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0,\n           stream->As<ep::CudaStream>()->cuda_stream()>>>(curand_states_, n, max_val, buffer);\n  }\n\n private:\n  const int64_t lower_;\n  const int64_t upper_;\n  const int64_t num_sample_per_rank_;\n  curandState* curand_states_;\n};\n\ntemplate<typename K>\n__global__ void IotaKernel(int64_t n, K* out) {\n  CUDA_1D_KERNEL_LOOP(i, n) { out[i] = static_cast<K>(i); }\n}\n\ntemplate<typename K>\n__global__ void MarkPositive(const int64_t n, const int64_t offset, const int64_t num_classes,\n                             const K* labels, K* out) {\n  CUDA_1D_KERNEL_LOOP(i, n) {\n    K label = labels[i] - offset;\n    if (label >= 0 && label < num_classes) { out[label] = label - num_classes; }\n  }\n}\n\ntemplate<typename K>\n__global__ void GetSampledLabel(const int64_t n, const int64_t offset, const K* label,\n                                K* sampled_label) {\n  CUDA_1D_KERNEL_LOOP(i, n) { sampled_label[i] = label[i] + offset; }\n}\n\ntemplate<typename K>\n__global__ void GetLabelMap(const int64_t n, const int64_t parallel_num,\n                            const int64_t num_sample_per_rank, const K* bound_index,\n                            const K* bound_value, K* label_map) {\n  CUDA_1D_KERNEL_LOOP(i, n) {\n#pragma unroll\n    for (int64_t j = 0; j < parallel_num; j++) {\n      if (i >= bound_index[j] && i < bound_index[j + 1]) {\n        label_map[i] = label_map[i] - bound_value[j] + j * num_sample_per_rank;\n      }\n    }\n  }\n}\n\ntemplate<typename K>\n__global__ void GetPartionBound(const int64_t n, const int64_t parallel_num,\n                                const int64_t num_classes_per_rank, const K* key_ptr,\n                                const K* value_ptr, K* bound_index, K* bound_value) {\n  CUDA_1D_KERNEL_LOOP(i, n) {\n    if (i != 0) {\n      const K cur_in = key_ptr[i] / num_classes_per_rank;\n      const K pre_in = key_ptr[i - 1] / num_classes_per_rank;\n      if (cur_in > pre_in) {\n        assert(cur_in < parallel_num);\n#pragma unroll\n        for (int32_t j = pre_in + 1; j <= cur_in; ++j) {\n          bound_index[j] = static_cast<K>(i);\n          bound_value[j] = value_ptr[i];\n        }\n      }\n    }\n  }\n  CUDA_1D_KERNEL_LOOP(i, parallel_num + 1) {\n    const K first_in = key_ptr[0] / num_classes_per_rank;\n    const K last_in = key_ptr[n - 1] / num_classes_per_rank;\n    if (i <= first_in) {\n      bound_index[i] = 0;\n      bound_value[i] = value_ptr[0];\n    } else if (i > last_in) {\n      bound_index[i] = n;\n      bound_value[i] = value_ptr[n - 1];\n    }\n  }\n}\n\ntemplate<typename K>\n__global__ void GetMappedLabel(const int64_t n, const K* label_map_key, const K* label_map_value,\n                               K* mapped_label) {\n  CUDA_1D_KERNEL_LOOP(i, n) { mapped_label[label_map_key[i]] = label_map_value[i]; }\n}\n\ntemplate<typename K>\nvoid MapLabel(ep::Stream* stream, const int64_t num_classes, const int64_t batch_size,\n              const int64_t lower_bound, const int64_t parallel_num, const int64_t num_sample,\n              size_t temp_storage_bytes, const K* label_ptr, K* mapped_label_ptr,\n              K* cub_sort_values_ptr, K* cub_sort_keys_out_ptr, K* cub_sort_values_out_ptr,\n              void* cub_tmp_storage_ptr, K* bound_index_ptr, K* bound_value_ptr) {\n  IotaKernel<<<BlocksNum4ThreadsNum(batch_size), kCudaThreadsNumPerBlock, 0,\n               stream->As<ep::CudaStream>()->cuda_stream()>>>(batch_size, cub_sort_values_ptr);\n  OF_CUDA_CHECK((cub::DeviceRadixSort::SortPairs<K, K>(\n      cub_tmp_storage_ptr, temp_storage_bytes, label_ptr, cub_sort_keys_out_ptr,\n      cub_sort_values_ptr, cub_sort_values_out_ptr, batch_size, 0, sizeof(K) * 8,\n      stream->As<ep::CudaStream>()->cuda_stream())));\n  NotEqualToPreviousAdjacentIterator<K, K> unique_counting_iter(cub_sort_keys_out_ptr, 0);\n  OF_CUDA_CHECK((cub::DeviceScan::InclusiveSum<NotEqualToPreviousAdjacentIterator<K, K>, K*>(\n      cub_tmp_storage_ptr, temp_storage_bytes, unique_counting_iter, cub_sort_values_ptr,\n      batch_size, stream->As<ep::CudaStream>()->cuda_stream())));\n\n  GetPartionBound<<<BlocksNum4ThreadsNum(batch_size), kCudaThreadsNumPerBlock, 0,\n                    stream->As<ep::CudaStream>()->cuda_stream()>>>(\n      batch_size, parallel_num, num_classes, cub_sort_keys_out_ptr, cub_sort_values_ptr,\n      bound_index_ptr, bound_value_ptr);\n\n  GetLabelMap<K><<<BlocksNum4ThreadsNum(batch_size), kCudaThreadsNumPerBlock, 0,\n                   stream->As<ep::CudaStream>()->cuda_stream()>>>(\n      batch_size, parallel_num, num_sample, bound_index_ptr, bound_value_ptr, cub_sort_values_ptr);\n\n  GetMappedLabel<<<BlocksNum4ThreadsNum(batch_size), kCudaThreadsNumPerBlock, 0,\n                   stream->As<ep::CudaStream>()->cuda_stream()>>>(\n      batch_size, cub_sort_values_out_ptr, cub_sort_values_ptr, mapped_label_ptr);\n}\n\n}  // namespace\n\ntemplate<typename T, typename K>\nclass DistributedPartialFcSampleGpuKernel final : public user_op::OpKernel {\n public:\n  DistributedPartialFcSampleGpuKernel() = default;\n  ~DistributedPartialFcSampleGpuKernel() override = default;\n\n  std::shared_ptr<user_op::OpKernelState> CreateOpKernelState(\n      user_op::KernelInitContext* ctx) const override {\n    const SbpParallel& in_sbp = ctx->SbpParallel4ArgNameAndIndex(\"weight\", 0);\n    const TensorDesc* in_logical_desc = ctx->LogicalTensorDesc4ArgNameAndIndex(\"weight\", 0);\n    const int64_t class_num = in_logical_desc->shape().At(0);\n    const int64_t num_sample = ctx->Attr<int64_t>(\"num_sample\");\n    int64_t seed = ctx->Attr<int64_t>(\"seed\");\n    const int64_t parallel_num = ctx->parallel_ctx().parallel_num();\n    const int64_t num_sample_per_rank = RoundUp(num_sample, parallel_num) / parallel_num;\n    if (in_sbp.has_split_parallel() && in_sbp.split_parallel().axis() == 0 && parallel_num > 1) {\n      std::seed_seq seq{seed};\n      std::vector<int64_t> seeds(parallel_num);\n      seq.generate(seeds.begin(), seeds.end());\n      seed = seeds.at(ctx->parallel_ctx().parallel_id());\n      CHECK(ctx->SbpParallel4ArgNameAndIndex(\"label\", 0).has_broadcast_parallel());\n      BalancedSplitter bs(class_num, parallel_num);\n      return std::make_shared<DistributedPartialFcSampleOpKernelState>(\n          ctx->stream(), bs.At(ctx->parallel_ctx().parallel_id()).begin(),\n          bs.At(ctx->parallel_ctx().parallel_id()).end(), num_sample_per_rank, seed);\n    } else {\n      return std::make_shared<DistributedPartialFcSampleOpKernelState>(ctx->stream(), 0, class_num,\n                                                                       num_sample_per_rank, seed);\n    }\n  }\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state,\n               const user_op::OpKernelCache*) const override {\n    const user_op::Tensor* weight = ctx->Tensor4ArgNameAndIndex(\"weight\", 0);\n    const user_op::Tensor* label = ctx->Tensor4ArgNameAndIndex(\"label\", 0);\n    user_op::Tensor* mapped_label = ctx->Tensor4ArgNameAndIndex(\"mapped_label\", 0);\n    user_op::Tensor* sampled_label = ctx->Tensor4ArgNameAndIndex(\"sampled_label\", 0);\n    user_op::Tensor* sampled_weight = ctx->Tensor4ArgNameAndIndex(\"sampled_weight\", 0);\n    user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n\n    const int64_t batch_size = label->shape_view().At(0);\n    const int64_t num_classes = weight->shape_view().At(0);\n    const int64_t parallel_num = ctx->parallel_ctx().parallel_num();\n    TmpBufferManager<K> buffer_manager(tmp_buffer->mut_dptr(), num_classes, batch_size,\n                                       parallel_num);\n\n    auto* kernel_state = dynamic_cast<DistributedPartialFcSampleOpKernelState*>(state);\n    CHECK_NOTNULL(kernel_state);\n    CHECK_EQ(num_classes, kernel_state->upper() - kernel_state->lower());\n    const int64_t lower_bound = kernel_state->lower();\n    const int64_t num_sample = kernel_state->num_sample_per_rank();\n    kernel_state->GenRandom<K>(ctx->stream(), num_classes, num_classes,\n                               buffer_manager.CubSortKeysPtr());\n    MarkPositive<<<BlocksNum4ThreadsNum(batch_size), kCudaThreadsNumPerBlock, 0,\n                   ctx->stream()->As<ep::CudaStream>()->cuda_stream()>>>(\n        batch_size, lower_bound, num_classes, label->dptr<K>(), buffer_manager.CubSortKeysPtr());\n    IotaKernel<<<BlocksNum4ThreadsNum(num_classes), kCudaThreadsNumPerBlock, 0,\n                 ctx->stream()->As<ep::CudaStream>()->cuda_stream()>>>(\n        num_classes, buffer_manager.CubSortValuesPtr());\n    size_t temp_storage_bytes = buffer_manager.GetCubTmpStorageSize();\n    OF_CUDA_CHECK((cub::DeviceRadixSort::SortPairs<K, K>(\n        buffer_manager.CubTmpStoragePtr(), temp_storage_bytes, buffer_manager.CubSortKeysPtr(),\n        buffer_manager.CubSortKeysOutPtr(), buffer_manager.CubSortValuesPtr(),\n        buffer_manager.CubSortValuesOutPtr(), num_classes, 0, sizeof(K) * 8,\n        ctx->stream()->As<ep::CudaStream>()->cuda_stream())));\n\n    GetSampledLabel<<<BlocksNum4ThreadsNum(num_sample), kCudaThreadsNumPerBlock, 0,\n                      ctx->stream()->As<ep::CudaStream>()->cuda_stream()>>>(\n        num_sample, lower_bound, buffer_manager.CubSortValuesOutPtr(),\n        sampled_label->mut_dptr<K>());\n\n    GatherKernelUtilImpl<DeviceType::kCUDA, T, K>::Forward(\n        ctx->stream(), buffer_manager.CubSortValuesOutPtr(), num_sample, weight->dptr<T>(),\n        Shape({1, num_classes, weight->shape_view().Count(1)}), sampled_weight->mut_dptr<T>(), 0);\n\n    MapLabel<K>(ctx->stream(), num_classes, batch_size, lower_bound, parallel_num, num_sample,\n                buffer_manager.GetCubTmpStorageSize(), label->dptr<K>(),\n                mapped_label->mut_dptr<K>(), buffer_manager.CubSortValuesPtr(),\n                buffer_manager.CubSortKeysOutPtr(), buffer_manager.CubSortValuesOutPtr(),\n                buffer_manager.CubTmpStoragePtr(), buffer_manager.BoundIndexPtr(),\n                buffer_manager.BoundValuePtr());\n  }\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_DISTRIBUTED_PARTIAL_FC_SAMPLE_CUDA_KERNEL(dtype_pair, ltype_pair)               \\\n  REGISTER_USER_KERNEL(\"distributed_partial_fc_sample\")                                          \\\n      .SetCreateFn<DistributedPartialFcSampleGpuKernel<OF_PP_PAIR_FIRST(dtype_pair),             \\\n                                                       OF_PP_PAIR_FIRST(ltype_pair)>>()          \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)                           \\\n                       && (user_op::HobDataType(\"label\", 0) == OF_PP_PAIR_SECOND(ltype_pair))    \\\n                       && (user_op::HobDataType(\"weight\", 0) == OF_PP_PAIR_SECOND(dtype_pair)))  \\\n      .SetInferTmpSizeFn([](oneflow::user_op::InferContext* ctx) {                               \\\n        const int64_t num_classes = ctx->InputTensorDesc(\"weight\", 0).shape().At(0);             \\\n        const int64_t batch_size = ctx->InputTensorDesc(\"label\", 0).shape().At(0);               \\\n        const int64_t parallel_num = ctx->parallel_ctx().parallel_num();                         \\\n        TmpBufferManager<OF_PP_PAIR_FIRST(ltype_pair)> buffer_manager(nullptr, num_classes,      \\\n                                                                      batch_size, parallel_num); \\\n        return buffer_manager.GetTotalBufferSize();                                              \\\n      });\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_DISTRIBUTED_PARTIAL_FC_SAMPLE_CUDA_KERNEL,\n                                 FLOATING_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ)\n\ntemplate<typename T, typename K>\nclass DistributedPartialFcSampleDisableBoxingGpuKernel final : public user_op::OpKernel {\n public:\n  DistributedPartialFcSampleDisableBoxingGpuKernel() = default;\n  ~DistributedPartialFcSampleDisableBoxingGpuKernel() override = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state,\n               const user_op::OpKernelCache*) const override {\n    const user_op::Tensor* sampled_weight_diff =\n        ctx->Tensor4ArgNameAndIndex(\"sampled_weight_diff\", 0);\n    const user_op::Tensor* sampled_label = ctx->Tensor4ArgNameAndIndex(\"sampled_label\", 0);\n    user_op::Tensor* boxing_disabled_sampled_weight_diff =\n        ctx->Tensor4ArgNameAndIndex(\"boxing_disabled_sampled_weight_diff\", 0);\n    user_op::Tensor* boxing_disabled_sampled_label =\n        ctx->Tensor4ArgNameAndIndex(\"boxing_disabled_sampled_label\", 0);\n    Memcpy<DeviceType::kCUDA>(ctx->stream(), boxing_disabled_sampled_weight_diff->mut_dptr<void>(),\n                              sampled_weight_diff->dptr<void>(),\n                              sampled_weight_diff->shape_view().elem_cnt()\n                                  * GetSizeOfDataType(sampled_weight_diff->data_type()));\n    Memcpy<DeviceType::kCUDA>(\n        ctx->stream(), boxing_disabled_sampled_label->mut_dptr<void>(), sampled_label->dptr<void>(),\n        sampled_label->shape_view().elem_cnt() * GetSizeOfDataType(sampled_label->data_type()));\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_DISTRIBUTED_PARTIAL_FC_SAMPLE_DISABLE_BOXING_CUDA_KERNEL(dtype_pair, ltype_pair) \\\n  REGISTER_USER_KERNEL(\"distributed_partial_fc_sample_disable_boxing\")                            \\\n      .SetCreateFn<DistributedPartialFcSampleDisableBoxingGpuKernel<                              \\\n          OF_PP_PAIR_FIRST(dtype_pair), OF_PP_PAIR_FIRST(ltype_pair)>>()                          \\\n      .SetIsMatchedHob(                                                                           \\\n          (user_op::HobDeviceType() == DeviceType::kCUDA)                                         \\\n          && (user_op::HobDataType(\"sampled_label\", 0) == OF_PP_PAIR_SECOND(ltype_pair))          \\\n          && (user_op::HobDataType(\"sampled_weight_diff\", 0) == OF_PP_PAIR_SECOND(dtype_pair)));\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_DISTRIBUTED_PARTIAL_FC_SAMPLE_DISABLE_BOXING_CUDA_KERNEL,\n                                 FLOATING_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ)\n\n}  // namespace user_op\n}  // namespace oneflow\n#endif\n"
  },
  {
    "path": "oneflow/user/kernels/pocketfft_hdronly.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n/*\nThis file is part of pocketfft.\n\nCopyright (C) 2010-2021 Max-Planck-Society\nCopyright (C) 2019-2020 Peter Bell\n\nFor the odd-sized DCT-IV transforms:\n  Copyright (C) 2003, 2007-14 Matteo Frigo\n  Copyright (C) 2003, 2007-14 Massachusetts Institute of Technology\n\nAuthors: Martin Reinecke, Peter Bell\n\nAll rights reserved.\n\nRedistribution and use in source and binary forms, with or without modification,\nare permitted provided that the following conditions are met:\n\n* Redistributions of source code must retain the above copyright notice, this\n  list of conditions and the following disclaimer.\n* Redistributions in binary form must reproduce the above copyright notice, this\n  list of conditions and the following disclaimer in the documentation and/or\n  other materials provided with the distribution.\n* Neither the name of the copyright holder nor the names of its contributors may\n  be used to endorse or promote products derived from this software without\n  specific prior written permission.\n\nTHIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\nANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\nWARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\nDISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR\nANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\nLOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON\nANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\nSOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n*/\n\n#ifndef POCKETFFT_HDRONLY_H\n#define POCKETFFT_HDRONLY_H\n\n#ifndef __cplusplus\n#error This file is C++ and requires a C++ compiler.\n#endif\n\n#if !(__cplusplus >= 201103L || _MSVC_LANG + 0L >= 201103L)\n#error This file requires at least C++11 support.\n#endif\n\n#ifndef POCKETFFT_CACHE_SIZE\n#define POCKETFFT_CACHE_SIZE 0\n#endif\n\n#include <cmath>\n#include <cstdlib>\n#include <stdexcept>\n#include <memory>\n#include <vector>\n#include <complex>\n#include <algorithm>\n#if POCKETFFT_CACHE_SIZE != 0\n#include <array>\n#include <mutex>\n#endif\n\n#ifndef POCKETFFT_NO_MULTITHREADING\n#include <mutex>\n#include <condition_variable>\n#include <thread>\n#include <queue>\n#include <atomic>\n#include <functional>\n#include <new>\n\n#ifdef POCKETFFT_PTHREADS\n#include <pthread.h>\n#endif\n#endif\n\n#if defined(__GNUC__)\n#define POCKETFFT_NOINLINE __attribute__((noinline))\n#define POCKETFFT_RESTRICT __restrict__\n#elif defined(_MSC_VER)\n#define POCKETFFT_NOINLINE __declspec(noinline)\n#define POCKETFFT_RESTRICT __restrict\n#else\n#define POCKETFFT_NOINLINE\n#define POCKETFFT_RESTRICT\n#endif\n\nnamespace pocketfft {\n\nnamespace detail {\nusing std::ptrdiff_t;\nusing std::size_t;\n\n// Always use std:: for <cmath> functions\ntemplate<typename T>\nT cos(T) = delete;\ntemplate<typename T>\nT sin(T) = delete;\ntemplate<typename T>\nT sqrt(T) = delete;\n\nusing shape_t = std::vector<size_t>;\nusing stride_t = std::vector<ptrdiff_t>;\n\nconstexpr bool FORWARD = true, BACKWARD = false;\n\n// only enable vector support for gcc>=5.0 and clang>=5.0\n#ifndef POCKETFFT_NO_VECTORS\n#define POCKETFFT_NO_VECTORS\n#if defined(__INTEL_COMPILER)\n// do nothing. This is necessary because this compiler also sets __GNUC__.\n#elif defined(__clang__)\n// AppleClang has their own version numbering\n#ifdef __apple_build_version__\n#if (__clang_major__ > 9) || (__clang_major__ == 9 && __clang_minor__ >= 1)\n#undef POCKETFFT_NO_VECTORS\n#endif\n#elif __clang_major__ >= 5\n#undef POCKETFFT_NO_VECTORS\n#endif\n#elif defined(__GNUC__)\n#if __GNUC__ >= 5\n#undef POCKETFFT_NO_VECTORS\n#endif\n#endif\n#endif\n\ntemplate<typename T>\nstruct VLEN {\n  static constexpr size_t val = 1;\n};\n\n#ifndef POCKETFFT_NO_VECTORS\n#if (defined(__AVX512F__))\ntemplate<>\nstruct VLEN<float> {\n  static constexpr size_t val = 16;\n};\ntemplate<>\nstruct VLEN<double> {\n  static constexpr size_t val = 8;\n};\n#elif (defined(__AVX__))\ntemplate<>\nstruct VLEN<float> {\n  static constexpr size_t val = 8;\n};\ntemplate<>\nstruct VLEN<double> {\n  static constexpr size_t val = 4;\n};\n#elif (defined(__SSE2__))\ntemplate<>\nstruct VLEN<float> {\n  static constexpr size_t val = 4;\n};\ntemplate<>\nstruct VLEN<double> {\n  static constexpr size_t val = 2;\n};\n#elif (defined(__VSX__))\ntemplate<>\nstruct VLEN<float> {\n  static constexpr size_t val = 4;\n};\ntemplate<>\nstruct VLEN<double> {\n  static constexpr size_t val = 2;\n};\n#elif (defined(__ARM_NEON__) || defined(__ARM_NEON))\ntemplate<>\nstruct VLEN<float> {\n  static constexpr size_t val = 4;\n};\ntemplate<>\nstruct VLEN<double> {\n  static constexpr size_t val = 2;\n};\n#else\n#define POCKETFFT_NO_VECTORS\n#endif\n#endif\n\n#if __cplusplus >= 201703L\ninline void* aligned_alloc(size_t align, size_t size) {\n  // aligned_alloc() requires that the requested size is a multiple of \"align\"\n  void* ptr = ::aligned_alloc(align, (size + align - 1) & (~(align - 1)));\n  if (!ptr) throw std::bad_alloc();\n  return ptr;\n}\ninline void aligned_dealloc(void* ptr) { free(ptr); }\n#else  // portable emulation\ninline void* aligned_alloc(size_t align, size_t size) {\n  align = std::max(align, alignof(max_align_t));\n  void* ptr = malloc(size + align);\n  if (!ptr) throw std::bad_alloc();\n  void* res = reinterpret_cast<void*>((reinterpret_cast<uintptr_t>(ptr) & ~(uintptr_t(align - 1)))\n                                      + uintptr_t(align));\n  (reinterpret_cast<void**>(res))[-1] = ptr;\n  return res;\n}\ninline void aligned_dealloc(void* ptr) {\n  if (ptr) free((reinterpret_cast<void**>(ptr))[-1]);\n}\n#endif\n\ntemplate<typename T>\nclass arr {\n private:\n  T* p;\n  size_t sz;\n\n#if defined(POCKETFFT_NO_VECTORS)\n  static T* ralloc(size_t num) {\n    if (num == 0) return nullptr;\n    void* res = malloc(num * sizeof(T));\n    if (!res) throw std::bad_alloc();\n    return reinterpret_cast<T*>(res);\n  }\n  static void dealloc(T* ptr) { free(ptr); }\n#else\n  static T* ralloc(size_t num) {\n    if (num == 0) return nullptr;\n    void* ptr = aligned_alloc(64, num * sizeof(T));\n    return static_cast<T*>(ptr);\n  }\n  static void dealloc(T* ptr) { aligned_dealloc(ptr); }\n#endif\n\n public:\n  arr() : p(0), sz(0) {}\n  arr(size_t n) : p(ralloc(n)), sz(n) {}\n  arr(arr&& other) : p(other.p), sz(other.sz) {\n    other.p = nullptr;\n    other.sz = 0;\n  }\n  ~arr() { dealloc(p); }\n\n  void resize(size_t n) {\n    if (n == sz) return;\n    dealloc(p);\n    p = ralloc(n);\n    sz = n;\n  }\n\n  T& operator[](size_t idx) { return p[idx]; }\n  const T& operator[](size_t idx) const { return p[idx]; }\n\n  T* data() { return p; }\n  const T* data() const { return p; }\n\n  size_t size() const { return sz; }\n};\n\ntemplate<typename T>\nstruct cmplx {\n  T r, i;\n  cmplx() {}\n  cmplx(T r_, T i_) : r(r_), i(i_) {}\n  void Set(T r_, T i_) {\n    r = r_;\n    i = i_;\n  }\n  void Set(T r_) {\n    r = r_;\n    i = T(0);\n  }\n  cmplx& operator+=(const cmplx& other) {\n    r += other.r;\n    i += other.i;\n    return *this;\n  }\n  template<typename T2>\n  cmplx& operator*=(T2 other) {\n    r *= other;\n    i *= other;\n    return *this;\n  }\n  template<typename T2>\n  cmplx& operator*=(const cmplx<T2>& other) {\n    T tmp = r * other.r - i * other.i;\n    i = r * other.i + i * other.r;\n    r = tmp;\n    return *this;\n  }\n  template<typename T2>\n  cmplx& operator+=(const cmplx<T2>& other) {\n    r += other.r;\n    i += other.i;\n    return *this;\n  }\n  template<typename T2>\n  cmplx& operator-=(const cmplx<T2>& other) {\n    r -= other.r;\n    i -= other.i;\n    return *this;\n  }\n  template<typename T2>\n  auto operator*(const T2& other) const -> cmplx<decltype(r * other)> {\n    return {r * other, i * other};\n  }\n  template<typename T2>\n  auto operator+(const cmplx<T2>& other) const -> cmplx<decltype(r + other.r)> {\n    return {r + other.r, i + other.i};\n  }\n  template<typename T2>\n  auto operator-(const cmplx<T2>& other) const -> cmplx<decltype(r + other.r)> {\n    return {r - other.r, i - other.i};\n  }\n  template<typename T2>\n  auto operator*(const cmplx<T2>& other) const -> cmplx<decltype(r + other.r)> {\n    return {r * other.r - i * other.i, r * other.i + i * other.r};\n  }\n  template<bool fwd, typename T2>\n  auto special_mul(const cmplx<T2>& other) const -> cmplx<decltype(r + other.r)> {\n    using Tres = cmplx<decltype(r + other.r)>;\n    return fwd ? Tres(r * other.r + i * other.i, i * other.r - r * other.i)\n               : Tres(r * other.r - i * other.i, r * other.i + i * other.r);\n  }\n};\ntemplate<typename T>\ninline void PM(T& a, T& b, T c, T d) {\n  a = c + d;\n  b = c - d;\n}\ntemplate<typename T>\ninline void PMINPLACE(T& a, T& b) {\n  T t = a;\n  a += b;\n  b = t - b;\n}\ntemplate<typename T>\ninline void MPINPLACE(T& a, T& b) {\n  T t = a;\n  a -= b;\n  b = t + b;\n}\ntemplate<typename T>\ncmplx<T> conj(const cmplx<T>& a) {\n  return {a.r, -a.i};\n}\ntemplate<bool fwd, typename T, typename T2>\nvoid special_mul(const cmplx<T>& v1, const cmplx<T2>& v2, cmplx<T>& res) {\n  res = fwd ? cmplx<T>(v1.r * v2.r + v1.i * v2.i, v1.i * v2.r - v1.r * v2.i)\n            : cmplx<T>(v1.r * v2.r - v1.i * v2.i, v1.r * v2.i + v1.i * v2.r);\n}\n\ntemplate<typename T>\nvoid ROT90(cmplx<T>& a) {\n  auto tmp_ = a.r;\n  a.r = -a.i;\n  a.i = tmp_;\n}\ntemplate<bool fwd, typename T>\nvoid ROTX90(cmplx<T>& a) {\n  auto tmp_ = fwd ? -a.r : a.r;\n  a.r = fwd ? a.i : -a.i;\n  a.i = tmp_;\n}\n\n//\n// twiddle factor section\n//\ntemplate<typename T>\nclass sincos_2pibyn {\n private:\n  using Thigh = typename std::conditional<(sizeof(T) > sizeof(double)), T, double>::type;\n  size_t N, mask, shift;\n  arr<cmplx<Thigh>> v1, v2;\n\n  static cmplx<Thigh> calc(size_t x, size_t n, Thigh ang) {\n    x <<= 3;\n    if (x < 4 * n)  // first half\n    {\n      if (x < 2 * n)  // first quadrant\n      {\n        if (x < n) return cmplx<Thigh>(std::cos(Thigh(x) * ang), std::sin(Thigh(x) * ang));\n        return cmplx<Thigh>(std::sin(Thigh(2 * n - x) * ang), std::cos(Thigh(2 * n - x) * ang));\n      } else  // second quadrant\n      {\n        x -= 2 * n;\n        if (x < n) return cmplx<Thigh>(-std::sin(Thigh(x) * ang), std::cos(Thigh(x) * ang));\n        return cmplx<Thigh>(-std::cos(Thigh(2 * n - x) * ang), std::sin(Thigh(2 * n - x) * ang));\n      }\n    } else {\n      x = 8 * n - x;\n      if (x < 2 * n)  // third quadrant\n      {\n        if (x < n) return cmplx<Thigh>(std::cos(Thigh(x) * ang), -std::sin(Thigh(x) * ang));\n        return cmplx<Thigh>(std::sin(Thigh(2 * n - x) * ang), -std::cos(Thigh(2 * n - x) * ang));\n      } else  // fourth quadrant\n      {\n        x -= 2 * n;\n        if (x < n) return cmplx<Thigh>(-std::sin(Thigh(x) * ang), -std::cos(Thigh(x) * ang));\n        return cmplx<Thigh>(-std::cos(Thigh(2 * n - x) * ang), -std::sin(Thigh(2 * n - x) * ang));\n      }\n    }\n  }\n\n public:\n  POCKETFFT_NOINLINE sincos_2pibyn(size_t n) : N(n) {\n    constexpr auto pi = 3.141592653589793238462643383279502884197L;\n    Thigh ang = Thigh(0.25L * pi / n);\n    size_t nval = (n + 2) / 2;\n    shift = 1;\n    while ((size_t(1) << shift) * (size_t(1) << shift) < nval) ++shift;\n    mask = (size_t(1) << shift) - 1;\n    v1.resize(mask + 1);\n    v1[0].Set(Thigh(1), Thigh(0));\n    for (size_t i = 1; i < v1.size(); ++i) v1[i] = calc(i, n, ang);\n    v2.resize((nval + mask) / (mask + 1));\n    v2[0].Set(Thigh(1), Thigh(0));\n    for (size_t i = 1; i < v2.size(); ++i) v2[i] = calc(i * (mask + 1), n, ang);\n  }\n\n  cmplx<T> operator[](size_t idx) const {\n    if (2 * idx <= N) {\n      auto x1 = v1[idx & mask], x2 = v2[idx >> shift];\n      return cmplx<T>(T(x1.r * x2.r - x1.i * x2.i), T(x1.r * x2.i + x1.i * x2.r));\n    }\n    idx = N - idx;\n    auto x1 = v1[idx & mask], x2 = v2[idx >> shift];\n    return cmplx<T>(T(x1.r * x2.r - x1.i * x2.i), -T(x1.r * x2.i + x1.i * x2.r));\n  }\n};\n\nstruct util  // hack to avoid duplicate symbols\n{\n  static POCKETFFT_NOINLINE size_t largest_prime_factor(size_t n) {\n    size_t res = 1;\n    while ((n & 1) == 0) {\n      res = 2;\n      n >>= 1;\n    }\n    for (size_t x = 3; x * x <= n; x += 2)\n      while ((n % x) == 0) {\n        res = x;\n        n /= x;\n      }\n    if (n > 1) res = n;\n    return res;\n  }\n\n  static POCKETFFT_NOINLINE double cost_guess(size_t n) {\n    constexpr double lfp = 1.1;  // penalty for non-hardcoded larger factors\n    size_t ni = n;\n    double result = 0.;\n    while ((n & 1) == 0) {\n      result += 2;\n      n >>= 1;\n    }\n    for (size_t x = 3; x * x <= n; x += 2)\n      while ((n % x) == 0) {\n        result += (x <= 5) ? double(x) : lfp * double(x);  // penalize larger prime factors\n        n /= x;\n      }\n    if (n > 1) result += (n <= 5) ? double(n) : lfp * double(n);\n    return result * double(ni);\n  }\n\n  /* returns the smallest composite of 2, 3, 5, 7 and 11 which is >= n */\n  static POCKETFFT_NOINLINE size_t good_size_cmplx(size_t n) {\n    if (n <= 12) return n;\n\n    size_t bestfac = 2 * n;\n    for (size_t f11 = 1; f11 < bestfac; f11 *= 11)\n      for (size_t f117 = f11; f117 < bestfac; f117 *= 7)\n        for (size_t f1175 = f117; f1175 < bestfac; f1175 *= 5) {\n          size_t x = f1175;\n          while (x < n) x *= 2;\n          for (;;) {\n            if (x < n)\n              x *= 3;\n            else if (x > n) {\n              if (x < bestfac) bestfac = x;\n              if (x & 1) break;\n              x >>= 1;\n            } else\n              return n;\n          }\n        }\n    return bestfac;\n  }\n\n  /* returns the smallest composite of 2, 3, 5 which is >= n */\n  static POCKETFFT_NOINLINE size_t good_size_real(size_t n) {\n    if (n <= 6) return n;\n\n    size_t bestfac = 2 * n;\n    for (size_t f5 = 1; f5 < bestfac; f5 *= 5) {\n      size_t x = f5;\n      while (x < n) x *= 2;\n      for (;;) {\n        if (x < n)\n          x *= 3;\n        else if (x > n) {\n          if (x < bestfac) bestfac = x;\n          if (x & 1) break;\n          x >>= 1;\n        } else\n          return n;\n      }\n    }\n    return bestfac;\n  }\n\n  static size_t prod(const shape_t& shape) {\n    size_t res = 1;\n    for (auto sz : shape) res *= sz;\n    return res;\n  }\n\n  static POCKETFFT_NOINLINE void sanity_check(const shape_t& shape, const stride_t& stride_in,\n                                              const stride_t& stride_out, bool inplace) {\n    auto ndim = shape.size();\n    if (ndim < 1) throw std::runtime_error(\"ndim must be >= 1\");\n    if ((stride_in.size() != ndim) || (stride_out.size() != ndim))\n      throw std::runtime_error(\"stride dimension mismatch\");\n    if (inplace && (stride_in != stride_out)) throw std::runtime_error(\"stride mismatch\");\n  }\n\n  static POCKETFFT_NOINLINE void sanity_check(const shape_t& shape, const stride_t& stride_in,\n                                              const stride_t& stride_out, bool inplace,\n                                              const shape_t& axes) {\n    sanity_check(shape, stride_in, stride_out, inplace);\n    auto ndim = shape.size();\n    shape_t tmp(ndim, 0);\n    for (auto ax : axes) {\n      if (ax >= ndim) throw std::invalid_argument(\"bad axis number\");\n      if (++tmp[ax] > 1) throw std::invalid_argument(\"axis specified repeatedly\");\n    }\n  }\n\n  static POCKETFFT_NOINLINE void sanity_check(const shape_t& shape, const stride_t& stride_in,\n                                              const stride_t& stride_out, bool inplace,\n                                              size_t axis) {\n    sanity_check(shape, stride_in, stride_out, inplace);\n    if (axis >= shape.size()) throw std::invalid_argument(\"bad axis number\");\n  }\n\n#ifdef POCKETFFT_NO_MULTITHREADING\n  static size_t thread_count(size_t /*nthreads*/, const shape_t& /*shape*/, size_t /*axis*/,\n                             size_t /*vlen*/) {\n    return 1;\n  }\n#else\n  static size_t thread_count(size_t nthreads, const shape_t& shape, size_t axis, size_t vlen) {\n    if (nthreads == 1) return 1;\n    size_t size = prod(shape);\n    size_t parallel = size / (shape[axis] * vlen);\n    if (shape[axis] < 1000) parallel /= 4;\n    size_t max_threads = nthreads == 0 ? std::thread::hardware_concurrency() : nthreads;\n    return std::max(size_t(1), std::min(parallel, max_threads));\n  }\n#endif\n};\n\nnamespace threading {\n\n#ifdef POCKETFFT_NO_MULTITHREADING\n\nconstexpr inline size_t thread_id() { return 0; }\nconstexpr inline size_t num_threads() { return 1; }\n\ntemplate<typename Func>\nvoid thread_map(size_t /* nthreads */, Func f) {\n  f();\n}\n\n#else\n\ninline size_t& thread_id() {\n  static thread_local size_t thread_id_ = 0;\n  return thread_id_;\n}\ninline size_t& num_threads() {\n  static thread_local size_t num_threads_ = 1;\n  return num_threads_;\n}\nstatic const size_t max_threads = std::max(1u, std::thread::hardware_concurrency());\n\nclass latch {\n  std::atomic<size_t> num_left_;\n  std::mutex mut_;\n  std::condition_variable completed_;\n  using lock_t = std::unique_lock<std::mutex>;\n\n public:\n  latch(size_t n) : num_left_(n) {}\n\n  void count_down() {\n    lock_t lock(mut_);\n    if (--num_left_) return;\n    completed_.notify_all();\n  }\n\n  void wait() {\n    lock_t lock(mut_);\n    completed_.wait(lock, [this] { return is_ready(); });\n  }\n  bool is_ready() { return num_left_ == 0; }\n};\n\ntemplate<typename T>\nclass concurrent_queue {\n  std::queue<T> q_;\n  std::mutex mut_;\n  std::atomic<size_t> size_;\n  using lock_t = std::lock_guard<std::mutex>;\n\n public:\n  void push(T val) {\n    lock_t lock(mut_);\n    ++size_;\n    q_.push(std::move(val));\n  }\n\n  bool try_pop(T& val) {\n    if (size_ == 0) return false;\n    lock_t lock(mut_);\n    // Queue might have been emptied while we acquired the lock\n    if (q_.empty()) return false;\n\n    val = std::move(q_.front());\n    --size_;\n    q_.pop();\n    return true;\n  }\n\n  bool empty() const { return size_ == 0; }\n};\n\n// C++ allocator with support for over-aligned types\ntemplate<typename T>\nstruct aligned_allocator {\n  using value_type = T;\n  template<class U>\n  aligned_allocator(const aligned_allocator<U>&) {}\n  aligned_allocator() = default;\n\n  T* allocate(size_t n) {\n    void* mem = aligned_alloc(alignof(T), n * sizeof(T));\n    return static_cast<T*>(mem);\n  }\n\n  void deallocate(T* p, size_t /*n*/) { aligned_dealloc(p); }\n};\n\nclass thread_pool {\n  // A reasonable guess, probably close enough for most hardware\n  static constexpr size_t cache_line_size = 64;\n  struct alignas(cache_line_size) worker {\n    std::thread thread;\n    std::condition_variable work_ready;\n    std::mutex mut;\n    std::atomic_flag busy_flag = ATOMIC_FLAG_INIT;\n    std::function<void()> work;\n\n    void worker_main(std::atomic<bool>& shutdown_flag, std::atomic<size_t>& unscheduled_tasks,\n                     concurrent_queue<std::function<void()>>& overflow_work) {\n      using lock_t = std::unique_lock<std::mutex>;\n      bool expect_work = true;\n      while (!shutdown_flag || expect_work) {\n        std::function<void()> local_work;\n        if (expect_work || unscheduled_tasks == 0) {\n          lock_t lock(mut);\n          // Wait until there is work to be executed\n          work_ready.wait(lock, [&] { return (work || shutdown_flag); });\n          local_work.swap(work);\n          expect_work = false;\n        }\n\n        bool marked_busy = false;\n        if (local_work) {\n          marked_busy = true;\n          local_work();\n        }\n\n        if (!overflow_work.empty()) {\n          if (!marked_busy && busy_flag.test_and_set()) {\n            expect_work = true;\n            continue;\n          }\n          marked_busy = true;\n\n          while (overflow_work.try_pop(local_work)) {\n            --unscheduled_tasks;\n            local_work();\n          }\n        }\n\n        if (marked_busy) busy_flag.clear();\n      }\n    }\n  };\n\n  concurrent_queue<std::function<void()>> overflow_work_;\n  std::mutex mut_;\n  std::vector<worker, aligned_allocator<worker>> workers_;\n  std::atomic<bool> shutdown_;\n  std::atomic<size_t> unscheduled_tasks_;\n  using lock_t = std::lock_guard<std::mutex>;\n\n  void create_threads() {\n    lock_t lock(mut_);\n    size_t nthreads = workers_.size();\n    for (size_t i = 0; i < nthreads; ++i) {\n      try {\n        auto* worker = &workers_[i];\n        worker->busy_flag.clear();\n        worker->work = nullptr;\n        worker->thread = std::thread(\n            [worker, this] { worker->worker_main(shutdown_, unscheduled_tasks_, overflow_work_); });\n      } catch (...) {\n        shutdown_locked();\n        throw;\n      }\n    }\n  }\n\n  void shutdown_locked() {\n    shutdown_ = true;\n    for (auto& worker : workers_) worker.work_ready.notify_all();\n\n    for (auto& worker : workers_)\n      if (worker.thread.joinable()) worker.thread.join();\n  }\n\n public:\n  explicit thread_pool(size_t nthreads) : workers_(nthreads) { create_threads(); }\n\n  thread_pool() : thread_pool(max_threads) {}\n\n  ~thread_pool() { shutdown(); }\n\n  void submit(std::function<void()> work) {\n    lock_t lock(mut_);\n    if (shutdown_) throw std::runtime_error(\"Work item submitted after shutdown\");\n\n    ++unscheduled_tasks_;\n\n    // First check for any idle workers and wake those\n    for (auto& worker : workers_)\n      if (!worker.busy_flag.test_and_set()) {\n        --unscheduled_tasks_;\n        {\n          lock_t lock(worker.mut);\n          worker.work = std::move(work);\n        }\n        worker.work_ready.notify_one();\n        return;\n      }\n\n    // If no workers were idle, push onto the overflow queue for later\n    overflow_work_.push(std::move(work));\n  }\n\n  void shutdown() {\n    lock_t lock(mut_);\n    shutdown_locked();\n  }\n\n  void restart() {\n    shutdown_ = false;\n    create_threads();\n  }\n};\n\ninline thread_pool& get_pool() {\n  static thread_pool pool;\n#ifdef POCKETFFT_PTHREADS\n  static std::once_flag f;\n  std::call_once(f, [] {\n    pthread_atfork(\n        +[] { get_pool().shutdown(); },  // prepare\n        +[] { get_pool().restart(); },   // parent\n        +[] { get_pool().restart(); }    // child\n    );\n  });\n#endif\n\n  return pool;\n}\n\n/** Map a function f over nthreads */\ntemplate<typename Func>\nvoid thread_map(size_t nthreads, Func f) {\n  if (nthreads == 0) nthreads = max_threads;\n\n  if (nthreads == 1) {\n    f();\n    return;\n  }\n\n  auto& pool = get_pool();\n  latch counter(nthreads);\n  std::exception_ptr ex;\n  std::mutex ex_mut;\n  for (size_t i = 0; i < nthreads; ++i) {\n    pool.submit([&f, &counter, &ex, &ex_mut, i, nthreads] {\n      thread_id() = i;\n      num_threads() = nthreads;\n      try {\n        f();\n      } catch (...) {\n        std::lock_guard<std::mutex> lock(ex_mut);\n        ex = std::current_exception();\n      }\n      counter.count_down();\n    });\n  }\n  counter.wait();\n  if (ex) std::rethrow_exception(ex);\n}\n\n#endif\n\n}  // namespace threading\n\n//\n// complex FFTPACK transforms\n//\n\ntemplate<typename T0>\nclass cfftp {\n private:\n  struct fctdata {\n    size_t fct;\n    cmplx<T0>*tw, *tws;\n  };\n\n  size_t length;\n  arr<cmplx<T0>> mem;\n  std::vector<fctdata> fact;\n\n  void add_factor(size_t factor) { fact.push_back({factor, nullptr, nullptr}); }\n\n  template<bool fwd, typename T>\n  void pass2(size_t ido, size_t l1, const T* POCKETFFT_RESTRICT cc, T* POCKETFFT_RESTRICT ch,\n             const cmplx<T0>* POCKETFFT_RESTRICT wa) const {\n    auto CH = [ch, ido, l1](size_t a, size_t b, size_t c) -> T& {\n      return ch[a + ido * (b + l1 * c)];\n    };\n    auto CC = [cc, ido](size_t a, size_t b, size_t c) -> const T& {\n      return cc[a + ido * (b + 2 * c)];\n    };\n    auto WA = [wa, ido](size_t x, size_t i) { return wa[i - 1 + x * (ido - 1)]; };\n\n    if (ido == 1)\n      for (size_t k = 0; k < l1; ++k) {\n        CH(0, k, 0) = CC(0, 0, k) + CC(0, 1, k);\n        CH(0, k, 1) = CC(0, 0, k) - CC(0, 1, k);\n      }\n    else\n      for (size_t k = 0; k < l1; ++k) {\n        CH(0, k, 0) = CC(0, 0, k) + CC(0, 1, k);\n        CH(0, k, 1) = CC(0, 0, k) - CC(0, 1, k);\n        for (size_t i = 1; i < ido; ++i) {\n          CH(i, k, 0) = CC(i, 0, k) + CC(i, 1, k);\n          special_mul<fwd>(CC(i, 0, k) - CC(i, 1, k), WA(0, i), CH(i, k, 1));\n        }\n      }\n  }\n\n#define POCKETFFT_PREP3(idx)                \\\n  T t0 = CC(idx, 0, k), t1, t2;             \\\n  PM(t1, t2, CC(idx, 1, k), CC(idx, 2, k)); \\\n  CH(idx, k, 0) = t0 + t1;\n#define POCKETFFT_PARTSTEP3a(u1, u2, twr, twi) \\\n  {                                            \\\n    T ca = t0 + t1 * twr;                      \\\n    T cb{-t2.i * twi, t2.r * twi};             \\\n    PM(CH(0, k, u1), CH(0, k, u2), ca, cb);    \\\n  }\n#define POCKETFFT_PARTSTEP3b(u1, u2, twr, twi)              \\\n  {                                                         \\\n    T ca = t0 + t1 * twr;                                   \\\n    T cb{-t2.i * twi, t2.r * twi};                          \\\n    special_mul<fwd>(ca + cb, WA(u1 - 1, i), CH(i, k, u1)); \\\n    special_mul<fwd>(ca - cb, WA(u2 - 1, i), CH(i, k, u2)); \\\n  }\n  template<bool fwd, typename T>\n  void pass3(size_t ido, size_t l1, const T* POCKETFFT_RESTRICT cc, T* POCKETFFT_RESTRICT ch,\n             const cmplx<T0>* POCKETFFT_RESTRICT wa) const {\n    constexpr T0 tw1r = -0.5, tw1i = (fwd ? -1 : 1) * T0(0.8660254037844386467637231707529362L);\n\n    auto CH = [ch, ido, l1](size_t a, size_t b, size_t c) -> T& {\n      return ch[a + ido * (b + l1 * c)];\n    };\n    auto CC = [cc, ido](size_t a, size_t b, size_t c) -> const T& {\n      return cc[a + ido * (b + 3 * c)];\n    };\n    auto WA = [wa, ido](size_t x, size_t i) { return wa[i - 1 + x * (ido - 1)]; };\n\n    if (ido == 1)\n      for (size_t k = 0; k < l1; ++k) {\n        POCKETFFT_PREP3(0)\n        POCKETFFT_PARTSTEP3a(1, 2, tw1r, tw1i)\n      }\n    else\n      for (size_t k = 0; k < l1; ++k) {\n        {\n          POCKETFFT_PREP3(0)\n          POCKETFFT_PARTSTEP3a(1, 2, tw1r, tw1i)\n        }\n        for (size_t i = 1; i < ido; ++i) {\n          POCKETFFT_PREP3(i)\n          POCKETFFT_PARTSTEP3b(1, 2, tw1r, tw1i)\n        }\n      }\n  }\n\n#undef POCKETFFT_PARTSTEP3b\n#undef POCKETFFT_PARTSTEP3a\n#undef POCKETFFT_PREP3\n\n  template<bool fwd, typename T>\n  void pass4(size_t ido, size_t l1, const T* POCKETFFT_RESTRICT cc, T* POCKETFFT_RESTRICT ch,\n             const cmplx<T0>* POCKETFFT_RESTRICT wa) const {\n    auto CH = [ch, ido, l1](size_t a, size_t b, size_t c) -> T& {\n      return ch[a + ido * (b + l1 * c)];\n    };\n    auto CC = [cc, ido](size_t a, size_t b, size_t c) -> const T& {\n      return cc[a + ido * (b + 4 * c)];\n    };\n    auto WA = [wa, ido](size_t x, size_t i) { return wa[i - 1 + x * (ido - 1)]; };\n\n    if (ido == 1)\n      for (size_t k = 0; k < l1; ++k) {\n        T t1, t2, t3, t4;\n        PM(t2, t1, CC(0, 0, k), CC(0, 2, k));\n        PM(t3, t4, CC(0, 1, k), CC(0, 3, k));\n        ROTX90<fwd>(t4);\n        PM(CH(0, k, 0), CH(0, k, 2), t2, t3);\n        PM(CH(0, k, 1), CH(0, k, 3), t1, t4);\n      }\n    else\n      for (size_t k = 0; k < l1; ++k) {\n        {\n          T t1, t2, t3, t4;\n          PM(t2, t1, CC(0, 0, k), CC(0, 2, k));\n          PM(t3, t4, CC(0, 1, k), CC(0, 3, k));\n          ROTX90<fwd>(t4);\n          PM(CH(0, k, 0), CH(0, k, 2), t2, t3);\n          PM(CH(0, k, 1), CH(0, k, 3), t1, t4);\n        }\n        for (size_t i = 1; i < ido; ++i) {\n          T t1, t2, t3, t4;\n          T cc0 = CC(i, 0, k), cc1 = CC(i, 1, k), cc2 = CC(i, 2, k), cc3 = CC(i, 3, k);\n          PM(t2, t1, cc0, cc2);\n          PM(t3, t4, cc1, cc3);\n          ROTX90<fwd>(t4);\n          CH(i, k, 0) = t2 + t3;\n          special_mul<fwd>(t1 + t4, WA(0, i), CH(i, k, 1));\n          special_mul<fwd>(t2 - t3, WA(1, i), CH(i, k, 2));\n          special_mul<fwd>(t1 - t4, WA(2, i), CH(i, k, 3));\n        }\n      }\n  }\n\n#define POCKETFFT_PREP5(idx)                \\\n  T t0 = CC(idx, 0, k), t1, t2, t3, t4;     \\\n  PM(t1, t4, CC(idx, 1, k), CC(idx, 4, k)); \\\n  PM(t2, t3, CC(idx, 2, k), CC(idx, 3, k)); \\\n  CH(idx, k, 0).r = t0.r + t1.r + t2.r;     \\\n  CH(idx, k, 0).i = t0.i + t1.i + t2.i;\n\n#define POCKETFFT_PARTSTEP5a(u1, u2, twar, twbr, twai, twbi) \\\n  {                                                          \\\n    T ca, cb;                                                \\\n    ca.r = t0.r + twar * t1.r + twbr * t2.r;                 \\\n    ca.i = t0.i + twar * t1.i + twbr * t2.i;                 \\\n    cb.i = twai * t4.r twbi * t3.r;                          \\\n    cb.r = -(twai * t4.i twbi * t3.i);                       \\\n    PM(CH(0, k, u1), CH(0, k, u2), ca, cb);                  \\\n  }\n\n#define POCKETFFT_PARTSTEP5b(u1, u2, twar, twbr, twai, twbi) \\\n  {                                                          \\\n    T ca, cb, da, db;                                        \\\n    ca.r = t0.r + twar * t1.r + twbr * t2.r;                 \\\n    ca.i = t0.i + twar * t1.i + twbr * t2.i;                 \\\n    cb.i = twai * t4.r twbi * t3.r;                          \\\n    cb.r = -(twai * t4.i twbi * t3.i);                       \\\n    special_mul<fwd>(ca + cb, WA(u1 - 1, i), CH(i, k, u1));  \\\n    special_mul<fwd>(ca - cb, WA(u2 - 1, i), CH(i, k, u2));  \\\n  }\n  template<bool fwd, typename T>\n  void pass5(size_t ido, size_t l1, const T* POCKETFFT_RESTRICT cc, T* POCKETFFT_RESTRICT ch,\n             const cmplx<T0>* POCKETFFT_RESTRICT wa) const {\n    constexpr T0 tw1r = T0(0.3090169943749474241022934171828191L),\n                 tw1i = (fwd ? -1 : 1) * T0(0.9510565162951535721164393333793821L),\n                 tw2r = T0(-0.8090169943749474241022934171828191L),\n                 tw2i = (fwd ? -1 : 1) * T0(0.5877852522924731291687059546390728L);\n\n    auto CH = [ch, ido, l1](size_t a, size_t b, size_t c) -> T& {\n      return ch[a + ido * (b + l1 * c)];\n    };\n    auto CC = [cc, ido](size_t a, size_t b, size_t c) -> const T& {\n      return cc[a + ido * (b + 5 * c)];\n    };\n    auto WA = [wa, ido](size_t x, size_t i) { return wa[i - 1 + x * (ido - 1)]; };\n\n    if (ido == 1)\n      for (size_t k = 0; k < l1; ++k) {\n        POCKETFFT_PREP5(0)\n        POCKETFFT_PARTSTEP5a(1, 4, tw1r, tw2r, +tw1i, +tw2i)\n            POCKETFFT_PARTSTEP5a(2, 3, tw2r, tw1r, +tw2i, -tw1i)\n      }\n    else\n      for (size_t k = 0; k < l1; ++k) {\n        {\n          POCKETFFT_PREP5(0)\n          POCKETFFT_PARTSTEP5a(1, 4, tw1r, tw2r, +tw1i, +tw2i)\n              POCKETFFT_PARTSTEP5a(2, 3, tw2r, tw1r, +tw2i, -tw1i)\n        }\n        for (size_t i = 1; i < ido; ++i) {\n          POCKETFFT_PREP5(i)\n          POCKETFFT_PARTSTEP5b(1, 4, tw1r, tw2r, +tw1i, +tw2i)\n              POCKETFFT_PARTSTEP5b(2, 3, tw2r, tw1r, +tw2i, -tw1i)\n        }\n      }\n  }\n\n#undef POCKETFFT_PARTSTEP5b\n#undef POCKETFFT_PARTSTEP5a\n#undef POCKETFFT_PREP5\n\n#define POCKETFFT_PREP7(idx)                    \\\n  T t1 = CC(idx, 0, k), t2, t3, t4, t5, t6, t7; \\\n  PM(t2, t7, CC(idx, 1, k), CC(idx, 6, k));     \\\n  PM(t3, t6, CC(idx, 2, k), CC(idx, 5, k));     \\\n  PM(t4, t5, CC(idx, 3, k), CC(idx, 4, k));     \\\n  CH(idx, k, 0).r = t1.r + t2.r + t3.r + t4.r;  \\\n  CH(idx, k, 0).i = t1.i + t2.i + t3.i + t4.i;\n\n#define POCKETFFT_PARTSTEP7a0(u1, u2, x1, x2, x3, y1, y2, y3, out1, out2) \\\n  {                                                                       \\\n    T ca, cb;                                                             \\\n    ca.r = t1.r + x1 * t2.r + x2 * t3.r + x3 * t4.r;                      \\\n    ca.i = t1.i + x1 * t2.i + x2 * t3.i + x3 * t4.i;                      \\\n    cb.i = y1 * t7.r y2 * t6.r y3 * t5.r;                                 \\\n    cb.r = -(y1 * t7.i y2 * t6.i y3 * t5.i);                              \\\n    PM(out1, out2, ca, cb);                                               \\\n  }\n#define POCKETFFT_PARTSTEP7a(u1, u2, x1, x2, x3, y1, y2, y3) \\\n  POCKETFFT_PARTSTEP7a0(u1, u2, x1, x2, x3, y1, y2, y3, CH(0, k, u1), CH(0, k, u2))\n#define POCKETFFT_PARTSTEP7(u1, u2, x1, x2, x3, y1, y2, y3)       \\\n  {                                                               \\\n    T da, db;                                                     \\\n    POCKETFFT_PARTSTEP7a0(u1, u2, x1, x2, x3, y1, y2, y3, da, db) \\\n        special_mul<fwd>(da, WA(u1 - 1, i), CH(i, k, u1));        \\\n    special_mul<fwd>(db, WA(u2 - 1, i), CH(i, k, u2));            \\\n  }\n\n  template<bool fwd, typename T>\n  void pass7(size_t ido, size_t l1, const T* POCKETFFT_RESTRICT cc, T* POCKETFFT_RESTRICT ch,\n             const cmplx<T0>* POCKETFFT_RESTRICT wa) const {\n    constexpr T0 tw1r = T0(0.6234898018587335305250048840042398L),\n                 tw1i = (fwd ? -1 : 1) * T0(0.7818314824680298087084445266740578L),\n                 tw2r = T0(-0.2225209339563144042889025644967948L),\n                 tw2i = (fwd ? -1 : 1) * T0(0.9749279121818236070181316829939312L),\n                 tw3r = T0(-0.9009688679024191262361023195074451L),\n                 tw3i = (fwd ? -1 : 1) * T0(0.433883739117558120475768332848359L);\n\n    auto CH = [ch, ido, l1](size_t a, size_t b, size_t c) -> T& {\n      return ch[a + ido * (b + l1 * c)];\n    };\n    auto CC = [cc, ido](size_t a, size_t b, size_t c) -> const T& {\n      return cc[a + ido * (b + 7 * c)];\n    };\n    auto WA = [wa, ido](size_t x, size_t i) { return wa[i - 1 + x * (ido - 1)]; };\n\n    if (ido == 1)\n      for (size_t k = 0; k < l1; ++k) {\n        POCKETFFT_PREP7(0)\n        POCKETFFT_PARTSTEP7a(1, 6, tw1r, tw2r, tw3r, +tw1i, +tw2i, +tw3i)\n            POCKETFFT_PARTSTEP7a(2, 5, tw2r, tw3r, tw1r, +tw2i, -tw3i, -tw1i)\n                POCKETFFT_PARTSTEP7a(3, 4, tw3r, tw1r, tw2r, +tw3i, -tw1i, +tw2i)\n      }\n    else\n      for (size_t k = 0; k < l1; ++k) {\n        {\n          POCKETFFT_PREP7(0)\n          POCKETFFT_PARTSTEP7a(1, 6, tw1r, tw2r, tw3r, +tw1i, +tw2i, +tw3i)\n              POCKETFFT_PARTSTEP7a(2, 5, tw2r, tw3r, tw1r, +tw2i, -tw3i, -tw1i)\n                  POCKETFFT_PARTSTEP7a(3, 4, tw3r, tw1r, tw2r, +tw3i, -tw1i, +tw2i)\n        }\n        for (size_t i = 1; i < ido; ++i) {\n          POCKETFFT_PREP7(i)\n          POCKETFFT_PARTSTEP7(1, 6, tw1r, tw2r, tw3r, +tw1i, +tw2i, +tw3i)\n          POCKETFFT_PARTSTEP7(2, 5, tw2r, tw3r, tw1r, +tw2i, -tw3i, -tw1i)\n          POCKETFFT_PARTSTEP7(3, 4, tw3r, tw1r, tw2r, +tw3i, -tw1i, +tw2i)\n        }\n      }\n  }\n\n#undef POCKETFFT_PARTSTEP7\n#undef POCKETFFT_PARTSTEP7a0\n#undef POCKETFFT_PARTSTEP7a\n#undef POCKETFFT_PREP7\n\n  template<bool fwd, typename T>\n  void ROTX45(T& a) const {\n    constexpr T0 hsqt2 = T0(0.707106781186547524400844362104849L);\n    if (fwd) {\n      auto tmp_ = a.r;\n      a.r = hsqt2 * (a.r + a.i);\n      a.i = hsqt2 * (a.i - tmp_);\n    } else {\n      auto tmp_ = a.r;\n      a.r = hsqt2 * (a.r - a.i);\n      a.i = hsqt2 * (a.i + tmp_);\n    }\n  }\n  template<bool fwd, typename T>\n  void ROTX135(T& a) const {\n    constexpr T0 hsqt2 = T0(0.707106781186547524400844362104849L);\n    if (fwd) {\n      auto tmp_ = a.r;\n      a.r = hsqt2 * (a.i - a.r);\n      a.i = hsqt2 * (-tmp_ - a.i);\n    } else {\n      auto tmp_ = a.r;\n      a.r = hsqt2 * (-a.r - a.i);\n      a.i = hsqt2 * (tmp_ - a.i);\n    }\n  }\n\n  template<bool fwd, typename T>\n  void pass8(size_t ido, size_t l1, const T* POCKETFFT_RESTRICT cc, T* POCKETFFT_RESTRICT ch,\n             const cmplx<T0>* POCKETFFT_RESTRICT wa) const {\n    auto CH = [ch, ido, l1](size_t a, size_t b, size_t c) -> T& {\n      return ch[a + ido * (b + l1 * c)];\n    };\n    auto CC = [cc, ido](size_t a, size_t b, size_t c) -> const T& {\n      return cc[a + ido * (b + 8 * c)];\n    };\n    auto WA = [wa, ido](size_t x, size_t i) { return wa[i - 1 + x * (ido - 1)]; };\n\n    if (ido == 1)\n      for (size_t k = 0; k < l1; ++k) {\n        T a0, a1, a2, a3, a4, a5, a6, a7;\n        PM(a1, a5, CC(0, 1, k), CC(0, 5, k));\n        PM(a3, a7, CC(0, 3, k), CC(0, 7, k));\n        PMINPLACE(a1, a3);\n        ROTX90<fwd>(a3);\n\n        ROTX90<fwd>(a7);\n        PMINPLACE(a5, a7);\n        ROTX45<fwd>(a5);\n        ROTX135<fwd>(a7);\n\n        PM(a0, a4, CC(0, 0, k), CC(0, 4, k));\n        PM(a2, a6, CC(0, 2, k), CC(0, 6, k));\n        PM(CH(0, k, 0), CH(0, k, 4), a0 + a2, a1);\n        PM(CH(0, k, 2), CH(0, k, 6), a0 - a2, a3);\n        ROTX90<fwd>(a6);\n        PM(CH(0, k, 1), CH(0, k, 5), a4 + a6, a5);\n        PM(CH(0, k, 3), CH(0, k, 7), a4 - a6, a7);\n      }\n    else\n      for (size_t k = 0; k < l1; ++k) {\n        {\n          T a0, a1, a2, a3, a4, a5, a6, a7;\n          PM(a1, a5, CC(0, 1, k), CC(0, 5, k));\n          PM(a3, a7, CC(0, 3, k), CC(0, 7, k));\n          PMINPLACE(a1, a3);\n          ROTX90<fwd>(a3);\n\n          ROTX90<fwd>(a7);\n          PMINPLACE(a5, a7);\n          ROTX45<fwd>(a5);\n          ROTX135<fwd>(a7);\n\n          PM(a0, a4, CC(0, 0, k), CC(0, 4, k));\n          PM(a2, a6, CC(0, 2, k), CC(0, 6, k));\n          PM(CH(0, k, 0), CH(0, k, 4), a0 + a2, a1);\n          PM(CH(0, k, 2), CH(0, k, 6), a0 - a2, a3);\n          ROTX90<fwd>(a6);\n          PM(CH(0, k, 1), CH(0, k, 5), a4 + a6, a5);\n          PM(CH(0, k, 3), CH(0, k, 7), a4 - a6, a7);\n        }\n        for (size_t i = 1; i < ido; ++i) {\n          T a0, a1, a2, a3, a4, a5, a6, a7;\n          PM(a1, a5, CC(i, 1, k), CC(i, 5, k));\n          PM(a3, a7, CC(i, 3, k), CC(i, 7, k));\n          ROTX90<fwd>(a7);\n          PMINPLACE(a1, a3);\n          ROTX90<fwd>(a3);\n          PMINPLACE(a5, a7);\n          ROTX45<fwd>(a5);\n          ROTX135<fwd>(a7);\n          PM(a0, a4, CC(i, 0, k), CC(i, 4, k));\n          PM(a2, a6, CC(i, 2, k), CC(i, 6, k));\n          PMINPLACE(a0, a2);\n          CH(i, k, 0) = a0 + a1;\n          special_mul<fwd>(a0 - a1, WA(3, i), CH(i, k, 4));\n          special_mul<fwd>(a2 + a3, WA(1, i), CH(i, k, 2));\n          special_mul<fwd>(a2 - a3, WA(5, i), CH(i, k, 6));\n          ROTX90<fwd>(a6);\n          PMINPLACE(a4, a6);\n          special_mul<fwd>(a4 + a5, WA(0, i), CH(i, k, 1));\n          special_mul<fwd>(a4 - a5, WA(4, i), CH(i, k, 5));\n          special_mul<fwd>(a6 + a7, WA(2, i), CH(i, k, 3));\n          special_mul<fwd>(a6 - a7, WA(6, i), CH(i, k, 7));\n        }\n      }\n  }\n\n#define POCKETFFT_PREP11(idx)                                     \\\n  T t1 = CC(idx, 0, k), t2, t3, t4, t5, t6, t7, t8, t9, t10, t11; \\\n  PM(t2, t11, CC(idx, 1, k), CC(idx, 10, k));                     \\\n  PM(t3, t10, CC(idx, 2, k), CC(idx, 9, k));                      \\\n  PM(t4, t9, CC(idx, 3, k), CC(idx, 8, k));                       \\\n  PM(t5, t8, CC(idx, 4, k), CC(idx, 7, k));                       \\\n  PM(t6, t7, CC(idx, 5, k), CC(idx, 6, k));                       \\\n  CH(idx, k, 0).r = t1.r + t2.r + t3.r + t4.r + t5.r + t6.r;      \\\n  CH(idx, k, 0).i = t1.i + t2.i + t3.i + t4.i + t5.i + t6.i;\n\n#define POCKETFFT_PARTSTEP11a0(u1, u2, x1, x2, x3, x4, x5, y1, y2, y3, y4, y5, out1, out2) \\\n  {                                                                                        \\\n    T ca = t1 + t2 * x1 + t3 * x2 + t4 * x3 + t5 * x4 + t6 * x5, cb;                       \\\n    cb.i = y1 * t11.r y2 * t10.r y3 * t9.r y4 * t8.r y5 * t7.r;                            \\\n    cb.r = -(y1 * t11.i y2 * t10.i y3 * t9.i y4 * t8.i y5 * t7.i);                         \\\n    PM(out1, out2, ca, cb);                                                                \\\n  }\n#define POCKETFFT_PARTSTEP11a(u1, u2, x1, x2, x3, x4, x5, y1, y2, y3, y4, y5) \\\n  POCKETFFT_PARTSTEP11a0(u1, u2, x1, x2, x3, x4, x5, y1, y2, y3, y4, y5, CH(0, k, u1), CH(0, k, u2))\n#define POCKETFFT_PARTSTEP11(u1, u2, x1, x2, x3, x4, x5, y1, y2, y3, y4, y5)       \\\n  {                                                                                \\\n    T da, db;                                                                      \\\n    POCKETFFT_PARTSTEP11a0(u1, u2, x1, x2, x3, x4, x5, y1, y2, y3, y4, y5, da, db) \\\n        special_mul<fwd>(da, WA(u1 - 1, i), CH(i, k, u1));                         \\\n    special_mul<fwd>(db, WA(u2 - 1, i), CH(i, k, u2));                             \\\n  }\n\n  template<bool fwd, typename T>\n  void pass11(size_t ido, size_t l1, const T* POCKETFFT_RESTRICT cc, T* POCKETFFT_RESTRICT ch,\n              const cmplx<T0>* POCKETFFT_RESTRICT wa) const {\n    constexpr T0 tw1r = T0(0.8412535328311811688618116489193677L),\n                 tw1i = (fwd ? -1 : 1) * T0(0.5406408174555975821076359543186917L),\n                 tw2r = T0(0.4154150130018864255292741492296232L),\n                 tw2i = (fwd ? -1 : 1) * T0(0.9096319953545183714117153830790285L),\n                 tw3r = T0(-0.1423148382732851404437926686163697L),\n                 tw3i = (fwd ? -1 : 1) * T0(0.9898214418809327323760920377767188L),\n                 tw4r = T0(-0.6548607339452850640569250724662936L),\n                 tw4i = (fwd ? -1 : 1) * T0(0.7557495743542582837740358439723444L),\n                 tw5r = T0(-0.9594929736144973898903680570663277L),\n                 tw5i = (fwd ? -1 : 1) * T0(0.2817325568414296977114179153466169L);\n\n    auto CH = [ch, ido, l1](size_t a, size_t b, size_t c) -> T& {\n      return ch[a + ido * (b + l1 * c)];\n    };\n    auto CC = [cc, ido](size_t a, size_t b, size_t c) -> const T& {\n      return cc[a + ido * (b + 11 * c)];\n    };\n    auto WA = [wa, ido](size_t x, size_t i) { return wa[i - 1 + x * (ido - 1)]; };\n\n    if (ido == 1)\n      for (size_t k = 0; k < l1; ++k) {\n        POCKETFFT_PREP11(0)\n        POCKETFFT_PARTSTEP11a(1, 10, tw1r, tw2r, tw3r, tw4r, tw5r, +tw1i, +tw2i, +tw3i, +tw4i,\n                              +tw5i) POCKETFFT_PARTSTEP11a(2, 9, tw2r, tw4r, tw5r, tw3r, tw1r,\n                                                           +tw2i, +tw4i, -tw5i, -tw3i, -tw1i)\n            POCKETFFT_PARTSTEP11a(3, 8, tw3r, tw5r, tw2r, tw1r, tw4r, +tw3i, -tw5i, -tw2i, +tw1i,\n                                  +tw4i) POCKETFFT_PARTSTEP11a(4, 7, tw4r, tw3r, tw1r, tw5r, tw2r,\n                                                               +tw4i, -tw3i, +tw1i, +tw5i, -tw2i)\n                POCKETFFT_PARTSTEP11a(5, 6, tw5r, tw1r, tw4r, tw2r, tw3r, +tw5i, -tw1i, +tw4i,\n                                      -tw2i, +tw3i)\n      }\n    else\n      for (size_t k = 0; k < l1; ++k) {\n        {\n          POCKETFFT_PREP11(0)\n          POCKETFFT_PARTSTEP11a(1, 10, tw1r, tw2r, tw3r, tw4r, tw5r, +tw1i, +tw2i, +tw3i, +tw4i,\n                                +tw5i) POCKETFFT_PARTSTEP11a(2, 9, tw2r, tw4r, tw5r, tw3r, tw1r,\n                                                             +tw2i, +tw4i, -tw5i, -tw3i, -tw1i)\n              POCKETFFT_PARTSTEP11a(3, 8, tw3r, tw5r, tw2r, tw1r, tw4r, +tw3i, -tw5i, -tw2i, +tw1i,\n                                    +tw4i) POCKETFFT_PARTSTEP11a(4, 7, tw4r, tw3r, tw1r, tw5r, tw2r,\n                                                                 +tw4i, -tw3i, +tw1i, +tw5i, -tw2i)\n                  POCKETFFT_PARTSTEP11a(5, 6, tw5r, tw1r, tw4r, tw2r, tw3r, +tw5i, -tw1i, +tw4i,\n                                        -tw2i, +tw3i)\n        }\n        for (size_t i = 1; i < ido; ++i) {\n          POCKETFFT_PREP11(i)\n          POCKETFFT_PARTSTEP11(1, 10, tw1r, tw2r, tw3r, tw4r, tw5r, +tw1i, +tw2i, +tw3i, +tw4i,\n                               +tw5i)\n          POCKETFFT_PARTSTEP11(2, 9, tw2r, tw4r, tw5r, tw3r, tw1r, +tw2i, +tw4i, -tw5i, -tw3i,\n                               -tw1i)\n          POCKETFFT_PARTSTEP11(3, 8, tw3r, tw5r, tw2r, tw1r, tw4r, +tw3i, -tw5i, -tw2i, +tw1i,\n                               +tw4i)\n          POCKETFFT_PARTSTEP11(4, 7, tw4r, tw3r, tw1r, tw5r, tw2r, +tw4i, -tw3i, +tw1i, +tw5i,\n                               -tw2i)\n          POCKETFFT_PARTSTEP11(5, 6, tw5r, tw1r, tw4r, tw2r, tw3r, +tw5i, -tw1i, +tw4i, -tw2i,\n                               +tw3i)\n        }\n      }\n  }\n\n#undef POCKETFFT_PARTSTEP11\n#undef POCKETFFT_PARTSTEP11a0\n#undef POCKETFFT_PARTSTEP11a\n#undef POCKETFFT_PREP11\n\n  template<bool fwd, typename T>\n  void passg(size_t ido, size_t ip, size_t l1, T* POCKETFFT_RESTRICT cc, T* POCKETFFT_RESTRICT ch,\n             const cmplx<T0>* POCKETFFT_RESTRICT wa,\n             const cmplx<T0>* POCKETFFT_RESTRICT csarr) const {\n    const size_t cdim = ip;\n    size_t ipph = (ip + 1) / 2;\n    size_t idl1 = ido * l1;\n\n    auto CH = [ch, ido, l1](size_t a, size_t b, size_t c) -> T& {\n      return ch[a + ido * (b + l1 * c)];\n    };\n    auto CC = [cc, ido, cdim](size_t a, size_t b, size_t c) -> const T& {\n      return cc[a + ido * (b + cdim * c)];\n    };\n    auto CX = [cc, ido, l1](size_t a, size_t b, size_t c) -> T& {\n      return cc[a + ido * (b + l1 * c)];\n    };\n    auto CX2 = [cc, idl1](size_t a, size_t b) -> T& { return cc[a + idl1 * b]; };\n    auto CH2 = [ch, idl1](size_t a, size_t b) -> const T& { return ch[a + idl1 * b]; };\n\n    arr<cmplx<T0>> wal(ip);\n    wal[0] = cmplx<T0>(1., 0.);\n    for (size_t i = 1; i < ip; ++i) wal[i] = cmplx<T0>(csarr[i].r, fwd ? -csarr[i].i : csarr[i].i);\n\n    for (size_t k = 0; k < l1; ++k)\n      for (size_t i = 0; i < ido; ++i) CH(i, k, 0) = CC(i, 0, k);\n    for (size_t j = 1, jc = ip - 1; j < ipph; ++j, --jc)\n      for (size_t k = 0; k < l1; ++k)\n        for (size_t i = 0; i < ido; ++i) PM(CH(i, k, j), CH(i, k, jc), CC(i, j, k), CC(i, jc, k));\n    for (size_t k = 0; k < l1; ++k)\n      for (size_t i = 0; i < ido; ++i) {\n        T tmp = CH(i, k, 0);\n        for (size_t j = 1; j < ipph; ++j) tmp += CH(i, k, j);\n        CX(i, k, 0) = tmp;\n      }\n    for (size_t l = 1, lc = ip - 1; l < ipph; ++l, --lc) {\n      // j=0\n      for (size_t ik = 0; ik < idl1; ++ik) {\n        CX2(ik, l).r = CH2(ik, 0).r + wal[l].r * CH2(ik, 1).r + wal[2 * l].r * CH2(ik, 2).r;\n        CX2(ik, l).i = CH2(ik, 0).i + wal[l].r * CH2(ik, 1).i + wal[2 * l].r * CH2(ik, 2).i;\n        CX2(ik, lc).r = -wal[l].i * CH2(ik, ip - 1).i - wal[2 * l].i * CH2(ik, ip - 2).i;\n        CX2(ik, lc).i = wal[l].i * CH2(ik, ip - 1).r + wal[2 * l].i * CH2(ik, ip - 2).r;\n      }\n\n      size_t iwal = 2 * l;\n      size_t j = 3, jc = ip - 3;\n      for (; j < ipph - 1; j += 2, jc -= 2) {\n        iwal += l;\n        if (iwal > ip) iwal -= ip;\n        cmplx<T0> xwal = wal[iwal];\n        iwal += l;\n        if (iwal > ip) iwal -= ip;\n        cmplx<T0> xwal2 = wal[iwal];\n        for (size_t ik = 0; ik < idl1; ++ik) {\n          CX2(ik, l).r += CH2(ik, j).r * xwal.r + CH2(ik, j + 1).r * xwal2.r;\n          CX2(ik, l).i += CH2(ik, j).i * xwal.r + CH2(ik, j + 1).i * xwal2.r;\n          CX2(ik, lc).r -= CH2(ik, jc).i * xwal.i + CH2(ik, jc - 1).i * xwal2.i;\n          CX2(ik, lc).i += CH2(ik, jc).r * xwal.i + CH2(ik, jc - 1).r * xwal2.i;\n        }\n      }\n      for (; j < ipph; ++j, --jc) {\n        iwal += l;\n        if (iwal > ip) iwal -= ip;\n        cmplx<T0> xwal = wal[iwal];\n        for (size_t ik = 0; ik < idl1; ++ik) {\n          CX2(ik, l).r += CH2(ik, j).r * xwal.r;\n          CX2(ik, l).i += CH2(ik, j).i * xwal.r;\n          CX2(ik, lc).r -= CH2(ik, jc).i * xwal.i;\n          CX2(ik, lc).i += CH2(ik, jc).r * xwal.i;\n        }\n      }\n    }\n\n    // shuffling and twiddling\n    if (ido == 1)\n      for (size_t j = 1, jc = ip - 1; j < ipph; ++j, --jc)\n        for (size_t ik = 0; ik < idl1; ++ik) {\n          T t1 = CX2(ik, j), t2 = CX2(ik, jc);\n          PM(CX2(ik, j), CX2(ik, jc), t1, t2);\n        }\n    else {\n      for (size_t j = 1, jc = ip - 1; j < ipph; ++j, --jc)\n        for (size_t k = 0; k < l1; ++k) {\n          T t1 = CX(0, k, j), t2 = CX(0, k, jc);\n          PM(CX(0, k, j), CX(0, k, jc), t1, t2);\n          for (size_t i = 1; i < ido; ++i) {\n            T x1, x2;\n            PM(x1, x2, CX(i, k, j), CX(i, k, jc));\n            size_t idij = (j - 1) * (ido - 1) + i - 1;\n            special_mul<fwd>(x1, wa[idij], CX(i, k, j));\n            idij = (jc - 1) * (ido - 1) + i - 1;\n            special_mul<fwd>(x2, wa[idij], CX(i, k, jc));\n          }\n        }\n    }\n  }\n\n  template<bool fwd, typename T>\n  void pass_all(T c[], T0 fct) const {\n    if (length == 1) {\n      c[0] *= fct;\n      return;\n    }\n    size_t l1 = 1;\n    arr<T> ch(length);\n    T *p1 = c, *p2 = ch.data();\n\n    for (size_t k1 = 0; k1 < fact.size(); k1++) {\n      size_t ip = fact[k1].fct;\n      size_t l2 = ip * l1;\n      size_t ido = length / l2;\n      if (ip == 4)\n        pass4<fwd>(ido, l1, p1, p2, fact[k1].tw);\n      else if (ip == 8)\n        pass8<fwd>(ido, l1, p1, p2, fact[k1].tw);\n      else if (ip == 2)\n        pass2<fwd>(ido, l1, p1, p2, fact[k1].tw);\n      else if (ip == 3)\n        pass3<fwd>(ido, l1, p1, p2, fact[k1].tw);\n      else if (ip == 5)\n        pass5<fwd>(ido, l1, p1, p2, fact[k1].tw);\n      else if (ip == 7)\n        pass7<fwd>(ido, l1, p1, p2, fact[k1].tw);\n      else if (ip == 11)\n        pass11<fwd>(ido, l1, p1, p2, fact[k1].tw);\n      else {\n        passg<fwd>(ido, ip, l1, p1, p2, fact[k1].tw, fact[k1].tws);\n        std::swap(p1, p2);\n      }\n      std::swap(p1, p2);\n      l1 = l2;\n    }\n    if (p1 != c) {\n      if (fct != 1.)\n        for (size_t i = 0; i < length; ++i) c[i] = ch[i] * fct;\n      else\n        std::copy_n(p1, length, c);\n    } else if (fct != 1.)\n      for (size_t i = 0; i < length; ++i) c[i] *= fct;\n  }\n\n public:\n  template<typename T>\n  void exec(T c[], T0 fct, bool fwd) const {\n    fwd ? pass_all<true>(c, fct) : pass_all<false>(c, fct);\n  }\n\n private:\n  POCKETFFT_NOINLINE void factorize() {\n    size_t len = length;\n    while ((len & 7) == 0) {\n      add_factor(8);\n      len >>= 3;\n    }\n    while ((len & 3) == 0) {\n      add_factor(4);\n      len >>= 2;\n    }\n    if ((len & 1) == 0) {\n      len >>= 1;\n      // factor 2 should be at the front of the factor list\n      add_factor(2);\n      std::swap(fact[0].fct, fact.back().fct);\n    }\n    for (size_t divisor = 3; divisor * divisor <= len; divisor += 2)\n      while ((len % divisor) == 0) {\n        add_factor(divisor);\n        len /= divisor;\n      }\n    if (len > 1) add_factor(len);\n  }\n\n  size_t twsize() const {\n    size_t twsize = 0, l1 = 1;\n    for (size_t k = 0; k < fact.size(); ++k) {\n      size_t ip = fact[k].fct, ido = length / (l1 * ip);\n      twsize += (ip - 1) * (ido - 1);\n      if (ip > 11) twsize += ip;\n      l1 *= ip;\n    }\n    return twsize;\n  }\n\n  void comp_twiddle() {\n    sincos_2pibyn<T0> twiddle(length);\n    size_t l1 = 1;\n    size_t memofs = 0;\n    for (size_t k = 0; k < fact.size(); ++k) {\n      size_t ip = fact[k].fct, ido = length / (l1 * ip);\n      fact[k].tw = mem.data() + memofs;\n      memofs += (ip - 1) * (ido - 1);\n      for (size_t j = 1; j < ip; ++j)\n        for (size_t i = 1; i < ido; ++i)\n          fact[k].tw[(j - 1) * (ido - 1) + i - 1] = twiddle[j * l1 * i];\n      if (ip > 11) {\n        fact[k].tws = mem.data() + memofs;\n        memofs += ip;\n        for (size_t j = 0; j < ip; ++j) fact[k].tws[j] = twiddle[j * l1 * ido];\n      }\n      l1 *= ip;\n    }\n  }\n\n public:\n  POCKETFFT_NOINLINE cfftp(size_t length_) : length(length_) {\n    if (length == 0) throw std::runtime_error(\"zero-length FFT requested\");\n    if (length == 1) return;\n    factorize();\n    mem.resize(twsize());\n    comp_twiddle();\n  }\n};\n\n//\n// real-valued FFTPACK transforms\n//\n\ntemplate<typename T0>\nclass rfftp {\n private:\n  struct fctdata {\n    size_t fct;\n    T0 *tw, *tws;\n  };\n\n  size_t length;\n  arr<T0> mem;\n  std::vector<fctdata> fact;\n\n  void add_factor(size_t factor) { fact.push_back({factor, nullptr, nullptr}); }\n\n  /* (a+ib) = conj(c+id) * (e+if) */\n  template<typename T1, typename T2, typename T3>\n  inline void MULPM(T1& a, T1& b, T2 c, T2 d, T3 e, T3 f) const {\n    a = c * e + d * f;\n    b = c * f - d * e;\n  }\n\n  template<typename T>\n  void radf2(size_t ido, size_t l1, const T* POCKETFFT_RESTRICT cc, T* POCKETFFT_RESTRICT ch,\n             const T0* POCKETFFT_RESTRICT wa) const {\n    auto WA = [wa, ido](size_t x, size_t i) { return wa[i + x * (ido - 1)]; };\n    auto CC = [cc, ido, l1](size_t a, size_t b, size_t c) -> const T& {\n      return cc[a + ido * (b + l1 * c)];\n    };\n    auto CH = [ch, ido](size_t a, size_t b, size_t c) -> T& { return ch[a + ido * (b + 2 * c)]; };\n\n    for (size_t k = 0; k < l1; k++) PM(CH(0, 0, k), CH(ido - 1, 1, k), CC(0, k, 0), CC(0, k, 1));\n    if ((ido & 1) == 0)\n      for (size_t k = 0; k < l1; k++) {\n        CH(0, 1, k) = -CC(ido - 1, k, 1);\n        CH(ido - 1, 0, k) = CC(ido - 1, k, 0);\n      }\n    if (ido <= 2) return;\n    for (size_t k = 0; k < l1; k++)\n      for (size_t i = 2; i < ido; i += 2) {\n        size_t ic = ido - i;\n        T tr2, ti2;\n        MULPM(tr2, ti2, WA(0, i - 2), WA(0, i - 1), CC(i - 1, k, 1), CC(i, k, 1));\n        PM(CH(i - 1, 0, k), CH(ic - 1, 1, k), CC(i - 1, k, 0), tr2);\n        PM(CH(i, 0, k), CH(ic, 1, k), ti2, CC(i, k, 0));\n      }\n  }\n\n// a2=a+b; b2=i*(b-a);\n#define POCKETFFT_REARRANGE(rx, ix, ry, iy)                      \\\n  {                                                              \\\n    auto t1 = rx + ry, t2 = ry - rx, t3 = ix + iy, t4 = ix - iy; \\\n    rx = t1;                                                     \\\n    ix = t3;                                                     \\\n    ry = t4;                                                     \\\n    iy = t2;                                                     \\\n  }\n\n  template<typename T>\n  void radf3(size_t ido, size_t l1, const T* POCKETFFT_RESTRICT cc, T* POCKETFFT_RESTRICT ch,\n             const T0* POCKETFFT_RESTRICT wa) const {\n    constexpr T0 taur = -0.5, taui = T0(0.8660254037844386467637231707529362L);\n\n    auto WA = [wa, ido](size_t x, size_t i) { return wa[i + x * (ido - 1)]; };\n    auto CC = [cc, ido, l1](size_t a, size_t b, size_t c) -> const T& {\n      return cc[a + ido * (b + l1 * c)];\n    };\n    auto CH = [ch, ido](size_t a, size_t b, size_t c) -> T& { return ch[a + ido * (b + 3 * c)]; };\n\n    for (size_t k = 0; k < l1; k++) {\n      T cr2 = CC(0, k, 1) + CC(0, k, 2);\n      CH(0, 0, k) = CC(0, k, 0) + cr2;\n      CH(0, 2, k) = taui * (CC(0, k, 2) - CC(0, k, 1));\n      CH(ido - 1, 1, k) = CC(0, k, 0) + taur * cr2;\n    }\n    if (ido == 1) return;\n    for (size_t k = 0; k < l1; k++)\n      for (size_t i = 2; i < ido; i += 2) {\n        size_t ic = ido - i;\n        T di2, di3, dr2, dr3;\n        MULPM(dr2, di2, WA(0, i - 2), WA(0, i - 1), CC(i - 1, k, 1),\n              CC(i, k, 1));  // d2=conj(WA0)*CC1\n        MULPM(dr3, di3, WA(1, i - 2), WA(1, i - 1), CC(i - 1, k, 2),\n              CC(i, k, 2));  // d3=conj(WA1)*CC2\n        POCKETFFT_REARRANGE(dr2, di2, dr3, di3);\n        CH(i - 1, 0, k) = CC(i - 1, k, 0) + dr2;  // c add\n        CH(i, 0, k) = CC(i, k, 0) + di2;\n        T tr2 = CC(i - 1, k, 0) + taur * dr2;  // c add\n        T ti2 = CC(i, k, 0) + taur * di2;\n        T tr3 = taui * dr3;  // t3 = taui*i*(d3-d2)?\n        T ti3 = taui * di3;\n        PM(CH(i - 1, 2, k), CH(ic - 1, 1, k), tr2, tr3);  // PM(i) = t2+t3\n        PM(CH(i, 2, k), CH(ic, 1, k), ti3, ti2);          // PM(ic) = conj(t2-t3)\n      }\n  }\n\n  template<typename T>\n  void radf4(size_t ido, size_t l1, const T* POCKETFFT_RESTRICT cc, T* POCKETFFT_RESTRICT ch,\n             const T0* POCKETFFT_RESTRICT wa) const {\n    constexpr T0 hsqt2 = T0(0.707106781186547524400844362104849L);\n\n    auto WA = [wa, ido](size_t x, size_t i) { return wa[i + x * (ido - 1)]; };\n    auto CC = [cc, ido, l1](size_t a, size_t b, size_t c) -> const T& {\n      return cc[a + ido * (b + l1 * c)];\n    };\n    auto CH = [ch, ido](size_t a, size_t b, size_t c) -> T& { return ch[a + ido * (b + 4 * c)]; };\n\n    for (size_t k = 0; k < l1; k++) {\n      T tr1, tr2;\n      PM(tr1, CH(0, 2, k), CC(0, k, 3), CC(0, k, 1));\n      PM(tr2, CH(ido - 1, 1, k), CC(0, k, 0), CC(0, k, 2));\n      PM(CH(0, 0, k), CH(ido - 1, 3, k), tr2, tr1);\n    }\n    if ((ido & 1) == 0)\n      for (size_t k = 0; k < l1; k++) {\n        T ti1 = -hsqt2 * (CC(ido - 1, k, 1) + CC(ido - 1, k, 3));\n        T tr1 = hsqt2 * (CC(ido - 1, k, 1) - CC(ido - 1, k, 3));\n        PM(CH(ido - 1, 0, k), CH(ido - 1, 2, k), CC(ido - 1, k, 0), tr1);\n        PM(CH(0, 3, k), CH(0, 1, k), ti1, CC(ido - 1, k, 2));\n      }\n    if (ido <= 2) return;\n    for (size_t k = 0; k < l1; k++)\n      for (size_t i = 2; i < ido; i += 2) {\n        size_t ic = ido - i;\n        T ci2, ci3, ci4, cr2, cr3, cr4, ti1, ti2, ti3, ti4, tr1, tr2, tr3, tr4;\n        MULPM(cr2, ci2, WA(0, i - 2), WA(0, i - 1), CC(i - 1, k, 1), CC(i, k, 1));\n        MULPM(cr3, ci3, WA(1, i - 2), WA(1, i - 1), CC(i - 1, k, 2), CC(i, k, 2));\n        MULPM(cr4, ci4, WA(2, i - 2), WA(2, i - 1), CC(i - 1, k, 3), CC(i, k, 3));\n        PM(tr1, tr4, cr4, cr2);\n        PM(ti1, ti4, ci2, ci4);\n        PM(tr2, tr3, CC(i - 1, k, 0), cr3);\n        PM(ti2, ti3, CC(i, k, 0), ci3);\n        PM(CH(i - 1, 0, k), CH(ic - 1, 3, k), tr2, tr1);\n        PM(CH(i, 0, k), CH(ic, 3, k), ti1, ti2);\n        PM(CH(i - 1, 2, k), CH(ic - 1, 1, k), tr3, ti4);\n        PM(CH(i, 2, k), CH(ic, 1, k), tr4, ti3);\n      }\n  }\n\n  template<typename T>\n  void radf5(size_t ido, size_t l1, const T* POCKETFFT_RESTRICT cc, T* POCKETFFT_RESTRICT ch,\n             const T0* POCKETFFT_RESTRICT wa) const {\n    constexpr T0 tr11 = T0(0.3090169943749474241022934171828191L),\n                 ti11 = T0(0.9510565162951535721164393333793821L),\n                 tr12 = T0(-0.8090169943749474241022934171828191L),\n                 ti12 = T0(0.5877852522924731291687059546390728L);\n\n    auto WA = [wa, ido](size_t x, size_t i) { return wa[i + x * (ido - 1)]; };\n    auto CC = [cc, ido, l1](size_t a, size_t b, size_t c) -> const T& {\n      return cc[a + ido * (b + l1 * c)];\n    };\n    auto CH = [ch, ido](size_t a, size_t b, size_t c) -> T& { return ch[a + ido * (b + 5 * c)]; };\n\n    for (size_t k = 0; k < l1; k++) {\n      T cr2, cr3, ci4, ci5;\n      PM(cr2, ci5, CC(0, k, 4), CC(0, k, 1));\n      PM(cr3, ci4, CC(0, k, 3), CC(0, k, 2));\n      CH(0, 0, k) = CC(0, k, 0) + cr2 + cr3;\n      CH(ido - 1, 1, k) = CC(0, k, 0) + tr11 * cr2 + tr12 * cr3;\n      CH(0, 2, k) = ti11 * ci5 + ti12 * ci4;\n      CH(ido - 1, 3, k) = CC(0, k, 0) + tr12 * cr2 + tr11 * cr3;\n      CH(0, 4, k) = ti12 * ci5 - ti11 * ci4;\n    }\n    if (ido == 1) return;\n    for (size_t k = 0; k < l1; ++k)\n      for (size_t i = 2, ic = ido - 2; i < ido; i += 2, ic -= 2) {\n        T di2, di3, di4, di5, dr2, dr3, dr4, dr5;\n        MULPM(dr2, di2, WA(0, i - 2), WA(0, i - 1), CC(i - 1, k, 1), CC(i, k, 1));\n        MULPM(dr3, di3, WA(1, i - 2), WA(1, i - 1), CC(i - 1, k, 2), CC(i, k, 2));\n        MULPM(dr4, di4, WA(2, i - 2), WA(2, i - 1), CC(i - 1, k, 3), CC(i, k, 3));\n        MULPM(dr5, di5, WA(3, i - 2), WA(3, i - 1), CC(i - 1, k, 4), CC(i, k, 4));\n        POCKETFFT_REARRANGE(dr2, di2, dr5, di5);\n        POCKETFFT_REARRANGE(dr3, di3, dr4, di4);\n        CH(i - 1, 0, k) = CC(i - 1, k, 0) + dr2 + dr3;\n        CH(i, 0, k) = CC(i, k, 0) + di2 + di3;\n        T tr2 = CC(i - 1, k, 0) + tr11 * dr2 + tr12 * dr3;\n        T ti2 = CC(i, k, 0) + tr11 * di2 + tr12 * di3;\n        T tr3 = CC(i - 1, k, 0) + tr12 * dr2 + tr11 * dr3;\n        T ti3 = CC(i, k, 0) + tr12 * di2 + tr11 * di3;\n        T tr5 = ti11 * dr5 + ti12 * dr4;\n        T ti5 = ti11 * di5 + ti12 * di4;\n        T tr4 = ti12 * dr5 - ti11 * dr4;\n        T ti4 = ti12 * di5 - ti11 * di4;\n        PM(CH(i - 1, 2, k), CH(ic - 1, 1, k), tr2, tr5);\n        PM(CH(i, 2, k), CH(ic, 1, k), ti5, ti2);\n        PM(CH(i - 1, 4, k), CH(ic - 1, 3, k), tr3, tr4);\n        PM(CH(i, 4, k), CH(ic, 3, k), ti4, ti3);\n      }\n  }\n\n#undef POCKETFFT_REARRANGE\n\n  template<typename T>\n  void radfg(size_t ido, size_t ip, size_t l1, T* POCKETFFT_RESTRICT cc, T* POCKETFFT_RESTRICT ch,\n             const T0* POCKETFFT_RESTRICT wa, const T0* POCKETFFT_RESTRICT csarr) const {\n    const size_t cdim = ip;\n    size_t ipph = (ip + 1) / 2;\n    size_t idl1 = ido * l1;\n\n    auto CC = [cc, ido, cdim](size_t a, size_t b, size_t c) -> T& {\n      return cc[a + ido * (b + cdim * c)];\n    };\n    auto CH = [ch, ido, l1](size_t a, size_t b, size_t c) -> const T& {\n      return ch[a + ido * (b + l1 * c)];\n    };\n    auto C1 = [cc, ido, l1](size_t a, size_t b, size_t c) -> T& {\n      return cc[a + ido * (b + l1 * c)];\n    };\n    auto C2 = [cc, idl1](size_t a, size_t b) -> T& { return cc[a + idl1 * b]; };\n    auto CH2 = [ch, idl1](size_t a, size_t b) -> T& { return ch[a + idl1 * b]; };\n\n    if (ido > 1) {\n      for (size_t j = 1, jc = ip - 1; j < ipph; ++j, --jc)  // 114\n      {\n        size_t is = (j - 1) * (ido - 1), is2 = (jc - 1) * (ido - 1);\n        for (size_t k = 0; k < l1; ++k)  // 113\n        {\n          size_t idij = is;\n          size_t idij2 = is2;\n          for (size_t i = 1; i <= ido - 2; i += 2)  // 112\n          {\n            T t1 = C1(i, k, j), t2 = C1(i + 1, k, j), t3 = C1(i, k, jc), t4 = C1(i + 1, k, jc);\n            T x1 = wa[idij] * t1 + wa[idij + 1] * t2, x2 = wa[idij] * t2 - wa[idij + 1] * t1,\n              x3 = wa[idij2] * t3 + wa[idij2 + 1] * t4, x4 = wa[idij2] * t4 - wa[idij2 + 1] * t3;\n            PM(C1(i, k, j), C1(i + 1, k, jc), x3, x1);\n            PM(C1(i + 1, k, j), C1(i, k, jc), x2, x4);\n            idij += 2;\n            idij2 += 2;\n          }\n        }\n      }\n    }\n\n    for (size_t j = 1, jc = ip - 1; j < ipph; ++j, --jc)  // 123\n      for (size_t k = 0; k < l1; ++k)                     // 122\n        MPINPLACE(C1(0, k, jc), C1(0, k, j));\n\n    // everything in C\n    // memset(ch,0,ip*l1*ido*sizeof(double));\n\n    for (size_t l = 1, lc = ip - 1; l < ipph; ++l, --lc)  // 127\n    {\n      for (size_t ik = 0; ik < idl1; ++ik)  // 124\n      {\n        CH2(ik, l) = C2(ik, 0) + csarr[2 * l] * C2(ik, 1) + csarr[4 * l] * C2(ik, 2);\n        CH2(ik, lc) = csarr[2 * l + 1] * C2(ik, ip - 1) + csarr[4 * l + 1] * C2(ik, ip - 2);\n      }\n      size_t iang = 2 * l;\n      size_t j = 3, jc = ip - 3;\n      for (; j < ipph - 3; j += 4, jc -= 4)  // 126\n      {\n        iang += l;\n        if (iang >= ip) iang -= ip;\n        T0 ar1 = csarr[2 * iang], ai1 = csarr[2 * iang + 1];\n        iang += l;\n        if (iang >= ip) iang -= ip;\n        T0 ar2 = csarr[2 * iang], ai2 = csarr[2 * iang + 1];\n        iang += l;\n        if (iang >= ip) iang -= ip;\n        T0 ar3 = csarr[2 * iang], ai3 = csarr[2 * iang + 1];\n        iang += l;\n        if (iang >= ip) iang -= ip;\n        T0 ar4 = csarr[2 * iang], ai4 = csarr[2 * iang + 1];\n        for (size_t ik = 0; ik < idl1; ++ik)  // 125\n        {\n          CH2(ik, l) +=\n              ar1 * C2(ik, j) + ar2 * C2(ik, j + 1) + ar3 * C2(ik, j + 2) + ar4 * C2(ik, j + 3);\n          CH2(ik, lc) +=\n              ai1 * C2(ik, jc) + ai2 * C2(ik, jc - 1) + ai3 * C2(ik, jc - 2) + ai4 * C2(ik, jc - 3);\n        }\n      }\n      for (; j < ipph - 1; j += 2, jc -= 2)  // 126\n      {\n        iang += l;\n        if (iang >= ip) iang -= ip;\n        T0 ar1 = csarr[2 * iang], ai1 = csarr[2 * iang + 1];\n        iang += l;\n        if (iang >= ip) iang -= ip;\n        T0 ar2 = csarr[2 * iang], ai2 = csarr[2 * iang + 1];\n        for (size_t ik = 0; ik < idl1; ++ik)  // 125\n        {\n          CH2(ik, l) += ar1 * C2(ik, j) + ar2 * C2(ik, j + 1);\n          CH2(ik, lc) += ai1 * C2(ik, jc) + ai2 * C2(ik, jc - 1);\n        }\n      }\n      for (; j < ipph; ++j, --jc)  // 126\n      {\n        iang += l;\n        if (iang >= ip) iang -= ip;\n        T0 ar = csarr[2 * iang], ai = csarr[2 * iang + 1];\n        for (size_t ik = 0; ik < idl1; ++ik)  // 125\n        {\n          CH2(ik, l) += ar * C2(ik, j);\n          CH2(ik, lc) += ai * C2(ik, jc);\n        }\n      }\n    }\n    for (size_t ik = 0; ik < idl1; ++ik)  // 101\n      CH2(ik, 0) = C2(ik, 0);\n    for (size_t j = 1; j < ipph; ++j)       // 129\n      for (size_t ik = 0; ik < idl1; ++ik)  // 128\n        CH2(ik, 0) += C2(ik, j);\n\n    // everything in CH at this point!\n    // memset(cc,0,ip*l1*ido*sizeof(double));\n\n    for (size_t k = 0; k < l1; ++k)     // 131\n      for (size_t i = 0; i < ido; ++i)  // 130\n        CC(i, 0, k) = CH(i, k, 0);\n\n    for (size_t j = 1, jc = ip - 1; j < ipph; ++j, --jc)  // 137\n    {\n      size_t j2 = 2 * j - 1;\n      for (size_t k = 0; k < l1; ++k)  // 136\n      {\n        CC(ido - 1, j2, k) = CH(0, k, j);\n        CC(0, j2 + 1, k) = CH(0, k, jc);\n      }\n    }\n\n    if (ido == 1) return;\n\n    for (size_t j = 1, jc = ip - 1; j < ipph; ++j, --jc)  // 140\n    {\n      size_t j2 = 2 * j - 1;\n      for (size_t k = 0; k < l1; ++k)                                        // 139\n        for (size_t i = 1, ic = ido - i - 2; i <= ido - 2; i += 2, ic -= 2)  // 138\n        {\n          CC(i, j2 + 1, k) = CH(i, k, j) + CH(i, k, jc);\n          CC(ic, j2, k) = CH(i, k, j) - CH(i, k, jc);\n          CC(i + 1, j2 + 1, k) = CH(i + 1, k, j) + CH(i + 1, k, jc);\n          CC(ic + 1, j2, k) = CH(i + 1, k, jc) - CH(i + 1, k, j);\n        }\n    }\n  }\n\n  template<typename T>\n  void radb2(size_t ido, size_t l1, const T* POCKETFFT_RESTRICT cc, T* POCKETFFT_RESTRICT ch,\n             const T0* POCKETFFT_RESTRICT wa) const {\n    auto WA = [wa, ido](size_t x, size_t i) { return wa[i + x * (ido - 1)]; };\n    auto CC = [cc, ido](size_t a, size_t b, size_t c) -> const T& {\n      return cc[a + ido * (b + 2 * c)];\n    };\n    auto CH = [ch, ido, l1](size_t a, size_t b, size_t c) -> T& {\n      return ch[a + ido * (b + l1 * c)];\n    };\n\n    for (size_t k = 0; k < l1; k++) PM(CH(0, k, 0), CH(0, k, 1), CC(0, 0, k), CC(ido - 1, 1, k));\n    if ((ido & 1) == 0)\n      for (size_t k = 0; k < l1; k++) {\n        CH(ido - 1, k, 0) = 2 * CC(ido - 1, 0, k);\n        CH(ido - 1, k, 1) = -2 * CC(0, 1, k);\n      }\n    if (ido <= 2) return;\n    for (size_t k = 0; k < l1; ++k)\n      for (size_t i = 2; i < ido; i += 2) {\n        size_t ic = ido - i;\n        T ti2, tr2;\n        PM(CH(i - 1, k, 0), tr2, CC(i - 1, 0, k), CC(ic - 1, 1, k));\n        PM(ti2, CH(i, k, 0), CC(i, 0, k), CC(ic, 1, k));\n        MULPM(CH(i, k, 1), CH(i - 1, k, 1), WA(0, i - 2), WA(0, i - 1), ti2, tr2);\n      }\n  }\n\n  template<typename T>\n  void radb3(size_t ido, size_t l1, const T* POCKETFFT_RESTRICT cc, T* POCKETFFT_RESTRICT ch,\n             const T0* POCKETFFT_RESTRICT wa) const {\n    constexpr T0 taur = -0.5, taui = T0(0.8660254037844386467637231707529362L);\n\n    auto WA = [wa, ido](size_t x, size_t i) { return wa[i + x * (ido - 1)]; };\n    auto CC = [cc, ido](size_t a, size_t b, size_t c) -> const T& {\n      return cc[a + ido * (b + 3 * c)];\n    };\n    auto CH = [ch, ido, l1](size_t a, size_t b, size_t c) -> T& {\n      return ch[a + ido * (b + l1 * c)];\n    };\n\n    for (size_t k = 0; k < l1; k++) {\n      T tr2 = 2 * CC(ido - 1, 1, k);\n      T cr2 = CC(0, 0, k) + taur * tr2;\n      CH(0, k, 0) = CC(0, 0, k) + tr2;\n      T ci3 = 2 * taui * CC(0, 2, k);\n      PM(CH(0, k, 2), CH(0, k, 1), cr2, ci3);\n    }\n    if (ido == 1) return;\n    for (size_t k = 0; k < l1; k++)\n      for (size_t i = 2, ic = ido - 2; i < ido; i += 2, ic -= 2) {\n        T tr2 = CC(i - 1, 2, k) + CC(ic - 1, 1, k);  // t2=CC(I) + conj(CC(ic))\n        T ti2 = CC(i, 2, k) - CC(ic, 1, k);\n        T cr2 = CC(i - 1, 0, k) + taur * tr2;  // c2=CC +taur*t2\n        T ci2 = CC(i, 0, k) + taur * ti2;\n        CH(i - 1, k, 0) = CC(i - 1, 0, k) + tr2;  // CH=CC+t2\n        CH(i, k, 0) = CC(i, 0, k) + ti2;\n        T cr3 = taui * (CC(i - 1, 2, k) - CC(ic - 1, 1, k));  // c3=taui*(CC(i)-conj(CC(ic)))\n        T ci3 = taui * (CC(i, 2, k) + CC(ic, 1, k));\n        T di2, di3, dr2, dr3;\n        PM(dr3, dr2, cr2, ci3);  // d2= (cr2-ci3, ci2+cr3) = c2+i*c3\n        PM(di2, di3, ci2, cr3);  // d3= (cr2+ci3, ci2-cr3) = c2-i*c3\n        MULPM(CH(i, k, 1), CH(i - 1, k, 1), WA(0, i - 2), WA(0, i - 1), di2, dr2);  // ch = WA*d2\n        MULPM(CH(i, k, 2), CH(i - 1, k, 2), WA(1, i - 2), WA(1, i - 1), di3, dr3);\n      }\n  }\n\n  template<typename T>\n  void radb4(size_t ido, size_t l1, const T* POCKETFFT_RESTRICT cc, T* POCKETFFT_RESTRICT ch,\n             const T0* POCKETFFT_RESTRICT wa) const {\n    constexpr T0 sqrt2 = T0(1.414213562373095048801688724209698L);\n\n    auto WA = [wa, ido](size_t x, size_t i) { return wa[i + x * (ido - 1)]; };\n    auto CC = [cc, ido](size_t a, size_t b, size_t c) -> const T& {\n      return cc[a + ido * (b + 4 * c)];\n    };\n    auto CH = [ch, ido, l1](size_t a, size_t b, size_t c) -> T& {\n      return ch[a + ido * (b + l1 * c)];\n    };\n\n    for (size_t k = 0; k < l1; k++) {\n      T tr1, tr2;\n      PM(tr2, tr1, CC(0, 0, k), CC(ido - 1, 3, k));\n      T tr3 = 2 * CC(ido - 1, 1, k);\n      T tr4 = 2 * CC(0, 2, k);\n      PM(CH(0, k, 0), CH(0, k, 2), tr2, tr3);\n      PM(CH(0, k, 3), CH(0, k, 1), tr1, tr4);\n    }\n    if ((ido & 1) == 0)\n      for (size_t k = 0; k < l1; k++) {\n        T tr1, tr2, ti1, ti2;\n        PM(ti1, ti2, CC(0, 3, k), CC(0, 1, k));\n        PM(tr2, tr1, CC(ido - 1, 0, k), CC(ido - 1, 2, k));\n        CH(ido - 1, k, 0) = tr2 + tr2;\n        CH(ido - 1, k, 1) = sqrt2 * (tr1 - ti1);\n        CH(ido - 1, k, 2) = ti2 + ti2;\n        CH(ido - 1, k, 3) = -sqrt2 * (tr1 + ti1);\n      }\n    if (ido <= 2) return;\n    for (size_t k = 0; k < l1; ++k)\n      for (size_t i = 2; i < ido; i += 2) {\n        T ci2, ci3, ci4, cr2, cr3, cr4, ti1, ti2, ti3, ti4, tr1, tr2, tr3, tr4;\n        size_t ic = ido - i;\n        PM(tr2, tr1, CC(i - 1, 0, k), CC(ic - 1, 3, k));\n        PM(ti1, ti2, CC(i, 0, k), CC(ic, 3, k));\n        PM(tr4, ti3, CC(i, 2, k), CC(ic, 1, k));\n        PM(tr3, ti4, CC(i - 1, 2, k), CC(ic - 1, 1, k));\n        PM(CH(i - 1, k, 0), cr3, tr2, tr3);\n        PM(CH(i, k, 0), ci3, ti2, ti3);\n        PM(cr4, cr2, tr1, tr4);\n        PM(ci2, ci4, ti1, ti4);\n        MULPM(CH(i, k, 1), CH(i - 1, k, 1), WA(0, i - 2), WA(0, i - 1), ci2, cr2);\n        MULPM(CH(i, k, 2), CH(i - 1, k, 2), WA(1, i - 2), WA(1, i - 1), ci3, cr3);\n        MULPM(CH(i, k, 3), CH(i - 1, k, 3), WA(2, i - 2), WA(2, i - 1), ci4, cr4);\n      }\n  }\n\n  template<typename T>\n  void radb5(size_t ido, size_t l1, const T* POCKETFFT_RESTRICT cc, T* POCKETFFT_RESTRICT ch,\n             const T0* POCKETFFT_RESTRICT wa) const {\n    constexpr T0 tr11 = T0(0.3090169943749474241022934171828191L),\n                 ti11 = T0(0.9510565162951535721164393333793821L),\n                 tr12 = T0(-0.8090169943749474241022934171828191L),\n                 ti12 = T0(0.5877852522924731291687059546390728L);\n\n    auto WA = [wa, ido](size_t x, size_t i) { return wa[i + x * (ido - 1)]; };\n    auto CC = [cc, ido](size_t a, size_t b, size_t c) -> const T& {\n      return cc[a + ido * (b + 5 * c)];\n    };\n    auto CH = [ch, ido, l1](size_t a, size_t b, size_t c) -> T& {\n      return ch[a + ido * (b + l1 * c)];\n    };\n\n    for (size_t k = 0; k < l1; k++) {\n      T ti5 = CC(0, 2, k) + CC(0, 2, k);\n      T ti4 = CC(0, 4, k) + CC(0, 4, k);\n      T tr2 = CC(ido - 1, 1, k) + CC(ido - 1, 1, k);\n      T tr3 = CC(ido - 1, 3, k) + CC(ido - 1, 3, k);\n      CH(0, k, 0) = CC(0, 0, k) + tr2 + tr3;\n      T cr2 = CC(0, 0, k) + tr11 * tr2 + tr12 * tr3;\n      T cr3 = CC(0, 0, k) + tr12 * tr2 + tr11 * tr3;\n      T ci4, ci5;\n      MULPM(ci5, ci4, ti5, ti4, ti11, ti12);\n      PM(CH(0, k, 4), CH(0, k, 1), cr2, ci5);\n      PM(CH(0, k, 3), CH(0, k, 2), cr3, ci4);\n    }\n    if (ido == 1) return;\n    for (size_t k = 0; k < l1; ++k)\n      for (size_t i = 2, ic = ido - 2; i < ido; i += 2, ic -= 2) {\n        T tr2, tr3, tr4, tr5, ti2, ti3, ti4, ti5;\n        PM(tr2, tr5, CC(i - 1, 2, k), CC(ic - 1, 1, k));\n        PM(ti5, ti2, CC(i, 2, k), CC(ic, 1, k));\n        PM(tr3, tr4, CC(i - 1, 4, k), CC(ic - 1, 3, k));\n        PM(ti4, ti3, CC(i, 4, k), CC(ic, 3, k));\n        CH(i - 1, k, 0) = CC(i - 1, 0, k) + tr2 + tr3;\n        CH(i, k, 0) = CC(i, 0, k) + ti2 + ti3;\n        T cr2 = CC(i - 1, 0, k) + tr11 * tr2 + tr12 * tr3;\n        T ci2 = CC(i, 0, k) + tr11 * ti2 + tr12 * ti3;\n        T cr3 = CC(i - 1, 0, k) + tr12 * tr2 + tr11 * tr3;\n        T ci3 = CC(i, 0, k) + tr12 * ti2 + tr11 * ti3;\n        T ci4, ci5, cr5, cr4;\n        MULPM(cr5, cr4, tr5, tr4, ti11, ti12);\n        MULPM(ci5, ci4, ti5, ti4, ti11, ti12);\n        T dr2, dr3, dr4, dr5, di2, di3, di4, di5;\n        PM(dr4, dr3, cr3, ci4);\n        PM(di3, di4, ci3, cr4);\n        PM(dr5, dr2, cr2, ci5);\n        PM(di2, di5, ci2, cr5);\n        MULPM(CH(i, k, 1), CH(i - 1, k, 1), WA(0, i - 2), WA(0, i - 1), di2, dr2);\n        MULPM(CH(i, k, 2), CH(i - 1, k, 2), WA(1, i - 2), WA(1, i - 1), di3, dr3);\n        MULPM(CH(i, k, 3), CH(i - 1, k, 3), WA(2, i - 2), WA(2, i - 1), di4, dr4);\n        MULPM(CH(i, k, 4), CH(i - 1, k, 4), WA(3, i - 2), WA(3, i - 1), di5, dr5);\n      }\n  }\n\n  template<typename T>\n  void radbg(size_t ido, size_t ip, size_t l1, T* POCKETFFT_RESTRICT cc, T* POCKETFFT_RESTRICT ch,\n             const T0* POCKETFFT_RESTRICT wa, const T0* POCKETFFT_RESTRICT csarr) const {\n    const size_t cdim = ip;\n    size_t ipph = (ip + 1) / 2;\n    size_t idl1 = ido * l1;\n\n    auto CC = [cc, ido, cdim](size_t a, size_t b, size_t c) -> const T& {\n      return cc[a + ido * (b + cdim * c)];\n    };\n    auto CH = [ch, ido, l1](size_t a, size_t b, size_t c) -> T& {\n      return ch[a + ido * (b + l1 * c)];\n    };\n    auto C1 = [cc, ido, l1](size_t a, size_t b, size_t c) -> const T& {\n      return cc[a + ido * (b + l1 * c)];\n    };\n    auto C2 = [cc, idl1](size_t a, size_t b) -> T& { return cc[a + idl1 * b]; };\n    auto CH2 = [ch, idl1](size_t a, size_t b) -> T& { return ch[a + idl1 * b]; };\n\n    for (size_t k = 0; k < l1; ++k)     // 102\n      for (size_t i = 0; i < ido; ++i)  // 101\n        CH(i, k, 0) = CC(i, 0, k);\n    for (size_t j = 1, jc = ip - 1; j < ipph; ++j, --jc)  // 108\n    {\n      size_t j2 = 2 * j - 1;\n      for (size_t k = 0; k < l1; ++k) {\n        CH(0, k, j) = 2 * CC(ido - 1, j2, k);\n        CH(0, k, jc) = 2 * CC(0, j2 + 1, k);\n      }\n    }\n\n    if (ido != 1) {\n      for (size_t j = 1, jc = ip - 1; j < ipph; ++j, --jc)  // 111\n      {\n        size_t j2 = 2 * j - 1;\n        for (size_t k = 0; k < l1; ++k)\n          for (size_t i = 1, ic = ido - i - 2; i <= ido - 2; i += 2, ic -= 2)  // 109\n          {\n            CH(i, k, j) = CC(i, j2 + 1, k) + CC(ic, j2, k);\n            CH(i, k, jc) = CC(i, j2 + 1, k) - CC(ic, j2, k);\n            CH(i + 1, k, j) = CC(i + 1, j2 + 1, k) - CC(ic + 1, j2, k);\n            CH(i + 1, k, jc) = CC(i + 1, j2 + 1, k) + CC(ic + 1, j2, k);\n          }\n      }\n    }\n    for (size_t l = 1, lc = ip - 1; l < ipph; ++l, --lc) {\n      for (size_t ik = 0; ik < idl1; ++ik) {\n        C2(ik, l) = CH2(ik, 0) + csarr[2 * l] * CH2(ik, 1) + csarr[4 * l] * CH2(ik, 2);\n        C2(ik, lc) = csarr[2 * l + 1] * CH2(ik, ip - 1) + csarr[4 * l + 1] * CH2(ik, ip - 2);\n      }\n      size_t iang = 2 * l;\n      size_t j = 3, jc = ip - 3;\n      for (; j < ipph - 3; j += 4, jc -= 4) {\n        iang += l;\n        if (iang > ip) iang -= ip;\n        T0 ar1 = csarr[2 * iang], ai1 = csarr[2 * iang + 1];\n        iang += l;\n        if (iang > ip) iang -= ip;\n        T0 ar2 = csarr[2 * iang], ai2 = csarr[2 * iang + 1];\n        iang += l;\n        if (iang > ip) iang -= ip;\n        T0 ar3 = csarr[2 * iang], ai3 = csarr[2 * iang + 1];\n        iang += l;\n        if (iang > ip) iang -= ip;\n        T0 ar4 = csarr[2 * iang], ai4 = csarr[2 * iang + 1];\n        for (size_t ik = 0; ik < idl1; ++ik) {\n          C2(ik, l) +=\n              ar1 * CH2(ik, j) + ar2 * CH2(ik, j + 1) + ar3 * CH2(ik, j + 2) + ar4 * CH2(ik, j + 3);\n          C2(ik, lc) += ai1 * CH2(ik, jc) + ai2 * CH2(ik, jc - 1) + ai3 * CH2(ik, jc - 2)\n                        + ai4 * CH2(ik, jc - 3);\n        }\n      }\n      for (; j < ipph - 1; j += 2, jc -= 2) {\n        iang += l;\n        if (iang > ip) iang -= ip;\n        T0 ar1 = csarr[2 * iang], ai1 = csarr[2 * iang + 1];\n        iang += l;\n        if (iang > ip) iang -= ip;\n        T0 ar2 = csarr[2 * iang], ai2 = csarr[2 * iang + 1];\n        for (size_t ik = 0; ik < idl1; ++ik) {\n          C2(ik, l) += ar1 * CH2(ik, j) + ar2 * CH2(ik, j + 1);\n          C2(ik, lc) += ai1 * CH2(ik, jc) + ai2 * CH2(ik, jc - 1);\n        }\n      }\n      for (; j < ipph; ++j, --jc) {\n        iang += l;\n        if (iang > ip) iang -= ip;\n        T0 war = csarr[2 * iang], wai = csarr[2 * iang + 1];\n        for (size_t ik = 0; ik < idl1; ++ik) {\n          C2(ik, l) += war * CH2(ik, j);\n          C2(ik, lc) += wai * CH2(ik, jc);\n        }\n      }\n    }\n    for (size_t j = 1; j < ipph; ++j)\n      for (size_t ik = 0; ik < idl1; ++ik) CH2(ik, 0) += CH2(ik, j);\n    for (size_t j = 1, jc = ip - 1; j < ipph; ++j, --jc)  // 124\n      for (size_t k = 0; k < l1; ++k) PM(CH(0, k, jc), CH(0, k, j), C1(0, k, j), C1(0, k, jc));\n\n    if (ido == 1) return;\n\n    for (size_t j = 1, jc = ip - 1; j < ipph; ++j, --jc)  // 127\n      for (size_t k = 0; k < l1; ++k)\n        for (size_t i = 1; i <= ido - 2; i += 2) {\n          CH(i, k, j) = C1(i, k, j) - C1(i + 1, k, jc);\n          CH(i, k, jc) = C1(i, k, j) + C1(i + 1, k, jc);\n          CH(i + 1, k, j) = C1(i + 1, k, j) + C1(i, k, jc);\n          CH(i + 1, k, jc) = C1(i + 1, k, j) - C1(i, k, jc);\n        }\n\n    // All in CH\n\n    for (size_t j = 1; j < ip; ++j) {\n      size_t is = (j - 1) * (ido - 1);\n      for (size_t k = 0; k < l1; ++k) {\n        size_t idij = is;\n        for (size_t i = 1; i <= ido - 2; i += 2) {\n          T t1 = CH(i, k, j), t2 = CH(i + 1, k, j);\n          CH(i, k, j) = wa[idij] * t1 - wa[idij + 1] * t2;\n          CH(i + 1, k, j) = wa[idij] * t2 + wa[idij + 1] * t1;\n          idij += 2;\n        }\n      }\n    }\n  }\n\n  template<typename T>\n  void copy_and_norm(T* c, T* p1, T0 fct) const {\n    if (p1 != c) {\n      if (fct != 1.)\n        for (size_t i = 0; i < length; ++i) c[i] = fct * p1[i];\n      else\n        std::copy_n(p1, length, c);\n    } else if (fct != 1.)\n      for (size_t i = 0; i < length; ++i) c[i] *= fct;\n  }\n\n public:\n  template<typename T>\n  void exec(T c[], T0 fct, bool r2hc) const {\n    if (length == 1) {\n      c[0] *= fct;\n      return;\n    }\n    size_t nf = fact.size();\n    arr<T> ch(length);\n    T *p1 = c, *p2 = ch.data();\n\n    if (r2hc)\n      for (size_t k1 = 0, l1 = length; k1 < nf; ++k1) {\n        size_t k = nf - k1 - 1;\n        size_t ip = fact[k].fct;\n        size_t ido = length / l1;\n        l1 /= ip;\n        if (ip == 4)\n          radf4(ido, l1, p1, p2, fact[k].tw);\n        else if (ip == 2)\n          radf2(ido, l1, p1, p2, fact[k].tw);\n        else if (ip == 3)\n          radf3(ido, l1, p1, p2, fact[k].tw);\n        else if (ip == 5)\n          radf5(ido, l1, p1, p2, fact[k].tw);\n        else {\n          radfg(ido, ip, l1, p1, p2, fact[k].tw, fact[k].tws);\n          std::swap(p1, p2);\n        }\n        std::swap(p1, p2);\n      }\n    else\n      for (size_t k = 0, l1 = 1; k < nf; k++) {\n        size_t ip = fact[k].fct, ido = length / (ip * l1);\n        if (ip == 4)\n          radb4(ido, l1, p1, p2, fact[k].tw);\n        else if (ip == 2)\n          radb2(ido, l1, p1, p2, fact[k].tw);\n        else if (ip == 3)\n          radb3(ido, l1, p1, p2, fact[k].tw);\n        else if (ip == 5)\n          radb5(ido, l1, p1, p2, fact[k].tw);\n        else\n          radbg(ido, ip, l1, p1, p2, fact[k].tw, fact[k].tws);\n        std::swap(p1, p2);\n        l1 *= ip;\n      }\n\n    copy_and_norm(c, p1, fct);\n  }\n\n private:\n  void factorize() {\n    size_t len = length;\n    while ((len % 4) == 0) {\n      add_factor(4);\n      len >>= 2;\n    }\n    if ((len % 2) == 0) {\n      len >>= 1;\n      // factor 2 should be at the front of the factor list\n      add_factor(2);\n      std::swap(fact[0].fct, fact.back().fct);\n    }\n    for (size_t divisor = 3; divisor * divisor <= len; divisor += 2)\n      while ((len % divisor) == 0) {\n        add_factor(divisor);\n        len /= divisor;\n      }\n    if (len > 1) add_factor(len);\n  }\n\n  size_t twsize() const {\n    size_t twsz = 0, l1 = 1;\n    for (size_t k = 0; k < fact.size(); ++k) {\n      size_t ip = fact[k].fct, ido = length / (l1 * ip);\n      twsz += (ip - 1) * (ido - 1);\n      if (ip > 5) twsz += 2 * ip;\n      l1 *= ip;\n    }\n    return twsz;\n  }\n\n  void comp_twiddle() {\n    sincos_2pibyn<T0> twid(length);\n    size_t l1 = 1;\n    T0* ptr = mem.data();\n    for (size_t k = 0; k < fact.size(); ++k) {\n      size_t ip = fact[k].fct, ido = length / (l1 * ip);\n      if (k < fact.size() - 1)  // last factor doesn't need twiddles\n      {\n        fact[k].tw = ptr;\n        ptr += (ip - 1) * (ido - 1);\n        for (size_t j = 1; j < ip; ++j)\n          for (size_t i = 1; i <= (ido - 1) / 2; ++i) {\n            fact[k].tw[(j - 1) * (ido - 1) + 2 * i - 2] = twid[j * l1 * i].r;\n            fact[k].tw[(j - 1) * (ido - 1) + 2 * i - 1] = twid[j * l1 * i].i;\n          }\n      }\n      if (ip > 5)  // special factors required by *g functions\n      {\n        fact[k].tws = ptr;\n        ptr += 2 * ip;\n        fact[k].tws[0] = 1.;\n        fact[k].tws[1] = 0.;\n        for (size_t i = 2, ic = 2 * ip - 2; i <= ic; i += 2, ic -= 2) {\n          fact[k].tws[i] = twid[i / 2 * (length / ip)].r;\n          fact[k].tws[i + 1] = twid[i / 2 * (length / ip)].i;\n          fact[k].tws[ic] = twid[i / 2 * (length / ip)].r;\n          fact[k].tws[ic + 1] = -twid[i / 2 * (length / ip)].i;\n        }\n      }\n      l1 *= ip;\n    }\n  }\n\n public:\n  POCKETFFT_NOINLINE rfftp(size_t length_) : length(length_) {\n    if (length == 0) throw std::runtime_error(\"zero-length FFT requested\");\n    if (length == 1) return;\n    factorize();\n    mem.resize(twsize());\n    comp_twiddle();\n  }\n};\n\n//\n// complex Bluestein transforms\n//\n\ntemplate<typename T0>\nclass fftblue {\n private:\n  size_t n, n2;\n  cfftp<T0> plan;\n  arr<cmplx<T0>> mem;\n  cmplx<T0>*bk, *bkf;\n\n  template<bool fwd, typename T>\n  void fft(cmplx<T> c[], T0 fct) const {\n    arr<cmplx<T>> akf(n2);\n\n    /* initialize a_k and FFT it */\n    for (size_t m = 0; m < n; ++m) special_mul<fwd>(c[m], bk[m], akf[m]);\n    auto zero = akf[0] * T0(0);\n    for (size_t m = n; m < n2; ++m) akf[m] = zero;\n\n    plan.exec(akf.data(), 1., true);\n\n    /* do the convolution */\n    akf[0] = akf[0].template special_mul<!fwd>(bkf[0]);\n    for (size_t m = 1; m < (n2 + 1) / 2; ++m) {\n      akf[m] = akf[m].template special_mul<!fwd>(bkf[m]);\n      akf[n2 - m] = akf[n2 - m].template special_mul<!fwd>(bkf[m]);\n    }\n    if ((n2 & 1) == 0) akf[n2 / 2] = akf[n2 / 2].template special_mul<!fwd>(bkf[n2 / 2]);\n\n    /* inverse FFT */\n    plan.exec(akf.data(), 1., false);\n\n    /* multiply by b_k */\n    for (size_t m = 0; m < n; ++m) c[m] = akf[m].template special_mul<fwd>(bk[m]) * fct;\n  }\n\n public:\n  POCKETFFT_NOINLINE fftblue(size_t length)\n      : n(length),\n        n2(util::good_size_cmplx(n * 2 - 1)),\n        plan(n2),\n        mem(n + n2 / 2 + 1),\n        bk(mem.data()),\n        bkf(mem.data() + n) {\n    /* initialize b_k */\n    sincos_2pibyn<T0> tmp(2 * n);\n    bk[0].Set(1, 0);\n\n    size_t coeff = 0;\n    for (size_t m = 1; m < n; ++m) {\n      coeff += 2 * m - 1;\n      if (coeff >= 2 * n) coeff -= 2 * n;\n      bk[m] = tmp[coeff];\n    }\n\n    /* initialize the zero-padded, Fourier transformed b_k. Add normalisation. */\n    arr<cmplx<T0>> tbkf(n2);\n    T0 xn2 = T0(1) / T0(n2);\n    tbkf[0] = bk[0] * xn2;\n    for (size_t m = 1; m < n; ++m) tbkf[m] = tbkf[n2 - m] = bk[m] * xn2;\n    for (size_t m = n; m <= (n2 - n); ++m) tbkf[m].Set(0., 0.);\n    plan.exec(tbkf.data(), 1., true);\n    for (size_t i = 0; i < n2 / 2 + 1; ++i) bkf[i] = tbkf[i];\n  }\n\n  template<typename T>\n  void exec(cmplx<T> c[], T0 fct, bool fwd) const {\n    fwd ? fft<true>(c, fct) : fft<false>(c, fct);\n  }\n\n  template<typename T>\n  void exec_r(T c[], T0 fct, bool fwd) {\n    arr<cmplx<T>> tmp(n);\n    if (fwd) {\n      auto zero = T0(0) * c[0];\n      for (size_t m = 0; m < n; ++m) tmp[m].Set(c[m], zero);\n      fft<true>(tmp.data(), fct);\n      c[0] = tmp[0].r;\n      std::copy_n(&tmp[1].r, n - 1, &c[1]);\n    } else {\n      tmp[0].Set(c[0], c[0] * 0);\n      std::copy_n(c + 1, n - 1, &tmp[1].r);\n      if ((n & 1) == 0) tmp[n / 2].i = T0(0) * c[0];\n      for (size_t m = 1; 2 * m < n; ++m) tmp[n - m].Set(tmp[m].r, -tmp[m].i);\n      fft<false>(tmp.data(), fct);\n      for (size_t m = 0; m < n; ++m) c[m] = tmp[m].r;\n    }\n  }\n};\n\n//\n// flexible (FFTPACK/Bluestein) complex 1D transform\n//\n\ntemplate<typename T0>\nclass pocketfft_c {\n private:\n  std::unique_ptr<cfftp<T0>> packplan;\n  std::unique_ptr<fftblue<T0>> blueplan;\n  size_t len;\n\n public:\n  POCKETFFT_NOINLINE pocketfft_c(size_t length) : len(length) {\n    if (length == 0) throw std::runtime_error(\"zero-length FFT requested\");\n    size_t tmp = (length < 50) ? 0 : util::largest_prime_factor(length);\n    if (tmp * tmp <= length) {\n      packplan = std::unique_ptr<cfftp<T0>>(new cfftp<T0>(length));\n      return;\n    }\n    double comp1 = util::cost_guess(length);\n    double comp2 = 2 * util::cost_guess(util::good_size_cmplx(2 * length - 1));\n    comp2 *= 1.5;       /* fudge factor that appears to give good overall performance */\n    if (comp2 < comp1)  // use Bluestein\n      blueplan = std::unique_ptr<fftblue<T0>>(new fftblue<T0>(length));\n    else\n      packplan = std::unique_ptr<cfftp<T0>>(new cfftp<T0>(length));\n  }\n\n  template<typename T>\n  POCKETFFT_NOINLINE void exec(cmplx<T> c[], T0 fct, bool fwd) const {\n    packplan ? packplan->exec(c, fct, fwd) : blueplan->exec(c, fct, fwd);\n  }\n\n  size_t length() const { return len; }\n};\n\n//\n// flexible (FFTPACK/Bluestein) real-valued 1D transform\n//\n\ntemplate<typename T0>\nclass pocketfft_r {\n private:\n  std::unique_ptr<rfftp<T0>> packplan;\n  std::unique_ptr<fftblue<T0>> blueplan;\n  size_t len;\n\n public:\n  POCKETFFT_NOINLINE pocketfft_r(size_t length) : len(length) {\n    if (length == 0) throw std::runtime_error(\"zero-length FFT requested\");\n    size_t tmp = (length < 50) ? 0 : util::largest_prime_factor(length);\n    if (tmp * tmp <= length) {\n      packplan = std::unique_ptr<rfftp<T0>>(new rfftp<T0>(length));\n      return;\n    }\n    double comp1 = 0.5 * util::cost_guess(length);\n    double comp2 = 2 * util::cost_guess(util::good_size_cmplx(2 * length - 1));\n    comp2 *= 1.5;       /* fudge factor that appears to give good overall performance */\n    if (comp2 < comp1)  // use Bluestein\n      blueplan = std::unique_ptr<fftblue<T0>>(new fftblue<T0>(length));\n    else\n      packplan = std::unique_ptr<rfftp<T0>>(new rfftp<T0>(length));\n  }\n\n  template<typename T>\n  POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool fwd) const {\n    packplan ? packplan->exec(c, fct, fwd) : blueplan->exec_r(c, fct, fwd);\n  }\n\n  size_t length() const { return len; }\n};\n\n//\n// sine/cosine transforms\n//\n\ntemplate<typename T0>\nclass T_dct1 {\n private:\n  pocketfft_r<T0> fftplan;\n\n public:\n  POCKETFFT_NOINLINE T_dct1(size_t length) : fftplan(2 * (length - 1)) {}\n\n  template<typename T>\n  POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool ortho, int /*type*/, bool /*cosine*/) const {\n    constexpr T0 sqrt2 = T0(1.414213562373095048801688724209698L);\n    size_t N = fftplan.length(), n = N / 2 + 1;\n    if (ortho) {\n      c[0] *= sqrt2;\n      c[n - 1] *= sqrt2;\n    }\n    arr<T> tmp(N);\n    tmp[0] = c[0];\n    for (size_t i = 1; i < n; ++i) tmp[i] = tmp[N - i] = c[i];\n    fftplan.exec(tmp.data(), fct, true);\n    c[0] = tmp[0];\n    for (size_t i = 1; i < n; ++i) c[i] = tmp[2 * i - 1];\n    if (ortho) {\n      c[0] *= sqrt2 * T0(0.5);\n      c[n - 1] *= sqrt2 * T0(0.5);\n    }\n  }\n\n  size_t length() const { return fftplan.length() / 2 + 1; }\n};\n\ntemplate<typename T0>\nclass T_dst1 {\n private:\n  pocketfft_r<T0> fftplan;\n\n public:\n  POCKETFFT_NOINLINE T_dst1(size_t length) : fftplan(2 * (length + 1)) {}\n\n  template<typename T>\n  POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool /*ortho*/, int /*type*/, bool /*cosine*/) const {\n    size_t N = fftplan.length(), n = N / 2 - 1;\n    arr<T> tmp(N);\n    tmp[0] = tmp[n + 1] = c[0] * 0;\n    for (size_t i = 0; i < n; ++i) {\n      tmp[i + 1] = c[i];\n      tmp[N - 1 - i] = -c[i];\n    }\n    fftplan.exec(tmp.data(), fct, true);\n    for (size_t i = 0; i < n; ++i) c[i] = -tmp[2 * i + 2];\n  }\n\n  size_t length() const { return fftplan.length() / 2 - 1; }\n};\n\ntemplate<typename T0>\nclass T_dcst23 {\n private:\n  pocketfft_r<T0> fftplan;\n  std::vector<T0> twiddle;\n\n public:\n  POCKETFFT_NOINLINE T_dcst23(size_t length) : fftplan(length), twiddle(length) {\n    sincos_2pibyn<T0> tw(4 * length);\n    for (size_t i = 0; i < length; ++i) twiddle[i] = tw[i + 1].r;\n  }\n\n  template<typename T>\n  POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool ortho, int type, bool cosine) const {\n    constexpr T0 sqrt2 = T0(1.414213562373095048801688724209698L);\n    size_t N = length();\n    size_t NS2 = (N + 1) / 2;\n    if (type == 2) {\n      if (!cosine)\n        for (size_t k = 1; k < N; k += 2) c[k] = -c[k];\n      c[0] *= 2;\n      if ((N & 1) == 0) c[N - 1] *= 2;\n      for (size_t k = 1; k < N - 1; k += 2) MPINPLACE(c[k + 1], c[k]);\n      fftplan.exec(c, fct, false);\n      for (size_t k = 1, kc = N - 1; k < NS2; ++k, --kc) {\n        T t1 = twiddle[k - 1] * c[kc] + twiddle[kc - 1] * c[k];\n        T t2 = twiddle[k - 1] * c[k] - twiddle[kc - 1] * c[kc];\n        c[k] = T0(0.5) * (t1 + t2);\n        c[kc] = T0(0.5) * (t1 - t2);\n      }\n      if ((N & 1) == 0) c[NS2] *= twiddle[NS2 - 1];\n      if (!cosine)\n        for (size_t k = 0, kc = N - 1; k < kc; ++k, --kc) std::swap(c[k], c[kc]);\n      if (ortho) c[0] *= sqrt2 * T0(0.5);\n    } else {\n      if (ortho) c[0] *= sqrt2;\n      if (!cosine)\n        for (size_t k = 0, kc = N - 1; k < NS2; ++k, --kc) std::swap(c[k], c[kc]);\n      for (size_t k = 1, kc = N - 1; k < NS2; ++k, --kc) {\n        T t1 = c[k] + c[kc], t2 = c[k] - c[kc];\n        c[k] = twiddle[k - 1] * t2 + twiddle[kc - 1] * t1;\n        c[kc] = twiddle[k - 1] * t1 - twiddle[kc - 1] * t2;\n      }\n      if ((N & 1) == 0) c[NS2] *= 2 * twiddle[NS2 - 1];\n      fftplan.exec(c, fct, true);\n      for (size_t k = 1; k < N - 1; k += 2) MPINPLACE(c[k], c[k + 1]);\n      if (!cosine)\n        for (size_t k = 1; k < N; k += 2) c[k] = -c[k];\n    }\n  }\n\n  size_t length() const { return fftplan.length(); }\n};\n\ntemplate<typename T0>\nclass T_dcst4 {\n private:\n  size_t N;\n  std::unique_ptr<pocketfft_c<T0>> fft;\n  std::unique_ptr<pocketfft_r<T0>> rfft;\n  arr<cmplx<T0>> C2;\n\n public:\n  POCKETFFT_NOINLINE T_dcst4(size_t length)\n      : N(length),\n        fft((N & 1) ? nullptr : new pocketfft_c<T0>(N / 2)),\n        rfft((N & 1) ? new pocketfft_r<T0>(N) : nullptr),\n        C2((N & 1) ? 0 : N / 2) {\n    if ((N & 1) == 0) {\n      sincos_2pibyn<T0> tw(16 * N);\n      for (size_t i = 0; i < N / 2; ++i) C2[i] = conj(tw[8 * i + 1]);\n    }\n  }\n\n  template<typename T>\n  POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool /*ortho*/, int /*type*/, bool cosine) const {\n    size_t n2 = N / 2;\n    if (!cosine)\n      for (size_t k = 0, kc = N - 1; k < n2; ++k, --kc) std::swap(c[k], c[kc]);\n    if (N & 1) {\n      // The following code is derived from the FFTW3 function apply_re11()\n      // and is released under the 3-clause BSD license with friendly\n      // permission of Matteo Frigo and Steven G. Johnson.\n\n      arr<T> y(N);\n      {\n        size_t i = 0, m = n2;\n        for (; m < N; ++i, m += 4) y[i] = c[m];\n        for (; m < 2 * N; ++i, m += 4) y[i] = -c[2 * N - m - 1];\n        for (; m < 3 * N; ++i, m += 4) y[i] = -c[m - 2 * N];\n        for (; m < 4 * N; ++i, m += 4) y[i] = c[4 * N - m - 1];\n        for (; i < N; ++i, m += 4) y[i] = c[m - 4 * N];\n      }\n      rfft->exec(y.data(), fct, true);\n      {\n        auto SGN = [](size_t i) {\n          constexpr T0 sqrt2 = T0(1.414213562373095048801688724209698L);\n          return (i & 2) ? -sqrt2 : sqrt2;\n        };\n        c[n2] = y[0] * SGN(n2 + 1);\n        size_t i = 0, i1 = 1, k = 1;\n        for (; k < n2; ++i, ++i1, k += 2) {\n          c[i] = y[2 * k - 1] * SGN(i1) + y[2 * k] * SGN(i);\n          c[N - i1] = y[2 * k - 1] * SGN(N - i) - y[2 * k] * SGN(N - i1);\n          c[n2 - i1] = y[2 * k + 1] * SGN(n2 - i) - y[2 * k + 2] * SGN(n2 - i1);\n          c[n2 + i1] = y[2 * k + 1] * SGN(n2 + i + 2) + y[2 * k + 2] * SGN(n2 + i1);\n        }\n        if (k == n2) {\n          c[i] = y[2 * k - 1] * SGN(i + 1) + y[2 * k] * SGN(i);\n          c[N - i1] = y[2 * k - 1] * SGN(i + 2) + y[2 * k] * SGN(i1);\n        }\n      }\n\n      // FFTW-derived code ends here\n    } else {\n      // even length algorithm from\n      // https://www.appletonaudio.com/blog/2013/derivation-of-fast-dct-4-algorithm-based-on-dft/\n      arr<cmplx<T>> y(n2);\n      for (size_t i = 0; i < n2; ++i) {\n        y[i].Set(c[2 * i], c[N - 1 - 2 * i]);\n        y[i] *= C2[i];\n      }\n      fft->exec(y.data(), fct, true);\n      for (size_t i = 0, ic = n2 - 1; i < n2; ++i, --ic) {\n        c[2 * i] = 2 * (y[i].r * C2[i].r - y[i].i * C2[i].i);\n        c[2 * i + 1] = -2 * (y[ic].i * C2[ic].r + y[ic].r * C2[ic].i);\n      }\n    }\n    if (!cosine)\n      for (size_t k = 1; k < N; k += 2) c[k] = -c[k];\n  }\n\n  size_t length() const { return N; }\n};\n\n//\n// multi-D infrastructure\n//\n\ntemplate<typename T>\nstd::shared_ptr<T> get_plan(size_t length) {\n#if POCKETFFT_CACHE_SIZE == 0\n  return std::make_shared<T>(length);\n#else\n  constexpr size_t nmax = POCKETFFT_CACHE_SIZE;\n  static std::array<std::shared_ptr<T>, nmax> cache;\n  static std::array<size_t, nmax> last_access{{0}};\n  static size_t access_counter = 0;\n  static std::mutex mut;\n\n  auto find_in_cache = [&]() -> std::shared_ptr<T> {\n    for (size_t i = 0; i < nmax; ++i)\n      if (cache[i] && (cache[i]->length() == length)) {\n        // no need to update if this is already the most recent entry\n        if (last_access[i] != access_counter) {\n          last_access[i] = ++access_counter;\n          // Guard against overflow\n          if (access_counter == 0) last_access.fill(0);\n        }\n        return cache[i];\n      }\n\n    return nullptr;\n  };\n\n  {\n    std::lock_guard<std::mutex> lock(mut);\n    auto p = find_in_cache();\n    if (p) return p;\n  }\n  auto plan = std::make_shared<T>(length);\n  {\n    std::lock_guard<std::mutex> lock(mut);\n    auto p = find_in_cache();\n    if (p) return p;\n\n    size_t lru = 0;\n    for (size_t i = 1; i < nmax; ++i)\n      if (last_access[i] < last_access[lru]) lru = i;\n\n    cache[lru] = plan;\n    last_access[lru] = ++access_counter;\n  }\n  return plan;\n#endif\n}\n\nclass arr_info {\n protected:\n  shape_t shp;\n  stride_t str;\n\n public:\n  arr_info(const shape_t& shape_, const stride_t& stride_) : shp(shape_), str(stride_) {}\n  size_t ndim() const { return shp.size(); }\n  size_t size() const { return util::prod(shp); }\n  const shape_t& shape() const { return shp; }\n  size_t shape(size_t i) const { return shp[i]; }\n  const stride_t& stride() const { return str; }\n  const ptrdiff_t& stride(size_t i) const { return str[i]; }\n};\n\ntemplate<typename T>\nclass cndarr : public arr_info {\n protected:\n  const char* d;\n\n public:\n  cndarr(const void* data_, const shape_t& shape_, const stride_t& stride_)\n      : arr_info(shape_, stride_), d(reinterpret_cast<const char*>(data_)) {}\n  const T& operator[](ptrdiff_t ofs) const { return *reinterpret_cast<const T*>(d + ofs); }\n};\n\ntemplate<typename T>\nclass ndarr : public cndarr<T> {\n public:\n  ndarr(void* data_, const shape_t& shape_, const stride_t& stride_)\n      : cndarr<T>::cndarr(const_cast<const void*>(data_), shape_, stride_) {}\n  T& operator[](ptrdiff_t ofs) {\n    return *reinterpret_cast<T*>(const_cast<char*>(cndarr<T>::d + ofs));\n  }\n};\n\ntemplate<size_t N>\nclass multi_iter {\n private:\n  shape_t pos;\n  const arr_info &iarr, &oarr;\n  ptrdiff_t p_ii, p_i[N], str_i, p_oi, p_o[N], str_o;\n  size_t idim, rem;\n\n  void advance_i() {\n    for (int i_ = int(pos.size()) - 1; i_ >= 0; --i_) {\n      auto i = size_t(i_);\n      if (i == idim) continue;\n      p_ii += iarr.stride(i);\n      p_oi += oarr.stride(i);\n      if (++pos[i] < iarr.shape(i)) return;\n      pos[i] = 0;\n      p_ii -= ptrdiff_t(iarr.shape(i)) * iarr.stride(i);\n      p_oi -= ptrdiff_t(oarr.shape(i)) * oarr.stride(i);\n    }\n  }\n\n public:\n  multi_iter(const arr_info& iarr_, const arr_info& oarr_, size_t idim_)\n      : pos(iarr_.ndim(), 0),\n        iarr(iarr_),\n        oarr(oarr_),\n        p_ii(0),\n        str_i(iarr.stride(idim_)),\n        p_oi(0),\n        str_o(oarr.stride(idim_)),\n        idim(idim_),\n        rem(iarr.size() / iarr.shape(idim)) {\n    auto nshares = threading::num_threads();\n    if (nshares == 1) return;\n    if (nshares == 0) throw std::runtime_error(\"can't run with zero threads\");\n    auto myshare = threading::thread_id();\n    if (myshare >= nshares) throw std::runtime_error(\"impossible share requested\");\n    size_t nbase = rem / nshares;\n    size_t additional = rem % nshares;\n    size_t lo = myshare * nbase + ((myshare < additional) ? myshare : additional);\n    size_t hi = lo + nbase + (myshare < additional);\n    size_t todo = hi - lo;\n\n    size_t chunk = rem;\n    for (size_t i = 0; i < pos.size(); ++i) {\n      if (i == idim) continue;\n      chunk /= iarr.shape(i);\n      size_t n_advance = lo / chunk;\n      pos[i] += n_advance;\n      p_ii += ptrdiff_t(n_advance) * iarr.stride(i);\n      p_oi += ptrdiff_t(n_advance) * oarr.stride(i);\n      lo -= n_advance * chunk;\n    }\n    rem = todo;\n  }\n  void advance(size_t n) {\n    if (rem < n) throw std::runtime_error(\"underrun\");\n    for (size_t i = 0; i < n; ++i) {\n      p_i[i] = p_ii;\n      p_o[i] = p_oi;\n      advance_i();\n    }\n    rem -= n;\n  }\n  ptrdiff_t iofs(size_t i) const { return p_i[0] + ptrdiff_t(i) * str_i; }\n  ptrdiff_t iofs(size_t j, size_t i) const { return p_i[j] + ptrdiff_t(i) * str_i; }\n  ptrdiff_t oofs(size_t i) const { return p_o[0] + ptrdiff_t(i) * str_o; }\n  ptrdiff_t oofs(size_t j, size_t i) const { return p_o[j] + ptrdiff_t(i) * str_o; }\n  size_t length_in() const { return iarr.shape(idim); }\n  size_t length_out() const { return oarr.shape(idim); }\n  ptrdiff_t stride_in() const { return str_i; }\n  ptrdiff_t stride_out() const { return str_o; }\n  size_t remaining() const { return rem; }\n};\n\nclass simple_iter {\n private:\n  shape_t pos;\n  const arr_info& arr;\n  ptrdiff_t p;\n  size_t rem;\n\n public:\n  simple_iter(const arr_info& arr_) : pos(arr_.ndim(), 0), arr(arr_), p(0), rem(arr_.size()) {}\n  void advance() {\n    --rem;\n    for (int i_ = int(pos.size()) - 1; i_ >= 0; --i_) {\n      auto i = size_t(i_);\n      p += arr.stride(i);\n      if (++pos[i] < arr.shape(i)) return;\n      pos[i] = 0;\n      p -= ptrdiff_t(arr.shape(i)) * arr.stride(i);\n    }\n  }\n  ptrdiff_t ofs() const { return p; }\n  size_t remaining() const { return rem; }\n};\n\nclass rev_iter {\n private:\n  shape_t pos;\n  const arr_info& arr;\n  std::vector<char> rev_axis;\n  std::vector<char> rev_jump;\n  size_t last_axis, last_size;\n  shape_t shp;\n  ptrdiff_t p, rp;\n  size_t rem;\n\n public:\n  rev_iter(const arr_info& arr_, const shape_t& axes)\n      : pos(arr_.ndim(), 0),\n        arr(arr_),\n        rev_axis(arr_.ndim(), 0),\n        rev_jump(arr_.ndim(), 1),\n        p(0),\n        rp(0) {\n    for (auto ax : axes) rev_axis[ax] = 1;\n    last_axis = axes.back();\n    last_size = arr.shape(last_axis) / 2 + 1;\n    shp = arr.shape();\n    shp[last_axis] = last_size;\n    rem = 1;\n    for (auto i : shp) rem *= i;\n  }\n  void advance() {\n    --rem;\n    for (int i_ = int(pos.size()) - 1; i_ >= 0; --i_) {\n      auto i = size_t(i_);\n      p += arr.stride(i);\n      if (!rev_axis[i])\n        rp += arr.stride(i);\n      else {\n        rp -= arr.stride(i);\n        if (rev_jump[i]) {\n          rp += ptrdiff_t(arr.shape(i)) * arr.stride(i);\n          rev_jump[i] = 0;\n        }\n      }\n      if (++pos[i] < shp[i]) return;\n      pos[i] = 0;\n      p -= ptrdiff_t(shp[i]) * arr.stride(i);\n      if (rev_axis[i]) {\n        rp -= ptrdiff_t(arr.shape(i) - shp[i]) * arr.stride(i);\n        rev_jump[i] = 1;\n      } else\n        rp -= ptrdiff_t(shp[i]) * arr.stride(i);\n    }\n  }\n  ptrdiff_t ofs() const { return p; }\n  ptrdiff_t rev_ofs() const { return rp; }\n  size_t remaining() const { return rem; }\n};\n\ntemplate<typename T>\nstruct VTYPE {};\ntemplate<typename T>\nusing vtype_t = typename VTYPE<T>::type;\n\n#ifndef POCKETFFT_NO_VECTORS\ntemplate<>\nstruct VTYPE<float> {\n  using type = float __attribute__((vector_size(VLEN<float>::val * sizeof(float))));\n};\ntemplate<>\nstruct VTYPE<double> {\n  using type = double __attribute__((vector_size(VLEN<double>::val * sizeof(double))));\n};\ntemplate<>\nstruct VTYPE<long double> {\n  using type =\n      long double __attribute__((vector_size(VLEN<long double>::val * sizeof(long double))));\n};\n#endif\n\ntemplate<typename T>\narr<char> alloc_tmp(const shape_t& shape, size_t axsize, size_t elemsize) {\n  auto othersize = util::prod(shape) / axsize;\n  auto tmpsize = axsize * ((othersize >= VLEN<T>::val) ? VLEN<T>::val : 1);\n  return arr<char>(tmpsize * elemsize);\n}\ntemplate<typename T>\narr<char> alloc_tmp(const shape_t& shape, const shape_t& axes, size_t elemsize) {\n  size_t fullsize = util::prod(shape);\n  size_t tmpsize = 0;\n  for (size_t i = 0; i < axes.size(); ++i) {\n    auto axsize = shape[axes[i]];\n    auto othersize = fullsize / axsize;\n    auto sz = axsize * ((othersize >= VLEN<T>::val) ? VLEN<T>::val : 1);\n    if (sz > tmpsize) tmpsize = sz;\n  }\n  return arr<char>(tmpsize * elemsize);\n}\n\ntemplate<typename T, size_t vlen>\nvoid copy_input(const multi_iter<vlen>& it, const cndarr<cmplx<T>>& src,\n                cmplx<vtype_t<T>>* POCKETFFT_RESTRICT dst) {\n  for (size_t i = 0; i < it.length_in(); ++i)\n    for (size_t j = 0; j < vlen; ++j) {\n      dst[i].r[j] = src[it.iofs(j, i)].r;\n      dst[i].i[j] = src[it.iofs(j, i)].i;\n    }\n}\n\ntemplate<typename T, size_t vlen>\nvoid copy_input(const multi_iter<vlen>& it, const cndarr<T>& src,\n                vtype_t<T>* POCKETFFT_RESTRICT dst) {\n  for (size_t i = 0; i < it.length_in(); ++i)\n    for (size_t j = 0; j < vlen; ++j) dst[i][j] = src[it.iofs(j, i)];\n}\n\ntemplate<typename T, size_t vlen>\nvoid copy_input(const multi_iter<vlen>& it, const cndarr<T>& src, T* POCKETFFT_RESTRICT dst) {\n  if (dst == &src[it.iofs(0)]) return;  // in-place\n  for (size_t i = 0; i < it.length_in(); ++i) dst[i] = src[it.iofs(i)];\n}\n\ntemplate<typename T, size_t vlen>\nvoid copy_output(const multi_iter<vlen>& it, const cmplx<vtype_t<T>>* POCKETFFT_RESTRICT src,\n                 ndarr<cmplx<T>>& dst) {\n  for (size_t i = 0; i < it.length_out(); ++i)\n    for (size_t j = 0; j < vlen; ++j) dst[it.oofs(j, i)].Set(src[i].r[j], src[i].i[j]);\n}\n\ntemplate<typename T, size_t vlen>\nvoid copy_output(const multi_iter<vlen>& it, const vtype_t<T>* POCKETFFT_RESTRICT src,\n                 ndarr<T>& dst) {\n  for (size_t i = 0; i < it.length_out(); ++i)\n    for (size_t j = 0; j < vlen; ++j) dst[it.oofs(j, i)] = src[i][j];\n}\n\ntemplate<typename T, size_t vlen>\nvoid copy_output(const multi_iter<vlen>& it, const T* POCKETFFT_RESTRICT src, ndarr<T>& dst) {\n  if (src == &dst[it.oofs(0)]) return;  // in-place\n  for (size_t i = 0; i < it.length_out(); ++i) dst[it.oofs(i)] = src[i];\n}\n\ntemplate<typename T>\nstruct add_vec {\n  using type = vtype_t<T>;\n};\ntemplate<typename T>\nstruct add_vec<cmplx<T>> {\n  using type = cmplx<vtype_t<T>>;\n};\ntemplate<typename T>\nusing add_vec_t = typename add_vec<T>::type;\n\ntemplate<typename Tplan, typename T, typename T0, typename Exec>\nPOCKETFFT_NOINLINE void general_nd(const cndarr<T>& in, ndarr<T>& out, const shape_t& axes, T0 fct,\n                                   size_t nthreads, const Exec& exec,\n                                   const bool allow_inplace = true) {\n  std::shared_ptr<Tplan> plan;\n\n  for (size_t iax = 0; iax < axes.size(); ++iax) {\n    size_t len = in.shape(axes[iax]);\n    if ((!plan) || (len != plan->length())) plan = get_plan<Tplan>(len);\n\n    threading::thread_map(util::thread_count(nthreads, in.shape(), axes[iax], VLEN<T>::val), [&] {\n      constexpr auto vlen = VLEN<T0>::val;\n      auto storage = alloc_tmp<T0>(in.shape(), len, sizeof(T));\n      const auto& tin(iax == 0 ? in : out);\n      multi_iter<vlen> it(tin, out, axes[iax]);\n#ifndef POCKETFFT_NO_VECTORS\n      if (vlen > 1)\n        while (it.remaining() >= vlen) {\n          it.advance(vlen);\n          auto tdatav = reinterpret_cast<add_vec_t<T>*>(storage.data());\n          exec(it, tin, out, tdatav, *plan, fct);\n        }\n#endif\n      while (it.remaining() > 0) {\n        it.advance(1);\n        auto buf = allow_inplace && it.stride_out() == sizeof(T)\n                       ? &out[it.oofs(0)]\n                       : reinterpret_cast<T*>(storage.data());\n        exec(it, tin, out, buf, *plan, fct);\n      }\n    });           // end of parallel region\n    fct = T0(1);  // factor has been applied, use 1 for remaining axes\n  }\n}\n\nstruct ExecC2C {\n  bool forward;\n\n  template<typename T0, typename T, size_t vlen>\n  void operator()(const multi_iter<vlen>& it, const cndarr<cmplx<T0>>& in, ndarr<cmplx<T0>>& out,\n                  T* buf, const pocketfft_c<T0>& plan, T0 fct) const {\n    copy_input(it, in, buf);\n    plan.exec(buf, fct, forward);\n    copy_output(it, buf, out);\n  }\n};\n\ntemplate<typename T, size_t vlen>\nvoid copy_hartley(const multi_iter<vlen>& it, const vtype_t<T>* POCKETFFT_RESTRICT src,\n                  ndarr<T>& dst) {\n  for (size_t j = 0; j < vlen; ++j) dst[it.oofs(j, 0)] = src[0][j];\n  size_t i = 1, i1 = 1, i2 = it.length_out() - 1;\n  for (i = 1; i < it.length_out() - 1; i += 2, ++i1, --i2)\n    for (size_t j = 0; j < vlen; ++j) {\n      dst[it.oofs(j, i1)] = src[i][j] + src[i + 1][j];\n      dst[it.oofs(j, i2)] = src[i][j] - src[i + 1][j];\n    }\n  if (i < it.length_out())\n    for (size_t j = 0; j < vlen; ++j) dst[it.oofs(j, i1)] = src[i][j];\n}\n\ntemplate<typename T, size_t vlen>\nvoid copy_hartley(const multi_iter<vlen>& it, const T* POCKETFFT_RESTRICT src, ndarr<T>& dst) {\n  dst[it.oofs(0)] = src[0];\n  size_t i = 1, i1 = 1, i2 = it.length_out() - 1;\n  for (i = 1; i < it.length_out() - 1; i += 2, ++i1, --i2) {\n    dst[it.oofs(i1)] = src[i] + src[i + 1];\n    dst[it.oofs(i2)] = src[i] - src[i + 1];\n  }\n  if (i < it.length_out()) dst[it.oofs(i1)] = src[i];\n}\n\nstruct ExecHartley {\n  template<typename T0, typename T, size_t vlen>\n  void operator()(const multi_iter<vlen>& it, const cndarr<T0>& in, ndarr<T0>& out, T* buf,\n                  const pocketfft_r<T0>& plan, T0 fct) const {\n    copy_input(it, in, buf);\n    plan.exec(buf, fct, true);\n    copy_hartley(it, buf, out);\n  }\n};\n\nstruct ExecDcst {\n  bool ortho;\n  int type;\n  bool cosine;\n\n  template<typename T0, typename T, typename Tplan, size_t vlen>\n  void operator()(const multi_iter<vlen>& it, const cndarr<T0>& in, ndarr<T0>& out, T* buf,\n                  const Tplan& plan, T0 fct) const {\n    copy_input(it, in, buf);\n    plan.exec(buf, fct, ortho, type, cosine);\n    copy_output(it, buf, out);\n  }\n};\n\ntemplate<typename T>\nPOCKETFFT_NOINLINE void general_r2c(const cndarr<T>& in, ndarr<cmplx<T>>& out, size_t axis,\n                                    bool forward, T fct, size_t nthreads) {\n  auto plan = get_plan<pocketfft_r<T>>(in.shape(axis));\n  size_t len = in.shape(axis);\n  threading::thread_map(util::thread_count(nthreads, in.shape(), axis, VLEN<T>::val), [&] {\n    constexpr auto vlen = VLEN<T>::val;\n    auto storage = alloc_tmp<T>(in.shape(), len, sizeof(T));\n    multi_iter<vlen> it(in, out, axis);\n#ifndef POCKETFFT_NO_VECTORS\n    if (vlen > 1)\n      while (it.remaining() >= vlen) {\n        it.advance(vlen);\n        auto tdatav = reinterpret_cast<vtype_t<T>*>(storage.data());\n        copy_input(it, in, tdatav);\n        plan->exec(tdatav, fct, true);\n        for (size_t j = 0; j < vlen; ++j) out[it.oofs(j, 0)].Set(tdatav[0][j]);\n        size_t i = 1, ii = 1;\n        if (forward)\n          for (; i < len - 1; i += 2, ++ii)\n            for (size_t j = 0; j < vlen; ++j)\n              out[it.oofs(j, ii)].Set(tdatav[i][j], tdatav[i + 1][j]);\n        else\n          for (; i < len - 1; i += 2, ++ii)\n            for (size_t j = 0; j < vlen; ++j)\n              out[it.oofs(j, ii)].Set(tdatav[i][j], -tdatav[i + 1][j]);\n        if (i < len)\n          for (size_t j = 0; j < vlen; ++j) out[it.oofs(j, ii)].Set(tdatav[i][j]);\n      }\n#endif\n    while (it.remaining() > 0) {\n      it.advance(1);\n      auto tdata = reinterpret_cast<T*>(storage.data());\n      copy_input(it, in, tdata);\n      plan->exec(tdata, fct, true);\n      out[it.oofs(0)].Set(tdata[0]);\n      size_t i = 1, ii = 1;\n      if (forward)\n        for (; i < len - 1; i += 2, ++ii) out[it.oofs(ii)].Set(tdata[i], tdata[i + 1]);\n      else\n        for (; i < len - 1; i += 2, ++ii) out[it.oofs(ii)].Set(tdata[i], -tdata[i + 1]);\n      if (i < len) out[it.oofs(ii)].Set(tdata[i]);\n    }\n  });  // end of parallel region\n}\ntemplate<typename T>\nPOCKETFFT_NOINLINE void general_c2r(const cndarr<cmplx<T>>& in, ndarr<T>& out, size_t axis,\n                                    bool forward, T fct, size_t nthreads) {\n  auto plan = get_plan<pocketfft_r<T>>(out.shape(axis));\n  size_t len = out.shape(axis);\n  threading::thread_map(util::thread_count(nthreads, in.shape(), axis, VLEN<T>::val), [&] {\n    constexpr auto vlen = VLEN<T>::val;\n    auto storage = alloc_tmp<T>(out.shape(), len, sizeof(T));\n    multi_iter<vlen> it(in, out, axis);\n#ifndef POCKETFFT_NO_VECTORS\n    if (vlen > 1)\n      while (it.remaining() >= vlen) {\n        it.advance(vlen);\n        auto tdatav = reinterpret_cast<vtype_t<T>*>(storage.data());\n        for (size_t j = 0; j < vlen; ++j) tdatav[0][j] = in[it.iofs(j, 0)].r;\n        {\n          size_t i = 1, ii = 1;\n          if (forward)\n            for (; i < len - 1; i += 2, ++ii)\n              for (size_t j = 0; j < vlen; ++j) {\n                tdatav[i][j] = in[it.iofs(j, ii)].r;\n                tdatav[i + 1][j] = -in[it.iofs(j, ii)].i;\n              }\n          else\n            for (; i < len - 1; i += 2, ++ii)\n              for (size_t j = 0; j < vlen; ++j) {\n                tdatav[i][j] = in[it.iofs(j, ii)].r;\n                tdatav[i + 1][j] = in[it.iofs(j, ii)].i;\n              }\n          if (i < len)\n            for (size_t j = 0; j < vlen; ++j) tdatav[i][j] = in[it.iofs(j, ii)].r;\n        }\n        plan->exec(tdatav, fct, false);\n        copy_output(it, tdatav, out);\n      }\n#endif\n    while (it.remaining() > 0) {\n      it.advance(1);\n      auto tdata = reinterpret_cast<T*>(storage.data());\n      tdata[0] = in[it.iofs(0)].r;\n      {\n        size_t i = 1, ii = 1;\n        if (forward)\n          for (; i < len - 1; i += 2, ++ii) {\n            tdata[i] = in[it.iofs(ii)].r;\n            tdata[i + 1] = -in[it.iofs(ii)].i;\n          }\n        else\n          for (; i < len - 1; i += 2, ++ii) {\n            tdata[i] = in[it.iofs(ii)].r;\n            tdata[i + 1] = in[it.iofs(ii)].i;\n          }\n        if (i < len) tdata[i] = in[it.iofs(ii)].r;\n      }\n      plan->exec(tdata, fct, false);\n      copy_output(it, tdata, out);\n    }\n  });  // end of parallel region\n}\n\nstruct ExecR2R {\n  bool r2h, forward;\n\n  template<typename T0, typename T, size_t vlen>\n  void operator()(const multi_iter<vlen>& it, const cndarr<T0>& in, ndarr<T0>& out, T* buf,\n                  const pocketfft_r<T0>& plan, T0 fct) const {\n    copy_input(it, in, buf);\n    if ((!r2h) && forward)\n      for (size_t i = 2; i < it.length_out(); i += 2) buf[i] = -buf[i];\n    plan.exec(buf, fct, r2h);\n    if (r2h && (!forward))\n      for (size_t i = 2; i < it.length_out(); i += 2) buf[i] = -buf[i];\n    copy_output(it, buf, out);\n  }\n};\n\ntemplate<typename T>\nvoid c2c(const shape_t& shape, const stride_t& stride_in, const stride_t& stride_out,\n         const shape_t& axes, bool forward, const std::complex<T>* data_in,\n         std::complex<T>* data_out, T fct, size_t nthreads = 1) {\n  if (util::prod(shape) == 0) return;\n  util::sanity_check(shape, stride_in, stride_out, data_in == data_out, axes);\n  cndarr<cmplx<T>> ain(data_in, shape, stride_in);\n  ndarr<cmplx<T>> aout(data_out, shape, stride_out);\n  general_nd<pocketfft_c<T>>(ain, aout, axes, fct, nthreads, ExecC2C{forward});\n}\n\ntemplate<typename T>\nvoid dct(const shape_t& shape, const stride_t& stride_in, const stride_t& stride_out,\n         const shape_t& axes, int type, const T* data_in, T* data_out, T fct, bool ortho,\n         size_t nthreads = 1) {\n  if ((type < 1) || (type > 4)) throw std::invalid_argument(\"invalid DCT type\");\n  if (util::prod(shape) == 0) return;\n  util::sanity_check(shape, stride_in, stride_out, data_in == data_out, axes);\n  cndarr<T> ain(data_in, shape, stride_in);\n  ndarr<T> aout(data_out, shape, stride_out);\n  const ExecDcst exec{ortho, type, true};\n  if (type == 1)\n    general_nd<T_dct1<T>>(ain, aout, axes, fct, nthreads, exec);\n  else if (type == 4)\n    general_nd<T_dcst4<T>>(ain, aout, axes, fct, nthreads, exec);\n  else\n    general_nd<T_dcst23<T>>(ain, aout, axes, fct, nthreads, exec);\n}\n\ntemplate<typename T>\nvoid dst(const shape_t& shape, const stride_t& stride_in, const stride_t& stride_out,\n         const shape_t& axes, int type, const T* data_in, T* data_out, T fct, bool ortho,\n         size_t nthreads = 1) {\n  if ((type < 1) || (type > 4)) throw std::invalid_argument(\"invalid DST type\");\n  if (util::prod(shape) == 0) return;\n  util::sanity_check(shape, stride_in, stride_out, data_in == data_out, axes);\n  cndarr<T> ain(data_in, shape, stride_in);\n  ndarr<T> aout(data_out, shape, stride_out);\n  const ExecDcst exec{ortho, type, false};\n  if (type == 1)\n    general_nd<T_dst1<T>>(ain, aout, axes, fct, nthreads, exec);\n  else if (type == 4)\n    general_nd<T_dcst4<T>>(ain, aout, axes, fct, nthreads, exec);\n  else\n    general_nd<T_dcst23<T>>(ain, aout, axes, fct, nthreads, exec);\n}\n\ntemplate<typename T>\nvoid r2c(const shape_t& shape_in, const stride_t& stride_in, const stride_t& stride_out,\n         size_t axis, bool forward, const T* data_in, std::complex<T>* data_out, T fct,\n         size_t nthreads = 1) {\n  if (util::prod(shape_in) == 0) return;\n  util::sanity_check(shape_in, stride_in, stride_out, false, axis);\n  cndarr<T> ain(data_in, shape_in, stride_in);\n  shape_t shape_out(shape_in);\n  shape_out[axis] = shape_in[axis] / 2 + 1;\n  ndarr<cmplx<T>> aout(data_out, shape_out, stride_out);\n  general_r2c(ain, aout, axis, forward, fct, nthreads);\n}\n\ntemplate<typename T>\nvoid r2c(const shape_t& shape_in, const stride_t& stride_in, const stride_t& stride_out,\n         const shape_t& axes, bool forward, const T* data_in, std::complex<T>* data_out, T fct,\n         size_t nthreads = 1) {\n  if (util::prod(shape_in) == 0) return;\n  util::sanity_check(shape_in, stride_in, stride_out, false, axes);\n  r2c(shape_in, stride_in, stride_out, axes.back(), forward, data_in, data_out, fct, nthreads);\n  if (axes.size() == 1) return;\n  shape_t shape_out(shape_in);\n  shape_out[axes.back()] = shape_in[axes.back()] / 2 + 1;\n  auto newaxes = shape_t{axes.begin(), --axes.end()};\n  c2c(shape_out, stride_out, stride_out, newaxes, forward, data_out, data_out, T(1), nthreads);\n}\n\ntemplate<typename T>\nvoid c2r(const shape_t& shape_out, const stride_t& stride_in, const stride_t& stride_out,\n         size_t axis, bool forward, const std::complex<T>* data_in, T* data_out, T fct,\n         size_t nthreads = 1) {\n  if (util::prod(shape_out) == 0) return;\n  util::sanity_check(shape_out, stride_in, stride_out, false, axis);\n  shape_t shape_in(shape_out);\n  shape_in[axis] = shape_out[axis] / 2 + 1;\n  cndarr<cmplx<T>> ain(data_in, shape_in, stride_in);\n  ndarr<T> aout(data_out, shape_out, stride_out);\n  general_c2r(ain, aout, axis, forward, fct, nthreads);\n}\n\ntemplate<typename T>\nvoid c2r(const shape_t& shape_out, const stride_t& stride_in, const stride_t& stride_out,\n         const shape_t& axes, bool forward, const std::complex<T>* data_in, T* data_out, T fct,\n         size_t nthreads = 1) {\n  if (util::prod(shape_out) == 0) return;\n  if (axes.size() == 1)\n    return c2r(shape_out, stride_in, stride_out, axes[0], forward, data_in, data_out, fct,\n               nthreads);\n  util::sanity_check(shape_out, stride_in, stride_out, false, axes);\n  auto shape_in = shape_out;\n  shape_in[axes.back()] = shape_out[axes.back()] / 2 + 1;\n  auto nval = util::prod(shape_in);\n  stride_t stride_inter(shape_in.size());\n  stride_inter.back() = sizeof(cmplx<T>);\n  for (int i = int(shape_in.size()) - 2; i >= 0; --i)\n    stride_inter[size_t(i)] = stride_inter[size_t(i + 1)] * ptrdiff_t(shape_in[size_t(i + 1)]);\n  arr<std::complex<T>> tmp(nval);\n  auto newaxes = shape_t{axes.begin(), --axes.end()};\n  c2c(shape_in, stride_in, stride_inter, newaxes, forward, data_in, tmp.data(), T(1), nthreads);\n  c2r(shape_out, stride_inter, stride_out, axes.back(), forward, tmp.data(), data_out, fct,\n      nthreads);\n}\n\ntemplate<typename T>\nvoid r2r_fftpack(const shape_t& shape, const stride_t& stride_in, const stride_t& stride_out,\n                 const shape_t& axes, bool real2hermitian, bool forward, const T* data_in,\n                 T* data_out, T fct, size_t nthreads = 1) {\n  if (util::prod(shape) == 0) return;\n  util::sanity_check(shape, stride_in, stride_out, data_in == data_out, axes);\n  cndarr<T> ain(data_in, shape, stride_in);\n  ndarr<T> aout(data_out, shape, stride_out);\n  general_nd<pocketfft_r<T>>(ain, aout, axes, fct, nthreads, ExecR2R{real2hermitian, forward});\n}\n\ntemplate<typename T>\nvoid r2r_separable_hartley(const shape_t& shape, const stride_t& stride_in,\n                           const stride_t& stride_out, const shape_t& axes, const T* data_in,\n                           T* data_out, T fct, size_t nthreads = 1) {\n  if (util::prod(shape) == 0) return;\n  util::sanity_check(shape, stride_in, stride_out, data_in == data_out, axes);\n  cndarr<T> ain(data_in, shape, stride_in);\n  ndarr<T> aout(data_out, shape, stride_out);\n  general_nd<pocketfft_r<T>>(ain, aout, axes, fct, nthreads, ExecHartley{}, false);\n}\n\ntemplate<typename T>\nvoid r2r_genuine_hartley(const shape_t& shape, const stride_t& stride_in,\n                         const stride_t& stride_out, const shape_t& axes, const T* data_in,\n                         T* data_out, T fct, size_t nthreads = 1) {\n  if (util::prod(shape) == 0) return;\n  if (axes.size() == 1)\n    return r2r_separable_hartley(shape, stride_in, stride_out, axes, data_in, data_out, fct,\n                                 nthreads);\n  util::sanity_check(shape, stride_in, stride_out, data_in == data_out, axes);\n  shape_t tshp(shape);\n  tshp[axes.back()] = tshp[axes.back()] / 2 + 1;\n  arr<std::complex<T>> tdata(util::prod(tshp));\n  stride_t tstride(shape.size());\n  tstride.back() = sizeof(std::complex<T>);\n  for (size_t i = tstride.size() - 1; i > 0; --i) tstride[i - 1] = tstride[i] * ptrdiff_t(tshp[i]);\n  r2c(shape, stride_in, tstride, axes, true, data_in, tdata.data(), fct, nthreads);\n  cndarr<cmplx<T>> atmp(tdata.data(), tshp, tstride);\n  ndarr<T> aout(data_out, shape, stride_out);\n  simple_iter iin(atmp);\n  rev_iter iout(aout, axes);\n  while (iin.remaining() > 0) {\n    auto v = atmp[iin.ofs()];\n    aout[iout.ofs()] = v.r + v.i;\n    aout[iout.rev_ofs()] = v.r - v.i;\n    iin.advance();\n    iout.advance();\n  }\n}\n\n}  // namespace detail\n\nusing detail::BACKWARD;\nusing detail::c2c;\nusing detail::c2r;\nusing detail::dct;\nusing detail::dst;\nusing detail::FORWARD;\nusing detail::r2c;\nusing detail::r2r_fftpack;\nusing detail::r2r_genuine_hartley;\nusing detail::r2r_separable_hartley;\nusing detail::shape_t;\nusing detail::stride_t;\n\n}  // namespace pocketfft\n\n#undef POCKETFFT_NOINLINE\n#undef POCKETFFT_RESTRICT\n\n#endif  // POCKETFFT_HDRONLY_H\n"
  },
  {
    "path": "oneflow/user/kernels/pocketfftplan.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"pocketfft_hdronly.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/new_kernel_util.h\"\n#include \"oneflow/core/kernel/kernel.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n\nnamespace oneflow {\nnamespace {\n\nenum class FFT_EXCUTETYPE { R2C, C2C, C2R };\n\ntemplate<typename dtype>\nstruct PocketFFtParams {\n  bool IsForward;\n  FFT_EXCUTETYPE excute_type;\n  dtype fct;\n  pocketfft::shape_t axes;\n  pocketfft::stride_t in_stridef;\n  pocketfft::stride_t out_stridef;\n  pocketfft::shape_t input_shape;\n  pocketfft::shape_t output_shape;\n  PocketFFtParams() = default;\n  PocketFFtParams(const Shape& in_shape, const Shape& out_shape, const Stride& in_stride,\n                  const Stride& out_stride, const std::vector<int64_t>& dims, const bool is_forward,\n                  const dtype f, FFT_EXCUTETYPE type)\n      : IsForward(is_forward),\n        excute_type(type),\n        fct(f),\n        axes(dims.begin(), dims.end()),\n        in_stridef(in_stride.begin(), in_stride.end()),\n        out_stridef(out_stride.begin(), out_stride.end()) {\n    input_shape.resize(in_shape.size());\n    output_shape.resize(out_shape.size());\n\n    std::copy(in_shape.begin(), in_shape.end(), input_shape.begin());\n    std::copy(out_shape.begin(), out_shape.end(), output_shape.begin());\n\n    // calc element size\n    size_t in_elemsize = type == FFT_EXCUTETYPE::C2C || type == FFT_EXCUTETYPE::C2R\n                             ? sizeof(std::complex<dtype>)\n                             : sizeof(dtype);\n    size_t out_elemsize = type == FFT_EXCUTETYPE::R2C || type == FFT_EXCUTETYPE::C2C\n                              ? sizeof(std::complex<dtype>)\n                              : sizeof(dtype);\n    for (auto& s : in_stridef) { s *= in_elemsize; }\n    for (auto& s : out_stridef) { s *= out_elemsize; }\n  }\n};\n\ntemplate<typename dtype>\nclass PocketFFtConfig {\n public:\n  PocketFFtConfig(const PocketFFtConfig&) = delete;\n  PocketFFtConfig& operator=(PocketFFtConfig const&) = delete;\n\n  explicit PocketFFtConfig(const PocketFFtParams<dtype>& params) : fftparams(params) {}\n\n  void excute(const std::complex<dtype>* in, std::complex<dtype>* out) {\n    pocketfft::c2c(fftparams.input_shape, fftparams.in_stridef, fftparams.out_stridef,\n                   fftparams.axes, fftparams.IsForward, in, out, fftparams.fct);\n  }\n\n  void excute(const dtype* in, std::complex<dtype>* out) {\n    pocketfft::r2c(fftparams.input_shape, fftparams.in_stridef, fftparams.out_stridef,\n                   fftparams.axes, fftparams.IsForward, in, out, fftparams.fct);\n  }\n\n  void excute(const std::complex<dtype>* in, dtype* out) {\n    pocketfft::c2r(fftparams.output_shape, fftparams.in_stridef, fftparams.out_stridef,\n                   fftparams.axes, fftparams.IsForward, in, out, fftparams.fct);\n  }\n\n private:\n  PocketFFtParams<dtype> fftparams;\n};\n\n}  // namespace\n\n}  // namespace oneflow"
  },
  {
    "path": "oneflow/user/kernels/prelu_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/ndarray/ndarray_util.h\"\n\nnamespace oneflow {\n\ntemplate<typename T>\nclass CpuPReluKernel final : public user_op::OpKernel {\n public:\n  CpuPReluKernel() = default;\n  ~CpuPReluKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    const user_op::Tensor* alpha = ctx->Tensor4ArgNameAndIndex(\"alpha\", 0);\n    user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    const T* x_ptr = x->dptr<T>();\n    const T* alpha_ptr = alpha->dptr<T>();\n    T* y_ptr = y->mut_dptr<T>();\n    const int32_t elem_cnt = x->shape_view().elem_cnt();\n    const int32_t alpha_size = alpha->shape_view().elem_cnt();\n    const int batch = x->shape_view().At(0);\n    const int channels = (x->shape_view().NumAxes() == 1) ? 1 : x->shape_view().At(1);\n    const int32_t inner_size = elem_cnt / batch / channels;\n    FOR_RANGE(int32_t, i, 0, elem_cnt) {\n      y_ptr[i] = x_ptr[i] > 0 ? x_ptr[i] : x_ptr[i] * alpha_ptr[(i / inner_size) % alpha_size];\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_CPU_PRELU_KERNEL(dtype)                                              \\\n  REGISTER_USER_KERNEL(\"prelu\").SetCreateFn<CpuPReluKernel<dtype>>().SetIsMatchedHob( \\\n      (user_op::HobDeviceType() == DeviceType::kCPU)                                  \\\n      && (user_op::HobDataType(\"y\", 0) == GetDataType<dtype>::value));\n\nREGISTER_CPU_PRELU_KERNEL(float)\nREGISTER_CPU_PRELU_KERNEL(double)\n\ntemplate<typename T>\nclass CpuPReluGradKernel final : public user_op::OpKernel {\n public:\n  CpuPReluGradKernel() = default;\n  ~CpuPReluGradKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    const user_op::Tensor* alpha = ctx->Tensor4ArgNameAndIndex(\"alpha\", 0);\n    const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex(\"dx\", 0);\n    user_op::Tensor* alpha_diff = ctx->Tensor4ArgNameAndIndex(\"alpha_diff\", 0);\n    const T* x_ptr = x->dptr<T>();\n    const T* alpha_ptr = alpha->dptr<T>();\n    const T* dy_ptr = dy->dptr<T>();\n    T* dx_ptr = dx->mut_dptr<T>();\n    T* alpha_diff_ptr = alpha_diff->mut_dptr<T>();\n\n    const int32_t elem_cnt = x->shape_view().elem_cnt();\n    const int32_t alpha_size = alpha->shape_view().elem_cnt();\n    const int batch = x->shape_view().At(0);\n    const int channels = (x->shape_view().NumAxes() == 1) ? 1 : x->shape_view().At(1);\n    const int32_t inner_size = elem_cnt / batch / channels;\n\n    Memset<DeviceType::kCPU>(ctx->stream(), alpha_diff->mut_dptr<T>(), 0,\n                             alpha_diff->shape_view().elem_cnt() * sizeof(T));\n\n    for (int i = 0; i < elem_cnt; i++) {\n      const T x_i = x_ptr[i];\n      const T dy_i = dy_ptr[i];\n      const T alpha_i = alpha_ptr[(i / inner_size) % alpha_size];\n      dx_ptr[i] = x_i > 0 ? dy_i : dy_i * alpha_i;\n      alpha_diff_ptr[(i / inner_size) % alpha_size] += x_i > 0 ? 0 : dy_i * x_i;\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_CPU_PRELU_GRAD_KERNEL(dtype)                         \\\n  REGISTER_USER_KERNEL(\"prelu_grad\")                                  \\\n      .SetCreateFn<CpuPReluGradKernel<dtype>>()                       \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \\\n                       && (user_op::HobDataType(\"dx\", 0) == GetDataType<dtype>::value));\n\nREGISTER_CPU_PRELU_GRAD_KERNEL(float)\nREGISTER_CPU_PRELU_GRAD_KERNEL(double)\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/prelu_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/ndarray/ndarray_util.h\"\n#include \"oneflow/core/cuda/elementwise.cuh\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nShape CreatePreluLeftExtendedShape(const ShapeView& shape, const int32_t alpha_size) {\n  DimVector dim_vec(shape.NumAxes());\n  dim_vec.at(0) = 1LL;\n  dim_vec.at(1) = alpha_size;\n  for (int i = 2; i < shape.NumAxes(); i++) { dim_vec.at(i) = 1LL; }\n  return Shape(std::move(dim_vec));\n}\n\ntemplate<typename T>\nstruct PreluForwardSingleAlphaFunctor {\n  OF_DEVICE_FUNC explicit PreluForwardSingleAlphaFunctor(const T alpha) : alpha(alpha) {}\n  __device__ T operator()(T x) const { return (x > static_cast<T>(0.0)) ? x : (alpha * x); }\n  const T alpha;\n};\n\ntemplate<typename T>\nstruct PreluForwardSingleAlphaPtrFunctor {\n  OF_DEVICE_FUNC explicit PreluForwardSingleAlphaPtrFunctor(const T* alpha_ptr)\n      : alpha_ptr(alpha_ptr) {}\n  __device__ PreluForwardSingleAlphaFunctor<T> operator()() const {\n    return PreluForwardSingleAlphaFunctor<T>(*alpha_ptr);\n  }\n  const T* alpha_ptr;\n};\n\ntemplate<typename T, typename IndexType, int pack_size, bool tail, bool alpha_requires_grad>\n__global__ void PReluBackwardSingleAlphaGpu(const IndexType elem_cnt, const int64_t n_tail,\n                                            const T* x, const T* alpha, const T* dy, T* dx,\n                                            T* alpha_diff, const T* tail_x, const T* tail_dy,\n                                            T* tail_dx, T* tail_alpha_diff) {\n  int32_t global_thread_id = blockIdx.x * blockDim.x + threadIdx.x;\n\n  using LoadType = cuda::elementwise::PackType<T, pack_size>;\n  using LoadPack = cuda::elementwise::Pack<T, pack_size>;\n  T zero_val = static_cast<T>(0);\n  T alpha_val = alpha[0];\n\n  for (int64_t linear_index = global_thread_id * pack_size; linear_index < elem_cnt;\n       linear_index += gridDim.x * blockDim.x * pack_size) {\n    const LoadType* x_load = reinterpret_cast<const LoadType*>(x + linear_index);\n    LoadPack x_vec;\n    x_vec.storage = *x_load;\n\n    const LoadType* dy_load = reinterpret_cast<const LoadType*>(dy + linear_index);\n    LoadPack dy_vec;\n    dy_vec.storage = *dy_load;\n\n    LoadPack dx_vec;\n    T zero_val = static_cast<T>(0.0);\n    if (alpha_requires_grad) {\n      LoadPack dalpha_vec;\n#pragma unroll\n      for (int i = 0; i < pack_size; i++) {\n        if (x_vec.elem[i] > zero_val) {\n          dx_vec.elem[i] = dy_vec.elem[i];\n          dalpha_vec.elem[i] = zero_val;\n        } else {\n          dx_vec.elem[i] = dy_vec.elem[i] * alpha_val;\n          dalpha_vec.elem[i] = dy_vec.elem[i] * x_vec.elem[i];\n        }\n      }\n      *(reinterpret_cast<LoadType*>(dx + linear_index)) = dx_vec.storage;\n      *(reinterpret_cast<LoadType*>(alpha_diff + linear_index)) = dalpha_vec.storage;\n    } else {\n#pragma unroll\n      for (int i = 0; i < pack_size; i++) {\n        if (x_vec.elem[i] > zero_val) {\n          dx_vec.elem[i] = dy_vec.elem[i];\n        } else {\n          dx_vec.elem[i] = dy_vec.elem[i] * alpha_val;\n        }\n      }\n      *(reinterpret_cast<LoadType*>(dx + linear_index)) = dx_vec.storage;\n    }\n  }\n\n  if (tail && global_thread_id < n_tail) {\n    const T tail_dy_val = tail_dy[global_thread_id];\n    if (tail_x[global_thread_id] > zero_val) {\n      tail_dx[global_thread_id] = tail_dy_val;\n      if (alpha_requires_grad) { tail_alpha_diff[global_thread_id] = zero_val; }\n    } else {\n      tail_dx[global_thread_id] = alpha_val * tail_dy_val;\n      if (alpha_requires_grad) {\n        tail_alpha_diff[global_thread_id] = tail_x[global_thread_id] * tail_dy_val;\n      }\n    }\n  }\n}\n\ntemplate<typename T>\n__global__ void BroadcastPReluMultiAlphaNaiveForwardGpu(const int32_t elem_cnt,\n                                                        const int32_t alpha_size,\n                                                        const int32_t inner_size, const T* x,\n                                                        const T* alpha, T* y) {\n  const T zero_val = static_cast<T>(0.0);\n  CUDA_1D_KERNEL_LOOP(i, elem_cnt) {\n    const T x_i = x[i];\n    int32_t alpha_idx = (i / inner_size) % alpha_size;\n    y[i] = x_i > zero_val ? x_i : x_i * alpha[alpha_idx];\n  }\n}\n\ntemplate<typename T, typename IndexType, int pack_size>\n__global__ void PReluForwardMultiAlphaGpu(const IndexType elem_cnt, const IndexType alpha_size,\n                                          const IndexType inner_size, const T* x, const T* alpha,\n                                          T* y) {\n  int32_t global_thread_id = blockIdx.x * blockDim.x + threadIdx.x;\n\n  using LoadType = cuda::elementwise::PackType<T, pack_size>;\n  using LoadPack = cuda::elementwise::Pack<T, pack_size>;\n  T zero_val = static_cast<T>(0);\n  for (int64_t linear_index = global_thread_id * pack_size; linear_index < elem_cnt;\n       linear_index += gridDim.x * blockDim.x * pack_size) {\n    IndexType alpha_idx = (linear_index / inner_size) % alpha_size;\n\n    const LoadType* x_load = reinterpret_cast<const LoadType*>(x + linear_index);\n    LoadPack x_vec;\n    x_vec.storage = *x_load;\n\n    LoadPack y_vec;\n\n    T alpha_val = alpha[alpha_idx];\n#pragma unroll\n    for (int i = 0; i < pack_size; i++) {\n      y_vec.elem[i] = x_vec.elem[i] > zero_val ? x_vec.elem[i] : x_vec.elem[i] * alpha_val;\n    }\n    *(reinterpret_cast<LoadType*>(y + linear_index)) = y_vec.storage;\n  }\n}\n\ntemplate<typename T, bool alpha_requires_grad>\n__global__ void BroadcastPReluMultiAlphaNaiveBackwardGpu(const int32_t elem_cnt,\n                                                         const int32_t alpha_size,\n                                                         const int32_t inner_size, const T* x,\n                                                         const T* alpha, const T* dy, T* dx,\n                                                         T* alpha_diff) {\n  const T zero_val = static_cast<T>(0.0);\n  CUDA_1D_KERNEL_LOOP(i, elem_cnt) {\n    const T x_i = x[i];\n    const T dy_i = dy[i];\n    int32_t alpha_i = (i / inner_size) % alpha_size;\n    if (x_i > zero_val) {\n      dx[i] = dy_i;\n      if (alpha_requires_grad) { alpha_diff[i] = zero_val; }\n    } else {\n      dx[i] = dy_i * alpha[alpha_i];\n      if (alpha_requires_grad) { alpha_diff[i] = dy_i * x_i; }\n    }\n  }\n}\n\ntemplate<typename T, typename IndexType, int pack_size, bool alpha_requires_grad>\n__global__ void PReluBackwardMultiAlphaGpu(const IndexType elem_cnt, const IndexType alpha_size,\n                                           const IndexType inner_size, const T* x, const T* alpha,\n                                           const T* dy, T* dx, T* alpha_diff) {\n  int32_t global_thread_id = blockIdx.x * blockDim.x + threadIdx.x;\n\n  using LoadType = cuda::elementwise::PackType<T, pack_size>;\n  using LoadPack = cuda::elementwise::Pack<T, pack_size>;\n  T zero_val = static_cast<T>(0);\n  for (int64_t linear_index = global_thread_id * pack_size; linear_index < elem_cnt;\n       linear_index += gridDim.x * blockDim.x * pack_size) {\n    IndexType alpha_idx = (linear_index / inner_size) % alpha_size;\n\n    const LoadType* x_load = reinterpret_cast<const LoadType*>(x + linear_index);\n    LoadPack x_vec;\n    x_vec.storage = *x_load;\n\n    const LoadType* dy_load = reinterpret_cast<const LoadType*>(dy + linear_index);\n    LoadPack dy_vec;\n    dy_vec.storage = *dy_load;\n\n    LoadPack dx_vec;\n    T alpha_val = alpha[alpha_idx];\n    if (alpha_requires_grad) {\n      LoadPack dalpha_vec;\n      T zero_val = static_cast<T>(0.0);\n#pragma unroll\n      for (int i = 0; i < pack_size; i++) {\n        if (x_vec.elem[i] > zero_val) {\n          dx_vec.elem[i] = dy_vec.elem[i];\n          dalpha_vec.elem[i] = zero_val;\n        } else {\n          dx_vec.elem[i] = dy_vec.elem[i] * alpha_val;\n          dalpha_vec.elem[i] = dy_vec.elem[i] * x_vec.elem[i];\n        }\n      }\n      *(reinterpret_cast<LoadType*>(dx + linear_index)) = dx_vec.storage;\n      *(reinterpret_cast<LoadType*>(alpha_diff + linear_index)) = dalpha_vec.storage;\n    } else {\n#pragma unroll\n      for (int i = 0; i < pack_size; i++) {\n        if (x_vec.elem[i] > zero_val) {\n          dx_vec.elem[i] = dy_vec.elem[i];\n        } else {\n          dx_vec.elem[i] = dy_vec.elem[i] * alpha_val;\n        }\n      }\n      *(reinterpret_cast<LoadType*>(dx + linear_index)) = dx_vec.storage;\n    }\n  }\n}\n\nconstexpr int32_t kBlockSize = 256;\n\ntemplate<typename T>\nint GetLaunchPackSize(const int64_t inner_size) {\n  constexpr int type_pack_size = cuda::elementwise::PackSize<T>();\n  for (int launch_pack_size = 8; launch_pack_size > 0; launch_pack_size /= 2) {\n    if (type_pack_size >= launch_pack_size && inner_size % launch_pack_size == 0) {\n      return launch_pack_size;\n    }\n  }\n  return 1;\n}\n\ntemplate<typename T, typename IndexType>\nvoid DispatchPreluForwardPackSize(ep::Stream* stream, const int64_t elem_cnt,\n                                  const int64_t alpha_size, const int64_t inner_size, const T* x,\n                                  const T* alpha, T* y) {\n  int grid_size;\n  const int pack_size = GetLaunchPackSize<T>(inner_size);\n  const int64_t pack_num = elem_cnt / pack_size;\n  cudaError_t err = cuda::elementwise::GetNumBlocks(pack_num, &grid_size);\n  if (pack_size == 8) {\n    PReluForwardMultiAlphaGpu<T, IndexType, 8>\n        <<<grid_size, kBlockSize, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(\n            elem_cnt, alpha_size, inner_size, x, alpha, y);\n  } else if (pack_size == 4) {\n    PReluForwardMultiAlphaGpu<T, IndexType, 4>\n        <<<grid_size, kBlockSize, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(\n            elem_cnt, alpha_size, inner_size, x, alpha, y);\n  } else if (pack_size == 2) {\n    PReluForwardMultiAlphaGpu<T, IndexType, 2>\n        <<<grid_size, kBlockSize, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(\n            elem_cnt, alpha_size, inner_size, x, alpha, y);\n  } else {\n    BroadcastPReluMultiAlphaNaiveForwardGpu<T>\n        <<<grid_size, kBlockSize, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(\n            elem_cnt, alpha_size, inner_size, x, alpha, y);\n  }\n}\n\ntemplate<typename T>\nvoid DispatchPreluForwardIndex(ep::Stream* stream, const int64_t elem_cnt, const int64_t alpha_size,\n                               const int64_t inner_size, const T* x, const T* alpha, T* y) {\n  if (elem_cnt < GetMaxVal<int32_t>()) {\n    DispatchPreluForwardPackSize<T, int32_t>(stream, elem_cnt, alpha_size, inner_size, x, alpha, y);\n  } else {\n    DispatchPreluForwardPackSize<T, int64_t>(stream, elem_cnt, alpha_size, inner_size, x, alpha, y);\n  }\n}\n\ntemplate<typename T, typename IndexType>\nvoid DispatchPreluBackwardPackSize(ep::Stream* stream, const int64_t elem_cnt,\n                                   const int64_t alpha_size, const int64_t inner_size, const T* x,\n                                   const T* alpha, const T* dy, T* dx, T* alpha_diff,\n                                   const bool alpha_requires_grad) {\n  int grid_size;\n  const int pack_size = GetLaunchPackSize<T>(inner_size);\n  const int64_t pack_num = elem_cnt / pack_size;\n  cudaError_t err = cuda::elementwise::GetNumBlocks(pack_num, &grid_size);\n\n  if (pack_size == 8) {\n    if (alpha_requires_grad) {\n      PReluBackwardMultiAlphaGpu<T, IndexType, 8, true>\n          <<<grid_size, kBlockSize, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(\n              elem_cnt, alpha_size, inner_size, x, alpha, dy, dx, alpha_diff);\n    } else {\n      PReluBackwardMultiAlphaGpu<T, IndexType, 8, false>\n          <<<grid_size, kBlockSize, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(\n              elem_cnt, alpha_size, inner_size, x, alpha, dy, dx, alpha_diff);\n    }\n  } else if (pack_size == 4) {\n    if (alpha_requires_grad) {\n      PReluBackwardMultiAlphaGpu<T, IndexType, 4, true>\n          <<<grid_size, kBlockSize, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(\n              elem_cnt, alpha_size, inner_size, x, alpha, dy, dx, alpha_diff);\n    } else {\n      PReluBackwardMultiAlphaGpu<T, IndexType, 4, false>\n          <<<grid_size, kBlockSize, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(\n              elem_cnt, alpha_size, inner_size, x, alpha, dy, dx, alpha_diff);\n    }\n  } else if (pack_size == 2) {\n    if (alpha_requires_grad) {\n      PReluBackwardMultiAlphaGpu<T, IndexType, 2, true>\n          <<<grid_size, kBlockSize, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(\n              elem_cnt, alpha_size, inner_size, x, alpha, dy, dx, alpha_diff);\n    } else {\n      PReluBackwardMultiAlphaGpu<T, IndexType, 2, false>\n          <<<grid_size, kBlockSize, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(\n              elem_cnt, alpha_size, inner_size, x, alpha, dy, dx, alpha_diff);\n    }\n\n  } else {\n    if (alpha_requires_grad) {\n      BroadcastPReluMultiAlphaNaiveBackwardGpu<T, true>\n          <<<grid_size, kBlockSize, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(\n              elem_cnt, alpha_size, inner_size, x, alpha, dy, dx, alpha_diff);\n    } else {\n      BroadcastPReluMultiAlphaNaiveBackwardGpu<T, false>\n          <<<grid_size, kBlockSize, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(\n              elem_cnt, alpha_size, inner_size, x, alpha, dy, dx, alpha_diff);\n    }\n  }\n}\n\ntemplate<typename T>\nvoid DispatchPreluBackwardIndex(ep::Stream* stream, const int64_t elem_cnt,\n                                const int64_t alpha_size, const int64_t inner_size, const T* x,\n                                const T* alpha, const T* dy, T* dx, T* alpha_diff,\n                                const bool alpha_requires_grad) {\n  if (elem_cnt < GetMaxVal<int32_t>()) {\n    DispatchPreluBackwardPackSize<T, int32_t>(stream, elem_cnt, alpha_size, inner_size, x, alpha,\n                                              dy, dx, alpha_diff, alpha_requires_grad);\n  } else {\n    DispatchPreluBackwardPackSize<T, int64_t>(stream, elem_cnt, alpha_size, inner_size, x, alpha,\n                                              dy, dx, alpha_diff, alpha_requires_grad);\n  }\n}\n\ntemplate<typename T, typename IndexType>\nvoid DispatchPreluBackwardSingleAlphaTail(ep::Stream* stream, const IndexType elem_cnt, const T* x,\n                                          const T* alpha, const T* dy, T* dx, T* alpha_diff,\n                                          const bool alpha_requires_grad) {\n  constexpr int pack_size = cuda::elementwise::PackSize<T>();\n  const int64_t pack_num = elem_cnt / pack_size;\n  int grid_size;\n  cudaError_t err = cuda::elementwise::GetNumBlocks(pack_num, &grid_size);\n  const int64_t tail_offset = pack_num * pack_size;\n  const int64_t n_tail = elem_cnt - tail_offset;\n  const bool tail = n_tail > 0 ? true : false;\n  if (tail) {\n    if (alpha_requires_grad) {\n      PReluBackwardSingleAlphaGpu<T, IndexType, pack_size, true, true>\n          <<<grid_size, kBlockSize, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(\n              elem_cnt, n_tail, x, alpha, dy, dx, alpha_diff, x + tail_offset, dy + tail_offset,\n              dx + tail_offset, alpha_diff + tail_offset);\n    } else {\n      PReluBackwardSingleAlphaGpu<T, IndexType, pack_size, true, false>\n          <<<grid_size, kBlockSize, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(\n              elem_cnt, n_tail, x, alpha, dy, dx, alpha_diff, x + tail_offset, dy + tail_offset,\n              dx + tail_offset, alpha_diff + tail_offset);\n    }\n  } else {\n    if (alpha_requires_grad) {\n      PReluBackwardSingleAlphaGpu<T, IndexType, pack_size, false, true>\n          <<<grid_size, kBlockSize, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(\n              elem_cnt, n_tail, x, alpha, dy, dx, alpha_diff, x + tail_offset, dy + tail_offset,\n              dx + tail_offset, alpha_diff + tail_offset);\n    } else {\n      PReluBackwardSingleAlphaGpu<T, IndexType, pack_size, false, false>\n          <<<grid_size, kBlockSize, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(\n              elem_cnt, n_tail, x, alpha, dy, dx, alpha_diff, x + tail_offset, dy + tail_offset,\n              dx + tail_offset, alpha_diff + tail_offset);\n    }\n  }\n}\n\ntemplate<typename T>\nvoid DispatchPreluBackwardSingleAlphaIndex(ep::Stream* stream, const int64_t elem_cnt, const T* x,\n                                           const T* alpha, const T* dy, T* dx, T* alpha_diff,\n                                           const bool alpha_requires_grad) {\n  if (elem_cnt < GetMaxVal<int32_t>()) {\n    DispatchPreluBackwardSingleAlphaTail<T, int32_t>(stream, elem_cnt, x, alpha, dy, dx, alpha_diff,\n                                                     alpha_requires_grad);\n  } else {\n    DispatchPreluBackwardSingleAlphaTail<T, int64_t>(stream, elem_cnt, x, alpha, dy, dx, alpha_diff,\n                                                     alpha_requires_grad);\n  }\n}\n\n}  // namespace\n\ntemplate<typename T>\nclass GpuPReluKernel final : public user_op::OpKernel {\n public:\n  GpuPReluKernel() = default;\n  ~GpuPReluKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    const user_op::Tensor* alpha = ctx->Tensor4ArgNameAndIndex(\"alpha\", 0);\n    user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    const int32_t elem_cnt = x->shape_view().elem_cnt();\n    const int32_t batch = x->shape_view().At(0);\n    const int32_t channels = (x->shape_view().NumAxes() == 1) ? 1 : x->shape_view().At(1);\n    const int32_t alpha_size = alpha->shape_view().elem_cnt();\n    const int32_t inner_size = elem_cnt / batch / channels;\n\n    if (alpha_size == 1) {\n      OF_CUDA_CHECK((cuda::elementwise::UnaryWithFactory(\n          PreluForwardSingleAlphaPtrFunctor<T>(reinterpret_cast<const T*>(alpha->dptr())), elem_cnt,\n          reinterpret_cast<T*>(y->mut_dptr()), reinterpret_cast<const T*>(x->dptr()),\n          ctx->stream()->As<ep::CudaStream>()->cuda_stream())));\n    } else {\n      DispatchPreluForwardIndex<T>(\n          ctx->stream(), elem_cnt, alpha_size, inner_size, reinterpret_cast<const T*>(x->dptr()),\n          reinterpret_cast<const T*>(alpha->dptr()), reinterpret_cast<T*>(y->mut_dptr()));\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_CUDA_PRELU_KERNEL(dtype)                                             \\\n  REGISTER_USER_KERNEL(\"prelu\").SetCreateFn<GpuPReluKernel<dtype>>().SetIsMatchedHob( \\\n      (user_op::HobDeviceType() == DeviceType::kCUDA)                                 \\\n      && (user_op::HobDataType(\"y\", 0) == GetDataType<dtype>::value));\n\nREGISTER_CUDA_PRELU_KERNEL(half)\nREGISTER_CUDA_PRELU_KERNEL(float)\nREGISTER_CUDA_PRELU_KERNEL(double)\n\ntemplate<typename T>\nclass GpuPReluGradKernel final : public user_op::OpKernel {\n public:\n  GpuPReluGradKernel() = default;\n  ~GpuPReluGradKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    const user_op::Tensor* alpha = ctx->Tensor4ArgNameAndIndex(\"alpha\", 0);\n    const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex(\"dx\", 0);\n    user_op::Tensor* alpha_diff = ctx->Tensor4ArgNameAndIndex(\"alpha_diff\", 0);\n    user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n    const bool alpha_requires_grad = ctx->Attr<bool>(\"alpha_requires_grad\");\n    const int32_t elem_cnt = x->shape_view().elem_cnt();\n    T* broadcasted_alpha_diff = tmp_buffer->mut_dptr<T>();\n    T* reduce_sum_tmp_buf = reinterpret_cast<T*>(tmp_buffer->mut_dptr<char>()\n                                                 + GetCudaAlignedSize(elem_cnt * sizeof(T)));\n\n    const int32_t batch = x->shape_view().At(0);\n    const int32_t channels = (x->shape_view().NumAxes() == 1) ? 1 : x->shape_view().At(1);\n    const int32_t alpha_size = alpha->shape_view().elem_cnt();\n    const int32_t inner_size = elem_cnt / batch / channels;\n\n    const Shape& left_extended_shape =\n        CreatePreluLeftExtendedShape(ShapeView(x->shape_view()), alpha_size);\n    if (alpha_size == 1) {\n      DispatchPreluBackwardSingleAlphaIndex<T>(ctx->stream(), elem_cnt, x->dptr<T>(),\n                                               alpha->dptr<T>(), dy->dptr<T>(), dx->mut_dptr<T>(),\n                                               broadcasted_alpha_diff, alpha_requires_grad);\n    } else {\n      DispatchPreluBackwardIndex<T>(ctx->stream(), elem_cnt, alpha_size, inner_size, x->dptr<T>(),\n                                    alpha->dptr<T>(), dy->dptr<T>(), dx->mut_dptr<T>(),\n                                    broadcasted_alpha_diff, alpha_requires_grad);\n    }\n    if (alpha_requires_grad) {\n      NdarrayUtil<DeviceType::kCUDA, T>::ReduceSum(\n          ctx->stream(), XpuVarNdarray<T>(left_extended_shape, alpha_diff->mut_dptr<T>()),\n          XpuVarNdarray<const T>(x->shape_view(), broadcasted_alpha_diff),\n          XpuVarNdarray<T>(x->shape_view(), reduce_sum_tmp_buf));\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_CUDA_PRELU_GRAD_KERNEL(dtype)                                          \\\n  REGISTER_USER_KERNEL(\"prelu_grad\")                                                    \\\n      .SetCreateFn<GpuPReluGradKernel<dtype>>()                                         \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)                  \\\n                       && (user_op::HobDataType(\"dx\", 0) == GetDataType<dtype>::value)) \\\n      .SetInferTmpSizeFn([](user_op::InferContext* ctx) {                               \\\n        const Shape& in_shape = ctx->InputShape(\"x\", 0);                                \\\n        const Shape& alpha_shape = ctx->InputShape(\"alpha\", 0);                         \\\n        const int64_t tmp_buffer_size =                                                 \\\n            2 * GetCudaAlignedSize(in_shape.elem_cnt() * sizeof(dtype));                \\\n        return tmp_buffer_size;                                                         \\\n      });\n\nREGISTER_CUDA_PRELU_GRAD_KERNEL(half)\nREGISTER_CUDA_PRELU_GRAD_KERNEL(float)\nREGISTER_CUDA_PRELU_GRAD_KERNEL(double)\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/quantization_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n\n#include <algorithm>\n\nnamespace oneflow {\n\ntemplate<typename T>\nvoid QuantizationPerLayerSymmetric(const T* in_ptr, const T scale, const int32_t quantization_bit,\n                                   const int64_t num_elements, T* out_ptr) {\n  T upper_bound = static_cast<T>(pow(2.0, quantization_bit - 1)) - 1;\n  T lower_bound = -upper_bound - 1;\n  FOR_RANGE(int64_t, i, 0, num_elements) {\n    T out = std::nearbyint(in_ptr[i] / scale);\n    out = out > upper_bound ? upper_bound : out;\n    out = out < lower_bound ? lower_bound : out;\n    out_ptr[i] = out;\n  }\n}\n\ntemplate<typename T>\nvoid QuantizationPerLayerAffine(const T* in_ptr, const T scale, const T zero_point,\n                                const int32_t quantization_bit, const int64_t num_elements,\n                                T* out_ptr) {\n  T upper_bound = static_cast<T>(pow(2.0, quantization_bit)) - 1;\n  T lower_bound = 0;\n  uint8_t zero_point_uint8 = static_cast<uint8_t>(std::round(zero_point));\n  FOR_RANGE(int64_t, i, 0, num_elements) {\n    T out = std::nearbyint(in_ptr[i] / scale + zero_point_uint8);\n    out = out > upper_bound ? upper_bound : out;\n    out = out < lower_bound ? lower_bound : out;\n    out_ptr[i] = out;\n  }\n}\n\ntemplate<typename T>\nvoid QuantizationPerLayerCambricon(const T* in_ptr, const T shift, const int32_t quantization_bit,\n                                   const int64_t num_elements, T* out_ptr) {\n  T upper_bound = static_cast<T>(pow(2.0, quantization_bit - 1)) - 1;\n  T lower_bound = -upper_bound - 1;\n  T scale = static_cast<T>(pow(2.0, static_cast<int32_t>(shift)));\n  FOR_RANGE(int64_t, i, 0, num_elements) {\n    T out = std::nearbyint(in_ptr[i] / scale);\n    out = out > upper_bound ? upper_bound : out;\n    out = out < lower_bound ? lower_bound : out;\n    out_ptr[i] = out;\n  }\n}\n\ntemplate<typename T>\nclass CpuQuantizationKernel final : public user_op::OpKernel {\n public:\n  CpuQuantizationKernel() = default;\n  ~CpuQuantizationKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    const user_op::Tensor* scale = ctx->Tensor4ArgNameAndIndex(\"scale\", 0);\n    const user_op::Tensor* zero_point = ctx->Tensor4ArgNameAndIndex(\"zero_point\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n\n    const std::string quantization_scheme = ctx->Attr<std::string>(\"quantization_scheme\");\n    const int32_t quantization_bit = ctx->Attr<int32_t>(\"quantization_bit\");\n    const std::string quantization_formula = ctx->Attr<std::string>(\"quantization_formula\");\n\n    const T* in_ptr = in->dptr<T>();\n    const T* scale_ptr = scale->dptr<T>();\n    T* out_ptr = out->mut_dptr<T>();\n\n    // round to even\n    auto origin_round_mode = std::fegetround();\n    std::fesetround(FE_TONEAREST);\n\n    if (quantization_formula == \"google\") {\n      int64_t outer_num = 1;\n      int64_t inner_num = in->shape_view().elem_cnt();\n      if (scale->shape_view().elem_cnt() > 1) {  // per-channel quantization\n        outer_num = in->shape_view().At(0);\n        inner_num = in->shape_view().Count(1);\n      }\n\n      if (quantization_scheme == \"symmetric\") {\n        FOR_RANGE(int64_t, c, 0, outer_num) {\n          QuantizationPerLayerSymmetric(in_ptr, scale_ptr[c], quantization_bit, inner_num, out_ptr);\n          in_ptr += inner_num;\n          out_ptr += inner_num;\n        }\n      } else {  // quantization_scheme == \"affine\"\n        const T* zero_point_ptr = zero_point->dptr<T>();\n        FOR_RANGE(int64_t, c, 0, outer_num) {\n          QuantizationPerLayerAffine(in_ptr, scale_ptr[c], zero_point_ptr[c], quantization_bit,\n                                     inner_num, out_ptr);\n          in_ptr += inner_num;\n          out_ptr += inner_num;\n        }\n      }\n    } else if (quantization_formula == \"cambricon\") {\n      QuantizationPerLayerCambricon(in_ptr, scale_ptr[0], quantization_bit,\n                                    in->shape_view().elem_cnt(), out_ptr);\n    } else {\n      UNIMPLEMENTED();\n    }\n\n    std::fesetround(origin_round_mode);\n  }\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_QUANTIZATION_KERNEL(dtype)                           \\\n  REGISTER_USER_KERNEL(\"quantization\")                                \\\n      .SetCreateFn<CpuQuantizationKernel<dtype>>()                    \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \\\n                       && (user_op::HobDataType(\"in\", 0) == GetDataType<dtype>::value))\n\nREGISTER_QUANTIZATION_KERNEL(float);\nREGISTER_QUANTIZATION_KERNEL(double);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/quantization_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/device/cuda_util.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/kernel_util.cuh\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename T>\n__global__ void QuantizationSymmetric(const T* in_ptr, const T* scale_ptr, const int64_t scale_size,\n                                      const int64_t elements, const int64_t panel_size,\n                                      const double quantization_bit, T* out_ptr) {\n  int64_t gid = (blockDim.x * blockIdx.x) + threadIdx.x;\n  int64_t step = gridDim.x * blockDim.x;\n\n  T upper_bound = static_cast<T>(pow(2.0, quantization_bit - 1)) - 1;\n  T lower_bound = -upper_bound - 1;\n\n  while (gid < elements) {\n    int64_t channel_index = gid / panel_size;\n    int64_t scale_idx = min(scale_size - 1, channel_index);\n\n    T scale = scale_ptr[scale_idx];\n\n    T out = nearbyint(in_ptr[gid] / scale);\n    out = out > upper_bound ? upper_bound : out;\n    out = out < lower_bound ? lower_bound : out;\n    out_ptr[gid] = out;\n\n    gid += step;\n  }\n}\n\ntemplate<typename T>\n__global__ void QuantizationAffine(const T* in_ptr, const T* scale_ptr, const T* zero_point_ptr,\n                                   const int64_t scale_size, const int64_t elements,\n                                   const int64_t panel_size, const double quantization_bit,\n                                   T* out_ptr) {\n  int64_t gid = (blockDim.x * blockIdx.x) + threadIdx.x;\n  int64_t step = gridDim.x * blockDim.x;\n\n  T upper_bound = static_cast<T>(pow(2.0, quantization_bit)) - 1;\n  T lower_bound = 0;\n\n  while (gid < elements) {\n    int64_t channel_index = gid / panel_size;\n    int64_t scale_idx = min(scale_size - 1, channel_index);\n\n    T scale = scale_ptr[scale_idx];\n    T zero_point = zero_point_ptr[scale_idx];\n\n    T out = nearbyint(in_ptr[gid] / scale + zero_point);\n    out = out > upper_bound ? upper_bound : out;\n    out = out < lower_bound ? lower_bound : out;\n    out_ptr[gid] = out;\n\n    gid += step;\n  }\n}\n\ntemplate<typename T>\n__global__ void QuantizationCambricon(const T* in_ptr, const T* shift, const int64_t scale_size,\n                                      const int64_t elements, const int64_t panel_size,\n                                      const double quantization_bit, T* out_ptr) {\n  int64_t gid = (blockDim.x * blockIdx.x) + threadIdx.x;\n  int64_t step = gridDim.x * blockDim.x;\n\n  T upper_bound = static_cast<T>(pow(2.0, quantization_bit - 1)) - 1;\n  T lower_bound = -upper_bound - 1;\n\n  T scale = static_cast<T>(pow(2.0, static_cast<int32_t>(shift[0])));\n\n  while (gid < elements) {\n    T out = nearbyint(in_ptr[gid] / scale);\n    out = out > upper_bound ? upper_bound : out;\n    out = out < lower_bound ? lower_bound : out;\n    out_ptr[gid] = out;\n    gid += step;\n  }\n}\n\n}  // namespace\n\ntemplate<typename T>\nclass GpuQuantizationKernel final : public user_op::OpKernel {\n public:\n  GpuQuantizationKernel() = default;\n  ~GpuQuantizationKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    const user_op::Tensor* scale = ctx->Tensor4ArgNameAndIndex(\"scale\", 0);\n    const user_op::Tensor* zero_point = ctx->Tensor4ArgNameAndIndex(\"zero_point\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n\n    const std::string quantization_scheme = ctx->Attr<std::string>(\"quantization_scheme\");\n    const int32_t quantization_bit = ctx->Attr<int32_t>(\"quantization_bit\");\n    const std::string quantization_formula = ctx->Attr<std::string>(\"quantization_formula\");\n\n    const int64_t elements = in->shape_view().elem_cnt();\n    const int64_t panel_size = in->shape_view().Count(1);\n    const int64_t scale_size = scale->shape_view().elem_cnt();\n\n    // round to even\n    auto origin_round_mode = std::fegetround();\n    std::fesetround(FE_TONEAREST);\n\n    if (quantization_formula == \"google\") {\n      if (quantization_scheme == \"symmetric\") {\n        RUN_CUDA_KERNEL((QuantizationSymmetric<T>), ctx->stream(), elements, in->dptr<T>(),\n                        scale->dptr<T>(), scale_size, elements, panel_size, quantization_bit,\n                        out->mut_dptr<T>());\n      } else {  // quantization_scheme == \"affine\"\n        RUN_CUDA_KERNEL((QuantizationAffine<T>), ctx->stream(), elements, in->dptr<T>(),\n                        scale->dptr<T>(), zero_point->dptr<T>(), scale_size, elements, panel_size,\n                        quantization_bit, out->mut_dptr<T>());\n      }\n    } else if (quantization_formula == \"cambricon\") {\n      RUN_CUDA_KERNEL((QuantizationCambricon<T>), ctx->stream(), elements, in->dptr<T>(),\n                      scale->dptr<T>(), scale_size, elements, panel_size, quantization_bit,\n                      out->mut_dptr<T>());\n    } else {\n      UNIMPLEMENTED();\n    }\n\n    std::fesetround(origin_round_mode);\n  }\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_QUANTIZATION_KERNEL(dtype)                            \\\n  REGISTER_USER_KERNEL(\"quantization\")                                 \\\n      .SetCreateFn<GpuQuantizationKernel<dtype>>()                     \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \\\n                       && (user_op::HobDataType(\"in\", 0) == GetDataType<dtype>::value))\n\nREGISTER_QUANTIZATION_KERNEL(float);\nREGISTER_QUANTIZATION_KERNEL(double);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/radix_sort.cuh",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_KERNELS_RADIX_SORT_CUH_\n#define ONEFLOW_USER_KERNELS_RADIX_SORT_CUH_\n\n#include <cub/cub.cuh>\n#include \"oneflow/core/device/cuda_util.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nclass MultiplyFunctor final {\n public:\n  MultiplyFunctor(int32_t num_col) : num_col_(num_col) {}\n  __host__ __device__ __forceinline__ int32_t operator()(int32_t idx) const {\n    return idx * num_col_;\n  }\n\n private:\n  int32_t num_col_;\n};\n\n}  // namespace\n\ntemplate<typename KeyType, typename ValueType>\nsize_t InferTempStorageForSortPairsAscending(int32_t num_row, int32_t num_col) {\n  size_t temp_storage_bytes = 0;\n  if (num_row > 1) {\n    using SegmentOffsetIter =\n        cub::TransformInputIterator<int32_t, MultiplyFunctor, cub::CountingInputIterator<int32_t>>;\n\n    cub::CountingInputIterator<int32_t> counting_iter(0);\n    MultiplyFunctor multiply_functor(num_col);\n    SegmentOffsetIter segment_offset_iter(counting_iter, multiply_functor);\n\n    auto err = cub::DeviceSegmentedRadixSort::SortPairs<KeyType, ValueType, SegmentOffsetIter>(\n        /* d_temp_storage */ nullptr,\n        /* temp_storage_bytes */ temp_storage_bytes,\n        /* d_keys_in */ nullptr,\n        /* d_keys_out */ nullptr,\n        /* d_values_in */ nullptr,\n        /* d_values_out */ nullptr,\n        /* num_items */ num_row * num_col,\n        /* num_segments */ num_row,\n        /* d_begin_offsets */ segment_offset_iter,\n        /* d_end_offsets */ segment_offset_iter + 1,\n        /* begin_bit */ 0,\n        /* end_bit */ sizeof(KeyType) * 8,\n        /* stream */ 0);\n    OF_CUDA_CHECK(err);\n  } else {\n    auto err = cub::DeviceRadixSort::SortPairs<KeyType, ValueType>(\n        /* d_temp_storage */ nullptr,\n        /* temp_storage_bytes */ temp_storage_bytes,\n        /* d_keys_in */ nullptr,\n        /* d_keys_out */ nullptr,\n        /* d_values_in */ nullptr,\n        /* d_values_out */ nullptr,\n        /* num_items */ num_row * num_col,\n        /* begin_bit */ 0,\n        /* end_bit */ sizeof(KeyType) * 8,\n        /* stream */ 0);\n    OF_CUDA_CHECK(err);\n  }\n\n  return temp_storage_bytes;\n}\n\ntemplate<typename KeyType, typename ValueType>\nsize_t InferTempStorageForSortPairsDescending(int32_t num_row, int32_t num_col) {\n  size_t temp_storage_bytes = 0;\n  if (num_row > 1) {\n    using SegmentOffsetIter =\n        cub::TransformInputIterator<int32_t, MultiplyFunctor, cub::CountingInputIterator<int32_t>>;\n\n    cub::CountingInputIterator<int32_t> counting_iter(0);\n    MultiplyFunctor multiply_functor(num_col);\n    SegmentOffsetIter segment_offset_iter(counting_iter, multiply_functor);\n\n    auto err =\n        cub::DeviceSegmentedRadixSort::SortPairsDescending<KeyType, ValueType, SegmentOffsetIter>(\n            /* d_temp_storage */ nullptr,\n            /* temp_storage_bytes */ temp_storage_bytes,\n            /* d_keys_in */ nullptr,\n            /* d_keys_out */ nullptr,\n            /* d_values_in */ nullptr,\n            /* d_values_out */ nullptr,\n            /* num_items */ num_row * num_col,\n            /* num_segments */ num_row,\n            /* d_begin_offsets */ segment_offset_iter,\n            /* d_end_offsets */ segment_offset_iter + 1,\n            /* begin_bit */ 0,\n            /* end_bit */ sizeof(KeyType) * 8,\n            /* stream */ 0);\n    OF_CUDA_CHECK(err);\n  } else {\n    auto err = cub::DeviceRadixSort::SortPairsDescending<KeyType, ValueType>(\n        /* d_temp_storage */ nullptr,\n        /* temp_storage_bytes */ temp_storage_bytes,\n        /* d_keys_in */ nullptr,\n        /* d_keys_out */ nullptr,\n        /* d_values_in */ nullptr,\n        /* d_values_out */ nullptr,\n        /* num_items */ num_row * num_col,\n        /* begin_bit */ 0,\n        /* end_bit */ sizeof(KeyType) * 8,\n        /* stream */ 0);\n    OF_CUDA_CHECK(err);\n  }\n\n  return temp_storage_bytes;\n}\n\ntemplate<typename KeyType>\nsize_t InferTempStorageForSortKeysAscending(int32_t num_row, int32_t num_col) {\n  size_t temp_storage_bytes = 0;\n  if (num_row > 1) {\n    using SegmentOffsetIter =\n        cub::TransformInputIterator<int32_t, MultiplyFunctor, cub::CountingInputIterator<int32_t>>;\n\n    cub::CountingInputIterator<int32_t> counting_iter(0);\n    MultiplyFunctor multiply_functor(num_col);\n    SegmentOffsetIter segment_offset_iter(counting_iter, multiply_functor);\n\n    auto err = cub::DeviceSegmentedRadixSort::SortKeys<KeyType, SegmentOffsetIter>(\n        /* d_temp_storage */ nullptr,\n        /* temp_storage_bytes */ temp_storage_bytes,\n        /* d_keys_in */ nullptr,\n        /* d_keys_out */ nullptr,\n        /* num_items */ num_row * num_col,\n        /* num_segments */ num_row,\n        /* d_begin_offsets */ segment_offset_iter,\n        /* d_end_offsets */ segment_offset_iter + 1,\n        /* begin_bit */ 0,\n        /* end_bit */ sizeof(KeyType) * 8,\n        /* stream */ 0);\n    OF_CUDA_CHECK(err);\n  } else {\n    auto err = cub::DeviceRadixSort::SortKeys<KeyType>(\n        /* d_temp_storage */ nullptr,\n        /* temp_storage_bytes */ temp_storage_bytes,\n        /* d_keys_in */ nullptr,\n        /* d_keys_out */ nullptr,\n        /* num_items */ num_row * num_col,\n        /* begin_bit */ 0,\n        /* end_bit */ sizeof(KeyType) * 8,\n        /* stream */ 0);\n    OF_CUDA_CHECK(err);\n  }\n  return temp_storage_bytes;\n}\n\ntemplate<typename KeyType>\nsize_t InferTempStorageForSortKeysDescending(int32_t num_row, int32_t num_col) {\n  size_t temp_storage_bytes = 0;\n  if (num_row > 1) {\n    using SegmentOffsetIter =\n        cub::TransformInputIterator<int32_t, MultiplyFunctor, cub::CountingInputIterator<int32_t>>;\n\n    cub::CountingInputIterator<int32_t> counting_iter(0);\n    MultiplyFunctor multiply_functor(num_col);\n    SegmentOffsetIter segment_offset_iter(counting_iter, multiply_functor);\n\n    auto err = cub::DeviceSegmentedRadixSort::SortKeysDescending<KeyType, SegmentOffsetIter>(\n        /* d_temp_storage */ nullptr,\n        /* temp_storage_bytes */ temp_storage_bytes,\n        /* d_keys_in */ nullptr,\n        /* d_keys_out */ nullptr,\n        /* num_items */ num_row * num_col,\n        /* num_segments */ num_row,\n        /* d_begin_offsets */ segment_offset_iter,\n        /* d_end_offsets */ segment_offset_iter + 1,\n        /* begin_bit */ 0,\n        /* end_bit */ sizeof(KeyType) * 8,\n        /* stream */ 0);\n    OF_CUDA_CHECK(err);\n  } else {\n    auto err = cub::DeviceRadixSort::SortKeysDescending<KeyType>(\n        /* d_temp_storage */ nullptr,\n        /* temp_storage_bytes */ temp_storage_bytes,\n        /* d_keys_in */ nullptr,\n        /* d_keys_out */ nullptr,\n        /* num_items */ num_row * num_col,\n        /* begin_bit */ 0,\n        /* end_bit */ sizeof(KeyType) * 8,\n        /* stream */ 0);\n    OF_CUDA_CHECK(err);\n  }\n\n  return temp_storage_bytes;\n}\n\ntemplate<typename KeyType, typename ValueType>\nvoid SortPairsAscending(const KeyType* keys_ptr, const ValueType* values_ptr, int32_t num_row,\n                        int32_t num_col, void* temp_storage_ptr, int32_t temp_storage_bytes,\n                        KeyType* sorted_keys_ptr, ValueType* sorted_values_ptr,\n                        cudaStream_t stream) {\n  size_t rt_inferred_temp_storage_bytes =\n      InferTempStorageForSortPairsAscending<KeyType, ValueType>(num_row, num_col);\n  CHECK_LE(rt_inferred_temp_storage_bytes, temp_storage_bytes);\n  if (num_row > 1) {\n    using SegmentOffsetIter =\n        cub::TransformInputIterator<int32_t, MultiplyFunctor, cub::CountingInputIterator<int32_t>>;\n\n    cub::CountingInputIterator<int32_t> counting_iter(0);\n    MultiplyFunctor multiply_functor(num_col);\n    SegmentOffsetIter segment_offset_iter(counting_iter, multiply_functor);\n\n    auto err = cub::DeviceSegmentedRadixSort::SortPairs(\n        /* d_temp_storage */ temp_storage_ptr,\n        /* temp_storage_bytes */ rt_inferred_temp_storage_bytes,\n        /* d_keys_in */ keys_ptr,\n        /* d_keys_out */ sorted_keys_ptr,\n        /* d_values_in */ values_ptr,\n        /* d_values_out */ sorted_values_ptr,\n        /* num_items */ num_row * num_col,\n        /* num_segments */ num_row,\n        /* d_begin_offsets */ segment_offset_iter,\n        /* d_end_offsets */ segment_offset_iter + 1,\n        /* begin_bit */ 0,\n        /* end_bit */ sizeof(KeyType) * 8,\n        /* stream */ stream);\n    OF_CUDA_CHECK(err);\n  } else {\n    auto err = cub::DeviceRadixSort::SortPairs(\n        /* d_temp_storage */ temp_storage_ptr,\n        /* temp_storage_bytes */ rt_inferred_temp_storage_bytes,\n        /* d_keys_in */ keys_ptr,\n        /* d_keys_out */ sorted_keys_ptr,\n        /* d_values_in */ values_ptr,\n        /* d_values_out */ sorted_values_ptr,\n        /* num_items */ num_row * num_col,\n        /* begin_bit */ 0,\n        /* end_bit */ sizeof(KeyType) * 8,\n        /* stream */ stream);\n    OF_CUDA_CHECK(err);\n  }\n}\n\ntemplate<typename KeyType, typename ValueType>\nvoid SortPairsDescending(const KeyType* keys_ptr, const ValueType* values_ptr, int32_t num_row,\n                         int32_t num_col, void* temp_storage_ptr, int32_t temp_storage_bytes,\n                         KeyType* sorted_keys_ptr, ValueType* sorted_values_ptr,\n                         cudaStream_t stream) {\n  size_t rt_inferred_temp_storage_bytes =\n      InferTempStorageForSortPairsDescending<KeyType, ValueType>(num_row, num_col);\n  CHECK_LE(rt_inferred_temp_storage_bytes, temp_storage_bytes);\n\n  if (num_row > 1) {\n    using SegmentOffsetIter =\n        cub::TransformInputIterator<int32_t, MultiplyFunctor, cub::CountingInputIterator<int32_t>>;\n\n    cub::CountingInputIterator<int32_t> counting_iter(0);\n    MultiplyFunctor multiply_functor(num_col);\n    SegmentOffsetIter segment_offset_iter(counting_iter, multiply_functor);\n\n    auto err = cub::DeviceSegmentedRadixSort::SortPairsDescending(\n        /* d_temp_storage */ temp_storage_ptr,\n        /* temp_storage_bytes */ rt_inferred_temp_storage_bytes,\n        /* d_keys_in */ keys_ptr,\n        /* d_keys_out */ sorted_keys_ptr,\n        /* d_values_in */ values_ptr,\n        /* d_values_out */ sorted_values_ptr,\n        /* num_items */ num_row * num_col,\n        /* num_segments */ num_row,\n        /* d_begin_offsets */ segment_offset_iter,\n        /* d_end_offsets */ segment_offset_iter + 1,\n        /* begin_bit */ 0,\n        /* end_bit */ sizeof(KeyType) * 8,\n        /* stream */ stream);\n    OF_CUDA_CHECK(err);\n  } else {\n    auto err = cub::DeviceRadixSort::SortPairsDescending(\n        /* d_temp_storage */ temp_storage_ptr,\n        /* temp_storage_bytes */ rt_inferred_temp_storage_bytes,\n        /* d_keys_in */ keys_ptr,\n        /* d_keys_out */ sorted_keys_ptr,\n        /* d_values_in */ values_ptr,\n        /* d_values_out */ sorted_values_ptr,\n        /* num_items */ num_row * num_col,\n        /* begin_bit */ 0,\n        /* end_bit */ sizeof(KeyType) * 8,\n        /* stream */ stream);\n    OF_CUDA_CHECK(err);\n  }\n}\n\ntemplate<typename KeyType>\nvoid SortKeysAscending(const KeyType* keys_ptr, int32_t num_row, int32_t num_col,\n                       void* temp_storage_ptr, int32_t temp_storage_bytes, KeyType* sorted_keys_ptr,\n                       cudaStream_t stream) {\n  size_t rt_inferred_temp_storage_bytes =\n      InferTempStorageForSortKeysAscending<KeyType>(num_row, num_col);\n  CHECK_LE(rt_inferred_temp_storage_bytes, temp_storage_bytes);\n\n  if (num_row > 1) {\n    using SegmentOffsetIter =\n        cub::TransformInputIterator<int32_t, MultiplyFunctor, cub::CountingInputIterator<int32_t>>;\n\n    cub::CountingInputIterator<int32_t> counting_iter(0);\n    MultiplyFunctor multiply_functor(num_col);\n    SegmentOffsetIter segment_offset_iter(counting_iter, multiply_functor);\n\n    auto err = cub::DeviceSegmentedRadixSort::SortKeys(\n        /* d_temp_storage */ temp_storage_ptr,\n        /* temp_storage_bytes */ rt_inferred_temp_storage_bytes,\n        /* d_keys_in */ keys_ptr,\n        /* d_keys_out */ sorted_keys_ptr,\n        /* num_items */ num_row * num_col,\n        /* num_segments */ num_row,\n        /* d_begin_offsets */ segment_offset_iter,\n        /* d_end_offsets */ segment_offset_iter + 1,\n        /* begin_bit */ 0,\n        /* end_bit */ sizeof(KeyType) * 8,\n        /* stream */ stream);\n    OF_CUDA_CHECK(err);\n  } else {\n    auto err = cub::DeviceRadixSort::SortKeys(\n        /* d_temp_storage */ temp_storage_ptr,\n        /* temp_storage_bytes */ rt_inferred_temp_storage_bytes,\n        /* d_keys_in */ keys_ptr,\n        /* d_keys_out */ sorted_keys_ptr,\n        /* num_items */ num_row * num_col,\n        /* begin_bit */ 0,\n        /* end_bit */ sizeof(KeyType) * 8,\n        /* stream */ stream);\n    OF_CUDA_CHECK(err);\n  }\n}\n\ntemplate<typename KeyType>\nvoid SortKeysDescending(const KeyType* keys_ptr, int32_t num_row, int32_t num_col,\n                        void* temp_storage_ptr, int32_t temp_storage_bytes,\n                        KeyType* sorted_keys_ptr, cudaStream_t stream) {\n  size_t rt_inferred_temp_storage_bytes =\n      InferTempStorageForSortKeysDescending<KeyType>(num_row, num_col);\n  CHECK_LE(rt_inferred_temp_storage_bytes, temp_storage_bytes);\n\n  if (num_row > 1) {\n    using SegmentOffsetIter =\n        cub::TransformInputIterator<int32_t, MultiplyFunctor, cub::CountingInputIterator<int32_t>>;\n\n    cub::CountingInputIterator<int32_t> counting_iter(0);\n    MultiplyFunctor multiply_functor(num_col);\n    SegmentOffsetIter segment_offset_iter(counting_iter, multiply_functor);\n\n    auto err = cub::DeviceSegmentedRadixSort::SortKeysDescending(\n        /* d_temp_storage */ temp_storage_ptr,\n        /* temp_storage_bytes */ rt_inferred_temp_storage_bytes,\n        /* d_keys_in */ keys_ptr,\n        /* d_keys_out */ sorted_keys_ptr,\n        /* num_items */ num_row * num_col,\n        /* num_segments */ num_row,\n        /* d_begin_offsets */ segment_offset_iter,\n        /* d_end_offsets */ segment_offset_iter + 1,\n        /* begin_bit */ 0,\n        /* end_bit */ sizeof(KeyType) * 8,\n        /* stream */ stream);\n    OF_CUDA_CHECK(err);\n  } else {\n    auto err = cub::DeviceRadixSort::SortKeysDescending(\n        /* d_temp_storage */ temp_storage_ptr,\n        /* temp_storage_bytes */ rt_inferred_temp_storage_bytes,\n        /* d_keys_in */ keys_ptr,\n        /* d_keys_out */ sorted_keys_ptr,\n        /* num_items */ num_row * num_col,\n        /* begin_bit */ 0,\n        /* end_bit */ sizeof(KeyType) * 8,\n        /* stream */ stream);\n    OF_CUDA_CHECK(err);\n  }\n}\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_KERNELS_RADIX_SORT_CUH_\n"
  },
  {
    "path": "oneflow/user/kernels/random_crop_kernel_state.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/user/kernels/random_seed_util.h\"\n#include \"oneflow/user/kernels/random_crop_kernel_state.h\"\n\nnamespace oneflow {\n\nstd::shared_ptr<RandomCropKernelState> CreateRandomCropKernelState(\n    user_op::KernelInitContext* ctx) {\n  int32_t num_attempts = ctx->Attr<int32_t>(\"num_attempts\");\n  CHECK(num_attempts >= 1);\n  const std::vector<float>& random_aspect_ratio =\n      ctx->Attr<std::vector<float>>(\"random_aspect_ratio\");\n  CHECK(random_aspect_ratio.size() == 2 && 0 < random_aspect_ratio.at(0)\n        && random_aspect_ratio.at(0) <= random_aspect_ratio.at(1));\n  const std::vector<float>& random_area = ctx->Attr<std::vector<float>>(\"random_area\");\n  CHECK(random_area.size() == 2 && 0 < random_area.at(0) && random_area.at(0) <= random_area.at(1));\n  const user_op::TensorDesc* out_tensor_desc = ctx->TensorDesc4ArgNameAndIndex(\"out\", 0);\n  return std::shared_ptr<RandomCropKernelState>(new RandomCropKernelState(\n      out_tensor_desc->shape().elem_cnt(), CHECK_JUST(GetOpKernelRandomSeed(ctx)),\n      {random_aspect_ratio.at(0), random_aspect_ratio.at(1)},\n      {random_area.at(0), random_area.at(1)}, num_attempts));\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/random_crop_kernel_state.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_KERNELS_RANDOM_CROP_KERNEL_STATE_H_\n#define ONEFLOW_USER_KERNELS_RANDOM_CROP_KERNEL_STATE_H_\n\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/user/image/random_crop_generator.h\"\n\nnamespace oneflow {\n\nclass RandomCropKernelState final : public user_op::OpKernelState {\n public:\n  explicit RandomCropKernelState(int32_t size, int64_t seed, AspectRatioRange aspect_ratio_range,\n                                 AreaRange area_range, int32_t num_attempts)\n      : gens_(size) {\n    std::seed_seq seq{seed};\n    std::vector<int> seeds(size);\n    seq.generate(seeds.begin(), seeds.end());\n    for (int32_t i = 0; i < size; ++i) {\n      gens_.at(i).reset(\n          new RandomCropGenerator(aspect_ratio_range, area_range, seeds.at(i), num_attempts));\n    }\n  }\n  ~RandomCropKernelState() = default;\n\n  RandomCropGenerator* GetGenerator(int32_t idx) { return gens_.at(idx).get(); }\n\n private:\n  std::vector<std::shared_ptr<RandomCropGenerator>> gens_;\n};\n\nstd::shared_ptr<RandomCropKernelState> CreateRandomCropKernelState(user_op::KernelInitContext* ctx);\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_KERNELS_RANDOM_CROP_KERNEL_STATE_H_\n"
  },
  {
    "path": "oneflow/user/kernels/random_mask_generator.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/user/kernels/random_mask_generator.h\"\n\nnamespace oneflow {\n\nvoid RandomMaskGenerator<DeviceType::kCPU>::Generate(ep::Stream* stream, const int64_t n,\n                                                     const float rate, bool* mask) {\n  CHECK_GE(n, 0);\n  std::uniform_real_distribution<float> random_distribution(GetZeroVal<float>(),\n                                                            GetOneVal<float>());\n  for (int64_t i = 0; i < n; ++i) { mask[i] = random_distribution(generator_->engine()) > rate; }\n}\n\ntemplate class RandomMaskGenerator<DeviceType::kCPU>;\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/random_mask_generator.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/user/kernels/random_mask_generator.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n#include \"oneflow/user/kernels/distributions/distribution_template_util.cuh\"\n\nnamespace oneflow {\n\nnamespace {\n\nusing PackType = ulonglong2;\n\nunion Pack {\n  PackType p_value;\n  bool b_value[sizeof(PackType)];\n};\n\n__device__ bool GenMask(curandStatePhilox4_32_10_t* state, const float rate) {\n  return curand_uniform(state) > rate;\n}\n\n__global__ void GenerateGpu(uint64_t seed, uint64_t offset, const int64_t n, const float rate,\n                            bool* mask) {\n  const int id = blockIdx.x * blockDim.x + threadIdx.x;\n  curandStatePhilox4_32_10_t state;\n  curand_init(seed, id, offset, &state);\n  PackType* pack_mask = reinterpret_cast<PackType*>(mask);\n  Pack pack;\n  CUDA_1D_KERNEL_LOOP(i, n / sizeof(PackType)) {\n#pragma unroll\n    for (int j = 0; j < sizeof(PackType); j += 4) {\n      auto rand = curand_uniform4(&state);\n      pack.b_value[j] = (&rand.x)[0] > rate;\n      pack.b_value[j + 1] = (&rand.x)[1] > rate;\n      pack.b_value[j + 2] = (&rand.x)[2] > rate;\n      pack.b_value[j + 3] = (&rand.x)[3] > rate;\n    }\n    pack_mask[i] = pack.p_value;\n  }\n\n  const int32_t rem_cnt = n % sizeof(PackType);\n  const int32_t rem_offset = n - rem_cnt;\n  if (id < rem_cnt) { mask[id + rem_offset] = GenMask(&state, rate); }\n}\n\n}  // namespace\n\nvoid RandomMaskGenerator<DeviceType::kCUDA>::Generate(ep::Stream* stream, const int64_t n,\n                                                      const float rate, bool* mask) {\n  if (n == 0) return;\n  ep::CudaStream* cuda_stream = stream->As<ep::CudaStream>();\n  auto execution_policy = generator_->CalcExecutionPolicy(n, cuda_stream);\n\n  auto counter_offset = std::get<0>(execution_policy);\n  auto grid = std::get<1>(execution_policy);\n  auto block = std::get<2>(execution_policy);\n\n  uint64_t seed = generator_->current_seed();\n  uint64_t offset = generator_->get_philox_offset(counter_offset);\n\n  GenerateGpu<<<grid, block, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(seed, offset, n,\n                                                                               rate, mask);\n}\n\ntemplate class RandomMaskGenerator<DeviceType::kCUDA>;\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/random_mask_generator.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_KERNELS_RANDOM_MASK_GENERATOR_H_\n#define ONEFLOW_USER_KERNELS_RANDOM_MASK_GENERATOR_H_\n\n#include \"oneflow/core/common/data_type.h\"\n#include \"oneflow/core/ep/include/stream.h\"\n#include \"oneflow/core/framework/random_generator.h\"\n#ifdef WITH_CUDA\n#include <curand.h>\n#include <curand_kernel.h>\n#endif\n\nnamespace oneflow {\n\ntemplate<DeviceType device_type>\nclass RandomMaskGenerator;\n\ntemplate<>\nclass RandomMaskGenerator<DeviceType::kCPU> final {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(RandomMaskGenerator);\n  RandomMaskGenerator(const std::shared_ptr<one::Generator>& generator,\n                      const int device_index = -1) {\n    generator_ = CHECK_JUST(generator->Get<ep::CPUGenerator>(device_index));\n  }\n  ~RandomMaskGenerator() = default;\n\n  void Generate(ep::Stream* stream, int64_t n, float rate, bool* mask);\n\n private:\n  std::shared_ptr<ep::CPUGenerator> generator_;\n};\n\n#ifdef WITH_CUDA\ntemplate<>\nclass RandomMaskGenerator<DeviceType::kCUDA> final {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(RandomMaskGenerator);\n  RandomMaskGenerator(const std::shared_ptr<one::Generator>& generator,\n                      const int device_index = -1) {\n    generator_ = CHECK_JUST(generator->Get<ep::CUDAGenerator>(device_index));\n  }\n  ~RandomMaskGenerator() = default;\n\n  void Generate(ep::Stream* stream, int64_t n, float rate, bool* mask);\n\n private:\n  std::shared_ptr<ep::CUDAGenerator> generator_;\n};\n#endif\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_KERNELS_RANDOM_MASK_GENERATOR_H_\n"
  },
  {
    "path": "oneflow/user/kernels/random_mask_like_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/user/kernels/random_mask_like_kernel.h\"\n\nnamespace oneflow {\n\nnamespace {\n#define REGISTER_RANDOM_MASK_LIKE_KERNEL(device)   \\\n  REGISTER_USER_KERNEL(\"random_mask_like\")         \\\n      .SetCreateFn<RandomMaskLikeKernel<device>>() \\\n      .SetIsMatchedHob(user_op::HobDeviceType() == device);\n\nREGISTER_RANDOM_MASK_LIKE_KERNEL(DeviceType::kCPU)\n#ifdef WITH_CUDA\nREGISTER_RANDOM_MASK_LIKE_KERNEL(DeviceType::kCUDA)\n#endif\n}  // namespace\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/random_mask_like_kernel.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_KERNELS_RANDOM_MASK_LIKE_KERNEL_H_\n#define ONEFLOW_USER_KERNELS_RANDOM_MASK_LIKE_KERNEL_H_\n\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/cuda_graph_support.h\"\n#include \"oneflow/user/kernels/random_mask_generator.h\"\n#include \"oneflow/user/kernels/random_seed_util.h\"\n#include \"oneflow/core/ep/include/device.h\"\n\nnamespace oneflow {\n\nclass RandomMaskLikeKernelState : public user_op::OpKernelState {\n public:\n  explicit RandomMaskLikeKernelState(const std::shared_ptr<one::Generator>& generator)\n      : generator_(generator) {}\n\n  const std::shared_ptr<one::Generator>& generator() const { return generator_; }\n\n private:\n  std::shared_ptr<one::Generator> generator_;\n};\n\nnamespace {\n\ntemplate<DeviceType device_type>\nclass RandomMaskLikeKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport {\n public:\n  RandomMaskLikeKernel() = default;\n  ~RandomMaskLikeKernel() = default;\n\n  std::shared_ptr<user_op::OpKernelState> CreateOpKernelState(\n      user_op::KernelInitContext* ctx) const override {\n    const auto& generator = CHECK_JUST(one::MakeGenerator(device_type));\n    generator->set_current_seed(\n        CHECK_JUST(GetOpKernelRandomSeedInCurrentRank(ctx, ctx->Attr<int64_t>(\"seed\"))));\n    return std::make_shared<RandomMaskLikeKernelState>(generator);\n  }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state,\n               const user_op::OpKernelCache*) const override {\n    const user_op::Tensor* like = ctx->Tensor4ArgNameAndIndex(\"like\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    int64_t elem_cnt = like->shape_view().elem_cnt();\n    bool* mask = out->mut_dptr<bool>();\n    auto* random_mask_like_state = dynamic_cast<RandomMaskLikeKernelState*>(state);\n    CHECK_NOTNULL(random_mask_like_state);\n    const auto& generator = random_mask_like_state->generator();\n    CHECK_NOTNULL(generator);\n    auto* stream = ctx->stream();\n    const auto device_index = stream->device()->device_index();\n    auto random_mask_like_gen =\n        std::make_shared<RandomMaskGenerator<device_type>>(generator, device_index);\n    random_mask_like_gen->Generate(stream, elem_cnt, ctx->Attr<float>(\"rate\"), mask);\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n}  // namespace\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_KERNELS_RANDOM_MASK_LIKE_KERNEL_H_\n"
  },
  {
    "path": "oneflow/user/kernels/random_seed_util.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/user/kernels/random_seed_util.h\"\n#include \"oneflow/core/common/container_util.h\"\n#include \"oneflow/core/functional/impl/common.h\"\n#include \"oneflow/core/rpc/include/global_process_ctx.h\"\n\nnamespace oneflow {\n\nMaybe<uint64_t> GetOpKernelRandomSeed(const user_op::KernelInitContext* ctx) {\n  int64_t seed = ctx->Attr<int64_t>(\"seed\");\n  if (!ctx->Attr<bool>(\"has_seed\")) { seed = NewRandomSeed(); }\n  return GetOpKernelRandomSeedInCurrentRank(ctx, seed);\n}\n\n// NOTE: Get random seed in current rank, and ensure that it will have same seed between\n// broadcast sbp and it will be different between split sbp.\n//\n// It will scan nd_sbp from last axis to first axis(It likes the algorithm in NdIndexOffsetHelper).\n// If sbp is broadcast, this axis will skip.\n// If sbp is split, it will use rand_id to accumulate the offset.\nMaybe<uint64_t> GetRandomSeedForRank(const ParallelDesc& placement, const NdSbp& nd_sbp,\n                                     uint64_t init_seed, int64_t rank_id) {\n  uint64_t seed = init_seed;\n  const Shape& hierarchy = *placement.hierarchy();\n  int64_t seed_idx = 0;\n  int64_t stride = 1;\n  for (int i = nd_sbp.sbp_parallel_size() - 1; i >= 0; --i) {\n    // coordinate at axis i\n    int coord = rank_id % hierarchy.At(i);\n    rank_id = (rank_id - coord) / hierarchy.At(i);\n    // coordinate reset to 0 if broadcast\n    if (nd_sbp.sbp_parallel(i).has_broadcast_parallel()) {\n      // do nothing\n    } else if (nd_sbp.sbp_parallel(i).has_split_parallel()) {\n      seed_idx += coord * stride;\n      stride *= hierarchy.At(i);\n    } else {\n      // other sbp is not allowed\n      return Error::RuntimeError() << \"random source op only support broadcast or split\";\n    }\n  }\n  std::seed_seq seq{init_seed};\n  std::vector<uint64_t> seeds(stride);\n  seq.generate(seeds.begin(), seeds.end());\n  seed = JUST(VectorAt(seeds, seed_idx));\n  return seed;\n}\n\nMaybe<uint64_t> GetOpKernelRandomSeedInCurrentRank(const user_op::KernelInitContext* ctx,\n                                                   uint64_t init_seed, const user_op::OpArg& arg) {\n  if (ctx->parallel_ctx().parallel_num() == 1) { return init_seed; }\n  CHECK_OR_RETURN(ctx->has_output(arg.name(), arg.index()))\n      << arg.name() << \"_\" << arg.index() << \" not exist\";\n  const auto& nd_sbp = ctx->NdSbp4ArgNameAndIndex(arg.name(), arg.index());\n  return GetRandomSeedForRank(ctx->parallel_desc(), nd_sbp, init_seed,\n                              ctx->parallel_ctx().parallel_id());\n}\n\nMaybe<one::Generator> GetGeneratorForLazyOrGlobal(const std::shared_ptr<one::Generator>& generator,\n                                                  bool is_lazy,\n                                                  const Optional<Symbol<ParallelDesc>>& placement,\n                                                  const Optional<Symbol<NdSbp>>& nd_sbp) {\n  bool is_global = placement.has_value() && nd_sbp.has_value();\n  if (!is_lazy && !is_global) { return generator; }\n\n  const auto& eager_cached_generator = generator->children_generators();\n\n  if (!is_lazy) {\n    Symbol<ParallelDesc> placement_val = JUST(placement);\n    Symbol<NdSbp> nd_sbp_val = JUST(nd_sbp);\n    if (eager_cached_generator.find(std::make_pair(placement_val, nd_sbp_val))\n        != eager_cached_generator.end()) {\n      return JUST(MapAt(eager_cached_generator, std::make_pair(placement_val, nd_sbp_val)));\n    }\n  }\n\n  uint64_t init_seed = 0;\n  if (is_lazy) {\n    auto cpu_gen = JUST(generator->Get<ep::CPUGenerator>(0));\n    CHECK_OR_RETURN(cpu_gen) << \"expect a CPUGenerator\";\n    init_seed = cpu_gen->engine()();\n  } else {\n    init_seed = generator->current_seed();\n  }\n  auto new_gen = JUST(one::MakeGenerator(JUST(generator->device())->type()));\n  if (is_lazy) {\n    new_gen->set_current_seed(init_seed);\n    return new_gen;\n  }\n\n  uint64_t rank_seed = init_seed;\n  if (JUST(placement)->parallel_num() > 1) {\n    JUST(one::functional::BroadcastSeedToAllRanks(&init_seed, /*root=*/0));\n    rank_seed = JUST(\n        GetRandomSeedForRank(*JUST(placement), *JUST(nd_sbp), init_seed, GlobalProcessCtx::Rank()));\n  }\n  new_gen->set_current_seed(rank_seed);\n  if (!is_lazy) { generator->add_children_generator(JUST(placement), JUST(nd_sbp), new_gen); }\n  return new_gen;\n}\n\nMaybe<one::Generator> GetGeneratorForLazyOrGlobal(const std::shared_ptr<one::Generator>& generator,\n                                                  bool is_lazy,\n                                                  const std::shared_ptr<one::Tensor>& input) {\n  if (input->is_global()) {\n    return GetGeneratorForLazyOrGlobal(generator, is_lazy, JUST(input->parallel_desc()),\n                                       JUST(input->nd_sbp()));\n  } else {\n    return GetGeneratorForLazyOrGlobal(generator, is_lazy, NullOpt, NullOpt);\n  }\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/random_seed_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_KERNELS_RANDOM_SEED_UTIL_H_\n#define ONEFLOW_USER_KERNELS_RANDOM_SEED_UTIL_H_\n\n#include \"oneflow/core/framework/op_kernel.h\"\n#include \"oneflow/core/framework/random_generator.h\"\n\nnamespace oneflow {\n\nMaybe<uint64_t> GetRandomSeedForRank(const ParallelDesc& placement, const NdSbp& nd_sbp,\n                                     uint64_t init_seed, int64_t rank_id);\n\nMaybe<uint64_t> GetOpKernelRandomSeed(const user_op::KernelInitContext* ctx);\nMaybe<uint64_t> GetOpKernelRandomSeedInCurrentRank(const user_op::KernelInitContext* ctx,\n                                                   uint64_t init_seed,\n                                                   const user_op::OpArg& arg = {\"out\", 0});\n\nMaybe<one::Generator> GetGeneratorForLazyOrGlobal(const std::shared_ptr<one::Generator>& generator,\n                                                  bool is_lazy,\n                                                  const Optional<Symbol<ParallelDesc>>& placement,\n                                                  const Optional<Symbol<NdSbp>>& nd_sbp);\n\nMaybe<one::Generator> GetGeneratorForLazyOrGlobal(const std::shared_ptr<one::Generator>& generator,\n                                                  bool is_lazy,\n                                                  const std::shared_ptr<one::Tensor>& input);\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_KERNELS_RANDOM_SEED_UTIL_H_\n"
  },
  {
    "path": "oneflow/user/kernels/randperm_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/random_generator.h\"\n#include \"oneflow/core/common/data_type.h\"\n#include \"oneflow/core/common/container_util.h\"\n#include \"oneflow/user/kernels/op_kernel_wrapper.h\"\n#include \"oneflow/user/kernels/arange_kernel_util.h\"\n#include \"oneflow/user/kernels/distributions/common.h\"\n#include \"oneflow/user/kernels/random_seed_util.h\"\n#include \"oneflow/core/job/nd_sbp_util.h\"\n#include \"oneflow/core/register/tensor_slice_view.h\"\n#include \"oneflow/core/ep/include/stream.h\"\n\nnamespace oneflow {\nclass CpuRandPermKernelCache final : public user_op::OpKernelCache {\n public:\n  CpuRandPermKernelCache(int32_t lower, int32_t upper) : lower_(lower), upper_(upper) {}\n  ~CpuRandPermKernelCache() override = default;\n\n  int32_t lower() const { return lower_; }\n  int32_t upper() const { return upper_; }\n\n private:\n  const int32_t lower_;\n  const int32_t upper_;\n};\n\nclass CpuRandPermKernel final : public user_op::OpKernel {\n public:\n  CpuRandPermKernel() = default;\n  ~CpuRandPermKernel() = default;\n  std::shared_ptr<user_op::OpKernelCache> InitOpKernelCache(\n      user_op::KernelCacheContext* ctx) const override {\n    int64_t parallel_num = ctx->parallel_ctx().parallel_num();\n    if (parallel_num > 1) {\n      const NdSbp& nd_sbp = ctx->NdSbp4ArgNameAndIndex(\"out\", 0);\n      const Shape& hierarchy = *ctx->parallel_desc().hierarchy();\n      int64_t parallel_id = ctx->parallel_ctx().parallel_id();\n      int32_t n = ctx->Attr<int32_t>(\"n\");\n      const Shape& logical_shape = Shape({n});\n      TensorSliceView view =\n          GetTensorSliceView4ParallelId(hierarchy, nd_sbp, logical_shape, parallel_id);\n      std::shared_ptr<CpuRandPermKernelCache> cache(\n          new CpuRandPermKernelCache(view.At(0).begin(), view.At(0).end()));\n      return cache;\n    } else {\n      return nullptr;\n    }\n  }\n  std::shared_ptr<user_op::OpKernelState> CreateOpKernelState(\n      user_op::KernelInitContext* ctx) const override {\n    const auto& generator = CHECK_JUST(one::MakeGenerator(kCPU));\n    generator->set_current_seed(\n        CHECK_JUST(GetOpKernelRandomSeedInCurrentRank(ctx, ctx->Attr<int64_t>(\"seed\"))));\n    return std::make_shared<DistributionKernelState>(generator);\n  }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state,\n               const user_op::OpKernelCache* cache) const override {\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    int32_t* output = out->mut_dptr<int32_t>();\n    const int32_t n = ctx->Attr<int32_t>(\"n\");\n    if (n == 0) { return; }\n    user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n    int32_t* temp = tmp_buffer->mut_dptr<int32_t>();\n    auto* distribution_state = dynamic_cast<DistributionKernelState*>(state);\n    CHECK_NOTNULL(distribution_state);\n    const auto& generator = distribution_state->generator();\n    const auto& cpu_generator = CHECK_JUST(generator->Get<ep::CPUGenerator>());\n    CHECK_NOTNULL(generator);\n    if (cache == nullptr) {\n      user_op::ArangeFunctor<DeviceType::kCPU, int32_t>()(ctx->stream(), 0, 1, n, output);\n      std::shuffle(output, output + n, cpu_generator->engine());\n    } else {\n      const auto* arange_cache = dynamic_cast<const CpuRandPermKernelCache*>(cache);\n      user_op::ArangeFunctor<DeviceType::kCPU, int32_t>()(ctx->stream(), 0, 1, n, temp);\n      std::shuffle(temp, temp + n, cpu_generator->engine());\n      auto len = arange_cache->upper() - arange_cache->lower();\n      memcpy(output, temp + arange_cache->lower(), sizeof(int32_t) * len);\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\nREGISTER_USER_KERNEL(\"randperm\")\n    .SetCreateFn<CpuRandPermKernel>()\n    .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU))\n    .SetInferTmpSizeFn([](user_op::InferContext* ctx) {\n      const int32_t n = ctx->Attr<int32_t>(\"n\");\n      return n * sizeof(int32_t);\n    });\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/randperm_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <curand.h>\n#include <curand_kernel.h>\n\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/random_generator.h\"\n#include \"oneflow/core/common/data_type.h\"\n#include \"oneflow/core/common/container_util.h\"\n#include \"oneflow/core/device/cuda_util.h\"\n#include \"oneflow/user/kernels/op_kernel_wrapper.h\"\n#include \"oneflow/user/kernels/arange_kernel_util.h\"\n#include \"oneflow/user/kernels/radix_sort.cuh\"\n#include \"oneflow/user/kernels/random_seed_util.h\"\n#include \"oneflow/user/kernels/distributions/common.h\"\n#include \"oneflow/user/kernels/distributions/distribution_template_util.cuh\"\n#include \"oneflow/core/ep/include/device.h\"\n#include \"oneflow/core/ep/include/stream.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n#include \"oneflow/core/job/nd_sbp_util.h\"\n#include \"oneflow/core/register/tensor_slice_view.h\"\n\nnamespace oneflow {\n__global__ void GeneKeysAndValues(const int32_t n, uint64_t seed, uint64_t offset, int32_t* values,\n                                  int32_t* keys) {\n  const int id = blockIdx.x * blockDim.x + threadIdx.x;\n  curandStatePhilox4_32_10_t state;\n  curand_init(seed, id, offset, &state);\n  CUDA_1D_KERNEL_LOOP(i, n) {\n    keys[i] = curand(&state);\n    values[i] = i;\n  }\n}\n\n__global__ void tempcopy2output(const int32_t n, const int32_t offset, int32_t* temp,\n                                int32_t* output) {\n  CUDA_1D_KERNEL_LOOP(i, n) { output[i] = temp[offset + i]; }\n}\nclass GpuRandPermKernelCache final : public user_op::OpKernelCache {\n public:\n  GpuRandPermKernelCache(int32_t lower, int32_t upper) : lower_(lower), upper_(upper) {}\n  ~GpuRandPermKernelCache() override = default;\n\n  int32_t lower() const { return lower_; }\n  int32_t upper() const { return upper_; }\n\n private:\n  const int32_t lower_;\n  const int32_t upper_;\n};\n\nnamespace {\n\ntemplate<typename K>\nsize_t GetCubSortPairsTempStorageSize(int64_t n) {\n  size_t cub_sort_temp_store_size = 0;\n  OF_CUDA_CHECK((cub::DeviceRadixSort::SortPairs<K, K>(nullptr, cub_sort_temp_store_size, nullptr,\n                                                       nullptr, nullptr, nullptr, n)));\n  size_t temp_store_size = GetCudaAlignedSize(cub_sort_temp_store_size);\n  CHECK_GE(temp_store_size, 0) << \"temp_store_size should >= 0.\";\n  CHECK_LT(temp_store_size, static_cast<size_t>(GetMaxVal<int64_t>()))\n      << \"temp_store_size should < \" << static_cast<size_t>(GetMaxVal<int64_t>());\n  return temp_store_size;\n}\n\n}  // namespace\n\nclass GpuRandPermKernel final : public user_op::OpKernel {\n public:\n  GpuRandPermKernel() = default;\n  ~GpuRandPermKernel() = default;\n  std::shared_ptr<user_op::OpKernelCache> InitOpKernelCache(\n      user_op::KernelCacheContext* ctx) const override {\n    int64_t parallel_num = ctx->parallel_ctx().parallel_num();\n    if (parallel_num > 1) {\n      const NdSbp& nd_sbp = ctx->NdSbp4ArgNameAndIndex(\"out\", 0);\n      const Shape& hierarchy = *ctx->parallel_desc().hierarchy();\n      int64_t parallel_id = ctx->parallel_ctx().parallel_id();\n      int32_t n = ctx->Attr<int32_t>(\"n\");\n      const Shape& logical_shape = Shape({n});\n      TensorSliceView view =\n          GetTensorSliceView4ParallelId(hierarchy, nd_sbp, logical_shape, parallel_id);\n      std::shared_ptr<GpuRandPermKernelCache> cache(\n          new GpuRandPermKernelCache(view.At(0).begin(), view.At(0).end()));\n      return cache;\n    } else {\n      return nullptr;\n    }\n  }\n  std::shared_ptr<user_op::OpKernelState> CreateOpKernelState(\n      user_op::KernelInitContext* ctx) const override {\n    const auto& generator = CHECK_JUST(one::MakeGenerator(kCUDA));\n    generator->set_current_seed(\n        CHECK_JUST(GetOpKernelRandomSeedInCurrentRank(ctx, ctx->Attr<int64_t>(\"seed\"))));\n    return std::make_shared<DistributionKernelState>(generator);\n  }\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state,\n               const user_op::OpKernelCache* cache) const override {\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    int32_t* output = out->mut_dptr<int32_t>();\n    const int32_t n = ctx->Attr<int32_t>(\"n\");\n    if (n == 0) { return; }\n    user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n\n    auto* distribution_state = dynamic_cast<DistributionKernelState*>(state);\n    CHECK_NOTNULL(distribution_state);\n    const auto& generator = distribution_state->generator();\n    CHECK_NOTNULL(generator);\n    auto* stream = ctx->stream();\n    const auto device_index = stream->device()->device_index();\n    const auto& gpu_generator = CHECK_JUST(generator->Get<ep::CUDAGenerator>(device_index));\n\n    ep::CudaStream* cuda_stream = stream->As<ep::CudaStream>();\n    auto execution_policy = gpu_generator->CalcExecutionPolicy(n, cuda_stream);\n\n    auto counter_offset = std::get<0>(execution_policy);\n    auto grid = std::get<1>(execution_policy);\n    auto block = std::get<2>(execution_policy);\n\n    uint64_t seed = gpu_generator->current_seed();\n    uint64_t offset = gpu_generator->get_philox_offset(counter_offset);\n\n    // layout for tmp |...key(in and out,2xN)..|....value....|.... space for sort function....|\n    // values are the desired indexes ,and keys are generated randomly.\n    void* tmp = tmp_buffer->mut_dptr<void>();\n    int32_t* key_base = reinterpret_cast<int32_t*>(tmp);\n\n    const int32_t key_aligned_bytes = GetCudaAlignedSize(n * sizeof(int32_t));\n    int32_t* value_base =\n        reinterpret_cast<int32_t*>(reinterpret_cast<char*>(key_base) + 2 * key_aligned_bytes);\n    const int32_t indices_aligned_bytes = GetCudaAlignedSize(n * sizeof(int32_t));\n    int32_t* temp_buffer_base =\n        reinterpret_cast<int32_t*>(reinterpret_cast<char*>(value_base) + indices_aligned_bytes);\n    const int32_t temp_buffer_aligned_bytes = GetCudaAlignedSize(n * sizeof(int32_t));\n\n    void* tmp_base = reinterpret_cast<void*>(reinterpret_cast<char*>(temp_buffer_base)\n                                             + temp_buffer_aligned_bytes);\n    size_t temp_storage_bytes = GetCubSortPairsTempStorageSize<int32_t>(n);\n    GeneKeysAndValues<<<grid, block, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(\n        n, seed, offset, value_base, key_base);\n    if (cache == nullptr) {\n      auto err = cub::DeviceRadixSort::SortPairs(\n          /* d_temp_storage */ tmp_base,\n          /* temp_storage_bytes */ temp_storage_bytes,\n          /* d_keys_in */ key_base,\n          /* d_keys_out */ key_base + n,\n          /* d_values_in */ value_base,\n          /* d_values_out */ output,\n          /* num_items */ n,\n          /* begin_bit */ 0,\n          /* end_bit */ sizeof(int32_t) * 8,\n          /* stream */ ctx->stream()->As<ep::CudaStream>()->cuda_stream());\n      OF_CUDA_CHECK(err);\n    } else {\n      auto err = cub::DeviceRadixSort::SortPairs(\n          /* d_temp_storage */ tmp_base,\n          /* temp_storage_bytes */ temp_storage_bytes,\n          /* d_keys_in */ key_base,\n          /* d_keys_out */ key_base + n,\n          /* d_values_in */ value_base,\n          /* d_values_out */ temp_buffer_base,\n          /* num_items */ n,\n          /* begin_bit */ 0,\n          /* end_bit */ sizeof(int32_t) * 8,\n          /* stream */ ctx->stream()->As<ep::CudaStream>()->cuda_stream());\n      OF_CUDA_CHECK(err);\n      const auto* randperm_cache = dynamic_cast<const GpuRandPermKernelCache*>(cache);\n      auto len = randperm_cache->upper() - randperm_cache->lower();\n      const int64_t offset = randperm_cache->lower();\n      int32_t block_num = gpu_generator->max_block_num();\n      tempcopy2output<<<block_num, kCudaThreadsNumPerBlock, 0,\n                        ctx->stream()->As<ep::CudaStream>()->cuda_stream()>>>(\n          len, offset, temp_buffer_base, output);\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\nREGISTER_USER_KERNEL(\"randperm\")\n    .SetCreateFn<GpuRandPermKernel>()\n    .SetIsMatchedHob(user_op::HobDeviceType() == DeviceType::kCUDA)\n    .SetInferTmpSizeFn([](user_op::InferContext* ctx) {\n      const int32_t n = ctx->Attr<int32_t>(\"n\");\n      /* Sorted In */\n      const int32_t sorted_in_aligned_bytes = 2 * GetCudaAlignedSize(n * sizeof(int32_t));\n      /* Indices */\n      const int32_t indices_aligned_bytes = GetCudaAlignedSize(n * sizeof(int32_t));\n      const int32_t temp_aligned_bytes = GetCudaAlignedSize(n * sizeof(int32_t));\n\n      /* CUB Temp Storage */\n      const int32_t temp_storage_bytes = GetCubSortPairsTempStorageSize<int32_t>(n);\n\n      return sorted_in_aligned_bytes + indices_aligned_bytes + temp_storage_bytes\n             + temp_aligned_bytes;\n    });\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/raw_reader_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/common/buffer.h\"\n#include \"oneflow/core/embedding/posix_file.h\"\n#include \"oneflow/core/common/nd_index_offset_helper.h\"\n#include \"oneflow/core/device/cuda_util.h\"\n#include \"oneflow/core/common/channel.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nstruct Block {\n  size_t file_index;\n  size_t offset_in_file;\n};\n\nstruct BatchReaderRequest {\n  std::shared_ptr<std::vector<size_t>> blocks;\n  void* buffer{};\n};\n\nclass BatchReader {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(BatchReader);\n  BatchReader(std::vector<std::unique_ptr<embedding::PosixFile>>&& files,\n              std::vector<Block>&& blocks, size_t block_size_bytes, size_t num_workers)\n      : head_(0),\n        tail_(0),\n        files_(std::move(files)),\n        blocks_(blocks),\n        block_size_bytes_(block_size_bytes),\n        num_workers_(num_workers) {\n    for (size_t i = 0; i < num_workers_; ++i) {\n      Worker worker;\n      auto* sq = new Channel<BatchReaderRequest>();\n      auto* cq = new Channel<BatchReaderRequest>();\n      worker.sq.reset(sq);\n      worker.cq.reset(cq);\n      worker.thread = std::thread([sq, cq, this]() {\n        while (true) {\n          BatchReaderRequest request;\n          auto status = sq->Receive(&request);\n          if (status == kChannelStatusErrorClosed) { break; }\n          CHECK_EQ(status, kChannelStatusSuccess) << \"channel error\";\n          size_t buffer_offset = 0;\n          for (size_t i = 0; i < request.blocks->size(); ++i) {\n            size_t block_index = request.blocks->at(i);\n            const Block& block = blocks_[block_index];\n            size_t remaining = block_size_bytes_;\n            size_t file_index = block.file_index;\n            size_t file_offset = block.offset_in_file;\n            while (remaining != 0) {\n              const size_t bytes_to_read =\n                  std::min(remaining, files_.at(file_index)->Size() - file_offset);\n              PCHECK(pread(files_[file_index]->fd(),\n                           reinterpret_cast<unsigned char*>(request.buffer) + buffer_offset,\n                           bytes_to_read, file_offset)\n                     == bytes_to_read)\n                  << \"file read error\";\n              remaining -= bytes_to_read;\n              buffer_offset += bytes_to_read;\n              if (remaining != 0) {\n                file_index = (file_index + 1) % files_.size();\n                file_offset = 0;\n              }\n            }\n          }\n          CHECK(cq->Send(std::move(request)) == kChannelStatusSuccess) << \"channel error\";\n        }\n      });\n      workers_.emplace_back(std::move(worker));\n    }\n  }\n  ~BatchReader() {\n    for (auto& work : workers_) { work.Close(); }\n  }\n\n  void SubmitRequest(BatchReaderRequest&& request) {\n    size_t worker_id = head_.fetch_add(1, std::memory_order_relaxed) % workers_.size();\n    workers_.at(worker_id).sq->Send(std::move(request));\n  }\n  void WaitCompleted(BatchReaderRequest* request) {\n    size_t worker_id = tail_.fetch_add(1, std::memory_order_relaxed) % workers_.size();\n    workers_.at(worker_id).cq->Receive(request);\n  }\n\n private:\n  struct Worker {\n    std::thread thread;\n    std::unique_ptr<Channel<BatchReaderRequest>> sq;\n    std::unique_ptr<Channel<BatchReaderRequest>> cq;\n    void Close() {\n      sq->Close();\n      cq->Close();\n      thread.join();\n    }\n  };\n  std::atomic<size_t> head_;\n  std::atomic<size_t> tail_;\n  std::vector<Worker> workers_;\n  std::vector<std::unique_ptr<embedding::PosixFile>> files_;\n  std::vector<Block> blocks_;\n  size_t block_size_bytes_;\n  size_t num_workers_;\n};\n\nsize_t GetNumShards(const Shape& hierarchy, const NdSbp& nd_sbp) {\n  size_t num_shards = 1;\n  FOR_RANGE(size_t, i, 0, nd_sbp.sbp_parallel_size()) {\n    const auto& sbp_parallel = nd_sbp.sbp_parallel(i);\n    if (sbp_parallel.has_split_parallel()) {\n      num_shards *= hierarchy.At(sbp_parallel.split_parallel().axis());\n    }\n  }\n  return num_shards;\n}\n\nsize_t GetShardIndex(const Shape& hierarchy, const NdSbp& nd_sbp, size_t rank) {\n  using index_helper_t = NdIndexOffsetHelper<int64_t, SHAPE_MAX_AXIS_SIZE>;\n  size_t ndim = hierarchy.NumAxes();\n  CHECK_GT(ndim, 0) << \"wrong hierarchy\";\n  CHECK_LE(ndim, SHAPE_MAX_AXIS_SIZE) << \"wrong hierarchy\";\n  index_helper_t index_helper(hierarchy.dim_vec().data(), ndim);\n  int64_t nd_index[SHAPE_MAX_AXIS_SIZE] = {0};\n  index_helper.OffsetToNdIndex(rank, nd_index);\n  size_t stride = 1;\n  size_t index = 0;\n  for (int i = ndim - 1; i >= 0; --i) {\n    const auto& sbp_parallel = nd_sbp.sbp_parallel(i);\n    if (sbp_parallel.has_split_parallel()) {\n      index += nd_index[i] * stride;\n      stride *= hierarchy.At(i);\n    }\n  }\n  return index;\n}\n\nclass BatchGenerator {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(BatchGenerator);\n  BatchGenerator() = default;\n  virtual ~BatchGenerator() = default;\n\n  virtual void Next(size_t* blocks) = 0;\n};\n\nclass SequentialBatchGenerator : public BatchGenerator {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(SequentialBatchGenerator);\n  SequentialBatchGenerator(size_t shard_index, size_t num_shards, size_t num_batches,\n                           size_t num_blocks_per_batch)\n      : shard_index_(shard_index),\n        num_shards_(num_shards),\n        num_batches_(num_batches),\n        num_blocks_per_batch_(num_blocks_per_batch),\n        num_blocks_per_local_batch_(num_blocks_per_batch_ / num_shards_),\n        next_batch_index_(0) {}\n  ~SequentialBatchGenerator() override = default;\n\n  void Next(size_t* blocks) override {\n    const size_t batch_index = next_batch_index_;\n    next_batch_index_ = (batch_index + 1) % num_batches_;\n    for (size_t i = 0; i < num_blocks_per_local_batch_; ++i) {\n      blocks[i] =\n          batch_index * num_blocks_per_batch_ + shard_index_ * num_blocks_per_local_batch_ + i;\n    }\n  }\n\n private:\n  size_t shard_index_;\n  size_t num_shards_;\n  size_t num_batches_;\n  size_t num_blocks_per_batch_;\n  size_t num_blocks_per_local_batch_;\n  size_t next_batch_index_;\n};\n\nclass RandomShuffleBatchGenerator : public BatchGenerator {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(RandomShuffleBatchGenerator);\n  RandomShuffleBatchGenerator(size_t shard_index, size_t num_shards, size_t num_batches,\n                              size_t num_blocks_per_batch, std::mt19937_64 generator)\n      : shard_index_(shard_index),\n        num_shards_(num_shards),\n        num_batches_(num_batches),\n        num_blocks_per_batch_(num_blocks_per_batch),\n        num_blocks_per_local_batch_(num_blocks_per_batch_ / num_shards_),\n        current_batch_pos_(0),\n        generator_(generator) {\n    batches_.resize(num_batches_);\n    std::iota(batches_.begin(), batches_.end(), 0);\n  }\n  ~RandomShuffleBatchGenerator() override = default;\n\n  void Next(size_t* blocks) override {\n    size_t target_batch_pos =\n        generator_() % (batches_.size() - current_batch_pos_) + current_batch_pos_;\n    if (target_batch_pos != current_batch_pos_) {\n      std::swap(batches_[target_batch_pos], batches_[current_batch_pos_]);\n    }\n    const size_t batch_index = batches_[current_batch_pos_];\n    for (size_t i = 0; i < num_blocks_per_local_batch_; ++i) {\n      blocks[i] =\n          batch_index * num_blocks_per_batch_ + shard_index_ * num_blocks_per_local_batch_ + i;\n    }\n    current_batch_pos_ = (current_batch_pos_ + 1) % batches_.size();\n    if (current_batch_pos_ == 0) { shard_index_ = (shard_index_ + 1) % num_shards_; }\n  }\n\n private:\n  size_t shard_index_;\n  size_t num_shards_;\n  size_t num_batches_;\n  size_t num_blocks_per_batch_;\n  size_t num_blocks_per_local_batch_;\n  std::vector<size_t> batches_;\n  size_t current_batch_pos_;\n  std::mt19937_64 generator_;\n};\n\nclass RawReaderKernelState final : public user_op::OpKernelState {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(RawReaderKernelState);\n  explicit RawReaderKernelState(user_op::KernelInitContext* ctx) {\n    const NdSbp& nd_sbp = ctx->NdSbp4ArgNameAndIndex(\"out\", 0);\n    num_shards_ = GetNumShards(*ctx->parallel_desc().hierarchy(), nd_sbp);\n    shard_index_ =\n        GetShardIndex(*ctx->parallel_desc().hierarchy(), nd_sbp, ctx->parallel_ctx().parallel_id());\n    batch_size_ = ctx->Attr<int64_t>(\"batch_size\");\n    CHECK_EQ(batch_size_ % num_shards_, 0) << \"batch_size must be a multiple of num_shards\";\n    local_batch_size_ = batch_size_ / num_shards_;\n    random_shuffle_ = ctx->Attr<bool>(\"random_shuffle\");\n    block_size_ = ctx->Attr<int64_t>(\"shuffle_block_size\");\n    if (block_size_ <= 0 || !random_shuffle_) { block_size_ = local_batch_size_; }\n    CHECK_EQ(batch_size_ % block_size_, 0) << \"batch_size must be a multiple of block_size\";\n    if (block_size_ > local_batch_size_) { block_size_ = local_batch_size_; }\n    const std::vector<std::string>& filenames = ctx->Attr<std::vector<std::string>>(\"files\");\n    const Shape& instance_shape = ctx->Attr<Shape>(\"shape\");\n    const size_t elem_cnt = instance_shape.elem_cnt();\n    CHECK_GT(elem_cnt, 0) << \"instance size must be greater than 0\";\n    DimVector dim_vec;\n    dim_vec.push_back(local_batch_size_);\n    for (int64_t i = 0; i < instance_shape.NumAxes(); ++i) {\n      dim_vec.push_back(instance_shape.At(i));\n    }\n    out_shape_ = Shape(dim_vec);\n    data_type_ = ctx->Attr<DataType>(\"data_type\");\n    instance_size_ = ctx->Attr<Shape>(\"shape\").elem_cnt() * GetSizeOfDataType(data_type_);\n    CHECK_GT(batch_size_, 0) << \"batch size must be greater than 0\";\n    size_t num_instances = 0;\n    std::vector<std::unique_ptr<embedding::PosixFile>> files;\n    int flags = O_RDONLY;\n    if (ParseBooleanFromEnv(\"ONEFLOW_RAW_READER_FORCE_DIRECT_IO\", false)) { flags |= O_DIRECT; }\n    for (const auto& filename : filenames) {\n      std::unique_ptr<embedding::PosixFile> file(new embedding::PosixFile(filename, flags, 0644));\n      if (file->Size() == 0) { continue; }\n      CHECK_EQ(file->Size() % instance_size_, 0) << \"file_size must be a multiple of instance_size\";\n      num_instances += file->Size() / instance_size_;\n      files.emplace_back(std::move(file));\n    }\n    if ((flags & O_DIRECT) != 0) {\n      num_batches_ = num_instances / batch_size_;\n    } else {\n      num_batches_ = RoundUp(num_instances, batch_size_) / batch_size_;\n    }\n    block_size_bytes_ = block_size_ * instance_size_;\n    local_batch_size_bytes_ = local_batch_size_ * instance_size_;\n    num_blocks_per_local_batch_ = local_batch_size_ / block_size_;\n    const size_t num_blocks = num_batches_ * (batch_size_ / block_size_);\n    size_t file_index = 0;\n    size_t offset_in_file = 0;\n    std::vector<Block> blocks;\n    for (size_t i = 0; i < num_blocks; ++i) {\n      blocks.emplace_back(Block{file_index, offset_in_file});\n      size_t remaining = block_size_bytes_;\n      while (remaining != 0) {\n        if (files[file_index]->Size() - offset_in_file >= remaining) {\n          offset_in_file += remaining;\n          if (offset_in_file == files[file_index]->Size()) { offset_in_file = 0; }\n          remaining = 0;\n        } else {\n          remaining -= (files[file_index]->Size() - offset_in_file);\n          offset_in_file = 0;\n          file_index = (file_index + 1) % files.size();\n        }\n      }\n    }\n    if (random_shuffle_) {\n      std::mt19937_64 generator;\n      generator.seed(ctx->Attr<int64_t>(\"seed\"));\n      std::shuffle(blocks.begin(), blocks.end(), generator);\n      batch_generator_.reset(new RandomShuffleBatchGenerator(\n          shard_index_, num_shards_, num_batches_, batch_size_ / block_size_, generator));\n    } else {\n      batch_generator_.reset(new SequentialBatchGenerator(shard_index_, num_shards_, num_batches_,\n                                                          batch_size_ / block_size_));\n    }\n    const size_t num_workers = ParseIntegerFromEnv(\"ONEFLOW_RAW_READER_NUM_WORKERS\", 1);\n    batch_reader_.reset(\n        new BatchReader(std::move(files), std::move(blocks), block_size_bytes_, num_workers));\n    prefetching_qd_ = ParseIntegerFromEnv(\"ONEFLOW_RAW_READER_PREFETCHING_QUEUE_DEPTH\", 256);\n    for (size_t i = 0; i < prefetching_qd_; ++i) {\n      BatchReaderRequest request;\n      request.blocks = std::make_shared<std::vector<size_t>>();\n      if (ctx->device_type() == DeviceType::kCPU) {\n        request.buffer = aligned_alloc(4096, RoundUp(local_batch_size_bytes_, 4096));  // NOLINT\n      } else if (ctx->device_type() == DeviceType::kCUDA) {\n#ifdef WITH_CUDA\n        int dev = 0;\n        OF_CUDA_CHECK(cudaGetDevice(&dev));\n        OF_CUDA_CHECK(NumaAwareCudaMallocHost(dev, &request.buffer, local_batch_size_bytes_));\n#else\n        UNIMPLEMENTED();\n#endif\n      } else {\n        UNIMPLEMENTED();\n      }\n      request.blocks = std::make_shared<std::vector<size_t>>(local_batch_size_ / block_size_);\n      batch_generator_->Next(request.blocks->data());\n      batch_reader_->SubmitRequest(std::move(request));\n    }\n    device_type_ = ctx->device_type();\n  }\n\n  ~RawReaderKernelState() {\n    for (size_t i = 0; i < prefetching_qd_; ++i) {\n      BatchReaderRequest request;\n      batch_reader_->WaitCompleted(&request);\n      if (device_type_ == DeviceType::kCPU) {\n        free(request.buffer);  // NOLINT\n      } else if (device_type_ == DeviceType::kCUDA) {\n#ifdef WITH_CUDA\n        OF_CUDA_CHECK(cudaFreeHost(request.buffer));\n#else\n        UNIMPLEMENTED();\n#endif\n      } else {\n        UNIMPLEMENTED();\n      }\n    }\n  }\n\n  void Next(user_op::KernelComputeContext* ctx) {\n    auto* tensor = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    CHECK_EQ(tensor->data_type(), data_type_) << \"data type mismatch\";\n    CHECK(tensor->shape_view() == ShapeView(out_shape_)) << \"shape mismatch\";\n    BatchReaderRequest request;\n    batch_reader_->WaitCompleted(&request);\n    if (ctx->stream()->device_type() == DeviceType::kCPU) {\n      std::memcpy(tensor->mut_dptr<char>(), request.buffer, local_batch_size_bytes_);\n    } else if (ctx->stream()->device_type() == DeviceType::kCUDA) {\n#ifdef WITH_CUDA\n      OF_CUDA_CHECK(cudaMemcpyAsync(tensor->mut_dptr<char>(), request.buffer,\n                                    local_batch_size_bytes_, cudaMemcpyDefault,\n                                    ctx->stream()->As<ep::CudaStream>()->cuda_stream()));\n#else\n      UNIMPLEMENTED();\n#endif\n    } else {\n      UNIMPLEMENTED() << \"only support CPU or CUDA\";\n    }\n    CHECK_JUST(ctx->stream()->Sync());\n    CHECK(request.blocks) << \"blocks is NULL\";\n    CHECK_EQ(request.blocks->size(), num_blocks_per_local_batch_) << \"blocks size mismatch\";\n    batch_generator_->Next(request.blocks->data());\n    batch_reader_->SubmitRequest(std::move(request));\n  }\n\n private:\n  size_t instance_size_;\n  size_t batch_size_;\n  size_t local_batch_size_;\n  size_t num_batches_;\n  size_t num_shards_;\n  size_t shard_index_;\n  size_t block_size_;\n  size_t block_size_bytes_;\n  size_t num_blocks_per_local_batch_;\n  size_t local_batch_size_bytes_;\n  bool random_shuffle_;\n  Shape out_shape_;\n  DataType data_type_;\n  std::unique_ptr<BatchGenerator> batch_generator_;\n  std::unique_ptr<BatchReader> batch_reader_;\n  DeviceType device_type_;\n  size_t prefetching_qd_;\n};\n\n}  // namespace\n\nclass RawReaderKernel final : public user_op::OpKernel {\n public:\n  RawReaderKernel() = default;\n  ~RawReaderKernel() override = default;\n\n  std::shared_ptr<user_op::OpKernelState> CreateOpKernelState(\n      user_op::KernelInitContext* ctx) const override {\n    std::shared_ptr<RawReaderKernelState> state(new RawReaderKernelState(ctx));\n    return state;\n  }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state,\n               const user_op::OpKernelCache*) const override {\n    auto* reader = CHECK_NOTNULL(dynamic_cast<RawReaderKernelState*>(state));\n    reader->Next(ctx);\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\nREGISTER_USER_KERNEL(\"raw_reader\").SetCreateFn<RawReaderKernel>();\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/reduce_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/scalar.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/ndarray/ndarray_util.h\"\n#include \"oneflow/core/ndarray/xpu_var_ndarray.h\"\n#include \"oneflow/core/kernel/kernel_util.h\"\n#include \"oneflow/core/kernel/cuda_graph_support.h\"\n#include \"oneflow/core/ep/include/primitive/cast.h\"\n#include \"oneflow/core/ep/include/primitive/fill.h\"\n\n#ifdef WITH_CUDA\n#include \"oneflow/core/ep/cuda/cuda_device.h\"\n#endif  // WITH_CUDA\n#include \"oneflow/core/ep/include/primitive/matmul.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nep::primitive::BlasTransposeType GetBlasTransposeType(bool transpose) {\n  return transpose ? ep::primitive::BlasTransposeType::T : ep::primitive::BlasTransposeType::N;\n}\n\nstd::unique_ptr<ep::primitive::Matmul> NewMatmulPrimitive(DeviceType device_type,\n                                                          DataType data_type, bool transpose_a,\n                                                          bool transpose_b) {\n  const auto trans_a = GetBlasTransposeType(transpose_a);\n  const auto trans_b = GetBlasTransposeType(transpose_b);\n  return ep::primitive::NewPrimitive<ep::primitive::MatmulFactory>(device_type, data_type, trans_a,\n                                                                   trans_b);\n}\n\ntemplate<typename Context>\nstd::unique_ptr<ep::primitive::Matmul> NewReduceMatmulTransAPrimitive(Context* ctx) {\n  const DataType data_type = ctx->TensorDesc4ArgNameAndIndex(\"input_tensor\", 0)->data_type();\n  return NewMatmulPrimitive(ctx->device_type(), data_type, /*transpose_a=*/true,\n                            /*transpose_b=*/false);\n}\n\ntemplate<typename Context>\nstd::unique_ptr<ep::primitive::Matmul> NewReduceMatmulNoTransAPrimitive(Context* ctx) {\n  const DataType data_type = ctx->TensorDesc4ArgNameAndIndex(\"input_tensor\", 0)->data_type();\n  return NewMatmulPrimitive(ctx->device_type(), data_type, /*transpose_a=*/false,\n                            /*transpose_b=*/false);\n}\n\ntemplate<typename Context>\nstd::unique_ptr<ep::primitive::Fill> NewFillPrimitive(Context* ctx) {\n  const DataType data_type = ctx->TensorDesc4ArgNameAndIndex(\"output_tensor\", 0)->data_type();\n  return ep::primitive::NewPrimitive<ep::primitive::FillFactory>(ctx->device_type(), data_type);\n}\n\nauto ReduceMatmulTransAPrimitiveExists() {\n  return hob::make_custom(\"ReduceMatmulTransAPrimitiveExists\",\n                          [](const user_op::KernelRegContext& ctx) {\n                            return NewReduceMatmulTransAPrimitive(&ctx).operator bool();\n                          });\n}\n\nauto ReduceMatmulNoTransAPrimitiveExists() {\n  return hob::make_custom(\"ReduceMatmulNoTransAPrimitiveExists\",\n                          [](const user_op::KernelRegContext& ctx) {\n                            return NewReduceMatmulNoTransAPrimitive(&ctx).operator bool();\n                          });\n}\n\nauto FillPrimitiveExists() {\n  return hob::make_custom(\"FillPrimitiveExists\", [](const user_op::KernelRegContext& ctx) {\n    return NewFillPrimitive(&ctx).operator bool();\n  });\n}\n\ntemplate<template<typename> class BinaryFunc, DeviceType device_type, typename T, typename K>\nclass ReduceKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport {\n public:\n  ReduceKernel() = default;\n  ~ReduceKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* input_tensor = ctx->Tensor4ArgNameAndIndex(\"input_tensor\", 0);\n    user_op::Tensor* output_tensor = ctx->Tensor4ArgNameAndIndex(\"output_tensor\", 0);\n    user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n    const auto& axis = ctx->Attr<std::vector<int32_t>>(\"axis\");\n    const int32_t output_elem_cnt = output_tensor->shape_view().elem_cnt();\n\n    if (input_tensor->shape_view().elem_cnt() == 0) {\n      if (output_tensor->shape_view().elem_cnt() != 0) {\n        Scalar init_value = [&]() {\n          if (std::is_same<BinaryFunc<T>, BinaryFuncAny<T>>::value) { return Scalar(0); }\n          if (std::is_same<BinaryFunc<T>, BinaryFuncAll<T>>::value) { return Scalar(1); }\n          return Scalar(0);\n        }();\n        CHECK_GE(output_elem_cnt, 0);\n        if (output_elem_cnt == 0) { return; }\n        std::unique_ptr<ep::primitive::Fill> fill = NewFillPrimitive(ctx);\n        CHECK(fill);\n        fill->Launch(ctx->stream(), output_tensor->mut_dptr<K>(), init_value, output_elem_cnt);\n      }\n      return;\n    }\n    const Shape& reduced_shape =\n        CreateReducedShape(input_tensor->shape_view(), {axis.begin(), axis.end()});\n    NdarrayReduce<device_type, T, BinaryFunc>::Reduce(\n        ctx->stream(), XpuVarNdarray<K>(reduced_shape, output_tensor->mut_dptr<K>()),\n        XpuVarNdarray<const T>(input_tensor->shape_view(), input_tensor->dptr<T>()),\n        XpuVarNdarray<T>(tmp_buffer->shape_view(), tmp_buffer->mut_dptr<T>()));\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n}  // namespace\n\n#define REGISTER_REDUCE_XPU_KERNEL(op_name, binary_func, device, dtype)                            \\\n  REGISTER_USER_KERNEL(op_name)                                                                    \\\n      .SetCreateFn<ReduceKernel<binary_func, device, dtype, dtype>>()                              \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == device)                                        \\\n                       && (user_op::HobDataType(\"output_tensor\", 0) == GetDataType<dtype>::value)) \\\n      .SetInferTmpSizeFn([](user_op::InferContext* ctx) {                                          \\\n        const Shape& in_shape = ctx->InputShape(\"input_tensor\", 0);                                \\\n        return in_shape.elem_cnt() * sizeof(dtype);                                                \\\n      });\n\n#define REGISTER_REDUCE_LOGICAL_XPU_KERNEL(op_name, binary_func, device, dtype)                  \\\n  REGISTER_USER_KERNEL(op_name)                                                                  \\\n      .SetCreateFn<ReduceKernel<binary_func, device, dtype, bool>>()                             \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == device)                                      \\\n                       && (user_op::HobDataType(\"input_tensor\", 0) == GetDataType<dtype>::value) \\\n                       && (user_op::HobDataType(\"output_tensor\", 0) == DataType::kBool)          \\\n                       && FillPrimitiveExists())                                                 \\\n      .SetInferTmpSizeFn([](user_op::InferContext* ctx) {                                        \\\n        const Shape& in_shape = ctx->InputShape(\"input_tensor\", 0);                              \\\n        return in_shape.elem_cnt() * sizeof(dtype);                                              \\\n      });\n\n#define REGISTER_REDUCE_ARITHMETIC_KERNELS(device, dtype)                  \\\n  REGISTER_REDUCE_XPU_KERNEL(\"reduce_prod\", BinaryFuncProd, device, dtype) \\\n  REGISTER_REDUCE_XPU_KERNEL(\"reduce_min\", BinaryFuncMin, device, dtype)   \\\n  REGISTER_REDUCE_XPU_KERNEL(\"reduce_max\", BinaryFuncMax, device, dtype)\n\n#define REGISTER_REDUCE_NANSUM_KERNELS(device, dtype) \\\n  REGISTER_REDUCE_XPU_KERNEL(\"reduce_nansum\", BinaryFuncNanSum, device, dtype)\n\n#define REGISTER_REDUCE_ARITHMETIC_KERNELS_BY_DEVICE(device) \\\n  REGISTER_REDUCE_ARITHMETIC_KERNELS(device, bool)           \\\n  REGISTER_REDUCE_ARITHMETIC_KERNELS(device, float)          \\\n  REGISTER_REDUCE_ARITHMETIC_KERNELS(device, double)         \\\n  REGISTER_REDUCE_ARITHMETIC_KERNELS(device, int8_t)         \\\n  REGISTER_REDUCE_ARITHMETIC_KERNELS(device, uint8_t)        \\\n  REGISTER_REDUCE_ARITHMETIC_KERNELS(device, int32_t)        \\\n  REGISTER_REDUCE_ARITHMETIC_KERNELS(device, int64_t)\n\n#define REGISTER_REDUCE_NANSUM_KERNELS_BY_DEVICE(device) \\\n  REGISTER_REDUCE_NANSUM_KERNELS(device, float)          \\\n  REGISTER_REDUCE_NANSUM_KERNELS(device, double)\n\nREGISTER_REDUCE_ARITHMETIC_KERNELS_BY_DEVICE(DeviceType::kCPU)\nREGISTER_REDUCE_NANSUM_KERNELS_BY_DEVICE(DeviceType::kCPU)\n#ifdef WITH_CUDA\nREGISTER_REDUCE_ARITHMETIC_KERNELS_BY_DEVICE(DeviceType::kCUDA)\nREGISTER_REDUCE_NANSUM_KERNELS_BY_DEVICE(DeviceType::kCUDA)\n#endif\n\n#define REGISTER_REDUCE_SUM_KERNELS(device, dtype) \\\n  REGISTER_REDUCE_XPU_KERNEL(\"reduce_sum\", BinaryFuncSum, device, dtype)\n\n#define REGISTER_REDUCE_SUM_KERNELS_BY_DEVICE(device) \\\n  REGISTER_REDUCE_SUM_KERNELS(device, double)         \\\n  REGISTER_REDUCE_SUM_KERNELS(device, int8_t)         \\\n  REGISTER_REDUCE_SUM_KERNELS(device, uint8_t)        \\\n  REGISTER_REDUCE_SUM_KERNELS(device, int32_t)        \\\n  REGISTER_REDUCE_SUM_KERNELS(device, int64_t)\n\nREGISTER_REDUCE_SUM_KERNELS(DeviceType::kCPU, std::complex<float>)\nREGISTER_REDUCE_SUM_KERNELS(DeviceType::kCPU, std::complex<double>)\n#ifdef WITH_CUDA\nREGISTER_REDUCE_SUM_KERNELS(DeviceType::kCUDA, cuComplex)\nREGISTER_REDUCE_SUM_KERNELS(DeviceType::kCUDA, cuDoubleComplex)\n#endif\n\nREGISTER_REDUCE_SUM_KERNELS_BY_DEVICE(DeviceType::kCPU)\n#ifdef WITH_CUDA\nREGISTER_REDUCE_SUM_KERNELS_BY_DEVICE(DeviceType::kCUDA)\n#endif\nREGISTER_REDUCE_SUM_KERNELS(DeviceType::kCPU, float)\nREGISTER_REDUCE_SUM_KERNELS(DeviceType::kCPU, float16)\n\n#define REGISTER_REDUCE_LOGICAL_KERNELS(device)                                    \\\n  REGISTER_REDUCE_LOGICAL_XPU_KERNEL(\"reduce_any\", BinaryFuncAny, device, bool)    \\\n  REGISTER_REDUCE_LOGICAL_XPU_KERNEL(\"reduce_all\", BinaryFuncAll, device, bool)    \\\n  REGISTER_REDUCE_LOGICAL_XPU_KERNEL(\"reduce_any\", BinaryFuncAny, device, float)   \\\n  REGISTER_REDUCE_LOGICAL_XPU_KERNEL(\"reduce_all\", BinaryFuncAll, device, float)   \\\n  REGISTER_REDUCE_LOGICAL_XPU_KERNEL(\"reduce_any\", BinaryFuncAny, device, double)  \\\n  REGISTER_REDUCE_LOGICAL_XPU_KERNEL(\"reduce_all\", BinaryFuncAll, device, double)  \\\n  REGISTER_REDUCE_LOGICAL_XPU_KERNEL(\"reduce_any\", BinaryFuncAny, device, int8_t)  \\\n  REGISTER_REDUCE_LOGICAL_XPU_KERNEL(\"reduce_all\", BinaryFuncAll, device, int8_t)  \\\n  REGISTER_REDUCE_LOGICAL_XPU_KERNEL(\"reduce_any\", BinaryFuncAny, device, uint8_t) \\\n  REGISTER_REDUCE_LOGICAL_XPU_KERNEL(\"reduce_all\", BinaryFuncAll, device, uint8_t) \\\n  REGISTER_REDUCE_LOGICAL_XPU_KERNEL(\"reduce_any\", BinaryFuncAny, device, int32_t) \\\n  REGISTER_REDUCE_LOGICAL_XPU_KERNEL(\"reduce_all\", BinaryFuncAll, device, int32_t) \\\n  REGISTER_REDUCE_LOGICAL_XPU_KERNEL(\"reduce_any\", BinaryFuncAny, device, int64_t) \\\n  REGISTER_REDUCE_LOGICAL_XPU_KERNEL(\"reduce_all\", BinaryFuncAll, device, int64_t)\n\nREGISTER_REDUCE_LOGICAL_KERNELS(DeviceType::kCPU)\n#ifdef WITH_CUDA\nREGISTER_REDUCE_LOGICAL_KERNELS(DeviceType::kCUDA)\n\nnamespace {\n\nstd::vector<int32_t> RegularAxis(const std::vector<int32_t>& axis) {\n  std::vector<int32_t> regular_axis = axis;\n  std::sort(regular_axis.begin(), regular_axis.end());\n  return regular_axis;\n}\n\nvoid GetReduceSumLayout(const std::vector<int32_t>& axis, const ShapeView& in_shape,\n                        bool* is_axis_contiguous, int64_t* outer_size, int64_t* inner_size,\n                        int64_t* reduce_size) {\n  if (!axis.empty()) {\n    *is_axis_contiguous = ((axis.back() - axis.front() + 1) == axis.size());\n    *outer_size = in_shape.Count(0, axis.front());\n    *inner_size = in_shape.Count(axis.back() + 1);\n    *reduce_size = in_shape.Count(axis.front(), axis.back() + 1);\n  }\n}\n\n}  // namespace\n\nclass ReduceSumHalfKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport {\n public:\n  ReduceSumHalfKernel() = default;\n  ~ReduceSumHalfKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    std::vector<int32_t> axis = RegularAxis(ctx->Attr<std::vector<int32_t>>(\"axis\"));\n    const user_op::Tensor* input_tensor = ctx->Tensor4ArgNameAndIndex(\"input_tensor\", 0);\n    user_op::Tensor* output_tensor = ctx->Tensor4ArgNameAndIndex(\"output_tensor\", 0);\n    user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n    const ShapeView& in_shape = input_tensor->shape_view();\n    const DataType data_type = input_tensor->data_type();\n    bool is_axis_contiguous = false;\n    int64_t outer_size = 0, inner_size = 0, reduce_size = 0;\n    GetReduceSumLayout(axis, in_shape, &is_axis_contiguous, &outer_size, &inner_size, &reduce_size);\n    if (is_axis_contiguous && (outer_size == 1 || inner_size == 1)) {\n      bool trans_a = (inner_size != 1);\n      const int32_t m = (inner_size == 1) ? outer_size : inner_size;\n      const int32_t n = 1;\n      const int32_t k = reduce_size;\n      const void* ones = nullptr;\n      auto* cuda_device = dynamic_cast<ep::CudaDevice*>(ctx->stream()->device());\n      if (cuda_device != nullptr) { ones = cuda_device->GetConstOnes(data_type, reduce_size); }\n      if (ones == nullptr) {\n        std::unique_ptr<ep::primitive::Fill> fill =\n            ep::primitive::NewPrimitive<ep::primitive::FillFactory>(ctx->stream()->device_type(),\n                                                                    data_type);\n        CHECK(fill);\n        fill->Launch(ctx->stream(), tmp_buffer->mut_dptr(), 1.0, reduce_size);\n        ones = tmp_buffer->dptr();\n      }\n      std::unique_ptr<ep::primitive::Matmul> matmul;\n      if (trans_a) {\n        matmul = NewReduceMatmulTransAPrimitive(ctx);\n      } else {\n        matmul = NewReduceMatmulNoTransAPrimitive(ctx);\n      }\n      matmul->Launch(ctx->stream(), m, n, k, 1.0, input_tensor->dptr(), ones, 0.0,\n                     output_tensor->mut_dptr());\n    } else {\n      const Shape& reduced_shape = CreateReducedShape(in_shape, {axis.begin(), axis.end()});\n      float* in_tmp_buffer = tmp_buffer->mut_dptr<float>();\n      const size_t in_tmp_buffer_bytes = GetCudaAlignedSize(in_shape.elem_cnt() * sizeof(float));\n      float* out_tmp_buffer =\n          reinterpret_cast<float*>(tmp_buffer->mut_dptr<char>() + in_tmp_buffer_bytes);\n      const size_t out_tmp_buffer_bytes =\n          GetCudaAlignedSize(reduced_shape.elem_cnt() * sizeof(float));\n      float* reduce_tmp_buffer = reinterpret_cast<float*>(\n          tmp_buffer->mut_dptr<char>() + in_tmp_buffer_bytes + out_tmp_buffer_bytes);\n      const size_t reduce_tmp_buffer_bytes =\n          GetCudaAlignedSize(in_shape.elem_cnt() * sizeof(float));\n      CHECK_LE(in_tmp_buffer_bytes + out_tmp_buffer_bytes + reduce_tmp_buffer_bytes,\n               tmp_buffer->shape_view().elem_cnt());\n      auto h2f = ep::primitive::NewPrimitive<ep::primitive::CastFactory>(\n          ctx->device_type(), data_type, DataType::kFloat);\n      CHECK(h2f);\n      auto f2h = ep::primitive::NewPrimitive<ep::primitive::CastFactory>(\n          ctx->device_type(), DataType::kFloat, data_type);\n      CHECK(f2h);\n      h2f->Launch(ctx->stream(), input_tensor->dptr(), in_tmp_buffer, in_shape.elem_cnt());\n\n      NdarrayReduce<DeviceType::kCUDA, float, BinaryFuncSum>::Reduce(\n          ctx->stream(), XpuVarNdarray<float>(reduced_shape, out_tmp_buffer),\n          XpuVarNdarray<const float>(in_shape, in_tmp_buffer),\n          XpuVarNdarray<float>(in_shape, reduce_tmp_buffer));\n\n      f2h->Launch(ctx->stream(), out_tmp_buffer, output_tensor->mut_dptr(),\n                  output_tensor->shape_view().elem_cnt());\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_REDUCE_SUM_HALF_KERNEL(dtype)                                                    \\\n  REGISTER_USER_KERNEL(\"reduce_sum\")                                                              \\\n      .SetCreateFn<ReduceSumHalfKernel>()                                                         \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)                            \\\n                       && (user_op::HobDataType(\"output_tensor\", 0) == GetDataType<dtype>::value) \\\n                       && ReduceMatmulTransAPrimitiveExists()                                     \\\n                       && ReduceMatmulNoTransAPrimitiveExists())                                  \\\n      .SetInferTmpSizeFn([](user_op::InferContext* ctx) {                                         \\\n        const Shape& in_shape = ctx->InputTensorDesc(\"input_tensor\", 0).shape();                  \\\n        const Shape& out_shape = ctx->OutputTensorDesc(\"output_tensor\", 0).shape();               \\\n        const auto& axis = RegularAxis(ctx->Attr<std::vector<int32_t>>(\"axis\"));                  \\\n        bool is_axis_contiguous = false;                                                          \\\n        int64_t outer_size = 0, inner_size = 0, reduce_size = 0;                                  \\\n        GetReduceSumLayout(axis, ShapeView(in_shape), &is_axis_contiguous, &outer_size,           \\\n                           &inner_size, &reduce_size);                                            \\\n        size_t tmp_bytes = 0;                                                                     \\\n        if (is_axis_contiguous && (outer_size == 1 || inner_size == 1)) {                         \\\n          tmp_bytes = GetCudaAlignedSize(reduce_size * sizeof(dtype));                            \\\n        } else {                                                                                  \\\n          tmp_bytes = (2 * GetCudaAlignedSize(in_shape.elem_cnt() * sizeof(float))                \\\n                       + GetCudaAlignedSize(out_shape.elem_cnt() * sizeof(float)));               \\\n        }                                                                                         \\\n        return tmp_bytes;                                                                         \\\n      });\n\nREGISTER_REDUCE_SUM_HALF_KERNEL(half)\n#if CUDA_VERSION >= 11000\nREGISTER_REDUCE_SUM_HALF_KERNEL(nv_bfloat16)\n#endif\n\nclass ReduceSumFloatCudaKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport {\n public:\n  ReduceSumFloatCudaKernel() = default;\n  ~ReduceSumFloatCudaKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    std::vector<int32_t> axis = RegularAxis(ctx->Attr<std::vector<int32_t>>(\"axis\"));\n    const user_op::Tensor* input_tensor = ctx->Tensor4ArgNameAndIndex(\"input_tensor\", 0);\n    user_op::Tensor* output_tensor = ctx->Tensor4ArgNameAndIndex(\"output_tensor\", 0);\n    user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n    const ShapeView& in_shape = input_tensor->shape_view();\n    if (input_tensor->shape_view().elem_cnt() == 0) {\n      if (output_tensor->shape_view().elem_cnt() != 0) {\n        Memset<DeviceType::kCUDA>(\n            ctx->stream(), output_tensor->mut_dptr<float>(), 0,\n            output_tensor->shape_view().elem_cnt() * GetSizeOfDataType(output_tensor->data_type()));\n      }\n      return;\n    }\n    bool is_axis_contiguous = false;\n    int64_t outer_size = 0, inner_size = 0, reduce_size = 0;\n    GetReduceSumLayout(axis, in_shape, &is_axis_contiguous, &outer_size, &inner_size, &reduce_size);\n    const float* ones = nullptr;\n    auto* cuda_device = dynamic_cast<ep::CudaDevice*>(ctx->stream()->device());\n    if (cuda_device != nullptr) {\n      ones = static_cast<const float*>(cuda_device->GetConstOnes(DataType::kFloat, reduce_size));\n    }\n    if ((!axis.empty()) && in_shape.NumAxes() > 0 && is_axis_contiguous\n        && (outer_size == 1 || inner_size == 1) && ones != nullptr\n        && ParseBooleanFromEnv(\"ONEFLOW_KERNEL_REDUCE_SUM_USE_MATMUL\", false)) {\n      ep::primitive::BlasTransposeType trans_a = (inner_size == 1)\n                                                     ? ep::primitive::BlasTransposeType::N\n                                                     : ep::primitive::BlasTransposeType::T;\n      ep::primitive::BlasTransposeType trans_b = ep::primitive::BlasTransposeType::N;\n      const int32_t m = (inner_size == 1) ? outer_size : inner_size;\n      const int32_t n = 1;\n      const int32_t k = reduce_size;\n#if CUDA_VERSION >= 11000\n      CublasMathModeGuard guard(ctx->stream()->As<ep::CudaStream>()->cublas_handle());\n      // disable tf32\n      guard.SetMathMode(CUBLAS_DEFAULT_MATH);\n#endif  // defined(WITH_CUDA) && CUDA_VERSION >= 11000\n      auto matmul = ep::primitive::NewPrimitive<ep::primitive::MatmulFactory>(\n          DeviceType::kCUDA, DataType::kFloat, trans_a, trans_b);\n      CHECK(matmul);\n      matmul->Launch(ctx->stream(), m, n, k, 1.0, input_tensor->dptr(), ones, 0.0,\n                     output_tensor->mut_dptr());\n    } else {\n      const Shape& reduced_shape = CreateReducedShape(in_shape, {axis.begin(), axis.end()});\n      NdarrayReduce<DeviceType::kCUDA, float, BinaryFuncSum>::Reduce(\n          ctx->stream(), XpuVarNdarray<float>(reduced_shape, output_tensor->mut_dptr<float>()),\n          XpuVarNdarray<const float>(input_tensor->shape_view(), input_tensor->dptr<float>()),\n          XpuVarNdarray<float>(tmp_buffer->shape_view(), tmp_buffer->mut_dptr<float>()));\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\nREGISTER_USER_KERNEL(\"reduce_sum\")\n    .SetCreateFn<ReduceSumFloatCudaKernel>()\n    .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)\n                     && (user_op::HobDataType(\"output_tensor\", 0) == DataType::kFloat))\n    .SetInferTmpSizeFn([](user_op::InferContext* ctx) {\n      const Shape& in_shape = ctx->InputTensorDesc(\"input_tensor\", 0).shape();\n      return GetCudaAlignedSize(in_shape.elem_cnt() * sizeof(float));\n    });\n\n#endif  // WITH_CUDA\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/reduce_like_kernels.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/ndarray/binary_func.h\"\n#include \"oneflow/core/ndarray/ndarray_util.h\"\n#include \"oneflow/core/kernel/kernel_util.h\"\n#include \"oneflow/core/kernel/cuda_graph_support.h\"\n#include \"oneflow/core/ep/include/primitive/cast.h\"\n#include \"oneflow/core/ep/include/primitive/fill.h\"\n#include \"oneflow/core/ep/include/primitive/matmul.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nep::primitive::BlasTransposeType GetBlasTransposeType(bool transpose) {\n  return transpose ? ep::primitive::BlasTransposeType::T : ep::primitive::BlasTransposeType::N;\n}\n\nstd::unique_ptr<ep::primitive::Matmul> NewMatmulPrimitive(DeviceType device_type,\n                                                          DataType data_type, bool transpose_a,\n                                                          bool transpose_b) {\n  const auto trans_a = GetBlasTransposeType(transpose_a);\n  const auto trans_b = GetBlasTransposeType(transpose_b);\n  return ep::primitive::NewPrimitive<ep::primitive::MatmulFactory>(device_type, data_type, trans_a,\n                                                                   trans_b);\n}\n\ntemplate<typename Context>\nstd::unique_ptr<ep::primitive::Matmul> NewReduceMatmulTransAPrimitive(Context* ctx) {\n  const DataType data_type = ctx->TensorDesc4ArgNameAndIndex(\"y\", 0)->data_type();\n  return NewMatmulPrimitive(ctx->device_type(), data_type, /*transpose_a=*/true,\n                            /*transpose_b=*/false);\n}\n\ntemplate<typename Context>\nstd::unique_ptr<ep::primitive::Matmul> NewReduceMatmulNoTransAPrimitive(Context* ctx) {\n  const DataType data_type = ctx->TensorDesc4ArgNameAndIndex(\"y\", 0)->data_type();\n  return NewMatmulPrimitive(ctx->device_type(), data_type, /*transpose_a=*/false,\n                            /*transpose_b=*/false);\n}\n\nauto ReduceMatmulTransAPrimitiveExists() {\n  return hob::make_custom(\"ReduceMatmulTransAPrimitiveExists\",\n                          [](const user_op::KernelRegContext& ctx) {\n                            return NewReduceMatmulTransAPrimitive(&ctx).operator bool();\n                          });\n}\n\nauto ReduceMatmulNoTransAPrimitiveExists() {\n  return hob::make_custom(\"ReduceMatmulNoTransAPrimitiveExists\",\n                          [](const user_op::KernelRegContext& ctx) {\n                            return NewReduceMatmulNoTransAPrimitive(&ctx).operator bool();\n                          });\n}\n\nsize_t ReduceSumLikeInferTmpSize(user_op::InferContext* ctx) {\n  if (ctx->Attr<std::vector<int32_t>>(\"axis\").empty()) { return 0; }\n  const user_op::TensorDesc& tensor_desc_x = ctx->InputTensorDesc(\"x\", 0);\n  return tensor_desc_x.shape().elem_cnt() * GetSizeOfDataType(tensor_desc_x.data_type());\n}\n\n}  // namespace\n\ntemplate<DeviceType device_type, typename T>\nclass ReduceSumLikeOpKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport {\n public:\n  ReduceSumLikeOpKernel() = default;\n  ~ReduceSumLikeOpKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    user_op::Tensor* tensor_x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    user_op::Tensor* tensor_y = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    const auto& axis = ctx->Attr<std::vector<int32_t>>(\"axis\");\n    if (tensor_x->shape_view().elem_cnt() == 0) {\n      if (tensor_y->shape_view().elem_cnt() != 0) {\n        Memset<device_type>(\n            ctx->stream(), tensor_y->mut_dptr<T>(), 0,\n            tensor_y->shape_view().elem_cnt() * GetSizeOfDataType(tensor_y->data_type()));\n      }\n      return;\n    }\n    if (axis.empty()) {\n      CHECK_EQ(tensor_x->shape_view(), tensor_y->shape_view());\n      Memcpy<device_type>(\n          ctx->stream(), tensor_y->mut_dptr(), tensor_x->dptr(),\n          tensor_x->shape_view().elem_cnt() * GetSizeOfDataType(tensor_x->data_type()));\n    } else {\n      user_op::Tensor* tensor_tmp = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n      T* temp_storage = static_cast<T*>(tensor_tmp->mut_dptr());\n      NdarrayUtil<device_type, T>::ReduceSum(\n          ctx->stream(),\n          XpuVarNdarray<T>(CreateReducedShape(tensor_x->shape_view(), {axis.begin(), axis.end()}),\n                           tensor_y->mut_dptr<T>()),\n          XpuVarNdarray<const T>(tensor_x->shape_view(), tensor_x->dptr<T>(),\n                                 tensor_x->shape_view().NumAxes()),\n          XpuVarNdarray<T>(tensor_x->shape_view(), temp_storage, tensor_x->shape_view().NumAxes()));\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_REDUCE_SUM_LIKE_KERNEL(device, data_type_pair)                                \\\n  REGISTER_USER_KERNEL(\"reduce_sum_like\")                                                      \\\n      .SetCreateFn<ReduceSumLikeOpKernel<device, OF_PP_PAIR_FIRST(data_type_pair)>>()          \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == device)                                    \\\n                       && (user_op::HobDataType(\"y\", 0) == OF_PP_PAIR_SECOND(data_type_pair))) \\\n      .SetInferTmpSizeFn(ReduceSumLikeInferTmpSize);\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_REDUCE_SUM_LIKE_KERNEL, DEVICE_TYPE_SEQ,\n                                 ARITHMETIC_DATA_TYPE_SEQ)\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_REDUCE_SUM_LIKE_KERNEL,\n                                 OF_PP_MAKE_TUPLE_SEQ(DeviceType::kCPU), COMPLEX_DATA_TYPE_SEQ);\n#if defined(WITH_CUDA)\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_REDUCE_SUM_LIKE_KERNEL,\n                                 OF_PP_MAKE_TUPLE_SEQ(DeviceType::kCUDA),\n                                 OF_PP_MAKE_TUPLE_SEQ(cuComplex, DataType::kComplex64));\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_REDUCE_SUM_LIKE_KERNEL,\n                                 OF_PP_MAKE_TUPLE_SEQ(DeviceType::kCUDA),\n                                 OF_PP_MAKE_TUPLE_SEQ(cuDoubleComplex, DataType::kComplex128));\n#endif  // WITH_CUDA\n\n#if defined(WITH_CUDA)\n\nnamespace {\n\nstd::vector<int32_t> RegularAxis(const std::vector<int32_t>& axis) {\n  std::vector<int32_t> regular_axis = axis;\n  std::sort(regular_axis.begin(), regular_axis.end());\n  return regular_axis;\n}\n\nvoid GetReduceSumLayout(const std::vector<int32_t>& axis, const ShapeView& in_shape,\n                        bool* is_axis_contiguous, int64_t* outer_size, int64_t* inner_size,\n                        int64_t* reduce_size) {\n  *is_axis_contiguous = ((axis.back() - axis.front() + 1) == axis.size());\n  *outer_size = in_shape.Count(0, axis.front());\n  *inner_size = in_shape.Count(axis.back() + 1);\n  *reduce_size = in_shape.Count(axis.front(), axis.back() + 1);\n}\n\n}  // namespace\n\nclass ReduceSumLikeHalfKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport {\n public:\n  ReduceSumLikeHalfKernel() = default;\n  ~ReduceSumLikeHalfKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    std::vector<int32_t> axis = RegularAxis(ctx->Attr<std::vector<int32_t>>(\"axis\"));\n    const user_op::Tensor* tensor_x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    user_op::Tensor* tensor_y = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    if (axis.empty()) {\n      CHECK_EQ(tensor_x->shape_view(), tensor_y->shape_view());\n      Memcpy<DeviceType::kCUDA>(\n          ctx->stream(), tensor_y->mut_dptr(), tensor_x->dptr(),\n          tensor_x->shape_view().elem_cnt() * GetSizeOfDataType(tensor_x->data_type()));\n    } else {\n      user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n      const ShapeView& in_shape = tensor_x->shape_view();\n      bool is_axis_contiguous = false;\n      int64_t outer_size = 0, inner_size = 0, reduce_size = 0;\n      GetReduceSumLayout(axis, in_shape, &is_axis_contiguous, &outer_size, &inner_size,\n                         &reduce_size);\n      if (is_axis_contiguous && (outer_size == 1 || inner_size == 1)) {\n        bool trans_a = (inner_size != 1);\n        const int32_t m = (inner_size == 1) ? outer_size : inner_size;\n        const int32_t n = 1;\n        const int32_t k = reduce_size;\n        std::unique_ptr<ep::primitive::Fill> fill =\n            ep::primitive::NewPrimitive<ep::primitive::FillFactory>(ctx->stream()->device_type(),\n                                                                    tensor_x->data_type());\n        CHECK(fill);\n        fill->Launch(ctx->stream(), tmp_buffer->mut_dptr(), 1.0, reduce_size);\n\n        std::unique_ptr<ep::primitive::Matmul> matmul;\n        if (trans_a) {\n          matmul = NewReduceMatmulTransAPrimitive(ctx);\n        } else {\n          matmul = NewReduceMatmulNoTransAPrimitive(ctx);\n        }\n        CHECK(matmul);\n        matmul->Launch(ctx->stream(), m, n, k, 1.0, tensor_x->dptr(), tmp_buffer->dptr(), 0.0,\n                       tensor_y->mut_dptr());\n\n      } else {\n        const Shape& reduced_shape = CreateReducedShape(in_shape, {axis.begin(), axis.end()});\n        float* in_tmp_buffer = tmp_buffer->mut_dptr<float>();\n        const size_t in_tmp_buffer_bytes = GetCudaAlignedSize(in_shape.elem_cnt() * sizeof(float));\n        float* out_tmp_buffer =\n            reinterpret_cast<float*>(tmp_buffer->mut_dptr<char>() + in_tmp_buffer_bytes);\n        const size_t out_tmp_buffer_bytes =\n            GetCudaAlignedSize(reduced_shape.elem_cnt() * sizeof(float));\n        float* reduce_tmp_buffer = reinterpret_cast<float*>(\n            tmp_buffer->mut_dptr<char>() + in_tmp_buffer_bytes + out_tmp_buffer_bytes);\n        const size_t reduce_tmp_buffer_bytes =\n            GetCudaAlignedSize(in_shape.elem_cnt() * sizeof(float));\n        CHECK_LE(in_tmp_buffer_bytes + out_tmp_buffer_bytes + reduce_tmp_buffer_bytes,\n                 tmp_buffer->shape_view().elem_cnt());\n        auto h2f = ep::primitive::NewPrimitive<ep::primitive::CastFactory>(\n            ctx->device_type(), tensor_x->data_type(), DataType::kFloat);\n        CHECK(h2f);\n        auto f2h = ep::primitive::NewPrimitive<ep::primitive::CastFactory>(\n            ctx->device_type(), DataType::kFloat, tensor_x->data_type());\n        CHECK(f2h);\n        h2f->Launch(ctx->stream(), tensor_x->dptr(), in_tmp_buffer, in_shape.elem_cnt());\n\n        NdarrayReduce<DeviceType::kCUDA, float, BinaryFuncSum>::Reduce(\n            ctx->stream(), XpuVarNdarray<float>(reduced_shape, out_tmp_buffer),\n            XpuVarNdarray<const float>(in_shape, in_tmp_buffer),\n            XpuVarNdarray<float>(in_shape, reduce_tmp_buffer));\n\n        f2h->Launch(ctx->stream(), out_tmp_buffer, tensor_y->mut_dptr(),\n                    tensor_y->shape_view().elem_cnt());\n      }\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_REDUCE_SUM_LIKE_HALF_KERNEL(dtype)                                     \\\n  REGISTER_USER_KERNEL(\"reduce_sum_like\")                                               \\\n      .SetCreateFn<ReduceSumLikeHalfKernel>()                                           \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)                  \\\n                       && (user_op::HobDataType(\"y\", 0) == GetDataType<dtype>::value)   \\\n                       && ReduceMatmulTransAPrimitiveExists()                           \\\n                       && ReduceMatmulNoTransAPrimitiveExists())                        \\\n      .SetInferTmpSizeFn([](user_op::InferContext* ctx) {                               \\\n        const Shape& in_shape = ctx->InputTensorDesc(\"x\", 0).shape();                   \\\n        const Shape& out_shape = ctx->OutputTensorDesc(\"y\", 0).shape();                 \\\n        const auto& axis = RegularAxis(ctx->Attr<std::vector<int32_t>>(\"axis\"));        \\\n        if (axis.empty()) {                                                             \\\n          size_t tmp_bytes = 0;                                                         \\\n          return tmp_bytes;                                                             \\\n        }                                                                               \\\n        bool is_axis_contiguous = false;                                                \\\n        int64_t outer_size = 0, inner_size = 0, reduce_size = 0;                        \\\n        GetReduceSumLayout(axis, ShapeView(in_shape), &is_axis_contiguous, &outer_size, \\\n                           &inner_size, &reduce_size);                                  \\\n        size_t tmp_bytes = 0;                                                           \\\n        if (is_axis_contiguous && (outer_size == 1 || inner_size == 1)) {               \\\n          tmp_bytes = GetCudaAlignedSize(reduce_size * sizeof(dtype));                  \\\n        } else {                                                                        \\\n          tmp_bytes = (2 * GetCudaAlignedSize(in_shape.elem_cnt() * sizeof(float))      \\\n                       + GetCudaAlignedSize(out_shape.elem_cnt() * sizeof(float)));     \\\n        }                                                                               \\\n        return tmp_bytes;                                                               \\\n      });\nREGISTER_REDUCE_SUM_LIKE_HALF_KERNEL(half)\n#if CUDA_VERSION >= 11000\nREGISTER_REDUCE_SUM_LIKE_HALF_KERNEL(nv_bfloat16)\n#endif\n\n#endif\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/reflection_pad_kernels.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/nd_index_offset_helper.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/new_kernel_util.h\"\n#include \"oneflow/user/kernels/reflection_pad_kernels_util.h\"\n\nnamespace oneflow {\nnamespace user_op {\n\ntemplate<DeviceType device_type, typename IN_T>\nclass ReflectionPad1dKernel final : public OpKernel {\n public:\n  ReflectionPad1dKernel() = default;\n  ~ReflectionPad1dKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const Tensor* x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    Tensor* y = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    const auto& padding = ctx->Attr<std::vector<int64_t>>(\"padding\");\n    const int64_t ndims = x->shape_view().NumAxes();\n    CHECK_EQ(padding.size(), ndims - 1);\n    const int64_t n_idx = 0;\n    const int64_t c_idx = 1;\n    const int64_t w_idx = 2;\n\n    const int64_t pad_left = padding[0];\n\n    const int64_t n_batch = y->shape_view().At(n_idx);\n    const int64_t n_channel = y->shape_view().At(c_idx);\n    const int64_t y_width = y->shape_view().At(w_idx);\n    const int64_t x_width = x->shape_view().At(w_idx);\n\n    IN_T* dest = y->mut_dptr<IN_T>();\n    const IN_T* src = x->dptr<IN_T>();\n    DimVector y_vector;\n    y->shape_view().ToDimVector(&y_vector);\n    NdIndexOffsetHelper<int64_t, 3> index_helper(y_vector.data());\n\n    ReflectionPad1dFunctor<device_type, IN_T>()(ctx->stream(), src, dest, index_helper, n_batch,\n                                                n_channel, y_width, x_width, pad_left);\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\ntemplate<DeviceType device_type, typename IN_T>\nclass ReflectionPad1dGradKernel final : public OpKernel {\n public:\n  ReflectionPad1dGradKernel() = default;\n  ~ReflectionPad1dGradKernel() = default;\n\n private:\n  void Compute(KernelComputeContext* ctx) const override {\n    const Tensor* dy = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    Tensor* dx = ctx->Tensor4ArgNameAndIndex(\"dx\", 0);\n    const auto& padding = ctx->Attr<std::vector<int64_t>>(\"padding\");\n    const int64_t ndims = dy->shape_view().NumAxes();\n    CHECK_EQ(padding.size(), ndims - 1);\n\n    const int64_t n_idx = 0;\n    const int64_t c_idx = 1;\n    const int64_t w_idx = 2;\n\n    const int64_t pad_left = padding[0];\n    const int64_t n_batch = dy->shape_view().At(n_idx);\n    const int64_t n_channel = dy->shape_view().At(c_idx);\n    const int64_t dy_width = dy->shape_view().At(w_idx);\n    const int64_t dx_width = dx->shape_view().At(w_idx);\n\n    const IN_T* src = dy->dptr<IN_T>();\n    IN_T* dest = dx->mut_dptr<IN_T>();\n    DimVector dy_vector;\n    dy->shape_view().ToDimVector(&dy_vector);\n    NdIndexOffsetHelper<int64_t, 3> index_helper(dy_vector.data());\n\n    size_t out_bytes_size = dx->shape_view().elem_cnt() * GetSizeOfDataType(dx->data_type());\n    Memset<device_type>(ctx->stream(), dest, 0, out_bytes_size);\n\n    ReflectionPad1dGradFunctor<device_type, IN_T>()(ctx->stream(), src, dest, index_helper, n_batch,\n                                                    n_channel, dy_width, dx_width, pad_left);\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\ntemplate<DeviceType device_type, typename IN_T>\nclass ReflectionPad2dKernel final : public OpKernel {\n public:\n  ReflectionPad2dKernel() = default;\n  ~ReflectionPad2dKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const Tensor* x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    Tensor* y = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    const auto& padding = ctx->Attr<std::vector<int64_t>>(\"padding\");\n    const int64_t n_idx = 0;\n    const int64_t c_idx = 1;\n    const int64_t h_idx = 2;\n    const int64_t w_idx = 3;\n\n    const int64_t pad_left = padding[0];\n    const int64_t pad_top = padding[2];\n\n    const int64_t n_batch = y->shape_view().At(n_idx);\n    const int64_t n_channel = y->shape_view().At(c_idx);\n    const int64_t y_height = y->shape_view().At(h_idx);\n    const int64_t y_width = y->shape_view().At(w_idx);\n    const int64_t x_height = x->shape_view().At(h_idx);\n    const int64_t x_width = x->shape_view().At(w_idx);\n\n    IN_T* dest = y->mut_dptr<IN_T>();\n    const IN_T* src = x->dptr<IN_T>();\n    DimVector y_vector;\n    y->shape_view().ToDimVector(&y_vector);\n    NdIndexOffsetHelper<int64_t, 4> index_helper(y_vector.data());\n\n    ReflectionPad2dFunctor<device_type, IN_T>()(ctx->stream(), src, dest, index_helper, n_batch,\n                                                n_channel, y_height, y_width, x_height, x_width,\n                                                pad_left, pad_top);\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\ntemplate<DeviceType device_type, typename IN_T>\nclass ReflectionPad2dGradKernel final : public OpKernel {\n public:\n  ReflectionPad2dGradKernel() = default;\n  ~ReflectionPad2dGradKernel() = default;\n\n private:\n  void Compute(KernelComputeContext* ctx) const override {\n    const Tensor* dy = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    Tensor* dx = ctx->Tensor4ArgNameAndIndex(\"dx\", 0);\n    const auto& padding = ctx->Attr<std::vector<int64_t>>(\"padding\");\n\n    const int64_t n_idx = 0;\n    const int64_t c_idx = 1;\n    const int64_t h_idx = 2;\n    const int64_t w_idx = 3;\n\n    int64_t pad_left = padding[0];\n    int64_t pad_top = padding[2];\n    int64_t n_batch = dy->shape_view().At(n_idx);\n    int64_t n_channel = dy->shape_view().At(c_idx);\n    int64_t dy_height = dy->shape_view().At(h_idx);\n    int64_t dy_width = dy->shape_view().At(w_idx);\n    int64_t dx_height = dx->shape_view().At(h_idx);\n    int64_t dx_width = dx->shape_view().At(w_idx);\n\n    const IN_T* src = dy->dptr<IN_T>();\n    IN_T* dest = dx->mut_dptr<IN_T>();\n    DimVector dy_vector;\n    dy->shape_view().ToDimVector(&dy_vector);\n    NdIndexOffsetHelper<int64_t, 4> index_helper(dy_vector.data());\n\n    size_t out_bytes_size = dx->shape_view().elem_cnt() * GetSizeOfDataType(dx->data_type());\n    Memset<device_type>(ctx->stream(), dest, 0, out_bytes_size);\n\n    ReflectionPad2dGradFunctor<device_type, IN_T>()(ctx->stream(), src, dest, index_helper, n_batch,\n                                                    n_channel, dy_height, dy_width, dx_height,\n                                                    dx_width, pad_left, pad_top);\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_REFLECTION_PAD_ND_KERNELS(device, dtype)                                \\\n  REGISTER_USER_KERNEL(\"reflection_pad1d\")                                               \\\n      .SetCreateFn<ReflectionPad1dKernel<device, dtype>>()                               \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == device)                              \\\n                       && (user_op::HobDataType(\"y\", 0) == GetDataType<dtype>::value));  \\\n  REGISTER_USER_KERNEL(\"reflection_pad1d_grad\")                                          \\\n      .SetCreateFn<ReflectionPad1dGradKernel<device, dtype>>()                           \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == device)                              \\\n                       && (user_op::HobDataType(\"dx\", 0) == GetDataType<dtype>::value)); \\\n  REGISTER_USER_KERNEL(\"reflection_pad2d\")                                               \\\n      .SetCreateFn<ReflectionPad2dKernel<device, dtype>>()                               \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == device)                              \\\n                       && (user_op::HobDataType(\"y\", 0) == GetDataType<dtype>::value));  \\\n  REGISTER_USER_KERNEL(\"reflection_pad2d_grad\")                                          \\\n      .SetCreateFn<ReflectionPad2dGradKernel<device, dtype>>()                           \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == device)                              \\\n                       && (user_op::HobDataType(\"dx\", 0) == GetDataType<dtype>::value));\n\n#define REGISTER_REFLECTION_PAD_ND_WITH_DEVICE(device) \\\n  REGISTER_REFLECTION_PAD_ND_KERNELS(device, float)    \\\n  REGISTER_REFLECTION_PAD_ND_KERNELS(device, double)   \\\n  REGISTER_REFLECTION_PAD_ND_KERNELS(device, int32_t)\n\nREGISTER_REFLECTION_PAD_ND_WITH_DEVICE(DeviceType::kCPU)\n#ifdef WITH_CUDA\nREGISTER_REFLECTION_PAD_ND_WITH_DEVICE(DeviceType::kCUDA)\nREGISTER_REFLECTION_PAD_ND_KERNELS(DeviceType::kCUDA, float16)\n#endif\n\n}  // namespace user_op\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/reflection_pad_kernels_util.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/user/kernels/reflection_pad_kernels_util.h\"\n#include \"oneflow/core/framework/framework.h\"\n\nnamespace oneflow {\nnamespace user_op {\n\ntemplate<typename IN_T>\nstruct ReflectionPad1dFunctor<DeviceType::kCPU, IN_T> final {\n  void operator()(ep::Stream* stream, const IN_T* src, IN_T* dest,\n                  const NdIndexOffsetHelper<int64_t, 3>& index_helper, const int64_t n_batch,\n                  const int64_t n_channel, const int64_t y_width, const int64_t x_width,\n                  const int64_t pad_left) {\n    const int64_t dest_num = n_channel * y_width;\n    const int64_t src_num = n_channel * x_width;\n    const int64_t elem_num = n_batch * dest_num;\n    DoReflectionPad1d<IN_T>(src, dest, index_helper, elem_num, src_num, dest_num, y_width, x_width,\n                            pad_left);\n  }\n};\n\ntemplate<typename IN_T>\nstruct ReflectionPad1dGradFunctor<DeviceType::kCPU, IN_T> final {\n  void operator()(ep::Stream* stream, const IN_T* src, IN_T* dest,\n                  const NdIndexOffsetHelper<int64_t, 3>& index_helper, const int64_t n_batch,\n                  const int64_t n_channel, const int64_t dy_width, const int64_t dx_width,\n                  const int64_t pad_left) {\n    const int64_t dest_num = n_channel * dx_width;\n    const int64_t src_num = n_channel * dy_width;\n    const int64_t elem_num = n_batch * src_num;\n    DoReflectionPad1dGrad<IN_T>(src, dest, index_helper, elem_num, src_num, dest_num, dy_width,\n                                dx_width, pad_left);\n  }\n};\n\ntemplate<typename IN_T>\nstruct ReflectionPad2dFunctor<DeviceType::kCPU, IN_T> final {\n  void operator()(ep::Stream* stream, const IN_T* src, IN_T* dest,\n                  const NdIndexOffsetHelper<int64_t, 4>& index_helper, const int64_t n_batch,\n                  const int64_t n_channel, const int64_t y_height, const int64_t y_width,\n                  const int64_t x_height, const int64_t x_width, const int64_t pad_left,\n                  const int64_t pad_top) {\n    const int64_t dest_num = n_channel * y_height * y_width;\n    const int64_t src_num = n_channel * x_height * x_width;\n    const int64_t elem_num = n_batch * dest_num;\n    DoReflectionPad2d<IN_T>(src, dest, index_helper, elem_num, src_num, dest_num, y_height, y_width,\n                            x_height, x_width, pad_left, pad_top);\n  }\n};\n\ntemplate<typename IN_T>\nstruct ReflectionPad2dGradFunctor<DeviceType::kCPU, IN_T> final {\n  void operator()(ep::Stream* stream, const IN_T* src, IN_T* dest,\n                  const NdIndexOffsetHelper<int64_t, 4>& index_helper, const int64_t n_batch,\n                  const int64_t n_channel, const int64_t dy_height, const int64_t dy_width,\n                  const int64_t dx_height, const int64_t dx_width, const int64_t pad_left,\n                  const int64_t pad_top) {\n    const int64_t dest_num = n_channel * dx_height * dx_width;\n    const int64_t src_num = n_channel * dy_height * dy_width;\n    const int64_t elem_num = n_batch * src_num;\n    DoReflectionPad2dGrad<IN_T>(src, dest, index_helper, elem_num, src_num, dest_num, dy_height,\n                                dy_width, dx_height, dx_width, pad_left, pad_top);\n  }\n};\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_REFLECTION_PAD_FUNCTOR, (DeviceType::kCPU),\n                                 PADDING_DATA_TYPE_CPU_SEQ);\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_REFLECTION_PAD_GRAD_FUNCTOR, (DeviceType::kCPU),\n                                 PADDING_DATA_TYPE_CPU_SEQ);\n\n}  // namespace user_op\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/reflection_pad_kernels_util.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifdef WITH_CUDA\n#include \"oneflow/core/common/data_type.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/user/kernels/reflection_pad_kernels_util.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n\nnamespace oneflow {\nnamespace user_op {\n\ntemplate<typename IN_T>\n__global__ void DoCUDAReflectionPad1d(const IN_T* src, IN_T* dest,\n                                      const NdIndexOffsetHelper<int64_t, 3> index_helper,\n                                      const int64_t elem_num, const int64_t src_num,\n                                      const int64_t dest_num, const int64_t y_width,\n                                      const int64_t x_width, const int64_t pad_left) {\n  DoReflectionPad1d<IN_T>(src, dest, index_helper, elem_num, src_num, dest_num, y_width, x_width,\n                          pad_left);\n};\n\ntemplate<typename IN_T>\n__global__ void DoCUDAReflectionPad1dGrad(const IN_T* src, IN_T* dest,\n                                          const NdIndexOffsetHelper<int64_t, 3> index_helper,\n                                          const int64_t elem_num, const int64_t src_num,\n                                          const int64_t dest_num, const int64_t dy_width,\n                                          const int64_t dx_width, const int64_t pad_left) {\n  DoReflectionPad1dGrad<IN_T>(src, dest, index_helper, elem_num, src_num, dest_num, dy_width,\n                              dx_width, pad_left);\n};\n\ntemplate<typename IN_T>\n__global__ void DoCUDAReflectionPad2d(const IN_T* src, IN_T* dest,\n                                      const NdIndexOffsetHelper<int64_t, 4> index_helper,\n                                      const int64_t elem_num, const int64_t src_num,\n                                      const int64_t dest_num, const int64_t y_height,\n                                      const int64_t y_width, const int64_t x_height,\n                                      const int64_t x_width, const int64_t pad_left,\n                                      const int64_t pad_top) {\n  DoReflectionPad2d<IN_T>(src, dest, index_helper, elem_num, src_num, dest_num, y_height, y_width,\n                          x_height, x_width, pad_left, pad_top);\n};\n\ntemplate<typename IN_T>\n__global__ void DoCUDAReflectionPad2dGrad(const IN_T* src, IN_T* dest,\n                                          const NdIndexOffsetHelper<int64_t, 4> index_helper,\n                                          const int64_t elem_num, const int64_t src_num,\n                                          const int64_t dest_num, const int64_t dy_height,\n                                          const int64_t dy_width, const int64_t dx_height,\n                                          const int64_t dx_width, const int64_t pad_left,\n                                          const int64_t pad_top) {\n  DoReflectionPad2dGrad<IN_T>(src, dest, index_helper, elem_num, src_num, dest_num, dy_height,\n                              dy_width, dx_height, dx_width, pad_left, pad_top);\n};\n\ntemplate<typename IN_T>\nstruct ReflectionPad1dFunctor<DeviceType::kCUDA, IN_T> final {\n  void operator()(ep::Stream* stream, const IN_T* src, IN_T* dest,\n                  const NdIndexOffsetHelper<int64_t, 3>& index_helper, const int64_t n_batch,\n                  const int64_t n_channel, const int64_t y_width, const int64_t x_width,\n                  const int64_t pad_left) {\n    const int64_t dest_num = n_channel * y_width;\n    const int64_t src_num = n_channel * x_width;\n    const int64_t elem_num = n_batch * dest_num;\n    DoCUDAReflectionPad1d<IN_T><<<BlocksNum4ThreadsNum(elem_num), kCudaThreadsNumPerBlock, 0,\n                                  stream->As<ep::CudaStream>()->cuda_stream()>>>(\n        src, dest, index_helper, elem_num, src_num, dest_num, y_width, x_width, pad_left);\n  }\n};\n\n// float16 implementation\ntemplate<>\nvoid ReflectionPad1dFunctor<DeviceType::kCUDA, float16>::operator()(\n    ep::Stream* stream, const float16* src, float16* dest,\n    const NdIndexOffsetHelper<int64_t, 3>& index_helper, const int64_t n_batch,\n    const int64_t n_channel, const int64_t y_width, const int64_t x_width, const int64_t pad_left) {\n  const int64_t dest_num = n_channel * y_width;\n  const int64_t src_num = n_channel * x_width;\n  const int64_t elem_num = n_batch * dest_num;\n  DoCUDAReflectionPad1d<half><<<BlocksNum4ThreadsNum(elem_num), kCudaThreadsNumPerBlock, 0,\n                                stream->As<ep::CudaStream>()->cuda_stream()>>>(\n      reinterpret_cast<const half*>(src), reinterpret_cast<half*>(dest), index_helper, elem_num,\n      src_num, dest_num, y_width, x_width, pad_left);\n}\n\ntemplate<typename IN_T>\nstruct ReflectionPad1dGradFunctor<DeviceType::kCUDA, IN_T> final {\n  void operator()(ep::Stream* stream, const IN_T* src, IN_T* dest,\n                  const NdIndexOffsetHelper<int64_t, 3>& index_helper, const int64_t n_batch,\n                  const int64_t n_channel, const int64_t dy_width, const int64_t dx_width,\n                  const int64_t pad_left) {\n    const int64_t dest_num = n_channel * dx_width;\n    const int64_t src_num = n_channel * dy_width;\n    const int64_t elem_num = n_batch * src_num;\n    DoCUDAReflectionPad1dGrad<IN_T><<<BlocksNum4ThreadsNum(elem_num), kCudaThreadsNumPerBlock, 0,\n                                      stream->As<ep::CudaStream>()->cuda_stream()>>>(\n        src, dest, index_helper, elem_num, src_num, dest_num, dy_width, dx_width, pad_left);\n  }\n};\n\n// float16 implementation\ntemplate<>\nvoid ReflectionPad1dGradFunctor<DeviceType::kCUDA, float16>::operator()(\n    ep::Stream* stream, const float16* src, float16* dest,\n    const NdIndexOffsetHelper<int64_t, 3>& index_helper, const int64_t n_batch,\n    const int64_t n_channel, const int64_t dy_width, const int64_t dx_width,\n    const int64_t pad_left) {\n  const int64_t dest_num = n_channel * dx_width;\n  const int64_t src_num = n_channel * dy_width;\n  const int64_t elem_num = n_batch * src_num;\n  DoCUDAReflectionPad1dGrad<half><<<BlocksNum4ThreadsNum(elem_num), kCudaThreadsNumPerBlock, 0,\n                                    stream->As<ep::CudaStream>()->cuda_stream()>>>(\n      reinterpret_cast<const half*>(src), reinterpret_cast<half*>(dest), index_helper, elem_num,\n      src_num, dest_num, dy_width, dx_width, pad_left);\n}\n\ntemplate<typename IN_T>\nstruct ReflectionPad2dFunctor<DeviceType::kCUDA, IN_T> final {\n  void operator()(ep::Stream* stream, const IN_T* src, IN_T* dest,\n                  const NdIndexOffsetHelper<int64_t, 4>& index_helper, const int64_t n_batch,\n                  const int64_t n_channel, const int64_t y_height, const int64_t y_width,\n                  const int64_t x_height, const int64_t x_width, const int64_t pad_left,\n                  const int64_t pad_top) {\n    const int64_t dest_num = n_channel * y_height * y_width;\n    const int64_t src_num = n_channel * x_height * x_width;\n    const int64_t elem_num = n_batch * dest_num;\n    DoCUDAReflectionPad2d<IN_T><<<BlocksNum4ThreadsNum(elem_num), kCudaThreadsNumPerBlock, 0,\n                                  stream->As<ep::CudaStream>()->cuda_stream()>>>(\n        src, dest, index_helper, elem_num, src_num, dest_num, y_height, y_width, x_height, x_width,\n        pad_left, pad_top);\n  }\n};\n\n// float16 implementation\ntemplate<>\nvoid ReflectionPad2dFunctor<DeviceType::kCUDA, float16>::operator()(\n    ep::Stream* stream, const float16* src, float16* dest,\n    const NdIndexOffsetHelper<int64_t, 4>& index_helper, const int64_t n_batch,\n    const int64_t n_channel, const int64_t y_height, const int64_t y_width, const int64_t x_height,\n    const int64_t x_width, const int64_t pad_left, const int64_t pad_top) {\n  const int64_t dest_num = n_channel * y_height * y_width;\n  const int64_t src_num = n_channel * x_height * x_width;\n  const int64_t elem_num = n_batch * dest_num;\n  DoCUDAReflectionPad2d<half><<<BlocksNum4ThreadsNum(elem_num), kCudaThreadsNumPerBlock, 0,\n                                stream->As<ep::CudaStream>()->cuda_stream()>>>(\n      reinterpret_cast<const half*>(src), reinterpret_cast<half*>(dest), index_helper, elem_num,\n      src_num, dest_num, y_height, y_width, x_height, x_width, pad_left, pad_top);\n}\n\ntemplate<typename IN_T>\nstruct ReflectionPad2dGradFunctor<DeviceType::kCUDA, IN_T> final {\n  void operator()(ep::Stream* stream, const IN_T* src, IN_T* dest,\n                  const NdIndexOffsetHelper<int64_t, 4>& index_helper, const int64_t n_batch,\n                  const int64_t n_channel, const int64_t dy_height, const int64_t dy_width,\n                  const int64_t dx_height, const int64_t dx_width, const int64_t pad_left,\n                  const int64_t pad_top) {\n    const int64_t dest_num = n_channel * dx_height * dx_width;\n    const int64_t src_num = n_channel * dy_height * dy_width;\n    const int64_t elem_num = n_batch * src_num;\n    DoCUDAReflectionPad2dGrad<IN_T><<<BlocksNum4ThreadsNum(elem_num), kCudaThreadsNumPerBlock, 0,\n                                      stream->As<ep::CudaStream>()->cuda_stream()>>>(\n        src, dest, index_helper, elem_num, src_num, dest_num, dy_height, dy_width, dx_height,\n        dx_width, pad_left, pad_top);\n  }\n};\n\n// float16 implementation\ntemplate<>\nvoid ReflectionPad2dGradFunctor<DeviceType::kCUDA, float16>::operator()(\n    ep::Stream* stream, const float16* src, float16* dest,\n    const NdIndexOffsetHelper<int64_t, 4>& index_helper, const int64_t n_batch,\n    const int64_t n_channel, const int64_t dy_height, const int64_t dy_width,\n    const int64_t dx_height, const int64_t dx_width, const int64_t pad_left,\n    const int64_t pad_top) {\n  const int64_t dest_num = n_channel * dx_height * dx_width;\n  const int64_t src_num = n_channel * dy_height * dy_width;\n  const int64_t elem_num = n_batch * src_num;\n  DoCUDAReflectionPad2dGrad<half><<<BlocksNum4ThreadsNum(elem_num), kCudaThreadsNumPerBlock, 0,\n                                    stream->As<ep::CudaStream>()->cuda_stream()>>>(\n      reinterpret_cast<const half*>(src), reinterpret_cast<half*>(dest), index_helper, elem_num,\n      src_num, dest_num, dy_height, dy_width, dx_height, dx_width, pad_left, pad_top);\n}\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_REFLECTION_PAD_FUNCTOR,\n                                 OF_PP_MAKE_TUPLE_SEQ(DeviceType::kCUDA),\n                                 PADDING_DATA_TYPE_CUDA_SEQ);\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_REFLECTION_PAD_GRAD_FUNCTOR,\n                                 OF_PP_MAKE_TUPLE_SEQ(DeviceType::kCUDA),\n                                 PADDING_DATA_TYPE_CUDA_SEQ);\n\n}  // namespace user_op\n}  // namespace oneflow\n\n#endif  // WITH_CUDA\n"
  },
  {
    "path": "oneflow/user/kernels/reflection_pad_kernels_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_KERNELS_REFLECTION_PAD_KERNELS_UTIL_H_\n#define ONEFLOW_USER_KERNELS_REFLECTION_PAD_KERNELS_UTIL_H_\n#ifdef WITH_CUDA\n#include \"oneflow/core/cuda/atomic.cuh\"\n#endif  // WITH_CUDA\n#include \"oneflow/core/common/nd_index_offset_helper.h\"\n#include \"oneflow/core/ndarray/xpu_util.h\"\n\nnamespace oneflow {\n\n#define PADDING_DATA_TYPE_CPU_SEQ \\\n  FLOATING_DATA_TYPE_SEQ          \\\n  OF_PP_MAKE_TUPLE_SEQ(int32_t, DataType::kInt32)\n\n#define PADDING_DATA_TYPE_CUDA_SEQ \\\n  FLOAT16_DATA_TYPE_SEQ            \\\n  PADDING_DATA_TYPE_CPU_SEQ\n\nnamespace user_op {\n\ntemplate<typename T>\nstruct DeviceAdd {\n  OF_DEVICE_FUNC static void Invoke(const T* x, T* y) {\n#if defined(__CUDA_ARCH__)\n    cuda::atomic::Add(y, *x);\n#else\n    *y += *x;\n#endif\n  };\n};\n\ntemplate<DeviceType device_type, typename IN_T>\nstruct ReflectionPad1dFunctor final {\n  void operator()(ep::Stream* stream, const IN_T* src, IN_T* dest,\n                  const NdIndexOffsetHelper<int64_t, 3>& index_helper, const int64_t n_batch,\n                  const int64_t n_channel, const int64_t y_width, const int64_t x_width,\n                  const int64_t pad_left);\n};\n\ntemplate<DeviceType device_type, typename IN_T>\nstruct ReflectionPad1dGradFunctor final {\n  void operator()(ep::Stream* stream, const IN_T* src, IN_T* dest,\n                  const NdIndexOffsetHelper<int64_t, 3>& index_helper, const int64_t n_batch,\n                  const int64_t n_channel, const int64_t dy_width, const int64_t dx_width,\n                  const int64_t pad_left);\n};\n\ntemplate<DeviceType device_type, typename IN_T>\nstruct ReflectionPad2dFunctor final {\n  void operator()(ep::Stream* stream, const IN_T* src, IN_T* dest,\n                  const NdIndexOffsetHelper<int64_t, 4>& index_helper, const int64_t n_batch,\n                  const int64_t n_channel, const int64_t y_height, const int64_t y_width,\n                  const int64_t x_height, const int64_t x_width, const int64_t pad_left,\n                  const int64_t pad_top);\n};\n\ntemplate<DeviceType device_type, typename IN_T>\nstruct ReflectionPad2dGradFunctor final {\n  void operator()(ep::Stream* stream, const IN_T* src, IN_T* dest,\n                  const NdIndexOffsetHelper<int64_t, 4>& index_helper, const int64_t n_batch,\n                  const int64_t n_channel, const int64_t dy_height, const int64_t dy_width,\n                  const int64_t dx_height, const int64_t dx_width, const int64_t pad_left,\n                  const int64_t pad_top);\n};\n\ntemplate<typename IN_T>\nOF_DEVICE_FUNC void DoReflectionPad1d(const IN_T* src, IN_T* dest,\n                                      const NdIndexOffsetHelper<int64_t, 3>& index_helper,\n                                      const int64_t elem_num, const int64_t src_num,\n                                      const int64_t dest_num, const int64_t y_width,\n                                      const int64_t x_width, const int64_t pad_left) {\n  XPU_1D_KERNEL_LOOP(k, elem_num) {\n    int64_t n, c, j, ip_x;\n    int64_t coord_y[3];\n    index_helper.OffsetToNdIndex(k, coord_y);\n    n = coord_y[0];\n    c = coord_y[1];\n    j = coord_y[2];\n    if (j < pad_left) {\n      ip_x = pad_left * 2 - j;\n    } else if (j >= pad_left && j < x_width + pad_left) {\n      ip_x = j;\n    } else {\n      ip_x = (x_width + pad_left - 1) * 2 - j;\n    }\n\n    ip_x = ip_x - pad_left;\n    int64_t dest_index = n * dest_num + c * y_width + j;\n    int64_t src_index = n * src_num + c * x_width + ip_x;\n    dest[dest_index] = src[src_index];\n  }\n}\n\ntemplate<typename IN_T>\nOF_DEVICE_FUNC void DoReflectionPad1dGrad(const IN_T* src, IN_T* dest,\n                                          const NdIndexOffsetHelper<int64_t, 3>& index_helper,\n                                          const int64_t elem_num, const int64_t src_num,\n                                          const int64_t dest_num, const int64_t dy_width,\n                                          const int64_t dx_width, const int64_t pad_left) {\n  XPU_1D_KERNEL_LOOP(k, elem_num) {\n    int64_t n, c, j, ip_x;\n    int64_t coord[3];\n    index_helper.OffsetToNdIndex(k, coord);\n    n = coord[0];\n    c = coord[1];\n    j = coord[2];\n    if (j < pad_left) {\n      ip_x = pad_left * 2 - j;\n    } else if (j >= pad_left && j < dx_width + pad_left) {\n      ip_x = j;\n    } else {\n      ip_x = (dx_width + pad_left - 1) * 2 - j;\n    }\n\n    ip_x = ip_x - pad_left;\n\n    int64_t src_index = n * src_num + c * dy_width + j;\n    int64_t dest_index = n * dest_num + c * dx_width + ip_x;\n    DeviceAdd<IN_T>::Invoke(src + src_index, dest + dest_index);\n  }\n}\n\ntemplate<typename IN_T>\nOF_DEVICE_FUNC void DoReflectionPad2d(const IN_T* src, IN_T* dest,\n                                      const NdIndexOffsetHelper<int64_t, 4>& index_helper,\n                                      const int64_t elem_num, const int64_t src_num,\n                                      const int64_t dest_num, const int64_t y_height,\n                                      const int64_t y_width, const int64_t x_height,\n                                      const int64_t x_width, const int64_t pad_left,\n                                      const int64_t pad_top) {\n  XPU_1D_KERNEL_LOOP(k, elem_num) {\n    int64_t n, c, i, j, ip_x, ip_y;\n    int64_t coord_y[4];\n    index_helper.OffsetToNdIndex(k, coord_y);\n    n = coord_y[0];\n    c = coord_y[1];\n    i = coord_y[2];\n    j = coord_y[3];\n    if (j < pad_left) {\n      ip_x = pad_left * 2 - j;\n    } else if (j >= pad_left && j < x_width + pad_left) {\n      ip_x = j;\n    } else {\n      ip_x = (x_width + pad_left - 1) * 2 - j;\n    }\n\n    if (i < pad_top) {\n      ip_y = pad_top * 2 - i;\n    } else if (i >= pad_top && i < x_height + pad_top) {\n      ip_y = i;\n    } else {\n      ip_y = (x_height + pad_top - 1) * 2 - i;\n    }\n    ip_x = ip_x - pad_left;\n    ip_y = ip_y - pad_top;\n    int64_t dest_index = n * dest_num + c * y_width * y_height + i * y_width + j;\n    int64_t src_index = n * src_num + c * x_width * x_height + ip_y * x_width + ip_x;\n    dest[dest_index] = src[src_index];\n  }\n}\n\ntemplate<typename IN_T>\nOF_DEVICE_FUNC void DoReflectionPad2dGrad(const IN_T* src, IN_T* dest,\n                                          const NdIndexOffsetHelper<int64_t, 4>& index_helper,\n                                          const int64_t elem_num, const int64_t src_num,\n                                          const int64_t dest_num, const int64_t dy_height,\n                                          const int64_t dy_width, const int64_t dx_height,\n                                          const int64_t dx_width, const int64_t pad_left,\n                                          const int64_t pad_top) {\n  XPU_1D_KERNEL_LOOP(k, elem_num) {\n    int64_t n, c, i, j, ip_x, ip_y;\n    int64_t coord[4];\n    index_helper.OffsetToNdIndex(k, coord);\n    n = coord[0];\n    c = coord[1];\n    i = coord[2];\n    j = coord[3];\n    if (j < pad_left) {\n      ip_x = pad_left * 2 - j;\n    } else if (j >= pad_left && j < dx_width + pad_left) {\n      ip_x = j;\n    } else {\n      ip_x = (dx_width + pad_left - 1) * 2 - j;\n    }\n\n    if (i < pad_top) {\n      ip_y = pad_top * 2 - i;\n    } else if (i >= pad_top && i < dx_height + pad_top) {\n      ip_y = i;\n    } else {\n      ip_y = (dx_height + pad_top - 1) * 2 - i;\n    }\n    ip_x = ip_x - pad_left;\n    ip_y = ip_y - pad_top;\n\n    int64_t src_index = n * src_num + c * dy_width * dy_height + i * dy_width + j;\n    int64_t dest_index = n * dest_num + c * dx_width * dx_height + ip_y * dx_width + ip_x;\n    DeviceAdd<IN_T>::Invoke(src + src_index, dest + dest_index);\n  }\n}\n\n// macros for functors instantiate\n#define INSTANTIATE_REFLECTION_PAD_FUNCTOR(device_type_v, dtype_pair)                  \\\n  template struct ReflectionPad1dFunctor<device_type_v, OF_PP_PAIR_FIRST(dtype_pair)>; \\\n  template struct ReflectionPad2dFunctor<device_type_v, OF_PP_PAIR_FIRST(dtype_pair)>;\n\n#define INSTANTIATE_REFLECTION_PAD_GRAD_FUNCTOR(device_type_v, dtype_pair)                 \\\n  template struct ReflectionPad1dGradFunctor<device_type_v, OF_PP_PAIR_FIRST(dtype_pair)>; \\\n  template struct ReflectionPad2dGradFunctor<device_type_v, OF_PP_PAIR_FIRST(dtype_pair)>;\n\n}  // namespace user_op\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_KERNELS_REFLECTION_PAD_KERNELS_UTIL_H_\n"
  },
  {
    "path": "oneflow/user/kernels/repeat_interleave_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/user/kernels/roll_kernel_utils.h\"\n\n#include <algorithm>\n\nnamespace oneflow {\n\ntemplate<typename T>\nclass CpuRepeatInterLeaveKernel final : public user_op::OpKernel {\n public:\n  CpuRepeatInterLeaveKernel() = default;\n  ~CpuRepeatInterLeaveKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    const user_op::Tensor* cumsum = ctx->Tensor4ArgNameAndIndex(\"cumsum\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    const T* in_ptr = in->dptr<T>();\n    const T* cumsum_ptr = cumsum->dptr<T>();\n    T* out_ptr = out->mut_dptr<T>();\n    for (T i = 0; i < in->shape_view().At(0); i++) {\n      T end = cumsum_ptr[i];\n      T size = in_ptr[i];\n      T start = end - size;\n      for (T j = start; j < end; j++) { out_ptr[j] = i; }\n    }\n  }\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_REPEAT_INTER_LEAVE_KERNEL(dtype)                     \\\n  REGISTER_USER_KERNEL(\"repeat_interleave\")                           \\\n      .SetCreateFn<CpuRepeatInterLeaveKernel<dtype>>()                \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \\\n                       && (user_op::HobDataType(\"in\", 0) == GetDataType<dtype>::value))\n\nREGISTER_REPEAT_INTER_LEAVE_KERNEL(int32_t);\nREGISTER_REPEAT_INTER_LEAVE_KERNEL(int64_t);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/repeat_interleave_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/user/kernels/roll_kernel_utils.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n\n#include <algorithm>\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename T>\n__global__ void repeat_interleave(const T* in_ptr, const T* cumsum_ptr, T* out_ptr,\n                                  const int64_t num) {\n  CUDA_1D_KERNEL_LOOP(i, num) {\n    T end = cumsum_ptr[i];\n    T size = in_ptr[i];\n    T start = end - size;\n    for (T j = start; j < end; j++) { out_ptr[j] = i; }\n  }\n}\n\n}  // namespace\n\ntemplate<typename T>\nclass GpuRepeatInterLeaveKernel final : public user_op::OpKernel {\n public:\n  GpuRepeatInterLeaveKernel() = default;\n  ~GpuRepeatInterLeaveKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    const user_op::Tensor* cumsum = ctx->Tensor4ArgNameAndIndex(\"cumsum\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    const int64_t& repeat_num = ctx->Attr<std::int64_t>(\"repeat_num\");\n    const T* in_ptr = in->dptr<T>();\n    const T* cumsum_ptr = cumsum->dptr<T>();\n    T* out_ptr = out->mut_dptr<T>();\n\n    repeat_interleave<T><<<BlocksNum4ThreadsNum(in->shape_view().At(0)), kCudaThreadsNumPerBlock, 0,\n                           ctx->stream()->As<ep::CudaStream>()->cuda_stream()>>>(\n        in_ptr, cumsum_ptr, out_ptr, in->shape_view().At(0));\n  }\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_REPEAT_INTER_LEAVE_KERNEL(dtype)                      \\\n  REGISTER_USER_KERNEL(\"repeat_interleave\")                            \\\n      .SetCreateFn<GpuRepeatInterLeaveKernel<dtype>>()                 \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \\\n                       && (user_op::HobDataType(\"in\", 0) == GetDataType<dtype>::value))\n\nREGISTER_REPEAT_INTER_LEAVE_KERNEL(int32_t);\nREGISTER_REPEAT_INTER_LEAVE_KERNEL(int64_t);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/replication_pad_kernels.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/nd_index_offset_helper.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/new_kernel_util.h\"\n#include \"oneflow/user/kernels/replication_pad_kernels_util.h\"\n\nnamespace oneflow {\nnamespace user_op {\n\ntemplate<DeviceType device_type, typename IN_T>\nclass ReplicationPad1dKernel final : public OpKernel {\n public:\n  ReplicationPad1dKernel() = default;\n  ~ReplicationPad1dKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const Tensor* x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    Tensor* y = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    const auto& padding = ctx->Attr<std::vector<int64_t>>(\"padding\");\n    const int64_t n_idx = 0;\n    const int64_t c_idx = 1;\n    const int64_t w_idx = 2;\n\n    const int64_t pad_left = padding[0];\n\n    const int64_t n_batch = y->shape_view().At(n_idx);\n    const int64_t n_channel = y->shape_view().At(c_idx);\n    const int64_t y_width = y->shape_view().At(w_idx);\n    const int64_t x_width = x->shape_view().At(w_idx);\n\n    IN_T* dest = y->mut_dptr<IN_T>();\n    const IN_T* src = x->dptr<IN_T>();\n    DimVector y_vector;\n    y->shape_view().ToDimVector(&y_vector);\n    NdIndexOffsetHelper<int64_t, 3> index_helper(y_vector.data());\n\n    ReplicationPad1dFunctor<device_type, IN_T>()(ctx->stream(), src, dest, index_helper, n_batch,\n                                                 n_channel, y_width, x_width, pad_left);\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\ntemplate<DeviceType device_type, typename IN_T>\nclass ReplicationPad1dGradKernel final : public OpKernel {\n public:\n  ReplicationPad1dGradKernel() = default;\n  ~ReplicationPad1dGradKernel() = default;\n\n private:\n  void Compute(KernelComputeContext* ctx) const override {\n    const Tensor* dy = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    Tensor* dx = ctx->Tensor4ArgNameAndIndex(\"dx\", 0);\n    const auto& padding = ctx->Attr<std::vector<int64_t>>(\"padding\");\n\n    const int64_t n_idx = 0;\n    const int64_t c_idx = 1;\n    const int64_t w_idx = 2;\n\n    const int64_t pad_left = padding[0];\n    const int64_t n_batch = dy->shape_view().At(n_idx);\n    const int64_t n_channel = dy->shape_view().At(c_idx);\n    const int64_t dy_width = dy->shape_view().At(w_idx);\n    const int64_t dx_width = dx->shape_view().At(w_idx);\n\n    const IN_T* src = dy->dptr<IN_T>();\n    IN_T* dest = dx->mut_dptr<IN_T>();\n    DimVector dy_vector;\n    dy->shape_view().ToDimVector(&dy_vector);\n    NdIndexOffsetHelper<int64_t, 3> index_helper(dy_vector.data());\n\n    size_t out_bytes_size = dx->shape_view().elem_cnt() * GetSizeOfDataType(dx->data_type());\n    Memset<device_type>(ctx->stream(), dest, 0, out_bytes_size);\n\n    ReplicationPad1dGradFunctor<device_type, IN_T>()(\n        ctx->stream(), src, dest, index_helper, n_batch, n_channel, dy_width, dx_width, pad_left);\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\ntemplate<DeviceType device_type, typename IN_T>\nclass ReplicationPad2dKernel final : public OpKernel {\n public:\n  ReplicationPad2dKernel() = default;\n  ~ReplicationPad2dKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const Tensor* x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    Tensor* y = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    const auto& padding = ctx->Attr<std::vector<int64_t>>(\"padding\");\n    const int64_t n_idx = 0;\n    const int64_t c_idx = 1;\n    const int64_t h_idx = 2;\n    const int64_t w_idx = 3;\n\n    const int64_t pad_left = padding[0];\n    const int64_t pad_top = padding[2];\n\n    const int64_t n_batch = y->shape_view().At(n_idx);\n    const int64_t n_channel = y->shape_view().At(c_idx);\n    const int64_t y_height = y->shape_view().At(h_idx);\n    const int64_t y_width = y->shape_view().At(w_idx);\n    const int64_t x_height = x->shape_view().At(h_idx);\n    const int64_t x_width = x->shape_view().At(w_idx);\n\n    IN_T* dest = y->mut_dptr<IN_T>();\n    const IN_T* src = x->dptr<IN_T>();\n    DimVector y_vector;\n    y->shape_view().ToDimVector(&y_vector);\n    NdIndexOffsetHelper<int64_t, 4> index_helper(y_vector.data());\n\n    ReplicationPad2dFunctor<device_type, IN_T>()(ctx->stream(), src, dest, index_helper, n_batch,\n                                                 n_channel, y_height, y_width, x_height, x_width,\n                                                 pad_left, pad_top);\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\ntemplate<DeviceType device_type, typename IN_T>\nclass ReplicationPad2dGradKernel final : public OpKernel {\n public:\n  ReplicationPad2dGradKernel() = default;\n  ~ReplicationPad2dGradKernel() = default;\n\n private:\n  void Compute(KernelComputeContext* ctx) const override {\n    const Tensor* dy = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    Tensor* dx = ctx->Tensor4ArgNameAndIndex(\"dx\", 0);\n    const auto& padding = ctx->Attr<std::vector<int64_t>>(\"padding\");\n\n    const int64_t n_idx = 0;\n    const int64_t c_idx = 1;\n    const int64_t h_idx = 2;\n    const int64_t w_idx = 3;\n\n    const int64_t pad_left = padding[0];\n    const int64_t pad_top = padding[2];\n    const int64_t n_batch = dy->shape_view().At(n_idx);\n    const int64_t n_channel = dy->shape_view().At(c_idx);\n    const int64_t dy_height = dy->shape_view().At(h_idx);\n    const int64_t dy_width = dy->shape_view().At(w_idx);\n    const int64_t dx_height = dx->shape_view().At(h_idx);\n    const int64_t dx_width = dx->shape_view().At(w_idx);\n\n    const IN_T* src = dy->dptr<IN_T>();\n    IN_T* dest = dx->mut_dptr<IN_T>();\n    DimVector dy_vector;\n    dy->shape_view().ToDimVector(&dy_vector);\n    NdIndexOffsetHelper<int64_t, 4> index_helper(dy_vector.data());\n\n    size_t out_bytes_size = dx->shape_view().elem_cnt() * GetSizeOfDataType(dx->data_type());\n    Memset<device_type>(ctx->stream(), dest, 0, out_bytes_size);\n\n    ReplicationPad2dGradFunctor<device_type, IN_T>()(ctx->stream(), src, dest, index_helper,\n                                                     n_batch, n_channel, dy_height, dy_width,\n                                                     dx_height, dx_width, pad_left, pad_top);\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_REPLICATION_PAD_ND_KERNELS(device, dtype)                               \\\n  REGISTER_USER_KERNEL(\"replication_pad1d\")                                              \\\n      .SetCreateFn<ReplicationPad1dKernel<device, dtype>>()                              \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == device)                              \\\n                       && (user_op::HobDataType(\"y\", 0) == GetDataType<dtype>::value));  \\\n  REGISTER_USER_KERNEL(\"replication_pad1d_grad\")                                         \\\n      .SetCreateFn<ReplicationPad1dGradKernel<device, dtype>>()                          \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == device)                              \\\n                       && (user_op::HobDataType(\"dx\", 0) == GetDataType<dtype>::value)); \\\n  REGISTER_USER_KERNEL(\"replication_pad2d\")                                              \\\n      .SetCreateFn<ReplicationPad2dKernel<device, dtype>>()                              \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == device)                              \\\n                       && (user_op::HobDataType(\"y\", 0) == GetDataType<dtype>::value));  \\\n  REGISTER_USER_KERNEL(\"replication_pad2d_grad\")                                         \\\n      .SetCreateFn<ReplicationPad2dGradKernel<device, dtype>>()                          \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == device)                              \\\n                       && (user_op::HobDataType(\"dx\", 0) == GetDataType<dtype>::value));\n\n#define REGISTER_REPLICATION_PAD_ND_WITH_DEVICE(device) \\\n  REGISTER_REPLICATION_PAD_ND_KERNELS(device, float)    \\\n  REGISTER_REPLICATION_PAD_ND_KERNELS(device, double)   \\\n  REGISTER_REPLICATION_PAD_ND_KERNELS(device, int32_t)\n\nREGISTER_REPLICATION_PAD_ND_WITH_DEVICE(DeviceType::kCPU)\n#ifdef WITH_CUDA\nREGISTER_REPLICATION_PAD_ND_WITH_DEVICE(DeviceType::kCUDA)\nREGISTER_REPLICATION_PAD_ND_KERNELS(DeviceType::kCUDA, float16)\n#endif\n\n}  // namespace user_op\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/replication_pad_kernels_util.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/user/kernels/replication_pad_kernels_util.h\"\n#include \"oneflow/core/framework/framework.h\"\n\nnamespace oneflow {\nnamespace user_op {\n\ntemplate<typename IN_T>\nstruct ReplicationPad1dFunctor<DeviceType::kCPU, IN_T> final {\n  void operator()(ep::Stream* stream, const IN_T* src, IN_T* dest,\n                  const NdIndexOffsetHelper<int64_t, 3>& index_helper, const int64_t n_batch,\n                  const int64_t n_channel, const int64_t y_width, const int64_t x_width,\n                  const int64_t pad_left) {\n    const int64_t dest_num = n_channel * y_width;\n    const int64_t src_num = n_channel * x_width;\n    const int64_t elem_num = n_batch * dest_num;\n    DoReplicationPad1d<IN_T>(src, dest, index_helper, elem_num, src_num, dest_num, y_width, x_width,\n                             pad_left);\n  }\n};\n\ntemplate<typename IN_T>\nstruct ReplicationPad1dGradFunctor<DeviceType::kCPU, IN_T> final {\n  void operator()(ep::Stream* stream, const IN_T* src, IN_T* dest,\n                  const NdIndexOffsetHelper<int64_t, 3>& index_helper, const int64_t n_batch,\n                  const int64_t n_channel, const int64_t dy_width, const int64_t dx_width,\n                  const int64_t pad_left) {\n    const int64_t dest_num = n_channel * dx_width;\n    const int64_t src_num = n_channel * dy_width;\n    const int64_t elem_num = n_batch * src_num;\n    DoReplicationPad1dGrad<IN_T>(src, dest, index_helper, elem_num, src_num, dest_num, dy_width,\n                                 dx_width, pad_left);\n  }\n};\n\ntemplate<typename IN_T>\nstruct ReplicationPad2dFunctor<DeviceType::kCPU, IN_T> final {\n  void operator()(ep::Stream* stream, const IN_T* src, IN_T* dest,\n                  const NdIndexOffsetHelper<int64_t, 4>& index_helper, const int64_t n_batch,\n                  const int64_t n_channel, const int64_t y_height, const int64_t y_width,\n                  const int64_t x_height, const int64_t x_width, const int64_t pad_left,\n                  const int64_t pad_top) {\n    const int64_t dest_num = n_channel * y_height * y_width;\n    const int64_t src_num = n_channel * x_height * x_width;\n    const int64_t elem_num = n_batch * dest_num;\n    DoReplicationPad2d<IN_T>(src, dest, index_helper, elem_num, src_num, dest_num, y_height,\n                             y_width, x_height, x_width, pad_left, pad_top);\n  }\n};\n\ntemplate<typename IN_T>\nstruct ReplicationPad2dGradFunctor<DeviceType::kCPU, IN_T> final {\n  void operator()(ep::Stream* stream, const IN_T* src, IN_T* dest,\n                  const NdIndexOffsetHelper<int64_t, 4>& index_helper, const int64_t n_batch,\n                  const int64_t n_channel, const int64_t dy_height, const int64_t dy_width,\n                  const int64_t dx_height, const int64_t dx_width, const int64_t pad_left,\n                  const int64_t pad_top) {\n    const int64_t dest_num = n_channel * dx_height * dx_width;\n    const int64_t src_num = n_channel * dy_height * dy_width;\n    const int64_t elem_num = n_batch * src_num;\n    DoReplicationPad2dGrad<IN_T>(src, dest, index_helper, elem_num, src_num, dest_num, dy_height,\n                                 dy_width, dx_height, dx_width, pad_left, pad_top);\n  }\n};\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_REPLICATION_PAD_FUNCTOR, (DeviceType::kCPU),\n                                 PADDING_DATA_TYPE_CPU_SEQ);\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_REPLICATION_PAD_GRAD_FUNCTOR, (DeviceType::kCPU),\n                                 PADDING_DATA_TYPE_CPU_SEQ);\n\n}  // namespace user_op\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/replication_pad_kernels_util.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <cstdint>\n#ifdef WITH_CUDA\n#include \"oneflow/core/common/data_type.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/user/kernels/replication_pad_kernels_util.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n\nnamespace oneflow {\nnamespace user_op {\n\ntemplate<typename IN_T>\n__global__ void DoCUDAReplicationPad1d(const IN_T* src, IN_T* dest,\n                                       const NdIndexOffsetHelper<int64_t, 3> index_helper,\n                                       const int64_t elem_num, const int64_t src_num,\n                                       const int64_t dest_num, const int64_t y_width,\n                                       const int64_t x_width, const int64_t pad_left) {\n  DoReplicationPad1d<IN_T>(src, dest, index_helper, elem_num, src_num, dest_num, y_width, x_width,\n                           pad_left);\n};\n\ntemplate<typename IN_T>\n__global__ void DoCUDAReplicationPad1dGrad(const IN_T* src, IN_T* dest,\n                                           const NdIndexOffsetHelper<int64_t, 3> index_helper,\n                                           const int64_t elem_num, const int64_t src_num,\n                                           const int64_t dest_num, const int64_t dy_width,\n                                           const int64_t dx_width, const int64_t pad_left) {\n  DoReplicationPad1dGrad<IN_T>(src, dest, index_helper, elem_num, src_num, dest_num, dy_width,\n                               dx_width, pad_left);\n};\n\ntemplate<typename IN_T>\n__global__ void DoCUDAReplicationPad2d(const IN_T* src, IN_T* dest,\n                                       const NdIndexOffsetHelper<int64_t, 4> index_helper,\n                                       const int64_t elem_num, const int64_t src_num,\n                                       const int64_t dest_num, const int64_t y_height,\n                                       const int64_t y_width, const int64_t x_height,\n                                       const int64_t x_width, const int64_t pad_left,\n                                       const int64_t pad_top) {\n  DoReplicationPad2d<IN_T>(src, dest, index_helper, elem_num, src_num, dest_num, y_height, y_width,\n                           x_height, x_width, pad_left, pad_top);\n};\n\ntemplate<typename IN_T>\n__global__ void DoCUDAReplicationPad2dGrad(const IN_T* src, IN_T* dest,\n                                           const NdIndexOffsetHelper<int64_t, 4> index_helper,\n                                           const int64_t elem_num, const int64_t src_num,\n                                           const int64_t dest_num, const int64_t dy_height,\n                                           const int64_t dy_width, const int64_t dx_height,\n                                           const int64_t dx_width, const int64_t pad_left,\n                                           const int64_t pad_top) {\n  DoReplicationPad2dGrad<IN_T>(src, dest, index_helper, elem_num, src_num, dest_num, dy_height,\n                               dy_width, dx_height, dx_width, pad_left, pad_top);\n};\n\ntemplate<typename IN_T>\nstruct ReplicationPad1dFunctor<DeviceType::kCUDA, IN_T> final {\n  void operator()(ep::Stream* stream, const IN_T* src, IN_T* dest,\n                  const NdIndexOffsetHelper<int64_t, 3>& index_helper, const int64_t n_batch,\n                  const int64_t n_channel, const int64_t y_width, const int64_t x_width,\n                  const int64_t pad_left) {\n    const int64_t dest_num = n_channel * y_width;\n    const int64_t src_num = n_channel * x_width;\n    const int64_t elem_num = n_batch * dest_num;\n    DoCUDAReplicationPad1d<IN_T><<<BlocksNum4ThreadsNum(elem_num), kCudaThreadsNumPerBlock, 0,\n                                   stream->As<ep::CudaStream>()->cuda_stream()>>>(\n        src, dest, index_helper, elem_num, src_num, dest_num, y_width, x_width, pad_left);\n  }\n};\n\n// float16 implementation\ntemplate<>\nvoid ReplicationPad1dFunctor<DeviceType::kCUDA, float16>::operator()(\n    ep::Stream* stream, const float16* src, float16* dest,\n    const NdIndexOffsetHelper<int64_t, 3>& index_helper, const int64_t n_batch,\n    const int64_t n_channel, const int64_t y_width, const int64_t x_width, const int64_t pad_left) {\n  const int64_t dest_num = n_channel * y_width;\n  const int64_t src_num = n_channel * x_width;\n  const int64_t elem_num = n_batch * dest_num;\n  DoCUDAReplicationPad1d<half><<<BlocksNum4ThreadsNum(elem_num), kCudaThreadsNumPerBlock, 0,\n                                 stream->As<ep::CudaStream>()->cuda_stream()>>>(\n      reinterpret_cast<const half*>(src), reinterpret_cast<half*>(dest), index_helper, elem_num,\n      src_num, dest_num, y_width, x_width, pad_left);\n}\n\ntemplate<typename IN_T>\nstruct ReplicationPad1dGradFunctor<DeviceType::kCUDA, IN_T> final {\n  void operator()(ep::Stream* stream, const IN_T* src, IN_T* dest,\n                  const NdIndexOffsetHelper<int64_t, 3>& index_helper, const int64_t n_batch,\n                  const int64_t n_channel, const int64_t dy_width, const int64_t dx_width,\n                  const int64_t pad_left) {\n    const int64_t dest_num = n_channel * dx_width;\n    const int64_t src_num = n_channel * dy_width;\n    const int64_t elem_num = n_batch * src_num;\n    DoCUDAReplicationPad1dGrad<IN_T><<<BlocksNum4ThreadsNum(elem_num), kCudaThreadsNumPerBlock, 0,\n                                       stream->As<ep::CudaStream>()->cuda_stream()>>>(\n        src, dest, index_helper, elem_num, src_num, dest_num, dy_width, dx_width, pad_left);\n  }\n};\n\n// float16 implementation\ntemplate<>\nvoid ReplicationPad1dGradFunctor<DeviceType::kCUDA, float16>::operator()(\n    ep::Stream* stream, const float16* src, float16* dest,\n    const NdIndexOffsetHelper<int64_t, 3>& index_helper, const int64_t n_batch,\n    const int64_t n_channel, const int64_t dy_width, const int64_t dx_width,\n    const int64_t pad_left) {\n  const int64_t dest_num = n_channel * dx_width;\n  const int64_t src_num = n_channel * dy_width;\n  const int64_t elem_num = n_batch * src_num;\n  DoCUDAReplicationPad1dGrad<half><<<BlocksNum4ThreadsNum(elem_num), kCudaThreadsNumPerBlock, 0,\n                                     stream->As<ep::CudaStream>()->cuda_stream()>>>(\n      reinterpret_cast<const half*>(src), reinterpret_cast<half*>(dest), index_helper, elem_num,\n      src_num, dest_num, dy_width, dx_width, pad_left);\n}\n\ntemplate<typename IN_T>\nstruct ReplicationPad2dFunctor<DeviceType::kCUDA, IN_T> final {\n  void operator()(ep::Stream* stream, const IN_T* src, IN_T* dest,\n                  const NdIndexOffsetHelper<int64_t, 4>& index_helper, const int64_t n_batch,\n                  const int64_t n_channel, const int64_t y_height, const int64_t y_width,\n                  const int64_t x_height, const int64_t x_width, const int64_t pad_left,\n                  const int64_t pad_top) {\n    const int64_t dest_num = n_channel * y_height * y_width;\n    const int64_t src_num = n_channel * x_height * x_width;\n    const int64_t elem_num = n_batch * dest_num;\n    DoCUDAReplicationPad2d<IN_T><<<BlocksNum4ThreadsNum(elem_num), kCudaThreadsNumPerBlock, 0,\n                                   stream->As<ep::CudaStream>()->cuda_stream()>>>(\n        src, dest, index_helper, elem_num, src_num, dest_num, y_height, y_width, x_height, x_width,\n        pad_left, pad_top);\n  }\n};\n\n// float16 implementation\ntemplate<>\nvoid ReplicationPad2dFunctor<DeviceType::kCUDA, float16>::operator()(\n    ep::Stream* stream, const float16* src, float16* dest,\n    const NdIndexOffsetHelper<int64_t, 4>& index_helper, const int64_t n_batch,\n    const int64_t n_channel, const int64_t y_height, const int64_t y_width, const int64_t x_height,\n    const int64_t x_width, const int64_t pad_left, const int64_t pad_top) {\n  const int64_t dest_num = n_channel * y_height * y_width;\n  const int64_t src_num = n_channel * x_height * x_width;\n  const int64_t elem_num = n_batch * dest_num;\n  DoCUDAReplicationPad2d<half><<<BlocksNum4ThreadsNum(elem_num), kCudaThreadsNumPerBlock, 0,\n                                 stream->As<ep::CudaStream>()->cuda_stream()>>>(\n      reinterpret_cast<const half*>(src), reinterpret_cast<half*>(dest), index_helper, elem_num,\n      src_num, dest_num, y_height, y_width, x_height, x_width, pad_left, pad_top);\n}\n\ntemplate<typename IN_T>\nstruct ReplicationPad2dGradFunctor<DeviceType::kCUDA, IN_T> final {\n  void operator()(ep::Stream* stream, const IN_T* src, IN_T* dest,\n                  const NdIndexOffsetHelper<int64_t, 4>& index_helper, const int64_t n_batch,\n                  const int64_t n_channel, const int64_t dy_height, const int64_t dy_width,\n                  const int64_t dx_height, const int64_t dx_width, const int64_t pad_left,\n                  const int64_t pad_top) {\n    const int64_t dest_num = n_channel * dx_height * dx_width;\n    const int64_t src_num = n_channel * dy_height * dy_width;\n    const int64_t elem_num = n_batch * src_num;\n    DoCUDAReplicationPad2dGrad<IN_T><<<BlocksNum4ThreadsNum(elem_num), kCudaThreadsNumPerBlock, 0,\n                                       stream->As<ep::CudaStream>()->cuda_stream()>>>(\n        src, dest, index_helper, elem_num, src_num, dest_num, dy_height, dy_width, dx_height,\n        dx_width, pad_left, pad_top);\n  }\n};\n\n// float16 implementation\ntemplate<>\nvoid ReplicationPad2dGradFunctor<DeviceType::kCUDA, float16>::operator()(\n    ep::Stream* stream, const float16* src, float16* dest,\n    const NdIndexOffsetHelper<int64_t, 4>& index_helper, const int64_t n_batch,\n    const int64_t n_channel, const int64_t dy_height, const int64_t dy_width,\n    const int64_t dx_height, const int64_t dx_width, const int64_t pad_left,\n    const int64_t pad_top) {\n  const int64_t dest_num = n_channel * dx_height * dx_width;\n  const int64_t src_num = n_channel * dy_height * dy_width;\n  const int64_t elem_num = n_batch * src_num;\n  DoCUDAReplicationPad2dGrad<half><<<BlocksNum4ThreadsNum(elem_num), kCudaThreadsNumPerBlock, 0,\n                                     stream->As<ep::CudaStream>()->cuda_stream()>>>(\n      reinterpret_cast<const half*>(src), reinterpret_cast<half*>(dest), index_helper, elem_num,\n      src_num, dest_num, dy_height, dy_width, dx_height, dx_width, pad_left, pad_top);\n}\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_REPLICATION_PAD_FUNCTOR,\n                                 OF_PP_MAKE_TUPLE_SEQ(DeviceType::kCUDA),\n                                 PADDING_DATA_TYPE_CUDA_SEQ);\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_REPLICATION_PAD_GRAD_FUNCTOR,\n                                 OF_PP_MAKE_TUPLE_SEQ(DeviceType::kCUDA),\n                                 PADDING_DATA_TYPE_CUDA_SEQ);\n}  // namespace user_op\n}  // namespace oneflow\n\n#endif  // WITH_CUDA\n"
  },
  {
    "path": "oneflow/user/kernels/replication_pad_kernels_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_KERNELS_REPLICATION_PAD_KERNELS_UTIL_H_\n#define ONEFLOW_USER_KERNELS_REPLICATION_PAD_KERNELS_UTIL_H_\n#ifdef WITH_CUDA\n#include \"oneflow/core/cuda/atomic.cuh\"\n#endif  // WITH_CUDA\n#include \"oneflow/core/common/nd_index_offset_helper.h\"\n#include \"oneflow/core/ndarray/xpu_util.h\"\n\nnamespace oneflow {\n\n#define PADDING_DATA_TYPE_CPU_SEQ \\\n  FLOATING_DATA_TYPE_SEQ          \\\n  OF_PP_MAKE_TUPLE_SEQ(int32_t, DataType::kInt32)\n\n#define PADDING_DATA_TYPE_CUDA_SEQ \\\n  FLOAT16_DATA_TYPE_SEQ            \\\n  PADDING_DATA_TYPE_CPU_SEQ\n\nnamespace user_op {\n\ntemplate<typename T>\nstruct DeviceAdd {\n  OF_DEVICE_FUNC static void Invoke(const T* x, T* y) {\n#if defined(__CUDA_ARCH__)\n    cuda::atomic::Add(y, *x);\n#else\n    *y += *x;\n#endif\n  };\n};\n\ntemplate<DeviceType device_type, typename IN_T>\nstruct ReplicationPad1dFunctor final {\n  void operator()(ep::Stream* stream, const IN_T* src, IN_T* dest,\n                  const NdIndexOffsetHelper<int64_t, 3>& index_helper, const int64_t n_batch,\n                  const int64_t n_channel, const int64_t y_width, const int64_t x_width,\n                  const int64_t pad_left);\n};\n\ntemplate<DeviceType device_type, typename IN_T>\nstruct ReplicationPad1dGradFunctor final {\n  void operator()(ep::Stream* stream, const IN_T* src, IN_T* dest,\n                  const NdIndexOffsetHelper<int64_t, 3>& index_helper, const int64_t n_batch,\n                  const int64_t n_channel, const int64_t dy_width, const int64_t dx_width,\n                  const int64_t pad_left);\n};\n\ntemplate<DeviceType device_type, typename IN_T>\nstruct ReplicationPad2dFunctor final {\n  void operator()(ep::Stream* stream, const IN_T* src, IN_T* dest,\n                  const NdIndexOffsetHelper<int64_t, 4>& index_helper, const int64_t n_batch,\n                  const int64_t n_channel, const int64_t y_height, const int64_t y_width,\n                  const int64_t x_height, const int64_t x_width, const int64_t pad_left,\n                  const int64_t pad_top);\n};\n\ntemplate<DeviceType device_type, typename IN_T>\nstruct ReplicationPad2dGradFunctor final {\n  void operator()(ep::Stream* stream, const IN_T* src, IN_T* dest,\n                  const NdIndexOffsetHelper<int64_t, 4>& index_helper, const int64_t n_batch,\n                  const int64_t n_channel, const int64_t dy_height, const int64_t dy_width,\n                  const int64_t dx_height, const int64_t dx_width, const int64_t pad_left,\n                  const int64_t pad_top);\n};\n\ntemplate<typename IN_T>\nOF_DEVICE_FUNC void DoReplicationPad1d(const IN_T* src, IN_T* dest,\n                                       const NdIndexOffsetHelper<int64_t, 3>& index_helper,\n                                       const int64_t elem_num, const int64_t src_num,\n                                       const int64_t dest_num, const int64_t y_width,\n                                       const int64_t x_width, const int64_t pad_left) {\n  XPU_1D_KERNEL_LOOP(k, elem_num) {\n    int64_t n, c, j, ip_x;\n    int64_t coord_y[3];\n    index_helper.OffsetToNdIndex(k, coord_y);\n    n = coord_y[0];\n    c = coord_y[1];\n    j = coord_y[2];\n    if (j < pad_left) {\n      ip_x = pad_left;\n    } else if (j >= pad_left && j < x_width + pad_left) {\n      ip_x = j;\n    } else {\n      ip_x = x_width + pad_left - 1;\n    }\n\n    ip_x = ip_x - pad_left;\n    int64_t dest_index = n * dest_num + c * y_width + j;\n    int64_t src_index = n * src_num + c * x_width + ip_x;\n    dest[dest_index] = src[src_index];\n  }\n}\n\ntemplate<typename IN_T>\nOF_DEVICE_FUNC void DoReplicationPad1dGrad(const IN_T* src, IN_T* dest,\n                                           const NdIndexOffsetHelper<int64_t, 3>& index_helper,\n                                           const int64_t elem_num, const int64_t src_num,\n                                           const int64_t dest_num, const int64_t dy_width,\n                                           const int64_t dx_width, const int64_t pad_left) {\n  XPU_1D_KERNEL_LOOP(k, elem_num) {\n    int64_t n, c, j, ip_x;\n    int64_t coord[3];\n    index_helper.OffsetToNdIndex(k, coord);\n    n = coord[0];\n    c = coord[1];\n    j = coord[2];\n    if (j < pad_left) {\n      ip_x = pad_left;\n    } else if (j >= pad_left && j < dx_width + pad_left) {\n      ip_x = j;\n    } else {\n      ip_x = dx_width + pad_left - 1;\n    }\n\n    ip_x = ip_x - pad_left;\n\n    int64_t src_index = n * src_num + c * dy_width + j;\n    int64_t dest_index = n * dest_num + c * dx_width + ip_x;\n    DeviceAdd<IN_T>::Invoke(src + src_index, dest + dest_index);\n  }\n}\n\ntemplate<typename IN_T>\nOF_DEVICE_FUNC void DoReplicationPad2d(const IN_T* src, IN_T* dest,\n                                       const NdIndexOffsetHelper<int64_t, 4>& index_helper,\n                                       const int64_t elem_num, const int64_t src_num,\n                                       const int64_t dest_num, const int64_t y_height,\n                                       const int64_t y_width, const int64_t x_height,\n                                       const int64_t x_width, const int64_t pad_left,\n                                       const int64_t pad_top) {\n  XPU_1D_KERNEL_LOOP(k, elem_num) {\n    int64_t n, c, i, j, ip_x, ip_y;\n    int64_t coord_y[4];\n    index_helper.OffsetToNdIndex(k, coord_y);\n    n = coord_y[0];\n    c = coord_y[1];\n    i = coord_y[2];\n    j = coord_y[3];\n    if (j < pad_left) {\n      ip_x = pad_left;\n    } else if (j >= pad_left && j < x_width + pad_left) {\n      ip_x = j;\n    } else {\n      ip_x = x_width + pad_left - 1;\n    }\n\n    if (i < pad_top) {\n      ip_y = pad_top;\n    } else if (i >= pad_top && i < x_height + pad_top) {\n      ip_y = i;\n    } else {\n      ip_y = x_height + pad_top - 1;\n    }\n    ip_x = ip_x - pad_left;\n    ip_y = ip_y - pad_top;\n\n    int64_t dest_index = n * dest_num + c * y_width * y_height + i * y_width + j;\n    int64_t src_index = n * src_num + c * x_width * x_height + ip_y * x_width + ip_x;\n    dest[dest_index] = src[src_index];\n  }\n}\n\ntemplate<typename IN_T>\nOF_DEVICE_FUNC void DoReplicationPad2dGrad(const IN_T* src, IN_T* dest,\n                                           const NdIndexOffsetHelper<int64_t, 4>& index_helper,\n                                           const int64_t elem_num, const int64_t src_num,\n                                           const int64_t dest_num, const int64_t dy_height,\n                                           const int64_t dy_width, const int64_t dx_height,\n                                           const int64_t dx_width, const int64_t pad_left,\n                                           const int64_t pad_top) {\n  XPU_1D_KERNEL_LOOP(k, elem_num) {\n    int64_t n, c, i, j, ip_x, ip_y;\n    int64_t coord[4];\n    index_helper.OffsetToNdIndex(k, coord);\n    n = coord[0];\n    c = coord[1];\n    i = coord[2];\n    j = coord[3];\n    if (j < pad_left) {\n      ip_x = pad_left;\n    } else if (j >= pad_left && j < dx_width + pad_left) {\n      ip_x = j;\n    } else {\n      ip_x = dx_width + pad_left - 1;\n    }\n\n    if (i < pad_top) {\n      ip_y = pad_top;\n    } else if (i >= pad_top && i < dx_height + pad_top) {\n      ip_y = i;\n    } else {\n      ip_y = dx_height + pad_top - 1;\n    }\n    ip_x = ip_x - pad_left;\n    ip_y = ip_y - pad_top;\n\n    int64_t src_index = n * src_num + c * dy_width * dy_height + i * dy_width + j;\n    int64_t dest_index = n * dest_num + c * dx_width * dx_height + ip_y * dx_width + ip_x;\n    DeviceAdd<IN_T>::Invoke(src + src_index, dest + dest_index);\n  }\n}\n\n// macros for functors instantiate(used by pad2d_kernels_util.cu)\n#define INSTANTIATE_REPLICATION_PAD_FUNCTOR(device_type_v, dtype_pair)                  \\\n  template struct ReplicationPad1dFunctor<device_type_v, OF_PP_PAIR_FIRST(dtype_pair)>; \\\n  template struct ReplicationPad2dFunctor<device_type_v, OF_PP_PAIR_FIRST(dtype_pair)>;\n\n#define INSTANTIATE_REPLICATION_PAD_GRAD_FUNCTOR(device_type_v, dtype_pair)                 \\\n  template struct ReplicationPad1dGradFunctor<device_type_v, OF_PP_PAIR_FIRST(dtype_pair)>; \\\n  template struct ReplicationPad2dGradFunctor<device_type_v, OF_PP_PAIR_FIRST(dtype_pair)>;\n\n}  // namespace user_op\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_KERNELS_REPLICATION_PAD_KERNELS_UTIL_H_\n"
  },
  {
    "path": "oneflow/user/kernels/rms_norm_gpu_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/cuda_graph_support.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n#include \"oneflow/core/ep/include/primitive/fill.h\"\n#include \"oneflow/core/ep/include/primitive/matmul.h\"\n#include \"oneflow/core/cuda/rms_norm.cuh\"\n#if CUDA_VERSION >= 11000\n#include <cuda_bf16.h>\n#endif  // CUDA_VERSION >= 11000\n\nnamespace oneflow {\nnamespace cuda {\nnamespace rms_norm {\n\ntemplate<typename SRC, typename DST, bool affine>\nstruct AffineStore {\n  AffineStore(DST* dst, const DST* weight, int32_t row_size)\n      : dst(dst), weight(weight), row_size(row_size) {}\n\n  template<int N>\n  __device__ void store(const SRC* src, int32_t row, int32_t col) {\n    layer_norm::Pack<DST, N> dst_pack;\n    layer_norm::Pack<DST, N> weight_pack;\n    const int32_t offset = (row * row_size + col) / N;\n    const int32_t weight_offset = col / N;\n    if (affine) {\n      weight_pack.storage =\n          *(reinterpret_cast<const layer_norm::PackType<DST, N>*>(weight) + weight_offset);\n    }\n#pragma unroll\n    for (int i = 0; i < N; ++i) {\n      if (affine) {\n        dst_pack.elem[i] = static_cast<DST>(src[i]) * weight_pack.elem[i];\n      } else {\n        dst_pack.elem[i] = static_cast<DST>(src[i]);\n      }\n    }\n    *(reinterpret_cast<layer_norm::PackType<DST, N>*>(dst) + offset) = dst_pack.storage;\n  }\n\n  DST* dst;\n  const DST* weight;\n  int32_t row_size;\n};\n\ntemplate<typename SRC, typename DST, bool affine>\nstruct AffineLoad {\n  AffineLoad(const SRC* src, const SRC* weight, int32_t row_size)\n      : src(src), weight(weight), row_size(row_size) {}\n\n  template<int N>\n  __device__ void load(DST* dst, int32_t row, int32_t col) const {\n    layer_norm::Pack<SRC, N> src_pack;\n    layer_norm::Pack<SRC, N> weight_pack;\n    const int32_t offset = (row * row_size + col) / N;\n    src_pack.storage = *(reinterpret_cast<const layer_norm::PackType<SRC, N>*>(src) + offset);\n    if (affine) {\n      const int32_t weight_offset = col / N;\n      weight_pack.storage =\n          *(reinterpret_cast<const layer_norm::PackType<SRC, N>*>(weight) + weight_offset);\n    }\n#pragma unroll\n    for (int i = 0; i < N; ++i) {\n      if (affine) {\n        dst[i] = static_cast<DST>(src_pack.elem[i] * weight_pack.elem[i]);\n      } else {\n        dst[i] = static_cast<DST>(src_pack.elem[i]);\n      }\n    }\n  }\n  const SRC* src;\n  const SRC* weight;\n  int32_t row_size;\n};\n\ntemplate<typename T, typename ComputeType, bool affine>\nvoid DispatchRmsNormForwardAffine(ep::Stream* stream, const int64_t nrow, const int64_t ncol,\n                                  const double eps, const T* x_dptr, const T* w_dptr, T* y_dptr,\n                                  ComputeType* inv_rms) {\n  layer_norm::DirectLoad<T, ComputeType> load(x_dptr, ncol);\n  AffineStore<ComputeType, T, affine> store(y_dptr, w_dptr, ncol);\n  OF_CUDA_CHECK((LaunchRmsNorm<decltype(load), decltype(store), ComputeType>(\n      stream->As<ep::CudaStream>()->cuda_stream(), load, store, nrow, ncol, eps, inv_rms)));\n}\n\ntemplate<typename T, typename ComputeType>\nvoid RmsNormForward(ep::Stream* stream, const int64_t nrow, const int64_t ncol, const double eps,\n                    const T* x_dptr, const T* w_dptr, T* y_dptr, ComputeType* inv_rms) {\n  if (w_dptr) {\n    DispatchRmsNormForwardAffine<T, ComputeType, true>(stream, nrow, ncol, eps, x_dptr, w_dptr,\n                                                       y_dptr, inv_rms);\n  } else {\n    DispatchRmsNormForwardAffine<T, ComputeType, false>(stream, nrow, ncol, eps, x_dptr, w_dptr,\n                                                        y_dptr, inv_rms);\n  }\n}\n\ntemplate<typename T, typename ComputeType, bool affine>\nvoid DispatchRmsNormBackwardAffine(ep::Stream* stream, const int64_t nrow, const int64_t ncol,\n                                   const T* dy_dptr, const T* x_dptr, const T* weight_dptr,\n                                   const ComputeType* inv_rms, T* dx_ptr) {\n  layer_norm::DirectLoad<T, ComputeType> load_x(x_dptr, ncol);\n  AffineLoad<T, ComputeType, affine> load_dy(dy_dptr, weight_dptr, ncol);\n  layer_norm::DirectStore<ComputeType, T> store(dx_ptr, ncol);\n  OF_CUDA_CHECK((rms_norm::LaunchRmsNormGrad(stream->As<ep::CudaStream>()->cuda_stream(), nrow,\n                                             ncol, load_x, load_dy, store, inv_rms)));\n}\n\ntemplate<typename T, typename ComputeType>\nvoid RmsNormBackward(ep::Stream* stream, const int64_t nrow, const int64_t ncol, const T* dy_dptr,\n                     const T* x_dptr, const T* weight_dptr, const ComputeType* inv_rms,\n                     T* dx_dptr) {\n  if (weight_dptr) {\n    DispatchRmsNormBackwardAffine<T, ComputeType, true>(stream, nrow, ncol, dy_dptr, x_dptr,\n                                                        weight_dptr, inv_rms, dx_dptr);\n  } else {\n    DispatchRmsNormBackwardAffine<T, ComputeType, false>(stream, nrow, ncol, dy_dptr, x_dptr,\n                                                         weight_dptr, inv_rms, dx_dptr);\n  }\n}\n\n}  // namespace rms_norm\n\ntemplate<typename T>\nclass RmsNormKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport {\n public:\n  RmsNormKernel() = default;\n  ~RmsNormKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    user_op::Tensor* inv_rms = ctx->Tensor4ArgNameAndIndex(\"inv_rms\", 0);\n    const double eps = ctx->Attr<float>(\"epsilon\");\n    const Shape& normalized_shape = ctx->Attr<Shape>(\"normalized_shape\");\n    const int64_t ncol = normalized_shape.elem_cnt();\n    const int64_t nrow = inv_rms->shape_view().elem_cnt();\n    const T* weight_dptr = nullptr;\n    if (ctx->has_input(\"weight\", 0)) {\n      const auto* weight = ctx->Tensor4ArgNameAndIndex(\"weight\", 0);\n      CHECK_EQ(weight->shape_view().elem_cnt(), ncol);\n      weight_dptr = weight->dptr<T>();\n    }\n    CHECK_EQ(x->shape_view().elem_cnt(), ncol * nrow);\n    CHECK_LT(nrow * ncol, std::numeric_limits<int32_t>::max())\n        << \"The size of tensor exceeds int32 max limit. The kernel don't support large tensor.\";\n    using ComputeType = typename layer_norm::DefaultComputeType<T>::type;\n    rms_norm::RmsNormForward<T>(ctx->stream(), nrow, ncol, eps, x->dptr<T>(), weight_dptr,\n                                y->mut_dptr<T>(), inv_rms->mut_dptr<ComputeType>());\n  };\n};\n\n#define REGISTER_RMS_NORM_CUDA_KERNEL(dtype)                           \\\n  REGISTER_USER_KERNEL(\"rms_norm\")                                     \\\n      .SetCreateFn<RmsNormKernel<dtype>>()                             \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \\\n                       && (user_op::HobDataType(\"x\", 0) == GetDataType<dtype>::value));\n\nREGISTER_RMS_NORM_CUDA_KERNEL(float)\nREGISTER_RMS_NORM_CUDA_KERNEL(double)\nREGISTER_RMS_NORM_CUDA_KERNEL(half)\n#if CUDA_VERSION >= 11000\nREGISTER_RMS_NORM_CUDA_KERNEL(nv_bfloat16)\n#endif\n\ntemplate<typename T>\nclass RmsNormGradKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport {\n public:\n  RmsNormGradKernel() = default;\n  ~RmsNormGradKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    const user_op::Tensor* inv_rms = ctx->Tensor4ArgNameAndIndex(\"inv_rms\", 0);\n    user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex(\"dx\", 0);\n    const int64_t nrow = inv_rms->shape_view().elem_cnt();\n    const int64_t ncol = x->shape_view().elem_cnt() / nrow;\n    const T* weight_dptr = nullptr;\n    if (ctx->has_input(\"weight\", 0)) {\n      const user_op::Tensor* weight = ctx->Tensor4ArgNameAndIndex(\"weight\", 0);\n      CHECK_EQ(ncol, weight->shape_view().elem_cnt());\n      weight_dptr = weight->dptr<T>();\n    }\n    CHECK_LT(nrow * ncol, std::numeric_limits<int32_t>::max())\n        << \"The size of tensor exceeds int32 max limit. The kernel don't support large tensor.\";\n    using ComputeType = typename layer_norm::DefaultComputeType<T>::type;\n    rms_norm::RmsNormBackward<T>(ctx->stream(), nrow, ncol, dy->dptr<T>(), x->dptr<T>(),\n                                 weight_dptr, inv_rms->dptr<ComputeType>(), dx->mut_dptr<T>());\n  };\n};\n\n#define REGISTER_RMS_NORM_GRAD_CUDA_KERNEL(dtype)                      \\\n  REGISTER_USER_KERNEL(\"rms_norm_grad\")                                \\\n      .SetCreateFn<RmsNormGradKernel<dtype>>()                         \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \\\n                       && (user_op::HobDataType(\"dy\", 0) == GetDataType<dtype>::value));\n\nREGISTER_RMS_NORM_GRAD_CUDA_KERNEL(float)\nREGISTER_RMS_NORM_GRAD_CUDA_KERNEL(double)\nREGISTER_RMS_NORM_GRAD_CUDA_KERNEL(half)\n#if CUDA_VERSION >= 11000\nREGISTER_RMS_NORM_GRAD_CUDA_KERNEL(nv_bfloat16)\n#endif\n\nnamespace {\n\nconstexpr int kNProcPerThread = 4;\n\n}  // namespace\n\ntemplate<typename T>\nclass RmsNormParamGradKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport {\n public:\n  RmsNormParamGradKernel() = default;\n  ~RmsNormParamGradKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    const user_op::Tensor* inv_rms = ctx->Tensor4ArgNameAndIndex(\"inv_rms\", 0);\n    user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n    user_op::Tensor* weight_grad = ctx->Tensor4ArgNameAndIndex(\"weight_grad\", 0);\n    const int64_t nrow = inv_rms->shape_view().elem_cnt();\n    const int64_t ncol = weight_grad->shape_view().elem_cnt();\n    CHECK_LT(nrow * ncol, std::numeric_limits<int32_t>::max())\n        << \"The size of tensor exceeds int32 max limit. The kernel don't support large tensor.\";\n\n    // step 1: dx = dy * y and reduce partial rows in a block\n    const int block_dim_x = rms_norm::kWarpSize;\n    const int block_dim_y = rms_norm::kWarpSize / kNProcPerThread;\n    int grid_dim_x;\n    int grid_dim_y;\n    OF_CUDA_CHECK((rms_norm::GetGrid2Dim<kNProcPerThread, T>(nrow, ncol, block_dim_x, block_dim_y,\n                                                             &grid_dim_x, &grid_dim_y)));\n    // tmp weight shape [grid_dim_y, ncol] (reduce nrow -> grid_dim_y)\n    size_t tmp_weight_grad_size = grid_dim_y * ncol;\n    T* tmp_weight_grad_dptr = reinterpret_cast<T*>(tmp_buffer->mut_dptr());\n    using ComputeType = typename layer_norm::DefaultComputeType<T>::type;\n    dim3 grid_dims(grid_dim_x, grid_dim_y);\n    dim3 block_dims(block_dim_x, block_dim_y);\n    rms_norm::RmsNormParamGrad<kNProcPerThread, T, ComputeType>\n        <<<grid_dims, block_dims, 0, ctx->stream()->As<ep::CudaStream>()->cuda_stream()>>>(\n            nrow, ncol, dy->dptr<T>(), x->dptr<T>(), inv_rms->dptr<ComputeType>(),\n            tmp_weight_grad_dptr);\n\n    // step 2: reduce rows throught gemm to calculate weight grad\n    // fill ones matrix with shape (grid_dim_y, 1)\n    const int32_t m = ncol;\n    const int32_t n = 1;\n    const int32_t k = grid_dim_y;\n    const DataType data_type = dy->data_type();\n    auto fill = ep::primitive::NewPrimitive<ep::primitive::FillFactory>(\n        ctx->stream()->device_type(), data_type);\n    CHECK(fill);\n    T* tmp_ones_dptr = tmp_buffer->mut_dptr<T>() + tmp_weight_grad_size;\n    fill->Launch(ctx->stream(), tmp_ones_dptr, 1.0, k);\n    // tmp weight grad (grid_dim_y, ncol) (T) * tmp ones (grid_dim_y, 1) (N)\n    // -> weight grad (ncol, 1)\n    auto matmul = ep::primitive::NewPrimitive<ep::primitive::MatmulFactory>(\n        ctx->stream()->device_type(), data_type, ep::primitive::BlasTransposeType::T,\n        ep::primitive::BlasTransposeType::N);\n    CHECK(matmul);\n    matmul->Launch(ctx->stream(), m, n, k, /*alpha*/ 1.0, tmp_weight_grad_dptr, tmp_ones_dptr,\n                   /*beta*/ 0.0, weight_grad->mut_dptr());\n  };\n};\n\ntemplate<typename T>\nsize_t InferRmsNormParamGradTempBufferSize(user_op::InferContext* ctx) {\n  const auto& shape = ctx->InputTensorDesc(\"dy\", 0).shape();\n  const auto& b_shape = ctx->InputTensorDesc(\"inv_rms\", 0).shape();\n  const int64_t nrow = b_shape.elem_cnt();\n  const int64_t ncol = shape.elem_cnt() / nrow;\n  const int block_dim_x = rms_norm::kWarpSize;\n  const int block_dim_y = rms_norm::kWarpSize / kNProcPerThread;\n  int grid_dim_x;\n  int grid_dim_y;\n  OF_CUDA_CHECK((rms_norm::GetGrid2Dim<kNProcPerThread, T>(nrow, ncol, block_dim_x, block_dim_y,\n                                                           &grid_dim_x, &grid_dim_y)));\n  return (grid_dim_y * ncol + grid_dim_y) * sizeof(T);\n}\n\n#define REGISTER_RMS_NORM_PARAM_GRAD_GPU_KERNEL(dtype)                                  \\\n  REGISTER_USER_KERNEL(\"rms_norm_param_grad\")                                           \\\n      .SetCreateFn<RmsNormParamGradKernel<dtype>>()                                     \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)                  \\\n                       && (user_op::HobDataType(\"dy\", 0) == GetDataType<dtype>::value)) \\\n      .SetInferTmpSizeFn(InferRmsNormParamGradTempBufferSize<dtype>);\n\nREGISTER_RMS_NORM_PARAM_GRAD_GPU_KERNEL(float)\nREGISTER_RMS_NORM_PARAM_GRAD_GPU_KERNEL(double)\nREGISTER_RMS_NORM_PARAM_GRAD_GPU_KERNEL(half)\n#if CUDA_VERSION >= 11000\nREGISTER_RMS_NORM_PARAM_GRAD_GPU_KERNEL(nv_bfloat16)\n#endif\n\n}  // namespace cuda\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/roc_auc_score_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename L, typename P>\ndouble RocAucScore(size_t n, const L* label, const P* pred, float* buffer) {\n  size_t p_samples_count = 0;\n  for (size_t i = 0; i < n; ++i) {\n    if (label[i] == 0) {\n      buffer[i] = -pred[i];\n    } else {\n      p_samples_count += 1;\n      buffer[i] = pred[i];\n    }\n  }\n  const size_t n_samples_count = n - p_samples_count;\n  constexpr size_t kParallelSortThreshold = 1024;\n  auto comp = [](float a, float b) { return fabs(a) < fabs(b); };\n  if (n < kParallelSortThreshold) {\n    std::sort(buffer, buffer + n, comp);\n  } else {\n    const size_t m2 = n / 2;\n    const size_t m1 = m2 / 2;\n    const size_t m3 = (m2 + n) / 2;\n    std::thread t0([&] { std::sort(buffer, buffer + m1, comp); });\n    std::thread t1([&] { std::sort(buffer + m1, buffer + m2, comp); });\n    std::thread t2([&] { std::sort(buffer + m2, buffer + m3, comp); });\n    std::thread t3([&] { std::sort(buffer + m3, buffer + n, comp); });\n    t0.join();\n    t1.join();\n    t2.join();\n    t3.join();\n    std::inplace_merge(buffer, buffer + m1, buffer + m2, comp);\n    std::inplace_merge(buffer + m2, buffer + m3, buffer + n, comp);\n    std::inplace_merge(buffer, buffer + m2, buffer + n, comp);\n  }\n  size_t tmp_n = 0;\n  double tmp_rank_sum = 0;\n  double rank_sum = 0;\n  size_t tmp_p_samples_count = 0;\n  for (size_t i = 0; i < n; ++i) {\n    if (i != 0 && fabs(buffer[i]) != fabs(buffer[i - 1])) {\n      rank_sum += tmp_p_samples_count * (tmp_rank_sum / tmp_n);\n      tmp_n = 0;\n      tmp_rank_sum = 0;\n      tmp_p_samples_count = 0;\n    }\n    if (buffer[i] > 0) { tmp_p_samples_count += 1; }\n    tmp_rank_sum += (i + 1);\n    tmp_n += 1;\n  }\n  rank_sum += tmp_p_samples_count * (tmp_rank_sum / tmp_n);\n  return (rank_sum - p_samples_count * (p_samples_count + 1) / 2)\n         / (p_samples_count * n_samples_count);\n}\n\ntemplate<typename L, typename P>\nclass RocAucScoreKernel final : public user_op::OpKernel {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(RocAucScoreKernel);\n  RocAucScoreKernel() = default;\n  ~RocAucScoreKernel() override = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* label = ctx->Tensor4ArgNameAndIndex(\"label\", 0);\n    const user_op::Tensor* pred = ctx->Tensor4ArgNameAndIndex(\"pred\", 0);\n    user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    P* out_ptr = out->mut_dptr<P>();\n    CHECK_EQ(label->shape_view().elem_cnt(), pred->shape_view().elem_cnt());\n    CHECK_EQ(out->shape_view().elem_cnt(), 1);\n    out_ptr[0] = RocAucScore(label->shape_view().elem_cnt(), label->dptr<L>(), pred->dptr<P>(),\n                             tmp_buffer->mut_dptr<float>());\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_ROC_AUC_SCORE_KERNEL(label_type, label_cpp_type, pred_type, pred_cpp_type) \\\n  REGISTER_USER_KERNEL(\"roc_auc_score\")                                                     \\\n      .SetCreateFn<RocAucScoreKernel<label_cpp_type, pred_cpp_type>>()                      \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)                       \\\n                       && (user_op::HobDataType(\"label\", 0) == label_type)                  \\\n                       && (user_op::HobDataType(\"pred\", 0) == pred_type))                   \\\n      .SetInferTmpSizeFn([](user_op::InferContext* ctx) -> size_t {                         \\\n        const Shape& pred_shape = ctx->InputShape(\"pred\", 0);                               \\\n        size_t tmp_buffer_size = pred_shape.elem_cnt() * sizeof(float);                     \\\n        return tmp_buffer_size;                                                             \\\n      })\nREGISTER_ROC_AUC_SCORE_KERNEL(DataType::kDouble, double, DataType::kFloat, float);\nREGISTER_ROC_AUC_SCORE_KERNEL(DataType::kFloat, float, DataType::kFloat, float);\nREGISTER_ROC_AUC_SCORE_KERNEL(DataType::kInt32, int, DataType::kFloat, float);\nREGISTER_ROC_AUC_SCORE_KERNEL(DataType::kInt64, int64_t, DataType::kFloat, float);\nREGISTER_ROC_AUC_SCORE_KERNEL(DataType::kInt8, int8_t, DataType::kFloat, float);\nREGISTER_ROC_AUC_SCORE_KERNEL(DataType::kUInt8, uint8_t, DataType::kFloat, float);\n\n}  // namespace\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/roi_align_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/new_kernel_util.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename T>\n__device__ T BilinearInterpolate(const T* channel_dptr, const int32_t height, const int32_t width,\n                                 T y, T x) {\n  if (y < -1.0 || y > height || x < -1.0 || x > width) { return 0; }\n\n  if (y <= 0) { y = 0; }\n  if (x <= 0) { x = 0; }\n  int32_t y_low = static_cast<int32_t>(y);\n  int32_t x_low = static_cast<int32_t>(x);\n  int32_t y_high = 0;\n  int32_t x_high = 0;\n\n  if (y_low >= height - 1) {\n    y_low = height - 1;\n    y_high = y_low;\n    y = static_cast<T>(y_low);\n  } else {\n    y_high = y_low + 1;\n  }\n\n  if (x_low >= width - 1) {\n    x_low = width - 1;\n    x_high = x_low;\n    x = static_cast<T>(x_low);\n  } else {\n    x_high = x_low + 1;\n  }\n\n  const T ly = y - y_low;\n  const T lx = x - x_low;\n  const T hy = 1.f - ly;\n  const T hx = 1.f - lx;\n\n  // https://en.wikipedia.org/wiki/Bilinear_interpolation\n  const int64_t q11 = y_low * width + x_low;\n  const int64_t q21 = y_low * width + x_high;\n  const int64_t q12 = y_high * width + x_low;\n  const int64_t q22 = y_high * width + x_high;\n  //  no 1 / (x_high - x_low) * (y_high - y_low) because it will always be 1 in RoI Align\n  return (hy * hx) * channel_dptr[q11] + (hy * lx) * channel_dptr[q21]\n         + (ly * hx) * channel_dptr[q12] + (ly * lx) * channel_dptr[q22];\n}\n\ntemplate<typename T>\n__device__ bool BilinearInterpolateDiff(const T bin_diff_avg, const int64_t height,\n                                        const int64_t width, T y, T x, T& diff11, T& diff21,\n                                        T& diff12, T& diff22, int32_t& x_low, int32_t& x_high,\n                                        int32_t& y_low, int32_t& y_high) {\n  if (y < -1.0 || y > height || x < -1.0 || x > width) { return false; }\n\n  if (y <= 0) { y = 0; }\n  if (x <= 0) { x = 0; }\n\n  y_low = static_cast<int32_t>(y);\n  x_low = static_cast<int32_t>(x);\n\n  if (y_low >= height - 1) {\n    y_low = height - 1;\n    y_high = y_low;\n    y = static_cast<T>(y_low);\n  } else {\n    y_high = y_low + 1;\n  }\n\n  if (x_low >= width - 1) {\n    x_low = width - 1;\n    x_high = x_low;\n    x = static_cast<T>(x_low);\n  } else {\n    x_high = x_low + 1;\n  }\n\n  const T ly = y - y_low;\n  const T lx = x - x_low;\n  const T hy = 1.f - ly;\n  const T hx = 1.f - lx;\n\n  diff11 = bin_diff_avg * hy * hx;\n  diff21 = bin_diff_avg * hy * lx;\n  diff12 = bin_diff_avg * ly * hx;\n  diff22 = bin_diff_avg * ly * lx;\n  return true;\n}\n\ntemplate<typename T>\n__global__ void RoiAlignForward(const int64_t nthreads, const T* in_dptr, const T* rois_dptr,\n                                const T spatial_scale, const int32_t sampling_ratio,\n                                const int64_t channel_num, const int64_t height,\n                                const int64_t width, const int64_t pooled_height,\n                                const int64_t pooled_width, const bool aligned, T* out_dptr) {\n  const int64_t pooled_area = pooled_height * pooled_width;\n  const int64_t channel_pooled_area = channel_num * pooled_height * pooled_width;\n  CUDA_1D_KERNEL_LOOP(index, nthreads) {\n    const int64_t h = (index / pooled_width) % pooled_height;\n    const int64_t w = index % pooled_width;\n    const int64_t c = (index / pooled_area) % channel_num;\n    const int64_t r = index / channel_pooled_area;\n    const T* offset_rois_dptr = rois_dptr + r * 5;\n    const int64_t n = static_cast<int64_t>(offset_rois_dptr[0]);\n    const T align_offset = aligned ? static_cast<T>(0.5) : static_cast<T>(0.f);\n    const T roi_start_w = offset_rois_dptr[1] * spatial_scale - align_offset;\n    const T roi_start_h = offset_rois_dptr[2] * spatial_scale - align_offset;\n    const T roi_end_w = offset_rois_dptr[3] * spatial_scale - align_offset;\n    const T roi_end_h = offset_rois_dptr[4] * spatial_scale - align_offset;\n    T roi_height = roi_end_h - roi_start_h;\n    T roi_width = roi_end_w - roi_start_w;\n    // aligned == false is for compatibility. the argument \"aligned\" doesn't have the semantic of\n    // determining minimum roi size\n    if (aligned == false) {\n      roi_height = max(roi_height, static_cast<T>(1.0));\n      roi_width = max(roi_width, static_cast<T>(1.0));\n    }\n    const T bin_height = static_cast<T>(roi_height) / static_cast<T>(pooled_height);\n    const T bin_width = static_cast<T>(roi_width) / static_cast<T>(pooled_width);\n    const int32_t bin_grid_height =\n        (sampling_ratio > 0) ? sampling_ratio : ceil(roi_height / pooled_height);\n    const int32_t bin_grid_width =\n        (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);\n    const T count = max(bin_grid_height * bin_grid_width, 1);\n    const T* channel_dptr = in_dptr + (n * channel_num + c) * height * width;\n    T out_val = 0.0;\n    FOR_RANGE(int64_t, grid_i, 0, bin_grid_height) {\n      // + .5f for center position\n      T y = roi_start_h + h * bin_height\n            + static_cast<T>(grid_i + 0.5f) * bin_height / static_cast<T>(bin_grid_height);\n      FOR_RANGE(int64_t, grid_j, 0, bin_grid_width) {\n        T x = roi_start_w + w * bin_width\n              + static_cast<T>(grid_j + 0.5f) * bin_width / static_cast<T>(bin_grid_width);\n        out_val += BilinearInterpolate(channel_dptr, height, width, y, x);\n      }\n    }\n    out_dptr[index] = out_val / count;\n  }\n}\n\ntemplate<typename T>\n__global__ void RoiAlignBackward(const int64_t nthreads, const T* out_diff_dptr, const T* rois_dptr,\n                                 const T spatial_scale, const int32_t sampling_ratio,\n                                 const int64_t channel_num, const int64_t height,\n                                 const int64_t width, const int64_t pooled_height,\n                                 const int64_t pooled_width, const bool aligned, T* in_diff_dptr) {\n  const int64_t pooled_area = pooled_height * pooled_width;\n  const int64_t channel_pooled_area = channel_num * pooled_height * pooled_width;\n  CUDA_1D_KERNEL_LOOP(index, nthreads) {\n    const int64_t h = (index / pooled_width) % pooled_height;\n    const int64_t w = index % pooled_width;\n    const int64_t c = (index / pooled_area) % channel_num;\n    const int64_t r = index / channel_pooled_area;\n    const T* offset_rois_dptr = rois_dptr + r * 5;\n    const int64_t n = static_cast<int64_t>(offset_rois_dptr[0]);\n    const T align_offset = aligned ? static_cast<T>(0.5) : static_cast<T>(0.f);\n    const T roi_start_w = offset_rois_dptr[1] * spatial_scale - align_offset;\n    const T roi_start_h = offset_rois_dptr[2] * spatial_scale - align_offset;\n    const T roi_end_w = offset_rois_dptr[3] * spatial_scale - align_offset;\n    const T roi_end_h = offset_rois_dptr[4] * spatial_scale - align_offset;\n    T roi_width = roi_end_w - roi_start_w;\n    T roi_height = roi_end_h - roi_start_h;\n    // aligned == false is for compatibility. the argument \"aligned\" doesn't have the semantic of\n    // determining minimum roi size\n    if (aligned == false) {\n      roi_height = max(roi_height, static_cast<T>(1.0));\n      roi_width = max(roi_width, static_cast<T>(1.0));\n    }\n    const T bin_height = static_cast<T>(roi_height) / static_cast<T>(pooled_height);\n    const T bin_width = static_cast<T>(roi_width) / static_cast<T>(pooled_width);\n    const int32_t bin_grid_height =\n        (sampling_ratio > 0) ? sampling_ratio : ceil(roi_height / pooled_height);\n    const int32_t bin_grid_width =\n        (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);\n\n    const T count = max(bin_grid_height * bin_grid_width, 1);\n    const T bin_diff_avg = out_diff_dptr[index] / count;\n    T* in_diff_channel_dptr = in_diff_dptr + (n * channel_num + c) * height * width;\n    FOR_RANGE(int64_t, grid_i, 0, bin_grid_height) {\n      // + .5f for center position\n      T y = roi_start_h + h * bin_height\n            + static_cast<T>(grid_i + 0.5f) * bin_height / static_cast<T>(bin_grid_height);\n      FOR_RANGE(int64_t, grid_j, 0, bin_grid_width) {\n        T x = roi_start_w + w * bin_width\n              + static_cast<T>(grid_j + 0.5f) * bin_width / static_cast<T>(bin_grid_width);\n        T diff11 = 0;\n        T diff21 = 0;\n        T diff12 = 0;\n        T diff22 = 0;\n        int32_t x_low = 0;\n        int32_t x_high = 0;\n        int32_t y_low = 0;\n        int32_t y_high = 0;\n        bool has_diff = BilinearInterpolateDiff(bin_diff_avg, height, width, y, x, diff11, diff21,\n                                                diff12, diff22, x_low, x_high, y_low, y_high);\n        if (has_diff) {\n          const int64_t q11 = y_low * width + x_low;\n          const int64_t q21 = y_low * width + x_high;\n          const int64_t q12 = y_high * width + x_low;\n          const int64_t q22 = y_high * width + x_high;\n          atomicAdd(in_diff_channel_dptr + q11, diff11);\n          atomicAdd(in_diff_channel_dptr + q21, diff21);\n          atomicAdd(in_diff_channel_dptr + q12, diff12);\n          atomicAdd(in_diff_channel_dptr + q22, diff22);\n        }\n      }\n    }\n  }\n}\n\n}  // namespace\n\ntemplate<typename T>\nclass RoIAlignKernel final : public user_op::OpKernel {\n public:\n  RoIAlignKernel() = default;\n  ~RoIAlignKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* x_blob = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    const user_op::Tensor* rois_blob = ctx->Tensor4ArgNameAndIndex(\"rois\", 0);\n    if (rois_blob->shape_view().elem_cnt() == 0) { return; }\n    user_op::Tensor* y_blob = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    const int32_t pooled_h = ctx->Attr<int32_t>(\"pooled_h\");\n    const int32_t pooled_w = ctx->Attr<int32_t>(\"pooled_w\");\n    const float spatial_scale = ctx->Attr<float>(\"spatial_scale\");\n    const int32_t sampling_ratio = ctx->Attr<int32_t>(\"sampling_ratio\");\n    const bool aligned = ctx->Attr<bool>(\"aligned\");\n\n    const int64_t elem_cnt = y_blob->shape_view().elem_cnt();\n    RoiAlignForward<T><<<BlocksNum4ThreadsNum(elem_cnt), kCudaThreadsNumPerBlock, 0,\n                         ctx->stream()->As<ep::CudaStream>()->cuda_stream()>>>(\n        elem_cnt, x_blob->dptr<T>(), rois_blob->dptr<T>(), spatial_scale, sampling_ratio,\n        x_blob->shape_view().At(1), x_blob->shape_view().At(2), x_blob->shape_view().At(3),\n        pooled_h, pooled_w, aligned, y_blob->mut_dptr<T>());\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\ntemplate<typename T>\nclass RoIAlignGradKernel final : public user_op::OpKernel {\n public:\n  RoIAlignGradKernel() = default;\n  ~RoIAlignGradKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    user_op::Tensor* dx_blob = ctx->Tensor4ArgNameAndIndex(\"dx\", 0);\n    if (dx_blob == nullptr) { return; }\n    Memset<DeviceType::kCUDA>(ctx->stream(), dx_blob->mut_dptr<T>(), 0,\n                              dx_blob->shape_view().elem_cnt() * sizeof(T));\n    const user_op::Tensor* dy_blob = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    const user_op::Tensor* rois_blob = ctx->Tensor4ArgNameAndIndex(\"rois\", 0);\n    const int32_t pooled_h = ctx->Attr<int32_t>(\"pooled_h\");\n    const int32_t pooled_w = ctx->Attr<int32_t>(\"pooled_w\");\n    const float spatial_scale = ctx->Attr<float>(\"spatial_scale\");\n    const int32_t sampling_ratio = ctx->Attr<int32_t>(\"sampling_ratio\");\n    const bool aligned = ctx->Attr<bool>(\"aligned\");\n\n    const int64_t elem_cnt = dy_blob->shape_view().elem_cnt();\n    if (elem_cnt > 0) {\n      RoiAlignBackward<T><<<BlocksNum4ThreadsNum(elem_cnt), kCudaThreadsNumPerBlock, 0,\n                            ctx->stream()->As<ep::CudaStream>()->cuda_stream()>>>(\n          elem_cnt, dy_blob->dptr<T>(), rois_blob->dptr<T>(), spatial_scale, sampling_ratio,\n          dx_blob->shape_view().At(1), dx_blob->shape_view().At(2), dx_blob->shape_view().At(3),\n          pooled_h, pooled_w, aligned, dx_blob->mut_dptr<T>());\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\nREGISTER_USER_KERNEL(\"roi_align\")\n    .SetCreateFn<RoIAlignKernel<float>>()\n    .SetIsMatchedHob(user_op::HobDeviceType() == DeviceType::kCUDA);\n\nREGISTER_USER_KERNEL(\"roi_align_grad\")\n    .SetCreateFn<RoIAlignGradKernel<float>>()\n    .SetIsMatchedHob(user_op::HobDeviceType() == DeviceType::kCUDA);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/roll_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/user/kernels/roll_kernel_utils.h\"\n\n#include <algorithm>\n\nnamespace oneflow {\n\ntemplate<typename T>\nclass CpuRollKernel final : public user_op::OpKernel {\n public:\n  CpuRollKernel() = default;\n  ~CpuRollKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    const std::vector<int32_t>& shifts = ctx->Attr<std::vector<int32_t>>(\"shifts\");\n    const std::vector<int32_t>& dims = ctx->Attr<std::vector<int32_t>>(\"dims\");\n\n    SHAPE new_shape{};\n    SHIFTS new_shifts{};\n    int32_t num_axes = 0;\n    computeParams(in->shape_view(), shifts, dims, new_shifts.val, new_shape.val, &num_axes);\n\n    const T* in_ptr = in->dptr<T>();\n    T* out_ptr = out->mut_dptr<T>();\n    const int32_t size = out->shape_view().elem_cnt();\n\n    STRIDE stride{};\n    initStride(stride, new_shape, num_axes);\n\n    transformShifts(new_shifts.val, new_shape.val, num_axes);\n\n    for (int32_t i = 0; i < size; ++i) {\n      int shifted_i = switchGetShiftedIndex(i, new_shifts.val, new_shape.val, stride.val, num_axes);\n      out_ptr[i] = in_ptr[shifted_i];\n    }\n  }\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_ROLL_KERNEL(dtype)                                                 \\\n  REGISTER_USER_KERNEL(\"roll\").SetCreateFn<CpuRollKernel<dtype>>().SetIsMatchedHob( \\\n      (user_op::HobDeviceType() == DeviceType::kCPU)                                \\\n      && (user_op::HobDataType(\"in\", 0) == GetDataType<dtype>::value))\n\nREGISTER_ROLL_KERNEL(float);\nREGISTER_ROLL_KERNEL(double);\nREGISTER_ROLL_KERNEL(bool);\nREGISTER_ROLL_KERNEL(uint8_t);\nREGISTER_ROLL_KERNEL(int8_t);\nREGISTER_ROLL_KERNEL(int32_t);\nREGISTER_ROLL_KERNEL(int64_t);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/roll_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/device/cuda_util.h\"\n#include \"oneflow/core/cuda/atomic.cuh\"\n#include \"oneflow/user/kernels/roll_kernel_utils.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename T, int Dim>\n__global__ void RollCudaKernel(const T* in_ptr, const SHIFTS shifts, const SHAPE shape,\n                               const STRIDE stride, const int64_t elements, T* out_ptr) {\n  int32_t global_index = (blockDim.x * blockIdx.x) + threadIdx.x;\n  int32_t step = gridDim.x * blockDim.x;\n  while (global_index < elements) {\n    int32_t shifted_global_index =\n        getShiftedIndex<Dim>(global_index, shifts.val, shape.val, stride.val);\n    out_ptr[global_index] = in_ptr[shifted_global_index];\n    global_index += step;\n  }\n}\n\ntemplate<typename T, int Dim>\nstruct GpuRollFunctor final {\n  void operator()(ep::Stream* stream, const T* in_ptr, const SHIFTS shifts, const SHAPE shape,\n                  const STRIDE stride, const int64_t elements, T* out_ptr) {\n    RollCudaKernel<T, Dim><<<BlocksNum4ThreadsNum(elements), kCudaThreadsNumPerBlock, 0,\n                             stream->As<ep::CudaStream>()->cuda_stream()>>>(\n        in_ptr, shifts, shape, stride, elements, out_ptr);\n  }\n};\n\ntemplate<int Dim>\nstruct GpuRollFunctor<float16, Dim> final {\n  void operator()(ep::Stream* stream, const float16* in_ptr, const SHIFTS shifts, const SHAPE shape,\n                  const STRIDE stride, const int64_t elements, float16* out_ptr) {\n    RollCudaKernel<half, Dim><<<BlocksNum4ThreadsNum(elements), kCudaThreadsNumPerBlock, 0,\n                                stream->As<ep::CudaStream>()->cuda_stream()>>>(\n        reinterpret_cast<const half*>(in_ptr), shifts, shape, stride, elements,\n        reinterpret_cast<half*>(out_ptr));\n  }\n};\n\ntemplate<typename T>\n__global__ void RollFlattenCudaKernel(const T* in_ptr, const int64_t start,\n                                      const int64_t elem_count_minus_start, const int64_t elements,\n                                      T* out_ptr) {\n  int64_t global_index = (blockDim.x * blockIdx.x) + threadIdx.x;\n  int32_t step = gridDim.x * blockDim.x;\n\n  while (global_index < elements) {\n    int64_t source_idx = 0;\n    if (global_index >= elem_count_minus_start) {\n      source_idx = global_index - elem_count_minus_start;\n    } else {\n      source_idx = global_index + start;\n    }\n    out_ptr[global_index] = in_ptr[source_idx];\n\n    global_index += step;\n  }\n}\n\ntemplate<typename T>\nstruct GpuRollFlattenFunctor final {\n  void operator()(ep::Stream* stream, const T* in_ptr, const int64_t start,\n                  const int64_t elem_count_minus_start, const int64_t elements, T* out_ptr) {\n    RollFlattenCudaKernel<T><<<BlocksNum4ThreadsNum(elements), kCudaThreadsNumPerBlock, 0,\n                               stream->As<ep::CudaStream>()->cuda_stream()>>>(\n        in_ptr, start, elem_count_minus_start, elements, out_ptr);\n  }\n};\n\ntemplate<>\nvoid GpuRollFlattenFunctor<float16>::operator()(ep::Stream* stream, const float16* in_ptr,\n                                                const int64_t start,\n                                                const int64_t elem_count_minus_start,\n                                                const int64_t elements, float16* out_ptr) {\n  RollFlattenCudaKernel<half><<<BlocksNum4ThreadsNum(elements), kCudaThreadsNumPerBlock, 0,\n                                stream->As<ep::CudaStream>()->cuda_stream()>>>(\n      reinterpret_cast<const half*>(in_ptr), start, elem_count_minus_start, elements,\n      reinterpret_cast<half*>(out_ptr));\n}\n\ntemplate<typename T>\n__global__ void Roll1DimCudaKernel(const T* in_ptr, const int32_t stride_x_size,\n                                   const int32_t stride, const int32_t size_minus_start,\n                                   const int32_t size_minus_start_x_stride,\n                                   const int32_t start_x_stride, const int64_t elements,\n                                   T* out_ptr) {\n  int32_t global_index = (blockDim.x * blockIdx.x) + threadIdx.x;\n  int32_t step = gridDim.x * blockDim.x;\n\n  while (global_index < elements) {\n    // roll dim idx is the index of linear_index along the rolling dimension.\n    int32_t roll_dim_idx = global_index % stride_x_size / stride;\n    // index into the source data to find appropriate value.\n    int32_t source_idx = 0;\n    if (roll_dim_idx >= size_minus_start) {\n      source_idx = global_index - size_minus_start_x_stride;\n    } else {\n      source_idx = global_index + start_x_stride;\n    }\n    out_ptr[global_index] = in_ptr[source_idx];\n\n    global_index += step;\n  }\n}\n\ntemplate<typename T>\nstruct GpuRoll1DimFunctor final {\n  void operator()(ep::Stream* stream, const T* in_ptr, const int32_t stride_x_size,\n                  const int32_t stride, const int32_t size_minus_start,\n                  const int32_t size_minus_start_x_stride, const int32_t start_x_stride,\n                  const int64_t elements, T* out_ptr) {\n    Roll1DimCudaKernel<T><<<BlocksNum4ThreadsNum(elements), kCudaThreadsNumPerBlock, 0,\n                            stream->As<ep::CudaStream>()->cuda_stream()>>>(\n        in_ptr, stride_x_size, stride, size_minus_start, size_minus_start_x_stride, start_x_stride,\n        elements, out_ptr);\n  }\n};\n\ntemplate<>\nvoid GpuRoll1DimFunctor<float16>::operator()(ep::Stream* stream, const float16* in_ptr,\n                                             const int32_t stride_x_size, const int32_t stride,\n                                             const int32_t size_minus_start,\n                                             const int32_t size_minus_start_x_stride,\n                                             const int32_t start_x_stride, const int64_t elements,\n                                             float16* out_ptr) {\n  Roll1DimCudaKernel<half><<<BlocksNum4ThreadsNum(elements), kCudaThreadsNumPerBlock, 0,\n                             stream->As<ep::CudaStream>()->cuda_stream()>>>(\n      reinterpret_cast<const half*>(in_ptr), stride_x_size, stride, size_minus_start,\n      size_minus_start_x_stride, start_x_stride, elements, reinterpret_cast<half*>(out_ptr));\n}\n\n}  // namespace\n\ntemplate<typename T>\nclass GpuRollKernel final : public user_op::OpKernel {\n public:\n  GpuRollKernel() = default;\n  ~GpuRollKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    const std::vector<int32_t>& shifts = ctx->Attr<std::vector<int32_t>>(\"shifts\");\n    const std::vector<int32_t>& dims = ctx->Attr<std::vector<int32_t>>(\"dims\");\n\n    const T* in_ptr = in->dptr<T>();\n    T* out_ptr = out->mut_dptr<T>();\n    const int64_t elem_count = out->shape_view().elem_cnt();\n\n    if (dims[0] == -1) {\n      // NOTE(Liang Depeng): Borrow the implementation of pytorch and simplify to 1d array case.\n      int64_t start = (elem_count - shifts[0]) % elem_count;\n      if (start < 0) start = start + elem_count;\n      const int64_t elem_count_minus_start = elem_count - start;\n      GpuRollFlattenFunctor<T>()(ctx->stream(), in_ptr, start, elem_count_minus_start, elem_count,\n                                 out_ptr);\n    } else {\n      SHAPE new_shape{};\n      SHIFTS new_shifts{};\n      int32_t num_axes = 0;\n      computeParams(in->shape_view(), shifts, dims, new_shifts.val, new_shape.val, &num_axes);\n\n      STRIDE stride{};\n      initStride(stride, new_shape, num_axes);\n\n      if (dims.size() == 1) {\n        // NOTE(Liang Depeng): Borrow the implementation of pytorch\n        const int32_t size = new_shape.val[dims[0]];\n        int32_t start = (size - new_shifts.val[dims[0]]) % size;\n        // Behavior of % is different in C++ vs Python for negative numbers. This\n        // corrects the difference.\n        if (start < 0) start = start + size;\n\n        const int32_t stride_x_size = stride.val[dims[0]] * size;\n        const int32_t size_minus_start = size - start;\n        const int32_t size_minus_start_x_stride = size_minus_start * stride.val[dims[0]];\n        const int32_t start_x_stride = start * stride.val[dims[0]];\n\n        GpuRoll1DimFunctor<T>()(ctx->stream(), in_ptr, stride_x_size, stride.val[dims[0]],\n                                size_minus_start, size_minus_start_x_stride, start_x_stride,\n                                elem_count, out_ptr);\n\n      } else {\n        transformShifts(new_shifts.val, new_shape.val, num_axes);\n        switch (num_axes) {\n          case 1:\n            GpuRollFunctor<T, 1>()(ctx->stream(), in_ptr, new_shifts, new_shape, stride, elem_count,\n                                   out_ptr);\n            break;\n          case 2:\n            GpuRollFunctor<T, 2>()(ctx->stream(), in_ptr, new_shifts, new_shape, stride, elem_count,\n                                   out_ptr);\n            break;\n          case 3:\n            GpuRollFunctor<T, 3>()(ctx->stream(), in_ptr, new_shifts, new_shape, stride, elem_count,\n                                   out_ptr);\n            break;\n          case 4:\n            GpuRollFunctor<T, 4>()(ctx->stream(), in_ptr, new_shifts, new_shape, stride, elem_count,\n                                   out_ptr);\n            break;\n          case 5:\n            GpuRollFunctor<T, 5>()(ctx->stream(), in_ptr, new_shifts, new_shape, stride, elem_count,\n                                   out_ptr);\n            break;\n          case 6:\n            GpuRollFunctor<T, 6>()(ctx->stream(), in_ptr, new_shifts, new_shape, stride, elem_count,\n                                   out_ptr);\n            break;\n          case 7:\n            GpuRollFunctor<T, 7>()(ctx->stream(), in_ptr, new_shifts, new_shape, stride, elem_count,\n                                   out_ptr);\n            break;\n          case 8:\n            GpuRollFunctor<T, 8>()(ctx->stream(), in_ptr, new_shifts, new_shape, stride, elem_count,\n                                   out_ptr);\n            break;\n          case 9:\n            GpuRollFunctor<T, 9>()(ctx->stream(), in_ptr, new_shifts, new_shape, stride, elem_count,\n                                   out_ptr);\n            break;\n          case 10:\n            GpuRollFunctor<T, 10>()(ctx->stream(), in_ptr, new_shifts, new_shape, stride,\n                                    elem_count, out_ptr);\n            break;\n          case 11:\n            GpuRollFunctor<T, 11>()(ctx->stream(), in_ptr, new_shifts, new_shape, stride,\n                                    elem_count, out_ptr);\n            break;\n          case 12:\n            GpuRollFunctor<T, 12>()(ctx->stream(), in_ptr, new_shifts, new_shape, stride,\n                                    elem_count, out_ptr);\n            break;\n          case 13:\n            GpuRollFunctor<T, 13>()(ctx->stream(), in_ptr, new_shifts, new_shape, stride,\n                                    elem_count, out_ptr);\n            break;\n          case 14:\n            GpuRollFunctor<T, 14>()(ctx->stream(), in_ptr, new_shifts, new_shape, stride,\n                                    elem_count, out_ptr);\n            break;\n          case 15:\n            GpuRollFunctor<T, 15>()(ctx->stream(), in_ptr, new_shifts, new_shape, stride,\n                                    elem_count, out_ptr);\n            break;\n          case 16:\n            GpuRollFunctor<T, 16>()(ctx->stream(), in_ptr, new_shifts, new_shape, stride,\n                                    elem_count, out_ptr);\n            break;\n          default: break;\n        }\n      }\n    }\n  }\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_ROLL_KERNEL(dtype)                                                 \\\n  REGISTER_USER_KERNEL(\"roll\").SetCreateFn<GpuRollKernel<dtype>>().SetIsMatchedHob( \\\n      (user_op::HobDeviceType() == DeviceType::kCUDA)                               \\\n      && (user_op::HobDataType(\"in\", 0) == GetDataType<dtype>::value))\n\nREGISTER_ROLL_KERNEL(float);\nREGISTER_ROLL_KERNEL(double);\nREGISTER_ROLL_KERNEL(float16);\nREGISTER_ROLL_KERNEL(bool);\nREGISTER_ROLL_KERNEL(uint8_t);\nREGISTER_ROLL_KERNEL(int8_t);\nREGISTER_ROLL_KERNEL(int32_t);\nREGISTER_ROLL_KERNEL(int64_t);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/roll_kernel_utils.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_ROLL_KERNEL_UTILS_H_\n#define ONEFLOW_ROLL_KERNEL_UTILS_H_\n\n#include \"oneflow/core/framework/framework.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nconst int32_t kMaxDims = 16;\n\nstruct SHIFTS {\n  int32_t val[kMaxDims];\n};\n\nstruct SHAPE {\n  int32_t val[kMaxDims];\n};\n\nstruct STRIDE {\n  STRIDE() {\n    for (int i = 0; i < kMaxDims; ++i) { val[i] = 1; }\n  }\n  int32_t val[kMaxDims];\n};\n\ntemplate<int Dim>\nOF_DEVICE_FUNC int32_t getShiftedIndex(const int32_t global_index, const int32_t* shifts,\n                                       const int32_t* shape, const int32_t* stride) {\n  int32_t remaining = global_index;\n  int32_t shifted_global_index = 0;\n#ifdef __CUDA_ARCH__\n#pragma unroll\n#endif\n  for (int32_t i = 0; i < Dim; ++i) {\n    const int32_t idx = remaining / stride[i];\n    // NOTE(Liang Depeng): Compute the shifted index of each axis.\n    int32_t shifted_idx = (idx - shifts[i]);\n    // NOTE(Liang Depeng): This correct the results.\n    if (shifted_idx < 0) shifted_idx = shifted_idx + shape[i];\n    if (shifted_idx >= shape[i]) shifted_idx = shifted_idx - shape[i];\n\n    shifted_global_index += shifted_idx * stride[i];\n    remaining = remaining - idx * stride[i];\n  }\n  return shifted_global_index;\n}\n\nOF_DEVICE_FUNC int32_t switchGetShiftedIndex(const int32_t global_index, const int32_t* shifts,\n                                             const int32_t* shape, const int32_t* stride, int n) {\n  switch (n) {\n    case 1: return getShiftedIndex<1>(global_index, shifts, shape, stride);\n    case 2: return getShiftedIndex<2>(global_index, shifts, shape, stride);\n    case 3: return getShiftedIndex<3>(global_index, shifts, shape, stride);\n    case 4: return getShiftedIndex<4>(global_index, shifts, shape, stride);\n    case 5: return getShiftedIndex<5>(global_index, shifts, shape, stride);\n    case 6: return getShiftedIndex<6>(global_index, shifts, shape, stride);\n    case 7: return getShiftedIndex<7>(global_index, shifts, shape, stride);\n    case 8: return getShiftedIndex<8>(global_index, shifts, shape, stride);\n    case 9: return getShiftedIndex<9>(global_index, shifts, shape, stride);\n    case 10: return getShiftedIndex<10>(global_index, shifts, shape, stride);\n    case 11: return getShiftedIndex<11>(global_index, shifts, shape, stride);\n    case 12: return getShiftedIndex<12>(global_index, shifts, shape, stride);\n    case 13: return getShiftedIndex<13>(global_index, shifts, shape, stride);\n    case 14: return getShiftedIndex<14>(global_index, shifts, shape, stride);\n    case 15: return getShiftedIndex<15>(global_index, shifts, shape, stride);\n    case 16: return getShiftedIndex<16>(global_index, shifts, shape, stride);\n  }\n  return 0;\n}\n\nstatic void initStride(STRIDE& stride, const SHAPE& dim_vec, const int32_t dims) {\n  for (int i = dims - 2; i >= 0; --i) { stride.val[i] = dim_vec.val[i + 1] * stride.val[i + 1]; }\n}\n\nstatic void transformShifts(int32_t* shifts, int32_t* shape, int n) {\n  for (int i = 0; i < n; ++i) { shifts[i] = shifts[i] % shape[i]; }  // NOLINT\n}\n\nstatic void computeParams(const ShapeView& in_shape, const std::vector<int32_t>& shifts,\n                          const std::vector<int32_t>& dims, int32_t* new_shifts, int32_t* new_shape,\n                          int32_t* new_num_axes) {\n  if (dims[0] == -1) {\n    // NOTE(Liang Depeng):\n    // If user did not set the dims parameter,\n    // the input tensor will be flattened before rolling,\n    // which means we can think of the input tensor as an 1 dimensional array.\n    new_shifts[0] = shifts[0];\n    *new_num_axes = 1;\n    new_shape[0] = in_shape.elem_cnt();\n  } else {\n    std::map<int32_t, int32_t> dim_to_shift;\n    for (int i = 0; i < shifts.size(); ++i) { dim_to_shift.emplace(dims[i], shifts[i]); }\n    // NOTE(Liang Depeng):\n    // Compute the shift parameter for each axis.\n    // For those axis which user did not specified shift value, will be set to 0\n    for (int i = 0; i < in_shape.NumAxes(); ++i) {\n      if (dim_to_shift.count(i) > 0) {\n        new_shifts[i] = dim_to_shift.at(i);\n      } else {\n        new_shifts[i] = 0;\n      }\n      new_shape[i] = in_shape.At(i);\n    }\n    *new_num_axes = in_shape.NumAxes();\n  }\n}\n\n}  // namespace\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_ROLL_KERNEL_UTILS_H_\n"
  },
  {
    "path": "oneflow/user/kernels/rrelu_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/kernel_util.h\"\n#include \"oneflow/user/kernels/distributions/common.h\"\n#include \"oneflow/user/kernels/random_seed_util.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename T, typename V>\nstatic T uniform_real(V val, T from, T to) {\n  constexpr auto MASK =\n      static_cast<V>((static_cast<uint64_t>(1) << std::numeric_limits<T>::digits) - 1);\n  constexpr auto DIVISOR =\n      static_cast<T>(1) / (static_cast<uint64_t>(1) << std::numeric_limits<T>::digits);\n  T x = (val & MASK) * DIVISOR;\n  return (x * (to - from) + from);\n}\n\nstatic uint64_t make64BitsFrom32Bits(uint32_t hi, uint32_t lo) {\n  return (static_cast<uint64_t>(hi) << 32) | lo;\n}\n\n}  // namespace\n\ntemplate<typename T>\nclass CpuRReluKernel final : public user_op::OpKernel {\n public:\n  CpuRReluKernel() = default;\n  ~CpuRReluKernel() = default;\n\n  std::shared_ptr<user_op::OpKernelState> CreateOpKernelState(\n      user_op::KernelInitContext* ctx) const override {\n    const auto& generator = CHECK_JUST(one::MakeGenerator(DeviceType::kCPU));\n    generator->set_current_seed(CHECK_JUST(\n        GetOpKernelRandomSeedInCurrentRank(ctx, ctx->Attr<int64_t>(\"seed\"), {\"output\", 0})));\n    return std::make_shared<DistributionKernelState>(generator);\n  }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state,\n               const user_op::OpKernelCache* cache) const override {\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    const int64_t size = in->shape_view().elem_cnt();\n    if (size == 0) return;\n\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"output\", 0);\n    user_op::Tensor* noise_data = ctx->Tensor4ArgNameAndIndex(\"noise_data\", 0);\n    const T& lower = ctx->Attr<float>(\"lower\");\n    const T& upper = ctx->Attr<float>(\"upper\");\n\n    T* out_ptr = out->mut_dptr<T>();\n    T* noise_ptr = noise_data->mut_dptr<T>();\n    const T* in_ptr = in->dptr<T>();\n\n    auto* distribution_state = dynamic_cast<DistributionKernelState*>(state);\n    CHECK_NOTNULL(distribution_state);\n    const auto& generator = distribution_state->generator();\n    CHECK_NOTNULL(generator);\n    auto cpu_gen = CHECK_JUST(generator->Get<ep::CPUGenerator>());\n    std::lock_guard<std::mutex> lock(cpu_gen->mutex_);\n    ep::pytorch_mt19937_engine& engine = cpu_gen->torch_engine();\n\n    FOR_RANGE(int64_t, i, 0, size) {\n      if (*(in_ptr + i) >= 0) {\n        noise_ptr[i] = 1;\n        out_ptr[i] = in_ptr[i];\n      } else {\n        uint32_t random1 = engine();\n        uint32_t random2 = engine();\n        uint64_t rand_unit = make64BitsFrom32Bits(random1, random2);\n        T uniform_sample = uniform_real(rand_unit, lower, upper);\n        noise_ptr[i] = uniform_sample;\n        out_ptr[i] = in_ptr[i] * uniform_sample;\n      }\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_CPU_RRelu_KERNEL(dtype)                                              \\\n  REGISTER_USER_KERNEL(\"rrelu\").SetCreateFn<CpuRReluKernel<dtype>>().SetIsMatchedHob( \\\n      (user_op::HobDeviceType() == DeviceType::kCPU)                                  \\\n      && (user_op::HobDataType(\"in\", 0) == GetDataType<dtype>::value))\n\nREGISTER_CPU_RRelu_KERNEL(float);\nREGISTER_CPU_RRelu_KERNEL(double);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/rrelu_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/new_kernel_util.h\"\n#include \"oneflow/user/kernels/distributions/normal_distribution.h\"\n#include \"oneflow/user/kernels/distributions/distribution_template_util.cuh\"\n#include \"oneflow/user/kernels/distributions/common.h\"\n#include \"oneflow/user/kernels/random_seed_util.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename T, typename ComputeType>\nstruct UniformTransformFunctor {\n  UniformTransformFunctor(ComputeType range, ComputeType lower) : range(range), lower(lower) {}\n  __device__ T operator()(ComputeType random_val) const {\n    return static_cast<T>(random_val * range + lower);\n  }\n  ComputeType range;\n  ComputeType lower;\n};\n\ntemplate<typename T, typename ComputeType, int unroll_factor, typename Distribution,\n         typename Transform>\nOF_LAUNCH_BOUNDS_2(block_size_bound, grid_size_bound)\n__global__\n    void RReluKernel(int64_t numel, uint64_t seed, uint64_t offset, const T* in_ptr, T* out_ptr,\n                     T* noise_data_ptr, Distribution dist_func, Transform transform_func) {\n  int idx = blockIdx.x * blockDim.x + threadIdx.x;\n  curandStatePhilox4_32_10_t state;\n  curand_init(seed, idx, offset, &state);\n\n  int rounded_size = ((numel - 1) / (blockDim.x * gridDim.x * unroll_factor) + 1) * blockDim.x\n                     * gridDim.x * unroll_factor;\n  for (int32_t linear_index = idx; linear_index < rounded_size;\n       linear_index += blockDim.x * gridDim.x * unroll_factor) {\n    auto rand = dist_func(&state);\n#pragma unroll\n    for (int ii = 0; ii < unroll_factor; ii++) {\n      int li = linear_index + blockDim.x * gridDim.x * ii;\n      if (li < numel) {\n        T r = transform_func(static_cast<ComputeType>((&rand.x)[ii]));\n        if (in_ptr[li] <= static_cast<T>(0)) {\n          out_ptr[li] = in_ptr[li] * r;\n          noise_data_ptr[li] = r;\n        } else {\n          out_ptr[li] = in_ptr[li];\n          noise_data_ptr[li] = static_cast<T>(1);\n        }\n      }\n    }\n  }\n}\n\n}  // namespace\n\ntemplate<typename T>\nclass CudaRReluKernel final : public user_op::OpKernel {\n public:\n  CudaRReluKernel() = default;\n  ~CudaRReluKernel() = default;\n\n  std::shared_ptr<user_op::OpKernelState> CreateOpKernelState(\n      user_op::KernelInitContext* ctx) const override {\n    const auto& generator = CHECK_JUST(one::MakeGenerator(DeviceType::kCUDA));\n    generator->set_current_seed(CHECK_JUST(\n        GetOpKernelRandomSeedInCurrentRank(ctx, ctx->Attr<int64_t>(\"seed\"), {\"output\", 0})));\n    return std::make_shared<DistributionKernelState>(generator);\n  }\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state,\n               const user_op::OpKernelCache* cache) const override {\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    const int64_t size = in->shape_view().elem_cnt();\n    if (size == 0) return;\n\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"output\", 0);\n    user_op::Tensor* noise_data = ctx->Tensor4ArgNameAndIndex(\"noise_data\", 0);\n    const T& lower = ctx->Attr<float>(\"lower\");\n    const T& upper = ctx->Attr<float>(\"upper\");\n\n    T* out_ptr = out->mut_dptr<T>();\n    T* noise_ptr = noise_data->mut_dptr<T>();\n    const T* in_ptr = in->dptr<T>();\n\n    auto* distribution_state = dynamic_cast<DistributionKernelState*>(state);\n    CHECK_NOTNULL(distribution_state);\n    const auto& generator = distribution_state->generator();\n    CHECK_NOTNULL(generator);\n    ep::CudaStream* cuda_stream = ctx->stream()->As<ep::CudaStream>();\n    const auto device_index = ctx->stream()->device()->device_index();\n    std::shared_ptr<ep::CUDAGenerator> cuda_gen =\n        CHECK_JUST(generator->Get<ep::CUDAGenerator>(device_index));\n    auto execution_policy = cuda_gen->CalcExecutionPolicy(size, cuda_stream);\n    auto counter_offset = std::get<0>(execution_policy);\n    uint64_t seed = cuda_gen->current_seed();\n    uint64_t offset = cuda_gen->get_philox_offset(counter_offset);\n\n    auto grid = std::get<1>(execution_policy);\n    auto block = std::get<2>(execution_policy);\n\n    using ComputeType = typename distribution::DefaultComputeType<T>::type;\n    UniformTransformFunctor<T, ComputeType> transform_functor(\n        static_cast<ComputeType>(upper - lower), static_cast<ComputeType>(lower));\n    if (std::is_same<T, double>::value) {\n      DistributionFunctor<DistributionOp::kUniform2Double> dist_functor;\n      RReluKernel<T, ComputeType, 2, decltype(dist_functor), decltype(transform_functor)>\n          <<<grid, block, 0, cuda_stream->cuda_stream()>>>(\n              size, seed, offset, in_ptr, out_ptr, noise_ptr, dist_functor, transform_functor);\n    } else {\n      // float\n      DistributionFunctor<DistributionOp::kUniform4> dist_functor;\n      RReluKernel<T, ComputeType, 4, decltype(dist_functor), decltype(transform_functor)>\n          <<<grid, block, 0, cuda_stream->cuda_stream()>>>(\n              size, seed, offset, in_ptr, out_ptr, noise_ptr, dist_functor, transform_functor);\n    }\n  }\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n#define REGISTER_CUDA_RRELU_KERNEL(dtype)                                              \\\n  REGISTER_USER_KERNEL(\"rrelu\").SetCreateFn<CudaRReluKernel<dtype>>().SetIsMatchedHob( \\\n      (user_op::HobDeviceType() == DeviceType::kCUDA)                                  \\\n      && (user_op::HobDataType(\"in\", 0) == GetDataType<dtype>::value));\n\nREGISTER_CUDA_RRELU_KERNEL(float)\nREGISTER_CUDA_RRELU_KERNEL(double)\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/same_padding_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/user/ops/nn_util.h\"\n#include \"oneflow/core/ep/include/primitive/copy_nd.h\"\n#include \"oneflow/core/ep/include/primitive/fill.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename Context>\nstd::unique_ptr<ep::primitive::Fill> NewFillPrimitive(Context* ctx) {\n  const DataType data_type = ctx->TensorDesc4ArgNameAndIndex(\"y\", 0)->data_type();\n  return ep::primitive::NewPrimitive<ep::primitive::FillFactory>(ctx->device_type(), data_type);\n}\n\ntemplate<typename Context>\nstd::unique_ptr<ep::primitive::CopyNd> NewCopyNdPrimitive(Context* ctx) {\n  const auto& in_arg_pair = ctx->inputs().front();\n  const int64_t ndims =\n      ctx->TensorDesc4ArgNameAndIndex(in_arg_pair.first, in_arg_pair.second)->shape().NumAxes();\n  return ep::primitive::NewPrimitive<ep::primitive::CopyNdFactory>(ctx->device_type(), ndims);\n}\n\nauto FillPrimitiveExists() {\n  return hob::make_custom(\"FillPrimitiveExists\", [](const user_op::KernelRegContext& ctx) {\n    return NewFillPrimitive(&ctx).operator bool();\n  });\n}\n\nauto CopyNdPrimitiveExists() {\n  return hob::make_custom(\"CopyNdPrimitiveExists\", [](const user_op::KernelRegContext& ctx) {\n    return NewCopyNdPrimitive(&ctx).operator bool();\n  });\n}\n\n}  // namespace\n\nclass SamePaddingKernel final : public user_op::OpKernel {\n public:\n  SamePaddingKernel() = default;\n  ~SamePaddingKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    const int64_t num_axes = x->shape_view().NumAxes();\n    const std::string& padding = ctx->Attr<std::string>(\"padding\");\n    const std::string& data_format = ctx->Attr<std::string>(\"data_format\");\n    const std::vector<int32_t> kernel_size = ctx->Attr<std::vector<int32_t>>(\"kernel_size\");\n    const std::vector<int32_t> strides = ctx->Attr<std::vector<int32_t>>(\"strides\");\n    const std::vector<int32_t> dilation_rate = ctx->Attr<std::vector<int32_t>>(\"dilation_rate\");\n    std::vector<int64_t> padding_before(num_axes, 0);\n    const size_t idx_offset = IdxOffset(data_format);\n    const int32_t num_spatial_dims = x->shape_view().NumAxes() - 2;\n    for (int32_t i = 0; i < num_spatial_dims; ++i) {\n      int32_t padding_small = 0;\n      int32_t padding_large = 0;\n      CHECK_JUST(CalcSamePadding(x->shape_view().At(idx_offset + i), kernel_size.at(i),  // NOLINT\n                                 dilation_rate.at(i), strides.at(i), &padding_small,     // NOLINT\n                                 &padding_large));                                       // NOLINT\n      if (padding == \"same_lower\") {\n        padding_before[idx_offset + i] = padding_large;\n      } else if (padding == \"same_upper\") {\n        padding_before[idx_offset + i] = padding_small;\n      } else {\n        UNIMPLEMENTED();\n      }\n      CHECK_EQ(y->shape_view().At(idx_offset + i),\n               x->shape_view().At(idx_offset + i) + padding_small + padding_large);\n    }\n    CHECK_EQ(padding_before.size(), num_axes);\n    std::unique_ptr<ep::primitive::Fill> fill_primitive = NewFillPrimitive(ctx);\n    CHECK(fill_primitive);\n    fill_primitive->Launch(ctx->stream(), y->mut_dptr(), Scalar(0), y->shape_view().elem_cnt());\n    DimVector src_pos_vec(num_axes, 0);\n    DimVector dst_pos_vec(padding_before.cbegin(), padding_before.cend());\n    std::unique_ptr<ep::primitive::CopyNd> copy_nd_primitive = NewCopyNdPrimitive(ctx);\n    CHECK(copy_nd_primitive);\n    copy_nd_primitive->Launch(ctx->stream(), x->data_type(), num_axes, y->mut_dptr(),\n                              y->shape_view().ptr(), dst_pos_vec.data(), x->dptr(),\n                              x->shape_view().ptr(), src_pos_vec.data(), x->shape_view().ptr());\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\nREGISTER_USER_KERNEL(\"same_padding\")\n    .SetCreateFn<SamePaddingKernel>()\n    .SetIsMatchedHob(FillPrimitiveExists() && CopyNdPrimitiveExists());\n\nclass SamePaddingGradKernel final : public user_op::OpKernel {\n public:\n  SamePaddingGradKernel() = default;\n  ~SamePaddingGradKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex(\"dx\", 0);\n    const int64_t num_axes = dy->shape_view().NumAxes();\n    const std::string& padding = ctx->Attr<std::string>(\"padding\");\n    const std::string& data_format = ctx->Attr<std::string>(\"data_format\");\n    const std::vector<int32_t> kernel_size = ctx->Attr<std::vector<int32_t>>(\"kernel_size\");\n    const std::vector<int32_t> strides = ctx->Attr<std::vector<int32_t>>(\"strides\");\n    const std::vector<int32_t> dilation_rate = ctx->Attr<std::vector<int32_t>>(\"dilation_rate\");\n    std::vector<int64_t> padding_before(num_axes, 0);\n    const size_t idx_offset = IdxOffset(data_format);\n    const int32_t num_spatial_dims = dy->shape_view().NumAxes() - 2;\n    for (int32_t i = 0; i < num_spatial_dims; ++i) {\n      int32_t padding_small = 0;\n      int32_t padding_large = 0;\n      CHECK_JUST(CalcSamePadding(dx->shape_view().At(idx_offset + i), kernel_size.at(i),  // NOLINT\n                                 dilation_rate.at(i), strides.at(i), &padding_small,      // NOLINT\n                                 &padding_large));                                        // NOLINT\n      if (padding == \"same_lower\") {\n        padding_before[idx_offset + i] = padding_large;\n      } else if (padding == \"same_upper\") {\n        padding_before[idx_offset + i] = padding_small;\n      } else {\n        UNIMPLEMENTED();\n      }\n      CHECK_EQ(dy->shape_view().At(idx_offset + i),\n               dx->shape_view().At(idx_offset + i) + padding_small + padding_large);\n    }\n    DimVector dst_pos_vec(num_axes, 0);\n    DimVector src_pos_vec(padding_before.cbegin(), padding_before.cend());\n    std::unique_ptr<ep::primitive::CopyNd> primitive = NewCopyNdPrimitive(ctx);\n    CHECK(primitive);\n    primitive->Launch(ctx->stream(), dy->data_type(), num_axes, dx->mut_dptr(),\n                      dx->shape_view().ptr(), dst_pos_vec.data(), dy->dptr(),\n                      dy->shape_view().ptr(), src_pos_vec.data(), dx->shape_view().ptr());\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\nREGISTER_USER_KERNEL(\"same_padding_grad\")\n    .SetCreateFn<SamePaddingGradKernel>()\n    .SetIsMatchedHob(CopyNdPrimitiveExists() == true);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/scalar_bitwise_kernels.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/cuda_graph_support.h\"\n#include \"oneflow/core/ep/include/primitive/broadcast_elementwise_binary.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename Context>\nstd::unique_ptr<ep::primitive::BroadcastElementwiseBinary> NewBinaryPrimitive(\n    Context* ctx, ep::primitive::BinaryOp op) {\n  const user_op::TensorDesc* in = ctx->TensorDesc4ArgNameAndIndex(\"in\", 0);\n  const user_op::TensorDesc* out = ctx->TensorDesc4ArgNameAndIndex(\"out\", 0);\n  const int64_t ndims = in->shape().NumAxes();\n  return ep::primitive::NewPrimitive<ep::primitive::BroadcastElementwiseBinaryFactory>(\n      ctx->device_type(), op, in->data_type(), out->data_type(), ndims);\n}\n\ntemplate<ep::primitive::BinaryOp op>\nauto PrimitiveExists() {\n  return hob::make_custom(\"BroadcastElementwiseBinaryPrimitiveExists\",\n                          [](const user_op::KernelRegContext& ctx) {\n                            return NewBinaryPrimitive(&ctx, op).operator bool();\n                          });\n}\n\ntemplate<ep::primitive::BinaryOp op>\nclass ScalarBitwiseKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport {\n public:\n  ScalarBitwiseKernel() = default;\n  ~ScalarBitwiseKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    Scalar scalar_operand = ctx->Attr<int64_t>(\"operand\");\n\n    int64_t elem_cnt = out->shape_view().elem_cnt();\n    if (elem_cnt != 0) {\n      std::unique_ptr<ep::primitive::BroadcastElementwiseBinary> primitive =\n          NewBinaryPrimitive(ctx, op);\n      CHECK(primitive);\n      primitive->Launch(ctx->stream(), in->shape_view().NumAxes(), in->shape_view().ptr(),\n                        in->dptr(), scalar_operand, out->mut_dptr());\n    } else {\n      // For 0-d Tensor\n      return;\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_UNARY_BITWISE_SCALAR_ELEMWISE_USER_KERNEL(kernel_name, binary_op) \\\n  REGISTER_USER_KERNEL(kernel_name)                                                \\\n      .SetCreateFn<ScalarBitwiseKernel<binary_op>>()                               \\\n      .SetIsMatchedHob(PrimitiveExists<binary_op>());\n\nREGISTER_UNARY_BITWISE_SCALAR_ELEMWISE_USER_KERNEL(\"scalar_bitwise_and\",\n                                                   ep::primitive::BinaryOp::kBitwiseAnd);\nREGISTER_UNARY_BITWISE_SCALAR_ELEMWISE_USER_KERNEL(\"scalar_bitwise_or\",\n                                                   ep::primitive::BinaryOp::kBitwiseOr);\nREGISTER_UNARY_BITWISE_SCALAR_ELEMWISE_USER_KERNEL(\"scalar_bitwise_xor\",\n                                                   ep::primitive::BinaryOp::kBitwiseXor);\n\n}  // namespace\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/scalar_by_tensor_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/cuda_graph_support.h\"\n#include \"oneflow/core/ep/include/primitive/broadcast_elementwise_binary.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename Context>\nstd::unique_ptr<ep::primitive::BroadcastElementwiseBinary> NewBroadcastElementwiseBinaryPrimitive(\n    Context* ctx, ep::primitive::BinaryOp op) {\n  const user_op::TensorDesc* x = ctx->TensorDesc4ArgNameAndIndex(\"x\", 0);\n  const user_op::TensorDesc* y = ctx->TensorDesc4ArgNameAndIndex(\"y\", 0);\n  const int64_t ndims = y->shape().NumAxes();\n  return ep::primitive::NewPrimitive<ep::primitive::BroadcastElementwiseBinaryFactory>(\n      ctx->device_type(), op, x->data_type(), y->data_type(), ndims);\n}\n\ntemplate<ep::primitive::BinaryOp op>\nauto BroadcastElementwiseBinaryPrimitiveExists() {\n  return hob::make_custom(\"BroadcastElementwiseBinaryPrimitiveExists\",\n                          [](const user_op::KernelRegContext& ctx) {\n                            return NewBroadcastElementwiseBinaryPrimitive(&ctx, op).operator bool();\n                          });\n}\n\ntemplate<ep::primitive::BinaryOp op>\nclass ScalarByTensorKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport {\n public:\n  ScalarByTensorKernel() = default;\n  ~ScalarByTensorKernel() override = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    const user_op::Tensor* scalar = ctx->Tensor4ArgNameAndIndex(\"scalar\", 0);\n    user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    int64_t elem_cnt = y->shape_view().elem_cnt();\n    if (elem_cnt != 0) {\n      std::unique_ptr<ep::primitive::BroadcastElementwiseBinary> primitive =\n          NewBroadcastElementwiseBinaryPrimitive(ctx, op);\n      CHECK(primitive);\n      primitive->Launch(ctx->stream(), x->shape_view().NumAxes(), x->shape_view().ptr(), x->dptr(),\n                        scalar->shape_view().NumAxes(), scalar->shape_view().ptr(), scalar->dptr(),\n                        y->mut_dptr());\n    } else {\n      // For 0-size Tensor\n      return;\n    }\n  };\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n}  // namespace\n\n#define REGISTER_SCALAR_BY_TENSOR_KERNEL(op_name, binary_op)                         \\\n  REGISTER_USER_KERNEL(op_name)                                                      \\\n      .SetCreateFn<ScalarByTensorKernel<binary_op>>()                                \\\n      .SetIsMatchedHob(BroadcastElementwiseBinaryPrimitiveExists<binary_op>())       \\\n      .SetInplaceProposalFn(                                                         \\\n          [](const user_op::InferContext&,                                           \\\n             const user_op::AddInplaceArgPair& AddInplaceArgPairFn) -> Maybe<void> { \\\n            OF_RETURN_IF_ERROR(AddInplaceArgPairFn(\"y\", 0, \"x\", 0, true));           \\\n            return Maybe<void>::Ok();                                                \\\n          });\n\n#define SCALAR_BY_TENSOR_SEQ                                                  \\\n  OF_PP_MAKE_TUPLE_SEQ(\"scalar_add_by_tensor\", ep::primitive::BinaryOp::kAdd) \\\n  OF_PP_MAKE_TUPLE_SEQ(\"scalar_sub_by_tensor\", ep::primitive::BinaryOp::kSub) \\\n  OF_PP_MAKE_TUPLE_SEQ(\"scalar_mul_by_tensor\", ep::primitive::BinaryOp::kMul) \\\n  OF_PP_MAKE_TUPLE_SEQ(\"scalar_div_by_tensor\", ep::primitive::BinaryOp::kDiv)\n\nOF_PP_FOR_EACH_TUPLE(REGISTER_SCALAR_BY_TENSOR_KERNEL, SCALAR_BY_TENSOR_SEQ)\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/scalar_logical_kernels.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/cuda_graph_support.h\"\n#include \"oneflow/core/ep/include/primitive/broadcast_elementwise_binary.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename Context>\nstd::unique_ptr<ep::primitive::BroadcastElementwiseBinary> NewBinaryPrimitive(\n    Context* ctx, ep::primitive::BinaryOp op) {\n  const user_op::TensorDesc* in = ctx->TensorDesc4ArgNameAndIndex(\"in\", 0);\n  const user_op::TensorDesc* out = ctx->TensorDesc4ArgNameAndIndex(\"out\", 0);\n  const int64_t ndims = in->shape().NumAxes();\n  return ep::primitive::NewPrimitive<ep::primitive::BroadcastElementwiseBinaryFactory>(\n      ctx->device_type(), op, in->data_type(), out->data_type(), ndims);\n}\n\ntemplate<ep::primitive::BinaryOp op>\nauto PrimitiveExists() {\n  return hob::make_custom(\"BroadcastElementwiseBinaryPrimitiveExists\",\n                          [](const user_op::KernelRegContext& ctx) {\n                            return NewBinaryPrimitive(&ctx, op).operator bool();\n                          });\n}\n\ntemplate<ep::primitive::BinaryOp op>\nclass ScalarLogicalKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport {\n public:\n  ScalarLogicalKernel() = default;\n  ~ScalarLogicalKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    Scalar scalar_operand;\n    if (ctx->Attr<bool>(\"has_int_operand\")) {\n      scalar_operand = ctx->Attr<int64_t>(\"int_operand\");\n    } else if (ctx->Attr<bool>(\"has_float_operand\")) {\n      scalar_operand = ctx->Attr<double>(\"float_operand\");\n    } else {\n      UNIMPLEMENTED();\n    }\n\n    int64_t elem_cnt = out->shape_view().elem_cnt();\n    if (elem_cnt != 0) {\n      std::unique_ptr<ep::primitive::BroadcastElementwiseBinary> primitive =\n          NewBinaryPrimitive(ctx, op);\n      CHECK(primitive);\n      primitive->Launch(ctx->stream(), in->shape_view().NumAxes(), in->shape_view().ptr(),\n                        in->dptr(), scalar_operand, out->mut_dptr());\n    } else {\n      // For 0-d Tensor\n      return;\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_UNARY_LOGICAL_SCALAR_ELEMWISE_USER_KERNEL(kernel_name, binary_op) \\\n  REGISTER_USER_KERNEL(kernel_name)                                                \\\n      .SetCreateFn<ScalarLogicalKernel<binary_op>>()                               \\\n      .SetIsMatchedHob(PrimitiveExists<binary_op>());\n\nREGISTER_UNARY_LOGICAL_SCALAR_ELEMWISE_USER_KERNEL(\"scalar_logical_equal\",\n                                                   ep::primitive::BinaryOp::kEqual);\nREGISTER_UNARY_LOGICAL_SCALAR_ELEMWISE_USER_KERNEL(\"scalar_logical_not_equal\",\n                                                   ep::primitive::BinaryOp::kNotEqual);\nREGISTER_UNARY_LOGICAL_SCALAR_ELEMWISE_USER_KERNEL(\"scalar_logical_greater\",\n                                                   ep::primitive::BinaryOp::kGreaterThan);\nREGISTER_UNARY_LOGICAL_SCALAR_ELEMWISE_USER_KERNEL(\"scalar_logical_greater_equal\",\n                                                   ep::primitive::BinaryOp::kGreaterEqual);\nREGISTER_UNARY_LOGICAL_SCALAR_ELEMWISE_USER_KERNEL(\"scalar_logical_less\",\n                                                   ep::primitive::BinaryOp::kLessThan);\nREGISTER_UNARY_LOGICAL_SCALAR_ELEMWISE_USER_KERNEL(\"scalar_logical_less_equal\",\n                                                   ep::primitive::BinaryOp::kLessEqual);\nREGISTER_UNARY_LOGICAL_SCALAR_ELEMWISE_USER_KERNEL(\"scalar_logical_or\",\n                                                   ep::primitive::BinaryOp::kLogicalOr);\nREGISTER_UNARY_LOGICAL_SCALAR_ELEMWISE_USER_KERNEL(\"scalar_logical_xor\",\n                                                   ep::primitive::BinaryOp::kLogicalXor);\nREGISTER_UNARY_LOGICAL_SCALAR_ELEMWISE_USER_KERNEL(\"scalar_logical_and\",\n                                                   ep::primitive::BinaryOp::kLogicalAnd);\n\n}  // namespace\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/scalar_math_kernels.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/ep/include/primitive/broadcast_elementwise_binary.h\"\n#include \"oneflow/core/common/scalar.h\"\n#include \"oneflow/core/kernel/cuda_graph_support.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename Context>\nstd::unique_ptr<ep::primitive::BroadcastElementwiseBinary> NewBroadcastElementwiseBinaryPrimitive(\n    Context* ctx, ep::primitive::BinaryOp op) {\n  const user_op::TensorDesc* x = ctx->TensorDesc4ArgNameAndIndex(\"in\", 0);\n  const user_op::TensorDesc* y = ctx->TensorDesc4ArgNameAndIndex(\"out\", 0);\n  const int64_t ndims = y->shape().NumAxes();\n  return ep::primitive::NewPrimitive<ep::primitive::BroadcastElementwiseBinaryFactory>(\n      ctx->device_type(), op, x->data_type(), y->data_type(), ndims);\n}\n\ntemplate<ep::primitive::BinaryOp op>\nauto BroadcastElementwiseBinaryPrimitiveExists() {\n  return hob::make_custom(\"BroadcastElementwiseBinaryPrimitiveExists\",\n                          [](const user_op::KernelRegContext& ctx) {\n                            return NewBroadcastElementwiseBinaryPrimitive(&ctx, op).operator bool();\n                          });\n}\n\ntemplate<typename Context>\nstd::unique_ptr<ep::primitive::BroadcastElementwiseBinary>\nNewBroadcastElementwiseAttrBinaryPrimitive(Context* ctx, ep::primitive::BinaryOp op) {\n  const user_op::TensorDesc* x = ctx->TensorDesc4ArgNameAndIndex(\"x\", 0);\n  const user_op::TensorDesc* dy = ctx->TensorDesc4ArgNameAndIndex(\"dy\", 0);\n  const int64_t ndims = dy->shape().NumAxes();\n  Scalar value;\n  if (ctx->template Attr<bool>(\"has_int_operand\")) {\n    value = Scalar(ctx->template Attr<int64_t>(\"int_operand\"));\n  } else if (ctx->template Attr<bool>(\"has_float_operand\")) {\n    value = Scalar(ctx->template Attr<double>(\"float_operand\"));\n  } else {\n    UNIMPLEMENTED();\n  }\n  return ep::primitive::NewPrimitive<ep::primitive::BroadcastElementwiseBinaryFactory>(\n      ctx->device_type(), op, x->data_type(), dy->data_type(), ndims, value);\n}\n\ntemplate<ep::primitive::BinaryOp op>\nauto BroadcastElementwiseAttrBinaryPrimitiveExists() {\n  return hob::make_custom(\n      \"BroadcastElementwiseBinaryAttrPrimitiveExists\", [](const user_op::KernelRegContext& ctx) {\n        return NewBroadcastElementwiseAttrBinaryPrimitive(&ctx, op).operator bool();\n      });\n}\n\n}  // namespace\n\ntemplate<ep::primitive::BinaryOp op>\nclass ScalarMathKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport {\n public:\n  ScalarMathKernel() = default;\n  ~ScalarMathKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    Scalar value;\n    if (ctx->Attr<bool>(\"has_int_operand\")) {\n      value = Scalar(ctx->Attr<int64_t>(\"int_operand\"));\n    } else if (ctx->Attr<bool>(\"has_float_operand\")) {\n      value = Scalar(ctx->Attr<double>(\"float_operand\"));\n    } else {\n      UNIMPLEMENTED();\n    }\n    int64_t elem_cnt = out->shape_view().elem_cnt();\n    if (elem_cnt != 0) {\n      const bool is_add_sub_0 =\n          (op == ep::primitive::BinaryOp::kAdd || op == ep::primitive::BinaryOp::kSub)\n          && value.Value<double>() == 0.0;\n      const bool is_mul_div_1 =\n          (op == ep::primitive::BinaryOp::kMul || op == ep::primitive::BinaryOp::kDiv)\n          && value.Value<double>() == 1.0;\n      if ((is_add_sub_0 || is_mul_div_1) && in->dptr() == out->dptr()) { return; }\n      std::unique_ptr<ep::primitive::BroadcastElementwiseBinary> primitive =\n          NewBroadcastElementwiseBinaryPrimitive(ctx, op);\n      CHECK(primitive);\n      primitive->Launch(ctx->stream(), in->shape_view().NumAxes(), in->shape_view().ptr(),\n                        in->dptr(), value, out->mut_dptr());\n    } else {\n      // For 0-d Tensor\n      return;\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\ntemplate<ep::primitive::BinaryOp op>\nclass ScalarReverseMathKernel final : public user_op::OpKernel {\n public:\n  ScalarReverseMathKernel() = default;\n  ~ScalarReverseMathKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    Scalar value;\n    if (ctx->Attr<bool>(\"has_int_operand\")) {\n      value = Scalar(ctx->Attr<int64_t>(\"int_operand\"));\n    } else if (ctx->Attr<bool>(\"has_float_operand\")) {\n      value = Scalar(ctx->Attr<double>(\"float_operand\"));\n    } else {\n      UNIMPLEMENTED();\n    }\n    int64_t elem_cnt = out->shape_view().elem_cnt();\n    if (elem_cnt != 0) {\n      std::unique_ptr<ep::primitive::BroadcastElementwiseBinary> primitive =\n          NewBroadcastElementwiseBinaryPrimitive(ctx, op);\n      CHECK(primitive);\n      primitive->Launch(ctx->stream(), value, in->shape_view().NumAxes(), in->shape_view().ptr(),\n                        in->dptr(), out->mut_dptr());\n    } else {\n      // For 0-d Tensor\n      return;\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define SCALAR_MATH_SEQ                                                       \\\n  OF_PP_MAKE_TUPLE_SEQ(\"scalar_add\", ep::primitive::BinaryOp::kAdd)           \\\n  OF_PP_MAKE_TUPLE_SEQ(\"scalar_mul\", ep::primitive::BinaryOp::kMul)           \\\n  OF_PP_MAKE_TUPLE_SEQ(\"scalar_div\", ep::primitive::BinaryOp::kDiv)           \\\n  OF_PP_MAKE_TUPLE_SEQ(\"scalar_floordiv\", ep::primitive::BinaryOp::kFloorDiv) \\\n  OF_PP_MAKE_TUPLE_SEQ(\"scalar_truncdiv\", ep::primitive::BinaryOp::kTruncDiv) \\\n  OF_PP_MAKE_TUPLE_SEQ(\"scalar_fmod\", ep::primitive::BinaryOp::kFmod)         \\\n  OF_PP_MAKE_TUPLE_SEQ(\"scalar_pow\", ep::primitive::BinaryOp::kPow)\n\n#define REGISTER_UNARY_MATH_SCALAR_ELEMWISE_USER_KERNEL(op_name, binary_op)          \\\n  REGISTER_USER_KERNEL(op_name)                                                      \\\n      .SetCreateFn<ScalarMathKernel<binary_op>>()                                    \\\n      .SetIsMatchedHob((BroadcastElementwiseBinaryPrimitiveExists<binary_op>()))     \\\n      .SetInplaceProposalFn(                                                         \\\n          [](const user_op::InferContext& ctx,                                       \\\n             const user_op::AddInplaceArgPair& AddInplaceArgPairFn) -> Maybe<void> { \\\n            OF_RETURN_IF_ERROR(AddInplaceArgPairFn(\"out\", 0, \"in\", 0, true));        \\\n            return Maybe<void>::Ok();                                                \\\n          });\n\nOF_PP_FOR_EACH_TUPLE(REGISTER_UNARY_MATH_SCALAR_ELEMWISE_USER_KERNEL, SCALAR_MATH_SEQ)\n\n#define REGISTER_UNARY_MATH_SCALAR_REVERSE_ELEMWISE_USER_KERNEL(op_name, binary_op)                \\\n  REGISTER_USER_KERNEL(op_name).SetCreateFn<ScalarReverseMathKernel<binary_op>>().SetIsMatchedHob( \\\n      (BroadcastElementwiseBinaryPrimitiveExists<binary_op>()));\n\nREGISTER_UNARY_MATH_SCALAR_REVERSE_ELEMWISE_USER_KERNEL(\"scalar_reverse_pow\",\n                                                        ep::primitive::BinaryOp::kPow)\n\ntemplate<ep::primitive::BinaryOp op>\nclass ScalarPowGradKernel final : public user_op::OpKernel {\n public:\n  ScalarPowGradKernel() = default;\n  ~ScalarPowGradKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* x_tensor = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    const user_op::Tensor* dy_tensor = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    user_op::Tensor* dx_tensor = ctx->Tensor4ArgNameAndIndex(\"dx\", 0);\n    const int64_t elem_cnt = dx_tensor->shape_view().elem_cnt();\n    if (elem_cnt != 0) {\n      std::unique_ptr<ep::primitive::BroadcastElementwiseBinary> primitive =\n          NewBroadcastElementwiseAttrBinaryPrimitive(ctx, op);\n      CHECK(primitive);\n      primitive->Launch(ctx->stream(), x_tensor->shape_view().NumAxes(),\n                        x_tensor->shape_view().ptr(), x_tensor->dptr(),\n                        dy_tensor->shape_view().NumAxes(), dy_tensor->shape_view().ptr(),\n                        dy_tensor->dptr(), dx_tensor->mut_dptr());\n    } else {\n      // For 0-d Tensor\n      return;\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_BINARY_MATH_WITH_ATTR_ELEMWISE_USER_KERNEL(op_name, binary_op)                \\\n  REGISTER_USER_KERNEL(op_name).SetCreateFn<ScalarPowGradKernel<binary_op>>().SetIsMatchedHob( \\\n      (BroadcastElementwiseAttrBinaryPrimitiveExists<binary_op>()));\n\nREGISTER_BINARY_MATH_WITH_ATTR_ELEMWISE_USER_KERNEL(\"scalar_pow_grad\",\n                                                    ep::primitive::BinaryOp::kScalarBasePowerGrad);\nREGISTER_BINARY_MATH_WITH_ATTR_ELEMWISE_USER_KERNEL(\"scalar_reverse_pow_grad\",\n                                                    ep::primitive::BinaryOp::kScalarExpPowerGrad);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/scaled_dot_product_attention_grad_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include <cstddef>\n#include <cstdint>\n#include \"oneflow/core/common/container_util.h\"\n#include \"oneflow/core/common/data_type.h\"\n#include \"oneflow/core/common/data_type.pb.h\"\n#include \"oneflow/core/common/just.h\"\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/common/shape_view.h\"\n#include \"oneflow/core/common/throw.h\"\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/framework/op_kernel.h\"\n#include \"oneflow/core/framework/user_op_tensor.h\"\n\n#if CUDA_VERSION >= 11070\n\n#ifdef WITH_CUTLASS\n\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n#include \"oneflow/core/cuda/elementwise.cuh\"\n#include \"oneflow/core/ep/include/primitive/permute.h\"\n#include \"cutlass/arch/mma.h\"\n#include \"cutlass/gemm/warp/mma.h\"\n#include \"oneflow/core/kernel/cuda_graph_support.h\"\n#include \"oneflow/user/kernels/random_seed_util.h\"\n#include \"oneflow/user/kernels/scaled_dot_product_attention_kernel.h\"\n// from flash_attention\n#include \"oneflow/user/kernels/scaled_dot_product_attention_util.h\"\n\nnamespace oneflow {\n\nnamespace user_op {\n\nnamespace {\n\nstatic size_t InferTmpBufferSizeForFlashAttentionGradKernel(InferContext* ctx) {\n  const auto& q_shape = ctx->InputTensorDesc(\"query\", 0).shape();\n  const int batch_size = q_shape.At(0);\n  const int seqlen_q = q_shape.At(1);\n  const int num_heads = q_shape.At(2);\n  const int head_size = q_shape.At(3);\n  auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };\n  const int head_size_rounded = round_multiple(head_size, 32);\n  const int seqlen_q_rounded = round_multiple(seqlen_q, 128);\n\n  size_t buffer_size = 0;\n  buffer_size += GetCudaAlignedSize(batch_size * num_heads * seqlen_q_rounded\n                                    * GetSizeOfDataType(DataType::kFloat));\n  buffer_size += GetCudaAlignedSize(batch_size * seqlen_q_rounded * num_heads * head_size_rounded\n                                    * GetSizeOfDataType(DataType::kFloat));\n  return buffer_size;\n}\n\nclass ScaledDotProductFlashAttentionGradKernel final : public user_op::OpKernel,\n                                                       public user_op::CudaGraphSupport {\n public:\n  ScaledDotProductFlashAttentionGradKernel() = default;\n  ~ScaledDotProductFlashAttentionGradKernel() override = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const Tensor* grad_out = ctx->Tensor4ArgNameAndIndex(\"grad_out\", 0);\n    const Tensor* query = ctx->Tensor4ArgNameAndIndex(\"query\", 0);\n    const Tensor* key = ctx->Tensor4ArgNameAndIndex(\"key\", 0);\n    const Tensor* value = ctx->Tensor4ArgNameAndIndex(\"value\", 0);\n    const Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    const Tensor* softmax_lse = ctx->Tensor4ArgNameAndIndex(\"softmax_lse\", 0);\n    const Tensor* rng_state = ctx->Tensor4ArgNameAndIndex(\"rng_state\", 0);\n    const Tensor* alibi_slopes_ = nullptr;\n    if (ctx->has_input(\"alibi_slopes_\", 0)) {\n      alibi_slopes_ = ctx->Tensor4ArgNameAndIndex(\"alibi_slopes_\", 0);\n    }\n\n    const float p_dropout = ctx->Attr<float>(\"p_dropout\");\n    const float softmax_scale = ctx->Attr<float>(\"softmax_scale\");\n    bool is_causal = ctx->Attr<bool>(\"is_causal\");\n    int window_size_left = ctx->Attr<int32_t>(\"window_size_left\");\n    int window_size_right = ctx->Attr<int32_t>(\"window_size_right\");\n\n    Tensor* grad_q = ctx->Tensor4ArgNameAndIndex(\"grad_q\", 0);\n    Tensor* grad_k = ctx->Tensor4ArgNameAndIndex(\"grad_k\", 0);\n    Tensor* grad_v = ctx->Tensor4ArgNameAndIndex(\"grad_v\", 0);\n    Tensor* tmp = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n    void* tmp_ptr = tmp->mut_dptr();\n\n    auto* cuda_device = dynamic_cast<ep::CudaDevice*>(ctx->stream()->device());\n    auto dprops = cuda_device->properties();\n    auto* cuda_stream = ctx->stream()->As<ep::CudaStream>();\n\n    bool is_dropout = p_dropout > 0.0f;\n\n    if (is_causal) { window_size_right = 0; }\n\n    const int arch = cuda_stream->cuda_arch() / 10;\n    const bool is_supported_arch = (arch == 80 || arch == 86 || arch == 89 || arch == 90);\n    CHECK(is_supported_arch);\n\n    const DataType data_type = query->data_type();\n    const bool is_supported_dtype =\n        (data_type == DataType::kFloat16 || data_type == DataType::kBFloat16);\n    CHECK(is_supported_dtype);\n    CHECK_EQ(key->data_type(), data_type);\n    CHECK_EQ(value->data_type(), data_type);\n    CHECK_EQ(grad_out->data_type(), data_type);\n    CHECK_EQ(out->data_type(), data_type);\n    CHECK_EQ(softmax_lse->data_type(), DataType::kFloat);\n    CHECK_EQ(rng_state->data_type(), DataType::kUInt64);\n\n    // check contiguous last dimension.\n    CHECK_EQ(CHECK_JUST(VectorAt(grad_out->stride(), 3)), 1);\n    CHECK_EQ(CHECK_JUST(VectorAt(query->stride(), 3)), 1);\n    CHECK_EQ(CHECK_JUST(VectorAt(key->stride(), 3)), 1);\n    CHECK_EQ(CHECK_JUST(VectorAt(value->stride(), 3)), 1);\n    CHECK_EQ(CHECK_JUST(VectorAt(out->stride(), 3)), 1);\n\n    const int batch_size = query->shape_view().At(0);\n    const int seqlen_q = query->shape_view().At(1);\n    const int num_heads = query->shape_view().At(2);\n    const int head_size = query->shape_view().At(3);\n    const int seqlen_k = key->shape_view().At(1);\n    const int num_heads_k = key->shape_view().At(2);\n    const int head_size_og = grad_out->shape_view().At(3);\n\n    // check tensor shape.\n    CHECK_EQ(grad_out->shape_view().At(0), batch_size);\n    CHECK_EQ(grad_out->shape_view().At(1), seqlen_q);\n    CHECK_EQ(grad_out->shape_view().At(2), num_heads);\n    CHECK_EQ(grad_out->shape_view().At(3), head_size_og);\n    CHECK_EQ(query->shape_view().At(0), batch_size);\n    CHECK_EQ(query->shape_view().At(1), seqlen_q);\n    CHECK_EQ(query->shape_view().At(2), num_heads);\n    CHECK_EQ(query->shape_view().At(3), head_size);\n    CHECK_EQ(key->shape_view().At(0), batch_size);\n    CHECK_EQ(key->shape_view().At(1), seqlen_k);\n    CHECK_EQ(key->shape_view().At(2), num_heads_k);\n    CHECK_EQ(key->shape_view().At(3), head_size);\n    CHECK_EQ(value->shape_view().At(0), batch_size);\n    CHECK_EQ(value->shape_view().At(1), seqlen_k);\n    CHECK_EQ(value->shape_view().At(2), num_heads_k);\n    CHECK_EQ(value->shape_view().At(3), head_size);\n    CHECK_EQ(out->shape_view().At(0), batch_size);\n    CHECK_EQ(out->shape_view().At(1), seqlen_q);\n    CHECK_EQ(out->shape_view().At(2), num_heads);\n    CHECK_EQ(out->shape_view().At(3), head_size);\n    CHECK_EQ(softmax_lse->shape_view().At(0), batch_size);\n    CHECK_EQ(softmax_lse->shape_view().At(1), num_heads);\n    CHECK_EQ(softmax_lse->shape_view().At(2), seqlen_q);\n\n    CHECK_GT(batch_size, 0);   // batch size must be postive\n    CHECK_LE(head_size, 256);  // only support head dimensions at most 256\n    // FlashAttention backward for head dim 256 with dropout, or head dim 224 with/without dropout\n    // requires A100/A800 or H100/H800\n    if (head_size > 192 && (head_size <= 224 || is_dropout)) { CHECK((arch == 80 || arch == 90)); }\n    CHECK(num_heads % num_heads_k\n          == 0);  // Number of heads in key/value must devide number of heads in query\n\n    if (window_size_left >= seqlen_k) { window_size_left = -1; }\n    if (window_size_right >= seqlen_k) { window_size_right = -1; }\n\n    auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };\n    const int head_size_rounded = round_multiple(head_size, 32);\n    const int seqlen_q_rounded = round_multiple(seqlen_q, 128);\n    const int seqlen_k_rounded = round_multiple(seqlen_k, 128);\n\n    // bool loop = seqlen_k > blocksize_c;\n    // TODO: change later, for now set to true for simplicity\n    bool loop = true;\n\n    // size: batch_size x num_heads x seqlen_q_rounded; datatype: float\n    void* softmax_d_ptr = tmp_ptr;\n    tmp_ptr = reinterpret_cast<char*>(tmp_ptr)\n              + GetCudaAlignedSize(batch_size * num_heads * seqlen_q_rounded\n                                   * GetSizeOfDataType(DataType::kFloat));\n\n    // set to false by default.\n    // TODO(chende): can get from forward kernel(add input in python interface, it's only used for\n    // backward).\n    bool deterministic = false;\n\n    void* dq_accum_ptr;\n    if (loop) {\n      // size: batch_size x seqlen_q_rounded x num_heads x head_size_rounded; datatype: float\n      dq_accum_ptr = tmp_ptr;\n    }\n\n    Flash_bwd_params params;\n\n    set_params_dgrad(params, batch_size, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded,\n                     num_heads, num_heads_k, head_size, head_size_rounded, query, key, value, out,\n                     grad_out, grad_q, grad_k, grad_v, nullptr, nullptr,\n                     loop ? dq_accum_ptr : nullptr,\n                     // loop ? dk_accum.data_ptr() : nullptr,\n                     // loop ? dv_accum.data_ptr() : nullptr,\n                     nullptr, nullptr, const_cast<void*>(softmax_lse->dptr()), softmax_d_ptr,\n                     p_dropout, softmax_scale, window_size_left, window_size_right, deterministic);\n\n    params.dq_accum_split_stride =\n        !deterministic ? 0 : seqlen_q_rounded * num_heads * head_size_rounded;\n\n    auto launch = &run_mha_bwd;\n\n    params.rng_state = const_cast<uint64_t*>(rng_state->dptr<uint64_t>());\n\n    set_params_alibi(params, alibi_slopes_, batch_size, num_heads);\n\n    if (seqlen_q > 0) { launch(params, cuda_stream->cuda_stream()); }\n  }\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_SCALED_DOT_PRODUCT_FLASH_ATTENTION_KERNEL(dtype)      \\\n  REGISTER_USER_KERNEL(\"scaled_dot_product_flash_attention_grad\")      \\\n      .SetCreateFn<ScaledDotProductFlashAttentionGradKernel>()         \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \\\n                       && (user_op::HobDataType(\"out\", 0) == dtype))   \\\n      .SetInferTmpSizeFn(InferTmpBufferSizeForFlashAttentionGradKernel);\n\nREGISTER_SCALED_DOT_PRODUCT_FLASH_ATTENTION_KERNEL(DataType::kFloat16)\nREGISTER_SCALED_DOT_PRODUCT_FLASH_ATTENTION_KERNEL(DataType::kBFloat16)\n\n}  // namespace\n\n}  // namespace user_op\n\n}  // namespace oneflow\n\n#endif  // WITH_CUTLASS\n\n#endif  // CUDA_VERSION >= 11070\n"
  },
  {
    "path": "oneflow/user/kernels/scaled_dot_product_attention_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include <cstddef>\n#include <cstdint>\n#include \"oneflow/core/common/container_util.h\"\n#include \"oneflow/core/common/data_type.h\"\n#include \"oneflow/core/common/data_type.pb.h\"\n#include \"oneflow/core/common/just.h\"\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/common/shape_view.h\"\n#include \"oneflow/core/common/throw.h\"\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/framework/op_kernel.h\"\n#include \"oneflow/core/framework/user_op_tensor.h\"\n\n#if CUDA_VERSION >= 11070\n\n#ifdef WITH_CUTLASS\n\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n#include \"oneflow/core/cuda/elementwise.cuh\"\n#include \"oneflow/core/ep/include/primitive/permute.h\"\n#include \"cutlass/arch/mma.h\"\n#include \"cutlass/gemm/warp/mma.h\"\n#include \"oneflow/core/kernel/cuda_graph_support.h\"\n#include \"oneflow/user/kernels/random_seed_util.h\"\n#include \"oneflow/user/kernels/scaled_dot_product_attention_kernel.h\"\n// from flash_attention\n#include \"oneflow/user/kernels/scaled_dot_product_attention_util.h\"\n\nnamespace oneflow {\n\nnamespace user_op {\n\nnamespace {\n\nstatic size_t InferTmpBufferSizeForFlashAttentionKernel(InferContext* ctx) {\n  const float p_dropout = ctx->Attr<float>(\"p_dropout\");\n  const auto& q_shape = ctx->InputTensorDesc(\"query\", 0).shape();\n  const auto& k_shape = ctx->InputTensorDesc(\"key\", 0).shape();\n  const int batch_size = q_shape.At(0);\n  const int seqlen_q = q_shape.At(1);\n  const int num_heads = q_shape.At(2);\n  const int head_size_og = q_shape.At(3);\n  const int seqlen_k = k_shape.At(1);\n  auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };\n  const int head_size = round_multiple(head_size_og, 8);\n  const int head_size_rounded = round_multiple(head_size, 32);\n\n  int dev;\n  {\n    cudaError_t err = cudaGetDevice(&dev);\n    if (err != cudaSuccess) { return err; }\n  }\n  int sm_count;\n  {\n    cudaError_t err = cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev);\n    if (err != cudaSuccess) { return err; }\n  }\n  const int block_n = head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64);\n  const int num_n_blocks = (seqlen_k + block_n - 1) / block_n;\n  const int num_m_blocks = (seqlen_q + 64 - 1) / 64;\n  size_t buffer_size = 0;\n  // for splitKV and splitKV is not implemented for dropout.\n  if (p_dropout == 0.0f) {\n    int num_splits =\n        num_splits_heuristic(batch_size * num_heads * num_m_blocks, sm_count, num_n_blocks, 128);\n    buffer_size += GetCudaAlignedSize(num_splits * batch_size * num_heads * seqlen_q\n                                      * GetSizeOfDataType(DataType::kFloat));\n    buffer_size += GetCudaAlignedSize(num_splits * batch_size * num_heads * seqlen_q\n                                      * head_size_rounded * GetSizeOfDataType(DataType::kFloat));\n  }\n  return buffer_size;\n}\n\nclass ScaledDotProductFlashAttentionKernel final : public user_op::OpKernel,\n                                                   public user_op::CudaGraphSupport {\n public:\n  ScaledDotProductFlashAttentionKernel() = default;\n  ~ScaledDotProductFlashAttentionKernel() override = default;\n\n  std::shared_ptr<user_op::OpKernelState> CreateOpKernelState(\n      user_op::KernelInitContext* ctx) const override {\n    const auto& generator = CHECK_JUST(one::MakeGenerator(DeviceType::kCUDA));\n    generator->set_current_seed(\n        CHECK_JUST(GetOpKernelRandomSeedInCurrentRank(ctx, ctx->Attr<int64_t>(\"seed\"))));\n    return std::make_shared<ScaledDotProductFlashAttentionKernelState>(generator);\n  }\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state,\n               const user_op::OpKernelCache*) const override {\n    const Tensor* query = ctx->Tensor4ArgNameAndIndex(\"query\", 0);\n    const Tensor* key = ctx->Tensor4ArgNameAndIndex(\"key\", 0);\n    const Tensor* value = ctx->Tensor4ArgNameAndIndex(\"value\", 0);\n    const Tensor* alibi_slopes_ = nullptr;\n    if (ctx->has_input(\"alibi_slopes_\", 0)) {\n      // default to null, it will never get input for current flash-attn version.\n      alibi_slopes_ = ctx->Tensor4ArgNameAndIndex(\"alibi_slopes_\", 0);\n      CHECK(!alibi_slopes_) << \"alibi_slopes should not have value\";\n    }\n\n    const float p_dropout = ctx->Attr<float>(\"p_dropout\");\n    const float softmax_scale = ctx->Attr<float>(\"softmax_scale\");\n    bool is_causal = ctx->Attr<bool>(\"is_causal\");\n    int window_size_left = ctx->Attr<int32_t>(\"window_size_left\");\n    int window_size_right = ctx->Attr<int32_t>(\"window_size_right\");\n    uint64_t seed = ctx->Attr<int64_t>(\"seed\");\n\n    Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    Tensor* softmax_lse = ctx->Tensor4ArgNameAndIndex(\"softmax_lse\", 0);\n    Tensor* rng_state = ctx->Tensor4ArgNameAndIndex(\"rng_state\", 0);\n    Tensor* tmp = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n    void* tmp_ptr = tmp->mut_dptr();\n\n    auto* cuda_device = dynamic_cast<ep::CudaDevice*>(ctx->stream()->device());\n    auto dprops = cuda_device->properties();\n    auto* cuda_stream = ctx->stream()->As<ep::CudaStream>();\n\n    const int arch = cuda_stream->cuda_arch() / 10;\n    const bool is_supported_arch = (arch == 80 || arch == 86 || arch == 89 || arch == 90);\n    CHECK(is_supported_arch) << \"only supports CUDA Arch 80, 86, 89 and 90.\";\n\n    const DataType data_type = query->data_type();\n    const bool is_supported_dtype =\n        (data_type == DataType::kFloat16 || data_type == DataType::kBFloat16);\n    CHECK(is_supported_dtype);\n    CHECK_EQ(key->data_type(), data_type);\n    CHECK_EQ(value->data_type(), data_type);\n    CHECK_EQ(out->data_type(), data_type);\n\n    CHECK_EQ(softmax_lse->data_type(), DataType::kFloat);\n\n    // check contiguous last dimension.\n    CHECK_EQ(CHECK_JUST(VectorAt(query->stride(), 3)), 1);\n    CHECK_EQ(CHECK_JUST(VectorAt(key->stride(), 3)), 1);\n    CHECK_EQ(CHECK_JUST(VectorAt(value->stride(), 3)), 1);\n\n    const int batch_size = query->shape_view().At(0);\n    const int seqlen_q = query->shape_view().At(1);\n    const int num_heads = query->shape_view().At(2);\n    const int head_size_og = query->shape_view().At(3);\n    const int seqlen_k = key->shape_view().At(1);\n    const int num_heads_k = key->shape_view().At(2);\n\n    // check tensor shape.\n    CHECK_EQ(query->shape_view().At(0), batch_size);\n    CHECK_EQ(query->shape_view().At(1), seqlen_q);\n    CHECK_EQ(query->shape_view().At(2), num_heads);\n    CHECK_EQ(query->shape_view().At(3), head_size_og);\n    CHECK_EQ(key->shape_view().At(0), batch_size);\n    CHECK_EQ(key->shape_view().At(1), seqlen_k);\n    CHECK_EQ(key->shape_view().At(2), num_heads_k);\n    CHECK_EQ(key->shape_view().At(3), head_size_og);\n    CHECK_EQ(value->shape_view().At(0), batch_size);\n    CHECK_EQ(value->shape_view().At(1), seqlen_k);\n    CHECK_EQ(value->shape_view().At(2), num_heads_k);\n    CHECK_EQ(value->shape_view().At(3), head_size_og);\n    CHECK_EQ(out->shape_view().At(0), batch_size);\n    CHECK_EQ(out->shape_view().At(1), seqlen_q);\n    CHECK_EQ(out->shape_view().At(2), num_heads);\n    CHECK_EQ(out->shape_view().At(3), head_size_og);\n    CHECK_EQ(softmax_lse->shape_view().At(0), batch_size);\n    CHECK_EQ(softmax_lse->shape_view().At(1), num_heads);\n    CHECK_EQ(softmax_lse->shape_view().At(2), seqlen_q);\n\n    CHECK_GT(batch_size, 0);      // batch size must be postive\n    CHECK_LE(head_size_og, 256);  // only support head dimensions at most 256\n    CHECK(num_heads % num_heads_k\n          == 0);  // Number of heads in key/value must devide number of heads in query\n\n    if (window_size_left >= seqlen_k) { window_size_left = -1; }\n    if (window_size_right >= seqlen_k) { window_size_right = -1; }\n\n    // causal=true is the same as causal=false in this case\n    if (seqlen_q == 1 && !alibi_slopes_) { is_causal = false; }\n    if (is_causal) { window_size_right = 0; }\n\n    const int seqlenq_ngroups_swapped = 0;\n\n    auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };\n    const int head_size = round_multiple(head_size_og, 8);\n    const int head_size_rounded = round_multiple(head_size, 32);\n    const int seqlen_q_rounded = round_multiple(seqlen_q, 128);\n    const int seqlen_k_rounded = round_multiple(seqlen_k, 128);\n\n    Flash_fwd_params params;\n    set_params_fprop(params, batch_size, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded,\n                     num_heads, num_heads_k, head_size, head_size_rounded, query, key, value, out,\n                     /*cu_seqlens_q_d=*/nullptr,\n                     /*cu_seqlens_k_d=*/nullptr,\n                     /*seqused_k=*/nullptr,\n                     /*return_softmax=*/nullptr, softmax_lse->mut_dptr(), p_dropout, softmax_scale,\n                     window_size_left, window_size_right);\n\n    int64_t counter_offset = params.b * params.h * 32;\n    params.rng_state = rng_state->mut_dptr<uint64_t>();\n\n    set_params_splitkv(params, batch_size, num_heads, head_size, seqlen_k, seqlen_q,\n                       head_size_rounded, p_dropout, /*num_splits*/ 0, dprops, tmp_ptr);\n\n    if (p_dropout > 0.0f) {\n      // todo gennerator.\n      auto* flash_attention_kernel_state =\n          dynamic_cast<ScaledDotProductFlashAttentionKernelState*>(state);\n      CHECK_NOTNULL(flash_attention_kernel_state);\n      const auto& generator = flash_attention_kernel_state->generator();\n      CHECK_NOTNULL(generator);\n      const auto device_index = cuda_device->device_index();\n      std::shared_ptr<ep::CUDAGenerator> cuda_generator =\n          CHECK_JUST(generator->Get<ep::CUDAGenerator>(device_index));\n      params.philox_args =\n          at::PhiloxCudaState(seed, cuda_generator->get_philox_offset(counter_offset));\n    }\n\n    set_params_alibi(params, alibi_slopes_, batch_size, num_heads);\n\n    if (seqlen_k > 0) { run_mha_fwd(params, cuda_stream->cuda_stream()); }\n  }\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_SCALED_DOT_PRODUCT_FLASH_ATTENTION_KERNEL(dtype)      \\\n  REGISTER_USER_KERNEL(\"scaled_dot_product_flash_attention\")           \\\n      .SetCreateFn<ScaledDotProductFlashAttentionKernel>()             \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \\\n                       && (user_op::HobDataType(\"out\", 0) == dtype))   \\\n      .SetInferTmpSizeFn(InferTmpBufferSizeForFlashAttentionKernel);\n\nREGISTER_SCALED_DOT_PRODUCT_FLASH_ATTENTION_KERNEL(DataType::kFloat16)\nREGISTER_SCALED_DOT_PRODUCT_FLASH_ATTENTION_KERNEL(DataType::kBFloat16)\n\n}  // namespace\n\n}  // namespace user_op\n\n}  // namespace oneflow\n\n#endif  // WITH_CUTLASS\n\n#endif  // CUDA_VERSION >= 11070\n"
  },
  {
    "path": "oneflow/user/kernels/scaled_dot_product_attention_kernel.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#ifndef ONEFLOW_USER_KERNELS_FLASH_ATTENTION_KERNEL_H_\n#define ONEFLOW_USER_KERNELS_FLASH_ATTENTION_KERNEL_H_\n\n#include \"oneflow/user/kernels/random_mask_generator.h\"\n#include \"oneflow/core/framework/framework.h\"\n\nnamespace oneflow {\n\nclass ScaledDotProductFlashAttentionKernelState : public user_op::OpKernelState {\n public:\n  explicit ScaledDotProductFlashAttentionKernelState(\n      const std::shared_ptr<one::Generator>& generator)\n      : generator_(generator) {}\n\n  const std::shared_ptr<one::Generator>& generator() const { return generator_; }\n\n private:\n  std::shared_ptr<one::Generator> generator_;\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_KERNELS_FLASH_ATTENTION_KERNEL_H_\n"
  },
  {
    "path": "oneflow/user/kernels/scaled_dot_product_attention_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#ifndef ONEFLOW_USER_KERNELS_FLASH_ATTENTION_UTIL_H_\n#define ONEFLOW_USER_KERNELS_FLASH_ATTENTION_UTIL_H_\n\n#include \"oneflow/core/framework/user_op_tensor.h\"\n#include \"oneflow/core/common/util.h\"\n#include \"flash.h\"\n#include \"static_switch.h\"\n\nnamespace oneflow {\n\nnamespace user_op {\n\nnamespace {\n\nvoid set_params_fprop(Flash_fwd_params& params,\n                      // sizes\n                      const size_t b, const size_t seqlen_q, const size_t seqlen_k,\n                      const size_t seqlen_q_rounded, const size_t seqlen_k_rounded, const size_t h,\n                      const size_t h_k, const size_t d, const size_t d_rounded,\n                      // device pointers\n                      const Tensor* q, const Tensor* k, const Tensor* v, Tensor* out,\n                      void* cu_seqlens_q_d, void* cu_seqlens_k_d, void* seqused_k, void* p_d,\n                      void* softmax_lse_d, float p_dropout, float softmax_scale,\n                      int window_size_left, int window_size_right,\n                      bool seqlenq_ngroups_swapped = false) {\n  // Reset the parameters\n  std::memset(&params, 0, sizeof(params));\n\n  params.is_bf16 = q->data_type() == DataType::kBFloat16;\n\n  // Set the pointers and strides.\n  params.q_ptr = const_cast<void*>(q->dptr());\n  params.k_ptr = const_cast<void*>(k->dptr());\n  params.v_ptr = const_cast<void*>(v->dptr());\n  // All stride are in elements, not bytes.\n  params.q_row_stride = CHECK_JUST(VectorAt(q->stride(), 1));\n  params.k_row_stride = CHECK_JUST(VectorAt(k->stride(), 1));\n  params.v_row_stride = CHECK_JUST(VectorAt(v->stride(), 1));\n  params.q_head_stride = CHECK_JUST(VectorAt(q->stride(), 2));\n  params.k_head_stride = CHECK_JUST(VectorAt(k->stride(), 2));\n  params.v_head_stride = CHECK_JUST(VectorAt(v->stride(), 2));\n  params.o_ptr = out->mut_dptr();\n  params.o_row_stride = CHECK_JUST(VectorAt(out->stride(), 1));\n  params.o_head_stride = CHECK_JUST(VectorAt(out->stride(), 2));\n\n  if (cu_seqlens_q_d == nullptr) {\n    params.q_batch_stride = CHECK_JUST(VectorAt(q->stride(), 0));\n    params.k_batch_stride = CHECK_JUST(VectorAt(k->stride(), 0));\n    params.v_batch_stride = CHECK_JUST(VectorAt(v->stride(), 0));\n    params.o_batch_stride = CHECK_JUST(VectorAt(out->stride(), 0));\n    if (seqlenq_ngroups_swapped) {\n      params.q_batch_stride *= seqlen_q;\n      params.o_batch_stride *= seqlen_q;\n    }\n  }\n\n  params.cu_seqlens_q = static_cast<int*>(cu_seqlens_q_d);\n  params.cu_seqlens_k = static_cast<int*>(cu_seqlens_k_d);\n  params.seqused_k = static_cast<int*>(seqused_k);\n\n  // P = softmax(QK^T)\n  params.p_ptr = p_d;\n\n  // Softmax sum\n  params.softmax_lse_ptr = softmax_lse_d;\n\n  // Set the dimensions.\n  params.b = b;\n  params.h = h;\n  params.h_k = h_k;\n  params.h_h_k_ratio = h / h_k;\n  params.seqlen_q = seqlen_q;\n  params.seqlen_k = seqlen_k;\n  params.seqlen_q_rounded = seqlen_q_rounded;\n  params.seqlen_k_rounded = seqlen_k_rounded;\n  params.d = d;\n  params.d_rounded = d_rounded;\n\n  // Set the different scale values.\n  params.scale_softmax = softmax_scale;\n  params.scale_softmax_log2 = softmax_scale * M_LOG2E;\n\n  // Set this to probability of keeping an element to simplify things.\n  params.p_dropout = 1.f - p_dropout;\n  // Convert p from float to int so we don't have to convert the random uint to float to compare.\n  // [Minor] We want to round down since when we do the comparison we use <= instead of <\n  // params.p_dropout_in_uint = uint32_t(std::floor(params.p_dropout * 4294967295.0));\n  // params.p_dropout_in_uint16_t = uint16_t(std::floor(params.p_dropout * 65535.0));\n  params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0));\n  params.rp_dropout = 1.f / params.p_dropout;\n  params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax;\n  CHECK_LT(p_dropout, 1.f);\n#ifdef FLASHATTENTION_DISABLE_DROPOUT\n  TORCH_CHECK(p_dropout == 0.0f, \"This flash attention build does not support dropout.\");\n#endif\n\n  // Causal is the special case where window_size_right == 0 and window_size_left < 0.\n  // Local is the more general case where window_size_right >= 0 or window_size_left >= 0.\n  params.is_causal = window_size_left < 0 && window_size_right == 0;\n\n  if (window_size_left < 0 && window_size_right >= 0) { window_size_left = seqlen_k; }\n  if (window_size_left >= 0 && window_size_right < 0) { window_size_right = seqlen_k; }\n  params.window_size_left = window_size_left;\n  params.window_size_right = window_size_right;\n\n#ifdef FLASHATTENTION_DISABLE_LOCAL\n  TORCH_CHECK(params.is_causal || (window_size_left < 0 && window_size_right < 0),\n              \"This flash attention build does not support local attention.\");\n#endif\n\n  params.is_seqlens_k_cumulative = true;\n\n#ifdef FLASHATTENTION_DISABLE_UNEVEN_K\n  TORCH_CHECK(d == d_rounded,\n              \"This flash attention build does not support headdim not being a multiple of 32.\");\n#endif\n}\n\nvoid set_params_dgrad(Flash_bwd_params& params,\n                      // sizes\n                      const size_t b, const size_t seqlen_q, const size_t seqlen_k,\n                      const size_t seqlen_q_rounded, const size_t seqlen_k_rounded, const size_t h,\n                      const size_t h_k, const size_t d, const size_t d_rounded,\n                      // device pointers\n                      const Tensor* q, const Tensor* k, const Tensor* v, const Tensor* out,\n                      const Tensor* dout, Tensor* dq, Tensor* dk, Tensor* dv, void* cu_seqlens_q_d,\n                      void* cu_seqlens_k_d, void* dq_accum_d, void* dk_accum_d, void* dv_accum_d,\n                      void* softmax_lse_d, void* dsoftmax_sum_d, float p_dropout,\n                      float softmax_scale, int window_size_left, int window_size_right,\n                      bool deterministic) {\n  set_params_fprop(params, b, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, h, h_k, d,\n                   d_rounded, q, k, v, const_cast<Tensor*>(out), cu_seqlens_q_d, cu_seqlens_k_d,\n                   nullptr, nullptr, softmax_lse_d, p_dropout, softmax_scale, window_size_left,\n                   window_size_right);\n\n  // Set the pointers and strides.\n  params.do_ptr = const_cast<void*>(dout->dptr());\n  params.do_row_stride = CHECK_JUST(VectorAt(dout->stride(), 1));\n  params.do_head_stride = CHECK_JUST(VectorAt(dout->stride(), 2));\n  params.dq_ptr = dq->mut_dptr();\n  params.dk_ptr = dk->mut_dptr();\n  params.dv_ptr = dv->mut_dptr();\n  params.dq_row_stride = CHECK_JUST(VectorAt(dq->stride(), 1));\n  params.dk_row_stride = CHECK_JUST(VectorAt(dk->stride(), 1));\n  params.dv_row_stride = CHECK_JUST(VectorAt(dv->stride(), 1));\n  params.dq_head_stride = CHECK_JUST(VectorAt(dq->stride(), 2));\n  params.dk_head_stride = CHECK_JUST(VectorAt(dk->stride(), 2));\n  params.dv_head_stride = CHECK_JUST(VectorAt(dv->stride(), 2));\n\n  if (cu_seqlens_q_d == nullptr) {\n    params.do_batch_stride = CHECK_JUST(VectorAt(dout->stride(), 0));\n    params.dq_batch_stride = CHECK_JUST(VectorAt(dq->stride(), 0));\n    params.dk_batch_stride = CHECK_JUST(VectorAt(dk->stride(), 0));\n    params.dv_batch_stride = CHECK_JUST(VectorAt(dv->stride(), 0));\n  }\n\n  params.dq_accum_ptr = dq_accum_d;\n  params.dk_accum_ptr = dk_accum_d;\n  params.dv_accum_ptr = dv_accum_d;\n\n  // Softmax sum\n  params.dsoftmax_sum = dsoftmax_sum_d;\n\n  params.deterministic = deterministic;\n}\n\nvoid run_mha_fwd(Flash_fwd_params& params, cudaStream_t stream, bool force_split_kernel = false) {\n  FP16_SWITCH(!params.is_bf16, [&] {\n    HEADDIM_SWITCH(params.d, [&] {\n      if (params.num_splits <= 1 && !force_split_kernel) {  // If we don't set it num_splits == 0\n        run_mha_fwd_<elem_type, kHeadDim>(params, stream);\n      } else {\n        run_mha_fwd_splitkv_dispatch<elem_type, kHeadDim>(params, stream);\n      }\n    });\n  });\n}\n\nvoid run_mha_bwd(Flash_bwd_params& params, cudaStream_t stream) {\n  FP16_SWITCH(!params.is_bf16, [&] {\n    HEADDIM_SWITCH(params.d, [&] { run_mha_bwd_<elem_type, kHeadDim>(params, stream); });\n  });\n}\n\n// Find the number of splits that maximizes the occupancy. For example, if we have\n// batch * n_heads = 48 and we have 108 SMs, having 2 splits (efficiency = 0.89) is\n// better than having 3 splits (efficiency = 0.67). However, we also don't want too many\n// splits as that would incur more HBM reads/writes.\n// So we find the best efficiency, then find the smallest number of splits that gets 85%\n// of the best efficiency.\ninline int num_splits_heuristic(int batch_nheads_mblocks, int num_SMs, int num_n_blocks,\n                                int max_splits) {\n  // If we have enough to almost fill the SMs, then just use 1 split\n  if (batch_nheads_mblocks >= 0.8f * num_SMs) { return 1; }\n  max_splits = std::min({max_splits, num_SMs, num_n_blocks});\n  float max_efficiency = 0.f;\n  std::vector<float> efficiency;\n  efficiency.reserve(max_splits);\n  auto ceildiv = [](int a, int b) { return (a + b - 1) / b; };\n  // Some splits are not eligible. For example, if we have 64 blocks and choose 11 splits,\n  // we'll have 6 * 10 + 4 blocks. If we choose 12 splits, we'll have 6 * 11 + (-2) blocks\n  // (i.e. it's 11 splits anyway).\n  // So we check if the number of blocks per split is the same as the previous num_splits.\n  auto is_split_eligible = [&ceildiv, &num_n_blocks](int num_splits) {\n    return num_splits == 1\n           || ceildiv(num_n_blocks, num_splits) != ceildiv(num_n_blocks, num_splits - 1);\n  };\n  for (int num_splits = 1; num_splits <= max_splits; num_splits++) {\n    if (!is_split_eligible(num_splits)) {\n      efficiency.push_back(0.f);\n    } else {\n      float n_waves = float(batch_nheads_mblocks * num_splits) / num_SMs;\n      float eff = n_waves / ceil(n_waves);\n      // printf(\"num_splits = %d, eff = %f\\n\", num_splits, eff);\n      if (eff > max_efficiency) { max_efficiency = eff; }\n      efficiency.push_back(eff);\n    }\n  }\n  for (int num_splits = 1; num_splits <= max_splits; num_splits++) {\n    if (!is_split_eligible(num_splits)) { continue; }\n    if (efficiency[num_splits - 1] >= 0.85 * max_efficiency) {\n      // printf(\"num_splits chosen = %d\\n\", num_splits);\n      return num_splits;\n    }\n  }\n  return 1;\n}\n\nvoid set_params_splitkv(Flash_fwd_params& params, const int batch_size, const int num_heads,\n                        const int head_size, const int max_seqlen_k, const int max_seqlen_q,\n                        const int head_size_rounded, const float p_dropout, const int num_splits,\n                        cudaDeviceProp& dprops, void* tmp_ptr) {\n  // This needs to match with run_mha_fwd_splitkv_dispatch\n  const int block_n = head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64);\n  const int num_n_blocks = (max_seqlen_k + block_n - 1) / block_n;\n  // Technically kBlockM = 64 only for the splitKV kernels, not the standard kernel.\n  // In any case we don't expect seqlen_q to be larger than 64 for inference.\n  const int num_m_blocks = (max_seqlen_q + 64 - 1) / 64;\n  params.num_splits = num_splits;\n  if (p_dropout == 0.0f) {  // SplitKV is not implemented for dropout\n    if (num_splits < 1) {\n      params.num_splits = num_splits_heuristic(batch_size * num_heads * num_m_blocks,\n                                               dprops.multiProcessorCount, num_n_blocks, 128);\n    }\n    if (params.num_splits > 1) {\n      size_t softmax_lse_accum_size =\n          params.num_splits * batch_size * num_heads * max_seqlen_q * sizeof(float);\n      params.softmax_lseaccum_ptr = tmp_ptr;\n      params.oaccum_ptr =\n          reinterpret_cast<char*>(tmp_ptr) + GetCudaAlignedSize(softmax_lse_accum_size);\n    }\n    CHECK_LE(params.num_splits, 128);\n  }\n}\n\nvoid set_params_alibi(Flash_fwd_params& params, const Tensor* alibi_slopes_, int batch_size,\n                      int num_heads) {\n  // TODO(ChenDe): Need Support Alibi params.\n  // default to null\n  CHECK(!alibi_slopes_) << \"alibi_slopes should be null.\";\n  params.alibi_slopes_ptr = nullptr;\n}\n\n}  // namespace\n\n}  // namespace user_op\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_KERNELS_FLASH_ATTENTION_UTIL_H_\n"
  },
  {
    "path": "oneflow/user/kernels/search_sorted_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/new_kernel_util.h\"\n#include \"oneflow/user/kernels/search_sorted_kernel_util.h\"\n\nnamespace oneflow {\n\ntemplate<typename T, typename K>\nclass CpuSearchSortedKernel final : public user_op::OpKernel {\n public:\n  CpuSearchSortedKernel() = default;\n  ~CpuSearchSortedKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* sorted_sequence = ctx->Tensor4ArgNameAndIndex(\"sorted_sequence\", 0);\n    const user_op::Tensor* values = ctx->Tensor4ArgNameAndIndex(\"values\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    const bool& right = ctx->Attr<bool>(\"right\");\n    const T* values_ptr = values->dptr<T>();\n    const T* sequence_ptr = sorted_sequence->dptr<T>();\n    K* out_ptr = out->mut_dptr<K>();\n    const int32_t instance_num = values->shape_view().elem_cnt();\n    bool is_values_scalar = values->shape_view().NumAxes() == 0;\n    bool is_sequence_1d = (sorted_sequence->shape_view().NumAxes() == 1);\n    K values_shape_last =\n        is_values_scalar ? 1 : values->shape_view().At(values->shape_view().NumAxes() - 1);\n    K sequence_shape_last =\n        sorted_sequence->shape_view().At(sorted_sequence->shape_view().NumAxes() - 1);\n    FOR_RANGE(int32_t, i, 0, instance_num) {\n      K start_bd = is_sequence_1d ? 0 : i / values_shape_last * sequence_shape_last;\n      K end_bd = start_bd + sequence_shape_last;\n      K pos = !right\n                  ? cus_lower_bound<T, K>(start_bd, end_bd, values_ptr[i], sequence_ptr) - start_bd\n                  : cus_upper_bound<T, K>(start_bd, end_bd, values_ptr[i], sequence_ptr) - start_bd;\n\n      out_ptr[i] = pos;\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_CPU_SEARCH_SORTED_KERNEL(in_dtype, out_dtype)                              \\\n  REGISTER_USER_KERNEL(\"searchsorted\")                                                      \\\n      .SetCreateFn<                                                                         \\\n          CpuSearchSortedKernel<OF_PP_PAIR_FIRST(in_dtype), OF_PP_PAIR_FIRST(out_dtype)>>() \\\n      .SetIsMatchedHob(                                                                     \\\n          (user_op::HobDeviceType() == DeviceType::kCPU)                                    \\\n          && (user_op::HobDataType(\"sorted_sequence\", 0) == OF_PP_PAIR_SECOND(in_dtype))    \\\n          && (user_op::HobDataType(\"values\", 0) == OF_PP_PAIR_SECOND(in_dtype))             \\\n          && (user_op::HobDataType(\"out\", 0) == OF_PP_PAIR_SECOND(out_dtype)));\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_CPU_SEARCH_SORTED_KERNEL, ARITHMETIC_DATA_TYPE_SEQ,\n                                 INDEX_DATA_TYPE_SEQ)\n\ntemplate<typename T, typename K>\nclass CpuSearchSortedScalarKernel final : public user_op::OpKernel {\n public:\n  CpuSearchSortedScalarKernel() = default;\n  ~CpuSearchSortedScalarKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* sorted_sequence = ctx->Tensor4ArgNameAndIndex(\"sorted_sequence\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n\n    const bool& right = ctx->Attr<bool>(\"right\");\n    const T& values = static_cast<T>(ctx->Attr<double>(\"values\"));\n\n    const T* sequence_ptr = sorted_sequence->dptr<T>();\n    K* out_ptr = out->mut_dptr<K>();\n    K sequence_shape_last = sorted_sequence->shape_view().At(0);\n\n    K pos = !right ? cus_lower_bound<T, K>(0, sequence_shape_last, values, sequence_ptr)\n                   : cus_upper_bound<T, K>(0, sequence_shape_last, values, sequence_ptr);\n\n    out_ptr[0] = pos;\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_CPU_SEARCH_SORTED_SCALAR_KERNEL(in_dtype, out_dtype)                             \\\n  REGISTER_USER_KERNEL(\"searchsorted_scalar\")                                                     \\\n      .SetCreateFn<                                                                               \\\n          CpuSearchSortedScalarKernel<OF_PP_PAIR_FIRST(in_dtype), OF_PP_PAIR_FIRST(out_dtype)>>() \\\n      .SetIsMatchedHob(                                                                           \\\n          (user_op::HobDeviceType() == DeviceType::kCPU)                                          \\\n          && (user_op::HobDataType(\"sorted_sequence\", 0) == OF_PP_PAIR_SECOND(in_dtype))          \\\n          && (user_op::HobDataType(\"out\", 0) == OF_PP_PAIR_SECOND(out_dtype)));\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_CPU_SEARCH_SORTED_SCALAR_KERNEL, ARITHMETIC_DATA_TYPE_SEQ,\n                                 INDEX_DATA_TYPE_SEQ)\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/search_sorted_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/cuda_graph_support.h\"\n#include \"oneflow/core/kernel/new_kernel_util.h\"\n#include \"oneflow/core/device/cuda_util.h\"\n#include \"oneflow/user/kernels/search_sorted_kernel_util.h\"\n\nnamespace oneflow {\n\ntemplate<typename T, typename K>\n__global__ void DoSearchSortedLogical(int32_t instance_num, bool is_sequence_1d,\n                                      K values_shape_last, K sequence_shape_last, bool right,\n                                      const T* values_ptr, const T* sequence_ptr, K* out_ptr) {\n  CUDA_1D_KERNEL_LOOP(i, instance_num) {\n    K start_bd = is_sequence_1d ? 0 : i / values_shape_last * sequence_shape_last;\n    K end_bd = start_bd + sequence_shape_last;\n    K pos = !right\n                ? cus_lower_bound<T, K>(start_bd, end_bd, values_ptr[i], sequence_ptr) - start_bd\n                : cus_upper_bound<T, K>(start_bd, end_bd, values_ptr[i], sequence_ptr) - start_bd;\n    out_ptr[i] = pos;\n  }\n}\n\ntemplate<typename T, typename K>\n__global__ void DoSearchSortedScalarLogical(K sequence_shape_last, bool right, const T values,\n                                            const T* sequence_ptr, K* out_ptr) {\n  CUDA_1D_KERNEL_LOOP(i, 1) {\n    K pos = !right ? cus_lower_bound<T, K>(0, sequence_shape_last, values, sequence_ptr)\n                   : cus_upper_bound<T, K>(0, sequence_shape_last, values, sequence_ptr);\n    out_ptr[0] = pos;\n  }\n}\n\ntemplate<typename T, typename K>\nclass GpuSearchSortedKernel final : public user_op::OpKernel {\n public:\n  GpuSearchSortedKernel() = default;\n  ~GpuSearchSortedKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* sorted_sequence = ctx->Tensor4ArgNameAndIndex(\"sorted_sequence\", 0);\n    const user_op::Tensor* values = ctx->Tensor4ArgNameAndIndex(\"values\", 0);\n\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    const bool& right = ctx->Attr<bool>(\"right\");\n    const T* values_ptr = values->dptr<T>();\n    const T* sequence_ptr = sorted_sequence->dptr<T>();\n    K* out_ptr = out->mut_dptr<K>();\n    const int32_t instance_num = values->shape_view().elem_cnt();\n    bool is_values_scalar = values->shape_view().NumAxes() == 0;\n    bool is_sequence_1d = (sorted_sequence->shape_view().NumAxes() == 1);\n    K values_shape_last =\n        is_values_scalar ? 1 : values->shape_view().At(values->shape_view().NumAxes() - 1);\n    K sequence_shape_last =\n        sorted_sequence->shape_view().At(sorted_sequence->shape_view().NumAxes() - 1);\n    RUN_CUDA_KERNEL((DoSearchSortedLogical<T, K>), ctx->stream(), instance_num, instance_num,\n                    is_sequence_1d, values_shape_last, sequence_shape_last, right, values_ptr,\n                    sequence_ptr, out_ptr);\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_GPU_SEARCH_SORTED_KERNEL(in_dtype, out_dtype)                              \\\n  REGISTER_USER_KERNEL(\"searchsorted\")                                                      \\\n      .SetCreateFn<                                                                         \\\n          GpuSearchSortedKernel<OF_PP_PAIR_FIRST(in_dtype), OF_PP_PAIR_FIRST(out_dtype)>>() \\\n      .SetIsMatchedHob(                                                                     \\\n          (user_op::HobDeviceType() == DeviceType::kCUDA)                                   \\\n          && (user_op::HobDataType(\"sorted_sequence\", 0) == OF_PP_PAIR_SECOND(in_dtype))    \\\n          && (user_op::HobDataType(\"values\", 0) == OF_PP_PAIR_SECOND(in_dtype))             \\\n          && (user_op::HobDataType(\"out\", 0) == OF_PP_PAIR_SECOND(out_dtype)));\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_GPU_SEARCH_SORTED_KERNEL, ARITHMETIC_DATA_TYPE_SEQ,\n                                 INDEX_DATA_TYPE_SEQ)\n\ntemplate<typename T, typename K>\nclass GpuSearchSortedScalarKernel final : public user_op::OpKernel {\n public:\n  GpuSearchSortedScalarKernel() = default;\n  ~GpuSearchSortedScalarKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* sorted_sequence = ctx->Tensor4ArgNameAndIndex(\"sorted_sequence\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n\n    const bool& right = ctx->Attr<bool>(\"right\");\n    const T& values = static_cast<T>(ctx->Attr<double>(\"values\"));\n\n    const T* sequence_ptr = sorted_sequence->dptr<T>();\n    K* out_ptr = out->mut_dptr<K>();\n    K sequence_shape_last = sorted_sequence->shape_view().At(0);\n    RUN_CUDA_KERNEL((DoSearchSortedScalarLogical<T, K>), ctx->stream(), 1, sequence_shape_last,\n                    right, values, sequence_ptr, out_ptr);\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_GPU_SEARCH_SORTED_SCALAR_KERNEL(in_dtype, out_dtype)                             \\\n  REGISTER_USER_KERNEL(\"searchsorted_scalar\")                                                     \\\n      .SetCreateFn<                                                                               \\\n          GpuSearchSortedScalarKernel<OF_PP_PAIR_FIRST(in_dtype), OF_PP_PAIR_FIRST(out_dtype)>>() \\\n      .SetIsMatchedHob(                                                                           \\\n          (user_op::HobDeviceType() == DeviceType::kCUDA)                                         \\\n          && (user_op::HobDataType(\"sorted_sequence\", 0) == OF_PP_PAIR_SECOND(in_dtype))          \\\n          && (user_op::HobDataType(\"out\", 0) == OF_PP_PAIR_SECOND(out_dtype)));\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_GPU_SEARCH_SORTED_SCALAR_KERNEL, ARITHMETIC_DATA_TYPE_SEQ,\n                                 INDEX_DATA_TYPE_SEQ)\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/search_sorted_kernel_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/new_kernel_util.h\"\n\ntemplate<typename T, typename K>\nOF_DEVICE_FUNC K cus_lower_bound(K start, K end, const T val, const T* bd) {\n  while (start < end) {\n    const K mid = start + ((end - start) >> 1);\n    const T mid_val = bd[mid];\n    if (!(mid_val >= val)) {\n      start = mid + 1;\n    } else {\n      end = mid;\n    }\n  }\n  return start;\n}\n\ntemplate<typename T, typename K>\nOF_DEVICE_FUNC K cus_upper_bound(K start, K end, const T val, const T* bd) {\n  while (start < end) {\n    const K mid = start + ((end - start) >> 1);\n    const T mid_val = bd[mid];\n    if (!(mid_val > val)) {\n      start = mid + 1;\n    } else {\n      end = mid;\n    }\n  }\n  return start;\n}\n"
  },
  {
    "path": "oneflow/user/kernels/sigmoid_cross_entropy_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/user/kernels/sigmoid_cross_entropy_kernel.h\"\n\nnamespace oneflow {\n\nnamespace {\ntemplate<template<typename, typename> class Opt, typename PredT, typename LabelT>\nstruct ElemwiseSigmoidCrossEntropyGradFunctor<DeviceType::kCPU, Opt, PredT, LabelT> final {\n  void operator()(ep::Stream* stream, int64_t n, PredT* prediction_diff, const PredT* prediction,\n                  const LabelT* label, const PredT* loss_diff) {\n    FOR_RANGE(int64_t, i, 0, n) {\n      prediction_diff[i] = Opt<PredT, LabelT>()(prediction[i], label[i], loss_diff[i]);\n    }\n  }\n};\n\ntemplate<template<typename, typename> class Opt, typename PredT, typename LabelT>\nstruct ElemwiseSigmoidCrossEntropyFunctor<DeviceType::kCPU, Opt, PredT, LabelT> final {\n  void operator()(ep::Stream* stream, int64_t n, PredT* loss, const PredT* prediction,\n                  const LabelT* label) {\n    FOR_RANGE(int64_t, i, 0, n) { loss[i] = Opt<PredT, LabelT>()(prediction[i], label[i]); }\n  }\n};\n}  // namespace\n\nREGISTER_SIGMOID_CROSS_ENTROPY_KERNEL(DeviceType::kCPU, float, int32_t)\nREGISTER_SIGMOID_CROSS_ENTROPY_KERNEL(DeviceType::kCPU, double, int32_t)\nREGISTER_SIGMOID_CROSS_ENTROPY_KERNEL(DeviceType::kCPU, float, int8_t)\nREGISTER_SIGMOID_CROSS_ENTROPY_KERNEL(DeviceType::kCPU, double, int8_t)\nREGISTER_SIGMOID_CROSS_ENTROPY_KERNEL(DeviceType::kCPU, float, float)\nREGISTER_SIGMOID_CROSS_ENTROPY_KERNEL(DeviceType::kCPU, double, double)\nREGISTER_SIGMOID_CROSS_ENTROPY_GRAD_KERNEL(DeviceType::kCPU, float, int32_t)\nREGISTER_SIGMOID_CROSS_ENTROPY_GRAD_KERNEL(DeviceType::kCPU, double, int32_t)\nREGISTER_SIGMOID_CROSS_ENTROPY_GRAD_KERNEL(DeviceType::kCPU, float, int8_t)\nREGISTER_SIGMOID_CROSS_ENTROPY_GRAD_KERNEL(DeviceType::kCPU, double, int8_t)\nREGISTER_SIGMOID_CROSS_ENTROPY_GRAD_KERNEL(DeviceType::kCPU, float, float)\nREGISTER_SIGMOID_CROSS_ENTROPY_GRAD_KERNEL(DeviceType::kCPU, double, double)\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/sigmoid_cross_entropy_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/cuda/elementwise.cuh\"\n#include \"oneflow/user/kernels/sigmoid_cross_entropy_kernel.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n\nnamespace oneflow {\n\nnamespace {\ntemplate<template<typename, typename> class Opt, typename PredT, typename LabelT>\nstruct ElemwiseSigmoidCrossEntropyGradFunctor<DeviceType::kCUDA, Opt, PredT, LabelT> final {\n  void operator()(ep::Stream* stream, int64_t n, PredT* prediction_diff, const PredT* prediction,\n                  const LabelT* label, const PredT* loss_diff) {\n    OF_CUDA_CHECK(cuda::elementwise::Ternary(Opt<PredT, LabelT>(), n, prediction_diff, prediction,\n                                             label, loss_diff,\n                                             stream->As<ep::CudaStream>()->cuda_stream()));\n  }\n};\n\ntemplate<template<typename, typename> class Opt, typename PredT, typename LabelT>\nstruct ElemwiseSigmoidCrossEntropyFunctor<DeviceType::kCUDA, Opt, PredT, LabelT> final {\n  void operator()(ep::Stream* stream, int64_t n, PredT* loss, const PredT* prediction,\n                  const LabelT* label) {\n    OF_CUDA_CHECK(cuda::elementwise::Binary(Opt<PredT, LabelT>(), n, loss, prediction, label,\n                                            stream->As<ep::CudaStream>()->cuda_stream()));\n  }\n};\n}  // namespace\nREGISTER_SIGMOID_CROSS_ENTROPY_KERNEL(DeviceType::kCUDA, float, int32_t)\nREGISTER_SIGMOID_CROSS_ENTROPY_KERNEL(DeviceType::kCUDA, double, int32_t)\nREGISTER_SIGMOID_CROSS_ENTROPY_KERNEL(DeviceType::kCUDA, float, int8_t)\nREGISTER_SIGMOID_CROSS_ENTROPY_KERNEL(DeviceType::kCUDA, double, int8_t)\nREGISTER_SIGMOID_CROSS_ENTROPY_KERNEL(DeviceType::kCUDA, float, float)\nREGISTER_SIGMOID_CROSS_ENTROPY_KERNEL(DeviceType::kCUDA, double, double)\nREGISTER_SIGMOID_CROSS_ENTROPY_GRAD_KERNEL(DeviceType::kCUDA, float, int32_t)\nREGISTER_SIGMOID_CROSS_ENTROPY_GRAD_KERNEL(DeviceType::kCUDA, double, int32_t)\nREGISTER_SIGMOID_CROSS_ENTROPY_GRAD_KERNEL(DeviceType::kCUDA, float, int8_t)\nREGISTER_SIGMOID_CROSS_ENTROPY_GRAD_KERNEL(DeviceType::kCUDA, double, int8_t)\nREGISTER_SIGMOID_CROSS_ENTROPY_GRAD_KERNEL(DeviceType::kCUDA, float, float)\nREGISTER_SIGMOID_CROSS_ENTROPY_GRAD_KERNEL(DeviceType::kCUDA, double, double)\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/sigmoid_cross_entropy_kernel.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_KERNELS_SIGMOID_CROSS_ENTROPY_KERNEL_H_\n#define ONEFLOW_USER_KERNELS_SIGMOID_CROSS_ENTROPY_KERNEL_H_\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/user/kernels/math_unary_elementwise_func.h\"\n\nnamespace oneflow {\n\ntemplate<typename PredT, typename LabelT>\nstruct SigmoidCrossEntropyFunctor {\n  OF_DEVICE_FUNC PredT operator()(const PredT prediction, const LabelT label) const {\n    return -1.f * prediction * (label - (prediction >= 0))\n           + LogFunctor<PredT>::Forward(\n               1 + ExpFunctor<PredT>::Forward(prediction - 2 * prediction * (prediction >= 0)));\n  }\n};\n\ntemplate<typename PredT, typename LabelT>\nstruct SigmoidCrossEntropyGradFunctor {\n  OF_DEVICE_FUNC PredT operator()(const PredT prediction, const LabelT label,\n                                  const PredT loss_diff) const {\n    return loss_diff * (1.f / (1.f + ExpFunctor<PredT>::Forward(-prediction)) - label);\n  }\n};\n\nnamespace {\ntemplate<DeviceType device_type, template<typename, typename> class Opt, typename PredT,\n         typename LabelT>\nstruct ElemwiseSigmoidCrossEntropyGradFunctor final {\n  void operator()(ep::Stream* stream, int64_t n, PredT* prediction_diff, const PredT* prediction,\n                  const LabelT* label, const PredT* loss_diff);\n};\n\ntemplate<DeviceType device_type, template<typename, typename> class Opt, typename PredT,\n         typename LabelT>\nstruct ElemwiseSigmoidCrossEntropyFunctor final {\n  void operator()(ep::Stream* stream, int64_t n, PredT* loss, const PredT* prediction,\n                  const LabelT* label);\n};\n}  // namespace\n\ntemplate<DeviceType device_type, template<typename, typename> class Opt, typename PredT,\n         typename LabelT>\nclass SigmoidCrossEntropyKernel final : public user_op::OpKernel {\n public:\n  SigmoidCrossEntropyKernel() = default;\n  ~SigmoidCrossEntropyKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* prediction = ctx->Tensor4ArgNameAndIndex(\"prediction\", 0);\n    const user_op::Tensor* label = ctx->Tensor4ArgNameAndIndex(\"label\", 0);\n    user_op::Tensor* loss = ctx->Tensor4ArgNameAndIndex(\"loss\", 0);\n    const auto n = prediction->shape_view().elem_cnt();\n    ElemwiseSigmoidCrossEntropyFunctor<device_type, Opt, PredT, LabelT>()(\n        ctx->stream(), n, loss->mut_dptr<PredT>(), prediction->dptr<PredT>(),\n        label->dptr<LabelT>());\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_SIGMOID_CROSS_ENTROPY_KERNEL(device_type, dtype, ltype)                      \\\n  REGISTER_USER_KERNEL(\"sigmoid_cross_entropy\")                                               \\\n      .SetCreateFn<                                                                           \\\n          SigmoidCrossEntropyKernel<device_type, SigmoidCrossEntropyFunctor, dtype, ltype>>() \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == device_type)                              \\\n                       && (user_op::HobDataType(\"label\", 0) == GetDataType<ltype>::value)     \\\n                       && (user_op::HobDataType(\"loss\", 0) == GetDataType<dtype>::value));\n\ntemplate<DeviceType device_type, template<typename, typename> class Opt, typename PredT,\n         typename LabelT>\nclass SigmoidCrossEntropyGradKernel final : public user_op::OpKernel {\n public:\n  SigmoidCrossEntropyGradKernel() = default;\n  ~SigmoidCrossEntropyGradKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* label = ctx->Tensor4ArgNameAndIndex(\"label\", 0);\n    const user_op::Tensor* loss_diff = ctx->Tensor4ArgNameAndIndex(\"loss_diff\", 0);\n    const user_op::Tensor* prediction = ctx->Tensor4ArgNameAndIndex(\"prediction\", 0);\n    user_op::Tensor* prediction_diff = ctx->Tensor4ArgNameAndIndex(\"prediction_diff\", 0);\n    const int64_t n = prediction->shape_view().elem_cnt();\n    ElemwiseSigmoidCrossEntropyGradFunctor<device_type, Opt, PredT, LabelT>()(\n        ctx->stream(), n, prediction_diff->mut_dptr<PredT>(), prediction->dptr<PredT>(),\n        label->dptr<LabelT>(), loss_diff->dptr<PredT>());\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_SIGMOID_CROSS_ENTROPY_GRAD_KERNEL(device_type, dtype, ltype)                 \\\n  REGISTER_USER_KERNEL(\"sigmoid_cross_entropy_grad\")                                          \\\n      .SetCreateFn<SigmoidCrossEntropyGradKernel<device_type, SigmoidCrossEntropyGradFunctor, \\\n                                                 dtype, ltype>>()                             \\\n      .SetIsMatchedHob(                                                                       \\\n          (user_op::HobDeviceType() == device_type)                                           \\\n          && (user_op::HobDataType(\"label\", 0) == GetDataType<ltype>::value)                  \\\n          && (user_op::HobDataType(\"prediction_diff\", 0) == GetDataType<dtype>::value));\n\n}  // namespace oneflow\n#endif  // ONEFLOW_USER_KERNELS_SIGMOID_CROSS_ENTROPY_KERNEL_H_\n"
  },
  {
    "path": "oneflow/user/kernels/skip_layer_norm_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/device/cudnn_util.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/ndarray/ndarray_util.h\"\n#include \"oneflow/core/cuda/atomic.cuh\"\n#include <cub/cub.cuh>\n#include \"oneflow/core/kernel/cuda_graph_support.h\"\n#include \"oneflow/core/ep/include/primitive/fill.h\"\n#include \"oneflow/core/ep/include/primitive/matmul.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n#include \"oneflow/core/cuda/layer_norm.cuh\"\n#if CUDA_VERSION >= 11000\n#include <cuda_bf16.h>\n#endif  // CUDA_VERSION >= 11000\n#include \"oneflow/core/device/cuda_pseudo_bfloat16.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename SRC, typename DST>\nstruct SkipLoad {\n  using LoadType = DST;\n  SkipLoad(const SRC* src, const SRC* bias, const SRC* skip, float alpha, int64_t row_size)\n      : src(src), bias(bias), skip(skip), alpha(alpha), row_size(row_size) {}\n  template<int N>\n  __device__ void load(DST* dst, int64_t row, int64_t col) const {\n    cuda::layer_norm::Pack<SRC, N> src_pack;\n    cuda::layer_norm::Pack<SRC, N> bias_pack;\n    cuda::layer_norm::Pack<SRC, N> skip_pack;\n    const int64_t offset = (row * row_size + col) / N;\n    const int64_t bias_offset = col / N;\n    src_pack.storage = *(reinterpret_cast<const cuda::layer_norm::PackType<SRC, N>*>(src) + offset);\n    if (bias) {\n      bias_pack.storage =\n          *(reinterpret_cast<const cuda::layer_norm::PackType<SRC, N>*>(bias) + bias_offset);\n    } else {\n#pragma unroll\n      for (int i = 0; i < N; ++i) { bias_pack.elem[i] = static_cast<SRC>(0.f); }\n    }\n    if (skip) {\n      skip_pack.storage =\n          *(reinterpret_cast<const cuda::layer_norm::PackType<SRC, N>*>(skip) + offset);\n    } else {\n#pragma unroll\n      for (int i = 0; i < N; ++i) { skip_pack.elem[i] = static_cast<SRC>(0.f); }\n    }\n#pragma unroll\n    for (int i = 0; i < N; ++i) {\n      dst[i] = static_cast<DST>(src_pack.elem[i] + bias_pack.elem[i]\n                                + skip_pack.elem[i] * static_cast<SRC>(alpha));\n    }\n  }\n  const SRC* src;\n  const SRC* bias;\n  const SRC* skip;\n  double alpha;\n  int64_t row_size;\n};\n\ntemplate<typename SRC, typename DST, bool do_scale, bool do_center>\nstruct AffineStore {\n  AffineStore(DST* y, int64_t row_size, const DST* gamma, const DST* beta)\n      : y(y), row_size(row_size), gamma(gamma), beta(beta) {}\n  template<int N>\n  __device__ void store(const SRC* src, int64_t row, int64_t col) {\n    cuda::layer_norm::Pack<DST, N> y_pack;\n    cuda::layer_norm::Pack<DST, N> gamma_pack;\n    cuda::layer_norm::Pack<DST, N> beta_pack;\n    const int64_t offset = (row * row_size + col) / N;\n    const int64_t gamma_offset = col / N;\n    if (do_scale) {\n      gamma_pack.storage =\n          *(reinterpret_cast<const cuda::layer_norm::PackType<DST, N>*>(gamma) + gamma_offset);\n    } else {\n#pragma unroll\n      for (int i = 0; i < N; ++i) { gamma_pack.elem[i] = static_cast<DST>(1.f); }\n    }\n    if (do_center) {\n      beta_pack.storage =\n          *(reinterpret_cast<const cuda::layer_norm::PackType<DST, N>*>(beta) + gamma_offset);\n    } else {\n#pragma unroll\n      for (int i = 0; i < N; ++i) { beta_pack.elem[i] = static_cast<DST>(0.f); }\n    }\n#pragma unroll\n    for (int i = 0; i < N; ++i) {\n      DST normalized_i = static_cast<DST>(src[i]);\n      if (do_scale || do_center) {\n        y_pack.elem[i] = normalized_i * gamma_pack.elem[i] + beta_pack.elem[i];\n      } else {\n        y_pack.elem[i] = normalized_i;\n      }\n    }\n    *(reinterpret_cast<cuda::layer_norm::PackType<DST, N>*>(y) + offset) = y_pack.storage;\n  }\n  DST* y;\n  int64_t row_size;\n  const DST* gamma;\n  const DST* beta;\n};\n\ntemplate<typename T, bool do_scale, bool do_center>\nvoid LaunchSkipLayerNormForwardGpu(ep::Stream* stream, const int64_t num_instances,\n                                   const int64_t norm_size, const double epsilon, const T* x_ptr,\n                                   const T* gamma_ptr, const T* beta_ptr, const T* bias_ptr,\n                                   const T* skip_ptr, const double alpha, T* y_ptr,\n                                   user_op::Tensor* mean, user_op::Tensor* inv_variance) {\n  constexpr int32_t block_size = 128;\n  unsigned int nb_element = norm_size * num_instances;\n  unsigned int grid_size = (nb_element + block_size - 1) / block_size;\n  using ComputeType = typename cuda::layer_norm::DefaultComputeType<T>::type;\n  SkipLoad<T, T> load(x_ptr, bias_ptr, skip_ptr, alpha, norm_size);\n  AffineStore<ComputeType, T, do_scale, do_center> store(y_ptr, norm_size, gamma_ptr, beta_ptr);\n  cuda::layer_norm::DispatchLayerNorm<decltype(load), decltype(store), ComputeType>(\n      stream->As<ep::CudaStream>()->cuda_stream(), load, store, num_instances, norm_size, epsilon,\n      mean->mut_dptr<ComputeType>(), inv_variance->mut_dptr<ComputeType>());\n}\n\ntemplate<typename T>\nvoid DispatchSkipLayerNormForwardGpu(ep::Stream* stream, const int64_t num_instances,\n                                     const int64_t norm_size, const double epsilon, const T* x_ptr,\n                                     const T* gamma_ptr, const T* beta_ptr, const T* bias_ptr,\n                                     const T* skip_ptr, const double alpha, T* y_ptr,\n                                     user_op::Tensor* mean, user_op::Tensor* inv_variance) {\n#define LAUNCH_GPU_KERNEL(has_gamma, has_beta)                                                   \\\n  LaunchSkipLayerNormForwardGpu<T, has_gamma, has_beta>(                                         \\\n      stream, num_instances, norm_size, epsilon, x_ptr, gamma_ptr, beta_ptr, bias_ptr, skip_ptr, \\\n      alpha, y_ptr, mean, inv_variance);\n\n  if (gamma_ptr != nullptr && beta_ptr != nullptr) {\n    LAUNCH_GPU_KERNEL(true, true);\n  } else if (gamma_ptr != nullptr && beta_ptr == nullptr) {\n    LAUNCH_GPU_KERNEL(true, false);\n  } else if (gamma_ptr == nullptr && beta_ptr != nullptr) {\n    LAUNCH_GPU_KERNEL(false, true);\n  } else {\n    LAUNCH_GPU_KERNEL(false, false);\n  }\n\n#undef LAUNCH_GPU_KERNEL\n}\n\ntemplate<typename T>\nclass SkipLayerNormGpuKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport {\n public:\n  SkipLayerNormGpuKernel() = default;\n  ~SkipLayerNormGpuKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    // obtain x and check its shape\n    const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    const ShapeView& x_shape = x->shape_view();\n    CHECK_GE(x_shape.NumAxes(), 2)\n        << \"number of axes of \\'x\\' should be greater than or equal to 2, yet get \"\n        << x_shape.NumAxes();\n\n    // obtain gamma and check its shape\n    const T* gamma_ptr = nullptr;\n    ShapeView gamma_shape;\n    if (ctx->has_input(\"gamma\", 0)) {\n      const user_op::Tensor* gamma = ctx->Tensor4ArgNameAndIndex(\"gamma\", 0);\n      gamma_shape = gamma->shape_view();\n      gamma_ptr = gamma->dptr<T>();\n      CHECK_EQ(gamma_shape.NumAxes(), 1)\n          << \"number of axes of \\'gamma\\' should be equal to 1, yet get \" << gamma_shape.NumAxes();\n      CHECK_EQ(gamma_shape.At(0), x_shape.At(x_shape.NumAxes() - 1))\n          << \"the size of \\'gamma\\'(\" << gamma_shape.At(0)\n          << \") is not consistant with the last dimension of \\'x\\'(\"\n          << x_shape.At(x_shape.NumAxes() - 1) << \")\";\n    }\n\n    // obtain beta and check its shape\n    const T* beta_ptr = nullptr;\n    ShapeView beta_shape;\n    if (ctx->has_input(\"beta\", 0)) {\n      const user_op::Tensor* beta = ctx->Tensor4ArgNameAndIndex(\"beta\", 0);\n      beta_shape = beta->shape_view();\n      beta_ptr = beta->dptr<T>();\n      CHECK_EQ(beta_shape.NumAxes(), 1)\n          << \"number of axes of \\'beta\\' should be equal to 1, yet get \" << beta_shape.NumAxes();\n      CHECK_EQ(beta_shape.At(0), x_shape.At(x_shape.NumAxes() - 1))\n          << \"the size of \\'beta\\'(\" << beta_shape.At(0)\n          << \") is not consistant with the last dimension of \\'x\\'(\"\n          << x_shape.At(x_shape.NumAxes() - 1) << \")\";\n    }\n\n    // obtain bias and check its shape\n    const T* bias_ptr = nullptr;\n    ShapeView bias_shape;\n    if (ctx->has_input(\"bias\", 0)) {\n      const user_op::Tensor* bias = ctx->Tensor4ArgNameAndIndex(\"bias\", 0);\n      bias_shape = bias->shape_view();\n      bias_ptr = bias->dptr<T>();\n      CHECK_EQ(bias_shape.NumAxes(), 1)\n          << \"number of axes of \\'bias\\' should be equal to 1, yet get \" << bias_shape.NumAxes();\n      CHECK_EQ(bias_shape.At(0), x_shape.At(x_shape.NumAxes() - 1))\n          << \"the size of \\'bias\\'(\" << bias_shape.At(0)\n          << \") is not consistant with the last dimension of \\'x\\'(\"\n          << x_shape.At(x_shape.NumAxes() - 1) << \")\";\n    }\n\n    // obtain residual and check its shape\n    const T* skip_ptr = nullptr;\n    ShapeView skip_shape;\n    if (ctx->has_input(\"skip\", 0)) {\n      const user_op::Tensor* skip = ctx->Tensor4ArgNameAndIndex(\"skip\", 0);\n      skip_shape = skip->shape_view();\n      skip_ptr = skip->dptr<T>();\n      CHECK_EQ(skip_shape, x_shape);\n    }\n\n    // obtain epsilon and check its value\n    const double epsilon = ctx->Attr<double>(\"epsilon\");\n    const double alpha = ctx->Attr<double>(\"alpha\");\n\n    // obtain output tensors\n    user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    user_op::Tensor* mean = ctx->Tensor4ArgNameAndIndex(\"mean\", 0);\n    user_op::Tensor* inv_variance = ctx->Tensor4ArgNameAndIndex(\"inv_variance\", 0);\n    const ShapeView& y_shape = y->shape_view();\n    const ShapeView& mean_shape = mean->shape_view();\n    const ShapeView& inv_variance_shape = inv_variance->shape_view();\n\n    // calculate number of instances and norm size\n    const int64_t num_instances = mean->shape_view().elem_cnt();\n    const int64_t norm_size = x->shape_view().elem_cnt() / num_instances;\n\n    // dispatch kernel\n    DispatchSkipLayerNormForwardGpu<T>(ctx->stream(), num_instances, norm_size, epsilon,\n                                       x->dptr<T>(), gamma_ptr, beta_ptr, bias_ptr, skip_ptr, alpha,\n                                       y->mut_dptr<T>(), mean, inv_variance);\n  }\n};\n\n}  // namespace\n\n#define REGISTER_SKIP_LAYER_NORM_CUDA_KERNEL(dtype)                    \\\n  REGISTER_USER_KERNEL(\"skip_layer_norm\")                              \\\n      .SetCreateFn<SkipLayerNormGpuKernel<dtype>>()                    \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \\\n                       && (user_op::HobDataType(\"y\", 0) == GetDataType<dtype>::value));\n\nREGISTER_SKIP_LAYER_NORM_CUDA_KERNEL(float)\nREGISTER_SKIP_LAYER_NORM_CUDA_KERNEL(double)\nREGISTER_SKIP_LAYER_NORM_CUDA_KERNEL(half)\n#if CUDA_VERSION >= 11000\nREGISTER_SKIP_LAYER_NORM_CUDA_KERNEL(nv_bfloat16)\n#endif\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/skip_rms_norm_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/device/cudnn_util.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/ndarray/ndarray_util.h\"\n#include \"oneflow/core/cuda/atomic.cuh\"\n#include <cub/cub.cuh>\n#include \"oneflow/core/kernel/cuda_graph_support.h\"\n#include \"oneflow/core/ep/include/primitive/fill.h\"\n#include \"oneflow/core/ep/include/primitive/matmul.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n#include \"oneflow/core/cuda/rms_norm.cuh\"\n#if CUDA_VERSION >= 11000\n#include <cuda_bf16.h>\n#endif  // CUDA_VERSION >= 11000\n#include \"oneflow/core/device/cuda_pseudo_bfloat16.h\"\n\nnamespace oneflow {\n\nnamespace cuda {\n\nnamespace rms_norm {\n\ntemplate<typename SRC, typename DST>\nstruct SkipLoad {\n  using LoadType = DST;\n  SkipLoad(const SRC* src, const SRC* bias, const SRC* skip, const float alpha, int64_t row_size)\n      : src(src), bias(bias), skip(skip), alpha(alpha), row_size(row_size) {}\n  template<int N>\n  __device__ void load(DST* dst, int64_t row, int64_t col) const {\n    layer_norm::Pack<SRC, N> src_pack;\n    layer_norm::Pack<SRC, N> bias_pack;\n    layer_norm::Pack<SRC, N> skip_pack;\n    const int64_t offset = (row * row_size + col) / N;\n    const int64_t bias_offset = col / N;\n    src_pack.storage = *(reinterpret_cast<const layer_norm::PackType<SRC, N>*>(src) + offset);\n    if (bias) {\n      bias_pack.storage =\n          *(reinterpret_cast<const layer_norm::PackType<SRC, N>*>(bias) + bias_offset);\n    } else {\n#pragma unroll\n      for (int i = 0; i < N; ++i) { bias_pack.elem[i] = static_cast<SRC>(0.f); }\n    }\n    if (skip) {\n      skip_pack.storage = *(reinterpret_cast<const layer_norm::PackType<SRC, N>*>(skip) + offset);\n    } else {\n#pragma unroll\n      for (int i = 0; i < N; ++i) { skip_pack.elem[i] = static_cast<SRC>(0.f); }\n    }\n#pragma unroll\n    for (int i = 0; i < N; ++i) {\n      dst[i] = static_cast<DST>(src_pack.elem[i] + bias_pack.elem[i]\n                                + skip_pack.elem[i] * static_cast<SRC>(alpha));\n    }\n  }\n  const SRC* src;\n  const SRC* bias;\n  const SRC* skip;\n  float alpha;\n  int64_t row_size;\n};\n\ntemplate<typename SRC, typename DST, bool affine>\nstruct AffineStore {\n  AffineStore(DST* dst, const DST* weight, int32_t row_size)\n      : dst(dst), weight(weight), row_size(row_size) {}\n\n  template<int N>\n  __device__ void store(const SRC* src, int32_t row, int32_t col) {\n    layer_norm::Pack<DST, N> dst_pack;\n    layer_norm::Pack<DST, N> weight_pack;\n    const int32_t offset = (row * row_size + col) / N;\n    const int32_t weight_offset = col / N;\n    if (affine) {\n      weight_pack.storage =\n          *(reinterpret_cast<const layer_norm::PackType<DST, N>*>(weight) + weight_offset);\n    }\n#pragma unroll\n    for (int i = 0; i < N; ++i) {\n      if (affine) {\n        dst_pack.elem[i] = static_cast<DST>(src[i]) * weight_pack.elem[i];\n      } else {\n        dst_pack.elem[i] = static_cast<DST>(src[i]);\n      }\n    }\n    *(reinterpret_cast<layer_norm::PackType<DST, N>*>(dst) + offset) = dst_pack.storage;\n  }\n\n  DST* dst;\n  const DST* weight;\n  int32_t row_size;\n};\n\ntemplate<typename T, typename ComputeType, bool affine>\nvoid DispatchSkipRmsNormForwardAffine(ep::Stream* stream, const int64_t nrow, const int64_t ncol,\n                                      const double eps, const double alpha, const T* x_dptr,\n                                      const T* w_dptr, const T* skip_dptr, const T* bias_dptr,\n                                      T* y_dptr, ComputeType* inv_rms) {\n  constexpr int32_t block_size = 128;\n  unsigned int nb_element = nrow * ncol;\n  unsigned int grid_size = (nb_element + block_size - 1) / block_size;\n  SkipLoad<T, ComputeType> load(x_dptr, bias_dptr, skip_dptr, alpha, ncol);\n  AffineStore<ComputeType, T, affine> store(y_dptr, w_dptr, ncol);\n  OF_CUDA_CHECK((LaunchRmsNorm<decltype(load), decltype(store), ComputeType>(\n      stream->As<ep::CudaStream>()->cuda_stream(), load, store, nrow, ncol, eps, inv_rms)));\n}\n\ntemplate<typename T, typename ComputeType>\nvoid SkipRmsNormForward(ep::Stream* stream, const int64_t nrow, const int64_t ncol,\n                        const double eps, const double alpha, const T* x_dptr, const T* w_dptr,\n                        const T* skip_dptr, const T* bias_dptr, T* y_dptr, ComputeType* inv_rms) {\n  if (w_dptr) {\n    DispatchSkipRmsNormForwardAffine<T, ComputeType, true>(\n        stream, nrow, ncol, eps, alpha, x_dptr, w_dptr, skip_dptr, bias_dptr, y_dptr, inv_rms);\n  } else {\n    DispatchSkipRmsNormForwardAffine<T, ComputeType, false>(\n        stream, nrow, ncol, eps, alpha, x_dptr, w_dptr, skip_dptr, bias_dptr, y_dptr, inv_rms);\n  }\n}\n\n}  // namespace rms_norm\n\ntemplate<typename T>\nclass SkipRmsNormGpuKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport {\n public:\n  SkipRmsNormGpuKernel() = default;\n  ~SkipRmsNormGpuKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    // obtain x and check its shape\n    const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    const ShapeView& x_shape = x->shape_view();\n    CHECK_GE(x_shape.NumAxes(), 2)\n        << \"number of axes of \\'x\\' should be greater than or equal to 2, yet get \"\n        << x_shape.NumAxes();\n\n    // obtain weight and check its shape\n    const T* weight_ptr = nullptr;\n    ShapeView weight_shape;\n    if (ctx->has_input(\"weight\", 0)) {\n      const user_op::Tensor* weight = ctx->Tensor4ArgNameAndIndex(\"weight\", 0);\n      weight_shape = weight->shape_view();\n      weight_ptr = weight->dptr<T>();\n      CHECK_EQ(weight_shape.NumAxes(), 1)\n          << \"number of axes of \\'weight\\' should be equal to 1, yet get \"\n          << weight_shape.NumAxes();\n      CHECK_EQ(weight_shape.At(0), x_shape.At(x_shape.NumAxes() - 1))\n          << \"the size of \\'weight\\'(\" << weight_shape.At(0)\n          << \") is not consistant with the last dimension of \\'x\\'(\"\n          << x_shape.At(x_shape.NumAxes() - 1) << \")\";\n    }\n\n    // obtain bias and check its shape\n    const T* bias_ptr = nullptr;\n    ShapeView bias_shape;\n    if (ctx->has_input(\"bias\", 0)) {\n      const user_op::Tensor* bias = ctx->Tensor4ArgNameAndIndex(\"bias\", 0);\n      bias_shape = bias->shape_view();\n      bias_ptr = bias->dptr<T>();\n      CHECK_EQ(bias_shape.NumAxes(), 1)\n          << \"number of axes of \\'bias\\' should be equal to 1, yet get \" << bias_shape.NumAxes();\n      CHECK_EQ(bias_shape.At(0), x_shape.At(x_shape.NumAxes() - 1))\n          << \"the size of \\'bias\\'(\" << bias_shape.At(0)\n          << \") is not consistant with the last dimension of \\'x\\'(\"\n          << x_shape.At(x_shape.NumAxes() - 1) << \")\";\n    }\n\n    // obtain skip and check its shape\n    const T* skip_ptr = nullptr;\n    ShapeView skip_shape;\n    if (ctx->has_input(\"skip\", 0)) {\n      const user_op::Tensor* skip = ctx->Tensor4ArgNameAndIndex(\"skip\", 0);\n      skip_shape = skip->shape_view();\n      skip_ptr = skip->dptr<T>();\n      CHECK_EQ(skip_shape, x_shape);\n    }\n\n    // obtain epsilon and check its value\n    const double epsilon = ctx->Attr<double>(\"epsilon\");\n    const double alpha = ctx->Attr<double>(\"alpha\");\n\n    // obtain output tensors\n    user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    user_op::Tensor* inv_rms = ctx->Tensor4ArgNameAndIndex(\"inv_rms\", 0);\n    const ShapeView& y_shape = y->shape_view();\n    const ShapeView& inv_rms_shape = inv_rms->shape_view();\n\n    // calculate number of instances and norm size\n    const int64_t nrow = inv_rms->shape_view().elem_cnt();\n    const int64_t ncol = x->shape_view().elem_cnt() / nrow;\n\n    // dispatch kernel\n    using ComputeType = typename layer_norm::DefaultComputeType<T>::type;\n    rms_norm::SkipRmsNormForward(ctx->stream(), nrow, ncol, epsilon, alpha, x->dptr<T>(),\n                                 weight_ptr, skip_ptr, bias_ptr, y->mut_dptr<T>(),\n                                 inv_rms->mut_dptr<ComputeType>());\n  }\n};\n\n#define REGISTER_SKIP_RMS_NORM_CUDA_KERNEL(dtype)                      \\\n  REGISTER_USER_KERNEL(\"skip_rms_norm\")                                \\\n      .SetCreateFn<SkipRmsNormGpuKernel<dtype>>()                      \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \\\n                       && (user_op::HobDataType(\"y\", 0) == GetDataType<dtype>::value));\n\nREGISTER_SKIP_RMS_NORM_CUDA_KERNEL(float)\nREGISTER_SKIP_RMS_NORM_CUDA_KERNEL(double)\nREGISTER_SKIP_RMS_NORM_CUDA_KERNEL(half)\n#if CUDA_VERSION >= 11000\nREGISTER_SKIP_RMS_NORM_CUDA_KERNEL(nv_bfloat16)\n#endif\n\n}  // namespace cuda\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/slice_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/balanced_splitter.h\"\n#include \"oneflow/core/framework/nd_sbp.h\"\n#include \"oneflow/core/job/nd_sbp_util.h\"\n#include \"oneflow/core/common/switch_func.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/user/kernels/slice_util.h\"\n#include \"oneflow/core/kernel/kernel_util.h\"\n#include \"oneflow/user/kernels/op_kernel_wrapper.h\"\n#include \"oneflow/core/kernel/cuda_graph_support.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nconst int SPLIT_AXIS_FOR_NON_SPLIT = -1;\n\n// [start, end)\nint64_t GetSizeInSlice(const int64_t start, const int64_t end, const int64_t step) {\n  if (end <= start) { return 0; }\n  return (end - start - 1) / step + 1;\n}\n\nclass SliceContext final {\n public:\n  struct SplitInfo {\n    // These fields shows how the logical tensor is split.\n    // The logical tensor is split on the axis `split_axis`\n    // The physical tensor on current device is in the range [lower, upper)\n    // The length of the logical tensor on `split_axis` is `logical_length`\n    // Example:\n    // Variable shape = (8, 7, 6, 5), sbp = S(0), on 4 devices, then on the first card:\n    // split_axis = 0\n    // lower = 0\n    // upper = 2\n    // logical_length = 8\n    const int64_t split_axis;\n    const int64_t lower;\n    const int64_t upper;\n    const int64_t logical_length;\n  };\n\n  SliceContext() : axis_bitset_(0) {}\n\n  Maybe<void> PushSplitInfo(int64_t split_axis, int64_t lower, int64_t upper,\n                            int64_t logical_length) {\n    if (split_axis != SPLIT_AXIS_FOR_NON_SPLIT) {\n      // split_axis can only be push once\n      CHECK_OR_RETURN(!IsAxisPushed(split_axis))\n          << \"split_axis \" << split_axis << \" has been pushed to SliceContext\";\n      CHECK_GE_OR_RETURN(split_axis, 0) << \"split_axis >= 0 or equal to SPLIT_AXIS_FOR_NON_SPLIT\";\n\n      axis_bitset_ |= ((uint32_t)1 << split_axis);  // NOLINT\n    }\n    split_info_vec_.emplace_back(SplitInfo{split_axis, lower, upper, logical_length});\n    return Maybe<void>::Ok();\n  }\n  const std::vector<SplitInfo>& GetSplitInfo() const { return split_info_vec_; }\n  bool IsAxisPushed(int64_t split_axis) const {\n    if (split_axis == SPLIT_AXIS_FOR_NON_SPLIT) { return false; }\n    CHECK_GE(split_axis, 0) << \"split_axis >= 0 or equal to SPLIT_AXIS_FOR_NON_SPLIT\";\n    return (axis_bitset_ & ((uint32_t)1 << split_axis)) != 0;  // NOLINT\n  }\n\n private:\n  std::vector<SplitInfo> split_info_vec_;\n  uint32_t axis_bitset_;\n};\n\nvoid ConstructSliceParamsLarge(const SliceContext& ctx, const std::vector<int64_t>& start_vec,\n                               const std::vector<int64_t>& stop_vec,\n                               const std::vector<int64_t>& step_vec, const ShapeView& shape,\n                               SliceParams* slice_param) {\n  const int64_t ndim = shape.NumAxes();\n  CHECK_LE(ndim, kSliceMaxDims);\n  CHECK_EQ(start_vec.size(), ndim);\n  CHECK_EQ(stop_vec.size(), ndim);\n  CHECK_EQ(step_vec.size(), ndim);\n\n  slice_param->ndim = ndim;\n  FOR_RANGE(int, i, 0, slice_param->ndim) {\n    const int64_t dim_size = shape.At(i);\n    const int64_t start_in_full_large = start_vec.at(i);\n    const int64_t stop_in_full_large = stop_vec.at(i);\n    const int64_t step = step_vec.at(i);\n    CHECK_GT(step, 0);\n    int64_t start_in_splitted_large = start_in_full_large;\n    int64_t stop_in_splitted_large = stop_in_full_large;\n    // large tensor has split sbp attribute\n    for (const auto& split_info : ctx.GetSplitInfo()) {\n      if (split_info.split_axis == i) {\n        if (start_in_splitted_large < split_info.lower) {\n          start_in_splitted_large =\n              split_info.lower\n              + (step - (split_info.lower - start_in_splitted_large) % step) % step;\n        }\n        start_in_splitted_large =\n            std::min(std::max(start_in_splitted_large, split_info.lower), split_info.upper);\n        stop_in_splitted_large =\n            std::min(std::max(stop_in_splitted_large, split_info.lower), split_info.upper);\n        start_in_splitted_large -= split_info.lower;\n        stop_in_splitted_large -= split_info.lower;\n      }\n    }\n    const int64_t slice_size =\n        GetSizeInSlice(start_in_splitted_large, stop_in_splitted_large, step);\n    slice_param->dims[i] = dim_size;\n    slice_param->start[i] = start_in_splitted_large;\n    slice_param->step[i] = step;\n    slice_param->size[i] = slice_size;\n  }\n}\n\nvoid ConstructSliceParamsSmall(const SliceContext& ctx, const std::vector<int64_t>& start_vec,\n                               const std::vector<int64_t>& stop_vec,\n                               const std::vector<int64_t>& step_vec, const ShapeView& shape,\n                               SliceParams* slice_param) {\n  const int64_t ndim = shape.NumAxes();\n  CHECK_LE(ndim, kSliceMaxDims);\n  CHECK_EQ(start_vec.size(), ndim);\n  CHECK_EQ(stop_vec.size(), ndim);\n  CHECK_EQ(step_vec.size(), ndim);\n\n  slice_param->ndim = ndim;\n  FOR_RANGE(int, i, 0, slice_param->ndim) {\n    const int64_t start_in_full_large = start_vec.at(i);\n    const int64_t step = step_vec.at(i);\n    CHECK_GT(step, 0);\n    // small tensor has broadcast/partialsum sbp attribute\n    const int64_t dim_size = shape.At(i);\n    int64_t start_in_full_small = 0;\n    int64_t stop_in_full_small = dim_size;\n    for (const auto& split_info : ctx.GetSplitInfo()) {\n      if (split_info.split_axis == i) {\n        start_in_full_small = GetSizeInSlice(start_in_full_large, split_info.lower, step);\n        stop_in_full_small = GetSizeInSlice(start_in_full_large, split_info.upper, step);\n        start_in_full_small = std::min(std::max<int64_t>(start_in_full_small, 0), dim_size);\n        stop_in_full_small = std::min(std::max<int64_t>(stop_in_full_small, 0), dim_size);\n      }\n    }\n    const int64_t slice_size = stop_in_full_small - start_in_full_small;\n    slice_param->dims[i] = dim_size;\n    slice_param->start[i] = start_in_full_small;\n    slice_param->step[i] = 1;\n    slice_param->size[i] = slice_size;\n  }\n}\n\nSliceParams ConstructSliceParams(user_op::KernelComputeContext* ctx, const user_op::Tensor* entire,\n                                 const user_op::Tensor* sliced) {\n  const auto& start_vec = ctx->Attr<std::vector<int64_t>>(\"start\");\n  const auto& stop_vec = ctx->Attr<std::vector<int64_t>>(\"stop\");\n  const auto& step_vec = ctx->Attr<std::vector<int64_t>>(\"step\");\n  const int64_t ndim = entire->shape_view().NumAxes();\n  CHECK_LE(ndim, kSliceMaxDims);\n  if (entire->shape_view().NumAxes() == 1) {\n    CHECK_LE(sliced->shape_view().NumAxes(), 1);\n  } else {\n    CHECK_EQ(sliced->shape_view().NumAxes(), ndim);\n  }\n  CHECK_EQ(start_vec.size(), ndim);\n  CHECK_EQ(stop_vec.size(), ndim);\n  CHECK_EQ(step_vec.size(), ndim);\n\n  SliceParams params;\n  if (entire->shape_view().NumAxes() == 1 && sliced->shape_view().NumAxes() == 0) {\n    params.ndim = ndim;\n    params.dims[0] = entire->shape_view().At(0);\n    params.start[0] = RegulateSliceStart(start_vec.at(0), entire->shape_view().At(0));\n    params.step[0] = step_vec.at(0);\n    params.size[0] = 1;\n    return params;\n  }\n  params.ndim = ndim;\n  FOR_RANGE(int, i, 0, params.ndim) {\n    const int64_t dim_size = entire->shape_view().At(i);\n    const int64_t slice_size = sliced->shape_view().At(i);\n    const int64_t step = step_vec.at(i);\n    CHECK_NE(step, 0);\n    const int64_t start = RegulateSliceStart(start_vec.at(i), dim_size);\n    const int64_t stop = RegulateSliceStop(stop_vec.at(i), dim_size);\n    if (step > 0) {\n      CHECK_LT(start + step * (slice_size - 1), stop);\n    } else {\n      CHECK_GT(start + step * (slice_size - 1), stop);\n    }\n    params.dims[i] = dim_size;\n    params.start[i] = start;\n    params.step[i] = step;\n    params.size[i] = slice_size;\n  }\n  return params;\n}\n\n}  // namespace\n\ntemplate<DeviceType device_type, typename T>\nvoid WriteSlice(user_op::KernelComputeContext* ctx, const user_op::Tensor* src,\n                user_op::Tensor* dst, const SliceContext& slice_ctx,\n                const bool from_large_to_small) {\n  const user_op::Tensor* large = from_large_to_small ? src : dst;\n  const user_op::Tensor* small = from_large_to_small ? dst : src;\n  // Check physical tensor's shape\n  for (const auto& split_info : slice_ctx.GetSplitInfo()) {\n    if (split_info.split_axis != SPLIT_AXIS_FOR_NON_SPLIT) {\n      CHECK_EQ(large->shape_view().At(split_info.split_axis), split_info.upper - split_info.lower)\n          << \"split_info shape mismatch physical tensor shape\";\n    }\n  }\n\n  const std::vector<int64_t> start_attr = ctx->Attr<std::vector<int64_t>>(\"start\");\n  const std::vector<int64_t> stop_attr = ctx->Attr<std::vector<int64_t>>(\"stop\");\n  const std::vector<int64_t> step_attr = ctx->Attr<std::vector<int64_t>>(\"step\");\n  const int64_t ndim = start_attr.size();\n  std::vector<int64_t> positive_start_vec(ndim);\n  std::vector<int64_t> positive_stop_vec(ndim);\n\n  // regulate axis number\n  std::vector<int64_t> logical_dims(ndim);\n  {\n    for (int i = 0; i < ndim; i++) {\n      if (!slice_ctx.IsAxisPushed(i)) {\n        // axis is not split, logical shape is same as physical shape\n        logical_dims[i] = large->shape_view().At(i);\n      }\n    }\n    for (const auto& split_info : slice_ctx.GetSplitInfo()) {\n      if (split_info.split_axis != SPLIT_AXIS_FOR_NON_SPLIT) {\n        logical_dims[split_info.split_axis] = split_info.logical_length;\n      }\n    }\n  }\n  for (int i = 0; i < ndim; i++) {\n    positive_start_vec[i] = RegulateSliceStart(start_attr[i], logical_dims[i]);\n    positive_stop_vec[i] = RegulateSliceStop(stop_attr[i], logical_dims[i]);\n  }\n\n  SliceParams large_slice_param;\n  std::copy(large->stride().begin(), large->stride().end(), large_slice_param.stride);\n  SliceParams small_slice_param;\n  std::copy(small->stride().begin(), small->stride().end(), small_slice_param.stride);\n  ConstructSliceParamsLarge(slice_ctx, positive_start_vec, positive_stop_vec, step_attr,\n                            large->shape_view(), &large_slice_param);\n  ConstructSliceParamsSmall(slice_ctx, positive_start_vec, positive_stop_vec, step_attr,\n                            small->shape_view(), &small_slice_param);\n  CHECK_EQ(large_slice_param.elem_cnt(), small_slice_param.elem_cnt());\n  if (large_slice_param.ndim == 0 && small_slice_param.ndim == 0) {\n    // Copy data directly for scalar tensor\n    AutoMemcpy(ctx->stream(), dst->mut_dptr<T>(), src->dptr<T>(), sizeof(T), src->mem_case(),\n               dst->mem_case());\n    return;\n  }\n  if (from_large_to_small) {\n    if (small_slice_param.elem_cnt() == small->shape_view().elem_cnt()) {\n      SliceKernelUtil<device_type, T>::Forward(ctx->stream(), large_slice_param, src->dptr<T>(),\n                                               dst->mut_dptr<T>());\n    } else {\n      AutoMemset(ctx->stream(), dst->mut_dptr(), 0,\n                 dst->shape_view().elem_cnt() * GetSizeOfDataType(dst->data_type()),\n                 dst->mem_case());\n      SliceKernelUtil<device_type, T>::Forward(ctx->stream(), large_slice_param, small_slice_param,\n                                               src->dptr<T>(), dst->mut_dptr<T>());\n    }\n  } else {\n    SliceKernelUtil<device_type, T>::Forward(ctx->stream(), small_slice_param, large_slice_param,\n                                             src->dptr<T>(), dst->mut_dptr<T>());\n  }\n}\n\ntemplate<DeviceType device_type, typename T>\nclass SliceKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport {\n public:\n  SliceKernel() = default;\n  ~SliceKernel() = default;\n\n  std::shared_ptr<user_op::OpKernelCache> InitOpKernelCache(\n      user_op::KernelCacheContext* ctx) const override {\n    SliceContext slice_ctx;\n    if (ctx->parallel_ctx().parallel_num() == 1) {\n      // split_axis == SPLIT_AXIS_FOR_NON_SPLIT means the sbp attribute is not 'split'\n      CHECK_JUST(slice_ctx.PushSplitInfo(SPLIT_AXIS_FOR_NON_SPLIT, 0, 0, 0));\n    } else {\n      const Shape& parallel_hierarchy = *ctx->parallel_desc().hierarchy();\n      NdSbp in_nd_sbp = ctx->NdSbp4ArgNameAndIndex(\"x\", 0);\n      {\n        const NdSbp& y_nd_sbp = ctx->NdSbp4ArgNameAndIndex(\"y\", 0);\n        // If x and y both split in the same axis(must be full slice),\n        // we can consider the physical tensor is broadcast in this axis.\n        FOR_RANGE(int32_t, i, 0, parallel_hierarchy.NumAxes()) {\n          const SbpParallel& x_sbp = in_nd_sbp.sbp_parallel(i);\n          const SbpParallel& y_sbp = y_nd_sbp.sbp_parallel(i);\n          if (x_sbp.has_split_parallel() && y_sbp.has_split_parallel()) {\n            CHECK_EQ(x_sbp.split_parallel().axis(), y_sbp.split_parallel().axis());\n            in_nd_sbp.mutable_sbp_parallel(i)->clear_split_parallel();\n            in_nd_sbp.mutable_sbp_parallel(i)->mutable_broadcast_parallel();\n          }\n        }\n      }\n      const Shape& logical_shape = ctx->LogicalTensorDesc4ArgNameAndIndex(\"x\", 0)->shape();\n      const int64_t parallel_id = ctx->parallel_ctx().parallel_id();\n      const TensorSliceView& slice_view =\n          GetTensorSliceView4ParallelId(parallel_hierarchy, in_nd_sbp, logical_shape, parallel_id);\n      for (int i = 0; i < logical_shape.NumAxes(); ++i) {\n        const Range& range = slice_view.At(i);\n        if (range.begin() != 0 || range.end() != logical_shape.At(i)) {\n          CHECK_JUST(slice_ctx.PushSplitInfo(i, range.begin(), range.end(), logical_shape.At(i)));\n        }\n      }\n    }\n    return std::make_shared<OpKernelCacheWrapper<SliceContext>>(slice_ctx);\n  }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*,\n               const user_op::OpKernelCache* cache) const override {\n    user_op::Tensor* y_tensor = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    if (y_tensor->shape_view().elem_cnt() == 0) { return; }\n    const user_op::Tensor* x_tensor = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    const SliceContext& slice_ctx =\n        dynamic_cast<const OpKernelCacheWrapper<SliceContext>*>(cache)->Get();\n    WriteSlice<device_type, T>(ctx, x_tensor, y_tensor, slice_ctx, /*from_large_to_small=*/true);\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\ntemplate<DeviceType device_type, typename T>\nclass SliceUpdateKernel final : public user_op::OpKernel {\n public:\n  SliceUpdateKernel() = default;\n  ~SliceUpdateKernel() = default;\n\n  std::shared_ptr<user_op::OpKernelCache> InitOpKernelCache(\n      user_op::KernelCacheContext* ctx) const override {\n    SliceContext slice_ctx;\n    if (ctx->parallel_ctx().parallel_num() == 1) {\n      // split_axis == SPLIT_AXIS_FOR_NON_SPLIT means the sbp attribute is not 'split'\n      CHECK_JUST(slice_ctx.PushSplitInfo(SPLIT_AXIS_FOR_NON_SPLIT, 0, 0, 0));\n    } else {\n      const Shape& parallel_hierarchy = *ctx->parallel_desc().hierarchy();\n      NdSbp ref_nd_sbp = ctx->NdSbp4ArgNameAndIndex(\"ref\", 0);\n      {\n        const NdSbp& value_nd_sbp = ctx->NdSbp4ArgNameAndIndex(\"value\", 0);\n        // If ref and value both split in the same axis(full slice),\n        // we can consider the physical tensor is broadcast in this axis.\n        for (int i = 0; i < parallel_hierarchy.NumAxes(); ++i) {\n          const SbpParallel& ref_sbp = ref_nd_sbp.sbp_parallel(i);\n          const SbpParallel& value_sbp = value_nd_sbp.sbp_parallel(i);\n          if (ref_sbp.has_split_parallel() && value_sbp.has_split_parallel()) {\n            CHECK_EQ(ref_sbp.split_parallel().axis(), value_sbp.split_parallel().axis());\n            ref_nd_sbp.mutable_sbp_parallel(i)->clear_split_parallel();\n            ref_nd_sbp.mutable_sbp_parallel(i)->mutable_broadcast_parallel();\n          }\n        }\n      }\n      const Shape& logical_shape = ctx->LogicalTensorDesc4ArgNameAndIndex(\"ref\", 0)->shape();\n      const int64_t parallel_id = ctx->parallel_ctx().parallel_id();\n      const TensorSliceView& slice_view =\n          GetTensorSliceView4ParallelId(parallel_hierarchy, ref_nd_sbp, logical_shape, parallel_id);\n      for (int i = 0; i < logical_shape.NumAxes(); ++i) {\n        const Range& range = slice_view.At(i);\n        if (range.begin() != 0 || range.end() != logical_shape.At(i)) {\n          CHECK_JUST(slice_ctx.PushSplitInfo(i, range.begin(), range.end(), logical_shape.At(i)));\n        }\n      }\n    }\n    return std::make_shared<OpKernelCacheWrapper<SliceContext>>(slice_ctx);\n  }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*,\n               const user_op::OpKernelCache* cache) const override {\n    const user_op::Tensor* value_tensor = ctx->Tensor4ArgNameAndIndex(\"value\", 0);\n    user_op::Tensor* ref_tensor = ctx->Tensor4ArgNameAndIndex(\"ref\", 0);\n    user_op::Tensor* y_tensor = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    if (y_tensor->shape_view().elem_cnt() == 0) { return; }\n    // When eager executing, y_tensor shared the same memory with ref_tensor\n    if (ref_tensor->dptr<T>() != y_tensor->dptr<T>()) {\n      // lazy run\n      AutoMemcpy(ctx->stream(), y_tensor->mut_dptr<T>(), ref_tensor->dptr<T>(),\n                 y_tensor->shape_view().elem_cnt() * sizeof(T), ref_tensor->mem_case(),\n                 y_tensor->mem_case());\n    }\n    const SliceContext& slice_ctx =\n        dynamic_cast<const OpKernelCacheWrapper<SliceContext>*>(cache)->Get();\n    WriteSlice<device_type, T>(ctx, value_tensor, y_tensor, slice_ctx,\n                               /*from_large_to_small=*/false);\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return true; }\n};\n\ntemplate<DeviceType device_type, typename T>\nclass SliceGradKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport {\n public:\n  SliceGradKernel() = default;\n  ~SliceGradKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* dy_tensor = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    user_op::Tensor* dx_tensor = ctx->Tensor4ArgNameAndIndex(\"dx\", 0);\n    size_t dx_byte_size = dx_tensor->shape_view().elem_cnt() * sizeof(T);\n    Memset<device_type>(ctx->stream(), dx_tensor->mut_dptr<T>(), 0, dx_byte_size);\n    if (dy_tensor->shape_view().elem_cnt() == 0) { return; }\n    SliceParams params = ConstructSliceParams(ctx, dx_tensor, dy_tensor);\n    SliceKernelUtil<device_type, T>::Backward(ctx->stream(), params, dy_tensor->dptr<T>(),\n                                              dx_tensor->mut_dptr<T>());\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_SLICE_KERNEL(device, dtype)                                               \\\n  REGISTER_USER_KERNEL(\"slice\").SetCreateFn<SliceKernel<device, dtype>>().SetIsMatchedHob( \\\n      (user_op::HobDeviceType() == device)                                                 \\\n      && (user_op::HobDataType(\"x\", 0) == GetDataType<dtype>::value));                     \\\n  REGISTER_USER_KERNEL(\"slice_grad\")                                                       \\\n      .SetCreateFn<SliceGradKernel<device, dtype>>()                                       \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == device)                                \\\n                       && (user_op::HobDataType(\"dx\", 0) == GetDataType<dtype>::value));   \\\n  REGISTER_USER_KERNEL(\"slice_update\")                                                     \\\n      .SetCreateFn<SliceUpdateKernel<device, dtype>>()                                     \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == device)                                \\\n                       && (user_op::HobDataType(\"ref\", 0) == GetDataType<dtype>::value));\n\n#define REGISTER_SLICE_KERNEL_WITH_DEVICE(device) \\\n  REGISTER_SLICE_KERNEL(device, bool)             \\\n  REGISTER_SLICE_KERNEL(device, float16)          \\\n  REGISTER_SLICE_KERNEL(device, float)            \\\n  REGISTER_SLICE_KERNEL(device, double)           \\\n  REGISTER_SLICE_KERNEL(device, int32_t)          \\\n  REGISTER_SLICE_KERNEL(device, int64_t)          \\\n  REGISTER_SLICE_KERNEL(device, int8_t)           \\\n  REGISTER_SLICE_KERNEL(device, uint8_t)\n\nREGISTER_SLICE_KERNEL(DeviceType::kCPU, std::complex<float>)\nREGISTER_SLICE_KERNEL(DeviceType::kCPU, std::complex<double>)\n#ifdef WITH_CUDA\nREGISTER_SLICE_KERNEL(DeviceType::kCUDA, cuComplex)\nREGISTER_SLICE_KERNEL(DeviceType::kCUDA, cuDoubleComplex)\n#endif\n\nREGISTER_SLICE_KERNEL_WITH_DEVICE(DeviceType::kCPU)\nREGISTER_SLICE_KERNEL(DeviceType::kCPU, bfloat16)\n#ifdef WITH_CUDA\nREGISTER_SLICE_KERNEL_WITH_DEVICE(DeviceType::kCUDA)\n#if CUDA_VERSION >= 11000\nREGISTER_SLICE_KERNEL(DeviceType::kCUDA, nv_bfloat16)\n#endif\n#endif\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/slice_util.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/user/kernels/slice_util.h\"\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/thread/thread_manager.h\"\n\nnamespace oneflow {\n\nSliceParams FoldContiguousFullSliceDimensions(const SliceParams& params) {\n  SliceParams fold_slice_params;\n  bool full_slice_on_prev_axis = false;\n  FOR_RANGE(int, i, 0, params.ndim) {\n    bool full_slice_on_cur_axis = params.IsFullSlice(i);\n    if (full_slice_on_cur_axis && full_slice_on_prev_axis) {\n      int cur_dim = fold_slice_params.ndim - 1;\n      fold_slice_params.dims[cur_dim] *= params.dims[i];\n      fold_slice_params.size[cur_dim] *= params.size[i];\n    } else {\n      int cur_dim = fold_slice_params.ndim;\n      fold_slice_params.dims[cur_dim] = params.dims[i];\n      fold_slice_params.start[cur_dim] = params.start[i];\n      fold_slice_params.step[cur_dim] = params.step[i];\n      fold_slice_params.size[cur_dim] = params.size[i];\n      fold_slice_params.ndim += 1;\n    }\n    full_slice_on_prev_axis = full_slice_on_cur_axis;\n  }\n  return fold_slice_params;\n}\n\ntemplate<typename T>\nstruct SliceKernelUtil<DeviceType::kCPU, T> {\n  static void Forward(ep::Stream* stream, const SliceParams& params, const T* entire, T* sliced) {\n    SliceParams fold_slice_params = FoldContiguousFullSliceDimensions(params);\n    SwitchDoForward(SwitchCase(fold_slice_params.ndim), stream, fold_slice_params, entire, sliced);\n  }\n\n  static void Forward(ep::Stream* stream, const SliceParams& entire_params,\n                      const SliceParams& sliced_params, const T* entire, T* sliced) {\n    SwitchDoForward(SwitchCase(entire_params.ndim), stream, entire_params, sliced_params, entire,\n                    sliced);\n  }\n\n  static void Backward(ep::Stream* stream, const SliceParams& params, const T* sliced, T* entire) {\n    SliceParams fold_slice_params = FoldContiguousFullSliceDimensions(params);\n    SwitchDoBackward(SwitchCase(fold_slice_params.ndim), stream, fold_slice_params, sliced, entire);\n  }\n\n private:\n  template<int NDIM>\n  static void DoForward(ep::Stream* stream, const SliceParams& params, const T* entire, T* sliced) {\n    CHECK_EQ(params.ndim, NDIM);\n    int64_t elem_cnt = params.elem_cnt();\n    SliceIndexHelper<NDIM> entire_idx_cvtr(params.dims);\n    SliceIndexHelper<NDIM> sliced_idx_cvtr(params.size);\n    MultiThreadLoop(elem_cnt, [&](int64_t i) {\n      int64_t offset = SliceOffsetToEntireOffset<NDIM>(i, params, entire_idx_cvtr, sliced_idx_cvtr);\n      sliced[i] = entire[offset];\n    });\n  }\n\n  template<typename DoEachT>\n  static void SteppedMultiThreadLoop(size_t elem_cnt, size_t step, const DoEachT& DoEach) {\n    if (elem_cnt == 0) { return; }\n    CHECK_GT(step, 0);\n    CHECK_EQ(elem_cnt % step, 0);\n    MultiThreadLoop(elem_cnt / step, [&](size_t i) { DoEach(i * step); });\n  }\n\n  template<int NDIM>\n  static void DoForward(ep::Stream* stream, const SliceParams& entire_params,\n                        const SliceParams& sliced_params, const T* entire, T* sliced) {\n    CHECK_EQ(entire_params.ndim, NDIM);\n    CHECK_EQ(sliced_params.ndim, NDIM);\n    int64_t elem_cnt = entire_params.elem_cnt();\n    SliceIndexHelper<NDIM> entire_splitted_large_idx_cvtr =\n        NdIndexStrideOffsetHelper<int64_t, NDIM>(entire_params.stride);\n    SliceIndexHelper<NDIM> sliced_splitted_large_idx_cvtr(entire_params.size);\n    SliceIndexHelper<NDIM> entire_full_small_idx_cvtr =\n        NdIndexStrideOffsetHelper<int64_t, NDIM>(sliced_params.stride);\n    SliceIndexHelper<NDIM> sliced_full_small_idx_cvtr(sliced_params.size);\n\n    int cnt = 1;\n    int entire_target_stride = 1;\n    int sliced_target_stride = 1;\n    // Calculate the length of continuous part\n    for (int i = NDIM - 1; i >= 0; i--) {\n      if (entire_params.stride[i] != entire_target_stride\n          || sliced_params.stride[i] != sliced_target_stride) {\n        break;\n      }\n      entire_target_stride *= entire_params.size[i];\n      sliced_target_stride *= sliced_params.size[i];\n      if (sliced_params.step[i] == 1 && entire_params.step[i] == 1) {\n        cnt *= sliced_params.size[i];\n      }\n      if (!entire_params.IsFullSlice(i) || !sliced_params.IsFullSlice(i)) { break; }\n    }\n    SteppedMultiThreadLoop(elem_cnt, cnt, [&](int64_t i) {\n      const int64_t entire_offset = SliceOffsetToEntireOffset<NDIM>(\n          i, entire_params, entire_splitted_large_idx_cvtr, sliced_splitted_large_idx_cvtr);\n      const int64_t sliced_offset = SliceOffsetToEntireOffset<NDIM>(\n          i, sliced_params, entire_full_small_idx_cvtr, sliced_full_small_idx_cvtr);\n      std::copy(entire + entire_offset, entire + entire_offset + cnt, sliced + sliced_offset);\n    });\n  }\n\n  template<int NDIM>\n  static void DoBackward(ep::Stream* stream, const SliceParams& params, const T* sliced,\n                         T* entire) {\n    CHECK_EQ(params.ndim, NDIM);\n    int64_t elem_cnt = params.elem_cnt();\n    SliceIndexHelper<NDIM> entire_idx_cvtr(params.dims);\n    SliceIndexHelper<NDIM> sliced_idx_cvtr(params.size);\n    MultiThreadLoop(elem_cnt, [&](int64_t i) {\n      int64_t offset = SliceOffsetToEntireOffset<NDIM>(i, params, entire_idx_cvtr, sliced_idx_cvtr);\n      entire[offset] = sliced[i];\n    });\n  }\n\n#define MAKE_SLICE_KERNEL_UTIL_SWITCH_ENTRY(func_name, N) \\\n  SliceKernelUtil<DeviceType::kCPU, T>::func_name<N>\n#define DEFINE_SLICE_KERNEL_UTIL_SWITCH_STATIC_METHOD(func_name)                  \\\n  DEFINE_STATIC_SWITCH_FUNC(void, func_name, MAKE_SLICE_KERNEL_UTIL_SWITCH_ENTRY, \\\n                            MAKE_NDIM_CTRV_SEQ(DIM_SEQ));\n\n  DEFINE_SLICE_KERNEL_UTIL_SWITCH_STATIC_METHOD(DoForward);\n  DEFINE_SLICE_KERNEL_UTIL_SWITCH_STATIC_METHOD(DoBackward);\n#undef DEFINE_SLICE_KERNEL_UTIL_SWITCH_STATIC_METHOD\n#undef MAKE_SLICE_KERNEL_UTIL_SWITCH_ENTRY\n};\n\nINSTANTIATE_SLICE_KERNEL_UTIL_WITH_DEVICE(DeviceType::kCPU)\nINSTANTIATE_SLICE_KERNEL_UTIL(DeviceType::kCPU, bfloat16)\nINSTANTIATE_SLICE_KERNEL_UTIL(DeviceType::kCPU, std::complex<float>)\nINSTANTIATE_SLICE_KERNEL_UTIL(DeviceType::kCPU, std::complex<double>)\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/slice_util.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/user/kernels/slice_util.h\"\n#include \"oneflow/core/common/switch_func.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n#if CUDA_VERSION >= 11000\n#include <cuda_bf16.h>\n#endif  // CUDA_VERSION >= 11000\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename T, int NDIM>\n__global__ void SliceForwardGpu(const int n, SliceParams params,\n                                SliceIndexHelper<NDIM> entire_idx_cvtr,\n                                SliceIndexHelper<NDIM> sliced_idx_cvtr, const T* entire,\n                                T* sliced) {\n  CUDA_1D_KERNEL_LOOP(i, n) {\n    int64_t offset = SliceOffsetToEntireOffset<NDIM>(i, params, entire_idx_cvtr, sliced_idx_cvtr);\n    sliced[i] = entire[offset];\n  }\n}\n\ntemplate<typename T, int NDIM>\n__global__ void SliceForwardGpu(const int n, SliceParams entire_params, SliceParams sliced_params,\n                                SliceIndexHelper<NDIM> entire_splitted_large_idx_cvtr,\n                                SliceIndexHelper<NDIM> sliced_splitted_large_idx_cvtr,\n                                SliceIndexHelper<NDIM> entire_full_small_idx_cvtr,\n                                SliceIndexHelper<NDIM> sliced_full_small_idx_cvtr, const T* entire,\n                                T* sliced) {\n  CUDA_1D_KERNEL_LOOP(i, n) {\n    int64_t entire_offset = SliceOffsetToEntireOffset<NDIM>(\n        i, entire_params, entire_splitted_large_idx_cvtr, sliced_splitted_large_idx_cvtr);\n    int64_t sliced_offset = SliceOffsetToEntireOffset<NDIM>(\n        i, sliced_params, entire_full_small_idx_cvtr, sliced_full_small_idx_cvtr);\n    sliced[sliced_offset] = entire[entire_offset];\n  }\n}\n\ntemplate<typename T, int NDIM>\n__global__ void SliceBackwardGpu(const int n, SliceParams params,\n                                 SliceIndexHelper<NDIM> entire_idx_cvtr,\n                                 SliceIndexHelper<NDIM> sliced_idx_cvtr, T* entire,\n                                 const T* sliced) {\n  CUDA_1D_KERNEL_LOOP(i, n) {\n    int64_t offset = SliceOffsetToEntireOffset<NDIM>(i, params, entire_idx_cvtr, sliced_idx_cvtr);\n    entire[offset] = sliced[i];\n  }\n}\n\ntemplate<typename T, int NDIM>\nvoid LaunchSliceForward(ep::Stream* stream, const SliceParams& params, const T* entire, T* sliced) {\n  CHECK_EQ(params.ndim, NDIM);\n  int64_t elem_cnt = params.elem_cnt();\n  SliceIndexHelper<NDIM> entire_idx_cvtr(params.dims);\n  SliceIndexHelper<NDIM> sliced_idx_cvtr(params.size);\n  if (elem_cnt == 0) { return; }\n  SliceForwardGpu<T, NDIM><<<BlocksNum4ThreadsNum(elem_cnt), kCudaThreadsNumPerBlock, 0,\n                             stream->As<ep::CudaStream>()->cuda_stream()>>>(\n      elem_cnt, params, entire_idx_cvtr, sliced_idx_cvtr, entire, sliced);\n}\n\ntemplate<typename T, int NDIM>\nvoid LaunchSliceForward(ep::Stream* stream, const SliceParams& entire_params,\n                        const SliceParams& sliced_params, const T* entire, T* sliced) {\n  CHECK_EQ(entire_params.ndim, NDIM);\n  CHECK_EQ(sliced_params.ndim, NDIM);\n  int64_t elem_cnt = entire_params.elem_cnt();\n  if (elem_cnt == 0) { return; }\n  SliceIndexHelper<NDIM> entire_splitted_large_idx_cvtr =\n      NdIndexStrideOffsetHelper<int64_t, NDIM>(entire_params.stride);\n  SliceIndexHelper<NDIM> sliced_splitted_large_idx_cvtr(entire_params.size);\n  SliceIndexHelper<NDIM> entire_full_small_idx_cvtr =\n      NdIndexStrideOffsetHelper<int64_t, NDIM>(sliced_params.stride);\n  SliceIndexHelper<NDIM> sliced_full_small_idx_cvtr(sliced_params.size);\n  SliceForwardGpu<T, NDIM><<<BlocksNum4ThreadsNum(elem_cnt), kCudaThreadsNumPerBlock, 0,\n                             stream->As<ep::CudaStream>()->cuda_stream()>>>(\n      elem_cnt, entire_params, sliced_params, entire_splitted_large_idx_cvtr,\n      sliced_splitted_large_idx_cvtr, entire_full_small_idx_cvtr, sliced_full_small_idx_cvtr,\n      entire, sliced);\n}\n\ntemplate<typename T, int NDIM>\nvoid LaunchSliceBackward(ep::Stream* stream, const SliceParams& params, const T* sliced,\n                         T* entire) {\n  CHECK_EQ(params.ndim, NDIM);\n  int64_t elem_cnt = params.elem_cnt();\n  SliceIndexHelper<NDIM> entire_idx_cvtr(params.dims);\n  SliceIndexHelper<NDIM> sliced_idx_cvtr(params.size);\n  if (elem_cnt == 0) { return; }\n  SliceBackwardGpu<T, NDIM><<<BlocksNum4ThreadsNum(elem_cnt), kCudaThreadsNumPerBlock, 0,\n                              stream->As<ep::CudaStream>()->cuda_stream()>>>(\n      elem_cnt, params, entire_idx_cvtr, sliced_idx_cvtr, entire, sliced);\n}\n\ntemplate<typename T>\nstruct SliceSwitchUtil final {\n#define MAKE_SLICE_SWITCH_ENTRY(func_name, N) func_name<T, N>\n#define DEFINE_SLICE_SWITCH_UTIL_STATIC_METHOD(func_name) \\\n  DEFINE_STATIC_SWITCH_FUNC(void, func_name, MAKE_SLICE_SWITCH_ENTRY, MAKE_NDIM_CTRV_SEQ(DIM_SEQ))\n\n  DEFINE_SLICE_SWITCH_UTIL_STATIC_METHOD(LaunchSliceForward)\n  DEFINE_SLICE_SWITCH_UTIL_STATIC_METHOD(LaunchSliceBackward)\n#undef DEFINE_SLICE_SWITCH_UTIL_STATIC_METHOD\n#undef MAKE_SLICE_SWITCH_ENTRY\n};\n\ntemplate<typename T>\nsize_t GetPackSize(const SliceParams& params, const T* entire, const T* sliced) {\n  CHECK_GT(params.ndim, 0);\n  const int64_t last_dim = params.ndim - 1;\n  const int64_t mask = (params.dims[last_dim] * sizeof(T)) | (params.start[last_dim] * sizeof(T))\n                       | (params.size[last_dim] * sizeof(T))\n                       | static_cast<int64_t>(reinterpret_cast<uintptr_t>(entire))\n                       | static_cast<int64_t>(reinterpret_cast<uintptr_t>(sliced));\n  if ((mask & 0xF) == 0) {\n    return 16;\n  } else if ((mask & 0x7) == 0) {\n    return 8;\n  } else if ((mask & 0x3) == 0) {\n    return 4;\n  } else if ((mask & 0x1) == 0) {\n    return 2;\n  } else {\n    return 1;\n  }\n}\n\ntemplate<typename T>\nvoid GetPackedParams(const SliceParams& params, const T* entire, const T* sliced, size_t* pack_size,\n                     SliceParams* packed_params) {\n  CHECK_GT(params.ndim, 0);\n  const int64_t last_dim = params.ndim - 1;\n  if (params.step[last_dim] == 1) {\n    *pack_size = GetPackSize<T>(params, entire, sliced);\n    CHECK_GE(*pack_size, sizeof(T));\n    const int64_t elem_per_pack = *pack_size / sizeof(T);\n    *packed_params = params;\n    packed_params->dims[last_dim] /= elem_per_pack;\n    packed_params->start[last_dim] /= elem_per_pack;\n    packed_params->size[last_dim] /= elem_per_pack;\n  } else {\n    *pack_size = sizeof(T);\n    *packed_params = params;\n  }\n}\n\n}  // namespace\n\ntemplate<typename T>\nstruct SliceKernelUtil<DeviceType::kCUDA, T> {\n  static void Forward(ep::Stream* stream, const SliceParams& params, const T* entire, T* sliced) {\n    SliceParams fold_slice_params = FoldContiguousFullSliceDimensions(params);\n    size_t pack_size;\n    SliceParams packed_params{};\n    GetPackedParams<T>(fold_slice_params, entire, sliced, &pack_size, &packed_params);\n    if (pack_size == 1) {\n      SliceSwitchUtil<uint8_t>::SwitchLaunchSliceForward(\n          SwitchCase(packed_params.ndim), stream, packed_params,\n          reinterpret_cast<const uint8_t*>(entire), reinterpret_cast<uint8_t*>(sliced));\n    } else if (pack_size == 2) {\n      SliceSwitchUtil<uint16_t>::SwitchLaunchSliceForward(\n          SwitchCase(packed_params.ndim), stream, packed_params,\n          reinterpret_cast<const uint16_t*>(entire), reinterpret_cast<uint16_t*>(sliced));\n    } else if (pack_size == 4) {\n      SliceSwitchUtil<uint32_t>::SwitchLaunchSliceForward(\n          SwitchCase(packed_params.ndim), stream, packed_params,\n          reinterpret_cast<const uint32_t*>(entire), reinterpret_cast<uint32_t*>(sliced));\n    } else if (pack_size == 8) {\n      SliceSwitchUtil<uint64_t>::SwitchLaunchSliceForward(\n          SwitchCase(packed_params.ndim), stream, packed_params,\n          reinterpret_cast<const uint64_t*>(entire), reinterpret_cast<uint64_t*>(sliced));\n    } else if (pack_size == 16) {\n      SliceSwitchUtil<ulonglong2>::SwitchLaunchSliceForward(\n          SwitchCase(packed_params.ndim), stream, packed_params,\n          reinterpret_cast<const ulonglong2*>(entire), reinterpret_cast<ulonglong2*>(sliced));\n    } else {\n      UNIMPLEMENTED();\n    }\n  }\n\n  static void Forward(ep::Stream* stream, const SliceParams& entire_params,\n                      const SliceParams& sliced_params, const T* entire, T* sliced) {\n    SliceSwitchUtil<T>::SwitchLaunchSliceForward(SwitchCase(entire_params.ndim), stream,\n                                                 entire_params, sliced_params, entire, sliced);\n  }\n\n  static void Backward(ep::Stream* stream, const SliceParams& params, const T* sliced, T* entire) {\n    SliceParams fold_slice_params = FoldContiguousFullSliceDimensions(params);\n    size_t pack_size;\n    SliceParams packed_params{};\n    GetPackedParams<T>(fold_slice_params, entire, sliced, &pack_size, &packed_params);\n    if (pack_size == 1) {\n      SliceSwitchUtil<uint8_t>::SwitchLaunchSliceBackward(\n          SwitchCase(packed_params.ndim), stream, packed_params,\n          reinterpret_cast<const uint8_t*>(sliced), reinterpret_cast<uint8_t*>(entire));\n    } else if (pack_size == 2) {\n      SliceSwitchUtil<uint16_t>::SwitchLaunchSliceBackward(\n          SwitchCase(packed_params.ndim), stream, packed_params,\n          reinterpret_cast<const uint16_t*>(sliced), reinterpret_cast<uint16_t*>(entire));\n    } else if (pack_size == 4) {\n      SliceSwitchUtil<uint32_t>::SwitchLaunchSliceBackward(\n          SwitchCase(packed_params.ndim), stream, packed_params,\n          reinterpret_cast<const uint32_t*>(sliced), reinterpret_cast<uint32_t*>(entire));\n    } else if (pack_size == 8) {\n      SliceSwitchUtil<uint64_t>::SwitchLaunchSliceBackward(\n          SwitchCase(packed_params.ndim), stream, packed_params,\n          reinterpret_cast<const uint64_t*>(sliced), reinterpret_cast<uint64_t*>(entire));\n    } else if (pack_size == 16) {\n      SliceSwitchUtil<ulonglong2>::SwitchLaunchSliceBackward(\n          SwitchCase(packed_params.ndim), stream, packed_params,\n          reinterpret_cast<const ulonglong2*>(sliced), reinterpret_cast<ulonglong2*>(entire));\n    } else {\n      UNIMPLEMENTED();\n    }\n  }\n};\n\nINSTANTIATE_SLICE_KERNEL_UTIL_WITH_DEVICE(DeviceType::kCUDA)\nINSTANTIATE_SLICE_KERNEL_UTIL(DeviceType::kCUDA, cuComplex)\nINSTANTIATE_SLICE_KERNEL_UTIL(DeviceType::kCUDA, cuDoubleComplex)\n#if CUDA_VERSION >= 11000\nINSTANTIATE_SLICE_KERNEL_UTIL(DeviceType::kCUDA, nv_bfloat16)\n#endif\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/slice_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_KERNELS_SLICE_UTIL_H_\n#define ONEFLOW_USER_KERNELS_SLICE_UTIL_H_\n\n#include <sstream>\n#include \"oneflow/core/common/nd_index_offset_helper.h\"\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/ep/include/stream.h\"\n\nnamespace oneflow {\n\ninline int64_t RegulateSliceStart(int64_t start, int64_t size) {\n  // slice start must be in range [-size, size)\n  // after changing to positive order it should be in range [0, size)\n  start = std::min(std::max(start, -size), size - 1);\n  return (start < 0) ? (start + size) : start;\n}\n\ninline int64_t RegulateSliceStop(int64_t stop, int64_t size) {\n  // slice stop must be in range [-size-1, size]\n  // after changing to positive order it should be in range [-1, size]\n  stop = std::min(std::max(stop, -size - 1), size);\n  return (stop < 0) ? (stop + size) : stop;\n}\n\nconstexpr size_t kSliceMaxDims = 8;\n\nstruct SliceParams {\n  int64_t ndim = 0;\n  int64_t dims[kSliceMaxDims]{0};\n  int64_t stride[kSliceMaxDims]{0};\n  int64_t start[kSliceMaxDims]{0};\n  int64_t step[kSliceMaxDims]{0};\n  int64_t size[kSliceMaxDims]{0};\n\n  int64_t elem_cnt() const {\n    if (ndim == 0) { return 0; }\n    int64_t elem_cnt = 1;\n    FOR_RANGE(int, i, 0, ndim) { elem_cnt *= size[i]; }\n    return elem_cnt;\n  }\n\n  bool IsFullSlice(int dim) const {\n    CHECK_GE(dim, 0);\n    CHECK_LT(dim, ndim);\n    if (step[dim] != 1) { return false; }\n    if (start[dim] != 0) { return false; }\n    if (size[dim] != dims[dim]) { return false; }\n    return true;\n  }\n\n  std::string ToString() const {\n    std::stringstream ss(\"SliceParams:\");\n    for (int i = 0; i < ndim; ++i) {\n      ss << \"\\n\\tdim: \" << i << \", start: \" << start[i] << \", step: \" << step[i]\n         << \", stride: \" << stride[i] << \", size: \" << size[i] << \", dims: \" << dims[i];\n    }\n    return ss.str();\n  }\n};\n\nSliceParams FoldContiguousFullSliceDimensions(const SliceParams& params);\n\ntemplate<int NDIM>\nusing SliceIndexHelper = NdIndexOffsetHelper<int64_t, NDIM>;\n\ntemplate<int NDIM>\nOF_DEVICE_FUNC int64_t SliceOffsetToEntireOffset(int64_t offset, const SliceParams& params,\n                                                 const SliceIndexHelper<NDIM>& entire_idx_cvtr,\n                                                 const SliceIndexHelper<NDIM>& sliced_idx_cvtr) {\n  int64_t nd_index[NDIM] = {0};\n  sliced_idx_cvtr.OffsetToNdIndex(offset, nd_index);\n#ifdef __CUDA_ARCH__\n#pragma unroll\n#endif\n  for (int64_t i = 0; i < NDIM; ++i) {\n    nd_index[i] = params.start[i] + params.step[i] * nd_index[i];\n    assert(nd_index[i] >= 0);\n    assert(nd_index[i] < params.dims[i]);\n  }\n  return entire_idx_cvtr.NdIndexToOffset(nd_index);\n}\n\ntemplate<DeviceType device_type, typename T>\nstruct SliceKernelUtil {\n  static void Forward(ep::Stream* stream, const SliceParams& params, const T* entire, T* sliced);\n  static void Forward(ep::Stream* stream, const SliceParams& entire_params,\n                      const SliceParams& sliced_params, const T* entire, T* sliced);\n  static void Backward(ep::Stream* stream, const SliceParams& params, const T* sliced, T* entire);\n};\n\n#define INSTANTIATE_SLICE_KERNEL_UTIL(device, dtype) template struct SliceKernelUtil<device, dtype>;\n\n#define INSTANTIATE_SLICE_KERNEL_UTIL_WITH_DEVICE(device) \\\n  INSTANTIATE_SLICE_KERNEL_UTIL(device, bool)             \\\n  INSTANTIATE_SLICE_KERNEL_UTIL(device, float16)          \\\n  INSTANTIATE_SLICE_KERNEL_UTIL(device, float)            \\\n  INSTANTIATE_SLICE_KERNEL_UTIL(device, double)           \\\n  INSTANTIATE_SLICE_KERNEL_UTIL(device, int32_t)          \\\n  INSTANTIATE_SLICE_KERNEL_UTIL(device, int64_t)          \\\n  INSTANTIATE_SLICE_KERNEL_UTIL(device, int8_t)           \\\n  INSTANTIATE_SLICE_KERNEL_UTIL(device, uint8_t)\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_KERNELS_SLICE_UTIL_H_\n"
  },
  {
    "path": "oneflow/user/kernels/smooth_l1_loss_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/user/kernels/loss_kernel_util.h\"\n\nnamespace oneflow {\nnamespace user_op {\n\nnamespace {\n\nusing namespace loss;\n\ntemplate<typename T>\nvoid ComputeSmoothL1Out(int64_t elem_cnt, const T* input, const T* target, T* out,\n                        const float beta) {\n  FOR_RANGE(int64_t, i, 0, elem_cnt) {\n    const T abs_diff = std::abs(input[i] - target[i]);\n    if (abs_diff < beta) {\n      out[i] = 0.5 * abs_diff * abs_diff / beta;\n    } else {\n      out[i] = abs_diff - 0.5 * beta;\n    }\n  }\n}\ntemplate<typename T>\nvoid ComputeSmoothL1GradOut(int64_t elem_cnt, const T* input, const T* target, const T* dy, T* dx,\n                            const float beta) {\n  FOR_RANGE(int64_t, i, 0, elem_cnt) {\n    const T diff = input[i] - target[i];\n    const T abs_diff = std::abs(diff);\n    if (abs_diff < beta) {\n      dx[i] = diff / beta;\n    } else {\n      dx[i] = (diff > GetZeroVal<T>()) - (diff < GetZeroVal<T>());\n    }\n    const T dy_val = dy[i];\n    dx[i] = dx[i] * dy_val;\n  }\n}\n\ntemplate<typename T>\nclass SmoothL1LossKernel : public SimpleLossKernel<DeviceType::kCPU, T, SmoothL1LossKernel<T>> {\n public:\n  void ComputeOut(user_op::KernelComputeContext* ctx, int64_t elem_cnt, const T* input,\n                  const T* target, T* out) const {\n    const float beta = ctx->Attr<float>(\"beta\");\n    ComputeSmoothL1Out(elem_cnt, input, target, out, beta);\n  }\n};\n\ntemplate<typename T>\nclass SmoothL1LossGradKernel\n    : public SimpleLossGradKernel<DeviceType::kCPU, T, SmoothL1LossGradKernel<T>> {\n public:\n  void ComputeOut(user_op::KernelComputeContext* ctx, int64_t elem_cnt, const T* input,\n                  const T* target, const T* dy, T* dx) const {\n    const float beta = ctx->Attr<float>(\"beta\");\n    ComputeSmoothL1GradOut(elem_cnt, input, target, dy, dx, beta);\n  }\n};\n\n}  // namespace\n\nREGISTER_SIMPLE_LOSS_KERNEL_CPU(\"smooth_l1_loss\", SmoothL1LossKernel)\nREGISTER_SIMPLE_LOSS_GRAD_KERNEL_CPU(\"smooth_l1_loss_grad\", SmoothL1LossGradKernel)\n\n}  // namespace user_op\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/smooth_l1_loss_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/cuda/elementwise.cuh\"\n#include \"oneflow/user/kernels/loss_kernel_util.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n\nnamespace oneflow {\nnamespace user_op {\n\nnamespace {\n\nusing namespace loss;\n\ntemplate<typename T>\nstruct SmoothL1Functor {\n  float beta_;\n  float inv_beta_;\n  T half_of_one_;\n  SmoothL1Functor(float beta)\n      : beta_(beta), inv_beta_(static_cast<float>(1.0 / beta)), half_of_one_(static_cast<T>(0.5)) {}\n\n  __device__ __forceinline__ T operator()(T input_val, T target_val) const {\n    const T abs_diff = abs(input_val - target_val);\n    if (abs_diff < beta_) {\n      return half_of_one_ * abs_diff * abs_diff * inv_beta_;\n    } else {\n      return abs_diff - half_of_one_ * beta_;\n    }\n  }\n};\n\ntemplate<>\nstruct SmoothL1Functor<half> {\n  half beta_;\n  half inv_beta_;\n  half zero_;\n  half half_of_one_;\n  SmoothL1Functor(float beta)\n      : beta_(__float2half(beta)),\n        inv_beta_(__float2half(static_cast<float>(1.0 / beta))),\n        zero_(__float2half(0.f)),\n        half_of_one_(__float2half(0.5f)) {}\n\n  __device__ __forceinline__ half operator()(half input_val, half target_val) const {\n    const half diff = input_val - target_val;\n    const half abs_diff = diff < zero_ ? __hneg(diff) : diff;\n    if (abs_diff < beta_) {\n      return half_of_one_ * abs_diff * abs_diff * inv_beta_;\n    } else {\n      return abs_diff - half_of_one_ * beta_;\n    }\n  }\n};\n\ntemplate<typename T>\nstruct SmoothL1GradFunctor {\n  float beta_;\n  float inv_beta_;\n  T zero_;\n  SmoothL1GradFunctor(float beta)\n      : beta_(beta), inv_beta_(static_cast<float>(1.0 / beta)), zero_(GetZeroVal<T>()) {}\n\n  __device__ __forceinline__ T operator()(T input_val, T target_val, T dy_val) const {\n    const T diff = input_val - target_val;\n    const T abs_diff = abs(diff);\n    T dx_val;\n    if (abs_diff < beta_) {\n      dx_val = diff * inv_beta_;\n    } else {\n      dx_val = (diff > zero_) - (diff < zero_);\n    }\n    return dx_val * dy_val;\n  }\n};\n\ntemplate<>\nstruct SmoothL1GradFunctor<half> {\n  half beta_;\n  half inv_beta_;\n  half zero_;\n  half one_;\n  SmoothL1GradFunctor(float beta)\n      : beta_(__float2half(beta)),\n        inv_beta_(__float2half(static_cast<float>(1.0 / beta))),\n        zero_(__float2half(0.f)),\n        one_(__float2half(1.f)) {}\n\n  __device__ __forceinline__ half operator()(half input_val, half target_val, half dy_val) const {\n    const half diff = input_val - target_val;\n    const half abs_diff = diff < zero_ ? __hneg(diff) : diff;\n    half dx_val;\n    if (abs_diff < beta_) {\n      dx_val = diff * inv_beta_;\n    } else {\n      dx_val = (diff > zero_) - (diff < zero_);\n    }\n    return dx_val * dy_val;\n  }\n};\n\ntemplate<typename T>\nclass SmoothL1LossKernel : public SimpleLossKernel<DeviceType::kCUDA, T, SmoothL1LossKernel<T>> {\n public:\n  void ComputeOut(user_op::KernelComputeContext* ctx, int64_t elem_cnt, const T* input,\n                  const T* target, T* out) const {\n    const float beta = ctx->Attr<float>(\"beta\");\n    OF_CUDA_CHECK((cuda::elementwise::Binary(SmoothL1Functor<T>(beta), elem_cnt, out, input, target,\n                                             ctx->stream()->As<ep::CudaStream>()->cuda_stream())));\n  }\n};\n\ntemplate<typename T>\nclass SmoothL1LossGradKernel\n    : public SimpleLossGradKernel<DeviceType::kCUDA, T, SmoothL1LossGradKernel<T>> {\n public:\n  void ComputeOut(user_op::KernelComputeContext* ctx, int64_t elem_cnt, const T* input,\n                  const T* target, const T* dy, T* dx) const {\n    const float beta = ctx->Attr<float>(\"beta\");\n    OF_CUDA_CHECK(\n        (cuda::elementwise::Ternary(SmoothL1GradFunctor<T>(beta), elem_cnt, dx, input, target, dy,\n                                    ctx->stream()->As<ep::CudaStream>()->cuda_stream())));\n  }\n};\n\n}  // namespace\n\nREGISTER_SIMPLE_LOSS_KERNEL_CUDA(\"smooth_l1_loss\", SmoothL1LossKernel)\nREGISTER_SIMPLE_LOSS_GRAD_KERNEL_CUDA(\"smooth_l1_loss_grad\", SmoothL1LossGradKernel)\n\n}  // namespace user_op\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/softmax_cross_entropy_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/user/kernels/softmax_cross_entropy_kernel.h\"\n#include \"oneflow/core/kernel/kernel_util.cuh\"\n\nnamespace oneflow {\nnamespace user_op {\n\ntemplate<typename T>\nstruct CrossEntropyKernelUtil<DeviceType::kCPU, T> {\n  static void ComputeEntropy(ep::Stream* stream, const int64_t num_instances,\n                             const int64_t num_classes, const T* x, const T* labels, T* y) {\n    FOR_RANGE(int64_t, i, 0, num_instances) {\n      T tmp = 0;\n      FOR_RANGE(int64_t, j, 0, num_classes) {\n        T label = labels[i * num_classes + j];\n        T prob = x[i * num_classes + j];\n        // tmp -= label * SafeLog(prob);\n        tmp -= label * logf((prob > 1e-20) ? prob : 1e-20);\n      }\n      y[i] = tmp;\n    }\n  }\n\n  static void ComputeDiffWithSoftmax(ep::Stream* stream, const int64_t elem_cnt,\n                                     const int64_t num_classes, const T* prob, const T* labels,\n                                     const T* dy, T* dx) {\n    FOR_RANGE(int64_t, i, 0, elem_cnt) {\n      const int32_t row_id = i / num_classes;\n      dx[i] = dy[row_id] * (prob[i] - labels[i]);\n    }\n  }\n};\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_SOFTMAX_CROSS_ENTROPY_KERNEL,\n                                 OF_PP_MAKE_TUPLE_SEQ(DeviceType::kCPU), FLOATING_DATA_TYPE_SEQ)\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_SOFTMAX_CROSS_ENTROPY_GRAD_KERNEL,\n                                 OF_PP_MAKE_TUPLE_SEQ(DeviceType::kCPU), FLOATING_DATA_TYPE_SEQ)\n}  // namespace user_op\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/softmax_cross_entropy_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/user/kernels/softmax_cross_entropy_kernel.h\"\n#include \"oneflow/core/kernel/kernel_util.cuh\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n#include <cub/cub.cuh>\n\nnamespace oneflow {\nnamespace user_op {\n\nnamespace {\n\nconstexpr int64_t kCrossEntropyGpuBlockSize = 128;\n\ntemplate<typename T>\n__global__ void ComputeEntropyGpu(const int64_t num_instances, const int64_t num_classes,\n                                  const T* x, const T* labels, T* y) {\n  typedef cub::BlockReduce<T, kCrossEntropyGpuBlockSize> BlockReduce;\n  __shared__ typename BlockReduce::TempStorage temp_storage;\n  const int tid = threadIdx.x;\n  for (int row = blockIdx.x; row < num_instances; row += gridDim.x) {\n    const int row_offset = row * num_classes;\n    const T* in_row = x + row_offset;\n    const T* label_row = labels + row_offset;\n    T result = 0;\n    for (int col = tid; col < num_classes; col += kCrossEntropyGpuBlockSize) {\n      T label = label_row[col];\n      T prob = in_row[col];\n      result += -label * SafeLog(prob);\n    }\n    __syncthreads();\n    T row_reduce_result = BlockReduce(temp_storage).Reduce(result, cub::Sum());\n    if (0 == tid) { y[row] = row_reduce_result; }\n  }\n}\n\n__global__ void ComputeEntropyGpuHalf(const int64_t num_instances, const int64_t num_classes,\n                                      const half* x, const half* labels, half* y) {\n  typedef cub::BlockReduce<float, kCrossEntropyGpuBlockSize> BlockReduce;\n  __shared__ typename BlockReduce::TempStorage temp_storage;\n  const int tid = threadIdx.x;\n  for (int row = blockIdx.x; row < num_instances; row += gridDim.x) {\n    const int row_offset = row * num_classes;\n    const half* in_row = x + row_offset;\n    const half* label_row = labels + row_offset;\n    float result = 0;\n    for (int col = tid; col < num_classes; col += kCrossEntropyGpuBlockSize) {\n      float label = __half2float(label_row[col]);\n      float prob = __half2float(in_row[col]);\n      result += -label * SafeLog(prob);\n    }\n    __syncthreads();\n    float row_reduce_result = BlockReduce(temp_storage).Reduce(result, cub::Sum());\n    if (0 == tid) { y[row] = __float2half(row_reduce_result); }\n  }\n}\n\ntemplate<typename T>\n__global__ void ComputeDiffWithSoftmaxGpu(const int64_t elem_cnt, const int64_t num_classes,\n                                          const T* prob, const T* labels, const T* dy, T* dx) {\n  CUDA_1D_KERNEL_LOOP(i, elem_cnt) {\n    const int32_t row_id = i / num_classes;\n    dx[i] = dy[row_id] * (prob[i] - labels[i]);\n  }\n}\n\n__global__ void ComputeDiffWithSoftmaxGpuHalf(const int64_t elem_cnt, const int64_t num_classes,\n                                              const half* prob, const half* labels, const half* dy,\n                                              half* dx) {\n#if __CUDA_ARCH__ >= 530 || !defined(__CUDA_ARCH__)\n  CUDA_1D_KERNEL_LOOP(i, elem_cnt) {\n    const int32_t row_id = i / num_classes;\n    dx[i] = __hmul(dy[row_id], __hsub(prob[i], labels[i]));\n  }\n#else\n  printf(\"use half need nvcc arch >= 530\");\n  assert(false);\n#endif /* __CUDA_ARCH__ >= 530 || !defined(__CUDA_ARCH__)*/\n}\n\n}  // namespace\n\nint GetCrossEntropyNumBlocks(const int num_instances) {\n  return std::min(static_cast<int>(num_instances), kCudaMaxBlocksNum);\n}\n\nint GetCrossEntropyBlockSize() { return kCrossEntropyGpuBlockSize; }\n\ntemplate<typename T>\nstruct CrossEntropyKernelUtil<DeviceType::kCUDA, T> {\n  static void ComputeEntropy(ep::Stream* stream, const int64_t num_instances,\n                             const int64_t num_classes, const T* x, const T* labels, T* y) {\n    OF_CUDA_CHECK(cudaMemsetAsync(y, 0, sizeof(T) * num_instances,\n                                  stream->As<ep::CudaStream>()->cuda_stream()));\n    ComputeEntropyGpu<<<GetCrossEntropyNumBlocks(num_instances), GetCrossEntropyBlockSize(), 0,\n                        stream->As<ep::CudaStream>()->cuda_stream()>>>(num_instances, num_classes,\n                                                                       x, labels, y);\n  }\n\n  static void ComputeDiffWithSoftmax(ep::Stream* stream, const int64_t elem_cnt,\n                                     const int64_t num_classes, const T* prob, const T* labels,\n                                     const T* dy, T* dx) {\n    ComputeDiffWithSoftmaxGpu<<<BlocksNum4ThreadsNum(elem_cnt), kCudaThreadsNumPerBlock, 0,\n                                stream->As<ep::CudaStream>()->cuda_stream()>>>(\n        elem_cnt, num_classes, prob, labels, dy, dx);\n  }\n};\n\ntemplate<>\nstruct CrossEntropyKernelUtil<DeviceType::kCUDA, float16> {\n  static void ComputeEntropy(ep::Stream* stream, const int64_t num_instances,\n                             const int64_t num_classes, const float16* x, const float16* labels,\n                             float16* y) {\n    OF_CUDA_CHECK(cudaMemsetAsync(y, 0, sizeof(float16) * num_instances,\n                                  stream->As<ep::CudaStream>()->cuda_stream()));\n    ComputeEntropyGpuHalf<<<GetCrossEntropyNumBlocks(num_instances), GetCrossEntropyBlockSize(), 0,\n                            stream->As<ep::CudaStream>()->cuda_stream()>>>(\n        num_instances, num_classes, reinterpret_cast<const half*>(x),\n        reinterpret_cast<const half*>(labels), reinterpret_cast<half*>(y));\n  }\n\n  static void ComputeDiffWithSoftmax(ep::Stream* stream, const int64_t elem_cnt,\n                                     const int64_t num_classes, const float16* prob,\n                                     const float16* labels, const float16* dy, float16* dx) {\n    ComputeDiffWithSoftmaxGpuHalf<<<BlocksNum4ThreadsNum(elem_cnt), kCudaThreadsNumPerBlock, 0,\n                                    stream->As<ep::CudaStream>()->cuda_stream()>>>(\n        elem_cnt, num_classes, reinterpret_cast<const half*>(prob),\n        reinterpret_cast<const half*>(labels), reinterpret_cast<const half*>(dy),\n        reinterpret_cast<half*>(dx));\n  }\n};\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_SOFTMAX_CROSS_ENTROPY_KERNEL,\n                                 OF_PP_MAKE_TUPLE_SEQ(DeviceType::kCUDA),\n                                 FLOATING_DATA_TYPE_SEQ FLOAT16_DATA_TYPE_SEQ)\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_SOFTMAX_CROSS_ENTROPY_GRAD_KERNEL,\n                                 OF_PP_MAKE_TUPLE_SEQ(DeviceType::kCUDA),\n                                 FLOATING_DATA_TYPE_SEQ FLOAT16_DATA_TYPE_SEQ)\n\n}  // namespace user_op\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/softmax_cross_entropy_kernel.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/ep/include/primitive/softmax.h\"\n\nnamespace oneflow {\nnamespace user_op {\n\nnamespace {\n\ntemplate<typename Context>\nstd::unique_ptr<ep::primitive::Softmax> NewSoftmaxPrimitive(Context* ctx) {\n  const DataType data_type = ctx->TensorDesc4ArgNameAndIndex(\"prediction\", 0)->data_type();\n  return ep::primitive::NewPrimitive<ep::primitive::SoftmaxFactory>(ctx->device_type(), data_type);\n}\n\nauto SoftmaxPrimitiveExists() {\n  return hob::make_custom(\"SoftmaxPrimitiveExists\", [](const user_op::KernelRegContext& ctx) {\n    return NewSoftmaxPrimitive(&ctx).operator bool();\n  });\n}\n\n}  // namespace\n\ntemplate<DeviceType device_type, typename T>\nstruct CrossEntropyKernelUtil {\n  static void ComputeEntropy(ep::Stream* stream, const int64_t num_instances,\n                             const int64_t num_classes, const T* x, const T* labels, T* y);\n  static void ComputeDiffWithSoftmax(ep::Stream* stream, const int64_t elem_cnt,\n                                     const int64_t num_classes, const T* prob, const T* labels,\n                                     const T* dy, T* dx);\n};\n\ntemplate<DeviceType device_type, typename T>\nclass SoftmaxCrossEntropyKernel final : public user_op::OpKernel {\n public:\n  SoftmaxCrossEntropyKernel() = default;\n  ~SoftmaxCrossEntropyKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* prediction = ctx->Tensor4ArgNameAndIndex(\"prediction\", 0);\n    const user_op::Tensor* label = ctx->Tensor4ArgNameAndIndex(\"label\", 0);\n    user_op::Tensor* prob = ctx->Tensor4ArgNameAndIndex(\"prob\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    const auto num_axes = label->shape_view().NumAxes();\n    const int64_t num_instances = label->shape_view().Count(0, num_axes - 1);\n    const int64_t num_classes = label->shape_view().At(num_axes - 1);\n    std::unique_ptr<ep::primitive::Softmax> primitive = NewSoftmaxPrimitive(ctx);\n    CHECK(primitive);\n    primitive->Launch(ctx->stream(), num_instances, num_classes, prediction->dptr(),\n                      prob->mut_dptr());\n\n    CrossEntropyKernelUtil<device_type, T>::ComputeEntropy(ctx->stream(), num_instances,\n                                                           num_classes, prob->dptr<T>(),\n                                                           label->dptr<T>(), out->mut_dptr<T>());\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_SOFTMAX_CROSS_ENTROPY_KERNEL(device_type_v, dtype_pair)                      \\\n  REGISTER_USER_KERNEL(\"softmax_cross_entropy\")                                               \\\n      .SetCreateFn<SoftmaxCrossEntropyKernel<device_type_v, OF_PP_PAIR_FIRST(dtype_pair)>>()  \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == device_type_v)                            \\\n                       && (user_op::HobDataType(\"label\", 0) == OF_PP_PAIR_SECOND(dtype_pair)) \\\n                       && (user_op::HobDataType(\"out\", 0) == OF_PP_PAIR_SECOND(dtype_pair))   \\\n                       && SoftmaxPrimitiveExists());\n\ntemplate<DeviceType device_type, typename T>\nclass SoftmaxCrossEntropyGradKernel final : public user_op::OpKernel {\n public:\n  SoftmaxCrossEntropyGradKernel() = default;\n  ~SoftmaxCrossEntropyGradKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* label = ctx->Tensor4ArgNameAndIndex(\"label\", 0);\n    const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    const user_op::Tensor* prob = ctx->Tensor4ArgNameAndIndex(\"prob\", 0);\n    user_op::Tensor* prediction_diff = ctx->Tensor4ArgNameAndIndex(\"prediction_diff\", 0);\n    const int64_t num_instances = dy->shape_view().elem_cnt();\n    CHECK_EQ(prob->shape_view().elem_cnt() % num_instances, 0);\n    const int64_t num_classes = prob->shape_view().elem_cnt() / num_instances;\n\n    CrossEntropyKernelUtil<device_type, T>::ComputeDiffWithSoftmax(\n        ctx->stream(), prediction_diff->shape_view().elem_cnt(), num_classes, prob->dptr<T>(),\n        label->dptr<T>(), dy->dptr<T>(), prediction_diff->mut_dptr<T>());\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_SOFTMAX_CROSS_ENTROPY_GRAD_KERNEL(device_type_v, dtype_pair)                    \\\n  REGISTER_USER_KERNEL(\"softmax_cross_entropy_grad\")                                             \\\n      .SetCreateFn<SoftmaxCrossEntropyGradKernel<device_type_v, OF_PP_PAIR_FIRST(dtype_pair)>>() \\\n      .SetIsMatchedHob(                                                                          \\\n          (user_op::HobDeviceType() == device_type_v)                                            \\\n          && (user_op::HobDataType(\"label\", 0) == OF_PP_PAIR_SECOND(dtype_pair))                 \\\n          && (user_op::HobDataType(\"prediction_diff\", 0) == OF_PP_PAIR_SECOND(dtype_pair)))      \\\n      .SetInplaceProposalFn([](const user_op::InferContext&,                                     \\\n                               user_op::AddInplaceArgPair AddInplaceArgPairFn) -> Maybe<void> {  \\\n        OF_RETURN_IF_ERROR(AddInplaceArgPairFn(\"prediction_diff\", 0, \"prob\", 0, true));          \\\n        return Maybe<void>::Ok();                                                                \\\n      });\n\n}  // namespace user_op\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/softmax_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/ep/include/primitive/softmax.h\"\n#include \"oneflow/core/ep/include/primitive/softmax_backward.h\"\n#include \"oneflow/core/kernel/cuda_graph_support.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename Context>\nstd::unique_ptr<ep::primitive::Softmax> NewSoftmaxPrimitive(Context* ctx) {\n  const DataType data_type = ctx->TensorDesc4ArgNameAndIndex(\"in\", 0)->data_type();\n  return ep::primitive::NewPrimitive<ep::primitive::SoftmaxFactory>(ctx->device_type(), data_type);\n}\n\nauto SoftmaxPrimitiveExists() {\n  return hob::make_custom(\"SoftmaxPrimitiveExists\", [](const user_op::KernelRegContext& ctx) {\n    return NewSoftmaxPrimitive(&ctx).operator bool();\n  });\n}\n\ntemplate<typename Context>\nstd::unique_ptr<ep::primitive::SoftmaxBackward> NewSoftmaxBackwardPrimitive(Context* ctx) {\n  const DataType data_type = ctx->TensorDesc4ArgNameAndIndex(\"dy\", 0)->data_type();\n  return ep::primitive::NewPrimitive<ep::primitive::SoftmaxBackwardFactory>(ctx->device_type(),\n                                                                            data_type);\n}\n\nauto SoftmaxBackwardPrimitiveExists() {\n  return hob::make_custom(\"SoftmaxBackwardPrimitiveExists\",\n                          [](const user_op::KernelRegContext& ctx) {\n                            return NewSoftmaxBackwardPrimitive(&ctx).operator bool();\n                          });\n}\n\n}  // namespace\n\nclass SoftmaxKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport {\n public:\n  SoftmaxKernel() = default;\n  ~SoftmaxKernel() override = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    const ShapeView& in_shape = in->shape_view();\n    const int64_t cols = in_shape.At(in_shape.NumAxes() - 1);\n    const int64_t rows = in_shape.Count(0, in_shape.NumAxes() - 1);\n    std::unique_ptr<ep::primitive::Softmax> primitive = NewSoftmaxPrimitive(ctx);\n    CHECK(primitive);\n    primitive->Launch(ctx->stream(), rows, cols, in->dptr(), out->mut_dptr());\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\nREGISTER_USER_KERNEL(\"softmax\").SetCreateFn<SoftmaxKernel>().SetIsMatchedHob(\n    SoftmaxPrimitiveExists() == true);\n\nclass SoftmaxGradKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport {\n public:\n  SoftmaxGradKernel() = default;\n  ~SoftmaxGradKernel() override = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex(\"dx\", 0);\n\n    const int64_t num_classes = y->shape_view().At(y->shape_view().NumAxes() - 1);\n    const int64_t num_instances = y->shape_view().elem_cnt() / num_classes;\n\n    std::unique_ptr<ep::primitive::SoftmaxBackward> primitive = NewSoftmaxBackwardPrimitive(ctx);\n    CHECK(primitive);\n    primitive->Launch(ctx->stream(), num_instances, num_classes, y->dptr(), dy->dptr(),\n                      dx->mut_dptr());\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\nREGISTER_USER_KERNEL(\"softmax_grad\")\n    .SetCreateFn<SoftmaxGradKernel>()\n    .SetIsMatchedHob(SoftmaxBackwardPrimitiveExists() == true);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/sort_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/new_kernel_util.h\"\n\nnamespace oneflow {\n\ntemplate<typename T>\nclass CpuSortKernel final : public user_op::OpKernel {\n public:\n  CpuSortKernel() = default;\n  ~CpuSortKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n\n    Memcpy<DeviceType::kCPU>(ctx->stream(), out->mut_dptr<T>(), in->dptr<T>(),\n                             in->shape_view().elem_cnt() * sizeof(T));\n    const int32_t instance_size = in->shape_view().At(in->shape_view().NumAxes() - 1);\n    const int32_t instance_num = in->shape_view().elem_cnt() / instance_size;\n    const std::string& direction = ctx->Attr<std::string>(\"direction\");\n    const bool is_ascending = direction == \"ASCENDING\";\n    const bool is_descending = direction == \"DESCENDING\";\n    FOR_RANGE(int32_t, i, 0, instance_num) {\n      T* out_ptr_i = out->mut_dptr<T>() + i * instance_size;\n      if (is_ascending) {\n        std::sort(out_ptr_i, out_ptr_i + instance_size, std::less<T>());\n      } else if (is_descending) {\n        std::sort(out_ptr_i, out_ptr_i + instance_size, std::greater<T>());\n      } else {\n        UNIMPLEMENTED();\n      }\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_CPU_SORT_KERNEL(dtype)                                             \\\n  REGISTER_USER_KERNEL(\"sort\").SetCreateFn<CpuSortKernel<dtype>>().SetIsMatchedHob( \\\n      (user_op::HobDeviceType() == DeviceType::kCPU)                                \\\n      && (user_op::HobDataType(\"out\", 0) == GetDataType<dtype>::value));\n\nREGISTER_CPU_SORT_KERNEL(float)\nREGISTER_CPU_SORT_KERNEL(double)\nREGISTER_CPU_SORT_KERNEL(int32_t)\nREGISTER_CPU_SORT_KERNEL(int64_t)\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/sort_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/new_kernel_util.h\"\n#include \"oneflow/user/kernels/radix_sort.cuh\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n\nnamespace oneflow {\n\ntemplate<typename T>\nclass GpuSortKernel final : public user_op::OpKernel {\n public:\n  GpuSortKernel() = default;\n  ~GpuSortKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n\n    Memcpy<DeviceType::kCUDA>(ctx->stream(), out->mut_dptr<T>(), in->dptr<T>(),\n                              in->shape_view().elem_cnt() * sizeof(T));\n    const int32_t instance_size = in->shape_view().At(in->shape_view().NumAxes() - 1);\n    const int32_t instance_num = in->shape_view().elem_cnt() / instance_size;\n    const std::string& direction = ctx->Attr<std::string>(\"direction\");\n    if (direction == \"ASCENDING\") {\n      SortKeysAscending(in->dptr<T>(), instance_num, instance_size, tmp_buffer->mut_dptr<void>(),\n                        tmp_buffer->shape_view().elem_cnt(), out->mut_dptr<T>(),\n                        ctx->stream()->As<ep::CudaStream>()->cuda_stream());\n    } else if (direction == \"DESCENDING\") {\n      SortKeysDescending(in->dptr<T>(), instance_num, instance_size, tmp_buffer->mut_dptr<void>(),\n                         tmp_buffer->shape_view().elem_cnt(), out->mut_dptr<T>(),\n                         ctx->stream()->As<ep::CudaStream>()->cuda_stream());\n    } else {\n      UNIMPLEMENTED();\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_CUDA_SORT_KERNEL(dtype)                                                    \\\n  REGISTER_USER_KERNEL(\"sort\")                                                              \\\n      .SetCreateFn<GpuSortKernel<dtype>>()                                                  \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)                      \\\n                       && (user_op::HobDataType(\"out\", 0) == GetDataType<dtype>::value))    \\\n      .SetInferTmpSizeFn([](user_op::InferContext* ctx) -> size_t {                         \\\n        const Shape& in_shape = ctx->InputShape(\"in\", 0);                                   \\\n        const int32_t instance_size = in_shape.dim_vec().back();                            \\\n        const int32_t instance_num = in_shape.elem_cnt() / instance_size;                   \\\n        const std::string& direction = ctx->Attr<std::string>(\"direction\");                 \\\n        if (direction == \"ASCENDING\") {                                                     \\\n          return InferTempStorageForSortKeysAscending<dtype>(instance_num, instance_size);  \\\n        } else if (direction == \"DESCENDING\") {                                             \\\n          return InferTempStorageForSortKeysDescending<dtype>(instance_num, instance_size); \\\n        } else {                                                                            \\\n          UNIMPLEMENTED();                                                                  \\\n          return 0;                                                                         \\\n        }                                                                                   \\\n      });\n\nREGISTER_CUDA_SORT_KERNEL(float)\nREGISTER_CUDA_SORT_KERNEL(double)\nREGISTER_CUDA_SORT_KERNEL(int32_t)\nREGISTER_CUDA_SORT_KERNEL(int64_t)\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/sparse_cross_entropy_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/common/balanced_splitter.h\"\n#include \"oneflow/user/kernels/sparse_cross_entropy_kernel_util.h\"\n#include \"oneflow/core/job/nd_sbp_util.h\"\n\nnamespace oneflow {\nnamespace user_op {\n\nnamespace {\n\nclass SparseCrossEntropyOpKernelCache final : public user_op::OpKernelCache {\n public:\n  SparseCrossEntropyOpKernelCache(int64_t lower, int64_t upper) : lower_(lower), upper_(upper) {}\n  ~SparseCrossEntropyOpKernelCache() override = default;\n\n  int64_t lower() const { return lower_; }\n  int64_t upper() const { return upper_; }\n\n private:\n  const int64_t lower_;\n  const int64_t upper_;\n};\n\n}  // namespace\n\ntemplate<DeviceType device_type, typename T, typename K>\nclass SparseCrossEntropyKernel final : public user_op::OpKernel {\n public:\n  SparseCrossEntropyKernel() = default;\n  ~SparseCrossEntropyKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* prediction = ctx->Tensor4ArgNameAndIndex(\"prediction\", 0);\n    const user_op::Tensor* label = ctx->Tensor4ArgNameAndIndex(\"label\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    const int64_t num_instances = label->shape_view().elem_cnt();\n    CHECK_EQ(prediction->shape_view().elem_cnt() % num_instances, 0);\n    const int64_t num_classes = prediction->shape_view().elem_cnt() / num_instances;\n    const int64_t lower_bound = 0;\n    const int64_t depth = ctx->Attr<int64_t>(\"depth\");\n    SparseCrossEntropyKernelUtil<device_type, T, K>::ComputeEntropy(\n        ctx->stream(), num_instances, num_classes, depth, lower_bound, prediction->dptr<T>(),\n        label->dptr<K>(), out->mut_dptr<T>());\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\ntemplate<DeviceType device_type, typename T, typename K>\nclass SparseCrossEntropyMsKernel final : public user_op::OpKernel {\n public:\n  SparseCrossEntropyMsKernel() = default;\n  ~SparseCrossEntropyMsKernel() = default;\n\n  std::shared_ptr<user_op::OpKernelCache> InitOpKernelCache(\n      user_op::KernelCacheContext* ctx) const override {\n    if (ctx->parallel_ctx().parallel_num() > 1) {\n      const NdSbp& nd_sbp = ctx->NdSbp4ArgNameAndIndex(\"prediction\", 0);\n      const Shape& hierarchy = *ctx->parallel_desc().hierarchy();\n      const TensorDesc* prediction_logical_desc =\n          ctx->LogicalTensorDesc4ArgNameAndIndex(\"prediction\", 0);\n      const int64_t class_axis = prediction_logical_desc->shape().NumAxes() - 1;\n      TensorSliceView view = GetTensorSliceView4ParallelId(\n          hierarchy, nd_sbp, prediction_logical_desc->shape(), ctx->parallel_ctx().parallel_id());\n      return std::make_shared<SparseCrossEntropyOpKernelCache>(view.At(class_axis).begin(),\n                                                               view.At(class_axis).end());\n    } else {\n      return nullptr;\n    }\n  }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*,\n               const user_op::OpKernelCache* cache) const override {\n    const user_op::Tensor* prediction = ctx->Tensor4ArgNameAndIndex(\"prediction\", 0);\n    const user_op::Tensor* label = ctx->Tensor4ArgNameAndIndex(\"label\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    const int64_t num_instances = label->shape_view().elem_cnt();\n    CHECK_EQ(prediction->shape_view().elem_cnt() % num_instances, 0);\n    const int64_t num_classes = prediction->shape_view().elem_cnt() / num_instances;\n    const int64_t depth = ctx->Attr<int64_t>(\"depth\");\n    int64_t lower_bound = 0;\n    if (cache != nullptr) {\n      auto* kernel_cache = dynamic_cast<const SparseCrossEntropyOpKernelCache*>(cache);\n      CHECK_NOTNULL(kernel_cache);\n      CHECK_EQ(num_classes, kernel_cache->upper() - kernel_cache->lower());\n      lower_bound = kernel_cache->lower();\n    }\n    Memset<device_type>(ctx->stream(), out->mut_dptr(), 0,\n                        out->shape_view().elem_cnt() * GetSizeOfDataType(out->data_type()));\n    SparseCrossEntropyKernelUtil<device_type, T, K>::ComputeEntropy(\n        ctx->stream(), num_instances, num_classes, depth, lower_bound, prediction->dptr<T>(),\n        label->dptr<K>(), out->mut_dptr<T>());\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_SPARSE_CROSS_ENTROPY_KERNEL(kernel_class, kernel_name, device_type_v, dtype_pair, \\\n                                             ltype_pair)                                           \\\n  REGISTER_USER_KERNEL(kernel_name)                                                                \\\n      .SetCreateFn<kernel_class<device_type_v, OF_PP_PAIR_FIRST(dtype_pair),                       \\\n                                OF_PP_PAIR_FIRST(ltype_pair)>>()                                   \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == device_type_v)                                 \\\n                       && (user_op::HobDataType(\"label\", 0) == OF_PP_PAIR_SECOND(ltype_pair))      \\\n                       && (user_op::HobDataType(\"out\", 0) == OF_PP_PAIR_SECOND(dtype_pair)));\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_SPARSE_CROSS_ENTROPY_KERNEL, (SparseCrossEntropyKernel),\n                                 (\"sparse_cross_entropy\"), OF_PP_MAKE_TUPLE_SEQ(DeviceType::kCPU),\n                                 FLOATING_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ)\n#ifdef WITH_CUDA\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_SPARSE_CROSS_ENTROPY_KERNEL, (SparseCrossEntropyKernel),\n                                 (\"sparse_cross_entropy\"), OF_PP_MAKE_TUPLE_SEQ(DeviceType::kCUDA),\n                                 FLOATING_DATA_TYPE_SEQ FLOAT16_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ)\n#endif\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_SPARSE_CROSS_ENTROPY_KERNEL, (SparseCrossEntropyMsKernel),\n                                 (\"sparse_cross_entropy_ms\"),\n                                 OF_PP_MAKE_TUPLE_SEQ(DeviceType::kCPU), FLOATING_DATA_TYPE_SEQ,\n                                 INDEX_DATA_TYPE_SEQ)\n#ifdef WITH_CUDA\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_SPARSE_CROSS_ENTROPY_KERNEL, (SparseCrossEntropyMsKernel),\n                                 (\"sparse_cross_entropy_ms\"),\n                                 OF_PP_MAKE_TUPLE_SEQ(DeviceType::kCUDA),\n                                 FLOATING_DATA_TYPE_SEQ FLOAT16_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ)\n#endif\n\ntemplate<DeviceType device_type, typename T, typename K>\nclass SparseCrossEntropyGradKernel final : public user_op::OpKernel {\n public:\n  SparseCrossEntropyGradKernel() = default;\n  ~SparseCrossEntropyGradKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* prediction = ctx->Tensor4ArgNameAndIndex(\"prediction\", 0);\n    const user_op::Tensor* label = ctx->Tensor4ArgNameAndIndex(\"label\", 0);\n    const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    user_op::Tensor* prediction_diff = ctx->Tensor4ArgNameAndIndex(\"prediction_diff\", 0);\n    const int64_t num_instances = label->shape_view().elem_cnt();\n    CHECK_EQ(prediction->shape_view().elem_cnt() % num_instances, 0);\n    const int64_t num_classes = prediction->shape_view().elem_cnt() / num_instances;\n    const int64_t lower_bound = 0;\n    const int64_t depth = ctx->Attr<int64_t>(\"depth\");\n    size_t prediction_diff_bytes_size =\n        prediction_diff->shape_view().elem_cnt() * GetSizeOfDataType(prediction_diff->data_type());\n    Memset<device_type>(ctx->stream(), prediction_diff->mut_dptr<T>(), 0,\n                        prediction_diff_bytes_size);\n    SparseCrossEntropyKernelUtil<device_type, T, K>::ComputeDiff(\n        ctx->stream(), num_instances, num_classes, depth, lower_bound, prediction->dptr<T>(),\n        label->dptr<K>(), dy->dptr<T>(), prediction_diff->mut_dptr<T>());\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\ntemplate<DeviceType device_type, typename T, typename K>\nclass SparseCrossEntropyMsGradKernel final : public user_op::OpKernel {\n public:\n  SparseCrossEntropyMsGradKernel() = default;\n  ~SparseCrossEntropyMsGradKernel() = default;\n\n  std::shared_ptr<user_op::OpKernelCache> InitOpKernelCache(\n      user_op::KernelCacheContext* ctx) const override {\n    if (ctx->parallel_ctx().parallel_num() > 1) {\n      const NdSbp& nd_sbp = ctx->NdSbp4ArgNameAndIndex(\"prediction\", 0);\n      const Shape& hierarchy = *ctx->parallel_desc().hierarchy();\n      const TensorDesc* prediction_logical_desc =\n          ctx->LogicalTensorDesc4ArgNameAndIndex(\"prediction\", 0);\n      const int64_t class_axis = prediction_logical_desc->shape().NumAxes() - 1;\n      TensorSliceView view = GetTensorSliceView4ParallelId(\n          hierarchy, nd_sbp, prediction_logical_desc->shape(), ctx->parallel_ctx().parallel_id());\n      return std::make_shared<SparseCrossEntropyOpKernelCache>(view.At(class_axis).begin(),\n                                                               view.At(class_axis).end());\n    } else {\n      return nullptr;\n    }\n  }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*,\n               const user_op::OpKernelCache* cache) const override {\n    const user_op::Tensor* prediction = ctx->Tensor4ArgNameAndIndex(\"prediction\", 0);\n    const user_op::Tensor* label = ctx->Tensor4ArgNameAndIndex(\"label\", 0);\n    const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    user_op::Tensor* prediction_diff = ctx->Tensor4ArgNameAndIndex(\"prediction_diff\", 0);\n    const int64_t num_instances = label->shape_view().elem_cnt();\n    CHECK_EQ(prediction->shape_view().elem_cnt() % num_instances, 0);\n    const int64_t num_classes = prediction->shape_view().elem_cnt() / num_instances;\n    const int64_t depth = ctx->Attr<int64_t>(\"depth\");\n    int64_t lower_bound = 0;\n    if (cache != nullptr) {\n      auto* kernel_cache = dynamic_cast<const SparseCrossEntropyOpKernelCache*>(cache);\n      CHECK_NOTNULL(kernel_cache);\n      CHECK_EQ(num_classes, kernel_cache->upper() - kernel_cache->lower());\n      lower_bound = kernel_cache->lower();\n    }\n    size_t prediction_diff_bytes_size =\n        prediction_diff->shape_view().elem_cnt() * GetSizeOfDataType(prediction_diff->data_type());\n    Memset<device_type>(ctx->stream(), prediction_diff->mut_dptr<T>(), 0,\n                        prediction_diff_bytes_size);\n    SparseCrossEntropyKernelUtil<device_type, T, K>::ComputeDiff(\n        ctx->stream(), num_instances, num_classes, depth, lower_bound, prediction->dptr<T>(),\n        label->dptr<K>(), dy->dptr<T>(), prediction_diff->mut_dptr<T>());\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_SPARSE_CROSS_ENTROPY_GRAD_KERNEL(kernel_class, kernel_name, device_type_v, \\\n                                                  dtype_pair, ltype_pair)                   \\\n  REGISTER_USER_KERNEL(kernel_name)                                                         \\\n      .SetCreateFn<kernel_class<device_type_v, OF_PP_PAIR_FIRST(dtype_pair),                \\\n                                OF_PP_PAIR_FIRST(ltype_pair)>>()                            \\\n      .SetIsMatchedHob(                                                                     \\\n          (user_op::HobDeviceType() == device_type_v)                                       \\\n          && (user_op::HobDataType(\"label\", 0) == OF_PP_PAIR_SECOND(ltype_pair))            \\\n          && (user_op::HobDataType(\"prediction_diff\", 0) == OF_PP_PAIR_SECOND(dtype_pair)));\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_SPARSE_CROSS_ENTROPY_GRAD_KERNEL,\n                                 (SparseCrossEntropyGradKernel), (\"sparse_cross_entropy_grad\"),\n                                 OF_PP_MAKE_TUPLE_SEQ(DeviceType::kCPU), FLOATING_DATA_TYPE_SEQ,\n                                 INDEX_DATA_TYPE_SEQ)\n#ifdef WITH_CUDA\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_SPARSE_CROSS_ENTROPY_GRAD_KERNEL,\n                                 (SparseCrossEntropyGradKernel), (\"sparse_cross_entropy_grad\"),\n                                 OF_PP_MAKE_TUPLE_SEQ(DeviceType::kCUDA),\n                                 FLOATING_DATA_TYPE_SEQ FLOAT16_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ)\n#endif\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_SPARSE_CROSS_ENTROPY_GRAD_KERNEL,\n                                 (SparseCrossEntropyMsGradKernel), (\"sparse_cross_entropy_ms_grad\"),\n                                 OF_PP_MAKE_TUPLE_SEQ(DeviceType::kCPU), FLOATING_DATA_TYPE_SEQ,\n                                 INDEX_DATA_TYPE_SEQ)\n#ifdef WITH_CUDA\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_SPARSE_CROSS_ENTROPY_GRAD_KERNEL,\n                                 (SparseCrossEntropyMsGradKernel), (\"sparse_cross_entropy_ms_grad\"),\n                                 OF_PP_MAKE_TUPLE_SEQ(DeviceType::kCUDA),\n                                 FLOATING_DATA_TYPE_SEQ FLOAT16_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ)\n#endif\n\n}  // namespace user_op\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/sparse_cross_entropy_kernel_util.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/user/kernels/sparse_cross_entropy_kernel_util.h\"\n#include \"oneflow/core/kernel/kernel_util.cuh\"\n\nnamespace oneflow {\nnamespace user_op {\n\ntemplate<typename T, typename K>\nstruct SparseCrossEntropyKernelUtil<DeviceType::kCPU, T, K> {\n  static void ComputeEntropy(ep::Stream* stream, const int64_t num_instances,\n                             const int64_t num_classes, const int64_t depth,\n                             const int64_t lower_bound, const T* x, const K* labels, T* y) {\n    FOR_RANGE(int64_t, i, 0, num_instances) {\n      CHECK_GE(labels[i], 0);\n      CHECK_LT(labels[i], depth);\n      K label = labels[i] - lower_bound;\n      if (label >= 0 && label < num_classes) { y[i] = -SafeLog(x[i * num_classes + label]); }\n    }\n  }\n\n  static void ComputeDiff(ep::Stream* stream, const int64_t num_instances,\n                          const int64_t num_classes, const int64_t depth, const int64_t lower_bound,\n                          const T* x, const K* labels, const T* dy, T* dx) {\n    FOR_RANGE(int64_t, i, 0, num_instances) {\n      CHECK_GE(labels[i], 0);\n      CHECK_LT(labels[i], depth);\n      K label = labels[i] - lower_bound;\n      if (label >= 0 && label < num_classes) {\n        dx[i * num_classes + label] = -dy[i] / MaxWithLogThreshold(x[i * num_classes + label]);\n      }\n    }\n  }\n\n  static void ComputeDiffWithSoftmax(ep::Stream* stream, const int64_t elem_cnt,\n                                     const int64_t num_classes, const int64_t depth,\n                                     const int64_t lower_bound, const T* prob, const K* labels,\n                                     const T* dy, T* dx) {\n    FOR_RANGE(int64_t, i, 0, elem_cnt) {\n      const int32_t row_id = i / num_classes;\n      const int32_t col_id = i - row_id * num_classes;\n      CHECK_GE(labels[row_id], 0);\n      CHECK_LT(labels[row_id], depth);\n      K label = labels[row_id] - lower_bound;\n\n      if (label == col_id) {\n        dx[i] = dy[row_id] * (prob[i] - 1);\n      } else {\n        dx[i] = dy[row_id] * prob[i];\n      }\n    }\n  }\n};\n\n#define INSTANTIATE_SPARSE_CROSS_ENTROPY_KERNEL_UTIL_CPU(data_type_pair, index_type_pair)          \\\n  template struct SparseCrossEntropyKernelUtil<DeviceType::kCPU, OF_PP_PAIR_FIRST(data_type_pair), \\\n                                               OF_PP_PAIR_FIRST(index_type_pair)>;\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_SPARSE_CROSS_ENTROPY_KERNEL_UTIL_CPU,\n                                 FLOATING_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ);\n#undef INSTANTIATE_SPARSE_CROSS_ENTROPY_KERNEL_UTIL_CPU\n\n}  // namespace user_op\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/sparse_cross_entropy_kernel_util.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/user/kernels/sparse_cross_entropy_kernel_util.h\"\n#include \"oneflow/core/kernel/kernel_util.cuh\"\n#include \"oneflow/core/kernel/new_kernel_util.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n\nnamespace oneflow {\nnamespace user_op {\n\nnamespace {\n\ntemplate<typename T, typename K>\n__global__ void ComputeEntropyGpu(const int64_t num_instances, const int64_t num_classes,\n                                  const int64_t depth, const int64_t lower_bound, const T* x,\n                                  const K* labels, T* y) {\n  CUDA_1D_KERNEL_LOOP_T(int64_t, i, num_instances) {\n    assert(labels[i] >= 0);\n    assert(labels[i] < depth);\n    K label = labels[i] - lower_bound;\n    if (label >= 0 && label < num_classes) { y[i] = -SafeLog(x[i * num_classes + label]); }\n  }\n}\n\ntemplate<typename K>\n__global__ void ComputeEntropyGpuHalf(const int64_t num_instances, const int64_t num_classes,\n                                      const int64_t depth, const int64_t lower_bound, const half* x,\n                                      const K* labels, half* y) {\n#if __CUDA_ARCH__ >= 530 || !defined(__CUDA_ARCH__)\n  CUDA_1D_KERNEL_LOOP_T(int64_t, i, num_instances) {\n    assert(labels[i] >= 0);\n    assert(labels[i] < depth);\n    K label = labels[i] - lower_bound;\n    if (label >= 0 && label < num_classes) {\n      y[i] = __float2half(-SafeLog(__half2float(x[i * num_classes + label])));\n    }\n  }\n#else\n  printf(\"use half need nvcc arch >= 530\");\n  assert(false);\n#endif /* __CUDA_ARCH__ >= 530 || !defined(__CUDA_ARCH__)*/\n}\n\ntemplate<typename T, typename K>\n__global__ void ComputeDiffGpu(const int64_t num_instances, const int64_t num_classes,\n                               const int64_t depth, const int64_t lower_bound, const T* x,\n                               const K* labels, const T* dy, T* dx) {\n  CUDA_1D_KERNEL_LOOP_T(int64_t, i, num_instances) {\n    assert(labels[i] >= 0);\n    assert(labels[i] < depth);\n    K label = labels[i] - lower_bound;\n    if (label >= 0 && label < num_classes) {\n      dx[i * num_classes + label] = -dy[i] / MaxWithLogThreshold(x[i * num_classes + label]);\n    }\n  }\n}\n\ntemplate<typename K>\n__global__ void ComputeDiffGpuHalf(const int64_t num_instances, const int64_t num_classes,\n                                   const int64_t depth, const int64_t lower_bound, const half* x,\n                                   const K* labels, const half* dy, half* dx) {\n#if __CUDA_ARCH__ >= 530 || !defined(__CUDA_ARCH__)\n  CUDA_1D_KERNEL_LOOP_T(int64_t, i, num_instances) {\n    assert(labels[i] >= 0);\n    assert(labels[i] < depth);\n    K label = labels[i] - lower_bound;\n    if (label >= 0 && label < num_classes) {\n      dx[i * num_classes + label] =\n          __hneg(__hdiv(__float2half(dy[i]), MaxWithLogThreshold(x[i * num_classes + label])));\n    }\n  }\n#else\n  printf(\"use half need nvcc arch >= 530\");\n  assert(false);\n#endif /* __CUDA_ARCH__ >= 530 || !defined(__CUDA_ARCH__)*/\n}\n\ntemplate<typename T, typename K, typename IndexType>\n__global__ void ComputeDiffWithSoftmaxGpu(const int64_t elem_cnt, const int64_t num_classes,\n                                          const int64_t depth, const int64_t lower_bound,\n                                          const T* prob, const K* labels, const T* dy, T* dx) {\n  CUDA_1D_KERNEL_LOOP_T(IndexType, i, elem_cnt) {\n    const IndexType row_id = i / num_classes;\n    const IndexType col_id = i - row_id * num_classes;\n    assert(labels[row_id] >= 0);\n    assert(labels[row_id] < depth);\n    K label = labels[row_id] - lower_bound;\n    if (label == col_id) {\n      dx[i] = dy[row_id] * (prob[i] - 1);\n    } else {\n      dx[i] = dy[row_id] * prob[i];\n    }\n  }\n}\n\ntemplate<typename K, typename IndexType>\n__global__ void ComputeDiffWithSoftmaxGpuHalf(const int64_t elem_cnt, const int64_t num_classes,\n                                              const int64_t depth, const int64_t lower_bound,\n                                              const half* prob, const K* labels, const half* dy,\n                                              half* dx) {\n#if __CUDA_ARCH__ >= 530 || !defined(__CUDA_ARCH__)\n  CUDA_1D_KERNEL_LOOP_T(IndexType, i, elem_cnt) {\n    // NOTE(chengcheng): int division ('/') of i will reduce performance of int64_t.\n    const IndexType row_id = i / num_classes;\n    const IndexType col_id = i - row_id * num_classes;\n    assert(labels[row_id] >= 0);\n    assert(labels[row_id] < depth);\n    K label = labels[row_id] - lower_bound;\n    if (label == col_id) {\n      dx[i] = __hmul(dy[row_id], __hsub(prob[i], __float2half(1.0)));\n    } else {\n      dx[i] = __hmul(dy[row_id], prob[i]);\n    }\n  }\n#else\n  printf(\"use half need nvcc arch >= 530\");\n  assert(false);\n#endif /* __CUDA_ARCH__ >= 530 || !defined(__CUDA_ARCH__)*/\n}\n\ntemplate<typename K, typename IndexType>\n__global__ void ComputeDiffWithSoftmaxGpuHalf2(const int64_t elem_cnt, const int64_t num_classes,\n                                               const int64_t depth, const int64_t lower_bound,\n                                               const half* prob, const K* labels, const half* dy,\n                                               half* dx) {\n#if __CUDA_ARCH__ >= 530 || !defined(__CUDA_ARCH__)\n  const int64_t h2_num_classes = num_classes / 2;\n  const int64_t h2_elem_cnt = elem_cnt / 2;\n  const auto* prob_h2 = reinterpret_cast<const half2*>(prob);\n  auto* dx_h2 = reinterpret_cast<half2*>(dx);\n  CUDA_1D_KERNEL_LOOP_T(IndexType, i, h2_elem_cnt) {\n    const IndexType row_id = i / h2_num_classes;\n    const IndexType h2_col_id = i - row_id * h2_num_classes;\n    assert(labels[row_id] >= 0);\n    assert(labels[row_id] < depth);\n    K label = labels[row_id] - lower_bound;\n    const half2 prob_h2_i = prob_h2[i];\n    const half dy_row = dy[row_id];\n    half2 dx_h2_i;\n    dx_h2_i.x = __hmul(dy_row, __hsub(prob_h2_i.x, static_cast<half>(label == 2 * h2_col_id)));\n    dx_h2_i.y = __hmul(dy_row, __hsub(prob_h2_i.y, static_cast<half>(label == 2 * h2_col_id + 1)));\n    dx_h2[i] = dx_h2_i;\n  }\n#else\n  printf(\"use half need nvcc arch >= 530\");\n  assert(false);\n#endif /* __CUDA_ARCH__ >= 530 || !defined(__CUDA_ARCH__)*/\n}\n\n}  // namespace\n\ntemplate<typename T, typename K>\nstruct SparseCrossEntropyKernelUtil<DeviceType::kCUDA, T, K> {\n  static void ComputeEntropy(ep::Stream* stream, const int64_t num_instances,\n                             const int64_t num_classes, const int64_t depth,\n                             const int64_t lower_bound, const T* x, const K* labels, T* y) {\n    ComputeEntropyGpu<<<BlocksNum4ThreadsNum(num_instances), kCudaThreadsNumPerBlock, 0,\n                        stream->As<ep::CudaStream>()->cuda_stream()>>>(\n        num_instances, num_classes, depth, lower_bound, x, labels, y);\n  }\n\n  static void ComputeDiff(ep::Stream* stream, const int64_t num_instances,\n                          const int64_t num_classes, const int64_t depth, const int64_t lower_bound,\n                          const T* x, const K* labels, const T* dy, T* dx) {\n    ComputeDiffGpu<<<BlocksNum4ThreadsNum(num_instances), kCudaThreadsNumPerBlock, 0,\n                     stream->As<ep::CudaStream>()->cuda_stream()>>>(\n        num_instances, num_classes, depth, lower_bound, x, labels, dy, dx);\n  }\n\n  static void ComputeDiffWithSoftmax(ep::Stream* stream, const int64_t elem_cnt,\n                                     const int64_t num_classes, const int64_t depth,\n                                     const int64_t lower_bound, const T* prob, const K* labels,\n                                     const T* dy, T* dx) {\n    if (elem_cnt < GetMaxVal<int32_t>() / 2) {\n      ComputeDiffWithSoftmaxGpu<T, K, int32_t>\n          <<<BlocksNum4ThreadsNum(elem_cnt), kCudaThreadsNumPerBlock, 0,\n             stream->As<ep::CudaStream>()->cuda_stream()>>>(elem_cnt, num_classes, depth,\n                                                            lower_bound, prob, labels, dy, dx);\n    } else {\n      ComputeDiffWithSoftmaxGpu<T, K, int64_t>\n          <<<BlocksNum4ThreadsNum(elem_cnt), kCudaThreadsNumPerBlock, 0,\n             stream->As<ep::CudaStream>()->cuda_stream()>>>(elem_cnt, num_classes, depth,\n                                                            lower_bound, prob, labels, dy, dx);\n    }\n  }\n};\n\ntemplate<typename K>\nstruct SparseCrossEntropyKernelUtil<DeviceType::kCUDA, float16, K> {\n  static void ComputeEntropy(ep::Stream* stream, const int64_t num_instances,\n                             const int64_t num_classes, const int64_t depth,\n                             const int64_t lower_bound, const float16* x, const K* labels,\n                             float16* y) {\n    ComputeEntropyGpuHalf<K><<<BlocksNum4ThreadsNum(num_instances), kCudaThreadsNumPerBlock, 0,\n                               stream->As<ep::CudaStream>()->cuda_stream()>>>(\n        num_instances, num_classes, depth, lower_bound, reinterpret_cast<const half*>(x), labels,\n        reinterpret_cast<half*>(y));\n  }\n\n  static void ComputeDiff(ep::Stream* stream, const int64_t num_instances,\n                          const int64_t num_classes, const int64_t depth, const int64_t lower_bound,\n                          const float16* x, const K* labels, const float16* dy, float16* dx) {\n    ComputeDiffGpuHalf<K><<<BlocksNum4ThreadsNum(num_instances), kCudaThreadsNumPerBlock, 0,\n                            stream->As<ep::CudaStream>()->cuda_stream()>>>(\n        num_instances, num_classes, depth, lower_bound, reinterpret_cast<const half*>(x), labels,\n        reinterpret_cast<const half*>(dy), reinterpret_cast<half*>(dx));\n  }\n\n  static void ComputeDiffWithSoftmax(ep::Stream* stream, const int64_t elem_cnt,\n                                     const int64_t num_classes, const int64_t depth,\n                                     const int64_t lower_bound, const float16* prob,\n                                     const K* labels, const float16* dy, float16* dx) {\n    if (num_classes % 2 == 0) {\n      if (elem_cnt < GetMaxVal<int32_t>() / 2) {\n        ComputeDiffWithSoftmaxGpuHalf2<K, int32_t>\n            <<<BlocksNum4ThreadsNum(elem_cnt / 2), kCudaThreadsNumPerBlock, 0,\n               stream->As<ep::CudaStream>()->cuda_stream()>>>(\n                elem_cnt, num_classes, depth, lower_bound, reinterpret_cast<const half*>(prob),\n                labels, reinterpret_cast<const half*>(dy), reinterpret_cast<half*>(dx));\n      } else {\n        ComputeDiffWithSoftmaxGpuHalf2<K, int64_t>\n            <<<BlocksNum4ThreadsNum(elem_cnt / 2), kCudaThreadsNumPerBlock, 0,\n               stream->As<ep::CudaStream>()->cuda_stream()>>>(\n                elem_cnt, num_classes, depth, lower_bound, reinterpret_cast<const half*>(prob),\n                labels, reinterpret_cast<const half*>(dy), reinterpret_cast<half*>(dx));\n      }\n    } else {\n      if (elem_cnt < GetMaxVal<int32_t>() / 2) {\n        ComputeDiffWithSoftmaxGpuHalf<K, int32_t>\n            <<<BlocksNum4ThreadsNum(elem_cnt), kCudaThreadsNumPerBlock, 0,\n               stream->As<ep::CudaStream>()->cuda_stream()>>>(\n                elem_cnt, num_classes, depth, lower_bound, reinterpret_cast<const half*>(prob),\n                labels, reinterpret_cast<const half*>(dy), reinterpret_cast<half*>(dx));\n      } else {\n        ComputeDiffWithSoftmaxGpuHalf<K, int64_t>\n            <<<BlocksNum4ThreadsNum(elem_cnt), kCudaThreadsNumPerBlock, 0,\n               stream->As<ep::CudaStream>()->cuda_stream()>>>(\n                elem_cnt, num_classes, depth, lower_bound, reinterpret_cast<const half*>(prob),\n                labels, reinterpret_cast<const half*>(dy), reinterpret_cast<half*>(dx));\n      }\n    }\n  }\n};\n\n#define INSTANTIATE_SPARSE_CROSS_ENTROPY_KERNEL_UTIL_CUDA(data_type_pair, index_type_pair) \\\n  template struct SparseCrossEntropyKernelUtil<                                            \\\n      DeviceType::kCUDA, OF_PP_PAIR_FIRST(data_type_pair), OF_PP_PAIR_FIRST(index_type_pair)>;\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_SPARSE_CROSS_ENTROPY_KERNEL_UTIL_CUDA,\n                                 FLOATING_DATA_TYPE_SEQ FLOAT16_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ);\n#undef INSTANTIATE_SPARSE_CROSS_ENTROPY_KERNEL_UTIL_CUDA\n\n}  // namespace user_op\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/sparse_cross_entropy_kernel_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_KERNELS_SPARSE_CROSS_ENTROPY_KERNEL_UTIL_H_\n#define ONEFLOW_USER_KERNELS_SPARSE_CROSS_ENTROPY_KERNEL_UTIL_H_\n\n#include \"oneflow/core/kernel/kernel_util.h\"\n\nnamespace oneflow {\nnamespace user_op {\n\ntemplate<DeviceType device_type, typename T, typename K>\nstruct SparseCrossEntropyKernelUtil {\n  static void ComputeEntropy(ep::Stream* stream, const int64_t num_instances,\n                             const int64_t num_classes, const int64_t depth,\n                             const int64_t lower_bound, const T* x, const K* labels, T* y);\n  static void ComputeDiff(ep::Stream* stream, const int64_t num_instances,\n                          const int64_t num_classes, const int64_t depth, const int64_t lower_bound,\n                          const T* x, const K* labels, const T* dy, T* dx);\n  static void ComputeDiffWithSoftmax(ep::Stream* stream, const int64_t elem_cnt,\n                                     const int64_t num_classes, const int64_t depth,\n                                     const int64_t lower_bound, const T* prob, const K* labels,\n                                     const T* dy, T* dx);\n};\n}  // namespace user_op\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_KERNELS_SPARSE_CROSS_ENTROPY_KERNEL_UTIL_H_\n"
  },
  {
    "path": "oneflow/user/kernels/sparse_softmax_cross_entropy_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/common/balanced_splitter.h\"\n#include \"oneflow/user/kernels/sparse_cross_entropy_kernel_util.h\"\n#include \"oneflow/user/kernels/sparse_softmax_cross_entropy_kernel_util.h\"\n#include \"oneflow/core/kernel/cuda_graph_support.h\"\n#include \"oneflow/core/job/nd_sbp_util.h\"\n#include \"oneflow/core/ep/include/primitive/log_softmax.h\"\n\nnamespace oneflow {\nnamespace user_op {\n\nnamespace {\n\ntemplate<typename Context>\nstd::unique_ptr<ep::primitive::LogSoftmax> NewLogSoftmaxPrimitive(Context* ctx) {\n  const DataType data_type = ctx->TensorDesc4ArgNameAndIndex(\"prediction\", 0)->data_type();\n  return ep::primitive::NewPrimitive<ep::primitive::LogSoftmaxFactory>(ctx->device_type(),\n                                                                       data_type);\n}\n\nauto LogSoftmaxPrimitiveExists() {\n  return hob::make_custom(\"LogSoftmaxPrimitiveExists\", [](const user_op::KernelRegContext& ctx) {\n    return NewLogSoftmaxPrimitive(&ctx).operator bool();\n  });\n}\n\nclass SparseSoftmaxCrossEntropyOpKernelCache final : public user_op::OpKernelCache {\n public:\n  SparseSoftmaxCrossEntropyOpKernelCache(int64_t lower, int64_t upper)\n      : lower_(lower), upper_(upper) {}\n  ~SparseSoftmaxCrossEntropyOpKernelCache() override = default;\n\n  int64_t lower() const { return lower_; }\n  int64_t upper() const { return upper_; }\n\n private:\n  const int64_t lower_;\n  const int64_t upper_;\n};\n\n}  // namespace\n\ntemplate<DeviceType device_type, typename T, typename K>\nclass SparseSoftmaxCrossEntropyKernel final : public user_op::OpKernel,\n                                              public user_op::CudaGraphSupport {\n public:\n  SparseSoftmaxCrossEntropyKernel() = default;\n  ~SparseSoftmaxCrossEntropyKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* prediction = ctx->Tensor4ArgNameAndIndex(\"prediction\", 0);\n    const user_op::Tensor* label = ctx->Tensor4ArgNameAndIndex(\"label\", 0);\n    user_op::Tensor* prob = ctx->Tensor4ArgNameAndIndex(\"prob\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    const int64_t num_instances = label->shape_view().elem_cnt();\n    CHECK_EQ(prediction->shape_view().elem_cnt() % num_instances, 0);\n    const int64_t num_classes = prediction->shape_view().elem_cnt() / num_instances;\n    const int64_t lower_bound = 0;\n    const int64_t depth = ctx->Attr<int64_t>(\"depth\");\n\n    std::unique_ptr<ep::primitive::LogSoftmax> primitive = NewLogSoftmaxPrimitive(ctx);\n    CHECK(primitive);\n    primitive->Launch(ctx->stream(), num_instances, num_classes, prediction->dptr(),\n                      prob->mut_dptr());\n\n    const K* labels = label->dptr<K>();\n    const T* prob_ptr = prob->dptr<T>();\n    T* out_ptr = out->mut_dptr<T>();\n\n    FOR_RANGE(int64_t, i, 0, num_instances) {\n      CHECK_GE(labels[i], 0);\n      CHECK_LT(labels[i], depth);\n      K _label = labels[i] - lower_bound;\n      if (_label >= 0 && _label < num_classes) { out_ptr[i] = -prob_ptr[i * num_classes + _label]; }\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\ntemplate<DeviceType device_type, typename T, typename K>\nclass SparseSoftmaxCrossEntropyMsKernel final : public user_op::OpKernel {\n public:\n  SparseSoftmaxCrossEntropyMsKernel() = default;\n  ~SparseSoftmaxCrossEntropyMsKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    LOG(FATAL) << \"SparseSoftmaxCrossEntropyMsKernel should be split to ops\";\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_SPARSE_SOFTMAX_CROSS_ENTROPY_KERNEL(kernel_class, kernel_name, device_type_v, \\\n                                                     dtype_pair, ltype_pair)                   \\\n  REGISTER_USER_KERNEL(kernel_name)                                                            \\\n      .SetCreateFn<kernel_class<device_type_v, OF_PP_PAIR_FIRST(dtype_pair),                   \\\n                                OF_PP_PAIR_FIRST(ltype_pair)>>()                               \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == device_type_v)                             \\\n                       && (user_op::HobDataType(\"label\", 0) == OF_PP_PAIR_SECOND(ltype_pair))  \\\n                       && (user_op::HobDataType(\"out\", 0) == OF_PP_PAIR_SECOND(dtype_pair))    \\\n                       && LogSoftmaxPrimitiveExists());\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_SPARSE_SOFTMAX_CROSS_ENTROPY_KERNEL,\n                                 (SparseSoftmaxCrossEntropyKernel),\n                                 (\"sparse_softmax_cross_entropy\"),\n                                 OF_PP_MAKE_TUPLE_SEQ(DeviceType::kCPU), FLOATING_DATA_TYPE_SEQ,\n                                 INDEX_DATA_TYPE_SEQ)\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_SPARSE_SOFTMAX_CROSS_ENTROPY_KERNEL,\n                                 (SparseSoftmaxCrossEntropyMsKernel),\n                                 (\"sparse_softmax_cross_entropy_ms\"),\n                                 OF_PP_MAKE_TUPLE_SEQ(DeviceType::kCPU), FLOATING_DATA_TYPE_SEQ,\n                                 INDEX_DATA_TYPE_SEQ)\n#ifdef WITH_CUDA\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_SPARSE_SOFTMAX_CROSS_ENTROPY_KERNEL,\n                                 (SparseSoftmaxCrossEntropyMsKernel),\n                                 (\"sparse_softmax_cross_entropy_ms\"),\n                                 OF_PP_MAKE_TUPLE_SEQ(DeviceType::kCUDA),\n                                 FLOATING_DATA_TYPE_SEQ FLOAT16_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ)\n#endif\n\ntemplate<DeviceType device_type, typename T, typename K>\nclass SparseSoftmaxCrossEntropyGradKernel final : public user_op::OpKernel,\n                                                  public user_op::CudaGraphSupport {\n public:\n  SparseSoftmaxCrossEntropyGradKernel() = default;\n  ~SparseSoftmaxCrossEntropyGradKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* label = ctx->Tensor4ArgNameAndIndex(\"label\", 0);\n    const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    const user_op::Tensor* prob = ctx->Tensor4ArgNameAndIndex(\"prob\", 0);\n    user_op::Tensor* prediction_diff = ctx->Tensor4ArgNameAndIndex(\"prediction_diff\", 0);\n    const int64_t num_instances = label->shape_view().elem_cnt();\n    CHECK_EQ(prob->shape_view().elem_cnt() % num_instances, 0);\n    const int64_t num_classes = prob->shape_view().elem_cnt() / num_instances;\n    const int64_t lower_bound = 0;\n    const int64_t depth = ctx->Attr<int64_t>(\"depth\");\n    SparseSoftmaxCrossEntropyKernelUtil<device_type, T, K>::ComputeDiff(\n        ctx->stream(), prediction_diff->shape_view().elem_cnt(), num_classes, depth, lower_bound,\n        prob->dptr<T>(), label->dptr<K>(), dy->dptr<T>(), prediction_diff->mut_dptr<T>());\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\ntemplate<DeviceType device_type, typename T, typename K>\nclass SparseSoftmaxCrossEntropyMsGradKernel final : public user_op::OpKernel {\n public:\n  SparseSoftmaxCrossEntropyMsGradKernel() = default;\n  ~SparseSoftmaxCrossEntropyMsGradKernel() = default;\n  std::shared_ptr<user_op::OpKernelCache> InitOpKernelCache(\n      user_op::KernelCacheContext* ctx) const override {\n    if (ctx->parallel_ctx().parallel_num() > 1) {\n      const NdSbp& nd_sbp = ctx->NdSbp4ArgNameAndIndex(\"prob\", 0);\n      const Shape& hierarchy = *ctx->parallel_desc().hierarchy();\n      const TensorDesc* prob_logical_desc = ctx->LogicalTensorDesc4ArgNameAndIndex(\"prob\", 0);\n      const int64_t class_axis = prob_logical_desc->shape().NumAxes() - 1;\n      TensorSliceView view = GetTensorSliceView4ParallelId(\n          hierarchy, nd_sbp, prob_logical_desc->shape(), ctx->parallel_ctx().parallel_id());\n      return std::make_shared<SparseSoftmaxCrossEntropyOpKernelCache>(view.At(class_axis).begin(),\n                                                                      view.At(class_axis).end());\n    } else {\n      return nullptr;\n    }\n  }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*,\n               const user_op::OpKernelCache* cache) const override {\n    const user_op::Tensor* label = ctx->Tensor4ArgNameAndIndex(\"label\", 0);\n    const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    const user_op::Tensor* prob = ctx->Tensor4ArgNameAndIndex(\"prob\", 0);\n    user_op::Tensor* prediction_diff = ctx->Tensor4ArgNameAndIndex(\"prediction_diff\", 0);\n    const int64_t num_instances = label->shape_view().elem_cnt();\n    CHECK_EQ(prob->shape_view().elem_cnt() % num_instances, 0);\n    const int64_t num_classes = prob->shape_view().elem_cnt() / num_instances;\n    const int64_t depth = ctx->Attr<int64_t>(\"depth\");\n    int64_t lower_bound = 0;\n    if (cache != nullptr) {\n      auto* kernel_cache = dynamic_cast<const SparseSoftmaxCrossEntropyOpKernelCache*>(cache);\n      CHECK_NOTNULL(kernel_cache);\n      CHECK_EQ(num_classes, kernel_cache->upper() - kernel_cache->lower());\n      lower_bound = kernel_cache->lower();\n    }\n    SparseCrossEntropyKernelUtil<device_type, T, K>::ComputeDiffWithSoftmax(\n        ctx->stream(), prediction_diff->shape_view().elem_cnt(), num_classes, depth, lower_bound,\n        prob->dptr<T>(), label->dptr<K>(), dy->dptr<T>(), prediction_diff->mut_dptr<T>());\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_SPARSE_SOFTMAX_CROSS_ENTROPY_GRAD_KERNEL(kernel_class, kernel_name,             \\\n                                                          device_type_v, dtype_pair, ltype_pair) \\\n  REGISTER_USER_KERNEL(kernel_name)                                                              \\\n      .SetCreateFn<kernel_class<device_type_v, OF_PP_PAIR_FIRST(dtype_pair),                     \\\n                                OF_PP_PAIR_FIRST(ltype_pair)>>()                                 \\\n      .SetIsMatchedHob(                                                                          \\\n          (user_op::HobDeviceType() == device_type_v)                                            \\\n          && (user_op::HobDataType(\"label\", 0) == OF_PP_PAIR_SECOND(ltype_pair))                 \\\n          && (user_op::HobDataType(\"prediction_diff\", 0) == OF_PP_PAIR_SECOND(dtype_pair)))      \\\n      .SetInplaceProposalFn([](const user_op::InferContext&,                                     \\\n                               user_op::AddInplaceArgPair AddInplaceArgPairFn) -> Maybe<void> {  \\\n        OF_RETURN_IF_ERROR(AddInplaceArgPairFn(\"prediction_diff\", 0, \"prob\", 0, true));          \\\n        return Maybe<void>::Ok();                                                                \\\n      });\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_SPARSE_SOFTMAX_CROSS_ENTROPY_GRAD_KERNEL,\n                                 (SparseSoftmaxCrossEntropyGradKernel),\n                                 (\"sparse_softmax_cross_entropy_grad\"),\n                                 OF_PP_MAKE_TUPLE_SEQ(DeviceType::kCPU), FLOATING_DATA_TYPE_SEQ,\n                                 INDEX_DATA_TYPE_SEQ)\n#ifdef WITH_CUDA\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_SPARSE_SOFTMAX_CROSS_ENTROPY_GRAD_KERNEL,\n                                 (SparseSoftmaxCrossEntropyGradKernel),\n                                 (\"sparse_softmax_cross_entropy_grad\"),\n                                 OF_PP_MAKE_TUPLE_SEQ(DeviceType::kCUDA),\n                                 FLOATING_DATA_TYPE_SEQ FLOAT16_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ)\n#endif\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_SPARSE_SOFTMAX_CROSS_ENTROPY_GRAD_KERNEL,\n                                 (SparseSoftmaxCrossEntropyMsGradKernel),\n                                 (\"sparse_softmax_cross_entropy_ms_grad\"),\n                                 OF_PP_MAKE_TUPLE_SEQ(DeviceType::kCPU), FLOATING_DATA_TYPE_SEQ,\n                                 INDEX_DATA_TYPE_SEQ)\n#ifdef WITH_CUDA\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_SPARSE_SOFTMAX_CROSS_ENTROPY_GRAD_KERNEL,\n                                 (SparseSoftmaxCrossEntropyMsGradKernel),\n                                 (\"sparse_softmax_cross_entropy_ms_grad\"),\n                                 OF_PP_MAKE_TUPLE_SEQ(DeviceType::kCUDA),\n                                 FLOATING_DATA_TYPE_SEQ FLOAT16_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ)\n#endif\n}  // namespace user_op\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/sparse_softmax_cross_entropy_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/user/kernels/sparse_cross_entropy_kernel_util.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/cuda/softmax.cuh\"\n#include \"oneflow/core/kernel/cuda_graph_support.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n\nnamespace oneflow {\nnamespace user_op {\n\nnamespace {\n\ntemplate<typename T>\nvoid ComputeProb(ep::Stream* stream, const int64_t row, const int64_t col, const T* in, T* prob) {\n  using ComputeType = typename cuda::softmax::DefaultComputeType<T>::type;\n  cuda::softmax::DirectLoad<T, ComputeType> load(in, col);\n  cuda::softmax::DirectStore<ComputeType, T> store(prob, col);\n  OF_CUDA_CHECK((cuda::softmax::DispatchLogSoftmax<decltype(load), decltype(store), ComputeType>(\n      stream->As<ep::CudaStream>()->cuda_stream(), load, store, row, col)));\n}\n\ntemplate<>\nvoid ComputeProb(ep::Stream* stream, const int64_t row, const int64_t col, const float16* in,\n                 float16* prob) {\n  cuda::softmax::DirectLoad<half, float> load(reinterpret_cast<const half*>(in), col);\n  cuda::softmax::DirectStore<float, half> store(reinterpret_cast<half*>(prob), col);\n  OF_CUDA_CHECK((cuda::softmax::DispatchLogSoftmax<decltype(load), decltype(store), float>(\n      stream->As<ep::CudaStream>()->cuda_stream(), load, store, row, col)));\n}\n\ntemplate<typename T, typename K>\n__global__ void ComputeSparseSoftmaxCrossEntropyResultGpu(const int64_t num_instances,\n                                                          const int64_t num_classes,\n                                                          const int64_t depth,\n                                                          const int64_t lower_bound,\n                                                          const K* labels, const T* prob, T* out) {\n  CUDA_1D_KERNEL_LOOP_T(int64_t, i, num_instances) {\n    assert(labels[i] >= 0);\n    assert(labels[i] < depth);\n    K label = labels[i] - lower_bound;\n    if (label >= 0 && label < num_classes) { out[i] = -prob[i * num_classes + label]; }\n  }\n}\ntemplate<typename T, typename K>\ninline typename std::enable_if<std::is_floating_point<T>::value, void>::type\nComputeSparseSoftmaxCrossEntropyResult(ep::Stream* stream, const int64_t num_instances,\n                                       const int64_t num_classes, const int64_t depth,\n                                       const int64_t lower_bound, const K* labels, const T* prob,\n                                       T* out) {\n  ComputeSparseSoftmaxCrossEntropyResultGpu<T, K>\n      <<<BlocksNum4ThreadsNum(num_instances), kCudaThreadsNumPerBlock, 0,\n         stream->As<ep::CudaStream>()->cuda_stream()>>>(num_instances, num_classes, depth,\n                                                        lower_bound, labels, prob, out);\n}\ntemplate<typename T, typename K>\ninline typename std::enable_if<std::is_same<T, float16>::value, void>::type\nComputeSparseSoftmaxCrossEntropyResult(ep::Stream* stream, const int64_t num_instances,\n                                       const int64_t num_classes, const int64_t depth,\n                                       const int64_t lower_bound, const K* labels, const T* prob,\n                                       T* out) {\n#if __CUDA_ARCH__ >= 530 || !defined(__CUDA_ARCH__)\n  ComputeSparseSoftmaxCrossEntropyResultGpu<half, K>\n      <<<BlocksNum4ThreadsNum(num_instances), kCudaThreadsNumPerBlock, 0,\n         stream->As<ep::CudaStream>()->cuda_stream()>>>(\n          num_instances, num_classes, depth, lower_bound, labels,\n          reinterpret_cast<const half*>(prob), reinterpret_cast<half*>(out));\n#else\n  printf(\"use half need nvcc arch >= 530\");\n  assert(false);\n#endif /* __CUDA_ARCH__ >= 530 || !defined(__CUDA_ARCH__)*/\n}\n}  // namespace\n\ntemplate<typename T, typename K>\nclass SparseSoftmaxCrossEntropyKernel final : public user_op::OpKernel,\n                                              public user_op::CudaGraphSupport {\n public:\n  SparseSoftmaxCrossEntropyKernel() = default;\n  ~SparseSoftmaxCrossEntropyKernel() override = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* prediction = ctx->Tensor4ArgNameAndIndex(\"prediction\", 0);\n    const user_op::Tensor* label = ctx->Tensor4ArgNameAndIndex(\"label\", 0);\n    user_op::Tensor* prob = ctx->Tensor4ArgNameAndIndex(\"prob\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n\n    const int64_t num_instances = label->shape_view().elem_cnt();\n    CHECK_EQ(prediction->shape_view().elem_cnt() % num_instances, 0);\n    const int64_t num_classes = prediction->shape_view().elem_cnt() / num_instances;\n    const int64_t lower_bound = 0;\n    const int64_t depth = ctx->Attr<int64_t>(\"depth\");\n\n    ComputeProb<T>(ctx->stream(), num_instances, num_classes, prediction->dptr<T>(),\n                   prob->mut_dptr<T>());\n    ComputeSparseSoftmaxCrossEntropyResult<T, K>(ctx->stream(), num_instances, num_classes, depth,\n                                                 lower_bound, label->dptr<K>(), prob->dptr<T>(),\n                                                 out->mut_dptr<T>());\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_SPARSE_SOFTMAX_CROSS_ENTROPY_KERNEL(dtype_pair, ltype_pair)                  \\\n  REGISTER_USER_KERNEL(\"sparse_softmax_cross_entropy\")                                        \\\n      .SetCreateFn<SparseSoftmaxCrossEntropyKernel<OF_PP_PAIR_FIRST(dtype_pair),              \\\n                                                   OF_PP_PAIR_FIRST(ltype_pair)>>()           \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)                        \\\n                       && (user_op::HobDataType(\"label\", 0) == OF_PP_PAIR_SECOND(ltype_pair)) \\\n                       && (user_op::HobDataType(\"out\", 0) == OF_PP_PAIR_SECOND(dtype_pair)));\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_SPARSE_SOFTMAX_CROSS_ENTROPY_KERNEL,\n                                 FLOATING_DATA_TYPE_SEQ FLOAT16_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ)\n\n}  // namespace user_op\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/sparse_softmax_cross_entropy_kernel_util.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/user/kernels/sparse_softmax_cross_entropy_kernel_util.h\"\n\nnamespace oneflow {\nnamespace user_op {\n\ntemplate<typename T, typename K>\nstruct SparseSoftmaxCrossEntropyKernelUtil<DeviceType::kCPU, T, K> {\n  static void ComputeDiff(ep::Stream* stream, const int64_t num_instances,\n                          const int64_t num_classes, const int64_t depth, const int64_t lower_bound,\n                          const T* prob, const K* labels, const T* dy, T* dx) {\n    FOR_RANGE(int64_t, i, 0, num_instances) {\n      const int32_t row_id = i / num_classes;\n      const int32_t col_id = i - row_id * num_classes;\n      CHECK_GE(labels[row_id], 0);\n      CHECK_LT(labels[row_id], depth);\n      K label = labels[row_id] - lower_bound;\n\n      if (label == col_id) {\n        dx[i] = dy[row_id] * (std::exp(prob[i]) - 1);\n      } else {\n        dx[i] = dy[row_id] * std::exp(prob[i]);\n      }\n    }\n  }\n};\n#define INSTANTIATE_SPARSE_SOFTMAX_CROSS_ENTROPY_KERNEL_UTIL_CPU(data_type_pair, index_type_pair) \\\n  template struct SparseSoftmaxCrossEntropyKernelUtil<                                            \\\n      DeviceType::kCPU, OF_PP_PAIR_FIRST(data_type_pair), OF_PP_PAIR_FIRST(index_type_pair)>;\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_SPARSE_SOFTMAX_CROSS_ENTROPY_KERNEL_UTIL_CPU,\n                                 FLOATING_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ);\n#undef INSTANTIATE_SPARSE_SOFTMAX_CROSS_ENTROPY_KERNEL_UTIL_CPU\n}  // namespace user_op\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/sparse_softmax_cross_entropy_kernel_util.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/user/kernels/sparse_softmax_cross_entropy_kernel_util.h\"\n#include \"oneflow/core/cuda/softmax.cuh\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n\nnamespace oneflow {\nnamespace user_op {\nnamespace {\n\ntemplate<typename T>\n__inline__ __device__ T Exp(T x);\n\ntemplate<>\n__inline__ __device__ float Exp<float>(float x) {\n#ifdef OF_SOFTMAX_USE_FAST_MATH\n  return __expf(x);\n#else\n  return exp(x);\n#endif\n}\n\ntemplate<>\n__inline__ __device__ double Exp<double>(double x) {\n  return exp(x);\n}\n\ntemplate<>\n__inline__ __device__ half Exp<half>(half x) {\n#ifdef OF_SOFTMAX_USE_FAST_MATH\n  return __float2half(__expf(__half2float(x)));\n#else\n  return __float2half(exp(__half2float(x)));\n#endif\n}\n\ntemplate<typename T, typename K, typename IndexType>\n__global__ void ComputeDiffGpu(const int64_t num_instances, const int64_t num_classes,\n                               const int64_t depth, const int64_t lower_bound, const T* prob,\n                               const K* labels, const T* dy, T* dx) {\n  CUDA_1D_KERNEL_LOOP_T(IndexType, i, num_instances) {\n    const IndexType row_id = i / num_classes;\n    const IndexType col_id = i - row_id * num_classes;\n    assert(labels[row_id] >= 0);\n    assert(labels[row_id] < depth);\n    K label = labels[row_id] - lower_bound;\n    if (label == col_id) {\n      dx[i] = dy[row_id] * (Exp(prob[i]) - 1);\n    } else {\n      dx[i] = dy[row_id] * Exp(prob[i]);\n    }\n  }\n}\n\ntemplate<typename K, typename IndexType>\n__global__ void ComputeDiffGpuHalf(const int64_t num_instances, const int64_t num_classes,\n                                   const int64_t depth, const int64_t lower_bound, const half* prob,\n                                   const K* labels, const half* dy, half* dx) {\n  CUDA_1D_KERNEL_LOOP_T(IndexType, i, num_instances) {\n    const IndexType row_id = i / num_classes;\n    const IndexType col_id = i - row_id * num_classes;\n    assert(labels[row_id] >= 0);\n    assert(labels[row_id] < depth);\n    K label = labels[row_id] - lower_bound;\n    if (label == col_id) {\n      dx[i] = __hmul(dy[row_id], __hsub(Exp(prob[i]), __float2half(1.0)));\n    } else {\n      dx[i] = __hmul(dy[row_id], Exp(prob[i]));\n    }\n  }\n}\n\n}  // namespace\n\ntemplate<typename T, typename K>\nstruct SparseSoftmaxCrossEntropyKernelUtil<DeviceType::kCUDA, T, K> {\n  static void ComputeDiff(ep::Stream* stream, const int64_t num_instances,\n                          const int64_t num_classes, const int64_t depth, const int64_t lower_bound,\n                          const T* prob, const K* labels, const T* dy, T* dx) {\n    if (num_instances < GetMaxVal<int32_t>() / 2) {\n      ComputeDiffGpu<T, K, int32_t><<<BlocksNum4ThreadsNum(num_instances), kCudaThreadsNumPerBlock,\n                                      0, stream->As<ep::CudaStream>()->cuda_stream()>>>(\n          num_instances, num_classes, depth, lower_bound, prob, labels, dy, dx);\n    } else {\n      // NOTE(chengcheng): int division ('/') of i will reduce performance of int64_t.\n      ComputeDiffGpu<T, K, int64_t><<<BlocksNum4ThreadsNum(num_instances), kCudaThreadsNumPerBlock,\n                                      0, stream->As<ep::CudaStream>()->cuda_stream()>>>(\n          num_instances, num_classes, depth, lower_bound, prob, labels, dy, dx);\n    }\n  }\n};\n\ntemplate<typename K>\nstruct SparseSoftmaxCrossEntropyKernelUtil<DeviceType::kCUDA, float16, K> {\n  static void ComputeDiff(ep::Stream* stream, const int64_t num_instances,\n                          const int64_t num_classes, const int64_t depth, const int64_t lower_bound,\n                          const float16* prob, const K* labels, const float16* dy, float16* dx) {\n    if (num_instances < GetMaxVal<int32_t>() / 2) {\n      ComputeDiffGpuHalf<K, int32_t><<<BlocksNum4ThreadsNum(num_instances), kCudaThreadsNumPerBlock,\n                                       0, stream->As<ep::CudaStream>()->cuda_stream()>>>(\n          num_instances, num_classes, depth, lower_bound, reinterpret_cast<const half*>(prob),\n          labels, reinterpret_cast<const half*>(dy), reinterpret_cast<half*>(dx));\n    } else {\n      ComputeDiffGpuHalf<K, int64_t><<<BlocksNum4ThreadsNum(num_instances), kCudaThreadsNumPerBlock,\n                                       0, stream->As<ep::CudaStream>()->cuda_stream()>>>(\n          num_instances, num_classes, depth, lower_bound, reinterpret_cast<const half*>(prob),\n          labels, reinterpret_cast<const half*>(dy), reinterpret_cast<half*>(dx));\n    }\n  }\n};\n\n#define INSTANTIATE_SPARSE_SOFTMAX_CROSS_ENTROPY_KERNEL_UTIL_CUDA(data_type_pair, index_type_pair) \\\n  template struct SparseSoftmaxCrossEntropyKernelUtil<                                             \\\n      DeviceType::kCUDA, OF_PP_PAIR_FIRST(data_type_pair), OF_PP_PAIR_FIRST(index_type_pair)>;\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_SPARSE_SOFTMAX_CROSS_ENTROPY_KERNEL_UTIL_CUDA,\n                                 FLOATING_DATA_TYPE_SEQ FLOAT16_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ);\n#undef INSTANTIATE_SPARSE_SOFTMAX_CROSS_ENTROPY_KERNEL_UTIL_CUDA\n\n}  // namespace user_op\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/sparse_softmax_cross_entropy_kernel_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_KERNELS_SPARSE_SOFTMAX_CROSS_ENTROPY_KERNEL_UTIL_H_\n#define ONEFLOW_USER_KERNELS_SPARSE_SOFTMAX_CROSS_ENTROPY_KERNEL_UTIL_H_\n\n#include \"oneflow/core/ndarray/ndarray_util.h\"\n\nnamespace oneflow {\nnamespace user_op {\n\ntemplate<DeviceType device_type, typename T, typename K>\nstruct SparseSoftmaxCrossEntropyKernelUtil {\n  static void ComputeDiff(ep::Stream* stream, const int64_t elem_cnt, const int64_t num_classes,\n                          const int64_t depth, const int64_t lower_bound, const T* prob,\n                          const K* labels, const T* dy, T* dx);\n};\n\n}  // namespace user_op\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_KERNELS_SPARSE_SOFTMAX_CROSS_ENTROPY_KERNEL_UTIL_H_\n"
  },
  {
    "path": "oneflow/user/kernels/split_like_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/ep/include/primitive/copy_nd.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename Context>\nstd::unique_ptr<ep::primitive::CopyNd> NewCopyNdPrimitive(Context* ctx) {\n  return ep::primitive::NewPrimitive<ep::primitive::CopyNdFactory>(ctx->device_type(), 2);\n}\n\nclass SplitLikeKernel final : public user_op::OpKernel {\n public:\n  SplitLikeKernel() = default;\n  ~SplitLikeKernel() override = default;\n\n private:\n  void InferShape(user_op::KernelInferContext* ctx) const override {\n    const auto axis = ctx->Attr<int64_t>(\"axis\");\n    const ShapeView& in_shape_view = ctx->ShapeView4ArgNameAndIndex(\"in\", 0);\n    int64_t total_dim_size = 0;\n    const int64_t like_num_axes = ctx->ShapeView4ArgNameAndIndex(\"like\", 0).NumAxes();\n    const int64_t in_num_axes = in_shape_view.NumAxes();\n    CHECK_LE(like_num_axes, in_num_axes);\n    CHECK_LT(axis, like_num_axes);\n    FOR_RANGE(int32_t, i, 0, ctx->outputs().size()) {\n      const ShapeView& like_shape_view = ctx->ShapeView4ArgNameAndIndex(\"like\", i);\n      CHECK_EQ(like_shape_view.NumAxes(), like_num_axes);\n      FOR_RANGE(int64_t, j, 0, like_num_axes) {\n        if (j == axis) {\n          total_dim_size += like_shape_view.At(j);\n        } else {\n          CHECK_EQ(like_shape_view.At(j), in_shape_view.At(j));\n        }\n      }\n      if (ctx->TensorDesc4ArgNameAndIndex(\"out\", i)->is_dynamic()) {\n        auto mut_shape_view = ctx->MutShapeView4ArgNameAndIndex(\"out\", i);\n        DimVector out_i_dim_vec;\n        like_shape_view.ToDimVector(&out_i_dim_vec);\n        FOR_RANGE(int64_t, j, like_num_axes, in_num_axes) {\n          out_i_dim_vec.emplace_back(in_shape_view.At(j));\n        }\n        mut_shape_view.set_shape(Shape(out_i_dim_vec));\n      }\n    }\n    CHECK_EQ(total_dim_size, in_shape_view.At(axis));\n  }\n\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* in_tensor = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    const auto axis = ctx->Attr<int64_t>(\"axis\");\n    const int64_t in_cols = in_tensor->shape_view().Count(axis);\n    const int64_t rows = in_tensor->shape_view().elem_cnt() / in_cols;\n    CHECK_GT(rows, 0);\n\n    auto primitive = NewCopyNdPrimitive(ctx);\n    CHECK(primitive);\n    int64_t in_col_offset = 0;\n    for (const auto& out_arg_pair : ctx->outputs()) {\n      user_op::Tensor* out_tensor =\n          ctx->Tensor4ArgNameAndIndex(out_arg_pair.first, out_arg_pair.second);\n      const int64_t out_cols = out_tensor->shape_view().Count(axis);\n      CHECK_EQ(out_tensor->shape_view().elem_cnt(), rows * out_cols);\n      if (out_cols > 0) {\n        DimVector dst_shape = {rows, out_cols};\n        DimVector dst_pos_vec = {0, 0};\n        DimVector src_shape = {rows, in_cols};\n        DimVector src_pos_vec = {0, in_col_offset};\n        DimVector extent_vec = {rows, out_cols};\n        primitive->Launch(ctx->stream(), out_tensor->data_type(), 2, out_tensor->mut_dptr(),\n                          dst_shape.data(), dst_pos_vec.data(), in_tensor->dptr(), src_shape.data(),\n                          src_pos_vec.data(), extent_vec.data());\n      }\n      in_col_offset += out_cols;\n    }\n    CHECK_EQ(in_col_offset, in_cols);\n  }\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\nauto CopyNdPrimitiveExists() {\n  return hob::make_custom(\"CopyNdPrimitiveExists\", [](const user_op::KernelRegContext& ctx) {\n    return NewCopyNdPrimitive(&ctx).operator bool();\n  });\n}\n\n}  // namespace\n\nREGISTER_USER_KERNEL(\"split_like\")\n    .SetCreateFn<SplitLikeKernel>()\n    .SetIsMatchedHob(CopyNdPrimitiveExists() == true);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/sqrt_square_sum_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <cstdint>\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/user/kernels/sqrt_square_sum_kernel_util.h\"\n#include \"oneflow/core/common/balanced_splitter.h\"\n#include \"oneflow/core/kernel/cuda_graph_support.h\"\n\nnamespace oneflow {\n\nnamespace user_op {\n\nint64_t getThreadNumBlocks(int64_t n) {\n  int64_t num_blocks = 1;\n#ifdef WITH_CUDA\n  num_blocks = BlocksNum4ThreadsNum(n);\n#endif\n  return num_blocks;\n}\n\ntemplate<DeviceType device_type, typename T>\nclass SqrtSquareSumKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport {\n public:\n  SqrtSquareSumKernel() = default;\n  ~SqrtSquareSumKernel() override = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    user_op::Tensor* tmp = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n\n    SqrtSquareSumKernelUtil<device_type, T>::SqrtSquareSum(ctx->stream(),\n                                                           x->shape_view().elem_cnt(), x->dptr<T>(),\n                                                           y->mut_dptr<T>(), tmp->mut_dptr<T>());\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_SQUARE_SUM_KERNEL(device, dtype)                                     \\\n  REGISTER_USER_KERNEL(\"sqrt_square_sum\")                                             \\\n      .SetCreateFn<SqrtSquareSumKernel<device, OF_PP_PAIR_FIRST(dtype)>>()            \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == device)                           \\\n                       && (user_op::HobDataType(\"y\", 0) == OF_PP_PAIR_SECOND(dtype))) \\\n      .SetInferTmpSizeFn([](user_op::InferContext* ctx) -> size_t {                   \\\n        const auto& x_shape = ctx->InputTensorDesc(\"x\", 0).shape();                   \\\n        const int32_t num_blocks = getThreadNumBlocks(x_shape.Count(0));              \\\n        int64_t tmp_buffer_size = num_blocks;                                         \\\n        return tmp_buffer_size * sizeof(OF_PP_PAIR_FIRST(dtype));                     \\\n      });\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_SQUARE_SUM_KERNEL, DEVICE_TYPE_SEQ,\n                                 FLOATING_DATA_TYPE_SEQ)\n\n}  // namespace user_op\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/sqrt_square_sum_kernel_util.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/user/kernels/sqrt_square_sum_kernel_util.h\"\n\nnamespace oneflow {\n\ntemplate<typename T>\nstruct SqrtSquareSumKernelUtil<DeviceType::kCPU, T> {\n  static void SqrtSquareSum(ep::Stream* stream, int64_t n, const T* x, T* y, T* tmp) {\n    T sum = 0;\n    FOR_RANGE(int64_t, i, 0, n) { sum += x[i] * x[i]; }\n    *y = std::sqrt(sum);\n  }\n};\n\n#define INSTANTIATE_SQUARE_SUM_KERNEL_UTIL_CPU(type_cpp, type_proto) \\\n  template struct SqrtSquareSumKernelUtil<DeviceType::kCPU, type_cpp>;\nOF_PP_FOR_EACH_TUPLE(INSTANTIATE_SQUARE_SUM_KERNEL_UTIL_CPU, FLOATING_DATA_TYPE_SEQ);\n#undef INSTANTIATE_SQUARE_SUM_KERNEL_UTIL_CPU\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/sqrt_square_sum_kernel_util.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/user/kernels/sqrt_square_sum_kernel_util.h\"\n#include \"oneflow/core/cuda/atomic.cuh\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n#include <cub/cub.cuh>\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename T>\n__global__ void SqrtSquareSumForOneThreadBlock(int64_t n, const T* x, T* y) {\n  T t_sum = 0;\n  CUDA_1D_KERNEL_LOOP(i, n) { t_sum += x[i] * x[i]; }\n  typedef cub::BlockReduce<T, kCudaThreadsNumPerBlock> BlockReduce;\n  __shared__ typename BlockReduce::TempStorage temp_storage;\n  T b_sum = BlockReduce(temp_storage).Sum(t_sum);\n  if (threadIdx.x == 0) { *y = sqrt(b_sum); }\n}\n\ntemplate<typename T>\n__global__ void SqrtSumForMultiThreadBlock(int64_t n, const T* x, T* y) {\n  T t_sum = 0;\n  CUDA_1D_KERNEL_LOOP(i, n) { t_sum += x[i]; }\n  typedef cub::BlockReduce<T, kCudaThreadsNumPerBlock> BlockReduce;\n  __shared__ typename BlockReduce::TempStorage temp_storage;\n  T b_sum = BlockReduce(temp_storage).Sum(t_sum);\n  if (threadIdx.x == 0) { *y = sqrt(b_sum); }\n}\n\ntemplate<typename T>\n__global__ void SquareSumForMultiThreadBlock(int64_t n, const T* x, T* tmp) {\n  T t_sum = 0;\n  CUDA_1D_KERNEL_LOOP(i, n) { t_sum += x[i] * x[i]; }\n  typedef cub::BlockReduce<T, kCudaThreadsNumPerBlock> BlockReduce;\n  __shared__ typename BlockReduce::TempStorage temp_storage;\n  T b_sum = BlockReduce(temp_storage).Sum(t_sum);\n  if (threadIdx.x == 0) { tmp[blockIdx.x] = b_sum; }\n}\n\n}  // namespace\n\ntemplate<typename T>\nstruct SqrtSquareSumKernelUtil<DeviceType::kCUDA, T> {\n  static void SqrtSquareSum(ep::Stream* stream, int64_t n, const T* x, T* y, T* tmp) {\n    const int32_t num_blocks = BlocksNum4ThreadsNum(n);\n    CHECK_GE(num_blocks, 0);\n    if (num_blocks == 1) {\n      SqrtSquareSumForOneThreadBlock<T>\n          <<<1, kCudaThreadsNumPerBlock, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(n, x, y);\n    } else {\n      Memset<DeviceType::kCUDA>(stream, y, 0, sizeof(T));\n      SquareSumForMultiThreadBlock<T>\n          <<<num_blocks, kCudaThreadsNumPerBlock, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(\n              n, x, tmp);\n      SqrtSumForMultiThreadBlock<T>\n          <<<1, kCudaThreadsNumPerBlock, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(\n              num_blocks, tmp, y);\n    }\n  }\n};\n\n#define INSTANTIATE_SQRT_SQUARE_SUM_KERNEL_UTIL_CUDA(type_cpp, type_proto) \\\n  template struct SqrtSquareSumKernelUtil<DeviceType::kCUDA, type_cpp>;\nOF_PP_FOR_EACH_TUPLE(INSTANTIATE_SQRT_SQUARE_SUM_KERNEL_UTIL_CUDA, FLOATING_DATA_TYPE_SEQ);\n#undef INSTANTIATE_SQRT_SQUARE_SUM_KERNEL_UTIL_CUDA\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/sqrt_square_sum_kernel_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_KERNELS_SQUARE_SUM_KERNEL_UTIL_H_\n#define ONEFLOW_USER_KERNELS_SQUARE_SUM_KERNEL_UTIL_H_\n\n#include \"oneflow/core/kernel/kernel_util.h\"\n\nnamespace oneflow {\n\ntemplate<DeviceType device_type, typename T>\nstruct SqrtSquareSumKernelUtil {\n  static void SqrtSquareSum(ep::Stream* stream, int64_t n, const T* x, T* y, T* tmp);\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_KERNELS_SQUARE_SUM_KERNEL_UTIL_H_\n"
  },
  {
    "path": "oneflow/user/kernels/square_sum_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/user/kernels/square_sum_kernel_util.h\"\n#include \"oneflow/core/common/balanced_splitter.h\"\n#include \"oneflow/core/kernel/cuda_graph_support.h\"\n\nnamespace oneflow {\n\nnamespace user_op {\n\ntemplate<DeviceType device_type, typename T>\nclass SquareSumKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport {\n public:\n  SquareSumKernel() = default;\n  ~SquareSumKernel() override = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n\n    SquareSumKernelUtil<device_type, T>::SquareSum(ctx->stream(), x->shape_view().elem_cnt(),\n                                                   x->dptr<T>(), y->mut_dptr<T>());\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_SQUARE_SUM_KERNEL(device, dtype)                      \\\n  REGISTER_USER_KERNEL(\"square_sum\")                                   \\\n      .SetCreateFn<SquareSumKernel<device, OF_PP_PAIR_FIRST(dtype)>>() \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == device)            \\\n                       && (user_op::HobDataType(\"y\", 0) == OF_PP_PAIR_SECOND(dtype)));\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_SQUARE_SUM_KERNEL, DEVICE_TYPE_SEQ,\n                                 FLOATING_DATA_TYPE_SEQ)\n\ntemplate<DeviceType device_type, typename T>\nclass MultiSquareSumKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport {\n public:\n  MultiSquareSumKernel() = default;\n  ~MultiSquareSumKernel() override = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    std::vector<SquareSumParam<T>> params;\n    params.resize(ctx->input_size(\"x\"));\n    for (int64_t i = 0; i < params.size(); ++i) {\n      const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex(\"x\", i);\n      params[i].count = x->shape_view().elem_cnt();\n      params[i].ptr = x->dptr<T>();\n    }\n    user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    SquareSumKernelUtil<device_type, T>::MultiSquareSum(ctx->stream(), params, y->mut_dptr<T>());\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_MULTI_SQUARE_SUM_KERNEL(device, dtype)                     \\\n  REGISTER_USER_KERNEL(\"multi_square_sum\")                                  \\\n      .SetCreateFn<MultiSquareSumKernel<device, OF_PP_PAIR_FIRST(dtype)>>() \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == device)                 \\\n                       && (user_op::HobDataType(\"y\", 0) == OF_PP_PAIR_SECOND(dtype)));\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_MULTI_SQUARE_SUM_KERNEL, DEVICE_TYPE_SEQ,\n                                 FLOATING_DATA_TYPE_SEQ)\n\n}  // namespace user_op\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/square_sum_kernel_util.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/user/kernels/square_sum_kernel_util.h\"\n\nnamespace oneflow {\n\ntemplate<typename T>\nstruct SquareSumKernelUtil<DeviceType::kCPU, T> {\n  static void SquareSum(ep::Stream* stream, int64_t n, const T* x, T* y) {\n    T sum = 0;\n    FOR_RANGE(int64_t, i, 0, n) { sum += x[i] * x[i]; }\n    *y = sum;\n  }\n\n  static void MultiSquareSum(ep::Stream* stream, const std::vector<SquareSumParam<T>>& params,\n                             T* y) {\n    T sum = 0;\n    FOR_RANGE(int64_t, i, 0, params.size()) {\n      const auto& p = params[i];\n      FOR_RANGE(int64_t, j, 0, p.count) { sum += p.ptr[j] * p.ptr[j]; }\n    }\n    *y = sum;\n  }\n};\n\n#define INSTANTIATE_SQUARE_SUM_KERNEL_UTIL_CPU(type_cpp, type_proto) \\\n  template struct SquareSumKernelUtil<DeviceType::kCPU, type_cpp>;\nOF_PP_FOR_EACH_TUPLE(INSTANTIATE_SQUARE_SUM_KERNEL_UTIL_CPU, FLOATING_DATA_TYPE_SEQ);\n#undef INSTANTIATE_SQUARE_SUM_KERNEL_UTIL_CPU\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/square_sum_kernel_util.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/user/kernels/square_sum_kernel_util.h\"\n#include \"oneflow/core/cuda/atomic.cuh\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n#include <cub/cub.cuh>\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename T, bool ONE_BLOCK>\n__global__ void SquareSumGpu(int64_t n, const T* x, T* y) {\n  T t_sum = 0;\n  CUDA_1D_KERNEL_LOOP(i, n) { t_sum += x[i] * x[i]; }\n  typedef cub::BlockReduce<T, kCudaThreadsNumPerBlock> BlockReduce;\n  __shared__ typename BlockReduce::TempStorage temp_storage;\n  T b_sum = BlockReduce(temp_storage).Sum(t_sum);\n  if (threadIdx.x == 0) {\n    if (ONE_BLOCK) {\n      *y = b_sum;\n    } else {\n      cuda::atomic::Add(y, b_sum);\n    }\n  }\n}\n\nconstexpr int64_t kMultiSquareSumMaxSize = 64;\n\ntemplate<typename T>\nstruct MultiSquareSumParams {\n  SquareSumParam<T> params[kMultiSquareSumMaxSize];\n  int32_t size;\n};\n\ntemplate<typename T>\n__global__ void MultiSquareSumGpu(const MultiSquareSumParams<T> params, T* y) {\n  T t_sum = 0;\n  for (int i = 0; i < params.size; ++i) {\n    const SquareSumParam<T> param = params.params[i];\n    CUDA_1D_KERNEL_LOOP(j, param.count) { t_sum += param.ptr[j] * param.ptr[j]; }\n  }\n  typedef cub::BlockReduce<T, kCudaThreadsNumPerBlock> BlockReduce;\n  __shared__ typename BlockReduce::TempStorage temp_storage;\n  T b_sum = BlockReduce(temp_storage).Sum(t_sum);\n  if (threadIdx.x == 0) { cuda::atomic::Add(y, b_sum); }\n}\n\n}  // namespace\n\ntemplate<typename T>\nstruct SquareSumKernelUtil<DeviceType::kCUDA, T> {\n  static void SquareSum(ep::Stream* stream, int64_t n, const T* x, T* y) {\n    const int32_t num_blocks = BlocksNum4ThreadsNum(n);\n    CHECK_GE(num_blocks, 0);\n    if (num_blocks == 0) {\n      Memset<DeviceType::kCUDA>(stream, y, 0, sizeof(T));\n    } else if (num_blocks == 1) {\n      SquareSumGpu<T, true>\n          <<<1, kCudaThreadsNumPerBlock, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(n, x, y);\n    } else {\n      Memset<DeviceType::kCUDA>(stream, y, 0, sizeof(T));\n      SquareSumGpu<T, false>\n          <<<num_blocks, kCudaThreadsNumPerBlock, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(\n              n, x, y);\n    }\n  }\n\n  static void MultiSquareSum(ep::Stream* stream, const std::vector<SquareSumParam<T>>& params,\n                             T* y) {\n    Memset<DeviceType::kCUDA>(stream, y, 0, sizeof(T));\n    for (int64_t start = 0; start < params.size(); start += kMultiSquareSumMaxSize) {\n      MultiSquareSumParams<T> gpu_params{};\n      int64_t max_count = 0;\n      gpu_params.size = std::min<int64_t>(start + kMultiSquareSumMaxSize, params.size()) - start;\n      for (int64_t i = 0; i < gpu_params.size; ++i) {\n        gpu_params.params[i] = params[start + i];\n        max_count = std::max(max_count, gpu_params.params[i].count);\n      }\n      MultiSquareSumGpu<T><<<BlocksNum4ThreadsNum(max_count), kCudaThreadsNumPerBlock, 0,\n                             stream->As<ep::CudaStream>()->cuda_stream()>>>(gpu_params, y);\n    }\n  }\n};\n\n#define INSTANTIATE_SQUARE_SUM_KERNEL_UTIL_CUDA(type_cpp, type_proto) \\\n  template struct SquareSumKernelUtil<DeviceType::kCUDA, type_cpp>;\nOF_PP_FOR_EACH_TUPLE(INSTANTIATE_SQUARE_SUM_KERNEL_UTIL_CUDA, FLOATING_DATA_TYPE_SEQ);\n#undef INSTANTIATE_SQUARE_SUM_KERNEL_UTIL_CUDA\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/square_sum_kernel_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_KERNELS_SQUARE_SUM_KERNEL_UTIL_H_\n#define ONEFLOW_USER_KERNELS_SQUARE_SUM_KERNEL_UTIL_H_\n\n#include \"oneflow/core/kernel/kernel_util.h\"\n\nnamespace oneflow {\n\ntemplate<typename T>\nstruct SquareSumParam {\n  const T* ptr;\n  int64_t count;\n};\n\ntemplate<DeviceType device_type, typename T>\nstruct SquareSumKernelUtil {\n  static void SquareSum(ep::Stream* stream, int64_t n, const T* x, T* y);\n  static void MultiSquareSum(ep::Stream* stream, const std::vector<SquareSumParam<T>>& params,\n                             T* y);\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_KERNELS_SQUARE_SUM_KERNEL_UTIL_H_\n"
  },
  {
    "path": "oneflow/user/kernels/ssp_variable_proxy_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/new_kernel_util.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<DeviceType device_type>\nclass SspVariableProxyKernel final : public user_op::OpKernel {\n public:\n  SspVariableProxyKernel() = default;\n  ~SspVariableProxyKernel() override = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* var = ctx->Tensor4ArgNameAndIndex(\"var\", 0);\n    const user_op::Tensor* ref = ctx->Tensor4ArgNameAndIndex(\"ref\", 0);\n    CHECK_EQ(var->dptr(), ref->dptr());\n    user_op::Tensor* value = ctx->Tensor4ArgNameAndIndex(\"value\", 0);\n    const ShapeView& in_shape = ref->shape_view();\n    CHECK_EQ(value->shape_view(), in_shape);\n    const DataType in_data_type = ref->data_type();\n    CHECK_EQ(value->data_type(), in_data_type);\n    Memcpy<device_type>(ctx->stream(), value->mut_dptr<void>(), ref->dptr<void>(),\n                        in_shape.elem_cnt() * GetSizeOfDataType(in_data_type));\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_SSP_VARIABLE_PROXY_KERNEL(device)                                              \\\n  REGISTER_USER_KERNEL(\"ssp_variable_proxy\")                                                    \\\n      .SetCreateFn<SspVariableProxyKernel<device>>()                                            \\\n      .SetIsMatchedHob(user_op::HobDeviceType() == device)                                      \\\n      .SetInplaceProposalFn([](const user_op::InferContext&,                                    \\\n                               user_op::AddInplaceArgPair AddInplaceArgPairFn) -> Maybe<void> { \\\n        OF_RETURN_IF_ERROR(AddInplaceArgPairFn(\"ref\", 0, \"var\", 0, true));                      \\\n        return Maybe<void>::Ok();                                                               \\\n      });\n\nREGISTER_SSP_VARIABLE_PROXY_KERNEL(DeviceType::kCPU)\n#ifdef WITH_CUDA\nREGISTER_SSP_VARIABLE_PROXY_KERNEL(DeviceType::kCUDA)\n#endif\n\n}  // namespace\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/stack_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/ep/include/primitive/copy_nd.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename Context>\nstd::unique_ptr<ep::primitive::CopyNd> NewCopyNdPrimitive(Context* ctx) {\n  return ep::primitive::NewPrimitive<ep::primitive::CopyNdFactory>(ctx->device_type(), 2);\n}\n\nclass StackKernel final : public user_op::OpKernel {\n public:\n  StackKernel() = default;\n  ~StackKernel() = default;\n\n private:\n  void InferShape(user_op::KernelInferContext* ctx) const override {\n    const ShapeView& first_input_shape_view = ctx->ShapeView4ArgNameAndIndex(\"in\", 0);\n    const int64_t axis = ctx->Attr<int64_t>(\"axis\");\n    const int64_t in_num_axes = first_input_shape_view.NumAxes();\n    DimVector out_dim_vec(in_num_axes + 1);\n    for (int i = 0; i < in_num_axes + 1; i++) {\n      if (i == axis) {\n        continue;\n      } else {\n        out_dim_vec.at(i) = first_input_shape_view.At(i);\n      }\n    }\n    for (const auto& in_arg_pair : ctx->inputs()) {\n      const ShapeView& input_shape_view =\n          ctx->ShapeView4ArgNameAndIndex(in_arg_pair.first, in_arg_pair.second);\n      CHECK_EQ(input_shape_view.NumAxes(), first_input_shape_view.NumAxes());\n      FOR_RANGE(int64_t, i, 0, in_num_axes + 1) {\n        if (i == axis) {\n          out_dim_vec.at(axis) += 1;\n        } else if (i < axis) {\n          CHECK_EQ(input_shape_view.At(i), out_dim_vec.at(i))\n              << \" Stack expects each tensor to be equal size\"\n                 \", but got \"\n              << first_input_shape_view.ToString() << \" at first input and \"\n              << input_shape_view.ToString();\n        } else {\n          CHECK_EQ(input_shape_view.At(i - 1), out_dim_vec.at(i))\n              << \" Stack expects each tensor to be equal size\"\n                 \", but got \"\n              << first_input_shape_view.ToString() << \" at first input and \"\n              << input_shape_view.ToString();\n        }\n      }\n    }\n\n    ctx->MutShapeView4ArgNameAndIndex(\"out\", 0).set_shape(Shape(out_dim_vec));\n  }\n\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    user_op::Tensor* out_tensor = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    if (out_tensor->shape_view().elem_cnt() == 0) { return; }\n    const int64_t axis = ctx->Attr<int64_t>(\"axis\");\n    const int64_t out_cols = out_tensor->shape_view().Count(axis);\n    const int64_t rows = out_tensor->shape_view().Count(0, axis);\n    CHECK_GT(rows, 0) << \"The multiplicative from axis 0 to axis \" << axis - 1\n                      << \" should be greater than 0. \";\n    auto primitive = NewCopyNdPrimitive(ctx);\n    CHECK(primitive) << \"Error in Stack kernel NewCopyNdPrimitive. \";\n    int64_t out_col_offset = 0;\n    for (const auto& in_arg_pair : ctx->inputs()) {\n      const user_op::Tensor* in_tensor =\n          ctx->Tensor4ArgNameAndIndex(in_arg_pair.first, in_arg_pair.second);\n      if (in_tensor->shape_view().elem_cnt() == 0) { continue; }\n      const int64_t in_cols = in_tensor->shape_view().Count(axis);\n      CHECK_EQ(in_tensor->shape_view().elem_cnt(), rows * in_cols)\n          << \"The element count of input tensor is not equal to `rows * in_cols`. \";\n      if (in_cols > 0) {\n        DimVector dst_shape = {rows, out_cols};\n        DimVector dst_pos_vec = {0, out_col_offset};\n        DimVector src_shape = {rows, in_cols};\n        DimVector src_pos_vec = {0, 0};\n        DimVector extent_vec = {rows, in_cols};\n        primitive->Launch(ctx->stream(), out_tensor->data_type(), 2, out_tensor->mut_dptr(),\n                          dst_shape.data(), dst_pos_vec.data(), in_tensor->dptr(), src_shape.data(),\n                          src_pos_vec.data(), extent_vec.data());\n      }\n      out_col_offset += in_cols;\n    }\n    CHECK_EQ(out_col_offset, out_cols) << \"The out column offset is not equal to out columns. \";\n  }\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\nauto CopyNdPrimitiveExists() {\n  return hob::make_custom(\"CopyNdPrimitiveExists\",\n                          [](const user_op::KernelRegContext& ctx) -> bool {\n                            return NewCopyNdPrimitive(&ctx).operator bool();\n                          });\n}\n\n}  // namespace\n\nREGISTER_USER_KERNEL(\"stack\").SetCreateFn<StackKernel>().SetIsMatchedHob(CopyNdPrimitiveExists()\n                                                                         == true);\n\nclass StackGradKernel final : public user_op::OpKernel {\n public:\n  StackGradKernel() = default;\n  ~StackGradKernel() override = default;\n\n private:\n  void InferShape(user_op::KernelInferContext* ctx) const override {\n    const auto axis = ctx->Attr<int64_t>(\"axis\");\n    const ShapeView& in_shape_view = ctx->ShapeView4ArgNameAndIndex(\"in\", 0);\n    int64_t total_dim_size = 0;\n    const int64_t like_num_axes = ctx->ShapeView4ArgNameAndIndex(\"like\", 0).NumAxes();\n    const int64_t in_num_axes = in_shape_view.NumAxes();\n    CHECK_LE(like_num_axes, in_num_axes)\n        << \"The num axes of `like` tensor should be less equal to num axes of `in` tensor. \";\n    CHECK_LE(axis, like_num_axes)\n        << \"The axis should be less than or equal to num axes of `like` tensor. \";\n    FOR_RANGE(size_t, i, 0, ctx->outputs().size()) {\n      const ShapeView& like_shape_view = ctx->ShapeView4ArgNameAndIndex(\"like\", i);\n      CHECK_EQ(like_shape_view.NumAxes(), like_num_axes)\n          << \"The num axes of `like` tensor at index \" << i\n          << \" should be equal to first `like` tensor. \";\n      FOR_RANGE(int64_t, j, 0, like_num_axes + 1) {\n        if (j == axis) {\n          total_dim_size += like_shape_view.Count(j);\n        } else if (j < axis) {\n          CHECK_EQ(in_shape_view.At(j), like_shape_view.At(j))\n              << \" Stack Grad expects the shape of input tensor is equal to like tensor's. \"\n                 \", but got \"\n              << in_shape_view.ToString() << \" at input and \" << like_shape_view.ToString()\n              << \"at like \";\n        } else {\n          CHECK_EQ(in_shape_view.At(j), like_shape_view.At(j - 1))\n              << \" Stack Grad expects the shape of input tensor is equal to like tensor's. \"\n                 \", but got \"\n              << in_shape_view.ToString() << \" at input and \" << like_shape_view.ToString()\n              << \"at like \";\n        }\n      }\n\n      if (ctx->TensorDesc4ArgNameAndIndex(\"out\", i)->is_dynamic()) {\n        auto mut_shape_view = ctx->MutShapeView4ArgNameAndIndex(\"out\", i);\n        DimVector out_i_dim_vec;\n        like_shape_view.ToDimVector(&out_i_dim_vec);\n        mut_shape_view.set_shape(Shape(out_i_dim_vec));\n      }\n    }\n    CHECK_EQ(total_dim_size, in_shape_view.Count(axis))\n        << \"The sum of dim size of each `like` tensor should be equal to `in` tensor count from \"\n           \"axis \"\n        << axis;\n  }\n\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* in_tensor = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    const int64_t axis = ctx->Attr<int64_t>(\"axis\");\n    const int64_t in_cols = in_tensor->shape_view().Count(axis);\n    const int64_t rows = in_tensor->shape_view().Count(0, axis);\n    CHECK_GT(rows, 0) << \"The multiplicative from axis 0 to axis \" << axis - 1\n                      << \" should be greater than 0. \";\n    auto primitive = NewCopyNdPrimitive(ctx);\n    CHECK(primitive) << \"Error in Stack Grad kernel NewCopyNdPrimitive. \";\n    int64_t in_col_offset = 0;\n    for (const auto& out_arg_pair : ctx->outputs()) {\n      user_op::Tensor* out_tensor =\n          ctx->Tensor4ArgNameAndIndex(out_arg_pair.first, out_arg_pair.second);\n      const int64_t out_cols = out_tensor->shape_view().Count(axis);\n      CHECK_EQ(out_tensor->shape_view().elem_cnt(), rows * out_cols)\n          << \"The element count of output tensor is not equal to `rows * out_cols`. \";\n      if (out_cols > 0) {\n        DimVector dst_shape = {rows, out_cols};\n        DimVector dst_pos_vec = {0, 0};\n        DimVector src_shape = {rows, in_cols};\n        DimVector src_pos_vec = {0, in_col_offset};\n        DimVector extent_vec = {rows, out_cols};\n        primitive->Launch(ctx->stream(), out_tensor->data_type(), 2, out_tensor->mut_dptr(),\n                          dst_shape.data(), dst_pos_vec.data(), in_tensor->dptr(), src_shape.data(),\n                          src_pos_vec.data(), extent_vec.data());\n      }\n      in_col_offset += out_cols;\n    }\n    CHECK_EQ(in_col_offset, in_cols) << \"The in column offset is not equal to in columns.\";\n  }\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\nREGISTER_USER_KERNEL(\"stack_grad\")\n    .SetCreateFn<StackGradKernel>()\n    .SetIsMatchedHob(CopyNdPrimitiveExists() == true);\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/stateful_opkernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/user/kernels/stateful_opkernel.h\"\n#include \"oneflow/core/framework/attr_value_accessor.h\"\n#include \"oneflow/core/framework/compute_complexity_fn_context.h\"\n#include \"oneflow/core/framework/user_op_conf.h\"\n#include \"oneflow/core/framework/user_op_registry_manager.h\"\n#include \"oneflow/core/eager/eager_blob_object.h\"\n#include \"oneflow/core/framework/attr_map.h\"\n#include \"oneflow/core/rpc/include/global_process_ctx.h\"\n#include \"oneflow/core/framework/global_tensor_infer_cache.h\"\n#include \"oneflow/core/operator/operator.h\"\n#include \"oneflow/core/profiler/profiler.h\"\n#include \"oneflow/core/profiler/profile_manager.h\"\n#include \"oneflow/core/profiler/event_recorder.h\"\n#include \"oneflow/core/eager/call_context.h\"\n\nnamespace oneflow {\nnamespace one {\n\nclass GlobalTensorInferResult;\n\nusing ArgVec = std::vector<std::pair<std::string, int32_t>>;\n\nusing EagerBlobObjectListRawPtr = const std::vector<std::shared_ptr<vm::EagerBlobObject>>*;\nusing GlobalTensorInferResultRawPtr = const GlobalTensorInferResult*;\n\nclass ZeroCopyBaseContextHelper {\n public:\n  ZeroCopyBaseContextHelper(const std::shared_ptr<const ArgTuple>& input_arg_tuple,\n                            const std::shared_ptr<const ArgTuple>& output_arg_tuple)\n      : input_arg_tuple_(input_arg_tuple), output_arg_tuple_(output_arg_tuple) {}\n\n#define RETURN_IF_FOUND(inputs, outputs, post_action)                                             \\\n  int32_t i = TryGetTensorTupleIndex(input_arg_tuple_->arg_name2bn_index2tensor_tuple_index(),    \\\n                                     arg_name, index);                                            \\\n  if (i >= 0) { return (inputs).at(i) post_action; }                                              \\\n  i = TryGetTensorTupleIndex(output_arg_tuple_->arg_name2bn_index2tensor_tuple_index(), arg_name, \\\n                             index);                                                              \\\n  if (i >= 0) { return (outputs).at(i) post_action; }\n\n  const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(eager::CallContext* call_ctx,\n                                                        const std::string& arg_name,\n                                                        const int32_t index) const {\n    RETURN_IF_FOUND(call_ctx->inputs(), call_ctx->outputs(), .get());\n    return nullptr;\n  }\n  user_op::TensorDesc* MutTensorDesc4ArgNameAndIndex(eager::CallContext* call_ctx,\n                                                     const std::string& arg_name,\n                                                     const int32_t index) const {\n    RETURN_IF_FOUND(call_ctx->inputs(), call_ctx->outputs(), .get());\n    return nullptr;\n  }\n\n  user_op::Tensor* Tensor4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name,\n                                          const int32_t index) const {\n    RETURN_IF_FOUND(call_ctx->inputs(), call_ctx->outputs(), .get());\n    if (arg_name == \"tmp_buffer\" && index == 0) { return call_ctx->mut_tmp_tensor(); }\n    return nullptr;\n  }\n\n  const GlobalTensorMeta* GlobalTensorMeta4ArgNameAndIndex(eager::CallContext* call_ctx,\n                                                           const std::string& arg_name,\n                                                           const int32_t index) const {\n    const auto& global_tensor_infer_result = call_ctx->global_tensor_infer_result();\n    RETURN_IF_FOUND(global_tensor_infer_result->input_tensor_metas(),\n                    global_tensor_infer_result->output_tensor_metas(), .shared_from_symbol().get());\n    return nullptr;\n  }\n\n  Optional<Symbol<ParallelDesc>> parallel_desc(eager::CallContext* call_ctx) const {\n    const auto& global_tensor_infer_result = call_ctx->global_tensor_infer_result();\n    if (!global_tensor_infer_result) { return Optional<Symbol<ParallelDesc>>(); }\n    if (!global_tensor_infer_result->input_tensor_metas().empty()) {\n      return global_tensor_infer_result->input_tensor_metas().at(0)->parallel_desc();\n    } else if (!global_tensor_infer_result->output_tensor_metas().empty()) {\n      return global_tensor_infer_result->output_tensor_metas().at(0)->parallel_desc();\n    } else {\n      UNIMPLEMENTED();\n      return Optional<Symbol<ParallelDesc>>();\n    }\n  }\n\n  const ParallelContext& parallel_ctx(eager::CallContext* call_ctx) const {\n    const auto& parallel_desc = this->parallel_desc(call_ctx);\n    if (parallel_desc.has_value()) {\n      const auto& parallel_desc_symbol = CHECK_JUST(parallel_desc);\n      return *CHECK_JUST(GetParallelContext4CurrentProcessCtx(parallel_desc_symbol));\n    } else {\n      static ParallelContext single_device_parallel_ctx(MakeSingleDeviceParallelCtx());\n      return single_device_parallel_ctx;\n    }\n  }\n\n  const ArgVec& inputs() const { return input_arg_tuple_->indexed_arg_name_and_index(); }\n  const ArgVec& outputs() const { return output_arg_tuple_->indexed_arg_name_and_index(); }\n\n private:\n  static int32_t TryGetTensorTupleIndex(const std::unordered_map<std::string, std::vector<int32_t>>&\n                                            arg_name2bn_index2tensor_tuple_index,\n                                        const std::string& arg_name, const int32_t arg_index) {\n    auto it = arg_name2bn_index2tensor_tuple_index.find(arg_name);\n    if (it != arg_name2bn_index2tensor_tuple_index.end()) { return it->second.at(arg_index); }\n    return -1;\n  }\n\n  static ParallelContext MakeSingleDeviceParallelCtx() {\n    ParallelContext single_device_parallel_ctx;\n    single_device_parallel_ctx.set_parallel_id(0);\n    single_device_parallel_ctx.set_parallel_num(1);\n    return single_device_parallel_ctx;\n  }\n\n  std::shared_ptr<const ArgTuple> input_arg_tuple_;\n  std::shared_ptr<const ArgTuple> output_arg_tuple_;\n};\n\nclass UserKernelBaseContextHelper final : public ZeroCopyBaseContextHelper {\n public:\n  UserKernelBaseContextHelper(DeviceType device_type,\n                              const std::shared_ptr<const ArgTuple>& input_arg_tuple,\n                              const std::shared_ptr<const ArgTuple>& output_arg_tuple)\n      : ZeroCopyBaseContextHelper(input_arg_tuple, output_arg_tuple), device_type_(device_type) {}\n\n  ~UserKernelBaseContextHelper() = default;\n\n  DeviceType device_type() const { return device_type_; }\n  const JobDesc& job_desc() const {\n    UNIMPLEMENTED();\n    return *(const JobDesc*)nullptr;\n  }\n\n private:\n  const DeviceType device_type_;\n};\n\nclass UserOpInferContextHelper final {\n public:\n  UserOpInferContextHelper(const user_op::UserOpConfWrapper* user_op_conf,\n                           const std::shared_ptr<const ArgTuple>& input_arg_tuple,\n                           const std::shared_ptr<const ArgTuple>& output_arg_tuple)\n      : user_op_conf_(user_op_conf),\n        zero_copy_base_ctx_helper_(input_arg_tuple, output_arg_tuple) {}\n\n  ~UserOpInferContextHelper() = default;\n\n  const user_op::TensorDesc* LogicalTensorDesc4ArgNameAndIndex(eager::CallContext* call_ctx,\n                                                               const std::string& arg_name,\n                                                               int32_t index) const {\n    UNIMPLEMENTED();\n    return nullptr;\n  }\n\n  const user_op::TensorDesc& InputTensorDesc(eager::CallContext* call_ctx,\n                                             const std::string& arg_name, int32_t index) const {\n    return *TensorDesc4ArgNameAndIndex(call_ctx, arg_name, index);\n  }\n  const user_op::TensorDesc& OutputTensorDesc(eager::CallContext* call_ctx,\n                                              const std::string& arg_name, int32_t index) const {\n    return *TensorDesc4ArgNameAndIndex(call_ctx, arg_name, index);\n  }\n  user_op::TensorDesc* MutOutputTensorDesc(eager::CallContext* call_ctx,\n                                           const std::string& arg_name, int32_t index) const {\n    return MutTensorDesc4ArgNameAndIndex(call_ctx, arg_name, index);\n  }\n  const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(eager::CallContext* call_ctx,\n                                                        const std::string& arg_name,\n                                                        int32_t index) const {\n    return zero_copy_base_ctx_helper_.TensorDesc4ArgNameAndIndex(call_ctx, arg_name, index);\n  }\n  user_op::TensorDesc* MutTensorDesc4ArgNameAndIndex(eager::CallContext* call_ctx,\n                                                     const std::string& arg_name,\n                                                     int32_t index) const {\n    return zero_copy_base_ctx_helper_.MutTensorDesc4ArgNameAndIndex(call_ctx, arg_name, index);\n  }\n\n  const Shape& InputShape(eager::CallContext* call_ctx, const std::string& arg_name,\n                          int32_t index) const {\n    return Shape4ArgNameAndIndex(call_ctx, arg_name, index);\n  }\n  const Shape& OutputShape(eager::CallContext* call_ctx, const std::string& arg_name,\n                           int32_t index) const {\n    return Shape4ArgNameAndIndex(call_ctx, arg_name, index);\n  }\n  void SetOutputShape(eager::CallContext* call_ctx, const std::string& arg_name, int32_t index,\n                      const Shape& shape) const {\n    SetShape4ArgNameAndIndex(call_ctx, arg_name, index, shape);\n  }\n  const Shape& Shape4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name,\n                                     int32_t index) const {\n    return NonNullTensorDesc4ArgNameAndIndex(call_ctx, arg_name, index).shape();\n  }\n  void SetShape4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name,\n                                int32_t index, const Shape& shape) const {\n    return MutNonNullTensorDesc4ArgNameAndIndex(call_ctx, arg_name, index)->set_shape(shape);\n  }\n  const Stride& InputStride(eager::CallContext* call_ctx, const std::string& arg_name,\n                            int32_t index) const {\n    return Stride4ArgNameAndIndex(call_ctx, arg_name, index);\n  }\n  const Stride& OutputStride(eager::CallContext* call_ctx, const std::string& arg_name,\n                             int32_t index) const {\n    return Stride4ArgNameAndIndex(call_ctx, arg_name, index);\n  }\n  void SetOutputStride(eager::CallContext* call_ctx, const std::string& arg_name, int32_t index,\n                       const Stride& stride) const {\n    return SetStride4ArgNameAndIndex(call_ctx, arg_name, index, stride);\n  }\n  const Stride& Stride4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name,\n                                       int32_t index) const {\n    return NonNullTensorDesc4ArgNameAndIndex(call_ctx, arg_name, index).stride();\n  }\n  void SetStride4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name,\n                                 int32_t index, const Stride& stride) const {\n    return MutNonNullTensorDesc4ArgNameAndIndex(call_ctx, arg_name, index)->set_stride(stride);\n  }\n  DataType InputDType(eager::CallContext* call_ctx, const std::string& arg_name,\n                      int32_t index) const {\n    return Dtype4ArgNameAndIndex(call_ctx, arg_name, index);\n  }\n  DataType OutputDType(eager::CallContext* call_ctx, const std::string& arg_name,\n                       int32_t index) const {\n    return Dtype4ArgNameAndIndex(call_ctx, arg_name, index);\n  }\n  void SetOutputDType(eager::CallContext* call_ctx, const std::string& arg_name, int32_t index,\n                      DataType data_type) const {\n    return SetDtype4ArgNameAndIndex(call_ctx, arg_name, index, data_type);\n  }\n  DataType Dtype4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name,\n                                 int32_t index) const {\n    return NonNullTensorDesc4ArgNameAndIndex(call_ctx, arg_name, index).data_type();\n  }\n  void SetDtype4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name,\n                                int32_t index, DataType data_type) const {\n    return MutNonNullTensorDesc4ArgNameAndIndex(call_ctx, arg_name, index)\n        ->set_data_type(data_type);\n  }\n  bool InputIsDynamic(eager::CallContext* call_ctx, const std::string& arg_name,\n                      int32_t index) const {\n    return IsDynamic4ArgNameAndIndex(call_ctx, arg_name, index);\n  }\n  bool OutputIsDynamic(eager::CallContext* call_ctx, const std::string& arg_name,\n                       int32_t index) const {\n    return IsDynamic4ArgNameAndIndex(call_ctx, arg_name, index);\n  }\n  void SetOutputIsDynamic(eager::CallContext* call_ctx, const std::string& arg_name, int32_t index,\n                          bool is_dynamic) const {\n    return SetIsDynamic4ArgNameAndIndex(call_ctx, arg_name, index, is_dynamic);\n  }\n  bool IsDynamic4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name,\n                                 int32_t index) const {\n    return NonNullTensorDesc4ArgNameAndIndex(call_ctx, arg_name, index).is_dynamic();\n  }\n  void SetIsDynamic4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name,\n                                    int32_t index, bool is_dynamic) const {\n    return MutNonNullTensorDesc4ArgNameAndIndex(call_ctx, arg_name, index)\n        ->set_is_dynamic(is_dynamic);\n  }\n\n  const ArgVec& inputs() const { return zero_copy_base_ctx_helper_.inputs(); }\n  const ArgVec& outputs() const { return zero_copy_base_ctx_helper_.outputs(); }\n  const JobDesc* job_desc() const {\n    UNIMPLEMENTED();\n    return nullptr;\n  }\n  const ParallelContext& parallel_ctx(eager::CallContext* call_ctx) const {\n    return zero_copy_base_ctx_helper_.parallel_ctx(call_ctx);\n  }\n  const ParallelDesc& parallel_desc(eager::CallContext* call_ctx) const {\n    return *CHECK_JUST(zero_copy_base_ctx_helper_.parallel_desc(call_ctx));\n  }\n  const SbpParallel& SbpParallel4ArgNameAndIndex(eager::CallContext* call_ctx,\n                                                 const std::string& arg_name, int32_t index) const {\n    const auto& nd_sbp = NdSbp4ArgNameAndIndex(call_ctx, arg_name, index);\n    CHECK_EQ(nd_sbp.sbp_parallel_size(), 1);\n    return nd_sbp.sbp_parallel(0);\n  }\n  const NdSbp& NdSbp4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name,\n                                     int32_t index) const {\n    return *CHECK_NOTNULL(zero_copy_base_ctx_helper_.GlobalTensorMeta4ArgNameAndIndex(\n                              call_ctx, arg_name, index))\n                ->nd_sbp();\n  }\n\n  int64_t parallel_num(eager::CallContext* call_ctx) const {\n    return parallel_ctx(call_ctx).parallel_num();\n  }\n\n  const std::string& input(const std::string& arg_name, int32_t index) const {\n    return user_op_conf().input(arg_name, index);\n  }\n  const std::string& output(const std::string& arg_name, int32_t index) const {\n    return user_op_conf().output(arg_name, index);\n  }\n  bool has_input(const std::string& arg_name, int32_t index) const {\n    return user_op_conf().has_input(arg_name, index);\n  }\n  bool has_output(const std::string& arg_name, int32_t index) const {\n    return user_op_conf().has_output(arg_name, index);\n  }\n  int32_t input_size(const std::string& arg_name) const {\n    return user_op_conf().input_size(arg_name);\n  }\n  int32_t output_size(const std::string& arg_name) const {\n    return user_op_conf().output_size(arg_name);\n  }\n  const std::string& op_name() const { return user_op_conf().op_name(); }\n  const std::string& op_type_name() const { return user_op_conf().op_type_name(); }\n  const std::string& op_loc() const { return user_op_conf_->op_conf().loc(); }\n\n  const user_op::UserOpConfWrapper& user_op_conf() const { return *user_op_conf_; }\n  const std::shared_ptr<const user_op::AttrVal>& Attr4Name(eager::CallContext* call_ctx,\n                                                           const std::string& attr_name) const {\n    return call_ctx->composed_attrs().Attr4Name(attr_name);\n  }\n\n private:\n  const user_op::TensorDesc& NonNullTensorDesc4ArgNameAndIndex(eager::CallContext* call_ctx,\n                                                               const std::string& arg_name,\n                                                               int32_t index) const {\n    const user_op::TensorDesc* tensor_desc = TensorDesc4ArgNameAndIndex(call_ctx, arg_name, index);\n    if (!tensor_desc) { LOG(FATAL) << \"Arg (\" << arg_name << \",\" << index << \") is not found\"; }\n    return *tensor_desc;\n  }\n  user_op::TensorDesc* MutNonNullTensorDesc4ArgNameAndIndex(eager::CallContext* call_ctx,\n                                                            const std::string& arg_name,\n                                                            int32_t index) const {\n    user_op::TensorDesc* tensor_desc = MutTensorDesc4ArgNameAndIndex(call_ctx, arg_name, index);\n    if (!tensor_desc) { LOG(FATAL) << \"Arg (\" << arg_name << \",\" << index << \") is not found\"; }\n    return tensor_desc;\n  }\n\n  const user_op::UserOpConfWrapper* user_op_conf_;\n  ZeroCopyBaseContextHelper zero_copy_base_ctx_helper_;\n};\n\nclass UserOpInferContext : public user_op::InferContext {\n public:\n  UserOpInferContext(const UserOpInferContextHelper* helper, eager::CallContext* call_ctx)\n      : helper_(helper), call_ctx_(call_ctx) {}\n\n  ~UserOpInferContext() override = default;\n\n  const user_op::TensorDesc* LogicalTensorDesc4ArgNameAndIndex(const std::string& arg_name,\n                                                               int32_t index) const override {\n    return helper_->LogicalTensorDesc4ArgNameAndIndex(call_ctx_, arg_name, index);\n  }\n\n  const user_op::TensorDesc& InputTensorDesc(const std::string& arg_name,\n                                             int32_t index) const override {\n    return helper_->InputTensorDesc(call_ctx_, arg_name, index);\n  }\n  const user_op::TensorDesc& OutputTensorDesc(const std::string& arg_name,\n                                              int32_t index) const override {\n    return helper_->OutputTensorDesc(call_ctx_, arg_name, index);\n  }\n  user_op::TensorDesc* MutOutputTensorDesc(const std::string& arg_name, int32_t index) override {\n    return helper_->MutOutputTensorDesc(call_ctx_, arg_name, index);\n  }\n  const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(const std::string& arg_name,\n                                                        int32_t index) const {\n    return helper_->TensorDesc4ArgNameAndIndex(call_ctx_, arg_name, index);\n  }\n  user_op::TensorDesc* MutTensorDesc4ArgNameAndIndex(const std::string& arg_name, int32_t index) {\n    return helper_->MutTensorDesc4ArgNameAndIndex(call_ctx_, arg_name, index);\n  }\n\n  const Shape& InputShape(const std::string& arg_name, int32_t index) const override {\n    return helper_->InputShape(call_ctx_, arg_name, index);\n  }\n  const Shape& OutputShape(const std::string& arg_name, int32_t index) const override {\n    return helper_->OutputShape(call_ctx_, arg_name, index);\n  }\n  void SetOutputShape(const std::string& arg_name, int32_t index, const Shape& shape) override {\n    return helper_->SetOutputShape(call_ctx_, arg_name, index, shape);\n  }\n  const Shape& Shape4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override {\n    return helper_->Shape4ArgNameAndIndex(call_ctx_, arg_name, index);\n  }\n  void SetShape4ArgNameAndIndex(const std::string& arg_name, int32_t index,\n                                const Shape& shape) override {\n    return helper_->SetShape4ArgNameAndIndex(call_ctx_, arg_name, index, shape);\n  }\n  const Stride& InputStride(const std::string& arg_name, int32_t index) const override {\n    return helper_->InputStride(call_ctx_, arg_name, index);\n  }\n  const Stride& OutputStride(const std::string& arg_name, int32_t index) const override {\n    return helper_->InputStride(call_ctx_, arg_name, index);\n  }\n  void SetOutputStride(const std::string& arg_name, int32_t index, const Stride& stride) override {\n    return helper_->SetOutputStride(call_ctx_, arg_name, index, stride);\n  }\n  const Stride& Stride4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override {\n    return helper_->Stride4ArgNameAndIndex(call_ctx_, arg_name, index);\n  }\n  void SetStride4ArgNameAndIndex(const std::string& arg_name, int32_t index,\n                                 const Stride& stride) override {\n    return helper_->SetStride4ArgNameAndIndex(call_ctx_, arg_name, index, stride);\n  }\n  DataType InputDType(const std::string& arg_name, int32_t index) const override {\n    return helper_->InputDType(call_ctx_, arg_name, index);\n  }\n  DataType OutputDType(const std::string& arg_name, int32_t index) const override {\n    return helper_->OutputDType(call_ctx_, arg_name, index);\n  }\n  void SetOutputDType(const std::string& arg_name, int32_t index, DataType data_type) override {\n    return helper_->SetOutputDType(call_ctx_, arg_name, index, data_type);\n  }\n  DataType Dtype4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override {\n    return helper_->Dtype4ArgNameAndIndex(call_ctx_, arg_name, index);\n  }\n  void SetDtype4ArgNameAndIndex(const std::string& arg_name, int32_t index,\n                                DataType data_type) override {\n    return helper_->SetDtype4ArgNameAndIndex(call_ctx_, arg_name, index, data_type);\n  }\n  MemoryFormat InputMemoryFormat(const std::string& arg_name, int32_t index) const override {\n    return MemoryFormat4ArgNameAndIndex(arg_name, index);\n  }\n  MemoryFormat OutputMemoryFormat(const std::string& arg_name, int32_t index) const override {\n    return MemoryFormat4ArgNameAndIndex(arg_name, index);\n  }\n  void SetOutputMemoryFormat(const std::string& arg_name, int32_t index,\n                             MemoryFormat memory_format) override {\n    return SetMemoryFormat4ArgNameAndIndex(arg_name, index, memory_format);\n  }\n  MemoryFormat MemoryFormat4ArgNameAndIndex(const std::string& arg_name,\n                                            int32_t index) const override {\n    return TensorDesc4ArgNameAndIndex(arg_name, index)->memory_format();\n  }\n  void SetMemoryFormat4ArgNameAndIndex(const std::string& arg_name, int32_t index,\n                                       MemoryFormat memory_format) override {\n    MutTensorDesc4ArgNameAndIndex(arg_name, index)->set_memory_format(memory_format);\n  }\n  bool InputIsDynamic(const std::string& arg_name, int32_t index) const override {\n    return helper_->InputIsDynamic(call_ctx_, arg_name, index);\n  }\n  bool OutputIsDynamic(const std::string& arg_name, int32_t index) const override {\n    return helper_->OutputIsDynamic(call_ctx_, arg_name, index);\n  }\n  void SetOutputIsDynamic(const std::string& arg_name, int32_t index, bool is_dynamic) override {\n    return helper_->SetOutputIsDynamic(call_ctx_, arg_name, index, is_dynamic);\n  }\n  bool IsDynamic4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override {\n    return helper_->IsDynamic4ArgNameAndIndex(call_ctx_, arg_name, index);\n  }\n  void SetIsDynamic4ArgNameAndIndex(const std::string& arg_name, int32_t index,\n                                    bool is_dynamic) override {\n    return helper_->SetIsDynamic4ArgNameAndIndex(call_ctx_, arg_name, index, is_dynamic);\n  }\n\n  const ArgVec& inputs() const override { return helper_->inputs(); }\n  const ArgVec& outputs() const override { return helper_->outputs(); }\n  const JobDesc* job_desc() const override { return helper_->job_desc(); }\n  const ParallelContext& parallel_ctx() const override { return helper_->parallel_ctx(call_ctx_); }\n  const ParallelDesc& parallel_desc() const override { return helper_->parallel_desc(call_ctx_); }\n  const SbpParallel& SbpParallel4ArgNameAndIndex(const std::string& arg_name,\n                                                 int32_t index) const override {\n    return helper_->SbpParallel4ArgNameAndIndex(call_ctx_, arg_name, index);\n  }\n  const NdSbp& NdSbp4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override {\n    return helper_->NdSbp4ArgNameAndIndex(call_ctx_, arg_name, index);\n  }\n\n  int64_t parallel_num() const override { return helper_->parallel_num(call_ctx_); }\n\n  const std::string& input(const std::string& arg_name, int32_t index) const override {\n    return helper_->input(arg_name, index);\n  }\n  const std::string& output(const std::string& arg_name, int32_t index) const override {\n    return helper_->output(arg_name, index);\n  }\n  bool has_input(const std::string& arg_name, int32_t index) const override {\n    return helper_->has_input(arg_name, index);\n  }\n  bool has_output(const std::string& arg_name, int32_t index) const override {\n    return helper_->has_output(arg_name, index);\n  }\n  int32_t input_size(const std::string& arg_name) const override {\n    return helper_->input_size(arg_name);\n  }\n  int32_t output_size(const std::string& arg_name) const override {\n    return helper_->output_size(arg_name);\n  }\n  const std::string& op_name() const override { return helper_->op_name(); }\n  const std::string& op_type_name() const override { return helper_->op_type_name(); }\n  const std::string& op_loc() const override { return helper_->op_loc(); }\n\n private:\n  const std::shared_ptr<const user_op::AttrVal>& Attr4Name(\n      const std::string& attr_name) const override {\n    return helper_->Attr4Name(call_ctx_, attr_name);\n  }\n\n  const UserOpInferContextHelper* helper_;\n  eager::CallContext* call_ctx_;\n};\n\nclass UserKernelComputeContextHelper final {\n public:\n  UserKernelComputeContextHelper(DeviceType device_type,\n                                 const user_op::UserOpConfWrapper* user_op_conf,\n                                 const std::shared_ptr<const ArgTuple>& input_arg_tuple,\n                                 const std::shared_ptr<const ArgTuple>& output_arg_tuple)\n      : user_op_conf_(user_op_conf),\n        base_ctx_helper_(device_type, input_arg_tuple, output_arg_tuple) {}\n\n  ~UserKernelComputeContextHelper() = default;\n\n  const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(eager::CallContext* call_ctx,\n                                                        const std::string& arg_name,\n                                                        int32_t index) const {\n    return base_ctx_helper_.TensorDesc4ArgNameAndIndex(call_ctx, arg_name, index);\n  }\n\n  user_op::Tensor* Tensor4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name,\n                                          int32_t index) const {\n    return base_ctx_helper_.Tensor4ArgNameAndIndex(call_ctx, arg_name, index);\n  }\n\n  DeviceType device_type() const { return base_ctx_helper_.device_type(); }\n  const ParallelContext& parallel_ctx(eager::CallContext* call_ctx) const {\n    return base_ctx_helper_.parallel_ctx(call_ctx);\n  }\n\n  const ArgVec& inputs() const { return base_ctx_helper_.inputs(); }\n  const ArgVec& outputs() const { return base_ctx_helper_.outputs(); }\n\n  const user_op::UserOpConfWrapper& user_op_conf() const { return *user_op_conf_; }\n  const std::shared_ptr<const user_op::AttrVal>& Attr4Name(eager::CallContext* call_ctx,\n                                                           const std::string& attr_name) const {\n    return call_ctx->composed_attrs().Attr4Name(attr_name);\n  }\n\n private:\n  const user_op::UserOpConfWrapper* user_op_conf_;\n  UserKernelBaseContextHelper base_ctx_helper_;\n};\n\nclass UserKernelComputeContext final : public user_op::KernelComputeContext {\n public:\n  UserKernelComputeContext(const UserKernelComputeContextHelper* helper,\n                           eager::CallContext* call_ctx, ep::Stream* stream)\n      : helper_(helper), call_ctx_(call_ctx), stream_(stream) {}\n\n  ~UserKernelComputeContext() = default;\n\n  const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(const std::string& arg_name,\n                                                        int32_t index) const override {\n    return helper_->TensorDesc4ArgNameAndIndex(call_ctx_, arg_name, index);\n  }\n\n  user_op::Tensor* Tensor4ArgNameAndIndex(const std::string& arg_name, int32_t index) override {\n    return helper_->Tensor4ArgNameAndIndex(call_ctx_, arg_name, index);\n  }\n\n  ep::Stream* stream() override {\n    CHECK_NOTNULL(stream_);\n    return stream_;\n  }\n\n  DeviceType device_type() const override { return helper_->device_type(); }\n\n  const ParallelContext& parallel_ctx() const override { return helper_->parallel_ctx(call_ctx_); }\n\n  const ArgVec& inputs() const override { return helper_->inputs(); }\n  const ArgVec& outputs() const override { return helper_->outputs(); }\n\n private:\n  const user_op::UserOpConfWrapper& user_op_conf() const override {\n    return helper_->user_op_conf();\n  }\n\n  const std::shared_ptr<const user_op::AttrVal>& Attr4Name(\n      const std::string& attr_name) const override {\n    return helper_->Attr4Name(call_ctx_, attr_name);\n  }\n\n  const UserKernelComputeContextHelper* helper_;\n  eager::CallContext* call_ctx_;\n  ep::Stream* stream_;\n};\n\nclass UserKernelRegContextHelper final {\n public:\n  UserKernelRegContextHelper(DeviceType device_type, const user_op::UserOpConfWrapper* user_op_conf,\n                             const std::shared_ptr<const ArgTuple>& input_arg_tuple,\n                             const std::shared_ptr<const ArgTuple>& output_arg_tuple)\n      : user_op_conf_(user_op_conf),\n        base_ctx_helper_(device_type, input_arg_tuple, output_arg_tuple) {}\n  ~UserKernelRegContextHelper() = default;\n\n  DeviceType device_type() const { return base_ctx_helper_.device_type(); }\n  const ParallelContext& parallel_ctx(eager::CallContext* call_ctx) const {\n    return base_ctx_helper_.parallel_ctx(call_ctx);\n  }\n  const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(eager::CallContext* call_ctx,\n                                                        const std::string& arg_name,\n                                                        int32_t index) const {\n    return base_ctx_helper_.TensorDesc4ArgNameAndIndex(call_ctx, arg_name, index);\n  }\n  const ArgVec& inputs() const { return base_ctx_helper_.inputs(); }\n  const ArgVec& outputs() const { return base_ctx_helper_.outputs(); }\n\n  const user_op::UserOpConfWrapper& user_op_conf() const { return *user_op_conf_; }\n\n  const std::shared_ptr<const user_op::AttrVal>& Attr4Name(eager::CallContext* call_ctx,\n                                                           const std::string& attr_name) const {\n    return call_ctx->composed_attrs().Attr4Name(attr_name);\n  }\n\n private:\n  const user_op::UserOpConfWrapper* user_op_conf_;\n  UserKernelBaseContextHelper base_ctx_helper_;\n};\n\nclass UserKernelRegContext final : public user_op::KernelRegContext {\n public:\n  UserKernelRegContext(const UserKernelRegContextHelper* helper, eager::CallContext* call_ctx)\n      : helper_(helper), call_ctx_(call_ctx) {}\n  ~UserKernelRegContext() = default;\n\n  DeviceType device_type() const override { return helper_->device_type(); }\n  const ParallelContext& parallel_ctx() const override { return helper_->parallel_ctx(call_ctx_); }\n  const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(const std::string& arg_name,\n                                                        int32_t index) const override {\n    return helper_->TensorDesc4ArgNameAndIndex(call_ctx_, arg_name, index);\n  }\n  const ArgVec& inputs() const override { return helper_->inputs(); }\n  const ArgVec& outputs() const override { return helper_->outputs(); }\n\n  const user_op::UserOpConfWrapper& user_op_conf() const override {\n    return helper_->user_op_conf();\n  }\n\n private:\n  const std::shared_ptr<const user_op::AttrVal>& Attr4Name(\n      const std::string& attr_name) const override {\n    return helper_->Attr4Name(call_ctx_, attr_name);\n  }\n\n  const UserKernelRegContextHelper* helper_;\n  eager::CallContext* call_ctx_;\n};\n\nclass UserKernelInitAndCacheContextHelper final {\n public:\n  UserKernelInitAndCacheContextHelper(DeviceType device_type,\n                                      const user_op::UserOpConfWrapper* user_op_conf,\n                                      const std::shared_ptr<const ArgTuple>& input_arg_tuple,\n                                      const std::shared_ptr<const ArgTuple>& output_arg_tuple)\n      : user_op_conf_(user_op_conf),\n        base_ctx_helper_(device_type, input_arg_tuple, output_arg_tuple) {}\n\n  ~UserKernelInitAndCacheContextHelper() = default;\n\n  DeviceType device_type() const { return base_ctx_helper_.device_type(); }\n  const ParallelContext& parallel_ctx(eager::CallContext* call_ctx) const {\n    return base_ctx_helper_.parallel_ctx(call_ctx);\n  }\n  const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(eager::CallContext* call_ctx,\n                                                        const std::string& arg_name,\n                                                        int32_t index) const {\n    return base_ctx_helper_.TensorDesc4ArgNameAndIndex(call_ctx, arg_name, index);\n  }\n  const user_op::TensorDesc* LogicalTensorDesc4ArgNameAndIndex(eager::CallContext* call_ctx,\n                                                               const std::string& arg_name,\n                                                               int32_t index) const {\n    return base_ctx_helper_.GlobalTensorMeta4ArgNameAndIndex(call_ctx, arg_name, index);\n  }\n  const SbpParallel& SbpParallel4ArgNameAndIndex(eager::CallContext* call_ctx,\n                                                 const std::string& arg_name, int32_t index) const {\n    const auto& nd_sbp = NdSbp4ArgNameAndIndex(call_ctx, arg_name, index);\n    CHECK_EQ(nd_sbp.sbp_parallel_size(), 1);\n    return nd_sbp.sbp_parallel(0);\n  }\n\n  const NdSbp& NdSbp4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name,\n                                     int32_t index) const {\n    return *CHECK_NOTNULL(\n                base_ctx_helper_.GlobalTensorMeta4ArgNameAndIndex(call_ctx, arg_name, index))\n                ->nd_sbp();\n  }\n\n  const ArgVec& inputs() const { return base_ctx_helper_.inputs(); }\n  const ArgVec& outputs() const { return base_ctx_helper_.outputs(); }\n  const ParallelDesc& parallel_desc(eager::CallContext* call_ctx) const {\n    return *CHECK_JUST(base_ctx_helper_.parallel_desc(call_ctx));\n  }\n\n  const std::shared_ptr<const user_op::AttrVal>& Attr4Name(eager::CallContext* call_ctx,\n                                                           const std::string& attr_name) const {\n    return call_ctx->composed_attrs().Attr4Name(attr_name);\n  }\n\n  const user_op::UserOpConfWrapper& user_op_conf() const { return *user_op_conf_; }\n\n private:\n  const user_op::UserOpConfWrapper* user_op_conf_;\n  UserKernelBaseContextHelper base_ctx_helper_;\n};\n\nclass UserKernelInitAndCacheContext final : public user_op::KernelInitContext,\n                                            public user_op::KernelCacheContext {\n public:\n  UserKernelInitAndCacheContext(const UserKernelInitAndCacheContextHelper* helper,\n                                eager::CallContext* call_ctx, ep::Stream* stream)\n      : helper_(helper), call_ctx_(call_ctx), stream_(stream) {}\n\n  ~UserKernelInitAndCacheContext() override = default;\n\n  ep::Stream* stream() override {\n    CHECK_NOTNULL(stream_);\n    return stream_;\n  }\n\n  DeviceType device_type() const override { return helper_->device_type(); }\n  const ParallelContext& parallel_ctx() const override { return helper_->parallel_ctx(call_ctx_); }\n  const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(const std::string& arg_name,\n                                                        int32_t index) const override {\n    return helper_->TensorDesc4ArgNameAndIndex(call_ctx_, arg_name, index);\n  }\n  const user_op::TensorDesc* LogicalTensorDesc4ArgNameAndIndex(const std::string& arg_name,\n                                                               int32_t index) const override {\n    return helper_->LogicalTensorDesc4ArgNameAndIndex(call_ctx_, arg_name, index);\n  }\n  const SbpParallel& SbpParallel4ArgNameAndIndex(const std::string& arg_name,\n                                                 int32_t index) const override {\n    return helper_->SbpParallel4ArgNameAndIndex(call_ctx_, arg_name, index);\n  }\n\n  const NdSbp& NdSbp4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override {\n    return helper_->NdSbp4ArgNameAndIndex(call_ctx_, arg_name, index);\n  }\n\n  const ArgVec& inputs() const override { return helper_->inputs(); }\n  const ArgVec& outputs() const override { return helper_->outputs(); }\n  const ParallelDesc& parallel_desc() const override { return helper_->parallel_desc(call_ctx_); }\n\n private:\n  const std::shared_ptr<const user_op::AttrVal>& Attr4Name(\n      const std::string& attr_name) const override {\n    return helper_->Attr4Name(call_ctx_, attr_name);\n  }\n\n  const user_op::UserOpConfWrapper& user_op_conf() const override {\n    return helper_->user_op_conf();\n  }\n\n  const UserKernelInitAndCacheContextHelper* helper_;\n  eager::CallContext* call_ctx_;\n  ep::Stream* stream_;\n};\n\nnamespace {\n\nMaybe<void> InitTensorTupleIndexes4Bns(const std::shared_ptr<const OperatorConf>& op_conf,\n                                       const ArgVec& indexed_input_pairs,\n                                       const ArgVec& indexed_output_pairs,\n                                       OpArgsVector<int64_t>* input_tuple_indexes4const_ibns,\n                                       OpArgsVector<int64_t>* input_tuple_indexes4mut_ibns,\n                                       OpArgsVector<int64_t>* output_tuple_indexes4mut_obns,\n                                       OpArgsVector<int64_t>* output_tuple_indexes4mut2_obns,\n                                       small_vector<bool>* output_tuple_indexes2is_mut2_type) {\n  const auto* op_reg_val =\n      user_op::UserOpRegistryMgr::Get().GetOpRegistryResult(op_conf->user_conf().op_type_name());\n  CHECK_NOTNULL_OR_RETURN(op_reg_val);\n\n  ArgModifierSignature arg_modifier_signature;\n  for (const auto& pair : indexed_input_pairs) {\n    const std::string ibn = GenRepeatedBn(pair.first, pair.second);\n    arg_modifier_signature.mutable_ibn2input_blob_modifier()->insert(\n        {ibn, user_op::InputArgModifier()});\n  }\n  for (const auto& pair : indexed_output_pairs) {\n    const std::string obn = GenRepeatedBn(pair.first, pair.second);\n    arg_modifier_signature.mutable_obn2output_blob_modifier()->insert(\n        {obn, user_op::OutputArgModifier()});\n  }\n  user_op::UserOpConfWrapper op_conf_wrapper(op_conf);\n  if (op_reg_val->input_arg_modify_fn) {\n    user_op::GetInputArgModifier GetInputArgModifierFn =\n        [&arg_modifier_signature](const std::string& in_arg_name,\n                                  int32_t in_arg_index) -> user_op::InputArgModifier* {\n      const std::string ibn = GenRepeatedBn(in_arg_name, in_arg_index);\n      auto* map = arg_modifier_signature.mutable_ibn2input_blob_modifier();\n      return &map->at(ibn);\n    };\n    JUST(op_reg_val->input_arg_modify_fn(GetInputArgModifierFn, op_conf_wrapper));\n  }\n  if (op_reg_val->output_arg_modify_fn) {\n    user_op::GetOutputArgModifier GetOutputArgModifierFn =\n        [&arg_modifier_signature](const std::string& in_arg_name,\n                                  int32_t in_arg_index) -> user_op::OutputArgModifier* {\n      const std::string obn = GenRepeatedBn(in_arg_name, in_arg_index);\n      auto* map = arg_modifier_signature.mutable_obn2output_blob_modifier();\n      return &map->at(obn);\n    };\n    JUST(op_reg_val->output_arg_modify_fn(GetOutputArgModifierFn, op_conf_wrapper));\n  }\n\n  for (int i = 0; i < indexed_input_pairs.size(); i++) {\n    const auto& pair = indexed_input_pairs.at(i);\n    const std::string ibn = GenRepeatedBn(pair.first, pair.second);\n    if (arg_modifier_signature.ibn2input_blob_modifier().at(ibn).is_mutable()) {\n      input_tuple_indexes4mut_ibns->emplace_back(i);\n    } else {\n      input_tuple_indexes4const_ibns->emplace_back(i);\n    }\n  }\n\n  for (int i = 0; i < indexed_output_pairs.size(); i++) {\n    const auto& pair = indexed_output_pairs.at(i);\n    const std::string obn = GenRepeatedBn(pair.first, pair.second);\n    if (arg_modifier_signature.obn2output_blob_modifier().at(obn).header_infered_before_compute()) {\n      output_tuple_indexes4mut_obns->emplace_back(i);\n      output_tuple_indexes2is_mut2_type->emplace_back(false);\n    } else {\n      output_tuple_indexes4mut2_obns->emplace_back(i);\n      output_tuple_indexes2is_mut2_type->emplace_back(true);\n    }\n  }\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\n/* static */ Maybe<StatefulOpKernel> StatefulOpKernel::New(\n    const std::shared_ptr<OperatorConf>& op_conf, const Symbol<Stream>& stream,\n    const AttrMap& base_attrs, const std::shared_ptr<const ParallelDesc>& parallel_desc,\n    const std::shared_ptr<const ArgTuple>& input_arg_tuple,\n    const std::shared_ptr<const ArgTuple>& output_arg_tuple) {\n  auto opkernel = std::shared_ptr<StatefulOpKernel>(new StatefulOpKernel());\n  opkernel->base_attrs_ = base_attrs;\n  opkernel->op_conf_ = op_conf;\n  opkernel->user_op_conf_.reset(new user_op::UserOpConfWrapper(op_conf));\n  opkernel->stream_ = stream;\n  opkernel->input_arg_tuple_ = input_arg_tuple;\n  opkernel->output_arg_tuple_ = output_arg_tuple;\n\n  const DeviceType device_type = CHECK_JUST(DeviceType4DeviceTag(op_conf->device_tag()));\n  const user_op::UserOpConfWrapper* user_op_conf = opkernel->user_op_conf_.get();\n  opkernel->op_infer_ctx_helper_.reset(\n      new UserOpInferContextHelper(user_op_conf, input_arg_tuple, output_arg_tuple));\n\n  opkernel->init_and_cache_ctx_helper_.reset(new UserKernelInitAndCacheContextHelper(\n      device_type, opkernel->user_op_conf_.get(), opkernel->input_arg_tuple_,\n      opkernel->output_arg_tuple_));\n  opkernel->compute_ctx_helper_.reset(new UserKernelComputeContextHelper(\n      device_type, user_op_conf, input_arg_tuple, output_arg_tuple));\n  opkernel->reg_ctx_helper_.reset(\n      new UserKernelRegContextHelper(device_type, user_op_conf, input_arg_tuple, output_arg_tuple));\n  const auto* op_reg_val =\n      user_op::UserOpRegistryMgr::Get().GetOpRegistryResult(user_op_conf->op_type_name());\n  CHECK_NOTNULL_OR_RETURN(op_reg_val);\n  if (op_reg_val->logical_tensor_desc_infer_fn) {\n    opkernel->tensor_desc_infer_fn_ = op_reg_val->logical_tensor_desc_infer_fn;\n  } else {\n    return Error::UnimplementedError();\n  }\n  opkernel->data_type_infer_fn_ = op_reg_val->data_type_infer_fn;\n\n  JUST(InitTensorTupleIndexes4Bns(\n      op_conf, input_arg_tuple->indexed_arg_name_and_index(),\n      output_arg_tuple->indexed_arg_name_and_index(), &opkernel->input_tuple_indexes4const_ibns_,\n      &opkernel->input_tuple_indexes4mut_ibns_, &opkernel->output_tuple_indexes4mut_obns_,\n      &opkernel->output_tuple_indexes4mut2_obns_, &opkernel->output_tuple_indexes2is_mut2_type_));\n\n  return opkernel;\n}\n\nStatefulOpKernel::~StatefulOpKernel() = default;\n\nsize_t StatefulOpKernel::InferTmpSize(eager::CallContext* call_ctx,\n                                      const user_op::OpKernel* user_opkernel) const {\n  UserOpInferContext op_infer_ctx(op_infer_ctx_helper_.get(), call_ctx);\n  const auto& InferTmpSizeFn = GetInferTmpSizeFn(user_opkernel);\n  return InferTmpSizeFn(&op_infer_ctx);\n}\n\nMaybe<void> StatefulOpKernel::ChooseOpKernel(eager::CallContext* call_ctx,\n                                             const user_op::OpKernel** user_opkernel,\n                                             bool* need_temp_storage) {\n  DataType primary_dtype = kInvalidDataType;\n  const auto& inputs = call_ctx->inputs();\n  const auto& outputs = call_ctx->outputs();\n  if (likely(!inputs.empty())) {\n    primary_dtype = inputs[0]->data_type();\n  } else if (likely(!outputs.empty())) {\n    primary_dtype = outputs[0]->data_type();\n  } else {\n    // do nothing\n  }\n\n  UserKernelRegContext reg_ctx(reg_ctx_helper_.get(), call_ctx);\n  for (const auto& pair : dtype2cached_kernels_[primary_dtype]) {\n    if (likely(pair.first->is_matched_hob->get(reg_ctx))) {\n      *need_temp_storage = pair.first->need_temp_storage;\n      *user_opkernel = pair.second.get();\n      return Maybe<void>::Ok();\n    }\n  }\n\n  OF_PROFILER_RANGE_GUARD(\"fallback\");\n\n  const auto& op_type_name = user_op_conf_->op_type_name();\n  const auto* kernel_reg_val =\n      JUST(user_op::UserOpRegistryMgr::Get().GetOpKernelRegistryResult(op_type_name, reg_ctx));\n  CHECK_NOTNULL(kernel_reg_val);\n  auto* kernel = kernel_reg_val->create_fn();\n  dtype2cached_kernels_[primary_dtype].push_back(\n      {kernel_reg_val, std::shared_ptr<const user_op::OpKernel>(kernel)});\n\n  infer_tmp_size_fn_map_.emplace(kernel, &kernel_reg_val->infer_tmp_size_fn);\n  *need_temp_storage = kernel_reg_val->need_temp_storage;\n  *user_opkernel = kernel;\n  return Maybe<void>::Ok();\n}\n\nvoid StatefulOpKernel::TryInitOpKernelStateAndCache(eager::CallContext* call_ctx,\n                                                    ep::Stream* stream,\n                                                    const user_op::OpKernel* op_kernel,\n                                                    user_op::OpKernelState** state,\n                                                    user_op::OpKernelCache** cache) {\n  UserKernelInitAndCacheContext init_and_cache_ctx(init_and_cache_ctx_helper_.get(), call_ctx,\n                                                   stream);\n  if (state != nullptr) {\n    auto it = op_kernel_state_map_.find(op_kernel);\n    if (it != op_kernel_state_map_.end()) {\n      *state = it->second.get();\n    } else {\n      auto created_state = op_kernel->CreateOpKernelState(&init_and_cache_ctx);\n      op_kernel_state_map_.emplace(op_kernel, created_state);\n      *state = created_state.get();\n    }\n  }\n\n  {\n    auto& cache_in_map = op_kernel_cache_map_[op_kernel];\n    op_kernel->InitOpKernelCacheWithFlags(&init_and_cache_ctx,\n                                          user_op::OpKernelCache::kAllMayChanged, &cache_in_map);\n    *cache = cache_in_map.get();\n  }\n}\n\nconst user_op::InferTmpSizeFn& StatefulOpKernel::GetInferTmpSizeFn(\n    const user_op::OpKernel* op_kernel) const {\n  return *infer_tmp_size_fn_map_.at(op_kernel);\n}\n\nuser_op::TensorDescInferFn StatefulOpKernel::TensorDescInferFn() const {\n  return tensor_desc_infer_fn_;\n}\n\nuser_op::DataTypeInferFn StatefulOpKernel::DataTypeInferFn() const { return data_type_infer_fn_; }\n\nvoid StatefulOpKernel::Compute(eager::CallContext* call_ctx, ep::Stream* stream,\n                               const user_op::OpKernel* user_opkernel,\n                               user_op::OpKernelState* state,\n                               const user_op::OpKernelCache* cache) const {\n  UserKernelComputeContext compute_context(compute_ctx_helper_.get(), call_ctx, stream);\n  auto* compute_ctx = &compute_context;\n  OF_PROFILER_RANGE_GUARD(\"Compute\");\n  auto er_guard = CHECK_JUST(profiler::EventRecorder::CreateKernelEventRecorder(\n      op_type_name(),\n#if defined(WITH_CUDA)\n      [compute_ctx]() -> int64_t {\n        const auto CalMemorySize = [compute_ctx](const one::ArgVec& args) -> int64_t {\n          const auto Func = [compute_ctx](int64_t mem_size, const auto& pair) {\n            const auto tensor = compute_ctx->Tensor4ArgNameAndIndex(pair.first, pair.second);\n            return mem_size\n                   + tensor->shape_view().elem_cnt() * GetSizeOfDataType(tensor->data_type());\n          };\n          return std::accumulate(args.begin(), args.end(), static_cast<int64_t>(0), Func);\n        };\n        return CalMemorySize(compute_ctx->inputs()) + CalMemorySize(compute_ctx->outputs());\n      },\n#endif\n      [call_ctx]() -> std::pair<std::string, int64_t> {\n        std::stringstream ss;\n        std::size_t hash = 0;\n        for (size_t i = 0; i < call_ctx->inputs().size(); i++) {\n          const auto& shape = call_ctx->inputs().at(i)->shape();\n          ss << shape;\n          if (i != call_ctx->inputs().size() - 1) { ss << \", \"; }\n          AddHash(&hash, shape);\n        }\n        return {ss.str(), hash};\n      },\n      [call_ctx]() -> std::pair<std::string, int64_t> {\n        const std::string attr_str = call_ctx->composed_attrs().ToString();\n        return {attr_str, std::hash<std::string>{}(attr_str)};\n      }));\n  user_opkernel->Compute(compute_ctx, state, cache);\n  CHECK_JUST(compute_ctx->stream()->GetAsyncError());\n}\n\n}  // namespace one\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/stateful_opkernel.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_KERNELS_STATEFUL_OPKERNEL_H_\n#define ONEFLOW_USER_KERNELS_STATEFUL_OPKERNEL_H_\n\n#include \"oneflow/core/eager/eager_blob_object.h\"\n#include \"oneflow/core/common/tensor_meta.h\"\n#include \"oneflow/core/kernel/kernel.h\"\n#include \"oneflow/core/framework/op_kernel.h\"\n#include \"oneflow/core/framework/stream.h\"\n#include \"oneflow/core/framework/user_op_kernel_registry.h\"\n#include \"oneflow/core/framework/arg_tuple.h\"\n#include \"oneflow/core/framework/op_interpreter.h\"\n#include \"oneflow/core/common/op_args_vector.h\"\n\nnamespace oneflow {\n\nclass AttrMap;\n\nnamespace vm {\nstruct OpCallInstructionUtil;\n}\n\nnamespace eager {\nclass CallContext;\n}\n\nnamespace one {\n\nusing ArgVec = std::vector<std::pair<std::string, int32_t>>;\n\nclass UserKernelRegContextHelper;\nclass UserOpInferContextHelper;\nclass UserKernelInitAndCacheContextHelper;\nclass UserKernelComputeContextHelper;\n\nclass StatefulOpKernel final {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(StatefulOpKernel);\n  static Maybe<StatefulOpKernel> New(const std::shared_ptr<OperatorConf>& op_conf,\n                                     const Symbol<Stream>& stream, const AttrMap& base_attrs,\n                                     const std::shared_ptr<const ParallelDesc>& parallel_desc,\n                                     const std::shared_ptr<const ArgTuple>& input_arg_tuple,\n                                     const std::shared_ptr<const ArgTuple>& output_arg_tuple);\n  ~StatefulOpKernel();\n  const Symbol<Stream>& stream() const { return stream_; }\n  const std::shared_ptr<MemoryCase>& mem_case() const { return stream_->device()->mem_case(); }\n  const std::string& op_type_name() const { return op_conf_->user_conf().op_type_name(); }\n  const OpArgsVector<int64_t>& input_tuple_indexes4const_ibns() const {\n    return input_tuple_indexes4const_ibns_;\n  }\n  const OpArgsVector<int64_t>& input_tuple_indexes4mut_ibns() const {\n    return input_tuple_indexes4mut_ibns_;\n  }\n  const OpArgsVector<int64_t>& output_tuple_indexes4mut_obns() const {\n    return output_tuple_indexes4mut_obns_;\n  }\n  const OpArgsVector<int64_t>& output_tuple_indexes4mut2_obns() const {\n    return output_tuple_indexes4mut2_obns_;\n  }\n\n  bool output_is_mut2_type(int64_t index) const {\n    return output_tuple_indexes2is_mut2_type_.at(index);\n  }\n\n  const AttrMap& base_attrs() const { return base_attrs_; }\n\n  size_t InferTmpSize(eager::CallContext* call_ctx, const user_op::OpKernel* user_opkernel) const;\n\n  Maybe<void> ChooseOpKernel(eager::CallContext* call_ctx, const user_op::OpKernel** user_opkernel,\n                             bool* need_temp_storage);\n\n  const OperatorConf& op_conf() const { return *op_conf_; }\n\n  const ArgTuple* input_arg_tuple() const { return input_arg_tuple_.get(); }\n  const ArgTuple* output_arg_tuple() const { return output_arg_tuple_.get(); }\n\n private:\n  friend struct vm::OpCallInstructionUtil;\n  StatefulOpKernel() = default;\n\n  void Compute(eager::CallContext* call_ctx, ep::Stream* stream,\n               const user_op::OpKernel* user_opkernel, user_op::OpKernelState* state,\n               const user_op::OpKernelCache* cache) const;\n\n  user_op::TensorDescInferFn TensorDescInferFn() const;\n  user_op::DataTypeInferFn DataTypeInferFn() const;\n\n  void TryInitOpKernelStateAndCache(eager::CallContext* call_ctx, ep::Stream* stream,\n                                    const user_op::OpKernel* op_kernel,\n                                    user_op::OpKernelState** state, user_op::OpKernelCache** cache);\n\n  user_op::OpKernelState* mut_opkernel_state(const user_op::OpKernel* opkernel) {\n    return op_kernel_state_map_.at(opkernel).get();\n  }\n\n  const user_op::InferTmpSizeFn& GetInferTmpSizeFn(const user_op::OpKernel* op_kernel) const;\n\n  std::shared_ptr<OperatorConf> op_conf_;\n  AttrMap base_attrs_;\n  std::unique_ptr<user_op::UserOpConfWrapper> user_op_conf_;\n  Symbol<Stream> stream_;\n  std::unique_ptr<const UserKernelRegContextHelper> reg_ctx_helper_;\n  std::unique_ptr<const UserOpInferContextHelper> op_infer_ctx_helper_;\n  std::unique_ptr<const UserKernelInitAndCacheContextHelper> init_and_cache_ctx_helper_;\n  std::unique_ptr<const UserKernelComputeContextHelper> compute_ctx_helper_;\n  std::shared_ptr<const ArgTuple> input_arg_tuple_;\n  std::shared_ptr<const ArgTuple> output_arg_tuple_;\n  user_op::TensorDescInferFn tensor_desc_infer_fn_;\n  user_op::DataTypeInferFn data_type_infer_fn_;\n  // NOTE: every device has its own stateful local opkernel instance,\n  // so only group kernels by dtype\n  std::array<std::vector<std::pair<const user_op::OpKernelRegistryResult*,\n                                   std::shared_ptr<const user_op::OpKernel>>>,\n             DataType_ARRAYSIZE>\n      dtype2cached_kernels_;\n  HashMap<const user_op::OpKernel*, std::shared_ptr<user_op::OpKernelState>> op_kernel_state_map_;\n  HashMap<const user_op::OpKernel*, std::shared_ptr<user_op::OpKernelCache>> op_kernel_cache_map_;\n  HashMap<const user_op::OpKernel*, const user_op::InferTmpSizeFn*> infer_tmp_size_fn_map_;\n  OpArgsVector<int64_t> input_tuple_indexes4const_ibns_;\n  OpArgsVector<int64_t> input_tuple_indexes4mut_ibns_;\n  OpArgsVector<int64_t> output_tuple_indexes4mut_obns_;\n  OpArgsVector<int64_t> output_tuple_indexes4mut2_obns_;\n  OpArgsVector<bool> output_tuple_indexes2is_mut2_type_;\n};\n\n}  // namespace one\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_KERNELS_STATEFUL_OPKERNEL_H_\n"
  },
  {
    "path": "oneflow/user/kernels/summary_kernels.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/user/summary/events_writer.h\"\n#include \"oneflow/user/summary/env_time.h\"\n#include \"oneflow/user/summary/histogram.h\"\n#include \"oneflow/user/summary/event_writer_helper.h\"\n\n#include <time.h>\n#include <cstdint>\n\nnamespace oneflow {\n\nnamespace summary {\n\ntemplate<typename T>\nclass SummaryWriteScalar final : public user_op::OpKernel {\n public:\n  SummaryWriteScalar() = default;\n  ~SummaryWriteScalar() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* step = ctx->Tensor4ArgNameAndIndex(\"step\", 0);\n    const user_op::Tensor* tag = ctx->Tensor4ArgNameAndIndex(\"tag\", 0);\n    const user_op::Tensor* value = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n\n    T* tvalue = const_cast<T*>(value->dptr<T>());\n    CHECK_NOTNULL(tvalue);\n    int64_t* istep = const_cast<int64_t*>(step->dptr<int64_t>());\n    CHECK_NOTNULL(istep);\n    int8_t* ctag = const_cast<int8_t*>(tag->dptr<int8_t>());\n    CHECK_NOTNULL(ctag);\n    std::string tag_str(reinterpret_cast<char*>(ctag), tag->shape_view().elem_cnt());\n    EventWriterHelper<DeviceType::kCPU, T>::WriteScalarToFile(\n        istep[0], static_cast<double>(tvalue[0]), tag_str);\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return true; }\n};\n\n#define REGISTER_SCALAR_USER_KERNEL(dtype)                            \\\n  REGISTER_USER_KERNEL(\"summary_write_scalar\")                        \\\n      .SetCreateFn<SummaryWriteScalar<dtype>>()                       \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \\\n                       && (user_op::HobDataType(\"in\", 0) == GetDataType<dtype>::value));\n\nREGISTER_SCALAR_USER_KERNEL(double)\nREGISTER_SCALAR_USER_KERNEL(float)\nREGISTER_SCALAR_USER_KERNEL(int64_t)\nREGISTER_SCALAR_USER_KERNEL(int32_t)\n\nclass CreateSummaryWriter final : public user_op::OpKernel {\n public:\n  CreateSummaryWriter() = default;\n  ~CreateSummaryWriter() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const std::string& logdir = ctx->Attr<std::string>(\"logdir\");\n    CHECK_JUST(Singleton<EventsWriter>::Get()->Init(logdir));\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return true; }\n};\n\nREGISTER_USER_KERNEL(\"create_summary_writer\")\n    .SetCreateFn<CreateSummaryWriter>()\n    .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU));\n\nclass FlushSummaryWriter final : public user_op::OpKernel {\n public:\n  FlushSummaryWriter() = default;\n  ~FlushSummaryWriter() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    Singleton<EventsWriter>::Get()->Flush();\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return true; }\n};\n\nREGISTER_USER_KERNEL(\"flush_summary_writer\")\n    .SetCreateFn<FlushSummaryWriter>()\n    .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU));\n\ntemplate<typename T>\nclass SummaryWriteHistogram final : public user_op::OpKernel {\n public:\n  SummaryWriteHistogram() = default;\n  ~SummaryWriteHistogram() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* step = ctx->Tensor4ArgNameAndIndex(\"step\", 0);\n    const user_op::Tensor* tag = ctx->Tensor4ArgNameAndIndex(\"tag\", 0);\n    const user_op::Tensor* value = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    int64_t* istep = const_cast<int64_t*>(step->dptr<int64_t>());\n    CHECK_NOTNULL(istep);\n    int8_t* ctag = const_cast<int8_t*>(tag->dptr<int8_t>());\n    CHECK_NOTNULL(ctag);\n    std::string tag_str(reinterpret_cast<char*>(ctag), tag->shape_view().elem_cnt());\n    EventWriterHelper<DeviceType::kCPU, T>::WriteHistogramToFile(static_cast<float>(istep[0]),\n                                                                 *value, tag_str);\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return true; }\n};\n\n#define REGISTER_HISTOGRAM_USER_KERNEL(dtype)                         \\\n  REGISTER_USER_KERNEL(\"summary_write_histogram\")                     \\\n      .SetCreateFn<SummaryWriteHistogram<dtype>>()                    \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \\\n                       && (user_op::HobDataType(\"in\", 0) == GetDataType<dtype>::value));\n\nREGISTER_HISTOGRAM_USER_KERNEL(double)\nREGISTER_HISTOGRAM_USER_KERNEL(float)\nREGISTER_HISTOGRAM_USER_KERNEL(int64_t)\nREGISTER_HISTOGRAM_USER_KERNEL(int32_t)\nREGISTER_HISTOGRAM_USER_KERNEL(int8_t)\nREGISTER_HISTOGRAM_USER_KERNEL(uint8_t)\n\ntemplate<typename T>\nclass SummaryWritePb final : public user_op::OpKernel {\n public:\n  SummaryWritePb() = default;\n  ~SummaryWritePb() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* step = ctx->Tensor4ArgNameAndIndex(\"step\", 0);\n    const user_op::Tensor* value = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    int64_t* istep = const_cast<int64_t*>(step->dptr<int64_t>());\n    CHECK_NOTNULL(istep);\n    int8_t* cvalue = const_cast<int8_t*>(value->dptr<int8_t>());\n    CHECK_NOTNULL(cvalue);\n    std::string value_str(reinterpret_cast<char*>(cvalue), value->shape_view().elem_cnt());\n    EventWriterHelper<DeviceType::kCPU, T>::WritePbToFile(istep[0], value_str);\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return true; }\n};\n\nREGISTER_USER_KERNEL(\"summary_write_pb\")\n    .SetCreateFn<SummaryWritePb<int8_t>>()\n    .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)\n                     && (user_op::HobDataType(\"in\", 0) == GetDataType<int8_t>::value));\n\ntemplate<typename T>\nclass SummaryWriteImage final : public user_op::OpKernel {\n public:\n  SummaryWriteImage() = default;\n  ~SummaryWriteImage() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* step = ctx->Tensor4ArgNameAndIndex(\"step\", 0);\n    const user_op::Tensor* tag = ctx->Tensor4ArgNameAndIndex(\"tag\", 0);\n    const user_op::Tensor* value = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    int64_t* istep = const_cast<int64_t*>(step->dptr<int64_t>());\n    CHECK_NOTNULL(istep);\n    char* ctag = const_cast<char*>(tag->dptr<char>());\n    CHECK_NOTNULL(ctag);\n    std::string tag_str(ctag, tag->shape_view().elem_cnt());\n    EventWriterHelper<DeviceType::kCPU, T>::WriteImageToFile(static_cast<int64_t>(istep[0]), *value,\n                                                             tag_str);\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return true; }\n};\n\nREGISTER_USER_KERNEL(\"summary_write_image\")\n    .SetCreateFn<SummaryWriteImage<uint8_t>>()\n    .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)\n                     && (user_op::HobDataType(\"in\", 0) == GetDataType<uint8_t>::value));\n\n}  // namespace summary\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/tensor_buffer_kernels.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/new_kernel_util.h\"\n#include \"oneflow/core/common/tensor_buffer.h\"\n#include \"oneflow/core/thread/thread_manager.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nclass TensorBufferToTensorKernel final : public user_op::OpKernel {\n public:\n  TensorBufferToTensorKernel() = default;\n  ~TensorBufferToTensorKernel() override = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    const ShapeView& in_shape = in->shape_view();\n    CHECK_EQ(in->data_type(), DataType::kTensorBuffer);\n    const ShapeView& out_shape = out->shape_view();\n    const auto& instance_shape = ctx->Attr<Shape>(\"instance_shape\");\n    CHECK_EQ(out_shape.NumAxes(), in_shape.NumAxes() + instance_shape.NumAxes());\n    FOR_RANGE(int64_t, i, 0, in_shape.NumAxes()) { CHECK_EQ(out_shape.At(i), in_shape.At(i)); }\n    FOR_RANGE(int64_t, i, 0, instance_shape.NumAxes()) {\n      CHECK_EQ(out_shape.At(i + in_shape.NumAxes()), instance_shape.At(i));\n    }\n    const auto data_type = ctx->Attr<DataType>(\"dtype\");\n    CHECK_EQ(out->data_type(), data_type);\n    const int64_t instance_size = instance_shape.elem_cnt() * GetSizeOfDataType(data_type);\n    const auto* in_ptr = in->dptr<TensorBuffer>();\n    auto* out_ptr = out->mut_dptr<char>();\n    MultiThreadLoop(in_shape.elem_cnt(), [&](size_t i) {\n      const TensorBuffer* tensor_buffer = in_ptr + i;\n      CHECK_EQ(tensor_buffer->nbytes(), instance_size);\n      CHECK_EQ(tensor_buffer->data_type(), data_type);\n      CHECK(tensor_buffer->shape_view() == instance_shape);\n      Memcpy<DeviceType::kCPU>(ctx->stream(), out_ptr + i * instance_size, tensor_buffer->data(),\n                               instance_size);\n    });\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\nREGISTER_USER_KERNEL(\"tensor_buffer_to_tensor\")\n    .SetCreateFn<TensorBufferToTensorKernel>()\n    .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)\n                     && (user_op::HobDataType(\"in\", 0) == DataType::kTensorBuffer));\n\nclass TensorToTensorBufferKernel final : public user_op::OpKernel {\n public:\n  TensorToTensorBufferKernel() = default;\n  ~TensorToTensorBufferKernel() override = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    const ShapeView& in_shape = in->shape_view();\n    const ShapeView& out_shape = out->shape_view();\n    const auto instance_dims = ctx->Attr<int32_t>(\"instance_dims\");\n    CHECK_LT(instance_dims, in_shape.NumAxes());\n    FOR_RANGE(int64_t, i, 0, in_shape.NumAxes() - instance_dims) {\n      CHECK_EQ(out_shape.At(i), in_shape.At(i));\n    }\n    DimVector instance_dim_vec;\n    FOR_RANGE(int64_t, i, in_shape.NumAxes() - instance_dims, in_shape.NumAxes()) {\n      instance_dim_vec.emplace_back(in_shape.At(i));\n    }\n    const Shape instance_shape(instance_dim_vec);\n    const auto data_type = in->data_type();\n    CHECK(IsTriviallyCopyableDataType(data_type));\n    const int64_t instance_size = instance_shape.elem_cnt() * GetSizeOfDataType(data_type);\n    const auto* in_ptr = in->dptr<char>();\n    auto* out_ptr = out->mut_dptr<TensorBuffer>();\n    MultiThreadLoop(in_shape.Count(0, in_shape.NumAxes() - instance_dims), [&](size_t i) {\n      TensorBuffer* tensor_buffer = out_ptr + i;\n      tensor_buffer->Resize(instance_shape, data_type);\n      CHECK_EQ(tensor_buffer->nbytes(), instance_size);\n      Memcpy<DeviceType::kCPU>(ctx->stream(), tensor_buffer->mut_data(), in_ptr + i * instance_size,\n                               instance_size);\n    });\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\nREGISTER_USER_KERNEL(\"tensor_to_tensor_buffer\")\n    .SetCreateFn<TensorToTensorBufferKernel>()\n    .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)\n                     && (user_op::HobDataType(\"out\", 0) == DataType::kTensorBuffer));\n\ntemplate<typename T>\nclass GenTensorBuffer final : public user_op::OpKernel {\n public:\n  GenTensorBuffer() = default;\n  ~GenTensorBuffer() override = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    const int64_t num_tensor_buffers = ctx->Attr<Shape>(\"shape\").elem_cnt();\n    const std::vector<Shape>& shape_list = ctx->Attr<std::vector<Shape>>(\"shape_list\");\n    const std::vector<float>& value_list = ctx->Attr<std::vector<float>>(\"value_list\");\n    CHECK_EQ(num_tensor_buffers, shape_list.size());\n    CHECK_EQ(num_tensor_buffers, value_list.size());\n    MultiThreadLoop(num_tensor_buffers, [&](size_t i) {\n      TensorBuffer* tensor_buffer = out->mut_dptr<TensorBuffer>() + i;\n      const Shape& shape = shape_list.at(i);\n      tensor_buffer->Resize(shape, GetDataType<T>::value);\n      T* begin = reinterpret_cast<T*>(tensor_buffer->mut_data());\n      std::fill(begin, begin + shape.elem_cnt(), static_cast<T>(value_list.at(i)));\n    });\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_GEN_TENSOR_BUFFER_KERNEL(dtype)                      \\\n  REGISTER_USER_KERNEL(\"gen_tensor_buffer\")                           \\\n      .SetCreateFn<GenTensorBuffer<dtype>>()                          \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \\\n                       && (user_op::HobAttr<DataType>(\"data_type\") == GetDataType<dtype>::value));\n\nREGISTER_GEN_TENSOR_BUFFER_KERNEL(int32_t)\nREGISTER_GEN_TENSOR_BUFFER_KERNEL(int64_t)\nREGISTER_GEN_TENSOR_BUFFER_KERNEL(float)\nREGISTER_GEN_TENSOR_BUFFER_KERNEL(double)\n\n#undef REGISTER_GEN_TENSOR_BUFFER_KERNEL\n\nclass TensorBufferToListOfTensors final : public user_op::OpKernel {\n public:\n  TensorBufferToListOfTensors() = default;\n  ~TensorBufferToListOfTensors() override = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    CHECK_GT(in->shape_view().elem_cnt(), 0);\n    CHECK_EQ(in->data_type(), DataType::kTensorBuffer);\n    const DataType out_dtype = ctx->Attr<DataType>(\"out_dtype\");\n    CHECK(IsTriviallyCopyableDataType(out_dtype));\n    const bool dynamic_out = ctx->Attr<bool>(\"dynamic_out\");\n    const auto* in_ptr = in->dptr<TensorBuffer>();\n    MultiThreadLoop(in->shape_view().elem_cnt(), [&](size_t i) {\n      const TensorBuffer* tensor_buffer = in_ptr + i;\n      user_op::Tensor* out_i = ctx->Tensor4ArgNameAndIndex(\"out\", i);\n      CHECK_EQ(out_dtype, tensor_buffer->data_type());\n      if (dynamic_out) {\n        CHECK_LE(tensor_buffer->shape_view().elem_cnt(), out_i->shape_view().elem_cnt());\n        out_i->mut_shape_view().set_shape(tensor_buffer->shape_view());\n      } else {\n        CHECK_EQ(tensor_buffer->shape_view().elem_cnt(), out_i->shape_view().elem_cnt());\n      }\n      Memcpy<DeviceType::kCPU>(ctx->stream(), out_i->mut_dptr<void>(), tensor_buffer->data(),\n                               tensor_buffer->nbytes());\n    });\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return true; }\n};\n\nREGISTER_USER_KERNEL(\"tensor_buffer_to_list_of_tensors\")\n    .SetCreateFn<TensorBufferToListOfTensors>()\n    .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)\n                     && (user_op::HobDataType(\"in\", 0) == DataType::kTensorBuffer));\n\nclass TensorBufferToListOfTensorsV2 final : public user_op::OpKernel {\n public:\n  TensorBufferToListOfTensorsV2() = default;\n  ~TensorBufferToListOfTensorsV2() override = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    CHECK_GT(in->shape_view().elem_cnt(), 0);\n    CHECK_EQ(in->data_type(), DataType::kTensorBuffer);\n    const std::vector<DataType>& out_dtypes = ctx->Attr<std::vector<DataType>>(\"out_dtypes\");\n    const bool dynamic_out = ctx->Attr<bool>(\"dynamic_out\");\n    const auto* in_ptr = in->dptr<TensorBuffer>();\n    MultiThreadLoop(in->shape_view().elem_cnt(), [&](size_t i) {\n      CHECK(IsTriviallyCopyableDataType(out_dtypes[i]));\n      const TensorBuffer* tensor_buffer = in_ptr + i;\n      user_op::Tensor* out_i = ctx->Tensor4ArgNameAndIndex(\"out\", i);\n      CHECK_EQ(out_dtypes[i], tensor_buffer->data_type());\n      if (dynamic_out) {\n        CHECK_LE(tensor_buffer->shape_view().elem_cnt(), out_i->shape_view().elem_cnt());\n        out_i->mut_shape_view().set_shape(tensor_buffer->shape_view());\n      } else {\n        CHECK_EQ(tensor_buffer->shape_view().elem_cnt(), out_i->shape_view().elem_cnt());\n      }\n      Memcpy<DeviceType::kCPU>(ctx->stream(), out_i->mut_dptr<void>(), tensor_buffer->data(),\n                               tensor_buffer->nbytes());\n    });\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return true; }\n};\n\nREGISTER_USER_KERNEL(\"tensor_buffer_to_list_of_tensors_v2\")\n    .SetCreateFn<TensorBufferToListOfTensorsV2>()\n    .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)\n                     && (user_op::HobDataType(\"in\", 0) == DataType::kTensorBuffer));\n\n}  // namespace\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/tensor_constant_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/ep/include/primitive/tensor_fill.h\"\n\nnamespace oneflow {\nnamespace user_op {\n\nnamespace {\n\ntemplate<typename Context>\nstd::unique_ptr<ep::primitive::TensorFill> NewTensorFillPrimitive(Context* ctx) {\n  const DataType data_type = ctx->TensorDesc4ArgNameAndIndex(\"out\", 0)->data_type();\n  return ep::primitive::NewPrimitive<ep::primitive::TensorFillFactory>(ctx->device_type(),\n                                                                       data_type);\n}\n\nclass TensorConstantKernel final : public OpKernel {\n public:\n  TensorConstantKernel() = default;\n  ~TensorConstantKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const Tensor* value_tensor = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    Tensor* out_tensor = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    CHECK(value_tensor->shape_view().NumAxes() <= 1 && value_tensor->shape_view().elem_cnt() == 1)\n        << \"Only scalar tensor as filled value is supported!\";\n\n    const int64_t elem_cnt = out_tensor->shape_view().elem_cnt();\n    CHECK_GE(elem_cnt, 0);\n    if (elem_cnt == 0) { return; }\n    std::unique_ptr<ep::primitive::TensorFill> tensor_fill = NewTensorFillPrimitive(ctx);\n    CHECK(tensor_fill);\n    tensor_fill->Launch(ctx->stream(), value_tensor->raw_dptr(), out_tensor->mut_dptr(), elem_cnt);\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\nauto TensorFillPrimitiveExists() {\n  return hob::make_custom(\"TensorFillPrimitiveExists\", [](const user_op::KernelRegContext& ctx) {\n    return NewTensorFillPrimitive(&ctx).operator bool();\n  });\n}\n\nREGISTER_USER_KERNEL(\"tensor_constant\")\n    .SetCreateFn<TensorConstantKernel>()\n    .SetIsMatchedHob(TensorFillPrimitiveExists() == true);\n\n}  // namespace\n\n}  // namespace user_op\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/tf_pool_cpu_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/user/kernels/op_kernel_wrapper.h\"\n#include \"oneflow/user/utils/pool_util.h\"\n#include \"oneflow/core/common/eigen_util.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nstruct PoolOpKernelCache final : public user_op::OpKernelCache {\n  Params3D params_3d;\n  explicit PoolOpKernelCache(const Params3D& params_3d) : params_3d(params_3d) {}\n  const Params3D& GetParams3D() const { return params_3d; }\n};\n\nstd::shared_ptr<PoolOpKernelCache> InitPoolOpKernelCache(user_op::KernelCacheContext* ctx,\n                                                         const int32_t& dim) {\n  const Shape& x_shape = ctx->TensorDesc4ArgNameAndIndex(\"x\", 0)->shape();\n  const std::string& data_format = ctx->Attr<std::string>(\"data_format\");\n  const std::string& padding = ctx->Attr<std::string>(\"padding\");\n  const auto& padding_before = ctx->Attr<std::vector<int32_t>>(\"padding_before\");\n  const auto& padding_after = ctx->Attr<std::vector<int32_t>>(\"padding_after\");\n  const std::vector<int32_t>& pool_size = ctx->Attr<std::vector<int32_t>>(\"pool_size\");\n  const std::vector<int32_t>& strides = ctx->Attr<std::vector<int32_t>>(\"strides\");\n  const bool ceil_mode = ctx->Attr<bool>(\"ceil_mode\");\n  Params3D params_3d = Params3D(dim, x_shape, data_format, padding, padding_before, padding_after,\n                                pool_size, strides, ceil_mode);\n  std::shared_ptr<PoolOpKernelCache> state(new PoolOpKernelCache(params_3d));\n  return state;\n}\n\ntemplate<typename T>\nstruct PoolCpuKernelUtil {\n public:\n  typedef std::function<T()> ForwardInitialize;\n  typedef std::function<void(const T& lhs, T& rhs)> CFirstProcess;\n  typedef std::function<void(const int64_t in_col, const int64_t out_col,\n                             ConstEigenMatrixMap<T>& in_mat, EigenMatrixMap<T>& out_mat)>\n      CLastProcess;\n  typedef std::function<void(const int64_t size, T& out)> CFirstFinalize;\n  typedef std::function<void(const int64_t size, const int64_t col, EigenMatrixMap<T>& out_mat)>\n      CLastFinalize;\n  typedef std::function<void(const T& in, const T& out, const T& out_diff, const int64_t size,\n                             T& in_diff)>\n      CFirstProcessGrad;\n  typedef std::function<void(const int64_t out_col, const int64_t in_col, const int64_t size,\n                             ConstEigenArrayMap<T>& out_arr, ConstEigenArrayMap<T>& in_arr,\n                             ConstEigenArrayMap<T>& out_diff_arr, EigenArrayMap<T>& in_diff_arr)>\n      CLastProcessGrad;\n\n  static void CFirstForward(const Params3D& params_3d, const user_op::Tensor* in_blob,\n                            user_op::Tensor* out_blob, const ForwardInitialize& initialize,\n                            const CFirstProcess& process, const CFirstFinalize& finalize) {\n    const Shape& in = params_3d.GetXShape5D();\n    const Shape& out = params_3d.GetYShape5D();\n    const std::vector<int32_t>& pool_size = params_3d.pool_size_3d();\n    const std::vector<int32_t>& strides = params_3d.strides_3d();\n    const std::vector<int32_t>& padding_before = params_3d.padding_before_3d();\n\n    const T* input = in_blob->dptr<T>();\n    T* output = out_blob->mut_dptr<T>();\n    FOR_RANGE(int64_t, n, 0, in.At(0)) {\n      FOR_RANGE(int64_t, c, 0, in.At(1)) {\n        FOR_RANGE(int64_t, pd, 0, out.At(2)) {\n          int64_t dstart = pd * strides.at(0) - padding_before.at(0);\n          int64_t dend = std::min(dstart + pool_size.at(0), in.At(2));\n          dstart = std::max(dstart, static_cast<int64_t>(0));\n          FOR_RANGE(int64_t, ph, 0, out.At(3)) {\n            int64_t hstart = ph * strides.at(1) - padding_before.at(1);\n            int64_t hend = std::min(hstart + pool_size.at(1), in.At(3));\n            hstart = std::max(hstart, static_cast<int64_t>(0));\n            FOR_RANGE(int64_t, pw, 0, out.At(4)) {\n              int64_t wstart = pw * strides.at(2) - padding_before.at(2);\n              int64_t wend = std::min(wstart + pool_size.at(2), in.At(4));\n              wstart = std::max(wstart, static_cast<int64_t>(0));\n\n              const int64_t pool_index = pd * out.Count(3) + ph * out.At(4) + pw;\n              T res = initialize();\n              FOR_RANGE(int64_t, d, dstart, dend) {\n                FOR_RANGE(int64_t, h, hstart, hend) {\n                  FOR_RANGE(int64_t, w, wstart, wend) {\n                    const int64_t input_index = d * in.Count(3) + h * in.At(4) + w;\n                    process(input[input_index], res);\n                  }\n                }\n              }\n              finalize((dend - dstart) * (hend - hstart) * (wend - wstart), res);\n              output[pool_index] = res;\n            }\n          }\n        }\n        input += in.Count(2);\n        output += out.Count(2);\n      }\n    }\n  }\n\n  static void CFirstBackward(const Params3D& params_3d, const user_op::Tensor* out_diff_blob,\n                             const user_op::Tensor* out_blob, const user_op::Tensor* in_blob,\n                             user_op::Tensor* in_diff_blob, const CFirstProcessGrad& process) {\n    const Shape& in = params_3d.GetXShape5D();\n    const Shape& out = params_3d.GetYShape5D();\n    const std::vector<int32_t>& pool_size = params_3d.pool_size_3d();\n    const std::vector<int32_t>& strides = params_3d.strides_3d();\n    const std::vector<int32_t>& padding_before = params_3d.padding_before_3d();\n\n    const T* output_diff = out_diff_blob->dptr<T>();\n    const T* output = out_blob->dptr<T>();\n    const T* input = in_blob->dptr<T>();\n    std::memset(in_diff_blob->mut_dptr<T>(), T(0), in.elem_cnt() * sizeof(T));\n    T* input_diff = in_diff_blob->mut_dptr<T>();\n    FOR_RANGE(int64_t, n, 0, in.At(0)) {\n      FOR_RANGE(int64_t, c, 0, in.At(1)) {\n        FOR_RANGE(int64_t, pd, 0, out.At(2)) {\n          int64_t dstart = pd * strides.at(0) - padding_before.at(0);\n          int64_t dend = std::min(dstart + pool_size.at(0), in.At(2));\n          dstart = std::max(dstart, static_cast<int64_t>(0));\n          FOR_RANGE(int64_t, ph, 0, out.At(3)) {\n            int64_t hstart = ph * strides.at(1) - padding_before.at(1);\n            int64_t hend = std::min(hstart + pool_size.at(1), in.At(3));\n            hstart = std::max(hstart, static_cast<int64_t>(0));\n            FOR_RANGE(int64_t, pw, 0, out.At(4)) {\n              int64_t wstart = pw * strides.at(2) - padding_before.at(2);\n              int64_t wend = std::min(wstart + pool_size.at(2), in.At(4));\n              wstart = std::max(wstart, static_cast<int64_t>(0));\n\n              const int64_t size = (dend - dstart) * (hend - hstart) * (wend - wstart);\n              const int64_t pool_index = pd * out.Count(3) + ph * out.At(4) + pw;\n              FOR_RANGE(int64_t, d, dstart, dend) {\n                FOR_RANGE(int64_t, h, hstart, hend) {\n                  FOR_RANGE(int64_t, w, wstart, wend) {\n                    const int64_t index = d * in.Count(3) + h * in.At(4) + w;\n                    process(input[index], output[pool_index], output_diff[pool_index], size,\n                            input_diff[index]);\n                  }\n                }\n              }\n            }\n          }\n        }\n        // offset\n        input += in.Count(2);\n        input_diff += in.Count(2);\n        output += out.Count(2);\n        output_diff += out.Count(2);\n      }\n    }\n  }\n\n  static void CLastForward(const Params3D& params_3d, const user_op::Tensor* in_blob,\n                           user_op::Tensor* out_blob, const ForwardInitialize& forward_initialize,\n                           const CLastProcess& process, const CLastFinalize& finalize) {\n    const Shape& in = params_3d.GetXShape5D();\n    const Shape& out = params_3d.GetYShape5D();\n    const std::vector<int32_t>& pool_size = params_3d.pool_size_3d();\n    const std::vector<int32_t>& strides = params_3d.strides_3d();\n    const std::vector<int32_t>& padding_before = params_3d.padding_before_3d();\n\n    ConstEigenMatrixMap<T> in_mat(in_blob->dptr<T>(), in.At(1), in.elem_cnt() / in.At(1));\n    EigenMatrixMap<T> out_mat(out_blob->mut_dptr<T>(), out.At(1), out.elem_cnt() / out.At(1));\n    FOR_RANGE(int64_t, n, 0, in.At(0)) {\n      FOR_RANGE(int64_t, pd, 0, out.At(2)) {\n        int64_t dstart = pd * strides.at(0) - padding_before.at(0);\n        int64_t dend = std::min(dstart + pool_size.at(0), in.At(2));\n        dstart = std::max(dstart, static_cast<int64_t>(0));\n        FOR_RANGE(int64_t, ph, 0, out.At(3)) {\n          int64_t hstart = ph * strides.at(1) - padding_before.at(1);\n          int64_t hend = std::min(hstart + pool_size.at(1), in.At(3));\n          hstart = std::max(hstart, static_cast<int64_t>(0));\n          FOR_RANGE(int64_t, pw, 0, out.At(4)) {\n            int64_t wstart = pw * strides.at(2) - padding_before.at(2);\n            int64_t wend = std::min(wstart + pool_size.at(2), in.At(4));\n            wstart = std::max(wstart, static_cast<int64_t>(0));\n            const int out_col = ((n * out.At(2) + pd) * out.At(3) + ph) * out.At(4) + pw;\n            out_mat.col(out_col).setConstant(forward_initialize());\n            FOR_RANGE(int64_t, d, dstart, dend) {\n              FOR_RANGE(int64_t, h, hstart, hend) {\n                FOR_RANGE(int64_t, w, wstart, wend) {\n                  const int in_col = ((n * in.At(2) + d) * in.At(3) + h) * in.At(4) + w;\n                  process(in_col, out_col, in_mat, out_mat);\n                }\n              }\n            }\n            finalize((hend - hstart) * (wend - wstart) * (dend - dstart), out_col, out_mat);\n          }\n        }\n      }\n    }\n  }\n\n  static void CLastBackward(const Params3D& params_3d, const user_op::Tensor* out_diff_blob,\n                            const user_op::Tensor* out_blob, const user_op::Tensor* in_blob,\n                            user_op::Tensor* in_diff_blob, const CLastProcessGrad& process) {\n    const Shape& in = params_3d.GetXShape5D();\n    const Shape& out = params_3d.GetYShape5D();\n    const std::vector<int32_t>& pool_size = params_3d.pool_size_3d();\n    const std::vector<int32_t>& strides = params_3d.strides_3d();\n    const std::vector<int32_t>& padding_before = params_3d.padding_before_3d();\n\n    // caffe2 implementation: need check\n    ConstEigenArrayMap<T> out_mat(out_blob->dptr<T>(), out.At(1), out.elem_cnt() / out.At(1));\n    ConstEigenArrayMap<T> in_mat(in_blob->dptr<T>(), in.At(1), in.elem_cnt() / in.At(1));\n    ConstEigenArrayMap<T> out_diff_mat(out_diff_blob->dptr<T>(), out.At(1),\n                                       out.elem_cnt() / out.At(1));\n    std::memset(in_diff_blob->mut_dptr<T>(), T(0), in.elem_cnt() * sizeof(T));\n    EigenArrayMap<T> in_diff_mat(in_diff_blob->mut_dptr<T>(), in.At(1), in.elem_cnt() / in.At(1));\n    FOR_RANGE(int64_t, n, 0, in.At(0)) {\n      FOR_RANGE(int64_t, pd, 0, out.At(2)) {\n        int64_t dstart = pd * strides.at(0) - padding_before.at(0);\n        int64_t dend = std::min(dstart + pool_size.at(0), in.At(2));\n        dstart = std::max(dstart, static_cast<int64_t>(0));\n        FOR_RANGE(int64_t, ph, 0, out.At(3)) {\n          int64_t hstart = ph * strides.at(1) - padding_before.at(1);\n          int64_t hend = std::min(hstart + pool_size.at(1), in.At(3));\n          hstart = std::max(hstart, static_cast<int64_t>(0));\n          FOR_RANGE(int64_t, pw, 0, out.At(4)) {\n            int64_t wstart = pw * strides.at(2) - padding_before.at(2);\n            int64_t wend = std::min(wstart + pool_size.at(2), in.At(4));\n            wstart = std::max(wstart, static_cast<int64_t>(0));\n            const int64_t pool_index = ((n * out.At(2) + pd) * out.At(3) + ph) * out.At(4) + pw;\n            const int64_t size = (dend - dstart) * (hend - hstart) * (wend - wstart);\n            FOR_RANGE(int64_t, d, dstart, dend) {\n              FOR_RANGE(int64_t, h, hstart, hend) {\n                FOR_RANGE(int64_t, w, wstart, wend) {\n                  const int64_t input_index = ((n * in.At(2) + d) * in.At(3) + h) * in.At(4) + w;\n                  process(pool_index, input_index, size, out_mat, in_mat, out_diff_mat,\n                          in_diff_mat);\n                }\n              }\n            }\n          }\n        }\n      }\n    }\n  }\n\n  static void AvgFWCompute(user_op::KernelComputeContext* ctx,\n                           const PoolOpKernelCache* pool_state) {\n    const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    CHECK_NOTNULL(pool_state);\n    const std::string& data_format = ctx->Attr<std::string>(\"data_format\");\n    if (data_format == \"channels_first\") {\n      CFirstForward(\n          pool_state->GetParams3D(), x, y, GetZeroVal<T>, [](const T& lhs, T& rhs) { rhs += lhs; },\n          [](const int64_t size, T& out) { out /= size; });\n    } else if (data_format == \"channels_last\") {\n      CLastForward(\n          pool_state->GetParams3D(), x, y, GetZeroVal<T>,\n          [](const int64_t in_col, const int64_t out_col, ConstEigenMatrixMap<T>& in_mat,\n             EigenMatrixMap<T>& out_mat) { out_mat.col(out_col) += in_mat.col(in_col); },\n          [](const int64_t size, const int64_t col, EigenMatrixMap<T>& out_mat) {\n            out_mat.col(col) /= size;\n          });\n    } else {\n      UNIMPLEMENTED();\n    }\n  }\n\n  static void AvgBWCompute(user_op::KernelComputeContext* ctx,\n                           const PoolOpKernelCache* pool_state) {\n    const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    const user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex(\"dx\", 0);\n    CHECK_NOTNULL(pool_state);\n    const std::string& data_format = ctx->Attr<std::string>(\"data_format\");\n    if (data_format == \"channels_first\") {\n      CFirstBackward(pool_state->GetParams3D(), dy, y, x, dx,\n                     [](const T& in, const T& out, const T& out_diff, const int64_t size,\n                        T& in_diff) { in_diff += (out_diff / static_cast<T>(size)); });\n    } else if (data_format == \"channels_last\") {\n      CLastBackward(pool_state->GetParams3D(), dy, y, x, dx,\n                    [](const int64_t out_col, const int64_t in_col, const int64_t size,\n                       ConstEigenArrayMap<T>& out_arr, ConstEigenArrayMap<T>& in_arr,\n                       ConstEigenArrayMap<T>& out_diff_arr, EigenArrayMap<T>& in_diff_arr) {\n                      in_diff_arr.col(in_col) += out_diff_arr.col(out_col) / static_cast<T>(size);\n                    });\n    } else {\n      UNIMPLEMENTED();\n    }\n  }\n\n  static void MaxFWCompute(user_op::KernelComputeContext* ctx,\n                           const PoolOpKernelCache* pool_state) {\n    const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    CHECK_NOTNULL(pool_state);\n    const std::string& data_format = ctx->Attr<std::string>(\"data_format\");\n    if (data_format == \"channels_first\") {\n      CFirstForward(\n          pool_state->GetParams3D(), x, y, GetMinVal<T>,\n          [](const T& lhs, T& rhs) {\n            if (lhs > rhs) { rhs = lhs; }\n          },\n          [](const int64_t size, T& out) {});\n    } else if (data_format == \"channels_last\") {\n      CLastForward(\n          pool_state->GetParams3D(), x, y, GetMinVal<T>,\n          [](const int64_t in_col, const int64_t out_col, ConstEigenMatrixMap<T>& in_mat,\n             EigenMatrixMap<T>& out_mat) {\n            out_mat.col(out_col) = out_mat.col(out_col).cwiseMax(in_mat.col(in_col));\n          },\n          [](const int64_t size, const int64_t col, EigenMatrixMap<T>& out_mat) {});\n    } else {\n      UNIMPLEMENTED();\n    }\n  }\n\n  static void MaxBWCompute(user_op::KernelComputeContext* ctx,\n                           const PoolOpKernelCache* pool_state) {\n    const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    const user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex(\"dx\", 0);\n    CHECK_NOTNULL(pool_state);\n    const std::string& data_format = ctx->Attr<std::string>(\"data_format\");\n    if (data_format == \"channels_first\") {\n      CFirstBackward(\n          pool_state->GetParams3D(), dy, y, x, dx,\n          [](const T& in, const T& out, const T& out_diff, const int64_t size, T& in_diff) {\n            if (in == out) { in_diff += out_diff; }\n          });\n    } else if (data_format == \"channels_last\") {\n      CLastBackward(\n          pool_state->GetParams3D(), dy, y, x, dx,\n          [](const int64_t out_col, const int64_t in_col, const int64_t size,\n             ConstEigenArrayMap<T>& out_arr, ConstEigenArrayMap<T>& in_arr,\n             ConstEigenArrayMap<T>& out_diff_arr, EigenArrayMap<T>& in_diff_arr) {\n            in_diff_arr.col(in_col) +=\n                out_diff_arr.col(out_col)\n                * (in_arr.col(in_col).cwiseEqual(out_arr.col(out_col)).template cast<T>());\n          });\n    } else {\n      UNIMPLEMENTED();\n    }\n  }\n};\n\n}  // namespace\n\ntemplate<typename T>\nclass AvgPool1DCpuKernel final : public user_op::OpKernel {\n public:\n  AvgPool1DCpuKernel() = default;\n  ~AvgPool1DCpuKernel() = default;\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n  std::shared_ptr<user_op::OpKernelCache> InitOpKernelCache(\n      user_op::KernelCacheContext* ctx) const override {\n    return InitPoolOpKernelCache(ctx, 1);\n  }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*,\n               const user_op::OpKernelCache* cache) const override {\n    PoolCpuKernelUtil<T>::AvgFWCompute(ctx, dynamic_cast<const PoolOpKernelCache*>(cache));\n  };\n};\n\ntemplate<typename T>\nclass AvgPool1DGradCpuKernel final : public user_op::OpKernel {\n public:\n  AvgPool1DGradCpuKernel() = default;\n  ~AvgPool1DGradCpuKernel() = default;\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n  std::shared_ptr<user_op::OpKernelCache> InitOpKernelCache(\n      user_op::KernelCacheContext* ctx) const override {\n    return InitPoolOpKernelCache(ctx, 1);\n  }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*,\n               const user_op::OpKernelCache* cache) const override {\n    PoolCpuKernelUtil<T>::AvgBWCompute(ctx, dynamic_cast<const PoolOpKernelCache*>(cache));\n  };\n};\n\ntemplate<typename T>\nclass AvgPool2DCpuKernel final : public user_op::OpKernel {\n public:\n  AvgPool2DCpuKernel() = default;\n  ~AvgPool2DCpuKernel() = default;\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n  std::shared_ptr<user_op::OpKernelCache> InitOpKernelCache(\n      user_op::KernelCacheContext* ctx) const override {\n    return InitPoolOpKernelCache(ctx, 2);\n  }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*,\n               const user_op::OpKernelCache* cache) const override {\n    PoolCpuKernelUtil<T>::AvgFWCompute(ctx, dynamic_cast<const PoolOpKernelCache*>(cache));\n  };\n};\n\ntemplate<typename T>\nclass AvgPool2DGradCpuKernel final : public user_op::OpKernel {\n public:\n  AvgPool2DGradCpuKernel() = default;\n  ~AvgPool2DGradCpuKernel() = default;\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n  std::shared_ptr<user_op::OpKernelCache> InitOpKernelCache(\n      user_op::KernelCacheContext* ctx) const override {\n    return InitPoolOpKernelCache(ctx, 2);\n  }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*,\n               const user_op::OpKernelCache* cache) const override {\n    PoolCpuKernelUtil<T>::AvgBWCompute(ctx, dynamic_cast<const PoolOpKernelCache*>(cache));\n  };\n};\n\ntemplate<typename T>\nclass AvgPool3DCpuKernel final : public user_op::OpKernel {\n public:\n  AvgPool3DCpuKernel() = default;\n  ~AvgPool3DCpuKernel() = default;\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n  std::shared_ptr<user_op::OpKernelCache> InitOpKernelCache(\n      user_op::KernelCacheContext* ctx) const override {\n    return InitPoolOpKernelCache(ctx, 3);\n  }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*,\n               const user_op::OpKernelCache* cache) const override {\n    PoolCpuKernelUtil<T>::AvgFWCompute(ctx, dynamic_cast<const PoolOpKernelCache*>(cache));\n  };\n};\n\ntemplate<typename T>\nclass AvgPool3DGradCpuKernel final : public user_op::OpKernel {\n public:\n  AvgPool3DGradCpuKernel() = default;\n  ~AvgPool3DGradCpuKernel() = default;\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n  std::shared_ptr<user_op::OpKernelCache> InitOpKernelCache(\n      user_op::KernelCacheContext* ctx) const override {\n    return InitPoolOpKernelCache(ctx, 3);\n  }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*,\n               const user_op::OpKernelCache* cache) const override {\n    PoolCpuKernelUtil<T>::AvgBWCompute(ctx, dynamic_cast<const PoolOpKernelCache*>(cache));\n  };\n};\n\ntemplate<typename T>\nclass MaxPool1DCpuKernel final : public user_op::OpKernel {\n public:\n  MaxPool1DCpuKernel() = default;\n  ~MaxPool1DCpuKernel() = default;\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n  std::shared_ptr<user_op::OpKernelCache> InitOpKernelCache(\n      user_op::KernelCacheContext* ctx) const override {\n    return InitPoolOpKernelCache(ctx, 1);\n  }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*,\n               const user_op::OpKernelCache* cache) const override {\n    PoolCpuKernelUtil<T>::MaxFWCompute(ctx, dynamic_cast<const PoolOpKernelCache*>(cache));\n  };\n};\n\ntemplate<typename T>\nclass MaxPool1DGradCpuKernel final : public user_op::OpKernel {\n public:\n  MaxPool1DGradCpuKernel() = default;\n  ~MaxPool1DGradCpuKernel() = default;\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n  std::shared_ptr<user_op::OpKernelCache> InitOpKernelCache(\n      user_op::KernelCacheContext* ctx) const override {\n    return InitPoolOpKernelCache(ctx, 1);\n  }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*,\n               const user_op::OpKernelCache* cache) const override {\n    PoolCpuKernelUtil<T>::MaxBWCompute(ctx, dynamic_cast<const PoolOpKernelCache*>(cache));\n  };\n};\n\ntemplate<typename T>\nclass MaxPool2DCpuKernel final : public user_op::OpKernel {\n public:\n  MaxPool2DCpuKernel() = default;\n  ~MaxPool2DCpuKernel() = default;\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n  std::shared_ptr<user_op::OpKernelCache> InitOpKernelCache(\n      user_op::KernelCacheContext* ctx) const override {\n    return InitPoolOpKernelCache(ctx, 2);\n  }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*,\n               const user_op::OpKernelCache* cache) const override {\n    PoolCpuKernelUtil<T>::MaxFWCompute(ctx, dynamic_cast<const PoolOpKernelCache*>(cache));\n  };\n};\n\ntemplate<typename T>\nclass MaxPool2DGradCpuKernel final : public user_op::OpKernel {\n public:\n  MaxPool2DGradCpuKernel() = default;\n  ~MaxPool2DGradCpuKernel() = default;\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n  std::shared_ptr<user_op::OpKernelCache> InitOpKernelCache(\n      user_op::KernelCacheContext* ctx) const override {\n    return InitPoolOpKernelCache(ctx, 2);\n  }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*,\n               const user_op::OpKernelCache* cache) const override {\n    PoolCpuKernelUtil<T>::MaxBWCompute(ctx, dynamic_cast<const PoolOpKernelCache*>(cache));\n  };\n};\n\ntemplate<typename T>\nclass MaxPool3DCpuKernel final : public user_op::OpKernel {\n public:\n  MaxPool3DCpuKernel() = default;\n  ~MaxPool3DCpuKernel() = default;\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n  std::shared_ptr<user_op::OpKernelCache> InitOpKernelCache(\n      user_op::KernelCacheContext* ctx) const override {\n    return InitPoolOpKernelCache(ctx, 3);\n  }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*,\n               const user_op::OpKernelCache* cache) const override {\n    PoolCpuKernelUtil<T>::MaxFWCompute(ctx, dynamic_cast<const PoolOpKernelCache*>(cache));\n  };\n};\n\ntemplate<typename T>\nclass MaxPool3DGradCpuKernel final : public user_op::OpKernel {\n public:\n  MaxPool3DGradCpuKernel() = default;\n  ~MaxPool3DGradCpuKernel() = default;\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n  std::shared_ptr<user_op::OpKernelCache> InitOpKernelCache(\n      user_op::KernelCacheContext* ctx) const override {\n    return InitPoolOpKernelCache(ctx, 3);\n  }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*,\n               const user_op::OpKernelCache* cache) const override {\n    PoolCpuKernelUtil<T>::MaxBWCompute(ctx, dynamic_cast<const PoolOpKernelCache*>(cache));\n  };\n};\n\n#define REGISTER_POOL_CPU_KERNEL(dtype)                                                 \\\n  REGISTER_USER_KERNEL(\"tf_avg_pool_1d\")                                                \\\n      .SetCreateFn<AvgPool1DCpuKernel<dtype>>()                                         \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)                   \\\n                       && (user_op::HobDataType(\"x\", 0) == GetDataType<dtype>::value)); \\\n  REGISTER_USER_KERNEL(\"tf_avg_pool_1d_grad\")                                           \\\n      .SetCreateFn<AvgPool1DGradCpuKernel<dtype>>()                                     \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)                   \\\n                       && (user_op::HobDataType(\"x\", 0) == GetDataType<dtype>::value)); \\\n  REGISTER_USER_KERNEL(\"tf_avg_pool_2d\")                                                \\\n      .SetCreateFn<AvgPool2DCpuKernel<dtype>>()                                         \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)                   \\\n                       && (user_op::HobDataType(\"x\", 0) == GetDataType<dtype>::value)); \\\n  REGISTER_USER_KERNEL(\"tf_avg_pool_2d_grad\")                                           \\\n      .SetCreateFn<AvgPool2DGradCpuKernel<dtype>>()                                     \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)                   \\\n                       && (user_op::HobDataType(\"x\", 0) == GetDataType<dtype>::value)); \\\n  REGISTER_USER_KERNEL(\"tf_avg_pool_3d\")                                                \\\n      .SetCreateFn<AvgPool3DCpuKernel<dtype>>()                                         \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)                   \\\n                       && (user_op::HobDataType(\"x\", 0) == GetDataType<dtype>::value)); \\\n  REGISTER_USER_KERNEL(\"tf_avg_pool_3d_grad\")                                           \\\n      .SetCreateFn<AvgPool3DGradCpuKernel<dtype>>()                                     \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)                   \\\n                       && (user_op::HobDataType(\"x\", 0) == GetDataType<dtype>::value)); \\\n  REGISTER_USER_KERNEL(\"tf_max_pool_1d\")                                                \\\n      .SetCreateFn<MaxPool1DCpuKernel<dtype>>()                                         \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)                   \\\n                       && (user_op::HobDataType(\"x\", 0) == GetDataType<dtype>::value)); \\\n  REGISTER_USER_KERNEL(\"tf_max_pool_1d_grad\")                                           \\\n      .SetCreateFn<MaxPool1DGradCpuKernel<dtype>>()                                     \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)                   \\\n                       && (user_op::HobDataType(\"x\", 0) == GetDataType<dtype>::value)); \\\n  REGISTER_USER_KERNEL(\"tf_max_pool_2d\")                                                \\\n      .SetCreateFn<MaxPool2DCpuKernel<dtype>>()                                         \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)                   \\\n                       && (user_op::HobDataType(\"x\", 0) == GetDataType<dtype>::value)); \\\n  REGISTER_USER_KERNEL(\"tf_max_pool_2d_grad\")                                           \\\n      .SetCreateFn<MaxPool2DGradCpuKernel<dtype>>()                                     \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)                   \\\n                       && (user_op::HobDataType(\"x\", 0) == GetDataType<dtype>::value)); \\\n  REGISTER_USER_KERNEL(\"tf_max_pool_3d\")                                                \\\n      .SetCreateFn<MaxPool3DCpuKernel<dtype>>()                                         \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)                   \\\n                       && (user_op::HobDataType(\"x\", 0) == GetDataType<dtype>::value)); \\\n  REGISTER_USER_KERNEL(\"tf_max_pool_3d_grad\")                                           \\\n      .SetCreateFn<MaxPool3DGradCpuKernel<dtype>>()                                     \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)                   \\\n                       && (user_op::HobDataType(\"x\", 0) == GetDataType<dtype>::value));\n\nREGISTER_POOL_CPU_KERNEL(float)\nREGISTER_POOL_CPU_KERNEL(double)\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/tf_pool_gpu_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifdef WITH_CUDA\n\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/user/utils/pool_util.h\"\n#include \"oneflow/core/device/cudnn_util.h\"\n#include \"oneflow/core/kernel/cuda_graph_support.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nclass CudnnPoolDesc final {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(CudnnPoolDesc);\n  CudnnPoolDesc(cudnnPoolingMode_t pooling_mode, int dims, const int* window, const int* padding,\n                const int* stride) {\n    OF_CUDNN_CHECK(cudnnCreatePoolingDescriptor(&val_));\n    OF_CUDNN_CHECK(cudnnSetPoolingNdDescriptor(val_, pooling_mode, CUDNN_NOT_PROPAGATE_NAN, dims,\n                                               window, padding, stride));\n  }\n\n  ~CudnnPoolDesc() { OF_CUDNN_CHECK(cudnnDestroyPoolingDescriptor(val_)); }\n\n  const cudnnPoolingDescriptor_t& Get() const { return val_; }\n\n private:\n  cudnnPoolingDescriptor_t val_;\n};\n\nclass GPUPoolOpKernelCache final : public user_op::OpKernelCache {\n public:\n  GPUPoolOpKernelCache(const int32_t dim, const std::string& pooling_type, const ShapeView& x_shape,\n                       const ShapeView& y_shape, const std::string& data_format,\n                       const DataType& dtype, const Params3D& params_3d)\n      : pooling_type_(pooling_type) {\n    Reset(dim, pooling_type, x_shape, y_shape, data_format, dtype, params_3d);\n  }\n  ~GPUPoolOpKernelCache() = default;\n\n  void Reset(const int32_t dim, const std::string& pooling_type, const ShapeView& x_shape,\n             const ShapeView& y_shape, const std::string& data_format, const DataType& dtype,\n             const Params3D& params_3d) {\n    FixedVector pool_size(dim);\n    FixedVector padding(dim);\n    FixedVector strides(dim);\n    FOR_RANGE(int, i, 0, dim) {\n      int32_t index_in_3d = i + 3 - dim;\n      pool_size[i] = params_3d.pool_size_3d().at(index_in_3d);\n      padding[i] = params_3d.padding_before_3d().at(index_in_3d);\n      strides[i] = params_3d.strides_3d().at(index_in_3d);\n    }\n\n    x_desc_.reset(new CudnnTensorDesc(dtype, x_shape, data_format));\n    y_desc_.reset(new CudnnTensorDesc(dtype, y_shape, data_format));\n    cudnnPoolingMode_t pooling_mode;\n    if (pooling_type == \"AVG\") {\n      pooling_mode = CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING;\n    } else if (pooling_type == \"MAX\") {\n      pooling_mode = CUDNN_POOLING_MAX;\n    } else {\n      UNIMPLEMENTED();\n    }\n    pooling_desc_.reset(\n        new CudnnPoolDesc(pooling_mode, dim, pool_size.data(), padding.data(), strides.data()));\n  }\n\n  static std::shared_ptr<GPUPoolOpKernelCache> FromKernelComputeContext(\n      const int32_t& dim, const std::string& pooling_type, user_op::KernelCacheContext* ctx) {\n    if (pooling_type != \"MAX\" && pooling_type != \"AVG\") { UNIMPLEMENTED(); }\n    const ShapeView& x_shape = ctx->TensorDesc4ArgNameAndIndex(\"x\", 0)->shape();\n    const std::string& data_format = ctx->Attr<std::string>(\"data_format\");\n    const std::string& padding = ctx->Attr<std::string>(\"padding\");\n    const auto& padding_before = ctx->Attr<std::vector<int32_t>>(\"padding_before\");\n    const auto& padding_after = ctx->Attr<std::vector<int32_t>>(\"padding_after\");\n    const std::vector<int32_t>& pool_size = ctx->Attr<std::vector<int32_t>>(\"pool_size\");\n    const std::vector<int32_t>& strides = ctx->Attr<std::vector<int32_t>>(\"strides\");\n    const bool ceil_mode = ctx->Attr<bool>(\"ceil_mode\");\n    const Params3D params_3d(dim, x_shape, data_format, padding, padding_before, padding_after,\n                             pool_size, strides, ceil_mode);\n    const ShapeView& y_shape = ctx->TensorDesc4ArgNameAndIndex(\"y\", 0)->shape();\n    const DataType dtype = ctx->TensorDesc4ArgNameAndIndex(\"x\", 0)->data_type();\n    return std::make_shared<GPUPoolOpKernelCache>(dim, pooling_type, x_shape, y_shape, data_format,\n                                                  dtype, params_3d);\n  }\n\n  const cudnnTensorDescriptor_t& cudnn_x_tensor_desc() const { return x_desc_->Get(); }\n  const cudnnTensorDescriptor_t& cudnn_y_tensor_desc() const { return y_desc_->Get(); }\n  const cudnnPoolingDescriptor_t& cudnn_pooling_desc() const { return pooling_desc_->Get(); }\n\n private:\n  std::unique_ptr<CudnnTensorDesc> x_desc_;\n  std::unique_ptr<CudnnTensorDesc> y_desc_;\n  std::unique_ptr<CudnnPoolDesc> pooling_desc_;\n  std::string pooling_type_;\n};\n\nstruct PoolGpuKernelUtil {\n  static void FWCompute(user_op::KernelComputeContext* ctx,\n                        const GPUPoolOpKernelCache* gpu_pool_op_kernel_cache) {\n    const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    CHECK(gpu_pool_op_kernel_cache != nullptr);\n    OF_CUDNN_CHECK(cudnnPoolingForward(\n        ctx->stream()->As<ep::CudaStream>()->cudnn_handle(),\n        gpu_pool_op_kernel_cache->cudnn_pooling_desc(), CudnnSPOnePtr(x->data_type()),\n        gpu_pool_op_kernel_cache->cudnn_x_tensor_desc(), x->dptr(), CudnnSPZeroPtr(x->data_type()),\n        gpu_pool_op_kernel_cache->cudnn_y_tensor_desc(), y->mut_dptr()));\n  }\n\n  static void BWCompute(user_op::KernelComputeContext* ctx,\n                        const GPUPoolOpKernelCache* gpu_pool_op_kernel_cache) {\n    const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    const user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex(\"dx\", 0);\n    CHECK(gpu_pool_op_kernel_cache != nullptr);\n    OF_CUDNN_CHECK(cudnnPoolingBackward(\n        ctx->stream()->As<ep::CudaStream>()->cudnn_handle(),\n        gpu_pool_op_kernel_cache->cudnn_pooling_desc(), CudnnSPOnePtr(y->data_type()),\n        gpu_pool_op_kernel_cache->cudnn_y_tensor_desc(), y->dptr(),\n        gpu_pool_op_kernel_cache->cudnn_y_tensor_desc(), dy->dptr(),\n        gpu_pool_op_kernel_cache->cudnn_x_tensor_desc(), x->dptr(), CudnnSPZeroPtr(y->data_type()),\n        gpu_pool_op_kernel_cache->cudnn_x_tensor_desc(), dx->mut_dptr()));\n  }\n};\n\n}  // namespace\n\nclass AvgPool1DGpuKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport {\n public:\n  AvgPool1DGpuKernel() = default;\n  ~AvgPool1DGpuKernel() = default;\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n  std::shared_ptr<user_op::OpKernelCache> InitOpKernelCache(\n      user_op::KernelCacheContext* ctx) const override {\n    return GPUPoolOpKernelCache::FromKernelComputeContext(1, \"AVG\", ctx);\n  }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*,\n               const user_op::OpKernelCache* cache) const override {\n    PoolGpuKernelUtil::FWCompute(ctx, dynamic_cast<const GPUPoolOpKernelCache*>(cache));\n  };\n};\n\nclass AvgPool1DGradGpuKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport {\n public:\n  AvgPool1DGradGpuKernel() = default;\n  ~AvgPool1DGradGpuKernel() = default;\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n  std::shared_ptr<user_op::OpKernelCache> InitOpKernelCache(\n      user_op::KernelCacheContext* ctx) const override {\n    return GPUPoolOpKernelCache::FromKernelComputeContext(1, \"AVG\", ctx);\n  }\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*,\n               const user_op::OpKernelCache* cache) const override {\n    PoolGpuKernelUtil::BWCompute(ctx, dynamic_cast<const GPUPoolOpKernelCache*>(cache));\n  };\n};\n\nclass AvgPool2DGpuKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport {\n public:\n  AvgPool2DGpuKernel() = default;\n  ~AvgPool2DGpuKernel() = default;\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n  std::shared_ptr<user_op::OpKernelCache> InitOpKernelCache(\n      user_op::KernelCacheContext* ctx) const override {\n    return GPUPoolOpKernelCache::FromKernelComputeContext(2, \"AVG\", ctx);\n  }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*,\n               const user_op::OpKernelCache* cache) const override {\n    PoolGpuKernelUtil::FWCompute(ctx, dynamic_cast<const GPUPoolOpKernelCache*>(cache));\n  };\n};\n\nclass AvgPool2DGradGpuKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport {\n public:\n  AvgPool2DGradGpuKernel() = default;\n  ~AvgPool2DGradGpuKernel() = default;\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n  std::shared_ptr<user_op::OpKernelCache> InitOpKernelCache(\n      user_op::KernelCacheContext* ctx) const override {\n    return GPUPoolOpKernelCache::FromKernelComputeContext(2, \"AVG\", ctx);\n  }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*,\n               const user_op::OpKernelCache* cache) const override {\n    PoolGpuKernelUtil::BWCompute(ctx, dynamic_cast<const GPUPoolOpKernelCache*>(cache));\n  };\n};\n\nclass AvgPool3DGpuKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport {\n public:\n  AvgPool3DGpuKernel() = default;\n  ~AvgPool3DGpuKernel() = default;\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n  std::shared_ptr<user_op::OpKernelCache> InitOpKernelCache(\n      user_op::KernelCacheContext* ctx) const override {\n    return GPUPoolOpKernelCache::FromKernelComputeContext(3, \"AVG\", ctx);\n  }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*,\n               const user_op::OpKernelCache* cache) const override {\n    PoolGpuKernelUtil::FWCompute(ctx, dynamic_cast<const GPUPoolOpKernelCache*>(cache));\n  };\n};\n\nclass AvgPool3DGradGpuKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport {\n public:\n  AvgPool3DGradGpuKernel() = default;\n  ~AvgPool3DGradGpuKernel() = default;\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n  std::shared_ptr<user_op::OpKernelCache> InitOpKernelCache(\n      user_op::KernelCacheContext* ctx) const override {\n    return GPUPoolOpKernelCache::FromKernelComputeContext(3, \"AVG\", ctx);\n  }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*,\n               const user_op::OpKernelCache* cache) const override {\n    PoolGpuKernelUtil::BWCompute(ctx, dynamic_cast<const GPUPoolOpKernelCache*>(cache));\n  };\n};\n\nclass MaxPool1DGpuKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport {\n public:\n  MaxPool1DGpuKernel() = default;\n  ~MaxPool1DGpuKernel() = default;\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n  std::shared_ptr<user_op::OpKernelCache> InitOpKernelCache(\n      user_op::KernelCacheContext* ctx) const override {\n    return GPUPoolOpKernelCache::FromKernelComputeContext(1, \"MAX\", ctx);\n  }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*,\n               const user_op::OpKernelCache* cache) const override {\n    PoolGpuKernelUtil::FWCompute(ctx, dynamic_cast<const GPUPoolOpKernelCache*>(cache));\n  };\n};\n\nclass MaxPool1DGradGpuKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport {\n public:\n  MaxPool1DGradGpuKernel() = default;\n  ~MaxPool1DGradGpuKernel() = default;\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n  std::shared_ptr<user_op::OpKernelCache> InitOpKernelCache(\n      user_op::KernelCacheContext* ctx) const override {\n    return GPUPoolOpKernelCache::FromKernelComputeContext(1, \"MAX\", ctx);\n  }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*,\n               const user_op::OpKernelCache* cache) const override {\n    PoolGpuKernelUtil::BWCompute(ctx, dynamic_cast<const GPUPoolOpKernelCache*>(cache));\n  };\n};\n\nclass MaxPool2DGpuKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport {\n public:\n  MaxPool2DGpuKernel() = default;\n  ~MaxPool2DGpuKernel() = default;\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n  std::shared_ptr<user_op::OpKernelCache> InitOpKernelCache(\n      user_op::KernelCacheContext* ctx) const override {\n    return GPUPoolOpKernelCache::FromKernelComputeContext(2, \"MAX\", ctx);\n  }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*,\n               const user_op::OpKernelCache* cache) const override {\n    PoolGpuKernelUtil::FWCompute(ctx, dynamic_cast<const GPUPoolOpKernelCache*>(cache));\n  };\n};\n\nclass MaxPool2DGradGpuKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport {\n public:\n  MaxPool2DGradGpuKernel() = default;\n  ~MaxPool2DGradGpuKernel() = default;\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n  std::shared_ptr<user_op::OpKernelCache> InitOpKernelCache(\n      user_op::KernelCacheContext* ctx) const override {\n    return GPUPoolOpKernelCache::FromKernelComputeContext(2, \"MAX\", ctx);\n  }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*,\n               const user_op::OpKernelCache* cache) const override {\n    PoolGpuKernelUtil::BWCompute(ctx, dynamic_cast<const GPUPoolOpKernelCache*>(cache));\n  };\n};\n\nclass MaxPool3DGpuKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport {\n public:\n  MaxPool3DGpuKernel() = default;\n  ~MaxPool3DGpuKernel() = default;\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n  std::shared_ptr<user_op::OpKernelCache> InitOpKernelCache(\n      user_op::KernelCacheContext* ctx) const override {\n    return GPUPoolOpKernelCache::FromKernelComputeContext(3, \"MAX\", ctx);\n  }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*,\n               const user_op::OpKernelCache* cache) const override {\n    PoolGpuKernelUtil::FWCompute(ctx, dynamic_cast<const GPUPoolOpKernelCache*>(cache));\n  };\n};\n\nclass MaxPool3DGradGpuKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport {\n public:\n  MaxPool3DGradGpuKernel() = default;\n  ~MaxPool3DGradGpuKernel() = default;\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n  std::shared_ptr<user_op::OpKernelCache> InitOpKernelCache(\n      user_op::KernelCacheContext* ctx) const override {\n    return GPUPoolOpKernelCache::FromKernelComputeContext(3, \"MAX\", ctx);\n  }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*,\n               const user_op::OpKernelCache* cache) const override {\n    PoolGpuKernelUtil::BWCompute(ctx, dynamic_cast<const GPUPoolOpKernelCache*>(cache));\n  };\n};\n\nREGISTER_USER_KERNEL(\"tf_avg_pool_1d\")\n    .SetCreateFn<AvgPool1DGpuKernel>()\n    .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA));\nREGISTER_USER_KERNEL(\"tf_avg_pool_1d_grad\")\n    .SetCreateFn<AvgPool1DGradGpuKernel>()\n    .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA));\nREGISTER_USER_KERNEL(\"tf_avg_pool_2d\")\n    .SetCreateFn<AvgPool2DGpuKernel>()\n    .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA));\nREGISTER_USER_KERNEL(\"tf_avg_pool_2d_grad\")\n    .SetCreateFn<AvgPool2DGradGpuKernel>()\n    .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA));\nREGISTER_USER_KERNEL(\"tf_avg_pool_3d\")\n    .SetCreateFn<AvgPool3DGpuKernel>()\n    .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA));\nREGISTER_USER_KERNEL(\"tf_avg_pool_3d_grad\")\n    .SetCreateFn<AvgPool3DGradGpuKernel>()\n    .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA));\nREGISTER_USER_KERNEL(\"tf_max_pool_1d\")\n    .SetCreateFn<MaxPool1DGpuKernel>()\n    .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA));\nREGISTER_USER_KERNEL(\"tf_max_pool_1d_grad\")\n    .SetCreateFn<MaxPool1DGradGpuKernel>()\n    .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA));\nREGISTER_USER_KERNEL(\"tf_max_pool_2d\")\n    .SetCreateFn<MaxPool2DGpuKernel>()\n    .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA));\nREGISTER_USER_KERNEL(\"tf_max_pool_2d_grad\")\n    .SetCreateFn<MaxPool2DGradGpuKernel>()\n    .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA));\nREGISTER_USER_KERNEL(\"tf_max_pool_3d\")\n    .SetCreateFn<MaxPool3DGpuKernel>()\n    .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA));\nREGISTER_USER_KERNEL(\"tf_max_pool_3d_grad\")\n    .SetCreateFn<MaxPool3DGradGpuKernel>()\n    .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA));\n\n}  // namespace oneflow\n\n#endif\n"
  },
  {
    "path": "oneflow/user/kernels/tf_prelu_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/ndarray/ndarray_util.h\"\n\nnamespace oneflow {\n\ntemplate<typename T>\nclass TfCpuPReluKernel final : public user_op::OpKernel {\n public:\n  TfCpuPReluKernel() = default;\n  ~TfCpuPReluKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    const user_op::Tensor* alpha = ctx->Tensor4ArgNameAndIndex(\"alpha\", 0);\n    user_op::Tensor* broadcasted_alpha = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n    user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    const T* x_ptr = x->dptr<T>();\n    T* y_ptr = y->mut_dptr<T>();\n    T* broadcasted_alpha_ptr = broadcasted_alpha->mut_dptr<T>();\n    const int32_t elem_cnt = x->shape_view().elem_cnt();\n    const Shape& left_extended_shape =\n        CreateLeftExtendedShape(ShapeView(alpha->shape_view()), x->shape_view().NumAxes());\n    NdarrayUtil<DeviceType::kCPU, T>::BroadcastTo(\n        ctx->stream(), XpuVarNdarray<T>(x->shape_view(), broadcasted_alpha_ptr),\n        XpuVarNdarray<const T>(left_extended_shape, alpha->dptr<T>()));\n    FOR_RANGE(int32_t, i, 0, elem_cnt) {\n      y_ptr[i] = x_ptr[i] > 0 ? x_ptr[i] : x_ptr[i] * broadcasted_alpha_ptr[i];\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_TF_CPU_PRELU_KERNEL(dtype)                                            \\\n  REGISTER_USER_KERNEL(\"tf_prelu\")                                                     \\\n      .SetCreateFn<TfCpuPReluKernel<dtype>>()                                          \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)                  \\\n                       && (user_op::HobDataType(\"y\", 0) == GetDataType<dtype>::value)) \\\n      .SetInferTmpSizeFn([](user_op::InferContext* ctx) {                              \\\n        const Shape& in_shape = ctx->InputShape(\"x\", 0);                               \\\n        return GetCudaAlignedSize(in_shape.elem_cnt() * sizeof(dtype));                \\\n      });\n\nREGISTER_TF_CPU_PRELU_KERNEL(float)\nREGISTER_TF_CPU_PRELU_KERNEL(double)\n\ntemplate<typename T>\nclass TfCpuPReluGradKernel final : public user_op::OpKernel {\n public:\n  TfCpuPReluGradKernel() = default;\n  ~TfCpuPReluGradKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    const user_op::Tensor* alpha = ctx->Tensor4ArgNameAndIndex(\"alpha\", 0);\n    const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n    user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex(\"dx\", 0);\n    user_op::Tensor* alpha_diff = ctx->Tensor4ArgNameAndIndex(\"alpha_diff\", 0);\n    const T* x_ptr = x->dptr<T>();\n    const T* dy_ptr = dy->dptr<T>();\n    T* dx_ptr = dx->mut_dptr<T>();\n    const int32_t elem_cnt = x->shape_view().elem_cnt();\n    T* broadcasted_alpha_ptr = tmp_buffer->mut_dptr<T>();\n    T* broadcasted_alpha_diff = reinterpret_cast<T*>(tmp_buffer->mut_dptr<char>()\n                                                     + GetCudaAlignedSize(elem_cnt * sizeof(T)));\n    T* reduce_sum_tmp_buf = reinterpret_cast<T*>(tmp_buffer->mut_dptr<char>()\n                                                 + 2 * GetCudaAlignedSize(elem_cnt * sizeof(T)));\n    const Shape& left_extended_shape =\n        CreateLeftExtendedShape(ShapeView(alpha->shape_view()), x->shape_view().NumAxes());\n    NdarrayUtil<DeviceType::kCPU, T>::BroadcastTo(\n        ctx->stream(), XpuVarNdarray<T>(x->shape_view(), broadcasted_alpha_ptr),\n        XpuVarNdarray<const T>(left_extended_shape, alpha->dptr<T>()));\n    FOR_RANGE(int32_t, i, 0, elem_cnt) {\n      dx_ptr[i] = x_ptr[i] > 0 ? dy_ptr[i] : dy_ptr[i] * broadcasted_alpha_ptr[i];\n      broadcasted_alpha_diff[i] = x_ptr[i] > 0 ? 0 : dy_ptr[i] * x_ptr[i];\n    }\n    NdarrayUtil<DeviceType::kCPU, T>::ReduceSum(\n        ctx->stream(), XpuVarNdarray<T>(left_extended_shape, alpha_diff->mut_dptr<T>()),\n        XpuVarNdarray<const T>(x->shape_view(), broadcasted_alpha_diff),\n        XpuVarNdarray<T>(x->shape_view(), reduce_sum_tmp_buf));\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_TF_CPU_PRELU_GRAD_KERNEL(dtype)                                        \\\n  REGISTER_USER_KERNEL(\"tf_prelu_grad\")                                                 \\\n      .SetCreateFn<TfCpuPReluGradKernel<dtype>>()                                       \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)                   \\\n                       && (user_op::HobDataType(\"dx\", 0) == GetDataType<dtype>::value)) \\\n      .SetInferTmpSizeFn([](user_op::InferContext* ctx) {                               \\\n        const Shape& in_shape = ctx->InputShape(\"x\", 0);                                \\\n        return 3 * GetCudaAlignedSize(in_shape.elem_cnt() * sizeof(dtype));             \\\n      });\n\nREGISTER_TF_CPU_PRELU_GRAD_KERNEL(float)\nREGISTER_TF_CPU_PRELU_GRAD_KERNEL(double)\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/tf_prelu_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/ndarray/ndarray_util.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename T>\n__global__ void BroadcastPReluForwardGpu(const int32_t elem_cnt, const int32_t alpha_size,\n                                         const int32_t inner_size, const T* x, const T* alpha,\n                                         T* y) {\n  T zero_val = static_cast<T>(0.0);\n  CUDA_1D_KERNEL_LOOP(i, elem_cnt) {\n    const T x_i = x[i];\n    const T alpha_i = alpha[(i / inner_size) % alpha_size];\n    y[i] = x_i > zero_val ? x_i : x_i * alpha_i;\n  }\n}\n\ntemplate<typename T>\n__global__ void BroadcastPReluBackwardGpu(const int32_t elem_cnt, const int32_t alpha_size,\n                                          const int32_t inner_size, const T* x, const T* alpha,\n                                          const T* dy, T* dx, T* alpha_diff) {\n  T zero_val = static_cast<T>(0.0);\n  CUDA_1D_KERNEL_LOOP(i, elem_cnt) {\n    const T x_i = x[i];\n    const T dy_i = dy[i];\n    const T alpha_i = alpha[(i / inner_size) % alpha_size];\n    T dx_i = zero_val;\n    T alpha_diff_i = zero_val;\n    if (x_i > zero_val) {\n      dx_i = dy_i;\n      alpha_diff_i = zero_val;\n    } else {\n      dx_i = dy_i * alpha_i;\n      alpha_diff_i = dy_i * x_i;\n    }\n    dx[i] = dx_i;\n    alpha_diff[i] = alpha_diff_i;\n  }\n}\n\ntemplate<typename T>\n__global__ void ElemwisePReluForwardGpu(const int32_t elem_cnt, const T* x, const T* alpha, T* y) {\n  T zero_val = static_cast<T>(0.0);\n  CUDA_1D_KERNEL_LOOP(i, elem_cnt) {\n    const T x_i = x[i];\n    const T alpha_i = alpha[i];\n    y[i] = x_i > zero_val ? x_i : x_i * alpha_i;\n  }\n}\n\ntemplate<typename T>\n__global__ void ElemwisePReluBackwardGpu(const int32_t elem_cnt, const T* x, const T* alpha,\n                                         const T* dy, T* dx, T* alpha_diff) {\n  T zero_val = static_cast<T>(0.0);\n  CUDA_1D_KERNEL_LOOP(i, elem_cnt) {\n    const T x_i = x[i];\n    const T dy_i = dy[i];\n    const T alpha_i = alpha[i];\n    T dx_i = zero_val;\n    T alpha_diff_i = zero_val;\n    if (x_i > zero_val) {\n      dx_i = dy_i;\n      alpha_diff_i = zero_val;\n    } else {\n      dx_i = dy_i * alpha_i;\n      alpha_diff_i = dy_i * x_i;\n    }\n    dx[i] = dx_i;\n    alpha_diff[i] = alpha_diff_i;\n  }\n}\n\nbool IsAlphaShapeContiguous(const ShapeView& alpha_shape, const ShapeView& x_shape) {\n  if (alpha_shape.elem_cnt() == 1) { return true; }\n  int64_t begin_idx = -1;\n  for (int64_t i = 0; i < alpha_shape.NumAxes(); ++i) {\n    if (alpha_shape.At(i) != 1) {\n      begin_idx = i;\n      break;\n    }\n  }\n  CHECK_NE(begin_idx, -1);\n  int64_t end_idx = -1;\n  for (int64_t i = alpha_shape.NumAxes(); i > 0; --i) {\n    if (alpha_shape.At(i - 1) != 1) {\n      end_idx = i;\n      break;\n    }\n  }\n  CHECK_NE(end_idx, -1);\n  if (alpha_shape.elem_cnt() == x_shape.Count(begin_idx + 1, end_idx + 1)) {\n    return true;\n  } else {\n    return false;\n  }\n}\n\nint32_t GetOuterSize(const ShapeView& alpha_shape, const ShapeView& x_shape) {\n  int32_t outer_size = x_shape.At(0);\n  for (int32_t i = 0; i < alpha_shape.NumAxes(); ++i) {\n    if (alpha_shape.At(i) == 1) {\n      outer_size *= x_shape.At(i + 1);\n    } else {\n      break;\n    }\n  }\n  return outer_size;\n}\n\n}  // namespace\n\ntemplate<typename T>\nclass TfGpuPReluKernel final : public user_op::OpKernel {\n public:\n  TfGpuPReluKernel() = default;\n  ~TfGpuPReluKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    const user_op::Tensor* alpha = ctx->Tensor4ArgNameAndIndex(\"alpha\", 0);\n    user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    const int32_t elem_cnt = x->shape_view().elem_cnt();\n    if (IsAlphaShapeContiguous(alpha->shape_view(), x->shape_view())) {\n      const int32_t outer_size = GetOuterSize(alpha->shape_view(), x->shape_view());\n      const int32_t alpha_size = alpha->shape_view().elem_cnt();\n      const int32_t inner_size = elem_cnt / outer_size / alpha_size;\n      BroadcastPReluForwardGpu<T><<<BlocksNum4ThreadsNum(elem_cnt), kCudaThreadsNumPerBlock, 0,\n                                    ctx->stream()->As<ep::CudaStream>()->cuda_stream()>>>(\n          elem_cnt, alpha_size, inner_size, x->dptr<T>(), alpha->dptr<T>(), y->mut_dptr<T>());\n    } else {\n      user_op::Tensor* broadcasted_alpha = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n      const Shape& left_extended_shape =\n          CreateLeftExtendedShape(ShapeView(alpha->shape_view()), x->shape_view().NumAxes());\n      NdarrayUtil<DeviceType::kCUDA, T>::BroadcastTo(\n          ctx->stream(), XpuVarNdarray<T>(x->shape_view(), broadcasted_alpha->mut_dptr<T>()),\n          XpuVarNdarray<const T>(left_extended_shape, alpha->dptr<T>()));\n      ElemwisePReluForwardGpu<T><<<BlocksNum4ThreadsNum(elem_cnt), kCudaThreadsNumPerBlock, 0,\n                                   ctx->stream()->As<ep::CudaStream>()->cuda_stream()>>>(\n          elem_cnt, x->dptr<T>(), broadcasted_alpha->dptr<T>(), y->mut_dptr<T>());\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_TF_CUDA_PRELU_KERNEL(dtype)                                           \\\n  REGISTER_USER_KERNEL(\"tf_prelu\")                                                     \\\n      .SetCreateFn<TfGpuPReluKernel<dtype>>()                                          \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)                 \\\n                       && (user_op::HobDataType(\"y\", 0) == GetDataType<dtype>::value)) \\\n      .SetInferTmpSizeFn([](user_op::InferContext* ctx) {                              \\\n        const Shape& in_shape = ctx->InputShape(\"x\", 0);                               \\\n        const Shape& alpha_shape = ctx->InputShape(\"alpha\", 0);                        \\\n        const int64_t tmp_buffer_size =                                                \\\n            IsAlphaShapeContiguous(alpha_shape, in_shape)                              \\\n                ? 0                                                                    \\\n                : GetCudaAlignedSize(in_shape.elem_cnt() * sizeof(dtype));             \\\n        return tmp_buffer_size;                                                        \\\n      });\n\nREGISTER_TF_CUDA_PRELU_KERNEL(half)\nREGISTER_TF_CUDA_PRELU_KERNEL(float)\nREGISTER_TF_CUDA_PRELU_KERNEL(double)\n\ntemplate<typename T>\nclass TfGpuPReluGradKernel final : public user_op::OpKernel {\n public:\n  TfGpuPReluGradKernel() = default;\n  ~TfGpuPReluGradKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    const user_op::Tensor* alpha = ctx->Tensor4ArgNameAndIndex(\"alpha\", 0);\n    const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex(\"dx\", 0);\n    user_op::Tensor* alpha_diff = ctx->Tensor4ArgNameAndIndex(\"alpha_diff\", 0);\n    user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n    const int32_t elem_cnt = x->shape_view().elem_cnt();\n    T* broadcasted_alpha_diff = tmp_buffer->mut_dptr<T>();\n    T* reduce_sum_tmp_buf = reinterpret_cast<T*>(tmp_buffer->mut_dptr<char>()\n                                                 + GetCudaAlignedSize(elem_cnt * sizeof(T)));\n    const Shape& left_extended_shape =\n        CreateLeftExtendedShape(ShapeView(alpha->shape_view()), x->shape_view().NumAxes());\n    if (IsAlphaShapeContiguous(alpha->shape_view(), x->shape_view())) {\n      const int32_t outer_size = GetOuterSize(alpha->shape_view(), x->shape_view());\n      const int32_t alpha_size = alpha->shape_view().elem_cnt();\n      const int32_t inner_size = elem_cnt / outer_size / alpha_size;\n      BroadcastPReluBackwardGpu<T><<<BlocksNum4ThreadsNum(elem_cnt), kCudaThreadsNumPerBlock, 0,\n                                     ctx->stream()->As<ep::CudaStream>()->cuda_stream()>>>(\n          elem_cnt, alpha_size, inner_size, x->dptr<T>(), alpha->dptr<T>(), dy->dptr<T>(),\n          dx->mut_dptr<T>(), broadcasted_alpha_diff);\n    } else {\n      T* broadcasted_alpha = reinterpret_cast<T*>(tmp_buffer->mut_dptr<char>()\n                                                  + 2 * GetCudaAlignedSize(elem_cnt * sizeof(T)));\n\n      NdarrayUtil<DeviceType::kCUDA, T>::BroadcastTo(\n          ctx->stream(), XpuVarNdarray<T>(x->shape_view(), broadcasted_alpha),\n          XpuVarNdarray<const T>(left_extended_shape, alpha->dptr<T>()));\n\n      ElemwisePReluBackwardGpu<T><<<BlocksNum4ThreadsNum(elem_cnt), kCudaThreadsNumPerBlock, 0,\n                                    ctx->stream()->As<ep::CudaStream>()->cuda_stream()>>>(\n          elem_cnt, x->dptr<T>(), broadcasted_alpha, dy->dptr<T>(), dx->mut_dptr<T>(),\n          broadcasted_alpha_diff);\n    }\n    NdarrayUtil<DeviceType::kCUDA, T>::ReduceSum(\n        ctx->stream(), XpuVarNdarray<T>(left_extended_shape, alpha_diff->mut_dptr<T>()),\n        XpuVarNdarray<const T>(x->shape_view(), broadcasted_alpha_diff),\n        XpuVarNdarray<T>(x->shape_view(), reduce_sum_tmp_buf));\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_TF_CUDA_PRELU_GRAD_KERNEL(dtype)                                       \\\n  REGISTER_USER_KERNEL(\"tf_prelu_grad\")                                                 \\\n      .SetCreateFn<TfGpuPReluGradKernel<dtype>>()                                       \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)                  \\\n                       && (user_op::HobDataType(\"dx\", 0) == GetDataType<dtype>::value)) \\\n      .SetInferTmpSizeFn([](user_op::InferContext* ctx) {                               \\\n        const Shape& in_shape = ctx->InputShape(\"x\", 0);                                \\\n        const Shape& alpha_shape = ctx->InputShape(\"alpha\", 0);                         \\\n        const int64_t tmp_buffer_size =                                                 \\\n            IsAlphaShapeContiguous(alpha_shape, in_shape)                               \\\n                ? 2 * GetCudaAlignedSize(in_shape.elem_cnt() * sizeof(dtype))           \\\n                : 3 * GetCudaAlignedSize(in_shape.elem_cnt() * sizeof(dtype));          \\\n        return tmp_buffer_size;                                                         \\\n      });\n\nREGISTER_TF_CUDA_PRELU_GRAD_KERNEL(half)\nREGISTER_TF_CUDA_PRELU_GRAD_KERNEL(float)\nREGISTER_TF_CUDA_PRELU_GRAD_KERNEL(double)\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/throw_error_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nclass ThrowErrorKernel final : public user_op::OpKernel {\n public:\n  ThrowErrorKernel() = default;\n  ~ThrowErrorKernel() override = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    THROW(RuntimeError) << \"throw error kernel\";\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\nREGISTER_USER_KERNEL(\"throw_error\").SetCreateFn<ThrowErrorKernel>();\n\n}  // namespace\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/to_contiguous_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/common/shape_vec.h\"\n#include \"oneflow/core/kernel/kernel_util.h\"\n#include \"oneflow/user/kernels/to_contiguous_kernel.h\"\n#include \"oneflow/core/common/stride.h\"\n#include \"oneflow/core/common/nd_index_offset_helper.h\"\n\nnamespace oneflow {\n\ntemplate<typename T>\nstruct ToContiguousUtil<DeviceType::kCPU, T> : ToContiguousUtilBase {\n  using ToContiguousUtilBase::ToContiguousUtilBase;\n\n  static constexpr size_t dsize = sizeof(T);\n\n  void operator()() {\n    if (contiguous_dim == -1) {\n      // 0-dim tensor\n      std::memcpy(out_dptr, in_dptr, block_size * dsize);\n    } else {\n      // if input tensor's strides equals to output's, than just copy one memory-contiguous tensor\n      bool is_same = true;\n      for (int64_t i = contiguous_dim; i != -1; --i) {\n        if (out_stride[i] != in_stride[i]) {\n          is_same = false;\n          break;\n        }\n      }\n      if (is_same) {\n        std::memcpy(out_dptr + out_offset * dsize, in_dptr + in_offset * dsize,\n                    element_count * dsize);\n      } else {\n        const int64_t ndim = contiguous_dim + 1;\n        int64_t coordinates[ndim];\n        for (int64_t i = 0; i < element_count; i += block_size) {\n          memset(coordinates, 0, sizeof(int64_t) * ndim);\n          out_offset = i;\n          in_offset = 0;\n          // compute coords(output offset to coords)\n          int64_t remaining = out_offset;\n          for (int i = 0; i < ndim; ++i) {\n            const int64_t idx = remaining / out_stride[i];\n            coordinates[i] = idx;\n            remaining = remaining - idx * out_stride[i];\n          }\n          // compute input offset\n          for (int64_t dim = 0; dim < ndim; ++dim) {\n            in_offset += in_stride[dim] * coordinates[dim];\n          }\n\n          // copy block_size data to output\n          std::memcpy(out_dptr + out_offset * dsize, in_dptr + in_offset * dsize,\n                      block_size * dsize);\n        }\n      }\n    }\n  }\n};\n\nnamespace {\n\ntemplate<DeviceType device_type, typename T>\nclass ToContiguousKernel final : public user_op::OpKernel {\n public:\n  ToContiguousKernel() = default;\n  ~ToContiguousKernel() override = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n\n    const ShapeView& in_shape = in->shape_view();\n    CHECK_EQ(out->shape_view(), in_shape);\n    const DataType in_data_type = in->data_type();\n    CHECK_EQ(out->data_type(), in_data_type);\n\n    std::vector<int64_t> in_stride(in->stride().begin(), in->stride().end());\n\n    const char* in_dptr = static_cast<const char*>(in->raw_dptr());\n    char* out_dptr = static_cast<char*>(out->mut_raw_dptr());\n    ToContiguousUtil<device_type, T>(ctx->stream(), in_shape, in_stride, in_dptr, out_dptr)();\n  }\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_TO_CONTIGUOUS_KERNEL(device_type, cpp_type, data_type) \\\n  REGISTER_USER_KERNEL(\"to_contiguous\")                                 \\\n      .SetCreateFn<ToContiguousKernel<device_type, cpp_type>>()         \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == device_type)        \\\n                       && (user_op::HobDataType(\"in\", 0) == data_type));\n\n#define REGISTER_TO_CONTIGUOUS_CPU_KERNEL(cpp_type, data_type) \\\n  REGISTER_TO_CONTIGUOUS_KERNEL(DeviceType::kCPU, cpp_type, data_type)\n#define REGISTER_TO_CONTIGUOUS_CUDA_KERNEL(cpp_type, data_type) \\\n  REGISTER_TO_CONTIGUOUS_KERNEL(DeviceType::kCUDA, cpp_type, data_type)\n\n#define REGISTER_TO_CONTIGUOUS_KERNEL_FOR_CPU_TYPES \\\n  OF_PP_FOR_EACH_TUPLE(REGISTER_TO_CONTIGUOUS_CPU_KERNEL, TO_CONTIGUOUS_CPU_TYPES)\n\n#define REGISTER_TO_CONTIGUOUS_KERNEL_FOR_CUDA_TYPES       \\\n  OF_PP_FOR_EACH_TUPLE(REGISTER_TO_CONTIGUOUS_CUDA_KERNEL, \\\n                       TO_CONTIGUOUS_COMMON_TYPES TO_CONTIGUOUS_CUDA_SPECIAL_TYPE)\n\nREGISTER_TO_CONTIGUOUS_KERNEL_FOR_CPU_TYPES\n#ifdef WITH_CUDA\nREGISTER_TO_CONTIGUOUS_KERNEL_FOR_CUDA_TYPES\n#endif\n\n}  // namespace\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/to_contiguous_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <type_traits>\n#include \"oneflow/core/common/device_type.pb.h\"\n#include \"oneflow/user/kernels/to_contiguous_kernel.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n#include \"oneflow/core/cuda/elementwise.cuh\"\n\nnamespace oneflow {\n\nnamespace {\n\nconstexpr int32_t kThreadWorkSize = 4;\nconstexpr int32_t kNumThreads = 32 * 4;\nconstexpr int32_t get_min_threads_num() { return kNumThreads; }\nconstexpr int32_t get_block_work_size() { return kThreadWorkSize * kNumThreads; }\nconstexpr int32_t get_num_blocks(int64_t elem_cnt) {\n  return (elem_cnt + get_block_work_size() - 1) / get_block_work_size();\n}\n\nstruct StrideParam {\n  int32_t stride[SHAPE_MAX_AXIS_SIZE];\n\n  StrideParam(const int64_t* stride_vec, const size_t ndim) {\n    for (size_t i = 0; i < ndim; ++i) { stride[i] = stride_vec[i]; }\n  }\n};\n\ntemplate<typename IndexType, size_t ndim>\n__device__ __forceinline__ IndexType compute_index(IndexType out_offset,\n                                                   const StrideParam& out_params,\n                                                   const StrideParam& in_params) {\n  IndexType in_offset = 0;\n  IndexType remaining = out_offset;\n\n#pragma unroll\n  for (size_t i = 0; i < ndim; ++i) {\n    const IndexType idx = static_cast<IndexType>(remaining / out_params.stride[i]);\n    remaining -= idx * out_params.stride[i];\n    in_offset += idx * in_params.stride[i];\n  }\n  return in_offset;\n}\n\ntemplate<typename T, typename IndexType, size_t ndim>\n__global__ void ToContiguousForwardGpuParallel(IndexType count, const StrideParam in_stride,\n                                               const StrideParam out_stride, const T* in_dptr,\n                                               T* out_dptr, const int32_t num_block_threads,\n                                               const int32_t thread_work_size,\n                                               const int32_t block_work_size) {\n  IndexType remaining = count - block_work_size * blockIdx.x;\n  IndexType idx = blockIdx.x;\n  IndexType thread_idx = threadIdx.x;\n#pragma unroll\n  for (int32_t i = 0; i < thread_work_size; i++) {\n    if (thread_idx >= remaining) { return; }\n    IndexType out_idx = thread_idx + block_work_size * idx;\n    IndexType in_idx = compute_index<IndexType, ndim>(out_idx, out_stride, in_stride);\n    out_dptr[out_idx] = in_dptr[in_idx];\n    thread_idx += num_block_threads;\n  }\n}\n\ntemplate<typename T, typename IndexType>\nvoid LaunchToContiguousKernel(ep::Stream* stream, IndexType count, const size_t ndim,\n                              IndexType block_size, const std::vector<int64_t>& in_stride,\n                              const DimVector& out_stride, const char* in_dptr, char* out_dptr) {\n  const int32_t num_blocks = get_num_blocks(count);\n  constexpr int32_t num_threads = get_min_threads_num();\n  constexpr int32_t block_work_size = get_block_work_size();\n  StrideParam param_in_stride(in_stride.data(), ndim), param_out_stride(out_stride.data(), ndim);\n\n  switch (ndim) {\n#define TO_CONTIGUOUS_FORWARD_GPU_PARALLEL(dim)                                             \\\n  case dim:                                                                                 \\\n    ToContiguousForwardGpuParallel<T, IndexType, dim>                                       \\\n        <<<num_blocks, num_threads, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(      \\\n            count, param_in_stride, param_out_stride, reinterpret_cast<const T*>(in_dptr),  \\\n            reinterpret_cast<T*>(out_dptr), num_threads, kThreadWorkSize, block_work_size); \\\n    break;\n\n    TO_CONTIGUOUS_FORWARD_GPU_PARALLEL(1)\n    TO_CONTIGUOUS_FORWARD_GPU_PARALLEL(2)\n    TO_CONTIGUOUS_FORWARD_GPU_PARALLEL(3)\n    TO_CONTIGUOUS_FORWARD_GPU_PARALLEL(4)\n    TO_CONTIGUOUS_FORWARD_GPU_PARALLEL(5)\n    TO_CONTIGUOUS_FORWARD_GPU_PARALLEL(6)\n    TO_CONTIGUOUS_FORWARD_GPU_PARALLEL(7)\n    TO_CONTIGUOUS_FORWARD_GPU_PARALLEL(8)\n    TO_CONTIGUOUS_FORWARD_GPU_PARALLEL(9)\n    TO_CONTIGUOUS_FORWARD_GPU_PARALLEL(10)\n    TO_CONTIGUOUS_FORWARD_GPU_PARALLEL(11)\n    TO_CONTIGUOUS_FORWARD_GPU_PARALLEL(12)\n    TO_CONTIGUOUS_FORWARD_GPU_PARALLEL(13)\n    TO_CONTIGUOUS_FORWARD_GPU_PARALLEL(14)\n    TO_CONTIGUOUS_FORWARD_GPU_PARALLEL(15)\n    TO_CONTIGUOUS_FORWARD_GPU_PARALLEL(16)\n    default: break;\n#undef TO_CONTIGUOUS_FORWARD_GPU_PARALLEL\n  }\n}\n\n}  // namespace\n\ntemplate<typename T>\nstruct ToContiguousUtil<DeviceType::kCUDA, T> : ToContiguousUtilBase {\n  using ToContiguousUtilBase::ToContiguousUtilBase;\n  static constexpr size_t dsize = sizeof(T);\n  void operator()() {\n    int constant_memory_size = 0;\n    const size_t ndims = contiguous_dim + 1;\n    if (ndims == 0) {\n      // 0-dim tensor\n      OF_CUDA_CHECK(cudaMemcpyAsync(out_dptr, in_dptr, block_size * dsize, cudaMemcpyDeviceToDevice,\n                                    stream->As<ep::CudaStream>()->cuda_stream()));\n    } else {\n      bool is_same = true;\n      for (int64_t i = contiguous_dim; i != -1; --i) {\n        if (out_stride[i] != in_stride[i]) {\n          is_same = false;\n          break;\n        }\n      }\n      if (is_same) {\n        // if input tensor's strides equals to output's, than just copy one memory-contiguous tensor\n        OF_CUDA_CHECK(cudaMemcpyAsync(out_dptr, in_dptr, element_count * dsize,\n                                      cudaMemcpyDeviceToDevice,\n                                      stream->As<ep::CudaStream>()->cuda_stream()));\n      } else {\n        if (element_count < GetMaxVal<int32_t>()) {\n          LaunchToContiguousKernel<T, int32_t>(stream, element_count, ndims, block_size, in_stride,\n                                               out_stride, in_dptr, out_dptr);\n        } else {\n          LaunchToContiguousKernel<T, int64_t>(stream, element_count, ndims, block_size, in_stride,\n                                               out_stride, in_dptr, out_dptr);\n        }\n      }\n    }\n  }\n};\n\n#define INSTANTIATE_TO_CONTIGUOUS_UTILS_FOR_CUDA(cpp_type, data_type) \\\n  template struct ToContiguousUtil<DeviceType::kCUDA, cpp_type>;\nOF_PP_FOR_EACH_TUPLE(INSTANTIATE_TO_CONTIGUOUS_UTILS_FOR_CUDA,\n                     TO_CONTIGUOUS_COMMON_TYPES TO_CONTIGUOUS_CUDA_SPECIAL_TYPE)\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/to_contiguous_kernel.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_KERNELS_TO_CONTIGUOUS_KERNEL_H_\n#define ONEFLOW_USER_KERNELS_TO_CONTIGUOUS_KERNEL_H_\n\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/common/shape_vec.h\"\n#include \"oneflow/core/kernel/kernel_util.h\"\n#include \"oneflow/core/ep/include/stream.h\"\n\nnamespace oneflow {\n\nclass ToContiguousUtilParam {\n protected:\n  ToContiguousUtilParam(ep::Stream* stream, const ShapeView& in_shape,\n                        const std::vector<int64_t>& in_stride, const char* in_dptr, char* out_dptr)\n      : stream(stream),\n        in_shape(in_shape),\n        in_stride(in_stride),\n        in_dptr(in_dptr),\n        out_dptr(out_dptr) {}\n\n  ep::Stream* stream;\n  const ShapeView& in_shape;\n  const std::vector<int64_t>& in_stride;\n  const char* in_dptr;\n  char* out_dptr;\n};\n\nclass ToContiguousUtilBase : public ToContiguousUtilParam {\n public:\n  ToContiguousUtilBase(ep::Stream* stream, const ShapeView& in_shape,\n                       const std::vector<int64_t>& in_stride, const char* in_dptr, char* out_dptr)\n      : ToContiguousUtilParam(stream, in_shape, in_stride, in_dptr, out_dptr),\n        block_size(1),\n        contiguous_dim(in_shape.NumAxes() - 1),\n        out_stride(in_shape.NumAxes()),\n        in_offset(0),\n        out_offset(0),\n        element_count(1) {\n    for (int64_t i = contiguous_dim; i != -1; --i) {\n      out_stride[i] = element_count;\n      element_count *= in_shape.At(i);\n    }\n    for (int64_t i = contiguous_dim; i != -1; --i) {\n      if (block_size == in_stride[i]) {\n        block_size *= in_shape.At(i);\n      } else {\n        break;\n      }\n    }\n  }\n\n  int64_t block_size = 1;\n  int64_t contiguous_dim = 0;\n\n  DimVector out_stride;\n\n  int64_t in_offset = 0;\n  int64_t out_offset = 0;\n  int64_t element_count = 1;\n};\n\ntemplate<DeviceType, typename>\nstruct ToContiguousUtil : ToContiguousUtilBase {\n  using ToContiguousUtilBase::ToContiguousUtilBase;\n\n  void operator()();\n};\n\n}  // namespace oneflow\n\n#define TO_CONTIGUOUS_COMMON_TYPES                  \\\n  OF_PP_MAKE_TUPLE_SEQ(bool, DataType::kBool)       \\\n  OF_PP_MAKE_TUPLE_SEQ(char, DataType::kChar)       \\\n  OF_PP_MAKE_TUPLE_SEQ(int8_t, DataType::kInt8)     \\\n  OF_PP_MAKE_TUPLE_SEQ(uint8_t, DataType::kUInt8)   \\\n  OF_PP_MAKE_TUPLE_SEQ(int32_t, DataType::kInt32)   \\\n  OF_PP_MAKE_TUPLE_SEQ(uint32_t, DataType::kUInt32) \\\n  OF_PP_MAKE_TUPLE_SEQ(int64_t, DataType::kInt64)   \\\n  OF_PP_MAKE_TUPLE_SEQ(uint64_t, DataType::kUInt64) \\\n  OF_PP_MAKE_TUPLE_SEQ(float, DataType::kFloat)     \\\n  OF_PP_MAKE_TUPLE_SEQ(double, DataType::kDouble)\n\n#define TO_CONTIGUOUS_CPU_TYPES                                          \\\n  TO_CONTIGUOUS_COMMON_TYPES COMPLEX_DATA_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ( \\\n      float16, DataType::kFloat16) OF_PP_MAKE_TUPLE_SEQ(bfloat16, DataType::kBFloat16)\n\n#ifdef WITH_CUDA\n#if CUDA_VERSION >= 11000\n#define TO_CONTIGUOUS_CUDA_SPECIAL_TYPE                  \\\n  OF_PP_MAKE_TUPLE_SEQ(half, DataType::kFloat16)         \\\n  OF_PP_MAKE_TUPLE_SEQ(nv_bfloat16, DataType::kBFloat16) \\\n  OF_PP_MAKE_TUPLE_SEQ(cuComplex, DataType::kComplex64)  \\\n  OF_PP_MAKE_TUPLE_SEQ(cuDoubleComplex, DataType::kComplex128)\n#else\n#define TO_CONTIGUOUS_CUDA_SPECIAL_TYPE                 \\\n  OF_PP_MAKE_TUPLE_SEQ(half, DataType::kFloat16)        \\\n  OF_PP_MAKE_TUPLE_SEQ(cuComplex, DataType::kComplex64) \\\n  OF_PP_MAKE_TUPLE_SEQ(cuDoubleComplex, DataType::kComplex128)\n#endif  // CUDA_VERSION >= 11000\n#endif  // WITH_CUDA\n#endif  // ONEFLOW_USER_KERNELS_TO_CONTIGUOUS_KERNEL_H_\n"
  },
  {
    "path": "oneflow/user/kernels/top_k_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/thread/thread_manager.h\"\n#include \"oneflow/core/common/balanced_splitter.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename T>\nvoid ComputeTopOne(const T* in_ptr, const Range& range, int64_t instance_size, int64_t* out_ptr) {\n  FOR_RANGE(int64_t, i, range.begin(), range.end()) {\n    const T* in_ptr_i = in_ptr + i * instance_size;\n    out_ptr[i] = std::distance(in_ptr_i, std::max_element(in_ptr_i, in_ptr_i + instance_size));\n  }\n}\n\ntemplate<typename T>\nvoid ComputeTopK(const T* in_ptr, int64_t* indices_ptr, const Range& range, int64_t instance_size,\n                 int64_t k, bool sorted, int64_t* out_ptr) {\n  FOR_RANGE(int64_t, i, range.begin(), range.end()) {\n    const int64_t offset = i * instance_size;\n    const T* in_ptr_i = in_ptr + offset;\n    int64_t* indices_ptr_i = indices_ptr + offset;\n    std::iota(indices_ptr_i, indices_ptr_i + instance_size, 0);\n    auto comp = [&](const int64_t lhs, const int64_t rhs) {\n      const T l = in_ptr_i[lhs];\n      const T r = in_ptr_i[rhs];\n      if (l == r) {\n        return lhs < rhs;\n      } else {\n        return l > r;\n      }\n    };\n    std::nth_element(indices_ptr_i, indices_ptr_i + k, indices_ptr_i + instance_size, comp);\n    if (sorted) { std::sort(indices_ptr_i, indices_ptr_i + k, comp); }\n    std::copy(indices_ptr_i, indices_ptr_i + k, out_ptr + i * k);\n  }\n}\n\ntemplate<typename T>\nvoid CpuTopK(ep::Stream* /*stream*/, const T* in_ptr, int64_t* indices_ptr, int64_t instance_num,\n             int64_t instance_size, int64_t k, bool sorted, int64_t* out_ptr) {\n  const int64_t num_thread =\n      std::min(instance_num, static_cast<int64_t>(Singleton<ThreadPool>::Get()->thread_num()));\n  const BalancedSplitter bs(instance_num, num_thread);\n  BlockingCounter bc(num_thread);\n  FOR_RANGE(int64_t, thread_id, 0, num_thread) {\n    const Range range = bs.At(thread_id);\n    Singleton<ThreadPool>::Get()->AddWork([=, &bc]() {\n      if (k == 1) {\n        ComputeTopOne(in_ptr, range, instance_size, out_ptr);\n      } else {\n        ComputeTopK(in_ptr, indices_ptr, range, instance_size, k, sorted, out_ptr);\n      }\n      bc.Decrease();\n    });\n  }\n  bc.WaitForeverUntilCntEqualZero();\n}\n\n}  // namespace\n\ntemplate<typename T>\nclass TopKCpuKernel final : public user_op::OpKernel {\n public:\n  TopKCpuKernel() = default;\n  ~TopKCpuKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    if (in->shape_view().elem_cnt() == 0) { return; }\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n\n    const int64_t instance_size = in->shape_view().At(in->shape_view().NumAxes() - 1);\n    const int64_t instance_num = in->shape_view().elem_cnt() / instance_size;\n    const int64_t k = std::min(static_cast<int64_t>(ctx->Attr<int32_t>(\"k\")), instance_size);\n    int64_t* indices_ptr = tmp_buffer ? tmp_buffer->mut_dptr<int64_t>() : nullptr;\n    CpuTopK(ctx->stream(), in->dptr<T>(), indices_ptr, instance_num, instance_size, k,\n            ctx->Attr<bool>(\"sorted\"), out->mut_dptr<int64_t>());\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_CPU_TOP_K_KERNEL(dtype)                                                \\\n  REGISTER_USER_KERNEL(\"top_k\")                                                         \\\n      .SetCreateFn<TopKCpuKernel<dtype>>()                                              \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)                   \\\n                       && (user_op::HobDataType(\"in\", 0) == GetDataType<dtype>::value)) \\\n      .SetInferTmpSizeFn([](user_op::InferContext* ctx) {                               \\\n        const Shape& in_shape = ctx->InputShape(\"in\", 0);                               \\\n        return ctx->Attr<int32_t>(\"k\") > 1 ? in_shape.elem_cnt() * sizeof(int64_t) : 0; \\\n      });\n\nREGISTER_CPU_TOP_K_KERNEL(float)\nREGISTER_CPU_TOP_K_KERNEL(double)\nREGISTER_CPU_TOP_K_KERNEL(int8_t)\nREGISTER_CPU_TOP_K_KERNEL(uint8_t)\nREGISTER_CPU_TOP_K_KERNEL(int32_t)\nREGISTER_CPU_TOP_K_KERNEL(int64_t)\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/top_k_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/user/kernels/radix_sort.cuh\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename T>\nT PowOf2Floor(T val, int64_t max_power) {\n  CHECK_GT(val, GetZeroVal<T>());\n  T max_floor = static_cast<T>(std::pow(2, max_power));\n  val = std::min(val, max_floor);\n  T ret = GetOneVal<T>();\n  while (true) {\n    ret *= 2;\n    if (ret >= val) { return ret == val ? ret : ret / 2; }\n  }\n}\n\ntemplate<typename T>\nT PowOf2Ceil(T val, int64_t max_power) {\n  CHECK_GT(val, GetZeroVal<T>());\n  T max_ceil = static_cast<T>(std::pow(2, max_power));\n  val = std::min(val, max_ceil);\n  T ret = GetOneVal<T>();\n  while (true) {\n    ret *= 2;\n    if (ret >= val) { return ret; }\n  }\n}\n\ntemplate<typename T, typename Compare>\n__device__ void BitonicSwap(T* data, const int64_t i, const int64_t j, const bool dir,\n                            const Compare& comp) {\n  if (comp(data[i], data[j]) == dir) {\n    T tmp = data[i];\n    data[i] = data[j];\n    data[j] = tmp;\n  }\n}\n\n// https://en.wikipedia.org/wiki/Bitonic_sorter\ntemplate<typename T, typename Compare>\n__device__ void BitonicSort(T* data, const int64_t elem_cnt, const Compare& comp) {\n  // The element count of instance should be pow-of-2\n  assert(elem_cnt > 0 && !(elem_cnt & (elem_cnt - 1)));\n\n  // Generate a bitonic sequence from input\n  for (int64_t size = 2; size <= elem_cnt / 2; size *= 2) {\n    // Merge 2 bitonic sequences of length 'size' into a bitonic sequence of length '2 * size'\n    for (int64_t stride = size / 2; stride > 0; stride /= 2) {\n      for (int64_t swap_id = threadIdx.x; swap_id < elem_cnt / 2; swap_id += blockDim.x) {\n        // Change dir at intervals of 'size / 2' swaps\n        const bool dir = swap_id & (size / 2);\n        // Locate the pair {pos, pos + stride} which is going te be swaped if needed\n        const int pos = 2 * swap_id - (swap_id & (stride - 1));\n\n        BitonicSwap(data, pos, pos + stride, dir, comp);\n\n        __syncthreads();\n      }\n    }\n  }\n\n  // Sort the bitonic sequence\n  for (int64_t stride = elem_cnt / 2; stride > 0; stride /= 2) {\n    for (int64_t swap_id = threadIdx.x; swap_id < elem_cnt / 2; swap_id += blockDim.x) {\n      // Locate the pair {pos, pos + stride} which is going te be swaped if needed\n      const int pos = 2 * swap_id - (swap_id & (stride - 1));\n\n      BitonicSwap(data, pos, pos + stride, false, comp);\n\n      __syncthreads();\n    }\n  }\n}\n\ntemplate<typename T>\nclass Entry final {\n public:\n  __device__ __forceinline__ Entry(int64_t index, T value) : index_(index), value_(value) {}\n\n  __device__ __forceinline__ int64_t GetIndex() const { return index_; }\n  __device__ __forceinline__ T GetValue() const { return value_; }\n  __device__ __forceinline__ void SetIndex(int64_t index) { index_ = index; }\n  __device__ __forceinline__ void SetValue(T value) { value_ = value; }\n\n  __device__ __forceinline__ bool operator<(const Entry& entry) const {\n    return (value_ < entry.GetValue()) || (value_ == entry.GetValue() && index_ > entry.GetIndex());\n  }\n  __device__ __forceinline__ bool operator>(const Entry& entry) const {\n    return (value_ > entry.GetValue()) || (value_ == entry.GetValue() && index_ < entry.GetIndex());\n  }\n\n private:\n  int64_t index_;\n  T value_;\n};\n\ntemplate<typename T>\nclass MinHeap final {\n public:\n  __device__ __forceinline__ MinHeap(Entry<T>* data, const int64_t heap_size,\n                                     const int64_t init_index, const T init_value)\n      : data_(data), heap_size_(heap_size) {\n    for (int64_t i = 0; i < heap_size; ++i) {\n      data_[i].SetIndex(init_index);\n      data_[i].SetValue(init_value);\n    }\n  }\n  __device__ __forceinline__ Entry<T>& Top() { return data_[0]; }\n  __device__ __forceinline__ void Swap(const int64_t i, const int64_t j) {\n    auto tmp = data_[j];\n    data_[j] = data_[i];\n    data_[i] = tmp;\n  }\n  __device__ __forceinline__ void MinHeapify(int64_t index) {\n    while (true) {\n      const int64_t left = 2 * index + 1;\n      const int64_t right = 2 * index + 2;\n      int64_t min = index;\n      if (left < heap_size_ && data_[left] < data_[min]) { min = left; }\n      if (right < heap_size_ && data_[right] < data_[min]) { min = right; }\n      if (min == index) { return; }\n      Swap(min, index);\n      index = min;\n    }\n  }\n\n private:\n  Entry<T>* data_;\n  int64_t heap_size_;\n};\n\ntemplate<typename T>\n__global__ void HeapTopKKernel(const T* in_ptr, const int64_t instance_num,\n                               const int64_t instance_size, const int64_t k,\n                               const int64_t heap_size, const int64_t init_index,\n                               const T init_value, int64_t* out_ptr) {\n  extern __shared__ char smem[];\n  auto* shared_entries = reinterpret_cast<Entry<T>*>(smem);\n\n  // Divide elements to be sorted into disjoint sets (# of sets == # of heaps).\n  // Each thread in the thread block manipulates one heap to select top heap_size entries from\n  // corresponding set\n  const T* input = in_ptr + blockIdx.x * instance_size;\n  auto heap =\n      MinHeap<T>(shared_entries + threadIdx.x * heap_size, heap_size, init_index, init_value);\n  for (int64_t i = threadIdx.x; i < instance_size; i += blockDim.x) {\n    auto entry = Entry<T>(i, input[i]);\n    if (entry > heap.Top()) {\n      heap.Top() = entry;\n      heap.MinHeapify(0);\n    }\n  }\n\n  __syncthreads();\n\n  // Merge all heaps into a unified, sorted array\n  BitonicSort(shared_entries, blockDim.x * heap_size,\n              [](const Entry<T>& x, const Entry<T>& y) { return x > y; });\n\n  // Write top_k elements in sorted array to output\n  for (int64_t i = threadIdx.x; i < k; i += blockDim.x) {\n    (out_ptr + blockIdx.x * k)[i] = shared_entries[i].GetIndex();\n  }\n}\n\ntemplate<typename T>\nclass TmpBufferManager final {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(TmpBufferManager);\n  TmpBufferManager(int64_t capacity, void* ptr, const ShapeView& in_shape)\n      : capacity_{capacity},\n        sorted_in_elem_cnt_{in_shape.elem_cnt()},\n        indices_elem_cnt_{sorted_in_elem_cnt_},\n        sorted_indices_elem_cnt_{sorted_in_elem_cnt_} {\n    const int64_t sorted_in_aligned_bytes = GetCudaAlignedSize(sorted_in_elem_cnt_ * sizeof(T));\n    const int64_t indices_aligned_bytes = GetCudaAlignedSize(indices_elem_cnt_ * sizeof(int64_t));\n    const int64_t sorted_indices_aligned_bytes = indices_aligned_bytes;\n    sorted_in_ptr_ = reinterpret_cast<T*>(ptr);\n    indices_ptr_ = reinterpret_cast<int64_t*>(reinterpret_cast<char*>(sorted_in_ptr_)\n                                              + sorted_in_aligned_bytes);\n    sorted_indices_ptr_ =\n        reinterpret_cast<int64_t*>(reinterpret_cast<char*>(indices_ptr_) + indices_aligned_bytes);\n    temp_storage_ptr_ = reinterpret_cast<void*>(reinterpret_cast<char*>(sorted_indices_ptr_)\n                                                + sorted_indices_aligned_bytes);\n    temp_storage_bytes_ =\n        capacity_ - sorted_in_aligned_bytes - indices_aligned_bytes - sorted_indices_aligned_bytes;\n    CHECK_GE(temp_storage_bytes_, 0);\n  }\n  ~TmpBufferManager() = default;\n\n  T* SortedInPtr() const { return sorted_in_ptr_; }\n  int64_t* IndicesPtr() const { return indices_ptr_; }\n  int64_t* SortedIndicesPtr() const { return sorted_indices_ptr_; }\n  void* TempStoragePtr() const { return temp_storage_ptr_; }\n\n  int64_t TempStorageBytes() const { return temp_storage_bytes_; }\n\n private:\n  int64_t capacity_;\n\n  T* sorted_in_ptr_;\n  int64_t* indices_ptr_;\n  int64_t* sorted_indices_ptr_;\n  void* temp_storage_ptr_;\n\n  int64_t sorted_in_elem_cnt_;\n  int64_t indices_elem_cnt_;\n  int64_t sorted_indices_elem_cnt_;\n  int64_t temp_storage_bytes_;\n};\n\n__global__ void InitializeIndices(int64_t elem_cnt, int64_t* indices_ptr, int64_t instance_size) {\n  CUDA_1D_KERNEL_LOOP(i, elem_cnt) { indices_ptr[i] = i % instance_size; };\n}\n\n}  // namespace\n\ntemplate<typename T>\nclass GpuTopKKernel final : public user_op::OpKernel {\n public:\n  GpuTopKKernel() = default;\n  ~GpuTopKKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    if (in->shape_view().elem_cnt() == 0) { return; }\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n\n    const int64_t elem_cnt = in->shape_view().elem_cnt();\n    const int64_t instance_size = in->shape_view().At(in->shape_view().NumAxes() - 1);\n    const int64_t instance_num = elem_cnt / instance_size;\n    const int64_t k = std::min(static_cast<int64_t>(ctx->Attr<int32_t>(\"k\")), instance_size);\n\n    if (k > 30 || instance_num == 1) {\n      user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n      TmpBufferManager<T> buf_manager(static_cast<int64_t>(tmp_buffer->shape_view().elem_cnt()),\n                                      tmp_buffer->mut_dptr<void>(), in->shape_view());\n\n      InitializeIndices<<<BlocksNum4ThreadsNum(elem_cnt), kCudaThreadsNumPerBlock, 0,\n                          ctx->stream()->As<ep::CudaStream>()->cuda_stream()>>>(\n          elem_cnt, buf_manager.IndicesPtr(), instance_size);\n      SortPairsDescending(in->dptr<T>(), buf_manager.IndicesPtr(), instance_num, instance_size,\n                          buf_manager.TempStoragePtr(), buf_manager.TempStorageBytes(),\n                          buf_manager.SortedInPtr(), buf_manager.SortedIndicesPtr(),\n                          ctx->stream()->As<ep::CudaStream>()->cuda_stream());\n      OF_CUDA_CHECK(cudaMemcpy2DAsync(\n          out->mut_dptr<int64_t>(), k * sizeof(int64_t), buf_manager.SortedIndicesPtr(),\n          instance_size * sizeof(int64_t), k * sizeof(int64_t), instance_num, cudaMemcpyDefault,\n          ctx->stream()->As<ep::CudaStream>()->cuda_stream()));\n    } else {\n      // Use as many heaps as possible (# of heaps == # of threads used in thread block).\n      // Limitation 1: size of shared memory\n      // We also need heap_size * num_heap to be pow-of-2 which is necessary for bitonic sort\n      const int64_t heap_size = PowOf2Ceil(k, 16);\n      int32_t num_heap =\n          PowOf2Floor(kCudaMaxSharedMemoryByteSize / (heap_size * sizeof(Entry<T>)), 16);\n      // Limitation 2: # of threads in thread block\n      num_heap = std::min(num_heap, kCudaThreadsNumPerBlock);\n\n      HeapTopKKernel<T><<<instance_num, num_heap, num_heap * heap_size * sizeof(Entry<T>),\n                          ctx->stream()->As<ep::CudaStream>()->cuda_stream()>>>(\n          in->dptr<T>(), instance_num, instance_size, k, heap_size, GetMaxVal<int64_t>(),\n          GetMinVal<T>(), out->mut_dptr<int64_t>());\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_CUDA_TOP_K_KERNEL(dtype)                                                        \\\n  REGISTER_USER_KERNEL(\"top_k\")                                                                  \\\n      .SetCreateFn<GpuTopKKernel<dtype>>()                                                       \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)                           \\\n                       && (user_op::HobDataType(\"in\", 0) == GetDataType<dtype>::value))          \\\n      .SetInferTmpSizeFn([](user_op::InferContext* ctx) {                                        \\\n        const Shape& in_shape = ctx->InputShape(\"in\", 0);                                        \\\n        const int64_t elem_cnt = in_shape.elem_cnt();                                            \\\n        const int64_t instance_size = in_shape.dim_vec().back();                                 \\\n        const int64_t instance_num = elem_cnt / instance_size;                                   \\\n                                                                                                 \\\n        /* Sorted In*/                                                                           \\\n        const int64_t sorted_in_aligned_bytes = GetCudaAlignedSize(elem_cnt * sizeof(dtype));    \\\n        /* Indices */                                                                            \\\n        const int64_t indices_aligned_bytes = GetCudaAlignedSize(elem_cnt * sizeof(int64_t));    \\\n        /* Sorted Indices */                                                                     \\\n        const int64_t sorted_indices_aligned_bytes = indices_aligned_bytes;                      \\\n        /* CUB Temp Storage */                                                                   \\\n        int64_t temp_storage_bytes =                                                             \\\n            InferTempStorageForSortPairsDescending<dtype, int64_t>(instance_num, instance_size); \\\n                                                                                                 \\\n        return sorted_in_aligned_bytes + indices_aligned_bytes + sorted_indices_aligned_bytes    \\\n               + temp_storage_bytes;                                                             \\\n      });\n\nREGISTER_CUDA_TOP_K_KERNEL(float)\nREGISTER_CUDA_TOP_K_KERNEL(double)\nREGISTER_CUDA_TOP_K_KERNEL(uint8_t)\nREGISTER_CUDA_TOP_K_KERNEL(int8_t)\nREGISTER_CUDA_TOP_K_KERNEL(int32_t)\nREGISTER_CUDA_TOP_K_KERNEL(int64_t)\nREGISTER_CUDA_TOP_K_KERNEL(half)\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/transpose_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/kernel_util.h\"\n#include \"oneflow/core/kernel/cuda_graph_support.h\"\n#include \"oneflow/core/ep/include/primitive/permute.h\"\n#include \"oneflow/core/ep/common/primitive/permute.h\"\n\nnamespace oneflow {\n\nnamespace user_op {\n\nnamespace {\n\nbool IsIdentity(const ShapeView& in_shape, const std::vector<int32_t>& perm) {\n  constexpr int kMaxNumDims = 12;\n  CHECK_LE(in_shape.NumAxes(), kMaxNumDims);\n  CHECK_EQ(in_shape.NumAxes(), perm.size());\n\n  size_t simplified_num_dims{};\n  int64_t simplified_src_dims[kMaxNumDims]{};\n  int simplified_permutation[kMaxNumDims]{};\n  ep::primitive::permute::SimplifyPermutation<kMaxNumDims>(\n      in_shape.NumAxes(), in_shape.ptr(), perm.data(), &simplified_num_dims, simplified_src_dims,\n      simplified_permutation);\n  for (int i = 0; i < simplified_num_dims; ++i) {\n    if (simplified_permutation[i] != i) { return false; }\n  }\n  return true;\n}\n\n}  // namespace\n\ntemplate<typename Context>\nstd::unique_ptr<ep::primitive::Permute> NewPermutePrimitive(Context* ctx) {\n  const int64_t num_dims = ctx->TensorDesc4ArgNameAndIndex(\"output\", 0)->shape().NumAxes();\n  return ep::primitive::NewPrimitive<ep::primitive::PermuteFactory>(ctx->device_type(), num_dims);\n}\n\nclass TransposeKernel final : public OpKernel, public user_op::CudaGraphSupport {\n public:\n  OF_DISALLOW_COPY_AND_MOVE(TransposeKernel);\n  TransposeKernel() = default;\n  ~TransposeKernel() override = default;\n\n private:\n  void Compute(KernelComputeContext* ctx) const override {\n    auto primitive = NewPermutePrimitive(ctx);\n    CHECK(primitive);\n\n    const Tensor* tensor_in = ctx->Tensor4ArgNameAndIndex(\"input\", 0);\n    Tensor* tensor_out = ctx->Tensor4ArgNameAndIndex(\"output\", 0);\n    const auto& perm = ctx->Attr<std::vector<int32_t>>(\"perm\");\n    const ShapeView& in_shape = tensor_in->shape_view();\n    DataType dtype = tensor_out->data_type();\n    size_t num_dims = tensor_in->shape_view().NumAxes();\n    const int64_t* src_dims = in_shape.ptr();\n\n    int64_t elem_cnt = tensor_out->shape_view().elem_cnt();\n\n    if (elem_cnt != 0) {\n      if (IsIdentity(in_shape, perm)) {\n        // if permute vector is 0,1,...,n, do data copy directly\n        AutoMemcpy(ctx->stream(), tensor_out->mut_dptr(), tensor_in->dptr(),\n                   elem_cnt * GetSizeOfDataType(dtype), tensor_out->mem_case(),\n                   tensor_in->mem_case());\n      } else {\n        primitive->Launch(ctx->stream(), dtype, num_dims, src_dims, tensor_in->dptr(), perm.data(),\n                          tensor_out->mut_dptr());\n      }\n\n    } else {\n      // For 0-d Tensor\n      return;\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\nauto PermutePrimitiveExists() {\n  return hob::make_custom(\"PermutePrimitiveExists\", [](const user_op::KernelRegContext& ctx) {\n    return NewPermutePrimitive(&ctx).operator bool();\n  });\n}\n\nREGISTER_USER_KERNEL(\"transpose\")\n    .SetCreateFn<TransposeKernel>()\n    .SetIsMatchedHob(PermutePrimitiveExists() == true)\n    .SetInplaceProposalFn([](const user_op::InferContext& ctx,\n                             const user_op::AddInplaceArgPair& AddInplaceArgPairFn) -> Maybe<void> {\n      const ShapeView input_shape(ctx.InputShape(\"input\", 0));\n      const auto& perm = ctx.Attr<std::vector<int32_t>>(\"perm\");\n      if (IsIdentity(input_shape, perm)) {\n        OF_RETURN_IF_ERROR(AddInplaceArgPairFn(\"output\", 0, \"input\", 0, false));\n      }\n      return Maybe<void>::Ok();\n    });\n\n}  // namespace user_op\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/tril_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <cstdint>\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/kernel_util.h\"\n\nnamespace oneflow {\n\ntemplate<typename T>\nclass CpuTrilKernel final : public user_op::OpKernel {\n public:\n  CpuTrilKernel() = default;\n  ~CpuTrilKernel() override = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    const auto shape = x->shape_view();\n    const auto diagonal = ctx->Attr<int64_t>(\"diagonal\");\n    const int64_t num_rows = shape.At(shape.NumAxes() - 2);\n    const int64_t num_cols = shape.At(shape.NumAxes() - 1);\n    user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    T* y_dptr = y->mut_dptr<T>();\n    const T* x_dptr = x->dptr<T>();\n    const T fill = ctx->Attr<bool>(\"is_floating_fill_value\")\n                       ? static_cast<T>(ctx->Attr<double>(\"floating_fill_value\"))\n                       : static_cast<T>(ctx->Attr<int64_t>(\"integer_fill_value\"));\n    int64_t matrix_size = num_rows * num_cols;\n    for (int64_t k = 0; k < shape.elem_cnt(); ++k) {\n      int64_t offset_in_matrix = k % matrix_size;\n      int64_t i = offset_in_matrix / num_cols;\n      int64_t j = offset_in_matrix - num_cols * i;\n      y_dptr[k] = j > i + diagonal ? fill : x_dptr[k];\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_CPU_TRIL_KERNEL(dtype)                                             \\\n  REGISTER_USER_KERNEL(\"tril\").SetCreateFn<CpuTrilKernel<dtype>>().SetIsMatchedHob( \\\n      (user_op::HobDeviceType() == DeviceType::kCPU)                                \\\n      && (user_op::HobDataType(\"out\", 0) == GetDataType<dtype>::value));\n\nREGISTER_CPU_TRIL_KERNEL(float)\nREGISTER_CPU_TRIL_KERNEL(double)\nREGISTER_CPU_TRIL_KERNEL(bool)\nREGISTER_CPU_TRIL_KERNEL(uint8_t)\nREGISTER_CPU_TRIL_KERNEL(int8_t)\nREGISTER_CPU_TRIL_KERNEL(int32_t)\nREGISTER_CPU_TRIL_KERNEL(int64_t)\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/tril_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/data_type.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/util/cuda_half_util.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename T>\n__global__ void TrilGpu(const int64_t elem_cnt, const int64_t num_rows, const int64_t num_cols,\n                        const int64_t diagonal, const T* x, const T fill, T* y) {\n  const int64_t matrix_size = num_rows * num_cols;\n  CUDA_1D_KERNEL_LOOP_T(int64_t, k, elem_cnt) {\n    const int64_t offset_in_matrix = k % matrix_size;\n    const int64_t i = offset_in_matrix / num_cols;\n    const int64_t j = offset_in_matrix - num_cols * i;\n    y[k] = j > i + diagonal ? fill : x[k];\n  }\n}\n\ntemplate<typename T>\n__global__ void TrilWarpProcessRowGpu(const int64_t total_rows, const int64_t num_rows,\n                                      const int64_t num_cols, const int64_t diagonal, const T* x,\n                                      const T fill, T* y) {\n  const int64_t warp_id = (blockIdx.x * blockDim.x + threadIdx.x) / kCudaWarpSize;\n  const int64_t lan_id = threadIdx.x % kCudaWarpSize;\n  const int64_t num_warp = blockDim.x * gridDim.x / kCudaWarpSize;\n  for (int64_t i = warp_id; i < total_rows; i += num_warp) {\n    const int64_t row = i % num_rows;\n    for (int64_t col = lan_id; col < num_cols; col += kCudaWarpSize) {\n      const int64_t idx = i * num_cols + col;\n      y[idx] = col > row + diagonal ? fill : x[idx];\n    }\n  }\n}\n\ntemplate<>\n__global__ void TrilWarpProcessRowGpu<half>(const int64_t total_rows, const int64_t num_rows,\n                                            const int64_t num_cols, const int64_t diagonal,\n                                            const half* x, const half fill, half* y) {\n  const int64_t h2_num_cols = num_cols / 2;\n  const auto* x_h2 = reinterpret_cast<const half2*>(x);\n  auto* y_h2 = reinterpret_cast<half2*>(y);\n\n  const int64_t warp_id = (blockIdx.x * blockDim.x + threadIdx.x) / kCudaWarpSize;\n  const int64_t lan_id = threadIdx.x % kCudaWarpSize;\n  const int64_t num_warp = blockDim.x * gridDim.x / kCudaWarpSize;\n  for (int64_t i = warp_id; i < total_rows; i += num_warp) {\n    const int64_t row = i % num_rows;\n    for (int64_t col = lan_id; col < h2_num_cols; col += kCudaWarpSize) {\n      const int64_t idx = i * h2_num_cols + col;\n      const half2 x_val = x_h2[idx];\n      half2 y_val;\n      y_val.x = (2 * col) > row + diagonal ? fill : x_val.x;\n      y_val.y = (2 * col + 1) > row + diagonal ? fill : x_val.y;\n      y_h2[idx] = y_val;\n    }\n  }\n}\n\ntemplate<typename T>\n__global__ void FusedScaleTrilGpu(const int64_t elem_cnt, const int64_t num_rows,\n                                  const int64_t num_cols, const int64_t diagonal, const T scale,\n                                  const T* x, const T fill, T* y) {\n  const int64_t matrix_size = num_rows * num_cols;\n  CUDA_1D_KERNEL_LOOP_T(int64_t, k, elem_cnt) {\n    const int64_t offset_in_matrix = k % matrix_size;\n    const int64_t i = offset_in_matrix / num_cols;\n    const int64_t j = offset_in_matrix - num_cols * i;\n    y[k] = j > i + diagonal ? fill : (scale * x[k]);\n  }\n}\n\ntemplate<typename T>\n__global__ void FusedScaleTrilWarpProcessRowGpu(const int64_t total_rows, const int64_t num_rows,\n                                                const int64_t num_cols, const int64_t diagonal,\n                                                const T scale, const T* x, const T fill, T* y) {\n  const int64_t warp_id = (blockIdx.x * blockDim.x + threadIdx.x) / kCudaWarpSize;\n  const int64_t lan_id = threadIdx.x % kCudaWarpSize;\n  const int64_t num_warp = blockDim.x * gridDim.x / kCudaWarpSize;\n  for (int64_t i = warp_id; i < total_rows; i += num_warp) {\n    const int64_t row = i % num_rows;\n    for (int64_t col = lan_id; col < num_cols; col += kCudaWarpSize) {\n      const int64_t idx = i * num_cols + col;\n      y[idx] = col > row + diagonal ? fill : (scale * x[idx]);\n    }\n  }\n}\n\ntemplate<>\n__global__ void FusedScaleTrilWarpProcessRowGpu<half>(const int64_t total_rows,\n                                                      const int64_t num_rows,\n                                                      const int64_t num_cols,\n                                                      const int64_t diagonal, const half scale,\n                                                      const half* x, const half fill, half* y) {\n  const int64_t h2_num_cols = num_cols / 2;\n  const auto* x_h2 = reinterpret_cast<const half2*>(x);\n  auto* y_h2 = reinterpret_cast<half2*>(y);\n  const half2 h2_scale = __half2half2(scale);\n  const int64_t warp_id = (blockIdx.x * blockDim.x + threadIdx.x) / kCudaWarpSize;\n  const int64_t lan_id = threadIdx.x % kCudaWarpSize;\n  const int64_t num_warp = blockDim.x * gridDim.x / kCudaWarpSize;\n  for (int64_t i = warp_id; i < total_rows; i += num_warp) {\n    const int64_t row = i % num_rows;\n    for (int64_t col = lan_id; col < h2_num_cols; col += kCudaWarpSize) {\n      const int64_t idx = i * h2_num_cols + col;\n      const half2 scaled_x = __hmul2(h2_scale, x_h2[idx]);\n      half2 y_val;\n      y_val.x = (2 * col) > row + diagonal ? fill : scaled_x.x;\n      y_val.y = (2 * col + 1) > row + diagonal ? fill : scaled_x.y;\n      y_h2[idx] = y_val;\n    }\n  }\n}\n\ntemplate<typename T>\nT GetAttrVal(bool is_floating_val, double floating_value, int64_t integer_value) {\n  return is_floating_val ? static_cast<T>(floating_value) : static_cast<T>(integer_value);\n}\n\ntemplate<>\nhalf GetAttrVal<half>(bool is_floating_val, double floating_value, int64_t integer_value) {\n  return is_floating_val ? __float2half(floating_value) : __float2half(integer_value);\n}\n\n}  // namespace\n\ntemplate<typename T>\nclass GpuTrilKernel final : public user_op::OpKernel {\n public:\n  GpuTrilKernel() = default;\n  ~GpuTrilKernel() override = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    const auto shape = x->shape_view();\n    const auto diagonal = ctx->Attr<int64_t>(\"diagonal\");\n    const int64_t num_rows = shape.At(shape.NumAxes() - 2);\n    const int64_t num_cols = shape.At(shape.NumAxes() - 1);\n    user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    const int32_t elem_cnt = shape.elem_cnt();\n    const T fill = GetAttrVal<T>(ctx->Attr<bool>(\"is_floating_fill_value\"),\n                                 ctx->Attr<double>(\"floating_fill_value\"),\n                                 ctx->Attr<int64_t>(\"integer_fill_value\"));\n    if (num_cols % (kCudaWarpSize * 2) == 0) {\n      const int64_t total_rows = elem_cnt / num_cols;\n      TrilWarpProcessRowGpu<<<BlocksNum4ThreadsNum(total_rows * kCudaWarpSize),\n                              kCudaThreadsNumPerBlock, 0,\n                              ctx->stream()->As<ep::CudaStream>()->cuda_stream()>>>(\n          total_rows, num_rows, num_cols, diagonal, x->dptr<T>(), fill, y->mut_dptr<T>());\n    } else {\n      TrilGpu<<<BlocksNum4ThreadsNum(elem_cnt), kCudaThreadsNumPerBlock, 0,\n                ctx->stream()->As<ep::CudaStream>()->cuda_stream()>>>(\n          elem_cnt, num_rows, num_cols, diagonal, x->dptr<T>(), fill, y->mut_dptr<T>());\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_CUDA_TRIL_KERNEL(dtype)                                                        \\\n  REGISTER_USER_KERNEL(\"tril\")                                                                  \\\n      .SetCreateFn<GpuTrilKernel<dtype>>()                                                      \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)                          \\\n                       && (user_op::HobDataType(\"out\", 0) == GetDataType<dtype>::value))        \\\n      .SetInplaceProposalFn([](const user_op::InferContext&,                                    \\\n                               user_op::AddInplaceArgPair AddInplaceArgPairFn) -> Maybe<void> { \\\n        OF_RETURN_IF_ERROR(AddInplaceArgPairFn(\"out\", 0, \"in\", 0, true));                       \\\n        return Maybe<void>::Ok();                                                               \\\n      });\n\nREGISTER_CUDA_TRIL_KERNEL(float)\nREGISTER_CUDA_TRIL_KERNEL(double)\nREGISTER_CUDA_TRIL_KERNEL(bool)\nREGISTER_CUDA_TRIL_KERNEL(uint8_t)\nREGISTER_CUDA_TRIL_KERNEL(int8_t)\nREGISTER_CUDA_TRIL_KERNEL(int32_t)\nREGISTER_CUDA_TRIL_KERNEL(int64_t)\nREGISTER_CUDA_TRIL_KERNEL(half)\n\ntemplate<typename T>\nclass GpuFusedScaleTrilKernel final : public user_op::OpKernel {\n public:\n  GpuFusedScaleTrilKernel() = default;\n  ~GpuFusedScaleTrilKernel() override = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    const auto shape = x->shape_view();\n    const auto diagonal = ctx->Attr<int64_t>(\"diagonal\");\n    const int32_t num_rows = shape.At(shape.NumAxes() - 2);\n    const int32_t num_cols = shape.At(shape.NumAxes() - 1);\n    user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    const int32_t elem_cnt = shape.elem_cnt();\n    const T fill = GetAttrVal<T>(ctx->Attr<bool>(\"is_floating_fill_value\"),\n                                 ctx->Attr<double>(\"floating_fill_value\"),\n                                 ctx->Attr<int64_t>(\"integer_fill_value\"));\n    const T scale = GetAttrVal<T>(ctx->Attr<bool>(\"is_floating_scale_value\"),\n                                  ctx->Attr<double>(\"floating_scale_value\"),\n                                  ctx->Attr<int64_t>(\"integer_scale_value\"));\n    if (num_cols % (kCudaWarpSize * 2) == 0) {\n      const int64_t total_rows = elem_cnt / num_cols;\n      FusedScaleTrilWarpProcessRowGpu<<<BlocksNum4ThreadsNum(total_rows * kCudaWarpSize),\n                                        kCudaThreadsNumPerBlock, 0,\n                                        ctx->stream()->As<ep::CudaStream>()->cuda_stream()>>>(\n          total_rows, num_rows, num_cols, diagonal, scale, x->dptr<T>(), fill, y->mut_dptr<T>());\n    } else {\n      FusedScaleTrilGpu<<<BlocksNum4ThreadsNum(elem_cnt), kCudaThreadsNumPerBlock, 0,\n                          ctx->stream()->As<ep::CudaStream>()->cuda_stream()>>>(\n          elem_cnt, num_rows, num_cols, diagonal, scale, x->dptr<T>(), fill, y->mut_dptr<T>());\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_CUDA_FUSED_SCALE_TRIL_KERNEL(dtype)                                            \\\n  REGISTER_USER_KERNEL(\"fused_scale_tril\")                                                      \\\n      .SetCreateFn<GpuFusedScaleTrilKernel<dtype>>()                                            \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)                          \\\n                       && (user_op::HobDataType(\"out\", 0) == GetDataType<dtype>::value))        \\\n      .SetInplaceProposalFn([](const user_op::InferContext&,                                    \\\n                               user_op::AddInplaceArgPair AddInplaceArgPairFn) -> Maybe<void> { \\\n        OF_RETURN_IF_ERROR(AddInplaceArgPairFn(\"out\", 0, \"in\", 0, true));                       \\\n        return Maybe<void>::Ok();                                                               \\\n      });\n\nREGISTER_CUDA_FUSED_SCALE_TRIL_KERNEL(float)\nREGISTER_CUDA_FUSED_SCALE_TRIL_KERNEL(double)\nREGISTER_CUDA_FUSED_SCALE_TRIL_KERNEL(bool)\nREGISTER_CUDA_FUSED_SCALE_TRIL_KERNEL(uint8_t)\nREGISTER_CUDA_FUSED_SCALE_TRIL_KERNEL(int8_t)\nREGISTER_CUDA_FUSED_SCALE_TRIL_KERNEL(int32_t)\nREGISTER_CUDA_FUSED_SCALE_TRIL_KERNEL(int64_t)\nREGISTER_CUDA_FUSED_SCALE_TRIL_KERNEL(half)\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/triu_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/kernel_util.h\"\n\nnamespace oneflow {\n\ntemplate<typename T>\nclass CpuTriuKernel final : public user_op::OpKernel {\n public:\n  CpuTriuKernel() = default;\n  ~CpuTriuKernel() override = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    const auto shape = x->shape_view();\n    const auto diagonal = ctx->Attr<int64_t>(\"diagonal\");\n    const int64_t num_rows = shape.At(shape.NumAxes() - 2);\n    const int64_t num_cols = shape.At(shape.NumAxes() - 1);\n    user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    T* y_dptr = y->mut_dptr<T>();\n    const T* x_dptr = x->dptr<T>();\n    int64_t matrix_size = num_rows * num_cols;\n    for (int64_t k = 0; k < shape.elem_cnt(); ++k) {\n      int64_t offset_in_matrix = k % matrix_size;\n      int64_t i = offset_in_matrix / num_cols;\n      int64_t j = offset_in_matrix - num_cols * i;\n      y_dptr[k] = j < i + diagonal ? static_cast<T>(0) : x_dptr[k];\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_CPU_TRIU_KERNEL(dtype)                                             \\\n  REGISTER_USER_KERNEL(\"triu\").SetCreateFn<CpuTriuKernel<dtype>>().SetIsMatchedHob( \\\n      (user_op::HobDeviceType() == DeviceType::kCPU)                                \\\n      && (user_op::HobDataType(\"out\", 0) == GetDataType<dtype>::value));\n\nREGISTER_CPU_TRIU_KERNEL(float16)\nREGISTER_CPU_TRIU_KERNEL(float)\nREGISTER_CPU_TRIU_KERNEL(double)\nREGISTER_CPU_TRIU_KERNEL(uint8_t)\nREGISTER_CPU_TRIU_KERNEL(int8_t)\nREGISTER_CPU_TRIU_KERNEL(int32_t)\nREGISTER_CPU_TRIU_KERNEL(int64_t)\nREGISTER_CPU_TRIU_KERNEL(bool)\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/triu_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/data_type.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/util/cuda_half_util.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename T>\n__global__ void TriuGpu(const int64_t elem_cnt, const int64_t num_rows, const int64_t num_cols,\n                        const int64_t diagonal, const T* x, T* y) {\n  const int64_t matrix_size = num_rows * num_cols;\n  CUDA_1D_KERNEL_LOOP_T(int64_t, k, elem_cnt) {\n    const int64_t offset_in_matrix = k % matrix_size;\n    const int64_t i = offset_in_matrix / num_cols;\n    const int64_t j = offset_in_matrix - num_cols * i;\n    y[k] = j < i + diagonal ? static_cast<T>(0) : x[k];\n  }\n}\n\ntemplate<typename T>\n__global__ void TriuWarpProcessRowGpu(const int64_t total_rows, const int64_t num_rows,\n                                      const int64_t num_cols, const int64_t diagonal, const T* x,\n                                      T* y) {\n  const int64_t warp_id = (blockIdx.x * blockDim.x + threadIdx.x) / kCudaWarpSize;\n  const int64_t lan_id = threadIdx.x % kCudaWarpSize;\n  const int64_t num_warp = blockDim.x * gridDim.x / kCudaWarpSize;\n  for (int64_t i = warp_id; i < total_rows; i += num_warp) {\n    const int64_t row = i % num_rows;\n    for (int64_t col = lan_id; col < num_cols; col += kCudaWarpSize) {\n      const int64_t idx = i * num_cols + col;\n      y[idx] = col < row + diagonal ? static_cast<T>(0) : x[idx];\n    }\n  }\n}\n\ntemplate<>\n__global__ void TriuWarpProcessRowGpu<half>(const int64_t total_rows, const int64_t num_rows,\n                                            const int64_t num_cols, const int64_t diagonal,\n                                            const half* x, half* y) {\n  const int64_t h2_num_cols = num_cols / 2;\n  const auto* x_h2 = reinterpret_cast<const half2*>(x);\n  auto* y_h2 = reinterpret_cast<half2*>(y);\n\n  const int64_t warp_id = (blockIdx.x * blockDim.x + threadIdx.x) / kCudaWarpSize;\n  const int64_t lan_id = threadIdx.x % kCudaWarpSize;\n  const int64_t num_warp = blockDim.x * gridDim.x / kCudaWarpSize;\n  for (int64_t i = warp_id; i < total_rows; i += num_warp) {\n    const int64_t row = i % num_rows;\n    for (int64_t col = lan_id; col < h2_num_cols; col += kCudaWarpSize) {\n      const int64_t idx = i * h2_num_cols + col;\n      const half2 x_val = x_h2[idx];\n      half2 y_val;\n      y_val.x = (2 * col) < row + diagonal ? static_cast<half>(0) : x_val.x;\n      y_val.y = (2 * col + 1) < row + diagonal ? static_cast<half>(0) : x_val.y;\n      y_h2[idx] = y_val;\n    }\n  }\n}\n\n}  // namespace\n\ntemplate<typename T>\nclass GpuTriuKernel final : public user_op::OpKernel {\n public:\n  GpuTriuKernel() = default;\n  ~GpuTriuKernel() override = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    const auto shape = x->shape_view();\n    const auto diagonal = ctx->Attr<int64_t>(\"diagonal\");\n    const int64_t num_rows = shape.At(shape.NumAxes() - 2);\n    const int64_t num_cols = shape.At(shape.NumAxes() - 1);\n    user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    const int32_t elem_cnt = shape.elem_cnt();\n    if (elem_cnt == 0) { return; }\n    if (num_cols % (kCudaWarpSize * 2) == 0) {\n      const int64_t total_rows = elem_cnt / num_cols;\n      TriuWarpProcessRowGpu<<<BlocksNum4ThreadsNum(total_rows * kCudaWarpSize),\n                              kCudaThreadsNumPerBlock, 0,\n                              ctx->stream()->As<ep::CudaStream>()->cuda_stream()>>>(\n          total_rows, num_rows, num_cols, diagonal, x->dptr<T>(), y->mut_dptr<T>());\n    } else {\n      TriuGpu<<<BlocksNum4ThreadsNum(elem_cnt), kCudaThreadsNumPerBlock, 0,\n                ctx->stream()->As<ep::CudaStream>()->cuda_stream()>>>(\n          elem_cnt, num_rows, num_cols, diagonal, x->dptr<T>(), y->mut_dptr<T>());\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_CUDA_TRIU_KERNEL(dtype)                                                        \\\n  REGISTER_USER_KERNEL(\"triu\")                                                                  \\\n      .SetCreateFn<GpuTriuKernel<dtype>>()                                                      \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)                          \\\n                       && (user_op::HobDataType(\"out\", 0) == GetDataType<dtype>::value))        \\\n      .SetInplaceProposalFn([](const user_op::InferContext&,                                    \\\n                               user_op::AddInplaceArgPair AddInplaceArgPairFn) -> Maybe<void> { \\\n        OF_RETURN_IF_ERROR(AddInplaceArgPairFn(\"out\", 0, \"in\", 0, true));                       \\\n        return Maybe<void>::Ok();                                                               \\\n      });\n\nREGISTER_CUDA_TRIU_KERNEL(half)\nREGISTER_CUDA_TRIU_KERNEL(float)\nREGISTER_CUDA_TRIU_KERNEL(double)\nREGISTER_CUDA_TRIU_KERNEL(uint8_t)\nREGISTER_CUDA_TRIU_KERNEL(int8_t)\nREGISTER_CUDA_TRIU_KERNEL(int32_t)\nREGISTER_CUDA_TRIU_KERNEL(int64_t)\nREGISTER_CUDA_TRIU_KERNEL(bool)\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/tuple_identity_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/new_kernel_util.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<DeviceType device_type>\nclass TupleIdentityKernel final : public user_op::OpKernel {\n public:\n  TupleIdentityKernel() = default;\n  ~TupleIdentityKernel() override = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const int64_t in_size = ctx->input_size(\"in\");\n    CHECK_EQ(ctx->output_size(\"out\"), in_size);\n    for (int64_t i = 0; i < in_size; ++i) {\n      const user_op::Tensor* in_i = ctx->Tensor4ArgNameAndIndex(\"in\", i);\n      user_op::Tensor* out_i = ctx->Tensor4ArgNameAndIndex(\"out\", i);\n      const DataType data_type = in_i->data_type();\n      CHECK_EQ(out_i->data_type(), data_type);\n      const ShapeView& shape = in_i->shape_view();\n      CHECK_EQ(out_i->shape_view(), shape);\n      Memcpy<device_type>(ctx->stream(), out_i->mut_dptr(), in_i->dptr(),\n                          shape.elem_cnt() * GetSizeOfDataType(data_type));\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_TUPLE_IDENTITY_KERNEL(device)    \\\n  REGISTER_USER_KERNEL(\"tuple_identity\")          \\\n      .SetCreateFn<TupleIdentityKernel<device>>() \\\n      .SetIsMatchedHob(user_op::HobDeviceType() == device);\n\nREGISTER_TUPLE_IDENTITY_KERNEL(DeviceType::kCPU)\n#ifdef WITH_CUDA\nREGISTER_TUPLE_IDENTITY_KERNEL(DeviceType::kCUDA)\n#endif\n\n}  // namespace\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/two_stage_reduce_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/ndarray/ndarray_util.h\"\n#include \"oneflow/core/ndarray/xpu_var_ndarray.h\"\n#include \"oneflow/user/kernels/two_stage_reduce_kernel_util.h\"\n#include \"oneflow/core/ep/include/primitive/cast.h\"\n#include \"oneflow/core/ep/include/primitive/broadcast_elementwise_binary.h\"\n\nnamespace oneflow {\n\nnamespace user_op {\n\ntemplate<template<typename> class BinaryFunc, DeviceType device_type, typename T>\nclass ReduceDeviceStageKernel final : public OpKernel {\n public:\n  ReduceDeviceStageKernel() = default;\n  ~ReduceDeviceStageKernel() = default;\n\n private:\n  void Compute(KernelComputeContext* ctx) const override {\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    user_op::Tensor* mask = ctx->Tensor4ArgNameAndIndex(\"mask\", 0);\n    user_op::Tensor* count = ctx->Tensor4ArgNameAndIndex(\"count\", 0);\n    user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n    T* reduce_tmp_buf = tmp_buffer->mut_dptr<T>();\n    int32_t* mask_tmp_buf = tmp_buffer->mut_dptr<int32_t>();\n    const size_t tmp_bytes =\n        GetCudaAlignedSize(in->shape_view().elem_cnt() * std::max(sizeof(T), sizeof(int32_t)));\n    int32_t* reduce_sum_tmp_buf =\n        reinterpret_cast<int32_t*>(tmp_buffer->mut_dptr<char>() + tmp_bytes);\n\n    NdarrayReduce<device_type, T, BinaryFunc>::Reduce(\n        ctx->stream(), XpuVarNdarray<T>(out->shape_view(), out->mut_dptr<T>()),\n        XpuVarNdarray<const T>(in->shape_view(), in->dptr<T>()),\n        XpuVarNdarray<T>(in->shape_view(), reduce_tmp_buf));\n    auto bcast_eq = ep::primitive::NewPrimitive<ep::primitive::BroadcastElementwiseBinaryFactory>(\n        ctx->device_type(), ep::primitive::BinaryOp::kEqual, in->data_type(), DataType::kBool,\n        in->shape_view().NumAxes());\n    CHECK(bcast_eq);\n    bcast_eq->Launch(ctx->stream(), in->shape_view().NumAxes(), in->shape_view().ptr(), in->dptr(),\n                     out->shape_view().NumAxes(), out->shape_view().ptr(), out->dptr(),\n                     mask->mut_dptr());\n\n    auto cast = ep::primitive::NewPrimitive<ep::primitive::CastFactory>(\n        ctx->device_type(), DataType::kInt8, DataType::kInt32);\n    CHECK(cast);\n\n    cast->Launch(ctx->stream(), mask->dptr<bool>(), mask_tmp_buf, mask->shape_view().elem_cnt());\n    NdarrayUtil<device_type, int32_t>::ReduceSum(\n        ctx->stream(), XpuVarNdarray<int32_t>(count->shape_view(), count->mut_dptr<int32_t>()),\n        XpuVarNdarray<const int32_t>(mask->shape_view(), mask_tmp_buf),\n        XpuVarNdarray<int32_t>(mask->shape_view(), reduce_sum_tmp_buf));\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\ntemplate<typename T>\nuser_op::InferTmpSizeFn GenDeviceStageInferTmpSizeFn() {\n  return [](user_op::InferContext* ctx) {\n    const Shape& in_shape = ctx->InputShape(\"in\", 0);\n    const size_t tmp_bytes =\n        GetCudaAlignedSize(in_shape.elem_cnt() * std::max(sizeof(T), sizeof(int32_t)));\n    const size_t reduce_sum_tmp_bytes = GetCudaAlignedSize(in_shape.elem_cnt() * sizeof(int32_t));\n    return tmp_bytes + reduce_sum_tmp_bytes;\n  };\n}\n\n#define REGISTER_REDUCE_DEVICE_STAGE_KERNEL(op_name, binary_func, device, dtype_pair)            \\\n  REGISTER_USER_KERNEL(op_name)                                                                  \\\n      .SetCreateFn<ReduceDeviceStageKernel<binary_func, device, OF_PP_PAIR_FIRST(dtype_pair)>>() \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == device)                                      \\\n                       && (user_op::HobDataType(\"out\", 0) == OF_PP_PAIR_SECOND(dtype_pair)))     \\\n      .SetInferTmpSizeFn(GenDeviceStageInferTmpSizeFn<OF_PP_PAIR_FIRST(dtype_pair)>());\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_REDUCE_DEVICE_STAGE_KERNEL, (\"reduce_max_device_stage\"),\n                                 (BinaryFuncMax), DEVICE_TYPE_SEQ,\n                                 FLOATING_DATA_TYPE_SEQ INDEX_DATA_TYPE_SEQ)\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_REDUCE_DEVICE_STAGE_KERNEL, (\"reduce_min_device_stage\"),\n                                 (BinaryFuncMin), DEVICE_TYPE_SEQ,\n                                 FLOATING_DATA_TYPE_SEQ INDEX_DATA_TYPE_SEQ)\n\ntemplate<DeviceType device_type, typename T>\nclass ReduceDeviceStageGradKernel final : public OpKernel {\n public:\n  ReduceDeviceStageGradKernel() = default;\n  ~ReduceDeviceStageGradKernel() = default;\n\n private:\n  void Compute(KernelComputeContext* ctx) const override {\n    const user_op::Tensor* out_diff = ctx->Tensor4ArgNameAndIndex(\"out_diff\", 0);\n    const user_op::Tensor* mask = ctx->Tensor4ArgNameAndIndex(\"mask\", 0);\n    const user_op::Tensor* count = ctx->Tensor4ArgNameAndIndex(\"count\", 0);\n    user_op::Tensor* in_diff = ctx->Tensor4ArgNameAndIndex(\"in_diff\", 0);\n    user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n    T* tmp_buf_ptr = tmp_buffer->mut_dptr<T>();\n    const size_t tmp_bytes = GetCudaAlignedSize(out_diff->shape_view().elem_cnt() * sizeof(T));\n    T* broadcasted_tmp_buf_ptr = reinterpret_cast<T*>(tmp_buffer->mut_dptr<char>() + tmp_bytes);\n\n    TwoStageReduceKernelUtil<device_type, T, int32_t>::Divide(\n        ctx->stream(), out_diff->shape_view().elem_cnt(), out_diff->dptr<T>(),\n        count->dptr<int32_t>(), tmp_buf_ptr);\n\n    NdarrayUtil<device_type, T>::BroadcastTo(\n        ctx->stream(), XpuVarNdarray<T>(in_diff->shape_view(), broadcasted_tmp_buf_ptr),\n        XpuVarNdarray<const T>(out_diff->shape_view(), tmp_buf_ptr));\n\n    TwoStageReduceKernelUtil<device_type, T, bool>::Mask(\n        ctx->stream(), in_diff->shape_view().elem_cnt(), broadcasted_tmp_buf_ptr,\n        mask->dptr<bool>(), in_diff->mut_dptr<T>());\n  }\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\ntemplate<typename T>\nuser_op::InferTmpSizeFn GenDeviceStageGradInferTmpSizeFn() {\n  return [](user_op::InferContext* ctx) {\n    const Shape& out_diff_shape = ctx->InputShape(\"out_diff\", 0);\n    const Shape& in_diff_shape = ctx->OutputShape(\"in_diff\", 0);\n    const size_t tmp_bytes = GetCudaAlignedSize(out_diff_shape.elem_cnt() * sizeof(T));\n    const size_t broadcasted_tmp_bytes = GetCudaAlignedSize(in_diff_shape.elem_cnt() * sizeof(T));\n    return tmp_bytes + broadcasted_tmp_bytes;\n  };\n}\n\n#define REGISTER_REDUCE_DEVICE_STAGE_GRAD_KERNEL(op_name, device, dtype_pair)                    \\\n  REGISTER_USER_KERNEL(op_name)                                                                  \\\n      .SetCreateFn<ReduceDeviceStageGradKernel<device, OF_PP_PAIR_FIRST(dtype_pair)>>()          \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == device)                                      \\\n                       && (user_op::HobDataType(\"in_diff\", 0) == OF_PP_PAIR_SECOND(dtype_pair))) \\\n      .SetInferTmpSizeFn(GenDeviceStageGradInferTmpSizeFn<OF_PP_PAIR_FIRST(dtype_pair)>());\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_REDUCE_DEVICE_STAGE_GRAD_KERNEL,\n                                 (\"reduce_max_device_stage_grad\"), DEVICE_TYPE_SEQ,\n                                 FLOATING_DATA_TYPE_SEQ INDEX_DATA_TYPE_SEQ)\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_REDUCE_DEVICE_STAGE_GRAD_KERNEL,\n                                 (\"reduce_min_device_stage_grad\"), DEVICE_TYPE_SEQ,\n                                 FLOATING_DATA_TYPE_SEQ INDEX_DATA_TYPE_SEQ)\n\ntemplate<template<typename> class BinaryFunc, DeviceType device_type, typename T>\nclass ReduceGlobalStageKernel final : public OpKernel {\n public:\n  ReduceGlobalStageKernel() = default;\n  ~ReduceGlobalStageKernel() = default;\n\n private:\n  void Compute(KernelComputeContext* ctx) const override {\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    user_op::Tensor* mask = ctx->Tensor4ArgNameAndIndex(\"mask\", 0);\n    user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n    const auto& axis = ctx->Attr<std::vector<int32_t>>(\"axis\");\n    const Shape& reduced_shape = CreateReducedShape(in->shape_view(), {axis.begin(), axis.end()});\n    NdarrayReduce<device_type, T, BinaryFunc>::Reduce(\n        ctx->stream(), XpuVarNdarray<T>(reduced_shape, out->mut_dptr<T>()),\n        XpuVarNdarray<const T>(in->shape_view(), in->dptr<T>()),\n        XpuVarNdarray<T>(in->shape_view(), tmp_buffer->mut_dptr<T>()));\n\n    auto bcast_eq = ep::primitive::NewPrimitive<ep::primitive::BroadcastElementwiseBinaryFactory>(\n        ctx->device_type(), ep::primitive::BinaryOp::kEqual, in->data_type(), DataType::kBool,\n        in->shape_view().NumAxes());\n    CHECK(bcast_eq);\n    bcast_eq->Launch(ctx->stream(), in->shape_view().NumAxes(), in->shape_view().ptr(), in->dptr(),\n                     reduced_shape.NumAxes(), reduced_shape.dim_vec().data(), out->dptr(),\n                     mask->mut_dptr());\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_REDUCE_GLOBAL_STAGE_KERNEL(op_name, binary_func, device, dtype_pair)            \\\n  REGISTER_USER_KERNEL(op_name)                                                                  \\\n      .SetCreateFn<ReduceGlobalStageKernel<binary_func, device, OF_PP_PAIR_FIRST(dtype_pair)>>() \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == device)                                      \\\n                       && (user_op::HobDataType(\"out\", 0) == OF_PP_PAIR_SECOND(dtype_pair)))     \\\n      .SetInferTmpSizeFn([](InferContext* ctx) {                                                 \\\n        const Shape& in_shape = ctx->InputShape(\"in\", 0);                                        \\\n        return in_shape.elem_cnt() * sizeof(OF_PP_PAIR_FIRST(dtype_pair));                       \\\n      });\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_REDUCE_GLOBAL_STAGE_KERNEL, (\"reduce_max_global_stage\"),\n                                 (BinaryFuncMax), DEVICE_TYPE_SEQ,\n                                 FLOATING_DATA_TYPE_SEQ INDEX_DATA_TYPE_SEQ)\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_REDUCE_GLOBAL_STAGE_KERNEL, (\"reduce_min_global_stage\"),\n                                 (BinaryFuncMin), DEVICE_TYPE_SEQ,\n                                 FLOATING_DATA_TYPE_SEQ INDEX_DATA_TYPE_SEQ)\n\ntemplate<DeviceType device_type, typename T>\nclass ReduceGlobalStageGradKernel final : public OpKernel {\n public:\n  ReduceGlobalStageGradKernel() = default;\n  ~ReduceGlobalStageGradKernel() = default;\n\n private:\n  void Compute(KernelComputeContext* ctx) const override {\n    const user_op::Tensor* out_diff = ctx->Tensor4ArgNameAndIndex(\"out_diff\", 0);\n    const user_op::Tensor* mask = ctx->Tensor4ArgNameAndIndex(\"mask\", 0);\n    const user_op::Tensor* device_count = ctx->Tensor4ArgNameAndIndex(\"device_count\", 0);\n    user_op::Tensor* in_diff = ctx->Tensor4ArgNameAndIndex(\"in_diff\", 0);\n    user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n    int32_t* device_count_with_mask = tmp_buffer->mut_dptr<int32_t>();\n    const size_t device_count_with_mask_bytes =\n        GetCudaAlignedSize(device_count->shape_view().elem_cnt() * sizeof(int32_t));\n    int32_t* global_count =\n        reinterpret_cast<int32_t*>(tmp_buffer->mut_dptr<char>() + device_count_with_mask_bytes);\n    const size_t global_count_bytes =\n        GetCudaAlignedSize(out_diff->shape_view().elem_cnt() * sizeof(int32_t));\n    int32_t* reduce_sum_tmp_buf = reinterpret_cast<int32_t*>(\n        tmp_buffer->mut_dptr<char>() + device_count_with_mask_bytes + global_count_bytes);\n    const size_t reduce_sum_tmp_bytes =\n        GetCudaAlignedSize(device_count->shape_view().elem_cnt() * sizeof(int32_t));\n    T* divided_buf_ptr =\n        reinterpret_cast<T*>(tmp_buffer->mut_dptr<char>() + device_count_with_mask_bytes\n                             + global_count_bytes + reduce_sum_tmp_bytes);\n    const size_t divided_buf_bytes =\n        GetCudaAlignedSize(out_diff->shape_view().elem_cnt() * sizeof(T));\n    T* broadcasted_divided_buf_ptr =\n        reinterpret_cast<T*>(tmp_buffer->mut_dptr<char>() + device_count_with_mask_bytes\n                             + global_count_bytes + reduce_sum_tmp_bytes + divided_buf_bytes);\n\n    TwoStageReduceKernelUtil<device_type, int32_t, bool>::Mask(\n        ctx->stream(), device_count->shape_view().elem_cnt(), device_count->dptr<int32_t>(),\n        mask->dptr<bool>(), device_count_with_mask);\n\n    const auto& axis = ctx->Attr<std::vector<int32_t>>(\"axis\");\n    const Shape& reduced_shape =\n        CreateReducedShape(device_count->shape_view(), {axis.begin(), axis.end()});\n\n    NdarrayUtil<device_type, int32_t>::ReduceSum(\n        ctx->stream(), XpuVarNdarray<int32_t>(reduced_shape, global_count),\n        XpuVarNdarray<const int32_t>(device_count->shape_view(), device_count_with_mask),\n        XpuVarNdarray<int32_t>(device_count->shape_view(), reduce_sum_tmp_buf));\n\n    TwoStageReduceKernelUtil<device_type, T, int32_t>::Divide(\n        ctx->stream(), out_diff->shape_view().elem_cnt(), out_diff->dptr<T>(), global_count,\n        divided_buf_ptr);\n\n    NdarrayUtil<device_type, T>::BroadcastTo(\n        ctx->stream(), XpuVarNdarray<T>(in_diff->shape_view(), broadcasted_divided_buf_ptr),\n        XpuVarNdarray<const T>(out_diff->shape_view(), divided_buf_ptr));\n\n    TwoStageReduceKernelUtil<device_type, T, int32_t>::Scale(\n        ctx->stream(), in_diff->shape_view().elem_cnt(), broadcasted_divided_buf_ptr,\n        device_count_with_mask, in_diff->mut_dptr<T>());\n  }\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\ntemplate<typename T>\nuser_op::InferTmpSizeFn GenGlobalStageGradInferTmpSizeFn() {\n  return [](user_op::InferContext* ctx) {\n    const Shape& device_count_shape = ctx->InputShape(\"device_count\", 0);\n    const Shape& out_diff_shape = ctx->InputShape(\"out_diff\", 0);\n    const Shape& in_diff_shape = ctx->OutputShape(\"in_diff\", 0);\n    const size_t device_count_with_mask_bytes =\n        GetCudaAlignedSize(device_count_shape.elem_cnt() * sizeof(int32_t));\n    const size_t global_count_bytes =\n        GetCudaAlignedSize(out_diff_shape.elem_cnt() * sizeof(int32_t));\n    const size_t reduce_sum_tmp_bytes =\n        GetCudaAlignedSize(device_count_shape.elem_cnt() * sizeof(int32_t));\n    const size_t divided_buf_bytes = GetCudaAlignedSize(out_diff_shape.elem_cnt() * sizeof(T));\n    const size_t broadcasted_divided_buf_bytes =\n        GetCudaAlignedSize(in_diff_shape.elem_cnt() * sizeof(T));\n    const size_t total_bytes = device_count_with_mask_bytes + global_count_bytes\n                               + reduce_sum_tmp_bytes + divided_buf_bytes\n                               + broadcasted_divided_buf_bytes;\n    return total_bytes;\n  };\n}\n\n#define REGISTER_REDUCE_GLOBAL_STAGE_GRAD_KERNEL(op_name, device, dtype_pair)                    \\\n  REGISTER_USER_KERNEL(op_name)                                                                  \\\n      .SetCreateFn<ReduceGlobalStageGradKernel<device, OF_PP_PAIR_FIRST(dtype_pair)>>()          \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == device)                                      \\\n                       && (user_op::HobDataType(\"in_diff\", 0) == OF_PP_PAIR_SECOND(dtype_pair))) \\\n      .SetInferTmpSizeFn(GenGlobalStageGradInferTmpSizeFn<OF_PP_PAIR_FIRST(dtype_pair)>());\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_REDUCE_GLOBAL_STAGE_GRAD_KERNEL,\n                                 (\"reduce_max_global_stage_grad\"), DEVICE_TYPE_SEQ,\n                                 FLOATING_DATA_TYPE_SEQ INDEX_DATA_TYPE_SEQ)\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_REDUCE_GLOBAL_STAGE_GRAD_KERNEL,\n                                 (\"reduce_min_global_stage_grad\"), DEVICE_TYPE_SEQ,\n                                 FLOATING_DATA_TYPE_SEQ INDEX_DATA_TYPE_SEQ)\n\n}  // namespace user_op\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/two_stage_reduce_kernel_util.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/user/kernels/two_stage_reduce_kernel_util.h\"\n#include \"oneflow/core/common/data_type_seq.h\"\n\nnamespace oneflow {\n\ntemplate<typename T, typename K>\nstruct TwoStageReduceKernelUtil<DeviceType::kCPU, T, K> {\n  static void Divide(ep::Stream* stream, const int64_t n, const T* x, const K* count, T* y) {\n    FOR_RANGE(int64_t, i, 0, n) { y[i] = x[i] / count[i]; }\n  }\n\n  static void Mask(ep::Stream* stream, const int64_t n, const T* x, const K* mask, T* y) {\n    FOR_RANGE(int64_t, i, 0, n) { y[i] = static_cast<T>(mask[i]) * x[i]; }\n  }\n\n  static void Scale(ep::Stream* stream, const int64_t n, const T* x, const K* scale, T* y) {\n    FOR_RANGE(int64_t, i, 0, n) { y[i] = x[i] * static_cast<T>(scale[i]); }\n  }\n};\n\n#define INSTANTIATE_TWO_STAGE_REDUCE_KERNEL_UTIL_CPU(data_type_pair, index_type_pair)          \\\n  template struct TwoStageReduceKernelUtil<DeviceType::kCPU, OF_PP_PAIR_FIRST(data_type_pair), \\\n                                           OF_PP_PAIR_FIRST(index_type_pair)>;\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_TWO_STAGE_REDUCE_KERNEL_UTIL_CPU,\n                                 FLOATING_DATA_TYPE_SEQ INDEX_DATA_TYPE_SEQ,\n                                 INT_DATA_TYPE_SEQ BOOL_DATA_TYPE_SEQ);\n#undef INSTANTIATE_TWO_STAGE_REDUCE_KERNEL_UTIL_CPU\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/two_stage_reduce_kernel_util.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/user/kernels/two_stage_reduce_kernel_util.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename T, typename K>\n__global__ void DivideGpu(const int64_t n, const T* x, const K* count, T* y) {\n  CUDA_1D_KERNEL_LOOP(i, n) { y[i] = x[i] / count[i]; }\n}\n\ntemplate<typename T, typename K>\n__global__ void MaskGpu(const int64_t n, const T* x, const K* mask, T* y) {\n  CUDA_1D_KERNEL_LOOP(i, n) { y[i] = static_cast<T>(mask[i]) * x[i]; }\n}\n\ntemplate<typename T, typename K>\n__global__ void ScaleGpu(const int64_t n, const T* x, const K* scale, T* y) {\n  CUDA_1D_KERNEL_LOOP(i, n) { y[i] = x[i] * static_cast<T>(scale[i]); }\n}\n\n}  // namespace\n\ntemplate<typename T, typename K>\nstruct TwoStageReduceKernelUtil<DeviceType::kCUDA, T, K> {\n  static void Divide(ep::Stream* stream, const int64_t n, const T* x, const K* count, T* y) {\n    DivideGpu<T, K><<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0,\n                      stream->As<ep::CudaStream>()->cuda_stream()>>>(n, x, count, y);\n  }\n\n  static void Mask(ep::Stream* stream, const int64_t n, const T* x, const K* mask, T* y) {\n    MaskGpu<T, K><<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0,\n                    stream->As<ep::CudaStream>()->cuda_stream()>>>(n, x, mask, y);\n  }\n\n  static void Scale(ep::Stream* stream, const int64_t n, const T* x, const K* scale, T* y) {\n    ScaleGpu<T, K><<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0,\n                     stream->As<ep::CudaStream>()->cuda_stream()>>>(n, x, scale, y);\n  }\n};\n\n#define INSTANTIATE_TWO_STAGE_REDUCE_KERNEL_UTIL_CUDA(data_type_pair, index_type_pair)          \\\n  template struct TwoStageReduceKernelUtil<DeviceType::kCUDA, OF_PP_PAIR_FIRST(data_type_pair), \\\n                                           OF_PP_PAIR_FIRST(index_type_pair)>;\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_TWO_STAGE_REDUCE_KERNEL_UTIL_CUDA,\n                                 FLOATING_DATA_TYPE_SEQ INDEX_DATA_TYPE_SEQ BOOL_DATA_TYPE_SEQ,\n                                 INT_DATA_TYPE_SEQ BOOL_DATA_TYPE_SEQ);\n#undef INSTANTIATE_TWO_STAGE_REDUCE_KERNEL_UTIL_CUDA\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/two_stage_reduce_kernel_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_KERNELS_TWO_STAGE_REDUCE_UTIL_H_\n#define ONEFLOW_USER_KERNELS_TWO_STAGE_REDUCE_UTIL_H_\n\n#include \"oneflow/core/ep/include/stream.h\"\n\nnamespace oneflow {\n\ntemplate<DeviceType device_type, typename T, typename K>\nstruct TwoStageReduceKernelUtil {\n  static void Divide(ep::Stream* stream, const int64_t n, const T* x, const K* count, T* y);\n  static void Mask(ep::Stream* stream, const int64_t n, const T* x, const K* mask, T* y);\n  static void Scale(ep::Stream* stream, const int64_t n, const T* x, const K* scale, T* y);\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_KERNELS_TWO_STAGE_REDUCE_UTIL_H_\n"
  },
  {
    "path": "oneflow/user/kernels/unfold_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/user/ops/nn_util.h\"\n#include \"oneflow/core/operator/operator_util.h\"\n#include \"oneflow/user/kernels/unfold_kernel_util.h\"\n\nnamespace oneflow {\n\nnamespace user_op {\n\nnamespace {\n\n// NDIM range: (1, 2, 3)\n// SDIM range: (1, 2), 1 indicates channels_last, 2 indicates channels_first\ntemplate<typename INDEX_T, int NDIM, int SDIM>\nclass UnfoldOpKernelState : public OpKernelState {\n public:\n  using ParamType = UnfoldParams<INDEX_T, NDIM, SDIM>;\n  UnfoldOpKernelState(const ShapeView& input_shape, const std::vector<int32_t>& kernel_size,\n                      const std::vector<int32_t>& padding, const std::vector<int32_t>& stride,\n                      const std::vector<int32_t>& dilation)\n      : params_(input_shape.At(0), input_shape.At(ParamType::kInputChannelDim),\n                input_shape.ptr() + SDIM, kernel_size.data(), padding.data(), stride.data(),\n                dilation.data()) {}\n  const ParamType& params() const { return params_; }\n\n private:\n  ParamType params_;\n};\n\ntemplate<typename INDEX_T, int NDIM, int SDIM>\nstd::shared_ptr<UnfoldOpKernelState<INDEX_T, NDIM, SDIM>> CreateUnfoldOpKernelState(\n    const ShapeView& input_shape, const std::vector<int32_t>& kernel_size,\n    const std::vector<int32_t>& padding, const std::vector<int32_t>& stride,\n    const std::vector<int32_t>& dilation) {\n  std::shared_ptr<UnfoldOpKernelState<INDEX_T, NDIM, SDIM>> state(\n      new UnfoldOpKernelState<INDEX_T, NDIM, SDIM>(input_shape, kernel_size, padding, stride,\n                                                   dilation));\n  return state;\n}\n\ntemplate<DeviceType device_type, typename T, typename INDEX_T, int NDIM, int SDIM>\nclass UnfoldKernel final : public OpKernel {\n public:\n  UnfoldKernel() = default;\n  ~UnfoldKernel() = default;\n\n private:\n  void Compute(KernelComputeContext* ctx) const override {\n    const Tensor* input = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    Tensor* output = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    const std::vector<int32_t> kernel_size = ctx->Attr<std::vector<int32_t>>(\"kernel_size\");\n    const std::vector<int32_t> padding = ctx->Attr<std::vector<int32_t>>(\"padding\");\n    const std::vector<int32_t> stride = ctx->Attr<std::vector<int32_t>>(\"strides\");\n    const std::vector<int32_t> dilation = ctx->Attr<std::vector<int32_t>>(\"dilation_rate\");\n\n    const auto& state_ptr = CreateUnfoldOpKernelState<INDEX_T, NDIM, SDIM>(\n        input->shape_view(), kernel_size, padding, stride, dilation);\n\n    const UnfoldParams<INDEX_T, NDIM, SDIM> params = state_ptr->params();\n    UnfoldKernelUtil<device_type, T, INDEX_T, NDIM, SDIM>::Forward(\n        ctx->stream(), &params, input->dptr<T>(), output->mut_dptr<T>());\n  }\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n}  // namespace\n\n// Currently support 4-D tensor and NCHW format\n#define REGISTER_UNFOLD_KERNEL(device, dtype)                    \\\n  REGISTER_USER_KERNEL(\"unfold\")                                 \\\n      .SetCreateFn<UnfoldKernel<device, dtype, int32_t, 2, 2>>() \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == device)      \\\n                       && (user_op::HobDataType(\"x\", 0) == GetDataType<dtype>::value));\n\nREGISTER_UNFOLD_KERNEL(DeviceType::kCPU, float)\nREGISTER_UNFOLD_KERNEL(DeviceType::kCPU, double)\n\n#ifdef WITH_CUDA\nREGISTER_UNFOLD_KERNEL(DeviceType::kCUDA, float)\nREGISTER_UNFOLD_KERNEL(DeviceType::kCUDA, double)\n#endif  // WITH_CUDA\n\n}  // namespace user_op\n\n}  // namespace oneflow"
  },
  {
    "path": "oneflow/user/kernels/unfold_kernel_util.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/user/kernels/unfold_kernel_util.h\"\n\nnamespace oneflow {\n\nnamespace user_op {\n\n// NDIM range: (1, 2, 3)\n// SDIM range: (1, 2), 1 indicates channels_last, 2 indicates channels_first\ntemplate<typename T, typename INDEX_T, int NDIM, int SDIM>\nstruct UnfoldKernelUtil<DeviceType::kCPU, T, INDEX_T, NDIM, SDIM> {\n  using ParamType = UnfoldParams<INDEX_T, NDIM, SDIM>;\n  static void Forward(ep::Stream* stream, const UnfoldParams<INDEX_T, NDIM, SDIM>* raw_params,\n                      const T* input_ptr, T* output_ptr) {\n    for (INDEX_T out_offset = 0; out_offset < raw_params->out_elem_cnt; ++out_offset) {\n      using ParamType = UnfoldParams<INDEX_T, NDIM, SDIM>;\n      INDEX_T in_index[ParamType::kInputNDim] = {0};\n      INDEX_T out_index[ParamType::kOutputNDim] = {0};\n      raw_params->out_index_helper.OffsetToNdIndex(out_offset, out_index);\n      if (!UnfoldIndexTransform<INDEX_T, NDIM, SDIM>(*raw_params, out_index, in_index)) {\n        INDEX_T in_offset = raw_params->in_index_helper.NdIndexToOffset(in_index);\n        output_ptr[out_offset] = input_ptr[in_offset];\n      } else {\n        output_ptr[out_offset] = static_cast<T>(kUnfoldPaddingValue);\n      }\n    }\n  }\n};\n\nINSTANTIATE_UNFOLD_KERNEL_UTIL_FOR_DEVICE(DeviceType::kCPU)\n\n}  // namespace user_op\n\n}  // namespace oneflow"
  },
  {
    "path": "oneflow/user/kernels/unfold_kernel_util.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifdef WITH_CUDA\n\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/cuda/elementwise.cuh\"\n#include \"oneflow/user/kernels/unfold_kernel_util.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n\nnamespace oneflow {\n\nnamespace user_op {\n\nnamespace {\n\nconstexpr int kBlockSize = cuda::elementwise::kBlockSize;\n\nint GetNumBlocks(int64_t elem_cnt) {\n  int num_blocks = 0;\n  OF_CUDA_CHECK(cuda::elementwise::GetNumBlocks(elem_cnt, &num_blocks));\n  return num_blocks;\n}\n\n// NDIM range: (1, 2, 3)\n// SDIM range: (1, 2), 1 indicates channels_last, 2 indicates channels_first\ntemplate<typename T, typename INDEX_T, int NDIM, int SDIM>\n__global__ void CudaUnfoldForward(UnfoldParams<INDEX_T, NDIM, SDIM> params, const T* in, T* out) {\n  CUDA_1D_KERNEL_LOOP_T(INDEX_T, out_offset, params.out_elem_cnt) {\n    using ParamType = UnfoldParams<INDEX_T, NDIM, SDIM>;\n    INDEX_T in_index[ParamType::kInputNDim] = {0};\n    INDEX_T out_index[ParamType::kOutputNDim] = {0};\n    params.out_index_helper.OffsetToNdIndex(out_offset, out_index);\n    if (!UnfoldIndexTransform<INDEX_T, NDIM, SDIM>(params, out_index, in_index)) {\n      INDEX_T in_offset = params.in_index_helper.NdIndexToOffset(in_index);\n      out[out_offset] = in[in_offset];\n    } else {\n      out[out_offset] = static_cast<T>(kUnfoldPaddingValue);\n    }\n  }\n}\n\n}  // namespace\n\ntemplate<typename T, typename INDEX_T, int NDIM, int SDIM>\nstruct UnfoldKernelUtil<DeviceType::kCUDA, T, INDEX_T, NDIM, SDIM> {\n  using ParamType = UnfoldParams<INDEX_T, NDIM, SDIM>;\n  static void Forward(ep::Stream* stream, const UnfoldParams<INDEX_T, NDIM, SDIM>* params,\n                      const T* input_ptr, T* output_ptr) {\n    CudaUnfoldForward<T, INDEX_T, NDIM, SDIM>\n        <<<GetNumBlocks(params->out_elem_cnt), kBlockSize, 0,\n           stream->As<ep::CudaStream>()->cuda_stream()>>>(*params, input_ptr, output_ptr);\n  }\n};\nINSTANTIATE_UNFOLD_KERNEL_UTIL_FOR_DEVICE(DeviceType::kCUDA)\n}  // namespace user_op\n}  // namespace oneflow\n#endif  // WITH_CUDA"
  },
  {
    "path": "oneflow/user/kernels/unfold_kernel_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_KERNELS_UNFOLD_KERNEL_UTIL_H_\n#define ONEFLOW_USER_KERNELS_UNFOLD_KERNEL_UTIL_H_\n\n#include \"oneflow/core/ep/include/stream.h\"\n#include \"oneflow/core/common/shape.h\"\n#include \"oneflow/core/common/nd_index_offset_helper.h\"\n#include \"oneflow/core/common/switch_func.h\"\n#include \"oneflow/core/ndarray/xpu_util.h\"\n\nnamespace oneflow {\n\nnamespace user_op {\n\nconstexpr int kUnfoldPaddingValue = 0;\n\n// NDIM range: (1, 2, 3)\n// SDIM range: (1, 2), 1 indicates channels_last, 2 indicates channels_first\ntemplate<typename INDEX_T, int NDIM, int SDIM>\nstruct UnfoldParams {\n  static constexpr int kInputNDim = NDIM + 2;\n  static constexpr int kOutputNDim = NDIM * 2 + 2;\n  static constexpr int kInputChannelDim = (2 - SDIM) * NDIM + 1;\n  static constexpr int kOutputChannelDim = (2 - SDIM) * NDIM * 2 + 1;\n  static_assert(kInputChannelDim < kInputNDim, \"\");\n  static_assert(kOutputChannelDim < kOutputNDim, \"\");\n  UnfoldParams(const int64_t batch_size, const int64_t channels, const int64_t* spatial_dims,\n               const int32_t* kernel_size, const int32_t* padding, const int32_t* stride,\n               const int32_t* dilation);\n  INDEX_T in_elem_cnt;\n  INDEX_T out_elem_cnt;\n  INDEX_T dims[NDIM];\n  int padding[NDIM];\n  int stride[NDIM];\n  int dilation[NDIM];\n  NdIndexOffsetHelper<INDEX_T, kInputNDim> in_index_helper;\n  NdIndexOffsetHelper<INDEX_T, kOutputNDim> out_index_helper;\n};\n\n// NDIM range: (1, 2, 3)\n// SDIM range: (1, 2), 1 indicates channels_last, 2 indicates channels_first\ntemplate<typename INDEX_T, int NDIM, int SDIM>\nUnfoldParams<INDEX_T, NDIM, SDIM>::UnfoldParams(const int64_t batch_size, const int64_t channels,\n                                                const int64_t* spatial_dims,\n                                                const int32_t* kernel_size, const int32_t* padding,\n                                                const int32_t* stride, const int32_t* dilation)\n    : in_index_helper(0), out_index_helper(0) {\n  INDEX_T input_dims[kInputNDim] = {0};\n  INDEX_T output_dims[kOutputNDim] = {0};\n  in_elem_cnt = batch_size * channels;\n  out_elem_cnt = batch_size * channels;\n  input_dims[0] = batch_size;\n  output_dims[0] = batch_size;\n  input_dims[kInputChannelDim] = channels;\n  output_dims[kOutputChannelDim] = channels;\n  for (int d = 0; d < NDIM; ++d) {\n    this->in_elem_cnt *= spatial_dims[d];\n    this->dims[d] = spatial_dims[d];\n    this->padding[d] = padding[d];\n    this->stride[d] = stride[d];\n    this->dilation[d] = dilation[d];\n    input_dims[SDIM + d] = spatial_dims[d];\n    output_dims[SDIM + d] = kernel_size[d];\n    output_dims[SDIM + NDIM + d] =\n        (spatial_dims[d] + 2 * padding[d] - dilation[d] * (kernel_size[d] - 1) - 1) / stride[d] + 1;\n    out_elem_cnt *= output_dims[SDIM + d] * output_dims[SDIM + NDIM + d];\n  }\n  in_index_helper = NdIndexOffsetHelper<INDEX_T, kInputNDim>(input_dims);\n  out_index_helper = NdIndexOffsetHelper<INDEX_T, kOutputNDim>(output_dims);\n}\n\n// index_a format: (N, C, di, hi, wi, db, hb, wb) or (N, di, hi, wi, db, hb, wb, C)\n// index_b format: (N, C, D, H, W) or (N, D, H, W, C)\n// return: true indicates out-of-bound, otherwise in-bound\ntemplate<typename INDEX_T, int NDIM, int SDIM>\nOF_DEVICE_FUNC bool UnfoldIndexTransform(const UnfoldParams<INDEX_T, NDIM, SDIM>& params,\n                                         const INDEX_T* index_a, INDEX_T* index_b) {\n  // batch dim index transform\n  index_b[0] = index_a[0];\n  // channel dim index transform\n  using ParamType = UnfoldParams<INDEX_T, NDIM, SDIM>;\n  index_b[ParamType::kInputChannelDim] = index_a[ParamType::kOutputChannelDim];\n// spatial dim index transform\n#ifdef __CUDA_ARCH__\n#pragma unroll\n#endif\n  // D,H,W spatial dim index transform\n  for (int64_t d = 0; d < NDIM; ++d) {\n    INDEX_T idx = index_a[SDIM + NDIM + d] * params.stride[d]\n                  + index_a[SDIM + d] * params.dilation[d] - params.padding[d];\n    if (idx < 0 || idx >= params.dims[d]) return true;\n    index_b[SDIM + d] = idx;\n  }\n  return false;\n}\n\ntemplate<DeviceType device_type, typename T, typename INDEX_T, int NDIM, int SDIM>\nstruct UnfoldKernelUtil {\n  static void Forward(ep::Stream* stream, const UnfoldParams<INDEX_T, NDIM, SDIM>* params,\n                      const T* input_ptr, T* output_ptr);\n};\n\n#define SPATIAL_NDIM_SEQ OF_PP_MAKE_TUPLE_SEQ(1) OF_PP_MAKE_TUPLE_SEQ(2) OF_PP_MAKE_TUPLE_SEQ(3)\n#define SPATIAL_DIM_SEQ OF_PP_MAKE_TUPLE_SEQ(1) OF_PP_MAKE_TUPLE_SEQ(2)\n\n#define INSTANTIATE_UNFOLD_KERNEL_UTIL(device, dtype, itype, ndim, sdim) \\\n  template struct UnfoldKernelUtil<device, dtype, itype, ndim, sdim>;\n\n#define INSTANTIATE_UNFOLD_KERNEL_UTIL_WITH_TYPE_PAIR(device, dtype_pair, itype_pair, ndim, sdim) \\\n  INSTANTIATE_UNFOLD_KERNEL_UTIL(device, OF_PP_PAIR_FIRST(dtype_pair),                            \\\n                                 OF_PP_PAIR_FIRST(itype_pair), ndim, sdim)\n\n#define INSTANTIATE_UNFOLD_KERNEL_UTIL_FOR_DEVICE(device)                                         \\\n  OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_UNFOLD_KERNEL_UTIL_WITH_TYPE_PAIR, (device),       \\\n                                   FLOATING_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ, SPATIAL_NDIM_SEQ, \\\n                                   SPATIAL_DIM_SEQ)\n\n}  // namespace user_op\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_KERNELS_UNFOLD_KERNEL_UTIL_H_"
  },
  {
    "path": "oneflow/user/kernels/unfold_tensor_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/user/ops/nn_util.h\"\n#include \"oneflow/core/operator/operator_util.h\"\n#include \"oneflow/user/kernels/unfold_tensor_kernel_utils.h\"\n\nnamespace oneflow {\n\ntemplate<typename T>\nclass UnfoldTensorKernel final : public user_op::OpKernel {\n public:\n  UnfoldTensorKernel() = default;\n  ~UnfoldTensorKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n\n    const ShapeView& in_shape = in->shape_view();\n    std::vector<int32_t> out_shape;\n    out_shape.resize(out->shape_view().NumAxes());\n    for (int i = 0; i < out->shape_view().NumAxes(); ++i) {\n      out_shape[i] = out->shape_view().At(i);\n    }\n\n    const int32_t in_dims = in_shape.NumAxes();\n    const int32_t out_dims = out_shape.size();\n    const int32_t dimension = ctx->Attr<int32_t>(\"dimension\");\n    const int32_t step = ctx->Attr<int32_t>(\"step\");\n\n    std::vector<int32_t> in_stride(in_dims, 1);\n    for (int32_t i = in_dims - 2; i >= 0; --i) {\n      in_stride[i] = in_shape.At(i + 1) * in_stride.at(i + 1);\n    }\n\n    std::vector<int32_t> out_stride(in_dims + 1);\n    out_stride[in_dims] = in_dims == 0 ? 1 : in_stride[dimension];\n    for (int d = 0; d < in_dims; ++d) {\n      if (d == dimension) {\n        out_stride[d] = step * in_stride[d];\n      } else {\n        out_stride[d] = in_stride[d];\n      }\n    }\n\n    const T* in_ptr = in->dptr<T>();\n    T* out_ptr = out->mut_dptr<T>();\n    const int32_t out_size = out->shape_view().elem_cnt();\n    for (int32_t i = 0; i < out_size; ++i) {\n      int offset = Offset(i, out_stride.data(), out_shape.data(), out_dims - 1);\n      out_ptr[i] = in_ptr[offset];\n    }\n  }\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_UNFOLD_TENSOR_KERNEL(dtype)                          \\\n  REGISTER_USER_KERNEL(\"unfold_tensor\")                               \\\n      .SetCreateFn<UnfoldTensorKernel<dtype>>()                       \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \\\n                       && (user_op::HobDataType(\"y\", 0) == GetDataType<dtype>::value));\n\nREGISTER_UNFOLD_TENSOR_KERNEL(float)\nREGISTER_UNFOLD_TENSOR_KERNEL(double)\nREGISTER_UNFOLD_TENSOR_KERNEL(int64_t)\nREGISTER_UNFOLD_TENSOR_KERNEL(int32_t)\n\ntemplate<typename T>\nclass UnfoldTensorGradKernel final : public user_op::OpKernel {\n public:\n  UnfoldTensorGradKernel() = default;\n  ~UnfoldTensorGradKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* dout = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    user_op::Tensor* din = ctx->Tensor4ArgNameAndIndex(\"dx\", 0);\n\n    const ShapeView& in_shape = in->shape_view();\n    const int32_t in_dims = in_shape.NumAxes();\n    std::vector<int32_t> din_stride(in_dims, 1);\n    for (int32_t i = in_dims - 2; i >= 0; --i) {\n      din_stride[i] = in_shape.At(i + 1) * din_stride.at(i + 1);\n    }\n\n    std::vector<int32_t> dout_shape;\n    dout_shape.resize(dout->shape_view().NumAxes());\n    for (int i = 0; i < dout->shape_view().NumAxes(); ++i) {\n      dout_shape[i] = dout->shape_view().At(i);\n    }\n\n    const int32_t dout_dims = dout_shape.size();\n    const int32_t dimension = ctx->Attr<int32_t>(\"dimension\");\n    const int32_t step = ctx->Attr<int32_t>(\"step\");\n\n    std::vector<int32_t> dout_stride(in_dims + 1);\n    dout_stride[in_dims] = in_dims == 0 ? 1 : din_stride[dimension];\n    for (int d = 0; d < in_dims; ++d) {\n      if (d == dimension) {\n        dout_stride[d] = step * din_stride[d];\n      } else {\n        dout_stride[d] = din_stride[d];\n      }\n    }\n\n    const T* dout_ptr = dout->dptr<T>();\n    T* din_ptr = din->mut_dptr<T>();\n\n    std::fill(din_ptr, din_ptr + din->shape_view().elem_cnt(), static_cast<T>(0));\n    const int32_t dout_size = dout->shape_view().elem_cnt();\n    for (int32_t i = 0; i < dout_size; ++i) {\n      int offset = Offset(i, dout_stride.data(), dout_shape.data(), dout_dims - 1);\n      din_ptr[offset] += dout_ptr[i];\n    }\n  }\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_UNFOLD_TENSOR_GRAD_KERNEL(dtype)                     \\\n  REGISTER_USER_KERNEL(\"unfold_tensor_grad\")                          \\\n      .SetCreateFn<UnfoldTensorGradKernel<dtype>>()                   \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \\\n                       && (user_op::HobDataType(\"x\", 0) == GetDataType<dtype>::value));\n\nREGISTER_UNFOLD_TENSOR_GRAD_KERNEL(float)\nREGISTER_UNFOLD_TENSOR_GRAD_KERNEL(double)\nREGISTER_UNFOLD_TENSOR_GRAD_KERNEL(int64_t)\nREGISTER_UNFOLD_TENSOR_GRAD_KERNEL(int32_t)\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/unfold_tensor_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/device/cuda_util.h\"\n#include \"oneflow/core/cuda/atomic.cuh\"\n#include \"oneflow/user/kernels/unfold_tensor_kernel_utils.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nconst int32_t NDIMS = 16;\nstruct STRIDES {\n  int32_t val[NDIMS];\n};\n\ntemplate<typename T>\n__global__ void UnfoldTensorCudaKernel(const T* in_ptr, const STRIDES out_stride,\n                                       const STRIDES out_shape, const int32_t out_dims,\n                                       const int32_t elements, T* out_ptr) {\n  int32_t gid = (blockDim.x * blockIdx.x) + threadIdx.x;\n  int32_t step = gridDim.x * blockDim.x;\n  while (gid < elements) {\n    int32_t offset = Offset(gid, out_stride.val, out_shape.val, out_dims - 1);\n    out_ptr[gid] = in_ptr[offset];\n    gid += step;\n  }\n}\n\ntemplate<typename T>\n__global__ void UnfoldTensorGradCudaKernel(const T* dout_ptr, const STRIDES dout_stride,\n                                           const STRIDES dout_shape, const int32_t dout_dims,\n                                           const int32_t elements, T* din_ptr) {\n  int32_t gid = (blockDim.x * blockIdx.x) + threadIdx.x;\n  int32_t step = gridDim.x * blockDim.x;\n  while (gid < elements) {\n    int32_t offset = Offset(gid, dout_stride.val, dout_shape.val, dout_dims - 1);\n    cuda::atomic::Add(&din_ptr[offset], dout_ptr[gid]);\n    gid += step;\n  }\n}\n\ntemplate<typename T>\n__global__ void InitPtr(const int32_t elements, T* ptr) {\n  int32_t gid = (blockDim.x * blockIdx.x) + threadIdx.x;\n  int32_t step = gridDim.x * blockDim.x;\n  while (gid < elements) {\n    ptr[gid] = static_cast<T>(0);\n    gid += step;\n  }\n}\n\ntemplate<typename T>\nstruct GpuUnfoldTensorFunctor final {\n  void operator()(ep::Stream* stream, const T* in_ptr, const STRIDES out_stride,\n                  const STRIDES out_shape, const int32_t out_dims, const int32_t elements,\n                  T* out_ptr) {\n    RUN_CUDA_KERNEL((UnfoldTensorCudaKernel<T>), stream, elements, in_ptr, out_stride, out_shape,\n                    out_dims, elements, out_ptr);\n  }\n};\n\ntemplate<typename T>\nstruct GpuUnfoldTensorGradFunctor final {\n  void operator()(ep::Stream* stream, const T* dout_ptr, const STRIDES dout_stride,\n                  const STRIDES dout_shape, const int32_t dout_dims, const int32_t dout_elements,\n                  const int32_t din_elements, T* din_ptr) {\n    RUN_CUDA_KERNEL((InitPtr<T>), stream, din_elements, din_elements, din_ptr);\n    RUN_CUDA_KERNEL((UnfoldTensorGradCudaKernel<T>), stream, dout_elements, dout_ptr, dout_stride,\n                    dout_shape, dout_dims, dout_elements, din_ptr);\n  }\n};\n\n}  // namespace\n\ntemplate<typename T>\nclass GpuUnfoldTensorKernel final : public user_op::OpKernel {\n public:\n  GpuUnfoldTensorKernel() = default;\n  ~GpuUnfoldTensorKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n\n    const ShapeView& in_shape = in->shape_view();\n    std::vector<int32_t> out_shape;\n    out_shape.resize(out->shape_view().NumAxes());\n    for (int i = 0; i < out->shape_view().NumAxes(); ++i) {\n      out_shape[i] = out->shape_view().At(i);\n    }\n    const int32_t in_dims = in_shape.NumAxes();\n    const int32_t out_dims = out_shape.size();\n    const int32_t dimension = ctx->Attr<int32_t>(\"dimension\");\n    const int32_t step = ctx->Attr<int32_t>(\"step\");\n\n    std::vector<int32_t> in_stride(in_dims, 1);\n    for (int32_t i = in_dims - 2; i >= 0; --i) {\n      in_stride[i] = in_shape.At(i + 1) * in_stride.at(i + 1);\n    }\n\n    std::vector<int32_t> out_stride(in_dims + 1);\n    out_stride[in_dims] = in_dims == 0 ? 1 : in_stride[dimension];\n    for (int d = 0; d < in_dims; ++d) {\n      if (d == dimension) {\n        out_stride[d] = step * in_stride[d];\n      } else {\n        out_stride[d] = in_stride[d];\n      }\n    }\n\n    const T* in_ptr = in->dptr<T>();\n    T* out_ptr = out->mut_dptr<T>();\n    const int32_t out_size = out->shape_view().elem_cnt();\n\n    STRIDES out_stride_cuda;\n    for (int i = 0; i < out_dims; ++i) { out_stride_cuda.val[i] = out_stride[i]; }\n    STRIDES out_shape_cuda;\n    for (int i = 0; i < out_dims; ++i) { out_shape_cuda.val[i] = out_shape[i]; }\n\n    GpuUnfoldTensorFunctor<T>()(ctx->stream(), in_ptr, out_stride_cuda, out_shape_cuda, out_dims,\n                                out_size, out_ptr);\n  }\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_UNFOLD_TENSOR_KERNEL(dtype)                           \\\n  REGISTER_USER_KERNEL(\"unfold_tensor\")                                \\\n      .SetCreateFn<GpuUnfoldTensorKernel<dtype>>()                     \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \\\n                       && (user_op::HobDataType(\"x\", 0) == GetDataType<dtype>::value))\n\nREGISTER_UNFOLD_TENSOR_KERNEL(float);\nREGISTER_UNFOLD_TENSOR_KERNEL(double);\nREGISTER_UNFOLD_TENSOR_KERNEL(int32_t);\nREGISTER_UNFOLD_TENSOR_KERNEL(int64_t);\n\ntemplate<typename T>\nclass GpuUnfoldTensorGradKernel final : public user_op::OpKernel {\n public:\n  GpuUnfoldTensorGradKernel() = default;\n  ~GpuUnfoldTensorGradKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* dout = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    user_op::Tensor* din = ctx->Tensor4ArgNameAndIndex(\"dx\", 0);\n\n    const ShapeView& in_shape = in->shape_view();\n    const int32_t in_dims = in_shape.NumAxes();\n    std::vector<int32_t> din_stride(in_dims, 1);\n    for (int32_t i = in_dims - 2; i >= 0; --i) {\n      din_stride[i] = in_shape.At(i + 1) * din_stride.at(i + 1);\n    }\n\n    std::vector<int32_t> dout_shape;\n    dout_shape.resize(dout->shape_view().NumAxes());\n    for (int i = 0; i < dout->shape_view().NumAxes(); ++i) {\n      dout_shape[i] = dout->shape_view().At(i);\n    }\n\n    const int32_t dout_dims = dout_shape.size();\n    const int32_t dimension = ctx->Attr<int32_t>(\"dimension\");\n    const int32_t step = ctx->Attr<int32_t>(\"step\");\n\n    std::vector<int32_t> dout_stride(in_dims + 1);\n    dout_stride[in_dims] = in_dims == 0 ? 1 : din_stride[dimension];\n    for (int d = 0; d < in_dims; ++d) {\n      if (d == dimension) {\n        dout_stride[d] = step * din_stride[d];\n      } else {\n        dout_stride[d] = din_stride[d];\n      }\n    }\n\n    STRIDES dout_stride_cuda;\n    for (int i = 0; i < dout_dims; ++i) { dout_stride_cuda.val[i] = dout_stride[i]; }\n    STRIDES dout_shape_cuda;\n    for (int i = 0; i < dout_dims; ++i) { dout_shape_cuda.val[i] = dout_shape[i]; }\n\n    const T* dout_ptr = dout->dptr<T>();\n    T* din_ptr = din->mut_dptr<T>();\n    const int32_t dout_size = dout->shape_view().elem_cnt();\n    const int32_t din_size = din->shape_view().elem_cnt();\n\n    GpuUnfoldTensorGradFunctor<T>()(ctx->stream(), dout_ptr, dout_stride_cuda, dout_shape_cuda,\n                                    dout_dims, dout_size, din_size, din_ptr);\n  }\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_UNFOLD_TENSOR_GRAD_KERNEL(dtype)                      \\\n  REGISTER_USER_KERNEL(\"unfold_tensor_grad\")                           \\\n      .SetCreateFn<GpuUnfoldTensorGradKernel<dtype>>()                 \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \\\n                       && (user_op::HobDataType(\"x\", 0) == GetDataType<dtype>::value))\n\nREGISTER_UNFOLD_TENSOR_GRAD_KERNEL(float);\nREGISTER_UNFOLD_TENSOR_GRAD_KERNEL(double);\nREGISTER_UNFOLD_TENSOR_GRAD_KERNEL(int32_t);\nREGISTER_UNFOLD_TENSOR_GRAD_KERNEL(int64_t);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/unfold_tensor_kernel_utils.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_UNFOLD_TENSOR_KERNEL_UTILS_H_\n#define ONEFLOW_UNFOLD_TENSOR_KERNEL_UTILS_H_\n\n#include \"oneflow/core/framework/framework.h\"\n\nnamespace oneflow {\nOF_DEVICE_FUNC int32_t Offset(int32_t in_offset, const int32_t* out_stride,\n                              const int32_t* out_shape, const int32_t n) {\n  int32_t remaining = 0;\n  int32_t out_offset = 0;\n#ifdef __CUDA_ARCH__\n#pragma unroll\n#endif\n  for (int32_t dim = n; dim >= 0; --dim) {\n    remaining = in_offset % out_shape[dim];\n    out_offset += remaining * out_stride[dim];\n    in_offset = in_offset / out_shape[dim];\n  }\n  return out_offset;\n}\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_UNFOLD_TENSOR_KERNEL_UTILS_H_\n"
  },
  {
    "path": "oneflow/user/kernels/unique_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/user/kernels/unique_kernel_util.h\"\n#include \"oneflow/core/framework/framework.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<DeviceType device_type, typename T, typename K>\nclass UniqueKernel final : public user_op::OpKernel {\n public:\n  UniqueKernel() = default;\n  ~UniqueKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    user_op::Tensor* idx = ctx->Tensor4ArgNameAndIndex(\"idx\", 0);\n    user_op::Tensor* num_unique = ctx->Tensor4ArgNameAndIndex(\"num_unique\", 0);\n    user_op::Tensor* tmp = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n    const bool& sorted = ctx->Attr<bool>(\"sorted\");\n    void* tmp_ptr = tmp ? tmp->mut_dptr() : nullptr;\n    int64_t tmp_size = tmp ? tmp->shape_view().elem_cnt() * GetSizeOfDataType(tmp->data_type()) : 0;\n    UniqueKernelUtil<device_type, T, K>::Unique(\n        ctx->stream(), x->shape_view().elem_cnt(), x->dptr<T>(), num_unique->mut_dptr<K>(),\n        y->mut_dptr<T>(), idx->mut_dptr<K>(), tmp_ptr, tmp_size, sorted);\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\ntemplate<DeviceType device_type, typename T, typename K>\nuser_op::InferTmpSizeFn GenInferTmpSizeFn() {\n  return [](user_op::InferContext* ctx) {\n    const auto& x = ctx->InputTensorDesc(\"x\", 0);\n    int64_t workspace_size_in_bytes = 0;\n    UniqueKernelUtil<device_type, T, K>::GetUniqueWorkspaceSizeInBytes(\n        nullptr, x.shape().elem_cnt(), &workspace_size_in_bytes);\n\n    return workspace_size_in_bytes;\n  };\n}\n\n#define REGISTER_UNIQUE_KERNEL(device_type_v, data_type_pair, indices_type_pair)            \\\n  REGISTER_USER_KERNEL(\"unique\")                                                            \\\n      .SetCreateFn<UniqueKernel<device_type_v, OF_PP_PAIR_FIRST(data_type_pair),            \\\n                                OF_PP_PAIR_FIRST(indices_type_pair)>>()                     \\\n      .SetIsMatchedHob(                                                                     \\\n          (user_op::HobDeviceType() == device_type_v)                                       \\\n          && (user_op::HobDataType(\"x\", 0) == OF_PP_PAIR_SECOND(data_type_pair))            \\\n          && (user_op::HobDataType(\"idx\", 0) == OF_PP_PAIR_SECOND(indices_type_pair)))      \\\n      .SetInferTmpSizeFn(GenInferTmpSizeFn<device_type_v, OF_PP_PAIR_FIRST(data_type_pair), \\\n                                           OF_PP_PAIR_FIRST(indices_type_pair)>());\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_UNIQUE_KERNEL, DEVICE_TYPE_SEQ, ARITHMETIC_DATA_TYPE_SEQ,\n                                 INDEX_DATA_TYPE_SEQ)\n\n}  // namespace\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/unique_kernel_util.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/user/kernels/unique_kernel_util.h\"\n\nnamespace oneflow {\n\ntemplate<typename KEY, typename IDX>\nstruct UniqueKernelUtil<DeviceType::kCPU, KEY, IDX> {\n  static void Unique(ep::Stream* stream, int64_t n, const KEY* in, IDX* num_unique, KEY* unique_out,\n                     IDX* idx_out, void* workspace, int64_t workspace_size_in_bytes, bool sorted) {\n    UniqueKernelUtil<DeviceType::kCPU, KEY, IDX>::UniqueWithCounts(\n        stream, n, in, num_unique, unique_out, idx_out, nullptr, workspace, workspace_size_in_bytes,\n        sorted);\n  }\n  static void UniqueWithCounts(ep::Stream* stream, int64_t n, const KEY* in, IDX* num_unique,\n                               KEY* unique_out, IDX* idx_out, IDX* count, void* workspace,\n                               int64_t workspace_size_in_bytes, bool sorted) {\n    std::vector<int64_t> sorted_idx(n);\n    std::iota(sorted_idx.begin(), sorted_idx.end(), 0);\n    if (sorted) {\n      std::sort(sorted_idx.begin(), sorted_idx.end(),\n                [&in](size_t a, size_t b) { return in[a] < in[b]; });\n    }\n\n    HashMap<KEY, IDX> map;\n    for (int64_t i : sorted_idx) {\n      KEY in_i = in[i];\n      auto it = map.find(in_i);\n      if (it == map.end()) {\n        IDX idx = map.size();\n        if (count != nullptr) { count[idx] = 1; }\n        idx_out[i] = idx;\n        unique_out[idx] = in_i;\n        map[in_i] = idx;\n      } else {\n        IDX idx = it->second;\n        if (count != nullptr) { count[idx] += 1; }\n        idx_out[i] = idx;\n      }\n    }\n    *num_unique = map.size();\n  }\n\n  static void GetUniqueWorkspaceSizeInBytes(ep::Stream* stream, int64_t n,\n                                            int64_t* workspace_size_in_bytes) {\n    *workspace_size_in_bytes = 1;\n  }\n  static void GetUniqueWithCountsWorkspaceSizeInBytes(ep::Stream* stream, int64_t n,\n                                                      int64_t* workspace_size_in_bytes) {\n    *workspace_size_in_bytes = 1;\n  }\n};\n\n#define INSTANTIATE_UNIQUE_KERNEL_UTIL_CPU(key_type_pair, idx_type_pair)              \\\n  template struct UniqueKernelUtil<DeviceType::kCPU, OF_PP_PAIR_FIRST(key_type_pair), \\\n                                   OF_PP_PAIR_FIRST(idx_type_pair)>;\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_UNIQUE_KERNEL_UTIL_CPU, ARITHMETIC_DATA_TYPE_SEQ,\n                                 INDEX_DATA_TYPE_SEQ);\n#undef INSTANTIATE_UNIQUE_KERNEL_UTIL_CPU\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/unique_kernel_util.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/user/kernels/unique_kernel_util.h\"\n#include \"oneflow/core/cuda/unique.cuh\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nconstexpr cuda::unique::Flag kUniqueFlag = cuda::unique::kOutputInverseIndices;\nconstexpr cuda::unique::Flag kUniqueWithCountsFlag =\n    cuda::unique::kOutputInverseIndices | cuda::unique::kOutputCounts;\n\n}  // namespace\n\ntemplate<typename KEY, typename IDX>\nstruct UniqueKernelUtil<DeviceType::kCUDA, KEY, IDX> {\n  static void Unique(ep::Stream* stream, int64_t n, const KEY* in, IDX* num_unique, KEY* unique_out,\n                     IDX* idx_out, void* workspace, int64_t workspace_size_in_bytes, bool sorted);\n  static void UniqueWithCounts(ep::Stream* stream, int64_t n, const KEY* in, IDX* num_unique,\n                               KEY* unique_out, IDX* idx_out, IDX* count, void* workspace,\n                               int64_t workspace_size_in_bytes, bool sorted);\n  static void GetUniqueWorkspaceSizeInBytes(ep::Stream* stream, int64_t n,\n                                            int64_t* workspace_size_in_bytes);\n  static void GetUniqueWithCountsWorkspaceSizeInBytes(ep::Stream* stream, int64_t n,\n                                                      int64_t* workspace_size_in_bytes);\n};\n\ntemplate<typename KEY, typename IDX>\nvoid UniqueKernelUtil<DeviceType::kCUDA, KEY, IDX>::Unique(\n    ep::Stream* stream, int64_t n, const KEY* in, IDX* num_unique, KEY* unique_out, IDX* idx_out,\n    void* workspace, int64_t workspace_size_in_bytes,\n    bool sorted /* not used, always return sorted output in CUDA,it`s the same as torch.unique*/) {\n  OF_CUDA_CHECK((cuda::unique::Launch<KEY, IDX>(kUniqueFlag, n, in, unique_out, num_unique, idx_out,\n                                                nullptr, workspace, workspace_size_in_bytes,\n                                                stream->As<ep::CudaStream>()->cuda_stream())));\n}\n\ntemplate<typename KEY, typename IDX>\nvoid UniqueKernelUtil<DeviceType::kCUDA, KEY, IDX>::UniqueWithCounts(\n    ep::Stream* stream, int64_t n, const KEY* in, IDX* num_unique, KEY* unique_out, IDX* idx_out,\n    IDX* count, void* workspace, int64_t workspace_size_in_bytes,\n    bool sorted /* not used, always return sorted output in CUDA,it`s the same as torch.unique*/) {\n  OF_CUDA_CHECK((cuda::unique::Launch<KEY, IDX>(\n      kUniqueWithCountsFlag, n, in, unique_out, num_unique, idx_out, count, workspace,\n      workspace_size_in_bytes, stream->As<ep::CudaStream>()->cuda_stream())));\n}\n\ntemplate<typename KEY, typename IDX>\nvoid UniqueKernelUtil<DeviceType::kCUDA, KEY, IDX>::GetUniqueWorkspaceSizeInBytes(\n    ep::Stream* stream, int64_t n, int64_t* workspace_size_in_bytes) {\n  size_t ws = 0;\n  OF_CUDA_CHECK((cuda::unique::GetWorkspaceSize<KEY, IDX>(kUniqueFlag, n, &ws)));\n  *workspace_size_in_bytes = static_cast<int64_t>(ws);\n}\n\ntemplate<typename KEY, typename IDX>\nvoid UniqueKernelUtil<DeviceType::kCUDA, KEY, IDX>::GetUniqueWithCountsWorkspaceSizeInBytes(\n    ep::Stream* stream, int64_t n, int64_t* workspace_size_in_bytes) {\n  size_t ws = 0;\n  OF_CUDA_CHECK((cuda::unique::GetWorkspaceSize<KEY, IDX>(kUniqueWithCountsFlag, n, &ws)));\n  *workspace_size_in_bytes = static_cast<int64_t>(ws);\n}\n\n#define INSTANTIATE_UNIQUE_KERNEL_UTIL_CUDA(key_type_pair, idx_type_pair)              \\\n  template struct UniqueKernelUtil<DeviceType::kCUDA, OF_PP_PAIR_FIRST(key_type_pair), \\\n                                   OF_PP_PAIR_FIRST(idx_type_pair)>;\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_UNIQUE_KERNEL_UTIL_CUDA, ARITHMETIC_DATA_TYPE_SEQ,\n                                 INDEX_DATA_TYPE_SEQ);\n#undef INSTANTIATE_UNIQUE_KERNEL_UTIL_CUDA\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/unique_kernel_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_KERNELS_UNIQUE_KERNEL_UTIL_H_\n#define ONEFLOW_USER_KERNELS_UNIQUE_KERNEL_UTIL_H_\n\n#include \"oneflow/core/kernel/kernel_util.h\"\n#include \"oneflow/core/ep/include/stream.h\"\n\nnamespace oneflow {\n\ntemplate<DeviceType device_type, typename KEY, typename IDX>\nstruct UniqueKernelUtil {\n  static void Unique(ep::Stream* stream, int64_t n, const KEY* in, IDX* num_unique, KEY* unique_out,\n                     IDX* idx_out, void* workspace, int64_t workspace_size_in_bytes, bool sorted);\n  static void UniqueWithCounts(ep::Stream* stream, int64_t n, const KEY* in, IDX* num_unique,\n                               KEY* unique_out, IDX* idx_out, IDX* count, void* workspace,\n                               int64_t workspace_size_in_bytes, bool sorted);\n  static void GetUniqueWorkspaceSizeInBytes(ep::Stream* stream, int64_t n,\n                                            int64_t* workspace_size_in_bytes);\n  static void GetUniqueWithCountsWorkspaceSizeInBytes(ep::Stream* stream, int64_t n,\n                                                      int64_t* workspace_size_in_bytes);\n};\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_KERNELS_UNIQUE_KERNEL_UTIL_H_\n"
  },
  {
    "path": "oneflow/user/kernels/unique_with_counts_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/user/kernels/unique_kernel_util.h\"\n#include \"oneflow/core/framework/framework.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<DeviceType device_type, typename T, typename K>\nclass UniqueWithCountsKernel final : public user_op::OpKernel {\n public:\n  UniqueWithCountsKernel() = default;\n  ~UniqueWithCountsKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    user_op::Tensor* idx = ctx->Tensor4ArgNameAndIndex(\"idx\", 0);\n    user_op::Tensor* count = ctx->Tensor4ArgNameAndIndex(\"count\", 0);\n    user_op::Tensor* num_unique = ctx->Tensor4ArgNameAndIndex(\"num_unique\", 0);\n    user_op::Tensor* tmp = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n    const bool& sorted = ctx->Attr<bool>(\"sorted\");\n    void* tmp_ptr = tmp ? tmp->mut_dptr() : nullptr;\n    int64_t tmp_size = tmp ? tmp->shape_view().elem_cnt() * GetSizeOfDataType(tmp->data_type()) : 0;\n    UniqueKernelUtil<device_type, T, K>::UniqueWithCounts(\n        ctx->stream(), x->shape_view().elem_cnt(), x->dptr<T>(), num_unique->mut_dptr<K>(),\n        y->mut_dptr<T>(), idx->mut_dptr<K>(), count->mut_dptr<K>(), tmp_ptr, tmp_size, sorted);\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\ntemplate<DeviceType device_type, typename T, typename K>\nuser_op::InferTmpSizeFn GenInferTmpSizeFn() {\n  return [](user_op::InferContext* ctx) {\n    const auto& x = ctx->InputTensorDesc(\"x\", 0);\n    int64_t workspace_size_in_bytes;\n    UniqueKernelUtil<device_type, T, K>::GetUniqueWithCountsWorkspaceSizeInBytes(\n        nullptr, x.shape().elem_cnt(), &workspace_size_in_bytes);\n\n    return workspace_size_in_bytes;\n  };\n}\n\n#define REGISTER_UNIQUE_WITH_COUNTS_KERNEL(device_type_v, data_type_pair, indices_type_pair) \\\n  REGISTER_USER_KERNEL(\"unique_with_counts\")                                                 \\\n      .SetCreateFn<UniqueWithCountsKernel<device_type_v, OF_PP_PAIR_FIRST(data_type_pair),   \\\n                                          OF_PP_PAIR_FIRST(indices_type_pair)>>()            \\\n      .SetIsMatchedHob(                                                                      \\\n          (user_op::HobDeviceType() == device_type_v)                                        \\\n          && (user_op::HobDataType(\"x\", 0) == OF_PP_PAIR_SECOND(data_type_pair))             \\\n          && (user_op::HobDataType(\"idx\", 0) == OF_PP_PAIR_SECOND(indices_type_pair)))       \\\n      .SetInferTmpSizeFn(GenInferTmpSizeFn<device_type_v, OF_PP_PAIR_FIRST(data_type_pair),  \\\n                                           OF_PP_PAIR_FIRST(indices_type_pair)>());\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_UNIQUE_WITH_COUNTS_KERNEL, DEVICE_TYPE_SEQ,\n                                 ARITHMETIC_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ)\n\n}  // namespace\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/unpack_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/new_kernel_util.h\"\n#include \"oneflow/user/kernels/op_kernel_wrapper.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<DeviceType device_type>\nclass UnpackKernel final : public user_op::OpKernel {\n public:\n  UnpackKernel() = default;\n  ~UnpackKernel() override = default;\n\n  std::shared_ptr<user_op::OpKernelState> CreateOpKernelState(\n      user_op::KernelInitContext* ctx) const override {\n    return std::make_shared<OpKernelStateWrapper<std::pair<size_t, size_t>>>(\n        std::make_pair<size_t, size_t>(0, 0));\n  }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state,\n               const user_op::OpKernelCache*) const override {\n    const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex(\"in\", 0);\n    CHECK_GT(in->shape_view().NumAxes(), 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    CHECK_EQ(in->data_type(), out->data_type());\n    CHECK_EQ(in->shape_view().NumAxes(), out->shape_view().NumAxes());\n    const auto unpack_num = ctx->Attr<int32_t>(\"unpack_num\");\n    CHECK_EQ(out->shape_view().At(0) * unpack_num, in->shape_view().At(0));\n    for (int64_t i = 1; i < in->shape_view().NumAxes(); ++i) {\n      CHECK_EQ(out->shape_view().At(i), in->shape_view().At(i));\n    }\n    const int64_t copy_size = out->shape_view().elem_cnt() * GetSizeOfDataType(out->data_type());\n    auto* state_wrapper = dynamic_cast<OpKernelStateWrapper<std::pair<size_t, size_t>>*>(state);\n    CHECK_NOTNULL(state_wrapper);\n    const size_t index = state_wrapper->Get().first;\n    CHECK_EQ(state_wrapper->Get().second, unpack_num);\n    Memcpy<device_type>(ctx->stream(), out->mut_dptr<char>(), in->dptr<char>() + index * copy_size,\n                        copy_size);\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_UNPACK_KERNEL(device)                                                \\\n  REGISTER_USER_KERNEL(\"unpack\").SetCreateFn<UnpackKernel<device>>().SetIsMatchedHob( \\\n      (user_op::HobDeviceType() == device));\n\nOF_PP_FOR_EACH_TUPLE(REGISTER_UNPACK_KERNEL, DEVICE_TYPE_SEQ)\n#if defined(WITH_MLU)\nREGISTER_UNPACK_KERNEL(DeviceType::kMLU)\n#endif\n#undef REGISTER_UNPACK_KERNEL\n\n}  // namespace\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/unsorted_batch_segment_sum_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/user/kernels/batch_gather_kernel_util.h\"\n#include \"oneflow/core/kernel/cuda_graph_support.h\"\n\nnamespace oneflow {\n\nnamespace user_op {\n\nnamespace {\n\nShape GetFlatShape(const ShapeView& shape, const int64_t axis) {\n  CHECK_GT(shape.NumAxes(), 0);\n  CHECK_GE(axis, 0);\n  CHECK_LT(axis, shape.NumAxes());\n  return Shape({shape.Count(0, axis), shape.At(axis), shape.Count(axis + 1)});\n}\n\n}  // namespace\n\ntemplate<DeviceType device_type, typename T, typename K>\nclass UnsortedBatchSegmentSumKernel final : public user_op::OpKernel,\n                                            public user_op::CudaGraphSupport {\n public:\n  UnsortedBatchSegmentSumKernel() = default;\n  ~UnsortedBatchSegmentSumKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* data = ctx->Tensor4ArgNameAndIndex(\"data\", 0);\n    const user_op::Tensor* segment_ids = ctx->Tensor4ArgNameAndIndex(\"segment_ids\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    const int64_t axis = segment_ids->shape_view().NumAxes() - 1;\n    const Shape& flat_data_shape = GetFlatShape(data->shape_view(), axis);\n\n    Memset<device_type>(ctx->stream(), out->mut_dptr(), 0,\n                        out->shape_view().elem_cnt() * sizeof(T));\n    BatchGatherKernelUtilImpl<device_type, T, K>::Backward(\n        ctx->stream(), data->dptr<T>(), segment_ids->dptr<K>(), flat_data_shape,\n        out->shape_view().At(axis), out->mut_dptr<T>());\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return true; }\n};\n\n#define REGISTER_UNSORTED_BATCH_SEGMENT_SUM_KERNEL(device, out_dtype, segment_ids_dtype)      \\\n  REGISTER_USER_KERNEL(\"unsorted_batch_segment_sum\")                                          \\\n      .SetCreateFn<UnsortedBatchSegmentSumKernel<device, OF_PP_PAIR_FIRST(out_dtype),         \\\n                                                 OF_PP_PAIR_FIRST(segment_ids_dtype)>>()      \\\n      .SetIsMatchedHob(                                                                       \\\n          (user_op::HobDeviceType() == device)                                                \\\n          && (user_op::HobDataType(\"segment_ids\", 0) == OF_PP_PAIR_SECOND(segment_ids_dtype)) \\\n          && (user_op::HobDataType(\"out\", 0) == OF_PP_PAIR_SECOND(out_dtype)));\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_UNSORTED_BATCH_SEGMENT_SUM_KERNEL, DEVICE_TYPE_SEQ,\n                                 FLOATING_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ)\n\n}  // namespace user_op\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/unsorted_segment_sum_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/user/kernels/unsorted_segment_sum_kernel_util.h\"\n#include \"oneflow/core/kernel/cuda_graph_support.h\"\n#include \"oneflow/core/job/nd_sbp_util.h\"\n#include \"oneflow/core/ep/include/primitive/cast.h\"\n#ifdef WITH_CUDA\n#include <cuda.h>\n#endif\n\nnamespace oneflow {\n\nnamespace user_op {\n\nnamespace {\n\nvoid CheckNdSbp(const Shape& hierarchy, int64_t sum_axis, const NdSbp& segment_ids_nd_sbp,\n                const NdSbp& data_nd_sbp, const NdSbp& out_nd_sbp) {\n  CHECK_EQ(hierarchy.NumAxes(), segment_ids_nd_sbp.sbp_parallel_size());\n  CHECK_EQ(hierarchy.NumAxes(), data_nd_sbp.sbp_parallel_size());\n  CHECK_EQ(hierarchy.NumAxes(), out_nd_sbp.sbp_parallel_size());\n  if (hierarchy.elem_cnt() == 1) { return; }\n  FOR_RANGE(int64_t, i, 0, hierarchy.NumAxes()) {\n    const auto& out_sbp = out_nd_sbp.sbp_parallel(i);\n    if (out_sbp.has_split_parallel() && out_sbp.split_parallel().axis() == sum_axis) {\n      CHECK(segment_ids_nd_sbp.sbp_parallel(i).has_broadcast_parallel());\n      CHECK(data_nd_sbp.sbp_parallel(i).has_broadcast_parallel());\n    }\n  }\n}\n\nclass UnsortedSegmentSumOpKernelCache final : public user_op::OpKernelCache {\n public:\n  UnsortedSegmentSumOpKernelCache(int64_t lower, int64_t upper) : lower_(lower), upper_(upper) {}\n  ~UnsortedSegmentSumOpKernelCache() override = default;\n\n  int64_t lower() const { return lower_; }\n  int64_t upper() const { return upper_; }\n\n private:\n  const int64_t lower_;\n  const int64_t upper_;\n};\n\nstd::shared_ptr<user_op::OpKernelCache> CreateUnsortedSegmentSumOpKernelCache(\n    user_op::KernelCacheContext* ctx) {\n  if (ctx->parallel_ctx().parallel_num() > 1) {\n    const auto axis = ctx->Attr<int64_t>(\"axis\");\n    const NdSbp& out_nd_sbp = ctx->NdSbp4ArgNameAndIndex(\"out\", 0);\n    const Shape& hierarchy = *ctx->parallel_desc().hierarchy();\n    CheckNdSbp(hierarchy, axis, ctx->NdSbp4ArgNameAndIndex(\"segment_ids\", 0),\n               ctx->NdSbp4ArgNameAndIndex(\"data\", 0), out_nd_sbp);\n    const TensorDesc* out_logical_desc = ctx->LogicalTensorDesc4ArgNameAndIndex(\"out\", 0);\n    TensorSliceView view = GetTensorSliceView4ParallelId(\n        hierarchy, out_nd_sbp, out_logical_desc->shape(), ctx->parallel_ctx().parallel_id());\n    return std::make_shared<UnsortedSegmentSumOpKernelCache>(view.At(axis).begin(),\n                                                             view.At(axis).end());\n  } else {\n    return nullptr;\n  }\n}\n\n}  // namespace\n\ntemplate<DeviceType device_type, typename T, typename K>\nclass UnsortedSegmentSumKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport {\n public:\n  UnsortedSegmentSumKernel() = default;\n  ~UnsortedSegmentSumKernel() override = default;\n\n  std::shared_ptr<user_op::OpKernelCache> InitOpKernelCache(\n      user_op::KernelCacheContext* ctx) const override {\n    return CreateUnsortedSegmentSumOpKernelCache(ctx);\n  }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*,\n               const user_op::OpKernelCache* cache) const override {\n    const user_op::Tensor* data = ctx->Tensor4ArgNameAndIndex(\"data\", 0);\n    const user_op::Tensor* segment_ids = ctx->Tensor4ArgNameAndIndex(\"segment_ids\", 0);\n    int64_t axis = ctx->Attr<int64_t>(\"axis\");\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    int64_t outer_dim_size = out->shape_view().Count(0, axis);\n    int64_t num_segments = out->shape_view().At(axis);\n    int64_t inner_dim_size = out->shape_view().Count(axis + 1);\n    int64_t num_segment_ids = segment_ids->shape_view().elem_cnt();\n    Memset<device_type>(ctx->stream(), out->mut_dptr(), 0,\n                        out->shape_view().elem_cnt() * sizeof(T));\n\n    int64_t offset = 0;\n    if (cache != nullptr) {\n      auto* sum_cache = dynamic_cast<const UnsortedSegmentSumOpKernelCache*>(cache);\n      CHECK_NOTNULL(sum_cache);\n      CHECK_EQ(out->shape_view().At(axis), sum_cache->upper() - sum_cache->lower());\n      offset = sum_cache->lower();\n    }\n\n    if (num_segment_ids != 0) {\n      UnsortedSegmentSumKernelUtil<device_type, T, K, T>::UnsortedSegmentSum(\n          ctx->stream(), segment_ids->dptr<K>(), data->dptr<T>(), num_segment_ids, num_segments,\n          outer_dim_size, inner_dim_size, offset, out->mut_dptr<T>());\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return true; }\n};\n\n#define REGISTER_UNSORTED_SEGMENT_SUM_KERNEL(device, out_type, segment_ids_type, kernel_type) \\\n  REGISTER_USER_KERNEL(kernel_type)                                                           \\\n      .SetCreateFn<UnsortedSegmentSumKernel<device, OF_PP_PAIR_FIRST(out_type),               \\\n                                            OF_PP_PAIR_FIRST(segment_ids_type)>>()            \\\n      .SetIsMatchedHob(                                                                       \\\n          (user_op::HobDeviceType() == device)                                                \\\n          && (user_op::HobDataType(\"segment_ids\", 0) == OF_PP_PAIR_SECOND(segment_ids_type))  \\\n          && (user_op::HobDataType(\"out\", 0) == OF_PP_PAIR_SECOND(out_type)));\n\n#define REGISTER_UNSORTED_SEGMENT_SUM_KERNEL_CASE(device_type, out_type, segment_ids_type) \\\n  REGISTER_UNSORTED_SEGMENT_SUM_KERNEL(device_type, out_type, segment_ids_type,            \\\n                                       (\"unsorted_segment_sum\"))\n\n#define REGISTER_UNSORTED_SEGMENT_SUM_LIKE_KERNEL_CASE(device_type, out_type, segment_ids_type) \\\n  REGISTER_UNSORTED_SEGMENT_SUM_KERNEL(device_type, out_type, segment_ids_type,                 \\\n                                       (\"unsorted_segment_sum_like\"))\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_UNSORTED_SEGMENT_SUM_KERNEL_CASE, DEVICE_TYPE_SEQ,\n                                 UNSORTED_SEGMENT_SUM_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ)\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_UNSORTED_SEGMENT_SUM_LIKE_KERNEL_CASE, DEVICE_TYPE_SEQ,\n                                 UNSORTED_SEGMENT_SUM_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ)\n\n#ifdef WITH_CUDA\ntemplate<typename T, typename K>\nclass UnsortedSegmentSumHalfKernel final : public user_op::OpKernel {\n public:\n  UnsortedSegmentSumHalfKernel() = default;\n  ~UnsortedSegmentSumHalfKernel() override = default;\n\n  std::shared_ptr<user_op::OpKernelCache> InitOpKernelCache(\n      user_op::KernelCacheContext* ctx) const override {\n    return CreateUnsortedSegmentSumOpKernelCache(ctx);\n  }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*,\n               const user_op::OpKernelCache* cache) const override {\n    const user_op::Tensor* data = ctx->Tensor4ArgNameAndIndex(\"data\", 0);\n    const user_op::Tensor* segment_ids = ctx->Tensor4ArgNameAndIndex(\"segment_ids\", 0);\n    int64_t axis = ctx->Attr<int64_t>(\"axis\");\n    user_op::Tensor* tmp_buf = ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    int64_t outer_dim_size = out->shape_view().Count(0, axis);\n    int64_t num_segments = out->shape_view().At(axis);\n    int64_t inner_dim_size = out->shape_view().Count(axis + 1);\n    int64_t num_segment_ids = segment_ids->shape_view().elem_cnt();\n    Memset<DeviceType::kCUDA>(ctx->stream(), tmp_buf->mut_dptr(), 0,\n                              out->shape_view().elem_cnt() * sizeof(float));\n    int64_t offset = 0;\n    if (cache != nullptr) {\n      auto* sum_cache = dynamic_cast<const UnsortedSegmentSumOpKernelCache*>(cache);\n      CHECK_NOTNULL(sum_cache);\n      CHECK_EQ(out->shape_view().At(axis), sum_cache->upper() - sum_cache->lower());\n      offset = sum_cache->lower();\n    }\n\n    UnsortedSegmentSumKernelUtil<DeviceType::kCUDA, float, K, T>::UnsortedSegmentSum(\n        ctx->stream(), segment_ids->dptr<K>(), data->dptr<T>(), num_segment_ids, num_segments,\n        outer_dim_size, inner_dim_size, offset, tmp_buf->mut_dptr<float>());\n\n    auto f2h = ep::primitive::NewPrimitive<ep::primitive::CastFactory>(\n        ctx->device_type(), DataType::kFloat, out->data_type());\n    CHECK(f2h);\n    f2h->Launch(ctx->stream(), tmp_buf->dptr<float>(), out->mut_dptr<T>(),\n                out->shape_view().elem_cnt());\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return true; }\n};\n\n#define REGISTER_UNSORTED_SEGMENT_SUM_HALF_HALF_KERNEL(out_type, segment_ids_type, kernel_type) \\\n  REGISTER_USER_KERNEL(kernel_type)                                                             \\\n      .SetCreateFn<UnsortedSegmentSumHalfKernel<OF_PP_PAIR_FIRST(out_type),                     \\\n                                                OF_PP_PAIR_FIRST(segment_ids_type)>>()          \\\n      .SetIsMatchedHob(                                                                         \\\n          (user_op::HobDeviceType() == DeviceType::kCUDA)                                       \\\n          && (user_op::HobDataType(\"segment_ids\", 0) == OF_PP_PAIR_SECOND(segment_ids_type))    \\\n          && (user_op::HobDataType(\"out\", 0) == OF_PP_PAIR_SECOND(out_type)))                   \\\n      .SetInferTmpSizeFn([](user_op::InferContext* ctx) {                                       \\\n        const Shape& out_shape = ctx->OutputShape(\"out\", 0);                                    \\\n        return GetCudaAlignedSize(out_shape.elem_cnt() * sizeof(float));                        \\\n      });\n\n#define REGISTER_UNSORTED_SEGMENT_SUM_HALF_KERNEL_CASE(out_type, segment_ids_type) \\\n  REGISTER_UNSORTED_SEGMENT_SUM_HALF_HALF_KERNEL(out_type, segment_ids_type,       \\\n                                                 (\"unsorted_segment_sum\"))         \\\n  REGISTER_UNSORTED_SEGMENT_SUM_HALF_HALF_KERNEL(out_type, segment_ids_type,       \\\n                                                 (\"unsorted_segment_sum_like\"))\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_UNSORTED_SEGMENT_SUM_HALF_KERNEL_CASE,\n                                 FLOAT16_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ)\n\n#if CUDA_VERSION >= 11000\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_UNSORTED_SEGMENT_SUM_HALF_KERNEL_CASE,\n                                 OF_PP_MAKE_TUPLE_SEQ(nv_bfloat16, DataType::kBFloat16),\n                                 INDEX_DATA_TYPE_SEQ)\n#endif\n\n#undef REGISTER_UNSORTED_SEGMENT_SUM_HALF_KERNEL_CASE\n\n#endif  // WITH_CUDA\n\n}  // namespace user_op\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/unsorted_segment_sum_kernel_util.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/user/kernels/unsorted_segment_sum_kernel_util.h\"\n\nnamespace oneflow {\n\ntemplate<typename T, typename K>\nstruct UnsortedSegmentSumKernelUtil<DeviceType::kCPU, T, K, T> final {\n  static void UnsortedSegmentSum(ep::Stream* stream, const K* segment_ids, const T* data,\n                                 int64_t num_segment_ids, int64_t num_segments,\n                                 int64_t outer_dim_size, int64_t inner_dim_size,\n                                 int64_t segment_id_offset, T* out);\n};\n\ntemplate<typename T, typename K>\nvoid UnsortedSegmentSumKernelUtil<DeviceType::kCPU, T, K, T>::UnsortedSegmentSum(\n    ep::Stream* stream, const K* segment_ids, const T* data, int64_t num_segment_ids,\n    int64_t num_segments, int64_t outer_dim_size, int64_t inner_dim_size, int64_t segment_id_offset,\n    T* out) {\n  FOR_RANGE(int64_t, outer_idx, 0, outer_dim_size) {\n    FOR_RANGE(int64_t, i, 0, num_segment_ids) {\n      CHECK_GE(segment_ids[i], 0);\n      const int64_t idx = segment_ids[i] - segment_id_offset;\n      T* to = out + outer_idx * num_segments * inner_dim_size + idx * inner_dim_size;\n      if (idx >= 0 && idx < num_segments) {\n        const T* from = data + outer_idx * num_segment_ids * inner_dim_size + i * inner_dim_size;\n        std::transform(from, from + inner_dim_size, to, to, std::plus<T>());\n      }\n    }\n  }\n}\n#define INITIATE_UNSORTED_SEGMENT_SUM_KERNEL_UTIL_CPU(in_type_pair, index_type_pair)             \\\n  template struct UnsortedSegmentSumKernelUtil<DeviceType::kCPU, OF_PP_PAIR_FIRST(in_type_pair), \\\n                                               OF_PP_PAIR_FIRST(index_type_pair),                \\\n                                               OF_PP_PAIR_FIRST(in_type_pair)>;\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INITIATE_UNSORTED_SEGMENT_SUM_KERNEL_UTIL_CPU,\n                                 UNSORTED_SEGMENT_SUM_DATA_TYPE_SEQ,\n                                 UNSORTED_SEGMENT_SUM_INDEX_TYPE_SEQ);\n\n#undef INITIATE_UNSORTED_SEGMENT_SUM_KERNEL_UTIL_CPU\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/unsorted_segment_sum_kernel_util.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/nd_index_offset_helper.h\"\n#include \"oneflow/user/kernels/unsorted_segment_sum_kernel_util.h\"\n#include \"oneflow/core/cuda/atomic.cuh\"\n#include \"oneflow/core/kernel/kernel.h\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n#include <assert.h>\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename T>\n__device__ __forceinline__ bool IsZero(T v) {\n  return v == 0;\n}\n\ntemplate<>\n__device__ __forceinline__ bool IsZero<half>(half v) {\n  return v == static_cast<half>(0);\n}\n\n#if CUDA_VERSION >= 11000\n\ntemplate<>\n__device__ __forceinline__ bool IsZero<nv_bfloat16>(nv_bfloat16 v) {\n  return v == __float2bfloat16(0);\n}\n\n#endif\n\ntemplate<>\n__device__ __forceinline__ bool IsZero<half2>(half2 v) {\n  return v.x == static_cast<half>(0) && v.y == static_cast<half>(0);\n}\n\ntemplate<typename T, typename K, typename IDX, typename U>\n__global__ void UnsortedSegmentSumGpu(const IDX data_elem_cnt,\n                                      const NdIndexOffsetHelper<IDX, 3> in_helper,\n                                      const NdIndexOffsetHelper<IDX, 3> out_helper, const U* data,\n                                      const K* segment_ids, const IDX num_segments,\n                                      const IDX segment_id_offset, T* out) {\n  CUDA_1D_KERNEL_LOOP_T(IDX, i, data_elem_cnt) {\n    const U val = data[i];\n    if (!IsZero(val)) {\n      IDX outer_idx, segment_id_idx, inner_idx;\n      in_helper.OffsetToNdIndex(i, outer_idx, segment_id_idx, inner_idx);\n      const K origin_idx = segment_ids[segment_id_idx];\n      assert(origin_idx >= 0);\n      const IDX idx = origin_idx - segment_id_offset;\n      if (idx >= 0 && idx < num_segments) {\n        const int64_t out_offset = out_helper.NdIndexToOffset(outer_idx, idx, inner_idx);\n        if (out_offset >= 0) { cuda::atomic::Add(out + out_offset, static_cast<T>(val)); }\n      }\n    }\n  }\n}\n\ntemplate<typename T, typename K, typename IDX, typename U>\n__global__ void UnsortedSegmentColSumGpu(const IDX data_elem_cnt,\n                                         const NdIndexOffsetHelper<IDX, 2> in_helper,\n                                         const NdIndexOffsetHelper<IDX, 2> out_helper,\n                                         const U* data, const K* segment_ids,\n                                         const IDX num_segments, const IDX segment_id_offset,\n                                         T* out) {\n  CUDA_1D_KERNEL_LOOP_T(IDX, i, data_elem_cnt) {\n    const U val = data[i];\n    if (!IsZero(val)) {\n      IDX outer_idx, segment_id_idx;\n      in_helper.OffsetToNdIndex(i, outer_idx, segment_id_idx);\n      const K origin_idx = segment_ids[segment_id_idx];\n      assert(origin_idx >= 0);\n      const IDX idx = origin_idx - segment_id_offset;\n      if (idx >= 0 && idx < num_segments) {\n        const int64_t out_offset = out_helper.NdIndexToOffset(outer_idx, idx);\n        if (out_offset >= 0) { cuda::atomic::Add(out + out_offset, static_cast<T>(val)); }\n      }\n    }\n  }\n}\n\ntemplate<typename T, typename K, typename IDX, typename U>\n__global__ void UnsortedSegmentRowSumGpu(const IDX data_elem_cnt,\n                                         const NdIndexOffsetHelper<IDX, 2> in_helper,\n                                         const NdIndexOffsetHelper<IDX, 2> out_helper,\n                                         const U* data, const K* segment_ids,\n                                         const IDX num_segments, const IDX segment_id_offset,\n                                         T* out) {\n  CUDA_1D_KERNEL_LOOP_T(IDX, i, data_elem_cnt) {\n    const U val = data[i];\n    if (!IsZero(val)) {\n      IDX segment_id_idx, inner_idx;\n      in_helper.OffsetToNdIndex(i, segment_id_idx, inner_idx);\n      const K origin_idx = segment_ids[segment_id_idx];\n      assert(origin_idx >= 0);\n      const IDX idx = origin_idx - segment_id_offset;\n      if (idx >= 0 && idx < num_segments) {\n        const int64_t out_offset = out_helper.NdIndexToOffset(idx, inner_idx);\n        if (out_offset >= 0) { cuda::atomic::Add(out + out_offset, static_cast<T>(val)); }\n      }\n    }\n  }\n}\n\ntemplate<typename T, typename K, typename IDX, typename U>\nvoid UnsortedSegmentSumUtil(ep::Stream* stream, const K* segment_ids, const U* data,\n                            IDX num_segment_ids, IDX num_segments, IDX outer_dim_size,\n                            IDX inner_dim_size, IDX segment_id_offset, T* out) {\n  const IDX data_elem_cnt = num_segment_ids * outer_dim_size * inner_dim_size;\n  if (inner_dim_size == 1) {\n    NdIndexOffsetHelper<IDX, 2> in_helper(outer_dim_size, num_segment_ids);\n    NdIndexOffsetHelper<IDX, 2> out_helper(outer_dim_size, num_segments);\n    UnsortedSegmentColSumGpu<T, K, IDX, U>\n        <<<BlocksNum4ThreadsNum(data_elem_cnt), kCudaThreadsNumPerBlock, 0,\n           stream->As<ep::CudaStream>()->cuda_stream()>>>(data_elem_cnt, in_helper, out_helper,\n                                                          data, segment_ids, num_segments,\n                                                          segment_id_offset, out);\n\n  } else if (outer_dim_size == 1) {\n    NdIndexOffsetHelper<IDX, 2> in_helper(num_segment_ids, inner_dim_size);\n    NdIndexOffsetHelper<IDX, 2> out_helper(num_segments, inner_dim_size);\n    UnsortedSegmentRowSumGpu<T, K, IDX, U>\n        <<<BlocksNum4ThreadsNum(data_elem_cnt), kCudaThreadsNumPerBlock, 0,\n           stream->As<ep::CudaStream>()->cuda_stream()>>>(data_elem_cnt, in_helper, out_helper,\n                                                          data, segment_ids, num_segments,\n                                                          segment_id_offset, out);\n\n  } else {\n    NdIndexOffsetHelper<IDX, 3> in_helper(outer_dim_size, num_segment_ids, inner_dim_size);\n    NdIndexOffsetHelper<IDX, 3> out_helper(outer_dim_size, num_segments, inner_dim_size);\n    UnsortedSegmentSumGpu<T, K, IDX, U>\n        <<<BlocksNum4ThreadsNum(data_elem_cnt), kCudaThreadsNumPerBlock, 0,\n           stream->As<ep::CudaStream>()->cuda_stream()>>>(data_elem_cnt, in_helper, out_helper,\n                                                          data, segment_ids, num_segments,\n                                                          segment_id_offset, out);\n  }\n}\n\ntemplate<typename T, typename K, typename IDX, typename U>\nvoid DispatchDataType(ep::Stream* stream, const K* segment_ids, const U* data,\n                      int64_t num_segment_ids, int64_t num_segments, int64_t outer_dim_size,\n                      int64_t inner_dim_size, int64_t segment_id_offset, T* out) {\n  auto* cuda_stream = stream->As<ep::CudaStream>();\n  if (std::is_same<T, half>::value && std::is_same<U, half>::value\n      && cuda_stream->device_properties().major >= 6\n      && reinterpret_cast<uintptr_t>(data) % sizeof(half2) == 0\n      && reinterpret_cast<uintptr_t>(out) % sizeof(half2) == 0 && inner_dim_size % 2 == 0) {\n    UnsortedSegmentSumUtil<half2, K, IDX, half2>(\n        stream, segment_ids, reinterpret_cast<const half2*>(data), num_segment_ids, num_segments,\n        outer_dim_size, inner_dim_size / 2, segment_id_offset, reinterpret_cast<half2*>(out));\n  } else {\n    UnsortedSegmentSumUtil<T, K, IDX, U>(stream, segment_ids, data, num_segment_ids, num_segments,\n                                         outer_dim_size, inner_dim_size, segment_id_offset, out);\n  }\n}\n\n}  // namespace\n\ntemplate<typename T, typename K, typename U>\nstruct UnsortedSegmentSumKernelUtil<DeviceType::kCUDA, T, K, U> final {\n  static void UnsortedSegmentSum(ep::Stream* stream, const K* segment_ids, const U* data,\n                                 int64_t num_segment_ids, int64_t num_segments,\n                                 int64_t outer_dim_size, int64_t inner_dim_size,\n                                 int64_t segment_id_offset, T* out) {\n    const int64_t data_elem_cnt = num_segment_ids * outer_dim_size * inner_dim_size;\n    const int64_t out_elem_cnt = outer_dim_size * num_segments * inner_dim_size;\n\n    if (std::max(data_elem_cnt, out_elem_cnt) < GetMaxVal<int32_t>() / 2) {\n      DispatchDataType<T, K, int32_t, U>(stream, segment_ids, data, num_segment_ids, num_segments,\n                                         outer_dim_size, inner_dim_size, segment_id_offset, out);\n    } else {\n      DispatchDataType<T, K, int64_t, U>(stream, segment_ids, data, num_segment_ids, num_segments,\n                                         outer_dim_size, inner_dim_size, segment_id_offset, out);\n    }\n  }\n};\n\ntemplate<typename K>\nstruct UnsortedSegmentSumKernelUtil<DeviceType::kCUDA, float, K, float16> final {\n  static void UnsortedSegmentSum(ep::Stream* stream, const K* segment_ids, const float16* data,\n                                 int64_t num_segment_ids, int64_t num_segments,\n                                 int64_t outer_dim_size, int64_t inner_dim_size,\n                                 int64_t segment_id_offset, float* out) {\n    UnsortedSegmentSumKernelUtil<DeviceType::kCUDA, float, K, half>::UnsortedSegmentSum(\n        stream, segment_ids, reinterpret_cast<const half*>(data), num_segment_ids, num_segments,\n        outer_dim_size, inner_dim_size, segment_id_offset, out);\n  }\n};\n\n#define INITIATE_UNSORTED_SEGMENT_SUM_KERNEL_UTIL_CUDA(in_type_pair, index_type_pair)             \\\n  template struct UnsortedSegmentSumKernelUtil<DeviceType::kCUDA, OF_PP_PAIR_FIRST(in_type_pair), \\\n                                               OF_PP_PAIR_FIRST(index_type_pair),                 \\\n                                               OF_PP_PAIR_FIRST(in_type_pair)>;\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INITIATE_UNSORTED_SEGMENT_SUM_KERNEL_UTIL_CUDA,\n                                 UNSORTED_SEGMENT_SUM_DATA_TYPE_SEQ,\n                                 UNSORTED_SEGMENT_SUM_INDEX_TYPE_SEQ);\n#undef INITIATE_UNSORTED_SEGMENT_SUM_KERNEL_UTIL_CUDA\n\n#define INITIATE_UNSORTED_SEGMENT_SUM_KERNEL_HALF_CUDA(in_type_pair, index_type_pair,             \\\n                                                       out_type_pair)                             \\\n  template struct UnsortedSegmentSumKernelUtil<DeviceType::kCUDA, OF_PP_PAIR_FIRST(in_type_pair), \\\n                                               OF_PP_PAIR_FIRST(index_type_pair),                 \\\n                                               OF_PP_PAIR_FIRST(out_type_pair)>;\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INITIATE_UNSORTED_SEGMENT_SUM_KERNEL_HALF_CUDA,\n                                 OF_PP_MAKE_TUPLE_SEQ(float, DataType::kFloat),\n                                 UNSORTED_SEGMENT_SUM_INDEX_TYPE_SEQ, FLOAT16_DATA_TYPE_SEQ);\n#if CUDA_VERSION >= 11000\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INITIATE_UNSORTED_SEGMENT_SUM_KERNEL_HALF_CUDA,\n                                 OF_PP_MAKE_TUPLE_SEQ(float, DataType::kFloat),\n                                 UNSORTED_SEGMENT_SUM_INDEX_TYPE_SEQ,\n                                 OF_PP_MAKE_TUPLE_SEQ(nv_bfloat16, DataType::kBFloat16));\n#endif\n\n#undef INITIATE_UNSORTED_SEGMENT_SUM_KERNEL_HALF_CUDA\n\ntemplate struct UnsortedSegmentSumKernelUtil<DeviceType::kCUDA, half, uint32_t, half>;\ntemplate struct UnsortedSegmentSumKernelUtil<DeviceType::kCUDA, half, int32_t, half>;\ntemplate struct UnsortedSegmentSumKernelUtil<DeviceType::kCUDA, half, int64_t, half>;\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/unsorted_segment_sum_kernel_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_KERNELS_UNSORTED_SEGMENT_SUM_KERNEL_UTIL_H_\n#define ONEFLOW_CORE_KERNELS_UNSORTED_SEGMENT_SUM_KERNEL_UTIL_H_\n\n#include \"oneflow/core/kernel/kernel_util.h\"\n\nnamespace oneflow {\n\ntemplate<DeviceType device_type, typename T, typename K, typename U>\nstruct UnsortedSegmentSumKernelUtil final {\n  static void UnsortedSegmentSum(ep::Stream* stream, const K* segment_ids, const U* data,\n                                 int64_t num_segment_ids, int64_t num_segments,\n                                 int64_t outer_dim_size, int64_t inner_dim_size,\n                                 int64_t segment_id_offset, T* out);\n};\n\n#define UNSORTED_SEGMENT_SUM_DATA_TYPE_SEQ \\\n  FLOATING_DATA_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(int32_t, DataType::kInt32)\n#define UNSORTED_SEGMENT_SUM_INDEX_TYPE_SEQ \\\n  INDEX_DATA_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(uint32_t, DataType::kUInt32)\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_KERNELS_UNSORTED_SEGMENT_SUM_KERNEL_UTIL_H_\n"
  },
  {
    "path": "oneflow/user/kernels/upsample_bicubic_2d_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/new_kernel_util.h\"\n#include \"oneflow/core/common/nd_index_offset_helper.h\"\n#include \"oneflow/user/kernels/upsample_kernel.h\"\n\nnamespace oneflow {\n\ntemplate<typename T>\nclass UpsampleBicubic2dCPUKernel final : public user_op::OpKernel {\n public:\n  UpsampleBicubic2dCPUKernel() = default;\n  ~UpsampleBicubic2dCPUKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* x_tensor = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    user_op::Tensor* y_tensor = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    const std::vector<int64_t> output_size = ctx->Attr<std::vector<int64_t>>(\"output_size\");\n    double height_scale = ctx->Attr<double>(\"height_scale\");\n    double width_scale = ctx->Attr<double>(\"width_scale\");\n\n    const T* in_ptr = x_tensor->dptr<T>();\n    T* out_ptr = y_tensor->mut_dptr<T>();\n    const bool align_corners = ctx->Attr<bool>(\"align_corners\");\n    const int nbatch = x_tensor->shape_view().At(0);\n    const int channels = x_tensor->shape_view().At(1);\n\n    const int64_t in_height = x_tensor->shape_view().At(2);\n    const int64_t in_width = x_tensor->shape_view().At(3);\n    const int64_t out_height = y_tensor->shape_view().At(2);\n    const int64_t out_width = y_tensor->shape_view().At(3);\n    if (!output_size.empty()) {\n      height_scale = static_cast<double>(out_height) / static_cast<double>(in_height);\n      width_scale = static_cast<double>(out_width) / static_cast<double>(in_width);\n    }\n\n    if (in_height == out_height && in_width == out_width) {\n      memcpy(out_ptr, in_ptr, sizeof(T) * nbatch * channels * in_height * in_width);\n    } else {\n      const T scale_height = GetAreaPixelScale(in_height, out_height, align_corners, height_scale);\n      const T scale_width = GetAreaPixelScale(in_width, out_width, align_corners, width_scale);\n\n      for (int64_t output_y = 0; output_y < out_height; output_y++) {\n        for (int64_t output_x = 0; output_x < out_width; output_x++) {\n          const T* in = in_ptr;\n          T* out = out_ptr;\n\n          const T real_x = GetAreaPixel(scale_width, output_x, align_corners, /*cubic=*/true);\n          int64_t input_x = std::floor(real_x);\n          const T t_x = real_x - input_x;\n\n          const T real_y = GetAreaPixel(scale_height, output_y, align_corners, /*cubic=*/true);\n          int64_t input_y = std::floor(real_y);\n          const T t_y = real_y - input_y;\n\n          for (int64_t c = 0; c < channels * nbatch; c++) {\n            T coefficients[4];\n\n            // Interpolate 4 times in the x direction\n            for (int64_t i = 0; i < 4; i++) {\n              coefficients[i] =\n                  cubic_interp1d<T>(upsample_get_value_bounded<T>(in, in_width, in_height,\n                                                                  input_x - 1, input_y - 1 + i),\n                                    upsample_get_value_bounded<T>(in, in_width, in_height,\n                                                                  input_x + 0, input_y - 1 + i),\n                                    upsample_get_value_bounded<T>(in, in_width, in_height,\n                                                                  input_x + 1, input_y - 1 + i),\n                                    upsample_get_value_bounded<T>(in, in_width, in_height,\n                                                                  input_x + 2, input_y - 1 + i),\n                                    t_x);\n            }\n\n            // Interpolate in the y direction using x interpolations\n            out[output_y * out_width + output_x] = cubic_interp1d<T>(\n                coefficients[0], coefficients[1], coefficients[2], coefficients[3], t_y);\n\n            // Move to next channel\n            in += in_width * in_height;\n            out += out_width * out_height;\n          }\n        }\n      }\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\ntemplate<typename T>\nclass UpsampleBicubic2dGradCPUKernel final : public user_op::OpKernel {\n public:\n  UpsampleBicubic2dGradCPUKernel() = default;\n  ~UpsampleBicubic2dGradCPUKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    user_op::Tensor* dx_tensor = ctx->Tensor4ArgNameAndIndex(\"dx\", 0);\n    Memset<DeviceType::kCPU>(ctx->stream(), dx_tensor->mut_dptr<T>(), 0,\n                             dx_tensor->shape_view().elem_cnt() * sizeof(T));\n    user_op::Tensor* dy_tensor = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    T* in_ptr = dx_tensor->mut_dptr<T>();\n    const T* out_ptr = dy_tensor->dptr<T>();\n    const bool align_corners = ctx->Attr<bool>(\"align_corners\");\n    const int nbatch = dx_tensor->shape_view().At(0);\n    int channels = dx_tensor->shape_view().At(1);\n    channels = channels * nbatch;\n\n    const int64_t in_height = dx_tensor->shape_view().At(2);\n    const int64_t in_width = dx_tensor->shape_view().At(3);\n    const int64_t out_height = dy_tensor->shape_view().At(2);\n    const int64_t out_width = dy_tensor->shape_view().At(3);\n\n    const std::vector<int64_t> output_size = ctx->Attr<std::vector<int64_t>>(\"output_size\");\n    double height_scale = ctx->Attr<double>(\"height_scale\");\n    double width_scale = ctx->Attr<double>(\"width_scale\");\n    if (!output_size.empty()) {\n      height_scale = static_cast<double>(out_height) / static_cast<double>(in_height);\n      width_scale = static_cast<double>(out_width) / static_cast<double>(in_width);\n    }\n    if (in_height == out_height && in_width == out_width) {\n      memcpy(in_ptr, out_ptr, sizeof(T) * channels * in_height * in_width);\n    } else {\n      const T scale_height = GetAreaPixelScale(in_height, out_height, align_corners, height_scale);\n      const T scale_width = GetAreaPixelScale(in_width, out_width, align_corners, width_scale);\n\n      for (int64_t output_y = 0; output_y < out_height; output_y++) {\n        for (int64_t output_x = 0; output_x < out_width; output_x++) {\n          T* in = in_ptr;\n          const T* out = out_ptr;\n\n          T real_x = GetAreaPixel(scale_width, output_x, align_corners, true);\n          int64_t input_x = std::floor(real_x);\n          T t_x = real_x - input_x;\n\n          T real_y = GetAreaPixel(scale_height, output_y, align_corners, true);\n          int64_t input_y = std::floor(real_y);\n          T t_y = real_y - input_y;\n\n          T x_coeffs[4];\n          T y_coeffs[4];\n\n          get_cubic_upsample_coefficients<T>(x_coeffs, t_x);\n          get_cubic_upsample_coefficients<T>(y_coeffs, t_y);\n\n          for (int64_t c = 0; c < channels; c++) {\n            T out_value = out[output_y * out_width + output_x];\n\n            for (int64_t i = 0; i < 4; i++) {\n              for (int64_t j = 0; j < 4; j++) {\n                upsample_increment_value_bounded<T>(in, in_width, in_height, input_x - 1 + i,\n                                                    input_y - 1 + j,\n                                                    out_value * y_coeffs[j] * x_coeffs[i]);\n              }\n            }\n\n            in += in_width * in_height;\n            out += out_width * out_height;\n          }\n        }\n      }\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_UPSAMPLE_BICUBIC_CPU_KERNEL(dtype)                                     \\\n  REGISTER_USER_KERNEL(\"upsample_bicubic_2d\")                                           \\\n      .SetCreateFn<UpsampleBicubic2dCPUKernel<dtype>>()                                 \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)                   \\\n                       && (user_op::HobDataType(\"y\", 0) == GetDataType<dtype>::value)); \\\n  REGISTER_USER_KERNEL(\"upsample_bicubic_2d_grad\")                                      \\\n      .SetCreateFn<UpsampleBicubic2dGradCPUKernel<dtype>>()                             \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)                   \\\n                       && (user_op::HobDataType(\"dx\", 0) == GetDataType<dtype>::value));\n\nREGISTER_UPSAMPLE_BICUBIC_CPU_KERNEL(float)\nREGISTER_UPSAMPLE_BICUBIC_CPU_KERNEL(double)\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/upsample_bicubic_2d_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/new_kernel_util.h\"\n#include \"oneflow/core/common/nd_index_offset_helper.h\"\n#include \"oneflow/core/cuda/atomic.cuh\"\n#include \"oneflow/user/kernels/upsample_kernel.h\"\n#include \"oneflow/core/kernel/kernel_util.cuh\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename T>\n__device__ void upsample_increment_value_bounded_cuda(T* data, int64_t width, int64_t height,\n                                                      int64_t element, int64_t x, int64_t y,\n                                                      T value) {\n  int64_t access_x = max(min(x, width - 1), static_cast<int64_t>(0));\n  int64_t access_y = max(min(y, height - 1), static_cast<int64_t>(0));\n  cuda::atomic::FastAdd(data, access_y * width + access_x, element, value);\n}\n\ntemplate<typename T>\n__global__ void UpsampleBicubic2dForward(const int64_t elem_cnt, const T* in_dptr,\n                                         const int64_t nbatch, const int64_t channels,\n                                         const int64_t in_height, const int64_t in_width,\n                                         const int64_t out_height, const int64_t out_width,\n                                         const float scale_height, const float scale_width,\n                                         bool align_corners, T* out_dptr) {\n  CUDA_1D_KERNEL_LOOP(idx, elem_cnt) {\n    const int output_x = idx % out_width;\n    const int output_y = idx / out_width;\n\n    const T* in = in_dptr;\n    T* out = out_dptr;\n\n    const T real_x = GetAreaPixel(scale_width, output_x, align_corners, /*cubic=*/true);\n    int64_t input_x = floor(1.0 * real_x);\n    const T t_x = real_x - input_x;\n\n    const T real_y = GetAreaPixel(scale_height, output_y, align_corners, /*cubic=*/true);\n    int64_t input_y = floor(1.0 * real_y);\n    const T t_y = real_y - input_y;\n\n    for (int64_t c = 0; c < channels * nbatch; c++) {\n      T coefficients[4];\n\n      // Interpolate 4 times in the x direction\n      for (int64_t i = 0; i < 4; i++) {\n        coefficients[i] = cubic_interp1d<T>(\n            upsample_get_value_bounded<T>(in, in_width, in_height, input_x - 1, input_y - 1 + i),\n            upsample_get_value_bounded<T>(in, in_width, in_height, input_x + 0, input_y - 1 + i),\n            upsample_get_value_bounded<T>(in, in_width, in_height, input_x + 1, input_y - 1 + i),\n            upsample_get_value_bounded<T>(in, in_width, in_height, input_x + 2, input_y - 1 + i),\n            t_x);\n      }\n\n      // Interpolate in the y direction using x interpolations\n      out[output_y * out_width + output_x] = cubic_interp1d<T>(\n          coefficients[0], coefficients[1], coefficients[2], coefficients[3], t_y);\n\n      // Move to next channel\n      in += in_width * in_height;\n      out += out_width * out_height;\n    }\n  }\n}\n\ntemplate<typename T>\n__global__ void UpsampleBicubic2dBackward(const int64_t elem_cnt, const T* dy_dptr,\n                                          const int64_t nbatch, const int64_t channels,\n                                          const int64_t in_height, const int64_t in_width,\n                                          const int64_t out_height, const int64_t out_width,\n                                          const float scale_height, const float scale_width,\n                                          bool align_corners, T* dx_dptr) {\n  CUDA_1D_KERNEL_LOOP(idx, elem_cnt) {\n    const int output_x = idx % out_width;\n    const int output_y = idx / out_width;\n\n    T* in = dx_dptr;\n    const T* out = dy_dptr;\n\n    T real_x = GetAreaPixel(scale_width, output_x, align_corners, true);\n    int64_t input_x = floor(1.0 * real_x);\n    T t_x = real_x - input_x;\n\n    T real_y = GetAreaPixel(scale_height, output_y, align_corners, true);\n    int64_t input_y = floor(1.0 * real_y);\n    T t_y = real_y - input_y;\n\n    T x_coeffs[4];\n    T y_coeffs[4];\n\n    get_cubic_upsample_coefficients<T>(x_coeffs, t_x);\n    get_cubic_upsample_coefficients<T>(y_coeffs, t_y);\n\n    for (int64_t c = 0; c < channels * nbatch; c++) {\n      T out_value = out[output_y * out_width + output_x];\n\n      for (int64_t i = 0; i < 4; i++) {\n        for (int64_t j = 0; j < 4; j++) {\n          upsample_increment_value_bounded_cuda<T>(in, in_width, in_height, elem_cnt,\n                                                   input_x - 1 + i, input_y - 1 + j,\n                                                   out_value * y_coeffs[j] * x_coeffs[i]);\n        }\n      }\n\n      in += in_width * in_height;\n      out += out_width * out_height;\n    }\n  }\n}\n\n}  // namespace\n\ntemplate<typename T>\nclass UpsampleBicubic2dGPUKernel final : public user_op::OpKernel {\n public:\n  UpsampleBicubic2dGPUKernel() = default;\n  ~UpsampleBicubic2dGPUKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* x_tensor = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    user_op::Tensor* y_tensor = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    const T* in_ptr = x_tensor->dptr<T>();\n    T* out_ptr = y_tensor->mut_dptr<T>();\n    const bool align_corners = ctx->Attr<bool>(\"align_corners\");\n\n    const int nbatch = x_tensor->shape_view().At(0);\n    const int channels = x_tensor->shape_view().At(1);\n    const int64_t in_height = x_tensor->shape_view().At(2);\n    const int64_t in_width = x_tensor->shape_view().At(3);\n    const int64_t out_height = y_tensor->shape_view().At(2);\n    const int64_t out_width = y_tensor->shape_view().At(3);\n    const std::vector<int64_t> output_size = ctx->Attr<std::vector<int64_t>>(\"output_size\");\n    double height_scale = ctx->Attr<double>(\"height_scale\");\n    double width_scale = ctx->Attr<double>(\"width_scale\");\n    if (!output_size.empty()) {\n      height_scale = static_cast<double>(out_height) / static_cast<double>(in_height);\n      width_scale = static_cast<double>(out_width) / static_cast<double>(in_width);\n    }\n    const int64_t elem_cnt = out_height * out_width;\n\n    if (in_height == out_height && in_width == out_width) {\n      Memcpy<DeviceType::kCUDA>(\n          ctx->stream(), y_tensor->mut_dptr<void>(), x_tensor->dptr<void>(),\n          x_tensor->shape_view().elem_cnt() * GetSizeOfDataType(x_tensor->data_type()));\n    } else {\n      const T scale_height = GetAreaPixelScale(in_height, out_height, align_corners, height_scale);\n      const T scale_width = GetAreaPixelScale(in_width, out_width, align_corners, width_scale);\n\n      RUN_CUDA_KERNEL((UpsampleBicubic2dForward<T>), ctx->stream(), elem_cnt, elem_cnt,\n                      x_tensor->dptr<T>(), nbatch, channels, in_height, in_width, out_height,\n                      out_width, scale_height, scale_width, align_corners, y_tensor->mut_dptr<T>());\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\ntemplate<typename T>\nclass UpsampleBicubic2dGradGPUKernel final : public user_op::OpKernel {\n public:\n  UpsampleBicubic2dGradGPUKernel() = default;\n  ~UpsampleBicubic2dGradGPUKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    user_op::Tensor* dx_tensor = ctx->Tensor4ArgNameAndIndex(\"dx\", 0);\n    Memset<DeviceType::kCUDA>(ctx->stream(), dx_tensor->mut_dptr<T>(), 0,\n                              dx_tensor->shape_view().elem_cnt() * sizeof(T));\n    const user_op::Tensor* dy_tensor = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    const bool align_corners = ctx->Attr<bool>(\"align_corners\");\n\n    const int nbatch = dx_tensor->shape_view().At(0);\n    const int channels = dx_tensor->shape_view().At(1);\n    const int64_t in_height = dx_tensor->shape_view().At(2);\n    const int64_t in_width = dx_tensor->shape_view().At(3);\n    const int64_t out_height = dy_tensor->shape_view().At(2);\n    const int64_t out_width = dy_tensor->shape_view().At(3);\n    const std::vector<int64_t> output_size = ctx->Attr<std::vector<int64_t>>(\"output_size\");\n    double height_scale = ctx->Attr<double>(\"height_scale\");\n    double width_scale = ctx->Attr<double>(\"width_scale\");\n    if (!output_size.empty()) {\n      height_scale = static_cast<double>(out_height) / static_cast<double>(in_height);\n      width_scale = static_cast<double>(out_width) / static_cast<double>(in_width);\n    }\n    const int64_t elem_cnt = out_height * out_width;\n\n    if (in_height == out_height && in_width == out_width) {\n      Memcpy<DeviceType::kCUDA>(\n          ctx->stream(), dx_tensor->mut_dptr<void>(), dy_tensor->dptr<void>(),\n          dy_tensor->shape_view().elem_cnt() * GetSizeOfDataType(dy_tensor->data_type()));\n    } else {\n      const T scale_height = GetAreaPixelScale(in_height, out_height, align_corners, height_scale);\n      const T scale_width = GetAreaPixelScale(in_width, out_width, align_corners, width_scale);\n\n      RUN_CUDA_KERNEL((UpsampleBicubic2dBackward<T>), ctx->stream(), elem_cnt, elem_cnt,\n                      dy_tensor->dptr<T>(), nbatch, channels, in_height, in_width, out_height,\n                      out_width, scale_height, scale_width, align_corners,\n                      dx_tensor->mut_dptr<T>());\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_UPSAMPLE_BICUBIC_CUDA_KERNEL(dtype)                                    \\\n  REGISTER_USER_KERNEL(\"upsample_bicubic_2d\")                                           \\\n      .SetCreateFn<UpsampleBicubic2dGPUKernel<dtype>>()                                 \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)                  \\\n                       && (user_op::HobDataType(\"y\", 0) == GetDataType<dtype>::value)); \\\n  REGISTER_USER_KERNEL(\"upsample_bicubic_2d_grad\")                                      \\\n      .SetCreateFn<UpsampleBicubic2dGradGPUKernel<dtype>>()                             \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)                  \\\n                       && (user_op::HobDataType(\"dx\", 0) == GetDataType<dtype>::value));\n\nREGISTER_UPSAMPLE_BICUBIC_CUDA_KERNEL(float)\nREGISTER_UPSAMPLE_BICUBIC_CUDA_KERNEL(double)\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/upsample_bilinear_2d_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/new_kernel_util.h\"\n#include \"oneflow/core/common/nd_index_offset_helper.h\"\n#include \"oneflow/user/kernels/upsample_kernel.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename T>\nstatic void UpsampleBilinear2DForward(const int64_t elem_cnt, const T* in_dptr,\n                                      NdIndexOffsetHelper<int64_t, 4> in_helper,\n                                      NdIndexOffsetHelper<int64_t, 4> out_helper,\n                                      const int64_t in_height, const int64_t in_width,\n                                      const T scale_h, const T scale_w, const bool align_corners,\n                                      T* out_dptr) {\n  for (int64_t index = 0; index < elem_cnt; ++index) {\n    int64_t n, c, h, w;\n    out_helper.OffsetToNdIndex(index, n, c, h, w);\n    BilinearParam<T> params;\n    GetBilinearParam(align_corners, h, w, in_height, in_width, scale_h, scale_w, &params);\n    const int64_t top_offset = in_helper.NdIndexToOffset(n, c, params.top_h_index, 0);\n    const int64_t bottom_offset = in_helper.NdIndexToOffset(n, c, params.bottom_h_index, 0);\n    const T top_left = in_dptr[top_offset + params.left_w_index];\n    const T top_right = in_dptr[top_offset + params.right_w_index];\n    const T bottom_left = in_dptr[bottom_offset + params.left_w_index];\n    const T bottom_right = in_dptr[bottom_offset + params.right_w_index];\n    out_dptr[index] =\n        (1 - params.h_lerp) * ((1 - params.w_lerp) * top_left + params.w_lerp * top_right)\n        + params.h_lerp * ((1 - params.w_lerp) * bottom_left + params.w_lerp * bottom_right);\n  }\n}\n\ntemplate<typename T>\nstatic void UpsampleBilinearBackward(const int64_t elem_cnt, const T* dy_dptr,\n                                     NdIndexOffsetHelper<int64_t, 4> dy_helper,\n                                     NdIndexOffsetHelper<int64_t, 4> dx_helper,\n                                     const int64_t dx_height, const int64_t dx_width,\n                                     const T scale_h, const T scale_w, const bool align_corners,\n                                     T* dx_dptr) {\n  for (int64_t index = 0; index < elem_cnt; ++index) {\n    int64_t n, c, h, w;\n    dy_helper.OffsetToNdIndex(index, n, c, h, w);\n    BilinearParam<T> params;\n    GetBilinearParam(align_corners, h, w, dx_height, dx_width, scale_h, scale_w, &params);\n    const int64_t top_offset = dx_helper.NdIndexToOffset(n, c, params.top_h_index, 0);\n    const int64_t bottom_offset = dx_helper.NdIndexToOffset(n, c, params.bottom_h_index, 0);\n    const T dy = dy_dptr[index];\n    const T dbottom = params.h_lerp * dy;\n    T* dx_dptr_bottom_offset = dx_dptr + bottom_offset;\n    *(dx_dptr_bottom_offset + params.left_w_index) += static_cast<T>((1 - params.w_lerp) * dbottom);\n    *(dx_dptr_bottom_offset + params.right_w_index) += static_cast<T>(params.w_lerp * dbottom);\n    const T dtop = dy - dbottom;\n    T* dx_dptr_top_offset = dx_dptr + top_offset;\n    *(dx_dptr_top_offset + params.left_w_index) += static_cast<T>((1 - params.w_lerp) * dtop);\n    *(dx_dptr_top_offset + params.right_w_index) += static_cast<T>(params.w_lerp * dtop);\n  }\n}\n\n}  // namespace\n\ntemplate<typename T>\nclass UpsampleBilinear2DCPUKernel final : public user_op::OpKernel {\n public:\n  UpsampleBilinear2DCPUKernel() = default;\n  ~UpsampleBilinear2DCPUKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* x_tensor = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    user_op::Tensor* y_tensor = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    const bool align_corners = ctx->Attr<bool>(\"align_corners\");\n    const std::vector<int64_t> output_size = ctx->Attr<std::vector<int64_t>>(\"output_size\");\n    double height_scale = ctx->Attr<double>(\"height_scale\");\n    double width_scale = ctx->Attr<double>(\"width_scale\");\n    const int64_t elem_cnt = y_tensor->shape_view().elem_cnt();\n    NdIndexOffsetHelper<int64_t, 4> in_helper(\n        x_tensor->shape_view().At(0), x_tensor->shape_view().At(1), x_tensor->shape_view().At(2),\n        x_tensor->shape_view().At(3));\n    NdIndexOffsetHelper<int64_t, 4> out_helper(\n        y_tensor->shape_view().At(0), y_tensor->shape_view().At(1), y_tensor->shape_view().At(2),\n        y_tensor->shape_view().At(3));\n\n    const int64_t nbatch = x_tensor->shape_view().At(0);\n    const int64_t channels = x_tensor->shape_view().At(1);\n    const int64_t in_height = x_tensor->shape_view().At(2);\n    const int64_t in_width = x_tensor->shape_view().At(3);\n    const int64_t out_height = y_tensor->shape_view().At(2);\n    const int64_t out_width = y_tensor->shape_view().At(3);\n\n    if (!output_size.empty()) {\n      height_scale = static_cast<double>(out_height) / static_cast<double>(in_height);\n      width_scale = static_cast<double>(out_width) / static_cast<double>(in_width);\n    }\n\n    if (in_height == out_height && in_width == out_width) {\n      memcpy(y_tensor->mut_dptr<void>(), x_tensor->dptr<void>(),\n             sizeof(T) * nbatch * channels * in_height * in_width);\n    } else {\n      const T scale_height = GetAreaPixelScale(in_height, out_height, align_corners, height_scale);\n      const T scale_width = GetAreaPixelScale(in_width, out_width, align_corners, width_scale);\n      UpsampleBilinear2DForward<T>(elem_cnt, x_tensor->dptr<T>(), in_helper, out_helper, in_height,\n                                   in_width, scale_height, scale_width, align_corners,\n                                   y_tensor->mut_dptr<T>());\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\ntemplate<typename T>\nclass UpsampleBilinear2DGradCPUKernel final : public user_op::OpKernel {\n public:\n  UpsampleBilinear2DGradCPUKernel() = default;\n  ~UpsampleBilinear2DGradCPUKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    user_op::Tensor* dx_tensor = ctx->Tensor4ArgNameAndIndex(\"dx\", 0);\n    Memset<DeviceType::kCPU>(ctx->stream(), dx_tensor->mut_dptr<T>(), 0,\n                             dx_tensor->shape_view().elem_cnt() * sizeof(T));\n    const user_op::Tensor* dy_tensor = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    const bool align_corners = ctx->Attr<bool>(\"align_corners\");\n    const std::vector<int64_t> output_size = ctx->Attr<std::vector<int64_t>>(\"output_size\");\n    double height_scale = ctx->Attr<double>(\"height_scale\");\n    double width_scale = ctx->Attr<double>(\"width_scale\");\n    const int64_t elem_cnt = dy_tensor->shape_view().elem_cnt();\n    NdIndexOffsetHelper<int64_t, 4> dy_helper(\n        dy_tensor->shape_view().At(0), dy_tensor->shape_view().At(1), dy_tensor->shape_view().At(2),\n        dy_tensor->shape_view().At(3));\n    NdIndexOffsetHelper<int64_t, 4> dx_helper(\n        dx_tensor->shape_view().At(0), dx_tensor->shape_view().At(1), dx_tensor->shape_view().At(2),\n        dx_tensor->shape_view().At(3));\n\n    const int64_t nbatch = dx_tensor->shape_view().At(0);\n    const int64_t channels = dx_tensor->shape_view().At(1);\n    const int64_t in_height = dx_tensor->shape_view().At(2);\n    const int64_t in_width = dx_tensor->shape_view().At(3);\n    const int64_t out_height = dy_tensor->shape_view().At(2);\n    const int64_t out_width = dy_tensor->shape_view().At(3);\n    if (!output_size.empty()) {\n      height_scale = static_cast<double>(out_height) / static_cast<double>(in_height);\n      width_scale = static_cast<double>(out_width) / static_cast<double>(in_width);\n    }\n\n    if (in_height == out_height && in_width == out_width) {\n      memcpy(dx_tensor->mut_dptr<void>(), dy_tensor->dptr<void>(),\n             sizeof(T) * nbatch * channels * in_height * in_width);\n    } else {\n      const T scale_height = GetAreaPixelScale(in_height, out_height, align_corners, height_scale);\n      const T scale_width = GetAreaPixelScale(in_width, out_width, align_corners, width_scale);\n      UpsampleBilinearBackward<T>(elem_cnt, dy_tensor->dptr<T>(), dy_helper, dx_helper, in_height,\n                                  in_width, scale_height, scale_width, align_corners,\n                                  dx_tensor->mut_dptr<T>());\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_UPSAMPLE_BILINEAR_2D_CPU_KERNEL(dtype)                                 \\\n  REGISTER_USER_KERNEL(\"upsample_bilinear_2d\")                                          \\\n      .SetCreateFn<UpsampleBilinear2DCPUKernel<dtype>>()                                \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)                   \\\n                       && (user_op::HobDataType(\"y\", 0) == GetDataType<dtype>::value)); \\\n  REGISTER_USER_KERNEL(\"upsample_bilinear_2d_grad\")                                     \\\n      .SetCreateFn<UpsampleBilinear2DGradCPUKernel<dtype>>()                            \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)                   \\\n                       && (user_op::HobDataType(\"dx\", 0) == GetDataType<dtype>::value));\n\nREGISTER_UPSAMPLE_BILINEAR_2D_CPU_KERNEL(float)\nREGISTER_UPSAMPLE_BILINEAR_2D_CPU_KERNEL(double)\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/upsample_bilinear_2d_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/new_kernel_util.h\"\n#include \"oneflow/core/common/nd_index_offset_helper.h\"\n#include \"oneflow/core/cuda/atomic.cuh\"\n#include \"oneflow/user/kernels/upsample_kernel.h\"\n\nnamespace oneflow {\n\nnamespace {\n\n__device__ __forceinline__ void GetBilinearParamHalf(const bool align_corners, const int64_t h,\n                                                     const int64_t w, const int64_t in_height,\n                                                     const int64_t in_width, const double scale_h,\n                                                     const double scale_w,\n                                                     BilinearParam<half>* params) {\n  half h1r;\n  if (align_corners) {\n    h1r = static_cast<half>(scale_h * static_cast<double>(h));\n  } else {\n    h1r = h1r = static_cast<half>((static_cast<double>(h) + 0.5f) * scale_h - 0.5f);\n    h1r = h1r < static_cast<half>(0.0) ? static_cast<half>(0.0) : h1r;\n  }\n  const int64_t h1 = int(h1r);\n  const int64_t h1p = (h1 < in_height - 1) ? 1 : 0;\n\n  half w1r;\n  if (align_corners) {\n    w1r = static_cast<half>(scale_w * static_cast<double>(w));\n  } else {\n    w1r = static_cast<half>((static_cast<double>(w) + 0.5f) * scale_w - 0.5f);\n    w1r = w1r < static_cast<half>(0.0) ? static_cast<half>(0.0) : w1r;\n  }\n  const int64_t w1 = int(w1r);\n  const int64_t w1p = (w1 < in_width - 1) ? 1 : 0;\n\n  params->top_h_index = h1;\n  params->bottom_h_index = h1 + h1p;\n  params->h_lerp = h1r - static_cast<half>(h1 * 1.0);\n  params->left_w_index = w1;\n  params->right_w_index = w1 + w1p;\n  params->w_lerp = w1r - static_cast<half>(w1 * 1.0);\n}\n\ntemplate<typename T>\n__global__ void UpsampleBilinear2DForward(const int64_t elem_cnt, const T* in_dptr,\n                                          NdIndexOffsetHelper<int64_t, 4> in_helper,\n                                          NdIndexOffsetHelper<int64_t, 4> out_helper,\n                                          const int64_t in_height, const int64_t in_width,\n                                          const T scale_h, const T scale_w,\n                                          const bool align_corners, T* out_dptr) {\n  CUDA_1D_KERNEL_LOOP(index, elem_cnt) {\n    int64_t n, c, h, w;\n    out_helper.OffsetToNdIndex(index, n, c, h, w);\n    BilinearParam<T> params;\n    GetBilinearParam(align_corners, h, w, in_height, in_width, scale_h, scale_w, &params);\n    const int64_t top_offset = in_helper.NdIndexToOffset(n, c, params.top_h_index, 0);\n    const int64_t bottom_offset = in_helper.NdIndexToOffset(n, c, params.bottom_h_index, 0);\n    const T top_left = in_dptr[top_offset + params.left_w_index];\n    const T top_right = in_dptr[top_offset + params.right_w_index];\n    const T bottom_left = in_dptr[bottom_offset + params.left_w_index];\n    const T bottom_right = in_dptr[bottom_offset + params.right_w_index];\n    out_dptr[index] =\n        (1 - params.h_lerp) * ((1 - params.w_lerp) * top_left + params.w_lerp * top_right)\n        + params.h_lerp * ((1 - params.w_lerp) * bottom_left + params.w_lerp * bottom_right);\n  }\n}\n\ntemplate<>\n__global__ void UpsampleBilinear2DForward(const int64_t elem_cnt, const half* in_dptr,\n                                          NdIndexOffsetHelper<int64_t, 4> in_helper,\n                                          NdIndexOffsetHelper<int64_t, 4> out_helper,\n                                          const int64_t in_height, const int64_t in_width,\n                                          const half scale_h, const half scale_w,\n                                          const bool align_corners, half* out_dptr) {\n  CUDA_1D_KERNEL_LOOP(index, elem_cnt) {\n    int64_t n, c, h, w;\n    out_helper.OffsetToNdIndex(index, n, c, h, w);\n    BilinearParam<half> params;\n    GetBilinearParamHalf(align_corners, h, w, in_height, in_width, scale_h, scale_w, &params);\n    const int64_t top_offset = in_helper.NdIndexToOffset(n, c, params.top_h_index, 0);\n    const int64_t bottom_offset = in_helper.NdIndexToOffset(n, c, params.bottom_h_index, 0);\n    const half top_left = in_dptr[top_offset + params.left_w_index];\n    const half top_right = in_dptr[top_offset + params.right_w_index];\n    const half bottom_left = in_dptr[bottom_offset + params.left_w_index];\n    const half bottom_right = in_dptr[bottom_offset + params.right_w_index];\n    out_dptr[index] =\n        (static_cast<half>(1.0) - params.h_lerp)\n            * ((static_cast<half>(1.0) - params.w_lerp) * top_left + params.w_lerp * top_right)\n        + params.h_lerp\n              * ((static_cast<half>(1.0) - params.w_lerp) * bottom_left\n                 + params.w_lerp * bottom_right);\n  }\n}\n\ntemplate<typename T>\n__global__ void UpsampleBilinearBackward(const int64_t elem_cnt, const T* dy_dptr,\n                                         NdIndexOffsetHelper<int64_t, 4> dy_helper,\n                                         NdIndexOffsetHelper<int64_t, 4> dx_helper,\n                                         const int64_t dx_height, const int64_t dx_width,\n                                         const T scale_h, const T scale_w, const bool align_corners,\n                                         T* dx_dptr) {\n  CUDA_1D_KERNEL_LOOP(index, elem_cnt) {\n    int64_t n, c, h, w;\n    dy_helper.OffsetToNdIndex(index, n, c, h, w);\n    BilinearParam<T> params;\n    GetBilinearParam(align_corners, h, w, dx_height, dx_width, scale_h, scale_w, &params);\n    const int64_t top_offset = dx_helper.NdIndexToOffset(n, c, params.top_h_index, 0);\n    const int64_t bottom_offset = dx_helper.NdIndexToOffset(n, c, params.bottom_h_index, 0);\n    const T dy = dy_dptr[index];\n    const T dbottom = params.h_lerp * dy;\n    T* dx_dptr_bottom_offset = dx_dptr + bottom_offset;\n    cuda::atomic::FastAdd(dx_dptr_bottom_offset, params.left_w_index, elem_cnt,\n                          static_cast<T>((1 - params.w_lerp) * dbottom));\n    cuda::atomic::FastAdd(dx_dptr_bottom_offset, params.right_w_index, elem_cnt,\n                          static_cast<T>(params.w_lerp * dbottom));\n    const T dtop = dy - dbottom;\n    T* dx_dptr_top_offset = dx_dptr + top_offset;\n    cuda::atomic::FastAdd(dx_dptr_top_offset, params.left_w_index, elem_cnt,\n                          static_cast<T>((1 - params.w_lerp) * dtop));\n    cuda::atomic::FastAdd(dx_dptr_top_offset, params.right_w_index, elem_cnt,\n                          static_cast<T>(params.w_lerp * dtop));\n  }\n}\n\ntemplate<>\n__global__ void UpsampleBilinearBackward(const int64_t elem_cnt, const half* dy_dptr,\n                                         NdIndexOffsetHelper<int64_t, 4> dy_helper,\n                                         NdIndexOffsetHelper<int64_t, 4> dx_helper,\n                                         const int64_t dx_height, const int64_t dx_width,\n                                         const half scale_h, const half scale_w,\n                                         const bool align_corners, half* dx_dptr) {\n  CUDA_1D_KERNEL_LOOP(index, elem_cnt) {\n    int64_t n, c, h, w;\n    dy_helper.OffsetToNdIndex(index, n, c, h, w);\n    BilinearParam<half> params;\n    GetBilinearParamHalf(align_corners, h, w, dx_height, dx_width, scale_h, scale_w, &params);\n    const int64_t top_offset = dx_helper.NdIndexToOffset(n, c, params.top_h_index, 0);\n    const int64_t bottom_offset = dx_helper.NdIndexToOffset(n, c, params.bottom_h_index, 0);\n    const half dy = dy_dptr[index];\n    const half dbottom = params.h_lerp * dy;\n    half* dx_dptr_bottom_offset = dx_dptr + bottom_offset;\n    cuda::atomic::FastAdd(dx_dptr_bottom_offset, params.left_w_index, elem_cnt,\n                          static_cast<half>((static_cast<half>(1.0) - params.w_lerp) * dbottom));\n    cuda::atomic::FastAdd(dx_dptr_bottom_offset, params.right_w_index, elem_cnt,\n                          static_cast<half>(params.w_lerp * dbottom));\n    const half dtop = dy - dbottom;\n    half* dx_dptr_top_offset = dx_dptr + top_offset;\n    cuda::atomic::FastAdd(dx_dptr_top_offset, params.left_w_index, elem_cnt,\n                          static_cast<half>((static_cast<half>(1.0) - params.w_lerp) * dtop));\n    cuda::atomic::FastAdd(dx_dptr_top_offset, params.right_w_index, elem_cnt,\n                          static_cast<half>(params.w_lerp * dtop));\n  }\n}\n\n}  // namespace\n\ntemplate<typename T>\nclass UpsampleBilinear2DGPUKernel final : public user_op::OpKernel {\n public:\n  UpsampleBilinear2DGPUKernel() = default;\n  ~UpsampleBilinear2DGPUKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* x_tensor = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    user_op::Tensor* y_tensor = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    const bool align_corners = ctx->Attr<bool>(\"align_corners\");\n    const std::vector<int64_t> output_size = ctx->Attr<std::vector<int64_t>>(\"output_size\");\n    double height_scale = ctx->Attr<double>(\"height_scale\");\n    double width_scale = ctx->Attr<double>(\"width_scale\");\n    const int64_t elem_cnt = y_tensor->shape_view().elem_cnt();\n    NdIndexOffsetHelper<int64_t, 4> in_helper(\n        x_tensor->shape_view().At(0), x_tensor->shape_view().At(1), x_tensor->shape_view().At(2),\n        x_tensor->shape_view().At(3));\n    NdIndexOffsetHelper<int64_t, 4> out_helper(\n        y_tensor->shape_view().At(0), y_tensor->shape_view().At(1), y_tensor->shape_view().At(2),\n        y_tensor->shape_view().At(3));\n\n    const int64_t in_height = x_tensor->shape_view().At(2);\n    const int64_t in_width = x_tensor->shape_view().At(3);\n    const int64_t out_height = y_tensor->shape_view().At(2);\n    const int64_t out_width = y_tensor->shape_view().At(3);\n    if (!output_size.empty()) {\n      height_scale = static_cast<double>(out_height) / static_cast<double>(in_height);\n      width_scale = static_cast<double>(out_width) / static_cast<double>(in_width);\n    }\n    if (in_height == out_height && in_width == out_width) {\n      Memcpy<DeviceType::kCUDA>(\n          ctx->stream(), y_tensor->mut_dptr<void>(), x_tensor->dptr<void>(),\n          x_tensor->shape_view().elem_cnt() * GetSizeOfDataType(x_tensor->data_type()));\n    } else {\n      const T scale_height = GetAreaPixelScale(in_height, out_height, align_corners, height_scale);\n      const T scale_width = GetAreaPixelScale(in_width, out_width, align_corners, width_scale);\n      RUN_CUDA_KERNEL((UpsampleBilinear2DForward<T>), ctx->stream(), elem_cnt, elem_cnt,\n                      x_tensor->dptr<T>(), in_helper, out_helper, in_height, in_width, scale_height,\n                      scale_width, align_corners, y_tensor->mut_dptr<T>());\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\ntemplate<typename T>\nclass UpsampleBilinear2DGradGPUKernel final : public user_op::OpKernel {\n public:\n  UpsampleBilinear2DGradGPUKernel() = default;\n  ~UpsampleBilinear2DGradGPUKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    user_op::Tensor* dx_tensor = ctx->Tensor4ArgNameAndIndex(\"dx\", 0);\n    Memset<DeviceType::kCUDA>(ctx->stream(), dx_tensor->mut_dptr<T>(), 0,\n                              dx_tensor->shape_view().elem_cnt() * sizeof(T));\n    const user_op::Tensor* dy_tensor = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    const bool align_corners = ctx->Attr<bool>(\"align_corners\");\n    const std::vector<int64_t> output_size = ctx->Attr<std::vector<int64_t>>(\"output_size\");\n    double height_scale = ctx->Attr<double>(\"height_scale\");\n    double width_scale = ctx->Attr<double>(\"width_scale\");\n    const int64_t elem_cnt = dy_tensor->shape_view().elem_cnt();\n    NdIndexOffsetHelper<int64_t, 4> dy_helper(\n        dy_tensor->shape_view().At(0), dy_tensor->shape_view().At(1), dy_tensor->shape_view().At(2),\n        dy_tensor->shape_view().At(3));\n    NdIndexOffsetHelper<int64_t, 4> dx_helper(\n        dx_tensor->shape_view().At(0), dx_tensor->shape_view().At(1), dx_tensor->shape_view().At(2),\n        dx_tensor->shape_view().At(3));\n\n    const int64_t in_height = dx_tensor->shape_view().At(2);\n    const int64_t in_width = dx_tensor->shape_view().At(3);\n    const int64_t out_height = dy_tensor->shape_view().At(2);\n    const int64_t out_width = dy_tensor->shape_view().At(3);\n    if (!output_size.empty()) {\n      height_scale = static_cast<double>(out_height) / static_cast<double>(in_height);\n      width_scale = static_cast<double>(out_width) / static_cast<double>(in_width);\n    }\n    if (in_height == out_height && in_width == out_width) {\n      Memcpy<DeviceType::kCUDA>(\n          ctx->stream(), dx_tensor->mut_dptr<void>(), dy_tensor->dptr<void>(),\n          dy_tensor->shape_view().elem_cnt() * GetSizeOfDataType(dy_tensor->data_type()));\n    } else {\n      const T scale_height = GetAreaPixelScale(in_height, out_height, align_corners, height_scale);\n      const T scale_width = GetAreaPixelScale(in_width, out_width, align_corners, width_scale);\n      RUN_CUDA_KERNEL((UpsampleBilinearBackward<T>), ctx->stream(), elem_cnt, elem_cnt,\n                      dy_tensor->dptr<T>(), dy_helper, dx_helper, in_height, in_width, scale_height,\n                      scale_width, align_corners, dx_tensor->mut_dptr<T>());\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_UPSAMPLE_BILINEAR_2D_CUDA_KERNEL(dtype)                                \\\n  REGISTER_USER_KERNEL(\"upsample_bilinear_2d\")                                          \\\n      .SetCreateFn<UpsampleBilinear2DGPUKernel<dtype>>()                                \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)                  \\\n                       && (user_op::HobDataType(\"y\", 0) == GetDataType<dtype>::value)); \\\n  REGISTER_USER_KERNEL(\"upsample_bilinear_2d_grad\")                                     \\\n      .SetCreateFn<UpsampleBilinear2DGradGPUKernel<dtype>>()                            \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)                  \\\n                       && (user_op::HobDataType(\"dx\", 0) == GetDataType<dtype>::value));\n\nREGISTER_UPSAMPLE_BILINEAR_2D_CUDA_KERNEL(half)\nREGISTER_UPSAMPLE_BILINEAR_2D_CUDA_KERNEL(float)\nREGISTER_UPSAMPLE_BILINEAR_2D_CUDA_KERNEL(double)\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/upsample_kernel.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/nd_index_offset_helper.h\"\n#include <math.h>\n\nOF_DEVICE_FUNC double GetLinearInputIndex(const int64_t out_dim_idx, const double scale,\n                                          bool align_corners) {\n  if (align_corners) {\n    return static_cast<double>(scale * out_dim_idx);\n  } else {\n    double src_idx = scale * (out_dim_idx + 0.5) - 0.5;\n    return static_cast<double>(src_idx < 0 ? 0 : src_idx);\n  }\n}\n\nOF_DEVICE_FUNC static int64_t GetNearestInputIndex(const int64_t out_dim_idx, const double scale,\n                                                   const int64_t in_dim_size) {\n  int64_t index = static_cast<int64_t>(floorf(out_dim_idx * scale));\n  index = index > in_dim_size - 1 ? in_dim_size - 1 : index;\n  return index;\n}\n\nOF_DEVICE_FUNC double GetAreaPixelScale(const int64_t input_size, const int64_t output_size,\n                                        bool align_corners, const double scale) {\n  if (align_corners) {\n    if (output_size > 1) {\n      return static_cast<double>(input_size - 1) / (output_size - 1);\n    } else {\n      return 0;\n    }\n  } else {\n    return (scale > 0. ? 1.0 / scale : static_cast<double>(input_size) / output_size);\n  }\n}\n\nOF_DEVICE_FUNC double GetAreaPixel(const double scale, const int64_t dst_index, bool align_corners,\n                                   bool cubic = false) {\n  if (align_corners) {\n    return scale * dst_index;\n  } else {\n    double src_idx = scale * (dst_index + 0.5) - 0.5;\n    return (!cubic && src_idx < 0) ? static_cast<double>(0) : src_idx;\n  }\n}\n\ntemplate<typename T>\nstruct BilinearParam {\n  int64_t top_h_index;\n  int64_t bottom_h_index;\n  int64_t left_w_index;\n  int64_t right_w_index;\n  T w_lerp;\n  T h_lerp;\n};\n\ntemplate<typename T>\nOF_DEVICE_FUNC void GetBilinearParam(const bool align_corners, const int64_t h, const int64_t w,\n                                     const int64_t in_height, const int64_t in_width,\n                                     const double scale_h, const double scale_w,\n                                     BilinearParam<T>* params) {\n  T h1r;\n  if (align_corners) {\n    h1r = scale_h * static_cast<T>(h);\n  } else {\n    h1r = (static_cast<T>(h) + 0.5f) * scale_h - 0.5f;\n    h1r = h1r < 0 ? 0 : h1r;\n  }\n  const int64_t h1 = h1r;\n  const int64_t h1p = (h1 < in_height - 1) ? 1 : 0;\n\n  T w1r;\n  if (align_corners) {\n    w1r = scale_w * static_cast<T>(w);\n  } else {\n    w1r = (static_cast<T>(w) + 0.5f) * scale_w - 0.5f;\n    w1r = w1r < 0 ? 0 : w1r;\n  }\n  const int64_t w1 = w1r;\n  const int64_t w1p = (w1 < in_width - 1) ? 1 : 0;\n\n  params->top_h_index = h1;\n  params->bottom_h_index = h1 + h1p;\n  params->h_lerp = h1r - h1;\n  params->left_w_index = w1;\n  params->right_w_index = w1 + w1p;\n  params->w_lerp = w1r - w1;\n}\n\ntemplate<typename T>\nOF_DEVICE_FUNC void upsample_increment_value_bounded(T* data, int64_t width, int64_t height,\n                                                     int64_t x, int64_t y, T value) {\n  int64_t access_x = std::max(std::min(x, width - 1), static_cast<int64_t>(0));\n  int64_t access_y = std::max(std::min(y, height - 1), static_cast<int64_t>(0));\n  data[access_y * width + access_x] += value;\n}\n\ntemplate<typename T>\nOF_DEVICE_FUNC T upsample_get_value_bounded(const T* data, const int64_t width,\n                                            const int64_t height, const int64_t x,\n                                            const int64_t y) {\n  int64_t access_x = x;\n  access_x = access_x > width - 1 ? width - 1 : access_x;\n  access_x = access_x < 0 ? 0 : access_x;\n\n  int64_t access_y = y;\n  access_y = access_y > height - 1 ? height - 1 : access_y;\n  access_y = access_y < 0 ? 0 : access_y;\n\n  return data[access_y * width + access_x];\n}\n\n// Based on\n// https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm\n\ntemplate<typename T>\nOF_DEVICE_FUNC T cubic_convolution1(const T x, const T A) {\n  return ((A + static_cast<T>(2.0)) * x - (A + static_cast<T>(3.0))) * x * x + static_cast<T>(1.0);\n}\n\ntemplate<typename T>\nOF_DEVICE_FUNC T cubic_convolution2(const T x, const T A) {\n  return ((A * x - static_cast<T>(5.0) * A) * x + static_cast<T>(8.0) * A) * x\n         - static_cast<T>(4.0) * A;\n}\n\ntemplate<typename T>\nOF_DEVICE_FUNC void get_cubic_upsample_coefficients(T coeffs[4], const T t) {\n  T A = -0.75;\n\n  T x1 = t;\n  coeffs[0] = cubic_convolution2<T>(x1 + 1.0, A);\n  coeffs[1] = cubic_convolution1<T>(x1, A);\n\n  // opposite coefficients\n  T x2 = 1.0 - t;\n  coeffs[2] = cubic_convolution1<T>(x2, A);\n  coeffs[3] = cubic_convolution2<T>(x2 + 1.0, A);\n}\n\ntemplate<typename T>\nOF_DEVICE_FUNC T cubic_interp1d(const T x0, const T x1, const T x2, const T x3, const T t) {\n  T coeffs[4];\n  get_cubic_upsample_coefficients<T>(coeffs, t);\n  return x0 * coeffs[0] * 1.0 + x1 * coeffs[1] * 1.0 + x2 * coeffs[2] * 1.0 + x3 * coeffs[3] * 1.0;\n}\n"
  },
  {
    "path": "oneflow/user/kernels/upsample_linear_1d_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/new_kernel_util.h\"\n#include \"oneflow/core/common/nd_index_offset_helper.h\"\n#include \"oneflow/user/kernels/upsample_kernel.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename T>\nstatic void UpsampleLinear1DForward(const int64_t elem_cnt, const T* in_dptr,\n                                    NdIndexOffsetHelper<int64_t, 3> in_helper,\n                                    NdIndexOffsetHelper<int64_t, 3> out_helper, const int in_height,\n                                    const double scale_factor, bool align_corners, T* out_dptr) {\n  for (int64_t index = 0; index < elem_cnt; ++index) {\n    int64_t n, c, h;\n    out_helper.OffsetToNdIndex(index, n, c, h);\n    const double h1r = GetLinearInputIndex(h, scale_factor, align_corners);\n    const int64_t h1 = h1r;\n    const int64_t h1p = (h1 < in_height - 1) ? 1 : 0;\n    const double h1lambda = h1r - h1;\n    const double h0lambda = static_cast<double>(1.) - h1lambda;\n    out_dptr[index] = h0lambda * in_dptr[in_helper.NdIndexToOffset(n, c, h1)]\n                      + h1lambda * in_dptr[in_helper.NdIndexToOffset(n, c, h1 + h1p)];\n  }\n}\n\ntemplate<typename T>\nstatic void UpsampleLinear1DBackward(const int64_t elem_cnt, const T* dy_dptr,\n                                     NdIndexOffsetHelper<int64_t, 3> dy_helper,\n                                     NdIndexOffsetHelper<int64_t, 3> dx_helper, const int in_height,\n                                     const double scale_factor, bool align_corners, T* dx_dptr) {\n  for (int64_t index = 0; index < elem_cnt; ++index) {\n    int64_t n, c, h;\n    dy_helper.OffsetToNdIndex(index, n, c, h);\n    const double h1r = GetLinearInputIndex(h, scale_factor, align_corners);\n    const int64_t h1 = h1r;\n    const int64_t h1p = (h1 < in_height - 1) ? 1 : 0;\n    const double h1lambda = h1r - h1;\n    const double h0lambda = static_cast<double>(1.) - h1lambda;\n\n    *(dx_dptr + dx_helper.NdIndexToOffset(n, c, h1)) += h0lambda * dy_dptr[index];\n    *(dx_dptr + dx_helper.NdIndexToOffset(n, c, h1 + h1p)) += h1lambda * dy_dptr[index];\n  }\n}\n\n}  // namespace\n\ntemplate<typename T>\nclass UpsampleLinear1DCPUKernel final : public user_op::OpKernel {\n public:\n  UpsampleLinear1DCPUKernel() = default;\n  ~UpsampleLinear1DCPUKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* x_tensor = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    user_op::Tensor* y_tensor = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    const bool align_corners = ctx->Attr<bool>(\"align_corners\");\n    const int64_t elem_cnt = y_tensor->shape_view().elem_cnt();\n    NdIndexOffsetHelper<int64_t, 3> in_helper(\n        x_tensor->shape_view().At(0), x_tensor->shape_view().At(1), x_tensor->shape_view().At(2));\n    NdIndexOffsetHelper<int64_t, 3> out_helper(\n        y_tensor->shape_view().At(0), y_tensor->shape_view().At(1), y_tensor->shape_view().At(2));\n    const int64_t nbatch = x_tensor->shape_view().At(0);\n    const int64_t channels = x_tensor->shape_view().At(1);\n    const int64_t in_height = x_tensor->shape_view().At(2);\n    const int64_t out_height = y_tensor->shape_view().At(2);\n    const std::vector<int64_t> output_size = ctx->Attr<std::vector<int64_t>>(\"output_size\");\n    double height_scale = ctx->Attr<double>(\"scale_factor\");\n    if (!output_size.empty()) {\n      height_scale = static_cast<double>(out_height) / static_cast<double>(in_height);\n    }\n\n    if (in_height == out_height) {\n      memcpy(y_tensor->mut_dptr<void>(), x_tensor->dptr<void>(),\n             sizeof(T) * nbatch * channels * in_height);\n    } else {\n      const T scale_height = GetAreaPixelScale(in_height, out_height, align_corners, height_scale);\n      UpsampleLinear1DForward<T>(elem_cnt, x_tensor->dptr<T>(), in_helper, out_helper, in_height,\n                                 scale_height, align_corners, y_tensor->mut_dptr<T>());\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\ntemplate<typename T>\nclass UpsampleLinearGrad1DCPUKernel final : public user_op::OpKernel {\n public:\n  UpsampleLinearGrad1DCPUKernel() = default;\n  ~UpsampleLinearGrad1DCPUKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    user_op::Tensor* dx_tensor = ctx->Tensor4ArgNameAndIndex(\"dx\", 0);\n    Memset<DeviceType::kCPU>(ctx->stream(), dx_tensor->mut_dptr<T>(), 0,\n                             dx_tensor->shape_view().elem_cnt() * sizeof(T));\n    const user_op::Tensor* dy_tensor = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    const bool align_corners = ctx->Attr<bool>(\"align_corners\");\n\n    NdIndexOffsetHelper<int64_t, 3> dy_helper(dy_tensor->shape_view().At(0),\n                                              dy_tensor->shape_view().At(1),\n                                              dy_tensor->shape_view().At(2));\n    NdIndexOffsetHelper<int64_t, 3> dx_helper(dx_tensor->shape_view().At(0),\n                                              dx_tensor->shape_view().At(1),\n                                              dx_tensor->shape_view().At(2));\n    const int64_t elem_cnt = dy_tensor->shape_view().elem_cnt();\n\n    const int64_t nbatch = dx_tensor->shape_view().At(0);\n    const int64_t channels = dx_tensor->shape_view().At(1);\n    const int64_t in_height = dx_tensor->shape_view().At(2);\n    const int64_t out_height = dy_tensor->shape_view().At(2);\n    const std::vector<int64_t> output_size = ctx->Attr<std::vector<int64_t>>(\"output_size\");\n    double height_scale = ctx->Attr<double>(\"scale_factor\");\n    if (!output_size.empty()) {\n      height_scale = static_cast<double>(out_height) / static_cast<double>(in_height);\n    }\n\n    if (in_height == out_height) {\n      memcpy(dx_tensor->mut_dptr<void>(), dy_tensor->dptr<void>(),\n             sizeof(T) * nbatch * channels * in_height);\n    } else {\n      const T scale_height = GetAreaPixelScale(in_height, out_height, align_corners, height_scale);\n      UpsampleLinear1DBackward<T>(elem_cnt, dy_tensor->dptr<T>(), dy_helper, dx_helper, in_height,\n                                  scale_height, align_corners, dx_tensor->mut_dptr<T>());\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_UPSAMPLELINEAR1D_CPU_KERNEL(dtype)                                     \\\n  REGISTER_USER_KERNEL(\"upsample_linear_1d\")                                            \\\n      .SetCreateFn<UpsampleLinear1DCPUKernel<dtype>>()                                  \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)                   \\\n                       && (user_op::HobDataType(\"y\", 0) == GetDataType<dtype>::value)); \\\n  REGISTER_USER_KERNEL(\"upsample_linear_1d_grad\")                                       \\\n      .SetCreateFn<UpsampleLinearGrad1DCPUKernel<dtype>>()                              \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)                   \\\n                       && (user_op::HobDataType(\"dx\", 0) == GetDataType<dtype>::value));\n\nREGISTER_UPSAMPLELINEAR1D_CPU_KERNEL(float)\nREGISTER_UPSAMPLELINEAR1D_CPU_KERNEL(double)\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/upsample_linear_1d_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/new_kernel_util.h\"\n#include \"oneflow/core/common/nd_index_offset_helper.h\"\n#include \"oneflow/core/cuda/atomic.cuh\"\n#include \"oneflow/user/kernels/upsample_kernel.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename T>\n__global__ void UpsampleLinear1DForward(const int64_t elem_cnt, const T* in_dptr,\n                                        NdIndexOffsetHelper<int64_t, 3> in_helper,\n                                        NdIndexOffsetHelper<int64_t, 3> out_helper,\n                                        const int in_height, const double scale_factor,\n                                        bool align_corners, T* out_dptr) {\n  CUDA_1D_KERNEL_LOOP(index, elem_cnt) {\n    int64_t n, c, h;\n    out_helper.OffsetToNdIndex(index, n, c, h);\n    const double h1r = GetLinearInputIndex(h, scale_factor, align_corners);\n    const int64_t h1 = h1r;\n    const int64_t h1p = (h1 < in_height - 1) ? 1 : 0;\n    const double h1lambda = h1r - h1;\n    const double h0lambda = static_cast<double>(1.) - h1lambda;\n    out_dptr[index] = h0lambda * in_dptr[in_helper.NdIndexToOffset(n, c, h1)]\n                      + h1lambda * in_dptr[in_helper.NdIndexToOffset(n, c, h1 + h1p)];\n  }\n}\n\ntemplate<typename T>\n__global__ void UpsampleLinear1DBackward(const int64_t elem_cnt, const T* dy_dptr,\n                                         NdIndexOffsetHelper<int64_t, 3> dy_helper,\n                                         NdIndexOffsetHelper<int64_t, 3> dx_helper,\n                                         const int in_height, const double scale_factor,\n                                         bool align_corners, T* dx_dptr) {\n  CUDA_1D_KERNEL_LOOP(index, elem_cnt) {\n    int64_t n, c, h;\n    dy_helper.OffsetToNdIndex(index, n, c, h);\n    const double h1r = GetLinearInputIndex(h, scale_factor, align_corners);\n    const int64_t h1 = h1r;\n    const int64_t h1p = (h1 < in_height - 1) ? 1 : 0;\n    const double h1lambda = h1r - h1;\n    const double h0lambda = static_cast<double>(1.) - h1lambda;\n\n    cuda::atomic::FastAdd(dx_dptr, dx_helper.NdIndexToOffset(n, c, h1), elem_cnt,\n                          static_cast<T>(h0lambda * dy_dptr[index]));\n    cuda::atomic::FastAdd(dx_dptr, dx_helper.NdIndexToOffset(n, c, h1 + h1p), elem_cnt,\n                          static_cast<T>(h1lambda * dy_dptr[index]));\n  }\n}\n\n}  // namespace\n\ntemplate<typename T>\nclass UpsampleLinear1DGPUKernel final : public user_op::OpKernel {\n public:\n  UpsampleLinear1DGPUKernel() = default;\n  ~UpsampleLinear1DGPUKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* x_tensor = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    user_op::Tensor* y_tensor = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    const bool align_corners = ctx->Attr<bool>(\"align_corners\");\n    const int64_t elem_cnt = y_tensor->shape_view().elem_cnt();\n    NdIndexOffsetHelper<int64_t, 3> in_helper(\n        x_tensor->shape_view().At(0), x_tensor->shape_view().At(1), x_tensor->shape_view().At(2));\n    NdIndexOffsetHelper<int64_t, 3> out_helper(\n        y_tensor->shape_view().At(0), y_tensor->shape_view().At(1), y_tensor->shape_view().At(2));\n    const int64_t in_height = x_tensor->shape_view().At(2);\n    const int64_t out_height = y_tensor->shape_view().At(2);\n    const std::vector<int64_t> output_size = ctx->Attr<std::vector<int64_t>>(\"output_size\");\n    double height_scale = ctx->Attr<double>(\"scale_factor\");\n    if (!output_size.empty()) {\n      height_scale = static_cast<double>(out_height) / static_cast<double>(in_height);\n    }\n    if (in_height == out_height) {\n      Memcpy<DeviceType::kCUDA>(\n          ctx->stream(), y_tensor->mut_dptr<void>(), x_tensor->dptr<void>(),\n          x_tensor->shape_view().elem_cnt() * GetSizeOfDataType(x_tensor->data_type()));\n    } else {\n      const T scale_height = GetAreaPixelScale(in_height, out_height, align_corners, height_scale);\n      RUN_CUDA_KERNEL((UpsampleLinear1DForward<T>), ctx->stream(), elem_cnt, elem_cnt,\n                      x_tensor->dptr<T>(), in_helper, out_helper, in_height, scale_height,\n                      align_corners, y_tensor->mut_dptr<T>());\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\ntemplate<typename T>\nclass UpsampleLinearGrad1DGPUKernel final : public user_op::OpKernel {\n public:\n  UpsampleLinearGrad1DGPUKernel() = default;\n  ~UpsampleLinearGrad1DGPUKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    user_op::Tensor* dx_tensor = ctx->Tensor4ArgNameAndIndex(\"dx\", 0);\n    Memset<DeviceType::kCUDA>(ctx->stream(), dx_tensor->mut_dptr<T>(), 0,\n                              dx_tensor->shape_view().elem_cnt() * sizeof(T));\n    const user_op::Tensor* dy_tensor = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    const bool align_corners = ctx->Attr<bool>(\"align_corners\");\n\n    NdIndexOffsetHelper<int64_t, 3> dy_helper(dy_tensor->shape_view().At(0),\n                                              dy_tensor->shape_view().At(1),\n                                              dy_tensor->shape_view().At(2));\n    NdIndexOffsetHelper<int64_t, 3> dx_helper(dx_tensor->shape_view().At(0),\n                                              dx_tensor->shape_view().At(1),\n                                              dx_tensor->shape_view().At(2));\n    const int64_t elem_cnt = dy_tensor->shape_view().elem_cnt();\n    const int64_t in_height = dx_tensor->shape_view().At(2);\n    const int64_t out_height = dy_tensor->shape_view().At(2);\n    const std::vector<int64_t> output_size = ctx->Attr<std::vector<int64_t>>(\"output_size\");\n    double height_scale = ctx->Attr<double>(\"scale_factor\");\n    if (!output_size.empty()) {\n      height_scale = static_cast<double>(out_height) / static_cast<double>(in_height);\n    }\n    if (in_height == out_height) {\n      Memcpy<DeviceType::kCUDA>(\n          ctx->stream(), dx_tensor->mut_dptr<void>(), dy_tensor->dptr<void>(),\n          dy_tensor->shape_view().elem_cnt() * GetSizeOfDataType(dy_tensor->data_type()));\n    } else {\n      const T scale_height = GetAreaPixelScale(in_height, out_height, align_corners, height_scale);\n      RUN_CUDA_KERNEL((UpsampleLinear1DBackward<T>), ctx->stream(), elem_cnt, elem_cnt,\n                      dy_tensor->dptr<T>(), dy_helper, dx_helper, in_height, scale_height,\n                      align_corners, dx_tensor->mut_dptr<T>());\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_UPSAMPLELINEAR1D_CUDA_KERNEL(dtype)                                    \\\n  REGISTER_USER_KERNEL(\"upsample_linear_1d\")                                            \\\n      .SetCreateFn<UpsampleLinear1DGPUKernel<dtype>>()                                  \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)                  \\\n                       && (user_op::HobDataType(\"y\", 0) == GetDataType<dtype>::value)); \\\n  REGISTER_USER_KERNEL(\"upsample_linear_1d_grad\")                                       \\\n      .SetCreateFn<UpsampleLinearGrad1DGPUKernel<dtype>>()                              \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)                  \\\n                       && (user_op::HobDataType(\"dx\", 0) == GetDataType<dtype>::value));\n\nREGISTER_UPSAMPLELINEAR1D_CUDA_KERNEL(float)\nREGISTER_UPSAMPLELINEAR1D_CUDA_KERNEL(double)\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/upsample_nearest_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/new_kernel_util.h\"\n#include \"oneflow/core/common/nd_index_offset_helper.h\"\n#include \"oneflow/user/kernels/upsample_kernel.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename T>\nstatic void UpsampleNearest1DForward(const int64_t elem_cnt, const T* in_dptr,\n                                     NdIndexOffsetHelper<int64_t, 3> in_helper,\n                                     NdIndexOffsetHelper<int64_t, 3> out_helper,\n                                     const int64_t in_height, const double scale_factor,\n                                     T* out_dptr) {\n  for (int64_t index = 0; index < elem_cnt; ++index) {\n    int64_t n, c, h;\n    out_helper.OffsetToNdIndex(index, n, c, h);\n    const int64_t in_h = GetNearestInputIndex(h, scale_factor, in_height);\n    out_dptr[index] = in_dptr[in_helper.NdIndexToOffset(n, c, in_h)];\n  }\n}\n\ntemplate<typename T>\nstatic void UpsampleNearest1DBackward(const int64_t elem_cnt, const T* dy_dptr,\n                                      NdIndexOffsetHelper<int64_t, 3> dy_helper,\n                                      NdIndexOffsetHelper<int64_t, 3> dx_helper,\n                                      const int64_t in_height, const double scale_factor,\n                                      T* dx_dptr) {\n  for (int64_t index = 0; index < elem_cnt; ++index) {\n    int64_t n, c, h;\n    dy_helper.OffsetToNdIndex(index, n, c, h);\n    const int64_t dx_h = GetNearestInputIndex(h, scale_factor, in_height);\n    *(dx_dptr + dx_helper.NdIndexToOffset(n, c, dx_h)) += dy_dptr[index];\n  }\n}\n\ntemplate<typename T>\nstatic void UpsampleNearest2DForward(const int64_t elem_cnt, const T* in_dptr,\n                                     NdIndexOffsetHelper<int64_t, 4> in_helper,\n                                     NdIndexOffsetHelper<int64_t, 4> out_helper,\n                                     const int64_t in_height, const int64_t in_width,\n                                     const double scale_h, const double scale_w, T* out_dptr) {\n  for (int64_t index = 0; index < elem_cnt; ++index) {\n    int64_t n, c, h, w;\n    out_helper.OffsetToNdIndex(index, n, c, h, w);\n    const int64_t in_h = GetNearestInputIndex(h, scale_h, in_height);\n    const int64_t in_w = GetNearestInputIndex(w, scale_w, in_width);\n    out_dptr[index] = in_dptr[in_helper.NdIndexToOffset(n, c, in_h, in_w)];\n  }\n}\n\ntemplate<typename T>\nstatic void UpsampleNearest2DBackward(const int64_t elem_cnt, const T* dy_dptr,\n                                      NdIndexOffsetHelper<int64_t, 4> dy_helper,\n                                      NdIndexOffsetHelper<int64_t, 4> dx_helper,\n                                      const int64_t dx_height, const int64_t dx_width,\n                                      const double scale_h, const double scale_w, T* dx_dptr) {\n  for (int64_t index = 0; index < elem_cnt; ++index) {\n    int64_t n, c, h, w;\n    dy_helper.OffsetToNdIndex(index, n, c, h, w);\n    const int64_t dx_h = GetNearestInputIndex(h, scale_h, dx_height);\n    const int64_t dx_w = GetNearestInputIndex(w, scale_w, dx_width);\n    *(dx_dptr + dx_helper.NdIndexToOffset(n, c, dx_h, dx_w)) += dy_dptr[index];\n  }\n}\n\ntemplate<typename T>\nstatic void UpsampleNearest3DForward(const int64_t elem_cnt, const T* in_dptr,\n                                     NdIndexOffsetHelper<int64_t, 5> in_helper,\n                                     NdIndexOffsetHelper<int64_t, 5> out_helper,\n                                     const int64_t in_depth, const int64_t in_height,\n                                     const int64_t in_width, const float scale_d,\n                                     const float scale_h, const float scale_w, T* out_dptr) {\n  for (int64_t index = 0; index < elem_cnt; ++index) {\n    int64_t n, c, d, h, w;\n    out_helper.OffsetToNdIndex(index, n, c, d, h, w);\n    const int64_t in_h = GetNearestInputIndex(h, scale_h, in_height);\n    const int64_t in_w = GetNearestInputIndex(w, scale_w, in_width);\n    const int64_t in_d = GetNearestInputIndex(d, scale_d, in_depth);\n    out_dptr[index] = in_dptr[in_helper.NdIndexToOffset(n, c, in_d, in_h, in_w)];\n  }\n}\n\ntemplate<typename T>\nstatic void UpsampleNearest3DBackward(const int64_t elem_cnt, const T* dy_dptr,\n                                      NdIndexOffsetHelper<int64_t, 5> dy_helper,\n                                      NdIndexOffsetHelper<int64_t, 5> dx_helper,\n                                      const int64_t in_depth, const int64_t in_height,\n                                      const int64_t in_width, const float scale_d,\n                                      const float scale_h, const float scale_w, T* dx_dptr) {\n  for (int64_t index = 0; index < elem_cnt; ++index) {\n    int64_t n, c, d, h, w;\n    dy_helper.OffsetToNdIndex(index, n, c, d, h, w);\n    const int64_t dx_h = GetNearestInputIndex(h, scale_h, in_height);\n    const int64_t dx_w = GetNearestInputIndex(w, scale_w, in_width);\n    const int64_t in_d = GetNearestInputIndex(d, scale_d, in_depth);\n    *(dx_dptr + dx_helper.NdIndexToOffset(n, c, in_d, dx_h, dx_w)) += dy_dptr[index];\n  }\n}\n\n}  // namespace\n\ntemplate<typename T>\nclass UpsampleNearest1DCPUKernel final : public user_op::OpKernel {\n public:\n  UpsampleNearest1DCPUKernel() = default;\n  ~UpsampleNearest1DCPUKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* x_tensor = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    user_op::Tensor* y_tensor = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    const int64_t elem_cnt = y_tensor->shape_view().elem_cnt();\n    const std::vector<int64_t> output_size = ctx->Attr<std::vector<int64_t>>(\"output_size\");\n    double height_scale = ctx->Attr<double>(\"scale_factor\");\n    const int64_t nbatch = x_tensor->shape_view().At(0);\n    const int64_t channels = x_tensor->shape_view().At(1);\n    const int64_t in_height = x_tensor->shape_view().At(2);\n    const int64_t out_height = y_tensor->shape_view().At(2);\n    if (!output_size.empty()) {\n      height_scale = static_cast<double>(out_height) / static_cast<double>(in_height);\n    }\n\n    if (in_height == out_height) {\n      memcpy(y_tensor->mut_dptr<void>(), x_tensor->dptr<void>(),\n             sizeof(T) * nbatch * channels * in_height);\n    } else {\n      NdIndexOffsetHelper<int64_t, 3> in_helper(\n          x_tensor->shape_view().At(0), x_tensor->shape_view().At(1), x_tensor->shape_view().At(2));\n      NdIndexOffsetHelper<int64_t, 3> out_helper(\n          y_tensor->shape_view().At(0), y_tensor->shape_view().At(1), y_tensor->shape_view().At(2));\n      UpsampleNearest1DForward<T>(elem_cnt, x_tensor->dptr<T>(), in_helper, out_helper,\n                                  x_tensor->shape_view().At(2), 1.f / height_scale,\n                                  y_tensor->mut_dptr<T>());\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\ntemplate<typename T>\nclass UpsampleNearestGrad1DCPUKernel final : public user_op::OpKernel {\n public:\n  UpsampleNearestGrad1DCPUKernel() = default;\n  ~UpsampleNearestGrad1DCPUKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    user_op::Tensor* dx_tensor = ctx->Tensor4ArgNameAndIndex(\"dx\", 0);\n\n    Memset<DeviceType::kCPU>(ctx->stream(), dx_tensor->mut_dptr<T>(), 0,\n                             dx_tensor->shape_view().elem_cnt() * sizeof(T));\n    const user_op::Tensor* dy_tensor = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    const std::vector<int64_t> output_size = ctx->Attr<std::vector<int64_t>>(\"output_size\");\n    double height_scale = ctx->Attr<double>(\"scale_factor\");\n    const int64_t elem_cnt = dy_tensor->shape_view().elem_cnt();\n    const int64_t nbatch = dx_tensor->shape_view().At(0);\n    const int64_t channels = dx_tensor->shape_view().At(1);\n    const int64_t in_height = dx_tensor->shape_view().At(2);\n    const int64_t out_height = dy_tensor->shape_view().At(2);\n    if (!output_size.empty()) {\n      height_scale = static_cast<double>(out_height) / static_cast<double>(in_height);\n    }\n    if (in_height == out_height) {\n      memcpy(dx_tensor->mut_dptr<void>(), dy_tensor->dptr<void>(),\n             sizeof(T) * nbatch * channels * in_height);\n    } else {\n      NdIndexOffsetHelper<int64_t, 3> dy_helper(dy_tensor->shape_view().At(0),\n                                                dy_tensor->shape_view().At(1),\n                                                dy_tensor->shape_view().At(2));\n      NdIndexOffsetHelper<int64_t, 3> dx_helper(dx_tensor->shape_view().At(0),\n                                                dx_tensor->shape_view().At(1),\n                                                dx_tensor->shape_view().At(2));\n      UpsampleNearest1DBackward<T>(elem_cnt, dy_tensor->dptr<T>(), dy_helper, dx_helper,\n                                   dx_tensor->shape_view().At(2), 1.f / height_scale,\n                                   dx_tensor->mut_dptr<T>());\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_UPSAMPNEAREST1D_CPU_KERNEL(dtype)                                      \\\n  REGISTER_USER_KERNEL(\"upsample_nearest_1d\")                                           \\\n      .SetCreateFn<UpsampleNearest1DCPUKernel<dtype>>()                                 \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)                   \\\n                       && (user_op::HobDataType(\"y\", 0) == GetDataType<dtype>::value)); \\\n  REGISTER_USER_KERNEL(\"upsample_nearest_1d_grad\")                                      \\\n      .SetCreateFn<UpsampleNearestGrad1DCPUKernel<dtype>>()                             \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)                   \\\n                       && (user_op::HobDataType(\"dx\", 0) == GetDataType<dtype>::value));\n\nREGISTER_UPSAMPNEAREST1D_CPU_KERNEL(float)\nREGISTER_UPSAMPNEAREST1D_CPU_KERNEL(double)\n\ntemplate<typename T>\nclass UpsampleNearest2DCPUKernel final : public user_op::OpKernel {\n public:\n  UpsampleNearest2DCPUKernel() = default;\n  ~UpsampleNearest2DCPUKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* x_tensor = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    user_op::Tensor* y_tensor = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    const std::vector<int64_t> output_size = ctx->Attr<std::vector<int64_t>>(\"output_size\");\n    double height_scale = ctx->Attr<double>(\"height_scale\");\n    double width_scale = ctx->Attr<double>(\"width_scale\");\n    const int64_t nbatch = x_tensor->shape_view().At(0);\n    const int64_t channels = x_tensor->shape_view().At(1);\n    const int64_t in_height = x_tensor->shape_view().At(2);\n    const int64_t in_width = x_tensor->shape_view().At(3);\n    const int64_t out_height = y_tensor->shape_view().At(2);\n    const int64_t out_width = y_tensor->shape_view().At(3);\n    const int64_t elem_cnt = y_tensor->shape_view().elem_cnt();\n    if (!output_size.empty()) {\n      height_scale = static_cast<double>(out_height) / static_cast<double>(in_height);\n      width_scale = static_cast<double>(out_width) / static_cast<double>(in_width);\n    }\n\n    if (in_height == out_height && in_width == out_width) {\n      memcpy(y_tensor->mut_dptr<void>(), x_tensor->dptr<void>(),\n             sizeof(T) * nbatch * channels * in_height * in_width);\n    } else {\n      NdIndexOffsetHelper<int64_t, 4> in_helper(\n          x_tensor->shape_view().At(0), x_tensor->shape_view().At(1), x_tensor->shape_view().At(2),\n          x_tensor->shape_view().At(3));\n      NdIndexOffsetHelper<int64_t, 4> out_helper(\n          y_tensor->shape_view().At(0), y_tensor->shape_view().At(1), y_tensor->shape_view().At(2),\n          y_tensor->shape_view().At(3));\n      UpsampleNearest2DForward<T>(elem_cnt, x_tensor->dptr<T>(), in_helper, out_helper,\n                                  x_tensor->shape_view().At(2), x_tensor->shape_view().At(3),\n                                  1.f / height_scale, 1.f / width_scale, y_tensor->mut_dptr<T>());\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\ntemplate<typename T>\nclass UpsampleNearest2DGradCPUKernel final : public user_op::OpKernel {\n public:\n  UpsampleNearest2DGradCPUKernel() = default;\n  ~UpsampleNearest2DGradCPUKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    user_op::Tensor* dx_tensor = ctx->Tensor4ArgNameAndIndex(\"dx\", 0);\n\n    Memset<DeviceType::kCPU>(ctx->stream(), dx_tensor->mut_dptr<T>(), 0,\n                             dx_tensor->shape_view().elem_cnt() * sizeof(T));\n    const user_op::Tensor* dy_tensor = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    const std::vector<int64_t> output_size = ctx->Attr<std::vector<int64_t>>(\"output_size\");\n    double height_scale = ctx->Attr<double>(\"height_scale\");\n    double width_scale = ctx->Attr<double>(\"width_scale\");\n    const int64_t nbatch = dx_tensor->shape_view().At(0);\n    const int64_t channels = dx_tensor->shape_view().At(1);\n    const int64_t in_height = dx_tensor->shape_view().At(2);\n    const int64_t in_width = dx_tensor->shape_view().At(3);\n    const int64_t out_height = dy_tensor->shape_view().At(2);\n    const int64_t out_width = dy_tensor->shape_view().At(3);\n    const int64_t elem_cnt = dy_tensor->shape_view().elem_cnt();\n    if (!output_size.empty()) {\n      height_scale = static_cast<double>(out_height) / static_cast<double>(in_height);\n      width_scale = static_cast<double>(out_width) / static_cast<double>(in_width);\n    }\n\n    if (in_height == out_height && in_width == out_width) {\n      memcpy(dx_tensor->mut_dptr<void>(), dy_tensor->dptr<void>(),\n             sizeof(T) * nbatch * channels * in_height * in_width);\n    } else {\n      NdIndexOffsetHelper<int64_t, 4> dy_helper(\n          dy_tensor->shape_view().At(0), dy_tensor->shape_view().At(1),\n          dy_tensor->shape_view().At(2), dy_tensor->shape_view().At(3));\n      NdIndexOffsetHelper<int64_t, 4> dx_helper(\n          dx_tensor->shape_view().At(0), dx_tensor->shape_view().At(1),\n          dx_tensor->shape_view().At(2), dx_tensor->shape_view().At(3));\n      UpsampleNearest2DBackward<T>(elem_cnt, dy_tensor->dptr<T>(), dy_helper, dx_helper,\n                                   dx_tensor->shape_view().At(2), dx_tensor->shape_view().At(3),\n                                   1.f / height_scale, 1.f / width_scale, dx_tensor->mut_dptr<T>());\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_UPSAMPLE_NEAREST_2D_CPU_KERNEL(dtype)                                  \\\n  REGISTER_USER_KERNEL(\"upsample_nearest_2d\")                                           \\\n      .SetCreateFn<UpsampleNearest2DCPUKernel<dtype>>()                                 \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)                   \\\n                       && (user_op::HobDataType(\"y\", 0) == GetDataType<dtype>::value)); \\\n  REGISTER_USER_KERNEL(\"upsample_nearest_2d_grad\")                                      \\\n      .SetCreateFn<UpsampleNearest2DGradCPUKernel<dtype>>()                             \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)                   \\\n                       && (user_op::HobDataType(\"dx\", 0) == GetDataType<dtype>::value));\n\nREGISTER_UPSAMPLE_NEAREST_2D_CPU_KERNEL(float)\nREGISTER_UPSAMPLE_NEAREST_2D_CPU_KERNEL(double)\n\ntemplate<typename T>\nclass UpsampleNearest3DCPUKernel final : public user_op::OpKernel {\n public:\n  UpsampleNearest3DCPUKernel() = default;\n  ~UpsampleNearest3DCPUKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* x_blob = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    user_op::Tensor* y_blob = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    const std::vector<int64_t> output_size = ctx->Attr<std::vector<int64_t>>(\"output_size\");\n    double depth_scale = ctx->Attr<double>(\"depth_scale\");\n    double height_scale = ctx->Attr<double>(\"height_scale\");\n    double width_scale = ctx->Attr<double>(\"width_scale\");\n    const int64_t in_depth = x_blob->shape_view().At(2);\n    const int64_t in_height = x_blob->shape_view().At(3);\n    const int64_t in_width = x_blob->shape_view().At(4);\n    const int64_t out_depth = y_blob->shape_view().At(2);\n    const int64_t out_height = y_blob->shape_view().At(3);\n    const int64_t out_width = y_blob->shape_view().At(4);\n    const int64_t elem_cnt = y_blob->shape_view().elem_cnt();\n    if (!output_size.empty()) {\n      depth_scale = static_cast<double>(out_depth) / static_cast<double>(in_depth);\n      height_scale = static_cast<double>(out_height) / static_cast<double>(in_height);\n      width_scale = static_cast<double>(out_width) / static_cast<double>(in_width);\n    }\n    NdIndexOffsetHelper<int64_t, 5> in_helper(\n        x_blob->shape_view().At(0), x_blob->shape_view().At(1), x_blob->shape_view().At(2),\n        x_blob->shape_view().At(3), x_blob->shape_view().At(4));\n    NdIndexOffsetHelper<int64_t, 5> out_helper(\n        y_blob->shape_view().At(0), y_blob->shape_view().At(1), y_blob->shape_view().At(2),\n        y_blob->shape_view().At(3), y_blob->shape_view().At(4));\n    UpsampleNearest3DForward<T>(elem_cnt, x_blob->dptr<T>(), in_helper, out_helper,\n                                x_blob->shape_view().At(2), x_blob->shape_view().At(3),\n                                x_blob->shape_view().At(4), 1.f / depth_scale, 1.f / height_scale,\n                                1.f / width_scale, y_blob->mut_dptr<T>());\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\ntemplate<typename T>\nclass UpsampleNearestGrad3DCPUKernel final : public user_op::OpKernel {\n public:\n  UpsampleNearestGrad3DCPUKernel() = default;\n  ~UpsampleNearestGrad3DCPUKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    user_op::Tensor* dx_blob = ctx->Tensor4ArgNameAndIndex(\"dx\", 0);\n    if (dx_blob == nullptr) { return; }\n    Memset<DeviceType::kCPU>(ctx->stream(), dx_blob->mut_dptr<T>(), 0,\n                             dx_blob->shape_view().elem_cnt() * sizeof(T));\n    const user_op::Tensor* dy_blob = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    const std::vector<int64_t> output_size = ctx->Attr<std::vector<int64_t>>(\"output_size\");\n    double depth_scale = ctx->Attr<double>(\"depth_scale\");\n    double height_scale = ctx->Attr<double>(\"height_scale\");\n    double width_scale = ctx->Attr<double>(\"width_scale\");\n    const int64_t in_depth = dx_blob->shape_view().At(2);\n    const int64_t in_height = dx_blob->shape_view().At(3);\n    const int64_t in_width = dx_blob->shape_view().At(4);\n    const int64_t out_depth = dy_blob->shape_view().At(2);\n    const int64_t out_height = dy_blob->shape_view().At(3);\n    const int64_t out_width = dy_blob->shape_view().At(4);\n    const int64_t elem_cnt = dy_blob->shape_view().elem_cnt();\n    if (!output_size.empty()) {\n      depth_scale = static_cast<double>(out_depth) / static_cast<double>(in_depth);\n      height_scale = static_cast<double>(out_height) / static_cast<double>(in_height);\n      width_scale = static_cast<double>(out_width) / static_cast<double>(in_width);\n    }\n    NdIndexOffsetHelper<int64_t, 5> dy_helper(\n        dy_blob->shape_view().At(0), dy_blob->shape_view().At(1), dy_blob->shape_view().At(2),\n        dy_blob->shape_view().At(3), dy_blob->shape_view().At(4));\n    NdIndexOffsetHelper<int64_t, 5> dx_helper(\n        dx_blob->shape_view().At(0), dx_blob->shape_view().At(1), dx_blob->shape_view().At(2),\n        dx_blob->shape_view().At(3), dx_blob->shape_view().At(4));\n    UpsampleNearest3DBackward<T>(elem_cnt, dy_blob->dptr<T>(), dy_helper, dx_helper,\n                                 dx_blob->shape_view().At(2), dx_blob->shape_view().At(3),\n                                 dx_blob->shape_view().At(4), 1.f / depth_scale, 1.f / height_scale,\n                                 1.f / width_scale, dx_blob->mut_dptr<T>());\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_UPSAMPNEAREST3D_CPU_KERNEL(dtype)                                      \\\n  REGISTER_USER_KERNEL(\"upsample_nearest_3d\")                                           \\\n      .SetCreateFn<UpsampleNearest3DCPUKernel<dtype>>()                                 \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)                   \\\n                       && (user_op::HobDataType(\"y\", 0) == GetDataType<dtype>::value)); \\\n  REGISTER_USER_KERNEL(\"upsample_nearest_3d_grad\")                                      \\\n      .SetCreateFn<UpsampleNearestGrad3DCPUKernel<dtype>>()                             \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)                   \\\n                       && (user_op::HobDataType(\"dx\", 0) == GetDataType<dtype>::value));\n\nREGISTER_UPSAMPNEAREST3D_CPU_KERNEL(float)\nREGISTER_UPSAMPNEAREST3D_CPU_KERNEL(double)\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/upsample_nearest_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/new_kernel_util.h\"\n#include \"oneflow/core/common/nd_index_offset_helper.h\"\n#include \"oneflow/core/kernel/kernel_util.cuh\"\n#include \"oneflow/user/kernels/upsample_kernel.h\"\n#include \"oneflow/core/kernel/cuda_graph_support.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename T>\n__global__ void UpsampleNearest1DForward(const int64_t elem_cnt, const T* in_dptr,\n                                         NdIndexOffsetHelper<int64_t, 3> in_helper,\n                                         NdIndexOffsetHelper<int64_t, 3> out_helper,\n                                         const int64_t in_height, const double scale_factor,\n                                         T* out_dptr) {\n  CUDA_1D_KERNEL_LOOP(index, elem_cnt) {\n    int64_t n, c, h;\n    out_helper.OffsetToNdIndex(index, n, c, h);\n    const int64_t in_h = GetNearestInputIndex(h, scale_factor, in_height);\n    out_dptr[index] = in_dptr[in_helper.NdIndexToOffset(n, c, in_h)];\n  }\n}\n\ntemplate<typename T>\n__global__ void UpsampleNearest1DBackward(const int64_t elem_cnt, const T* dy_dptr,\n                                          NdIndexOffsetHelper<int64_t, 3> dy_helper,\n                                          NdIndexOffsetHelper<int64_t, 3> dx_helper,\n                                          const int64_t in_height, const double scale_factor,\n                                          T* dx_dptr) {\n  CUDA_1D_KERNEL_LOOP(index, elem_cnt) {\n    int64_t n, c, h;\n    dy_helper.OffsetToNdIndex(index, n, c, h);\n    const int64_t dx_h = GetNearestInputIndex(h, scale_factor, in_height);\n    cuda::atomic::FastAdd(dx_dptr, dx_helper.NdIndexToOffset(n, c, dx_h), elem_cnt,\n                          static_cast<T>(dy_dptr[index]));\n  }\n}\n\ntemplate<typename T>\n__global__ void UpsampleNearest2DForward(const int64_t elem_cnt, const T* in_dptr,\n                                         NdIndexOffsetHelper<int64_t, 4> in_helper,\n                                         NdIndexOffsetHelper<int64_t, 4> out_helper,\n                                         const int64_t in_height, const int64_t in_width,\n                                         const double scale_h, const double scale_w, T* out_dptr) {\n  CUDA_1D_KERNEL_LOOP(index, elem_cnt) {\n    int64_t n, c, h, w;\n    out_helper.OffsetToNdIndex(index, n, c, h, w);\n    const int64_t in_h = GetNearestInputIndex(h, scale_h, in_height);\n    const int64_t in_w = GetNearestInputIndex(w, scale_w, in_width);\n    out_dptr[index] = in_dptr[in_helper.NdIndexToOffset(n, c, in_h, in_w)];\n  }\n}\n\ntemplate<typename T>\nstruct alignas(2 * sizeof(T)) Pack2X {\n  T x;\n  T y;\n};\n\ntemplate<typename T>\n__global__ void UpsampleNearest2D2XForward(const int32_t in_elem_cnt, const T* in_dptr,\n                                           const int32_t in_height, const int32_t in_width,\n                                           T* out_dptr) {\n  const int32_t in_hw_size = in_width * in_height;\n  CUDA_1D_KERNEL_LOOP(index, in_elem_cnt) {\n    const T in_value = in_dptr[index];\n    const int32_t nc_idx = index / in_hw_size;\n    const int32_t hw_off = index - nc_idx * in_hw_size;\n    const int32_t h = hw_off / in_width;\n    const int32_t w = hw_off - h * in_width;\n    Pack2X<T> out_value{in_value, in_value};\n    Pack2X<T>* out_pack_dptr = reinterpret_cast<Pack2X<T>*>(out_dptr);\n    out_pack_dptr[nc_idx * in_hw_size * 2 + h * 2 * in_width + w] = out_value;\n    out_pack_dptr[nc_idx * in_hw_size * 2 + (h * 2 + 1) * in_width + w] = out_value;\n  }\n}\n\ntemplate<typename T>\n__global__ void UpsampleNearest2DBackward(const int64_t elem_cnt, const T* dy_dptr,\n                                          NdIndexOffsetHelper<int64_t, 4> dy_helper,\n                                          NdIndexOffsetHelper<int64_t, 4> dx_helper,\n                                          const int64_t dx_height, const int64_t dx_width,\n                                          const double scale_h, const double scale_w, T* dx_dptr) {\n  CUDA_1D_KERNEL_LOOP(index, elem_cnt) {\n    int64_t n, c, h, w;\n    dy_helper.OffsetToNdIndex(index, n, c, h, w);\n    const int64_t dx_h = GetNearestInputIndex(h, scale_h, dx_height);\n    const int64_t dx_w = GetNearestInputIndex(w, scale_w, dx_width);\n    cuda::atomic::FastAdd(dx_dptr, dx_helper.NdIndexToOffset(n, c, dx_h, dx_w), elem_cnt,\n                          static_cast<T>(dy_dptr[index]));\n  }\n}\n\ntemplate<typename T>\n__global__ void UpsampleNearest2D2XBackward(const int32_t in_elem_cnt, const T* dy_dptr,\n                                            const int32_t dx_height, const int32_t dx_width,\n                                            T* dx_dptr) {\n  const int32_t dx_hw_size = dx_height * dx_width;\n  CUDA_1D_KERNEL_LOOP(index, in_elem_cnt) {\n    T dx_value = 0.0;\n    const int32_t nc_idx = index / dx_hw_size;\n    const int32_t dx_hw_off = index - nc_idx * dx_hw_size;\n    const int32_t dx_h = dx_hw_off / dx_width;\n    const int32_t dx_w = dx_hw_off - dx_h * dx_width;\n    const Pack2X<T>* dy_pack_dptr = reinterpret_cast<const Pack2X<T>*>(dy_dptr);\n    const Pack2X<T> dy_pack_value1 =\n        dy_pack_dptr[nc_idx * dx_hw_size * 2 + dx_h * 2 * dx_width + dx_w];\n    const Pack2X<T> dy_pack_value2 =\n        dy_pack_dptr[nc_idx * dx_hw_size * 2 + (dx_h * 2 + 1) * dx_width + dx_w];\n    dx_value += dy_pack_value1.x;\n    dx_value += dy_pack_value1.y;\n    dx_value += dy_pack_value2.x;\n    dx_value += dy_pack_value2.y;\n    dx_dptr[index] = dx_value;\n  }\n}\n\ntemplate<typename T>\n__global__ void UpsampleNearest3DForward(const int64_t elem_cnt, const T* in_dptr,\n                                         NdIndexOffsetHelper<int64_t, 5> in_helper,\n                                         NdIndexOffsetHelper<int64_t, 5> out_helper,\n                                         const int64_t in_depth, const int64_t in_height,\n                                         const int64_t in_width, const float scale_d,\n                                         const float scale_h, const float scale_w, T* out_dptr) {\n  CUDA_1D_KERNEL_LOOP(index, elem_cnt) {\n    int64_t n, c, d, h, w;\n    out_helper.OffsetToNdIndex(index, n, c, d, h, w);\n    const int64_t in_h = GetNearestInputIndex(h, scale_h, in_height);\n    const int64_t in_w = GetNearestInputIndex(w, scale_w, in_width);\n    const int64_t in_d = GetNearestInputIndex(d, scale_d, in_depth);\n    out_dptr[index] = in_dptr[in_helper.NdIndexToOffset(n, c, in_d, in_h, in_w)];\n  }\n}\n\ntemplate<typename T>\n__global__ void UpsampleNearest3DBackward(const int64_t elem_cnt, const T* dy_dptr,\n                                          NdIndexOffsetHelper<int64_t, 5> dy_helper,\n                                          NdIndexOffsetHelper<int64_t, 5> dx_helper,\n                                          const int64_t in_depth, const int64_t in_height,\n                                          const int64_t in_width, const float scale_d,\n                                          const float scale_h, const float scale_w, T* dx_dptr) {\n  CUDA_1D_KERNEL_LOOP(index, elem_cnt) {\n    int64_t n, c, d, h, w;\n    dy_helper.OffsetToNdIndex(index, n, c, d, h, w);\n    const int64_t dx_h = GetNearestInputIndex(h, scale_h, in_height);\n    const int64_t dx_w = GetNearestInputIndex(w, scale_w, in_width);\n    const int64_t in_d = GetNearestInputIndex(d, scale_d, in_depth);\n    cuda::atomic::FastAdd(dx_dptr, dx_helper.NdIndexToOffset(n, c, in_d, dx_h, dx_w), elem_cnt,\n                          static_cast<T>(dy_dptr[index]));\n  }\n}\n\n}  // namespace\n\ntemplate<typename T>\nclass UpsampleNearest1DGPUKernel final : public user_op::OpKernel {\n public:\n  UpsampleNearest1DGPUKernel() = default;\n  ~UpsampleNearest1DGPUKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* x_tensor = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    user_op::Tensor* y_tensor = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    const std::vector<int64_t> output_size = ctx->Attr<std::vector<int64_t>>(\"output_size\");\n    double height_scale = ctx->Attr<double>(\"scale_factor\");\n    const int64_t elem_cnt = y_tensor->shape_view().elem_cnt();\n    const int64_t in_height = x_tensor->shape_view().At(2);\n    const int64_t out_height = y_tensor->shape_view().At(2);\n    if (!output_size.empty()) {\n      height_scale = static_cast<double>(out_height) / static_cast<double>(in_height);\n    }\n    if (in_height == out_height) {\n      Memcpy<DeviceType::kCUDA>(\n          ctx->stream(), y_tensor->mut_dptr<void>(), x_tensor->dptr<void>(),\n          x_tensor->shape_view().elem_cnt() * GetSizeOfDataType(x_tensor->data_type()));\n    } else {\n      NdIndexOffsetHelper<int64_t, 3> in_helper(\n          x_tensor->shape_view().At(0), x_tensor->shape_view().At(1), x_tensor->shape_view().At(2));\n      NdIndexOffsetHelper<int64_t, 3> out_helper(\n          y_tensor->shape_view().At(0), y_tensor->shape_view().At(1), y_tensor->shape_view().At(2));\n      RUN_CUDA_KERNEL((UpsampleNearest1DForward<T>), ctx->stream(), elem_cnt, elem_cnt,\n                      x_tensor->dptr<T>(), in_helper, out_helper, x_tensor->shape_view().At(2),\n                      1.f / height_scale, y_tensor->mut_dptr<T>());\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\ntemplate<typename T>\nclass UpsampleNearestGrad1DGPUKernel final : public user_op::OpKernel {\n public:\n  UpsampleNearestGrad1DGPUKernel() = default;\n  ~UpsampleNearestGrad1DGPUKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    user_op::Tensor* dx_tensor = ctx->Tensor4ArgNameAndIndex(\"dx\", 0);\n\n    const user_op::Tensor* dy_tensor = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    const std::vector<int64_t> output_size = ctx->Attr<std::vector<int64_t>>(\"output_size\");\n    double height_scale = ctx->Attr<double>(\"scale_factor\");\n    const int64_t elem_cnt = dy_tensor->shape_view().elem_cnt();\n    const int64_t in_height = dx_tensor->shape_view().At(2);\n    const int64_t out_height = dy_tensor->shape_view().At(2);\n    if (!output_size.empty()) {\n      height_scale = static_cast<double>(out_height) / static_cast<double>(in_height);\n    }\n    if (in_height == out_height) {\n      Memcpy<DeviceType::kCUDA>(\n          ctx->stream(), dx_tensor->mut_dptr<void>(), dy_tensor->dptr<void>(),\n          dy_tensor->shape_view().elem_cnt() * GetSizeOfDataType(dy_tensor->data_type()));\n    } else {\n      Memset<DeviceType::kCUDA>(ctx->stream(), dx_tensor->mut_dptr<T>(), 0,\n                                dx_tensor->shape_view().elem_cnt() * sizeof(T));\n      NdIndexOffsetHelper<int64_t, 3> dy_helper(dy_tensor->shape_view().At(0),\n                                                dy_tensor->shape_view().At(1),\n                                                dy_tensor->shape_view().At(2));\n      NdIndexOffsetHelper<int64_t, 3> dx_helper(dx_tensor->shape_view().At(0),\n                                                dx_tensor->shape_view().At(1),\n                                                dx_tensor->shape_view().At(2));\n      RUN_CUDA_KERNEL((UpsampleNearest1DBackward<T>), ctx->stream(), elem_cnt, elem_cnt,\n                      dy_tensor->dptr<T>(), dy_helper, dx_helper, dx_tensor->shape_view().At(2),\n                      1.f / height_scale, dx_tensor->mut_dptr<T>());\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_UPSAMPNEAREST1D_CUDA_KERNEL(dtype)                                     \\\n  REGISTER_USER_KERNEL(\"upsample_nearest_1d\")                                           \\\n      .SetCreateFn<UpsampleNearest1DGPUKernel<dtype>>()                                 \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)                  \\\n                       && (user_op::HobDataType(\"y\", 0) == GetDataType<dtype>::value)); \\\n  REGISTER_USER_KERNEL(\"upsample_nearest_1d_grad\")                                      \\\n      .SetCreateFn<UpsampleNearestGrad1DGPUKernel<dtype>>()                             \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)                  \\\n                       && (user_op::HobDataType(\"dx\", 0) == GetDataType<dtype>::value));\n\nREGISTER_UPSAMPNEAREST1D_CUDA_KERNEL(float)\nREGISTER_UPSAMPNEAREST1D_CUDA_KERNEL(double)\n\ntemplate<typename T>\nclass UpsampleNearest2DGPUKernel final : public user_op::OpKernel,\n                                         public user_op::CudaGraphSupport {\n public:\n  UpsampleNearest2DGPUKernel() = default;\n  ~UpsampleNearest2DGPUKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* x_tensor = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    user_op::Tensor* y_tensor = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    const std::vector<int64_t> output_size = ctx->Attr<std::vector<int64_t>>(\"output_size\");\n    double height_scale = ctx->Attr<double>(\"height_scale\");\n    double width_scale = ctx->Attr<double>(\"width_scale\");\n    const int64_t out_elem_cnt = y_tensor->shape_view().elem_cnt();\n    const int64_t in_elem_cnt = x_tensor->shape_view().elem_cnt();\n    const int64_t in_height = x_tensor->shape_view().At(2);\n    const int64_t in_width = x_tensor->shape_view().At(3);\n    const int64_t out_height = y_tensor->shape_view().At(2);\n    const int64_t out_width = y_tensor->shape_view().At(3);\n    if (!output_size.empty()) {\n      height_scale = static_cast<double>(out_height) / static_cast<double>(in_height);\n      width_scale = static_cast<double>(out_width) / static_cast<double>(in_width);\n    }\n\n    if (in_height == out_height && in_width == out_width) {\n      Memcpy<DeviceType::kCUDA>(\n          ctx->stream(), y_tensor->mut_dptr<void>(), x_tensor->dptr<void>(),\n          x_tensor->shape_view().elem_cnt() * GetSizeOfDataType(x_tensor->data_type()));\n    } else {\n      const int64_t n = x_tensor->shape_view().At(0);\n      const int64_t c = x_tensor->shape_view().At(1);\n      if (out_height == 2 * in_height && out_width == 2 * in_width && in_elem_cnt <= 1 << 29) {\n        RUN_CUDA_KERNEL(UpsampleNearest2D2XForward<T>, ctx->stream(), in_elem_cnt, in_elem_cnt,\n                        x_tensor->dptr<T>(), in_height, in_width, y_tensor->mut_dptr<T>());\n      } else {\n        NdIndexOffsetHelper<int64_t, 4> in_helper(n, c, in_height, in_width);\n        NdIndexOffsetHelper<int64_t, 4> out_helper(n, c, out_height, out_width);\n        RUN_CUDA_KERNEL((UpsampleNearest2DForward<T>), ctx->stream(), out_elem_cnt, out_elem_cnt,\n                        x_tensor->dptr<T>(), in_helper, out_helper, in_height, in_width,\n                        1.f / height_scale, 1.f / width_scale, y_tensor->mut_dptr<T>());\n      }\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\ntemplate<typename T>\nclass UpsampleNearest2DGradGPUKernel final : public user_op::OpKernel {\n public:\n  UpsampleNearest2DGradGPUKernel() = default;\n  ~UpsampleNearest2DGradGPUKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    user_op::Tensor* dx_tensor = ctx->Tensor4ArgNameAndIndex(\"dx\", 0);\n\n    const user_op::Tensor* dy_tensor = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    const std::vector<int64_t> output_size = ctx->Attr<std::vector<int64_t>>(\"output_size\");\n    double height_scale = ctx->Attr<double>(\"height_scale\");\n    double width_scale = ctx->Attr<double>(\"width_scale\");\n    const int64_t elem_cnt = dy_tensor->shape_view().elem_cnt();\n    const int64_t in_elem_cnt = dx_tensor->shape_view().elem_cnt();\n    const int64_t in_height = dx_tensor->shape_view().At(2);\n    const int64_t in_width = dx_tensor->shape_view().At(3);\n    const int64_t out_height = dy_tensor->shape_view().At(2);\n    const int64_t out_width = dy_tensor->shape_view().At(3);\n    if (!output_size.empty()) {\n      height_scale = static_cast<double>(out_height) / static_cast<double>(in_height);\n      width_scale = static_cast<double>(out_width) / static_cast<double>(in_width);\n    }\n    if (in_height == out_height && in_width == out_width) {\n      Memcpy<DeviceType::kCUDA>(\n          ctx->stream(), dx_tensor->mut_dptr<void>(), dy_tensor->dptr<void>(),\n          dy_tensor->shape_view().elem_cnt() * GetSizeOfDataType(dy_tensor->data_type()));\n    } else {\n      if (out_height == 2 * in_height && out_width == 2 * in_width && in_elem_cnt <= 1 << 29) {\n        RUN_CUDA_KERNEL(UpsampleNearest2D2XBackward<T>, ctx->stream(), in_elem_cnt, in_elem_cnt,\n                        dy_tensor->dptr<T>(), dx_tensor->shape_view().At(2),\n                        dx_tensor->shape_view().At(3), dx_tensor->mut_dptr<T>());\n      } else {\n        Memset<DeviceType::kCUDA>(ctx->stream(), dx_tensor->mut_dptr<T>(), 0,\n                                  dx_tensor->shape_view().elem_cnt() * sizeof(T));\n        NdIndexOffsetHelper<int64_t, 4> dy_helper(\n            dy_tensor->shape_view().At(0), dy_tensor->shape_view().At(1),\n            dy_tensor->shape_view().At(2), dy_tensor->shape_view().At(3));\n        NdIndexOffsetHelper<int64_t, 4> dx_helper(\n            dx_tensor->shape_view().At(0), dx_tensor->shape_view().At(1),\n            dx_tensor->shape_view().At(2), dx_tensor->shape_view().At(3));\n        RUN_CUDA_KERNEL((UpsampleNearest2DBackward<T>), ctx->stream(), elem_cnt, elem_cnt,\n                        dy_tensor->dptr<T>(), dy_helper, dx_helper, dx_tensor->shape_view().At(2),\n                        dx_tensor->shape_view().At(3), 1.f / height_scale, 1.f / width_scale,\n                        dx_tensor->mut_dptr<T>());\n      }\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_UPSAMPLE_NEAREST_2D_CUDA_KERNEL(dtype)                                 \\\n  REGISTER_USER_KERNEL(\"upsample_nearest_2d\")                                           \\\n      .SetCreateFn<UpsampleNearest2DGPUKernel<dtype>>()                                 \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)                  \\\n                       && (user_op::HobDataType(\"y\", 0) == GetDataType<dtype>::value)); \\\n  REGISTER_USER_KERNEL(\"upsample_nearest_2d_grad\")                                      \\\n      .SetCreateFn<UpsampleNearest2DGradGPUKernel<dtype>>()                             \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)                  \\\n                       && (user_op::HobDataType(\"dx\", 0) == GetDataType<dtype>::value));\n\nREGISTER_UPSAMPLE_NEAREST_2D_CUDA_KERNEL(float)\nREGISTER_UPSAMPLE_NEAREST_2D_CUDA_KERNEL(half)\nREGISTER_UPSAMPLE_NEAREST_2D_CUDA_KERNEL(double)\n\ntemplate<typename T>\nclass UpsampleNearest3DGPUKernel final : public user_op::OpKernel {\n public:\n  UpsampleNearest3DGPUKernel() = default;\n  ~UpsampleNearest3DGPUKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* x_tensor = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    user_op::Tensor* y_tensor = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    const std::vector<int64_t> output_size = ctx->Attr<std::vector<int64_t>>(\"output_size\");\n    double depth_scale = ctx->Attr<double>(\"depth_scale\");\n    double height_scale = ctx->Attr<double>(\"height_scale\");\n    double width_scale = ctx->Attr<double>(\"width_scale\");\n    const int64_t in_depth = x_tensor->shape_view().At(2);\n    const int64_t in_height = x_tensor->shape_view().At(3);\n    const int64_t in_width = x_tensor->shape_view().At(4);\n    const int64_t out_depth = y_tensor->shape_view().At(2);\n    const int64_t out_height = y_tensor->shape_view().At(3);\n    const int64_t out_width = y_tensor->shape_view().At(4);\n    const int64_t elem_cnt = y_tensor->shape_view().elem_cnt();\n    if (!output_size.empty()) {\n      depth_scale = static_cast<double>(out_depth) / static_cast<double>(in_depth);\n      height_scale = static_cast<double>(out_height) / static_cast<double>(in_height);\n      width_scale = static_cast<double>(out_width) / static_cast<double>(in_width);\n    }\n    NdIndexOffsetHelper<int64_t, 5> in_helper(\n        x_tensor->shape_view().At(0), x_tensor->shape_view().At(1), x_tensor->shape_view().At(2),\n        x_tensor->shape_view().At(3), x_tensor->shape_view().At(4));\n    NdIndexOffsetHelper<int64_t, 5> out_helper(\n        y_tensor->shape_view().At(0), y_tensor->shape_view().At(1), y_tensor->shape_view().At(2),\n        y_tensor->shape_view().At(3), y_tensor->shape_view().At(4));\n    RUN_CUDA_KERNEL((UpsampleNearest3DForward<T>), ctx->stream(), elem_cnt, elem_cnt,\n                    x_tensor->dptr<T>(), in_helper, out_helper, x_tensor->shape_view().At(2),\n                    x_tensor->shape_view().At(3), x_tensor->shape_view().At(4), 1.f / depth_scale,\n                    1.f / height_scale, 1.f / width_scale, y_tensor->mut_dptr<T>());\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\ntemplate<typename T>\nclass UpsampleNearestGrad3DGPUKernel final : public user_op::OpKernel {\n public:\n  UpsampleNearestGrad3DGPUKernel() = default;\n  ~UpsampleNearestGrad3DGPUKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    user_op::Tensor* dx_tensor = ctx->Tensor4ArgNameAndIndex(\"dx\", 0);\n\n    Memset<DeviceType::kCUDA>(ctx->stream(), dx_tensor->mut_dptr<T>(), 0,\n                              dx_tensor->shape_view().elem_cnt() * sizeof(T));\n    const user_op::Tensor* dy_tensor = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    const std::vector<int64_t> output_size = ctx->Attr<std::vector<int64_t>>(\"output_size\");\n    double depth_scale = ctx->Attr<double>(\"depth_scale\");\n    double height_scale = ctx->Attr<double>(\"height_scale\");\n    double width_scale = ctx->Attr<double>(\"width_scale\");\n    const int64_t in_depth = dx_tensor->shape_view().At(2);\n    const int64_t in_height = dx_tensor->shape_view().At(3);\n    const int64_t in_width = dx_tensor->shape_view().At(4);\n    const int64_t out_depth = dy_tensor->shape_view().At(2);\n    const int64_t out_height = dy_tensor->shape_view().At(3);\n    const int64_t out_width = dy_tensor->shape_view().At(4);\n    const int64_t elem_cnt = dy_tensor->shape_view().elem_cnt();\n    if (!output_size.empty()) {\n      depth_scale = static_cast<double>(out_depth) / static_cast<double>(in_depth);\n      height_scale = static_cast<double>(out_height) / static_cast<double>(in_height);\n      width_scale = static_cast<double>(out_width) / static_cast<double>(in_width);\n    }\n    NdIndexOffsetHelper<int64_t, 5> dy_helper(\n        dy_tensor->shape_view().At(0), dy_tensor->shape_view().At(1), dy_tensor->shape_view().At(2),\n        dy_tensor->shape_view().At(3), dy_tensor->shape_view().At(4));\n    NdIndexOffsetHelper<int64_t, 5> dx_helper(\n        dx_tensor->shape_view().At(0), dx_tensor->shape_view().At(1), dx_tensor->shape_view().At(2),\n        dx_tensor->shape_view().At(3), dx_tensor->shape_view().At(4));\n    RUN_CUDA_KERNEL((UpsampleNearest3DBackward<T>), ctx->stream(), elem_cnt, elem_cnt,\n                    dy_tensor->dptr<T>(), dy_helper, dx_helper, dx_tensor->shape_view().At(2),\n                    dx_tensor->shape_view().At(3), dx_tensor->shape_view().At(4), 1.f / depth_scale,\n                    1.f / height_scale, 1.f / width_scale, dx_tensor->mut_dptr<T>());\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_UPSAMPNEAREST3D_CUDA_KERNEL(dtype)                                     \\\n  REGISTER_USER_KERNEL(\"upsample_nearest_3d\")                                           \\\n      .SetCreateFn<UpsampleNearest3DGPUKernel<dtype>>()                                 \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)                  \\\n                       && (user_op::HobDataType(\"y\", 0) == GetDataType<dtype>::value)); \\\n  REGISTER_USER_KERNEL(\"upsample_nearest_3d_grad\")                                      \\\n      .SetCreateFn<UpsampleNearestGrad3DGPUKernel<dtype>>()                             \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)                  \\\n                       && (user_op::HobDataType(\"dx\", 0) == GetDataType<dtype>::value));\n\nREGISTER_UPSAMPNEAREST3D_CUDA_KERNEL(float)\nREGISTER_UPSAMPNEAREST3D_CUDA_KERNEL(double)\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/upsample_trilinear_3d_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/new_kernel_util.h\"\n#include \"oneflow/core/common/nd_index_offset_helper.h\"\n#include \"oneflow/user/kernels/upsample_kernel.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename T>\nstatic void UpsampleTrilinear3DForward(const int64_t elem_cnt, const T* in_dptr,\n                                       NdIndexOffsetHelper<int64_t, 5> in_helper,\n                                       NdIndexOffsetHelper<int64_t, 5> out_helper,\n                                       const int64_t in_depth, const int64_t in_height,\n                                       const int64_t in_width, const T rdepth, const T rheight,\n                                       const T rwidth, const bool align_corners, T* out_dptr) {\n  for (int64_t index = 0; index < elem_cnt; ++index) {\n    int64_t n, c, d, h, w;\n    out_helper.OffsetToNdIndex(index, n, c, d, h, w);\n\n    const T t1r = GetAreaPixel(rdepth, d, align_corners);\n    const int64_t t1 = t1r;\n    const int64_t t1p = (t1 < in_depth - 1) ? 1 : 0;\n    const T t1lambda = t1r - t1;\n    const T t0lambda = static_cast<T>(1.) - t1lambda;\n\n    const T h1r = GetAreaPixel(rheight, h, align_corners);\n    const int64_t h1 = h1r;\n    const int64_t h1p = (h1 < in_height - 1) ? 1 : 0;\n    const T h1lambda = h1r - h1;\n    const T h0lambda = static_cast<T>(1.) - h1lambda;\n\n    const T w1r = GetAreaPixel(rwidth, w, align_corners);\n    const int64_t w1 = w1r;\n    const int64_t w1p = (w1 < in_width - 1) ? 1 : 0;\n    const T w1lambda = w1r - w1;\n    const T w0lambda = static_cast<T>(1.) - w1lambda;\n\n    const T* pos1 = &in_dptr[in_helper.NdIndexToOffset(n, c, t1, h1, w1)];\n\n    out_dptr[index] =\n        t0lambda\n            * (h0lambda * (w0lambda * pos1[0] + w1lambda * pos1[w1p])\n               + h1lambda\n                     * (w0lambda * pos1[h1p * in_width] + w1lambda * pos1[h1p * in_width + w1p]))\n        + t1lambda\n              * (h0lambda\n                     * (w0lambda * pos1[t1p * in_height * in_width]\n                        + w1lambda * pos1[t1p * in_height * in_width + w1p])\n                 + h1lambda\n                       * (w0lambda * pos1[t1p * in_height * in_width + h1p * in_width]\n                          + w1lambda * pos1[t1p * in_height * in_width + h1p * in_width + w1p]));\n  }\n}\n\ntemplate<typename T>\nstatic void UpsampleTrilinear3DBackward(const int64_t elem_cnt, const T* dy_dptr,\n                                        NdIndexOffsetHelper<int64_t, 5> dy_helper,\n                                        NdIndexOffsetHelper<int64_t, 5> dx_helper,\n                                        const int64_t in_depth, const int64_t in_height,\n                                        const int64_t in_width, const T rdepth, const T rheight,\n                                        const T rwidth, const bool align_corners, T* dx_dptr) {\n  for (int64_t index = 0; index < elem_cnt; ++index) {\n    int64_t n, c, d, h, w;\n    dy_helper.OffsetToNdIndex(index, n, c, d, h, w);\n\n    const T t1r = GetAreaPixel(rdepth, d, align_corners);\n    const int64_t t1 = t1r;\n    const int64_t t1p = (t1 < in_depth - 1) ? 1 : 0;\n    const T t1lambda = t1r - t1;\n    const T t0lambda = static_cast<T>(1.) - t1lambda;\n\n    const T h1r = GetAreaPixel(rheight, h, align_corners);\n    const int64_t h1 = h1r;\n    const int64_t h1p = (h1 < in_height - 1) ? 1 : 0;\n    const T h1lambda = h1r - h1;\n    const T h0lambda = static_cast<T>(1.) - h1lambda;\n\n    const T w1r = GetAreaPixel(rwidth, w, align_corners);\n    const int64_t w1 = w1r;\n    const int64_t w1p = (w1 < in_width - 1) ? 1 : 0;\n    const T w1lambda = w1r - w1;\n    const T w0lambda = static_cast<T>(1.) - w1lambda;\n\n    T* pos1 = &dx_dptr[dx_helper.NdIndexToOffset(n, c, t1, h1, w1)];\n    const T* pos2 = &dy_dptr[index];\n\n    pos1[0] += t0lambda * h0lambda * w0lambda * pos2[0];\n    pos1[w1p] += t0lambda * h0lambda * w1lambda * pos2[0];\n    pos1[h1p * in_width] += t0lambda * h1lambda * w0lambda * pos2[0];\n    pos1[h1p * in_width + w1p] += t0lambda * h1lambda * w1lambda * pos2[0];\n    pos1[t1p * in_height * in_width] += t1lambda * h0lambda * w0lambda * pos2[0];\n    pos1[t1p * in_height * in_width + w1p] += t1lambda * h0lambda * w1lambda * pos2[0];\n    pos1[t1p * in_height * in_width + h1p * in_width] += t1lambda * h1lambda * w0lambda * pos2[0];\n    pos1[t1p * in_height * in_width + h1p * in_width + w1p] +=\n        t1lambda * h1lambda * w1lambda * pos2[0];\n  }\n}\n\n}  // namespace\n\ntemplate<typename T>\nclass UpsampleTrilinear3DCPUKernel final : public user_op::OpKernel {\n public:\n  UpsampleTrilinear3DCPUKernel() = default;\n  ~UpsampleTrilinear3DCPUKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* x_tensor = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    user_op::Tensor* y_tensor = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    const bool align_corners = ctx->Attr<bool>(\"align_corners\");\n    const int64_t elem_cnt = y_tensor->shape_view().elem_cnt();\n    NdIndexOffsetHelper<int64_t, 5> in_helper(\n        x_tensor->shape_view().At(0), x_tensor->shape_view().At(1), x_tensor->shape_view().At(2),\n        x_tensor->shape_view().At(3), x_tensor->shape_view().At(4));\n    NdIndexOffsetHelper<int64_t, 5> out_helper(\n        y_tensor->shape_view().At(0), y_tensor->shape_view().At(1), y_tensor->shape_view().At(2),\n        y_tensor->shape_view().At(3), y_tensor->shape_view().At(4));\n\n    const int64_t in_depth = x_tensor->shape_view().At(2);\n    const int64_t in_height = x_tensor->shape_view().At(3);\n    const int64_t in_width = x_tensor->shape_view().At(4);\n\n    const int64_t out_depth = y_tensor->shape_view().At(2);\n    const int64_t out_height = y_tensor->shape_view().At(3);\n    const int64_t out_width = y_tensor->shape_view().At(4);\n\n    const std::vector<int64_t> output_size = ctx->Attr<std::vector<int64_t>>(\"output_size\");\n    double depth_scale = ctx->Attr<double>(\"depth_scale\");\n    double height_scale = ctx->Attr<double>(\"height_scale\");\n    double width_scale = ctx->Attr<double>(\"width_scale\");\n    if (!output_size.empty()) {\n      depth_scale = static_cast<double>(out_depth) / static_cast<double>(in_depth);\n      height_scale = static_cast<double>(out_height) / static_cast<double>(in_height);\n      width_scale = static_cast<double>(out_width) / static_cast<double>(in_width);\n    }\n\n    const T scale_depth = GetAreaPixelScale(in_depth, out_depth, align_corners, depth_scale);\n    const T scale_height = GetAreaPixelScale(in_height, out_height, align_corners, height_scale);\n    const T scale_width = GetAreaPixelScale(in_width, out_width, align_corners, width_scale);\n\n    UpsampleTrilinear3DForward<T>(elem_cnt, x_tensor->dptr<T>(), in_helper, out_helper,\n                                  x_tensor->shape_view().At(2), x_tensor->shape_view().At(3),\n                                  x_tensor->shape_view().At(4), scale_depth, scale_height,\n                                  scale_width, align_corners, y_tensor->mut_dptr<T>());\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\ntemplate<typename T>\nclass UpsampleTrilinearGrad3DCPUKernel final : public user_op::OpKernel {\n public:\n  UpsampleTrilinearGrad3DCPUKernel() = default;\n  ~UpsampleTrilinearGrad3DCPUKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    user_op::Tensor* dx_tensor = ctx->Tensor4ArgNameAndIndex(\"dx\", 0);\n\n    Memset<DeviceType::kCPU>(ctx->stream(), dx_tensor->mut_dptr<T>(), 0,\n                             dx_tensor->shape_view().elem_cnt() * sizeof(T));\n    const user_op::Tensor* dy_tensor = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    const bool align_corners = ctx->Attr<bool>(\"align_corners\");\n    const int64_t elem_cnt = dy_tensor->shape_view().elem_cnt();\n    NdIndexOffsetHelper<int64_t, 5> dy_helper(\n        dy_tensor->shape_view().At(0), dy_tensor->shape_view().At(1), dy_tensor->shape_view().At(2),\n        dy_tensor->shape_view().At(3), dy_tensor->shape_view().At(4));\n    NdIndexOffsetHelper<int64_t, 5> dx_helper(\n        dx_tensor->shape_view().At(0), dx_tensor->shape_view().At(1), dx_tensor->shape_view().At(2),\n        dx_tensor->shape_view().At(3), dx_tensor->shape_view().At(4));\n\n    const int64_t in_depth = dx_tensor->shape_view().At(2);\n    const int64_t in_height = dx_tensor->shape_view().At(3);\n    const int64_t in_width = dx_tensor->shape_view().At(4);\n\n    const int64_t out_depth = dy_tensor->shape_view().At(2);\n    const int64_t out_height = dy_tensor->shape_view().At(3);\n    const int64_t out_width = dy_tensor->shape_view().At(4);\n\n    const std::vector<int64_t> output_size = ctx->Attr<std::vector<int64_t>>(\"output_size\");\n    double depth_scale = ctx->Attr<double>(\"depth_scale\");\n    double height_scale = ctx->Attr<double>(\"height_scale\");\n    double width_scale = ctx->Attr<double>(\"width_scale\");\n    if (!output_size.empty()) {\n      depth_scale = static_cast<double>(out_depth) / static_cast<double>(in_depth);\n      height_scale = static_cast<double>(out_height) / static_cast<double>(in_height);\n      width_scale = static_cast<double>(out_width) / static_cast<double>(in_width);\n    }\n\n    const T scale_depth = GetAreaPixelScale(in_depth, out_depth, align_corners, depth_scale);\n    const T scale_height = GetAreaPixelScale(in_height, out_height, align_corners, height_scale);\n    const T scale_width = GetAreaPixelScale(in_width, out_width, align_corners, width_scale);\n\n    UpsampleTrilinear3DBackward<T>(elem_cnt, dy_tensor->dptr<T>(), dy_helper, dx_helper,\n                                   dx_tensor->shape_view().At(2), dx_tensor->shape_view().At(3),\n                                   dx_tensor->shape_view().At(4), scale_depth, scale_height,\n                                   scale_width, align_corners, dx_tensor->mut_dptr<T>());\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_UPSAMPTRILINEAR3D_CPU_KERNEL(dtype)                                    \\\n  REGISTER_USER_KERNEL(\"upsample_trilinear_3d\")                                         \\\n      .SetCreateFn<UpsampleTrilinear3DCPUKernel<dtype>>()                               \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)                   \\\n                       && (user_op::HobDataType(\"y\", 0) == GetDataType<dtype>::value)); \\\n  REGISTER_USER_KERNEL(\"upsample_trilinear_3d_grad\")                                    \\\n      .SetCreateFn<UpsampleTrilinearGrad3DCPUKernel<dtype>>()                           \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)                   \\\n                       && (user_op::HobDataType(\"dx\", 0) == GetDataType<dtype>::value));\n\nREGISTER_UPSAMPTRILINEAR3D_CPU_KERNEL(float)\nREGISTER_UPSAMPTRILINEAR3D_CPU_KERNEL(double)\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/upsample_trilinear_3d_kernel.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/new_kernel_util.h\"\n#include \"oneflow/core/common/nd_index_offset_helper.h\"\n#include \"oneflow/core/cuda/atomic.cuh\"\n#include \"oneflow/user/kernels/upsample_kernel.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename T>\n__global__ void UpsampleTrilinear3DForward(const int64_t elem_cnt, const T* in_dptr,\n                                           NdIndexOffsetHelper<int64_t, 5> in_helper,\n                                           NdIndexOffsetHelper<int64_t, 5> out_helper,\n                                           const int64_t in_depth, const int64_t in_height,\n                                           const int64_t in_width, const T rdepth, const T rheight,\n                                           const T rwidth, const bool align_corners, T* out_dptr) {\n  CUDA_1D_KERNEL_LOOP(index, elem_cnt) {\n    int64_t n, c, d, h, w;\n    out_helper.OffsetToNdIndex(index, n, c, d, h, w);\n\n    const T t1r = GetAreaPixel(rdepth, d, align_corners);\n    const int64_t t1 = t1r;\n    const int64_t t1p = (t1 < in_depth - 1) ? 1 : 0;\n    const T t1lambda = t1r - t1;\n    const T t0lambda = static_cast<T>(1.) - t1lambda;\n\n    const T h1r = GetAreaPixel(rheight, h, align_corners);\n    const int64_t h1 = h1r;\n    const int64_t h1p = (h1 < in_height - 1) ? 1 : 0;\n    const T h1lambda = h1r - h1;\n    const T h0lambda = static_cast<T>(1.) - h1lambda;\n\n    const T w1r = GetAreaPixel(rwidth, w, align_corners);\n    const int64_t w1 = w1r;\n    const int64_t w1p = (w1 < in_width - 1) ? 1 : 0;\n    const T w1lambda = w1r - w1;\n    const T w0lambda = static_cast<T>(1.) - w1lambda;\n\n    const T* pos1 = &in_dptr[in_helper.NdIndexToOffset(n, c, t1, h1, w1)];\n\n    out_dptr[index] =\n        t0lambda\n            * (h0lambda * (w0lambda * pos1[0] + w1lambda * pos1[w1p])\n               + h1lambda\n                     * (w0lambda * pos1[h1p * in_width] + w1lambda * pos1[h1p * in_width + w1p]))\n        + t1lambda\n              * (h0lambda\n                     * (w0lambda * pos1[t1p * in_height * in_width]\n                        + w1lambda * pos1[t1p * in_height * in_width + w1p])\n                 + h1lambda\n                       * (w0lambda * pos1[t1p * in_height * in_width + h1p * in_width]\n                          + w1lambda * pos1[t1p * in_height * in_width + h1p * in_width + w1p]));\n  }\n}\n\ntemplate<typename T>\n__global__ void UpsampleTrilinear3DBackward(const int64_t elem_cnt, const T* dy_dptr,\n                                            NdIndexOffsetHelper<int64_t, 5> dy_helper,\n                                            NdIndexOffsetHelper<int64_t, 5> dx_helper,\n                                            const int64_t in_depth, const int64_t in_height,\n                                            const int64_t in_width, const T rdepth, const T rheight,\n                                            const T rwidth, const bool align_corners, T* dx_dptr) {\n  CUDA_1D_KERNEL_LOOP(index, elem_cnt) {\n    int64_t n, c, d, h, w;\n    dy_helper.OffsetToNdIndex(index, n, c, d, h, w);\n\n    const T t1r = GetAreaPixel(rdepth, d, align_corners);\n    const int64_t t1 = t1r;\n    const int64_t t1p = (t1 < in_depth - 1) ? 1 : 0;\n    const T t1lambda = t1r - t1;\n    const T t0lambda = static_cast<T>(1.) - t1lambda;\n\n    const T h1r = GetAreaPixel(rheight, h, align_corners);\n    const int64_t h1 = h1r;\n    const int64_t h1p = (h1 < in_height - 1) ? 1 : 0;\n    const T h1lambda = h1r - h1;\n    const T h0lambda = static_cast<T>(1.) - h1lambda;\n\n    const T w1r = GetAreaPixel(rwidth, w, align_corners);\n    const int64_t w1 = w1r;\n    const int64_t w1p = (w1 < in_width - 1) ? 1 : 0;\n    const T w1lambda = w1r - w1;\n    const T w0lambda = static_cast<T>(1.) - w1lambda;\n\n    T* pos1 = &dx_dptr[dx_helper.NdIndexToOffset(n, c, t1, h1, w1)];\n    const T* pos2 = &dy_dptr[index];\n\n    cuda::atomic::FastAdd(pos1, 0, elem_cnt, t0lambda * h0lambda * w0lambda * pos2[0]);\n    cuda::atomic::FastAdd(pos1, w1p, elem_cnt, t0lambda * h0lambda * w1lambda * pos2[0]);\n    cuda::atomic::FastAdd(pos1, h1p * in_width, elem_cnt, t0lambda * h1lambda * w0lambda * pos2[0]);\n    cuda::atomic::FastAdd(pos1, h1p * in_width + w1p, elem_cnt,\n                          t0lambda * h1lambda * w1lambda * pos2[0]);\n    cuda::atomic::FastAdd(pos1, t1p * in_height * in_width, elem_cnt,\n                          t1lambda * h0lambda * w0lambda * pos2[0]);\n    cuda::atomic::FastAdd(pos1, t1p * in_height * in_width + w1p, elem_cnt,\n                          t1lambda * h0lambda * w1lambda * pos2[0]);\n    cuda::atomic::FastAdd(pos1, t1p * in_height * in_width + h1p * in_width, elem_cnt,\n                          t1lambda * h1lambda * w0lambda * pos2[0]);\n    cuda::atomic::FastAdd(pos1, t1p * in_height * in_width + h1p * in_width + w1p, elem_cnt,\n                          t1lambda * h1lambda * w1lambda * pos2[0]);\n  }\n}\n\n}  // namespace\n\ntemplate<typename T>\nclass UpsampleTrilinear3DGPUKernel final : public user_op::OpKernel {\n public:\n  UpsampleTrilinear3DGPUKernel() = default;\n  ~UpsampleTrilinear3DGPUKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* x_tensor = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    user_op::Tensor* y_tensor = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    const bool align_corners = ctx->Attr<bool>(\"align_corners\");\n    const int64_t elem_cnt = y_tensor->shape_view().elem_cnt();\n    NdIndexOffsetHelper<int64_t, 5> in_helper(\n        x_tensor->shape_view().At(0), x_tensor->shape_view().At(1), x_tensor->shape_view().At(2),\n        x_tensor->shape_view().At(3), x_tensor->shape_view().At(4));\n    NdIndexOffsetHelper<int64_t, 5> out_helper(\n        y_tensor->shape_view().At(0), y_tensor->shape_view().At(1), y_tensor->shape_view().At(2),\n        y_tensor->shape_view().At(3), y_tensor->shape_view().At(4));\n\n    const int64_t in_depth = x_tensor->shape_view().At(2);\n    const int64_t in_height = x_tensor->shape_view().At(3);\n    const int64_t in_width = x_tensor->shape_view().At(4);\n\n    const int64_t out_depth = y_tensor->shape_view().At(2);\n    const int64_t out_height = y_tensor->shape_view().At(3);\n    const int64_t out_width = y_tensor->shape_view().At(4);\n\n    const std::vector<int64_t> output_size = ctx->Attr<std::vector<int64_t>>(\"output_size\");\n    double depth_scale = ctx->Attr<double>(\"depth_scale\");\n    double height_scale = ctx->Attr<double>(\"height_scale\");\n    double width_scale = ctx->Attr<double>(\"width_scale\");\n    if (!output_size.empty()) {\n      depth_scale = static_cast<double>(out_depth) / static_cast<double>(in_depth);\n      height_scale = static_cast<double>(out_height) / static_cast<double>(in_height);\n      width_scale = static_cast<double>(out_width) / static_cast<double>(in_width);\n    }\n\n    const T scale_depth = GetAreaPixelScale(in_depth, out_depth, align_corners, depth_scale);\n    const T scale_height = GetAreaPixelScale(in_height, out_height, align_corners, height_scale);\n    const T scale_width = GetAreaPixelScale(in_width, out_width, align_corners, width_scale);\n\n    RUN_CUDA_KERNEL((UpsampleTrilinear3DForward<T>), ctx->stream(), elem_cnt, elem_cnt,\n                    x_tensor->dptr<T>(), in_helper, out_helper, x_tensor->shape_view().At(2),\n                    x_tensor->shape_view().At(3), x_tensor->shape_view().At(4), scale_depth,\n                    scale_height, scale_width, align_corners, y_tensor->mut_dptr<T>());\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\ntemplate<typename T>\nclass UpsampleTrilinearGrad3DGPUKernel final : public user_op::OpKernel {\n public:\n  UpsampleTrilinearGrad3DGPUKernel() = default;\n  ~UpsampleTrilinearGrad3DGPUKernel() = default;\n\n private:\n  using user_op::OpKernel::Compute;\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    user_op::Tensor* dx_tensor = ctx->Tensor4ArgNameAndIndex(\"dx\", 0);\n\n    Memset<DeviceType::kCUDA>(ctx->stream(), dx_tensor->mut_dptr<T>(), 0,\n                              dx_tensor->shape_view().elem_cnt() * sizeof(T));\n    const user_op::Tensor* dy_tensor = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    const bool align_corners = ctx->Attr<bool>(\"align_corners\");\n    const int64_t elem_cnt = dy_tensor->shape_view().elem_cnt();\n    NdIndexOffsetHelper<int64_t, 5> dy_helper(\n        dy_tensor->shape_view().At(0), dy_tensor->shape_view().At(1), dy_tensor->shape_view().At(2),\n        dy_tensor->shape_view().At(3), dy_tensor->shape_view().At(4));\n    NdIndexOffsetHelper<int64_t, 5> dx_helper(\n        dx_tensor->shape_view().At(0), dx_tensor->shape_view().At(1), dx_tensor->shape_view().At(2),\n        dx_tensor->shape_view().At(3), dx_tensor->shape_view().At(4));\n\n    const int64_t in_depth = dx_tensor->shape_view().At(2);\n    const int64_t in_height = dx_tensor->shape_view().At(3);\n    const int64_t in_width = dx_tensor->shape_view().At(4);\n\n    const int64_t out_depth = dy_tensor->shape_view().At(2);\n    const int64_t out_height = dy_tensor->shape_view().At(3);\n    const int64_t out_width = dy_tensor->shape_view().At(4);\n\n    const std::vector<int64_t> output_size = ctx->Attr<std::vector<int64_t>>(\"output_size\");\n    double depth_scale = ctx->Attr<double>(\"depth_scale\");\n    double height_scale = ctx->Attr<double>(\"height_scale\");\n    double width_scale = ctx->Attr<double>(\"width_scale\");\n    if (!output_size.empty()) {\n      depth_scale = static_cast<double>(out_depth) / static_cast<double>(in_depth);\n      height_scale = static_cast<double>(out_height) / static_cast<double>(in_height);\n      width_scale = static_cast<double>(out_width) / static_cast<double>(in_width);\n    }\n\n    const T scale_depth = GetAreaPixelScale(in_depth, out_depth, align_corners, depth_scale);\n    const T scale_height = GetAreaPixelScale(in_height, out_height, align_corners, height_scale);\n    const T scale_width = GetAreaPixelScale(in_width, out_width, align_corners, width_scale);\n\n    RUN_CUDA_KERNEL((UpsampleTrilinear3DBackward<T>), ctx->stream(), elem_cnt, elem_cnt,\n                    dy_tensor->dptr<T>(), dy_helper, dx_helper, dx_tensor->shape_view().At(2),\n                    dx_tensor->shape_view().At(3), dx_tensor->shape_view().At(4), scale_depth,\n                    scale_height, scale_width, align_corners, dx_tensor->mut_dptr<T>());\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_UPSAMPTRILINEAR3D_CUDA_KERNEL(dtype)                                   \\\n  REGISTER_USER_KERNEL(\"upsample_trilinear_3d\")                                         \\\n      .SetCreateFn<UpsampleTrilinear3DGPUKernel<dtype>>()                               \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)                  \\\n                       && (user_op::HobDataType(\"y\", 0) == GetDataType<dtype>::value)); \\\n  REGISTER_USER_KERNEL(\"upsample_trilinear_3d_grad\")                                    \\\n      .SetCreateFn<UpsampleTrilinearGrad3DGPUKernel<dtype>>()                           \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)                  \\\n                       && (user_op::HobDataType(\"dx\", 0) == GetDataType<dtype>::value));\n\nREGISTER_UPSAMPTRILINEAR3D_CUDA_KERNEL(float)\nREGISTER_UPSAMPTRILINEAR3D_CUDA_KERNEL(double)\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/util_ops_kernels.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/user/kernels/elementwise_primitive_kernel.h\"\n\nnamespace oneflow {\nnamespace user_op {\n#define UTIL_OPS_SEQ                                            \\\n  OF_PP_MAKE_TUPLE_SEQ(\"isinf\", ep::primitive::UnaryOp::kIsInf) \\\n  OF_PP_MAKE_TUPLE_SEQ(\"isnan\", ep::primitive::UnaryOp::kIsNan) \\\n  OF_PP_MAKE_TUPLE_SEQ(\"isfinite\", ep::primitive::UnaryOp::kIsFinite)\n\n#define RISTER_UTIL_OPS(op_name, op_kind)                                                 \\\n  REGISTER_USER_KERNEL(op_name)                                                           \\\n      .SetCreateFn([]() {                                                                 \\\n        return user_op::NewOpKernel<UnaryPrimitiveKernel>(                                \\\n            \"out\", \"in\", [](user_op::KernelComputeContext* ctx) {                         \\\n              const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex(\"in\", 0);  \\\n              const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex(\"out\", 0); \\\n              return ep::primitive::NewPrimitive<ep::primitive::ElementwiseUnaryFactory>( \\\n                  ctx->device_type(), op_kind, src->data_type(), dst->data_type());       \\\n            });                                                                           \\\n      })                                                                                  \\\n      .SetIsMatchedHob(UnaryPrimitiveExists(op_kind, \"out\", \"in\"));\n\nOF_PP_FOR_EACH_TUPLE(RISTER_UTIL_OPS, UTIL_OPS_SEQ)\n#undef RISTER_UTIL_OPS\n#undef UTIL_OPS_SEQ\n}  // namespace user_op\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/variance_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/ndarray/ndarray_reduce.h\"\n#include \"oneflow/core/ndarray/ndarray_util.h\"\n#include \"oneflow/user/kernels/variance_kernel_util.h\"\n\nnamespace oneflow {\nnamespace user_op {\n\ntemplate<DeviceType device_type, typename T, typename ComputeType>\nclass VarKernel final : public user_op::OpKernel {\n public:\n  VarKernel() = default;\n  ~VarKernel() override = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* input = ctx->Tensor4ArgNameAndIndex(\"input\", 0);\n    user_op::Tensor* output = ctx->Tensor4ArgNameAndIndex(\"output\", 0);\n    const bool unbiased = ctx->Attr<bool>(\"unbiased\");\n    const T* in_ptr = input->dptr<T>();\n    T* out_ptr = output->mut_dptr<T>();\n    const std::vector<int32_t> axis = ctx->Attr<std::vector<int32_t>>(\"dim\");\n    const int64_t input_dim_element = input->shape_view().elem_cnt();\n    int64_t axis_dim_element = 1;\n    for (int64_t i = 0; i < axis.size(); ++i) {\n      axis_dim_element *= input->shape_view().At(axis[i]);\n    }\n    // when computing the variance with all the elements, the implementation of cuda kernel may use\n    // tmp buffer for computation.\n    ComputeType* tmp_buffer_ptr =\n        (input_dim_element > 0\n         && (axis.size() == input->shape_view().NumAxes() || input_dim_element == axis_dim_element)\n         && DeviceType::kCUDA == device_type)\n            ? ctx->Tensor4ArgNameAndIndex(\"tmp_buffer\", 0)->mut_dptr<ComputeType>()\n            : nullptr;\n    VarParamHelper param_helper(input->shape_view(), axis, unbiased);\n    VarFunctor<device_type, T, ComputeType>()(ctx->stream(), in_ptr, out_ptr, tmp_buffer_ptr,\n                                              param_helper.param);\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\n#define REGISTER_VAR_CPU_KERNEL(dtype, compute_type)                   \\\n  REGISTER_USER_KERNEL(\"var\")                                          \\\n      .SetCreateFn<VarKernel<DeviceType::kCPU, dtype, compute_type>>() \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)  \\\n                       && (user_op::HobAttr<DataType>(\"dtype\") == GetDataType<dtype>::value));\nREGISTER_VAR_CPU_KERNEL(float, double)\nREGISTER_VAR_CPU_KERNEL(double, double)\nREGISTER_VAR_CPU_KERNEL(float16, double)\nREGISTER_VAR_CPU_KERNEL(bfloat16, double)\n#undef REGISTER_VAR_CPU_KERNEL\n\n#ifdef WITH_CUDA\n\ntemplate<typename ComputeType>\nsize_t InferTmpBufferSize(user_op::InferContext* ctx) {\n  const TensorDesc& input = ctx->InputTensorDesc(\"input\", 0);\n  const Shape& input_shape = input.shape();\n  const std::vector<int32_t> axis = ctx->Attr<std::vector<int32_t>>(\"dim\");\n  const int64_t input_dim_element = input.shape().elem_cnt();\n  int64_t axis_dim_element = 1;\n  for (int64_t i = 0; i < axis.size(); ++i) { axis_dim_element *= input.shape().At(axis[i]); }\n  if (input_dim_element > 0\n      && (axis.size() == input_shape.NumAxes() || input_dim_element == axis_dim_element)) {\n    return GetCudaAlignedSize(\n        std::min(static_cast<int32_t>(std::ceil(std::sqrt(input.shape().elem_cnt()))),\n                 kCudaMaxBlocksNum)\n        * sizeof(ComputeType) * 3);\n  }\n  return 0;\n}\n\n#define REGISTER_VAR_CUDA_KERNEL(dtype, compute_type)                                         \\\n  REGISTER_USER_KERNEL(\"var\")                                                                 \\\n      .SetCreateFn<VarKernel<DeviceType::kCUDA, dtype, compute_type>>()                       \\\n      .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)                        \\\n                       && (user_op::HobAttr<DataType>(\"dtype\") == GetDataType<dtype>::value)) \\\n      .SetInferTmpSizeFn(InferTmpBufferSize<compute_type>);\n\nREGISTER_VAR_CUDA_KERNEL(float, double)\nREGISTER_VAR_CUDA_KERNEL(double, double)\nREGISTER_VAR_CUDA_KERNEL(half, double)\n#if CUDA_VERSION >= 11000\nREGISTER_VAR_CUDA_KERNEL(nv_bfloat16, double)\n#endif  // CUDA_VERSION >= 11000\n#undef REGISTER_VAR_CUDA_KERNEL\n#endif  // WITH_CUDA\n\n}  // namespace user_op\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/variance_kernel_util.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/user/kernels/variance_kernel_util.h\"\n\nnamespace oneflow {\n\nnamespace user_op {\n\ntemplate<typename T, typename ComputeType>\nstruct VarFunctor<DeviceType::kCPU, T, ComputeType> final {\n  void operator()(ep::Stream* stream, const T* in_ptr, T* out_ptr, ComputeType* tmp_buffer_ptr,\n                  const VarParam var_param) {\n    // if var_param.parallel_num is 0, do nothing, return 0-size tensor\n    if (IsNanOut(var_param)) {\n      for (size_t i = 0; i < var_param.parallel_num; i++) {\n        out_ptr[i] = std::numeric_limits<T>::quiet_NaN();\n      }\n    } else {\n      for (size_t i = 0; i < var_param.parallel_num; i++) {\n        const size_t input_offset = LinearIndex2Offset(\n            i, var_param.dim_size_in_caxis, var_param.stride_in_caxis, var_param.caxis_size);\n        ComputeVarUsingWelford<T, ComputeType>(&in_ptr[input_offset], &out_ptr[i], var_param);\n      }\n    }\n  }\n};\n\ntemplate struct VarFunctor<DeviceType::kCPU, float, double>;\ntemplate struct VarFunctor<DeviceType::kCPU, double, double>;\ntemplate struct VarFunctor<DeviceType::kCPU, float16, double>;\ntemplate struct VarFunctor<DeviceType::kCPU, bfloat16, double>;\n}  // namespace user_op\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/variance_kernel_util.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/user/kernels/variance_kernel_util.h\"\n#include \"oneflow/core/cuda/layer_norm.cuh\"\n\nnamespace oneflow {\nnamespace user_op {\n\nnamespace {\ntemplate<typename T>\n__inline__ __device__ T Nan();\n\ntemplate<>\n__inline__ __device__ float Nan<float>() {\n  return CUDART_NAN_F;\n}\n\ntemplate<>\n__inline__ __device__ double Nan<double>() {\n  return CUDART_NAN;\n}\n\ntemplate<>\n__inline__ __device__ half Nan<half>() {\n  return half(CUDART_NAN_F);\n}\n\n#if CUDA_VERSION >= 11000\ntemplate<>\n__inline__ __device__ nv_bfloat16 Nan<nv_bfloat16>() {\n  return nv_bfloat16(CUDART_NAN_F);\n}\n#endif\n}  // namespace\n\ntemplate<typename T, typename ComputeType>\n__global__ void ComputeVarUsingWelfordWrapper(const T* in_ptr, T* out_ptr, const VarParam var_param,\n                                              bool is_nan) {\n  if (is_nan) {\n    CUDA_1D_KERNEL_LOOP(i, var_param.parallel_num) { out_ptr[i] = Nan<T>(); }\n  } else {\n    CUDA_1D_KERNEL_LOOP(i, var_param.parallel_num) {\n      const size_t input_offset = LinearIndex2Offset(\n          i, var_param.dim_size_in_caxis, var_param.stride_in_caxis, var_param.caxis_size);\n      ComputeVarUsingWelford<T, ComputeType>(&in_ptr[input_offset], &out_ptr[i], var_param);\n    }\n  }\n}\n\nnamespace {\ntemplate<typename T, typename ComputeType>\ninline __device__ void WelfordReduce(const T* in_ptr, ComputeType* mean, ComputeType* m2,\n                                     ComputeType* count, const size_t total_elem_cnt,\n                                     const size_t start, const size_t step) {\n  ComputeType old_mean = 0.0;\n  for (size_t i = start; i < total_elem_cnt; i += step) {\n    ++(*count);\n    old_mean = *mean;\n    *mean += (static_cast<ComputeType>(in_ptr[i]) - *mean) / *count;\n    *m2 += (static_cast<ComputeType>(in_ptr[i]) - *mean)\n           * (static_cast<ComputeType>(in_ptr[i]) - old_mean);\n  }\n}\n\ntemplate<typename T>\ninline __device__ void WelfordCombine(const T* b_mean, const T* b_m2, const T* b_count, T* mean,\n                                      T* m2, T* count, const size_t total_elem_cnt,\n                                      const size_t start, const size_t step) {\n  for (size_t i = start; i < total_elem_cnt; i += step) {\n    cuda::layer_norm::WelfordCombine(b_mean[i], b_m2[i], b_count[i], mean, m2, count);\n  }\n}\n__device__ int32_t done_block_count = 0;\n}  // namespace\n\ntemplate<typename T, typename ComputeType>\n__global__ void ComputeVarScalarOut(const T* in_ptr, T* out_ptr, ComputeType* tmp_buffer_ptr,\n                                    const VarParam var_param, bool is_nan) {\n  if (is_nan) {\n    if (blockIdx.x == 0 && threadIdx.x == 0) { *out_ptr = Nan<T>(); }\n    return;\n  }\n  const size_t elems_per_block = var_param.elem_cnt / gridDim.x;\n  const size_t elems_per_thread = elems_per_block / blockDim.x;\n  // tail element number in block\n  size_t tail_elems = elems_per_block % blockDim.x;\n\n  ComputeType thread_mean = 0.0;\n  ComputeType thread_m2 = 0.0;\n  ComputeType thread_count = 0.0;\n  // every thread deal it's elems\n  if (elems_per_thread > 0) {\n    const size_t block_offset = blockIdx.x * elems_per_block;\n    WelfordReduce<T, ComputeType>(&in_ptr[block_offset], &thread_mean, &thread_m2, &thread_count,\n                                  elems_per_block - tail_elems, threadIdx.x, blockDim.x);\n  }\n  // thread 0 of last block handles tail element between blocks\n  if (blockIdx.x == gridDim.x - 1 && threadIdx.x == 0) {\n    tail_elems += var_param.elem_cnt % gridDim.x;\n  }\n  // thread 0 deal tail elems\n  if (tail_elems != 0 && threadIdx.x == 0) {\n    const size_t tail_offset = blockIdx.x * elems_per_block + blockDim.x * elems_per_thread;\n    WelfordReduce<T, ComputeType>(&in_ptr[tail_offset], &thread_mean, &thread_m2, &thread_count,\n                                  tail_elems,\n                                  /*tail start=*/0, /*step=*/1);\n  }\n\n  ComputeType block_mean = 0;\n  ComputeType block_m2 = 0;\n  ComputeType block_count = 0;\n  cuda::layer_norm::WelfordBlockAllReduce<ComputeType>(thread_mean, thread_m2, thread_count,\n                                                       &block_mean, &block_m2, &block_count);\n\n  if (gridDim.x == 1) {\n    if (threadIdx.x == 0) {\n      *out_ptr =\n          cuda::layer_norm::Div(block_m2, (var_param.unbiased ? block_count - 1 : block_count));\n    }\n    return;\n  }\n\n  ComputeType* tmp_mean_ptr = tmp_buffer_ptr;\n  ComputeType* tmp_m2_ptr = &tmp_mean_ptr[gridDim.x];\n  ComputeType* tmp_count_ptr = &tmp_m2_ptr[gridDim.x];\n  if (threadIdx.x == 0) {\n    tmp_mean_ptr[blockIdx.x] = block_mean;\n    tmp_m2_ptr[blockIdx.x] = block_m2;\n    tmp_count_ptr[blockIdx.x] = block_count;\n  }\n  __shared__ bool is_last_block;\n  if (threadIdx.x == 0) { is_last_block = atomicAdd(&done_block_count, 1) == gridDim.x - 1; }\n  __syncthreads();\n  if (is_last_block) {\n    ComputeType last_block_thread_mean = 0;\n    ComputeType last_block_thread_m2 = 0;\n    ComputeType last_block_thread_count = 0;\n    const size_t welforddatas_per_thread = gridDim.x / blockDim.x;\n    const size_t tail_welforddatas = gridDim.x % blockDim.x;\n\n    if (welforddatas_per_thread > 0) {\n      WelfordCombine(tmp_mean_ptr, tmp_m2_ptr, tmp_count_ptr, &last_block_thread_mean,\n                     &last_block_thread_m2, &last_block_thread_count, gridDim.x - tail_welforddatas,\n                     threadIdx.x, blockDim.x);\n    }\n    // thread 0 deal tail welford data\n    if (tail_welforddatas != 0 && threadIdx.x == 0) {\n      const size_t last_block_tail_offset = blockDim.x * welforddatas_per_thread;\n      WelfordCombine(&tmp_mean_ptr[last_block_tail_offset], &tmp_m2_ptr[last_block_tail_offset],\n                     &tmp_count_ptr[last_block_tail_offset], &last_block_thread_mean,\n                     &last_block_thread_m2, &last_block_thread_count, tail_welforddatas,\n                     /*tail start=*/0, /*step=*/1);\n    }\n    ComputeType final_mean = 0;\n    ComputeType final_m2 = 0;\n    ComputeType final_count = 0;\n    cuda::layer_norm::WelfordBlockAllReduce<ComputeType>(\n        last_block_thread_mean, last_block_thread_m2, last_block_thread_count, &final_mean,\n        &final_m2, &final_count);\n    if (threadIdx.x == 0) {\n      *out_ptr =\n          cuda::layer_norm::Div(final_m2, (var_param.unbiased ? final_count - 1 : final_count));\n      done_block_count = 0;\n    }\n  }\n}\n\ntemplate<typename T, typename ComputeType>\nstruct VarFunctor<DeviceType::kCUDA, T, ComputeType> final {\n  void operator()(ep::Stream* stream, const T* in_ptr, T* out_ptr, ComputeType* tmp_buffer_ptr,\n                  const VarParam var_param) {\n    int grid_dim = 0;\n    int block_dim = 0;\n    SetGridDimAndBlockDim(var_param.elem_cnt, &grid_dim, &block_dim);\n    if (var_param.parallel_num == 1) {\n      ComputeVarScalarOut<T, ComputeType>\n          <<<grid_dim, block_dim, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(\n              in_ptr, out_ptr, tmp_buffer_ptr, var_param, IsNanOut(var_param));\n    } else {\n      // if var_param.parallel_num is 0, do nothing, return 0-size tensor\n      if (var_param.parallel_num == 0) { return; }\n      RUN_CUDA_KERNEL((ComputeVarUsingWelfordWrapper<T, ComputeType>), stream,\n                      var_param.parallel_num, in_ptr, out_ptr, var_param, IsNanOut(var_param));\n    }\n  }\n};\n\ntemplate struct VarFunctor<DeviceType::kCUDA, float, double>;\ntemplate struct VarFunctor<DeviceType::kCUDA, double, double>;\ntemplate struct VarFunctor<DeviceType::kCUDA, half, double>;\n\n#if CUDA_VERSION >= 11000\ntemplate struct VarFunctor<DeviceType::kCUDA, nv_bfloat16, double>;\n#endif\n}  // namespace user_op\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/variance_kernel_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_KERNELS_VARIANCE_KERNEL_UTIL_H_\n#define ONEFLOW_USER_KERNELS_VARIANCE_KERNEL_UTIL_H_\n\n#include \"oneflow/core/common/device_type.pb.h\"\n#include \"oneflow/core/common/shape_view.h\"\n#include \"oneflow/core/common/data_type.h\"\n#include \"oneflow/core/ep/include/stream.h\"\n#include \"oneflow/core/device/cuda_util.h\"\nnamespace oneflow {\nnamespace user_op {\n\nOF_DEVICE_FUNC size_t LinearIndex2Offset(const size_t linear_index,\n                                         const int32_t* dim_size_in_axis_ptr,\n                                         const int32_t* stride_vec_ptr, const int32_t size) {\n  // low dim at begin\n  size_t offset = 0;\n  size_t num_dim = 0;\n  for (int j = 0; j < size; j++) {\n    num_dim = (j == 0 ? linear_index : (num_dim / dim_size_in_axis_ptr[j - 1]));\n    offset += num_dim % dim_size_in_axis_ptr[j] * stride_vec_ptr[j];\n  }\n  return offset;\n}\n\nnamespace {\nconstexpr size_t MaxDims = 8;\n}  // namespace\n\nstruct VarParam {\n  VarParam() : unbiased(true), parallel_num(1), elem_cnt(1), axis_size(1), caxis_size(1) {}\n  bool unbiased;\n  size_t parallel_num;\n  size_t elem_cnt;\n  int32_t axis_size;\n  int32_t caxis_size;\n  int32_t stride_in_axis[MaxDims];\n  int32_t dim_size_in_axis[MaxDims];\n  int32_t stride_in_caxis[MaxDims];\n  int32_t dim_size_in_caxis[MaxDims];\n};\n\nclass VarParamHelper final {\n public:\n  VarParamHelper() = delete;\n  explicit VarParamHelper(const ShapeView& input_shape, const std::vector<int32_t> axis,\n                          const bool unbiased)\n      : axis_(axis), input_shape_(input_shape) {\n    param.unbiased = unbiased;\n    ComputeStrideVec(axis_, param.stride_in_axis);\n    caxis_ = GetCAxis();\n    ComputeStrideVec(caxis_, param.stride_in_caxis);\n    GetDimSizeInAxis(axis_, param.dim_size_in_axis);\n    GetDimSizeInAxis(caxis_, param.dim_size_in_caxis);\n    ComputeElemCntAndParallelNum();\n    param.axis_size = axis_.size();\n    param.caxis_size = caxis_.size();\n  }\n\n  VarParam param;\n\n private:\n  void ComputeElemCntAndParallelNum() {\n    for (int i = 0; i < axis_.size(); i++) { param.elem_cnt *= input_shape_.At(axis_[i]); }\n    for (int i = 0; i < caxis_.size(); i++) { param.parallel_num *= input_shape_.At(caxis_[i]); }\n  }\n\n  void ComputeStrideVec(const std::vector<int32_t> axis, int32_t* stride_vec) {\n    // low dim at begin\n    const int axis_size = axis.size();\n    for (int i = 0; i < axis_size; i++) {\n      int stride = 1;\n      if (axis.at(i) + 1 == input_shape_.NumAxes()) {\n        stride_vec[axis_size - 1 - i] = 1;\n      } else {\n        for (int j = axis.at(i) + 1; j < input_shape_.NumAxes(); j++) {\n          stride *= input_shape_.At(j);\n        }\n        stride_vec[axis_size - 1 - i] = stride;\n      }\n    }\n  }\n\n  std::vector<int32_t> GetCAxis() {\n    std::vector<int32_t> caxis;\n    caxis.resize(input_shape_.NumAxes());\n    std::iota(caxis.begin(), caxis.end(), 0);\n    for (int i = 0; i < axis_.size(); i++) { caxis.erase(caxis.begin() + axis_.at(i) - i); }\n    return caxis;\n  }\n\n  void GetDimSizeInAxis(const std::vector<int32_t> axis, int32_t* dim_size_in_axis) {\n    // low dim at begin\n    const int axis_size = axis.size();\n    for (int i = 0; i < axis_size; i++) {\n      dim_size_in_axis[axis_size - 1 - i] = input_shape_.At(axis.at(i));\n    }\n  }\n\n  const std::vector<int32_t> axis_;\n  const ShapeView input_shape_;\n  std::vector<int32_t> caxis_;\n};\n\ntemplate<typename T, typename ComputeType>\nOF_DEVICE_FUNC void ComputeVarUsingWelford(const T* in_ptr, T* out_ptr, const VarParam& var_param) {\n  size_t count = 0;\n  // torch use double even for float data, so here float will result in accuracy error.\n  ComputeType mean = 0.0;\n  ComputeType old_mean = 0.0;\n  ComputeType m2 = 0.0;\n  for (size_t i = 0; i < var_param.elem_cnt; i++) {\n    const size_t offset = LinearIndex2Offset(i, var_param.dim_size_in_axis,\n                                             var_param.stride_in_axis, var_param.axis_size);\n    count++;\n    old_mean = mean;\n    mean += (static_cast<ComputeType>(in_ptr[offset]) - mean) / count;\n    m2 += (static_cast<ComputeType>(in_ptr[offset]) - mean)\n          * (static_cast<ComputeType>(in_ptr[offset]) - old_mean);\n  }\n  *out_ptr = m2 / (var_param.unbiased ? count - 1 : count);\n}\n\nnamespace {\n\nOF_DEVICE_FUNC bool IsNanOut(const VarParam var_param) {\n  return (var_param.elem_cnt == 0) || (var_param.elem_cnt == 1 && var_param.unbiased == true);\n}\n\n#ifdef WITH_CUDA\nvoid SetGridDimAndBlockDim(const size_t total_elem_cnt, int* grid_dim, int* block_dim) {\n  // when total_elem_cnt > 2 * kCudaThreadsNumPerBlock, use two cuda kernel\n  if (total_elem_cnt > (kCudaThreadsNumPerBlock << 1)) {\n    *grid_dim =\n        std::min(static_cast<int32_t>(std::ceil(std::sqrt(total_elem_cnt))), kCudaMaxBlocksNum);\n    *block_dim = kCudaThreadsNumPerBlock;\n  } else {\n    *grid_dim = 1;\n    int32_t aligned_block_dim =\n        (total_elem_cnt >= kCudaThreadsNumPerBlock)\n            ? kCudaThreadsNumPerBlock\n            // avoid get block_dim = 0 when total_elem_cnt is 0\n            : std::max<long unsigned>((total_elem_cnt + kCudaWarpSize - 1) / kCudaWarpSize, 1)\n                  * kCudaWarpSize;\n    *block_dim = std::min(aligned_block_dim, kCudaThreadsNumPerBlock);\n  }\n}\n#endif  // WITH_CUDA\n}  // namespace\n\ntemplate<DeviceType device_type, typename T, typename ComputeType>\nstruct VarFunctor final {\n  void operator()(ep::Stream* stream, const T* in_ptr, T* out_ptr, ComputeType* tmp_buffer_ptr,\n                  const VarParam var_param);\n};\n\n}  // namespace user_op\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_KERNELS_VARIANCE_KERNEL_UTIL_H_\n"
  },
  {
    "path": "oneflow/user/kernels/vector_matrix_product_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/new_kernel_util.h\"\n#include \"oneflow/core/framework/config_def.h\"\n#include \"oneflow/core/kernel/cuda_graph_support.h\"\n#include \"oneflow/core/ep/include/primitive/matmul.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nep::primitive::BlasTransposeType GetBlasTransposeType(bool transpose) {\n  return transpose ? ep::primitive::BlasTransposeType::T : ep::primitive::BlasTransposeType::N;\n}\n\ntemplate<typename Context>\nstd::unique_ptr<ep::primitive::Memcpy> NewMemcpyPrimitive(Context* ctx) {\n  return ep::primitive::NewPrimitive<ep::primitive::MemcpyFactory>(\n      ctx->device_type(), ep::primitive::MemcpyKind::kDtoD);\n}\n\nstd::unique_ptr<ep::primitive::Matmul> NewMatmulPrimitive(DeviceType device_type,\n                                                          DataType data_type, bool transpose_a,\n                                                          bool transpose_b) {\n  const auto trans_a = GetBlasTransposeType(transpose_a);\n  const auto trans_b = GetBlasTransposeType(transpose_b);\n  return ep::primitive::NewPrimitive<ep::primitive::MatmulFactory>(device_type, data_type, trans_a,\n                                                                   trans_b);\n}\n\ntemplate<typename Context>\nstd::unique_ptr<ep::primitive::Matmul> NewVectorMatrixProductPrimitive(Context* ctx) {\n  const DataType data_type = ctx->TensorDesc4ArgNameAndIndex(\"out\", 0)->data_type();\n  return NewMatmulPrimitive(ctx->device_type(), data_type, false, false);\n}\n\ntemplate<typename Context>\nstd::unique_ptr<ep::primitive::Matmul> NewVectorMatrixProductGradAPrimitive(Context* ctx) {\n  const DataType data_type = ctx->TensorDesc4ArgNameAndIndex(\"dx\", 0)->data_type();\n  return NewMatmulPrimitive(ctx->device_type(), data_type, false, true);\n}\n\ntemplate<typename Context>\nstd::unique_ptr<ep::primitive::Matmul> NewVectorMatrixProductGradBPrimitive(Context* ctx) {\n  const DataType data_type = ctx->TensorDesc4ArgNameAndIndex(\"dx\", 0)->data_type();\n  return NewMatmulPrimitive(ctx->device_type(), data_type, true, false);\n}\n\nauto VectorMatrixProductPrimitiveExists() {\n  return hob::make_custom(\"NewVectorMatrixProductPrimitiveExists\",\n                          [](const user_op::KernelRegContext& ctx) {\n                            return NewVectorMatrixProductPrimitive(&ctx).operator bool();\n                          });\n}\n\nauto VectorMatrixProductGradAPrimitiveExists() {\n  return hob::make_custom(\"NewVectorMatrixProductGradAPrimitiveExists\",\n                          [](const user_op::KernelRegContext& ctx) {\n                            return NewVectorMatrixProductGradAPrimitive(&ctx).operator bool();\n                          });\n}\n\nauto VectorMatrixProductGradBPrimitiveExists() {\n  return hob::make_custom(\"NewVectorMatrixProductGradBPrimitiveExists\",\n                          [](const user_op::KernelRegContext& ctx) {\n                            return NewVectorMatrixProductGradBPrimitive(&ctx).operator bool();\n                          });\n}\n\nclass VectorMatrixProductKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport {\n public:\n  VectorMatrixProductKernel() = default;\n  ~VectorMatrixProductKernel() = default;\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    /*\n    A(k, ) matmul B(k, n) -> (1, k) matmul (k, n) -> (1, n) -> (n)\n    */\n    const user_op::Tensor* a = ctx->Tensor4ArgNameAndIndex(\"a\", 0);\n    CHECK_EQ(a->shape_view().NumAxes(), 1) << \"A Numdims should be equal to 1. \";\n    const DataType data_type = a->data_type();\n    const user_op::Tensor* b = ctx->Tensor4ArgNameAndIndex(\"b\", 0);\n    CHECK_EQ(b->shape_view().NumAxes(), 2) << \"B Numdims should be equal to 2. \";\n    CHECK_EQ(b->data_type(), data_type) << \"Matrix A Datatype should be equal to Vector B\";\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    CHECK_EQ(out->shape_view().NumAxes(), 1) << \"Out Numdims should be equal to 1. \";\n    CHECK_EQ(out->data_type(), data_type) << \"Out Datatype should be equal to input's. \";\n    size_t m = 1;\n    size_t k = a->shape_view().At(0);\n    size_t n = b->shape_view().At(1);\n    const double alpha = 1.0;\n    double beta = 0.0;\n    auto matmul = NewVectorMatrixProductPrimitive(ctx);\n    CHECK(matmul);\n    matmul->Launch(ctx->stream(), m, n, k, alpha, a->dptr(), b->dptr(), beta, out->mut_dptr());\n  }\n};\n\nREGISTER_USER_KERNEL(\"vector_matrix_product\")\n    .SetCreateFn<VectorMatrixProductKernel>()\n    .SetIsMatchedHob(VectorMatrixProductPrimitiveExists());\n\nclass VectorMatrixProductGradAKernel final : public user_op::OpKernel,\n                                             public user_op::CudaGraphSupport {\n public:\n  VectorMatrixProductGradAKernel() = default;\n  ~VectorMatrixProductGradAKernel() = default;\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    /*\n    A(k, ) matmul B(k, n) -> (1, k) matmul (k, n) -> (1, n) -> (n)\n    GradA = dy (n) matmul B_transpose(n, k) -> (1, n) matmul (n, k)\n    */\n    const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    const user_op::Tensor* b = ctx->Tensor4ArgNameAndIndex(\"b\", 0);\n    user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex(\"dx\", 0);\n    size_t m = 1;\n    size_t k = dy->shape_view().At(0);\n    size_t n = b->shape_view().At(0);\n    const double alpha = 1.0;\n    double beta = 0.0;\n    auto matmul = NewVectorMatrixProductGradAPrimitive(ctx);\n    CHECK(matmul);\n    matmul->Launch(ctx->stream(), m, n, k, alpha, dy->dptr(), b->dptr(), beta, dx->mut_dptr());\n  }\n};\n\nREGISTER_USER_KERNEL(\"vector_matrix_product_grad_a\")\n    .SetCreateFn<VectorMatrixProductGradAKernel>()\n    .SetIsMatchedHob(VectorMatrixProductGradAPrimitiveExists());\n\nclass VectorMatrixProductGradBKernel final : public user_op::OpKernel,\n                                             public user_op::CudaGraphSupport {\n public:\n  VectorMatrixProductGradBKernel() = default;\n  ~VectorMatrixProductGradBKernel() = default;\n\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    /*\n    A(k, ) matmul B(k, n) -> (1, k) matmul (k, n) -> (1, n) -> (n)\n    GradB = a_transpose (k, 1) matmul dy (1, n)\n    */\n    const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex(\"dy\", 0);\n    const user_op::Tensor* a = ctx->Tensor4ArgNameAndIndex(\"a\", 0);\n    user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex(\"dx\", 0);\n    size_t m = a->shape_view().At(0);\n    size_t k = 1;\n    size_t n = dy->shape_view().At(0);\n    const double alpha = 1.0;\n    double beta = 0.0;\n    auto matmul = NewVectorMatrixProductGradBPrimitive(ctx);\n    CHECK(matmul);\n    matmul->Launch(ctx->stream(), m, n, k, alpha, a->dptr(), dy->dptr(), beta, dx->mut_dptr());\n  }\n};\n\nREGISTER_USER_KERNEL(\"vector_matrix_product_grad_b\")\n    .SetCreateFn<VectorMatrixProductGradBKernel>()\n    .SetIsMatchedHob(VectorMatrixProductGradBPrimitiveExists());\n\n}  // namespace\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/where_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/cuda_graph_support.h\"\n#include \"oneflow/core/ep/include/primitive/where.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename Context>\nauto NewPrimitive(Context* ctx) -> std::unique_ptr<ep::primitive::Where> {\n  const user_op::TensorDesc* cond_desc = ctx->TensorDesc4ArgNameAndIndex(\"condition\", 0);\n  const user_op::TensorDesc* out_desc = ctx->TensorDesc4ArgNameAndIndex(\"out\", 0);\n  return ep::primitive::NewPrimitive<ep::primitive::WhereFactory>(\n      ctx->device_type(), cond_desc->data_type(), out_desc->data_type(),\n      out_desc->shape().NumAxes());\n}\n\nauto PrimitiveExists() {\n  return hob::make_custom(\"PrimitiveExists\", [](const user_op::KernelRegContext& ctx) -> bool {\n    return NewPrimitive(&ctx).operator bool();\n  });\n}\n\nclass WhereKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport {\n public:\n  WhereKernel() = default;\n  ~WhereKernel() = default;\n\n private:\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    const user_op::Tensor* cond = ctx->Tensor4ArgNameAndIndex(\"condition\", 0);\n    const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex(\"x\", 0);\n    const user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex(\"y\", 0);\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    if (out->shape_view().elem_cnt() == 0) { return; }\n    auto primitive = NewPrimitive(ctx);\n    CHECK(primitive);\n    primitive->Launch(ctx->stream(), cond->shape_view().size(), cond->shape_view().ptr(),\n                      cond->dptr(), x->shape_view().size(), x->shape_view().ptr(), x->dptr(),\n                      y->shape_view().size(), y->shape_view().ptr(), y->dptr(), out->mut_dptr());\n  }\n};\n\nREGISTER_USER_KERNEL(\"where\").SetCreateFn<WhereKernel>().SetIsMatchedHob(PrimitiveExists() == true);\n\n}  // namespace\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/where_kernel_util.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/user/kernels/where_kernel_util.h\"\n\nnamespace oneflow {\n\ntemplate<typename T, typename CondT>\nstruct WhereKernelUtil<DeviceType::kCPU, T, CondT> {\n  static void Where(ep::Stream* stream, const int64_t elem_cnt, const CondT* cond, const T* lhs,\n                    const T* rhs, T* out) {\n    FOR_RANGE(int64_t, i, 0, elem_cnt) { out[i] = static_cast<bool>(cond[i]) ? lhs[i] : rhs[i]; }\n  }\n  static void WhereXScalar(ep::Stream* stream, const int64_t elem_cnt, const CondT* cond,\n                           const T x_scalar, const T* rhs, T* out) {\n    FOR_RANGE(int64_t, i, 0, elem_cnt) { out[i] = static_cast<bool>(cond[i]) ? x_scalar : rhs[i]; }\n  }\n  static void WhereYScalar(ep::Stream* stream, const int64_t elem_cnt, const CondT* cond,\n                           const T* lhs, const T y_scalar, T* out) {\n    FOR_RANGE(int64_t, i, 0, elem_cnt) { out[i] = static_cast<bool>(cond[i]) ? lhs[i] : y_scalar; }\n  }\n  static void WhereXYScalar(ep::Stream* stream, const int64_t elem_cnt, const CondT* cond,\n                            const T x_scalar, const T y_scalar, T* out) {\n    FOR_RANGE(int64_t, i, 0, elem_cnt) {\n      out[i] = static_cast<bool>(cond[i]) ? x_scalar : y_scalar;\n    }\n  }\n};\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_WHERE_FUNCTOR, (DeviceType::kCPU),\n                                 ARITHMETIC_DATA_TYPE_SEQ BOOL_DATA_TYPE_SEQ,\n                                 INT_DATA_TYPE_SEQ BOOL_DATA_TYPE_SEQ)\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/where_kernel_util.cu",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/user/kernels/where_kernel_util.h\"\n#include \"oneflow/core/cuda/elementwise.cuh\"\n#include \"oneflow/core/ep/cuda/cuda_stream.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename T, typename CondT>\nstruct WhereFunctor {\n  OF_DEVICE_FUNC T operator()(CondT cond, T lhs, T rhs) const {\n    return static_cast<bool>(cond) ? lhs : rhs;\n  }\n};\n\ntemplate<typename T, typename CondT>\nstruct WhereScalarXFunctor {\n  OF_DEVICE_FUNC explicit WhereScalarXFunctor(T scalar) : x_scalar(scalar) {}\n  OF_DEVICE_FUNC T operator()(CondT cond, T rhs) const {\n    return static_cast<bool>(cond) ? x_scalar : rhs;\n  }\n  const T x_scalar;\n};\n\ntemplate<typename T, typename CondT>\nstruct WhereScalarYFunctor {\n  OF_DEVICE_FUNC explicit WhereScalarYFunctor(T scalar) : y_scalar(scalar) {}\n  OF_DEVICE_FUNC T operator()(CondT cond, T lhs) const {\n    return static_cast<bool>(cond) ? lhs : y_scalar;\n  }\n  const T y_scalar;\n};\n\ntemplate<typename T, typename CondT>\nstruct WhereScalarXYFunctor {\n  OF_DEVICE_FUNC explicit WhereScalarXYFunctor(T x_scalar, T y_scalar)\n      : x_scalar(x_scalar), y_scalar(y_scalar) {}\n  OF_DEVICE_FUNC T operator()(CondT cond) const {\n    return static_cast<bool>(cond) ? x_scalar : y_scalar;\n  }\n  const T x_scalar;\n  const T y_scalar;\n};\n\n}  // namespace\n\ntemplate<typename T, typename CondT>\nstruct WhereKernelUtil<DeviceType::kCUDA, T, CondT> {\n  static void Where(ep::Stream* stream, const int64_t elem_cnt, const CondT* cond, const T* lhs,\n                    const T* rhs, T* out) {\n    cuda::elementwise::Ternary(WhereFunctor<T, CondT>(), elem_cnt, out, cond, lhs, rhs,\n                               stream->As<ep::CudaStream>()->cuda_stream());\n  }\n  static void WhereXScalar(ep::Stream* stream, const int64_t elem_cnt, const CondT* cond,\n                           const T x_scalar, const T* rhs, T* out) {\n    cuda::elementwise::Binary(WhereScalarXFunctor<T, CondT>(x_scalar), elem_cnt, out, cond, rhs,\n                              stream->As<ep::CudaStream>()->cuda_stream());\n  }\n  static void WhereYScalar(ep::Stream* stream, const int64_t elem_cnt, const CondT* cond,\n                           const T* lhs, const T y_scalar, T* out) {\n    cuda::elementwise::Binary(WhereScalarYFunctor<T, CondT>(y_scalar), elem_cnt, out, cond, lhs,\n                              stream->As<ep::CudaStream>()->cuda_stream());\n  }\n  static void WhereXYScalar(ep::Stream* stream, const int64_t elem_cnt, const CondT* cond,\n                            const T x_scalar, const T y_scalar, T* out) {\n    cuda::elementwise::Unary(WhereScalarXYFunctor<T, CondT>(x_scalar, y_scalar), elem_cnt, out,\n                             cond, stream->As<ep::CudaStream>()->cuda_stream());\n  }\n};\n\nOF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_WHERE_FUNCTOR, (DeviceType::kCUDA),\n                                 ARITHMETIC_DATA_TYPE_SEQ FLOAT16_DATA_TYPE_SEQ BOOL_DATA_TYPE_SEQ,\n                                 INT_DATA_TYPE_SEQ BOOL_DATA_TYPE_SEQ)\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/kernels/where_kernel_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_KERNELS_WHERE_KERNEL_UTIL_H_\n#define ONEFLOW_USER_KERNELS_WHERE_KERNEL_UTIL_H_\n\n#include \"oneflow/core/kernel/kernel_util.h\"\n#include \"oneflow/core/ep/include/stream.h\"\n\nnamespace oneflow {\n\ntemplate<DeviceType device_type, typename T, typename CondT>\nstruct WhereKernelUtil {\n  static void Where(ep::Stream* stream, const int64_t elem_cnt, const CondT* cond, const T* lhs,\n                    const T* rhs, T* out);\n  static void WhereXScalar(ep::Stream* stream, const int64_t elem_cnt, const CondT* cond,\n                           const T x_scalar, const T* rhs, T* out);\n  static void WhereYScalar(ep::Stream* stream, const int64_t elem_cnt, const CondT* cond,\n                           const T* lhs, const T y_scalar, T* out);\n  static void WhereXYScalar(ep::Stream* stream, const int64_t elem_cnt, const CondT* cond,\n                            const T x_scalar, const T y_scalar, T* out);\n};\n\n#define INSTANTIATE_WHERE_FUNCTOR(device_type_v, dtype_pair, ctype_pair)       \\\n  template struct WhereKernelUtil<device_type_v, OF_PP_PAIR_FIRST(dtype_pair), \\\n                                  OF_PP_PAIR_FIRST(ctype_pair)>;\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_KERNELS_WHERE_KERNEL_UTIL_H_\n"
  },
  {
    "path": "oneflow/user/kernels/zero_like_kernel.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/kernel/cuda_graph_support.h\"\n#include \"oneflow/core/ep/include/primitive/memset.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nclass ZeroLikeKernel final : public user_op::OpKernel {\n public:\n  ZeroLikeKernel() = default;\n  ~ZeroLikeKernel() = default;\n\n private:\n  void Compute(user_op::KernelComputeContext* ctx) const override {\n    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex(\"out\", 0);\n    const int64_t elem_cnt = out->shape_view().elem_cnt();\n    if (elem_cnt > 0) {\n      std::unique_ptr<ep::primitive::Memset> primitive =\n          ep::primitive::NewPrimitive<ep::primitive::MemsetFactory>(ctx->stream()->device_type());\n      CHECK(primitive) << \"Can not create Memset primitive for device type \"\n                       << ctx->stream()->device_type();\n      primitive->Launch(ctx->stream(), out->mut_dptr(), 0,\n                        elem_cnt * GetSizeOfDataType(out->data_type()));\n    }\n  }\n  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }\n};\n\nREGISTER_USER_KERNEL(\"zero_like\").SetCreateFn<ZeroLikeKernel>();\n\n}  // namespace\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/acc_ctrl_tick_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/operator/operator.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/* static */ Maybe<void> AccCtrlTickOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  ctx->SetOutputShape(\"out\", 0, Shape({1}));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> AccCtrlTickOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> AccCtrlTickOp::GetSbp(user_op::SbpContext* ctx) {\n  return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx);\n}\n\n/* static */ Maybe<void> AccCtrlTickOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) {\n  const NdSbp& in_dis_hint = ctx->NdSbpHint4InputArgNameAndIndex(\"in\", 0);\n  const Shape& parallel_hierarchy = ctx->parallel_hierarchy();\n  CHECK_EQ_OR_RETURN(in_dis_hint.sbp_parallel_size(),  // NOLINT(maybe-need-error-msg)\n                     parallel_hierarchy.NumAxes());    // NOLINT(maybe-need-error-msg)\n\n  NdSbp* in_distribution = ctx->NdSbp4ArgNameAndIndex(\"in\", 0);\n  NdSbp* out_distribution = ctx->NdSbp4ArgNameAndIndex(\"out\", 0);\n  in_distribution->clear_sbp_parallel();\n  out_distribution->clear_sbp_parallel();\n  // in use hint\n  in_distribution->CopyFrom(in_dis_hint);\n\n  for (int32_t i = 0; i < parallel_hierarchy.NumAxes(); ++i) {\n    // out dim1 = broadcast\n    out_distribution->add_sbp_parallel()->mutable_broadcast_parallel();\n  }\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> AccCtrlTickOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> AccCtrlTickOp::InferOutputBlobTimeShape(\n    user_op::InferOutputBlobTimeShapeFnContext* ctx) {\n  const int32_t max_acc_num = ctx->user_op_conf().attr<int32_t>(\"max_acc_num\");\n  const Shape& in_time_shape = ctx->TimeShape4InputArgNameAndIndex(\"in\", 0);\n  DimVector time_shape_dim_vec = in_time_shape.dim_vec();  // NOLINT(maybe-need-error-msg)\n  CHECK_OR_RETURN(!time_shape_dim_vec.empty());            // NOLINT(maybe-need-error-msg)\n  if (time_shape_dim_vec.back() == max_acc_num) {\n    time_shape_dim_vec.pop_back();\n  } else if (time_shape_dim_vec.back() % max_acc_num == 0) {\n    time_shape_dim_vec.back() /= max_acc_num;\n  } else {\n    const int64_t elem_cnt = in_time_shape.elem_cnt();\n    CHECK_EQ_OR_RETURN(elem_cnt % max_acc_num, 0);\n    time_shape_dim_vec.resize(1);\n    time_shape_dim_vec.back() = elem_cnt / max_acc_num;\n  }\n  *ctx->mut_output_blob_time_shape() = Shape(time_shape_dim_vec);\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/acc_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/*static*/ Maybe<void> AccOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& in = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"in\", 0);\n  FOR_RANGE(int64_t, i, 0, in.shape().NumAxes()) {\n    ctx->NewBuilder().Split(user_op::OpArg(\"in\", 0), i).Split(user_op::OpArg(\"out\", 0), i).Build();\n  }\n  ctx->NewBuilder()\n      .PartialSum(user_op::OpArg(\"in\", 0))\n      .PartialSum(user_op::OpArg(\"out\", 0))\n      .Build();\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> AccOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  ctx->SetOutputShape(\"out\", 0, ctx->InputShape(\"in\", 0));\n  ctx->SetOutputIsDynamic(\"out\", 0, ctx->InputIsDynamic(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> AccOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return AccOp::InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> AccOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> AccOp::InferOutputBlobTimeShape(\n    user_op::InferOutputBlobTimeShapeFnContext* ctx) {\n  const int32_t max_acc_num = ctx->user_op_conf().attr<int32_t>(\"max_acc_num\");\n  const Shape& in_time_shape = ctx->TimeShape4InputArgNameAndIndex(\"in\", 0);\n  DimVector time_shape_dim_vec = in_time_shape.dim_vec();\n  CHECK_OR_RETURN(!time_shape_dim_vec.empty());\n  if (time_shape_dim_vec.back() == max_acc_num) {\n    time_shape_dim_vec.pop_back();\n  } else if (time_shape_dim_vec.back() % max_acc_num == 0) {\n    time_shape_dim_vec.back() /= max_acc_num;\n  } else {\n    const int64_t elem_cnt = in_time_shape.elem_cnt();\n    time_shape_dim_vec.resize(1);\n    time_shape_dim_vec.back() = elem_cnt / max_acc_num;\n  }\n  *ctx->mut_output_blob_time_shape() = Shape(time_shape_dim_vec);\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/adaptive_max_pool_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/user/ops/nn_util.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nMaybe<void> InferFWTensorDesc(user_op::InferContext* ctx) {\n  std::vector<int64_t> output_size = ctx->Attr<std::vector<int64_t>>(\"output_size\");\n  const std::string& data_format = ctx->Attr<std::string>(\"data_format\");\n  const Shape& x_shape = ctx->InputShape(\"x\", 0);\n  DimVector out_shape(x_shape.NumAxes());\n  out_shape[0] = x_shape.dim_vec()[0];\n  out_shape[1] = x_shape.dim_vec()[1];\n  if (data_format == \"channels_first\") {\n    out_shape[1] = x_shape.dim_vec()[1];\n    for (int i = 2; i < out_shape.size(); ++i) {\n      out_shape[i] = output_size.size() > i - 2 ? output_size[i - 2] : output_size[0];\n    }\n  } else {\n    out_shape[3] = x_shape.dim_vec()[3];\n    for (int i = 1; i < out_shape.size() - 1; ++i) {\n      out_shape[i] = output_size.size() > i - 1 ? output_size[i - 1] : output_size[0];\n    }\n  }\n\n  ctx->SetOutputShape(\"y\", 0, Shape(out_shape));\n  ctx->SetOutputShape(\"index\", 0, Shape(out_shape));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InferBWTensorDesc(user_op::InferContext* ctx) {\n  ctx->SetOutputShape(\"dx\", 0, ctx->InputShape(\"x\", 0));\n  ctx->SetOutputIsDynamic(\"dx\", 0, ctx->InputIsDynamic(\"x\", 0));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> FwGetSbpFn(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"x\", 0);\n  // only for nchw\n  FOR_RANGE(int64_t, i, 0, std::min(2, (int)tensor.shape().NumAxes())) {\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"x\", 0), i)\n        .Split(user_op::OpArg(\"y\", 0), i)\n        .Split(user_op::OpArg(\"index\", 0), i)\n        .Build();\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> BwGetSbpFn(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"x\", 0);\n  FOR_RANGE(int64_t, i, 0, std::min(2, (int)tensor.shape().NumAxes())) {\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"x\", 0), i)\n        .Split(user_op::OpArg(\"dy\", 0), i)\n        .Split(user_op::OpArg(\"dx\", 0), i)\n        .Split(user_op::OpArg(\"index\", 0), i)\n        .Build();\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InferFWDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"y\", 0, ctx->InputDType(\"x\", 0));\n  ctx->SetOutputDType(\"index\", 0, DataType::kInt64);\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InferBWDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"dx\", 0, ctx->InputDType(\"x\", 0));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\n#define DEF_ADAPTIVE_MAX_POOL_OP(op_class_name_prefix)                                            \\\n  /* static */ Maybe<void> op_class_name_prefix##Op::InferLogicalTensorDesc(                      \\\n      user_op::InferContext* ctx) {                                                               \\\n    return InferFWTensorDesc(ctx);                                                                \\\n  }                                                                                               \\\n                                                                                                  \\\n  /*static*/ Maybe<void> op_class_name_prefix##Op::InferPhysicalTensorDesc(                       \\\n      user_op::InferContext* ctx) {                                                               \\\n    return InferLogicalTensorDesc(ctx);                                                           \\\n  }                                                                                               \\\n                                                                                                  \\\n  /* static */ Maybe<void> op_class_name_prefix##Op::GetSbp(user_op::SbpContext* ctx) {           \\\n    return FwGetSbpFn(ctx);                                                                       \\\n  }                                                                                               \\\n                                                                                                  \\\n  /* static */ Maybe<void> op_class_name_prefix##Op::InferDataType(user_op::InferContext* ctx) {  \\\n    return InferFWDataType(ctx);                                                                  \\\n  }                                                                                               \\\n                                                                                                  \\\n  /* static */ Maybe<void> op_class_name_prefix##GradOp::InferLogicalTensorDesc(                  \\\n      user_op::InferContext* ctx) {                                                               \\\n    return InferBWTensorDesc(ctx);                                                                \\\n  }                                                                                               \\\n                                                                                                  \\\n  /*static*/                                                                                      \\\n  Maybe<void> op_class_name_prefix##GradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { \\\n    return InferLogicalTensorDesc(ctx);                                                           \\\n  }                                                                                               \\\n                                                                                                  \\\n  /* static */                                                                                    \\\n  Maybe<void> op_class_name_prefix##GradOp::GetSbp(user_op::SbpContext* ctx) {                    \\\n    return BwGetSbpFn(ctx);                                                                       \\\n  }                                                                                               \\\n                                                                                                  \\\n  /* static */                                                                                    \\\n  Maybe<void> op_class_name_prefix##GradOp::InferDataType(user_op::InferContext* ctx) {           \\\n    return InferBWDataType(ctx);                                                                  \\\n  }\n\nDEF_ADAPTIVE_MAX_POOL_OP(AdaptiveMaxPool1D);\nDEF_ADAPTIVE_MAX_POOL_OP(AdaptiveMaxPool2D);\nDEF_ADAPTIVE_MAX_POOL_OP(AdaptiveMaxPool3D);\n\n#undef DEF_ADAPTIVE_MAX_POOL_OP\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/adaptive_pool_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/user/ops/nn_util.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nMaybe<void> InferFWTensorDesc(user_op::InferContext* ctx) {\n  std::vector<int64_t> output_size = ctx->Attr<std::vector<int64_t>>(\"output_size\");\n  const std::string& data_format = ctx->Attr<std::string>(\"data_format\");\n  const Shape& x_shape = ctx->InputShape(\"x\", 0);\n  DimVector out_shape(x_shape.NumAxes());\n  out_shape[0] = x_shape.dim_vec()[0];\n  if (data_format == \"channels_first\") {\n    out_shape[1] = x_shape.dim_vec()[1];\n    for (int i = 2; i < out_shape.size(); ++i) {\n      out_shape[i] = output_size.size() > i - 2 ? output_size[i - 2] : output_size[0];\n    }\n  } else {\n    out_shape[3] = x_shape.dim_vec()[3];\n    for (int i = 1; i < out_shape.size() - 1; ++i) {\n      out_shape[i] = output_size.size() > i - 1 ? output_size[i - 1] : output_size[0];\n    }\n  }\n\n  ctx->SetOutputShape(\"y\", 0, Shape(out_shape));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InferBWTensorDesc(user_op::InferContext* ctx) {\n  ctx->SetOutputShape(\"dx\", 0, ctx->InputShape(\"x\", 0));\n  ctx->SetOutputIsDynamic(\"dx\", 0, ctx->InputIsDynamic(\"x\", 0));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> FwGetSbpFn(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"x\", 0);\n  // only for nchw\n  FOR_RANGE(int64_t, i, 0, std::min(2, (int)tensor.shape().NumAxes())) {\n    ctx->NewBuilder().Split(user_op::OpArg(\"x\", 0), i).Split(user_op::OpArg(\"y\", 0), i).Build();\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> BwGetSbpFn(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"x\", 0);\n  FOR_RANGE(int64_t, i, 0, std::min(2, (int)tensor.shape().NumAxes())) {\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"x\", 0), i)\n        .Split(user_op::OpArg(\"dy\", 0), i)\n        .Split(user_op::OpArg(\"dx\", 0), i)\n        .Build();\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InferFWDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"y\", 0, ctx->InputDType(\"x\", 0));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InferBWDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"dx\", 0, ctx->InputDType(\"x\", 0));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\n#define DEF_ADAPTIVE_AVG_POOL_OP(op_class_name_prefix)                                           \\\n  /* static */ Maybe<void> op_class_name_prefix##Op::InferLogicalTensorDesc(                     \\\n      user_op::InferContext* ctx) {                                                              \\\n    return InferFWTensorDesc(ctx);                                                               \\\n  }                                                                                              \\\n                                                                                                 \\\n  /*static*/ Maybe<void> op_class_name_prefix##Op::InferPhysicalTensorDesc(                      \\\n      user_op::InferContext* ctx) {                                                              \\\n    return InferLogicalTensorDesc(ctx);                                                          \\\n  }                                                                                              \\\n                                                                                                 \\\n  /* static */ Maybe<void> op_class_name_prefix##Op::GetSbp(user_op::SbpContext* ctx) {          \\\n    return FwGetSbpFn(ctx);                                                                      \\\n  }                                                                                              \\\n                                                                                                 \\\n  /* static */ Maybe<void> op_class_name_prefix##Op::InferDataType(user_op::InferContext* ctx) { \\\n    return InferFWDataType(ctx);                                                                 \\\n  }                                                                                              \\\n                                                                                                 \\\n  /* static */ Maybe<void> op_class_name_prefix##GradOp::InferLogicalTensorDesc(                 \\\n      user_op::InferContext* ctx) {                                                              \\\n    return InferBWTensorDesc(ctx);                                                               \\\n  }                                                                                              \\\n                                                                                                 \\\n  /*static*/ Maybe<void> op_class_name_prefix##GradOp::InferPhysicalTensorDesc(                  \\\n      user_op::InferContext* ctx) {                                                              \\\n    return InferLogicalTensorDesc(ctx);                                                          \\\n  }                                                                                              \\\n                                                                                                 \\\n  /* static */ Maybe<void> op_class_name_prefix##GradOp::GetSbp(user_op::SbpContext* ctx) {      \\\n    return BwGetSbpFn(ctx);                                                                      \\\n  }                                                                                              \\\n                                                                                                 \\\n  /* static */ Maybe<void> op_class_name_prefix##GradOp::InferDataType(                          \\\n      user_op::InferContext* ctx) {                                                              \\\n    return InferBWDataType(ctx);                                                                 \\\n  }\n\nDEF_ADAPTIVE_AVG_POOL_OP(AdaptiveAvgPool1D)\nDEF_ADAPTIVE_AVG_POOL_OP(AdaptiveAvgPool2D)\nDEF_ADAPTIVE_AVG_POOL_OP(AdaptiveAvgPool3D)\n\n#undef DEF_ADAPTIVE_AVG_POOL_OP\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/add_n_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n/* static */ Maybe<void> AddNOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const auto& in_0 = ctx->InputTensorDesc(\"in\", 0);\n  auto* out = ctx->MutOutputTensorDesc(\"out\", 0);\n  CHECK_NOTNULL_OR_RETURN(out);  // NOLINT(maybe-need-error-msg)\n  for (const auto& pair : ctx->inputs()) {\n    const auto& cur_in = ctx->InputTensorDesc(pair.first, pair.second);\n    if (in_0.shape().NumAxes() > 0 && cur_in.shape().NumAxes() > 0) {\n      CHECK_EQ_OR_RETURN(in_0.shape(), cur_in.shape())\n          << Error::RuntimeError()\n          << \"inconsistent tensor size, expected all tensor to have the same shapes, \"\n          << \"but got \" << in_0.shape().ToString() << \" and \" << cur_in.shape().ToString();\n    }\n  }\n  out->set_shape(in_0.shape());\n  out->set_is_dynamic(in_0.is_dynamic());\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> AddNOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> AddNOp::GetSbp(user_op::SbpContext* ctx) {\n  int64_t num_axes = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"in\", 0).shape().NumAxes();\n  for (int64_t i = 0; i < num_axes; ++i) {\n    ctx->NewBuilder().Split(ctx->inputs(), i).Split(user_op::OpArg(\"out\", 0), i).Build();\n  }\n  ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(user_op::OpArg(\"out\", 0)).Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> AddNOp::InferDataType(user_op::InferContext* ctx) {\n  const auto& in_0 = ctx->InputTensorDesc(\"in\", 0);\n  auto* out = ctx->MutOutputTensorDesc(\"out\", 0);\n  CHECK_NOTNULL_OR_RETURN(out);  // NOLINT(maybe-need-error-msg)\n  for (const auto& pair : ctx->inputs()) {\n    const auto& cur_in = ctx->InputTensorDesc(pair.first, pair.second);\n    CHECK_EQ_OR_RETURN(in_0.data_type(), cur_in.data_type())\n        << Error::RuntimeError() << ctx->op_name()\n        << \" expected all tenser to have same type, but found \" << DataType_Name(in_0.data_type())\n        << \" and \" << DataType_Name(cur_in.data_type());\n  }\n  out->set_data_type(in_0.data_type());\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> AddNOp::CheckAttr(const user_op::UserOpDefWrapper&,\n                                         const user_op::UserOpConfWrapper& op_conf) {\n  CHECK_OR_RETURN(op_conf.input_size(\"in\") >= 2)\n      << Error::RuntimeError()\n      << \"The number of input tensors should be greater than or equal to 2\";\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/affine_grid_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n#include \"oneflow/core/operator/operator.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nMaybe<void> CheckAttr_(const user_op::UserOpDefWrapper& def,\n                       const user_op::UserOpConfWrapper& conf) {\n  bool pass_checked = true;\n  std::stringstream err;\n  err << \"Illegal value for \" << conf.op_type_name() << \" op \" << conf.op_name() << \": \";\n\n  const auto& size = conf.attr<Shape>(\"size\");\n  if (size.NumAxes() != 4 && size.NumAxes() != 5) {\n    err << \"dimension of size can't be:\" << size.NumAxes();\n    pass_checked = false;\n  }\n\n  for (int i = 0; i < size.NumAxes(); i++) {\n    if (size.At(i) <= 0) { err << \"element of size can't be:\" << size.At(i); }\n  }\n\n  if (pass_checked) {\n    return Maybe<void>::Ok();\n  } else {\n    return oneflow::Error::CheckFailedError() << err.str();\n  }\n}\n\n}  // namespace\n\n/* static */ Maybe<void> AffineGridOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& theta = ctx->InputTensorDesc(\"theta\", 0);\n  user_op::TensorDesc* grid = ctx->MutOutputTensorDesc(\"grid\", 0);\n  const Shape& size = ctx->Attr<Shape>(\"size\");\n  // Only support 2D or 3D affine grid with NCHW layout\n  // For 2D grid: theta = { N, 2, 3 },\n  //              size  = { N, C, H, W }\n  //              grid  = { N, H, W, 2 }\n  // For 3D grid: theta = { N, 3, 4 },\n  //              size  = { N, C, D, H, W }\n  //              grid  = { N, D, H, W, 3 }\n  bool is_2d_grid = true;\n  if (theta.shape().At(1) == 2) {\n    CHECK_EQ_OR_RETURN(theta.shape().At(2), 3) << \"Theta shape  MUST be (N, 2, 3) or (N, 3, 4)\";\n    CHECK_EQ_OR_RETURN(size.NumAxes(), 4) << \"Dimension of size MUST be 4, when 2d affine grid\";\n    CHECK_EQ_OR_RETURN(theta.shape().At(0), size.At(0))\n        << \"Theta and size MUST have same batch dimension\";\n    is_2d_grid = true;\n  } else if (theta.shape().At(1) == 3) {\n    CHECK_EQ_OR_RETURN(theta.shape().At(2), 4) << \"Theta shape  MUST be (N, 2, 3) or (N, 3, 4)\";\n    CHECK_EQ_OR_RETURN(size.NumAxes(), 5) << \"Dimension of size MUST be 4, when 3d affine grid\";\n    CHECK_EQ_OR_RETURN(theta.shape().At(0), size.At(0))\n        << \"Theta and size MUST have same batch dimension\";\n    is_2d_grid = false;\n  } else {\n    CHECK_OR_RETURN(false) << \"Theta MUST be 2D or 3D grid\";\n  }\n  grid->set_is_dynamic(theta.is_dynamic());\n  Shape grid_shape;\n  if (is_2d_grid) {\n    grid_shape = {size.At(0), size.At(2), size.At(3), 2};\n  } else {\n    grid_shape = {size.At(0), size.At(2), size.At(3), size.At(4), 3};\n  }\n  grid->set_shape(grid_shape);\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> AffineGridOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& theta = ctx->InputTensorDesc(\"theta\", 0);\n  user_op::TensorDesc* grid = ctx->MutOutputTensorDesc(\"grid\", 0);\n  const Shape& size = ctx->Attr<Shape>(\"size\");\n  // Only support 2D or 3D affine grid with NCHW layout\n  // For 2D grid: theta = { N, 2, 3 },\n  //              size  = { N, C, H, W }\n  //              grid  = { N, H, W, 2 }\n  // For 3D grid: theta = { N, 3, 4 },\n  //              size  = { N, C, D, H, W }\n  //              grid  = { N, D, H, W, 3 }\n  const Shape& theta_shape = theta.shape();\n  bool is_2d_grid = true;\n  if (theta_shape.At(1) == 2) {\n    CHECK_EQ_OR_RETURN(theta_shape.At(2), 3) << \"Theta shape  MUST be (N, 2, 3) or (N, 3, 4)\";\n    CHECK_EQ_OR_RETURN(size.NumAxes(), 4) << \"Dimension of size MUST be 4, when 2d affine grid\";\n    is_2d_grid = true;\n  } else if (theta_shape.At(1) == 3) {\n    CHECK_EQ_OR_RETURN(theta_shape.At(2), 4) << \"Theta shape  MUST be (N, 2, 3) or (N, 3, 4)\";\n    CHECK_EQ_OR_RETURN(size.NumAxes(), 5) << \"Dimension of size MUST be 4, when 3d affine grid\";\n    is_2d_grid = false;\n  } else {\n    CHECK_OR_RETURN(false) << \"Theta MUST be 2D or 3D grid\";\n  }\n\n  int64_t N = size.At(0);\n  const int64_t& parallel_num = ctx->parallel_ctx().parallel_num();\n  if (parallel_num > 1) {\n    const NdSbp& nd_sbp = ctx->NdSbp4ArgNameAndIndex(\"theta\", 0);\n    Shape logical_shape = theta_shape;\n    logical_shape.Set(0, size.At(0));\n    const auto& physical_shape =\n        JUST(GetPhysicalShape(logical_shape, nd_sbp, ctx->parallel_desc(), ctx->parallel_ctx()));\n    N = physical_shape->At(0);\n  }\n  CHECK_EQ_OR_RETURN(theta_shape.At(0), N)\n      << \"The dimension 0 size of theta shape should be \" << N << \", but got \" << theta_shape.At(0);\n\n  grid->set_is_dynamic(theta.is_dynamic());\n  Shape grid_shape;\n  if (is_2d_grid) {\n    grid_shape = {N, size.At(2), size.At(3), 2};\n  } else {\n    grid_shape = {N, size.At(2), size.At(3), size.At(4), 3};\n  }\n  grid->set_shape(grid_shape);\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> AffineGridOp::GetSbp(user_op::SbpContext* ctx) {\n  ctx->NewBuilder()\n      .Split(user_op::OpArg(\"theta\", 0), 0)\n      .Split(user_op::OpArg(\"grid\", 0), 0)\n      .Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> AffineGridOp::CheckAttr(const user_op::UserOpDefWrapper& def,\n                                                 const user_op::UserOpConfWrapper& conf) {\n  return CheckAttr_(def, conf);\n}\n\n/* static */ Maybe<void> AffineGridOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"grid\", 0, ctx->InputDType(\"theta\", 0));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> AffineGridGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& dgrid = ctx->InputTensorDesc(\"dgrid\", 0);\n  const Shape& size = ctx->Attr<Shape>(\"size\");\n  if (size.NumAxes() == 4) {\n    ctx->MutOutputTensorDesc(\"dtheta\", 0)->set_shape(Shape({dgrid.shape().At(0), 2, 3}));\n  } else if (size.NumAxes() == 5) {\n    ctx->MutOutputTensorDesc(\"dtheta\", 0)->set_shape(Shape({dgrid.shape().At(0), 3, 4}));\n  } else {\n    CHECK_OR_RETURN(false) << \"size MUST be 4D or 5D\";\n  }\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> AffineGridGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> AffineGridGradOp::GetSbp(user_op::SbpContext* ctx) {\n  ctx->NewBuilder()\n      .Split(user_op::OpArg(\"dgrid\", 0), 0)\n      .Split(user_op::OpArg(\"dtheta\", 0), 0)\n      .Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> AffineGridGradOp::CheckAttr(const user_op::UserOpDefWrapper& def,\n                                                     const user_op::UserOpConfWrapper& conf) {\n  return CheckAttr_(def, conf);\n}\n\n/* static */ Maybe<void> AffineGridGradOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"dtheta\", 0, ctx->InputDType(\"dgrid\", 0));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/amp_white_identity_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/* static */ Maybe<void> AmpWhiteIdentityOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& in = ctx->InputTensorDesc(\"in\", 0);\n  user_op::TensorDesc* out = ctx->MutOutputTensorDesc(\"out\", 0);\n  out->set_shape(in.shape());\n  out->set_is_dynamic(in.is_dynamic());\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> AmpWhiteIdentityOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> AmpWhiteIdentityOp::GetSbp(user_op::SbpContext* ctx) {\n  const auto& in = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"in\", 0);\n  for (int i = 0; i < in.shape().NumAxes(); ++i) {\n    ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build();\n  }\n  ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> AmpWhiteIdentityOp::InferDataType(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& in = ctx->InputTensorDesc(\"in\", 0);\n  user_op::TensorDesc* out = ctx->MutOutputTensorDesc(\"out\", 0);\n  out->set_data_type(in.data_type());\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> AmpBlackIdentityOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& in = ctx->InputTensorDesc(\"in\", 0);\n  user_op::TensorDesc* out = ctx->MutOutputTensorDesc(\"out\", 0);\n  out->set_shape(in.shape());\n  out->set_is_dynamic(in.is_dynamic());\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> AmpBlackIdentityOp::GetSbp(user_op::SbpContext* ctx) {\n  const auto& in = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"in\", 0);\n  for (int i = 0; i < in.shape().NumAxes(); ++i) {\n    ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build();\n  }\n  ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> AmpBlackIdentityOp::InferDataType(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& in = ctx->InputTensorDesc(\"in\", 0);\n  user_op::TensorDesc* out = ctx->MutOutputTensorDesc(\"out\", 0);\n  out->set_data_type(in.data_type());\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/arange_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n#include \"oneflow/core/common/balanced_splitter.h\"\n#include \"oneflow/core/job/nd_sbp_util.h\"\n\nnamespace oneflow {\n\n/* static */ Maybe<void> ArangeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  DataType dtype = ctx->Attr<DataType>(\"dtype\");\n  int64_t range_elem_cnt = 0;\n  if (IsIntegralDataType(dtype)) {\n    int64_t integer_delta = ctx->Attr<int64_t>(\"integer_delta\");\n    CHECK_NE_OR_RETURN(integer_delta, static_cast<int64_t>(0))\n        << \"RuntimeError: step must be nonzero. \";\n    int64_t integer_start = ctx->Attr<int64_t>(\"integer_start\");\n    int64_t integer_limit = ctx->Attr<int64_t>(\"integer_limit\");\n    // CHECK when limit > start, delta > 0; limit < start, delta < 0;\n    CHECK_GE_OR_RETURN((integer_limit - integer_start) / integer_delta, static_cast<int64_t>(0))\n        << \"RuntimeError: upper bound and larger bound inconsistent with step sign\";\n    range_elem_cnt = std::ceil(static_cast<double>(integer_limit - integer_start) / integer_delta);\n  } else {\n    double float_delta = ctx->Attr<double>(\"float_delta\");\n    CHECK_NE_OR_RETURN(float_delta, static_cast<double>(0.0))\n        << \"RuntimeError: step must be nonzero. \";\n    double float_start = ctx->Attr<double>(\"float_start\");\n    double float_limit = ctx->Attr<double>(\"float_limit\");\n    // CHECK when limit > start, delta > 0; limit < start, delta < 0;\n    // CHECK_GE For 0-Dim Tensor\n    CHECK_GE_OR_RETURN((float_limit - float_start) / float_delta, static_cast<double>(0.0))\n        << \"RuntimeError: upper bound and larger bound inconsistent with step sign\";\n    range_elem_cnt = std::ceil(static_cast<double>(float_limit - float_start) / float_delta);\n  }\n  ctx->SetOutputShape(\"out\", 0, Shape({range_elem_cnt}));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> ArangeOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  DataType dtype = ctx->Attr<DataType>(\"dtype\");\n  int64_t range_elem_cnt = 0;\n  if (IsIntegralDataType(dtype)) {\n    int64_t integer_delta = ctx->Attr<int64_t>(\"integer_delta\");\n    if (integer_delta == static_cast<int64_t>(0)) {\n      return Error::RuntimeError() << \" step must be nonzero. \";\n    }\n    int64_t integer_start = ctx->Attr<int64_t>(\"integer_start\");\n    int64_t integer_limit = ctx->Attr<int64_t>(\"integer_limit\");\n    // CHECK when limit > start, delta > 0; limit < start, delta < 0;\n    if ((integer_limit - integer_start) / integer_delta < static_cast<int64_t>(0)) {\n      return Error::RuntimeError() << \" upper bound and larger bound inconsistent with step sign\";\n    }\n    range_elem_cnt = std::ceil(static_cast<double>(integer_limit - integer_start) / integer_delta);\n  } else {\n    double float_delta = ctx->Attr<double>(\"float_delta\");\n    if (float_delta == static_cast<double>(0.0)) {\n      return Error::RuntimeError() << \" step must be nonzero. \";\n    }\n    double float_start = ctx->Attr<double>(\"float_start\");\n    double float_limit = ctx->Attr<double>(\"float_limit\");\n    // CHECK when limit > start, delta > 0; limit < start, delta < 0;\n    // CHECK_GE For 0-Dim Tensor\n    if ((float_limit - float_start) / float_delta < static_cast<double>(0.0)) {\n      return Error::RuntimeError() << \" upper bound and larger bound inconsistent with step sign\";\n    }\n    range_elem_cnt = std::ceil(static_cast<double>(float_limit - float_start) / float_delta);\n  }\n  const Shape& logical_shape = Shape({range_elem_cnt});\n  const NdSbp& nd_sbp = ctx->NdSbp4ArgNameAndIndex(\"out\", 0);\n  const Shape& parallel_hierarchy = *ctx->parallel_desc().hierarchy();\n\n  const int64_t parallel_id = ctx->parallel_ctx().parallel_id();\n  const auto tensor_slice_view =\n      GetTensorSliceView4ParallelId(parallel_hierarchy, nd_sbp, logical_shape, parallel_id);\n  const Shape& physical_shape = tensor_slice_view.shape();\n\n  ctx->SetOutputShape(\"out\", 0, physical_shape);\n\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> ArangeOp::GetSbp(user_op::SbpContext* ctx) {\n  ctx->NewBuilder().Broadcast(ctx->inputs()).Broadcast(ctx->outputs()).Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> ArangeOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) {\n  SbpParallel default_sbp;\n  default_sbp.mutable_broadcast_parallel();\n  return user_op::InferNdSbp4SrcOp(ctx, default_sbp);\n}\n\n/* static */ Maybe<void> ArangeOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, ctx->Attr<DataType>(\"dtype\"));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/arg_sort_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/* static */ Maybe<void> ArgSortOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  ctx->SetOutputShape(\"out\", 0, ctx->InputShape(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> ArgSortOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> ArgSortOp::GetSbp(user_op::SbpContext* ctx) {\n  // The current implementation can only do arg_sort in the last dimension and should use\n  // Broadcast (by default) instead of Split for that dimension\n  const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"in\", 0);\n  FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes() - 1) {\n    ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build();\n  }\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> ArgSortOp::CheckAttr(const user_op::UserOpDefWrapper& def,\n                                              const user_op::UserOpConfWrapper& conf) {\n  const std::string& direction = conf.attr<std::string>(\"direction\");\n  CHECK_OR_RETURN(direction == \"ASCENDING\" || direction == \"DESCENDING\")\n      << Error::RuntimeError()\n      << \"expected the input direction parameter value is \\\"ASCENDING\\\" or \\\"DESCENDING\\\", \"\n      << \"but found the value is \" << direction;\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> ArgSortOp::InferDataType(user_op::InferContext* ctx) {\n  if (ctx->parallel_desc().device_type() == DeviceType::kNPU) {\n    ctx->SetOutputDType(\"out\", 0, DataType::kInt64);\n  } else {\n    ctx->SetOutputDType(\"out\", 0, DataType::kInt32);\n  }\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/arg_where_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nMaybe<void> InferTensorDesc(user_op::InferContext* ctx) {\n  const Shape& input_shape = ctx->InputShape(\"input\", 0);\n  user_op::TensorDesc* output_desc = ctx->MutOutputTensorDesc(\"output\", 0);\n  output_desc->set_shape(Shape({input_shape.elem_cnt(), input_shape.NumAxes()}));\n  output_desc->set_is_dynamic(true);\n  user_op::TensorDesc* output_size_desc = ctx->MutOutputTensorDesc(\"output_size\", 0);\n  output_size_desc->set_shape(Shape({1}));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\n/* static */ Maybe<void> ArgwhereOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  return InferTensorDesc(ctx);\n}\n\n/*static*/ Maybe<void> ArgwhereOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> ArgwhereOp::GetSbp(user_op::SbpContext* ctx) {\n  return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx);\n}\n\n/* static */ Maybe<void> ArgwhereOp::InferDataType(user_op::InferContext* ctx) {\n  const DataType dtype = ctx->Attr<DataType>(\"dtype\");\n  user_op::TensorDesc* output_desc = ctx->MutOutputTensorDesc(\"output\", 0);\n  output_desc->set_data_type(dtype);\n  user_op::TensorDesc* output_size_desc = ctx->MutOutputTensorDesc(\"output_size\", 0);\n  output_size_desc->set_data_type(dtype);\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/argmax_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/* static */ Maybe<void> ArgmaxOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  auto dim_vec = ctx->InputShape(\"in\", 0).dim_vec();\n  dim_vec.pop_back();\n  ctx->SetOutputShape(\"out\", 0, Shape(std::move(dim_vec)));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> ArgmaxOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> ArgmaxOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"in\", 0);\n  FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes() - 1) {\n    ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build();\n  }\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> ArgmaxOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, DataType::kInt64);\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/as_strided_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/* static */ auto AsStridedOp::InferLogicalTensorDesc(user_op::InferContext* ctx) -> Maybe<void> {\n  const auto& size = ctx->Attr<std::vector<int64_t>>(\"size\");\n  const auto& stride = ctx->Attr<std::vector<int64_t>>(\"stride\");\n  CHECK_EQ_OR_RETURN(size.size(), stride.size()) << \"mismatch in length of strides and shape\";\n  DimVector out_vec;\n  out_vec.insert(out_vec.end(), size.cbegin(), size.cend());\n  user_op::TensorDesc* output_desc = ctx->MutOutputTensorDesc(\"output\", 0);\n  output_desc->set_shape(Shape(out_vec));\n  return Maybe<void>::Ok();\n}\n/*static*/ auto AsStridedOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) -> Maybe<void> {\n  return AsStridedOp::InferLogicalTensorDesc(ctx);\n}\n/*static*/ auto AsStridedOp::GetSbp(user_op::SbpContext* ctx) -> Maybe<void> {\n  return Maybe<void>::Ok();\n}\n/*static*/ auto AsStridedOp::InferDataType(user_op::InferContext* ctx) -> Maybe<void> {\n  ctx->SetOutputDType(\"output\", 0, ctx->InputDType(\"input\", 0));\n  return Maybe<void>::Ok();\n}\n\n/* static */ auto AsStridedGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx)\n    -> Maybe<void> {\n  const Shape& input_shape = ctx->InputShape(\"input\", 0);\n  user_op::TensorDesc* dx_desc = ctx->MutOutputTensorDesc(\"dx\", 0);\n  dx_desc->set_shape(input_shape);\n  return Maybe<void>::Ok();\n}\n/*static*/ auto AsStridedGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx)\n    -> Maybe<void> {\n  return AsStridedGradOp::InferLogicalTensorDesc(ctx);\n}\n/*static*/ auto AsStridedGradOp::GetSbp(user_op::SbpContext* ctx) -> Maybe<void> {\n  return Maybe<void>::Ok();\n}\n/*static*/ auto AsStridedGradOp::InferDataType(user_op::InferContext* ctx) -> Maybe<void> {\n  ctx->SetOutputDType(\"dx\", 0, ctx->InputDType(\"input\", 0));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/assign_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nMaybe<void> InferTensorDesc(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& ref_desc = ctx->InputTensorDesc(\"ref\", 0);\n  const user_op::TensorDesc& value_desc = ctx->InputTensorDesc(\"value\", 0);\n  CHECK_OR_RETURN(!ref_desc.is_dynamic());\n  CHECK_OR_RETURN(ref_desc.shape() == value_desc.shape());\n  if (ctx->has_input(\"condition\", 0)) {\n    const user_op::TensorDesc& condition = ctx->InputTensorDesc(\"condition\", 0);\n    CHECK_OR_RETURN(condition.shape().NumAxes() == 1);\n    CHECK_OR_RETURN(condition.shape().At(0) == 1);\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> GetSbpSignatures(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& ref_desc = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"ref\", 0);\n  FOR_RANGE(int64_t, axis, 0, ref_desc.shape().NumAxes()) {\n    if (ctx->user_op_conf().has_input(\"condition\", 0)) {\n      ctx->NewBuilder()\n          .Split(user_op::OpArg(\"ref\", 0), axis)\n          .Split(user_op::OpArg(\"value\", 0), axis)\n          .Broadcast(user_op::OpArg(\"condition\", 0))\n          .Build();\n    } else {\n      ctx->NewBuilder().Split(ctx->inputs(), axis).Build();\n    }\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InputArgModifierFn(const user_op::GetInputArgModifier& GetInputArgModifierFn,\n                               const user_op::UserOpConfWrapper& conf) {\n  user_op::InputArgModifier* ref_modifier = GetInputArgModifierFn(\"ref\", 0);\n  CHECK_OR_RETURN(ref_modifier != nullptr);\n  ref_modifier->set_is_mutable(true);\n  user_op::InputArgModifier* value_modifier = GetInputArgModifierFn(\"value\", 0);\n  CHECK_OR_RETURN(value_modifier != nullptr);\n  value_modifier->set_requires_grad(false);\n  if (conf.has_input(\"condition\", 0)) {\n    user_op::InputArgModifier* condition_modifier = GetInputArgModifierFn(\"condition\", 0);\n    CHECK_OR_RETURN(condition_modifier != nullptr);\n    condition_modifier->set_requires_grad(false);\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InferDataType_(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& ref_desc = ctx->InputTensorDesc(\"ref\", 0);\n  const user_op::TensorDesc& value_desc = ctx->InputTensorDesc(\"value\", 0);\n  CHECK_OR_RETURN(ref_desc.data_type() == value_desc.data_type())\n      << Error::RuntimeError() << DataType_Name(ref_desc.data_type()) << \" vs.\"\n      << DataType_Name(value_desc.data_type());\n  if (ctx->has_input(\"condition\", 0)) {\n    const user_op::TensorDesc& condition = ctx->InputTensorDesc(\"condition\", 0);\n    CHECK_OR_RETURN(IsIndexDataType(condition.data_type()));\n  }\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\n#define DEF_ASSIGN_OP(op_class_name)                                                              \\\n  /* static */ Maybe<void> op_class_name::InferLogicalTensorDesc(user_op::InferContext* ctx) {    \\\n    return InferTensorDesc(ctx);                                                                  \\\n  }                                                                                               \\\n                                                                                                  \\\n  /*static*/ Maybe<void> op_class_name::InferPhysicalTensorDesc(user_op::InferContext* ctx) {     \\\n    return InferLogicalTensorDesc(ctx);                                                           \\\n  }                                                                                               \\\n                                                                                                  \\\n  /* static */ Maybe<void> op_class_name::GetSbp(user_op::SbpContext* ctx) {                      \\\n    return GetSbpSignatures(ctx);                                                                 \\\n  }                                                                                               \\\n                                                                                                  \\\n  /* static */ Maybe<void> op_class_name::ModifyInputArg(                                         \\\n      const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { \\\n    return InputArgModifierFn(GetInputArgModifierFn, conf);                                       \\\n  }                                                                                               \\\n                                                                                                  \\\n  /* static */ Maybe<void> op_class_name::InferDataType(user_op::InferContext* ctx) {             \\\n    return InferDataType_(ctx);                                                                   \\\n  }\n\nDEF_ASSIGN_OP(AssignUserOp)\nDEF_ASSIGN_OP(AssignIfOp)\nDEF_ASSIGN_OP(AssignIfNotOp)\n\n#undef DEF_ASSIGN_OP\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/avg_pool_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/user/kernels/avg_pool_kernel_util.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntypedef std::function<Maybe<void>(user_op::InferContext* ctx)> TensorDescInferFn;\n\nTensorDescInferFn AvgPoolMakeForwardTensorDescInferFn(const int32_t dim) {\n  return [dim](user_op::InferContext* ctx) -> Maybe<void> {\n    const Shape& x_shape = ctx->Shape4ArgNameAndIndex(\"x\", 0);\n    const std::string& data_format = ctx->Attr<std::string>(\"data_format\");\n    const std::vector<int32_t>& padding = ctx->Attr<std::vector<int32_t>>(\"padding\");\n    const std::vector<int32_t>& kernel_size = ctx->Attr<std::vector<int32_t>>(\"kernel_size\");\n    const std::vector<int32_t>& stride = ctx->Attr<std::vector<int32_t>>(\"stride\");\n    const bool ceil_mode = ctx->Attr<bool>(\"ceil_mode\");\n    const bool count_include_pad = ctx->Attr<bool>(\"count_include_pad\");\n    const int32_t& divisor_override = ctx->Attr<int32_t>(\"divisor_override\");\n\n    CHECK_EQ_OR_RETURN(kernel_size.size(), dim)\n        << Error::RuntimeError() << \"kernel size.size() should equal to dim.\";\n    for (int32_t pool_dim : kernel_size) {\n      CHECK_GT_OR_RETURN(pool_dim, 0)\n          << Error::RuntimeError() << \"kernel size should great than 0, but got: \" << pool_dim;\n    }\n    CHECK_EQ_OR_RETURN(stride.size(), dim)\n        << Error::RuntimeError() << \"stride.size() should equal to dim.\";\n    for (int32_t stride_dim : stride) {\n      CHECK_GT_OR_RETURN(stride_dim, 0)\n          << Error::RuntimeError() << \"stride size should great than 0, but got: \" << stride_dim;\n    }\n    for (int32_t i = 0; i < padding.size(); i++) {\n      CHECK_GE_OR_RETURN(kernel_size[i], 2 * padding[i])\n          << \"pad should be smaller than half of kernel size\";\n    }\n\n    const AvgPoolParams3D params_3d(dim, x_shape, data_format, padding, kernel_size, stride,\n                                    ceil_mode, count_include_pad, divisor_override);\n    user_op::TensorDesc* y_desc = ctx->MutOutputTensorDesc(\"y\", 0);\n    *y_desc = ctx->InputTensorDesc(\"x\", 0);\n    y_desc->set_shape(params_3d.GetYShape());\n\n    return Maybe<void>::Ok();\n  };\n}\n\nMaybe<void> AvgPoolForwardGetSbpFn(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"x\", 0);\n  FOR_RANGE(int64_t, i, 0, std::min(2, (int)tensor.shape().NumAxes() - 2)) {\n    ctx->NewBuilder().Split(user_op::OpArg(\"x\", 0), i).Split(user_op::OpArg(\"y\", 0), i).Build();\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> AvgPoolBackwardGetSbpFn(user_op::SbpContext* ctx) {\n  FOR_RANGE(int64_t, i, 0, 2) {\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"x\", 0), i)\n        .Split(user_op::OpArg(\"dy\", 0), i)\n        .Split(user_op::OpArg(\"dx\", 0), i)\n        .Build();\n  }\n  return Maybe<void>::Ok();\n}\n\n// Logically computation cost of pool op is the product of output data amount and pool kernal data\n// amount. After adding sbp, we just divide it by parallel number if output data is splitted because\n// splitting input and using partial sum for output is not a valid sbp for this op for now.\nMaybe<double> GetComputationCost(user_op::ComputeComplexityFnContext* ctx,\n                                 const std::string& blob_name) {\n  const std::vector<int32_t> pool_size = ctx->Attr<std::vector<int32_t>>(\"kernel_size\");\n  double logical_computation_cost = std::accumulate(\n      pool_size.begin(), pool_size.end(), ctx->Shape4ArgNameAndIndex(blob_name, 0).elem_cnt(),\n      std::multiplies<double>());\n  const auto& parallel_hierarchy = ctx->parallel_desc().hierarchy();\n  const auto& nd_sbp_y = ctx->NdSbp4ArgNameAndIndex(blob_name, 0);\n  for (int32_t dim_sbp = 0; dim_sbp < nd_sbp_y.sbp_parallel_size(); dim_sbp++) {\n    if (nd_sbp_y.sbp_parallel(dim_sbp).has_split_parallel()) {\n      logical_computation_cost /= parallel_hierarchy->At(dim_sbp);\n    }\n  }\n  return logical_computation_cost;\n}\n\nMaybe<void> BackwardTensorDescInferFn(user_op::InferContext* ctx) {\n  *ctx->MutOutputTensorDesc(\"dx\", 0) = ctx->InputTensorDesc(\"x\", 0);\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> FwInferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"y\", 0, ctx->InputDType(\"x\", 0));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> BwInferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"dx\", 0, ctx->InputDType(\"x\", 0));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\n#define IMPLEMENT_AVGPOOL_FUNCS(name, ndim)                                              \\\n  /*static*/ Maybe<void> name##Op::GetSbp(user_op::SbpContext* ctx) {                    \\\n    return AvgPoolForwardGetSbpFn(ctx);                                                  \\\n  }                                                                                      \\\n  /*static*/ Maybe<void> name##Op::InferLogicalTensorDesc(user_op::InferContext* ctx) {  \\\n    return AvgPoolMakeForwardTensorDescInferFn(ndim)(ctx);                               \\\n  }                                                                                      \\\n  /*static*/ Maybe<void> name##Op::InferPhysicalTensorDesc(user_op::InferContext* ctx) { \\\n    return InferLogicalTensorDesc(ctx);                                                  \\\n  }                                                                                      \\\n  /*static*/ Maybe<void> name##Op::InferDataType(user_op::InferContext* ctx) {           \\\n    return FwInferDataType(ctx);                                                         \\\n  }                                                                                      \\\n  /*static*/ Maybe<double> name##Op::GetComputeComplexity(                               \\\n      user_op::ComputeComplexityFnContext* ctx) {                                        \\\n    return GetComputationCost(ctx, \"y\");                                                 \\\n  }\n\nIMPLEMENT_AVGPOOL_FUNCS(AvgPool1D, 1)\nIMPLEMENT_AVGPOOL_FUNCS(AvgPool2D, 2)\nIMPLEMENT_AVGPOOL_FUNCS(AvgPool3D, 3)\n#undef IMPLEMENT_AVGPOOL_FUNCS\n\n#define IMPLEMENT_AVGPOOL_BACKWARD_FUNCS(name)                                               \\\n  /*static*/ Maybe<void> name##GradOp::GetSbp(user_op::SbpContext* ctx) {                    \\\n    return AvgPoolBackwardGetSbpFn(ctx);                                                     \\\n  }                                                                                          \\\n  /*static*/ Maybe<void> name##GradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {  \\\n    return BackwardTensorDescInferFn(ctx);                                                   \\\n  }                                                                                          \\\n  /*static*/ Maybe<void> name##GradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { \\\n    return InferLogicalTensorDesc(ctx);                                                      \\\n  }                                                                                          \\\n  /*static*/ Maybe<void> name##GradOp::InferDataType(user_op::InferContext* ctx) {           \\\n    return BwInferDataType(ctx);                                                             \\\n  }                                                                                          \\\n  /*static*/ Maybe<double> name##GradOp::GetComputeComplexity(                               \\\n      user_op::ComputeComplexityFnContext* ctx) {                                            \\\n    return GetComputationCost(ctx, \"dy\");                                                    \\\n  }\n\nIMPLEMENT_AVGPOOL_BACKWARD_FUNCS(AvgPool1D)\nIMPLEMENT_AVGPOOL_BACKWARD_FUNCS(AvgPool2D)\nIMPLEMENT_AVGPOOL_BACKWARD_FUNCS(AvgPool3D)\n#undef IMPLEMENT_AVGPOOL_BACKWARD_FUNCS\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/batch_gather_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/* static */ Maybe<void> BatchGatherOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& in = ctx->InputTensorDesc(\"in\", 0);\n  CHECK_GT_OR_RETURN(in.shape().NumAxes(), 0)\n      << Error::RuntimeError() << \"The dimension of the input tensor should be greater than zero, \"\n      << \"but got \" << in.shape().NumAxes();\n  const user_op::TensorDesc& indices = ctx->InputTensorDesc(\"indices\", 0);\n  CHECK_GT_OR_RETURN(indices.shape().NumAxes(), 0)\n      << Error::RuntimeError()\n      << \"The dimension of the indices tensor should be greater than zero, \"\n      << \"but got \" << indices.shape().NumAxes();\n  user_op::TensorDesc* out = ctx->MutOutputTensorDesc(\"out\", 0);\n  CHECK_LE_OR_RETURN(indices.shape().dim_vec().size(), in.shape().dim_vec().size())\n      << Error::RuntimeError()\n      << \"The dimension of the input tensor should be greater than or equal to the dimension of \"\n         \"the indices tensor, \"\n      << \"but found that the dimension of the input tensor is \" << in.shape().dim_vec().size()\n      << \", and the dimension of the indices tensor is \" << indices.shape().dim_vec().size();\n  FOR_RANGE(int64_t, i, 0, indices.shape().dim_vec().size() - 1) {\n    if (in.is_dynamic() && indices.is_dynamic() == false) {\n      CHECK_GE_OR_RETURN(indices.shape().dim_vec().at(i), in.shape().dim_vec().at(i))\n          << Error::RuntimeError()\n          << \"The size of indices tensor should be greater than or equal to the \"\n             \"size of input tensor \"\n          << \" at dimension \" << i\n          << \" when the input tensor is dynamic and the indices tensor is not dynamic\";\n    } else if (in.is_dynamic() == false && indices.is_dynamic()) {\n      LOG(FATAL)\n          << \"The indices tensor is not allowed to be dynamic when the input tensor is not dynamic\";\n    } else {\n      CHECK_EQ_OR_RETURN(indices.shape().dim_vec().at(i), in.shape().dim_vec().at(i))\n          << Error::RuntimeError()\n          << \"The size of indices tensor must match the size of input tensor\"\n          << \" at dimension \" << i << \" when two tensors are both dynamic or neither\";\n    }\n  }\n\n  DimVector dim_vec(in.shape().dim_vec());\n  dim_vec.at(indices.shape().NumAxes() - 1) = indices.shape().dim_vec().back();\n  out->set_shape(Shape(dim_vec));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> BatchGatherOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> BatchGatherOp::GetSbp(user_op::SbpContext* ctx) {\n  const int64_t indices_num_axes =\n      ctx->LogicalTensorDesc4InputArgNameAndIndex(\"indices\", 0).shape().NumAxes();\n  if (indices_num_axes > 1) {\n    FOR_RANGE(int64_t, i, 0, indices_num_axes - 1) {\n      ctx->NewBuilder()\n          .Split(user_op::OpArg(\"indices\", 0), i)\n          .Split(user_op::OpArg(\"in\", 0), i)\n          .Split(user_op::OpArg(\"out\", 0), i)\n          .Build();\n    }\n  }\n  ctx->NewBuilder()\n      .Broadcast(user_op::OpArg(\"indices\", 0))\n      .PartialSum(user_op::OpArg(\"in\", 0))\n      .PartialSum(user_op::OpArg(\"out\", 0))\n      .Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> BatchGatherOp::ModifyInputArg(\n    const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) {\n  user_op::InputArgModifier* indices_modifier = GetInputArgModifierFn(\"indices\", 0);\n  CHECK_OR_RETURN(indices_modifier != nullptr);  // NOLINT(maybe-need-error-msg)\n  indices_modifier->set_requires_grad(false);\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> BatchGatherOp::InferDataType(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& indices = ctx->InputTensorDesc(\"indices\", 0);\n  CHECK_OR_RETURN(IsIndexDataType(indices.data_type()))\n      << Error::TypeError() << \"The dtype of the indices tensor must be int32 or int64\";\n  const user_op::TensorDesc& in = ctx->InputTensorDesc(\"in\", 0);\n  user_op::TensorDesc* out = ctx->MutOutputTensorDesc(\"out\", 0);\n  out->set_data_type(in.data_type());\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/batch_norm_backward_elemt_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nstd::function<Maybe<void>(const std::string&)> MakeSetOutTensorDescFn(user_op::InferContext* ctx,\n                                                                      const Shape& shape) {\n  return [=](const std::string& bn) -> Maybe<void> {\n    if (ctx->has_output(bn, 0)) {\n      auto* tensor_desc = ctx->MutOutputTensorDesc(bn, 0);\n      CHECK_OR_RETURN(tensor_desc != nullptr) << \"output tensordesc of \" << bn << \" is null.\";\n      tensor_desc->set_shape(shape);\n    }\n    return Maybe<void>::Ok();\n  };\n}\n\nstd::function<Maybe<void>(const std::string&)> MakeSetOutDataTypeFn(user_op::InferContext* ctx,\n                                                                    DataType data_type) {\n  return [=](const std::string& bn) -> Maybe<void> {\n    if (ctx->has_output(bn, 0)) {\n      auto* tensor_desc = ctx->MutOutputTensorDesc(bn, 0);\n      CHECK_OR_RETURN(tensor_desc != nullptr) << \"output tensordesc of \" << bn << \" is null.\";\n      tensor_desc->set_data_type(data_type);\n    }\n    return Maybe<void>::Ok();\n  };\n}\n\n}  // namespace\n\n/* static */ Maybe<void> BatchNormBackwardElemtOp::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) {\n  const auto& x = ctx->InputTensorDesc(\"input\", 0);\n  const Shape& x_shape = x.shape();\n  const auto SetOutTensorDesc = MakeSetOutTensorDescFn(ctx, x_shape);\n  JUST(SetOutTensorDesc(\"grad_in\"));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> BatchNormBackwardElemtOp::InferPhysicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> BatchNormBackwardElemtOp::GetSbp(user_op::SbpContext* ctx) {\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> BatchNormBackwardElemtOp::InferDataType(user_op::InferContext* ctx) {\n  const auto& x = ctx->InputTensorDesc(\"input\", 0);\n  const auto data_type = x.data_type();\n  const DataType out_data_type = data_type == DataType::kFloat16 ? DataType::kFloat : data_type;\n  const auto SetOutDataType = MakeSetOutDataTypeFn(ctx, out_data_type);\n  JUST(SetOutDataType(\"grad_in\"));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/batch_norm_backward_reduce_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nstd::function<Maybe<void>(const std::string&)> MakeSetOutTensorDescFn(user_op::InferContext* ctx,\n                                                                      const Shape& shape) {\n  return [=](const std::string& bn) -> Maybe<void> {\n    if (ctx->has_output(bn, 0)) {\n      auto* tensor_desc = ctx->MutOutputTensorDesc(bn, 0);\n      CHECK_OR_RETURN(tensor_desc != nullptr) << \"output tensordesc of \" << bn << \" is null.\";\n      tensor_desc->set_shape(shape);\n    }\n    return Maybe<void>::Ok();\n  };\n}\n\nstd::function<Maybe<void>(const std::string&)> MakeSetOutDataTypeFn(user_op::InferContext* ctx,\n                                                                    DataType data_type) {\n  return [=](const std::string& bn) -> Maybe<void> {\n    if (ctx->has_output(bn, 0)) {\n      auto* tensor_desc = ctx->MutOutputTensorDesc(bn, 0);\n      CHECK_OR_RETURN(tensor_desc != nullptr) << \"output tensordesc of \" << bn << \" is null.\";\n      tensor_desc->set_data_type(data_type);\n    }\n    return Maybe<void>::Ok();\n  };\n}\n\n}  // namespace\n\n/* static */ Maybe<void> BatchNormBackwardReduceOp::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) {\n  const auto& x = ctx->InputTensorDesc(\"input\", 0);\n  const Shape& x_shape = x.shape();\n  const auto axis = ctx->Attr<int32_t>(\"axis\");\n  CHECK_GE_OR_RETURN(axis, 0) << \"channel axis should be larger than 0\";\n  CHECK_LT_OR_RETURN(axis, x_shape.NumAxes())\n      << \"channel axis should be less than \" << x_shape.NumAxes();\n  const Shape param_shape({x_shape.At(axis)});\n  const auto SetOutTensorDesc = MakeSetOutTensorDescFn(ctx, param_shape);\n  JUST(SetOutTensorDesc(\"sum_dy\"));\n  JUST(SetOutTensorDesc(\"sum_dy_xmu\"));\n  JUST(SetOutTensorDesc(\"grad_weight\"));\n  JUST(SetOutTensorDesc(\"grad_bias\"));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> BatchNormBackwardReduceOp::InferPhysicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> BatchNormBackwardReduceOp::GetSbp(user_op::SbpContext* ctx) {\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> BatchNormBackwardReduceOp::InferDataType(user_op::InferContext* ctx) {\n  const auto& x = ctx->InputTensorDesc(\"input\", 0);\n  const auto data_type = x.data_type();\n  const DataType out_data_type = data_type == DataType::kFloat16 ? DataType::kFloat : data_type;\n  const auto SetOutDataType = MakeSetOutDataTypeFn(ctx, out_data_type);\n  JUST(SetOutDataType(\"sum_dy\"));\n  JUST(SetOutDataType(\"sum_dy_xmu\"));\n  JUST(SetOutDataType(\"grad_weight\"));\n  JUST(SetOutDataType(\"grad_bias\"));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/batch_norm_elemt_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nstd::function<Maybe<void>(const std::string&)> MakeSetOutTensorDescFn(user_op::InferContext* ctx,\n                                                                      const Shape& shape) {\n  return [=](const std::string& bn) -> Maybe<void> {\n    if (ctx->has_output(bn, 0)) {\n      auto* tensor_desc = ctx->MutOutputTensorDesc(bn, 0);\n      CHECK_OR_RETURN(tensor_desc != nullptr) << \"output tensordesc of \" << bn << \" is null.\";\n      tensor_desc->set_shape(shape);\n    }\n    return Maybe<void>::Ok();\n  };\n}\n\nstd::function<Maybe<void>(const std::string&)> MakeSetOutDataTypeFn(user_op::InferContext* ctx,\n                                                                    DataType data_type) {\n  return [=](const std::string& bn) -> Maybe<void> {\n    if (ctx->has_output(bn, 0)) {\n      auto* tensor_desc = ctx->MutOutputTensorDesc(bn, 0);\n      CHECK_OR_RETURN(tensor_desc != nullptr) << \"output tensordesc of \" << bn << \" is null.\";\n      tensor_desc->set_data_type(data_type);\n    }\n    return Maybe<void>::Ok();\n  };\n}\n\n}  // namespace\n\n/* static */ Maybe<void> BatchNormElemtOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const auto& x = ctx->InputTensorDesc(\"input\", 0);\n  const Shape& x_shape = x.shape();\n  const auto SetOutTensorDesc = MakeSetOutTensorDescFn(ctx, x_shape);\n  JUST(SetOutTensorDesc(\"output\"));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> BatchNormElemtOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> BatchNormElemtOp::GetSbp(user_op::SbpContext* ctx) {\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> BatchNormElemtOp::InferDataType(user_op::InferContext* ctx) {\n  const auto& x = ctx->InputTensorDesc(\"input\", 0);\n  const auto data_type = x.data_type();\n  const DataType out_data_type = data_type == DataType::kFloat16 ? DataType::kFloat : data_type;\n  const auto SetOutDataType = MakeSetOutDataTypeFn(ctx, out_data_type);\n  JUST(SetOutDataType(\"output\"));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/batch_norm_gather_stats_with_counts_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nstd::function<Maybe<void>(const std::string&)> MakeSetOutTensorDescFn(user_op::InferContext* ctx,\n                                                                      const Shape& shape) {\n  return [=](const std::string& bn) -> Maybe<void> {\n    if (ctx->has_output(bn, 0)) {\n      auto* tensor_desc = ctx->MutOutputTensorDesc(bn, 0);\n      CHECK_OR_RETURN(tensor_desc != nullptr) << \"output tensordesc of \" << bn << \" is null.\";\n      tensor_desc->set_shape(shape);\n    }\n    return Maybe<void>::Ok();\n  };\n}\n\nstd::function<Maybe<void>(const std::string&)> MakeSetOutDataTypeFn(user_op::InferContext* ctx,\n                                                                    DataType data_type) {\n  return [=](const std::string& bn) -> Maybe<void> {\n    if (ctx->has_output(bn, 0)) {\n      auto* tensor_desc = ctx->MutOutputTensorDesc(bn, 0);\n      CHECK_OR_RETURN(tensor_desc != nullptr) << \"output tensordesc of \" << bn << \" is null.\";\n      tensor_desc->set_data_type(data_type);\n    }\n    return Maybe<void>::Ok();\n  };\n}\n\n}  // namespace\n\n/* static */ Maybe<void> BatchNormGatherStatsWithCountsOp::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) {\n  const auto& mean = ctx->InputTensorDesc(\"mean\", 0);\n  const Shape& mean_shape = mean.shape();\n  const Shape param_shape({mean_shape.At(1)});\n  const auto SetOutTensorDesc = MakeSetOutTensorDescFn(ctx, param_shape);\n  JUST(SetOutTensorDesc(\"global_mean\"));\n  JUST(SetOutTensorDesc(\"global_invstd\"));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> BatchNormGatherStatsWithCountsOp::InferPhysicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> BatchNormGatherStatsWithCountsOp::GetSbp(user_op::SbpContext* ctx) {\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> BatchNormGatherStatsWithCountsOp::InferDataType(\n    user_op::InferContext* ctx) {\n  const auto& x = ctx->InputTensorDesc(\"input\", 0);\n  const auto data_type = x.data_type();\n  const DataType out_data_type = data_type == DataType::kFloat16 ? DataType::kFloat : data_type;\n  const auto SetOutDataType = MakeSetOutDataTypeFn(ctx, out_data_type);\n  JUST(SetOutDataType(\"global_mean\"));\n  JUST(SetOutDataType(\"global_invstd\"));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> BatchNormGatherStatsWithCountsOp::ModifyInputArg(\n    const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) {\n  if (conf.has_input(\"running_mean\", 0)) {\n    CHECK_OR_RETURN(conf.has_input(\"running_var\", 0))\n        << \"running_mean and running_var should be provided as inputs in the same time.\";\n    user_op::InputArgModifier* running_mean_modifier = GetInputArgModifierFn(\"running_mean\", 0);\n    CHECK_OR_RETURN(running_mean_modifier != nullptr)\n        << \"input arg modifier of running_mean is null.\";\n    running_mean_modifier->set_is_mutable(true);\n    running_mean_modifier->set_requires_grad(false);\n    user_op::InputArgModifier* running_var_modifier = GetInputArgModifierFn(\"running_var\", 0);\n    CHECK_OR_RETURN(running_var_modifier != nullptr)\n        << \"input arg modifier of running_var is null.\";\n    running_var_modifier->set_is_mutable(true);\n    running_var_modifier->set_requires_grad(false);\n  }\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/batch_norm_stats_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nstd::function<Maybe<void>(const std::string&)> MakeSetOutTensorDescFn(user_op::InferContext* ctx,\n                                                                      const Shape& shape) {\n  return [=](const std::string& bn) -> Maybe<void> {\n    if (ctx->has_output(bn, 0)) {\n      auto* tensor_desc = ctx->MutOutputTensorDesc(bn, 0);\n      CHECK_OR_RETURN(tensor_desc != nullptr) << \"output tensordesc of \" << bn << \" is null.\";\n      tensor_desc->set_shape(shape);\n    }\n    return Maybe<void>::Ok();\n  };\n}\n\nstd::function<Maybe<void>(const std::string&)> MakeSetOutDataTypeFn(user_op::InferContext* ctx,\n                                                                    DataType data_type) {\n  return [=](const std::string& bn) -> Maybe<void> {\n    if (ctx->has_output(bn, 0)) {\n      auto* tensor_desc = ctx->MutOutputTensorDesc(bn, 0);\n      CHECK_OR_RETURN(tensor_desc != nullptr) << \"output tensordesc of \" << bn << \" is null.\";\n      tensor_desc->set_data_type(data_type);\n    }\n    return Maybe<void>::Ok();\n  };\n}\n\n}  // namespace\n\n/* static */ Maybe<void> BatchNormStatsOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const auto& x = ctx->InputTensorDesc(\"input\", 0);\n  const Shape& x_shape = x.shape();\n  const auto axis = ctx->Attr<int32_t>(\"axis\");\n  CHECK_GE_OR_RETURN(axis, 0) << \"channel axis should be larger than 0\";\n  CHECK_LT_OR_RETURN(axis, x_shape.NumAxes())\n      << \"channel axis should be less than \" << x_shape.NumAxes();\n  const Shape param_shape({x_shape.At(axis)});\n  const auto SetOutTensorDesc = MakeSetOutTensorDescFn(ctx, param_shape);\n  JUST(SetOutTensorDesc(\"mean\"));\n  JUST(SetOutTensorDesc(\"invstd\"));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> BatchNormStatsOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> BatchNormStatsOp::GetSbp(user_op::SbpContext* ctx) {\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> BatchNormStatsOp::InferDataType(user_op::InferContext* ctx) {\n  const auto& x = ctx->InputTensorDesc(\"input\", 0);\n  const auto data_type = x.data_type();\n  const DataType out_data_type = data_type == DataType::kFloat16 ? DataType::kFloat : data_type;\n  const auto SetOutDataType = MakeSetOutDataTypeFn(ctx, out_data_type);\n  JUST(SetOutDataType(\"mean\"));\n  JUST(SetOutDataType(\"invstd\"));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/bernoulli_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/* static */ Maybe<void> BernoulliOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  user_op::TensorDesc* out_tensor = ctx->MutOutputTensorDesc(\"out\", 0);\n  const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc(\"in\", 0);\n  out_tensor->set_shape(in_tensor.shape());\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> BernoulliOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> BernoulliOp::GetSbp(user_op::SbpContext* ctx) {\n  const auto& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"in\", 0);\n  for (int i = 0; i < in_tensor.shape().NumAxes(); ++i) {\n    ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build();\n  }\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> BernoulliOp::InferDataType(user_op::InferContext* ctx) {\n  user_op::TensorDesc* out_tensor = ctx->MutOutputTensorDesc(\"out\", 0);\n  out_tensor->set_data_type(ctx->Attr<DataType>(\"dtype\"));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/bias_add_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/* static */ Maybe<void> BiasAddOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const auto& a_tensor_desc = ctx->InputTensorDesc(\"a\", 0);\n  const auto& b_tensor_desc = ctx->InputTensorDesc(\"b\", 0);\n  const auto bias_add_axis = ctx->Attr<int32_t>(\"axis\");\n  CHECK_EQ_OR_RETURN(b_tensor_desc.shape().NumAxes(), 1)\n      << Error::RuntimeError() << \"Bias tensor has to be a one-dimensional vector\";\n  CHECK_GE_OR_RETURN(bias_add_axis, 0)\n      << Error::RuntimeError() << \"The size of the axis must greater than or equal to 0, \"\n      << \"but got \" << bias_add_axis;\n  CHECK_LT_OR_RETURN(bias_add_axis, a_tensor_desc.shape().NumAxes())\n      << Error::IndexError() << \"Dimension out of range (expected to be in range of [0\"\n      << \", \" << a_tensor_desc.shape().NumAxes() - 1 << \"],\"\n      << \" but got \" << bias_add_axis << \")\";\n  CHECK_EQ_OR_RETURN(a_tensor_desc.shape().At(bias_add_axis), b_tensor_desc.shape().At(0))\n      << Error::RuntimeError() << \"The size of tensor \" << a_tensor_desc.shape().ToString()\n      << \" must match the size of tensor \" << b_tensor_desc.shape().ToString() << \" at dimension \"\n      << bias_add_axis;\n  ctx->SetOutputShape(\"out\", 0, ctx->InputShape(\"a\", 0));\n  ctx->SetOutputIsDynamic(\"out\", 0, ctx->InputIsDynamic(\"a\", 0));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> BiasAddOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> BiasAddOp::GetSbp(user_op::SbpContext* ctx) {\n  const auto axis = ctx->Attr<int32_t>(\"axis\");\n  for (int64_t i = 0; i < ctx->LogicalTensorDesc4InputArgNameAndIndex(\"a\", 0).shape().NumAxes();\n       ++i) {\n    if (i == axis) { continue; }\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"a\", 0), i)\n        .Broadcast(user_op::OpArg(\"b\", 0))\n        .Split(ctx->outputs(), i)\n        .Build();\n  }\n  ctx->NewBuilder()\n      .Split(user_op::OpArg(\"b\", 0), 0)\n      .Split(user_op::OpArg(\"a\", 0), axis)\n      .Split(ctx->outputs(), axis)\n      .Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> BiasAddOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"a\", 0));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/binary_cross_entropy_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/user/ops/loss_op_util.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nMaybe<void> InferTensorDescFn_(user_op::InferContext* ctx) {\n  const auto& input_desc = ctx->InputTensorDesc(\"input\", 0);\n  const auto& target_desc = ctx->InputTensorDesc(\"target\", 0);\n  CHECK_EQ_OR_RETURN(input_desc.is_dynamic(), target_desc.is_dynamic());\n  CHECK_EQ_OR_RETURN(input_desc.shape(), target_desc.shape());\n  if (ctx->has_input(\"weight\", 0)) {\n    const auto& weight_desc = ctx->InputTensorDesc(\"weight\", 0);\n    CHECK_EQ_OR_RETURN(weight_desc.is_dynamic(), input_desc.is_dynamic());\n    CHECK_EQ_OR_RETURN(weight_desc.shape(), input_desc.shape());\n  }\n\n  user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc(\"out\", 0);\n  out_desc->set_is_dynamic(input_desc.is_dynamic());\n  out_desc->set_shape(input_desc.shape());\n\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InferDataType_(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& input_desc = ctx->InputTensorDesc(\"input\", 0);\n  const user_op::TensorDesc& target_desc = ctx->InputTensorDesc(\"target\", 0);\n  CHECK_EQ_OR_RETURN(input_desc.data_type(), target_desc.data_type())\n      << \"InferDataType Failed. Expected \" << DataType_Name(input_desc.data_type()) << \", but got \"\n      << DataType_Name(target_desc.data_type());\n  if (ctx->has_input(\"weight\", 0)) {\n    const auto& weight_desc = ctx->InputTensorDesc(\"weight\", 0);\n    CHECK_EQ_OR_RETURN(weight_desc.data_type(), input_desc.data_type())\n        << \"InferDataType Failed. Expected \" << DataType_Name(input_desc.data_type())\n        << \", but got \" << DataType_Name(weight_desc.data_type());\n  }\n\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"input\", 0));\n\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InferGradTensorDescFn(user_op::InferContext* ctx) {\n  const auto& input_desc = ctx->InputTensorDesc(\"input\", 0);\n  const auto& target_desc = ctx->InputTensorDesc(\"target\", 0);\n  const auto& dy_desc = ctx->InputTensorDesc(\"dy\", 0);\n  CHECK_EQ_OR_RETURN(input_desc.is_dynamic(), target_desc.is_dynamic());\n  CHECK_EQ_OR_RETURN(input_desc.shape(), target_desc.shape());\n  CHECK_EQ_OR_RETURN(dy_desc.shape(), target_desc.shape());\n  if (ctx->has_input(\"weight\", 0)) {\n    const auto& weight_desc = ctx->InputTensorDesc(\"weight\", 0);\n    CHECK_EQ_OR_RETURN(weight_desc.is_dynamic(), input_desc.is_dynamic());\n    CHECK_EQ_OR_RETURN(weight_desc.shape(), input_desc.shape());\n  }\n\n  user_op::TensorDesc* dx_desc = ctx->MutOutputTensorDesc(\"dx\", 0);\n  dx_desc->set_is_dynamic(input_desc.is_dynamic());\n  dx_desc->set_shape(input_desc.shape());\n\n  return Maybe<void>::Ok();\n}\nMaybe<void> InferGradDataType(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& input_desc = ctx->InputTensorDesc(\"input\", 0);\n  const user_op::TensorDesc& target_desc = ctx->InputTensorDesc(\"target\", 0);\n  CHECK_EQ_OR_RETURN(input_desc.data_type(), target_desc.data_type())\n      << \"InferDataType Failed. Expected \" << DataType_Name(input_desc.data_type()) << \", but got \"\n      << DataType_Name(target_desc.data_type());\n  if (ctx->has_input(\"weight\", 0)) {\n    const auto& weight_desc = ctx->InputTensorDesc(\"weight\", 0);\n    CHECK_EQ_OR_RETURN(weight_desc.data_type(), input_desc.data_type())\n        << \"InferDataType Failed. Expected \" << DataType_Name(input_desc.data_type())\n        << \", but got \" << DataType_Name(weight_desc.data_type());\n  }\n\n  ctx->SetOutputDType(\"dx\", 0, ctx->InputDType(\"dy\", 0));\n\n  return Maybe<void>::Ok();\n}\n}  // namespace\n\n/* static */ Maybe<void> BinaryCrossEntropyOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  return InferTensorDescFn_(ctx);\n}\n\n/*static*/ Maybe<void> BinaryCrossEntropyOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> BinaryCrossEntropyOp::GetSbp(user_op::SbpContext* ctx) {\n  return GenLossForwardDefaultGetSbpFn()(ctx);\n}\n\n/* static */ Maybe<void> BinaryCrossEntropyOp::ModifyInputArg(\n    const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) {\n  user_op::InputArgModifier* target_modifier = GetInputArgModifierFn(\"target\", 0);\n  CHECK_OR_RETURN(target_modifier != nullptr);\n  target_modifier->set_requires_grad(false);\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> BinaryCrossEntropyOp::InferDataType(user_op::InferContext* ctx) {\n  return InferDataType_(ctx);\n}\n\n/* static */ Maybe<void> BinaryCrossEntropyGradOp::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferGradTensorDescFn(ctx);\n}\n\n/*static*/ Maybe<void> BinaryCrossEntropyGradOp::InferPhysicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> BinaryCrossEntropyGradOp::GetSbp(user_op::SbpContext* ctx) {\n  return GenLossBackwardDefaultGetSbpFn()(ctx);\n}\n\n/* static */ Maybe<void> BinaryCrossEntropyGradOp::InferDataType(user_op::InferContext* ctx) {\n  return InferGradDataType(ctx);\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/binary_cross_entropy_with_logits_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/framework/dtype.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/user/ops/loss_op_util.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\nnamespace {\nMaybe<void> InferTensorDescFn(user_op::InferContext* ctx) {\n  const auto& input_desc = ctx->InputTensorDesc(\"input\", 0);\n  const auto& target_desc = ctx->InputTensorDesc(\"target\", 0);\n  CHECK_EQ_OR_RETURN(input_desc.is_dynamic(), target_desc.is_dynamic());\n  CHECK_EQ_OR_RETURN(input_desc.shape(), target_desc.shape())\n      << \"Input shape should be equal to Target shape. \";\n  if (ctx->has_input(\"weight\", 0)) {\n    const auto& weight_desc = ctx->InputTensorDesc(\"weight\", 0);\n    CHECK_EQ_OR_RETURN(weight_desc.is_dynamic(), input_desc.is_dynamic());\n    CHECK_EQ_OR_RETURN(weight_desc.shape(), input_desc.shape());\n  }\n  if (ctx->Attr<bool>(\"has_pos_weight\")) {\n    const auto& pos_weight_desc = ctx->InputTensorDesc(\"pos_weight\", 0);\n    CHECK_EQ_OR_RETURN(pos_weight_desc.is_dynamic(), input_desc.is_dynamic());\n  }\n  user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc(\"out\", 0);\n  out_desc->set_is_dynamic(input_desc.is_dynamic());\n  out_desc->set_shape(input_desc.shape());\n\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InferDataType_(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& input_desc = ctx->InputTensorDesc(\"input\", 0);\n  const user_op::TensorDesc& target_desc = ctx->InputTensorDesc(\"target\", 0);\n  CHECK_GE_OR_RETURN(DType::priority_order[input_desc.data_type()],\n                     DType::priority_order[DType::Float16()->data_type()]);\n  CHECK_GE_OR_RETURN(DType::priority_order[target_desc.data_type()],\n                     DType::priority_order[DType::Float16()->data_type()]);\n  if (ctx->has_input(\"weight\", 0)) {\n    const auto& weight_desc = ctx->InputTensorDesc(\"weight\", 0);\n    CHECK_EQ_OR_RETURN(weight_desc.data_type(), target_desc.data_type())\n        << \"InferDataType Failed. Expected \" << DataType_Name(target_desc.data_type())\n        << \", but got \" << DataType_Name(weight_desc.data_type());\n  }\n  if (ctx->Attr<bool>(\"has_pos_weight\")) {\n    const auto& pos_weight_desc = ctx->InputTensorDesc(\"pos_weight\", 0);\n    CHECK_EQ_OR_RETURN(pos_weight_desc.data_type(), target_desc.data_type())\n        << \"InferDataType Failed. Expected \" << DataType_Name(target_desc.data_type())\n        << \", but got \" << DataType_Name(pos_weight_desc.data_type());\n  }\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"target\", 0));\n\n  return Maybe<void>::Ok();\n}\nMaybe<void> InferGradTensorDescFn(user_op::InferContext* ctx) {\n  const auto& input_desc = ctx->InputTensorDesc(\"input\", 0);\n  const auto& target_desc = ctx->InputTensorDesc(\"target\", 0);\n  const auto& dy_desc = ctx->InputTensorDesc(\"dy\", 0);\n  CHECK_EQ_OR_RETURN(input_desc.is_dynamic(), target_desc.is_dynamic());\n  CHECK_EQ_OR_RETURN(input_desc.shape(), target_desc.shape())\n      << \"Input shape should be equal to Target shape. \";\n  CHECK_EQ_OR_RETURN(dy_desc.shape(), target_desc.shape())\n      << \"Dy shape should be equal to Target shape. \";\n  if (ctx->has_input(\"weight\", 0)) {\n    const auto& weight_desc = ctx->InputTensorDesc(\"weight\", 0);\n    CHECK_EQ_OR_RETURN(weight_desc.is_dynamic(), input_desc.is_dynamic());\n    CHECK_EQ_OR_RETURN(weight_desc.shape(), input_desc.shape());\n  }\n  if (ctx->Attr<bool>(\"has_pos_weight\")) {\n    const auto& pos_weight_desc = ctx->InputTensorDesc(\"pos_weight\", 0);\n    CHECK_EQ_OR_RETURN(pos_weight_desc.is_dynamic(), input_desc.is_dynamic());\n  }\n\n  user_op::TensorDesc* dx_desc = ctx->MutOutputTensorDesc(\"dx\", 0);\n  dx_desc->set_is_dynamic(input_desc.is_dynamic());\n  dx_desc->set_shape(input_desc.shape());\n\n  return Maybe<void>::Ok();\n}\nMaybe<void> InferGradDataType(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& input_desc = ctx->InputTensorDesc(\"input\", 0);\n  const user_op::TensorDesc& target_desc = ctx->InputTensorDesc(\"target\", 0);\n  CHECK_GE_OR_RETURN(DType::priority_order[input_desc.data_type()],\n                     DType::priority_order[DType::Float16()->data_type()]);\n  CHECK_GE_OR_RETURN(DType::priority_order[target_desc.data_type()],\n                     DType::priority_order[DType::Float16()->data_type()]);\n  if (ctx->has_input(\"weight\", 0)) {\n    const auto& weight_desc = ctx->InputTensorDesc(\"weight\", 0);\n    CHECK_EQ_OR_RETURN(weight_desc.data_type(), target_desc.data_type())\n        << \"InferDataType Failed. Expected \" << DataType_Name(weight_desc.data_type())\n        << \", but got \" << DataType_Name(target_desc.data_type());\n  }\n  if (ctx->Attr<bool>(\"has_pos_weight\")) {\n    const auto& pos_weight_desc = ctx->InputTensorDesc(\"pos_weight\", 0);\n    CHECK_EQ_OR_RETURN(pos_weight_desc.data_type(), target_desc.data_type())\n        << \"InferDataType Failed. Expected \" << DataType_Name(target_desc.data_type())\n        << \", but got \" << DataType_Name(pos_weight_desc.data_type());\n  }\n  ctx->SetOutputDType(\"dx\", 0, ctx->InputDType(\"input\", 0));\n\n  return Maybe<void>::Ok();\n}\n}  // namespace\n\n/* static */ Maybe<void> BinaryCrossEntropyWithLogitsOp::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferTensorDescFn(ctx);\n}\n\n/*static*/ Maybe<void> BinaryCrossEntropyWithLogitsOp::InferPhysicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> BinaryCrossEntropyWithLogitsOp::GetSbp(user_op::SbpContext* ctx) {\n  return GenLossForwardDefaultGetSbpFn(\n      [](user_op::UserOpSbpSignatureBuilder& builder, user_op::SbpContext* ctx) {\n        if (ctx->user_op_conf().has_input(\"pos_weight\", 0)) {\n          builder.Broadcast(user_op::OpArg(\"pos_weight\", 0));\n        }\n      })(ctx);\n}\n\n/* static */ Maybe<void> BinaryCrossEntropyWithLogitsOp::ModifyInputArg(\n    const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) {\n  user_op::InputArgModifier* target_modifier = GetInputArgModifierFn(\"target\", 0);\n  CHECK_OR_RETURN(target_modifier != nullptr);\n  target_modifier->set_requires_grad(false);\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> BinaryCrossEntropyWithLogitsOp::InferDataType(user_op::InferContext* ctx) {\n  return InferDataType_(ctx);\n}\n\n/* static */ Maybe<void> BinaryCrossEntropyWithLogitsGradOp::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferGradTensorDescFn(ctx);\n}\n\n/*static*/ Maybe<void> BinaryCrossEntropyWithLogitsGradOp::InferPhysicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> BinaryCrossEntropyWithLogitsGradOp::GetSbp(user_op::SbpContext* ctx) {\n  return GenLossBackwardDefaultGetSbpFn(\n      [](user_op::UserOpSbpSignatureBuilder& builder, user_op::SbpContext* ctx) {\n        if (ctx->user_op_conf().has_input(\"pos_weight\", 0)) {\n          builder.Broadcast(user_op::OpArg(\"pos_weight\", 0));\n        }\n      })(ctx);\n}\n\n/* static */ Maybe<void> BinaryCrossEntropyWithLogitsGradOp::InferDataType(\n    user_op::InferContext* ctx) {\n  return InferGradDataType(ctx);\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/binary_cross_entropy_with_logits_reduce_mean_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/framework/dtype.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nMaybe<void> InferTensorDescFn(user_op::InferContext* ctx) {\n  const auto& input_desc = ctx->InputTensorDesc(\"input\", 0);\n  const auto& target_desc = ctx->InputTensorDesc(\"target\", 0);\n  CHECK_EQ_OR_RETURN(input_desc.shape(), target_desc.shape())\n      << \"Input shape should be equal to Target shape. \";\n  user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc(\"out\", 0);\n  out_desc->set_is_dynamic(false);\n  out_desc->set_shape(Shape({}));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InferFwDataType(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& input_desc = ctx->InputTensorDesc(\"input\", 0);\n  const user_op::TensorDesc& target_desc = ctx->InputTensorDesc(\"target\", 0);\n  CHECK_GE_OR_RETURN(DType::priority_order[input_desc.data_type()],\n                     DType::priority_order[DType::Float16()->data_type()]);\n  CHECK_GE_OR_RETURN(DType::priority_order[target_desc.data_type()],\n                     DType::priority_order[DType::Float16()->data_type()]);\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"target\", 0));\n\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InferGradTensorDescFn(user_op::InferContext* ctx) {\n  const auto& input_desc = ctx->InputTensorDesc(\"input\", 0);\n  const auto& target_desc = ctx->InputTensorDesc(\"target\", 0);\n  CHECK_EQ_OR_RETURN(input_desc.shape(), target_desc.shape())\n      << \"Input shape should be equal to Target shape. \";\n  user_op::TensorDesc* dx_desc = ctx->MutOutputTensorDesc(\"dx\", 0);\n  dx_desc->set_is_dynamic(false);\n  dx_desc->set_shape(input_desc.shape());\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InferGradDataType(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& input_desc = ctx->InputTensorDesc(\"input\", 0);\n  const user_op::TensorDesc& target_desc = ctx->InputTensorDesc(\"target\", 0);\n  CHECK_GE_OR_RETURN(DType::priority_order[input_desc.data_type()],\n                     DType::priority_order[DType::Float16()->data_type()]);\n  CHECK_GE_OR_RETURN(DType::priority_order[target_desc.data_type()],\n                     DType::priority_order[DType::Float16()->data_type()]);\n  ctx->SetOutputDType(\"dx\", 0, ctx->InputDType(\"input\", 0));\n  return Maybe<void>::Ok();\n}\n}  // namespace\n\n/* static */ Maybe<void> BinaryCrossEntropyWithLogitsReduceMeanOp::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferTensorDescFn(ctx);\n}\n\n/*static*/ Maybe<void> BinaryCrossEntropyWithLogitsReduceMeanOp::InferPhysicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> BinaryCrossEntropyWithLogitsReduceMeanOp::GetSbp(\n    user_op::SbpContext* ctx) {\n  ctx->NewBuilder()\n      .Split(user_op::OpArg(\"input\", 0), 0)\n      .Split(user_op::OpArg(\"target\", 0), 0)\n      .PartialSum(user_op::OpArg(\"out\", 0))\n      .Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> BinaryCrossEntropyWithLogitsReduceMeanOp::ModifyInputArg(\n    const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) {\n  user_op::InputArgModifier* target_modifier = GetInputArgModifierFn(\"target\", 0);\n  CHECK_OR_RETURN(target_modifier != nullptr) << \"target_modifier should not be nullptr. \";\n  target_modifier->set_requires_grad(false);\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> BinaryCrossEntropyWithLogitsReduceMeanOp::InferDataType(\n    user_op::InferContext* ctx) {\n  return InferFwDataType(ctx);\n}\n\n/* static */ Maybe<void> BinaryCrossEntropyWithLogitsReduceMeanGradOp::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferGradTensorDescFn(ctx);\n}\n\n/*static*/ Maybe<void> BinaryCrossEntropyWithLogitsReduceMeanGradOp::InferPhysicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> BinaryCrossEntropyWithLogitsReduceMeanGradOp::GetSbp(\n    user_op::SbpContext* ctx) {\n  ctx->NewBuilder()\n      .Split(user_op::OpArg(\"input\", 0), 0)\n      .Split(user_op::OpArg(\"target\", 0), 0)\n      .Split(user_op::OpArg(\"dx\", 0), 0)\n      .Broadcast(user_op::OpArg(\"dy\", 0))\n      .Build();\n\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> BinaryCrossEntropyWithLogitsReduceMeanGradOp::InferDataType(\n    user_op::InferContext* ctx) {\n  return InferGradDataType(ctx);\n}\n\n/* static */ Maybe<void> FusedBCEReduceMeanFwBwOp::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) {\n  const auto& input_desc = ctx->InputTensorDesc(\"input\", 0);\n  const auto& target_desc = ctx->InputTensorDesc(\"target\", 0);\n  CHECK_EQ_OR_RETURN(input_desc.shape(), target_desc.shape())\n      << \"Input shape should be equal to Target shape. \";\n  user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc(\"out\", 0);\n  out_desc->set_is_dynamic(false);\n  out_desc->set_shape(Shape({}));\n  user_op::TensorDesc* dx_desc = ctx->MutOutputTensorDesc(\"dx\", 0);\n  dx_desc->set_is_dynamic(false);\n  dx_desc->set_shape(input_desc.shape());\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> FusedBCEReduceMeanFwBwOp::InferPhysicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> FusedBCEReduceMeanFwBwOp::GetSbp(user_op::SbpContext* ctx) {\n  ctx->NewBuilder()\n      .Split(user_op::OpArg(\"input\", 0), 0)\n      .Split(user_op::OpArg(\"target\", 0), 0)\n      .PartialSum(user_op::OpArg(\"out\", 0))\n      .Split(user_op::OpArg(\"dx\", 0), 0)\n      .Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> FusedBCEReduceMeanFwBwOp::InferDataType(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& input_desc = ctx->InputTensorDesc(\"input\", 0);\n  const user_op::TensorDesc& target_desc = ctx->InputTensorDesc(\"target\", 0);\n  CHECK_GE_OR_RETURN(DType::priority_order[input_desc.data_type()],\n                     DType::priority_order[DType::Float16()->data_type()]);\n  CHECK_GE_OR_RETURN(DType::priority_order[target_desc.data_type()],\n                     DType::priority_order[DType::Float16()->data_type()]);\n  DataType out_dtype = ctx->Attr<DataType>(\"out_dtype\");\n  if (out_dtype == DataType::kInvalidDataType) { out_dtype = target_desc.data_type(); }\n  ctx->SetOutputDType(\"out\", 0, out_dtype);\n  ctx->SetOutputDType(\"dx\", 0, input_desc.data_type());\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/bincount_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nMaybe<void> InferTensorDesc(user_op::InferContext* ctx) {\n  user_op::TensorDesc* output_desc = ctx->MutOutputTensorDesc(\"out\", 0);\n  const int64_t size = ctx->Attr<int64_t>(\"size\");\n  output_desc->set_shape(Shape({size}));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\n/* static */ Maybe<void> BinCountOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  return InferTensorDesc(ctx);\n}\n\n/*static*/ Maybe<void> BinCountOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> BinCountOp::GetSbp(user_op::SbpContext* ctx) {\n  return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx);\n}\n\n/* static */ Maybe<void> BinCountOp::InferDataType(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& input_desc = ctx->InputTensorDesc(\"in\", 0);\n  user_op::TensorDesc* output_desc = ctx->MutOutputTensorDesc(\"out\", 0);\n  if (ctx->has_input(\"weight\", 0)) {\n    const user_op::TensorDesc& weight_desc = ctx->InputTensorDesc(\"weight\", 0);\n    output_desc->set_data_type(weight_desc.data_type());\n  } else {\n    output_desc->set_data_type(input_desc.data_type());\n  }\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/broadcast_div_grad_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/* static */ Maybe<void> BroadcastDivGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  ctx->SetOutputShape(\"dy\", 0, ctx->InputShape(\"y\", 0));\n  ctx->SetOutputIsDynamic(\"dy\", 0, ctx->InputIsDynamic(\"y\", 0));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> BroadcastDivGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> BroadcastDivGradOp::GetSbp(user_op::SbpContext* ctx) {\n  const Shape& y_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"y\", 0).shape();\n  const Shape& z_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"z\", 0).shape();\n  CHECK_LE_OR_RETURN(y_shape.NumAxes(), z_shape.NumAxes());\n  FOR_RANGE(int64_t, i, 0, y_shape.NumAxes()) {\n    const int64_t axis_y = y_shape.NumAxes() - 1 - i;\n    const int64_t axis_z = z_shape.NumAxes() - 1 - i;\n    if (y_shape.At(axis_y) == z_shape.At(axis_z)) {\n      ctx->NewBuilder()\n          .Split(user_op::OpArg(\"y\", 0), axis_y)\n          .Split(user_op::OpArg(\"z\", 0), axis_z)\n          .Split(user_op::OpArg(\"dz\", 0), axis_z)\n          .Split(user_op::OpArg(\"dy\", 0), axis_y)\n          .Build();\n    } else {\n      ctx->NewBuilder()\n          .Broadcast(user_op::OpArg(\"y\", 0))\n          .Split(user_op::OpArg(\"z\", 0), axis_z)\n          .Split(user_op::OpArg(\"dz\", 0), axis_z)\n          .PartialSum(user_op::OpArg(\"dy\", 0))\n          .Build();\n    }\n  }\n  ctx->NewBuilder()\n      .Broadcast(user_op::OpArg(\"y\", 0))\n      .PartialSum(user_op::OpArg(\"z\", 0))\n      .Broadcast(user_op::OpArg(\"dz\", 0))\n      .PartialSum(user_op::OpArg(\"dy\", 0))\n      .Build();\n  ctx->NewBuilder()\n      .Broadcast(user_op::OpArg(\"y\", 0))\n      .Broadcast(user_op::OpArg(\"z\", 0))\n      .PartialSum(user_op::OpArg(\"dz\", 0))\n      .PartialSum(user_op::OpArg(\"dy\", 0))\n      .Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> BroadcastDivGradOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"dy\", 0, ctx->InputDType(\"y\", 0));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/broadcast_like_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/operator/reduce_sbp_util.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nMaybe<void> GetSbpSignatures(user_op::SbpContext* ctx) {\n  const auto& x_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"x\", 0).shape();\n  const auto& like_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"like\", 0).shape();\n  int32_t x_num_axes = x_shape.NumAxes();\n  int32_t like_num_axes = like_shape.NumAxes();\n  const auto& reduced_axes = ctx->Attr<std::vector<int32_t>>(\"broadcast_axes\");\n  HashSet<int32_t> conf_axes;\n  ReduceSbpUtil::GetRegularAxes(like_num_axes, reduced_axes, &conf_axes);\n  auto IsReducedAxis = ReduceSbpUtil::MakePredicatorIsReducedAxis(conf_axes, like_num_axes);\n  int32_t num_reduced_axis = 0;\n  FOR_RANGE(int64_t, i, 0, like_num_axes) {\n    if (IsReducedAxis(i)) {\n      ctx->NewBuilder()\n          .Broadcast(user_op::OpArg(\"x\", 0))\n          .Split(user_op::OpArg(\"like\", 0), i)\n          .Split(user_op::OpArg(\"y\", 0), i)\n          .Build();\n      if (x_num_axes < like_num_axes) { num_reduced_axis += 1; }\n    } else {\n      ctx->NewBuilder()\n          .Split(user_op::OpArg(\"x\", 0), i - num_reduced_axis)\n          .Split(user_op::OpArg(\"like\", 0), i)\n          .Split(user_op::OpArg(\"y\", 0), i)\n          .Build();\n    }\n  }\n  ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build();\n  ctx->NewBuilder()\n      .PartialSum(user_op::OpArg(\"x\", 0))\n      .Broadcast(user_op::OpArg(\"like\", 0))\n      .PartialSum(user_op::OpArg(\"y\", 0))\n      .Build();\n  ctx->NewBuilder()\n      .Broadcast(user_op::OpArg(\"x\", 0))\n      .PartialSum(user_op::OpArg(\"like\", 0))\n      .Broadcast(user_op::OpArg(\"y\", 0))\n      .Build();\n  return Maybe<void>::Ok();\n}\n\nbool IsAxesLegal(const AxisVector& axis_vec, const Shape& like_shape, const Shape& in_shape) {\n  Shape reduced_like_shape = CreateReducedShape(like_shape, axis_vec);\n  if (like_shape.NumAxes() > in_shape.NumAxes()) {\n    std::vector<int64_t> in_shape_vec;\n    in_shape_vec.reserve(in_shape.NumAxes());\n    std::vector<int64_t> like_shape_vec;\n    like_shape_vec.reserve(reduced_like_shape.NumAxes());\n    for (const int64_t& dim : in_shape.dim_vec()) {\n      if (dim != 1) { in_shape_vec.emplace_back(dim); }\n    }\n    for (const int64_t& dim : reduced_like_shape.dim_vec()) {\n      if (dim != 1) { like_shape_vec.emplace_back(dim); }\n    }\n    if (in_shape_vec.size() > like_shape_vec.size()) {\n      return false;\n    } else {\n      return std::equal(in_shape_vec.begin(), in_shape_vec.end(), like_shape_vec.begin());\n    }\n  }\n  return reduced_like_shape.dim_vec() == in_shape.dim_vec();\n}\n\nMaybe<void> InferTensorDesc(user_op::InferContext* ctx) {\n  const auto& broadcast_axes = ctx->Attr<std::vector<int32_t>>(\"broadcast_axes\");\n  CHECK_OR_RETURN(!broadcast_axes.empty());\n  const Shape& in_shape = ctx->InputShape(\"x\", 0);\n  const Shape& like_shape = ctx->InputShape(\"like\", 0);\n  const AxisVector axis_vec = {broadcast_axes.begin(), broadcast_axes.end()};\n  CHECK_OR_RETURN(IsAxesLegal(axis_vec, like_shape, in_shape))\n      << Error::RuntimeError() << \"Invalid input parameter: like shape:\" << like_shape.ToString()\n      << \", in shape:\" << in_shape.ToString() << \", axis_vec size:\" << axis_vec.size();\n  ctx->SetOutputShape(\"y\", 0, like_shape);\n  ctx->SetOutputStride(\"y\", 0, Stride(like_shape));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\n/* static */ Maybe<void> BroadcastLikeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  return InferTensorDesc(ctx);\n}\n\n/*static*/ Maybe<void> BroadcastLikeOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> BroadcastLikeOp::GetSbp(user_op::SbpContext* ctx) {\n  return GetSbpSignatures(ctx);\n}\n\n/* static */ Maybe<void> BroadcastLikeOp::ModifyInputArg(\n    const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) {\n  user_op::InputArgModifier* like_modifier = GetInputArgModifierFn(\"like\", 0);\n  CHECK_OR_RETURN(like_modifier != nullptr);\n  like_modifier->set_requires_grad(false);\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> BroadcastLikeOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"y\", 0, ctx->InputDType(\"x\", 0));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/buffer_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/* static */ Maybe<void> IdentityBufferOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  ctx->SetOutputShape(\"out\", 0, ctx->InputShape(\"in\", 0));\n  ctx->SetOutputIsDynamic(\"out\", 0, ctx->InputIsDynamic(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> IdentityBufferOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> IdentityBufferOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"in\", 0);\n  FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) {\n    ctx->NewBuilder().Split(user_op::OpArg(\"in\", 0), i).Split(user_op::OpArg(\"out\", 0), i).Build();\n  }\n  ctx->NewBuilder()\n      .PartialSum(user_op::OpArg(\"in\", 0))\n      .PartialSum(user_op::OpArg(\"out\", 0))\n      .Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> IdentityBufferOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/cast_like_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/* static */ Maybe<void> CastLikeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  ctx->SetOutputShape(\"out\", 0, ctx->InputShape(\"in\", 0));\n  ctx->SetOutputIsDynamic(\"out\", 0, ctx->InputIsDynamic(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> CastLikeOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> CastLikeOp::GetSbp(user_op::SbpContext* ctx) {\n  const auto& in_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"in\", 0).shape();\n  for (int i = 0; i < in_shape.NumAxes(); ++i) {\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"in\", 0), i)\n        .Split(user_op::OpArg(\"dtype_like\", 0), i)\n        .Split(user_op::OpArg(\"out\", 0), i)\n        .Build();\n  }\n  ctx->NewBuilder()\n      .PartialSum(user_op::OpArg(\"dtype_like\", 0))\n      .Broadcast(user_op::OpArg(\"in\", 0))\n      .Broadcast(user_op::OpArg(\"out\", 0))\n      .Build();\n  ctx->NewBuilder()\n      .Broadcast(user_op::OpArg(\"dtype_like\", 0))\n      .PartialSum(user_op::OpArg(\"in\", 0))\n      .PartialSum(user_op::OpArg(\"out\", 0))\n      .Build();\n  ctx->NewBuilder()\n      .PartialSum(user_op::OpArg(\"dtype_like\", 0))\n      .PartialSum(user_op::OpArg(\"in\", 0))\n      .PartialSum(user_op::OpArg(\"out\", 0))\n      .Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> CastLikeOp::ModifyInputArg(\n    const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) {\n  user_op::InputArgModifier* dtype_like_modifier = GetInputArgModifierFn(\"dtype_like\", 0);\n  CHECK_NOTNULL_OR_RETURN(dtype_like_modifier);\n  dtype_like_modifier->set_requires_grad(false);\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> CastLikeOp::InferDataType(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& dtype_like_tensor_desc = ctx->InputTensorDesc(\"dtype_like\", 0);\n  user_op::TensorDesc* output_tensor_desc = ctx->MutOutputTensorDesc(\"out\", 0);\n  output_tensor_desc->set_data_type(dtype_like_tensor_desc.data_type());\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/cast_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n#include \"oneflow/core/framework/device.h\"\n#include \"oneflow/core/framework/stream.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nMaybe<Symbol<Stream>> MakeCastStream(const Symbol<Device>& in_device, const bool pin_memory) {\n  if (pin_memory) {\n    CHECK_OR_RETURN(in_device->type() == \"cpu\")\n        << \"cast op only support pin_memory in cpu device but got \" << in_device->type();\n    // TODO:(zhaoluyang) Parsing pin-memory-device from python\n    auto pin_device = JUST(Device::New(\"cuda\"));\n    return Stream::New(pin_device, StreamType::kPinnedCompute);\n  }\n  return Stream::New(in_device, StreamType::kCompute);\n}\n\n}  // namespace\n\n/* static */ Maybe<void> CastOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& input_tensor_desc = ctx->InputTensorDesc(\"in\", 0);\n  user_op::TensorDesc* output_tensor_desc = ctx->MutOutputTensorDesc(\"out\", 0);\n  output_tensor_desc->set_shape(input_tensor_desc.shape());\n  output_tensor_desc->set_stride(\n      input_tensor_desc.stride());  // output's stride should consistent with input's\n  output_tensor_desc->set_is_dynamic(input_tensor_desc.is_dynamic());\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> CastOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> CastOp::GetSbp(user_op::SbpContext* ctx) {\n  const auto& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"in\", 0);\n  for (int i = 0; i < in_tensor.shape().NumAxes(); ++i) {\n    ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build();\n  }\n  ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> CastOp::InferDataType(user_op::InferContext* ctx) {\n  user_op::TensorDesc* output_tensor_desc = ctx->MutOutputTensorDesc(\"out\", 0);\n  output_tensor_desc->set_data_type(ctx->Attr<DataType>(\"dtype\"));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<Symbol<Stream>> CastOp::InferDeviceAndStream(\n    user_op::DeviceAndStreamInferContext* ctx) {\n  const Symbol<Device>& in_device = ctx->InputTensorDevice4ArgNameAndIndex(\"in\", 0);\n  *ctx->OutputTensorDevice4ArgNameAndIndex(\"out\", 0) = in_device;\n  const bool pin_memory = ctx->Attr<bool>(\"pin_memory\");\n  return MakeCastStream(in_device, pin_memory);\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/cast_to_static_shape_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/* static */ Maybe<void> CastToStaticShapeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& input_desc = ctx->InputTensorDesc(\"input\", 0);\n  user_op::TensorDesc* output_desc = ctx->MutOutputTensorDesc(\"output\", 0);\n  output_desc->set_shape(input_desc.shape());\n  output_desc->set_is_dynamic(false);\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> CastToStaticShapeOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> CastToStaticShapeOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& input_desc = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"input\", 0);\n  FOR_RANGE(int64_t, i, 0, input_desc.shape().NumAxes()) {\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"input\", 0), i)\n        .Split(user_op::OpArg(\"output\", 0), i)\n        .Build();\n  }\n  ctx->NewBuilder()\n      .PartialSum(user_op::OpArg(\"input\", 0))\n      .PartialSum(user_op::OpArg(\"output\", 0))\n      .Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> CastToStaticShapeOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"output\", 0, ctx->InputDType(\"input\", 0));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/cast_to_tick_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/operator/operator.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/* static */ Maybe<void> CastToTickOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  ctx->SetOutputShape(\"out\", 0, Shape({1}));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> CastToTickOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> CastToTickOp::GetSbp(user_op::SbpContext* ctx) {\n  return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx);\n}\n\n/* static */ Maybe<void> CastToTickOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) {\n  const NdSbp& in_dis_hint = ctx->NdSbpHint4InputArgNameAndIndex(\"in\", 0);\n  const Shape& parallel_hierarchy = ctx->parallel_hierarchy();\n  CHECK_EQ_OR_RETURN(in_dis_hint.sbp_parallel_size(), parallel_hierarchy.NumAxes());\n\n  NdSbp* in_distribution = ctx->NdSbp4ArgNameAndIndex(\"in\", 0);\n  NdSbp* out_distribution = ctx->NdSbp4ArgNameAndIndex(\"out\", 0);\n  in_distribution->clear_sbp_parallel();\n  out_distribution->clear_sbp_parallel();\n  // in use hint\n  in_distribution->CopyFrom(in_dis_hint);\n\n  for (int32_t i = 0; i < parallel_hierarchy.NumAxes(); ++i) {\n    // out dim1 = broadcast\n    out_distribution->add_sbp_parallel()->mutable_broadcast_parallel();\n  }\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> CastToTickOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/categorical_ordinal_encode_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/* static */ Maybe<void> CategoricalOrdinalEncodeOp::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) {\n  const Shape& table_shape = ctx->InputShape(\"table\", 0);\n  CHECK_EQ_OR_RETURN(table_shape.NumAxes(), 1);\n  CHECK_EQ_OR_RETURN(table_shape.elem_cnt() % 2, 0);\n  const Shape& size_shape = ctx->InputShape(\"size\", 0);\n  CHECK_EQ_OR_RETURN(size_shape.NumAxes(), 1);\n  CHECK_EQ_OR_RETURN(size_shape.elem_cnt(), 1);\n  ctx->SetOutputShape(\"out\", 0, ctx->InputShape(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> CategoricalOrdinalEncodeOp::InferPhysicalTensorDesc(\n    user_op::InferContext* ctx) {\n  CHECK_EQ_OR_RETURN(ctx->parallel_ctx().parallel_num(), 1);\n  const Shape& table_shape = ctx->InputShape(\"table\", 0);\n  CHECK_EQ_OR_RETURN(table_shape.NumAxes(), 1);\n  CHECK_EQ_OR_RETURN(table_shape.elem_cnt() % 2, 0);\n  const Shape& size_shape = ctx->InputShape(\"size\", 0);\n  CHECK_EQ_OR_RETURN(size_shape.NumAxes(), 1);\n  CHECK_EQ_OR_RETURN(size_shape.elem_cnt(), 1);\n  ctx->SetOutputShape(\"out\", 0, ctx->InputShape(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> CategoricalOrdinalEncodeOp::GetSbp(user_op::SbpContext* ctx) {\n  CHECK_EQ_OR_RETURN(ctx->parallel_num(), 1);\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> CategoricalOrdinalEncodeOp::ModifyInputArg(\n    const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) {\n  user_op::InputArgModifier* table = GetInputArgModifierFn(\"table\", 0);\n  table->set_is_mutable(true);\n  table->set_requires_grad(false);\n  user_op::InputArgModifier* size = GetInputArgModifierFn(\"size\", 0);\n  size->set_is_mutable(true);\n  size->set_requires_grad(false);\n  user_op::InputArgModifier* in = GetInputArgModifierFn(\"in\", 0);\n  in->set_requires_grad(false);\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> CategoricalOrdinalEncodeOp::CheckAttr(\n    const user_op::UserOpDefWrapper& def, const user_op::UserOpConfWrapper& conf) {\n  CHECK_OR_RETURN(conf.attr<bool>(\"hash_precomputed\"));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> CategoricalOrdinalEncodeOp::InferDataType(user_op::InferContext* ctx) {\n  DataType data_type = ctx->InputDType(\"in\", 0);\n  CHECK_OR_RETURN(IsIndexDataType(data_type));\n  CHECK_EQ_OR_RETURN(ctx->InputDType(\"table\", 0), data_type)\n      << \"InferDataType Failed. Expected \" << DataType_Name(ctx->InputDType(\"table\", 0))\n      << \", but got \" << DataType_Name(data_type);\n  CHECK_EQ_OR_RETURN(ctx->InputDType(\"size\", 0), data_type)\n      << \"InferDataType Failed. Expected \" << DataType_Name(ctx->InputDType(\"size\", 0))\n      << \", but got \" << DataType_Name(data_type);\n  ctx->SetOutputDType(\"out\", 0, data_type);\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/celu_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/* static */ Maybe<void> CeluOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  ctx->SetOutputShape(\"out\", 0, ctx->InputShape(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> CeluOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> CeluOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"in\", 0);\n  FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) {\n    ctx->NewBuilder().Split(user_op::OpArg(\"in\", 0), i).Split(user_op::OpArg(\"out\", 0), i).Build();\n  }\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> CeluOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> CeluGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& y_shape = ctx->InputShape(\"y\", 0);\n  const Shape& dy_shape = ctx->InputShape(\"dy\", 0);\n  CHECK_OR_RETURN(dy_shape == y_shape);\n  ctx->SetOutputShape(\"dx\", 0, dy_shape);\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> CeluGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> CeluGradOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& y_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"y\", 0);\n  FOR_RANGE(int64_t, i, 0, y_tensor.shape().NumAxes()) {\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"y\", 0), i)\n        .Split(user_op::OpArg(\"dy\", 0), i)\n        .Split(user_op::OpArg(\"dx\", 0), i)\n        .Build();\n  }\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> CeluGradOp::InferDataType(user_op::InferContext* ctx) {\n  CHECK_EQ_OR_RETURN(ctx->InputDType(\"dy\", 0), ctx->InputDType(\"y\", 0))\n      << \"InferDataType Failed. Expected \" << DataType_Name(ctx->InputDType(\"dy\", 0))\n      << \", but got \" << DataType_Name(ctx->InputDType(\"y\", 0));\n  ctx->SetOutputDType(\"dx\", 0, ctx->InputDType(\"y\", 0));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/clip_by_value_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nMaybe<void> InferClipTensorDesc(user_op::InferContext* ctx) {\n  ctx->SetOutputShape(\"y\", 0, ctx->InputShape(\"x\", 0));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> GetClipSbpSignature(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"x\", 0);\n  FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) {\n    ctx->NewBuilder().Split(user_op::OpArg(\"x\", 0), i).Split(user_op::OpArg(\"y\", 0), i).Build();\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InferClipGradTensorDesc(user_op::InferContext* ctx) {\n  ctx->SetOutputShape(\"dx\", 0, ctx->InputShape(\"x\", 0));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> GetClipGradSbpSignature(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"x\", 0);\n  FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) {\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"dy\", 0), i)\n        .Split(user_op::OpArg(\"x\", 0), i)\n        .Split(user_op::OpArg(\"dx\", 0), i)\n        .Build();\n  }\n  ctx->NewBuilder()\n      .Broadcast(user_op::OpArg(\"x\", 0))\n      .PartialSum(user_op::OpArg(\"dy\", 0))\n      .PartialSum(user_op::OpArg(\"dx\", 0))\n      .Build();\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InferClipTensorDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"y\", 0, ctx->InputDType(\"x\", 0));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InferClipGradDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"dx\", 0, ctx->InputDType(\"x\", 0));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\n#define DEF_CLIP_BY_VALUE_OP(op_class_name_prefix)                                               \\\n  /* static */ Maybe<void> op_class_name_prefix##Op::InferLogicalTensorDesc(                     \\\n      user_op::InferContext* ctx) {                                                              \\\n    return InferClipTensorDesc(ctx);                                                             \\\n  }                                                                                              \\\n                                                                                                 \\\n  /*static*/ Maybe<void> op_class_name_prefix##Op::InferPhysicalTensorDesc(                      \\\n      user_op::InferContext* ctx) {                                                              \\\n    return InferLogicalTensorDesc(ctx);                                                          \\\n  }                                                                                              \\\n                                                                                                 \\\n  /* static */ Maybe<void> op_class_name_prefix##Op::GetSbp(user_op::SbpContext* ctx) {          \\\n    return GetClipSbpSignature(ctx);                                                             \\\n  }                                                                                              \\\n                                                                                                 \\\n  /* static */ Maybe<void> op_class_name_prefix##Op::InferDataType(user_op::InferContext* ctx) { \\\n    return InferClipTensorDataType(ctx);                                                         \\\n  }                                                                                              \\\n  /* static */ Maybe<void> op_class_name_prefix##GradOp::InferLogicalTensorDesc(                 \\\n      user_op::InferContext* ctx) {                                                              \\\n    return InferClipGradTensorDesc(ctx);                                                         \\\n  }                                                                                              \\\n  /*static*/ Maybe<void> op_class_name_prefix##GradOp::InferPhysicalTensorDesc(                  \\\n      user_op::InferContext* ctx) {                                                              \\\n    return InferLogicalTensorDesc(ctx);                                                          \\\n  }                                                                                              \\\n  /* static */ Maybe<void> op_class_name_prefix##GradOp::GetSbp(user_op::SbpContext* ctx) {      \\\n    return GetClipGradSbpSignature(ctx);                                                         \\\n  }                                                                                              \\\n  /* static */ Maybe<void> op_class_name_prefix##GradOp::InferDataType(                          \\\n      user_op::InferContext* ctx) {                                                              \\\n    return InferClipGradDataType(ctx);                                                           \\\n  }\n\nDEF_CLIP_BY_VALUE_OP(ClipByScalar)\nDEF_CLIP_BY_VALUE_OP(ClipByScalarMin)\nDEF_CLIP_BY_VALUE_OP(ClipByScalarMax)\n\n#undef DEF_CLIP_BY_VALUE_OP\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/coco_reader_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/* static */ Maybe<void> COCOReaderOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  int64_t batch_size = ctx->Attr<int64_t>(\"batch_size\");\n  user_op::TensorDesc* image_desc = ctx->MutOutputTensorDesc(\"image\", 0);\n  image_desc->set_shape(Shape({batch_size}));\n  user_op::TensorDesc* image_id_desc = ctx->MutOutputTensorDesc(\"image_id\", 0);\n  image_id_desc->set_shape(Shape({batch_size}));\n  user_op::TensorDesc* image_size_desc = ctx->MutOutputTensorDesc(\"image_size\", 0);\n  image_size_desc->set_shape(Shape({batch_size, 2}));\n  user_op::TensorDesc* bbox_desc = ctx->MutOutputTensorDesc(\"gt_bbox\", 0);\n  bbox_desc->set_shape(Shape({batch_size}));\n  user_op::TensorDesc* label_desc = ctx->MutOutputTensorDesc(\"gt_label\", 0);\n  label_desc->set_shape(Shape({batch_size}));\n  user_op::TensorDesc* segm_desc = ctx->MutOutputTensorDesc(\"gt_segm\", 0);\n  segm_desc->set_shape(Shape({batch_size}));\n  user_op::TensorDesc* segm_index_desc = ctx->MutOutputTensorDesc(\"gt_segm_index\", 0);\n  segm_index_desc->set_shape(Shape({batch_size}));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> COCOReaderOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  const NdSbp& nd_sbp = ctx->NdSbp4ArgNameAndIndex(\"image\", 0);\n  CHECK_OR_RETURN(nd_sbp == ctx->NdSbp4ArgNameAndIndex(\"image_id\", 0));\n  CHECK_OR_RETURN(nd_sbp == ctx->NdSbp4ArgNameAndIndex(\"image_size\", 0));\n  CHECK_OR_RETURN(nd_sbp == ctx->NdSbp4ArgNameAndIndex(\"gt_bbox\", 0));\n  CHECK_OR_RETURN(nd_sbp == ctx->NdSbp4ArgNameAndIndex(\"gt_label\", 0));\n  CHECK_OR_RETURN(nd_sbp == ctx->NdSbp4ArgNameAndIndex(\"gt_segm\", 0));\n  CHECK_OR_RETURN(nd_sbp == ctx->NdSbp4ArgNameAndIndex(\"gt_segm_index\", 0));\n\n  int64_t batch_size = ctx->Attr<int64_t>(\"batch_size\");\n  int64_t parallel_num = ctx->parallel_ctx().parallel_num();\n  int64_t device_batch_size = batch_size;\n  if (parallel_num > 1) {\n    int64_t split_num = 1;\n    const Shape& hierarchy = *ctx->parallel_desc().hierarchy();\n    for (int32_t i = 0; i < nd_sbp.sbp_parallel_size(); ++i) {\n      if (nd_sbp.sbp_parallel(i).has_split_parallel()) { split_num *= hierarchy.At(i); }\n    }\n    CHECK_EQ_OR_RETURN(device_batch_size % split_num, 0);\n    device_batch_size /= split_num;\n  }\n\n  user_op::TensorDesc* image_desc = ctx->MutOutputTensorDesc(\"image\", 0);\n  image_desc->set_shape(Shape({device_batch_size}));\n  user_op::TensorDesc* image_id_desc = ctx->MutOutputTensorDesc(\"image_id\", 0);\n  image_id_desc->set_shape(Shape({device_batch_size}));\n  user_op::TensorDesc* image_size_desc = ctx->MutOutputTensorDesc(\"image_size\", 0);\n  image_size_desc->set_shape(Shape({device_batch_size, 2}));\n  user_op::TensorDesc* bbox_desc = ctx->MutOutputTensorDesc(\"gt_bbox\", 0);\n  bbox_desc->set_shape(Shape({device_batch_size}));\n  user_op::TensorDesc* label_desc = ctx->MutOutputTensorDesc(\"gt_label\", 0);\n  label_desc->set_shape(Shape({device_batch_size}));\n  user_op::TensorDesc* segm_desc = ctx->MutOutputTensorDesc(\"gt_segm\", 0);\n  segm_desc->set_shape(Shape({device_batch_size}));\n  user_op::TensorDesc* segm_index_desc = ctx->MutOutputTensorDesc(\"gt_segm_index\", 0);\n  segm_index_desc->set_shape(Shape({device_batch_size}));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> COCOReaderOp::GetSbp(user_op::SbpContext* ctx) {\n  ctx->NewBuilder().Broadcast(ctx->inputs()).Split(ctx->outputs(), 0).Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> COCOReaderOp::ModifyOutputArg(\n    const GetOutputArgModifier& GetOutputArgModifierFn, const user_op::UserOpConfWrapper& conf) {\n  user_op::OutputArgModifier* image_modifier = GetOutputArgModifierFn(\"image\", 0);\n  CHECK_OR_RETURN(image_modifier != nullptr);\n  image_modifier->set_header_infered_before_compute(false);\n\n  user_op::OutputArgModifier* image_id_modifier = GetOutputArgModifierFn(\"image_id\", 0);\n  CHECK_OR_RETURN(image_id_modifier != nullptr);\n  image_id_modifier->set_header_infered_before_compute(false);\n\n  user_op::OutputArgModifier* image_size_modifier = GetOutputArgModifierFn(\"image_size\", 0);\n  CHECK_OR_RETURN(image_size_modifier != nullptr);\n  image_size_modifier->set_header_infered_before_compute(false);\n\n  user_op::OutputArgModifier* gt_bbox_modifier = GetOutputArgModifierFn(\"gt_bbox\", 0);\n  CHECK_OR_RETURN(gt_bbox_modifier != nullptr);\n  gt_bbox_modifier->set_header_infered_before_compute(false);\n\n  user_op::OutputArgModifier* gt_label_modifier = GetOutputArgModifierFn(\"gt_label\", 0);\n  CHECK_OR_RETURN(gt_label_modifier != nullptr);\n  gt_label_modifier->set_header_infered_before_compute(false);\n\n  user_op::OutputArgModifier* gt_segm_modifier = GetOutputArgModifierFn(\"gt_segm\", 0);\n  CHECK_OR_RETURN(gt_segm_modifier != nullptr);\n  gt_segm_modifier->set_header_infered_before_compute(false);\n\n  user_op::OutputArgModifier* gt_segm_index_modifier = GetOutputArgModifierFn(\"gt_segm_index\", 0);\n  CHECK_OR_RETURN(gt_segm_index_modifier != nullptr);\n  gt_segm_index_modifier->set_header_infered_before_compute(false);\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> COCOReaderOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) {\n  SbpParallel default_sbp;\n  default_sbp.mutable_split_parallel()->set_axis(0);\n  return user_op::InferNdSbp4SrcOp(ctx, default_sbp);\n}\n\n/* static */ Maybe<void> COCOReaderOp::InferDataType(user_op::InferContext* ctx) {\n  user_op::TensorDesc* image_desc = ctx->MutOutputTensorDesc(\"image\", 0);\n  image_desc->set_data_type(DataType::kTensorBuffer);\n  user_op::TensorDesc* image_id_desc = ctx->MutOutputTensorDesc(\"image_id\", 0);\n  image_id_desc->set_data_type(DataType::kInt64);\n  user_op::TensorDesc* image_size_desc = ctx->MutOutputTensorDesc(\"image_size\", 0);\n  image_size_desc->set_data_type(DataType::kInt32);\n  user_op::TensorDesc* bbox_desc = ctx->MutOutputTensorDesc(\"gt_bbox\", 0);\n  bbox_desc->set_data_type(DataType::kTensorBuffer);\n  user_op::TensorDesc* label_desc = ctx->MutOutputTensorDesc(\"gt_label\", 0);\n  label_desc->set_data_type(DataType::kTensorBuffer);\n  user_op::TensorDesc* segm_desc = ctx->MutOutputTensorDesc(\"gt_segm\", 0);\n  segm_desc->set_data_type(DataType::kTensorBuffer);\n  user_op::TensorDesc* segm_index_desc = ctx->MutOutputTensorDesc(\"gt_segm_index\", 0);\n  segm_index_desc->set_data_type(DataType::kTensorBuffer);\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/combined_margin_loss_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/* static */ Maybe<void> CombinedMarginLossOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& x = ctx->InputTensorDesc(\"x\", 0);\n  const user_op::TensorDesc& label = ctx->InputTensorDesc(\"label\", 0);\n  user_op::TensorDesc* theta = ctx->MutOutputTensorDesc(\"theta\", 0);\n  CHECK_EQ_OR_RETURN(label.shape().At(0), x.shape().At(0));\n  CHECK_GE_OR_RETURN(x.shape().NumAxes(), 2);\n  ctx->SetOutputShape(\"y\", 0, ctx->InputShape(\"x\", 0));\n  ctx->SetIsDynamic4ArgNameAndIndex(\"y\", 0, ctx->InputIsDynamic(\"x\", 0));\n  theta->set_is_dynamic(x.is_dynamic());\n  theta->set_shape(label.shape());\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> CombinedMarginLossOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> CombinedMarginLossOp::GetSbp(user_op::SbpContext* ctx) {\n  ctx->NewBuilder()\n      .Split(user_op::OpArg(\"x\", 0), 0)\n      .Split(user_op::OpArg(\"label\", 0), 0)\n      .Split(user_op::OpArg(\"y\", 0), 0)\n      .Split(user_op::OpArg(\"theta\", 0), 0)\n      .Build();\n  ctx->NewBuilder()\n      .Split(user_op::OpArg(\"x\", 0), 1)\n      .Broadcast(user_op::OpArg(\"label\", 0))\n      .Split(user_op::OpArg(\"y\", 0), 1)\n      .PartialSum(user_op::OpArg(\"theta\", 0))\n      .Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> CombinedMarginLossOp::ModifyInputArg(\n    const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) {\n  user_op::InputArgModifier* label_arg_modifier = GetInputArgModifierFn(\"label\", 0);\n  label_arg_modifier->set_requires_grad(false);\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> CombinedMarginLossOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"y\", 0, ctx->InputDType(\"x\", 0));\n  ctx->SetOutputDType(\"theta\", 0, ctx->InputDType(\"x\", 0));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> CombinedMarginLossGradOp::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) {\n  const user_op::TensorDesc& dy = ctx->InputTensorDesc(\"dy\", 0);\n  const user_op::TensorDesc& label = ctx->InputTensorDesc(\"label\", 0);\n  const user_op::TensorDesc& theta = ctx->InputTensorDesc(\"theta\", 0);\n  CHECK_EQ_OR_RETURN(label.shape().At(0), dy.shape().At(0));\n  CHECK_EQ_OR_RETURN(label.shape().At(0), theta.shape().At(0));\n  CHECK_GE_OR_RETURN(dy.shape().NumAxes(), 2);\n  ctx->SetOutputShape(\"dx\", 0, ctx->InputShape(\"dy\", 0));\n  ctx->SetIsDynamic4ArgNameAndIndex(\"dx\", 0, ctx->InputIsDynamic(\"dy\", 0));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> CombinedMarginLossGradOp::InferPhysicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> CombinedMarginLossGradOp::GetSbp(user_op::SbpContext* ctx) {\n  ctx->NewBuilder()\n      .Split(user_op::OpArg(\"dy\", 0), 0)\n      .Split(user_op::OpArg(\"label\", 0), 0)\n      .Split(user_op::OpArg(\"theta\", 0), 0)\n      .Split(user_op::OpArg(\"dx\", 0), 0)\n      .Build();\n  ctx->NewBuilder()\n      .Split(user_op::OpArg(\"dy\", 0), 1)\n      .Broadcast(user_op::OpArg(\"label\", 0))\n      .Broadcast(user_op::OpArg(\"theta\", 0))\n      .Split(user_op::OpArg(\"dx\", 0), 1)\n      .Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> CombinedMarginLossGradOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"dx\", 0, ctx->InputDType(\"dy\", 0));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/comm_net_device_infer_util.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/user/ops/comm_net_device_infer_util.h\"\n#include \"oneflow/core/common/decorator.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nMaybe<Symbol<Stream>> RawGetTransportDevice(Symbol<Device> device) {\n  return Stream::New(JUST(Device::New(device->type())), StreamType::kCcl);\n}\n\n}  // namespace\n\ndecltype(GetTransportDevice) GetTransportDevice = DECORATE(&RawGetTransportDevice, ThreadLocal);\n\nMaybe<Symbol<Device>> DefaultGetOutputDeivce(user_op::DeviceAndStreamInferContext* ctx) {\n  CHECK_GT_OR_RETURN(ctx->inputs().size(), 0);\n  return ctx->InputTensorDevice4ArgNameAndIndex(\"in\", 0);\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/comm_net_device_infer_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_CORE_USER_OP_NCCL_DEVICE_INFER_UTIL_H_\n#define ONEFLOW_CORE_USER_OP_NCCL_DEVICE_INFER_UTIL_H_\n\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/device.h\"\n#include \"oneflow/core/framework/stream.h\"\n#include \"oneflow/core/common/env_var/eager.h\"\n#include \"oneflow/core/job/lazy_mode.h\"\n\nnamespace oneflow {\n\nextern Maybe<Symbol<Stream>> (*GetTransportDevice)(Symbol<Device>);\n\nMaybe<Symbol<Device>> DefaultGetOutputDeivce(user_op::DeviceAndStreamInferContext* ctx);\n\ntemplate<Maybe<Symbol<Device>> (*GetOutputDeivce)(user_op::DeviceAndStreamInferContext*) =\n             DefaultGetOutputDeivce>\nMaybe<Symbol<Stream>> DeviceAndStreamInferFn(user_op::DeviceAndStreamInferContext* ctx) {\n  Symbol<Device> output_device = JUST(GetOutputDeivce(ctx));\n  for (const auto& pair : ctx->outputs()) {\n    *ctx->OutputTensorDevice4ArgNameAndIndex(pair.first, pair.second) = output_device;\n  }\n  if (EagerNcclUseComputeStream() && !LazyMode::is_enabled()) {\n    return GetDefaultStreamByDevice(output_device);\n  }\n  return GetTransportDevice(output_device);\n}\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_USER_OP_NCCL_DEVICE_INFER_UTIL_H_\n"
  },
  {
    "path": "oneflow/user/ops/complex_ops.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <map>\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\nstatic std::map<DataType, DataType> complex_to_real_map{{DataType::kComplex32, DataType::kFloat16},\n                                                        {DataType::kComplex64, DataType::kFloat},\n                                                        {DataType::kComplex128, DataType::kDouble}};\nstatic std::map<DataType, DataType> real_to_complex_map{{DataType::kFloat16, DataType::kComplex32},\n                                                        {DataType::kFloat, DataType::kComplex64},\n                                                        {DataType::kDouble, DataType::kComplex128}};\n\n/*static*/ Maybe<void> RealOp::GetSbp(user_op::SbpContext* ctx) {\n  return user_op::GetSbpFnUtil::SplitForEachAxis(ctx);\n}\n/*static*/ Maybe<void> RealOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  return user_op::TensorDescInferFnUtil::Unchanged(ctx);\n}\n/*static*/ Maybe<void> RealOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> RealOp::InferDataType(user_op::InferContext* ctx) {\n  const std::pair<std::string, int32_t>& input_arg = ctx->inputs().at(0);\n  const user_op::TensorDesc& tensor_desc = ctx->InputTensorDesc(input_arg.first, input_arg.second);\n  const std::pair<std::string, int32_t>& output_arg = ctx->outputs().at(0);\n  ctx->SetOutputDType(output_arg.first, output_arg.second,\n                      complex_to_real_map[tensor_desc.data_type()]);\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> RealGradOp::GetSbp(user_op::SbpContext* ctx) {\n  return user_op::GetSbpFnUtil::SplitForEachAxis(ctx);\n}\n/*static*/ Maybe<void> RealGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  return user_op::TensorDescInferFnUtil::Unchanged(ctx);\n}\n/*static*/ Maybe<void> RealGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> RealGradOp::InferDataType(user_op::InferContext* ctx) {\n  const std::pair<std::string, int32_t>& input_arg = ctx->inputs().at(0);\n  const user_op::TensorDesc& tensor_desc = ctx->InputTensorDesc(input_arg.first, input_arg.second);\n  const std::pair<std::string, int32_t>& output_arg = ctx->outputs().at(0);\n  ctx->SetOutputDType(output_arg.first, output_arg.second,\n                      real_to_complex_map[tensor_desc.data_type()]);\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> ImagOp::GetSbp(user_op::SbpContext* ctx) {\n  return user_op::GetSbpFnUtil::SplitForEachAxis(ctx);\n}\n/*static*/ Maybe<void> ImagOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  return user_op::TensorDescInferFnUtil::Unchanged(ctx);\n}\n/*static*/ Maybe<void> ImagOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> ImagOp::InferDataType(user_op::InferContext* ctx) {\n  const std::pair<std::string, int32_t>& input_arg = ctx->inputs().at(0);\n  const user_op::TensorDesc& tensor_desc = ctx->InputTensorDesc(input_arg.first, input_arg.second);\n  const std::pair<std::string, int32_t>& output_arg = ctx->outputs().at(0);\n  ctx->SetOutputDType(output_arg.first, output_arg.second,\n                      complex_to_real_map[tensor_desc.data_type()]);\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> ImagGradOp::GetSbp(user_op::SbpContext* ctx) {\n  return user_op::GetSbpFnUtil::SplitForEachAxis(ctx);\n}\n/*static*/ Maybe<void> ImagGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  return user_op::TensorDescInferFnUtil::Unchanged(ctx);\n}\n/*static*/ Maybe<void> ImagGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> ImagGradOp::InferDataType(user_op::InferContext* ctx) {\n  const std::pair<std::string, int32_t>& input_arg = ctx->inputs().at(0);\n  const user_op::TensorDesc& tensor_desc = ctx->InputTensorDesc(input_arg.first, input_arg.second);\n  const std::pair<std::string, int32_t>& output_arg = ctx->outputs().at(0);\n  ctx->SetOutputDType(output_arg.first, output_arg.second,\n                      real_to_complex_map[tensor_desc.data_type()]);\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> ConjPhysicalOp::GetSbp(user_op::SbpContext* ctx) {\n  return user_op::GetSbpFnUtil::SplitForEachAxis(ctx);\n}\n/*static*/ Maybe<void> ConjPhysicalOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  return user_op::TensorDescInferFnUtil::Unchanged(ctx);\n}\n/*static*/ Maybe<void> ConjPhysicalOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> ConjPhysicalOp::InferDataType(user_op::InferContext* ctx) {\n  return user_op::TensorDescInferFnUtil::UnchangedDataType(ctx);\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/concat_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/* static */ Maybe<void> ConcatOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& first_in_desc = ctx->InputTensorDesc(\"in\", 0);\n  const int64_t axis = ctx->Attr<int64_t>(\"axis\");\n  CHECK_GE_OR_RETURN(axis, 0);\n  CHECK_LT_OR_RETURN(axis, first_in_desc.shape().NumAxes());\n  DimVector out_dim_vec = first_in_desc.shape().dim_vec();\n  out_dim_vec.at(axis) = 0;\n  int64_t first_axes = first_in_desc.shape().NumAxes();\n  int64_t first_elemcnt = first_in_desc.shape().elem_cnt();\n  int64_t dynamic_dim_size = 0;\n  for (const auto& in_arg_pair : ctx->inputs()) {\n    const user_op::TensorDesc& in_desc =\n        ctx->InputTensorDesc(in_arg_pair.first, in_arg_pair.second);\n    if (first_elemcnt == 0 and first_axes == 1) {\n      if (in_desc.shape().elem_cnt() != 0 or in_desc.shape().NumAxes() != 1) {\n        out_dim_vec = in_desc.shape().dim_vec();\n        out_dim_vec.at(axis) = 0;\n        first_axes = in_desc.shape().NumAxes();\n        first_elemcnt = in_desc.shape().elem_cnt();\n      } else {\n        continue;\n      }\n    } else if (in_desc.shape().elem_cnt() != 0 or in_desc.shape().NumAxes() != 1) {\n      CHECK_EQ_OR_RETURN(in_desc.shape().NumAxes(), first_axes);\n    }\n    FOR_RANGE(int64_t, i, 0, in_desc.shape().NumAxes()) {\n      if (in_desc.shape().elem_cnt() == 0 and in_desc.shape().NumAxes() == 1) { continue; }\n      if (i == axis) {\n        if (in_desc.is_dynamic()) {\n          dynamic_dim_size += in_desc.shape().At(i);\n        } else {\n          out_dim_vec.at(axis) += in_desc.shape().At(i);\n        }\n      } else {\n        CHECK_EQ_OR_RETURN(in_desc.shape().At(i), out_dim_vec.at(i));\n      }\n    }\n  }\n\n  user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc(\"out\", 0);\n  const int64_t max_dim_size = ctx->Attr<int64_t>(\"max_dim_size\");\n  CHECK_LE_OR_RETURN(out_dim_vec.at(axis), max_dim_size);\n  if (dynamic_dim_size == 0) {\n    out_desc->set_is_dynamic(false);\n  } else {\n    out_desc->set_is_dynamic(true);\n    out_dim_vec.at(axis) = max_dim_size;\n  }\n  out_desc->set_shape(Shape(out_dim_vec));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> ConcatOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> ConcatOp::GetSbp(user_op::SbpContext* ctx) {\n  const int64_t axis = ctx->Attr<int64_t>(\"axis\");\n  const user_op::TensorDesc& first_in_desc = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"in\", 0);\n  FOR_RANGE(int64_t, i, 0, first_in_desc.shape().NumAxes()) {\n    if (i == axis) { continue; }\n    ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build();\n  }\n  ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> ConcatOp::InferDataType(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& first_in_desc = ctx->InputTensorDesc(\"in\", 0);\n  for (const auto& in_arg_pair : ctx->inputs()) {\n    const user_op::TensorDesc& in_desc =\n        ctx->InputTensorDesc(in_arg_pair.first, in_arg_pair.second);\n    CHECK_EQ_OR_RETURN(in_desc.data_type(), first_in_desc.data_type())\n        << \"InferDataType Failed. Expected \" << DataType_Name(in_desc.data_type()) << \", but got \"\n        << DataType_Name(first_in_desc.data_type());\n  }\n  user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc(\"out\", 0);\n  out_desc->set_data_type(first_in_desc.data_type());\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> ConcatOp::CheckAttr(const user_op::UserOpDefWrapper&,\n                                           const user_op::UserOpConfWrapper& op_conf) {\n  CHECK_OR_RETURN(op_conf.input_size(\"in\") >= 2);\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/constant_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n#include \"oneflow/core/job/nd_sbp_util.h\"\n\nnamespace oneflow {\n\n/* static */ Maybe<void> ConstantOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  ctx->SetOutputShape(\"out\", 0, Shape(ctx->Attr<Shape>(\"shape\").dim_vec()));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> ConstantOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& parallel_hierarchy = *ctx->parallel_desc().hierarchy();\n  const NdSbp& nd_sbp = ctx->NdSbp4ArgNameAndIndex(\"out\", 0);\n  const Shape& logical_shape = ctx->Attr<Shape>(\"shape\");\n  const int64_t parallel_id = ctx->parallel_ctx().parallel_id();\n  const auto tensor_slice_view =\n      GetTensorSliceView4ParallelId(parallel_hierarchy, nd_sbp, logical_shape, parallel_id);\n  const Shape& physical_shape = tensor_slice_view.shape();\n\n  ctx->SetOutputShape(\"out\", 0, physical_shape);\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> ConstantOp::GetSbp(user_op::SbpContext* ctx) { return Maybe<void>::Ok(); }\n\n/* static */ Maybe<void> ConstantOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) {\n  SbpParallel default_sbp;\n  default_sbp.mutable_broadcast_parallel();\n  return user_op::InferNdSbp4SrcOp(ctx, default_sbp);\n}\n\n/* static */ Maybe<void> ConstantOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, ctx->Attr<DataType>(\"dtype\"));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/conv_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/user/ops/nn_util.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<size_t NDims>\nMaybe<void> InferTensorDesc4Conv(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& in = ctx->InputTensorDesc(\"in\", 0);\n  CHECK_EQ_OR_RETURN(NDims + 2, in.shape().NumAxes())\n      << \"Conv\" << NDims << \"D op's input shape ndim should equal to \" << NDims + 2\n      << \" ,but got: \" << in.shape().NumAxes();\n\n  auto data_format = ctx->Attr<std::string>(\"data_format\");\n  auto kernel_size = ctx->Attr<std::vector<int32_t>>(\"kernel_size\");\n  CHECK_EQ_OR_RETURN(NDims, kernel_size.size());\n  int32_t filters = ctx->Attr<int32_t>(\"filters\");\n  size_t idx_offset = IdxOffset(data_format);\n  {\n    const auto& padding_before = ctx->Attr<std::vector<int32_t>>(\"padding_before\");\n    auto dilation_rate = ctx->Attr<std::vector<int32_t>>(\"dilation_rate\");\n    auto strides = ctx->Attr<std::vector<int32_t>>(\"strides\");\n    CHECK_EQ_OR_RETURN(NDims, dilation_rate.size());\n    CHECK_EQ_OR_RETURN(NDims, strides.size());\n    CHECK_EQ_OR_RETURN(NDims, padding_before.size());\n\n    user_op::TensorDesc* out = ctx->MutOutputTensorDesc(\"out\", 0);\n    DimVector out_shape(NDims + 2);\n    out_shape.at(0) = in.shape().At(0);\n    const size_t c_dim = data_format == \"channels_first\" ? 1 : NDims + 1;\n    out_shape.at(c_dim) = filters;\n    for (int32_t i = 0; i < NDims; ++i) {\n      JUST(CalcConvOut(in.shape().At(idx_offset + i), kernel_size.at(i), dilation_rate.at(i),\n                       strides.at(i), padding_before.at(i), &out_shape.at(idx_offset + i)));\n    }\n    out->set_is_dynamic(in.is_dynamic());\n    out->set_shape(Shape(out_shape));\n  }\n\n  {\n    int32_t groups = ctx->Attr<int32_t>(\"groups\");\n    CHECK_GT_OR_RETURN(groups, 0);\n    CHECK_LE_OR_RETURN(groups, filters);\n    CHECK_EQ_OR_RETURN(filters % groups, 0);\n\n    DimVector weight_shape(in.shape().dim_vec());\n    weight_shape.at(0) = filters;\n    if (data_format == \"channels_first\") {\n      CHECK_LE_OR_RETURN(groups, weight_shape.at(1));\n      CHECK_EQ_OR_RETURN(weight_shape.at(1) % groups, 0);\n      weight_shape.at(1) = weight_shape.at(1) / groups;\n    } else if (data_format == \"channels_last\") {\n      CHECK_LE_OR_RETURN(groups, weight_shape.at(NDims + 1));\n      CHECK_EQ_OR_RETURN(weight_shape.at(NDims + 1) % groups, 0);\n      weight_shape.at(NDims + 1) = weight_shape.at(NDims + 1) / groups;\n    } else {\n      UNIMPLEMENTED_THEN_RETURN();\n    }\n    for (size_t i = 0; i < NDims; ++i) { weight_shape.at(idx_offset + i) = kernel_size.at(i); }\n\n    const user_op::TensorDesc& weight = ctx->InputTensorDesc(\"weight\", 0);\n    CHECK_EQ_OR_RETURN(weight.shape(), Shape(weight_shape));\n  }\n\n  bool has_bias = ctx->has_input(\"bias\", 0);\n  if (has_bias) {\n    const user_op::TensorDesc& bias = ctx->InputTensorDesc(\"bias\", 0);\n    CHECK_EQ_OR_RETURN(bias.shape(), Shape({filters}));\n  }\n\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> GetSbpSignatures4Conv(user_op::SbpContext* ctx) {\n  bool has_bias = false;\n  for (const auto& pair : ctx->inputs()) {\n    if (pair.first == \"bias\") {\n      CHECK_EQ_OR_RETURN(0, pair.second);\n      has_bias = true;\n      break;\n    }\n  }\n\n  if (has_bias) {\n    ctx->NewBuilder()\n        .Split(ctx->inputs(), 0)\n        .Split(user_op::OpArg(\"in\", 0), 0)\n        .Broadcast(user_op::OpArg(\"weight\", 0))\n        .Broadcast(user_op::OpArg(\"bias\", 0))\n        .Split(user_op::OpArg(\"out\", 0), 0)\n        .Build();\n  } else {\n    ctx->NewBuilder()\n        .Split(ctx->inputs(), 0)\n        .Split(user_op::OpArg(\"in\", 0), 0)\n        .Broadcast(user_op::OpArg(\"weight\", 0))\n        .Split(user_op::OpArg(\"out\", 0), 0)\n        .Build();\n  }\n  return Maybe<void>::Ok();\n}\n\n/*\nExample for conv2d:\n\nComputationCost\n= ((k*k + k*k-1)*c + c-1 + bias?1:0) * out_channel * out_width * out_height * batch_size\n= (2*k*k*c - 1 + bias?1:0) * out_channel * out_width * out_height * batch_size\n≈ 2*k*k*c * out_channel * out_width * out_height * batch_size\n*/\nMaybe<double> ConvComputationCost(user_op::ComputeComplexityFnContext* ctx) {\n  const std::vector<int32_t> kernel_size = ctx->Attr<std::vector<int32_t>>(\"kernel_size\");\n  const std::string data_format = ctx->Attr<std::string>(\"data_format\");\n  const user_op::TensorDesc* in = ctx->TensorDesc4ArgNameAndIndex(\"in\", 0);\n  const size_t c_dim = data_format == \"channels_first\" ? 1 : in->shape().NumAxes() - 1;\n  const int32_t c = in->shape().At(c_dim);\n  const user_op::TensorDesc* out = ctx->TensorDesc4ArgNameAndIndex(\"out\", 0);\n  double cost =\n      std::accumulate(kernel_size.begin(), kernel_size.end(), 1.0, std::multiplies<double>());\n  cost = cost * 2 * c;\n  cost *= std::accumulate(out->shape().dim_vec().begin(), out->shape().dim_vec().end(), 1.0,\n                          std::multiplies<double>());\n\n  const auto& parallel_hierarchy = ctx->parallel_desc().hierarchy();\n  const auto& nd_sbp_out = ctx->NdSbp4ArgNameAndIndex(\"out\", 0);\n  for (int32_t dim_sbp = 0; dim_sbp < nd_sbp_out.sbp_parallel_size(); dim_sbp++) {\n    if (nd_sbp_out.sbp_parallel(dim_sbp).has_split_parallel()) {\n      cost /= parallel_hierarchy->At(dim_sbp);\n    }\n  }\n  return cost;\n}\n\ntemplate<size_t NDims>\nMaybe<void> CheckAttr_(const user_op::UserOpDefWrapper& def,\n                       const user_op::UserOpConfWrapper& conf) {\n  bool is_checked = true;\n  std::stringstream err;\n  err << \"Illegal value for \" << conf.op_type_name() << \" op \" << conf.op_name() << \": \";\n\n  const auto& data_format = conf.attr<std::string>(\"data_format\");\n  if (!(data_format == \"channels_first\" || data_format == \"channels_last\")) {\n    err << \" data_format:\" << data_format;\n    is_checked = false;\n  }\n\n  if (NDims != 0) {\n    const auto& padding_before = conf.attr<std::vector<int32_t>>(\"padding_before\");\n    if (padding_before.size() != NDims) {\n      err << \" padding_before: number of element is \" << padding_before.size();\n      is_checked = false;\n    }\n\n    const auto& kernel_size = conf.attr<std::vector<int32_t>>(\"kernel_size\");\n    if (kernel_size.size() != NDims) {\n      err << \" kernel_size: number of element is \" << kernel_size.size();\n      is_checked = false;\n    }\n\n    const auto& strides = conf.attr<std::vector<int32_t>>(\"strides\");\n    if (strides.size() != NDims) {\n      err << \" strides: number of element is \" << strides.size();\n      is_checked = false;\n    }\n\n    const auto& dilation_rate = conf.attr<std::vector<int32_t>>(\"dilation_rate\");\n    if (dilation_rate.size() != NDims) {\n      err << \" dilation_rate: number of element is \" << dilation_rate.size();\n      is_checked = false;\n    }\n  }\n\n  if (is_checked) {\n    return Maybe<void>::Ok();\n  } else {\n    return oneflow::Error::CheckFailedError() << err.str();\n  }\n}\n\n}  // namespace\n\n/* static */ Maybe<void> Conv1DOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  return InferTensorDesc4Conv<1>(ctx);\n}\n\n/*static*/ Maybe<void> Conv1DOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> Conv1DOp::GetSbp(user_op::SbpContext* ctx) {\n  return GetSbpSignatures4Conv(ctx);\n}\n\n/* static */ Maybe<double> Conv1DOp::GetComputeComplexity(\n    user_op::ComputeComplexityFnContext* ctx) {\n  return ConvComputationCost(ctx);\n}\n\n/* static */ Maybe<void> Conv1DOp::CheckAttr(const user_op::UserOpDefWrapper& def,\n                                             const user_op::UserOpConfWrapper& conf) {\n  return CheckAttr_<1>(def, conf);\n}\n\n/* static */ Maybe<void> Conv1DOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> Conv2DOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  return InferTensorDesc4Conv<2>(ctx);\n}\n\n/*static*/ Maybe<void> Conv2DOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> Conv2DOp::GetSbp(user_op::SbpContext* ctx) {\n  return GetSbpSignatures4Conv(ctx);\n}\n\n/* static */ Maybe<double> Conv2DOp::GetComputeComplexity(\n    user_op::ComputeComplexityFnContext* ctx) {\n  return ConvComputationCost(ctx);\n}\n\n/* static */ Maybe<void> Conv2DOp::CheckAttr(const user_op::UserOpDefWrapper& def,\n                                             const user_op::UserOpConfWrapper& conf) {\n  return CheckAttr_<2>(def, conf);\n}\n\n/* static */ Maybe<void> Conv2DOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> Conv3DOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  return InferTensorDesc4Conv<3>(ctx);\n}\n\n/*static*/ Maybe<void> Conv3DOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> Conv3DOp::GetSbp(user_op::SbpContext* ctx) {\n  return GetSbpSignatures4Conv(ctx);\n}\n\n/* static */ Maybe<double> Conv3DOp::GetComputeComplexity(\n    user_op::ComputeComplexityFnContext* ctx) {\n  return ConvComputationCost(ctx);\n}\n\n/* static */ Maybe<void> Conv3DOp::CheckAttr(const user_op::UserOpDefWrapper& def,\n                                             const user_op::UserOpConfWrapper& conf) {\n  return CheckAttr_<3>(def, conf);\n}\n\n/* static */ Maybe<void> Conv3DOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> ConvDataGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& dy = ctx->InputTensorDesc(\"dy\", 0);\n  const user_op::TensorDesc& x_like = ctx->InputTensorDesc(\"x_like\", 0);\n  const int32_t num_spatial_dims = ctx->Attr<int32_t>(\"num_spatial_dims\");\n  CHECK_GE_OR_RETURN(num_spatial_dims, 1);\n  CHECK_LE_OR_RETURN(num_spatial_dims, 3);\n  CHECK_EQ_OR_RETURN(dy.shape().NumAxes(), num_spatial_dims + 2);\n  CHECK_EQ_OR_RETURN(x_like.shape().NumAxes(), num_spatial_dims + 2);\n  if (ctx->has_input(\"_add_to_output\", 0)) {\n    const user_op::TensorDesc& add_to_output = ctx->InputTensorDesc(\"_add_to_output\", 0);\n    CHECK_EQ_OR_RETURN(add_to_output.shape(), x_like.shape());\n  }\n  ctx->SetOutputShape(\"dx\", 0, ctx->InputShape(\"x_like\", 0));\n  ctx->SetOutputIsDynamic(\"dx\", 0, ctx->InputIsDynamic(\"x_like\", 0));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> ConvDataGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> ConvDataGradOp::GetSbp(user_op::SbpContext* ctx) {\n  std::vector<user_op::OpArg> split_args;\n  split_args.emplace_back(\"dy\", 0);\n  split_args.emplace_back(\"x_like\", 0);\n  split_args.emplace_back(\"dx\", 0);\n  if (ctx->user_op_conf().has_input(\"_add_to_output\", 0)) {\n    split_args.emplace_back(\"_add_to_output\", 0);\n  }\n  ctx->NewBuilder().Split(split_args, 0).Broadcast(user_op::OpArg(\"filter\", 0)).Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> ConvDataGradOp::CheckAttr(const user_op::UserOpDefWrapper& def,\n                                                   const user_op::UserOpConfWrapper& conf) {\n  return CheckAttr_<0>(def, conf);\n}\n\n/* static */ Maybe<void> ConvDataGradOp::InferDataType(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& dy = ctx->InputTensorDesc(\"dy\", 0);\n  const user_op::TensorDesc& x_like = ctx->InputTensorDesc(\"x_like\", 0);\n  CHECK_EQ_OR_RETURN(x_like.data_type(), dy.data_type())\n      << \"InferDataType Failed. Expected \" << DataType_Name(dy.data_type()) << \", but got \"\n      << DataType_Name(x_like.data_type());\n  if (ctx->has_input(\"_add_to_output\", 0)) {\n    const user_op::TensorDesc& add_to_output = ctx->InputTensorDesc(\"_add_to_output\", 0);\n    CHECK_EQ_OR_RETURN(add_to_output.data_type(), x_like.data_type())\n        << \"InferDataType Failed. Expected \" << DataType_Name(add_to_output.data_type())\n        << \", but got \" << DataType_Name(x_like.data_type());\n  }\n  ctx->SetOutputDType(\"dx\", 0, ctx->InputDType(\"x_like\", 0));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<double> ConvDataGradOp::GetComputeComplexity(\n    user_op::ComputeComplexityFnContext* ctx) {\n  const std::vector<int32_t> kernel_size = ctx->Attr<std::vector<int32_t>>(\"kernel_size\");\n  const user_op::TensorDesc* dx = ctx->TensorDesc4ArgNameAndIndex(\"dx\", 0);\n  const user_op::TensorDesc* dy = ctx->TensorDesc4ArgNameAndIndex(\"dy\", 0);\n  const size_t c_dim =\n      ctx->Attr<std::string>(\"data_format\") == \"channels_first\" ? 1 : dy->shape().NumAxes() - 1;\n\n  double cost =\n      std::accumulate(kernel_size.begin(), kernel_size.end(), 1.0, std::multiplies<double>())\n      * std::accumulate(dx->shape().dim_vec().begin(), dx->shape().dim_vec().end(), 1.0,\n                        std::multiplies<double>())\n      * 2.0 * dy->shape().At(c_dim);\n\n  const auto& nd_sbp = ctx->NdSbp4ArgNameAndIndex(\"dx\", 0);\n  const auto& parallel_hierarchy = ctx->parallel_desc().hierarchy();\n  for (int32_t dim_sbp = 0; dim_sbp < nd_sbp.sbp_parallel_size(); dim_sbp++) {\n    if (nd_sbp.sbp_parallel(dim_sbp).has_split_parallel()) {\n      cost /= parallel_hierarchy->At(dim_sbp);\n    }\n  }\n  return cost;\n}\n\n/* static */ Maybe<void> ConvFilterGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& dy = ctx->InputTensorDesc(\"dy\", 0);\n  const user_op::TensorDesc& x = ctx->InputTensorDesc(\"x\", 0);\n\n  const int32_t num_spatial_dims = ctx->Attr<int32_t>(\"num_spatial_dims\");\n  const int32_t groups = ctx->Attr<int32_t>(\"groups\");\n  const std::string& data_format = ctx->Attr<std::string>(\"data_format\");\n  const std::vector<int32_t> kernel_size = ctx->Attr<std::vector<int32_t>>(\"kernel_size\");\n\n  CHECK_GE_OR_RETURN(num_spatial_dims, 1);\n  CHECK_LE_OR_RETURN(num_spatial_dims, 3);\n  CHECK_EQ_OR_RETURN(dy.shape().NumAxes(), num_spatial_dims + 2);\n  CHECK_EQ_OR_RETURN(x.shape().NumAxes(), num_spatial_dims + 2);\n  CHECK_GT_OR_RETURN(groups, 0);\n\n  DimVector filter_diff_dim_vec;\n  if (data_format == \"channels_first\") {\n    CHECK_LE_OR_RETURN(groups, x.shape().At(1));\n    CHECK_LE_OR_RETURN(groups, dy.shape().At(1));\n    CHECK_EQ_OR_RETURN(x.shape().At(1) % groups, 0);\n    CHECK_EQ_OR_RETURN(dy.shape().At(1) % groups, 0);\n    filter_diff_dim_vec.emplace_back(dy.shape().At(1));\n    filter_diff_dim_vec.emplace_back(x.shape().At(1) / groups);\n    filter_diff_dim_vec.insert(filter_diff_dim_vec.end(), kernel_size.cbegin(), kernel_size.cend());\n  } else {\n    CHECK_EQ_OR_RETURN(\"channels_last\", data_format);\n    CHECK_EQ_OR_RETURN(groups, 1);\n    filter_diff_dim_vec.emplace_back(dy.shape().dim_vec().back());\n    filter_diff_dim_vec.insert(filter_diff_dim_vec.end(), kernel_size.cbegin(), kernel_size.cend());\n    filter_diff_dim_vec.emplace_back(x.shape().dim_vec().back() / groups);\n  }\n\n  user_op::TensorDesc* filter_diff = ctx->MutOutputTensorDesc(\"filter_diff\", 0);\n  filter_diff->set_shape(Shape(filter_diff_dim_vec));\n  filter_diff->set_is_dynamic(false);\n\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> ConvFilterGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> ConvFilterGradOp::GetSbp(user_op::SbpContext* ctx) {\n  ctx->NewBuilder()\n      .Split(user_op::OpArg(\"dy\", 0), 0)\n      .Split(user_op::OpArg(\"x\", 0), 0)\n      .PartialSum(user_op::OpArg(\"filter_diff\", 0))\n      .Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> ConvFilterGradOp::CheckAttr(const user_op::UserOpDefWrapper& def,\n                                                     const user_op::UserOpConfWrapper& conf) {\n  return CheckAttr_<0>(def, conf);\n}\n\n/* static */ Maybe<void> ConvFilterGradOp::InferDataType(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& dy = ctx->InputTensorDesc(\"dy\", 0);\n  const user_op::TensorDesc& x = ctx->InputTensorDesc(\"x\", 0);\n  CHECK_EQ_OR_RETURN(x.data_type(), dy.data_type())\n      << \"InferDataType Failed. Expected \" << DataType_Name(dy.data_type()) << \", but got \"\n      << DataType_Name(x.data_type());\n  user_op::TensorDesc* filter_diff = ctx->MutOutputTensorDesc(\"filter_diff\", 0);\n  filter_diff->set_data_type(x.data_type());\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<double> ConvFilterGradOp::GetComputeComplexity(\n    user_op::ComputeComplexityFnContext* ctx) {\n  const std::vector<int32_t> kernel_size = ctx->Attr<std::vector<int32_t>>(\"kernel_size\");\n  const user_op::TensorDesc* dy = ctx->TensorDesc4ArgNameAndIndex(\"dy\", 0);\n  const user_op::TensorDesc* x = ctx->TensorDesc4ArgNameAndIndex(\"x\", 0);\n  const size_t c_dim =\n      ctx->Attr<std::string>(\"data_format\") == \"channels_first\" ? 1 : x->shape().NumAxes() - 1;\n\n  double cost =\n      std::accumulate(kernel_size.begin(), kernel_size.end(), 1.0, std::multiplies<double>())\n      * std::accumulate(dy->shape().dim_vec().begin(), dy->shape().dim_vec().end(), 1.0,\n                        std::multiplies<double>())\n      * 2.0 * x->shape().At(c_dim);\n\n  const auto& nd_sbp = ctx->NdSbp4ArgNameAndIndex(\"dy\", 0);\n  const auto& parallel_hierarchy = ctx->parallel_desc().hierarchy();\n  for (int32_t dim_sbp = 0; dim_sbp < nd_sbp.sbp_parallel_size(); dim_sbp++) {\n    if (nd_sbp.sbp_parallel(dim_sbp).has_split_parallel()) {\n      cost /= parallel_hierarchy->At(dim_sbp);\n    }\n  }\n  return cost;\n}\n\n/* static */ Maybe<void> ConvBiasGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& dy = ctx->InputTensorDesc(\"dy\", 0);\n  user_op::TensorDesc* bias_diff = ctx->MutOutputTensorDesc(\"bias_diff\", 0);\n\n  int32_t num_spatial_dims = ctx->Attr<int32_t>(\"num_spatial_dims\");\n  std::string data_format = ctx->Attr<std::string>(\"data_format\");\n\n  CHECK_GE_OR_RETURN(num_spatial_dims, 1);\n  CHECK_LE_OR_RETURN(num_spatial_dims, 3);\n  CHECK_EQ_OR_RETURN(dy.shape().NumAxes(), num_spatial_dims + 2);\n  if (data_format == \"channels_first\") {\n    bias_diff->set_shape(Shape({dy.shape().At(1)}));\n  } else if (data_format == \"channels_last\") {\n    bias_diff->set_shape(Shape({dy.shape().At(dy.shape().NumAxes() - 1)}));\n  } else {\n    OF_UNIMPLEMENTED();\n  }\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> ConvBiasGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> ConvBiasGradOp::GetSbp(user_op::SbpContext* ctx) {\n  ctx->NewBuilder()\n      .Split(user_op::OpArg(\"dy\", 0), 0)\n      .PartialSum(user_op::OpArg(\"bias_diff\", 0))\n      .Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> ConvBiasGradOp::CheckAttr(const user_op::UserOpDefWrapper& def,\n                                                   const user_op::UserOpConfWrapper& conf) {\n  std::string data_format = conf.attr<std::string>(\"data_format\");\n  if (data_format == \"channels_first\" || data_format == \"channels_last\") {\n    return Maybe<void>::Ok();\n  }\n  return oneflow::Error::CheckFailedError() << \"Illegal value for \" << conf.op_type_name() << \" op \"\n                                            << conf.op_name() << \": data_format:\" << data_format;\n}\n\n/* static */ Maybe<void> ConvBiasGradOp::InferDataType(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& dy = ctx->InputTensorDesc(\"dy\", 0);\n  user_op::TensorDesc* bias_diff = ctx->MutOutputTensorDesc(\"bias_diff\", 0);\n  bias_diff->set_data_type(dy.data_type());\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<double> ConvBiasGradOp::GetComputeComplexity(\n    user_op::ComputeComplexityFnContext* ctx) {\n  const user_op::TensorDesc* dy = ctx->TensorDesc4ArgNameAndIndex(\"dy\", 0);\n  const std::string data_format = ctx->Attr<std::string>(\"data_format\");\n  double cost = std::accumulate(dy->shape().dim_vec().begin(), dy->shape().dim_vec().end(), 1.0,\n                                std::multiplies<double>());\n  const auto& nd_sbp = ctx->NdSbp4ArgNameAndIndex(\"dy\", 0);\n  const auto& parallel_hierarchy = ctx->parallel_desc().hierarchy();\n  for (int32_t dim_sbp = 0; dim_sbp < nd_sbp.sbp_parallel_size(); dim_sbp++) {\n    if (nd_sbp.sbp_parallel(dim_sbp).has_split_parallel()) {\n      cost /= parallel_hierarchy->At(dim_sbp);\n    }\n  }\n  return cost;\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/convert_memory_format_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/user/ops/convert_memory_format_op.h\"\n\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\nstatic Shape ComputeShapeIdentity(const Shape& shape) { return shape; }\n\nShape ComputeShapeContiguousToChannelsLast(const Shape& shape) {\n  int ndim = shape.size();\n  if (ndim <= 2) { return ComputeShapeIdentity(shape); }\n  Shape target_shape(ndim);\n  target_shape[0] = shape[0];\n  target_shape[ndim - 1] = shape[1];\n  for (int i = 0; i < ndim - 2; ++i) { target_shape[i + 1] = shape[i + 2]; }\n  return target_shape;\n}\n\nShape ComputeShapeChannelsLastToContiguous(const Shape& shape) {\n  int ndim = shape.size();\n  if (ndim <= 2) { return ComputeShapeIdentity(shape); }\n  Shape target_shape(ndim);\n  target_shape[0] = shape[0];\n  target_shape[1] = shape[ndim - 1];\n  for (int i = 0; i < ndim - 2; ++i) { target_shape[i + 2] = shape[i + 1]; }\n  return target_shape;\n}\n\nstatic Maybe<void> GetSbpIdentity(user_op::SbpContext* ctx, const Shape& shape) {\n  for (int32_t i = 0; i < shape.size(); ++i) {\n    ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build();\n  }\n  return Maybe<void>::Ok();\n}\n\nstatic Maybe<void> GetSbpContiguousToChannelsLast(user_op::SbpContext* ctx, const Shape& shape) {\n  int ndim = shape.size();\n  if (ndim <= 2) { return GetSbpIdentity(ctx, shape); }\n  ctx->NewBuilder().Split(ctx->inputs(), 0).Split(ctx->outputs(), 0).Build();\n  ctx->NewBuilder().Split(ctx->inputs(), 1).Split(ctx->outputs(), ndim - 1).Build();\n  for (int32_t i = 0; i < ndim - 2; ++i) {\n    ctx->NewBuilder().Split(ctx->inputs(), i + 2).Split(ctx->outputs(), i + 1).Build();\n  }\n  return Maybe<void>::Ok();\n}\n\nstatic Maybe<void> GetSbpChannelsLastToContiguous(user_op::SbpContext* ctx, const Shape& shape) {\n  int ndim = shape.size();\n  if (ndim <= 2) { return GetSbpIdentity(ctx, shape); }\n  ctx->NewBuilder().Split(ctx->inputs(), 0).Split(ctx->outputs(), 0).Build();\n  ctx->NewBuilder().Split(ctx->inputs(), ndim - 1).Split(ctx->outputs(), 1).Build();\n  for (int32_t i = 0; i < ndim - 2; ++i) {\n    ctx->NewBuilder().Split(ctx->inputs(), i + 1).Split(ctx->outputs(), i + 2).Build();\n  }\n  return Maybe<void>::Ok();\n}\n\nusing ComputeShapeFunc = std::function<Shape(const Shape&)>;\nusing GetSbpFunc = std::function<Maybe<void>(user_op::SbpContext* ctx, const Shape& shape)>;\n\nstatic ComputeShapeFunc compute_shape_funcs[kMemoryFormatCount][kMemoryFormatCount] = {\n    /*kContiguous->other*/ {ComputeShapeIdentity, ComputeShapeContiguousToChannelsLast},\n    /*kChannelsLast->other*/ {ComputeShapeChannelsLastToContiguous, ComputeShapeIdentity},\n};\n\nstatic GetSbpFunc get_sbp_funcs[kMemoryFormatCount][kMemoryFormatCount] = {\n    /*kContiguous->other*/ {GetSbpIdentity, GetSbpContiguousToChannelsLast},\n    /*kChannelsLast->other*/ {GetSbpChannelsLastToContiguous, GetSbpIdentity},\n};\n\nShape ComputeConvertMemoryFormatShape(const Shape& shape, MemoryFormat memory_format,\n                                      MemoryFormat target_memory_format) {\n  auto shape_func = compute_shape_funcs[memory_format][target_memory_format];\n  return shape_func(shape);\n}\n\nstatic Maybe<void> GetConvertMemoryFormatSbp(user_op::SbpContext* ctx, const Shape& shape,\n                                             MemoryFormat memory_format,\n                                             MemoryFormat target_memory_format) {\n  auto sbp_func = get_sbp_funcs[memory_format][target_memory_format];\n  return sbp_func(ctx, shape);\n}\n\n/*static*/ Maybe<void> ConvertMemoryFormatOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& input_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"in\", 0);\n  const auto& memory_format = ctx->Attr<MemoryFormat>(\"memory_format\");\n\n  JUST(GetConvertMemoryFormatSbp(ctx, input_tensor.shape(), input_tensor.memory_format(),\n                                 memory_format));\n  ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build();\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> ConvertMemoryFormatOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& in_tensor_desc = ctx->InputTensorDesc(\"in\", 0);\n  user_op::TensorDesc* out_tensor_desc = ctx->MutOutputTensorDesc(\"out\", 0);\n  const Shape& in_shape = in_tensor_desc.shape();\n  const auto& memory_format = ctx->Attr<MemoryFormat>(\"memory_format\");\n\n  out_tensor_desc->set_is_dynamic(in_tensor_desc.is_dynamic());\n  out_tensor_desc->set_shape(\n      ComputeConvertMemoryFormatShape(in_shape, in_tensor_desc.memory_format(), memory_format));\n  out_tensor_desc->set_memory_format(memory_format);\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> ConvertMemoryFormatOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/*static*/ Maybe<void> ConvertMemoryFormatOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/convert_memory_format_op.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n\nnamespace oneflow {\n\nShape ComputeShapeContiguousToChannelsLast(const Shape& shape);\nShape ComputeShapeChannelsLastToContiguous(const Shape& shape);\n\nShape ComputeConvertMemoryFormatShape(const Shape& shape, MemoryFormat memory_format,\n                                      MemoryFormat target_memory_format);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/copy_hd_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/user/ops/nn_util.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n#include \"oneflow/core/operator/operator.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nMaybe<void> InferLogical(user_op::InferContext* ctx) {\n  UNIMPLEMENTED_THEN_RETURN() << \"copy hd should only exist in physical graph\";\n}\n\nMaybe<void> InferPhysical(user_op::InferContext* ctx) {\n  *ctx->MutOutputTensorDesc(\"out\", 0) = ctx->InputTensorDesc(\"in\", 0);\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> FwGetSbpFn(user_op::SbpContext* ctx) { return Maybe<void>::Ok(); }\n\nMaybe<void> InferFWDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\nMaybe<void> CopyD2HOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogical(ctx);\n}\n\nMaybe<void> CopyD2HOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferPhysical(ctx);\n}\n\nMaybe<void> CopyD2HOp::GetSbp(user_op::SbpContext* ctx) { return FwGetSbpFn(ctx); }\n\nMaybe<void> CopyD2HOp::InferDataType(user_op::InferContext* ctx) { return InferFWDataType(ctx); }\n\nMaybe<void> CopyH2DOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogical(ctx);\n}\n\nMaybe<void> CopyH2DOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferPhysical(ctx);\n}\n\nMaybe<void> CopyH2DOp::GetSbp(user_op::SbpContext* ctx) { return FwGetSbpFn(ctx); }\n\nMaybe<void> CopyH2DOp::InferDataType(user_op::InferContext* ctx) { return InferFWDataType(ctx); }\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/copy_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/device.h\"\n#include \"oneflow/core/framework/stream.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n#include \"oneflow/core/common/env_var/stream.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nStreamType GetH2DStreamType() {\n  if (ThreadLocalEnvBool<ONEFLOW_STREAM_ENABLE_H2D_STREAM>()) {\n    return StreamType::kHost2Device;\n  } else {\n    return StreamType::kCompute;\n  }\n}\n\nMaybe<Symbol<Stream>> MakeCopyStream(const Symbol<Device>& in_device,\n                                     const Symbol<Device>& out_device, const bool pin_memory) {\n  if (in_device->type() != \"cpu\" && out_device->type() == \"cpu\") {\n    return Stream::New(in_device, StreamType::kDevice2Host);\n  } else if (in_device->type() == \"cpu\" && out_device->type() != \"cpu\") {\n    return Stream::New(out_device, GetH2DStreamType());\n  } else if (in_device->type() == \"cpu\" && out_device->type() == \"cpu\" && pin_memory) {\n    // TODO:(zhaoluyang) Parsing pin-memory-device from python\n    auto pin_device = JUST(Device::New(\"cuda\"));\n    return Stream::New(pin_device, StreamType::kPinnedCompute);\n  } else {\n    CHECK_EQ_OR_RETURN(in_device->type(), out_device->type());\n    return Stream::New(out_device, StreamType::kCompute);\n  }\n}\n\n}  // namespace\n\n/* static */ Maybe<void> CopyOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  ctx->SetOutputShape(\"out\", 0, ctx->InputShape(\"in\", 0));\n  ctx->SetOutputStride(\"out\", 0, ctx->InputStride(\"in\", 0));\n  ctx->SetOutputIsDynamic(\"out\", 0, ctx->InputIsDynamic(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> CopyOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> CopyOp::GetSbp(user_op::SbpContext* ctx) {\n  const auto& inputs = ctx->inputs();\n  CHECK_EQ_OR_RETURN(inputs.size(), 1);\n  const auto& input =\n      ctx->LogicalTensorDesc4InputArgNameAndIndex(inputs[0].first, inputs[0].second);\n  for (int64_t axis = 0; axis < input.shape().NumAxes(); ++axis) {\n    ctx->NewBuilder().Split(inputs, axis).Split(ctx->outputs(), axis).Build();\n  }\n  ctx->NewBuilder().PartialSum(inputs).PartialSum(ctx->outputs()).Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> CopyOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<Symbol<Stream>> CopyOp::InferDeviceAndStream(\n    user_op::DeviceAndStreamInferContext* ctx) {\n  Symbol<Device> out_device = ctx->Attr<Symbol<Device>>(\"device\");\n  *ctx->OutputTensorDevice4ArgNameAndIndex(\"out\", 0) = out_device;\n  const Symbol<Device>& in_device = ctx->InputTensorDevice4ArgNameAndIndex(\"in\", 0);\n  const bool pin_memory = ctx->Attr<bool>(\"pin_memory\");\n  return MakeCopyStream(in_device, out_device, pin_memory);\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/count_not_finite_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/* static */ Maybe<void> CountNotFiniteOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  user_op::TensorDesc* y_desc = ctx->MutOutputTensorDesc(\"y\", 0);\n  y_desc->set_shape(Shape({1}));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> CountNotFiniteOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> CountNotFiniteOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& x = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"x\", 0);\n  FOR_RANGE(int64_t, i, 0, x.shape().NumAxes()) {\n    ctx->NewBuilder().Split(user_op::OpArg(\"x\", 0), i).PartialSum(user_op::OpArg(\"y\", 0)).Build();\n  }\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> CountNotFiniteOp::InferDataType(user_op::InferContext* ctx) {\n  user_op::TensorDesc* y_desc = ctx->MutOutputTensorDesc(\"y\", 0);\n  y_desc->set_data_type(DataType::kInt64);\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> MultiCountNotFiniteOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  user_op::TensorDesc* y_desc = ctx->MutOutputTensorDesc(\"y\", 0);\n  y_desc->set_shape(Shape({1}));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> MultiCountNotFiniteOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> MultiCountNotFiniteOp::GetSbp(user_op::SbpContext* ctx) {\n  int64_t min_num_axes = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"x\", 0).shape().NumAxes();\n  for (int64_t i = 1; i < ctx->user_op_conf().input_size(\"x\"); ++i) {\n    min_num_axes = std::min(min_num_axes,\n                            ctx->LogicalTensorDesc4InputArgNameAndIndex(\"x\", i).shape().NumAxes());\n  }\n  for (int64_t i = 0; i < min_num_axes; ++i) {\n    ctx->NewBuilder().Split(ctx->inputs(), i).PartialSum(user_op::OpArg(\"y\", 0)).Build();\n  }\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> MultiCountNotFiniteOp::InferDataType(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& first_x_desc = ctx->InputTensorDesc(\"x\", 0);\n  for (const auto& in_arg_pair : ctx->inputs()) {\n    const user_op::TensorDesc& x_desc = ctx->InputTensorDesc(in_arg_pair.first, in_arg_pair.second);\n    CHECK_EQ_OR_RETURN(x_desc.data_type(), first_x_desc.data_type())\n        << \"InferDataType Failed. Expected \" << DataType_Name(first_x_desc.data_type())\n        << \", but got \" << DataType_Name(x_desc.data_type());\n  }\n  user_op::TensorDesc* y_desc = ctx->MutOutputTensorDesc(\"y\", 0);\n  y_desc->set_data_type(DataType::kInt64);\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> MultiCountNotFiniteOp::CheckAttr(const user_op::UserOpDefWrapper&,\n                                                        const user_op::UserOpConfWrapper& op_conf) {\n  CHECK_OR_RETURN(op_conf.input_size(\"x\") >= 1);\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/ctc_loss_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/* static */ Maybe<void> CtcLossOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& log_probs = ctx->InputTensorDesc(\"log_probs\", 0);\n  const user_op::TensorDesc& targets = ctx->InputTensorDesc(\"targets\", 0);\n  const user_op::TensorDesc& input_lengths = ctx->InputTensorDesc(\"input_lengths\", 0);\n  const user_op::TensorDesc& target_lengths = ctx->InputTensorDesc(\"target_lengths\", 0);\n  const int64_t batch_size = log_probs.shape().At(1);\n  const int64_t max_target_length = ctx->Attr<int64_t>(\"max_target_length\");\n  if (targets.shape().NumAxes() == 2) {\n    CHECK_EQ_OR_RETURN(targets.shape().At(0), batch_size);\n    CHECK_GE_OR_RETURN(targets.shape().At(1), max_target_length);\n  }\n  CHECK_EQ_OR_RETURN(input_lengths.shape().At(0), batch_size);\n  CHECK_EQ_OR_RETURN(target_lengths.shape().At(0), batch_size);\n  CHECK_GE_OR_RETURN(ctx->Attr<int64_t>(\"blank\"), 0);\n  CHECK_LT_OR_RETURN(ctx->Attr<int64_t>(\"blank\"), log_probs.shape().At(2));\n\n  ctx->SetOutputShape(\"loss\", 0, Shape({batch_size}));\n  ctx->SetOutputShape(\"alpha\", 0,\n                      Shape({batch_size, log_probs.shape().At(0), 2 * max_target_length + 1}));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> CtcLossOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> CtcLossOp::GetSbp(user_op::SbpContext* ctx) {\n  ctx->NewBuilder()\n      .Split(user_op::OpArg(\"log_probs\", 0), 1)  // `log_probs` batch axis is 1\n      .Split(user_op::OpArg(\"targets\", 0), 0)\n      .Split(user_op::OpArg(\"input_lengths\", 0), 0)\n      .Split(user_op::OpArg(\"target_lengths\", 0), 0)\n      .Split(user_op::OpArg(\"loss\", 0), 0)\n      .Split(user_op::OpArg(\"alpha\", 0), 0)\n      .Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> CtcLossOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"loss\", 0, ctx->InputDType(\"log_probs\", 0));\n  ctx->SetOutputDType(\"alpha\", 0, ctx->InputDType(\"log_probs\", 0));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> CtcLossGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& log_probs = ctx->InputTensorDesc(\"log_probs\", 0);\n  const user_op::TensorDesc& targets = ctx->InputTensorDesc(\"targets\", 0);\n  const user_op::TensorDesc& input_lengths = ctx->InputTensorDesc(\"input_lengths\", 0);\n  const user_op::TensorDesc& target_lengths = ctx->InputTensorDesc(\"target_lengths\", 0);\n  const int64_t batch_size = log_probs.shape().At(1);\n  const int64_t max_target_length = ctx->Attr<int64_t>(\"max_target_length\");\n  if (targets.shape().NumAxes() == 2) {\n    CHECK_EQ_OR_RETURN(targets.shape().At(0), batch_size);\n    CHECK_GE_OR_RETURN(targets.shape().At(1), max_target_length);\n  }\n  CHECK_EQ_OR_RETURN(input_lengths.shape().At(0), batch_size);\n  CHECK_EQ_OR_RETURN(target_lengths.shape().At(0), batch_size);\n  CHECK_GE_OR_RETURN(ctx->Attr<int64_t>(\"blank\"), 0);\n  CHECK_LT_OR_RETURN(ctx->Attr<int64_t>(\"blank\"), log_probs.shape().At(2));\n\n  ctx->SetOutputShape(\"grad\", 0, log_probs.shape());\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> CtcLossGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> CtcLossGradOp::GetSbp(user_op::SbpContext* ctx) {\n  ctx->NewBuilder()\n      .Split(user_op::OpArg(\"grad_out\", 0), 0)\n      .Split(user_op::OpArg(\"log_probs\", 0), 1)  // `log_probs` batch axis is 1\n      .Split(user_op::OpArg(\"targets\", 0), 0)\n      .Split(user_op::OpArg(\"input_lengths\", 0), 0)\n      .Split(user_op::OpArg(\"target_lengths\", 0), 0)\n      .Split(user_op::OpArg(\"loss\", 0), 0)\n      .Split(user_op::OpArg(\"alpha\", 0), 0)\n      .Split(user_op::OpArg(\"grad\", 0), 1)\n      .Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> CtcLossGradOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"grad\", 0, ctx->InputDType(\"log_probs\", 0));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> CtcGreedyDecoderOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& log_probs = ctx->InputTensorDesc(\"log_probs\", 0);\n  const user_op::TensorDesc& input_lengths = ctx->InputTensorDesc(\"input_lengths\", 0);\n  const int64_t batch_size = log_probs.shape().At(1);\n  CHECK_EQ_OR_RETURN(batch_size, input_lengths.shape().At(0));\n  ctx->SetOutputShape(\"decoded\", 0, Shape({batch_size, log_probs.shape().At(0)}));\n  ctx->SetOutputShape(\"neg_sum_logits\", 0, Shape({batch_size, 1}));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> CtcGreedyDecoderOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> CtcGreedyDecoderOp::GetSbp(user_op::SbpContext* ctx) {\n  ctx->NewBuilder()\n      .Split(user_op::OpArg(\"log_probs\", 0), 1)  // `log_probs` batch axis is 1\n      .Split(user_op::OpArg(\"input_lengths\", 0), 0)\n      .Split(user_op::OpArg(\"decoded\", 0), 0)\n      .Split(user_op::OpArg(\"neg_sum_logits\", 0), 0)\n      .Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> CtcGreedyDecoderOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"decoded\", 0, ctx->InputDType(\"input_lengths\", 0));\n  ctx->SetOutputDType(\"neg_sum_logits\", 0, ctx->InputDType(\"log_probs\", 0));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/cublas_bias_add_relu_matmul_grad_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/just.h\"\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/infer_util.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nMaybe<void> InferTensorDesc4FusedMatmulBackward(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& weight_desc = ctx->InputTensorDesc(\"weight\", 0);\n  const user_op::TensorDesc& dy_desc = ctx->InputTensorDesc(\"dy\", 0);\n  const int64_t bias_size = weight_desc.shape().At(1);\n  Shape d_grad_shape({dy_desc.shape().At(0), weight_desc.shape().At(1)});\n  ctx->SetOutputShape(\"d_grad\", 0, d_grad_shape);\n  ctx->SetOutputShape(\"d_bias\", 0, Shape({bias_size}));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InferDataType4MatmulBackward(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& weight_desc = ctx->InputTensorDesc(\"weight\", 0);\n  const user_op::TensorDesc& dy_desc = ctx->InputTensorDesc(\"dy\", 0);\n  CHECK_EQ_OR_RETURN(weight_desc.data_type(), dy_desc.data_type())\n      << \"InferDataType Failed. Expected \" << DataType_Name(dy_desc.data_type()) << \", but got \"\n      << DataType_Name(weight_desc.data_type());\n\n  user_op::TensorDesc* d_grad_desc = ctx->MutOutputTensorDesc(\"d_grad\", 0);\n  user_op::TensorDesc* d_bias_desc = ctx->MutOutputTensorDesc(\"d_bias\", 0);\n\n  d_grad_desc->set_data_type(dy_desc.data_type());\n  d_bias_desc->set_data_type(dy_desc.data_type());\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\n/* static */ Maybe<void> CublasBiasAddReluMatmulGradOp::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferTensorDesc4FusedMatmulBackward(ctx);\n}\n\n/*static*/ Maybe<void> CublasBiasAddReluMatmulGradOp::InferPhysicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> CublasBiasAddReluMatmulGradOp::GetSbp(user_op::SbpContext* ctx) {\n  ctx->NewBuilder()\n      .Broadcast(user_op::OpArg(\"weight\", 0))\n      .Split(user_op::OpArg(\"dy\", 0), 0)\n      .Split(user_op::OpArg(\"aux\", 0), 0)\n      .Split(user_op::OpArg(\"d_grad\", 0), 0)\n      .PartialSum(user_op::OpArg(\"d_bias\", 0))\n      .Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> CublasBiasAddReluMatmulGradOp::InferDataType(user_op::InferContext* ctx) {\n  return InferDataType4MatmulBackward(ctx);\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/cublas_fused_matmul_bias_add_grad_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/just.h\"\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/infer_util.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nMaybe<void> InferTensorDesc4MatmulBiasAddBackward(user_op::InferContext* ctx) {\n  /*\n  x (m, k)\n  w (n, k) need transpose\n  bias (n, )\n  y (m, n)\n  w_grad = dy_transpose matmul x\n  */\n  const user_op::TensorDesc& x_desc = ctx->InputTensorDesc(\"x\", 0);\n  const user_op::TensorDesc& dy_desc = ctx->InputTensorDesc(\"dy\", 0);\n\n  const int64_t bias_size = dy_desc.shape().At(1);\n  Shape w_grad_shape({dy_desc.shape().At(1), x_desc.shape().At(1)});\n  ctx->SetOutputShape(\"w_grad\", 0, w_grad_shape);\n  ctx->SetOutputShape(\"b_grad\", 0, Shape({bias_size}));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InferDataType4MatmulBiasAddBackward(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& x_desc = ctx->InputTensorDesc(\"x\", 0);\n  const user_op::TensorDesc& dy_desc = ctx->InputTensorDesc(\"dy\", 0);\n  CHECK_EQ_OR_RETURN(x_desc.data_type(), dy_desc.data_type())\n      << \"InferDataType Failed. Expected \" << DataType_Name(dy_desc.data_type()) << \", but got \"\n      << DataType_Name(x_desc.data_type());\n\n  user_op::TensorDesc* w_grad_desc = ctx->MutOutputTensorDesc(\"w_grad\", 0);\n  user_op::TensorDesc* b_grad_desc = ctx->MutOutputTensorDesc(\"b_grad\", 0);\n\n  w_grad_desc->set_data_type(dy_desc.data_type());\n  b_grad_desc->set_data_type(dy_desc.data_type());\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\n/* static */ Maybe<void> CublasMatmulBiasAddGradOp::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferTensorDesc4MatmulBiasAddBackward(ctx);\n}\n\n/*static*/ Maybe<void> CublasMatmulBiasAddGradOp::InferPhysicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> CublasMatmulBiasAddGradOp::GetSbp(user_op::SbpContext* ctx) {\n  /*\n  dy need transpose.\n\n  assume dy(m, n), x(m, k), dbias=(n, 1)\n  dw = dy_T matmul x\n\n  */\n  ctx->NewBuilder()\n      .Split(user_op::OpArg(\"dy\", 0), 1)\n      .Broadcast(user_op::OpArg(\"x\", 0))\n      .Split(user_op::OpArg(\"w_grad\", 0), 0)\n      .Split(user_op::OpArg(\"b_grad\", 0), 0)\n      .Build();\n\n  ctx->NewBuilder()\n      .Split(user_op::OpArg(\"dy\", 0), 0)\n      .Split(user_op::OpArg(\"x\", 0), 0)\n      .PartialSum(user_op::OpArg(\"w_grad\", 0))\n      .PartialSum(user_op::OpArg(\"b_grad\", 0))\n      .Build();\n\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> CublasMatmulBiasAddGradOp::InferDataType(user_op::InferContext* ctx) {\n  return InferDataType4MatmulBiasAddBackward(ctx);\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/cublas_fused_mlp_grad_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nMaybe<void> InferTensorDesc4FusedMatmulBackward(user_op::InferContext* ctx) {\n  const int64_t weight_num = ctx->input_size(\"weights\");\n  const user_op::TensorDesc& x_desc = ctx->InputTensorDesc(\"x\", 0);\n  for (int idx = weight_num - 1; idx >= 0; idx--) {\n    const user_op::TensorDesc& weight_desc = ctx->InputTensorDesc(\"weights\", idx);\n    ctx->SetOutputShape(\"d_weights\", idx, weight_desc.shape());\n    ctx->SetOutputShape(\"d_biases\", idx, Shape({weight_desc.shape().At(0)}));\n  }\n  ctx->SetOutputShape(\"d_x\", 0, x_desc.shape());\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InferDataType4MatmulBackward(user_op::InferContext* ctx) {\n  const int64_t weight_num = ctx->input_size(\"weights\");\n  const int64_t dweight_num = ctx->output_size(\"d_weights\");\n  CHECK_EQ(weight_num, dweight_num) << \"The number of weights and d_weights should be equal. \";\n  const int64_t dbias_size = ctx->output_size(\"d_biases\");\n  CHECK_EQ(weight_num, dbias_size) << \"The number of d_biases should be equal to weight_num. \"\n                                      \"Because last layer's bias_grad is computed by ReduceSum. \";\n  const user_op::TensorDesc& dy_desc = ctx->InputTensorDesc(\"dy\", 0);\n  for (int idx = weight_num - 1; idx >= 0; idx--) {\n    ctx->SetOutputDType(\"d_weights\", idx, dy_desc.data_type());\n    ctx->SetOutputDType(\"d_biases\", idx, dy_desc.data_type());\n  }\n  ctx->SetOutputDType(\"d_x\", 0, dy_desc.data_type());\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\n/* static */ Maybe<void> CublasFusedMLPGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  return InferTensorDesc4FusedMatmulBackward(ctx);\n}\n\n/*static*/ Maybe<void> CublasFusedMLPGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> CublasFusedMLPGradOp::GetSbp(user_op::SbpContext* ctx) {\n  auto builder = ctx->NewBuilder().Split(user_op::OpArg(\"x\", 0), 0);\n  builder.Split(user_op::OpArg(\"dy\", 0), 0);\n  for (int i = 0; i < ctx->user_op_conf().input_size(\"weights\"); ++i) {\n    builder.Broadcast(user_op::OpArg(\"weights\", i));\n  }\n  for (int i = 0; i < ctx->user_op_conf().input_size(\"cublas_aux\"); ++i) {\n    builder.Split(user_op::OpArg(\"cublas_aux\", i), 0);\n  }\n  for (int i = 0; i < ctx->user_op_conf().input_size(\"hidden\"); ++i) {\n    builder.Split(user_op::OpArg(\"hidden\", i), 0);\n  }\n\n  builder.Split(user_op::OpArg(\"d_x\", 0), 0);\n  if (ParseBooleanFromEnv(\"ONEFLOW_ONE_EMBEDDING_FUSED_MLP_GRAD_OVERLAP_ALLREDUCE\", false)) {\n    // FusedMLPGradKernel do allreduce for dbias and dweight, so here convert from PartialSum to\n    // Broadcast.\n    for (int i = 0; i < ctx->user_op_conf().output_size(\"d_biases\"); ++i) {\n      builder.Broadcast(user_op::OpArg(\"d_biases\", i));\n    }\n    for (int i = 0; i < ctx->user_op_conf().output_size(\"d_weights\"); ++i) {\n      builder.Broadcast(user_op::OpArg(\"d_weights\", i));\n    }\n  } else {\n    for (int i = 0; i < ctx->user_op_conf().output_size(\"d_biases\"); ++i) {\n      builder.PartialSum(user_op::OpArg(\"d_biases\", i));\n    }\n    for (int i = 0; i < ctx->user_op_conf().output_size(\"d_weights\"); ++i) {\n      builder.PartialSum(user_op::OpArg(\"d_weights\", i));\n    }\n  }\n\n  builder.Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> CublasFusedMLPGradOp::InferDataType(user_op::InferContext* ctx) {\n  return InferDataType4MatmulBackward(ctx);\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/cublas_fused_mlp_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/data_type.pb.h\"\n#include \"oneflow/core/common/just.h\"\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/infer_util.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nconstexpr int32_t kAuxReluLdAlignRequirement = 128;\n\nlong AlignReluAuxLd(long aux_ld) {\n  /*\n  ReLu bit-mask matrix leading dimension in elements.\n  Must be divisible by 128 and be no less than the number of rows in the output matrix.\n  */\n  long old_aux_ld = aux_ld;\n  return ((old_aux_ld + kAuxReluLdAlignRequirement - 1) / kAuxReluLdAlignRequirement)\n         * kAuxReluLdAlignRequirement;\n}\n\nMaybe<void> InferTensorDesc4FusedMatmul(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& x_desc = ctx->InputTensorDesc(\"x\", 0);\n  int32_t weight_size = ctx->input_size(\"weights\");\n  int32_t bias_size = ctx->input_size(\"biases\");\n  CHECK_EQ_OR_RETURN(weight_size, bias_size);\n  /*\n  A: (m, k)\n  B: (n, k) need transpose\n  C: (m, n)\n  */\n  int64_t m = 0, n = 0, k = 0, cublas_aux_ld = 0;\n  m = x_desc.shape().At(0);\n  k = x_desc.shape().At(1);\n\n  for (int32_t idx = 0; idx < weight_size; idx++) {\n    // skip first input weight.\n    const user_op::TensorDesc& weight_desc = ctx->InputTensorDesc(\"weights\", idx);\n    const user_op::TensorDesc& bias_desc = ctx->InputTensorDesc(\"biases\", idx);\n    CHECK_EQ_OR_RETURN(weight_desc.shape().NumAxes(), 2);\n    CHECK_EQ_OR_RETURN(bias_desc.shape().NumAxes(), 1);\n\n    n = weight_desc.shape().At(0);\n    CHECK_EQ_OR_RETURN(bias_desc.shape().At(0), n);\n    CHECK_EQ_OR_RETURN(weight_desc.shape().At(1), k);\n\n    cublas_aux_ld = n;\n    // Set Middle result shape.\n    long cublas_aligned_aux_ld = AlignReluAuxLd(cublas_aux_ld);\n    int64_t aux_size = cublas_aligned_aux_ld / 32;  // Cause we use int32_t as dtype\n    ctx->SetOutputShape(\"cublas_aux\", idx, Shape({m, aux_size}));\n    ctx->SetOutputShape(\"hidden\", idx, Shape({m, n}));\n    // Set for next layer.\n    k = n;\n  }\n  ctx->SetOutputShape(\"out\", 0, Shape({m, n}));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InferDataType4Matmul(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& first_in_desc = ctx->InputTensorDesc(\"x\", 0);\n\n  for (const auto& in_arg_pair : ctx->inputs()) {\n    const user_op::TensorDesc& in_desc =\n        ctx->InputTensorDesc(in_arg_pair.first, in_arg_pair.second);\n    CHECK_EQ_OR_RETURN(in_desc.data_type(), first_in_desc.data_type())\n        << \"InferDataType Failed. Expected \" << DataType_Name(first_in_desc.data_type())\n        << \", but got \" << DataType_Name(in_desc.data_type());\n  }\n\n  user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc(\"out\", 0);\n  out_desc->set_data_type(first_in_desc.data_type());\n\n  for (int32_t i = 0; i < ctx->output_size(\"hidden\"); i++) {\n    user_op::TensorDesc* hidden_desc = ctx->MutOutputTensorDesc(\"hidden\", i);\n    hidden_desc->set_data_type(first_in_desc.data_type());\n  }\n\n  for (int32_t i = 0; i < ctx->output_size(\"cublas_aux\"); i++) {\n    user_op::TensorDesc* aux_desc = ctx->MutOutputTensorDesc(\"cublas_aux\", i);\n    aux_desc->set_data_type(DataType::kInt32);\n  }\n\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\n/* static */ Maybe<void> CublasFusedMLPOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  return InferTensorDesc4FusedMatmul(ctx);\n}\n\n/*static*/ Maybe<void> CublasFusedMLPOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> CublasFusedMLPOp::GetSbp(user_op::SbpContext* ctx) {\n  // Currently Only support S0 B B B B ... S0\n  auto builder = ctx->NewBuilder().Split(user_op::OpArg(\"x\", 0), 0);\n  for (int i = 0; i < ctx->user_op_conf().input_size(\"weights\"); ++i) {\n    builder.Broadcast(user_op::OpArg(\"weights\", i));\n  }\n  for (int i = 0; i < ctx->user_op_conf().input_size(\"biases\"); ++i) {\n    builder.Broadcast(user_op::OpArg(\"biases\", i));\n  }\n  for (int i = 0; i < ctx->user_op_conf().output_size(\"cublas_aux\"); ++i) {\n    builder.Split(user_op::OpArg(\"cublas_aux\", i), 0);\n  }\n  for (int i = 0; i < ctx->user_op_conf().output_size(\"hidden\"); ++i) {\n    builder.Split(user_op::OpArg(\"hidden\", i), 0);\n  }\n  builder.Split(user_op::OpArg(\"out\", 0), 0);\n  builder.Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> CublasFusedMLPOp::InferDataType(user_op::InferContext* ctx) {\n  return InferDataType4Matmul(ctx);\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/cum_ops.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\nMaybe<void> CumsumOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  ctx->SetOutputShape(\"y\", 0, ctx->InputShape(\"x\", 0));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> CumsumOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\nMaybe<void> CumsumOp::GetSbp(user_op::SbpContext* ctx) {\n  const auto& in_tensor_desc = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"x\", 0);\n  auto dim = ctx->Attr<int64_t>(\"dim\");\n  for (auto i = dim + 1; i < in_tensor_desc.shape().NumAxes(); i++) {\n    ctx->NewBuilder().Split(user_op::OpArg(\"x\", 0), i).Split(user_op::OpArg(\"y\", 0), i).Build();\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> CumsumOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"y\", 0, ctx->InputDType(\"x\", 0));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> CumProdOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  ctx->SetOutputShape(\"y\", 0, ctx->InputShape(\"x\", 0));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> CumProdOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\nMaybe<void> CumProdOp::GetSbp(user_op::SbpContext* ctx) {\n  const auto& in_tensor_desc = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"x\", 0);\n  auto dim = ctx->Attr<int64_t>(\"dim\");\n  for (auto i = dim + 1; i < in_tensor_desc.shape().NumAxes(); i++) {\n    ctx->NewBuilder().Split(user_op::OpArg(\"x\", 0), i).Split(user_op::OpArg(\"y\", 0), i).Build();\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> CumProdOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"y\", 0, ctx->InputDType(\"x\", 0));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> CumProdGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  ctx->SetOutputShape(\"dx\", 0, ctx->InputShape(\"dy\", 0));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> CumProdGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\nMaybe<void> CumProdGradOp::GetSbp(user_op::SbpContext* ctx) {\n  const auto& dy_tensor_desc = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"dy\", 0);\n  auto dim = ctx->Attr<int64_t>(\"dim\");\n  for (auto i = dim + 1; i < dy_tensor_desc.shape().NumAxes(); i++) {\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"dy\", 0), i)\n        .Split(user_op::OpArg(\"output\", 0), i)\n        .Split(user_op::OpArg(\"input\", 0), i)\n        .Split(user_op::OpArg(\"dx\", 0), i)\n        .Build();\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> CumProdGradOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"dx\", 0, ctx->InputDType(\"dy\", 0));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/data_shuffle_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n#include \"oneflow/core/embedding/embedding_manager.h\"\n#include \"oneflow/core/operator/operator.h\"\n\nnamespace oneflow {\n\n/* static */ Maybe<void> UniqueKeyValuePairOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& keys_shape = ctx->InputShape(\"keys\", 0);\n  const int32_t num_tables = ctx->Attr<int32_t>(\"num_tables\");\n  CHECK_GE_OR_RETURN(num_tables, 1) << \"num_tables must greater than 1, but get \" << num_tables;\n  if (ctx->has_input(\"values\", 0)) {\n    const Shape& values_shape = ctx->InputShape(\"values\", 0);\n    CHECK_EQ_OR_RETURN(keys_shape, values_shape) << \"keys_shape must equal to values_shape\";\n  } else {\n    if (num_tables > 1) {\n      CHECK_EQ_OR_RETURN(keys_shape.NumAxes(), 2);\n      CHECK_EQ_OR_RETURN(keys_shape.At(1), num_tables) << \"keys cols must equal to num_tables\";\n    }\n  }\n  ctx->SetOutputShape(\"num_unique\", 0, Shape({1}));\n  ctx->SetOutputShape(\"unique_keys\", 0, Shape({keys_shape.elem_cnt()}));\n  ctx->SetOutputShape(\"unique_values\", 0, Shape({keys_shape.elem_cnt()}));\n  ctx->SetOutputShape(\"inverse_indices\", 0, keys_shape);\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> UniqueKeyValuePairOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> UniqueKeyValuePairOp::GetSbp(user_op::SbpContext* ctx) {\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> UniqueKeyValuePairOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"num_unique\", 0, DataType::kInt32);\n  ctx->SetOutputDType(\"unique_keys\", 0, ctx->InputDType(\"keys\", 0));\n  ctx->SetOutputDType(\"inverse_indices\", 0, DataType::kInt32);\n  if (ctx->has_input(\"values\", 0)) {\n    ctx->SetOutputDType(\"unique_values\", 0, ctx->InputDType(\"values\", 0));\n  } else {\n    ctx->SetOutputDType(\"unique_values\", 0, DataType::kUInt8);\n  }\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> IdShuffleOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& ids_shape = ctx->InputShape(\"ids\", 0);\n  const int32_t num_tables = ctx->Attr<int32_t>(\"num_tables\");\n  CHECK_GE_OR_RETURN(num_tables, 1) << \"num_tables must greater than 1, but get \" << num_tables;\n  if (ctx->has_input(\"table_ids\", 0)) {\n    const Shape& table_ids_shape = ctx->InputShape(\"table_ids\", 0);\n    CHECK_EQ_OR_RETURN(ids_shape, table_ids_shape) << \"ids_shape must equal to table_ids_shape\";\n  } else {\n    if (num_tables > 1) {\n      CHECK_EQ_OR_RETURN(ids_shape.NumAxes(), 2);\n      CHECK_EQ_OR_RETURN(ids_shape.At(1), num_tables) << \"ids cols must equal to num_tables\";\n    }\n  }\n  const int64_t num_ids = ids_shape.elem_cnt();\n  const int64_t parallel_num = ctx->parallel_num();\n  ctx->SetOutputShape(\"num_unique_matrix\", 0, Shape({parallel_num * parallel_num}));\n  ctx->SetOutputShape(\"inverse_unique_partition_indices\", 0, ids_shape);\n  ctx->SetOutputShape(\"cur_rank_num_unique\", 0, Shape({1}));\n  ctx->SetOutputShape(\"cur_rank_unique_ids\", 0, Shape({num_ids * parallel_num}));\n  ctx->SetOutputShape(\"cur_rank_inverse_indices\", 0, Shape({num_ids * parallel_num}));\n  ctx->SetOutputShape(\"cur_rank_unique_table_ids\", 0, Shape({num_ids * parallel_num}));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> IdShuffleOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> IdShuffleOp::GetSbp(user_op::SbpContext* ctx) {\n  ctx->NewBuilder()\n      .Split(ctx->inputs(), 0)\n      .Split(ctx->outputs(), 0)\n      .Broadcast(user_op::OpArg(\"num_unique_matrix\", 0))\n      .Broadcast(user_op::OpArg(\"cur_rank_num_unique\", 0))\n      .Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> IdShuffleOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"num_unique_matrix\", 0, DataType::kUInt32);\n  ctx->SetOutputDType(\"inverse_unique_partition_indices\", 0, DataType::kUInt32);\n  ctx->SetOutputDType(\"cur_rank_num_unique\", 0, DataType::kUInt32);\n  ctx->SetOutputDType(\"cur_rank_unique_ids\", 0, ctx->InputDType(\"ids\", 0));\n  ctx->SetOutputDType(\"cur_rank_inverse_indices\", 0, DataType::kUInt32);\n  if (ctx->has_input(\"table_ids\", 0)) {\n    ctx->SetOutputDType(\"cur_rank_unique_table_ids\", 0, ctx->InputDType(\"table_ids\", 0));\n  } else {\n    ctx->SetOutputDType(\"cur_rank_unique_table_ids\", 0, DataType::kUInt8);\n  }\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> EmbeddingShuffleOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& cur_rank_embeddings_shape = ctx->InputShape(\"cur_rank_embeddings\", 0);\n  const Shape& num_unique_matrix_shape = ctx->InputShape(\"num_unique_matrix\", 0);\n  const Shape& cur_rank_inverse_indices_shape = ctx->InputShape(\"cur_rank_inverse_indices\", 0);\n  const Shape& inverse_unique_partition_indices_shape =\n      ctx->InputShape(\"inverse_unique_partition_indices\", 0);\n  const int64_t embedding_size = ctx->Attr<int64_t>(\"embedding_size\");\n  const int64_t num_ids = inverse_unique_partition_indices_shape.elem_cnt();\n  const int64_t parallel_num = ctx->parallel_num();\n  if (embedding::UseDynamicMemoryAllocation()) {\n    CHECK_EQ_OR_RETURN(cur_rank_embeddings_shape.elem_cnt(), 1)\n        << \"if use dynamic memory allocation, cur_rank_embeddings elem_cnt should be 1.\";\n  } else {\n    CHECK_EQ_OR_RETURN(cur_rank_embeddings_shape.NumAxes(), 2)\n        << \"cur_rank_embeddings num_axes should be 2.\";\n    CHECK_EQ_OR_RETURN(cur_rank_embeddings_shape.At(0), parallel_num * num_ids)\n        << \" got \" << cur_rank_embeddings_shape.At(0) << \" and \" << parallel_num * num_ids;\n    CHECK_EQ_OR_RETURN(embedding_size, cur_rank_embeddings_shape.At(1))\n        << \" got \" << embedding_size << \" and \" << cur_rank_embeddings_shape.At(1);\n  }\n  CHECK_EQ_OR_RETURN(num_unique_matrix_shape.elem_cnt(), parallel_num * parallel_num);\n  CHECK_EQ_OR_RETURN(cur_rank_inverse_indices_shape.elem_cnt(), parallel_num * num_ids);\n  DimVector out_dim_vec = inverse_unique_partition_indices_shape.dim_vec();\n  out_dim_vec.push_back(embedding_size);\n  ctx->SetOutputShape(\"embeddings\", 0, Shape(out_dim_vec));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> EmbeddingShuffleOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> EmbeddingShuffleOp::GetSbp(user_op::SbpContext* ctx) {\n  auto builder = ctx->NewBuilder()\n                     .Split(ctx->inputs(), 0)\n                     .Broadcast(user_op::OpArg(\"num_unique_matrix\", 0))\n                     .Split(ctx->outputs(), 0);\n  if (embedding::UseDynamicMemoryAllocation()) {\n    builder.Broadcast(user_op::OpArg(\"cur_rank_embeddings\", 0)).Build();\n  } else {\n    builder.Split(user_op::OpArg(\"cur_rank_embeddings\", 0), 0).Build();\n  }\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> EmbeddingShuffleOp::InferDataType(user_op::InferContext* ctx) {\n  CHECK_OR_RETURN(ctx->InputDType(\"num_unique_matrix\", 0) == DataType::kUInt32);\n  CHECK_OR_RETURN(ctx->InputDType(\"cur_rank_inverse_indices\", 0) == DataType::kUInt32);\n  CHECK_OR_RETURN(ctx->InputDType(\"inverse_unique_partition_indices\", 0) == DataType::kUInt32);\n  ctx->SetOutputDType(\"embeddings\", 0, ctx->InputDType(\"cur_rank_embeddings\", 0));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> EmbeddingGradientShuffleOp::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) {\n  const Shape& embedding_grad_shape = ctx->InputShape(\"embedding_grad\", 0);\n  const Shape& num_unique_matrix_shape = ctx->InputShape(\"num_unique_matrix\", 0);\n  const Shape& cur_rank_inverse_indices_shape = ctx->InputShape(\"cur_rank_inverse_indices\", 0);\n  const Shape& inverse_unique_partition_indices_shape =\n      ctx->InputShape(\"inverse_unique_partition_indices\", 0);\n  const int64_t num_ids = inverse_unique_partition_indices_shape.elem_cnt();\n  const int64_t parallel_num = ctx->parallel_num();\n  CHECK_EQ_OR_RETURN(embedding_grad_shape.elem_cnt() % num_ids, 0);\n  const int64_t embedding_size = embedding_grad_shape.elem_cnt() / num_ids;\n  CHECK_EQ_OR_RETURN(num_unique_matrix_shape.elem_cnt(), parallel_num * parallel_num);\n  CHECK_EQ_OR_RETURN(cur_rank_inverse_indices_shape.elem_cnt(), parallel_num * num_ids);\n  DimVector out_dim_vec = cur_rank_inverse_indices_shape.dim_vec();\n  out_dim_vec.push_back(embedding_size);\n  ctx->SetOutputShape(\"cur_rank_unique_embedding_grad\", 0, Shape(out_dim_vec));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> EmbeddingGradientShuffleOp::InferPhysicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> EmbeddingGradientShuffleOp::GetSbp(user_op::SbpContext* ctx) {\n  ctx->NewBuilder()\n      .Split(ctx->inputs(), 0)\n      .Broadcast(user_op::OpArg(\"num_unique_matrix\", 0))\n      .Split(ctx->outputs(), 0)\n      .Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> EmbeddingGradientShuffleOp::InferDataType(user_op::InferContext* ctx) {\n  CHECK_OR_RETURN(ctx->InputDType(\"num_unique_matrix\", 0) == DataType::kUInt32);\n  CHECK_OR_RETURN(ctx->InputDType(\"cur_rank_inverse_indices\", 0) == DataType::kUInt32);\n  CHECK_OR_RETURN(ctx->InputDType(\"inverse_unique_partition_indices\", 0) == DataType::kUInt32);\n  ctx->SetOutputDType(\"cur_rank_unique_embedding_grad\", 0, ctx->InputDType(\"embedding_grad\", 0));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> OneEmbeddingGatherOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& in_shape = ctx->InputShape(\"in\", 0);\n  const Shape& indices_shape = ctx->InputShape(\"indices\", 0);\n  const int64_t embedding_size = ctx->Attr<int64_t>(\"embedding_size\");\n  const int64_t num_ids = indices_shape.elem_cnt();\n  const int64_t parallel_num = ctx->parallel_num();\n  if (embedding::UseDynamicMemoryAllocation()) {\n    CHECK_EQ_OR_RETURN(in_shape.elem_cnt(), 1)\n        << \"if use dynamic memory allocation, in elem_cnt should be 1.\";\n  } else {\n    CHECK_EQ_OR_RETURN(in_shape.NumAxes(), 2) << \"in num_axes should be 2.\";\n    CHECK_EQ_OR_RETURN(in_shape.At(0), parallel_num * num_ids)\n        << \" got \" << in_shape.At(0) << \" and \" << parallel_num * num_ids;\n    CHECK_EQ_OR_RETURN(embedding_size, in_shape.At(1))\n        << \" got \" << embedding_size << \" and \" << in_shape.At(1);\n  }\n  DimVector out_dim_vec = indices_shape.dim_vec();\n  out_dim_vec.push_back(embedding_size);\n  ctx->SetOutputShape(\"out\", 0, Shape(out_dim_vec));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> OneEmbeddingGatherOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> OneEmbeddingGatherOp::GetSbp(user_op::SbpContext* ctx) {\n  // Only used in parallel_num = 1.\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> OneEmbeddingGatherOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\nREGISTER_USER_OP_SAME_OUTPUT_BLOB_REGST_NUM_WITH_FUNC(\"id_shuffle\", []() {\n  if (!ParseBooleanFromEnv(\"ONEFLOW_ONE_EMBEDDING_DISABLE_PIPELINED_EXECUTION\", false)) {\n    return 2;\n  } else {\n    return 1;\n  }\n});\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/deconv_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/user/ops/nn_util.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<size_t NDims>\nMaybe<void> InferTensorDesc4DeConv(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& in = ctx->InputTensorDesc(\"in\", 0);\n\n  CHECK_EQ_OR_RETURN(NDims + 2, in.shape().NumAxes());\n\n  const std::string& data_format = ctx->Attr<std::string>(\"data_format\");\n  const auto& kernel_size = ctx->Attr<std::vector<int32_t>>(\"kernel_size\");\n  CHECK_EQ_OR_RETURN(NDims, kernel_size.size());\n  const int32_t filters = ctx->Attr<int32_t>(\"filters\");\n  size_t idx_offset = IdxOffset(data_format);\n  int32_t groups = ctx->Attr<int32_t>(\"groups\");\n  {\n    const auto& dilation_rate = ctx->Attr<std::vector<int32_t>>(\"dilation_rate\");\n    const auto& output_padding = ctx->Attr<std::vector<int32_t>>(\"output_padding\");\n    const auto& strides = ctx->Attr<std::vector<int32_t>>(\"strides\");\n    const auto& padding_before = ctx->Attr<std::vector<int32_t>>(\"padding_before\");\n    CHECK_EQ_OR_RETURN(NDims, dilation_rate.size());\n    CHECK_EQ_OR_RETURN(NDims, strides.size());\n    CHECK_EQ_OR_RETURN(NDims, output_padding.size());\n\n    user_op::TensorDesc* out = ctx->MutOutputTensorDesc(\"out\", 0);\n    DimVector out_shape(NDims + 2);\n    out_shape.at(0) = in.shape().At(0);\n    const size_t c_dim = data_format == \"channels_first\" ? 1 : NDims + 1;\n    out_shape.at(c_dim) = filters;\n    for (int32_t i = 0; i < NDims; ++i) {\n      int32_t effective_filter_size = (kernel_size.at(i) - 1) * dilation_rate.at(i) + 1;\n      out_shape.at(idx_offset + i) = (in.shape().At(idx_offset + i) - 1) * strides.at(i)\n                                     - 2 * padding_before.at(i) + output_padding.at(i)\n                                     + effective_filter_size;\n    }\n    if (in.shape().At(0) != 0) {\n      for (int i = 0; i < out_shape.size(); i++) {\n        CHECK_GT_OR_RETURN(out_shape[i], 0)\n            << \"RuntimeError: Given input size per channel: (\" << Shape(in.shape())\n            << \"). Calculated output size per channel: (\" << Shape(out_shape)\n            << \"). Output size is too small\";\n      }\n    }\n    out->set_is_dynamic(in.is_dynamic());\n    out->set_shape(Shape(out_shape));\n  }\n\n  {\n    DimVector weight_shape(in.shape().dim_vec());\n    if (data_format == \"channels_first\") {\n      weight_shape.at(0) = in.shape().At(1);\n      weight_shape.at(1) = filters / groups;\n    } else if (data_format == \"channels_last\") {\n      weight_shape.at(0) = in.shape().At(NDims + 1);\n      weight_shape.at(NDims + 1) = filters / groups;\n    } else {\n      UNIMPLEMENTED_THEN_RETURN();\n    }\n    for (size_t i = 0; i < NDims; ++i) { weight_shape.at(idx_offset + i) = kernel_size.at(i); }\n    const user_op::TensorDesc& weight = ctx->InputTensorDesc(\"weight\", 0);\n    CHECK_EQ_OR_RETURN(weight.shape(), Shape(weight_shape));\n  }\n\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InferDataType_(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> GetSbpSignatures4DeConv(user_op::SbpContext* ctx) {\n  ctx->NewBuilder()\n      .Split(user_op::OpArg(\"in\", 0), 0)\n      .Broadcast(user_op::OpArg(\"weight\", 0))\n      .Split(user_op::OpArg(\"out\", 0), 0)\n      .Build();\n\n  return Maybe<void>::Ok();\n}\n\ntemplate<size_t NDims>\nMaybe<void> CheckAttr_(const user_op::UserOpDefWrapper& def,\n                       const user_op::UserOpConfWrapper& conf) {\n  bool is_checked = true;\n  std::stringstream err;\n  err << \"Illegal value for \" << conf.op_type_name() << \" op \" << conf.op_name() << \": \";\n\n  const std::string& data_format = conf.attr<std::string>(\"data_format\");\n  if (!(data_format == \"channels_first\" || data_format == \"channels_last\")) {\n    err << \" data_format:\" << data_format;\n    is_checked = false;\n  }\n\n  if (NDims != 0) {\n    const auto& padding_before = conf.attr<std::vector<int32_t>>(\"padding_before\");\n    if (padding_before.size() != NDims) {\n      err << \" padding_before: number of element is \" << padding_before.size();\n      is_checked = false;\n    }\n\n    const auto& kernel_size = conf.attr<std::vector<int32_t>>(\"kernel_size\");\n    if (kernel_size.size() != NDims) {\n      err << \" kernel_size: number of element is \" << kernel_size.size();\n      is_checked = false;\n    }\n\n    const auto& strides = conf.attr<std::vector<int32_t>>(\"strides\");\n    if (strides.size() != NDims) {\n      err << \" strides: number of element is \" << strides.size();\n      is_checked = false;\n    }\n\n    const auto& dilation_rate = conf.attr<std::vector<int32_t>>(\"dilation_rate\");\n    if (dilation_rate.size() != NDims) {\n      err << \" dilation_rate: number of element is \" << dilation_rate.size();\n      is_checked = false;\n    }\n  }\n\n  if (is_checked) {\n    return Maybe<void>::Ok();\n  } else {\n    return oneflow::Error::CheckFailedError() << err.str();\n  }\n}\n\n}  // namespace\n\n/* static */ Maybe<void> Deconv1DOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  return InferTensorDesc4DeConv<1>(ctx);\n}\n\n/*static*/ Maybe<void> Deconv1DOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> Deconv1DOp::GetSbp(user_op::SbpContext* ctx) {\n  return GetSbpSignatures4DeConv(ctx);\n}\n\n/* static */ Maybe<void> Deconv1DOp::CheckAttr(const user_op::UserOpDefWrapper& def,\n                                               const user_op::UserOpConfWrapper& conf) {\n  return CheckAttr_<1>(def, conf);\n}\n\n/* static */ Maybe<void> Deconv1DOp::InferDataType(user_op::InferContext* ctx) {\n  return InferDataType_(ctx);\n}\n\n/* static */ Maybe<void> Deconv2DOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  return InferTensorDesc4DeConv<2>(ctx);\n}\n\n/*static*/ Maybe<void> Deconv2DOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> Deconv2DOp::GetSbp(user_op::SbpContext* ctx) {\n  return GetSbpSignatures4DeConv(ctx);\n}\n\n/* static */ Maybe<void> Deconv2DOp::CheckAttr(const user_op::UserOpDefWrapper& def,\n                                               const user_op::UserOpConfWrapper& conf) {\n  return CheckAttr_<2>(def, conf);\n}\n\n/* static */ Maybe<void> Deconv2DOp::InferDataType(user_op::InferContext* ctx) {\n  return InferDataType_(ctx);\n}\n\n/* static */ Maybe<void> Deconv3DOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  return InferTensorDesc4DeConv<3>(ctx);\n}\n\n/*static*/ Maybe<void> Deconv3DOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> Deconv3DOp::GetSbp(user_op::SbpContext* ctx) {\n  return GetSbpSignatures4DeConv(ctx);\n}\n\n/* static */ Maybe<void> Deconv3DOp::CheckAttr(const user_op::UserOpDefWrapper& def,\n                                               const user_op::UserOpConfWrapper& conf) {\n  return CheckAttr_<3>(def, conf);\n}\n\n/* static */ Maybe<void> Deconv3DOp::InferDataType(user_op::InferContext* ctx) {\n  return InferDataType_(ctx);\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/deform_conv_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/* static */ Maybe<void> DeformConv2dOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& input_shape = ctx->InputShape(\"input\", 0);\n  const Shape& weight_shape = ctx->InputShape(\"weight\", 0);\n  const Shape& offset_shape = ctx->InputShape(\"offset\", 0);\n  const Shape& mask_shape = ctx->InputShape(\"mask\", 0);\n  const int32_t kW = weight_shape.at(3);\n  const int32_t kH = weight_shape.at(2);\n  const int32_t dW = ctx->Attr<int32_t>(\"stride_w\");\n  const int32_t dH = ctx->Attr<int32_t>(\"stride_h\");\n  const int32_t padW = ctx->Attr<int32_t>(\"pad_w\");\n  const int32_t padH = ctx->Attr<int32_t>(\"pad_h\");\n  const int32_t dilationW = ctx->Attr<int32_t>(\"dilation_w\");\n  const int32_t dilationH = ctx->Attr<int32_t>(\"dilation_h\");\n  const int32_t deformable_group = ctx->Attr<int32_t>(\"offset_groups\");\n  const bool use_mask = ctx->Attr<bool>(\"use_mask\");\n  bool has_bias = ctx->has_input(\"bias\", 0);\n  if (has_bias) {\n    const Shape& bias_shape = ctx->InputShape(\"bias\", 0);\n    std::cout << \"bias_shape:\" << bias_shape.ToString() << std::endl;\n    CHECK_EQ_OR_RETURN(bias_shape.At(0), weight_shape.At(0));\n  }\n  CHECK_OR_RETURN(dW > 0 && dH > 0)\n      << Error::RuntimeError() << \"The stride must be greater than 0,but got \" << dW << \" and \"\n      << dH;\n  CHECK_OR_RETURN(kW > 0 && kH > 0)\n      << Error::RuntimeError() << \"The weight must be greater than 0,but got \" << kW << \" and \"\n      << kH;\n\n  CHECK_OR_RETURN(padW >= 0 && padH >= 0)\n      << Error::RuntimeError() << \"The pad must be greater than or equal to 0,but got \" << padW\n      << \" and \" << padH;\n  CHECK_OR_RETURN(dilationW > 0 && dilationH > 0)\n      << Error::RuntimeError() << \"The dilation must be greater than 0,but got \" << dilationH\n      << \" and \" << dilationW;\n\n  CHECK_EQ_OR_RETURN(input_shape.NumAxes(), 4);                   // NOLINT(maybe-need-error-msg)\n  CHECK_EQ_OR_RETURN(weight_shape.NumAxes(), 4);                  // NOLINT(maybe-need-error-msg)\n  CHECK_EQ_OR_RETURN(offset_shape.NumAxes(), 4);                  // NOLINT(maybe-need-error-msg)\n  if (use_mask) { CHECK_EQ_OR_RETURN(mask_shape.NumAxes(), 4); }  // NOLINT(maybe-need-error-msg)\n  CHECK_EQ_OR_RETURN(weight_shape.At(2), kH);                     // NOLINT(maybe-need-error-msg)\n  CHECK_EQ_OR_RETURN(weight_shape.At(3), kW);                     // NOLINT(maybe-need-error-msg)\n\n  CHECK_EQ_OR_RETURN(offset_shape.At(1), deformable_group * 2 * kW * kH)\n      << Error::RuntimeError() << \"offset.shape[1] is not valid: got: \" << offset_shape.At(1)\n      << \" ,expected: \" << deformable_group * 2 * kW * kH;\n\n  if (use_mask) {\n    CHECK_EQ_OR_RETURN(mask_shape.At(1), deformable_group * kW * kH)\n        << Error::RuntimeError() << \"mask.shape[1] is not valid: got: \" << mask_shape.At(1)\n        << \" expected: \" << deformable_group * kW * kH;\n  }\n  CHECK_EQ_OR_RETURN(offset_shape.At(0), input_shape.At(0))\n      << Error::RuntimeError() << \"invalid batch size of offset:got: \" << offset_shape.At(0)\n      << \" ,expected: \" << input_shape.At(0);\n\n  int64_t outputWidth = (input_shape.At(3) + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;\n  int64_t outputHeight = (input_shape.At(2) + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;\n  CHECK_OR_RETURN(outputWidth > 0 && outputHeight > 0)\n      << Error::RuntimeError() << \"Calculated output size too small - out_h: \" << outputHeight\n      << \" ,out_w: \" << outputWidth;\n  CHECK_OR_RETURN(offset_shape.At(2) == outputHeight && offset_shape.At(3) == outputWidth)\n      << Error::RuntimeError() << \"invalid offset output dims: got ( \" << offset_shape.At(2) << \", \"\n      << offset_shape.At(3) << \")\"\n      << \",expected: \"\n      << \"(\" << outputHeight << \", \" << outputWidth << \")\";\n\n  if (use_mask) {\n    CHECK_OR_RETURN(mask_shape.At(2) == outputHeight && mask_shape.At(3) == outputWidth)\n        << Error::RuntimeError() << \"invalid mask output dims: got ( \" << mask_shape.At(2) << \", \"\n        << mask_shape.At(3) << \")\"\n        << \",expected: \"\n        << \"(\" << outputHeight << \", \" << outputWidth << \")\";\n  }\n  ctx->SetOutputShape(\"output\", 0,\n                      Shape({input_shape.At(0), weight_shape.At(0), outputHeight, outputWidth}));\n  ctx->SetOutputIsDynamic(\"output\", 0, ctx->InputIsDynamic(\"input\", 0));\n\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> DeformConv2dInputGradOp::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) {\n  const Shape& input_shape = ctx->InputShape(\"input\", 0);\n  const Shape& weight_shape = ctx->InputShape(\"weight\", 0);\n  const Shape& offset_shape = ctx->InputShape(\"offset\", 0);\n  const Shape& output_grad_shape = ctx->InputShape(\"output_grad\", 0);\n  const int32_t kW = weight_shape.at(3);\n  const int32_t kH = weight_shape.at(2);\n  const int32_t dW = ctx->Attr<int32_t>(\"stride_w\");\n  const int32_t dH = ctx->Attr<int32_t>(\"stride_h\");\n  const int32_t padW = ctx->Attr<int32_t>(\"pad_w\");\n  const int32_t padH = ctx->Attr<int32_t>(\"pad_h\");\n  const int32_t dilationW = ctx->Attr<int32_t>(\"dilation_w\");\n  const int32_t dilationH = ctx->Attr<int32_t>(\"dilation_h\");\n  const bool use_mask = ctx->Attr<bool>(\"use_mask\");\n  CHECK_EQ_OR_RETURN(weight_shape.NumAxes(), 4);\n\n  int64_t outputWidth = (input_shape.At(3) + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;\n  int64_t outputHeight = (input_shape.At(2) + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;\n  CHECK_EQ_OR_RETURN(output_grad_shape.At(2), outputHeight);        // NOLINT(maybe-need-error-msg)\n  CHECK_EQ_OR_RETURN(output_grad_shape.At(3), outputWidth);         // NOLINT(maybe-need-error-msg)\n  CHECK_EQ_OR_RETURN(output_grad_shape.At(1), weight_shape.At(0));  // NOLINT(maybe-need-error-msg)\n  ctx->SetOutputShape(\"input_grad\", 0, ctx->InputShape(\"input\", 0));\n  ctx->SetOutputShape(\"offset_grad\", 0, ctx->InputShape(\"offset\", 0));\n  ctx->SetOutputIsDynamic(\"input_grad\", 0, ctx->InputIsDynamic(\"input\", 0));\n  ctx->SetOutputIsDynamic(\"offset_grad\", 0, false);\n\n  if (use_mask) {\n    ctx->SetOutputShape(\"mask_grad\", 0,\n                        Shape({offset_shape.At(0), offset_shape.At(1) / 2, offset_shape.At(2),\n                               offset_shape.At(3)}));\n    ctx->SetOutputIsDynamic(\"mask_grad\", 0, false);\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> DeformConv2dParamGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& input_shape = ctx->InputShape(\"input\", 0);\n  const Shape& output_grad_shape = ctx->InputShape(\"output_grad\", 0);\n  const Shape& weight_shape = ctx->InputShape(\"weight\", 0);\n  const int32_t kW = weight_shape.at(3);\n  const int32_t kH = weight_shape.at(2);\n  const int32_t dW = ctx->Attr<int32_t>(\"stride_w\");\n  const int32_t dH = ctx->Attr<int32_t>(\"stride_h\");\n  const int32_t padW = ctx->Attr<int32_t>(\"pad_w\");\n  const int32_t padH = ctx->Attr<int32_t>(\"pad_h\");\n  const int32_t dilationW = ctx->Attr<int32_t>(\"dilation_w\");\n  const int32_t dilationH = ctx->Attr<int32_t>(\"dilation_h\");\n  int64_t outputWidth = (input_shape.At(3) + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;\n  int64_t outputHeight = (input_shape.At(2) + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;\n  CHECK_EQ_OR_RETURN(output_grad_shape.At(2), outputHeight);  // NOLINT(maybe-need-error-msg)\n  CHECK_EQ_OR_RETURN(output_grad_shape.At(3), outputWidth);   // NOLINT(maybe-need-error-msg)\n  ctx->SetOutputShape(\"weight_grad\", 0, ctx->InputShape(\"weight\", 0));\n  ctx->SetOutputIsDynamic(\"weight_grad\", 0, false);\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> DeformConv2dOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> DeformConv2dInputGradOp::InferPhysicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> DeformConv2dParamGradOp::InferPhysicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> DeformConv2dOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"output\", 0, ctx->InputDType(\"input\", 0));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> DeformConv2dInputGradOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"input_grad\", 0, ctx->InputDType(\"input\", 0));\n  ctx->SetOutputDType(\"offset_grad\", 0, ctx->InputDType(\"offset\", 0));\n  const bool use_mask = ctx->Attr<bool>(\"use_mask\");\n  if (use_mask) { ctx->SetOutputDType(\"mask_grad\", 0, ctx->InputDType(\"mask\", 0)); }\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> DeformConv2dParamGradOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"weight_grad\", 0, ctx->InputDType(\"input\", 0));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> DeformConv2dOp::GetSbp(user_op::SbpContext* ctx) {\n  ctx->NewBuilder()\n      .Split(user_op::OpArg(\"input\", 0), 0)\n      .Split(user_op::OpArg(\"offset\", 0), 0)\n      .Split(user_op::OpArg(\"mask\", 0), 0)\n      .Broadcast(user_op::OpArg(\"weight\", 0))\n      .Split(user_op::OpArg(\"output\", 0), 0)\n      .Build();\n\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> DeformConv2dInputGradOp::GetSbp(user_op::SbpContext* ctx) {\n  ctx->NewBuilder()\n      .Split(user_op::OpArg(\"output_grad\", 0), 0)\n      .Split(user_op::OpArg(\"input\", 0), 0)\n      .Split(user_op::OpArg(\"offset\", 0), 0)\n      .Split(user_op::OpArg(\"mask\", 0), 0)\n      .Broadcast(user_op::OpArg(\"weight\", 0))\n      .Split(user_op::OpArg(\"input_grad\", 0), 0)\n      .Split(user_op::OpArg(\"offset_grad\", 0), 0)\n      .Split(user_op::OpArg(\"mask_grad\", 0), 0)\n      .Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> DeformConv2dParamGradOp::GetSbp(user_op::SbpContext* ctx) {\n  ctx->NewBuilder()\n      .Split(user_op::OpArg(\"output_grad\", 0), 0)\n      .Broadcast(user_op::OpArg(\"weight\", 0))\n      .Split(user_op::OpArg(\"input\", 0), 0)\n      .Split(user_op::OpArg(\"mask\", 0), 0)\n      .Split(user_op::OpArg(\"offset\", 0), 0)\n      .PartialSum(user_op::OpArg(\"weight_grad\", 0))\n      .Build();\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/depend_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/* static */ Maybe<void> DependOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  ctx->SetOutputShape(\"out\", 0, ctx->InputShape(\"in\", 0));\n  ctx->SetOutputIsDynamic(\"out\", 0, ctx->InputIsDynamic(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> DependOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> DependOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"in\", 0);\n  FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) {\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"in\", 0), i)\n        .Broadcast(user_op::OpArg(\"depend_tensor\", 0))\n        .Split(user_op::OpArg(\"out\", 0), i)\n        .Build();\n  }\n  ctx->NewBuilder()\n      .PartialSum(user_op::OpArg(\"in\", 0))\n      .Broadcast(user_op::OpArg(\"depend_tensor\", 0))\n      .PartialSum(user_op::OpArg(\"out\", 0))\n      .Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> DependOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/det_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/*static*/ Maybe<void> DetOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const auto& x_desc = ctx->InputTensorDesc(\"x\", 0);\n  auto x_shape = x_desc.shape();\n  ctx->SetOutputShape(\"y\", 0, Shape(x_shape.begin(), x_shape.end() - 2));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> DetOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/*static*/ Maybe<void> DetOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& x = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"x\", 0);\n  FOR_RANGE(int64_t, i, 0, x.shape().NumAxes() - 2) {\n    ctx->NewBuilder().Split(user_op::OpArg(\"x\", 0), i).Split(user_op::OpArg(\"y\", 0), i).Build();\n  }\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> DetOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"y\", 0, ctx->InputDType(\"x\", 0));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/diag_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/* static */ Maybe<void> DiagOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& in = ctx->InputTensorDesc(\"in\", 0);\n  const int32_t diagonal = ctx->Attr<int32_t>(\"diagonal\");\n  const ShapeView& in_shape = in.shape();\n  const int32_t in_dim = in_shape.NumAxes();\n  CHECK_GE_OR_RETURN(in_dim, 1);\n  CHECK_LE_OR_RETURN(in_dim, 2);\n\n  DimVector out_dim_vec = {0};\n  if (in_dim == 1) {\n    int32_t out_tensor_size = in_shape.At(0) + std::abs(diagonal);\n    out_dim_vec[0] = out_tensor_size;\n    out_dim_vec.emplace_back(out_tensor_size);\n  } else {\n    if (diagonal >= 0) {\n      out_dim_vec[0] = std::min(in_shape.At(0), in_shape.At(1) - diagonal);\n    } else {\n      out_dim_vec[0] = std::min(in_shape.At(0) + diagonal, in_shape.At(1));\n    }\n    // For 0-size Tensor.\n    CHECK_GE_OR_RETURN(out_dim_vec[0], 0);  // NOLINT\n  }\n\n  user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc(\"out\", 0);\n  out_desc->set_is_dynamic(false);\n  out_desc->set_shape(Shape(out_dim_vec));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> DiagOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> DiagOp::GetSbp(user_op::SbpContext* ctx) {\n  ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> DiagOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> DiagGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& in = ctx->InputTensorDesc(\"in\", 0);\n  const Shape& in_shape = in.shape();\n  user_op::TensorDesc* dx_desc = ctx->MutOutputTensorDesc(\"dx\", 0);\n  dx_desc->set_shape(Shape(in_shape.dim_vec()));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> DiagGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> DiagGradOp::GetSbp(user_op::SbpContext* ctx) {\n  ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> DiagGradOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"dx\", 0, ctx->InputDType(\"dy\", 0));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/diagonal_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/* static */ Maybe<void> DiagonalOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& in = ctx->InputTensorDesc(\"in\", 0);\n  const int32_t offset = ctx->Attr<int32_t>(\"offset\");\n  const ShapeView& in_shape = in.shape();\n  const int32_t in_dim = in_shape.NumAxes();\n  CHECK_GE_OR_RETURN(in_dim, 2);\n\n  DimVector out_dim_vec = {};\n  FOR_RANGE(int32_t, index, 2, in_dim) { out_dim_vec.push_back(in_shape.At(index)); }\n  int32_t last_dim = 0;\n  if (offset >= 0) {\n    last_dim = std::min(in_shape.At(0), in_shape.At(1) - offset);\n  } else {\n    last_dim = std::min(in_shape.At(0) + offset, in_shape.At(1));\n  }\n  if (last_dim < 0) { last_dim = 0; }\n  out_dim_vec.push_back(last_dim);\n\n  user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc(\"out\", 0);\n  out_desc->set_is_dynamic(false);\n  out_desc->set_shape(Shape(out_dim_vec));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> DiagonalOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> DiagonalOp::GetSbp(user_op::SbpContext* ctx) {\n  ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> DiagonalOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> DiagonalGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& in = ctx->InputTensorDesc(\"in\", 0);\n  const Shape& in_shape = in.shape();\n  user_op::TensorDesc* dx_desc = ctx->MutOutputTensorDesc(\"dx\", 0);\n  dx_desc->set_shape(Shape(in_shape.dim_vec()));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> DiagonalGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> DiagonalGradOp::GetSbp(user_op::SbpContext* ctx) {\n  ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> DiagonalGradOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"dx\", 0, ctx->InputDType(\"dy\", 0));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/dim_gather_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/user/kernels/dim_gather_kernel_util.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/* static */ Maybe<void> DimGatherOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& in = ctx->InputTensorDesc(\"input\", 0);\n  int64_t input_num_axes = in.shape().NumAxes();\n  // For 0-dim tensor\n  CHECK_GE_OR_RETURN(input_num_axes, 0);  // NOLINT\n  CHECK_LE_OR_RETURN(input_num_axes, kDimGatherMaxDimCount);\n\n  const user_op::TensorDesc& index = ctx->InputTensorDesc(\"index\", 0);\n  int64_t index_num_axes = index.shape().NumAxes();\n\n  const int32_t dim = ctx->Attr<int32_t>(\"dim\");\n  // For 0-dim tensor\n  CHECK_GE_OR_RETURN(dim, 0);\n  CHECK_LE_OR_RETURN(dim, input_num_axes);                                         // NOLINT\n  if (input_num_axes > 0) { CHECK_GE_OR_RETURN(input_num_axes, index_num_axes); }  // NOLINT\n\n  CHECK_EQ_OR_RETURN(in.is_dynamic(), index.is_dynamic());\n\n  user_op::TensorDesc* out = ctx->MutOutputTensorDesc(\"output\", 0);\n  out->set_shape(index.shape());\n\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> DimGatherOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> DimGatherOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& index_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"index\", 0);\n  int64_t index_num_axes = index_tensor.shape().NumAxes();\n  const int32_t dim = ctx->Attr<int32_t>(\"dim\");\n\n  FOR_RANGE(int64_t, i, 0, index_num_axes) {\n    if (i != dim) {\n      ctx->NewBuilder()\n          .Split(user_op::OpArg(\"index\", 0), i)\n          .Split(user_op::OpArg(\"input\", 0), i)\n          .Split(user_op::OpArg(\"output\", 0), i)\n          .Build();\n    } else if (i == dim) {\n      ctx->NewBuilder()\n          .Broadcast(user_op::OpArg(\"input\", 0))\n          .Split(user_op::OpArg(\"index\", 0), i)\n          .Split(user_op::OpArg(\"output\", 0), i)\n          .Build();\n    }\n  }\n  ctx->NewBuilder()\n      .PartialSum(user_op::OpArg(\"input\", 0))\n      .Broadcast(user_op::OpArg(\"index\", 0))\n      .PartialSum(user_op::OpArg(\"output\", 0))\n      .Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> DimGatherOp::ModifyInputArg(\n    const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) {\n  user_op::InputArgModifier* indices_modifier = GetInputArgModifierFn(\"index\", 0);\n  CHECK_OR_RETURN(indices_modifier != nullptr);\n  indices_modifier->set_requires_grad(false);\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> DimGatherOp::InferDataType(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& index = ctx->InputTensorDesc(\"index\", 0);\n  CHECK_OR_RETURN(IsIndexDataType(index.data_type()));\n  const user_op::TensorDesc& in = ctx->InputTensorDesc(\"input\", 0);\n  user_op::TensorDesc* out = ctx->MutOutputTensorDesc(\"output\", 0);\n  out->set_data_type(in.data_type());\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/dim_scatter_ops.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/error.h\"\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/framework/user_op_registry.h\"\n#include \"oneflow/user/kernels/dim_scatter_kernel_util.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\nnamespace {\nMaybe<void> InferTensorDesc(user_op::InferContext* ctx) {\n  const user_op::TensorDesc* input =\n      ctx->has_input(\"input\", 0) ? &ctx->InputTensorDesc(\"input\", 0) : nullptr;\n  const user_op::TensorDesc& index = ctx->InputTensorDesc(\"index\", 0);\n  const user_op::TensorDesc* like =\n      ctx->has_input(\"like\", 0) ? &ctx->InputTensorDesc(\"like\", 0) : nullptr;\n  const user_op::TensorDesc& src = ctx->InputTensorDesc(\"src\", 0);\n\n  int32_t dim = ctx->Attr<int32_t>(\"dim\");\n\n  // check index.numaxes == src.num_axes == input/like.numaxes\n  int64_t src_num_axes = src.shape().NumAxes();\n  // For 0-dim Tensor\n  CHECK_GE_OR_RETURN(src_num_axes, 0);  // NOLINT\n  CHECK_LE_OR_RETURN(src_num_axes, user_op::kDimGatherMaxDimCount);\n  int64_t index_num_axes = index.shape().NumAxes();\n  CHECK_EQ_OR_RETURN(src_num_axes, index_num_axes);\n\n  int64_t output_num_axes = 0;\n  if (input) {\n    output_num_axes = input->shape().NumAxes();\n  } else if (like) {\n    output_num_axes = like->shape().NumAxes();\n  } else {\n    OF_UNIMPLEMENTED() << \"Input tensor and like tensor cannot be empty simultaneously.\";\n  }\n  // For 0-dim Tensor\n  if (output_num_axes != 0 && index_num_axes != 0) {\n    CHECK_EQ_OR_RETURN(output_num_axes, index_num_axes);  // NOLINT\n  } else if (output_num_axes != 0) {\n    CHECK_LE_OR_RETURN(output_num_axes, 1);  // NOLINT\n  } else {\n    CHECK_LE_OR_RETURN(index_num_axes, 1);  // NOLINT\n  }\n\n  // check index.shape(i) <= input/like.shape(i)\n  FOR_RANGE(int64_t, i, 0, index_num_axes) {\n    if (i == dim) continue;\n    if (input) {\n      CHECK_LE_OR_RETURN(index.shape().At(i), input->shape().At(i));\n    } else {\n      CHECK_LE_OR_RETURN(index.shape().At(i), like->shape().At(i));\n    }\n  }\n\n  // check index.shape(i) <= src.shape(i)\n  FOR_RANGE(int64_t, i, 0, index_num_axes) {\n    if (i == dim) continue;\n    CHECK_LE_OR_RETURN(index.shape().At(i), src.shape().At(i));\n  }\n\n  user_op::TensorDesc* out = ctx->MutOutputTensorDesc(\"output\", 0);\n  out->set_shape(input ? input->shape() : like->shape());\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InferScalarTensorDesc(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& input = ctx->InputTensorDesc(\"input\", 0);\n  const user_op::TensorDesc& index = ctx->InputTensorDesc(\"index\", 0);\n\n  int32_t dim = ctx->Attr<int32_t>(\"dim\");\n\n  // check index.numaxes == src.num_axes == input/like.numaxes\n  int64_t output_num_axes = input.shape().NumAxes();\n  int64_t index_num_axes = index.shape().NumAxes();\n  // For 0-dim tensor\n  CHECK_GE_OR_RETURN(output_num_axes, index_num_axes);  // NOLINT\n\n  // check index.shape(i) <= input/like.shape(i)\n  FOR_RANGE(int64_t, i, 0, index_num_axes) {\n    if (i == dim) continue;\n    CHECK_LE_OR_RETURN(index.shape().At(i), input.shape().At(i));\n  }\n\n  user_op::TensorDesc* out = ctx->MutOutputTensorDesc(\"output\", 0);\n  out->set_shape(input.shape());\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InputArgModifierFn(const user_op::GetInputArgModifier& GetInputArgModifierFn,\n                               const user_op::UserOpConfWrapper&) {\n  user_op::InputArgModifier* indices_modifier = GetInputArgModifierFn(\"index\", 0);\n  CHECK_OR_RETURN(indices_modifier != nullptr);\n  indices_modifier->set_requires_grad(false);\n\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InputScalarArgModifierFn(const user_op::GetInputArgModifier& GetInputArgModifierFn,\n                                     const user_op::UserOpConfWrapper&) {\n  user_op::InputArgModifier* indices_modifier = GetInputArgModifierFn(\"index\", 0);\n  CHECK_OR_RETURN(indices_modifier != nullptr);\n  indices_modifier->set_requires_grad(false);\n\n  return Maybe<void>::Ok();\n}\n\nvoid _SetSbp(user_op::SbpContext* ctx, const char* like_or_input) {\n  const int32_t dim = ctx->Attr<int32_t>(\"dim\");\n\n  const Shape& index_tensor_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"index\", 0).shape();\n  const Shape& src_tensor_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"src\", 0).shape();\n  const Shape& input_tensor_shape =\n      ctx->LogicalTensorDesc4InputArgNameAndIndex(like_or_input, 0).shape();\n\n  FOR_RANGE(int64_t, i, 0, index_tensor_shape.NumAxes()) {\n    if (i == dim) { continue; }\n    int64_t len = index_tensor_shape.At(i);\n    if (len == src_tensor_shape.At(i) && len == input_tensor_shape.At(i)) {\n      ctx->NewBuilder()\n          .Split(user_op::OpArg(\"index\", 0), i)\n          .Split(user_op::OpArg(\"src\", 0), i)\n          .Split(user_op::OpArg(like_or_input, 0), i)\n          .Split(user_op::OpArg(\"output\", 0), i)\n          .Build();\n    }\n  }\n\n  ctx->NewBuilder()\n      .PartialSum(user_op::OpArg(\"src\", 0))\n      .Broadcast(user_op::OpArg(\"index\", 0))\n      .PartialSum(user_op::OpArg(\"output\", 0))\n      .PartialSum(user_op::OpArg(like_or_input, 0))\n      .Build();\n}\n\nMaybe<void> SetSbpLike(user_op::SbpContext* ctx) {\n  _SetSbp(ctx, \"like\");\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> SetSbpScatter(user_op::SbpContext* ctx) {\n  _SetSbp(ctx, \"input\");\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> SetSbpScatterScalar(user_op::SbpContext* ctx) {\n  const int32_t dim = ctx->Attr<int32_t>(\"dim\");\n\n  const Shape& index_tensor_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"index\", 0).shape();\n  const Shape& input_tensor_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"input\", 0).shape();\n\n  FOR_RANGE(int64_t, i, 0, index_tensor_shape.NumAxes()) {\n    if (i == dim) { continue; }\n    if (index_tensor_shape.At(i) == input_tensor_shape.At(i)) {\n      ctx->NewBuilder()\n          .Split(user_op::OpArg(\"index\", 0), i)\n          .Split(user_op::OpArg(\"input\", 0), i)\n          .Split(user_op::OpArg(\"output\", 0), i)\n          .Build();\n    }\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InferDtype(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& index = ctx->InputTensorDesc(\"index\", 0);\n  CHECK_OR_RETURN(IsIndexDataType(index.data_type()));\n  if (ctx->has_input(\"input\", 0)) {\n    CHECK_EQ_OR_RETURN(ctx->InputDType(\"input\", 0), ctx->InputDType(\"src\", 0))\n        << \"InferDataType Failed. Expected \" << DataType_Name(ctx->InputDType(\"src\", 0))\n        << \", but got \" << DataType_Name(ctx->InputDType(\"input\", 0));\n  } else {\n    CHECK_EQ_OR_RETURN(ctx->InputDType(\"like\", 0), ctx->InputDType(\"src\", 0))\n        << \"InferDataType Failed. Expected \" << DataType_Name(ctx->InputDType(\"like\", 0))\n        << \", but got \" << DataType_Name(ctx->InputDType(\"src\", 0));\n  }\n  ctx->SetOutputDType(\"output\", 0, ctx->InputDType(\"src\", 0));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InferScalarDtype(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& index = ctx->InputTensorDesc(\"index\", 0);\n  CHECK_OR_RETURN(IsIndexDataType(index.data_type()));\n  ctx->SetOutputDType(\"output\", 0, ctx->InputDType(\"input\", 0));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\n/* static */ Maybe<void> DimScatterAddLikeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  return InferTensorDesc(ctx);\n}\n\n/*static*/ Maybe<void> DimScatterAddLikeOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> DimScatterAddLikeOp::GetSbp(user_op::SbpContext* ctx) {\n  return SetSbpLike(ctx);\n}\n\n/* static */ Maybe<void> DimScatterAddLikeOp::ModifyInputArg(\n    const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) {\n  return InputArgModifierFn(GetInputArgModifierFn, conf);\n}\n\n/* static */ Maybe<void> DimScatterAddLikeOp::InferDataType(user_op::InferContext* ctx) {\n  return InferDtype(ctx);\n}\n\n#define DEF_SCATTER_OP(op_class_name)                                                             \\\n  /* static */ Maybe<void> op_class_name::InferLogicalTensorDesc(user_op::InferContext* ctx) {    \\\n    return InferTensorDesc(ctx);                                                                  \\\n  }                                                                                               \\\n                                                                                                  \\\n  /*static*/ Maybe<void> op_class_name::InferPhysicalTensorDesc(user_op::InferContext* ctx) {     \\\n    return InferLogicalTensorDesc(ctx);                                                           \\\n  }                                                                                               \\\n                                                                                                  \\\n  /* static */ Maybe<void> op_class_name::GetSbp(user_op::SbpContext* ctx) {                      \\\n    return SetSbpScatter(ctx);                                                                    \\\n  }                                                                                               \\\n                                                                                                  \\\n  /* static */ Maybe<void> op_class_name::ModifyInputArg(                                         \\\n      const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { \\\n    return InputArgModifierFn(GetInputArgModifierFn, conf);                                       \\\n  }                                                                                               \\\n                                                                                                  \\\n  /* static */ Maybe<void> op_class_name::InferDataType(user_op::InferContext* ctx) {             \\\n    return InferDtype(ctx);                                                                       \\\n  }\n\n#define DEF_SCATTER_SCALAR_OP(optypename)                                                         \\\n  /* static */ Maybe<void> optypename::InferLogicalTensorDesc(user_op::InferContext* ctx) {       \\\n    return InferScalarTensorDesc(ctx);                                                            \\\n  }                                                                                               \\\n                                                                                                  \\\n  /*static*/ Maybe<void> optypename::InferPhysicalTensorDesc(user_op::InferContext* ctx) {        \\\n    return InferLogicalTensorDesc(ctx);                                                           \\\n  }                                                                                               \\\n                                                                                                  \\\n  /* static */ Maybe<void> optypename::GetSbp(user_op::SbpContext* ctx) {                         \\\n    return SetSbpScatterScalar(ctx);                                                              \\\n  }                                                                                               \\\n                                                                                                  \\\n  /* static */ Maybe<void> optypename::ModifyInputArg(                                            \\\n      const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { \\\n    return InputScalarArgModifierFn(GetInputArgModifierFn, conf);                                 \\\n  }                                                                                               \\\n                                                                                                  \\\n  /* static */ Maybe<void> optypename::InferDataType(user_op::InferContext* ctx) {                \\\n    return InferScalarDtype(ctx);                                                                 \\\n  }\n\nDEF_SCATTER_OP(DimScatterAddOp);\nDEF_SCATTER_OP(DimScatterUpdateOp);\nDEF_SCATTER_OP(DimScatterMulOp);\n\nDEF_SCATTER_SCALAR_OP(DimScatterUpdateScalarOp);\nDEF_SCATTER_SCALAR_OP(DimScatterAddScalarOp);\nDEF_SCATTER_SCALAR_OP(DimScatterMulScalarOp);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/distributions/exponential_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n#include \"oneflow/core/job/nd_sbp_util.h\"\n\nnamespace oneflow {\n\n/* static */ Maybe<void> ExponentialOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& shape = ctx->Attr<Shape>(\"out_shape\");\n  DimVector dim_vec;\n  if (shape.NumAxes() > 0) {\n    dim_vec.insert(dim_vec.end(), shape.dim_vec().cbegin(), shape.dim_vec().cend());\n  }\n  ctx->SetOutputShape(\"out\", 0, Shape(dim_vec));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> ExponentialOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& parallel_hierarchy = *ctx->parallel_desc().hierarchy();\n  const NdSbp& nd_sbp = ctx->NdSbp4ArgNameAndIndex(\"out\", 0);\n  const Shape& logical_shape = ctx->Attr<Shape>(\"out_shape\");\n  const int64_t parallel_id = ctx->parallel_ctx().parallel_id();\n  const auto tensor_slice_view =\n      GetTensorSliceView4ParallelId(parallel_hierarchy, nd_sbp, logical_shape, parallel_id);\n  const Shape& physical_shape = tensor_slice_view.shape();\n\n  ctx->SetOutputShape(\"out\", 0, physical_shape);\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> ExponentialOp::GetSbp(user_op::SbpContext* ctx) {\n  ctx->NewBuilder().Broadcast(ctx->inputs()).Broadcast(ctx->outputs()).Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> ExponentialOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) {\n  SbpParallel default_sbp;\n  default_sbp.mutable_broadcast_parallel();\n  return user_op::InferNdSbp4SrcOp(ctx, default_sbp);\n}\n\n/* static */ Maybe<void> ExponentialOp::InferDataType(user_op::InferContext* ctx) {\n  auto dtype = ctx->Attr<DataType>(\"dtype\");\n  ctx->SetOutputDType(\"out\", 0, dtype);\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/distributions/multinomial_with_replacement_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/* static */ Maybe<void> MultinomialWithReplacementOp::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) {\n  int32_t num_samples = ctx->Attr<int32_t>(\"num_samples\");\n  const Shape& x_shape = ctx->InputShape(\"x\", 0);\n  if (x_shape.NumAxes() == 1) {\n    ctx->SetOutputShape(\"out\", 0, Shape({num_samples}));\n  } else {\n    ctx->SetOutputShape(\"out\", 0, Shape({x_shape.At(0), num_samples}));\n  }\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> MultinomialWithReplacementOp::InferPhysicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> MultinomialWithReplacementOp::GetSbp(user_op::SbpContext* ctx) {\n  const Shape& x_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"x\", 0).shape();\n  if (x_shape.NumAxes() == 2) {\n    ctx->NewBuilder().Split(ctx->inputs(), 0).Split(ctx->outputs(), 0).Build();\n  }\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> MultinomialWithReplacementOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, DataType::kInt64);\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/distributions/normal_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n#include \"oneflow/core/job/nd_sbp_util.h\"\n\nnamespace oneflow {\n\n/* static */ Maybe<void> NormalOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  ctx->SetOutputShape(\"out\", 0, ctx->Attr<Shape>(\"shape\"));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> NormalOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& parallel_hierarchy = *ctx->parallel_desc().hierarchy();\n  const NdSbp& nd_sbp = ctx->NdSbp4ArgNameAndIndex(\"out\", 0);\n  const Shape& logical_shape = ctx->Attr<Shape>(\"shape\");\n  const int64_t parallel_id = ctx->parallel_ctx().parallel_id();\n  const auto tensor_slice_view =\n      GetTensorSliceView4ParallelId(parallel_hierarchy, nd_sbp, logical_shape, parallel_id);\n  const Shape& physical_shape = tensor_slice_view.shape();\n\n  ctx->SetOutputShape(\"out\", 0, physical_shape);\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> NormalOp::GetSbp(user_op::SbpContext* ctx) {\n  ctx->NewBuilder().Broadcast(ctx->inputs()).Broadcast(ctx->outputs()).Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> NormalOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) {\n  SbpParallel default_sbp;\n  default_sbp.mutable_broadcast_parallel();\n  return user_op::InferNdSbp4SrcOp(ctx, default_sbp);\n}\n\n/* static */ Maybe<void> NormalOp::InferDataType(user_op::InferContext* ctx) {\n  auto dtype = ctx->Attr<DataType>(\"dtype\");\n  ctx->SetOutputDType(\"out\", 0, dtype);\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/distributions/uniform_int_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n#include \"oneflow/core/job/nd_sbp_util.h\"\n\nnamespace oneflow {\n\n/* static */ Maybe<void> UniformIntOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& shape = ctx->Attr<Shape>(\"shape\");\n  DimVector dim_vec;\n  if (shape.NumAxes() > 0) {\n    dim_vec.insert(dim_vec.end(), shape.dim_vec().cbegin(), shape.dim_vec().cend());\n  }\n  ctx->SetOutputShape(\"out\", 0, Shape(dim_vec));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> UniformIntOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& parallel_hierarchy = *ctx->parallel_desc().hierarchy();\n  const NdSbp& nd_sbp = ctx->NdSbp4ArgNameAndIndex(\"out\", 0);\n  const Shape& logical_shape = ctx->Attr<Shape>(\"shape\");\n  const int64_t parallel_id = ctx->parallel_ctx().parallel_id();\n  const auto tensor_slice_view =\n      GetTensorSliceView4ParallelId(parallel_hierarchy, nd_sbp, logical_shape, parallel_id);\n  const Shape& physical_shape = tensor_slice_view.shape();\n\n  ctx->SetOutputShape(\"out\", 0, physical_shape);\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> UniformIntOp::GetSbp(user_op::SbpContext* ctx) {\n  ctx->NewBuilder().Broadcast(ctx->inputs()).Broadcast(ctx->outputs()).Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> UniformIntOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) {\n  SbpParallel default_sbp;\n  default_sbp.mutable_broadcast_parallel();\n  return user_op::InferNdSbp4SrcOp(ctx, default_sbp);\n}\n\n/* static */ Maybe<void> UniformIntOp::InferDataType(user_op::InferContext* ctx) {\n  auto dtype = ctx->Attr<DataType>(\"dtype\");\n  ctx->SetOutputDType(\"out\", 0, dtype);\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/distributions/uniform_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n#include \"oneflow/core/job/nd_sbp_util.h\"\n\nnamespace oneflow {\n\n/* static */ Maybe<void> UniformOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& shape = ctx->Attr<Shape>(\"shape\");\n  DimVector dim_vec;\n  if (shape.NumAxes() > 0) {\n    dim_vec.insert(dim_vec.end(), shape.dim_vec().cbegin(), shape.dim_vec().cend());\n  }\n  ctx->SetOutputShape(\"out\", 0, Shape(dim_vec));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> UniformOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& parallel_hierarchy = *ctx->parallel_desc().hierarchy();\n  const NdSbp& nd_sbp = ctx->NdSbp4ArgNameAndIndex(\"out\", 0);\n  const Shape& logical_shape = ctx->Attr<Shape>(\"shape\");\n  const int64_t parallel_id = ctx->parallel_ctx().parallel_id();\n  const auto tensor_slice_view =\n      GetTensorSliceView4ParallelId(parallel_hierarchy, nd_sbp, logical_shape, parallel_id);\n  const Shape& physical_shape = tensor_slice_view.shape();\n\n  ctx->SetOutputShape(\"out\", 0, physical_shape);\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> UniformOp::GetSbp(user_op::SbpContext* ctx) {\n  const Shape& logical_shape = ctx->Attr<Shape>(\"shape\");\n  int64_t num_axes = logical_shape.NumAxes();\n  for (int i = 0; i < num_axes; ++i) {\n    ctx->NewBuilder().Broadcast(ctx->inputs()).Split(ctx->outputs(), i).Build();\n  }\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> UniformOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) {\n  SbpParallel default_sbp;\n  default_sbp.mutable_broadcast_parallel();\n  return user_op::InferNdSbp4SrcOp(ctx, default_sbp);\n}\n\n/* static */ Maybe<void> UniformOp::InferDataType(user_op::InferContext* ctx) {\n  auto dtype = ctx->Attr<DataType>(\"dtype\");\n  ctx->SetOutputDType(\"out\", 0, dtype);\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> UniformOp::DumpNdSbpSignatureForOpConfFn(const NdSbpSignature& nd_sbp_sig,\n                                                                  OperatorConf* op_conf) {\n  return user_op::SetSrcOpNdSbp(nd_sbp_sig, \"out_0\", op_conf);\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/dot_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/* static */ Maybe<void> DotOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& x = ctx->InputTensorDesc(\"x\", 0);\n  const user_op::TensorDesc& y = ctx->InputTensorDesc(\"y\", 0);\n  CHECK_OR_RETURN(x.shape() == y.shape())\n      << Error::RuntimeError()\n      << \"inconsistent tensor size, expected tensor to have the same number of elements, but got \"\n      << x.shape().elem_cnt() << \" and \" << y.shape().elem_cnt() << \" elements respectively\";\n  CHECK_OR_RETURN(x.shape().NumAxes() == 1)\n      << Error::RuntimeError() << \"1D tensors expected, but got \" << x.shape().NumAxes()\n      << \"D tensors\";\n  ctx->SetOutputShape(\"out\", 0, Shape({}));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> DotOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> DotOp::GetSbp(user_op::SbpContext* ctx) {\n  ctx->NewBuilder()\n      .Split(user_op::OpArg(\"x\", 0), 0)\n      .Split(user_op::OpArg(\"y\", 0), 0)\n      .PartialSum(user_op::OpArg(\"out\", 0))\n      .Build();\n\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> DotOp::InferDataType(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& x = ctx->InputTensorDesc(\"x\", 0);\n  const user_op::TensorDesc& y = ctx->InputTensorDesc(\"y\", 0);\n  CHECK_OR_RETURN(x.data_type() == y.data_type())\n      << Error::RuntimeError() << \"expected both vectors to have same dtype, but found \"\n      << DataType_Name(x.data_type()) << \" and \" << DataType_Name(y.data_type());\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"x\", 0));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/dropout_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/* static */ Maybe<void> DropoutOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& in_shape = ctx->InputShape(\"in\", 0);\n  ctx->SetOutputShape(\"out\", 0, in_shape);\n  ctx->SetOutputShape(\"mask\", 0, in_shape);\n  ctx->SetOutputIsDynamic(\"out\", 0, ctx->InputIsDynamic(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> DropoutOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> DropoutOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"in\", 0);\n  FOR_RANGE(int64_t, axis, 0, in_tensor.shape().NumAxes()) {\n    ctx->NewBuilder().Split(ctx->inputs(), axis).Split(ctx->outputs(), axis).Build();\n  }\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> DropoutOp::CheckAttr(const user_op::UserOpDefWrapper& def,\n                                              const user_op::UserOpConfWrapper& conf) {\n  float rate = conf.attr<float>(\"rate\");\n  CHECK_GE_OR_RETURN(rate, 0.0);\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> DropoutOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"in\", 0));\n  ctx->SetOutputDType(\"mask\", 0, DataType::kBool);\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> DropoutGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& dy_shape = ctx->InputShape(\"dy\", 0);\n  ctx->SetOutputShape(\"dx\", 0, dy_shape);\n  ctx->SetOutputIsDynamic(\"dx\", 0, ctx->InputIsDynamic(\"dy\", 0));\n  CHECK_EQ_OR_RETURN(ctx->InputShape(\"mask\", 0), dy_shape);\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> DropoutGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> DropoutGradOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& dy_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"dy\", 0);\n  FOR_RANGE(int64_t, axis, 0, dy_tensor.shape().NumAxes()) {\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"dy\", 0), axis)\n        .Split(user_op::OpArg(\"mask\", 0), axis)\n        .Split(user_op::OpArg(\"dx\", 0), axis)\n        .Build();\n  }\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> DropoutGradOp::CheckAttr(const user_op::UserOpDefWrapper& def,\n                                                  const user_op::UserOpConfWrapper& conf) {\n  float scale = conf.attr<float>(\"scale\");\n  CHECK_GT_OR_RETURN(scale, 1);\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> DropoutGradOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"dx\", 0, ctx->InputDType(\"dy\", 0));\n  CHECK_EQ_OR_RETURN(ctx->InputDType(\"mask\", 0), DataType::kBool)\n      << \"InferDataType Failed. Expected \" << DataType_Name(DataType::kBool) << \", but got \"\n      << DataType_Name(ctx->InputDType(\"mask\", 0));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> RandomMaskLikeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  ctx->SetOutputShape(\"out\", 0, ctx->InputShape(\"like\", 0));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> RandomMaskLikeOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> RandomMaskLikeOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& like_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"like\", 0);\n  FOR_RANGE(int64_t, axis, 0, like_tensor.shape().NumAxes()) {\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"like\", 0), axis)\n        .Split(user_op::OpArg(\"out\", 0), axis)\n        .Build();\n  }\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> RandomMaskLikeOp::CheckAttr(const user_op::UserOpDefWrapper& def,\n                                                     const user_op::UserOpConfWrapper& conf) {\n  float rate = conf.attr<float>(\"rate\");\n  CHECK_GE_OR_RETURN(rate, 0);\n  CHECK_LT_OR_RETURN(rate, 1);\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> RandomMaskLikeOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, DataType::kBool);\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/dynamic_loss_scale_schedule_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nbool IsScalarTensor(const user_op::TensorDesc* desc) {\n  return desc->shape().NumAxes() == 1 && desc->shape().At(0) == 1;\n}\n\nbool IsTensorWithType(const user_op::TensorDesc* desc, DataType data_type) {\n  return desc->data_type() == data_type;\n}\n\n}  // namespace\n\n/* static */ Maybe<void> DynamicLossScaleScheduleOp::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) {\n  CHECK_OR_RETURN(IsScalarTensor(&(ctx->InputTensorDesc(\"count_not_finite\", 0))));\n  CHECK_OR_RETURN(IsScalarTensor(&(ctx->InputTensorDesc(\"loss_scale\", 0))));\n  CHECK_OR_RETURN(IsScalarTensor(&(ctx->InputTensorDesc(\"good_step_counter\", 0))));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> DynamicLossScaleScheduleOp::InferPhysicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> DynamicLossScaleScheduleOp::GetSbp(user_op::SbpContext* ctx) {\n  return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx);\n}\n\n/* static */ Maybe<void> DynamicLossScaleScheduleOp::ModifyInputArg(\n    const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) {\n  user_op::InputArgModifier* loss_scale = GetInputArgModifierFn(\"loss_scale\", 0);\n  CHECK_OR_RETURN(loss_scale != nullptr);\n  loss_scale->set_is_mutable(true);\n  user_op::InputArgModifier* good_step_counter = GetInputArgModifierFn(\"good_step_counter\", 0);\n  CHECK_OR_RETURN(good_step_counter != nullptr);\n  good_step_counter->set_is_mutable(true);\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> DynamicLossScaleScheduleOp::InferDataType(user_op::InferContext* ctx) {\n  CHECK_OR_RETURN(\n      IsTensorWithType(&(ctx->InputTensorDesc(\"count_not_finite\", 0)), DataType::kInt64));\n  CHECK_OR_RETURN(IsTensorWithType(&(ctx->InputTensorDesc(\"loss_scale\", 0)), DataType::kFloat));\n  CHECK_OR_RETURN(\n      IsTensorWithType(&(ctx->InputTensorDesc(\"good_step_counter\", 0)), DataType::kInt64));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/eager_b_to_s_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/job/parallel_desc.h\"\n#include \"oneflow/core/common/decorator.h\"\n#include \"oneflow/core/common/balanced_splitter.h\"\n#include \"oneflow/core/common/shape.h\"\n#include \"oneflow/core/framework/device.h\"\n#include \"oneflow/user/ops/comm_net_device_infer_util.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n// Can only be called in local TODO: move this comment to ods\n/* static */ Maybe<void> EagerBToSOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& shape = ctx->Attr<Shape>(\"shape\");\n  const std::string& out_parallel_conf_txt = ctx->Attr<std::string>(\"out_parallel_conf\");\n  const int64_t out_split_axis = ctx->Attr<int64_t>(\"out_split_axis\");\n  Symbol<ParallelDesc> out_parallel_desc = JUST(TxtStringToPlacement(out_parallel_conf_txt));\n  DimVector dim_vec{shape.dim_vec()};\n  int64_t out_parallel_num = out_parallel_desc->parallel_num();\n  if (out_parallel_num > 1) {\n    CHECK_LT_OR_RETURN(out_split_axis, shape.NumAxes());\n    BalancedSplitter bs(shape.At(out_split_axis), out_parallel_num);\n    const auto& opt_parallel_id = JUST(GetParallelId4CurrentProcessCtx(out_parallel_desc));\n    int64_t parallel_id = opt_parallel_id->value_or(0);\n    dim_vec[out_split_axis] = bs.At(parallel_id).size();\n  }\n  ctx->SetOutputShape(\"out\", 0, Shape(dim_vec));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> EagerBToSOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> EagerBToSOp::GetSbp(user_op::SbpContext* ctx) {\n  return Error::TypeError() << \"eager_b_to_s op doesn't support global tensor!\";\n}\n\n/* static */ Maybe<void> EagerBToSOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) {\n  return Error::TypeError() << \"eager_b_to_s op doesn't support global tensor!\";\n}\n\n/* static */ Maybe<void> EagerBToSOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<Symbol<Stream>> EagerBToSOp::InferDeviceAndStream(\n    user_op::DeviceAndStreamInferContext* ctx) {\n  return DeviceAndStreamInferFn(ctx);\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/eager_ccl_ops.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/common/balanced_splitter.h\"\n#include \"oneflow/core/common/decorator.h\"\n#include \"oneflow/core/common/container_util.h\"\n#include \"oneflow/core/framework/device.h\"\n#include \"oneflow/user/ops/comm_net_device_infer_util.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n#include \"oneflow/core/job/nd_sbp_util.h\"\n\nnamespace oneflow {\n\n/* static */ Maybe<void> EagerCclAllReduceOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  ctx->SetOutputShape(\"out\", 0, ctx->InputShape(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> EagerCclAllReduceOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> EagerCclAllReduceOp::GetSbp(user_op::SbpContext* ctx) {\n  ctx->NewBuilder().PartialSum(user_op::OpArg(\"in\", 0)).Broadcast(user_op::OpArg(\"out\", 0)).Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> EagerCclAllReduceOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<Symbol<Stream>> EagerCclAllReduceOp::InferDeviceAndStream(\n    user_op::DeviceAndStreamInferContext* ctx) {\n  return DeviceAndStreamInferFn(ctx);\n}\n\n/* static */ Maybe<void> EagerCclBroadcastOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  size_t size = ctx->input_size(\"in\");\n  const std::vector<Shape>& shape_list = ctx->Attr<std::vector<Shape>>(\"shape_list\");\n  CHECK_EQ_OR_RETURN(size, ctx->output_size(\"out\"))\n      << \"the size of input tensor tuple should equal the size of output tensor tuple.\";\n  for (int i = 0; i < size; ++i) { ctx->SetOutputShape(\"out\", i, JUST(VectorAt(shape_list, i))); }\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> EagerCclBroadcastOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> EagerCclBroadcastOp::GetSbp(user_op::SbpContext* ctx) {\n  ctx->NewBuilder().PartialSum(ctx->inputs()).Broadcast(ctx->outputs()).Build();\n  ctx->NewBuilder().Broadcast(ctx->inputs()).Broadcast(ctx->outputs()).Build();\n  ctx->NewBuilder().Split(ctx->inputs(), 0).Broadcast(ctx->outputs()).Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> EagerCclBroadcastOp::InferDataType(user_op::InferContext* ctx) {\n  size_t size = ctx->input_size(\"in\");\n  CHECK_EQ_OR_RETURN(size, ctx->output_size(\"out\"))\n      << \"the size of input tensor tuple should equal the size of output tensor tuple.\";\n  for (int i = 0; i < size; ++i) { ctx->SetOutputDType(\"out\", i, ctx->InputDType(\"in\", i)); }\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<Symbol<Stream>> EagerCclBroadcastOp::InferDeviceAndStream(\n    user_op::DeviceAndStreamInferContext* ctx) {\n  return DeviceAndStreamInferFn(ctx);\n}\n\n/* static */ Maybe<void> EagerCclTouchOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> EagerCclTouchOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> EagerCclTouchOp::GetSbp(user_op::SbpContext* ctx) {\n  // local only\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> EagerCclTouchOp::InferDataType(user_op::InferContext* ctx) {\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<Symbol<Stream>> EagerCclTouchOp::InferDeviceAndStream(\n    user_op::DeviceAndStreamInferContext* ctx) {\n  return DeviceAndStreamInferFn(ctx);\n}\n\n/* static */ Maybe<void> EagerCclReduceOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  ctx->SetOutputShape(\"out\", 0, ctx->InputShape(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> EagerCclReduceOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> EagerCclReduceOp::GetSbp(user_op::SbpContext* ctx) {\n  UNIMPLEMENTED_THEN_RETURN() << \"global tensor are not supported\";\n}\n\n/* static */ Maybe<void> EagerCclReduceOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<Symbol<Stream>> EagerCclReduceOp::InferDeviceAndStream(\n    user_op::DeviceAndStreamInferContext* ctx) {\n  return DeviceAndStreamInferFn(ctx);\n}\n\n/* static */ Maybe<void> EagerCclReduceScatterOp::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) {\n  ctx->SetOutputShape(\"out\", 0, ctx->InputShape(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> EagerCclReduceScatterOp::InferPhysicalTensorDesc(\n    user_op::InferContext* ctx) {\n  const auto& input_shape = ctx->InputShape(\"in\", 0);\n  const auto& shape = ctx->Attr<Shape>(\"output_shape\");\n  Symbol<ParallelDesc> parallel_desc =\n      JUST(TxtStringToPlacement(ctx->Attr<std::string>(\"parallel_conf\")));\n  CHECK_EQ_OR_RETURN(input_shape.elem_cnt(), shape.elem_cnt() * parallel_desc->parallel_num())\n      << Error::RuntimeError()\n      << \"output tensor size must be equal to world_size times input tensor size\";\n  CHECK_EQ_OR_RETURN(ctx->InputDType(\"in\", 0), ctx->Attr<DataType>(\"output_dtype\"))\n      << Error::RuntimeError() << \"output tensor must have the same type as input tensor\";\n  ctx->SetOutputShape(\"out\", 0, ctx->Attr<Shape>(\"output_shape\"));\n  ctx->SetOutputDType(\"out\", 0, ctx->Attr<DataType>(\"output_dtype\"));\n  ctx->SetOutputIsDynamic(\"out\", 0, ctx->InputIsDynamic(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> EagerCclReduceScatterOp::GetSbp(user_op::SbpContext* ctx) {\n  return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx);\n}\n\n/* static */ Maybe<void> EagerCclReduceScatterOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) {\n  const NdSbp& in_dis_hint = ctx->NdSbpHint4InputArgNameAndIndex(\"in\", 0);\n  NdSbp* in_nd_sbp = ctx->NdSbp4ArgNameAndIndex(\"in\", 0);\n  NdSbp* out_nd_sbp = ctx->NdSbp4ArgNameAndIndex(\"out\", 0);\n  CHECK_GE_OR_RETURN(in_dis_hint.sbp_parallel_size(), 1);\n  for (const auto& sbp_hint : in_dis_hint.sbp_parallel()) {\n    CHECK_OR_RETURN(sbp_hint.has_partial_sum_parallel() || sbp_hint.has_broadcast_parallel());\n  }\n  in_nd_sbp->clear_sbp_parallel();\n  out_nd_sbp->clear_sbp_parallel();\n\n  // P2S or B2S\n  const Shape& parallel_hierarchy = ctx->parallel_hierarchy();\n  CHECK_GE_OR_RETURN(parallel_hierarchy.NumAxes(), 1);\n  in_nd_sbp->CopyFrom(in_dis_hint);\n  for (int32_t i = 0; i < parallel_hierarchy.NumAxes(); ++i) {\n    out_nd_sbp->add_sbp_parallel()->mutable_split_parallel()->set_axis(0);\n  }\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> EagerCclReduceScatterOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<Symbol<Stream>> EagerCclReduceScatterOp::InferDeviceAndStream(\n    user_op::DeviceAndStreamInferContext* ctx) {\n  return DeviceAndStreamInferFn(ctx);\n}\n\n/* static */ Maybe<void> EagerCclAllGatherOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  ctx->SetOutputShape(\"out\", 0, ctx->InputShape(\"in\", 0));\n  ctx->SetOutputIsDynamic(\"out\", 0, ctx->InputIsDynamic(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> EagerCclAllGatherOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  const auto& input_shape = ctx->InputShape(\"in\", 0);\n  const auto& shape = ctx->Attr<Shape>(\"output_shape\");\n  Symbol<ParallelDesc> parallel_desc =\n      JUST(TxtStringToPlacement(ctx->Attr<std::string>(\"parallel_conf\")));\n  CHECK_EQ_OR_RETURN(input_shape.elem_cnt() * parallel_desc->parallel_num(), shape.elem_cnt())\n      << Error::RuntimeError()\n      << \"output tensor size must be equal to world_size times input tensor size\";\n  CHECK_EQ_OR_RETURN(ctx->InputDType(\"in\", 0), ctx->Attr<DataType>(\"output_dtype\"))\n      << Error::RuntimeError() << Error::RuntimeError()\n      << \"output tensor must have the same type as input tensor\";\n  ctx->SetOutputShape(\"out\", 0, ctx->Attr<Shape>(\"output_shape\"));\n  ctx->SetOutputDType(\"out\", 0, ctx->Attr<DataType>(\"output_dtype\"));\n  ctx->SetOutputIsDynamic(\"out\", 0, ctx->InputIsDynamic(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> EagerCclAllGatherOp::GetSbp(user_op::SbpContext* ctx) {\n  return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx);\n}\n\n/* static */ Maybe<void> EagerCclAllGatherOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) {\n  const NdSbp& in_dis_hint = ctx->NdSbpHint4InputArgNameAndIndex(\"in\", 0);\n  NdSbp* in_nd_sbp = ctx->NdSbp4ArgNameAndIndex(\"in\", 0);\n  NdSbp* out_nd_sbp = ctx->NdSbp4ArgNameAndIndex(\"out\", 0);\n  CHECK_GE_OR_RETURN(in_dis_hint.sbp_parallel_size(), 1);\n  for (const auto& sbp_hint : in_dis_hint.sbp_parallel()) {\n    CHECK_OR_RETURN(sbp_hint.has_split_parallel());\n    CHECK_EQ_OR_RETURN(sbp_hint.split_parallel().axis(), 0);\n  }\n\n  in_nd_sbp->clear_sbp_parallel();\n  out_nd_sbp->clear_sbp_parallel();\n\n  // S(0)->B\n  const Shape& parallel_hierarchy = ctx->parallel_hierarchy();\n  CHECK_GE_OR_RETURN(parallel_hierarchy.NumAxes(), 1);\n  for (int32_t i = 0; i < parallel_hierarchy.NumAxes(); ++i) {\n    in_nd_sbp->add_sbp_parallel()->mutable_split_parallel()->set_axis(0);\n    out_nd_sbp->add_sbp_parallel()->mutable_broadcast_parallel();\n  }\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> EagerCclAllGatherOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<Symbol<Stream>> EagerCclAllGatherOp::InferDeviceAndStream(\n    user_op::DeviceAndStreamInferContext* ctx) {\n  return DeviceAndStreamInferFn(ctx);\n}\n\n/* static */ Maybe<void> EagerCclS2SOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  ctx->SetOutputShape(\"out\", 0, ctx->InputShape(\"in\", 0));\n  ctx->SetOutputIsDynamic(\"out\", 0, ctx->InputIsDynamic(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> EagerCclS2SOp::GetSbp(user_op::SbpContext* ctx) {\n  return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx);\n}\n\n/* static */ Maybe<void> EagerCclS2SOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) {\n  const int64_t in_split_axis = ctx->user_op_conf().attr<int64_t>(\"in_split_axis\");\n  const int64_t out_split_axis = ctx->user_op_conf().attr<int64_t>(\"out_split_axis\");\n  const NdSbp& in_dis_hint = ctx->NdSbpHint4InputArgNameAndIndex(\"in\", 0);\n  NdSbp* in_nd_sbp = ctx->NdSbp4ArgNameAndIndex(\"in\", 0);\n  NdSbp* out_nd_sbp = ctx->NdSbp4ArgNameAndIndex(\"out\", 0);\n  CHECK_GE_OR_RETURN(in_dis_hint.sbp_parallel_size(), 1);\n  for (const auto& sbp_hint : in_dis_hint.sbp_parallel()) {\n    CHECK_OR_RETURN(sbp_hint.has_split_parallel());\n    CHECK_EQ_OR_RETURN(sbp_hint.split_parallel().axis(), in_split_axis);\n  }\n\n  in_nd_sbp->clear_sbp_parallel();\n  out_nd_sbp->clear_sbp_parallel();\n\n  // S(in)->S(out)\n  const Shape& parallel_hierarchy = ctx->parallel_hierarchy();\n  CHECK_GE_OR_RETURN(parallel_hierarchy.NumAxes(), 1);\n  for (int32_t i = 0; i < parallel_hierarchy.NumAxes(); ++i) {\n    in_nd_sbp->add_sbp_parallel()->mutable_split_parallel()->set_axis(in_split_axis);\n    out_nd_sbp->add_sbp_parallel()->mutable_split_parallel()->set_axis(out_split_axis);\n  }\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> EagerCclS2SOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<Symbol<Stream>> EagerCclS2SOp::InferDeviceAndStream(\n    user_op::DeviceAndStreamInferContext* ctx) {\n  return DeviceAndStreamInferFn(ctx);\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/eager_p_to_b_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/job/parallel_desc.h\"\n#include \"oneflow/core/common/decorator.h\"\n#include \"oneflow/core/common/shape.h\"\n#include \"oneflow/core/framework/device.h\"\n#include \"oneflow/user/ops/comm_net_device_infer_util.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n// Can only be called in local\n/* static */ Maybe<void> EagerPToBOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  ctx->SetOutputShape(\"out\", 0, Shape(ctx->Attr<Shape>(\"shape\").dim_vec()));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> EagerPToBOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> EagerPToBOp::GetSbp(user_op::SbpContext* ctx) {\n  return Error::TypeError() << \"eager_s_to_b op doesn't support global tensor!\";\n}\n\n/* static */ Maybe<void> EagerPToBOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) {\n  return Error::TypeError() << \"eager_s_to_b op doesn't support global tensor!\";\n}\n\n/* static */ Maybe<void> EagerPToBOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<Symbol<Stream>> EagerPToBOp::InferDeviceAndStream(\n    user_op::DeviceAndStreamInferContext* ctx) {\n  return DeviceAndStreamInferFn(ctx);\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/eager_p_to_s_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/job/parallel_desc.h\"\n#include \"oneflow/core/common/balanced_splitter.h\"\n#include \"oneflow/core/common/decorator.h\"\n#include \"oneflow/core/common/shape.h\"\n#include \"oneflow/core/framework/device.h\"\n#include \"oneflow/user/ops/comm_net_device_infer_util.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/* static */ Maybe<void> EagerPToSOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& shape = ctx->Attr<Shape>(\"shape\");\n  const std::string& out_parallel_conf_txt = ctx->Attr<std::string>(\"out_parallel_conf\");\n  const int64_t out_split_axis = ctx->Attr<int64_t>(\"out_split_axis\");\n  Symbol<ParallelDesc> out_parallel_desc = JUST(TxtStringToPlacement(out_parallel_conf_txt));\n  DimVector dim_vec{shape.dim_vec()};\n  int64_t out_parallel_num = out_parallel_desc->parallel_num();\n  if (out_parallel_num > 1) {\n    CHECK_LT_OR_RETURN(out_split_axis, shape.NumAxes());\n    BalancedSplitter bs(shape.At(out_split_axis), out_parallel_num);\n    const auto& opt_parallel_id = JUST(GetParallelId4CurrentProcessCtx(out_parallel_desc));\n    int64_t parallel_id = opt_parallel_id->value_or(0);\n    dim_vec[out_split_axis] = bs.At(parallel_id).size();\n  }\n  ctx->SetOutputShape(\"out\", 0, Shape(dim_vec));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> EagerPToSOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> EagerPToSOp::GetSbp(user_op::SbpContext* ctx) {\n  return Error::TypeError() << \"eager_b_to_s op doesn't support global tensor!\";\n}\n\n/* static */ Maybe<void> EagerPToSOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) {\n  return Error::TypeError() << \"eager_b_to_s op doesn't support global tensor!\";\n}\n\n/* static */ Maybe<void> EagerPToSOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<Symbol<Stream>> EagerPToSOp::InferDeviceAndStream(\n    user_op::DeviceAndStreamInferContext* ctx) {\n  return DeviceAndStreamInferFn(ctx);\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/eager_s_to_b_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/job/parallel_desc.h\"\n#include \"oneflow/core/common/decorator.h\"\n#include \"oneflow/core/common/shape.h\"\n#include \"oneflow/core/framework/device.h\"\n#include \"oneflow/user/ops/comm_net_device_infer_util.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/* static */ Maybe<void> EagerSToBOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  ctx->SetOutputShape(\"out\", 0, Shape(ctx->Attr<Shape>(\"shape\").dim_vec()));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> EagerSToBOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> EagerSToBOp::GetSbp(user_op::SbpContext* ctx) {\n  return Error::TypeError() << \"eager_s_to_b op doesn't support global tensor!\";\n}\n\n/* static */ Maybe<void> EagerSToBOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) {\n  return Error::TypeError() << \"eager_s_to_b op doesn't support global tensor!\";\n}\n\n/* static */ Maybe<void> EagerSToBOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<Symbol<Stream>> EagerSToBOp::InferDeviceAndStream(\n    user_op::DeviceAndStreamInferContext* ctx) {\n  return DeviceAndStreamInferFn(ctx);\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/eager_s_to_p_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/job/parallel_desc.h\"\n#include \"oneflow/core/common/decorator.h\"\n#include \"oneflow/core/common/shape.h\"\n#include \"oneflow/core/framework/device.h\"\n#include \"oneflow/user/ops/comm_net_device_infer_util.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/* static */ Maybe<void> EagerSToPOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  ctx->SetOutputShape(\"out\", 0, Shape(ctx->Attr<Shape>(\"shape\").dim_vec()));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> EagerSToPOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> EagerSToPOp::GetSbp(user_op::SbpContext* ctx) {\n  return Error::TypeError() << \"eager_b_to_s op doesn't support global tensor!\";\n}\n\n/* static */ Maybe<void> EagerSToPOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) {\n  return Error::TypeError() << \"eager_b_to_s op doesn't support global tensor!\";\n}\n\n/* static */ Maybe<void> EagerSToPOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<Symbol<Stream>> EagerSToPOp::InferDeviceAndStream(\n    user_op::DeviceAndStreamInferContext* ctx) {\n  return DeviceAndStreamInferFn(ctx);\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/eager_s_to_s_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/job/parallel_desc.h\"\n#include \"oneflow/core/common/balanced_splitter.h\"\n#include \"oneflow/core/common/decorator.h\"\n#include \"oneflow/core/common/shape.h\"\n#include \"oneflow/core/framework/device.h\"\n#include \"oneflow/user/ops/comm_net_device_infer_util.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/* static */ Maybe<void> EagerNaiveSToSOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& shape = ctx->Attr<Shape>(\"shape\");\n  const std::string& out_parallel_conf_txt = ctx->Attr<std::string>(\"out_parallel_conf\");\n  const int64_t out_split_axis = ctx->Attr<int64_t>(\"out_split_axis\");\n  Symbol<ParallelDesc> out_parallel_desc = JUST(TxtStringToPlacement(out_parallel_conf_txt));\n  DimVector dim_vec{shape.dim_vec()};\n  int64_t out_parallel_num = out_parallel_desc->parallel_num();\n  if (out_parallel_num > 1) {\n    CHECK_LE_OR_RETURN(out_split_axis, shape.NumAxes());\n    BalancedSplitter bs(shape.At(out_split_axis), out_parallel_num);\n    const auto& opt_parallel_id = JUST(GetParallelId4CurrentProcessCtx(out_parallel_desc));\n    int64_t parallel_id = opt_parallel_id->value_or(0);\n    dim_vec[out_split_axis] = bs.At(parallel_id).size();\n  }\n  ctx->SetOutputShape(\"out\", 0, Shape(dim_vec));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> EagerNaiveSToSOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> EagerNaiveSToSOp::GetSbp(user_op::SbpContext* ctx) {\n  return Error::TypeError() << \"eager_naive_s_to_s op doesn't support global tensor!\";\n}\n\n/* static */ Maybe<void> EagerNaiveSToSOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) {\n  return Error::TypeError() << \"eager_naive_s_to_s op doesn't support global tensor!\";\n}\n\n/* static */ Maybe<void> EagerNaiveSToSOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<Symbol<Stream>> EagerNaiveSToSOp::InferDeviceAndStream(\n    user_op::DeviceAndStreamInferContext* ctx) {\n  return DeviceAndStreamInferFn(ctx);\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/eager_symmetric_s_to_p_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/job/parallel_desc.h\"\n#include \"oneflow/core/framework/device.h\"\n#include \"oneflow/user/ops/comm_net_device_infer_util.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/* static */ Maybe<void> EagerSymmetricSToPOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  ctx->SetOutputShape(\"out\", 0, ctx->InputShape(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> EagerSymmetricSToPOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> EagerSymmetricSToPOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& in = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"in\", 0);\n  FOR_RANGE(int64_t, i, 0, in.shape().NumAxes()) {\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"in\", 0), i)\n        .PartialSum(user_op::OpArg(\"out\", 0))\n        .Build();\n  }\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> EagerSymmetricSToPOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) {\n  const int64_t in_split_axis = ctx->user_op_conf().attr<int64_t>(\"in_split_axis\");\n  const NdSbp& in_dis_hint = ctx->NdSbpHint4InputArgNameAndIndex(\"in\", 0);\n  NdSbp* in_nd_sbp = ctx->NdSbp4ArgNameAndIndex(\"in\", 0);\n  NdSbp* out_nd_sbp = ctx->NdSbp4ArgNameAndIndex(\"out\", 0);\n  CHECK_GE_OR_RETURN(in_dis_hint.sbp_parallel_size(), 1);\n  for (const auto& sbp_hint : in_dis_hint.sbp_parallel()) {\n    CHECK_OR_RETURN(sbp_hint.has_split_parallel());\n    CHECK_EQ_OR_RETURN(sbp_hint.split_parallel().axis(), in_split_axis);\n  }\n\n  in_nd_sbp->clear_sbp_parallel();\n  out_nd_sbp->clear_sbp_parallel();\n\n  const Shape& parallel_hierarchy = ctx->parallel_hierarchy();\n  CHECK_GE_OR_RETURN(parallel_hierarchy.NumAxes(), 1);\n  for (int32_t i = 0; i < parallel_hierarchy.NumAxes(); ++i) {\n    in_nd_sbp->add_sbp_parallel()->mutable_split_parallel()->set_axis(in_split_axis);\n    out_nd_sbp->add_sbp_parallel()->mutable_partial_sum_parallel();\n  }\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> EagerSymmetricSToPOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<Symbol<Stream>> EagerSymmetricSToPOp::InferDeviceAndStream(\n    user_op::DeviceAndStreamInferContext* ctx) {\n  return DeviceAndStreamInferFn(ctx);\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/elementwise_maximum_minimum_ops.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\nnamespace {\nusing namespace user_op;\n\nMaybe<void> GetSbpSignature_(SbpContext* ctx) {\n  const Shape& x_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"x\", 0).shape();\n  const Shape& y_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"y\", 0).shape();\n\n  FOR_RANGE(int64_t, i, 0, x_shape.NumAxes()) {\n    if (x_shape.At(i) == 1 && y_shape.At(i) == 1) { continue; }\n    if (x_shape.At(i) == y_shape.At(i)) {\n      ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build();\n    } else {\n      UNIMPLEMENTED();\n    }\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InferTensorDesc_(InferContext* ctx) {\n  const TensorDesc& tensor_x = ctx->InputTensorDesc(\"x\", 0);\n  const TensorDesc& tensor_y = ctx->InputTensorDesc(\"y\", 0);\n\n  CHECK_EQ_OR_RETURN(tensor_x.shape().NumAxes(), tensor_y.shape().NumAxes())\n      << \"Shape of tensor x and y should be same\";\n\n  FOR_RANGE(int64_t, i, 0, tensor_x.shape().NumAxes()) {\n    CHECK_EQ_OR_RETURN(tensor_x.shape().At(i), tensor_y.shape().At(i));\n  }\n\n  TensorDesc* tensor_dx = ctx->MutOutputTensorDesc(\"dx\", 0);\n  TensorDesc* tensor_dy = ctx->MutOutputTensorDesc(\"dy\", 0);\n\n  if (tensor_dx) { tensor_dx->set_shape(tensor_x.shape()); }\n\n  if (tensor_dy) { tensor_dy->set_shape(tensor_y.shape()); }\n\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InferDataType_(InferContext* ctx) {\n  const TensorDesc& tensor_dz = ctx->InputTensorDesc(\"dz\", 0);\n  TensorDesc* tensor_dx = ctx->MutOutputTensorDesc(\"dx\", 0);\n  TensorDesc* tensor_dy = ctx->MutOutputTensorDesc(\"dy\", 0);\n\n  if (tensor_dx) { tensor_dx->set_data_type(tensor_dz.data_type()); }\n\n  if (tensor_dy) { tensor_dy->set_data_type(tensor_dz.data_type()); }\n\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\n#define DEF_ELEMENTWISE_XIMUM_FW_OP(op_class_name_prefix)                                        \\\n  /* static */ Maybe<void> op_class_name_prefix##Op::InferLogicalTensorDesc(                     \\\n      user_op::InferContext* ctx) {                                                              \\\n    return user_op::TensorDescInferFnUtil::Unchanged(ctx);                                       \\\n  }                                                                                              \\\n                                                                                                 \\\n  /*static*/ Maybe<void> op_class_name_prefix##Op::InferPhysicalTensorDesc(                      \\\n      user_op::InferContext* ctx) {                                                              \\\n    return InferLogicalTensorDesc(ctx);                                                          \\\n  }                                                                                              \\\n                                                                                                 \\\n  /* static */ Maybe<void> op_class_name_prefix##Op::GetSbp(user_op::SbpContext* ctx) {          \\\n    return user_op::GetSbpFnUtil::SplitForEachAxis(ctx);                                         \\\n  }                                                                                              \\\n                                                                                                 \\\n  /* static */ Maybe<void> op_class_name_prefix##Op::InferDataType(user_op::InferContext* ctx) { \\\n    return user_op::TensorDescInferFnUtil::UnchangedDataType(ctx);                               \\\n  }\n\n#define DEF_ELEMENTWISE_XIMUM_BW_OP(op_class_name_prefix)                                       \\\n  /* static */ Maybe<void> op_class_name_prefix##BackwardOp::InferLogicalTensorDesc(            \\\n      user_op::InferContext* ctx) {                                                             \\\n    return InferTensorDesc_(ctx);                                                               \\\n  }                                                                                             \\\n                                                                                                \\\n  /*static*/ Maybe<void> op_class_name_prefix##BackwardOp::InferPhysicalTensorDesc(             \\\n      user_op::InferContext* ctx) {                                                             \\\n    return InferLogicalTensorDesc(ctx);                                                         \\\n  }                                                                                             \\\n                                                                                                \\\n  /* static */ Maybe<void> op_class_name_prefix##BackwardOp::GetSbp(user_op::SbpContext* ctx) { \\\n    return GetSbpSignature_(ctx);                                                               \\\n  }                                                                                             \\\n                                                                                                \\\n  /* static */ Maybe<void> op_class_name_prefix##BackwardOp::InferDataType(                     \\\n      user_op::InferContext* ctx) {                                                             \\\n    return InferDataType_(ctx);                                                                 \\\n  }\n\n#define REGISTER_ELEMENTWISE_XIMUM_OP(op_type_name, op_class_name_prefix) \\\n  DEF_ELEMENTWISE_XIMUM_FW_OP(op_class_name_prefix);                      \\\n  DEF_ELEMENTWISE_XIMUM_BW_OP(op_class_name_prefix);\n\nREGISTER_ELEMENTWISE_XIMUM_OP(\"elementwise_maximum\", ElementwiseMaximum);\nREGISTER_ELEMENTWISE_XIMUM_OP(\"elementwise_minimum\", ElementwiseMinimum);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/elu_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/* static */ Maybe<void> EluOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  ctx->SetOutputShape(\"out\", 0, ctx->InputShape(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> EluOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> EluOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"in\", 0);\n  FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) {\n    ctx->NewBuilder().Split(user_op::OpArg(\"in\", 0), i).Split(user_op::OpArg(\"out\", 0), i).Build();\n  }\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> EluOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> EluGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& x_shape = ctx->InputShape(\"x\", 0);\n  const Shape& dy_shape = ctx->InputShape(\"dy\", 0);\n  CHECK_OR_RETURN(dy_shape == x_shape);\n  ctx->SetOutputShape(\"dx\", 0, dy_shape);\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> EluGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> EluGradOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"x\", 0);\n  FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) {\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"x\", 0), i)\n        .Split(user_op::OpArg(\"dy\", 0), i)\n        .Split(user_op::OpArg(\"dx\", 0), i)\n        .Build();\n  }\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> EluGradOp::InferDataType(user_op::InferContext* ctx) {\n  CHECK_EQ_OR_RETURN(ctx->InputDType(\"dy\", 0), ctx->InputDType(\"x\", 0))\n      << \"InferDataType Failed. Expected \" << DataType_Name(ctx->InputDType(\"dy\", 0))\n      << \", but got \" << DataType_Name(ctx->InputDType(\"x\", 0));\n  ctx->SetOutputDType(\"dx\", 0, ctx->InputDType(\"x\", 0));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/embedding_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/* static */ Maybe<void> EmbeddingRenormOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  ctx->SetOutputShape(\"out\", 0, ctx->InputShape(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> EmbeddingRenormOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/*static*/ Maybe<void> EmbeddingRenormOp::GetSbp(user_op::SbpContext* ctx) {\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> EmbeddingRenormOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> EmbeddingOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& weight_shape = ctx->InputShape(\"weight\", 0);\n  const Shape& indices_shape = ctx->InputShape(\"indices\", 0);\n\n  DimVector out_dim_vec;\n  out_dim_vec.insert(out_dim_vec.end(), indices_shape.dim_vec().cbegin(),\n                     indices_shape.dim_vec().cend());\n  out_dim_vec.push_back(weight_shape.At(1));\n\n  user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc(\"out\", 0);\n  out_desc->set_shape(Shape(out_dim_vec));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> EmbeddingOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/*static*/ Maybe<void> EmbeddingOp::GetSbp(user_op::SbpContext* ctx) {\n  const int64_t indices_num_axes =\n      ctx->LogicalTensorDesc4InputArgNameAndIndex(\"indices\", 0).shape().NumAxes();\n  const bool scale_grad_by_freq = ctx->Attr<bool>(\"scale_grad_by_freq\");\n\n  if (!scale_grad_by_freq) {\n    FOR_RANGE(int64_t, i, 0, indices_num_axes) {\n      ctx->NewBuilder()\n          .Split(user_op::OpArg(\"indices\", 0), i)\n          .Broadcast(user_op::OpArg(\"weight\", 0))\n          .Split(user_op::OpArg(\"out\", 0), i)\n          .Build();\n    }\n  }\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> EmbeddingOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"weight\", 0));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> EmbeddingOp::ModifyInputArg(\n    const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) {\n  user_op::InputArgModifier* indices_modifier = GetInputArgModifierFn(\"indices\", 0);\n  CHECK_OR_RETURN(indices_modifier != nullptr);  // NOLINT(maybe-need-error-msg)\n  indices_modifier->set_requires_grad(false);\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> EmbeddingGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& weight_shape = ctx->InputShape(\"weight\", 0);\n  user_op::TensorDesc* dx_desc = ctx->MutOutputTensorDesc(\"dx\", 0);\n  dx_desc->set_shape(weight_shape);\n\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> EmbeddingGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return EmbeddingGradOp::InferLogicalTensorDesc(ctx);\n}\n\n/*static*/ Maybe<void> EmbeddingGradOp::GetSbp(user_op::SbpContext* ctx) {\n  const bool scale_grad_by_freq = ctx->Attr<bool>(\"scale_grad_by_freq\");\n  const int64_t indices_num_axes =\n      ctx->LogicalTensorDesc4InputArgNameAndIndex(\"indices\", 0).shape().NumAxes();\n\n  if (!scale_grad_by_freq) {\n    for (int32_t i = 0; i < indices_num_axes; i++) {\n      ctx->NewBuilder()\n          .Split(user_op::OpArg(\"dy\", 0), i)\n          .Broadcast(user_op::OpArg(\"weight\", 0))\n          .Split(user_op::OpArg(\"indices\", 0), i)\n          .PartialSum(user_op::OpArg(\"dx\", 0))\n          .Build();\n    }\n  }\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> EmbeddingGradOp::ModifyInputArg(\n    const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) {\n  user_op::InputArgModifier* indices_modifier = GetInputArgModifierFn(\"indices\", 0);\n  CHECK_OR_RETURN(indices_modifier != nullptr);  // NOLINT(maybe-need-error-msg)\n  indices_modifier->set_requires_grad(false);\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> EmbeddingGradOp::InferDataType(user_op::InferContext* ctx) {\n  CHECK_EQ_OR_RETURN(ctx->InputDType(\"weight\", 0), ctx->InputDType(\"dy\", 0))\n      << \"InferDataType Failed. Expected \" << DataType_Name(ctx->InputDType(\"dy\", 0))\n      << \", but got \" << DataType_Name(ctx->InputDType(\"weight\", 0));\n  ctx->SetOutputDType(\"dx\", 0, ctx->InputDType(\"dy\", 0));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/empty_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n#include \"oneflow/core/job/nd_sbp_util.h\"\n#include \"oneflow/core/framework/device.h\"\n#include \"oneflow/core/framework/stream.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nMaybe<Symbol<Stream>> MakeEmptyStream(const Symbol<Device>& out_device, const bool pin_memory) {\n  if (pin_memory) {\n    CHECK_OR_RETURN(out_device->type() == \"cpu\")\n        << \"empty op only support pin_memory in cpu device but got \" << out_device->type();\n    // TODO:(zhaoluyang) Parsing pin-memory-device from python\n    auto pin_device = JUST(Device::New(\"cuda\"));\n    return Stream::New(pin_device, StreamType::kPinnedCompute);\n  }\n  return Stream::New(out_device, StreamType::kCompute);\n}\n\n}  // namespace\n\n/* static */ Maybe<void> EmptyOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  ctx->SetOutputShape(\"out\", 0, Shape(ctx->Attr<Shape>(\"shape\").dim_vec()));\n  ctx->SetOutputStride(\"out\", 0, Stride(Shape(ctx->Attr<Shape>(\"shape\").dim_vec())));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> EmptyOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& parallel_hierarchy = *ctx->parallel_desc().hierarchy();\n  const NdSbp& nd_sbp = ctx->NdSbp4ArgNameAndIndex(\"out\", 0);\n  const Shape& logical_shape = ctx->Attr<Shape>(\"shape\");\n  const int64_t parallel_id = ctx->parallel_ctx().parallel_id();\n  const auto tensor_slice_view =\n      GetTensorSliceView4ParallelId(parallel_hierarchy, nd_sbp, logical_shape, parallel_id);\n  const Shape& physical_shape = tensor_slice_view.shape();\n\n  ctx->SetOutputShape(\"out\", 0, physical_shape);\n  ctx->SetOutputStride(\"out\", 0, Stride(physical_shape));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> EmptyOp::GetSbp(user_op::SbpContext* ctx) { return Maybe<void>::Ok(); }\n\n/* static */ Maybe<void> EmptyOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) {\n  SbpParallel default_sbp;\n  default_sbp.mutable_broadcast_parallel();\n  return user_op::InferNdSbp4SrcOp(ctx, default_sbp);\n}\n\n/* static */ Maybe<void> EmptyOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, ctx->Attr<DataType>(\"dtype\"));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<Symbol<Stream>> EmptyOp::InferDeviceAndStream(\n    user_op::DeviceAndStreamInferContext* ctx) {\n  Symbol<Device> out_device =\n      JUST(Device::New(ctx->Attr<std::string>(\"device_type\"), ctx->Attr<int64_t>(\"device_id\")));\n  *ctx->OutputTensorDevice4ArgNameAndIndex(\"out\", 0) = out_device;\n  const bool pin_memory = ctx->Attr<bool>(\"pin_memory\");\n  return MakeEmptyStream(out_device, pin_memory);\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/erfinv_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/* static */ Maybe<void> ErfInvOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& x_shape = ctx->InputShape(\"x\", 0);\n  ctx->SetOutputShape(\"y\", 0, x_shape);\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> ErfInvOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> ErfInvOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"x\", 0);\n  FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) {\n    ctx->NewBuilder().Split(user_op::OpArg(\"x\", 0), i).Split(user_op::OpArg(\"y\", 0), i).Build();\n  }\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> ErfInvOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"y\", 0, ctx->InputDType(\"x\", 0));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/expand_dims_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nint32_t TransformNegativeAxisToPositive(int32_t axis, const int32_t num_axes) {\n  axis = axis < 0 ? axis + num_axes + 1 : axis;\n  CHECK_GE(axis, 0);\n  CHECK_LE(axis, num_axes);\n  return axis;\n}\n\n}  // namespace\n\n/* static */ Maybe<void> ExpandDimsOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& in_shape = ctx->InputShape(\"in\", 0);\n  const int32_t axis =\n      TransformNegativeAxisToPositive(ctx->Attr<int32_t>(\"axis\"), in_shape.NumAxes());\n\n  auto dim_vec = in_shape.dim_vec();\n  dim_vec.insert(dim_vec.begin() + axis, 1);\n  ctx->SetOutputShape(\"out\", 0, Shape(dim_vec));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> ExpandDimsOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> ExpandDimsOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"in\", 0);\n  const int32_t axis =\n      TransformNegativeAxisToPositive(ctx->Attr<int32_t>(\"axis\"), in_tensor.shape().NumAxes());\n\n  auto dim_vec = in_tensor.shape().dim_vec();\n  FOR_RANGE(int32_t, in_axis, 0, dim_vec.size()) {\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"in\", 0), in_axis)\n        .Split(user_op::OpArg(\"out\", 0), in_axis < axis ? in_axis : in_axis + 1)\n        .Build();\n  }\n  ctx->NewBuilder()\n      .PartialSum(user_op::OpArg(\"in\", 0))\n      .PartialSum(user_op::OpArg(\"out\", 0))\n      .Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> ExpandDimsOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/expand_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n#include \"oneflow/core/job/nd_sbp_util.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nMaybe<void> InferExpandOutputStride(const Shape& input_shape, const Stride& input_stride,\n                                    const Shape& expand_shape, Stride* output_stride) {\n  CHECK_EQ_OR_RETURN(input_shape.size(), input_stride.size());  // NOLINT(maybe-need-error-msg)\n  size_t lpad = expand_shape.size() - input_shape.size();\n  CHECK_GE_OR_RETURN(lpad, 0);  // NOLINT(maybe-need-error-msg)\n\n  output_stride->resize(expand_shape.size(), 0);\n  for (int i = expand_shape.size() - 1; i >= 0; --i) {\n    int64_t dim = i < lpad ? 1 : input_shape[i - lpad];\n    if (dim == expand_shape[i]) {\n      if (i >= lpad) {\n        output_stride->at(i) = input_stride[i - lpad];\n      } else if (i < expand_shape.size() - 1) {\n        output_stride->at(i) = output_stride->at(i + 1) * expand_shape[i + 1];\n      }\n    } else {\n      CHECK_EQ_OR_RETURN(dim, 1);  // NOLINT(maybe-need-error-msg)\n    }\n  }\n  // NOTE: expand op only can output contiguous stride,\n  // because lazy don't support to_contiguous op for now\n  *output_stride = Stride(expand_shape);\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\n/* static */ Maybe<void> ExpandOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& input_shape = ctx->InputShape(\"in\", 0);\n  const Stride& input_stride = ctx->InputStride(\"in\", 0);\n  const Shape& expand_shape = ctx->Attr<Shape>(\"expand_shape\");\n\n  ctx->SetOutputShape(\"out\", 0, expand_shape);\n\n  Stride output_stride;\n  JUST(InferExpandOutputStride(input_shape, input_stride, expand_shape, &output_stride));\n  ctx->SetOutputStride(\"out\", 0, output_stride);\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> ExpandOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& input_shape = ctx->InputShape(\"in\", 0);\n  const Stride& input_stride = ctx->InputStride(\"in\", 0);\n\n  const auto& global_expand_shape = ctx->Attr<Shape>(\"expand_shape\");\n  const auto& output_sbp = ctx->NdSbp4ArgNameAndIndex(\"out\", 0);\n  const auto& device_mesh = *ctx->parallel_desc().hierarchy();\n  const auto& rank = ctx->parallel_ctx().parallel_id();\n  const auto local_view =\n      GetTensorSliceView4ParallelId(device_mesh, output_sbp, global_expand_shape, rank);\n  const auto& local_expand_shape = local_view.shape();\n  ctx->SetOutputShape(\"out\", 0, local_expand_shape);\n\n  Stride output_stride;\n  JUST(InferExpandOutputStride(input_shape, input_stride, local_expand_shape, &output_stride));\n  ctx->SetOutputStride(\"out\", 0, output_stride);\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> ExpandOp::GetSbp(user_op::SbpContext* ctx) {\n  const auto& global_in_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"in\", 0).shape();\n  const auto& global_expand_shape = ctx->Attr<Shape>(\"expand_shape\");\n  size_t lpad = global_expand_shape.size() - global_in_shape.size();\n  CHECK_GE_OR_RETURN(lpad, 0);  // NOLINT(maybe-need-error-msg)\n\n  for (size_t i = 0; i < global_in_shape.size(); ++i) {\n    if (global_in_shape[i] == global_expand_shape[i + lpad]) {\n      ctx->NewBuilder()\n          .Split(user_op::OpArg(\"in\", 0), i)\n          .Split(user_op::OpArg(\"out\", 0), i + lpad)\n          .Build();\n    }\n  }\n\n  ctx->NewBuilder()\n      .PartialSum(user_op::OpArg(\"in\", 0))\n      .PartialSum(user_op::OpArg(\"out\", 0))\n      .Build();\n\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> ExpandOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/eye_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/* static */ Maybe<void> EyeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  int64_t rows = ctx->Attr<int64_t>(\"rows\");\n  int64_t cols = ctx->Attr<int64_t>(\"cols\");\n  ctx->SetOutputShape(\"out\", 0, Shape({rows, cols}));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> EyeOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> EyeOp::GetSbp(user_op::SbpContext* ctx) {\n  ctx->NewBuilder().Broadcast(ctx->inputs()).Broadcast(ctx->outputs()).Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> EyeOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, ctx->Attr<DataType>(\"dtype\"));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/fake_quantization_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/* static */ Maybe<void> FakeQuantizationOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& in_shape = ctx->InputShape(\"in\", 0);\n  const Shape& scale_shape = ctx->InputShape(\"scale\", 0);\n  const Shape& zero_point_shape = ctx->InputShape(\"zero_point\", 0);\n\n  // NOTE(Liang Depeng): scale_shape->elem_cnt() > 1 means per-channel quantization for\n  // convolution weights.\n  if (scale_shape.elem_cnt() > 1) {\n    CHECK_EQ_OR_RETURN(scale_shape.elem_cnt(), in_shape.At(0));\n    CHECK_EQ_OR_RETURN(zero_point_shape.elem_cnt(), in_shape.At(0));\n  }\n\n  ctx->SetOutputShape(\"out\", 0, in_shape);\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> FakeQuantizationOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> FakeQuantizationOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"in\", 0);\n  const Shape& logical_scale_shape =\n      ctx->LogicalTensorDesc4InputArgNameAndIndex(\"scale\", 0).shape();\n  ctx->NewBuilder()\n      .Broadcast(user_op::OpArg(\"in\", 0))\n      .Broadcast(user_op::OpArg(\"scale\", 0))\n      .Broadcast(user_op::OpArg(\"zero_point\", 0))\n      .Broadcast(user_op::OpArg(\"out\", 0))\n      .Build();\n  if (logical_scale_shape.elem_cnt() > 1) {\n    // NOTE(Liang Depeng): only consider convolution weight per-channel quantization\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"in\", 0), 0)\n        .Split(user_op::OpArg(\"scale\", 0), 0)\n        .Split(user_op::OpArg(\"zero_point\", 0), 0)\n        .Split(user_op::OpArg(\"out\", 0), 0)\n        .Build();\n  } else {\n    // NOTE(Liang Depeng): the sbp signature of per-layer quantization is the same as eltwise\n    // ops\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"in\", 0), 0)\n        .Broadcast(user_op::OpArg(\"scale\", 0))\n        .Broadcast(user_op::OpArg(\"zero_point\", 0))\n        .Split(user_op::OpArg(\"out\", 0), 0)\n        .Build();\n  }\n  FOR_RANGE(int64_t, i, 1, in_tensor.shape().NumAxes()) {\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"in\", 0), i)\n        .Broadcast(user_op::OpArg(\"scale\", 0))\n        .Broadcast(user_op::OpArg(\"zero_point\", 0))\n        .Split(user_op::OpArg(\"out\", 0), i)\n        .Build();\n  }\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> FakeQuantizationOp::ModifyInputArg(\n    const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) {\n  user_op::InputArgModifier* scale = GetInputArgModifierFn(\"scale\", 0);\n  CHECK_OR_RETURN(scale != nullptr);\n  scale->set_requires_grad(false);\n\n  user_op::InputArgModifier* zero_point = GetInputArgModifierFn(\"zero_point\", 0);\n  CHECK_OR_RETURN(zero_point != nullptr);\n  zero_point->set_requires_grad(false);\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> FakeQuantizationOp::CheckAttr(const user_op::UserOpDefWrapper& def,\n                                                       const user_op::UserOpConfWrapper& conf) {\n  const int32_t quantization_bit = conf.attr<int32_t>(\"quantization_bit\");\n  CHECK_GT_OR_RETURN(quantization_bit, 1);\n  CHECK_LE_OR_RETURN(quantization_bit, 8);\n\n  std::string quantization_scheme = conf.attr<std::string>(\"quantization_scheme\");\n  CHECK_OR_RETURN(quantization_scheme == \"symmetric\" || quantization_scheme == \"affine\");\n\n  std::string quantization_formula = conf.attr<std::string>(\"quantization_formula\");\n  CHECK_OR_RETURN(quantization_formula == \"google\" || quantization_formula == \"cambricon\");\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> FakeQuantizationOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/fft_ops.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <cstdint>\n#include \"oneflow/core/common/data_type.pb.h\"\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\nnamespace oneflow {\n\n/* static */ Maybe<void> FftC2COp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& in_shape = ctx->InputShape(\"input\", 0);\n  Stride out_stride = Stride(in_shape);  // contiguous\n  ctx->SetOutputShape(\"out\", 0, in_shape);\n  ctx->SetOutputStride(\"out\", 0, out_stride);\n  ctx->SetOutputIsDynamic(\"out\", 0, ctx->InputIsDynamic(\"input\", 0));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> FftC2COp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> FftC2COp::GetSbp(user_op::SbpContext* ctx) {\n  ctx->NewBuilder()\n      .PartialSum(user_op::OpArg(\"input\", 0))\n      .PartialSum(user_op::OpArg(\"out\", 0))\n      .Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> FftC2COp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"input\", 0));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> FftR2COp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& in_shape = ctx->InputShape(\"input\", 0);\n  const auto& dims = ctx->Attr<std::vector<int64_t>>(\"dims\");\n  bool onesided = ctx->Attr<bool>(\"onesided\");\n\n  Shape out_shape = in_shape;\n  auto last_dim = dims.back();\n  if (onesided) { out_shape[last_dim] = out_shape[last_dim] / 2 + 1; }\n  Stride out_stride = Stride(out_shape);\n  ctx->SetOutputShape(\"out\", 0, out_shape);\n  ctx->SetOutputStride(\"out\", 0, out_stride);\n  ctx->SetOutputIsDynamic(\"out\", 0, ctx->InputIsDynamic(\"input\", 0));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> FftR2COp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> FftR2COp::GetSbp(user_op::SbpContext* ctx) {\n  // TO-DO : Validate sbp\n  ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> FftR2COp::InferDataType(user_op::InferContext* ctx) {\n  const DataType& input_type = ctx->InputDType(\"input\", 0);\n  switch (input_type) {\n    case (kFloat): ctx->SetOutputDType(\"out\", 0, kComplex64); break;\n    case (kDouble): ctx->SetOutputDType(\"out\", 0, kComplex128); break;\n    default: CHECK_OR_RETURN(false) << \"RuntimeError: dtype can't be handled\";\n  }\n\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> FftC2ROp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& in_shape = ctx->InputShape(\"input\", 0);\n\n  const auto& dims = ctx->Attr<std::vector<int64_t>>(\"dims\");\n  int64_t last_dim_size = ctx->Attr<int64_t>(\"last_dim_size\");\n\n  Shape out_shape = in_shape;\n  out_shape[dims.back()] = last_dim_size;\n  Stride out_stride = Stride(out_shape);\n  ctx->SetOutputShape(\"out\", 0, out_shape);\n  ctx->SetOutputStride(\"out\", 0, out_stride);\n  ctx->SetOutputIsDynamic(\"out\", 0, ctx->InputIsDynamic(\"input\", 0));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> FftC2ROp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> FftC2ROp::GetSbp(user_op::SbpContext* ctx) {\n  // TO-DO : Validate sbp\n  ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> FftC2ROp::InferDataType(user_op::InferContext* ctx) {\n  const DataType& input_type = ctx->InputDType(\"input\", 0);\n  switch (input_type) {\n    case (kComplex64): ctx->SetOutputDType(\"out\", 0, kFloat); break;\n    case (kComplex128): ctx->SetOutputDType(\"out\", 0, kDouble); break;\n    default: CHECK_OR_RETURN(false) << \"RuntimeError: dtype can't be handled\";\n  }\n\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow"
  },
  {
    "path": "oneflow/user/ops/fill_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/* static */ Maybe<void> FillOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& in_shape = ctx->InputShape(\"in\", 0);\n  ctx->SetOutputShape(\"out\", 0, in_shape);\n  ctx->SetOutputStride(\"out\", 0, ctx->InputStride(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> FillOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> FillOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"in\", 0);\n  FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) {\n    ctx->NewBuilder().Split(user_op::OpArg(\"in\", 0), i).Split(user_op::OpArg(\"out\", 0), i).Build();\n  }\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> FillOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> FillTensorOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& in_shape = ctx->InputShape(\"in\", 0);\n  ctx->SetOutputShape(\"out\", 0, in_shape);\n  ctx->SetOutputStride(\"out\", 0, ctx->InputStride(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> FillTensorOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> FillTensorOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"in\", 0);\n  FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) {\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"in\", 0), i)\n        .Broadcast(user_op::OpArg(\"value\", 0))\n        .Split(user_op::OpArg(\"out\", 0), i)\n        .Build();\n  }\n  ctx->NewBuilder()\n      .PartialSum(user_op::OpArg(\"in\", 0))\n      .PartialSum(user_op::OpArg(\"value\", 0))\n      .PartialSum(user_op::OpArg(\"out\", 0))\n      .Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> FillTensorOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/flip_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/*static*/ auto FlipOp::InferLogicalTensorDesc(user_op::InferContext* ctx) -> Maybe<void> {\n  const user_op::TensorDesc& x_desc = ctx->InputTensorDesc(\"x\", 0);\n  const int input_dims = x_desc.shape().NumAxes();\n  const std::vector<int32_t> dims = ctx->Attr<std::vector<int32_t>>(\"dims\");\n  CHECK_OR_RETURN(dims.size() <= input_dims) << \"len of dims must less than len of input tensor\";\n  for (auto x : dims) { CHECK_OR_RETURN(x < input_dims) << \"dims parameter is illegal.\"; }\n  user_op::TensorDesc* y_desc = ctx->MutOutputTensorDesc(\"y\", 0);\n  y_desc->set_shape(x_desc.shape());\n  return Maybe<void>::Ok();\n}\n/*static*/ auto FlipOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) -> Maybe<void> {\n  return FlipOp::InferLogicalTensorDesc(ctx);\n}\n/*static*/ auto FlipOp::GetSbp(user_op::SbpContext* ctx) -> Maybe<void> {\n  const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"x\", 0);\n  const std::vector<int32_t> dims = ctx->Attr<std::vector<int32_t>>(\"dims\");\n  FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) {\n    bool flag = true;\n    for (auto x : dims) {\n      if (x == i) {\n        flag = false;\n        break;\n      }\n    }\n    if (flag) {\n      ctx->NewBuilder().Split(user_op::OpArg(\"x\", 0), i).Split(user_op::OpArg(\"y\", 0), i).Build();\n    }\n  }\n  return Maybe<void>::Ok();\n}\n/*static*/ auto FlipOp::InferDataType(user_op::InferContext* ctx) -> Maybe<void> {\n  ctx->SetOutputDType(\"y\", 0, ctx->InputDType(\"x\", 0));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/frac_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/tensor_desc.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/*static*/ auto FracOp::InferLogicalTensorDesc(user_op::InferContext* ctx) -> Maybe<void> {\n  const user_op::TensorDesc& x_desc = ctx->InputTensorDesc(\"x\", 0);\n  user_op::TensorDesc* y_desc = ctx->MutOutputTensorDesc(\"y\", 0);\n  y_desc->set_shape(x_desc.shape());\n  return Maybe<void>::Ok();\n}\n/*static*/ auto FracOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) -> Maybe<void> {\n  return FracOp::InferLogicalTensorDesc(ctx);\n}\n/*static*/ auto FracOp::GetSbp(user_op::SbpContext* ctx) -> Maybe<void> {\n  const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"x\", 0);\n  FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) {\n    ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build();\n  }\n  return Maybe<void>::Ok();\n}\n/*static*/ auto FracOp::InferDataType(user_op::InferContext* ctx) -> Maybe<void> {\n  ctx->SetOutputDType(\"y\", 0, ctx->InputDType(\"x\", 0));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/fused_attention_ops.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nMaybe<void> ParseDims(const Shape& shape, const std::string& layout,\n                      const Optional<int64_t>& batch_size, const Optional<int64_t>& seq_len,\n                      const Optional<int64_t>& num_heads, const Optional<int64_t>& head_size,\n                      int64_t* b, int64_t* m, int64_t* h, int64_t* k, bool* bm_packed) {\n  if (shape.NumAxes() == 2) {\n    if (layout == \"(BM)(HK)\" || layout == \"(BM)(H2K)\" || layout == \"(BM)(H3K)\") {\n      *bm_packed = true;\n      CHECK_OR_RETURN(batch_size);\n      CHECK_OR_RETURN(seq_len);\n      *b = JUST(batch_size);\n      *m = JUST(seq_len);\n      int64_t packed_n = 0;\n      if (layout == \"(BM)(HK)\") {\n        packed_n = 1;\n      } else if (layout == \"(BM)(H2K)\") {\n        packed_n = 2;\n      } else if (layout == \"(BM)(H3K)\") {\n        packed_n = 3;\n      } else {\n        UNIMPLEMENTED_THEN_RETURN();\n      }\n      const int64_t hidden_size = shape.At(1);\n      if (num_heads) {\n        const int64_t expected_h = JUST(num_heads);\n        const int64_t packed_h = packed_n * expected_h;\n        CHECK_EQ_OR_RETURN(hidden_size % packed_h, 0);\n        *h = expected_h;\n        *k = hidden_size / packed_h;\n      } else if (head_size) {\n        const int64_t expected_k = JUST(head_size);\n        const int64_t packed_k = packed_n * expected_k;\n        CHECK_EQ_OR_RETURN(hidden_size % packed_k, 0);\n        *h = hidden_size / packed_k;\n        *k = expected_k;\n      } else {\n        UNIMPLEMENTED_THEN_RETURN();\n      }\n    } else {\n      UNIMPLEMENTED_THEN_RETURN();\n    }\n  } else if (shape.NumAxes() == 3) {\n    if (layout == \"BM(HK)\" || layout == \"MB(HK)\" || layout == \"BM(H2K)\" || layout == \"MB(H2K)\"\n        || layout == \"BM(H3K)\" || layout == \"MB(H3K)\") {\n      *bm_packed = false;\n      int64_t packed_n = 0;\n      if (layout == \"BM(HK)\") {\n        *b = shape.At(0);\n        *m = shape.At(1);\n        packed_n = 1;\n      } else if (layout == \"MB(HK)\") {\n        *b = shape.At(1);\n        *m = shape.At(0);\n        packed_n = 1;\n      } else if (layout == \"BM(H2K)\") {\n        *b = shape.At(0);\n        *m = shape.At(1);\n        packed_n = 2;\n      } else if (layout == \"MB(H2K)\") {\n        *b = shape.At(1);\n        *m = shape.At(0);\n        packed_n = 2;\n      } else if (layout == \"BM(H3K)\") {\n        *b = shape.At(0);\n        *m = shape.At(1);\n        packed_n = 3;\n      } else if (layout == \"MB(H3K)\") {\n        *b = shape.At(1);\n        *m = shape.At(0);\n        packed_n = 3;\n      } else {\n        UNIMPLEMENTED_THEN_RETURN();\n      }\n      const int64_t hidden_size = shape.At(2);\n      if (num_heads) {\n        const int64_t expected_h = JUST(num_heads);\n        const int64_t packed_h = packed_n * expected_h;\n        CHECK_EQ_OR_RETURN(hidden_size % packed_h, 0);\n        *h = expected_h;\n        *k = hidden_size / packed_h;\n      } else if (head_size) {\n        const int64_t expected_k = JUST(head_size);\n        const int64_t packed_k = packed_n * expected_k;\n        CHECK_EQ_OR_RETURN(hidden_size % packed_k, 0);\n        *h = hidden_size / packed_k;\n        *k = expected_k;\n      } else {\n        UNIMPLEMENTED_THEN_RETURN();\n      }\n    } else if (layout == \"(BM)HK\") {\n      *bm_packed = true;\n      CHECK_OR_RETURN(batch_size);\n      CHECK_OR_RETURN(seq_len);\n      *b = JUST(batch_size);\n      *m = JUST(seq_len);\n      *h = shape.At(1);\n      *k = shape.At(2);\n    } else {\n      UNIMPLEMENTED_THEN_RETURN();\n    }\n  } else if (shape.NumAxes() == 4) {\n    *bm_packed = false;\n    if (layout == \"BMHK\") {\n      *b = shape.At(0);\n      *m = shape.At(1);\n      *h = shape.At(2);\n      *k = shape.At(3);\n    } else if (layout == \"BHMK\") {\n      *b = shape.At(0);\n      *m = shape.At(2);\n      *h = shape.At(1);\n      *k = shape.At(3);\n    } else if (layout == \"MBHK\") {\n      *b = shape.At(1);\n      *m = shape.At(0);\n      *h = shape.At(2);\n      *k = shape.At(3);\n    } else {\n      UNIMPLEMENTED_THEN_RETURN();\n    }\n  } else {\n    UNIMPLEMENTED_THEN_RETURN();\n  };\n  if (batch_size) {\n    const int64_t expected_b = JUST(batch_size);\n    CHECK_EQ_OR_RETURN(*b, expected_b);\n  }\n  if (seq_len) {\n    const int64_t expected_m = JUST(seq_len);\n    CHECK_EQ_OR_RETURN(*m, expected_m);\n  }\n  if (num_heads) {\n    const int64_t expected_h = JUST(num_heads);\n    CHECK_EQ_OR_RETURN(*h, expected_h);\n  }\n  if (head_size) {\n    const int64_t expected_k = JUST(head_size);\n    CHECK_EQ_OR_RETURN(*k, expected_k);\n  }\n\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> ParseDims(const Shape& shape, const std::string& layout,\n                      const Optional<int64_t>& num_heads, const Optional<int64_t>& head_size,\n                      int64_t* b, int64_t* m, int64_t* h, int64_t* k) {\n  bool bm_packed{};\n  return ParseDims(shape, layout, Optional<int64_t>(), Optional<int64_t>(), num_heads, head_size, b,\n                   m, h, k, &bm_packed);\n}\n\nMaybe<Shape> LayoutToShape(int64_t b, int64_t m, int64_t h, int64_t k, const std::string& layout) {\n  if (layout == \"BM(HK)\") {\n    return Shape({b, m, h * k});\n  } else if (layout == \"BM(H2K)\") {\n    return Shape({b, m, h * k * 2});\n  } else if (layout == \"BM(H3K)\") {\n    return Shape({b, m, h * k * 3});\n  } else if (layout == \"MB(HK)\") {\n    return Shape({m, b, h * k});\n  } else if (layout == \"MB(H2K)\") {\n    return Shape({m, b, h * k * 2});\n  } else if (layout == \"MB(H3K)\") {\n    return Shape({m, b, h * k * 3});\n  } else if (layout == \"BMHK\") {\n    return Shape({b, m, h, k});\n  } else if (layout == \"BHMK\") {\n    return Shape({b, h, m, k});\n  } else if (layout == \"MBHK\") {\n    return Shape({m, b, h, k});\n  } else {\n    UNIMPLEMENTED_THEN_RETURN();\n  }\n}\n\nMaybe<void> ParseSplitAxis(const std::string& layout, bool can_hk_split, int64_t* b_split_axis,\n                           int64_t* h_split_axis) {\n  if (layout == \"BM(HK)\" || layout == \"BM(H2K)\" || layout == \"BM(H3K)\") {\n    *b_split_axis = 0;\n    if (can_hk_split) {\n      *h_split_axis = 2;\n    } else {\n      *h_split_axis = -1;\n    }\n  } else if (layout == \"MB(HK)\" || layout == \"MB(H2K)\" || layout == \"MB(H3K)\") {\n    *b_split_axis = 1;\n    if (can_hk_split) {\n      *h_split_axis = 2;\n    } else {\n      *h_split_axis = -1;\n    }\n  } else if (layout == \"BMHK\") {\n    *b_split_axis = 0;\n    *h_split_axis = 2;\n  } else if (layout == \"BHMK\") {\n    *b_split_axis = 0;\n    *h_split_axis = 1;\n  } else if (layout == \"MBHK\") {\n    *b_split_axis = 1;\n    *h_split_axis = 2;\n  } else if (layout == \"(BM)HK\") {\n    *b_split_axis = -1;\n    *h_split_axis = 1;\n  } else if (layout == \"(BM)(HK)\" || layout == \"(BM)(H2K)\" || layout == \"(BM)(H3K)\") {\n    *b_split_axis = -1;\n    if (can_hk_split) {\n      *h_split_axis = 1;\n    } else {\n      *h_split_axis = -1;\n    }\n  } else {\n    UNIMPLEMENTED_THEN_RETURN();\n  }\n  return Maybe<void>::Ok();\n};\n\n}  // namespace\n\n/*static*/ auto FusedMultiHeadAttentionInferenceOp::InferDataType(user_op::InferContext* ctx)\n    -> Maybe<void> {\n  DataType query_type = ctx->InputDType(\"query\", 0);\n  DataType key_type = ctx->InputDType(\"key\", 0);\n  DataType value_type = ctx->InputDType(\"value\", 0);\n  CHECK_EQ_OR_RETURN(key_type, query_type);\n  CHECK_EQ_OR_RETURN(value_type, query_type);\n  if (ctx->has_input(\"attn_bias\", 0)) {\n    CHECK_EQ_OR_RETURN(ctx->InputDType(\"attn_bias\", 0), query_type);\n  }\n  if (ctx->has_input(\"query_seq_start\", 0)) {\n    CHECK_EQ_OR_RETURN(ctx->InputDType(\"query_seq_start\", 0), DataType::kInt32);\n  }\n  if (ctx->has_input(\"key_seq_start\", 0)) {\n    CHECK_EQ_OR_RETURN(ctx->InputDType(\"key_seq_start\", 0), DataType::kInt32);\n  }\n  if (ctx->has_input(\"key_seq_len\", 0)) {\n    CHECK_EQ_OR_RETURN(ctx->InputDType(\"key_seq_len\", 0), DataType::kInt32);\n  }\n  ctx->SetOutputDType(\"out\", 0, query_type);\n  return Maybe<void>::Ok();\n}\n\n/*static*/ auto FusedMultiHeadAttentionInferenceOp::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) -> Maybe<void> {\n  const int64_t query_head_size = ctx->Attr<int64_t>(\"query_head_size\");\n  CHECK_GE_OR_RETURN(query_head_size, 1);\n\n  Optional<int64_t> batch_size;\n  if (ctx->has_input(\"query_seq_start\", 0)) {\n    CHECK_OR_RETURN(ctx->has_input(\"key_seq_start\", 0));\n    const Shape& query_seq_start_shape = ctx->InputShape(\"query_seq_start\", 0);\n    CHECK_EQ_OR_RETURN(query_seq_start_shape.NumAxes(), 1);\n    CHECK_GT_OR_RETURN(query_seq_start_shape.At(0), 1);\n    CHECK_OR_RETURN(ctx->InputShape(\"key_seq_start\", 0) == query_seq_start_shape);\n    batch_size = query_seq_start_shape.At(0) - 1;\n    if (ctx->has_input(\"key_seq_len\", 0)) {\n      const Shape& key_seq_len_shape = ctx->InputShape(\"key_seq_len\", 0);\n      CHECK_EQ_OR_RETURN(key_seq_len_shape.NumAxes(), 1);\n      CHECK_EQ_OR_RETURN(key_seq_len_shape.At(0), query_seq_start_shape.At(0) - 1);\n    }\n  } else {\n    CHECK_OR_RETURN(!ctx->has_input(\"key_seq_start\", 0));\n    CHECK_OR_RETURN(!ctx->has_input(\"key_seq_len\", 0));\n  }\n\n  Optional<int64_t> query_max_seq_len;\n  const int64_t attr_query_max_seq_len = ctx->Attr<int64_t>(\"query_max_seq_len\");\n  if (attr_query_max_seq_len != 0) { query_max_seq_len = attr_query_max_seq_len; }\n  Optional<int64_t> key_max_seq_len;\n  const int64_t attr_key_max_seq_len = ctx->Attr<int64_t>(\"key_max_seq_len\");\n  if (attr_key_max_seq_len != 0) { key_max_seq_len = attr_key_max_seq_len; }\n  const Shape& query_shape = ctx->InputShape(\"query\", 0);\n  const std::string& query_layout = ctx->Attr<std::string>(\"query_layout\");\n  int64_t q_b = 0;\n  int64_t q_m = 0;\n  int64_t q_h = 0;\n  int64_t q_k = 0;\n  bool q_bm_packed = false;\n  JUST(ParseDims(query_shape, query_layout, batch_size, query_max_seq_len, Optional<int64_t>(),\n                 query_head_size, &q_b, &q_m, &q_h, &q_k, &q_bm_packed));\n  if (q_bm_packed) { CHECK_OR_RETURN(ctx->has_input(\"query_seq_start\", 0)); }\n\n  const Shape& key_shape = ctx->InputShape(\"key\", 0);\n  const std::string& key_layout = ctx->Attr<std::string>(\"key_layout\");\n  int64_t k_b = 0;\n  int64_t k_m = 0;\n  int64_t k_h = 0;\n  int64_t k_k = 0;\n  bool k_bm_packed = false;\n  JUST(ParseDims(key_shape, key_layout, q_b, key_max_seq_len, q_h, q_k, &k_b, &k_m, &k_h, &k_k,\n                 &k_bm_packed));\n  CHECK_EQ_OR_RETURN(k_b, q_b);\n  CHECK_EQ_OR_RETURN(k_h, q_h);\n  CHECK_EQ_OR_RETURN(k_bm_packed, q_bm_packed);\n\n  const Shape& value_shape = ctx->InputShape(\"value\", 0);\n  const std::string& value_layout = ctx->Attr<std::string>(\"value_layout\");\n  int64_t v_b = 0;\n  int64_t v_m = 0;\n  int64_t v_h = 0;\n  int64_t v_k = 0;\n  bool v_bm_packed = false;\n  JUST(ParseDims(value_shape, value_layout, q_b, k_m, q_h, Optional<int64_t>(), &v_b, &v_m, &v_h,\n                 &v_k, &v_bm_packed));\n  CHECK_EQ_OR_RETURN(v_b, q_b);\n  CHECK_EQ_OR_RETURN(v_m, k_m);\n  CHECK_EQ_OR_RETURN(v_bm_packed, k_bm_packed);\n\n  if (ctx->has_input(\"attn_bias\", 0)) {\n    const Shape& attn_bias_shape = ctx->InputShape(\"attn_bias\", 0);\n    const int64_t num_attn_bias_axes = attn_bias_shape.NumAxes();\n    CHECK_GE_OR_RETURN(num_attn_bias_axes, 1);\n    CHECK_LE_OR_RETURN(num_attn_bias_axes, 4);\n    DimVector padded_attn_bias_shape;\n    for (int i = 0; i < 4 - num_attn_bias_axes; ++i) { padded_attn_bias_shape.push_back(1); }\n    for (int i = 0; i < num_attn_bias_axes; ++i) {\n      padded_attn_bias_shape.push_back(attn_bias_shape.At(i));\n    }\n    CHECK_OR_RETURN(padded_attn_bias_shape.at(0) == 1 || padded_attn_bias_shape.at(0) == q_b);\n    CHECK_OR_RETURN(padded_attn_bias_shape.at(1) == 1 || padded_attn_bias_shape.at(1) == q_h);\n    CHECK_OR_RETURN(padded_attn_bias_shape.at(2) == 1 || padded_attn_bias_shape.at(2) >= q_m);\n    CHECK_OR_RETURN(padded_attn_bias_shape.at(3) >= k_m);\n  }\n  const std::string& output_layout = ctx->Attr<std::string>(\"output_layout\");\n  const bool o_bm_packed = output_layout == \"(BM)(HK)\";\n  CHECK_EQ(o_bm_packed, q_bm_packed);\n  if (output_layout == \"(BM)(HK)\") {\n    ctx->SetOutputShape(\"out\", 0, Shape({query_shape.At(0), q_h * v_k}));\n  } else if (output_layout == \"BM(HK)\") {\n    ctx->SetOutputShape(\"out\", 0, Shape({q_b, q_m, q_h * v_k}));\n  } else if (output_layout == \"MB(HK)\") {\n    ctx->SetOutputShape(\"out\", 0, Shape({q_m, q_b, q_h * v_k}));\n  } else {\n    UNIMPLEMENTED_THEN_RETURN();\n  }\n  return Maybe<void>::Ok();\n}\n/*static*/ auto FusedMultiHeadAttentionInferenceOp::InferPhysicalTensorDesc(\n    user_op::InferContext* ctx) -> Maybe<void> {\n  return FusedMultiHeadAttentionInferenceOp::InferLogicalTensorDesc(ctx);\n}\n/*static*/ auto FusedMultiHeadAttentionInferenceOp::GetSbp(user_op::SbpContext* ctx)\n    -> Maybe<void> {\n  const int64_t query_head_size = ctx->user_op_conf().attr<int64_t>(\"query_head_size\");\n  const std::string& query_layout = ctx->user_op_conf().attr<std::string>(\"query_layout\");\n  const std::string& key_layout = ctx->user_op_conf().attr<std::string>(\"key_layout\");\n  const std::string& value_layout = ctx->user_op_conf().attr<std::string>(\"value_layout\");\n  const std::string& output_layout = ctx->user_op_conf().attr<std::string>(\"output_layout\");\n  int64_t num_heads = 0;\n  const user_op::TensorDesc& query = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"query\", 0);\n  if (query.shape().NumAxes() == 2) {\n    if (query_layout == \"(BM)(HK)\") {\n      CHECK_EQ_OR_RETURN(query.shape().At(1) % query_head_size, 0);\n      num_heads = query.shape().At(1) / query_head_size;\n    } else if (query_layout == \"(BM)(H3K)\") {\n      CHECK_EQ_OR_RETURN(query.shape().At(1) % (query_head_size * 3), 0);\n      num_heads = query.shape().At(1) / (query_head_size * 3);\n    } else {\n      UNIMPLEMENTED_THEN_RETURN();\n    }\n  } else if (query.shape().NumAxes() == 3) {\n    if (query_layout == \"BM(HK)\" || query_layout == \"MB(HK)\") {\n      CHECK_EQ_OR_RETURN(query.shape().At(2) % query_head_size, 0);\n      num_heads = query.shape().At(2) / query_head_size;\n    } else if (query_layout == \"BM(H3K)\" || query_layout == \"MB(H3K)\") {\n      CHECK_EQ_OR_RETURN(query.shape().At(2) % (query_head_size * 3), 0);\n      num_heads = query.shape().At(2) / (query_head_size * 3);\n    } else if (query_layout == \"(BM)HK\") {\n      num_heads = query.shape().At(1);\n    } else {\n      UNIMPLEMENTED_THEN_RETURN();\n    }\n  } else if (query.shape().NumAxes() == 4) {\n    if (query_layout == \"BMHK\") {\n      num_heads = query.shape().At(2);\n    } else if (query_layout == \"BHMK\") {\n      num_heads = query.shape().At(1);\n    } else {\n      UNIMPLEMENTED_THEN_RETURN();\n    }\n  } else {\n    UNIMPLEMENTED_THEN_RETURN();\n  }\n  const bool can_hk_split = num_heads % ctx->parallel_num() == 0;\n  int64_t q_b_split_axis = -1;\n  int64_t q_h_split_axis = -1;\n  JUST(ParseSplitAxis(query_layout, can_hk_split, &q_b_split_axis, &q_h_split_axis));\n  int64_t k_b_split_axis = -1;\n  int64_t k_h_split_axis = -1;\n  JUST(ParseSplitAxis(key_layout, can_hk_split, &k_b_split_axis, &k_h_split_axis));\n  int64_t v_b_split_axis = -1;\n  int64_t v_h_split_axis = -1;\n  JUST(ParseSplitAxis(value_layout, can_hk_split, &v_b_split_axis, &v_h_split_axis));\n  int64_t o_b_split_axis = -1;\n  int64_t o_h_split_axis = -1;\n  JUST(ParseSplitAxis(output_layout, can_hk_split, &o_b_split_axis, &o_h_split_axis));\n\n  std::vector<user_op::OpArg> attn_bias_arg;\n  if (ctx->user_op_conf().has_input(\"attn_bias\", 0)) { attn_bias_arg.emplace_back(\"attn_bias\", 0); }\n  std::vector<user_op::OpArg> var_len_args;\n  if (ctx->user_op_conf().has_input(\"query_seq_start\", 0)) {\n    var_len_args.emplace_back(\"query_seq_start\", 0);\n  }\n  if (ctx->user_op_conf().has_input(\"key_seq_start\", 0)) {\n    var_len_args.emplace_back(\"key_seq_start\", 0);\n  }\n  if (ctx->user_op_conf().has_input(\"key_seq_len\", 0)) {\n    var_len_args.emplace_back(\"key_seq_len\", 0);\n  }\n  if (q_b_split_axis >= 0 && k_b_split_axis >= 0 && v_b_split_axis >= 0 && o_b_split_axis >= 0\n      && var_len_args.empty()) {\n    bool broadcast_attn_bias = false;\n    if (ctx->user_op_conf().has_input(\"attn_bias\", 0)) {\n      const user_op::TensorDesc& attn_bias =\n          ctx->LogicalTensorDesc4InputArgNameAndIndex(\"attn_bias\", 0);\n      if (attn_bias.shape().NumAxes() < 4 || attn_bias.shape().At(0) == 1) {\n        broadcast_attn_bias = true;\n      }\n    }\n    if (broadcast_attn_bias) {\n      ctx->NewBuilder()\n          .Split(user_op::OpArg(\"query\", 0), q_b_split_axis)\n          .Split(user_op::OpArg(\"key\", 0), k_b_split_axis)\n          .Split(user_op::OpArg(\"value\", 0), v_b_split_axis)\n          .Broadcast(attn_bias_arg)\n          .Split(ctx->outputs(), o_b_split_axis)\n          .Build();\n\n    } else {\n      ctx->NewBuilder()\n          .Split(user_op::OpArg(\"query\", 0), q_b_split_axis)\n          .Split(user_op::OpArg(\"key\", 0), k_b_split_axis)\n          .Split(user_op::OpArg(\"value\", 0), v_b_split_axis)\n          .Split(attn_bias_arg, 0)\n          .Split(ctx->outputs(), o_b_split_axis)\n          .Build();\n    }\n  }\n  if (q_h_split_axis >= 0 && k_h_split_axis >= 0 && v_h_split_axis >= 0 && o_h_split_axis >= 0) {\n    bool broadcast_attn_bias = false;\n    if (ctx->user_op_conf().has_input(\"attn_bias\", 0)) {\n      const user_op::TensorDesc& attn_bias =\n          ctx->LogicalTensorDesc4InputArgNameAndIndex(\"attn_bias\", 0);\n      if (attn_bias.shape().NumAxes() == 4) {\n        if (attn_bias.shape().At(1) == 1) { broadcast_attn_bias = true; }\n      } else if (attn_bias.shape().NumAxes() == 3) {\n        if (attn_bias.shape().At(0) == 1) { broadcast_attn_bias = true; }\n      } else {\n        broadcast_attn_bias = true;\n      }\n    }\n    if (broadcast_attn_bias) {\n      ctx->NewBuilder()\n          .Split(user_op::OpArg(\"query\", 0), q_h_split_axis)\n          .Split(user_op::OpArg(\"key\", 0), k_h_split_axis)\n          .Split(user_op::OpArg(\"value\", 0), v_h_split_axis)\n          .Broadcast(attn_bias_arg)\n          .Broadcast(var_len_args)\n          .Split(ctx->outputs(), o_h_split_axis)\n          .Build();\n\n    } else {\n      ctx->NewBuilder()\n          .Split(user_op::OpArg(\"query\", 0), q_h_split_axis)\n          .Split(user_op::OpArg(\"key\", 0), k_h_split_axis)\n          .Split(user_op::OpArg(\"value\", 0), v_h_split_axis)\n          .Split(attn_bias_arg, 1)\n          .Broadcast(var_len_args)\n          .Split(ctx->outputs(), o_h_split_axis)\n          .Build();\n    }\n  }\n  return Maybe<void>::Ok();\n}\n\n/*static*/ auto FusedAttentionConcatPastKeyValueOp::InferDataType(user_op::InferContext* ctx)\n    -> Maybe<void> {\n  const DataType data_type = ctx->InputDType(\"key\", 0);\n  CHECK_EQ_OR_RETURN(ctx->InputDType(\"value\", 0), data_type);\n  if (ctx->has_input(\"past_key\", 0)) {\n    CHECK_EQ_OR_RETURN(ctx->InputDType(\"past_key\", 0), data_type);\n  }\n  if (ctx->has_input(\"past_value\", 0)) {\n    CHECK_EQ_OR_RETURN(ctx->InputDType(\"past_value\", 0), data_type);\n  }\n  ctx->SetOutputDType(\"output_key\", 0, data_type);\n  ctx->SetOutputDType(\"output_value\", 0, data_type);\n  return Maybe<void>::Ok();\n}\n\n/*static*/ auto FusedAttentionConcatPastKeyValueOp::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) -> Maybe<void> {\n  const int64_t key_head_size = ctx->Attr<int64_t>(\"key_head_size\");\n  CHECK_GE_OR_RETURN(key_head_size, 1);\n\n  const Shape& key_shape = ctx->InputShape(\"key\", 0);\n  const std::string& key_layout = ctx->Attr<std::string>(\"key_layout\");\n  int64_t k_b = 0;\n  int64_t k_m = 0;\n  int64_t k_h = 0;\n  int64_t k_k = 0;\n  JUST(\n      ParseDims(key_shape, key_layout, Optional<int64_t>(), key_head_size, &k_b, &k_m, &k_h, &k_k));\n\n  const Shape& value_shape = ctx->InputShape(\"value\", 0);\n  const std::string& value_layout = ctx->Attr<std::string>(\"value_layout\");\n  int64_t v_b = 0;\n  int64_t v_m = 0;\n  int64_t v_h = 0;\n  int64_t v_k = 0;\n  JUST(ParseDims(value_shape, value_layout, k_h, k_k, &v_b, &v_m, &v_h, &v_k));\n  CHECK_EQ_OR_RETURN(v_b, k_b);\n  CHECK_EQ_OR_RETURN(v_m, k_m);\n\n  int64_t past_k_b = 0;\n  int64_t past_k_m = 0;\n  int64_t past_k_h = 0;\n  int64_t past_k_k = 0;\n  int64_t past_v_b = 0;\n  int64_t past_v_m = 0;\n  int64_t past_v_h = 0;\n  int64_t past_v_k = 0;\n  const std::string& past_key_layout = ctx->Attr<std::string>(\"past_key_layout\");\n  const std::string& past_value_layout = ctx->Attr<std::string>(\"past_value_layout\");\n  if (ctx->has_input(\"past_key\", 0)) {\n    CHECK_OR_RETURN(ctx->has_input(\"past_value\", 0));\n    const Shape& past_key_shape = ctx->InputShape(\"past_key\", 0);\n    JUST(ParseDims(past_key_shape, past_key_layout, k_h, k_k, &past_k_b, &past_k_m, &past_k_h,\n                   &past_k_k));\n    CHECK_EQ_OR_RETURN(past_k_b, k_b);\n\n    const Shape& past_value_shape = ctx->InputShape(\"past_value\", 0);\n    JUST(ParseDims(past_value_shape, past_value_layout, k_h, k_k, &past_v_b, &past_v_m, &past_v_h,\n                   &past_v_k));\n    CHECK_EQ_OR_RETURN(past_v_b, k_b);\n    CHECK_EQ_OR_RETURN(past_v_m, past_k_m);\n  } else {\n    CHECK_OR_RETURN(!ctx->has_input(\"past_value\", 0));\n  }\n\n  ctx->SetOutputShape(\"output_key\", 0,\n                      *JUST(LayoutToShape(k_b, past_k_m + k_m, k_h, k_k, past_key_layout)));\n  ctx->SetOutputShape(\"output_value\", 0,\n                      *JUST(LayoutToShape(v_b, past_v_m + v_m, v_h, v_k, past_value_layout)));\n  return Maybe<void>::Ok();\n}\n/*static*/ auto FusedAttentionConcatPastKeyValueOp::InferPhysicalTensorDesc(\n    user_op::InferContext* ctx) -> Maybe<void> {\n  return FusedAttentionConcatPastKeyValueOp::InferLogicalTensorDesc(ctx);\n}\n/*static*/ auto FusedAttentionConcatPastKeyValueOp::GetSbp(user_op::SbpContext* ctx)\n    -> Maybe<void> {\n  const int64_t key_head_size = ctx->user_op_conf().attr<int64_t>(\"key_head_size\");\n  const std::string& past_key_layout = ctx->user_op_conf().attr<std::string>(\"past_key_layout\");\n  const std::string& past_value_layout = ctx->user_op_conf().attr<std::string>(\"past_value_layout\");\n  const std::string& key_layout = ctx->user_op_conf().attr<std::string>(\"key_layout\");\n  const std::string& value_layout = ctx->user_op_conf().attr<std::string>(\"value_layout\");\n  int64_t num_heads = 0;\n  {\n    int64_t b = 0;\n    int64_t m = 0;\n    int64_t k = 0;\n\n    const user_op::TensorDesc& key = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"key\", 0);\n    JUST(ParseDims(key.shape(), key_layout, Optional<int64_t>(), key_head_size, &b, &m, &num_heads,\n                   &k));\n  }\n  const bool can_hk_split = num_heads % ctx->parallel_num() == 0;\n  int64_t past_k_b_split_axis = -1;\n  int64_t past_k_h_split_axis = -1;\n  JUST(ParseSplitAxis(past_key_layout, can_hk_split, &past_k_b_split_axis, &past_k_h_split_axis));\n  int64_t past_v_b_split_axis = -1;\n  int64_t past_v_h_split_axis = -1;\n  JUST(ParseSplitAxis(past_value_layout, can_hk_split, &past_v_b_split_axis, &past_v_h_split_axis));\n  int64_t k_b_split_axis = -1;\n  int64_t k_h_split_axis = -1;\n  JUST(ParseSplitAxis(key_layout, can_hk_split, &k_b_split_axis, &k_h_split_axis));\n  int64_t v_b_split_axis = -1;\n  int64_t v_h_split_axis = -1;\n  JUST(ParseSplitAxis(value_layout, can_hk_split, &v_b_split_axis, &v_h_split_axis));\n\n  std::vector<user_op::OpArg> past_key_arg;\n  if (ctx->user_op_conf().has_input(\"past_key\", 0)) { past_key_arg.emplace_back(\"past_key\", 0); }\n  std::vector<user_op::OpArg> past_value_arg;\n  if (ctx->user_op_conf().has_input(\"past_value\", 0)) {\n    past_value_arg.emplace_back(\"past_value\", 0);\n  }\n  if (past_k_b_split_axis >= 0 && past_v_b_split_axis >= 0 && k_b_split_axis >= 0\n      && v_b_split_axis >= 0) {\n    ctx->NewBuilder()\n        .Split(past_key_arg, past_k_b_split_axis)\n        .Split(past_value_arg, past_v_b_split_axis)\n        .Split(user_op::OpArg(\"key\", 0), k_b_split_axis)\n        .Split(user_op::OpArg(\"value\", 0), v_b_split_axis)\n        .Split(user_op::OpArg(\"output_key\", 0), past_k_b_split_axis)\n        .Split(user_op::OpArg(\"output_value\", 0), past_v_b_split_axis)\n        .Build();\n  }\n\n  if (past_k_h_split_axis >= 0 && past_v_h_split_axis >= 0 && k_h_split_axis >= 0\n      && v_h_split_axis >= 0) {\n    ctx->NewBuilder()\n        .Split(past_key_arg, past_k_h_split_axis)\n        .Split(past_value_arg, past_v_h_split_axis)\n        .Split(user_op::OpArg(\"key\", 0), k_h_split_axis)\n        .Split(user_op::OpArg(\"value\", 0), v_h_split_axis)\n        .Split(user_op::OpArg(\"output_key\", 0), past_k_h_split_axis)\n        .Split(user_op::OpArg(\"output_value\", 0), past_v_h_split_axis)\n        .Build();\n  }\n\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> FusedApplyRotaryEmbOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& x_desc = ctx->InputTensorDesc(\"x\", 0);\n  const std::string& x_layout = ctx->Attr<std::string>(\"x_layout\");\n  const std::string& output_layout = ctx->Attr<std::string>(\"output_layout\");\n  const std::string& mode = ctx->Attr<std::string>(\"mode\");\n  const int64_t rotary_size = ctx->Attr<int64_t>(\"rotary_size\");\n  const int64_t k_size = ctx->Attr<int64_t>(\"k_size\");\n  const int64_t tensor_index = ctx->Attr<int64_t>(\"tensor_index\");\n\n  CHECK_OR_RETURN((tensor_index >= 0) && (tensor_index <= 2))\n      << \"tensor_index should be in range [0, 2].\";\n  CHECK_OR_RETURN((mode == \"interval\") || (mode == \"plane\"))\n      << \"mode should be either \\\"interval\\\" or \\\"plane\\\".\";\n\n  CHECK_OR_RETURN(output_layout != \"BM(H2K)\" && output_layout != \"BM(H3K)\"\n                  && output_layout != \"MB(H2K)\" && output_layout != \"MB(H3K)\")\n      << \"output_layout should not be \\\"BM(H2k)\\\", \\\"BM(H3K)\\\", \\\"MB(H2K)\\\", \\\"MB(H3K)\\\".\";\n\n  int64_t b = 0, m = 0, h = 0, k = 0;\n\n  JUST(ParseDims(x_desc.shape(), x_layout, Optional<int64_t>(), Optional<int64_t>(k_size), &b, &m,\n                 &h, &k));\n\n  CHECK_LE_OR_RETURN(rotary_size, k) << \"rotary_size should be no more than K of input x.\";\n\n  int64_t rotary_emb_dim = 1;\n\n  if (ctx->has_input(\"position_ids\", 0)) {\n    const user_op::TensorDesc& position_ids_desc = ctx->InputTensorDesc(\"position_ids\", 0);\n    CHECK_EQ_OR_RETURN(position_ids_desc.shape().NumAxes(), 3)\n        << \"ndims of position_ids should be equal to 3, either in form of B1M or B2M.\";\n    CHECK_EQ_OR_RETURN(position_ids_desc.shape().At(0), b)\n        << \"1st dim of position_ids should be equal to B.\";\n    CHECK_EQ_OR_RETURN(position_ids_desc.shape().At(2), m)\n        << \"3rd dim of position_ids should be equal to M.\";\n    rotary_emb_dim = position_ids_desc.shape().At(1);\n    CHECK_OR_RETURN(rotary_emb_dim == 1 || rotary_emb_dim == 2)\n        << \"2nd dim of position_ids should be 1 or 2.\";\n  }\n\n  const int64_t actual_rotary_size = rotary_size / rotary_emb_dim;\n  CHECK_EQ_OR_RETURN(actual_rotary_size % 2, 0)\n      << \"rotary_size should be a multiple of 2 * rotary_encoding_dim.\";\n\n  bool has_cos = ctx->has_input(\"cos\", 0);\n  bool has_sin = ctx->has_input(\"sin\", 0);\n  // TODO: fused_apply_rotary_emb have same logic no matter name\n  if (has_cos && has_sin) {\n    const user_op::TensorDesc& cos_desc = ctx->InputTensorDesc(\"cos\", 0);\n    const user_op::TensorDesc& sin_desc = ctx->InputTensorDesc(\"sin\", 0);\n    CHECK_EQ_OR_RETURN(cos_desc.shape().NumAxes(), 2)\n        << \"The number of dimensions of cos should be equal to 2.\";\n    CHECK_OR_RETURN(cos_desc.shape() == sin_desc.shape())\n        << \"The dimensions of cos & sin should be the same.\";\n    CHECK_EQ_OR_RETURN(cos_desc.shape().At(1), actual_rotary_size)\n        << \"The 1st dimension of cos & sin should equal to rotary_size // \"\n           \"rotary_embedding_dimension.\";\n  } else if (!has_cos && !has_sin) {\n    // Do nothing\n  } else {\n    UNIMPLEMENTED_THEN_RETURN();\n  }\n\n  if (!ctx->has_input(\"position_ids\", 0)) {\n    if (has_cos && has_sin) {\n      const user_op::TensorDesc& cos_desc = ctx->InputTensorDesc(\"cos\", 0);\n      CHECK_GE_OR_RETURN(cos_desc.shape().At(0), m)\n          << \"M of cos should be no less than M of x if position_ids is not given.\";\n      // K of cos & sin is checked inside ParseDims\n    }\n  }\n\n  Shape out_shape = *JUST(LayoutToShape(b, m, h, k, output_layout));\n  ctx->SetOutputShape(\"out\", 0, out_shape);\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> FusedApplyRotaryEmbOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> FusedApplyRotaryEmbOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& x_desc = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"x\", 0);\n  int num_heads = -1;\n  const int64_t k_size = ctx->Attr<int64_t>(\"k_size\");\n  const std::string& x_layout = ctx->Attr<std::string>(\"x_layout\");\n  const std::string& output_layout = ctx->Attr<std::string>(\"output_layout\");\n  if (x_desc.shape().NumAxes() == 2) {\n    if (x_layout == \"(BM)(HK)\") {\n      CHECK_EQ_OR_RETURN(x_desc.shape().At(1) % k_size, 0);\n      num_heads = x_desc.shape().At(1) / k_size;\n    } else if (x_layout == \"(BM)(H3K)\") {\n      CHECK_EQ_OR_RETURN(x_desc.shape().At(1) % (k_size * 3), 0);\n      num_heads = x_desc.shape().At(1) / (k_size * 3);\n    } else {\n      UNIMPLEMENTED_THEN_RETURN();\n    }\n  } else if (x_desc.shape().NumAxes() == 3) {\n    if (x_layout == \"BM(HK)\" || x_layout == \"MB(HK)\") {\n      CHECK_EQ_OR_RETURN(x_desc.shape().At(2) % k_size, 0);\n      num_heads = x_desc.shape().At(2) / k_size;\n    } else if (x_layout == \"BM(H3K)\" || x_layout == \"MB(H3K)\") {\n      CHECK_EQ_OR_RETURN(x_desc.shape().At(2) % (k_size * 3), 0);\n      num_heads = x_desc.shape().At(2) / (k_size * 3);\n    } else if (x_layout == \"(BM)HK\") {\n      num_heads = x_desc.shape().At(1);\n    } else {\n      UNIMPLEMENTED_THEN_RETURN();\n    }\n  } else if (x_desc.shape().NumAxes() == 4) {\n    if (x_layout == \"BMHK\") {\n      num_heads = x_desc.shape().At(2);\n    } else if (x_layout == \"BHMK\") {\n      num_heads = x_desc.shape().At(1);\n    } else {\n      UNIMPLEMENTED_THEN_RETURN();\n    }\n  } else {\n    UNIMPLEMENTED_THEN_RETURN();\n  }\n  const bool can_hk_split = num_heads % ctx->parallel_num() == 0;\n  int64_t x_b_split_axis = -1;\n  int64_t x_h_split_axis = -1;\n  JUST(ParseSplitAxis(x_layout, can_hk_split, &x_b_split_axis, &x_h_split_axis));\n  int64_t o_b_split_axis = -1;\n  int64_t o_h_split_axis = -1;\n  JUST(ParseSplitAxis(output_layout, can_hk_split, &o_b_split_axis, &o_h_split_axis));\n  if (x_b_split_axis >= 0 && o_b_split_axis >= 0) {\n    auto builder = ctx->NewBuilder()\n                       .Split(user_op::OpArg(\"x\", 0), x_b_split_axis)\n                       .Split(user_op::OpArg(\"out\", 0), o_b_split_axis);\n    if (ctx->user_op_conf().has_input(\"cos\", 0))\n      builder = builder.Broadcast(user_op::OpArg(\"cos\", 0)).Broadcast(user_op::OpArg(\"sin\", 0));\n    if (ctx->user_op_conf().has_input(\"position_ids\", 0))\n      builder = builder.Split(user_op::OpArg(\"position_ids\", 0), 0);\n    builder.Build();\n  }\n  if (x_h_split_axis >= 0 && o_h_split_axis >= 0) {\n    auto builder = ctx->NewBuilder()\n                       .Split(user_op::OpArg(\"x\", 0), x_h_split_axis)\n                       .Split(user_op::OpArg(\"out\", 0), o_h_split_axis);\n    if (ctx->user_op_conf().has_input(\"cos\", 0))\n      builder = builder.Broadcast(user_op::OpArg(\"cos\", 0)).Broadcast(user_op::OpArg(\"sin\", 0));\n    if (ctx->user_op_conf().has_input(\"position_ids\", 0))\n      builder = builder.Broadcast(user_op::OpArg(\"position_ids\", 0));\n    builder.Build();\n  }\n\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> FusedApplyRotaryEmbOp::InferDataType(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& first_in_desc = ctx->InputTensorDesc(\"x\", 0);\n\n  bool has_sinuous = ctx->has_input(\"cos\", 0);\n\n  if (has_sinuous) {\n    const user_op::TensorDesc& cos_desc = ctx->InputTensorDesc(\"cos\", 0);\n    const user_op::TensorDesc& sin_desc = ctx->InputTensorDesc(\"sin\", 0);\n\n    CHECK_EQ_OR_RETURN(cos_desc.data_type(), first_in_desc.data_type())\n        << \"InferDataType Failed. Expected \" << DataType_Name(first_in_desc.data_type())\n        << \", but got \" << DataType_Name(cos_desc.data_type());\n    CHECK_EQ_OR_RETURN(sin_desc.data_type(), first_in_desc.data_type())\n        << \"InferDataType Failed. Expected \" << DataType_Name(first_in_desc.data_type())\n        << \", but got \" << DataType_Name(sin_desc.data_type());\n  }\n\n  user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc(\"out\", 0);\n  out_desc->set_data_type(first_in_desc.data_type());\n\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/fused_bias_add_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/*static*/ auto FusedBiasAddGeluOp::InferLogicalTensorDesc(user_op::InferContext* ctx)\n    -> Maybe<void> {\n  const auto& a_tensor_desc = ctx->InputTensorDesc(\"a\", 0);\n  const auto& b_tensor_desc = ctx->InputTensorDesc(\"b\", 0);\n  const auto bias_add_axis = ctx->Attr<int32_t>(\"axis\");\n  CHECK_EQ_OR_RETURN(b_tensor_desc.shape().NumAxes(), 1);\n  CHECK_GE_OR_RETURN(bias_add_axis, 0);\n  CHECK_LT_OR_RETURN(bias_add_axis, a_tensor_desc.shape().NumAxes());\n  CHECK_EQ_OR_RETURN(a_tensor_desc.shape().At(bias_add_axis), b_tensor_desc.shape().At(0));\n  ctx->SetOutputShape(\"out\", 0, a_tensor_desc.shape());\n  ctx->SetOutputIsDynamic(\"out\", 0, a_tensor_desc.is_dynamic());\n  return Maybe<void>::Ok();\n}\n/*static*/ auto FusedBiasAddGeluOp::InferPhysicalTensorDesc(user_op::InferContext* ctx)\n    -> Maybe<void> {\n  return FusedBiasAddGeluOp::InferLogicalTensorDesc(ctx);\n}\n/*static*/ auto FusedBiasAddGeluOp::InferDataType(user_op::InferContext* ctx) -> Maybe<void> {\n  const auto& a_tensor_desc = ctx->InputTensorDesc(\"a\", 0);\n  ctx->SetOutputDType(\"out\", 0, a_tensor_desc.data_type());\n  return Maybe<void>::Ok();\n}\n/*static*/ auto FusedBiasAddGeluOp::GetSbp(user_op::SbpContext* ctx) -> Maybe<void> {\n  const auto axis = ctx->Attr<int32_t>(\"axis\");\n  for (int64_t i = 0; i < ctx->LogicalTensorDesc4InputArgNameAndIndex(\"a\", 0).shape().NumAxes();\n       ++i) {\n    if (i == axis) { continue; }\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"a\", 0), i)\n        .Broadcast(user_op::OpArg(\"b\", 0))\n        .Split(ctx->outputs(), i)\n        .Build();\n  }\n  ctx->NewBuilder()\n      .Split(user_op::OpArg(\"b\", 0), 0)\n      .Split(user_op::OpArg(\"a\", 0), axis)\n      .Split(ctx->outputs(), axis)\n      .Build();\n  return Maybe<void>::Ok();\n}\n/*static*/ auto FusedBiasAddGeluGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx)\n    -> Maybe<void> {\n  const auto& a_tensor_desc = ctx->InputTensorDesc(\"a\", 0);\n  const auto& b_tensor_desc = ctx->InputTensorDesc(\"b\", 0);\n  const auto bias_add_axis = ctx->Attr<int32_t>(\"axis\");\n  CHECK_EQ_OR_RETURN(b_tensor_desc.shape().NumAxes(), 1);\n  CHECK_GE_OR_RETURN(bias_add_axis, 0);\n  CHECK_LT_OR_RETURN(bias_add_axis, a_tensor_desc.shape().NumAxes());\n  CHECK_EQ_OR_RETURN(a_tensor_desc.shape().At(bias_add_axis), b_tensor_desc.shape().At(0));\n  ctx->SetOutputShape(\"dx\", 0, a_tensor_desc.shape());\n  ctx->SetOutputIsDynamic(\"dx\", 0, a_tensor_desc.is_dynamic());\n  return Maybe<void>::Ok();\n}\n\n/*static*/ auto FusedBiasAddGeluGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx)\n    -> Maybe<void> {\n  return FusedBiasAddGeluGradOp::InferLogicalTensorDesc(ctx);\n}\n/*static*/ auto FusedBiasAddGeluGradOp::InferDataType(user_op::InferContext* ctx) -> Maybe<void> {\n  const auto& a_tensor_desc = ctx->InputTensorDesc(\"a\", 0);\n  ctx->SetOutputDType(\"dx\", 0, a_tensor_desc.data_type());\n  return Maybe<void>::Ok();\n}\n/*static*/ auto FusedBiasAddGeluGradOp::GetSbp(user_op::SbpContext* ctx) -> Maybe<void> {\n  const auto axis = ctx->Attr<int32_t>(\"axis\");\n  for (int64_t i = 0; i < ctx->LogicalTensorDesc4InputArgNameAndIndex(\"a\", 0).shape().NumAxes();\n       ++i) {\n    if (i == axis) { continue; }\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"a\", 0), i)\n        .Split(user_op::OpArg(\"dy\", 0), i)\n        .Broadcast(user_op::OpArg(\"b\", 0))\n        .Split(ctx->outputs(), i)\n        .Build();\n  }\n  ctx->NewBuilder()\n      .Split(user_op::OpArg(\"b\", 0), 0)\n      .Split(user_op::OpArg(\"a\", 0), axis)\n      .Split(user_op::OpArg(\"dy\", 0), axis)\n      .Split(ctx->outputs(), axis)\n      .Build();\n  return Maybe<void>::Ok();\n}\n\n/*static*/ auto FusedBiasAddMaskScaleOp::InferLogicalTensorDesc(user_op::InferContext* ctx)\n    -> Maybe<void> {\n  const auto& a_tensor_desc = ctx->InputTensorDesc(\"a\", 0);\n  const auto& mask_tensor_desc = ctx->InputTensorDesc(\"mask\", 0);\n  const auto& b_tensor_desc = ctx->InputTensorDesc(\"b\", 0);\n  const auto bias_add_axis = ctx->Attr<int32_t>(\"axis\");\n  CHECK_EQ_OR_RETURN(b_tensor_desc.shape().NumAxes(), 1);\n  CHECK_GE_OR_RETURN(bias_add_axis, 0);\n  CHECK_LT_OR_RETURN(bias_add_axis, a_tensor_desc.shape().NumAxes());\n  CHECK_EQ_OR_RETURN(a_tensor_desc.shape().At(bias_add_axis), b_tensor_desc.shape().At(0));\n  CHECK_EQ_OR_RETURN(a_tensor_desc.shape(), mask_tensor_desc.shape());\n  ctx->SetOutputShape(\"out\", 0, a_tensor_desc.shape());\n  ctx->SetOutputIsDynamic(\"out\", 0, a_tensor_desc.is_dynamic());\n  return Maybe<void>::Ok();\n}\n/*static*/ auto FusedBiasAddMaskScaleOp::InferPhysicalTensorDesc(user_op::InferContext* ctx)\n    -> Maybe<void> {\n  return FusedBiasAddMaskScaleOp::InferLogicalTensorDesc(ctx);\n}\n/*static*/ auto FusedBiasAddMaskScaleOp::InferDataType(user_op::InferContext* ctx) -> Maybe<void> {\n  const auto& a_tensor_desc = ctx->InputTensorDesc(\"a\", 0);\n  ctx->SetOutputDType(\"out\", 0, a_tensor_desc.data_type());\n  return Maybe<void>::Ok();\n}\n/*static*/ auto FusedBiasAddMaskScaleOp::ModifyInputArg(\n    const user_op::GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&)\n    -> Maybe<void> {\n  user_op::InputArgModifier* mask_modifier = GetInputArgModifierFn(\"mask\", 0);\n  CHECK_OR_RETURN(mask_modifier != nullptr);\n  mask_modifier->set_requires_grad(false);\n  return Maybe<void>::Ok();\n}\n/*static*/ auto FusedBiasAddMaskScaleOp::GetSbp(user_op::SbpContext* ctx) -> Maybe<void> {\n  const auto axis = ctx->Attr<int32_t>(\"axis\");\n  std::vector<user_op::OpArg> split_args;\n  split_args.emplace_back(\"a\", 0);\n  split_args.emplace_back(\"mask\", 0);\n  split_args.emplace_back(\"out\", 0);\n  if (ctx->user_op_conf().has_input(\"_add_to_output\", 0)) {\n    split_args.emplace_back(\"_add_to_output\", 0);\n  }\n  for (int64_t i = 0; i < ctx->LogicalTensorDesc4InputArgNameAndIndex(\"a\", 0).shape().NumAxes();\n       ++i) {\n    if (i == axis) { continue; }\n    ctx->NewBuilder().Split(split_args, i).Broadcast(user_op::OpArg(\"b\", 0)).Build();\n  }\n  ctx->NewBuilder().Split(user_op::OpArg(\"b\", 0), 0).Split(split_args, axis).Build();\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/fused_bias_add_scale_mask_softmax_dropout_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nbool CheckBroadcastable(const Shape& shape, const Shape& broadcast_shape) {\n  int left_pad = broadcast_shape.size() - shape.size();\n  if (left_pad < 0) { return false; }\n  for (int i = 0; i < shape.size(); ++i) {\n    int j = i + left_pad;\n    if (shape[i] != 1 && shape[i] != broadcast_shape[j]) { return false; }\n  }\n  return true;\n}\n\nbool CheckBroadcastAndSimplifyDims(const Shape& shape, const Shape& broadcast_shape,\n                                   int& simplified_ndim, int64_t* simplified_dims) {\n  int lpad = broadcast_shape.size() - shape.size();\n  if (lpad < 0) { return false; }\n  simplified_ndim = 0;\n  bool prev_broadcast = false;\n  for (int i = 0; i < broadcast_shape.size(); ++i) {\n    int64_t dim = (i < lpad) ? 1 : shape[i - lpad];\n    int64_t broadcast_dim = broadcast_shape[i];\n    if (dim != 1 && dim != broadcast_dim) { return false; }\n    bool broadcast = (dim == 1 && broadcast_dim != 1);\n    if (simplified_ndim > 0 && broadcast == prev_broadcast) {\n      // fold to prev dim\n      simplified_dims[simplified_ndim - 1] *= dim;\n    } else {\n      simplified_dims[simplified_ndim] = dim;\n      simplified_ndim += 1;\n    }\n    prev_broadcast = broadcast;\n  }\n  return true;\n}\n\n// return lpad\nint GetBroadcastDims(const Shape& shape, const Shape& broadcast_shape,\n                     HashSet<int>& broadcast_dims) {\n  int lpad = broadcast_shape.size() - shape.size();\n  if (lpad < 0) { return lpad; }\n  for (int i = 0; i < broadcast_shape.size(); ++i) {\n    if (i < lpad) {\n      broadcast_dims.insert(i);\n    } else {\n      int j = i - lpad;\n      if (shape[j] == 1 && shape[j] != broadcast_shape[i]) { broadcast_dims.insert(i); }\n      if (shape[j] != 1 && shape[j] != broadcast_shape[i]) { return -1; }\n    }\n  }\n  return lpad;\n}\n\n}  // namespace\n\nMaybe<void> FusedBiasAddScaleMaskSoftmaxDropoutOp::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) {\n  const Shape& x_shape = ctx->InputShape(\"x\", 0);\n  const Shape& bias_shape = ctx->InputShape(\"bias\", 0);\n  const Shape& mask_shape = ctx->InputShape(\"mask\", 0);\n  const Shape& dropout_mask_shape = ctx->InputShape(\"dropout_mask\", 0);\n\n  CHECK_GE_OR_RETURN(x_shape.size(), 2) << Error::RuntimeError() << \"x has at least 2 dimensions\";\n  CHECK_EQ_OR_RETURN(x_shape.back(), mask_shape.back())\n      << \" Last dimension of x and mask should be equal, which is softmax dimension.\";\n  CHECK_EQ_OR_RETURN(dropout_mask_shape, x_shape)\n      << Error::RuntimeError() << \"dropout_mask shape \" << dropout_mask_shape.ToString()\n      << \" should be equal to x shape \" << x_shape.ToString();\n\n  int simplified_bias_ndim = 0;\n  int simplified_mask_ndim = 0;\n  DimVector simplified_bias_dims(x_shape.size());\n  DimVector simplified_mask_dims(x_shape.size());\n  CHECK_OR_RETURN(CheckBroadcastAndSimplifyDims(bias_shape, x_shape, simplified_bias_ndim,\n                                                simplified_bias_dims.data()))\n      << Error::RuntimeError() << \"bias shape \" << bias_shape.ToString()\n      << \" could not be broadcast to x shape \" << x_shape.ToString();\n  CHECK_OR_RETURN(CheckBroadcastAndSimplifyDims(mask_shape, x_shape, simplified_mask_ndim,\n                                                simplified_mask_dims.data()))\n      << Error::RuntimeError() << \"mask shape \" << mask_shape.ToString()\n      << \" could not be broadcast to x shape \" << x_shape.ToString();\n  CHECK_GT_OR_RETURN(simplified_bias_ndim, 0);  // NOLINT(maybe-need-error-msg)\n  CHECK_GT_OR_RETURN(simplified_mask_ndim, 0);  // NOLINT(maybe-need-error-msg)\n  // (1, ) -> (K, )\n  // (M, 1) -> (M, N)\n  // (1, N) -> (M, N)\n  // (M, 1, N) -> (M, K, N)\n  if ((simplified_bias_ndim == 2 && simplified_bias_dims[0] != 1) || simplified_bias_ndim > 2) {\n    return Error::RuntimeError()\n           << \"bias only support (1, N)->(M, N) broadcast, but got bias shape \"\n           << bias_shape.ToString() << \" broadcast to x shape \" << x_shape.ToString();\n  }\n\n  if (simplified_mask_ndim > 3 || (simplified_mask_ndim == 3 && simplified_mask_dims[1] != 1)) {\n    return Error::RuntimeError() << \"mask support (M, 1)->(M, N) or (1, N)->(M, N) or (M, 1, \"\n                                    \"N)->(M, K, N) broadcast, but got mask shape \"\n                                 << mask_shape.ToString() << \" broadcast to x shape \"\n                                 << x_shape.ToString();\n  }\n\n  ctx->SetOutputShape(\"y\", 0, x_shape);\n  ctx->SetOutputShape(\"softmax_y\", 0, x_shape);\n  ctx->SetOutputIsDynamic(\"y\", 0, ctx->InputIsDynamic(\"x\", 0));\n  ctx->SetOutputIsDynamic(\"softmax_y\", 0, ctx->InputIsDynamic(\"x\", 0));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> FusedBiasAddScaleMaskSoftmaxDropoutOp::InferPhysicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\nMaybe<void> FusedBiasAddScaleMaskSoftmaxDropoutOp::InferDataType(user_op::InferContext* ctx) {\n  const DataType x_dtype = ctx->InputDType(\"x\", 0);\n  const DataType bias_dtype = ctx->InputDType(\"bias\", 0);\n  const DataType mask_dtype = ctx->InputDType(\"mask\", 0);\n  const DataType dropout_mask_dtype = ctx->InputDType(\"dropout_mask\", 0);\n\n  CHECK_EQ_OR_RETURN(bias_dtype, x_dtype)\n      << Error::RuntimeError() << \"Expected bias data type \" << DataType_Name(x_dtype)\n      << \", but got \" << DataType_Name(bias_dtype);\n  CHECK_OR_RETURN(IsBoolDataType(mask_dtype) || IsIntegralDataType(mask_dtype))\n      << Error::RuntimeError() << \"Expected mask data type to be bool or integer, but got \"\n      << DataType_Name(mask_dtype);\n  CHECK_OR_RETURN(IsBoolDataType(dropout_mask_dtype))\n      << Error::RuntimeError() << \"Expected dropout_mask data type to be bool, but got \"\n      << DataType_Name(dropout_mask_dtype);\n\n  ctx->SetOutputDType(\"y\", 0, x_dtype);\n  ctx->SetOutputDType(\"softmax_y\", 0, x_dtype);\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> FusedBiasAddScaleMaskSoftmaxDropoutOp::GetSbp(user_op::SbpContext* ctx) {\n  const Shape& x_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"x\", 0).shape();\n  const Shape& bias_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"bias\", 0).shape();\n  const Shape& mask_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"mask\", 0).shape();\n  const Shape& dropout_mask_shape =\n      ctx->LogicalTensorDesc4InputArgNameAndIndex(\"dropout_mask\", 0).shape();\n\n  CHECK_GE_OR_RETURN(x_shape.size(), 2) << Error::RuntimeError() << \"x has at least 2 dimensions\";\n  CHECK_EQ_OR_RETURN(dropout_mask_shape, x_shape)\n      << Error::RuntimeError() << \"dropout_mask_shape shape \" << dropout_mask_shape.ToString()\n      << \" should be equal to x shape \" << x_shape.ToString();\n\n  HashSet<int> bias_broadcast_dims;\n  HashSet<int> mask_broadcast_dims;\n  int bias_lpad = GetBroadcastDims(bias_shape, x_shape, bias_broadcast_dims);\n  int mask_lpad = GetBroadcastDims(mask_shape, x_shape, mask_broadcast_dims);\n\n  CHECK_GE_OR_RETURN(bias_lpad, 0)\n      << Error::RuntimeError() << \"bias shape \" << bias_shape.ToString()\n      << \" could not be broadcast to x shape \" << x_shape.ToString();\n  CHECK_GE_OR_RETURN(mask_lpad, 0)\n      << Error::RuntimeError() << \"mask shape \" << mask_shape.ToString()\n      << \" could not be broadcast to x shape \" << x_shape.ToString();\n\n  std::vector<user_op::OpArg> split_args = {\n      {\"x\", 0},\n      {\"dropout_mask\", 0},\n      {\"y\", 0},\n      {\"softmax_y\", 0},\n  };\n\n  for (int i = 0; i < x_shape.size(); ++i) {\n    bool bias_can_split = (bias_broadcast_dims.find(i) == bias_broadcast_dims.end());\n    bool mask_can_split = (mask_broadcast_dims.find(i) == mask_broadcast_dims.end());\n    if (bias_can_split && mask_can_split) {\n      CHECK_GE_OR_RETURN(i, bias_lpad);  // NOLINT(maybe-need-error-msg)\n      CHECK_GE_OR_RETURN(i, mask_lpad);  // NOLINT(maybe-need-error-msg)\n      ctx->NewBuilder()\n          .Split(split_args, i)\n          .Split(user_op::OpArg(\"bias\", 0), i - bias_lpad)\n          .Split(user_op::OpArg(\"mask\", 0), i - mask_lpad)\n          .Build();\n    } else if (bias_can_split) {\n      CHECK_GE_OR_RETURN(i, bias_lpad);  // NOLINT(maybe-need-error-msg)\n      ctx->NewBuilder()\n          .Split(split_args, i)\n          .Split(user_op::OpArg(\"bias\", 0), i - bias_lpad)\n          .Broadcast(user_op::OpArg(\"mask\", 0))\n          .Build();\n    } else if (mask_can_split) {\n      CHECK_GE_OR_RETURN(i, mask_lpad);  // NOLINT(maybe-need-error-msg)\n      ctx->NewBuilder()\n          .Split(split_args, i)\n          .Broadcast(user_op::OpArg(\"bias\", 0))\n          .Split(user_op::OpArg(\"mask\", 0), i - mask_lpad)\n          .Build();\n    } else {\n      ctx->NewBuilder()\n          .Split(split_args, i)\n          .Broadcast(user_op::OpArg(\"bias\", 0))\n          .Broadcast(user_op::OpArg(\"mask\", 0))\n          .Build();\n    }\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> FusedBiasAddScaleMaskSoftmaxDropoutOp::ModifyInputArg(\n    const user_op::GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&) {\n  user_op::InputArgModifier* mask_modifier = GetInputArgModifierFn(\"mask\", 0);\n  user_op::InputArgModifier* dropout_mask_modifier = GetInputArgModifierFn(\"dropout_mask\", 0);\n  CHECK_OR_RETURN(mask_modifier != nullptr) << \" cannot find mask input.\";\n  CHECK_OR_RETURN(dropout_mask_modifier != nullptr) << \" cannot find dropout mask input.\";\n  mask_modifier->set_requires_grad(false);\n  dropout_mask_modifier->set_requires_grad(false);\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/fused_cast_scale_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\nMaybe<void> FusedCastScaleOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& x = ctx->InputTensorDesc(\"x\", 0);\n  const user_op::TensorDesc& scale_by_tensor = ctx->InputTensorDesc(\"scale_by_tensor\", 0);\n  CHECK_EQ_OR_RETURN(scale_by_tensor.shape().NumAxes(), 1);\n  CHECK_EQ_OR_RETURN(scale_by_tensor.shape().At(0), 1);\n  user_op::TensorDesc* y = ctx->MutOutputTensorDesc(\"y\", 0);\n  y->set_is_dynamic(x.is_dynamic());\n  y->set_shape(x.shape());\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> FusedCastScaleOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return FusedCastScaleOp::InferLogicalTensorDesc(ctx);\n}\n\nMaybe<void> FusedCastScaleOp::InferDataType(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& scale_by_tensor = ctx->InputTensorDesc(\"scale_by_tensor\", 0);\n  user_op::TensorDesc* y = ctx->MutOutputTensorDesc(\"y\", 0);\n  y->set_data_type(scale_by_tensor.data_type());\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> FusedCastScaleOp::GetSbp(user_op::SbpContext* ctx) {\n  const auto& x = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"x\", 0);\n  for (int i = 0; i < x.shape().NumAxes(); ++i) {\n    ctx->NewBuilder()\n        .Broadcast(user_op::OpArg(\"scale_by_tensor\", 0))\n        .Split(user_op::OpArg(\"x\", 0), i)\n        .Split(user_op::OpArg(\"y\", 0), i)\n        .Build();\n  }\n  ctx->NewBuilder()\n      .PartialSum(user_op::OpArg(\"scale_by_tensor\", 0))\n      .Broadcast(user_op::OpArg(\"x\", 0))\n      .PartialSum(user_op::OpArg(\"y\", 0))\n      .Build();\n  ctx->NewBuilder()\n      .Broadcast(user_op::OpArg(\"scale_by_tensor\", 0))\n      .PartialSum(user_op::OpArg(\"x\", 0))\n      .PartialSum(user_op::OpArg(\"y\", 0))\n      .Build();\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/fused_center_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\nMaybe<void> FusedCenterOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& b1_x1 = ctx->InputTensorDesc(\"b1_x1\", 0);\n  const user_op::TensorDesc& b1_x2 = ctx->InputTensorDesc(\"b1_x2\", 0);\n  const user_op::TensorDesc& b1_y1 = ctx->InputTensorDesc(\"b1_y1\", 0);\n  const user_op::TensorDesc& b1_y2 = ctx->InputTensorDesc(\"b1_y2\", 0);\n  const user_op::TensorDesc& b2_x1 = ctx->InputTensorDesc(\"b2_x1\", 0);\n  const user_op::TensorDesc& b2_x2 = ctx->InputTensorDesc(\"b2_x2\", 0);\n  const user_op::TensorDesc& b2_y1 = ctx->InputTensorDesc(\"b2_y1\", 0);\n  const user_op::TensorDesc& b2_y2 = ctx->InputTensorDesc(\"b2_y2\", 0);\n\n  CHECK_EQ_OR_RETURN(b1_x1.shape(), b1_x2.shape());\n  CHECK_EQ_OR_RETURN(b1_x1.shape(), b1_y1.shape());\n  CHECK_EQ_OR_RETURN(b1_x1.shape(), b1_y2.shape());\n  CHECK_EQ_OR_RETURN(b1_x1.shape(), b2_x1.shape());\n  CHECK_EQ_OR_RETURN(b1_x1.shape(), b2_x2.shape());\n  CHECK_EQ_OR_RETURN(b1_x1.shape(), b2_y1.shape());\n  CHECK_EQ_OR_RETURN(b1_x1.shape(), b2_y2.shape());\n\n  user_op::TensorDesc* rho = ctx->MutOutputTensorDesc(\"rho2\", 0);\n  rho->set_is_dynamic(b1_x1.is_dynamic());\n  rho->set_shape(b1_x1.shape());\n\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> FusedCenterOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return FusedCenterOp::InferLogicalTensorDesc(ctx);\n}\n\nMaybe<void> FusedCenterOp::InferDataType(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& b1_x1 = ctx->InputTensorDesc(\"b1_x1\", 0);\n  const user_op::TensorDesc& b1_x2 = ctx->InputTensorDesc(\"b1_x2\", 0);\n  const user_op::TensorDesc& b1_y1 = ctx->InputTensorDesc(\"b1_y1\", 0);\n  const user_op::TensorDesc& b1_y2 = ctx->InputTensorDesc(\"b1_y2\", 0);\n  const user_op::TensorDesc& b2_x1 = ctx->InputTensorDesc(\"b2_x1\", 0);\n  const user_op::TensorDesc& b2_x2 = ctx->InputTensorDesc(\"b2_x2\", 0);\n  const user_op::TensorDesc& b2_y1 = ctx->InputTensorDesc(\"b2_y1\", 0);\n  const user_op::TensorDesc& b2_y2 = ctx->InputTensorDesc(\"b2_y2\", 0);\n\n  CHECK_EQ_OR_RETURN(b1_x1.data_type(), b1_x2.data_type());\n  CHECK_EQ_OR_RETURN(b1_x1.data_type(), b1_y1.data_type());\n  CHECK_EQ_OR_RETURN(b1_x1.data_type(), b1_y2.data_type());\n  CHECK_EQ_OR_RETURN(b1_x1.data_type(), b2_x1.data_type());\n  CHECK_EQ_OR_RETURN(b1_x1.data_type(), b2_x2.data_type());\n  CHECK_EQ_OR_RETURN(b1_x1.data_type(), b2_y1.data_type());\n  CHECK_EQ_OR_RETURN(b1_x1.data_type(), b2_y2.data_type());\n\n  user_op::TensorDesc* rho = ctx->MutOutputTensorDesc(\"rho2\", 0);\n  rho->set_data_type(b1_x1.data_type());\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> FusedCenterOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& b1_x1 = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"b1_x1\", 0);\n  FOR_RANGE(int64_t, i, 0, b1_x1.shape().NumAxes()) {\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"b1_x1\", 0), i)\n        .Split(user_op::OpArg(\"b1_x2\", 0), i)\n        .Split(user_op::OpArg(\"b1_y1\", 0), i)\n        .Split(user_op::OpArg(\"b1_y2\", 0), i)\n        .Split(user_op::OpArg(\"b2_x1\", 0), i)\n        .Split(user_op::OpArg(\"b2_x2\", 0), i)\n        .Split(user_op::OpArg(\"b2_y1\", 0), i)\n        .Split(user_op::OpArg(\"b2_y2\", 0), i)\n        .Split(user_op::OpArg(\"rho2\", 0), i)\n        .Build();\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> FusedCenterGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& b1_x1 = ctx->InputTensorDesc(\"b1_x1\", 0);\n  const user_op::TensorDesc& b1_x2 = ctx->InputTensorDesc(\"b1_x2\", 0);\n  const user_op::TensorDesc& b1_y1 = ctx->InputTensorDesc(\"b1_y1\", 0);\n  const user_op::TensorDesc& b1_y2 = ctx->InputTensorDesc(\"b1_y2\", 0);\n  const user_op::TensorDesc& b2_x1 = ctx->InputTensorDesc(\"b2_x1\", 0);\n  const user_op::TensorDesc& b2_x2 = ctx->InputTensorDesc(\"b2_x2\", 0);\n  const user_op::TensorDesc& b2_y1 = ctx->InputTensorDesc(\"b2_y1\", 0);\n  const user_op::TensorDesc& b2_y2 = ctx->InputTensorDesc(\"b2_y2\", 0);\n  const user_op::TensorDesc& rho2_diff = ctx->InputTensorDesc(\"rho2_diff\", 0);\n\n  CHECK_EQ_OR_RETURN(b1_x1.shape(), b1_x2.shape());\n  CHECK_EQ_OR_RETURN(b1_x1.shape(), b1_y1.shape());\n  CHECK_EQ_OR_RETURN(b1_x1.shape(), b1_y2.shape());\n  CHECK_EQ_OR_RETURN(b1_x1.shape(), b2_x1.shape());\n  CHECK_EQ_OR_RETURN(b1_x1.shape(), b2_x2.shape());\n  CHECK_EQ_OR_RETURN(b1_x1.shape(), b2_y1.shape());\n  CHECK_EQ_OR_RETURN(b1_x1.shape(), b2_y2.shape());\n  CHECK_EQ_OR_RETURN(b1_x1.shape(), rho2_diff.shape());\n\n  user_op::TensorDesc* b1_x1_diff = ctx->MutOutputTensorDesc(\"b1_x1_diff\", 0);\n  b1_x1_diff->set_is_dynamic(b1_x1.is_dynamic());\n  b1_x1_diff->set_shape(b1_x1.shape());\n\n  user_op::TensorDesc* b1_x2_diff = ctx->MutOutputTensorDesc(\"b1_x2_diff\", 0);\n  b1_x2_diff->set_is_dynamic(b1_x1.is_dynamic());\n  b1_x2_diff->set_shape(b1_x1.shape());\n\n  user_op::TensorDesc* b2_x1_diff = ctx->MutOutputTensorDesc(\"b2_x1_diff\", 0);\n  b2_x1_diff->set_is_dynamic(b1_x1.is_dynamic());\n  b2_x1_diff->set_shape(b1_x1.shape());\n\n  user_op::TensorDesc* b2_x2_diff = ctx->MutOutputTensorDesc(\"b2_x2_diff\", 0);\n  b2_x2_diff->set_is_dynamic(b1_x1.is_dynamic());\n  b2_x2_diff->set_shape(b1_x1.shape());\n\n  user_op::TensorDesc* b1_y1_diff = ctx->MutOutputTensorDesc(\"b1_y1_diff\", 0);\n  b1_y1_diff->set_is_dynamic(b1_x1.is_dynamic());\n  b1_y1_diff->set_shape(b1_x1.shape());\n\n  user_op::TensorDesc* b1_y2_diff = ctx->MutOutputTensorDesc(\"b1_y2_diff\", 0);\n  b1_y2_diff->set_is_dynamic(b1_x1.is_dynamic());\n  b1_y2_diff->set_shape(b1_x1.shape());\n\n  user_op::TensorDesc* b2_y1_diff = ctx->MutOutputTensorDesc(\"b2_y1_diff\", 0);\n  b2_y1_diff->set_is_dynamic(b1_x1.is_dynamic());\n  b2_y1_diff->set_shape(b1_x1.shape());\n\n  user_op::TensorDesc* b2_y2_diff = ctx->MutOutputTensorDesc(\"b2_y2_diff\", 0);\n  b2_y2_diff->set_is_dynamic(b1_x1.is_dynamic());\n  b2_y2_diff->set_shape(b1_x1.shape());\n\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> FusedCenterGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return FusedCenterGradOp::InferLogicalTensorDesc(ctx);\n}\n\nMaybe<void> FusedCenterGradOp::InferDataType(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& b1_x1 = ctx->InputTensorDesc(\"b1_x1\", 0);\n  const user_op::TensorDesc& b1_x2 = ctx->InputTensorDesc(\"b1_x2\", 0);\n  const user_op::TensorDesc& b1_y1 = ctx->InputTensorDesc(\"b1_y1\", 0);\n  const user_op::TensorDesc& b1_y2 = ctx->InputTensorDesc(\"b1_y2\", 0);\n  const user_op::TensorDesc& b2_x1 = ctx->InputTensorDesc(\"b2_x1\", 0);\n  const user_op::TensorDesc& b2_x2 = ctx->InputTensorDesc(\"b2_x2\", 0);\n  const user_op::TensorDesc& b2_y1 = ctx->InputTensorDesc(\"b2_y1\", 0);\n  const user_op::TensorDesc& b2_y2 = ctx->InputTensorDesc(\"b2_y2\", 0);\n  const user_op::TensorDesc& rho2_diff = ctx->InputTensorDesc(\"rho2_diff\", 0);\n\n  CHECK_EQ_OR_RETURN(b1_x1.data_type(), b1_x2.data_type());\n  CHECK_EQ_OR_RETURN(b1_x1.data_type(), b1_y1.data_type());\n  CHECK_EQ_OR_RETURN(b1_x1.data_type(), b1_y2.data_type());\n  CHECK_EQ_OR_RETURN(b1_x1.data_type(), b2_x1.data_type());\n  CHECK_EQ_OR_RETURN(b1_x1.data_type(), b2_x2.data_type());\n  CHECK_EQ_OR_RETURN(b1_x1.data_type(), b2_y1.data_type());\n  CHECK_EQ_OR_RETURN(b1_x1.data_type(), b2_y2.data_type());\n  CHECK_EQ_OR_RETURN(b1_x1.data_type(), rho2_diff.data_type());\n\n  user_op::TensorDesc* b1_x1_diff = ctx->MutOutputTensorDesc(\"b1_x1_diff\", 0);\n  b1_x1_diff->set_data_type(b1_x1.data_type());\n\n  user_op::TensorDesc* b1_x2_diff = ctx->MutOutputTensorDesc(\"b1_x2_diff\", 0);\n  b1_x2_diff->set_data_type(b1_x1.data_type());\n\n  user_op::TensorDesc* b2_x1_diff = ctx->MutOutputTensorDesc(\"b2_x1_diff\", 0);\n  b2_x1_diff->set_data_type(b1_x1.data_type());\n\n  user_op::TensorDesc* b2_x2_diff = ctx->MutOutputTensorDesc(\"b2_x2_diff\", 0);\n  b2_x2_diff->set_data_type(b1_x1.data_type());\n\n  user_op::TensorDesc* b1_y1_diff = ctx->MutOutputTensorDesc(\"b1_y1_diff\", 0);\n  b1_y1_diff->set_data_type(b1_x1.data_type());\n\n  user_op::TensorDesc* b1_y2_diff = ctx->MutOutputTensorDesc(\"b1_y2_diff\", 0);\n  b1_y2_diff->set_data_type(b1_x1.data_type());\n\n  user_op::TensorDesc* b2_y1_diff = ctx->MutOutputTensorDesc(\"b2_y1_diff\", 0);\n  b2_y1_diff->set_data_type(b1_x1.data_type());\n\n  user_op::TensorDesc* b2_y2_diff = ctx->MutOutputTensorDesc(\"b2_y2_diff\", 0);\n  b2_y2_diff->set_data_type(b1_x1.data_type());\n\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> FusedCenterGradOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& b1_x1 = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"b1_x1\", 0);\n  FOR_RANGE(int64_t, i, 0, b1_x1.shape().NumAxes()) {\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"b1_x1\", 0), i)\n        .Split(user_op::OpArg(\"b1_x2\", 0), i)\n        .Split(user_op::OpArg(\"b1_y1\", 0), i)\n        .Split(user_op::OpArg(\"b1_y2\", 0), i)\n        .Split(user_op::OpArg(\"b2_x1\", 0), i)\n        .Split(user_op::OpArg(\"b2_x2\", 0), i)\n        .Split(user_op::OpArg(\"b2_y1\", 0), i)\n        .Split(user_op::OpArg(\"b2_y2\", 0), i)\n        .Split(user_op::OpArg(\"rho2_diff\", 0), i)\n        .Split(user_op::OpArg(\"b1_x1_diff\", 0), i)\n        .Split(user_op::OpArg(\"b1_x2_diff\", 0), i)\n        .Split(user_op::OpArg(\"b1_y1_diff\", 0), i)\n        .Split(user_op::OpArg(\"b1_y2_diff\", 0), i)\n        .Split(user_op::OpArg(\"b2_x1_diff\", 0), i)\n        .Split(user_op::OpArg(\"b2_x2_diff\", 0), i)\n        .Split(user_op::OpArg(\"b2_y1_diff\", 0), i)\n        .Split(user_op::OpArg(\"b2_y2_diff\", 0), i)\n        .Build();\n  }\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/fused_clip_grad_ops.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/infer_util.h\"\n#include \"oneflow/core/framework/user_op_conf.h\"\n#include \"oneflow/core/framework/user_op_registry.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nMaybe<void> SetInputArgModifierMutable(const user_op::GetInputArgModifier& GetInputArgModifierFn,\n                                       const std::string& arg_name, int32_t arg_index) {\n  user_op::InputArgModifier* arg_modifier = GetInputArgModifierFn(arg_name, arg_index);\n  CHECK_NOTNULL_OR_RETURN(arg_modifier) << \"Arg Modifier should not be null. \";\n  arg_modifier->set_is_mutable(true);\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InputArgModifyFn(const user_op::GetInputArgModifier& GetInputArgModifierFn,\n                             const user_op::UserOpConfWrapper& conf) {\n  for (int64_t i = 0; i < conf.input_size(\"model_diff\"); i++) {\n    JUST(SetInputArgModifierMutable(GetInputArgModifierFn, \"model_diff\", i));\n  }\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\n/* static */ Maybe<void> FusedClipGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const auto& in_0 = ctx->InputTensorDesc(\"model_diff\", 0);\n  auto* out = ctx->MutOutputTensorDesc(\"out\", 0);\n  for (int64_t i = 1; i < ctx->input_size(\"model_diff\"); ++i) {\n    const auto& cur_in = ctx->InputTensorDesc(\"model_diff\", i);\n    CHECK_EQ_OR_RETURN(in_0.shape(), cur_in.shape())\n        << Error::RuntimeError()\n        << \"inconsistent tensor size, expected all tensor to have the same shape, \"\n        << \"but got \" << in_0.shape().DebugStr() << \" and \" << cur_in.shape().DebugStr();\n  }\n  out->set_shape(Shape({1}));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> FusedClipGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> FusedClipGradOp::GetSbp(user_op::SbpContext* ctx) {\n  const int64_t num_axes =\n      ctx->LogicalTensorDesc4InputArgNameAndIndex(\"model_diff\", 0).shape().NumAxes();\n  for (int64_t i = 0; i < num_axes; ++i) {\n    ctx->NewBuilder().Split(ctx->inputs(), i).Split(user_op::OpArg(\"out\", 0), i).Build();\n  }\n  ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(user_op::OpArg(\"out\", 0)).Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> FusedClipGradOp::ModifyInputArg(\n    const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) {\n  return InputArgModifyFn(GetInputArgModifierFn, conf);\n}\n\n/* static */ Maybe<void> FusedClipGradOp::InferDataType(user_op::InferContext* ctx) {\n  const auto& in_0 = ctx->InputTensorDesc(\"model_diff\", 0);\n  auto* out = ctx->MutOutputTensorDesc(\"out\", 0);\n  const DataType data_type = in_0.data_type();\n  for (int64_t i = 1; i < ctx->input_size(\"model_diff\"); ++i) {\n    const auto& cur_in = ctx->InputTensorDesc(\"model_diff\", i);\n    CHECK_EQ_OR_RETURN(cur_in.data_type(), data_type)\n        << Error::RuntimeError() << ctx->op_name()\n        << \" expected all tenser to have same type, but found \" << DataType_Name(cur_in.data_type())\n        << \" and \" << DataType_Name(data_type);\n  }\n  out->set_data_type(data_type);\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/fused_codegeex_qkv_reshape.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/common/shape_vec.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\nMaybe<void> FusedCodegeexQkvReshapeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& query = ctx->InputTensorDesc(\"query\", 0);\n  const user_op::TensorDesc& key = ctx->InputTensorDesc(\"key\", 0);\n  const user_op::TensorDesc& value = ctx->InputTensorDesc(\"value\", 0);\n  const int32_t num_attention_heads = ctx->Attr<int32_t>(\"num_attention_heads\");\n  CHECK_EQ_OR_RETURN(query.shape().size(), 3) << \"query shape size should be equal 3\";\n  CHECK_EQ_OR_RETURN(key.shape().size(), 3) << \"key shape size should be equal 3\";\n  CHECK_EQ_OR_RETURN(value.shape().size(), 3) << \"value shape size should be equal 3\";\n  CHECK_EQ_OR_RETURN(query.shape(), key.shape())\n      << \"query, key, value should has same shape in codegeex attention block\";\n  CHECK_EQ_OR_RETURN(query.shape(), value.shape())\n      << \"query, key, value should has same shape in codegeex attention block\";\n  CHECK_EQ_OR_RETURN(query.shape()[2] % num_attention_heads, 0)\n      << \"hidden_size must be divisible by num_attention_heads\";\n\n  Shape new_shape(DimVector{query.shape()[0], query.shape()[1], num_attention_heads,\n                            query.shape()[2] / num_attention_heads});\n  user_op::TensorDesc* new_query = ctx->MutOutputTensorDesc(\"new_query\", 0);\n  new_query->set_is_dynamic(query.is_dynamic());\n  new_query->set_shape(new_shape);\n\n  user_op::TensorDesc* new_key = ctx->MutOutputTensorDesc(\"new_key\", 0);\n  new_key->set_is_dynamic(key.is_dynamic());\n  new_key->set_shape(new_shape);\n\n  user_op::TensorDesc* new_value = ctx->MutOutputTensorDesc(\"new_value\", 0);\n  new_value->set_is_dynamic(value.is_dynamic());\n  new_value->set_shape(new_shape);\n\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> FusedCodegeexQkvReshapeOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return FusedCodegeexQkvReshapeOp::InferLogicalTensorDesc(ctx);\n}\n\nMaybe<void> FusedCodegeexQkvReshapeOp::InferDataType(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& query = ctx->InputTensorDesc(\"query\", 0);\n  const user_op::TensorDesc& key = ctx->InputTensorDesc(\"key\", 0);\n  const user_op::TensorDesc& value = ctx->InputTensorDesc(\"value\", 0);\n\n  user_op::TensorDesc* new_query = ctx->MutOutputTensorDesc(\"new_query\", 0);\n  new_query->set_data_type(query.data_type());\n  user_op::TensorDesc* new_key = ctx->MutOutputTensorDesc(\"new_key\", 0);\n  new_key->set_data_type(key.data_type());\n  user_op::TensorDesc* new_value = ctx->MutOutputTensorDesc(\"new_value\", 0);\n  new_value->set_data_type(value.data_type());\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> FusedCodegeexQkvReshapeOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& query = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"query\", 0);\n  FOR_RANGE(int64_t, i, 0, query.shape().NumAxes() - 1) {\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"query\", 0), i)\n        .Split(user_op::OpArg(\"key\", 0), i)\n        .Split(user_op::OpArg(\"value\", 0), i)\n        .Split(user_op::OpArg(\"new_query\", 0), i)\n        .Split(user_op::OpArg(\"new_key\", 0), i)\n        .Split(user_op::OpArg(\"new_value\", 0), i)\n        .Build();\n  }\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/fused_cross_feature_interaction_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/* static */ Maybe<void> FusedCrossFeatureInteractionOp::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) {\n  const Shape& x_shape = ctx->InputShape(\"x\", 0);\n  const Shape& weight_shape = ctx->InputShape(\"weight\", 0);\n  CHECK_EQ_OR_RETURN(x_shape.At(1), weight_shape.At(1)) << \"Matmul K dims should be equal. \";\n  ctx->SetOutputShape(\"matmul_result\", 0, Shape({x_shape.At(0), weight_shape.At(0)}));\n  const Shape& x0_shape = ctx->InputShape(\"x0\", 0);\n  const Shape& bias_shape = ctx->InputShape(\"bias\", 0);\n  CHECK_EQ_OR_RETURN(bias_shape.At(0), x0_shape.At(1)) << \"Bias dim should be equal to X0 dim1. \";\n  ctx->SetOutputShape(\"out\", 0, x0_shape);\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> FusedCrossFeatureInteractionOp::InferPhysicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> FusedCrossFeatureInteractionOp::GetSbp(user_op::SbpContext* ctx) {\n  ctx->NewBuilder()\n      .Split(user_op::OpArg(\"x\", 0), 0)\n      .Broadcast(user_op::OpArg(\"weight\", 0))\n      .Split(user_op::OpArg(\"x0\", 0), 0)\n      .Broadcast(user_op::OpArg(\"bias\", 0))\n      .Split(user_op::OpArg(\"matmul_result\", 0), 0)\n      .Split(user_op::OpArg(\"out\", 0), 0)\n      .Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> FusedCrossFeatureInteractionOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"x\", 0));\n  ctx->SetOutputDType(\"matmul_result\", 0, ctx->InputDType(\"x\", 0));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> FusedCrossFeatureInteractionV1GradOp::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) {\n  const Shape& x0_shape = ctx->InputShape(\"x0\", 0);\n  const Shape& weight_shape = ctx->InputShape(\"weight\", 0);\n  ctx->SetOutputShape(\"dx0\", 0, x0_shape);\n  ctx->SetOutputShape(\"dw\", 0, weight_shape);\n  ctx->SetOutputShape(\"dx\", 0, x0_shape);\n  ctx->SetOutputShape(\"dbias\", 0, Shape({x0_shape.At(1)}));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> FusedCrossFeatureInteractionV1GradOp::InferPhysicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> FusedCrossFeatureInteractionV1GradOp::GetSbp(user_op::SbpContext* ctx) {\n  ctx->NewBuilder()\n      .Split(user_op::OpArg(\"dy\", 0), 0)\n      .Broadcast(user_op::OpArg(\"weight\", 0))\n      .Split(user_op::OpArg(\"x\", 0), 0)\n      .Split(user_op::OpArg(\"x0\", 0), 0)\n      .Split(user_op::OpArg(\"matmul_result\", 0), 0)\n      .Split(user_op::OpArg(\"dx0\", 0), 0)\n      .PartialSum(user_op::OpArg(\"dw\", 0))\n      .Split(user_op::OpArg(\"dx\", 0), 0)\n      .PartialSum(user_op::OpArg(\"dbias\", 0))\n      .Build();\n\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> FusedCrossFeatureInteractionV1GradOp::InferDataType(\n    user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"dx0\", 0, ctx->InputDType(\"x\", 0));\n  ctx->SetOutputDType(\"dw\", 0, ctx->InputDType(\"x\", 0));\n  ctx->SetOutputDType(\"dx\", 0, ctx->InputDType(\"x\", 0));\n  ctx->SetOutputDType(\"dbias\", 0, ctx->InputDType(\"x\", 0));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> FusedCrossFeatureInteractionV2GradOp::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) {\n  const Shape& x0_shape = ctx->InputShape(\"x0\", 0);\n  const Shape& weight_shape = ctx->InputShape(\"weight\", 0);\n  ctx->SetOutputShape(\"dx0\", 0, x0_shape);\n  ctx->SetOutputShape(\"dw\", 0, weight_shape);\n  ctx->SetOutputShape(\"dx\", 0, x0_shape);\n  ctx->SetOutputShape(\"dbias\", 0, Shape({x0_shape.At(1)}));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> FusedCrossFeatureInteractionV2GradOp::InferPhysicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> FusedCrossFeatureInteractionV2GradOp::GetSbp(user_op::SbpContext* ctx) {\n  ctx->NewBuilder()\n      .Split(user_op::OpArg(\"dy\", 0), 0)\n      .Broadcast(user_op::OpArg(\"weight\", 0))\n      .Broadcast(user_op::OpArg(\"bias\", 0))\n      .Split(user_op::OpArg(\"x\", 0), 0)\n      .Split(user_op::OpArg(\"x0\", 0), 0)\n      .Split(user_op::OpArg(\"matmul_result\", 0), 0)\n      .Split(user_op::OpArg(\"dx0\", 0), 0)\n      .PartialSum(user_op::OpArg(\"dw\", 0))\n      .Split(user_op::OpArg(\"dx\", 0), 0)\n      .PartialSum(user_op::OpArg(\"dbias\", 0))\n      .Build();\n\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> FusedCrossFeatureInteractionV2GradOp::InferDataType(\n    user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"dx0\", 0, ctx->InputDType(\"x\", 0));\n  ctx->SetOutputDType(\"dw\", 0, ctx->InputDType(\"x\", 0));\n  ctx->SetOutputDType(\"dx\", 0, ctx->InputDType(\"x\", 0));\n  ctx->SetOutputDType(\"dbias\", 0, ctx->InputDType(\"x\", 0));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/fused_dot_feature_interaction_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/* static */ Maybe<void> FusedDotFeatureInteractionOp::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) {\n  const int64_t feature_input_size = ctx->input_size(\"features\");\n  CHECK_GE_OR_RETURN(feature_input_size, 1);\n  const Shape& first_feature_shape = ctx->InputShape(\"features\", 0);\n  CHECK_EQ_OR_RETURN(first_feature_shape.NumAxes(), 3);\n  const int64_t batch_size = first_feature_shape.At(0);\n  const int64_t vector_size = first_feature_shape.At(2);\n  int64_t features_concated_dim = first_feature_shape.At(1);\n  for (int64_t i = 1; i < feature_input_size; ++i) {\n    const Shape& feature_shape = ctx->InputShape(\"features\", i);\n    CHECK_EQ_OR_RETURN(feature_shape.NumAxes(), 3);\n    CHECK_EQ_OR_RETURN(feature_shape.At(0), batch_size);\n    CHECK_EQ_OR_RETURN(feature_shape.At(2), vector_size);\n    features_concated_dim += feature_shape.At(1);\n  }\n  const std::string& pooling = ctx->Attr<std::string>(\"pooling\");\n  if (pooling == \"sum\") {\n    ctx->SetOutputShape(\"out\", 0, Shape({batch_size, vector_size}));\n    return Maybe<void>::Ok();\n  }\n  if (ctx->has_input(\"sparse_feature\", 0)) {\n    CHECK_OR_RETURN(pooling == \"none\") << \"only none pooling support sparse feature.\";\n    CHECK_OR_RETURN(ctx->has_input(\"sparse_indices\", 0))\n        << \"if input sparse_feature exists, must have input sparse_indices.\";\n    const Shape& sparse_feature_shape = ctx->InputShape(\"sparse_feature\", 0);\n    const Shape& sparse_indices_shape = ctx->InputShape(\"sparse_indices\", 0);\n    CHECK_EQ_OR_RETURN(sparse_indices_shape.NumAxes(), 2)\n        << \"sparse_indices num_axes must be 2, but get \" << sparse_indices_shape.NumAxes();\n    CHECK_EQ_OR_RETURN(sparse_indices_shape.At(0), batch_size)\n        << \"get \" << sparse_indices_shape.At(0) << \" and \" << batch_size;\n    CHECK_EQ_OR_RETURN(sparse_feature_shape.At(sparse_feature_shape.NumAxes() - 1), vector_size)\n        << \"get \" << sparse_feature_shape.At(sparse_feature_shape.NumAxes() - 1) << \" and \"\n        << vector_size;\n    features_concated_dim += sparse_indices_shape.At(1);\n  }\n  const bool self_interaction = ctx->Attr<bool>(\"self_interaction\");\n  const int32_t output_padding = ctx->Attr<int32_t>(\"output_padding\");\n  const int64_t interaction_dim = self_interaction\n                                      ? features_concated_dim * (features_concated_dim + 1) / 2\n                                      : features_concated_dim * (features_concated_dim - 1) / 2;\n  int64_t out_dim = interaction_dim + output_padding;\n  if (ctx->has_input(\"output_concat\", 0)) {\n    const Shape& output_concat_shape = ctx->InputShape(\"output_concat\", 0);\n    CHECK_EQ_OR_RETURN(output_concat_shape.NumAxes(), 2);\n    CHECK_EQ_OR_RETURN(output_concat_shape.At(0), batch_size);\n    out_dim += output_concat_shape.At(1);\n  }\n  ctx->SetOutputShape(\"out\", 0, Shape({batch_size, out_dim}));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> FusedDotFeatureInteractionOp::InferPhysicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> FusedDotFeatureInteractionOp::GetSbp(user_op::SbpContext* ctx) {\n  auto builder = ctx->NewBuilder().Split(ctx->inputs(), 0).Split(ctx->outputs(), 0);\n  if (ctx->user_op_conf().has_input(\"num_valid_sparse_feature\", 0)) {\n    builder.Broadcast(user_op::OpArg(\"num_valid_sparse_feature\", 0));\n  }\n  builder.Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> FusedDotFeatureInteractionOp::InferDataType(user_op::InferContext* ctx) {\n  const int64_t feature_input_size = ctx->input_size(\"features\");\n  CHECK_GE_OR_RETURN(feature_input_size, 1);\n  DataType first_feature_dtype = ctx->InputDType(\"features\", 0);\n  for (int64_t i = 1; i < feature_input_size; ++i) {\n    CHECK_EQ_OR_RETURN(first_feature_dtype, ctx->InputDType(\"features\", i))\n        << \"InferDataType Failed. Expected \" << DataType_Name(ctx->InputDType(\"features\", i))\n        << \", but got \" << DataType_Name(first_feature_dtype);\n  }\n  if (ctx->has_input(\"output_concat\", 0)) {\n    CHECK_EQ_OR_RETURN(first_feature_dtype, ctx->InputDType(\"output_concat\", 0))\n        << \"InferDataType Failed. Expected \" << DataType_Name(ctx->InputDType(\"output_concat\", 0))\n        << \", but got \" << DataType_Name(first_feature_dtype);\n  }\n  if (ctx->has_input(\"sparse_feature\", 0)) {\n    CHECK_EQ_OR_RETURN(first_feature_dtype, ctx->InputDType(\"sparse_feature\", 0))\n        << \"InferDataType Failed. Expected \" << DataType_Name(ctx->InputDType(\"sparse_feature\", 0))\n        << \", but got \" << DataType_Name(first_feature_dtype);\n  }\n  ctx->SetOutputDType(\"out\", 0, first_feature_dtype);\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> FusedDotFeatureInteractionGradOp::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) {\n  const Shape& dy_shape = ctx->InputShape(\"dy\", 0);\n  const int64_t batch_size = dy_shape.At(0);\n  CHECK_EQ_OR_RETURN(ctx->output_size(\"features_grad\"), ctx->input_size(\"features\"))\n      << \"features_grad and features must have same size\";\n  for (int64_t i = 0; i < ctx->output_size(\"features_grad\"); ++i) {\n    ctx->SetOutputShape(\"features_grad\", i, ctx->InputShape(\"features\", i));\n  }\n  if (ctx->has_output(\"output_concat_grad\", 0)) {\n    const int32_t output_concat_grad_dim = ctx->Attr<int32_t>(\"output_concat_grad_dim\");\n    ctx->SetOutputShape(\"output_concat_grad\", 0, Shape({batch_size, output_concat_grad_dim}));\n  }\n  if (ctx->has_output(\"sparse_feature_grad\", 0)) {\n    ctx->SetOutputShape(\"sparse_feature_grad\", 0, ctx->InputShape(\"sparse_feature\", 0));\n  }\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> FusedDotFeatureInteractionGradOp::InferPhysicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> FusedDotFeatureInteractionGradOp::GetSbp(user_op::SbpContext* ctx) {\n  auto builder = ctx->NewBuilder().Split(ctx->inputs(), 0).Split(ctx->outputs(), 0);\n  if (ctx->user_op_conf().has_input(\"num_valid_sparse_feature\", 0)) {\n    builder.Broadcast(user_op::OpArg(\"num_valid_sparse_feature\", 0));\n  }\n  builder.Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> FusedDotFeatureInteractionGradOp::InferDataType(\n    user_op::InferContext* ctx) {\n  DataType dy_dtype = ctx->InputDType(\"dy\", 0);\n  for (int64_t i = 0; i < ctx->output_size(\"features_grad\"); ++i) {\n    ctx->SetOutputDType(\"features_grad\", i, dy_dtype);\n  }\n  if (ctx->has_output(\"output_concat_grad\", 0)) {\n    ctx->SetOutputDType(\"output_concat_grad\", 0, dy_dtype);\n  }\n  if (ctx->has_output(\"sparse_feature_grad\", 0)) {\n    ctx->SetOutputDType(\"sparse_feature_grad\", 0, dy_dtype);\n  }\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/fused_get_boundding_boxes_coord_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\nMaybe<void> FusedGetBounddingBoxesCoordOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& x1 = ctx->InputTensorDesc(\"x1\", 0);\n  Shape x1_shape = x1.shape();\n\n  user_op::TensorDesc* b1_x1 = ctx->MutOutputTensorDesc(\"b1_x1\", 0);\n  b1_x1->set_is_dynamic(x1.is_dynamic());\n  b1_x1->set_shape(x1_shape);\n\n  user_op::TensorDesc* b1_x2 = ctx->MutOutputTensorDesc(\"b1_x2\", 0);\n  b1_x2->set_is_dynamic(x1.is_dynamic());\n  b1_x2->set_shape(x1_shape);\n\n  user_op::TensorDesc* b1_y1 = ctx->MutOutputTensorDesc(\"b1_y1\", 0);\n  b1_y1->set_is_dynamic(x1.is_dynamic());\n  b1_y1->set_shape(x1_shape);\n\n  user_op::TensorDesc* b1_y2 = ctx->MutOutputTensorDesc(\"b1_y2\", 0);\n  b1_y2->set_is_dynamic(x1.is_dynamic());\n  b1_y2->set_shape(x1_shape);\n\n  user_op::TensorDesc* b2_x1 = ctx->MutOutputTensorDesc(\"b2_x1\", 0);\n  b2_x1->set_is_dynamic(x1.is_dynamic());\n  b2_x1->set_shape(x1_shape);\n\n  user_op::TensorDesc* b2_x2 = ctx->MutOutputTensorDesc(\"b2_x2\", 0);\n  b2_x2->set_is_dynamic(x1.is_dynamic());\n  b2_x2->set_shape(x1_shape);\n\n  user_op::TensorDesc* b2_y1 = ctx->MutOutputTensorDesc(\"b2_y1\", 0);\n  b2_y1->set_is_dynamic(x1.is_dynamic());\n  b2_y1->set_shape(x1_shape);\n\n  user_op::TensorDesc* b2_y2 = ctx->MutOutputTensorDesc(\"b2_y2\", 0);\n  b2_y2->set_is_dynamic(x1.is_dynamic());\n  b2_y2->set_shape(x1_shape);\n\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> FusedGetBounddingBoxesCoordOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return FusedGetBounddingBoxesCoordOp::InferLogicalTensorDesc(ctx);\n}\n\nMaybe<void> FusedGetBounddingBoxesCoordOp::InferDataType(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& x1 = ctx->InputTensorDesc(\"x1\", 0);\n\n  user_op::TensorDesc* b1_x1 = ctx->MutOutputTensorDesc(\"b1_x1\", 0);\n  b1_x1->set_data_type(x1.data_type());\n\n  user_op::TensorDesc* b1_x2 = ctx->MutOutputTensorDesc(\"b1_x2\", 0);\n  b1_x2->set_data_type(x1.data_type());\n\n  user_op::TensorDesc* b1_y1 = ctx->MutOutputTensorDesc(\"b1_y1\", 0);\n  b1_y1->set_data_type(x1.data_type());\n\n  user_op::TensorDesc* b1_y2 = ctx->MutOutputTensorDesc(\"b1_y2\", 0);\n  b1_y2->set_data_type(x1.data_type());\n\n  user_op::TensorDesc* b2_x1 = ctx->MutOutputTensorDesc(\"b2_x1\", 0);\n  b2_x1->set_data_type(x1.data_type());\n\n  user_op::TensorDesc* b2_x2 = ctx->MutOutputTensorDesc(\"b2_x2\", 0);\n  b2_x2->set_data_type(x1.data_type());\n\n  user_op::TensorDesc* b2_y1 = ctx->MutOutputTensorDesc(\"b2_y1\", 0);\n  b2_y1->set_data_type(x1.data_type());\n\n  user_op::TensorDesc* b2_y2 = ctx->MutOutputTensorDesc(\"b2_y2\", 0);\n  b2_y2->set_data_type(x1.data_type());\n\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> FusedGetBounddingBoxesCoordOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& x1 = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"x1\", 0);\n  FOR_RANGE(int64_t, i, 0, x1.shape().NumAxes()) {\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"x1\", 0), i)\n        .Split(user_op::OpArg(\"y1\", 0), i)\n        .Split(user_op::OpArg(\"w1\", 0), i)\n        .Split(user_op::OpArg(\"h1\", 0), i)\n        .Split(user_op::OpArg(\"x2\", 0), i)\n        .Split(user_op::OpArg(\"y2\", 0), i)\n        .Split(user_op::OpArg(\"w2\", 0), i)\n        .Split(user_op::OpArg(\"h2\", 0), i)\n        .Split(user_op::OpArg(\"b1_x1\", 0), i)\n        .Split(user_op::OpArg(\"b1_x2\", 0), i)\n        .Split(user_op::OpArg(\"b1_y1\", 0), i)\n        .Split(user_op::OpArg(\"b1_y2\", 0), i)\n        .Split(user_op::OpArg(\"b2_x1\", 0), i)\n        .Split(user_op::OpArg(\"b2_x2\", 0), i)\n        .Split(user_op::OpArg(\"b2_y1\", 0), i)\n        .Split(user_op::OpArg(\"b2_y2\", 0), i)\n        .Build();\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> FusedGetBounddingBoxesCoordGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& b1_x1_diff = ctx->InputTensorDesc(\"b1_x1_diff\", 0);\n\n  user_op::TensorDesc* x1_diff = ctx->MutOutputTensorDesc(\"x1_diff\", 0);\n  x1_diff->set_is_dynamic(b1_x1_diff.is_dynamic());\n  x1_diff->set_shape(b1_x1_diff.shape());\n\n  user_op::TensorDesc* y1_diff = ctx->MutOutputTensorDesc(\"y1_diff\", 0);\n  y1_diff->set_is_dynamic(b1_x1_diff.is_dynamic());\n  y1_diff->set_shape(b1_x1_diff.shape());\n\n  user_op::TensorDesc* w1_diff = ctx->MutOutputTensorDesc(\"w1_diff\", 0);\n  w1_diff->set_is_dynamic(b1_x1_diff.is_dynamic());\n  w1_diff->set_shape(b1_x1_diff.shape());\n\n  user_op::TensorDesc* h1_diff = ctx->MutOutputTensorDesc(\"h1_diff\", 0);\n  h1_diff->set_is_dynamic(b1_x1_diff.is_dynamic());\n  h1_diff->set_shape(b1_x1_diff.shape());\n\n  user_op::TensorDesc* x2_diff = ctx->MutOutputTensorDesc(\"x2_diff\", 0);\n  x2_diff->set_is_dynamic(b1_x1_diff.is_dynamic());\n  x2_diff->set_shape(b1_x1_diff.shape());\n\n  user_op::TensorDesc* y2_diff = ctx->MutOutputTensorDesc(\"y2_diff\", 0);\n  y2_diff->set_is_dynamic(b1_x1_diff.is_dynamic());\n  y2_diff->set_shape(b1_x1_diff.shape());\n\n  user_op::TensorDesc* w2_diff = ctx->MutOutputTensorDesc(\"w2_diff\", 0);\n  w2_diff->set_is_dynamic(b1_x1_diff.is_dynamic());\n  w2_diff->set_shape(b1_x1_diff.shape());\n\n  user_op::TensorDesc* h2_diff = ctx->MutOutputTensorDesc(\"h2_diff\", 0);\n  h2_diff->set_is_dynamic(b1_x1_diff.is_dynamic());\n  h2_diff->set_shape(b1_x1_diff.shape());\n\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> FusedGetBounddingBoxesCoordGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return FusedGetBounddingBoxesCoordGradOp::InferLogicalTensorDesc(ctx);\n}\n\nMaybe<void> FusedGetBounddingBoxesCoordGradOp::InferDataType(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& b1_x1_diff = ctx->InputTensorDesc(\"b1_x1_diff\", 0);\n\n  user_op::TensorDesc* x1_diff = ctx->MutOutputTensorDesc(\"x1_diff\", 0);\n  x1_diff->set_data_type(b1_x1_diff.data_type());\n\n  user_op::TensorDesc* y1_diff = ctx->MutOutputTensorDesc(\"y1_diff\", 0);\n  y1_diff->set_data_type(b1_x1_diff.data_type());\n\n  user_op::TensorDesc* w1_diff = ctx->MutOutputTensorDesc(\"w1_diff\", 0);\n  w1_diff->set_data_type(b1_x1_diff.data_type());\n\n  user_op::TensorDesc* h1_diff = ctx->MutOutputTensorDesc(\"h1_diff\", 0);\n  h1_diff->set_data_type(b1_x1_diff.data_type());\n\n  user_op::TensorDesc* x2_diff = ctx->MutOutputTensorDesc(\"x2_diff\", 0);\n  x2_diff->set_data_type(b1_x1_diff.data_type());\n\n  user_op::TensorDesc* y2_diff = ctx->MutOutputTensorDesc(\"y2_diff\", 0);\n  y2_diff->set_data_type(b1_x1_diff.data_type());\n\n  user_op::TensorDesc* w2_diff = ctx->MutOutputTensorDesc(\"w2_diff\", 0);\n  w2_diff->set_data_type(b1_x1_diff.data_type());\n\n  user_op::TensorDesc* h2_diff = ctx->MutOutputTensorDesc(\"h2_diff\", 0);\n  h2_diff->set_data_type(b1_x1_diff.data_type());\n\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> FusedGetBounddingBoxesCoordGradOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& b1_x1_diff =\n      ctx->LogicalTensorDesc4InputArgNameAndIndex(\"b1_x1_diff\", 0);\n  FOR_RANGE(int64_t, i, 0, b1_x1_diff.shape().NumAxes()) {\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"b1_x1_diff\", 0), i)\n        .Split(user_op::OpArg(\"b1_x2_diff\", 0), i)\n        .Split(user_op::OpArg(\"b1_y1_diff\", 0), i)\n        .Split(user_op::OpArg(\"b1_y2_diff\", 0), i)\n        .Split(user_op::OpArg(\"b2_x1_diff\", 0), i)\n        .Split(user_op::OpArg(\"b2_x2_diff\", 0), i)\n        .Split(user_op::OpArg(\"b2_y1_diff\", 0), i)\n        .Split(user_op::OpArg(\"b2_y2_diff\", 0), i)\n        .Split(user_op::OpArg(\"x1_diff\", 0), i)\n        .Split(user_op::OpArg(\"y1_diff\", 0), i)\n        .Split(user_op::OpArg(\"w1_diff\", 0), i)\n        .Split(user_op::OpArg(\"h1_diff\", 0), i)\n        .Split(user_op::OpArg(\"x2_diff\", 0), i)\n        .Split(user_op::OpArg(\"y2_diff\", 0), i)\n        .Split(user_op::OpArg(\"w2_diff\", 0), i)\n        .Split(user_op::OpArg(\"h2_diff\", 0), i)\n        .Build();\n  }\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/fused_get_ciou_diagonal_angle_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\nMaybe<void> FusedGetCiouDiagonalAngleOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& w1 = ctx->InputTensorDesc(\"w1\", 0);\n  const user_op::TensorDesc& h1 = ctx->InputTensorDesc(\"h1\", 0);\n  const user_op::TensorDesc& w2 = ctx->InputTensorDesc(\"w2\", 0);\n  const user_op::TensorDesc& h2 = ctx->InputTensorDesc(\"h2\", 0);\n\n  CHECK_EQ_OR_RETURN(w1.shape(), h1.shape());\n  CHECK_EQ_OR_RETURN(w1.shape(), w2.shape());\n  CHECK_EQ_OR_RETURN(w1.shape(), h2.shape());\n\n  user_op::TensorDesc* v = ctx->MutOutputTensorDesc(\"v\", 0);\n  v->set_is_dynamic(w1.is_dynamic());\n  v->set_shape(w1.shape());\n\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> FusedGetCiouDiagonalAngleOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return FusedGetCiouDiagonalAngleOp::InferLogicalTensorDesc(ctx);\n}\n\nMaybe<void> FusedGetCiouDiagonalAngleOp::InferDataType(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& w1 = ctx->InputTensorDesc(\"w1\", 0);\n  const user_op::TensorDesc& h1 = ctx->InputTensorDesc(\"h1\", 0);\n  const user_op::TensorDesc& w2 = ctx->InputTensorDesc(\"w2\", 0);\n  const user_op::TensorDesc& h2 = ctx->InputTensorDesc(\"h2\", 0);\n\n  CHECK_EQ_OR_RETURN(w1.data_type(), h1.data_type());\n  CHECK_EQ_OR_RETURN(w1.data_type(), w2.data_type());\n  CHECK_EQ_OR_RETURN(w1.data_type(), h2.data_type());\n\n  user_op::TensorDesc* v = ctx->MutOutputTensorDesc(\"v\", 0);\n  v->set_data_type(w1.data_type());\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> FusedGetCiouDiagonalAngleOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& b1_x1 = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"w1\", 0);\n  FOR_RANGE(int64_t, i, 0, b1_x1.shape().NumAxes()) {\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"w1\", 0), i)\n        .Split(user_op::OpArg(\"h1\", 0), i)\n        .Split(user_op::OpArg(\"w2\", 0), i)\n        .Split(user_op::OpArg(\"h2\", 0), i)\n        .Split(user_op::OpArg(\"v\", 0), i)\n        .Build();\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> FusedGetCiouDiagonalAngleGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& w1 = ctx->InputTensorDesc(\"w1\", 0);\n  const user_op::TensorDesc& h1 = ctx->InputTensorDesc(\"h1\", 0);\n  const user_op::TensorDesc& w2 = ctx->InputTensorDesc(\"w2\", 0);\n  const user_op::TensorDesc& h2 = ctx->InputTensorDesc(\"h2\", 0);\n  const user_op::TensorDesc& v_diff = ctx->InputTensorDesc(\"v_diff\", 0);\n\n  CHECK_EQ_OR_RETURN(w1.shape(), h1.shape());\n  CHECK_EQ_OR_RETURN(w1.shape(), w2.shape());\n  CHECK_EQ_OR_RETURN(w1.shape(), h2.shape());\n  CHECK_EQ_OR_RETURN(w1.shape(), v_diff.shape());\n\n  user_op::TensorDesc* w1_diff = ctx->MutOutputTensorDesc(\"w1_diff\", 0);\n  w1_diff->set_is_dynamic(w1.is_dynamic());\n  w1_diff->set_shape(w1.shape());\n\n  user_op::TensorDesc* h1_diff = ctx->MutOutputTensorDesc(\"h1_diff\", 0);\n  h1_diff->set_is_dynamic(w1.is_dynamic());\n  h1_diff->set_shape(w1.shape());\n\n  user_op::TensorDesc* w2_diff = ctx->MutOutputTensorDesc(\"w2_diff\", 0);\n  w2_diff->set_is_dynamic(w1.is_dynamic());\n  w2_diff->set_shape(w1.shape());\n\n  user_op::TensorDesc* h2_diff = ctx->MutOutputTensorDesc(\"h2_diff\", 0);\n  h2_diff->set_is_dynamic(w1.is_dynamic());\n  h2_diff->set_shape(w1.shape());\n\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> FusedGetCiouDiagonalAngleGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return FusedGetCiouDiagonalAngleGradOp::InferLogicalTensorDesc(ctx);\n}\n\nMaybe<void> FusedGetCiouDiagonalAngleGradOp::InferDataType(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& w1 = ctx->InputTensorDesc(\"w1\", 0);\n  const user_op::TensorDesc& h1 = ctx->InputTensorDesc(\"h1\", 0);\n  const user_op::TensorDesc& w2 = ctx->InputTensorDesc(\"w2\", 0);\n  const user_op::TensorDesc& h2 = ctx->InputTensorDesc(\"h2\", 0);\n  const user_op::TensorDesc& v_diff = ctx->InputTensorDesc(\"v_diff\", 0);\n\n  CHECK_EQ_OR_RETURN(w1.data_type(), h1.data_type());\n  CHECK_EQ_OR_RETURN(w1.data_type(), w2.data_type());\n  CHECK_EQ_OR_RETURN(w1.data_type(), h2.data_type());\n  CHECK_EQ_OR_RETURN(w1.data_type(), v_diff.data_type());\n\n  user_op::TensorDesc* w1_diff = ctx->MutOutputTensorDesc(\"w1_diff\", 0);\n  w1_diff->set_is_dynamic(w1.is_dynamic());\n  w1_diff->set_data_type(w1.data_type());\n\n  user_op::TensorDesc* h1_diff = ctx->MutOutputTensorDesc(\"h1_diff\", 0);\n  h1_diff->set_is_dynamic(w1.is_dynamic());\n  h1_diff->set_data_type(w1.data_type());\n\n  user_op::TensorDesc* w2_diff = ctx->MutOutputTensorDesc(\"w2_diff\", 0);\n  w2_diff->set_is_dynamic(w1.is_dynamic());\n  w2_diff->set_data_type(w1.data_type());\n\n  user_op::TensorDesc* h2_diff = ctx->MutOutputTensorDesc(\"h2_diff\", 0);\n  h2_diff->set_is_dynamic(w1.is_dynamic());\n  h2_diff->set_data_type(w1.data_type());\n\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> FusedGetCiouDiagonalAngleGradOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& w1 = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"w1\", 0);\n  FOR_RANGE(int64_t, i, 0, w1.shape().NumAxes()) {\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"w1\", 0), i)\n        .Split(user_op::OpArg(\"h1\", 0), i)\n        .Split(user_op::OpArg(\"w2\", 0), i)\n        .Split(user_op::OpArg(\"h1\", 0), i)\n        .Split(user_op::OpArg(\"v_diff\", 0), i)\n        .Split(user_op::OpArg(\"w1_diff\", 0), i)\n        .Split(user_op::OpArg(\"h1_diff\", 0), i)\n        .Split(user_op::OpArg(\"w2_diff\", 0), i)\n        .Split(user_op::OpArg(\"h2_diff\", 0), i)\n        .Build();\n  }\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/fused_get_ciou_result_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\nMaybe<void> FusedGetCiouResultOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& v = ctx->InputTensorDesc(\"v\", 0);\n\n  user_op::TensorDesc* y = ctx->MutOutputTensorDesc(\"y\", 0);\n  y->set_is_dynamic(v.is_dynamic());\n  y->set_shape(v.shape());\n\n  user_op::TensorDesc* ahpha = ctx->MutOutputTensorDesc(\"alpha\", 0);\n  ahpha->set_is_dynamic(v.is_dynamic());\n  ahpha->set_shape(v.shape());\n\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> FusedGetCiouResultOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return FusedGetCiouResultOp::InferLogicalTensorDesc(ctx);\n}\n\nMaybe<void> FusedGetCiouResultOp::InferDataType(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& v = ctx->InputTensorDesc(\"v\", 0);\n\n  user_op::TensorDesc* y = ctx->MutOutputTensorDesc(\"y\", 0);\n  y->set_data_type(v.data_type());\n\n  user_op::TensorDesc* alpha = ctx->MutOutputTensorDesc(\"alpha\", 0);\n  alpha->set_data_type(v.data_type());\n\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> FusedGetCiouResultOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& v = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"v\", 0);\n  FOR_RANGE(int64_t, i, 0, v.shape().NumAxes()) {\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"v\", 0), i)\n        .Split(user_op::OpArg(\"iou\", 0), i)\n        .Split(user_op::OpArg(\"rho2\", 0), i)\n        .Split(user_op::OpArg(\"c2\", 0), i)\n        .Split(user_op::OpArg(\"y\", 0), i)\n        .Split(user_op::OpArg(\"alpha\", 0), i)\n        .Build();\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> FusedGetCiouResultGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& dy = ctx->InputTensorDesc(\"dy\", 0);\n\n  user_op::TensorDesc* dv = ctx->MutOutputTensorDesc(\"dv\", 0);\n  dv->set_is_dynamic(dy.is_dynamic());\n  dv->set_shape(dy.shape());\n\n  user_op::TensorDesc* diou = ctx->MutOutputTensorDesc(\"diou\", 0);\n  diou->set_is_dynamic(dy.is_dynamic());\n  diou->set_shape(dy.shape());\n\n  user_op::TensorDesc* drho2 = ctx->MutOutputTensorDesc(\"drho2\", 0);\n  drho2->set_is_dynamic(dy.is_dynamic());\n  drho2->set_shape(dy.shape());\n\n  user_op::TensorDesc* dc2 = ctx->MutOutputTensorDesc(\"dc2\", 0);\n  dc2->set_is_dynamic(dy.is_dynamic());\n  dc2->set_shape(dy.shape());\n\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> FusedGetCiouResultGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return FusedGetCiouResultGradOp::InferLogicalTensorDesc(ctx);\n}\n\nMaybe<void> FusedGetCiouResultGradOp::InferDataType(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& dy = ctx->InputTensorDesc(\"dy\", 0);\n\n  user_op::TensorDesc* dv = ctx->MutOutputTensorDesc(\"dv\", 0);\n  dv->set_data_type(dy.data_type());\n\n  user_op::TensorDesc* diou = ctx->MutOutputTensorDesc(\"diou\", 0);\n  diou->set_data_type(dy.data_type());\n\n  user_op::TensorDesc* drho2 = ctx->MutOutputTensorDesc(\"drho2\", 0);\n  drho2->set_data_type(dy.data_type());\n\n  user_op::TensorDesc* dc2 = ctx->MutOutputTensorDesc(\"dc2\", 0);\n  dc2->set_data_type(dy.data_type());\n\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> FusedGetCiouResultGradOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& dy = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"dy\", 0);\n  FOR_RANGE(int64_t, i, 0, dy.shape().NumAxes()) {\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"dy\", 0), i)\n        .Split(user_op::OpArg(\"alpha\", 0), i)\n        .Split(user_op::OpArg(\"rho2\", 0), i)\n        .Split(user_op::OpArg(\"c2\", 0), i)\n        .Split(user_op::OpArg(\"dv\", 0), i)\n        .Split(user_op::OpArg(\"diou\", 0), i)\n        .Split(user_op::OpArg(\"drho2\", 0), i)\n        .Split(user_op::OpArg(\"dc2\", 0), i)\n        .Build();\n  }\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/fused_get_convex_diagonal_squared_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\nMaybe<void> FusedGetConvexDiagonalSquaredOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& b1_x1 = ctx->InputTensorDesc(\"b1_x1\", 0);\n\n  user_op::TensorDesc* c2 = ctx->MutOutputTensorDesc(\"c2\", 0);\n  c2->set_is_dynamic(b1_x1.is_dynamic());\n  c2->set_shape(b1_x1.shape());\n\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> FusedGetConvexDiagonalSquaredOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return FusedGetConvexDiagonalSquaredOp::InferLogicalTensorDesc(ctx);\n}\n\nMaybe<void> FusedGetConvexDiagonalSquaredOp::InferDataType(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& b1_x1 = ctx->InputTensorDesc(\"b1_x1\", 0);\n\n  user_op::TensorDesc* c2 = ctx->MutOutputTensorDesc(\"c2\", 0);\n  c2->set_data_type(b1_x1.data_type());\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> FusedGetConvexDiagonalSquaredOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& b1_x1 = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"b1_x1\", 0);\n  FOR_RANGE(int64_t, i, 0, b1_x1.shape().NumAxes()) {\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"b1_x1\", 0), i)\n        .Split(user_op::OpArg(\"b1_x2\", 0), i)\n        .Split(user_op::OpArg(\"b2_x1\", 0), i)\n        .Split(user_op::OpArg(\"b2_x2\", 0), i)\n        .Split(user_op::OpArg(\"b1_y1\", 0), i)\n        .Split(user_op::OpArg(\"b1_y2\", 0), i)\n        .Split(user_op::OpArg(\"b2_y1\", 0), i)\n        .Split(user_op::OpArg(\"b2_y2\", 0), i)\n        .Split(user_op::OpArg(\"c2\", 0), i)\n        .Build();\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> FusedGetConvexDiagonalSquaredGradOp::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) {\n  const user_op::TensorDesc& b1_x1 = ctx->InputTensorDesc(\"b1_x1\", 0);\n\n  user_op::TensorDesc* b1_x1_diff = ctx->MutOutputTensorDesc(\"b1_x1_diff\", 0);\n  b1_x1_diff->set_is_dynamic(b1_x1.is_dynamic());\n  b1_x1_diff->set_shape(b1_x1.shape());\n\n  user_op::TensorDesc* b1_x2_diff = ctx->MutOutputTensorDesc(\"b1_x2_diff\", 0);\n  b1_x2_diff->set_is_dynamic(b1_x1.is_dynamic());\n  b1_x2_diff->set_shape(b1_x1.shape());\n\n  user_op::TensorDesc* b2_x1_diff = ctx->MutOutputTensorDesc(\"b2_x1_diff\", 0);\n  b2_x1_diff->set_is_dynamic(b1_x1.is_dynamic());\n  b2_x1_diff->set_shape(b1_x1.shape());\n\n  user_op::TensorDesc* b2_x2_diff = ctx->MutOutputTensorDesc(\"b2_x2_diff\", 0);\n  b2_x2_diff->set_is_dynamic(b1_x1.is_dynamic());\n  b2_x2_diff->set_shape(b1_x1.shape());\n\n  user_op::TensorDesc* b1_y1_diff = ctx->MutOutputTensorDesc(\"b1_y1_diff\", 0);\n  b1_y1_diff->set_is_dynamic(b1_x1.is_dynamic());\n  b1_y1_diff->set_shape(b1_x1.shape());\n\n  user_op::TensorDesc* b1_y2_diff = ctx->MutOutputTensorDesc(\"b1_y2_diff\", 0);\n  b1_y2_diff->set_is_dynamic(b1_x1.is_dynamic());\n  b1_y2_diff->set_shape(b1_x1.shape());\n\n  user_op::TensorDesc* b2_y1_diff = ctx->MutOutputTensorDesc(\"b2_y1_diff\", 0);\n  b2_y1_diff->set_is_dynamic(b1_x1.is_dynamic());\n  b2_y1_diff->set_shape(b1_x1.shape());\n\n  user_op::TensorDesc* b2_y2_diff = ctx->MutOutputTensorDesc(\"b2_y2_diff\", 0);\n  b2_y2_diff->set_is_dynamic(b1_x1.is_dynamic());\n  b2_y2_diff->set_shape(b1_x1.shape());\n\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> FusedGetConvexDiagonalSquaredGradOp::InferPhysicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return FusedGetConvexDiagonalSquaredGradOp::InferLogicalTensorDesc(ctx);\n}\n\nMaybe<void> FusedGetConvexDiagonalSquaredGradOp::InferDataType(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& b1_x1 = ctx->InputTensorDesc(\"b1_x1\", 0);\n\n  user_op::TensorDesc* b1_x1_diff = ctx->MutOutputTensorDesc(\"b1_x1_diff\", 0);\n  b1_x1_diff->set_data_type(b1_x1.data_type());\n\n  user_op::TensorDesc* b1_x2_diff = ctx->MutOutputTensorDesc(\"b1_x2_diff\", 0);\n  b1_x2_diff->set_data_type(b1_x1.data_type());\n\n  user_op::TensorDesc* b2_x1_diff = ctx->MutOutputTensorDesc(\"b2_x1_diff\", 0);\n  b2_x1_diff->set_data_type(b1_x1.data_type());\n\n  user_op::TensorDesc* b2_x2_diff = ctx->MutOutputTensorDesc(\"b2_x2_diff\", 0);\n  b2_x2_diff->set_data_type(b1_x1.data_type());\n\n  user_op::TensorDesc* b1_y1_diff = ctx->MutOutputTensorDesc(\"b1_y1_diff\", 0);\n  b1_y1_diff->set_data_type(b1_x1.data_type());\n\n  user_op::TensorDesc* b1_y2_diff = ctx->MutOutputTensorDesc(\"b1_y2_diff\", 0);\n  b1_y2_diff->set_data_type(b1_x1.data_type());\n\n  user_op::TensorDesc* b2_y1_diff = ctx->MutOutputTensorDesc(\"b2_y1_diff\", 0);\n  b2_y1_diff->set_data_type(b1_x1.data_type());\n\n  user_op::TensorDesc* b2_y2_diff = ctx->MutOutputTensorDesc(\"b2_y2_diff\", 0);\n  b2_y2_diff->set_data_type(b1_x1.data_type());\n\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> FusedGetConvexDiagonalSquaredGradOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& b1_x1 = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"b1_x1\", 0);\n  FOR_RANGE(int64_t, i, 0, b1_x1.shape().NumAxes()) {\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"c2_diff\", 0), i)\n        .Split(user_op::OpArg(\"b1_x1\", 0), i)\n        .Split(user_op::OpArg(\"b1_x2\", 0), i)\n        .Split(user_op::OpArg(\"b2_x1\", 0), i)\n        .Split(user_op::OpArg(\"b2_x2\", 0), i)\n        .Split(user_op::OpArg(\"b1_y1\", 0), i)\n        .Split(user_op::OpArg(\"b1_y2\", 0), i)\n        .Split(user_op::OpArg(\"b2_y1\", 0), i)\n        .Split(user_op::OpArg(\"b2_y2\", 0), i)\n        .Split(user_op::OpArg(\"b1_x1_diff\", 0), i)\n        .Split(user_op::OpArg(\"b1_x2_diff\", 0), i)\n        .Split(user_op::OpArg(\"b2_x1_diff\", 0), i)\n        .Split(user_op::OpArg(\"b2_x2_diff\", 0), i)\n        .Split(user_op::OpArg(\"b1_y1_diff\", 0), i)\n        .Split(user_op::OpArg(\"b1_y2_diff\", 0), i)\n        .Split(user_op::OpArg(\"b2_y1_diff\", 0), i)\n        .Split(user_op::OpArg(\"b2_y2_diff\", 0), i)\n        .Build();\n  }\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/fused_get_intersection_area_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\nMaybe<void> FusedGetIntersectionAreaOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& b1_x1 = ctx->InputTensorDesc(\"b1_x1\", 0);\n  const user_op::TensorDesc& b1_x2 = ctx->InputTensorDesc(\"b1_x2\", 0);\n  const user_op::TensorDesc& b1_y1 = ctx->InputTensorDesc(\"b1_y1\", 0);\n  const user_op::TensorDesc& b1_y2 = ctx->InputTensorDesc(\"b1_y2\", 0);\n  const user_op::TensorDesc& b2_x1 = ctx->InputTensorDesc(\"b2_x1\", 0);\n  const user_op::TensorDesc& b2_x2 = ctx->InputTensorDesc(\"b2_x2\", 0);\n  const user_op::TensorDesc& b2_y1 = ctx->InputTensorDesc(\"b2_y1\", 0);\n  const user_op::TensorDesc& b2_y2 = ctx->InputTensorDesc(\"b2_y2\", 0);\n\n  CHECK_EQ_OR_RETURN(b1_x1.shape(), b1_x2.shape());\n  CHECK_EQ_OR_RETURN(b1_x1.shape(), b1_y1.shape());\n  CHECK_EQ_OR_RETURN(b1_x1.shape(), b1_y2.shape());\n  CHECK_EQ_OR_RETURN(b1_x1.shape(), b2_x1.shape());\n  CHECK_EQ_OR_RETURN(b1_x1.shape(), b2_x2.shape());\n  CHECK_EQ_OR_RETURN(b1_x1.shape(), b2_y1.shape());\n  CHECK_EQ_OR_RETURN(b1_x1.shape(), b2_y2.shape());\n\n  user_op::TensorDesc* inter = ctx->MutOutputTensorDesc(\"inter\", 0);\n  inter->set_is_dynamic(b1_x1.is_dynamic());\n  inter->set_shape(b1_x1.shape());\n\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> FusedGetIntersectionAreaOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return FusedGetIntersectionAreaOp::InferLogicalTensorDesc(ctx);\n}\n\nMaybe<void> FusedGetIntersectionAreaOp::InferDataType(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& b1_x1 = ctx->InputTensorDesc(\"b1_x1\", 0);\n  const user_op::TensorDesc& b1_x2 = ctx->InputTensorDesc(\"b1_x2\", 0);\n  const user_op::TensorDesc& b1_y1 = ctx->InputTensorDesc(\"b1_y1\", 0);\n  const user_op::TensorDesc& b1_y2 = ctx->InputTensorDesc(\"b1_y2\", 0);\n  const user_op::TensorDesc& b2_x1 = ctx->InputTensorDesc(\"b2_x1\", 0);\n  const user_op::TensorDesc& b2_x2 = ctx->InputTensorDesc(\"b2_x2\", 0);\n  const user_op::TensorDesc& b2_y1 = ctx->InputTensorDesc(\"b2_y1\", 0);\n  const user_op::TensorDesc& b2_y2 = ctx->InputTensorDesc(\"b2_y2\", 0);\n\n  CHECK_EQ_OR_RETURN(b1_x1.data_type(), b1_x2.data_type());\n  CHECK_EQ_OR_RETURN(b1_x1.data_type(), b1_y1.data_type());\n  CHECK_EQ_OR_RETURN(b1_x1.data_type(), b1_y2.data_type());\n  CHECK_EQ_OR_RETURN(b1_x1.data_type(), b2_x1.data_type());\n  CHECK_EQ_OR_RETURN(b1_x1.data_type(), b2_x2.data_type());\n  CHECK_EQ_OR_RETURN(b1_x1.data_type(), b2_y1.data_type());\n  CHECK_EQ_OR_RETURN(b1_x1.data_type(), b2_y2.data_type());\n\n  user_op::TensorDesc* inter = ctx->MutOutputTensorDesc(\"inter\", 0);\n  inter->set_data_type(b1_x1.data_type());\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> FusedGetIntersectionAreaOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& b1_x1 = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"b1_x1\", 0);\n  FOR_RANGE(int64_t, i, 0, b1_x1.shape().NumAxes()) {\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"b1_x1\", 0), i)\n        .Split(user_op::OpArg(\"b1_x2\", 0), i)\n        .Split(user_op::OpArg(\"b1_y1\", 0), i)\n        .Split(user_op::OpArg(\"b1_y2\", 0), i)\n        .Split(user_op::OpArg(\"b2_x1\", 0), i)\n        .Split(user_op::OpArg(\"b2_x2\", 0), i)\n        .Split(user_op::OpArg(\"b2_y1\", 0), i)\n        .Split(user_op::OpArg(\"b2_y2\", 0), i)\n        .Split(user_op::OpArg(\"inter\", 0), i)\n        .Build();\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> FusedGetIntersectionAreaGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& b1_x1 = ctx->InputTensorDesc(\"b1_x1\", 0);\n  const user_op::TensorDesc& b1_x2 = ctx->InputTensorDesc(\"b1_x2\", 0);\n  const user_op::TensorDesc& b1_y1 = ctx->InputTensorDesc(\"b1_y1\", 0);\n  const user_op::TensorDesc& b1_y2 = ctx->InputTensorDesc(\"b1_y2\", 0);\n  const user_op::TensorDesc& b2_x1 = ctx->InputTensorDesc(\"b2_x1\", 0);\n  const user_op::TensorDesc& b2_x2 = ctx->InputTensorDesc(\"b2_x2\", 0);\n  const user_op::TensorDesc& b2_y1 = ctx->InputTensorDesc(\"b2_y1\", 0);\n  const user_op::TensorDesc& b2_y2 = ctx->InputTensorDesc(\"b2_y2\", 0);\n  const user_op::TensorDesc& inter_diff = ctx->InputTensorDesc(\"inter_diff\", 0);\n\n  CHECK_EQ_OR_RETURN(b1_x1.shape(), b1_x2.shape());\n  CHECK_EQ_OR_RETURN(b1_x1.shape(), b1_y1.shape());\n  CHECK_EQ_OR_RETURN(b1_x1.shape(), b1_y2.shape());\n  CHECK_EQ_OR_RETURN(b1_x1.shape(), b2_x1.shape());\n  CHECK_EQ_OR_RETURN(b1_x1.shape(), b2_x2.shape());\n  CHECK_EQ_OR_RETURN(b1_x1.shape(), b2_y1.shape());\n  CHECK_EQ_OR_RETURN(b1_x1.shape(), b2_y2.shape());\n  CHECK_EQ_OR_RETURN(b1_x1.shape(), inter_diff.shape());\n\n  user_op::TensorDesc* b1_x1_diff = ctx->MutOutputTensorDesc(\"b1_x1_diff\", 0);\n  b1_x1_diff->set_is_dynamic(b1_x1.is_dynamic());\n  b1_x1_diff->set_shape(b1_x1.shape());\n\n  user_op::TensorDesc* b1_x2_diff = ctx->MutOutputTensorDesc(\"b1_x2_diff\", 0);\n  b1_x2_diff->set_is_dynamic(b1_x1.is_dynamic());\n  b1_x2_diff->set_shape(b1_x1.shape());\n\n  user_op::TensorDesc* b2_x1_diff = ctx->MutOutputTensorDesc(\"b2_x1_diff\", 0);\n  b2_x1_diff->set_is_dynamic(b1_x1.is_dynamic());\n  b2_x1_diff->set_shape(b1_x1.shape());\n\n  user_op::TensorDesc* b2_x2_diff = ctx->MutOutputTensorDesc(\"b2_x2_diff\", 0);\n  b2_x2_diff->set_is_dynamic(b1_x1.is_dynamic());\n  b2_x2_diff->set_shape(b1_x1.shape());\n\n  user_op::TensorDesc* b1_y1_diff = ctx->MutOutputTensorDesc(\"b1_y1_diff\", 0);\n  b1_y1_diff->set_is_dynamic(b1_x1.is_dynamic());\n  b1_y1_diff->set_shape(b1_x1.shape());\n\n  user_op::TensorDesc* b1_y2_diff = ctx->MutOutputTensorDesc(\"b1_y2_diff\", 0);\n  b1_y2_diff->set_is_dynamic(b1_x1.is_dynamic());\n  b1_y2_diff->set_shape(b1_x1.shape());\n\n  user_op::TensorDesc* b2_y1_diff = ctx->MutOutputTensorDesc(\"b2_y1_diff\", 0);\n  b2_y1_diff->set_is_dynamic(b1_x1.is_dynamic());\n  b2_y1_diff->set_shape(b1_x1.shape());\n\n  user_op::TensorDesc* b2_y2_diff = ctx->MutOutputTensorDesc(\"b2_y2_diff\", 0);\n  b2_y2_diff->set_is_dynamic(b1_x1.is_dynamic());\n  b2_y2_diff->set_shape(b1_x1.shape());\n\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> FusedGetIntersectionAreaGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return FusedGetIntersectionAreaGradOp::InferLogicalTensorDesc(ctx);\n}\n\nMaybe<void> FusedGetIntersectionAreaGradOp::InferDataType(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& b1_x1 = ctx->InputTensorDesc(\"b1_x1\", 0);\n  const user_op::TensorDesc& b1_x2 = ctx->InputTensorDesc(\"b1_x2\", 0);\n  const user_op::TensorDesc& b1_y1 = ctx->InputTensorDesc(\"b1_y1\", 0);\n  const user_op::TensorDesc& b1_y2 = ctx->InputTensorDesc(\"b1_y2\", 0);\n  const user_op::TensorDesc& b2_x1 = ctx->InputTensorDesc(\"b2_x1\", 0);\n  const user_op::TensorDesc& b2_x2 = ctx->InputTensorDesc(\"b2_x2\", 0);\n  const user_op::TensorDesc& b2_y1 = ctx->InputTensorDesc(\"b2_y1\", 0);\n  const user_op::TensorDesc& b2_y2 = ctx->InputTensorDesc(\"b2_y2\", 0);\n  const user_op::TensorDesc& inter_diff = ctx->InputTensorDesc(\"inter_diff\", 0);\n\n  CHECK_EQ_OR_RETURN(b1_x1.data_type(), b1_x2.data_type());\n  CHECK_EQ_OR_RETURN(b1_x1.data_type(), b1_y1.data_type());\n  CHECK_EQ_OR_RETURN(b1_x1.data_type(), b1_y2.data_type());\n  CHECK_EQ_OR_RETURN(b1_x1.data_type(), b2_x1.data_type());\n  CHECK_EQ_OR_RETURN(b1_x1.data_type(), b2_x2.data_type());\n  CHECK_EQ_OR_RETURN(b1_x1.data_type(), b2_y1.data_type());\n  CHECK_EQ_OR_RETURN(b1_x1.data_type(), b2_y2.data_type());\n  CHECK_EQ_OR_RETURN(b1_x1.data_type(), inter_diff.data_type());\n\n  user_op::TensorDesc* b1_x1_diff = ctx->MutOutputTensorDesc(\"b1_x1_diff\", 0);\n  b1_x1_diff->set_data_type(b1_x1.data_type());\n\n  user_op::TensorDesc* b1_x2_diff = ctx->MutOutputTensorDesc(\"b1_x2_diff\", 0);\n  b1_x2_diff->set_data_type(b1_x1.data_type());\n\n  user_op::TensorDesc* b2_x1_diff = ctx->MutOutputTensorDesc(\"b2_x1_diff\", 0);\n  b2_x1_diff->set_data_type(b1_x1.data_type());\n\n  user_op::TensorDesc* b2_x2_diff = ctx->MutOutputTensorDesc(\"b2_x2_diff\", 0);\n  b2_x2_diff->set_data_type(b1_x1.data_type());\n\n  user_op::TensorDesc* b1_y1_diff = ctx->MutOutputTensorDesc(\"b1_y1_diff\", 0);\n  b1_y1_diff->set_data_type(b1_x1.data_type());\n\n  user_op::TensorDesc* b1_y2_diff = ctx->MutOutputTensorDesc(\"b1_y2_diff\", 0);\n  b1_y2_diff->set_data_type(b1_x1.data_type());\n\n  user_op::TensorDesc* b2_y1_diff = ctx->MutOutputTensorDesc(\"b2_y1_diff\", 0);\n  b2_y1_diff->set_data_type(b1_x1.data_type());\n\n  user_op::TensorDesc* b2_y2_diff = ctx->MutOutputTensorDesc(\"b2_y2_diff\", 0);\n  b2_y2_diff->set_data_type(b1_x1.data_type());\n\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> FusedGetIntersectionAreaGradOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& b1_x1 = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"b1_x1\", 0);\n  FOR_RANGE(int64_t, i, 0, b1_x1.shape().NumAxes()) {\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"b1_x1\", 0), i)\n        .Split(user_op::OpArg(\"b1_x2\", 0), i)\n        .Split(user_op::OpArg(\"b1_y1\", 0), i)\n        .Split(user_op::OpArg(\"b1_y2\", 0), i)\n        .Split(user_op::OpArg(\"b2_x1\", 0), i)\n        .Split(user_op::OpArg(\"b2_x2\", 0), i)\n        .Split(user_op::OpArg(\"b2_y1\", 0), i)\n        .Split(user_op::OpArg(\"b2_y2\", 0), i)\n        .Split(user_op::OpArg(\"inter_diff\", 0), i)\n        .Split(user_op::OpArg(\"b1_x1_diff\", 0), i)\n        .Split(user_op::OpArg(\"b1_x2_diff\", 0), i)\n        .Split(user_op::OpArg(\"b1_y1_diff\", 0), i)\n        .Split(user_op::OpArg(\"b1_y2_diff\", 0), i)\n        .Split(user_op::OpArg(\"b2_x1_diff\", 0), i)\n        .Split(user_op::OpArg(\"b2_x2_diff\", 0), i)\n        .Split(user_op::OpArg(\"b2_y1_diff\", 0), i)\n        .Split(user_op::OpArg(\"b2_y2_diff\", 0), i)\n        .Build();\n  }\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/fused_get_iou_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\nMaybe<void> FusedGetIouOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& w1 = ctx->InputTensorDesc(\"w1\", 0);\n\n  user_op::TensorDesc* iou = ctx->MutOutputTensorDesc(\"iou\", 0);\n  iou->set_is_dynamic(w1.is_dynamic());\n  iou->set_shape(w1.shape());\n\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> FusedGetIouOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return FusedGetIouOp::InferLogicalTensorDesc(ctx);\n}\n\nMaybe<void> FusedGetIouOp::InferDataType(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& w1 = ctx->InputTensorDesc(\"w1\", 0);\n\n  user_op::TensorDesc* iou = ctx->MutOutputTensorDesc(\"iou\", 0);\n  iou->set_data_type(w1.data_type());\n\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> FusedGetIouOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& w1 = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"w1\", 0);\n  FOR_RANGE(int64_t, i, 0, w1.shape().NumAxes()) {\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"w1\", 0), i)\n        .Split(user_op::OpArg(\"h1\", 0), i)\n        .Split(user_op::OpArg(\"w2\", 0), i)\n        .Split(user_op::OpArg(\"h2\", 0), i)\n        .Split(user_op::OpArg(\"inter\", 0), i)\n        .Split(user_op::OpArg(\"iou\", 0), i)\n        .Build();\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> FusedGetIouGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& diou = ctx->InputTensorDesc(\"diou\", 0);\n\n  user_op::TensorDesc* dw1 = ctx->MutOutputTensorDesc(\"dw1\", 0);\n  dw1->set_is_dynamic(diou.is_dynamic());\n  dw1->set_shape(diou.shape());\n\n  user_op::TensorDesc* dh1 = ctx->MutOutputTensorDesc(\"dh1\", 0);\n  dh1->set_is_dynamic(diou.is_dynamic());\n  dh1->set_shape(diou.shape());\n\n  user_op::TensorDesc* dinter = ctx->MutOutputTensorDesc(\"dinter\", 0);\n  dinter->set_is_dynamic(diou.is_dynamic());\n  dinter->set_shape(diou.shape());\n\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> FusedGetIouGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return FusedGetIouGradOp::InferLogicalTensorDesc(ctx);\n}\n\nMaybe<void> FusedGetIouGradOp::InferDataType(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& diou = ctx->InputTensorDesc(\"diou\", 0);\n\n  user_op::TensorDesc* dw1 = ctx->MutOutputTensorDesc(\"dw1\", 0);\n  dw1->set_data_type(diou.data_type());\n\n  user_op::TensorDesc* dh1 = ctx->MutOutputTensorDesc(\"dh1\", 0);\n  dh1->set_data_type(diou.data_type());\n\n  user_op::TensorDesc* dinter = ctx->MutOutputTensorDesc(\"dinter\", 0);\n  dinter->set_data_type(diou.data_type());\n\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> FusedGetIouGradOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& dy = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"dy\", 0);\n  FOR_RANGE(int64_t, i, 0, dy.shape().NumAxes()) {\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"diou\", 0), i)\n        .Split(user_op::OpArg(\"w1\", 0), i)\n        .Split(user_op::OpArg(\"h1\", 0), i)\n        .Split(user_op::OpArg(\"w2\", 0), i)\n        .Split(user_op::OpArg(\"h2\", 0), i)\n        .Split(user_op::OpArg(\"inter\", 0), i)\n        .Split(user_op::OpArg(\"dw1\", 0), i)\n        .Split(user_op::OpArg(\"dh1\", 0), i)\n        .Split(user_op::OpArg(\"dinter\", 0), i)\n        .Build();\n  }\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/fused_glu_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/* static */ auto FusedGluOp::GetSbp(user_op::SbpContext* ctx) -> Maybe<void> {\n  // check whether the user provide weight tensor v\n  bool is_split_mode = false;\n  if (ctx->user_op_conf().has_input(\"v\", 0)) { is_split_mode = true; }\n\n  bool has_b = ctx->user_op_conf().has_input(\"b\", 0);\n  bool has_c = ctx->user_op_conf().has_input(\"c\", 0);\n\n  // check whether the user provide bais tensors\n  CHECK_OR_RETURN(!(has_b && (is_split_mode && !has_c)))\n      << \"expected existance of c, when provide tensors w, v and b\";\n  bool has_bias = false;\n  if (has_b && (is_split_mode && has_c)) {\n    has_bias = true;\n  } else if (has_b && (!is_split_mode)) {\n    has_bias = true;\n  } else {\n    has_bias = false;\n  }\n\n  // data parallelism\n  for (int64_t i = 0; i < ctx->LogicalTensorDesc4InputArgNameAndIndex(\"x\", 0).shape().NumAxes() - 1;\n       ++i) {\n    if (is_split_mode && has_bias) {\n      ctx->NewBuilder()\n          .Split(user_op::OpArg(\"x\", 0), i)\n          .Broadcast(user_op::OpArg(\"w\", 0))\n          .Broadcast(user_op::OpArg(\"b\", 0))\n          .Broadcast(user_op::OpArg(\"v\", 0))\n          .Broadcast(user_op::OpArg(\"c\", 0))\n          .Split(ctx->outputs(), i)\n          .Build();\n    } else if (is_split_mode && !has_bias) {\n      ctx->NewBuilder()\n          .Split(user_op::OpArg(\"x\", 0), i)\n          .Broadcast(user_op::OpArg(\"w\", 0))\n          .Broadcast(user_op::OpArg(\"v\", 0))\n          .Split(ctx->outputs(), i)\n          .Build();\n    } else if (!is_split_mode && has_bias) {\n      ctx->NewBuilder()\n          .Split(user_op::OpArg(\"x\", 0), i)\n          .Broadcast(user_op::OpArg(\"w\", 0))\n          .Broadcast(user_op::OpArg(\"b\", 0))\n          .Split(ctx->outputs(), i)\n          .Build();\n    } else if (!is_split_mode && !has_bias) {\n      ctx->NewBuilder()\n          .Split(user_op::OpArg(\"x\", 0), i)\n          .Broadcast(user_op::OpArg(\"w\", 0))\n          .Split(ctx->outputs(), i)\n          .Build();\n    }\n  }\n\n  // model parallelism\n  if (is_split_mode && has_bias) {\n    ctx->NewBuilder()\n        .Broadcast(user_op::OpArg(\"x\", 0))\n        .Split(user_op::OpArg(\"w\", 0), 0)\n        .Split(user_op::OpArg(\"b\", 0), 0)\n        .Split(user_op::OpArg(\"v\", 0), 0)\n        .Split(user_op::OpArg(\"c\", 0), 0)\n        .Split(ctx->outputs(),\n               ctx->LogicalTensorDesc4InputArgNameAndIndex(\"x\", 0).shape().NumAxes() - 1)\n        .Build();\n  } else if (is_split_mode && !has_bias) {\n    ctx->NewBuilder()\n        .Broadcast(user_op::OpArg(\"x\", 0))\n        .Split(user_op::OpArg(\"w\", 0), 0)\n        .Split(user_op::OpArg(\"v\", 0), 0)\n        .Split(ctx->outputs(),\n               ctx->LogicalTensorDesc4InputArgNameAndIndex(\"x\", 0).shape().NumAxes() - 1)\n        .Build();\n  }\n\n  return Maybe<void>::Ok();\n}\n\n/* static */ auto FusedGluOp::InferLogicalTensorDesc(user_op::InferContext* ctx) -> Maybe<void> {\n  // obtain input shape\n  const Shape& x_shape = ctx->InputShape(\"x\", 0);\n  const Shape& w_shape = ctx->InputShape(\"w\", 0);\n\n  // check whether the user provide weight tensor v\n  bool is_split_mode = false;\n  if (ctx->has_input(\"v\", 0)) { is_split_mode = true; }\n\n  bool has_b = ctx->has_input(\"b\", 0);\n  bool has_c = ctx->has_input(\"c\", 0);\n\n  // check whether the user provide bais tensors\n  CHECK_OR_RETURN(!(has_b && (is_split_mode && !has_c)))\n      << \"expected existance of c, when provide tensors w, v and b\";\n  bool has_bias = false;\n  if (has_b && (is_split_mode && has_c)) {\n    has_bias = true;\n  } else if (has_b && (!is_split_mode)) {\n    has_bias = true;\n  } else {\n    has_bias = false;\n  }\n\n  // check dimensions of x, w and b\n  CHECK_GT_OR_RETURN(x_shape.NumAxes(), 1)\n      << \"number of axes of \\'x\\' should have be greater than 1, yet get \" << x_shape.NumAxes();\n  CHECK_EQ_OR_RETURN(w_shape.NumAxes(), 2)\n      << \"number of axes of \\'w\\' should have be equal to 2, yet get \" << w_shape.NumAxes();\n  if (has_bias) {\n    const Shape& b_shape = ctx->InputShape(\"b\", 0);\n    CHECK_EQ_OR_RETURN(b_shape.NumAxes(), 1)\n        << \"number of axes of \\'b\\' should have be equal to 1, yet get \" << b_shape.NumAxes();\n  }\n\n  // check input shapes of w and b\n  size_t x_num_axes = x_shape.NumAxes();\n  CHECK_EQ_OR_RETURN(w_shape.At(1), x_shape.At(x_num_axes - 1))\n      << \"dimension 1 of \\'w\\'(\" << w_shape.At(1)\n      << \") is not consistant with the last dimension of \\'x\\'(\" << x_shape.At(x_num_axes - 1)\n      << \")\";\n  if (has_bias) {\n    const Shape& b_shape = ctx->InputShape(\"b\", 0);\n    CHECK_EQ_OR_RETURN(b_shape.At(0), w_shape.At(0))\n        << \"dimension 0 of \\'b\\'(\" << b_shape.At(0)\n        << \") is not consistant with dimension 0 of \\'w\\'(\" << w_shape.At(0) << \")\";\n  }\n  if (!is_split_mode) {\n    CHECK_EQ_OR_RETURN(w_shape.At(1) % 2, 0) << \"dimension 1 of \\'w\\' is not divisible by 2\";\n  }\n\n  // check both dimensions and input shapes of v and c (optional)\n  if (is_split_mode) {\n    const Shape& v_shape = ctx->InputShape(\"v\", 0);\n\n    CHECK_EQ_OR_RETURN(v_shape.NumAxes(), 2)\n        << \"number of axes of \\'v\\' should have be equal to 2, yet get \" << v_shape.NumAxes();\n    CHECK_OR_RETURN(v_shape == w_shape) << \"the shape of \\'v\\' is not consistant with \\'w\\'\";\n\n    if (has_bias) {\n      const Shape& b_shape = ctx->InputShape(\"b\", 0);\n      const Shape& c_shape = ctx->InputShape(\"c\", 0);\n      CHECK_EQ_OR_RETURN(c_shape.NumAxes(), 1)\n          << \"number of axes of \\'c\\' should have be equal to 1, yet get \" << c_shape.NumAxes();\n      CHECK_OR_RETURN(c_shape == b_shape) << \"the shape of \\'c\\' is not consistant with \\'b\\'\";\n    }\n  }\n\n  // set shape of the output tensor y\n  Shape y_shape = x_shape;  // borrow from input shape\n  size_t y_num_axes = x_num_axes;\n  if (is_split_mode) {\n    y_shape.Set(y_num_axes - 1, w_shape.At(0));\n  } else {\n    y_shape.Set(y_num_axes - 1, w_shape.At(0) / 2);\n  }\n  user_op::TensorDesc* y_tensor = ctx->MutOutputTensorDesc(\"y\", 0);\n  y_tensor->set_shape(y_shape);\n\n  // set shape of the output tensors of both matmul_wx and matmul_vx\n  Shape matmul_wx_shape = x_shape;  // borrow from input shape\n  matmul_wx_shape.Set(x_num_axes - 1, w_shape.At(0));\n  user_op::TensorDesc* matmul_wx_tensor = ctx->MutOutputTensorDesc(\"matmul_wx\", 0);\n  matmul_wx_tensor->set_shape(matmul_wx_shape);\n  if (is_split_mode) {\n    user_op::TensorDesc* matmul_vx_tensor = ctx->MutOutputTensorDesc(\"matmul_vx\", 0);\n    matmul_vx_tensor->set_shape(y_shape);\n  }\n\n  return Maybe<void>::Ok();\n}\n\n/* static */ auto FusedGluOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) -> Maybe<void> {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ auto FusedGluOp::InferDataType(user_op::InferContext* ctx) -> Maybe<void> {\n  // obtain input data types\n  DataType x_dtype = ctx->InputDType(\"x\", 0);\n\n  // check whether the user provide weight tensor v\n  bool is_split_mode = false;\n  if (ctx->has_input(\"v\", 0)) { is_split_mode = true; }\n\n  bool has_b = ctx->has_input(\"b\", 0);\n  bool has_c = ctx->has_input(\"c\", 0);\n\n  // check whether the user provide bais tensors\n  CHECK_OR_RETURN(!(has_b && (is_split_mode && !has_c)))\n      << \"expected existance of c, when provide tensors w, v and b\";\n  bool has_bias = false;\n  if (has_b && (is_split_mode && has_c)) {\n    has_bias = true;\n  } else if (has_b && (!is_split_mode)) {\n    has_bias = true;\n  } else {\n    has_bias = false;\n  }\n\n  // check types of x, w and b\n  CHECK_EQ_OR_RETURN(ctx->InputDType(\"w\", 0), x_dtype)\n      << \"data type of \\'w\\' is not consitant with \\'x\\'\";\n  if (has_bias) {\n    CHECK_EQ_OR_RETURN(ctx->InputDType(\"b\", 0), x_dtype)\n        << \"data type of \\'b\\' is not consitant with \\'x\\'\";\n  }\n\n  // check types of v and c (optional)\n  if (is_split_mode) {\n    CHECK_EQ_OR_RETURN(ctx->InputDType(\"v\", 0), x_dtype)\n        << \"data type of \\'v\\' is not consitant with \\'x\\'\";\n    if (has_bias) {\n      CHECK_EQ_OR_RETURN(ctx->InputDType(\"c\", 0), x_dtype)\n          << \"data type of \\'c\\' is not consitant with \\'x\\'\";\n    }\n  }\n\n  // set output data type\n  ctx->SetOutputDType(\"y\", 0, x_dtype);\n  ctx->SetOutputDType(\"matmul_wx\", 0, x_dtype);\n  if (is_split_mode) { ctx->SetOutputDType(\"matmul_vx\", 0, x_dtype); }\n\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/fused_glu_without_linear_grad_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/* static */ auto FusedGluWithoutLinearGradOp::GetSbp(user_op::SbpContext* ctx) -> Maybe<void> {\n  // check existance of optional args\n  bool is_split_mode = false;\n  if (ctx->user_op_conf().has_input(\"matmul_vx\", 0)) { is_split_mode = true; }\n\n  for (int64_t i = 0;\n       i < ctx->LogicalTensorDesc4InputArgNameAndIndex(\"dy\", 0).shape().NumAxes() - 1; ++i) {\n    if (is_split_mode) {\n      ctx->NewBuilder()\n          .Split(user_op::OpArg(\"dy\", 0), i)\n          .Split(user_op::OpArg(\"matmul_wx\", 0), i)\n          .Split(user_op::OpArg(\"matmul_vx\", 0), i)\n          .Split(ctx->outputs(), i)\n          .Build();\n    } else {\n      ctx->NewBuilder()\n          .Split(user_op::OpArg(\"dy\", 0), i)\n          .Split(user_op::OpArg(\"matmul_wx\", 0), i)\n          .Split(ctx->outputs(), i)\n          .Build();\n    }\n  }\n\n  return Maybe<void>::Ok();\n}\n\n/* static */ auto FusedGluWithoutLinearGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx)\n    -> Maybe<void> {\n  // obtain input shape\n  const Shape& dy_shape = ctx->InputShape(\"dy\", 0);\n  const Shape& matmul_wx_shape = ctx->InputShape(\"matmul_wx\", 0);\n\n  // check existance of optional args\n  bool is_split_mode = false;\n  if (ctx->has_input(\"matmul_vx\", 0)) { is_split_mode = true; }\n\n  // obtain dimensions of dy and matmul_wx\n  size_t dy_num_axes = dy_shape.NumAxes();\n  size_t matmul_wx_num_axes = matmul_wx_shape.NumAxes();\n\n  // check dimensions of dy and matmul_wx\n  CHECK_GT_OR_RETURN(dy_num_axes, 1)\n      << \"number of axes of \\'dy\\' should have be greater than 1, yet get \" << dy_num_axes;\n  CHECK_GT_OR_RETURN(matmul_wx_num_axes, 1)\n      << \"number of axes of \\'matmul_wx\\' should have be greater than 1, yet get \"\n      << matmul_wx_num_axes;\n  CHECK_EQ_OR_RETURN(dy_num_axes, matmul_wx_num_axes)\n      << \"number of axes of \\'dy\\'(\" << dy_num_axes\n      << \") is not consistant with the one of \\'matmul_wx\\'(\" << matmul_wx_num_axes << \")\";\n\n  // check input shapes of dy and matmul_wx\n  for (uint64_t i = 0; i < dy_num_axes - 1; i++) {\n    size_t dy_size = dy_shape.At(i);\n    size_t matmul_wx_size = matmul_wx_shape.At(i);\n    CHECK_EQ_OR_RETURN(dy_size, matmul_wx_size)\n        << \"dimension \" << i << \"of \\'dy\\'(\" << dy_size << \") and \\'matmul_wx\\'(\" << matmul_wx_size\n        << \") is not consistent\";\n  }\n  if (is_split_mode) {\n    CHECK_EQ_OR_RETURN(dy_shape.At(dy_num_axes - 1), matmul_wx_shape.At(matmul_wx_num_axes - 1))\n        << \"the last dimension of \\'dy\\'(\" << dy_shape.At(dy_num_axes - 1)\n        << \") is not consistant with the last dimension of \\'matmul_wx\\'(\"\n        << matmul_wx_shape.At(matmul_wx_num_axes - 1) << \")\";\n  } else {\n    CHECK_EQ_OR_RETURN(2 * dy_shape.At(dy_num_axes - 1), matmul_wx_shape.At(matmul_wx_num_axes - 1))\n        << \"two times of the last dimension of \\'dy\\'(\" << 2 * dy_shape.At(dy_num_axes - 1)\n        << \") is not consistant with the last dimension of \\'matmul_wx\\'(\"\n        << matmul_wx_shape.At(matmul_wx_num_axes - 1) << \")\";\n  }\n\n  // check both dimensions and input shapes of matmul_vx (optional)\n  if (is_split_mode) {\n    // obtain input shape\n    const Shape& matmul_vx_shape = ctx->InputShape(\"matmul_vx\", 0);\n\n    // check dimensions of matmul_vx\n    size_t matmul_vx_num_axes = matmul_vx_shape.NumAxes();\n    CHECK_GT_OR_RETURN(matmul_vx_num_axes, 1)\n        << \"number of axes of \\'matmul_vx\\' should have be greater than 1, yet get \"\n        << matmul_vx_num_axes;\n    CHECK_EQ_OR_RETURN(matmul_vx_num_axes, dy_num_axes)\n        << \"number of axes of \\'dy\\'(\" << dy_num_axes\n        << \") is not consistant with the one of \\'matmul_vx\\'(\" << matmul_vx_num_axes << \")\";\n\n    // check input shapes of dy and matmul_vx\n    for (uint64_t i = 0; i < dy_num_axes - 1; i++) {\n      size_t dy_size = dy_shape.At(i);\n      size_t matmul_vx_size = matmul_vx_shape.At(i);\n      CHECK_EQ_OR_RETURN(dy_size, matmul_vx_size)\n          << \"dimension \" << i << \"of \\'dy\\'(\" << dy_size << \") and \\'matmul_vx\\'(\"\n          << matmul_vx_size << \") is not consistent\";\n    }\n    CHECK_EQ_OR_RETURN(matmul_vx_shape.At(matmul_vx_num_axes - 1), dy_shape.At(dy_num_axes - 1))\n        << \"the last dimension of \\'dy\\'(\" << dy_shape.At(dy_num_axes - 1)\n        << \") is not consistant with the last dimension of \\'matmul_vx\\'(\"\n        << matmul_vx_shape.At(matmul_vx_num_axes - 1) << \")\";\n  }\n\n  // set shape of the output tensor d_matmul_wx\n  Shape d_matmul_wx_shape = matmul_wx_shape;  // borrow from input shape\n  user_op::TensorDesc* d_matmul_wx_tensor = ctx->MutOutputTensorDesc(\"d_matmul_wx\", 0);\n  d_matmul_wx_tensor->set_shape(d_matmul_wx_shape);\n\n  // set shape of the output tensor d_matmul_vx (optional)\n  if (is_split_mode) {\n    const Shape& matmul_vx_shape = ctx->InputShape(\"matmul_vx\", 0);\n    Shape d_matmul_vx_shape = matmul_vx_shape;  // borrow from input shape\n    user_op::TensorDesc* d_matmul_vx_tensor = ctx->MutOutputTensorDesc(\"d_matmul_vx\", 0);\n    d_matmul_vx_tensor->set_shape(d_matmul_vx_shape);\n  }\n\n  return Maybe<void>::Ok();\n}\n\n/* static */ auto FusedGluWithoutLinearGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx)\n    -> Maybe<void> {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ auto FusedGluWithoutLinearGradOp::InferDataType(user_op::InferContext* ctx)\n    -> Maybe<void> {\n  // obtain input data types\n  DataType dy_dtype = ctx->InputDType(\"dy\", 0);\n\n  // check types of matmul_wx\n  CHECK_EQ_OR_RETURN(ctx->InputDType(\"matmul_wx\", 0), dy_dtype)\n      << \"data type of \\'matmul_wx\\' is not consitant with \\'dy\\'\";\n\n  bool is_split_mode = ctx->has_input(\"matmul_vx\", 0);\n\n  // check types of matmul_vx (optional)\n  if (is_split_mode) {\n    CHECK_EQ_OR_RETURN(ctx->InputDType(\"matmul_vx\", 0), dy_dtype)\n        << \"data type of \\'matmul_vx\\' is not consitant with \\'dy\\'\";\n  }\n\n  // set output data type\n  ctx->SetOutputDType(\"d_matmul_wx\", 0, dy_dtype);\n  if (is_split_mode) { ctx->SetOutputDType(\"d_matmul_vx\", 0, dy_dtype); }\n\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/fused_gru_cell_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/* static */ Maybe<void> FusedGruCellOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& hx_shape = ctx->InputShape(\"hx\", 0);\n  ctx->SetOutputShape(\"hy\", 0, hx_shape);\n  ctx->SetOutputShape(\"workspace\", 0, Shape({hx_shape.At(0), hx_shape.At(1) * 5}));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> FusedGruCellOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> FusedGruCellOp::GetSbp(user_op::SbpContext* ctx) {\n  // input_gates shape:  [batch_size, hidden_size * 3]\n  // hidden_gates shape: [batch_size, hidden_size * 3]\n  // hx shape:           [batch_size, hidden_size]\n  // input_bias shape:   [hidden_size * 3]\n  // hidden_bias shape:  [hidden_size * 3]\n\n  // hy shape:           [batch_size, hidden_size]\n  // workspace shape:    [batch_size, hidden_size * 5]\n\n  std::vector<user_op::OpArg> broadcast_args;\n  if (ctx->user_op_conf().has_input(\"input_bias\", 0)) {\n    broadcast_args.emplace_back(\"input_bias\", 0);\n  }\n  if (ctx->user_op_conf().has_input(\"hidden_bias\", 0)) {\n    broadcast_args.emplace_back(\"hidden_bias\", 0);\n  }\n\n  std::vector<user_op::OpArg> split_args;\n  split_args.emplace_back(\"input_gates\", 0);\n  split_args.emplace_back(\"hidden_gates\", 0);\n  split_args.emplace_back(\"hx\", 0);\n  split_args.emplace_back(\"hy\", 0);\n  split_args.emplace_back(\"workspace\", 0);\n\n  ctx->NewBuilder().Split(split_args, 0).Broadcast(broadcast_args).Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> FusedGruCellOp::InferDataType(user_op::InferContext* ctx) {\n  DataType in_types = ctx->InputDType(\"hx\", 0);\n  ctx->SetOutputDType(\"hy\", 0, in_types);\n  ctx->SetOutputDType(\"workspace\", 0, in_types);\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> FusedGruCellGradOp ::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& grad_hy_shape = ctx->InputShape(\"grad_hy\", 0);\n  DimVector dim_vec({grad_hy_shape.At(0), grad_hy_shape.At(1) * 3});\n  ctx->SetOutputShape(\"grad_input_gates\", 0, Shape(dim_vec));\n  ctx->SetOutputShape(\"grad_hidden_gates\", 0, Shape(dim_vec));\n\n  if (ctx->has_output(\"grad_hx\", 0)) { ctx->SetOutputShape(\"grad_hx\", 0, grad_hy_shape); }\n\n  if (ctx->has_output(\"grad_input_bias\", 0) && ctx->has_output(\"grad_hidden_bias\", 0)) {\n    ctx->SetOutputShape(\"grad_input_bias\", 0, Shape({grad_hy_shape.At(1) * 3}));\n    ctx->SetOutputShape(\"grad_hidden_bias\", 0, Shape({grad_hy_shape.At(1) * 3}));\n  }\n\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> FusedGruCellGradOp ::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> FusedGruCellGradOp ::GetSbp(user_op::SbpContext* ctx) {\n  // grad_hy shape:       [batch_size, hidden_size]\n  // workspace shape:     [batch_size, hidden_size * 5]\n\n  // grad_input_gates shape:     [batch_size, hidden_size * 3]\n  // grad_hidden_gates shape:    [batch_size, hidden_size * 3]\n  // grad_hx shape:              [batch_size, hidden_size]\n  // grad_input_bias shape:      [hidden_size * 3]\n  // grad_hidden_bias shape:     [hidden_size * 3]\n\n  std::vector<user_op::OpArg> partial_sum_args;\n  if (ctx->user_op_conf().has_output(\"grad_input_bias\", 0)) {\n    partial_sum_args.emplace_back(\"grad_input_bias\", 0);\n  }\n  if (ctx->user_op_conf().has_output(\"grad_hidden_bias\", 0)) {\n    partial_sum_args.emplace_back(\"grad_hidden_bias\", 0);\n  }\n\n  std::vector<user_op::OpArg> split_args;\n  split_args.emplace_back(\"grad_hy\", 0);\n  split_args.emplace_back(\"workspace\", 0);\n  split_args.emplace_back(\"grad_input_gates\", 0);\n  split_args.emplace_back(\"grad_hidden_gates\", 0);\n\n  if (ctx->user_op_conf().has_output(\"grad_hx\", 0)) { split_args.emplace_back(\"grad_hx\", 0); }\n\n  ctx->NewBuilder().Split(split_args, 0).PartialSum(partial_sum_args).Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> FusedGruCellGradOp ::InferDataType(user_op::InferContext* ctx) {\n  DataType in_types = ctx->InputDType(\"grad_hy\", 0);\n  ctx->SetOutputDType(\"grad_input_gates\", 0, in_types);\n  ctx->SetOutputDType(\"grad_hidden_gates\", 0, in_types);\n  if (ctx->has_output(\"grad_hx\", 0)) { ctx->SetOutputDType(\"grad_hx\", 0, in_types); }\n  if (ctx->has_output(\"grad_input_bias\", 0)) {\n    ctx->SetOutputDType(\"grad_input_bias\", 0, in_types);\n  }\n  if (ctx->has_output(\"grad_hidden_bias\", 0)) {\n    ctx->SetOutputDType(\"grad_hidden_bias\", 0, in_types);\n  }\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/fused_linear_with_groupwise_quantized_weight_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/data_type.pb.h\"\n#include \"oneflow/core/common/just.h\"\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/infer_util.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nMaybe<void> InferTensorDesc4FusedMatmulBias(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& x_desc = ctx->InputTensorDesc(\"x\", 0);\n  CHECK_GE_OR_RETURN(x_desc.shape().NumAxes(), 2);\n  const int64_t k = x_desc.shape().At(x_desc.shape().NumAxes() - 1);\n\n  const user_op::TensorDesc& w_desc = ctx->InputTensorDesc(\"w\", 0);\n  CHECK_EQ_OR_RETURN(w_desc.shape().NumAxes(), 2);\n  const int64_t n = w_desc.shape().At(0);\n  const int32_t num_bits = ctx->Attr<int32_t>(\"num_bits\");\n  if (num_bits == 8) {\n    CHECK_EQ_OR_RETURN(w_desc.shape().At(1), k);\n  } else if (num_bits == 4) {\n    CHECK_EQ_OR_RETURN(w_desc.shape().At(1) * 2, k);\n  } else {\n    UNIMPLEMENTED_THEN_RETURN();\n  }\n  const int64_t group_dim = ctx->Attr<int64_t>(\"group_dim\");\n  CHECK_OR_RETURN(group_dim == 0 || group_dim == 1);\n  const int64_t group_dim_size = group_dim == 0 ? n : k;\n  const int64_t group_size = ctx->Attr<int64_t>(\"group_size\");\n  CHECK_GT_OR_RETURN(group_size, 1);\n  CHECK_LE_OR_RETURN(group_size, group_dim_size);\n  CHECK_EQ_OR_RETURN(group_dim_size % group_size, 0);\n  const int64_t num_groups = group_dim_size / group_size;\n  const user_op::TensorDesc& w_scale_desc = ctx->InputTensorDesc(\"w_scale\", 0);\n  CHECK_EQ_OR_RETURN(w_scale_desc.shape().NumAxes(), 2);\n  if (group_dim == 0) {\n    CHECK_EQ_OR_RETURN(w_scale_desc.shape().At(0), num_groups);\n    CHECK_EQ_OR_RETURN(w_scale_desc.shape().At(1), k);\n  } else if (group_dim == 1) {\n    CHECK_EQ_OR_RETURN(w_scale_desc.shape().At(0), n);\n    CHECK_EQ_OR_RETURN(w_scale_desc.shape().At(1), num_groups);\n  } else {\n    UNIMPLEMENTED_THEN_RETURN();\n  }\n  Shape out_shape = x_desc.shape();\n  out_shape[x_desc.shape().NumAxes() - 1] = n;\n\n  if (ctx->has_input(\"b\", 0)) {\n    const user_op::TensorDesc& b_desc = ctx->InputTensorDesc(\"b\", 0);\n    CHECK_EQ_OR_RETURN(b_desc.shape().NumAxes(), 1);\n    CHECK_EQ_OR_RETURN(b_desc.shape().At(0), n);\n  }\n\n  if (ctx->has_input(\"w_zero\", 0)) {\n    const user_op::TensorDesc& w_zero_desc = ctx->InputTensorDesc(\"w_zero\", 0);\n    CHECK_OR_RETURN(w_zero_desc.shape() == w_scale_desc.shape());\n  }\n\n  ctx->SetOutputShape(\"out\", 0, out_shape);\n\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InferDataType4MatmulBias(user_op::InferContext* ctx) {\n  const DataType data_type = ctx->InputDType(\"x\", 0);\n  CHECK_EQ_OR_RETURN(ctx->InputDType(\"w_scale\", 0), data_type);\n  if (ctx->has_input(\"w_zero\", 0)) { CHECK_EQ_OR_RETURN(ctx->InputDType(\"w_zero\", 0), data_type); }\n  if (ctx->has_input(\"b\", 0)) { CHECK_EQ_OR_RETURN(ctx->InputDType(\"b\", 0), data_type); }\n  if (ctx->Attr<bool>(\"symmetric\")) {\n    CHECK_OR_RETURN(ctx->InputDType(\"w\", 0) == DataType::kUInt8\n                    || ctx->InputDType(\"w\", 0) == DataType::kInt8);\n  } else {\n    CHECK_EQ_OR_RETURN(ctx->InputDType(\"w\", 0), DataType::kUInt8);\n  }\n  ctx->SetOutputDType(\"out\", 0, data_type);\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\n/* static */ Maybe<void> FusedLinearWithGroupwiseQuantizedWeightOp::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferTensorDesc4FusedMatmulBias(ctx);\n}\n\n/*static*/ Maybe<void> FusedLinearWithGroupwiseQuantizedWeightOp::InferPhysicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> FusedLinearWithGroupwiseQuantizedWeightOp::GetSbp(\n    user_op::SbpContext* ctx) {\n  // (b, m, k) * (n, k)\n  const auto& x_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"x\", 0).shape();\n\n  const int64_t x_num_axes = x_shape.NumAxes();\n\n  const int64_t out_num_axes = x_num_axes;\n  const int32_t k_x_axis = x_num_axes - 1;\n\n  std::vector<user_op::OpArg> bias_args;\n  if (ctx->user_op_conf().has_input(\"b\", 0)) { bias_args.emplace_back(\"b\", 0); }\n\n  std::vector<user_op::OpArg> scale_args;\n  scale_args.emplace_back(\"w_scale\", 0);\n  if (ctx->user_op_conf().has_input(\"w_zero\", 0)) { scale_args.emplace_back(\"w_zero\", 0); }\n\n  for (int i = 0; i < x_shape.NumAxes() - 1; i++) {\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"x\", 0), i)\n        .Broadcast(user_op::OpArg(\"w\", 0))\n        .Broadcast(scale_args)\n        .Broadcast(bias_args)\n        .Split(user_op::OpArg(\"out\", 0), i)\n        .Build();\n  }\n\n  const int64_t group_dim = ctx->user_op_conf().attr<int64_t>(\"group_dim\");\n  const int64_t group_size = ctx->user_op_conf().attr<int64_t>(\"group_size\");\n  CHECK_OR_RETURN(group_dim == 0 || group_dim == 1);\n  const auto& x_desc = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"x\", 0);\n  const auto& w_desc = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"w\", 0);\n  CHECK_GE_OR_RETURN(x_desc.shape().NumAxes(), 2);\n  CHECK_EQ_OR_RETURN(w_desc.shape().NumAxes(), 2);\n  const int64_t k = x_desc.shape().At(x_desc.shape().NumAxes() - 1);\n  const int64_t n = w_desc.shape().At(0);\n  const int64_t group_dim_size = group_dim == 0 ? n : k;\n  CHECK_EQ_OR_RETURN(group_dim_size % group_size, 0);\n  const int64_t num_groups = group_dim_size / group_size;\n\n  // B x S(n_axis) -> S(n_axis)\n  if (group_dim == 1 || num_groups % ctx->parallel_num() == 0) {\n    ctx->NewBuilder()\n        .Broadcast(user_op::OpArg(\"x\", 0))\n        .Split(user_op::OpArg(\"w\", 0), 0)\n        .Split(scale_args, 0)\n        .Split(bias_args, 0)\n        .Split(user_op::OpArg(\"out\", 0), out_num_axes - 1)\n        .Build();\n  }\n\n  // S(x_k_axis) x S(w_k_axis) -> P\n  if (group_dim == 0 || num_groups % ctx->parallel_num() == 0) {\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"x\", 0), k_x_axis)\n        .Split(user_op::OpArg(\"w\", 0), 1)\n        .Split(scale_args, 1)\n        .PartialSum(bias_args)\n        .PartialSum(user_op::OpArg(\"out\", 0))\n        .Build();\n  }\n\n  // P x B -> P\n  ctx->NewBuilder()\n      .PartialSum(user_op::OpArg(\"x\", 0))\n      .Broadcast(user_op::OpArg(\"w\", 0))\n      .Broadcast(scale_args)\n      .PartialSum(bias_args)\n      .PartialSum(user_op::OpArg(\"out\", 0))\n      .Build();\n\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> FusedLinearWithGroupwiseQuantizedWeightOp::InferDataType(\n    user_op::InferContext* ctx) {\n  return InferDataType4MatmulBias(ctx);\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/fused_lstm_cell_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/* static */ Maybe<void> FusedLstmCellOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& cx_shape = ctx->InputShape(\"cx\", 0);\n  ctx->SetOutputShape(\"hy\", 0, cx_shape);\n  ctx->SetOutputShape(\"cy\", 0, cx_shape);\n  ctx->SetOutputShape(\"workspace\", 0, ctx->InputShape(\"input_gates\", 0));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> FusedLstmCellOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> FusedLstmCellOp::GetSbp(user_op::SbpContext* ctx) {\n  // input_gates shape:  [batch_size, hidden_size * 4]\n  // hidden_gates shape: [batch_size, hidden_size * 4]\n  // cx shape:           [batch_size, hidden_size]\n  // input_bias shape:   [hidden_size * 4]\n  // hidden_bias shape:  [hidden_size * 4]\n\n  // hy shape:           [batch_size, hidden_size]\n  // cy shape:           [batch_size, hidden_size]\n  // workspace shape:    [batch_size, hidden_size * 4]\n\n  std::vector<user_op::OpArg> broadcast_args;\n  if (ctx->user_op_conf().has_input(\"input_bias\", 0)) {\n    broadcast_args.emplace_back(\"input_bias\", 0);\n  }\n  if (ctx->user_op_conf().has_input(\"hidden_bias\", 0)) {\n    broadcast_args.emplace_back(\"hidden_bias\", 0);\n  }\n\n  std::vector<user_op::OpArg> split_args;\n  split_args.emplace_back(\"input_gates\", 0);\n  split_args.emplace_back(\"hidden_gates\", 0);\n  split_args.emplace_back(\"cx\", 0);\n  split_args.emplace_back(\"hy\", 0);\n  split_args.emplace_back(\"cy\", 0);\n  split_args.emplace_back(\"workspace\", 0);\n\n  ctx->NewBuilder().Split(split_args, 0).Broadcast(broadcast_args).Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> FusedLstmCellOp::InferDataType(user_op::InferContext* ctx) {\n  DataType in_types = ctx->InputDType(\"cx\", 0);\n  ctx->SetOutputDType(\"hy\", 0, in_types);\n  ctx->SetOutputDType(\"cy\", 0, in_types);\n  ctx->SetOutputDType(\"workspace\", 0, in_types);\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> FusedLstmCellGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  ctx->SetOutputShape(\"grad_gates\", 0, ctx->InputShape(\"workspace\", 0));\n\n  if (ctx->has_output(\"grad_cx\", 0)) {\n    ctx->SetOutputShape(\"grad_cx\", 0, ctx->InputShape(\"cx\", 0));\n  }\n\n  if (ctx->has_output(\"grad_bias\", 0)) {\n    ctx->SetOutputShape(\"grad_bias\", 0, Shape({ctx->InputShape(\"workspace\", 0).At(1)}));\n  }\n\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> FusedLstmCellGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> FusedLstmCellGradOp::GetSbp(user_op::SbpContext* ctx) {\n  // grad_hy shape:       [batch_size, hidden_size]\n  // grad_cy shape:       [batch_size, hidden_size]\n  // cx shape:            [batch_size, hidden_size]\n  // cy shape:            [batch_size, hidden_size]\n  // workspace shape:     [batch_size, hidden_size * 4]\n\n  // grad_gates shape:    [batch_size, hidden_size * 4]\n  // grad_cx shape:       [batch_size, hidden_size]\n  // grad_bias shape:     [hidden_size * 4]\n\n  std::vector<user_op::OpArg> partial_sum_args;\n  if (ctx->user_op_conf().has_output(\"grad_bias\", 0)) {\n    partial_sum_args.emplace_back(\"grad_bias\", 0);\n  }\n\n  std::vector<user_op::OpArg> split_args;\n  split_args.emplace_back(\"grad_hy\", 0);\n  split_args.emplace_back(\"grad_cy\", 0);\n  split_args.emplace_back(\"cx\", 0);\n  split_args.emplace_back(\"cy\", 0);\n  split_args.emplace_back(\"workspace\", 0);\n  split_args.emplace_back(\"grad_gates\", 0);\n\n  if (ctx->user_op_conf().has_output(\"grad_cx\", 0)) { split_args.emplace_back(\"grad_cx\", 0); }\n\n  ctx->NewBuilder().Split(split_args, 0).PartialSum(partial_sum_args).Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> FusedLstmCellGradOp::InferDataType(user_op::InferContext* ctx) {\n  DataType in_types = ctx->InputDType(\"grad_hy\", 0);\n  ctx->SetOutputDType(\"grad_gates\", 0, in_types);\n  if (ctx->has_output(\"grad_cx\", 0)) { ctx->SetOutputDType(\"grad_cx\", 0, in_types); }\n  if (ctx->has_output(\"grad_bias\", 0)) { ctx->SetOutputDType(\"grad_bias\", 0, in_types); }\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/fused_matmul_bias_add_relu_dropout_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nconstexpr int32_t kAuxReluLdAlignRequirement = 128;\n\nlong AlignReluAuxLd(long aux_ld) {\n  /*\n  ReLu bit-mask matrix leading dimension in elements.\n  Must be divisible by 128 and be no less than the number of rows in the output matrix.\n  */\n  long old_aux_ld = aux_ld;\n  return ((old_aux_ld + kAuxReluLdAlignRequirement - 1) / kAuxReluLdAlignRequirement)\n         * kAuxReluLdAlignRequirement;\n}\n\nMaybe<void> InferTensorDesc4FusedMatmul(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& x_desc = ctx->InputTensorDesc(\"x\", 0);\n  int32_t weight_size = ctx->input_size(\"weights\");\n  int32_t bias_size = ctx->input_size(\"biases\");\n  CHECK_EQ_OR_RETURN(weight_size, bias_size) << \"Weight num should be equal to bias num. \";\n  /*\n  A: (m, k)\n  B: (n, k) need transpose\n  C: (m, n)\n  */\n  int64_t m = 0, n = 0, k = 0, cublas_aux_ld = 0;\n  m = x_desc.shape().At(0);\n  k = x_desc.shape().At(1);\n\n  for (int32_t idx = 0; idx < weight_size; idx++) {\n    // skip first input weight.\n    const user_op::TensorDesc& weight_desc = ctx->InputTensorDesc(\"weights\", idx);\n    const user_op::TensorDesc& bias_desc = ctx->InputTensorDesc(\"biases\", idx);\n    CHECK_EQ_OR_RETURN(weight_desc.shape().NumAxes(), 2) << \"Weight's ndim should be equal to 2. \";\n    CHECK_EQ_OR_RETURN(bias_desc.shape().NumAxes(), 1) << \"Bias's ndim should be equal to 1. \";\n\n    n = weight_desc.shape().At(0);\n    CHECK_EQ_OR_RETURN(bias_desc.shape().At(0), n)\n        << \"Bias shape should be equal to N. Assume (M, K) matmul (N, K, transpose_b=True) \"\n           \"bias_add (N, ). \";\n    CHECK_EQ_OR_RETURN(weight_desc.shape().At(1), k)\n        << \"Weight shape should be equal to K. Assume (M, K) matmul (N, K, transpose_b=True) \"\n           \"bias_add (N, ). \";\n\n    cublas_aux_ld = n;\n    // Set Middle result shape.\n    long cublas_aligned_aux_ld = AlignReluAuxLd(cublas_aux_ld);\n    int64_t aux_size = cublas_aligned_aux_ld / 32;  // Cause we use int32_t as dtype\n    ctx->SetOutputShape(\"cublas_aux\", idx, Shape({m, aux_size}));\n    ctx->SetOutputShape(\"hidden\", idx, Shape({m, n}));\n    // Set for next layer.\n    k = n;\n  }\n  ctx->SetOutputShape(\"out\", 0, Shape({m, n}));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InferDataType4Matmul(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& first_in_desc = ctx->InputTensorDesc(\"x\", 0);\n\n  for (const auto& in_arg_pair : ctx->inputs()) {\n    const user_op::TensorDesc& in_desc =\n        ctx->InputTensorDesc(in_arg_pair.first, in_arg_pair.second);\n    CHECK_EQ_OR_RETURN(in_desc.data_type(), first_in_desc.data_type())\n        << \"InferDataType Failed. Expected \" << DataType_Name(in_desc.data_type()) << \", but got \"\n        << DataType_Name(first_in_desc.data_type());\n  }\n\n  user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc(\"out\", 0);\n  out_desc->set_data_type(first_in_desc.data_type());\n\n  for (int32_t i = 0; i < ctx->output_size(\"hidden\"); i++) {\n    user_op::TensorDesc* hidden_desc = ctx->MutOutputTensorDesc(\"hidden\", i);\n    hidden_desc->set_data_type(first_in_desc.data_type());\n  }\n\n  for (int32_t i = 0; i < ctx->output_size(\"cublas_aux\"); i++) {\n    user_op::TensorDesc* aux_desc = ctx->MutOutputTensorDesc(\"cublas_aux\", i);\n    aux_desc->set_data_type(DataType::kInt32);\n  }\n\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\n/* static */ Maybe<void> FusedMatmulBiasAddReluDropoutOp::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferTensorDesc4FusedMatmul(ctx);\n}\n\n/*static*/ Maybe<void> FusedMatmulBiasAddReluDropoutOp::InferPhysicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> FusedMatmulBiasAddReluDropoutOp::GetSbp(user_op::SbpContext* ctx) {\n  auto builder = ctx->NewBuilder().Split(user_op::OpArg(\"x\", 0), 0);\n  for (int i = 0; i < ctx->user_op_conf().input_size(\"weights\"); ++i) {\n    builder.Broadcast(user_op::OpArg(\"weights\", i));\n  }\n  for (int i = 0; i < ctx->user_op_conf().input_size(\"biases\"); ++i) {\n    builder.Broadcast(user_op::OpArg(\"biases\", i));\n  }\n  for (int i = 0; i < ctx->user_op_conf().output_size(\"cublas_aux\"); ++i) {\n    builder.Split(user_op::OpArg(\"cublas_aux\", i), 0);\n  }\n  for (int i = 0; i < ctx->user_op_conf().output_size(\"hidden\"); ++i) {\n    builder.Split(user_op::OpArg(\"hidden\", i), 0);\n  }\n  builder.Split(user_op::OpArg(\"out\", 0), 0);\n  builder.Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> FusedMatmulBiasAddReluDropoutOp::InferDataType(\n    user_op::InferContext* ctx) {\n  return InferDataType4Matmul(ctx);\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/fused_matmul_bias_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/data_type.pb.h\"\n#include \"oneflow/core/common/just.h\"\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/infer_util.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nMaybe<void> InferTensorDesc4FusedMatmulBias(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& x_desc = ctx->InputTensorDesc(\"x\", 0);\n  /*\n  x: (m_i, ... m_1, k)\n  weight: (n, k) need transpose\n  bias: (n)\n  */\n\n  CHECK_GE_OR_RETURN(x_desc.shape().NumAxes(), 2);\n  const int64_t k = x_desc.shape().At(x_desc.shape().NumAxes() - 1);\n\n  const user_op::TensorDesc& weight_desc = ctx->InputTensorDesc(\"weight\", 0);\n  const user_op::TensorDesc& bias_desc = ctx->InputTensorDesc(\"bias\", 0);\n  CHECK_EQ_OR_RETURN(weight_desc.shape().NumAxes(), 2);\n  CHECK_EQ_OR_RETURN(bias_desc.shape().NumAxes(), 1);\n\n  const int64_t n = weight_desc.shape().At(0);\n\n  CHECK_EQ_OR_RETURN(bias_desc.shape().At(0), n);\n  CHECK_EQ_OR_RETURN(weight_desc.shape().At(1), k);\n\n  Shape out_shape = x_desc.shape();\n  out_shape[x_desc.shape().NumAxes() - 1] = n;\n  ctx->SetOutputShape(\"out\", 0, out_shape);\n\n  if (ctx->has_input(\"_add_to_output\", 0)) {\n    const user_op::TensorDesc& _add_to_output_desc = ctx->InputTensorDesc(\"_add_to_output\", 0);\n    CHECK_EQ_OR_RETURN(_add_to_output_desc.shape(), out_shape);\n  }\n\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InferDataType4MatmulBias(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& first_in_desc = ctx->InputTensorDesc(\"x\", 0);\n\n  for (const auto& in_arg_pair : ctx->inputs()) {\n    const user_op::TensorDesc& in_desc =\n        ctx->InputTensorDesc(in_arg_pair.first, in_arg_pair.second);\n    CHECK_EQ_OR_RETURN(in_desc.data_type(), first_in_desc.data_type())\n        << \"InferDataType Failed. Expected \" << DataType_Name(first_in_desc.data_type())\n        << \", but got \" << DataType_Name(in_desc.data_type());\n  }\n\n  user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc(\"out\", 0);\n  out_desc->set_data_type(first_in_desc.data_type());\n\n  if (ctx->has_input(\"_add_to_output\", 0)) {\n    CHECK_EQ_OR_RETURN(ctx->InputDType(\"_add_to_output\", 0), out_desc->data_type())\n        << \"InferDataType Failed. _add_to_output Expected \" << DataType_Name(out_desc->data_type())\n        << \", but got \" << DataType_Name(ctx->InputDType(\"_add_to_output\", 0));\n  }\n\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\n/* static */ Maybe<void> FusedMatmulBiasOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  return InferTensorDesc4FusedMatmulBias(ctx);\n}\n\n/*static*/ Maybe<void> FusedMatmulBiasOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> FusedMatmulBiasOp::GetSbp(user_op::SbpContext* ctx) {\n  // (b, m, k) * (n, k)\n  const auto& x_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"x\", 0).shape();\n\n  const int64_t x_num_axes = x_shape.NumAxes();\n\n  const int64_t out_num_axes = x_num_axes;\n  const int32_t k_x_axis = x_num_axes - 1;\n\n  std::vector<user_op::OpArg> out_and_add_to_output_args;\n  out_and_add_to_output_args.emplace_back(\"out\", 0);\n  if (ctx->user_op_conf().has_input(\"_add_to_output\", 0)) {\n    out_and_add_to_output_args.emplace_back(\"_add_to_output\", 0);\n  }\n\n  for (int i = 0; i < x_shape.NumAxes() - 1; i++) {\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"x\", 0), i)\n        .Broadcast(user_op::OpArg(\"weight\", 0))\n        .Broadcast(user_op::OpArg(\"bias\", 0))\n        .Split(out_and_add_to_output_args, i)\n        .Build();\n  }\n\n  // B x S(n_axis) -> S(n_axis)\n  ctx->NewBuilder()\n      .Broadcast(user_op::OpArg(\"x\", 0))\n      .Split(user_op::OpArg(\"weight\", 0), 0)\n      .Split(user_op::OpArg(\"bias\", 0), 0)\n      .Split(out_and_add_to_output_args, out_num_axes - 1)\n      .Build();\n\n  // S(x_k_axis) x S(w_k_axis) -> P\n  ctx->NewBuilder()\n      .Split(user_op::OpArg(\"x\", 0), k_x_axis)\n      .Split(user_op::OpArg(\"weight\", 0), 1)\n      .PartialSum(user_op::OpArg(\"bias\", 0))\n      .PartialSum(out_and_add_to_output_args)\n      .Build();\n\n  // P x B -> P\n  ctx->NewBuilder()\n      .PartialSum(user_op::OpArg(\"x\", 0))\n      .Broadcast(user_op::OpArg(\"weight\", 0))\n      .PartialSum(user_op::OpArg(\"bias\", 0))\n      .PartialSum(out_and_add_to_output_args)\n      .Build();\n\n  // B x P -> P\n  ctx->NewBuilder()\n      .Broadcast(user_op::OpArg(\"x\", 0))\n      .PartialSum(user_op::OpArg(\"weight\", 0))\n      .PartialSum(user_op::OpArg(\"bias\", 0))\n      .PartialSum(out_and_add_to_output_args)\n      .Build();\n\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> FusedMatmulBiasOp::InferDataType(user_op::InferContext* ctx) {\n  return InferDataType4MatmulBias(ctx);\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/fused_relu_dropout_grad_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/data_type.pb.h\"\n#include \"oneflow/core/common/just.h\"\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/infer_util.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nMaybe<void> InferTensorDesc4FusedReluDropoutGrad(user_op::InferContext* ctx) {\n  ctx->SetOutputShape(\"dx\", 0, ctx->InputShape(\"dy\", 0));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InferDataType4FusedReluDropoutGrad(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"dx\", 0, ctx->InputDType(\"dy\", 0));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\n/* static */ Maybe<void> FusedReluDropoutGradOp::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferTensorDesc4FusedReluDropoutGrad(ctx);\n}\n\n/*static*/ Maybe<void> FusedReluDropoutGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> FusedReluDropoutGradOp::GetSbp(user_op::SbpContext* ctx) {\n  ctx->NewBuilder()\n      .Split(user_op::OpArg(\"dy\", 0), 0)\n      .Split(user_op::OpArg(\"mask\", 0), 0)\n      .Split(user_op::OpArg(\"dx\", 0), 0)\n      .Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> FusedReluDropoutGradOp::InferDataType(user_op::InferContext* ctx) {\n  return InferDataType4FusedReluDropoutGrad(ctx);\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/fused_scale_mask_bias_softmax_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/common/shape.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n#include \"oneflow/core/framework/user_op_conf.h\"\n\nnamespace oneflow {\n\n/*static*/ auto FusedScaleMaskBiasSoftmaxOp::InferDataType(user_op::InferContext* ctx)\n    -> Maybe<void> {\n  DataType query_type = ctx->InputDType(\"x\", 0);\n  DataType mask_bias_type = ctx->InputDType(\"mask\", 0);\n  CHECK_EQ_OR_RETURN(mask_bias_type, query_type);\n\n  if (ctx->has_input(\"bias\", 0)) {\n    DataType bias_type = ctx->InputDType(\"bias\", 0);\n    CHECK_EQ_OR_RETURN(bias_type, query_type);\n  }\n  ctx->SetOutputDType(\"out\", 0, query_type);\n  return Maybe<void>::Ok();\n}\n\n/*static*/ auto FusedScaleMaskBiasSoftmaxOp::InferLogicalTensorDesc(user_op::InferContext* ctx)\n    -> Maybe<void> {\n  const float scale = ctx->Attr<float>(\"scale\");\n  CHECK_LE_OR_RETURN(scale, 1.);\n\n  const Shape& x_shape = ctx->InputShape(\"x\", 0);\n  const Shape& mask_shape = ctx->InputShape(\"mask\", 0);\n  CHECK_OR_RETURN(x_shape[-1] == mask_shape[-1] && x_shape[0] == mask_shape[0]);\n  if (ctx->has_input(\"bias\", 0)) {\n    const Shape& bias_shape = ctx->InputShape(\"bias\", 0);\n    CHECK_OR_RETURN(mask_shape[-1] == bias_shape[-1]);\n    CHECK_OR_RETURN(mask_shape[0] == bias_shape[0] || bias_shape[0] == 1);\n    for (int i = 1; i < x_shape.NumAxes() - 1; i++) {\n      CHECK_OR_RETURN((mask_shape[i] == 1 || bias_shape[i] == 1)\n                      && mask_shape[i] * bias_shape[i] == x_shape[i]);\n    }\n  } else {\n    auto axes = x_shape.NumAxes();\n    bool reach1 = false;\n    for (int i = 0; i < axes - 1; i++) {\n      CHECK_OR_RETURN((mask_shape[i] == x_shape[i] && !reach1) || (1 == mask_shape[i]));\n      reach1 = (1 == mask_shape[i]);\n    }\n  }\n  ctx->SetOutputShape(\"out\", 0, x_shape);\n  return Maybe<void>::Ok();\n}\n\n/*static*/ auto FusedScaleMaskBiasSoftmaxOp::InferPhysicalTensorDesc(user_op::InferContext* ctx)\n    -> Maybe<void> {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/*static*/ auto FusedScaleMaskBiasSoftmaxOp::GetSbp(user_op::SbpContext* ctx) -> Maybe<void> {\n  if (ctx->Attr<bool>(\"inplace\") == false)\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"x\", 0), 0)\n        .Split(user_op::OpArg(\"mask\", 0), 0)\n        .Broadcast(user_op::OpArg(\"bias\", 0))\n        .Split(user_op::OpArg(\"out\", 0), 0)\n        .Build();\n  else\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"x\", 0), 0)\n        .Split(user_op::OpArg(\"mask\", 0), 0)\n        .Broadcast(user_op::OpArg(\"bias\", 0))\n        .Build();\n  return Maybe<void>::Ok();\n}\n\n/*static*/ auto FusedScaleMaskBiasSoftmaxGradOp::InferDataType(user_op::InferContext* ctx)\n    -> Maybe<void> {\n  DataType y_type = ctx->InputDType(\"y\", 0);\n  DataType dy_type = ctx->InputDType(\"dy\", 0);\n  CHECK_EQ_OR_RETURN(y_type, dy_type);\n  ctx->SetOutputDType(\"dx\", 0, y_type);\n  return Maybe<void>::Ok();\n}\n\n/*static*/ auto FusedScaleMaskBiasSoftmaxGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx)\n    -> Maybe<void> {\n  const Shape& y_shape = ctx->InputShape(\"y\", 0);\n  const Shape& dy_shape = ctx->InputShape(\"dy\", 0);\n  CHECK_EQ_OR_RETURN(y_shape, dy_shape);\n  ctx->SetOutputShape(\"dx\", 0, y_shape);\n  return Maybe<void>::Ok();\n}\n\n/*static*/ auto FusedScaleMaskBiasSoftmaxGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx)\n    -> Maybe<void> {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/*static*/ auto FusedScaleMaskBiasSoftmaxGradOp::GetSbp(user_op::SbpContext* ctx) -> Maybe<void> {\n  ctx->NewBuilder()\n      .Split(user_op::OpArg(\"y\", 0), 0)\n      .Split(user_op::OpArg(\"dy\", 0), 0)\n      .Split(user_op::OpArg(\"dx\", 0), 0)\n      .Build();\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/fused_scale_mask_softmax_dropout_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/*static*/ auto FusedScaleMaskSoftmaxDropoutOp::InferLogicalTensorDesc(user_op::InferContext* ctx)\n    -> Maybe<void> {\n  const user_op::TensorDesc& x_desc = ctx->InputTensorDesc(\"x\", 0);\n  const user_op::TensorDesc& mask_desc = ctx->InputTensorDesc(\"mask\", 0);\n  const auto x_shape = x_desc.shape();\n  const auto mask_shape = mask_desc.shape();\n  CHECK_EQ_OR_RETURN(x_desc.shape().At(x_shape.NumAxes() - 1),\n                     mask_desc.shape().At(mask_shape.NumAxes() - 1))\n      << \" last dim of x and mask is not equal.\";\n  ctx->SetOutputShape(\"y\", 0, x_desc.shape());\n  ctx->SetOutputIsDynamic(\"y\", 0, x_desc.is_dynamic());\n  ctx->SetOutputShape(\"softmax_y\", 0, x_desc.shape());\n  ctx->SetOutputIsDynamic(\"softmax_y\", 0, x_desc.is_dynamic());\n  return Maybe<void>::Ok();\n}\n/*static*/ auto FusedScaleMaskSoftmaxDropoutOp::InferPhysicalTensorDesc(user_op::InferContext* ctx)\n    -> Maybe<void> {\n  return FusedScaleMaskSoftmaxDropoutOp::InferLogicalTensorDesc(ctx);\n}\n/*static*/ auto FusedScaleMaskSoftmaxDropoutOp::InferDataType(user_op::InferContext* ctx)\n    -> Maybe<void> {\n  const user_op::TensorDesc& x_desc = ctx->InputTensorDesc(\"x\", 0);\n  const user_op::TensorDesc& mask_desc = ctx->InputTensorDesc(\"mask\", 0);\n  CHECK_EQ_OR_RETURN(mask_desc.data_type(), DataType::kBool)\n      << \"InferDataType Failed. Expected \" << DataType_Name(DataType::kBool) << \", but got \"\n      << DataType_Name(mask_desc.data_type());\n  ctx->SetOutputDType(\"y\", 0, x_desc.data_type());\n  ctx->SetOutputDType(\"softmax_y\", 0, x_desc.data_type());\n  return Maybe<void>::Ok();\n}\n/*static*/ auto FusedScaleMaskSoftmaxDropoutOp::ModifyInputArg(\n    const user_op::GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&)\n    -> Maybe<void> {\n  user_op::InputArgModifier* mask_modifier = GetInputArgModifierFn(\"mask\", 0);\n  user_op::InputArgModifier* dropout_mask_modifier = GetInputArgModifierFn(\"dropout_mask\", 0);\n  CHECK_OR_RETURN(mask_modifier != nullptr) << \" cannot find mask input.\";\n  CHECK_OR_RETURN(dropout_mask_modifier != nullptr) << \" cannot find dropout mask input.\";\n  mask_modifier->set_requires_grad(false);\n  dropout_mask_modifier->set_requires_grad(false);\n  return Maybe<void>::Ok();\n}\n/*static*/ auto FusedScaleMaskSoftmaxDropoutOp::GetSbp(user_op::SbpContext* ctx) -> Maybe<void> {\n  const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"x\", 0);\n  CHECK_GE_OR_RETURN(x_tensor.shape().NumAxes(), 2) << \" x num axes at least 2.\";\n  const user_op::TensorDesc& mask_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"mask\", 0);\n  CHECK_EQ_OR_RETURN(x_tensor.shape().NumAxes(), mask_tensor.shape().NumAxes())\n      << \" x num axes must equal with mask.\";\n  FOR_RANGE(int64_t, axis, 0, x_tensor.shape().NumAxes() - 2) {\n    // NOTE(chengcheng): mask support broadcast, when dim value = 1, sbp = broadcast\n    if (mask_tensor.shape().At(axis) == 1) {\n      ctx->NewBuilder()\n          .Split(user_op::OpArg(\"x\", 0), axis)\n          .Broadcast(user_op::OpArg(\"mask\", 0))\n          .Split(user_op::OpArg(\"dropout_mask\", 0), axis)\n          .Split(user_op::OpArg(\"y\", 0), axis)\n          .Split(user_op::OpArg(\"softmax_y\", 0), axis)\n          .Build();\n    } else {\n      ctx->NewBuilder()\n          .Split(user_op::OpArg(\"x\", 0), axis)\n          .Split(user_op::OpArg(\"mask\", 0), axis)\n          .Split(user_op::OpArg(\"dropout_mask\", 0), axis)\n          .Split(user_op::OpArg(\"y\", 0), axis)\n          .Split(user_op::OpArg(\"softmax_y\", 0), axis)\n          .Build();\n    }\n  }\n  return Maybe<void>::Ok();\n}\n\n/*static*/ auto FusedScaleMaskSoftmaxDropoutGradOp::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) -> Maybe<void> {\n  const user_op::TensorDesc& softmax_y_desc = ctx->InputTensorDesc(\"softmax_y\", 0);\n  const user_op::TensorDesc& dy_desc = ctx->InputTensorDesc(\"dy\", 0);\n  const user_op::TensorDesc& mask_desc = ctx->InputTensorDesc(\"mask\", 0);\n  CHECK_EQ_OR_RETURN(dy_desc.shape(), softmax_y_desc.shape()) << \" dy and y shape must equal.\";\n  CHECK_EQ_OR_RETURN(dy_desc.shape().At(dy_desc.shape().NumAxes() - 1),\n                     mask_desc.shape().At(mask_desc.shape().NumAxes() - 1))\n      << \" last dim of y and mask is not equal.\";\n  user_op::TensorDesc* dx_desc = ctx->MutOutputTensorDesc(\"dx\", 0);\n  dx_desc->set_shape(dy_desc.shape());\n  dx_desc->set_is_dynamic(dy_desc.is_dynamic());\n  return Maybe<void>::Ok();\n}\n/*static*/ auto FusedScaleMaskSoftmaxDropoutGradOp::InferPhysicalTensorDesc(\n    user_op::InferContext* ctx) -> Maybe<void> {\n  return FusedScaleMaskSoftmaxDropoutGradOp::InferLogicalTensorDesc(ctx);\n}\n/*static*/ auto FusedScaleMaskSoftmaxDropoutGradOp::InferDataType(user_op::InferContext* ctx)\n    -> Maybe<void> {\n  const user_op::TensorDesc& softmax_y_desc = ctx->InputTensorDesc(\"softmax_y\", 0);\n  const user_op::TensorDesc& dy_desc = ctx->InputTensorDesc(\"dy\", 0);\n  const user_op::TensorDesc& mask_desc = ctx->InputTensorDesc(\"mask\", 0);\n  CHECK_EQ_OR_RETURN(dy_desc.data_type(), softmax_y_desc.data_type())\n      << \"InferDataType Failed. Expected \" << DataType_Name(softmax_y_desc.data_type())\n      << \", but got \" << DataType_Name(dy_desc.data_type());\n  CHECK_EQ_OR_RETURN(mask_desc.data_type(), DataType::kBool)\n      << \"InferDataType Failed. Expected \" << DataType_Name(DataType::kBool) << \", but got \"\n      << DataType_Name(mask_desc.data_type());\n  user_op::TensorDesc* dx_desc = ctx->MutOutputTensorDesc(\"dx\", 0);\n  dx_desc->set_data_type(dy_desc.data_type());\n  return Maybe<void>::Ok();\n}\n/*static*/ auto FusedScaleMaskSoftmaxDropoutGradOp::GetSbp(user_op::SbpContext* ctx)\n    -> Maybe<void> {\n  const user_op::TensorDesc& dy_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"dy\", 0);\n  CHECK_GE_OR_RETURN(dy_tensor.shape().NumAxes(), 2) << \" dy num axes at least 2.\";\n  const user_op::TensorDesc& mask_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"mask\", 0);\n  CHECK_EQ_OR_RETURN(dy_tensor.shape().NumAxes(), mask_tensor.shape().NumAxes())\n      << \" dy num axes must equal with mask.\";\n  FOR_RANGE(int64_t, axis, 0, dy_tensor.shape().NumAxes() - 2) {\n    if (mask_tensor.shape().At(axis) == 1) {\n      ctx->NewBuilder()\n          .Split(user_op::OpArg(\"softmax_y\", 0), axis)\n          .Split(user_op::OpArg(\"dy\", 0), axis)\n          .Broadcast(user_op::OpArg(\"mask\", 0))\n          .Split(user_op::OpArg(\"dropout_mask\", 0), axis)\n          .Split(user_op::OpArg(\"dx\", 0), axis)\n          .Build();\n    } else {\n      ctx->NewBuilder()\n          .Split(user_op::OpArg(\"softmax_y\", 0), axis)\n          .Split(user_op::OpArg(\"dy\", 0), axis)\n          .Split(user_op::OpArg(\"mask\", 0), axis)\n          .Split(user_op::OpArg(\"dropout_mask\", 0), axis)\n          .Split(user_op::OpArg(\"dx\", 0), axis)\n          .Build();\n    }\n  }\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/fused_scale_mask_softmax_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/*static*/ auto FusedScaleMaskSoftmaxOp::InferLogicalTensorDesc(user_op::InferContext* ctx)\n    -> Maybe<void> {\n  const user_op::TensorDesc& x_desc = ctx->InputTensorDesc(\"x\", 0);\n  const user_op::TensorDesc& mask_desc = ctx->InputTensorDesc(\"mask\", 0);\n  const auto x_shape = x_desc.shape();\n  const auto mask_shape = mask_desc.shape();\n  CHECK_EQ_OR_RETURN(x_desc.shape().At(x_shape.NumAxes() - 1),\n                     mask_desc.shape().At(mask_shape.NumAxes() - 1))\n      << \" last dim of x and mask is not equal.\";\n  ctx->SetOutputShape(\"y\", 0, x_desc.shape());\n  ctx->SetOutputIsDynamic(\"y\", 0, x_desc.is_dynamic());\n  return Maybe<void>::Ok();\n}\n/*static*/ auto FusedScaleMaskSoftmaxOp::InferPhysicalTensorDesc(user_op::InferContext* ctx)\n    -> Maybe<void> {\n  return FusedScaleMaskSoftmaxOp::InferLogicalTensorDesc(ctx);\n}\n/*static*/ auto FusedScaleMaskSoftmaxOp::InferDataType(user_op::InferContext* ctx) -> Maybe<void> {\n  const user_op::TensorDesc& x_desc = ctx->InputTensorDesc(\"x\", 0);\n  const user_op::TensorDesc& mask_desc = ctx->InputTensorDesc(\"mask\", 0);\n  CHECK_EQ_OR_RETURN(mask_desc.data_type(), DataType::kBool) << \" mask dtype only support bool.\";\n  ctx->SetOutputDType(\"y\", 0, x_desc.data_type());\n  return Maybe<void>::Ok();\n}\n/*static*/ auto FusedScaleMaskSoftmaxOp::ModifyInputArg(\n    const user_op::GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&)\n    -> Maybe<void> {\n  user_op::InputArgModifier* mask_modifier = GetInputArgModifierFn(\"mask\", 0);\n  CHECK_OR_RETURN(mask_modifier != nullptr) << \" cannot find mask input.\";\n  mask_modifier->set_requires_grad(false);\n  return Maybe<void>::Ok();\n}\n/*static*/ auto FusedScaleMaskSoftmaxOp::GetSbp(user_op::SbpContext* ctx) -> Maybe<void> {\n  const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"x\", 0);\n  CHECK_GE_OR_RETURN(x_tensor.shape().NumAxes(), 2) << \" x num axes at least 2.\";\n  const user_op::TensorDesc& mask_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"mask\", 0);\n  CHECK_EQ_OR_RETURN(x_tensor.shape().NumAxes(), mask_tensor.shape().NumAxes())\n      << \" x num axes must equal with mask.\";\n  FOR_RANGE(int64_t, axis, 0, x_tensor.shape().NumAxes() - 2) {\n    if (mask_tensor.shape().At(axis) == 1) {\n      ctx->NewBuilder()\n          .Split(user_op::OpArg(\"x\", 0), axis)\n          .Broadcast(user_op::OpArg(\"mask\", 0))\n          .Split(user_op::OpArg(\"y\", 0), axis)\n          .Build();\n    } else {\n      ctx->NewBuilder()\n          .Split(user_op::OpArg(\"x\", 0), axis)\n          .Split(user_op::OpArg(\"mask\", 0), axis)\n          .Split(user_op::OpArg(\"y\", 0), axis)\n          .Build();\n    }\n  }\n  return Maybe<void>::Ok();\n}\n\n/*static*/ auto FusedScaleMaskSoftmaxGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx)\n    -> Maybe<void> {\n  const user_op::TensorDesc& dy_desc = ctx->InputTensorDesc(\"dy\", 0);\n  const user_op::TensorDesc& y_desc = ctx->InputTensorDesc(\"y\", 0);\n  const user_op::TensorDesc& mask_desc = ctx->InputTensorDesc(\"mask\", 0);\n  CHECK_EQ_OR_RETURN(dy_desc.shape(), y_desc.shape()) << \" dy and y shape must equal.\";\n  CHECK_EQ_OR_RETURN(y_desc.shape().At(y_desc.shape().NumAxes() - 1),\n                     mask_desc.shape().At(mask_desc.shape().NumAxes() - 1))\n      << \" last dim of y and mask is not equal.\";\n  user_op::TensorDesc* dx_desc = ctx->MutOutputTensorDesc(\"dx\", 0);\n  dx_desc->set_shape(dy_desc.shape());\n  dx_desc->set_is_dynamic(dy_desc.is_dynamic());\n  return Maybe<void>::Ok();\n}\n/*static*/ auto FusedScaleMaskSoftmaxGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx)\n    -> Maybe<void> {\n  return FusedScaleMaskSoftmaxGradOp::InferLogicalTensorDesc(ctx);\n}\n/*static*/ auto FusedScaleMaskSoftmaxGradOp::InferDataType(user_op::InferContext* ctx)\n    -> Maybe<void> {\n  const user_op::TensorDesc& dy_desc = ctx->InputTensorDesc(\"dy\", 0);\n  const user_op::TensorDesc& y_desc = ctx->InputTensorDesc(\"y\", 0);\n  const user_op::TensorDesc& mask_desc = ctx->InputTensorDesc(\"mask\", 0);\n  CHECK_EQ_OR_RETURN(dy_desc.data_type(), y_desc.data_type()) << \" dy and y dtype must equal\";\n  CHECK_EQ_OR_RETURN(mask_desc.data_type(), DataType::kBool) << \" mask dtype only support bool.\";\n  user_op::TensorDesc* dx_desc = ctx->MutOutputTensorDesc(\"dx\", 0);\n  dx_desc->set_data_type(dy_desc.data_type());\n  return Maybe<void>::Ok();\n}\n/*static*/ auto FusedScaleMaskSoftmaxGradOp::GetSbp(user_op::SbpContext* ctx) -> Maybe<void> {\n  const user_op::TensorDesc& dy_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"dy\", 0);\n  CHECK_GE_OR_RETURN(dy_tensor.shape().NumAxes(), 2) << \" dy num axes at least 2.\";\n  const user_op::TensorDesc& mask_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"mask\", 0);\n  CHECK_EQ_OR_RETURN(dy_tensor.shape().NumAxes(), mask_tensor.shape().NumAxes())\n      << \" dy num axes must equal with mask.\";\n  FOR_RANGE(int64_t, axis, 0, dy_tensor.shape().NumAxes() - 2) {\n    if (mask_tensor.shape().At(axis) == 1) {\n      ctx->NewBuilder()\n          .Split(user_op::OpArg(\"y\", 0), axis)\n          .Split(user_op::OpArg(\"dy\", 0), axis)\n          .Broadcast(user_op::OpArg(\"mask\", 0))\n          .Split(user_op::OpArg(\"dx\", 0), axis)\n          .Build();\n    } else {\n      ctx->NewBuilder()\n          .Split(user_op::OpArg(\"y\", 0), axis)\n          .Split(user_op::OpArg(\"dy\", 0), axis)\n          .Split(user_op::OpArg(\"mask\", 0), axis)\n          .Split(user_op::OpArg(\"dx\", 0), axis)\n          .Build();\n    }\n  }\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/fused_scale_tril_softmax_mask_scale_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n/*static*/ auto FusedTrilScaleSoftmaxMaskScaleOp::InferLogicalTensorDesc(user_op::InferContext* ctx)\n    -> Maybe<void> {\n  const user_op::TensorDesc& x_desc = ctx->InputTensorDesc(\"x\", 0);\n  ctx->SetOutputShape(\"y\", 0, x_desc.shape());\n  ctx->SetOutputIsDynamic(\"y\", 0, x_desc.is_dynamic());\n  ctx->SetOutputShape(\"softmax_y\", 0, x_desc.shape());\n  ctx->SetOutputIsDynamic(\"softmax_y\", 0, x_desc.is_dynamic());\n  return Maybe<void>::Ok();\n}\n/*static*/ auto FusedTrilScaleSoftmaxMaskScaleOp::InferPhysicalTensorDesc(\n    user_op::InferContext* ctx) -> Maybe<void> {\n  return FusedTrilScaleSoftmaxMaskScaleOp::InferLogicalTensorDesc(ctx);\n}\n/*static*/ auto FusedTrilScaleSoftmaxMaskScaleOp::InferDataType(user_op::InferContext* ctx)\n    -> Maybe<void> {\n  const user_op::TensorDesc& x_desc = ctx->InputTensorDesc(\"x\", 0);\n  ctx->SetOutputDType(\"y\", 0, x_desc.data_type());\n  ctx->SetOutputDType(\"softmax_y\", 0, x_desc.data_type());\n  return Maybe<void>::Ok();\n}\n/*static*/ auto FusedTrilScaleSoftmaxMaskScaleOp::ModifyInputArg(\n    const user_op::GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&)\n    -> Maybe<void> {\n  user_op::InputArgModifier* mask_modifier = GetInputArgModifierFn(\"mask\", 0);\n  CHECK_OR_RETURN(mask_modifier != nullptr);\n  mask_modifier->set_requires_grad(false);\n  return Maybe<void>::Ok();\n}\n/*static*/ auto FusedTrilScaleSoftmaxMaskScaleOp::GetSbp(user_op::SbpContext* ctx) -> Maybe<void> {\n  const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"x\", 0);\n  CHECK_GE_OR_RETURN(x_tensor.shape().NumAxes(), 2);\n  FOR_RANGE(int64_t, axis, 0, x_tensor.shape().NumAxes() - 2) {\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"x\", 0), axis)\n        .Split(user_op::OpArg(\"mask\", 0), axis)\n        .Split(user_op::OpArg(\"y\", 0), axis)\n        .Split(user_op::OpArg(\"softmax_y\", 0), axis)\n        .Build();\n  }\n  return Maybe<void>::Ok();\n}\n\n/*static*/ auto FusedTrilScaleSoftmaxMaskScaleGradOp::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) -> Maybe<void> {\n  const user_op::TensorDesc& softmax_y_desc = ctx->InputTensorDesc(\"softmax_y\", 0);\n  const user_op::TensorDesc& dy_desc = ctx->InputTensorDesc(\"dy\", 0);\n  user_op::TensorDesc* dx_desc = ctx->MutOutputTensorDesc(\"dx\", 0);\n  CHECK_OR_RETURN(dy_desc.shape() == softmax_y_desc.shape());\n  dx_desc->set_shape(dy_desc.shape());\n  dx_desc->set_is_dynamic(dy_desc.is_dynamic());\n  return Maybe<void>::Ok();\n}\n/*static*/ auto FusedTrilScaleSoftmaxMaskScaleGradOp::InferPhysicalTensorDesc(\n    user_op::InferContext* ctx) -> Maybe<void> {\n  return FusedTrilScaleSoftmaxMaskScaleGradOp::InferLogicalTensorDesc(ctx);\n}\n/*static*/ auto FusedTrilScaleSoftmaxMaskScaleGradOp::InferDataType(user_op::InferContext* ctx)\n    -> Maybe<void> {\n  const user_op::TensorDesc& softmax_y_desc = ctx->InputTensorDesc(\"softmax_y\", 0);\n  const user_op::TensorDesc& dy_desc = ctx->InputTensorDesc(\"dy\", 0);\n  user_op::TensorDesc* dx_desc = ctx->MutOutputTensorDesc(\"dx\", 0);\n  CHECK_OR_RETURN(dy_desc.data_type() == softmax_y_desc.data_type());\n  dx_desc->set_data_type(dy_desc.data_type());\n  return Maybe<void>::Ok();\n}\n/*static*/ auto FusedTrilScaleSoftmaxMaskScaleGradOp::GetSbp(user_op::SbpContext* ctx)\n    -> Maybe<void> {\n  const user_op::TensorDesc& dy_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"dy\", 0);\n  CHECK_GE_OR_RETURN(dy_tensor.shape().NumAxes(), 2);\n  FOR_RANGE(int64_t, axis, 0, dy_tensor.shape().NumAxes() - 2) {\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"softmax_y\", 0), axis)\n        .Split(user_op::OpArg(\"dy\", 0), axis)\n        .Split(user_op::OpArg(\"mask\", 0), axis)\n        .Split(user_op::OpArg(\"dx\", 0), axis)\n        .Build();\n  }\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/fused_self_attention_query_mul_key_and_value_ops.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/*static*/ auto FusedSelfAttentionQueryMulKeyAndValueOp::InferDataType(user_op::InferContext* ctx)\n    -> Maybe<void> {\n  DataType dtype = ctx->InputDType(\"hidden_states\", 0);\n  ctx->SetOutputDType(\"query_mul_key\", 0, dtype);\n  ctx->SetOutputDType(\"value\", 0, dtype);\n  return Maybe<void>::Ok();\n}\n/*static*/ auto FusedSelfAttentionQueryMulKeyAndValueOp::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) -> Maybe<void> {\n  CHECK_OR_RETURN(!(ctx->InputIsDynamic(\"hidden_states\", 0)));\n  int64_t head_size = ctx->Attr<int64_t>(\"head_size\");\n  const Shape& hidden_states_shape = ctx->InputShape(\"hidden_states\", 0);\n  // hidden_states_shape (seq_len, batch_size, hidden_size)\n  // layout is (seq_len, batch_size, num_heads, 3, head_size)\n  // for example shape (1024, 4, 12, 3, 64) -> (1024, 4, 12, 192) which stride is (9216, 2304,\n  // 192, 1)\n  CHECK_EQ_OR_RETURN(hidden_states_shape.NumAxes(), 3);\n  int64_t seq_len = hidden_states_shape.At(0);\n  int64_t batch_size = hidden_states_shape.At(1);\n  int64_t hidden_size = hidden_states_shape.At(2);\n  CHECK_EQ_OR_RETURN(hidden_size % (head_size * 3), 0);\n  int64_t num_heads = hidden_size / (head_size * 3);\n\n  ctx->SetOutputShape(\"query_mul_key\", 0, Shape({batch_size, num_heads, seq_len, seq_len}));\n  ctx->SetOutputShape(\"value\", 0, Shape({batch_size, num_heads, seq_len, head_size}));\n\n  return Maybe<void>::Ok();\n}\n/*static*/ auto FusedSelfAttentionQueryMulKeyAndValueOp::InferPhysicalTensorDesc(\n    user_op::InferContext* ctx) -> Maybe<void> {\n  return FusedSelfAttentionQueryMulKeyAndValueOp::InferLogicalTensorDesc(ctx);\n}\n/*static*/ auto FusedSelfAttentionQueryMulKeyAndValueOp::GetSbp(user_op::SbpContext* ctx)\n    -> Maybe<void> {\n  ctx->NewBuilder()\n      .Split(user_op::OpArg(\"hidden_states\", 0), 1)\n      .Split(user_op::OpArg(\"query_mul_key\", 0), 0)\n      .Split(user_op::OpArg(\"value\", 0), 0)\n      .Build();\n  ctx->NewBuilder()\n      .Split(user_op::OpArg(\"hidden_states\", 0), 2)\n      .Split(user_op::OpArg(\"query_mul_key\", 0), 1)\n      .Split(user_op::OpArg(\"value\", 0), 1)\n      .Build();\n  return Maybe<void>::Ok();\n}\n\n/*static*/ auto FusedSelfAttentionQueryMulKeyAndValueGradOp::InferDataType(\n    user_op::InferContext* ctx) -> Maybe<void> {\n  DataType dtype = ctx->InputDType(\"query_mul_key_grad\", 0);\n  CHECK_EQ_OR_RETURN(ctx->InputDType(\"value_grad\", 0), dtype)\n      << \"InferDataType Failed. Expected \" << DataType_Name(dtype) << \", but got \"\n      << DataType_Name(ctx->InputDType(\"value_grad\", 0));\n  ctx->SetOutputDType(\"hidden_states_grad\", 0, dtype);\n  return Maybe<void>::Ok();\n}\n/*static*/ auto FusedSelfAttentionQueryMulKeyAndValueGradOp::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) -> Maybe<void> {\n  CHECK_OR_RETURN(!(ctx->InputIsDynamic(\"query_mul_key_grad\", 0)));\n  CHECK_OR_RETURN(!(ctx->InputIsDynamic(\"value_grad\", 0)));\n  const Shape& h_shape = ctx->InputShape(\"hidden_states\", 0);\n  const Shape& qmk_grad_shape = ctx->InputShape(\"query_mul_key_grad\", 0);\n  const Shape& v_grad_shape = ctx->InputShape(\"value_grad\", 0);\n  CHECK_EQ_OR_RETURN(h_shape.NumAxes(), 3);\n  CHECK_EQ_OR_RETURN(qmk_grad_shape.NumAxes(), 4);\n  CHECK_EQ_OR_RETURN(v_grad_shape.NumAxes(), 4);\n  // hidden_states shape (s, b, H)\n  int64_t seq_len = h_shape.At(0);\n  int64_t batch_size = h_shape.At(1);\n  int64_t hidden_size = h_shape.At(2);\n  // value grad shape (b, n, s, h)\n  int64_t num_heads = v_grad_shape.At(1);\n  int64_t head_size = v_grad_shape.At(3);\n  CHECK_EQ_OR_RETURN(v_grad_shape.At(0), batch_size);\n  CHECK_EQ_OR_RETURN(v_grad_shape.At(2), seq_len);\n  CHECK_EQ_OR_RETURN(hidden_size, num_heads * 3 * head_size);\n  // qmk grad shape (b, n, sq, sk)\n  CHECK_EQ_OR_RETURN(qmk_grad_shape.At(0), batch_size);\n  CHECK_EQ_OR_RETURN(qmk_grad_shape.At(1), num_heads);\n  CHECK_EQ_OR_RETURN(qmk_grad_shape.At(2), seq_len);\n  CHECK_EQ_OR_RETURN(qmk_grad_shape.At(3), seq_len);\n\n  ctx->SetOutputShape(\"hidden_states_grad\", 0, h_shape);\n  return Maybe<void>::Ok();\n}\n/*static*/ auto FusedSelfAttentionQueryMulKeyAndValueGradOp::InferPhysicalTensorDesc(\n    user_op::InferContext* ctx) -> Maybe<void> {\n  return FusedSelfAttentionQueryMulKeyAndValueGradOp::InferLogicalTensorDesc(ctx);\n}\n/*static*/ auto FusedSelfAttentionQueryMulKeyAndValueGradOp::GetSbp(user_op::SbpContext* ctx)\n    -> Maybe<void> {\n  ctx->NewBuilder()\n      .Split(user_op::OpArg(\"query_mul_key_grad\", 0), 0)\n      .Split(user_op::OpArg(\"value_grad\", 0), 0)\n      .Split(user_op::OpArg(\"hidden_states\", 0), 1)\n      .Split(user_op::OpArg(\"hidden_states_grad\", 0), 1)\n      .Build();\n  ctx->NewBuilder()\n      .Split(user_op::OpArg(\"query_mul_key_grad\", 0), 1)\n      .Split(user_op::OpArg(\"value_grad\", 0), 1)\n      .Split(user_op::OpArg(\"hidden_states\", 0), 2)\n      .Split(user_op::OpArg(\"hidden_states_grad\", 0), 2)\n      .Build();\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/fused_weighted_sum_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/* static */ Maybe<void> FusedWeightedSumOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const auto& in_0 = ctx->InputTensorDesc(\"in\", 0);\n  auto* out = ctx->MutOutputTensorDesc(\"out\", 0);\n  for (int64_t i = 1; i < ctx->input_size(\"in\"); ++i) {\n    const auto& cur_in = ctx->InputTensorDesc(\"in\", i);\n    CHECK_EQ_OR_RETURN(in_0.shape(), cur_in.shape())\n        << Error::RuntimeError()\n        << \"inconsistent tensor size, expected all tensor to have the same shape, \"\n        << \"but got \" << in_0.shape().DebugStr() << \" and \" << cur_in.shape().DebugStr();\n  }\n  out->set_shape(in_0.shape());\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> FusedWeightedSumOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> FusedWeightedSumOp::GetSbp(user_op::SbpContext* ctx) {\n  const int64_t num_axes = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"in\", 0).shape().NumAxes();\n  for (int64_t i = 0; i < num_axes; ++i) {\n    ctx->NewBuilder().Split(ctx->inputs(), i).Split(user_op::OpArg(\"out\", 0), i).Build();\n  }\n  ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(user_op::OpArg(\"out\", 0)).Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> FusedWeightedSumOp::InferDataType(user_op::InferContext* ctx) {\n  const auto& in_0 = ctx->InputTensorDesc(\"in\", 0);\n  auto* out = ctx->MutOutputTensorDesc(\"out\", 0);\n  const DataType data_type = in_0.data_type();\n  for (int64_t i = 1; i < ctx->input_size(\"in\"); ++i) {\n    const auto& cur_in = ctx->InputTensorDesc(\"in\", i);\n    CHECK_EQ_OR_RETURN(cur_in.data_type(), data_type)\n        << Error::RuntimeError() << ctx->op_name()\n        << \" expected all tenser to have same type, but found \" << DataType_Name(cur_in.data_type())\n        << \" and \" << DataType_Name(data_type);\n  }\n  out->set_data_type(data_type);\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> FusedWeightedSumOp::CheckAttr(const user_op::UserOpDefWrapper&,\n                                                     const user_op::UserOpConfWrapper& op_conf) {\n  CHECK_OR_RETURN(op_conf.input_size(\"in\") >= 2)\n      << Error::RuntimeError()\n      << \"The number of input tensors should be greater than or equal to 2\";\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/gather_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/*static*/ auto GatherOp::InferLogicalTensorDesc(user_op::InferContext* ctx) -> Maybe<void> {\n  const user_op::TensorDesc& in = ctx->InputTensorDesc(\"in\", 0);\n  CHECK_GT_OR_RETURN(in.shape().NumAxes(), 0);\n  const int64_t axis = ctx->Attr<int64_t>(\"axis\");\n  const user_op::TensorDesc& indices = ctx->InputTensorDesc(\"indices\", 0);\n  // For 0-dim Tensor\n  CHECK_GE_OR_RETURN(indices.shape().NumAxes(), 0);  // NOLINT\n  user_op::TensorDesc* out = ctx->MutOutputTensorDesc(\"out\", 0);\n\n  DimVector dim_vec;\n  dim_vec.insert(dim_vec.end(), in.shape().dim_vec().cbegin(),\n                 in.shape().dim_vec().cbegin() + axis);\n  dim_vec.insert(dim_vec.end(), indices.shape().dim_vec().cbegin(),\n                 indices.shape().dim_vec().cend());\n  dim_vec.insert(dim_vec.end(), in.shape().dim_vec().cbegin() + axis + 1,\n                 in.shape().dim_vec().end());\n  out->set_shape(Shape(dim_vec));\n  out->set_is_dynamic(indices.is_dynamic() || in.is_dynamic());\n  return Maybe<void>::Ok();\n}\n/*static*/ auto GatherOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) -> Maybe<void> {\n  return GatherOp::InferLogicalTensorDesc(ctx);\n}\n/*static*/ auto GatherOp::ModifyInputArg(const user_op::GetInputArgModifier& GetInputArgModifierFn,\n                                         const user_op::UserOpConfWrapper&) -> Maybe<void> {\n  user_op::InputArgModifier* indices_modifier = GetInputArgModifierFn(\"indices\", 0);\n  CHECK_OR_RETURN(indices_modifier != nullptr);\n  indices_modifier->set_requires_grad(false);\n  return Maybe<void>::Ok();\n}\n/*static*/ auto GatherOp::GetSbp(user_op::SbpContext* ctx) -> Maybe<void> {\n  const int64_t in_num_axes =\n      ctx->LogicalTensorDesc4InputArgNameAndIndex(\"in\", 0).shape().NumAxes();\n  const int64_t indices_num_axes =\n      ctx->LogicalTensorDesc4InputArgNameAndIndex(\"indices\", 0).shape().NumAxes();\n  const int64_t gather_axis = ctx->Attr<int64_t>(\"axis\");\n  CHECK_GE_OR_RETURN(gather_axis, 0);\n  CHECK_LT_OR_RETURN(gather_axis, in_num_axes);\n  FOR_RANGE(int64_t, i, 0, indices_num_axes) {\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"indices\", 0), i)\n        .Broadcast(user_op::OpArg(\"in\", 0))\n        .Split(user_op::OpArg(\"out\", 0), gather_axis + i)\n        .Build();\n  }\n  FOR_RANGE(int64_t, i, 0, in_num_axes) {\n    if (i == gather_axis) {\n      ctx->NewBuilder()\n          .Broadcast(user_op::OpArg(\"indices\", 0))\n          .Split(user_op::OpArg(\"in\", 0), i)\n          .PartialSum(user_op::OpArg(\"out\", 0))\n          .Build();\n    } else {\n      ctx->NewBuilder()\n          .Broadcast(user_op::OpArg(\"indices\", 0))\n          .Split(user_op::OpArg(\"in\", 0), i)\n          .Split(user_op::OpArg(\"out\", 0), i < gather_axis ? i : i + indices_num_axes - 1)\n          .Build();\n    }\n  }\n  return Maybe<void>::Ok();\n}\n/*static*/ auto GatherOp::InferDataType(user_op::InferContext* ctx) -> Maybe<void> {\n  const user_op::TensorDesc& in = ctx->InputTensorDesc(\"in\", 0);\n  const user_op::TensorDesc& indices = ctx->InputTensorDesc(\"indices\", 0);\n  user_op::TensorDesc* out = ctx->MutOutputTensorDesc(\"out\", 0);\n  CHECK_OR_RETURN(IsIndexDataType(indices.data_type()));\n  out->set_data_type(in.data_type());\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/gelu_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nMaybe<void> InferGeluTensorDesc(user_op::InferContext* ctx) {\n  ctx->SetOutputShape(\"out\", 0, ctx->InputShape(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InferGeluDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> GetGeluSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"in\", 0);\n  FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) {\n    ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build();\n  }\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\n/*static*/ auto GeluOp::InferLogicalTensorDesc(user_op::InferContext* ctx) -> Maybe<void> {\n  return InferGeluTensorDesc(ctx);\n}\n/*static*/ auto GeluOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) -> Maybe<void> {\n  return InferGeluTensorDesc(ctx);\n}\n/*static*/ auto GeluOp::InferDataType(user_op::InferContext* ctx) -> Maybe<void> {\n  return InferGeluDataType(ctx);\n}\n/*static*/ auto GeluOp::GetSbp(user_op::SbpContext* ctx) -> Maybe<void> { return GetGeluSbp(ctx); }\n\n/*static*/ auto FastGeluOp::InferLogicalTensorDesc(user_op::InferContext* ctx) -> Maybe<void> {\n  return InferGeluTensorDesc(ctx);\n}\n/*static*/ auto FastGeluOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) -> Maybe<void> {\n  return InferGeluTensorDesc(ctx);\n}\n/*static*/ auto FastGeluOp::InferDataType(user_op::InferContext* ctx) -> Maybe<void> {\n  return InferGeluDataType(ctx);\n}\n/*static*/ auto FastGeluOp::GetSbp(user_op::SbpContext* ctx) -> Maybe<void> {\n  return GetGeluSbp(ctx);\n}\n\nnamespace {\n\nMaybe<void> InferGeluGradTensorDesc(user_op::InferContext* ctx) {\n  const Shape& x_shape = ctx->InputShape(\"x\", 0);\n  const Shape& dy_shape = ctx->InputShape(\"dy\", 0);\n  CHECK_OR_RETURN(dy_shape == x_shape)\n      << \"InferTensorDesc failed (\" << ctx->op_name() << \"). Expected x shape \"\n      << x_shape.ToString() << \" to be equal to dy shape \" << dy_shape.ToString();\n  ctx->SetOutputShape(\"dx\", 0, dy_shape);\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InferGeluGradDataType(user_op::InferContext* ctx) {\n  CHECK_EQ_OR_RETURN(ctx->InputDType(\"x\", 0), ctx->InputDType(\"dy\", 0))\n      << \"InferDataType Failed. Expected \" << DataType_Name(ctx->InputDType(\"dy\", 0))\n      << \", but got \" << DataType_Name(ctx->InputDType(\"x\", 0));\n  ctx->SetOutputDType(\"dx\", 0, ctx->InputDType(\"x\", 0));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> GetGeluGradSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"x\", 0);\n  FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) {\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"x\", 0), i)\n        .Split(user_op::OpArg(\"dy\", 0), i)\n        .Split(user_op::OpArg(\"dx\", 0), i)\n        .Build();\n  }\n  ctx->NewBuilder()\n      .Broadcast(user_op::OpArg(\"x\", 0))\n      .PartialSum(user_op::OpArg(\"dy\", 0))\n      .PartialSum(user_op::OpArg(\"dx\", 0))\n      .Build();\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\n/*static*/ auto GeluGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) -> Maybe<void> {\n  return InferGeluGradTensorDesc(ctx);\n}\n/*static*/ auto GeluGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) -> Maybe<void> {\n  return InferGeluGradTensorDesc(ctx);\n}\n/*static*/ auto GeluGradOp::InferDataType(user_op::InferContext* ctx) -> Maybe<void> {\n  return InferGeluGradDataType(ctx);\n}\n/*static*/ auto GeluGradOp::GetSbp(user_op::SbpContext* ctx) -> Maybe<void> {\n  return GetGeluGradSbp(ctx);\n}\n\n/*static*/ auto FastGeluGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) -> Maybe<void> {\n  return InferGeluGradTensorDesc(ctx);\n}\n/*static*/ auto FastGeluGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) -> Maybe<void> {\n  return InferGeluGradTensorDesc(ctx);\n}\n/*static*/ auto FastGeluGradOp::InferDataType(user_op::InferContext* ctx) -> Maybe<void> {\n  return InferGeluGradDataType(ctx);\n}\n/*static*/ auto FastGeluGradOp::GetSbp(user_op::SbpContext* ctx) -> Maybe<void> {\n  return GetGeluGradSbp(ctx);\n}\n\n/*static*/ Maybe<void> FusedFastGeluMulOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& in_shape = ctx->InputShape(\"in\", 0);\n  const Shape& m_shape = ctx->InputShape(\"multiplier\", 0);\n  CHECK_OR_RETURN(ctx->InputShape(\"multiplier\", 0) == in_shape)\n      << \"Expected multiplier shape \" << in_shape.ToString() << \", but got \" << m_shape.ToString();\n  ctx->SetOutputShape(\"out\", 0, in_shape);\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> FusedFastGeluMulOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/*static*/ Maybe<void> FusedFastGeluMulOp::InferDataType(user_op::InferContext* ctx) {\n  const DataType in_dtype = ctx->InputDType(\"in\", 0);\n  const DataType m_dtype = ctx->InputDType(\"multiplier\", 0);\n  CHECK_EQ_OR_RETURN(m_dtype, in_dtype)\n      << \"Expected multiplier data type \" << DataType_Name(in_dtype) << \", but got \"\n      << DataType_Name(m_dtype);\n  ctx->SetOutputDType(\"out\", 0, in_dtype);\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> FusedFastGeluMulOp::GetSbp(user_op::SbpContext* ctx) {\n  return GetGeluSbp(ctx);\n}\n\n/*static*/ Maybe<void> FusedFastGeluMulGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& in_shape = ctx->InputShape(\"in\", 0);\n  const Shape& out_diff_shape = ctx->InputShape(\"out_diff\", 0);\n  const Shape& m_shape = ctx->InputShape(\"multiplier\", 0);\n  CHECK_EQ_OR_RETURN(out_diff_shape, in_shape);  // NOLINT(maybe-need-error-msg)\n  CHECK_EQ_OR_RETURN(m_shape, in_shape);         // NOLINT(maybe-need-error-msg)\n  ctx->SetOutputShape(\"in_diff\", 0, in_shape);\n  ctx->SetOutputShape(\"multiplier_diff\", 0, m_shape);\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> FusedFastGeluMulGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/*static*/ Maybe<void> FusedFastGeluMulGradOp::InferDataType(user_op::InferContext* ctx) {\n  const DataType in_dtype = ctx->InputDType(\"in\", 0);\n  const DataType out_diff_dtype = ctx->InputDType(\"out_diff\", 0);\n  const DataType m_dtype = ctx->InputDType(\"multiplier\", 0);\n  CHECK_EQ_OR_RETURN(out_diff_dtype, in_dtype);  // NOLINT(maybe-need-error-msg)\n  CHECK_EQ_OR_RETURN(m_dtype, in_dtype);         // NOLINT(maybe-need-error-msg)\n  ctx->SetOutputDType(\"in_diff\", 0, in_dtype);\n  ctx->SetOutputDType(\"multiplier_diff\", 0, m_dtype);\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> FusedFastGeluMulGradOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& in = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"in\", 0);\n  FOR_RANGE(int64_t, i, 0, in.shape().NumAxes()) {\n    ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build();\n  }\n  ctx->NewBuilder()\n      .Broadcast(user_op::OpArg(\"in\", 0))\n      .Broadcast(user_op::OpArg(\"multiplier\", 0))\n      .PartialSum(user_op::OpArg(\"out_diff\", 0))\n      .PartialSum(user_op::OpArg(\"in_diff\", 0))\n      .PartialSum(user_op::OpArg(\"multiplier_diff\", 0))\n      .Build();\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/generate_random_batch_permutation_indices_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include <cstdint>\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/*static*/ auto GenerateRandomBatchPermutationIndicesOp::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) -> Maybe<void> {\n  ctx->SetOutputShape(\"y\", 0, Shape({ctx->InputShape(\"x\", 0).At(0)}));\n  return Maybe<void>::Ok();\n}\n/*static*/ auto GenerateRandomBatchPermutationIndicesOp::InferPhysicalTensorDesc(\n    user_op::InferContext* ctx) -> Maybe<void> {\n  return GenerateRandomBatchPermutationIndicesOp::InferLogicalTensorDesc(ctx);\n}\n/*static*/ auto GenerateRandomBatchPermutationIndicesOp::GetSbp(user_op::SbpContext* ctx)\n    -> Maybe<void> {\n  ctx->NewBuilder().PartialSum(user_op::OpArg(\"x\", 0)).Broadcast(user_op::OpArg(\"y\", 0)).Build();\n  const auto& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"x\", 0);\n  FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) {\n    ctx->NewBuilder().Split(user_op::OpArg(\"x\", 0), i).Broadcast(user_op::OpArg(\"y\", 0)).Build();\n  }\n  return Maybe<void>::Ok();\n}\n/*static*/ auto GenerateRandomBatchPermutationIndicesOp::InferDataType(user_op::InferContext* ctx)\n    -> Maybe<void> {\n  ctx->SetOutputDType(\"y\", 0, DataType::kInt32);\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/gpt_data_loader_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/*static*/ auto MegatronGptMmapDataLoaderOp::InferLogicalTensorDesc(user_op::InferContext* ctx)\n    -> Maybe<void> {\n  int64_t batch_size = ctx->Attr<int64_t>(\"batch_size\");\n  int64_t sample_len = ctx->Attr<int64_t>(\"seq_length\") + ctx->Attr<int64_t>(\"label_length\");\n  user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc(\"out\", 0);\n  out_desc->set_shape(Shape({batch_size, sample_len}));\n  return Maybe<void>::Ok();\n}\n/*static*/ auto MegatronGptMmapDataLoaderOp::InferDataType(user_op::InferContext* ctx)\n    -> Maybe<void> {\n  ctx->MutOutputTensorDesc(\"out\", 0)->set_data_type(ctx->Attr<DataType>(\"dtype\"));\n  return Maybe<void>::Ok();\n}\n/*static*/ auto MegatronGptMmapDataLoaderOp::GetSbp(user_op::SbpContext* ctx) -> Maybe<void> {\n  ctx->NewBuilder().Split(ctx->outputs(), 0).Build();\n  return Maybe<void>::Ok();\n}\n/*static*/ auto MegatronGptMmapDataLoaderOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx)\n    -> Maybe<void> {\n  SbpParallel default_sbp;\n  default_sbp.mutable_split_parallel()->set_axis(0);\n  return user_op::InferNdSbp4SrcOp(ctx, default_sbp);\n}\n/*static*/ auto MegatronGptMmapDataLoaderOp::ModifyInputArg(\n    const user_op::GetInputArgModifier& GetInputArgModifierFn,\n    const user_op::UserOpConfWrapper& conf) -> Maybe<void> {\n  if (!conf.has_input(\"iteration\", 0)) { return Maybe<void>::Ok(); }\n  user_op::InputArgModifier* input_modifier = GetInputArgModifierFn(\"iteration\", 0);\n  CHECK_OR_RETURN(input_modifier != nullptr);\n  input_modifier->set_is_mutable(true);\n  input_modifier->set_requires_grad(false);\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/greater_inplace_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nbool CheckBroadCastAble(const Shape& shape, const Shape& broadcast_shape) {\n  int left_pad = broadcast_shape.size() - shape.size();\n  if (left_pad < 0) { return false; }\n  for (int i = 0; i < shape.size(); ++i) {\n    int j = i + left_pad;\n    if (shape[i] != 1 && shape[i] != broadcast_shape[j]) { return false; }\n  }\n  return true;\n}\n\n}  // namespace\n\n/*static*/ Maybe<void> BroadCastInplaceGreaterOp::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) {\n  const auto& x_desc = ctx->InputTensorDesc(\"x\", 0);\n  const auto& y_desc = ctx->InputTensorDesc(\"y\", 0);\n  auto x_shape = x_desc.shape();\n  auto y_shape = y_desc.shape();\n  bool broadcast_status = CheckBroadCastAble(y_shape, x_shape);\n  CHECK_OR_RETURN(broadcast_status);\n  ctx->SetOutputShape(\"out\", 0, x_shape);\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> BroadCastInplaceGreaterOp::InferPhysicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return BroadCastInplaceGreaterOp::InferLogicalTensorDesc(ctx);\n}\n\n/*static*/ Maybe<void> BroadCastInplaceGreaterOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& x = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"x\", 0);\n  FOR_RANGE(int64_t, i, 0, x.shape().NumAxes()) {\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"x\", 0), i)\n        .Split(user_op::OpArg(\"y\", 0), i)\n        .Split(user_op::OpArg(\"out\", 0), i)\n        .Build();\n  }\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> BroadCastInplaceGreaterOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"x\", 0));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> ScalarLogicalInplaceGreaterOp::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) {\n  ctx->SetOutputShape(\"out\", 0, ctx->InputShape(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> ScalarLogicalInplaceGreaterOp::InferPhysicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return ScalarLogicalInplaceGreaterOp::InferLogicalTensorDesc(ctx);\n}\n\n/*static*/ Maybe<void> ScalarLogicalInplaceGreaterOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& in = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"in\", 0);\n  FOR_RANGE(int64_t, i, 0, in.shape().NumAxes()) {\n    ctx->NewBuilder().Split(user_op::OpArg(\"in\", 0), i).Split(user_op::OpArg(\"out\", 0), i).Build();\n  }\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> ScalarLogicalInplaceGreaterOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/grid_sample_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\nMaybe<void> GridSampleOp::CheckAttr(const user_op::UserOpDefWrapper& def,\n                                    const user_op::UserOpConfWrapper& conf) {\n  bool pass_checked = true;\n  std::stringstream err;\n  err << \"Illegal value for \" << conf.op_type_name() << \" op \" << conf.op_name() << \": \";\n\n  const auto& interpolation_mode = conf.attr<std::string>(\"interpolation_mode\");\n  if (!(interpolation_mode == \"bilinear\" || interpolation_mode == \"nearest\"\n        || interpolation_mode == \"bicubic\")) {\n    err << \" interpolation_mode:\" << interpolation_mode;\n    pass_checked = false;\n  }\n\n  const auto& padding_mode = conf.attr<std::string>(\"padding_mode\");\n  if (!(padding_mode == \"zeros\" || padding_mode == \"border\" || padding_mode == \"reflection\")) {\n    err << \" padding_mode:\" << padding_mode;\n    pass_checked = false;\n  }\n\n  if (pass_checked) {\n    return Maybe<void>::Ok();\n  } else {\n    return oneflow::Error::CheckFailedError() << err.str();\n  }\n}\n\n/*static*/ auto GridSampleOp::InferLogicalTensorDesc(user_op::InferContext* ctx) -> Maybe<void> {\n  const user_op::TensorDesc& input = ctx->InputTensorDesc(\"input\", 0);\n  const user_op::TensorDesc& grid = ctx->InputTensorDesc(\"grid\", 0);\n  user_op::TensorDesc& output = *(ctx->MutOutputTensorDesc(\"output\", 0));\n  // Only support 4D or 5D input with NCHW layout\n  // For 4D grid: input  = { N, C, H_in, W_in },\n  //              grid   = { N, H_out, W_out, 2 }\n  //              output = { N, C, H_out, W_out }\n  // For 5D grid: input  = { N, C, D_in, H_in, W_in },\n  //              grid   = { N, D_out, H_out, W_out, 3 }\n  //              output = { N, C, D_out, H_out, W_out }\n  const Shape& input_shape = input.shape();\n  const Shape& grid_shape = grid.shape();\n\n  bool is_4d_input = true;\n  if (input_shape.NumAxes() == 4) {\n    CHECK_EQ_OR_RETURN(grid_shape.NumAxes(), 4) << \"Grid and input MUST have same dimention\";\n    CHECK_EQ_OR_RETURN(grid_shape.At(3), 2) << \"Grid shape MUST (N, H_out, W_out, 2)\";\n    is_4d_input = true;\n  } else if (input_shape.NumAxes() == 5) {\n    CHECK_EQ_OR_RETURN(grid_shape.NumAxes(), 5) << \"Grid and input MUST have same dimention\";\n    CHECK_EQ_OR_RETURN(grid_shape.At(4), 3) << \"Grid shape MUST (N, H_out, W_out, 3)\";\n    if (ctx->Attr<std::string>(\"interpolation_mode\") == \"bicubic\") {\n      oneflow::Error::CheckFailedError() << \"Mode='bicubic' supports only 4-D input\";\n    }\n    is_4d_input = false;\n  } else {\n    CHECK_OR_RETURN(false) << \"MUST be 4D or 5D input\";\n  }\n  output.set_is_dynamic(grid.is_dynamic());\n  if (is_4d_input) {\n    output.set_shape(\n        Shape({input_shape.At(0), input_shape.At(1), grid_shape.At(1), grid_shape.At(2)}));\n  } else {\n    output.set_shape(Shape({input_shape.At(0), input_shape.At(1), grid_shape.At(1),\n                            grid_shape.At(2), grid_shape.At(3)}));\n  }\n  return Maybe<void>::Ok();\n}\n/*static*/ auto GridSampleOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) -> Maybe<void> {\n  return GridSampleOp::InferLogicalTensorDesc(ctx);\n}\n/*static*/ auto GridSampleOp::GetSbp(user_op::SbpContext* ctx) -> Maybe<void> {\n  ctx->NewBuilder()\n      .Split(user_op::OpArg(\"input\", 0), 0)\n      .Split(user_op::OpArg(\"grid\", 0), 0)\n      .Split(user_op::OpArg(\"output\", 0), 0)\n      .Build();\n  ctx->NewBuilder()\n      .Split(user_op::OpArg(\"input\", 0), 1)\n      .Broadcast(user_op::OpArg(\"grid\", 0))\n      .Split(user_op::OpArg(\"output\", 0), 1)\n      .Build();\n  return Maybe<void>::Ok();\n}\n/*static*/ auto GridSampleOp::InferDataType(user_op::InferContext* ctx) -> Maybe<void> {\n  ctx->SetOutputDType(\"output\", 0, ctx->InputDType(\"input\", 0));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> GridSampleGradOp::CheckAttr(const user_op::UserOpDefWrapper& def,\n                                        const user_op::UserOpConfWrapper& conf) {\n  return GridSampleOp::CheckAttr(def, conf);\n}\n\n/*static*/ auto GridSampleGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx)\n    -> Maybe<void> {\n  ctx->MutOutputTensorDesc(\"dinput\", 0)->set_shape(ctx->InputTensorDesc(\"input\", 0).shape());\n  ctx->MutOutputTensorDesc(\"dgrid\", 0)->set_shape(ctx->InputTensorDesc(\"grid\", 0).shape());\n  return Maybe<void>::Ok();\n}\n/*static*/ auto GridSampleGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx)\n    -> Maybe<void> {\n  return GridSampleGradOp::InferLogicalTensorDesc(ctx);\n}\n/*static*/ auto GridSampleGradOp::GetSbp(user_op::SbpContext* ctx) -> Maybe<void> {\n  ctx->NewBuilder()\n      .Split(user_op::OpArg(\"doutput\", 0), 0)\n      .Split(user_op::OpArg(\"input\", 0), 0)\n      .Split(user_op::OpArg(\"grid\", 0), 0)\n      .Split(user_op::OpArg(\"dinput\", 0), 0)\n      .Split(user_op::OpArg(\"dgrid\", 0), 0)\n      .Build();\n  ctx->NewBuilder()\n      .Split(user_op::OpArg(\"doutput\", 0), 1)\n      .Split(user_op::OpArg(\"input\", 0), 1)\n      .Broadcast(user_op::OpArg(\"grid\", 0))\n      .Split(user_op::OpArg(\"dinput\", 0), 1)\n      .PartialSum(user_op::OpArg(\"dgrid\", 0))\n      .Build();\n  return Maybe<void>::Ok();\n}\n/*static*/ auto GridSampleGradOp::InferDataType(user_op::InferContext* ctx) -> Maybe<void> {\n  ctx->SetOutputDType(\"dinput\", 0, ctx->InputDType(\"input\", 0));\n  ctx->SetOutputDType(\"dgrid\", 0, ctx->InputDType(\"grid\", 0));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/group_norm_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\nDEFINE_ENV_BOOL(ONEFLOW_GROUP_NORM_USE_FP16_DIRECTLY, false);\n\nnamespace {\n\noneflow::DataType InferGnParamDataType(const DataType x_data_type) {\n  if (EnvBool<ONEFLOW_GROUP_NORM_USE_FP16_DIRECTLY>()) { return x_data_type; }\n  return (x_data_type == DataType::kFloat16 || x_data_type == DataType::kBFloat16)\n             ? DataType::kFloat\n             : x_data_type;\n}\n\n}  // namespace\n\n/* static */ Maybe<void> GroupNormOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& x = ctx->InputTensorDesc(\"x\", 0);\n  user_op::TensorDesc* y = ctx->MutOutputTensorDesc(\"y\", 0);\n  user_op::TensorDesc* mean = ctx->MutOutputTensorDesc(\"mean\", 0);\n  user_op::TensorDesc* inv_variance = ctx->MutOutputTensorDesc(\"inv_variance\", 0);\n  const bool affine = ctx->Attr<bool>(\"affine\");\n  const int32_t num_groups = ctx->Attr<int32_t>(\"num_groups\");\n  const int64_t batch_size = x.shape().At(0);\n  const std::string& data_format = ctx->Attr<std::string>(\"data_format\");\n  CHECK_GT_OR_RETURN(x.shape().NumAxes(), 2);\n  int64_t channel_size = 0;\n  if (data_format == \"channels_first\") {\n    channel_size = x.shape().At(1);\n  } else if (data_format == \"channels_last\") {\n    channel_size = x.shape().At(x.shape().NumAxes() - 1);\n  } else {\n    UNIMPLEMENTED_THEN_RETURN();\n  }\n  y->set_shape(x.shape());\n  y->set_is_dynamic(x.is_dynamic());\n  if (affine) {\n    const user_op::TensorDesc& gamma = ctx->InputTensorDesc(\"gamma\", 0);\n    CHECK_EQ_OR_RETURN(gamma.shape().At(0), channel_size);\n    const user_op::TensorDesc& beta = ctx->InputTensorDesc(\"beta\", 0);\n    CHECK_EQ_OR_RETURN(beta.shape().At(0), channel_size);\n  }\n  CHECK_EQ_OR_RETURN(channel_size % num_groups, 0) << \"Channels should be divisble by num_groups. \";\n  mean->set_shape(Shape({batch_size, num_groups}));\n  *inv_variance = *mean;\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> GroupNormOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> GroupNormOp::GetSbp(user_op::SbpContext* ctx) {\n  ctx->NewBuilder()\n      .Split(ctx->inputs(), 0)\n      .Split(ctx->outputs(), 0)\n      .Broadcast(user_op::OpArg(\"gamma\", 0))\n      .Broadcast(user_op::OpArg(\"beta\", 0))\n      .Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> GroupNormOp::InferDataType(user_op::InferContext* ctx) {\n  const bool affine = ctx->Attr<bool>(\"affine\");\n  const user_op::TensorDesc& x = ctx->InputTensorDesc(\"x\", 0);\n  user_op::TensorDesc* y = ctx->MutOutputTensorDesc(\"y\", 0);\n  y->set_data_type(x.data_type());\n  if (affine) {\n    const user_op::TensorDesc& gamma = ctx->InputTensorDesc(\"gamma\", 0);\n    CHECK_EQ_OR_RETURN(gamma.data_type(), x.data_type())\n        << \"InferDataType Failed. Expected \" << DataType_Name(x.data_type()) << \", but got \"\n        << DataType_Name(gamma.data_type());\n    const user_op::TensorDesc& beta = ctx->InputTensorDesc(\"beta\", 0);\n    CHECK_EQ_OR_RETURN(beta.data_type(), x.data_type())\n        << \"InferDataType Failed. Expected \" << DataType_Name(x.data_type()) << \", but got \"\n        << DataType_Name(beta.data_type());\n  }\n  user_op::TensorDesc* mean = ctx->MutOutputTensorDesc(\"mean\", 0);\n  user_op::TensorDesc* inv_variance = ctx->MutOutputTensorDesc(\"inv_variance\", 0);\n  mean->set_data_type(InferGnParamDataType(x.data_type()));\n  inv_variance->set_data_type(mean->data_type());\n  return Maybe<void>::Ok();\n}\n\n// GroupNorm Grad\n/* static */ Maybe<void> GroupNormGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& dy = ctx->InputTensorDesc(\"dy\", 0);\n  const user_op::TensorDesc& x = ctx->InputTensorDesc(\"x\", 0);\n  const user_op::TensorDesc& mean = ctx->InputTensorDesc(\"mean\", 0);\n  const user_op::TensorDesc& inv_variance = ctx->InputTensorDesc(\"inv_variance\", 0);\n  const int32_t num_groups = ctx->Attr<int32_t>(\"num_groups\");\n  user_op::TensorDesc* dx = ctx->MutOutputTensorDesc(\"dx\", 0);\n  CHECK_EQ_OR_RETURN(dy.shape(), x.shape());\n  const Shape& gn_param_shape = Shape({x.shape().At(0), num_groups});\n  CHECK_EQ_OR_RETURN(mean.shape(), gn_param_shape);\n  CHECK_EQ_OR_RETURN(inv_variance.shape(), gn_param_shape);\n  dx->set_shape(dy.shape());\n  dx->set_is_dynamic(dy.is_dynamic());\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> GroupNormGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> GroupNormGradOp::GetSbp(user_op::SbpContext* ctx) {\n  std::vector<user_op::OpArg> broadcast_args;\n  if (ctx->user_op_conf().has_input(\"gamma\", 0)) {\n    broadcast_args.emplace_back(user_op::OpArg(\"gamma\", 0));\n  }\n\n  ctx->NewBuilder()\n      .Split(user_op::OpArg(\"dy\", 0), 0)\n      .Split(user_op::OpArg(\"x\", 0), 0)\n      .Split(user_op::OpArg(\"mean\", 0), 0)\n      .Split(user_op::OpArg(\"inv_variance\", 0), 0)\n      .Split(ctx->outputs(), 0)\n      .Broadcast(broadcast_args)\n      .Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> GroupNormGradOp::InferDataType(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& dy = ctx->InputTensorDesc(\"dy\", 0);\n  const user_op::TensorDesc& x = ctx->InputTensorDesc(\"x\", 0);\n  CHECK_EQ_OR_RETURN(dy.data_type(), x.data_type())\n      << \"InferDataType Failed. Expected \" << DataType_Name(x.data_type()) << \", but got \"\n      << DataType_Name(dy.data_type());\n  const user_op::TensorDesc& mean = ctx->InputTensorDesc(\"mean\", 0);\n  const user_op::TensorDesc& inv_variance = ctx->InputTensorDesc(\"inv_variance\", 0);\n  const DataType& gn_param_data_type = InferGnParamDataType(x.data_type());\n  CHECK_EQ_OR_RETURN(mean.data_type(), gn_param_data_type)\n      << \"InferDataType Failed. Expected \" << DataType_Name(gn_param_data_type) << \", but got \"\n      << DataType_Name(mean.data_type());\n  CHECK_EQ_OR_RETURN(inv_variance.data_type(), gn_param_data_type)\n      << \"InferDataType Failed. Expected \" << DataType_Name(gn_param_data_type) << \", but got \"\n      << DataType_Name(inv_variance.data_type());\n  user_op::TensorDesc* dx = ctx->MutOutputTensorDesc(\"dx\", 0);\n  dx->set_data_type(dy.data_type());\n  return Maybe<void>::Ok();\n}\n\n// GroupNorm Param Grad\n/* static */ Maybe<void> GroupNormParamGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& x = ctx->InputTensorDesc(\"x\", 0);\n  user_op::TensorDesc* dgamma = ctx->MutOutputTensorDesc(\"dgamma\", 0);\n  user_op::TensorDesc* dbeta = ctx->MutOutputTensorDesc(\"dbeta\", 0);\n  const int64_t channel_size = x.shape().At(1);\n  dgamma->set_shape(Shape{channel_size});\n  dbeta->set_shape(Shape{channel_size});\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> GroupNormParamGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> GroupNormParamGradOp::GetSbp(user_op::SbpContext* ctx) {\n  ctx->NewBuilder()\n      .Split(user_op::OpArg(\"dy\", 0), 0)\n      .Split(user_op::OpArg(\"x\", 0), 0)\n      .Split(user_op::OpArg(\"mean\", 0), 0)\n      .Split(user_op::OpArg(\"inv_variance\", 0), 0)\n      .PartialSum(ctx->outputs())\n      .Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> GroupNormParamGradOp::InferDataType(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& dy = ctx->InputTensorDesc(\"dy\", 0);\n  user_op::TensorDesc* dgamma = ctx->MutOutputTensorDesc(\"dgamma\", 0);\n  user_op::TensorDesc* dbeta = ctx->MutOutputTensorDesc(\"dbeta\", 0);\n  dgamma->set_data_type(dy.data_type());\n  dbeta->set_data_type(dy.data_type());\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/grouped_matmul_bias_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/data_type.pb.h\"\n#include \"oneflow/core/common/just.h\"\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/infer_util.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/* static */ Maybe<void> GroupedMatmulBiasOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const int64_t input_size = ctx->input_size(\"xs\");\n  CHECK_EQ_OR_RETURN(ctx->input_size(\"weights\"), input_size);\n  const bool has_biases = ctx->has_input(\"biases\", 0);\n  if (has_biases) { CHECK_EQ_OR_RETURN(ctx->input_size(\"biases\"), input_size); }\n  CHECK_EQ_OR_RETURN(ctx->output_size(\"ys\"), input_size);\n\n  const DataType data_type = ctx->InputTensorDesc(\"xs\", 0).data_type();\n  for (int64_t i = 0; i < input_size; ++i) {\n    const user_op::TensorDesc& x_desc = ctx->InputTensorDesc(\"xs\", i);\n    CHECK_EQ_OR_RETURN(x_desc.data_type(), data_type);\n    CHECK_GE_OR_RETURN(x_desc.shape().NumAxes(), 2);\n    const int64_t k = x_desc.shape().At(x_desc.shape().NumAxes() - 1);\n    const user_op::TensorDesc& weight_desc = ctx->InputTensorDesc(\"weights\", i);\n    CHECK_EQ_OR_RETURN(weight_desc.shape().NumAxes(), 2);\n    CHECK_EQ_OR_RETURN(weight_desc.shape().At(1), k);\n    const int64_t n = weight_desc.shape().At(0);\n    if (has_biases) {\n      const user_op::TensorDesc& bias_desc = ctx->InputTensorDesc(\"biases\", i);\n      CHECK_EQ_OR_RETURN(bias_desc.shape().NumAxes(), 1);\n      CHECK_EQ_OR_RETURN(bias_desc.shape().At(0), n);\n    }\n    user_op::TensorDesc* y_desc = ctx->MutOutputTensorDesc(\"ys\", i);\n    y_desc->set_data_type(data_type);\n    DimVector out_dim_vec = x_desc.shape().dim_vec();\n    out_dim_vec.back() = n;\n    y_desc->set_shape(Shape(out_dim_vec));\n  }\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> GroupedMatmulBiasOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> GroupedMatmulBiasOp::GetSbp(user_op::SbpContext* ctx) {\n  {\n    // s0 x b\n    auto builder = ctx->NewBuilder();\n    for (int64_t i = 0; i < ctx->user_op_conf().input_size(\"xs\"); ++i) {\n      builder.Split(user_op::OpArg(\"xs\", i), 0);\n    }\n    for (int i = 0; i < ctx->user_op_conf().input_size(\"weights\"); ++i) {\n      builder.Broadcast(user_op::OpArg(\"weights\", i));\n    }\n    for (int i = 0; i < ctx->user_op_conf().input_size(\"biases\"); ++i) {\n      builder.Broadcast(user_op::OpArg(\"biases\", i));\n    }\n    for (int i = 0; i < ctx->user_op_conf().output_size(\"ys\"); ++i) {\n      builder.Split(user_op::OpArg(\"ys\", i), 0);\n    }\n    builder.Build();\n  }\n\n  {\n    // b x s0\n    auto builder = ctx->NewBuilder();\n    for (int64_t i = 0; i < ctx->user_op_conf().input_size(\"xs\"); ++i) {\n      builder.Broadcast(user_op::OpArg(\"xs\", i));\n    }\n    for (int i = 0; i < ctx->user_op_conf().input_size(\"weights\"); ++i) {\n      builder.Split(user_op::OpArg(\"weights\", i), 0);\n    }\n    for (int i = 0; i < ctx->user_op_conf().input_size(\"biases\"); ++i) {\n      builder.Split(user_op::OpArg(\"biases\", i), 0);\n    }\n    for (int i = 0; i < ctx->user_op_conf().output_size(\"ys\"); ++i) {\n      builder.Split(user_op::OpArg(\"ys\", i),\n                    ctx->LogicalTensorDesc4InputArgNameAndIndex(\"xs\", i).shape().NumAxes() - 1);\n    }\n    builder.Build();\n  }\n\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> GroupedMatmulBiasOp::InferDataType(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& first_in_desc = ctx->InputTensorDesc(\"xs\", 0);\n  for (const auto& in_arg_pair : ctx->inputs()) {\n    const user_op::TensorDesc& in_desc =\n        ctx->InputTensorDesc(in_arg_pair.first, in_arg_pair.second);\n    CHECK_EQ_OR_RETURN(in_desc.data_type(), first_in_desc.data_type())\n        << \"InferDataType Failed. Expected \" << DataType_Name(first_in_desc.data_type())\n        << \", but got \" << DataType_Name(in_desc.data_type());\n  }\n  for (int32_t i = 0; i < ctx->output_size(\"ys\"); i++) {\n    user_op::TensorDesc* y_desc = ctx->MutOutputTensorDesc(\"ys\", i);\n    y_desc->set_data_type(first_in_desc.data_type());\n  }\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/groupwise_dequantize_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/*static*/ Maybe<void> GroupwiseDequantizeOp::GetSbp(user_op::SbpContext* ctx) {\n  const Shape& in_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"in\", 0).shape();\n  const Shape& scale_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"scale\", 0).shape();\n  std::vector<user_op::OpArg> scale_zero_args;\n  scale_zero_args.emplace_back(user_op::OpArg(\"scale\", 0));\n  if (ctx->user_op_conf().has_input(\"zero\", 0)) {\n    scale_zero_args.emplace_back(user_op::OpArg(\"zero\", 0));\n  }\n  for (int32_t i = 0; i < in_shape.NumAxes(); ++i) {\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"in\", 0), i)\n        .Split(scale_zero_args, i)\n        .Split(user_op::OpArg(\"out\", 0), i)\n        .Build();\n  }\n  const int64_t group_dim = ctx->Attr<int64_t>(\"group_dim\");\n  if (scale_shape.At(group_dim) == 1) {\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"in\", 0), group_dim)\n        .Broadcast(scale_zero_args)\n        .Split(user_op::OpArg(\"out\", 0), group_dim)\n        .Build();\n  }\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> GroupwiseDequantizeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& in_shape = ctx->InputShape(\"in\", 0);\n  const Shape& scale_shape = ctx->InputShape(\"scale\", 0);\n  const int32_t num_bits = ctx->Attr<int32_t>(\"num_bits\");\n  const int64_t group_dim = ctx->Attr<int64_t>(\"group_dim\");\n  const int64_t group_size = ctx->Attr<int64_t>(\"group_size\");\n  CHECK_OR_RETURN(num_bits == 4 || num_bits == 8);\n  CHECK_GE_OR_RETURN(in_shape.NumAxes(), 1);\n  CHECK_OR_RETURN(group_dim >= 0 && group_dim < in_shape.NumAxes());\n  Shape out_shape = in_shape;\n  out_shape.Set(out_shape.NumAxes() - 1, out_shape.At(out_shape.NumAxes() - 1) * (8 / num_bits));\n  const int64_t group_dim_size = out_shape.At(group_dim);\n  CHECK_GE_OR_RETURN(group_size, 0);\n  CHECK_EQ_OR_RETURN(group_dim_size % group_size, 0);\n  const int64_t num_groups = group_dim_size / group_size;\n  CHECK_EQ_OR_RETURN(scale_shape.NumAxes(), in_shape.NumAxes());\n  if (ctx->has_input(\"zero\", 0)) {\n    CHECK_EQ_OR_RETURN(ctx->InputShape(\"zero\", 0).NumAxes(), in_shape.NumAxes());\n  }\n  for (int64_t i = 0; i < out_shape.NumAxes(); ++i) {\n    if (i == group_dim) {\n      CHECK_EQ_OR_RETURN(scale_shape.At(i), num_groups);\n      if (ctx->has_input(\"zero\", 0)) {\n        CHECK_EQ_OR_RETURN(ctx->InputShape(\"zero\", 0).At(i), num_groups);\n      }\n    } else {\n      CHECK_EQ_OR_RETURN(scale_shape.At(i), out_shape.At(i));\n      if (ctx->has_input(\"zero\", 0)) {\n        CHECK_EQ_OR_RETURN(ctx->InputShape(\"zero\", 0).At(i), out_shape.At(i));\n      }\n    }\n  }\n  ctx->SetOutputShape(\"out\", 0, out_shape);\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> GroupwiseDequantizeOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/*static*/ Maybe<void> GroupwiseDequantizeOp::InferDataType(user_op::InferContext* ctx) {\n  const DataType data_type = ctx->InputDType(\"scale\", 0);\n  if (ctx->has_input(\"zero\", 0)) { CHECK_EQ_OR_RETURN(ctx->InputDType(\"zero\", 0), data_type); }\n  if (ctx->Attr<bool>(\"symmetric\")) {\n    CHECK_OR_RETURN(ctx->InputDType(\"in\", 0) == DataType::kUInt8\n                    || ctx->InputDType(\"in\", 0) == DataType::kInt8);\n  } else {\n    CHECK_EQ_OR_RETURN(ctx->InputDType(\"in\", 0), DataType::kUInt8);\n  }\n  ctx->SetOutputDType(\"out\", 0, data_type);\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/hardshrink_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/* static */ Maybe<void> HardShrinkOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  ctx->SetOutputShape(\"out\", 0, ctx->InputShape(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> HardShrinkOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> HardShrinkOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"in\", 0);\n  FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) {\n    ctx->NewBuilder().Split(user_op::OpArg(\"in\", 0), i).Split(user_op::OpArg(\"out\", 0), i).Build();\n  }\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> HardShrinkOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> HardShrinkGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& y_shape = ctx->InputShape(\"y\", 0);\n  const Shape& dy_shape = ctx->InputShape(\"dy\", 0);\n  CHECK_OR_RETURN(dy_shape == y_shape) << \"The shape of y_grad and y must be same.\";\n  ctx->SetOutputShape(\"dx\", 0, dy_shape);\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> HardShrinkGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> HardShrinkGradOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& y_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"y\", 0);\n  FOR_RANGE(int64_t, i, 0, y_tensor.shape().NumAxes()) {\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"y\", 0), i)\n        .Split(user_op::OpArg(\"dy\", 0), i)\n        .Split(user_op::OpArg(\"dx\", 0), i)\n        .Build();\n  }\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> HardShrinkGradOp::InferDataType(user_op::InferContext* ctx) {\n  CHECK_EQ_OR_RETURN(ctx->InputDType(\"dy\", 0), ctx->InputDType(\"y\", 0))\n      << \"InferDataType Failed. Expected \" << DataType_Name(ctx->InputDType(\"y\", 0)) << \", but got \"\n      << DataType_Name(ctx->InputDType(\"dy\", 0));\n  ctx->SetOutputDType(\"dx\", 0, ctx->InputDType(\"y\", 0));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/hardsigmoid_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/* static */ Maybe<void> HardsigmoidOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  ctx->SetOutputShape(\"out\", 0, ctx->InputShape(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> HardsigmoidOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> HardsigmoidOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"in\", 0);\n  FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) {\n    ctx->NewBuilder().Split(user_op::OpArg(\"in\", 0), i).Split(user_op::OpArg(\"out\", 0), i).Build();\n  }\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> HardsigmoidOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> HardsigmoidGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& x_shape = ctx->InputShape(\"x\", 0);\n  const Shape& dy_shape = ctx->InputShape(\"dy\", 0);\n  CHECK_OR_RETURN(dy_shape == x_shape);\n  ctx->SetOutputShape(\"dx\", 0, dy_shape);\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> HardsigmoidGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> HardsigmoidGradOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"x\", 0);\n  FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) {\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"x\", 0), i)\n        .Split(user_op::OpArg(\"dy\", 0), i)\n        .Split(user_op::OpArg(\"dx\", 0), i)\n        .Build();\n  }\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> HardsigmoidGradOp::InferDataType(user_op::InferContext* ctx) {\n  CHECK_EQ_OR_RETURN(ctx->InputDType(\"x\", 0), ctx->InputDType(\"dy\", 0))\n      << \"InferDataType Failed. Expected \" << DataType_Name(ctx->InputDType(\"dy\", 0))\n      << \", but got \" << DataType_Name(ctx->InputDType(\"x\", 0));\n  ctx->SetOutputDType(\"dx\", 0, ctx->InputDType(\"x\", 0));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/hardswish_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/* static */ Maybe<void> HardswishOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  ctx->SetOutputShape(\"out\", 0, ctx->InputShape(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> HardswishOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> HardswishOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"in\", 0);\n  FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) {\n    ctx->NewBuilder().Split(user_op::OpArg(\"in\", 0), i).Split(user_op::OpArg(\"out\", 0), i).Build();\n  }\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> HardswishOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> HardswishGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& x_shape = ctx->InputShape(\"x\", 0);\n  const Shape& dy_shape = ctx->InputShape(\"dy\", 0);\n  CHECK_OR_RETURN(dy_shape == x_shape);\n  ctx->SetOutputShape(\"dx\", 0, dy_shape);\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> HardswishGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> HardswishGradOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"x\", 0);\n  FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) {\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"x\", 0), i)\n        .Split(user_op::OpArg(\"dy\", 0), i)\n        .Split(user_op::OpArg(\"dx\", 0), i)\n        .Build();\n  }\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> HardswishGradOp::InferDataType(user_op::InferContext* ctx) {\n  CHECK_EQ_OR_RETURN(ctx->InputDType(\"x\", 0), ctx->InputDType(\"dy\", 0))\n      << \"InferDataType Failed. Expected \" << DataType_Name(ctx->InputDType(\"dy\", 0))\n      << \", but got \" << DataType_Name(ctx->InputDType(\"x\", 0));\n  ctx->SetOutputDType(\"dx\", 0, ctx->InputDType(\"x\", 0));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/hardtanh_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/* static */ Maybe<void> HardtanhOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  ctx->SetOutputShape(\"out\", 0, ctx->InputShape(\"in\", 0));\n  double min_val = ctx->Attr<double>(\"min_val\");\n  double max_val = ctx->Attr<double>(\"max_val\");\n  CHECK_LE_OR_RETURN(min_val, max_val);\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> HardtanhOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> HardtanhOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"in\", 0);\n  FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) {\n    ctx->NewBuilder().Split(user_op::OpArg(\"in\", 0), i).Split(user_op::OpArg(\"out\", 0), i).Build();\n  }\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> HardtanhOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> HardtanhGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& y_shape = ctx->InputShape(\"y\", 0);\n  const Shape& dy_shape = ctx->InputShape(\"dy\", 0);\n  CHECK_OR_RETURN(dy_shape == y_shape);\n  ctx->SetOutputShape(\"dx\", 0, dy_shape);\n  double min_val = ctx->Attr<double>(\"min_val\");\n  double max_val = ctx->Attr<double>(\"max_val\");\n  CHECK_LE_OR_RETURN(min_val, max_val);\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> HardtanhGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> HardtanhGradOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& y_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"y\", 0);\n  FOR_RANGE(int64_t, i, 0, y_tensor.shape().NumAxes()) {\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"y\", 0), i)\n        .Split(user_op::OpArg(\"dy\", 0), i)\n        .Split(user_op::OpArg(\"dx\", 0), i)\n        .Build();\n  }\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> HardtanhGradOp::InferDataType(user_op::InferContext* ctx) {\n  CHECK_EQ_OR_RETURN(ctx->InputDType(\"y\", 0), ctx->InputDType(\"dy\", 0))\n      << \"InferDataType Failed. Expected \" << DataType_Name(ctx->InputDType(\"dy\", 0))\n      << \", but got \" << DataType_Name(ctx->InputDType(\"y\", 0));\n  ctx->SetOutputDType(\"dx\", 0, ctx->InputDType(\"y\", 0));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/hierarchical_parallel_cast_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/operator/operator.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/* static */ Maybe<void> HierarchicalParallelCastOp::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) {\n  ctx->SetOutputShape(\"out\", 0, ctx->InputShape(\"in\", 0));\n  ctx->SetOutputIsDynamic(\"out\", 0, ctx->InputIsDynamic(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> HierarchicalParallelCastOp::InferPhysicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> HierarchicalParallelCastOp::GetSbp(user_op::SbpContext* ctx) {\n  return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx);\n}\n\n/* static */ Maybe<void> HierarchicalParallelCastOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) {\n  NdSbp* in_distribution = ctx->NdSbp4ArgNameAndIndex(\"in\", 0);\n  NdSbp* out_distribution = ctx->NdSbp4ArgNameAndIndex(\"out\", 0);\n  const Shape& parallel_hierarchy = ctx->parallel_hierarchy();\n  const auto& conf = ctx->user_op_conf().attr<std::vector<std::string>>(\"nd_sbp\");\n  CHECK_EQ_OR_RETURN(conf.size(), parallel_hierarchy.NumAxes());\n  for (const std::string& sbp_str : conf) {\n    SbpParallel sbp_parallel;\n    CHECK_OR_RETURN(ParseSbpParallelFromString(sbp_str, &sbp_parallel));\n    *in_distribution->add_sbp_parallel() = sbp_parallel;\n    *out_distribution->add_sbp_parallel() = sbp_parallel;\n  }\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> HierarchicalParallelCastOp::GetNdSbpSignatureList(\n    user_op::GetNdSbpSignatureListContext* ctx) {\n  const auto& conf = ctx->Attr<std::vector<std::string>>(\"nd_sbp\");\n  NdSbpSignature nd_sbp_signature;\n  for (const std::string& sbp_str : conf) {\n    SbpParallel sbp_parallel;\n    CHECK_OR_RETURN(ParseSbpParallelFromString(sbp_str, &sbp_parallel));\n    *(*nd_sbp_signature.mutable_bn_in_op2nd_sbp())[GenRepeatedBn(\"in\", 0)].add_sbp_parallel() =\n        sbp_parallel;\n    *(*nd_sbp_signature.mutable_bn_in_op2nd_sbp())[GenRepeatedBn(\"out\", 0)].add_sbp_parallel() =\n        sbp_parallel;\n  }\n  ctx->AddNdSbpSignature(nd_sbp_signature);\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> HierarchicalParallelCastOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> HierarchicalParallelCastLikeOp::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) {\n  ctx->SetOutputShape(\"out\", 0, ctx->InputShape(\"in\", 0));\n  ctx->SetOutputIsDynamic(\"out\", 0, ctx->InputIsDynamic(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> HierarchicalParallelCastLikeOp::InferPhysicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> HierarchicalParallelCastLikeOp::GetSbp(user_op::SbpContext* ctx) {\n  return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx);\n}\n\n/* static */ Maybe<void> HierarchicalParallelCastLikeOp::InferNdSbp(\n    user_op::InferNdSbpFnContext* ctx) {\n  NdSbp* in_distribution = ctx->NdSbp4ArgNameAndIndex(\"in\", 0);\n  NdSbp* out_distribution = ctx->NdSbp4ArgNameAndIndex(\"out\", 0);\n  NdSbp* like_distribution = ctx->NdSbp4ArgNameAndIndex(\"like\", 0);\n  const NdSbp& hint_distribution = ctx->NdSbpHint4InputArgNameAndIndex(\"like\", 0);\n  *in_distribution = hint_distribution;\n  *out_distribution = hint_distribution;\n  *like_distribution = hint_distribution;\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> HierarchicalParallelCastLikeOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/identity_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/* static */ Maybe<void> IdentityOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  ctx->SetOutputShape(\"out\", 0, ctx->InputShape(\"in\", 0));\n  ctx->SetOutputIsDynamic(\"out\", 0, ctx->InputIsDynamic(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> IdentityOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> IdentityOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"in\", 0);\n  FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) {\n    ctx->NewBuilder().Split(user_op::OpArg(\"in\", 0), i).Split(user_op::OpArg(\"out\", 0), i).Build();\n  }\n  ctx->NewBuilder()\n      .PartialSum(user_op::OpArg(\"in\", 0))\n      .PartialSum(user_op::OpArg(\"out\", 0))\n      .Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> IdentityOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/image_batch_align_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<typename T>\nbool PowerOfTwo(T x) {\n  static_assert(std::is_integral<T>::value, \"T must be integral\");\n  return x != 0 && (x & (x - 1)) == 0;\n}\n\n}  // namespace\n\n/* static */ Maybe<void> ImageBatchAlignOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& in_desc = ctx->InputTensorDesc(\"in\", 0);\n  CHECK_OR_RETURN(in_desc.shape().NumAxes() == 1);\n  const Shape& shape_attr = ctx->Attr<Shape>(\"shape\");\n  const bool dynamic_out = ctx->Attr<bool>(\"dynamic_out\");\n  DimVector dim_vec(shape_attr.NumAxes() + 1);\n  dim_vec.at(0) = in_desc.shape().elem_cnt();\n  FOR_RANGE(int64_t, i, 0, shape_attr.NumAxes()) { dim_vec.at(i + 1) = shape_attr.At(i); }\n  user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc(\"out\", 0);\n  out_desc->set_shape(Shape(dim_vec));\n  out_desc->set_is_dynamic(dynamic_out);\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> ImageBatchAlignOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> ImageBatchAlignOp::GetSbp(user_op::SbpContext* ctx) {\n  ctx->NewBuilder().Split(ctx->inputs(), 0).Split(ctx->outputs(), 0).Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> ImageBatchAlignOp::ModifyOutputArg(\n    const GetOutputArgModifier& GetOutputArgModifierFn, const user_op::UserOpConfWrapper& conf) {\n  user_op::OutputArgModifier* out_modifier = GetOutputArgModifierFn(\"out\", 0);\n  CHECK_OR_RETURN(out_modifier != nullptr);\n  out_modifier->set_header_infered_before_compute(false);\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> ImageBatchAlignOp::CheckAttr(const user_op::UserOpDefWrapper& def,\n                                                      const user_op::UserOpConfWrapper& conf) {\n  bool check_failed = false;\n  std::stringstream err;\n  err << \"Illegal attr value for \" << conf.op_type_name() << \" op, op_name: \" << conf.op_name();\n  const Shape& shape = conf.attr<Shape>(\"shape\");\n  if (shape.NumAxes() != 3) {\n    err << \", shape: \" << shape.ToString() << \" (image shape must has 3 axes)\";\n    check_failed = true;\n  }\n  DataType data_type = conf.attr<DataType>(\"data_type\");\n  if (data_type != DataType::kUInt8 && data_type != DataType::kFloat) {\n    err << \", data_type: \" << data_type << \" (only support kUInt8 and kFloat for now)\";\n    check_failed = true;\n  }\n  int32_t alignment = conf.attr<int32_t>(\"alignment\");\n  if (alignment < 0) {\n    err << \", alignment: \" << alignment << \" (alignment must be greater than or equal to 0)\";\n    check_failed = true;\n  } else if (alignment != 0 && !PowerOfTwo(alignment)) {\n    err << \", alignment: \" << alignment\n        << \" (alignment must be power of 2 when it's not equal to 0)\";\n    check_failed = true;\n  }\n  if (check_failed) { return oneflow::Error::CheckFailedError() << err.str(); }\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> ImageBatchAlignOp::InferDataType(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& in_desc = ctx->InputTensorDesc(\"in\", 0);\n  CHECK_OR_RETURN(in_desc.data_type() == DataType::kTensorBuffer);\n  user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc(\"out\", 0);\n  out_desc->set_data_type(ctx->Attr<DataType>(\"data_type\"));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/image_decode_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/* static */ Maybe<void> ImageDecodeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& in_desc = ctx->InputTensorDesc(\"in\", 0);\n  CHECK_OR_RETURN(in_desc.shape().NumAxes() == 1 && in_desc.shape().At(0) >= 1);\n  user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc(\"out\", 0);\n  out_desc->set_shape(in_desc.shape());\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> ImageDecodeOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> ImageDecodeOp::GetSbp(user_op::SbpContext* ctx) {\n  ctx->NewBuilder().Split(ctx->inputs(), 0).Split(ctx->outputs(), 0).Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> ImageDecodeOp::CheckAttr(const user_op::UserOpDefWrapper& def,\n                                                  const user_op::UserOpConfWrapper& conf) {\n  bool check_failed = false;\n  std::stringstream err;\n  err << \"Illegal attr value for \" << conf.op_type_name() << \" op, op_name: \" << conf.op_name();\n  const std::string& color_space = conf.attr<std::string>(\"color_space\");\n  if (color_space != \"BGR\" && color_space != \"RGB\" && color_space != \"GRAY\") {\n    err << \", color_space: \" << color_space\n        << \" (color_space can only be one of BGR, RGB and GRAY)\";\n    check_failed = true;\n  }\n  DataType data_type = conf.attr<DataType>(\"data_type\");\n  if (data_type != DataType::kUInt8 && data_type != DataType::kFloat) {\n    err << \", data_type: \" << data_type << \" (only support kUInt8 and kFloat for now)\";\n    check_failed = true;\n  }\n  if (check_failed) { return oneflow::Error::CheckFailedError() << err.str(); }\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> ImageDecodeOp::InferDataType(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& in_desc = ctx->InputTensorDesc(\"in\", 0);\n  CHECK_OR_RETURN(in_desc.data_type() == DataType::kTensorBuffer);\n  user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc(\"out\", 0);\n  out_desc->set_data_type(DataType::kTensorBuffer);\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/image_object_preprocess_ops.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nMaybe<void> ImageObjectGetSbp(user_op::SbpContext* ctx) {\n  ctx->NewBuilder().Split(ctx->inputs(), 0).Split(ctx->outputs(), 0).Build();\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\n/* static */ Maybe<void> ImageFlipOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& in_desc = ctx->InputTensorDesc(\"in\", 0);\n  CHECK_EQ_OR_RETURN(in_desc.shape().NumAxes(), 1);\n  const int N = in_desc.shape().elem_cnt();\n\n  const user_op::TensorDesc& flip_code_desc = ctx->InputTensorDesc(\"flip_code\", 0);\n  CHECK_EQ_OR_RETURN(flip_code_desc.shape().elem_cnt(), N);\n\n  ctx->SetOutputShape(\"out\", 0, ctx->InputShape(\"in\", 0));\n  ctx->SetOutputIsDynamic(\"out\", 0, ctx->InputIsDynamic(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> ImageFlipOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> ImageFlipOp::GetSbp(user_op::SbpContext* ctx) {\n  return ImageObjectGetSbp(ctx);\n}\n\n/* static */ Maybe<void> ImageFlipOp::InferDataType(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& in_desc = ctx->InputTensorDesc(\"in\", 0);\n  CHECK_EQ_OR_RETURN(in_desc.data_type(), DataType::kTensorBuffer)\n      << \"InferDataType Failed. Expected \" << DataType_Name(DataType::kTensorBuffer) << \", but got \"\n      << DataType_Name(in_desc.data_type());\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> ObjectBboxFlipOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& bbox_desc = ctx->InputTensorDesc(\"bbox\", 0);\n  CHECK_EQ_OR_RETURN(bbox_desc.shape().NumAxes(), 1);\n  const int N = bbox_desc.shape().elem_cnt();\n\n  const user_op::TensorDesc& image_size_desc = ctx->InputTensorDesc(\"image_size\", 0);\n  CHECK_EQ_OR_RETURN(image_size_desc.shape().elem_cnt(), N * 2);\n\n  const user_op::TensorDesc& flip_code_desc = ctx->InputTensorDesc(\"flip_code\", 0);\n  CHECK_EQ_OR_RETURN(flip_code_desc.shape().elem_cnt(), N);\n\n  ctx->SetOutputShape(\"out\", 0, ctx->InputShape(\"bbox\", 0));\n  ctx->SetOutputIsDynamic(\"out\", 0, ctx->InputIsDynamic(\"bbox\", 0));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> ObjectBboxFlipOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> ObjectBboxFlipOp::GetSbp(user_op::SbpContext* ctx) {\n  return ImageObjectGetSbp(ctx);\n}\n\n/* static */ Maybe<void> ObjectBboxFlipOp::InferDataType(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& bbox_desc = ctx->InputTensorDesc(\"bbox\", 0);\n  CHECK_EQ_OR_RETURN(bbox_desc.data_type(), DataType::kTensorBuffer)\n      << \"InferDataType Failed. Expected \" << DataType_Name(DataType::kTensorBuffer) << \", but got \"\n      << DataType_Name(bbox_desc.data_type());\n  const user_op::TensorDesc& image_size_desc = ctx->InputTensorDesc(\"image_size\", 0);\n  CHECK_EQ_OR_RETURN(image_size_desc.data_type(), DataType::kInt32)\n      << \"InferDataType Failed. Expected \" << DataType_Name(DataType::kInt32) << \", but got \"\n      << DataType_Name(image_size_desc.data_type());\n  const user_op::TensorDesc& flip_code_desc = ctx->InputTensorDesc(\"flip_code\", 0);\n  CHECK_EQ_OR_RETURN(flip_code_desc.data_type(), DataType::kInt8)\n      << \"InferDataType Failed. Expected \" << DataType_Name(DataType::kInt8) << \", but got \"\n      << DataType_Name(flip_code_desc.data_type());\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"bbox\", 0));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> ObjectBboxScaleOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& bbox_desc = ctx->InputTensorDesc(\"bbox\", 0);\n  CHECK_EQ_OR_RETURN(bbox_desc.shape().NumAxes(), 1);\n  const int N = bbox_desc.shape().elem_cnt();\n\n  const user_op::TensorDesc& scale_desc = ctx->InputTensorDesc(\"scale\", 0);\n  CHECK_EQ_OR_RETURN(scale_desc.shape().elem_cnt(), N * 2);\n\n  ctx->SetOutputShape(\"out\", 0, ctx->InputShape(\"bbox\", 0));\n  ctx->SetOutputIsDynamic(\"out\", 0, ctx->InputIsDynamic(\"bbox\", 0));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> ObjectBboxScaleOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> ObjectBboxScaleOp::GetSbp(user_op::SbpContext* ctx) {\n  return ImageObjectGetSbp(ctx);\n}\n\n/* static */ Maybe<void> ObjectBboxScaleOp::InferDataType(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& bbox_desc = ctx->InputTensorDesc(\"bbox\", 0);\n  CHECK_EQ_OR_RETURN(bbox_desc.data_type(), DataType::kTensorBuffer)\n      << \"InferDataType Failed. Expected \" << DataType_Name(DataType::kTensorBuffer) << \", but got \"\n      << DataType_Name(bbox_desc.data_type());\n  const user_op::TensorDesc& scale_desc = ctx->InputTensorDesc(\"scale\", 0);\n  CHECK_EQ_OR_RETURN(scale_desc.data_type(), DataType::kFloat)\n      << \"InferDataType Failed. Expected \" << DataType_Name(DataType::kFloat) << \", but got \"\n      << DataType_Name(scale_desc.data_type());\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"bbox\", 0));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> ObjectSegmentationPolygonFlipOp::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) {\n  const user_op::TensorDesc& poly_desc = ctx->InputTensorDesc(\"poly\", 0);\n  CHECK_EQ_OR_RETURN(poly_desc.shape().NumAxes(), 1);\n  const int N = poly_desc.shape().elem_cnt();\n\n  const user_op::TensorDesc& image_size_desc = ctx->InputTensorDesc(\"image_size\", 0);\n  CHECK_EQ_OR_RETURN(image_size_desc.shape().elem_cnt(), N * 2);\n\n  const user_op::TensorDesc& flip_code_desc = ctx->InputTensorDesc(\"flip_code\", 0);\n  CHECK_EQ_OR_RETURN(flip_code_desc.shape().elem_cnt(), N);\n\n  ctx->SetOutputShape(\"out\", 0, ctx->InputShape(\"poly\", 0));\n  ctx->SetOutputIsDynamic(\"out\", 0, ctx->InputIsDynamic(\"poly\", 0));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> ObjectSegmentationPolygonFlipOp::InferPhysicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> ObjectSegmentationPolygonFlipOp::GetSbp(user_op::SbpContext* ctx) {\n  return ImageObjectGetSbp(ctx);\n}\n\n/* static */ Maybe<void> ObjectSegmentationPolygonFlipOp::InferDataType(\n    user_op::InferContext* ctx) {\n  const user_op::TensorDesc& poly_desc = ctx->InputTensorDesc(\"poly\", 0);\n  CHECK_EQ_OR_RETURN(poly_desc.data_type(), DataType::kTensorBuffer)\n      << \"InferDataType Failed. Expected \" << DataType_Name(DataType::kTensorBuffer) << \", but got \"\n      << DataType_Name(poly_desc.data_type());\n  const user_op::TensorDesc& image_size_desc = ctx->InputTensorDesc(\"image_size\", 0);\n  CHECK_EQ_OR_RETURN(image_size_desc.data_type(), DataType::kInt32)\n      << \"InferDataType Failed. Expected \" << DataType_Name(DataType::kInt32) << \", but got \"\n      << DataType_Name(image_size_desc.data_type());\n  const user_op::TensorDesc& flip_code_desc = ctx->InputTensorDesc(\"flip_code\", 0);\n  CHECK_EQ_OR_RETURN(flip_code_desc.data_type(), DataType::kInt8)\n      << \"InferDataType Failed. Expected \" << DataType_Name(DataType::kInt8) << \", but got \"\n      << DataType_Name(flip_code_desc.data_type());\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"poly\", 0));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> ObjectSegmentationPolygonScaleOp::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) {\n  const user_op::TensorDesc& poly_desc = ctx->InputTensorDesc(\"poly\", 0);\n  CHECK_EQ_OR_RETURN(poly_desc.shape().NumAxes(), 1);\n  const int N = poly_desc.shape().elem_cnt();\n\n  const user_op::TensorDesc& scale_desc = ctx->InputTensorDesc(\"scale\", 0);\n  CHECK_EQ_OR_RETURN(scale_desc.shape().elem_cnt(), N * 2);\n\n  ctx->SetOutputShape(\"out\", 0, ctx->InputShape(\"poly\", 0));\n  ctx->SetOutputIsDynamic(\"out\", 0, ctx->InputIsDynamic(\"poly\", 0));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> ObjectSegmentationPolygonScaleOp::InferPhysicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> ObjectSegmentationPolygonScaleOp::GetSbp(user_op::SbpContext* ctx) {\n  return ImageObjectGetSbp(ctx);\n}\n\n/* static */ Maybe<void> ObjectSegmentationPolygonScaleOp::InferDataType(\n    user_op::InferContext* ctx) {\n  const user_op::TensorDesc& poly_desc = ctx->InputTensorDesc(\"poly\", 0);\n  CHECK_EQ_OR_RETURN(poly_desc.data_type(), DataType::kTensorBuffer)\n      << \"InferDataType Failed. Expected \" << DataType_Name(DataType::kTensorBuffer) << \", but got \"\n      << DataType_Name(poly_desc.data_type());\n  const user_op::TensorDesc& scale_desc = ctx->InputTensorDesc(\"scale\", 0);\n  CHECK_EQ_OR_RETURN(scale_desc.data_type(), DataType::kFloat)\n      << \"InferDataType Failed. Expected \" << DataType_Name(DataType::kFloat) << \", but got \"\n      << DataType_Name(scale_desc.data_type());\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"poly\", 0));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> ImageNormalizeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& in_desc = ctx->InputTensorDesc(\"in\", 0);\n  CHECK_EQ_OR_RETURN(in_desc.shape().NumAxes(), 1);\n  ctx->SetOutputShape(\"out\", 0, ctx->InputShape(\"in\", 0));\n  ctx->SetOutputIsDynamic(\"out\", 0, ctx->InputIsDynamic(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> ImageNormalizeOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> ImageNormalizeOp::GetSbp(user_op::SbpContext* ctx) {\n  return ImageObjectGetSbp(ctx);\n}\n\n/* static */ Maybe<void> ImageNormalizeOp::InferDataType(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& in_desc = ctx->InputTensorDesc(\"in\", 0);\n  CHECK_EQ_OR_RETURN(in_desc.data_type(), DataType::kTensorBuffer)\n      << \"InferDataType Failed. Expected \" << DataType_Name(DataType::kTensorBuffer) << \", but got \"\n      << DataType_Name(in_desc.data_type());\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> ObjectSegmentationPolygonToMaskOp::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) {\n  const user_op::TensorDesc& poly_desc = ctx->InputTensorDesc(\"poly\", 0);\n  CHECK_EQ_OR_RETURN(poly_desc.shape().NumAxes(), 1);\n  const int N = poly_desc.shape().elem_cnt();\n\n  const user_op::TensorDesc& poly_index_desc = ctx->InputTensorDesc(\"poly_index\", 0);\n  CHECK_EQ_OR_RETURN(poly_index_desc.shape().NumAxes(), 1);\n  CHECK_EQ_OR_RETURN(poly_index_desc.shape().elem_cnt(), N);\n\n  const user_op::TensorDesc& image_size_desc = ctx->InputTensorDesc(\"image_size\", 0);\n  CHECK_EQ_OR_RETURN(image_size_desc.shape().elem_cnt(), N * 2);\n\n  ctx->SetOutputShape(\"out\", 0, ctx->InputShape(\"poly\", 0));\n  ctx->SetOutputIsDynamic(\"out\", 0, ctx->InputIsDynamic(\"poly\", 0));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> ObjectSegmentationPolygonToMaskOp::InferPhysicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> ObjectSegmentationPolygonToMaskOp::GetSbp(user_op::SbpContext* ctx) {\n  return ImageObjectGetSbp(ctx);\n}\n\n/* static */ Maybe<void> ObjectSegmentationPolygonToMaskOp::InferDataType(\n    user_op::InferContext* ctx) {\n  const user_op::TensorDesc& poly_desc = ctx->InputTensorDesc(\"poly\", 0);\n  CHECK_EQ_OR_RETURN(poly_desc.data_type(), DataType::kTensorBuffer)\n      << \"InferDataType Failed. Expected \" << DataType_Name(DataType::kTensorBuffer) << \", but got \"\n      << DataType_Name(poly_desc.data_type());\n  const user_op::TensorDesc& poly_index_desc = ctx->InputTensorDesc(\"poly_index\", 0);\n  CHECK_EQ_OR_RETURN(poly_index_desc.data_type(), DataType::kTensorBuffer)\n      << \"InferDataType Failed. Expected \" << DataType_Name(DataType::kTensorBuffer) << \", but got \"\n      << DataType_Name(poly_desc.data_type());\n  const user_op::TensorDesc& image_size_desc = ctx->InputTensorDesc(\"image_size\", 0);\n  CHECK_EQ_OR_RETURN(image_size_desc.data_type(), DataType::kInt32)\n      << \"InferDataType Failed. Expected \" << DataType_Name(DataType::kInt32) << \", but got \"\n      << DataType_Name(image_size_desc.data_type());\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"poly\", 0));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/image_preprocess_ops.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/common/balanced_splitter.h\"\n#include \"oneflow/core/job/sbp_parallel.h\"\n#include \"oneflow/user/image/image_util.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n#include \"oneflow/core/job/nd_sbp_util.h\"\n\nnamespace oneflow {\n\n/* static */ Maybe<void> CropMirrorNormalizeFromTensorbufferOp::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) {\n  const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc(\"in\", 0);\n  bool has_mirror = ctx->has_input(\"mirror\", 0);\n  if (has_mirror) {\n    const user_op::TensorDesc& mirror_tensor = ctx->InputTensorDesc(\"mirror\", 0);\n    CHECK_OR_RETURN(mirror_tensor.shape().NumAxes() == 1\n                    && in_tensor.shape().At(0) == mirror_tensor.shape().At(0));\n  }\n  user_op::TensorDesc* out_tensor = ctx->MutOutputTensorDesc(\"out\", 0);\n  int64_t N = in_tensor.shape().At(0);\n  int64_t H = ctx->Attr<int64_t>(\"crop_h\");\n  int64_t W = ctx->Attr<int64_t>(\"crop_w\");\n  std::string color_space = ctx->Attr<std::string>(\"color_space\");\n  int64_t C = ImageUtil::IsColor(color_space) ? 3 : 1;\n\n  CHECK_OR_RETURN(H != 0 && W != 0);\n  CHECK_OR_RETURN(in_tensor.shape().NumAxes() == 1);\n  std::string output_layout = ctx->Attr<std::string>(\"output_layout\");\n  if (output_layout == \"NCHW\") {\n    out_tensor->set_shape(Shape({N, C, H, W}));\n  } else if (output_layout == \"NHWC\") {\n    out_tensor->set_shape(Shape({N, H, W, C}));\n  } else {\n    return Error::CheckFailedError() << \"output_layout: \" << output_layout << \" is not supported\";\n  }\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> CropMirrorNormalizeFromTensorbufferOp::InferPhysicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> CropMirrorNormalizeFromTensorbufferOp::GetSbp(user_op::SbpContext* ctx) {\n  ctx->NewBuilder().Split(ctx->inputs(), 0).Split(ctx->outputs(), 0).Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> CropMirrorNormalizeFromTensorbufferOp::InferDataType(\n    user_op::InferContext* ctx) {\n  const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc(\"in\", 0);\n  CHECK_EQ_OR_RETURN(in_tensor.data_type(), DataType::kTensorBuffer)\n      << \"InferDataType Failed. Expected \" << DataType_Name(DataType::kTensorBuffer) << \", but got \"\n      << DataType_Name(in_tensor.data_type());\n  bool has_mirror = ctx->has_input(\"mirror\", 0);\n  if (has_mirror) {\n    const user_op::TensorDesc& mirror_tensor = ctx->InputTensorDesc(\"mirror\", 0);\n    CHECK_EQ_OR_RETURN(mirror_tensor.data_type(), DataType::kInt8)\n        << \"InferDataType Failed. Expected \" << DataType_Name(DataType::kInt8) << \", but got \"\n        << DataType_Name(mirror_tensor.data_type());\n  }\n\n  user_op::TensorDesc* out_tensor = ctx->MutOutputTensorDesc(\"out\", 0);\n  DataType output_dtype = ctx->Attr<DataType>(\"output_dtype\");\n  CHECK_EQ_OR_RETURN(output_dtype,\n                     DataType::kFloat)\n      << \"InferDataType Failed. Expected \" << DataType_Name(DataType::kFloat) << \", but got \"\n      << DataType_Name(output_dtype);  // only support float now; for float16 in future\n  out_tensor->set_data_type(output_dtype);\n\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> CropMirrorNormalizeFromUint8Op::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) {\n  const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc(\"in\", 0);\n  bool has_mirror = ctx->has_input(\"mirror\", 0);\n  if (has_mirror) {\n    const user_op::TensorDesc& mirror_tensor = ctx->InputTensorDesc(\"mirror\", 0);\n    CHECK_OR_RETURN(mirror_tensor.shape().NumAxes() == 1\n                    && in_tensor.shape().At(0) == mirror_tensor.shape().At(0));\n  }\n  user_op::TensorDesc* out_tensor = ctx->MutOutputTensorDesc(\"out\", 0);\n  int64_t N = in_tensor.shape().At(0);\n  int64_t H = ctx->Attr<int64_t>(\"crop_h\");\n  int64_t W = ctx->Attr<int64_t>(\"crop_w\");\n  std::string color_space = ctx->Attr<std::string>(\"color_space\");\n  int64_t C = ImageUtil::IsColor(color_space) ? 3 : 1;\n  CHECK_EQ_OR_RETURN(in_tensor.shape().NumAxes(), 4);  // {N, H, W, C}\n  CHECK_EQ_OR_RETURN(in_tensor.shape().At(3), C);\n  if (H == 0 || W == 0) {\n    H = in_tensor.shape().At(1);\n    W = in_tensor.shape().At(2);\n  } else {\n    H = std::min(H, in_tensor.shape().At(1));\n    W = std::min(W, in_tensor.shape().At(2));\n  }\n  std::string output_layout = ctx->Attr<std::string>(\"output_layout\");\n  if (output_layout == \"NCHW\") {\n    out_tensor->set_shape(Shape({N, C, H, W}));\n  } else if (output_layout == \"NHWC\") {\n    out_tensor->set_shape(Shape({N, H, W, C}));\n  } else {\n    return Error::CheckFailedError() << \"output_layout: \" << output_layout << \" is not supported\";\n  }\n\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> CropMirrorNormalizeFromUint8Op::InferPhysicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> CropMirrorNormalizeFromUint8Op::GetSbp(user_op::SbpContext* ctx) {\n  ctx->NewBuilder().Split(ctx->inputs(), 0).Split(ctx->outputs(), 0).Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> CropMirrorNormalizeFromUint8Op::InferDataType(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc(\"in\", 0);\n  CHECK_EQ_OR_RETURN(in_tensor.data_type(), DataType::kUInt8)\n      << \"InferDataType Failed. Expected \" << DataType_Name(DataType::kUInt8) << \", but got \"\n      << DataType_Name(in_tensor.data_type());\n  bool has_mirror = ctx->has_input(\"mirror\", 0);\n  if (has_mirror) {\n    const user_op::TensorDesc& mirror_tensor = ctx->InputTensorDesc(\"mirror\", 0);\n    CHECK_EQ_OR_RETURN(mirror_tensor.data_type(), DataType::kInt8)\n        << \"InferDataType Failed. Expected \" << DataType_Name(DataType::kInt8) << \", but got \"\n        << DataType_Name(mirror_tensor.data_type());\n  }\n  user_op::TensorDesc* out_tensor = ctx->MutOutputTensorDesc(\"out\", 0);\n  DataType output_dtype = ctx->Attr<DataType>(\"output_dtype\");\n  CHECK_EQ_OR_RETURN(output_dtype,\n                     DataType::kFloat)\n      << \"InferDataType Failed. Expected \" << DataType_Name(DataType::kFloat) << \", but got \"\n      << DataType_Name(output_dtype);  // only support float now; for float16 in future\n  out_tensor->set_data_type(output_dtype);\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> CoinFlipOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  user_op::TensorDesc* out_tensor = ctx->MutOutputTensorDesc(\"out\", 0);\n  int64_t batch_size = ctx->Attr<int64_t>(\"batch_size\");\n  out_tensor->set_shape(Shape({batch_size}));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> CoinFlipOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& parallel_hierarchy = *ctx->parallel_desc().hierarchy();\n  const NdSbp& nd_sbp = ctx->NdSbp4ArgNameAndIndex(\"out\", 0);\n  int64_t batch_size = ctx->Attr<int64_t>(\"batch_size\");\n  const Shape logical_shape = Shape({batch_size});\n  const int64_t parallel_id = ctx->parallel_ctx().parallel_id();\n\n  const auto tensor_slice_view =\n      GetTensorSliceView4ParallelId(parallel_hierarchy, nd_sbp, logical_shape, parallel_id);\n  const Shape& physical_shape = tensor_slice_view.shape();\n  ctx->SetOutputShape(\"out\", 0, physical_shape);\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> CoinFlipOp::GetSbp(user_op::SbpContext* ctx) {\n  ctx->NewBuilder().Split(user_op::OpArg(\"out\", 0), 0).Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> CoinFlipOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) {\n  const Shape& hierarchy = ctx->parallel_hierarchy();\n  NdSbp* output_dist = ctx->NdSbp4ArgNameAndIndex(\"out\", 0);\n  // the input may be produced by tick which should be broadcast parallel dist\n  std::vector<NdSbp*> inputs_dist;\n  for (const auto& arg_pair : ctx->inputs()) {\n    inputs_dist.emplace_back(ctx->NdSbp4ArgNameAndIndex(arg_pair.first, arg_pair.second));\n  }\n  const auto& dist_conf = ctx->user_op_conf().attr<std::vector<std::string>>(\"nd_sbp\");\n  if (dist_conf.size() == 0) {\n    FOR_RANGE(int, i, 0, hierarchy.NumAxes()) {\n      output_dist->add_sbp_parallel()->mutable_split_parallel()->set_axis(0);\n      for (auto* input_dist : inputs_dist) {\n        input_dist->add_sbp_parallel()->mutable_broadcast_parallel();\n      }\n    }\n  } else {\n    CHECK_EQ_OR_RETURN(dist_conf.size(), hierarchy.NumAxes());\n    for (const std::string& sbp_str : dist_conf) {\n      SbpParallel sbp_parallel;\n      CHECK_OR_RETURN(ParseSbpParallelFromString(sbp_str, &sbp_parallel));\n      CHECK_OR_RETURN(\n          (sbp_parallel.has_split_parallel() && sbp_parallel.split_parallel().axis() == 0)\n          || sbp_parallel.has_broadcast_parallel());\n      *output_dist->add_sbp_parallel() = sbp_parallel;\n      for (auto* input_dist : inputs_dist) {\n        input_dist->add_sbp_parallel()->mutable_broadcast_parallel();\n      }\n    }\n  }\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> CoinFlipOp::InferDataType(user_op::InferContext* ctx) {\n  user_op::TensorDesc* out_tensor = ctx->MutOutputTensorDesc(\"out\", 0);\n  out_tensor->set_data_type(DataType::kInt8);\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> ImageRandomCropOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc(\"in\", 0);\n  user_op::TensorDesc* out_tensor = ctx->MutOutputTensorDesc(\"out\", 0);\n  out_tensor->set_shape(in_tensor.shape());\n  out_tensor->set_is_dynamic(in_tensor.is_dynamic());\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> ImageRandomCropOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> ImageRandomCropOp::GetSbp(user_op::SbpContext* ctx) {\n  return user_op::GetSbpFnUtil::SplitForEachAxis(ctx);\n}\n\n/* static */ Maybe<void> ImageRandomCropOp::ModifyInputArg(\n    const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) {\n  user_op::InputArgModifier* in_modifier = GetInputArgModifierFn(\"in\", 0);\n  CHECK_NOTNULL_OR_RETURN(in_modifier);\n  in_modifier->set_requires_grad(false);\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> ImageRandomCropOp::InferDataType(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc(\"in\", 0);\n  CHECK_OR_RETURN(in_tensor.data_type() == DataType::kTensorBuffer);\n  ctx->SetOutputDType(\"out\", 0, in_tensor.data_type());\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/image_resize_ops.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/user/image/image_util.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/* static */ Maybe<void> ImageResizeToFixedOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc(\"in\", 0);\n  CHECK_OR_RETURN(in_tensor.shape().NumAxes() == 1 && in_tensor.shape().elem_cnt() > 0);\n  int64_t batch_size = in_tensor.shape().elem_cnt();\n  int64_t target_width = ctx->Attr<int64_t>(\"target_width\");\n  int64_t target_height = ctx->Attr<int64_t>(\"target_height\");\n  int64_t channels = ctx->Attr<int64_t>(\"channels\");\n\n  user_op::TensorDesc* out_tensor = ctx->MutOutputTensorDesc(\"out\", 0);\n  out_tensor->set_shape(Shape({batch_size, target_height, target_width, channels}));\n  out_tensor->set_is_dynamic(in_tensor.is_dynamic());\n\n  user_op::TensorDesc* scale_tensor = ctx->MutOutputTensorDesc(\"scale\", 0);\n  scale_tensor->set_shape(Shape({batch_size, 2}));\n  scale_tensor->set_is_dynamic(in_tensor.is_dynamic());\n\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> ImageResizeToFixedOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> ImageResizeToFixedOp::GetSbp(user_op::SbpContext* ctx) {\n  ctx->NewBuilder().Split(ctx->inputs(), 0).Split(ctx->outputs(), 0).Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> ImageResizeToFixedOp::CheckAttr(const user_op::UserOpDefWrapper& def,\n                                                         const user_op::UserOpConfWrapper& conf) {\n  bool check_failed = false;\n  std::ostringstream err;\n  err << \"Illegal attr value for \" << conf.op_type_name() << \" op, op_name: \" << conf.op_name();\n  int64_t target_width = conf.attr<int64_t>(\"target_width\");\n  int64_t target_height = conf.attr<int64_t>(\"target_height\");\n  if (target_width <= 0 || target_height <= 0) {\n    err << \", target_width: \" << target_width << \", target_height: \" << target_height;\n    check_failed = true;\n  }\n  int64_t channels = conf.attr<int64_t>(\"channels\");\n  if (channels != 1 && channels != 3) {\n    err << \", channels: \" << channels << \" (channels can only be 1 or 3)\";\n    check_failed = true;\n  }\n  DataType data_type = conf.attr<DataType>(\"data_type\");\n  if (data_type != DataType::kUInt8 && data_type != DataType::kFloat) {\n    err << \", data_type: \" << data_type << \" (only support kUInt8 and kFloat for now)\";\n    check_failed = true;\n  }\n  const std::string& interp_type = conf.attr<std::string>(\"interpolation_type\");\n  if (!CheckInterpolationValid(interp_type, err)) { check_failed = true; }\n  if (check_failed) { return oneflow::Error::CheckFailedError() << err.str(); }\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> ImageResizeToFixedOp::InferDataType(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc(\"in\", 0);\n  CHECK_OR_RETURN(in_tensor.data_type() == DataType::kTensorBuffer);\n  user_op::TensorDesc* out_tensor = ctx->MutOutputTensorDesc(\"out\", 0);\n  out_tensor->set_data_type(ctx->Attr<DataType>(\"data_type\"));\n  user_op::TensorDesc* scale_tensor = ctx->MutOutputTensorDesc(\"scale\", 0);\n  scale_tensor->set_data_type(DataType::kFloat);\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> ImageResizeKeepAspectRatioOp::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) {\n  const user_op::TensorDesc& in_desc = ctx->InputTensorDesc(\"in\", 0);\n  CHECK_OR_RETURN(in_desc.shape().NumAxes() == 1 && in_desc.shape().At(0) > 0);\n  user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc(\"out\", 0);\n  out_desc->set_shape(in_desc.shape());\n  user_op::TensorDesc* size_desc = ctx->MutOutputTensorDesc(\"size\", 0);\n  size_desc->set_shape(in_desc.shape());\n  user_op::TensorDesc* scale_desc = ctx->MutOutputTensorDesc(\"scale\", 0);\n  scale_desc->set_shape(in_desc.shape());\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> ImageResizeKeepAspectRatioOp::InferPhysicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> ImageResizeKeepAspectRatioOp::GetSbp(user_op::SbpContext* ctx) {\n  ctx->NewBuilder().Split(ctx->inputs(), 0).Split(ctx->outputs(), 0).Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> ImageResizeKeepAspectRatioOp::CheckAttr(\n    const user_op::UserOpDefWrapper& def, const user_op::UserOpConfWrapper& conf) {\n  bool check_failed = false;\n  std::ostringstream err;\n  err << \"Illegal attr value for \" << conf.op_type_name() << \" op, op_name: \" << conf.op_name();\n  const int32_t target_size = conf.attr<int32_t>(\"target_size\");\n  const int32_t max_size = conf.attr<int32_t>(\"max_size\");\n  if (target_size <= 0) {\n    err << \", target_size: \" << target_size << \" (target_size must be greater than 0)\";\n    check_failed = true;\n  }\n  if (max_size < target_size && max_size > 0) {\n    err << \", max_size: \" << max_size\n        << \" (max_size must be greater than target_size or equal to 0)\";\n    check_failed = true;\n  }\n  const std::string& interp_type = conf.attr<std::string>(\"interpolation_type\");\n  if (!CheckInterpolationValid(interp_type, err)) { check_failed = true; }\n  if (check_failed) { return oneflow::Error::CheckFailedError() << err.str(); }\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> ImageResizeKeepAspectRatioOp::InferDataType(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& in_desc = ctx->InputTensorDesc(\"in\", 0);\n  CHECK_OR_RETURN(in_desc.data_type() == DataType::kTensorBuffer);\n  user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc(\"out\", 0);\n  out_desc->set_data_type(DataType::kTensorBuffer);\n  user_op::TensorDesc* size_desc = ctx->MutOutputTensorDesc(\"size\", 0);\n  size_desc->set_data_type(DataType::kTensorBuffer);\n  user_op::TensorDesc* scale_desc = ctx->MutOutputTensorDesc(\"scale\", 0);\n  scale_desc->set_data_type(DataType::kTensorBuffer);\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/image_target_resize_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/* static */ Maybe<void> ImageTargetResizeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& in_desc = ctx->InputTensorDesc(\"in\", 0);\n  CHECK_OR_RETURN(in_desc.shape().NumAxes() == 1 && in_desc.shape().At(0) >= 1);\n  user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc(\"out\", 0);\n  out_desc->set_shape(in_desc.shape());\n  user_op::TensorDesc* size_desc = ctx->MutOutputTensorDesc(\"size\", 0);\n  size_desc->set_shape(Shape({in_desc.shape().elem_cnt(), 2}));\n  user_op::TensorDesc* scale_desc = ctx->MutOutputTensorDesc(\"scale\", 0);\n  scale_desc->set_shape(Shape({in_desc.shape().elem_cnt(), 2}));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> ImageTargetResizeOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> ImageTargetResizeOp::GetSbp(user_op::SbpContext* ctx) {\n  ctx->NewBuilder().Split(ctx->inputs(), 0).Split(ctx->outputs(), 0).Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> ImageTargetResizeOp::CheckAttr(const user_op::UserOpDefWrapper& def,\n                                                        const user_op::UserOpConfWrapper& conf) {\n  bool check_failed = false;\n  std::stringstream err;\n  err << \"Illegal attr value for \" << conf.op_type_name() << \" op, op_name: \" << conf.op_name();\n  const int32_t target_size = conf.attr<int32_t>(\"target_size\");\n  const int32_t max_size = conf.attr<int32_t>(\"max_size\");\n  if (target_size <= 0) {\n    err << \", target_size: \" << target_size << \" (target_size must be greater than 0)\";\n    check_failed = true;\n  }\n  if (max_size < target_size) {\n    err << \", max_size: \" << max_size << \" (max_size must be greater than 0)\";\n    check_failed = true;\n  }\n  if (check_failed) { return oneflow::Error::CheckFailedError() << err.str(); }\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> ImageTargetResizeOp::InferDataType(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& in_desc = ctx->InputTensorDesc(\"in\", 0);\n  CHECK_OR_RETURN(in_desc.data_type() == DataType::kTensorBuffer);\n  user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc(\"out\", 0);\n  out_desc->set_data_type(DataType::kTensorBuffer);\n  user_op::TensorDesc* size_desc = ctx->MutOutputTensorDesc(\"size\", 0);\n  size_desc->set_data_type(DataType::kInt32);\n  user_op::TensorDesc* scale_desc = ctx->MutOutputTensorDesc(\"scale\", 0);\n  scale_desc->set_data_type(DataType::kFloat);\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/in_top_k_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/* static */ Maybe<void> InTopKOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& targets = ctx->InputTensorDesc(\"targets\", 0);\n  const user_op::TensorDesc& predictions = ctx->InputTensorDesc(\"predictions\", 0);\n  user_op::TensorDesc* out = ctx->MutOutputTensorDesc(\"out\", 0);\n  CHECK_EQ_OR_RETURN(targets.shape().NumAxes(), 1);      // NOLINT(maybe-need-error-msg)\n  CHECK_EQ_OR_RETURN(predictions.shape().NumAxes(), 2);  // NOLINT(maybe-need-error-msg)\n  const bool is_dynamic = targets.is_dynamic();\n  CHECK_EQ_OR_RETURN(is_dynamic, predictions.is_dynamic());  // NOLINT(maybe-need-error-msg)\n  out->set_is_dynamic(is_dynamic);\n  out->set_shape(targets.shape());\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> InTopKOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> InTopKOp::GetSbp(user_op::SbpContext* ctx) {\n  ctx->NewBuilder().Split(ctx->inputs(), 0).Split(ctx->outputs(), 0).Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> InTopKOp::InferDataType(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& targets = ctx->InputTensorDesc(\"targets\", 0);\n  CHECK_OR_RETURN(IsIndexDataType(targets.data_type())) << \" targets data type must be index type\";\n  const user_op::TensorDesc& predictions = ctx->InputTensorDesc(\"predictions\", 0);\n  CHECK_EQ_OR_RETURN(predictions.data_type(), DataType::kFloat)\n      << \"InferDataType Failed. Expected \" << DataType_Name(DataType::kFloat) << \", but got \"\n      << DataType_Name(predictions.data_type());\n  user_op::TensorDesc* out = ctx->MutOutputTensorDesc(\"out\", 0);\n  out->set_data_type(kBool);\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/index_add_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/* static */ Maybe<void> IndexAddOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& input_shape = ctx->InputShape(\"input\", 0);\n  ctx->SetOutputShape(\"output\", 0, input_shape);\n  ctx->SetOutputStride(\"output\", 0, ctx->InputStride(\"input\", 0));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> IndexAddOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> IndexAddOp::GetSbp(user_op::SbpContext* ctx) {\n  // TODO(yangzhimin): support more valid sbp signature.\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> IndexAddOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"output\", 0, ctx->InputDType(\"input\", 0));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/indexed_slices_reduce_sum_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/* static */ Maybe<void> IndexedSlicesReduceSumOp::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) {\n  const user_op::TensorDesc& x_indices = ctx->InputTensorDesc(\"x_indices\", 0);\n  const user_op::TensorDesc& x_values = ctx->InputTensorDesc(\"x_values\", 0);\n  CHECK_LT_OR_RETURN(x_indices.shape().NumAxes(), x_values.shape().NumAxes());\n  FOR_RANGE(int64_t, i, 0, x_indices.shape().NumAxes()) {\n    CHECK_EQ_OR_RETURN(x_indices.shape().At(i), x_values.shape().At(i));\n  }\n\n  const int64_t n = x_indices.shape().elem_cnt();\n  const int64_t m = x_values.shape().elem_cnt() / n;\n  user_op::TensorDesc* y_indices = ctx->MutOutputTensorDesc(\"y_indices\", 0);\n  user_op::TensorDesc* y_values = ctx->MutOutputTensorDesc(\"y_values\", 0);\n  *y_indices = x_indices;\n  y_indices->set_shape(Shape({n}));\n  *y_values = x_values;\n  y_values->set_shape(Shape({n, m}));\n  user_op::TensorDesc* num_unique = ctx->MutOutputTensorDesc(\"num_unique\", 0);\n  num_unique->set_shape(Shape({1}));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> IndexedSlicesReduceSumOp::InferPhysicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> IndexedSlicesReduceSumOp::GetSbp(user_op::SbpContext* ctx) {\n  return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx);\n}\n\n/* static */ Maybe<void> IndexedSlicesReduceSumOp::InferDataType(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& x_indices = ctx->InputTensorDesc(\"x_indices\", 0);\n  CHECK_OR_RETURN(IsIndexDataType(x_indices.data_type()));\n  user_op::TensorDesc* num_unique = ctx->MutOutputTensorDesc(\"num_unique\", 0);\n  num_unique->set_data_type(DataType::kInt64);\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/inv_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/*static*/ Maybe<void> InvOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  ctx->SetOutputShape(\"y\", 0, ctx->InputShape(\"x\", 0));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> InvOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/*static*/ Maybe<void> InvOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& x = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"x\", 0);\n  FOR_RANGE(int64_t, i, 0, x.shape().NumAxes() - 2) {\n    ctx->NewBuilder().Split(user_op::OpArg(\"x\", 0), i).Split(user_op::OpArg(\"y\", 0), i).Build();\n  }\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> InvOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"y\", 0, ctx->InputDType(\"x\", 0));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/kl_div_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/user/ops/loss_op_util.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\nnamespace {\nMaybe<void> KlInferTensorDescFn(user_op::InferContext* ctx) {\n  const auto& input_desc = ctx->InputTensorDesc(\"input\", 0);\n  const auto& target_desc = ctx->InputTensorDesc(\"target\", 0);\n  CHECK_EQ_OR_RETURN(input_desc.is_dynamic(), target_desc.is_dynamic());\n  CHECK_EQ_OR_RETURN(input_desc.shape(), target_desc.shape());\n\n  user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc(\"out\", 0);\n  out_desc->set_is_dynamic(input_desc.is_dynamic());\n  out_desc->set_shape(input_desc.shape());\n\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> KlInferDataType(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& input_desc = ctx->InputTensorDesc(\"input\", 0);\n  const user_op::TensorDesc& target_desc = ctx->InputTensorDesc(\"target\", 0);\n  CHECK_EQ_OR_RETURN(input_desc.data_type(), target_desc.data_type())\n      << \"InferDataType Failed. Expected \" << DataType_Name(target_desc.data_type()) << \", but got \"\n      << DataType_Name(input_desc.data_type());\n\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"input\", 0));\n\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InferGradTensorDescFn(user_op::InferContext* ctx) {\n  const auto& input_desc = ctx->InputTensorDesc(\"input\", 0);\n  const auto& target_desc = ctx->InputTensorDesc(\"target\", 0);\n  const auto& dy_desc = ctx->InputTensorDesc(\"dy\", 0);\n  CHECK_EQ_OR_RETURN(input_desc.is_dynamic(), target_desc.is_dynamic());\n  CHECK_EQ_OR_RETURN(input_desc.shape(), target_desc.shape());\n  CHECK_EQ_OR_RETURN(dy_desc.shape(), target_desc.shape());\n\n  user_op::TensorDesc* dx_desc = ctx->MutOutputTensorDesc(\"dx\", 0);\n  dx_desc->set_is_dynamic(input_desc.is_dynamic());\n  dx_desc->set_shape(input_desc.shape());\n\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InferGradDataType(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& input_desc = ctx->InputTensorDesc(\"input\", 0);\n  const user_op::TensorDesc& target_desc = ctx->InputTensorDesc(\"target\", 0);\n  CHECK_EQ_OR_RETURN(input_desc.data_type(), target_desc.data_type())\n      << \"InferDataType Failed. Expected \" << DataType_Name(target_desc.data_type()) << \", but got \"\n      << DataType_Name(input_desc.data_type());\n\n  ctx->SetOutputDType(\"dx\", 0, ctx->InputDType(\"dy\", 0));\n\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\n/* static */ Maybe<void> KlDivLossOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  return KlInferTensorDescFn(ctx);\n}\n\n/*static*/ Maybe<void> KlDivLossOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> KlDivLossOp::GetSbp(user_op::SbpContext* ctx) {\n  const auto& input_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"input\", 0).shape();\n  FOR_RANGE(int64_t, i, 0, input_shape.NumAxes()) {\n    ctx->NewBuilder().Split(ctx->inputs(), i).Split(user_op::OpArg(\"out\", 0), i).Build();\n  }\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> KlDivLossOp::ModifyInputArg(\n    const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) {\n  user_op::InputArgModifier* target_modifier = GetInputArgModifierFn(\"target\", 0);\n  CHECK_OR_RETURN(target_modifier != nullptr);\n  target_modifier->set_requires_grad(false);\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> KlDivLossOp::InferDataType(user_op::InferContext* ctx) {\n  return KlInferDataType(ctx);\n}\n\n/* static */ Maybe<void> KlDivLossGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  return InferGradTensorDescFn(ctx);\n}\n\n/*static*/ Maybe<void> KlDivLossGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> KlDivLossGradOp::GetSbp(user_op::SbpContext* ctx) {\n  const auto& input_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"input\", 0).shape();\n  FOR_RANGE(int64_t, i, 0, input_shape.NumAxes()) {\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"input\", 0), i)\n        .Split(user_op::OpArg(\"target\", 0), i)\n        .Split(user_op::OpArg(\"dx\", 0), i)\n        .Split(user_op::OpArg(\"dy\", 0), i)\n        .Build();\n  }\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> KlDivLossGradOp::InferDataType(user_op::InferContext* ctx) {\n  return InferGradDataType(ctx);\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/l1_l2_regularize_gradient_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nMaybe<void> InferTensorDesc(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& model = ctx->InputTensorDesc(\"model\", 0);\n  const user_op::TensorDesc& model_diff = ctx->InputTensorDesc(\"model_diff\", 0);\n  CHECK_EQ_OR_RETURN(model_diff.shape(), model.shape());\n  ctx->SetOutputShape(\"out\", 0, ctx->InputShape(\"model\", 0));\n  ctx->SetOutputIsDynamic(\"out\", 0, ctx->InputIsDynamic(\"model\", 0));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> GetSbpSignatures(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& model = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"model\", 0);\n  FOR_RANGE(int64_t, axis, 0, model.shape().NumAxes()) {\n    ctx->NewBuilder().Split(ctx->inputs(), axis).Split(ctx->outputs(), axis).Build();\n  }\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\n/* static */ Maybe<void> L1L2RegularizeGradientOp::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferTensorDesc(ctx);\n}\n\n/*static*/ Maybe<void> L1L2RegularizeGradientOp::InferPhysicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> L1L2RegularizeGradientOp::GetSbp(user_op::SbpContext* ctx) {\n  return GetSbpSignatures(ctx);\n}\n\n/* static */ Maybe<void> L1L2RegularizeGradientOp::InferDataType(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& model = ctx->InputTensorDesc(\"model\", 0);\n  const user_op::TensorDesc& model_diff = ctx->InputTensorDesc(\"model_diff\", 0);\n  CHECK_EQ_OR_RETURN(model_diff.data_type(), model.data_type())\n      << \"InferDataType Failed. Expected \" << DataType_Name(model.data_type()) << \", but got \"\n      << DataType_Name(model_diff.data_type());\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"model\", 0));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/l2_normalize_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/* static */ Maybe<void> L2NormalizeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& x_shape = ctx->InputShape(\"x\", 0);\n  const int32_t axis = ctx->Attr<int32_t>(\"axis\");\n  const float epsilon = ctx->Attr<float>(\"epsilon\");\n  CHECK_GE_OR_RETURN(axis, 0);\n  CHECK_LT_OR_RETURN(axis, x_shape.NumAxes());\n  CHECK_GT_OR_RETURN(epsilon, 0);\n  ctx->SetOutputShape(\"y\", 0, x_shape);\n  Shape square_x_sum_shape = x_shape;\n  square_x_sum_shape.Set(axis, 1);\n  ctx->SetOutputShape(\"square_x_sum\", 0, square_x_sum_shape);\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> L2NormalizeOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> L2NormalizeOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"x\", 0);\n  const int32_t axis = ctx->Attr<int32_t>(\"axis\");\n  FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) {\n    if (i != axis) {\n      ctx->NewBuilder()\n          .Split(user_op::OpArg(\"x\", 0), i)\n          .Split(user_op::OpArg(\"y\", 0), i)\n          .Split(user_op::OpArg(\"square_x_sum\", 0), i)\n          .Build();\n    }\n  }\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> L2NormalizeOp::InferDataType(user_op::InferContext* ctx) {\n  DataType x_dtype = ctx->InputDType(\"x\", 0);\n  DataType square_x_sum_dtype = x_dtype;\n  if (x_dtype == DataType::kFloat16 || x_dtype == DataType::kBFloat16) {\n    square_x_sum_dtype = DataType::kFloat;\n  }\n  ctx->SetOutputDType(\"square_x_sum\", 0, square_x_sum_dtype);\n  ctx->SetOutputDType(\"y\", 0, x_dtype);\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> L2NormalizeGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& dy_shape = ctx->InputShape(\"dy\", 0);\n  const Shape& y_shape = ctx->InputShape(\"y\", 0);\n  const Shape& square_x_sum_shape = ctx->InputShape(\"square_x_sum\", 0);\n  const int32_t axis = ctx->Attr<int32_t>(\"axis\");\n  const float epsilon = ctx->Attr<float>(\"epsilon\");\n  CHECK_EQ_OR_RETURN(dy_shape, y_shape);\n  CHECK_GE_OR_RETURN(axis, 0);\n  CHECK_LT_OR_RETURN(axis, dy_shape.NumAxes());\n  CHECK_GT_OR_RETURN(epsilon, 0);\n  FOR_RANGE(int32_t, i, 0, dy_shape.NumAxes()) {\n    if (i == axis) {\n      CHECK_EQ_OR_RETURN(square_x_sum_shape.At(i), 1);\n    } else {\n      CHECK_EQ_OR_RETURN(square_x_sum_shape.At(i), dy_shape.At(i));\n    }\n  }\n  ctx->SetOutputShape(\"dx\", 0, dy_shape);\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> L2NormalizeGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> L2NormalizeGradOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& y_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"y\", 0);\n  const int32_t axis = ctx->Attr<int32_t>(\"axis\");\n  FOR_RANGE(int64_t, i, 0, y_tensor.shape().NumAxes()) {\n    if (i != axis) {\n      ctx->NewBuilder()\n          .Split(user_op::OpArg(\"y\", 0), i)\n          .Split(user_op::OpArg(\"dy\", 0), i)\n          .Split(user_op::OpArg(\"square_x_sum\", 0), i)\n          .Split(user_op::OpArg(\"dx\", 0), i)\n          .Build();\n    }\n  }\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> L2NormalizeGradOp::InferDataType(user_op::InferContext* ctx) {\n  CHECK_EQ_OR_RETURN(ctx->InputDType(\"y\", 0), ctx->InputDType(\"dy\", 0))\n      << \"InferDataType Failed. Expected \" << DataType_Name(ctx->InputDType(\"dy\", 0))\n      << \", but got \" << DataType_Name(ctx->InputDType(\"y\", 0));\n  CHECK_EQ_OR_RETURN(ctx->InputDType(\"y\", 0), ctx->InputDType(\"square_x_sum\", 0))\n      << \"InferDataType Failed. Expected \" << DataType_Name(ctx->InputDType(\"square_x_sum\", 0))\n      << \", but got \" << DataType_Name(ctx->InputDType(\"y\", 0));\n  ctx->SetOutputDType(\"dx\", 0, ctx->InputDType(\"dy\", 0));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/layer_norm_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\nDEFINE_ENV_BOOL(ONEFLOW_LAYER_NORM_PARAM_KEEP_DIM, false);\n\nnamespace {\n\nint64_t ShiftNegativeAxisIfNeed(const Shape& shape, int64_t axis) {\n  const int64_t shifted = axis < 0 ? axis + shape.NumAxes() : axis;\n  CHECK_GE(shifted, 0);\n  CHECK_LT(shifted, shape.NumAxes());\n  return shifted;\n}\n\nShape InferBnParamShape(const Shape& x_shape, const int64_t begin_norm_axis) {\n  DimVector bn_param_shape_dim_vec;\n  bn_param_shape_dim_vec.insert(bn_param_shape_dim_vec.end(), x_shape.dim_vec().cbegin(),\n                                x_shape.dim_vec().cbegin() + begin_norm_axis);\n  if (EnvBool<ONEFLOW_LAYER_NORM_PARAM_KEEP_DIM>()) {\n    while (bn_param_shape_dim_vec.size() < x_shape.dim_vec().size()) {\n      bn_param_shape_dim_vec.push_back(1);\n    }\n  }\n  const Shape bn_param_shape(bn_param_shape_dim_vec);\n  return bn_param_shape;\n}\n\noneflow::DataType InferBnParamDataType(const DataType x_data_type) {\n  return (x_data_type == DataType::kFloat16 || x_data_type == DataType::kBFloat16)\n             ? DataType::kFloat\n             : x_data_type;\n}\n\n}  // namespace\n\n/* static */ Maybe<void> LayerNormOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& x = ctx->InputTensorDesc(\"x\", 0);\n  user_op::TensorDesc* y = ctx->MutOutputTensorDesc(\"y\", 0);\n  user_op::TensorDesc* mean = ctx->MutOutputTensorDesc(\"mean\", 0);\n  user_op::TensorDesc* inv_variance = ctx->MutOutputTensorDesc(\"inv_variance\", 0);\n  const bool center = ctx->Attr<bool>(\"center\");\n  const bool scale = ctx->Attr<bool>(\"scale\");\n  const int64_t begin_params_axis =\n      ShiftNegativeAxisIfNeed(x.shape(), ctx->Attr<int64_t>(\"begin_params_axis\"));\n  y->set_shape(x.shape());\n  y->set_is_dynamic(x.is_dynamic());\n  DimVector param_shape_dim_vec;\n  param_shape_dim_vec.insert(param_shape_dim_vec.end(),\n                             x.shape().dim_vec().cbegin() + begin_params_axis,\n                             x.shape().dim_vec().cend());\n  const Shape param_shape(param_shape_dim_vec);\n  if (center) {\n    const user_op::TensorDesc& beta = ctx->InputTensorDesc(\"beta\", 0);\n    CHECK_EQ_OR_RETURN(beta.shape(), param_shape);\n  }\n  if (scale) {\n    const user_op::TensorDesc& gamma = ctx->InputTensorDesc(\"gamma\", 0);\n    CHECK_EQ_OR_RETURN(gamma.shape(), param_shape);\n  }\n  const int64_t begin_norm_axis =\n      ShiftNegativeAxisIfNeed(x.shape(), ctx->Attr<int64_t>(\"begin_norm_axis\"));\n  if (begin_norm_axis != begin_params_axis) {\n    return Error::RuntimeError() << \"begin_norm_axis must equal to begin_params_axis, but got \"\n                                 << begin_norm_axis << \" vs \" << begin_params_axis;\n  }\n  mean->set_shape(InferBnParamShape(x.shape(), begin_norm_axis));\n  *inv_variance = *mean;\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> LayerNormOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> LayerNormOp::GetSbp(user_op::SbpContext* ctx) {\n  const Shape& x_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"x\", 0).shape();\n  int64_t begin_norm_axis = ShiftNegativeAxisIfNeed(x_shape, ctx->Attr<int64_t>(\"begin_norm_axis\"));\n  int64_t begin_params_axis =\n      ShiftNegativeAxisIfNeed(x_shape, ctx->Attr<int64_t>(\"begin_params_axis\"));\n  for (int i = 0; i < std::min(begin_norm_axis, begin_params_axis); ++i) {\n    ctx->NewBuilder()\n        .Split(ctx->inputs(), i)\n        .Split(ctx->outputs(), i)\n        .Broadcast(user_op::OpArg(\"gamma\", 0))\n        .Broadcast(user_op::OpArg(\"beta\", 0))\n        .Build();\n  }\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> LayerNormOp::InferDataType(user_op::InferContext* ctx) {\n  const bool center = ctx->Attr<bool>(\"center\");\n  const user_op::TensorDesc& x = ctx->InputTensorDesc(\"x\", 0);\n  user_op::TensorDesc* y = ctx->MutOutputTensorDesc(\"y\", 0);\n  y->set_data_type(x.data_type());\n  if (center) {\n    const user_op::TensorDesc& beta = ctx->InputTensorDesc(\"beta\", 0);\n    CHECK_EQ_OR_RETURN(beta.data_type(), x.data_type())\n        << \"InferDataType Failed. Expected \" << DataType_Name(x.data_type()) << \", but got \"\n        << DataType_Name(beta.data_type());\n  }\n  const bool scale = ctx->Attr<bool>(\"scale\");\n  if (scale) {\n    const user_op::TensorDesc& gamma = ctx->InputTensorDesc(\"gamma\", 0);\n    CHECK_EQ_OR_RETURN(gamma.data_type(), x.data_type())\n        << \"InferDataType Failed. Expected \" << DataType_Name(x.data_type()) << \", but got \"\n        << DataType_Name(gamma.data_type());\n  }\n  user_op::TensorDesc* mean = ctx->MutOutputTensorDesc(\"mean\", 0);\n  user_op::TensorDesc* inv_variance = ctx->MutOutputTensorDesc(\"inv_variance\", 0);\n  mean->set_data_type(InferBnParamDataType(x.data_type()));\n  inv_variance->set_data_type(mean->data_type());\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> LayerNormGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& dy = ctx->InputTensorDesc(\"dy\", 0);\n  const user_op::TensorDesc& x = ctx->InputTensorDesc(\"x\", 0);\n  const user_op::TensorDesc& mean = ctx->InputTensorDesc(\"mean\", 0);\n  const user_op::TensorDesc& inv_variance = ctx->InputTensorDesc(\"inv_variance\", 0);\n  user_op::TensorDesc* dx = ctx->MutOutputTensorDesc(\"dx\", 0);\n  CHECK_EQ_OR_RETURN(dy.shape(), x.shape());\n  const int64_t begin_norm_axis = ctx->Attr<int64_t>(\"begin_norm_axis\");\n  CHECK_GT_OR_RETURN(begin_norm_axis, 0);\n  const Shape& bn_param_shape = InferBnParamShape(x.shape(), begin_norm_axis);\n  CHECK_EQ_OR_RETURN(mean.shape(), bn_param_shape);\n  CHECK_EQ_OR_RETURN(inv_variance.shape(), bn_param_shape);\n  dx->set_shape(dy.shape());\n  dx->set_is_dynamic(dy.is_dynamic());\n  if (ctx->has_input(\"_add_to_output\", 0)) {\n    const auto& add_to_output = ctx->InputTensorDesc(\"_add_to_output\", 0);\n    CHECK_EQ_OR_RETURN(add_to_output.shape(), dx->shape());\n  }\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> LayerNormGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> LayerNormGradOp::GetSbp(user_op::SbpContext* ctx) {\n  std::vector<user_op::OpArg> broadcast_args;\n  if (ctx->user_op_conf().has_input(\"gamma\", 0)) {\n    broadcast_args.emplace_back(user_op::OpArg(\"gamma\", 0));\n  }\n  int64_t begin_norm_axis = ctx->Attr<int64_t>(\"begin_norm_axis\");\n  for (int i = 0; i < begin_norm_axis; ++i) {\n    ctx->NewBuilder()\n        .Split(ctx->inputs(), i)\n        .Split(ctx->outputs(), i)\n        .Broadcast(broadcast_args)\n        .Build();\n  }\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> LayerNormGradOp::InferDataType(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& dy = ctx->InputTensorDesc(\"dy\", 0);\n  const user_op::TensorDesc& x = ctx->InputTensorDesc(\"x\", 0);\n  CHECK_EQ_OR_RETURN(dy.data_type(), x.data_type())\n      << \"InferDataType Failed. Expected \" << DataType_Name(x.data_type()) << \", but got \"\n      << DataType_Name(dy.data_type());\n  const user_op::TensorDesc& mean = ctx->InputTensorDesc(\"mean\", 0);\n  const user_op::TensorDesc& inv_variance = ctx->InputTensorDesc(\"inv_variance\", 0);\n  DataType bn_param_data_type = InferBnParamDataType(x.data_type());\n  CHECK_EQ_OR_RETURN(mean.data_type(), bn_param_data_type)\n      << \"InferDataType Failed. Expected \" << DataType_Name(bn_param_data_type) << \", but got \"\n      << DataType_Name(mean.data_type());\n  CHECK_EQ_OR_RETURN(inv_variance.data_type(), bn_param_data_type)\n      << \"InferDataType Failed. Expected \" << DataType_Name(bn_param_data_type) << \", but got \"\n      << DataType_Name(inv_variance.data_type());\n  user_op::TensorDesc* dx = ctx->MutOutputTensorDesc(\"dx\", 0);\n  dx->set_data_type(dy.data_type());\n  if (ctx->has_input(\"_add_to_output\", 0)) {\n    const auto& add_to_output = ctx->InputTensorDesc(\"_add_to_output\", 0);\n    CHECK_EQ_OR_RETURN(add_to_output.data_type(), dx->data_type())\n        << \"InferDataType Failed. Expected \" << DataType_Name(dx->data_type()) << \", but got \"\n        << DataType_Name(add_to_output.data_type());\n  }\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> LayerNormParamGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  // TODO: tsai: replace lambda with user op if\n  auto has_tensor = [ctx](const std::string& bn) -> bool {\n    bool ret = false;\n    for (const auto& t : ctx->inputs()) {\n      if (bn == t.first) { return true; }\n    }\n    for (const auto& t : ctx->outputs()) {\n      if (bn == t.first) { return true; }\n    }\n    return ret;\n  };\n  const user_op::TensorDesc& dy = ctx->InputTensorDesc(\"dy\", 0);\n  const int64_t begin_params_axis = ctx->Attr<int64_t>(\"begin_params_axis\");\n  const bool has_beta_diff = has_tensor(\"beta_diff\");\n  const bool has_gamma_diff = has_tensor(\"gamma_diff\");\n  CHECK_GE_OR_RETURN(begin_params_axis, 1);\n  CHECK_LT_OR_RETURN(begin_params_axis, dy.shape().NumAxes());\n  DimVector param_shape_dim_vec;\n  param_shape_dim_vec.insert(param_shape_dim_vec.end(),\n                             dy.shape().dim_vec().cbegin() + begin_params_axis,\n                             dy.shape().dim_vec().cend());\n  const Shape param_shape(param_shape_dim_vec);\n  if (has_beta_diff) {\n    user_op::TensorDesc* beta_diff = ctx->MutOutputTensorDesc(\"beta_diff\", 0);\n    beta_diff->set_shape(param_shape);\n  }\n  if (has_gamma_diff) {\n    user_op::TensorDesc* gamma_diff = ctx->MutOutputTensorDesc(\"gamma_diff\", 0);\n    gamma_diff->set_shape(param_shape);\n  }\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> LayerNormParamGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> LayerNormParamGradOp::GetSbp(user_op::SbpContext* ctx) {\n  int64_t begin_params_axis = ctx->Attr<int64_t>(\"begin_params_axis\");\n  for (int i = 0; i < begin_params_axis; ++i) {\n    ctx->NewBuilder().Split(ctx->inputs(), i).PartialSum(ctx->outputs()).Build();\n  }\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> LayerNormParamGradOp::InferDataType(user_op::InferContext* ctx) {\n  auto has_tensor = [ctx](const std::string& bn) -> bool {\n    bool ret = false;\n    for (auto& t : ctx->inputs()) {\n      if (bn == t.first) { return true; }\n    }\n    for (auto& t : ctx->outputs()) {\n      if (bn == t.first) { return true; }\n    }\n    return ret;\n  };\n  const bool has_beta_diff = has_tensor(\"beta_diff\");\n  const bool has_gamma_diff = has_tensor(\"gamma_diff\");\n  const user_op::TensorDesc& dy = ctx->InputTensorDesc(\"dy\", 0);\n  if (has_beta_diff) {\n    user_op::TensorDesc* beta_diff = ctx->MutOutputTensorDesc(\"beta_diff\", 0);\n    beta_diff->set_data_type(dy.data_type());\n  }\n  if (has_gamma_diff) {\n    user_op::TensorDesc* gamma_diff = ctx->MutOutputTensorDesc(\"gamma_diff\", 0);\n    gamma_diff->set_data_type(dy.data_type());\n  }\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> FuseLayerNormGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& dy = ctx->InputTensorDesc(\"dy\", 0);\n  const user_op::TensorDesc& x = ctx->InputTensorDesc(\"x\", 0);\n  const user_op::TensorDesc& mean = ctx->InputTensorDesc(\"mean\", 0);\n  const user_op::TensorDesc& inv_variance = ctx->InputTensorDesc(\"inv_variance\", 0);\n  user_op::TensorDesc* dx = ctx->MutOutputTensorDesc(\"dx\", 0);\n  CHECK_EQ_OR_RETURN(dy.shape(), x.shape()) << \"dy and x shapes should be equal.\";\n  const int64_t begin_norm_axis = ctx->Attr<int64_t>(\"begin_norm_axis\");\n  CHECK_GT_OR_RETURN(begin_norm_axis, 0) << \"begin_norm_axis must be greater than 0.\";\n  const Shape& bn_param_shape = InferBnParamShape(x.shape(), begin_norm_axis);\n  CHECK_EQ_OR_RETURN(mean.shape(), bn_param_shape) << \"mean shape must match bn_param_shape.\";\n  CHECK_EQ_OR_RETURN(inv_variance.shape(), bn_param_shape)\n      << \"inv_variance shape must match bn_param_shape.\";\n  dx->set_shape(dy.shape());\n  dx->set_is_dynamic(dy.is_dynamic());\n  if (ctx->has_input(\"_add_to_output\", 0)) {\n    const auto& add_to_output = ctx->InputTensorDesc(\"_add_to_output\", 0);\n    CHECK_EQ_OR_RETURN(add_to_output.shape(), dx->shape())\n        << \"add_to_output shape must match dx shape.\";\n  }\n\n  auto has_tensor = [ctx](const std::string& bn) -> bool {\n    bool ret = false;\n    for (const auto& t : ctx->inputs()) {\n      if (bn == t.first) { return true; }\n    }\n    for (const auto& t : ctx->outputs()) {\n      if (bn == t.first) { return true; }\n    }\n    return ret;\n  };\n  const int64_t begin_params_axis = ctx->Attr<int64_t>(\"begin_params_axis\");\n  const bool has_beta_diff = has_tensor(\"beta_diff\");\n  const bool has_gamma_diff = has_tensor(\"gamma_diff\");\n  CHECK_GE_OR_RETURN(begin_params_axis, 1)\n      << \"begin_params_axis must be greater than or equal to 1.\";\n  CHECK_LT_OR_RETURN(begin_params_axis, dy.shape().NumAxes())\n      << \"begin_params_axis must be less than the number of axes in dy shape.\";\n  DimVector param_shape_dim_vec;\n  param_shape_dim_vec.insert(param_shape_dim_vec.end(),\n                             dy.shape().dim_vec().cbegin() + begin_params_axis,\n                             dy.shape().dim_vec().cend());\n  const Shape param_shape(param_shape_dim_vec);\n  if (has_beta_diff) {\n    user_op::TensorDesc* beta_diff = ctx->MutOutputTensorDesc(\"beta_diff\", 0);\n    beta_diff->set_shape(param_shape);\n  }\n  if (has_gamma_diff) {\n    user_op::TensorDesc* gamma_diff = ctx->MutOutputTensorDesc(\"gamma_diff\", 0);\n    gamma_diff->set_shape(param_shape);\n  }\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> FuseLayerNormGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> FuseLayerNormGradOp::GetSbp(user_op::SbpContext* ctx) {\n  std::vector<user_op::OpArg> broadcast_args;\n  if (ctx->user_op_conf().has_input(\"gamma\", 0)) { broadcast_args.emplace_back(\"gamma\", 0); }\n  int64_t begin_norm_axis = ctx->Attr<int64_t>(\"begin_norm_axis\");\n  int64_t begin_params_axis = ctx->Attr<int64_t>(\"begin_params_axis\");\n  CHECK_EQ(begin_norm_axis, begin_params_axis)\n      << \"begin_norm_axis and begin_params_axis must be equal, but got \" << begin_norm_axis\n      << \" and \" << begin_params_axis;\n  for (int i = 0; i < begin_norm_axis; ++i) {\n    ctx->NewBuilder()\n        .Split(ctx->inputs(), i)\n        .Split(user_op::OpArg(\"dx\", 0), i)\n        .PartialSum(user_op::OpArg(\"gamma_diff\", 0))\n        .PartialSum(user_op::OpArg(\"beta_diff\", 0))\n        .Broadcast(broadcast_args)\n        .Build();\n  }\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> FuseLayerNormGradOp::InferDataType(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& dy = ctx->InputTensorDesc(\"dy\", 0);\n  const user_op::TensorDesc& x = ctx->InputTensorDesc(\"x\", 0);\n  CHECK_EQ_OR_RETURN(dy.data_type(), x.data_type())\n      << \"InferDataType Failed. Expected \" << DataType_Name(x.data_type()) << \", but got \"\n      << DataType_Name(dy.data_type());\n  const user_op::TensorDesc& mean = ctx->InputTensorDesc(\"mean\", 0);\n  const user_op::TensorDesc& inv_variance = ctx->InputTensorDesc(\"inv_variance\", 0);\n  DataType bn_param_data_type = InferBnParamDataType(x.data_type());\n  CHECK_EQ_OR_RETURN(mean.data_type(), bn_param_data_type)\n      << \"InferDataType Failed. Expected \" << DataType_Name(bn_param_data_type) << \", but got \"\n      << DataType_Name(mean.data_type());\n  CHECK_EQ_OR_RETURN(inv_variance.data_type(), bn_param_data_type)\n      << \"InferDataType Failed. Expected \" << DataType_Name(bn_param_data_type) << \", but got \"\n      << DataType_Name(inv_variance.data_type());\n  user_op::TensorDesc* dx = ctx->MutOutputTensorDesc(\"dx\", 0);\n  dx->set_data_type(dy.data_type());\n  if (ctx->has_input(\"_add_to_output\", 0)) {\n    const auto& add_to_output = ctx->InputTensorDesc(\"_add_to_output\", 0);\n    CHECK_EQ_OR_RETURN(add_to_output.data_type(), dx->data_type())\n        << \"InferDataType Failed. Expected \" << DataType_Name(dx->data_type()) << \", but got \"\n        << DataType_Name(add_to_output.data_type());\n  }\n\n  auto has_tensor = [ctx](const std::string& bn) -> bool {\n    bool ret = false;\n    for (auto& t : ctx->inputs()) {\n      if (bn == t.first) { return true; }\n    }\n    for (auto& t : ctx->outputs()) {\n      if (bn == t.first) { return true; }\n    }\n    return ret;\n  };\n  const bool has_beta_diff = has_tensor(\"beta_diff\");\n  const bool has_gamma_diff = has_tensor(\"gamma_diff\");\n  if (has_beta_diff) {\n    user_op::TensorDesc* beta_diff = ctx->MutOutputTensorDesc(\"beta_diff\", 0);\n    beta_diff->set_data_type(dy.data_type());\n  }\n  if (has_gamma_diff) {\n    user_op::TensorDesc* gamma_diff = ctx->MutOutputTensorDesc(\"gamma_diff\", 0);\n    gamma_diff->set_data_type(dy.data_type());\n  }\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/leaky_relu_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/* static */ Maybe<void> LeakyReluOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  ctx->SetOutputShape(\"y\", 0, ctx->InputShape(\"x\", 0));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> LeakyReluOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> LeakyReluOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"x\", 0);\n  FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) {\n    ctx->NewBuilder().Split(user_op::OpArg(\"x\", 0), i).Split(user_op::OpArg(\"y\", 0), i).Build();\n  }\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> LeakyReluOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"y\", 0, ctx->InputDType(\"x\", 0));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> LeakyReluGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& x_shape = ctx->InputShape(\"x\", 0);\n  const Shape& dy_shape = ctx->InputShape(\"dy\", 0);\n  CHECK_OR_RETURN(dy_shape == x_shape);\n  ctx->SetOutputShape(\"dx\", 0, dy_shape);\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> LeakyReluGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> LeakyReluGradOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"x\", 0);\n  FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) {\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"x\", 0), i)\n        .Split(user_op::OpArg(\"dy\", 0), i)\n        .Split(user_op::OpArg(\"dx\", 0), i)\n        .Build();\n  }\n  ctx->NewBuilder()\n      .Broadcast(user_op::OpArg(\"x\", 0))\n      .PartialSum(user_op::OpArg(\"dy\", 0))\n      .PartialSum(user_op::OpArg(\"dx\", 0))\n      .Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> LeakyReluGradOp::InferDataType(user_op::InferContext* ctx) {\n  CHECK_EQ_OR_RETURN(ctx->InputDType(\"x\", 0), ctx->InputDType(\"dy\", 0))\n      << \"InferDataType Failed. Expected \" << DataType_Name(ctx->InputDType(\"dy\", 0))\n      << \", but got \" << DataType_Name(ctx->InputDType(\"x\", 0));\n  ctx->SetOutputDType(\"dx\", 0, ctx->InputDType(\"dy\", 0));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/lerp_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\nMaybe<void> LerpOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& start = ctx->InputTensorDesc(\"start\", 0);\n  const user_op::TensorDesc& end = ctx->InputTensorDesc(\"end\", 0);\n  const user_op::TensorDesc& weight = ctx->InputTensorDesc(\"weight\", 0);\n\n  CHECK_EQ_OR_RETURN(start.shape(), end.shape())\n      << \"The size of tensor start\" << start.shape() << \"must match the size of tensor end\"\n      << end.shape();\n  if (weight.shape().elem_cnt() != 1) {\n    CHECK_EQ_OR_RETURN(start.shape(), weight.shape())\n        << \"The size of tensor start\" << start.shape() << \"must match the size of tensor weight\"\n        << weight.shape();\n  }\n\n  user_op::TensorDesc* out = ctx->MutOutputTensorDesc(\"out\", 0);\n  out->set_is_dynamic(start.is_dynamic());\n  out->set_shape(start.shape());\n\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> LerpOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return LerpOp::InferLogicalTensorDesc(ctx);\n}\n\nMaybe<void> LerpOp::InferDataType(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& start = ctx->InputTensorDesc(\"start\", 0);\n  const user_op::TensorDesc& end = ctx->InputTensorDesc(\"end\", 0);\n  const user_op::TensorDesc& weight = ctx->InputTensorDesc(\"weight\", 0);\n\n  CHECK_EQ_OR_RETURN(start.data_type(), end.data_type())\n      << Error::RuntimeError() << \"expected dtype \" << start.data_type()\n      << \" for `end` but got dtype \" << end.data_type();\n  CHECK_EQ_OR_RETURN(start.data_type(), weight.data_type())\n      << Error::RuntimeError() << \"expected dtype \" << start.data_type()\n      << \" for `weight` but got dtype \" << weight.data_type();\n\n  user_op::TensorDesc* out = ctx->MutOutputTensorDesc(\"out\", 0);\n  out->set_data_type(start.data_type());\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> LerpOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& start = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"start\", 0);\n  FOR_RANGE(int64_t, i, 0, start.shape().NumAxes()) {\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"start\", 0), i)\n        .Split(user_op::OpArg(\"end\", 0), i)\n        .Split(user_op::OpArg(\"weight\", 0), i)\n        .Split(user_op::OpArg(\"out\", 0), i)\n        .Build();\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> LerpGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& start = ctx->InputTensorDesc(\"start\", 0);\n  const user_op::TensorDesc& end = ctx->InputTensorDesc(\"end\", 0);\n  const user_op::TensorDesc& weight = ctx->InputTensorDesc(\"weight\", 0);\n  const user_op::TensorDesc& out_diff = ctx->InputTensorDesc(\"out_diff\", 0);\n\n  CHECK_EQ_OR_RETURN(start.shape(), end.shape())\n      << \"The size of tensor start\" << start.shape() << \"must match the size of tensor end\"\n      << end.shape();\n  CHECK_EQ_OR_RETURN(start.shape(), weight.shape())\n      << \"The size of tensor start\" << start.shape() << \"must match the size of tensor weight\"\n      << weight.shape();\n  CHECK_EQ_OR_RETURN(start.shape(), out_diff.shape())\n      << \"The size of tensor start\" << start.shape() << \"must match the size of tensor out_diff\"\n      << out_diff.shape();\n\n  user_op::TensorDesc* start_diff = ctx->MutOutputTensorDesc(\"start_diff\", 0);\n  user_op::TensorDesc* end_diff = ctx->MutOutputTensorDesc(\"end_diff\", 0);\n  user_op::TensorDesc* weight_diff = ctx->MutOutputTensorDesc(\"weight_diff\", 0);\n  start_diff->set_is_dynamic(start.is_dynamic());\n  start_diff->set_shape(start.shape());\n\n  end_diff->set_is_dynamic(end.is_dynamic());\n  end_diff->set_shape(end.shape());\n\n  weight_diff->set_is_dynamic(weight.is_dynamic());\n  weight_diff->set_shape(weight.shape());\n\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> LerpGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return LerpGradOp::InferLogicalTensorDesc(ctx);\n}\n\nMaybe<void> LerpGradOp::InferDataType(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& start = ctx->InputTensorDesc(\"start\", 0);\n  const user_op::TensorDesc& end = ctx->InputTensorDesc(\"end\", 0);\n  const user_op::TensorDesc& weight = ctx->InputTensorDesc(\"weight\", 0);\n  const user_op::TensorDesc& out_diff = ctx->InputTensorDesc(\"out_diff\", 0);\n\n  CHECK_EQ_OR_RETURN(start.data_type(), end.data_type())\n      << Error::RuntimeError() << \"expected dtype \" << start.data_type()\n      << \" for `end` but got dtype \" << end.data_type();\n  CHECK_EQ_OR_RETURN(start.data_type(), weight.data_type())\n      << Error::RuntimeError() << \"expected dtype \" << start.data_type()\n      << \" for `weight` but got dtype \" << weight.data_type();\n  CHECK_EQ_OR_RETURN(start.data_type(), out_diff.data_type())\n      << Error::RuntimeError() << \"expected dtype \" << start.data_type()\n      << \" for `out_diff` but got dtype \" << out_diff.data_type();\n\n  user_op::TensorDesc* start_diff = ctx->MutOutputTensorDesc(\"start_diff\", 0);\n  user_op::TensorDesc* end_diff = ctx->MutOutputTensorDesc(\"end_diff\", 0);\n  user_op::TensorDesc* weight_diff = ctx->MutOutputTensorDesc(\"weight_diff\", 0);\n\n  start_diff->set_data_type(start.data_type());\n  end_diff->set_data_type(end.data_type());\n  weight_diff->set_data_type(weight.data_type());\n\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> LerpGradOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& start = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"start\", 0);\n  FOR_RANGE(int64_t, i, 0, start.shape().NumAxes()) {\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"start\", 0), i)\n        .Split(user_op::OpArg(\"end\", 0), i)\n        .Split(user_op::OpArg(\"weight\", 0), i)\n        .Split(user_op::OpArg(\"out_diff\", 0), i)\n        .Split(user_op::OpArg(\"start_diff\", 0), i)\n        .Split(user_op::OpArg(\"end_diff\", 0), i)\n        .Split(user_op::OpArg(\"weight_diff\", 0), i)\n        .Build();\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> ScalarLerpOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& start = ctx->InputTensorDesc(\"start\", 0);\n  const user_op::TensorDesc& end = ctx->InputTensorDesc(\"end\", 0);\n\n  CHECK_EQ_OR_RETURN(start.shape(), end.shape())\n      << \"The size of tensor start\" << start.shape() << \"must match the size of tensor end\"\n      << end.shape();\n\n  user_op::TensorDesc* out = ctx->MutOutputTensorDesc(\"out\", 0);\n  out->set_is_dynamic(start.is_dynamic());\n  out->set_shape(start.shape());\n\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> ScalarLerpOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return ScalarLerpOp::InferLogicalTensorDesc(ctx);\n}\n\nMaybe<void> ScalarLerpOp::InferDataType(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& start = ctx->InputTensorDesc(\"start\", 0);\n  const user_op::TensorDesc& end = ctx->InputTensorDesc(\"end\", 0);\n\n  CHECK_EQ_OR_RETURN(start.data_type(), end.data_type())\n      << Error::RuntimeError() << \"expected dtype \" << start.data_type()\n      << \" for `end` but got dtype \" << end.data_type();\n\n  user_op::TensorDesc* out = ctx->MutOutputTensorDesc(\"out\", 0);\n  out->set_data_type(start.data_type());\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> ScalarLerpOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& start = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"start\", 0);\n  FOR_RANGE(int64_t, i, 0, start.shape().NumAxes()) {\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"start\", 0), i)\n        .Split(user_op::OpArg(\"end\", 0), i)\n        .Split(user_op::OpArg(\"out\", 0), i)\n        .Build();\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> ScalarLerpGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& start = ctx->InputTensorDesc(\"start\", 0);\n  const user_op::TensorDesc& end = ctx->InputTensorDesc(\"end\", 0);\n  const user_op::TensorDesc& out_diff = ctx->InputTensorDesc(\"out_diff\", 0);\n\n  CHECK_EQ_OR_RETURN(start.shape(), end.shape())\n      << \"The size of tensor start\" << start.shape() << \"must match the size of tensor end\"\n      << end.shape();\n  CHECK_EQ_OR_RETURN(start.shape(), out_diff.shape())\n      << \"The size of tensor start\" << start.shape() << \"must match the size of tensor out_diff\"\n      << out_diff.shape();\n\n  user_op::TensorDesc* start_diff = ctx->MutOutputTensorDesc(\"start_diff\", 0);\n  user_op::TensorDesc* end_diff = ctx->MutOutputTensorDesc(\"end_diff\", 0);\n  start_diff->set_is_dynamic(start.is_dynamic());\n  start_diff->set_shape(start.shape());\n\n  end_diff->set_is_dynamic(start.is_dynamic());\n  end_diff->set_shape(start.shape());\n\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> ScalarLerpGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return ScalarLerpGradOp::InferLogicalTensorDesc(ctx);\n}\n\nMaybe<void> ScalarLerpGradOp::InferDataType(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& start = ctx->InputTensorDesc(\"start\", 0);\n  const user_op::TensorDesc& end = ctx->InputTensorDesc(\"end\", 0);\n  const user_op::TensorDesc& out_diff = ctx->InputTensorDesc(\"out_diff\", 0);\n\n  CHECK_EQ_OR_RETURN(start.data_type(), end.data_type())\n      << Error::RuntimeError() << \"expected dtype \" << start.data_type()\n      << \" for `end` but got dtype \" << end.data_type();\n  CHECK_EQ_OR_RETURN(start.data_type(), out_diff.data_type())\n      << Error::RuntimeError() << \"expected dtype \" << start.data_type()\n      << \" for `out_diff` but got dtype \" << out_diff.data_type();\n\n  user_op::TensorDesc* start_diff = ctx->MutOutputTensorDesc(\"start_diff\", 0);\n  user_op::TensorDesc* end_diff = ctx->MutOutputTensorDesc(\"end_diff\", 0);\n  start_diff->set_data_type(start.data_type());\n  end_diff->set_data_type(start.data_type());\n\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> ScalarLerpGradOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& start = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"start\", 0);\n  FOR_RANGE(int64_t, i, 0, start.shape().NumAxes()) {\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"start\", 0), i)\n        .Split(user_op::OpArg(\"end\", 0), i)\n        .Split(user_op::OpArg(\"out_diff\", 0), i)\n        .Split(user_op::OpArg(\"start_diff\", 0), i)\n        .Split(user_op::OpArg(\"end_diff\", 0), i)\n        .Build();\n  }\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/linalg_cross_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/*static*/ Maybe<void> LinalgCrossOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  ctx->SetOutputShape(\"out\", 0, ctx->InputShape(\"input\", 0));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> LinalgCrossOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/*static*/ Maybe<void> LinalgCrossOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& input = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"input\", 0);\n  const int64_t num_axes = input.shape().NumAxes();\n  const int64_t dim = ctx->Attr<int64_t>(\"dim\");\n\n  FOR_RANGE(int64_t, i, 0, num_axes) {\n    if (i == dim) continue;\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"input\", 0), i)\n        .Split(user_op::OpArg(\"other\", 0), i)\n        .Split(user_op::OpArg(\"out\", 0), i)\n        .Build();\n  }\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> LinalgCrossOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"input\", 0));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow"
  },
  {
    "path": "oneflow/user/ops/log_softmax_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/* static */ Maybe<void> LogSoftmaxOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  ctx->SetOutputShape(\"prob\", 0, ctx->InputShape(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> LogSoftmaxOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> LogSoftmaxOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"in\", 0);\n  FOR_RANGE(int64_t, axis, 0, in_tensor.shape().NumAxes() - 1) {\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"in\", 0), axis)\n        .Split(user_op::OpArg(\"prob\", 0), axis)\n        .Build();\n  }\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> LogSoftmaxOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"prob\", 0, ctx->InputDType(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> LogSoftmaxGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& y_shape = ctx->InputShape(\"prob\", 0);\n  const Shape& dy_shape = ctx->InputShape(\"dy\", 0);\n  CHECK_OR_RETURN(dy_shape == y_shape);\n  ctx->SetOutputShape(\"dx\", 0, dy_shape);\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> LogSoftmaxGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> LogSoftmaxGradOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& y_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"prob\", 0);\n  FOR_RANGE(int64_t, axis, 0, y_tensor.shape().NumAxes() - 1) {\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"prob\", 0), axis)\n        .Split(user_op::OpArg(\"dy\", 0), axis)\n        .Split(user_op::OpArg(\"dx\", 0), axis)\n        .Build();\n  }\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> LogSoftmaxGradOp::InferDataType(user_op::InferContext* ctx) {\n  CHECK_EQ_OR_RETURN(ctx->InputDType(\"prob\", 0), ctx->InputDType(\"dy\", 0))\n      << \"InferDataType Failed. Expected \" << DataType_Name(ctx->InputDType(\"dy\", 0))\n      << \", but got \" << DataType_Name(ctx->InputDType(\"prob\", 0));\n  ctx->SetOutputDType(\"dx\", 0, ctx->InputDType(\"prob\", 0));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/logical_not_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nMaybe<void> InferDataTypeLogicalNot(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"y\", 0, DataType::kBool);\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\n/* static */ Maybe<void> LogicalNotOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  return user_op::TensorDescInferFnUtil::Unchanged(ctx);\n}\n\n/*static*/ Maybe<void> LogicalNotOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> LogicalNotOp::GetSbp(user_op::SbpContext* ctx) {\n  return user_op::GetSbpFnUtil::SplitForEachAxis(ctx);\n}\n\n/* static */ Maybe<void> LogicalNotOp::InferDataType(user_op::InferContext* ctx) {\n  return InferDataTypeLogicalNot(ctx);\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/loss_op_util.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/user/ops/loss_op_util.h\"\n#include \"oneflow/core/common/just.h\"\n\nnamespace oneflow {\n\nuser_op::GetSbpFn GenLossForwardDefaultGetSbpFn(\n    const std::function<void(user_op::UserOpSbpSignatureBuilder& builder,\n                             user_op::SbpContext* ctx)>& f) {\n  return [=](user_op::SbpContext* ctx) -> Maybe<void> {\n    auto builder = ctx->NewBuilder()\n                       .Split(user_op::OpArg(\"input\", 0), 0)\n                       .Split(user_op::OpArg(\"target\", 0), 0)\n                       .Split(user_op::OpArg(\"out\", 0), 0);\n    if (ctx->user_op_conf().has_input(\"weight\", 0)) {\n      builder.Split(user_op::OpArg(\"weight\", 0), 0);\n    }\n    f(builder, ctx);\n    builder.Build();\n    return Maybe<void>::Ok();\n  };\n}\n\nuser_op::GetSbpFn GenLossBackwardDefaultGetSbpFn(\n    const std::function<void(user_op::UserOpSbpSignatureBuilder& builder,\n                             user_op::SbpContext* ctx)>& f) {\n  return [=](user_op::SbpContext* ctx) -> Maybe<void> {\n    auto builder = ctx->NewBuilder()\n                       .Split(user_op::OpArg(\"input\", 0), 0)\n                       .Split(user_op::OpArg(\"target\", 0), 0)\n                       .Split(user_op::OpArg(\"dx\", 0), 0)\n                       .Split(user_op::OpArg(\"dy\", 0), 0);\n    if (ctx->user_op_conf().has_input(\"weight\", 0)) {\n      builder.Split(user_op::OpArg(\"weight\", 0), 0);\n    }\n    f(builder, ctx);\n    builder.Build();\n    return Maybe<void>::Ok();\n  };\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/loss_op_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_OPS_LOSS_OP_UTIL_H_\n#define ONEFLOW_USER_OPS_LOSS_OP_UTIL_H_\n\n#include <functional>\n#include \"oneflow/core/framework/framework.h\"\n\nnamespace oneflow {\n\nuser_op::GetSbpFn GenLossForwardDefaultGetSbpFn(\n    const std::function<void(user_op::UserOpSbpSignatureBuilder& builder,\n                             user_op::SbpContext* ctx)>& f =\n        [](user_op::UserOpSbpSignatureBuilder& builder, user_op::SbpContext* ctx) {});\n\nuser_op::GetSbpFn GenLossBackwardDefaultGetSbpFn(\n    const std::function<void(user_op::UserOpSbpSignatureBuilder& builder,\n                             user_op::SbpContext* ctx)>& f =\n        [](user_op::UserOpSbpSignatureBuilder& builder, user_op::SbpContext* ctx) {});\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_OPS_LOSS_OP_UTIL_H_\n"
  },
  {
    "path": "oneflow/user/ops/lu_composition_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/*static*/ Maybe<void> LUDecompositionOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const auto& x_desc = ctx->InputTensorDesc(\"x\", 0);\n  auto x_shape = x_desc.shape();\n  ctx->SetOutputShape(\"pivot\", 0, Shape(x_shape.begin(), x_shape.end() - 1));\n  ctx->SetOutputShape(\"LU\", 0, x_shape);\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> LUDecompositionOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/*static*/ Maybe<void> LUDecompositionOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& x = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"x\", 0);\n  FOR_RANGE(int64_t, i, 0, x.shape().NumAxes() - 2) {\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"x\", 0), i)\n        .Split(user_op::OpArg(\"LU\", 0), i)\n        .Split(user_op::OpArg(\"pivot\", 0), i)\n        .Build();\n  }\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> LUDecompositionOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"LU\", 0, ctx->InputDType(\"x\", 0));\n  ctx->SetOutputDType(\"pivot\", 0, DataType::kInt32);\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/masked_fill_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nMaybe<void> InferMaskedFillTensorDesc(user_op::InferContext* ctx) {\n  const Shape& mask_shape = ctx->InputShape(\"mask\", 0);\n  ctx->SetOutputShape(\"out\", 0, mask_shape);\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InferMaskedFillDataType(user_op::InferContext* ctx) {\n  DataType mask_dtype = ctx->InputDType(\"mask\", 0);\n  CHECK_OR_RETURN(IsIntegralDataType(mask_dtype) || IsBoolDataType(mask_dtype))\n      << \" mask type must be integral or bool\";\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"x\", 0));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> GetMaskedFillSbpSignatures(user_op::SbpContext* ctx) {\n  const Shape& mask_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"mask\", 0).shape();\n  const Shape& x_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"x\", 0).shape();\n  FOR_RANGE(int64_t, i, 0, mask_shape.NumAxes()) {\n    if (mask_shape.At(i) == 1 && x_shape.At(i) == 1) { continue; }\n    if (mask_shape.At(i) == x_shape.At(i)) {\n      ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build();\n    } else if (mask_shape.At(i) == 1) {\n      ctx->NewBuilder()\n          .Broadcast(user_op::OpArg(\"mask\", 0))\n          .Split(user_op::OpArg(\"x\", 0), i)\n          .Split(ctx->outputs(), i)\n          .Build();\n    } else if (x_shape.At(i) == 1) {\n      ctx->NewBuilder()\n          .Split(user_op::OpArg(\"mask\", 0), i)\n          .Broadcast(user_op::OpArg(\"x\", 0))\n          .Split(ctx->outputs(), i)\n          .Build();\n    } else {\n      UNIMPLEMENTED();\n    }\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> GetMaskedFillInputArgModify(const user_op::GetInputArgModifier& GetInputArgModifierFn,\n                                        const user_op::UserOpConfWrapper&) {\n  user_op::InputArgModifier* mask_arg_modifier = GetInputArgModifierFn(\"mask\", 0);\n  mask_arg_modifier->set_requires_grad(false);\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\n/* static */ Maybe<void> MaskedFillOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  return InferMaskedFillTensorDesc(ctx);\n}\n\n/*static*/ Maybe<void> MaskedFillOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> MaskedFillOp::GetSbp(user_op::SbpContext* ctx) {\n  return GetMaskedFillSbpSignatures(ctx);\n}\n\n/* static */ Maybe<void> MaskedFillOp::ModifyInputArg(\n    const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) {\n  return GetMaskedFillInputArgModify(GetInputArgModifierFn, conf);\n}\n\n/* static */ Maybe<void> MaskedFillOp::InferDataType(user_op::InferContext* ctx) {\n  return InferMaskedFillDataType(ctx);\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/math_binary_broadcast_ops.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/ndarray/binary_func.h\"\n#include \"oneflow/user/ops/math_binary_broadcast_seq.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nbool IsScalarTensor(const user_op::TensorDesc* tensor) {\n  return tensor->shape().NumAxes() == 1 && tensor->shape().At(0) == 1;\n}\n\nbool IsZeroDimTensor(const user_op::TensorDesc* tensor) { return tensor->shape().NumAxes() == 0; }\n\nMaybe<void> InferTensorDescBinaryBroadcastNormal(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& tensor_x = ctx->InputTensorDesc(\"x\", 0);\n  const user_op::TensorDesc& tensor_y = ctx->InputTensorDesc(\"y\", 0);\n  user_op::TensorDesc* tensor_z = ctx->MutOutputTensorDesc(\"z\", 0);\n\n  size_t output_num_axes = std::max(tensor_x.shape().NumAxes(), tensor_y.shape().NumAxes());\n  if (IsZeroDimTensor(&tensor_x)) {\n    ctx->SetOutputShape(\"z\", 0, ctx->InputShape(\"y\", 0));\n    ctx->SetOutputIsDynamic(\"z\", 0, ctx->InputIsDynamic(\"y\", 0));\n  } else if (IsZeroDimTensor(&tensor_y)) {\n    ctx->SetOutputShape(\"z\", 0, ctx->InputShape(\"x\", 0));\n    ctx->SetOutputIsDynamic(\"z\", 0, ctx->InputIsDynamic(\"x\", 0));\n  } else if (IsScalarTensor(&tensor_x)) {\n    ctx->SetOutputShape(\"z\", 0, ctx->InputShape(\"y\", 0));\n    ctx->SetOutputIsDynamic(\"z\", 0, ctx->InputIsDynamic(\"y\", 0));\n  } else if (IsScalarTensor(&tensor_y)) {\n    ctx->SetOutputShape(\"z\", 0, ctx->InputShape(\"x\", 0));\n    ctx->SetOutputIsDynamic(\"z\", 0, ctx->InputIsDynamic(\"x\", 0));\n  } else {\n    const auto& x_shape = CreateLeftExtendedShape(ShapeView(tensor_x.shape()), output_num_axes);\n    const auto& y_shape = CreateLeftExtendedShape(ShapeView(tensor_y.shape()), output_num_axes);\n    ctx->SetOutputShape(\"z\", 0, ctx->InputShape(\"x\", 0));\n    ctx->SetOutputIsDynamic(\"z\", 0, ctx->InputIsDynamic(\"x\", 0));\n    Shape out_shape(x_shape);\n    FOR_RANGE(int64_t, i, 0, x_shape.NumAxes()) {\n      if (x_shape.At(i) != 1 && y_shape.At(i) != 1 && x_shape.At(i) != y_shape.At(i)) {\n        return Error::RuntimeError()\n               << \"The size of tensor a (\" << x_shape.At(i) << \") must match the size of tensor b (\"\n               << y_shape.At(i) << \") at non-singleton dimension \" << i;\n      }\n      out_shape.Set(i, (x_shape.At(i) == 0 || y_shape.At(i) == 0)\n                           ? 0\n                           : std::max(x_shape.At(i), y_shape.At(i)));\n    }\n    tensor_z->set_shape(out_shape);\n  }\n  tensor_z->set_is_dynamic(tensor_x.is_dynamic() || tensor_y.is_dynamic());\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InferTensorDescBinaryBroadcastLogical(user_op::InferContext* ctx) {\n  return InferTensorDescBinaryBroadcastNormal(ctx);\n}\n\nMaybe<void> InferDataTypeBinaryBroadcastNormal(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& tensor_x = ctx->InputTensorDesc(\"x\", 0);\n  const user_op::TensorDesc& tensor_y = ctx->InputTensorDesc(\"y\", 0);\n  CHECK_EQ_OR_RETURN(tensor_x.data_type(), tensor_y.data_type())\n      << \"InferDataType Failed. Expected \" << DataType_Name(tensor_x.data_type()) << \", but got \"\n      << DataType_Name(tensor_y.data_type());\n  ctx->SetOutputDType(\"z\", 0, ctx->InputDType(\"x\", 0));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InferDataTypeBinaryBroadcastLogical(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& tensor_x = ctx->InputTensorDesc(\"x\", 0);\n  const user_op::TensorDesc& tensor_y = ctx->InputTensorDesc(\"y\", 0);\n  CHECK_EQ_OR_RETURN(tensor_x.data_type(), tensor_y.data_type())\n      << \"InferDataType Failed. Expected \" << DataType_Name(tensor_x.data_type()) << \", but got \"\n      << DataType_Name(tensor_y.data_type());\n  ctx->SetOutputDType(\"z\", 0, DataType::kBool);\n  return Maybe<void>::Ok();\n}\n\ntemplate<template<typename> class binary_func>\nvoid GenPartialSbpSign(user_op::SbpContext* ctx) {}\n\ntemplate<>\nvoid GenPartialSbpSign<BinaryFuncAdd>(user_op::SbpContext* ctx) {\n  ctx->NewBuilder()\n      .PartialSum(user_op::OpArg(\"x\", 0))\n      .PartialSum(user_op::OpArg(\"y\", 0))\n      .PartialSum(user_op::OpArg(\"z\", 0))\n      .Build();\n}\n\ntemplate<>\nvoid GenPartialSbpSign<BinaryFuncNanSum>(user_op::SbpContext* ctx) {\n  ctx->NewBuilder()\n      .PartialSum(user_op::OpArg(\"x\", 0))\n      .PartialSum(user_op::OpArg(\"y\", 0))\n      .PartialSum(user_op::OpArg(\"z\", 0))\n      .Build();\n}\n\ntemplate<>\nvoid GenPartialSbpSign<BinaryFuncSub>(user_op::SbpContext* ctx) {\n  ctx->NewBuilder()\n      .PartialSum(user_op::OpArg(\"x\", 0))\n      .PartialSum(user_op::OpArg(\"y\", 0))\n      .PartialSum(user_op::OpArg(\"z\", 0))\n      .Build();\n}\n\ntemplate<>\nvoid GenPartialSbpSign<BinaryFuncMul>(user_op::SbpContext* ctx) {\n  ctx->NewBuilder()\n      .Broadcast(user_op::OpArg(\"x\", 0))\n      .PartialSum(user_op::OpArg(\"y\", 0))\n      .PartialSum(user_op::OpArg(\"z\", 0))\n      .Build();\n  ctx->NewBuilder()\n      .PartialSum(user_op::OpArg(\"x\", 0))\n      .Broadcast(user_op::OpArg(\"y\", 0))\n      .PartialSum(user_op::OpArg(\"z\", 0))\n      .Build();\n}\n\ntemplate<>\nvoid GenPartialSbpSign<BinaryFuncDiv>(user_op::SbpContext* ctx) {\n  ctx->NewBuilder()\n      .PartialSum(user_op::OpArg(\"x\", 0))\n      .Broadcast(user_op::OpArg(\"y\", 0))\n      .PartialSum(user_op::OpArg(\"z\", 0))\n      .Build();\n}\n\ntemplate<template<typename> class binary_func>\nMaybe<void> GetBinaryBroadcastSbpSignature(user_op::SbpContext* ctx) {\n  const Shape& x_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"x\", 0).shape();\n  const Shape& y_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"y\", 0).shape();\n  if (x_shape.NumAxes() < y_shape.NumAxes()) {\n    FOR_RANGE(int64_t, i, 0, y_shape.NumAxes() - x_shape.NumAxes()) {\n      ctx->NewBuilder()\n          .Broadcast(user_op::OpArg(\"x\", 0))\n          .Split(user_op::OpArg(\"y\", 0), i)\n          .Split(user_op::OpArg(\"z\", 0), i)\n          .Build();\n    }\n    FOR_RANGE(int64_t, i, 0, x_shape.NumAxes()) {\n      ctx->NewBuilder()\n          .Split(user_op::OpArg(\"x\", 0), x_shape.NumAxes() - 1 - i)\n          .Split(user_op::OpArg(\"y\", 0), y_shape.NumAxes() - 1 - i)\n          .Split(ctx->outputs(), y_shape.NumAxes() - 1 - i)\n          .Build();\n    }\n  } else if (x_shape.NumAxes() > y_shape.NumAxes()) {\n    FOR_RANGE(int64_t, i, 0, x_shape.NumAxes() - y_shape.NumAxes()) {\n      ctx->NewBuilder()\n          .Split(user_op::OpArg(\"x\", 0), i)\n          .Broadcast(user_op::OpArg(\"y\", 0))\n          .Split(user_op::OpArg(\"z\", 0), i)\n          .Build();\n    }\n    FOR_RANGE(int64_t, i, 0, y_shape.NumAxes()) {\n      ctx->NewBuilder()\n          .Split(user_op::OpArg(\"x\", 0), x_shape.NumAxes() - 1 - i)\n          .Split(user_op::OpArg(\"y\", 0), y_shape.NumAxes() - 1 - i)\n          .Split(ctx->outputs(), x_shape.NumAxes() - 1 - i)\n          .Build();\n    }\n  } else {\n    FOR_RANGE(int64_t, i, 0, x_shape.NumAxes()) {\n      if (x_shape.At(i) == 1 && y_shape.At(i) == 1) { continue; }\n      if (x_shape.At(i) == y_shape.At(i)) {\n        ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build();\n      } else if (x_shape.At(i) == 1) {\n        ctx->NewBuilder()\n            .Broadcast(user_op::OpArg(\"x\", 0))\n            .Split(user_op::OpArg(\"y\", 0), i)\n            .Split(ctx->outputs(), i)\n            .Build();\n      } else if (y_shape.At(i) == 1) {\n        ctx->NewBuilder()\n            .Split(user_op::OpArg(\"x\", 0), i)\n            .Broadcast(user_op::OpArg(\"y\", 0))\n            .Split(ctx->outputs(), i)\n            .Build();\n      } else {\n        UNIMPLEMENTED();\n      }\n    }\n  }\n  GenPartialSbpSign<binary_func>(ctx);\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\n#define REGISTER_BINARY_BROADCAST_NORMAL_USER_OP(op_name, suffix)                        \\\n  /* static */ Maybe<void> op_name::InferLogicalTensorDesc(user_op::InferContext* ctx) { \\\n    return InferTensorDescBinaryBroadcastNormal(ctx);                                    \\\n  }                                                                                      \\\n  /*static*/ Maybe<void> op_name::InferPhysicalTensorDesc(user_op::InferContext* ctx) {  \\\n    return InferLogicalTensorDesc(ctx);                                                  \\\n  }                                                                                      \\\n  /* static */ Maybe<void> op_name::GetSbp(user_op::SbpContext* ctx) {                   \\\n    return GetBinaryBroadcastSbpSignature<BinaryFunc##suffix>(ctx);                      \\\n  }                                                                                      \\\n  /* static */ Maybe<void> op_name::InferDataType(user_op::InferContext* ctx) {          \\\n    return InferDataTypeBinaryBroadcastNormal(ctx);                                      \\\n  }\n\n#define REGISTER_BINARY_BROADCAST_LOGICAL_USER_OP(op_name, suffix)                       \\\n  /* static */ Maybe<void> op_name::InferLogicalTensorDesc(user_op::InferContext* ctx) { \\\n    return InferTensorDescBinaryBroadcastLogical(ctx);                                   \\\n  }                                                                                      \\\n  /*static*/ Maybe<void> op_name::InferPhysicalTensorDesc(user_op::InferContext* ctx) {  \\\n    return InferLogicalTensorDesc(ctx);                                                  \\\n  }                                                                                      \\\n  /* static */ Maybe<void> op_name::GetSbp(user_op::SbpContext* ctx) {                   \\\n    return GetBinaryBroadcastSbpSignature<BinaryFunc##suffix>(ctx);                      \\\n  }                                                                                      \\\n  /* static */ Maybe<void> op_name::InferDataType(user_op::InferContext* ctx) {          \\\n    return InferDataTypeBinaryBroadcastLogical(ctx);                                     \\\n  }\n\nOF_PP_FOR_EACH_TUPLE(REGISTER_BINARY_BROADCAST_NORMAL_USER_OP, MATH_BINARY_BROADCAST_FUNC_SEQ_ODS)\nOF_PP_FOR_EACH_TUPLE(REGISTER_BINARY_BROADCAST_LOGICAL_USER_OP,\n                     MATH_BINARY_BROADCAST_LOGICAL_FUNC_SEQ_ODS)\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/math_binary_broadcast_seq.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_OPS_MATH_BINARY_BROADCAST_SEQ_H_\n#define ONEFLOW_USER_OPS_MATH_BINARY_BROADCAST_SEQ_H_\n\n#include \"oneflow/core/common/util.h\"\n\nnamespace oneflow {\n\n#define MATH_BINARY_BROADCAST_FUNC_SEQ                      \\\n  OF_PP_MAKE_TUPLE_SEQ(\"broadcast_add\", Add)                \\\n  OF_PP_MAKE_TUPLE_SEQ(\"broadcast_sub\", Sub)                \\\n  OF_PP_MAKE_TUPLE_SEQ(\"broadcast_mul\", Mul)                \\\n  OF_PP_MAKE_TUPLE_SEQ(\"broadcast_div\", Div)                \\\n  OF_PP_MAKE_TUPLE_SEQ(\"broadcast_minimum\", Min)            \\\n  OF_PP_MAKE_TUPLE_SEQ(\"broadcast_maximum\", Max)            \\\n  OF_PP_MAKE_TUPLE_SEQ(\"broadcast_bitwise_and\", BitwiseAnd) \\\n  OF_PP_MAKE_TUPLE_SEQ(\"broadcast_bitwise_or\", BitwiseOr)   \\\n  OF_PP_MAKE_TUPLE_SEQ(\"broadcast_bitwise_xor\", BitwiseXor) \\\n  OF_PP_MAKE_TUPLE_SEQ(\"broadcast_floor_mod\", FloorMod)     \\\n  OF_PP_MAKE_TUPLE_SEQ(\"broadcast_fmod\", FMod)              \\\n  OF_PP_MAKE_TUPLE_SEQ(\"broadcast_pow\", Pow)                \\\n  OF_PP_MAKE_TUPLE_SEQ(\"broadcast_zeta\", Zeta)\n\n#define MATH_BINARY_BROADCAST_LOGICAL_FUNC_SEQ          \\\n  OF_PP_MAKE_TUPLE_SEQ(\"broadcast_equal\", EQ)           \\\n  OF_PP_MAKE_TUPLE_SEQ(\"broadcast_not_equal\", NE)       \\\n  OF_PP_MAKE_TUPLE_SEQ(\"broadcast_greater\", GT)         \\\n  OF_PP_MAKE_TUPLE_SEQ(\"broadcast_greater_equal\", GE)   \\\n  OF_PP_MAKE_TUPLE_SEQ(\"broadcast_less\", LT)            \\\n  OF_PP_MAKE_TUPLE_SEQ(\"broadcast_less_equal\", LE)      \\\n  OF_PP_MAKE_TUPLE_SEQ(\"broadcast_logical_and\", AND)    \\\n  OF_PP_MAKE_TUPLE_SEQ(\"broadcast_logical_or\", OR)      \\\n  OF_PP_MAKE_TUPLE_SEQ(\"broadcast_logical_xor\", XOR)    \\\n  OF_PP_MAKE_TUPLE_SEQ(\"broadcast_isclose_eq_nan\", IEN) \\\n  OF_PP_MAKE_TUPLE_SEQ(\"broadcast_isclose_neq_nan\", INN)\n\n#define MATH_BINARY_BROADCAST_FUNC_SEQ_ODS                \\\n  OF_PP_MAKE_TUPLE_SEQ(BroadcastAddOp, Add)               \\\n  OF_PP_MAKE_TUPLE_SEQ(BroadcastSubOp, Sub)               \\\n  OF_PP_MAKE_TUPLE_SEQ(BroadcastMulOp, Mul)               \\\n  OF_PP_MAKE_TUPLE_SEQ(BroadcastDivOp, Div)               \\\n  OF_PP_MAKE_TUPLE_SEQ(BroadcastMinimumOp, Min)           \\\n  OF_PP_MAKE_TUPLE_SEQ(BroadcastMaximumOp, Max)           \\\n  OF_PP_MAKE_TUPLE_SEQ(BroadcastBitwiseAndOp, BitwiseAnd) \\\n  OF_PP_MAKE_TUPLE_SEQ(BroadcastBitwiseOrOp, BitwiseOr)   \\\n  OF_PP_MAKE_TUPLE_SEQ(BroadcastBitwiseXorOp, BitwiseXor) \\\n  OF_PP_MAKE_TUPLE_SEQ(BroadcastFloorModOp, FloorMod)     \\\n  OF_PP_MAKE_TUPLE_SEQ(BroadcastFmodOp, FMod)             \\\n  OF_PP_MAKE_TUPLE_SEQ(BroadcastPowOp, Pow)               \\\n  OF_PP_MAKE_TUPLE_SEQ(BroadcastZetaOp, Zeta)\n\n#define MATH_BINARY_BROADCAST_LOGICAL_FUNC_SEQ_ODS      \\\n  OF_PP_MAKE_TUPLE_SEQ(BroadcastEqualOp, EQ)            \\\n  OF_PP_MAKE_TUPLE_SEQ(BroadcastNotEqualOp, NE)         \\\n  OF_PP_MAKE_TUPLE_SEQ(BroadcastGreaterOp, GT)          \\\n  OF_PP_MAKE_TUPLE_SEQ(BroadcastGreaterEqualOp, GE)     \\\n  OF_PP_MAKE_TUPLE_SEQ(BroadcastLessOp, LT)             \\\n  OF_PP_MAKE_TUPLE_SEQ(BroadcastLessEqualOp, LE)        \\\n  OF_PP_MAKE_TUPLE_SEQ(BroadcastLogicalAndOp, AND)      \\\n  OF_PP_MAKE_TUPLE_SEQ(BroadcastLogicalOrOp, OR)        \\\n  OF_PP_MAKE_TUPLE_SEQ(BroadcastLogicalXorOp, XOR)      \\\n  OF_PP_MAKE_TUPLE_SEQ(BroadcastIsCloseEqualNanOp, IEN) \\\n  OF_PP_MAKE_TUPLE_SEQ(BroadcastIsCloseNotEqualNanOp, INN)\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_OPS_MATH_BINARY_BROADCAST_SEQ_H_\n"
  },
  {
    "path": "oneflow/user/ops/math_binary_elementwise_ops.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/user/ops/math_binary_elementwise_seq.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n#define MATH_ELEMENTWISE_DEFAULT_SET_FUNC(op_type)                                       \\\n  /* static */ Maybe<void> op_type::InferLogicalTensorDesc(user_op::InferContext* ctx) { \\\n    return user_op::TensorDescInferFnUtil::Unchanged(ctx);                               \\\n  }                                                                                      \\\n  /*static*/ Maybe<void> op_type::InferPhysicalTensorDesc(user_op::InferContext* ctx) {  \\\n    return InferLogicalTensorDesc(ctx);                                                  \\\n  }                                                                                      \\\n  /* static */ Maybe<void> op_type::GetSbp(user_op::SbpContext* ctx) {                   \\\n    return user_op::GetSbpFnUtil::SplitForEachAxis(ctx);                                 \\\n  }                                                                                      \\\n  /* static */ Maybe<void> op_type::InferDataType(user_op::InferContext* ctx) {          \\\n    return user_op::TensorDescInferFnUtil::UnchangedDataType(ctx);                       \\\n  }\n\n#define REGISTER_MATH_BINARY_ELEMENTWISE_OP_AND_GRAD(math_binary_elementwise_type, func_prefix) \\\n  MATH_ELEMENTWISE_DEFAULT_SET_FUNC(func_prefix##Op);                                           \\\n                                                                                                \\\n  MATH_ELEMENTWISE_DEFAULT_SET_FUNC(func_prefix##XGradOp);                                      \\\n                                                                                                \\\n  MATH_ELEMENTWISE_DEFAULT_SET_FUNC(func_prefix##YGradOp);\n\nOF_PP_FOR_EACH_TUPLE(REGISTER_MATH_BINARY_ELEMENTWISE_OP_AND_GRAD,\n                     MATH_BINARY_ELEMENTWISE_FUNC_SEQ_ODS)\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/math_binary_elementwise_seq.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_OPS_MATH_BINARY_ELEMENTWISE_SEQ_H_\n#define ONEFLOW_USER_OPS_MATH_BINARY_ELEMENTWISE_SEQ_H_\n\n#include \"oneflow/core/common/util.h\"\n\nnamespace oneflow {\n\n#define MATH_BINARY_ELEMENTWISE_FUNC_SEQ     \\\n  OF_PP_MAKE_TUPLE_SEQ(\"pow\", Pow)           \\\n  OF_PP_MAKE_TUPLE_SEQ(\"atan2\", Atan2)       \\\n  OF_PP_MAKE_TUPLE_SEQ(\"floordiv\", FloorDiv) \\\n  OF_PP_MAKE_TUPLE_SEQ(\"truncdiv\", TruncDiv) \\\n  OF_PP_MAKE_TUPLE_SEQ(\"xdivy\", Xdivy)       \\\n  OF_PP_MAKE_TUPLE_SEQ(\"xlogy\", Xlogy)\n\n#define MATH_BINARY_ELEMENTWISE_FUNC_SEQ_ODS \\\n  OF_PP_MAKE_TUPLE_SEQ(\"pow\", Pow)           \\\n  OF_PP_MAKE_TUPLE_SEQ(\"atan2\", Atan2)       \\\n  OF_PP_MAKE_TUPLE_SEQ(\"floordiv\", Floordiv) \\\n  OF_PP_MAKE_TUPLE_SEQ(\"truncdiv\", Truncdiv) \\\n  OF_PP_MAKE_TUPLE_SEQ(\"xdivy\", Xdivy)       \\\n  OF_PP_MAKE_TUPLE_SEQ(\"xlogy\", Xlogy)\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_OPS_MATH_BINARY_ELEMENTWISE_SEQ_H_\n"
  },
  {
    "path": "oneflow/user/ops/math_unary_elementwise_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/user/ops/math_unary_elementwise_seq.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n#define MATH_ELEMENTWISE_DEFAULT_SET_FUNC(op_type)                                       \\\n  /* static */ Maybe<void> op_type::InferLogicalTensorDesc(user_op::InferContext* ctx) { \\\n    return user_op::TensorDescInferFnUtil::Unchanged(ctx);                               \\\n  }                                                                                      \\\n  /*static*/ Maybe<void> op_type::InferPhysicalTensorDesc(user_op::InferContext* ctx) {  \\\n    return InferLogicalTensorDesc(ctx);                                                  \\\n  }                                                                                      \\\n  /* static */ Maybe<void> op_type::GetSbp(user_op::SbpContext* ctx) {                   \\\n    return user_op::GetSbpFnUtil::SplitForEachAxis(ctx);                                 \\\n  }                                                                                      \\\n  /* static */ Maybe<void> op_type::InferDataType(user_op::InferContext* ctx) {          \\\n    return user_op::TensorDescInferFnUtil::UnchangedDataType(ctx);                       \\\n  }\n\n#define REGISTER_MATH_UNARY_ELEMENTWISE_OP_AND_GRAD_WITH_DY_X(math_unary_elementwise_type, \\\n                                                              func_prefix)                 \\\n  MATH_ELEMENTWISE_DEFAULT_SET_FUNC(func_prefix##Op)                                       \\\n  MATH_ELEMENTWISE_DEFAULT_SET_FUNC(func_prefix##GradOp)\n\nOF_PP_FOR_EACH_TUPLE(REGISTER_MATH_UNARY_ELEMENTWISE_OP_AND_GRAD_WITH_DY_X,\n                     MATH_UNARY_ELEMENTWISE_PRIMITIVE_FUNC_BWD_WITH_DY_X_SEQ)\n\n#define REGISTER_MATH_UNARY_ELEMENTWISE_OP_AND_GRAD_WITH_DY_Y(math_unary_elementwise_type, \\\n                                                              func_prefix)                 \\\n  MATH_ELEMENTWISE_DEFAULT_SET_FUNC(func_prefix##Op)                                       \\\n  MATH_ELEMENTWISE_DEFAULT_SET_FUNC(func_prefix##GradOp)\n\nOF_PP_FOR_EACH_TUPLE(REGISTER_MATH_UNARY_ELEMENTWISE_OP_AND_GRAD_WITH_DY_Y,\n                     MATH_UNARY_ELEMENTWISE_FUNC_BWD_WITH_DY_Y_SEQ)\n\n#define REGISTER_MATH_UNARY_ELEMENTWISE_OP_AND_GRAD_WITH_FILL(math_unary_elementwise_type, \\\n                                                              func_prefix)                 \\\n  MATH_ELEMENTWISE_DEFAULT_SET_FUNC(func_prefix##Op)\n\nOF_PP_FOR_EACH_TUPLE(REGISTER_MATH_UNARY_ELEMENTWISE_OP_AND_GRAD_WITH_FILL,\n                     MATH_UNARY_ELEMENTWISE_FUNC_BWD_WITH_FILL_SEQ)\n\n// Negative's grad function = negative(dy), so here register negative op separately.\nMATH_ELEMENTWISE_DEFAULT_SET_FUNC(NegativeOp)\nMATH_ELEMENTWISE_DEFAULT_SET_FUNC(BitwiseNotOp)\nMATH_ELEMENTWISE_DEFAULT_SET_FUNC(TrigammaOp)\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/math_unary_elementwise_seq.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_OPS_MATH_UNARY_ELEMENTWISE_SEQ_H_\n#define ONEFLOW_USER_OPS_MATH_UNARY_ELEMENTWISE_SEQ_H_\n\n#include \"oneflow/core/common/util.h\"\n\nnamespace oneflow {\n\n#define MATH_UNARY_ELEMENTWISE_FUNC_SEQ                      \\\n  OF_PP_MAKE_TUPLE_SEQ(\"abs\", Abs)                           \\\n  OF_PP_MAKE_TUPLE_SEQ(\"acos\", Acos)                         \\\n  OF_PP_MAKE_TUPLE_SEQ(\"acosh\", Acosh)                       \\\n  OF_PP_MAKE_TUPLE_SEQ(\"asin\", Asin)                         \\\n  OF_PP_MAKE_TUPLE_SEQ(\"asinh\", Asinh)                       \\\n  OF_PP_MAKE_TUPLE_SEQ(\"atan\", Atan)                         \\\n  OF_PP_MAKE_TUPLE_SEQ(\"atanh\", Atanh)                       \\\n  OF_PP_MAKE_TUPLE_SEQ(\"ceil\", Ceil)                         \\\n  OF_PP_MAKE_TUPLE_SEQ(\"cos\", Cos)                           \\\n  OF_PP_MAKE_TUPLE_SEQ(\"cosh\", Cosh)                         \\\n  OF_PP_MAKE_TUPLE_SEQ(\"digamma\", Digamma)                   \\\n  OF_PP_MAKE_TUPLE_SEQ(\"trigamma\", Trigamma)                 \\\n  OF_PP_MAKE_TUPLE_SEQ(\"erf\", Erf)                           \\\n  OF_PP_MAKE_TUPLE_SEQ(\"erfc\", Erfc)                         \\\n  OF_PP_MAKE_TUPLE_SEQ(\"exp\", Exp)                           \\\n  OF_PP_MAKE_TUPLE_SEQ(\"exp2\", Exp2)                         \\\n  OF_PP_MAKE_TUPLE_SEQ(\"expm1\", Expm1)                       \\\n  OF_PP_MAKE_TUPLE_SEQ(\"floor\", Floor)                       \\\n  OF_PP_MAKE_TUPLE_SEQ(\"lgamma\", Lgamma)                     \\\n  OF_PP_MAKE_TUPLE_SEQ(\"log\", Log)                           \\\n  OF_PP_MAKE_TUPLE_SEQ(\"log2\", Log2)                         \\\n  OF_PP_MAKE_TUPLE_SEQ(\"log10\", Log10)                       \\\n  OF_PP_MAKE_TUPLE_SEQ(\"log1p\", Log1p)                       \\\n  OF_PP_MAKE_TUPLE_SEQ(\"log_sigmoid\", LogSigmoid)            \\\n  OF_PP_MAKE_TUPLE_SEQ(\"negative\", Negative)                 \\\n  OF_PP_MAKE_TUPLE_SEQ(\"reciprocal\", Reciprocal)             \\\n  OF_PP_MAKE_TUPLE_SEQ(\"reciprocal_no_nan\", ReciprocalNoNan) \\\n  OF_PP_MAKE_TUPLE_SEQ(\"rint\", Rint)                         \\\n  OF_PP_MAKE_TUPLE_SEQ(\"round\", Round)                       \\\n  OF_PP_MAKE_TUPLE_SEQ(\"rsqrt\", Rsqrt)                       \\\n  OF_PP_MAKE_TUPLE_SEQ(\"sigmoid_v2\", Sigmoid)                \\\n  OF_PP_MAKE_TUPLE_SEQ(\"sign\", Sign)                         \\\n  OF_PP_MAKE_TUPLE_SEQ(\"sin\", Sin)                           \\\n  OF_PP_MAKE_TUPLE_SEQ(\"sinh\", Sinh)                         \\\n  OF_PP_MAKE_TUPLE_SEQ(\"sqrt\", Sqrt)                         \\\n  OF_PP_MAKE_TUPLE_SEQ(\"square\", Square)                     \\\n  OF_PP_MAKE_TUPLE_SEQ(\"tan\", Tan)                           \\\n  OF_PP_MAKE_TUPLE_SEQ(\"not_equal_zero\", NotEqualZero)\n\n#define MATH_UNARY_ELEMENTWISE_PRIMITIVE_FUNC_BWD_WITH_DY_X_SEQ \\\n  OF_PP_MAKE_TUPLE_SEQ(\"abs\", Abs)                              \\\n  OF_PP_MAKE_TUPLE_SEQ(\"acos\", Acos)                            \\\n  OF_PP_MAKE_TUPLE_SEQ(\"acosh\", Acosh)                          \\\n  OF_PP_MAKE_TUPLE_SEQ(\"asin\", Asin)                            \\\n  OF_PP_MAKE_TUPLE_SEQ(\"asinh\", Asinh)                          \\\n  OF_PP_MAKE_TUPLE_SEQ(\"atan\", Atan)                            \\\n  OF_PP_MAKE_TUPLE_SEQ(\"atanh\", Atanh)                          \\\n  OF_PP_MAKE_TUPLE_SEQ(\"cos\", Cos)                              \\\n  OF_PP_MAKE_TUPLE_SEQ(\"cosh\", Cosh)                            \\\n  OF_PP_MAKE_TUPLE_SEQ(\"erf\", Erf)                              \\\n  OF_PP_MAKE_TUPLE_SEQ(\"erfc\", Erfc)                            \\\n  OF_PP_MAKE_TUPLE_SEQ(\"exp\", Exp)                              \\\n  OF_PP_MAKE_TUPLE_SEQ(\"exp2\", Exp2)                            \\\n  OF_PP_MAKE_TUPLE_SEQ(\"expm1\", Expm1)                          \\\n  OF_PP_MAKE_TUPLE_SEQ(\"log\", Log)                              \\\n  OF_PP_MAKE_TUPLE_SEQ(\"lgamma\", Lgamma)                        \\\n  OF_PP_MAKE_TUPLE_SEQ(\"digamma\", Digamma)                      \\\n  OF_PP_MAKE_TUPLE_SEQ(\"log2\", Log2)                            \\\n  OF_PP_MAKE_TUPLE_SEQ(\"log10\", Log10)                          \\\n  OF_PP_MAKE_TUPLE_SEQ(\"log1p\", Log1p)                          \\\n  OF_PP_MAKE_TUPLE_SEQ(\"log_sigmoid\", LogSigmoid)               \\\n  OF_PP_MAKE_TUPLE_SEQ(\"reciprocal\", Reciprocal)                \\\n  OF_PP_MAKE_TUPLE_SEQ(\"reciprocal_no_nan\", ReciprocalNoNan)    \\\n  OF_PP_MAKE_TUPLE_SEQ(\"rsqrt\", Rsqrt)                          \\\n  OF_PP_MAKE_TUPLE_SEQ(\"sin\", Sin)                              \\\n  OF_PP_MAKE_TUPLE_SEQ(\"sinh\", Sinh)                            \\\n  OF_PP_MAKE_TUPLE_SEQ(\"sqrt\", Sqrt)                            \\\n  OF_PP_MAKE_TUPLE_SEQ(\"square\", Square)                        \\\n  OF_PP_MAKE_TUPLE_SEQ(\"tan\", Tan)\n\n#define MATH_UNARY_ELEMENTWISE_FUNC_BWD_WITH_DY_Y_SEQ OF_PP_MAKE_TUPLE_SEQ(\"sigmoid\", Sigmoid)\n\n#define MATH_UNARY_ELEMENTWISE_FUNC_BWD_WITH_FILL_SEQ  \\\n  OF_PP_MAKE_TUPLE_SEQ(\"not_equal_zero\", NotEqualZero) \\\n  OF_PP_MAKE_TUPLE_SEQ(\"sign\", Sign)                   \\\n  OF_PP_MAKE_TUPLE_SEQ(\"rint\", Rint)                   \\\n  OF_PP_MAKE_TUPLE_SEQ(\"round\", Round)                 \\\n  OF_PP_MAKE_TUPLE_SEQ(\"floor\", Floor)                 \\\n  OF_PP_MAKE_TUPLE_SEQ(\"ceil\", Ceil)\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_OPS_MATH_UNARY_ELEMENTWISE_SEQ_H_\n"
  },
  {
    "path": "oneflow/user/ops/matmul_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nMaybe<void> InferTensorDesc4Matmul(user_op::InferContext* ctx) {\n  bool transpose_a = ctx->Attr<bool>(\"transpose_a\");\n  bool transpose_b = ctx->Attr<bool>(\"transpose_b\");\n\n  const user_op::TensorDesc& a = ctx->InputTensorDesc(\"a\", 0);\n  const user_op::TensorDesc& b = ctx->InputTensorDesc(\"b\", 0);\n  CHECK_EQ_OR_RETURN(a.shape().NumAxes(), b.shape().NumAxes());\n  CHECK_GE_OR_RETURN(a.shape().NumAxes(), 2);\n  size_t num_axes = a.shape().NumAxes();\n\n  if (num_axes > 2) {\n    for (int i = 0; i < num_axes - 2; ++i) { CHECK_EQ_OR_RETURN(a.shape().At(i), b.shape().At(i)); }\n  }\n\n  user_op::TensorDesc* out = ctx->MutOutputTensorDesc(\"out\", 0);\n\n  Shape output = ctx->InputShape(\"a\", 0);\n  ctx->SetOutputIsDynamic(\"out\", 0, ctx->InputIsDynamic(\"a\", 0));\n\n  int64_t m, n, k;  // tensor a (no trans): m*k, tensor b (no trans): k*n\n  if (!transpose_a) {\n    m = a.shape().At(num_axes - 2);\n    k = a.shape().At(num_axes - 1);\n  } else {\n    m = a.shape().At(num_axes - 1);\n    k = a.shape().At(num_axes - 2);\n  }\n  if (!transpose_b) {\n    CHECK_EQ_OR_RETURN(k, b.shape().At(num_axes - 2));\n    n = b.shape().At(num_axes - 1);\n  } else {\n    CHECK_EQ_OR_RETURN(k, b.shape().At(num_axes - 1));\n    n = b.shape().At(num_axes - 2);\n  }\n  output.Set(num_axes - 2, m);\n  output.Set(num_axes - 1, n);\n  out->set_shape(output);\n  if (ctx->has_input(\"_add_to_output\", 0)) {\n    const auto& add_to_output = ctx->InputTensorDesc(\"_add_to_output\", 0);\n    CHECK_EQ_OR_RETURN(add_to_output.shape(), out->shape());\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InferDataType4Matmul(user_op::InferContext* ctx) {\n  DataType dtype = ctx->InputDType(\"a\", 0);\n  CHECK_EQ_OR_RETURN(ctx->InputDType(\"b\", 0), dtype)\n      << \"InferDataType Failed. Expected \" << DataType_Name(dtype) << \", but got \"\n      << DataType_Name(ctx->InputDType(\"b\", 0));\n  if (ctx->has_input(\"_add_to_output\", 0)) {\n    CHECK_EQ_OR_RETURN(ctx->InputDType(\"_add_to_output\", 0), dtype)\n        << \"InferDataType Failed. Expected \" << DataType_Name(dtype) << \", but got \"\n        << DataType_Name(ctx->InputDType(\"_add_to_output\", 0));\n  }\n  ctx->SetOutputDType(\"out\", 0, dtype);\n  return Maybe<void>::Ok();\n}\n\n// Theoretically computation cost of matrix multiplication is the products of the number of matrix\n// and first dimension of matrix a, second dimension of matrix a, second dimension of matrix\n// b. If there is any splitting sbp parallel, the computation cost will be divided by number of\n// machines. If we use S(1) at matrix a and S(0) at matrix b, then it will be P at output matrix.\n// This is why we don't use SbpParallel at output matrix.\nMaybe<double> GetComputationCost(user_op::ComputeComplexityFnContext* ctx) {\n  bool transpose_b = ctx->Attr<bool>(\"transpose_b\");\n  const Shape& shape_b = ctx->Shape4ArgNameAndIndex(\"b\", 0);\n  int64_t n = 0;\n  if (!transpose_b) {\n    n = shape_b.At(shape_b.NumAxes() - 1);\n  } else {\n    n = shape_b.At(shape_b.NumAxes() - 2);\n  }\n\n  double logical_computation_cost = 2 * ctx->Shape4ArgNameAndIndex(\"a\", 0).elem_cnt() * n;\n  const auto& nd_sbp_a = ctx->NdSbp4ArgNameAndIndex(\"a\", 0);\n  const auto& nd_sbp_b = ctx->NdSbp4ArgNameAndIndex(\"b\", 0);\n  const auto& parallel_hierarchy = ctx->parallel_desc().hierarchy();\n  for (int32_t sbp_dim = 0; sbp_dim < nd_sbp_a.sbp_parallel_size(); sbp_dim++) {\n    if (nd_sbp_a.sbp_parallel(sbp_dim).has_split_parallel()\n        || nd_sbp_b.sbp_parallel(sbp_dim).has_split_parallel()) {\n      logical_computation_cost /= parallel_hierarchy->At(sbp_dim);\n    }\n  }\n  return logical_computation_cost;\n}\n\n}  // namespace\n\n/* static */ Maybe<void> MatmulOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  return InferTensorDesc4Matmul(ctx);\n}\n\n/*static*/ Maybe<void> MatmulOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/*static*/ Maybe<double> MatmulOp::GetComputeComplexity(user_op::ComputeComplexityFnContext* ctx) {\n  return GetComputationCost(ctx);\n}\n\n/* static */ Maybe<void> MatmulOp::GetSbp(user_op::SbpContext* ctx) {\n  // (m, k_a) * (k_b, n) where k_a == k_b\n  int32_t m_axis = -1;\n  int32_t k_a_axis = -1;\n  int32_t k_b_axis = -1;\n  int32_t n_axis = -1;\n  if (ctx->Attr<bool>(\"transpose_a\")) {\n    m_axis = 1;\n    k_a_axis = 0;\n  } else {\n    m_axis = 0;\n    k_a_axis = 1;\n  }\n  if (ctx->Attr<bool>(\"transpose_b\")) {\n    k_b_axis = 1;\n    n_axis = 0;\n  } else {\n    k_b_axis = 0;\n    n_axis = 1;\n  }\n  std::vector<user_op::OpArg> out_and_add_to_output_args;\n  out_and_add_to_output_args.emplace_back(\"out\", 0);\n  if (ctx->user_op_conf().has_input(\"_add_to_output\", 0)) {\n    out_and_add_to_output_args.emplace_back(\"_add_to_output\", 0);\n  }\n  ctx->NewBuilder()\n      .Split(user_op::OpArg(\"a\", 0), m_axis)\n      .Broadcast(user_op::OpArg(\"b\", 0))\n      .Split(out_and_add_to_output_args, 0)\n      .Build();\n  ctx->NewBuilder()\n      .Broadcast(user_op::OpArg(\"a\", 0))\n      .Split(user_op::OpArg(\"b\", 0), n_axis)\n      .Split(out_and_add_to_output_args, 1)\n      .Build();\n  ctx->NewBuilder()\n      .Split(user_op::OpArg(\"a\", 0), k_a_axis)\n      .Split(user_op::OpArg(\"b\", 0), k_b_axis)\n      .PartialSum(out_and_add_to_output_args)\n      .Build();\n  ctx->NewBuilder()\n      .PartialSum(user_op::OpArg(\"a\", 0))\n      .Broadcast(user_op::OpArg(\"b\", 0))\n      .PartialSum(out_and_add_to_output_args)\n      .Build();\n  ctx->NewBuilder()\n      .Broadcast(user_op::OpArg(\"a\", 0))\n      .PartialSum(user_op::OpArg(\"b\", 0))\n      .PartialSum(out_and_add_to_output_args)\n      .Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> MatmulOp::InferDataType(user_op::InferContext* ctx) {\n  return InferDataType4Matmul(ctx);\n}\n\n// BatchMatmul\n\n/* static */ Maybe<void> BatchMatmulOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  return InferTensorDesc4Matmul(ctx);\n}\n\n/*static*/ Maybe<void> BatchMatmulOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> BatchMatmulOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& a_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"a\", 0);\n  std::vector<user_op::OpArg> out_and_add_to_output_args;\n  out_and_add_to_output_args.emplace_back(\"out\", 0);\n  if (ctx->user_op_conf().has_input(\"_add_to_output\", 0)) {\n    out_and_add_to_output_args.emplace_back(\"_add_to_output\", 0);\n  }\n  int32_t num_axes = a_tensor.shape().NumAxes();\n  FOR_RANGE(int64_t, i, 0, num_axes - 2) {\n    ctx->NewBuilder().Split(ctx->inputs(), i).Split(out_and_add_to_output_args, i).Build();\n  }\n  int32_t m_axis = -1;\n  int32_t k_a_axis = -1;\n  int32_t k_b_axis = -1;\n  int32_t n_axis = -1;\n  if (ctx->Attr<bool>(\"transpose_a\")) {\n    m_axis = num_axes - 1;\n    k_a_axis = num_axes - 2;\n  } else {\n    m_axis = num_axes - 2;\n    k_a_axis = num_axes - 1;\n  }\n  if (ctx->Attr<bool>(\"transpose_b\")) {\n    k_b_axis = num_axes - 1;\n    n_axis = num_axes - 2;\n  } else {\n    k_b_axis = num_axes - 2;\n    n_axis = num_axes - 1;\n  }\n  ctx->NewBuilder()\n      .Split(user_op::OpArg(\"a\", 0), m_axis)\n      .Broadcast(user_op::OpArg(\"b\", 0))\n      .Split(out_and_add_to_output_args, num_axes - 2)\n      .Build();\n  ctx->NewBuilder()\n      .Broadcast(user_op::OpArg(\"a\", 0))\n      .Split(user_op::OpArg(\"b\", 0), n_axis)\n      .Split(out_and_add_to_output_args, num_axes - 1)\n      .Build();\n  ctx->NewBuilder()\n      .Split(user_op::OpArg(\"a\", 0), k_a_axis)\n      .Split(user_op::OpArg(\"b\", 0), k_b_axis)\n      .PartialSum(out_and_add_to_output_args)\n      .Build();\n  ctx->NewBuilder()\n      .PartialSum(user_op::OpArg(\"a\", 0))\n      .Broadcast(user_op::OpArg(\"b\", 0))\n      .PartialSum(out_and_add_to_output_args)\n      .Build();\n  ctx->NewBuilder()\n      .Broadcast(user_op::OpArg(\"a\", 0))\n      .PartialSum(user_op::OpArg(\"b\", 0))\n      .PartialSum(out_and_add_to_output_args)\n      .Build();\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<double> BatchMatmulOp::GetComputeComplexity(\n    user_op::ComputeComplexityFnContext* ctx) {\n  return GetComputationCost(ctx);\n}\n\n/* static */ Maybe<void> BatchMatmulOp::InferDataType(user_op::InferContext* ctx) {\n  return InferDataType4Matmul(ctx);\n}\n\n// BroadcastMatmul\n\n/* static */ Maybe<void> BroadcastMatmulOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  bool transpose_a = ctx->Attr<bool>(\"transpose_a\");\n  bool transpose_b = ctx->Attr<bool>(\"transpose_b\");\n\n  const user_op::TensorDesc& a = ctx->InputTensorDesc(\"a\", 0);\n  const user_op::TensorDesc& b = ctx->InputTensorDesc(\"b\", 0);\n  user_op::TensorDesc* out = ctx->MutOutputTensorDesc(\"out\", 0);\n\n  const int64_t num_a_dims = a.shape().NumAxes();\n  const int64_t num_b_dims = b.shape().NumAxes();\n  const size_t num_max_batch_dims = std::max(num_a_dims, num_b_dims) - 2;\n  auto MakeGetBatchDim = [num_max_batch_dims](size_t num_dims, const Shape& shape_dim) {\n    const int64_t num_batch_dims = num_dims - 2;\n    const int64_t num_padding_dims = num_max_batch_dims - num_batch_dims;\n    return [num_padding_dims, shape_dim](size_t index) {\n      return index < num_padding_dims ? 1 : shape_dim.At(index - num_padding_dims);\n    };\n  };\n  auto GetABatchDim = MakeGetBatchDim(num_a_dims, a.shape());\n  auto GetBBatchDim = MakeGetBatchDim(num_b_dims, b.shape());\n\n  DimVector out_dim_vec(std::max(num_a_dims, num_b_dims));\n  FOR_RANGE(int64_t, i, 0, out_dim_vec.size() - 2) {\n    // Set broadcast shape\n    //                       m  k          k  n\n    // For example: A(16, 1, 4, 8) B(1, 8, 8, 6)\n    // We First set the previous batch dims to broadcasted shape: C(16, 8)\n    // Then we emplace back m, n -> C(16, 8, 4, 6)\n    const int64_t a_batch_dim = GetABatchDim(i);\n    const int64_t b_batch_dim = GetBBatchDim(i);\n    CHECK(((a_batch_dim != 1 && b_batch_dim == 1) || (a_batch_dim == 1 && b_batch_dim != 1)\n           || (a_batch_dim == b_batch_dim)))\n        << \"Batch Dims could not broadcast, please check. \";\n    out_dim_vec[i] = std::max(a_batch_dim, b_batch_dim);\n  }\n  int64_t m = 0;\n  int64_t n = 0;\n  int64_t k = 0;  // tensor a (no trans): batch_dims*m*k, tensor b (no trans): batch_dims*k*n\n  if (!transpose_a) {\n    m = a.shape().At(num_a_dims - 2);\n    k = a.shape().At(num_a_dims - 1);\n  } else {\n    m = a.shape().At(num_a_dims - 1);\n    k = a.shape().At(num_a_dims - 2);\n  }\n  if (!transpose_b) {\n    CHECK_EQ_OR_RETURN(k, b.shape().At(num_b_dims - 2))\n        << \"K dim should be equal to b.shape().At(num_b_dims - 2). \";\n    n = b.shape().At(num_b_dims - 1);\n  } else {\n    CHECK_EQ_OR_RETURN(k, b.shape().At(num_b_dims - 1))\n        << \"K dim should be equal to b.shape().At(num_b_dims - 1). \";\n    n = b.shape().At(num_b_dims - 2);\n  }\n  out_dim_vec.at(num_max_batch_dims) = m;\n  out_dim_vec.at(num_max_batch_dims + 1) = n;\n  out->set_shape(Shape(out_dim_vec));\n\n  if (ctx->has_input(\"_add_to_output\", 0)) {\n    const user_op::TensorDesc& add_to_output = ctx->InputTensorDesc(\"_add_to_output\", 0);\n    CHECK_EQ_OR_RETURN(add_to_output.shape(), out->shape());\n  }\n\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> BroadcastMatmulOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> BroadcastMatmulOp::GetSbp(user_op::SbpContext* ctx) {\n  // (b, m, k) * (k, n) when transpose_b is false\n  // (b, m, k) * (n, k) when transpose_b is true\n  bool transpose_a = ctx->Attr<bool>(\"transpose_a\");\n  bool transpose_b = ctx->Attr<bool>(\"transpose_b\");\n\n  const auto& a_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"a\", 0).shape();\n  const auto& b_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"b\", 0).shape();\n\n  const int64_t a_num_axes = a_shape.NumAxes();\n  const int64_t b_num_axes = b_shape.NumAxes();\n\n  int32_t m_a_axis = -1;\n  int32_t k_a_axis = -1;\n  int32_t k_b_axis = -1;\n  int32_t n_axis = -1;\n\n  if (transpose_a) {\n    m_a_axis = a_num_axes - 1;\n    k_a_axis = a_num_axes - 2;\n  } else {\n    m_a_axis = a_num_axes - 2;\n    k_a_axis = a_num_axes - 1;\n  }\n  if (transpose_b) {\n    k_b_axis = b_num_axes - 1;\n    n_axis = b_num_axes - 2;\n  } else {\n    k_b_axis = b_num_axes - 2;\n    n_axis = b_num_axes - 1;\n  }\n\n  std::vector<user_op::OpArg> out_and_add_to_output_args;\n  out_and_add_to_output_args.emplace_back(\"out\", 0);\n  if (ctx->user_op_conf().has_input(\"_add_to_output\", 0)) {\n    out_and_add_to_output_args.emplace_back(\"_add_to_output\", 0);\n  }\n\n  const int64_t a_batch_dims = a_num_axes - 2;\n  const int64_t b_batch_dims = b_num_axes - 2;\n  const int64_t max_num_axes = std::max(a_num_axes, b_num_axes);\n  const size_t num_max_batch_dims = max_num_axes - 2;\n  auto MakeGetBatchDim = [num_max_batch_dims](size_t num_dims, const Shape& shape_dim) {\n    const int64_t num_batch_dims = num_dims - 2;\n    const int64_t num_padding_dims = num_max_batch_dims - num_batch_dims;\n    return [num_padding_dims, shape_dim](size_t index) {\n      return index < num_padding_dims ? 1 : shape_dim.At(index - num_padding_dims);\n    };\n  };\n  auto GetABatchDim = MakeGetBatchDim(a_num_axes, a_shape);\n  auto GetBBatchDim = MakeGetBatchDim(b_num_axes, b_shape);\n\n  for (int i = 0; i < num_max_batch_dims; i++) {\n    const int64_t a_batch_dim = GetABatchDim(i);\n    const int64_t b_batch_dim = GetBBatchDim(i);\n\n    if (a_batch_dim == b_batch_dim && a_batch_dim != 1) {\n      // S(b axis) x S(b axis) -> S(b axis)\n      ctx->NewBuilder()\n          .Split(user_op::OpArg(\"a\", 0), i - (num_max_batch_dims - a_batch_dims))\n          .Split(user_op::OpArg(\"b\", 0), i - (num_max_batch_dims - b_batch_dims))\n          .Split(out_and_add_to_output_args, i)\n          .Build();\n    } else if (a_batch_dim == 1 && b_batch_dim != 1) {\n      // B x S(b axis) -> S(b axis)\n      ctx->NewBuilder()\n          .Broadcast(user_op::OpArg(\"a\", 0))\n          .Split(user_op::OpArg(\"b\", 0), i - (num_max_batch_dims - b_batch_dims))\n          .Split(out_and_add_to_output_args, i)\n          .Build();\n    } else if (b_batch_dim == 1 && a_batch_dim != 1) {\n      // S(b axis) x B -> S(b axis)\n      ctx->NewBuilder()\n          .Split(user_op::OpArg(\"a\", 0), i - (num_max_batch_dims - a_batch_dims))\n          .Broadcast(user_op::OpArg(\"b\", 0))\n          .Split(out_and_add_to_output_args, i)\n          .Build();\n    }\n  }\n\n  // S(m axis) x B -> S(m axis)\n  ctx->NewBuilder()\n      .Split(user_op::OpArg(\"a\", 0), m_a_axis)\n      .Broadcast(user_op::OpArg(\"b\", 0))\n      .Split(out_and_add_to_output_args, max_num_axes - 2)\n      .Build();\n\n  // B x S(n_axis) -> S(n_axis)\n  ctx->NewBuilder()\n      .Broadcast(user_op::OpArg(\"a\", 0))\n      .Split(user_op::OpArg(\"b\", 0), n_axis)\n      .Split(out_and_add_to_output_args, max_num_axes - 1)\n      .Build();\n\n  // S(a_k_axis) x S(b_k_axis) -> P\n  ctx->NewBuilder()\n      .Split(user_op::OpArg(\"a\", 0), k_a_axis)\n      .Split(user_op::OpArg(\"b\", 0), k_b_axis)\n      .PartialSum(out_and_add_to_output_args)\n      .Build();\n\n  // P x B -> P\n  ctx->NewBuilder()\n      .PartialSum(user_op::OpArg(\"a\", 0))\n      .Broadcast(user_op::OpArg(\"b\", 0))\n      .PartialSum(out_and_add_to_output_args)\n      .Build();\n\n  // B x P -> P\n  ctx->NewBuilder()\n      .Broadcast(user_op::OpArg(\"a\", 0))\n      .PartialSum(user_op::OpArg(\"b\", 0))\n      .PartialSum(out_and_add_to_output_args)\n      .Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> BroadcastMatmulOp::InferDataType(user_op::InferContext* ctx) {\n  return InferDataType4Matmul(ctx);\n}\n\n/*static*/ Maybe<double> BroadcastMatmulOp::GetComputeComplexity(\n    user_op::ComputeComplexityFnContext* ctx) {\n  return GetComputationCost(ctx);\n}\n\n/* static */ Maybe<void> BroadcastMatmulGradBOp::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) {\n  const user_op::TensorDesc& a = ctx->InputTensorDesc(\"a\", 0);\n  const user_op::TensorDesc& b = ctx->InputTensorDesc(\"b\", 0);\n  user_op::TensorDesc* out = ctx->MutOutputTensorDesc(\"out\", 0);\n\n  CHECK_EQ_OR_RETURN(a.shape().NumAxes(), b.shape().NumAxes());\n  for (int i = 0; i < a.shape().NumAxes() - 1; ++i) {\n    CHECK_EQ_OR_RETURN(a.shape().At(i), b.shape().At(i));\n  }\n  out->set_shape(\n      Shape({a.shape().At(a.shape().NumAxes() - 1), b.shape().At(b.shape().NumAxes() - 1)}));\n\n  if (ctx->has_input(\"_add_to_output\", 0)) {\n    const user_op::TensorDesc& add_to_output = ctx->InputTensorDesc(\"_add_to_output\", 0);\n    CHECK_EQ_OR_RETURN(add_to_output.shape(), out->shape());\n  }\n\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> BroadcastMatmulGradBOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/*static*/ Maybe<double> BroadcastMatmulGradBOp::GetComputeComplexity(\n    user_op::ComputeComplexityFnContext* ctx) {\n  const Shape& shape_a = ctx->Shape4ArgNameAndIndex(\"a\", 0);\n  int64_t n = shape_a.At(shape_a.NumAxes() - 2);\n\n  double logical_computation_cost = 2 * ctx->Shape4ArgNameAndIndex(\"b\", 0).elem_cnt() * n;\n  const auto& nd_sbp_a = ctx->NdSbp4ArgNameAndIndex(\"a\", 0);\n  const auto& nd_sbp_b = ctx->NdSbp4ArgNameAndIndex(\"b\", 0);\n  const auto& parallel_hierarchy = ctx->parallel_desc().hierarchy();\n  for (int32_t sbp_dim = 0; sbp_dim < nd_sbp_a.sbp_parallel_size(); sbp_dim++) {\n    if (nd_sbp_a.sbp_parallel(sbp_dim).has_split_parallel()\n        || nd_sbp_b.sbp_parallel(sbp_dim).has_split_parallel()) {\n      logical_computation_cost /= parallel_hierarchy->At(sbp_dim);\n    }\n  }\n  return logical_computation_cost;\n}\n/* static */ Maybe<void> BroadcastMatmulGradBOp::GetSbp(user_op::SbpContext* ctx) {\n  const auto& a_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"a\", 0).shape();\n  int64_t last_axis = a_shape.NumAxes() - 1;\n  std::vector<user_op::OpArg> out_and_add_to_output_args;\n  out_and_add_to_output_args.emplace_back(\"out\", 0);\n  if (ctx->user_op_conf().has_input(\"_add_to_output\", 0)) {\n    out_and_add_to_output_args.emplace_back(\"_add_to_output\", 0);\n  }\n  // S(b or m axis) x S(b or m axis) -> P\n  for (int64_t i = 0; i < last_axis; ++i) {\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"a\", 0), i)\n        .Split(user_op::OpArg(\"b\", 0), i)\n        .PartialSum(out_and_add_to_output_args)\n        .Build();\n  }\n  // (b, m, k) * (b, m, n) -> (k, n) [transpose a]\n  // S(k) x B -> S(0) or B x S(n) -> S(1)\n  // (b, m, n) * (b, m, k) -> (n, k) [transpose a]\n  // S(n) x B -> S(0) or B x S(k) -> S(1)\n  ctx->NewBuilder()\n      .Split(user_op::OpArg(\"a\", 0), last_axis)\n      .Broadcast(user_op::OpArg(\"b\", 0))\n      .Split(out_and_add_to_output_args, 0)\n      .Build();\n  ctx->NewBuilder()\n      .Broadcast(user_op::OpArg(\"a\", 0))\n      .Split(user_op::OpArg(\"b\", 0), last_axis)\n      .Split(out_and_add_to_output_args, 1)\n      .Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> BroadcastMatmulGradBOp::InferDataType(user_op::InferContext* ctx) {\n  return InferDataType4Matmul(ctx);\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/matrix_vector_product_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nMaybe<void> InferTensorDesc4MatrixVectorProduct(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& a = ctx->InputTensorDesc(\"a\", 0);\n  const user_op::TensorDesc& b = ctx->InputTensorDesc(\"b\", 0);\n  int64_t m = a.shape().At(0);\n  int64_t k = a.shape().At(1);\n  CHECK_EQ_OR_RETURN(k, b.shape().At(0)) << \"Dim K should be equal to vector b's dim0. \";\n  ctx->SetOutputShape(\"out\", 0, Shape({m}));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InferDataType4MatrixVectorProduct(user_op::InferContext* ctx) {\n  DataType dtype = ctx->InputDType(\"a\", 0);\n  CHECK_EQ_OR_RETURN(ctx->InputDType(\"b\", 0), dtype)\n      << \"InferDataType Failed. Expected \" << DataType_Name(dtype) << \", but got \"\n      << DataType_Name(ctx->InputDType(\"b\", 0));\n  ctx->SetOutputDType(\"out\", 0, dtype);\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InferTensorDesc4MatrixVectorProductGradA(user_op::InferContext* ctx) {\n  /*\n  A(m, k) matmul B(k) -> (m, k) matmul (k, 1) -> (m, 1) -> (m)\n  GradA = dy (m) matmul B(k) -> (m, 1) (k, 1)_transpose\n  */\n  const user_op::TensorDesc& dy = ctx->InputTensorDesc(\"dy\", 0);\n  const user_op::TensorDesc& b = ctx->InputTensorDesc(\"b\", 0);\n  int64_t m = dy.shape().At(0);\n  int64_t n = b.shape().At(0);\n  ctx->SetOutputShape(\"dx\", 0, Shape({m, n}));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InferTensorDesc4MatrixVectorProductGradB(user_op::InferContext* ctx) {\n  /*\n  A(m, k) matmul B(k) -> (m, k) matmul (k, 1) -> (m, 1) -> (m)\n  GradB = dy_transpose (1, m) matmul A(m, k)\n  */\n  const user_op::TensorDesc& a = ctx->InputTensorDesc(\"a\", 0);\n  int64_t n = a.shape().At(1);\n  ctx->SetOutputShape(\"dx\", 0, Shape({n}));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InferDataType4Grad(user_op::InferContext* ctx) {\n  DataType dtype = ctx->InputDType(\"dy\", 0);\n  ctx->SetOutputDType(\"dx\", 0, dtype);\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\n/* static */ Maybe<void> MatrixVectorProductOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  return InferTensorDesc4MatrixVectorProduct(ctx);\n}\n\n/*static*/ Maybe<void> MatrixVectorProductOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> MatrixVectorProductOp::GetSbp(user_op::SbpContext* ctx) {\n  ctx->NewBuilder()\n      .Split(user_op::OpArg(\"a\", 0), 0)\n      .Broadcast(user_op::OpArg(\"b\", 0))\n      .Split(user_op::OpArg(\"out\", 0), 0)\n      .Build();\n  ctx->NewBuilder()\n      .Split(user_op::OpArg(\"a\", 0), 1)\n      .Split(user_op::OpArg(\"b\", 0), 0)\n      .PartialSum(user_op::OpArg(\"out\", 0))\n      .Build();\n  ctx->NewBuilder()\n      .PartialSum(user_op::OpArg(\"a\", 0))\n      .Broadcast(user_op::OpArg(\"b\", 0))\n      .PartialSum(user_op::OpArg(\"out\", 0))\n      .Build();\n  ctx->NewBuilder()\n      .Broadcast(user_op::OpArg(\"a\", 0))\n      .PartialSum(user_op::OpArg(\"b\", 0))\n      .PartialSum(user_op::OpArg(\"out\", 0))\n      .Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> MatrixVectorProductOp::InferDataType(user_op::InferContext* ctx) {\n  return InferDataType4MatrixVectorProduct(ctx);\n}\n\n/* static */ Maybe<void> MatrixVectorProductGradAOp::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferTensorDesc4MatrixVectorProductGradA(ctx);\n}\n\n/*static*/ Maybe<void> MatrixVectorProductGradAOp::InferPhysicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> MatrixVectorProductGradAOp::GetSbp(user_op::SbpContext* ctx) {\n  /*\n  A(m, k) matmul B(k) -> (m, k) matmul (k, 1) -> (m, 1) -> (m)\n  GradA = dy (m) matmul B(k) -> (m, 1) (k, 1)_transpose\n  */\n  ctx->NewBuilder()\n      .Split(user_op::OpArg(\"dy\", 0), 0)\n      .Broadcast(user_op::OpArg(\"b\", 0))\n      .Split(user_op::OpArg(\"dx\", 0), 0)\n      .Build();\n  ctx->NewBuilder()\n      .PartialSum(user_op::OpArg(\"dy\", 0))\n      .Broadcast(user_op::OpArg(\"b\", 0))\n      .PartialSum(user_op::OpArg(\"dx\", 0))\n      .Build();\n  ctx->NewBuilder()\n      .Broadcast(user_op::OpArg(\"dy\", 0))\n      .PartialSum(user_op::OpArg(\"b\", 0))\n      .PartialSum(user_op::OpArg(\"dx\", 0))\n      .Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> MatrixVectorProductGradAOp::InferDataType(user_op::InferContext* ctx) {\n  return InferDataType4Grad(ctx);\n}\n\n/* static */ Maybe<void> MatrixVectorProductGradBOp::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferTensorDesc4MatrixVectorProductGradB(ctx);\n}\n\n/*static*/ Maybe<void> MatrixVectorProductGradBOp::InferPhysicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> MatrixVectorProductGradBOp::GetSbp(user_op::SbpContext* ctx) {\n  /*\n  A(m, k) matmul B(k) -> (m, k) matmul (k, 1) -> (m, 1) -> (m)\n  dy = (m, )\n  GradB = dy_transpose (1, m) matmul A(m, k)\n  */\n  ctx->NewBuilder()\n      .Broadcast(user_op::OpArg(\"dy\", 0))\n      .Split(user_op::OpArg(\"a\", 0), 1)\n      .Split(user_op::OpArg(\"dx\", 0), 0)\n      .Build();\n  ctx->NewBuilder()\n      .Split(user_op::OpArg(\"dy\", 0), 0)\n      .Split(user_op::OpArg(\"a\", 0), 0)\n      .PartialSum(user_op::OpArg(\"dx\", 0))\n      .Build();\n  ctx->NewBuilder()\n      .PartialSum(user_op::OpArg(\"dy\", 0))\n      .Broadcast(user_op::OpArg(\"a\", 0))\n      .PartialSum(user_op::OpArg(\"dx\", 0))\n      .Build();\n  ctx->NewBuilder()\n      .Broadcast(user_op::OpArg(\"dy\", 0))\n      .PartialSum(user_op::OpArg(\"a\", 0))\n      .PartialSum(user_op::OpArg(\"dx\", 0))\n      .Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> MatrixVectorProductGradBOp::InferDataType(user_op::InferContext* ctx) {\n  return InferDataType4Grad(ctx);\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/max_pool_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/user/kernels/max_pool_kernel_util.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntypedef std::function<Maybe<void>(user_op::InferContext* ctx)> TensorDescInferFn;\n\nTensorDescInferFn MaxPoolMakeForwardTensorDescInferFn(const int32_t dim) {\n  return [dim](user_op::InferContext* ctx) -> Maybe<void> {\n    const Shape& x_shape = ctx->InputShape(\"x\", 0);\n    const std::string& data_format = ctx->Attr<std::string>(\"data_format\");\n    const std::vector<int32_t>& padding = ctx->Attr<std::vector<int32_t>>(\"padding\");\n    const std::vector<int32_t>& kernel_size = ctx->Attr<std::vector<int32_t>>(\"kernel_size\");\n    const std::vector<int32_t>& stride = ctx->Attr<std::vector<int32_t>>(\"stride\");\n    const std::vector<int32_t>& dilation = ctx->Attr<std::vector<int32_t>>(\"dilation\");\n    const bool return_indices = ctx->Attr<bool>(\"return_indices\");\n    const bool ceil_mode = ctx->Attr<bool>(\"ceil_mode\");\n\n    CHECK_EQ_OR_RETURN(kernel_size.size(), dim);\n    for (int32_t pool_dim : kernel_size) { CHECK_GT_OR_RETURN(pool_dim, 0); }\n    CHECK_EQ_OR_RETURN(stride.size(), dim);\n    for (int32_t stride_dim : stride) { CHECK_GT_OR_RETURN(stride_dim, 0); }\n    for (int32_t i = 0; i < padding.size(); i++) {\n      CHECK_GE_OR_RETURN(kernel_size[i], 2 * padding[i])\n          << \"pad should be smaller than half of kernel size\";\n    }\n\n    const MaxPoolParams3D params_3d(dim, x_shape, data_format, padding, kernel_size, stride,\n                                    dilation, return_indices, ceil_mode);\n    user_op::TensorDesc* y_desc = ctx->MutOutputTensorDesc(\"y\", 0);\n    *y_desc = ctx->InputTensorDesc(\"x\", 0);\n    y_desc->set_shape(params_3d.GetYShape());\n\n    user_op::TensorDesc* indice_desc = ctx->MutOutputTensorDesc(\"indice\", 0);\n    *indice_desc = *ctx->MutOutputTensorDesc(\"y\", 0);\n    indice_desc->set_shape(y_desc->shape());\n    indice_desc->set_data_type(kInt64);\n    return Maybe<void>::Ok();\n  };\n}\n\nMaybe<void> MaxPoolForwardGetSbpFn(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"x\", 0);\n  FOR_RANGE(int64_t, i, 0, std::min(2, (int)tensor.shape().NumAxes() - 2)) {\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"x\", 0), i)\n        .Split(user_op::OpArg(\"y\", 0), i)\n        .Split(user_op::OpArg(\"indice\", 0), i)\n        .Build();\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> MaxPoolBackwardGetSbpFn(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"x\", 0);\n  FOR_RANGE(int64_t, i, 0, std::min(2, (int)tensor.shape().NumAxes())) {\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"x\", 0), i)\n        .Split(user_op::OpArg(\"indice\", 0), i)\n        .Split(user_op::OpArg(\"dy\", 0), i)\n        .Split(user_op::OpArg(\"dx\", 0), i)\n        .Build();\n  }\n  return Maybe<void>::Ok();\n}\n\n// Logically computation cost of pool op is the product of output data amount and pool kernal data\n// amount. After adding sbp, we just divide it by parallel number if output data is splitted because\n// splitting input and using partial sum for output is not a valid sbp for this op for now.\nMaybe<double> GetComputationCost(user_op::ComputeComplexityFnContext* ctx,\n                                 const std::string& blob_name) {\n  const std::vector<int32_t>& pool_size = ctx->Attr<std::vector<int32_t>>(\"kernel_size\");\n  double logical_computation_cost = std::accumulate(\n      pool_size.begin(), pool_size.end(), ctx->Shape4ArgNameAndIndex(blob_name, 0).elem_cnt(),\n      std::multiplies<double>());\n  const auto& parallel_hierarchy = ctx->parallel_desc().hierarchy();\n  const auto& nd_sbp_y = ctx->NdSbp4ArgNameAndIndex(blob_name, 0);\n  for (int32_t dim_sbp = 0; dim_sbp < nd_sbp_y.sbp_parallel_size(); dim_sbp++) {\n    if (nd_sbp_y.sbp_parallel(dim_sbp).has_split_parallel()) {\n      logical_computation_cost /= parallel_hierarchy->At(dim_sbp);\n    }\n  }\n  return logical_computation_cost;\n}\n\nMaybe<void> BackwardTensorDescInferFn(user_op::InferContext* ctx) {\n  *ctx->MutOutputTensorDesc(\"dx\", 0) = ctx->InputTensorDesc(\"x\", 0);\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> FwInferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"y\", 0, ctx->InputDType(\"x\", 0));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> BwInferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"dx\", 0, ctx->InputDType(\"x\", 0));\n  return Maybe<void>::Ok();\n}\n}  // namespace\n\n#define IMPLEMENT_MAXPOOL_FUNCS(name, dim)                                               \\\n  /*static*/ Maybe<void> name##Op::GetSbp(user_op::SbpContext* ctx) {                    \\\n    return MaxPoolForwardGetSbpFn(ctx);                                                  \\\n  }                                                                                      \\\n  /*static*/ Maybe<void> name##Op::InferLogicalTensorDesc(user_op::InferContext* ctx) {  \\\n    return MaxPoolMakeForwardTensorDescInferFn(dim)(ctx);                                \\\n  }                                                                                      \\\n  /*static*/ Maybe<void> name##Op::InferPhysicalTensorDesc(user_op::InferContext* ctx) { \\\n    return InferLogicalTensorDesc(ctx);                                                  \\\n  }                                                                                      \\\n  /*static*/ Maybe<void> name##Op::InferDataType(user_op::InferContext* ctx) {           \\\n    return FwInferDataType(ctx);                                                         \\\n  }                                                                                      \\\n  /*static*/ Maybe<double> name##Op::GetComputeComplexity(                               \\\n      user_op::ComputeComplexityFnContext* ctx) {                                        \\\n    return GetComputationCost(ctx, \"y\");                                                 \\\n  }\n\nIMPLEMENT_MAXPOOL_FUNCS(MaxPool1D, 1)\nIMPLEMENT_MAXPOOL_FUNCS(MaxPool2D, 2)\nIMPLEMENT_MAXPOOL_FUNCS(MaxPool3D, 3)\n#undef IMPLEMENT_MAXPOOL_FUNCS\n\n#define IMPLEMENT_MAXPOOL_BACKWARD_FUNCS(name)                                               \\\n  /*static*/ Maybe<void> name##GradOp::GetSbp(user_op::SbpContext* ctx) {                    \\\n    return MaxPoolBackwardGetSbpFn(ctx);                                                     \\\n  }                                                                                          \\\n  /*static*/ Maybe<void> name##GradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {  \\\n    return BackwardTensorDescInferFn(ctx);                                                   \\\n  }                                                                                          \\\n  /*static*/ Maybe<void> name##GradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { \\\n    return InferLogicalTensorDesc(ctx);                                                      \\\n  }                                                                                          \\\n  /*static*/ Maybe<void> name##GradOp::InferDataType(user_op::InferContext* ctx) {           \\\n    return BwInferDataType(ctx);                                                             \\\n  }                                                                                          \\\n  /*static*/ Maybe<double> name##GradOp::GetComputeComplexity(                               \\\n      user_op::ComputeComplexityFnContext* ctx) {                                            \\\n    return GetComputationCost(ctx, \"dy\");                                                    \\\n  }\n\nIMPLEMENT_MAXPOOL_BACKWARD_FUNCS(MaxPool1D)\nIMPLEMENT_MAXPOOL_BACKWARD_FUNCS(MaxPool2D)\nIMPLEMENT_MAXPOOL_BACKWARD_FUNCS(MaxPool3D)\n#undef IMPLEMENT_MAXPOOL_BACKWARD_FUNCS\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/max_unpool_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/shape.h\"\n#include \"oneflow/user/kernels/max_unpool_kernel_util.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntypedef std::function<Maybe<void>(user_op::InferContext* ctx)> TensorDescInferFn;\n\nTensorDescInferFn MaxUnpoolMakeForwardTensorDescInferFn(const int32_t dim) {\n  return [dim](user_op::InferContext* ctx) -> Maybe<void> {\n    const Shape& x_shape = ctx->InputShape(\"x\", 0);\n    const std::vector<int32_t>& padding = ctx->Attr<std::vector<int32_t>>(\"padding\");\n    const std::vector<int32_t>& kernel_size = ctx->Attr<std::vector<int32_t>>(\"kernel_size\");\n    const std::vector<int32_t>& stride = ctx->Attr<std::vector<int32_t>>(\"stride\");\n    Shape output_shape = Shape();\n    if (ctx->Attr<bool>(\"has_output_size\")) {\n      output_shape = ctx->Attr<Shape>(\"output_size\");\n    } else {\n      const MaxUnpoolParams3D params_3d(dim, x_shape, padding, kernel_size, stride);\n      output_shape = params_3d.GetYShape();\n    }\n\n    user_op::TensorDesc* y_desc = ctx->MutOutputTensorDesc(\"y\", 0);\n    *y_desc = ctx->InputTensorDesc(\"x\", 0);\n    y_desc->set_shape(output_shape);\n\n    return Maybe<void>::Ok();\n  };\n}\n\nMaybe<void> MaxUnpoolForwardGetSbpFn(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"x\", 0);\n  FOR_RANGE(int64_t, i, 0, std::min(2L, tensor.shape().NumAxes() - 2)) {\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"x\", 0), i)\n        .Split(user_op::OpArg(\"indices\", 0), i)\n        .Split(user_op::OpArg(\"y\", 0), i)\n        .Build();\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> MaxUnpoolBackwardGetSbpFn(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"x\", 0);\n  FOR_RANGE(int64_t, i, 0, std::min(2L, tensor.shape().NumAxes())) {\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"x\", 0), i)\n        .Split(user_op::OpArg(\"indices\", 0), i)\n        .Split(user_op::OpArg(\"dy\", 0), i)\n        .Split(user_op::OpArg(\"dx\", 0), i)\n        .Build();\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> BackwardTensorDescInferFn(user_op::InferContext* ctx) {\n  *ctx->MutOutputTensorDesc(\"dx\", 0) = ctx->InputTensorDesc(\"x\", 0);\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> FwInferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"y\", 0, ctx->InputDType(\"x\", 0));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> BwInferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"dx\", 0, ctx->InputDType(\"x\", 0));\n  return Maybe<void>::Ok();\n}\n}  // namespace\n\n#define IMPLEMENT_MAXUNPOOL_FUNCS(name, dim)                                             \\\n  /*static*/ Maybe<void> name##Op::GetSbp(user_op::SbpContext* ctx) {                    \\\n    return MaxUnpoolForwardGetSbpFn(ctx);                                                \\\n  }                                                                                      \\\n  /*static*/ Maybe<void> name##Op::InferLogicalTensorDesc(user_op::InferContext* ctx) {  \\\n    return MaxUnpoolMakeForwardTensorDescInferFn(dim)(ctx);                              \\\n  }                                                                                      \\\n  /*static*/ Maybe<void> name##Op::InferPhysicalTensorDesc(user_op::InferContext* ctx) { \\\n    return InferLogicalTensorDesc(ctx);                                                  \\\n  }                                                                                      \\\n  /*static*/ Maybe<void> name##Op::InferDataType(user_op::InferContext* ctx) {           \\\n    return FwInferDataType(ctx);                                                         \\\n  }\n\nIMPLEMENT_MAXUNPOOL_FUNCS(MaxUnpool1D, 1)\nIMPLEMENT_MAXUNPOOL_FUNCS(MaxUnpool2D, 2)\nIMPLEMENT_MAXUNPOOL_FUNCS(MaxUnpool3D, 3)\n#undef IMPLEMENT_MAXUNPOOL_FUNCS\n\n#define IMPLEMENT_MAXUNPOOL_BACKWARD_FUNCS(name)                                             \\\n  /*static*/ Maybe<void> name##GradOp::GetSbp(user_op::SbpContext* ctx) {                    \\\n    return MaxUnpoolBackwardGetSbpFn(ctx);                                                   \\\n  }                                                                                          \\\n  /*static*/ Maybe<void> name##GradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {  \\\n    return BackwardTensorDescInferFn(ctx);                                                   \\\n  }                                                                                          \\\n  /*static*/ Maybe<void> name##GradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { \\\n    return InferLogicalTensorDesc(ctx);                                                      \\\n  }                                                                                          \\\n  /*static*/ Maybe<void> name##GradOp::InferDataType(user_op::InferContext* ctx) {           \\\n    return BwInferDataType(ctx);                                                             \\\n  }\n\nIMPLEMENT_MAXUNPOOL_BACKWARD_FUNCS(MaxUnpool1D)\nIMPLEMENT_MAXUNPOOL_BACKWARD_FUNCS(MaxUnpool2D)\nIMPLEMENT_MAXUNPOOL_BACKWARD_FUNCS(MaxUnpool3D)\n#undef IMPLEMENT_MAXUNPOOL_BACKWARD_FUNCS\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/median_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/*static*/ Maybe<void> MedianOp::GetSbp(user_op::SbpContext* ctx) {\n  const auto& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"input\", 0);\n  int64_t num_axes = in_tensor.shape().NumAxes();\n  if (num_axes == 0) {\n    ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build();\n  }\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> MedianOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& ones_shape = {1};\n  ctx->SetOutputShape(\"output\", 0, ones_shape.RemoveOnes({0}));\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> MedianOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> MedianOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"output\", 0, ctx->InputDType(\"input\", 0));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/median_with_indices_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/*static*/ Maybe<void> MedianWithIndicesOp::GetSbp(user_op::SbpContext* ctx) {\n  const auto& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"input\", 0);\n  int64_t num_axes = in_tensor.shape().NumAxes();\n  FOR_RANGE(int64_t, i, 0, num_axes - 1) {\n    ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build();\n  }\n  if (num_axes == 0) {\n    ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build();\n  }\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> MedianWithIndicesOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& input_shape = ctx->InputShape(\"input\", 0);\n  const Shape& reduce_shape = CreateReducedShape(input_shape, {-1});\n  ctx->SetOutputShape(\"values\", 0, reduce_shape.RemoveOnes({-1}));\n  ctx->SetOutputShape(\"indices\", 0, reduce_shape.RemoveOnes({-1}));\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> MedianWithIndicesOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> MedianWithIndicesOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"values\", 0, ctx->InputDType(\"input\", 0));\n  ctx->SetOutputDType(\"indices\", 0, DataType::kInt64);\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/min_max_observer_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/* static */ Maybe<void> MinMaxObserverOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& in_shape = ctx->InputShape(\"in\", 0);\n\n  if (ctx->Attr<std::string>(\"quantization_formula\") == \"google\") {\n    if (ctx->Attr<bool>(\"per_layer_quantization\") == true) {\n      ctx->SetOutputShape(\"scale\", 0, Shape({1}));\n      ctx->SetOutputShape(\"zero_point\", 0, Shape({1}));\n    } else {\n      // NOTE(Liang Depeng): For now per-channel quantization only support axis 0\n      ctx->SetOutputShape(\"scale\", 0, Shape({in_shape.At(0)}));\n      ctx->SetOutputShape(\"zero_point\", 0, Shape({in_shape.At(0)}));\n    }\n  } else {  // quantization_formula == \"cambricon\"\n    ctx->SetOutputShape(\"scale\", 0, Shape({1}));\n    ctx->SetOutputShape(\"zero_point\", 0, Shape({1}));\n  }\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> MinMaxObserverOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> MinMaxObserverOp::GetSbp(user_op::SbpContext* ctx) {\n  // NOTE(Liang Depeng): input needs to be broadcast in order to accurately calculate the\n  // global scale and zero_point\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> MinMaxObserverOp::ModifyInputArg(\n    const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) {\n  user_op::InputArgModifier* in = GetInputArgModifierFn(\"in\", 0);\n  CHECK_OR_RETURN(in != nullptr);\n  in->set_requires_grad(false);\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> MinMaxObserverOp::CheckAttr(const user_op::UserOpDefWrapper& def,\n                                                     const user_op::UserOpConfWrapper& op_conf) {\n  int32_t quantization_bit = op_conf.attr<int32_t>(\"quantization_bit\");\n  CHECK_GT_OR_RETURN(quantization_bit, 1);\n  CHECK_LE_OR_RETURN(quantization_bit, 8);\n\n  std::string quantization_scheme = op_conf.attr<std::string>(\"quantization_scheme\");\n  CHECK_OR_RETURN(quantization_scheme == \"symmetric\" || quantization_scheme == \"affine\");\n\n  std::string quantization_formula = op_conf.attr<std::string>(\"quantization_formula\");\n  CHECK_OR_RETURN(quantization_formula == \"google\" || quantization_formula == \"cambricon\");\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> MinMaxObserverOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"scale\", 0, ctx->InputDType(\"in\", 0));\n  ctx->SetOutputDType(\"zero_point\", 0, ctx->InputDType(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/mish_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/* static */ Maybe<void> MishOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  ctx->SetOutputShape(\"out\", 0, ctx->InputShape(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> MishOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> MishOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"in\", 0);\n  FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) {\n    ctx->NewBuilder().Split(user_op::OpArg(\"in\", 0), i).Split(user_op::OpArg(\"out\", 0), i).Build();\n  }\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> MishOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> MishGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& x_shape = ctx->InputShape(\"x\", 0);\n  const Shape& dy_shape = ctx->InputShape(\"dy\", 0);\n  CHECK_OR_RETURN(dy_shape == x_shape);\n  ctx->SetOutputShape(\"dx\", 0, dy_shape);\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> MishGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> MishGradOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"x\", 0);\n  FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) {\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"x\", 0), i)\n        .Split(user_op::OpArg(\"dy\", 0), i)\n        .Split(user_op::OpArg(\"dx\", 0), i)\n        .Build();\n  }\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> MishGradOp::InferDataType(user_op::InferContext* ctx) {\n  CHECK_EQ_OR_RETURN(ctx->InputDType(\"dy\", 0), ctx->InputDType(\"x\", 0))\n      << \"InferDataType Failed. Expected \" << DataType_Name(ctx->InputDType(\"dy\", 0))\n      << \", but got \" << DataType_Name(ctx->InputDType(\"x\", 0));\n  ctx->SetOutputDType(\"dx\", 0, ctx->InputDType(\"x\", 0));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/mode_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/*static*/ Maybe<void> ModeOp::GetSbp(user_op::SbpContext* ctx) {\n  const auto& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"input\", 0);\n  int64_t num_axes = in_tensor.shape().NumAxes();\n  FOR_RANGE(int64_t, i, 0, num_axes - 1) {\n    ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build();\n  }\n  if (num_axes == 0) {\n    ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build();\n  }\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> ModeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& input_shape = ctx->InputShape(\"input\", 0);\n  const Shape& reduce_shape = CreateReducedShape(input_shape, {-1});\n  ctx->SetOutputShape(\"values\", 0, reduce_shape.RemoveOnes({-1}));\n  ctx->SetOutputShape(\"indices\", 0, reduce_shape.RemoveOnes({-1}));\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> ModeOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> ModeOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"values\", 0, ctx->InputDType(\"input\", 0));\n  ctx->SetOutputDType(\"indices\", 0, DataType::kInt64);\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/model_update_ops.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/infer_util.h\"\n#include \"oneflow/core/framework/user_op_conf.h\"\n#include \"oneflow/core/framework/user_op_registry.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nMaybe<void> CheckShapeLike(const user_op::TensorDesc* tensor_desc,\n                           const user_op::TensorDesc* like) {\n  CHECK_EQ_OR_RETURN(tensor_desc->shape(), like->shape());\n  return Maybe<void>::Ok();\n}\nMaybe<void> CheckDataTypeLike(const user_op::TensorDesc* tensor_desc,\n                              const user_op::TensorDesc* like) {\n  CHECK_EQ_OR_RETURN(tensor_desc->data_type(), like->data_type())\n      << \"InferDataType Failed. Expected \" << DataType_Name(tensor_desc->data_type())\n      << \", but got \" << DataType_Name(like->data_type());\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> CheckScalarShape(const user_op::TensorDesc* tensor_desc) {\n  CHECK_OR_RETURN(tensor_desc->shape().NumAxes() == 0\n                  || (tensor_desc->shape().NumAxes() == 1 && tensor_desc->shape().At(0) == 1))\n      << tensor_desc->shape().DebugStr();\n  return Maybe<void>::Ok();\n}\nMaybe<void> CheckScalarDataType(const user_op::TensorDesc* tensor_desc, const DataType data_type) {\n  CHECK_EQ_OR_RETURN(tensor_desc->data_type(), data_type)\n      << \"InferDataType Failed. Expected \" << DataType_Name(tensor_desc->data_type())\n      << \", but got \" << DataType_Name(data_type);\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> CheckLearningRateShape(user_op::InferContext* ctx) {\n  if (ctx->has_input(\"learning_rate\", 0)) {\n    const user_op::TensorDesc& learning_rate = ctx->InputTensorDesc(\"learning_rate\", 0);\n    JUST(CheckScalarShape(&learning_rate));\n  }\n  return Maybe<void>::Ok();\n}\nMaybe<void> CheckLearningRateDataType(user_op::InferContext* ctx) {\n  if (ctx->has_input(\"learning_rate\", 0)) {\n    const user_op::TensorDesc& learning_rate = ctx->InputTensorDesc(\"learning_rate\", 0);\n    JUST(CheckScalarDataType(&learning_rate, DataType::kFloat));\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> CheckIndexedSlicesModelDiffDesc(const user_op::TensorDesc* model,\n                                            const user_op::TensorDesc* model_diff_indices,\n                                            const user_op::TensorDesc* model_diff_values) {\n  const int64_t num_indices_axes = model_diff_indices->shape().NumAxes();\n  const int64_t num_values_axes = model_diff_values->shape().NumAxes();\n  CHECK_GE_OR_RETURN(num_values_axes, num_indices_axes);\n  FOR_RANGE(int64_t, i, 0, num_indices_axes) {\n    CHECK_EQ_OR_RETURN(model_diff_values->shape().At(i), model_diff_indices->shape().At(i));\n  }\n  const int64_t num_model_axes = model->shape().NumAxes();\n  CHECK_EQ_OR_RETURN(num_model_axes, num_values_axes - num_indices_axes + 1);\n  FOR_RANGE(int64_t, i, 1, num_model_axes) {\n    CHECK_EQ_OR_RETURN(model->shape().At(i),\n                       model_diff_values->shape().At(num_indices_axes + i - 1));\n  }\n  return Maybe<void>::Ok();\n}\nMaybe<void> CheckIndexedSlicesModelDiffDataType(const user_op::TensorDesc* model,\n                                                const user_op::TensorDesc* model_diff_indices,\n                                                const user_op::TensorDesc* model_diff_values) {\n  CHECK_OR_RETURN(IsIndexDataType(model_diff_indices->data_type()));\n  CHECK_EQ_OR_RETURN(model->data_type(), model_diff_values->data_type())\n      << \"InferDataType Failed. Expected \" << DataType_Name(model->data_type()) << \", but got \"\n      << DataType_Name(model_diff_values->data_type());\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InferSGDUpdateTensorDesc(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& model = ctx->InputTensorDesc(\"model\", 0);\n  const Shape& shape = model.shape();\n  const user_op::TensorDesc& model_diff = ctx->InputTensorDesc(\"model_diff\", 0);\n  if (shape.NumAxes() > 0 && model_diff.shape().NumAxes() > 0) {\n    CHECK_EQ_OR_RETURN(model_diff.shape(), shape);\n  }\n  JUST(CheckLearningRateShape(ctx));\n  if (ctx->has_input(\"model_copy\", 0)) {\n    CHECK_EQ_OR_RETURN(ctx->InputTensorDesc(\"model_copy\", 0).shape(), shape)\n        << \"Model copy shape should be equal to Model shape. \";\n  }\n  if (ctx->has_input(\"scale_by_tensor\", 0)) {\n    const auto& scale_by_tensor = ctx->InputTensorDesc(\"scale_by_tensor\", 0);\n    JUST(CheckScalarShape(&scale_by_tensor));\n  }\n  return Maybe<void>::Ok();\n}\nMaybe<void> InferSGDUpdateDataType(user_op::InferContext* ctx) {\n  JUST(CheckLearningRateDataType(ctx));\n  if (ctx->has_input(\"scale_by_tensor\", 0)) {\n    const auto& scale_by_tensor = ctx->InputTensorDesc(\"scale_by_tensor\", 0);\n    const user_op::TensorDesc& model = ctx->InputTensorDesc(\"model\", 0);\n    JUST(CheckScalarDataType(&scale_by_tensor, model.data_type()));\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InferIndexedSlicesSGDUpdateTensorDesc(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& model = ctx->InputTensorDesc(\"model\", 0);\n  const user_op::TensorDesc& model_diff_indices = ctx->InputTensorDesc(\"model_diff_indices\", 0);\n  const user_op::TensorDesc& model_diff_values = ctx->InputTensorDesc(\"model_diff_values\", 0);\n  JUST(CheckIndexedSlicesModelDiffDesc(&model, &model_diff_indices, &model_diff_values));\n  JUST(CheckLearningRateShape(ctx));\n  return Maybe<void>::Ok();\n}\nMaybe<void> InferIndexedSlicesSGDUpdateDataType(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& model = ctx->InputTensorDesc(\"model\", 0);\n  const user_op::TensorDesc& model_diff_indices = ctx->InputTensorDesc(\"model_diff_indices\", 0);\n  const user_op::TensorDesc& model_diff_values = ctx->InputTensorDesc(\"model_diff_values\", 0);\n  JUST(CheckIndexedSlicesModelDiffDataType(&model, &model_diff_indices, &model_diff_values));\n  JUST(CheckLearningRateDataType(ctx));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InferMomentumUpdateTensorDesc(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& model = ctx->InputTensorDesc(\"model\", 0);\n  const user_op::TensorDesc& model_diff = ctx->InputTensorDesc(\"model_diff\", 0);\n  CHECK_EQ_OR_RETURN(model_diff.shape(), model.shape());\n  const user_op::TensorDesc& momentum = ctx->InputTensorDesc(\"momentum\", 0);\n  JUST(CheckShapeLike(&momentum, &model));\n  JUST(CheckLearningRateShape(ctx));\n  if (ctx->has_input(\"scale_by_tensor\", 0)) {\n    const auto& scale_by_tensor = ctx->InputTensorDesc(\"scale_by_tensor\", 0);\n    JUST(CheckScalarShape(&scale_by_tensor));\n  }\n  return Maybe<void>::Ok();\n}\nMaybe<void> InferMomentumUpdateDataType(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& model = ctx->InputTensorDesc(\"model\", 0);\n  const user_op::TensorDesc& momentum = ctx->InputTensorDesc(\"momentum\", 0);\n  JUST(CheckDataTypeLike(&momentum, &model));\n  JUST(CheckLearningRateDataType(ctx));\n  if (ctx->has_input(\"scale_by_tensor\", 0)) {\n    const auto& scale_by_tensor = ctx->InputTensorDesc(\"scale_by_tensor\", 0);\n    JUST(CheckScalarDataType(&scale_by_tensor, model.data_type()));\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InferIndexedSlicesMomentumUpdateTensorDesc(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& model = ctx->InputTensorDesc(\"model\", 0);\n  const user_op::TensorDesc& model_diff_indices = ctx->InputTensorDesc(\"model_diff_indices\", 0);\n  const user_op::TensorDesc& model_diff_values = ctx->InputTensorDesc(\"model_diff_values\", 0);\n  JUST(CheckIndexedSlicesModelDiffDesc(&model, &model_diff_indices, &model_diff_values));\n  const user_op::TensorDesc& momentum = ctx->InputTensorDesc(\"momentum\", 0);\n  JUST(CheckShapeLike(&momentum, &model));\n  JUST(CheckLearningRateShape(ctx));\n  return Maybe<void>::Ok();\n}\nMaybe<void> InferIndexedSlicesMomentumUpdateDataType(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& model = ctx->InputTensorDesc(\"model\", 0);\n  const user_op::TensorDesc& model_diff_indices = ctx->InputTensorDesc(\"model_diff_indices\", 0);\n  const user_op::TensorDesc& model_diff_values = ctx->InputTensorDesc(\"model_diff_values\", 0);\n  JUST(CheckIndexedSlicesModelDiffDataType(&model, &model_diff_indices, &model_diff_values));\n  const user_op::TensorDesc& momentum = ctx->InputTensorDesc(\"momentum\", 0);\n  JUST(CheckDataTypeLike(&momentum, &model));\n  JUST(CheckLearningRateDataType(ctx));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InferAdamUpdateTensorDesc(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& model = ctx->InputTensorDesc(\"model\", 0);\n  const Shape& shape = model.shape();\n  const user_op::TensorDesc& model_diff = ctx->InputTensorDesc(\"model_diff\", 0);\n  CHECK_EQ_OR_RETURN(model_diff.shape(), shape);\n  const user_op::TensorDesc& m = ctx->InputTensorDesc(\"m\", 0);\n  JUST(CheckShapeLike(&m, &model));\n  const user_op::TensorDesc& v = ctx->InputTensorDesc(\"v\", 0);\n  JUST(CheckShapeLike(&v, &model));\n  JUST(CheckLearningRateShape(ctx));\n  if (ctx->has_input(\"model_copy\", 0)) {\n    CHECK_EQ_OR_RETURN(ctx->InputTensorDesc(\"model_copy\", 0).shape(), shape)\n        << \"Model copy shape should be equal to Model shape. \";\n  }\n  if (ctx->has_input(\"scale_by_tensor\", 0)) {\n    const auto& scale_by_tensor = ctx->InputTensorDesc(\"scale_by_tensor\", 0);\n    JUST(CheckScalarShape(&scale_by_tensor));\n  }\n  return Maybe<void>::Ok();\n}\nMaybe<void> InferAdamUpdateDataType(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& model = ctx->InputTensorDesc(\"model\", 0);\n  const user_op::TensorDesc& m = ctx->InputTensorDesc(\"m\", 0);\n  JUST(CheckDataTypeLike(&m, &model));\n  const user_op::TensorDesc& v = ctx->InputTensorDesc(\"v\", 0);\n  JUST(CheckDataTypeLike(&v, &model));\n  JUST(CheckLearningRateDataType(ctx));\n  if (ctx->has_input(\"scale_by_tensor\", 0)) {\n    const auto& scale_by_tensor = ctx->InputTensorDesc(\"scale_by_tensor\", 0);\n    JUST(CheckScalarDataType(&scale_by_tensor, model.data_type()));\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InferAdagradUpdateTensorDesc(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& model = ctx->InputTensorDesc(\"model\", 0);\n  const Shape& shape = model.shape();\n  const user_op::TensorDesc& model_diff = ctx->InputTensorDesc(\"model_diff\", 0);\n  CHECK_EQ_OR_RETURN(model_diff.shape(), shape);\n  const user_op::TensorDesc& sum = ctx->InputTensorDesc(\"sum\", 0);\n  JUST(CheckShapeLike(&sum, &model));\n  JUST(CheckLearningRateShape(ctx));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InferAdagradUpdateDataType(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& model = ctx->InputTensorDesc(\"model\", 0);\n  const user_op::TensorDesc& sum = ctx->InputTensorDesc(\"sum\", 0);\n  JUST(CheckDataTypeLike(&sum, &model));\n  JUST(CheckLearningRateDataType(ctx));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InferIndexedSlicesAdamUpdateTensorDesc(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& model = ctx->InputTensorDesc(\"model\", 0);\n  const user_op::TensorDesc& model_diff_indices = ctx->InputTensorDesc(\"model_diff_indices\", 0);\n  const user_op::TensorDesc& model_diff_values = ctx->InputTensorDesc(\"model_diff_values\", 0);\n  JUST(CheckIndexedSlicesModelDiffDesc(&model, &model_diff_indices, &model_diff_values));\n  JUST(CheckLearningRateShape(ctx));\n  return Maybe<void>::Ok();\n}\nMaybe<void> InferIndexedSlicesAdamUpdateDataType(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& model = ctx->InputTensorDesc(\"model\", 0);\n  const user_op::TensorDesc& model_diff_indices = ctx->InputTensorDesc(\"model_diff_indices\", 0);\n  const user_op::TensorDesc& model_diff_values = ctx->InputTensorDesc(\"model_diff_values\", 0);\n  JUST(CheckIndexedSlicesModelDiffDataType(&model, &model_diff_indices, &model_diff_values));\n  JUST(CheckLearningRateDataType(ctx));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InferLambUpdateTensorDesc(user_op::InferContext* ctx) {\n  const float beta1 = ctx->Attr<float>(\"beta1\");\n  const float beta2 = ctx->Attr<float>(\"beta2\");\n  CHECK_GE_OR_RETURN(beta1, 0);\n  CHECK_LT_OR_RETURN(beta1, 1);\n  CHECK_GE_OR_RETURN(beta2, 0);\n  CHECK_LT_OR_RETURN(beta2, 1);\n  const user_op::TensorDesc& model = ctx->InputTensorDesc(\"model\", 0);\n\n  const Shape& shape = model.shape();\n  const user_op::TensorDesc& model_diff = ctx->InputTensorDesc(\"model_diff\", 0);\n  CHECK_EQ_OR_RETURN(model_diff.shape(), shape);\n  const user_op::TensorDesc& m = ctx->InputTensorDesc(\"m\", 0);\n  JUST(CheckShapeLike(&m, &model));\n  const user_op::TensorDesc& v = ctx->InputTensorDesc(\"v\", 0);\n  JUST(CheckShapeLike(&v, &model));\n  JUST(CheckLearningRateShape(ctx));\n  if (ctx->has_input(\"scale_by_tensor\", 0)) {\n    const auto& scale_by_tensor = ctx->InputTensorDesc(\"scale_by_tensor\", 0);\n    JUST(CheckScalarShape(&scale_by_tensor));\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InferLambUpdateDataType(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& model = ctx->InputTensorDesc(\"model\", 0);\n  const user_op::TensorDesc& m = ctx->InputTensorDesc(\"m\", 0);\n  JUST(CheckDataTypeLike(&m, &model));\n  const user_op::TensorDesc& v = ctx->InputTensorDesc(\"v\", 0);\n  JUST(CheckDataTypeLike(&v, &model));\n  JUST(CheckLearningRateDataType(ctx));\n  if (ctx->has_input(\"scale_by_tensor\", 0)) {\n    const auto& scale_by_tensor = ctx->InputTensorDesc(\"scale_by_tensor\", 0);\n    JUST(CheckScalarDataType(&scale_by_tensor, model.data_type()));\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InferFtrlUpdateTensorDesc(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& model = ctx->InputTensorDesc(\"model\", 0);\n  const Shape& shape = model.shape();\n  const user_op::TensorDesc& model_diff = ctx->InputTensorDesc(\"model_diff\", 0);\n  CHECK_EQ_OR_RETURN(model_diff.shape(), shape)\n      << \"Model Diff shape is not consistent with Weight shape. \";\n  const user_op::TensorDesc& accumulate = ctx->InputTensorDesc(\"accumulate\", 0);\n  const user_op::TensorDesc& z = ctx->InputTensorDesc(\"z\", 0);\n  JUST(CheckShapeLike(&accumulate, &model));\n  JUST(CheckShapeLike(&z, &model));\n  JUST(CheckLearningRateShape(ctx));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InferFtrlUpdateDataType(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& model = ctx->InputTensorDesc(\"model\", 0);\n  const user_op::TensorDesc& accumulate = ctx->InputTensorDesc(\"accumulate\", 0);\n  const user_op::TensorDesc& z = ctx->InputTensorDesc(\"z\", 0);\n  JUST(CheckDataTypeLike(&accumulate, &model));\n  JUST(CheckDataTypeLike(&z, &model));\n  JUST(CheckLearningRateDataType(ctx));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InferAdadeltaUpdateTensorDesc(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& model = ctx->InputTensorDesc(\"model\", 0);\n  const user_op::TensorDesc& model_diff = ctx->InputTensorDesc(\"model_diff\", 0);\n  const user_op::TensorDesc& square_avgs = ctx->InputTensorDesc(\"square_avgs\", 0);\n  const user_op::TensorDesc& acc_deltas = ctx->InputTensorDesc(\"acc_deltas\", 0);\n  JUST(CheckShapeLike(&model_diff, &model));\n  JUST(CheckShapeLike(&square_avgs, &model));\n  JUST(CheckShapeLike(&acc_deltas, &model));\n  JUST(CheckLearningRateShape(ctx));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InferAdadeltaUpdateDataType(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& model = ctx->InputTensorDesc(\"model\", 0);\n  const user_op::TensorDesc& square_avgs = ctx->InputTensorDesc(\"square_avgs\", 0);\n  const user_op::TensorDesc& acc_deltas = ctx->InputTensorDesc(\"acc_deltas\", 0);\n  JUST(CheckDataTypeLike(&square_avgs, &model));\n  JUST(CheckDataTypeLike(&acc_deltas, &model));\n  JUST(CheckLearningRateDataType(ctx));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> SetInputArgModifierMutable(const user_op::GetInputArgModifier& GetInputArgModifierFn,\n                                       const std::string& arg_name, int32_t arg_index) {\n  user_op::InputArgModifier* arg_modifier = GetInputArgModifierFn(arg_name, arg_index);\n  CHECK_NOTNULL_OR_RETURN(arg_modifier);\n  arg_modifier->set_is_mutable(true);\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> AdamInputArgModifyFn(const user_op::GetInputArgModifier& GetInputArgModifierFn,\n                                 const user_op::UserOpConfWrapper& conf) {\n  JUST(SetInputArgModifierMutable(GetInputArgModifierFn, \"model\", 0));\n  JUST(SetInputArgModifierMutable(GetInputArgModifierFn, \"m\", 0));\n  JUST(SetInputArgModifierMutable(GetInputArgModifierFn, \"v\", 0));\n  if (conf.has_input(\"max_v\", 0)) {\n    JUST(SetInputArgModifierMutable(GetInputArgModifierFn, \"max_v\", 0));\n  }\n  if (conf.has_input(\"model_copy\", 0)) {\n    JUST(SetInputArgModifierMutable(GetInputArgModifierFn, \"model_copy\", 0));\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> AdagradInputArgModifyFn(const user_op::GetInputArgModifier& GetInputArgModifierFn,\n                                    const user_op::UserOpConfWrapper& conf) {\n  JUST(SetInputArgModifierMutable(GetInputArgModifierFn, \"model\", 0));\n  JUST(SetInputArgModifierMutable(GetInputArgModifierFn, \"sum\", 0));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> LambInputArgModifyFn(const user_op::GetInputArgModifier& GetInputArgModifierFn,\n                                 const user_op::UserOpConfWrapper& conf) {\n  JUST(SetInputArgModifierMutable(GetInputArgModifierFn, \"model\", 0));\n  JUST(SetInputArgModifierMutable(GetInputArgModifierFn, \"m\", 0));\n  JUST(SetInputArgModifierMutable(GetInputArgModifierFn, \"v\", 0));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> SgdInputArgModifyFn(const user_op::GetInputArgModifier& GetInputArgModifierFn,\n                                const user_op::UserOpConfWrapper& conf) {\n  JUST(SetInputArgModifierMutable(GetInputArgModifierFn, \"model\", 0));\n  if (conf.has_input(\"model_copy\", 0)) {\n    JUST(SetInputArgModifierMutable(GetInputArgModifierFn, \"model_copy\", 0));\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> IndexedSlicesSgdInputArgModifyFn(\n    const user_op::GetInputArgModifier& GetInputArgModifierFn,\n    const user_op::UserOpConfWrapper& conf) {\n  JUST(SetInputArgModifierMutable(GetInputArgModifierFn, \"model\", 0));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> MomentumInputArgModifyFn(const user_op::GetInputArgModifier& GetInputArgModifierFn,\n                                     const user_op::UserOpConfWrapper& conf) {\n  JUST(SetInputArgModifierMutable(GetInputArgModifierFn, \"model\", 0));\n  JUST(SetInputArgModifierMutable(GetInputArgModifierFn, \"momentum\", 0));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> IndexedSlicesMomentumInputArgModifyFn(\n    const user_op::GetInputArgModifier& GetInputArgModifierFn,\n    const user_op::UserOpConfWrapper& conf) {\n  JUST(SetInputArgModifierMutable(GetInputArgModifierFn, \"model\", 0));\n  JUST(SetInputArgModifierMutable(GetInputArgModifierFn, \"momentum\", 0));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> RmsPropUpdateInputArgModifyFn(const user_op::GetInputArgModifier& GetInputArgModifierFn,\n                                          const user_op::UserOpConfWrapper& conf) {\n  JUST(SetInputArgModifierMutable(GetInputArgModifierFn, \"model\", 0));\n  JUST(SetInputArgModifierMutable(GetInputArgModifierFn, \"mean_square\", 0));\n  if (conf.attr<bool>(\"centered\")) {\n    JUST(SetInputArgModifierMutable(GetInputArgModifierFn, \"mean_gradient\", 0));\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> LarsUpdateInputArgModifyFn(const user_op::GetInputArgModifier& GetInputArgModifierFn,\n                                       const user_op::UserOpConfWrapper& conf) {\n  JUST(SetInputArgModifierMutable(GetInputArgModifierFn, \"model\", 0));\n  JUST(SetInputArgModifierMutable(GetInputArgModifierFn, \"momentum\", 0));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> FtrlInputArgModifyFn(const user_op::GetInputArgModifier& GetInputArgModifierFn,\n                                 const user_op::UserOpConfWrapper& conf) {\n  JUST(SetInputArgModifierMutable(GetInputArgModifierFn, \"model\", 0));\n  JUST(SetInputArgModifierMutable(GetInputArgModifierFn, \"accumulate\", 0));\n  JUST(SetInputArgModifierMutable(GetInputArgModifierFn, \"z\", 0));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> AdadeltaInputArgModifyFn(const user_op::GetInputArgModifier& GetInputArgModifierFn,\n                                     const user_op::UserOpConfWrapper& conf) {\n  JUST(SetInputArgModifierMutable(GetInputArgModifierFn, \"model\", 0));\n  JUST(SetInputArgModifierMutable(GetInputArgModifierFn, \"square_avgs\", 0));\n  JUST(SetInputArgModifierMutable(GetInputArgModifierFn, \"acc_deltas\", 0));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InferRmsPropUpdateTensorDesc(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& model = ctx->InputTensorDesc(\"model\", 0);\n\n  const Shape& shape = model.shape();\n  const user_op::TensorDesc& model_diff = ctx->InputTensorDesc(\"model_diff\", 0);\n  CHECK_EQ_OR_RETURN(model_diff.shape(), shape);\n  const user_op::TensorDesc& mean_square = ctx->InputTensorDesc(\"mean_square\", 0);\n  JUST(CheckShapeLike(&mean_square, &model));\n  JUST(CheckLearningRateShape(ctx));\n  if (ctx->has_input(\"scale_by_tensor\", 0)) {\n    const auto& scale_by_tensor = ctx->InputTensorDesc(\"scale_by_tensor\", 0);\n    JUST(CheckScalarShape(&scale_by_tensor));\n  }\n  if (ctx->Attr<bool>(\"centered\")) {\n    CHECK_OR_RETURN(ctx->has_input(\"mean_gradient\", 0));\n    const user_op::TensorDesc& mean_gradient = ctx->InputTensorDesc(\"mean_gradient\", 0);\n    JUST(CheckShapeLike(&mean_gradient, &model));\n  } else {\n    CHECK_OR_RETURN(!ctx->has_input(\"mean_gradient\", 0));\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InferRmsPropUpdateDataType(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& model = ctx->InputTensorDesc(\"model\", 0);\n  const user_op::TensorDesc& mean_square = ctx->InputTensorDesc(\"mean_square\", 0);\n  JUST(CheckDataTypeLike(&mean_square, &model));\n  JUST(CheckLearningRateDataType(ctx));\n  const DataType data_type = model.data_type();\n  if (ctx->has_input(\"scale_by_tensor\", 0)) {\n    const auto& scale_by_tensor = ctx->InputTensorDesc(\"scale_by_tensor\", 0);\n    JUST(CheckScalarDataType(&scale_by_tensor, data_type));\n  }\n  if (ctx->Attr<bool>(\"centered\")) {\n    CHECK_OR_RETURN(ctx->has_input(\"mean_gradient\", 0));\n    const user_op::TensorDesc& mean_gradient = ctx->InputTensorDesc(\"mean_gradient\", 0);\n    JUST(CheckDataTypeLike(&mean_gradient, &model));\n  }\n  return Maybe<void>::Ok();\n}\nMaybe<void> InferLarsUpdateTensorDesc(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& model = ctx->InputTensorDesc(\"model\", 0);\n\n  const Shape& shape = model.shape();\n  const user_op::TensorDesc& model_diff = ctx->InputTensorDesc(\"model_diff\", 0);\n  CHECK_EQ_OR_RETURN(model_diff.shape(), shape);\n  const user_op::TensorDesc& momentum = ctx->InputTensorDesc(\"momentum\", 0);\n  JUST(CheckShapeLike(&momentum, &model));\n  JUST(CheckLearningRateShape(ctx));\n  if (ctx->has_input(\"scale_by_tensor\", 0)) {\n    const auto& scale_by_tensor = ctx->InputTensorDesc(\"scale_by_tensor\", 0);\n    JUST(CheckScalarShape(&scale_by_tensor));\n  }\n  return Maybe<void>::Ok();\n}\nMaybe<void> InferLarsUpdateDataType(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& model = ctx->InputTensorDesc(\"model\", 0);\n  const user_op::TensorDesc& momentum = ctx->InputTensorDesc(\"momentum\", 0);\n  JUST(CheckDataTypeLike(&momentum, &model));\n  JUST(CheckLearningRateDataType(ctx));\n  const DataType data_type = model.data_type();\n  if (ctx->has_input(\"scale_by_tensor\", 0)) {\n    const auto& scale_by_tensor = ctx->InputTensorDesc(\"scale_by_tensor\", 0);\n    JUST(CheckScalarDataType(&scale_by_tensor, data_type));\n  }\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\n/* static */ Maybe<void> SgdUpdateOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  return InferSGDUpdateTensorDesc(ctx);\n}\n\n/*static*/ Maybe<void> SgdUpdateOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> SgdUpdateOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& model = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"model\", 0);\n  FOR_RANGE(int64_t, axis, 0, model.shape().NumAxes()) {\n    auto builder = ctx->NewBuilder()\n                       .Broadcast(ctx->inputs())\n                       .Split(user_op::OpArg(\"model\", 0), axis)\n                       .Split(user_op::OpArg(\"model_diff\", 0), axis);\n    if (ctx->user_op_conf().has_input(\"model_copy\", 0)) {\n      builder.Split(user_op::OpArg(\"model_copy\", 0), axis);\n    }\n    builder.Build();\n  }\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> SgdUpdateOp::ModifyInputArg(\n    const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) {\n  return SgdInputArgModifyFn(GetInputArgModifierFn, conf);\n}\n\n/* static */ Maybe<void> SgdUpdateOp::InferDataType(user_op::InferContext* ctx) {\n  return InferSGDUpdateDataType(ctx);\n}\n\n/* static */ Maybe<void> IndexedSlicesSgdUpdateOp::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferIndexedSlicesSGDUpdateTensorDesc(ctx);\n}\n\n/*static*/ Maybe<void> IndexedSlicesSgdUpdateOp::InferPhysicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> IndexedSlicesSgdUpdateOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& model = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"model\", 0);\n  const user_op::TensorDesc& model_diff_indices =\n      ctx->LogicalTensorDesc4InputArgNameAndIndex(\"model_diff_indices\", 0);\n  ctx->NewBuilder()\n      .Broadcast(user_op::OpArg(\"learning_rate\", 0))\n      .Broadcast(user_op::OpArg(\"model_diff_indices\", 0))\n      .Broadcast(user_op::OpArg(\"model_diff_values\", 0))\n      .Split(user_op::OpArg(\"model\", 0), 0)\n      .Build();\n  FOR_RANGE(int64_t, i, 1, model.shape().NumAxes()) {\n    ctx->NewBuilder()\n        .Broadcast(user_op::OpArg(\"learning_rate\", 0))\n        .Broadcast(user_op::OpArg(\"model_diff_indices\", 0))\n        .Split(user_op::OpArg(\"model_diff_values\", 0), model_diff_indices.shape().NumAxes() + i - 1)\n        .Split(user_op::OpArg(\"model\", 0), i)\n        .Build();\n  }\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> IndexedSlicesSgdUpdateOp::ModifyInputArg(\n    const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) {\n  return IndexedSlicesSgdInputArgModifyFn(GetInputArgModifierFn, conf);\n}\n\n/* static */ Maybe<void> IndexedSlicesSgdUpdateOp::InferDataType(user_op::InferContext* ctx) {\n  return InferIndexedSlicesSGDUpdateDataType(ctx);\n}\n\n/* static */ Maybe<void> MomentumUpdateOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  return InferMomentumUpdateTensorDesc(ctx);\n}\n\n/*static*/ Maybe<void> MomentumUpdateOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> MomentumUpdateOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& model = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"model\", 0);\n  FOR_RANGE(int64_t, axis, 0, model.shape().NumAxes()) {\n    ctx->NewBuilder()\n        .Broadcast(ctx->inputs())\n        .Split(user_op::OpArg(\"model\", 0), axis)\n        .Split(user_op::OpArg(\"model_diff\", 0), axis)\n        .Split(user_op::OpArg(\"momentum\", 0), axis)\n        .Build();\n  }\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> MomentumUpdateOp::ModifyInputArg(\n    const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) {\n  return MomentumInputArgModifyFn(GetInputArgModifierFn, conf);\n}\n\n/* static */ Maybe<void> MomentumUpdateOp::InferDataType(user_op::InferContext* ctx) {\n  return InferMomentumUpdateDataType(ctx);\n}\n\n/* static */ Maybe<void> IndexedSlicesMomentumUpdateOp::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferIndexedSlicesMomentumUpdateTensorDesc(ctx);\n}\n\n/*static*/ Maybe<void> IndexedSlicesMomentumUpdateOp::InferPhysicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> IndexedSlicesMomentumUpdateOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& model = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"model\", 0);\n  const user_op::TensorDesc& model_diff_indices =\n      ctx->LogicalTensorDesc4InputArgNameAndIndex(\"model_diff_indices\", 0);\n  ctx->NewBuilder()\n      .Broadcast(user_op::OpArg(\"learning_rate\", 0))\n      .Broadcast(user_op::OpArg(\"model_diff_indices\", 0))\n      .Broadcast(user_op::OpArg(\"model_diff_values\", 0))\n      .Split(user_op::OpArg(\"model\", 0), 0)\n      .Split(user_op::OpArg(\"momentum\", 0), 0)\n      .Build();\n  FOR_RANGE(int64_t, i, 1, model.shape().NumAxes()) {\n    ctx->NewBuilder()\n        .Broadcast(user_op::OpArg(\"learning_rate\", 0))\n        .Broadcast(user_op::OpArg(\"model_diff_indices\", 0))\n        .Split(user_op::OpArg(\"model_diff_values\", 0), model_diff_indices.shape().NumAxes() + i - 1)\n        .Split(user_op::OpArg(\"model\", 0), i)\n        .Split(user_op::OpArg(\"momentum\", 0), i)\n        .Build();\n  }\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> IndexedSlicesMomentumUpdateOp::ModifyInputArg(\n    const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) {\n  return IndexedSlicesMomentumInputArgModifyFn(GetInputArgModifierFn, conf);\n}\n\n/* static */ Maybe<void> IndexedSlicesMomentumUpdateOp::InferDataType(user_op::InferContext* ctx) {\n  return InferIndexedSlicesMomentumUpdateDataType(ctx);\n}\n\n/* static */ Maybe<void> AdamUpdateOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  return InferAdamUpdateTensorDesc(ctx);\n}\n\n/*static*/ Maybe<void> AdamUpdateOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> AdamUpdateOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& model = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"model\", 0);\n  FOR_RANGE(int64_t, axis, 0, model.shape().NumAxes()) {\n    std::vector<user_op::OpArg> split_args;\n    split_args.emplace_back(\"model\", 0);\n    split_args.emplace_back(\"model_diff\", 0);\n    split_args.emplace_back(\"m\", 0);\n    split_args.emplace_back(\"v\", 0);\n    if (ctx->user_op_conf().has_input(\"max_v\", 0)) { split_args.emplace_back(\"max_v\", 0); }\n    auto builder = ctx->NewBuilder().Broadcast(ctx->inputs()).Split(split_args, axis);\n    if (ctx->user_op_conf().has_input(\"model_copy\", 0)) {\n      builder.Split(user_op::OpArg(\"model_copy\", 0), axis);\n    }\n    builder.Build();\n  }\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> AdamUpdateOp::ModifyInputArg(\n    const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) {\n  return AdamInputArgModifyFn(GetInputArgModifierFn, conf);\n}\n\n/* static */ Maybe<void> AdamUpdateOp::InferDataType(user_op::InferContext* ctx) {\n  return InferAdamUpdateDataType(ctx);\n}\n\n/* static */ Maybe<void> AdagradUpdateOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  return InferAdagradUpdateTensorDesc(ctx);\n}\n\n/*static*/ Maybe<void> AdagradUpdateOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> AdagradUpdateOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& model = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"model\", 0);\n  FOR_RANGE(int64_t, axis, 0, model.shape().NumAxes()) {\n    ctx->NewBuilder()\n        .Broadcast(ctx->inputs())\n        .Split(user_op::OpArg(\"model\", 0), axis)\n        .Split(user_op::OpArg(\"model_diff\", 0), axis)\n        .Split(user_op::OpArg(\"sum\", 0), axis)\n        .Build();\n  }\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> AdagradUpdateOp::ModifyInputArg(\n    const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) {\n  return AdagradInputArgModifyFn(GetInputArgModifierFn, conf);\n}\n\n/* static */ Maybe<void> AdagradUpdateOp::InferDataType(user_op::InferContext* ctx) {\n  return InferAdagradUpdateDataType(ctx);\n}\n\n/* static */ Maybe<void> IndexedSlicesAdamUpdateOp::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferIndexedSlicesAdamUpdateTensorDesc(ctx);\n}\n\n/*static*/ Maybe<void> IndexedSlicesAdamUpdateOp::InferPhysicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> IndexedSlicesAdamUpdateOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& model = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"model\", 0);\n  const user_op::TensorDesc& model_diff_indices =\n      ctx->LogicalTensorDesc4InputArgNameAndIndex(\"model_diff_indices\", 0);\n  std::vector<user_op::OpArg> broadcast_args;\n  broadcast_args.emplace_back(\"learning_rate\", 0);\n  broadcast_args.emplace_back(\"model_diff_indices\", 0);\n\n  std::vector<user_op::OpArg> split_args;\n  split_args.emplace_back(\"model\", 0);\n  split_args.emplace_back(\"m\", 0);\n  split_args.emplace_back(\"v\", 0);\n  if (ctx->user_op_conf().has_input(\"max_v\", 0)) { split_args.emplace_back(\"max_v\", 0); }\n\n  ctx->NewBuilder()\n      .Broadcast(broadcast_args)\n      .Broadcast(user_op::OpArg(\"model_diff_values\", 0))\n      .Split(split_args, 0)\n      .Build();\n\n  FOR_RANGE(int64_t, i, 1, model.shape().NumAxes()) {\n    ctx->NewBuilder()\n        .Broadcast(broadcast_args)\n        .Split(user_op::OpArg(\"model_diff_values\", 0), model_diff_indices.shape().NumAxes() + i - 1)\n        .Split(split_args, i)\n        .Build();\n  }\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> IndexedSlicesAdamUpdateOp::ModifyInputArg(\n    const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) {\n  return AdamInputArgModifyFn(GetInputArgModifierFn, conf);\n}\n\n/* static */ Maybe<void> IndexedSlicesAdamUpdateOp::InferDataType(user_op::InferContext* ctx) {\n  return InferIndexedSlicesAdamUpdateDataType(ctx);\n}\n\n/* static */ Maybe<void> LambUpdateOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLambUpdateTensorDesc(ctx);\n}\n\n/*static*/ Maybe<void> LambUpdateOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> LambUpdateOp::GetSbp(user_op::SbpContext* ctx) {\n  return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx);\n}\n\n/* static */ Maybe<void> LambUpdateOp::ModifyInputArg(\n    const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) {\n  return LambInputArgModifyFn(GetInputArgModifierFn, conf);\n}\n\n/* static */ Maybe<void> LambUpdateOp::InferDataType(user_op::InferContext* ctx) {\n  return InferLambUpdateDataType(ctx);\n}\n\n/* static */ Maybe<void> AdamBiasCorrectionFactorOp::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) {\n  ctx->SetOutputShape(\"out\", 0, ctx->InputShape(\"train_step\", 0));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> AdamBiasCorrectionFactorOp::InferPhysicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> AdamBiasCorrectionFactorOp::GetSbp(user_op::SbpContext* ctx) {\n  return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx);\n}\n\n/* static */ Maybe<void> AdamBiasCorrectionFactorOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, DataType::kFloat);\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> RmspropUpdateOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  return InferRmsPropUpdateTensorDesc(ctx);\n}\n\n/*static*/ Maybe<void> RmspropUpdateOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> RmspropUpdateOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& model = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"model\", 0);\n  bool centered = ctx->Attr<bool>(\"centered\");\n  FOR_RANGE(int64_t, axis, 0, model.shape().NumAxes()) {\n    if (centered) {\n      ctx->NewBuilder()\n          .Broadcast(ctx->inputs())\n          .Split(user_op::OpArg(\"model\", 0), axis)\n          .Split(user_op::OpArg(\"model_diff\", 0), axis)\n          .Split(user_op::OpArg(\"mean_square\", 0), axis)\n          .Split(user_op::OpArg(\"mean_gradient\", 0), axis)\n          .Build();\n    } else {\n      ctx->NewBuilder()\n          .Broadcast(ctx->inputs())\n          .Split(user_op::OpArg(\"model\", 0), axis)\n          .Split(user_op::OpArg(\"model_diff\", 0), axis)\n          .Split(user_op::OpArg(\"mean_square\", 0), axis)\n          .Build();\n    }\n  }\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> RmspropUpdateOp::ModifyInputArg(\n    const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) {\n  return RmsPropUpdateInputArgModifyFn(GetInputArgModifierFn, conf);\n}\n\n/* static */ Maybe<void> RmspropUpdateOp::InferDataType(user_op::InferContext* ctx) {\n  return InferRmsPropUpdateDataType(ctx);\n}\n\n/* static */ Maybe<void> LarsUpdateOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLarsUpdateTensorDesc(ctx);\n}\n\n/*static*/ Maybe<void> LarsUpdateOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> LarsUpdateOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& model = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"model\", 0);\n  FOR_RANGE(int64_t, axis, 0, model.shape().NumAxes()) {\n    ctx->NewBuilder()\n        .Broadcast(ctx->inputs())\n        .Split(user_op::OpArg(\"model\", 0), axis)\n        .Split(user_op::OpArg(\"model_diff\", 0), axis)\n        .Split(user_op::OpArg(\"momentum\", 0), axis)\n        .Build();\n  }\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> LarsUpdateOp::ModifyInputArg(\n    const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) {\n  return LarsUpdateInputArgModifyFn(GetInputArgModifierFn, conf);\n}\n\n/* static */ Maybe<void> LarsUpdateOp::InferDataType(user_op::InferContext* ctx) {\n  return InferLarsUpdateDataType(ctx);\n}\n\n/* static */ Maybe<void> FtrlUpdateOp::ModifyInputArg(\n    const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) {\n  return FtrlInputArgModifyFn(GetInputArgModifierFn, conf);\n}\n\n/* static */ Maybe<void> FtrlUpdateOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  return InferFtrlUpdateTensorDesc(ctx);\n}\n\n/*static*/ Maybe<void> FtrlUpdateOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> FtrlUpdateOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& model = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"model\", 0);\n  FOR_RANGE(int64_t, axis, 0, model.shape().NumAxes()) {\n    ctx->NewBuilder()\n        .Broadcast(ctx->inputs())\n        .Split(user_op::OpArg(\"model\", 0), axis)\n        .Split(user_op::OpArg(\"model_diff\", 0), axis)\n        .Split(user_op::OpArg(\"accumulate\", 0), axis)\n        .Split(user_op::OpArg(\"z\", 0), axis)\n        .Build();\n  }\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> FtrlUpdateOp::InferDataType(user_op::InferContext* ctx) {\n  return InferFtrlUpdateDataType(ctx);\n}\n\n/* static */ Maybe<void> AdadeltaUpdateOp::ModifyInputArg(\n    const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) {\n  return AdadeltaInputArgModifyFn(GetInputArgModifierFn, conf);\n}\n\n/* static */ Maybe<void> AdadeltaUpdateOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  return InferAdadeltaUpdateTensorDesc(ctx);\n}\n\n/*static*/ Maybe<void> AdadeltaUpdateOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> AdadeltaUpdateOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& model = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"model\", 0);\n  FOR_RANGE(int64_t, axis, 0, model.shape().NumAxes()) {\n    ctx->NewBuilder()\n        .Broadcast(ctx->inputs())\n        .Split(user_op::OpArg(\"model\", 0), axis)\n        .Split(user_op::OpArg(\"model_diff\", 0), axis)\n        .Split(user_op::OpArg(\"square_avgs\", 0), axis)\n        .Split(user_op::OpArg(\"acc_deltas\", 0), axis)\n        .Build();\n  }\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> AdadeltaUpdateOp::InferDataType(user_op::InferContext* ctx) {\n  return InferAdadeltaUpdateDataType(ctx);\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/moving_average_min_max_observer_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/* static */ Maybe<void> MovingAverageMinMaxObserverOp::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) {\n  const Shape& moving_max_shape = ctx->InputShape(\"moving_max\", 0);\n  const Shape& moving_min_shape = ctx->InputShape(\"moving_min\", 0);\n  const Shape& current_train_step = ctx->InputShape(\"current_train_step\", 0);\n\n  // NOTE(Liang Depeng): for now only support per-layer quantization\n  // TODO(Liang Depeng): depthwise convolution support per-channel quantization\n  CHECK_OR_RETURN(moving_max_shape.NumAxes() == 1 && moving_max_shape.At(0) == 1);\n  CHECK_OR_RETURN(moving_min_shape.NumAxes() == 1 && moving_min_shape.At(0) == 1);\n\n  CHECK_OR_RETURN(current_train_step.NumAxes() == 1 && current_train_step.At(0) == 1);\n\n  ctx->SetOutputShape(\"scale\", 0, Shape({1}));\n  ctx->SetOutputShape(\"zero_point\", 0, Shape({1}));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> MovingAverageMinMaxObserverOp::InferPhysicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> MovingAverageMinMaxObserverOp::GetSbp(user_op::SbpContext* ctx) {\n  // NOTE(Liang Depeng): all inputs need to be broadcast in order to accuratly calculate the\n  // global scale and zero_point\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> MovingAverageMinMaxObserverOp::ModifyInputArg(\n    const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) {\n  user_op::InputArgModifier* in = GetInputArgModifierFn(\"in\", 0);\n  CHECK_OR_RETURN(in != nullptr);\n  in->set_requires_grad(false);\n\n  user_op::InputArgModifier* current_train_step = GetInputArgModifierFn(\"current_train_step\", 0);\n  CHECK_OR_RETURN(current_train_step != nullptr);\n  current_train_step->set_requires_grad(false);\n\n  user_op::InputArgModifier* moving_max = GetInputArgModifierFn(\"moving_max\", 0);\n  CHECK_OR_RETURN(moving_max != nullptr);\n  moving_max->set_requires_grad(false);\n  moving_max->set_is_mutable(true);\n\n  user_op::InputArgModifier* moving_min = GetInputArgModifierFn(\"moving_min\", 0);\n  CHECK_OR_RETURN(moving_min != nullptr);\n  moving_min->set_requires_grad(false);\n  moving_min->set_is_mutable(true);\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> MovingAverageMinMaxObserverOp::CheckAttr(\n    const user_op::UserOpDefWrapper& def, const user_op::UserOpConfWrapper& op_conf) {\n  int32_t quantization_bit = op_conf.attr<int32_t>(\"quantization_bit\");\n  CHECK_GT_OR_RETURN(quantization_bit, 1);\n  CHECK_LE_OR_RETURN(quantization_bit, 8);\n\n  std::string quantization_scheme = op_conf.attr<std::string>(\"quantization_scheme\");\n  CHECK_OR_RETURN(quantization_scheme == \"symmetric\" || quantization_scheme == \"affine\");\n\n  int64_t stop_update_after_iters = op_conf.attr<int64_t>(\"stop_update_after_iters\");\n  CHECK_GT_OR_RETURN(stop_update_after_iters, 0);\n\n  std::string quantization_formula = op_conf.attr<std::string>(\"quantization_formula\");\n  CHECK_OR_RETURN(quantization_formula == \"google\" || quantization_formula == \"cambricon\");\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> MovingAverageMinMaxObserverOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"scale\", 0, ctx->InputDType(\"in\", 0));\n  ctx->SetOutputDType(\"zero_point\", 0, ctx->InputDType(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/multi_reduce_ops.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n#include \"oneflow/core/framework/nd_sbp.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nMaybe<void> InferMultiReduceOpShape(user_op::InferContext* ctx) {\n  CHECK_GT_OR_RETURN(ctx->input_size(\"x\"), 0) << ctx->op_name() << \"must have at least 1 input\";\n  ctx->SetOutputShape(\"y\", 0, Shape({}));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InferMultiReduceOpDataType(user_op::InferContext* ctx) {\n  DataType x_0_dtype = ctx->InputDType(\"x\", 0);\n  for (size_t i = 1; i < ctx->input_size(\"x\"); ++i) {\n    CHECK_EQ_OR_RETURN(ctx->InputDType(\"x\", i), x_0_dtype)\n        << ctx->op_name() << \": the \" << i << \" th input has the different data type with others\";\n  }\n  ctx->SetOutputDType(\"y\", 0, x_0_dtype);\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> GetMultiReduceOpSbp(user_op::SbpContext* ctx) {\n  const auto& x_0 = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"x\", 0);\n  int64_t min_num_axes = x_0.shape().NumAxes();\n  for (size_t i = 1; i < ctx->user_op_conf().input_size(\"x\"); ++i) {\n    const auto& x_i = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"x\", i);\n    min_num_axes = std::min(min_num_axes, x_i.shape().NumAxes());\n  }\n  for (int64_t i = 0; i < min_num_axes; ++i) {\n    ctx->NewBuilder().Split(ctx->inputs(), i).PartialSum(user_op::OpArg(\"y\", 0)).Build();\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InferLocalMultiReduceOpLogicalShape(user_op::InferContext* ctx) {\n  CHECK_GT_OR_RETURN(ctx->input_size(\"x\"), 0) << ctx->op_name() << \"must have at least 1 input\";\n  const NdSbp& any_nd_sbp = ctx->NdSbp4ArgNameAndIndex(\"x\", 0);\n  for (int32_t i = 1; i < ctx->input_size(\"x\"); ++i) {\n    const NdSbp& input_i_sbp = ctx->NdSbp4ArgNameAndIndex(\"x\", i);\n    CHECK_OR_RETURN(input_i_sbp == any_nd_sbp)\n        << ctx->op_name() << \": the \" << i << \" th arg has the different sbp with others, \"\n        << NdSbpToString(input_i_sbp) << \" vs. \" << NdSbpToString(any_nd_sbp);\n  }\n  auto rank_mesh = ctx->parallel_desc().hierarchy();\n  CHECK_EQ_OR_RETURN(rank_mesh->NumAxes(), any_nd_sbp.sbp_parallel_size())\n      << ctx->op_name() << \": ndim of ranks of \" << *JUST(PlacementToString(ctx->parallel_desc()))\n      << \" is mismatched with the size of sbp \" << NdSbpToString(any_nd_sbp);\n  int64_t split_num = 1;\n  for (int64_t i = 0; i < rank_mesh->NumAxes(); ++i) {\n    if (any_nd_sbp.sbp_parallel(i).has_split_parallel()) { split_num *= rank_mesh->At(i); }\n  }\n  ctx->SetOutputShape(\"y\", 0, Shape({split_num}));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InferLocalMultiReduceOpPhysicalShape(user_op::InferContext* ctx) {\n  CHECK_GT_OR_RETURN(ctx->input_size(\"x\"), 0) << ctx->op_name() << \"must have at least 1 input\";\n  ctx->SetOutputShape(\"y\", 0, Shape({1}));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> GetLocalMultiReduceOpSbp(user_op::SbpContext* ctx) {\n  const auto& x_0 = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"x\", 0);\n  int64_t min_num_axes = x_0.shape().NumAxes();\n  for (size_t i = 1; i < ctx->user_op_conf().input_size(\"x\"); ++i) {\n    const auto& x_i = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"x\", i);\n    min_num_axes = std::min(min_num_axes, x_i.shape().NumAxes());\n  }\n  for (int64_t i = 0; i < min_num_axes; ++i) {\n    ctx->NewBuilder().Split(ctx->inputs(), i).Split(user_op::OpArg(\"y\", 0), 0).Build();\n  }\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\n#define DEFINE_MULTI_REDUCE_OP_METHODS(op)                                 \\\n  Maybe<void> op##Op::InferLogicalTensorDesc(user_op::InferContext* ctx) { \\\n    return InferMultiReduceOpShape(ctx);                                   \\\n  }                                                                        \\\n  Maybe<void> op##Op::InferDataType(user_op::InferContext* ctx) {          \\\n    return InferMultiReduceOpDataType(ctx);                                \\\n  }                                                                        \\\n  Maybe<void> op##Op::GetSbp(user_op::SbpContext* ctx) { return GetMultiReduceOpSbp(ctx); }\n\nDEFINE_MULTI_REDUCE_OP_METHODS(MultiReduceSumPowAbs)\nDEFINE_MULTI_REDUCE_OP_METHODS(MultiReduceMaxAbs)\nDEFINE_MULTI_REDUCE_OP_METHODS(MultiReduceMinAbs)\n#undef DEFINE_MULTI_REDUCE_OP_METHODS\n\n#define DEFINE_LOCAL_MULTI_REDUCE_OP_METHODS(op)                            \\\n  Maybe<void> op##Op::InferLogicalTensorDesc(user_op::InferContext* ctx) {  \\\n    return InferLocalMultiReduceOpLogicalShape(ctx);                        \\\n  }                                                                         \\\n  Maybe<void> op##Op::InferPhysicalTensorDesc(user_op::InferContext* ctx) { \\\n    return InferLocalMultiReduceOpPhysicalShape(ctx);                       \\\n  }                                                                         \\\n  Maybe<void> op##Op::InferDataType(user_op::InferContext* ctx) {           \\\n    return InferMultiReduceOpDataType(ctx);                                 \\\n  }                                                                         \\\n  Maybe<void> op##Op::GetSbp(user_op::SbpContext* ctx) { return GetLocalMultiReduceOpSbp(ctx); }\n\nDEFINE_LOCAL_MULTI_REDUCE_OP_METHODS(LocalMultiReduceMaxAbs)\nDEFINE_LOCAL_MULTI_REDUCE_OP_METHODS(LocalMultiReduceMinAbs)\n#undef DEFINE_LOCAL_MULTI_REDUCE_OP_METHODS\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/multi_tensor_model_update_ops.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/infer_util.h\"\n#include \"oneflow/core/framework/user_op_conf.h\"\n#include \"oneflow/core/framework/user_op_registry.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nMaybe<void> CheckShapeLike(const user_op::TensorDesc* tensor_desc,\n                           const user_op::TensorDesc* like) {\n  CHECK_EQ_OR_RETURN(tensor_desc->shape(), like->shape())\n      << \"Tensordesc shape should be equal to Like shape. \";\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> CheckDataTypeLike(const user_op::TensorDesc* tensor_desc,\n                              const user_op::TensorDesc* like) {\n  CHECK_EQ_OR_RETURN(tensor_desc->data_type(), like->data_type())\n      << \"Tensordesc DataType should be equal to Like DataType. \";\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> CheckScalarShape(const user_op::TensorDesc* tensor_desc) {\n  CHECK_OR_RETURN(tensor_desc->shape().NumAxes() == 0\n                  || (tensor_desc->shape().NumAxes() == 1 && tensor_desc->shape().At(0) == 1))\n      << tensor_desc->shape().DebugStr();\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> CheckScalarDataType(const user_op::TensorDesc* tensor_desc, const DataType data_type) {\n  CHECK_EQ_OR_RETURN(tensor_desc->data_type(), data_type)\n      << \"TensorDesc DataType should be equal to Scalar DataType. \";\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> CheckLearningRateShape(user_op::InferContext* ctx) {\n  if (ctx->has_input(\"learning_rate\", 0)) {\n    const user_op::TensorDesc& learning_rate = ctx->InputTensorDesc(\"learning_rate\", 0);\n    JUST(CheckScalarShape(&learning_rate));\n  }\n  return Maybe<void>::Ok();\n}\nMaybe<void> CheckLearningRateDataType(user_op::InferContext* ctx) {\n  if (ctx->has_input(\"learning_rate\", 0)) {\n    const user_op::TensorDesc& learning_rate = ctx->InputTensorDesc(\"learning_rate\", 0);\n    JUST(CheckScalarDataType(&learning_rate, DataType::kFloat));\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> SetInputArgModifierMutable(const user_op::GetInputArgModifier& GetInputArgModifierFn,\n                                       const std::string& arg_name, int32_t arg_index) {\n  user_op::InputArgModifier* arg_modifier = GetInputArgModifierFn(arg_name, arg_index);\n  CHECK_NOTNULL_OR_RETURN(arg_modifier) << \"Arg Modifier should not be null. \";\n  arg_modifier->set_is_mutable(true);\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InferSGDUpdateTensorDesc(user_op::InferContext* ctx) {\n  const int64_t weight_size = ctx->input_size(\"model\");\n  for (int i = 0; i < weight_size; i++) {\n    const user_op::TensorDesc& model = ctx->InputTensorDesc(\"model\", i);\n    const user_op::TensorDesc& model_diff = ctx->InputTensorDesc(\"model_diff\", i);\n    CHECK_EQ_OR_RETURN(model_diff.shape(), model.shape())\n        << \"Model Diff shape should be equal to Model shape. \";\n  }\n  JUST(CheckLearningRateShape(ctx));\n  if (ctx->has_input(\"scale_by_tensor\", 0)) {\n    const auto& scale_by_tensor = ctx->InputTensorDesc(\"scale_by_tensor\", 0);\n    JUST(CheckScalarShape(&scale_by_tensor));\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InferSGDUpdateDataType(user_op::InferContext* ctx) {\n  JUST(CheckLearningRateDataType(ctx));\n  const user_op::TensorDesc& first_model_desc = ctx->InputTensorDesc(\"model\", 0);\n  const int64_t input_size = ctx->input_size(\"model\");\n  for (int64_t i = 0; i < input_size; i++) {\n    const user_op::TensorDesc& model = ctx->InputTensorDesc(\"model\", i);\n    CHECK_EQ(model.data_type(), first_model_desc.data_type()) << \"Model DataType should be equal. \";\n  }\n  if (ctx->has_input(\"scale_by_tensor\", 0)) {\n    const auto& scale_by_tensor = ctx->InputTensorDesc(\"scale_by_tensor\", 0);\n    JUST(CheckScalarDataType(&scale_by_tensor, first_model_desc.data_type()));\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> SgdInputArgModifyFn(const user_op::GetInputArgModifier& GetInputArgModifierFn,\n                                const user_op::UserOpConfWrapper& conf) {\n  for (int64_t i = 0; i < conf.input_size(\"model\"); i++) {\n    JUST(SetInputArgModifierMutable(GetInputArgModifierFn, \"model\", i));\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InferMomentumUpdateTensorDesc(user_op::InferContext* ctx) {\n  const int64_t weight_size = ctx->input_size(\"model\");\n  for (int i = 0; i < weight_size; i++) {\n    const user_op::TensorDesc& model = ctx->InputTensorDesc(\"model\", i);\n    const user_op::TensorDesc& model_diff = ctx->InputTensorDesc(\"model_diff\", i);\n    const user_op::TensorDesc& momentum_buf = ctx->InputTensorDesc(\"momentum_buf\", i);\n    CHECK_EQ_OR_RETURN(model_diff.shape(), model.shape())\n        << \"Model Diff shape should be equal to Model shape. \";\n    CHECK_EQ_OR_RETURN(momentum_buf.shape(), model.shape())\n        << \"Momentum buf shape should be equal to Model shape. \";\n  }\n  JUST(CheckLearningRateShape(ctx));\n  if (ctx->has_input(\"scale_by_tensor\", 0)) {\n    const auto& scale_by_tensor = ctx->InputTensorDesc(\"scale_by_tensor\", 0);\n    JUST(CheckScalarShape(&scale_by_tensor));\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InferMomentumUpdateDataType(user_op::InferContext* ctx) {\n  JUST(CheckLearningRateDataType(ctx));\n  const user_op::TensorDesc& first_model_desc = ctx->InputTensorDesc(\"model\", 0);\n  const int64_t input_size = ctx->input_size(\"model\");\n  for (int64_t i = 0; i < input_size; i++) {\n    const user_op::TensorDesc& model = ctx->InputTensorDesc(\"model\", i);\n    const user_op::TensorDesc& momentum_buf = ctx->InputTensorDesc(\"momentum_buf\", i);\n    CHECK_EQ(model.data_type(), first_model_desc.data_type()) << \"Model DataType should be equal. \";\n    CHECK_EQ(momentum_buf.data_type(), first_model_desc.data_type())\n        << \"Momentum buf DataType should be equal. \";\n  }\n  if (ctx->has_input(\"scale_by_tensor\", 0)) {\n    const auto& scale_by_tensor = ctx->InputTensorDesc(\"scale_by_tensor\", 0);\n    JUST(CheckScalarDataType(&scale_by_tensor, first_model_desc.data_type()));\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> MomentumInputArgModifyFn(const user_op::GetInputArgModifier& GetInputArgModifierFn,\n                                     const user_op::UserOpConfWrapper& conf) {\n  for (int64_t i = 0; i < conf.input_size(\"model\"); i++) {\n    JUST(SetInputArgModifierMutable(GetInputArgModifierFn, \"model\", i));\n    JUST(SetInputArgModifierMutable(GetInputArgModifierFn, \"momentum_buf\", i));\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InferAdamUpdateTensorDesc(user_op::InferContext* ctx) {\n  const int64_t weight_size = ctx->input_size(\"model\");\n  for (int i = 0; i < weight_size; i++) {\n    const user_op::TensorDesc& model = ctx->InputTensorDesc(\"model\", i);\n    const user_op::TensorDesc& model_diff = ctx->InputTensorDesc(\"model_diff\", i);\n    const user_op::TensorDesc& m = ctx->InputTensorDesc(\"m\", i);\n    const user_op::TensorDesc& v = ctx->InputTensorDesc(\"v\", i);\n\n    CHECK_EQ_OR_RETURN(model_diff.shape(), model.shape())\n        << \"Model Diff shape should be equal to Model shape. \";\n    CHECK_EQ_OR_RETURN(m.shape(), model.shape()) << \"m shape should be equal to Model shape. \";\n    CHECK_EQ_OR_RETURN(v.shape(), model.shape()) << \"v shape should be equal to Model shape. \";\n  }\n  JUST(CheckLearningRateShape(ctx));\n  if (ctx->has_input(\"scale_by_tensor\", 0)) {\n    const auto& scale_by_tensor = ctx->InputTensorDesc(\"scale_by_tensor\", 0);\n    JUST(CheckScalarShape(&scale_by_tensor));\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InferAdamUpdateDataType(user_op::InferContext* ctx) {  // todo\n  JUST(CheckLearningRateDataType(ctx));\n  const user_op::TensorDesc& first_model_desc = ctx->InputTensorDesc(\"model\", 0);\n  const int64_t input_size = ctx->input_size(\"model\");\n  for (int64_t i = 0; i < input_size; i++) {\n    const user_op::TensorDesc& model = ctx->InputTensorDesc(\"model\", i);\n    const user_op::TensorDesc& m = ctx->InputTensorDesc(\"m\", i);\n    const user_op::TensorDesc& v = ctx->InputTensorDesc(\"v\", i);\n    CHECK_EQ(model.data_type(), first_model_desc.data_type()) << \"Model DataType should be equal. \";\n    CHECK_EQ(m.data_type(), first_model_desc.data_type()) << \"m DataType should be equal. \";\n    CHECK_EQ(v.data_type(), first_model_desc.data_type()) << \"v DataType should be equal. \";\n  }\n  if (ctx->has_input(\"scale_by_tensor\", 0)) {\n    const auto& scale_by_tensor = ctx->InputTensorDesc(\"scale_by_tensor\", 0);\n    JUST(CheckScalarDataType(&scale_by_tensor, first_model_desc.data_type()));\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> AdamInputArgModifyFn(const user_op::GetInputArgModifier& GetInputArgModifierFn,\n                                 const user_op::UserOpConfWrapper& conf) {\n  for (int64_t i = 0; i < conf.input_size(\"model\"); i++) {\n    JUST(SetInputArgModifierMutable(GetInputArgModifierFn, \"model\", i));\n    JUST(SetInputArgModifierMutable(GetInputArgModifierFn, \"m\", i));\n    JUST(SetInputArgModifierMutable(GetInputArgModifierFn, \"v\", i));\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InferSGDUpdateWithCastTensorDesc(user_op::InferContext* ctx) {\n  const int64_t weight_size = ctx->input_size(\"model\");\n  for (int i = 0; i < weight_size; i++) {\n    const user_op::TensorDesc& model = ctx->InputTensorDesc(\"model\", i);\n    const user_op::TensorDesc& model_copy = ctx->InputTensorDesc(\"model_copy\", i);\n    const user_op::TensorDesc& model_diff = ctx->InputTensorDesc(\"model_diff\", i);\n    CHECK_EQ_OR_RETURN(model_diff.shape(), model.shape())\n        << \"Model diff shape should be equal to Model shape. \";\n    CHECK_EQ_OR_RETURN(model_copy.shape(), model.shape())\n        << \"Model copy shape should be equal to Model shape. \";\n  }\n  JUST(CheckLearningRateShape(ctx));\n  if (ctx->has_input(\"scale_by_tensor\", 0)) {\n    const auto& scale_by_tensor = ctx->InputTensorDesc(\"scale_by_tensor\", 0);\n    JUST(CheckScalarShape(&scale_by_tensor));\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> SgdWithCastInputArgModifyFn(const user_op::GetInputArgModifier& GetInputArgModifierFn,\n                                        const user_op::UserOpConfWrapper& conf) {\n  for (int64_t i = 0; i < conf.input_size(\"model\"); i++) {\n    JUST(SetInputArgModifierMutable(GetInputArgModifierFn, \"model\", i));\n    JUST(SetInputArgModifierMutable(GetInputArgModifierFn, \"model_copy\", i));\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InferMomentumUpdateWithCastTensorDesc(user_op::InferContext* ctx) {\n  const int64_t weight_size = ctx->input_size(\"model\");\n  for (int i = 0; i < weight_size; i++) {\n    const user_op::TensorDesc& model = ctx->InputTensorDesc(\"model\", i);\n    const user_op::TensorDesc& model_copy = ctx->InputTensorDesc(\"model_copy\", i);\n    const user_op::TensorDesc& momentum_buf = ctx->InputTensorDesc(\"momentum_buf\", i);\n    const user_op::TensorDesc& model_diff = ctx->InputTensorDesc(\"model_diff\", i);\n    CHECK_EQ_OR_RETURN(model_diff.shape(), model.shape())\n        << \"Model diff shape should be equal to Model shape. \";\n    CHECK_EQ_OR_RETURN(momentum_buf.shape(), model.shape())\n        << \"Momentum buf shape should be equal to Model shape. \";\n    CHECK_EQ_OR_RETURN(model_copy.shape(), model.shape())\n        << \"Model copy shape should be equal to Model shape. \";\n  }\n  JUST(CheckLearningRateShape(ctx));\n  if (ctx->has_input(\"scale_by_tensor\", 0)) {\n    const auto& scale_by_tensor = ctx->InputTensorDesc(\"scale_by_tensor\", 0);\n    JUST(CheckScalarShape(&scale_by_tensor));\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> MomentumWithCastInputArgModifyFn(\n    const user_op::GetInputArgModifier& GetInputArgModifierFn,\n    const user_op::UserOpConfWrapper& conf) {\n  for (int64_t i = 0; i < conf.input_size(\"model\"); i++) {\n    JUST(SetInputArgModifierMutable(GetInputArgModifierFn, \"model\", i));\n    JUST(SetInputArgModifierMutable(GetInputArgModifierFn, \"momentum_buf\", i));\n    JUST(SetInputArgModifierMutable(GetInputArgModifierFn, \"model_copy\", i));\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InferAdamUpdateWithCastTensorDesc(user_op::InferContext* ctx) {\n  const int64_t weight_size = ctx->input_size(\"model\");\n  for (int i = 0; i < weight_size; i++) {\n    const user_op::TensorDesc& model = ctx->InputTensorDesc(\"model\", i);\n    const user_op::TensorDesc& model_diff = ctx->InputTensorDesc(\"model_diff\", i);\n    const user_op::TensorDesc& model_copy = ctx->InputTensorDesc(\"model_copy\", i);\n    const user_op::TensorDesc& m = ctx->InputTensorDesc(\"m\", i);\n    const user_op::TensorDesc& v = ctx->InputTensorDesc(\"v\", i);\n\n    CHECK_EQ_OR_RETURN(model_diff.shape(), model.shape())\n        << \"Model diff shape should be equal to Model shape. \";\n    CHECK_EQ_OR_RETURN(model_copy.shape(), model.shape())\n        << \"Model copy shape should be equal to Model shape. \";\n    CHECK_EQ_OR_RETURN(m.shape(), model.shape()) << \"m shape should be equal to Model shape. \";\n    CHECK_EQ_OR_RETURN(v.shape(), model.shape()) << \"v shape should be equal to Model shape. \";\n  }\n  JUST(CheckLearningRateShape(ctx));\n  if (ctx->has_input(\"scale_by_tensor\", 0)) {\n    const auto& scale_by_tensor = ctx->InputTensorDesc(\"scale_by_tensor\", 0);\n    JUST(CheckScalarShape(&scale_by_tensor));\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> AdamWithCastInputArgModifyFn(const user_op::GetInputArgModifier& GetInputArgModifierFn,\n                                         const user_op::UserOpConfWrapper& conf) {\n  for (int64_t i = 0; i < conf.input_size(\"model\"); i++) {\n    JUST(SetInputArgModifierMutable(GetInputArgModifierFn, \"model\", i));\n    JUST(SetInputArgModifierMutable(GetInputArgModifierFn, \"model_copy\", i));\n    JUST(SetInputArgModifierMutable(GetInputArgModifierFn, \"m\", i));\n    JUST(SetInputArgModifierMutable(GetInputArgModifierFn, \"v\", i));\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InferYoloV5WeightUpdateTensorDesc(user_op::InferContext* ctx) {\n  const int64_t weight_size = ctx->input_size(\"model\");\n  for (int i = 0; i < weight_size; i++) {\n    const user_op::TensorDesc& model_i = ctx->InputTensorDesc(\"model\", i);\n    const user_op::TensorDesc& model_update_i = ctx->InputTensorDesc(\"model_update\", i);\n    CHECK_EQ_OR_RETURN(model_update_i.shape(), model_i.shape())\n        << \"All Model shape should be equal to model_update shape.\";\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InferYoloV5WeightUpdateDataType(user_op::InferContext* ctx) {\n  JUST(CheckLearningRateDataType(ctx));\n  const user_op::TensorDesc& first_model_desc = ctx->InputTensorDesc(\"model\", 0);\n  const int64_t input_size = ctx->input_size(\"model\");\n  for (int64_t i = 0; i < input_size; i++) {\n    const user_op::TensorDesc& model = ctx->InputTensorDesc(\"model\", i);\n    const user_op::TensorDesc& model_update_i = ctx->InputTensorDesc(\"model_update\", i);\n    CHECK_EQ(model.data_type(), first_model_desc.data_type()) << \"Model DataType should be equal. \";\n    CHECK_EQ(model_update_i.data_type(), first_model_desc.data_type())\n        << \"Model DataType should be equal to model_update DataType.\";\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> YoloV5WeightInputArgModifyFn(const user_op::GetInputArgModifier& GetInputArgModifierFn,\n                                         const user_op::UserOpConfWrapper& conf) {\n  for (int64_t i = 0; i < conf.input_size(\"model\"); i++) {\n    JUST(SetInputArgModifierMutable(GetInputArgModifierFn, \"model\", i));\n    JUST(SetInputArgModifierMutable(GetInputArgModifierFn, \"model_update\", i));\n  }\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\n/* static */ Maybe<void> MultiTensorSgdUpdateOp::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferSGDUpdateTensorDesc(ctx);\n}\n\n/*static*/ Maybe<void> MultiTensorSgdUpdateOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> MultiTensorSgdUpdateOp::GetSbp(user_op::SbpContext* ctx) {\n  ctx->NewBuilder().Broadcast(ctx->inputs()).Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> MultiTensorSgdUpdateOp::ModifyInputArg(\n    const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) {\n  return SgdInputArgModifyFn(GetInputArgModifierFn, conf);\n}\n\n/* static */ Maybe<void> MultiTensorSgdUpdateOp::InferDataType(user_op::InferContext* ctx) {\n  return InferSGDUpdateDataType(ctx);\n}\n\n/* static */ Maybe<void> MultiTensorMomentumUpdateOp::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferMomentumUpdateTensorDesc(ctx);\n}\n\n/*static*/ Maybe<void> MultiTensorMomentumUpdateOp::InferPhysicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> MultiTensorMomentumUpdateOp::GetSbp(user_op::SbpContext* ctx) {\n  ctx->NewBuilder().Broadcast(ctx->inputs()).Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> MultiTensorMomentumUpdateOp::ModifyInputArg(\n    const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) {\n  return MomentumInputArgModifyFn(GetInputArgModifierFn, conf);\n}\n\n/* static */ Maybe<void> MultiTensorMomentumUpdateOp::InferDataType(user_op::InferContext* ctx) {\n  return InferMomentumUpdateDataType(ctx);\n}\n\n/* static */ Maybe<void> MultiTensorAdamUpdateOp::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferAdamUpdateTensorDesc(ctx);\n}\n\n/*static*/ Maybe<void> MultiTensorAdamUpdateOp::InferPhysicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> MultiTensorAdamUpdateOp::GetSbp(user_op::SbpContext* ctx) {\n  ctx->NewBuilder().Broadcast(ctx->inputs()).Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> MultiTensorAdamUpdateOp::ModifyInputArg(\n    const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) {\n  return AdamInputArgModifyFn(GetInputArgModifierFn, conf);\n}\n\n/* static */ Maybe<void> MultiTensorAdamUpdateOp::InferDataType(user_op::InferContext* ctx) {\n  return InferAdamUpdateDataType(ctx);\n}\n\n/* static */ Maybe<void> MultiTensorSgdUpdateWithCastOp::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferSGDUpdateTensorDesc(ctx);\n}\n\n/*static*/ Maybe<void> MultiTensorSgdUpdateWithCastOp::InferPhysicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> MultiTensorSgdUpdateWithCastOp::GetSbp(user_op::SbpContext* ctx) {\n  ctx->NewBuilder().Broadcast(ctx->inputs()).Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> MultiTensorSgdUpdateWithCastOp::ModifyInputArg(\n    const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) {\n  return SgdWithCastInputArgModifyFn(GetInputArgModifierFn, conf);\n}\n\n/* static */ Maybe<void> MultiTensorSgdUpdateWithCastOp::InferDataType(user_op::InferContext* ctx) {\n  return InferSGDUpdateDataType(ctx);\n}\n\n/* static */ Maybe<void> MultiTensorMomentumUpdateWithCastOp::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferMomentumUpdateTensorDesc(ctx);\n}\n\n/*static*/ Maybe<void> MultiTensorMomentumUpdateWithCastOp::InferPhysicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> MultiTensorMomentumUpdateWithCastOp::GetSbp(user_op::SbpContext* ctx) {\n  ctx->NewBuilder().Broadcast(ctx->inputs()).Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> MultiTensorMomentumUpdateWithCastOp::ModifyInputArg(\n    const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) {\n  return MomentumWithCastInputArgModifyFn(GetInputArgModifierFn, conf);\n}\n\n/* static */ Maybe<void> MultiTensorMomentumUpdateWithCastOp::InferDataType(\n    user_op::InferContext* ctx) {\n  return InferMomentumUpdateDataType(ctx);\n}\n\n/* static */ Maybe<void> MultiTensorAdamUpdateWithCastOp::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferAdamUpdateWithCastTensorDesc(ctx);\n}\n\n/*static*/ Maybe<void> MultiTensorAdamUpdateWithCastOp::InferPhysicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> MultiTensorAdamUpdateWithCastOp::GetSbp(user_op::SbpContext* ctx) {\n  ctx->NewBuilder().Broadcast(ctx->inputs()).Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> MultiTensorAdamUpdateWithCastOp::ModifyInputArg(\n    const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) {\n  return AdamWithCastInputArgModifyFn(GetInputArgModifierFn, conf);\n}\n\n/* static */ Maybe<void> MultiTensorAdamUpdateWithCastOp::InferDataType(\n    user_op::InferContext* ctx) {\n  return InferAdamUpdateDataType(ctx);\n}\n\n/* static */ Maybe<void> MultiTensorYoloV5WeightUpdateOp::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferYoloV5WeightUpdateTensorDesc(ctx);\n}\n\n/*static*/ Maybe<void> MultiTensorYoloV5WeightUpdateOp::InferPhysicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> MultiTensorYoloV5WeightUpdateOp::GetSbp(user_op::SbpContext* ctx) {\n  ctx->NewBuilder().Broadcast(ctx->inputs()).Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> MultiTensorYoloV5WeightUpdateOp::ModifyInputArg(\n    const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) {\n  return YoloV5WeightInputArgModifyFn(GetInputArgModifierFn, conf);\n}\n\n/* static */ Maybe<void> MultiTensorYoloV5WeightUpdateOp::InferDataType(\n    user_op::InferContext* ctx) {\n  return InferYoloV5WeightUpdateDataType(ctx);\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/mutable_cast_once_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/* static */ Maybe<void> MutableCastOnceOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& input_tensor_desc = ctx->InputTensorDesc(\"in\", 0);\n  user_op::TensorDesc* output_tensor_desc = ctx->MutOutputTensorDesc(\"out\", 0);\n  output_tensor_desc->set_shape(input_tensor_desc.shape());\n  output_tensor_desc->set_is_dynamic(input_tensor_desc.is_dynamic());\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> MutableCastOnceOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> MutableCastOnceOp::GetSbp(user_op::SbpContext* ctx) {\n  const auto& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"in\", 0);\n  for (int i = 0; i < in_tensor.shape().NumAxes(); ++i) {\n    ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build();\n  }\n  ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> MutableCastOnceOp::InferDataType(user_op::InferContext* ctx) {\n  user_op::TensorDesc* output_tensor_desc = ctx->MutOutputTensorDesc(\"out\", 0);\n  output_tensor_desc->set_data_type(ctx->Attr<DataType>(\"dtype\"));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/narrow_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/* static */ Maybe<void> NarrowOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& in = ctx->InputTensorDesc(\"in\", 0);\n  CHECK_GT_OR_RETURN(in.shape().NumAxes(), 0);\n  const int64_t& dim = ctx->Attr<int64_t>(\"dim\");\n  const int64_t& start = ctx->Attr<int64_t>(\"start\");\n  int64_t length = ctx->Attr<int64_t>(\"length\");\n  CHECK_GE_OR_RETURN(dim, 0);\n  CHECK_GE_OR_RETURN(start, 0);\n  CHECK_GE_OR_RETURN(length, 0);\n  // length should be input size if split the full slice dimension\n  if (start == 0 && length > in.shape().At(dim)) { length = in.shape().At(dim); }\n  user_op::TensorDesc* out = ctx->MutOutputTensorDesc(\"out\", 0);\n\n  DimVector dim_vec;\n  dim_vec.insert(dim_vec.end(), in.shape().dim_vec().cbegin(), in.shape().dim_vec().cbegin() + dim);\n  dim_vec.insert(dim_vec.end(), length);\n  dim_vec.insert(dim_vec.end(), in.shape().dim_vec().cbegin() + dim + 1,\n                 in.shape().dim_vec().end());\n  out->set_shape(Shape(dim_vec));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> NarrowOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> NarrowOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"in\", 0);\n  const int64_t& dim = ctx->Attr<int64_t>(\"dim\");\n  const int64_t& length = ctx->Attr<int64_t>(\"length\");\n  FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) {\n    if (i != dim) {\n      ctx->NewBuilder()\n          .Split(user_op::OpArg(\"in\", 0), i)\n          .Split(user_op::OpArg(\"out\", 0), i)\n          .Build();\n    } else {\n      if (length == in_tensor.shape().At(i)) {\n        ctx->NewBuilder()\n            .Split(user_op::OpArg(\"in\", 0), i)\n            .Split(user_op::OpArg(\"out\", 0), i)\n            .Build();\n      }\n    }\n  }\n  ctx->NewBuilder()\n      .PartialSum(user_op::OpArg(\"in\", 0))\n      .PartialSum(user_op::OpArg(\"out\", 0))\n      .Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> NarrowOp::InferDataType(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& in = ctx->InputTensorDesc(\"in\", 0);\n  user_op::TensorDesc* out = ctx->MutOutputTensorDesc(\"out\", 0);\n  out->set_data_type(in.data_type());\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> NarrowGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& like_shape = ctx->InputShape(\"like\", 0);\n  const Shape& dy_shape = ctx->InputShape(\"dy\", 0);\n  const int64_t ndim = dy_shape.NumAxes();\n  CHECK_EQ_OR_RETURN(like_shape.NumAxes(), ndim);\n\n  ctx->SetOutputShape(\"dx\", 0, like_shape);\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> NarrowGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> NarrowGradOp::GetSbp(user_op::SbpContext* ctx) {\n  const Shape& like_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"like\", 0).shape();\n  const int64_t ndim = like_shape.NumAxes();\n  const int64_t& dim = ctx->Attr<int64_t>(\"dim\");\n  const int64_t& length = ctx->Attr<int64_t>(\"length\");\n  FOR_RANGE(int64_t, i, 0, ndim) {\n    if (i != dim) {\n      ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build();\n    } else {\n      if (length == like_shape.At(i)) {\n        ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build();\n      }\n    }\n  }\n  ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build();\n  ctx->NewBuilder()\n      .PartialSum(user_op::OpArg(\"dy\", 0))\n      .Broadcast(user_op::OpArg(\"like\", 0))\n      .PartialSum(user_op::OpArg(\"dx\", 0))\n      .Build();\n  ctx->NewBuilder()\n      .Broadcast(user_op::OpArg(\"dy\", 0))\n      .PartialSum(user_op::OpArg(\"like\", 0))\n      .Broadcast(user_op::OpArg(\"dx\", 0))\n      .Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> NarrowGradOp::ModifyInputArg(\n    const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) {\n  user_op::InputArgModifier* dy_modifier = GetInputArgModifierFn(\"dy\", 0);\n  CHECK_NOTNULL_OR_RETURN(dy_modifier);\n  dy_modifier->set_requires_grad(false);\n  user_op::InputArgModifier* like_modifier = GetInputArgModifierFn(\"like\", 0);\n  CHECK_NOTNULL_OR_RETURN(like_modifier);\n  like_modifier->set_requires_grad(false);\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> NarrowGradOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"dx\", 0, ctx->InputDType(\"dy\", 0));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/nccl_logical_2d_sbp_ops.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n#include \"oneflow/core/operator/operator.h\"\n#include \"oneflow/user/ops/comm_net_device_infer_util.h\"\n#include \"oneflow/user/ops/nccl_logical_util.h\"\n\nnamespace oneflow {\n\n/* static */ Maybe<void> _ncclLogical_2DSameDim0AllReduceOp::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) {\n  ctx->SetOutputShape(\"out\", 0, ctx->InputShape(\"in\", 0));\n  ctx->SetOutputIsDynamic(\"out\", 0, ctx->InputIsDynamic(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> _ncclLogical_2DSameDim0AllReduceOp::GetSbp(user_op::SbpContext* ctx) {\n  return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx);\n}\n\n/* static */ Maybe<void> _ncclLogical_2DSameDim0AllReduceOp::InferNdSbp(\n    user_op::InferNdSbpFnContext* ctx) {\n  NdSbp* input_nd_sbp = ctx->NdSbp4ArgNameAndIndex(\"in\", 0);\n  NdSbp* output_nd_sbp = ctx->NdSbp4ArgNameAndIndex(\"out\", 0);\n  input_nd_sbp->clear_sbp_parallel();\n  output_nd_sbp->clear_sbp_parallel();\n\n  JUST(GetNcclLogicalNdSbpFromAttr(ctx, \"src_reduced_nd_sbp\", input_nd_sbp));\n  JUST(GetNcclLogicalNdSbpFromAttr(ctx, \"dst_reduced_nd_sbp\", output_nd_sbp));\n  // (*, P) -> (*, B)\n  CHECK_EQ_OR_RETURN(input_nd_sbp->sbp_parallel_size(), 2);\n  CHECK_EQ_OR_RETURN(output_nd_sbp->sbp_parallel_size(), 2);\n  CHECK_OR_RETURN(input_nd_sbp->sbp_parallel(0) == output_nd_sbp->sbp_parallel(0));\n  CHECK_OR_RETURN(input_nd_sbp->sbp_parallel(1).has_partial_sum_parallel());\n  CHECK_OR_RETURN(output_nd_sbp->sbp_parallel(1).has_broadcast_parallel());\n  CHECK_EQ_OR_RETURN(ctx->parallel_hierarchy().NumAxes(), 2);\n\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> _ncclLogical_2DSameDim0AllReduceOp::InferDataType(\n    user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<Symbol<Stream>> _ncclLogical_2DSameDim0AllReduceOp::InferDeviceAndStream(\n    user_op::DeviceAndStreamInferContext* ctx) {\n  return DeviceAndStreamInferFn(ctx);\n}\n\n/* static */ Maybe<void> _ncclLogical_2DSameDim1AllReduceOp::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) {\n  ctx->SetOutputShape(\"out\", 0, ctx->InputShape(\"in\", 0));\n  ctx->SetOutputIsDynamic(\"out\", 0, ctx->InputIsDynamic(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> _ncclLogical_2DSameDim1AllReduceOp::GetSbp(user_op::SbpContext* ctx) {\n  return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx);\n}\n\n/* static */ Maybe<void> _ncclLogical_2DSameDim1AllReduceOp::InferNdSbp(\n    user_op::InferNdSbpFnContext* ctx) {\n  NdSbp* input_nd_sbp = ctx->NdSbp4ArgNameAndIndex(\"in\", 0);\n  NdSbp* output_nd_sbp = ctx->NdSbp4ArgNameAndIndex(\"out\", 0);\n  input_nd_sbp->clear_sbp_parallel();\n  output_nd_sbp->clear_sbp_parallel();\n\n  JUST(GetNcclLogicalNdSbpFromAttr(ctx, \"src_reduced_nd_sbp\", input_nd_sbp));\n  JUST(GetNcclLogicalNdSbpFromAttr(ctx, \"dst_reduced_nd_sbp\", output_nd_sbp));\n  // (P, *) -> (B, *)\n  CHECK_EQ_OR_RETURN(input_nd_sbp->sbp_parallel_size(), 2);\n  CHECK_EQ_OR_RETURN(output_nd_sbp->sbp_parallel_size(), 2);\n  CHECK_OR_RETURN(input_nd_sbp->sbp_parallel(0).has_partial_sum_parallel());\n  CHECK_OR_RETURN(output_nd_sbp->sbp_parallel(0).has_broadcast_parallel());\n  CHECK_OR_RETURN(input_nd_sbp->sbp_parallel(1) == output_nd_sbp->sbp_parallel(1));\n  CHECK_EQ_OR_RETURN(ctx->parallel_hierarchy().NumAxes(), 2);\n\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> _ncclLogical_2DSameDim1AllReduceOp::InferDataType(\n    user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<Symbol<Stream>> _ncclLogical_2DSameDim1AllReduceOp::InferDeviceAndStream(\n    user_op::DeviceAndStreamInferContext* ctx) {\n  return DeviceAndStreamInferFn(ctx);\n}\n\n/* static */ Maybe<void> _ncclLogical_2DSameDim0AllGatherOp::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) {\n  ctx->SetOutputShape(\"out\", 0, ctx->InputShape(\"in\", 0));\n  ctx->SetOutputIsDynamic(\"out\", 0, ctx->InputIsDynamic(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> _ncclLogical_2DSameDim0AllGatherOp::GetSbp(user_op::SbpContext* ctx) {\n  return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx);\n}\n\n/* static */ Maybe<void> _ncclLogical_2DSameDim0AllGatherOp::InferNdSbp(\n    user_op::InferNdSbpFnContext* ctx) {\n  NdSbp* input_nd_sbp = ctx->NdSbp4ArgNameAndIndex(\"in\", 0);\n  NdSbp* output_nd_sbp = ctx->NdSbp4ArgNameAndIndex(\"out\", 0);\n  input_nd_sbp->clear_sbp_parallel();\n  output_nd_sbp->clear_sbp_parallel();\n\n  JUST(GetNcclLogicalNdSbpFromAttr(ctx, \"src_reduced_nd_sbp\", input_nd_sbp));\n  JUST(GetNcclLogicalNdSbpFromAttr(ctx, \"dst_reduced_nd_sbp\", output_nd_sbp));\n  // (*, S(0)) -> (*, B)\n  CHECK_EQ_OR_RETURN(input_nd_sbp->sbp_parallel_size(), 2);\n  CHECK_EQ_OR_RETURN(output_nd_sbp->sbp_parallel_size(), 2);\n  CHECK_OR_RETURN(input_nd_sbp->sbp_parallel(0) == output_nd_sbp->sbp_parallel(0));\n  CHECK_OR_RETURN(input_nd_sbp->sbp_parallel(1).has_split_parallel());\n  CHECK_EQ_OR_RETURN(input_nd_sbp->sbp_parallel(1).split_parallel().axis(), 0);\n  CHECK_OR_RETURN(output_nd_sbp->sbp_parallel(1).has_broadcast_parallel());\n  CHECK_EQ_OR_RETURN(ctx->parallel_hierarchy().NumAxes(), 2);\n\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> _ncclLogical_2DSameDim0AllGatherOp::InferDataType(\n    user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<Symbol<Stream>> _ncclLogical_2DSameDim0AllGatherOp::InferDeviceAndStream(\n    user_op::DeviceAndStreamInferContext* ctx) {\n  return DeviceAndStreamInferFn(ctx);\n}\n\n/* static */ Maybe<void> _ncclLogical_2DSameDim0AllGatherNoncontinuousOp::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) {\n  ctx->SetOutputShape(\"out\", 0, ctx->InputShape(\"in\", 0));\n  ctx->SetOutputIsDynamic(\"out\", 0, ctx->InputIsDynamic(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> _ncclLogical_2DSameDim0AllGatherNoncontinuousOp::GetSbp(\n    user_op::SbpContext* ctx) {\n  return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx);\n}\n\n/* static */ Maybe<void> _ncclLogical_2DSameDim0AllGatherNoncontinuousOp::InferNdSbp(\n    user_op::InferNdSbpFnContext* ctx) {\n  NdSbp* input_nd_sbp = ctx->NdSbp4ArgNameAndIndex(\"in\", 0);\n  NdSbp* output_nd_sbp = ctx->NdSbp4ArgNameAndIndex(\"out\", 0);\n  input_nd_sbp->clear_sbp_parallel();\n  output_nd_sbp->clear_sbp_parallel();\n\n  JUST(GetNcclLogicalNdSbpFromAttr(ctx, \"src_reduced_nd_sbp\", input_nd_sbp));\n  JUST(GetNcclLogicalNdSbpFromAttr(ctx, \"dst_reduced_nd_sbp\", output_nd_sbp));\n  // (*, S(>=1)) -> (*, B)\n  CHECK_EQ_OR_RETURN(input_nd_sbp->sbp_parallel_size(), 2);\n  CHECK_EQ_OR_RETURN(output_nd_sbp->sbp_parallel_size(), 2);\n  CHECK_OR_RETURN(input_nd_sbp->sbp_parallel(0) == output_nd_sbp->sbp_parallel(0));\n  CHECK_OR_RETURN(input_nd_sbp->sbp_parallel(1).has_split_parallel());\n  CHECK_GE_OR_RETURN(input_nd_sbp->sbp_parallel(1).split_parallel().axis(), 1);\n  CHECK_OR_RETURN(output_nd_sbp->sbp_parallel(1).has_broadcast_parallel());\n  CHECK_EQ_OR_RETURN(ctx->parallel_hierarchy().NumAxes(), 2);\n\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> _ncclLogical_2DSameDim0AllGatherNoncontinuousOp::InferDataType(\n    user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<Symbol<Stream>>\n_ncclLogical_2DSameDim0AllGatherNoncontinuousOp::InferDeviceAndStream(\n    user_op::DeviceAndStreamInferContext* ctx) {\n  return DeviceAndStreamInferFn(ctx);\n}\n\n/* static */ Maybe<void> _ncclLogical_2DSameDim0All2allOp::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) {\n  ctx->SetOutputShape(\"out\", 0, ctx->InputShape(\"in\", 0));\n  ctx->SetOutputIsDynamic(\"out\", 0, ctx->InputIsDynamic(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> _ncclLogical_2DSameDim0All2allOp::GetSbp(user_op::SbpContext* ctx) {\n  return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx);\n}\n\n/* static */ Maybe<void> _ncclLogical_2DSameDim0All2allOp::InferNdSbp(\n    user_op::InferNdSbpFnContext* ctx) {\n  NdSbp* input_nd_sbp = ctx->NdSbp4ArgNameAndIndex(\"in\", 0);\n  NdSbp* output_nd_sbp = ctx->NdSbp4ArgNameAndIndex(\"out\", 0);\n  input_nd_sbp->clear_sbp_parallel();\n  output_nd_sbp->clear_sbp_parallel();\n\n  JUST(GetNcclLogicalNdSbpFromAttr(ctx, \"src_reduced_nd_sbp\", input_nd_sbp));\n  JUST(GetNcclLogicalNdSbpFromAttr(ctx, \"dst_reduced_nd_sbp\", output_nd_sbp));\n  // (*, S) -> (*, S)\n  CHECK_EQ_OR_RETURN(input_nd_sbp->sbp_parallel_size(), 2);\n  CHECK_EQ_OR_RETURN(output_nd_sbp->sbp_parallel_size(), 2);\n  CHECK_OR_RETURN(input_nd_sbp->sbp_parallel(0) == output_nd_sbp->sbp_parallel(0));\n  CHECK_OR_RETURN(input_nd_sbp->sbp_parallel(1).has_split_parallel());\n  CHECK_OR_RETURN(output_nd_sbp->sbp_parallel(1).has_split_parallel());\n  CHECK_EQ_OR_RETURN(ctx->parallel_hierarchy().NumAxes(), 2);\n\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> _ncclLogical_2DSameDim0All2allOp::InferDataType(\n    user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<Symbol<Stream>> _ncclLogical_2DSameDim0All2allOp::InferDeviceAndStream(\n    user_op::DeviceAndStreamInferContext* ctx) {\n  return DeviceAndStreamInferFn(ctx);\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/nccl_logical_fusion_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n#include \"oneflow/core/operator/operator.h\"\n#include \"oneflow/user/ops/nccl_logical_util.h\"\n#include \"oneflow/user/ops/comm_net_device_infer_util.h\"\n#include \"oneflow/core/common/container_util.h\"\n\nnamespace oneflow {\n\n/* static */ Maybe<void> _ncclLogicalFusionOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  int32_t nccl_size = ctx->input_size(\"in\");\n  CHECK_EQ_OR_RETURN(nccl_size, ctx->output_size(\"out\"));  // NOLINT\n  for (int32_t i = 0; i < nccl_size; ++i) {\n    ctx->SetOutputShape(\"out\", i, ctx->InputShape(\"in\", i));\n    ctx->SetOutputIsDynamic(\"out\", i, ctx->InputIsDynamic(\"in\", i));\n  }\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> _ncclLogicalFusionOp::GetSbp(user_op::SbpContext* ctx) {\n  return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx);\n}\n\n/* static */ Maybe<void> _ncclLogicalFusionOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) {\n  int32_t nccl_size = ctx->inputs().size();\n  CHECK_EQ_OR_RETURN(nccl_size, ctx->outputs().size());  // NOLINT\n  const std::vector<std::string>& src_nd_sbp_str_list =\n      ctx->user_op_conf().attr<std::vector<std::string>>(\"src_nd_sbp_str_list\");\n  const std::vector<std::string>& dst_nd_sbp_str_list =\n      ctx->user_op_conf().attr<std::vector<std::string>>(\"dst_nd_sbp_str_list\");\n  CHECK_EQ_OR_RETURN(nccl_size, src_nd_sbp_str_list.size());  // NOLINT\n  CHECK_EQ_OR_RETURN(nccl_size, dst_nd_sbp_str_list.size());  // NOLINT\n  for (int32_t i = 0; i < nccl_size; ++i) {\n    NdSbp* input_nd_sbp = ctx->NdSbp4ArgNameAndIndex(\"in\", i);\n    NdSbp* output_nd_sbp = ctx->NdSbp4ArgNameAndIndex(\"out\", i);\n    input_nd_sbp->clear_sbp_parallel();\n    output_nd_sbp->clear_sbp_parallel();\n    CHECK_OR_RETURN(ParseNdSbpFromLongString(JUST(VectorAt(src_nd_sbp_str_list, i)), input_nd_sbp))\n        << Error::RuntimeError() << \" Cannot parse str: \" << JUST(VectorAt(src_nd_sbp_str_list, i))\n        << \" to input nd_sbp attr of op : \" << ctx->user_op_conf().op_name();\n    CHECK_OR_RETURN(ParseNdSbpFromLongString(JUST(VectorAt(dst_nd_sbp_str_list, i)), output_nd_sbp))\n        << Error::RuntimeError() << \" Cannot parse str: \" << JUST(VectorAt(dst_nd_sbp_str_list, i))\n        << \" to output nd_sbp attr of op : \" << ctx->user_op_conf().op_name();\n  }\n\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> _ncclLogicalFusionOp::InferDataType(user_op::InferContext* ctx) {\n  int32_t nccl_size = ctx->input_size(\"in\");\n  CHECK_EQ_OR_RETURN(nccl_size, ctx->output_size(\"out\"));  // NOLINT\n  for (int32_t i = 0; i < nccl_size; ++i) {\n    ctx->SetOutputDType(\"out\", i, ctx->InputDType(\"in\", i));\n  }\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<Symbol<Stream>> _ncclLogicalFusionOp::InferDeviceAndStream(\n    user_op::DeviceAndStreamInferContext* ctx) {\n  return DeviceAndStreamInferFn(ctx);\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/nccl_logical_ops.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n#include \"oneflow/core/operator/operator.h\"\n#include \"oneflow/user/ops/nccl_logical_util.h\"\n#include \"oneflow/user/ops/comm_net_device_infer_util.h\"\n\nnamespace oneflow {\n\n/* static */ Maybe<void> _ncclLogicalAllReduceOp::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) {\n  ctx->SetOutputShape(\"out\", 0, ctx->InputShape(\"in\", 0));\n  ctx->SetOutputIsDynamic(\"out\", 0, ctx->InputIsDynamic(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> _ncclLogicalAllReduceOp::GetSbp(user_op::SbpContext* ctx) {\n  return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx);\n}\n\n/* static */ Maybe<void> _ncclLogicalAllReduceOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) {\n  NdSbp* input_nd_sbp = ctx->NdSbp4ArgNameAndIndex(\"in\", 0);\n  NdSbp* output_nd_sbp = ctx->NdSbp4ArgNameAndIndex(\"out\", 0);\n  input_nd_sbp->clear_sbp_parallel();\n  output_nd_sbp->clear_sbp_parallel();\n\n  JUST(GetNcclLogicalNdSbpFromAttr(ctx, \"src_reduced_nd_sbp\", input_nd_sbp));\n  JUST(GetNcclLogicalNdSbpFromAttr(ctx, \"dst_reduced_nd_sbp\", output_nd_sbp));\n  // P->B\n  CHECK_EQ_OR_RETURN(input_nd_sbp->sbp_parallel_size(), 1);\n  CHECK_EQ_OR_RETURN(output_nd_sbp->sbp_parallel_size(), 1);\n  CHECK_OR_RETURN(input_nd_sbp->sbp_parallel(0).has_partial_sum_parallel());\n  CHECK_OR_RETURN(output_nd_sbp->sbp_parallel(0).has_broadcast_parallel());\n  CHECK_EQ_OR_RETURN(ctx->parallel_hierarchy().NumAxes(), 1);\n\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> _ncclLogicalAllReduceOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<Symbol<Stream>> _ncclLogicalAllReduceOp::InferDeviceAndStream(\n    user_op::DeviceAndStreamInferContext* ctx) {\n  return DeviceAndStreamInferFn(ctx);\n}\n\n/* static */ Maybe<void> _ncclLogicalReduceScatterOp::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) {\n  ctx->SetOutputShape(\"out\", 0, ctx->InputShape(\"in\", 0));\n  ctx->SetOutputIsDynamic(\"out\", 0, ctx->InputIsDynamic(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> _ncclLogicalReduceScatterOp::GetSbp(user_op::SbpContext* ctx) {\n  return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx);\n}\n\n/* static */ Maybe<void> _ncclLogicalReduceScatterOp::InferNdSbp(\n    user_op::InferNdSbpFnContext* ctx) {\n  NdSbp* input_nd_sbp = ctx->NdSbp4ArgNameAndIndex(\"in\", 0);\n  NdSbp* output_nd_sbp = ctx->NdSbp4ArgNameAndIndex(\"out\", 0);\n  input_nd_sbp->clear_sbp_parallel();\n  output_nd_sbp->clear_sbp_parallel();\n\n  JUST(GetNcclLogicalNdSbpFromAttr(ctx, \"src_reduced_nd_sbp\", input_nd_sbp));\n  JUST(GetNcclLogicalNdSbpFromAttr(ctx, \"dst_reduced_nd_sbp\", output_nd_sbp));\n  // P->S(0)\n  CHECK_EQ_OR_RETURN(input_nd_sbp->sbp_parallel_size(), 1);\n  CHECK_EQ_OR_RETURN(output_nd_sbp->sbp_parallel_size(), 1);\n  CHECK_OR_RETURN(input_nd_sbp->sbp_parallel(0).has_partial_sum_parallel());\n  CHECK_OR_RETURN(output_nd_sbp->sbp_parallel(0).has_split_parallel());\n  CHECK_EQ_OR_RETURN(output_nd_sbp->sbp_parallel(0).split_parallel().axis(), 0);\n  CHECK_EQ_OR_RETURN(ctx->parallel_hierarchy().NumAxes(), 1);\n\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> _ncclLogicalReduceScatterOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<Symbol<Stream>> _ncclLogicalReduceScatterOp::InferDeviceAndStream(\n    user_op::DeviceAndStreamInferContext* ctx) {\n  return DeviceAndStreamInferFn(ctx);\n}\n\n/* static */ Maybe<void> _ncclLogicalAllGatherOp::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) {\n  ctx->SetOutputShape(\"out\", 0, ctx->InputShape(\"in\", 0));\n  ctx->SetOutputIsDynamic(\"out\", 0, ctx->InputIsDynamic(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> _ncclLogicalAllGatherOp::GetSbp(user_op::SbpContext* ctx) {\n  return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx);\n}\n\n/* static */ Maybe<void> _ncclLogicalAllGatherOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) {\n  NdSbp* input_nd_sbp = ctx->NdSbp4ArgNameAndIndex(\"in\", 0);\n  NdSbp* output_nd_sbp = ctx->NdSbp4ArgNameAndIndex(\"out\", 0);\n  input_nd_sbp->clear_sbp_parallel();\n  output_nd_sbp->clear_sbp_parallel();\n\n  JUST(GetNcclLogicalNdSbpFromAttr(ctx, \"src_reduced_nd_sbp\", input_nd_sbp));\n  JUST(GetNcclLogicalNdSbpFromAttr(ctx, \"dst_reduced_nd_sbp\", output_nd_sbp));\n  // S(0)->B\n  CHECK_EQ_OR_RETURN(input_nd_sbp->sbp_parallel_size(), 1);\n  CHECK_EQ_OR_RETURN(output_nd_sbp->sbp_parallel_size(), 1);\n  CHECK_OR_RETURN(input_nd_sbp->sbp_parallel(0).has_split_parallel());\n  CHECK_EQ_OR_RETURN(input_nd_sbp->sbp_parallel(0).split_parallel().axis(), 0);\n  CHECK_OR_RETURN(output_nd_sbp->sbp_parallel(0).has_broadcast_parallel());\n  CHECK_EQ_OR_RETURN(ctx->parallel_hierarchy().NumAxes(), 1);\n\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> _ncclLogicalAllGatherOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<Symbol<Stream>> _ncclLogicalAllGatherOp::InferDeviceAndStream(\n    user_op::DeviceAndStreamInferContext* ctx) {\n  return DeviceAndStreamInferFn(ctx);\n}\n\n/* static */ Maybe<void> _ncclLogicalAllGatherNoncontinuousOp::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) {\n  ctx->SetOutputShape(\"out\", 0, ctx->InputShape(\"in\", 0));\n  ctx->SetOutputIsDynamic(\"out\", 0, ctx->InputIsDynamic(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> _ncclLogicalAllGatherNoncontinuousOp::GetSbp(user_op::SbpContext* ctx) {\n  return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx);\n}\n\n/* static */ Maybe<void> _ncclLogicalAllGatherNoncontinuousOp::InferNdSbp(\n    user_op::InferNdSbpFnContext* ctx) {\n  NdSbp* input_nd_sbp = ctx->NdSbp4ArgNameAndIndex(\"in\", 0);\n  NdSbp* output_nd_sbp = ctx->NdSbp4ArgNameAndIndex(\"out\", 0);\n  input_nd_sbp->clear_sbp_parallel();\n  output_nd_sbp->clear_sbp_parallel();\n\n  JUST(GetNcclLogicalNdSbpFromAttr(ctx, \"src_reduced_nd_sbp\", input_nd_sbp));\n  JUST(GetNcclLogicalNdSbpFromAttr(ctx, \"dst_reduced_nd_sbp\", output_nd_sbp));\n  // S(>=1)->B\n  CHECK_EQ_OR_RETURN(input_nd_sbp->sbp_parallel_size(), 1);\n  CHECK_EQ_OR_RETURN(output_nd_sbp->sbp_parallel_size(), 1);\n  CHECK_OR_RETURN(input_nd_sbp->sbp_parallel(0).has_split_parallel());\n  CHECK_GE_OR_RETURN(input_nd_sbp->sbp_parallel(0).split_parallel().axis(), 1);\n  CHECK_OR_RETURN(output_nd_sbp->sbp_parallel(0).has_broadcast_parallel());\n  CHECK_EQ_OR_RETURN(ctx->parallel_hierarchy().NumAxes(), 1);\n\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> _ncclLogicalAllGatherNoncontinuousOp::InferDataType(\n    user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<Symbol<Stream>> _ncclLogicalAllGatherNoncontinuousOp::InferDeviceAndStream(\n    user_op::DeviceAndStreamInferContext* ctx) {\n  return DeviceAndStreamInferFn(ctx);\n}\n\n/* static */ Maybe<void> _ncclLogicalReduceScatterNoncontinuousOp::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) {\n  ctx->SetOutputShape(\"out\", 0, ctx->InputShape(\"in\", 0));\n  ctx->SetOutputIsDynamic(\"out\", 0, ctx->InputIsDynamic(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> _ncclLogicalReduceScatterNoncontinuousOp::GetSbp(\n    user_op::SbpContext* ctx) {\n  return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx);\n}\n\n/* static */ Maybe<void> _ncclLogicalReduceScatterNoncontinuousOp::InferNdSbp(\n    user_op::InferNdSbpFnContext* ctx) {\n  NdSbp* input_nd_sbp = ctx->NdSbp4ArgNameAndIndex(\"in\", 0);\n  NdSbp* output_nd_sbp = ctx->NdSbp4ArgNameAndIndex(\"out\", 0);\n  input_nd_sbp->clear_sbp_parallel();\n  output_nd_sbp->clear_sbp_parallel();\n\n  JUST(GetNcclLogicalNdSbpFromAttr(ctx, \"src_reduced_nd_sbp\", input_nd_sbp));\n  JUST(GetNcclLogicalNdSbpFromAttr(ctx, \"dst_reduced_nd_sbp\", output_nd_sbp));\n  // P->S(0)\n  CHECK_EQ_OR_RETURN(input_nd_sbp->sbp_parallel_size(), 1) << \"input_nd_sbp should be 1d.\";\n  CHECK_EQ_OR_RETURN(output_nd_sbp->sbp_parallel_size(), 1) << \"output_nd_sbp should be 1d.\";\n  CHECK_OR_RETURN(input_nd_sbp->sbp_parallel(0).has_partial_sum_parallel())\n      << \"input_nd_sbp should be partial_sum_parallel.\";\n  CHECK_OR_RETURN(output_nd_sbp->sbp_parallel(0).has_split_parallel())\n      << \"output_nd_sbp should be split parallel.\";\n  CHECK_GE_OR_RETURN(output_nd_sbp->sbp_parallel(0).split_parallel().axis(), 1)\n      << \"output_nd_sbp split axis should greater equal 1.\";\n  CHECK_EQ_OR_RETURN(ctx->parallel_hierarchy().NumAxes(), 1) << \"parallel_hierarchy should be 1d.\";\n\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> _ncclLogicalReduceScatterNoncontinuousOp::InferDataType(\n    user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<Symbol<Stream>> _ncclLogicalReduceScatterNoncontinuousOp::InferDeviceAndStream(\n    user_op::DeviceAndStreamInferContext* ctx) {\n  return DeviceAndStreamInferFn(ctx);\n}\n\n/* static */ Maybe<void> _ncclLogicalS2sOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  ctx->SetOutputShape(\"out\", 0, ctx->InputShape(\"in\", 0));\n  ctx->SetOutputIsDynamic(\"out\", 0, ctx->InputIsDynamic(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> _ncclLogicalS2sOp::GetSbp(user_op::SbpContext* ctx) {\n  return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx);\n}\n\n/* static */ Maybe<void> _ncclLogicalS2sOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) {\n  NdSbp* input_nd_sbp = ctx->NdSbp4ArgNameAndIndex(\"in\", 0);\n  NdSbp* output_nd_sbp = ctx->NdSbp4ArgNameAndIndex(\"out\", 0);\n  input_nd_sbp->clear_sbp_parallel();\n  output_nd_sbp->clear_sbp_parallel();\n\n  JUST(GetNcclLogicalNdSbpFromAttr(ctx, \"src_reduced_nd_sbp\", input_nd_sbp));\n  JUST(GetNcclLogicalNdSbpFromAttr(ctx, \"dst_reduced_nd_sbp\", output_nd_sbp));\n  // S->S\n  CHECK_EQ_OR_RETURN(input_nd_sbp->sbp_parallel_size(), 1);\n  CHECK_EQ_OR_RETURN(output_nd_sbp->sbp_parallel_size(), 1);\n  CHECK_OR_RETURN(input_nd_sbp->sbp_parallel(0).has_split_parallel());\n  CHECK_OR_RETURN(output_nd_sbp->sbp_parallel(0).has_split_parallel());\n  CHECK_EQ_OR_RETURN(ctx->parallel_hierarchy().NumAxes(), 1);\n\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> _ncclLogicalS2sOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<Symbol<Stream>> _ncclLogicalS2sOp::InferDeviceAndStream(\n    user_op::DeviceAndStreamInferContext* ctx) {\n  return DeviceAndStreamInferFn(ctx);\n}\n\n/* static */ Maybe<void> _ncclLogicalSendRecvOp::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) {\n  ctx->SetOutputShape(\"out\", 0, ctx->InputShape(\"in\", 0));\n  ctx->SetOutputIsDynamic(\"out\", 0, ctx->InputIsDynamic(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> _ncclLogicalSendRecvOp::GetSbp(user_op::SbpContext* ctx) {\n  return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx);\n}\n\n/* static */ Maybe<void> _ncclLogicalSendRecvOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) {\n  NdSbp* input_nd_sbp = ctx->NdSbp4ArgNameAndIndex(\"in\", 0);\n  NdSbp* output_nd_sbp = ctx->NdSbp4ArgNameAndIndex(\"out\", 0);\n  input_nd_sbp->clear_sbp_parallel();\n  output_nd_sbp->clear_sbp_parallel();\n\n  JUST(GetNcclLogicalNdSbpFromAttr(ctx, \"src_reduced_nd_sbp\", input_nd_sbp));\n  JUST(GetNcclLogicalNdSbpFromAttr(ctx, \"dst_reduced_nd_sbp\", output_nd_sbp));\n\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> _ncclLogicalSendRecvOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<Symbol<Stream>> _ncclLogicalSendRecvOp::InferDeviceAndStream(\n    user_op::DeviceAndStreamInferContext* ctx) {\n  return DeviceAndStreamInferFn(ctx);\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/nccl_logical_util.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/user/ops/nccl_logical_util.h\"\n\nnamespace oneflow {\n\nstd::string GetCommKeyFromNcclType(const std::string& op_type_name) {\n  if (op_type_name == \"_nccl_logical_2D_same_dim0_all_reduce\"\n      || op_type_name == \"_nccl_logical_2D_same_dim0_all_gather\"\n      || op_type_name == \"_nccl_logical_2D_same_dim0_all_gather_noncontinuous\"\n      || op_type_name == \"_nccl_logical_2D_same_dim0_all2all\") {\n    return \"SameDim0\";\n  }\n  if (op_type_name == \"_nccl_logical_2D_same_dim1_all_reduce\") { return \"SameDim1\"; }\n  return \"\";\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/nccl_logical_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_OPS_NCCL_LOGICAL_UTIL_H_\n#define ONEFLOW_USER_OPS_NCCL_LOGICAL_UTIL_H_\n\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/job/sbp_parallel.h\"\n\nnamespace oneflow {\n\ntemplate<typename ContextT, typename AttrT>\nstruct AttrFromContext {\n  const AttrT& operator()(ContextT*, const std::string&);\n};\n\ntemplate<typename AttrT>\nstruct AttrFromContext<user_op::InferNdSbpFnContext, AttrT> {\n  const AttrT& operator()(user_op::InferNdSbpFnContext* ctx, const std::string& attr_name) {\n    return ctx->user_op_conf().template attr<AttrT>(attr_name);\n  }\n};\n\ntemplate<typename AttrT>\nstruct AttrFromContext<user_op::KernelInitContext, AttrT> {\n  const AttrT& operator()(user_op::KernelInitContext* ctx, const std::string& attr_name) {\n    return ctx->Attr<AttrT>(attr_name);\n  }\n};\n\ntemplate<typename AttrT>\nstruct AttrFromContext<user_op::InferContext, AttrT> {\n  const AttrT& operator()(user_op::InferContext* ctx, const std::string& attr_name) {\n    return ctx->Attr<AttrT>(attr_name);\n  }\n};\n\ntemplate<typename ContextT>\nstruct OpTypeNameFromContext {\n  const std::string& operator()(ContextT*);\n};\n\ntemplate<>\nstruct OpTypeNameFromContext<user_op::InferNdSbpFnContext> {\n  const std::string& operator()(user_op::InferNdSbpFnContext* ctx) {\n    return ctx->user_op_conf().op_type_name();\n  }\n};\n\ntemplate<>\nstruct OpTypeNameFromContext<user_op::KernelInitContext> {\n  const std::string& operator()(user_op::KernelInitContext* ctx) { return ctx->op_type_name(); }\n};\n\ntemplate<>\nstruct OpTypeNameFromContext<user_op::InferContext> {\n  const std::string& operator()(user_op::InferContext* ctx) { return ctx->op_type_name(); }\n};\n\ntemplate<typename ContextT>\nMaybe<void> GetNcclLogicalNdSbpFromAttr(ContextT* ctx, const std::string& attr_name,\n                                        NdSbp* nd_sbp) {\n  const auto& sbp_str_list = AttrFromContext<ContextT, std::vector<std::string>>()(ctx, attr_name);\n\n  if (!ParseNdSbpFromStringList(sbp_str_list, nd_sbp)) {\n    std::ostringstream err;\n    err << \"invalid \" << attr_name << \": [\";\n    for (size_t i = 0; i < sbp_str_list.size(); ++i) {\n      const auto& sbp_str = sbp_str_list[i];\n      if (i == 0) {\n        err << sbp_str;\n      } else {\n        err << \", \" << sbp_str;\n      }\n    }\n    err << \"] for \" << OpTypeNameFromContext<ContextT>()(ctx);\n    return Error::RuntimeError() << err.str();\n  }\n\n  return Maybe<void>::Ok();\n}\n\nstd::string GetCommKeyFromNcclType(const std::string& op_type_name);\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_OPS_NCCL_LOGICAL_UTIL_H_\n"
  },
  {
    "path": "oneflow/user/ops/nd_index_slice_ops.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nMaybe<void> CheckScatterNdShape(const Shape& params_shape, const Shape& indices_shape,\n                                const Shape& updates_shape) {\n  int64_t batch_ndims = indices_shape.NumAxes() - 1;\n  int64_t index_ndims = indices_shape.At(batch_ndims);\n  CHECK_LE_OR_RETURN(batch_ndims, updates_shape.NumAxes());\n  CHECK_LE_OR_RETURN(index_ndims, params_shape.NumAxes());\n  FOR_RANGE(int64_t, i, 0, batch_ndims) {\n    CHECK_EQ_OR_RETURN(updates_shape.At(i), indices_shape.At(i));\n  }\n  int64_t slice_ndims = params_shape.NumAxes() - index_ndims;\n  CHECK_EQ_OR_RETURN(slice_ndims, updates_shape.NumAxes() - batch_ndims);\n  FOR_RANGE(int64_t, i, 0, slice_ndims) {\n    CHECK_EQ_OR_RETURN(updates_shape.At(i + batch_ndims), params_shape.At(i + index_ndims));\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InferScatterNdTensorDesc(user_op::InferContext* ctx) {\n  const Shape& indices_shape = ctx->InputShape(\"indices\", 0);\n  const Shape& updates_shape = ctx->InputShape(\"updates\", 0);\n  const Shape& params_shape = ctx->Attr<Shape>(\"shape\");\n  JUST(CheckScatterNdShape(params_shape, indices_shape, updates_shape));\n  ctx->SetOutputShape(\"out\", 0, params_shape);\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InferScatterNdDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"updates\", 0));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InferScatterNdLikeTensorDesc(user_op::InferContext* ctx) {\n  const Shape& indices_shape = ctx->InputShape(\"indices\", 0);\n  const Shape& updates_shape = ctx->InputShape(\"updates\", 0);\n  const Shape& like_shape = ctx->InputShape(\"like\", 0);\n  JUST(CheckScatterNdShape(like_shape, indices_shape, updates_shape));\n  ctx->SetOutputShape(\"out\", 0, like_shape);\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InferScatterNdLikeDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"updates\", 0));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InferTensorScatterNdOptTensorDesc(user_op::InferContext* ctx) {\n  const Shape& params_shape = ctx->InputShape(\"params\", 0);\n  const Shape& updates_shape = ctx->InputShape(\"updates\", 0);\n  const Shape& indices_shape = ctx->InputShape(\"indices\", 0);\n  JUST(CheckScatterNdShape(params_shape, indices_shape, updates_shape));\n  ctx->SetOutputShape(\"out\", 0, params_shape);\n  ctx->SetOutputStride(\"out\", 0, ctx->InputStride(\"params\", 0));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InferTensorScatterNdOptDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"params\", 0));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> GetTensorScatterNdOptSbpSignatures(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& params_tensor =\n      ctx->LogicalTensorDesc4InputArgNameAndIndex(\"params\", 0);\n  const user_op::TensorDesc& indices_tensor =\n      ctx->LogicalTensorDesc4InputArgNameAndIndex(\"indices\", 0);\n  int64_t indices_num_axes = indices_tensor.shape().NumAxes();\n  FOR_RANGE(int64_t, i, 0, indices_num_axes - 1) {\n    ctx->NewBuilder()\n        .Broadcast(user_op::OpArg(\"params\", 0))\n        .Split(user_op::OpArg(\"indices\", 0), i)\n        .Split(user_op::OpArg(\"updates\", 0), i)\n        .Broadcast(user_op::OpArg(\"out\", 0))\n        .Build();\n  }\n  int64_t index_ndims = indices_tensor.shape().At(indices_num_axes - 1);\n  FOR_RANGE(int64_t, i, index_ndims, params_tensor.shape().NumAxes()) {\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"params\", 0), i)\n        .Broadcast(user_op::OpArg(\"indices\", 0))\n        .Split(user_op::OpArg(\"updates\", 0), i - index_ndims + indices_num_axes - 1)\n        .Split(user_op::OpArg(\"out\", 0), i)\n        .Build();\n  }\n  ctx->NewBuilder()\n      .PartialSum(user_op::OpArg(\"params\", 0))\n      .Broadcast(user_op::OpArg(\"indices\", 0))\n      .PartialSum(user_op::OpArg(\"updates\", 0))\n      .PartialSum(user_op::OpArg(\"out\", 0))\n      .Build();\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\n/* static */ Maybe<void> GatherNdOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& params_shape = ctx->InputShape(\"params\", 0);\n  const Shape& indices_shape = ctx->InputShape(\"indices\", 0);\n  int64_t index_ndims = indices_shape.At(indices_shape.NumAxes() - 1);\n  CHECK_LE_OR_RETURN(index_ndims, params_shape.NumAxes());\n  DimVector out_shape_vec(indices_shape.dim_vec().cbegin(), indices_shape.dim_vec().cend() - 1);\n  FOR_RANGE(int64_t, i, index_ndims, params_shape.NumAxes()) {\n    out_shape_vec.emplace_back(params_shape.At(i));\n  }\n  const Shape& out_shape = Shape(out_shape_vec);\n  bool is_out_of_bounds = params_shape.Count(0) == 0 && out_shape.Count(0) != 0;\n  CHECK_OR_RETURN(!is_out_of_bounds)\n      << Error::IndexError() << \"The index is out of bounds for dimension with size 0\";\n  ctx->SetOutputShape(\"out\", 0, out_shape);\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> GatherNdOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> GatherNdOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& params_tensor =\n      ctx->LogicalTensorDesc4InputArgNameAndIndex(\"params\", 0);\n  const user_op::TensorDesc& indices_tensor =\n      ctx->LogicalTensorDesc4InputArgNameAndIndex(\"indices\", 0);\n  int64_t indices_num_axes = indices_tensor.shape().NumAxes();\n  FOR_RANGE(int64_t, i, 0, indices_num_axes - 1) {\n    ctx->NewBuilder()\n        .Broadcast(user_op::OpArg(\"params\", 0))\n        .Split(user_op::OpArg(\"indices\", 0), i)\n        .Split(user_op::OpArg(\"out\", 0), i)\n        .Build();\n  }\n  int64_t index_ndims = indices_tensor.shape().At(indices_num_axes - 1);\n  FOR_RANGE(int64_t, i, index_ndims, params_tensor.shape().NumAxes()) {\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"params\", 0), i)\n        .Broadcast(user_op::OpArg(\"indices\", 0))\n        .Split(user_op::OpArg(\"out\", 0), i - index_ndims + indices_num_axes - 1)\n        .Build();\n  }\n  ctx->NewBuilder()\n      .PartialSum(user_op::OpArg(\"params\", 0))\n      .Broadcast(user_op::OpArg(\"indices\", 0))\n      .PartialSum(user_op::OpArg(\"out\", 0))\n      .Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> GatherNdOp::ModifyInputArg(\n    const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) {\n  user_op::InputArgModifier* indices_modifier = GetInputArgModifierFn(\"indices\", 0);\n  CHECK_OR_RETURN(indices_modifier != nullptr);\n  indices_modifier->set_requires_grad(false);\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> GatherNdOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"params\", 0));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> ScatterNdOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  return InferScatterNdTensorDesc(ctx);\n}\n\n/*static*/ Maybe<void> ScatterNdOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> ScatterNdOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& indices_desc =\n      ctx->LogicalTensorDesc4InputArgNameAndIndex(\"indices\", 0);\n  int64_t indices_num_axes = indices_desc.shape().NumAxes();\n  FOR_RANGE(int64_t, i, 0, indices_num_axes - 1) {\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"indices\", 0), i)\n        .Split(user_op::OpArg(\"updates\", 0), i)\n        .Broadcast(user_op::OpArg(\"out\", 0))\n        .Build();\n  }\n  const Shape& out_shape = ctx->Attr<Shape>(\"shape\");\n  int64_t index_ndims = indices_desc.shape().At(indices_num_axes - 1);\n  int64_t slice_ndims = out_shape.NumAxes() - index_ndims;\n  FOR_RANGE(int64_t, i, 0, slice_ndims) {\n    ctx->NewBuilder()\n        .Broadcast(user_op::OpArg(\"indices\", 0))\n        .Split(user_op::OpArg(\"updates\", 0), i + indices_num_axes - 1)\n        .Split(user_op::OpArg(\"out\", 0), i + index_ndims)\n        .Build();\n  }\n  ctx->NewBuilder()\n      .PartialSum(user_op::OpArg(\"updates\", 0))\n      .Broadcast(user_op::OpArg(\"indices\", 0))\n      .PartialSum(user_op::OpArg(\"out\", 0))\n      .Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> ScatterNdOp::ModifyInputArg(\n    const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) {\n  user_op::InputArgModifier* indices_modifier = GetInputArgModifierFn(\"indices\", 0);\n  CHECK_OR_RETURN(indices_modifier != nullptr);\n  indices_modifier->set_requires_grad(false);\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> ScatterNdOp::InferDataType(user_op::InferContext* ctx) {\n  return InferScatterNdDataType(ctx);\n}\n\n/* static */ Maybe<void> ScatterNdLikeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  return InferScatterNdLikeTensorDesc(ctx);\n}\n\n/*static*/ Maybe<void> ScatterNdLikeOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> ScatterNdLikeOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& indices_tensor =\n      ctx->LogicalTensorDesc4InputArgNameAndIndex(\"indices\", 0);\n  int64_t indices_num_axes = indices_tensor.shape().NumAxes();\n  FOR_RANGE(int64_t, i, 0, indices_num_axes - 1) {\n    ctx->NewBuilder()\n        .Broadcast(user_op::OpArg(\"like\", 0))\n        .Split(user_op::OpArg(\"indices\", 0), i)\n        .Split(user_op::OpArg(\"updates\", 0), i)\n        .Broadcast(user_op::OpArg(\"out\", 0))\n        .Build();\n  }\n  const Shape& out_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"like\", 0).shape();\n  int64_t index_ndims = indices_tensor.shape().At(indices_num_axes - 1);\n  int64_t slice_ndims = out_shape.NumAxes() - index_ndims;\n  FOR_RANGE(int64_t, i, 0, slice_ndims) {\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"like\", 0), i + index_ndims)\n        .Broadcast(user_op::OpArg(\"indices\", 0))\n        .Split(user_op::OpArg(\"updates\", 0), i + indices_num_axes - 1)\n        .Split(user_op::OpArg(\"out\", 0), i + index_ndims)\n        .Build();\n  }\n  ctx->NewBuilder()\n      .PartialSum(user_op::OpArg(\"like\", 0))\n      .PartialSum(user_op::OpArg(\"updates\", 0))\n      .Broadcast(user_op::OpArg(\"indices\", 0))\n      .PartialSum(user_op::OpArg(\"out\", 0))\n      .Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> ScatterNdLikeOp::InferDataType(user_op::InferContext* ctx) {\n  return InferScatterNdLikeDataType(ctx);\n}\n\n/* static */ Maybe<void> TensorScatterNdUpdateOp::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferTensorScatterNdOptTensorDesc(ctx);\n}\n\n/*static*/ Maybe<void> TensorScatterNdUpdateOp::InferPhysicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> TensorScatterNdUpdateOp::GetSbp(user_op::SbpContext* ctx) {\n  return GetTensorScatterNdOptSbpSignatures(ctx);\n}\n\n/* static */ Maybe<void> TensorScatterNdUpdateOp::ModifyInputArg(\n    const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) {\n  user_op::InputArgModifier* indices_modifier = GetInputArgModifierFn(\"indices\", 0);\n  CHECK_OR_RETURN(indices_modifier != nullptr);\n  indices_modifier->set_requires_grad(false);\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> TensorScatterNdUpdateOp::InferDataType(user_op::InferContext* ctx) {\n  return InferTensorScatterNdOptDataType(ctx);\n}\n\n/* static */ Maybe<void> TensorScatterNdAddOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  return InferTensorScatterNdOptTensorDesc(ctx);\n}\n\n/*static*/ Maybe<void> TensorScatterNdAddOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> TensorScatterNdAddOp::GetSbp(user_op::SbpContext* ctx) {\n  return GetTensorScatterNdOptSbpSignatures(ctx);\n}\n\n/* static */ Maybe<void> TensorScatterNdAddOp::ModifyInputArg(\n    const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) {\n  user_op::InputArgModifier* indices_modifier = GetInputArgModifierFn(\"indices\", 0);\n  CHECK_OR_RETURN(indices_modifier != nullptr);\n  indices_modifier->set_requires_grad(false);\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> TensorScatterNdAddOp::InferDataType(user_op::InferContext* ctx) {\n  return InferTensorScatterNdOptDataType(ctx);\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/nll_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/* static */ Maybe<void> NLLOp::InferDataType(user_op::InferContext* ctx) {\n  CHECK_OR_RETURN(IsIndexDataType(ctx->InputDType(\"target\", 0)))\n      << ctx->op_name() << \": expected target being integer type\";\n\n  DataType input_dtype = ctx->InputDType(\"input\", 0);\n  if (ctx->has_input(\"weight\", 0)) {\n    DataType weight_dtype = ctx->InputDType(\"weight\", 0);\n    CHECK_EQ_OR_RETURN(weight_dtype, input_dtype) << ctx->op_name() << \": expected weight dtype \"\n                                                  << input_dtype << \", but got \" << weight_dtype;\n  }\n\n  ctx->SetOutputDType(\"output\", 0, input_dtype);\n  ctx->SetOutputDType(\"out_weight\", 0, input_dtype);\n\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> NLLOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const auto& input_desc = ctx->InputTensorDesc(\"input\", 0);\n  const auto& target_desc = ctx->InputTensorDesc(\"target\", 0);\n\n  const bool is_dynamic = input_desc.is_dynamic();\n  CHECK_EQ_OR_RETURN(target_desc.is_dynamic(), is_dynamic)\n      << ctx->op_name() << \": expected the same dynamic with input and target\";\n  const int64_t K = input_desc.shape().NumAxes();\n  CHECK_GE_OR_RETURN(K, 2) << ctx->op_name() << \": expected 2 or more dimensions for input\";\n  CHECK_EQ_OR_RETURN(target_desc.shape().NumAxes(), K - 1)\n      << ctx->op_name() << \": expected 1 less diemensions than input for target\";\n  const int64_t N = target_desc.shape().elem_cnt();\n  const int64_t C = input_desc.shape().At(input_desc.shape().NumAxes() - 1);\n  CHECK_EQ_OR_RETURN(input_desc.shape().elem_cnt(), N * C)\n      << ctx->op_name() << \": expected input size \" << input_desc.shape().ToString()\n      << \" to match target size \" << target_desc.shape().ToString();\n\n  if (ctx->has_input(\"weight\", 0)) {\n    const auto& weight_desc = ctx->InputTensorDesc(\"weight\", 0);\n    CHECK_EQ_OR_RETURN(weight_desc.is_dynamic(), is_dynamic)\n        << ctx->op_name() << \": expected the same dynamic with input and weight\";\n    CHECK_EQ_OR_RETURN(weight_desc.shape().elem_cnt(), C)\n        << ctx->op_name() << \": expected weight size \" << C << \", got \"\n        << weight_desc.shape().ToString();\n  }\n\n  user_op::TensorDesc* output_desc = ctx->MutOutputTensorDesc(\"output\", 0);\n  output_desc->set_is_dynamic(is_dynamic);\n  output_desc->set_shape(Shape({N}));\n\n  user_op::TensorDesc* out_weight_desc = ctx->MutOutputTensorDesc(\"out_weight\", 0);\n  out_weight_desc->set_is_dynamic(is_dynamic);\n  out_weight_desc->set_shape(Shape({N}));\n\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> NLLOp::GetSbp(user_op::SbpContext* ctx) {\n  // split batch dim\n  auto builder1 = ctx->NewBuilder()\n                      .Split(user_op::OpArg(\"input\", 0), 0)\n                      .Split(user_op::OpArg(\"target\", 0), 0)\n                      .Split(user_op::OpArg(\"output\", 0), 0)\n                      .Split(user_op::OpArg(\"out_weight\", 0), 0);\n  if (ctx->user_op_conf().has_input(\"weight\", 0)) {\n    builder1.Broadcast(user_op::OpArg(\"weight\", 0));\n  }\n  builder1.Build();\n\n  // split class dim\n  const auto& shape = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"input\", 0).shape();\n  auto builder2 = ctx->NewBuilder()\n                      .Split(user_op::OpArg(\"input\", 0), shape.NumAxes() - 1)\n                      .Broadcast(user_op::OpArg(\"target\", 0))\n                      .PartialSum(user_op::OpArg(\"output\", 0))\n                      .PartialSum(user_op::OpArg(\"out_weight\", 0));\n  if (ctx->user_op_conf().has_input(\"weight\", 0)) {\n    builder2.Split(user_op::OpArg(\"weight\", 0), 0);\n  }\n  builder2.Build();\n\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> NLLOp::ModifyInputArg(const GetInputArgModifier& GetInputArgModifierFn,\n                                               const user_op::UserOpConfWrapper& conf) {\n  user_op::InputArgModifier* target_modifier = GetInputArgModifierFn(\"target\", 0);\n  CHECK_OR_RETURN(target_modifier != nullptr);\n  target_modifier->set_requires_grad(false);\n  if (conf.has_input(\"weight\", 0)) {\n    auto* weight_modifier = GetInputArgModifierFn(\"weight\", 0);\n    if (weight_modifier) { weight_modifier->set_requires_grad(false); }\n  }\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> NLLGradOp::InferDataType(user_op::InferContext* ctx) {\n  CHECK_OR_RETURN(IsIndexDataType(ctx->InputDType(\"target\", 0)))\n      << ctx->op_name() << \": expected target being integer type\";\n\n  DataType input_dtype = ctx->InputDType(\"input\", 0);\n  CHECK_EQ_OR_RETURN(ctx->InputDType(\"out_grad\", 0), input_dtype)\n      << ctx->op_name() << \": expected out_grad dtype \" << input_dtype << \", got \"\n      << ctx->InputDType(\"out_grad\", 0);\n\n  if (ctx->has_input(\"weight\", 0)) {\n    CHECK_EQ_OR_RETURN(ctx->InputDType(\"weight\", 0), input_dtype)\n        << ctx->op_name() << \": expected weight dtype \" << input_dtype << \", got \"\n        << ctx->InputDType(\"weight\", 0);\n  }\n\n  ctx->SetOutputDType(\"in_grad\", 0, input_dtype);\n\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> NLLGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const auto& input_desc = ctx->InputTensorDesc(\"input\", 0);\n  const auto& target_desc = ctx->InputTensorDesc(\"target\", 0);\n  const auto& out_grad_desc = ctx->InputTensorDesc(\"out_grad\", 0);\n\n  bool is_dynamic = input_desc.is_dynamic();\n  CHECK_EQ_OR_RETURN(target_desc.is_dynamic(), is_dynamic)\n      << ctx->op_name() << \": expected target dynamic \" << is_dynamic;\n  CHECK_EQ_OR_RETURN(out_grad_desc.is_dynamic(), is_dynamic)\n      << ctx->op_name() << \": expected out_grad dynamic \" << is_dynamic;\n\n  const int64_t N = target_desc.shape().elem_cnt();\n  CHECK_EQ_OR_RETURN(out_grad_desc.shape().elem_cnt(), N)\n      << ctx->op_name() << \": expected out_grad size \" << N << \", got \"\n      << out_grad_desc.shape().ToString();\n\n  const int64_t C = input_desc.shape().At(input_desc.shape().NumAxes() - 1);\n  CHECK_EQ_OR_RETURN(input_desc.shape().elem_cnt(), N * C)\n      << ctx->op_name() << \": expected input size \" << N << \", got \"\n      << input_desc.shape().ToString();\n\n  if (ctx->has_input(\"weight\", 0)) {\n    const auto& weight_desc = ctx->InputTensorDesc(\"weight\", 0);\n    CHECK_EQ_OR_RETURN(weight_desc.shape().elem_cnt(), C)\n        << ctx->op_name() << \": expected weight size \" << C << \", got \"\n        << weight_desc.shape().ToString();\n  }\n\n  user_op::TensorDesc* in_grad_desc = ctx->MutOutputTensorDesc(\"in_grad\", 0);\n  in_grad_desc->set_is_dynamic(is_dynamic);\n  in_grad_desc->set_shape(input_desc.shape());\n\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> NLLGradOp::GetSbp(user_op::SbpContext* ctx) {\n  // split batch dim\n  auto builder1 = ctx->NewBuilder()\n                      .Split(user_op::OpArg(\"input\", 0), 0)\n                      .Split(user_op::OpArg(\"target\", 0), 0)\n                      .Split(user_op::OpArg(\"out_grad\", 0), 0)\n                      .Split(user_op::OpArg(\"in_grad\", 0), 0);\n  if (ctx->user_op_conf().has_input(\"weight\", 0)) {\n    builder1.Broadcast(user_op::OpArg(\"weight\", 0));\n  }\n  builder1.Build();\n\n  // split class dim\n  const auto& shape = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"input\", 0).shape();\n  auto builder2 = ctx->NewBuilder()\n                      .Split(user_op::OpArg(\"input\", 0), shape.NumAxes() - 1)\n                      .Broadcast(user_op::OpArg(\"target\", 0))\n                      .Broadcast(user_op::OpArg(\"out_grad\", 0))\n                      .Split(user_op::OpArg(\"in_grad\", 0), shape.NumAxes() - 1);\n  if (ctx->user_op_conf().has_input(\"weight\", 0)) {\n    builder2.Split(user_op::OpArg(\"weight\", 0), 0);\n  }\n  builder2.Build();\n\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/nms_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nMaybe<void> InferNmsTensorDesc(user_op::InferContext* ctx) {\n  ctx->SetOutputShape(\"out\", 0, Shape({ctx->InputShape(\"in\", 0).At(0)}));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InferNmsDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, DataType::kInt8);\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\n/* static */ Maybe<void> NmsOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  return InferNmsTensorDesc(ctx);\n}\n\n/*static*/ Maybe<void> NmsOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> NmsOp::GetSbp(user_op::SbpContext* ctx) {\n  return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx);\n}\n\n/* static */ Maybe<void> NmsOp::InferDataType(user_op::InferContext* ctx) {\n  return InferNmsDataType(ctx);\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/nn_util.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/user/ops/nn_util.h\"\n\nnamespace oneflow {\n\nMaybe<void> CalcOutAndPadding(int64_t input_size, int32_t filter_size, int32_t dilation_rate,\n                              int32_t stride, const std::string& padding_type, int64_t* output_size,\n                              int32_t* padding_before, int32_t* padding_after) {\n  CHECK_GT_OR_RETURN(stride, 0);\n  CHECK_GE_OR_RETURN(dilation_rate, 1);\n\n  int32_t effective_filter_size = (filter_size - 1) * dilation_rate + 1;\n  if (padding_type == \"valid\") {\n    if (output_size) { *output_size = (input_size - effective_filter_size + stride) / stride; }\n    if (padding_before) { *padding_before = 0; }\n    if (padding_after) { *padding_after = 0; }\n  } else if (padding_type == \"same\") {\n    int64_t tmp_output_size = (input_size + stride - 1) / stride;\n    if (output_size) { *output_size = tmp_output_size; }\n    const int32_t padding_needed = std::max(\n        0,\n        static_cast<int32_t>((tmp_output_size - 1) * stride + effective_filter_size - input_size));\n    // For odd values of total padding, add more padding at the 'right'\n    // side of the given dimension.\n    if (padding_before) { *padding_before = padding_needed / 2; }\n    if (padding_after) { *padding_after = padding_needed - padding_needed / 2; }\n  } else {\n    UNIMPLEMENTED();\n  }\n  if (output_size) { CHECK_GE_OR_RETURN((*output_size), 0); }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> CalcSamePadding(int64_t input_size, int32_t filter_size, int32_t dilation_rate,\n                            int32_t stride, int32_t* padding_small, int32_t* padding_large) {\n  CHECK_GT_OR_RETURN(stride, 0);\n  CHECK_GE_OR_RETURN(dilation_rate, 1);\n\n  int32_t effective_filter_size = (filter_size - 1) * dilation_rate + 1;\n  int64_t tmp_output_size = (input_size + stride - 1) / stride;\n  const int32_t padding_needed = std::max(\n      0, static_cast<int32_t>((tmp_output_size - 1) * stride + effective_filter_size - input_size));\n  if (padding_small) { *padding_small = padding_needed / 2; }\n  if (padding_large) { *padding_large = padding_needed - padding_needed / 2; }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> CalcConvOut(int64_t input_size, int32_t filter_size, int32_t dilation_rate,\n                        int32_t stride, int32_t padding_before, int64_t* output_size) {\n  CHECK_GT_OR_RETURN(stride, 0);\n  CHECK_GE_OR_RETURN(dilation_rate, 1);\n\n  int32_t effective_filter_size = (filter_size - 1) * dilation_rate + 1;\n  if (output_size) {\n    *output_size = (input_size + 2 * padding_before - effective_filter_size + stride) / stride;\n    CHECK_GE_OR_RETURN((*output_size), 0);\n  }\n  return Maybe<void>::Ok();\n}\n\nconst size_t IdxOffset(const std::string& data_format) {\n  if (data_format == \"channels_first\") {\n    return 2;\n  } else if (data_format == \"channels_last\") {\n    return 1;\n  } else {\n    UNIMPLEMENTED();\n  }\n}\n\nconst int32_t ChannelIdx(const std::string& data_format, int32_t num_axes) {\n  if (data_format == \"channels_first\") {\n    return 1;\n  } else if (data_format == \"channels_last\") {\n    return num_axes - 1;\n  } else {\n    UNIMPLEMENTED();\n  }\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/nn_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_OPS_NN_UTIL_H_\n#define ONEFLOW_USER_OPS_NN_UTIL_H_\n\n#include \"oneflow/core/framework/framework.h\"\n\nnamespace oneflow {\n\nMaybe<void> CalcOutAndPadding(int64_t input_size, int32_t filter_size, int32_t dilation_rate,\n                              int32_t stride, const std::string& padding_type, int64_t* output_size,\n                              int32_t* padding_before, int32_t* padding_after);\n\nMaybe<void> CalcSamePadding(int64_t input_size, int32_t filter_size, int32_t dilation_rate,\n                            int32_t stride, int32_t* padding_small, int32_t* padding_large);\n\nMaybe<void> CalcConvOut(int64_t input_size, int32_t filter_size, int32_t dilation_rate,\n                        int32_t stride, int32_t padding_before, int64_t* output_size);\n\nconst size_t IdxOffset(const std::string& data_format);\nconst int32_t ChannelIdx(const std::string& data_format, int32_t num_axes);\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_OPS_NN_UTIL_H_\n"
  },
  {
    "path": "oneflow/user/ops/noncontiguous_binary_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/data_type.h\"\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/common/shape.h\"\n#include \"oneflow/core/common/shape_view.h\"\n#include \"oneflow/core/common/stride.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/*static*/ Maybe<void> NonContiguousBinaryOp::GetSbp(user_op::SbpContext* ctx) {\n  // only support broadcast\n  ctx->NewBuilder()\n      .Broadcast(user_op::OpArg(\"lhs\", 0))\n      .Broadcast(user_op::OpArg(\"rhs\", 0))\n      .Broadcast(user_op::OpArg(\"y\", 0))\n      .Build();\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> NonContiguousBinaryOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& lhs = ctx->InputShape(\"lhs\", 0);\n  const Shape& rhs = ctx->InputShape(\"rhs\", 0);\n  CHECK_EQ(lhs.NumAxes(), rhs.NumAxes());\n  for (int i = 0; i < lhs.NumAxes(); i++) CHECK_EQ(lhs.At(i), rhs.At(i));\n  ctx->SetOutputShape(\"y\", 0, lhs);\n  const bool inplace = ctx->Attr<bool>(\"inplace\");\n  if (inplace) {\n    ctx->SetOutputStride(\"y\", 0, ctx->InputStride(\"lhs\", 0));\n  } else {  // set contiguous for y if not inplace\n    ctx->SetOutputStride(\"y\", 0, Stride(lhs));\n  }\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> NonContiguousBinaryOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> NonContiguousBinaryOp::InferDataType(user_op::InferContext* ctx) {\n  auto lhs = ctx->InputDType(\"lhs\", 0);\n  auto rhs = ctx->InputDType(\"rhs\", 0);\n  ctx->SetOutputDType(\"y\", 0, GetSizeOfDataType(lhs) >= GetSizeOfDataType(rhs) ? lhs : rhs);\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> NonContiguousBinaryOpGrad::GetSbp(user_op::SbpContext* ctx) {\n  ctx->NewBuilder()\n      .Broadcast(user_op::OpArg(\"lhs\", 0))\n      .Broadcast(user_op::OpArg(\"rhs\", 0))\n      .Broadcast(user_op::OpArg(\"dy\", 0))\n      .Broadcast(user_op::OpArg(\"dlhs\", 0))\n      .Broadcast(user_op::OpArg(\"drhs\", 0))\n      .Build();\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> NonContiguousBinaryOpGrad::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) {\n  const Shape& lhs = ctx->InputShape(\"lhs\", 0);\n  const Shape& rhs = ctx->InputShape(\"rhs\", 0);\n  CHECK_EQ(lhs.NumAxes(), rhs.NumAxes());\n  for (int i = 0; i < lhs.NumAxes(); i++) CHECK_EQ(lhs.At(i), rhs.At(i));\n  ctx->SetOutputShape(\"dlhs\", 0, lhs);\n  ctx->SetOutputStride(\"dlhs\", 0, ctx->InputStride(\"lhs\", 0));\n  ctx->SetOutputShape(\"drhs\", 0, rhs);\n  ctx->SetOutputStride(\"drhs\", 0, ctx->InputStride(\"rhs\", 0));\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> NonContiguousBinaryOpGrad::InferPhysicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> NonContiguousBinaryOpGrad::InferDataType(user_op::InferContext* ctx) {\n  auto lhs = ctx->InputDType(\"lhs\", 0);\n  auto rhs = ctx->InputDType(\"rhs\", 0);\n  ctx->SetOutputDType(\"dlhs\", 0, lhs);\n  ctx->SetOutputDType(\"drhs\", 0, rhs);\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/normalization_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n#ifdef WITH_CUDA\n#include \"oneflow/core/device/cuda_util.h\"\n#include \"oneflow/core/device/cudnn_util.h\"\n#endif\n\nnamespace oneflow {\n\nnamespace {\n\nstd::function<Maybe<void>(const std::string&)> MakeCheckParamTensorDescFn(\n    user_op::InferContext* ctx, const Shape& shape) {\n  return [=](const std::string& bn) -> Maybe<void> {\n    if (ctx->has_input(bn, 0)) {\n      const auto& tensor_desc = ctx->InputTensorDesc(bn, 0);\n      CHECK_EQ_OR_RETURN(tensor_desc.shape(), shape);\n    }\n    return Maybe<void>::Ok();\n  };\n}\n\nstd::function<Maybe<void>(const std::string&)> MakeCheckParamDataTypeFn(user_op::InferContext* ctx,\n                                                                        DataType data_type) {\n  return [=](const std::string& bn) -> Maybe<void> {\n    if (ctx->has_input(bn, 0)) {\n      const auto& tensor_desc = ctx->InputTensorDesc(bn, 0);\n      CHECK_EQ_OR_RETURN(tensor_desc.data_type(), data_type)\n          << \"InferDataType Failed. Expected \" << DataType_Name(tensor_desc.data_type())\n          << \", but got \" << DataType_Name(data_type);\n    }\n    return Maybe<void>::Ok();\n  };\n}\n\nstd::function<Maybe<void>(const std::string&)> MakeSetParamTensorDescFn(user_op::InferContext* ctx,\n                                                                        const Shape& shape) {\n  return [=](const std::string& bn) -> Maybe<void> {\n    if (ctx->has_output(bn, 0)) {\n      auto* tensor_desc = ctx->MutOutputTensorDesc(bn, 0);\n      CHECK_OR_RETURN(tensor_desc != nullptr);\n      tensor_desc->set_shape(shape);\n    }\n    return Maybe<void>::Ok();\n  };\n}\n\nstd::function<Maybe<void>(const std::string&)> MakeSetParamDataTypeFn(user_op::InferContext* ctx,\n                                                                      DataType data_type) {\n  return [=](const std::string& bn) -> Maybe<void> {\n    if (ctx->has_output(bn, 0)) {\n      auto* tensor_desc = ctx->MutOutputTensorDesc(bn, 0);\n      CHECK_OR_RETURN(tensor_desc != nullptr);\n      tensor_desc->set_data_type(data_type);\n    }\n    return Maybe<void>::Ok();\n  };\n}\n\nMaybe<void> FwInputArgModifyFn(const user_op::GetInputArgModifier& GetInputArgModifierFn,\n                               const user_op::UserOpConfWrapper& conf) {\n  bool training = true;\n  if (conf.op_type_name() == \"normalization\" || conf.op_type_name() == \"normalization_add_relu\") {\n    training = conf.attr<bool>(\"training\");\n  }\n  if (conf.has_input(\"moving_mean\", 0)) {\n    CHECK_OR_RETURN(conf.has_input(\"moving_variance\", 0));\n    user_op::InputArgModifier* moving_mean_modifier = GetInputArgModifierFn(\"moving_mean\", 0);\n    CHECK_OR_RETURN(moving_mean_modifier != nullptr);\n    moving_mean_modifier->set_is_mutable(training);\n    moving_mean_modifier->set_requires_grad(false);\n    user_op::InputArgModifier* moving_variance_modifier =\n        GetInputArgModifierFn(\"moving_variance\", 0);\n    CHECK_OR_RETURN(moving_variance_modifier != nullptr);\n    moving_variance_modifier->set_is_mutable(training);\n    moving_variance_modifier->set_requires_grad(false);\n  } else {\n    CHECK_OR_RETURN(training)\n        << \"Must have moving mean and moving variance for normalization in inference mode.\";\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> FwGetSbpFn(user_op::SbpContext* ctx) {\n  std::vector<user_op::OpArg> split_args;\n  split_args.emplace_back(\"x\", 0);\n  split_args.emplace_back(\"y\", 0);\n  if (ctx->user_op_conf().has_input(\"addend\", 0)) { split_args.emplace_back(\"addend\", 0); }\n  if (ctx->user_op_conf().has_input(\"_add_to_output\", 0)) {\n    split_args.emplace_back(\"_add_to_output\", 0);\n  }\n  std::vector<user_op::OpArg> broadcast_args;\n  broadcast_args.emplace_back(\"moving_mean\", 0);\n  broadcast_args.emplace_back(\"moving_variance\", 0);\n  broadcast_args.emplace_back(\"gamma\", 0);\n  broadcast_args.emplace_back(\"beta\", 0);\n  if (ctx->user_op_conf().has_output(\"mean\", 0)) { broadcast_args.emplace_back(\"mean\", 0); }\n  if (ctx->user_op_conf().has_output(\"inv_variance\", 0)) {\n    broadcast_args.emplace_back(\"inv_variance\", 0);\n  }\n  if (ctx->user_op_conf().has_output(\"reserve_space\", 0)) {\n    broadcast_args.emplace_back(\"reserve_space\", 0);\n  }\n  ctx->NewBuilder().Broadcast(broadcast_args).Split(split_args, 0).Build();\n  return Maybe<void>::Ok();\n}\n\nuser_op::TensorDescInferFn MakeFwTensorDescInferFn(\n    const std::function<Maybe<void>(user_op::InferContext* ctx, const user_op::TensorDesc* x,\n                                    user_op::TensorDesc* reserve_space)>& reserve_space_infer_fn) {\n  return [reserve_space_infer_fn](user_op::InferContext* ctx) -> Maybe<void> {\n#ifdef WITH_CUDA\n    // assume cudnn is enabled\n    CHECK_GE_OR_RETURN(ctx->Attr<float>(\"epsilon\"), CUDNN_BN_MIN_EPSILON);\n#endif\n    const auto& x = ctx->InputTensorDesc(\"x\", 0);\n    const auto data_type = x.data_type();\n    const Shape& x_shape = x.shape();\n    if (ctx->has_input(\"addend\", 0)) {\n      const auto& addend = ctx->InputTensorDesc(\"addend\", 0);\n      CHECK_EQ_OR_RETURN(addend.data_type(), data_type)\n          << \"InferDataType Failed. Expected \" << DataType_Name(addend.data_type()) << \", but got \"\n          << DataType_Name(data_type);\n      CHECK_EQ_OR_RETURN(addend.shape(), x_shape);\n    }\n    if (ctx->has_input(\"_add_to_output\", 0)) {\n      const auto& add_to_output = ctx->InputTensorDesc(\"_add_to_output\", 0);\n      CHECK_EQ_OR_RETURN(add_to_output.data_type(), data_type)\n          << \"InferDataType Failed. Expected \" << DataType_Name(add_to_output.data_type())\n          << \", but got \" << DataType_Name(data_type);\n      CHECK_EQ_OR_RETURN(add_to_output.shape(), x_shape);\n    }\n    *ctx->MutOutputTensorDesc(\"y\", 0) = x;\n    const auto axis = ctx->Attr<int32_t>(\"axis\");\n    CHECK_GE_OR_RETURN(axis, 0);\n    CHECK_LT_OR_RETURN(axis, x_shape.NumAxes());\n    const Shape param_shape({x_shape.At(axis)});\n    const auto CheckParamTensorDesc = MakeCheckParamTensorDescFn(ctx, param_shape);\n    const auto SetParamTensorDesc = MakeSetParamTensorDescFn(ctx, param_shape);\n    if (ctx->has_input(\"moving_mean\", 0)) {\n      CHECK_OR_RETURN(ctx->has_input(\"moving_variance\", 0));\n      JUST(CheckParamTensorDesc(\"moving_mean\"));\n      JUST(CheckParamTensorDesc(\"moving_variance\"));\n    }\n    JUST(CheckParamTensorDesc(\"beta\"));\n    JUST(CheckParamTensorDesc(\"gamma\"));\n    JUST(SetParamTensorDesc(\"mean\"));\n    JUST(SetParamTensorDesc(\"inv_variance\"));\n    if (ctx->has_output(\"reserve_space\", 0)) {\n      CHECK_OR_RETURN(reserve_space_infer_fn);\n      reserve_space_infer_fn(ctx, &x, ctx->MutOutputTensorDesc(\"reserve_space\", 0));\n    }\n    return Maybe<void>::Ok();\n  };\n}\n\nuser_op::DataTypeInferFn MakeFwDataTypeInferFn(\n    const std::function<Maybe<void>(user_op::InferContext* ctx, const user_op::TensorDesc* x,\n                                    user_op::TensorDesc* reserve_space)>& reserve_space_infer_fn) {\n  return [reserve_space_infer_fn](user_op::InferContext* ctx) -> Maybe<void> {\n    const auto& x = ctx->InputTensorDesc(\"x\", 0);\n    const auto data_type = x.data_type();\n    if (ctx->has_input(\"addend\", 0)) {\n      const auto& addend = ctx->InputTensorDesc(\"addend\", 0);\n      CHECK_EQ_OR_RETURN(addend.data_type(), data_type)\n          << \"InferDataType Failed. Expected \" << DataType_Name(data_type) << \", but got \"\n          << DataType_Name(addend.data_type());\n    }\n    if (ctx->has_input(\"_add_to_output\", 0)) {\n      const auto& add_to_output = ctx->InputTensorDesc(\"_add_to_output\", 0);\n      CHECK_EQ_OR_RETURN(add_to_output.data_type(), data_type)\n          << \"InferDataType Failed. Expected \" << DataType_Name(data_type) << \", but got \"\n          << DataType_Name(add_to_output.data_type());\n    }\n    *ctx->MutOutputTensorDesc(\"y\", 0) = x;\n    const DataType param_data_type =\n        (data_type == DataType::kFloat16 || data_type == DataType::kBFloat16) ? DataType::kFloat\n                                                                              : data_type;\n    const auto CheckParamDataType = MakeCheckParamDataTypeFn(ctx, param_data_type);\n    const auto SetParamDataType = MakeSetParamDataTypeFn(ctx, param_data_type);\n    if (ctx->has_input(\"moving_mean\", 0)) {\n      CHECK_OR_RETURN(ctx->has_input(\"moving_variance\", 0));\n      JUST(CheckParamDataType(\"moving_mean\"));\n      JUST(CheckParamDataType(\"moving_variance\"));\n    }\n    CHECK_OR_RETURN(ctx->has_input(\"gamma\", 0));\n    JUST(CheckParamDataType(\"beta\"));\n    JUST(CheckParamDataType(\"gamma\"));\n    JUST(SetParamDataType(\"mean\"));\n    JUST(SetParamDataType(\"inv_variance\"));\n    if (ctx->has_output(\"reserve_space\", 0)) {\n      CHECK_OR_RETURN(reserve_space_infer_fn);\n      reserve_space_infer_fn(ctx, &x, ctx->MutOutputTensorDesc(\"reserve_space\", 0));\n    }\n    return Maybe<void>::Ok();\n  };\n}\n\nuser_op::TensorDescInferFn MakeFwTensorDescInferFn() {\n  return MakeFwTensorDescInferFn(\n      std::function<Maybe<void>(user_op::InferContext * ctx, const user_op::TensorDesc* x,\n                                user_op::TensorDesc* reserve_space)>());\n}\n\nuser_op::DataTypeInferFn MakeFwDataTypeInferFn() {\n  return MakeFwDataTypeInferFn(\n      std::function<Maybe<void>(user_op::InferContext * ctx, const user_op::TensorDesc* x,\n                                user_op::TensorDesc* reserve_space)>());\n}\n\n}  // namespace\n\n/* static */ Maybe<void> NormalizationOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  return MakeFwTensorDescInferFn()(ctx);\n}\n\n/*static*/ Maybe<void> NormalizationOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> NormalizationOp::GetSbp(user_op::SbpContext* ctx) {\n  return FwGetSbpFn(ctx);\n}\n\n/* static */ Maybe<void> NormalizationOp::ModifyInputArg(\n    const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) {\n  return FwInputArgModifyFn(GetInputArgModifierFn, conf);\n}\n\n/* static */ Maybe<void> NormalizationOp::InferDataType(user_op::InferContext* ctx) {\n  return MakeFwDataTypeInferFn()(ctx);\n}\n\n/* static */ Maybe<void> NormalizationAddReluOp::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return MakeFwTensorDescInferFn([](user_op::InferContext* ctx, const user_op::TensorDesc* x,\n                                    user_op::TensorDesc* reserve_space) -> Maybe<void> {\n    const auto& x_desc = ctx->InputTensorDesc(\"x\", 0);\n    size_t reserve_space_bits = x_desc.shape().elem_cnt();\n    int64_t parallel_num = ctx->parallel_num();\n    if (parallel_num != 1) {\n      // There no need to call NdSbp4ArgNameAndIndex when parallel_num = 1 in local.\n      const NdSbp& x_nd_sbp = ctx->NdSbp4ArgNameAndIndex(\"x\", 0);\n      const Shape& hierarchy = *ctx->parallel_desc().hierarchy();\n      int64_t split_num = 1;\n      for (int32_t i = 0; i < x_nd_sbp.sbp_parallel_size(); ++i) {\n        if (x_nd_sbp.sbp_parallel(i).has_split_parallel()) {\n          CHECK_EQ_OR_RETURN(x_nd_sbp.sbp_parallel(i).split_parallel().axis(), 0)\n              << \"blob x in NormalizationAddReluOp only support B or S(0)\";\n          split_num *= hierarchy.At(i);\n        }\n      }\n      CHECK_EQ_OR_RETURN(reserve_space_bits % split_num, 0);\n      reserve_space_bits = reserve_space_bits / split_num;\n    }\n    reserve_space->set_shape(Shape({static_cast<int64_t>(RoundUp(reserve_space_bits, 32) / 32)}));\n    return Maybe<void>::Ok();\n  })(ctx);\n}\n\n/* static */ Maybe<void> NormalizationAddReluOp::InferPhysicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return MakeFwTensorDescInferFn([](user_op::InferContext* ctx, const user_op::TensorDesc* x,\n                                    user_op::TensorDesc* reserve_space) -> Maybe<void> {\n    const auto& x_desc = ctx->InputTensorDesc(\"x\", 0);\n    reserve_space->set_shape(\n        Shape({static_cast<int64_t>(RoundUp(x_desc.shape().elem_cnt(), 32) / 32)}));\n    return Maybe<void>::Ok();\n  })(ctx);\n}\n\n/* static */ Maybe<void> NormalizationAddReluOp::GetSbp(user_op::SbpContext* ctx) {\n  return FwGetSbpFn(ctx);\n}\n\n/* static */ Maybe<void> NormalizationAddReluOp::ModifyInputArg(\n    const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) {\n  return FwInputArgModifyFn(GetInputArgModifierFn, conf);\n}\n\n/* static */ Maybe<void> NormalizationAddReluOp::InferDataType(user_op::InferContext* ctx) {\n  return MakeFwDataTypeInferFn([](user_op::InferContext* ctx, const user_op::TensorDesc* x,\n                                  user_op::TensorDesc* reserve_space) -> Maybe<void> {\n    reserve_space->set_data_type(DataType::kInt32);\n    return Maybe<void>::Ok();\n  })(ctx);\n}\n\n#if defined(WITH_CUDA) && (CUDNN_VERSION >= 7401)\n\nnamespace {\n\nvoid InferCudnnReserveSpaceSize(DataType data_type, cudnnBatchNormOps_t ops, int64_t n, int64_t c,\n                                int64_t h, int64_t w, size_t* reserve_space_size) {\n  cudnnHandle_t cudnn_handle = Singleton<CudnnHandlePool>::Get()->Get();\n  CudnnTensorDesc xy_desc(CUDNN_TENSOR_NHWC, data_type, n, c, h, w);\n  CudnnActivationDesc activation_desc(CUDNN_ACTIVATION_RELU, CUDNN_PROPAGATE_NAN, 0);\n  OF_CUDNN_CHECK(cudnnGetBatchNormalizationTrainingExReserveSpaceSize(\n      cudnn_handle, CUDNN_BATCHNORM_SPATIAL_PERSISTENT, ops, activation_desc.Get(), xy_desc.Get(),\n      reserve_space_size));\n  Singleton<CudnnHandlePool>::Get()->Put(cudnn_handle);\n}\n\n}  // namespace\n\n/* static */ Maybe<void> CudnnFusedNormalizationAddReluOp::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return MakeFwTensorDescInferFn([](user_op::InferContext* ctx, const user_op::TensorDesc* x,\n                                    user_op::TensorDesc* reserve_space) -> Maybe<void> {\n    const Shape& x_shape = x->shape();\n    const auto axis = ctx->Attr<int32_t>(\"axis\");\n    CHECK_EQ_OR_RETURN(x_shape.Count(axis + 1), 1);\n    int64_t n = x_shape.At(0);\n    {\n      const auto& x_nd_sbp = ctx->NdSbp4ArgNameAndIndex(\"x\", 0);\n      const Shape& hierarchy = *ctx->parallel_desc().hierarchy();\n      int64_t split_num = 1;\n      for (int32_t i = 0; i < x_nd_sbp.sbp_parallel_size(); ++i) {\n        if (x_nd_sbp.sbp_parallel(i).has_split_parallel()) {\n          CHECK_EQ_OR_RETURN(x_nd_sbp.sbp_parallel(i).split_parallel().axis(), 0)\n              << \"blob x in CudnnFusedNormalizationAddReluOp only support B or S(0)\";\n          split_num *= hierarchy.At(i);\n        }\n      }\n      CHECK_EQ_OR_RETURN(n % split_num, 0);\n      n = n / split_num;\n    }\n    int64_t h = x_shape.Count(1, axis);\n    int64_t w = 1;\n    int64_t c = x_shape.At(axis);\n    cudnnBatchNormOps_t ops;\n    if (ctx->has_input(\"addend\", 0)) {\n      ops = CUDNN_BATCHNORM_OPS_BN_ADD_ACTIVATION;\n    } else {\n      ops = CUDNN_BATCHNORM_OPS_BN_ACTIVATION;\n    }\n    size_t reserve_space_size;\n    InferCudnnReserveSpaceSize(x->data_type(), ops, n, c, h, w, &reserve_space_size);\n    reserve_space_size = std::max(reserve_space_size, GetOneVal<size_t>());\n    reserve_space->set_shape(Shape({static_cast<int64_t>(reserve_space_size)}));\n    return Maybe<void>::Ok();\n  })(ctx);\n}\n\n/* static */ Maybe<void> CudnnFusedNormalizationAddReluOp::InferPhysicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return MakeFwTensorDescInferFn([](user_op::InferContext* ctx, const user_op::TensorDesc* x,\n                                    user_op::TensorDesc* reserve_space) -> Maybe<void> {\n    const Shape& x_shape = x->shape();\n    const auto axis = ctx->Attr<int32_t>(\"axis\");\n    CHECK_EQ_OR_RETURN(x_shape.Count(axis + 1), 1);\n    int64_t n = x_shape.At(0);\n    int64_t h = x_shape.Count(1, axis);\n    int64_t w = 1;\n    int64_t c = x_shape.At(axis);\n    cudnnBatchNormOps_t ops;\n    if (ctx->has_input(\"addend\", 0)) {\n      ops = CUDNN_BATCHNORM_OPS_BN_ADD_ACTIVATION;\n    } else {\n      ops = CUDNN_BATCHNORM_OPS_BN_ACTIVATION;\n    }\n    size_t reserve_space_size;\n    InferCudnnReserveSpaceSize(x->data_type(), ops, n, c, h, w, &reserve_space_size);\n    reserve_space_size = std::max(reserve_space_size, GetOneVal<size_t>());\n    reserve_space->set_shape(Shape({static_cast<int64_t>(reserve_space_size)}));\n    return Maybe<void>::Ok();\n  })(ctx);\n}\n\n/* static */ Maybe<void> CudnnFusedNormalizationAddReluOp::GetSbp(user_op::SbpContext* ctx) {\n  return FwGetSbpFn(ctx);\n}\n\n/* static */ Maybe<void> CudnnFusedNormalizationAddReluOp::ModifyInputArg(\n    const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) {\n  return FwInputArgModifyFn(GetInputArgModifierFn, conf);\n}\n\n/* static */ Maybe<void> CudnnFusedNormalizationAddReluOp::InferDataType(\n    user_op::InferContext* ctx) {\n  return MakeFwDataTypeInferFn([](user_op::InferContext* ctx, const user_op::TensorDesc* x,\n                                  user_op::TensorDesc* reserve_space) -> Maybe<void> {\n    reserve_space->set_data_type(DataType::kChar);\n    return Maybe<void>::Ok();\n  })(ctx);\n}\n\n#else\n\n/* static */ Maybe<void> CudnnFusedNormalizationAddReluOp::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return Error::UnimplementedError() << \"require CUDA and CuDNN >= 7401\";\n}\n\n/* static */ Maybe<void> CudnnFusedNormalizationAddReluOp::InferPhysicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return Error::UnimplementedError() << \"require CUDA and CuDNN >= 7401\";\n}\n\n/* static */ Maybe<void> CudnnFusedNormalizationAddReluOp::GetSbp(user_op::SbpContext* ctx) {\n  return Error::UnimplementedError() << \"require CUDA and CuDNN >= 7401\";\n}\n\n/* static */ Maybe<void> CudnnFusedNormalizationAddReluOp::ModifyInputArg(\n    const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) {\n  return Error::UnimplementedError() << \"require CUDA and CuDNN >= 7401\";\n}\n\n/* static */ Maybe<void> CudnnFusedNormalizationAddReluOp::InferDataType(\n    user_op::InferContext* ctx) {\n  return Error::UnimplementedError() << \"require CUDA and CuDNN >= 7401\";\n}\n\n#endif  // WITH_CUDA\n\nnamespace {\n\nMaybe<void> BwTensorDescInferFn(user_op::InferContext* ctx) {\n#ifdef WITH_CUDA\n  // assume cudnn is enabled\n  CHECK_GE_OR_RETURN(ctx->Attr<float>(\"epsilon\"), CUDNN_BN_MIN_EPSILON);\n#endif\n  const user_op::TensorDesc& x = ctx->InputTensorDesc(\"x\", 0);\n  const Shape& x_shape = x.shape();\n  const user_op::TensorDesc& dy = ctx->InputTensorDesc(\"dy\", 0);\n  CHECK_EQ_OR_RETURN(dy.shape(), x_shape);\n  if (ctx->has_input(\"y\", 0)) {\n    const user_op::TensorDesc& y = ctx->InputTensorDesc(\"y\", 0);\n    CHECK_EQ_OR_RETURN(y.shape(), x_shape);\n  }\n  *ctx->MutOutputTensorDesc(\"dx\", 0) = x;\n  if (ctx->has_output(\"addend_diff\", 0)) { *ctx->MutOutputTensorDesc(\"addend_diff\", 0) = x; }\n  const Shape param_shape({x_shape.At(ctx->Attr<int32_t>(\"axis\"))});\n  const auto CheckParamTensorDesc = MakeCheckParamTensorDescFn(ctx, param_shape);\n  const auto SetParamTensorDesc = MakeSetParamTensorDescFn(ctx, param_shape);\n  JUST(CheckParamTensorDesc(\"mean\"));\n  JUST(CheckParamTensorDesc(\"inv_variance\"));\n  JUST(CheckParamTensorDesc(\"gamma\"));\n  JUST(CheckParamTensorDesc(\"beta\"));\n  JUST(SetParamTensorDesc(\"gamma_diff\"));\n  JUST(SetParamTensorDesc(\"beta_diff\"));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> BwDataTypeInferFn(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& x = ctx->InputTensorDesc(\"x\", 0);\n  const DataType x_type = x.data_type();\n  const user_op::TensorDesc& dy = ctx->InputTensorDesc(\"dy\", 0);\n  CHECK_EQ_OR_RETURN(dy.data_type(), x_type)\n      << \"InferDataType Failed. Expected \" << DataType_Name(x_type) << \", but got \"\n      << DataType_Name(dy.data_type());\n  if (ctx->has_input(\"y\", 0)) {\n    const user_op::TensorDesc& y = ctx->InputTensorDesc(\"y\", 0);\n    CHECK_EQ_OR_RETURN(y.data_type(), x_type)\n        << \"InferDataType Failed. Expected \" << DataType_Name(x_type) << \", but got \"\n        << DataType_Name(y.data_type());\n  }\n  *ctx->MutOutputTensorDesc(\"dx\", 0) = x;\n  if (ctx->has_output(\"addend_diff\", 0)) { *ctx->MutOutputTensorDesc(\"addend_diff\", 0) = x; }\n  const DataType param_data_type =\n      (x_type == DataType::kFloat16 || x_type == DataType::kBFloat16) ? DataType::kFloat : x_type;\n  const auto CheckParamDataType = MakeCheckParamDataTypeFn(ctx, param_data_type);\n  const auto SetParamDataType = MakeSetParamDataTypeFn(ctx, param_data_type);\n  JUST(CheckParamDataType(\"mean\"));\n  JUST(CheckParamDataType(\"inv_variance\"));\n  JUST(CheckParamDataType(\"gamma\"));\n  JUST(CheckParamDataType(\"beta\"));\n  JUST(SetParamDataType(\"gamma_diff\"));\n  JUST(SetParamDataType(\"beta_diff\"));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> BwGetSbpFn(user_op::SbpContext* ctx) {\n  std::vector<user_op::OpArg> broadcast_args;\n  broadcast_args.emplace_back(\"mean\", 0);\n  broadcast_args.emplace_back(\"inv_variance\", 0);\n  broadcast_args.emplace_back(\"gamma\", 0);\n  if (ctx->user_op_conf().has_input(\"beta\", 0)) { broadcast_args.emplace_back(\"beta\", 0); }\n  if (ctx->user_op_conf().has_input(\"reserve_space\", 0)) {\n    broadcast_args.emplace_back(\"reserve_space\", 0);\n  }\n  std::vector<user_op::OpArg> partial_sum_args;\n  partial_sum_args.emplace_back(\"gamma_diff\", 0);\n  partial_sum_args.emplace_back(\"beta_diff\", 0);\n  std::vector<user_op::OpArg> split_args;\n  split_args.emplace_back(\"x\", 0);\n  split_args.emplace_back(\"dy\", 0);\n  split_args.emplace_back(\"dx\", 0);\n  if (ctx->user_op_conf().has_input(\"y\", 0)) { split_args.emplace_back(\"y\", 0); }\n  if (ctx->user_op_conf().has_output(\"addend_diff\", 0)) {\n    split_args.emplace_back(\"addend_diff\", 0);\n  }\n  ctx->NewBuilder()\n      .Broadcast(broadcast_args)\n      .PartialSum(partial_sum_args)\n      .Split(split_args, 0)\n      .Build();\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\n/* static */ Maybe<void> NormalizationGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  return BwTensorDescInferFn(ctx);\n}\n\n/*static*/ Maybe<void> NormalizationGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> NormalizationGradOp::GetSbp(user_op::SbpContext* ctx) {\n  return BwGetSbpFn(ctx);\n}\n\n/* static */ Maybe<void> NormalizationGradOp::InferDataType(user_op::InferContext* ctx) {\n  return BwDataTypeInferFn(ctx);\n}\n\n/* static */ Maybe<void> NormalizationAddReluGradOp::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return BwTensorDescInferFn(ctx);\n}\n\n/*static*/ Maybe<void> NormalizationAddReluGradOp::InferPhysicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> NormalizationAddReluGradOp::GetSbp(user_op::SbpContext* ctx) {\n  return BwGetSbpFn(ctx);\n}\n\n/* static */ Maybe<void> NormalizationAddReluGradOp::InferDataType(user_op::InferContext* ctx) {\n  return BwDataTypeInferFn(ctx);\n}\n\n#if defined(WITH_CUDA) && (CUDNN_VERSION >= 7401)\n\n/* static */ Maybe<void> CudnnFusedNormalizationAddReluGradOp::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return BwTensorDescInferFn(ctx);\n}\n\n/*static*/ Maybe<void> CudnnFusedNormalizationAddReluGradOp::InferPhysicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> CudnnFusedNormalizationAddReluGradOp::GetSbp(user_op::SbpContext* ctx) {\n  return BwGetSbpFn(ctx);\n}\n\n/* static */ Maybe<void> CudnnFusedNormalizationAddReluGradOp::InferDataType(\n    user_op::InferContext* ctx) {\n  return BwDataTypeInferFn(ctx);\n}\n\n#else\n\n/* static */ Maybe<void> CudnnFusedNormalizationAddReluGradOp::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return Error::UnimplementedError() << \"require CUDA and CuDNN >= 7401\";\n}\n\n/*static*/ Maybe<void> CudnnFusedNormalizationAddReluGradOp::InferPhysicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return Error::UnimplementedError() << \"require CUDA and CuDNN >= 7401\";\n}\n\n/* static */ Maybe<void> CudnnFusedNormalizationAddReluGradOp::GetSbp(user_op::SbpContext* ctx) {\n  return Error::UnimplementedError() << \"require CUDA and CuDNN >= 7401\";\n}\n\n/* static */ Maybe<void> CudnnFusedNormalizationAddReluGradOp::InferDataType(\n    user_op::InferContext* ctx) {\n  return Error::UnimplementedError() << \"require CUDA and CuDNN >= 7401\";\n}\n\n#endif\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/nvtx_range_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n#ifdef WITH_CUDA\n\n/* static */ Maybe<void> NvtxStartOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  ctx->SetOutputShape(\"out\", 0, ctx->InputShape(\"in\", 0));\n  ctx->SetOutputIsDynamic(\"out\", 0, ctx->InputIsDynamic(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> NvtxStartOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> NvtxStartOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"in\", 0);\n  FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) {\n    ctx->NewBuilder().Split(user_op::OpArg(\"in\", 0), i).Split(user_op::OpArg(\"out\", 0), i).Build();\n  }\n  ctx->NewBuilder()\n      .PartialSum(user_op::OpArg(\"in\", 0))\n      .PartialSum(user_op::OpArg(\"out\", 0))\n      .Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> NvtxStartOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> NvtxEndOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  ctx->SetOutputShape(\"out\", 0, ctx->InputShape(\"in\", 0));\n  ctx->SetOutputIsDynamic(\"out\", 0, ctx->InputIsDynamic(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> NvtxEndOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> NvtxEndOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"in\", 0);\n  FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) {\n    ctx->NewBuilder().Split(user_op::OpArg(\"in\", 0), i).Split(user_op::OpArg(\"out\", 0), i).Build();\n  }\n  ctx->NewBuilder()\n      .PartialSum(user_op::OpArg(\"in\", 0))\n      .PartialSum(user_op::OpArg(\"out\", 0))\n      .Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> NvtxEndOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n#else\n\n/* static */ Maybe<void> NvtxStartOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  return Error::UnimplementedError() << \"require CUDA to use NVTX\";\n}\n\n/*static*/ Maybe<void> NvtxStartOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> NvtxStartOp::GetSbp(user_op::SbpContext* ctx) {\n  return Error::UnimplementedError() << \"require CUDA to use NVTX\";\n}\n\n/* static */ Maybe<void> NvtxStartOp::InferDataType(user_op::InferContext* ctx) {\n  return Error::UnimplementedError() << \"require CUDA to use NVTX\";\n}\n\n/* static */ Maybe<void> NvtxEndOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  return Error::UnimplementedError() << \"require CUDA to use NVTX\";\n}\n\n/*static*/ Maybe<void> NvtxEndOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return Error::UnimplementedError() << \"require CUDA to use NVTX\";\n}\n\n/* static */ Maybe<void> NvtxEndOp::GetSbp(user_op::SbpContext* ctx) {\n  return Error::UnimplementedError() << \"require CUDA to use NVTX\";\n}\n\n/* static */ Maybe<void> NvtxEndOp::InferDataType(user_op::InferContext* ctx) {\n  return Error::UnimplementedError() << \"require CUDA to use NVTX\";\n}\n\n#endif  // WITH_CUDA\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/ofrecord_decoder_ops.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/common/balanced_splitter.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/* static */ Maybe<void> OfrecordRawDecoderOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc(\"in\", 0);\n  user_op::TensorDesc* out_tensor = ctx->MutOutputTensorDesc(\"out\", 0);\n  CHECK_OR_RETURN(in_tensor.shape().NumAxes() == 1 && in_tensor.shape().At(0) >= 1);\n  Shape conf_shape = ctx->Attr<Shape>(\"shape\");\n  DimVector dim_vec(1 + conf_shape.NumAxes());\n  dim_vec[0] = in_tensor.shape().At(0);\n  for (int i = 1; i < dim_vec.size(); ++i) { dim_vec[i] = conf_shape.At(i - 1); }\n  out_tensor->set_shape(Shape(dim_vec));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> OfrecordRawDecoderOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> OfrecordRawDecoderOp::GetSbp(user_op::SbpContext* ctx) {\n  ctx->NewBuilder().Split(user_op::OpArg(\"in\", 0), 0).Split(user_op::OpArg(\"out\", 0), 0).Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> OfrecordRawDecoderOp::ModifyInputArg(\n    const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) {\n  user_op::InputArgModifier* in_modifier = GetInputArgModifierFn(\"in\", 0);\n  CHECK_NOTNULL_OR_RETURN(in_modifier);\n  in_modifier->set_requires_grad(false);\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> OfrecordRawDecoderOp::InferDataType(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc(\"in\", 0);\n  user_op::TensorDesc* out_tensor = ctx->MutOutputTensorDesc(\"out\", 0);\n  CHECK_OR_RETURN(in_tensor.data_type() == DataType::kOFRecord);\n  out_tensor->set_data_type(ctx->Attr<DataType>(\"data_type\"));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> OfrecordBytesDecoderOp::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) {\n  const user_op::TensorDesc& in = ctx->InputTensorDesc(\"in\", 0);\n  user_op::TensorDesc* out = ctx->MutOutputTensorDesc(\"out\", 0);\n  out->set_is_dynamic(in.is_dynamic());\n  out->set_shape(in.shape());\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> OfrecordBytesDecoderOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> OfrecordBytesDecoderOp::GetSbp(user_op::SbpContext* ctx) {\n  return user_op::GetSbpFnUtil::SplitForEachAxis(ctx);\n}\n\n/* static */ Maybe<void> OfrecordBytesDecoderOp::ModifyInputArg(\n    const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) {\n  user_op::InputArgModifier* in_modifier = GetInputArgModifierFn(\"in\", 0);\n  CHECK_NOTNULL_OR_RETURN(in_modifier);\n  in_modifier->set_requires_grad(false);\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> OfrecordBytesDecoderOp::InferDataType(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& in = ctx->InputTensorDesc(\"in\", 0);\n  user_op::TensorDesc* out = ctx->MutOutputTensorDesc(\"out\", 0);\n  CHECK_OR_RETURN(in.data_type() == DataType::kOFRecord);\n  out->set_data_type(DataType::kTensorBuffer);\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> OfrecordImageDecoderOp::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) {\n  const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc(\"in\", 0);\n  user_op::TensorDesc* out_tensor = ctx->MutOutputTensorDesc(\"out\", 0);\n  CHECK_OR_RETURN(in_tensor.shape().NumAxes() == 1 && in_tensor.shape().At(0) >= 1);\n  out_tensor->set_shape(in_tensor.shape());\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> OfrecordImageDecoderOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> OfrecordImageDecoderOp::GetSbp(user_op::SbpContext* ctx) {\n  ctx->NewBuilder().Split(user_op::OpArg(\"in\", 0), 0).Split(user_op::OpArg(\"out\", 0), 0).Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> OfrecordImageDecoderOp::ModifyInputArg(\n    const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) {\n  user_op::InputArgModifier* in_modifier = GetInputArgModifierFn(\"in\", 0);\n  CHECK_NOTNULL_OR_RETURN(in_modifier);\n  in_modifier->set_requires_grad(false);\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> OfrecordImageDecoderOp::InferDataType(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc(\"in\", 0);\n  user_op::TensorDesc* out_tensor = ctx->MutOutputTensorDesc(\"out\", 0);\n  CHECK_OR_RETURN(in_tensor.data_type() == DataType::kOFRecord);\n  out_tensor->set_data_type(DataType::kTensorBuffer);\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> OfrecordImageDecoderRandomCropOp::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) {\n  const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc(\"in\", 0);\n  user_op::TensorDesc* out_tensor = ctx->MutOutputTensorDesc(\"out\", 0);\n  CHECK_OR_RETURN(in_tensor.shape().NumAxes() == 1 && in_tensor.shape().At(0) >= 1);\n  out_tensor->set_shape(in_tensor.shape());\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> OfrecordImageDecoderRandomCropOp::InferPhysicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> OfrecordImageDecoderRandomCropOp::GetSbp(user_op::SbpContext* ctx) {\n  ctx->NewBuilder().Split(user_op::OpArg(\"in\", 0), 0).Split(user_op::OpArg(\"out\", 0), 0).Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> OfrecordImageDecoderRandomCropOp::ModifyInputArg(\n    const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) {\n  user_op::InputArgModifier* in_modifier = GetInputArgModifierFn(\"in\", 0);\n  CHECK_NOTNULL_OR_RETURN(in_modifier);\n  in_modifier->set_requires_grad(false);\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> OfrecordImageDecoderRandomCropOp::InferDataType(\n    user_op::InferContext* ctx) {\n  const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc(\"in\", 0);\n  user_op::TensorDesc* out_tensor = ctx->MutOutputTensorDesc(\"out\", 0);\n  CHECK_OR_RETURN(in_tensor.data_type() == DataType::kOFRecord);\n  out_tensor->set_data_type(DataType::kTensorBuffer);\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/ofrecord_image_classification_reader_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/* static */ Maybe<void> OfrecordImageClassificationReaderOp::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) {\n  user_op::TensorDesc* image_tensor = ctx->MutOutputTensorDesc(\"image\", 0);\n  user_op::TensorDesc* label_tensor = ctx->MutOutputTensorDesc(\"label\", 0);\n  int32_t batch_size = ctx->Attr<int32_t>(\"batch_size\");\n  image_tensor->set_shape(Shape({batch_size}));\n  label_tensor->set_shape(Shape({batch_size}));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> OfrecordImageClassificationReaderOp::InferPhysicalTensorDesc(\n    user_op::InferContext* ctx) {\n  user_op::TensorDesc* image_tensor = ctx->MutOutputTensorDesc(\"image\", 0);\n  user_op::TensorDesc* label_tensor = ctx->MutOutputTensorDesc(\"label\", 0);\n  int32_t local_batch_size = ctx->Attr<int32_t>(\"batch_size\");\n  int64_t parallel_num = ctx->parallel_ctx().parallel_num();\n\n  if (parallel_num > 1) {\n    int64_t split_num = 1;\n    const NdSbp& nd_sbp = ctx->NdSbp4ArgNameAndIndex(\"image\", 0);\n    const Shape& hierarchy = *ctx->parallel_desc().hierarchy();\n    for (int32_t i = 0; i < nd_sbp.sbp_parallel_size(); ++i) {\n      if (nd_sbp.sbp_parallel(i).has_split_parallel()) { split_num *= hierarchy.At(i); }\n    }\n    CHECK_EQ_OR_RETURN(local_batch_size % split_num, 0);\n    local_batch_size /= split_num;\n  }\n  image_tensor->set_shape(Shape({local_batch_size}));\n  label_tensor->set_shape(Shape({local_batch_size}));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> OfrecordImageClassificationReaderOp::GetSbp(user_op::SbpContext* ctx) {\n  ctx->NewBuilder().Broadcast(ctx->inputs()).Split(ctx->outputs(), 0).Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> OfrecordImageClassificationReaderOp::ModifyOutputArg(\n    const GetOutputArgModifier& GetOutputArgModifierFn, const user_op::UserOpConfWrapper& conf) {\n  user_op::OutputArgModifier* image_modifier = GetOutputArgModifierFn(\"image\", 0);\n  CHECK_OR_RETURN(image_modifier != nullptr);\n  image_modifier->set_header_infered_before_compute(false);\n  user_op::OutputArgModifier* label_modifier = GetOutputArgModifierFn(\"label\", 0);\n  CHECK_OR_RETURN(label_modifier != nullptr);\n  label_modifier->set_header_infered_before_compute(false);\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> OfrecordImageClassificationReaderOp::InferDataType(\n    user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"image\", 0, DataType::kTensorBuffer);\n  ctx->SetOutputDType(\"label\", 0, DataType::kTensorBuffer);\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/ofrecord_reader_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/operator/operator.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n/* static */ Maybe<void> OFRecordReaderOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  user_op::TensorDesc* out_tensor = ctx->MutOutputTensorDesc(\"out\", 0);\n  out_tensor->set_shape(Shape({ctx->Attr<int32_t>(\"batch_size\")}));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> OFRecordReaderOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  user_op::TensorDesc* out_tensor = ctx->MutOutputTensorDesc(\"out\", 0);\n  int32_t batch_size = ctx->Attr<int32_t>(\"batch_size\");\n  int64_t parallel_num = ctx->parallel_ctx().parallel_num();\n  if (parallel_num > 1) {\n    int64_t split_num = 1;\n    const NdSbp& nd_sbp = ctx->NdSbp4ArgNameAndIndex(\"out\", 0);\n    const Shape& hierarchy = *ctx->parallel_desc().hierarchy();\n    for (int32_t i = 0; i < nd_sbp.sbp_parallel_size(); ++i) {\n      if (nd_sbp.sbp_parallel(i).has_split_parallel()) { split_num *= hierarchy.At(i); }\n    }\n    CHECK_EQ_OR_RETURN(batch_size % split_num, 0);\n    batch_size /= split_num;\n  }\n  out_tensor->set_shape(Shape({batch_size}));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> OFRecordReaderOp::GetSbp(user_op::SbpContext* ctx) {\n  ctx->NewBuilder().Broadcast(ctx->inputs()).Split(ctx->outputs(), 0).Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> OFRecordReaderOp::GetNdSbpSignatureList(\n    user_op::GetNdSbpSignatureListContext* ctx) {\n  NdSbpSignature nd_sbp_signature;\n  SbpParallel split_sbp_parallel;\n  split_sbp_parallel.mutable_split_parallel()->set_axis(0);\n  for (int32_t dim_sbp = 0; dim_sbp < ctx->parallel_hierarchy().NumAxes(); dim_sbp++) {\n    *(*nd_sbp_signature.mutable_bn_in_op2nd_sbp())[GenRepeatedBn(\"out\", 0)].add_sbp_parallel() =\n        split_sbp_parallel;\n  }\n  ctx->AddNdSbpSignature(nd_sbp_signature);\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<double> OFRecordReaderOp::GetComputeComplexity(\n    user_op::ComputeComplexityFnContext* ctx) {\n  // Don't support broadcast.\n  return double(ctx->Shape4ArgNameAndIndex(\"out\", 0).elem_cnt()\n                * GetSizeOfDataType(DataType::kOFRecord))\n         / ctx->parallel_desc().hierarchy()->elem_cnt();\n}\n\n/* static */ Maybe<void> OFRecordReaderOp::ModifyOutputArg(\n    const GetOutputArgModifier& GetOutputArgModifierFn, const user_op::UserOpConfWrapper& conf) {\n  user_op::OutputArgModifier* out_modifier = GetOutputArgModifierFn(\"out\", 0);\n  CHECK_OR_RETURN(out_modifier != nullptr);\n  // NOTE(chengcheng): OFRecordReader Only support static shape infer which will read all batch\n  //  size data with output shape (batch_size,)\n  // out_modifier->set_header_infered_before_compute(false);\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> OFRecordReaderOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) {\n  SbpParallel default_sbp;\n  default_sbp.mutable_split_parallel()->set_axis(0);\n  return user_op::InferNdSbp4SrcOp(ctx, default_sbp);\n}\n\n/* static */ Maybe<void> OFRecordReaderOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, DataType::kOFRecord);\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/one_embedding_ops.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n#include \"oneflow/core/embedding/embedding_manager.h\"\n\nnamespace oneflow {\n\n/* static */ Maybe<void> OneEmbeddingFusedLookupOp::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) {\n  const Shape& ids_shape = ctx->InputShape(\"ids\", 0);\n  if (ctx->has_input(\"table_ids\", 0)) {\n    const Shape& table_ids_shape = ctx->InputShape(\"table_ids\", 0);\n    CHECK_EQ_OR_RETURN(ids_shape, table_ids_shape) << \"table_ids shape must equal to ids shape\";\n  }\n  DimVector out_dim_vec = ids_shape.dim_vec();\n  const int64_t embedding_size = ctx->Attr<int64_t>(\"embedding_size\");\n  out_dim_vec.push_back(embedding_size);\n  ctx->SetOutputShape(\"embeddings\", 0, Shape(out_dim_vec));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> OneEmbeddingFusedLookupOp::InferPhysicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> OneEmbeddingFusedLookupOp::GetSbp(user_op::SbpContext* ctx) {\n  auto builder = ctx->NewBuilder()\n                     .Broadcast(user_op::OpArg(\"shadow\", 0))\n                     .Split(user_op::OpArg(\"ids\", 0), 0)\n                     .Split(user_op::OpArg(\"embeddings\", 0), 0);\n  if (ctx->user_op_conf().has_input(\"table_ids\", 0)) {\n    builder.Split(user_op::OpArg(\"table_ids\", 0), 0);\n  }\n  builder.Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> OneEmbeddingFusedLookupOp::ModifyInputArg(\n    const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) {\n  user_op::InputArgModifier* shadow = GetInputArgModifierFn(\"shadow\", 0);\n  CHECK_OR_RETURN(shadow != nullptr) << \"shadow is nullptr\";\n  shadow->set_requires_grad(false);\n  user_op::InputArgModifier* ids = GetInputArgModifierFn(\"ids\", 0);\n  CHECK_OR_RETURN(ids != nullptr);\n  ids->set_requires_grad(false);\n  if (conf.has_input(\"table_ids\", 0)) {\n    user_op::InputArgModifier* table_ids = GetInputArgModifierFn(\"table_ids\", 0);\n    CHECK_OR_RETURN(table_ids != nullptr) << \"table_ids is nullptr\";\n    table_ids->set_requires_grad(false);\n  }\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> OneEmbeddingFusedLookupOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"embeddings\", 0, ctx->InputDType(\"shadow\", 0));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> OneEmbeddingFusedLookupGradOp::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> OneEmbeddingFusedLookupGradOp::InferPhysicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> OneEmbeddingFusedLookupGradOp::GetSbp(user_op::SbpContext* ctx) {\n  ctx->NewBuilder()\n      .Split(user_op::OpArg(\"ids\", 0), 0)\n      .Split(user_op::OpArg(\"embedding_grad\", 0), 0)\n      .Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> OneEmbeddingFusedLookupGradOp::InferDataType(user_op::InferContext* ctx) {\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> EmbeddingPrefetchOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& num_unique_ids_shape = ctx->InputShape(\"num_unique_ids\", 0);\n  const Shape& unique_ids_shape = ctx->InputShape(\"unique_ids\", 0);\n  const Shape& table_ids_shape = ctx->InputShape(\"table_ids\", 0);\n  CHECK_EQ_OR_RETURN(unique_ids_shape, table_ids_shape)\n      << \"table_ids shape must equal to ids shape\";\n  CHECK_EQ_OR_RETURN(num_unique_ids_shape.elem_cnt(), 1);\n  ctx->SetOutputShape(\"context\", 0, num_unique_ids_shape);\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> EmbeddingPrefetchOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> EmbeddingPrefetchOp::GetSbp(user_op::SbpContext* ctx) {\n  ctx->NewBuilder()\n      .Broadcast(user_op::OpArg(\"num_unique_ids\", 0))\n      .Split(user_op::OpArg(\"unique_ids\", 0), 0)\n      .Split(user_op::OpArg(\"table_ids\", 0), 0)\n      .Broadcast(user_op::OpArg(\"context\", 0))\n      .Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> EmbeddingPrefetchOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"context\", 0, ctx->InputDType(\"num_unique_ids\", 0));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> EmbeddingLookupOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& num_unique_ids_shape = ctx->InputShape(\"num_unique_ids\", 0);\n  const Shape& unique_ids_shape = ctx->InputShape(\"unique_ids\", 0);\n  const Shape& table_ids_shape = ctx->InputShape(\"table_ids\", 0);\n  CHECK_EQ_OR_RETURN(unique_ids_shape, table_ids_shape)\n      << \"table_ids shape must equal to ids shape\";\n  CHECK_EQ_OR_RETURN(num_unique_ids_shape.elem_cnt(), 1);\n  const int64_t embedding_size = ctx->Attr<int64_t>(\"embedding_size\");\n  const int64_t line_size = ctx->Attr<int64_t>(\"line_size\");\n  CHECK_NE_OR_RETURN(embedding_size, 0);\n  CHECK_NE_OR_RETURN(line_size, 0);\n  CHECK_GE_OR_RETURN(line_size, embedding_size);\n  const bool use_dynamic_memory_allocation = embedding::UseDynamicMemoryAllocation();\n  if (ctx->has_output(\"embeddings\", 0)) {\n    if (use_dynamic_memory_allocation) {\n      ctx->SetOutputShape(\"embeddings\", 0, Shape({1}));\n    } else {\n      DimVector embeddings_dim_vec = unique_ids_shape.dim_vec();\n      embeddings_dim_vec.push_back(embedding_size);\n      ctx->SetOutputShape(\"embeddings\", 0, Shape(embeddings_dim_vec));\n    }\n  }\n  if (use_dynamic_memory_allocation) {\n    ctx->SetOutputShape(\"unique_values\", 0, Shape({1}));\n  } else {\n    DimVector unique_values_dim_vec = unique_ids_shape.dim_vec();\n    unique_values_dim_vec.push_back(line_size);\n    ctx->SetOutputShape(\"unique_values\", 0, Shape(unique_values_dim_vec));\n  }\n\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> EmbeddingLookupOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> EmbeddingLookupOp::GetSbp(user_op::SbpContext* ctx) {\n  auto builder = ctx->NewBuilder()\n                     .Broadcast(user_op::OpArg(\"num_unique_ids\", 0))\n                     .Split(user_op::OpArg(\"unique_ids\", 0), 0)\n                     .Split(user_op::OpArg(\"table_ids\", 0), 0);\n  if (ctx->user_op_conf().has_input(\"context\", 0)) {\n    builder.Broadcast(user_op::OpArg(\"context\", 0));\n  }\n  const bool use_dynamic_memory_allocation = embedding::UseDynamicMemoryAllocation();\n  if (use_dynamic_memory_allocation) {\n    builder.Broadcast(user_op::OpArg(\"unique_values\", 0));\n  } else {\n    builder.Split(user_op::OpArg(\"unique_values\", 0), 0);\n  }\n  if (ctx->user_op_conf().has_output(\"embeddings\", 0)) {\n    if (use_dynamic_memory_allocation) {\n      builder.Broadcast(user_op::OpArg(\"embeddings\", 0));\n    } else {\n      builder.Split(user_op::OpArg(\"embeddings\", 0), 0);\n    }\n  }\n  builder.Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> EmbeddingLookupOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"unique_values\", 0, ctx->Attr<DataType>(\"dtype\"));\n  if (ctx->has_output(\"embeddings\", 0)) {\n    ctx->SetOutputDType(\"embeddings\", 0, ctx->Attr<DataType>(\"embeddings_dtype\"));\n  }\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> EmbeddingPutOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> EmbeddingPutOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> EmbeddingPutOp::GetSbp(user_op::SbpContext* ctx) {\n  auto builder = ctx->NewBuilder()\n                     .Broadcast(user_op::OpArg(\"num_unique_ids\", 0))\n                     .Split(user_op::OpArg(\"unique_ids\", 0), 0);\n  if (embedding::UseDynamicMemoryAllocation()) {\n    builder.Broadcast(user_op::OpArg(\"unique_embeddings\", 0)).Build();\n  } else {\n    builder.Split(user_op::OpArg(\"unique_embeddings\", 0), 0).Build();\n  }\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> EmbeddingPutOp::InferDataType(user_op::InferContext* ctx) {\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> CheckDataShape(user_op::InferContext* ctx) {\n  if (ctx->has_input(\"learning_rate\", 0)) {\n    CHECK_EQ_OR_RETURN(ctx->InputShape(\"learning_rate\", 0), Shape({1}));\n  }\n  if (ctx->has_input(\"down_scale_by_tensor\", 0)) {\n    CHECK_EQ_OR_RETURN(ctx->InputShape(\"down_scale_by_tensor\", 0), Shape({1}));\n  }\n  CHECK_EQ_OR_RETURN(ctx->InputShape(\"num_unique_ids\", 0), Shape({1}));\n  const Shape& embedding_grad_shape = ctx->InputShape(\"embedding_grad\", 0);\n  CHECK_EQ_OR_RETURN(embedding_grad_shape.NumAxes(), 2);\n  const Shape& unique_embeddings_shape = ctx->InputShape(\"unique_embeddings\", 0);\n  if (embedding::UseDynamicMemoryAllocation()) {\n    CHECK_EQ_OR_RETURN(unique_embeddings_shape.elem_cnt(), 1)\n        << \"if use dynamic memory allocation, unique_embeddings elem_cnt should be 1.\";\n  } else {\n    CHECK_EQ_OR_RETURN(unique_embeddings_shape.NumAxes(), 2)\n        << \"unique_embeddings num_axes should be 2.\";\n    CHECK_EQ_OR_RETURN(unique_embeddings_shape.At(0), embedding_grad_shape.At(0))\n        << \"got \" << unique_embeddings_shape.At(0) << \" and \" << embedding_grad_shape.At(0);\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> CheckDataType(user_op::InferContext* ctx) {\n  if (ctx->has_input(\"learning_rate\", 0)) {\n    const DataType learning_rate_dtype = ctx->InputDType(\"learning_rate\", 0);\n    CHECK_EQ_OR_RETURN(learning_rate_dtype, DataType::kFloat)\n        << \"InferDataType Failed. Expected \" << DataType_Name(DataType::kFloat) << \", but got \"\n        << DataType_Name(learning_rate_dtype);\n  }\n  if (ctx->has_input(\"down_scale_by_tensor\", 0)) {\n    CHECK_EQ_OR_RETURN(ctx->InputDType(\"down_scale_by_tensor\", 0),\n                       ctx->InputDType(\"unique_embeddings\", 0))\n        << \"InferDataType Failed. Expected \"\n        << DataType_Name(ctx->InputDType(\"unique_embeddings\", 0)) << \", but got \"\n        << DataType_Name(ctx->InputDType(\"down_scale_by_tensor\", 0));\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> GetEmbeddingUpdateSbp(user_op::SbpContext* ctx) {\n  auto builder = ctx->NewBuilder()\n                     .Broadcast(ctx->inputs())\n                     .Broadcast(user_op::OpArg(\"num_unique_ids\", 0))\n                     .Split(user_op::OpArg(\"embedding_grad\", 0), 0);\n  if (embedding::UseDynamicMemoryAllocation()) {\n    builder.Broadcast(user_op::OpArg(\"unique_embeddings\", 0))\n        .Broadcast(user_op::OpArg(\"updated_unique_embeddings\", 0))\n        .Build();\n  } else {\n    builder.Split(user_op::OpArg(\"unique_embeddings\", 0), 0)\n        .Split(user_op::OpArg(\"updated_unique_embeddings\", 0), 0)\n        .Build();\n  }\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> OneEmbeddingFusedSgdUpdatePutOp::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> OneEmbeddingFusedSgdUpdatePutOp::InferPhysicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> OneEmbeddingFusedSgdUpdatePutOp::GetSbp(user_op::SbpContext* ctx) {\n  auto builder = ctx->NewBuilder()\n                     .Broadcast(user_op::OpArg(\"learning_rate\", 0))\n                     .Broadcast(user_op::OpArg(\"num_unique_ids\", 0))\n                     .Split(user_op::OpArg(\"unique_ids\", 0), 0)\n                     .Split(user_op::OpArg(\"embedding_grad\", 0), 0);\n  if (embedding::UseDynamicMemoryAllocation()) {\n    builder.Broadcast(user_op::OpArg(\"unique_embeddings\", 0)).Build();\n  } else {\n    builder.Split(user_op::OpArg(\"unique_embeddings\", 0), 0).Build();\n  }\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> OneEmbeddingFusedSgdUpdatePutOp::InferDataType(\n    user_op::InferContext* ctx) {\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> OneEmbeddingSgdUpdateOp::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) {\n  JUST(CheckDataShape(ctx));\n  const int64_t embedding_size = ctx->Attr<int64_t>(\"embedding_size\");\n  const int64_t line_size = ctx->Attr<int64_t>(\"line_size\");\n  CHECK_NE_OR_RETURN(embedding_size, 0) << \"should set attr embedding_size\";\n  CHECK_NE_OR_RETURN(line_size, 0) << \"should set attr line_size\";\n  CHECK_EQ_OR_RETURN(line_size, embedding_size)\n      << \"when use SGD optimizer, line_size should equals to embedding_size, but get line_size: \"\n      << line_size << \" embedding_size: \" << embedding_size\n      << \", please set size_factor of store_options to 1.\";\n  const Shape& unique_embeddings_shape = ctx->InputShape(\"unique_embeddings\", 0);\n  ctx->SetOutputShape(\"updated_unique_embeddings\", 0, unique_embeddings_shape);\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> OneEmbeddingSgdUpdateOp::InferPhysicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> OneEmbeddingSgdUpdateOp::GetSbp(user_op::SbpContext* ctx) {\n  JUST(GetEmbeddingUpdateSbp(ctx));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> OneEmbeddingSgdUpdateOp::InferDataType(user_op::InferContext* ctx) {\n  JUST(CheckDataType(ctx));\n  ctx->SetOutputDType(\"updated_unique_embeddings\", 0, ctx->InputDType(\"unique_embeddings\", 0));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> OneEmbeddingMomentumUpdateOp::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) {\n  JUST(CheckDataShape(ctx));\n  const int64_t embedding_size = ctx->Attr<int64_t>(\"embedding_size\");\n  const int64_t line_size = ctx->Attr<int64_t>(\"line_size\");\n  CHECK_NE_OR_RETURN(embedding_size, 0) << \"should set attr embedding_size\";\n  CHECK_NE_OR_RETURN(line_size, 0) << \"should set attr line_size\";\n  CHECK_EQ_OR_RETURN(line_size, embedding_size * 2)\n      << \"when using Momentum optimizer, line_size should equals to embedding_size * 2, but get \"\n         \"line_size: \"\n      << line_size << \" embedding_size: \" << embedding_size\n      << \", please set size_factor of store_options to 2.\";\n  const Shape& unique_embeddings_shape = ctx->InputShape(\"unique_embeddings\", 0);\n  ctx->SetOutputShape(\"updated_unique_embeddings\", 0, unique_embeddings_shape);\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> OneEmbeddingMomentumUpdateOp::InferPhysicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> OneEmbeddingMomentumUpdateOp::GetSbp(user_op::SbpContext* ctx) {\n  JUST(GetEmbeddingUpdateSbp(ctx));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> OneEmbeddingMomentumUpdateOp::InferDataType(user_op::InferContext* ctx) {\n  JUST(CheckDataType(ctx));\n  ctx->SetOutputDType(\"updated_unique_embeddings\", 0, ctx->InputDType(\"unique_embeddings\", 0));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> OneEmbeddingAdamUpdateOp::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) {\n  JUST(CheckDataShape(ctx));\n  const int64_t embedding_size = ctx->Attr<int64_t>(\"embedding_size\");\n  const int64_t line_size = ctx->Attr<int64_t>(\"line_size\");\n  CHECK_NE_OR_RETURN(embedding_size, 0) << \"should set attr embedding_size\";\n  CHECK_NE_OR_RETURN(line_size, 0) << \"should set attr line_size\";\n  CHECK_EQ_OR_RETURN(line_size, embedding_size * 3)\n      << \"when using Adam optimizer, line_size should equals to embedding_size * 3, but get \"\n         \"line_size: \"\n      << line_size << \" embedding_size: \" << embedding_size\n      << \", please set size_factor of store_options to 3.\";\n  const Shape& unique_embeddings_shape = ctx->InputShape(\"unique_embeddings\", 0);\n  ctx->SetOutputShape(\"updated_unique_embeddings\", 0, unique_embeddings_shape);\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> OneEmbeddingAdamUpdateOp::InferPhysicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> OneEmbeddingAdamUpdateOp::GetSbp(user_op::SbpContext* ctx) {\n  JUST(GetEmbeddingUpdateSbp(ctx));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> OneEmbeddingAdamUpdateOp::InferDataType(user_op::InferContext* ctx) {\n  JUST(CheckDataType(ctx));\n  ctx->SetOutputDType(\"updated_unique_embeddings\", 0, ctx->InputDType(\"unique_embeddings\", 0));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> OneEmbeddingSmartDecaySparseAdamUpdateOp::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) {\n  JUST(CheckDataShape(ctx));\n  const int64_t embedding_size = ctx->Attr<int64_t>(\"embedding_size\");\n  const int64_t line_size = ctx->Attr<int64_t>(\"line_size\");\n  CHECK_NE_OR_RETURN(embedding_size, 0) << \"should set attr embedding_size\";\n  CHECK_NE_OR_RETURN(line_size, 0) << \"should set attr line_size\";\n  const int64_t value_dtype_size = GetSizeOfDataType(ctx->InputDType(\"unique_embeddings\", 0));\n  const int64_t step_dtype_size = sizeof(int64_t);\n  const int64_t model_and_states_bytes = embedding_size * 3 * value_dtype_size;\n  const int64_t align_to_step_size_bytes =\n      (model_and_states_bytes + step_dtype_size - 1) / step_dtype_size * step_dtype_size;\n  const int64_t smart_decay_sparse_adam_line_size =\n      (align_to_step_size_bytes + step_dtype_size) / value_dtype_size;\n  CHECK_EQ_OR_RETURN(line_size, smart_decay_sparse_adam_line_size)\n      << \"when using SmartDecayAdam optimizer with embedding_size \" << embedding_size\n      << \", storage_dim should equals to \" << smart_decay_sparse_adam_line_size\n      << \", but got \"\n         \"storage_dim: \"\n      << line_size << \", please set storage_dim of store_options to \"\n      << smart_decay_sparse_adam_line_size;\n  const Shape& unique_embeddings_shape = ctx->InputShape(\"unique_embeddings\", 0);\n  ctx->SetOutputShape(\"updated_unique_embeddings\", 0, unique_embeddings_shape);\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> OneEmbeddingSmartDecaySparseAdamUpdateOp::InferPhysicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> OneEmbeddingSmartDecaySparseAdamUpdateOp::GetSbp(\n    user_op::SbpContext* ctx) {\n  JUST(GetEmbeddingUpdateSbp(ctx));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> OneEmbeddingSmartDecaySparseAdamUpdateOp::InferDataType(\n    user_op::InferContext* ctx) {\n  JUST(CheckDataType(ctx));\n  ctx->SetOutputDType(\"updated_unique_embeddings\", 0, ctx->InputDType(\"unique_embeddings\", 0));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> OneEmbeddingAdagradUpdateOp::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) {\n  JUST(CheckDataShape(ctx));\n  const int64_t embedding_size = ctx->Attr<int64_t>(\"embedding_size\");\n  const int64_t line_size = ctx->Attr<int64_t>(\"line_size\");\n  CHECK_NE_OR_RETURN(embedding_size, 0) << \"should set attr embedding_size\";\n  CHECK_NE_OR_RETURN(line_size, 0) << \"should set attr line_size\";\n  CHECK_EQ_OR_RETURN(line_size, embedding_size * 2)\n      << \"when using Adagrad optimizer, line_size should equals to embedding_size * 2, but get \"\n         \"line_size: \"\n      << line_size << \" embedding_size: \" << embedding_size\n      << \", please set size_factor of store_options to 2.\";\n  const Shape& unique_embeddings_shape = ctx->InputShape(\"unique_embeddings\", 0);\n  ctx->SetOutputShape(\"updated_unique_embeddings\", 0, unique_embeddings_shape);\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> OneEmbeddingAdagradUpdateOp::InferPhysicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> OneEmbeddingAdagradUpdateOp::GetSbp(user_op::SbpContext* ctx) {\n  JUST(GetEmbeddingUpdateSbp(ctx));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> OneEmbeddingAdagradUpdateOp::InferDataType(user_op::InferContext* ctx) {\n  JUST(CheckDataType(ctx));\n  ctx->SetOutputDType(\"updated_unique_embeddings\", 0, ctx->InputDType(\"unique_embeddings\", 0));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> OneEmbeddingFtrlUpdateOp::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) {\n  JUST(CheckDataShape(ctx));\n  const int64_t embedding_size = ctx->Attr<int64_t>(\"embedding_size\");\n  const int64_t line_size = ctx->Attr<int64_t>(\"line_size\");\n  CHECK_NE_OR_RETURN(embedding_size, 0) << \"should set attr embedding_size\";\n  CHECK_NE_OR_RETURN(line_size, 0) << \"should set attr line_size\";\n  CHECK_EQ_OR_RETURN(line_size, embedding_size * 3)\n      << \"when using Ftrl optimizer, line_size should equals to embedding_size * 3, but get \"\n         \"line_size: \"\n      << line_size << \" embedding_size: \" << embedding_size\n      << \", please set size_factor of store_options to 3.\";\n  const Shape& unique_embeddings_shape = ctx->InputShape(\"unique_embeddings\", 0);\n  ctx->SetOutputShape(\"updated_unique_embeddings\", 0, unique_embeddings_shape);\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> OneEmbeddingFtrlUpdateOp::InferPhysicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> OneEmbeddingFtrlUpdateOp::GetSbp(user_op::SbpContext* ctx) {\n  JUST(GetEmbeddingUpdateSbp(ctx));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> OneEmbeddingFtrlUpdateOp::InferDataType(user_op::InferContext* ctx) {\n  JUST(CheckDataType(ctx));\n  ctx->SetOutputDType(\"updated_unique_embeddings\", 0, ctx->InputDType(\"unique_embeddings\", 0));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> IdShuffleCopyOutOp::GetSbp(user_op::SbpContext* ctx) {\n  ctx->NewBuilder()\n      .Split(ctx->inputs(), 0)\n      .Split(ctx->outputs(), 0)\n      .Broadcast(user_op::OpArg(\"num_unique_matrix\", 0))\n      .Broadcast(user_op::OpArg(\"out_num_unique_matrix\", 0))\n      .Broadcast(user_op::OpArg(\"cur_rank_num_unique\", 0))\n      .Broadcast(user_op::OpArg(\"out_cur_rank_num_unique\", 0))\n      .Build();\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> IdShuffleCopyOutOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  ctx->SetOutputShape(\"out_num_unique_matrix\", 0, ctx->InputShape(\"num_unique_matrix\", 0));\n  ctx->SetOutputShape(\"out_inverse_unique_partition_indices\", 0,\n                      ctx->InputShape(\"inverse_unique_partition_indices\", 0));\n  ctx->SetOutputShape(\"out_cur_rank_num_unique\", 0, ctx->InputShape(\"cur_rank_num_unique\", 0));\n  ctx->SetOutputShape(\"out_cur_rank_unique_ids\", 0, ctx->InputShape(\"cur_rank_unique_ids\", 0));\n  ctx->SetOutputShape(\"out_cur_rank_unique_table_ids\", 0,\n                      ctx->InputShape(\"cur_rank_unique_table_ids\", 0));\n  ctx->SetOutputShape(\"out_cur_rank_inverse_indices\", 0,\n                      ctx->InputShape(\"cur_rank_inverse_indices\", 0));\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> IdShuffleCopyOutOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> IdShuffleCopyOutOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out_num_unique_matrix\", 0, ctx->InputDType(\"num_unique_matrix\", 0));\n  ctx->SetOutputDType(\"out_inverse_unique_partition_indices\", 0,\n                      ctx->InputDType(\"inverse_unique_partition_indices\", 0));\n  ctx->SetOutputDType(\"out_cur_rank_num_unique\", 0, ctx->InputDType(\"cur_rank_num_unique\", 0));\n  ctx->SetOutputDType(\"out_cur_rank_unique_ids\", 0, ctx->InputDType(\"cur_rank_unique_ids\", 0));\n  ctx->SetOutputDType(\"out_cur_rank_unique_table_ids\", 0,\n                      ctx->InputDType(\"cur_rank_unique_table_ids\", 0));\n  ctx->SetOutputDType(\"out_cur_rank_inverse_indices\", 0,\n                      ctx->InputDType(\"cur_rank_inverse_indices\", 0));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/one_hot_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/common/balanced_splitter.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/* static */ Maybe<void> OneHotOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const int64_t depth = ctx->Attr<int64_t>(\"depth\");\n  CHECK_GT_OR_RETURN(depth, 0);\n  const user_op::TensorDesc& indices_desc = ctx->InputTensorDesc(\"indices\", 0);\n  // For 0-dim Tensor\n  CHECK_GE_OR_RETURN(indices_desc.shape().NumAxes(), 0)\n      << \"indices dim must be great or equal than 0\";\n  user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc(\"out\", 0);\n  out_desc->set_is_dynamic(indices_desc.is_dynamic());\n  DimVector dim_vec = indices_desc.shape().dim_vec();\n  dim_vec.emplace_back(depth);\n  out_desc->set_shape(Shape(dim_vec));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> OneHotOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> OneHotOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& indices_tensor =\n      ctx->LogicalTensorDesc4InputArgNameAndIndex(\"indices\", 0);\n  FOR_RANGE(int64_t, i, 0, indices_tensor.shape().NumAxes()) {\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"indices\", 0), i)\n        .Split(user_op::OpArg(\"out\", 0), i)\n        .Build();\n  }\n\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> OneHotOp::ModifyInputArg(const GetInputArgModifier& GetInputArgModifierFn,\n                                                  const user_op::UserOpConfWrapper& conf) {\n  user_op::InputArgModifier* indices_modifier = GetInputArgModifierFn(\"indices\", 0);\n  CHECK_OR_RETURN(indices_modifier != nullptr);\n  indices_modifier->set_requires_grad(false);\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> OneHotOp::InferDataType(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& indices_desc = ctx->InputTensorDesc(\"indices\", 0);\n  CHECK_OR_RETURN(IsIndexDataType(indices_desc.data_type()));\n  user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc(\"out\", 0);\n  DataType dtype = ctx->Attr<DataType>(\"dtype\");\n  out_desc->set_data_type(dtype);\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/ones_like_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/*static*/ Maybe<void> OnesLikeOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& like_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"like\", 0);\n  FOR_RANGE(int64_t, i, 0, like_tensor.shape().NumAxes()) {\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"like\", 0), i)\n        .Split(user_op::OpArg(\"out\", 0), i)\n        .Build();\n  }\n  ctx->NewBuilder()\n      .PartialSum(user_op::OpArg(\"like\", 0))\n      .Broadcast(user_op::OpArg(\"out\", 0))\n      .Build();\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> OnesLikeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  ctx->SetOutputShape(\"out\", 0, ctx->InputShape(\"like\", 0));\n  ctx->SetOutputStride(\"out\", 0, ctx->InputStride(\"like\", 0));\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> OnesLikeOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return OnesLikeOp::InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> OnesLikeOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"like\", 0));\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> OnesLikeOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) {\n  const NdSbp& in_sbp = ctx->NdSbpHint4InputArgNameAndIndex(\"like\", 0);\n  NdSbp* like_distribution = ctx->NdSbp4ArgNameAndIndex(\"like\", 0);\n  NdSbp* out_distribution = ctx->NdSbp4ArgNameAndIndex(\"out\", 0);\n  *like_distribution = in_sbp;\n  *out_distribution = in_sbp;\n  for (auto& sbp : *out_distribution->mutable_sbp_parallel()) {\n    if (sbp.has_partial_sum_parallel()) {\n      sbp.Clear();\n      *sbp.mutable_broadcast_parallel() = BroadcastParallel();\n    }\n  }\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/p2p_comm_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/user/ops/comm_net_device_infer_util.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/*static*/ Maybe<void> SendOp::GetSbp(user_op::SbpContext* ctx) { UNIMPLEMENTED_THEN_RETURN(); }\n/*static*/ Maybe<void> SendOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  // Do nothing.\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> SendOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return SendOp::InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> SendOp::InferDataType(user_op::InferContext* ctx) {\n  // Do nothing.\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<Symbol<Stream>> SendOp::InferDeviceAndStream(\n    user_op::DeviceAndStreamInferContext* ctx) {\n  return DeviceAndStreamInferFn(ctx);\n}\n\nnamespace {\n\nMaybe<Symbol<Device>> GetRecvOutputDeivce(user_op::DeviceAndStreamInferContext* ctx) {\n  const std::string& device_type = ctx->Attr<std::string>(\"device_type\");\n  const int device_id = ctx->Attr<int64_t>(\"device_id\");\n  return Device::New(device_type, device_id);\n}\n\n}  // namespace\n\n/*static*/ Maybe<void> RecvOp::GetSbp(user_op::SbpContext* ctx) { UNIMPLEMENTED_THEN_RETURN(); }\n/*static*/ Maybe<void> RecvOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  ctx->SetOutputShape(\"out\", 0, ctx->Attr<Shape>(\"shape\"));\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> RecvOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return RecvOp::InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> RecvOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, ctx->Attr<DataType>(\"dtype\"));\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<Symbol<Stream>> RecvOp::InferDeviceAndStream(\n    user_op::DeviceAndStreamInferContext* ctx) {\n  return DeviceAndStreamInferFn<&GetRecvOutputDeivce>(ctx);\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/pack_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/*static*/ Maybe<void> PackOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& in = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"in\", 0);\n  FOR_RANGE(int64_t, i, 0, in.shape().NumAxes()) {\n    ctx->NewBuilder().Split(user_op::OpArg(\"in\", 0), i).Split(user_op::OpArg(\"out\", 0), i).Build();\n  }\n  ctx->NewBuilder()\n      .PartialSum(user_op::OpArg(\"in\", 0))\n      .PartialSum(user_op::OpArg(\"out\", 0))\n      .Build();\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> PackOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& in_desc = ctx->InputTensorDesc(\"in\", 0);\n  const int32_t pack_num = ctx->Attr<int32_t>(\"pack_num\");\n  CHECK_GT_OR_RETURN(pack_num, 0);\n  Shape out_shape = in_desc.shape();\n  user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc(\"out\", 0);\n  out_desc->set_is_dynamic(in_desc.is_dynamic());\n  if (out_shape.NumAxes() > 0) {\n    out_shape.Set(0, out_shape.At(0) * pack_num);\n    out_desc->set_shape(out_shape);\n  } else {\n    // NOTE(chengcheng): for Scalar input pack\n    CHECK_EQ_OR_RETURN(out_shape.elem_cnt(), 1);\n    out_desc->set_shape(Shape({pack_num}));\n  }\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> PackOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return PackOp::InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> PackOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> PackOp::InferOutputBlobTimeShape(\n    user_op::InferOutputBlobTimeShapeFnContext* ctx) {\n  const int32_t pack_num = ctx->user_op_conf().attr<int32_t>(\"pack_num\");\n  DimVector time_shape_dim_vec = ctx->TimeShape4InputArgNameAndIndex(\"in\", 0).dim_vec();\n  CHECK_OR_RETURN(!time_shape_dim_vec.empty());\n  CHECK_EQ_OR_RETURN(time_shape_dim_vec.back(), pack_num);\n  time_shape_dim_vec.pop_back();\n  if (time_shape_dim_vec.empty()) { time_shape_dim_vec.emplace_back(1); }\n  *ctx->mut_output_blob_time_shape() = Shape(time_shape_dim_vec);\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/pad_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/common/balanced_splitter.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/*static*/ Maybe<void> PadOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"x\", 0);\n  const auto& padding_before = ctx->Attr<std::vector<int64_t>>(\"padding_before\");\n  const auto& padding_after = ctx->Attr<std::vector<int64_t>>(\"padding_after\");\n  FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) {\n    if (padding_before[i] == 0 && padding_after[i] == 0) {\n      ctx->NewBuilder().Split(user_op::OpArg(\"x\", 0), i).Split(user_op::OpArg(\"y\", 0), i).Build();\n    }\n  }\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> PadOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& x_shape = ctx->InputShape(\"x\", 0);\n  const auto& padding_before = ctx->Attr<std::vector<int64_t>>(\"padding_before\");\n  const auto& padding_after = ctx->Attr<std::vector<int64_t>>(\"padding_after\");\n  CHECK_EQ_OR_RETURN(padding_before.size(), x_shape.NumAxes());\n  CHECK_EQ_OR_RETURN(padding_after.size(), x_shape.NumAxes());\n  DimVector y_dim_vec(x_shape.NumAxes());\n  FOR_RANGE(int64_t, i, 0, x_shape.NumAxes()) {\n    y_dim_vec[i] = x_shape.At(i) + padding_before[i] + padding_after[i];\n  }\n  ctx->SetOutputShape(\"y\", 0, Shape(y_dim_vec));\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> PadOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return PadOp::InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> PadOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"y\", 0, ctx->InputDType(\"x\", 0));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/parallel_cast_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/operator/operator.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/*static*/ Maybe<void> ParallelCastOp::GetSbp(user_op::SbpContext* ctx) {\n  return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx);\n}\n/*static*/ Maybe<void> ParallelCastOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  ctx->SetOutputShape(\"out\", 0, ctx->InputShape(\"in\", 0));\n  ctx->SetOutputIsDynamic(\"out\", 0, ctx->InputIsDynamic(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> ParallelCastOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return ParallelCastOp::InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> ParallelCastOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> ParallelCastOp::InferSbpSignature(user_op::InferSbpSignatureFnContext* ctx) {\n  auto* bn2sbp = ctx->mutable_sbp_signature()->mutable_bn_in_op2sbp_parallel();\n  const std::string& ibn = GenRepeatedBn(\"in\", 0);\n  const std::string& obn = GenRepeatedBn(\"out\", 0);\n  const auto& sbp_parallel_str = ctx->Attr<std::string>(\"sbp_parallel\");\n  if (sbp_parallel_str.empty()) {\n    const auto& sbp_parallel = ctx->SbpParallelHint4InputArgNameAndIndex(\"in\", 0);\n    (*bn2sbp)[ibn] = sbp_parallel;\n    (*bn2sbp)[obn] = sbp_parallel;\n  } else {\n    SbpParallel sbp_parallel;\n    CHECK_OR_RETURN(ParseSbpParallelFromString(sbp_parallel_str, &sbp_parallel))\n        << \"invalid sbp_parallel: \" << sbp_parallel_str;\n    if (sbp_parallel.has_split_parallel()) {\n      int64_t split_axis = sbp_parallel.split_parallel().axis();\n      const auto& in_desc = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"in\", 0);\n      int64_t num_axes = in_desc.shape().NumAxes();\n      CHECK_GE_OR_RETURN(split_axis, 0);\n      CHECK_LT_OR_RETURN(split_axis, num_axes);\n    }\n    (*bn2sbp)[ibn] = sbp_parallel;\n    (*bn2sbp)[obn] = sbp_parallel;\n  }\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/partial_fc_sample_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/*static*/ Maybe<void> DistributedPartialFcSampleOp::GetSbp(user_op::SbpContext* ctx) {\n  ctx->NewBuilder()\n      .Split(user_op::OpArg(\"weight\", 0), 0)\n      .Broadcast(user_op::OpArg(\"label\", 0))\n      .Broadcast(user_op::OpArg(\"mapped_label\", 0))\n      .Split(user_op::OpArg(\"sampled_label\", 0), 0)\n      .Split(user_op::OpArg(\"sampled_weight\", 0), 0)\n      .Build();\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> DistributedPartialFcSampleOp::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) {\n  const int64_t num_sample = ctx->Attr<int64_t>(\"num_sample\");\n  const user_op::TensorDesc& weight = ctx->InputTensorDesc(\"weight\", 0);\n  const user_op::TensorDesc& label = ctx->InputTensorDesc(\"label\", 0);\n  user_op::TensorDesc* mapped_label = ctx->MutOutputTensorDesc(\"mapped_label\", 0);\n  user_op::TensorDesc* sampled_weight = ctx->MutOutputTensorDesc(\"sampled_weight\", 0);\n  user_op::TensorDesc* sampled_label = ctx->MutOutputTensorDesc(\"sampled_label\", 0);\n  mapped_label->set_shape(label.shape());\n  mapped_label->set_is_dynamic(label.is_dynamic());\n  Shape sampled_weight_shape = weight.shape();\n  sampled_weight_shape.Set(0, num_sample);\n  sampled_weight->set_shape(sampled_weight_shape);\n  sampled_weight->set_is_dynamic(weight.is_dynamic());\n  Shape sampled_label_shape = label.shape();\n  sampled_label_shape.Set(0, num_sample);\n  sampled_label->set_shape(sampled_label_shape);\n  sampled_label->set_is_dynamic(label.is_dynamic());\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> DistributedPartialFcSampleOp::InferPhysicalTensorDesc(\n    user_op::InferContext* ctx) {\n  const int64_t num_sample = ctx->Attr<int64_t>(\"num_sample\");\n  const int64_t parallel_num = ctx->parallel_ctx().parallel_num();\n  CHECK_EQ_OR_RETURN(num_sample % parallel_num, 0);\n  const int64_t num_sample_per_rank = num_sample / parallel_num;\n  const user_op::TensorDesc& weight = ctx->InputTensorDesc(\"weight\", 0);\n  const user_op::TensorDesc& label = ctx->InputTensorDesc(\"label\", 0);\n  user_op::TensorDesc* mapped_label = ctx->MutOutputTensorDesc(\"mapped_label\", 0);\n  user_op::TensorDesc* sampled_weight = ctx->MutOutputTensorDesc(\"sampled_weight\", 0);\n  user_op::TensorDesc* sampled_label = ctx->MutOutputTensorDesc(\"sampled_label\", 0);\n  mapped_label->set_shape(label.shape());\n  mapped_label->set_is_dynamic(label.is_dynamic());\n  Shape sampled_weight_shape = weight.shape();\n  sampled_weight_shape.Set(0, num_sample_per_rank);\n  sampled_weight->set_shape(sampled_weight_shape);\n  sampled_weight->set_is_dynamic(weight.is_dynamic());\n  Shape sampled_label_shape = label.shape();\n  sampled_label_shape.Set(0, num_sample_per_rank);\n  sampled_label->set_shape(sampled_label_shape);\n  sampled_label->set_is_dynamic(label.is_dynamic());\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> DistributedPartialFcSampleOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"mapped_label\", 0, ctx->InputDType(\"label\", 0));\n  ctx->SetOutputDType(\"sampled_weight\", 0, ctx->InputDType(\"weight\", 0));\n  ctx->SetOutputDType(\"sampled_label\", 0, ctx->InputDType(\"label\", 0));\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> DistributedPartialFcSampleOp::ModifyInputArg(\n    const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&) {\n  user_op::InputArgModifier* label_modifier = GetInputArgModifierFn(\"label\", 0);\n  CHECK_NOTNULL_OR_RETURN(label_modifier);\n  label_modifier->set_requires_grad(false);\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> DistributedPartialFcSampleDisableBoxingOp::GetSbp(user_op::SbpContext* ctx) {\n  ctx->NewBuilder()\n      .Split(user_op::OpArg(\"sampled_weight_diff\", 0), 0)\n      .Split(user_op::OpArg(\"sampled_label\", 0), 0)\n      .Broadcast(user_op::OpArg(\"boxing_disabled_sampled_weight_diff\", 0))\n      .Broadcast(user_op::OpArg(\"boxing_disabled_sampled_label\", 0))\n      .Build();\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> DistributedPartialFcSampleDisableBoxingOp::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) {\n  user_op::TensorDesc* boxing_disabled_sampled_weight_diff =\n      ctx->MutOutputTensorDesc(\"boxing_disabled_sampled_weight_diff\", 0);\n  Shape boxing_disabled_sampled_weight_diff_shape = ctx->InputShape(\"sampled_weight_diff\", 0);\n  CHECK_EQ_OR_RETURN(boxing_disabled_sampled_weight_diff_shape.At(0) % ctx->parallel_num(), 0);\n  boxing_disabled_sampled_weight_diff_shape.Set(\n      0, boxing_disabled_sampled_weight_diff_shape.At(0) / ctx->parallel_num());\n  boxing_disabled_sampled_weight_diff->set_shape(boxing_disabled_sampled_weight_diff_shape);\n  boxing_disabled_sampled_weight_diff->set_is_dynamic(\n      ctx->InputIsDynamic(\"sampled_weight_diff\", 0));\n  user_op::TensorDesc* boxing_disabled_sampled_label =\n      ctx->MutOutputTensorDesc(\"boxing_disabled_sampled_label\", 0);\n  Shape boxing_disabled_sampled_label_shape = ctx->InputShape(\"sampled_label\", 0);\n  ;\n  CHECK_EQ_OR_RETURN(boxing_disabled_sampled_label_shape.At(0) % ctx->parallel_num(), 0);\n  boxing_disabled_sampled_label_shape.Set(\n      0, boxing_disabled_sampled_label_shape.At(0) / ctx->parallel_num());\n  boxing_disabled_sampled_label->set_shape(boxing_disabled_sampled_label_shape);\n  boxing_disabled_sampled_label->set_is_dynamic(ctx->InputIsDynamic(\"sampled_label\", 0));\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> DistributedPartialFcSampleDisableBoxingOp::InferPhysicalTensorDesc(\n    user_op::InferContext* ctx) {\n  ctx->SetOutputShape(\"boxing_disabled_sampled_weight_diff\", 0,\n                      ctx->InputShape(\"sampled_weight_diff\", 0));\n  ctx->SetOutputIsDynamic(\"boxing_disabled_sampled_weight_diff\", 0,\n                          ctx->InputIsDynamic(\"sampled_weight_diff\", 0));\n  ctx->SetOutputShape(\"boxing_disabled_sampled_label\", 0, ctx->InputShape(\"sampled_label\", 0));\n  ctx->SetOutputIsDynamic(\"boxing_disabled_sampled_label\", 0,\n                          ctx->InputIsDynamic(\"sampled_label\", 0));\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> DistributedPartialFcSampleDisableBoxingOp::InferDataType(\n    user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"boxing_disabled_sampled_weight_diff\", 0,\n                      ctx->InputDType(\"sampled_weight_diff\", 0));\n  ctx->SetOutputDType(\"boxing_disabled_sampled_label\", 0, ctx->InputDType(\"sampled_label\", 0));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/pinned_identity_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/* static */ Maybe<void> PinnedIdentityOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  ctx->SetOutputShape(\"out\", 0, ctx->InputShape(\"in\", 0));\n  ctx->SetOutputIsDynamic(\"out\", 0, ctx->InputIsDynamic(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> PinnedIdentityOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> PinnedIdentityOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"in\", 0);\n  FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) {\n    ctx->NewBuilder().Split(user_op::OpArg(\"in\", 0), i).Split(user_op::OpArg(\"out\", 0), i).Build();\n  }\n  ctx->NewBuilder()\n      .PartialSum(user_op::OpArg(\"in\", 0))\n      .PartialSum(user_op::OpArg(\"out\", 0))\n      .Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> PinnedIdentityOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/prelu_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/*static*/ Maybe<void> PreluOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"x\", 0);\n  const user_op::TensorDesc& alpha_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"alpha\", 0);\n  if (alpha_tensor.shape().At(0) != 1) {\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"x\", 0), 1)\n        .Split(user_op::OpArg(\"alpha\", 0), 0)\n        .Split(user_op::OpArg(\"y\", 0), 1)\n        .Build();\n  }\n  FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) {\n    if (i == 1) continue;\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"x\", 0), i)\n        .Broadcast(user_op::OpArg(\"alpha\", 0))\n        .Split(user_op::OpArg(\"y\", 0), i)\n        .Build();\n  }\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> PreluOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& x_shape = ctx->InputShape(\"x\", 0);\n  const Shape& alpha_shape = ctx->InputShape(\"alpha\", 0);\n  CHECK_EQ_OR_RETURN(alpha_shape.NumAxes(), 1);\n  ctx->SetOutputShape(\"y\", 0, x_shape);\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> PreluOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> PreluOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"y\", 0, ctx->InputDType(\"x\", 0));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> PreluGradOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"x\", 0);\n  ctx->NewBuilder()\n      .Split(user_op::OpArg(\"dy\", 0), 0)\n      .Split(user_op::OpArg(\"x\", 0), 0)\n      .Broadcast(user_op::OpArg(\"alpha\", 0))\n      .Split(user_op::OpArg(\"dx\", 0), 0)\n      .PartialSum(user_op::OpArg(\"alpha_diff\", 0))\n      .Build();\n  ctx->NewBuilder()\n      .PartialSum(user_op::OpArg(\"dy\", 0))\n      .Broadcast(user_op::OpArg(\"x\", 0))\n      .Broadcast(user_op::OpArg(\"alpha\", 0))\n      .PartialSum(user_op::OpArg(\"dx\", 0))\n      .PartialSum(user_op::OpArg(\"alpha_diff\", 0))\n      .Build();\n  ctx->NewBuilder()\n      .Split(user_op::OpArg(\"dy\", 0), 1)\n      .Split(user_op::OpArg(\"x\", 0), 1)\n      .Split(user_op::OpArg(\"alpha\", 0), 0)\n      .Split(user_op::OpArg(\"dx\", 0), 1)\n      .Split(user_op::OpArg(\"alpha_diff\", 0), 0)\n      .Build();\n  FOR_RANGE(int64_t, i, 1, x_tensor.shape().NumAxes()) {\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"dy\", 0), i)\n        .Split(user_op::OpArg(\"x\", 0), i)\n        .Split(user_op::OpArg(\"alpha\", 0), 0)\n        .Split(user_op::OpArg(\"dx\", 0), i)\n        .Split(user_op::OpArg(\"alpha_diff\", 0), 0)\n        .Build();\n  }\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> PreluGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& x_shape = ctx->InputShape(\"x\", 0);\n  const Shape& dy_shape = ctx->InputShape(\"dy\", 0);\n  const Shape& alpha_shape = ctx->InputShape(\"alpha\", 0);\n  CHECK_EQ_OR_RETURN(alpha_shape.NumAxes(), 1);\n  CHECK_OR_RETURN((alpha_shape.At(0) == x_shape.At(1)) || (alpha_shape.At(0) == 1));\n  CHECK_EQ_OR_RETURN(dy_shape, x_shape);\n  ctx->SetOutputShape(\"dx\", 0, x_shape);\n  ctx->SetOutputShape(\"alpha_diff\", 0, alpha_shape);\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> PreluGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> PreluGradOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"dx\", 0, ctx->InputDType(\"x\", 0));\n  ctx->SetOutputDType(\"alpha_diff\", 0, ctx->InputDType(\"alpha\", 0));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/quantization_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/*static*/ Maybe<void> QuantizationOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"in\", 0);\n  const Shape& logical_scale_shape =\n      ctx->LogicalTensorDesc4InputArgNameAndIndex(\"scale\", 0).shape();\n  ctx->NewBuilder()\n      .Broadcast(user_op::OpArg(\"in\", 0))\n      .Broadcast(user_op::OpArg(\"scale\", 0))\n      .Broadcast(user_op::OpArg(\"zero_point\", 0))\n      .Broadcast(user_op::OpArg(\"out\", 0))\n      .Build();\n  if (logical_scale_shape.elem_cnt() > 1) {\n    // NOTE(Liang Depeng): only consider convolution weight per-channel quantization\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"in\", 0), 0)\n        .Split(user_op::OpArg(\"scale\", 0), 0)\n        .Split(user_op::OpArg(\"zero_point\", 0), 0)\n        .Split(user_op::OpArg(\"out\", 0), 0)\n        .Build();\n  } else {\n    // NOTE(Liang Depeng): the sbp signature of per-layer quantization is the same as eltwise\n    // ops\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"in\", 0), 0)\n        .Broadcast(user_op::OpArg(\"scale\", 0))\n        .Broadcast(user_op::OpArg(\"zero_point\", 0))\n        .Split(user_op::OpArg(\"out\", 0), 0)\n        .Build();\n  }\n  FOR_RANGE(int64_t, i, 1, in_tensor.shape().NumAxes()) {\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"in\", 0), i)\n        .Broadcast(user_op::OpArg(\"scale\", 0))\n        .Broadcast(user_op::OpArg(\"zero_point\", 0))\n        .Split(user_op::OpArg(\"out\", 0), i)\n        .Build();\n  }\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> QuantizationOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& in_shape = ctx->InputShape(\"in\", 0);\n  const Shape& scale_shape = ctx->InputShape(\"scale\", 0);\n  const Shape& zero_point_shape = ctx->InputShape(\"zero_point\", 0);\n\n  // NOTE(Liang Depeng): scale_shape->elem_cnt() > 1 means per-channel quantization for\n  // convolution weights.\n  if (scale_shape.elem_cnt() > 1) {\n    CHECK_EQ_OR_RETURN(scale_shape.elem_cnt(), in_shape.At(0));\n    CHECK_EQ_OR_RETURN(zero_point_shape.elem_cnt(), in_shape.At(0));\n  }\n\n  ctx->SetOutputShape(\"out\", 0, in_shape);\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> QuantizationOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> QuantizationOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> QuantizationOp::ModifyInputArg(\n    const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&) {\n  user_op::InputArgModifier* scale = GetInputArgModifierFn(\"scale\", 0);\n  CHECK_OR_RETURN(scale != nullptr);\n  scale->set_requires_grad(false);\n\n  user_op::InputArgModifier* zero_point = GetInputArgModifierFn(\"zero_point\", 0);\n  CHECK_OR_RETURN(zero_point != nullptr);\n  zero_point->set_requires_grad(false);\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> QuantizationOp::CheckAttr(const user_op::UserOpDefWrapper&,\n                                                 const user_op::UserOpConfWrapper& op_conf) {\n  const int32_t quantization_bit = op_conf.attr<int32_t>(\"quantization_bit\");\n  CHECK_GT_OR_RETURN(quantization_bit, 1);\n  CHECK_LE_OR_RETURN(quantization_bit, 8);\n\n  std::string quantization_scheme = op_conf.attr<std::string>(\"quantization_scheme\");\n  CHECK_OR_RETURN(quantization_scheme == \"symmetric\" || quantization_scheme == \"affine\");\n\n  std::string quantization_formula = op_conf.attr<std::string>(\"quantization_formula\");\n  CHECK_OR_RETURN(quantization_formula == \"google\" || quantization_formula == \"cambricon\");\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/quick_gelu_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/*static*/ Maybe<void> QuickGeluOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  ctx->SetOutputShape(\"y\", 0, ctx->InputShape(\"x\", 0));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> QuickGeluOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/*static*/ Maybe<void> QuickGeluOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"y\", 0, ctx->InputDType(\"x\", 0));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> QuickGeluOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"x\", 0);\n  FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) {\n    ctx->NewBuilder().Split(user_op::OpArg(\"x\", 0), i).Split(user_op::OpArg(\"y\", 0), i).Build();\n  }\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> QuickGeluGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& x_shape = ctx->InputShape(\"x\", 0);\n  const Shape& dy_shape = ctx->InputShape(\"dy\", 0);\n  CHECK_OR_RETURN(dy_shape == x_shape)\n      << \"InferTensorDesc failed (\" << ctx->op_name() << \"). Expected x shape \"\n      << x_shape.ToString() << \" to be equal to dy shape \" << dy_shape.ToString();\n  ctx->SetOutputShape(\"dx\", 0, dy_shape);\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> QuickGeluGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/*static*/ Maybe<void> QuickGeluGradOp::InferDataType(user_op::InferContext* ctx) {\n  CHECK_EQ_OR_RETURN(ctx->InputDType(\"x\", 0), ctx->InputDType(\"dy\", 0))\n      << \"InferDataType Failed. Expected \" << DataType_Name(ctx->InputDType(\"dy\", 0))\n      << \", but got \" << DataType_Name(ctx->InputDType(\"x\", 0));\n  ctx->SetOutputDType(\"dx\", 0, ctx->InputDType(\"x\", 0));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> QuickGeluGradOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"x\", 0);\n  FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) {\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"x\", 0), i)\n        .Split(user_op::OpArg(\"dy\", 0), i)\n        .Split(user_op::OpArg(\"dx\", 0), i)\n        .Build();\n  }\n  ctx->NewBuilder()\n      .Broadcast(user_op::OpArg(\"x\", 0))\n      .PartialSum(user_op::OpArg(\"dy\", 0))\n      .PartialSum(user_op::OpArg(\"dx\", 0))\n      .Build();\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/randperm_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n#include \"oneflow/core/common/balanced_splitter.h\"\n#include \"oneflow/core/job/nd_sbp_util.h\"\n\nnamespace oneflow {\n\n/*static*/ Maybe<void> RandpermOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) {\n  SbpParallel default_sbp;\n  default_sbp.mutable_broadcast_parallel();\n  return user_op::InferNdSbp4SrcOp(ctx, default_sbp);\n}\n/*static*/ Maybe<void> RandpermOp::GetSbp(user_op::SbpContext* ctx) { return Maybe<void>::Ok(); }\n/*static*/ Maybe<void> RandpermOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  int32_t n = ctx->Attr<int32_t>(\"n\");\n  CHECK_GE_OR_RETURN(n, 0) << Error::RuntimeError()\n                           << \"Trying to create tensor with negative dimension \" << n << \":\"\n                           << \" [\" << n << \"]\";\n  ctx->SetOutputShape(\"out\", 0, Shape({n}));\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> RandpermOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& parallel_hierarchy = *ctx->parallel_desc().hierarchy();\n  const NdSbp& nd_sbp = ctx->NdSbp4ArgNameAndIndex(\"out\", 0);\n  int32_t n = ctx->Attr<int32_t>(\"n\");\n  const Shape& logical_shape = Shape({n});\n  const int64_t parallel_id = ctx->parallel_ctx().parallel_id();\n  const auto tensor_slice_view =\n      GetTensorSliceView4ParallelId(parallel_hierarchy, nd_sbp, logical_shape, parallel_id);\n  const Shape& physical_shape = tensor_slice_view.shape();\n\n  ctx->SetOutputShape(\"out\", 0, physical_shape);\n\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> RandpermOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, DataType::kInt32);\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/raw_reader_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/* static */ Maybe<void> RawReaderOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& instance_shape = ctx->Attr<Shape>(\"shape\");\n  const int32_t batch_size = ctx->Attr<int64_t>(\"batch_size\");\n  DimVector dim_vec;\n  dim_vec.push_back(batch_size);\n  for (int64_t i = 0; i < instance_shape.NumAxes(); ++i) {\n    dim_vec.push_back(instance_shape.At(i));\n  }\n  user_op::TensorDesc* out_tensor = ctx->MutOutputTensorDesc(\"out\", 0);\n  out_tensor->set_shape(Shape(dim_vec));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> RawReaderOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  user_op::TensorDesc* out_tensor = ctx->MutOutputTensorDesc(\"out\", 0);\n  int32_t batch_size = ctx->Attr<int64_t>(\"batch_size\");\n  int64_t parallel_num = ctx->parallel_ctx().parallel_num();\n  if (parallel_num > 1) {\n    int64_t split_num = 1;\n    const NdSbp& nd_sbp = ctx->NdSbp4ArgNameAndIndex(\"out\", 0);\n    const Shape& hierarchy = *ctx->parallel_desc().hierarchy();\n    for (int32_t i = 0; i < nd_sbp.sbp_parallel_size(); ++i) {\n      if (nd_sbp.sbp_parallel(i).has_split_parallel()) { split_num *= hierarchy.At(i); }\n    }\n    CHECK_EQ_OR_RETURN(batch_size % split_num, 0) << \"batch_size must be a multiple of shard num\";\n    batch_size /= split_num;\n  }\n  const Shape& instance_shape = ctx->Attr<Shape>(\"shape\");\n  DimVector dim_vec;\n  dim_vec.push_back(batch_size);\n  for (int64_t i = 0; i < instance_shape.NumAxes(); ++i) {\n    dim_vec.push_back(instance_shape.At(i));\n  }\n  out_tensor->set_shape(Shape({dim_vec}));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> RawReaderOp::GetSbp(user_op::SbpContext* ctx) {\n  ctx->NewBuilder().Broadcast(ctx->inputs()).Split(ctx->outputs(), 0).Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> RawReaderOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) {\n  SbpParallel default_sbp;\n  default_sbp.mutable_split_parallel()->set_axis(0);\n  return user_op::InferNdSbp4SrcOp(ctx, default_sbp);\n}\n\n/* static */ Maybe<void> RawReaderOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, ctx->Attr<DataType>(\"data_type\"));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/reduce_like_ops.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/operator/reduce_sbp_util.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/*static*/ Maybe<void> ReduceSumLikeOp::GetSbp(user_op::SbpContext* ctx) {\n  int32_t num_axes = 0;\n  HashSet<int32_t> conf_axes;\n\n  const auto& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"x\", 0);\n  num_axes = in_tensor.shape().NumAxes();\n  const auto& reduced_axes = ctx->Attr<std::vector<int32_t>>(\"axis\");\n  ReduceSbpUtil::GetRegularAxes(num_axes, reduced_axes, &conf_axes);\n\n  const auto& like_num_axes =\n      ctx->LogicalTensorDesc4InputArgNameAndIndex(\"like\", 0).shape().NumAxes();\n  const bool keep_dims = (num_axes == like_num_axes);\n  auto IsReducedAxis = ReduceSbpUtil::MakePredicatorIsReducedAxis(conf_axes, num_axes);\n  int64_t num_reduced_axes = 0;\n  FOR_RANGE(int64_t, i, 0, num_axes) {\n    if (in_tensor.shape().at(i) == 1) {\n      num_reduced_axes += 1;\n    } else if (IsReducedAxis(i)) {\n      ctx->NewBuilder()\n          .Split(user_op::OpArg(\"x\", 0), i)\n          .Broadcast(user_op::OpArg(\"like\", 0))\n          .PartialSum(user_op::OpArg(\"y\", 0))\n          .Build();\n      ctx->NewBuilder()\n          .Split(user_op::OpArg(\"x\", 0), i)\n          .PartialSum(user_op::OpArg(\"like\", 0))\n          .PartialSum(user_op::OpArg(\"y\", 0))\n          .Build();\n      num_reduced_axes += 1;\n    } else {\n      const int64_t out_split_axis = keep_dims ? i : i - num_reduced_axes;\n      ctx->NewBuilder()\n          .Split(user_op::OpArg(\"x\", 0), i)\n          .Split(user_op::OpArg(\"like\", 0), out_split_axis)\n          .Split(user_op::OpArg(\"y\", 0), out_split_axis)\n          .Build();\n    }\n  }\n  ctx->NewBuilder()\n      .Broadcast(user_op::OpArg(\"x\", 0))\n      .PartialSum(user_op::OpArg(\"like\", 0))\n      .Broadcast(user_op::OpArg(\"y\", 0))\n      .Build();\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> ReduceSumLikeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& x_tensor = ctx->InputTensorDesc(\"x\", 0);\n  const user_op::TensorDesc& like_tensor = ctx->InputTensorDesc(\"like\", 0);\n  const auto& axis = ctx->Attr<std::vector<int32_t>>(\"axis\");\n  if (axis.empty()) {\n    CHECK_EQ_OR_RETURN(x_tensor.shape(), like_tensor.shape())\n        << Error::RuntimeError()\n        << \"The shape of the x tensor must be consistent to the shape of the like tensor\"\n        << \" when the input axis list is empty\";\n  }\n\n  user_op::TensorDesc* y_tensor = ctx->MutOutputTensorDesc(\"y\", 0);\n  y_tensor->set_shape(like_tensor.shape());\n  y_tensor->set_is_dynamic(like_tensor.is_dynamic());\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> ReduceSumLikeOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> ReduceSumLikeOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"y\", 0, ctx->InputDType(\"x\", 0));\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> ReduceSumLikeOp::ModifyInputArg(\n    const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&) {\n  user_op::InputArgModifier* like_arg_modifier = GetInputArgModifierFn(\"like\", 0);\n  CHECK_OR_RETURN(like_arg_modifier != nullptr);  // NOLINT(maybe-need-error-msg)\n  like_arg_modifier->set_requires_grad(false);\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/reduce_ops.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/operator/reduce_sbp_util.h\"\n#include \"oneflow/core/ndarray/binary_func.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\nMaybe<void> InferTensorDescFn(user_op::InferContext* ctx) {\n  const Shape& input_shape = ctx->InputShape(\"input_tensor\", 0);\n  const auto& reduce_axes = ctx->Attr<std::vector<int32_t>>(\"axis\");\n  Shape output_shape;\n  // For 0-dim Tensor\n  if (reduce_axes.empty()) {\n    output_shape = input_shape;\n  } else {\n    const AxisVector reduce_axes_vec = {reduce_axes.begin(), reduce_axes.end()};\n    const Shape& reduce_shape = CreateReducedShape(input_shape, reduce_axes_vec);\n    const bool keepdims = ctx->Attr<bool>(\"keepdims\");\n    if (keepdims) {\n      output_shape = reduce_shape;\n    } else {\n      output_shape = reduce_shape.RemoveOnes(reduce_axes_vec);\n    }\n  }\n  ctx->SetOutputShape(\"output_tensor\", 0, output_shape);\n  ctx->SetOutputStride(\"output_tensor\", 0, Stride(output_shape));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"output_tensor\", 0, ctx->InputDType(\"input_tensor\", 0));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InferLogicalDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"output_tensor\", 0, DataType::kBool);\n  return Maybe<void>::Ok();\n}\n\ntemplate<template<typename> class binary_func>\nvoid GeneratePartialSbp(user_op::SbpContext* ctx, int64_t axis) {\n  // TODO(lixinqi)\n}\n\ntemplate<>\nvoid GeneratePartialSbp<BinaryFuncSum>(user_op::SbpContext* ctx, int64_t axis) {\n  ctx->NewBuilder().Split(ctx->inputs(), axis).PartialSum(ctx->outputs()).Build();\n  ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build();\n}\n\ntemplate<template<typename> class binary_func>\nMaybe<void> GetSbpFn(user_op::SbpContext* ctx) {\n  const auto& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"input_tensor\", 0);\n  int64_t num_axes = in_tensor.shape().NumAxes();\n  bool keep_dims = ctx->Attr<bool>(\"keepdims\");\n  const auto& reduce_axes = ctx->Attr<std::vector<int32_t>>(\"axis\");\n  HashSet<int32_t> conf_axes;\n  ReduceSbpUtil::GetRegularAxes(num_axes, reduce_axes, &conf_axes);\n  auto IsReducedAxis = ReduceSbpUtil::MakePredicatorIsReducedAxis(conf_axes, num_axes);\n  int32_t num_reduced_axes = 0;\n  FOR_RANGE(int64_t, i, 0, num_axes) {\n    if (IsReducedAxis(i)) {\n      GeneratePartialSbp<binary_func>(ctx, i);\n      num_reduced_axes += 1;\n    } else {\n      ctx->NewBuilder()\n          .Split(ctx->inputs(), i)\n          .Split(ctx->outputs(), keep_dims ? i : i - num_reduced_axes)\n          .Build();\n    }\n  }\n  if (num_axes == 0) {\n    ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build();\n  }\n  return Maybe<void>::Ok();\n}\n\n#define IMPLEMENT_REDUCE_OP_FUNCS(name, binary_func, infer_dtype_func)                   \\\n  /*static*/ Maybe<void> name##Op::GetSbp(user_op::SbpContext* ctx) {                    \\\n    return GetSbpFn<binary_func>(ctx);                                                   \\\n  }                                                                                      \\\n  /*static*/ Maybe<void> name##Op::InferLogicalTensorDesc(user_op::InferContext* ctx) {  \\\n    return InferTensorDescFn(ctx);                                                       \\\n  }                                                                                      \\\n  /*static*/ Maybe<void> name##Op::InferPhysicalTensorDesc(user_op::InferContext* ctx) { \\\n    return InferLogicalTensorDesc(ctx);                                                  \\\n  }                                                                                      \\\n  /*static*/ Maybe<void> name##Op::InferDataType(user_op::InferContext* ctx) {           \\\n    return infer_dtype_func(ctx);                                                        \\\n  }\n\nIMPLEMENT_REDUCE_OP_FUNCS(ReduceAny, BinaryFuncAny, InferLogicalDataType)\nIMPLEMENT_REDUCE_OP_FUNCS(ReduceAll, BinaryFuncAll, InferLogicalDataType)\nIMPLEMENT_REDUCE_OP_FUNCS(ReduceMin, BinaryFuncMin, oneflow::InferDataType)\nIMPLEMENT_REDUCE_OP_FUNCS(ReduceMax, BinaryFuncMax, oneflow::InferDataType)\nIMPLEMENT_REDUCE_OP_FUNCS(ReduceSum, BinaryFuncSum, oneflow::InferDataType)\nIMPLEMENT_REDUCE_OP_FUNCS(ReduceProd, BinaryFuncProd, oneflow::InferDataType)\nIMPLEMENT_REDUCE_OP_FUNCS(ReduceNanSum, BinaryFuncNanSum, oneflow::InferDataType)\n#undef IMPLEMENT_REDUCE_OP_FUNCS\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/reflection_pad_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/balanced_splitter.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/user/ops/nn_util.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<size_t ndim>\nMaybe<void> GetOpSbpSignature(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"x\", 0);\n  const int64_t input_dims = x_tensor.shape().NumAxes();\n  const int64_t split_dims = input_dims - (ndim - 2);\n  FOR_RANGE(int64_t, i, 0, split_dims) {\n    ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build();\n  }\n  ctx->NewBuilder().Broadcast(ctx->inputs()).Broadcast(ctx->outputs()).Build();\n  return Maybe<void>::Ok();\n}\n\ntemplate<size_t ndim>\nMaybe<void> GetOpGradSbpSignature(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& dy_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"dy\", 0);\n  const int64_t grad_dims = dy_tensor.shape().NumAxes();\n  const int64_t split_dims = grad_dims - (ndim - 2);\n  FOR_RANGE(int64_t, i, 0, split_dims) {\n    ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build();\n  }\n  ctx->NewBuilder().Broadcast(ctx->inputs()).Broadcast(ctx->outputs()).Build();\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\n/*static*/ Maybe<void> ReflectionPad1DOp::GetSbp(user_op::SbpContext* ctx) {\n  return GetOpSbpSignature<3>(ctx);\n}\n/*static*/ Maybe<void> ReflectionPad1DOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& x_shape = ctx->InputShape(\"x\", 0);\n  const auto& padding = ctx->Attr<std::vector<int64_t>>(\"padding\");\n  const int64_t n_idx = 0;\n  const int64_t c_idx = 1;\n  const int64_t w_idx = 2;\n\n  DimVector y_dim_vec(x_shape.NumAxes());\n  const int64_t w_x = x_shape.At(w_idx);\n\n  y_dim_vec[n_idx] = x_shape.At(n_idx);\n  y_dim_vec[c_idx] = x_shape.At(c_idx);\n  y_dim_vec[w_idx] = w_x + padding[0] + padding[1];\n\n  ctx->SetOutputShape(\"y\", 0, Shape(y_dim_vec));\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> ReflectionPad1DOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return ReflectionPad1DOp::InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> ReflectionPad1DOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"y\", 0, ctx->InputDType(\"x\", 0));\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> ReflectionPad1DOp::ModifyInputArg(\n    const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&) {\n  user_op::InputArgModifier* x_modifier = GetInputArgModifierFn(\"x\", 0);\n  CHECK_NOTNULL_OR_RETURN(x_modifier);  // NOLINT\n  x_modifier->set_requires_grad(true);\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> ReflectionPad1DGradOp::GetSbp(user_op::SbpContext* ctx) {\n  return GetOpGradSbpSignature<3>(ctx);\n}\n/*static*/ Maybe<void> ReflectionPad1DGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& dy_shape = ctx->InputShape(\"dy\", 0);\n  const auto& padding = ctx->Attr<std::vector<int64_t>>(\"padding\");\n  const int64_t n_idx = 0;\n  const int64_t c_idx = 1;\n  const int64_t w_idx = 2;\n\n  DimVector dx_dim_vec(dy_shape.NumAxes());\n  int64_t w_dy = dy_shape.At(w_idx);\n\n  dx_dim_vec[n_idx] = dy_shape.At(0);\n  dx_dim_vec[c_idx] = dy_shape.At(1);\n  dx_dim_vec[w_idx] = w_dy - padding[0] - padding[1];\n\n  ctx->SetOutputShape(\"dx\", 0, Shape(dx_dim_vec));\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> ReflectionPad1DGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return ReflectionPad1DGradOp::InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> ReflectionPad1DGradOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"dx\", 0, ctx->InputDType(\"dy\", 0));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> ReflectionPad2DOp::GetSbp(user_op::SbpContext* ctx) {\n  return GetOpSbpSignature<4>(ctx);\n}\n/*static*/ Maybe<void> ReflectionPad2DOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& x_shape = ctx->InputShape(\"x\", 0);\n  const auto& padding = ctx->Attr<std::vector<int64_t>>(\"padding\");\n  const int64_t n_idx = 0;\n  const int64_t c_idx = 1;\n  const int64_t h_idx = 2;\n  const int64_t w_idx = 3;\n\n  DimVector y_dim_vec(x_shape.NumAxes());\n  const int64_t h_x = x_shape.At(h_idx);\n  const int64_t w_x = x_shape.At(w_idx);\n\n  y_dim_vec[n_idx] = x_shape.At(n_idx);\n  y_dim_vec[c_idx] = x_shape.At(c_idx);\n  y_dim_vec[h_idx] = h_x + padding[2] + padding[3];\n  y_dim_vec[w_idx] = w_x + padding[0] + padding[1];\n\n  ctx->SetOutputShape(\"y\", 0, Shape(y_dim_vec));\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> ReflectionPad2DOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return ReflectionPad2DOp::InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> ReflectionPad2DOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"y\", 0, ctx->InputDType(\"x\", 0));\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> ReflectionPad2DOp::ModifyInputArg(\n    const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&) {\n  user_op::InputArgModifier* x_modifier = GetInputArgModifierFn(\"x\", 0);\n  CHECK_NOTNULL_OR_RETURN(x_modifier);  // NOLINT\n  x_modifier->set_requires_grad(true);\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> ReflectionPad2DGradOp::GetSbp(user_op::SbpContext* ctx) {\n  return GetOpGradSbpSignature<4>(ctx);\n}\n/*static*/ Maybe<void> ReflectionPad2DGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& dy_shape = ctx->InputShape(\"dy\", 0);\n  const auto& padding = ctx->Attr<std::vector<int64_t>>(\"padding\");\n  const int64_t n_idx = 0;\n  const int64_t c_idx = 1;\n  const int64_t h_idx = 2;\n  const int64_t w_idx = 3;\n\n  DimVector dx_dim_vec(dy_shape.NumAxes());\n  int64_t h_dy = dy_shape.At(h_idx);\n  int64_t w_dy = dy_shape.At(w_idx);\n\n  dx_dim_vec[n_idx] = dy_shape.At(0);\n  dx_dim_vec[c_idx] = dy_shape.At(1);\n  dx_dim_vec[h_idx] = h_dy - padding[2] - padding[3];\n  dx_dim_vec[w_idx] = w_dy - padding[0] - padding[1];\n\n  ctx->SetOutputShape(\"dx\", 0, Shape(dx_dim_vec));\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> ReflectionPad2DGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return ReflectionPad2DGradOp::InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> ReflectionPad2DGradOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"dx\", 0, ctx->InputDType(\"dy\", 0));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/relu_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/*static*/ Maybe<void> ReluOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"x\", 0);\n  FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) {\n    ctx->NewBuilder().Split(user_op::OpArg(\"x\", 0), i).Split(user_op::OpArg(\"y\", 0), i).Build();\n  }\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> ReluOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  ctx->SetOutputShape(\"y\", 0, ctx->InputShape(\"x\", 0));\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> ReluOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> ReluOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"y\", 0, ctx->InputDType(\"x\", 0));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> ReluGradOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& y_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"y\", 0);\n  FOR_RANGE(int64_t, i, 0, y_tensor.shape().NumAxes()) {\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"y\", 0), i)\n        .Split(user_op::OpArg(\"dy\", 0), i)\n        .Split(user_op::OpArg(\"dx\", 0), i)\n        .Build();\n  }\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> ReluGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& y_shape = ctx->InputShape(\"y\", 0);\n  const Shape& dy_shape = ctx->InputShape(\"dy\", 0);\n  CHECK_OR_RETURN(dy_shape == y_shape)\n      << Error::RuntimeError() << \"Tensors y and dy must have the same shape\";\n  ctx->SetOutputShape(\"dx\", 0, dy_shape);\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> ReluGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> ReluGradOp::InferDataType(user_op::InferContext* ctx) {\n  DataType data_type = ctx->InputDType(\"y\", 0);\n  CHECK_EQ_OR_RETURN(ctx->InputDType(\"dy\", 0), data_type)\n      << \"InferDataType Failed. Expected \" << DataType_Name(data_type) << \", but got \"\n      << DataType_Name(ctx->InputDType(\"dy\", 0));\n  ctx->SetOutputDType(\"dx\", 0, data_type);\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/repeat_interleave_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/*static*/ Maybe<void> Repeat_InterLeaveOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& in = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"in\", 0);\n  FOR_RANGE(int64_t, i, 0, in.shape().NumAxes()) {\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"in\", 0), i)\n        .Split(user_op::OpArg(\"cumsum\", 0), i)\n        .Split(user_op::OpArg(\"out\", 0), i)\n        .Build();\n  }\n  ctx->NewBuilder()\n      .PartialSum(user_op::OpArg(\"in\", 0))\n      .PartialSum(user_op::OpArg(\"cumsum\", 0))\n      .PartialSum(user_op::OpArg(\"out\", 0))\n      .Build();\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> Repeat_InterLeaveOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const int64_t repeat_num = ctx->Attr<int64_t>(\"repeat_num\");\n  user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc(\"out\", 0);\n  out_desc->set_shape(Shape({repeat_num}));\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> Repeat_InterLeaveOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> Repeat_InterLeaveOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/repeat_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n#include \"oneflow/core/operator/operator.h\"\n\nnamespace oneflow {\n\n/*static*/ Maybe<void> RepeatOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& in = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"in\", 0);\n  FOR_RANGE(int64_t, i, 0, in.shape().NumAxes()) {\n    ctx->NewBuilder().Split(user_op::OpArg(\"in\", 0), i).Split(user_op::OpArg(\"out\", 0), i).Build();\n  }\n  ctx->NewBuilder()\n      .PartialSum(user_op::OpArg(\"in\", 0))\n      .PartialSum(user_op::OpArg(\"out\", 0))\n      .Build();\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> RepeatOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  ctx->SetOutputShape(\"out\", 0, ctx->InputShape(\"in\", 0));\n  ctx->SetOutputIsDynamic(\"out\", 0, ctx->InputIsDynamic(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> RepeatOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> RepeatOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> RepeatOp::InferOutputBlobTimeShape(\n    user_op::InferOutputBlobTimeShapeFnContext* ctx) {\n  DimVector dim_vec(ctx->TimeShape4InputArgNameAndIndex(\"in\", 0).dim_vec());\n  dim_vec.emplace_back(ctx->user_op_conf().attr<int32_t>(\"repeat_num\"));\n  *ctx->mut_output_blob_time_shape() = Shape(dim_vec);\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/replication_pad_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/common/balanced_splitter.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/user/ops/nn_util.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\nnamespace {\n\ntemplate<size_t ndim>\nMaybe<void> GetOpSbpSignature(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"x\", 0);\n  const int64_t input_dims = x_tensor.shape().NumAxes();\n  const int64_t first_two_dims = input_dims - (ndim - 2);\n  FOR_RANGE(int64_t, i, 0, first_two_dims) {\n    ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build();\n  }\n  ctx->NewBuilder().Broadcast(ctx->inputs()).Broadcast(ctx->outputs()).Build();\n  return Maybe<void>::Ok();\n}\n\ntemplate<size_t ndim>\nMaybe<void> GetOpGradSbpSignature(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& dy_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"dy\", 0);\n  const int64_t grad_dims = dy_tensor.shape().NumAxes();\n  CHECK_EQ_OR_RETURN(grad_dims, ndim);  // NOLINT\n  const int64_t first_two_dims = grad_dims - (ndim - 2);\n  FOR_RANGE(int64_t, i, 0, first_two_dims) {\n    ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build();\n  }\n  ctx->NewBuilder().Broadcast(ctx->inputs()).Broadcast(ctx->outputs()).Build();\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\n/*static*/ Maybe<void> ReplicationPad1DOp::GetSbp(user_op::SbpContext* ctx) {\n  return GetOpSbpSignature<3>(ctx);\n}\n/*static*/ Maybe<void> ReplicationPad1DOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& x_shape = ctx->InputShape(\"x\", 0);\n  const auto& padding = ctx->Attr<std::vector<int64_t>>(\"padding\");\n  const int64_t n_idx = 0;\n  const int64_t c_idx = 1;\n  const int64_t w_idx = 2;\n\n  DimVector y_dim_vec(x_shape.NumAxes());\n  const int64_t w_x = x_shape.At(w_idx);\n\n  y_dim_vec[n_idx] = x_shape.At(n_idx);\n  y_dim_vec[c_idx] = x_shape.At(c_idx);\n  y_dim_vec[w_idx] = w_x + padding[0] + padding[1];\n\n  ctx->SetOutputShape(\"y\", 0, Shape(y_dim_vec));\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> ReplicationPad1DOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return ReplicationPad1DOp::InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> ReplicationPad1DOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"y\", 0, ctx->InputDType(\"x\", 0));\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> ReplicationPad1DOp::ModifyInputArg(\n    const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&) {\n  user_op::InputArgModifier* x_modifier = GetInputArgModifierFn(\"x\", 0);\n  CHECK_NOTNULL_OR_RETURN(x_modifier);\n  x_modifier->set_requires_grad(true);\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> ReplicationPad1DGradOp::GetSbp(user_op::SbpContext* ctx) {\n  return GetOpGradSbpSignature<3>(ctx);\n}\n/*static*/ Maybe<void> ReplicationPad1DGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& dy_shape = ctx->InputShape(\"dy\", 0);\n  const auto& padding = ctx->Attr<std::vector<int64_t>>(\"padding\");\n  CHECK_EQ_OR_RETURN(padding.size(), dy_shape.NumAxes() - 1);  // NOLINT\n  const int64_t n_idx = 0;\n  const int64_t c_idx = 1;\n  const int64_t w_idx = 2;\n\n  DimVector dx_dim_vec(dy_shape.NumAxes());\n  int64_t w_dy = dy_shape.At(w_idx);\n\n  dx_dim_vec[n_idx] = dy_shape.At(0);\n  dx_dim_vec[c_idx] = dy_shape.At(1);\n  dx_dim_vec[w_idx] = w_dy - padding[0] - padding[1];\n\n  ctx->SetOutputShape(\"dx\", 0, Shape(dx_dim_vec));\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> ReplicationPad1DGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return ReplicationPad1DGradOp::InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> ReplicationPad1DGradOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"dx\", 0, ctx->InputDType(\"dy\", 0));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> ReplicationPad2DOp::GetSbp(user_op::SbpContext* ctx) {\n  return GetOpSbpSignature<4>(ctx);\n}\n/*static*/ Maybe<void> ReplicationPad2DOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& x_shape = ctx->InputShape(\"x\", 0);\n  const auto& padding = ctx->Attr<std::vector<int64_t>>(\"padding\");\n  CHECK_EQ_OR_RETURN(padding.size(), x_shape.NumAxes());  // NOLINT\n  const int64_t n_idx = 0;\n  const int64_t c_idx = 1;\n  const int64_t h_idx = 2;\n  const int64_t w_idx = 3;\n\n  DimVector y_dim_vec(x_shape.NumAxes());\n  const int64_t h_x = x_shape.At(h_idx);\n  const int64_t w_x = x_shape.At(w_idx);\n\n  y_dim_vec[n_idx] = x_shape.At(n_idx);\n  y_dim_vec[c_idx] = x_shape.At(c_idx);\n  y_dim_vec[h_idx] = h_x + padding[2] + padding[3];\n  y_dim_vec[w_idx] = w_x + padding[0] + padding[1];\n\n  ctx->SetOutputShape(\"y\", 0, Shape(y_dim_vec));\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> ReplicationPad2DOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return ReplicationPad2DOp::InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> ReplicationPad2DOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"y\", 0, ctx->InputDType(\"x\", 0));\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> ReplicationPad2DOp::ModifyInputArg(\n    const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&) {\n  user_op::InputArgModifier* x_modifier = GetInputArgModifierFn(\"x\", 0);\n  CHECK_NOTNULL_OR_RETURN(x_modifier);  // NOLINT\n  x_modifier->set_requires_grad(true);\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> ReplicationPad2DGradOp::GetSbp(user_op::SbpContext* ctx) {\n  return GetOpGradSbpSignature<4>(ctx);\n}\n/*static*/ Maybe<void> ReplicationPad2DGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& dy_shape = ctx->InputShape(\"dy\", 0);\n  const auto& padding = ctx->Attr<std::vector<int64_t>>(\"padding\");\n  CHECK_EQ_OR_RETURN(padding.size(), dy_shape.NumAxes());  // NOLINT\n  const int64_t n_idx = 0;\n  const int64_t c_idx = 1;\n  const int64_t h_idx = 2;\n  const int64_t w_idx = 3;\n\n  DimVector dx_dim_vec(dy_shape.NumAxes());\n  int64_t h_dy = dy_shape.At(h_idx);\n  int64_t w_dy = dy_shape.At(w_idx);\n\n  dx_dim_vec[n_idx] = dy_shape.At(0);\n  dx_dim_vec[c_idx] = dy_shape.At(1);\n  dx_dim_vec[h_idx] = h_dy - padding[2] - padding[3];\n  dx_dim_vec[w_idx] = w_dy - padding[0] - padding[1];\n\n  ctx->SetOutputShape(\"dx\", 0, Shape(dx_dim_vec));\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> ReplicationPad2DGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return ReplicationPad2DGradOp::InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> ReplicationPad2DGradOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"dx\", 0, ctx->InputDType(\"dy\", 0));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/reshape_like_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/user/ops/reshape_user_op_util.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/*static*/ Maybe<void> ReshapeLikeOp::GetSbp(user_op::SbpContext* ctx) {\n  const auto& in_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"in\", 0).shape();\n  const auto& like_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"like\", 0).shape();\n  ctx->NewBuilder()\n      .PartialSum(user_op::OpArg(\"like\", 0))\n      .Broadcast(user_op::OpArg(\"in\", 0))\n      .Broadcast(user_op::OpArg(\"out\", 0))\n      .Build();\n  ctx->NewBuilder()\n      .Broadcast(user_op::OpArg(\"like\", 0))\n      .PartialSum(user_op::OpArg(\"in\", 0))\n      .PartialSum(user_op::OpArg(\"out\", 0))\n      .Build();\n  user_op::UserOpSbpSignatureBuilder builder = ctx->NewBuilder();\n  return ReshapeUserOpUtil::GetReshapeUserOpSbpSignatures(in_shape, like_shape, {{\"in\", 0}},\n                                                          {{\"like\", 0}, {\"out\", 0}},\n                                                          ctx->hierarchy_value(), &builder);\n}\n/*static*/ Maybe<void> ReshapeLikeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& in_shape = ctx->InputShape(\"in\", 0);\n  const Shape& like_shape = ctx->InputShape(\"like\", 0);\n  CHECK_EQ_OR_RETURN(in_shape.elem_cnt(), like_shape.elem_cnt())\n      << Error::RuntimeError()\n      << \"The element number of the in tensor must be equal to the element number of the \"\n         \"like tensor, \"\n      << \"but got \" << in_shape.elem_cnt() << \" and \" << like_shape.elem_cnt();\n  ctx->SetOutputShape(\"out\", 0, like_shape);\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> ReshapeLikeOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> ReshapeLikeOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> ReshapeLikeOp::ModifyInputArg(\n    const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&) {\n  user_op::InputArgModifier* like_modifier = GetInputArgModifierFn(\"like\", 0);\n  CHECK_NOTNULL_OR_RETURN(like_modifier);  // NOLINT(maybe-need-error-msg)\n  like_modifier->set_requires_grad(false);\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/reshape_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n#include \"oneflow/core/framework/sbp_infer_util.h\"\n#include \"oneflow/core/framework/user_op_conf.h\"\n#include \"oneflow/core/framework/nd_sbp.h\"\n#include \"oneflow/core/operator/operator.h\"\n#include \"oneflow/user/ops/reshape_user_op_util.h\"\n\nnamespace oneflow {\n\n/*static*/ Maybe<void> ReshapeOp::GetSbp(user_op::SbpContext* ctx) {\n  const auto& in_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"in\", 0).shape();\n  const Shape& shape = ctx->Attr<Shape>(\"shape\");\n  const auto& outshape = JUST(ReshapeUserOpUtil::GetLogicalOutBlobShape(in_shape, shape));\n  user_op::UserOpSbpSignatureBuilder builder = ctx->NewBuilder();\n  return ReshapeUserOpUtil::GetReshapeUserOpSbpSignatures(\n      in_shape, *outshape, {{\"in\", 0}}, {{\"out\", 0}}, ctx->hierarchy_value(), &builder);\n}\n\n/*static*/ Maybe<void> ReshapeOp::EnumerateNdSbpSignatures(\n    user_op::GetNdSbpSignatureListContext* ctx) {\n  const Shape& in_shape = ctx->BlobShape4InputArgNameAndIndex(\"in\", 0);\n  const Shape& shape_attr = ctx->Attr<Shape>(\"shape\");\n  std::shared_ptr<Shape> out_shape_ptr =\n      JUST(ReshapeUserOpUtil::GetLogicalOutBlobShape(in_shape, shape_attr));\n\n  std::vector<NdSbpSignature>* nd_sbp_sig_list = ctx->MutNdSbpSignatureList();\n  JUST(ReshapeUserOpUtil::EnumerateNdSbpSignatures({{\"in\", 0}}, in_shape, {{\"out\", 0}},\n                                                   *out_shape_ptr, ctx->parallel_hierarchy(),\n                                                   nd_sbp_sig_list));\n\n  // Go down from the tail to the head, since we might drop the tail.\n  for (int32_t sbp_id = nd_sbp_sig_list->size() - 1; sbp_id >= 0; sbp_id--) {\n    auto& nd_sbp_sig = (*nd_sbp_sig_list)[sbp_id];\n    const auto& out_nd_sbp_it = nd_sbp_sig.bn_in_op2nd_sbp().find(\"out_0\");\n    CHECK_OR_RETURN(out_nd_sbp_it != nd_sbp_sig.bn_in_op2nd_sbp().end())\n        << \"can't get sbp for out_0\";\n    Shape out_logical_shape = *out_shape_ptr;\n    // filter by output only be needed here\n    // filter by input will be done in Operator::FilterNdSbpSignatureListByLogicalShape\n    if (JUST(FilterNdSbpByLogicalShape(out_nd_sbp_it->second, out_logical_shape,\n                                       ctx->parallel_hierarchy()))) {\n      // Remove the Nd SBP candidate\n      std::swap(nd_sbp_sig, nd_sbp_sig_list->back());\n      nd_sbp_sig_list->pop_back();\n    }\n  }\n\n  DeduplicateNdSbpSignatureList(nd_sbp_sig_list, {\"in_0\", \"out_0\"});\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> ReshapeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  Shape shape = ctx->Attr<Shape>(\"shape\");\n  const user_op::TensorDesc& in_tensor_desc = ctx->InputTensorDesc(\"in\", 0);\n  user_op::TensorDesc* out_tensor_desc = ctx->MutOutputTensorDesc(\"out\", 0);\n\n  const Shape& in_shape = in_tensor_desc.shape();\n  CHECK_OR_RETURN(in_tensor_desc.is_dynamic() == false);  // NOLINT(maybe-need-error-msg)\n  out_tensor_desc->set_data_type(in_tensor_desc.data_type());\n  if (in_shape.NumAxes() == 0 || shape.NumAxes() == 0) {\n    // NOTE(chengcheng): input/output Scalar\n    // do nothing\n  } else {\n    CHECK_GE_OR_RETURN(shape.NumAxes(), 1);     // NOLINT(maybe-need-error-msg)\n    CHECK_GE_OR_RETURN(in_shape.NumAxes(), 1);  // NOLINT(maybe-need-error-msg)\n\n    int need_infer_axis = -1;\n    size_t count = 1;\n    for (int i = 0; i < shape.NumAxes(); ++i) {\n      if (shape.At(i) == -1) {\n        CHECK_EQ_OR_RETURN(need_infer_axis, -1)\n            << Error::RuntimeError() << \"Shape \" << shape.ToString()\n            << \" has more than 1 axis that needs to be infered\";\n        need_infer_axis = i;\n      } else {\n        count *= shape.At(i);\n      }\n    }\n    if (need_infer_axis != -1) { shape.Set(need_infer_axis, in_shape.elem_cnt() / count); }\n  }\n  out_tensor_desc->set_shape(shape);\n  out_tensor_desc->set_stride(Stride(shape));\n  // For 0-size tensor, we don't need to check whether the input and output tensors have the same\n  // element size.\n  if (in_shape.elem_cnt() > 0) {\n    CHECK_EQ_OR_RETURN(shape.elem_cnt(), in_shape.elem_cnt())\n        << Error::RuntimeError() << \"Reshape infer ERROR! in op_name: \" << ctx->op_name()\n        << \" input shape is : \" << in_shape.ToString()\n        << \" , output shape is : \" << shape.ToString()\n        << \" , and reshape shape conf is : \" << ctx->Attr<Shape>(\"shape\").ToString()\n        << \" op_loc: \" << ctx->op_loc();\n  }\n\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> ReshapeOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  Shape logical_shape = ctx->Attr<Shape>(\"shape\");\n  const user_op::TensorDesc& in_tensor_desc = ctx->InputTensorDesc(\"in\", 0);\n  user_op::TensorDesc* out_tensor_desc = ctx->MutOutputTensorDesc(\"out\", 0);\n\n  const Shape& in_shape = in_tensor_desc.shape();\n  out_tensor_desc->set_stride(Stride(in_tensor_desc.shape()));\n  out_tensor_desc->set_is_dynamic(in_tensor_desc.is_dynamic());\n  if (in_shape.NumAxes() == 0 || logical_shape.NumAxes() == 0) {\n    // NOTE(chengcheng): input/output Scalar\n    // do nothing\n  } else {\n    CHECK_GE_OR_RETURN(logical_shape.NumAxes(), 1);  // NOLINT(maybe-need-error-msg)\n    CHECK_GE_OR_RETURN(in_shape.NumAxes(), 1);       // NOLINT(maybe-need-error-msg)\n    const auto& in_nd_sbp = ctx->NdSbp4ArgNameAndIndex(\"in\", 0);\n    const Shape in_logical_shape =\n        *JUST(GetLogicalShape(in_shape, in_nd_sbp, ctx->parallel_desc()));\n    int need_infer_axis = -1;\n    size_t count = 1;\n    for (int i = 0; i < logical_shape.NumAxes(); ++i) {\n      if (logical_shape.At(i) == -1) {\n        CHECK_EQ_OR_RETURN(need_infer_axis, -1)\n            << Error::RuntimeError() << \"Shape \" << logical_shape.ToString()\n            << \" has more than 1 axis that needs to be infered\";\n        need_infer_axis = i;\n      } else {\n        count *= logical_shape.At(i);\n      }\n    }\n    if (need_infer_axis != -1) {\n      logical_shape.Set(need_infer_axis, in_logical_shape.elem_cnt() / count);\n    }\n  }\n  const auto& nd_sbp = ctx->NdSbp4ArgNameAndIndex(\"out\", 0);\n  out_tensor_desc->set_shape(\n      *JUST(GetPhysicalShape(logical_shape, nd_sbp, ctx->parallel_desc(), ctx->parallel_ctx())));\n  out_tensor_desc->set_stride(Stride(out_tensor_desc->shape()));\n  CHECK_EQ_OR_RETURN(out_tensor_desc->shape().elem_cnt(), in_shape.elem_cnt())\n      << Error::RuntimeError() << \" Reshape infer ERROR! in op_name: \" << ctx->op_name()\n      << \" input shape is : \" << in_shape.ToString()\n      << \" , output shape is : \" << out_tensor_desc->shape().ToString()\n      << \" , output logical shape is \" << logical_shape.ToString()\n      << \" , and reshape shape conf is : \" << ctx->Attr<Shape>(\"shape\").ToString()\n      << \" op_loc: \" << ctx->op_loc();\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> ReshapeOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/reshape_user_op_util.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/user/ops/reshape_user_op_util.h\"\n#include \"oneflow/core/operator/operator.h\"\n#include \"oneflow/core/common/cpp_attribute.h\"\n#include \"oneflow/core/common/container_util.h\"\n\nnamespace oneflow {\n\nMaybe<Shape> ReshapeUserOpUtil::GetLogicalOutBlobShape(const Shape& in_shape,\n                                                       const Shape& reshape) {\n  if (unlikely(in_shape.elem_cnt() == 0)) {\n    FOR_RANGE(int, axis, 0, reshape.NumAxes()) {\n      int64_t dim = reshape.At(axis);\n      if (dim == -1) {\n        return Error::RuntimeError()\n               << \"Cannot reshape tensor of 0 elements into shape \" << reshape.DebugStr()\n               << \" because the unspecified dimension size -1 can be any value and is ambiguous\";\n      } else if (dim < 0) {\n        return Error::RuntimeError() << \"Invalid shape dimension \" << dim\n                                     << \", the shape dimension can not to be less than 0\";\n      }\n    }\n    return std::make_shared<Shape>(reshape);\n  }\n  size_t total_elem_dim_exclude_minus_1 = 1;\n  bool has_minus_1 = false;\n  bool minus_1_axis = -1;\n  DimVector dim_vec;\n  FOR_RANGE(int, axis, 0, reshape.NumAxes()) {\n    int64_t dim = reshape.At(axis);\n    dim_vec.emplace_back(dim);\n    if (dim == -1) {\n      CHECK_OR_RETURN(has_minus_1 == false)\n          << Error::RuntimeError()\n          << \"There are multiple '-1' in the shape list, only one '-1' can be inferred\";\n      has_minus_1 = true;\n      minus_1_axis = axis;\n    } else if (dim > 0) {\n      CHECK_LE_OR_RETURN(dim, in_shape.elem_cnt())\n          << Error::RuntimeError() << \"Invalid axis: \" << axis << \", dim: \" << dim;\n      total_elem_dim_exclude_minus_1 *= dim;\n      CHECK_LE_OR_RETURN(total_elem_dim_exclude_minus_1, in_shape.elem_cnt())\n          << Error::RuntimeError()\n          << \"Element number in reshape_conf must be less than or equal to input blob, \"\n          << \"but got \" << total_elem_dim_exclude_minus_1 << \" and \" << in_shape.elem_cnt();\n    } else {\n      OF_UNIMPLEMENTED() << \"only positive number or -1 supported\";\n    }\n  }\n  CHECK_EQ_OR_RETURN(in_shape.elem_cnt() % total_elem_dim_exclude_minus_1, 0)\n      << Error::RuntimeError()\n      << \"Element number in input blob must be an integer multiple of reshape_conf, \"\n      << \"but got \" << in_shape.elem_cnt() << \" and \" << total_elem_dim_exclude_minus_1;\n  if (has_minus_1) {\n    dim_vec[minus_1_axis] = in_shape.elem_cnt() / total_elem_dim_exclude_minus_1;\n  } else {\n    CHECK_EQ_OR_RETURN(in_shape.elem_cnt(), total_elem_dim_exclude_minus_1)\n        << \"Element number in input blob must be equal to reshape_conf, \"\n        << \"but got \" << in_shape.elem_cnt() << \" and \" << total_elem_dim_exclude_minus_1;\n  }\n  return std::make_shared<Shape>(dim_vec);\n}\n\nMaybe<void> ReshapeUserOpUtil::Squeeze(const Shape& origin, Shape* shape,\n                                       HashMap<int, int>* squeezed_axis2origin_axis) {\n  DimVector dim_vec;\n  FOR_RANGE(int, axis, 0, origin.NumAxes()) {\n    int64_t dim = origin.At(axis);\n    CHECK_GE_OR_RETURN(dim, 0) << Error::RuntimeError()\n                               << \"Trying to suqeeze tensor with negative dimension \" << dim\n                               << \" : \" << origin.DebugStr();\n    if (dim == 1) { continue; }\n    CHECK_OR_RETURN(squeezed_axis2origin_axis->emplace(dim_vec.size(), axis).second)\n        << \"emplace error\";  // NOLINT(maybe-need-error-msg)\n    dim_vec.emplace_back(dim);\n  }\n  *shape = Shape(dim_vec);\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> ReshapeUserOpUtil::GetGroupStartInAxis2OutAxis(\n    const Shape& in_shape, const Shape& out_shape, const int64_t hierarchy_value,\n    HashMap<int, int>* group_start_in_axis2out_axis) {\n  CHECK_GE_OR_RETURN(in_shape.NumAxes(), 0)\n      << Error::RuntimeError()\n      << \"The dimension of input tensor must be greater than or equal to zero, \"\n      << \"but got \" << in_shape.NumAxes();  // support 0D tensor\n  CHECK_GE_OR_RETURN(out_shape.NumAxes(), 0)\n      << Error::RuntimeError()\n      << \"The dimension of output tensor must be greater than or equal to zero, \"\n      << \"but got \" << out_shape.NumAxes();  // support 0D tensor\n  CHECK_EQ_OR_RETURN(in_shape.elem_cnt(), out_shape.elem_cnt())\n      << Error::RuntimeError()\n      << \"The element number of input tensor must be equal to output tensor, \"\n      << \"but got \" << in_shape.elem_cnt() << \" and \" << out_shape.elem_cnt();\n  // Initialization\n  // shape_count is the product of the axis length in [start_axis, end)\n  int64_t in_shape_count = 1;\n  int64_t out_shape_count = 1;\n  int64_t in_axis = in_shape.NumAxes();\n  int64_t out_axis = out_shape.NumAxes();\n  // Move forward functions\n  auto Move2NextAxis = [](const Shape& shape, int64_t* axis, int64_t* shape_count) {\n    (*axis)--;\n    if (*axis >= 0) { *shape_count *= shape.At(*axis); }\n  };\n  auto MoveInAxis = [&] { Move2NextAxis(in_shape, &in_axis, &in_shape_count); };\n  auto MoveOutAxis = [&] { Move2NextAxis(out_shape, &out_axis, &out_shape_count); };\n  // Move the first step\n  MoveInAxis();\n  MoveOutAxis();\n  // At the last step, both in_axis == out_axis == 0\n  // Then they would move to -1 simultaneously.\n  while (in_axis >= 0) {\n    if (in_shape_count == out_shape_count) {\n      // Record split axises\n      if (in_shape.At(in_axis) == out_shape.At(out_axis)\n          || (in_shape.At(in_axis) % hierarchy_value == 0\n              && out_shape.At(out_axis) % hierarchy_value == 0)) {\n        (*group_start_in_axis2out_axis)[in_axis] = out_axis;\n      }\n      // Move forward\n      MoveInAxis();\n      MoveOutAxis();\n    } else if (in_shape_count < out_shape_count) {\n      MoveInAxis();\n    } else {\n      // in_shape_count > out_shape_count\n      MoveOutAxis();\n    }\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> ReshapeUserOpUtil::GetReshapeUserOpSbpSignatures(\n    const Shape& in_shape, const Shape& out_shape, const std::vector<user_op::OpArg>& in_args,\n    const std::vector<user_op::OpArg>& out_args, const int64_t hierarchy_value,\n    user_op::UserOpSbpSignatureBuilder* builder) {\n  if (in_shape.NumAxes() == 0 || in_shape.elem_cnt() == 0) {\n    return Maybe<void>::Ok();\n  }  // 0D/0Size tensor only support b2b\n  HashMap<int, int> squeezed_group_start_in_axis2out_axis;\n  HashMap<int, int> in_squeezed_axis2original_axis;\n  HashMap<int, int> out_squeezed_axis2original_axis;\n  {\n    Shape squeezed_in_shape;\n    Shape squeezed_out_shape;\n    JUST(ReshapeUserOpUtil::Squeeze(in_shape, &squeezed_in_shape, &in_squeezed_axis2original_axis));\n    JUST(ReshapeUserOpUtil::Squeeze(out_shape, &squeezed_out_shape,\n                                    &out_squeezed_axis2original_axis));\n    JUST(ReshapeUserOpUtil::GetGroupStartInAxis2OutAxis(squeezed_in_shape, squeezed_out_shape,\n                                                        hierarchy_value,\n                                                        &squeezed_group_start_in_axis2out_axis));\n  }\n  for (const auto& pair : squeezed_group_start_in_axis2out_axis) {\n    int64_t start_in_axis = in_squeezed_axis2original_axis.at(pair.first);\n    int64_t start_out_axis = out_squeezed_axis2original_axis.at(pair.second);\n    builder->Split(in_args, start_in_axis).Split(out_args, start_out_axis).Build();\n  }\n  builder->PartialSum(in_args).PartialSum(out_args).Build();\n  return Maybe<void>::Ok();\n}\n\nnamespace {\n\nvoid FowardRankMesh(size_t depth, size_t max_depth, std::deque<int>& rank_axes_queue,\n                    std::vector<std::vector<int>>& rank_axes_subset) {\n  if (depth == max_depth) {\n    // skip empty subset\n    if (rank_axes_queue.empty()) { return; }\n    rank_axes_subset.emplace_back();\n    auto& rank_axes = rank_axes_subset.back();\n    for (int rank_axis : rank_axes_queue) { rank_axes.push_back(rank_axis); }\n  } else {\n    // forward by skip current depth axis\n    FowardRankMesh(depth + 1, max_depth, rank_axes_queue, rank_axes_subset);\n    // fowward by keep current depth axis\n    rank_axes_queue.push_back(depth);\n    FowardRankMesh(depth + 1, max_depth, rank_axes_queue, rank_axes_subset);\n    rank_axes_queue.pop_back();\n  }\n}\n\nvoid GenRankMeshSubset(size_t mesh_depth, std::vector<std::vector<int>>& rank_axes_subset) {\n  std::deque<int> rank_axes_queue;\n  FowardRankMesh(0, mesh_depth, rank_axes_queue, rank_axes_subset);\n}\n\n}  // namespace\n\nMaybe<void> ReshapeUserOpUtil::EnumerateNdSplitIn2OutAxis(\n    const Shape& in_shape, const std::vector<int>& origin_in_axes, const Shape& out_shape,\n    const std::vector<int>& origin_out_axes, const Shape& rank_mesh,\n    std::vector<std::map<int, std::pair<int, int>>>* nd_split_groups) {\n  CHECK_EQ_OR_RETURN(in_shape.elem_cnt(), out_shape.elem_cnt());\n  CHECK_EQ_OR_RETURN(in_shape.size(), origin_in_axes.size());\n  CHECK_EQ_OR_RETURN(out_shape.size(), origin_out_axes.size());\n  // generate all subset of rank_mesh (keep order)\n  // for example rank_mesh=(2, 3, 5), subset include:\n  // (2, 3, 5)\n  // (2, 3)\n  // (2, 5)\n  // (2,)\n  // (3, 5)\n  // (3,)\n  // (5,)\n  std::vector<std::vector<int>> rank_axes_subset;\n  GenRankMeshSubset(rank_mesh.size(), rank_axes_subset);\n  // traverse all subset to detect contiguous nd-split signatures\n  // for example (6,) reshape to (2, 3) with rank_mesh=(2, 3)\n  // nd-split signatures include:\n  // S(0) -> S(0) with rank_axis=0 (1d)\n  // S(0) -> S(1) with rank_axis=1 (1d)\n  // [S(0), S(0)] -> [S(0), S(1)] with rank_mesh=(2,3) (2d)\n  for (const std::vector<int>& rank_axes : rank_axes_subset) {\n    int rank_axis_idx = 0;\n    int in_axis = in_shape.size() - 1;\n    int out_axis = out_shape.size() - 1;\n    int64_t in_dim_size = in_shape[in_axis];\n    int64_t out_dim_size = out_shape[out_axis];\n    // rank_axis -> {in_axis, out_axis}\n    std::map<int, std::pair<int, int>> rank_in2out_axis;\n    // go down from tail to head axis, since the dimensions\n    // in the in_shape and the out_shape passed in\n    // are reverse order\n    while (in_axis >= 0 && out_axis >= 0 && rank_axis_idx < rank_axes.size()) {\n      // dim_size == 1 then move to next axis to find contiguous split axis\n      if (in_dim_size == 1) {\n        in_axis--;\n        in_dim_size = in_shape[in_axis];\n        continue;\n      }\n      if (out_dim_size == 1) {\n        out_axis--;\n        out_dim_size = out_shape[out_axis];\n        continue;\n      }\n      int rank_axis = rank_axes[rank_axis_idx];\n      int64_t rank_num = rank_mesh[rank_axis];\n      // dim_size is indivisible by rank_num indicate split can't continue\n      if (in_dim_size % rank_num != 0 || out_dim_size % rank_num != 0) { break; }\n      // divide dim_size by rank_num both at in_axis and out_axis till dim_size == 1\n      in_dim_size /= rank_num;\n      out_dim_size /= rank_num;\n      int origin_in_axis = origin_in_axes[in_axis];\n      int origin_out_axis = origin_out_axes[out_axis];\n      // mark rank_axis that can be splited by in_axis and out_axis both\n      rank_in2out_axis.emplace(rank_axis, std::make_pair(origin_in_axis, origin_out_axis));\n      rank_axis_idx++;\n    }\n    // ensure all rank axes are marked splitable with some axis (in and out)\n    if (rank_in2out_axis.size() == rank_axes.size()) {\n      nd_split_groups->emplace_back(std::move(rank_in2out_axis));\n    }\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> ReshapeUserOpUtil::EnumerateNdSplitIn2OutAxisGroups(\n    const Shape& in_shape, const Shape& out_shape, const Shape& rank_mesh,\n    std::vector<std::map<int, std::pair<int, int>>>* nd_sbp_in2out_sig_groups) {\n  int in_axis = in_shape.size();\n  int out_axis = out_shape.size();\n  int64_t in_count = 1;\n  int64_t out_count = 1;\n  auto MoveAxis = [](const Shape& shape, int& axis, int64_t& count) {\n    axis--;\n    if (axis >= 0 && axis < shape.size()) { count *= shape[axis]; }\n  };\n  auto MoveInAxis = [&]() { MoveAxis(in_shape, in_axis, in_count); };\n  auto MoveOutAxis = [&]() { MoveAxis(out_shape, out_axis, out_count); };\n  MoveInAxis();\n  MoveOutAxis();\n\n  DimVector group_in_dim_vec;\n  DimVector group_out_dim_vec;\n  std::vector<int> group_in_axes;\n  std::vector<int> group_out_axes;\n  group_in_axes.reserve(rank_mesh.size());\n  group_out_axes.reserve(rank_mesh.size());\n\n  // group reshape dimensions\n  // for example:\n  // (4, 5, 2, 3) reshape to (2, 2, 5, 6) will be divided to 3 groups:\n  // (   4,| 5, | 2, 3)\n  // (2, 2,| 5, | 6)\n  // group1: (2, 3) -> (6)\n  // group2: (5,) -> (5)\n  // group3: (4,) -> (2, 2)\n  while (in_axis >= 0 && out_axis >= 0) {\n    // move in_axis when in_count < out_count\n    // move out_axis when out_count < in_count\n    // move both when in_count == out_count\n    if (in_count < out_count) {\n      // skip dim_size == 1\n      if (in_shape[in_axis] != 1) {\n        group_in_dim_vec.push_back(in_shape[in_axis]);\n        group_in_axes.push_back(in_axis);\n      }\n      MoveInAxis();\n    } else if (in_count > out_count) {\n      if (out_shape[out_axis] != 1) {\n        group_out_dim_vec.push_back(out_shape[out_axis]);\n        group_out_axes.push_back(out_axis);\n      }\n      MoveOutAxis();\n    } else {  // in_count == out_count\n      if (in_shape[in_axis] == out_shape[out_axis]) {\n        // group2: (5, 5) in the example will reach this branch\n        for (int rank_axis = 0; rank_axis < rank_mesh.size(); ++rank_axis) {\n          int64_t rank_num = rank_mesh[rank_axis];\n          if (in_shape[in_axis] % rank_num == 0) {\n            std::map<int, std::pair<int, int>> rank_in2out_split_axis{\n                {rank_axis, std::make_pair(in_axis, out_axis)}};\n            nd_sbp_in2out_sig_groups->emplace_back(std::move(rank_in2out_split_axis));\n          }\n        }\n      } else {\n        // the reshape group (group1 and group3 in the example) finish\n        group_in_dim_vec.push_back(in_shape[in_axis]);\n        group_in_axes.push_back(in_axis);\n        group_out_dim_vec.push_back(out_shape[out_axis]);\n        group_out_axes.push_back(out_axis);\n        // enumerate all nd-split signatures for one group\n        JUST(EnumerateNdSplitIn2OutAxis(Shape(group_in_dim_vec), group_in_axes,\n                                        Shape(group_out_dim_vec), group_out_axes, rank_mesh,\n                                        nd_sbp_in2out_sig_groups));\n        group_in_dim_vec.clear();\n        group_out_dim_vec.clear();\n        group_in_axes.clear();\n        group_out_axes.clear();\n      }\n      MoveInAxis();\n      MoveOutAxis();\n    }\n  }\n\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> ReshapeUserOpUtil::DfsCombineNdSbpSignatureGroups(\n    const std::vector<std::map<int, std::pair<int, int>>>& nd_sbp_sig_groups, size_t rank_num_axes,\n    std::vector<std::vector<std::pair<int, int>>>* nd_sbp_sig_list) {\n  std::map<int, std::pair<int, int>> nd_sbp_sig_group;\n  std::set<std::vector<std::pair<int, int>>> nd_sbp_sig_set;\n  JUST(DfsCombineNdSbpSignatureGroups(nd_sbp_sig_groups, rank_num_axes, nd_sbp_sig_group,\n                                      nd_sbp_sig_set));\n  std::copy(nd_sbp_sig_set.begin(), nd_sbp_sig_set.end(), back_inserter(*nd_sbp_sig_list));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> ReshapeUserOpUtil::DfsCombineNdSbpSignatureGroups(\n    const std::vector<std::map<int, std::pair<int, int>>>& nd_sbp_sig_groups, size_t rank_num_axes,\n    const std::map<int, std::pair<int, int>>& nd_sbp_sig_group,\n    std::set<std::vector<std::pair<int, int>>>& nd_sbp_sig_set) {\n  if (nd_sbp_sig_group.size() == rank_num_axes) {\n    std::vector<std::pair<int, int>> nd_sbp_sig;\n    for (int i = 0; i < rank_num_axes; ++i) {\n      nd_sbp_sig.emplace_back(JUST(MapAt(nd_sbp_sig_group, i)));\n    }\n    nd_sbp_sig_set.emplace(nd_sbp_sig);\n  } else {\n    for (const auto& nd_sbp_sig_group_to_combine : nd_sbp_sig_groups) {\n      std::map<int, std::pair<int, int>> new_nd_sbp_sig_group = nd_sbp_sig_group;\n      bool combine_failed = false;\n      for (const auto& rank_in2out_pair : nd_sbp_sig_group_to_combine) {\n        int rank_axis = rank_in2out_pair.first;\n        if (nd_sbp_sig_group.find(rank_axis) != nd_sbp_sig_group.end()) {\n          combine_failed = true;\n          break;\n        }\n        CHECK_OR_RETURN(new_nd_sbp_sig_group.emplace(rank_axis, rank_in2out_pair.second).second);\n      }\n      if (!combine_failed) {\n        JUST(DfsCombineNdSbpSignatureGroups(nd_sbp_sig_groups, rank_num_axes, new_nd_sbp_sig_group,\n                                            nd_sbp_sig_set));\n      }\n    }\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> ReshapeUserOpUtil::EnumerateNdSbpIn2OutSignatures(\n    const Shape& in_shape, const Shape& out_shape, const Shape& rank_mesh,\n    std::vector<std::vector<std::pair<int, int>>>* nd_sbp_in2out_signatures) {\n  CHECK_GT_OR_RETURN(in_shape.size(), 0)\n      << Error::RuntimeError() << \"The dimension of input tensor must be greater than zero, \"\n      << \"but got \" << in_shape.size();\n  CHECK_GT_OR_RETURN(out_shape.size(), 0)\n      << Error::RuntimeError() << \"The dimension of output tensor must be greater than zero, \"\n      << \"but got \" << out_shape.size();\n  CHECK_EQ_OR_RETURN(in_shape.elem_cnt(), out_shape.elem_cnt())\n      << Error::RuntimeError()\n      << \"The element number of input tensor must be equal to output tensor, \"\n      << \"but got \" << in_shape.elem_cnt() << \" and \" << out_shape.elem_cnt();\n\n  // groups of nd of rank_axis -> (in_axis, out_axis)\n  std::vector<std::map<int, std::pair<int, int>>> nd_sbp_signature_groups;\n  JUST(EnumerateNdSplitIn2OutAxisGroups(in_shape, out_shape, rank_mesh, &nd_sbp_signature_groups));\n\n  std::map<int, std::pair<int, int>> nd_sbp_in2out_group;\n  for (int rank_axis = 0; rank_axis < rank_mesh.size(); ++rank_axis) {\n    // -1 indicate broadcaste, -2 indicate partial sum\n    nd_sbp_in2out_group.emplace(rank_axis, std::make_pair(-1, -1));\n    nd_sbp_signature_groups.emplace_back(nd_sbp_in2out_group);\n    nd_sbp_in2out_group.clear();\n    nd_sbp_in2out_group.emplace(rank_axis, std::make_pair(-2, -2));\n    nd_sbp_signature_groups.emplace_back(nd_sbp_in2out_group);\n    nd_sbp_in2out_group.clear();\n  }\n\n  JUST(DfsCombineNdSbpSignatureGroups(nd_sbp_signature_groups, rank_mesh.size(),\n                                      nd_sbp_in2out_signatures));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> ReshapeUserOpUtil::FilterNdSbpIn2OutSignatures(\n    const Shape& in_shape, const Shape& out_shape, const Shape& rank_mesh,\n    std::vector<std::vector<std::pair<int, int>>>* nd_sbp_in2out_signatures) {\n  // filter the Nd SBP candidates\n  // Go down from the tail to the head, since we might drop the tail.\n  for (int i = nd_sbp_in2out_signatures->size() - 1; i >= 0; --i) {\n    auto& nd_sbp_sig = (*nd_sbp_in2out_signatures)[i];\n    CHECK_EQ_OR_RETURN(nd_sbp_sig.size(), rank_mesh.size());\n    bool match_failed = false;\n    DimVector in_dim_vec = in_shape.dim_vec();\n    DimVector out_dim_vec = out_shape.dim_vec();\n    for (int rank_axis = 0; rank_axis < nd_sbp_sig.size(); ++rank_axis) {\n      int64_t rank_num = rank_mesh[rank_axis];\n      int in_sig = nd_sbp_sig[rank_axis].first;\n      int out_sig = nd_sbp_sig[rank_axis].second;\n      if (in_sig >= 0) {\n        if (in_dim_vec[in_sig] % rank_num == 0) {\n          in_dim_vec[in_sig] /= rank_num;\n        } else {\n          match_failed = true;\n          break;\n        }\n      }\n      if (out_sig >= 0) {\n        if (out_dim_vec[out_sig] % rank_num == 0) {\n          out_dim_vec[out_sig] /= rank_num;\n        } else {\n          match_failed = true;\n          break;\n        }\n      }\n    }\n    if (match_failed) {\n      // swap the invalid Nd SBP with the tail and drop it\n      std::swap(nd_sbp_sig, nd_sbp_in2out_signatures->back());\n      nd_sbp_in2out_signatures->pop_back();\n    }\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> ReshapeUserOpUtil::EnumerateNdSbpSignatures(\n    const std::vector<user_op::OpArg>& in_args, const Shape& in_shape,\n    const std::vector<user_op::OpArg>& out_args, const Shape& out_shape, const Shape& rank_mesh,\n    std::vector<NdSbpSignature>* nd_sbp_sig_list) {\n  CHECK_EQ_OR_RETURN(in_shape.elem_cnt(), out_shape.elem_cnt());\n  if (in_shape.elem_cnt() == 0) { return Maybe<void>::Ok(); }\n  if (in_shape.size() == 0 || out_shape.size() == 0) { return Maybe<void>::Ok(); }\n  std::vector<std::vector<std::pair<int, int>>> nd_sbp_in2out_sig_list;\n  JUST(EnumerateNdSbpIn2OutSignatures(in_shape, out_shape, rank_mesh, &nd_sbp_in2out_sig_list));\n  for (const auto& nd_sbp_in2out_axis : nd_sbp_in2out_sig_list) {\n    nd_sbp_sig_list->emplace_back();\n    auto& nd_sbp_sig = nd_sbp_sig_list->back();\n    for (const auto& in2out_axis : nd_sbp_in2out_axis) {\n      for (const auto& in_arg : in_args) {\n        const auto& ibn = in_arg.name() + \"_\" + std::to_string(in_arg.index());\n        auto& in_nd_sbp = (*nd_sbp_sig.mutable_bn_in_op2nd_sbp())[ibn];\n        auto* in_sbp = in_nd_sbp.add_sbp_parallel();\n        if (in2out_axis.first == -1) {\n          in_sbp->mutable_broadcast_parallel();\n        } else if (in2out_axis.first == -2) {\n          in_sbp->mutable_partial_sum_parallel();\n        } else {\n          in_sbp->mutable_split_parallel()->set_axis(in2out_axis.first);\n        }\n      }\n      for (const auto& out_arg : out_args) {\n        const auto& obn = out_arg.name() + \"_\" + std::to_string(out_arg.index());\n        auto& out_nd_sbp = (*nd_sbp_sig.mutable_bn_in_op2nd_sbp())[obn];\n        auto* out_sbp = out_nd_sbp.add_sbp_parallel();\n        if (in2out_axis.second == -1) {\n          out_sbp->mutable_broadcast_parallel();\n        } else if (in2out_axis.second == -2) {\n          out_sbp->mutable_partial_sum_parallel();\n        } else {\n          out_sbp->mutable_split_parallel()->set_axis(in2out_axis.second);\n        }\n      }\n    }\n  }\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/reshape_user_op_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_OPS_RESHAPE_USER_OP_UTIL\n#define ONEFLOW_USER_OPS_RESHAPE_USER_OP_UTIL\n\n#include \"oneflow/core/framework/sbp_context.h\"\n#include \"oneflow/core/framework/user_op_conf.h\"\n#include \"oneflow/core/framework/framework.h\"\n\nnamespace oneflow {\nstruct ReshapeUserOpUtil {\n  static Maybe<Shape> GetLogicalOutBlobShape(const Shape& in_shape, const Shape& reshape);\n  static Maybe<void> Squeeze(const Shape& origin, Shape* shape,\n                             HashMap<int, int>* squeezed_axis2origin_axis);\n  static Maybe<void> GetGroupStartInAxis2OutAxis(const Shape& in_shape, const Shape& out_shape,\n                                                 const int64_t hierarchy_value,\n                                                 HashMap<int, int>* group_start_in_axis2out_axis);\n  static Maybe<void> GetReshapeUserOpSbpSignatures(const Shape& in_shape, const Shape& out_shape,\n                                                   const std::vector<user_op::OpArg>& in_args,\n                                                   const std::vector<user_op::OpArg>& out_args,\n                                                   const int64_t hierarchy_value,\n                                                   user_op::UserOpSbpSignatureBuilder* builder);\n\n  static Maybe<void> DfsCombineNdSbpSignatureGroups(\n      const std::vector<std::map<int, std::pair<int, int>>>& nd_sbp_sig_groups,\n      size_t rank_num_axes, std::vector<std::vector<std::pair<int, int>>>* nd_sbp_sig_list);\n  static Maybe<void> DfsCombineNdSbpSignatureGroups(\n      const std::vector<std::map<int, std::pair<int, int>>>& nd_sbp_sig_groups,\n      size_t rank_num_axes, const std::map<int, std::pair<int, int>>& nd_sbp_sig_group,\n      std::set<std::vector<std::pair<int, int>>>& nd_sbp_sig_set);\n  static Maybe<void> EnumerateNdSplitIn2OutAxis(\n      const Shape& in_shape, const std::vector<int>& origin_in_axes, const Shape& out_shape,\n      const std::vector<int>& origin_out_axes, const Shape& rank_mesh,\n      std::vector<std::map<int, std::pair<int, int>>>* nd_split_groups);\n  static Maybe<void> EnumerateNdSplitIn2OutAxisGroups(\n      const Shape& in_shape, const Shape& out_shape, const Shape& rank_mesh,\n      std::vector<std::map<int, std::pair<int, int>>>* nd_sbp_in2out_sig_groups);\n  static Maybe<void> EnumerateNdSbpIn2OutSignatures(\n      const Shape& in_shape, const Shape& out_shape, const Shape& rank_mesh,\n      std::vector<std::vector<std::pair<int, int>>>* nd_sbp_in2out_signatures);\n  static Maybe<void> FilterNdSbpIn2OutSignatures(\n      const Shape& in_shape, const Shape& out_shape, const Shape& rank_mesh,\n      std::vector<std::vector<std::pair<int, int>>>* nd_sbp_in2out_signatures);\n  static Maybe<void> EnumerateNdSbpSignatures(const std::vector<user_op::OpArg>& in_args,\n                                              const Shape& in_shape,\n                                              const std::vector<user_op::OpArg>& out_args,\n                                              const Shape& out_shape, const Shape& rank_mesh,\n                                              std::vector<NdSbpSignature>* nd_sbp_sig_list);\n};\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_OPS_RESHAPE_USER_OP_UTIL\n"
  },
  {
    "path": "oneflow/user/ops/reshape_user_op_util_test.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/user/ops/reshape_user_op_util.h\"\n\n#include <gtest/gtest.h>\n\nnamespace oneflow {\nnamespace test {\n\nnamespace {\n\nstd::string NdSbpIn2OutSignaturesToString(\n    const std::vector<std::vector<std::pair<int, int>>>& nd_split_in2out_axis_list) {\n  std::ostringstream ss;\n  ss << \"{\";\n  int i = 0;\n  for (const auto& nd_split_axis : nd_split_in2out_axis_list) {\n    if (i > 0) { ss << \", \"; }\n    ss << \"{\";\n    int j = 0;\n    for (const auto& split_in2out_axis : nd_split_axis) {\n      if (j > 0) { ss << \", \"; }\n      ss << \"{\" << split_in2out_axis.first << \", \" << split_in2out_axis.second << \"}\";\n      j++;\n    }\n    ss << \"}\";\n    i++;\n  }\n  ss << \"}\";\n  return ss.str();\n}\n\nstd::string NdSbpSignatureGroupsToString(\n    const std::vector<std::map<int, std::pair<int, int>>>& nd_sbp_signature_groups) {\n  std::ostringstream ss;\n  ss << \"{\";\n  int i = 0;\n  for (const auto& nd_sbp_sig_group : nd_sbp_signature_groups) {\n    if (i > 0) { ss << \", \"; }\n    ss << \"{\";\n    int j = 0;\n    for (const auto& nd_sbp_sig_pair : nd_sbp_sig_group) {\n      if (j > 0) { ss << \", \"; }\n      ss << nd_sbp_sig_pair.first << \": {\" << nd_sbp_sig_pair.second.first << \", \"\n         << nd_sbp_sig_pair.second.second << \"}\";\n      j++;\n    }\n    ss << \"}\";\n    i++;\n  }\n  ss << \"}\";\n  return ss.str();\n}\n\nvoid TestEnumerateNdSbpIn2OutSignatures(\n    const Shape& in_shape, const Shape& out_shape, const Shape& rank_mesh,\n    const std::vector<std::map<int, std::pair<int, int>>>& expected_nd_sbp_in2out_sig_groups,\n    const std::vector<std::vector<std::pair<int, int>>>& expected_nd_sbp_in2out_sig_list) {\n  std::vector<std::map<int, std::pair<int, int>>> actual_nd_sbp_in2out_sig_groups;\n  CHECK_JUST(ReshapeUserOpUtil::EnumerateNdSplitIn2OutAxisGroups(in_shape, out_shape, rank_mesh,\n                                                                 &actual_nd_sbp_in2out_sig_groups));\n  std::sort(actual_nd_sbp_in2out_sig_groups.begin(), actual_nd_sbp_in2out_sig_groups.end());\n  ASSERT_EQ(expected_nd_sbp_in2out_sig_groups.size(), actual_nd_sbp_in2out_sig_groups.size());\n  for (size_t i = 0; i < actual_nd_sbp_in2out_sig_groups.size(); ++i) {\n    const auto& exp_nd_sbp_sig_group = expected_nd_sbp_in2out_sig_groups[i];\n    const auto& act_nd_sbp_sig_group = actual_nd_sbp_in2out_sig_groups[i];\n    ASSERT_EQ(exp_nd_sbp_sig_group.size(), act_nd_sbp_sig_group.size());\n    for (const auto& act_pair : act_nd_sbp_sig_group) {\n      auto exp_it = exp_nd_sbp_sig_group.find(act_pair.first);\n      ASSERT_TRUE(exp_it != exp_nd_sbp_sig_group.end());\n      ASSERT_EQ(exp_it->second.first, act_pair.second.first);\n      ASSERT_EQ(exp_it->second.second, act_pair.second.second);\n    }\n  }\n\n  std::vector<std::vector<std::pair<int, int>>> actual_nd_sbp_in2out_sig_list;\n  CHECK_JUST(ReshapeUserOpUtil::EnumerateNdSbpIn2OutSignatures(in_shape, out_shape, rank_mesh,\n                                                               &actual_nd_sbp_in2out_sig_list));\n  CHECK_JUST(ReshapeUserOpUtil::FilterNdSbpIn2OutSignatures(in_shape, out_shape, rank_mesh,\n                                                            &actual_nd_sbp_in2out_sig_list));\n  std::sort(actual_nd_sbp_in2out_sig_list.begin(), actual_nd_sbp_in2out_sig_list.end());\n  ASSERT_EQ(expected_nd_sbp_in2out_sig_list.size(), actual_nd_sbp_in2out_sig_list.size());\n  for (size_t i = 0; i < actual_nd_sbp_in2out_sig_list.size(); ++i) {\n    const auto& exp_nd_sbp_sig = expected_nd_sbp_in2out_sig_list[i];\n    const auto& act_nd_sbp_sig = actual_nd_sbp_in2out_sig_list[i];\n    ASSERT_EQ(exp_nd_sbp_sig.size(), act_nd_sbp_sig.size());\n    for (size_t j = 0; j < act_nd_sbp_sig.size(); ++j) {\n      ASSERT_EQ(exp_nd_sbp_sig[j].first, act_nd_sbp_sig[j].first);\n      ASSERT_EQ(exp_nd_sbp_sig[j].second, act_nd_sbp_sig[j].second);\n    }\n  }\n}\n\n}  // namespace\n\nusing std::pair;\n\nTEST(ReshapeUserOpUtil, EnumerateNdSbpIn2OutSignatures) {\n  // clang-format off\n  // 2D-split\n  TestEnumerateNdSbpIn2OutSignatures(\n      /*in_shape*/ {4}, /*out_shape*/ {2, 2}, /*rank_mesh*/ {2, 2},\n      {{{0, pair{0, 0}}},\n       {{0, pair{0, 0}}, {1, pair{0, 1}}},\n       {{1, pair{0, 0}}}},\n      {{{-2, -2}, {-2, -2}},\n       {{-2, -2}, {-1, -1}},\n       {{-2, -2}, {0, 0}},\n       {{-1, -1}, {-2, -2}},\n       {{-1, -1}, {-1, -1}},\n       {{-1, -1}, {0, 0}},\n       {{0, 0}, {-2, -2}},\n       {{0, 0}, {-1, -1}},\n       {{0, 0}, {0, 1}}});\n  TestEnumerateNdSbpIn2OutSignatures(\n      /*in_shape*/ {12}, /*out_shape*/ {2, 2, 3}, /*rank_mesh*/ {2, 2},\n      {{{0, pair{0, 0}}},\n       {{0, pair{0, 0}}, {1, pair{0, 1}}},\n       {{1, pair{0, 0}}}},\n      {{{-2, -2}, {-2, -2}},\n       {{-2, -2}, {-1, -1}},\n       {{-2, -2}, {0, 0}},\n       {{-1, -1}, {-2, -2}},\n       {{-1, -1}, {-1, -1}},\n       {{-1, -1}, {0, 0}},\n       {{0, 0}, {-2, -2}},\n       {{0, 0}, {-1, -1}},\n       {{0, 0}, {0, 1}}});\n  TestEnumerateNdSbpIn2OutSignatures(\n      /*in_shape*/ {2, 4}, /*out_shape*/ {8}, /*rank_mesh*/ {2, 2},\n      {{{0, pair{0, 0}}},\n       {{0, pair{0, 0}}, {1, pair{1, 0}}},\n       {{1, pair{0, 0}}}},\n      {{{-2, -2}, {-2, -2}},\n       {{-2, -2}, {-1, -1}},\n       {{-2, -2}, {0, 0}},\n       {{-1, -1}, {-2, -2}},\n       {{-1, -1}, {-1, -1}},\n       {{-1, -1}, {0, 0}},\n       {{0, 0}, {-2, -2}},\n       {{0, 0}, {-1, -1}},\n       {{0, 0}, {1, 0}}});\n  TestEnumerateNdSbpIn2OutSignatures(\n      /*in_shape*/ {2, 1, 4}, /*out_shape*/ {8}, /*rank_mesh*/ {2, 2},\n      {{{0, pair{0, 0}}},\n       {{0, pair{0, 0}}, {1, pair{2, 0}}},\n       {{1, pair{0, 0}}}},\n      {{{-2, -2}, {-2, -2}},\n       {{-2, -2}, {-1, -1}},\n       {{-2, -2}, {0, 0}},\n       {{-1, -1}, {-2, -2}},\n       {{-1, -1}, {-1, -1}},\n       {{-1, -1}, {0, 0}},\n       {{0, 0}, {-2, -2}},\n       {{0, 0}, {-1, -1}},\n       {{0, 0}, {2, 0}}});\n  TestEnumerateNdSbpIn2OutSignatures(\n      /*in_shape*/ {8, 2}, /*out_shape*/ {2, 4, 2}, /*rank_mesh*/ {2, 2},\n      {{{0, pair{0, 0}}},\n       {{0, pair{0, 0}}, {1, pair{0, 1}}},\n       {{0, pair{1, 2}}},\n       {{1, pair{0, 0}}},\n       {{1, pair{1, 2}}}},\n      {{{-2, -2}, {-2, -2}},\n       {{-2, -2}, {-1, -1}},\n       {{-2, -2}, {0, 0}},\n       {{-2, -2}, {1, 2}},\n       {{-1, -1}, {-2, -2}},\n       {{-1, -1}, {-1, -1}},\n       {{-1, -1}, {0, 0}},\n       {{-1, -1}, {1, 2}},\n       {{0, 0}, {-2, -2}},\n       {{0, 0}, {-1, -1}},\n       {{0, 0}, {0, 1}},\n       {{0, 0}, {1, 2}},\n       {{1, 2}, {-2, -2}},\n       {{1, 2}, {-1, -1}},\n       {{1, 2}, {0, 0}}});\n  TestEnumerateNdSbpIn2OutSignatures(\n      /*in_shape*/ {8, 1, 2}, /*out_shape*/ {2, 1, 4, 2}, /*rank_mesh*/ {2, 2},\n      {{{0, pair{0, 0}}},\n       {{0, pair{0, 0}}, {1, pair{0, 2}}},\n       {{0, pair{2, 3}}},\n       {{1, pair{0, 0}}},\n       {{1, pair{2, 3}}},},\n      {{{-2, -2}, {-2, -2}},\n       {{-2, -2}, {-1, -1}},\n       {{-2, -2}, {0, 0}},\n       {{-2, -2}, {2, 3}},\n       {{-1, -1}, {-2, -2}},\n       {{-1, -1}, {-1, -1}},\n       {{-1, -1}, {0, 0}},\n       {{-1, -1}, {2, 3}},\n       {{0, 0}, {-2, -2}},\n       {{0, 0}, {-1, -1}},\n       {{0, 0}, {0, 2}},\n       {{0, 0}, {2, 3}},\n       {{2, 3}, {-2, -2}},\n       {{2, 3}, {-1, -1}},\n       {{2, 3}, {0, 0}}});\n  TestEnumerateNdSbpIn2OutSignatures(\n      /*in_shape*/ {3, 2, 3, 5}, /*out_shape*/ {3, 30}, /*rank_mesh*/ {2, 3},\n      {{{0, pair{1, 1}}},\n       {{0, pair{1, 1}}, {1, pair{2, 1}}},\n       {{1, pair{0, 0}}}},\n      {{{-2, -2}, {-2, -2}},\n       {{-2, -2}, {-1, -1}},\n       {{-2, -2}, {0, 0}},\n       {{-1, -1}, {-2, -2}},\n       {{-1, -1}, {-1, -1}},\n       {{-1, -1}, {0, 0}},\n       {{1, 1}, {-2, -2}},\n       {{1, 1}, {-1, -1}},\n       {{1, 1}, {0, 0}},\n       {{1, 1}, {2, 1}}});\n  TestEnumerateNdSbpIn2OutSignatures(\n      /*in_shape*/ {2, 4}, /*out_shape*/ {4, 2}, /*rank_mesh*/ {2, 2},\n      {{{0, pair{0, 0}}},\n       {{0, pair{0, 0}}, {1, pair{1, 0}}},\n       {{1, pair{0, 0}}}},\n      {{{-2, -2}, {-2, -2}},\n       {{-2, -2}, {-1, -1}},\n       {{-2, -2}, {0, 0}},\n       {{-1, -1}, {-2, -2}},\n       {{-1, -1}, {-1, -1}},\n       {{-1, -1}, {0, 0}},\n       {{0, 0}, {-2, -2}},\n       {{0, 0}, {-1, -1}},\n       {{0, 0}, {1, 0}}});\n  TestEnumerateNdSbpIn2OutSignatures(\n      /*in_shape*/ {4, 2}, /*out_shape*/ {2, 4}, /*rank_mesh*/ {2, 2},\n      {{{0, pair{0, 0}}},\n       {{0, pair{0, 0}}, {1, pair{0, 1}}},\n       {{1, pair{0, 0}}}},\n      {{{-2, -2}, {-2, -2}},\n       {{-2, -2}, {-1, -1}},\n       {{-2, -2}, {0, 0}},\n       {{-1, -1}, {-2, -2}},\n       {{-1, -1}, {-1, -1}},\n       {{-1, -1}, {0, 0}},\n       {{0, 0}, {-2, -2}},\n       {{0, 0}, {-1, -1}},\n       {{0, 0}, {0, 1}}});\n  TestEnumerateNdSbpIn2OutSignatures(\n      /*in_shape*/ {4, 3}, /*out_shape*/ {3, 4}, /*rank_mesh*/ {2, 3}, \n      {},\n      {{{-2, -2}, {-2, -2}}, \n       {{-2, -2}, {-1, -1}}, \n       {{-1, -1}, {-2, -2}}, \n       {{-1, -1}, {-1, -1}}});\n  TestEnumerateNdSbpIn2OutSignatures(\n      /*in_shape*/ {2, 6}, /*out_shape*/ {4, 3}, /*rank_mesh*/ {2, 3},\n      {{{0, pair{0, 0}}}},\n      {{{-2, -2}, {-2, -2}},\n       {{-2, -2}, {-1, -1}},\n       {{-1, -1}, {-2, -2}},\n       {{-1, -1}, {-1, -1}},\n       {{0, 0}, {-2, -2}},\n       {{0, 0}, {-1, -1}}});\n  TestEnumerateNdSbpIn2OutSignatures(\n      /*in_shape*/ {2, 2, 5, 4}, /*out_shape*/ {4, 5, 2, 2}, /*rank_mesh*/ {2, 2},\n      {{{0, pair{0, 0}}},\n       {{0, pair{0, 0}}, {1, pair{1, 0}}},\n       {{0, pair{3, 2}}},\n       {{0, pair{3, 2}}, {1, pair{3, 3}}},\n       {{1, pair{0, 0}}},\n       {{1, pair{3, 2}}}},\n      {{{-2, -2}, {-2, -2}},\n       {{-2, -2}, {-1, -1}},\n       {{-2, -2}, {0, 0}},\n       {{-2, -2}, {3, 2}},\n       {{-1, -1}, {-2, -2}},\n       {{-1, -1}, {-1, -1}},\n       {{-1, -1}, {0, 0}},\n       {{-1, -1}, {3, 2}},\n       {{0, 0}, {-2, -2}},\n       {{0, 0}, {-1, -1}},\n       {{0, 0}, {1, 0}},\n       {{0, 0}, {3, 2}},\n       {{3, 2}, {-2, -2}},\n       {{3, 2}, {-1, -1}},\n       {{3, 2}, {0, 0}},\n       {{3, 2}, {3, 3}}});\n  // 3D-split\n  TestEnumerateNdSbpIn2OutSignatures(\n      /*in_shape*/ {24}, /*out_shape*/ {2, 4, 3}, /*rank_mesh*/ {2, 2, 2},\n      {{{0, pair{0, 0}}},\n       {{0, pair{0, 0}}, {1, pair{0, 1}}},\n       {{0, pair{0, 0}}, {1, pair{0, 1}}, {2, pair{0, 1}}},\n       {{0, pair{0, 0}}, {2, pair{0, 1}}},\n       {{1, pair{0, 0}}},\n       {{1, pair{0, 0}}, {2, pair{0, 1}}},\n       {{2, pair{0, 0}}}},\n      {{{-2, -2}, {-2, -2}, {-2, -2}}, {{-2, -2}, {-2, -2}, {-1, -1}}, {{-2, -2}, {-2, -2}, {0, 0}},\n       {{-2, -2}, {-1, -1}, {-2, -2}}, {{-2, -2}, {-1, -1}, {-1, -1}}, {{-2, -2}, {-1, -1}, {0, 0}},\n       {{-2, -2}, {0, 0}, {-2, -2}},   {{-2, -2}, {0, 0}, {-1, -1}},   {{-2, -2}, {0, 0}, {0, 1}},\n       {{-1, -1}, {-2, -2}, {-2, -2}}, {{-1, -1}, {-2, -2}, {-1, -1}}, {{-1, -1}, {-2, -2}, {0, 0}},\n       {{-1, -1}, {-1, -1}, {-2, -2}}, {{-1, -1}, {-1, -1}, {-1, -1}}, {{-1, -1}, {-1, -1}, {0, 0}},\n       {{-1, -1}, {0, 0}, {-2, -2}},   {{-1, -1}, {0, 0}, {-1, -1}},   {{-1, -1}, {0, 0}, {0, 1}},\n       {{0, 0}, {-2, -2}, {-2, -2}},   {{0, 0}, {-2, -2}, {-1, -1}},   {{0, 0}, {-2, -2}, {0, 1}},\n       {{0, 0}, {-1, -1}, {-2, -2}},   {{0, 0}, {-1, -1}, {-1, -1}},   {{0, 0}, {-1, -1}, {0, 1}},\n       {{0, 0}, {0, 1}, {-2, -2}},     {{0, 0}, {0, 1}, {-1, -1}},     {{0, 0}, {0, 1}, {0, 1}}});\n  TestEnumerateNdSbpIn2OutSignatures(\n      /*in_shape*/ {3, 24}, /*out_shape*/ {3, 2, 2, 6}, /*rank_mesh*/ {2, 2, 2},\n      {{{0, pair{1, 1}}},\n       {{0, pair{1, 1}}, {1, pair{1, 2}}},\n       {{0, pair{1, 1}}, {1, pair{1, 2}}, {2, pair{1, 3}}},\n       {{0, pair{1, 1}}, {2, pair{1, 2}}},\n       {{1, pair{1, 1}}},\n       {{1, pair{1, 1}}, {2, pair{1, 2}}},\n       {{2, pair{1, 1}}}},\n      {{{-2, -2}, {-2, -2}, {-2, -2}}, {{-2, -2}, {-2, -2}, {-1, -1}}, {{-2, -2}, {-2, -2}, {1, 1}},\n       {{-2, -2}, {-1, -1}, {-2, -2}}, {{-2, -2}, {-1, -1}, {-1, -1}}, {{-2, -2}, {-1, -1}, {1, 1}},\n       {{-2, -2}, {1, 1}, {-2, -2}},   {{-2, -2}, {1, 1}, {-1, -1}},   {{-2, -2}, {1, 1}, {1, 2}},\n       {{-1, -1}, {-2, -2}, {-2, -2}}, {{-1, -1}, {-2, -2}, {-1, -1}}, {{-1, -1}, {-2, -2}, {1, 1}},\n       {{-1, -1}, {-1, -1}, {-2, -2}}, {{-1, -1}, {-1, -1}, {-1, -1}}, {{-1, -1}, {-1, -1}, {1, 1}},\n       {{-1, -1}, {1, 1}, {-2, -2}},   {{-1, -1}, {1, 1}, {-1, -1}},   {{-1, -1}, {1, 1}, {1, 2}},\n       {{1, 1}, {-2, -2}, {-2, -2}},   {{1, 1}, {-2, -2}, {-1, -1}},   {{1, 1}, {-2, -2}, {1, 2}},\n       {{1, 1}, {-1, -1}, {-2, -2}},   {{1, 1}, {-1, -1}, {-1, -1}},   {{1, 1}, {-1, -1}, {1, 2}},\n       {{1, 1}, {1, 2}, {-2, -2}},     {{1, 1}, {1, 2}, {-1, -1}},     {{1, 1}, {1, 2}, {1, 3}}});\n  TestEnumerateNdSbpIn2OutSignatures(\n      /*in_shape*/ {4, 77, 3}, /*out_shape*/ {2, 2, 77, 3}, /*rank_mesh*/ {2, 2, 3},\n      {{{0, pair{0, 0}}},\n       {{0, pair{0, 0}}, {1, pair{0, 1}}},\n       {{1, pair{0, 0}}},\n       {{2, pair{2, 3}}}},\n      {{{-2, -2}, {-2, -2}, {-2, -2}}, {{-2, -2}, {-2, -2}, {-1, -1}}, {{-2, -2}, {-2, -2}, {2, 3}},\n       {{-2, -2}, {-1, -1}, {-2, -2}}, {{-2, -2}, {-1, -1}, {-1, -1}}, {{-2, -2}, {-1, -1}, {2, 3}},\n       {{-2, -2}, {0, 0}, {-2, -2}},   {{-2, -2}, {0, 0}, {-1, -1}},   {{-2, -2}, {0, 0}, {2, 3}},\n       {{-1, -1}, {-2, -2}, {-2, -2}}, {{-1, -1}, {-2, -2}, {-1, -1}}, {{-1, -1}, {-2, -2}, {2, 3}},\n       {{-1, -1}, {-1, -1}, {-2, -2}}, {{-1, -1}, {-1, -1}, {-1, -1}}, {{-1, -1}, {-1, -1}, {2, 3}},\n       {{-1, -1}, {0, 0}, {-2, -2}},   {{-1, -1}, {0, 0}, {-1, -1}},   {{-1, -1}, {0, 0}, {2, 3}},\n       {{0, 0}, {-2, -2}, {-2, -2}},   {{0, 0}, {-2, -2}, {-1, -1}},   {{0, 0}, {-2, -2}, {2, 3}},\n       {{0, 0}, {-1, -1}, {-2, -2}},   {{0, 0}, {-1, -1}, {-1, -1}},   {{0, 0}, {-1, -1}, {2, 3}},\n       {{0, 0}, {0, 1}, {-2, -2}},     {{0, 0}, {0, 1}, {-1, -1}},     {{0, 0}, {0, 1}, {2, 3}}});\n  TestEnumerateNdSbpIn2OutSignatures(\n      /*in_shape*/ {2, 3, 2, 5}, /*out_shape*/ {12, 5}, /*rank_mesh*/ {2, 3, 2},\n      {{{0, pair{0, 0}}},\n       {{0, pair{0, 0}}, {1, pair{1, 0}}},\n       {{0, pair{0, 0}}, {1, pair{1, 0}}, {2, pair{2, 0}}},\n       {{2, pair{0, 0}}}},\n      {{{-2, -2}, {-2, -2}, {-2, -2}},\n       {{-2, -2}, {-2, -2}, {-1, -1}},\n       {{-2, -2}, {-2, -2}, {0, 0}},\n       {{-2, -2}, {-1, -1}, {-2, -2}},\n       {{-2, -2}, {-1, -1}, {-1, -1}},\n       {{-2, -2}, {-1, -1}, {0, 0}},\n       {{-1, -1}, {-2, -2}, {-2, -2}},\n       {{-1, -1}, {-2, -2}, {-1, -1}},\n       {{-1, -1}, {-2, -2}, {0, 0}},\n       {{-1, -1}, {-1, -1}, {-2, -2}},\n       {{-1, -1}, {-1, -1}, {-1, -1}},\n       {{-1, -1}, {-1, -1}, {0, 0}},\n       {{0, 0}, {-2, -2}, {-2, -2}},\n       {{0, 0}, {-2, -2}, {-1, -1}},\n       {{0, 0}, {-1, -1}, {-2, -2}},\n       {{0, 0}, {-1, -1}, {-1, -1}},\n       {{0, 0}, {1, 0}, {-2, -2}},\n       {{0, 0}, {1, 0}, {-1, -1}},\n       {{0, 0}, {1, 0}, {2, 0}}});\n  TestEnumerateNdSbpIn2OutSignatures(\n      /*in_shape*/ {2, 1, 3, 2, 5}, /*out_shape*/ {12, 1, 5}, /*rank_mesh*/ {2, 3, 2},\n      {{{0, pair{0, 0}}},\n       {{0, pair{0, 0}}, {1, pair{2, 0}}},\n       {{0, pair{0, 0}}, {1, pair{2, 0}}, {2, pair{3, 0}}},\n       {{2, pair{0, 0}}}},\n      {{{-2, -2}, {-2, -2}, {-2, -2}},\n       {{-2, -2}, {-2, -2}, {-1, -1}},\n       {{-2, -2}, {-2, -2}, {0, 0}},\n       {{-2, -2}, {-1, -1}, {-2, -2}},\n       {{-2, -2}, {-1, -1}, {-1, -1}},\n       {{-2, -2}, {-1, -1}, {0, 0}},\n       {{-1, -1}, {-2, -2}, {-2, -2}},\n       {{-1, -1}, {-2, -2}, {-1, -1}},\n       {{-1, -1}, {-2, -2}, {0, 0}},\n       {{-1, -1}, {-1, -1}, {-2, -2}},\n       {{-1, -1}, {-1, -1}, {-1, -1}},\n       {{-1, -1}, {-1, -1}, {0, 0}},\n       {{0, 0}, {-2, -2}, {-2, -2}},\n       {{0, 0}, {-2, -2}, {-1, -1}},\n       {{0, 0}, {-1, -1}, {-2, -2}},\n       {{0, 0}, {-1, -1}, {-1, -1}},\n       {{0, 0}, {2, 0}, {-2, -2}},\n       {{0, 0}, {2, 0}, {-1, -1}},\n       {{0, 0}, {2, 0}, {3, 0}}});\n  TestEnumerateNdSbpIn2OutSignatures(\n      /*in_shape*/ {8, 4}, /*out_shape*/ {2, 2, 8}, /*rank_mesh*/ {2, 2, 2},\n      {{{0, pair{0, 0}}},\n       {{0, pair{0, 0}}, {1, pair{0, 1}}},\n       {{0, pair{0, 0}}, {1, pair{0, 1}}, {2, pair{0, 2}}},\n       {{0, pair{0, 0}}, {2, pair{0, 1}}},\n       {{1, pair{0, 0}}},\n       {{1, pair{0, 0}}, {2, pair{0, 1}}},\n       {{2, pair{0, 0}}}},\n      {{{-2, -2}, {-2, -2}, {-2, -2}}, {{-2, -2}, {-2, -2}, {-1, -1}}, {{-2, -2}, {-2, -2}, {0, 0}},\n       {{-2, -2}, {-1, -1}, {-2, -2}}, {{-2, -2}, {-1, -1}, {-1, -1}}, {{-2, -2}, {-1, -1}, {0, 0}},\n       {{-2, -2}, {0, 0}, {-2, -2}},   {{-2, -2}, {0, 0}, {-1, -1}},   {{-2, -2}, {0, 0}, {0, 1}},\n       {{-1, -1}, {-2, -2}, {-2, -2}}, {{-1, -1}, {-2, -2}, {-1, -1}}, {{-1, -1}, {-2, -2}, {0, 0}},\n       {{-1, -1}, {-1, -1}, {-2, -2}}, {{-1, -1}, {-1, -1}, {-1, -1}}, {{-1, -1}, {-1, -1}, {0, 0}},\n       {{-1, -1}, {0, 0}, {-2, -2}},   {{-1, -1}, {0, 0}, {-1, -1}},   {{-1, -1}, {0, 0}, {0, 1}},\n       {{0, 0}, {-2, -2}, {-2, -2}},   {{0, 0}, {-2, -2}, {-1, -1}},   {{0, 0}, {-2, -2}, {0, 1}},\n       {{0, 0}, {-1, -1}, {-2, -2}},   {{0, 0}, {-1, -1}, {-1, -1}},   {{0, 0}, {-1, -1}, {0, 1}},\n       {{0, 0}, {0, 1}, {-2, -2}},     {{0, 0}, {0, 1}, {-1, -1}},     {{0, 0}, {0, 1}, {0, 2}}});\n  TestEnumerateNdSbpIn2OutSignatures(\n      /*in_shape*/ {8, 2, 2}, /*out_shape*/ {2, 2, 4, 2}, /*rank_mesh*/ {2, 2, 2},\n      {{{0, pair{0, 0}}},\n       {{0, pair{0, 0}}, {1, pair{0, 1}}},\n       {{0, pair{0, 0}}, {1, pair{0, 1}}, {2, pair{0, 2}}},\n       {{0, pair{0, 0}}, {2, pair{0, 1}}},\n       {{0, pair{2, 3}}},\n       {{1, pair{0, 0}}},\n       {{1, pair{0, 0}}, {2, pair{0, 1}}},\n       {{1, pair{2, 3}}},\n       {{2, pair{0, 0}}},\n       {{2, pair{2, 3}}}},\n      {{{-2, -2}, {-2, -2}, {-2, -2}}, {{-2, -2}, {-2, -2}, {-1, -1}},\n       {{-2, -2}, {-2, -2}, {0, 0}},   {{-2, -2}, {-2, -2}, {2, 3}},\n       {{-2, -2}, {-1, -1}, {-2, -2}}, {{-2, -2}, {-1, -1}, {-1, -1}},\n       {{-2, -2}, {-1, -1}, {0, 0}},   {{-2, -2}, {-1, -1}, {2, 3}},\n       {{-2, -2}, {0, 0}, {-2, -2}},   {{-2, -2}, {0, 0}, {-1, -1}},\n       {{-2, -2}, {0, 0}, {0, 1}},     {{-2, -2}, {0, 0}, {2, 3}},\n       {{-2, -2}, {2, 3}, {-2, -2}},   {{-2, -2}, {2, 3}, {-1, -1}},\n       {{-2, -2}, {2, 3}, {0, 0}},     {{-1, -1}, {-2, -2}, {-2, -2}},\n       {{-1, -1}, {-2, -2}, {-1, -1}}, {{-1, -1}, {-2, -2}, {0, 0}},\n       {{-1, -1}, {-2, -2}, {2, 3}},   {{-1, -1}, {-1, -1}, {-2, -2}},\n       {{-1, -1}, {-1, -1}, {-1, -1}}, {{-1, -1}, {-1, -1}, {0, 0}},\n       {{-1, -1}, {-1, -1}, {2, 3}},   {{-1, -1}, {0, 0}, {-2, -2}},\n       {{-1, -1}, {0, 0}, {-1, -1}},   {{-1, -1}, {0, 0}, {0, 1}},\n       {{-1, -1}, {0, 0}, {2, 3}},     {{-1, -1}, {2, 3}, {-2, -2}},\n       {{-1, -1}, {2, 3}, {-1, -1}},   {{-1, -1}, {2, 3}, {0, 0}},\n       {{0, 0}, {-2, -2}, {-2, -2}},   {{0, 0}, {-2, -2}, {-1, -1}},\n       {{0, 0}, {-2, -2}, {0, 1}},     {{0, 0}, {-2, -2}, {2, 3}},\n       {{0, 0}, {-1, -1}, {-2, -2}},   {{0, 0}, {-1, -1}, {-1, -1}},\n       {{0, 0}, {-1, -1}, {0, 1}},     {{0, 0}, {-1, -1}, {2, 3}},\n       {{0, 0}, {0, 1}, {-2, -2}},     {{0, 0}, {0, 1}, {-1, -1}},\n       {{0, 0}, {0, 1}, {0, 2}},       {{0, 0}, {0, 1}, {2, 3}},\n       {{0, 0}, {2, 3}, {-2, -2}},     {{0, 0}, {2, 3}, {-1, -1}},\n       {{0, 0}, {2, 3}, {0, 1}},       {{2, 3}, {-2, -2}, {-2, -2}},\n       {{2, 3}, {-2, -2}, {-1, -1}},   {{2, 3}, {-2, -2}, {0, 0}},\n       {{2, 3}, {-1, -1}, {-2, -2}},   {{2, 3}, {-1, -1}, {-1, -1}},\n       {{2, 3}, {-1, -1}, {0, 0}},     {{2, 3}, {0, 0}, {-2, -2}},\n       {{2, 3}, {0, 0}, {-1, -1}},     {{2, 3}, {0, 0}, {0, 1}}});\n  TestEnumerateNdSbpIn2OutSignatures(\n      /*in_shape*/ {8, 2, 1, 2}, /*out_shape*/ {2, 2, 1, 4, 2}, /*rank_mesh*/ {2, 2, 2},\n      {{{0, pair{0, 0}}},\n       {{0, pair{0, 0}}, {1, pair{0, 1}}},\n       {{0, pair{0, 0}}, {1, pair{0, 1}}, {2, pair{0, 3}}},\n       {{0, pair{0, 0}}, {2, pair{0, 1}}},\n       {{0, pair{3, 4}}},\n       {{1, pair{0, 0}}},\n       {{1, pair{0, 0}}, {2, pair{0, 1}}},\n       {{1, pair{3, 4}}},\n       {{2, pair{0, 0}}},\n       {{2, pair{3, 4}}}},\n      {{{-2, -2}, {-2, -2}, {-2, -2}}, {{-2, -2}, {-2, -2}, {-1, -1}},\n       {{-2, -2}, {-2, -2}, {0, 0}},   {{-2, -2}, {-2, -2}, {3, 4}},\n       {{-2, -2}, {-1, -1}, {-2, -2}}, {{-2, -2}, {-1, -1}, {-1, -1}},\n       {{-2, -2}, {-1, -1}, {0, 0}},   {{-2, -2}, {-1, -1}, {3, 4}},\n       {{-2, -2}, {0, 0}, {-2, -2}},   {{-2, -2}, {0, 0}, {-1, -1}},\n       {{-2, -2}, {0, 0}, {0, 1}},     {{-2, -2}, {0, 0}, {3, 4}},\n       {{-2, -2}, {3, 4}, {-2, -2}},   {{-2, -2}, {3, 4}, {-1, -1}},\n       {{-2, -2}, {3, 4}, {0, 0}},     {{-1, -1}, {-2, -2}, {-2, -2}},\n       {{-1, -1}, {-2, -2}, {-1, -1}}, {{-1, -1}, {-2, -2}, {0, 0}},\n       {{-1, -1}, {-2, -2}, {3, 4}},   {{-1, -1}, {-1, -1}, {-2, -2}},\n       {{-1, -1}, {-1, -1}, {-1, -1}}, {{-1, -1}, {-1, -1}, {0, 0}},\n       {{-1, -1}, {-1, -1}, {3, 4}},   {{-1, -1}, {0, 0}, {-2, -2}},\n       {{-1, -1}, {0, 0}, {-1, -1}},   {{-1, -1}, {0, 0}, {0, 1}},\n       {{-1, -1}, {0, 0}, {3, 4}},     {{-1, -1}, {3, 4}, {-2, -2}},\n       {{-1, -1}, {3, 4}, {-1, -1}},   {{-1, -1}, {3, 4}, {0, 0}},\n       {{0, 0}, {-2, -2}, {-2, -2}},   {{0, 0}, {-2, -2}, {-1, -1}},\n       {{0, 0}, {-2, -2}, {0, 1}},     {{0, 0}, {-2, -2}, {3, 4}},\n       {{0, 0}, {-1, -1}, {-2, -2}},   {{0, 0}, {-1, -1}, {-1, -1}},\n       {{0, 0}, {-1, -1}, {0, 1}},     {{0, 0}, {-1, -1}, {3, 4}},\n       {{0, 0}, {0, 1}, {-2, -2}},     {{0, 0}, {0, 1}, {-1, -1}},\n       {{0, 0}, {0, 1}, {0, 3}},       {{0, 0}, {0, 1}, {3, 4}},\n       {{0, 0}, {3, 4}, {-2, -2}},     {{0, 0}, {3, 4}, {-1, -1}},\n       {{0, 0}, {3, 4}, {0, 1}},       {{3, 4}, {-2, -2}, {-2, -2}},\n       {{3, 4}, {-2, -2}, {-1, -1}},   {{3, 4}, {-2, -2}, {0, 0}},\n       {{3, 4}, {-1, -1}, {-2, -2}},   {{3, 4}, {-1, -1}, {-1, -1}},\n       {{3, 4}, {-1, -1}, {0, 0}},     {{3, 4}, {0, 0}, {-2, -2}},\n       {{3, 4}, {0, 0}, {-1, -1}},     {{3, 4}, {0, 0}, {0, 1}}});\n  TestEnumerateNdSbpIn2OutSignatures(\n      /*in_shape*/ {6, 4}, /*out_shape*/ {2, 3, 2, 2}, /*rank_mesh*/ {2, 3, 2},\n      {{{0, pair{0, 0}}},\n       {{0, pair{0, 0}}, {1, pair{0, 1}}},\n       {{0, pair{1, 2}}},\n       {{0, pair{1, 2}}, {2, pair{1, 3}}},\n       {{2, pair{0, 0}}},\n       {{2, pair{1, 2}}}},\n      {{{-2, -2}, {-2, -2}, {-2, -2}}, {{-2, -2}, {-2, -2}, {-1, -1}},\n       {{-2, -2}, {-2, -2}, {0, 0}},   {{-2, -2}, {-2, -2}, {1, 2}},\n       {{-2, -2}, {-1, -1}, {-2, -2}}, {{-2, -2}, {-1, -1}, {-1, -1}},\n       {{-2, -2}, {-1, -1}, {0, 0}},   {{-2, -2}, {-1, -1}, {1, 2}},\n       {{-1, -1}, {-2, -2}, {-2, -2}}, {{-1, -1}, {-2, -2}, {-1, -1}},\n       {{-1, -1}, {-2, -2}, {0, 0}},   {{-1, -1}, {-2, -2}, {1, 2}},\n       {{-1, -1}, {-1, -1}, {-2, -2}}, {{-1, -1}, {-1, -1}, {-1, -1}},\n       {{-1, -1}, {-1, -1}, {0, 0}},   {{-1, -1}, {-1, -1}, {1, 2}},\n       {{0, 0}, {-2, -2}, {-2, -2}},   {{0, 0}, {-2, -2}, {-1, -1}},\n       {{0, 0}, {-2, -2}, {1, 2}},     {{0, 0}, {-1, -1}, {-2, -2}},\n       {{0, 0}, {-1, -1}, {-1, -1}},   {{0, 0}, {-1, -1}, {1, 2}},\n       {{0, 0}, {0, 1}, {-2, -2}},     {{0, 0}, {0, 1}, {-1, -1}},\n       {{0, 0}, {0, 1}, {1, 2}},       {{1, 2}, {-2, -2}, {-2, -2}},\n       {{1, 2}, {-2, -2}, {-1, -1}},   {{1, 2}, {-2, -2}, {0, 0}},\n       {{1, 2}, {-2, -2}, {1, 3}},     {{1, 2}, {-1, -1}, {-2, -2}},\n       {{1, 2}, {-1, -1}, {-1, -1}},   {{1, 2}, {-1, -1}, {0, 0}},\n       {{1, 2}, {-1, -1}, {1, 3}}});\n  TestEnumerateNdSbpIn2OutSignatures(\n      /*in_shape*/ {6, 4, 1}, /*out_shape*/ {2, 1, 3, 2, 2}, /*rank_mesh*/ {2, 3, 2},\n      {{{0, pair{0, 0}}},\n       {{0, pair{0, 0}}, {1, pair{0, 2}}},\n       {{0, pair{1, 3}}},\n       {{0, pair{1, 3}}, {2, pair{1, 4}}},\n       {{2, pair{0, 0}}},\n       {{2, pair{1, 3}}}},\n      {{{-2, -2}, {-2, -2}, {-2, -2}}, {{-2, -2}, {-2, -2}, {-1, -1}},\n       {{-2, -2}, {-2, -2}, {0, 0}},   {{-2, -2}, {-2, -2}, {1, 3}},\n       {{-2, -2}, {-1, -1}, {-2, -2}}, {{-2, -2}, {-1, -1}, {-1, -1}},\n       {{-2, -2}, {-1, -1}, {0, 0}},   {{-2, -2}, {-1, -1}, {1, 3}},\n       {{-1, -1}, {-2, -2}, {-2, -2}}, {{-1, -1}, {-2, -2}, {-1, -1}},\n       {{-1, -1}, {-2, -2}, {0, 0}},   {{-1, -1}, {-2, -2}, {1, 3}},\n       {{-1, -1}, {-1, -1}, {-2, -2}}, {{-1, -1}, {-1, -1}, {-1, -1}},\n       {{-1, -1}, {-1, -1}, {0, 0}},   {{-1, -1}, {-1, -1}, {1, 3}},\n       {{0, 0}, {-2, -2}, {-2, -2}},   {{0, 0}, {-2, -2}, {-1, -1}},\n       {{0, 0}, {-2, -2}, {1, 3}},     {{0, 0}, {-1, -1}, {-2, -2}},\n       {{0, 0}, {-1, -1}, {-1, -1}},   {{0, 0}, {-1, -1}, {1, 3}},\n       {{0, 0}, {0, 2}, {-2, -2}},     {{0, 0}, {0, 2}, {-1, -1}},\n       {{0, 0}, {0, 2}, {1, 3}},       {{1, 3}, {-2, -2}, {-2, -2}},\n       {{1, 3}, {-2, -2}, {-1, -1}},   {{1, 3}, {-2, -2}, {0, 0}},\n       {{1, 3}, {-2, -2}, {1, 4}},     {{1, 3}, {-1, -1}, {-2, -2}},\n       {{1, 3}, {-1, -1}, {-1, -1}},   {{1, 3}, {-1, -1}, {0, 0}},\n       {{1, 3}, {-1, -1}, {1, 4}}});\n  TestEnumerateNdSbpIn2OutSignatures(\n      /*in_shape*/ {6, 3, 4}, /*out_shape*/ {2, 3, 3, 2, 2}, /*rank_mesh*/ {2, 3, 2},\n      {{{0, pair{0, 0}}},\n       {{0, pair{0, 0}}, {1, pair{0, 1}}},\n       {{0, pair{2, 3}}},\n       {{0, pair{2, 3}}, {2, pair{2, 4}}},\n       {{1, pair{1, 2}}},\n       {{2, pair{0, 0}}},\n       {{2, pair{2, 3}}}},\n      {{{-2, -2}, {-2, -2}, {-2, -2}}, {{-2, -2}, {-2, -2}, {-1, -1}},\n       {{-2, -2}, {-2, -2}, {0, 0}},   {{-2, -2}, {-2, -2}, {2, 3}},\n       {{-2, -2}, {-1, -1}, {-2, -2}}, {{-2, -2}, {-1, -1}, {-1, -1}},\n       {{-2, -2}, {-1, -1}, {0, 0}},   {{-2, -2}, {-1, -1}, {2, 3}},\n       {{-2, -2}, {1, 2}, {-2, -2}},   {{-2, -2}, {1, 2}, {-1, -1}},\n       {{-2, -2}, {1, 2}, {0, 0}},     {{-2, -2}, {1, 2}, {2, 3}},\n       {{-1, -1}, {-2, -2}, {-2, -2}}, {{-1, -1}, {-2, -2}, {-1, -1}},\n       {{-1, -1}, {-2, -2}, {0, 0}},   {{-1, -1}, {-2, -2}, {2, 3}},\n       {{-1, -1}, {-1, -1}, {-2, -2}}, {{-1, -1}, {-1, -1}, {-1, -1}},\n       {{-1, -1}, {-1, -1}, {0, 0}},   {{-1, -1}, {-1, -1}, {2, 3}},\n       {{-1, -1}, {1, 2}, {-2, -2}},   {{-1, -1}, {1, 2}, {-1, -1}},\n       {{-1, -1}, {1, 2}, {0, 0}},     {{-1, -1}, {1, 2}, {2, 3}},\n       {{0, 0}, {-2, -2}, {-2, -2}},   {{0, 0}, {-2, -2}, {-1, -1}},\n       {{0, 0}, {-2, -2}, {2, 3}},     {{0, 0}, {-1, -1}, {-2, -2}},\n       {{0, 0}, {-1, -1}, {-1, -1}},   {{0, 0}, {-1, -1}, {2, 3}},\n       {{0, 0}, {0, 1}, {-2, -2}},     {{0, 0}, {0, 1}, {-1, -1}},\n       {{0, 0}, {0, 1}, {2, 3}},       {{0, 0}, {1, 2}, {-2, -2}},\n       {{0, 0}, {1, 2}, {-1, -1}},     {{0, 0}, {1, 2}, {2, 3}},\n       {{2, 3}, {-2, -2}, {-2, -2}},   {{2, 3}, {-2, -2}, {-1, -1}},\n       {{2, 3}, {-2, -2}, {0, 0}},     {{2, 3}, {-2, -2}, {2, 4}},\n       {{2, 3}, {-1, -1}, {-2, -2}},   {{2, 3}, {-1, -1}, {-1, -1}},\n       {{2, 3}, {-1, -1}, {0, 0}},     {{2, 3}, {-1, -1}, {2, 4}},\n       {{2, 3}, {1, 2}, {-2, -2}},     {{2, 3}, {1, 2}, {-1, -1}},\n       {{2, 3}, {1, 2}, {0, 0}},       {{2, 3}, {1, 2}, {2, 4}}});\n  TestEnumerateNdSbpIn2OutSignatures(\n      /*in_shape*/ {2, 8}, /*out_shape*/ {2, 2, 2, 2}, /*rank_mesh*/ {2, 2, 2},\n      {{{0, pair{0, 0}}},\n       {{0, pair{1, 1}}},\n       {{0, pair{1, 1}}, {1, pair{1, 2}}},\n       {{0, pair{1, 1}}, {1, pair{1, 2}}, {2, pair{1, 3}}},\n       {{0, pair{1, 1}}, {2, pair{1, 2}}},\n       {{1, pair{0, 0}}},\n       {{1, pair{1, 1}}},\n       {{1, pair{1, 1}}, {2, pair{1, 2}}},\n       {{2, pair{0, 0}}},\n       {{2, pair{1, 1}}}},\n      {{{-2, -2}, {-2, -2}, {-2, -2}}, {{-2, -2}, {-2, -2}, {-1, -1}},\n       {{-2, -2}, {-2, -2}, {0, 0}},   {{-2, -2}, {-2, -2}, {1, 1}},\n       {{-2, -2}, {-1, -1}, {-2, -2}}, {{-2, -2}, {-1, -1}, {-1, -1}},\n       {{-2, -2}, {-1, -1}, {0, 0}},   {{-2, -2}, {-1, -1}, {1, 1}},\n       {{-2, -2}, {0, 0}, {-2, -2}},   {{-2, -2}, {0, 0}, {-1, -1}},\n       {{-2, -2}, {0, 0}, {1, 1}},     {{-2, -2}, {1, 1}, {-2, -2}},\n       {{-2, -2}, {1, 1}, {-1, -1}},   {{-2, -2}, {1, 1}, {0, 0}},\n       {{-2, -2}, {1, 1}, {1, 2}},     {{-1, -1}, {-2, -2}, {-2, -2}},\n       {{-1, -1}, {-2, -2}, {-1, -1}}, {{-1, -1}, {-2, -2}, {0, 0}},\n       {{-1, -1}, {-2, -2}, {1, 1}},   {{-1, -1}, {-1, -1}, {-2, -2}},\n       {{-1, -1}, {-1, -1}, {-1, -1}}, {{-1, -1}, {-1, -1}, {0, 0}},\n       {{-1, -1}, {-1, -1}, {1, 1}},   {{-1, -1}, {0, 0}, {-2, -2}},\n       {{-1, -1}, {0, 0}, {-1, -1}},   {{-1, -1}, {0, 0}, {1, 1}},\n       {{-1, -1}, {1, 1}, {-2, -2}},   {{-1, -1}, {1, 1}, {-1, -1}},\n       {{-1, -1}, {1, 1}, {0, 0}},     {{-1, -1}, {1, 1}, {1, 2}},\n       {{0, 0}, {-2, -2}, {-2, -2}},   {{0, 0}, {-2, -2}, {-1, -1}},\n       {{0, 0}, {-2, -2}, {1, 1}},     {{0, 0}, {-1, -1}, {-2, -2}},\n       {{0, 0}, {-1, -1}, {-1, -1}},   {{0, 0}, {-1, -1}, {1, 1}},\n       {{0, 0}, {1, 1}, {-2, -2}},     {{0, 0}, {1, 1}, {-1, -1}},\n       {{0, 0}, {1, 1}, {1, 2}},       {{1, 1}, {-2, -2}, {-2, -2}},\n       {{1, 1}, {-2, -2}, {-1, -1}},   {{1, 1}, {-2, -2}, {0, 0}},\n       {{1, 1}, {-2, -2}, {1, 2}},     {{1, 1}, {-1, -1}, {-2, -2}},\n       {{1, 1}, {-1, -1}, {-1, -1}},   {{1, 1}, {-1, -1}, {0, 0}},\n       {{1, 1}, {-1, -1}, {1, 2}},     {{1, 1}, {0, 0}, {-2, -2}},\n       {{1, 1}, {0, 0}, {-1, -1}},     {{1, 1}, {0, 0}, {1, 2}},\n       {{1, 1}, {1, 2}, {-2, -2}},     {{1, 1}, {1, 2}, {-1, -1}},\n       {{1, 1}, {1, 2}, {0, 0}},       {{1, 1}, {1, 2}, {1, 3}}});\n  // 4D-split\n  TestEnumerateNdSbpIn2OutSignatures(\n      /*in_shape*/ {4, 77, 8}, /*out_shape*/ {2, 2, 77, 2, 4}, /*rank_mesh*/ {2, 2, 2, 2},\n      {{{0, pair{0, 0}}},\n       {{0, pair{0, 0}}, {1, pair{0, 1}}},\n       {{0, pair{0, 0}}, {2, pair{0, 1}}},\n       {{0, pair{0, 0}}, {3, pair{0, 1}}},\n       {{0, pair{2, 3}}},\n       {{0, pair{2, 3}}, {1, pair{2, 4}}},\n       {{0, pair{2, 3}}, {1, pair{2, 4}}, {2, pair{2, 4}}},\n       {{0, pair{2, 3}}, {1, pair{2, 4}}, {3, pair{2, 4}}},\n       {{0, pair{2, 3}}, {2, pair{2, 4}}},\n       {{0, pair{2, 3}}, {2, pair{2, 4}}, {3, pair{2, 4}}},\n       {{0, pair{2, 3}}, {3, pair{2, 4}}},\n       {{1, pair{0, 0}}},\n       {{1, pair{0, 0}}, {2, pair{0, 1}}},\n       {{1, pair{0, 0}}, {3, pair{0, 1}}},\n       {{1, pair{2, 3}}},\n       {{1, pair{2, 3}}, {2, pair{2, 4}}},\n       {{1, pair{2, 3}}, {2, pair{2, 4}}, {3, pair{2, 4}}},\n       {{1, pair{2, 3}}, {3, pair{2, 4}}},\n       {{2, pair{0, 0}}},\n       {{2, pair{0, 0}}, {3, pair{0, 1}}},\n       {{2, pair{2, 3}}},\n       {{2, pair{2, 3}}, {3, pair{2, 4}}},\n       {{3, pair{0, 0}}},\n       {{3, pair{2, 3}}}},\n      {{{-2, -2}, {-2, -2}, {-2, -2}, {-2, -2}}, {{-2, -2}, {-2, -2}, {-2, -2}, {-1, -1}},\n       {{-2, -2}, {-2, -2}, {-2, -2}, {0, 0}},   {{-2, -2}, {-2, -2}, {-2, -2}, {2, 3}},\n       {{-2, -2}, {-2, -2}, {-1, -1}, {-2, -2}}, {{-2, -2}, {-2, -2}, {-1, -1}, {-1, -1}},\n       {{-2, -2}, {-2, -2}, {-1, -1}, {0, 0}},   {{-2, -2}, {-2, -2}, {-1, -1}, {2, 3}},\n       {{-2, -2}, {-2, -2}, {0, 0}, {-2, -2}},   {{-2, -2}, {-2, -2}, {0, 0}, {-1, -1}},\n       {{-2, -2}, {-2, -2}, {0, 0}, {0, 1}},     {{-2, -2}, {-2, -2}, {0, 0}, {2, 3}},\n       {{-2, -2}, {-2, -2}, {2, 3}, {-2, -2}},   {{-2, -2}, {-2, -2}, {2, 3}, {-1, -1}},\n       {{-2, -2}, {-2, -2}, {2, 3}, {0, 0}},     {{-2, -2}, {-2, -2}, {2, 3}, {2, 4}},\n       {{-2, -2}, {-1, -1}, {-2, -2}, {-2, -2}}, {{-2, -2}, {-1, -1}, {-2, -2}, {-1, -1}},\n       {{-2, -2}, {-1, -1}, {-2, -2}, {0, 0}},   {{-2, -2}, {-1, -1}, {-2, -2}, {2, 3}},\n       {{-2, -2}, {-1, -1}, {-1, -1}, {-2, -2}}, {{-2, -2}, {-1, -1}, {-1, -1}, {-1, -1}},\n       {{-2, -2}, {-1, -1}, {-1, -1}, {0, 0}},   {{-2, -2}, {-1, -1}, {-1, -1}, {2, 3}},\n       {{-2, -2}, {-1, -1}, {0, 0}, {-2, -2}},   {{-2, -2}, {-1, -1}, {0, 0}, {-1, -1}},\n       {{-2, -2}, {-1, -1}, {0, 0}, {0, 1}},     {{-2, -2}, {-1, -1}, {0, 0}, {2, 3}},\n       {{-2, -2}, {-1, -1}, {2, 3}, {-2, -2}},   {{-2, -2}, {-1, -1}, {2, 3}, {-1, -1}},\n       {{-2, -2}, {-1, -1}, {2, 3}, {0, 0}},     {{-2, -2}, {-1, -1}, {2, 3}, {2, 4}},\n       {{-2, -2}, {0, 0}, {-2, -2}, {-2, -2}},   {{-2, -2}, {0, 0}, {-2, -2}, {-1, -1}},\n       {{-2, -2}, {0, 0}, {-2, -2}, {0, 1}},     {{-2, -2}, {0, 0}, {-2, -2}, {2, 3}},\n       {{-2, -2}, {0, 0}, {-1, -1}, {-2, -2}},   {{-2, -2}, {0, 0}, {-1, -1}, {-1, -1}},\n       {{-2, -2}, {0, 0}, {-1, -1}, {0, 1}},     {{-2, -2}, {0, 0}, {-1, -1}, {2, 3}},\n       {{-2, -2}, {0, 0}, {0, 1}, {-2, -2}},     {{-2, -2}, {0, 0}, {0, 1}, {-1, -1}},\n       {{-2, -2}, {0, 0}, {0, 1}, {2, 3}},       {{-2, -2}, {0, 0}, {2, 3}, {-2, -2}},\n       {{-2, -2}, {0, 0}, {2, 3}, {-1, -1}},     {{-2, -2}, {0, 0}, {2, 3}, {0, 1}},\n       {{-2, -2}, {0, 0}, {2, 3}, {2, 4}},       {{-2, -2}, {2, 3}, {-2, -2}, {-2, -2}},\n       {{-2, -2}, {2, 3}, {-2, -2}, {-1, -1}},   {{-2, -2}, {2, 3}, {-2, -2}, {0, 0}},\n       {{-2, -2}, {2, 3}, {-2, -2}, {2, 4}},     {{-2, -2}, {2, 3}, {-1, -1}, {-2, -2}},\n       {{-2, -2}, {2, 3}, {-1, -1}, {-1, -1}},   {{-2, -2}, {2, 3}, {-1, -1}, {0, 0}},\n       {{-2, -2}, {2, 3}, {-1, -1}, {2, 4}},     {{-2, -2}, {2, 3}, {0, 0}, {-2, -2}},\n       {{-2, -2}, {2, 3}, {0, 0}, {-1, -1}},     {{-2, -2}, {2, 3}, {0, 0}, {0, 1}},\n       {{-2, -2}, {2, 3}, {0, 0}, {2, 4}},       {{-2, -2}, {2, 3}, {2, 4}, {-2, -2}},\n       {{-2, -2}, {2, 3}, {2, 4}, {-1, -1}},     {{-2, -2}, {2, 3}, {2, 4}, {0, 0}},\n       {{-2, -2}, {2, 3}, {2, 4}, {2, 4}},       {{-1, -1}, {-2, -2}, {-2, -2}, {-2, -2}},\n       {{-1, -1}, {-2, -2}, {-2, -2}, {-1, -1}}, {{-1, -1}, {-2, -2}, {-2, -2}, {0, 0}},\n       {{-1, -1}, {-2, -2}, {-2, -2}, {2, 3}},   {{-1, -1}, {-2, -2}, {-1, -1}, {-2, -2}},\n       {{-1, -1}, {-2, -2}, {-1, -1}, {-1, -1}}, {{-1, -1}, {-2, -2}, {-1, -1}, {0, 0}},\n       {{-1, -1}, {-2, -2}, {-1, -1}, {2, 3}},   {{-1, -1}, {-2, -2}, {0, 0}, {-2, -2}},\n       {{-1, -1}, {-2, -2}, {0, 0}, {-1, -1}},   {{-1, -1}, {-2, -2}, {0, 0}, {0, 1}},\n       {{-1, -1}, {-2, -2}, {0, 0}, {2, 3}},     {{-1, -1}, {-2, -2}, {2, 3}, {-2, -2}},\n       {{-1, -1}, {-2, -2}, {2, 3}, {-1, -1}},   {{-1, -1}, {-2, -2}, {2, 3}, {0, 0}},\n       {{-1, -1}, {-2, -2}, {2, 3}, {2, 4}},     {{-1, -1}, {-1, -1}, {-2, -2}, {-2, -2}},\n       {{-1, -1}, {-1, -1}, {-2, -2}, {-1, -1}}, {{-1, -1}, {-1, -1}, {-2, -2}, {0, 0}},\n       {{-1, -1}, {-1, -1}, {-2, -2}, {2, 3}},   {{-1, -1}, {-1, -1}, {-1, -1}, {-2, -2}},\n       {{-1, -1}, {-1, -1}, {-1, -1}, {-1, -1}}, {{-1, -1}, {-1, -1}, {-1, -1}, {0, 0}},\n       {{-1, -1}, {-1, -1}, {-1, -1}, {2, 3}},   {{-1, -1}, {-1, -1}, {0, 0}, {-2, -2}},\n       {{-1, -1}, {-1, -1}, {0, 0}, {-1, -1}},   {{-1, -1}, {-1, -1}, {0, 0}, {0, 1}},\n       {{-1, -1}, {-1, -1}, {0, 0}, {2, 3}},     {{-1, -1}, {-1, -1}, {2, 3}, {-2, -2}},\n       {{-1, -1}, {-1, -1}, {2, 3}, {-1, -1}},   {{-1, -1}, {-1, -1}, {2, 3}, {0, 0}},\n       {{-1, -1}, {-1, -1}, {2, 3}, {2, 4}},     {{-1, -1}, {0, 0}, {-2, -2}, {-2, -2}},\n       {{-1, -1}, {0, 0}, {-2, -2}, {-1, -1}},   {{-1, -1}, {0, 0}, {-2, -2}, {0, 1}},\n       {{-1, -1}, {0, 0}, {-2, -2}, {2, 3}},     {{-1, -1}, {0, 0}, {-1, -1}, {-2, -2}},\n       {{-1, -1}, {0, 0}, {-1, -1}, {-1, -1}},   {{-1, -1}, {0, 0}, {-1, -1}, {0, 1}},\n       {{-1, -1}, {0, 0}, {-1, -1}, {2, 3}},     {{-1, -1}, {0, 0}, {0, 1}, {-2, -2}},\n       {{-1, -1}, {0, 0}, {0, 1}, {-1, -1}},     {{-1, -1}, {0, 0}, {0, 1}, {2, 3}},\n       {{-1, -1}, {0, 0}, {2, 3}, {-2, -2}},     {{-1, -1}, {0, 0}, {2, 3}, {-1, -1}},\n       {{-1, -1}, {0, 0}, {2, 3}, {0, 1}},       {{-1, -1}, {0, 0}, {2, 3}, {2, 4}},\n       {{-1, -1}, {2, 3}, {-2, -2}, {-2, -2}},   {{-1, -1}, {2, 3}, {-2, -2}, {-1, -1}},\n       {{-1, -1}, {2, 3}, {-2, -2}, {0, 0}},     {{-1, -1}, {2, 3}, {-2, -2}, {2, 4}},\n       {{-1, -1}, {2, 3}, {-1, -1}, {-2, -2}},   {{-1, -1}, {2, 3}, {-1, -1}, {-1, -1}},\n       {{-1, -1}, {2, 3}, {-1, -1}, {0, 0}},     {{-1, -1}, {2, 3}, {-1, -1}, {2, 4}},\n       {{-1, -1}, {2, 3}, {0, 0}, {-2, -2}},     {{-1, -1}, {2, 3}, {0, 0}, {-1, -1}},\n       {{-1, -1}, {2, 3}, {0, 0}, {0, 1}},       {{-1, -1}, {2, 3}, {0, 0}, {2, 4}},\n       {{-1, -1}, {2, 3}, {2, 4}, {-2, -2}},     {{-1, -1}, {2, 3}, {2, 4}, {-1, -1}},\n       {{-1, -1}, {2, 3}, {2, 4}, {0, 0}},       {{-1, -1}, {2, 3}, {2, 4}, {2, 4}},\n       {{0, 0}, {-2, -2}, {-2, -2}, {-2, -2}},   {{0, 0}, {-2, -2}, {-2, -2}, {-1, -1}},\n       {{0, 0}, {-2, -2}, {-2, -2}, {0, 1}},     {{0, 0}, {-2, -2}, {-2, -2}, {2, 3}},\n       {{0, 0}, {-2, -2}, {-1, -1}, {-2, -2}},   {{0, 0}, {-2, -2}, {-1, -1}, {-1, -1}},\n       {{0, 0}, {-2, -2}, {-1, -1}, {0, 1}},     {{0, 0}, {-2, -2}, {-1, -1}, {2, 3}},\n       {{0, 0}, {-2, -2}, {0, 1}, {-2, -2}},     {{0, 0}, {-2, -2}, {0, 1}, {-1, -1}},\n       {{0, 0}, {-2, -2}, {0, 1}, {2, 3}},       {{0, 0}, {-2, -2}, {2, 3}, {-2, -2}},\n       {{0, 0}, {-2, -2}, {2, 3}, {-1, -1}},     {{0, 0}, {-2, -2}, {2, 3}, {0, 1}},\n       {{0, 0}, {-2, -2}, {2, 3}, {2, 4}},       {{0, 0}, {-1, -1}, {-2, -2}, {-2, -2}},\n       {{0, 0}, {-1, -1}, {-2, -2}, {-1, -1}},   {{0, 0}, {-1, -1}, {-2, -2}, {0, 1}},\n       {{0, 0}, {-1, -1}, {-2, -2}, {2, 3}},     {{0, 0}, {-1, -1}, {-1, -1}, {-2, -2}},\n       {{0, 0}, {-1, -1}, {-1, -1}, {-1, -1}},   {{0, 0}, {-1, -1}, {-1, -1}, {0, 1}},\n       {{0, 0}, {-1, -1}, {-1, -1}, {2, 3}},     {{0, 0}, {-1, -1}, {0, 1}, {-2, -2}},\n       {{0, 0}, {-1, -1}, {0, 1}, {-1, -1}},     {{0, 0}, {-1, -1}, {0, 1}, {2, 3}},\n       {{0, 0}, {-1, -1}, {2, 3}, {-2, -2}},     {{0, 0}, {-1, -1}, {2, 3}, {-1, -1}},\n       {{0, 0}, {-1, -1}, {2, 3}, {0, 1}},       {{0, 0}, {-1, -1}, {2, 3}, {2, 4}},\n       {{0, 0}, {0, 1}, {-2, -2}, {-2, -2}},     {{0, 0}, {0, 1}, {-2, -2}, {-1, -1}},\n       {{0, 0}, {0, 1}, {-2, -2}, {2, 3}},       {{0, 0}, {0, 1}, {-1, -1}, {-2, -2}},\n       {{0, 0}, {0, 1}, {-1, -1}, {-1, -1}},     {{0, 0}, {0, 1}, {-1, -1}, {2, 3}},\n       {{0, 0}, {0, 1}, {2, 3}, {-2, -2}},       {{0, 0}, {0, 1}, {2, 3}, {-1, -1}},\n       {{0, 0}, {0, 1}, {2, 3}, {2, 4}},         {{0, 0}, {2, 3}, {-2, -2}, {-2, -2}},\n       {{0, 0}, {2, 3}, {-2, -2}, {-1, -1}},     {{0, 0}, {2, 3}, {-2, -2}, {0, 1}},\n       {{0, 0}, {2, 3}, {-2, -2}, {2, 4}},       {{0, 0}, {2, 3}, {-1, -1}, {-2, -2}},\n       {{0, 0}, {2, 3}, {-1, -1}, {-1, -1}},     {{0, 0}, {2, 3}, {-1, -1}, {0, 1}},\n       {{0, 0}, {2, 3}, {-1, -1}, {2, 4}},       {{0, 0}, {2, 3}, {0, 1}, {-2, -2}},\n       {{0, 0}, {2, 3}, {0, 1}, {-1, -1}},       {{0, 0}, {2, 3}, {0, 1}, {2, 4}},\n       {{0, 0}, {2, 3}, {2, 4}, {-2, -2}},       {{0, 0}, {2, 3}, {2, 4}, {-1, -1}},\n       {{0, 0}, {2, 3}, {2, 4}, {0, 1}},         {{0, 0}, {2, 3}, {2, 4}, {2, 4}},\n       {{2, 3}, {-2, -2}, {-2, -2}, {-2, -2}},   {{2, 3}, {-2, -2}, {-2, -2}, {-1, -1}},\n       {{2, 3}, {-2, -2}, {-2, -2}, {0, 0}},     {{2, 3}, {-2, -2}, {-2, -2}, {2, 4}},\n       {{2, 3}, {-2, -2}, {-1, -1}, {-2, -2}},   {{2, 3}, {-2, -2}, {-1, -1}, {-1, -1}},\n       {{2, 3}, {-2, -2}, {-1, -1}, {0, 0}},     {{2, 3}, {-2, -2}, {-1, -1}, {2, 4}},\n       {{2, 3}, {-2, -2}, {0, 0}, {-2, -2}},     {{2, 3}, {-2, -2}, {0, 0}, {-1, -1}},\n       {{2, 3}, {-2, -2}, {0, 0}, {0, 1}},       {{2, 3}, {-2, -2}, {0, 0}, {2, 4}},\n       {{2, 3}, {-2, -2}, {2, 4}, {-2, -2}},     {{2, 3}, {-2, -2}, {2, 4}, {-1, -1}},\n       {{2, 3}, {-2, -2}, {2, 4}, {0, 0}},       {{2, 3}, {-2, -2}, {2, 4}, {2, 4}},\n       {{2, 3}, {-1, -1}, {-2, -2}, {-2, -2}},   {{2, 3}, {-1, -1}, {-2, -2}, {-1, -1}},\n       {{2, 3}, {-1, -1}, {-2, -2}, {0, 0}},     {{2, 3}, {-1, -1}, {-2, -2}, {2, 4}},\n       {{2, 3}, {-1, -1}, {-1, -1}, {-2, -2}},   {{2, 3}, {-1, -1}, {-1, -1}, {-1, -1}},\n       {{2, 3}, {-1, -1}, {-1, -1}, {0, 0}},     {{2, 3}, {-1, -1}, {-1, -1}, {2, 4}},\n       {{2, 3}, {-1, -1}, {0, 0}, {-2, -2}},     {{2, 3}, {-1, -1}, {0, 0}, {-1, -1}},\n       {{2, 3}, {-1, -1}, {0, 0}, {0, 1}},       {{2, 3}, {-1, -1}, {0, 0}, {2, 4}},\n       {{2, 3}, {-1, -1}, {2, 4}, {-2, -2}},     {{2, 3}, {-1, -1}, {2, 4}, {-1, -1}},\n       {{2, 3}, {-1, -1}, {2, 4}, {0, 0}},       {{2, 3}, {-1, -1}, {2, 4}, {2, 4}},\n       {{2, 3}, {0, 0}, {-2, -2}, {-2, -2}},     {{2, 3}, {0, 0}, {-2, -2}, {-1, -1}},\n       {{2, 3}, {0, 0}, {-2, -2}, {0, 1}},       {{2, 3}, {0, 0}, {-2, -2}, {2, 4}},\n       {{2, 3}, {0, 0}, {-1, -1}, {-2, -2}},     {{2, 3}, {0, 0}, {-1, -1}, {-1, -1}},\n       {{2, 3}, {0, 0}, {-1, -1}, {0, 1}},       {{2, 3}, {0, 0}, {-1, -1}, {2, 4}},\n       {{2, 3}, {0, 0}, {0, 1}, {-2, -2}},       {{2, 3}, {0, 0}, {0, 1}, {-1, -1}},\n       {{2, 3}, {0, 0}, {0, 1}, {2, 4}},         {{2, 3}, {0, 0}, {2, 4}, {-2, -2}},\n       {{2, 3}, {0, 0}, {2, 4}, {-1, -1}},       {{2, 3}, {0, 0}, {2, 4}, {0, 1}},\n       {{2, 3}, {0, 0}, {2, 4}, {2, 4}},         {{2, 3}, {2, 4}, {-2, -2}, {-2, -2}},\n       {{2, 3}, {2, 4}, {-2, -2}, {-1, -1}},     {{2, 3}, {2, 4}, {-2, -2}, {0, 0}},\n       {{2, 3}, {2, 4}, {-2, -2}, {2, 4}},       {{2, 3}, {2, 4}, {-1, -1}, {-2, -2}},\n       {{2, 3}, {2, 4}, {-1, -1}, {-1, -1}},     {{2, 3}, {2, 4}, {-1, -1}, {0, 0}},\n       {{2, 3}, {2, 4}, {-1, -1}, {2, 4}},       {{2, 3}, {2, 4}, {0, 0}, {-2, -2}},\n       {{2, 3}, {2, 4}, {0, 0}, {-1, -1}},       {{2, 3}, {2, 4}, {0, 0}, {0, 1}},\n       {{2, 3}, {2, 4}, {0, 0}, {2, 4}},         {{2, 3}, {2, 4}, {2, 4}, {-2, -2}},\n       {{2, 3}, {2, 4}, {2, 4}, {-1, -1}},       {{2, 3}, {2, 4}, {2, 4}, {0, 0}}});\n  // clang-format on\n}\n\n}  // namespace test\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/rms_norm_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/* static */ Maybe<void> RmsNormOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& x_shape = ctx->InputShape(\"x\", 0);\n  const Shape& normalized_shape = ctx->Attr<Shape>(\"normalized_shape\");\n  if (ctx->has_input(\"weight\", 0)) {\n    const Shape& w_shape = ctx->InputShape(\"weight\", 0);\n    CHECK_EQ_OR_RETURN(w_shape, normalized_shape)\n        << \"expected weight shape \" << normalized_shape.ToString() << \", got \"\n        << w_shape.ToString();\n  }\n  CHECK_LE_OR_RETURN(normalized_shape.size(), x_shape.size())\n      << \"invalid normalized shape \" << normalized_shape.ToString() << \" with input shape \"\n      << x_shape.ToString();\n  size_t batch_ndim = x_shape.size() - normalized_shape.size();\n  DimVector batch_dims(batch_ndim);\n  for (int i = 0; i < x_shape.size(); ++i) {\n    if (i < batch_ndim) {\n      batch_dims[i] = x_shape[i];\n    } else {\n      CHECK_EQ_OR_RETURN(normalized_shape[i - batch_ndim], x_shape[i])\n          << \"invalid normalized shape \" << normalized_shape.ToString() << \" with input shape \"\n          << x_shape.ToString();\n    }\n  }\n  ctx->SetOutputShape(\"y\", 0, x_shape);\n  ctx->SetOutputShape(\"inv_rms\", 0, Shape{batch_dims});\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> RmsNormOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> RmsNormOp::InferDataType(user_op::InferContext* ctx) {\n  DataType x_dtype = ctx->InputDType(\"x\", 0);\n  if (ctx->has_input(\"weight\", 0)) {\n    DataType w_dtype = ctx->InputDType(\"weight\", 0);\n    CHECK_EQ_OR_RETURN(w_dtype, x_dtype)\n        << \"RmsNormOp \" << ctx->op_name() << \" has different input dtype \" << DataType_Name(x_dtype)\n        << \" and param dtype \" << DataType_Name(w_dtype);\n  }\n  ctx->SetOutputDType(\"y\", 0, x_dtype);\n\n  DataType rms_dtype = x_dtype;\n  if (x_dtype == DataType::kFloat16 || x_dtype == DataType::kBFloat16) {\n    rms_dtype = DataType::kFloat;\n  }\n  ctx->SetOutputDType(\"inv_rms\", 0, rms_dtype);\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> RmsNormOp::GetSbp(user_op::SbpContext* ctx) {\n  const Shape& x_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"x\", 0).shape();\n  const Shape& normalized_shape = ctx->Attr<Shape>(\"normalized_shape\");\n  size_t batch_ndim = x_shape.size() - normalized_shape.size();\n  for (int i = 0; i < batch_ndim; ++i) {\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"x\", 0), i)\n        .Broadcast(user_op::OpArg(\"weight\", 0))\n        .Split(ctx->outputs(), i)\n        .Build();\n  }\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> RmsNormGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& shape = ctx->InputShape(\"dy\", 0);\n  CHECK_EQ_OR_RETURN(ctx->InputShape(\"x\", 0), shape);  // NOLINT(maybe-need-error-msg)\n  // No need to check weight and inv_rms legality which should be guaranteed by forward op\n  ctx->SetOutputShape(\"dx\", 0, shape);\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> RmsNormGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> RmsNormGradOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"dx\", 0, ctx->InputDType(\"x\", 0));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> RmsNormGradOp::GetSbp(user_op::SbpContext* ctx) {\n  std::vector<user_op::OpArg> split_args = {user_op::OpArg(\"dy\", 0), user_op::OpArg(\"x\", 0),\n                                            user_op::OpArg(\"inv_rms\", 0)};\n  std::vector<user_op::OpArg> broadcast_args;\n  if (ctx->user_op_conf().has_input(\"weight\", 0)) { broadcast_args.emplace_back(\"weight\", 0); }\n  const Shape& b_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"inv_rms\", 0).shape();\n  for (int i = 0; i < b_shape.size(); ++i) {\n    ctx->NewBuilder()\n        .Split(split_args, i)\n        .Broadcast(broadcast_args)\n        .Split(ctx->outputs(), i)\n        .Build();\n  }\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> RmsNormParamGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& shape = ctx->InputShape(\"dy\", 0);\n  CHECK_EQ_OR_RETURN(ctx->InputShape(\"x\", 0), shape);  // NOLINT(maybe-need-error-msg)\n  const Shape& b_shape = ctx->InputShape(\"inv_rms\", 0);\n\n  CHECK_LE_OR_RETURN(b_shape.size(), shape.size())\n      << \"invalid inv_rms shape \" << b_shape.ToString() << \" with dy shape \" << shape.ToString();\n  size_t n_ndim = shape.size() - b_shape.size();\n  DimVector n_shape_vec(n_ndim);\n  for (int i = 0; i < shape.size(); ++i) {\n    if (i < b_shape.size()) {\n      CHECK_EQ_OR_RETURN(b_shape[i], shape[i]) << \"invalid inv_rms shape \" << b_shape.ToString()\n                                               << \" with dy shape \" << shape.ToString();\n    } else {\n      n_shape_vec[i - b_shape.size()] = shape[i];\n    }\n  }\n  ctx->SetOutputShape(\"weight_grad\", 0, Shape{n_shape_vec});\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> RmsNormParamGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> RmsNormParamGradOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"weight_grad\", 0, ctx->InputDType(\"x\", 0));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> RmsNormParamGradOp::GetSbp(user_op::SbpContext* ctx) {\n  const Shape& b_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"inv_rms\", 0).shape();\n  for (int i = 0; i < b_shape.size(); ++i) {\n    ctx->NewBuilder().Split(ctx->inputs(), i).PartialSum(ctx->outputs()).Build();\n  }\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/roc_auc_score_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/* static */ Maybe<void> RocAucScoreOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc(\"out\", 0);\n  const Shape& pred_shape = ctx->InputTensorDesc(\"pred\", 0).shape();\n  const Shape& label_shape = ctx->InputTensorDesc(\"label\", 0).shape();\n  CHECK_EQ_OR_RETURN(pred_shape.elem_cnt(), label_shape.elem_cnt())\n      << \"pred and label MUST have same element count.\";\n  out_desc->set_is_dynamic(false);\n  out_desc->set_shape(Shape({1}));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> RocAucScoreOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> RocAucScoreOp::GetSbp(user_op::SbpContext* ctx) {\n  return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx);\n}\n\n/* static */ Maybe<void> RocAucScoreOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, DataType::kFloat);\n  const user_op::TensorDesc& label = ctx->InputTensorDesc(\"label\", 0);\n  CHECK_OR_RETURN(IsFloatingDataType(label.data_type()) || IsIntegralDataType(label.data_type()))\n      << \"Input `label` data type \" << DataType_Name(label.data_type()) << \" is not supported.\";\n  const user_op::TensorDesc& pred = ctx->InputTensorDesc(\"pred\", 0);\n  CHECK_OR_RETURN(pred.data_type() == DataType::kFloat)\n      << \"Input `pred` data type \" << DataType_Name(pred.data_type()) << \" is not supported.\";\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/roi_align_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/*static*/ Maybe<void> RoiAlignOp::GetSbp(user_op::SbpContext* ctx) {\n  ctx->NewBuilder()\n      .Broadcast(user_op::OpArg(\"x\", 0))\n      .Split(user_op::OpArg(\"rois\", 0), 0)\n      .Split(user_op::OpArg(\"y\", 0), 0)\n      .Build();\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> RoiAlignOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& x_shape = ctx->InputShape(\"x\", 0);\n  const Shape& rois_shape = ctx->InputShape(\"rois\", 0);\n  const int32_t pooled_h = ctx->Attr<int32_t>(\"pooled_h\");\n  const int32_t pooled_w = ctx->Attr<int32_t>(\"pooled_w\");\n  // x: feature map (N, C, H, W)\n  CHECK_EQ_OR_RETURN(x_shape.NumAxes(), 4)\n      << Error::RuntimeError() << \"The dimension of x tensor must be equal to 4, \"\n      << \"but got \" << x_shape.NumAxes();\n  // rois: (R, 5)\n  CHECK_EQ_OR_RETURN(rois_shape.NumAxes(), 2)\n      << Error::RuntimeError() << \"The dimension of rois tensor must be equal to 2, \"\n      << \"but got \" << rois_shape.NumAxes();\n  CHECK_EQ_OR_RETURN(rois_shape.At(1), 5)\n      << Error::RuntimeError() << \"The size of rois tensor must be equal to 5 at dimension 1, \"\n      << \"but got \" << rois_shape.At(1);\n  // y: (R, C, pool_h, pool_w)\n  ctx->SetOutputShape(\"y\", 0, Shape({rois_shape.At(0), x_shape.At(1), pooled_h, pooled_w}));\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> RoiAlignOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> RoiAlignOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"y\", 0, ctx->InputDType(\"x\", 0));\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> RoiAlignOp::ModifyInputArg(const GetInputArgModifier& GetInputArgModifierFn,\n                                                  const user_op::UserOpConfWrapper&) {\n  user_op::InputArgModifier* roi_modifier = GetInputArgModifierFn(\"rois\", 0);\n  CHECK_OR_RETURN(roi_modifier != nullptr);  // NOLINT(maybe-need-error-msg)\n  roi_modifier->set_requires_grad(false);\n  user_op::InputArgModifier* feat_modifier = GetInputArgModifierFn(\"x\", 0);\n  CHECK_OR_RETURN(feat_modifier != nullptr);  //  NOLINT(maybe-need-error-msg)\n  feat_modifier->set_requires_grad(true);\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> RoiAlignGradOp::GetSbp(user_op::SbpContext* ctx) {\n  ctx->NewBuilder()\n      .Split(user_op::OpArg(\"dy\", 0), 0)\n      .Broadcast(user_op::OpArg(\"x_like\", 0))\n      .Split(user_op::OpArg(\"rois\", 0), 0)\n      .Broadcast(user_op::OpArg(\"dx\", 0))\n      .Build();\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> RoiAlignGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& dy_shape = ctx->InputShape(\"dy\", 0);\n  const Shape& x_like_shape = ctx->InputShape(\"x_like\", 0);\n  const Shape& rois_shape = ctx->InputShape(\"rois\", 0);\n  const int32_t pooled_h = ctx->Attr<int32_t>(\"pooled_h\");\n  const int32_t pooled_w = ctx->Attr<int32_t>(\"pooled_w\");\n  // x: feature map (N, C, H, W)\n  CHECK_EQ_OR_RETURN(x_like_shape.NumAxes(), 4)\n      << Error::RuntimeError() << \"The dimension of x_like tensor must be equal to 4, \"\n      << \"but got \" << x_like_shape.NumAxes();\n\n  // rois: (R, 5)\n  CHECK_EQ_OR_RETURN(rois_shape.NumAxes(), 2)\n      << Error::RuntimeError() << \"The dimension of rois tensor must be equal to 2, \"\n      << \"but got \" << rois_shape.NumAxes();\n  CHECK_EQ_OR_RETURN(rois_shape.At(1), 5)\n      << Error::RuntimeError() << \"The size of rois tensor must be equal to 5 \"\n      << \"at dimension 1, \"\n      << \"but got \" << rois_shape.At(1);\n  // y: (R, C, pool_h, pool_w)\n  const Shape& y_shape = Shape({rois_shape.At(0), x_like_shape.At(1), pooled_h, pooled_w});\n  CHECK_EQ_OR_RETURN(y_shape, dy_shape)\n      << Error::RuntimeError() << \"Tensors y and dy must have same shape\";\n  ctx->SetOutputShape(\"dx\", 0, x_like_shape);\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> RoiAlignGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> RoiAlignGradOp::InferDataType(user_op::InferContext* ctx) {\n  CHECK_EQ_OR_RETURN(ctx->InputDType(\"dy\", 0), ctx->InputDType(\"x_like\", 0))\n      << Error::TypeError() << \"The dy tensor and x_like tensor must have same type\";\n\n  ctx->SetOutputDType(\"dx\", 0, ctx->InputDType(\"x_like\", 0));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/roll_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/*static*/ Maybe<void> RollOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"in\", 0);\n  const std::vector<int32_t>& dims = ctx->Attr<std::vector<int32_t>>(\"dims\");\n\n  CHECK_GT_OR_RETURN(dims.size(), 0)\n      << Error::RuntimeError() << \"The input list of dims doesn't allow to be empty\";\n  // NOTE(Liang Depeng): (dims.size == 1 && dims[0] == -1) means that user call flow.roll with\n  // dims == None\n  if (dims[0] != -1) {\n    FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) {\n      if (std::find(dims.begin(), dims.end(), i) == dims.end()) {\n        ctx->NewBuilder()\n            .Split(user_op::OpArg(\"in\", 0), i)\n            .Split(user_op::OpArg(\"out\", 0), i)\n            .Build();\n      }\n    }\n  }\n\n  ctx->NewBuilder()\n      .PartialSum(user_op::OpArg(\"in\", 0))\n      .PartialSum(user_op::OpArg(\"out\", 0))\n      .Build();\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> RollOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& in_shape = ctx->InputShape(\"in\", 0);\n  ctx->SetOutputShape(\"out\", 0, in_shape);\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> RollOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> RollOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/rrelu_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/* static */ Maybe<void> RReluOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& in_shape = ctx->InputShape(\"in\", 0);\n  ctx->SetOutputShape(\"output\", 0, in_shape);\n  ctx->SetOutputShape(\"noise_data\", 0, in_shape);\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> RReluOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> RReluOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"in\", 0);\n  FOR_RANGE(int64_t, axis, 0, in_tensor.shape().NumAxes()) {\n    ctx->NewBuilder().Split(ctx->inputs(), axis).Split(ctx->outputs(), axis).Build();\n  }\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> RReluOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"output\", 0, ctx->InputDType(\"in\", 0));\n  ctx->SetOutputDType(\"noise_data\", 0, ctx->InputDType(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/same_padding_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/common/balanced_splitter.h\"\n#include \"oneflow/user/ops/nn_util.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/*static*/ Maybe<void> SamePaddingOp::GetSbp(user_op::SbpContext* ctx) {\n  const int32_t num_axes =\n      ctx->LogicalTensorDesc4InputArgNameAndIndex(\"x_like\", 0).shape().NumAxes();\n  const std::string& data_format = ctx->Attr<std::string>(\"data_format\");\n  ctx->NewBuilder().Split(user_op::OpArg(\"x\", 0), 0).Split(user_op::OpArg(\"y\", 0), 0).Build();\n  const int32_t channel_idx = ChannelIdx(data_format, num_axes);\n  ctx->NewBuilder()\n      .Split(user_op::OpArg(\"x\", 0), channel_idx)\n      .Split(user_op::OpArg(\"y\", 0), channel_idx)\n      .Build();\n  ctx->NewBuilder().PartialSum(user_op::OpArg(\"x\", 0)).PartialSum(user_op::OpArg(\"y\", 0)).Build();\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> SamePaddingOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& x_desc = ctx->InputTensorDesc(\"x\", 0);\n  user_op::TensorDesc* y_desc = ctx->MutOutputTensorDesc(\"y\", 0);\n  y_desc->set_shape(x_desc.shape());\n  y_desc->set_is_dynamic(x_desc.is_dynamic());\n  const std::string& data_format = ctx->Attr<std::string>(\"data_format\");\n  const auto& kernel_size = ctx->Attr<std::vector<int32_t>>(\"kernel_size\");\n  const auto& strides = ctx->Attr<std::vector<int32_t>>(\"strides\");\n  const auto& dilation_rate = ctx->Attr<std::vector<int32_t>>(\"dilation_rate\");\n  const size_t idx_offset = IdxOffset(data_format);\n  const int32_t num_spatial_dims = x_desc.shape().NumAxes() - 2;\n  CHECK_EQ_OR_RETURN(num_spatial_dims, kernel_size.size())\n      << Error::RuntimeError()\n      << \"The dimension of x tensor must be equal to the size of kernel_size array plus 2, \"\n      << \"but got \" << num_spatial_dims << \" and \" << kernel_size.size();\n  CHECK_EQ_OR_RETURN(num_spatial_dims, strides.size())\n      << Error::RuntimeError()\n      << \"The dimension of x tensor must be equal to the size of strides array plus 2, \"\n      << \"but got \" << num_spatial_dims << \" and \" << strides.size();\n  CHECK_EQ_OR_RETURN(num_spatial_dims, dilation_rate.size())\n      << Error::RuntimeError()\n      << \"The dimension of x tensor must be equal to the size of dilation_rate array plus 2, \"\n      << \"but got \" << num_spatial_dims << \" and \" << dilation_rate.size();\n  DimVector y_dim_vec(x_desc.shape().dim_vec());\n  for (int32_t i = 0; i < num_spatial_dims; ++i) {\n    int32_t padding_small = 0;\n    int32_t padding_large = 0;\n    JUST(CalcSamePadding(x_desc.shape().At(idx_offset + i), kernel_size.at(i), dilation_rate.at(i),\n                         strides.at(i), &padding_small, &padding_large));\n    y_dim_vec[idx_offset + i] = x_desc.shape().At(idx_offset + i) + padding_small + padding_large;\n  }\n  y_desc->set_shape(Shape(y_dim_vec));\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> SamePaddingOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> SamePaddingOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"y\", 0, ctx->InputDType(\"x\", 0));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> SamePaddingGradOp::GetSbp(user_op::SbpContext* ctx) {\n  const int32_t num_axes =\n      ctx->LogicalTensorDesc4InputArgNameAndIndex(\"x_like\", 0).shape().NumAxes();\n  const std::string& data_format = ctx->Attr<std::string>(\"data_format\");\n  ctx->NewBuilder()\n      .Split(user_op::OpArg(\"x_like\", 0), 0)\n      .Split(user_op::OpArg(\"dy\", 0), 0)\n      .Split(user_op::OpArg(\"dx\", 0), 0)\n      .Build();\n  const int32_t channel_idx = ChannelIdx(data_format, num_axes);\n  ctx->NewBuilder()\n      .Split(user_op::OpArg(\"x_like\", 0), channel_idx)\n      .Split(user_op::OpArg(\"dy\", 0), channel_idx)\n      .Split(user_op::OpArg(\"dx\", 0), channel_idx)\n      .Build();\n  ctx->NewBuilder()\n      .PartialSum(user_op::OpArg(\"x_like\", 0))\n      .PartialSum(user_op::OpArg(\"dy\", 0))\n      .PartialSum(user_op::OpArg(\"dx\", 0))\n      .Build();\n  ctx->NewBuilder()\n      .Broadcast(user_op::OpArg(\"x_like\", 0))\n      .PartialSum(user_op::OpArg(\"dy\", 0))\n      .PartialSum(user_op::OpArg(\"dx\", 0))\n      .Build();\n  ctx->NewBuilder()\n      .PartialSum(user_op::OpArg(\"x_like\", 0))\n      .Broadcast(user_op::OpArg(\"dy\", 0))\n      .Broadcast(user_op::OpArg(\"dx\", 0))\n      .Build();\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> SamePaddingGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  ctx->SetOutputShape(\"dx\", 0, ctx->InputShape(\"x_like\", 0));\n  ctx->SetOutputIsDynamic(\"dx\", 0, ctx->InputIsDynamic(\"x_like\", 0));\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> SamePaddingGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> SamePaddingGradOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"dx\", 0, ctx->InputDType(\"x_like\", 0));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/scalar_bitwise_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n#define IMPLEMENT_SCALAR_BITWISE_OP_FUNCS(name)                                                  \\\n  /*static*/ Maybe<void> name##Op::GetSbp(user_op::SbpContext* ctx) {                            \\\n    const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"in\", 0); \\\n    FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) {                                      \\\n      ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build();                \\\n    }                                                                                            \\\n    return Maybe<void>::Ok();                                                                    \\\n  }                                                                                              \\\n  /*static*/ Maybe<void> name##Op::InferLogicalTensorDesc(user_op::InferContext* ctx) {          \\\n    ctx->SetOutputShape(\"out\", 0, ctx->InputShape(\"in\", 0));                                     \\\n    ctx->SetOutputIsDynamic(\"out\", 0, ctx->InputIsDynamic(\"in\", 0));                             \\\n    return Maybe<void>::Ok();                                                                    \\\n  }                                                                                              \\\n  /*static*/ Maybe<void> name##Op::InferPhysicalTensorDesc(user_op::InferContext* ctx) {         \\\n    return InferLogicalTensorDesc(ctx);                                                          \\\n  }                                                                                              \\\n  /*static*/ Maybe<void> name##Op::InferDataType(user_op::InferContext* ctx) {                   \\\n    ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"in\", 0));                                     \\\n    return Maybe<void>::Ok();                                                                    \\\n  }\n\nIMPLEMENT_SCALAR_BITWISE_OP_FUNCS(ScalarBitwiseAnd);\nIMPLEMENT_SCALAR_BITWISE_OP_FUNCS(ScalarBitwiseOr);\nIMPLEMENT_SCALAR_BITWISE_OP_FUNCS(ScalarBitwiseXor);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/scalar_by_tensor_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nMaybe<void> TensorDescInferFn(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& x = ctx->InputTensorDesc(\"x\", 0);\n  const user_op::TensorDesc& scalar = ctx->InputTensorDesc(\"scalar\", 0);\n  CHECK_EQ_OR_RETURN(scalar.shape().elem_cnt(), 1)\n      << Error::RuntimeError() << \"The input scalar tensor is not a scalar\";\n  user_op::TensorDesc* y = ctx->MutOutputTensorDesc(\"y\", 0);\n  y->set_shape(x.shape());\n  y->set_is_dynamic(x.is_dynamic());\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> DataTypeInferFn(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& x = ctx->InputTensorDesc(\"x\", 0);\n  const user_op::TensorDesc& scalar = ctx->InputTensorDesc(\"scalar\", 0);\n  CHECK_EQ_OR_RETURN(x.data_type(), scalar.data_type())\n      << Error::TypeError() << \"Tensors x and scalar have different type\";\n  user_op::TensorDesc* y = ctx->MutOutputTensorDesc(\"y\", 0);\n  y->set_data_type(x.data_type());\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> GetBasicSbpSignature(user_op::SbpContext* ctx) {\n  const auto& x = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"x\", 0);\n  FOR_RANGE(int64_t, i, 0, x.shape().NumAxes()) {\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"x\", 0), i)\n        .Split(user_op::OpArg(\"y\", 0), i)\n        .Broadcast(user_op::OpArg(\"scalar\", 0))\n        .Build();\n  }\n  return Maybe<void>::Ok();\n}\n\nusing GetSbpFn = std::function<Maybe<void>(user_op::SbpContext*)>;\nGetSbpFn MakeGetSbpFn(GetSbpFn extra) {\n  return [extra](user_op::SbpContext* ctx) -> Maybe<void> {\n    JUST(extra(ctx));\n    JUST(GetBasicSbpSignature(ctx));\n    return Maybe<void>::Ok();\n  };\n}\n\n}  // namespace\n\n/*static*/ Maybe<void> ScalarAddByTensorOp::GetSbp(user_op::SbpContext* ctx) {\n  return MakeGetSbpFn([](user_op::SbpContext* ctx) {\n    ctx->NewBuilder()\n        .PartialSum(user_op::OpArg(\"x\", 0))\n        .PartialSum(user_op::OpArg(\"scalar\", 0))\n        .PartialSum(user_op::OpArg(\"y\", 0))\n        .Build();\n    return Maybe<void>::Ok();\n  })(ctx);\n}\n/*static*/ Maybe<void> ScalarAddByTensorOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  return TensorDescInferFn(ctx);\n}\n/*static*/ Maybe<void> ScalarAddByTensorOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> ScalarAddByTensorOp::InferDataType(user_op::InferContext* ctx) {\n  return DataTypeInferFn(ctx);\n}\n\n/*static*/ Maybe<void> HostScalarAddByTensorOp::GetSbp(user_op::SbpContext* ctx) {\n  return MakeGetSbpFn([](user_op::SbpContext* ctx) {\n    ctx->NewBuilder()\n        .PartialSum(user_op::OpArg(\"x\", 0))\n        .PartialSum(user_op::OpArg(\"scalar\", 0))\n        .PartialSum(user_op::OpArg(\"y\", 0))\n        .Build();\n    return Maybe<void>::Ok();\n  })(ctx);\n}\n/*static*/ Maybe<void> HostScalarAddByTensorOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  return TensorDescInferFn(ctx);\n}\n/*static*/ Maybe<void> HostScalarAddByTensorOp::InferPhysicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> HostScalarAddByTensorOp::InferDataType(user_op::InferContext* ctx) {\n  return DataTypeInferFn(ctx);\n}\n\nREGISTER_OP_HOST_MEMORY_INPUT(\"host_scalar_add_by_tensor\", \"scalar\", 0);\n\n/*static*/ Maybe<void> ScalarSubByTensorOp::GetSbp(user_op::SbpContext* ctx) {\n  return MakeGetSbpFn([](user_op::SbpContext* ctx) {\n    ctx->NewBuilder()\n        .PartialSum(user_op::OpArg(\"x\", 0))\n        .PartialSum(user_op::OpArg(\"scalar\", 0))\n        .PartialSum(user_op::OpArg(\"y\", 0))\n        .Build();\n    return Maybe<void>::Ok();\n  })(ctx);\n}\n/*static*/ Maybe<void> ScalarSubByTensorOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  return TensorDescInferFn(ctx);\n}\n/*static*/ Maybe<void> ScalarSubByTensorOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> ScalarSubByTensorOp::InferDataType(user_op::InferContext* ctx) {\n  return DataTypeInferFn(ctx);\n}\n\n/*static*/ Maybe<void> ScalarMulByTensorOp::GetSbp(user_op::SbpContext* ctx) {\n  return MakeGetSbpFn([](user_op::SbpContext* ctx) {\n    ctx->NewBuilder()\n        .PartialSum(user_op::OpArg(\"x\", 0))\n        .Broadcast(user_op::OpArg(\"scalar\", 0))\n        .PartialSum(user_op::OpArg(\"y\", 0))\n        .Build();\n    ctx->NewBuilder()\n        .Broadcast(user_op::OpArg(\"x\", 0))\n        .PartialSum(user_op::OpArg(\"scalar\", 0))\n        .PartialSum(user_op::OpArg(\"y\", 0))\n        .Build();\n    return Maybe<void>::Ok();\n  })(ctx);\n}\n/*static*/ Maybe<void> ScalarMulByTensorOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  return TensorDescInferFn(ctx);\n}\n/*static*/ Maybe<void> ScalarMulByTensorOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> ScalarMulByTensorOp::InferDataType(user_op::InferContext* ctx) {\n  return DataTypeInferFn(ctx);\n}\n\n/*static*/ Maybe<void> ScalarDivByTensorOp::GetSbp(user_op::SbpContext* ctx) {\n  return MakeGetSbpFn([](user_op::SbpContext* ctx) {\n    ctx->NewBuilder()\n        .PartialSum(user_op::OpArg(\"x\", 0))\n        .Broadcast(user_op::OpArg(\"scalar\", 0))\n        .PartialSum(user_op::OpArg(\"y\", 0))\n        .Build();\n    return Maybe<void>::Ok();\n  })(ctx);\n}\n/*static*/ Maybe<void> ScalarDivByTensorOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  return TensorDescInferFn(ctx);\n}\n/*static*/ Maybe<void> ScalarDivByTensorOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> ScalarDivByTensorOp::InferDataType(user_op::InferContext* ctx) {\n  return DataTypeInferFn(ctx);\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/scalar_logical_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n#define IMPLEMENT_SCALAR_LOGICAL_OP_FUNCS(name)                                                  \\\n  /*static*/ Maybe<void> name##Op::GetSbp(user_op::SbpContext* ctx) {                            \\\n    const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"in\", 0); \\\n    FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) {                                      \\\n      ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build();                \\\n    }                                                                                            \\\n    return Maybe<void>::Ok();                                                                    \\\n  }                                                                                              \\\n  /*static*/ Maybe<void> name##Op::InferLogicalTensorDesc(user_op::InferContext* ctx) {          \\\n    ctx->SetOutputShape(\"out\", 0, ctx->InputShape(\"in\", 0));                                     \\\n    ctx->SetOutputIsDynamic(\"out\", 0, ctx->InputIsDynamic(\"in\", 0));                             \\\n    return Maybe<void>::Ok();                                                                    \\\n  }                                                                                              \\\n  /*static*/ Maybe<void> name##Op::InferPhysicalTensorDesc(user_op::InferContext* ctx) {         \\\n    return InferLogicalTensorDesc(ctx);                                                          \\\n  }                                                                                              \\\n  /*static*/ Maybe<void> name##Op::InferDataType(user_op::InferContext* ctx) {                   \\\n    ctx->SetOutputDType(\"out\", 0, DataType::kBool);                                              \\\n    return Maybe<void>::Ok();                                                                    \\\n  }\n\nIMPLEMENT_SCALAR_LOGICAL_OP_FUNCS(ScalarLogicalEqual);\nIMPLEMENT_SCALAR_LOGICAL_OP_FUNCS(ScalarLogicalNotEqual);\nIMPLEMENT_SCALAR_LOGICAL_OP_FUNCS(ScalarLogicalGreater);\nIMPLEMENT_SCALAR_LOGICAL_OP_FUNCS(ScalarLogicalGreaterEqual);\nIMPLEMENT_SCALAR_LOGICAL_OP_FUNCS(ScalarLogicalLess);\nIMPLEMENT_SCALAR_LOGICAL_OP_FUNCS(ScalarLogicalLessEqual);\nIMPLEMENT_SCALAR_LOGICAL_OP_FUNCS(ScalarLogicalAnd);\nIMPLEMENT_SCALAR_LOGICAL_OP_FUNCS(ScalarLogicalOr);\nIMPLEMENT_SCALAR_LOGICAL_OP_FUNCS(ScalarLogicalXor);\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/scalar_math_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nMaybe<void> GetSbp4ScalarMath(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"in\", 0);\n  FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) {\n    ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build();\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> GetSbp4ScalarMul(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"in\", 0);\n  FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) {\n    ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build();\n  }\n  ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build();\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\n#define IMPLEMENT_SCALAR_MATH_OP_FUNCS(op_name, get_sbp_fn)                                        \\\n  /*static*/ Maybe<void> op_name##Op::GetSbp(user_op::SbpContext* ctx) { return get_sbp_fn(ctx); } \\\n  /*static*/ Maybe<void> op_name##Op::InferLogicalTensorDesc(user_op::InferContext* ctx) {         \\\n    ctx->SetOutputShape(\"out\", 0, ctx->InputShape(\"in\", 0));                                       \\\n    ctx->SetOutputIsDynamic(\"out\", 0, ctx->InputIsDynamic(\"in\", 0));                               \\\n    return Maybe<void>::Ok();                                                                      \\\n  }                                                                                                \\\n  /*static*/ Maybe<void> op_name##Op::InferPhysicalTensorDesc(user_op::InferContext* ctx) {        \\\n    return InferLogicalTensorDesc(ctx);                                                            \\\n  }                                                                                                \\\n  /*static*/ Maybe<void> op_name##Op::InferDataType(user_op::InferContext* ctx) {                  \\\n    ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"in\", 0));                                       \\\n    return Maybe<void>::Ok();                                                                      \\\n  }\n\nIMPLEMENT_SCALAR_MATH_OP_FUNCS(ScalarAdd, GetSbp4ScalarMath)\nIMPLEMENT_SCALAR_MATH_OP_FUNCS(ScalarFloordiv, GetSbp4ScalarMath)\nIMPLEMENT_SCALAR_MATH_OP_FUNCS(ScalarTruncdiv, GetSbp4ScalarMath)\nIMPLEMENT_SCALAR_MATH_OP_FUNCS(ScalarFmod, GetSbp4ScalarMath)\nIMPLEMENT_SCALAR_MATH_OP_FUNCS(ScalarMul, GetSbp4ScalarMul)\nIMPLEMENT_SCALAR_MATH_OP_FUNCS(ScalarDiv, GetSbp4ScalarMul)\nIMPLEMENT_SCALAR_MATH_OP_FUNCS(ScalarPow, GetSbp4ScalarMath)\nIMPLEMENT_SCALAR_MATH_OP_FUNCS(ScalarReversePow, GetSbp4ScalarMath)\n#undef IMPLEMENT_SCALAR_MATH_OP_FUNCS\n\n/*static*/ Maybe<void> ScalarPowGradOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"x\", 0);\n  FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) {\n    ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build();\n  }\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> ScalarPowGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  ctx->SetOutputShape(\"dx\", 0, ctx->InputShape(\"x\", 0));\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> ScalarPowGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> ScalarPowGradOp::InferDataType(user_op::InferContext* ctx) {\n  CHECK_EQ_OR_RETURN(ctx->InputDType(\"x\", 0), ctx->InputDType(\"dy\", 0))\n      << Error::TypeError() << \"Tensors dy and x must have same type\";\n  ctx->SetOutputDType(\"dx\", 0, ctx->InputDType(\"x\", 0));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> ScalarReversePowGradOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"x\", 0);\n  FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) {\n    ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build();\n  }\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> ScalarReversePowGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  ctx->SetOutputShape(\"dx\", 0, ctx->InputShape(\"x\", 0));\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> ScalarReversePowGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> ScalarReversePowGradOp::InferDataType(user_op::InferContext* ctx) {\n  CHECK_EQ_OR_RETURN(ctx->InputDType(\"x\", 0), ctx->InputDType(\"dy\", 0))\n      << Error::TypeError() << \"Tensors dy and x must have same type\";\n  ctx->SetOutputDType(\"dx\", 0, ctx->InputDType(\"x\", 0));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/scaled_dot_product_flash_attention_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/common/data_type.pb.h\"\n#include \"oneflow/core/common/just.h\"\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/common/shape.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\nMaybe<void> ScaledDotProductFlashAttentionOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& q_shape = ctx->InputShape(\"query\", 0);\n  const Shape& k_shape = ctx->InputShape(\"key\", 0);\n  const Shape& v_shape = ctx->InputShape(\"value\", 0);\n\n  auto batch_size = q_shape.At(0);\n  auto seqlen_q = q_shape.At(1);\n  auto num_heads = q_shape.At(2);\n  auto head_size_og = q_shape.At(3);\n  auto seqlen_k = k_shape.At(1);\n  auto num_heads_k = k_shape.At(2);\n\n  // check input tensor shape.\n  CHECK_EQ_OR_RETURN(batch_size, k_shape.At(0)) << \"query has different batch size from key.\";\n  CHECK_EQ_OR_RETURN(batch_size, v_shape.At(0)) << \"query has different batch size from value.\";\n\n  CHECK_EQ_OR_RETURN(seqlen_k, v_shape.At(1)) << \"key has different seqlen from value.\";\n  CHECK_EQ_OR_RETURN(num_heads_k, v_shape.At(2)) << \"key has different num_heads from value.\";\n\n  CHECK_EQ_OR_RETURN(head_size_og, k_shape.At(3)) << \"query has different head_size from key\";\n  CHECK_EQ_OR_RETURN(head_size_og, v_shape.At(3)) << \"query has different head_size from value\";\n\n  // batch size must be positive.\n  CHECK_GT_OR_RETURN(batch_size, 0) << \"batch size must be positive\";\n\n  // only support head dimensions at most 256.\n  CHECK_LE_OR_RETURN(head_size_og, 256) << \"only support head dimensions at most 256\";\n\n  // number of heads in key/value must devide number of heads in query.\n  CHECK_EQ_OR_RETURN(num_heads % num_heads_k, 0)\n      << \"number of heads in key/value must devide number of heads in query.\";\n\n  ctx->SetOutputShape(\"out\", 0, Shape({batch_size, seqlen_q, num_heads, head_size_og}));\n  // save for backward\n  ctx->SetOutputShape(\"softmax_lse\", 0, Shape({batch_size, num_heads, seqlen_q}));\n  // save seed and offset for backward.\n  ctx->SetOutputShape(\"rng_state\", 0, Shape({2}));\n\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> ScaledDotProductFlashAttentionOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return ScaledDotProductFlashAttentionOp::InferLogicalTensorDesc(ctx);\n}\n\nMaybe<void> ScaledDotProductFlashAttentionOp::GetSbp(user_op::SbpContext* ctx) {\n  auto parallel_num = ctx->parallel_num();\n  const Shape& q_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"query\", 0).shape();\n  const Shape& k_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"key\", 0).shape();\n  auto num_heads = q_shape.At(2);\n  auto num_heads_k = k_shape.At(2);\n  bool can_spilt_num_heads =\n      num_heads == num_heads_k || (!(num_heads % parallel_num) && !(num_heads_k % parallel_num));\n  if (can_spilt_num_heads) {\n    // prior to split on num_heads.\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"query\", 0), 2)\n        .Split(user_op::OpArg(\"key\", 0), 2)\n        .Split(user_op::OpArg(\"value\", 0), 2)\n        .Split(user_op::OpArg(\"out\", 0), 2)\n        .Split(user_op::OpArg(\"softmax\", 0), 1)\n        .Broadcast(user_op::OpArg(\"rng_state\", 0))\n        .Build();\n  } else {\n    // otherwise split on batch_size.\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"query\", 0), 0)\n        .Split(user_op::OpArg(\"key\", 0), 0)\n        .Split(user_op::OpArg(\"value\", 0), 0)\n        .Split(user_op::OpArg(\"out\", 0), 0)\n        .Split(user_op::OpArg(\"softmax\", 0), 0)\n        .Broadcast(user_op::OpArg(\"rng_state\", 0))\n        .Build();\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> ScaledDotProductFlashAttentionOp::InferDataType(user_op::InferContext* ctx) {\n  auto q_datatype = ctx->InputDType(\"query\", 0);\n  auto k_datatype = ctx->InputDType(\"key\", 0);\n  auto v_datatype = ctx->InputDType(\"value\", 0);\n\n  CHECK_EQ_OR_RETURN(q_datatype, k_datatype) << \"query has different data type from key.\";\n  CHECK_EQ_OR_RETURN(q_datatype, v_datatype) << \"query has different data type from value.\";\n\n  ctx->SetOutputDType(\"out\", 0, q_datatype);\n  ctx->SetOutputDType(\"softmax_lse\", 0, DataType::kFloat);\n  ctx->SetOutputDType(\"rng_state\", 0, DataType::kUInt64);\n\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> ScaledDotProductFlashAttentionGradOp::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) {\n  const Shape& dout_shape = ctx->InputShape(\"grad_out\", 0);\n  const Shape& q_shape = ctx->InputShape(\"query\", 0);\n  const Shape& k_shape = ctx->InputShape(\"key\", 0);\n  const Shape& v_shape = ctx->InputShape(\"value\", 0);\n  const Shape& out_shape = ctx->InputShape(\"out\", 0);\n  const Shape& softmax_lse_shape = ctx->InputShape(\"softmax_lse\", 0);\n\n  auto batch_size = q_shape.At(0);\n  auto seqlen_q = q_shape.At(1);\n  auto num_heads = q_shape.At(2);\n  auto head_size = q_shape.At(3);\n  auto seqlen_k = k_shape.At(1);\n  auto num_heads_k = k_shape.At(2);\n  auto head_size_og = dout_shape.At(3);\n\n  // check input tensor shape.\n  CHECK_EQ_OR_RETURN(batch_size, k_shape.At(0)) << \"query has different batch size from key.\";\n  CHECK_EQ_OR_RETURN(batch_size, v_shape.At(0)) << \"query has different batch size from value.\";\n  CHECK_EQ_OR_RETURN(batch_size, dout_shape.At(0))\n      << \"query has different batch size from grad_out.\";\n  CHECK_EQ_OR_RETURN(batch_size, out_shape.At(0)) << \"query has different batch size from out.\";\n  CHECK_EQ_OR_RETURN(batch_size, softmax_lse_shape.At(0))\n      << \"query has different batch size from softmax_lse.\";\n\n  CHECK_EQ_OR_RETURN(seqlen_k, v_shape.At(1)) << \"key has different seqlen from value.\";\n  CHECK_EQ_OR_RETURN(num_heads_k, v_shape.At(2)) << \"key has different num_heads from value.\";\n\n  // dout should be padded in functional layer if needed.\n  CHECK_EQ_OR_RETURN(head_size_og, head_size) << \"grad_out has different head_size from query\";\n  CHECK_EQ_OR_RETURN(head_size, k_shape.At(3)) << \"query has different head_size from key\";\n  CHECK_EQ_OR_RETURN(head_size, v_shape.At(3)) << \"query has different head_size from value\";\n\n  // batch size must be positive.\n  CHECK_GT_OR_RETURN(batch_size, 0) << \"batch size must be positive\";\n\n  // only support head dimensions at most 256.\n  CHECK_LE_OR_RETURN(head_size_og, 256) << \"only support head dimensions at most 256\";\n\n  CHECK_EQ_OR_RETURN(num_heads % num_heads_k, 0)\n      << \"number of heads in key/value must devide number of heads in query.\";\n\n  // grad_k/v should be expanded if needed(when num_heads != num_heads_k && num_heads % num_heads_k\n  // == 0).\n  ctx->SetOutputShape(\"grad_q\", 0, Shape({batch_size, seqlen_q, num_heads, head_size}));\n  ctx->SetOutputShape(\"grad_k\", 0, Shape({batch_size, seqlen_k, num_heads, head_size}));\n  ctx->SetOutputShape(\"grad_v\", 0, Shape({batch_size, seqlen_k, num_heads, head_size}));\n\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> ScaledDotProductFlashAttentionGradOp::InferPhysicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return ScaledDotProductFlashAttentionGradOp::InferLogicalTensorDesc(ctx);\n}\n\nMaybe<void> ScaledDotProductFlashAttentionGradOp::GetSbp(user_op::SbpContext* ctx) {\n  auto parallel_num = ctx->parallel_num();\n  const Shape& q_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"query\", 0).shape();\n  const Shape& k_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"key\", 0).shape();\n  auto num_heads = q_shape.At(2);\n  auto num_heads_k = k_shape.At(2);\n  bool can_spilt_num_heads =\n      num_heads == num_heads_k || (!(num_heads % parallel_num) && !(num_heads_k % parallel_num));\n  if (can_spilt_num_heads) {\n    // prior to split on num_heads.\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"grad_out\", 0), 2)\n        .Split(user_op::OpArg(\"query\", 0), 2)\n        .Split(user_op::OpArg(\"key\", 0), 2)\n        .Split(user_op::OpArg(\"value\", 0), 2)\n        .Split(user_op::OpArg(\"out\", 0), 2)\n        .Split(user_op::OpArg(\"softmax\", 0), 1)\n        .Broadcast(user_op::OpArg(\"rng_state\", 0))\n        .Split(user_op::OpArg(\"grad_q\", 0), 2)\n        .Split(user_op::OpArg(\"grad_k\", 0), 2)\n        .Split(user_op::OpArg(\"grad_v\", 0), 2)\n        .Build();\n  } else {\n    // otherwise split on batch_size.\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"grad_out\", 0), 0)\n        .Split(user_op::OpArg(\"query\", 0), 0)\n        .Split(user_op::OpArg(\"key\", 0), 0)\n        .Split(user_op::OpArg(\"value\", 0), 0)\n        .Split(user_op::OpArg(\"out\", 0), 0)\n        .Split(user_op::OpArg(\"softmax\", 0), 0)\n        .Broadcast(user_op::OpArg(\"rng_state\", 0))\n        .Split(user_op::OpArg(\"grad_q\", 0), 0)\n        .Split(user_op::OpArg(\"grad_k\", 0), 0)\n        .Split(user_op::OpArg(\"grad_v\", 0), 0)\n        .Build();\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> ScaledDotProductFlashAttentionGradOp::InferDataType(user_op::InferContext* ctx) {\n  auto dout_datatype = ctx->InputDType(\"grad_out\", 0);\n  auto q_datatype = ctx->InputDType(\"query\", 0);\n  auto k_datatype = ctx->InputDType(\"key\", 0);\n  auto v_datatype = ctx->InputDType(\"value\", 0);\n  auto out_datatype = ctx->InputDType(\"out\", 0);\n\n  CHECK_EQ_OR_RETURN(q_datatype, k_datatype) << \"query has different data type from key.\";\n  CHECK_EQ_OR_RETURN(q_datatype, v_datatype) << \"query has different data type from value.\";\n  CHECK_EQ_OR_RETURN(q_datatype, dout_datatype) << \"query has different data type from grad_out.\";\n  CHECK_EQ_OR_RETURN(q_datatype, out_datatype) << \"query has different data type from out.\";\n\n  ctx->SetOutputDType(\"grad_q\", 0, q_datatype);\n  ctx->SetOutputDType(\"grad_k\", 0, q_datatype);\n  ctx->SetOutputDType(\"grad_v\", 0, q_datatype);\n\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/search_sorted_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/* static */ Maybe<void> SearchSortedOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  ctx->SetOutputShape(\"out\", 0, ctx->InputShape(\"values\", 0));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> SearchSortedOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> SearchSortedOp::GetSbp(user_op::SbpContext* ctx) {\n  // The current implementation can only do arg_sort in the last dimension and should use\n  // Broadcast (by default) instead of Split for that dimension\n  const user_op::TensorDesc& in_tensor =\n      ctx->LogicalTensorDesc4InputArgNameAndIndex(\"sorted_sequence\", 0);\n  FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes() - 1) {\n    ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build();\n  }\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> SearchSortedOp::CheckAttr(const user_op::UserOpDefWrapper& def,\n                                                   const user_op::UserOpConfWrapper& conf) {\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> SearchSortedOp::InferDataType(user_op::InferContext* ctx) {\n  const bool& out_int32 = ctx->Attr<bool>(\"out_int32\");\n  if (out_int32) {\n    ctx->SetOutputDType(\"out\", 0, DataType::kInt32);\n  } else {\n    ctx->SetOutputDType(\"out\", 0, DataType::kInt64);\n  }\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> SearchSortedScalarOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  ctx->SetOutputShape(\"out\", 0, Shape({}));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> SearchSortedScalarOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> SearchSortedScalarOp::GetSbp(user_op::SbpContext* ctx) {\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> SearchSortedScalarOp::CheckAttr(const user_op::UserOpDefWrapper& def,\n                                                         const user_op::UserOpConfWrapper& conf) {\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> SearchSortedScalarOp::InferDataType(user_op::InferContext* ctx) {\n  const bool& out_int32 = ctx->Attr<bool>(\"out_int32\");\n  if (out_int32) {\n    ctx->SetOutputDType(\"out\", 0, DataType::kInt32);\n  } else {\n    ctx->SetOutputDType(\"out\", 0, DataType::kInt64);\n  }\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/selu_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/*static*/ Maybe<void> SeluOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"in\", 0);\n  FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) {\n    ctx->NewBuilder().Split(user_op::OpArg(\"in\", 0), i).Split(user_op::OpArg(\"out\", 0), i).Build();\n  }\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> SeluOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  ctx->SetOutputShape(\"out\", 0, ctx->InputShape(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> SeluOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> SeluOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> SeluGradOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"x\", 0);\n  FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) {\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"x\", 0), i)\n        .Split(user_op::OpArg(\"dy\", 0), i)\n        .Split(user_op::OpArg(\"dx\", 0), i)\n        .Build();\n  }\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> SeluGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& x_shape = ctx->InputShape(\"x\", 0);\n  const Shape& dy_shape = ctx->InputShape(\"dy\", 0);\n  CHECK_OR_RETURN(dy_shape == x_shape)\n      << Error::RuntimeError() << \"Tensors dy and x must be the same shape\";\n  ctx->SetOutputShape(\"dx\", 0, dy_shape);\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> SeluGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> SeluGradOp::InferDataType(user_op::InferContext* ctx) {\n  CHECK_EQ_OR_RETURN(ctx->InputDType(\"dy\", 0), ctx->InputDType(\"x\", 0))\n      << Error::TypeError() << \"Tensors dy and x must have same type\";\n  ctx->SetOutputDType(\"dx\", 0, ctx->InputDType(\"x\", 0));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/sigmoid_cross_entropy_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/*static*/ Maybe<void> SigmoidCrossEntropyOp::GetSbp(user_op::SbpContext* ctx) {\n  const auto num_out_axes =\n      ctx->LogicalTensorDesc4InputArgNameAndIndex(\"prediction\", 0).shape().NumAxes();\n  FOR_RANGE(int64_t, i, 0, num_out_axes) {\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"prediction\", 0), i)\n        .Split(user_op::OpArg(\"label\", 0), i)\n        .Split(user_op::OpArg(\"loss\", 0), i)\n        .Build();\n  }\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> SigmoidCrossEntropyOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& prediction_desc = ctx->InputTensorDesc(\"prediction\", 0);\n  const user_op::TensorDesc& label_desc = ctx->InputTensorDesc(\"label\", 0);\n  CHECK_EQ_OR_RETURN(label_desc.shape(), prediction_desc.shape())\n      << Error::RuntimeError() << \"The size of label \" << label_desc.shape()\n      << \" must match the size of prediction \" << prediction_desc.shape();\n  user_op::TensorDesc* loss_desc = ctx->MutOutputTensorDesc(\"loss\", 0);\n  loss_desc->set_shape(prediction_desc.shape());\n  loss_desc->set_is_dynamic(prediction_desc.is_dynamic());\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> SigmoidCrossEntropyOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> SigmoidCrossEntropyOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"loss\", 0, ctx->InputDType(\"prediction\", 0));\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> SigmoidCrossEntropyOp::ModifyInputArg(\n    const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&) {\n  user_op::InputArgModifier* cond_arg_modifier = GetInputArgModifierFn(\"label\", 0);\n  cond_arg_modifier->set_requires_grad(false);\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> SigmoidCrossEntropyGradOp::GetSbp(user_op::SbpContext* ctx) {\n  const auto num_dy_axes =\n      ctx->LogicalTensorDesc4InputArgNameAndIndex(\"loss_diff\", 0).shape().NumAxes();\n  FOR_RANGE(int64_t, i, 0, num_dy_axes) {\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"loss_diff\", 0), i)\n        .Split(user_op::OpArg(\"label\", 0), i)\n        .Split(user_op::OpArg(\"prediction\", 0), i)\n        .Split(user_op::OpArg(\"prediction_diff\", 0), i)\n        .Build();\n  }\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> SigmoidCrossEntropyGradOp::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) {\n  const user_op::TensorDesc& prediction_desc = ctx->InputTensorDesc(\"prediction\", 0);\n  const user_op::TensorDesc& label_desc = ctx->InputTensorDesc(\"label\", 0);\n  const user_op::TensorDesc& loss_diff_desc = ctx->InputTensorDesc(\"loss_diff\", 0);\n  CHECK_EQ_OR_RETURN(label_desc.shape(), prediction_desc.shape())\n      << Error::RuntimeError() << \"The size of label \" << label_desc.shape()\n      << \" must match the size of prediction \" << prediction_desc.shape();\n  CHECK_EQ_OR_RETURN(loss_diff_desc.shape(), prediction_desc.shape())\n      << Error::RuntimeError() << \"The size of loss_diff \" << loss_diff_desc.shape()\n      << \" must match the size of prediction \" << prediction_desc.shape();\n  user_op::TensorDesc* prediction_diff = ctx->MutOutputTensorDesc(\"prediction_diff\", 0);\n  prediction_diff->set_shape(prediction_desc.shape());\n  prediction_diff->set_is_dynamic(prediction_desc.is_dynamic());\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> SigmoidCrossEntropyGradOp::InferPhysicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> SigmoidCrossEntropyGradOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"prediction_diff\", 0, ctx->InputDType(\"prediction\", 0));\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> SigmoidCrossEntropyGradOp::ModifyInputArg(\n    const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&) {\n  user_op::InputArgModifier* cond_arg_modifier = GetInputArgModifierFn(\"label\", 0);\n  cond_arg_modifier->set_requires_grad(false);\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/silu_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/*static*/ Maybe<void> SiluOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"in\", 0);\n  FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) {\n    ctx->NewBuilder().Split(user_op::OpArg(\"in\", 0), i).Split(user_op::OpArg(\"out\", 0), i).Build();\n  }\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> SiluOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  ctx->SetOutputShape(\"out\", 0, ctx->InputShape(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> SiluOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> SiluOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> SiluGradOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"x\", 0);\n  FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) {\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"x\", 0), i)\n        .Split(user_op::OpArg(\"dy\", 0), i)\n        .Split(user_op::OpArg(\"dx\", 0), i)\n        .Build();\n  }\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> SiluGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& x_shape = ctx->InputShape(\"x\", 0);\n  const Shape& dy_shape = ctx->InputShape(\"dy\", 0);\n  CHECK_OR_RETURN(dy_shape == x_shape) << Error::RuntimeError() << \"The size of dy \" << dy_shape\n                                       << \" must match the size of x \" << x_shape;\n  ctx->SetOutputShape(\"dx\", 0, dy_shape);\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> SiluGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> SiluGradOp::InferDataType(user_op::InferContext* ctx) {\n  CHECK_EQ_OR_RETURN(ctx->InputDType(\"dy\", 0), ctx->InputDType(\"x\", 0))\n      << Error::TypeError() << \"dy and x are expected to have the same dtype, but found \"\n      << DataType_Name(ctx->InputDType(\"dy\", 0)) << \" and \"\n      << DataType_Name(ctx->InputDType(\"x\", 0));\n  ctx->SetOutputDType(\"dx\", 0, ctx->InputDType(\"x\", 0));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/skip_layer_norm_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\nnamespace {\n\noneflow::DataType InferParamDataType(const DataType x_data_type) {\n  return (x_data_type == DataType::kFloat16 || x_data_type == DataType::kBFloat16)\n             ? DataType::kFloat\n             : x_data_type;\n}\n\n}  // namespace\n\n/* static */ auto SkipLayerNormOp::GetSbp(user_op::SbpContext* ctx) -> Maybe<void> {\n  for (int64_t i = 0; i < ctx->LogicalTensorDesc4InputArgNameAndIndex(\"x\", 0).shape().NumAxes() - 1;\n       ++i) {\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"x\", 0), i)\n        .Split(user_op::OpArg(\"skip\", 0), i)\n        .Broadcast(user_op::OpArg(\"bias\", 0))\n        .Broadcast(user_op::OpArg(\"gamma\", 0))\n        .Broadcast(user_op::OpArg(\"beta\", 0))\n        .Split(ctx->outputs(), i)\n        .Build();\n  }\n\n  return Maybe<void>::Ok();\n}\n\n/* static */ auto SkipLayerNormOp::InferLogicalTensorDesc(user_op::InferContext* ctx)\n    -> Maybe<void> {\n  // check shape of x\n  const Shape& x_shape = ctx->InputShape(\"x\", 0);\n  CHECK_GE_OR_RETURN(x_shape.NumAxes(), 2)\n      << \"number of axes of \\'x\\' should have be greater than or equal to 2, yet get \"\n      << x_shape.NumAxes();\n\n  // check shape of gamma, beta and bias\n  if (ctx->has_input(\"gamma\", 0)) {\n    const Shape& gamma_shape = ctx->InputShape(\"gamma\", 0);\n    CHECK_EQ_OR_RETURN(gamma_shape.NumAxes(), 1)\n        << \"number of axes of \\'gamma\\' should be equal to 1, yet get \" << gamma_shape.NumAxes();\n    CHECK_EQ_OR_RETURN(gamma_shape.At(0), x_shape.At(x_shape.NumAxes() - 1))\n        << \"the size of \\'gamma\\'(\" << gamma_shape.At(0)\n        << \") is not consistant with the last dimension of \\'x\\'(\"\n        << x_shape.At(x_shape.NumAxes() - 1) << \")\";\n  }\n  if (ctx->has_input(\"beta\", 0)) {\n    const Shape& beta_shape = ctx->InputShape(\"beta\", 0);\n    CHECK_EQ_OR_RETURN(beta_shape.NumAxes(), 1)\n        << \"number of axes of \\'beta\\' should be equal to 1, yet get \" << beta_shape.NumAxes();\n    CHECK_EQ_OR_RETURN(beta_shape.At(0), x_shape.At(x_shape.NumAxes() - 1))\n        << \"the size of \\'beta\\'(\" << beta_shape.At(0)\n        << \") is not consistant with the last dimension of \\'x\\'(\"\n        << x_shape.At(x_shape.NumAxes() - 1) << \")\";\n  }\n  if (ctx->has_input(\"bias\", 0)) {\n    const Shape& bias_shape = ctx->InputShape(\"bias\", 0);\n    CHECK_EQ_OR_RETURN(bias_shape.NumAxes(), 1)\n        << \"number of axes of \\'bias\\' should be equal to 1, yet get \" << bias_shape.NumAxes();\n    CHECK_EQ_OR_RETURN(bias_shape.At(0), x_shape.At(x_shape.NumAxes() - 1))\n        << \"the size of \\'bias\\'(\" << bias_shape.At(0)\n        << \") is not consistant with the last dimension of \\'x\\'(\"\n        << x_shape.At(x_shape.NumAxes() - 1) << \")\";\n  }\n\n  // check shape of skip\n  if (ctx->has_input(\"skip\", 0)) {\n    const Shape& skip_shape = ctx->InputShape(\"skip\", 0);\n    CHECK_EQ_OR_RETURN(skip_shape, x_shape) << \"shape of \\'skip\\' is not the same as \\'x\\'\";\n  }\n\n  // set output shape of y\n  user_op::TensorDesc* y_tensor = ctx->MutOutputTensorDesc(\"y\", 0);\n  y_tensor->set_shape(x_shape);\n\n  // set output shape of mean and varience\n  DimVector mean_dim_vec;\n  mean_dim_vec.push_back(x_shape.Count(0, x_shape.NumAxes() - 1));\n  Shape mean_shape(mean_dim_vec);\n\n  user_op::TensorDesc* mean_tensor = ctx->MutOutputTensorDesc(\"mean\", 0);\n  user_op::TensorDesc* varience_tensor = ctx->MutOutputTensorDesc(\"inv_variance\", 0);\n  mean_tensor->set_shape(mean_shape);\n  varience_tensor->set_shape(mean_shape);\n\n  return Maybe<void>::Ok();\n}\n\n/* static */ auto SkipLayerNormOp::InferPhysicalTensorDesc(user_op::InferContext* ctx)\n    -> Maybe<void> {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ auto SkipLayerNormOp::InferDataType(user_op::InferContext* ctx) -> Maybe<void> {\n  // obtain input data types\n  DataType x_dtype = ctx->InputDType(\"x\", 0);\n\n  // check data type of gamma\n  if (ctx->has_input(\"gamma\", 0)) {\n    CHECK_EQ_OR_RETURN(ctx->InputDType(\"gamma\", 0), x_dtype)\n        << \"data type of \\'gamma\\' is not consitant with \\'x\\'\";\n  }\n\n  // check data type of bias\n  if (ctx->has_input(\"bias\", 0)) {\n    CHECK_EQ_OR_RETURN(ctx->InputDType(\"bias\", 0), x_dtype)\n        << \"data type of \\'bias\\' is not consitant with \\'x\\'\";\n  }\n\n  // check data types of beta\n  if (ctx->has_input(\"beta\", 0)) {\n    CHECK_EQ_OR_RETURN(ctx->InputDType(\"beta\", 0), x_dtype)\n        << \"data type of \\'beta\\' is not consitant with \\'x\\'\";\n  }\n\n  // check data types of skip\n  if (ctx->has_input(\"skip\", 0)) {\n    CHECK_EQ_OR_RETURN(ctx->InputDType(\"skip\", 0), x_dtype)\n        << \"data type of \\'skip\\' is not consitant with \\'x\\'\";\n  }\n\n  // set output data type\n  ctx->SetOutputDType(\"y\", 0, x_dtype);\n  ctx->SetOutputDType(\"mean\", 0, InferParamDataType(x_dtype));\n  ctx->SetOutputDType(\"inv_variance\", 0, InferParamDataType(x_dtype));\n\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/skip_rms_norm_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\nnamespace {\n\noneflow::DataType InferParamDataType(const DataType x_data_type) {\n  return (x_data_type == DataType::kFloat16 || x_data_type == DataType::kBFloat16)\n             ? DataType::kFloat\n             : x_data_type;\n}\n\n}  // namespace\n\n/* static */ auto SkipRmsNormOp::GetSbp(user_op::SbpContext* ctx) -> Maybe<void> {\n  for (int64_t i = 0; i < ctx->LogicalTensorDesc4InputArgNameAndIndex(\"x\", 0).shape().NumAxes() - 1;\n       ++i) {\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"x\", 0), i)\n        .Split(user_op::OpArg(\"skip\", 0), i)\n        .Broadcast(user_op::OpArg(\"bias\", 0))\n        .Broadcast(user_op::OpArg(\"weight\", 0))\n        .Split(ctx->outputs(), i)\n        .Build();\n  }\n\n  return Maybe<void>::Ok();\n}\n\n/* static */ auto SkipRmsNormOp::InferLogicalTensorDesc(user_op::InferContext* ctx) -> Maybe<void> {\n  // check shape of x\n  const Shape& x_shape = ctx->InputShape(\"x\", 0);\n  CHECK_GE_OR_RETURN(x_shape.NumAxes(), 2)\n      << \"number of axes of \\'x\\' should have be greater than or equal to 2, yet get \"\n      << x_shape.NumAxes();\n\n  // check shape of weight and bias\n  if (ctx->has_input(\"weight\", 0)) {\n    const Shape& weight_shape = ctx->InputShape(\"weight\", 0);\n    CHECK_EQ_OR_RETURN(weight_shape.NumAxes(), 1)\n        << \"number of axes of \\'weight\\' should be equal to 1, yet get \" << weight_shape.NumAxes();\n    CHECK_EQ_OR_RETURN(weight_shape.At(0), x_shape.At(x_shape.NumAxes() - 1))\n        << \"the size of \\'weight\\'(\" << weight_shape.At(0)\n        << \") is not consistant with the last dimension of \\'x\\'(\"\n        << x_shape.At(x_shape.NumAxes() - 1) << \")\";\n  }\n  if (ctx->has_input(\"bias\", 0)) {\n    const Shape& bias_shape = ctx->InputShape(\"bias\", 0);\n    CHECK_EQ_OR_RETURN(bias_shape.NumAxes(), 1)\n        << \"number of axes of \\'bias\\' should be equal to 1, yet get \" << bias_shape.NumAxes();\n    CHECK_EQ_OR_RETURN(bias_shape.At(0), x_shape.At(x_shape.NumAxes() - 1))\n        << \"the size of \\'bias\\'(\" << bias_shape.At(0)\n        << \") is not consistant with the last dimension of \\'x\\'(\"\n        << x_shape.At(x_shape.NumAxes() - 1) << \")\";\n  }\n\n  // check shape of skip\n  if (ctx->has_input(\"skip\", 0)) {\n    const Shape& skip_shape = ctx->InputShape(\"skip\", 0);\n    CHECK_EQ_OR_RETURN(skip_shape, x_shape) << \"shape of \\'skip\\' is not the same as \\'x\\'\";\n  }\n\n  // set output shape of y\n  user_op::TensorDesc* y_tensor = ctx->MutOutputTensorDesc(\"y\", 0);\n  y_tensor->set_shape(x_shape);\n\n  // set output shape of inv_rms\n  DimVector inv_rms_dim_vec;\n  inv_rms_dim_vec.push_back(x_shape.Count(0, x_shape.NumAxes() - 1));\n  Shape inv_rms_shape(inv_rms_dim_vec);\n  user_op::TensorDesc* inv_rms_tensor = ctx->MutOutputTensorDesc(\"inv_rms\", 0);\n  inv_rms_tensor->set_shape(inv_rms_shape);\n\n  return Maybe<void>::Ok();\n}\n\n/* static */ auto SkipRmsNormOp::InferPhysicalTensorDesc(user_op::InferContext* ctx)\n    -> Maybe<void> {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ auto SkipRmsNormOp::InferDataType(user_op::InferContext* ctx) -> Maybe<void> {\n  // obtain input data types\n  DataType x_dtype = ctx->InputDType(\"x\", 0);\n\n  // check data type of bias\n  if (ctx->has_input(\"bias\", 0)) {\n    CHECK_EQ_OR_RETURN(ctx->InputDType(\"bias\", 0), x_dtype)\n        << \"data type of \\'bias\\' is not consitant with \\'x\\'\";\n  }\n\n  // check data types of weight\n  if (ctx->has_input(\"weight\", 0)) {\n    CHECK_EQ_OR_RETURN(ctx->InputDType(\"weight\", 0), x_dtype)\n        << \"data type of \\'weight\\' is not consitant with \\'x\\'\";\n  }\n\n  // check data types of skip\n  if (ctx->has_input(\"skip\", 0)) {\n    CHECK_EQ_OR_RETURN(ctx->InputDType(\"skip\", 0), x_dtype)\n        << \"data type of \\'skip\\' is not consitant with \\'x\\'\";\n  }\n\n  // set output data type\n  ctx->SetOutputDType(\"y\", 0, x_dtype);\n  ctx->SetOutputDType(\"inv_rms\", 0, InferParamDataType(x_dtype));\n\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/slice_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/job/nd_sbp_util.h\"\n#include \"oneflow/user/kernels/slice_util.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n#include \"oneflow/core/operator/operator.h\"\n\nnamespace oneflow {\n\nnamespace {\nbool IsFullSlice(int64_t start, int64_t stop, int64_t step, int64_t size) {\n  if (step != 1) { return false; }\n  if (start != 0) { return false; }\n  if (stop != size) { return false; }\n  return true;\n}\n}  // namespace\n\n/*static*/ Maybe<void> SliceUpdateOp::GetSbp(user_op::SbpContext* ctx) {\n  const Shape& x_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"ref\", 0).shape();\n  const int64_t ndim = x_shape.NumAxes();\n  const auto& start_vec = ctx->Attr<std::vector<int64_t>>(\"start\");\n  const auto& stop_vec = ctx->Attr<std::vector<int64_t>>(\"stop\");\n  const auto& step_vec = ctx->Attr<std::vector<int64_t>>(\"step\");\n  CHECK_EQ_OR_RETURN(start_vec.size(), ndim)\n      << Error::RuntimeError()\n      << \"The size of start list must be equal to the dimension of ref tensor, \"\n      << \"but got \" << start_vec.size() << \" and \" << ndim;\n  CHECK_EQ_OR_RETURN(stop_vec.size(), ndim)\n      << Error::RuntimeError()\n      << \"The size of stop list must be equal to the dimension of ref tensor, \"\n      << \"but got \" << stop_vec.size() << \" and \" << ndim;\n  CHECK_EQ_OR_RETURN(step_vec.size(), ndim)\n      << Error::RuntimeError()\n      << \"The size of step list must be equal to the dimension of ref tensor, \"\n      << \"but got \" << step_vec.size() << \" and \" << ndim;\n\n  FOR_RANGE(int64_t, axis, 0, ndim) {\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"ref\", 0), axis)\n        .Broadcast(user_op::OpArg(\"value\", 0))\n        .Split(user_op::OpArg(\"y\", 0), axis)\n        .Build();\n    // FullSlice support S+S->S\n    if (IsFullSlice(start_vec[axis], stop_vec[axis], step_vec[axis], x_shape.At(axis))) {\n      ctx->NewBuilder().Split(ctx->inputs(), axis).Split(ctx->outputs(), axis).Build();\n    }\n  }\n  ctx->NewBuilder()\n      .PartialSum(user_op::OpArg(\"ref\", 0))\n      .PartialSum(user_op::OpArg(\"value\", 0))\n      .PartialSum(user_op::OpArg(\"y\", 0))\n      .Build();\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> SliceUpdateOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& ref_desc = ctx->InputTensorDesc(\"ref\", 0);\n  const Shape& value_shape = ctx->InputTensorDesc(\"value\", 0).shape();\n  const auto& start_vec = ctx->Attr<std::vector<int64_t>>(\"start\");\n  const auto& stop_vec = ctx->Attr<std::vector<int64_t>>(\"stop\");\n  const auto& step_vec = ctx->Attr<std::vector<int64_t>>(\"step\");\n  CHECK_OR_RETURN(!ref_desc.is_dynamic())\n      << Error::RuntimeError() << \"The ref tensor is not dynamic\";\n  FOR_RANGE(size_t, i, 0, step_vec.size()) {\n    const int64_t step = step_vec.at(i);\n    const int64_t start = start_vec.at(i);\n    const int64_t stop = stop_vec.at(i);\n    CHECK_GT_OR_RETURN(step, 0) << Error::RuntimeError()\n                                << \"The step list elements must be greater than 0, \"\n                                << \"but got \" << step << \" at index \" << i;\n\n    CHECK_GE_OR_RETURN(start, 0) << Error::RuntimeError()\n                                 << \"The start list elements must be greater than or equal to 0, \"\n                                 << \"but got \" << start << \" at index \" << i;\n    CHECK_GE_OR_RETURN(stop, 0) << Error::RuntimeError()\n                                << \"The stop list elements must be greater than or equal to 0, \"\n                                << \"but got \" << stop << \" at index \" << i;\n    CHECK_LE_OR_RETURN(start, stop) << Error::RuntimeError()\n                                    << \"The element in start list must be less than or equal to \"\n                                       \"the element in stop list at index \"\n                                    << i << \", but got \" << start << \" and \" << stop;\n    CHECK_EQ_OR_RETURN((stop - start + step - 1) / step, value_shape.At(i))\n        << Error::RuntimeError()\n        << \"The size of slice tuple must be equal to the size of value tensor at dimension \" << i\n        << \", but got \" << (stop - start + step - 1) / step << \" and \" << value_shape.At(i);\n  }\n  auto* y_desc = ctx->MutOutputTensorDesc(\"y\", 0);\n  y_desc->set_shape(ref_desc.shape());\n  y_desc->set_is_dynamic(ref_desc.is_dynamic());\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> SliceUpdateOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& ref_desc = ctx->InputTensorDesc(\"ref\", 0);\n  auto* y_desc = ctx->MutOutputTensorDesc(\"y\", 0);\n  y_desc->set_shape(ref_desc.shape());\n  y_desc->set_is_dynamic(ref_desc.is_dynamic());\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> SliceUpdateOp::InferDataType(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& ref_desc = ctx->InputTensorDesc(\"ref\", 0);\n  const user_op::TensorDesc& value_desc = ctx->InputTensorDesc(\"value\", 0);\n  CHECK_OR_RETURN(ref_desc.data_type() == value_desc.data_type())\n      << Error::TypeError() << \"Tensors ref and value must have same type\";\n  auto* y_desc = ctx->MutOutputTensorDesc(\"y\", 0);\n  y_desc->set_data_type(ref_desc.data_type());\n  y_desc->set_stride(ref_desc.stride());\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> SliceOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& input_desc = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"x\", 0);\n  const Shape& in_shape = input_desc.shape();\n  int32_t ndim = in_shape.NumAxes();\n  const auto& start_vec = ctx->Attr<std::vector<int64_t>>(\"start\");\n  const auto& stop_vec = ctx->Attr<std::vector<int64_t>>(\"stop\");\n  const auto& step_vec = ctx->Attr<std::vector<int64_t>>(\"step\");\n  CHECK_EQ_OR_RETURN(start_vec.size(), ndim)\n      << \"start_vec's dim not equal to ref shape's dim: \" << start_vec.size() << \" vs \" << ndim;\n  CHECK_EQ_OR_RETURN(stop_vec.size(), ndim)\n      << \"stop_vec's dim not equal to ref shape's dim: \" << start_vec.size() << \" vs \" << ndim;\n  CHECK_EQ_OR_RETURN(step_vec.size(), ndim)\n      << \"step_vec's dim not equal to ref shape's dim: \" << start_vec.size() << \" vs \" << ndim;\n\n  FOR_RANGE(int64_t, axis, 0, input_desc.shape().NumAxes()) {\n    if (IsFullSlice(start_vec[axis], stop_vec[axis], step_vec[axis], in_shape.At(axis))) {\n      ctx->NewBuilder().Split(ctx->inputs(), axis).Split(ctx->outputs(), axis).Build();\n    } else {\n      ctx->NewBuilder()\n          .Split(user_op::OpArg(\"x\", 0), axis)\n          .PartialSum(user_op::OpArg(\"y\", 0))\n          .Build();\n    }\n  }\n  ctx->NewBuilder().PartialSum(user_op::OpArg(\"x\", 0)).PartialSum(user_op::OpArg(\"y\", 0)).Build();\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> SliceOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& x_shape = ctx->InputShape(\"x\", 0);\n  const int64_t ndim = x_shape.NumAxes();\n  const auto& start_vec = ctx->Attr<std::vector<int64_t>>(\"start\");\n  const auto& stop_vec = ctx->Attr<std::vector<int64_t>>(\"stop\");\n  const auto& step_vec = ctx->Attr<std::vector<int64_t>>(\"step\");\n  DimVector dim_vec(ndim);\n  FOR_RANGE(size_t, i, 0, dim_vec.size()) {\n    const int64_t step = step_vec.at(i);\n    const int64_t start = start_vec.at(i);\n    const int64_t stop = stop_vec.at(i);\n    CHECK_GT_OR_RETURN(step, 0) << Error::RuntimeError()\n                                << \"The step list elements must be greater than 0, \"\n                                << \"but got \" << step << \" at index \" << i;\n    CHECK_GE_OR_RETURN(start, 0) << Error::RuntimeError()\n                                 << \"The start list elements must be greater than or equal to 0, \"\n                                 << \"but got \" << start << \" at index \" << i;\n    CHECK_GE_OR_RETURN(stop, 0) << Error::RuntimeError()\n                                << \"The stop list elements must be greater than or equal to 0, \"\n                                << \"but got \" << stop << \" at index \" << i;\n    CHECK_LE_OR_RETURN(start, stop) << Error::RuntimeError()\n                                    << \"The element in start list must be less than or equal to \"\n                                       \"the element in stop list at index \"\n                                    << i << \", but got \" << start << \" and \" << stop;\n    const int64_t diff = stop - start - 1;\n    dim_vec[i] = diff / step + 1;\n  }\n  ctx->SetOutputShape(\"y\", 0, Shape(dim_vec));\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> SliceOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& x_shape = ctx->InputShape(\"x\", 0);\n  const int64_t ndim = x_shape.NumAxes();\n  const auto& start_vec = ctx->Attr<std::vector<int64_t>>(\"start\");\n  const auto& stop_vec = ctx->Attr<std::vector<int64_t>>(\"stop\");\n  const auto& step_vec = ctx->Attr<std::vector<int64_t>>(\"step\");\n  DimVector dim_vec(ndim);  // logical shape in slice attributes\n  FOR_RANGE(size_t, i, 0, dim_vec.size()) {\n    const int64_t step = step_vec[i];\n    const int64_t start = start_vec[i];\n    const int64_t stop = stop_vec[i];\n    CHECK_GT_OR_RETURN(step, 0) << \"Slice step must be greater than 0\";\n    CHECK_GE_OR_RETURN(start, 0) << \"Slice start must be greater or equal to 0\";\n    CHECK_GE_OR_RETURN(stop, 0) << \"Slice stop must be greater or equal to 0\";\n    CHECK_LE_OR_RETURN(start, stop) << \"Slice start must be less or equal to stop\";\n    const int64_t diff = stop - start - 1;\n    dim_vec[i] = diff / step + 1;\n  }\n  // Get physical shape with TensorSliceView\n  const NdSbp& y_nd_sbp = ctx->NdSbp4ArgNameAndIndex(\"y\", 0);\n  const Shape& parallel_hierarchy = *ctx->parallel_desc().hierarchy();\n  const Shape& logical_shape = Shape(dim_vec);\n  const int64_t parallel_id = ctx->parallel_ctx().parallel_id();\n  const TensorSliceView& slice_view =\n      GetTensorSliceView4ParallelId(parallel_hierarchy, y_nd_sbp, logical_shape, parallel_id);\n  ctx->SetOutputShape(\"y\", 0, Shape(slice_view.shape()));\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> SliceOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"y\", 0, ctx->InputDType(\"x\", 0));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> SliceGradOp::GetSbp(user_op::SbpContext* ctx) {\n  const Shape& like_shape = ctx->Attr<Shape>(\"like_shape\");\n  const int64_t ndim = like_shape.NumAxes();\n  const auto& start_vec = ctx->Attr<std::vector<int64_t>>(\"start\");\n  const auto& stop_vec = ctx->Attr<std::vector<int64_t>>(\"stop\");\n  const auto& step_vec = ctx->Attr<std::vector<int64_t>>(\"step\");\n  CHECK_EQ_OR_RETURN(start_vec.size(), ndim)\n      << Error::RuntimeError()\n      << \"The size of start list must be equal to the dimension of ref tensor, \"\n      << \"but got \" << start_vec.size() << \" and \" << ndim;\n  CHECK_EQ_OR_RETURN(stop_vec.size(), ndim)\n      << Error::RuntimeError()\n      << \"The size of stop list must be equal to the dimension of ref tensor, \"\n      << \"but got \" << stop_vec.size() << \" and \" << ndim;\n  CHECK_EQ_OR_RETURN(step_vec.size(), ndim)\n      << Error::RuntimeError()\n      << \"The size of step list must be equal to the dimension of ref tensor, \"\n      << \"but got \" << step_vec.size() << \" and \" << ndim;\n  FOR_RANGE(int, i, 0, ndim) {\n    if (IsFullSlice(start_vec[i], stop_vec[i], step_vec[i], like_shape.At(i))) {\n      ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build();\n    }\n  }\n  ctx->NewBuilder().PartialSum(user_op::OpArg(\"dy\", 0)).PartialSum(user_op::OpArg(\"dx\", 0)).Build();\n  ctx->NewBuilder().Broadcast(user_op::OpArg(\"dy\", 0)).Broadcast(user_op::OpArg(\"dx\", 0)).Build();\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> SliceGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& like_shape = ctx->Attr<Shape>(\"like_shape\");\n  const Shape& dy_shape = ctx->InputShape(\"dy\", 0);\n  const auto& start_vec = ctx->Attr<std::vector<int64_t>>(\"start\");\n  const auto& stop_vec = ctx->Attr<std::vector<int64_t>>(\"stop\");\n  const auto& step_vec = ctx->Attr<std::vector<int64_t>>(\"step\");\n\n  const int64_t ndim = dy_shape.NumAxes();\n  CHECK_EQ_OR_RETURN(start_vec.size(), ndim)\n      << Error::RuntimeError()\n      << \"The size of start list must be equal to the dimension of ref tensor, \"\n      << \"but got \" << start_vec.size() << \" and \" << ndim;\n  CHECK_EQ_OR_RETURN(stop_vec.size(), ndim)\n      << Error::RuntimeError()\n      << \"The size of stop list must be equal to the dimension of ref tensor, \"\n      << \"but got \" << stop_vec.size() << \" and \" << ndim;\n  CHECK_EQ_OR_RETURN(step_vec.size(), ndim)\n      << Error::RuntimeError()\n      << \"The size of step list must be equal to the dimension of ref tensor, \"\n      << \"but got \" << step_vec.size() << \" and \" << ndim;\n  ctx->SetOutputShape(\"dx\", 0, like_shape);\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> SliceGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  Shape logical_shape = ctx->Attr<Shape>(\"like_shape\");\n  const user_op::TensorDesc& dy_desc = ctx->InputTensorDesc(\"dy\", 0);\n  user_op::TensorDesc* dx_desc = ctx->MutOutputTensorDesc(\"dx\", 0);\n  dx_desc->set_is_dynamic(dy_desc.is_dynamic());\n\n  const auto& nd_sbp = ctx->NdSbp4ArgNameAndIndex(\"dx\", 0);\n  dx_desc->set_shape(\n      *JUST(GetPhysicalShape(logical_shape, nd_sbp, ctx->parallel_desc(), ctx->parallel_ctx())));\n  int dx_ndim = dx_desc->shape().NumAxes();\n  int dy_ndim = dy_desc.shape().NumAxes();\n  CHECK_EQ_OR_RETURN(dx_ndim, dy_ndim)\n      << Error::RuntimeError() << \"The output dimension (\" << dx_ndim\n      << \") should be equal to the input dimension (\" << dy_ndim << \") for slice backward\";\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> SliceGradOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"dx\", 0, ctx->InputDType(\"dy\", 0));\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> SliceGradOp::ModifyInputArg(const GetInputArgModifier& GetInputArgModifierFn,\n                                                   const user_op::UserOpConfWrapper&) {\n  user_op::InputArgModifier* dy_modifier = GetInputArgModifierFn(\"dy\", 0);\n  dy_modifier->set_requires_grad(false);\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/smooth_l1_loss_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/user/ops/loss_op_util.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/*static*/ Maybe<void> SmoothL1LossOp::GetSbp(user_op::SbpContext* ctx) {\n  const auto& input_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"input\", 0).shape();\n  FOR_RANGE(int64_t, i, 0, input_shape.NumAxes()) {\n    ctx->NewBuilder().Split(ctx->inputs(), i).Split(user_op::OpArg(\"out\", 0), i).Build();\n  }\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> SmoothL1LossOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const auto& input_desc = ctx->InputTensorDesc(\"input\", 0);\n  const auto& target_desc = ctx->InputTensorDesc(\"target\", 0);\n  CHECK_EQ_OR_RETURN(input_desc.is_dynamic(), target_desc.is_dynamic())\n      << Error::RuntimeError()\n      << \"input and target are expected to have the same dynamic property, but found \"\n      << input_desc.is_dynamic() << \" and \" << target_desc.is_dynamic();\n  CHECK_EQ_OR_RETURN(input_desc.shape(), target_desc.shape())\n      << Error::RuntimeError() << \"The size of input \" << input_desc.shape()\n      << \" must match the size of target \" << target_desc.shape();\n  CHECK_GE_OR_RETURN(ctx->Attr<float>(\"beta\"), 0)\n      << Error::RuntimeError() << \"beta must be greater than or equal to 0, but found it to be \"\n      << ctx->Attr<float>(\"beta\");\n\n  user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc(\"out\", 0);\n  out_desc->set_is_dynamic(input_desc.is_dynamic());\n  out_desc->set_shape(input_desc.shape());\n\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> SmoothL1LossOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> SmoothL1LossOp::InferDataType(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& input_desc = ctx->InputTensorDesc(\"input\", 0);\n  const user_op::TensorDesc& target_desc = ctx->InputTensorDesc(\"target\", 0);\n  CHECK_EQ_OR_RETURN(input_desc.data_type(), target_desc.data_type())\n      << Error::TypeError() << \"input and target are expected to have the same dtype, but found \"\n      << DataType_Name(input_desc.data_type()) << \" and \" << DataType_Name(target_desc.data_type());\n\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"input\", 0));\n\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> SmoothL1LossOp::ModifyInputArg(\n    const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&) {\n  user_op::InputArgModifier* target_modifier = GetInputArgModifierFn(\"target\", 0);\n  CHECK_OR_RETURN(target_modifier != nullptr);  // NOLINT(maybe-need-error-msg)\n  target_modifier->set_requires_grad(false);\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> SmoothL1LossGradOp::GetSbp(user_op::SbpContext* ctx) {\n  const auto& input_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"input\", 0).shape();\n  FOR_RANGE(int64_t, i, 0, input_shape.NumAxes()) {\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"input\", 0), i)\n        .Split(user_op::OpArg(\"target\", 0), i)\n        .Split(user_op::OpArg(\"dx\", 0), i)\n        .Split(user_op::OpArg(\"dy\", 0), i)\n        .Build();\n  }\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> SmoothL1LossGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const auto& input_desc = ctx->InputTensorDesc(\"input\", 0);\n  const auto& target_desc = ctx->InputTensorDesc(\"target\", 0);\n  const auto& dy_desc = ctx->InputTensorDesc(\"dy\", 0);\n  CHECK_EQ_OR_RETURN(input_desc.is_dynamic(), target_desc.is_dynamic())\n      << Error::RuntimeError()\n      << \"input and target are expected to have the same dynamic property, but found \"\n      << input_desc.is_dynamic() << \" and \" << target_desc.is_dynamic();\n  CHECK_EQ_OR_RETURN(input_desc.shape(), target_desc.shape())\n      << Error::RuntimeError() << \"The size of input \" << input_desc.shape()\n      << \" must match the size of target \" << target_desc.shape();\n  CHECK_EQ_OR_RETURN(dy_desc.shape(), target_desc.shape())\n      << Error::RuntimeError() << \"The size of dy \" << dy_desc.shape()\n      << \" must match the size of target \" << target_desc.shape();\n\n  CHECK_GE_OR_RETURN(ctx->Attr<float>(\"beta\"), 0)\n      << Error::RuntimeError() << \"beta must be greater than or equal to 0, but found it to be \"\n      << ctx->Attr<float>(\"beta\");\n\n  user_op::TensorDesc* dx_desc = ctx->MutOutputTensorDesc(\"dx\", 0);\n  dx_desc->set_is_dynamic(input_desc.is_dynamic());\n  dx_desc->set_shape(input_desc.shape());\n\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> SmoothL1LossGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> SmoothL1LossGradOp::InferDataType(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& input_desc = ctx->InputTensorDesc(\"input\", 0);\n  const user_op::TensorDesc& target_desc = ctx->InputTensorDesc(\"target\", 0);\n  CHECK_EQ_OR_RETURN(input_desc.data_type(), target_desc.data_type())\n      << Error::TypeError() << \"input and target are expected to have the same dtype, but found \"\n      << DataType_Name(input_desc.data_type()) << \" and \" << DataType_Name(target_desc.data_type());\n\n  ctx->SetOutputDType(\"dx\", 0, ctx->InputDType(\"dy\", 0));\n\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/softmax_cross_entropy_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/*static*/ Maybe<void> SoftmaxCrossEntropyOp::GetSbp(user_op::SbpContext* ctx) {\n  // ctx->LogicalTensorDesc4InputArgNameAndIndex(\"out\", 0) is not initialized here\n  const auto num_out_axes =\n      ctx->LogicalTensorDesc4InputArgNameAndIndex(\"prediction\", 0).shape().NumAxes() - 1;\n  FOR_RANGE(int64_t, i, 0, num_out_axes) {\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"prediction\", 0), i)\n        .Split(user_op::OpArg(\"label\", 0), i)\n        .Split(user_op::OpArg(\"prob\", 0), i)\n        .Split(user_op::OpArg(\"out\", 0), i)\n        .Build();\n  }\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> SoftmaxCrossEntropyOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& prediction_desc = ctx->InputTensorDesc(\"prediction\", 0);\n  const user_op::TensorDesc& label_desc = ctx->InputTensorDesc(\"label\", 0);\n  CHECK_EQ_OR_RETURN(prediction_desc.is_dynamic(), label_desc.is_dynamic())\n      << Error::RuntimeError()\n      << \"prediction and label are expected to have the same dynamic property, but found \"\n      << prediction_desc.is_dynamic() << \" and \" << label_desc.is_dynamic();\n  CHECK_GE_OR_RETURN(prediction_desc.shape().NumAxes(), 2)\n      << Error::RuntimeError()\n      << \"The dimension of prediction must be greater than or equal to 2, but found \"\n      << prediction_desc.shape().NumAxes();\n  CHECK_EQ_OR_RETURN(label_desc.shape(), prediction_desc.shape())\n      << Error::RuntimeError() << \"The size of label \" << label_desc.shape()\n      << \" must match the size of prediction \" << prediction_desc.shape();\n  const int64_t num_out_axes = prediction_desc.shape().NumAxes() - 1;\n  DimVector out_dim_vector;\n  FOR_RANGE(int64_t, i, 0, num_out_axes) {\n    out_dim_vector.emplace_back(prediction_desc.shape().At(i));\n  }\n  ctx->SetOutputShape(\"prob\", 0, ctx->InputShape(\"prediction\", 0));\n  ctx->SetOutputIsDynamic(\"prob\", 0, ctx->InputIsDynamic(\"prediction\", 0));\n  user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc(\"out\", 0);\n  out_desc->set_is_dynamic(prediction_desc.is_dynamic());\n  out_desc->set_shape(Shape(out_dim_vector));\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> SoftmaxCrossEntropyOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> SoftmaxCrossEntropyOp::InferDataType(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& prediction_desc = ctx->InputTensorDesc(\"prediction\", 0);\n  const user_op::TensorDesc& label_desc = ctx->InputTensorDesc(\"label\", 0);\n  CHECK_EQ_OR_RETURN(label_desc.data_type(), prediction_desc.data_type())\n      << Error::TypeError()\n      << \"label and prediction are expected to have the same dtype, but found \"\n      << DataType_Name(label_desc.data_type()) << \" and \"\n      << DataType_Name(prediction_desc.data_type());\n  ctx->SetOutputDType(\"prob\", 0, ctx->InputDType(\"prediction\", 0));\n  user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc(\"out\", 0);\n  out_desc->set_data_type(prediction_desc.data_type());\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> SoftmaxCrossEntropyOp::ModifyInputArg(\n    const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&) {\n  user_op::InputArgModifier* cond_arg_modifier = GetInputArgModifierFn(\"label\", 0);\n  cond_arg_modifier->set_requires_grad(false);\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> SoftmaxCrossEntropyGradOp::GetSbp(user_op::SbpContext* ctx) {\n  const auto num_dy_axes = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"dy\", 0).shape().NumAxes();\n  FOR_RANGE(int64_t, i, 0, num_dy_axes) {\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"dy\", 0), i)\n        .Split(user_op::OpArg(\"label\", 0), i)\n        .Split(user_op::OpArg(\"prob\", 0), i)\n        .Split(user_op::OpArg(\"prediction_diff\", 0), i)\n        .Build();\n  }\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> SoftmaxCrossEntropyGradOp::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) {\n  const user_op::TensorDesc& prob_desc = ctx->InputTensorDesc(\"prob\", 0);\n  const user_op::TensorDesc& label_desc = ctx->InputTensorDesc(\"label\", 0);\n  const user_op::TensorDesc& dy_desc = ctx->InputTensorDesc(\"dy\", 0);\n  CHECK_EQ_OR_RETURN(prob_desc.is_dynamic(), label_desc.is_dynamic())\n      << Error::RuntimeError()\n      << \"prob and label are expected to have the same dynamic property, but found \"\n      << prob_desc.is_dynamic() << \" and \" << label_desc.is_dynamic();\n  CHECK_GE_OR_RETURN(prob_desc.shape().NumAxes(), 2)\n      << Error::RuntimeError()\n      << \"The dimension of prob must be greater than or equal to 2, but found \"\n      << prob_desc.shape().NumAxes();\n  CHECK_EQ_OR_RETURN(dy_desc.shape().NumAxes(), prob_desc.shape().NumAxes() - 1)\n      << Error::RuntimeError()\n      << \"The dimension of dy is expected to be less than that of prob by 1, but found \"\n      << dy_desc.shape().NumAxes() << \" and \" << prob_desc.shape().NumAxes() - 1;\n  FOR_RANGE(int64_t, i, 0, dy_desc.shape().NumAxes()) {\n    CHECK_EQ_OR_RETURN(dy_desc.shape().At(i), label_desc.shape().At(i))\n        << Error::RuntimeError() << \"The size of dy (\" << dy_desc.shape().At(i)\n        << \") must match the size of label (\" << label_desc.shape().At(i) << \") at dimension \" << i;\n  }\n  CHECK_EQ_OR_RETURN(label_desc.shape(), prob_desc.shape())\n      << Error::RuntimeError() << \"The size of label \" << label_desc.shape()\n      << \" must match the size of prob \" << prob_desc.shape();\n  ctx->SetOutputShape(\"prediction_diff\", 0, ctx->InputShape(\"prob\", 0));\n  ctx->SetOutputIsDynamic(\"prediction_diff\", 0, ctx->InputIsDynamic(\"prob\", 0));\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> SoftmaxCrossEntropyGradOp::InferPhysicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> SoftmaxCrossEntropyGradOp::InferDataType(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& prob_desc = ctx->InputTensorDesc(\"prob\", 0);\n  const user_op::TensorDesc& label_desc = ctx->InputTensorDesc(\"label\", 0);\n  const user_op::TensorDesc& dy_desc = ctx->InputTensorDesc(\"dy\", 0);\n  CHECK_EQ_OR_RETURN(label_desc.data_type(), prob_desc.data_type())\n      << Error::TypeError() << \"label and prob are expected to have the same dtype, but found \"\n      << DataType_Name(label_desc.data_type()) << \" and \" << DataType_Name(prob_desc.data_type());\n  CHECK_EQ_OR_RETURN(dy_desc.data_type(), prob_desc.data_type())\n      << Error::TypeError() << \"dy and prob are expected to have the same dtype, but found \"\n      << DataType_Name(dy_desc.data_type()) << \" and \" << DataType_Name(prob_desc.data_type());\n  ctx->SetOutputDType(\"prediction_diff\", 0, ctx->InputDType(\"prob\", 0));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/softmax_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/*static*/ Maybe<void> SoftmaxOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"in\", 0);\n  FOR_RANGE(int64_t, axis, 0, in_tensor.shape().NumAxes() - 1) {\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"in\", 0), axis)\n        .Split(user_op::OpArg(\"out\", 0), axis)\n        .Build();\n  }\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> SoftmaxOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  ctx->SetOutputShape(\"out\", 0, ctx->InputShape(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> SoftmaxOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> SoftmaxOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n// Logically computation cost of pool op is the product of output data amount and pool kernal data\n// amount. After adding sbp, we just divide it by parallel number if output data is splitted because\n// splitting input and using partial sum for output is not a valid sbp for this op for now.\n/*static*/ Maybe<double> SoftmaxOp::GetComputeComplexity(user_op::ComputeComplexityFnContext* ctx) {\n  double logical_computation_cost = ctx->Shape4ArgNameAndIndex(\"in\", 0).elem_cnt() * 10;\n  const auto& parallel_hierarchy = ctx->parallel_desc().hierarchy();\n  const auto& nd_sbp_in = ctx->NdSbp4ArgNameAndIndex(\"in\", 0);\n  for (int32_t dim_sbp = 0; dim_sbp < nd_sbp_in.sbp_parallel_size(); dim_sbp++) {\n    if (nd_sbp_in.sbp_parallel(dim_sbp).has_split_parallel()) {\n      logical_computation_cost /= parallel_hierarchy->At(dim_sbp);\n    }\n  }\n  return logical_computation_cost;\n}\n\n/*static*/ Maybe<void> SoftmaxGradOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& y_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"y\", 0);\n  FOR_RANGE(int64_t, axis, 0, y_tensor.shape().NumAxes() - 1) {\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"y\", 0), axis)\n        .Split(user_op::OpArg(\"dy\", 0), axis)\n        .Split(user_op::OpArg(\"dx\", 0), axis)\n        .Build();\n  }\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> SoftmaxGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& y_shape = ctx->InputShape(\"y\", 0);\n  const Shape& dy_shape = ctx->InputShape(\"dy\", 0);\n  CHECK_OR_RETURN(dy_shape == y_shape) << Error::RuntimeError() << \"The size of dy \" << dy_shape\n                                       << \" must match the size of y \" << y_shape;\n  ctx->SetOutputShape(\"dx\", 0, dy_shape);\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> SoftmaxGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> SoftmaxGradOp::InferDataType(user_op::InferContext* ctx) {\n  CHECK_EQ_OR_RETURN(ctx->InputDType(\"dy\", 0), ctx->InputDType(\"y\", 0))\n      << Error::TypeError() << \"dy and y are expected to have the same dtype, but found \"\n      << DataType_Name(ctx->InputDType(\"dy\", 0)) << \" and \"\n      << DataType_Name(ctx->InputDType(\"y\", 0));\n  ctx->SetOutputDType(\"dx\", 0, ctx->InputDType(\"y\", 0));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/softplus_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/* static */ Maybe<void> SoftplusOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  ctx->SetOutputShape(\"out\", 0, ctx->InputShape(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> SoftplusOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> SoftplusOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"in\", 0);\n  FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) {\n    ctx->NewBuilder().Split(user_op::OpArg(\"in\", 0), i).Split(user_op::OpArg(\"out\", 0), i).Build();\n  }\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> SoftplusOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> SoftplusGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& x_shape = ctx->InputShape(\"x\", 0);\n  const Shape& dy_shape = ctx->InputShape(\"dy\", 0);\n  CHECK_OR_RETURN(dy_shape == x_shape) << Error::RuntimeError() << \"The size of dy \" << dy_shape\n                                       << \" must match the size of x \" << x_shape;\n  ctx->SetOutputShape(\"dx\", 0, dy_shape);\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> SoftplusGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> SoftplusGradOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"x\", 0);\n  FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) {\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"x\", 0), i)\n        .Split(user_op::OpArg(\"dy\", 0), i)\n        .Split(user_op::OpArg(\"dx\", 0), i)\n        .Build();\n  }\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> SoftplusGradOp::InferDataType(user_op::InferContext* ctx) {\n  CHECK_EQ_OR_RETURN(ctx->InputDType(\"dy\", 0), ctx->InputDType(\"x\", 0))\n      << Error::TypeError() << \"dy and x are expected to have the same dtype, but found \"\n      << DataType_Name(ctx->InputDType(\"dy\", 0)) << \" and \"\n      << DataType_Name(ctx->InputDType(\"x\", 0));\n  ctx->SetOutputDType(\"dx\", 0, ctx->InputDType(\"x\", 0));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/softshrink_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/* static */ Maybe<void> SoftShrinkOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  ctx->SetOutputShape(\"out\", 0, ctx->InputShape(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> SoftShrinkOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> SoftShrinkOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"in\", 0);\n  FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) {\n    ctx->NewBuilder().Split(user_op::OpArg(\"in\", 0), i).Split(user_op::OpArg(\"out\", 0), i).Build();\n  }\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> SoftShrinkOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> SoftShrinkGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& y_shape = ctx->InputShape(\"y\", 0);\n  const Shape& dy_shape = ctx->InputShape(\"dy\", 0);\n  CHECK_OR_RETURN(dy_shape == y_shape) << Error::RuntimeError() << \"The size of dy \" << dy_shape\n                                       << \" must match the size of y \" << y_shape;\n  ctx->SetOutputShape(\"dx\", 0, dy_shape);\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> SoftShrinkGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> SoftShrinkGradOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& y_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"y\", 0);\n  FOR_RANGE(int64_t, i, 0, y_tensor.shape().NumAxes()) {\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"y\", 0), i)\n        .Split(user_op::OpArg(\"dy\", 0), i)\n        .Split(user_op::OpArg(\"dx\", 0), i)\n        .Build();\n  }\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> SoftShrinkGradOp::InferDataType(user_op::InferContext* ctx) {\n  CHECK_EQ_OR_RETURN(ctx->InputDType(\"dy\", 0), ctx->InputDType(\"y\", 0))\n      << Error::TypeError() << \"dy and y are expected to have the same dtype, but found \"\n      << DataType_Name(ctx->InputDType(\"dy\", 0)) << \" and \"\n      << DataType_Name(ctx->InputDType(\"y\", 0));\n  ctx->SetOutputDType(\"dx\", 0, ctx->InputDType(\"y\", 0));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/softsign_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/*static*/ Maybe<void> SoftsignOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"in\", 0);\n  FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) {\n    ctx->NewBuilder().Split(user_op::OpArg(\"in\", 0), i).Split(user_op::OpArg(\"out\", 0), i).Build();\n  }\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> SoftsignOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  ctx->SetOutputShape(\"out\", 0, ctx->InputShape(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> SoftsignOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> SoftsignOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> SoftsignGradOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"x\", 0);\n  FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) {\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"x\", 0), i)\n        .Split(user_op::OpArg(\"dy\", 0), i)\n        .Split(user_op::OpArg(\"dx\", 0), i)\n        .Build();\n  }\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> SoftsignGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& x_shape = ctx->InputShape(\"x\", 0);\n  const Shape& dy_shape = ctx->InputShape(\"dy\", 0);\n  CHECK_OR_RETURN(dy_shape == x_shape) << Error::RuntimeError() << \"The size of dy \" << dy_shape\n                                       << \" must match the size of x \" << x_shape;\n  ctx->SetOutputShape(\"dx\", 0, dy_shape);\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> SoftsignGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> SoftsignGradOp::InferDataType(user_op::InferContext* ctx) {\n  CHECK_EQ_OR_RETURN(ctx->InputDType(\"dy\", 0), ctx->InputDType(\"x\", 0))\n      << Error::TypeError() << \"dy and x are expected to have the same dtype, but found \"\n      << DataType_Name(ctx->InputDType(\"dy\", 0)) << \" and \"\n      << DataType_Name(ctx->InputDType(\"x\", 0));\n  ctx->SetOutputDType(\"dx\", 0, ctx->InputDType(\"x\", 0));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/sort_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/*static*/ Maybe<void> SortOp::GetSbp(user_op::SbpContext* ctx) {\n  // The current implementation can only do sort in the last dimension and should use Broadcast\n  // (by default) instead of Split for that dimension\n  const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"in\", 0);\n  FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes() - 1) {\n    ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build();\n  }\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> SortOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  ctx->SetOutputShape(\"out\", 0, ctx->InputShape(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> SortOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> SortOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> SortOp::CheckAttr(const user_op::UserOpDefWrapper&,\n                                         const user_op::UserOpConfWrapper& op_conf) {\n  const std::string& direction = op_conf.attr<std::string>(\"direction\");\n  CHECK_OR_RETURN(direction == \"ASCENDING\" || direction == \"DESCENDING\")\n      << Error::RuntimeError()\n      << \"The input direction parameter value is expected to be ASCENDING or DESCENDING, \"\n      << \"but found it to be \" << direction;\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/sparse_cross_entropy_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nMaybe<void> CheckPredictionLabelDesc(const user_op::TensorDesc* prediction_desc,\n                                     const user_op::TensorDesc* label_desc) {\n  CHECK_EQ_OR_RETURN(prediction_desc->is_dynamic(), label_desc->is_dynamic())\n      << Error::RuntimeError()\n      << \"prediction and label are expected to have the same dynamic property, but found \"\n      << prediction_desc->is_dynamic() << \" and \" << label_desc->is_dynamic();\n  CHECK_GE_OR_RETURN(prediction_desc->shape().NumAxes(), 2)\n      << Error::RuntimeError()\n      << \"The dimension of prediction must be greater than or equal to 2, but found \"\n      << prediction_desc->shape().NumAxes();\n  const int64_t num_out_axes = prediction_desc->shape().NumAxes() - 1;\n  CHECK_EQ_OR_RETURN(label_desc->shape().NumAxes(), num_out_axes)\n      << Error::RuntimeError()\n      << \"The dimension of label is expected to be less than that of prediction by 1, but found \"\n      << label_desc->shape().NumAxes() << \" and \" << num_out_axes;\n  FOR_RANGE(int64_t, i, 0, num_out_axes) {\n    CHECK_EQ_OR_RETURN(prediction_desc->shape().At(i), label_desc->shape().At(i))\n        << Error::RuntimeError() << \"The size of prediction (\" << prediction_desc->shape().At(i)\n        << \") must match the size of label (\" << label_desc->shape().At(i) << \") at dimension \"\n        << i;\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InferTensorDescFn(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& prediction_desc = ctx->InputTensorDesc(\"prediction\", 0);\n  const user_op::TensorDesc& label_desc = ctx->InputTensorDesc(\"label\", 0);\n  JUST(CheckPredictionLabelDesc(&prediction_desc, &label_desc));\n  user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc(\"out\", 0);\n  out_desc->set_is_dynamic(prediction_desc.is_dynamic());\n  out_desc->set_shape(label_desc.shape());\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InferGradTensorDescFn(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& prediction_desc = ctx->InputTensorDesc(\"prediction\", 0);\n  const user_op::TensorDesc& label_desc = ctx->InputTensorDesc(\"label\", 0);\n  const user_op::TensorDesc& dy_desc = ctx->InputTensorDesc(\"dy\", 0);\n  JUST(CheckPredictionLabelDesc(&prediction_desc, &label_desc));\n  CHECK_EQ_OR_RETURN(dy_desc.shape(), label_desc.shape())\n      << Error::RuntimeError() << \"The size of dy \" << dy_desc.shape()\n      << \" must match the size of label \" << label_desc.shape();\n  ctx->SetOutputShape(\"prediction_diff\", 0, prediction_desc.shape());\n  ctx->SetOutputIsDynamic(\"prediction_diff\", 0, prediction_desc.is_dynamic());\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InferDataType(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& prediction_desc = ctx->InputTensorDesc(\"prediction\", 0);\n  const user_op::TensorDesc& label_desc = ctx->InputTensorDesc(\"label\", 0);\n  CHECK_OR_RETURN(IsIndexDataType(label_desc.data_type()))\n      << Error::TypeError() << \"The dtype of label must be integer, but found \"\n      << DataType_Name(label_desc.data_type());\n  user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc(\"out\", 0);\n  out_desc->set_data_type(prediction_desc.data_type());\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InferDataTypeGrad(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& prediction_desc = ctx->InputTensorDesc(\"prediction\", 0);\n  const user_op::TensorDesc& dy_desc = ctx->InputTensorDesc(\"dy\", 0);\n  const user_op::TensorDesc& label_desc = ctx->InputTensorDesc(\"label\", 0);\n  CHECK_OR_RETURN(IsIndexDataType(label_desc.data_type()))\n      << Error::TypeError() << \"The dtype of label must be integer, but found \"\n      << DataType_Name(label_desc.data_type());\n  CHECK_EQ_OR_RETURN(dy_desc.data_type(), prediction_desc.data_type())\n      << Error::TypeError() << \"dy and prediction are expected to have the same dtype, but found \"\n      << DataType_Name(dy_desc.data_type()) << \" and \"\n      << DataType_Name(prediction_desc.data_type());\n  ctx->SetOutputDType(\"prediction_diff\", 0, prediction_desc.data_type());\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\n/*static*/ Maybe<void> SparseCrossEntropyOp::GetSbp(user_op::SbpContext* ctx) {\n  ctx->NewBuilder()\n      .Split(user_op::OpArg(\"prediction\", 0), 0)\n      .Split(user_op::OpArg(\"label\", 0), 0)\n      .Split(user_op::OpArg(\"out\", 0), 0)\n      .Build();\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> SparseCrossEntropyOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  return InferTensorDescFn(ctx);\n}\n/*static*/ Maybe<void> SparseCrossEntropyOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> SparseCrossEntropyOp::InferDataType(user_op::InferContext* ctx) {\n  return oneflow::InferDataType(ctx);\n}\n/*static*/ Maybe<void> SparseCrossEntropyOp::ModifyInputArg(\n    const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&) {\n  user_op::InputArgModifier* label_modifier = GetInputArgModifierFn(\"label\", 0);\n  CHECK_OR_RETURN(label_modifier != nullptr);  // NOLINT(maybe-need-error-msg)\n  label_modifier->set_requires_grad(false);\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> SparseCrossEntropyMsOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& prediction =\n      ctx->LogicalTensorDesc4InputArgNameAndIndex(\"prediction\", 0);\n  ctx->NewBuilder()\n      .Split(user_op::OpArg(\"prediction\", 0), 0)\n      .Split(user_op::OpArg(\"label\", 0), 0)\n      .Split(user_op::OpArg(\"out\", 0), 0)\n      .Build();\n  ctx->NewBuilder()\n      .Split(user_op::OpArg(\"prediction\", 0), prediction.shape().NumAxes() - 1)\n      .Broadcast(user_op::OpArg(\"label\", 0))\n      .PartialSum(user_op::OpArg(\"out\", 0))\n      .Build();\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> SparseCrossEntropyMsOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  return InferTensorDescFn(ctx);\n}\n/*static*/ Maybe<void> SparseCrossEntropyMsOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> SparseCrossEntropyMsOp::InferDataType(user_op::InferContext* ctx) {\n  return oneflow::InferDataType(ctx);\n}\n/*static*/ Maybe<void> SparseCrossEntropyMsOp::ModifyInputArg(\n    const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&) {\n  user_op::InputArgModifier* label_modifier = GetInputArgModifierFn(\"label\", 0);\n  CHECK_OR_RETURN(label_modifier != nullptr);  // NOLINT(maybe-need-error-msg)\n  label_modifier->set_requires_grad(false);\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> SparseCrossEntropyGradOp::GetSbp(user_op::SbpContext* ctx) {\n  ctx->NewBuilder()\n      .Split(user_op::OpArg(\"prediction\", 0), 0)\n      .Split(user_op::OpArg(\"label\", 0), 0)\n      .Split(user_op::OpArg(\"dy\", 0), 0)\n      .Split(user_op::OpArg(\"prediction_diff\", 0), 0)\n      .Build();\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> SparseCrossEntropyGradOp::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferGradTensorDescFn(ctx);\n}\n/*static*/ Maybe<void> SparseCrossEntropyGradOp::InferPhysicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> SparseCrossEntropyGradOp::InferDataType(user_op::InferContext* ctx) {\n  return InferDataTypeGrad(ctx);\n}\n\n/*static*/ Maybe<void> SparseCrossEntropyMsGradOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& prediction =\n      ctx->LogicalTensorDesc4InputArgNameAndIndex(\"prediction\", 0);\n  ctx->NewBuilder()\n      .Split(user_op::OpArg(\"prediction\", 0), 0)\n      .Split(user_op::OpArg(\"label\", 0), 0)\n      .Split(user_op::OpArg(\"dy\", 0), 0)\n      .Split(user_op::OpArg(\"prediction_diff\", 0), 0)\n      .Build();\n  ctx->NewBuilder()\n      .Split(user_op::OpArg(\"prediction\", 0), prediction.shape().NumAxes() - 1)\n      .Broadcast(user_op::OpArg(\"label\", 0))\n      .Broadcast(user_op::OpArg(\"dy\", 0))\n      .Split(user_op::OpArg(\"prediction_diff\", 0), prediction.shape().NumAxes() - 1)\n      .Build();\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> SparseCrossEntropyMsGradOp::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferGradTensorDescFn(ctx);\n}\n/*static*/ Maybe<void> SparseCrossEntropyMsGradOp::InferPhysicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> SparseCrossEntropyMsGradOp::InferDataType(user_op::InferContext* ctx) {\n  return InferDataTypeGrad(ctx);\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/sparse_softmax_cross_entropy_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nMaybe<void> InferTensorDescFn(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& prediction_desc = ctx->InputTensorDesc(\"prediction\", 0);\n  const user_op::TensorDesc& label_desc = ctx->InputTensorDesc(\"label\", 0);\n  CHECK_EQ_OR_RETURN(prediction_desc.is_dynamic(), label_desc.is_dynamic())\n      << Error::RuntimeError()\n      << \"prediction and label are expected to have the same dynamic property, but found \"\n      << prediction_desc.is_dynamic() << \" and \" << label_desc.is_dynamic();\n  CHECK_GE_OR_RETURN(prediction_desc.shape().NumAxes(), 2)\n      << Error::RuntimeError()\n      << \"The dimension of prediction must be greater than or equal to 2, but found \"\n      << prediction_desc.shape().NumAxes();\n  const int64_t num_out_axes = prediction_desc.shape().NumAxes() - 1;\n  CHECK_EQ_OR_RETURN(label_desc.shape().NumAxes(), num_out_axes)\n      << Error::RuntimeError()\n      << \"The dimension of label is expected to be less than that of prediction by 1, but found \"\n      << label_desc.shape().NumAxes() << \" and \" << num_out_axes;\n  FOR_RANGE(int64_t, i, 0, num_out_axes) {\n    CHECK_EQ_OR_RETURN(prediction_desc.shape().At(i), label_desc.shape().At(i))\n        << Error::RuntimeError() << \"The size of prediction (\" << prediction_desc.shape().At(i)\n        << \") must match the size of label (\" << label_desc.shape().At(i) << \") at dimension \" << i;\n  }\n  ctx->SetOutputIsDynamic(\"prob\", 0, prediction_desc.is_dynamic());\n  // 'prob' is just for compute prediction's grad, prob's grad will be ignored\n  ctx->SetOutputShape(\"prob\", 0, prediction_desc.shape());\n  user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc(\"out\", 0);\n  out_desc->set_is_dynamic(prediction_desc.is_dynamic());\n  out_desc->set_shape(label_desc.shape());\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InferGradTensorDescFn(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& prob_desc = ctx->InputTensorDesc(\"prob\", 0);\n  const user_op::TensorDesc& label_desc = ctx->InputTensorDesc(\"label\", 0);\n  const user_op::TensorDesc& dy_desc = ctx->InputTensorDesc(\"dy\", 0);\n  CHECK_EQ_OR_RETURN(prob_desc.is_dynamic(), label_desc.is_dynamic())\n      << Error::RuntimeError()\n      << \"prob and label are expected to have the same dynamic property, but found \"\n      << prob_desc.is_dynamic() << \" and \" << label_desc.is_dynamic();\n  CHECK_GE_OR_RETURN(prob_desc.shape().NumAxes(), 2)\n      << Error::RuntimeError()\n      << \"The dimension of prob must be greater than or equal to 2, but found \"\n      << prob_desc.shape().NumAxes();\n  const int64_t num_out_axes = prob_desc.shape().NumAxes() - 1;\n  CHECK_EQ_OR_RETURN(label_desc.shape().NumAxes(), num_out_axes)\n      << Error::RuntimeError()\n      << \"The dimension of label is expected to be less than that of prediction by 1, but found \"\n      << label_desc.shape().NumAxes() << \" and \" << num_out_axes;\n  FOR_RANGE(int64_t, i, 0, num_out_axes) {\n    CHECK_EQ_OR_RETURN(prob_desc.shape().At(i), label_desc.shape().At(i))\n        << Error::RuntimeError() << \"The size of prob (\" << prob_desc.shape().At(i)\n        << \") must match the size of label (\" << label_desc.shape().At(i) << \") at dimension \" << i;\n  }\n  CHECK_EQ_OR_RETURN(dy_desc.shape(), label_desc.shape())\n      << Error::RuntimeError() << \"The size of dy \" << dy_desc.shape()\n      << \" must match the size of label \" << label_desc.shape();\n  ctx->SetOutputShape(\"prediction_diff\", 0, prob_desc.shape());\n  ctx->SetOutputIsDynamic(\"prediction_diff\", 0, prob_desc.is_dynamic());\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InferDataType(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& label_desc = ctx->InputTensorDesc(\"label\", 0);\n  CHECK_OR_RETURN(IsIndexDataType(label_desc.data_type()))\n      << Error::TypeError() << \"The dtype of label must be integer, but found \"\n      << DataType_Name(label_desc.data_type());\n  ctx->SetOutputDType(\"prob\", 0, ctx->InputDType(\"prediction\", 0));\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"prediction\", 0));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InferDataTypeGrad(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& prob_desc = ctx->InputTensorDesc(\"prob\", 0);\n  const user_op::TensorDesc& label_desc = ctx->InputTensorDesc(\"label\", 0);\n  CHECK_OR_RETURN(IsIndexDataType(label_desc.data_type()))\n      << Error::TypeError() << \"The dtype of label must be integer, but found \"\n      << DataType_Name(label_desc.data_type());\n  const user_op::TensorDesc& dy_desc = ctx->InputTensorDesc(\"dy\", 0);\n  CHECK_EQ_OR_RETURN(dy_desc.data_type(), prob_desc.data_type())\n      << Error::TypeError() << \"dy and prob are expected to have the same dtype, but found \"\n      << DataType_Name(dy_desc.data_type()) << \" and \" << DataType_Name(prob_desc.data_type());\n  ctx->SetOutputDType(\"prediction_diff\", 0, prob_desc.data_type());\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> AddSignature(user_op::SbpContext* ctx) {\n  ctx->NewBuilder()\n      .Split(user_op::OpArg(\"prediction\", 0), 0)\n      .Split(user_op::OpArg(\"label\", 0), 0)\n      .Split(user_op::OpArg(\"prob\", 0), 0)\n      .Split(user_op::OpArg(\"out\", 0), 0)\n      .Build();\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> AddMsSignature(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& prediction =\n      ctx->LogicalTensorDesc4InputArgNameAndIndex(\"prediction\", 0);\n  ctx->NewBuilder()\n      .Split(user_op::OpArg(\"prediction\", 0), 0)\n      .Split(user_op::OpArg(\"prob\", 0), 0)\n      .Split(user_op::OpArg(\"label\", 0), 0)\n      .Split(user_op::OpArg(\"out\", 0), 0)\n      .Build();\n  ctx->NewBuilder()\n      .Split(user_op::OpArg(\"prediction\", 0), prediction.shape().NumAxes() - 1)\n      .Split(user_op::OpArg(\"prob\", 0), prediction.shape().NumAxes() - 1)\n      .Broadcast(user_op::OpArg(\"label\", 0))\n      .PartialSum(user_op::OpArg(\"out\", 0))\n      .Build();\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> AddGradSignature(user_op::SbpContext* ctx) {\n  ctx->NewBuilder()\n      .Split(user_op::OpArg(\"dy\", 0), 0)\n      .Split(user_op::OpArg(\"label\", 0), 0)\n      .Split(user_op::OpArg(\"prob\", 0), 0)\n      .Split(user_op::OpArg(\"prediction_diff\", 0), 0)\n      .Build();\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> AddGradMsSignature(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& prob = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"prob\", 0);\n  ctx->NewBuilder()\n      .Split(user_op::OpArg(\"prob\", 0), 0)\n      .Split(user_op::OpArg(\"label\", 0), 0)\n      .Split(user_op::OpArg(\"dy\", 0), 0)\n      .Split(user_op::OpArg(\"prediction_diff\", 0), 0)\n      .Build();\n  ctx->NewBuilder()\n      .Split(user_op::OpArg(\"prob\", 0), prob.shape().NumAxes() - 1)\n      .Broadcast(user_op::OpArg(\"label\", 0))\n      .Broadcast(user_op::OpArg(\"dy\", 0))\n      .Split(user_op::OpArg(\"prediction_diff\", 0), prob.shape().NumAxes() - 1)\n      .Build();\n  return Maybe<void>::Ok();\n}\n\ntemplate<Maybe<void> (*GetSbpSignature)(user_op::SbpContext*)>\nMaybe<void> GetSbpFn(user_op::SbpContext* ctx) {\n  JUST(GetSbpSignature(ctx));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\n#define IMPLEMENT_SPAESE_SOFTMAX_CROSS_ENTROPY_OP_FUNCS(op_name, sbp_sig)                       \\\n  /*static*/ Maybe<void> op_name##Op::GetSbp(user_op::SbpContext* ctx) { return sbp_sig(ctx); } \\\n  /*static*/ Maybe<void> op_name##Op::InferLogicalTensorDesc(user_op::InferContext* ctx) {      \\\n    return InferTensorDescFn(ctx);                                                              \\\n  }                                                                                             \\\n  /*static*/ Maybe<void> op_name##Op::InferPhysicalTensorDesc(user_op::InferContext* ctx) {     \\\n    return InferLogicalTensorDesc(ctx);                                                         \\\n  }                                                                                             \\\n  /*static*/ Maybe<void> op_name##Op::InferDataType(user_op::InferContext* ctx) {               \\\n    return oneflow::InferDataType(ctx);                                                         \\\n  }                                                                                             \\\n  /*static*/ Maybe<void> op_name##Op::ModifyInputArg(                                           \\\n      const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&) {    \\\n    user_op::InputArgModifier* label_modifier = GetInputArgModifierFn(\"label\", 0);              \\\n    CHECK_OR_RETURN(label_modifier != nullptr); /* NOLINT(maybe-need-error-msg) */              \\\n    label_modifier->set_requires_grad(false);                                                   \\\n    return Maybe<void>::Ok();                                                                   \\\n  }\n\nIMPLEMENT_SPAESE_SOFTMAX_CROSS_ENTROPY_OP_FUNCS(SparseSoftmaxCrossEntropy, AddSignature);\nIMPLEMENT_SPAESE_SOFTMAX_CROSS_ENTROPY_OP_FUNCS(SparseSoftmaxCrossEntropyMs, AddMsSignature);\n#undef IMPLEMENT_SPAESE_SOFTMAX_CROSS_ENTROPY_OP_FUNCS\n\n#define IMPLEMENT_SPAESE_SOFTMAX_CROSS_ENTROPY_GRAD_OP_FUNCS(op_name, sbp_sig)                  \\\n  /*static*/ Maybe<void> op_name##GradOp::GetSbp(user_op::SbpContext* ctx) {                    \\\n    return sbp_sig(ctx);                                                                        \\\n  }                                                                                             \\\n  /*static*/ Maybe<void> op_name##GradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {  \\\n    return InferGradTensorDescFn(ctx);                                                          \\\n  }                                                                                             \\\n  /*static*/ Maybe<void> op_name##GradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { \\\n    return InferLogicalTensorDesc(ctx);                                                         \\\n  }                                                                                             \\\n  /*static*/ Maybe<void> op_name##GradOp::InferDataType(user_op::InferContext* ctx) {           \\\n    return InferDataTypeGrad(ctx);                                                              \\\n  }\n\nIMPLEMENT_SPAESE_SOFTMAX_CROSS_ENTROPY_GRAD_OP_FUNCS(SparseSoftmaxCrossEntropy, AddGradSignature);\nIMPLEMENT_SPAESE_SOFTMAX_CROSS_ENTROPY_GRAD_OP_FUNCS(SparseSoftmaxCrossEntropyMs,\n                                                     AddGradMsSignature);\n#undef IMPLEMENT_SPAESE_SOFTMAX_CROSS_ENTROPY_GRAD_OP_FUNCS\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/split_like_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/*static*/ Maybe<void> SplitLikeOp::GetSbp(user_op::SbpContext* ctx) {\n  const auto axis = ctx->Attr<int64_t>(\"axis\");\n  const int64_t in_num_axes =\n      ctx->LogicalTensorDesc4InputArgNameAndIndex(\"in\", 0).shape().NumAxes();\n  const int64_t like_num_axes =\n      ctx->LogicalTensorDesc4InputArgNameAndIndex(\"like\", 0).shape().NumAxes();\n  FOR_RANGE(int64_t, i, 0, like_num_axes) {\n    if (i == axis) { continue; }\n    ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build();\n  }\n  std::vector<user_op::OpArg> like_arg_vec;\n  const size_t like_arg_size = ctx->outputs().size();\n  like_arg_vec.reserve(like_arg_size);\n  FOR_RANGE(int32_t, i, 0, like_arg_size) { like_arg_vec.emplace_back(\"like\", i); }\n  FOR_RANGE(int64_t, i, like_num_axes, in_num_axes) {\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"in\", 0), i)\n        .Broadcast(like_arg_vec)\n        .Split(ctx->outputs(), i)\n        .Build();\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"in\", 0), i)\n        .PartialSum(like_arg_vec)\n        .Split(ctx->outputs(), i)\n        .Build();\n  }\n  ctx->NewBuilder()\n      .PartialSum(user_op::OpArg(\"in\", 0))\n      .PartialSum(like_arg_vec)\n      .PartialSum(ctx->outputs())\n      .Build();\n  ctx->NewBuilder()\n      .PartialSum(user_op::OpArg(\"in\", 0))\n      .Broadcast(like_arg_vec)\n      .PartialSum(ctx->outputs())\n      .Build();\n  ctx->NewBuilder()\n      .Broadcast(user_op::OpArg(\"in\", 0))\n      .PartialSum(like_arg_vec)\n      .Broadcast(ctx->outputs())\n      .Build();\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> SplitLikeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const auto axis = ctx->Attr<int64_t>(\"axis\");\n  const user_op::TensorDesc& in_desc = ctx->InputTensorDesc(\"in\", 0);\n  int64_t dynamic_dim_size = 0;\n  int64_t static_dim_size = 0;\n  const int64_t in_num_axes = ctx->InputTensorDesc(\"in\", 0).shape().NumAxes();\n  const int64_t like_num_axes = ctx->InputTensorDesc(\"like\", 0).shape().NumAxes();\n  CHECK_LE_OR_RETURN(like_num_axes, in_num_axes)\n      << Error::RuntimeError() << \"The dimension of like (\" << like_num_axes\n      << \") should be less than or equal to input (\" << in_num_axes << \")\";\n  CHECK_LT_OR_RETURN(axis, like_num_axes)\n      << Error::RuntimeError() << \"The axis (\" << axis\n      << \") should be less than the dimension of like (\" << like_num_axes << \")\";\n  FOR_RANGE(int32_t, i, 0, ctx->outputs().size()) {\n    const user_op::TensorDesc& like_i_desc = ctx->InputTensorDesc(\"like\", i);\n    user_op::TensorDesc* out_i_desc = ctx->MutOutputTensorDesc(\"out\", i);\n    CHECK_EQ_OR_RETURN(like_i_desc.shape().NumAxes(), like_num_axes)\n        << Error::RuntimeError() << \"The dimension of like_i (\" << like_i_desc.shape().NumAxes()\n        << \") must match the dimension of the first like (\" << like_num_axes << \")\";\n    FOR_RANGE(int64_t, j, 0, like_num_axes) {\n      if (j == axis) {\n        if (like_i_desc.is_dynamic()) {\n          dynamic_dim_size += like_i_desc.shape().At(j);\n        } else {\n          static_dim_size += like_i_desc.shape().At(j);\n        }\n      } else {\n        CHECK_EQ_OR_RETURN(in_desc.shape().At(j), like_i_desc.shape().At(j))\n            << Error::RuntimeError() << \"The size of input (\" << in_desc.shape().At(j)\n            << \") must match the size of like_i (\" << like_i_desc.shape().At(j) << \") at dimension \"\n            << j;\n      }\n    }\n    DimVector out_i_dim_vec = like_i_desc.shape().dim_vec();\n    FOR_RANGE(int64_t, j, like_num_axes, in_num_axes) {\n      out_i_dim_vec.emplace_back(in_desc.shape().At(j));\n    }\n    out_i_desc->set_shape(Shape(out_i_dim_vec));\n    out_i_desc->set_is_dynamic(like_i_desc.is_dynamic());\n  }\n  if (dynamic_dim_size == 0) {\n    CHECK_EQ_OR_RETURN(static_dim_size, in_desc.shape().At(axis))\n        << Error::RuntimeError() << \"In non-dynamic shape situation, the total size of like (\"\n        << static_dim_size << \") should be equal to the size of input (\" << in_desc.shape().At(axis)\n        << \") at dimension \" << axis;\n  } else {\n    CHECK_LE_OR_RETURN(static_dim_size, in_desc.shape().At(axis))\n        << Error::RuntimeError() << \"In dynamic shape situation, the total size of like (\"\n        << static_dim_size << \") should be less than or equal to the size of input (\"\n        << in_desc.shape().At(axis) << \") at dimension \" << axis;\n  }\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> SplitLikeOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> SplitLikeOp::InferDataType(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& in_desc = ctx->InputTensorDesc(\"in\", 0);\n  FOR_RANGE(int32_t, i, 0, ctx->outputs().size()) {\n    user_op::TensorDesc* out_i_desc = ctx->MutOutputTensorDesc(\"out\", i);\n    out_i_desc->set_data_type(in_desc.data_type());\n  }\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> SplitLikeOp::ModifyInputArg(const GetInputArgModifier& GetInputArgModifierFn,\n                                                   const user_op::UserOpConfWrapper& user_op_conf) {\n  FOR_RANGE(int32_t, i, 0, user_op_conf.input_size(\"like\")) {\n    user_op::InputArgModifier* like_modifier = GetInputArgModifierFn(\"like\", i);\n    CHECK_NOTNULL_OR_RETURN(like_modifier);  // NOLINT(maybe-need-error-msg)\n    like_modifier->set_requires_grad(false);\n  }\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> SplitLikeOp::CheckAttr(const user_op::UserOpDefWrapper&,\n                                              const user_op::UserOpConfWrapper& op_conf) {\n  CHECK_OR_RETURN(op_conf.input_size(\"like\") >= 1)\n      << Error::RuntimeError() << \"The number of like should be greater than or equal to 1\";\n  CHECK_OR_RETURN(op_conf.output_size(\"out\") >= 1)\n      << Error::RuntimeError() << \"The number of output should be greater than or equal to 1\";\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/sqrt_square_sum_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/*static*/ Maybe<void> SqrtSquareSumOp::GetSbp(user_op::SbpContext* ctx) {\n  const int64_t num_x_axes = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"x\", 0).shape().NumAxes();\n  FOR_RANGE(int64_t, i, 0, num_x_axes) {\n    ctx->NewBuilder().Split(user_op::OpArg(\"x\", 0), i).PartialSum(user_op::OpArg(\"y\", 0)).Build();\n  }\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> SqrtSquareSumOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  user_op::TensorDesc* y = ctx->MutOutputTensorDesc(\"y\", 0);\n  y->set_shape(Shape({}));\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> SqrtSquareSumOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> SqrtSquareSumOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"y\", 0, ctx->InputDType(\"x\", 0));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/square_relu_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/*static*/ Maybe<void> SquareReLUOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  ctx->SetOutputShape(\"y\", 0, ctx->InputShape(\"x\", 0));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> SquareReLUOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/*static*/ Maybe<void> SquareReLUOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"y\", 0, ctx->InputDType(\"x\", 0));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> SquareReLUOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"x\", 0);\n  FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) {\n    ctx->NewBuilder().Split(user_op::OpArg(\"x\", 0), i).Split(user_op::OpArg(\"y\", 0), i).Build();\n  }\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> SquareReLUGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& x_shape = ctx->InputShape(\"x\", 0);\n  const Shape& dy_shape = ctx->InputShape(\"dy\", 0);\n  CHECK_OR_RETURN(dy_shape == x_shape)\n      << \"InferTensorDesc failed (\" << ctx->op_name() << \"). Expected x shape \"\n      << x_shape.ToString() << \" to be equal to dy shape \" << dy_shape.ToString();\n  ctx->SetOutputShape(\"dx\", 0, dy_shape);\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> SquareReLUGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/*static*/ Maybe<void> SquareReLUGradOp::InferDataType(user_op::InferContext* ctx) {\n  CHECK_EQ_OR_RETURN(ctx->InputDType(\"x\", 0), ctx->InputDType(\"dy\", 0))\n      << \"InferDataType Failed. Expected \" << DataType_Name(ctx->InputDType(\"dy\", 0))\n      << \", but got \" << DataType_Name(ctx->InputDType(\"x\", 0));\n  ctx->SetOutputDType(\"dx\", 0, ctx->InputDType(\"x\", 0));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> SquareReLUGradOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"x\", 0);\n  FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) {\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"x\", 0), i)\n        .Split(user_op::OpArg(\"dy\", 0), i)\n        .Split(user_op::OpArg(\"dx\", 0), i)\n        .Build();\n  }\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/square_sum_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/*static*/ Maybe<void> SquareSumOp::GetSbp(user_op::SbpContext* ctx) {\n  const int64_t num_x_axes = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"x\", 0).shape().NumAxes();\n  FOR_RANGE(int64_t, i, 0, num_x_axes) {\n    ctx->NewBuilder().Split(user_op::OpArg(\"x\", 0), i).PartialSum(user_op::OpArg(\"y\", 0)).Build();\n  }\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> SquareSumOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  user_op::TensorDesc* y = ctx->MutOutputTensorDesc(\"y\", 0);\n  y->set_shape(Shape({1}));\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> SquareSumOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> SquareSumOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"y\", 0, ctx->InputDType(\"x\", 0));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> MultiSquareSumOp::GetSbp(user_op::SbpContext* ctx) {\n  int64_t min_num_axes = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"x\", 0).shape().NumAxes();\n  for (int64_t i = 1; i < ctx->user_op_conf().input_size(\"x\"); ++i) {\n    min_num_axes = std::min(min_num_axes,\n                            ctx->LogicalTensorDesc4InputArgNameAndIndex(\"x\", i).shape().NumAxes());\n  }\n  for (int64_t i = 0; i < min_num_axes; ++i) {\n    ctx->NewBuilder().Split(ctx->inputs(), i).PartialSum(user_op::OpArg(\"y\", 0)).Build();\n  }\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> MultiSquareSumOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  user_op::TensorDesc* y = ctx->MutOutputTensorDesc(\"y\", 0);\n  y->set_shape(Shape({1}));\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> MultiSquareSumOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> MultiSquareSumOp::InferDataType(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& x_0 = ctx->InputTensorDesc(\"x\", 0);\n  user_op::TensorDesc* y = ctx->MutOutputTensorDesc(\"y\", 0);\n  for (int64_t i = 1; i < ctx->input_size(\"x\"); ++i) {\n    const user_op::TensorDesc& x_i = ctx->InputTensorDesc(\"x\", i);\n    CHECK_EQ_OR_RETURN(x_i.data_type(), x_0.data_type())\n        << Error::TypeError()\n        << \"All tensors are expected to have the same dtype, but found at least two dtypes, \"\n        << DataType_Name(x_i.data_type()) << \" and \" << DataType_Name(x_0.data_type());\n  }\n  y->set_data_type(x_0.data_type());\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> MultiSquareSumOp::CheckAttr(const user_op::UserOpDefWrapper&,\n                                                   const user_op::UserOpConfWrapper& op_conf) {\n  CHECK_OR_RETURN(op_conf.input_size(\"x\") >= 1)\n      << Error::RuntimeError() << \"The number of x should be greater than or equal to 1\";\n  return Maybe<void>::Ok();\n}\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/squeeze_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nMaybe<void> TransformNegativeAxesToPositive(const std::vector<int32_t>& axes_vec,\n                                            const int32_t num_axes, AxisVector* fixed_axes_vec) {\n  fixed_axes_vec->resize(axes_vec.size());\n  FOR_RANGE(size_t, i, 0, fixed_axes_vec->size()) {\n    CHECK_GE_OR_RETURN(axes_vec[i], -num_axes);\n    CHECK_LT_OR_RETURN(axes_vec[i], num_axes);\n    fixed_axes_vec->at(i) = axes_vec[i] >= 0 ? axes_vec[i] : axes_vec[i] + num_axes;\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> CheckAndLabelAxesToSqueezeMinusOne(const AxisVector& axes, DimVector* dim_vec) {\n  for (const auto& axis : axes) {\n    CHECK_EQ_OR_RETURN(dim_vec->at(axis), 1);\n    dim_vec->at(axis) = -1;\n  }\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\n/*static*/ Maybe<void> SqueezeOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"in\", 0);\n  AxisVector fixed_axes_vec;\n  JUST(TransformNegativeAxesToPositive(ctx->Attr<std::vector<int32_t>>(\"axes\"),\n                                       in_tensor.shape().NumAxes(), &fixed_axes_vec));\n\n  DimVector dim_vec = in_tensor.shape().dim_vec();\n  JUST(CheckAndLabelAxesToSqueezeMinusOne(fixed_axes_vec, &dim_vec));\n  int32_t out_axis = 0;\n  FOR_RANGE(int32_t, in_axis, 0, dim_vec.size()) {\n    if (dim_vec.at(in_axis) != -1) {\n      ctx->NewBuilder()\n          .Split(user_op::OpArg(\"in\", 0), in_axis)\n          .Split(user_op::OpArg(\"out\", 0), out_axis)\n          .Build();\n      ++out_axis;\n    }\n  }\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> SqueezeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& in_shape = ctx->InputShape(\"in\", 0);\n  AxisVector fixed_axes_vec;\n  JUST(TransformNegativeAxesToPositive(ctx->Attr<std::vector<int32_t>>(\"axes\"), in_shape.NumAxes(),\n                                       &fixed_axes_vec));\n\n  DimVector dim_vec = in_shape.dim_vec();\n  JUST(CheckAndLabelAxesToSqueezeMinusOne(fixed_axes_vec, &dim_vec));\n  dim_vec.erase(std::remove(dim_vec.begin(), dim_vec.end(), -1), dim_vec.end());\n  ctx->SetOutputShape(\"out\", 0, Shape(dim_vec));\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> SqueezeOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> SqueezeOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/ssp_variable_proxy_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/*static*/ Maybe<void> SspVariableProxyOp::GetSbp(user_op::SbpContext* ctx) {\n  const auto& var_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"var\", 0);\n  FOR_RANGE(int64_t, i, 0, var_tensor.shape().NumAxes()) {\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"var\", 0), i)\n        .Split(user_op::OpArg(\"ref\", 0), i)\n        .Split(user_op::OpArg(\"value\", 0), i)\n        .Build();\n  }\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> SspVariableProxyOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& var_shape = ctx->InputShape(\"var\", 0);\n  ctx->SetOutputShape(\"ref\", 0, var_shape);\n  ctx->SetOutputShape(\"value\", 0, var_shape);\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> SspVariableProxyOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> SspVariableProxyOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"ref\", 0, ctx->InputDType(\"var\", 0));\n  ctx->SetOutputDType(\"value\", 0, ctx->InputDType(\"var\", 0));\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> SspVariableProxyOp::ModifyOutputArg(\n    const GetOutputArgModifier& GetOutputArgModifierFn, const user_op::UserOpConfWrapper&) {\n  user_op::OutputArgModifier* out_modifier = GetOutputArgModifierFn(\"ref\", 0);\n  CHECK_OR_RETURN(out_modifier != nullptr);\n  out_modifier->set_is_mutable(true);\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/stack_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/* static */ Maybe<void> StackOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& first_in_desc = ctx->InputTensorDesc(\"in\", 0);\n  const int64_t axis = ctx->Attr<int64_t>(\"axis\");\n  CHECK_GE_OR_RETURN(axis, 0) << \"The axis should be greater than or equal to 0.\";\n  const int64_t in_num_axes = first_in_desc.shape().NumAxes();\n  CHECK_LE_OR_RETURN(axis, in_num_axes)\n      << \"The axis should be less than or equal to input num axes.\";\n  DimVector out_dim_vec(in_num_axes + 1);\n  for (int i = 0; i < in_num_axes + 1; i++) {\n    if (i == axis) {\n      continue;\n    } else if (i < axis) {\n      out_dim_vec.at(i) = first_in_desc.shape().At(i);\n    } else {\n      out_dim_vec.at(i) = first_in_desc.shape().At(i - 1);\n    }\n  }\n  int64_t dynamic_dim_size = 0;\n  for (const auto& in_arg_pair : ctx->inputs()) {\n    const user_op::TensorDesc& in_desc =\n        ctx->InputTensorDesc(in_arg_pair.first, in_arg_pair.second);\n    CHECK_EQ_OR_RETURN(in_desc.shape().NumAxes(), first_in_desc.shape().NumAxes())\n        << \"The num axes of input should be equal to first input's num axes. \";\n    FOR_RANGE(int64_t, i, 0, in_num_axes + 1) {\n      if (i == axis) {\n        if (in_desc.is_dynamic()) {\n          dynamic_dim_size += 1;\n        } else {\n          out_dim_vec.at(axis) += 1;\n        }\n      } else if (i < axis) {\n        CHECK_EQ_OR_RETURN(in_desc.shape().At(i), out_dim_vec.at(i))\n            << \"The input shape at axis \" << i << \" is not equal to out shape at axis \" << i;\n      } else {\n        CHECK_EQ_OR_RETURN(in_desc.shape().At(i - 1), out_dim_vec.at(i))\n            << \"The input shape at axis \" << i - 1 << \" is not equal to out shape at axis \" << i;\n      }\n    }\n  }\n  user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc(\"out\", 0);\n  const int64_t max_dim_size = ctx->Attr<int64_t>(\"max_dim_size\");\n  CHECK_LE_OR_RETURN(out_dim_vec.at(axis), max_dim_size)\n      << \"The out shape at axis \" << axis << \" should be less equal to \" << max_dim_size;\n  if (dynamic_dim_size == 0) {\n    out_desc->set_is_dynamic(false);\n  } else {\n    out_desc->set_is_dynamic(true);\n    out_dim_vec.at(axis) = max_dim_size;\n  }\n  out_desc->set_shape(Shape(out_dim_vec));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> StackOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> StackOp::GetSbp(user_op::SbpContext* ctx) {\n  const int64_t axis = ctx->Attr<int64_t>(\"axis\");\n  const user_op::TensorDesc& first_in_desc = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"in\", 0);\n  FOR_RANGE(int64_t, i, 0, first_in_desc.shape().NumAxes()) {\n    /*\n    Stack can be view as expand_dims + concat.\n    For stack([(2, 4, 6), (2, 4, 6), axis=1]), it equals to [2, 4, 6]->[2, 1, 4, 6]. concat([2, 1,\n    4, 6], [2, 1, 4, 6], concat_dim=1) Concat split all the axis except the concat_dim.\n    */\n    if (i >= axis) {\n      ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i + 1).Build();\n    } else {\n      ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build();\n    }\n  }\n  ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> StackOp::InferDataType(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& first_in_desc = ctx->InputTensorDesc(\"in\", 0);\n  for (const auto& in_arg_pair : ctx->inputs()) {\n    const user_op::TensorDesc& in_desc =\n        ctx->InputTensorDesc(in_arg_pair.first, in_arg_pair.second);\n    CHECK_EQ_OR_RETURN(in_desc.data_type(), first_in_desc.data_type())\n        << \"InferDataType Failed. Expected \" << DataType_Name(first_in_desc.data_type())\n        << \", but got \" << DataType_Name(in_desc.data_type());\n  }\n  user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc(\"out\", 0);\n  out_desc->set_data_type(first_in_desc.data_type());\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> StackOp::CheckAttr(const user_op::UserOpDefWrapper&,\n                                          const user_op::UserOpConfWrapper& op_conf) {\n  CHECK_OR_RETURN(op_conf.input_size(\"in\") >= 1)\n      << \"The size of input should be greater than or equal to 1. \";\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> StackGradOp::GetSbp(user_op::SbpContext* ctx) {\n  const auto axis = ctx->Attr<int64_t>(\"axis\");\n  const int64_t like_num_axes =\n      ctx->LogicalTensorDesc4InputArgNameAndIndex(\"like\", 0).shape().NumAxes();\n  std::vector<user_op::OpArg> like_arg_vec;\n  const size_t like_arg_size = ctx->outputs().size();\n  like_arg_vec.reserve(like_arg_size);\n  FOR_RANGE(int32_t, i, 0, like_arg_size) { like_arg_vec.emplace_back(\"like\", i); }\n  FOR_RANGE(int64_t, i, 0, like_num_axes) {\n    if (i >= axis) {\n      ctx->NewBuilder()\n          .Split(like_arg_vec, i)\n          .Split(ctx->outputs(), i)\n          .Split(user_op::OpArg(\"in\", 0), i + 1)\n          .Build();\n    } else {\n      ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build();\n    }\n  }\n  ctx->NewBuilder()\n      .PartialSum(user_op::OpArg(\"in\", 0))\n      .PartialSum(like_arg_vec)\n      .PartialSum(ctx->outputs())\n      .Build();\n  ctx->NewBuilder()\n      .PartialSum(user_op::OpArg(\"in\", 0))\n      .Broadcast(like_arg_vec)\n      .PartialSum(ctx->outputs())\n      .Build();\n  ctx->NewBuilder()\n      .Broadcast(user_op::OpArg(\"in\", 0))\n      .PartialSum(like_arg_vec)\n      .Broadcast(ctx->outputs())\n      .Build();\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> StackGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const auto axis = ctx->Attr<int64_t>(\"axis\");\n  const user_op::TensorDesc& in_desc = ctx->InputTensorDesc(\"in\", 0);\n  int64_t dynamic_dim_size = 0;\n  int64_t static_dim_size = 0;\n  const int64_t in_num_axes = ctx->InputTensorDesc(\"in\", 0).shape().NumAxes();\n  const int64_t like_num_axes = ctx->InputTensorDesc(\"like\", 0).shape().NumAxes();\n  CHECK_LE_OR_RETURN(like_num_axes, in_num_axes)\n      << \"The num axes of `like` tensor should be less equal to num axes of `in` tensor. \";\n  CHECK_LE_OR_RETURN(axis, like_num_axes)\n      << \"The axis should be less equal than num axes of `like` tensor. \";\n  FOR_RANGE(int32_t, i, 0, ctx->outputs().size()) {\n    const user_op::TensorDesc& like_i_desc = ctx->InputTensorDesc(\"like\", i);\n    user_op::TensorDesc* out_i_desc = ctx->MutOutputTensorDesc(\"out\", i);\n    CHECK_EQ_OR_RETURN(like_i_desc.shape().NumAxes(), like_num_axes)\n        << \"The num axes of `like` tensor at index \" << i\n        << \" should be equal to first `like` tensor. \";\n    FOR_RANGE(int64_t, j, 0, like_num_axes + 1) {\n      if (j == axis) {\n        if (like_i_desc.is_dynamic()) {\n          dynamic_dim_size += like_i_desc.shape().Count(j);\n        } else {\n          static_dim_size += like_i_desc.shape().Count(j);\n        }\n      } else if (j < axis) {\n        CHECK_EQ_OR_RETURN(in_desc.shape().At(j), like_i_desc.shape().At(j))\n            << \" Stack Grad expects the shape of input tensor is equal to like tensor's. \"\n               \", but got \"\n            << in_desc.shape().ToString() << \" at input and \" << like_i_desc.shape().ToString()\n            << \"at like \";\n      } else {\n        CHECK_EQ_OR_RETURN(in_desc.shape().At(j), like_i_desc.shape().At(j - 1))\n            << \" Stack Grad expects the shape of input tensor is equal to like tensor's. \"\n               \", but got \"\n            << in_desc.shape().ToString() << \" at input and \" << like_i_desc.shape().ToString()\n            << \"at like \";\n      }\n    }\n    DimVector out_i_dim_vec = like_i_desc.shape().dim_vec();\n    out_i_desc->set_shape(Shape(out_i_dim_vec));\n    out_i_desc->set_is_dynamic(like_i_desc.is_dynamic());\n  }\n  if (dynamic_dim_size == 0) {\n    CHECK_EQ_OR_RETURN(static_dim_size, in_desc.shape().Count(axis))\n        << \"In non dynamic shape situation, the static dim size should be equal to input tensor \"\n           \"size. \";\n  } else {\n    CHECK_LE_OR_RETURN(static_dim_size, in_desc.shape().Count(axis))\n        << \"In dynamic shape situation, the static dim size should be less equal to input tensor \"\n           \"size. \";\n  }\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> StackGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> StackGradOp::InferDataType(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& in_desc = ctx->InputTensorDesc(\"in\", 0);\n  FOR_RANGE(int32_t, i, 0, ctx->outputs().size()) {\n    user_op::TensorDesc* out_i_desc = ctx->MutOutputTensorDesc(\"out\", i);\n    out_i_desc->set_data_type(in_desc.data_type());\n  }\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> StackGradOp::ModifyInputArg(const GetInputArgModifier& GetInputArgModifierFn,\n                                                   const user_op::UserOpConfWrapper& user_op_conf) {\n  FOR_RANGE(int32_t, i, 0, user_op_conf.input_size(\"like\")) {\n    user_op::InputArgModifier* like_modifier = GetInputArgModifierFn(\"like\", i);\n    CHECK_NOTNULL_OR_RETURN(like_modifier);\n    like_modifier->set_requires_grad(false);\n  }\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> StackGradOp::CheckAttr(const user_op::UserOpDefWrapper&,\n                                              const user_op::UserOpConfWrapper& op_conf) {\n  CHECK_OR_RETURN(op_conf.input_size(\"like\") >= 1)\n      << \"The count of like tensor should be greater than or equal to 1. \";\n  CHECK_OR_RETURN(op_conf.output_size(\"out\") >= 1)\n      << \"The count of out tensor should be greater than or equal to 1. \";\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/stft_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\nconst Stride InferOutputStride(const Shape& in_shape, bool onesided = true,\n                               bool return_complex = false) {\n  // TODO(yzm):support return_complex\n  int last_dim_size = in_shape.At(2);\n  if (onesided) { last_dim_size = last_dim_size / 2 + 1; }\n  Stride out_stride(in_shape.NumAxes(), 0);\n  if (in_shape.At(0) == 1) {\n    out_stride = {2, 2 * last_dim_size, 1};\n  } else {\n    out_stride = {last_dim_size * 2 * in_shape.At(1), 2, 2 * last_dim_size, 1};\n  }\n  return out_stride;\n}\n\nconst Shape InferOutputShape(const Shape& in_shape, bool onesided = true,\n                             bool return_complex = false) {\n  // TODO(yzm):support return_complex\n  Shape out_shape;\n  int last_dim_size = in_shape.At(2);\n  if (onesided) { last_dim_size = last_dim_size / 2 + 1; }\n  if (in_shape.At(0) == 1) {\n    out_shape = {last_dim_size, in_shape.At(1), 2};\n  } else {\n    out_shape = {in_shape.At(0), last_dim_size, in_shape.At(1), 2};\n  }\n  return out_shape;\n}\n\n/* static */ Maybe<void> StftOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& in_shape = ctx->InputShape(\"input\", 0);\n  const bool onesided = ctx->Attr<bool>(\"onesided\");\n\n  const Stride& out_stride = InferOutputStride(in_shape, onesided);\n  const Shape& out_shape = InferOutputShape(in_shape, onesided);\n\n  ctx->SetOutputStride(\"output\", 0, out_stride);\n  ctx->SetOutputShape(\"output\", 0, out_shape);\n  ctx->SetOutputIsDynamic(\"output\", 0, ctx->InputIsDynamic(\"input\", 0));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> StftOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> StftOp::GetSbp(user_op::SbpContext* ctx) {\n  ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> StftOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"output\", 0, ctx->InputDType(\"input\", 0));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<double> StftOp::GetComputeComplexity(user_op::ComputeComplexityFnContext* ctx) {\n  // TODO: add ComputeComplexityFun\n  return 0.0;\n}\n\n}  // namespace oneflow"
  },
  {
    "path": "oneflow/user/ops/summary_ops.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nMaybe<void> CheckStepShape(const Shape* step) {\n  CHECK_OR_RETURN(step->elem_cnt() == 1);\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> CheckStepShapeInCtx(user_op::InferContext* ctx) {\n  JUST(CheckStepShape(&ctx->InputShape(\"step\", 0)));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> CheckInAndStepScalar(user_op::InferContext* ctx) {\n  const Shape& in_shape = ctx->InputShape(\"in\", 0);\n  const Shape& step_shape = ctx->InputShape(\"step\", 0);\n  CHECK_OR_RETURN(in_shape.elem_cnt() == 1 && step_shape.elem_cnt() == 1);\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\n/*static*/ Maybe<void> CreateSummaryWriterOp::GetSbp(user_op::SbpContext* ctx) {\n  return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx);\n}\n/*static*/ Maybe<void> CreateSummaryWriterOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> CreateSummaryWriterOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> CreateSummaryWriterOp::InferDataType(user_op::InferContext* ctx) {\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> FlushSummaryWriterOp::GetSbp(user_op::SbpContext* ctx) {\n  return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx);\n}\n/*static*/ Maybe<void> FlushSummaryWriterOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> FlushSummaryWriterOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> FlushSummaryWriterOp::InferDataType(user_op::InferContext* ctx) {\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> SummaryWriteScalarOp::GetSbp(user_op::SbpContext* ctx) {\n  return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx);\n}\n/*static*/ Maybe<void> SummaryWriteScalarOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  return CheckInAndStepScalar(ctx);\n}\n/*static*/ Maybe<void> SummaryWriteScalarOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> SummaryWriteScalarOp::InferDataType(user_op::InferContext* ctx) {\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> SummaryWriteHistogramOp::GetSbp(user_op::SbpContext* ctx) {\n  return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx);\n}\n/*static*/ Maybe<void> SummaryWriteHistogramOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  return CheckStepShapeInCtx(ctx);\n}\n/*static*/ Maybe<void> SummaryWriteHistogramOp::InferPhysicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> SummaryWriteHistogramOp::InferDataType(user_op::InferContext* ctx) {\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> SummaryWritePbOp::GetSbp(user_op::SbpContext* ctx) {\n  return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx);\n}\n/*static*/ Maybe<void> SummaryWritePbOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  return CheckStepShapeInCtx(ctx);\n}\n/*static*/ Maybe<void> SummaryWritePbOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> SummaryWritePbOp::InferDataType(user_op::InferContext* ctx) {\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> SummaryWriteImageOp::GetSbp(user_op::SbpContext* ctx) {\n  return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx);\n}\n/*static*/ Maybe<void> SummaryWriteImageOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  return CheckStepShapeInCtx(ctx);\n}\n/*static*/ Maybe<void> SummaryWriteImageOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> SummaryWriteImageOp::InferDataType(user_op::InferContext* ctx) {\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/tanh_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/*static*/ Maybe<void> TanhOp::GetSbp(user_op::SbpContext* ctx) {\n  return user_op::GetSbpFnUtil::SplitForEachAxis(ctx);\n}\n/*static*/ Maybe<void> TanhOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  return user_op::TensorDescInferFnUtil::Unchanged(ctx);\n}\n/*static*/ Maybe<void> TanhOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> TanhOp::InferDataType(user_op::InferContext* ctx) {\n  return user_op::TensorDescInferFnUtil::UnchangedDataType(ctx);\n}\n\n/*static*/ Maybe<void> TanhGradOp::GetSbp(user_op::SbpContext* ctx) {\n  return user_op::GetSbpFnUtil::SplitForEachAxis(ctx);\n}\n/*static*/ Maybe<void> TanhGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  return user_op::TensorDescInferFnUtil::Unchanged(ctx);\n}\n/*static*/ Maybe<void> TanhGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> TanhGradOp::InferDataType(user_op::InferContext* ctx) {\n  return user_op::TensorDescInferFnUtil::UnchangedDataType(ctx);\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/tensor_buffer_ops.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/*static*/ Maybe<void> TensorBufferToTensorOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& in = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"in\", 0);\n  FOR_RANGE(int64_t, i, 0, in.shape().NumAxes()) {\n    ctx->NewBuilder().Split(user_op::OpArg(\"in\", 0), i).Split(user_op::OpArg(\"out\", 0), i).Build();\n  }\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> TensorBufferToTensorOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& in = ctx->InputTensorDesc(\"in\", 0);\n  user_op::TensorDesc* out = ctx->MutOutputTensorDesc(\"out\", 0);\n  out->set_is_dynamic(in.is_dynamic());\n  const auto& instance_shape = ctx->Attr<Shape>(\"instance_shape\");\n  DimVector dim_vec;\n  dim_vec.insert(dim_vec.end(), in.shape().dim_vec().cbegin(), in.shape().dim_vec().cend());\n  dim_vec.insert(dim_vec.end(), instance_shape.dim_vec().cbegin(), instance_shape.dim_vec().cend());\n  out->set_shape(Shape(dim_vec));\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> TensorBufferToTensorOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> TensorBufferToTensorOp::InferDataType(user_op::InferContext* ctx) {\n  const auto data_type = ctx->Attr<DataType>(\"dtype\");\n  user_op::TensorDesc* out = ctx->MutOutputTensorDesc(\"out\", 0);\n  CHECK_OR_RETURN(IsTriviallyCopyableDataType(data_type));\n  out->set_data_type(data_type);\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> TensorToTensorBufferOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& in = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"in\", 0);\n  const auto& instance_dims = ctx->Attr<int32_t>(\"instance_dims\");\n  CHECK_LE_OR_RETURN(instance_dims, in.shape().NumAxes());\n  FOR_RANGE(int64_t, i, 0, in.shape().NumAxes() - instance_dims) {\n    ctx->NewBuilder().Split(user_op::OpArg(\"in\", 0), i).Split(user_op::OpArg(\"out\", 0), i).Build();\n  }\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> TensorToTensorBufferOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& in = ctx->InputTensorDesc(\"in\", 0);\n  const Shape& in_shape = in.shape();\n  const auto& instance_dims = ctx->Attr<int32_t>(\"instance_dims\");\n  CHECK_LT_OR_RETURN(instance_dims, in_shape.NumAxes());\n  user_op::TensorDesc* out = ctx->MutOutputTensorDesc(\"out\", 0);\n  out->set_is_dynamic(in.is_dynamic());\n  DimVector out_dim_vec;\n  out_dim_vec.insert(out_dim_vec.end(), in_shape.dim_vec().cbegin(),\n                     in_shape.dim_vec().cend() - instance_dims);\n  out->set_shape(Shape(out_dim_vec));\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> TensorToTensorBufferOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> TensorToTensorBufferOp::InferDataType(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& in = ctx->InputTensorDesc(\"in\", 0);\n  CHECK_OR_RETURN(IsTriviallyCopyableDataType(in.data_type()));\n  user_op::TensorDesc* out = ctx->MutOutputTensorDesc(\"out\", 0);\n  out->set_data_type(DataType::kTensorBuffer);\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> GenTensorBufferOp::GetSbp(user_op::SbpContext* ctx) {\n  return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx);\n}\n/*static*/ Maybe<void> GenTensorBufferOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  user_op::TensorDesc* out = ctx->MutOutputTensorDesc(\"out\", 0);\n  const Shape& shape = ctx->Attr<Shape>(\"shape\");\n  const int64_t num_tensor_buffers = shape.elem_cnt();\n  const std::vector<Shape>& shape_list = ctx->Attr<std::vector<Shape>>(\"shape_list\");\n  const std::vector<float>& value_list = ctx->Attr<std::vector<float>>(\"value_list\");\n  CHECK_EQ_OR_RETURN(num_tensor_buffers, shape_list.size());\n  CHECK_EQ_OR_RETURN(num_tensor_buffers, value_list.size());\n  out->set_shape(shape);\n  out->set_is_dynamic(ctx->Attr<bool>(\"dynamic_out\"));\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> GenTensorBufferOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> GenTensorBufferOp::InferDataType(user_op::InferContext* ctx) {\n  user_op::TensorDesc* out = ctx->MutOutputTensorDesc(\"out\", 0);\n  out->set_data_type(DataType::kTensorBuffer);\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> TensorBufferToListOfTensorsOp::GetSbp(user_op::SbpContext* ctx) {\n  return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx);\n}\n/*static*/ Maybe<void> TensorBufferToListOfTensorsOp::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) {\n  const user_op::TensorDesc& in = ctx->InputTensorDesc(\"in\", 0);\n  CHECK_GT_OR_RETURN(in.shape().elem_cnt(), 0);\n  CHECK_OR_RETURN(!in.is_dynamic());\n  const Shape& out_shape = ctx->Attr<Shape>(\"out_shape\");\n  const bool dynamic_out = ctx->Attr<bool>(\"dynamic_out\");\n  int64_t num_tensor_buffers = in.shape().elem_cnt();\n  for (int64_t i = 0; i < num_tensor_buffers; ++i) {\n    user_op::TensorDesc* out_i = ctx->MutOutputTensorDesc(\"out\", i);\n    out_i->set_shape(out_shape);\n    out_i->set_is_dynamic(dynamic_out);\n  }\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> TensorBufferToListOfTensorsOp::InferPhysicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> TensorBufferToListOfTensorsOp::InferDataType(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& in = ctx->InputTensorDesc(\"in\", 0);\n  CHECK_EQ_OR_RETURN(in.data_type(), DataType::kTensorBuffer)\n      << \"InferDataType Failed. Expected \" << DataType_Name(DataType::kTensorBuffer) << \", but got \"\n      << DataType_Name(in.data_type());\n  const DataType out_dtype = ctx->Attr<DataType>(\"out_dtype\");\n  CHECK_OR_RETURN(IsTriviallyCopyableDataType(out_dtype));\n  int64_t num_tensor_buffers = ctx->outputs().size();\n  for (int64_t i = 0; i < num_tensor_buffers; ++i) {\n    user_op::TensorDesc* out_i = ctx->MutOutputTensorDesc(\"out\", i);\n    out_i->set_data_type(out_dtype);\n  }\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> TensorBufferToListOfTensorsOp::ModifyOutputArg(\n    const GetOutputArgModifier& GetOutputArgModifierFn, const user_op::UserOpConfWrapper& conf) {\n  if (conf.attr<bool>(\"dynamic_out\")) {\n    FOR_RANGE(int64_t, i, 0, conf.output_size(\"out\")) {\n      user_op::OutputArgModifier* out_i_modifier = GetOutputArgModifierFn(\"out\", i);\n      CHECK_OR_RETURN(out_i_modifier != nullptr);\n      out_i_modifier->set_header_infered_before_compute(false);\n    }\n  }\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> TensorBufferToListOfTensorsOp::CheckAttr(\n    const user_op::UserOpDefWrapper&, const user_op::UserOpConfWrapper& op_conf) {\n  CHECK_OR_RETURN(op_conf.output_size(\"out\") >= 1);\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> TensorBufferToListOfTensorsV2Op::GetSbp(user_op::SbpContext* ctx) {\n  return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx);\n}\n/*static*/ Maybe<void> TensorBufferToListOfTensorsV2Op::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) {\n  const user_op::TensorDesc& in = ctx->InputTensorDesc(\"in\", 0);\n  CHECK_GT_OR_RETURN(in.shape().elem_cnt(), 0);\n  CHECK_OR_RETURN(!in.is_dynamic());\n  const std::vector<Shape>& out_shapes = ctx->Attr<std::vector<Shape>>(\"out_shapes\");\n  const bool dynamic_out = ctx->Attr<bool>(\"dynamic_out\");\n  int64_t num_tensor_buffers = in.shape().elem_cnt();\n  for (int64_t i = 0; i < num_tensor_buffers; ++i) {\n    user_op::TensorDesc* out_i = ctx->MutOutputTensorDesc(\"out\", i);\n    out_i->set_shape(out_shapes[i]);\n    out_i->set_is_dynamic(dynamic_out);\n  }\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> TensorBufferToListOfTensorsV2Op::InferPhysicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> TensorBufferToListOfTensorsV2Op::InferDataType(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& in = ctx->InputTensorDesc(\"in\", 0);\n  CHECK_EQ_OR_RETURN(in.data_type(), DataType::kTensorBuffer)\n      << \"InferDataType Failed. Expected \" << DataType_Name(DataType::kTensorBuffer) << \", but got \"\n      << DataType_Name(in.data_type());\n  const std::vector<DataType>& out_dtypes = ctx->Attr<std::vector<DataType>>(\"out_dtypes\");\n  int64_t num_tensor_buffers = ctx->outputs().size();\n  for (int64_t i = 0; i < num_tensor_buffers; ++i) {\n    CHECK_OR_RETURN(IsTriviallyCopyableDataType(out_dtypes[i]));\n    user_op::TensorDesc* out_i = ctx->MutOutputTensorDesc(\"out\", i);\n    out_i->set_data_type(out_dtypes[i]);\n  }\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> TensorBufferToListOfTensorsV2Op::ModifyOutputArg(\n    const GetOutputArgModifier& GetOutputArgModifierFn, const user_op::UserOpConfWrapper& conf) {\n  if (conf.attr<bool>(\"dynamic_out\")) {\n    FOR_RANGE(int64_t, i, 0, conf.output_size(\"out\")) {\n      user_op::OutputArgModifier* out_i_modifier = GetOutputArgModifierFn(\"out\", i);\n      CHECK_OR_RETURN(out_i_modifier != nullptr);\n      out_i_modifier->set_header_infered_before_compute(false);\n    }\n  }\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> TensorBufferToListOfTensorsV2Op::CheckAttr(\n    const user_op::UserOpDefWrapper&, const user_op::UserOpConfWrapper& op_conf) {\n  CHECK_OR_RETURN(op_conf.output_size(\"out\") >= 1);\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/tensor_constant_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n#include \"oneflow/core/job/nd_sbp_util.h\"\n\nnamespace oneflow {\n\n/* static */ Maybe<void> TensorConstantOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  ctx->SetOutputShape(\"out\", 0, Shape(ctx->Attr<Shape>(\"shape\").dim_vec()));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> TensorConstantOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& parallel_hierarchy = *ctx->parallel_desc().hierarchy();\n  const NdSbp& nd_sbp = ctx->NdSbp4ArgNameAndIndex(\"out\", 0);\n  const Shape& logical_shape = ctx->Attr<Shape>(\"shape\");\n  const int64_t parallel_id = ctx->parallel_ctx().parallel_id();\n  const auto tensor_slice_view =\n      GetTensorSliceView4ParallelId(parallel_hierarchy, nd_sbp, logical_shape, parallel_id);\n  const Shape& physical_shape = tensor_slice_view.shape();\n\n  ctx->SetOutputShape(\"out\", 0, physical_shape);\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> TensorConstantOp::GetSbp(user_op::SbpContext* ctx) {\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> TensorConstantOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) {\n  SbpParallel default_sbp;\n  default_sbp.mutable_broadcast_parallel();\n  return user_op::InferNdSbp4SrcOp(ctx, default_sbp);\n}\n\n/* static */ Maybe<void> TensorConstantOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, ctx->Attr<DataType>(\"dtype\"));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/tf_pool_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/user/utils/pool_util.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\nnamespace {\n\n// Logically computation cost of pool op is the product of output data amount and pool kernal data\n// amount. After adding sbp, we just divide it by parallel number if output data is splitted because\n// splitting input and using partial sum for output is not a valid sbp for this op for now.\nMaybe<double> GetComputationCost(user_op::ComputeComplexityFnContext* ctx) {\n  const std::vector<int32_t> pool_size = ctx->Attr<std::vector<int32_t>>(\"pool_size\");\n  double logical_computation_cost =\n      std::accumulate(pool_size.begin(), pool_size.end(),\n                      ctx->Shape4ArgNameAndIndex(\"y\", 0).elem_cnt(), std::multiplies<double>());\n  const auto& parallel_hierarchy = ctx->parallel_desc().hierarchy();\n  const auto& nd_sbp_y = ctx->NdSbp4ArgNameAndIndex(\"y\", 0);\n  for (int32_t dim_sbp = 0; dim_sbp < nd_sbp_y.sbp_parallel_size(); dim_sbp++) {\n    if (nd_sbp_y.sbp_parallel(dim_sbp).has_split_parallel()) {\n      logical_computation_cost /= parallel_hierarchy->At(dim_sbp);\n    }\n  }\n  return logical_computation_cost;\n}\n\ntypedef std::function<Maybe<void>(user_op::InferContext* ctx)> TensorDescInferFn;\n\nTensorDescInferFn MakeFwTensorDescInferFn(const int32_t dim) {\n  return [dim](user_op::InferContext* ctx) -> Maybe<void> {\n    const Shape& x_shape = ctx->InputShape(\"x\", 0);\n    const std::string& data_format = ctx->Attr<std::string>(\"data_format\");\n    const std::string& padding = ctx->Attr<std::string>(\"padding\");\n    const auto& padding_before = ctx->Attr<std::vector<int32_t>>(\"padding_before\");\n    const auto& padding_after = ctx->Attr<std::vector<int32_t>>(\"padding_after\");\n    const std::vector<int32_t> pool_size = ctx->Attr<std::vector<int32_t>>(\"pool_size\");\n    const std::vector<int32_t> strides = ctx->Attr<std::vector<int32_t>>(\"strides\");\n    const bool ceil_mode = ctx->Attr<bool>(\"ceil_mode\");\n\n    CHECK_EQ_OR_RETURN(pool_size.size(), dim);\n    for (int32_t pool_dim : pool_size) { CHECK_GT_OR_RETURN(pool_dim, 0); }\n    CHECK_EQ_OR_RETURN(strides.size(), dim);\n    for (int32_t stride_dim : strides) { CHECK_GT_OR_RETURN(stride_dim, 0); }\n\n    const Params3D params_3d(dim, x_shape, data_format, padding, padding_before, padding_after,\n                             pool_size, strides, ceil_mode);\n    user_op::TensorDesc* y_desc = ctx->MutOutputTensorDesc(\"y\", 0);\n    y_desc->set_shape(params_3d.GetYShape());\n    y_desc->set_is_dynamic(ctx->InputIsDynamic(\"x\", 0));\n    return Maybe<void>::Ok();\n  };\n}\n\nMaybe<void> BwTensorDescInferFn(user_op::InferContext* ctx) {\n  ctx->SetOutputShape(\"dx\", 0, ctx->InputShape(\"x\", 0));\n  ctx->SetOutputIsDynamic(\"dx\", 0, ctx->InputIsDynamic(\"x\", 0));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> FwInferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"y\", 0, ctx->InputDType(\"x\", 0));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> BwInferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"dx\", 0, ctx->InputDType(\"x\", 0));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> FwGetSbpFn(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"x\", 0);\n  FOR_RANGE(int64_t, i, 0, tensor.shape().NumAxes()) {\n    ctx->NewBuilder().Split(user_op::OpArg(\"x\", 0), i).Split(user_op::OpArg(\"y\", 0), i).Build();\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> BwGetSbpFn(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"x\", 0);\n  FOR_RANGE(int64_t, i, 0, tensor.shape().NumAxes()) {\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"x\", 0), i)\n        .Split(user_op::OpArg(\"y\", 0), i)\n        .Split(user_op::OpArg(\"dy\", 0), i)\n        .Split(user_op::OpArg(\"dx\", 0), i)\n        .Build();\n  }\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\n#define IMPLEMENT_TF_POOL_FUNCS(name, dim)                                                      \\\n  /*static*/ Maybe<void> name##Op::GetSbp(user_op::SbpContext* ctx) { return FwGetSbpFn(ctx); } \\\n  /*static*/ Maybe<void> name##Op::InferLogicalTensorDesc(user_op::InferContext* ctx) {         \\\n    return MakeFwTensorDescInferFn(dim)(ctx);                                                   \\\n  }                                                                                             \\\n  /*static*/ Maybe<void> name##Op::InferPhysicalTensorDesc(user_op::InferContext* ctx) {        \\\n    return InferLogicalTensorDesc(ctx);                                                         \\\n  }                                                                                             \\\n  /*static*/ Maybe<void> name##Op::InferDataType(user_op::InferContext* ctx) {                  \\\n    return FwInferDataType(ctx);                                                                \\\n  }                                                                                             \\\n  /*static*/ Maybe<double> name##Op::GetComputeComplexity(                                      \\\n      user_op::ComputeComplexityFnContext* ctx) {                                               \\\n    return GetComputationCost(ctx);                                                             \\\n  }\n\nIMPLEMENT_TF_POOL_FUNCS(TfAvgPool1D, 1)\nIMPLEMENT_TF_POOL_FUNCS(TfAvgPool2D, 2)\nIMPLEMENT_TF_POOL_FUNCS(TfAvgPool3D, 3)\nIMPLEMENT_TF_POOL_FUNCS(TfMaxPool1D, 1)\nIMPLEMENT_TF_POOL_FUNCS(TfMaxPool2D, 2)\nIMPLEMENT_TF_POOL_FUNCS(TfMaxPool3D, 3)\n#undef IMPLEMENT_TF_POOL_FUNCS\n\n#define IMPLEMENT_TF_POOL_BACKWARD_FUNCS(name)                                               \\\n  /*static*/ Maybe<void> name##GradOp::GetSbp(user_op::SbpContext* ctx) {                    \\\n    return BwGetSbpFn(ctx);                                                                  \\\n  }                                                                                          \\\n  /*static*/ Maybe<void> name##GradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {  \\\n    return BwTensorDescInferFn(ctx);                                                         \\\n  }                                                                                          \\\n  /*static*/ Maybe<void> name##GradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { \\\n    return InferLogicalTensorDesc(ctx);                                                      \\\n  }                                                                                          \\\n  /*static*/ Maybe<void> name##GradOp::InferDataType(user_op::InferContext* ctx) {           \\\n    return BwInferDataType(ctx);                                                             \\\n  }                                                                                          \\\n  /*static*/ Maybe<double> name##GradOp::GetComputeComplexity(                               \\\n      user_op::ComputeComplexityFnContext* ctx) {                                            \\\n    return GetComputationCost(ctx);                                                          \\\n  }\n\nIMPLEMENT_TF_POOL_BACKWARD_FUNCS(TfAvgPool1D)\nIMPLEMENT_TF_POOL_BACKWARD_FUNCS(TfAvgPool2D)\nIMPLEMENT_TF_POOL_BACKWARD_FUNCS(TfAvgPool3D)\nIMPLEMENT_TF_POOL_BACKWARD_FUNCS(TfMaxPool1D)\nIMPLEMENT_TF_POOL_BACKWARD_FUNCS(TfMaxPool2D)\nIMPLEMENT_TF_POOL_BACKWARD_FUNCS(TfMaxPool3D)\n#undef IMPLEMENT_TF_POOL_BACKWARD_FUNCS\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/tf_prelu_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/*static*/ Maybe<void> TfPreluOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"x\", 0);\n  const user_op::TensorDesc& alpha_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"alpha\", 0);\n  ctx->NewBuilder()\n      .Split(user_op::OpArg(\"x\", 0), 0)\n      .Broadcast(user_op::OpArg(\"alpha\", 0))\n      .Split(user_op::OpArg(\"y\", 0), 0)\n      .Build();\n  FOR_RANGE(int64_t, i, 1, x_tensor.shape().NumAxes()) {\n    if (x_tensor.shape().At(i) == alpha_tensor.shape().At(i - 1)) {\n      ctx->NewBuilder()\n          .Split(user_op::OpArg(\"x\", 0), i)\n          .Split(user_op::OpArg(\"alpha\", 0), i - 1)\n          .Split(user_op::OpArg(\"y\", 0), i)\n          .Build();\n    }\n  }\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> TfPreluOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& x_desc = ctx->InputTensorDesc(\"x\", 0);\n  user_op::TensorDesc* y_desc = ctx->MutOutputTensorDesc(\"y\", 0);\n  const Shape& alpha_shape = ctx->InputShape(\"alpha\", 0);\n  CHECK_EQ_OR_RETURN(x_desc.shape().NumAxes(), alpha_shape.NumAxes() + 1);\n  FOR_RANGE(int64_t, i, 1, x_desc.shape().NumAxes()) {\n    CHECK_OR_RETURN((alpha_shape.At(i - 1) == x_desc.shape().At(i))\n                    || (alpha_shape.At(i - 1) == 1));\n  }\n  y_desc->set_shape(x_desc.shape());\n  y_desc->set_is_dynamic(x_desc.is_dynamic());\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> TfPreluOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> TfPreluOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"y\", 0, ctx->InputDType(\"x\", 0));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> TfPreluGradOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"x\", 0);\n  const user_op::TensorDesc& alpha_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"alpha\", 0);\n  ctx->NewBuilder()\n      .Split(user_op::OpArg(\"dy\", 0), 0)\n      .Split(user_op::OpArg(\"x\", 0), 0)\n      .Broadcast(user_op::OpArg(\"alpha\", 0))\n      .Split(user_op::OpArg(\"dx\", 0), 0)\n      .PartialSum(user_op::OpArg(\"alpha_diff\", 0))\n      .Build();\n  ctx->NewBuilder()\n      .PartialSum(user_op::OpArg(\"dy\", 0))\n      .Broadcast(user_op::OpArg(\"x\", 0))\n      .Broadcast(user_op::OpArg(\"alpha\", 0))\n      .PartialSum(user_op::OpArg(\"dx\", 0))\n      .PartialSum(user_op::OpArg(\"alpha_diff\", 0))\n      .Build();\n  FOR_RANGE(int64_t, i, 1, x_tensor.shape().NumAxes()) {\n    if (x_tensor.shape().At(i) == alpha_tensor.shape().At(i - 1)) {\n      ctx->NewBuilder()\n          .Split(user_op::OpArg(\"dy\", 0), i)\n          .Split(user_op::OpArg(\"x\", 0), i)\n          .Split(user_op::OpArg(\"alpha\", 0), i - 1)\n          .Split(user_op::OpArg(\"dx\", 0), i)\n          .Split(user_op::OpArg(\"alpha_diff\", 0), i - 1)\n          .Build();\n    }\n  }\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> TfPreluGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& x_desc = ctx->InputTensorDesc(\"x\", 0);\n  const user_op::TensorDesc& dy_desc = ctx->InputTensorDesc(\"dy\", 0);\n  user_op::TensorDesc* dx_desc = ctx->MutOutputTensorDesc(\"dx\", 0);\n  const user_op::TensorDesc& alpha_desc = ctx->InputTensorDesc(\"alpha\", 0);\n  CHECK_EQ_OR_RETURN(x_desc.shape().NumAxes(), alpha_desc.shape().NumAxes() + 1);\n  FOR_RANGE(int64_t, i, 1, x_desc.shape().NumAxes()) {\n    CHECK_OR_RETURN((alpha_desc.shape().At(i - 1) == x_desc.shape().At(i))\n                    || (alpha_desc.shape().At(i - 1) == 1));\n  }\n  CHECK_EQ_OR_RETURN(dy_desc.shape(), x_desc.shape());\n  CHECK_EQ_OR_RETURN(dy_desc.data_type(), x_desc.data_type())\n      << \"InferDataType Failed. Expected \" << DataType_Name(ctx->InputDType(\"dy\", 0))\n      << \", but got \" << DataType_Name(x_desc.data_type());\n  dx_desc->set_shape(x_desc.shape());\n  dx_desc->set_is_dynamic(x_desc.is_dynamic());\n  ctx->SetOutputShape(\"alpha_diff\", 0, alpha_desc.shape());\n  ctx->SetOutputIsDynamic(\"alpha_diff\", 0, alpha_desc.is_dynamic());\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> TfPreluGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> TfPreluGradOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"dx\", 0, ctx->InputDType(\"x\", 0));\n  ctx->SetOutputDType(\"alpha_diff\", 0, ctx->InputDType(\"alpha\", 0));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/threshold_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/* static */ Maybe<void> ThresholdOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  ctx->SetOutputShape(\"out\", 0, ctx->InputShape(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> ThresholdOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> ThresholdOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"in\", 0);\n  FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) {\n    ctx->NewBuilder().Split(user_op::OpArg(\"in\", 0), i).Split(user_op::OpArg(\"out\", 0), i).Build();\n  }\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> ThresholdOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> ThresholdGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& x_shape = ctx->InputShape(\"x\", 0);\n  const Shape& dy_shape = ctx->InputShape(\"dy\", 0);\n  CHECK_OR_RETURN(dy_shape == x_shape);\n  ctx->SetOutputShape(\"dx\", 0, dy_shape);\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> ThresholdGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> ThresholdGradOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"x\", 0);\n  FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) {\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"x\", 0), i)\n        .Split(user_op::OpArg(\"dy\", 0), i)\n        .Split(user_op::OpArg(\"dx\", 0), i)\n        .Build();\n  }\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> ThresholdGradOp::InferDataType(user_op::InferContext* ctx) {\n  CHECK_EQ_OR_RETURN(ctx->InputDType(\"dy\", 0), ctx->InputDType(\"x\", 0))\n      << \"InferDataType Failed. Expected \" << DataType_Name(ctx->InputDType(\"dy\", 0))\n      << \", but got \" << DataType_Name(ctx->InputDType(\"x\", 0));\n  ctx->SetOutputDType(\"dx\", 0, ctx->InputDType(\"x\", 0));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/throw_error_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\nMaybe<void> ThrowErrorOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  return user_op::TensorDescInferFnUtil::Unchanged(ctx);\n}\nMaybe<void> ThrowErrorOp::GetSbp(user_op::SbpContext* ctx) {\n  return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx);\n}\nMaybe<void> ThrowErrorOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"y\", 0, ctx->InputDType(\"x\", 0));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/to_contiguous_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/*static*/ Maybe<void> ToContiguousOp::GetSbp(user_op::SbpContext* ctx) {\n  ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build();\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> ToContiguousOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& in_desc = ctx->InputTensorDesc(\"in\", 0);\n  ctx->SetOutputShape(\"out\", 0, ctx->InputShape(\"in\", 0));\n  ctx->SetOutputStride(\"out\", 0, Stride(in_desc.shape()));\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> ToContiguousOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> ToContiguousOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/top_k_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/*static*/ Maybe<void> TopKOp::GetSbp(user_op::SbpContext* ctx) {\n  // The current implementation can only do top_k in the last dimension and should use Broadcast\n  // (by default) instead of Split for that dimension\n  const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"in\", 0);\n  FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes() - 1) {\n    ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build();\n  }\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> TopKOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  Shape out_shape = ctx->InputShape(\"in\", 0);\n  out_shape.Set(\n      out_shape.NumAxes() - 1,\n      std::min(ctx->Attr<int32_t>(\"k\"), static_cast<int32_t>(out_shape.dim_vec().back())));\n  ctx->SetOutputShape(\"out\", 0, out_shape);\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> TopKOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> TopKOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, DataType::kInt64);\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/transpose_ops.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\nvoid CheckIsPerm(const std::vector<int32_t>& perm) {\n  std::vector<bool> is_used(perm.size(), false);\n  FOR_RANGE(size_t, i, 0, perm.size()) {\n    CHECK_GE(perm[i], 0);\n    CHECK_LE(perm[i], perm.size());\n    CHECK_EQ(is_used[perm[i]], false);\n    is_used[perm[i]] = true;\n  }\n}\n\n/*static*/ Maybe<void> TransposeOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& input_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"input\", 0);\n  const auto& perm = ctx->Attr<std::vector<int32_t>>(\"perm\");\n  CHECK_EQ_OR_RETURN(perm.size(), input_tensor.shape().NumAxes());\n  FOR_RANGE(int32_t, i, 0, perm.size()) {\n    int32_t axis = perm.at(i);\n    if (axis < 0) { axis += perm.size(); }\n    CHECK_GE_OR_RETURN(axis, 0);\n    CHECK_LT_OR_RETURN(axis, perm.size());\n    ctx->NewBuilder().Split(ctx->inputs(), axis).Split(ctx->outputs(), i).Build();\n  }\n  ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build();\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> TransposeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& in_tensor_desc = ctx->InputTensorDesc(\"input\", 0);\n  user_op::TensorDesc* out_tensor_desc = ctx->MutOutputTensorDesc(\"output\", 0);\n  const Shape& in_shape = in_tensor_desc.shape();\n  Shape out_shape = in_tensor_desc.shape();\n  const auto& perm = ctx->Attr<std::vector<int32_t>>(\"perm\");\n  CHECK_EQ_OR_RETURN(perm.size(), in_shape.NumAxes());\n  CheckIsPerm(perm);\n  // if (perm.at(0) != 0) { CHECK_OR_RETURN(!in_tensor_desc->is_dynamic()); }\n  out_tensor_desc->set_is_dynamic(in_tensor_desc.is_dynamic());\n  FOR_RANGE(size_t, i, 0, perm.size()) { out_shape.Set(i, in_shape.At(perm[i])); }\n  out_tensor_desc->set_shape(out_shape);\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> TransposeOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> TransposeOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"output\", 0, ctx->InputDType(\"input\", 0));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/tril_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/*static*/ Maybe<void> TrilOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& in = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"in\", 0);\n  FOR_RANGE(int64_t, i, 0, in.shape().NumAxes() - 2) {\n    ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build();\n  }\n  bool fill_zero = ctx->Attr<bool>(\"is_floating_fill_value\")\n                       ? (ctx->Attr<double>(\"floating_fill_value\") == static_cast<double>(0))\n                       : (ctx->Attr<int64_t>(\"integer_fill_value\") == static_cast<int64_t>(0));\n  if (fill_zero) {\n    ctx->NewBuilder()\n        .PartialSum(user_op::OpArg(\"in\", 0))\n        .PartialSum(user_op::OpArg(\"out\", 0))\n        .Build();\n  }\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> TrilOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& in = ctx->InputTensorDesc(\"in\", 0);\n  user_op::TensorDesc* out = ctx->MutOutputTensorDesc(\"out\", 0);\n  CHECK_GE_OR_RETURN(in.shape().NumAxes(), 2);\n  out->set_shape(in.shape());\n  out->set_is_dynamic(in.is_dynamic());\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> TrilOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> TrilOp::InferDataType(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& in = ctx->InputTensorDesc(\"in\", 0);\n  user_op::TensorDesc* out = ctx->MutOutputTensorDesc(\"out\", 0);\n  out->set_data_type(in.data_type());\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> FusedScaleTrilOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& in = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"in\", 0);\n  FOR_RANGE(int64_t, i, 0, in.shape().NumAxes() - 2) {\n    ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build();\n  }\n  bool fill_zero = ctx->Attr<bool>(\"is_floating_fill_value\")\n                       ? (ctx->Attr<double>(\"floating_fill_value\") == static_cast<double>(0))\n                       : (ctx->Attr<int64_t>(\"integer_fill_value\") == static_cast<int64_t>(0));\n  if (fill_zero) {\n    ctx->NewBuilder()\n        .PartialSum(user_op::OpArg(\"in\", 0))\n        .PartialSum(user_op::OpArg(\"out\", 0))\n        .Build();\n  }\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> FusedScaleTrilOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& in = ctx->InputTensorDesc(\"in\", 0);\n  user_op::TensorDesc* out = ctx->MutOutputTensorDesc(\"out\", 0);\n  CHECK_GE_OR_RETURN(in.shape().NumAxes(), 2);\n  out->set_shape(in.shape());\n  out->set_is_dynamic(in.is_dynamic());\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> FusedScaleTrilOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> FusedScaleTrilOp::InferDataType(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& in = ctx->InputTensorDesc(\"in\", 0);\n  user_op::TensorDesc* out = ctx->MutOutputTensorDesc(\"out\", 0);\n  out->set_data_type(in.data_type());\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/triu_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/*static*/ Maybe<void> TriuOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& in = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"in\", 0);\n  FOR_RANGE(int64_t, i, 0, in.shape().NumAxes() - 2) {\n    ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build();\n  }\n  ctx->NewBuilder()\n      .PartialSum(user_op::OpArg(\"in\", 0))\n      .PartialSum(user_op::OpArg(\"out\", 0))\n      .Build();\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> TriuOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& in = ctx->InputTensorDesc(\"in\", 0);\n  user_op::TensorDesc* out = ctx->MutOutputTensorDesc(\"out\", 0);\n  CHECK_GE_OR_RETURN(in.shape().NumAxes(), 2);\n  out->set_shape(in.shape());\n  out->set_is_dynamic(in.is_dynamic());\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> TriuOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> TriuOp::InferDataType(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& in = ctx->InputTensorDesc(\"in\", 0);\n  user_op::TensorDesc* out = ctx->MutOutputTensorDesc(\"out\", 0);\n  out->set_data_type(in.data_type());\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/trunc_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/*static*/ Maybe<void> TruncOp::GetSbp(user_op::SbpContext* ctx) {\n  return user_op::GetSbpFnUtil::SplitForEachAxis(ctx);\n}\n/*static*/ Maybe<void> TruncOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  return user_op::TensorDescInferFnUtil::Unchanged(ctx);\n}\n/*static*/ Maybe<void> TruncOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> TruncOp::InferDataType(user_op::InferContext* ctx) {\n  return user_op::TensorDescInferFnUtil::UnchangedDataType(ctx);\n}\n\n/* static */ Maybe<void> TruncGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  return user_op::TensorDescInferFnUtil::Unchanged(ctx);\n}\n\n/*static*/ Maybe<void> TruncGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> TruncGradOp::GetSbp(user_op::SbpContext* ctx) {\n  return user_op::GetSbpFnUtil::SplitForEachAxis(ctx);\n}\n\n/* static */ Maybe<void> TruncGradOp::InferDataType(user_op::InferContext* ctx) {\n  return user_op::TensorDescInferFnUtil::UnchangedDataType(ctx);\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/tuple_identity_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/operator/operator.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/*static*/ Maybe<void> TupleIdentityOp::GetSbp(user_op::SbpContext* ctx) {\n  return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx);\n}\n/*static*/ Maybe<void> TupleIdentityOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const int64_t in_size = ctx->input_size(\"in\");\n  CHECK_EQ_OR_RETURN(ctx->output_size(\"out\"), in_size);\n  for (int64_t i = 0; i < in_size; ++i) {\n    ctx->SetOutputShape(\"out\", i, ctx->InputShape(\"in\", i));\n    ctx->SetIsDynamic4ArgNameAndIndex(\"out\", i, ctx->InputIsDynamic(\"in\", i));\n  }\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> TupleIdentityOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> TupleIdentityOp::InferDataType(user_op::InferContext* ctx) {\n  const int64_t in_size = ctx->input_size(\"in\");\n  CHECK_EQ_OR_RETURN(ctx->output_size(\"out\"), in_size);\n  for (int64_t i = 0; i < in_size; ++i) { ctx->SetOutputDType(\"out\", i, ctx->InputDType(\"in\", i)); }\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> TupleIdentityOp::InferSbpSignature(\n    user_op::InferSbpSignatureFnContext* ctx) {\n  SbpSignature* signature = ctx->mutable_sbp_signature();\n  const SbpSignature& sbp_signature_conf = ctx->sbp_signature_conf();\n  auto* bn2sbp = signature->mutable_bn_in_op2sbp_parallel();\n  const auto& bn2conf_sbp = sbp_signature_conf.bn_in_op2sbp_parallel();\n  const int64_t in_size = ctx->user_op_conf().input_size(\"in\");\n  CHECK_EQ_OR_RETURN(ctx->user_op_conf().output_size(\"out\"), in_size);\n  for (int64_t i = 0; i < in_size; ++i) {\n    const SbpParallel* sbp_parallel = nullptr;\n    const std::string ibn = GenRepeatedBn(\"in\", i);\n    const std::string& obn = GenRepeatedBn(\"out\", i);\n    const auto& conf_sbp_it = bn2conf_sbp.find(obn);\n    if (conf_sbp_it == bn2conf_sbp.end()) {\n      sbp_parallel = &ctx->SbpParallelHint4InputArgNameAndIndex(\"in\", i);\n    } else {\n      sbp_parallel = &conf_sbp_it->second;\n    }\n    (*bn2sbp)[ibn] = *sbp_parallel;\n    (*bn2sbp)[obn] = *sbp_parallel;\n  }\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> TupleIdentityOp::CheckAttr(const user_op::UserOpDefWrapper&,\n                                                  const user_op::UserOpConfWrapper& op_conf) {\n  CHECK_OR_RETURN(op_conf.input_size(\"in\") >= 1);\n  CHECK_OR_RETURN(op_conf.output_size(\"out\") >= 1);\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/two_stage_reduce_ops.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/operator/reduce_sbp_util.h\"\n#include \"oneflow/core/ndarray/binary_func.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nMaybe<void> InferReduceDeviceStageDtypeFn(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"in\", 0));\n  ctx->SetOutputDType(\"mask\", 0, DataType::kBool);\n  ctx->SetOutputDType(\"count\", 0, DataType::kInt32);\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InferReduceDeviceStageLogicalTensorDescFn(user_op::InferContext* ctx) {\n  const Shape& input_shape = ctx->InputShape(\"in\", 0);\n  const auto& axis = ctx->Attr<std::vector<int32_t>>(\"axis\");\n  const int64_t num_axes = input_shape.NumAxes();\n  Shape output_shape;\n  if (axis.empty()) {\n    output_shape = Shape::Ones(num_axes);\n  } else {\n    const ParallelDesc& parallel_desc = ctx->parallel_desc();\n    const NdSbp& in_nd_sbp = ctx->NdSbp4ArgNameAndIndex(\"in\", 0);\n    DimVector dim_vec = input_shape.dim_vec();\n    if (parallel_desc.hierarchy()->NumAxes() == 1) {\n      const auto& input_sbp = in_nd_sbp.sbp_parallel(0);\n      for (auto i : axis) {\n        const int64_t regular_axis = ShiftNegativeAxis(i, num_axes);\n        dim_vec.at(regular_axis) =\n            (input_sbp.has_split_parallel() && input_sbp.split_parallel().axis() == regular_axis)\n                ? parallel_desc.parallel_num()\n                : 1;\n      }\n    } else {\n      CHECK_EQ_OR_RETURN(axis.size(), 1);\n      const int64_t regular_axis = ShiftNegativeAxis(axis.at(0), num_axes);\n      dim_vec.at(regular_axis) = 1;\n      for (int64_t i = 0; i < parallel_desc.hierarchy()->NumAxes(); ++i) {\n        const auto& input_sbp = in_nd_sbp.sbp_parallel(i);\n        if (input_sbp.has_split_parallel() && input_sbp.split_parallel().axis() == regular_axis) {\n          dim_vec.at(regular_axis) *= parallel_desc.hierarchy()->At(i);\n        }\n      }\n    }\n    output_shape = Shape(dim_vec);\n  }\n  ctx->SetOutputShape(\"out\", 0, output_shape);\n  ctx->SetOutputShape(\"mask\", 0, input_shape);\n  ctx->SetOutputShape(\"count\", 0, output_shape);\n\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InferReduceDeviceStagePhysicalTensorDescFn(user_op::InferContext* ctx) {\n  const Shape& input_shape = ctx->InputShape(\"in\", 0);\n  const auto& axis = ctx->Attr<std::vector<int32_t>>(\"axis\");\n  Shape output_shape;\n  if (axis.empty()) {\n    output_shape = Shape::Ones(input_shape.NumAxes());\n  } else {\n    const AxisVector axis_vec = {axis.begin(), axis.end()};\n    const Shape& reduced_shape = CreateReducedShape(input_shape, axis_vec);\n    output_shape = reduced_shape;\n  }\n\n  ctx->SetOutputShape(\"out\", 0, output_shape);\n  ctx->SetOutputShape(\"mask\", 0, input_shape);\n  ctx->SetOutputShape(\"count\", 0, output_shape);\n  ;\n\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InferReduceDeviceStageGradDtypeFn(user_op::InferContext* ctx) {\n  CHECK_EQ_OR_RETURN(ctx->InputDType(\"mask\", 0), DataType::kBool)\n      << \"InferDataType Failed. Expected \" << DataType_Name(DataType::kBool) << \", but got \"\n      << DataType_Name(ctx->InputDType(\"mask\", 0));\n  CHECK_EQ_OR_RETURN(ctx->InputDType(\"count\", 0), DataType::kInt32)\n      << \"InferDataType Failed. Expected \" << DataType_Name(DataType::kInt32) << \", but got \"\n      << DataType_Name(ctx->InputDType(\"count\", 0));\n  ctx->SetOutputDType(\"in_diff\", 0, ctx->InputDType(\"out_diff\", 0));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InferReduceDeviceStageGradTensorDescFn(user_op::InferContext* ctx) {\n  CHECK_EQ_OR_RETURN(ctx->InputShape(\"out_diff\", 0), ctx->InputShape(\"count\", 0));\n  ctx->SetOutputShape(\"in_diff\", 0, ctx->InputShape(\"mask\", 0));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InferReduceGlobalStageDtypeFn(user_op::InferContext* ctx) {\n  CHECK_EQ_OR_RETURN(ctx->InputDType(\"device_count\", 0), DataType::kInt32)\n      << \"InferDataType Failed. Expected \" << DataType_Name(DataType::kInt32) << \", but got \"\n      << DataType_Name(ctx->InputDType(\"device_count\", 0));\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"in\", 0));\n  ctx->SetOutputDType(\"mask\", 0, DataType::kBool);\n\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InferReduceGlobalStageTensorDescFn(user_op::InferContext* ctx) {\n  const Shape& input_shape = ctx->InputShape(\"in\", 0);\n  const Shape& device_count_shape = ctx->InputShape(\"device_count\", 0);\n  CHECK_EQ_OR_RETURN(input_shape, device_count_shape);\n  const auto& axis = ctx->Attr<std::vector<int32_t>>(\"axis\");\n  bool keepdims = ctx->Attr<bool>(\"keepdims\");\n  Shape output_shape;\n  if (axis.empty()) {\n    if (keepdims) {\n      output_shape = Shape::Ones(input_shape.NumAxes());\n    } else {\n      output_shape = Shape({1});\n    }\n  } else {\n    const AxisVector axis_vec = {axis.begin(), axis.end()};\n    const Shape& reduced_shape = CreateReducedShape(input_shape, axis_vec);\n    if (keepdims) {\n      output_shape = reduced_shape;\n    } else {\n      output_shape = reduced_shape.RemoveOnes(axis_vec);\n    }\n  }\n\n  ctx->SetOutputShape(\"out\", 0, output_shape);\n  ctx->SetOutputShape(\"mask\", 0, input_shape);\n\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InferReduceGlobalStageGradDtypeFn(user_op::InferContext* ctx) {\n  CHECK_EQ_OR_RETURN(ctx->InputDType(\"mask\", 0), DataType::kBool)\n      << \"InferDataType Failed. Expected \" << DataType_Name(DataType::kBool) << \", but got \"\n      << DataType_Name(ctx->InputDType(\"mask\", 0));\n  CHECK_EQ_OR_RETURN(ctx->InputDType(\"device_count\", 0), DataType::kInt32)\n      << \"InferDataType Failed. Expected \" << DataType_Name(DataType::kInt32) << \", but got \"\n      << DataType_Name(ctx->InputDType(\"device_count\", 0));\n\n  ctx->SetOutputDType(\"in_diff\", 0, ctx->InputDType(\"out_diff\", 0));\n\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InferReduceGlobalStageGradTensorDescFn(user_op::InferContext* ctx) {\n  const Shape& mask_shape = ctx->InputShape(\"mask\", 0);\n  const Shape& device_count_shape = ctx->InputShape(\"device_count\", 0);\n  CHECK_EQ_OR_RETURN(device_count_shape, mask_shape);\n  ctx->SetOutputShape(\"in_diff\", 0, mask_shape);\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> GetReduceDeviceStageSbpFn(user_op::SbpContext* ctx) {\n  int32_t num_axes = 0;\n  HashSet<int32_t> conf_axes;\n  {\n    const auto& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"in\", 0);\n    num_axes = in_tensor.shape().NumAxes();\n    const auto& reduced_axes = ctx->Attr<std::vector<int32_t>>(\"axis\");\n    conf_axes = {reduced_axes.begin(), reduced_axes.end()};\n  }\n  auto IsReducedAxis = ReduceSbpUtil::MakePredicatorIsReducedAxis(conf_axes, num_axes);\n  FOR_RANGE(int64_t, i, 0, num_axes) {\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"in\", 0), i)\n        .Split(user_op::OpArg(\"out\", 0), i)\n        .Split(user_op::OpArg(\"mask\", 0), i)\n        .Split(user_op::OpArg(\"count\", 0), i)\n        .Build();\n  }\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> GetReduceDeviceStageGradSbpFn(user_op::SbpContext* ctx) {\n  int32_t num_axes = 0;\n  HashSet<int32_t> conf_axes;\n  {\n    const auto& output_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"out_diff\", 0);\n    num_axes = output_tensor.shape().NumAxes();\n    const auto& reduced_axes = ctx->Attr<std::vector<int32_t>>(\"axis\");\n    conf_axes = {reduced_axes.begin(), reduced_axes.end()};\n  }\n  auto IsReducedAxis = ReduceSbpUtil::MakePredicatorIsReducedAxis(conf_axes, num_axes);\n  FOR_RANGE(int64_t, i, 0, num_axes) {\n    if (IsReducedAxis(i) || i == 0) {\n      ctx->NewBuilder()\n          .Split(user_op::OpArg(\"out_diff\", 0), i)\n          .Split(user_op::OpArg(\"count\", 0), i)\n          .Split(user_op::OpArg(\"mask\", 0), i)\n          .Split(user_op::OpArg(\"in_diff\", 0), i)\n          .Build();\n    }\n  }\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\n#define IMPLEMENT_REDUCE_DEVICE_STAGE_USER_OP_FUNCS(op_name)                                \\\n  /*static*/ Maybe<void> op_name##Op::GetSbp(user_op::SbpContext* ctx) {                    \\\n    return GetReduceDeviceStageSbpFn(ctx);                                                  \\\n  }                                                                                         \\\n  /*static*/ Maybe<void> op_name##Op::InferLogicalTensorDesc(user_op::InferContext* ctx) {  \\\n    return InferReduceDeviceStageLogicalTensorDescFn(ctx);                                  \\\n  }                                                                                         \\\n  /*static*/ Maybe<void> op_name##Op::InferPhysicalTensorDesc(user_op::InferContext* ctx) { \\\n    return InferReduceDeviceStagePhysicalTensorDescFn(ctx);                                 \\\n  }                                                                                         \\\n  /*static*/ Maybe<void> op_name##Op::InferDataType(user_op::InferContext* ctx) {           \\\n    return InferReduceDeviceStageDtypeFn(ctx);                                              \\\n  }\n\nIMPLEMENT_REDUCE_DEVICE_STAGE_USER_OP_FUNCS(ReduceMinDeviceStage)\nIMPLEMENT_REDUCE_DEVICE_STAGE_USER_OP_FUNCS(ReduceMaxDeviceStage)\n#undef IMPLEMENT_REDUCE_DEVICE_STAGE_USER_OP_FUNCS\n\n#define IMPLEMENT_REDUCE_DEVICE_STAGE_USER_GRAD_OP_FUNCS(op_name)                               \\\n  /*static*/ Maybe<void> op_name##GradOp::GetSbp(user_op::SbpContext* ctx) {                    \\\n    return GetReduceDeviceStageGradSbpFn(ctx);                                                  \\\n  }                                                                                             \\\n  /*static*/ Maybe<void> op_name##GradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {  \\\n    return InferReduceDeviceStageGradTensorDescFn(ctx);                                         \\\n  }                                                                                             \\\n  /*static*/ Maybe<void> op_name##GradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { \\\n    return InferLogicalTensorDesc(ctx);                                                         \\\n  }                                                                                             \\\n  /*static*/ Maybe<void> op_name##GradOp::InferDataType(user_op::InferContext* ctx) {           \\\n    return InferReduceDeviceStageGradDtypeFn(ctx);                                              \\\n  }\n\nIMPLEMENT_REDUCE_DEVICE_STAGE_USER_GRAD_OP_FUNCS(ReduceMinDeviceStage)\nIMPLEMENT_REDUCE_DEVICE_STAGE_USER_GRAD_OP_FUNCS(ReduceMaxDeviceStage)\n#undef IMPLEMENT_REDUCE_DEVICE_STAGE_USER_GRAD_OP_FUNCS\n\n#define IMPLEMENT_REDUCE_GLOBAL_STAGE_OP_FUNCS(op_name)                                          \\\n  /*static*/ Maybe<void> op_name##Op::GetSbp(user_op::SbpContext* ctx) {                         \\\n    ctx->NewBuilder()                                                                            \\\n        .Split(user_op::OpArg(\"in\", 0), 0)                                                       \\\n        .Split(user_op::OpArg(\"device_count\", 0), 0)                                             \\\n        .Split(user_op::OpArg(\"out\", 0), 0)                                                      \\\n        .Split(user_op::OpArg(\"mask\", 0), 0)                                                     \\\n        .Build();                                                                                \\\n    return Maybe<void>::Ok();                                                                    \\\n  }                                                                                              \\\n  /*static*/ Maybe<void> op_name##Op::InferLogicalTensorDesc(user_op::InferContext* ctx) {       \\\n    return InferReduceGlobalStageTensorDescFn(ctx);                                              \\\n  }                                                                                              \\\n  /*static*/ Maybe<void> op_name##Op::InferPhysicalTensorDesc(user_op::InferContext* ctx) {      \\\n    return InferLogicalTensorDesc(ctx);                                                          \\\n  }                                                                                              \\\n  /*static*/ Maybe<void> op_name##Op::InferDataType(user_op::InferContext* ctx) {                \\\n    return InferReduceGlobalStageDtypeFn(ctx);                                                   \\\n  }                                                                                              \\\n  /*static*/ Maybe<void> op_name##Op::ModifyInputArg(                                            \\\n      const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&) {     \\\n    user_op::InputArgModifier* device_count_modifier = GetInputArgModifierFn(\"device_count\", 0); \\\n    device_count_modifier->set_requires_grad(false);                                             \\\n    return Maybe<void>::Ok();                                                                    \\\n  }\n\nIMPLEMENT_REDUCE_GLOBAL_STAGE_OP_FUNCS(ReduceMinGlobalStage)\nIMPLEMENT_REDUCE_GLOBAL_STAGE_OP_FUNCS(ReduceMaxGlobalStage)\n#undef IMPLEMENT_REDUCE_GLOBAL_STAGE_OP_FUNCS\n\n#define IMPLEMENT_REDUCE_GLOBAL_STAGE_GRAD_OP_FUNCS(op_name)                                    \\\n  /*static*/ Maybe<void> op_name##GradOp::GetSbp(user_op::SbpContext* ctx) {                    \\\n    ctx->NewBuilder()                                                                           \\\n        .Split(user_op::OpArg(\"out_diff\", 0), 0)                                                \\\n        .Split(user_op::OpArg(\"mask\", 0), 0)                                                    \\\n        .Split(user_op::OpArg(\"device_count\", 0), 0)                                            \\\n        .Split(user_op::OpArg(\"in_diff\", 0), 0)                                                 \\\n        .Build();                                                                               \\\n    return Maybe<void>::Ok();                                                                   \\\n  }                                                                                             \\\n  /*static*/ Maybe<void> op_name##GradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {  \\\n    return InferReduceGlobalStageGradTensorDescFn(ctx);                                         \\\n  }                                                                                             \\\n  /*static*/ Maybe<void> op_name##GradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { \\\n    return InferLogicalTensorDesc(ctx);                                                         \\\n  }                                                                                             \\\n  /*static*/ Maybe<void> op_name##GradOp::InferDataType(user_op::InferContext* ctx) {           \\\n    return InferReduceGlobalStageGradDtypeFn(ctx);                                              \\\n  }\n\nIMPLEMENT_REDUCE_GLOBAL_STAGE_GRAD_OP_FUNCS(ReduceMinGlobalStage)\nIMPLEMENT_REDUCE_GLOBAL_STAGE_GRAD_OP_FUNCS(ReduceMaxGlobalStage)\n#undef IMPLEMENT_REDUCE_GLOBAL_STAGE_GRAD_OP_FUNCS\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/unfold_fold_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/user/ops/nn_util.h\"\n#include \"oneflow/core/operator/operator_util.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nMaybe<void> UnfoldTensorDescInferFn(user_op::InferContext* ctx) {\n  const Shape& x_shape = ctx->InputShape(\"x\", 0);\n  const int32_t spatial_ndim = x_shape.NumAxes() - 2;\n  std::string data_format = ctx->Attr<std::string>(\"data_format\");\n  std::vector<int32_t> padding = ctx->Attr<std::vector<int32_t>>(\"padding\");\n  const std::vector<int32_t>& kernel_size = ctx->Attr<std::vector<int32_t>>(\"kernel_size\");\n  const std::vector<int32_t>& strides = ctx->Attr<std::vector<int32_t>>(\"strides\");\n  const std::vector<int32_t>& dilation_rate = ctx->Attr<std::vector<int32_t>>(\"dilation_rate\");\n  const int32_t idx_offset = IdxOffset(data_format);\n  const size_t c_dim = data_format == \"channels_first\" ? 1 : spatial_ndim + 1;\n\n  CHECK_EQ_OR_RETURN(spatial_ndim, 2);  // only support 4-D tensor now.\n  CHECK_EQ_OR_RETURN(padding.size(), spatial_ndim);\n  for (int32_t pad : padding) { CHECK_GE_OR_RETURN(pad, 0); }\n  CHECK_EQ_OR_RETURN(kernel_size.size(), spatial_ndim);\n  for (int32_t kernel : kernel_size) { CHECK_GT_OR_RETURN(kernel, 0); }\n  CHECK_EQ_OR_RETURN(strides.size(), spatial_ndim);\n  for (int32_t stride : strides) { CHECK_GT_OR_RETURN(stride, 0); }\n  CHECK_EQ_OR_RETURN(dilation_rate.size(), spatial_ndim);\n  for (int32_t dilation : dilation_rate) { CHECK_GE_OR_RETURN(dilation, 1); }\n\n  std::vector<int64_t> dhw_shape(spatial_ndim);\n  for (int32_t i = 0; i < spatial_ndim; ++i) {\n    dhw_shape[i] =\n        (x_shape.At(idx_offset + i) + 2 * padding[i] - dilation_rate[i] * (kernel_size[i] - 1) - 1)\n            / strides[i]\n        + 1;\n  }\n\n  DimVector y_shape(3);\n  y_shape.at(0) = x_shape.At(0);\n  y_shape.at(1) =\n      x_shape.At(c_dim)\n      * std::accumulate(kernel_size.begin(), kernel_size.end(), 1, std::multiplies<int>());\n  y_shape.at(2) = std::accumulate(dhw_shape.begin(), dhw_shape.end(), 1, std::multiplies<int>());\n\n  ctx->SetOutputShape(\"y\", 0, Shape(y_shape));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> SetUnfoldDTypeFn(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"y\", 0, ctx->InputDType(\"x\", 0));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> GetUnfoldSbpFn(user_op::SbpContext* ctx) {\n  ctx->NewBuilder().Split(user_op::OpArg(\"x\", 0), 0).Split(user_op::OpArg(\"y\", 0), 0).Build();\n\n  ctx->NewBuilder().Split(user_op::OpArg(\"x\", 0), 1).Split(user_op::OpArg(\"y\", 0), 1).Build();\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> FoldTensorDescInferFn(user_op::InferContext* ctx) {\n  const Shape& x_shape = ctx->InputShape(\"x\", 0);\n  const int32_t spatial_ndim = x_shape.NumAxes() - 1;  // (n, c*K*K, h*w)\n\n  std::string data_format = ctx->Attr<std::string>(\"data_format\");\n  std::vector<int32_t> output_size = ctx->Attr<std::vector<int32_t>>(\"output_size\");\n  std::vector<int32_t> padding = ctx->Attr<std::vector<int32_t>>(\"padding\");\n  const std::vector<int32_t>& kernel_size = ctx->Attr<std::vector<int32_t>>(\"kernel_size\");\n  const std::vector<int32_t>& strides = ctx->Attr<std::vector<int32_t>>(\"strides\");\n  const std::vector<int32_t>& dilation_rate = ctx->Attr<std::vector<int32_t>>(\"dilation_rate\");\n  const size_t c_dim = data_format == \"channels_first\" ? 1 : spatial_ndim;\n  const size_t length_dim = data_format == \"channels_first\" ? spatial_ndim : 1;\n\n  const int32_t input_planes = x_shape.At(c_dim);\n  const int32_t input_length = x_shape.At(length_dim);\n\n  CHECK_EQ_OR_RETURN(spatial_ndim, 2);  // only support 4-D tensor now.\n  CHECK_EQ_OR_RETURN(output_size.size(), spatial_ndim);\n  CHECK_EQ_OR_RETURN(padding.size(), spatial_ndim);\n  for (int32_t pad : padding) { CHECK_GE_OR_RETURN(pad, 0); }\n  CHECK_EQ_OR_RETURN(kernel_size.size(), spatial_ndim);\n  for (int32_t kernel : kernel_size) { CHECK_GT_OR_RETURN(kernel, 0); }\n  CHECK_EQ_OR_RETURN(strides.size(), spatial_ndim);\n  for (int32_t stride : strides) { CHECK_GT_OR_RETURN(stride, 0); }\n  CHECK_EQ_OR_RETURN(dilation_rate.size(), spatial_ndim);\n  for (int32_t dilation : dilation_rate) { CHECK_GE_OR_RETURN(dilation, 1); }\n\n  CHECK_EQ_OR_RETURN(input_planes % (kernel_size[0] * kernel_size[1]),\n                     0);  // C*K*K should be divided by K*K\n\n  const int32_t output_height =\n      (output_size[0] + 2 * padding[0] - dilation_rate[0] * (kernel_size[0] - 1) - 1) / strides[0]\n      + 1;\n  const int32_t output_width =\n      (output_size[1] + 2 * padding[1] - dilation_rate[1] * (kernel_size[1] - 1) - 1) / strides[1]\n      + 1;\n  CHECK_EQ_OR_RETURN(output_height * output_width, input_length);  // input_length == OH*OW\n\n  DimVector y_shape(4);\n  y_shape.at(0) = x_shape.At(0);\n  y_shape.at(1) = input_planes / (kernel_size[0] * kernel_size[1]);\n  y_shape.at(2) = output_size[0];\n  y_shape.at(3) = output_size[1];\n\n  ctx->SetOutputShape(\"y\", 0, Shape(y_shape));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> FoldDTypeFn(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"y\", 0, ctx->InputDType(\"x\", 0));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> GetFoldSbpFn(user_op::SbpContext* ctx) {\n  ctx->NewBuilder().Split(user_op::OpArg(\"x\", 0), 0).Split(user_op::OpArg(\"y\", 0), 0).Build();\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\n/*static*/ Maybe<void> UnfoldOp::GetSbp(user_op::SbpContext* ctx) { return GetUnfoldSbpFn(ctx); }\n/*static*/ Maybe<void> UnfoldOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  return UnfoldTensorDescInferFn(ctx);\n}\n/*static*/ Maybe<void> UnfoldOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> UnfoldOp::InferDataType(user_op::InferContext* ctx) {\n  return SetUnfoldDTypeFn(ctx);\n}\n\n/*static*/ Maybe<void> FoldOp::GetSbp(user_op::SbpContext* ctx) { return GetFoldSbpFn(ctx); }\n/*static*/ Maybe<void> FoldOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  return FoldTensorDescInferFn(ctx);\n}\n/*static*/ Maybe<void> FoldOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> FoldOp::InferDataType(user_op::InferContext* ctx) {\n  return FoldDTypeFn(ctx);\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/unfold_tensor_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/user/kernels/unfold_tensor_kernel_utils.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/*static*/ Maybe<void> UnfoldTensorOp::GetSbp(user_op::SbpContext* ctx) {\n  const int32_t dimension = ctx->Attr<int32_t>(\"dimension\");\n  const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"x\", 0);\n  FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) {\n    if (i != dimension) {\n      ctx->NewBuilder().Split(user_op::OpArg(\"x\", 0), i).Split(user_op::OpArg(\"y\", 0), i).Build();\n    }\n  }\n  ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build();\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> UnfoldTensorOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& in = ctx->InputTensorDesc(\"x\", 0);\n  const int32_t dimension = ctx->Attr<int32_t>(\"dimension\");\n  const int32_t size = ctx->Attr<int32_t>(\"size\");\n  const int32_t step = ctx->Attr<int32_t>(\"step\");\n\n  const Shape& in_shape = ctx->InputShape(\"x\", 0);\n  const int32_t in_dim = in_shape.NumAxes();\n  CHECK_GE_OR_RETURN(dimension, 0);\n  // NOTE(lixiang): remove -1 for 0-dim tensor\n  CHECK_LE_OR_RETURN(dimension, in_dim);\n\n  const int32_t max_size = in_dim == 0 ? 1 : in_shape.At(dimension);\n  CHECK_GT_OR_RETURN(size, 0);\n  CHECK_LE_OR_RETURN(size, max_size);\n  CHECK_GT_OR_RETURN(step, 0);\n\n  DimVector out_shape(in_dim + 1);\n  out_shape[in_dim] = size;\n  FOR_RANGE(int32_t, d, 0, in_dim) {\n    int32_t in_size_at_d = in.shape().At(d);\n    if (d == dimension) {\n      out_shape.at(d) = (in_size_at_d - size) / step + 1;\n    } else {\n      out_shape.at(d) = in_size_at_d;\n    }\n  }\n  ctx->SetOutputShape(\"y\", 0, Shape(out_shape));\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> UnfoldTensorOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> UnfoldTensorOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"y\", 0, ctx->InputDType(\"x\", 0));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> UnfoldTensorGradOp::GetSbp(user_op::SbpContext* ctx) {\n  const int32_t dimension = ctx->Attr<int32_t>(\"dimension\");\n  const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"x\", 0);\n  FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) {\n    if (i != dimension) {\n      ctx->NewBuilder()\n          .Split(user_op::OpArg(\"dy\", 0), i)\n          .Split(user_op::OpArg(\"x\", 0), i)\n          .Split(user_op::OpArg(\"dx\", 0), i)\n          .Build();\n    }\n  }\n  ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build();\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> UnfoldTensorGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& in = ctx->InputTensorDesc(\"x\", 0);\n  const Shape& in_shape = in.shape();\n  user_op::TensorDesc* dx_desc = ctx->MutOutputTensorDesc(\"dx\", 0);\n  dx_desc->set_shape(Shape(in_shape.dim_vec()));\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> UnfoldTensorGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> UnfoldTensorGradOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"dx\", 0, ctx->InputDType(\"dy\", 0));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/unique_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/*static*/ Maybe<void> UniqueOp::GetSbp(user_op::SbpContext* ctx) {\n  return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx);\n}\n/*static*/ Maybe<void> UniqueOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& x = ctx->InputTensorDesc(\"x\", 0);\n  CHECK_EQ_OR_RETURN(x.shape().NumAxes(), 1);\n\n  user_op::TensorDesc* y = ctx->MutOutputTensorDesc(\"y\", 0);\n  y->set_shape(x.shape());\n  y->set_is_dynamic(x.is_dynamic());\n\n  user_op::TensorDesc* idx = ctx->MutOutputTensorDesc(\"idx\", 0);\n  idx->set_shape(x.shape());\n  idx->set_is_dynamic(x.is_dynamic());\n\n  user_op::TensorDesc* num_unique = ctx->MutOutputTensorDesc(\"num_unique\", 0);\n  num_unique->set_shape(Shape({1}));\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> UniqueOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> UniqueOp::InferDataType(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& x = ctx->InputTensorDesc(\"x\", 0);\n  auto out_idx = ctx->Attr<DataType>(\"out_idx\");\n  CHECK_OR_RETURN(IsIndexDataType(out_idx));\n  user_op::TensorDesc* y = ctx->MutOutputTensorDesc(\"y\", 0);\n  y->set_data_type(x.data_type());\n\n  user_op::TensorDesc* idx = ctx->MutOutputTensorDesc(\"idx\", 0);\n  idx->set_data_type(out_idx);\n\n  user_op::TensorDesc* num_unique = ctx->MutOutputTensorDesc(\"num_unique\", 0);\n  num_unique->set_data_type(out_idx);\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/unique_with_counts_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/*static*/ Maybe<void> UniqueWithCountsOp::GetSbp(user_op::SbpContext* ctx) {\n  return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx);\n}\n/*static*/ Maybe<void> UniqueWithCountsOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& x = ctx->InputTensorDesc(\"x\", 0);\n  CHECK_EQ_OR_RETURN(x.shape().NumAxes(), 1);\n\n  user_op::TensorDesc* y = ctx->MutOutputTensorDesc(\"y\", 0);\n  y->set_shape(x.shape());\n  y->set_is_dynamic(x.is_dynamic());\n\n  user_op::TensorDesc* idx = ctx->MutOutputTensorDesc(\"idx\", 0);\n  idx->set_shape(x.shape());\n  idx->set_is_dynamic(x.is_dynamic());\n\n  user_op::TensorDesc* count = ctx->MutOutputTensorDesc(\"count\", 0);\n  count->set_shape(x.shape());\n  count->set_is_dynamic(x.is_dynamic());\n\n  user_op::TensorDesc* num_unique = ctx->MutOutputTensorDesc(\"num_unique\", 0);\n  num_unique->set_shape(Shape({1}));\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> UniqueWithCountsOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> UniqueWithCountsOp::InferDataType(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& x = ctx->InputTensorDesc(\"x\", 0);\n  auto out_idx = ctx->Attr<DataType>(\"out_idx\");\n  CHECK_OR_RETURN(IsIndexDataType(out_idx));\n  user_op::TensorDesc* y = ctx->MutOutputTensorDesc(\"y\", 0);\n  y->set_data_type(x.data_type());\n\n  user_op::TensorDesc* idx = ctx->MutOutputTensorDesc(\"idx\", 0);\n  idx->set_data_type(out_idx);\n\n  user_op::TensorDesc* count = ctx->MutOutputTensorDesc(\"count\", 0);\n  count->set_data_type(out_idx);\n  user_op::TensorDesc* num_unique = ctx->MutOutputTensorDesc(\"num_unique\", 0);\n  num_unique->set_data_type(out_idx);\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/unpack_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/*static*/ Maybe<void> UnpackOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& in = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"in\", 0);\n  FOR_RANGE(int64_t, i, 0, in.shape().NumAxes()) {\n    ctx->NewBuilder().Split(user_op::OpArg(\"in\", 0), i).Split(user_op::OpArg(\"out\", 0), i).Build();\n  }\n  ctx->NewBuilder()\n      .PartialSum(user_op::OpArg(\"in\", 0))\n      .PartialSum(user_op::OpArg(\"out\", 0))\n      .Build();\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> UnpackOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& in_desc = ctx->InputTensorDesc(\"in\", 0);\n  const Shape& in_shape = in_desc.shape();\n  CHECK_GT_OR_RETURN(in_shape.NumAxes(), 0);\n  const auto unpack_num = ctx->Attr<int32_t>(\"unpack_num\");\n  CHECK_EQ_OR_RETURN(in_shape.At(0) % unpack_num, 0);\n  user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc(\"out\", 0);\n  Shape out_shape = in_desc.shape();\n  out_shape.Set(0, in_shape.At(0) / unpack_num);\n  out_desc->set_shape(out_shape);\n  out_desc->set_is_dynamic(in_desc.is_dynamic());\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> UnpackOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> UnpackOp::InferDataType(user_op::InferContext* ctx) {\n  user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc(\"out\", 0);\n  const user_op::TensorDesc& in_desc = ctx->InputTensorDesc(\"in\", 0);\n  out_desc->set_data_type(in_desc.data_type());\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> UnpackOp::InferOutputBlobTimeShape(\n    user_op::InferOutputBlobTimeShapeFnContext* ctx) {\n  const int32_t unpack_num = ctx->user_op_conf().attr<int32_t>(\"unpack_num\");\n  DimVector time_shape_dim_vec = ctx->TimeShape4InputArgNameAndIndex(\"in\", 0).dim_vec();\n  time_shape_dim_vec.emplace_back(unpack_num);\n  *ctx->mut_output_blob_time_shape() = Shape(time_shape_dim_vec);\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/unsorted_batch_segment_sum_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/*static*/ Maybe<void> UnsortedBatchSegmentSumOp::GetSbp(user_op::SbpContext* ctx) {\n  const int64_t segment_ids_num_axes =\n      ctx->LogicalTensorDesc4InputArgNameAndIndex(\"segment_ids\", 0).shape().NumAxes();\n  if (segment_ids_num_axes > 1) {\n    FOR_RANGE(int64_t, i, 0, segment_ids_num_axes - 1) {\n      ctx->NewBuilder()\n          .Split(user_op::OpArg(\"segment_ids\", 0), i)\n          .Split(user_op::OpArg(\"data\", 0), i)\n          .Split(user_op::OpArg(\"out\", 0), i)\n          .Build();\n    }\n  }\n  ctx->NewBuilder()\n      .Broadcast(user_op::OpArg(\"segment_ids\", 0))\n      .PartialSum(user_op::OpArg(\"data\", 0))\n      .PartialSum(user_op::OpArg(\"out\", 0))\n      .Build();\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> UnsortedBatchSegmentSumOp::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) {\n  const user_op::TensorDesc& data = ctx->InputTensorDesc(\"data\", 0);\n  const user_op::TensorDesc& segment_ids = ctx->InputTensorDesc(\"segment_ids\", 0);\n  CHECK_GE_OR_RETURN(segment_ids.shape().NumAxes(), 1);\n  CHECK_GE_OR_RETURN(data.shape().NumAxes(), segment_ids.shape().NumAxes());\n  CHECK_EQ_OR_RETURN(segment_ids.is_dynamic(), data.is_dynamic());\n  const int64_t num_segments = ctx->Attr<int64_t>(\"num_segments\");\n  CHECK_GE_OR_RETURN(num_segments, 1);\n  user_op::TensorDesc* out = ctx->MutOutputTensorDesc(\"out\", 0);\n\n  FOR_RANGE(int64_t, i, 0, segment_ids.shape().NumAxes() - 1) {\n    CHECK_EQ_OR_RETURN(segment_ids.shape().At(i), data.shape().At(i));\n  }\n\n  DimVector dim_vec(data.shape().dim_vec());\n  dim_vec.at(segment_ids.shape().NumAxes() - 1) = num_segments;\n  out->set_shape(Shape(dim_vec));\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> UnsortedBatchSegmentSumOp::InferPhysicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> UnsortedBatchSegmentSumOp::InferDataType(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& data = ctx->InputTensorDesc(\"data\", 0);\n  const user_op::TensorDesc& segment_ids = ctx->InputTensorDesc(\"segment_ids\", 0);\n  user_op::TensorDesc* out = ctx->MutOutputTensorDesc(\"out\", 0);\n  CHECK_OR_RETURN(IsIndexDataType(segment_ids.data_type()));\n  out->set_data_type(data.data_type());\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> UnsortedBatchSegmentSumOp::ModifyInputArg(\n    const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&) {\n  user_op::InputArgModifier* segment_ids_modifier = GetInputArgModifierFn(\"segment_ids\", 0);\n  CHECK_NOTNULL_OR_RETURN(segment_ids_modifier);\n  segment_ids_modifier->set_requires_grad(false);\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/unsorted_segment_sum_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/*static*/ Maybe<void> UnsortedSegmentSumOp::GetSbp(user_op::SbpContext* ctx) {\n  const int64_t data_num_axes =\n      ctx->LogicalTensorDesc4InputArgNameAndIndex(\"data\", 0).shape().NumAxes();\n  const int64_t segment_ids_num_axes =\n      ctx->LogicalTensorDesc4InputArgNameAndIndex(\"segment_ids\", 0).shape().NumAxes();\n  const int64_t axis = ctx->Attr<int64_t>(\"axis\");\n  FOR_RANGE(int64_t, i, 0, segment_ids_num_axes) {\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"segment_ids\", 0), i)\n        .Split(user_op::OpArg(\"data\", 0), i + axis)\n        .PartialSum(user_op::OpArg(\"out\", 0))\n        .Build();\n  }\n  FOR_RANGE(int64_t, i, 0, data_num_axes) {\n    if (i >= axis && i < axis + segment_ids_num_axes) { continue; }\n    const int64_t out_split_axis = (i < axis) ? i : i - segment_ids_num_axes + 1;\n    if (out_split_axis == axis) { continue; }\n    ctx->NewBuilder()\n        .Broadcast(user_op::OpArg(\"segment_ids\", 0))\n        .Split(user_op::OpArg(\"data\", 0), i)\n        .Split(user_op::OpArg(\"out\", 0), out_split_axis)\n        .Build();\n  }\n  ctx->NewBuilder()\n      .Broadcast(user_op::OpArg(\"segment_ids\", 0))\n      .PartialSum(user_op::OpArg(\"data\", 0))\n      .PartialSum(user_op::OpArg(\"out\", 0))\n      .Build();\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> UnsortedSegmentSumOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& data_shape = ctx->InputShape(\"data\", 0);\n  const int64_t axis = ctx->Attr<int64_t>(\"axis\");\n  const int64_t num_segments = ctx->Attr<int64_t>(\"num_segments\");\n  const Shape& segment_ids_shape = ctx->InputShape(\"segment_ids\", 0);\n\n  DimVector dim_vec;\n  dim_vec.insert(dim_vec.end(), data_shape.dim_vec().cbegin(),\n                 data_shape.dim_vec().cbegin() + axis);\n  dim_vec.emplace_back(num_segments);\n  dim_vec.insert(dim_vec.end(), data_shape.dim_vec().cbegin() + axis + segment_ids_shape.NumAxes(),\n                 data_shape.dim_vec().end());\n  ctx->SetOutputShape(\"out\", 0, Shape(dim_vec));\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> UnsortedSegmentSumOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> UnsortedSegmentSumOp::InferDataType(user_op::InferContext* ctx) {\n  CHECK_OR_RETURN(IsIndexDataType(ctx->InputDType(\"segment_ids\", 0)));\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"data\", 0));\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> UnsortedSegmentSumOp::ModifyInputArg(\n    const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&) {\n  user_op::InputArgModifier* segment_ids_modifier = GetInputArgModifierFn(\"segment_ids\", 0);\n  CHECK_NOTNULL_OR_RETURN(segment_ids_modifier);\n  segment_ids_modifier->set_requires_grad(false);\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> UnsortedSegmentSumLikeOp::GetSbp(user_op::SbpContext* ctx) {\n  const int64_t data_num_axes =\n      ctx->LogicalTensorDesc4InputArgNameAndIndex(\"data\", 0).shape().NumAxes();\n  const int64_t segment_ids_num_axes =\n      ctx->LogicalTensorDesc4InputArgNameAndIndex(\"segment_ids\", 0).shape().NumAxes();\n  const int64_t axis = ctx->Attr<int64_t>(\"axis\");\n  FOR_RANGE(int64_t, i, 0, segment_ids_num_axes) {\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"segment_ids\", 0), i)\n        .Split(user_op::OpArg(\"data\", 0), i + axis)\n        .Broadcast(user_op::OpArg(\"like\", 0))\n        .PartialSum(user_op::OpArg(\"out\", 0))\n        .Build();\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"segment_ids\", 0), i)\n        .Split(user_op::OpArg(\"data\", 0), i + axis)\n        .PartialSum(user_op::OpArg(\"like\", 0))\n        .PartialSum(user_op::OpArg(\"out\", 0))\n        .Build();\n  }\n  FOR_RANGE(int64_t, i, 0, data_num_axes) {\n    if (i >= axis && i < axis + segment_ids_num_axes) { continue; }\n    const int64_t out_split_axis = (i < axis) ? i : i - segment_ids_num_axes + 1;\n    if (out_split_axis == axis) { continue; }\n    ctx->NewBuilder()\n        .Broadcast(user_op::OpArg(\"segment_ids\", 0))\n        .Split(user_op::OpArg(\"data\", 0), i)\n        .Split(user_op::OpArg(\"like\", 0), out_split_axis)\n        .Split(user_op::OpArg(\"out\", 0), out_split_axis)\n        .Build();\n  }\n  ctx->NewBuilder()\n      .Broadcast(user_op::OpArg(\"segment_ids\", 0))\n      .PartialSum(user_op::OpArg(\"data\", 0))\n      .Broadcast(user_op::OpArg(\"like\", 0))\n      .PartialSum(user_op::OpArg(\"out\", 0))\n      .Build();\n  ctx->NewBuilder()\n      .Broadcast(user_op::OpArg(\"segment_ids\", 0))\n      .PartialSum(user_op::OpArg(\"data\", 0))\n      .PartialSum(user_op::OpArg(\"like\", 0))\n      .PartialSum(user_op::OpArg(\"out\", 0))\n      .Build();\n  ctx->NewBuilder()\n      .Broadcast(user_op::OpArg(\"segment_ids\", 0))\n      .Broadcast(user_op::OpArg(\"data\", 0))\n      .Split(user_op::OpArg(\"like\", 0), axis)\n      .Split(user_op::OpArg(\"out\", 0), axis)\n      .Build();\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> UnsortedSegmentSumLikeOp::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) {\n  const Shape& data_shape = ctx->InputShape(\"data\", 0);\n  const Shape& like_shape = ctx->InputShape(\"like\", 0);\n  const Shape& segment_ids_shape = ctx->InputShape(\"segment_ids\", 0);\n  const int64_t axis = ctx->Attr<int64_t>(\"axis\");\n  CHECK_GE_OR_RETURN(axis, 0);\n  CHECK_LE_OR_RETURN(axis, like_shape.NumAxes());\n  FOR_RANGE(int64_t, i, 0, axis) { CHECK_EQ_OR_RETURN(like_shape.At(i), data_shape.At(i)); }\n  CHECK_EQ_OR_RETURN(data_shape.NumAxes() - segment_ids_shape.NumAxes() + 1, like_shape.NumAxes());\n  FOR_RANGE(int64_t, i, axis + 1, like_shape.NumAxes()) {\n    CHECK_EQ_OR_RETURN(like_shape.At(i), data_shape.At(i + segment_ids_shape.NumAxes() - 1));\n  }\n  ctx->SetOutputShape(\"out\", 0, ctx->InputShape(\"like\", 0));\n  ctx->SetIsDynamic4ArgNameAndIndex(\"out\", 0, ctx->InputIsDynamic(\"like\", 0));\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> UnsortedSegmentSumLikeOp::InferPhysicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> UnsortedSegmentSumLikeOp::InferDataType(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& data = ctx->InputTensorDesc(\"data\", 0);\n  const user_op::TensorDesc& like = ctx->InputTensorDesc(\"like\", 0);\n  CHECK_EQ_OR_RETURN(data.data_type(), like.data_type())\n      << \"InferDataType Failed. Expected \" << DataType_Name(like.data_type()) << \", but got \"\n      << DataType_Name(data.data_type());\n  CHECK_OR_RETURN(IsIndexDataType(ctx->InputDType(\"segment_ids\", 0)));\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"data\", 0));\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> UnsortedSegmentSumLikeOp::ModifyInputArg(\n    const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&) {\n  user_op::InputArgModifier* segment_ids_modifier = GetInputArgModifierFn(\"segment_ids\", 0);\n  CHECK_NOTNULL_OR_RETURN(segment_ids_modifier);\n  segment_ids_modifier->set_requires_grad(false);\n  user_op::InputArgModifier* like_modifier = GetInputArgModifierFn(\"like\", 0);\n  CHECK_NOTNULL_OR_RETURN(like_modifier);\n  like_modifier->set_requires_grad(false);\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/upsample_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace {\nusing namespace oneflow;\ntemplate<int32_t N>\ntypename std::enable_if<(N <= 3), Maybe<void>>::type UpsamplingInferLogicalDesc(\n    user_op::InferContext* ctx, const std::string& func_name) {\n  const user_op::TensorDesc& x_desc = ctx->InputTensorDesc(\"x\", 0);\n  user_op::TensorDesc* y_desc = ctx->MutOutputTensorDesc(\"y\", 0);\n  if (N == 1) {\n    CHECK_OR_RETURN(ctx->Attr<std::string>(\"data_format\") == \"channels_first\"\n                    && x_desc.shape().NumAxes() == (N + 2))\n        << func_name << \" only supports NCH\";\n    int64_t input_width = x_desc.shape().At(2);\n    int64_t output_width = 0;\n    const double scale_factor = ctx->Attr<double>(\"scale_factor\");\n    std::vector<int64_t> output_size = ctx->Attr<std::vector<int64_t>>(\"output_size\");\n    if (output_size.size()) {\n      output_width = output_size[0];\n    } else {\n      output_width = static_cast<int64_t>(scale_factor * input_width);\n    }\n    CHECK_OR_RETURN(input_width > 0 && output_width > 0)\n        << func_name\n        << \": Input and output sizes should be greater than 0, but got input (W: \" << input_width\n        << \") output (W: \" << output_width << \")\";\n    y_desc->set_shape(Shape({x_desc.shape().At(0), x_desc.shape().At(1), output_width}));\n  } else if (N == 2) {\n    CHECK_OR_RETURN(ctx->Attr<std::string>(\"data_format\") == \"channels_first\"\n                    && x_desc.shape().NumAxes() == (N + 2))\n        << func_name << \" only supports NCHW\";\n    const double height_scale = ctx->Attr<double>(\"height_scale\");\n    const double width_scale = ctx->Attr<double>(\"width_scale\");\n    std::vector<int64_t> output_size = ctx->Attr<std::vector<int64_t>>(\"output_size\");\n    int64_t input_height = x_desc.shape().At(2);\n    int64_t input_width = x_desc.shape().At(3);\n    int64_t output_height = 0;\n    int64_t output_width = 0;\n    if (output_size.size()) {\n      output_height = output_size[0];\n      output_width = output_size[1];\n    } else {\n      output_height = static_cast<int64_t>(height_scale * input_height);\n      output_width = static_cast<int64_t>(width_scale * input_width);\n    }\n    CHECK_OR_RETURN(input_height > 0 && input_width > 0 && output_height > 0 && output_width > 0)\n        << func_name\n        << \": Input and output sizes should be greater than 0, but got input (H: \" << input_height\n        << \", W: \" << input_width << \") output (H: \" << output_height << \", W: \" << output_width\n        << \")\";\n    y_desc->set_shape(\n        Shape({x_desc.shape().At(0), x_desc.shape().At(1), output_height, output_width}));\n  } else if (N == 3) {\n    CHECK_OR_RETURN(ctx->Attr<std::string>(\"data_format\") == \"channels_first\"\n                    && x_desc.shape().NumAxes() == 5)\n        << func_name << \" only supports NCDHW\";\n    const double depth_scale = ctx->Attr<double>(\"depth_scale\");\n    const double height_scale = ctx->Attr<double>(\"height_scale\");\n    const double width_scale = ctx->Attr<double>(\"width_scale\");\n    std::vector<int64_t> output_size = ctx->Attr<std::vector<int64_t>>(\"output_size\");\n    int64_t input_depth = x_desc.shape().At(2);\n    int64_t input_height = x_desc.shape().At(3);\n    int64_t input_width = x_desc.shape().At(4);\n    int64_t output_depth = 0;\n    int64_t output_height = 0;\n    int64_t output_width = 0;\n    if (output_size.size()) {\n      output_depth = output_size[0];\n      output_height = output_size[1];\n      output_width = output_size[2];\n    } else {\n      output_depth = static_cast<int64_t>(depth_scale * input_depth);\n      output_height = static_cast<int64_t>(height_scale * input_height);\n      output_width = static_cast<int64_t>(width_scale * input_width);\n    }\n    CHECK_OR_RETURN(input_depth > 0 && input_height > 0 && input_width > 0 && output_depth > 0\n                    && output_height > 0 && output_width > 0)\n        << func_name\n        << \": Input and output sizes should be greater than 0, but got input (D: \" << input_depth\n        << \", H: \" << input_height << \", W: \" << input_width << \") output (D: \" << output_depth\n        << \"H: \" << output_height << \", W: \" << output_width << \")\";\n    y_desc->set_shape(Shape(\n        {x_desc.shape().At(0), x_desc.shape().At(1), output_depth, output_height, output_width}));\n  }\n  return Maybe<void>::Ok();\n}\n}  // namespace\n\nnamespace oneflow {\n\n/*static*/ Maybe<void> UpsampleLinear1DOp::GetSbp(user_op::SbpContext* ctx) {\n  ctx->NewBuilder().Split(user_op::OpArg(\"x\", 0), 0).Split(user_op::OpArg(\"y\", 0), 0).Build();\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> UpsampleLinear1DOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  return UpsamplingInferLogicalDesc<1>(ctx, \"upsample_linear_1d\");\n}\n/*static*/ Maybe<void> UpsampleLinear1DOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> UpsampleLinear1DOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"y\", 0, ctx->InputDType(\"x\", 0));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> UpsampleNearest1DOp::GetSbp(user_op::SbpContext* ctx) {\n  ctx->NewBuilder().Split(user_op::OpArg(\"x\", 0), 0).Split(user_op::OpArg(\"y\", 0), 0).Build();\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> UpsampleNearest1DOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  return UpsamplingInferLogicalDesc<1>(ctx, \"upsample_nearest_1d\");\n}\n/*static*/ Maybe<void> UpsampleNearest1DOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> UpsampleNearest1DOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"y\", 0, ctx->InputDType(\"x\", 0));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> UpsampleNearest2DOp::GetSbp(user_op::SbpContext* ctx) {\n  ctx->NewBuilder().Split(user_op::OpArg(\"x\", 0), 0).Split(user_op::OpArg(\"y\", 0), 0).Build();\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> UpsampleNearest2DOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  return UpsamplingInferLogicalDesc<2>(ctx, \"upsample_nearest_2d\");\n}\n/*static*/ Maybe<void> UpsampleNearest2DOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> UpsampleNearest2DOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"y\", 0, ctx->InputDType(\"x\", 0));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> UpsampleBilinear2DOp::GetSbp(user_op::SbpContext* ctx) {\n  ctx->NewBuilder().Split(user_op::OpArg(\"x\", 0), 0).Split(user_op::OpArg(\"y\", 0), 0).Build();\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> UpsampleBilinear2DOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  return UpsamplingInferLogicalDesc<2>(ctx, \"upsample_bilinear_2d\");\n}\n/*static*/ Maybe<void> UpsampleBilinear2DOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> UpsampleBilinear2DOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"y\", 0, ctx->InputDType(\"x\", 0));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> UpsampleBicubic2DOp::GetSbp(user_op::SbpContext* ctx) {\n  ctx->NewBuilder().Split(user_op::OpArg(\"x\", 0), 0).Split(user_op::OpArg(\"y\", 0), 0).Build();\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> UpsampleBicubic2DOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  return UpsamplingInferLogicalDesc<2>(ctx, \"upsample_bicubic_2d\");\n}\n/*static*/ Maybe<void> UpsampleBicubic2DOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> UpsampleBicubic2DOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"y\", 0, ctx->InputDType(\"x\", 0));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> UpsampleNearest3DOp::GetSbp(user_op::SbpContext* ctx) {\n  ctx->NewBuilder().Split(user_op::OpArg(\"x\", 0), 0).Split(user_op::OpArg(\"y\", 0), 0).Build();\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> UpsampleNearest3DOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  return UpsamplingInferLogicalDesc<3>(ctx, \"upsample_nearest_3d\");\n}\n/*static*/ Maybe<void> UpsampleNearest3DOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> UpsampleNearest3DOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"y\", 0, ctx->InputDType(\"x\", 0));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> UpsampleTrilinear3DOp::GetSbp(user_op::SbpContext* ctx) {\n  ctx->NewBuilder().Split(user_op::OpArg(\"x\", 0), 0).Split(user_op::OpArg(\"y\", 0), 0).Build();\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> UpsampleTrilinear3DOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  return UpsamplingInferLogicalDesc<3>(ctx, \"upsample_trilinear_3d\");\n}\n/*static*/ Maybe<void> UpsampleTrilinear3DOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> UpsampleTrilinear3DOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"y\", 0, ctx->InputDType(\"x\", 0));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> UpsampleLinear1DGradOp::GetSbp(user_op::SbpContext* ctx) {\n  ctx->NewBuilder()\n      .Split(user_op::OpArg(\"dy\", 0), 0)\n      .Split(user_op::OpArg(\"x\", 0), 0)\n      .Split(user_op::OpArg(\"dx\", 0), 0)\n      .Build();\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> UpsampleLinear1DGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& dy_shape = ctx->InputShape(\"dy\", 0);\n  CHECK_OR_RETURN(ctx->Attr<std::string>(\"data_format\") == \"channels_first\"\n                  && dy_shape.NumAxes() == 3)\n      << \"upsample_linear_1d_grad only supports NCH\";\n  ctx->SetOutputShape(\"dx\", 0, ctx->InputShape(\"x\", 0));\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> UpsampleLinear1DGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> UpsampleLinear1DGradOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"dx\", 0, ctx->InputDType(\"dy\", 0));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> UpsampleNearest1DGradOp::GetSbp(user_op::SbpContext* ctx) {\n  ctx->NewBuilder()\n      .Split(user_op::OpArg(\"dy\", 0), 0)\n      .Split(user_op::OpArg(\"x\", 0), 0)\n      .Split(user_op::OpArg(\"dx\", 0), 0)\n      .Build();\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> UpsampleNearest1DGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& dy_shape = ctx->InputShape(\"dy\", 0);\n  CHECK_OR_RETURN(ctx->Attr<std::string>(\"data_format\") == \"channels_first\"\n                  && dy_shape.NumAxes() == 3)\n      << \"upsample_nearest_1d_grad only supports NCH\";\n  ctx->SetOutputShape(\"dx\", 0, ctx->InputShape(\"x\", 0));\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> UpsampleNearest1DGradOp::InferPhysicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> UpsampleNearest1DGradOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"dx\", 0, ctx->InputDType(\"dy\", 0));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> UpsampleNearest2DGradOp::GetSbp(user_op::SbpContext* ctx) {\n  ctx->NewBuilder()\n      .Split(user_op::OpArg(\"dy\", 0), 0)\n      .Split(user_op::OpArg(\"x\", 0), 0)\n      .Split(user_op::OpArg(\"dx\", 0), 0)\n      .Build();\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> UpsampleNearest2DGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& dy_shape = ctx->InputShape(\"dy\", 0);\n  CHECK_OR_RETURN(ctx->Attr<std::string>(\"data_format\") == \"channels_first\"\n                  && dy_shape.NumAxes() == 4)\n      << \"upsample_nearest_2d_grad only supports NCHW\";\n  ctx->SetOutputShape(\"dx\", 0, ctx->InputShape(\"x\", 0));\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> UpsampleNearest2DGradOp::InferPhysicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> UpsampleNearest2DGradOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"dx\", 0, ctx->InputDType(\"dy\", 0));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> UpsampleBilinear2DGradOp::GetSbp(user_op::SbpContext* ctx) {\n  ctx->NewBuilder()\n      .Split(user_op::OpArg(\"dy\", 0), 0)\n      .Split(user_op::OpArg(\"x\", 0), 0)\n      .Split(user_op::OpArg(\"dx\", 0), 0)\n      .Build();\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> UpsampleBilinear2DGradOp::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) {\n  const Shape& dy_shape = ctx->InputShape(\"dy\", 0);\n  CHECK_OR_RETURN(ctx->Attr<std::string>(\"data_format\") == \"channels_first\"\n                  && dy_shape.NumAxes() == 4)\n      << \"upsample_bilinear_2d_grad only supports NCHW\";\n  ctx->SetOutputShape(\"dx\", 0, ctx->InputShape(\"x\", 0));\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> UpsampleBilinear2DGradOp::InferPhysicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> UpsampleBilinear2DGradOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"dx\", 0, ctx->InputDType(\"dy\", 0));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> UpsampleBicubic2DGradOp::GetSbp(user_op::SbpContext* ctx) {\n  ctx->NewBuilder()\n      .Split(user_op::OpArg(\"dy\", 0), 0)\n      .Split(user_op::OpArg(\"x\", 0), 0)\n      .Split(user_op::OpArg(\"dx\", 0), 0)\n      .Build();\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> UpsampleBicubic2DGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& dy_shape = ctx->InputShape(\"dy\", 0);\n  CHECK_OR_RETURN(ctx->Attr<std::string>(\"data_format\") == \"channels_first\"\n                  && dy_shape.NumAxes() == 4)\n      << \"upsample_bicubic_2d_grad only supports NCHW\";\n  ctx->SetOutputShape(\"dx\", 0, ctx->InputShape(\"x\", 0));\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> UpsampleBicubic2DGradOp::InferPhysicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> UpsampleBicubic2DGradOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"dx\", 0, ctx->InputDType(\"dy\", 0));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> UpsampleNearest3DGradOp::GetSbp(user_op::SbpContext* ctx) {\n  ctx->NewBuilder()\n      .Split(user_op::OpArg(\"dy\", 0), 0)\n      .Split(user_op::OpArg(\"x\", 0), 0)\n      .Split(user_op::OpArg(\"dx\", 0), 0)\n      .Build();\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> UpsampleNearest3DGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& dy_shape = ctx->InputShape(\"dy\", 0);\n  CHECK_OR_RETURN(ctx->Attr<std::string>(\"data_format\") == \"channels_first\"\n                  && dy_shape.NumAxes() == 5)\n      << \"upsample_nearest_3d_grad only supports NCDHW\";\n  ctx->SetOutputShape(\"dx\", 0, ctx->InputShape(\"x\", 0));\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> UpsampleNearest3DGradOp::InferPhysicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> UpsampleNearest3DGradOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"dx\", 0, ctx->InputDType(\"dy\", 0));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> UpsampleTrilinear3DGradOp::GetSbp(user_op::SbpContext* ctx) {\n  ctx->NewBuilder()\n      .Split(user_op::OpArg(\"dy\", 0), 0)\n      .Split(user_op::OpArg(\"x\", 0), 0)\n      .Split(user_op::OpArg(\"dx\", 0), 0)\n      .Build();\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> UpsampleTrilinear3DGradOp::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) {\n  const Shape& dy_shape = ctx->InputShape(\"dy\", 0);\n  CHECK_OR_RETURN(ctx->Attr<std::string>(\"data_format\") == \"channels_first\"\n                  && dy_shape.NumAxes() == 5)\n      << \"upsample_trilinear_3d_grad only supports NCDHW\";\n  ctx->SetOutputShape(\"dx\", 0, ctx->InputShape(\"x\", 0));\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> UpsampleTrilinear3DGradOp::InferPhysicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> UpsampleTrilinear3DGradOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"dx\", 0, ctx->InputDType(\"dy\", 0));\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/util_ops.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/* static */ Maybe<void> IsNanOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  ctx->SetOutputShape(\"out\", 0, ctx->InputShape(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> IsNanOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> IsNanOp::GetSbp(user_op::SbpContext* ctx) {\n  ctx->NewBuilder().Broadcast(ctx->inputs()).Broadcast(ctx->outputs()).Build();\n  const auto& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"in\", 0);\n  for (int i = 0; i < in_tensor.shape().NumAxes(); ++i) {\n    ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build();\n  }\n  ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> IsNanOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, DataType::kBool);\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> IsInfOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  ctx->SetOutputShape(\"out\", 0, ctx->InputShape(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> IsInfOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> IsInfOp::GetSbp(user_op::SbpContext* ctx) {\n  ctx->NewBuilder().Broadcast(ctx->inputs()).Broadcast(ctx->outputs()).Build();\n  const auto& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"in\", 0);\n  for (int i = 0; i < in_tensor.shape().NumAxes(); ++i) {\n    ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build();\n  }\n  ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> IsInfOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, DataType::kBool);\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> IsFiniteOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  ctx->SetOutputShape(\"out\", 0, ctx->InputShape(\"in\", 0));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> IsFiniteOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> IsFiniteOp::GetSbp(user_op::SbpContext* ctx) {\n  ctx->NewBuilder().Broadcast(ctx->inputs()).Broadcast(ctx->outputs()).Build();\n  const auto& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"in\", 0);\n  for (int i = 0; i < in_tensor.shape().NumAxes(); ++i) {\n    ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build();\n  }\n  ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> IsFiniteOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, DataType::kBool);\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/variance_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/operator/reduce_sbp_util.h\"\n#include \"oneflow/core/ndarray/binary_func.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\nMaybe<void> VarOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& input_shape = ctx->InputShape(\"input\", 0);\n  const auto& reduce_axes = ctx->Attr<std::vector<int32_t>>(\"dim\");\n  CHECK_OR_RETURN(!reduce_axes.empty());\n  const AxisVector reduce_axes_vec = {reduce_axes.begin(), reduce_axes.end()};\n  const Shape& reduce_shape = CreateReducedShape(input_shape, reduce_axes_vec);\n  const bool keepdim = ctx->Attr<bool>(\"keepdim\");\n  Shape output_shape;\n  if (keepdim) {\n    output_shape = reduce_shape;\n  } else {\n    output_shape = reduce_shape.RemoveOnes(reduce_axes_vec);\n  }\n  ctx->SetOutputShape(\"output\", 0, output_shape);\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> VarOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\nMaybe<void> VarOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"output\", 0, ctx->InputDType(\"input\", 0));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> VarOp::GetSbp(user_op::SbpContext* ctx) {\n  ctx->NewBuilder().Broadcast(ctx->inputs()).Broadcast(ctx->outputs()).Build();\n  const Shape& input_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"input\", 0).shape();\n  const int64_t ndim = input_shape.NumAxes();\n  const std::vector<int32_t> axis = ctx->Attr<std::vector<int32_t>>(\"dim\");\n  const bool keepdim = ctx->Attr<bool>(\"keepdim\");\n  if (keepdim) {\n    for (int i = 0; i < ndim; i++) {\n      if (std::find(axis.begin(), axis.end(), i) == axis.end()) {\n        ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build();\n      }\n    }\n  } else {\n    int offset = 0;\n    for (int i = 0; i < ndim; i++) {\n      if (std::find(axis.begin(), axis.end(), i) == axis.end()) {\n        ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i - offset).Build();\n      } else {\n        offset += 1;\n      }\n    }\n  }\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/vector_matrix_product_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nMaybe<void> InferTensorDesc4VectorMatrixProduct(user_op::InferContext* ctx) {\n  const user_op::TensorDesc& a = ctx->InputTensorDesc(\"a\", 0);\n  const user_op::TensorDesc& b = ctx->InputTensorDesc(\"b\", 0);\n  int64_t k = a.shape().At(0);\n  CHECK_EQ_OR_RETURN(k, b.shape().At(0)) << \"Dim K should be equal to vector b's dim0. \";\n  int64_t n = b.shape().At(1);\n  ctx->SetOutputShape(\"out\", 0, Shape({n}));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InferDataType4VectorMatrixProduct(user_op::InferContext* ctx) {\n  DataType dtype = ctx->InputDType(\"a\", 0);\n  CHECK_EQ_OR_RETURN(ctx->InputDType(\"b\", 0), dtype)\n      << \"Matrix A datatype should be equal to Vector B. \";\n  ctx->SetOutputDType(\"out\", 0, dtype);\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InferTensorDesc4VectorMatrixProductGradA(user_op::InferContext* ctx) {\n  /*\n  A(k, ) matmul B(k, n) -> (1, k) matmul (k, n) -> (1, n) -> (n)\n  GradA = dy (n) matmul B_transpose(n, k) -> (1, n) matmul (n, k)\n  */\n  const user_op::TensorDesc& b = ctx->InputTensorDesc(\"b\", 0);\n  int64_t k = b.shape().At(0);\n  ctx->SetOutputShape(\"dx\", 0, Shape({k}));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InferTensorDesc4VectorMatrixProductGradB(user_op::InferContext* ctx) {\n  /*\n  A(k, ) matmul B(k, n) -> (1, k) matmul (k, n) -> (1, n) -> (n)\n  GradB = a (k, 1) matmul dy (1, n)\n  */\n  const user_op::TensorDesc& dy = ctx->InputTensorDesc(\"dy\", 0);\n  const user_op::TensorDesc& a = ctx->InputTensorDesc(\"a\", 0);\n  int64_t k = a.shape().At(0);\n  int64_t n = dy.shape().At(0);\n  ctx->SetOutputShape(\"dx\", 0, Shape({k, n}));\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> InferDataType4Grad(user_op::InferContext* ctx) {\n  DataType dtype = ctx->InputDType(\"dy\", 0);\n  ctx->SetOutputDType(\"dx\", 0, dtype);\n  return Maybe<void>::Ok();\n}\n\n}  // namespace\n\n/* static */ Maybe<void> VectorMatrixProductOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  return InferTensorDesc4VectorMatrixProduct(ctx);\n}\n\n/*static*/ Maybe<void> VectorMatrixProductOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> VectorMatrixProductOp::GetSbp(user_op::SbpContext* ctx) {\n  ctx->NewBuilder()\n      .Broadcast(user_op::OpArg(\"a\", 0))\n      .Split(user_op::OpArg(\"b\", 0), 1)\n      .Split(user_op::OpArg(\"out\", 0), 0)\n      .Build();\n  ctx->NewBuilder()\n      .Split(user_op::OpArg(\"a\", 0), 0)\n      .Split(user_op::OpArg(\"b\", 0), 0)\n      .PartialSum(user_op::OpArg(\"out\", 0))\n      .Build();\n  ctx->NewBuilder()\n      .PartialSum(user_op::OpArg(\"a\", 0))\n      .Broadcast(user_op::OpArg(\"b\", 0))\n      .PartialSum(user_op::OpArg(\"out\", 0))\n      .Build();\n  ctx->NewBuilder()\n      .Broadcast(user_op::OpArg(\"a\", 0))\n      .PartialSum(user_op::OpArg(\"b\", 0))\n      .PartialSum(user_op::OpArg(\"out\", 0))\n      .Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> VectorMatrixProductOp::InferDataType(user_op::InferContext* ctx) {\n  return InferDataType4VectorMatrixProduct(ctx);\n}\n\n/* static */ Maybe<void> VectorMatrixProductGradAOp::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferTensorDesc4VectorMatrixProductGradA(ctx);\n}\n\n/*static*/ Maybe<void> VectorMatrixProductGradAOp::InferPhysicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> VectorMatrixProductGradAOp::GetSbp(user_op::SbpContext* ctx) {\n  /*\n  A(k, ) matmul B(k, n) -> (1, k) matmul (k, n) -> (1, n) -> (n)\n  GradA = dy (n) matmul B_transpose(n, k) -> (1, n) matmul (n, k)\n  */\n  ctx->NewBuilder()\n      .Broadcast(user_op::OpArg(\"dy\", 0))\n      .Split(user_op::OpArg(\"b\", 0), 0)\n      .Split(user_op::OpArg(\"dx\", 0), 0)\n      .Build();\n  ctx->NewBuilder()\n      .Split(user_op::OpArg(\"dy\", 0), 0)\n      .Split(user_op::OpArg(\"b\", 0), 1)\n      .PartialSum(user_op::OpArg(\"dx\", 0))\n      .Build();\n  ctx->NewBuilder()\n      .PartialSum(user_op::OpArg(\"dy\", 0))\n      .Broadcast(user_op::OpArg(\"b\", 0))\n      .PartialSum(user_op::OpArg(\"dx\", 0))\n      .Build();\n  ctx->NewBuilder()\n      .Broadcast(user_op::OpArg(\"dy\", 0))\n      .PartialSum(user_op::OpArg(\"b\", 0))\n      .PartialSum(user_op::OpArg(\"dx\", 0))\n      .Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> VectorMatrixProductGradAOp::InferDataType(user_op::InferContext* ctx) {\n  return InferDataType4Grad(ctx);\n}\n\n/* static */ Maybe<void> VectorMatrixProductGradBOp::InferLogicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferTensorDesc4VectorMatrixProductGradB(ctx);\n}\n\n/*static*/ Maybe<void> VectorMatrixProductGradBOp::InferPhysicalTensorDesc(\n    user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/* static */ Maybe<void> VectorMatrixProductGradBOp::GetSbp(user_op::SbpContext* ctx) {\n  /*\n  A(k, ) matmul B(k, n) -> (1, k) matmul (k, n) -> (1, n) -> (n)\n  A(k, ) -> (1, k)\n  GradB = a_transpose (k, 1) matmul dy (1, n)\n  */\n  ctx->NewBuilder()\n      .Split(user_op::OpArg(\"a\", 0), 0)\n      .Broadcast(user_op::OpArg(\"dy\", 0))\n      .Split(user_op::OpArg(\"dx\", 0), 0)\n      .Build();\n  ctx->NewBuilder()\n      .Broadcast(user_op::OpArg(\"a\", 0))\n      .Split(user_op::OpArg(\"dy\", 0), 0)\n      .Split(user_op::OpArg(\"dx\", 0), 1)\n      .Build();\n  ctx->NewBuilder()\n      .Broadcast(user_op::OpArg(\"a\", 0))\n      .PartialSum(user_op::OpArg(\"dy\", 0))\n      .PartialSum(user_op::OpArg(\"dx\", 0))\n      .Build();\n  ctx->NewBuilder()\n      .PartialSum(user_op::OpArg(\"a\", 0))\n      .Broadcast(user_op::OpArg(\"dy\", 0))\n      .PartialSum(user_op::OpArg(\"dx\", 0))\n      .Build();\n  return Maybe<void>::Ok();\n}\n\n/* static */ Maybe<void> VectorMatrixProductGradBOp::InferDataType(user_op::InferContext* ctx) {\n  return InferDataType4Grad(ctx);\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/where_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n#include \"oneflow/core/framework/dtype.h\"\n\nnamespace oneflow {\n\nnamespace {\n\nMaybe<Shape> GetBroadcastShape(const Shape& cond_shape, const Shape& x_shape,\n                               const Shape& y_shape) {\n  size_t ndim = std::max(x_shape.size(), y_shape.size());\n  ndim = std::max(ndim, cond_shape.size());\n\n  DimVector broadcast_dim_vec(ndim);\n  for (size_t i = 0; i < ndim; ++i) {\n    size_t cond_lpad = ndim - cond_shape.size();\n    size_t x_lpad = ndim - x_shape.size();\n    size_t y_lpad = ndim - y_shape.size();\n    int64_t cond_dim = (i < cond_lpad) ? 1 : cond_shape[i - cond_lpad];\n    int64_t x_dim = (i < x_lpad) ? 1 : x_shape[i - x_lpad];\n    int64_t y_dim = (i < y_lpad) ? 1 : y_shape[i - y_lpad];\n    int64_t max_dim = std::max(x_dim, y_dim);\n    max_dim = std::max(max_dim, cond_dim);\n    broadcast_dim_vec[i] = max_dim;\n    if ((cond_dim != 1 && cond_dim != max_dim) || (x_dim != 1 && x_dim != max_dim)\n        || (y_dim != 1 && y_dim != max_dim)) {\n      return Error::RuntimeError() << \"The tensor cond with size \" << cond_shape.ToString()\n                                   << \", x with size \" << x_shape.ToString() << \" and y with size \"\n                                   << y_shape.ToString() << \" are not broadcastable.\";\n    }\n  }\n  return Shape(broadcast_dim_vec);\n}\n\n}  // namespace\n\n/*static*/ Maybe<void> WhereOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  const Shape& cond_shape = ctx->InputShape(\"condition\", 0);\n  const Shape& x_shape = ctx->InputShape(\"x\", 0);\n  const Shape& y_shape = ctx->InputShape(\"y\", 0);\n  ctx->SetOutputShape(\"out\", 0, *JUST(GetBroadcastShape(cond_shape, x_shape, y_shape)));\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> WhereOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n\n/*static*/ Maybe<void> WhereOp::InferDataType(user_op::InferContext* ctx) {\n  DataType cond_dtype = ctx->InputDType(\"condition\", 0);\n  CHECK_OR_RETURN(IsBoolDataType(cond_dtype) || IsIntegralDataType(cond_dtype));\n  DataType x_dtype = ctx->InputDType(\"x\", 0);\n  CHECK_EQ_OR_RETURN(x_dtype, ctx->InputDType(\"y\", 0))\n      << \"InferDataType Failed. Expected \" << DataType_Name(ctx->InputDType(\"y\", 0)) << \", but got \"\n      << DataType_Name(x_dtype);\n  ctx->SetOutputDType(\"out\", 0, x_dtype);\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> WhereOp::GetSbp(user_op::SbpContext* ctx) {\n  const Shape& cond_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"condition\", 0).shape();\n  const Shape& x_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"x\", 0).shape();\n  const Shape& y_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"y\", 0).shape();\n  Shape broadcast_shape = *JUST(GetBroadcastShape(cond_shape, x_shape, y_shape));\n  const size_t ndim = broadcast_shape.size();\n\n  std::vector<user_op::OpArg> broadcast_args;\n  std::vector<user_op::OpArg> split_args;\n  std::vector<int> split_dims;\n  broadcast_args.reserve(3);\n  split_args.reserve(3);\n  split_dims.reserve(3);\n\n  auto CheckArgCanSplit = [&](std::string&& arg_name, const int dim, const Shape& shape) {\n    size_t ddiff = ndim - shape.size();\n    int dim_size = (dim >= ddiff) ? shape[dim - ddiff] : 1;\n    if (dim_size == 1) {\n      broadcast_args.emplace_back(std::forward<decltype(arg_name)>(arg_name), 0);\n    } else {\n      split_args.emplace_back(std::forward<decltype(arg_name)>(arg_name), 0);\n      split_dims.push_back(dim - ddiff);\n    }\n  };\n\n  for (int i = 0; i < ndim; ++i) {\n    if (broadcast_shape[i] == 1) { continue; }\n    broadcast_args.clear();\n    split_args.clear();\n    split_dims.clear();\n    CheckArgCanSplit(\"x\", i, x_shape);\n    CheckArgCanSplit(\"y\", i, y_shape);\n    CheckArgCanSplit(\"condition\", i, cond_shape);\n\n    auto builder = ctx->NewBuilder();\n    builder.Broadcast(broadcast_args);\n    for (int i = 0; i < split_args.size(); ++i) { builder.Split(split_args[i], split_dims[i]); }\n    builder.Split(user_op::OpArg(\"out\", 0), i);\n    builder.Build();\n  }\n\n  ctx->NewBuilder()\n      .Broadcast(user_op::OpArg(\"condition\", 0))\n      .PartialSum(user_op::OpArg(\"x\", 0))\n      .PartialSum(user_op::OpArg(\"y\", 0))\n      .PartialSum(user_op::OpArg(\"out\", 0))\n      .Build();\n\n  return Maybe<void>::Ok();\n}\n\n/*static*/ Maybe<void> WhereOp::ModifyInputArg(const GetInputArgModifier& fn,\n                                               const user_op::UserOpConfWrapper& conf) {\n  user_op::InputArgModifier* cond_arg_modifier = fn(\"condition\", 0);\n  cond_arg_modifier->set_requires_grad(false);\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/ops/zero_like_op.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/core/framework/op_generated.h\"\n\nnamespace oneflow {\n\n/*static*/ Maybe<void> ZeroLikeOp::GetSbp(user_op::SbpContext* ctx) {\n  const user_op::TensorDesc& like_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex(\"like\", 0);\n  FOR_RANGE(int64_t, i, 0, like_tensor.shape().NumAxes()) {\n    ctx->NewBuilder()\n        .Split(user_op::OpArg(\"like\", 0), i)\n        .Split(user_op::OpArg(\"out\", 0), i)\n        .Build();\n  }\n  ctx->NewBuilder()\n      .PartialSum(user_op::OpArg(\"like\", 0))\n      .Broadcast(user_op::OpArg(\"out\", 0))\n      .Build();\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> ZeroLikeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {\n  ctx->SetOutputShape(\"out\", 0, ctx->InputShape(\"like\", 0));\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> ZeroLikeOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {\n  return InferLogicalTensorDesc(ctx);\n}\n/*static*/ Maybe<void> ZeroLikeOp::InferDataType(user_op::InferContext* ctx) {\n  ctx->SetOutputDType(\"out\", 0, ctx->InputDType(\"like\", 0));\n  return Maybe<void>::Ok();\n}\n/*static*/ Maybe<void> ZeroLikeOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) {\n  const NdSbp& in_sbp = ctx->NdSbpHint4InputArgNameAndIndex(\"like\", 0);\n  NdSbp* like_distribution = ctx->NdSbp4ArgNameAndIndex(\"like\", 0);\n  NdSbp* out_distribution = ctx->NdSbp4ArgNameAndIndex(\"out\", 0);\n  *like_distribution = in_sbp;\n  *out_distribution = in_sbp;\n  return Maybe<void>::Ok();\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/summary/crc32c.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_SUMMARY_CRC32C_H_\n#define ONEFLOW_USER_SUMMARY_CRC32C_H_\n\n#include \"oneflow/core/common/util.h\"\n\nnamespace oneflow {\n\nnamespace summary {\n\nstatic const uint32_t table[256] = {\n    0x00000000, 0xf26b8303, 0xe13b70f7, 0x1350f3f4, 0xc79a971f, 0x35f1141c, 0x26a1e7e8, 0xd4ca64eb,\n    0x8ad958cf, 0x78b2dbcc, 0x6be22838, 0x9989ab3b, 0x4d43cfd0, 0xbf284cd3, 0xac78bf27, 0x5e133c24,\n    0x105ec76f, 0xe235446c, 0xf165b798, 0x030e349b, 0xd7c45070, 0x25afd373, 0x36ff2087, 0xc494a384,\n    0x9a879fa0, 0x68ec1ca3, 0x7bbcef57, 0x89d76c54, 0x5d1d08bf, 0xaf768bbc, 0xbc267848, 0x4e4dfb4b,\n    0x20bd8ede, 0xd2d60ddd, 0xc186fe29, 0x33ed7d2a, 0xe72719c1, 0x154c9ac2, 0x061c6936, 0xf477ea35,\n    0xaa64d611, 0x580f5512, 0x4b5fa6e6, 0xb93425e5, 0x6dfe410e, 0x9f95c20d, 0x8cc531f9, 0x7eaeb2fa,\n    0x30e349b1, 0xc288cab2, 0xd1d83946, 0x23b3ba45, 0xf779deae, 0x05125dad, 0x1642ae59, 0xe4292d5a,\n    0xba3a117e, 0x4851927d, 0x5b016189, 0xa96ae28a, 0x7da08661, 0x8fcb0562, 0x9c9bf696, 0x6ef07595,\n    0x417b1dbc, 0xb3109ebf, 0xa0406d4b, 0x522bee48, 0x86e18aa3, 0x748a09a0, 0x67dafa54, 0x95b17957,\n    0xcba24573, 0x39c9c670, 0x2a993584, 0xd8f2b687, 0x0c38d26c, 0xfe53516f, 0xed03a29b, 0x1f682198,\n    0x5125dad3, 0xa34e59d0, 0xb01eaa24, 0x42752927, 0x96bf4dcc, 0x64d4cecf, 0x77843d3b, 0x85efbe38,\n    0xdbfc821c, 0x2997011f, 0x3ac7f2eb, 0xc8ac71e8, 0x1c661503, 0xee0d9600, 0xfd5d65f4, 0x0f36e6f7,\n    0x61c69362, 0x93ad1061, 0x80fde395, 0x72966096, 0xa65c047d, 0x5437877e, 0x4767748a, 0xb50cf789,\n    0xeb1fcbad, 0x197448ae, 0x0a24bb5a, 0xf84f3859, 0x2c855cb2, 0xdeeedfb1, 0xcdbe2c45, 0x3fd5af46,\n    0x7198540d, 0x83f3d70e, 0x90a324fa, 0x62c8a7f9, 0xb602c312, 0x44694011, 0x5739b3e5, 0xa55230e6,\n    0xfb410cc2, 0x092a8fc1, 0x1a7a7c35, 0xe811ff36, 0x3cdb9bdd, 0xceb018de, 0xdde0eb2a, 0x2f8b6829,\n    0x82f63b78, 0x709db87b, 0x63cd4b8f, 0x91a6c88c, 0x456cac67, 0xb7072f64, 0xa457dc90, 0x563c5f93,\n    0x082f63b7, 0xfa44e0b4, 0xe9141340, 0x1b7f9043, 0xcfb5f4a8, 0x3dde77ab, 0x2e8e845f, 0xdce5075c,\n    0x92a8fc17, 0x60c37f14, 0x73938ce0, 0x81f80fe3, 0x55326b08, 0xa759e80b, 0xb4091bff, 0x466298fc,\n    0x1871a4d8, 0xea1a27db, 0xf94ad42f, 0x0b21572c, 0xdfeb33c7, 0x2d80b0c4, 0x3ed04330, 0xccbbc033,\n    0xa24bb5a6, 0x502036a5, 0x4370c551, 0xb11b4652, 0x65d122b9, 0x97baa1ba, 0x84ea524e, 0x7681d14d,\n    0x2892ed69, 0xdaf96e6a, 0xc9a99d9e, 0x3bc21e9d, 0xef087a76, 0x1d63f975, 0x0e330a81, 0xfc588982,\n    0xb21572c9, 0x407ef1ca, 0x532e023e, 0xa145813d, 0x758fe5d6, 0x87e466d5, 0x94b49521, 0x66df1622,\n    0x38cc2a06, 0xcaa7a905, 0xd9f75af1, 0x2b9cd9f2, 0xff56bd19, 0x0d3d3e1a, 0x1e6dcdee, 0xec064eed,\n    0xc38d26c4, 0x31e6a5c7, 0x22b65633, 0xd0ddd530, 0x0417b1db, 0xf67c32d8, 0xe52cc12c, 0x1747422f,\n    0x49547e0b, 0xbb3ffd08, 0xa86f0efc, 0x5a048dff, 0x8ecee914, 0x7ca56a17, 0x6ff599e3, 0x9d9e1ae0,\n    0xd3d3e1ab, 0x21b862a8, 0x32e8915c, 0xc083125f, 0x144976b4, 0xe622f5b7, 0xf5720643, 0x07198540,\n    0x590ab964, 0xab613a67, 0xb831c993, 0x4a5a4a90, 0x9e902e7b, 0x6cfbad78, 0x7fab5e8c, 0x8dc0dd8f,\n    0xe330a81a, 0x115b2b19, 0x020bd8ed, 0xf0605bee, 0x24aa3f05, 0xd6c1bc06, 0xc5914ff2, 0x37faccf1,\n    0x69e9f0d5, 0x9b8273d6, 0x88d28022, 0x7ab90321, 0xae7367ca, 0x5c18e4c9, 0x4f48173d, 0xbd23943e,\n    0xf36e6f75, 0x0105ec76, 0x12551f82, 0xe03e9c81, 0x34f4f86a, 0xc69f7b69, 0xd5cf889d, 0x27a40b9e,\n    0x79b737ba, 0x8bdcb4b9, 0x988c474d, 0x6ae7c44e, 0xbe2da0a5, 0x4c4623a6, 0x5f16d052, 0xad7d5351};\n\ninline uint32_t GetCrc32(const char* buf, size_t size) {\n  const uint8_t* uchar_buf = reinterpret_cast<const uint8_t*>(buf);\n  uint32_t crc = 0 ^ 0xffffffffu;\n  for (int i = 0; i < size; ++i) { crc = table[(crc & 0xff) ^ uchar_buf[i]] ^ (crc >> 8); }\n  return crc ^ 0xffffffffu;\n}\n\ninline uint32_t MaskCrc32(uint32_t crc) { return ((crc >> 15) | (crc << 17)) + 0xa282ead8ul; }\n\n}  // namespace summary\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_SUMMARY_CRC32C_H_\n"
  },
  {
    "path": "oneflow/user/summary/env_time.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_SUMMARY_ENV_TIME_H_\n#define ONEFLOW_USER_SUMMARY_ENV_TIME_H_\n\n#include \"oneflow/core/common/util.h\"\n\nnamespace oneflow {\n\nnamespace summary {\n\nstatic constexpr uint64_t kMicroTimeToNanoTime = 1000ULL;\nstatic constexpr uint64_t kSecondToNanoTime = 1000ULL * 1000ULL * 1000ULL;\nstatic constexpr uint64_t kMircoTimeToSecondTime = 1000ULL * 1000ULL;\n\ninline uint64_t CurrentNanoTime() {\n  struct timespec ts;\n  clock_gettime(CLOCK_REALTIME, &ts);\n  return (static_cast<uint64_t>(ts.tv_sec) * kSecondToNanoTime + static_cast<uint64_t>(ts.tv_nsec));\n}\n\ninline uint64_t CurrentMircoTime() { return CurrentNanoTime() / kMicroTimeToNanoTime; }\n\ninline uint64_t CurrentSecondTime() { return CurrentMircoTime() / kMircoTimeToSecondTime; }\n\ninline double GetWallTime() {\n  return static_cast<double>(CurrentNanoTime() / kMicroTimeToNanoTime) / 1.0e6;\n}\n\n}  // namespace summary\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_SUMMARY_ENV_TIME_H_\n"
  },
  {
    "path": "oneflow/user/summary/event_writer_helper.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/user/summary/event_writer_helper.h\"\n#include \"oneflow/user/summary/env_time.h\"\n#include \"oneflow/user/summary/events_writer.h\"\n#include \"oneflow/user/summary/histogram.h\"\n#include \"oneflow/core/common/protobuf.h\"\n#include \"oneflow/core/summary/summary.pb.h\"\n#include \"oneflow/core/summary/event.pb.h\"\n\n#include <png.h>\n#include <zlib.h>\n#include <memory>\n#include <type_traits>\n#define USER_LIBPNG_VER_STRING \"1.6.24\"\n\nnamespace oneflow {\n\nnamespace summary {\n\nconst char* kScalarPluginName = \"scalars\";\nconst char* kHistogramPluginName = \"histograms\";\nconst char* kImagePluginName = \"images\";\n\nvoid SetPluginData(SummaryMetadata* metadata, const char* name) {\n  if (metadata->plugin_data().plugin_name().empty()) {\n    metadata->mutable_plugin_data()->set_plugin_name(name);\n  }\n}\n\nMaybe<void> FillScalarInSummary(const float& value, const std::string& tag, Summary* s) {\n  SummaryMetadata metadata;\n  SetPluginData(&metadata, kScalarPluginName);\n  Summary::Value* v = s->add_value();\n  v->set_tag(tag);\n  *v->mutable_metadata() = metadata;\n  v->set_simple_value(value);\n  return Maybe<void>::Ok();\n}\n\ntemplate<typename T>\nMaybe<void> FillHistogramInSummary(const user_op::Tensor& value, const std::string& tag,\n                                   Summary* s) {\n  SummaryMetadata metadata;\n  SetPluginData(&metadata, kHistogramPluginName);\n  Summary::Value* v = s->add_value();\n  v->set_tag(tag);\n  *v->mutable_metadata() = metadata;\n  summary::Histogram histo;\n  for (int64_t i = 0; i < value.shape_view().elem_cnt(); i++) {\n    double double_val = value.dptr<T>()[i];\n    histo.AppendValue(double_val);\n  }\n  histo.AppendToProto(v->mutable_histo());\n  return Maybe<void>::Ok();\n}\n\nvoid WriteImageDataFn(png_structp png_ptr, png_bytep data, png_size_t length) {\n  std::string* const s = reinterpret_cast<std::string*>(png_get_io_ptr(png_ptr));\n  s->append(reinterpret_cast<const char*>(data), length);\n}\n\nbool WriteImageToBuffer(const uint8_t* image, int width, int height, int depth,\n                        std::string* png_string) {\n  CHECK_NOTNULL(image);\n  CHECK_NOTNULL(png_string);\n  if (width == 0 || height == 0) return false;\n  png_string->resize(0);\n  png_infop info_ptr = nullptr;\n  png_structp png_ptr = png_create_write_struct(USER_LIBPNG_VER_STRING, 0, 0, 0);\n  if (png_ptr == nullptr) return false;\n  if (setjmp(png_jmpbuf(png_ptr))) {\n    png_destroy_write_struct(&png_ptr, info_ptr ? &info_ptr : nullptr);\n    return false;\n  }\n  info_ptr = png_create_info_struct(png_ptr);\n  if (info_ptr == nullptr) {\n    png_destroy_write_struct(&png_ptr, nullptr);\n    return false;\n  }\n  int color_type = -1;\n  switch (depth) {\n    case 1: color_type = PNG_COLOR_TYPE_GRAY; break;\n    case 2: color_type = PNG_COLOR_TYPE_GRAY_ALPHA; break;\n    case 3: color_type = PNG_COLOR_TYPE_RGB; break;\n    case 4: color_type = PNG_COLOR_TYPE_RGB_ALPHA; break;\n    default: png_destroy_write_struct(&png_ptr, &info_ptr); return false;\n  }\n  const int bit_depth = 8;\n  png_set_write_fn(png_ptr, png_string, WriteImageDataFn, nullptr);\n  png_set_compression_level(png_ptr, Z_DEFAULT_COMPRESSION);\n  png_set_compression_mem_level(png_ptr, MAX_MEM_LEVEL);\n  png_set_IHDR(png_ptr, info_ptr, width, height, bit_depth, color_type, PNG_INTERLACE_NONE,\n               PNG_COMPRESSION_TYPE_DEFAULT, PNG_FILTER_TYPE_DEFAULT);\n  png_write_info(png_ptr, info_ptr);\n  png_byte* row = reinterpret_cast<png_byte*>(const_cast<uint8_t*>(image));\n  int row_bytes = width * depth;\n  for (; height--; row += row_bytes) png_write_row(png_ptr, row);\n  png_write_end(png_ptr, nullptr);\n  png_destroy_write_struct(&png_ptr, &info_ptr);\n  return true;\n}\n\nMaybe<void> FillImageInSummary(const user_op::Tensor& tensor, const std::string& tag, Summary* s) {\n  SummaryMetadata metadata;\n  SetPluginData(&metadata, kImagePluginName);\n  if (!(tensor.shape_view().NumAxes() == 4\n        && (tensor.shape_view().At(3) == 1 || tensor.shape_view().At(3) == 3\n            || tensor.shape_view().At(3) == 4))) {\n    UNIMPLEMENTED();\n  }\n  if (!(tensor.shape_view().At(0) < (1LL << 31) && tensor.shape_view().At(1) < (1LL << 31)\n        && tensor.shape_view().At(2) < (1LL << 31)\n        && (tensor.shape_view().At(1) * tensor.shape_view().At(2)) < (1LL << 29))) {\n    UNIMPLEMENTED();\n  }\n  const int64_t batch_size = static_cast<int64_t>(tensor.shape_view().At(0));\n  const int64_t h = static_cast<int64_t>(tensor.shape_view().At(1));\n  const int64_t w = static_cast<int64_t>(tensor.shape_view().At(2));\n  const int64_t hw = h * w;\n  const int64_t depth = static_cast<int64_t>(tensor.shape_view().At(3));\n  if (tensor.data_type() == DataType::kUInt8) {\n    auto ith_image = [&tensor, hw, depth](int i) {\n      auto images = tensor.dptr<uint8_t>();\n      auto image_i = std::unique_ptr<uint8_t[]>{new uint8_t[hw * depth]};\n      memcpy(image_i.get(), images + i * hw * depth, hw * depth);\n      return image_i;\n    };\n    for (int i = 0; i < batch_size; ++i) {\n      Summary::Value* v = s->add_value();\n      *v->mutable_metadata() = metadata;\n      if (batch_size == 1) {\n        v->set_tag(tag);\n      } else {\n        v->set_tag(tag + std::to_string(i));\n      }\n      Image* si = v->mutable_image();\n      si->set_height(h);\n      si->set_width(w);\n      si->set_colorspace(depth);\n      auto image = ith_image(i);\n      if (!WriteImageToBuffer(image.get(), w, h, depth, si->mutable_encoded_image_string()))\n        UNIMPLEMENTED();\n    }\n  }\n  return Maybe<void>::Ok();\n}\n\ntemplate<typename T>\nstruct EventWriterHelper<DeviceType::kCPU, T> {\n  static void WritePbToFile(int64_t step, const std::string& value) {\n    std::unique_ptr<Event> e{new Event};\n    Summary sum;\n    TxtString2PbMessage(value, &sum);\n    e->set_step(step);\n    e->set_wall_time(GetWallTime());\n    *e->mutable_summary() = sum;\n    Singleton<EventsWriter>::Get()->AppendQueue(std::move(e));\n  }\n\n  static void WriteScalarToFile(int64_t step, float value, const std::string& tag) {\n    std::unique_ptr<Event> e{new Event};\n    e->set_step(step);\n    e->set_wall_time(GetWallTime());\n    CHECK_JUST(FillScalarInSummary(value, tag, e->mutable_summary()));\n    Singleton<EventsWriter>::Get()->AppendQueue(std::move(e));\n  }\n\n  static void WriteHistogramToFile(int64_t step, const user_op::Tensor& value,\n                                   const std::string& tag) {\n    std::unique_ptr<Event> e{new Event};\n    e->set_step(step);\n    e->set_wall_time(GetWallTime());\n    CHECK_JUST(FillHistogramInSummary<T>(value, tag, e->mutable_summary()));\n    Singleton<EventsWriter>::Get()->AppendQueue(std::move(e));\n  }\n\n  static void WriteImageToFile(int64_t step, const user_op::Tensor& tensor,\n                               const std::string& tag) {\n    std::unique_ptr<Event> e{new Event};\n    e->set_step(step);\n    e->set_wall_time(GetWallTime());\n    CHECK_JUST(FillImageInSummary(tensor, tag, e->mutable_summary()));\n    Singleton<EventsWriter>::Get()->AppendQueue(std::move(e));\n  }\n};\n\n#define INSTANTIATE_EVENT_WRITE_HELPER_CPU(dtype) \\\n  template struct EventWriterHelper<DeviceType::kCPU, dtype>;\n\nINSTANTIATE_EVENT_WRITE_HELPER_CPU(float)\nINSTANTIATE_EVENT_WRITE_HELPER_CPU(double)\nINSTANTIATE_EVENT_WRITE_HELPER_CPU(int32_t)\nINSTANTIATE_EVENT_WRITE_HELPER_CPU(int64_t)\nINSTANTIATE_EVENT_WRITE_HELPER_CPU(uint8_t)\nINSTANTIATE_EVENT_WRITE_HELPER_CPU(int8_t)\n\n}  // namespace summary\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/summary/event_writer_helper.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_SUMMARY_EVENT_WRITER_HELPER_H_\n#define ONEFLOW_USER_SUMMARY_EVENT_WRITER_HELPER_H_\n\n#include \"oneflow/core/framework/framework.h\"\n\nnamespace oneflow {\n\nnamespace summary {\n\nclass EventsWriter;\n\ntemplate<DeviceType device_type, typename T>\nstruct EventWriterHelper {\n  static void WritePbToFile(int64_t step, const std::string& value);\n  static void WriteScalarToFile(int64_t step, float value, const std::string& tag);\n  static void WriteHistogramToFile(int64_t step, const user_op::Tensor& value,\n                                   const std::string& tag);\n  static void WriteImageToFile(int64_t step, const user_op::Tensor& tensor, const std::string& tag);\n};\n\n}  // namespace summary\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_SUMMARY_EVENT_WRITER_HELPER_H_\n"
  },
  {
    "path": "oneflow/user/summary/events_writer.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/user/summary/events_writer.h\"\n#include \"oneflow/core/common/str_util.h\"\n#include \"oneflow/user/summary/env_time.h\"\n\nnamespace oneflow {\n\nnamespace summary {\n\nEventsWriter::EventsWriter() : is_inited_(false) {}\n\nEventsWriter::~EventsWriter() { Close(); }\n\nMaybe<void> EventsWriter::Init(const std::string& logdir) {\n  file_system_ = std::make_unique<fs::PosixFileSystem>();\n  log_dir_ = logdir + \"/event\";\n  file_system_->RecursivelyCreateDirIfNotExist(log_dir_);\n  JUST(TryToInit());\n  is_inited_ = true;\n  last_flush_time_ = CurrentMircoTime();\n  return Maybe<void>::Ok();\n}\n\nMaybe<void> EventsWriter::TryToInit() {\n  if (!filename_.empty()) {\n    if (!file_system_->FileExists(filename_)) {\n      LOG(WARNING) << \"Event log file was lost, attempting create a new log file!\";\n    } else {\n      return Maybe<void>::Ok();\n    }\n  }\n\n  int32_t current_time = CurrentSecondTime();\n  char fname[100] = {'\\0'};\n  snprintf(fname, 100, \"event.%d.log\", current_time);\n\n  filename_ = JoinPath(log_dir_, fname);\n  file_system_->NewWritableFile(filename_, &writable_file_);\n  CHECK_OR_RETURN(writable_file_ != nullptr);\n  {\n    Event event;\n    event.set_wall_time(current_time);\n    event.set_file_version(FILE_VERSION);\n    WriteEvent(event);\n    Flush();\n  }\n  return Maybe<void>::Ok();\n}\n\nvoid EventsWriter::AppendQueue(std::unique_ptr<Event> event) {\n  queue_mutex.lock();\n  event_queue_.emplace_back(std::move(event));\n  queue_mutex.unlock();\n  if (event_queue_.size() > MAX_QUEUE_NUM || CurrentMircoTime() - last_flush_time_ > FLUSH_TIME) {\n    Flush();\n  }\n}\n\nvoid EventsWriter::Flush() {\n  queue_mutex.lock();\n  for (const std::unique_ptr<Event>& e : event_queue_) { WriteEvent(*e); }\n  event_queue_.clear();\n  queue_mutex.unlock();\n  FileFlush();\n  last_flush_time_ = CurrentMircoTime();\n}\n\nvoid EventsWriter::WriteEvent(const Event& event) {\n  std::string event_str;\n  event.AppendToString(&event_str);\n  if (!TryToInit().IsOk()) {\n    LOG(ERROR) << \"Write failed because file could not be opened.\";\n    return;\n  }\n  if (writable_file_ == nullptr) {\n    LOG(WARNING) << \"Log file is closed!\";\n    return;\n  }\n\n  char head[kHeadSize];\n  char tail[kTailSize];\n  EncodeHead(head, event_str.size());\n  EncodeTail(tail, event_str.data(), event_str.size());\n  writable_file_->Append(head, sizeof(head));\n  writable_file_->Append(event_str.data(), event_str.size());\n  writable_file_->Append(tail, sizeof(tail));\n  FileFlush();\n}\n\nvoid EventsWriter::FileFlush() {\n  if (writable_file_ == nullptr) { return; }\n  writable_file_->Flush();\n}\n\nvoid EventsWriter::Close() {\n  if (!is_inited_) { return; }\n  queue_mutex.unlock();\n  Flush();\n  if (writable_file_ != nullptr) {\n    writable_file_->Close();\n    writable_file_.reset(nullptr);\n  }\n}\n\n}  // namespace summary\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/summary/events_writer.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_SUMMARY_EVENTS_WRITER_H_\n#define ONEFLOW_USER_SUMMARY_EVENTS_WRITER_H_\n\n#include \"oneflow/core/persistence/posix/posix_file_system.h\"\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/user/summary/crc32c.h\"\n#include \"oneflow/core/summary/event.pb.h\"\n\n#include <time.h>\n#include <mutex>\n\nnamespace oneflow {\n\nnamespace summary {\n\n#define MAX_QUEUE_NUM 10\n#define FLUSH_TIME 3 * 60 * 1000 * 1000\n#define FILE_VERSION \"brain.Event:3\"\nconst size_t kHeadSize = sizeof(uint64_t) + sizeof(uint32_t);\nconst size_t kTailSize = sizeof(uint32_t);\n\nclass EventsWriter {\n public:\n  EventsWriter();\n  ~EventsWriter();\n\n  Maybe<void> Init(const std::string& logdir);\n  void WriteEvent(const Event& event);\n  void Flush();\n  void Close();\n\n  void AppendQueue(std::unique_ptr<Event> event);\n  void FileFlush();\n\n private:\n  Maybe<void> TryToInit();\n  inline static void EncodeHead(char* head, size_t size);\n  inline static void EncodeTail(char* tail, const char* data, size_t size);\n\n  bool is_inited_;\n  std::string log_dir_;\n  std::string filename_;\n  std::unique_ptr<fs::FileSystem> file_system_;\n  std::unique_ptr<fs::WritableFile> writable_file_;\n  uint64_t last_flush_time_;\n  std::vector<std::unique_ptr<Event>> event_queue_;\n  std::mutex queue_mutex;\n  OF_DISALLOW_COPY(EventsWriter);\n};\n\nvoid EventsWriter::EncodeHead(char* head, size_t size) {\n  memcpy(head, &size, sizeof(size));\n  uint32_t value = MaskCrc32(GetCrc32(head, sizeof(uint64_t)));\n  memcpy(head + sizeof(uint64_t), &value, sizeof(value));\n}\n\nvoid EventsWriter::EncodeTail(char* tail, const char* data, size_t size) {\n  uint32_t value = MaskCrc32(GetCrc32(data, size));\n  memcpy(tail, &value, sizeof(value));\n}\n\n}  // namespace summary\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_SUMMARY_EVENTS_WRITER_H_\n"
  },
  {
    "path": "oneflow/user/summary/histogram.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/user/summary/histogram.h\"\n#include \"oneflow/core/common/maybe.h\"\n#include <cfloat>\n#include <algorithm>\n\nnamespace oneflow {\n\nnamespace summary {\n\nstatic std::vector<double> defalut_container = {-DBL_MAX,\n                                                -451872326.521804,\n                                                -410793024.1107308,\n                                                -373448203.737028,\n                                                -339498367.0336618,\n                                                -308634879.1215107,\n                                                -280577162.83773696,\n                                                -255070148.03430632,\n                                                -231881952.75846028,\n                                                -210801775.23496386,\n                                                -191637977.48633078,\n                                                -174216343.1693916,\n                                                -158378493.79035598,\n                                                -143980448.9003236,\n                                                -130891317.18211237,\n                                                -118992106.52919304,\n                                                -108174642.2992664,\n                                                -98340583.90842399,\n                                                -89400530.82583998,\n                                                -81273209.8416727,\n                                                -73884736.21970245,\n                                                -67167942.01791131,\n                                                -61061765.47082846,\n                                                -55510695.882571325,\n                                                -50464268.984155744,\n                                                -45876608.16741431,\n                                                -41706007.424922094,\n                                                -37914552.20447463,\n                                                -34467774.731340565,\n                                                -31334340.664855056,\n                                                -28485764.24077732,\n                                                -25896149.309797563,\n                                                -23541953.91799778,\n                                                -21401776.28908889,\n                                                -19456160.26280808,\n                                                -17687418.420734618,\n                                                -16079471.291576924,\n                                                -14617701.174160838,\n                                                -13288819.249237124,\n                                                -12080744.772033747,\n                                                -10982495.247303406,\n                                                -9984086.58845764,\n                                                -9076442.353143308,\n                                                -8251311.230130279,\n                                                -7501192.027391163,\n                                                -6819265.479446511,\n                                                -6199332.254042282,\n                                                -5635756.594583892,\n                                                -5123415.085985356,\n                                                -4657650.078168505,\n                                                -4234227.34378955,\n                                                -3849297.5852632266,\n                                                -3499361.4411483877,\n                                                -3181237.6737712612,\n                                                -2892034.2488829647,\n                                                -2629122.0444390588,\n                                                -2390110.949490053,\n                                                -2172828.135900048,\n                                                -1975298.30536368,\n                                                -1795725.7321487998,\n                                                -1632477.9383170905,\n                                                -1484070.8530155367,\n                                                -1349155.320923215,\n                                                -1226504.8372029227,\n                                                -1115004.3974572024,\n                                                -1013640.3613247294,\n                                                -921491.2375679357,\n                                                -837719.3068799415,\n                                                -761563.0062544922,\n                                                -692330.005685902,\n                                                -629390.9142599108,\n                                                -572173.5584181007,\n                                                -520157.7803800915,\n                                                -472870.7094364468,\n                                                -429882.4631240425,\n                                                -390802.23920367495,\n                                                -355274.76291243173,\n                                                -322977.0571931197,\n                                                -293615.50653919973,\n                                                -266923.1877629088,\n                                                -242657.4434208262,\n                                                -220597.67583711472,\n                                                -200543.34167010427,\n                                                -182312.12879100387,\n                                                -165738.2989009126,\n                                                -150671.18081901144,\n                                                -136973.80074455583,\n                                                -124521.63704050529,\n                                                -113201.48821864116,\n                                                -102910.44383512832,\n                                                -93554.94894102574,\n                                                -85049.95358275066,\n                                                -77318.13962068241,\n                                                -70289.21783698401,\n                                                -63899.28894271274,\n                                                -58090.26267519339,\n                                                -52809.32970472126,\n                                                -48008.4815497466,\n                                                -43644.07413613327,\n                                                -39676.43103284842,\n                                                -36069.48275713492,\n                                                -32790.438870122656,\n                                                -29809.489881929687,\n                                                -27099.536256299714,\n                                                -24635.942051181555,\n                                                -22396.310955619592,\n                                                -20360.2826869269,\n                                                -18509.347897206273,\n                                                -16826.679906551155,\n                                                -15296.98173322832,\n                                                -13906.347030207562,\n                                                -12642.133663825056,\n                                                -11492.848785295504,\n                                                -10448.04435026864,\n                                                -9498.222136607854,\n                                                -8634.74739691623,\n                                                -7849.770360832936,\n                                                -7136.154873484486,\n                                                -6487.413521349533,\n                                                -5897.648655772302,\n                                                -5361.49877797482,\n                                                -4874.089798158927,\n                                                -4430.990725599024,\n                                                -4028.173386908203,\n                                                -3661.9758062801843,\n                                                -3329.0689148001675,\n                                                -3026.42628618197,\n                                                -2751.2966238017907,\n                                                -2501.1787489107187,\n                                                -2273.798862646108,\n                                                -2067.089875132825,\n                                                -1879.1726137571134,\n                                                -1708.338739779194,\n                                                -1553.0352179810852,\n                                                -1411.8501981646227,\n                                                -1283.500180149657,\n                                                -1166.8183455905971,\n                                                -1060.7439505369064,\n                                                -964.3126823062785,\n                                                -876.6478930057076,\n                                                -796.9526300051887,\n                                                -724.5023909138079,\n                                                -658.6385371943708,\n                                                -598.762306540337,\n                                                -544.3293695821245,\n                                                -494.844881438295,\n                                                -449.85898312572266,\n                                                -408.9627119324751,\n                                                -371.78428357497734,\n                                                -337.9857123408885,\n                                                -307.2597384917168,\n                                                -279.3270349924698,\n                                                -253.93366817497255,\n                                                -230.84878924997503,\n                                                -209.86253568179546,\n                                                -190.78412334708676,\n                                                -173.44011213371522,\n                                                -157.67282921246837,\n                                                -143.33893564769852,\n                                                -130.30812331608956,\n                                                -118.46193028735415,\n                                                -107.69266389759467,\n                                                -97.90242172508606,\n                                                -89.00220156826005,\n                                                -80.91109233478186,\n                                                -73.55553848616532,\n                                                -66.86867135105938,\n                                                -60.78970122823579,\n                                                -55.26336475294163,\n                                                -50.2394225026742,\n                                                -45.67220227515836,\n                                                -41.520183886507596,\n                                                -37.745621715006905,\n                                                -34.314201559097185,\n                                                -31.19472869008835,\n                                                -28.35884426371668,\n                                                -25.78076751246971,\n                                                -23.437061374972462,\n                                                -21.306419431793145,\n                                                -19.36947221072104,\n                                                -17.608611100655487,\n                                                -16.00782827332317,\n                                                -14.552571157566518,\n                                                -13.229610143242288,\n                                                -12.026918312038443,\n                                                -10.933562101853129,\n                                                -9.93960191077557,\n                                                -9.0360017370687,\n                                                -8.214547033698818,\n                                                -7.467770030635288,\n                                                -6.788881846032079,\n                                                -6.171710769120072,\n                                                -5.6106461537455194,\n                                                -5.100587412495926,\n                                                -4.636897647723569,\n                                                -4.215361497930517,\n                                                -3.8321468163004693,\n                                                -3.4837698330004265,\n                                                -3.167063484545842,\n                                                -2.8791486223144016,\n                                                -2.6174078384676376,\n                                                -2.379461671334216,\n                                                -2.163146973940196,\n                                                -1.9664972490365416,\n                                                -1.7877247718514013,\n                                                -1.6252043380467283,\n                                                -1.4774584891333893,\n                                                -1.3431440810303539,\n                                                -1.221040073663958,\n                                                -1.1100364306035981,\n                                                -1.0091240278214528,\n                                                -0.9173854798376843,\n                                                -0.8339867998524402,\n                                                -0.7581698180476728,\n                                                -0.689245289134248,\n                                                -0.62658662648568,\n                                                -0.5696242058960727,\n                                                -0.5178401871782479,\n                                                -0.47076380652567984,\n                                                -0.4279670968415271,\n                                                -0.389060997128661,\n                                                -0.35369181557150997,\n                                                -0.32153801415591815,\n                                                -0.2923072855962892,\n                                                -0.26573389599662656,\n                                                -0.2415762690878423,\n                                                -0.21961479007985663,\n                                                -0.199649809163506,\n                                                -0.18149982651227817,\n                                                -0.16499984228388923,\n                                                -0.14999985662171747,\n                                                -0.13636350601974315,\n                                                -0.12396682365431194,\n                                                -0.11269711241301085,\n                                                -0.1024519203754644,\n                                                -0.09313810943224035,\n                                                -0.08467100857476395,\n                                                -0.07697364415887631,\n                                                -0.069976040144433,\n                                                -0.06361458194948454,\n                                                -0.05783143813589502,\n                                                -0.05257403466899547,\n                                                -0.04779457697181406,\n                                                -0.04344961542892187,\n                                                -0.03949965038992897,\n                                                -0.0359087730817536,\n                                                -0.032644339165230546,\n                                                -0.029676671968391403,\n                                                -0.02697879269853764,\n                                                -0.02452617518048876,\n                                                -0.022296522891353417,\n                                                -0.020269566264866742,\n                                                -0.018426878422606128,\n                                                -0.01675170765691466,\n                                                -0.01522882514264969,\n                                                -0.013844386493317899,\n                                                -0.012585805903016271,\n                                                -0.01144164173001479,\n                                                -0.010401492481831627,\n                                                -0.009455902256210569,\n                                                -0.008596274778373244,\n                                                -0.007814795253066584,\n                                                -0.007104359320969622,\n                                                -0.006458508473608747,\n                                                -0.005871371339644315,\n                                                -0.005337610308767558,\n                                                -0.004852373007970507,\n                                                -0.004411248189064097,\n                                                -0.004010225626421907,\n                                                -0.0036456596603835515,\n                                                -0.0033142360548941373,\n                                                -0.003012941868085579,\n                                                -0.0027390380618959806,\n                                                -0.0024900346017236183,\n                                                -0.0022636678197487437,\n                                                -0.0020578798361352213,\n                                                -0.0018707998510320192,\n                                                -0.0017007271373018355,\n                                                -0.001546115579365305,\n                                                -0.0014055596176048226,\n                                                -0.0012777814705498386,\n                                                -0.0011616195186816714,\n                                                -0.0010560177442560648,\n                                                -0.000960016131141877,\n                                                -0.0008727419374017063,\n                                                -0.0007934017612742784,\n                                                -0.0007212743284311622,\n                                                -0.0006557039349374201,\n                                                -0.0005960944863067456,\n                                                -0.0005419040784606777,\n                                                -0.0004926400713278887,\n                                                -0.0004478546102980806,\n                                                -0.0004071405548164369,\n                                                -0.00037012777710585166,\n                                                -0.000336479797368956,\n                                                -0.00030589072488086905,\n                                                -0.0002780824771644264,\n                                                -0.00025280225196766033,\n                                                -0.0002298202290615094,\n                                                -0.00020892748096500852,\n                                                -0.00018993407360455317,\n                                                -0.00017266733964050286,\n                                                -0.00015697030876409349,\n                                                -0.00014270028069463043,\n                                                -0.00012972752790420947,\n                                                -0.00011793411627655406,\n                                                -0.0001072128329786855,\n                                                -9.7466211798805e-05,\n                                                -8.860564708982272e-05,\n                                                -8.05505882634752e-05,\n                                                -7.322780751225018e-05,\n                                                -6.657073410204561e-05,\n                                                -6.051884918367783e-05,\n                                                -5.5017135621525293e-05,\n                                                -5.0015577837750266e-05,\n                                                -4.546870712522751e-05,\n                                                -4.1335188295661374e-05,\n                                                -3.75774439051467e-05,\n                                                -3.416131264104245e-05,\n                                                -3.105573876458404e-05,\n                                                -2.8232489785985488e-05,\n                                                -2.566589980544135e-05,\n                                                -2.3332636186764862e-05,\n                                                -2.1211487442513508e-05,\n                                                -1.9283170402285007e-05,\n                                                -1.7530154911168186e-05,\n                                                -1.5936504464698348e-05,\n                                                -1.448773133154395e-05,\n                                                -1.3170664846858136e-05,\n                                                -1.197333167896194e-05,\n                                                -1.088484698087449e-05,\n                                                -9.895315437158626e-06,\n                                                -8.995741306507842e-06,\n                                                -8.177946642279855e-06,\n                                                -7.43449694752714e-06,\n                                                -6.758633588661036e-06,\n                                                -6.144212353328214e-06,\n                                                0.0,\n                                                6.144212353328214e-06,\n                                                6.758633588661036e-06,\n                                                7.43449694752714e-06,\n                                                8.177946642279855e-06,\n                                                8.995741306507842e-06,\n                                                9.895315437158626e-06,\n                                                1.088484698087449e-05,\n                                                1.197333167896194e-05,\n                                                1.3170664846858136e-05,\n                                                1.448773133154395e-05,\n                                                1.5936504464698348e-05,\n                                                1.7530154911168186e-05,\n                                                1.9283170402285007e-05,\n                                                2.1211487442513508e-05,\n                                                2.3332636186764862e-05,\n                                                2.566589980544135e-05,\n                                                2.8232489785985488e-05,\n                                                3.105573876458404e-05,\n                                                3.416131264104245e-05,\n                                                3.75774439051467e-05,\n                                                4.1335188295661374e-05,\n                                                4.546870712522751e-05,\n                                                5.0015577837750266e-05,\n                                                5.5017135621525293e-05,\n                                                6.051884918367783e-05,\n                                                6.657073410204561e-05,\n                                                7.322780751225018e-05,\n                                                8.05505882634752e-05,\n                                                8.860564708982272e-05,\n                                                9.7466211798805e-05,\n                                                0.0001072128329786855,\n                                                0.00011793411627655406,\n                                                0.00012972752790420947,\n                                                0.00014270028069463043,\n                                                0.00015697030876409349,\n                                                0.00017266733964050286,\n                                                0.00018993407360455317,\n                                                0.00020892748096500852,\n                                                0.0002298202290615094,\n                                                0.00025280225196766033,\n                                                0.0002780824771644264,\n                                                0.00030589072488086905,\n                                                0.000336479797368956,\n                                                0.00037012777710585166,\n                                                0.0004071405548164369,\n                                                0.0004478546102980806,\n                                                0.0004926400713278887,\n                                                0.0005419040784606777,\n                                                0.0005960944863067456,\n                                                0.0006557039349374201,\n                                                0.0007212743284311622,\n                                                0.0007934017612742784,\n                                                0.0008727419374017063,\n                                                0.000960016131141877,\n                                                0.0010560177442560648,\n                                                0.0011616195186816714,\n                                                0.0012777814705498386,\n                                                0.0014055596176048226,\n                                                0.001546115579365305,\n                                                0.0017007271373018355,\n                                                0.0018707998510320192,\n                                                0.0020578798361352213,\n                                                0.0022636678197487437,\n                                                0.0024900346017236183,\n                                                0.0027390380618959806,\n                                                0.003012941868085579,\n                                                0.0033142360548941373,\n                                                0.0036456596603835515,\n                                                0.004010225626421907,\n                                                0.004411248189064097,\n                                                0.004852373007970507,\n                                                0.005337610308767558,\n                                                0.005871371339644315,\n                                                0.006458508473608747,\n                                                0.007104359320969622,\n                                                0.007814795253066584,\n                                                0.008596274778373244,\n                                                0.009455902256210569,\n                                                0.010401492481831627,\n                                                0.01144164173001479,\n                                                0.012585805903016271,\n                                                0.013844386493317899,\n                                                0.01522882514264969,\n                                                0.01675170765691466,\n                                                0.018426878422606128,\n                                                0.020269566264866742,\n                                                0.022296522891353417,\n                                                0.02452617518048876,\n                                                0.02697879269853764,\n                                                0.029676671968391403,\n                                                0.032644339165230546,\n                                                0.0359087730817536,\n                                                0.03949965038992897,\n                                                0.04344961542892187,\n                                                0.04779457697181406,\n                                                0.05257403466899547,\n                                                0.05783143813589502,\n                                                0.06361458194948454,\n                                                0.069976040144433,\n                                                0.07697364415887631,\n                                                0.08467100857476395,\n                                                0.09313810943224035,\n                                                0.1024519203754644,\n                                                0.11269711241301085,\n                                                0.12396682365431194,\n                                                0.13636350601974315,\n                                                0.14999985662171747,\n                                                0.16499984228388923,\n                                                0.18149982651227817,\n                                                0.199649809163506,\n                                                0.21961479007985663,\n                                                0.2415762690878423,\n                                                0.26573389599662656,\n                                                0.2923072855962892,\n                                                0.32153801415591815,\n                                                0.35369181557150997,\n                                                0.389060997128661,\n                                                0.4279670968415271,\n                                                0.47076380652567984,\n                                                0.5178401871782479,\n                                                0.5696242058960727,\n                                                0.62658662648568,\n                                                0.689245289134248,\n                                                0.7581698180476728,\n                                                0.8339867998524402,\n                                                0.9173854798376843,\n                                                1.0091240278214528,\n                                                1.1100364306035981,\n                                                1.221040073663958,\n                                                1.3431440810303539,\n                                                1.4774584891333893,\n                                                1.6252043380467283,\n                                                1.7877247718514013,\n                                                1.9664972490365416,\n                                                2.163146973940196,\n                                                2.379461671334216,\n                                                2.6174078384676376,\n                                                2.8791486223144016,\n                                                3.167063484545842,\n                                                3.4837698330004265,\n                                                3.8321468163004693,\n                                                4.215361497930517,\n                                                4.636897647723569,\n                                                5.100587412495926,\n                                                5.6106461537455194,\n                                                6.171710769120072,\n                                                6.788881846032079,\n                                                7.467770030635288,\n                                                8.214547033698818,\n                                                9.0360017370687,\n                                                9.93960191077557,\n                                                10.933562101853129,\n                                                12.026918312038443,\n                                                13.229610143242288,\n                                                14.552571157566518,\n                                                16.00782827332317,\n                                                17.608611100655487,\n                                                19.36947221072104,\n                                                21.306419431793145,\n                                                23.437061374972462,\n                                                25.78076751246971,\n                                                28.35884426371668,\n                                                31.19472869008835,\n                                                34.314201559097185,\n                                                37.745621715006905,\n                                                41.520183886507596,\n                                                45.67220227515836,\n                                                50.2394225026742,\n                                                55.26336475294163,\n                                                60.78970122823579,\n                                                66.86867135105938,\n                                                73.55553848616532,\n                                                80.91109233478186,\n                                                89.00220156826005,\n                                                97.90242172508606,\n                                                107.69266389759467,\n                                                118.46193028735415,\n                                                130.30812331608956,\n                                                143.33893564769852,\n                                                157.67282921246837,\n                                                173.44011213371522,\n                                                190.78412334708676,\n                                                209.86253568179546,\n                                                230.84878924997503,\n                                                253.93366817497255,\n                                                279.3270349924698,\n                                                307.2597384917168,\n                                                337.9857123408885,\n                                                371.78428357497734,\n                                                408.9627119324751,\n                                                449.85898312572266,\n                                                494.844881438295,\n                                                544.3293695821245,\n                                                598.762306540337,\n                                                658.6385371943708,\n                                                724.5023909138079,\n                                                796.9526300051887,\n                                                876.6478930057076,\n                                                964.3126823062785,\n                                                1060.7439505369064,\n                                                1166.8183455905971,\n                                                1283.500180149657,\n                                                1411.8501981646227,\n                                                1553.0352179810852,\n                                                1708.338739779194,\n                                                1879.1726137571134,\n                                                2067.089875132825,\n                                                2273.798862646108,\n                                                2501.1787489107187,\n                                                2751.2966238017907,\n                                                3026.42628618197,\n                                                3329.0689148001675,\n                                                3661.9758062801843,\n                                                4028.173386908203,\n                                                4430.990725599024,\n                                                4874.089798158927,\n                                                5361.49877797482,\n                                                5897.648655772302,\n                                                6487.413521349533,\n                                                7136.154873484486,\n                                                7849.770360832936,\n                                                8634.74739691623,\n                                                9498.222136607854,\n                                                10448.04435026864,\n                                                11492.848785295504,\n                                                12642.133663825056,\n                                                13906.347030207562,\n                                                15296.98173322832,\n                                                16826.679906551155,\n                                                18509.347897206273,\n                                                20360.2826869269,\n                                                22396.310955619592,\n                                                24635.942051181555,\n                                                27099.536256299714,\n                                                29809.489881929687,\n                                                32790.438870122656,\n                                                36069.48275713492,\n                                                39676.43103284842,\n                                                43644.07413613327,\n                                                48008.4815497466,\n                                                52809.32970472126,\n                                                58090.26267519339,\n                                                63899.28894271274,\n                                                70289.21783698401,\n                                                77318.13962068241,\n                                                85049.95358275066,\n                                                93554.94894102574,\n                                                102910.44383512832,\n                                                113201.48821864116,\n                                                124521.63704050529,\n                                                136973.80074455583,\n                                                150671.18081901144,\n                                                165738.2989009126,\n                                                182312.12879100387,\n                                                200543.34167010427,\n                                                220597.67583711472,\n                                                242657.4434208262,\n                                                266923.1877629088,\n                                                293615.50653919973,\n                                                322977.0571931197,\n                                                355274.76291243173,\n                                                390802.23920367495,\n                                                429882.4631240425,\n                                                472870.7094364468,\n                                                520157.7803800915,\n                                                572173.5584181007,\n                                                629390.9142599108,\n                                                692330.005685902,\n                                                761563.0062544922,\n                                                837719.3068799415,\n                                                921491.2375679357,\n                                                1013640.3613247294,\n                                                1115004.3974572024,\n                                                1226504.8372029227,\n                                                1349155.320923215,\n                                                1484070.8530155367,\n                                                1632477.9383170905,\n                                                1795725.7321487998,\n                                                1975298.30536368,\n                                                2172828.135900048,\n                                                2390110.949490053,\n                                                2629122.0444390588,\n                                                2892034.2488829647,\n                                                3181237.6737712612,\n                                                3499361.4411483877,\n                                                3849297.5852632266,\n                                                4234227.34378955,\n                                                4657650.078168505,\n                                                5123415.085985356,\n                                                5635756.594583892,\n                                                6199332.254042282,\n                                                6819265.479446511,\n                                                7501192.027391163,\n                                                8251311.230130279,\n                                                9076442.353143308,\n                                                9984086.58845764,\n                                                10982495.247303406,\n                                                12080744.772033747,\n                                                13288819.249237124,\n                                                14617701.174160838,\n                                                16079471.291576924,\n                                                17687418.420734618,\n                                                19456160.26280808,\n                                                21401776.28908889,\n                                                23541953.91799778,\n                                                25896149.309797563,\n                                                28485764.24077732,\n                                                31334340.664855056,\n                                                34467774.731340565,\n                                                37914552.20447463,\n                                                41706007.424922094,\n                                                45876608.16741431,\n                                                50464268.984155744,\n                                                55510695.882571325,\n                                                61061765.47082846,\n                                                67167942.01791131,\n                                                73884736.21970245,\n                                                81273209.8416727,\n                                                89400530.82583998,\n                                                98340583.90842399,\n                                                108174642.2992664,\n                                                118992106.52919304,\n                                                130891317.18211237,\n                                                143980448.9003236,\n                                                158378493.79035598,\n                                                174216343.1693916,\n                                                191637977.48633078,\n                                                210801775.23496386,\n                                                231881952.75846028,\n                                                255070148.03430632,\n                                                280577162.83773696,\n                                                308634879.1215107,\n                                                339498367.0336618,\n                                                373448203.737028,\n                                                410793024.1107308,\n                                                451872326.521804,\n                                                DBL_MAX};\n\nHistogram::Histogram() {\n  max_constainers_ = defalut_container;\n  containers_.resize(max_constainers_.size());\n  for (size_t idx = 0; idx < max_constainers_.size(); idx++) { containers_.at(idx) = 0; }\n  value_sum_ = 0;\n  sum_value_squares_ = 0;\n  max_value_ = -DBL_MAX;\n  value_count_ = 0;\n  min_value_ = DBL_MAX;\n}\n\nvoid Histogram::AppendValue(double value) {\n  value_sum_ += value;\n  value_count_++;\n  sum_value_squares_ += value * value;\n  if (max_value_ < value) { max_value_ = value; }\n  if (min_value_ > value) { min_value_ = value; }\n  int idx = std::upper_bound(max_constainers_.begin(), max_constainers_.end(), value)\n            - max_constainers_.begin();\n  CHECK_GT(containers_.size(), idx);\n  containers_.at(idx) += 1.0;\n}\n\nvoid Histogram::AppendToProto(HistogramProto* hist_proto) {\n  hist_proto->Clear();\n  hist_proto->set_num(value_count_);\n  hist_proto->set_sum(value_sum_);\n  hist_proto->set_min(min_value_);\n  hist_proto->set_max(max_value_);\n  hist_proto->set_sum_squares(sum_value_squares_);\n  for (size_t idx = 0; idx < containers_.size();) {\n    double num = containers_.at(idx);\n    double last = max_constainers_.at(idx);\n    idx++;\n    if (num <= 0.0) {\n      while (idx < containers_.size() && containers_.at(idx) <= 0.0) {\n        last = max_constainers_.at(idx);\n        num = containers_.at(idx);\n        idx++;\n      }\n    }\n    hist_proto->add_bucket_limit(last);\n    hist_proto->add_bucket(num);\n  }\n}\n\n}  // namespace summary\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/summary/histogram.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_SUMMARY_HISTOGRAM_H_\n#define ONEFLOW_USER_SUMMARY_HISTOGRAM_H_\n\n#include <vector>\n#include \"oneflow/core/summary/summary.pb.h\"\n\nnamespace oneflow {\n\nnamespace summary {\n\nclass Histogram {\n public:\n  Histogram();\n  ~Histogram() {}\n\n  void AppendValue(double value);\n  void AppendToProto(HistogramProto* proto);\n\n private:\n  double value_count_;\n  double value_sum_;\n  double sum_value_squares_;\n  double min_value_;\n  double max_value_;\n\n  std::vector<double> max_constainers_;\n  std::vector<double> containers_;\n};\n\n}  // namespace summary\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_SUMMARY_HISTOGRAM_H_\n"
  },
  {
    "path": "oneflow/user/summary/plan_to_physical_graph.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/user/summary/plan_to_physical_graph.h\"\n#include \"oneflow/core/summary/graph.pb.h\"\n#include \"oneflow/core/common/util.h\"\n#include \"oneflow/core/persistence/tee_persistent_log_stream.h\"\n#include \"oneflow/core/job/id_manager.h\"\n#include \"oneflow/core/framework/to_string.h\"\n#include \"oneflow/core/job/plan_util.h\"\n\nnamespace oneflow {\n\nnamespace summary {\n\nvoid PlanToPhysicalGraphFile(const Plan& plan) {\n  GraphDef physical_graph;\n  physical_graph.set_version(3);  // \"compute graph version number = 3\"\n  HashMap<int64_t, std::string> regst_desc_id2produce_op_name;\n  HashMap<int64_t, std::string> task_id2op_name;\n  HashSet<int64_t> ctrl_regst_desc_id_set;\n  for (const TaskProto& task : plan.task()) {\n    std::string op_name = \"\";\n    for (const ExecNodeProto& exec_node : task.exec_sequence().exec_node()) {\n      if (op_name != \"\") { op_name += \" && \"; }\n      op_name += (exec_node.kernel_conf().op_attribute().op_conf().name());\n    }\n    if (op_name == \"\") { continue; }\n    task_id2op_name.insert({task.task_id(), op_name});\n    for (const auto& pair : task.produced_regst_desc()) {\n      const RegstDescProto& regst = pair.second;\n      int64_t regst_desc_id = regst.regst_desc_id();\n      regst_desc_id2produce_op_name.insert({regst_desc_id, op_name});\n      if (regst.regst_desc_type().has_ctrl_regst_desc()) {\n        ctrl_regst_desc_id_set.insert(regst_desc_id);\n      }\n    }\n  }\n\n  for (const TaskProto& task : plan.task()) {\n    if (task_id2op_name.find(task.task_id()) == task_id2op_name.end()) { continue; }\n    NodeDef* node = physical_graph.add_node();\n    node->set_name(task_id2op_name.at(task.task_id()));\n    const OperatorConf& op_conf =\n        task.exec_sequence().exec_node(0).kernel_conf().op_attribute().op_conf();\n    DeviceType device_type = PlanUtil::GetStreamId(task).device_id().device_type();\n    node->set_device(*CHECK_JUST(DeviceTag4DeviceType(device_type)));\n    if (op_conf.has_user_conf()) {\n      const UserOpConf& user_op = op_conf.user_conf();\n      node->set_op(user_op.op_type_name());\n      node->mutable_attr()->insert(user_op.attr().begin(), user_op.attr().end());\n    } else {\n      // maybe need get op / attr by every different op_type_case\n      node->set_op(\"system_op\");\n    }\n    for (const auto& pair : task.consumed_regst_desc_id()) {\n      for (int64_t regst_desc_id : pair.second.regst_desc_id()) {\n        if (regst_desc_id2produce_op_name.find(regst_desc_id)\n            != regst_desc_id2produce_op_name.end()) {\n          std::string input_name = regst_desc_id2produce_op_name.at(regst_desc_id);\n          if (ctrl_regst_desc_id_set.find(regst_desc_id) != ctrl_regst_desc_id_set.end()) {\n            input_name = \"^\" + input_name;  // control edge\n          }\n          node->add_input(input_name);\n        }\n      }\n    }\n  }\n  TeePersistentLogStream::Create(\"physical_graph\")->Write(physical_graph);\n}\n\n}  // namespace summary\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/summary/plan_to_physical_graph.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_SUMMARY_PLAN_TO_PHYSICAL_GRAPH_H_\n#define ONEFLOW_USER_SUMMARY_PLAN_TO_PHYSICAL_GRAPH_H_\n\n#include \"oneflow/core/job/plan.pb.h\"\n\nnamespace oneflow {\n\nnamespace summary {\n\nvoid PlanToPhysicalGraphFile(const Plan& plan);\n}\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_SUMMARY_PLAN_TO_PHYSICAL_GRAPH_H_\n"
  },
  {
    "path": "oneflow/user/summary/summary_converter.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_SUMMARY_SUMMARY_CONVERTER_H_\n#define ONEFLOW_USER_SUMMARY_SUMMARY_CONVERTER_H_\n\n#include \"nlohmann/json.hpp\"\n#include \"oneflow/core/common/data_type.h\"\n#include \"oneflow/core/common/protobuf.h\"\n\nnamespace oneflow {\n\nnamespace summary {\n\nstatic void ConvertProtobufMsg2Json(nlohmann::json& json_value, const PbMessage& pb_msg);\n\nstatic void ConvertRepeatedField2Json(nlohmann::json& json_value, const PbMessage& pb_msg,\n                                      const PbFd* pb_field,\n                                      const google::protobuf::Reflection* pb_reflection) {\n  if (NULL == pb_field || NULL == pb_reflection) { ConvertProtobufMsg2Json(json_value, pb_msg); }\n\n  for (int i = 0; i < pb_reflection->FieldSize(pb_msg, pb_field); ++i) {\n    nlohmann::json tmp_json_value;\n    switch (pb_field->type()) {\n      case PbFd::TYPE_MESSAGE: {\n        const PbMessage& msg = pb_reflection->GetRepeatedMessage(pb_msg, pb_field, i);\n        if (0 != msg.ByteSize()) { ConvertProtobufMsg2Json(tmp_json_value, msg); }\n      } break;\n      case PbFd::TYPE_INT32:\n        tmp_json_value = pb_reflection->GetRepeatedInt32(pb_msg, pb_field, i);\n        break;\n      case PbFd::TYPE_UINT32:\n        tmp_json_value = pb_reflection->GetRepeatedUInt32(pb_msg, pb_field, i);\n        break;\n      case PbFd::TYPE_INT64: {\n        static char int64_str[25];\n        memset(int64_str, 0, sizeof(int64_str));\n        snprintf(int64_str, sizeof(int64_str), \"%lld\",\n                 (long long)pb_reflection->GetRepeatedInt64(pb_msg, pb_field, i));\n        tmp_json_value = int64_str;\n      } break;\n      case PbFd::TYPE_UINT64: {\n        static char uint64str[25];\n        memset(uint64str, 0, sizeof(uint64str));\n        snprintf(uint64str, sizeof(uint64str), \"%llu\",\n                 (unsigned long long)pb_reflection->GetRepeatedUInt64(pb_msg, pb_field, i));\n        tmp_json_value = uint64str;\n      } break;\n      case PbFd::TYPE_STRING:\n      case PbFd::TYPE_BYTES:\n        tmp_json_value = pb_reflection->GetRepeatedString(pb_msg, pb_field, i);\n        break;\n      case PbFd::TYPE_BOOL:\n        tmp_json_value = pb_reflection->GetRepeatedBool(pb_msg, pb_field, i);\n        break;\n      case PbFd::TYPE_ENUM:\n        tmp_json_value = pb_reflection->GetRepeatedEnum(pb_msg, pb_field, i)->name();\n        break;\n      case PbFd::TYPE_FLOAT:\n        tmp_json_value = pb_reflection->GetRepeatedFloat(pb_msg, pb_field, i);\n        break;\n      case PbFd::TYPE_DOUBLE:\n        tmp_json_value = pb_reflection->GetRepeatedDouble(pb_msg, pb_field, i);\n        break;\n      default: break;\n    }\n    json_value.emplace_back(tmp_json_value);\n  }\n}\n\nstatic void ConvertProtobufMsg2Json(nlohmann::json& json_value, const PbMessage& pb_msg) {\n  const google::protobuf::Descriptor* pb_descriptor = pb_msg.GetDescriptor();\n  const google::protobuf::Reflection* pb_reflection = pb_msg.GetReflection();\n\n  const int count = pb_descriptor->field_count();\n\n  for (int i = 0; i < count; ++i) {\n    const PbFd* pb_field = pb_descriptor->field(i);\n\n    if (pb_field->is_repeated()) {\n      if (pb_reflection->FieldSize(pb_msg, pb_field) > 0) {\n        ConvertRepeatedField2Json(json_value[pb_field->name()], pb_msg, pb_field, pb_reflection);\n      }\n      continue;\n    }\n\n    if (!pb_reflection->HasField(pb_msg, pb_field)) { continue; }\n\n    switch (pb_field->type()) {\n      case PbFd::TYPE_MESSAGE: {\n        const PbMessage& msg = pb_reflection->GetMessage(pb_msg, pb_field);\n        if (0 != msg.ByteSize()) { ConvertProtobufMsg2Json(json_value[pb_field->name()], msg); }\n      } break;\n      case PbFd::TYPE_INT32:\n        json_value[pb_field->name()] = pb_reflection->GetInt32(pb_msg, pb_field);\n        break;\n      case PbFd::TYPE_UINT32:\n        json_value[pb_field->name()] = pb_reflection->GetUInt32(pb_msg, pb_field);\n        break;\n      case PbFd::TYPE_INT64: {\n        static char int64_str[25];\n        memset(int64_str, 0, sizeof(int64_str));\n        snprintf(int64_str, sizeof(int64_str), \"%lld\",\n                 (long long)pb_reflection->GetInt64(pb_msg, pb_field));\n        json_value[pb_field->name()] = int64_str;\n      } break;\n      case PbFd::TYPE_UINT64: {\n        static char uint64_str[25];\n        memset(uint64_str, 0, sizeof(uint64_str));\n        snprintf(uint64_str, sizeof(uint64_str), \"%llu\",\n                 (unsigned long long)pb_reflection->GetUInt64(pb_msg, pb_field));\n        json_value[pb_field->name()] = uint64_str;\n      } break;\n      case PbFd::TYPE_STRING:\n      case PbFd::TYPE_BYTES: {\n        json_value[pb_field->name()] = pb_reflection->GetString(pb_msg, pb_field);\n      } break;\n      case PbFd::TYPE_BOOL: {\n        json_value[pb_field->name()] = pb_reflection->GetBool(pb_msg, pb_field);\n      } break;\n      case PbFd::TYPE_ENUM: {\n        json_value[pb_field->name()] = pb_reflection->GetEnum(pb_msg, pb_field)->name();\n      } break;\n      case PbFd::TYPE_FLOAT: {\n        json_value[pb_field->name()] = pb_reflection->GetFloat(pb_msg, pb_field);\n      } break;\n      case PbFd::TYPE_DOUBLE: {\n        json_value[pb_field->name()] = pb_reflection->GetDouble(pb_msg, pb_field);\n      } break;\n      default: break;\n    }\n  }\n}\n\n}  // namespace summary\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_SUMMARY_SUMMARY_CONVERTER_H_\n"
  },
  {
    "path": "oneflow/user/utils/pool_util.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"oneflow/user/utils/pool_util.h\"\n#include \"oneflow/core/operator/operator_util.h\"\n\nnamespace oneflow {\n\nParams3D::Params3D(const int32_t dim, const ShapeView& x_shape, const std::string& data_format,\n                   const std::string& padding, const std::vector<int32_t>& padding_before,\n                   const std::vector<int32_t>& padding_after, const std::vector<int32_t>& pool_size,\n                   const std::vector<int32_t>& strides, const bool ceil_mode)\n    : dim_(dim),\n      pool_size_3d_(Get3DVec(pool_size, dim)),\n      strides_3d_(Get3DVec(strides, dim)),\n      padding_before_3d_(Get3DVec<Get3DVecType::kPad>(padding_before, dim)),\n      padding_after_3d_(Get3DVec<Get3DVecType::kPad>(padding_after, dim)),\n      data_format_(data_format),\n      padding_(padding),\n      ceil_mode_(ceil_mode) {\n  x_3d_ = {GetInDim(x_shape, data_format, 0, dim), GetInDim(x_shape, data_format, 1, dim),\n           GetInDim(x_shape, data_format, 2, dim)};\n  Get3DOutputSize(x_3d_, pool_size_3d_, strides_3d_, padding_, ceil_mode_, nullptr, &y_3d_,\n                  &padding_before_3d_, &padding_after_3d_);\n  if (data_format == \"channels_first\") {\n    channel_num_ = x_shape.At(1);\n  } else {\n    CHECK_EQ(data_format_, \"channels_last\")\n        << \"data_format must be 'channels_first' or 'channels_last'\";\n    channel_num_ = x_shape.At(x_shape.NumAxes() - 1);\n  }\n  batch_num_ = x_shape.At(0);\n}\n\nvoid Params3D::Reset(const ShapeView& x_shape) {\n  x_3d_ = {GetInDim(x_shape, data_format_, 0, dim_), GetInDim(x_shape, data_format_, 1, dim_),\n           GetInDim(x_shape, data_format_, 2, dim_)};\n  Get3DOutputSize(x_3d_, pool_size_3d_, strides_3d_, padding_, ceil_mode_, nullptr, &y_3d_,\n                  &padding_before_3d_, &padding_after_3d_);\n}\n\nShape Params3D::GetYShape() const {\n  DimVector y_dim_vec;\n  if (dim_ == 1) {\n    y_dim_vec = {y_3d_.at(2)};\n  } else if (dim_ == 2) {\n    y_dim_vec = {y_3d_.at(1), y_3d_.at(2)};\n  } else if (dim_ == 3) {\n    y_dim_vec = {y_3d_.at(0), y_3d_.at(1), y_3d_.at(2)};\n  } else {\n    UNIMPLEMENTED();\n  }\n  if (data_format_ == \"channels_first\") {\n    y_dim_vec.insert(y_dim_vec.begin(), channel_num_);\n  } else {\n    CHECK_EQ(data_format_, \"channels_last\")\n        << \"data_format must be 'channels_first' or 'channels_last'\";\n    y_dim_vec.insert(y_dim_vec.end(), channel_num_);\n  }\n  y_dim_vec.insert(y_dim_vec.begin(), batch_num_);\n  return Shape(y_dim_vec);\n}\n\nShape Params3D::GetXShape5D() const {\n  return Shape({batch_num_, channel_num_, x_3d_.at(0), x_3d_.at(1), x_3d_.at(2)});\n}\n\nShape Params3D::GetYShape5D() const {\n  return Shape({batch_num_, channel_num_, y_3d_.at(0), y_3d_.at(1), y_3d_.at(2)});\n}\n\n}  // namespace oneflow\n"
  },
  {
    "path": "oneflow/user/utils/pool_util.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#ifndef ONEFLOW_USER_UTILS_POOL_UTIL_H_\n#define ONEFLOW_USER_UTILS_POOL_UTIL_H_\n#include \"oneflow/core/device/cudnn_util.h\"\n#include \"oneflow/core/framework/framework.h\"\n#include \"oneflow/user/kernels/op_kernel_wrapper.h\"\n\nnamespace oneflow {\n\ntypedef small_vector<int64_t, SHAPE_MAX_AXIS_SIZE> FixedDimVector;\ntypedef small_vector<int32_t, SHAPE_MAX_AXIS_SIZE> FixedVector;\n\nclass Params3D {\n public:\n  Params3D(const int32_t dim, const ShapeView& x_shape, const std::string& data_format,\n           const std::string& padding, const std::vector<int32_t>& padding_before,\n           const std::vector<int32_t>& padding_after, const std::vector<int32_t>& pool_size,\n           const std::vector<int32_t>& strides, const bool ceil_mode);\n  ~Params3D() = default;\n  void Reset(const ShapeView& x_shape);\n\n  Shape GetYShape() const;\n  Shape GetXShape5D() const;\n  Shape GetYShape5D() const;\n\n  const std::vector<int32_t>& pool_size_3d() const { return pool_size_3d_; }\n  const std::vector<int32_t>& strides_3d() const { return strides_3d_; }\n  const std::vector<int32_t>& padding_before_3d() const { return padding_before_3d_; }\n  const std::vector<int32_t>& padding_after_3d() const { return padding_after_3d_; }\n\n private:\n  int32_t dim_;\n  FixedDimVector x_3d_;\n  FixedDimVector y_3d_;\n  std::vector<int32_t> pool_size_3d_;\n  std::vector<int32_t> strides_3d_;\n  std::vector<int32_t> padding_before_3d_;\n  std::vector<int32_t> padding_after_3d_;\n  std::string data_format_;\n  std::string padding_;\n  bool ceil_mode_;\n  int64_t batch_num_;\n  int64_t channel_num_;\n};\n\nenum class Get3DVecType { kPad, kNonPad };\n\ntemplate<Get3DVecType get_3d_vec_type = Get3DVecType::kNonPad>\nstd::vector<int32_t> Get3DVec(const std::vector<int32_t>& original_vec, int32_t NDims) {\n  std::vector<int32_t> vec;\n  FOR_RANGE(uint8_t, dim, 0, 3) {\n    int64_t index = static_cast<int64_t>(dim) - (3 - NDims);\n    if (index < 0) {\n      vec.emplace_back(static_cast<int32_t>(get_3d_vec_type));  // kPad -> 0, kNonPad -> 1\n    } else {\n      vec.emplace_back(original_vec.at(index));\n    }\n  }\n  return vec;\n}\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_USER_UTILS_POOL_UTIL_H_\n"
  },
  {
    "path": "python/.gitignore",
    "content": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packaging\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\npip-wheel-metadata/\nshare/python-wheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.nox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n*.py,cover\n.hypothesis/\n.pytest_cache/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\ndb.sqlite3\ndb.sqlite3-journal\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# IPython\nprofile_default/\nipython_config.py\n\n# pyenv\n.python-version\n\n# pipenv\n#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.\n#   However, in case of collaboration, if having platform-specific dependencies or dependencies\n#   having no cross-platform support, pipenv may install dependencies that don't work, or not\n#   install all needed dependencies.\n#Pipfile.lock\n\n# PEP 582; used by e.g. github.com/David-OConnor/pyflow\n__pypackages__/\n\n# Celery stuff\ncelerybeat-schedule\ncelerybeat.pid\n\n# SageMath parsed files\n*.sage.py\n\n# Environments\n.env\n.venv\nenv/\nvenv/\nENV/\nenv.bak/\nvenv.bak/\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n.dmypy.json\ndmypy.json\n\n# Pyre type checker\n.pyre/\n\n/oneflow/include\n/oneflow/core\n/oneflow/compatible/single_client/core\n/oneflow/version.py\nlib.py\n*.ast.py\nunittest-log-*\nlog\noutput\n"
  },
  {
    "path": "python/oneflow/_C/__init__.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nfrom oneflow._oneflow_internal._C import *\nimport oneflow._C._nn as _nn\nimport warnings\n\n\ndef allclose(input, other, atol=1e-08, rtol=1e-05, equal_nan=False):\n    return isclose(input, other, atol, rtol, equal_nan).all().item()\n\n\ndef _log_api_usage_once(event):\n    warnings.warn(\"_log_api_usage_once is not implemented in oneflow\")\n"
  },
  {
    "path": "python/oneflow/_C/_nn.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow as flow\nimport builtins\nfrom oneflow.framework.tensor import Tensor\nfrom typing import overload, Tuple, Any\n\n_device = flow.device\n_bool = builtins.bool\n_dtype = flow.dtype\n\n\n@overload\ndef _parse_to(\n    device: _device,\n    dtype: _dtype,\n    non_blocking: _bool,\n    copy: _bool,\n    *,\n    memory_format: Any,\n) -> Tuple[_device, _dtype, _bool, Any]:\n    ...\n\n\n@overload\ndef _parse_to(\n    dtype: _dtype, non_blocking: _bool, copy: _bool, *, memory_format: Any\n) -> Tuple[_device, _dtype, _bool, Any]:\n    ...\n\n\n@overload\ndef _parse_to(\n    tensor: Tensor, non_blocking: _bool, copy: _bool, *, memory_format: Any\n) -> Tuple[_device, _dtype, _bool, Any]:\n    ...\n\n\ndef _parse_to(*args, **kwargs):\n    # TODO: implement _parse_to natively\n    result = flow.tensor([]).to(*args, **kwargs)\n\n    return (result.device, result.dtype, False, None)\n"
  },
  {
    "path": "python/oneflow/__init__.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport os\nimport sys\nimport collections\nimport warnings\n\n# https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#cuda-environment-variables\nif \"CUDA_MODULE_LOADING\" not in os.environ:\n    os.environ[\"CUDA_MODULE_LOADING\"] = \"LAZY\"\n\nimport oneflow._oneflow_internal\n\noneflow._oneflow_internal.RegisterSignalHandler()\n\noneflow_python_base_dir = os.path.dirname(os.path.realpath(__file__))\noneflow._oneflow_internal.InitPythonPathsToBeKeptAndFilteredForDebugging(\n    oneflow_python_base_dir\n)\noneflow._oneflow_internal.InitNumpyCAPI()\noneflow._oneflow_internal.CheckAndClearRegistryFlag()\nSize = oneflow._oneflow_internal.Size\ndevice = oneflow._oneflow_internal.device\nplacement = oneflow._oneflow_internal.placement\n\nlocals()[\"dtype\"] = oneflow._oneflow_internal.dtype\nlocals()[\"bool\"] = oneflow._oneflow_internal.bool\nlocals()[\"float16\"] = oneflow._oneflow_internal.float16\nlocals()[\"half\"] = oneflow._oneflow_internal.float16\nlocals()[\"float32\"] = oneflow._oneflow_internal.float32\nlocals()[\"float\"] = oneflow._oneflow_internal.float\nlocals()[\"double\"] = oneflow._oneflow_internal.double\nlocals()[\"float64\"] = oneflow._oneflow_internal.float64\nlocals()[\"int8\"] = oneflow._oneflow_internal.int8\nlocals()[\"int\"] = oneflow._oneflow_internal.int32\nlocals()[\"int32\"] = oneflow._oneflow_internal.int32\nlocals()[\"int64\"] = oneflow._oneflow_internal.int64\nlocals()[\"long\"] = oneflow._oneflow_internal.int64\nlocals()[\"uint8\"] = oneflow._oneflow_internal.uint8\nlocals()[\"record\"] = oneflow._oneflow_internal.record\nlocals()[\"tensor_buffer\"] = oneflow._oneflow_internal.tensor_buffer\nlocals()[\"bfloat16\"] = oneflow._oneflow_internal.bfloat16\nlocals()[\"char\"] = oneflow._oneflow_internal.char\nlocals()[\"short\"] = oneflow._oneflow_internal.int16\nlocals()[\"int16\"] = oneflow._oneflow_internal.int16\n\nlocals()[\"cfloat\"] = oneflow._oneflow_internal.cfloat\nlocals()[\"complex64\"] = oneflow._oneflow_internal.complex64\nlocals()[\"cdouble\"] = oneflow._oneflow_internal.cdouble\nlocals()[\"complex128\"] = oneflow._oneflow_internal.complex128\n\nlocals()[\"layout\"] = oneflow._oneflow_internal.layout\nlocals()[\"strided\"] = oneflow._oneflow_internal.strided\n\nlocals()[\"memory_format\"] = oneflow._oneflow_internal.memory_format\nlocals()[\"contiguous_format\"] = oneflow._oneflow_internal.contiguous_format\nlocals()[\"channels_last\"] = oneflow._oneflow_internal.channels_last\nlocals()[\"preserve_format\"] = oneflow._oneflow_internal.preserve_format\nfrom oneflow.version import __version__\nfrom oneflow.version import __git_commit__\n\n_DEPRECATED = set()\n\n\ndef oneflow_deprecate(*api_names, **kwargs):\n    def Decorator(func_or_class):\n        _DEPRECATED.add(func_or_class)\n        return func_or_class\n\n    return Decorator\n\n\ndef is_deprecated(func_or_class):\n    return (\n        isinstance(func_or_class, collections.abc.Hashable)\n        and func_or_class in _DEPRECATED\n    )\n\n\ndef use_deterministic_algorithms(mode, *, warn_only=False):\n    # register a empty method\n    warnings.warn(\"Oneflow temporarily does not support use_deterministic_algorithms.\")\n\n\nfrom oneflow._C import abs\nfrom oneflow._C import exp\nfrom oneflow._C import exp2\nfrom oneflow._C import acos\nfrom oneflow._C import acos as arccos\nfrom oneflow._C import acosh\nfrom oneflow._C import acosh as arccosh\nfrom oneflow._C import amin\nfrom oneflow._C import atanh\nfrom oneflow._C import atanh as arctanh\nfrom oneflow._C import batch_matmul as bmm\nfrom oneflow._C import baddbmm\nfrom oneflow._C import broadcast_like\nfrom oneflow._C import chunk\nfrom oneflow._C import digamma\nfrom oneflow._C import split\nfrom oneflow._C import sign\nfrom oneflow._C import sinh\nfrom oneflow._C import tan\nfrom oneflow._C import greater\nfrom oneflow._C import greater as gt\nfrom oneflow._C import greater_ as gt_\nfrom oneflow._C import greater_equal\nfrom oneflow._C import greater_equal as ge\nfrom oneflow._C import log\nfrom oneflow._C import log2\nfrom oneflow._C import log10\nfrom oneflow._C import logical_and\nfrom oneflow._C import logical_or\nfrom oneflow._C import logical_xor\nfrom oneflow._C import logical_not\nfrom oneflow._C import logaddexp\nfrom oneflow._C import quantile\nfrom oneflow._C import gelu_with_approximate as gelu\nfrom oneflow._C import quick_gelu\nfrom oneflow._C import square_relu\nfrom oneflow._C import mish\nfrom oneflow._C import repeat\nfrom oneflow._C import repeat_interleave\nfrom oneflow._C import tile\nfrom oneflow._C import sigmoid\nfrom oneflow._C import tanh\nfrom oneflow._C import as_strided\nfrom oneflow._C import as_strided_\nfrom oneflow._C import silu\nfrom oneflow._C import selu\nfrom oneflow._C import softshrink\nfrom oneflow._C import softsign\nfrom oneflow._C import cast\nfrom oneflow._C import diag\nfrom oneflow._C import log1p\nfrom oneflow._C import add\nfrom oneflow._C import addcdiv\nfrom oneflow._C import div, div_\nfrom oneflow._C import addcmul\nfrom oneflow._C import floor, floor_\nfrom oneflow._C import floor_divide\nfrom oneflow._C import frac, frac_\nfrom oneflow._C import mul\nfrom oneflow._C import negative\nfrom oneflow._C import negative as neg\nfrom oneflow._C import reciprocal\nfrom oneflow._C import sub\nfrom oneflow._C import sin, sin_\nfrom oneflow._C import asin\nfrom oneflow._C import asin as arcsin\nfrom oneflow._C import asinh\nfrom oneflow._C import asinh as arcsinh\nfrom oneflow._C import atan\nfrom oneflow._C import atan as arctan\nfrom oneflow._C import atan2\nfrom oneflow._C import ceil, ceil_\nfrom oneflow._C import clamp, clamp_, clamp_min, clamp_min_, clamp_max, clamp_max_\nfrom oneflow._C import clip, clip_\nfrom oneflow._C import cos\nfrom oneflow._C import cosh\nfrom oneflow._C import diagonal\nfrom oneflow._C import erf\nfrom oneflow._C import erfc\nfrom oneflow._C import expm1\nfrom oneflow._C import fmod\nfrom oneflow._C import flatten\nfrom oneflow._C import topk\nfrom oneflow._C import in_top_k\nfrom oneflow._C import lgamma\nfrom oneflow._C import minimum\nfrom oneflow._C import maximum\nfrom oneflow._C import max\nfrom oneflow._C import min\nfrom oneflow._C import median\nfrom oneflow._C import mode\nfrom oneflow._C import pow\nfrom oneflow._C import reduce_prod as prod\nfrom oneflow._C import reduce_sum as sum\nfrom oneflow._C import reduce_mean as mean\nfrom oneflow._C import reduce_all as all\nfrom oneflow._C import reduce_any as any\nfrom oneflow._C import reduce_nansum as nansum\nfrom oneflow._C import logsumexp\nfrom oneflow._C import rsqrt\nfrom oneflow._C import sqrt\nfrom oneflow._C import square\nfrom oneflow._C import matmul\nfrom oneflow._C import mm\nfrom oneflow._C import matrix_vector_product as mv\nfrom oneflow._C import bernoulli\nfrom oneflow._C import round, round_\nfrom oneflow._C import softplus\nfrom oneflow._C import threshold\nfrom oneflow._C import tril\nfrom oneflow._C import triu\nfrom oneflow._C import trunc\nfrom oneflow._C import pad\nfrom oneflow._C import transpose\nfrom oneflow._C import relu\nfrom oneflow._C import roc_auc_score\nfrom oneflow._C import softmax\nfrom oneflow._C import log_softmax\nfrom oneflow._C import argmax\nfrom oneflow._C import argmin\nfrom oneflow._C import std\n\nfrom oneflow._C import stft\nfrom oneflow._C import var\nfrom oneflow._C import stack, hstack, vstack, dstack, column_stack, row_stack\nfrom oneflow._C import atleast_1d, atleast_2d, atleast_3d\nfrom oneflow._C import squeeze\nfrom oneflow._C import narrow\nfrom oneflow._C import unsqueeze\nfrom oneflow._C import permute\nfrom oneflow._C import select\nfrom oneflow._C import unbind\nfrom oneflow._C import tensor_split\nfrom oneflow._C import hann_window\nfrom oneflow._C import hsplit\nfrom oneflow._C import vsplit\nfrom oneflow._C import concat\nfrom oneflow._C import concat as cat\nfrom oneflow._C import dim_gather as gather\nfrom oneflow._C import deform_conv2d\nfrom oneflow._C import gather_nd\nfrom oneflow._C import roi_align\nfrom oneflow._C import dot\nfrom oneflow._C import eye\nfrom oneflow._C import erfinv, erfinv_\nfrom oneflow._C import cumsum\nfrom oneflow._C import contiguous\nfrom oneflow._C import cumprod\nfrom oneflow._C import swapaxes\nfrom oneflow._C import amax\nfrom oneflow._C import swapdims\nfrom oneflow._C import t\nfrom oneflow._C import masked_fill\nfrom oneflow._C import masked_fill_\nfrom oneflow._C import equal\nfrom oneflow._C import broadcast_equal as eq\nfrom oneflow._C import not_equal\nfrom oneflow._C import not_equal as ne\nfrom oneflow._C import less as lt\nfrom oneflow._C import less_equal as le\nfrom oneflow._C import searchsorted\nfrom oneflow._C import flip\nfrom oneflow._C import index_select\nfrom oneflow._C import isnan\nfrom oneflow._C import isinf\nfrom oneflow._C import isfinite\nfrom oneflow._C import inv as inverse\nfrom oneflow._C import det\nfrom oneflow._C import iinfo, finfo\nfrom oneflow._C import multinomial\nfrom oneflow._C import linalg_cross as cross\nfrom oneflow._C import bincount\nfrom oneflow._C import isclose\nfrom oneflow._C import allclose\nfrom oneflow._C import lerp, lerp_\nfrom oneflow._C import index_add, index_add_\nfrom oneflow._C import sort\nfrom oneflow._C import clone\nfrom oneflow._C import bitwise_and, bitwise_or, bitwise_xor, bitwise_not\nfrom oneflow._C import real, imag, conj, conj_physical\n\nfrom oneflow._oneflow_internal import _set_num_threads as set_num_threads\n\nfrom . import sbp\n\nsbp.sbp.__call__ = lambda self: self\n\nimport atexit\n\nimport oneflow.framework.c_api_util\nimport oneflow.framework.register_class_method_util as register_class_method_util\n\n\nregister_class_method_util.RegisterMethod4Class()\nimport oneflow.framework.env_util as env_util\nimport oneflow.framework.scope_util as scope_util\nimport oneflow.framework.session_context as session_ctx\nfrom oneflow.framework.tensor_str import set_printoptions\n\n_oneflow_global_unique_env = env_util.GetEnv()\nsession_ctx.NewDefaultSession(_oneflow_global_unique_env)\n\noneflow._oneflow_internal.RegisterGILForeignLockHelper()\noneflow._oneflow_internal.autograd.graph.register_saved_tensors_hook_manager()\noneflow._oneflow_internal.RegisterStackGetter()\n\n\nclass ExitHook:\n    def __init__(self):\n        self.exit_code = None\n        self.exception = None\n\n        self._orig_exit = sys.exit\n        self._orig_excepthook = sys.excepthook\n\n        def exit(code=0):\n            self.exit_code = code\n            self._orig_exit(code)\n\n        sys.exit = exit\n\n        def exc_handler(exc_type, exc, *args):\n            self.exception = exc\n            self._orig_excepthook(exc_type, exc, *args)\n\n        sys.excepthook = exc_handler\n\n    def is_normal_exit(self):\n        if self.exit_code is not None:\n            return self.exit_code == 0\n        return self.exception is None\n\n\nhook = ExitHook()\n\n\ndef atexit_hook(hook):\n    _oneflow_global_unique_env.switch_to_shutting_down(hook.is_normal_exit())\n    oneflow.framework.session_context.TryCloseDefaultSession()\n\n\natexit.register(atexit_hook, hook)\ndel atexit_hook\ndel hook\ndel ExitHook\ndel atexit\ndel oneflow\n\n# default dtype\nfrom oneflow.framework.dtype import (\n    set_default_dtype,\n    set_default_tensor_type,\n    get_default_dtype,\n    is_floating_point,\n)\n\nimport oneflow._C\nfrom oneflow._C import tensor, batch_gather\nfrom oneflow._C import from_numpy, from_dlpack\n\nfrom oneflow.autograd import (\n    enable_grad,\n    set_grad_enabled,\n    no_grad,\n    inference_mode,\n    is_grad_enabled,\n)\nimport oneflow.nn.image\n\nfrom oneflow.framework.check_point_v2 import load\nfrom oneflow.framework.check_point_v2 import save\nfrom oneflow.framework.check_point_v2 import frombuffer\nfrom oneflow.framework.dtype import convert_oneflow_dtype_to_numpy_dtype, dtypes\nfrom oneflow.framework.function_util import FunctionConfig\nfrom oneflow.framework.function_util import FunctionConfig as function_config\nfrom oneflow.framework.generator import create_generator as Generator\nfrom oneflow.framework.generator import (\n    default_generator,\n    seed,\n    manual_seed,\n    initial_seed,\n    get_rng_state,\n    set_rng_state,\n)\n\n# NOTE(chengcheng) oneflow.Model is unavailable now.\n# from oneflow.framework.model import Model\nimport oneflow.utils.tensor\nimport oneflow.utils.global_view\nimport oneflow.utils.model_zoo\nfrom oneflow.framework.tensor import Tensor\nfrom oneflow.framework.tensor import is_nonzero\nfrom oneflow._oneflow_internal import to_dlpack\nfrom oneflow.framework.type_tensor import *\n\nfrom oneflow.framework.tensor import zero_\n\nfrom oneflow.nn.modules.pooling import (\n    adaptive_avg_pool1d,\n    adaptive_avg_pool2d,\n    adaptive_avg_pool3d,\n)\nfrom oneflow.nn.modules.einsum import einsum_op as einsum\nfrom oneflow.nn.modules.is_tensor import is_tensor_op as is_tensor\nfrom oneflow.nn.modules.arange import arange_op as arange\nfrom oneflow.nn.modules.linspace import linspace_op as linspace\nfrom oneflow.nn.modules.logspace import logspace_op as logspace\nfrom oneflow.nn.modules.argsort import argsort_op as argsort\nfrom oneflow.nn.modules.argwhere import argwhere_op as argwhere\nfrom oneflow.nn.modules.constant import ones_op as ones\nfrom oneflow.nn.modules.constant import zeros_op as zeros\nfrom oneflow.nn.modules.constant import zeros_like_op as zeros_like\nfrom oneflow.nn.modules.constant import ones_like_op as ones_like\nfrom oneflow.nn.modules.constant import full_op as full\nfrom oneflow.nn.modules.constant import full_like_op as full_like\nfrom oneflow.nn.modules.constant import new_ones_op as new_ones\nfrom oneflow.nn.modules.constant import new_zeros_op as new_zeros\nfrom oneflow.nn.modules.constant import new_full_op as new_full\nfrom oneflow.nn.modules.empty import empty_op as empty\nfrom oneflow.nn.modules.empty import new_empty_op as new_empty\nfrom oneflow.nn.modules.empty import empty_like_op as empty_like\nfrom oneflow._C import empty_strided\nfrom oneflow.nn.modules.dataset import tensor_buffer_to_list_of_tensors\nfrom oneflow._C import movedim\nfrom oneflow.nn.modules.expand import expand_op as expand\nfrom oneflow.nn.modules.distributed_partial_fc_sample import (\n    distributed_partial_fc_sample_op as distributed_partial_fc_sample,\n)\nfrom oneflow.nn.modules.roll import roll_op as roll\nfrom oneflow.nn.modules.masked_select import masked_select_op as masked_select\nfrom oneflow.nn.modules.math_ops import addmm_op as addmm\nfrom oneflow.nn.modules.nonzero import nonzero_op as nonzero\nfrom oneflow.nn.modules.nms import nms_op as nms\nfrom oneflow.nn.modules.numel import numel_op as numel\nfrom oneflow.nn.modules.meshgrid import meshgrid_op as meshgrid\nfrom oneflow.nn.modules.unique import unique_op as unique\nfrom oneflow._C import normal\nfrom oneflow._C import normal_\nfrom oneflow._C import rand\nfrom oneflow._C import randn\nfrom oneflow._C import randn_like\nfrom oneflow._C import randint\nfrom oneflow._C import randint_like\nfrom oneflow._C import randperm\nfrom oneflow.nn.modules.reshape import reshape_op as reshape\nfrom oneflow.nn.modules.reshape import view_op as view\nfrom oneflow.nn.modules.slice import slice_op as slice\nfrom oneflow.nn.modules.slice import slice_update_op as slice_update\nfrom oneflow.nn.modules.tensor_buffer import gen_tensor_buffer\nfrom oneflow.nn.modules.tensor_buffer import (\n    tensor_buffer_to_tensor_op as tensor_buffer_to_tensor,\n)\nfrom oneflow.nn.modules.tensordot import tensordot\nfrom oneflow.nn.modules.norm import norm\nfrom oneflow.nn.modules.as_tensor import as_tensor\nfrom oneflow.nn.modules.tensor_buffer import tensor_to_tensor_buffer\nfrom oneflow.nn.modules.global_cast import local_to_global_op as local_to_global\nfrom oneflow.nn.modules.global_cast import global_to_global_op as global_to_global\nfrom oneflow.nn.modules.global_cast import to_global_op as to_global\nfrom oneflow.nn.modules.global_cast import to_local_op as to_local\nfrom oneflow.nn.modules.where import where_op as where\nfrom oneflow.nn.modules.scatter import *\nfrom oneflow.nn.modules.broadcast_ops import (\n    broadcast_tensors,\n    broadcast_shapes,\n    broadcast_to,\n)\nfrom oneflow.ops.stateful_ops import StatefulOp as stateful_op\n\n# autocast\nfrom oneflow._oneflow_internal import (\n    is_autocast_enabled,\n    set_autocast_enabled,\n    get_autocast_gpu_dtype,\n    get_autocast_cpu_dtype,\n    set_autocast_gpu_dtype,\n    set_autocast_cpu_dtype,\n    is_autocast_cache_enabled,\n    set_autocast_cache_enabled,\n    clear_autocast_cache,\n)\nfrom oneflow.amp.autocast_mode import *\nfrom oneflow.jit import *\n\nfrom . import (\n    autograd,\n    distributed,\n    distributions,\n    linalg,\n    optim,\n    comm,\n    boxing,\n    backends,\n    amp,\n    hub,\n    fx,\n    fft,\n    special,\n)\nimport oneflow.utils.data\nimport oneflow.framework.docstr as docstr\nimport oneflow.cuda\nimport oneflow.multiprocessing\nimport oneflow.asyncs\nimport oneflow.one_embedding\nimport oneflow.profiler\nimport oneflow.mock_torch\nimport oneflow.remat\n\nif oneflow._oneflow_internal.flags.with_mlir():\n    oneflow_internal_path = oneflow._oneflow_internal.__file__\n    oneflow._oneflow_internal.ir.load_jit_shared_lib(oneflow_internal_path)\n"
  },
  {
    "path": "python/oneflow/__main__.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport argparse\nimport os\n\nparser = argparse.ArgumentParser()\nparser.add_argument(\"--doctor\", default=False, action=\"store_true\", required=False)\nargs = parser.parse_args()\n\n\ndef main():\n    if args.doctor:\n        import oneflow\n        import oneflow.sysconfig\n\n        print(\"path:\", oneflow.__path__)\n        print(\"version:\", oneflow.__version__)\n        print(\"git_commit:\", oneflow.__git_commit__)\n        print(\"cmake_build_type:\", oneflow.sysconfig.cmake_build_type())\n        print(\"rdma:\", oneflow.sysconfig.with_rdma())\n        print(\"mlir:\", oneflow.sysconfig.with_mlir())\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "python/oneflow/_dynamo/__init__.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport warnings\n\n# Reference: https://github.com/pytorch/pytorch/blob/v2.0.1/torch/_dynamo/__init__.py\n__all__ = [\n    \"allow_in_graph\",\n]\n\n\ndef allow_in_graph(fn):\n    \"\"\"\n    \"\"\"\n    if isinstance(fn, (list, tuple)):\n        return [allow_in_graph(x) for x in fn]\n    assert callable(fn), \"allow_in_graph expects a callable\"\n    warnings.warn(\n        \"The oneflow._dynamo.allow_in_graph interface is just to align the torch._dynamo.allow_in_graph interface and has no practical significance.\"\n    )\n    return fn\n"
  },
  {
    "path": "python/oneflow/_utils.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport sys\nimport traceback\nimport oneflow as flow\n\n\nclass KeyErrorMessage(str):\n    r\"\"\"str subclass that returns itself in repr\"\"\"\n\n    def __repr__(self):\n        return self\n\n\nclass ExceptionWrapper(object):\n    r\"\"\"Wraps an exception plus traceback to communicate across threads\"\"\"\n\n    def __init__(self, exc_info=None, where=\"in background\"):\n        # It is important that we don't store exc_info, see\n        # NOTE [ Python Traceback Reference Cycle Problem ]\n        if exc_info is None:\n            exc_info = sys.exc_info()\n        self.exc_type = exc_info[0]\n        self.exc_msg = \"\".join(traceback.format_exception(*exc_info))\n        self.where = where\n\n    def reraise(self):\n        r\"\"\"Reraises the wrapped exception in the current thread\"\"\"\n        # Format a message such as: \"Caught ValueError in DataLoader worker\n        # process 2. Original Traceback:\", followed by the traceback.\n        msg = \"Caught {} {}.\\nOriginal {}\".format(\n            self.exc_type.__name__, self.where, self.exc_msg\n        )\n        if self.exc_type == KeyError:\n            # KeyError calls repr() on its argument (usually a dict key). This\n            # makes stack traces unreadable. It will not be changed in Python\n            # (https://bugs.python.org/issue2651), so we work around it.\n            msg = KeyErrorMessage(msg)\n        elif getattr(self.exc_type, \"message\", None):\n            # Some exceptions have first argument as non-str but explicitly\n            # have message field\n            raise self.exc_type(message=msg)\n        raise self.exc_type(msg)\n\n\ndef _flatten_dense_tensors(tensors):\n    \"\"\"Flatten dense tensors into a contiguous 1D buffer. Assume tensors are of\n    same dense type.\n\n    The api is referenced from https://github.com/pytorch/pytorch/blob/master/torch/_utils.py#L437\n\n    Since inputs are dense, the resulting tensor will be a concatenated 1D\n    buffer. Element-wise operation on this buffer will be equivalent to\n    operating individually.\n\n    Args:\n        tensors (Iterable[Tensor]): dense tensors to flatten.\n\n    Returns:\n        A contiguous 1D buffer containing input tensors.\n    \"\"\"\n    if len(tensors) == 1:\n        return flow._C.flatten(tensors[0])\n    else:\n        flatten_tensors = []\n        for tensor in tensors:\n            flatten_tensors.append(flow.flatten(tensor))\n        return flow.cat(flatten_tensors, 0)\n\n\ndef _unflatten_dense_tensors(flat, tensors):\n    \"\"\"View a flat buffer using the sizes of tensors. Assume that tensors are of\n    same dense type, and that flat is given by _flatten_dense_tensors.\n\n    The api is referenced from https://github.com/pytorch/pytorch/blob/master/torch/_utils.py#L474\n\n    Args:\n        flat (Tensor): flattened dense tensors to unflatten.\n        tensors (Iterable[Tensor]): dense tensors whose sizes will be used to\n          unflatten flat.\n\n    Returns:\n        Unflattened dense tensors with sizes same as tensors and values from\n        flat.\n    \"\"\"\n    outputs = []\n    offset = 0\n    for tensor in tensors:\n        numel = tensor.numel()\n        if numel == 0:\n            outputs.append(flow.zeros_like(tensor))\n        else:\n            outputs.append(flow.narrow(flat, 0, offset, numel).view(tensor.size()))\n            offset += numel\n    return outputs\n"
  },
  {
    "path": "python/oneflow/amp/__init__.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nfrom .grad_scaler import GradScaler\nfrom .grad_scaler import StaticGradScaler\nfrom .autocast_mode import *\n"
  },
  {
    "path": "python/oneflow/amp/autocast_mode.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport functools\nimport warnings\nfrom typing import Any, Optional\n\nimport oneflow as flow\nimport oneflow._oneflow_internal.lazy_mode as lazy_mode\n\n\n__all__ = [\"autocast_decorator\", \"autocast\"]\n\n\ndef autocast_decorator(autocast_instance, func):\n    @functools.wraps(func)\n    def decorate_autocast(*args, **kwargs):\n        with autocast_instance:\n            return func(*args, **kwargs)\n\n    return decorate_autocast\n\n\nclass autocast(object):\n    r\"\"\"\n    Note:\n      The following doc was origined by pytorch, see\n      https://github.com/pytorch/pytorch/blob/master/torch/amp/autocast_mode.py#L19-L179\n\n    Instances of :class:`autocast` serve as context managers or decorators that\n    allow regions of your script to run in mixed precision.\n\n    In these regions, ops run in an op-specific dtype chosen by autocast\n    to improve performance while maintaining accuracy.\n\n    When entering an autocast-enabled region, Tensors may be any type.\n    You should not call ``half()`` or ``bfloat16()`` on your model(s) or inputs when using autocasting.\n\n    :class:`autocast` should wrap only the forward pass(es) of your network, including the loss\n    computation(s).  Backward passes under autocast are not recommended.\n    Backward ops run in the same type that autocast used for corresponding forward ops.\n\n    Example for CUDA Devices::\n\n        # Creates model and optimizer in default precision\n        model = Net().cuda()\n        optimizer = optim.SGD(model.parameters(), ...)\n\n        for input, target in data:\n            optimizer.zero_grad()\n\n            # Enables autocasting for the forward pass (model + loss)\n            with oneflow.autocast(device_type=\"cuda\"):\n                output = model(input)\n                loss = loss_fn(output, target)\n\n            # Exits the context manager before backward()\n            loss.backward()\n            optimizer.step()\n\n\n    :class:`autocast` can also be used as a decorator, e.g., on the ``forward`` method of your model::\n\n        class AutocastModel(nn.Module):\n            ...\n            @oneflow.autocast(device_type=\"cuda\")\n            def forward(self, input):\n                ...\n\n    Floating-point Tensors produced in an autocast-enabled region may be ``float16``.\n    After returning to an autocast-disabled region, using them with floating-point\n    Tensors of different dtypes may cause type mismatch errors.  If so, cast the Tensor(s)\n    produced in the autocast region back to ``float32`` (or other dtype if desired).\n    If a Tensor from the autocast region is already ``float32``, the cast is a no-op,\n    and incurs no additional overhead.\n    CUDA Example::\n\n        # Creates some tensors in default dtype (here assumed to be float32)\n        a_float32 = oneflow.rand((8, 8), device=\"cuda\")\n        b_float32 = oneflow.rand((8, 8), device=\"cuda\")\n        c_float32 = oneflow.rand((8, 8), device=\"cuda\")\n        d_float32 = oneflow.rand((8, 8), device=\"cuda\")\n\n        with oneflow.autocast(device_type=\"cuda\"):\n            # oneflow.mm is on autocast's list of ops that should run in float16.\n            # Inputs are float32, but the op runs in float16 and produces float16 output.\n            # No manual casts are required.\n            e_float16 = oneflow.mm(a_float32, b_float32)\n            # Also handles mixed input types\n            f_float16 = oneflow.mm(d_float32, e_float16)\n\n        # After exiting autocast, calls f_float16.float() to use with d_float32\n        g_float32 = oneflow.mm(d_float32, f_float16.float())\n\n    CPU Training Example::\n\n        # Creates model and optimizer in default precision\n        model = Net()\n        optimizer = optim.SGD(model.parameters(), ...)\n\n        for epoch in epochs:\n            for input, target in data:\n                optimizer.zero_grad()\n\n                # Runs the forward pass with autocasting.\n                with oneflow.autocast(device_type=\"cpu\", dtype=oneflow.bfloat16):\n                    output = model(input)\n                    loss = loss_fn(output, target)\n\n                loss.backward()\n                optimizer.step()\n\n\n    CPU Inference Example::\n\n        # Creates model in default precision\n        model = Net().eval()\n\n        with oneflow.autocast(device_type=\"cpu\", dtype=oneflow.bfloat16):\n            for input in data:\n                # Runs the forward pass with autocasting.\n                output = model(input)\n\n    The autocast state is thread-local.  If you want it enabled in a new thread, the context manager or decorator\n    must be invoked in that thread.\n\n    Args:\n        device_type(str, required):  Whether to use 'cuda' or 'cpu' device\n        enabled(bool, optional):  Whether autocasting should be enabled in the region.\n            Default: ``True``\n        dtype(oneflow_dtype, optional):  Whether to use oneflow.float16 or oneflow.bfloat16.\n        cache_enabled(bool, optional):  Whether the weight cache inside autocast should be enabled.\n            Default: ``True``\n    \"\"\"\n\n    def __init__(\n        self,\n        device_type: str,\n        dtype: Optional[flow.dtype] = None,\n        enabled: bool = True,\n        cache_enabled: Optional[bool] = None,\n    ):\n        self.device = device_type\n        if self.device == \"cuda\":\n            self.fast_dtype = flow.get_autocast_gpu_dtype()\n        elif self.device == \"cpu\":\n            self.fast_dtype = flow.get_autocast_cpu_dtype()\n        else:\n            raise RuntimeError(\n                \"User specified autocast device_type must be 'cuda' or 'cpu'\"\n            )\n        self.cache_enabled = flow.is_autocast_cache_enabled()\n\n        if dtype is not None:\n            self.fast_dtype = dtype\n        if cache_enabled is not None:\n            self.cache_enabled = cache_enabled\n\n        if self.device == \"cpu\":\n            warnings.warn(\n                \"CPU autocast is not supported currently. Disabling autocast.\"\n            )\n            enabled = False\n        if lazy_mode.is_enabled():\n            warnings.warn(\n                \"Autocast is not supported for lazy mode. Disabling autocast.\"\n            )\n            enabled = False\n        self.enabled = enabled\n\n    def __enter__(self):\n        self.autocast_mode = flow._oneflow_internal.AutoCastMode(\n            self.device, self.fast_dtype, self.enabled, self.cache_enabled\n        )\n        return self\n\n    def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any):\n        del self.autocast_mode\n\n    def __call__(self, func):\n        return autocast_decorator(self, func)\n"
  },
  {
    "path": "python/oneflow/amp/grad_scaler.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\n\nclass GradScaler(object):\n    def __init__(\n        self,\n        init_scale=2.0 ** 16,\n        growth_factor=2.0,\n        backoff_factor=0.5,\n        growth_interval=2000,\n    ):\n        self._init_scale = init_scale\n        self._growth_factor = growth_factor\n        self._backoff_factor = backoff_factor\n        if self._backoff_factor != 1.0 / self._growth_factor:\n            raise ValueError(\n                \"Only support 1.0/growth_factor as backoff_factor at the moment, \"\n                \"got {}\".format(backoff_factor)\n            )\n        self._growth_interval = growth_interval\n\n    def _generate_conf_for_graph(self, train_conf):\n        train_conf.dynamic_loss_scale_policy.initial_loss_scale = self._init_scale\n        train_conf.dynamic_loss_scale_policy.increment_period = self._growth_interval\n        train_conf.dynamic_loss_scale_policy.multiplier = self._growth_factor\n\n\nclass StaticGradScaler(object):\n    def __init__(self, scale_factor):\n        if scale_factor <= 0.0:\n            raise ValueError(\"StaticGradScaler's scale_factor must > 0.0\")\n\n        self._scale_factor = scale_factor\n\n    def _generate_conf_for_graph(self, train_conf):\n        train_conf.loss_scale_factor = self._scale_factor\n"
  },
  {
    "path": "python/oneflow/ao/quantization.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\n\nclass DeQuantStub:\n    def __init__(self, *args, **kwargs):\n        raise NotImplementedError(\n            \"The oneflow.ao.DeQuantStub interface is just to align the torch.ao.DeQuantStub interface and has no practical significance.\"\n        )\n\n\nclass QuantStub:\n    def __init__(self, *args, **kwargs):\n        raise NotImplementedError(\n            \"The oneflow.ao.QuantStub interface is just to align the torch.ao.QuantStub interface and has no practical significance.\"\n        )\n"
  },
  {
    "path": "python/oneflow/asyncs/__init__.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nfrom .thread import Thread, thread\n"
  },
  {
    "path": "python/oneflow/asyncs/thread.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow._oneflow_internal\n\nThread = oneflow._oneflow_internal.AsyncThread\n\n\nclass thread:\n    r\"\"\"Context-manager to pick worker thread.\n    By default, all opkernels are excuted/launched in worker thread 0. Within this context, opkernels can be excuted/launched in the worker thread indicated by `thread_global_id`. \n    This context manager is thread local; it will not affect ops in other threads.\n    Also functions as a decorator. (Make sure to instantiate with parenthesis.)\n\n    Args:\n        worker_thread: a worker thread create with oneflow.asyncs.Thread. \n\n    For example:\n\n    .. code-block:: python\n        >>> import oneflow as flow\n        >>> with flow.asyncs.thread(flow.asyncs.Thread()):\n        ...     print(flow.ones(2, 2))\n        ...\n        tensor([[1., 1.],\n                [1., 1.]], dtype=oneflow.float32)\n    \"\"\"\n\n    def __init__(self, worker_thread: Thread):\n        self.stream_set_ = oneflow._oneflow_internal.StreamSet(worker_thread)\n        self.worker_thread_ = worker_thread\n\n    def __enter__(self):\n        self.guard_ = oneflow._oneflow_internal.StreamGuard(self.stream_set_)\n\n    def __exit__(self, type, value, traceback):\n        del self.guard_\n"
  },
  {
    "path": "python/oneflow/autograd/__init__.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nfrom oneflow.autograd.autograd import backward, grad\nfrom oneflow.autograd.autograd_function import Function\nfrom oneflow.autograd.autograd_mode import (\n    set_grad_enabled,\n    enable_grad,\n    inference_mode,\n    is_grad_enabled,\n    no_grad,\n)\nfrom oneflow.autograd.functional import vjp, jvp, jacobian, hessian, hvp, vhp\nfrom . import graph\n\n__all__ = [\n    \"backward\",\n    \"grad\",\n    \"Function\",\n    \"set_grad_enabled\",\n    \"enable_grad\",\n    \"inference_mode\",\n    \"is_grad_enabled\",\n    \"no_grad\",\n    \"vjp\",\n    \"jvp\",\n    \"jacobian\",\n    \"hessian\",\n    \"hvp\",\n    \"vhp\",\n]\n"
  },
  {
    "path": "python/oneflow/autograd/autograd.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nfrom typing import Sequence, Tuple, Union\n\nfrom oneflow._oneflow_internal import TensorTuple\nfrom oneflow._oneflow_internal.autograd import backward as backward_api\nfrom oneflow._oneflow_internal.autograd import grad as grad_api\nfrom oneflow.framework.tensor import Tensor\nfrom oneflow.framework.tensor_tuple_util import convert_to_tensor_tuple\n\n\ndef grad(\n    outputs: Union[Tensor, Sequence[Tensor]],\n    inputs: Union[Tensor, Sequence[Tensor]],\n    grad_outputs: Union[Tensor, Sequence[Tensor], None] = None,\n    retain_graph: bool = False,\n    create_graph: bool = False,\n    allow_unused: bool = False,\n    is_grads_batched: bool = False,\n) -> Tuple[Tensor]:\n    r\"\"\"\n    Computes and returns the sum of gradients of outputs with respect to the inputs.\n\n    The documentation is referenced from:\n    https://pytorch.org/docs/1.10/generated/torch.autograd.grad.html.\n\n    The graph is differentiated using the chain rule. ``grad_outputs`` should be a sequence of\n    length matching ``outputs``, containing the \"vector\" in the Jacobian-vector product.\n    (``None`` is an acceptable value for that tensor don't require gradient.)\n\n    Args:\n        outputs (Sequence[Tensor]): Tensors of which the derivative will be computed.\n        inputs (Sequence[Tensor]): Inputs w.r.t. which the derivative will be returned(and not\n            accumulated into ``.grad``).\n        grad_outputs (Sequence[Tensor], optional): The \"vector\" in the Jacobian-vector product.\n            Usually gradients w.r.t. each output. None values can be specified for scalar Tensors\n            or ones that don't require grad. Defaults to None.\n        retain_graph (bool, optional): If ``False``, the graph used to compute the grads will be\n            reset after backward is complete. Defaults to ``False``. Note that in nearly all cases\n            setting this option to ``True`` is not needed and often can be worked around in a much\n            more efficient way. Defaults to the value of ``create_graph``.\n        create_graph (bool, optional): If ``True``, graph of the derivative will be constructed,\n            allowing to compute higher order derivative products. Defaults to ``False``.\n        allow_unused (bool, optional): If ``False``, specifying inputs that were not\n            used when computing outputs (and therefore their grad is always zero)\n            is an error. Defaults to ``False``.\n        is_grads_batched (bool, optional): If True, the first dimension of each tensor in\n            grad_outputs will be interpreted as the batch dimension. Instead of computing a single\n            vector-Jacobian product, we compute a batch of vector-Jacobian products for each “vector”\n            in the batch. This should lead to performance improvements when compared to manually\n            looping and performing backward multiple times. Defaults to ``False``.\n\n    Returns:\n        Tuple(Tensor): A tuple of tensors containing the gradients for each ``inputs``.\n    \"\"\"\n    in_grads = grad_api(\n        convert_to_tensor_tuple(outputs),\n        convert_to_tensor_tuple(inputs),\n        convert_to_tensor_tuple(grad_outputs),\n        retain_graph,\n        create_graph,\n        allow_unused,\n        is_grads_batched,\n    )\n    return tuple([x for x in in_grads])\n\n\ndef backward(\n    tensors: Union[Tensor, Sequence[Tensor]],\n    grad_tensors: Union[Tensor, Sequence[Tensor], None],\n    retain_graph: bool = False,\n    create_graph: bool = False,\n) -> None:\n    r\"\"\"\n    Computes the sum of gradients of given tensors with respect to graph leaves.\n\n    The documentation is referenced from:\n    https://pytorch.org/docs/1.10/generated/torch.autograd.backward.html.\n\n    The graph is differentiated using the chain rule. If any of ``tensors`` are non-scalar (i.e.\n    their data has more than one element) and require gradient, then the Jacobian-vector product\n    would be computed, in this case the function additionally requires specifying ``grad_tensors``.\n    It should be a sequence of matching length, that contains the \"vector\" in the Jacobian-vector\n    product, usually the gradient of the differentiated function w.r.t. corresponding tensors.\n    (``None`` is an acceptable value for all tensors that don't need gradient.)\n\n    This function accumulates gradients in the leaves - you might need to zero ``.grad`` attributes\n    or set them to ``None`` before calling it.\n\n    Note:\n        Using this method with ``create_graph=True`` will create a reference cycle between the\n        parameter and its gradient which can cause a memory leak. We recommend using\n        ``autograd.grad`` when creating the graph to avoid this. If you have to use this function,\n        make sure to reset the ``.grad`` fields of your parameters to ``None`` after use to break\n        the cycle and avoid the leak.\n\n    Args:\n        tensors (Tensor or Sequence[Tensor]): Tensors of which the derivative will be computed.\n        grad_tensors (Tensor or Sequence[Tensor], optional): The \"vector\" in the Jacobian-vector\n            product, usually gradients each element of corresponding tensors. (None values can be\n            specified for scalar Tensors or ones that don't require grad.)\n        retain_graph (bool, optional): If ``False``, the graph used to compute the grads will be\n            reset after backward is complete. Defaults to ``False``. Note that in nearly all cases\n            setting this option to ``True`` is not needed and often can be worked around in a much\n            more efficient way. Defaults to the value of ``create_graph``.\n        create_graph (bool, optional): If ``True``, graph of the derivative will be constructed,\n            allowing to compute higher order derivative products. Defaults to ``False``.\n    \"\"\"\n    backward_api(\n        convert_to_tensor_tuple(tensors),\n        convert_to_tensor_tuple(grad_tensors),\n        retain_graph,\n        create_graph,\n    )\n"
  },
  {
    "path": "python/oneflow/autograd/autograd_function.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nfrom oneflow._oneflow_internal import TensorTuple\nfrom oneflow._oneflow_internal.autograd import AutogradFunctionBase\n\n\nclass Function(AutogradFunctionBase):\n    r\"\"\"\n    Function(self)\n\n    Base class to create custom autograd.Function.\n\n    To create a custom autograd.Function, subclass this class and implement the ``forward()``\n    and ``backward()`` static methods. Then, to use your custom op in the forward pass, call the\n    class method ``apply()`` or ``__call__()``. Do not call ``forward()`` directly.\n\n    For example:\n\n    .. code-block:: python\n\n        class Exp(Function):\n            @staticmethod\n            def forward(ctx, i):\n                result = i.exp()\n                ctx.save_for_backward(result)\n                return result\n\n            @staticmethod\n            def backward(ctx, grad_output):\n                result, = ctx.saved_tensors\n                return grad_output * result\n\n        # Use it by calling the apply method or __call__ method\n        output = Exp.apply(input)  # output = Exp()(input)\n    \"\"\"\n\n    def __init__(self):\n        super().__init__()\n\n    def __call__(self, *inputs):\n        r\"\"\"\n        See :meth:`self.apply`.\n        \"\"\"\n        return self.apply(*inputs)\n\n    @classmethod\n    def apply(cls, *inputs):\n        r\"\"\"\n        Calculate output tensors and build backward graph.\n        \"\"\"\n        return AutogradFunctionBase.apply(\n            cls.__name__, cls.forward, cls.backward, *inputs\n        )\n\n    @staticmethod\n    def forward(ctx, *inputs):\n        r\"\"\"\n        Override this function for custom forward calculation.\n        \"\"\"\n        raise NotImplementedError(\n            \"You must implement the forward function for custom autograd.Function.\"\n        )\n\n    @staticmethod\n    def backward(ctx, *out_grads):\n        r\"\"\"\n        Override this function for custom backward calculation.\n        \"\"\"\n        raise NotImplementedError(\n            \"You must implement the backward function for custom autograd.Function.\"\n        )\n"
  },
  {
    "path": "python/oneflow/autograd/autograd_mode.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport oneflow._oneflow_internal\nfrom oneflow._oneflow_internal.autograd import AutoGradMode\n\n\ndef is_grad_enabled():\n    r\"\"\"\n    Returns True if grad mode is currently enabled.\n    \"\"\"\n    return oneflow._oneflow_internal.autograd.is_grad_enabled()\n\n\nclass inference_mode:\n    r\"\"\"\n    Context-manager that enables or disables inference mode\n\n    InferenceMode is a new context manager analogous to no_grad to be used when you arecertain\n    your operations will have no interactions with autograd (e.g., model training). Code run\n    under this mode gets better performance by disabling view tracking and version counter bumps.\n\n    This context manager is thread local; it will not affect computation in other threads.\n\n    Also functions as a decorator. (Make sure to instantiate with parenthesis.)\n\n    Args:\n        mode (bool): Flag whether to enable or disable inference mode. (default: True)\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> x = flow.ones(2, 3, requires_grad=True)\n        >>> with flow.inference_mode():\n        ...     y = x * x\n        >>> y.requires_grad\n        False\n        >>> @flow.inference_mode()\n        ... def no_grad_func(x):\n        ...     return x * x\n        >>> y = no_grad_func(x)\n        >>> y.requires_grad\n        False\n    \"\"\"\n\n    def __init__(self, mode=True):\n        self.infer_mode = mode\n\n    def __call__(self, func):\n        def wrapper(*args, **kwargs):\n            with AutoGradMode(not self.infer_mode):\n                return func(*args, **kwargs)\n\n        return wrapper\n\n    def __enter__(self):\n        self.grad_mode = AutoGradMode(not self.infer_mode)\n        return self\n\n    def __exit__(self, exc_type, exc_val, exc_tb):\n        pass\n\n\nclass enable_grad:\n    r\"\"\"\n    Context-manager that enabled gradient calculation.\n\n    Enables gradient calculation, if it has been disabled via no_grad.\n\n    This context manager is thread local; it will not affect computation in other threads.\n\n    Also functions as a decorator. (Make sure to instantiate with parenthesis.)\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> x = flow.ones(2, 3, requires_grad=True)\n        >>> with flow.no_grad():\n        ...     with flow.enable_grad():\n        ...         y = x * x\n        >>> y.requires_grad\n        True\n        >>> @flow.enable_grad()\n        ... def no_grad_func(x):\n        ...     return x * x\n        >>> with flow.no_grad():\n        ...     y = no_grad_func(x)\n        >>> y.requires_grad\n        True\n    \"\"\"\n\n    def __call__(self, func):\n        def wrapper(*args, **kwargs):\n            with AutoGradMode(True):\n                return func(*args, **kwargs)\n\n        return wrapper\n\n    def __enter__(self):\n        self.grad_mode = AutoGradMode(True)\n        return self\n\n    def __exit__(self, exc_type, exc_val, exc_tb):\n        pass\n\n\nclass no_grad:\n    r\"\"\"\n    Context-manager that disabled gradient calculation.\n\n    Disabling gradient calculation is useful for inference, when you are sure that\n    you will not call Tensor.backward(). It will reduce memory consumption for computations\n    that would otherwise have requires_grad=True.\n\n    In this mode, the result of every computation will have requires_grad=False, even when\n    the inputs have requires_grad=True.\n\n    This context manager is thread local; it will not affect computation in other threads.\n\n    Also functions as a decorator. (Make sure to instantiate with parenthesis.)\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> x = flow.ones(2, 3, requires_grad=True)\n        >>> with flow.no_grad():\n        ...     y = x * x\n        >>> y.requires_grad\n        False\n        >>> @flow.no_grad()\n        ... def no_grad_func(x):\n        ...     return x * x\n        >>> y = no_grad_func(x)\n        >>> y.requires_grad\n        False\n    \"\"\"\n\n    def __call__(self, func):\n        def wrapper(*args, **kwargs):\n            with AutoGradMode(False):\n                return func(*args, **kwargs)\n\n        return wrapper\n\n    def __enter__(self):\n        self.grad_mode = AutoGradMode(False)\n        return self\n\n    def __exit__(self, exc_type, exc_val, exc_tb):\n        pass\n\n\nclass set_grad_enabled:\n    r\"\"\"\n    Context-manager that enabled gradient calculation.\n\n    Enables gradient calculation, if it has been disabled via no_grad.\n\n    This context manager is thread local; it will not affect computation in other threads.\n\n    Also functions as a decorator. (Make sure to instantiate with parenthesis.)\n\n\n    Args:\n        mode (bool): Flag whether to enable or disable gradient calculation. (default: True)\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> x = flow.ones(2, 3, requires_grad=True)\n        >>> with flow.set_grad_enabled(True):\n        ...     y = x * x\n        >>> y.requires_grad\n        True\n        >>> @flow.set_grad_enabled(False)\n        ... def no_grad_func(x):\n        ...     return x * x\n        >>> y = no_grad_func(x)\n        >>> y.requires_grad\n        False\n        \n    \"\"\"\n\n    def __init__(self, is_train=True):\n        self.is_train = is_train\n        self.prev_mode = is_grad_enabled()\n        oneflow._oneflow_internal.autograd.set_grad_enabled(is_train)\n\n    def __call__(self, func):\n        # recover grad mode set in __init__\n        oneflow._oneflow_internal.autograd.set_grad_enabled(self.prev_mode)\n\n        def wrapper(*args, **kwargs):\n            with AutoGradMode(self.is_train):\n                return func(*args, **kwargs)\n\n        return wrapper\n\n    def __enter__(self):\n        # recover grad mode set in __init__\n        oneflow._oneflow_internal.autograd.set_grad_enabled(self.prev_mode)\n        self.grad_mode = AutoGradMode(self.is_train)\n        return self\n\n    def __exit__(self, exc_type, exc_val, exc_tb):\n        pass\n\n\nif __name__ == \"__main__\":\n    import doctest\n\n    doctest.testmod(raise_on_error=True)\n"
  },
  {
    "path": "python/oneflow/autograd/functional.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\n# This code is referenced from https://github.com/pytorch/pytorch/blob/master/torch/autograd/functional.py and consistent with oneflow.\nfrom typing import List, Tuple\n\nimport oneflow as flow\n\n__all__ = [\"vjp\", \"jvp\", \"jacobian\", \"hessian\", \"hvp\", \"vhp\"]\n\n# Utility functions\n\n\ndef _as_tuple_nocheck(x):\n    if isinstance(x, tuple):\n        return x\n    elif isinstance(x, list):\n        return tuple(x)\n    else:\n        return (x,)\n\n\ndef _as_tuple(inp, arg_name=None, fn_name=None):\n    # Ensures that inp is a tuple of Tensors\n    # Returns whether or not the original inp was a tuple and the tupled version of the input\n    if arg_name is None and fn_name is None:\n        return _as_tuple_nocheck(inp)\n\n    is_inp_tuple = True\n    if not isinstance(inp, tuple):\n        inp = (inp,)\n        is_inp_tuple = False\n\n    for i, el in enumerate(inp):\n        if not isinstance(el, flow.Tensor):\n            if is_inp_tuple:\n                raise TypeError(\n                    f\"The {arg_name} given to {fn_name} must be either a Tensor or a tuple of Tensors but the\"\n                    f\" value at index {i} has type {type(el)}.\"\n                )\n            else:\n                raise TypeError(\n                    f\"The {arg_name} given to {fn_name} must be either a Tensor or a tuple of Tensors but the\"\n                    f\" given {arg_name} has type {type(el)}.\"\n                )\n\n    return is_inp_tuple, inp\n\n\ndef _tuple_postprocess(res, to_unpack):\n    # Unpacks a potentially nested tuple of Tensors\n    # to_unpack should be a single boolean or a tuple of two booleans.\n    # It is used to:\n    # - invert _as_tuple when res should match the inp given to _as_tuple\n    # - optionally remove nesting of two tuples created by multiple calls to _as_tuple\n    if isinstance(to_unpack, tuple):\n        assert len(to_unpack) == 2\n        if not to_unpack[1]:\n            res = tuple(el[0] for el in res)\n        if not to_unpack[0]:\n            res = res[0]\n    else:\n        if not to_unpack:\n            res = res[0]\n    return res\n\n\ndef _grad_preprocess(inputs, create_graph, need_graph):\n    # Preprocess the inputs to make sure they require gradient\n    # inputs is a tuple of Tensors to preprocess\n    # create_graph specifies if the user wants gradients to flow back to the Tensors in inputs\n    # need_graph specifies if we internally want gradients to flow back to the Tensors in res\n    # Note that we *always* create a new Tensor object to be able to see the difference between\n    # inputs given as arguments and the same Tensors automatically captured by the user function.\n    res = []\n    for inp in inputs:\n        if create_graph and inp.requires_grad:\n            # Create at least a new Tensor object in a differentiable way\n            # oneflow.torch has no is_sparse attribute. https://github.com/Oneflow-Inc/oneflow/issues/10401\n            res.append(inp.view_as(inp))\n        else:\n            res.append(inp.detach().requires_grad_(need_graph))\n    return tuple(res)\n\n\ndef _grad_postprocess(inputs, create_graph):\n    # Postprocess the generated Tensors to avoid returning Tensors with history when the user did not\n    # request it.\n    if isinstance(inputs[0], flow.Tensor):\n        if not create_graph:\n            return tuple(inp.detach() for inp in inputs)\n        else:\n            return inputs\n    else:\n        return tuple(_grad_postprocess(inp, create_graph) for inp in inputs)\n\n\ndef _validate_v(v, other, is_other_tuple):\n    # This assumes that other is the correct shape, and v should match\n    # Both are assumed to be tuples of Tensors\n    if len(other) != len(v):\n        if is_other_tuple:\n            raise RuntimeError(\n                f\"v is a tuple of invalid length: should be {len(other)} but got {len(v)}.\"\n            )\n        else:\n            raise RuntimeError(\"The given v should contain a single Tensor.\")\n\n    for idx, (el_v, el_other) in enumerate(zip(v, other)):\n        if el_v.size() != el_other.size():\n            prepend = \"\"\n            if is_other_tuple:\n                prepend = f\"Entry {idx} in \"\n            raise RuntimeError(\n                f\"{prepend}v has invalid size: should be {el_other.size()} but got {el_v.size()}.\"\n            )\n\n\ndef _check_requires_grad(inputs, input_type, strict):\n    # Used to make all the necessary checks to raise nice errors in strict mode.\n    if not strict:\n        return\n\n    if input_type not in [\"outputs\", \"grad_inputs\", \"jacobian\", \"hessian\"]:\n        raise RuntimeError(\"Invalid input_type to _check_requires_grad\")\n    for i, inp in enumerate(inputs):\n        if inp is None:\n            # This can only be reached for grad_inputs.\n            raise RuntimeError(\n                f\"The output of the user-provided function is independent of input {i}.\"\n                \" This is not allowed in strict mode.\"\n            )\n        if not inp.requires_grad:\n            if input_type == \"hessian\":\n                raise RuntimeError(\n                    f\"The hessian of the user-provided function with respect to input {i}\"\n                    \" is independent of the input. This is not allowed in strict mode.\"\n                    \" You should ensure that your function is thrice differentiable and that\"\n                    \" the hessian depends on the inputs.\"\n                )\n            elif input_type == \"jacobian\":\n                raise RuntimeError(\n                    \"While computing the hessian, found that the jacobian of the user-provided\"\n                    f\" function with respect to input {i} is independent of the input. This is not\"\n                    \" allowed in strict mode. You should ensure that your function is twice\"\n                    \" differentiable and that the jacobian depends on the inputs (this would be\"\n                    \" violated by a linear function for example).\"\n                )\n            elif input_type == \"grad_inputs\":\n                raise RuntimeError(\n                    f\"The gradient with respect to input {i} is independent of the inputs of the\"\n                    \" user-provided function. This is not allowed in strict mode.\"\n                )\n            else:\n                raise RuntimeError(\n                    f\"Output {i} of the user-provided function does not require gradients.\"\n                    \" The outputs must be computed in a differentiable manner from the input\"\n                    \" when running in strict mode.\"\n                )\n\n\ndef _autograd_grad(\n    outputs, inputs, grad_outputs=None, create_graph=False, retain_graph=None,\n):\n    # Version of autograd.grad that accepts `None` in outputs and do not compute gradients for them.\n    # This has the extra constraint that inputs has to be a tuple\n    assert isinstance(outputs, tuple)\n    if grad_outputs is None:\n        grad_outputs = (None,) * len(outputs)\n    assert isinstance(grad_outputs, tuple)\n    assert len(outputs) == len(grad_outputs)\n\n    new_outputs: Tuple[flow.Tensor, ...] = tuple()\n    new_grad_outputs: Tuple[flow.Tensor, ...] = tuple()\n    for out, grad_out in zip(outputs, grad_outputs):\n        if out is not None and out.requires_grad:\n            new_outputs += (out,)\n            new_grad_outputs += (grad_out,)\n\n    if len(new_outputs) == 0:\n        # No differentiable output, we don't need to call the autograd engine\n        return (None,) * len(inputs)\n    else:\n        return flow.autograd.grad(\n            new_outputs,\n            inputs,\n            new_grad_outputs,\n            allow_unused=True,\n            create_graph=create_graph,\n            retain_graph=retain_graph,\n        )\n\n\ndef _fill_in_zeros(grads, refs, strict, create_graph, stage):\n    # Used to detect None in the grads and depending on the flags, either replace them\n    # with Tensors full of 0s of the appropriate size based on the refs or raise an error.\n    # strict and create graph allow us to detect when it is appropriate to raise an error\n    # stage gives us information of which backward call we consider to give good error message\n    if stage not in [\"back\", \"back_trick\", \"double_back\", \"double_back_trick\"]:\n        raise RuntimeError(f\"Invalid stage argument '{stage}' to _fill_in_zeros\")\n\n    res: Tuple[flow.Tensor, ...] = tuple()\n    for i, grads_i in enumerate(grads):\n        if grads_i is None:\n            if strict:\n                if stage == \"back\":\n                    raise RuntimeError(\n                        \"The output of the user-provided function is independent of \"\n                        f\"input {i}. This is not allowed in strict mode.\"\n                    )\n                elif stage == \"back_trick\":\n                    raise RuntimeError(\n                        f\"The gradient with respect to the input is independent of entry {i}\"\n                        \" in the grad_outputs when using the double backward trick to compute\"\n                        \" forward mode gradients. This is not allowed in strict mode.\"\n                    )\n                elif stage == \"double_back\":\n                    raise RuntimeError(\n                        \"The jacobian of the user-provided function is independent of \"\n                        f\"input {i}. This is not allowed in strict mode.\"\n                    )\n                else:\n                    raise RuntimeError(\n                        \"The hessian of the user-provided function is independent of \"\n                        f\"entry {i} in the grad_jacobian. This is not allowed in strict \"\n                        \"mode as it prevents from using the double backward trick to \"\n                        \"replace forward mode AD.\"\n                    )\n\n            grads_i = flow.zeros_like(refs[i])\n        else:\n            if strict and create_graph and not grads_i.requires_grad:\n                if \"double\" not in stage:\n                    raise RuntimeError(\n                        \"The jacobian of the user-provided function is independent of \"\n                        f\"input {i}. This is not allowed in strict mode when create_graph=True.\"\n                    )\n                else:\n                    raise RuntimeError(\n                        \"The hessian of the user-provided function is independent of \"\n                        f\"input {i}. This is not allowed in strict mode when create_graph=True.\"\n                    )\n\n        res += (grads_i,)\n\n    return res\n\n\n# Public API\n\n\ndef vjp(func, inputs, v=None, create_graph=False, strict=False):\n    r\"\"\"Compute the dot product between a vector ``v`` and the Jacobian of the given function at the point given by the inputs.\n\n    The documentation is referenced from: https://pytorch.org/docs/stable/generated/torch.autograd.functional.vjp.html\n\n    Args:\n        func (function): a Python function that takes Tensor inputs and returns\n            a tuple of Tensors or a Tensor.\n        inputs (tuple of Tensors or Tensor): inputs to the function ``func``.\n        v (tuple of Tensors or Tensor): The vector for which the vector\n            Jacobian product is computed.  Must be the same size as the output\n            of ``func``. This argument is optional when the output of ``func``\n            contains a single element and (if it is not provided) will be set\n            as a Tensor containing a single ``1``.\n        create_graph (bool, optional): If ``True``, both the output and result\n            will be computed in a differentiable way. Note that when ``strict``\n            is ``False``, the result can not require gradients or be\n            disconnected from the inputs.  Defaults to ``False``.\n        strict (bool, optional): If ``True``, an error will be raised when we\n            detect that there exists an input such that all the outputs are\n            independent of it. If ``False``, we return a Tensor of zeros as the\n            vjp for said inputs, which is the expected mathematical value.\n            Defaults to ``False``.\n\n    Returns:\n        output (tuple): tuple with:\n            func_output (tuple of Tensors or Tensor): output of ``func(inputs)``\n\n            vjp (tuple of Tensors or Tensor): result of the dot product with\n            the same shape as the inputs.\n\n    Example:\n\n        >>> def exp_reducer(x):\n        ...     return x.exp().sum(dim=1)\n        >>> inputs = flow.rand(4, 4)\n        >>> v = flow.ones(4)\n        >>> vjp(exp_reducer, inputs, v) # doctest: +ELLIPSIS\n        (tensor([5.7817, 7.2458, 5.7830, 6.7782]),\n         tensor([[1.4458, 1.3962, 1.3042, 1.6354],\n                [2.1288, 1.0652, 1.5483, 2.5035],\n                [2.2046, 1.1292, 1.1432, 1.3059],\n                [1.3225, 1.6652, 1.7753, 2.0152]]))\n\n        >>> vjp(exp_reducer, inputs, v, create_graph=True) # doctest: +ELLIPSIS\n        (tensor([5.7817, 7.2458, 5.7830, 6.7782], grad_fn=<SumBackward1>),\n         tensor([[1.4458, 1.3962, 1.3042, 1.6354],\n                [2.1288, 1.0652, 1.5483, 2.5035],\n                [2.2046, 1.1292, 1.1432, 1.3059],\n                [1.3225, 1.6652, 1.7753, 2.0152]], grad_fn=<MulBackward0>))\n\n        >>> def adder(x, y):\n        ...     return 2 * x + 3 * y\n        >>> inputs = (flow.rand(2), flow.rand(2))\n        >>> v = flow.ones(2)\n        >>> vjp(adder, inputs, v)  # doctest: +ELLIPSIS\n        (tensor([2.4225, 2.3340]),\n         (tensor([2., 2.]), tensor([3., 3.])))\n    \"\"\"\n    with flow.enable_grad():\n        is_inputs_tuple, inputs = _as_tuple(inputs, \"inputs\", \"vjp\")\n        inputs = _grad_preprocess(inputs, create_graph=create_graph, need_graph=True)\n\n        outputs = func(*inputs)\n        is_outputs_tuple, outputs = _as_tuple(\n            outputs, \"outputs of the user-provided function\", \"vjp\"\n        )\n        _check_requires_grad(outputs, \"outputs\", strict=strict)\n\n        if v is not None:\n            _, v = _as_tuple(v, \"v\", \"vjp\")\n            v = _grad_preprocess(v, create_graph=create_graph, need_graph=False)\n            _validate_v(v, outputs, is_outputs_tuple)\n        else:\n            if len(outputs) != 1 or outputs[0].nelement() != 1:\n                raise RuntimeError(\n                    \"The vector v can only be None if the \"\n                    \"user-provided function returns \"\n                    \"a single Tensor with a single element.\"\n                )\n\n    enable_grad = True if create_graph else flow.is_grad_enabled()\n    with flow.set_grad_enabled(enable_grad):\n        grad_res = _autograd_grad(outputs, inputs, v, create_graph=create_graph)\n        vjp = _fill_in_zeros(grad_res, inputs, strict, create_graph, \"back\")\n\n    # Cleanup objects and return them to the user\n    outputs = _grad_postprocess(outputs, create_graph)\n    vjp = _grad_postprocess(vjp, create_graph)\n\n    return (\n        _tuple_postprocess(outputs, is_outputs_tuple),\n        _tuple_postprocess(vjp, is_inputs_tuple),\n    )\n\n\ndef jvp(func, inputs, v=None, create_graph=False, strict=False):\n    r\"\"\"Compute the dot product between the Jacobian of the given function at the point given by the inputs and a vector ``v``.\n    \n    The documentation is referenced from: https://pytorch.org/docs/stable/generated/torch.autograd.functional.jvp.html\n\n    Args:\n        func (function): a Python function that takes Tensor inputs and returns\n            a tuple of Tensors or a Tensor.\n        inputs (tuple of Tensors or Tensor): inputs to the function ``func``.\n        v (tuple of Tensors or Tensor): The vector for which the Jacobian\n            vector product is computed. Must be the same size as the input of\n            ``func``. This argument is optional when the input to ``func``\n            contains a single element and (if it is not provided) will be set\n            as a Tensor containing a single ``1``.\n        create_graph (bool, optional): If ``True``, both the output and result\n            will be computed in a differentiable way. Note that when ``strict``\n            is ``False``, the result can not require gradients or be\n            disconnected from the inputs.  Defaults to ``False``.\n        strict (bool, optional): If ``True``, an error will be raised when we\n            detect that there exists an input such that all the outputs are\n            independent of it. If ``False``, we return a Tensor of zeros as the\n            jvp for said inputs, which is the expected mathematical value.\n            Defaults to ``False``.\n\n    Returns:\n        output (tuple): tuple with:\n            func_output (tuple of Tensors or Tensor): output of ``func(inputs)``\n\n            jvp (tuple of Tensors or Tensor): result of the dot product with\n            the same shape as the output.\n\n    Note:\n        ``autograd.functional.jvp`` computes the jvp by using the backward of\n        the backward (sometimes called the double backwards trick). This is not\n        the most performant way of computing the jvp. Please consider using\n        :func:`flow.func.jvp` instead.\n\n    Example:\n\n        >>> def exp_reducer(x):\n        ...     return x.exp().sum(dim=1)\n        >>> inputs = flow.rand(4, 4)\n        >>> v = flow.ones(4, 4)\n        >>> jvp(exp_reducer, inputs, v) # doctest: +ELLIPSIS\n        (tensor([6.3090, 4.6742, 7.9114, 8.2106]),\n         tensor([6.3090, 4.6742, 7.9114, 8.2106]))\n\n        >>> jvp(exp_reducer, inputs, v, create_graph=True) # doctest: +ELLIPSIS\n        (tensor([6.3090, 4.6742, 7.9114, 8.2106], grad_fn=<SumBackward1>),\n         tensor([6.3090, 4.6742, 7.9114, 8.2106], grad_fn=<SqueezeBackward1>))\n\n        >>> def adder(x, y):\n        ...     return 2 * x + 3 * y\n        >>> inputs = (flow.rand(2), flow.rand(2))\n        >>> v = (flow.ones(2), flow.ones(2))\n        >>> jvp(adder, inputs, v) # doctest: +ELLIPSIS\n        (tensor([2.2399, 2.5005]),\n         tensor([5., 5.]))\n\n    \"\"\"\n    with flow.enable_grad():\n        is_inputs_tuple, inputs = _as_tuple(inputs, \"inputs\", \"jvp\")\n        inputs = _grad_preprocess(inputs, create_graph=create_graph, need_graph=True)\n\n        if v is not None:\n            _, v = _as_tuple(v, \"v\", \"jvp\")\n            v = _grad_preprocess(v, create_graph=create_graph, need_graph=False)\n            _validate_v(v, inputs, is_inputs_tuple)\n        else:\n            if len(inputs) != 1 or inputs[0].nelement() != 1:\n                raise RuntimeError(\n                    \"The vector v can only be None if the input to \"\n                    \"the user-provided function is a single Tensor \"\n                    \"with a single element.\"\n                )\n\n        outputs = func(*inputs)\n        is_outputs_tuple, outputs = _as_tuple(\n            outputs, \"outputs of the user-provided function\", \"jvp\"\n        )\n        _check_requires_grad(outputs, \"outputs\", strict=strict)\n        # The backward is linear so the value of grad_outputs is not important as\n        # it won't appear in the double backward graph. We only need to ensure that\n        # it does not contain inf or nan.\n        grad_outputs = tuple(\n            flow.zeros_like(out, requires_grad=True) for out in outputs\n        )\n\n        grad_inputs = _autograd_grad(outputs, inputs, grad_outputs, create_graph=True)\n        _check_requires_grad(grad_inputs, \"grad_inputs\", strict=strict)\n\n    if create_graph:\n        with flow.enable_grad():\n            grad_res = _autograd_grad(\n                grad_inputs, grad_outputs, v, create_graph=create_graph\n            )\n            jvp = _fill_in_zeros(grad_res, outputs, strict, create_graph, \"back_trick\")\n    else:\n        grad_res = _autograd_grad(\n            grad_inputs, grad_outputs, v, create_graph=create_graph\n        )\n        jvp = _fill_in_zeros(grad_res, outputs, strict, create_graph, \"back_trick\")\n\n    # Cleanup objects and return them to the user\n    outputs = _grad_postprocess(outputs, create_graph)\n    jvp = _grad_postprocess(jvp, create_graph)\n\n    return (\n        _tuple_postprocess(outputs, is_outputs_tuple),\n        _tuple_postprocess(jvp, is_outputs_tuple),\n    )\n\n\ndef _construct_standard_basis_for(tensors, tensor_numels: Tuple[int, ...]):\n    # This function:\n    # - constructs a N=sum(tensor_numels) standard basis. i.e. an NxN identity matrix.\n    # - Splits the identity matrix into chunks with each chunk size determined by `tensor_numels`.\n    # - Each chunk corresponds to one tensor. The chunk has the same dtype and\n    #   device as the tensor\n    #\n    # For example, with tensor_numels = [1, 2, 1], this function returns:\n    # ( tensor([[1],     tensor([[0, 0],      tensor([[0],\n    #           [0],             [1, 0],              [0],\n    #           [0],             [0, 1],              [0],\n    #           [0]])  ,         [0, 0]])  ,          [1]])  )\n    #\n    # Precondition: tensor_numels == tuple(tensor.numel() for tensor in tensors)\n    # Precondition: tensors always has at least one element.\n    #\n    # See NOTE: [Computing jacobian with vmap and grad for multiple tensors]\n    # for context behind this function. All the pre-conditions are guarded for\n    # in flow.autograd.functional.jacobian.\n    assert len(tensors) == len(tensor_numels)\n    assert len(tensors) > 0\n    total_numel = sum(tensor_numels)\n    chunks = tuple(\n        tensor.new_zeros(total_numel, tensor_numel)\n        for tensor, tensor_numel in zip(tensors, tensor_numels)\n    )\n    diag_start_idx = 0\n    for chunk, numel in zip(chunks, tensor_numels):\n        # fill_ does not support NonContiguous.https://github.com/Oneflow-Inc/oneflow/issues/10394\n        # chunk.diagonal(diag_start_idx).fill_(1)\n        for i in range(numel):\n            chunk[diag_start_idx + i][i] = 1\n        diag_start_idx += numel\n    return chunks\n\n\ndef _jacfwd(func, inputs, strict=False, vectorize=False):\n    if strict:\n        raise RuntimeError(\n            \"flow.autograd.functional.jacobian: `strict=True` \"\n            'and `strategy=\"forward-mode\"` are not supported together (yet). '\n            \"Please either set `strict=False` or \"\n            '`strategy=\"reverse-mode\"`.'\n        )\n    is_inputs_tuple, inputs = _as_tuple(inputs, \"inputs\", \"jacobian\")\n    output_info = []\n\n    if vectorize:\n        # Computing Jacobian does not support vectorize. see issue 10397. https://github.com/Oneflow-Inc/oneflow/issues/10397\n        raise NotImplementedError(\"Computing Jacobian does not support vectorize. \")\n    else:\n        raise NotImplementedError(\n            \"Computing Jacobian using forward-AD or forward-over-reverse Hessian is\"\n            \"only implemented for `vectorize=True`.\"\n        )\n\n\ndef jacobian(\n    func,\n    inputs,\n    create_graph=False,\n    strict=False,\n    vectorize=False,\n    strategy=\"reverse-mode\",\n):\n    r\"\"\"Compute the Jacobian of a given function.\n\n    The documentation is referenced from: https://pytorch.org/docs/stable/generated/torch.autograd.functional.jacobian.html\n\n    Args:\n        func (function): a Python function that takes Tensor inputs and returns\n            a tuple of Tensors or a Tensor.\n        inputs (tuple of Tensors or Tensor): inputs to the function ``func``.\n        create_graph (bool, optional): If ``True``, the Jacobian will be\n            computed in a differentiable manner. Note that when ``strict`` is\n            ``False``, the result can not require gradients or be disconnected\n            from the inputs.  Defaults to ``False``.\n        strict (bool, optional): If ``True``, an error will be raised when we\n            detect that there exists an input such that all the outputs are\n            independent of it. If ``False``, we return a Tensor of zeros as the\n            jacobian for said inputs, which is the expected mathematical value.\n            Defaults to ``False``.\n        vectorize (bool, optional): This feature is experimental.\n            Please consider using :func:`flow.func.jacrev` or\n            :func:`flow.func.jacfwd` instead if you are looking for something\n            less experimental and more performant.\n            When computing the jacobian, usually we invoke\n            ``autograd.grad`` once per row of the jacobian. If this flag is\n            ``True``, we perform only a single ``autograd.grad`` call with\n            ``batched_grad=True`` which uses the vmap prototype feature.\n            Though this should lead to performance improvements in many cases,\n            because this feature is still experimental, there may be performance\n            cliffs. See :func:`flow.autograd.grad`'s ``batched_grad`` parameter for\n            more information.\n        strategy (str, optional): Set to ``\"forward-mode\"`` or ``\"reverse-mode\"`` to\n            determine whether the Jacobian will be computed with forward or reverse\n            mode AD. Currently, ``\"forward-mode\"`` requires ``vectorized=True``.\n            Defaults to ``\"reverse-mode\"``. If ``func`` has more outputs than\n            inputs, ``\"forward-mode\"`` tends to be more performant. Otherwise,\n            prefer to use ``\"reverse-mode\"``.\n\n    Returns:\n        Jacobian (Tensor or nested tuple of Tensors): if there is a single\n        input and output, this will be a single Tensor containing the\n        Jacobian for the linearized inputs and output. If one of the two is\n        a tuple, then the Jacobian will be a tuple of Tensors. If both of\n        them are tuples, then the Jacobian will be a tuple of tuple of\n        Tensors where ``Jacobian[i][j]`` will contain the Jacobian of the\n        ``i``\\th output and ``j``\\th input and will have as size the\n        concatenation of the sizes of the corresponding output and the\n        corresponding input and will have same dtype and device as the\n        corresponding input. If strategy is ``forward-mode``, the dtype will be\n        that of the output; otherwise, the input.\n\n    Example:\n\n        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)\n        >>> def exp_reducer(x):\n        ...     return x.exp().sum(dim=1)\n        >>> inputs = flow.rand(2, 2)\n        >>> # xdoctest: +IGNORE_WANT(\"non-deterministic\")\n        >>> jacobian(exp_reducer, inputs)\n        tensor([[[1.4917, 2.4352],\n                 [0.0000, 0.0000]],\n                [[0.0000, 0.0000],\n                 [2.4369, 2.3799]]])\n\n        >>> jacobian(exp_reducer, inputs, create_graph=True)\n        tensor([[[1.4917, 2.4352],\n                 [0.0000, 0.0000]],\n                [[0.0000, 0.0000],\n                 [2.4369, 2.3799]]], grad_fn=<ViewBackward>)\n\n        >>> def exp_adder(x, y):\n        ...     return 2 * x.exp() + 3 * y\n        >>> inputs = (flow.rand(2), flow.rand(2))\n        >>> jacobian(exp_adder, inputs)\n        (tensor([[2.8052, 0.0000],\n                [0.0000, 3.3963]]),\n         tensor([[3., 0.],\n                 [0., 3.]]))\n    \"\"\"\n    assert strategy in (\"forward-mode\", \"reverse-mode\"), (\n        'Expected strategy to be either \"forward-mode\" or \"reverse-mode\". Hint: If your '\n        'function has more outputs than inputs, \"forward-mode\" tends to be more performant. '\n        'Otherwise, prefer to use \"reverse-mode\".'\n    )\n    if strategy == \"forward-mode\":\n        if create_graph:\n            raise NotImplementedError(\n                \"flow.autograd.functional.jacobian: `create_graph=True` \"\n                'and `strategy=\"forward-mode\"` are not supported together (yet). '\n                \"Please either set `create_graph=False` or \"\n                '`strategy=\"reverse-mode\"`.'\n            )\n        return _jacfwd(func, inputs, strict, vectorize)\n\n    with flow.enable_grad():\n        is_inputs_tuple, inputs = _as_tuple(inputs, \"inputs\", \"jacobian\")\n        inputs = _grad_preprocess(inputs, create_graph=create_graph, need_graph=True)\n\n        outputs = func(*inputs)\n        is_outputs_tuple, outputs = _as_tuple(\n            outputs, \"outputs of the user-provided function\", \"jacobian\"\n        )\n        _check_requires_grad(outputs, \"outputs\", strict=strict)\n\n        if vectorize:\n            if strict:\n                raise RuntimeError(\n                    \"flow.autograd.functional.jacobian: `strict=True` \"\n                    \"and `vectorized=True` are not supported together. \"\n                    \"Please either set `strict=False` or \"\n                    \"`vectorize=False`.\"\n                )\n            # NOTE: [Computing jacobian with vmap and grad for multiple outputs]\n            #\n            # Let's consider f(x) = (x**2, x.sum()) and let x = flow.randn(3).\n            # It turns out we can compute the jacobian of this function with a single\n            # call to autograd.grad by using vmap over the correct grad_outputs.\n            #\n            # Firstly, one way to compute the jacobian is to stack x**2 and x.sum()\n            # into a 4D vector. E.g., use g(x) = flow.stack([x**2, x.sum()])\n            #\n            # To get the first row of the jacobian, we call\n            # >>> autograd.grad(g(x), x, grad_outputs=flow.tensor([1, 0, 0, 0]))\n            # To get the 2nd row of the jacobian, we call\n            # >>> autograd.grad(g(x), x, grad_outputs=flow.tensor([0, 1, 0, 0]))\n            # and so on.\n            #\n            # Using vmap, we can vectorize all 4 of these computations into one by\n            # passing the standard basis for R^4 as the grad_output.\n            # vmap(partial(autograd.grad, g(x), x))(flow.eye(4)).\n            #\n            # Now, how do we compute the jacobian *without stacking the output*?\n            # We can just split the standard basis across the outputs. So to\n            # compute the jacobian of f(x), we'd use\n            # >>> autograd.grad(f(x), x, grad_outputs=_construct_standard_basis_for(...))\n            # The grad_outputs looks like the following:\n            # ( flow.tensor([[1, 0, 0],\n            #                 [0, 1, 0],\n            #                 [0, 0, 1],\n            #                 [0, 0, 0]]),\n            #   flow.tensor([[0],\n            #                 [0],\n            #                 [0],\n            #                 [1]]) )\n            #\n            # But we're not done yet!\n            # >>> vmap(partial(autograd.grad(f(x), x, grad_outputs=...)))\n            # returns a Tensor of shape [4, 3]. We have to remember to split the\n            # jacobian of shape [4, 3] into two:\n            # - one of shape [3, 3] for the first output\n            # - one of shape [   3] for the second output\n\n            # Step 1: Construct grad_outputs by splitting the standard basis\n            output_numels = tuple(output.numel() for output in outputs)\n            grad_outputs = _construct_standard_basis_for(outputs, output_numels)\n            flat_outputs = tuple(output.reshape(-1) for output in outputs)\n\n            # Step 2: Call vmap + autograd.grad\n            def vjp(grad_output):\n                vj = list(\n                    _autograd_grad(\n                        flat_outputs, inputs, grad_output, create_graph=create_graph,\n                    )\n                )\n                for el_idx, vj_el in enumerate(vj):\n                    if vj_el is not None:\n                        continue\n                    vj[el_idx] = flow.zeros_like(inputs[el_idx]).expand(\n                        (sum(output_numels),) + inputs[el_idx].shape\n                    )\n                return tuple(vj)\n\n            jacobians_of_flat_output = vjp(grad_outputs)\n\n            # Step 3: The returned jacobian is one big tensor per input. In this step,\n            # we split each Tensor by output.\n            jacobian_input_output = []\n            for jac_input_i, input_i in zip(jacobians_of_flat_output, inputs):\n                jacobian_input_i_output = []\n                for jac, output_j in zip(\n                    jac_input_i.split(output_numels, dim=0), outputs\n                ):\n                    jacobian_input_i_output_j = jac.view(output_j.shape + input_i.shape)\n                    jacobian_input_i_output.append(jacobian_input_i_output_j)\n                jacobian_input_output.append(jacobian_input_i_output)\n\n            # Step 4: Right now, `jacobian` is a List[List[Tensor]].\n            # The outer List corresponds to the number of inputs,\n            # the inner List corresponds to the number of outputs.\n            # We need to exchange the order of these and convert to tuples\n            # before returning.\n            jacobian_output_input = tuple(zip(*jacobian_input_output))\n\n            jacobian_output_input = _grad_postprocess(\n                jacobian_output_input, create_graph\n            )\n            return _tuple_postprocess(\n                jacobian_output_input, (is_outputs_tuple, is_inputs_tuple)\n            )\n\n        jacobian: Tuple[flow.Tensor, ...] = tuple()\n\n        for i, out in enumerate(outputs):\n            # mypy complains that expression and variable have different types due to the empty list\n            jac_i: Tuple[List[flow.Tensor]] = tuple([] for _ in range(len(inputs)))  # type: ignore[assignment]\n            for j in range(out.nelement()):\n                vj = _autograd_grad(\n                    (out.reshape(-1)[j],),\n                    inputs,\n                    retain_graph=True,\n                    create_graph=create_graph,\n                )\n\n                for el_idx, (jac_i_el, vj_el, inp_el) in enumerate(\n                    zip(jac_i, vj, inputs)\n                ):\n                    if vj_el is not None:\n                        if strict and create_graph and not vj_el.requires_grad:\n                            msg = (\n                                \"The jacobian of the user-provided function is \"\n                                f\"independent of input {i}. This is not allowed in \"\n                                \"strict mode when create_graph=True.\"\n                            )\n                            raise RuntimeError(msg)\n                        jac_i_el.append(vj_el)\n                    else:\n                        if strict:\n                            msg = (\n                                f\"Output {i} of the user-provided function is \"\n                                f\"independent of input {el_idx}. This is not allowed in \"\n                                \"strict mode.\"\n                            )\n                            raise RuntimeError(msg)\n                        jac_i_el.append(flow.zeros_like(inp_el))\n\n            jacobian += (\n                tuple(\n                    flow.stack(jac_i_el, dim=0).view(\n                        out.size() + inputs[el_idx].size()  # type: ignore[operator]\n                    )\n                    for (el_idx, jac_i_el) in enumerate(jac_i)\n                ),\n            )\n\n        jacobian = _grad_postprocess(jacobian, create_graph)\n\n        return _tuple_postprocess(jacobian, (is_outputs_tuple, is_inputs_tuple))\n\n\ndef hessian(\n    func,\n    inputs,\n    create_graph=False,\n    strict=False,\n    vectorize=False,\n    outer_jacobian_strategy=\"reverse-mode\",\n):\n    r\"\"\"Compute the Hessian of a given scalar function.\n\n        The documentation is referenced from: https://pytorch.org/docs/stable/generated/torch.autograd.functional.hessian.html\n\n    Args:\n        func (function): a Python function that takes Tensor inputs and returns\n            a Tensor with a single element.\n        inputs (tuple of Tensors or Tensor): inputs to the function ``func``.\n        create_graph (bool, optional): If ``True``, the Hessian will be computed in\n            a differentiable manner. Note that when ``strict`` is ``False``, the result can not\n            require gradients or be disconnected from the inputs.\n            Defaults to ``False``.\n        strict (bool, optional): If ``True``, an error will be raised when we detect that there exists an input\n            such that all the outputs are independent of it. If ``False``, we return a Tensor of zeros as the\n            hessian for said inputs, which is the expected mathematical value.\n            Defaults to ``False``.\n        vectorize (bool, optional): This feature is experimental.\n            Please consider using :func:`flow.func.hessian`\n            instead if you are looking for something less experimental and more performant.\n            When computing the hessian, usually we invoke\n            ``autograd.grad`` once per row of the hessian. If this flag is\n            ``True``, we use the vmap prototype feature as the backend to\n            vectorize calls to ``autograd.grad`` so we only invoke it once\n            instead of once per row. This should lead to performance\n            improvements in many use cases, however, due to this feature\n            being incomplete, there may be performance cliffs. Please\n            use `flow._C._debug_only_display_vmap_fallback_warnings(True)`\n            to show any performance warnings and file us issues if\n            warnings exist for your use case. Defaults to ``False``.\n        outer_jacobian_strategy (str, optional): The Hessian is computed by\n            computing the Jacobian of a Jacobian. The inner Jacobian is always\n            computed in reverse-mode AD. Setting strategy to ``\"forward-mode\"``\n            or ``\"reverse-mode\"`` determines whether the outer Jacobian will be\n            computed with forward or reverse mode AD. Currently, computing the outer\n            Jacobian in ``\"forward-mode\"`` requires ``vectorized=True``. Defaults\n            to ``\"reverse-mode\"``.\n\n    Returns:\n        Hessian (Tensor or a tuple of tuple of Tensors): if there is a single input,\n        this will be a single Tensor containing the Hessian for the input.\n        If it is a tuple, then the Hessian will be a tuple of tuples where\n        ``Hessian[i][j]`` will contain the Hessian of the ``i``\\th input\n        and ``j``\\th input with size the sum of the size of the ``i``\\th input plus\n        the size of the ``j``\\th input. ``Hessian[i][j]`` will have the same\n        dtype and device as the corresponding ``i``\\th input.\n\n    Example:\n\n        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)\n        >>> def pow_reducer(x):\n        ...     return x.pow(3).sum()\n        >>> inputs = flow.rand(2, 2)\n        >>> # xdoctest: +IGNORE_WANT(\"non-deterministic\")\n        >>> hessian(pow_reducer, inputs)\n        tensor([[[[5.2265, 0.0000],\n                  [0.0000, 0.0000]],\n                 [[0.0000, 4.8221],\n                  [0.0000, 0.0000]]],\n                [[[0.0000, 0.0000],\n                  [1.9456, 0.0000]],\n                 [[0.0000, 0.0000],\n                  [0.0000, 3.2550]]]])\n\n        >>> hessian(pow_reducer, inputs, create_graph=True)\n        tensor([[[[5.2265, 0.0000],\n                  [0.0000, 0.0000]],\n                 [[0.0000, 4.8221],\n                  [0.0000, 0.0000]]],\n                [[[0.0000, 0.0000],\n                  [1.9456, 0.0000]],\n                 [[0.0000, 0.0000],\n                  [0.0000, 3.2550]]]], grad_fn=<ViewBackward>)\n\n\n        >>> def pow_adder_reducer(x, y):\n        ...     return (2 * x.pow(2) + 3 * y.pow(2)).sum()\n        >>> inputs = (flow.rand(2), flow.rand(2))\n        >>> hessian(pow_adder_reducer, inputs)\n        ((tensor([[4., 0.],\n                  [0., 4.]]),\n          tensor([[0., 0.],\n                  [0., 0.]])),\n         (tensor([[0., 0.],\n                  [0., 0.]]),\n          tensor([[6., 0.],\n                  [0., 6.]])))\n    \"\"\"\n    is_inputs_tuple, inputs = _as_tuple(inputs, \"inputs\", \"hessian\")\n    assert outer_jacobian_strategy in (\n        \"forward-mode\",\n        \"reverse-mode\",\n    ), 'Expected strategy to be either \"forward-mode\" or \"reverse-mode\".'\n\n    def ensure_single_output_function(*inp):\n        out = func(*inp)\n        is_out_tuple, t_out = _as_tuple(\n            out, \"outputs of the user-provided function\", \"hessian\"\n        )\n        _check_requires_grad(t_out, \"outputs\", strict=strict)\n\n        if is_out_tuple or not isinstance(out, flow.Tensor):\n            raise RuntimeError(\n                \"The function given to hessian should return a single Tensor\"\n            )\n\n        if out.nelement() != 1:\n            raise RuntimeError(\n                \"The Tensor returned by the function given to hessian should contain a single element\"\n            )\n\n        return out.squeeze()\n\n    def jac_func(*inp):\n        if outer_jacobian_strategy == \"forward-mode\":\n            # _grad_preprocess requires create_graph=True and input to require_grad\n            # or else the input will be detached\n            inp = tuple(t.requires_grad_(True) for t in inp)\n        jac = jacobian(ensure_single_output_function, inp, create_graph=True)\n        _check_requires_grad(jac, \"jacobian\", strict=strict)\n        return jac\n\n    res = jacobian(\n        jac_func,\n        inputs,\n        create_graph=create_graph,\n        strict=strict,\n        vectorize=vectorize,\n        strategy=outer_jacobian_strategy,\n    )\n    return _tuple_postprocess(res, (is_inputs_tuple, is_inputs_tuple))\n\n\ndef vhp(func, inputs, v=None, create_graph=False, strict=False):\n    r\"\"\"Compute the dot product between vector ``v`` and Hessian of a  given scalar function at a specified point.\n\n    The documentation is referenced from: https://pytorch.org/docs/stable/generated/torch.autograd.functional.vhp.html\n\n    Args:\n        func (function): a Python function that takes Tensor inputs and returns\n            a Tensor with a single element.\n        inputs (tuple of Tensors or Tensor): inputs to the function ``func``.\n        v (tuple of Tensors or Tensor): The vector for which the vector Hessian\n            product is computed. Must be the same size as the input of\n            ``func``. This argument is optional when ``func``'s input contains\n            a single element and (if it is not provided) will be set as a\n            Tensor containing a single ``1``.\n        create_graph (bool, optional): If ``True``, both the output and result\n            will be computed in a differentiable way. Note that when ``strict``\n            is ``False``, the result can not require gradients or be\n            disconnected from the inputs.\n            Defaults to ``False``.\n        strict (bool, optional): If ``True``, an error will be raised when we\n            detect that there exists an input such that all the outputs are\n            independent of it. If ``False``, we return a Tensor of zeros as the\n            vhp for said inputs, which is the expected mathematical value.\n            Defaults to ``False``.\n\n    Returns:\n        output (tuple): tuple with:\n            func_output (tuple of Tensors or Tensor): output of ``func(inputs)``\n\n            vhp (tuple of Tensors or Tensor): result of the dot product with the\n            same shape as the inputs.\n\n    Example:\n\n        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)\n        >>> def pow_reducer(x):\n        ...     return x.pow(3).sum()\n        >>> inputs = flow.rand(2, 2)\n        >>> v = flow.ones(2, 2)\n        >>> # xdoctest: +IGNORE_WANT(\"non-deterministic\")\n        >>> vhp(pow_reducer, inputs, v)\n        (tensor(0.5591),\n         tensor([[1.0689, 1.2431],\n                 [3.0989, 4.4456]]))\n        >>> vhp(pow_reducer, inputs, v, create_graph=True)\n        (tensor(0.5591, grad_fn=<SumBackward0>),\n         tensor([[1.0689, 1.2431],\n                 [3.0989, 4.4456]], grad_fn=<MulBackward0>))\n        >>> def pow_adder_reducer(x, y):\n        ...     return (2 * x.pow(2) + 3 * y.pow(2)).sum()\n        >>> inputs = (flow.rand(2), flow.rand(2))\n        >>> v = (flow.zeros(2), flow.ones(2))\n        >>> vhp(pow_adder_reducer, inputs, v)\n        (tensor(4.8053),\n         (tensor([0., 0.]),\n          tensor([6., 6.])))\n    \"\"\"\n    with flow.enable_grad():\n        is_inputs_tuple, inputs = _as_tuple(inputs, \"inputs\", \"vhp\")\n        inputs = _grad_preprocess(inputs, create_graph=create_graph, need_graph=True)\n\n        if v is not None:\n            _, v = _as_tuple(v, \"v\", \"vhp\")\n            v = _grad_preprocess(v, create_graph=create_graph, need_graph=False)\n            _validate_v(v, inputs, is_inputs_tuple)\n        else:\n            if len(inputs) != 1 or inputs[0].nelement() != 1:\n                raise RuntimeError(\n                    \"The vector v can only be None if the input to the user-provided function \"\n                    \"is a single Tensor with a single element.\"\n                )\n        outputs = func(*inputs)\n        is_outputs_tuple, outputs = _as_tuple(\n            outputs, \"outputs of the user-provided function\", \"vhp\"\n        )\n        _check_requires_grad(outputs, \"outputs\", strict=strict)\n\n        if is_outputs_tuple or not isinstance(outputs[0], flow.Tensor):\n            raise RuntimeError(\n                \"The function given to vhp should return a single Tensor\"\n            )\n\n        if outputs[0].nelement() != 1:\n            raise RuntimeError(\n                \"The Tensor returned by the function given to vhp should contain a single element\"\n            )\n\n        jac = _autograd_grad(outputs, inputs, create_graph=True)\n        _check_requires_grad(jac, \"jacobian\", strict=strict)\n\n    enable_grad = True if create_graph else flow.is_grad_enabled()\n    with flow.set_grad_enabled(enable_grad):\n        grad_res = _autograd_grad(jac, inputs, v, create_graph=create_graph)\n        vhp = _fill_in_zeros(grad_res, inputs, strict, create_graph, \"double_back\")\n\n    outputs = _grad_postprocess(outputs, create_graph)\n    vhp = _grad_postprocess(vhp, create_graph)\n\n    return (\n        _tuple_postprocess(outputs, is_outputs_tuple),\n        _tuple_postprocess(vhp, is_inputs_tuple),\n    )\n\n\ndef hvp(func, inputs, v=None, create_graph=False, strict=False):\n    r\"\"\"Compute the dot product between the scalar function's Hessian and a vector ``v`` at a specified point.\n\n    The documentation is referenced from: https://pytorch.org/docs/stable/generated/torch.autograd.functional.hvp.html\n\n    Args:\n        func (function): a Python function that takes Tensor inputs and returns\n            a Tensor with a single element.\n        inputs (tuple of Tensors or Tensor): inputs to the function ``func``.\n        v (tuple of Tensors or Tensor): The vector for which the Hessian vector\n            product is computed. Must be the same size as the input of\n            ``func``. This argument is optional when ``func``'s input contains\n            a single element and (if it is not provided) will be set as a\n            Tensor containing a single ``1``.\n        create_graph (bool, optional): If ``True``, both the output and result will be\n            computed in a differentiable way. Note that when ``strict`` is\n            ``False``, the result can not require gradients or be disconnected\n            from the inputs.  Defaults to ``False``.\n        strict (bool, optional): If ``True``, an error will be raised when we\n            detect that there exists an input such that all the outputs are\n            independent of it. If ``False``, we return a Tensor of zeros as the\n            hvp for said inputs, which is the expected mathematical value.\n            Defaults to ``False``.\n    Returns:\n        output (tuple): tuple with:\n            func_output (tuple of Tensors or Tensor): output of ``func(inputs)``\n\n            hvp (tuple of Tensors or Tensor): result of the dot product with\n            the same shape as the inputs.\n\n    Example:\n\n        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)\n        >>> def pow_reducer(x):\n        ...     return x.pow(3).sum()\n        >>> inputs = flow.rand(2, 2)\n        >>> v = flow.ones(2, 2)\n        >>> # xdoctest: +IGNORE_WANT(\"non-deterministic\")\n        >>> hvp(pow_reducer, inputs, v)\n        (tensor(0.1448),\n         tensor([[2.0239, 1.6456],\n                 [2.4988, 1.4310]]))\n\n        >>> hvp(pow_reducer, inputs, v, create_graph=True)\n        (tensor(0.1448, grad_fn=<SumBackward0>),\n         tensor([[2.0239, 1.6456],\n                 [2.4988, 1.4310]], grad_fn=<MulBackward0>))\n\n\n        >>> def pow_adder_reducer(x, y):\n        ...     return (2 * x.pow(2) + 3 * y.pow(2)).sum()\n        >>> inputs = (flow.rand(2), flow.rand(2))\n        >>> v = (flow.zeros(2), flow.ones(2))\n        >>> hvp(pow_adder_reducer, inputs, v)\n        (tensor(2.3030),\n         (tensor([0., 0.]),\n          tensor([6., 6.])))\n\n    Note:\n\n        This function is significantly slower than `vhp` due to backward mode AD constraints.\n        If your functions is twice continuously differentiable, then hvp = vhp.t(). So if you\n        know that your function satisfies this condition, you should use vhp instead that is\n        much faster with the current implementation.\n\n    \"\"\"\n    with flow.enable_grad():\n        is_inputs_tuple, inputs = _as_tuple(inputs, \"inputs\", \"hvp\")\n        inputs = _grad_preprocess(inputs, create_graph=create_graph, need_graph=True)\n\n        if v is not None:\n            _, v = _as_tuple(v, \"v\", \"hvp\")\n            v = _grad_preprocess(v, create_graph=create_graph, need_graph=False)\n            _validate_v(v, inputs, is_inputs_tuple)\n        else:\n            if len(inputs) != 1 or inputs[0].nelement() != 1:\n                raise RuntimeError(\n                    \"The vector v can only be None if the input to the user-provided function \"\n                    \"is a single Tensor with a single element.\"\n                )\n        outputs = func(*inputs)\n        is_outputs_tuple, outputs = _as_tuple(\n            outputs, \"outputs of the user-provided function\", \"hvp\"\n        )\n        _check_requires_grad(outputs, \"outputs\", strict=strict)\n\n        if is_outputs_tuple or not isinstance(outputs[0], flow.Tensor):\n            raise RuntimeError(\n                \"The function given to hvp should return a single Tensor\"\n            )\n\n        if outputs[0].nelement() != 1:\n            raise RuntimeError(\n                \"The Tensor returned by the function given to hvp should contain a single element\"\n            )\n\n        jac = _autograd_grad(outputs, inputs, create_graph=True)\n        _check_requires_grad(jac, \"jacobian\", strict=strict)\n\n        grad_jac = tuple(flow.zeros_like(inp, requires_grad=True) for inp in inputs)\n\n        double_back = _autograd_grad(jac, inputs, grad_jac, create_graph=True)\n        _check_requires_grad(jac, \"hessian\", strict=strict)\n\n    enable_grad = True if create_graph else flow.is_grad_enabled()\n    with flow.set_grad_enabled(enable_grad):\n        grad_res = _autograd_grad(double_back, grad_jac, v, create_graph=create_graph)\n        hvp = _fill_in_zeros(\n            grad_res, inputs, strict, create_graph, \"double_back_trick\"\n        )\n\n    outputs = _grad_postprocess(outputs, create_graph)\n    hvp = _grad_postprocess(hvp, create_graph)\n\n    return (\n        _tuple_postprocess(outputs, is_outputs_tuple),\n        _tuple_postprocess(hvp, is_inputs_tuple),\n    )\n"
  },
  {
    "path": "python/oneflow/autograd/graph.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n# This file is mostly copied from PyTorch\n\nimport oneflow as flow\nfrom typing import Callable, Any\n\n\nclass saved_tensors_hooks:\n    \"\"\"Context-manager that sets a pair of pack / unpack hooks for saved tensors.\n\n    Use this context-manager to define how intermediary results of an operation\n    should be packed before saving, and unpacked on retrieval.\n\n    In that context, the ``pack_hook`` function will be called everytime an\n    operation saves a tensor for backward (this includes intermediary results\n    saved using\n    :func:`~oneflow.autograd.function.save_for_backward` but\n    also those recorded by a OneFlow-defined operation). The output of\n    ``pack_hook`` is then stored in the computation graph instead of the\n    original tensor.\n\n    The ``unpack_hook`` is called when the saved tensor needs to be accessed,\n    namely when executing :func:`oneflow.Tensor.backward()` or\n    :func:`oneflow.autograd.grad()`. It takes as argument the *packed* object\n    returned by ``pack_hook`` and should return a tensor which has the same\n    content as the original tensor (passed as input to the corresponding\n    ``pack_hook``).\n\n    The hooks should have the following signatures:\n\n        pack_hook(tensor: Tensor) -> Any\n\n        unpack_hook(Any) -> Tensor\n\n    where the return value of ``pack_hook`` is a valid input to ``unpack_hook``.\n\n    In general, you want ``unpack_hook(pack_hook(t))`` to be equal to ``t`` in terms\n    of value, size, dtype and device.\n\n    Example::\n\n        >>> def pack_hook(x):\n        ...     print(\"Packing\", x)\n        ...     return x\n        >>>\n        >>> def unpack_hook(x):\n        ...     print(\"Unpacking\", x)\n        ...     return x\n        >>>\n        >>> a = flow.ones(5, requires_grad=True)\n        >>> b = flow.ones(5, requires_grad=True) * 2\n        >>> with flow.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook):\n        ...     y = a * b\n        Packing tensor([1., 1., 1., 1., 1.])\n        Packing tensor([2., 2., 2., 2., 2.])\n        >>> y.sum().backward()\n        Unpacking tensor([1., 1., 1., 1., 1.])\n        Unpacking tensor([2., 2., 2., 2., 2.])\n\n    .. warning ::\n        Performing an inplace operation on the input to either hooks may lead\n        to undefined behavior.\n\n    .. warning ::\n        Only one pair of hooks is allowed at a time. When recursively nesting this\n        context-manager, only the inner-most pair of hooks will be applied.\n    \"\"\"\n\n    def __init__(\n        self,\n        pack_hook: Callable[[\"flow.Tensor\"], Any],\n        unpack_hook: Callable[[Any], \"flow.Tensor\"],\n    ):\n        self.pack_hook = pack_hook\n        self.unpack_hook = unpack_hook\n\n    def __enter__(self):\n        flow._oneflow_internal.autograd.graph.append_new_hooks(\n            self.pack_hook, self.unpack_hook\n        )\n\n    def __exit__(self, *args: Any):\n        flow._oneflow_internal.autograd.graph.pop_hooks()\n"
  },
  {
    "path": "python/oneflow/autograd/profiler.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nfrom oneflow.profiler.profiler import profile\n\n\ndef record_function():\n    raise NotImplementedError()\n\n\nclass emit_nvtx:\n    def __init__(self):\n        raise NotImplementedError()\n"
  },
  {
    "path": "python/oneflow/autoprof/__init__.py",
    "content": ""
  },
  {
    "path": "python/oneflow/autoprof/__main__.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport atexit\nimport csv\nimport unittest\nimport os\nimport sys\nimport subprocess\nimport tempfile\n\nimport oneflow as flow\nimport oneflow.test_utils.automated_test_util.profiler as auto_profiler\nfrom oneflow.autoprof.util import *\n\n\ncsv_filename = os.getenv(\"ONEFLOW_PROFILE_CSV\", \"op_prof\")\n\nif csv_filename[-4:] != \".csv\":\n    csv_filename += \".csv\"\n\nf = open(csv_filename, \"w\")\n# all functions registered are called in last in, first out order\nif flow.support.env_var_util.parse_boolean_from_env(\n    \"ONEFLOW_PROFILE_PRINT_SUMMARY\", True\n):\n    atexit.register(print_summary_from_csv, csv_filename)\natexit.register(lambda f: f.close(), f)\n\nwriter = csv.writer(f)\n\nONLY_ONEFLOW = flow.support.env_var_util.parse_boolean_from_env(\n    \"ONEFLOW_PROFILE_ONLY_ONEFLOW\", False\n)\nONLY_PYTORCH = flow.support.env_var_util.parse_boolean_from_env(\n    \"ONEFLOW_PROFILE_ONLY_PYTORCH\", False\n)\nassert not (ONLY_ONEFLOW and ONLY_PYTORCH)\n\nif not ONLY_ONEFLOW and not ONLY_PYTORCH:\n    env = os.environ.copy()\n    env.update({\"ONEFLOW_PROFILE_ONLY_ONEFLOW\": \"1\"})\n    temp_f = tempfile.NamedTemporaryFile(mode=\"w\", suffix=\".csv\", delete=False)\n    env.update({\"ONEFLOW_PROFILE_CSV\": temp_f.name})\n    env.update({\"ONEFLOW_PROFILE_PRINT_SUMMARY\": \"0\"})\n    subprocess.run([sys.executable, \"-m\", \"oneflow.autoprof\", *sys.argv[1:]], env=env)\n    temp_f.close()\n    temp_f = open(temp_f.name, \"r\")\n    rows = list(csv.reader(temp_f))\n    temp_f.close()\n    os.remove(temp_f.name)\n\n    env = os.environ.copy()\n    env.update({\"ONEFLOW_PROFILE_ONLY_PYTORCH\": \"1\"})\n    temp_f = tempfile.NamedTemporaryFile(mode=\"w\", suffix=\".csv\", delete=False)\n    env.update({\"ONEFLOW_PROFILE_CSV\": temp_f.name})\n    env.update({\"ONEFLOW_PROFILE_PRINT_SUMMARY\": \"0\"})\n    subprocess.run([sys.executable, \"-m\", \"oneflow.autoprof\", *sys.argv[1:]], env=env)\n    temp_f.close()\n    temp_f = open(temp_f.name, \"r\")\n    rows.extend(list(csv.reader(temp_f))[1:])\n    temp_f.close()\n    os.remove(temp_f.name)\n\n    writer.writerows(rows)\n    exit(0)\n\nwriter.writerow(\n    [\n        \"OP\",\n        \"Args\",\n        \"Library\",\n        \"Kernel Time (us, GPU)\",\n        \"Kernel Bandwidth (GB/s, GPU)\",\n        \"Kernel Time (us, 1 CPU)\",\n        \"End-to-end Time (us, 1 CPU)\",\n        \"Kernel Time (us, 32 CPUs)\",\n        \"End-to-end Time (us, 32 CPUs)\",\n        \"Description\",\n    ]\n)\n\nauto_profiler.set_hardware_info_list([(\"cuda\", None), (\"cpu\", 1), (\"cpu\", 32)])\n\nif ONLY_ONEFLOW:\n    auto_profiler.profiled_framework = [\"oneflow\"]\nif ONLY_PYTORCH:\n    auto_profiler.profiled_framework = [\"pytorch\"]\n\nauto_profiler.set_profiler_hook(lambda profs: add_row(profs, writer, f))\n\n# Align with https://github.com/python/cpython/blob/3.10/Lib/unittest/__main__.py\n__unittest = True\n\nfrom unittest.main import main\n\nloader = unittest.TestLoader()\nloader.testMethodPrefix = \"profile_\"\n\nmain(module=None, testLoader=loader)\n"
  },
  {
    "path": "python/oneflow/autoprof/util.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nfrom typing import Iterable, Union, TypeVar\n\nfrom rich import box\nfrom rich.console import Console\nfrom rich.table import Table\n\nimport csv\nimport oneflow.test_utils.automated_test_util.profiler as auto_profiler\n\n\nT = TypeVar(\"T\")\n\n\ndef get_sole_value(x: Iterable[T]) -> T:\n    s = set(x)\n    assert len(s) == 1\n    return list(s)[0]\n\n\ndef get_pytorch_cpu_kernel_time(prof) -> Union[str, float]:\n    assert prof.num > 1\n    cpu_kernel_items = list(filter(lambda x: x.count >= prof.num, prof.key_averages()))\n    if len(cpu_kernel_items) == 0:\n        return \"-\"\n    kernel_cpu_time = (\n        sum(map(lambda x: x.self_cpu_time_total, cpu_kernel_items)) / prof.num\n    )\n    return round(kernel_cpu_time, 1)\n\n\ndef get_oneflow_cpu_kernel_time(prof) -> Union[str, float]:\n    assert prof.num > 1\n    cpu_kernel_items = list(filter(lambda x: x.count >= prof.num, prof.key_averages()))\n    if len(cpu_kernel_items) == 0:\n        return \"-\"\n    kernel_cpu_time = sum(map(lambda x: x.cpu_time_total, cpu_kernel_items)) / prof.num\n    return round(kernel_cpu_time, 1)\n\n\ndef get_pytorch_gpu_kernel_time(prof) -> Union[str, float]:\n    gpu_kernel_items = list(filter(lambda x: x.count >= prof.num, prof.key_averages()))\n    if len(gpu_kernel_items) == 0:\n        return \"-\"\n    kernel_gpu_time = (\n        sum(map(lambda x: x.self_cuda_time_total, gpu_kernel_items)) / prof.num\n    )\n    return round(kernel_gpu_time, 1)\n\n\ndef get_oneflow_gpu_kernel_time(prof) -> Union[str, float]:\n    gpu_kernel_items = list(\n        filter(lambda x: x.cuda_time_total is not None, prof.key_averages())\n    )\n    if len(gpu_kernel_items) == 0:\n        return \"-\"\n    kernel_gpu_time = sum(map(lambda x: x.cuda_time_total, gpu_kernel_items)) / prof.num\n    return round(kernel_gpu_time, 1)\n\n\ndef get_oneflow_gpu_kernel_bandwidth(prof) -> str:\n    gpu_kernel_items = list(\n        filter(lambda x: x.cuda_time_total is not None, prof.key_averages())\n    )\n    if len(gpu_kernel_items) == 0:\n        return \"-\"\n    if len(gpu_kernel_items) == 1:\n        return gpu_kernel_items[0].bandwidth\n    return \", \".join([f\"{x.name}: {x.bandwidth}\" for x in gpu_kernel_items])\n\n\ndef get_pytorch_cpu_end_to_end_time(prof) -> float:\n    total = get_sole_value(\n        filter(lambda x: x.key == auto_profiler.END_TO_END, prof.key_averages())\n    )\n    assert total.count == 1\n    return round(total.cpu_time / prof.num, 1)\n\n\ndef get_oneflow_cpu_end_to_end_time(prof) -> float:\n    total = list(\n        filter(lambda x: x.name == auto_profiler.END_TO_END, prof.key_averages())\n    )[0]\n    assert total.count == 1\n    return round(total.cpu_time / prof.num, 1)\n\n\ndef add_row(profs, writer, f):\n    non_none_profs = list(filter(lambda x: x is not None, profs))\n    op_name = get_sole_value([prof.op_name for prof in non_none_profs])\n    args_description = get_sole_value(\n        [prof.args_description for prof in non_none_profs]\n    )\n    additional_description = get_sole_value(\n        [prof.additional_description for prof in non_none_profs]\n    )\n    if \"oneflow\" in auto_profiler.profiled_framework:\n        writer.writerow(\n            [\n                op_name,\n                args_description,\n                \"OneFlow\",\n                get_oneflow_gpu_kernel_time(profs[0]),\n                get_oneflow_gpu_kernel_bandwidth(profs[0]),\n                get_oneflow_cpu_kernel_time(profs[1]),\n                get_oneflow_cpu_end_to_end_time(profs[1]),\n                get_oneflow_cpu_kernel_time(profs[2]),\n                get_oneflow_cpu_end_to_end_time(profs[2]),\n                additional_description,\n            ]\n        )\n    if \"pytorch\" in auto_profiler.profiled_framework:\n        writer.writerow(\n            [\n                op_name,\n                args_description,\n                \"PyTorch\",\n                get_pytorch_gpu_kernel_time(profs[3]),\n                \"-\",\n                get_pytorch_cpu_kernel_time(profs[4]),\n                get_pytorch_cpu_end_to_end_time(profs[4]),\n                get_pytorch_cpu_kernel_time(profs[5]),\n                get_pytorch_cpu_end_to_end_time(profs[5]),\n                additional_description,\n            ]\n        )\n    f.flush()\n\n\ndef print_summary_from_csv(filename) -> None:\n    print(\"----------------------------------------------------------------------\")\n    print(\n        'Summary (\"KT\" means \"Kernel Time\", \"ET\" means \"End-to-end Time\", in microseconds; \"BW\" means \"Bandwidth\" in GB/s):'\n    )\n    with open(filename, \"r\") as f:\n        table = Table(\n            \"OP\",\n            \"Args\",\n            \"Lib\",\n            \"KT(GPU)\",\n            \"BW(GPU)\",\n            \"KT(1 CPU)\",\n            \"ET(1 CPU)\",\n            \"KT(32 CPU)\",\n            \"ET(32 CPU)\",\n            box=box.SIMPLE,\n        )\n        for row in list(csv.reader(f))[1:]:\n            row[2] = {\"PyTorch\": \"PT\", \"OneFlow\": \"OF\"}[row[2]]\n            table.add_row(*row[:-1])\n        Console().print(table)\n"
  },
  {
    "path": "python/oneflow/backends/__init__.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nfrom . import cuda\nfrom . import cudnn\nfrom . import mps\n"
  },
  {
    "path": "python/oneflow/backends/cuda/__init__.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport oneflow._oneflow_internal\n\n\nclass cuMatmulMode:\n    def __getattr__(self, name):\n        if name == \"allow_tf32\":\n            return oneflow._oneflow_internal.ep.is_matmul_allow_tf32()\n        elif name == \"allow_fp16_reduced_precision_reduction\":\n            return (\n                oneflow._oneflow_internal.ep.is_matmul_allow_fp16_reduced_precision_reduction()\n            )\n        raise AssertionError(\"Unknown attribute \" + name)\n\n    def __setattr__(self, name, value):\n        if name == \"allow_tf32\":\n            return oneflow._oneflow_internal.ep.set_matmul_allow_tf32(value)\n        elif name == \"allow_fp16_reduced_precision_reduction\":\n            return oneflow._oneflow_internal.ep.set_matmul_allow_fp16_reduced_precision_reduction(\n                value\n            )\n        raise AssertionError(\"Unknown attribute \" + name)\n\n\nmatmul = cuMatmulMode()\n"
  },
  {
    "path": "python/oneflow/backends/cudnn/__init__.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nfrom oneflow.framework.config_util import (\n    api_reserved_device_mem_mbyte as set_reserved_mem_mbytes,\n)\nfrom oneflow.framework.config_util import (\n    api_enable_cudnn_fused_normalization_add_relu as enable_fused_normalization_add_relu,\n)\nfrom oneflow.framework.config_util import (\n    api_enable_cudnn_conv_heuristic_search_algo as enable_conv_heuristic_search_algo,\n)\n\nbenchmark = False\n"
  },
  {
    "path": "python/oneflow/backends/mps/__init__.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\n__all__ = [\"is_available\"]\n\n\ndef is_available() -> bool:\n    return False\n"
  },
  {
    "path": "python/oneflow/boxing/__init__.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nfrom oneflow.framework.config_util import api_enable_fusion as enable_fusion\nfrom . import nccl\n"
  },
  {
    "path": "python/oneflow/boxing/nccl/__init__.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nfrom oneflow.framework.config_util import (\n    api_nccl_fusion_threshold_mb as set_fusion_threshold_mbytes,\n    api_nccl_fusion_max_ops as set_fusion_max_ops_num,\n    api_nccl_fusion_all_reduce as allow_fuse_all_reduce,\n    api_nccl_fusion_reduce_scatter as allow_fuse_reduce_scatter,\n    api_nccl_fusion_all_gather as allow_fuse_all_gather,\n    api_nccl_fusion_reduce as allow_fuse_reduce,\n    api_nccl_fusion_broadcast as allow_fuse_broadcast,\n    api_nccl_enable_mixed_fusion as allow_fuse_mixed_ops,\n    api_nccl_fusion_all_reduce_use_buffer as enable_use_buffer_to_fuse_all_reduce,\n)\n\nfrom oneflow.framework.config_util import api_nccl_num_streams as set_stream_num\n\nfrom oneflow.framework.config_util import (\n    api_nccl_enable_all_to_all as enable_all_to_all,\n)\n\nfrom oneflow.framework.config_util import (\n    api_nccl_use_compute_stream as enable_use_compute_stream,\n)\n\nfrom oneflow.framework.config_util import (\n    api_disable_group_boxing_by_dst_parallel as disable_group_boxing_by_dst_parallel,\n)\n"
  },
  {
    "path": "python/oneflow/comm/__init__.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nfrom oneflow.comm.comm_ops import all_reduce\nfrom oneflow.comm.comm_ops import all_gather\nfrom oneflow.comm.comm_ops import all_gather_into_tensor\nfrom oneflow.comm.comm_ops import reduce_scatter_tensor\nfrom oneflow.comm.comm_ops import broadcast\nfrom oneflow.comm.comm_ops import scatter\nfrom oneflow.comm.comm_ops import reduce\nfrom oneflow.comm.comm_ops import all_to_all\nfrom oneflow.comm.comm_ops import barrier\nfrom oneflow.comm.comm_ops import reduce_scatter\nfrom oneflow.comm.comm_ops import gather\nfrom oneflow._C import send, recv\n"
  },
  {
    "path": "python/oneflow/comm/comm_ops.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport oneflow as flow\nimport numpy as np\n\n\ndef all_reduce(tensor):\n    \"\"\"\n    Reduces the tensor data across all machines in such a way that all get\n    the final result.\n    After the call ``tensor`` is going to be bitwise identical in all processes.\n\n    Args:\n        tensor (Tensor): the input tensor\n\n    For example:\n\n    .. code-block:: python\n\n        >>> # We have 1 process groups, 2 ranks.\n        >>> import oneflow as flow\n\n        >>> tensor = flow.tensor([[1, 2], [3, 4]], device=\"cuda\") + flow.env.get_local_rank()\n        >>> # tensor on rank0\n        >>> tensor # doctest: +ONLY_CHECK_RANK_0\n        tensor([[1, 2],\n                [3, 4]], device='cuda:0', dtype=oneflow.int64)\n        >>> # tensor on rank1\n        >>> tensor # doctest: +ONLY_CHECK_RANK_1\n        tensor([[2, 3],\n                [4, 5]], device='cuda:1', dtype=oneflow.int64)\n        >>> flow.comm.all_reduce(tensor)\n        >>> tensor.numpy()\n        array([[3, 5],\n               [7, 9]], dtype=int64)\n\n    \"\"\"\n    assert isinstance(tensor, flow._oneflow_internal.Tensor)\n    assert tensor.device.index == flow.env.get_local_rank()\n    assert tensor.is_local\n    flow._C.local_all_reduce(tensor, inplace=True)\n\n\ndef all_gather(tensor_list, tensor):\n    \"\"\"\n    Gathers tensors from the whole group in a list.\n\n    Args:\n        tensor_list (list[Tensor]): Output list. It should contain\n            correctly-sized tensors to be used for output of the collective.\n        tensor (Tensor): Tensor to be broadcast from current process.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> # We have 1 process groups, 2 ranks.\n        >>> import oneflow as flow\n\n        >>> input = flow.tensor([[1, 2], [3, 4]], device=\"cuda\") + flow.env.get_local_rank()\n        >>> # input on rank0\n        >>> input # doctest: +ONLY_CHECK_RANK_0\n        tensor([[1, 2],\n                [3, 4]], device='cuda:0', dtype=oneflow.int64)\n        >>> # input on rank1\n        >>> input # doctest: +ONLY_CHECK_RANK_1\n        tensor([[2, 3],\n                [4, 5]], device='cuda:1', dtype=oneflow.int64)\n        >>> tensor_list = [flow.zeros(2, 2, dtype=flow.int64) for _ in range(2)]\n        >>> flow.comm.all_gather(tensor_list, input)\n        >>> # result on rank0\n        >>> tensor_list # doctest: +ONLY_CHECK_RANK_0\n        [tensor([[1, 2],\n                [3, 4]], device='cuda:0', dtype=oneflow.int64), tensor([[2, 3],\n                [4, 5]], device='cuda:0', dtype=oneflow.int64)]\n        >>> # result on rank1\n        >>> tensor_list # doctest: +ONLY_CHECK_RANK_1\n        [tensor([[1, 2],\n                [3, 4]], device='cuda:1', dtype=oneflow.int64), tensor([[2, 3],\n                [4, 5]], device='cuda:1', dtype=oneflow.int64)]\n\n    \"\"\"\n    assert isinstance(tensor, flow._oneflow_internal.Tensor)\n    assert isinstance(tensor_list, list)\n    assert len(tensor_list) == flow.env.get_world_size()\n    assert tensor.device.index == flow.env.get_local_rank()\n    assert tensor.is_local\n    tensor = tensor.expand(*([1] + list(tensor.shape)))\n    device_type = tensor.device.type\n    placement = flow.placement.all(device_type)\n    tensor = (\n        tensor.to_global(placement=placement, sbp=flow.sbp.split(0))\n        .to_global(placement=placement, sbp=flow.sbp.broadcast)\n        .to_local()\n    )\n    assert len(tensor_list) == flow.env.get_world_size()\n    # TODO(): getitem has bug on global tensor with size = [2, 1].\n    for i in range(tensor.shape[0]):\n        tensor_list[i] = tensor[i]\n\n\ndef all_gather_into_tensor(output_tensor, input_tensor):\n    \"\"\"\n    Gather tensors from all ranks and put them in a single output tensor.\n\n    Args:\n        output_tensor (Tensor): Output tensor to accommodate tensor elements\n            from all ranks. It must be correctly sized to have one of the\n            following forms:\n            (i) a concatenation of all the input tensors along the primary\n            dimension; for definition of \"concatenation\", see ``oneflow.cat()``;\n            (ii) a stack of all the input tensors along the primary dimension;\n            for definition of \"stack\", see ``oneflow.stack()``.\n            Examples below may better explain the supported output forms.\n        input_tensor (Tensor): Tensor to be gathered from current rank.\n            The input tensors in this API must have the same size across all ranks.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> # We have 1 process groups, 2 ranks.\n        >>> # All tensors below are of flow.int64 dtype and on CUDA devices.\n        >>> import oneflow as flow\n        >>> tensor_in = flow.tensor([[1, 2, 3], [4, 5, 6]], dtype=flow.int64, device=\"cuda\") + flow.env.get_rank() * 6\n        >>> tensor_in # doctest: +ONLY_CHECK_RANK_0\n        tensor([[1, 2, 3],\n                [4, 5, 6]], device='cuda:0', dtype=oneflow.int64)\n        >>> # Output in concatenation form\n        >>> tensor_out = flow.zeros(4, 3, dtype=flow.int64, device=\"cuda\")\n        >>> flow.comm.all_gather_into_tensor(tensor_out, tensor_in)\n        >>> # result on rank0\n        >>> tensor_out # doctest: +ONLY_CHECK_RANK_0\n        tensor([[ 1,  2,  3],\n                [ 4,  5,  6],\n                [ 7,  8,  9],\n                [10, 11, 12]], device='cuda:0', dtype=oneflow.int64)\n        >>> # result on rank1\n        >>> tensor_out # doctest: +ONLY_CHECK_RANK_1\n        tensor([[ 1,  2,  3],\n                [ 4,  5,  6],\n                [ 7,  8,  9],\n                [10, 11, 12]], device='cuda:1', dtype=oneflow.int64)\n        >>> # Output in stack form\n        >>> tensor_out2 = flow.zeros(2, 3, 2, dtype=flow.int64, device=\"cuda\")\n        >>> flow.comm.all_gather_into_tensor(tensor_out2, tensor_in)\n        >>> # result on rank0\n        >>> tensor_out2 # doctest: +ONLY_CHECK_RANK_0\n        tensor([[[ 1,  2],\n                 [ 3,  4],\n                 [ 5,  6]],\n        <BLANKLINE>\n                [[ 7,  8],\n                 [ 9, 10],\n                 [11, 12]]], device='cuda:0', dtype=oneflow.int64)\n        >>> # result on rank1\n        >>> tensor_out2 # doctest: +ONLY_CHECK_RANK_1\n        tensor([[[ 1,  2],\n                 [ 3,  4],\n                 [ 5,  6]],\n        <BLANKLINE>\n                [[ 7,  8],\n                 [ 9, 10],\n                 [11, 12]]], device='cuda:1', dtype=oneflow.int64)\n\n    \"\"\"\n    assert output_tensor.is_local\n    assert input_tensor.is_local\n    flow._C.local_all_gather(output_tensor, input_tensor)\n\n\ndef broadcast(tensor, src):\n    \"\"\"\n    Broadcasts the tensor to the whole group.\n    ``tensor`` must have the same number of elements in all processes\n    participating in the collective.\n\n    Args:\n        tensor (Tensor): Data to be sent if ``src`` is the rank of current\n            process, and tensor to be used to save received data otherwise.\n        src (int): Source rank.\n\n    .. code-block:: python\n\n        >>> # We have 1 process groups, 2 ranks.\n        >>> import oneflow as flow\n        >>> tensor = flow.tensor([[1, 2], [3, 4]], device=\"cuda\") + flow.env.get_local_rank()\n        >>> # input on rank0\n        >>> tensor # doctest: +ONLY_CHECK_RANK_0\n        tensor([[1, 2],\n                [3, 4]], device='cuda:0', dtype=oneflow.int64)\n        >>> # input on rank1\n        >>> tensor # doctest: +ONLY_CHECK_RANK_1\n        tensor([[2, 3],\n                [4, 5]], device='cuda:1', dtype=oneflow.int64)\n        >>> flow.comm.broadcast(tensor, 0)\n        >>> # result on rank0\n        >>> tensor # doctest: +ONLY_CHECK_RANK_0\n        tensor([[1, 2],\n                [3, 4]], device='cuda:0', dtype=oneflow.int64)\n\n    \"\"\"\n    assert isinstance(src, int)\n    assert isinstance(tensor, flow._oneflow_internal.Tensor)\n    assert tensor.is_local\n    flow._C.comm_broadcast(tensor, src_rank=src, inplace=True)\n\n\ndef scatter(tensor, scatter_list=None, src=0):\n    \"\"\"\n    Scatters a list of tensors to all processes in a group.\n\n    Each process will receive exactly one tensor and store its data in the\n    ``tensor`` argument.\n\n    Args:\n        tensor (Tensor): Output tensor.\n        scatter_list (list[Tensor]): List of tensors to scatter (default is\n            None, must be specified on the source rank)\n        src (int): Source rank (default is 0)\n    \"\"\"\n    assert isinstance(src, int)\n    assert isinstance(tensor, flow._oneflow_internal.Tensor)\n    assert tensor.is_local\n    out_shape = tensor.shape\n    if flow.env.get_rank() == src:\n        tensor.data = scatter_list[src]\n        assert isinstance(scatter_list, list)\n        assert len(scatter_list) == flow.env.get_world_size()\n        for i in range(len(scatter_list)):\n            if i == src:\n                continue\n            assert isinstance(scatter_list[i], flow._oneflow_internal.Tensor)\n            assert scatter_list[i].is_local\n            assert (\n                scatter_list[i].shape == out_shape\n            ), f\"invalid tensor size at index {i}: {out_shape} vs {scatter_list[i].shape}\"\n            flow.comm.send(scatter_list[i], i)\n    # send/recv on the same rank is invalid\n    if flow.env.get_rank() != src:\n        flow.comm.recv(src, out=tensor)\n\n\ndef reduce(tensor, dst):\n    \"\"\"\n    Reduces the tensor data across all machines.\n\n    Only the process with rank ``dst`` is going to receive the final result.\n\n    Args:\n        tensor (Tensor): Input and output of the collective. The function\n            operates in-place.\n        dst (int): Destination rank\n\n    \"\"\"\n    assert isinstance(tensor, flow._oneflow_internal.Tensor)\n    assert tensor.is_local\n    assert isinstance(dst, int)\n    original_tensor = flow._C.identity(tensor)\n    flow.comm.all_reduce(tensor)\n    if flow.env.get_rank() != dst:\n        tensor.data = original_tensor\n\n\ndef all_to_all(output_tensor_list, input_tensor_list):\n    \"\"\"\n    Each process scatters list of input tensors to all processes in a group and\n    return gathered list of tensors in output list.\n\n    Args:\n        output_tensor_list (list[Tensor]): List of tensors to be gathered one\n            per rank.\n        input_tensor_list (list[Tensor]): List of tensors to scatter one per rank.\n\n    \"\"\"\n\n    def _check_list(tensor_list):\n        assert isinstance(tensor_list, list)\n        assert len(tensor_list) == flow.env.get_world_size()\n        shape = tensor_list[0].shape\n        dtype = tensor_list[0].dtype\n        device = tensor_list[0].device\n        for tensor in tensor_list:\n            assert isinstance(tensor, flow._oneflow_internal.Tensor)\n            assert tensor.is_local\n            assert shape == tensor.shape\n            assert dtype == tensor.dtype\n            assert device == tensor.device\n\n    _check_list(output_tensor_list)\n    _check_list(input_tensor_list)\n\n    assert input_tensor_list[0].shape == output_tensor_list[0].shape\n    assert input_tensor_list[0].dtype == output_tensor_list[0].dtype\n    assert input_tensor_list[0].device == output_tensor_list[0].device\n\n    for i in range(flow.env.get_world_size()):\n        flow.comm.scatter(\n            output_tensor_list[i],\n            input_tensor_list if i == flow.env.get_rank() else [],\n            src=i,\n        )\n\n\ndef barrier():\n    \"\"\"\n    Synchronizes all processes.\n\n    \"\"\"\n    flow._oneflow_internal.eager.ClusterSync()\n\n\ndef reduce_scatter(output, input_list):\n    \"\"\"\n    Reduces, then scatters a list of tensors to all processes in a group.\n\n    Args:\n        output (Tensor): Output tensor.\n        input_list (list[Tensor]): List of tensors to reduce and scatter.\n\n    \"\"\"\n    assert isinstance(output, flow._oneflow_internal.Tensor)\n    assert output.is_local\n    assert isinstance(input_list, list)\n    assert len(input_list) == flow.env.get_world_size()\n    output_shape = output.shape\n    device_type = output.device.type\n    placement = flow.placement.all(device_type)\n    reduced_tensor_list = []\n    for tensor in input_list:\n        assert tensor.is_local\n        assert tensor.shape == output_shape\n        tensor = tensor.to_global(\n            placement=placement, sbp=flow.sbp.partial_sum\n        ).to_global(placement=placement, sbp=flow.sbp.broadcast)\n        reduced_tensor_list.append(tensor.to_local())\n    output.data = reduced_tensor_list[flow.env.get_rank()]\n\n\ndef reduce_scatter_tensor(output_tensor, input_tensor):\n    \"\"\"\n    Reduces, then scatters a tensor to all ranks.\n\n    Args:\n        output (Tensor): Output tensor. It should have the same size across all\n            ranks.\n        input (Tensor): Input tensor to be reduced and scattered. Its size\n            should be output tensor size times the world size. The input tensor\n            can have one of the following shapes:\n            (i) a concatenation of the output tensors along the primary\n            dimension, or\n            (ii) a stack of the output tensors along the primary dimension.\n            For definition of \"concatenation\", see ``oneflow.cat()``.\n            For definition of \"stack\", see ``oneflow.stack()``.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> # We have 1 process groups, 2 ranks.\n        >>> # All tensors below are of flow.int64 dtype and on CUDA devices.\n        >>> import oneflow as flow\n        >>> tensor_in = flow.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]], dtype=flow.int64, device=\"cuda\")\n        >>> tensor_in # doctest: +ONLY_CHECK_RANK_0\n        tensor([[ 1,  2,  3],\n                [ 4,  5,  6],\n                [ 7,  8,  9],\n                [10, 11, 12]], device='cuda:0', dtype=oneflow.int64)\n        >>> # Output in concatenation form\n        >>> tensor_out = flow.zeros(2, 3, dtype=flow.int64, device=\"cuda\")\n        >>> flow.comm.reduce_scatter_tensor(tensor_out, tensor_in)\n        >>> # result on rank0\n        >>> tensor_out # doctest: +ONLY_CHECK_RANK_0\n        tensor([[ 2,  4,  6],\n                [ 8, 10, 12]], device='cuda:0', dtype=oneflow.int64)\n        >>> # result on rank1\n        >>> tensor_out # doctest: +ONLY_CHECK_RANK_1\n        tensor([[14, 16, 18],\n                [20, 22, 24]], device='cuda:1', dtype=oneflow.int64)\n        >>> # Output in stack form\n        >>> tensor_in2 = tensor_in.reshape(2, 3, 2)\n        >>> tensor_out2 = flow.zeros(2, 3, dtype=flow.int64, device=\"cuda\")\n        >>> flow.comm.reduce_scatter_tensor(tensor_out2, tensor_in2)\n        >>> # result on rank0\n        >>> tensor_out2 # doctest: +ONLY_CHECK_RANK_0\n        tensor([[ 2,  4,  6],\n                [ 8, 10, 12]], device='cuda:0', dtype=oneflow.int64)\n        >>> # result on rank1\n        >>> tensor_out2 # doctest: +ONLY_CHECK_RANK_1\n        tensor([[14, 16, 18],\n                [20, 22, 24]], device='cuda:1', dtype=oneflow.int64)\n\n    \"\"\"\n    assert output_tensor.is_local\n    assert input_tensor.is_local\n    flow._C.local_reduce_scatter(output_tensor, input_tensor)\n\n\ndef gather(tensor, gather_list=None, dst=0):\n    \"\"\"\n    Gathers a list of tensors in a single process.\n\n    Args:\n        tensor (Tensor): Input tensor.\n        gather_list (list[Tensor], optional): List of appropriately-sized\n            tensors to use for gathered data (default is None, must be specified\n            on the destination rank)\n        dst (int, optional): Destination rank (default is 0)\n\n    \"\"\"\n    assert isinstance(tensor, flow._oneflow_internal.Tensor)\n    assert tensor.is_local\n    shape = tensor.shape\n    dtype = tensor.dtype\n    tensor = tensor.expand(*([1] + list(shape)))\n    device_type = tensor.device.type\n    placement = flow.placement.all(device_type)\n    tensor = tensor.to_global(placement=placement, sbp=flow.sbp.split(0)).to_global(\n        placement=placement, sbp=flow.sbp.broadcast\n    )\n\n    if gather_list is None:\n        gather_list = [\n            flow.empty(shape, dtype=dtype) for _ in range(flow.env.get_world_size())\n        ]\n\n    assert gather_list is not None\n    assert isinstance(gather_list, list)\n    assert len(gather_list) == flow.env.get_world_size()\n    for i in range(tensor.shape[0]):\n        gather_list[i] = tensor[i].to_local()\n"
  },
  {
    "path": "python/oneflow/cuda/__init__.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow as flow\n\nfrom oneflow.cuda.type_tensor import *\nfrom oneflow.cuda._utils import _get_device_index\n\nfrom typing import Optional, Tuple, Union, Any\n\ndefault_generators = flow._oneflow_internal.default_generators()\n\n\ndef is_available() -> bool:\n    r\"\"\"Returns a bool indicating if CUDA is currently available.\"\"\"\n    # This function never throws and returns 0 if driver is missing or can't\n    # be initialized\n    return device_count() > 0\n\n\ndef device_count() -> int:\n    r\"\"\"Returns the number of GPUs available.\"\"\"\n    return flow._oneflow_internal.CudaGetDeviceCount()\n\n\ndef current_device() -> int:\n    r\"\"\"Returns local rank as device index.\"\"\"\n    return flow._oneflow_internal.GetCudaDeviceIndex()\n\n\ndef get_device_properties(device: Union[flow.device, str, int] = None):\n    r\"\"\"Gets the properties of a device.\n\n    The documentation is referenced from:\n    https://pytorch.org/docs/1.10/generated/torch.cuda.get_device_properties.html.\n\n    Args:\n        device(oneflow.device or str or int): device for which to return the properties of the device.\n\n    Returns:\n        the properties of the device.\n    \"\"\"\n    device = _get_device_index(device, optional=True)\n    return flow._oneflow_internal._get_device_properties(device)\n\n\ndef get_device_capability(\n    device: Optional[Union[flow.device, str, int]] = None\n) -> Tuple[int, int]:\n    r\"\"\"Gets the cuda capability of a device.\n\n    The documentation is referenced from:\n    https://pytorch.org/docs/1.10/generated/torch.cuda.get_device_capability.html.\n\n    Args:\n        device (oneflow.device or int or str, optional): device for which to return the\n            device capability. It uses the current device, given by\n            :func:`~oneflow.cuda.current_device`, if :attr:`device` is ``None``\n            (default).\n\n    Returns:\n        tuple(int, int): the major and minor cuda capability of the device\n    \"\"\"\n    device_prop = get_device_properties(device)\n    return device_prop.major, device_prop.minor\n\n\ndef get_device_name(device: Optional[Union[flow.device, str, int]] = None) -> str:\n    r\"\"\"Gets the name of a device.\n\n    The documentation is referenced from:\n    https://pytorch.org/docs/1.10/generated/torch.cuda.get_device_name.html.\n\n    Args:\n        device (oneflow.device or int or str, optional): device for which to return the\n            name. It uses the current device, given by :func:`~oneflow.cuda.current_device`,\n            if :attr:`device` is ``None`` (default).\n\n    Returns:\n        str: the name of the device\n    \"\"\"\n    return get_device_properties(device).name\n\n\ndef manual_seed_all(seed) -> None:\n    r\"\"\"Sets the seed for generating random numbers on all GPUs.\n    \n    The documentation is referenced from:\n    https://pytorch.org/docs/1.10/generated/torch.cuda.manual_seed_all.html.\n\n    It's safe to call this function if CUDA is not available; in that\n    case, it is silently ignored.\n\n    Args:\n        seed (int): The desired seed.\n    \"\"\"\n    seed = int(seed)\n    flow._oneflow_internal.ManualSeedAllCudaGenerator(seed)\n\n\ndef manual_seed(seed: int) -> None:\n    r\"\"\"Sets the seed for generating random numbers for the current GPU.\n    \n    The documentation is referenced from:\n    https://pytorch.org/docs/1.10/generated/torch.cuda.manual_seed.html.\n\n    It's safe to call this function if CUDA is not available; in that\n    case, it is silently ignored.\n\n    Args:\n        seed (int): The desired seed.\n\n    .. warning::\n        If you are working with a multi-GPU model, this function is insufficient\n        to get determinism.  To seed all GPUs, use :func:`manual_seed_all`.\n    \"\"\"\n    seed = int(seed)\n    idx = current_device()\n    flow._oneflow_internal.manual_seed(seed, \"cuda\", idx)\n\n\ndef set_device(device: Union[flow.device, str, int]) -> None:\n    r\"\"\"Sets the current device.\n    \n    The documentation is referenced from:\n    https://pytorch.org/docs/1.10/generated/torch.cuda.set_device.html.\n\n    Usage of this function is discouraged in favor of :attr:`device`. In most\n    cases it's better to use ``CUDA_VISIBLE_DEVICES`` environmental variable.\n\n    Args:\n        device (flow.device or int): selected device. This function is a no-op\n            if this argument is negative.\n    \"\"\"\n    device_idx = _get_device_index(device)\n    if device_idx < 0:\n        return\n    if flow.env.get_world_size() > 0:\n        if device_idx == flow.env.get_local_rank():\n            return\n        raise ValueError(\n            \"Setting cuda device to a device whose index does not equal to the local rank is not supported.\"\n        )\n    flow._oneflow_internal.SetCudaDeviceIndex(device_idx)\n\n\ndef synchronize(device: Union[flow.device, str, int, None] = None) -> None:\n    r\"\"\"\n    \n    Waits for all kernels in all streams on a CUDA device to complete.\n    \n    Note:\n        In the eager mode of oneflow, all operations will be converted\n        into instructions executed in the virtual machine, \n        so in order to comply with the semantics of synchronization,\n        this function will call the `eager.Sync()` function before the device is synchronized,\n        which may affect the operations executed in other devices.\n\n    Args:\n        device (flow.device or int, optional): device for which to synchronize.\n            It uses the current device, given by :func:`~oneflow.cuda.current_device`,\n            if :attr:`device` is ``None`` (default).\n    \"\"\"\n    device_idx = _get_device_index(device, optional=True)\n    if device_idx >= 0:\n        flow._oneflow_internal.eager.Sync()\n        flow._oneflow_internal.CudaSynchronize(device_idx)\n\n\ndef empty_cache() -> None:\n    r\"\"\"\n    \n    Releases all unoccupied cached memory currently held by the caching \n    allocators of all OneFlow streams so those can be re-allocated in OneFlow streams \n    or other GPU application and visible in `nvidia-smi`.\n    \n    Note:\n            :func:`~flow.cuda.empty_cache` may enable one stream to release memory \n            and then freed memory can be used by another stream. It may also help reduce \n            fragmentation of GPU memory in certain cases.\n\n    \"\"\"\n    return flow._oneflow_internal.EmptyCache()\n\n\ndef mem_get_info(device: Any = None) -> Tuple[int, int]:\n    r\"\"\"Returns the global free and total GPU memory for a given\n    device using cudaMemGetInfo.\n\n    The documentation is referenced from:\n    https://pytorch.org/docs/stable/generated/torch.cuda.mem_get_info.html\n\n    Args:\n        device (flow.device or int, optional): selected device. Returns\n            statistic for the current device, given by :func:`~flow.cuda.current_device`,\n            if :attr:`device` is ``None`` (default).\n    \"\"\"\n    if device is None:\n        device = current_device()\n    device = _get_device_index(device)\n    return flow._oneflow_internal.CudaMemGetInfo(device)\n\n\nfrom .random import *  # noqa: F403\n\n\nclass Event:\n    def __init__(self):\n        raise NotImplementedError()\n"
  },
  {
    "path": "python/oneflow/cuda/_utils.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow as flow\nfrom typing import Any, Optional\n\n\ndef _get_current_device_index() -> int:\n    r\"\"\"Checks if there are CUDA devices available and\n    returns the device index of the current default CUDA device.\n    Returns -1 in case there are no CUDA devices available.\n    \n    Arguments: ``None``\n    \"\"\"\n    if flow.cuda.is_available():\n        return flow.cuda.current_device()\n    return -1\n\n\ndef _get_device_index(\n    device: Any, optional: bool = False, allow_cpu: bool = False\n) -> int:\n    r\"\"\"Gets the device index from :attr:`device`, which can be a torch.device\n    object, a Python integer, or ``None``.\n\n    If :attr:`device` is a flow.device object, returns the device index if it\n    is a CUDA device. Note that for a CUDA device without a specified index,\n    i.e., ``flow.device('cuda')``, this will return the current default CUDA\n    device if :attr:`optional` is ``True``. If :attr:`allow_cpu` is ``True``,\n    CPU devices will be accepted and ``-1`` will be returned in this case.\n\n    If :attr:`device` is a Python integer, it is returned as is.\n\n    If :attr:`device` is ``None``, this will return the current default CUDA\n    device if :attr:`optional` is ``True``.\n    \"\"\"\n    device_idx: Optional[int] = None\n    if isinstance(device, str):\n        device = flow.device(device)\n    if isinstance(device, flow.device):\n        if allow_cpu:\n            if device.type not in [\"cuda\", \"cpu\"]:\n                raise ValueError(\n                    \"Expected a cuda or cpu device, but got: {}\".format(device)\n                )\n        elif device.type != \"cuda\":\n            raise ValueError(\"Expected a cuda device, but got: {}\".format(device))\n        device_idx = -1 if device.type == \"cpu\" else device.index\n    if isinstance(device, int):\n        device_idx = device\n    if device_idx is None:\n        if optional:\n            device_idx = _get_current_device_index()\n        else:\n            raise ValueError(\n                \"Expected a flow.device with a specified index \"\n                \"or an integer, but got:{}\".format(device)\n            )\n    return device_idx\n"
  },
  {
    "path": "python/oneflow/cuda/amp/__init__.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nfrom .autocast_mode import autocast\nfrom oneflow.amp import GradScaler\n"
  },
  {
    "path": "python/oneflow/cuda/amp/autocast_mode.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow as flow\nfrom typing import Any, Optional\n\n\nclass autocast(flow.amp.autocast_mode.autocast):\n    r\"\"\"\n    See :class:`oneflow.autocast`.\n    ``oneflow.cuda.amp.autocast(args...)`` is equivalent to ``oneflow.autocast(\"cuda\", args...)``\n    \"\"\"\n\n    def __init__(\n        self,\n        enabled: bool = True,\n        dtype: Optional[flow.dtype] = None,\n        cache_enabled: Optional[bool] = None,\n    ):\n        super().__init__(\n            \"cuda\", enabled=enabled, dtype=dtype, cache_enabled=cache_enabled\n        )\n\n    def __enter__(self):\n        return super().__enter__()\n\n    def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any):  # type: ignore[override]\n        return super().__exit__(exc_type, exc_val, exc_tb)\n\n    def __call__(self, func):\n        return super().__call__(func)\n"
  },
  {
    "path": "python/oneflow/cuda/random.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow as flow\nfrom oneflow import Tensor\nfrom typing import cast, Iterable, List, Union\nfrom . import current_device, device_count\n\n\ndef get_rng_state(device: Union[int, str, flow.device] = \"cuda\") -> Tensor:\n    r\"\"\"Returns the random number generator state of the specified GPU as a ByteTensor.\n\n    Args:\n        device (flow.device or int, optional): The device to return the RNG state of.\n            Default: ``'cuda'`` (i.e., ``flow.device('cuda')``, the current CUDA device).\n    \"\"\"\n    # TODO (add lazy initialization mechanism in OneFlow)\n    # _lazy_init()\n    if isinstance(device, str):\n        device = flow.device(device)\n    elif isinstance(device, int):\n        device = flow.device(\"cuda\", device)\n    idx = device.index\n    if idx is None:\n        idx = current_device()\n    default_generator = flow.cuda.default_generators[idx]\n    return default_generator.get_state()\n\n\ndef get_rng_state_all() -> List[Tensor]:\n    r\"\"\"Returns a list of ByteTensor representing the random number states of all devices.\"\"\"\n\n    results = []\n    for i in range(device_count()):\n        results.append(get_rng_state(i))\n    return results\n\n\ndef set_rng_state(\n    new_state: Tensor, device: Union[int, str, flow.device] = \"cuda\"\n) -> None:\n    r\"\"\"Sets the random number generator state of the specified GPU.\n\n    Args:\n        new_state (flow.ByteTensor): The desired state\n        device (flow.device or int, optional): The device to set the RNG state.\n            Default: ``'cuda'`` (i.e., ``flow.device('cuda')``, the current CUDA device).\n    \"\"\"\n    new_state_copy = new_state.clone()\n    if isinstance(device, str):\n        device = flow.device(device)\n    elif isinstance(device, int):\n        device = flow.device(\"cuda\", device)\n\n    if device.type == \"cpu\":\n        raise ValueError(\n            \"Cannot set RNG state for CPU device in flow.cuda.set_rng_state func!\"\n        )\n    idx = cast(flow.device, device).index\n    if idx is None:\n        idx = current_device()\n    default_generator = flow.cuda.default_generators[idx]\n    default_generator.set_state(new_state_copy)\n\n\ndef set_rng_state_all(new_states: Iterable[Tensor]) -> None:\n    r\"\"\"Sets the random number generator state of all devices.\n\n    Args:\n        new_states (Iterable of flow.ByteTensor): The desired state for each device\"\"\"\n    for i, state in enumerate(new_states):\n        set_rng_state(state, i)\n"
  },
  {
    "path": "python/oneflow/cuda/type_tensor.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow\nimport oneflow as flow\nfrom oneflow._C import cuda\n\nHalfTensor = cuda.HalfTensor\nFloatTensor = cuda.FloatTensor\nDoubleTensor = cuda.DoubleTensor\nBoolTensor = cuda.BoolTensor\nByteTensor = cuda.ByteTensor\nCharTensor = cuda.CharTensor\nIntTensor = cuda.IntTensor\nLongTensor = cuda.LongTensor\nComplexFloatTensor = cuda.ComplexFloatTensor\nComplexDoubleTensor = cuda.ComplexDoubleTensor\n\n\n__all__ = [\n    \"HalfTensor\",\n    \"FloatTensor\",\n    \"DoubleTensor\",\n    \"BoolTensor\",\n    \"ByteTensor\",\n    \"CharTensor\",\n    \"IntTensor\",\n    \"LongTensor\",\n    \"ComplexFloatTensor\",\n    \"ComplexDoubleTensor\",\n    # TODO: Add support for BFloat16Tensor, ComplexHalfTensor\n]\n"
  },
  {
    "path": "python/oneflow/data.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nfrom oneflow.experimental.load_mnist import load_mnist\n"
  },
  {
    "path": "python/oneflow/distributed/__init__.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n# Just for alignment with pytorch, not really useful\nfrom .constants import default_pg_timeout\n\nfrom typing import List, Optional\n\nimport oneflow as flow\n\n\nclass ReduceOp:\n    \"\"\"Reduce operation enum. Mainly for PyTorch compatibility.\n    Currently only support SUM.\n\n    See also :func:`oneflow.comm.all_reduce()`\n    \"\"\"\n\n    SUM = \"sum\"\n\n\ndef is_initialized() -> bool:\n    \"\"\"Always returns True. This function is only for PyTorch compatibility.\n\n    Returns:\n        True\n    \"\"\"\n    return True\n\n\n# PyTorch doesn't have torch.distributed.get_local_rank,\n# we add it for the consistency between flow.env and flow.distributed\nget_local_rank = flow.env.get_local_rank\n\n\ndef get_rank(group=None) -> int:\n    \"\"\"Alias of `oneflow.env.get_rank()` for PyTorch compatibility.\n\n    See also :func:`oneflow.env.get_rank()`\n    \"\"\"\n    assert group is None, \"group is not supported yet\"\n    return flow.env.get_rank()\n\n\ndef get_world_size(group=None) -> int:\n    \"\"\"Alias of `oneflow.env.get_world_size()` for PyTorch compatibility.\n\n    See also :func:`oneflow.env.get_world_size()`\n    \"\"\"\n    assert group is None, \"group is not supported yet\"\n    return flow.env.get_world_size()\n\n\ndef send(tensor: flow.Tensor, dst: int, group=None, tag: int = 0) -> None:\n    \"\"\"Alias of `oneflow.comm.send()` for PyTorch compatibility.\n\n    See also :func:`oneflow.comm.send()`\n    \"\"\"\n    assert group is None, \"group is not supported yet\"\n    assert tag == 0, \"tag is not supported yet\"\n    return flow.comm.send(tensor, dst)\n\n\ndef recv(tensor: flow.Tensor, src: int, group=None, tag: int = 0) -> None:\n    \"\"\"Alias of `oneflow.comm.recv()` for PyTorch compatibility.\n\n    See also :func:`oneflow.comm.recv()`\n    \"\"\"\n    assert group is None, \"group is not supported yet\"\n    assert tag == 0, \"tag is not supported yet\"\n    return flow.comm.recv(tensor, src)\n\n\ndef broadcast(\n    tensor: flow.Tensor, src: int, group=None, async_op: bool = False\n) -> None:\n    \"\"\"Alias of `oneflow.comm.broadcast()` for PyTorch compatibility.\n\n    See also :func:`oneflow.comm.broadcast()`\n    \"\"\"\n    assert group is None, \"group is not supported yet\"\n    assert async_op is False, \"async_op is not supported yet\"\n    return flow.comm.broadcast(tensor, src)\n\n\ndef barrier(group=None, async_op=False, device_ids=None) -> None:\n    \"\"\"Alias of `oneflow.comm.barrier()` for PyTorch compatibility.\n\n    See also :func:`oneflow.comm.barrier()`\n    \"\"\"\n    assert group is None, \"group is not supported yet\"\n    assert async_op is False, \"async_op is not supported yet\"\n    assert device_ids is None, \"device_ids is not supported yet\"\n    return flow.comm.barrier()\n\n\ndef all_reduce(\n    tensor: flow.Tensor, op: ReduceOp, group=None, async_op: bool = False\n) -> None:\n    \"\"\"Alias of `oneflow.comm.all_reduce()` for PyTorch compatibility.\n\n    See also :func:`oneflow.comm.all_reduce()`\n    \"\"\"\n    assert op == ReduceOp.SUM, \"only ReduceOp.SUM is supported\"\n    assert group is None, \"group is not supported yet\"\n    assert async_op is False, \"async_op is not supported yet\"\n    return flow.comm.all_reduce(tensor)\n\n\ndef all_gather(\n    tensor_list: List[flow.Tensor],\n    tensor: flow.Tensor,\n    group=None,\n    async_op: bool = False,\n) -> None:\n    \"\"\"Alias of `oneflow.comm.all_gather()` for PyTorch compatibility.\n\n    See also :func:`oneflow.comm.all_gather()`\n    \"\"\"\n    assert group is None, \"group is not supported yet\"\n    assert async_op is False, \"async_op is not supported yet\"\n    return flow.comm.all_gather(tensor_list, tensor)\n\n\ndef reduce(\n    tensor: flow.Tensor, dst: int, op: ReduceOp, group=None, async_op: bool = False\n) -> None:\n    \"\"\"Alias of `oneflow.comm.reduce()` for PyTorch compatibility.\n\n    See also :func:`oneflow.comm.reduce()`\n    \"\"\"\n    assert op == ReduceOp.SUM, \"only ReduceOp.SUM is supported\"\n    assert group is None, \"group is not supported yet\"\n    assert async_op is False, \"async_op is not supported yet\"\n    return flow.comm.reduce(tensor, dst)\n\n\ndef all_to_all(\n    output_tensor_list: List[flow.Tensor],\n    input_tensor_list: List[flow.Tensor],\n    group=None,\n    async_op: bool = False,\n) -> None:\n    \"\"\"Alias of `oneflow.comm.all_to_all()` for PyTorch compatibility.\n\n    See also :func:`oneflow.comm.all_to_all()`\n    \"\"\"\n    assert group is None, \"group is not supported yet\"\n    assert async_op is False, \"async_op is not supported yet\"\n    return flow.comm.all_to_all(output_tensor_list, input_tensor_list)\n\n\ndef reduce_scatter(\n    output: flow.Tensor,\n    input_list: List[flow.Tensor],\n    op: ReduceOp,\n    group=None,\n    async_op: bool = False,\n) -> None:\n    \"\"\"Alias of `oneflow.comm.reduce_scatter()` for PyTorch compatibility.\n\n    See also :func:`oneflow.comm.reduce_scatter()`\n    \"\"\"\n    assert op == ReduceOp.SUM, \"only ReduceOp.SUM is supported\"\n    assert group is None, \"group is not supported yet\"\n    assert async_op is False, \"async_op is not supported yet\"\n    return flow.comm.reduce_scatter(output, input_list)\n\n\ndef gather(\n    tensor: flow.Tensor,\n    gather_list: Optional[List[flow.Tensor]] = None,\n    dst: int = 0,\n    group=None,\n    async_op: bool = False,\n) -> None:\n    \"\"\"Alias of `oneflow.comm.gather()` for PyTorch compatibility.\n\n    See also :func:`oneflow.comm.gather()`\n    \"\"\"\n    assert group is None, \"group is not supported yet\"\n    assert async_op is False, \"async_op is not supported yet\"\n    return flow.comm.gather(tensor, gather_list, dst)\n\n\ndef is_available():\n    return True\n"
  },
  {
    "path": "python/oneflow/distributed/constants.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n# Just for alignment with pytorch, not really useful\n\nfrom datetime import timedelta\n\ndefault_pg_timeout = timedelta(milliseconds=30 * 60 * 1000)\n"
  },
  {
    "path": "python/oneflow/distributed/launch.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\"\"\"\nThis file is mostly copied from PyTorch v1.8.1 torch/distributed/launch.py\n\"\"\"\nimport os\nimport signal\nimport subprocess\nimport sys\nimport time\nfrom argparse import REMAINDER, ArgumentParser\nfrom typing import IO, Any, List, Optional\n\nstdout_filename = \"stdout\"\nstderr_filename = \"stderr\"\n\n\ndef parse_args():\n    \"\"\"\n    Helper function parsing the command line options\n    @retval ArgumentParser\n    \"\"\"\n    parser = ArgumentParser(\n        description=\"OneFlow distributed training launch helper utility that will spawn up multiple distributed processes\"\n    )\n    parser.add_argument(\n        \"--nnodes\",\n        type=int,\n        default=1,\n        help=\"The number of nodes to use for distributed training\",\n    )\n    parser.add_argument(\n        \"--node_rank\",\n        type=int,\n        default=0,\n        help=\"The rank of the node for multi-node distributed training\",\n    )\n    parser.add_argument(\n        \"--nproc_per_node\",\n        type=int,\n        default=1,\n        help=\"The number of processes to launch on each node, for GPU training, this is recommended to be set to the number of GPUs in your system so that each process can be bound to a single GPU.\",\n    )\n    parser.add_argument(\n        \"--master_addr\",\n        default=\"127.0.0.1\",\n        type=str,\n        help=\"Master node (rank 0)'s address, should be either the IP address or the hostname of node 0, for single node multi-proc training, the --master_addr can simply be 127.0.0.1\",\n    )\n    parser.add_argument(\n        \"--master_port\",\n        default=29500,\n        type=int,\n        help=\"Master node (rank 0)'s free port that needs to be used for communication during distributed training\",\n    )\n    parser.add_argument(\n        \"-m\",\n        \"--module\",\n        default=False,\n        action=\"store_true\",\n        help=\"Changes each process to interpret the launch script as a python module, executing with the same behavior as'python -m'.\",\n    )\n    parser.add_argument(\n        \"--no_python\",\n        default=False,\n        action=\"store_true\",\n        help='Do not prepend the training script with \"python\" - just exec it directly. Useful when the script is not a Python script.',\n    )\n    parser.add_argument(\n        \"--redirect_stdout_and_stderr\",\n        default=False,\n        action=\"store_true\",\n        help=f\"write the stdout and stderr to files\\n                    '{stdout_filename}' and '{stderr_filename}' in logdir.\",\n    )\n    parser.add_argument(\n        \"--logdir\",\n        default=\"log\",\n        type=str,\n        help=f\"Relative path to write subprocess logs to. Passing in a relative\\n        path will create a directory if needed. Note that\\n        successive runs with the same path to write logs to will overwrite existing logs,\\n        so be sure to save logs as needed.\",\n    )\n    parser.add_argument(\n        \"training_script\",\n        type=str,\n        help=\"The full path to the single GPU training program/script to be launched in parallel, followed by all the arguments for the training script\",\n    )\n    parser.add_argument(\"training_script_args\", nargs=REMAINDER)\n    return parser.parse_args()\n\n\ndef main():\n    args = parse_args()\n    dist_world_size = args.nproc_per_node * args.nnodes\n    current_env = os.environ.copy()\n    current_env[\"MASTER_ADDR\"] = args.master_addr\n    current_env[\"MASTER_PORT\"] = str(args.master_port)\n    current_env[\"WORLD_SIZE\"] = str(dist_world_size)\n\n    if args.master_port is None or args.master_port >= 2 ** 16:\n        raise ValueError(\n            f\"The port number of the master endpoint '{args.master_addr}:{args.master_port}' must be an integer \"\n            \"between 0 and 65536.\"\n        )\n\n    if \"OMP_NUM_THREADS\" not in os.environ and args.nproc_per_node > 1:\n        current_env[\"OMP_NUM_THREADS\"] = str(1)\n        print(\n            \"*****************************************\\n\"\n            \"Setting OMP_NUM_THREADS environment variable for each process \"\n            \"to be {} in default, to avoid your system being overloaded, \"\n            \"please further tune the variable for optimal performance in \"\n            \"your application as needed. \\n\"\n            \"*****************************************\".format(\n                current_env[\"OMP_NUM_THREADS\"]\n            )\n        )\n\n    processes: List[Any] = []\n\n    if (\n        args.redirect_stdout_and_stderr\n        and os.path.exists(args.logdir)\n        and not os.path.isdir(args.logdir)\n    ):\n        raise ValueError(\"argument --logdir must be a path to a directory.\")\n\n    subprocess_file_handles = []\n    for local_rank in range(0, args.nproc_per_node):\n        dist_rank = args.nproc_per_node * args.node_rank + local_rank\n        current_env[\"RANK\"] = str(dist_rank)\n        current_env[\"LOCAL_RANK\"] = str(local_rank)\n        with_python = not args.no_python\n        cmd = []\n        if with_python:\n            cmd = [sys.executable, \"-u\"]\n            if args.module:\n                cmd.append(\"-m\")\n        elif args.module:\n            raise ValueError(\n                \"Don't use both the '--no_python' flag and the '--module' flag at the same time.\"\n            )\n        cmd.append(args.training_script)\n        cmd.extend(args.training_script_args)\n        stdout_handle: Optional[IO]\n        stderr_handle: Optional[IO]\n        log_directory_path = os.path.join(\n            os.getcwd(), args.logdir, f\"local_rank_{local_rank}\"\n        )\n        current_env[\"GLOG_log_dir\"] = log_directory_path\n        if args.redirect_stdout_and_stderr:\n            os.makedirs(log_directory_path, exist_ok=True)\n            node_rank = args.node_rank\n            stdout_handle = open(os.path.join(log_directory_path, stdout_filename), \"w\")\n            stderr_handle = open(os.path.join(log_directory_path, stderr_filename), \"w\")\n            subprocess_file_handles.append((stdout_handle, stderr_handle))\n            stdout_name = stdout_handle.name\n            stderr_name = stderr_handle.name\n            print(\n                f\"Note: Stdout and stderr for node {node_rank} rank {local_rank} will\\n            be written to {stdout_name}, {stderr_name} respectively.\"\n            )\n        sig_names = {2: \"SIGINT\", 15: \"SIGTERM\"}\n        last_return_code = None\n\n        # set killing flag to make sure killing signal only executed once\n        kill_flag = True\n\n        def sigkill_handler(signum, frame):\n            nonlocal kill_flag\n            if not kill_flag:\n                return\n            for process in processes:\n                print(f\"Killing subprocess {process.pid}\")\n            kill_flag = False\n            try:\n                # Note: use os.kill or process.kill() may only kill current process\n                # use killpg will kill(use signal) this process and all sub-processes\n                #\n                # Note: Worker processes launched by data loader will exit automatically\n                # when its parent process exits because of `_prctl_pr_set_pdeathsig`.\n                os.killpg(os.getpid(), signal.SIGTERM)\n            except Exception:\n                pass\n            if last_return_code is not None:\n                raise subprocess.CalledProcessError(\n                    returncode=last_return_code, cmd=cmd\n                )\n            if signum in sig_names:\n                print(f\"Main process received {sig_names[signum]}, exiting\")\n            sys.exit(1)\n\n        signal.signal(signal.SIGINT, sigkill_handler)\n        signal.signal(signal.SIGTERM, sigkill_handler)\n        stdout_handle = (\n            None\n            if not subprocess_file_handles\n            else subprocess_file_handles[local_rank][0]\n        )\n        stderr_handle = (\n            None\n            if not subprocess_file_handles\n            else subprocess_file_handles[local_rank][1]\n        )\n        process = subprocess.Popen(\n            cmd, env=current_env, stdout=stdout_handle, stderr=stderr_handle\n        )\n        processes.append(process)\n    try:\n        alive_processes = set(processes)\n        while len(alive_processes):\n            finished_processes = []\n            for process in alive_processes:\n                if process.poll() is None:\n                    continue\n                elif process.returncode != 0:\n                    last_return_code = process.returncode\n                    sigkill_handler(signal.SIGTERM, None)\n                else:\n                    finished_processes.append(process)\n            alive_processes = set(alive_processes) - set(finished_processes)\n            time.sleep(1)\n    finally:\n        for (stdout_handle, stderr_handle) in subprocess_file_handles:\n            stdout_handle.close()\n            stderr_handle.close()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "python/oneflow/distributions/__init__.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nr\"\"\"\nThe documentation is referenced from: \nhttps://github.com/pytorch/pytorch/blob/master/torch/distributions/__init__.py\n\nThe ``distributions`` package contains parameterizable probability distributions\nand sampling functions. This allows the construction of stochastic computation\ngraphs and stochastic gradient estimators for optimization. This package\ngenerally follows the design of the `TensorFlow Distributions`_ package.\n\n.. _`TensorFlow Distributions`:\n    https://arxiv.org/abs/1711.10604\n\nIt is not possible to directly backpropagate through random samples. However,\nthere are two main methods for creating surrogate functions that can be\nbackpropagated through. These are the score function estimator/likelihood ratio\nestimator/REINFORCE and the pathwise derivative estimator. REINFORCE is commonly\nseen as the basis for policy gradient methods in reinforcement learning, and the\npathwise derivative estimator is commonly seen in the reparameterization trick\nin variational autoencoders. Whilst the score function only requires the value\nof samples :math:`f(x)`, the pathwise derivative requires the derivative\n:math:`f'(x)`. The next sections discuss these two in a reinforcement learning\nexample. For more details see\n`Gradient Estimation Using Stochastic Computation Graphs`_ .\n\n.. _`Gradient Estimation Using Stochastic Computation Graphs`:\n     https://arxiv.org/abs/1506.05254\n\nScore function\n^^^^^^^^^^^^^^\n\nWhen the probability density function is differentiable with respect to its\nparameters, we only need :meth:`~oneflow.distributions.Distribution.sample` and\n:meth:`~oneflow.distributions.Distribution.log_prob` to implement REINFORCE:\n\n.. math::\n\n    \\Delta\\theta  = \\alpha r \\frac{\\partial\\log p(a|\\pi^\\theta(s))}{\\partial\\theta}\n\nwhere :math:`\\theta` are the parameters, :math:`\\alpha` is the learning rate,\n:math:`r` is the reward and :math:`p(a|\\pi^\\theta(s))` is the probability of\ntaking action :math:`a` in state :math:`s` given policy :math:`\\pi^\\theta`.\n\nIn practice we would sample an action from the output of a network, apply this\naction in an environment, and then use ``log_prob`` to construct an equivalent\nloss function. Note that we use a negative because optimizers use gradient\ndescent, whilst the rule above assumes gradient ascent. With a categorical\npolicy, the code for implementing REINFORCE would be as follows::\n\n    probs = policy_network(state)\n    # Note that this is equivalent to what used to be called multinomial\n    m = Categorical(probs)\n    action = m.sample()\n    next_state, reward = env.step(action)\n    loss = -m.log_prob(action) * reward\n    loss.backward()\n\nPathwise derivative\n^^^^^^^^^^^^^^^^^^^\n\nThe other way to implement these stochastic/policy gradients would be to use the\nreparameterization trick from the\n:meth:`~oneflow.distributions.Distribution.rsample` method, where the\nparameterized random variable can be constructed via a parameterized\ndeterministic function of a parameter-free random variable. The reparameterized\nsample therefore becomes differentiable. The code for implementing the pathwise\nderivative would be as follows::\n\n    params = policy_network(state)\n    m = Normal(*params)\n    # Any distribution with .has_rsample == True could work based on the application\n    action = m.rsample()\n    next_state, reward = env.step(action)  # Assuming that reward is differentiable\n    loss = -reward\n    loss.backward()\n\"\"\"\n\nfrom .distribution import Distribution\nfrom .categorical import Categorical\n"
  },
  {
    "path": "python/oneflow/distributions/categorical.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow as flow\nfrom oneflow.distributions.distribution import Distribution\nfrom oneflow.distributions.utils import probs_to_logits, logits_to_probs\n\n# NOTE(Liang Depeng): modified from\n# https://github.com/pytorch/pytorch/blob/master/torch/distributions/categorical.py\n\n__all__ = [\"Categorical\"]\n\n\nclass Categorical(Distribution):\n    r\"\"\"\n    Creates a categorical distribution parameterized by either :attr:`probs` or\n    :attr:`logits` (but not both).\n\n    .. note::\n        It is equivalent to the distribution that :func:`oneflow.multinomial`\n        samples from.\n\n    Samples are integers from :math:`\\{0, \\ldots, K-1\\}` where `K` is ``probs.size(-1)``.\n    If `probs` is 1-dimensional with length-`K`, each element is the relative probability\n    of sampling the class at that index.\n    If `probs` is N-dimensional, the first N-1 dimensions are treated as a batch of\n    relative probability vectors.\n\n    .. note:: The `probs` argument must be non-negative, finite and have a non-zero sum,\n              and it will be normalized to sum to 1 along the last dimension. :attr:`probs`\n              will return this normalized value.\n              The `logits` argument will be interpreted as unnormalized log probabilities\n              and can therefore be any real number. It will likewise be normalized so that\n              the resulting probabilities sum to 1 along the last dimension. :attr:`logits`\n              will return this normalized value.\n\n    Args:\n        probs (Tensor): event probabilities\n        logits (Tensor): event log probabilities (unnormalized)\n\n    See also: :func:`oneflow.multinomial`\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> gen = flow.manual_seed(0)\n        >>> m = flow.distributions.categorical.Categorical(flow.tensor([ 0.25, 0.25, 0.25, 0.25 ]))\n        >>> m.sample()  # equal probability of 0, 1, 2, 3\n        tensor(3, dtype=oneflow.int64)\n    \"\"\"\n    has_enumerate_support = True\n\n    def __init__(self, probs=None, logits=None, validate_args=None):\n        if (probs is None) == (logits is None):\n            raise ValueError(\n                \"Either `probs` or `logits` must be specified, but not both.\"\n            )\n        assert validate_args is None\n\n        if probs is not None:\n            if probs.dim() < 1:\n                raise ValueError(\"`probs` parameter must be at least one-dimensional.\")\n            self.probs = probs / probs.sum(-1, keepdim=True)\n        else:\n            if logits.dim() < 1:\n                raise ValueError(\"`logits` parameter must be at least one-dimensional.\")\n            self.logits = logits\n            # Normalize\n\n            import math\n\n            def logsumexp(t):\n                if t.numel() != 0:\n                    maxes = flow.max(t, dim=-1, keepdim=True)[0]\n                    maxes.masked_fill_(flow.abs(maxes) == math.inf, 0)\n                    result = flow.sum(flow.exp(t - maxes), dim=-1, keepdim=True)\n                    return flow.log(result) + maxes\n                else:\n                    return flow.log(flow.sum(t, dim=-1, keepdim=True))\n\n            self.probs = logits_to_probs(logits - logsumexp(logits))\n\n        self._param = self.probs if probs is not None else self.logits\n        self._num_events = self._param.size()[-1]\n        batch_shape = (\n            self._param.size()[:-1] if self._param.ndimension() > 1 else flow.Size()\n        )\n        super(Categorical, self).__init__(batch_shape, validate_args=validate_args)\n\n    def logits(self):\n        return probs_to_logits(self.probs)\n\n    def probs(self):\n        return logits_to_probs(self.logits)\n\n    def sample(self, sample_shape=flow.Size()):\n        if not isinstance(sample_shape, flow.Size):\n            sample_shape = flow.Size(sample_shape)\n        probs_2d = self.probs.reshape(-1, self._num_events)\n        samples_2d = flow.multinomial(probs_2d, sample_shape.numel(), True).T\n        return samples_2d.reshape(self._extended_shape(sample_shape))\n\n\nif __name__ == \"__main__\":\n    import doctest\n\n    doctest.testmod(raise_on_error=True)\n"
  },
  {
    "path": "python/oneflow/distributions/distribution.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow\nimport warnings\nfrom typing import Dict, Optional, Any\n\n# NOTE(Liang Depeng): Modified from\n#                     https://github.com/pytorch/pytorch/blob/master/torch/distributions/distribution.py\n\n__all__ = [\"Distribution\"]\n\n\nclass Distribution(object):\n    r\"\"\"\n    Distribution is the abstract base class for probability distributions.\n    \"\"\"\n\n    has_rsample = False\n    has_enumerate_support = False\n    _validate_args = __debug__\n\n    @staticmethod\n    def set_default_validate_args(value):\n        \"\"\"\n        Sets whether validation is enabled or disabled.\n\n        The default behavior mimics Python's ``assert`` statement: validation\n        is on by default, but is disabled if Python is run in optimized mode\n        (via ``python -O``). Validation may be expensive, so you may want to\n        disable it once a model is working.\n\n        Args:\n            value (bool): Whether to enable validation.\n        \"\"\"\n        if value not in [True, False]:\n            raise ValueError\n        Distribution._validate_args = value\n\n    def __init__(\n        self, batch_shape=oneflow.Size(), event_shape=oneflow.Size(), validate_args=None\n    ):\n        self._batch_shape = batch_shape\n        self._event_shape = event_shape\n        assert validate_args is None, \"only support validate_args=None for now.\"\n        super(Distribution, self).__init__()\n\n    def expand(self, batch_shape, _instance=None):\n        \"\"\"\n        Returns a new distribution instance (or populates an existing instance\n        provided by a derived class) with batch dimensions expanded to\n        `batch_shape`. This method calls :class:`~oneflow.Tensor.expand` on\n        the distribution's parameters. As such, this does not allocate new\n        memory for the expanded distribution instance. Additionally,\n        this does not repeat any args checking or parameter broadcasting in\n        `__init__.py`, when an instance is first created.\n\n        Args:\n            batch_shape (oneflow.Size): the desired expanded size.\n            _instance: new instance provided by subclasses that\n                need to override `.expand`.\n\n        Returns:\n            New distribution instance with batch dimensions expanded to\n            `batch_size`.\n        \"\"\"\n        raise NotImplementedError\n\n    @property\n    def batch_shape(self):\n        \"\"\"\n        Returns the shape over which parameters are batched.\n        \"\"\"\n        return self._batch_shape\n\n    @property\n    def event_shape(self):\n        \"\"\"\n        Returns the shape of a single sample (without batching).\n        \"\"\"\n        return self._event_shape\n\n    @property\n    def mean(self):\n        \"\"\"\n        Returns the mean of the distribution.\n        \"\"\"\n        raise NotImplementedError\n\n    @property\n    def mode(self):\n        \"\"\"\n        Returns the mode of the distribution.\n        \"\"\"\n        raise NotImplementedError(f\"{self.__class__} does not implement mode\")\n\n    @property\n    def variance(self):\n        \"\"\"\n        Returns the variance of the distribution.\n        \"\"\"\n        raise NotImplementedError\n\n    @property\n    def stddev(self):\n        \"\"\"\n        Returns the standard deviation of the distribution.\n        \"\"\"\n        return self.variance.sqrt()\n\n    def sample(self, sample_shape=oneflow.Size()):\n        \"\"\"\n        Generates a sample_shape shaped sample or sample_shape shaped batch of\n        samples if the distribution parameters are batched.\n        \"\"\"\n        with oneflow.no_grad():\n            return self.rsample(sample_shape)\n\n    def rsample(self, sample_shape=oneflow.Size()):\n        \"\"\"\n        Generates a sample_shape shaped reparameterized sample or sample_shape\n        shaped batch of reparameterized samples if the distribution parameters\n        are batched.\n        \"\"\"\n        raise NotImplementedError\n\n    def sample_n(self, n):\n        \"\"\"\n        Generates n samples or n batches of samples if the distribution\n        parameters are batched.\n        \"\"\"\n        warnings.warn(\n            \"sample_n will be deprecated. Use .sample((n,)) instead\", UserWarning\n        )\n        return self.sample(oneflow.Size((n,)))\n\n    def log_prob(self, value):\n        \"\"\"\n        Returns the log of the probability density/mass function evaluated at\n        `value`.\n\n        Args:\n            value (Tensor):\n        \"\"\"\n        raise NotImplementedError\n\n    def cdf(self, value):\n        \"\"\"\n        Returns the cumulative density/mass function evaluated at\n        `value`.\n\n        Args:\n            value (Tensor):\n        \"\"\"\n        raise NotImplementedError\n\n    def icdf(self, value):\n        \"\"\"\n        Returns the inverse cumulative density/mass function evaluated at\n        `value`.\n\n        Args:\n            value (Tensor):\n        \"\"\"\n        raise NotImplementedError\n\n    def enumerate_support(self, expand=True):\n        \"\"\"\n        Returns tensor containing all values supported by a discrete\n        distribution. The result will enumerate over dimension 0, so the shape\n        of the result will be `(cardinality,) + batch_shape + event_shape`\n        (where `event_shape = ()` for univariate distributions).\n\n        Note that this enumerates over all batched tensors in lock-step\n        `[[0, 0], [1, 1], ...]`. With `expand=False`, enumeration happens\n        along dim 0, but with the remaining batch dimensions being\n        singleton dimensions, `[[0], [1], ..`.\n\n        To iterate over the full Cartesian product use\n        `itertools.product(m.enumerate_support())`.\n\n        Args:\n            expand (bool): whether to expand the support over the\n                batch dims to match the distribution's `batch_shape`.\n\n        Returns:\n            Tensor iterating over dimension 0.\n        \"\"\"\n        raise NotImplementedError\n\n    def entropy(self):\n        \"\"\"\n        Returns entropy of distribution, batched over batch_shape.\n\n        Returns:\n            Tensor of shape batch_shape.\n        \"\"\"\n        raise NotImplementedError\n\n    def perplexity(self):\n        \"\"\"\n        Returns perplexity of distribution, batched over batch_shape.\n\n        Returns:\n            Tensor of shape batch_shape.\n        \"\"\"\n        return oneflow.exp(self.entropy())\n\n    def _extended_shape(self, sample_shape=oneflow.Size()):\n        \"\"\"\n        Returns the size of the sample returned by the distribution, given\n        a `sample_shape`. Note, that the batch and event shapes of a distribution\n        instance are fixed at the time of construction. If this is empty, the\n        returned shape is upcast to (1,).\n\n        Args:\n            sample_shape (oneflow.Size): the size of the sample to be drawn.\n        \"\"\"\n        if not isinstance(sample_shape, oneflow.Size):\n            sample_shape = oneflow.Size(sample_shape)\n        return sample_shape + self._batch_shape + self._event_shape\n\n    def _validate_sample(self, value):\n        \"\"\"\n        Argument validation for distribution methods such as `log_prob`,\n        `cdf` and `icdf`. The rightmost dimensions of a value to be\n        scored via these methods must agree with the distribution's batch\n        and event shapes.\n\n        Args:\n            value (Tensor): the tensor whose log probability is to be\n                computed by the `log_prob` method.\n        Raises\n            ValueError: when the rightmost dimensions of `value` do not match the\n                distribution's batch and event shapes.\n        \"\"\"\n        if not isinstance(value, oneflow.Tensor):\n            raise ValueError(\"The value argument to log_prob must be a Tensor\")\n\n        event_dim_start = len(value.size()) - len(self._event_shape)\n        if value.size()[event_dim_start:] != self._event_shape:\n            raise ValueError(\n                \"The right-most size of value must match event_shape: {} vs {}.\".format(\n                    value.size(), self._event_shape\n                )\n            )\n\n        actual_shape = value.size()\n        expected_shape = self._batch_shape + self._event_shape\n        for i, j in zip(reversed(actual_shape), reversed(expected_shape)):\n            if i != 1 and j != 1 and i != j:\n                raise ValueError(\n                    \"Value is not broadcastable with batch_shape+event_shape: {} vs {}.\".format(\n                        actual_shape, expected_shape\n                    )\n                )\n        try:\n            support = self.support\n        except NotImplementedError:\n            warnings.warn(\n                f\"{self.__class__} does not define `support` to enable \"\n                + \"sample validation. Please initialize the distribution with \"\n                + \"`validate_args=False` to turn off validation.\"\n            )\n            return\n        assert support is not None\n        valid = support.check(value)\n        if not valid.all():\n            raise ValueError(\n                \"Expected value argument \"\n                f\"({type(value).__name__} of shape {tuple(value.shape)}) \"\n                f\"to be within the support ({repr(support)}) \"\n                f\"of the distribution {repr(self)}, \"\n                f\"but found invalid values:\\n{value}\"\n            )\n\n    def _get_checked_instance(self, cls, _instance=None):\n        if _instance is None and type(self).__init__ != cls.__init__:\n            raise NotImplementedError(\n                \"Subclass {} of {} that defines a custom __init__ method \"\n                \"must also define a custom .expand() method.\".format(\n                    self.__class__.__name__, cls.__name__\n                )\n            )\n        return self.__new__(type(self)) if _instance is None else _instance\n\n    def __repr__(self):\n        return self.__class__.__name__\n\n\nif __name__ == \"__main__\":\n    import doctest\n\n    doctest.testmod(raise_on_error=True)\n"
  },
  {
    "path": "python/oneflow/distributions/utils.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nfrom functools import update_wrapper\nfrom numbers import Number\nimport oneflow as flow\nimport oneflow.nn.functional as F\nfrom typing import Dict, Any\n\n# NOTE(Liang Depeng): modified from\n# https://github.com/pytorch/pytorch/blob/master/torch/distributions/utils.py\n\neuler_constant = 0.57721566490153286060  # Euler Mascheroni Constant\n\n\ndef logits_to_probs(logits, is_binary=False):\n    r\"\"\"\n    Converts a tensor of logits into probabilities. Note that for the\n    binary case, each value denotes log odds, whereas for the\n    multi-dimensional case, the values along the last dimension denote\n    the log probabilities (possibly unnormalized) of the events.\n    \"\"\"\n    if is_binary:\n        return flow.sigmoid(logits)\n    return F.softmax(logits, dim=-1)\n\n\ndef clamp_probs(probs):\n    eps = flow.finfo(probs.dtype).eps\n    return probs.clamp(min=eps, max=1 - eps)\n\n\ndef probs_to_logits(probs, is_binary=False):\n    r\"\"\"\n    Converts a tensor of probabilities into logits. For the binary case,\n    this denotes the probability of occurrence of the event indexed by `1`.\n    For the multi-dimensional case, the values along the last dimension\n    denote the probabilities of occurrence of each of the events.\n    \"\"\"\n    ps_clamped = clamp_probs(probs)\n    if is_binary:\n        return flow.log(ps_clamped) - flow.log1p(-ps_clamped)\n    return flow.log(ps_clamped)\n\n\nif __name__ == \"__main__\":\n    import doctest\n\n    doctest.testmod(raise_on_error=True)\n"
  },
  {
    "path": "python/oneflow/env.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nfrom oneflow.framework.env_util import api_all_device_placement as all_device_placement\nimport oneflow._oneflow_internal\n\n\ndef get_local_rank():\n    \"\"\"Returns the local rank of current machine.\n    Local rank is not globally unique. It is only unique per process on a machine.\n\n    Returns:\n        The the local rank of process on current machine.\n\n    \"\"\"\n    return oneflow._oneflow_internal.GetLocalRank()\n\n\ndef get_rank():\n    \"\"\"Returns the rank of current process group.\n    Rank is globally unique, range of which is [0, world_size).\n\n    Returns:\n        The rank of the process group.\n\n    \"\"\"\n    return oneflow._oneflow_internal.GetRank()\n\n\ndef get_node_size():\n    \"\"\"Returns the number of machines in the current process group.\n\n    Returns:\n        The the number of machines in the process group.\n\n    \"\"\"\n    return oneflow._oneflow_internal.GetNodeSize()\n\n\ndef get_world_size():\n    \"\"\"Returns the number of processes in the current process group.\n\n    Returns:\n        The world size of the process group.\n\n    \"\"\"\n    return oneflow._oneflow_internal.GetWorldSize()\n\n\ndef init_rdma():\n    \"\"\"\n    Init RDMA in the current envirment. If the current envirment support \n    RDMA, turning on RDMA by calling oneflow.env.init_rdma() can speed up \n    data transfer.\n\n    Note:\n        - Make sure to avoid using fork() after oneflow.env.init_rdma() is invoked. \n          Otherwise, data corruption or segmentation fault  may result!\n\n        - Requires all devices to execute oneflow.env.init_rdma() simultaneously. \n          Otherwise, deadlock may result!\n\n\n    \"\"\"\n    oneflow._oneflow_internal.InitRDMA()\n\n\ndef rdma_is_initialized():\n    \"\"\"Returns whether RDMA is initialized in the current envirment or not.\n\n    Returns:\n        Whether RDMA is initialized or not.\n\n    \"\"\"\n    return oneflow._oneflow_internal.RDMAIsInitialized()\n\n\ndef destory_rdma():\n    \"\"\"Destory RDMA in the current envirment. \n    \"\"\"\n    return oneflow._oneflow_internal.DestoryRDMA()\n"
  },
  {
    "path": "python/oneflow/experimental/load_mnist.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport hashlib\nimport os\n\nimport numpy as np\nimport requests\nfrom tqdm import tqdm\n\n\ndef get_sha256hash(file_path, Bytes=1024):\n    sha256hash = hashlib.sha256()\n    with open(file_path, \"rb\") as f:\n        while True:\n            data = f.read(Bytes)\n            if data:\n                sha256hash.update(data)\n            else:\n                break\n    ret = sha256hash.hexdigest()\n    return ret\n\n\ndef download_mnist_file(out_path, url):\n    resp = requests.get(url=url, stream=True)\n    size = int(resp.headers[\"Content-Length\"]) / 1024\n    print(\"File size: %.4f kb, downloading...\" % size)\n    with open(out_path, \"wb\") as f:\n        for data in tqdm(\n            iterable=resp.iter_content(1024), total=size, unit=\"k\", desc=out_path\n        ):\n            f.write(data)\n        print(\"Done!\")\n\n\ndef get_mnist_file(sha256, url, out_dir):\n    path = os.path.join(out_dir, \"mnist.npz\")\n    if not os.path.isfile(path):\n        download_mnist_file(path, url)\n    print(\"File mnist.npz already exist, path:\", path)\n    if not get_sha256hash(path) == sha256:\n        cheksum_fail = \"sha256 verification failed, remove {0} and try again\".format(\n            path\n        )\n        raise Exception(cheksum_fail)\n    return path\n\n\ndef load_mnist(\n    train_batch_size=100,\n    test_batch_size=100,\n    data_format=\"NCHW\",\n    url=\"https://oneflow-public.oss-cn-beijing.aliyuncs.com/datasets/mnist.npz\",\n    hash_check=\"63d4344077849053dc3036b247fa012b2b381de53fd055a66b539dffd76cf08e\",\n    out_dir=\".\",\n):\n    \"\"\"Load mnist dataset, return images and labels,\n            if  dataset doesn't exist, then download it to directory that out_dir specified\n\n    Args:\n        train_batch_size (int, optional): batch size for train. Defaults to 100.\n        test_batch_size (int, optional): batch size for test or evaluate. Defaults to 100.\n        data_format (str, optional): data format. Defaults to \"NCHW\".\n        url (str, optional): url to get mnist.npz. Defaults to \"https://oneflow-public.oss-cn-beijing.aliyuncs.com/datasets/mnist.npz\".\n        hash_check (str, optional): file hash value. Defaults to \"63d4344077849053dc3036b247fa012b2b381de53fd055a66b539dffd76cf08e\".\n        out_dir (str, optional): dir to save downloaded file. Defaults to \"./\".\n\n    Returns:\n        (train_images, train_labels), (test_images, test_labels)\n    \"\"\"\n    path = get_mnist_file(hash_check, url, out_dir)\n    with np.load(path, allow_pickle=True) as f:\n        (x_train, y_train) = (f[\"x_train\"], f[\"y_train\"])\n        (x_test, y_test) = (f[\"x_test\"], f[\"y_test\"])\n\n    def normalize(x, y, batch_size):\n        x = x.astype(np.float32) / 255.0\n        y = y.astype(np.int32)\n        if data_format == \"NCHW\":\n            images = x.reshape((-1, batch_size, 1, x.shape[1], x.shape[2]))\n        else:\n            images = x.reshape((-1, batch_size, x.shape[1], x.shape[2], 1))\n        labels = y.reshape((-1, batch_size))\n        return (images, labels)\n\n    (train_images, train_labels) = normalize(x_train, y_train, train_batch_size)\n    (test_images, test_labels) = normalize(x_test, y_test, test_batch_size)\n    return ((train_images, train_labels), (test_images, test_labels))\n"
  },
  {
    "path": "python/oneflow/fft/__init__.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nfrom oneflow.framework.tensor import Tensor\nimport oneflow as flow\n\n\ndef fft(input, n=None, dim=-1, norm=None) -> Tensor:\n    r\"\"\"\n    \n    Computes the one dimensional discrete Fourier transform of :attr:`input`.\n    \n    Note:\n    The Fourier domain representation of any real signal satisfies the\n    Hermitian property: `X[i] = conj(X[-i])`. This function always returns both\n    the positive and negative frequency terms even though, for real inputs, the\n    negative frequencies are redundant. :func:`oneflow.fft.rfft` returns the\n    more compact one-sided representation where only the positive frequencies\n    are returned.\n\n    Args:\n        input (Tensor): the input tensor\n        n (int, optional): Signal length. If given, the input will either be zero-padded\n            or trimmed to this length before computing the FFT.\n        dim (int, optional): The dimension along which to take the one dimensional FFT.\n        norm (str, optional): Normalization mode. For the forward transform\n            (:func:`oneflow.fft.fft`), these correspond to:\n\n            * ``\"forward\"`` - normalize by ``1/n``\n            * ``\"backward\"`` - no normalization\n            * ``\"ortho\"`` - normalize by ``1/sqrt(n)`` (making the FFT orthonormal)\n\n            Calling the backward transform (:func:`oneflow.fft.ifft`) with the same\n            normalization mode will apply an overall normalization of ``1/n`` between\n            the two transforms. This is required to make :func:`oneflow.fft.ifft`\n            the exact inverse.\n\n            Default is ``\"backward\"`` (no normalization).\n\n    Example:\n    \n        >>> t = oneflow.arange(4)\n        >>> t\n        tensor([0, 1, 2, 3])\n        >>> oneflow.fft.fft(t)\n        tensor([ 6+0j, -2+2j, -2+0j, -2-2j], dtype=oneflow.complex64)\n\n        >>> t = oneflow.tensor([0.+1.j, 2.+3.j, 4.+5.j, 6.+7.j])\n        >>> oneflow.fft.fft(t)\n        tensor([12+16j, -8+0j, -4-4j,  -8j], dtype=oneflow.complex128)\n    \"\"\"\n    if n is None:\n        n = -1\n    return flow._C.fft(input, n, dim, norm)\n\n\ndef ifft(input, n=None, dim=-1, norm=None) -> Tensor:\n    r\"\"\"\n\n    Computes the one dimensional inverse discrete Fourier transform of :attr:`input`.\n\n    Args:\n        input (Tensor): the input tensor\n        n (int, optional): Signal length. If given, the input will either be zero-padded\n            or trimmed to this length before computing the IFFT.\n        dim (int, optional): The dimension along which to take the one dimensional IFFT.\n        norm (str, optional): Normalization mode. For the backward transform\n            (:func:`oneflow.fft.ifft`), these correspond to:\n\n            * ``\"forward\"`` - no normalization\n            * ``\"backward\"`` - normalize by ``1/n``\n            * ``\"ortho\"`` - normalize by ``1/sqrt(n)`` (making the IFFT orthonormal)\n\n            Calling the forward transform (:func:`~oneflow.fft.fft`) with the same\n            normalization mode will apply an overall normalization of ``1/n`` between\n            the two transforms. This is required to make :func:`oneflow.fft.ifft`\n            the exact inverse.\n\n            Default is ``\"backward\"`` (normalize by ``1/n``).\n\n    Example:\n\n        >>> t = oneflow.tensor([ 6.+0.j, -2.+2.j, -2.+0.j, -2.-2.j])\n        >>> oneflow.fft.ifft(t)\n        tensor([0j, (1+0j), (2+0j), (3+0j)], dtype=oneflow.complex128)\n    \"\"\"\n    if n is None:\n        n = -1\n    return flow._C.ifft(input, n, dim, norm)\n\n\ndef fft2(input, s=None, dim=(-2, -1), norm=None) -> Tensor:\n    r\"\"\"\n\n    Computes the 2 dimensional discrete Fourier transform of :attr:`input`.\n    Equivalent to :func:`~oneflow.fft.fftn` but FFTs only the last two dimensions by default.\n\n    Note:\n        The Fourier domain representation of any real signal satisfies the\n        Hermitian property: ``X[i, j] = conj(X[-i, -j])``. This\n        function always returns all positive and negative frequency terms even\n        though, for real inputs, half of these values are redundant.\n        :func:`~oneflow.fft.rfft2` returns the more compact one-sided representation\n        where only the positive frequencies of the last dimension are returned.\n\n    Args:\n        input (Tensor): the input tensor\n        s (Tuple[int], optional): Signal size in the transformed dimensions.\n            If given, each dimension ``dim[i]`` will either be zero-padded or\n            trimmed to the length ``s[i]`` before computing the FFT.\n            If a length ``-1`` is specified, no padding is done in that dimension.\n            Default: ``s = [input.size(d) for d in dim]``\n        dim (Tuple[int], optional): Dimensions to be transformed.\n            Default: last two dimensions.\n        norm (str, optional): Normalization mode. For the forward transform\n            (:func:`oneflow.fft.fft2`), these correspond to:\n\n            * ``\"forward\"`` - normalize by ``1/n``\n            * ``\"backward\"`` - no normalization\n            * ``\"ortho\"`` - normalize by ``1/sqrt(n)`` (making the FFT orthonormal)\n\n            Where ``n = prod(s)`` is the logical FFT size.\n            Calling the backward transform (:func:`oneflow.fft.ifft2`) with the same\n            normalization mode will apply an overall normalization of ``1/n``\n            between the two transforms. This is required to make\n            :func:`~oneflow.fft.ifft2` the exact inverse.\n\n            Default is ``\"backward\"`` (no normalization).\n\n    \"\"\"\n    return flow._C.fft2(input, s, dim, norm)\n\n\ndef ifft2(input, s=None, dim=(-2, -1), norm=None) -> Tensor:\n    r\"\"\"\n\n    Computes the 2 dimensional inverse discrete Fourier transform of :attr:`input`.\n    Equivalent to :func:`oneflow.fft.ifftn` but IFFTs only the last two dimensions by default.\n\n    Args:\n        input (Tensor): the input tensor\n        s (Tuple[int], optional): Signal size in the transformed dimensions.\n            If given, each dimension ``dim[i]`` will either be zero-padded or\n            trimmed to the length ``s[i]`` before computing the IFFT.\n            If a length ``-1`` is specified, no padding is done in that dimension.\n            Default: ``s = [input.size(d) for d in dim]``\n        dim (Tuple[int], optional): Dimensions to be transformed.\n            Default: last two dimensions.\n        norm (str, optional): Normalization mode. For the backward transform\n            (:func:`oneflow.fft.ifft2`), these correspond to:\n\n            * ``\"forward\"`` - no normalization\n            * ``\"backward\"`` - normalize by ``1/n``\n            * ``\"ortho\"`` - normalize by ``1/sqrt(n)`` (making the IFFT orthonormal)\n\n            Where ``n = prod(s)`` is the logical IFFT size.\n            Calling the forward transform (:func:`oneflow.fft.fft2`) with the same\n            normalization mode will apply an overall normalization of ``1/n`` between\n            the two transforms. This is required to make :func:`oneflow.fft.ifft2`\n            the exact inverse.\n\n            Default is ``\"backward\"`` (normalize by ``1/n``).\n\n\n    \"\"\"\n    return flow._C.ifft2(input, s, dim, norm)\n\n\ndef fftn(input, s=None, dim=None, norm=None) -> Tensor:\n    r\"\"\"\n\n    Computes the N dimensional discrete Fourier transform of :attr:`input`.\n\n    Note:\n        The Fourier domain representation of any real signal satisfies the\n        Hermitian property: ``X[i_1, ..., i_n] = conj(X[-i_1, ..., -i_n])``. This\n        function always returns all positive and negative frequency terms even\n        though, for real inputs, half of these values are redundant.\n        :func:`oneflow.fft.rfftn` returns the more compact one-sided representation\n        where only the positive frequencies of the last dimension are returned.\n\n    Args:\n        input (Tensor): the input tensor\n        s (Tuple[int], optional): Signal size in the transformed dimensions.\n            If given, each dimension ``dim[i]`` will either be zero-padded or\n            trimmed to the length ``s[i]`` before computing the FFT.\n            If a length ``-1`` is specified, no padding is done in that dimension.\n            Default: ``s = [input.size(d) for d in dim]``\n        dim (Tuple[int], optional): Dimensions to be transformed.\n            Default: all dimensions, or the last ``len(s)`` dimensions if :attr:`s` is given.\n        norm (str, optional): Normalization mode. For the forward transform\n            (:func:`oneflow.fft.fftn`), these correspond to:\n\n            * ``\"forward\"`` - normalize by ``1/n``\n            * ``\"backward\"`` - no normalization\n            * ``\"ortho\"`` - normalize by ``1/sqrt(n)`` (making the FFT orthonormal)\n\n            Where ``n = prod(s)`` is the logical FFT size.\n            Calling the backward transform (:func:`oneflow.fft.ifftn`) with the same\n            normalization mode will apply an overall normalization of ``1/n``\n            between the two transforms. This is required to make\n            :func:`oneflow.fft.ifftn` the exact inverse.\n\n            Default is ``\"backward\"`` (no normalization).\n\n    \"\"\"\n    return flow._C.fftn(input, s, dim, norm)\n\n\ndef ifftn(input, s=None, dim=None, norm=None) -> Tensor:\n    r\"\"\"\n\n    Computes the N dimensional inverse discrete Fourier transform of :attr:`input`.\n\n    Args:\n        input (Tensor): the input tensor\n        s (Tuple[int], optional): Signal size in the transformed dimensions.\n            If given, each dimension ``dim[i]`` will either be zero-padded or\n            trimmed to the length ``s[i]`` before computing the IFFT.\n            If a length ``-1`` is specified, no padding is done in that dimension.\n            Default: ``s = [input.size(d) for d in dim]``\n        dim (Tuple[int], optional): Dimensions to be transformed.\n            Default: all dimensions, or the last ``len(s)`` dimensions if :attr:`s` is given.\n        norm (str, optional): Normalization mode. For the backward transform\n            (:func:`oneflow.fft.ifftn`), these correspond to:\n\n            * ``\"forward\"`` - no normalization\n            * ``\"backward\"`` - normalize by ``1/n``\n            * ``\"ortho\"`` - normalize by ``1/sqrt(n)`` (making the IFFT orthonormal)\n\n            Where ``n = prod(s)`` is the logical IFFT size.\n            Calling the forward transform (:func:`oneflow.fft.fftn`) with the same\n            normalization mode will apply an overall normalization of ``1/n`` between\n            the two transforms. This is required to make :func:`oneflow.fft.ifftn`\n            the exact inverse.\n\n            Default is ``\"backward\"`` (normalize by ``1/n``).\n\n    \"\"\"\n    return flow._C.ifftn(input, s, dim, norm)\n\n\ndef rfft(input, n=None, dim=-1, norm=None) -> Tensor:\n    r\"\"\"\n\n    Computes the one dimensional Fourier transform of real-valued :attr:`input`.\n\n    The FFT of a real signal is Hermitian-symmetric, ``X[i] = conj(X[-i])`` so\n    the output contains only the positive frequencies below the Nyquist frequency.\n    To compute the full output, use :func:`oneflow.fft.fft`\n\n    Args:\n        input (Tensor): the real input tensor\n        n (int, optional): Signal length. If given, the input will either be zero-padded\n            or trimmed to this length before computing the real FFT.\n        dim (int, optional): The dimension along which to take the one dimensional real FFT.\n        norm (str, optional): Normalization mode. For the forward transform\n            (:func:`oneflow.fft.rfft`), these correspond to:\n\n            * ``\"forward\"`` - normalize by ``1/n``\n            * ``\"backward\"`` - no normalization\n            * ``\"ortho\"`` - normalize by ``1/sqrt(n)`` (making the FFT orthonormal)\n\n            Calling the backward transform (:func:`oneflow.fft.irfft`) with the same\n            normalization mode will apply an overall normalization of ``1/n`` between\n            the two transforms. This is required to make :func:`oneflow.fft.irfft`\n            the exact inverse.\n\n            Default is ``\"backward\"`` (no normalization).\n\n    Example:\n\n        >>> t = oneflow.arange(4)\n        >>> t\n        tensor([0, 1, 2, 3], dtype=oneflow.int64)\n        >>> oneflow.fft.rfft(t)\n        tensor([ (6+0j), (-2+2j), (-2+0j)], dtype=oneflow.complex64)\n\n        Compare against the full output from :func:`oneflow.fft.fft`:\n\n        >>> oneflow.fft.fft(t)\n        tensor([ (6+0j), (-2+2j), (-2+0j), (-2-2j)], dtype=oneflow.complex64)\n\n        Notice that the symmetric element ``T[-1] == T[1].conj()`` is omitted.\n        At the Nyquist frequency ``T[-2] == T[2]`` is it's own symmetric pair,\n        and therefore must always be real-valued.\n    \"\"\"\n\n    if n is None:\n        n = -1\n    return flow._C.rfft(input, n, dim, norm)\n\n\ndef irfft(input, n=None, dim=-1, norm=None) -> Tensor:\n    r\"\"\"\n\n    Computes the inverse of :func:`oneflow.fft.rfft`.\n\n    :attr:`input` is interpreted as a one-sided Hermitian signal in the Fourier\n    domain, as produced by :func:`oneflow.fft.rfft`. By the Hermitian property, the\n    output will be real-valued.\n\n    Note:\n        Some input frequencies must be real-valued to satisfy the Hermitian\n        property. In these cases the imaginary component will be ignored.\n        For example, any imaginary component in the zero-frequency term cannot\n        be represented in a real output and so will always be ignored.\n\n    Note:\n        The correct interpretation of the Hermitian input depends on the length of\n        the original data, as given by :attr:`n`. This is because each input shape\n        could correspond to either an odd or even length signal. By default, the\n        signal is assumed to be even length and odd signals will not round-trip\n        properly. So, it is recommended to always pass the signal length :attr:`n`.\n\n    Args:\n        input (Tensor): the input tensor representing a half-Hermitian signal\n        n (int, optional): Output signal length. This determines the length of the\n            output signal. If given, the input will either be zero-padded or trimmed to this\n            length before computing the real IFFT.\n            Defaults to even output: ``n=2*(input.size(dim) - 1)``.\n        dim (int, optional): The dimension along which to take the one dimensional real IFFT.\n        norm (str, optional): Normalization mode. For the backward transform\n            (:func:`oneflow.fft.irfft`), these correspond to:\n\n            * ``\"forward\"`` - no normalization\n            * ``\"backward\"`` - normalize by ``1/n``\n            * ``\"ortho\"`` - normalize by ``1/sqrt(n)`` (making the real IFFT orthonormal)\n\n            Calling the forward transform (:func:`oneflow.fft.rfft`) with the same\n            normalization mode will apply an overall normalization of ``1/n`` between\n            the two transforms. This is required to make :func:`oneflow.fft.irfft`\n            the exact inverse.\n\n            Default is ``\"backward\"`` (normalize by ``1/n``).\n\n\n    \"\"\"\n\n    if n is None:\n        n = -1\n    return flow._C.irfft(input, n, dim, norm)\n\n\ndef rfft2(input, s=None, dim=(-2, -1), norm=None) -> Tensor:\n    r\"\"\"\n\n    Computes the 2-dimensional discrete Fourier transform of real :attr:`input`.\n    Equivalent to :func:`oneflow.fft.rfftn` but FFTs only the last two dimensions by default.\n\n    The FFT of a real signal is Hermitian-symmetric, ``X[i, j] = conj(X[-i, -j])``,\n    so the full :func:`oneflow.fft.fft2` output contains redundant information.\n    :func:`oneflow.fft.rfft2` instead omits the negative frequencies in the last\n    dimension.\n\n    Args:\n        input (Tensor): the input tensor\n        s (Tuple[int], optional): Signal size in the transformed dimensions.\n            If given, each dimension ``dim[i]`` will either be zero-padded or\n            trimmed to the length ``s[i]`` before computing the real FFT.\n            If a length ``-1`` is specified, no padding is done in that dimension.\n            Default: ``s = [input.size(d) for d in dim]``\n        dim (Tuple[int], optional): Dimensions to be transformed.\n            Default: last two dimensions.\n        norm (str, optional): Normalization mode. For the forward transform\n            (:func:`oneflow.fft.rfft2`), these correspond to:\n\n            * ``\"forward\"`` - normalize by ``1/n``\n            * ``\"backward\"`` - no normalization\n            * ``\"ortho\"`` - normalize by ``1/sqrt(n)`` (making the real FFT orthonormal)\n\n            Where ``n = prod(s)`` is the logical FFT size.\n            Calling the backward transform (:func:`oneflow.fft.irfft2`) with the same\n            normalization mode will apply an overall normalization of ``1/n`` between\n            the two transforms. This is required to make :func:`oneflow.fft.irfft2`\n            the exact inverse.\n\n            Default is ``\"backward\"`` (no normalization).\n\n    \"\"\"\n\n    return flow._C.rfft2(input, s, dim, norm)\n\n\ndef irfft2(input, s=None, dim=(-2, -1), norm=None) -> Tensor:\n    r\"\"\"\n\n    Computes the inverse of :func:`oneflow.fft.rfft2`.\n    Equivalent to :func:`oneflow.fft.irfftn` but IFFTs only the last two dimensions by default.\n\n    :attr:`input` is interpreted as a one-sided Hermitian signal in the Fourier\n    domain, as produced by :func:`oneflow.fft.rfft2`. By the Hermitian property, the\n    output will be real-valued.\n\n    Note:\n        Some input frequencies must be real-valued to satisfy the Hermitian\n        property. In these cases the imaginary component will be ignored.\n        For example, any imaginary component in the zero-frequency term cannot\n        be represented in a real output and so will always be ignored.\n\n    Note:\n        The correct interpretation of the Hermitian input depends on the length of\n        the original data, as given by :attr:`s`. This is because each input shape\n        could correspond to either an odd or even length signal. By default, the\n        signal is assumed to be even length and odd signals will not round-trip\n        properly. So, it is recommended to always pass the signal shape :attr:`s`.\n\n    Args:\n        input (Tensor): the input tensor\n        s (Tuple[int], optional): Signal size in the transformed dimensions.\n            If given, each dimension ``dim[i]`` will either be zero-padded or\n            trimmed to the length ``s[i]`` before computing the real FFT.\n            If a length ``-1`` is specified, no padding is done in that dimension.\n            Defaults to even output in the last dimension:\n            ``s[-1] = 2*(input.size(dim[-1]) - 1)``.\n        dim (Tuple[int], optional): Dimensions to be transformed.\n            The last dimension must be the half-Hermitian compressed dimension.\n            Default: last two dimensions.\n        norm (str, optional): Normalization mode. For the backward transform\n            (:func:`oneflow.fft.irfft2`), these correspond to:\n\n            * ``\"forward\"`` - no normalization\n            * ``\"backward\"`` - normalize by ``1/n``\n            * ``\"ortho\"`` - normalize by ``1/sqrt(n)`` (making the real IFFT orthonormal)\n\n            Where ``n = prod(s)`` is the logical IFFT size.\n            Calling the forward transform (:func:`oneflow.fft.rfft2`) with the same\n            normalization mode will apply an overall normalization of ``1/n`` between\n            the two transforms. This is required to make :func:`oneflow.fft.irfft2`\n            the exact inverse.\n\n            Default is ``\"backward\"`` (normalize by ``1/n``).\n\n\n    \"\"\"\n    return flow._C.irfft2(input, s, dim, norm)\n\n\ndef rfftn(input, s=None, dim=None, norm=None) -> Tensor:\n    r\"\"\"\n\n    Computes the N-dimensional discrete Fourier transform of real :attr:`input`.\n\n    The FFT of a real signal is Hermitian-symmetric,\n    ``X[i_1, ..., i_n] = conj(X[-i_1, ..., -i_n])`` so the full\n    :func:`oneflow.fft.fftn` output contains redundant information.\n    :func:`oneflow.fft.rfftn` instead omits the negative frequencies in the\n    last dimension.\n\n    Args:\n        input (Tensor): the input tensor\n        s (Tuple[int], optional): Signal size in the transformed dimensions.\n            If given, each dimension ``dim[i]`` will either be zero-padded or\n            trimmed to the length ``s[i]`` before computing the real FFT.\n            If a length ``-1`` is specified, no padding is done in that dimension.\n            Default: ``s = [input.size(d) for d in dim]``\n        dim (Tuple[int], optional): Dimensions to be transformed.\n            Default: all dimensions, or the last ``len(s)`` dimensions if :attr:`s` is given.\n        norm (str, optional): Normalization mode. For the forward transform\n            (:func:`oneflow.fft.rfftn`), these correspond to:\n\n            * ``\"forward\"`` - normalize by ``1/n``\n            * ``\"backward\"`` - no normalization\n            * ``\"ortho\"`` - normalize by ``1/sqrt(n)`` (making the real FFT orthonormal)\n\n            Where ``n = prod(s)`` is the logical FFT size.\n            Calling the backward transform (:func:`oneflow.fft.irfftn`) with the same\n            normalization mode will apply an overall normalization of ``1/n`` between\n            the two transforms. This is required to make :func:`oneflow.fft.irfftn`\n            the exact inverse.\n\n            Default is ``\"backward\"`` (no normalization).\n\n    \"\"\"\n\n    return flow._C.rfftn(input, s, dim, norm)\n\n\ndef irfftn(input, s=None, dim=None, norm=None) -> Tensor:\n    r\"\"\"\n\n    Computes the inverse of :func:`oneflow.fft.rfftn`.\n\n    :attr:`input` is interpreted as a one-sided Hermitian signal in the Fourier\n    domain, as produced by :func:`oneflow.fft.rfftn`. By the Hermitian property, the\n    output will be real-valued.\n\n    Note:\n        Some input frequencies must be real-valued to satisfy the Hermitian\n        property. In these cases the imaginary component will be ignored.\n        For example, any imaginary component in the zero-frequency term cannot\n        be represented in a real output and so will always be ignored.\n\n    Note:\n        The correct interpretation of the Hermitian input depends on the length of\n        the original data, as given by :attr:`s`. This is because each input shape\n        could correspond to either an odd or even length signal. By default, the\n        signal is assumed to be even length and odd signals will not round-trip\n        properly. So, it is recommended to always pass the signal shape :attr:`s`.\n\n    Args:\n        input (Tensor): the input tensor\n        s (Tuple[int], optional): Signal size in the transformed dimensions.\n            If given, each dimension ``dim[i]`` will either be zero-padded or\n            trimmed to the length ``s[i]`` before computing the real FFT.\n            If a length ``-1`` is specified, no padding is done in that dimension.\n            Defaults to even output in the last dimension:\n            ``s[-1] = 2*(input.size(dim[-1]) - 1)``.\n        dim (Tuple[int], optional): Dimensions to be transformed.\n            The last dimension must be the half-Hermitian compressed dimension.\n            Default: all dimensions, or the last ``len(s)`` dimensions if :attr:`s` is given.\n        norm (str, optional): Normalization mode. For the backward transform\n            (:func:`oneflow.fft.irfftn`), these correspond to:\n\n            * ``\"forward\"`` - no normalization\n            * ``\"backward\"`` - normalize by ``1/n``\n            * ``\"ortho\"`` - normalize by ``1/sqrt(n)`` (making the real IFFT orthonormal)\n\n            Where ``n = prod(s)`` is the logical IFFT size.\n            Calling the forward transform (:func:`oneflow.fft.rfftn`) with the same\n            normalization mode will apply an overall normalization of ``1/n`` between\n            the two transforms. This is required to make :func:`oneflow.fft.irfftn`\n            the exact inverse.\n\n            Default is ``\"backward\"`` (normalize by ``1/n``).\n\n    \"\"\"\n    return flow._C.irfftn(input, s, dim, norm)\n\n\ndef hfft(input, n=None, dim=-1, norm=None) -> Tensor:\n    r\"\"\"\n    hfft(input, n=None, dim=-1, norm=None, *, out=None) -> Tensor\n\n    Computes the one dimensional discrete Fourier transform of a Hermitian\n    symmetric :attr:`input` signal.\n\n    Note:\n\n        :func:`oneflow.fft.hfft`/:func:`oneflow.fft.ihfft` are analogous to\n        :func:`oneflow.fft.rfft`/:func:`oneflow.fft.irfft`. The real FFT expects\n        a real signal in the time-domain and gives a Hermitian symmetry in the\n        frequency-domain. The Hermitian FFT is the opposite; Hermitian symmetric in\n        the time-domain and real-valued in the frequency-domain. For this reason,\n        special care needs to be taken with the length argument :attr:`n`, in the\n        same way as with :func:`oneflow.fft.irfft`.\n\n    Note:\n        Because the signal is Hermitian in the time-domain, the result will be\n        real in the frequency domain. Note that some input frequencies must be\n        real-valued to satisfy the Hermitian property. In these cases the imaginary\n        component will be ignored. For example, any imaginary component in\n        ``input[0]`` would result in one or more complex frequency terms which\n        cannot be represented in a real output and so will always be ignored.\n\n    Note:\n        The correct interpretation of the Hermitian input depends on the length of\n        the original data, as given by :attr:`n`. This is because each input shape\n        could correspond to either an odd or even length signal. By default, the\n        signal is assumed to be even length and odd signals will not round-trip\n        properly. So, it is recommended to always pass the signal length :attr:`n`.\n\n    Args:\n        input (Tensor): the input tensor representing a half-Hermitian signal\n        n (int, optional): Output signal length. This determines the length of the\n            real output. If given, the input will either be zero-padded or trimmed to this\n            length before computing the Hermitian FFT.\n            Defaults to even output: ``n=2*(input.size(dim) - 1)``.\n        dim (int, optional): The dimension along which to take the one dimensional Hermitian FFT.\n        norm (str, optional): Normalization mode. For the forward transform\n            (:func:`oneflow.fft.hfft`), these correspond to:\n\n            * ``\"forward\"`` - normalize by ``1/n``\n            * ``\"backward\"`` - no normalization\n            * ``\"ortho\"`` - normalize by ``1/sqrt(n)`` (making the Hermitian FFT orthonormal)\n\n            Calling the backward transform (:func:`oneflow.fft.ihfft`) with the same\n            normalization mode will apply an overall normalization of ``1/n`` between\n            the two transforms. This is required to make :func:`oneflow.fft.ihfft`\n            the exact inverse.\n\n            Default is ``\"backward\"`` (no normalization).\n\n    Example:\n\n        Taking a real-valued frequency signal and bringing it into the time domain\n        gives Hermitian symmetric output:\n\n        >>> t = oneflow.linspace(0, 1, 5)\n        >>> t\n        tensor([0.0000, 0.2500, 0.5000, 0.7500, 1.0000], dtype=oneflow.float32)\n        >>> T = oneflow.fft.ifft(t)\n        >>> T\n        tensor([ (0.5000-0.0000j), (-0.1250-0.1720j), (-0.1250-0.0406j), (-0.1250+0.0406j),\n                (-0.1250+0.1720j)], dtype=oneflow.complex64)\n        \n        Note that ``T[1] == T[-1].conj()`` and ``T[2] == T[-2].conj()`` is\n        redundant. We can thus compute the forward transform without considering\n        negative frequencies:\n\n        >>> oneflow.fft.hfft(T[:3], n=5)\n        tensor([0.0000, 0.2500, 0.5000, 0.7500, 1.0000], dtype=oneflow.float32)\n\n        Like with :func:`oneflow.fft.irfft`, the output length must be given in order\n        to recover an even length output:\n\n        >>> oneflow.fft.hfft(T[:3])\n        tensor([0.1250, 0.2809, 0.6250, 0.9691], dtype=oneflow.float32)\n    \"\"\"\n\n    if n is None:\n        n = -1\n    return flow._C.hfft(input, n, dim, norm)\n\n\ndef ihfft(input, n=None, dim=-1, norm=None) -> Tensor:\n    r\"\"\"\n    \n    Computes the inverse of :func:`oneflow.fft.hfft`.\n\n    :attr:`input` must be a real-valued signal, interpreted in the Fourier domain.\n    The IFFT of a real signal is Hermitian-symmetric, ``X[i] = conj(X[-i])``.\n    :func:`oneflow.fft.ihfft` represents this in the one-sided form where only the\n    positive frequencies below the Nyquist frequency are included. To compute the\n    full output, use :func:`oneflow.fft.ifft`.\n\n\n    Args:\n        input (Tensor): the real input tensor\n        n (int, optional): Signal length. If given, the input will either be zero-padded\n            or trimmed to this length before computing the Hermitian IFFT.\n        dim (int, optional): The dimension along which to take the one dimensional Hermitian IFFT.\n        norm (str, optional): Normalization mode. For the backward transform\n            (:func:`oneflow.fft.ihfft`), these correspond to:\n\n            * ``\"forward\"`` - no normalization\n            * ``\"backward\"`` - normalize by ``1/n``\n            * ``\"ortho\"`` - normalize by ``1/sqrt(n)`` (making the IFFT orthonormal)\n\n            Calling the forward transform (:func:`oneflow.fft.hfft`) with the same\n            normalization mode will apply an overall normalization of ``1/n`` between\n            the two transforms. This is required to make :func:`oneflow.fft.ihfft`\n            the exact inverse.\n\n            Default is ``\"backward\"`` (normalize by ``1/n``).\n\n    Example:\n\n        >>> t = oneflow.arange(5)\n        >>> t\n        tensor([0, 1, 2, 3, 4], dtype=oneflow.int64)\n        >>> oneflow.fft.ihfft(t)\n        tensor([ (2.0000-0.0000j), (-0.5000-0.6882j), (-0.5000-0.1625j)], dtype=oneflow.complex64)\n        \n        Compare against the full output from :func:`oneflow.fft.ifft`:\n\n        >>> oneflow.fft.ifft(t)\n        tensor([ 2.0000-0.0000j, -0.5000-0.6882j, -0.5000-0.1625j, -0.5000+0.1625j,\n                -0.5000+0.6882j])\n        tensor([ (2.0000-0.0000j), (-0.5000-0.6882j), (-0.5000-0.1625j), (-0.5000+0.1625j),\n                (-0.5000+0.6882j)], dtype=oneflow.complex64)\n    \"\"\"\n    if n is None:\n        n = -1\n    return flow._C.ihfft(input, n, dim, norm)\n\n\ndef hfft2(input, s=None, dim=(-2, -1), norm=None) -> Tensor:\n    r\"\"\"\n\n    Computes the 2-dimensional discrete Fourier transform of a Hermitian symmetric\n    :attr:`input` signal. Equivalent to :func:`oneflow.fft.hfftn` but only\n    transforms the last two dimensions by default.\n\n    :attr:`input` is interpreted as a one-sided Hermitian signal in the time\n    domain. By the Hermitian property, the Fourier transform will be real-valued.\n\n    Args:\n        input (Tensor): the input tensor\n        s (Tuple[int], optional): Signal size in the transformed dimensions.\n            If given, each dimension ``dim[i]`` will either be zero-padded or\n            trimmed to the length ``s[i]`` before computing the Hermitian FFT.\n            If a length ``-1`` is specified, no padding is done in that dimension.\n            Defaults to even output in the last dimension:\n            ``s[-1] = 2*(input.size(dim[-1]) - 1)``.\n        dim (Tuple[int], optional): Dimensions to be transformed.\n            The last dimension must be the half-Hermitian compressed dimension.\n            Default: last two dimensions.\n        norm (str, optional): Normalization mode. For the forward transform\n            (:func:`oneflow.fft.hfft2`), these correspond to:\n\n            * ``\"forward\"`` - normalize by ``1/n``\n            * ``\"backward\"`` - no normalization\n            * ``\"ortho\"`` - normalize by ``1/sqrt(n)`` (making the Hermitian FFT orthonormal)\n\n            Where ``n = prod(s)`` is the logical FFT size.\n            Calling the backward transform (:func:`oneflow.fft.ihfft2`) with the same\n            normalization mode will apply an overall normalization of ``1/n`` between\n            the two transforms. This is required to make :func:`oneflow.fft.ihfft2`\n            the exact inverse.\n\n            Default is ``\"backward\"`` (no normalization).\n\n\n    Example:\n\n        Starting from a real frequency-space signal, we can generate a\n        Hermitian-symmetric time-domain signal:\n        >>> T = oneflow.rand(10, 9)\n        >>> t = oneflow.fft.ihfft2(T)\n\n        Without specifying the output length to :func:`oneflow.fft.hfftn`, the\n        output will not round-trip properly because the input is odd-length in the\n        last dimension:\n\n        >>> oneflow.fft.hfft2(t).size()\n        oneflow.Size([10, 10])\n\n        So, it is recommended to always pass the signal shape :attr:`s`.\n\n        >>> roundtrip = oneflow.fft.hfft2(t, T.size())\n        >>> roundtrip.size()\n        oneflow.Size([10, 9])\n        >>> oneflow.allclose(roundtrip, T)\n        True\n\n    \"\"\"\n    return flow._C.hfft2(input, s, dim, norm)\n\n\ndef ihfft2(input, s=None, dim=(-2, -1), norm=None) -> Tensor:\n    r\"\"\"\n\n    Computes the 2-dimensional inverse discrete Fourier transform of real\n    :attr:`input`. Equivalent to :func:`oneflow.fft.ihfftn` but transforms only the\n    two last dimensions by default.\n\n    Args:\n        input (Tensor): the input tensor\n        s (Tuple[int], optional): Signal size in the transformed dimensions.\n            If given, each dimension ``dim[i]`` will either be zero-padded or\n            trimmed to the length ``s[i]`` before computing the Hermitian IFFT.\n            If a length ``-1`` is specified, no padding is done in that dimension.\n            Default: ``s = [input.size(d) for d in dim]``\n        dim (Tuple[int], optional): Dimensions to be transformed.\n            Default: last two dimensions.\n        norm (str, optional): Normalization mode. For the backward transform\n            (:func:`oneflow.fft.ihfft2`), these correspond to:\n\n            * ``\"forward\"`` - no normalization\n            * ``\"backward\"`` - normalize by ``1/n``\n            * ``\"ortho\"`` - normalize by ``1/sqrt(n)`` (making the Hermitian IFFT orthonormal)\n\n            Where ``n = prod(s)`` is the logical IFFT size.\n            Calling the forward transform (:func:`oneflow.fft.hfft2`) with the same\n            normalization mode will apply an overall normalization of ``1/n`` between\n            the two transforms. This is required to make :func:`oneflow.fft.ihfft2`\n            the exact inverse.\n\n            Default is ``\"backward\"`` (normalize by ``1/n``).\n\n    \"\"\"\n    return flow._C.ihfft2(input, s, dim, norm)\n\n\ndef hfftn(input, s=None, dim=None, norm=None) -> Tensor:\n    r\"\"\"\n\n    Computes the n-dimensional discrete Fourier transform of a Hermitian symmetric\n    :attr:`input` signal.\n\n    :attr:`input` is interpreted as a one-sided Hermitian signal in the time\n    domain. By the Hermitian property, the Fourier transform will be real-valued.\n\n    Note:\n        :func:`oneflow.fft.hfftn`/:func:`oneflow.fft.ihfftn` are analogous to\n        :func:`oneflow.fft.rfftn`/:func:`oneflow.fft.irfftn`. The real FFT expects\n        a real signal in the time-domain and gives Hermitian symmetry in the\n        frequency-domain. The Hermitian FFT is the opposite; Hermitian symmetric in\n        the time-domain and real-valued in the frequency-domain. For this reason,\n        special care needs to be taken with the shape argument :attr:`s`, in the\n        same way as with :func:`oneflow.fft.irfftn`.\n\n    Note:\n        Some input frequencies must be real-valued to satisfy the Hermitian\n        property. In these cases the imaginary component will be ignored.\n        For example, any imaginary component in the zero-frequency term cannot\n        be represented in a real output and so will always be ignored.\n\n    Note:\n        The correct interpretation of the Hermitian input depends on the length of\n        the original data, as given by :attr:`s`. This is because each input shape\n        could correspond to either an odd or even length signal. By default, the\n        signal is assumed to be even length and odd signals will not round-trip\n        properly. It is recommended to always pass the signal shape :attr:`s`.\n\n\n    Args:\n        input (Tensor): the input tensor\n        s (Tuple[int], optional): Signal size in the transformed dimensions.\n            If given, each dimension ``dim[i]`` will either be zero-padded or\n            trimmed to the length ``s[i]`` before computing the real FFT.\n            If a length ``-1`` is specified, no padding is done in that dimension.\n            Defaults to even output in the last dimension:\n            ``s[-1] = 2*(input.size(dim[-1]) - 1)``.\n        dim (Tuple[int], optional): Dimensions to be transformed.\n            The last dimension must be the half-Hermitian compressed dimension.\n            Default: all dimensions, or the last ``len(s)`` dimensions if :attr:`s` is given.\n        norm (str, optional): Normalization mode. For the forward transform\n            (:func:`oneflow.fft.hfftn`), these correspond to:\n\n            * ``\"forward\"`` - normalize by ``1/n``\n            * ``\"backward\"`` - no normalization\n            * ``\"ortho\"`` - normalize by ``1/sqrt(n)`` (making the Hermitian FFT orthonormal)\n\n            Where ``n = prod(s)`` is the logical FFT size.\n            Calling the backward transform (:func:`oneflow.fft.ihfftn`) with the same\n            normalization mode will apply an overall normalization of ``1/n`` between\n            the two transforms. This is required to make :func:`oneflow.fft.ihfftn`\n            the exact inverse.\n\n            Default is ``\"backward\"`` (no normalization).\n\n    \"\"\"\n    return flow._C.hfftn(input, s, dim, norm)\n\n\ndef ihfftn(input, s=None, dim=None, norm=None) -> Tensor:\n    r\"\"\"\n\n    Computes the N-dimensional inverse discrete Fourier transform of real :attr:`input`.\n\n    :attr:`input` must be a real-valued signal, interpreted in the Fourier domain.\n    The n-dimensional IFFT of a real signal is Hermitian-symmetric,\n    ``X[i, j, ...] = conj(X[-i, -j, ...])``. :func:`oneflow.fft.ihfftn` represents\n    this in the one-sided form where only the positive frequencies below the\n    Nyquist frequency are included in the last signal dimension. To compute the\n    full output, use :func:`oneflow.fft.ifftn`.\n\n    Args:\n        input (Tensor): the input tensor\n        s (Tuple[int], optional): Signal size in the transformed dimensions.\n            If given, each dimension ``dim[i]`` will either be zero-padded or\n            trimmed to the length ``s[i]`` before computing the Hermitian IFFT.\n            If a length ``-1`` is specified, no padding is done in that dimension.\n            Default: ``s = [input.size(d) for d in dim]``\n        dim (Tuple[int], optional): Dimensions to be transformed.\n            Default: all dimensions, or the last ``len(s)`` dimensions if :attr:`s` is given.\n        norm (str, optional): Normalization mode. For the backward transform\n            (:func:`oneflow.fft.ihfftn`), these correspond to:\n\n            * ``\"forward\"`` - no normalization\n            * ``\"backward\"`` - normalize by ``1/n``\n            * ``\"ortho\"`` - normalize by ``1/sqrt(n)`` (making the Hermitian IFFT orthonormal)\n\n            Where ``n = prod(s)`` is the logical IFFT size.\n            Calling the forward transform (:func:`oneflow.fft.hfftn`) with the same\n            normalization mode will apply an overall normalization of ``1/n`` between\n            the two transforms. This is required to make :func:`oneflow.fft.ihfftn`\n            the exact inverse.\n\n            Default is ``\"backward\"`` (normalize by ``1/n``).\n\n    \"\"\"\n    return flow._C.ihfftn(input, s, dim, norm)\n"
  },
  {
    "path": "python/oneflow/framework/__init__.py",
    "content": ""
  },
  {
    "path": "python/oneflow/framework/args_tree.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nfrom typing import Union, List, Tuple, Dict, Callable\nfrom collections import OrderedDict\nfrom oneflow.framework.tensor import Tensor\n\n\ndef _is_raw_type(value, raw_type):\n    # Special case for namedtuple return types\n    # For example, max(x, dim=1) return oneflow.return_types.max(values=..., indices=...)\n    if (\n        raw_type == tuple\n        and isinstance(value, tuple)\n        and type(value).__module__ == \"oneflow.return_types\"\n    ):\n        return True\n    return type(value) is raw_type\n\n\nclass NamedArg(object):\n    r\"\"\"\n    The class for wrapping over the input/output argument and associating each input/output argument with a prefix and name.\n    The input/output argument can be viewed as a tree. NamedArg basically wraps over each tree node on this tree.\n    The recursive structure of the input/output arguments are kept, for example:\n\n    input = [1, {key: \"value\" }] will be constructed into: \n        \n    named_input = NamedArg([NamedArg(1), NamedArg({key: NamedArg(\"value\")})])\n    \"\"\"\n\n    def __init__(\n        self, prefix=\"\", name=None, global_index=0, tensor_type=Tensor\n    ) -> None:\n        self._name = name if name is not None else str(global_index)\n        self._prefix = prefix\n        self._global_index = global_index\n        self._is_value_set = False\n        self._value = None\n        self._tensor_type = tensor_type\n\n    def prefix(self):\n        return self._prefix\n\n    def name(self):\n        return self._name\n\n    def global_index(self):\n        return self._global_index\n\n    def value(self):\n        assert self._is_value_set, \"self._value is not set yet\"\n        return self._value\n\n    def is_leaf(self):\n        assert self._is_value_set, \"self._value is not set yet\"\n        return not (\n            _is_raw_type(self._value, dict)\n            or _is_raw_type(self._value, OrderedDict)\n            or _is_raw_type(self._value, tuple)\n            or _is_raw_type(self._value, list)\n        )\n\n    def set_value(self, value):\n        assert not _is_raw_type(value, NamedArg), \"cannot accept value of type NamedArg\"\n        self._value = value\n        self._is_value_set = True\n\n    def __repr__(self):\n        repr_str = \"\"\n        repr_str += \"(name: \" + self._name\n        repr_str += \", idx: \" + str(self._global_index)\n        repr_str += \", type: \"\n        if _is_raw_type(self._value, tuple):\n            repr_str += \"TUPLE\"\n        elif _is_raw_type(self._value, list):\n            repr_str += \"LIST\"\n        elif _is_raw_type(self._value, dict) or _is_raw_type(self._value, OrderedDict):\n            repr_str += \"DICT\"\n        elif isinstance(self._value, self._tensor_type):\n            repr_str += \"TENSOR\"\n        elif self._value is None:\n            repr_str += \"NONE\"\n        else:\n            repr_str += \"OPAQUE\"\n\n        if isinstance(self._value, self._tensor_type):\n            repr_str += (\n                \", value: tensor(\"\n                + str(self._value.shape)\n                + \", \"\n                + str(self._value.dtype)\n                + \")\"\n            )\n        elif (\n            _is_raw_type(self._value, dict)\n            or _is_raw_type(self._value, OrderedDict)\n            or _is_raw_type(self._value, list)\n            or _is_raw_type(self._value, tuple)\n        ):\n            repr_str += \", value: \" + repr(self._value)\n        else:\n            repr_str += \", value: \" + repr(self._value)\n        repr_str += \")\"\n        return repr_str\n\n\nclass ArgsTree(object):\n    def __init__(\n        self,\n        io_args: Union[Tuple, List, Dict],\n        gen_name: bool = False,\n        root_prefix: str = \"\",\n        root_name: str = None,\n        tensor_type=Tensor,\n    ) -> None:\n\n        self._io_args = io_args\n        self._gen_name = gen_name\n        self._root_prefix = root_prefix\n        self._root_name = root_name\n        self._named_io_args = None\n        self._next_global_index = 0\n        self._tensor_type = tensor_type\n\n        if self._gen_name:\n            self._named_io_args = self._construct_named_io_args(\n                self._io_args, self._root_prefix, self._root_name\n            )\n\n    def gen_name(self):\n        return self._gen_name\n\n    def iter_nodes(self):\n        r\"\"\"\n        return a generator of the args tree nodes in the DFS manner. \n        The node returned can be of type NamedArg or non-NamedArg depending on whether gen_name is set. \n        If gen_name is set, the node will be NamedArg. \n        \"\"\"\n\n        if self._gen_name:\n            args_to_iter = self._named_io_args\n        else:\n            args_to_iter = self._io_args\n\n        # NOTE(lixiang): Generator expression and iterator are used.\n        #   This avoids generating the full list in memory and only processes the nodes that need to be processed,\n        #   reducing time and space consumption.\n        stack = [iter([args_to_iter])]\n        while len(stack) > 0:\n            try:\n                curr = next(stack[-1])\n                if _is_raw_type(curr, NamedArg):\n                    curr_value = curr.value()\n                else:\n                    curr_value = curr\n\n                if _is_raw_type(curr_value, list) or _is_raw_type(curr_value, tuple):\n                    children = curr_value\n                elif _is_raw_type(curr_value, dict) or _is_raw_type(\n                    curr_value, OrderedDict\n                ):\n                    children = curr_value.values()\n                else:\n                    children = None\n\n                if children:\n                    stack.append(iter(children))\n\n                yield curr\n\n            except StopIteration:\n                stack.pop()\n\n    def iter_named_nodes(self):\n        assert self._gen_name, \"Only use this if gen_name is set!\"\n        for named_node in self.iter_nodes():\n            yield (named_node.prefix() + \"_\" + named_node.name(), named_node)\n\n    def _construct_named_io_args(self, value, prefix: str, name: str) -> NamedArg:\n        arg = NamedArg(prefix, name, self._next_global_index, self._tensor_type)\n        self._next_global_index += 1\n\n        if _is_raw_type(value, list) or _is_raw_type(value, tuple):\n\n            def construct_func(enum):\n                (i, v) = enum\n                next_prefix = prefix + (\".\" if prefix else \"\") + str(i)\n                new_arg = self._construct_named_io_args(v, next_prefix, None)\n                return new_arg\n\n            arg.set_value(value.__class__(map(construct_func, enumerate(value))))\n\n        elif _is_raw_type(value, dict) or _is_raw_type(value, OrderedDict):\n\n            def construct_func(enum):\n                i, (key, v) = enum\n                next_prefix = prefix + (\".\" if prefix else \"\") + str(i)\n                new_arg = self._construct_named_io_args(v, next_prefix, key)\n                return key, new_arg\n\n            arg.set_value(\n                value.__class__(map(construct_func, enumerate(value.items())))\n            )\n        else:\n            arg.set_value(value)\n\n        return arg\n\n    def map_tuple_leaf(self, map_function: Callable):\n        r\"\"\"\n        When the type of io args is tuple or list, map the leaf of the arguments into map_function(leaf).\n        \"\"\"\n        assert map_function != None, \"map function cannot be None\"\n        assert isinstance(\n            self._io_args, (tuple, list)\n        ), \"only used when io args is a tuple or list of tensors\"\n\n        stack = []\n\n        # Cases handled: tuple(tensor, ...), such as input args.\n        if len(self._io_args) > 0 and isinstance(self._io_args[0], self._tensor_type):\n            for i in self._io_args:\n                mapped_value = map_function(i)\n                stack.append(mapped_value)\n\n            if isinstance(self._io_args, tuple):\n                return tuple(stack)\n            elif isinstance(self._io_args, list):\n                return stack\n\n        # Cases handled: tuple(tuple(tensor, ...), ), such as the output args of return.\n        elif (\n            len(self._io_args) > 0\n            and isinstance(self._io_args[0], (tuple, list))\n            and all(isinstance(arg, self._tensor_type) for arg in self._io_args[0])\n        ):\n            for i in self._io_args[0]:\n                mapped_value = map_function(i)\n                stack.append(mapped_value)\n\n            if isinstance(self._io_args[0], tuple):\n                return (tuple(stack),)\n            elif isinstance(self._io_args[0], list):\n                return (stack,)\n\n        # Other cases.\n        # Do not loop optimize, and continue to execute the recursive code (`_execute_mapping`).\n        else:\n            return self._execute_mapping(self._io_args, map_function)\n\n    def map_leaf(self, map_function: Callable):\n        r\"\"\"\n        Map the leaf of the arguments into map_function(leaf).\n        \"\"\"\n        assert map_function != None, \"map function cannot be None\"\n\n        if self._gen_name:\n            args_to_map = self._named_io_args\n        else:\n            args_to_map = self._io_args\n\n        return self._execute_mapping(args_to_map, map_function)\n\n    def _execute_mapping(self, value, map_function):\n        if _is_raw_type(value, tuple) or _is_raw_type(value, list):\n            mapped_value = value.__class__(\n                map(lambda x: self._execute_mapping(x, map_function), value)\n            )\n        elif _is_raw_type(value, dict) or _is_raw_type(value, OrderedDict):\n            mapped_value = value.__class__(\n                map(\n                    lambda x: (x[0], self._execute_mapping(x[1], map_function)),\n                    value.items(),\n                )\n            )\n        elif _is_raw_type(value, NamedArg):\n            if value.is_leaf():  # only map the leaf: TENSOR/NONE/OPAQUE\n                mapped_value = map_function(value)\n            else:\n                mapped_value = self._execute_mapping(value.value(), map_function)\n        else:\n            mapped_value = map_function(value)\n\n        return mapped_value\n\n    def __repr__(self):\n        if self._named_io_args:\n            return self._named_io_args.__repr__()\n        else:\n            return str(self.__class__)\n"
  },
  {
    "path": "python/oneflow/framework/attr_util.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nr\"\"\"\nGet the nested attribute given the owning object and attribute chain.\n\nFor example, if we want to get `resource.collective_boxing_conf.nccl_num_streams`\n\nwe can call `get_nested_attribute(resource, [\"collective_boxing_conf\", \"nccl_num_streams\"])\n\"\"\"\n\n\ndef get_nested_attribute(owning_object, attrs_chain):\n    if not isinstance(attrs_chain, list):\n        if isinstance(attrs_chain, str):\n            attrs_chain = [attrs_chain]\n        else:\n            assert False, (\n                \"attrs_chain should be either a string or a list, but get \"\n                + str(type(attrs_chain))\n            )\n\n    last_attr = owning_object\n    for att in attrs_chain:\n        assert hasattr(last_attr, att), (\n            repr(last_attr) + \" does not have attribute \" + att + \" !\"\n        )\n        last_attr = getattr(last_attr, att)\n    return last_attr\n\n\ndef SetProtoAttrValue(attr_value, py_value, default_attr_value):\n    if default_attr_value.HasField(\"at_bool\"):\n        if py_value is None:\n            py_value = True\n        assert type(py_value) is bool\n        attr_value.at_bool = py_value\n    elif default_attr_value.HasField(\"at_int64\"):\n        assert type(py_value) is int\n        attr_value.at_int64 = py_value\n    elif default_attr_value.HasField(\"at_double\"):\n        assert type(py_value) is float\n        attr_value.at_double = py_value\n    elif default_attr_value.HasField(\"at_string\"):\n        assert type(py_value) is str\n        attr_value.at_string = py_value\n    else:\n        raise ValueError(\n            \"config with type %s is invalid. supported types: [bool, int, float, str]\"\n            % type(py_value)\n        )\n"
  },
  {
    "path": "python/oneflow/framework/balanced_splitter.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\n\ndef BalancedPartNums(total, part_size):\n    base = int(total / part_size)\n    remainder = total % part_size\n    return [base + int(i < remainder) for i in range(part_size)]\n\n\ndef BalancedRanges(total, part_size):\n    balanced_part_nums = BalancedPartNums(total, part_size)\n    ranges = []\n    start = 0\n    for part_num in balanced_part_nums:\n        end = start + part_num\n        ranges.append((start, end))\n        start = end\n    return ranges\n"
  },
  {
    "path": "python/oneflow/framework/c_api_util.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nfrom google.protobuf import text_format\n\nimport oneflow\nimport oneflow.core.common.data_type_pb2 as dtype_util\nimport oneflow.core.common.error_pb2 as error_util\nimport oneflow.core.job.env_pb2 as env_pb2\nimport oneflow.core.job.job_pb2 as job_pb\nimport oneflow.core.job.job_conf_pb2 as job_conf_pb\nimport oneflow.core.job.job_set_pb2 as job_set_pb\nimport oneflow.core.job.placement_pb2 as placement_pb\nimport oneflow.core.job.resource_pb2 as resource_util\nimport oneflow.core.operator.op_attribute_pb2 as op_attribute_pb\nimport oneflow.core.operator.op_conf_pb2 as op_conf_util\nimport oneflow.core.record.record_pb2 as record_util\nimport oneflow.core.register.logical_blob_id_pb2 as logical_blob_id_util\nfrom oneflow.core.framework.config_def_pb2 import ConfigDef\nfrom oneflow.core.job.inter_user_job_info_pb2 import InterUserJobInfo\n\n\ndef CurrentResource():\n    resource = oneflow._oneflow_internal.CurrentResource()\n    return text_format.Parse(resource, resource_util.Resource())\n\n\ndef EnvResource():\n    resource = oneflow._oneflow_internal.EnvResource()\n    return text_format.Parse(resource, resource_util.Resource())\n\n\ndef GetEnvContext(env_proto):\n    assert type(env_proto) is env_pb2.EnvProto\n    env_proto_str = text_format.MessageToString(env_proto)\n    env_ctx = oneflow._oneflow_internal.EnvContext(env_proto_str)\n    return env_ctx\n\n\ndef JobBuildAndInferCtx_Open(job_name):\n    job_name = str(job_name)\n    oneflow._oneflow_internal.JobBuildAndInferCtx_Open(job_name)\n\n\ndef CurJobBuildAndInferCtx_SetJobConf(job_config_proto):\n    assert type(job_config_proto) is job_conf_pb.JobConfigProto, type(job_config_proto)\n    job_config_proto_str = text_format.MessageToString(job_config_proto)\n    oneflow._oneflow_internal.CurJobBuildAndInferCtx_SetJobConf(job_config_proto_str)\n\n\ndef InferOpConf(op_conf_proto, upstream_signature):\n    serialized_op_conf = str(text_format.MessageToString(op_conf_proto))\n    serialized_upstream_sig = str(text_format.MessageToString(upstream_signature))\n    op_attribute_str = oneflow._oneflow_internal.InferOpConf(\n        serialized_op_conf, serialized_upstream_sig\n    )\n    return text_format.Parse(op_attribute_str, op_attribute_pb.OpAttribute())\n\n\ndef IsInterfaceOpConf(op_conf):\n    op_type_field = op_conf.WhichOneof(\"op_type\")\n    field_number = op_conf_util.OperatorConf.DESCRIPTOR.fields_by_name[\n        op_type_field\n    ].number\n    return oneflow._oneflow_internal.IsInterfaceOpTypeCase(field_number)\n\n\ndef GetOpParallelSymbolId(op_conf_proto):\n    serialized_op_conf = str(text_format.MessageToString(op_conf_proto))\n    return oneflow._oneflow_internal.GetOpParallelSymbolId(serialized_op_conf)\n\n\ndef CheckAndCompleteUserOpConf(op_conf_proto):\n    serialized_op_conf = str(text_format.MessageToString(op_conf_proto))\n    new_op_conf = oneflow._oneflow_internal.CheckAndCompleteUserOpConf(\n        serialized_op_conf\n    )\n    return text_format.Parse(new_op_conf, op_conf_util.OperatorConf())\n\n\ndef GetFunctionConfigDef():\n    func_config_def = oneflow._oneflow_internal.GetFunctionConfigDef()\n    return text_format.Parse(func_config_def, ConfigDef())\n\n\ndef GetScopeConfigDef():\n    scope_config_def = oneflow._oneflow_internal.GetScopeConfigDef()\n    return text_format.Parse(scope_config_def, ConfigDef())\n\n\ndef GetCurrentJob():\n    serialized_job = oneflow._oneflow_internal.GetSerializedCurrentJob()\n    ret = job_pb.Job()\n    ret.ParseFromString(serialized_job)\n    return ret\n"
  },
  {
    "path": "python/oneflow/framework/check_point_v2.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport contextlib\nimport os\nimport warnings\nfrom typing import (\n    Any,\n    Callable,\n    Dict,\n    Iterable,\n    List,\n    Optional,\n    Sequence,\n    Tuple,\n    Union,\n    IO,\n    BinaryIO,\n)\nfrom pathlib import Path\nimport pickle\nimport json\nfrom collections import OrderedDict\nimport io\n\nimport numpy as np\nfrom google.protobuf import text_format\n\nimport oneflow\nimport oneflow as flow\nimport oneflow._oneflow_internal\nimport oneflow.core.framework.variable_meta_info_pb2 as variable_meta_info_pb\nimport oneflow.framework.dtype as dtype_util\nimport oneflow.framework.id_util as id_util\nfrom oneflow.framework.tensor import Tensor\nimport oneflow.nn.graph.graph as graph_util\nfrom oneflow.framework.args_tree import ArgsTree\nimport pickle\nfrom oneflow.nn.graph import GraphTensor\n\nSNAPSHOT_DONE_FILENAME = \"snapshot_done\"\nMETA_INFO_FILENAME = \"meta\"\nPICKLE_FILENAME = \"pickled_data\"\nDATA_FILENAME = \"out\"\nPROTOCOL_VERSION = 1\nONEFLOW_MAGIC_KEY = \"__oneflow__\"\n\nMAP_LOCATION = Optional[\n    Union[Callable[[Tensor, str], Tensor], flow.device, str, flow.placement]\n]\nFILE_LIKE = Union[os.PathLike, BinaryIO, IO[bytes], Path]\n\n\nclass _opener(object):\n    def __init__(self, file_like):\n        self.file_like = file_like\n\n    def __enter__(self):\n        return self.file_like\n\n    def __exit__(self, *args):\n        pass\n\n\nclass _open_file(_opener):\n    def __init__(self, path, mode):\n        super(_open_file, self).__init__(open(path, mode))\n\n    def __exit__(self, *args):\n        self.file_like.close()\n\n\nclass _open_buffer_reader(_opener):\n    def __init__(self, buffer):\n        super(_open_buffer_reader, self).__init__(buffer)\n        _check_seekable(buffer)\n\n\nclass _open_buffer_writer(_opener):\n    def __exit__(self, *args):\n        self.file_like.flush()\n\n\ndef _open_file_like(path_or_buffer, mode):\n    if _is_path(path_or_buffer):\n        return _open_file(path_or_buffer, mode)\n    else:\n        if \"w\" in mode:\n            return _open_buffer_writer(path_or_buffer)\n        elif \"r\" in mode:\n            return _open_buffer_reader(path_or_buffer)\n        else:\n            raise RuntimeError(f\"Expected 'r' or 'w' in mode but got {mode}\")\n\n\ndef _is_path(path_or_buffer):\n    return isinstance(path_or_buffer, Path)\n\n\ndef _check_seekable(f) -> bool:\n    def raise_err_msg(patterns, e):\n        for p in patterns:\n            if p in str(e):\n                msg = (\n                    str(e)\n                    + \". You can only oneflow.load from a file that is seekable.\"\n                    + \" Please pre-load the data into a buffer like io.BytesIO and\"\n                    + \" try to load from it instead.\"\n                )\n                raise type(e)(msg)\n        raise e\n\n    try:\n        f.seek(f.tell())\n        return True\n    except (io.UnsupportedOperation, AttributeError) as e:\n        raise_err_msg([\"seek\", \"tell\"], e)\n    return False\n\n\nclass FileBackendVariableBlob:\n    def __init__(\n        self,\n        var_dir: str,\n        dtype: Optional[oneflow.dtype] = None,\n        shape: Optional[Sequence[int]] = None,\n    ):\n        data_path = os.path.join(var_dir, DATA_FILENAME)\n        if not os.path.isfile(data_path):\n            raise FileNotFoundError()\n        self.var_dir_ = var_dir\n        meta_info_path = os.path.join(self.var_dir_, META_INFO_FILENAME)\n        if os.path.exists(meta_info_path):\n            meta_info = variable_meta_info_pb.VariableMetaInfo()\n            with open(meta_info_path) as f:\n                text_format.Parse(f.read(), meta_info)\n            self.has_meta_info_ = True\n        else:\n            self.has_meta_info_ = False\n        if self.has_meta_info_:\n            assert dtype is None and shape is None\n            self.shape_ = tuple(meta_info.shape.dim)\n            self.dtype_ = dtype_util.convert_proto_dtype_to_oneflow_dtype(\n                meta_info.data_type\n            )\n        elif shape is not None and dtype is not None:\n            self.shape_ = shape\n            self.dtype_ = dtype\n            self.has_meta_info_ = True\n        elif shape is not None or dtype is not None:\n            raise RuntimeError(\"both or neither of shape and dtype should be None\")\n        else:\n            pass\n        if self.has_meta_info_:\n            itemsize = np.dtype(\n                dtype_util.convert_oneflow_dtype_to_numpy_dtype(self.dtype_)\n            ).itemsize\n            assert os.path.getsize(data_path) == np.prod(self.shape).item() * itemsize\n\n    @property\n    def file_path(self) -> str:\n        return os.path.join(self.var_dir_, DATA_FILENAME)\n\n    @property\n    def shape(self) -> Tuple[int]:\n        return self.shape_\n\n    @property\n    def quant_info(self):\n        raise NotImplementedError()\n\n    @property\n    def dtype(self) -> oneflow.dtype:\n        return self.dtype_\n\n    def numpy(self) -> np.ndarray:\n        if not self.has_meta_info_:\n            raise RuntimeError(\"This variable does not have meta info\")\n        return np.fromfile(\n            self.file_path,\n            dtype=dtype_util.convert_oneflow_dtype_to_numpy_dtype(self.dtype),\n        ).reshape(self.shape)\n\n\ndef _save_tensor_to_disk(tensor: \"oneflow.Tensor\", dir_name: Union[str, Path]) -> None:\n    os.makedirs(dir_name, exist_ok=True)\n    meta_info = variable_meta_info_pb.VariableMetaInfo()\n    meta_info.shape.dim[:] = tensor.shape\n    meta_info.data_type = oneflow._oneflow_internal.deprecated.GetProtoDtype4OfDtype(\n        tensor.dtype\n    )\n    data_path = os.path.join(dir_name, DATA_FILENAME)\n    with open(data_path, \"wb\") as f:\n        f.write(tensor.numpy().tobytes())\n\n    with open(os.path.join(dir_name, META_INFO_FILENAME), \"w\") as f:\n        f.write(text_format.MessageToString(meta_info))\n\n\nValueContainer = Union[FileBackendVariableBlob, np.ndarray, \"oneflow.Tensor\"]\n\n\ndef _default_restore_location(storage, location=None):\n    return storage\n\n\ndef _get_restore_location(map_location):\n    if map_location is None:\n        restore_location = _default_restore_location\n    elif isinstance(map_location, (str, flow.device)):\n\n        def restore_location(storage, location=None):\n            return storage.to(device=map_location)\n\n    elif isinstance(map_location, flow.placement):\n\n        def restore_location(storage, location=None):\n            return storage.to_global(placement=map_location)\n\n    else:\n\n        def restore_location(storage, location=None):\n            result = map_location(storage, location)\n            if result is None:\n                result = _default_restore_location(storage, location)\n            return result\n\n    return restore_location\n\n\ndef smart_to(obj: Any, dest: MAP_LOCATION) -> \"oneflow.Tensor\":\n    if not isinstance(obj, flow.Tensor):\n        return obj\n    tensor = obj\n    restore_location = _get_restore_location(dest)\n    return restore_location(tensor, None)\n\n\ndef module_to(obj: flow.nn.Module, dest: MAP_LOCATION) -> \"oneflow.nn.Module\":\n    restore_location = _get_restore_location(dest)\n    # for nn.Module object, we will use a tensor to get the device\n    # to support dest with a Callable type\n    device = restore_location(flow.tensor([0])).device\n    obj.to(device)\n    return obj\n\n\ndef _map_location(obj: Any, map_location: MAP_LOCATION):\n    if isinstance(obj, flow.nn.Module):\n        return module_to(obj, map_location)\n    else:\n        res = ArgsTree(obj).map_leaf(lambda x: smart_to(x, map_location))\n        return res\n\n\ndef _LoadSingleVariable(\n    path: Optional[str],\n    global_src_rank: Optional[int] = None,\n    map_location: MAP_LOCATION = None,\n) -> \"flow.Tensor\":\n    if global_src_rank is not None:\n        rank = flow.env.get_rank()\n        if rank == global_src_rank:\n            file_backed_blob = FileBackendVariableBlob(path)\n            loaded = flow.tensor(file_backed_blob.numpy(), dtype=file_backed_blob.dtype)\n        else:\n            loaded = flow.tensor([])\n        loaded = loaded.to_global(\n            flow.placement(\"cpu\", [global_src_rank]), flow.sbp.broadcast\n        )\n    else:\n        loaded = flow.tensor(FileBackendVariableBlob(path).numpy())\n    return smart_to(loaded, map_location)\n\n\ndef _broadcast_py_object(obj, src: int = 0):\n    rank = flow.env.get_rank()\n    if src == rank:\n        obj_bytes = pickle.dumps(obj)\n        return pickle.loads(flow._oneflow_internal.cpu_broadcast(obj_bytes, src))\n    else:\n        return pickle.loads(flow._oneflow_internal.cpu_broadcast(None, src))\n\n\n# NOTE(jianhao):\n# (de)serializing a container of global tensors requires the order\n# of those tensors are the same across all ranks.\ndef tensor_getstate(self):\n    # context_data is not None means setstate/getstate is called inside\n    # flow.save or flow.load\n    if context_data is not None:\n        if context_data.global_rank is None:\n            assert (\n                self.is_local\n            ), \"Please set global_dst_rank in `flow.save` to save global tensor\"\n            tensor = self\n        else:\n            assert not self.is_local\n            # Boxing to cpu firstly to avoid extra gpu memory usage\n            tensor = (\n                self.to_global(\n                    sbp=self.sbp, placement=flow.placement(\"cpu\", self.placement.ranks)\n                )\n                .to_global(\n                    sbp=flow.sbp.broadcast,\n                    placement=flow.placement(\"cpu\", [context_data.global_rank]),\n                )\n                .to_local()\n            )\n        if context_data.save_as_external_data:\n            if context_data.global_rank is None:\n                rel_dir_name = id_util.UniqueStr(\"tensor_\")\n            else:\n                rel_dir_name = f\"global_tensor_{self.global_id()}\"\n            abs_dir_name = context_data.path / rel_dir_name\n\n            if (\n                context_data.global_rank is None\n                or context_data.global_rank == flow.env.get_rank()\n            ):\n                _save_tensor_to_disk(tensor, abs_dir_name)\n\n            return {\"path\": rel_dir_name}\n        else:\n            return {\n                \"data\": tensor.numpy(),\n                \"dtype\": tensor.dtype,\n                \"device\": \"cpu\",\n            }\n    else:\n        if self.is_local:\n            if self.is_cuda:\n                device = \"cuda\"\n            else:\n                device = \"cpu\"\n            return {\"data\": self.numpy(), \"dtype\": self.dtype, \"device\": device}\n        else:\n            return {\n                \"data\": self.numpy(),\n                \"dtype\": self.dtype,\n                \"placement\": self.placement,\n                \"sbp\": self.sbp,\n            }\n\n\ndef tensor_setstate(self, pickle_dict):\n    if context_data is not None:\n        if context_data.save_as_external_data:\n            rel_dir_name = pickle_dict[\"path\"]\n            abs_dir_name = context_data.path / rel_dir_name\n            tmp_tensor = _LoadSingleVariable(\n                str(abs_dir_name), context_data.global_rank, context_data.map_location\n            )\n            self.__init__(tmp_tensor)\n        else:\n            self.__init__(flow.tensor(pickle_dict[\"data\"], dtype=pickle_dict[\"dtype\"]))\n    else:\n        if \"placement\" in pickle_dict:\n            return self.__init__(\n                flow.tensor(\n                    pickle_dict[\"data\"],\n                    dtype=pickle_dict[\"dtype\"],\n                    placement=pickle_dict[\"placement\"],\n                    sbp=pickle_dict[\"sbp\"],\n                )\n            )\n        else:\n            return self.__init__(\n                flow.tensor(\n                    pickle_dict[\"data\"],\n                    dtype=pickle_dict[\"dtype\"],\n                    device=pickle_dict[\"device\"],\n                )\n            )\n\n\ndef placement_getstate(self):\n    return {\n        \"type\": self.type,\n        \"ranks\": self.ranks,\n    }\n\n\ndef placement_setstate(self, state):\n    return self.__init__(state[\"type\"], state[\"ranks\"])\n\n\ndef RegisterMethods():\n    Tensor.__setstate__ = tensor_setstate\n    Tensor.__getstate__ = tensor_getstate\n    flow._oneflow_internal.placement.__getstate__ = placement_getstate\n    flow._oneflow_internal.placement.__setstate__ = placement_setstate\n\n\nload_methods = []\n\n\ndef load_if(condition):\n    def decorator(func):\n        def condition_always_returning_extra_data(*args, **kwargs):\n            res = condition(*args, **kwargs)\n            if isinstance(res, tuple):\n                assert len(res) == 2\n                assert isinstance(res[1], tuple)\n                return res\n            else:\n                return res, ()\n\n        load_methods.append((condition_always_returning_extra_data, func))\n        return func\n\n    return decorator\n\n\ndef is_dir_and_no_pickle_file(path: FILE_LIKE, support_pytorch_format: bool):\n    if _is_path(path) and path.is_dir():\n        pickle_path = path / PICKLE_FILENAME\n        return not pickle_path.exists()\n    return False\n\n\n@load_if(is_dir_and_no_pickle_file)\ndef legacy_load(\n    path: Path, global_src_rank: Optional[int], map_location: MAP_LOCATION,\n) -> Dict[str, \"flow.Tensor\"]:\n    assert os.path.isdir(path), \"Directory {} doesn't exist!\".format(path)\n    rank = flow.env.get_rank()\n    var_dict = {}\n    if global_src_rank is None or rank == global_src_rank:\n        all_files = os.listdir(path)\n        assert SNAPSHOT_DONE_FILENAME in all_files\n        all_files.remove(SNAPSHOT_DONE_FILENAME)\n        if global_src_rank is not None:\n            _broadcast_py_object(all_files, global_src_rank)\n    else:\n        all_files = _broadcast_py_object(None, global_src_rank)\n    for f in all_files:\n        var_dir = os.path.join(path, f)\n        try:\n            var_dict[f] = _LoadSingleVariable(var_dir, global_src_rank, map_location)\n        except FileNotFoundError:\n            warnings.warn(\n                f\"'{var_dir}' does not have valid tensor data. Please check it if it is unexpected.\",\n                stacklevel=2,\n            )\n    return var_dict\n\n\n@contextlib.contextmanager\ndef tensor_pickling_context(\n    path: Path,\n    global_rank: Optional[int],\n    mp: MAP_LOCATION,\n    save_as_external_data: bool,\n):\n    global context_data\n    context_data = ContextData(path, global_rank, mp, save_as_external_data)\n    try:\n        yield\n    finally:\n        context_data = None\n\n\ndef is_oneflow_pickle_file(path: FILE_LIKE, support_pytorch_format: bool) -> bool:\n    if _is_path(path) and not path.is_file():\n        return False\n    try:\n        with _open_file_like(path, \"rb\") as f:\n            content = pickle.load(f)\n            if ONEFLOW_MAGIC_KEY in content:\n                return True, (content,)\n            else:\n                return False\n    except:\n        return False\n\n\n# `path` is not used in this function, because the file is already loaded\n# and deserialized in `is_oneflow_pickle_file`, and the content is passed\n# as `content`.\n@load_if(is_oneflow_pickle_file)\ndef load_from_oneflow_single_file(\n    path: FILE_LIKE, global_src_rank, map_location: MAP_LOCATION, content: Any = None,\n):\n    rank = flow.env.get_rank()\n    if global_src_rank is None or rank == global_src_rank:\n        assert content[\"protocol_version\"] == PROTOCOL_VERSION\n        res = content[\"data\"]\n    else:\n        res = None\n\n    if global_src_rank is not None:\n        res = flow.utils.global_view.to_global(\n            res,\n            placement=flow.placement(\"cpu\", [global_src_rank]),\n            sbp=flow.sbp.broadcast,\n            warn_on_non_tensor_leaf=False,\n        )\n    res = _map_location(res, map_location)\n    return res\n\n\ndef is_file_and_support_pytorch_format(\n    path: FILE_LIKE, support_pytorch_format: bool\n) -> bool:\n    if not support_pytorch_format:\n        return False\n    if _is_path(path) and not path.is_file():\n        return False\n    try:\n        with flow.mock_torch.disable():\n            import torch\n\n            content = torch.load(path, map_location=\"cpu\")\n            return True, (content,)\n    except:\n        if os.getenv(\"ONEFLOW_DEBUG_CHECKPOINT\") == \"1\":\n            import traceback\n\n            traceback.print_exc()\n        return False\n\n\n@load_if(is_file_and_support_pytorch_format)\ndef load_from_pytorch_file(\n    path: FILE_LIKE, global_src_rank, map_location: MAP_LOCATION, torch_obj: Any = None\n):\n    if torch_obj is not None:\n        with flow.mock_torch.disable():\n            import torch\n\n            def torch_tensor_to_flow(x):\n                if isinstance(x, torch.Tensor):\n                    return flow.utils.tensor.from_torch(x)\n                else:\n                    return x\n\n            flow_obj = ArgsTree(torch_obj).map_leaf(torch_tensor_to_flow)\n    else:\n        flow_obj = None\n    if global_src_rank is not None:\n        flow_obj = flow.utils.global_view.to_global(\n            flow_obj,\n            placement=flow.placement(\"cpu\", [global_src_rank]),\n            sbp=flow.sbp.broadcast,\n            warn_on_non_tensor_leaf=False,\n        )\n    flow_obj = _map_location(flow_obj, map_location)\n    return flow_obj\n\n\ndef is_dir_and_has_pickle_file(path: FILE_LIKE, support_pytorch_format: bool) -> bool:\n    if _is_path(path) and path.is_dir():\n        pickle_path = path / PICKLE_FILENAME\n        return pickle_path.exists()\n    return False\n\n\n@load_if(is_dir_and_has_pickle_file)\ndef load_from_oneflow_pickle_dir(\n    path: Path, global_src_rank: Optional[int], map_location: MAP_LOCATION,\n):\n    rank = flow.env.get_rank()\n    pickle_path = path / PICKLE_FILENAME\n    if global_src_rank is not None:\n        if rank == global_src_rank:\n            pickle_bytes = pickle_path.read_bytes()\n            _broadcast_py_object(pickle_bytes, global_src_rank)\n        else:\n            pickle_bytes = _broadcast_py_object(None, global_src_rank)\n    else:\n        pickle_bytes = pickle_path.read_bytes()\n\n    if map_location is not None:\n        assert isinstance(\n            map_location, (str, flow.device, flow.placement)\n        ), \"'map_location' only supports str, device or placement.\"\n    with tensor_pickling_context(path, global_src_rank, map_location, True):\n        res = pickle.loads(pickle_bytes)\n    assert res[\"protocol_version\"] == PROTOCOL_VERSION\n    return res[\"data\"]\n\n\ndef load(\n    path: Union[FILE_LIKE, str],\n    global_src_rank: Optional[int] = None,\n    map_location: MAP_LOCATION = None,\n    *,\n    support_pytorch_format: bool = True,\n) -> Any:\n    r\"\"\"Loads an object saved with oneflow.save() from a directory.\n\n    Args:\n        path: a file-like object (has to implement :meth:`read`, :meth:`readline`, :meth:`tell`, and :meth:`seek`),\n            or a string or os.PathLike object containing a file name\n        global_src_rank (int, optional): The source rank for\n            loading global tensors. When specified, only the\n            process whose rank == global_src_rank will really\n            read the files in `path`, and tensors in the loaded\n            object will be consistent with placement =\n            `flow.placement('cuda', [global_src_rank])`\n        map_location (str, flow.device or flow.placement, callable, optional):\n            indicates the location where all tensors should be loaded.\n        support_pytorch_format (bool, optional): whether to support\n            loading the file saved by `torch.save`. Default: True\n\n    Returns:\n        The loaded object\n    \"\"\"\n    if isinstance(path, str):\n        path = Path(path)\n    rank = flow.env.get_rank()\n    if global_src_rank is None or global_src_rank == rank:\n        for i, (condition, load) in enumerate(load_methods):\n            is_ok, extra_data = condition(path, support_pytorch_format)\n            if is_ok:\n                if global_src_rank is not None:\n                    _broadcast_py_object(i, global_src_rank)\n                break\n        else:\n            if _is_path(path):\n                err_msg = f'Cannot load file \"{path}\"'\n            else:\n                err_msg = \"Cannot load the data\"\n            raise ValueError(err_msg)\n    else:\n        i = _broadcast_py_object(None, global_src_rank)\n        load = load_methods[i][1]\n        extra_data = ()\n\n    return load(path, global_src_rank, map_location, *extra_data)  # type: ignore\n\n\ndef save_one_embedding_info(state_dict: Any, path: Union[str, Path]) -> None:\n    path: Path = Path(path)\n\n    _embedding_info_dict = {\"embedding\": []}\n    os.makedirs(path, exist_ok=True)\n\n    _save_one_embedding_info_flag = False\n\n    for module in state_dict.keys():\n        if not isinstance(state_dict[module], OrderedDict):\n            continue\n        for module_key in state_dict[module].keys():\n            _info_dict = {}\n            if \"OneEmbeddingKeyValueOptions\" in module_key:\n                if not _save_one_embedding_info_flag:\n                    _save_one_embedding_info_flag = True\n\n                module_key_prefix = module_key.rstrip(\"OneEmbeddingKeyValueOptions\")\n\n                _embedding_info_dict[\"embedding\"].append(\n                    {\n                        \"snapshot\": state_dict[\"module\"][\n                            module_key_prefix + \"OneEmbeddingSnapshot\"\n                        ],\n                        \"kv_options\": json.loads(\n                            state_dict[\"module\"][\n                                module_key_prefix + \"OneEmbeddingKeyValueOptions\"\n                            ]\n                        ),\n                    }\n                )\n\n    if _save_one_embedding_info_flag:\n        with open(os.path.join(path, \"one_embedding_options.json\"), \"w\") as f:\n            f.write(json.dumps(_embedding_info_dict, indent=4))\n\n\ndef save(\n    obj: Any,\n    path_or_buffer: FILE_LIKE,\n    global_dst_rank: Optional[int] = None,\n    save_as_external_data: bool = False,\n) -> None:\n    r\"\"\"Save an object to a directory.\n\n    Args:\n        obj: The object to be saved\n        path_or_buffer: a file-like object (has to implement write and flush) or a string or\n           os.PathLike object containing a file name\n        global_dst_rank (int, optional): The destination rank for\n            saving global tensors. When specified, whole tensors\n            will be saved by the process whose rank ==\n            global_src_rank, while other processes will not do any\n            disk I/O.\n        save_as_external_data (bool): useful only if path_or_buffer is a string or\n           os.PathLike object containing a file name\n    \"\"\"\n    if isinstance(path_or_buffer, str):\n        path_or_buffer = Path(path_or_buffer)\n\n    if isinstance(obj, graph_util.Graph):\n        if not _is_path(path_or_buffer):\n            raise ValueError(\n                \"path_or_buffer must be the type of {`str`, `pathlib.Path`} while obj is Graph\"\n            )\n        _save_graph(obj, path_or_buffer)\n        return\n\n    # this `path` is only used for `ContextData` and is set to empty when `path_or_buffer` is IO[bytes] or BinaryIO\n    path: Path = Path(path_or_buffer if _is_path(path_or_buffer) else \"\")\n    obj = {\"protocol_version\": PROTOCOL_VERSION, ONEFLOW_MAGIC_KEY: None, \"data\": obj}\n\n    with tensor_pickling_context(path, global_dst_rank, None, save_as_external_data):\n        pickled_bytes = pickle.dumps(obj)\n\n    if _is_path(path_or_buffer) and save_as_external_data:\n        path_or_buffer.mkdir(exist_ok=True)\n        path_or_buffer = path_or_buffer / PICKLE_FILENAME\n\n    def write_file():\n        with _open_file_like(path_or_buffer, \"wb\") as f:\n            f.write(pickled_bytes)\n\n    if global_dst_rank is not None:\n        assert isinstance(\n            global_dst_rank, int\n        ), f\"global_dst_rank expected type int, but got {type(global_dst_rank)}.\"\n        assert (\n            global_dst_rank >= 0 and global_dst_rank < flow.env.get_world_size()\n        ), f\"out of range (expected to be in range of [0, {flow.env.get_world_size()}), but got {global_dst_rank}).\"\n        if flow.env.get_rank() == global_dst_rank:\n            write_file()\n    else:\n        # global_dst_rank is None\n        write_file()\n\n\ndef _save_graph(obj: graph_util.Graph, path: Union[str, Path]):\n    path: Path = Path(path)\n    graph: graph_util.Graph = obj\n    if not graph._is_compiled:\n        raise RuntimeError(\"graph must be compiled first.\")\n\n    path.mkdir(exist_ok=True)\n\n    serialized_job = graph._forward_job_proto.SerializeToString()\n    oneflow._oneflow_internal.nn.graph.SaveJobToIR(serialized_job, str(path))\n\n    for x in graph._state():\n        _save_tensor_to_disk(\n            x.to(Tensor),\n            path / f\"{x.to(GraphTensor).name_prefix}{x.to(GraphTensor).name}\",\n        )\n\n    save_one_embedding_info(obj.state_dict(), path)\n\n\ndef frombuffer(\n    buffer: object,\n    dtype: oneflow.dtype,\n    count: int = -1,\n    offset: int = 0,\n    requires_grad: bool = False,\n):\n    return oneflow.tensor(\n        np.frombuffer(\n            buffer,\n            dtype_util.convert_oneflow_dtype_to_numpy_dtype(dtype),\n            count,\n            offset,\n        ),\n        dtype=dtype,\n        requires_grad=requires_grad,\n    )\n\n\nclass ContextData:\n    def __init__(\n        self,\n        path: Path,\n        global_rank: Optional[int],\n        map_location: Optional[Union[str, flow.device, flow.placement]],\n        save_as_external_data: bool,\n    ):\n        self.path = path\n        self.global_rank = global_rank\n        self.map_location = map_location\n        self.save_as_external_data = save_as_external_data\n\n\ncontext_data = None\n"
  },
  {
    "path": "python/oneflow/framework/config_util.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport os\nimport sys\nimport traceback\nfrom typing import Callable, List, Union\n\nimport oneflow._oneflow_internal\nimport oneflow.core.job.resource_pb2 as resource_util\nimport oneflow.framework.session_context as session_ctx\nimport oneflow.framework.attr_util as attr_util\n\n\ndef _set_resource_attr(attrs_chain: Union[List[str], str], attr_value, type_):\n    r\"\"\"\n    set the attribute of config_proto.resource to attr_value.\n    the attribute is specified as a string or a list of string.\n\n    for example, if we want to do this:\n        `config_proto.resource.machine_num = 1`\n\n    we can call `_set_resource_attr(\"machine_num\", 1)`\n\n    if we want to do:\n        `config_proto.resource.collective_boxing_conf.nccl_num_streams = 1`\n    \n    we can call `_set_resource_attr([\"collective_boxing_conf\", \"nccl_num_streams\"], 1)`\n`\n    \"\"\"\n    assert isinstance(attr_value, type_), (\n        \"Attribute \"\n        + repr(attrs_chain)\n        + \" type unmatched! Expected: \"\n        + str(type_)\n        + \" but get: \"\n        + str(type(attr_value))\n    )\n\n    if isinstance(attrs_chain, str):\n        attrs_chain = [attrs_chain]\n\n    session = session_ctx.GetDefaultSession()\n\n    # get the current resource config\n    resource_config = (\n        session.config_proto.resource\n        if session.status_ != session.Status.INITED\n        else session.resource\n    )\n\n    # update the current resource config\n    setattr(\n        attr_util.get_nested_attribute(\n            resource_config, attrs_chain[0:-1]\n        ),  # the owning object of the attribute to be updated\n        attrs_chain[-1],  # the attribute needs to be updated\n        attr_value,\n    )\n\n    # update the resource config eagerly if the session is already initialized\n    if session.status_ == session.Status.INITED:\n        session.update_resource_eagerly(resource_config)\n\n\ndef api_load_library(val: str) -> None:\n    \"\"\"Load necessary library for job now\n    Args:\n        val (str): path to shared object file\n    \"\"\"\n    assert type(val) is str\n    oneflow._oneflow_internal.LoadLibrary(val)\n\n\ndef api_numa_aware_cuda_malloc_host(val: bool = True) -> None:\n    \"\"\"Whether or not let numa know  that  cuda allocated host's memory.\n\n    Args:\n        val (bool, optional): True or False. Defaults to True.\n    \"\"\"\n    print(\n        \"'enable_numa_aware_cuda_malloc_host' has been deprecated, has no effect and will be removed in the future.\"\n    )\n\n\ndef api_reserved_device_mem_mbyte(val: int) -> None:\n    \"\"\"Set up the memory size of reserved device\n    Args:\n        val (int):  memory size, e.g. 1024(mb)\n    \"\"\"\n\n    attrs, type_ = api_attrs_and_type[api_reserved_device_mem_mbyte]\n    _set_resource_attr(attrs, val, type_)\n\n\ndef api_enable_cudnn_fused_normalization_add_relu(val: bool) -> None:\n    \"\"\"Whether enable cudnn_fused_normalization_add_relu.\n\n    Args:\n        val (bool): whether enable or not\n    \"\"\"\n\n    attrs, type_ = api_attrs_and_type[api_enable_cudnn_fused_normalization_add_relu]\n    _set_resource_attr(attrs, val, type_)\n\n\ndef api_enable_cudnn_conv_heuristic_search_algo(val: bool) -> None:\n    \"\"\"Whether enable cudnn conv operatioin to use heuristic search algorithm.\n\n    Args:\n        val (bool): whether enable or not, the default value is true.\n    \"\"\"\n\n    attrs, type_ = api_attrs_and_type[api_enable_cudnn_conv_heuristic_search_algo]\n    _set_resource_attr(attrs, val, type_)\n\n\ndef api_enable_fusion(val: bool = True) -> None:\n    \"\"\"Whether or not allow fusion the operators\n\n    Args:\n        val (bool, optional): True or False. Defaults to True.\n    \"\"\"\n\n    attrs, type_ = api_attrs_and_type[api_enable_fusion]\n    _set_resource_attr(attrs, val, type_)\n\n\ndef api_nccl_use_compute_stream(val: bool = False) -> None:\n    \"\"\"Whether or not nccl use compute stream to reuse nccl memory and speedup\n\n    Args:\n        val (bool, optional): True or False. Defaults to False.\n    \"\"\"\n\n    attrs, type_ = api_attrs_and_type[api_nccl_use_compute_stream]\n    _set_resource_attr(attrs, val, type_)\n\n\ndef api_disable_group_boxing_by_dst_parallel(val: bool = False) -> None:\n    \"\"\"Whether or not disable group boxing by dst parallel pass to reduce boxing memory life cycle.\n\n    Args:\n        val (bool, optional): True or False. Defaults to False.\n    \"\"\"\n\n    attrs, type_ = api_attrs_and_type[api_disable_group_boxing_by_dst_parallel]\n    _set_resource_attr(attrs, val, type_)\n\n\ndef api_nccl_num_streams(val: int) -> None:\n    \"\"\"Set up the number of nccl parallel streams while use boxing\n\n    Args:\n        val (int): number of streams\n    \"\"\"\n\n    attrs, type_ = api_attrs_and_type[api_nccl_num_streams]\n    _set_resource_attr(attrs, val, type_)\n\n\ndef api_nccl_fusion_threshold_mb(val: int) -> None:\n    \"\"\"Set up threshold for oprators fusion\n\n    Args:\n        val (int): int number, e.g. 10(mb)\n    \"\"\"\n\n    attrs, type_ = api_attrs_and_type[api_nccl_fusion_threshold_mb]\n    _set_resource_attr(attrs, val, type_)\n\n\ndef api_nccl_fusion_all_reduce_use_buffer(val: bool) -> None:\n    \"\"\"Whether or not use buffer during nccl fusion progress\n\n    Args:\n        val (bool): True or False\n    \"\"\"\n\n    attrs, type_ = api_attrs_and_type[api_nccl_fusion_all_reduce_use_buffer]\n    _set_resource_attr(attrs, val, type_)\n\n\ndef api_nccl_fusion_all_reduce(val: bool) -> None:\n    \"\"\"Whether or not use nccl fusion during all reduce progress\n\n    Args:\n        val (bool):  True or False\n    \"\"\"\n\n    attrs, type_ = api_attrs_and_type[api_nccl_fusion_all_reduce]\n    _set_resource_attr(attrs, val, type_)\n\n\ndef api_nccl_fusion_reduce_scatter(val: bool) -> None:\n    \"\"\"Whether or not  use nccl fusion during reduce scatter progress\n\n    Args:\n        val (bool): True or False\n    \"\"\"\n\n    attrs, type_ = api_attrs_and_type[api_nccl_fusion_reduce_scatter]\n    _set_resource_attr(attrs, val, type_)\n\n\ndef api_nccl_fusion_all_gather(val: bool) -> None:\n    \"\"\"Whether or not use nccl fusion during all gather progress\n\n    Args:\n        val (bool): True or False\n    \"\"\"\n\n    attrs, type_ = api_attrs_and_type[api_nccl_fusion_all_gather]\n    _set_resource_attr(attrs, val, type_)\n\n\ndef api_nccl_fusion_reduce(val: bool) -> None:\n    \"\"\"Whether or not use nccl fusion during reduce progress\n\n    Args:\n        val (bool): True or False\n    \"\"\"\n\n    attrs, type_ = api_attrs_and_type[api_nccl_fusion_reduce]\n    _set_resource_attr(attrs, val, type_)\n\n\ndef api_nccl_fusion_broadcast(val: bool) -> None:\n    \"\"\"Whether or not use nccl fusion during broadcast progress\n\n    Args:\n        val (bool): True or False\n    \"\"\"\n\n    attrs, type_ = api_attrs_and_type[api_nccl_fusion_broadcast]\n    _set_resource_attr(attrs, val, type_)\n\n\ndef api_nccl_fusion_max_ops(val: int) -> None:\n    \"\"\"Maximum number of ops for nccl fusion.\n\n    Args:\n        val (int): Maximum number of ops\n    \"\"\"\n\n    attrs, type_ = api_attrs_and_type[api_nccl_fusion_max_ops]\n    _set_resource_attr(attrs, val, type_)\n\n\ndef api_nccl_enable_all_to_all(val: bool) -> None:\n    \"\"\"Whether or not use nccl all2all during s2s boxing\n\n    Args:\n        val (bool): True or False\n    \"\"\"\n\n    attrs, type_ = api_attrs_and_type[api_nccl_enable_all_to_all]\n    _set_resource_attr(attrs, val, type_)\n\n\ndef api_nccl_enable_mixed_fusion(val: bool) -> None:\n    \"\"\"Whether or not use nccl mixed fusion\n\n    Args:\n        val (bool): True or False\n    \"\"\"\n\n    attrs, type_ = api_attrs_and_type[api_nccl_enable_mixed_fusion]\n    _set_resource_attr(attrs, val, type_)\n\n\napi_attrs_and_type = {\n    api_reserved_device_mem_mbyte: (\"reserved_device_mem_mbyte\", int),\n    api_enable_cudnn_fused_normalization_add_relu: (\n        [\"cudnn_conf\", \"enable_cudnn_fused_normalization_add_relu\"],\n        bool,\n    ),\n    api_enable_cudnn_conv_heuristic_search_algo: (\n        [\"cudnn_conf\", \"cudnn_conv_heuristic_search_algo\"],\n        bool,\n    ),\n    api_enable_fusion: ([\"collective_boxing_conf\", \"enable_fusion\"], bool),\n    api_nccl_use_compute_stream: (\"nccl_use_compute_stream\", bool),\n    api_disable_group_boxing_by_dst_parallel: (\n        \"disable_group_boxing_by_dst_parallel\",\n        bool,\n    ),\n    api_nccl_num_streams: ([\"collective_boxing_conf\", \"nccl_num_streams\"], int),\n    api_nccl_fusion_threshold_mb: (\n        [\"collective_boxing_conf\", \"nccl_fusion_threshold_mb\"],\n        int,\n    ),\n    api_nccl_fusion_all_reduce_use_buffer: (\n        [\"collective_boxing_conf\", \"nccl_fusion_all_reduce_use_buffer\"],\n        bool,\n    ),\n    api_nccl_fusion_all_reduce: (\n        [\"collective_boxing_conf\", \"nccl_fusion_all_reduce\"],\n        bool,\n    ),\n    api_nccl_fusion_reduce_scatter: (\n        [\"collective_boxing_conf\", \"nccl_fusion_reduce_scatter\"],\n        bool,\n    ),\n    api_nccl_fusion_all_gather: (\n        [\"collective_boxing_conf\", \"nccl_fusion_all_gather\"],\n        bool,\n    ),\n    api_nccl_fusion_reduce: ([\"collective_boxing_conf\", \"nccl_fusion_reduce\"], bool),\n    api_nccl_fusion_broadcast: (\n        [\"collective_boxing_conf\", \"nccl_fusion_broadcast\"],\n        bool,\n    ),\n    api_nccl_fusion_max_ops: ([\"collective_boxing_conf\", \"nccl_fusion_max_ops\"], int),\n    api_nccl_enable_all_to_all: (\n        [\"collective_boxing_conf\", \"nccl_enable_all_to_all\"],\n        bool,\n    ),\n    api_nccl_enable_mixed_fusion: (\n        [\"collective_boxing_conf\", \"nccl_enable_mixed_fusion\"],\n        bool,\n    ),\n}\n"
  },
  {
    "path": "python/oneflow/framework/distribute.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport traceback\nimport warnings\nfrom contextlib import contextmanager\n\nimport oneflow._oneflow_internal\n\n\ndef split_sbp(dim=None, **kwargs) -> oneflow._oneflow_internal.sbp.sbp:\n    \"\"\"\n    Generate a split signature which indicates the tensor will be split along `dim`.\n\n    Args:\n        dim (int): The dimension in which the tensor is split. \n\n    Returns:\n        SbpParallel: Split scheme object, often required by `to_global` method of `Tensor`\n\n    Example::\n        array = numpy.array([[1.0, 2.0], [3.0, 4.0]])\n        t1 = flow.tensor(array)\n        ct2 = t1.to_global(sbp=flow.sbp.split(0), placement=(\"cuda\", ranks=[0, 1, 2, 3]))\n\n    \"\"\"\n    if dim is None:\n        for key, value in kwargs.items():\n            if key == \"axis\":\n                if not isinstance(value, int):\n                    raise TypeError(\n                        \"split_sbp(): parameter must be int, not {}.\".format(\n                            type(value)\n                        )\n                    )\n                warnings.warn(\n                    \"This 'axis' parameter of oneflow.sbp.split() has been updated to 'dim' since OneFlow version 0.8.\"\n                )\n                dim = value\n            else:\n                raise TypeError(\n                    \"split_sbp() got an unexpected keyword argument '%s'.\" % key\n                )\n\n        if dim is None:\n            raise TypeError(\"split_sbp() missing 1 required argument: 'dim'.\")\n\n    else:\n        for key, value in kwargs.items():\n            if key == \"axis\":\n                raise TypeError(\n                    \"split_sbp() received an invalid combination of arguments - duplicate argument `axis`\"\n                )\n            else:\n                raise TypeError(\n                    \"split_sbp() got an unexpected keyword argument '%s'.\" % key\n                )\n\n    assert isinstance(dim, int)\n    return oneflow._oneflow_internal.sbp.split(dim)\n"
  },
  {
    "path": "python/oneflow/framework/docstr/__init__.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nfrom .math_ops import *\nfrom .random import *\nfrom .conv import *\nfrom .as_tensor import *\nfrom .pooling import *\nfrom .activation import *\nfrom .dropout import *\nfrom .vision import *\nfrom .norm import *\nfrom .normalization import *\nfrom .loss import *\nfrom .onehot import *\nfrom .comparison import *\nfrom .cast import *\nfrom .constant import *\nfrom .array_ops import *\nfrom .tensor import *\nfrom .tensor_attributes import *\nfrom .comm import *\nfrom .ctc_decode import *\nfrom .trigonometric_ops import *\nfrom .tensor_ops import *\nfrom .meshgrid import *\nfrom .dataset import *\nfrom .bmm import *\nfrom .flatten import *\nfrom .chunk import *\nfrom .broadcast_like import *\nfrom .arange import *\nfrom .split import *\nfrom .clamp import *\nfrom .erfinv import *\nfrom .swapaxes import *\nfrom .amax import *\nfrom .unbind import *\nfrom .repeat import *\nfrom .repeat_interleave import *\nfrom .tile import *\nfrom .tensor_t import *\nfrom .topk import *\nfrom .nms import *\nfrom .nonzero import *\nfrom .reduce_ops import *\nfrom .masked_fill import *\nfrom .expand import *\nfrom .flip import *\nfrom .in_top_k import *\nfrom .index_select import *\nfrom .sort import *\nfrom .is_floating_point import *\nfrom .swapdims import *\nfrom .where import *\nfrom .einsum import *\nfrom .oneflow import *\nfrom .argsort import *\nfrom .module import *\nfrom .util_ops import *\nfrom .tensordot import *\nfrom .searchsorted import *\nfrom .amin import *\nfrom .deconv import *\nfrom .inv import *\nfrom .logical_ops import *\nfrom .bitwise_ops import *\nfrom .distance import *\nfrom .addcdiv import *\nfrom .hann_window import *\nfrom .convolution import *\nfrom .linalg import *\nfrom .index_add import *\nfrom .baddbmm import *\nfrom .lerp import *\nfrom .quantile import *\nfrom .depend import *\nfrom .special_ops import *\n"
  },
  {
    "path": "python/oneflow/framework/docstr/activation.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow\nfrom oneflow.framework.docstr.utils import add_docstr\n\nadd_docstr(\n    oneflow._C.prelu,\n    \"\"\"\n    prelu(x: Tensor, alpha: Tensor) -> Tensor  \n\n    Applies the element-wise function:\n\n    .. math::\n        prelu(x) = max(0,x) + alpha * min(0,x) \n\n    For example:\n\n    .. code-block:: python\n\n        >>> import numpy as np\n        >>> import oneflow as flow\n\n        >>> x = flow.tensor(np.asarray([[[[1, -2], [3, 4]]]]), dtype=flow.float32)\n        >>> alpha = flow.nn.Parameter(flow.tensor([1], dtype=flow.float32).fill_(0.25))\n        >>> flow.nn.functional.prelu(x, alpha)\n        tensor([[[[ 1.0000, -0.5000],\n                  [ 3.0000,  4.0000]]]], dtype=oneflow.float32,\n               grad_fn=<preluBackward>)\n   \n    See\n    :class:`~oneflow.nn.PReLU` for more details.\n \n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.relu,\n    \"\"\"\n    Applies the rectified linear unit function element-wise. See :class:`~oneflow.nn.ReLU` for more details. \n\n    Args:\n        inplace: If set to ``True``, will do this operation in-place. Default: ``False``\n    \n    For examples:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n\n        >>> ndarr = np.asarray([1, -2, 3])\n        >>> input = flow.Tensor(ndarr)\n        >>> output = flow.relu(input)\n        >>> output\n        tensor([1., 0., 3.], dtype=oneflow.float32)\n\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.gelu,\n    r\"\"\"\n    gelu(x: Tensor) -> Tensor \n\n    Applies the Gaussian Error Linear Units function:\n\n    .. math:: \\\\text{GELU}(x) = x * \\Phi(x)\n\n    where :math:`\\Phi(x)` is the Cumulative Distribution Function for Gaussian Distribution.\n\n    When the approximate argument is 'tanh', Gelu is estimated with:\n\n    .. math:: \\\\text{GELU}(x) = 0.5 * x * (1 + \\\\text{Tanh}(\\sqrt(2 / \\pi) * (x + 0.044715 * x^3)))\n\n    Args:\n        input (oneflow.Tensor): Input Tensor\n        approximate (string, optional): the gelu approximation algorithm to use:\n            ``'none'`` | ``'tanh'``. Default: ``'none'``\n\n    Returns:\n        oneflow.Tensor: A Tensor has same shape as the input.\n    \n    For example:\n\n    .. code-block:: python\n\n        >>> import numpy as np\n        >>> import oneflow as flow\n        \n        >>> x = np.array([-0.5, 0, 0.5]).astype(np.float32)\n        >>> input = flow.tensor(x)\n\n        >>> out = flow.gelu(input)\n        >>> out\n        tensor([-0.1543,  0.0000,  0.3457], dtype=oneflow.float32)\n\n    See    \n    :class:`~oneflow.nn.GELU` for more details.\n \n    \"\"\",\n)\n\n\nadd_docstr(\n    oneflow._C.quick_gelu,\n    r\"\"\"\n    quick_gelu(x: Tensor) -> Tensor \n\n    Applies GELU approximation that is fast but somewhat inaccurate. See: https://github.com/hendrycks/GELUs\n\n    .. math::\n        \\\\text{QuickGELU}(x) = x * \\\\sigma(1.702x) = x * \\\\frac{1}{1 + \\\\exp(-1.702x)}\n\n    Args:\n        input (oneflow.Tensor): Input Tensor\n\n    Returns:\n        oneflow.Tensor: A Tensor has same shape as the input.\n\n    See    \n    :class:`~oneflow.nn.QuickGELU` for more details.\n \n    \"\"\",\n)\n\nadd_docstr(\n    oneflow._C.square_relu,\n    r\"\"\"\n    square_relu(x: Tensor) -> Tensor \n\n    Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2\n\n    .. math::\n        \\\\text{ReLU}(x) = \\\\max(0, x) * \\\\max(0, x)\n\n    Args:\n        input (oneflow.Tensor): Input Tensor\n\n    Returns:\n        oneflow.Tensor: A Tensor has same shape as the input.\n\n    See    \n    :class:`~oneflow.nn.SquareReLU` for more details.\n \n    \"\"\",\n)\n\nadd_docstr(\n    oneflow._C.softmax,\n    r\"\"\"\n    softmax(x: Tensor, dim: int) -> Tensor \n\n    Softmax is defined as:\n\n    .. math::\n        \\text{Softmax}(x_{i}) = \\frac{\\\\exp(x_i)}{\\sum_j \\exp(x_j)}\n    \n    See :class:`~oneflow.nn.Softmax` for more details.\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow._C.log_softmax,\n    r\"\"\"\n    log_softmax(x: Tensor, dim: int) -> Tensor \n\n    LogSoftmax is defined as:\n\n    .. math::\n        \\text{LogSoftmax}(x_{i}) = \\log\\left(\\frac{\\exp(x_i) }{ \\sum_j \\exp(x_j)} \\right) = x_i - \\log({ \\sum_j \\exp(x_j)})\n    \n    See :class:`~oneflow.nn.LogSoftmax` for more details.\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow._C.gumbel_softmax,\n    r\"\"\"\n    gumbel_softmax(x: Tensor, dim: int, tau: float = 1.0, hard: bool = False) -> Tensor \n\n    Solve the problem that the output values of argmax do not reflect the probability distribution of the model's output.\n    Compensates for the fact that the argmax cannot participate in gradient back-propagation.\n\n    Gumbel is defined as:\n\n    .. math::\n        Gumbel_i = -log(-log(U_i)),\\ U_i \\sim U(0,1)\n\n    Add Noise ~ Gumbel:\n\n    .. math::\n        In = (In + Noise) / tau\n\n    Calculate Softmax value:\n\n    .. math::\n        gumbel\\_softmax(In)=\\frac{e^{In_i/tau}}{\\sum_{j=1}^n{e^{In_j/tau}}},i=1,2,3...n\n\n    Parameters:\n        x (oneflow.Tensor): the input Tensor.\n        dim (int, Tuple[int]): the dimension to softmax. \n        tau (double): the input tensor of Softmax should obey the Gumbel(x, tau).\n        hard (bool): if `hard=True`, the output tensor will be one-hot.\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.softplus,\n    r\"\"\"\n    softplus(x: Tensor, beta: double = 1, threshold: double = 20) -> Tensor \n\n    Applies the element-wise function:\n\n    .. math::\n        \\text{Softplus}(x) = \\frac{1}{\\beta} * \\log(1 + \\exp(\\beta * x))   \n\n    For numerical stability the implementation reverts to the linear function\n    when :math:`input \\times \\beta > threshold`. \n    \n    See :class:`~oneflow.nn.Softplus` for more details.\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.tanh,\n    r\"\"\"\n    tanh(x: Tensor) -> Tensor \n\n    The equation is:\n\n    .. math::\n\n        out = \\frac{e^x-e^{-x}}{e^x+e^{-x}}\n\n    See :class:`~oneflow.nn.Tanh` for more details.\n    \"\"\",\n)\nadd_docstr(\n    oneflow._C.logsigmoid,\n    r\"\"\"\n    logsigmoid(x: Tensor) -> Tensor \n\n    Applies the element-wise function:\n\n    .. math::\n        \\text{logsigmoid}(x) = \\log\\left(\\frac{ 1 }{ 1 + \\exp(-x)}\\right)\n   \n    For example:\n\n    .. code-block:: python\n\n\n        >>> import numpy as np\n        >>> import oneflow as flow\n        \n        >>> x = np.array([-0.5, 0, 0.5]).astype(np.float32)\n        >>> input = flow.tensor(x)     \n          \n        >>> out = flow.nn.functional.logsigmoid(input)\n        >>> out\n        tensor([-0.9741, -0.6931, -0.4741], dtype=oneflow.float32)\n\n    See :class:`~oneflow.nn.LogSigmoid` for more details.\n\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow._C.softsign,\n    r\"\"\"\n    softsign(x: Tensor) -> Tensor \n\n    The formula is: \n    \n    .. math::  \n    \n        softsign(x) = \\frac{x}{1 + |x|}\n    \n    For example:\n    \n    .. code-block:: python\n    \n        >>> import numpy as np\n        >>> import oneflow as flow\n\n        >>> x = np.array([1, 2, 3]).astype(np.float32)\n        >>> input = flow.tensor(x) \n        >>> out = flow.nn.functional.softsign(input)\n        >>> out\n        tensor([0.5000, 0.6667, 0.7500], dtype=oneflow.float32)\n \n    See :class:`~oneflow.nn.Softsign` for more details.\n    \n    \"\"\",\n)\n\n\nadd_docstr(\n    oneflow.silu,\n    \"\"\"\n    silu(x: Tensor) -> Tensor\n\n    The formula is: \n\n    .. math::\n\n        \\text{silu}(x) = x * sigmoid(x)\n        \n    For example:\n\n    .. code-block:: python\n\n        >>> import numpy as np\n        >>> import oneflow as flow\n\n        >>> x = np.array([1, 2, 3]).astype(np.float32)\n        >>> input = flow.tensor(x)       \n        >>> out = flow.silu(input)\n        >>> out\n        tensor([0.7311, 1.7616, 2.8577], dtype=oneflow.float32)\n\n    See :class:`~oneflow.nn.SiLU` for more details.\n\n    \"\"\",\n)\n\n\nadd_docstr(\n    oneflow.mish,\n    \"\"\" \n    mish(x: Tensor) -> Tensor \n\n    Applies the element-wise function:\n\n    .. math::\n        \\text{mish}(x) = x * \\text{tanh}(\\text{softplus}(x))\n\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import numpy as np\n        >>> import oneflow as flow\n        \n        >>> x = np.array([1, 2, 3]).astype(np.float32)\n        >>> input = flow.tensor(x)       \n\n        >>> out = flow.mish(input)\n        >>> out\n        tensor([0.8651, 1.9440, 2.9865], dtype=oneflow.float32)\n\n    See :class:`~oneflow.nn.Mish` for more details.\n    \n    \"\"\",\n)\n\n\nadd_docstr(\n    oneflow._C.hardsigmoid,\n    \"\"\"\n    hardsigmoid(x: Tensor)-> Tensor\n\n    Applies the element-wise function\n\n    .. math::\n        \\text{Hardsigmoid}(x) = \\begin{cases}\n            0 & \\text{if~} x \\le -3, \\\\\n            1 & \\text{if~} x \\ge +3, \\\\\n            x / 6 + 1 / 2 & \\text{otherwise}\n        \\end{cases}\n\n    \n    See :class:`~oneflow.nn.Hardsigmoid` for more details.\n    \"\"\",\n)\nadd_docstr(\n    oneflow._C.hardswish,\n    \"\"\"\n    hardswish(x: Tensor)-> Tensor\n\n    Applies the hardswish function, element-wise, as described in the paper:\n\n    `Searching for MobileNetV3`_.\n\n    .. math::\n        \\text{Hardswish}(x) = \\begin{cases}\n            0 & \\text{if~} x \\le -3, \\\\\n            x & \\text{if~} x \\ge +3, \\\\\n            x \\cdot (x + 3) /6 & \\text{otherwise}\n        \\end{cases}\n\n    See :class:`~oneflow.nn.Hardswish` for more details.\n\n    .. _`Searching for MobileNetV3`:\n        https://arxiv.org/abs/1905.02244\n    \"\"\",\n)\nadd_docstr(\n    oneflow.sigmoid,\n    r\"\"\"\n    sigmoid(input) -> Tensor\n\n    Applies the element-wise function :math:`\\text{Sigmoid}(x) = \\frac{1}{1 + \\exp(-x)}`\n\n    See :class:`~oneflow.nn.Sigmoid` for more details.\n\n    For examples:\n\n    .. code-block:: python\n\n        >>> import numpy as np\n        >>> import oneflow as flow\n\n        >>> x = np.array([0.81733328, 0.43621480, 0.10351428])\n        >>> input = flow.tensor(x, dtype=flow.float32)\n        >>> out = flow.nn.functional.sigmoid(input)\n        >>> out\n        tensor([0.6937, 0.6074, 0.5259], dtype=oneflow.float32)\n\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow._C.hardtanh,\n    \"\"\"\n    hardtanh(input, min_val=-1., max_val=1.) -> Tensor\n\n    Applies the HardTanh function element-wise. See :class:`~oneflow.nn.Hardtanh` for more\n    details.\n\n    \"\"\",\n)\nadd_docstr(\n    oneflow._C.leaky_relu,\n    \"\"\"\n    leaky_relu(x: Tensor,  alpha :Float) -> Tensor\n\n    Applies element-wise,\n    :math:`\\text{LeakyReLU}(x) = \\max(0, x) + \\text{negative_slope} * \\min(0, x)`\n\n    See :class:`~oneflow.nn.LeakyReLU` for more details.\n\n    \"\"\",\n)\nadd_docstr(\n    oneflow._C.rrelu,\n    \"\"\"\n    rrelu(x: Tensor, lower: Float = 1.0 / 8, upper: Float = 1.0 / 3, training: bool = False, inplace: bool = False) -> Tensor\n\n    Applies the randomized leaky rectified liner unit function, element-wise\n    :math:`\\text{RReLU}(x) = \\max(0, x) + a * \\min(0, x)`\n\n    where :math:`a` is randomly sampled from uniform distribution\n    :math:`\\mathcal{U}(\\text{lower}, \\text{upper})`.\n    \n    See :class:`~oneflow.nn.RReLU` for more details.\n\n    \"\"\",\n)\nadd_docstr(\n    oneflow._C.rrelu_,\n    \"\"\"\n    rrelu(x: Tensor, lower: Float = 1.0 / 8, upper: Float = 1.0 / 3, training: bool = False) -> Tensor\n\n    In-place version of :func:`rrelu`.\n    \"\"\",\n)\nadd_docstr(\n    oneflow._C.elu,\n    \"\"\"\n    elu(x: Tensor, alpha :Float) -> Tensor\n\n    Applies element-wise,\n        :math:`\\text{ELU}(x) = \\max(0,x) + \\min(0, \\alpha * (\\exp(x) - 1))`.\n\n    See :class:`~oneflow.nn.ELU` for more details.\n\n    For examples:\n\n    .. code-block:: python\n\n        >>> import numpy as np\n        >>> import oneflow as flow\n\n        >>> x = np.array([-0.5, 0, 0.5]).astype(np.float32)\n        >>> input = flow.tensor(x)\n        >>> out = flow.nn.functional.elu(input, alpha=1.0)\n        >>> out\n        tensor([-0.3935,  0.0000,  0.5000], dtype=oneflow.float32)\n    \"\"\",\n)\nadd_docstr(\n    oneflow.selu,\n    \"\"\"\n    selu(x: Tensor) -> Tensor\n\n    Applies element-wise function\n\n    .. math::\n\n        \\text{SELU}(x) = scale * (\\max(0,x) + \\min(0, \\alpha * (\\exp(x) - 1)))`, with :math:`\\alpha=1.6732632423543772848170429916717` and  :math:`scale=1.0507009873554804934193349852946`.\n\n    See :class:`~oneflow.nn.SELU` for more details.\n\n    For examples:\n\n    .. code-block:: python\n\n        >>> import numpy as np\n        >>> import oneflow as flow\n\n        >>> x = np.array([1, 2, 3]).astype(np.float32)\n        >>> input = flow.tensor(x)\n        >>> out = flow.nn.functional.selu(input)\n        >>> out\n        tensor([1.0507, 2.1014, 3.1521], dtype=oneflow.float32)\n    \"\"\",\n)\nadd_docstr(\n    oneflow._C.glu,\n    \"\"\"\n    glu(input: Tensor, dim: int) -> Tensor \n\n    The equation is:\n\n    .. math::\n         GLU(input) = GLU(a, b) = a \\otimes sigmoid(b)\n    \n    .. note::\n        where input is split in half along dim to form a and b, ⊗ is the element-wise product between matrices.\n    \n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import oneflow.nn as nn\n        >>> x = flow.tensor([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=flow.float32)\n        >>> y = nn.functional.glu(x)\n        >>> y\n        tensor([[0.9526, 1.9640],\n                [4.9954, 5.9980]], dtype=oneflow.float32)\n\n    See    \n    :class:`~oneflow.nn.GLU` for more details.\n \n    \"\"\",\n)\n\n\nadd_docstr(\n    oneflow._C.celu,\n    r\"\"\"\n    celu(x: Tensor, alpha: Float=1.0, inplace: bool=False) -> Tensor\n\n    Applies the element-wise function:\n\n    .. math::\n\n        \\text{CELU}(x) = \\max(0,x) + \\min(0, \\alpha * (\\exp(x/\\alpha) - 1))\n\n    See :class:`~oneflow.nn.CELU` for more details.\n\n    For examples:\n\n    .. code-block:: python\n\n        >>> import numpy as np\n        >>> import oneflow as flow\n\n        >>> x = np.array([-0.5, 0, 0.5]).astype(np.float32)\n        >>> input = flow.tensor(x)\n        >>> out = flow.nn.functional.celu(input, alpha=0.5)\n        >>> out\n        tensor([-0.3161,  0.0000,  0.5000], dtype=oneflow.float32)\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow._C.threshold,\n    \"\"\"\n    threshold(input: Tensor, threshold: float, value: float) -> Tensor\n\n    Thresholds each element of the input Tensor.\n\n    See :class:`~oneflow.nn.Threshold` for more details.\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow._C.hardshrink,\n    \"\"\"\n    hardshrink(input: Tensor, lambd: float=0.5, inplace: bool=False) -> Tensor\n\n    Applies the hard shrinkage function in an element-wise manner.\n\n    See :class:`~oneflow.nn.Hardshrink` for more details.\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow._C.softshrink,\n    \"\"\"\n    softshrink(input: Tensor, lambd: float=0.5, inplace: bool=False) -> Tensor\n\n    Applies the soft shrinkage function in an element-wise manner.\n\n    See :class:`~oneflow.nn.Softshrink` for more details.\n    \"\"\",\n)\n"
  },
  {
    "path": "python/oneflow/framework/docstr/addcdiv.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow\nfrom oneflow.framework.docstr.utils import add_docstr\n\nadd_docstr(\n    oneflow.addcdiv,\n    r\"\"\"\n    addcdiv(input, tensor1, tensor2, *, value=1) -> Tensor\n\n    This function is equivalent to PyTorch’s addcdiv function. \n    The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.addcdiv.html.\n    \n    Performs the element-wise division of :attr:`tensor1` by :attr:`tensor2`,\n    multiply the result by the scalar :attr:`value` and add it to :attr:`input`.\n\n    .. math::\n        \\text{out}_i = \\text{input}_i + \\text{value} \\times \\frac{\\text{tensor1}_i}{\\text{tensor2}_i}\n\n\n    The shapes of :attr:`input`, :attr:`tensor1`, and :attr:`tensor2` must be\n    `broadcastable`.\n\n    For inputs of type `FloatTensor` or `DoubleTensor`, :attr:`value` must be\n    a real number, otherwise an integer.\n\n    Args:\n        input (Tensor): the tensor to be added\n        tensor1 (Tensor): the numerator tensor\n        tensor2 (Tensor): the denominator tensor\n\n    Keyword args:\n        value (Number, optional): multiplier for :math:`\\text{{tensor1}} / \\text{{tensor2}}`\n\n    Example::\n\n        >>> import oneflow as flow\n        >>> input = flow.tensor([ 0.3810,  1.2774, -0.2972, -0.3719])\n        >>> tensor1 = flow.tensor([0.8032,  0.2930, -0.8113, -0.2308])\n        >>> tensor2 = flow.tensor([[0.5], [1]])\n        >>> output = flow.addcdiv(input, tensor1, tensor2)\n        >>> output.shape\n        oneflow.Size([2, 4])\n    \"\"\",\n)\n"
  },
  {
    "path": "python/oneflow/framework/docstr/amax.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport oneflow\nfrom oneflow.framework.docstr.utils import add_docstr\n\nadd_docstr(\n    oneflow.amax,\n    \"\"\"\n    oneflow.amax(input, dim=None, keepdim=False) -> Tensor\n\n    Returns the maximum along a dimension.\n\n    This function is equivalent to PyTorch’s amax function. \n\n    Args:\n        input (oneflow.Tensor): the input Tensor.\n        dim (int or List of int, optional): the dimension or the dimensions to reduce. Dim is None by default. \n        keepdim (bool, optional): whether to retain the dimension. keepdim is False by default. \n\n    Returns:\n        oneflow.Tensor: Maximum of the input tensor\n\n    For example:\n\n    .. code-block:: python\n    \n        >>> import oneflow as flow\n               \n        >>> x = flow.tensor([[[0,1],[2,3]],[[4,5],[6,7]]])\n        >>> flow.amax(x, 1)\n        tensor([[2, 3],\n                [6, 7]], dtype=oneflow.int64)\n        >>> flow.amax(x, 0)\n        tensor([[4, 5],\n                [6, 7]], dtype=oneflow.int64)\n        >>> flow.amax(x)\n        tensor(7, dtype=oneflow.int64)\n        >>> flow.amax(x, 0, True)\n        tensor([[[4, 5],\n                 [6, 7]]], dtype=oneflow.int64)\n    \"\"\",\n)\n"
  },
  {
    "path": "python/oneflow/framework/docstr/amin.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow\nfrom oneflow.framework.docstr.utils import add_docstr\n\nadd_docstr(\n    oneflow.amin,\n    \"\"\"\n    amin(input, dim, keepdim=False) -> Tensor  \n    \n    Returns the minimum value of each slice of the `input` tensor in the given dimension(s) `dim`.\n\n    If `keepdim` is `True`, the output tensor is of the same size as `input` except in the dimension(s) `dim` where it is of size 1. Otherwise, `dim` is squeezed (see :func:`oneflow.squeeze`), resulting in the output tensor having 1 (or `len(dim)`) fewer dimension(s).\n    \n    This function is equivalent to PyTorch’s amin function. \n    The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.amin.html.\n\n    Parameters:\n        input (oneflow.Tensor): the input Tensor.\n        dim (int, Tuple[int]): the dimension or dimensions to reduce. \n        keepdim (bool): whether the output tensor has `dim` retained or not.\n    \n    Example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n               \n        >>> x = flow.tensor([[[0,1],[2,3]],[[4,5],[6,7]]])\n        >>> flow.amin(x, 1)\n        tensor([[0, 1],\n                [4, 5]], dtype=oneflow.int64)\n        >>> flow.amin(x, 0)\n        tensor([[0, 1],\n                [2, 3]], dtype=oneflow.int64)\n        >>> flow.amin(x)\n        tensor(0, dtype=oneflow.int64)\n        >>> flow.amin(x, 0, True)\n        tensor([[[0, 1],\n                 [2, 3]]], dtype=oneflow.int64)\n    \"\"\",\n)\n"
  },
  {
    "path": "python/oneflow/framework/docstr/arange.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow\nfrom oneflow.framework.docstr.utils import add_docstr\n\nadd_docstr(\n    oneflow.arange,\n    \"\"\"\n    oneflow.arange(start: int = 0, end, step: int = 1, dtype: Optional[oneflow._oneflow_internal.dtype] = None, device: Optional[Union[oneflow._oneflow_internal.device, str]] = None, placement: Optional[oneflow._oneflow_internal.placement] = None, sbp: Optional[Union[oneflow._oneflow_internal.sbp.sbp, List[oneflow._oneflow_internal.sbp.sbp]]] = None, requires_grad: bool = False)\n\n    Returns a 1-D tensor of size :math:`\\\\left\\\\lfloor \\\\frac{\\\\text{end} - \\\\text{start}}{\\\\text{step}} \\\\right\\\\rfloor + 1`\n    with values from :attr:`start` to :attr:`end` with step :attr:`step`. Step is\n    the gap between two values in the tensor.\n\n    .. math::\n        \\\\text{out}_{i+1} = \\\\text{out}_i + \\\\text{step}.\n\n    Args:\n        start (int): the starting value for the set of points. Default: ``0``.\n        end (int): the ending value for the set of points\n        step (int): the gap between each pair of adjacent points. Default: ``1``.\n\n    Keyword args:\n        dtype(flow.dtype, optional): If `dtype` is not given, infer the `dtype` from the other input arguments. If any of start, end, or step are floating-point, the `dtype` is inferred to be the floating-point data type. Otherwise, the `dtype` is inferred to be `flow.int64`.\n        device(flow.device, optional): the desired device of returned tensor. Default: if None, uses the current device for the default tensor.\n        requires_grad(bool, optional): If autograd should record operations on the returned tensor. Default: `False`.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        \n        >>> y = flow.arange(0, 5)\n        >>> y\n        tensor([0, 1, 2, 3, 4], dtype=oneflow.int64)\n\n    \"\"\",\n)\n"
  },
  {
    "path": "python/oneflow/framework/docstr/argsort.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow\nfrom oneflow.framework.docstr.utils import add_docstr\n\nadd_docstr(\n    oneflow.argsort,\n    \"\"\"\n    argsort() -> Tensor\n    This operator sorts the input Tensor at specified dim and returns the indices of the sorted Tensor.\n\n    Args:\n        input (oneflow.Tensor): the input Tensor.\n        dim (int, optional): the dimension to be sorted. Defaults to the last dim (-1).\n        descending (bool, optional): controls the sorting order (ascending or descending).\n\n    Returns:\n        oneflow.Tensor: The indices of the sorted Tensor.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import numpy as np\n        >>> import oneflow as flow\n        >>> x = np.array([[10, 2, 9, 3, 7],\n        ...               [1, 9, 4, 3, 2]]).astype(\"float32\")\n        >>> input = flow.Tensor(x)\n        >>> output = flow.argsort(input)\n        >>> output\n        tensor([[1, 3, 4, 2, 0],\n                [0, 4, 3, 2, 1]], dtype=oneflow.int32)\n        >>> output = flow.argsort(input, descending=True)\n        >>> output\n        tensor([[0, 2, 4, 3, 1],\n                [1, 2, 3, 4, 0]], dtype=oneflow.int32)\n        >>> output = flow.argsort(input, dim=0)\n        >>> output\n        tensor([[1, 0, 1, 0, 1],\n                [0, 1, 0, 1, 0]], dtype=oneflow.int32)\n\n    \"\"\",\n)\n"
  },
  {
    "path": "python/oneflow/framework/docstr/array_ops.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow\nfrom oneflow.framework.docstr.utils import add_docstr\n\nadd_docstr(\n    oneflow.diagonal,\n    r\"\"\"\n    oneflow.diagonal(input, offset, dim1, dim2) -> Tensor\n    \n    Returns a partial view of input with the its diagonal elements with respect to dim1 and dim2 \n    appended as a dimension at the end of the shape.\n    \n    Args:\n        input (Tensor): the input tensor.Must be at least 2-dimensional.\n        offset (Optional[int], 0): which diagonal to consider. Default: 0 (main diagonal)\n        dim1 (Optional[int], 0): first dimension with respect to which to take diagonal. Default: 0\n        dim2 (Optional[int], 1): second dimension with respect to which to take diagonal. Default: 1\n    \n    Returns:\n        oneflow.Tensor: the output Tensor.\n\n    For example:\n    \n    .. code-block:: python\n\n        >>> import oneflow as flow\n        \n        >>> input = flow.randn(2,  3,  4)\n        >>> output = flow.diagonal(input, offset=1, dim1=1, dim2=0)\n        >>> output.shape\n        oneflow.Size([4, 1])\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.diag,\n    r\"\"\"\n    If input is a vector (1-D tensor), then returns a 2-D square tensor with the elements of input as the diagonal.\n    If input is a matrix (2-D tensor), then returns a 1-D tensor with diagonal elements of input.\n\n    Args:\n        input (Tensor): the input tensor.\n        diagonal (Optional[int], 0): The diagonal to consider. \n            If diagonal = 0, it is the main diagonal. If diagonal > 0, it is above the main diagonal. If diagonal < 0, it is below the main diagonal. Defaults to 0.\n    \n    Returns:\n        oneflow.Tensor: the output Tensor.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n        >>> arr = np.array(\n        ...     [\n        ...        [1.0, 2.0, 3.0],\n        ...        [4.0, 5.0, 6.0],\n        ...        [7.0, 8.0, 9.0],\n        ...     ]\n        ... )\n\n        >>> input = flow.tensor(arr, dtype=flow.float32)\n        >>> flow.diag(input)\n        tensor([1., 5., 9.], dtype=oneflow.float32)\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.tril,\n    r\"\"\"Returns the lower triangular part of a matrix (2-D tensor) or batch of matrices input along the specified diagonal, \n    the other elements of the result tensor out are set to 0.\n    \n    .. note::\n        - if diagonal = 0, the diagonal of the returned tensor will be the main diagonal,\n        - if diagonal > 0, the diagonal of the returned tensor will be above the main diagonal, \n        - if diagonal < 0, the diagonal of the returned tensor will be below the main diagonal.\n\n    Args:\n        input (Tensor): the input tensor. \n        diagonal (int, optional): the diagonal to specify. \n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n\n        >>> x = flow.tensor(np.ones(shape=(3, 3)).astype(np.float32))\n        >>> flow.tril(x)\n        tensor([[1., 0., 0.],\n                [1., 1., 0.],\n                [1., 1., 1.]], dtype=oneflow.float32)\n\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.triu,\n    r\"\"\"Returns the upper triangular part of a matrix (2-D tensor) or batch of matrices input, \n    the other elements of the result tensor out are set to 0.\n    \n    Args:\n        input (Tensor): the input tensor. \n        diagonal (int, optional): the diagonal to consider\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n        \n        >>> x = flow.tensor(np.ones(shape=(3, 3)).astype(np.float32))\n        >>> flow.triu(x)\n        tensor([[1., 1., 1.],\n                [0., 1., 1.],\n                [0., 0., 1.]], dtype=oneflow.float32)\n\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.argmax,\n    r\"\"\"The op computes the index with the largest value of a Tensor at specified axis.\n\n    Args:\n        input (oneflow.Tensor): Input Tensor\n        dim (int, optional): dimension to be calculated. Defaults to the last dim (-1)\n        keepdim (bool optional):  whether the output tensor has dim retained or not. Ignored if dim=None.\n\n    Returns:\n        oneflow.Tensor: A Tensor(dtype=int64) contains the index with the largest value of `input`\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        \n        >>> input = flow.tensor([[1, 3, 8, 7, 2],\n        ...            [1, 9, 4, 3, 2]], dtype=flow.float32)\n        >>> output = flow.argmax(input)\n        >>> output\n        tensor(6, dtype=oneflow.int64)\n        >>> output = flow.argmax(input, dim=1)\n        >>> output\n        tensor([2, 1], dtype=oneflow.int64)\n\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.argmin,\n    r\"\"\"The op computes the index with the largest value of a Tensor at specified axis.\n\n    Args:\n        input (oneflow.Tensor): Input Tensor\n        dim (int, optional): dimension to be calculated. Defaults to the last dim (-1)\n        keepdim (bool optional):  whether the output tensor has dim retained or not. Ignored if dim=None.\n\n    Returns:\n        oneflow.Tensor: A Tensor(dtype=int64) contains the index with the largest value of `input`\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        \n        >>> input = flow.tensor([[4, 3, 1, 0, 2],\n        ...            [5, 9, 7, 6, 8]], dtype=flow.float32)\n        >>> output = flow.argmin(input)\n        >>> output\n        tensor(3, dtype=oneflow.int64)\n        >>> output = flow.argmin(input, dim=1)\n        >>> output\n        tensor([3, 0], dtype=oneflow.int64)\n\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.batch_gather,\n    r\"\"\"Gather the element in batch dims. \n    \n    Args:\n        in (Tensor): the input tensor. \n        indices (Tensor): the indices tensor, its dtype must be int32/64. \n\n    For example:\n\n    Example 1: \n\n    .. code-block:: python\n\n        >>> import oneflow as flow \n        >>> import numpy as np \n\n        >>> x = flow.Tensor(np.array([[1, 2, 3], \n        ...                           [4, 5, 6]]))\n        >>> indices = flow.tensor(np.array([1, 0]).astype(np.int64))\n        >>> out = flow.batch_gather(x, indices)\n\n        tensor([[4., 5., 6.],\n                [1., 2., 3.]], dtype=oneflow.float32)\n\n    Example 2: \n\n    .. code-block:: python\n\n        >>> import oneflow as flow \n        >>> import numpy as np \n\n        >>> x = flow.Tensor(np.array([[[1, 2, 3], [4, 5, 6]], \n        ...                           [[1, 2, 3], [4, 5, 6]]]))\n        >>> indices = flow.tensor(np.array([[1, 0], \n        ...                                 [0, 1]]).astype(np.int64))\n        >>> out = flow.batch_gather(x, indices)\n\n        tensor([[[4., 5., 6.],\n                 [1., 2., 3.]],\n                [[1., 2., 3.],\n                 [4., 5., 6.]]], dtype=oneflow.float32)\n\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.transpose,\n    r\"\"\"Returns a tensor that is a transposed version of input. The given dimensions dim0 and dim1 are swapped.\n\n    The resulting out tensor shares its underlying storage with the input tensor, so changing the content of one would change the content of the other.\n\n    Args:\n        input (oneflow.Tensor): the input tensor.\n        dim0 (int): the first dimension to be transposed.\n        dim1 (int): the second dimension to be transposed.\n    Returns:\n        Tensor: A transposed tensor.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import numpy as np\n        >>> import oneflow as flow\n        >>> input = flow.tensor(np.random.randn(2, 6, 5, 3), dtype=flow.float32)\n        >>> out = flow.transpose(input, 0, 1).shape\n        >>> out\n        oneflow.Size([6, 2, 5, 3])\n\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.atleast_1d,\n    r\"\"\"\n    oneflow.atleast_1d(*tensors) -> Tensor or List[Tensor]\n\n    Returns a 1-dimensional view of each input tensor with zero dimensions. Input tensors with one or more dimensions are returned as-is.\n\n    The interface is consistent with PyTorch.\n\n    The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.atleast_1d.html.\n\n    Args:\n        tensors (List[oneflow.Tensor] or oneflow.Tensor): Tensor or list of tensors to be reshaped\n\n    Returns:\n        A `Tensor`\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> x = flow.randn(1)\n        >>> flow.atleast_1d(x).shape\n        oneflow.Size([1])\n        >>> x = flow.tensor(0)\n        >>> x.shape\n        oneflow.Size([])\n        >>> flow.atleast_1d(x).shape\n        oneflow.Size([1])\n\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.atleast_2d,\n    r\"\"\"\n    oneflow.atleast_2d(*tensors) -> Tensor or List[Tensor]\n\n    Returns a 2-dimensional view of each input tensor with zero dimensions. Input tensors with two or more dimensions are returned as-is.\n\n    The interface is consistent with PyTorch.\n\n    The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.atleast_2d.html.\n\n\n    Args:\n        tensors (List[oneflow.Tensor] or oneflow.Tensor): Tensor or list of tensors to be reshaped\n\n    Returns:\n        A `Tensor`\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> x = flow.tensor(0)\n        >>> x.shape\n        oneflow.Size([])\n        >>> flow.atleast_2d(x).shape\n        oneflow.Size([1, 1])\n        >>> x = flow.randn(3)\n        >>> flow.atleast_2d(x).shape\n        oneflow.Size([1, 3])\n        >>> x = flow.randn(3, 3)\n        >>> flow.atleast_2d(x).shape\n        oneflow.Size([3, 3])\n\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.atleast_3d,\n    r\"\"\"\n    oneflow.atleast_3d(*tensors) -> Tensor or List[Tensor]\n\n    Returns a 3-dimensional view of each input tensor with zero dimensions. Input tensors with three or more dimensions are returned as-is.\n\n    The interface is consistent with PyTorch.\n\n    The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.atleast_3d.html.\n\n    Args:\n        tensors (List[oneflow.Tensor] or oneflow.Tensor): Tensor or list of tensors to be reshaped\n\n    Returns:\n        A `Tensor`\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> x = flow.tensor(0)\n        >>> flow.atleast_3d(x).shape\n        oneflow.Size([1, 1, 1])\n        >>> x = flow.randn(3)\n        >>> flow.atleast_3d(x).shape\n        oneflow.Size([1, 3, 1])\n        >>> x = flow.randn(3, 4)\n        >>> flow.atleast_3d(x).shape\n        oneflow.Size([3, 4, 1])\n        >>> x = flow.randn(3, 4, 5)\n        >>> flow.atleast_3d(x).shape\n        oneflow.Size([3, 4, 5])\n\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.stack,\n    r\"\"\"Concatenates a sequence of tensors along a new dimension.\n    The returned tensor shares the same underlying data with input tensors.\n\n    A :attr:`dim` value within the range `[-input.ndimension() - 1, input.ndimension() + 1]`\n    can be used. Negative :attr:`dim` will correspond to :meth:`stack`\n    applied at :attr:`dim` = ``dim + input.ndimension() + 1``.\n\n    Args:\n        inputs (List[oneflow.Tensor]): the list of input tensors. Each tensor should have the same shape.\n        dim (int): the index at which to insert the concatenated dimension.\n\n    Returns:\n        A `Tensor`\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n\n        >>> x1 = flow.tensor(np.random.rand(1, 3, 5))\n        >>> x2 = flow.tensor(np.random.rand(1, 3, 5))\n        >>> y = flow.stack([x1, x2], dim = -1)\n        >>> y.shape\n        oneflow.Size([1, 3, 5, 2])\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.hstack,\n    r\"\"\"\n    oneflow.hstack(tensors) -> Tensor\n\n    Stack tensors in :attr:`tensors` horizontally (column wise).\n\n    This is equivalent to concatenation tensors in :attr:`tensors` along the first axis for 1-D tensors, and along the second axis for all other tensors.\n\n    When there are tensors with dimension less than 1, these tensors will be reshaped by ``oneflow.atleast_1d()`` to 1-dims tensors before stacking.\n\n    The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.hstack.html.\n\n    Args:\n        tensors: (List[oneflow.Tensor]): sequence of tensors to stack\n\n    Returns:\n        A `Tensor`\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> x1 = flow.randn(5, 2)\n        >>> x2 = flow.randn(5, 3)\n        >>> flow.hstack([x1, x2]).shape\n        oneflow.Size([5, 5])\n        >>> x = flow.randn(5)\n        >>> flow.hstack([x, x]).shape\n        oneflow.Size([10])\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.vstack,\n    r\"\"\"\n    oneflow.vstack(tensors) -> Tensor\n\n    Stack tensors in :attr:`tensors` vertically (row wise).\n\n    This is equivalent to concatenation tensors in :attr:`tensors` along the first axis.\n\n    When there are tensors with dimension less than 2, these tensors will be reshaped by ``oneflow.atleast_2d()`` to 2-D tensors before stacking.\n\n    The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.vstack.html.\n\n    Args:\n        tensors: (List[oneflow.Tensor]): sequence of tensors to stack\n\n    Returns:\n        A `Tensor`\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> x1 = flow.randn(2, 5)\n        >>> x2 = flow.randn(3, 5)\n        >>> flow.vstack([x1, x2]).shape\n        oneflow.Size([5, 5])\n        >>> x = flow.randn(5)\n        >>> flow.vstack([x, x]).shape\n        oneflow.Size([2, 5])\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.dstack,\n    r\"\"\"\n    oneflow.dstack(tensors) -> Tensor\n\n    Stack tensors in :attr:`tensors` depthwish (along third axis).\n\n    This is equivalent to concatenation tensors in :attr:`tensors` along the third axis after 1-D and 2-D tensors have been reshaped by ``oneflow.atleast_3d()``.\n\n    The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.dstack.html.\n\n    Args:\n        tensors: (List[oneflow.Tensor]): sequence of tensors to stack\n\n    Returns:\n        A `Tensor`\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> x1 = flow.randn(2, 3, 4)\n        >>> x2 = flow.randn(2, 3, 2)\n        >>> flow.dstack([x1, x2]).shape\n        oneflow.Size([2, 3, 6])\n        >>> x = flow.randn(6, 4)\n        >>> flow.dstack([x, x]).shape\n        oneflow.Size([6, 4, 2])\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.column_stack,\n    r\"\"\"\n    oneflow.column_stack(tensors) -> Tensor\n\n    Creates a new tensor by horizontally stacking the tensors in :attr:`tensors`.\n\n    Equivalent to :code:`oneflow.hstack(tensors)`, tensors with dimensions less than 2 will be reshaped to :code:`(t.numel(), 1)` before being stacked horizontally.\n\n    The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.column_stack.html.\n\n    Args:\n        tensors: (List[oneflow.Tensor]): sequence of tensors to stack\n\n    Returns:\n        A `Tensor`\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> x1 = flow.randn(5)\n        >>> x2 = flow.randn(5)\n        >>> flow.column_stack([x1, x2]).shape\n        oneflow.Size([5, 2])\n        >>> x1 = flow.randn(2, 5)\n        >>> x2 = flow.randn(2, 2)\n        >>> flow.column_stack([x1, x2]).shape\n        oneflow.Size([2, 7])\n\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.row_stack,\n    r\"\"\"\n    oneflow.row_stack(tensors) -> Tensor\n\n    Alias of ``oneflow.vstack()``.\n\n    Stack tensors in :attr:`tensors` vertically (row wise).\n\n    This is equivalent to concatenation tensors in :attr:`tensors` along the first axis.\n\n    When there are tensors with dimension less than 2, these tensors will be reshaped by ``oneflow.atleast_2d()`` to 2-D tensors before stacking.\n\n    The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.row_stack.html.\n\n    Args:\n        tensors: (List[oneflow.Tensor]): sequence of tensors to stack\n\n    Returns:\n        A `Tensor`\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> x1 = flow.randn(2, 5)\n        >>> x2 = flow.randn(3, 5)\n        >>> flow.vstack([x1, x2]).shape\n        oneflow.Size([5, 5])\n        >>> x = flow.randn(5)\n        >>> flow.vstack([x, x]).shape\n        oneflow.Size([2, 5])\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.squeeze,\n    r\"\"\"This operator removes the specified dimention which size is 1 of the input Tensor.\n    If the `dim` is not specified, this operator will remove all the dimention which size is 1 of the input Tensor.\n\n    The amount of element in return value is the same as Tensor `input`.\n\n    Args:\n        input (oneflow.Tensor): the input Tensor.\n        dim (int, optinal): Defaults to None, if given, the input will be squeezed only in this dimension.\n\n    Returns:\n        Tensor: The result Tensor.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n        >>> input = flow.tensor(np.array([[[[1, 1, 1]]]]).astype(np.int32))\n        >>> input.shape\n        oneflow.Size([1, 1, 1, 3])\n        >>> out = flow.squeeze(input, dim=[1, 2]).shape\n        >>> out\n        oneflow.Size([1, 3])\n\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.cat,\n    r\"\"\"\n    cat(tensors, dim=0) -> Tensor \n\n    Concatenate two or more `Tensor` s at specified dim.\n\n    Analogous to `numpy.concatenate <https://docs.scipy.org/doc/numpy/reference/generated/numpy.concatenate.html>`_\n\n    Args:\n        inputs: a `list` of `Tensor`\n        dim: a `int`.\n\n    Returns:\n        A `Tensor`\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n\n        >>> input1 = flow.tensor(np.random.randn(2, 6, 5, 3), dtype=flow.float32)\n        >>> input2 = flow.tensor(np.random.randn(2, 6, 5, 3), dtype=flow.float32)\n        >>> input3 = flow.tensor(np.random.randn(2, 6, 5, 3), dtype=flow.float32)\n\n        >>> out = flow.cat([input1, input2, input3], dim=1) # equal to using flow.concat()\n        >>> out.shape\n        oneflow.Size([2, 18, 5, 3])\n\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.gather,\n    \"\"\"\n    oneflow.gather(input, dim, index, sparse_grad=False) -> Tensor\n    \n    Gathers values along an axis specified by `dim`.\n\n    For a 3-D tensor the output is specified by::\n\n        out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0\n        out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1\n        out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2\n\n    :attr:`input` and :attr:`index` must have the same number of dimensions.\n    It is also required that ``index.size(d) <= input.size(d)`` for all\n    dimensions ``d != dim``.  :attr:`out` will have the same shape as :attr:`index`.\n    Note that ``input`` and ``index`` do not broadcast against each other.\n\n    Args:\n        input (Tensor): the source tensor\n        dim (int): the axis along which to index\n        index (LongTensor): the indices of elements to gather\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n        >>> input = np.random.randn(3, 4, 3, 5)\n        >>> index = np.random.choice(np.arange(3), size=180, replace=True).reshape((3, 4, 3, 5))\n        >>> output = flow.gather(flow.Tensor(input), 1, flow.tensor(index, dtype=flow.int64))\n        >>> output.shape\n        oneflow.Size([3, 4, 3, 5])\n\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.gather_nd,\n    r\"\"\"\n    oneflow.gather_nd(input, index) -> Tensor\n    \n    This operator is a high-dimensional extension of `gather`, `index` is a K-dimensional\n    tensor, which is regarded as a index of input Tensor `input`.\n\n    Each element defines a slice of `input`:\n\n    .. math::\n\n        output[i_{0},i_{1},...,i_{K-2}] = input[index(i_{0},i_{1},...,i_{K-2})]\n\n\n    Args:\n        input: The input Tensor.\n        index: The slice indices.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n        >>> input = flow.tensor(np.array([[1, 2,3], [4, 5,6],[7,8,9]]), dtype=flow.float)\n        >>> index_1 = flow.tensor(np.array([[0], [2]]), dtype=flow.int)\n        >>> out_1 = flow.gather_nd(input,index_1)\n        >>> print(out_1.shape)\n        oneflow.Size([2, 3])\n        >>> out_1\n        tensor([[1., 2., 3.],\n                [7., 8., 9.]], dtype=oneflow.float32)\n        >>> index_2 = flow.tensor(np.array([[0,2], [2,1]]), dtype=flow.int)\n        >>> out_2 = flow.gather_nd(input,index_2)\n        >>> out_2\n        tensor([3., 8.], dtype=oneflow.float32)\n\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.bincount,\n    r\"\"\"oneflow.bincount(input, weights=None, minlength=0) → Tensor\n\n    The interface is consistent with PyTorch.\n\n    The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.bincount.html.\n    \n    Count the frequency of each value in an array of non-negative ints.\n\n    The number of bins (size 1) is one larger than the largest value in ``input`` unless ``input`` is empty,\n    in which case the result is a tensor of size 0. If ``minlength`` is specified,\n    the number of bins is at least ``minlength`` and if ``input`` is empty,\n    then the result is tensor of size ``minlength`` filled with zeros.\n    If ``n`` is the value at position ``i``, ``out[n] += weights[i]`` if ``weights`` is specified else ``out[n] += 1``.\n\n    Args:\n        input (oneflow.Tensor): 1-d int Tensor\n        weights (oneflow.Tensor): optional, weight for each value in the input tensor. Should be of same size as input tensor.\n        minlength (int): optional, minimum number of bins. Should be non-negative.\n    \n    For example:\n\n    .. code-block:: python \n\n        >>> import oneflow as flow\n        >>> x = flow.tensor([1, 2, 4, 6])\n        >>> flow.bincount(x)\n        tensor([0, 1, 1, 0, 1, 0, 1], dtype=oneflow.int64)\n        >>> x = flow.tensor([1, 2, 1])\n        >>> weights = flow.tensor([0.1, 0.2, 0.15])\n        >>> flow.bincount(x, weights=weights)\n        tensor([0.0000, 0.2500, 0.2000], dtype=oneflow.float32)\n        >>> flow.bincount(x, weights=weights, minlength=4)\n        tensor([0.0000, 0.2500, 0.2000, 0.0000], dtype=oneflow.float32)\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.clone,\n    r\"\"\"oneflow.clone(input) → Tensor\n\n    Returns a copy of input.\n\n    The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.clone.html\n\n    .. note::\n        This function is differentiable, so gradients will flow back from the result\n        of this operation to ``input``. To create a tensor without an autograd relationship\n        to ``input`` see :meth:`detach`.\n\n    Args:\n        input (oneflow.Tensor): input Tensor to be cloned\n\n    For example:\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> x = flow.Tensor([1, 2, 3])\n        >>> y = flow.clone(x)\n        >>> y\n        tensor([1., 2., 3.], dtype=oneflow.float32)\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.frac,\n    r\"\"\"frac(input) → Tensor\n\n    Computes the fractional portion of each element in :attr:`input`.\n\n    .. math::\n        \\text{out}_{i} = \\text{input}_{i} - \\left\\lfloor |\\text{input}_{i}| \\right\\rfloor * \\operatorname{sgn}(\\text{input}_{i})\n\n    Args:\n        input: The input Tensor.\n\n    Returns:\n        Tensor: The fractional part of the argument.\n\n    For example:\n    \n        >>> import oneflow as flow\n        >>> flow.frac(flow.Tensor([1, 2.50, -3.21]))\n        tensor([ 0.0000,  0.5000, -0.2100], dtype=oneflow.float32)\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.frac_,\n    r\"\"\"\n    In-place version of :func:`oneflow.frac`.\n    \"\"\",\n)\n"
  },
  {
    "path": "python/oneflow/framework/docstr/as_tensor.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow\nfrom oneflow.framework.docstr.utils import add_docstr\n\nadd_docstr(\n    oneflow.as_tensor,\n    r\"\"\"\n    as_tensor(data, dtype=None, device=None) -> Tensor\n    \n    Converts data into a tensor, sharing data and preserving autograd history if possible.\n\n    If data is already a tensor with the requeseted dtype and device then data itself is returned, but if data is a tensor with a different dtype or device then it’s copied as if using data.to(dtype=dtype, device=device).\n\n    If data is a NumPy array (an ndarray) with the same dtype and device then a tensor is constructed using oneflow.from_numpy.\n    \n    The interface is consistent with PyTorch.\n\n    Args:\n        data (array_like): Initial data for the tensor. Can be a list, tuple, NumPy ``ndarray``, scalar, and other types.\n        dtype (oneflow.dtype, optional): the desired data type of returned tensor. Default: if ``None``, infers data type from data.\n        device (oneflow.device, optional): the device of the constructed tensor. If ``None`` and data is a tensor then the device of data is used. If None and data is not a tensor then the result tensor is constructed on the CPU.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n        \n        >>> a = np.array([1, 2, 3])\n        >>> t = flow.as_tensor(a, device=flow.device('cuda'))\n        >>> t\n        tensor([1, 2, 3], device='cuda:0', dtype=oneflow.int64)\n        >>> t[0] = -1\n        >>> a\n        array([1, 2, 3])\n\n    \"\"\",\n)\n"
  },
  {
    "path": "python/oneflow/framework/docstr/autograd.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow\nfrom oneflow.framework.docstr.utils import add_docstr\nfrom oneflow._oneflow_internal.autograd.Function import FunctionCtx\n\nadd_docstr(\n    FunctionCtx.saved_tensors, \"Get saved tensors in ctx.\",\n)\n\nadd_docstr(\n    FunctionCtx.save_for_backward,\n    \"Saves given tensors for a future call to ``backward()``.\",\n)\n\nadd_docstr(\n    FunctionCtx.mark_non_differentiable, \"Marks outputs as non-differentiable.\",\n)\n"
  },
  {
    "path": "python/oneflow/framework/docstr/baddbmm.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow\nfrom oneflow.framework.docstr.utils import add_docstr\n\nadd_docstr(\n    oneflow.baddbmm,\n    r\"\"\"\n    baddbmm(input, batch1, batch2, *, beta=1, alpha=1, out=None) -> Tensor\n\n    The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.baddbmm.html.\n\n    Performs a batch matrix-matrix product of matrices in :attr:`batch1` and :attr:`batch2`.\n    :attr:`input` is added to the final result.\n\n    :attr:`batch1` and :attr:`batch2` must be 3-D tensors each containing the same\n    number of matrices.\n\n    If :attr:`batch1` is a :math:`(b \\times n \\times m)` tensor, :attr:`batch2` is a\n    :math:`(b \\times m \\times p)` tensor, then :attr:`input` must be\n    broadcastable with a\n    :math:`(b \\times n \\times p)` tensor and :attr:`out` will be a\n    :math:`(b \\times n \\times p)` tensor.\n\n    .. math::\n        \\text{out}_i = \\beta\\ \\text{input}_i + \\alpha\\ (\\text{batch1}_i \\mathbin{@} \\text{batch2}_i)\n\n    If :attr:`beta` is 0, then :attr:`input` will be ignored, and `nan` and `inf` in it will not be propagated.\n\n    For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and\n    :attr:`alpha` must be real numbers, otherwise they should be integers.\n\n    Args:\n    input (Tensor): the tensor to be added\n    batch1 (Tensor): the first batch of matrices to be multiplied\n    batch2 (Tensor): the second batch of matrices to be multiplied\n\n    Keyword args:\n        beta (Number, optional): multiplier for :attr:`input` (:math:`\\beta`)\n        alpha (Number, optional): multiplier for :math:`\\text{{batch1}} \\mathbin{{@}} \\text{{batch2}}` (:math:`\\alpha`)\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> input = flow.randn(10, 3, 5)\n        >>> batch1 = flow.randn(10, 3, 4)\n        >>> batch2 = flow.randn(10, 4, 5)\n        >>> of_out = flow.baddbmm(input, batch1, batch2)\n        >>> of_out.shape\n        oneflow.Size([10, 3, 5])\n    \"\"\",\n)\n"
  },
  {
    "path": "python/oneflow/framework/docstr/bitwise_ops.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow\nfrom oneflow.framework.docstr.utils import add_docstr\n\n\nadd_docstr(\n    oneflow.bitwise_and,\n    \"\"\"\n    Computes the bitwise AND of input and other.\n    The input tensor must be of integral or Boolean types.\n    For bool tensors, it computes the logical AND.\n\n    The interface is consistent with PyTorch.\n    The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.bitwise_and.html\n\n    Args:\n        input (oneflow.Tensor): The input Tensor\n        other (oneflow.Tensor): The Tensor to compute bitwise AND with\n\n    Returns:\n        oneflow.Tensor: The output Tensor\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> x = flow.tensor([1, 2, 3])\n        >>> flow.bitwise_and(x, 2)\n        tensor([0, 2, 2], dtype=oneflow.int64)\n        >>> y = flow.tensor([5, 6, 7])\n        >>> flow.bitwise_and(x, y)\n        tensor([1, 2, 3], dtype=oneflow.int64)\n\n    \"\"\",\n)\n\n\nadd_docstr(\n    oneflow.bitwise_or,\n    \"\"\"\n    Computes the bitwise OR of input and other.\n    The input tensor must be of integral or Boolean types.\n    For bool tensors, it computes the logical OR.\n\n    The interface is consistent with PyTorch.\n    The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.bitwise_or.html\n\n    Args:\n        input (oneflow.Tensor): The input Tensor\n        other (oneflow.Tensor): The Tensor to compute OR with\n\n    Returns:\n        oneflow.Tensor: The output Tensor\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> x = flow.tensor([1, 2, 3])\n        >>> flow.bitwise_or(x, 4)\n        tensor([5, 6, 7], dtype=oneflow.int64)\n        >>> y = flow.tensor([5, 6, 7])\n        >>> flow.bitwise_or(x, y)\n        tensor([5, 6, 7], dtype=oneflow.int64)\n\n    \"\"\",\n)\n\n\nadd_docstr(\n    oneflow.bitwise_xor,\n    \"\"\"\n    Computes the bitwise XOR of input and other.\n    The input tensor must be of integral or Boolean types.\n    For bool tensors, it computes the logical XOR.\n\n    The interface is consistent with PyTorch.\n    The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.bitwise_xor.html\n\n    Args:\n        input (oneflow.Tensor): The input Tensor\n        other (oneflow.Tensor): The Tensor to compute XOR with\n\n    Returns:\n        oneflow.Tensor: The output Tensor\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> x = flow.tensor([1, 2, 3])\n        >>> flow.bitwise_xor(x, 2)\n        tensor([3, 0, 1], dtype=oneflow.int64)\n        >>> y = flow.tensor([5, 6, 7])\n        >>> flow.bitwise_xor(x, y)\n        tensor([4, 4, 4], dtype=oneflow.int64)\n\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.bitwise_not,\n    \"\"\"\n    Computes the bitwise NOT of input.\n    The input tensor must be of integral or Boolean types.\n    For bool tensors, it computes the logical NOT.\n\n    The interface is consistent with PyTorch.\n    The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.bitwise_not.html\n\n    Args:\n        input (oneflow.Tensor): The input Tensor\n\n    Returns:\n        oneflow.Tensor: The output Tensor\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> x = flow.tensor([1, 2, 3])\n        >>> flow.bitwise_not(x)\n        tensor([-2, -3, -4], dtype=oneflow.int64)\n        >>> x = flow.tensor([0, 0, 1]).bool()\n        >>> flow.bitwise_not(x)\n        tensor([ True,  True, False], dtype=oneflow.bool)\n\n    \"\"\",\n)\n"
  },
  {
    "path": "python/oneflow/framework/docstr/bmm.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow\nfrom oneflow.framework.docstr.utils import add_docstr\n\nadd_docstr(\n    oneflow.bmm,\n    \"\"\"\n    Performs a batch matrix-matrix product of matrices stored in input and mat2.\n\n    `input` and `mat2` must be 3-D tensors each containing the same number of matrices.\n\n    If input is a (b x n x m) tensor, mat2 is a (b x m x p) tensor, out will be a (b x n x p) tensor.\n\n    Args:\n        input(oneflow.Tensor):  the first batch of matrices to be multiplied\n        mat2(oneflow.Tensor): the second batch of matrices to be multiplied\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n        >>> input1 = flow.randn(10, 3, 4)\n        >>> input2 = flow.randn(10, 4, 5)\n        >>> of_out = flow.bmm(input1, input2)\n        >>> of_out.shape\n        oneflow.Size([10, 3, 5])\n    \"\"\",\n)\n"
  },
  {
    "path": "python/oneflow/framework/docstr/broadcast_like.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow\nfrom oneflow.framework.docstr.utils import add_docstr\n\nadd_docstr(\n    oneflow.broadcast_like,\n    \"\"\"\n    This operator broadcast tensor `x` to `like_tensor` according to the broadcast_axes. \n\n    Args:\n        x (Tensor): The input Tensor. \n        like_tensor (Tensor): The like Tensor. \n        broadcast_axes (Optional[Sequence], optional): The axes you want to broadcast. Defaults to None.\n\n    Returns:\n        [Tensor]: Broadcasted input Tensor. \n\n    For example: \n\n    .. code:: python\n\n        >>> import oneflow as flow \n\n        >>> x = flow.randn(3, 1, 1)\n        >>> like_tensor = flow.randn(3, 4, 5)\n        >>> broadcast_tensor = flow.broadcast_like(x, like_tensor, broadcast_axes=[1, 2]) \n        >>> broadcast_tensor.shape\n        oneflow.Size([3, 4, 5])\n    \n    \"\"\",\n)\n"
  },
  {
    "path": "python/oneflow/framework/docstr/cast.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow\nfrom oneflow.framework.docstr.utils import add_docstr\n\nadd_docstr(\n    oneflow.cast,\n    \"\"\"\n    \n    The operation takes input tensor `x` and casts it to the output with `dtype`\n\n    Args:\n        x (oneflow.Tensor): A Tensor\n        dtype (flow.dtype): Data type of the output tensor\n\n    Returns:\n        oneflow.Tensor: A Tensor with specific dtype.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n        >>> np_arr = np.random.randn(2, 3, 4, 5).astype(np.float32)\n        >>> input = flow.tensor(np_arr, dtype=flow.float32)\n        >>> output = flow.cast(input, flow.int8)\n        >>> np.array_equal(output.numpy(), np_arr.astype(np.int8))\n        True\n\n    \"\"\",\n)\n"
  },
  {
    "path": "python/oneflow/framework/docstr/chunk.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow\nfrom oneflow.framework.docstr.utils import add_docstr\n\nadd_docstr(\n    oneflow.chunk,\n    \"\"\"Splits a tensor into a specific number of chunks. Each chunk is a view of the input tensor. Last chunk will be bigger if the tensor size along the given dimension dim is not divisible by chunks.\n\n    Args:\n        input (oneflow.Tensor): The tensor to split.\n        chunks (int): Number of chunks to return.\n        dim (int): Dimension along which to split the tensor.\n\n    Returns:\n        List of Tensors.\n\n    For example:\n\n    .. code-block:: python\n    \n        >>> import oneflow as flow\n        >>> import numpy as np\n               \n        >>> arr = np.random.randn(5, 3, 6, 9).astype(np.float32)\n        >>> input = flow.tensor(arr)\n        >>> output = []\n        >>> chunks = 3\n        >>> output = flow.chunk(input, chunks=chunks, dim=2)\n        >>> out_shape = []\n        >>> for i in range(0, chunks):\n        ...     out_shape.append(output[i].numpy().shape)\n        >>> out_shape\n        [(5, 3, 2, 9), (5, 3, 2, 9), (5, 3, 2, 9)]\n\n    \"\"\",\n)\n"
  },
  {
    "path": "python/oneflow/framework/docstr/clamp.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow\nfrom oneflow.framework.docstr.utils import add_docstr\n\nadd_docstr(\n    oneflow.clamp,\n    \"\"\"\n    Clamp all elements in :attr:`input` into the range `[` :attr:`min`, :attr:`max` `]` and return\n    a resulting tensor:\n\n    .. math::\n        y_i = \\\\begin{cases}\n            \\\\text{min} & \\\\text{if } x_i < \\\\text{min} \\\\\\\\\n            x_i & \\\\text{if } \\\\text{min} \\\\leq x_i \\\\leq \\\\text{max} \\\\\\\\\n            \\\\text{max} & \\\\text{if } x_i > \\\\text{max}\n        \\\\end{cases}\n\n    If :attr:`input` is of type `FloatTensor` or `DoubleTensor`, args :attr:`min`\n    and :attr:`max` must be real numbers, otherwise they should be integers.\n\n    Args:\n        input (Tensor): the input tensor.\n        min (Number): lower-bound of the range to be clamped to. Defaults to None.\n        max (Number): upper-bound of the range to be clamped to. Defaults to None.\n        out (Tensor, optional): the output tensor.\n\n    For example:\n\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n        >>> arr = np.array([0.2, 0.6, -1.5, -0.3])\n        >>> input = flow.Tensor(arr)\n        >>> output = flow.clamp(input, min=-0.5, max=0.5)\n        >>> output\n        tensor([ 0.2000,  0.5000, -0.5000, -0.3000], dtype=oneflow.float32)\n\n        >>> arr = np.array([0.2, 0.6, -1.5, -0.3])\n        >>> input = flow.Tensor(arr)\n        >>> output = flow.clamp(input, min=None, max=0.5)\n        >>> output\n        tensor([ 0.2000,  0.5000, -1.5000, -0.3000], dtype=oneflow.float32)\n\n        >>> arr = np.array([0.2, 0.6, -1.5, -0.3])\n        >>> input = flow.Tensor(arr)\n        >>> output = flow.clamp(input, min=-0.5, max=None)\n        >>> output\n        tensor([ 0.2000,  0.6000, -0.5000, -0.3000], dtype=oneflow.float32)\n\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.clamp_min,\n    \"\"\"\n    Clamp all elements in :attr:`input` which are less than :attr:`min` to :attr:`min` and return\n    a resulting tensor:\n\n    .. math::\n        y_i = \\max(min, x_i)\n\n    If :attr:`input` is of type `FloatTensor` or `DoubleTensor`, args :attr:`min`\n    must be real numbers, otherwise they should be integers.\n\n    Args:\n        input (Tensor): the input tensor.\n        min (Number): lower-bound of the range to be clamped to.\n        out (Tensor, optional): the output tensor.\n\n    For example:\n\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> input = flow.Tensor([0.2, 0.6, -1.5, -0.3])\n        >>> output = flow.clamp_min(input, min=-0.5)\n        >>> output\n        tensor([ 0.2000,  0.6000, -0.5000, -0.3000], dtype=oneflow.float32)\n\n        >>> input = flow.Tensor([0.2, 0.6, -1.5, -0.3])\n        >>> output = flow.clamp_min(input, min=-2)\n        >>> output\n        tensor([ 0.2000,  0.6000, -1.5000, -0.3000], dtype=oneflow.float32)\n\n        >>> input = flow.Tensor([0.2, 0.6, -1.5, -0.3])\n        >>> output = flow.clamp_min(input, min=1)\n        >>> output\n        tensor([1., 1., 1., 1.], dtype=oneflow.float32)\n\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.clamp_max,\n    \"\"\"\n    Clamp all elements in :attr:`input` which are greater than :attr:`max` to :attr:`max` and return\n    a resulting tensor:\n\n    .. math::\n        y_i = \\min(max, x_i)\n\n    If :attr:`input` is of type `FloatTensor` or `DoubleTensor`, args :attr:`max`\n    must be real numbers, otherwise they should be integers.\n\n    Args:\n        input (Tensor): the input tensor.\n        max (Number): upper-bound of the range to be clamped to.\n        out (Tensor, optional): the output tensor.\n\n    For example:\n\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> input = flow.Tensor([0.2, 0.6, -1.5, -0.3])\n        >>> output = flow.clamp_max(input, max=-0.5)\n        >>> output\n        tensor([-0.5000, -0.5000, -1.5000, -0.5000], dtype=oneflow.float32)\n\n        >>> input = flow.Tensor([0.2, 0.6, -1.5, -0.3])\n        >>> output = flow.clamp_max(input, max=-2)\n        >>> output\n        tensor([-2., -2., -2., -2.], dtype=oneflow.float32)\n\n        >>> input = flow.Tensor([0.2, 0.6, -1.5, -0.3])\n        >>> output = flow.clamp_max(input, max=1)\n        >>> output\n        tensor([ 0.2000,  0.6000, -1.5000, -0.3000], dtype=oneflow.float32)\n\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.clip,\n    \"\"\"\n    Alias for :func:`oneflow.clamp`. \n    \"\"\",\n)\n"
  },
  {
    "path": "python/oneflow/framework/docstr/comm.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow\nfrom oneflow.framework.docstr.utils import add_docstr\n\nadd_docstr(\n    oneflow.comm.send,\n    \"\"\"Sends a tensor synchronously.\n\n    Args:\n        tensor (Tensor): Tensor to send.\n        dst (int): Destination rank.\n        send_meta (Bool): Whether to send meta information (default is True)\n\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.comm.recv,\n    \"\"\"Receives a tensor synchronously.\n    \n    All(send_meta is False) or none of shape, dtype and device should have value.\n\n    Args:\n        src (int, optional): Source rank. Will receive from any\n            process if unspecified.\n        shape (optional): output tensor shape.\n        dataType (optional): output tensor data type.\n        device (optional): output tensor device.\n        out (Tensor, optional): Tensor to fill with received data.\n    \n    Returns:\n        if out is None, return received tensor. otherwise got data from out self without return.\n    \"\"\",\n)\n"
  },
  {
    "path": "python/oneflow/framework/docstr/comparison.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport oneflow\nfrom oneflow.framework.docstr.utils import add_docstr\n\nadd_docstr(\n    oneflow.greater,\n    \"\"\"Returns the truth value of :math:`input > other` element-wise.\n\n    Args:\n        input (oneflow.Tensor): A Tensor\n        other (oneflow.Tensor): A Tensor\n\n    Returns:\n        oneflow.Tensor: A Tensor with bool type.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import numpy as np\n        >>> import oneflow as flow\n        \n        >>> input1 = flow.tensor(np.random.randn(2, 6, 5, 3), dtype=flow.float32)\n        >>> input2 = flow.tensor(np.random.randn(2, 6, 5, 3), dtype=flow.float32)\n\n        >>> out = flow.gt(input1, input2).shape\n        >>> out\n        oneflow.Size([2, 6, 5, 3])\n\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.greater_equal,\n    \"\"\"Returns the truth value of :math:`input >= other` element-wise.\n\n    Args:\n        input (oneflow.Tensor): A Tensor\n        other (oneflow.Tensor): A Tensor\n\n    Returns:\n        oneflow.Tensor: A Tensor with bool type.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import numpy as np\n        >>> import oneflow as flow\n        \n        >>> input1 = flow.tensor(np.array([1, 2, 3]).astype(np.float32), dtype=flow.float32)\n        >>> input2 = flow.tensor(np.array([1, 1, 4]).astype(np.float32), dtype=flow.float32)\n\n        >>> out = flow.ge(input1, input2)\n        >>> out\n        tensor([ True,  True, False], dtype=oneflow.bool)\n\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.eq,\n    \"\"\"eq(input, other) -> Tensor\n\n    Computes element-wise equality.\n    The second argument can be a number or a tensor whose shape is broadcastable with the first argument.\n\n    Args:\n        input (oneflow.Tensor): the tensor to compare\n        other (oneflow.Tensor, float or int): the target to compare\n\n    Returns:\n\n        - A boolean tensor that is True where :attr:`input` is equal to :attr:`other` and False elsewhere\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n        \n        >>> input = flow.tensor(np.array([2, 3, 4, 5]), dtype=flow.float32)\n        >>> other = flow.tensor(np.array([2, 3, 4, 1]), dtype=flow.float32)\n\n        >>> y = flow.eq(input, other)\n        >>> y\n        tensor([ True,  True,  True, False], dtype=oneflow.bool)\n\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.equal,\n    \"\"\"equal(input, other) -> bool\n\n    `True` if two tensors have the same size and elements, `False` otherwise.\n\n    Args:\n        input (oneflow.Tensor): the tensor to compare\n        other (oneflow.Tensor): the target to compare\n\n    Returns:\n        A boolean value\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n        \n        >>> input = flow.tensor(np.array([2, 3, 4, 5]), dtype=flow.float32)\n        >>> other = flow.tensor(np.array([2, 3, 4, 1]), dtype=flow.float32)\n\n        >>> y = flow.equal(input, other)\n        >>> y\n        False\n\n        >>> y = flow.equal(input, input)\n        >>> y\n        True\n\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.ne,\n    \"\"\"ne(input, other) -> Tensor\n\n    Computes element-wise not equality.\n    The second argument can be a number or a tensor whose shape is broadcastable with the first argument.\n\n    Args:\n        input (oneflow.Tensor): the tensor to compare\n        other (oneflow.Tensor, float or int): the target to compare\n\n    Returns:\n\n        - A boolean tensor that is True where :attr:`input` is not equal to :attr:`other` and False elsewhere\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n        \n        >>> input = flow.tensor(np.array([2, 3, 4, 5]), dtype=flow.float32)\n        >>> other = flow.tensor(np.array([2, 3, 4, 1]), dtype=flow.float32)\n\n        >>> y = flow.ne(input, other)\n        >>> y\n        tensor([False, False, False,  True], dtype=oneflow.bool)\n\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.lt,\n    \"\"\"lt(input, other) -> Tensor\n\n    Returns the truth value of :math:`input < other` element-wise.\n\n    Args:\n        input (oneflow.Tensor): A Tensor\n        other (oneflow.Tensor): A Tensor\n\n    Returns:\n        oneflow.Tensor: A Tensor with bool type.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import numpy as np\n        >>> import oneflow as flow\n        \n        >>> input1 = flow.tensor(np.array([1, 2, 3]).astype(np.float32), dtype=flow.float32)\n        >>> input2 = flow.tensor(np.array([1, 2, 4]).astype(np.float32), dtype=flow.float32)\n\n        >>> out = flow.lt(input1, input2)\n        >>> out\n        tensor([False, False,  True], dtype=oneflow.bool)\n\n    \"\"\",\n)\n\n\nadd_docstr(\n    oneflow.le,\n    \"\"\"le(input, other) -> Tensor\n    \n    Returns the truth value of :math:`input <= other` element-wise.\n\n    Args:\n        input (oneflow.Tensor): A Tensor\n        other (oneflow.Tensor): A Tensor\n\n    Returns:\n        oneflow.Tensor: A Tensor with bool type.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import numpy as np\n        >>> import oneflow as flow\n        \n        >>> input1 = flow.tensor(np.array([1, 2, 3]).astype(np.float32), dtype=flow.float32)\n        >>> input2 = flow.tensor(np.array([1, 1, 4]).astype(np.float32), dtype=flow.float32)\n\n        >>> out = flow.le(input1, input2)\n        >>> out\n        tensor([ True, False,  True], dtype=oneflow.bool)\n\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.isclose,\n    r\"\"\"isclose(input, other, atol=1e-08, rtol=1e-05, equal_nan=False) -> Tensor\n    \n    The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.isclose.html\n\n    Returns a new tensor with boolean elements representing if each element of\n    :attr:`input` is \"close\" to the corresponding element of :attr:`other`.\n    Closeness is defined as:\n\n    .. math::\n        \\lvert \\text{input} - \\text{other} \\rvert \\leq \\texttt{atol} + \\texttt{rtol} \\times \\lvert \\text{other} \\rvert\n\n    Args:\n        input (oneflow.Tensor): first tensor to compare\n        other (oneflow.Tensor): second tensor to compare\n        atol (float, optional): absolute tolerance. Default: 1e-08\n        rtol (float, optional): relative tolerance. Default: 1e-05\n        equal_nan (bool, optional): if ``True``, then two ``NaN`` s will be considered equal. Default: ``False``\n\n    Returns:\n        oneflow.Tensor: A Tensor with bool type.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        \n        >>> flow.isclose(flow.tensor((1., 2, 3)), flow.tensor((1 + 1e-10, 3, 4)))\n        tensor([ True, False, False], dtype=oneflow.bool)\n\n        >>> flow.isclose(flow.tensor((float('inf'), 4)), flow.tensor((float('inf'), 6)), rtol=.5)\n        tensor([True, True], dtype=oneflow.bool)\n\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.allclose,\n    r\"\"\"allclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False) -> bool\n    \n    The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.allclose.html\n\n    This function checks if :attr:`input` and :attr:`other` satisfy the condition:\n\n    .. math::\n        \\lvert \\text{input} - \\text{other} \\rvert \\leq \\texttt{atol} + \\texttt{rtol} \\times \\lvert \\text{other} \\rvert\n\n    elementwise, for all elements of :attr:`input` and :attr:`other`. The behaviour of this function is analogous to\n    `numpy.allclose <https://docs.scipy.org/doc/numpy/reference/generated/numpy.allclose.html>`_\n\n    Args:\n        input (oneflow.Tensor): first tensor to compare\n        other (oneflow.Tensor): second tensor to compare\n        atol (float, optional): absolute tolerance. Default: 1e-08\n        rtol (float, optional): relative tolerance. Default: 1e-05\n        equal_nan (bool, optional): if ``True``, then two ``NaN`` s will be considered equal. Default: ``False``\n\n    Returns:\n        oneflow.Tensor: A Tensor with bool type.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n\n        >>> flow.allclose(flow.tensor([10000., 1e-07]), flow.tensor([10000.1, 1e-08]))\n        False\n        >>> flow.allclose(flow.tensor([10000., 1e-08]), flow.tensor([10000.1, 1e-09]))\n        True\n        >>> flow.allclose(flow.tensor([1.0, float('nan')]), flow.tensor([1.0, float('nan')]))\n        False\n        >>> flow.allclose(flow.tensor([1.0, float('nan')]), flow.tensor([1.0, float('nan')]), equal_nan=True)\n        True\n\n    \"\"\",\n)\n"
  },
  {
    "path": "python/oneflow/framework/docstr/constant.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow\nfrom oneflow.framework.docstr.utils import add_docstr\n\nadd_docstr(\n    oneflow.ones_like,\n    \"\"\"\n    ones_like(input, *, dtype=None, device=None, placement=None, sbp=None, requires_grad=False) -> Tensor\n\n    The interface is consistent with PyTorch.    \n    The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.ones_like.html.\n\n    Returns a tensor filled with the scalar value 1, with the same size as input.\n    flow.ones_like(input) is equivalent to flow.ones(input.shape, dtype=input.dtype)\n\n    Args:\n        input(Tensor): The size of input will determine size of the output tensor.\n        dtype (flow.dtype, optional):  the desired type of returned tensor. Default: if None, same flow.dtype as this tensor.\n        device (flow.device, optional): the desired device of returned tensor. Default: if None, same flow.device as this tensor.\n        placement (flow.placement, optional): the desired placement of returned global tensor. Default: if None, the returned tensor is local one using the argument `device`.\n        sbp (flow.sbp.sbp or tuple of flow.sbp.sbp, optional): the desired sbp descriptor of returned global tensor. Default: if None, the returned tensor is local one using the argument `device`.\n        requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: False.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n        >>> x = flow.tensor(np.random.rand(5), dtype=flow.float32)\n        >>> y = flow.ones_like(x)\n        >>> y\n        tensor([1., 1., 1., 1., 1.], dtype=oneflow.float32)\n\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.zeros_like,\n    \"\"\"\n    zeros_like(input, *, dtype=None, device=None, placement=None, sbp=None, requires_grad=False) -> Tensor\n\n    The interface is consistent with PyTorch.    \n    The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.zeros_like.html.\n\n    Returns a tensor filled with the scalar value 0, with the same size as input.\n    flow.zeros_like(input) is equivalent to flow.zeros(input.shape, dtype=input.dtype)\n\n    Args:\n        input(Tensor): The size of input will determine size of the output tensor.\n        dtype (flow.dtype, optional):  the desired type of returned tensor. Default: if None, same flow.dtype as this tensor.\n        device (flow.device, optional): the desired device of returned tensor. Default: if None, same flow.device as this tensor.\n        placement (flow.placement, optional): the desired placement of returned global tensor. Default: if None, the returned tensor is local one using the argument `device`.\n        sbp (flow.sbp.sbp or tuple of flow.sbp.sbp, optional): the desired sbp descriptor of returned global tensor. Default: if None, the returned tensor is local one using the argument `device`.\n        requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: False.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n        >>> x = flow.tensor(np.random.rand(5), dtype=flow.float32)\n        >>> y = flow.zeros_like(x)\n        >>> y\n        tensor([0., 0., 0., 0., 0.], dtype=oneflow.float32)\n\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.new_ones,\n    \"\"\"\n    new_ones(x, size=None, dtype=None, device=None, placement=None, sbp=None, requires_grad=False) -> Tensor\n\n    The interface is consistent with PyTorch.    \n    The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.Tensor.new_ones.html.\n\n    Returns a Tensor of size size filled with 1. By default, the returned Tensor has the same oneflow.dtype and oneflow.device as this tensor.\n\n    Args:\n        size (int...): a list, tuple, or flow.Size of integers defining the shape of the output tensor.\n        dtype (flow.dtype, optional):  the desired type of returned tensor. Default: if None, same flow.dtype as this tensor.\n        device (flow.device, optional): the desired device of returned tensor. Default: if None, same flow.device as this tensor.\n        placement (flow.placement, optional): the desired placement of returned global tensor. Default: if None, the returned tensor is local one using the argument `device`.\n        sbp (flow.sbp.sbp or tuple of flow.sbp.sbp, optional): the desired sbp descriptor of returned global tensor. Default: if None, the returned tensor is local one using the argument `device`.\n        requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: False.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import numpy as np\n        >>> import oneflow as flow\n\n        >>> x = flow.Tensor(np.ones((1, 2, 3)))\n        >>> y = x.new_ones((2, 2))\n        >>> y\n        tensor([[1., 1.],\n                [1., 1.]], dtype=oneflow.float32)\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.empty,\n    \"\"\"\n    empty(*size, *, dtype=None, device=None, placement=None, sbp=None, requires_grad=False, pin_memory=False) -> Tensor\n\n    The interface is consistent with PyTorch.    \n    The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.empty.html.\n\n    Returns a tensor filled with uninitialized data.\n    The shape of the tensor is defined by the variable argument ``size``.\n\n    Args:\n        size (int... or oneflow.Size): Defining the shape of the output tensor.\n          Can be a variable number of arguments or a collection like a list or tuple or oneflow.Size.\n        dtype (flow.dtype, optional): The desired data type of returned tensor. Default: ``flow.float32``.\n        device (oneflow.device, optional): The desired device of returned local tensor. If None, uses the\n          current device.\n        placement (flow.placement, optional): The desired device of returned global tensor. If None, will\n          construct local tensor.\n        sbp (flow.sbp or List[flow.sbp], optional): The desired sbp of returned global tensor.\n        requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: False.\n        pin_memory (bool, optional) – If set, returned tensor would be allocated in the pinned memory. Works only for CPU tensors. Default: False.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> y = flow.empty(4, 5)  # construct local empty tensor\n        >>> y.shape\n        oneflow.Size([4, 5])\n        >>> y.is_global\n        False\n        >>> placement = flow.placement(\"cpu\", ranks=[0])\n        >>> y = flow.empty(4, 5, placement=placement, sbp=flow.sbp.broadcast)  # construct consistent empty tensor\n        >>> y.is_global\n        True\n\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.empty_like,\n    \"\"\"\n    empty_like(input, *, dtype=None, device=None, placement=None, sbp=None, requires_grad=False) -> Tensor\n\n    The interface is consistent with PyTorch.    \n    The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.empty_like.html.\n\n    Returns an uninitialized tensor with the same size as :attr:`input`.\n    ``oneflow.empty_like(input)`` is equivalent to\n    ``oneflow.empty(input.size(), dtype=input.dtype, device=input.device)``.\n\n    Args:\n        input(Tensor): The size of input will determine size of the output tensor.\n        dtype (flow.dtype, optional): The desired data type of returned tensor. Default: ``flow.float32``.\n        device (oneflow.device, optional): The desired device of returned local tensor. If None, uses the\n          current device.\n        placement (flow.placement, optional): The desired device of returned global tensor. If None, will\n          construct local tensor.\n        sbp (flow.sbp or List[flow.sbp], optional): The desired sbp of returned global tensor.\n        requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: False.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> x = flow.randn(2, 3)\n        >>> y = flow.empty_like(x)\n        >>> y.shape\n        oneflow.Size([2, 3])\n\n    \"\"\",\n)\n"
  },
  {
    "path": "python/oneflow/framework/docstr/conv.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow\nfrom oneflow.framework.docstr.utils import add_docstr\n\nadd_docstr(\n    oneflow._C.conv1d,\n    r\"\"\"\n    conv1d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1) -> Tensor\n\n    Applies a 1D convolution over an input signal composed of several input\n    planes.\n\n    The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.functional.conv1d.html.\n\n    See :class:`~oneflow.nn.Conv1d` for details and output shape.\n\n    Args:\n        input: input tensor of shape :math:`(\\text{minibatch} , \\text{in_channels} , iW)`\n        weight: filters of shape :math:`(\\text{out_channels} , \\frac{\\text{in_channels}}{\\text{groups}} , iW)`\n        bias: optional bias of shape :math:`(\\text{out_channels})`. Default: None.\n        stride: the stride of the convolving kernel. Can be a single number or a\n          tuple `(sW,)`. Default: 1\n        padding: implicit paddings on both sides of the input. Can be a\n          single number or a tuple `(padW,)`. Default: 0\n        dilation: the spacing between kernel elements. Can be a single number or\n          a tuple `(dW,)`. Default: 1\n        groups: split input into groups, :math:`\\text{in_channels}` should be divisible by the\n          number of groups. Default: 1\n\n    For examples:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import oneflow.nn.functional as F\n        \n        >>> inputs = flow.randn(33, 16, 30)\n        >>> filters = flow.randn(20, 16, 5)\n        >>> outputs = F.conv1d(inputs, filters)\n        \"\"\",\n)\nadd_docstr(\n    oneflow._C.conv2d,\n    r\"\"\"\n    conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1) -> Tensor\n\n    Applies a 2D convolution over an input image composed of several input\n    planes.\n\n    The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.functional.conv2d.html.\n\n    See :class:`~oneflow.nn.Conv2d` for details and output shape.\n\n    Args:\n        input: input tensor of shape :math:`(\\text{minibatch} , \\text{in_channels} , iH , iW)`\n        weight: filters of shape :math:`(\\text{out_channels} , \\frac{\\text{in_channels}}{\\text{groups}} , kH , kW)`\n        bias: optional bias of shape :math:`(\\text{out_channels})`. Default: None.\n        stride: the stride of the convolving kernel. Can be a single number or a\n          tuple `(sH, sW)`. Default: 1\n        padding: implicit paddings on both sides of the input. Can be a\n          single number or a tuple `(padH, padW)`. Default: 0\n        dilation: the spacing between kernel elements. Can be a single number or\n          a tuple `(dH, dW)`. Default: 1\n        groups: split input into groups, :math:`\\text{in_channels}` should be divisible by the\n          number of groups. Default: 1\n\n    For examples:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import oneflow.nn.functional as F\n        \n        >>> inputs = flow.randn(8, 4, 3, 3)\n        >>> filters = flow.randn(1, 4, 5, 5)\n        >>> outputs = F.conv2d(inputs, filters, padding=1)\n    \n        \"\"\",\n)\nadd_docstr(\n    oneflow._C.conv3d,\n    r\"\"\"\n    conv3d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1) -> Tensor\n\n    Applies a 3D convolution over an input image composed of several input\n    planes.\n\n    The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.functional.conv3d.html.\n\n    See :class:`~oneflow.nn.Conv3d` for details and output shape.\n\n    Args:\n        input: input tensor of shape\n          :math:`(\\text{minibatch} , \\text{in_channels} , iD , iH , iW)`\n        weight: filters of shape\n          :math:`(\\text{out_channels} , \\frac{\\text{in_channels}}{\\text{groups}} , kD , kH , kW)`\n        bias: optional bias of shape :math:`(\\text{out_channels})`. Default: None.\n        stride: the stride of the convolving kernel. Can be a single number or a\n          tuple `(sD, sH, sW)`. Default: 1\n        padding: implicit paddings on both sides of the input. Can be a\n          single number or a tuple `(padD, padH, padW)`. Default: 0\n        dilation: the spacing between kernel elements. Can be a single number or\n          a tuple `(dD, dH, dW)`. Default: 1\n        groups: split input into groups, :math:`\\text{in_channels}` should be\n          divisible by the number of groups. Default: 1\n\n    For examples:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import oneflow.nn.functional as F\n        \n        >>> inputs = flow.randn(20, 16, 50, 10, 20)\n        >>> filters = flow.randn(33, 16, 3, 3, 3)\n        >>> outputs = F.conv3d(inputs, filters)\n        \n    \"\"\",\n)\n"
  },
  {
    "path": "python/oneflow/framework/docstr/convolution.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow\nfrom oneflow.framework.docstr.utils import add_docstr\n\nadd_docstr(\n    oneflow.nn.functional.fold,\n    r\"\"\"\n    fold(input, output_size, kernel_size, dilation=1, padding=0, stride=1)\n\n    The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.functional.fold.html.\n    \n    Combines an array of sliding local blocks into a large containing tensor.\n\n    .. warning::\n        Currently, only 3-D input tensors (batched image-like tensors) are supported, and only unbatched (3D) \n        or batched (4D) image-like output tensors are supported.\n\n    See :class:`oneflow.nn.Fold` for details.\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.nn.functional.unfold,\n    r\"\"\"\n    unfold(input, kernel_size, dilation=1, padding=0, stride=1)\n\n    The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.functional.unfold.html.\n\n    Extracts sliding local blocks from a batched input tensor.\n\n    .. warning::\n        Currently, only 4-D input tensors (batched image-like tensors) are supported.\n\n    .. warning::\n\n        More than one element of the unfolded tensor may refer to a single\n        memory location. As a result, in-place operations (especially ones that\n        are vectorized) may result in incorrect behavior. If you need to write\n        to the tensor, please clone it first.\n\n\n    See :class:`oneflow.nn.Unfold` for details.\n    \"\"\",\n)\n"
  },
  {
    "path": "python/oneflow/framework/docstr/ctc_decode.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow\nfrom oneflow.framework.docstr.utils import add_docstr\n\nadd_docstr(\n    oneflow._C.ctc_greedy_decoder,\n    \"\"\"\n    ctc_greedy_decoder(log_probs: Tensor, input_lengths: Tensor, merge_repeated: bool=True) -> Tensor\n\n    Performs greedy decoding on the logits given in input (best path).\n\n    Args:\n        log_probs(oneflow.Tensor): A Tensor of shape [input_length, batch_size, num_labels]. The logarithmized probabilities of the outputs (e.g. obtained with flow.nn.logsoftmax()).\n        input_lengths(oneflow.Tensor): A Tensor of shape [batch_size]. It represent the lengths of the inputs. And the lengths are specified for each sequence to achieve masking under the assumption that sequences are padded to equal lengths.\n        merge_repeated (bool, optional): If merge_repeated is True, merge repeated classes in output. This means that if consecutive logits' maximum indices are the same, only the first of these is emitted. Defaults to True.\n\n    Returns:\n        decoded(oneflow.Tensor): A Tensor of shape [batch_size, input_length], The decoded outputs.\n        neg_sum_logits(oneflow.Tensor): A float matrix (batch_size x 1) containing, for the sequence found, the negative of the sum of the greatest logit at each timeframe.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n        >>> log_probs = flow.tensor(\n        ...     [\n        ...         [[-1.54, -1.20, -1.95, -1.65, -1.81], [-1.84, -1.74, -1.58, -1.55, -1.12]],\n        ...         [[-1.68, -1.48, -1.89, -1.30, -2.07], [-1.13, -1.45, -1.24, -1.61, -1.66]],\n        ...         [[-1.56, -1.40, -2.83, -1.67, -1.48], [-1.20, -2.01, -2.05, -1.95, -1.24]],\n        ...         [[-2.09, -1.76, -1.36, -1.67, -1.45], [-1.85, -1.48, -1.34, -2.16, -1.55]],\n        ...     ]\n        ... )\n        >>> input_lengths = flow.tensor([4, 4])\n        >>> decoded, neg_sum_logits = flow.nn.functional.ctc_greedy_decoder(log_probs, input_lengths)\n        >>> decoded\n        tensor([[1, 3, 1, 2],\n                [0, 2, 0, 0]], dtype=oneflow.int64)\n        >>> neg_sum_logits\n        tensor([[5.2600],\n                [4.7900]], dtype=oneflow.float32)\n\n    \"\"\",\n)\n"
  },
  {
    "path": "python/oneflow/framework/docstr/dataset.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow\nfrom oneflow.framework.docstr.utils import add_docstr\n"
  },
  {
    "path": "python/oneflow/framework/docstr/deconv.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow\nfrom oneflow.framework.docstr.utils import add_docstr\n\nadd_docstr(\n    oneflow._C.deconv1d,\n    r\"\"\"\n    conv_transpose1d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1) -> Tensor\n\n    Applies a 1D transposed convolution operator over an input signal composed of several input planes, sometimes also called “deconvolution”.\n    \n    The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.functional.conv_transpose1d.html\n\n    See :class:`~oneflow.nn.ConvTranspose1d` for details and output shape.\n\n    Args:\n        input: input tensor of shape :math:`(\\text{minibatch} , \\text{in_channels} , iW)`\n        weight: filters of shape :math:`(\\text{in_channels} , \\frac{\\text{out_channels}}{\\text{groups}} , kW)`\n        bias: optional bias of shape :math:`(\\text{out_channels})`. Default: None.\n        stride: the stride of the convolving kernel. Can be a single number or a\n          tuple `(sW,)`. Default: 1\n        padding: `dilation * (kernel_size - 1) - padding` zero-padding will be added to both sides of each dimension in the input. Can be a single number or a tuple `(padW,)`. Default: 0\n        output_padding: additional size added to one side of each dimension in the output shape. Can be a single number or a tuple `(out_padW)`. Default: 0\n        groups: split input into groups, :math:`\\text{in_channels}` should be divisible by the\n          number of groups. Default: 1\n        dilation: the spacing between kernel elements. Can be a single number or\n          a tuple `(dW,)`. Default: 1\n\n    For examples:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import oneflow.nn.functional as F\n        \n        >>> inputs = flow.randn(20, 16, 50)\n        >>> weights = flow.randn(16, 33, 5)\n        >>> outputs = F.conv_transpose1d(inputs, weights)\n        \"\"\",\n)\nadd_docstr(\n    oneflow._C.deconv2d,\n    r\"\"\"\n    conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1) -> Tensor\n\n    Applies a 2D transposed convolution operator over an input image composed of several input planes, sometimes also called “deconvolution”.\n\n    The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.functional.conv_transpose3d.html\n\n    See :class:`~oneflow.nn.ConvTranspose2d` for details and output shape.\n\n    Args:\n        input: input tensor of shape :math:`(\\text{minibatch} , \\text{in_channels} , iH , iW)`\n        weight: filters of shape :math:`(\\text{in_channels} , \\frac{\\text{out_channels}}{\\text{groups}} , kH , kW)`\n        bias: optional bias of shape :math:`(\\text{out_channels})`. Default: None.\n        stride: the stride of the convolving kernel. Can be a single number or a\n          tuple `(sH, sW)`. Default: 1\n        padding: `dilation * (kernel_size - 1) - padding` zero-padding will be added to both sides of each dimension in the input. Can be a single number or a tuple `(padH, padW)`. Default: 0\n        output_padding: additional size added to one side of each dimension in the output shape. Can be a single number or a tuple `(out_padH, out_padW)`. Default: 0\n        groups: split input into groups, :math:`\\text{in_channels}` should be divisible by the\n          number of groups. Default: 1\n        dilation: the spacing between kernel elements. Can be a single number or\n          a tuple `(dH, dW)`. Default: 1\n    \n    For examples:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import oneflow.nn.functional as F\n        \n        >>> inputs = flow.randn(1, 4, 5, 5)\n        >>> weights = flow.randn(4, 8, 3, 3)\n        >>> outputs = F.conv_transpose2d(inputs, weights, padding=1)\n        \"\"\",\n)\nadd_docstr(\n    oneflow._C.deconv3d,\n    r\"\"\"\n    conv_transpose3d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1) -> Tensor\n\n    Applies a 3D transposed convolution operator over an input image composed of several input planes, sometimes also called “deconvolution”.\n\n    The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.functional.conv_transpose3d.html\n\n    See :class:`~oneflow.nn.ConvTranspose3d` for details and output shape.\n\n    Args:\n        input: input tensor of shape\n          :math:`(\\text{minibatch} , \\text{in_channels} , iT , iH , iW)`\n        weight: filters of shape\n          :math:`(\\text{in_channels} , \\frac{\\text{out_channels}}{\\text{groups}} , kT , kH , kW)`\n        bias: optional bias of shape :math:`(\\text{out_channels})`. Default: None.\n        stride: the stride of the convolving kernel. Can be a single number or a\n          tuple `(sD, sH, sW)`. Default: 1\n        padding: `dilation * (kernel_size - 1) - padding` zero-padding will be added to both sides of each dimension in the input. Can be a single number or a tuple `(padT, padH, padW)`. Default: 0\n        output_padding: additional size added to one side of each dimension in the output shape. Can be a single number or a tuple `(out_padT, out_padH, out_padW)`. Default: 0\n        groups: split input into groups, :math:`\\text{in_channels}` should be\n          divisible by the number of groups. Default: 1\n        dilation: the spacing between kernel elements. Can be a single number or\n          a tuple `(dT, dH, dW)`. Default: 1\n        \n    For examples:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import oneflow.nn.functional as F\n        \n        >>> inputs = flow.randn(20, 16, 50, 10, 20)\n        >>> weights = flow.randn(16, 33, 3, 3, 3)\n        >>> outputs = F.conv_transpose3d(inputs, weights)\n    \"\"\",\n)\n"
  },
  {
    "path": "python/oneflow/framework/docstr/depend.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow\nfrom oneflow.framework.docstr.utils import add_docstr\n\nadd_docstr(\n    oneflow._C.depend,\n    r\"\"\"\n    Add control dependency to guarantee OP A is executed before OP B.\n    Used to prevent OPs from being rearranged or eliminated during graph compilation.\n    Args:\n        input (Tensor): a tensor intended to input OP B\n        depend (Tensor or List[Tensor]): one of the output tensors of OP A (support passing in multiple tensors form different OP)\n    Returns:\n        Tensor: the identity of \"input\" tensor\n    Examples:\n        >>> import oneflow as flow\n        >>> import oneflow.nn as nn\n        >>> import oneflow.nn.functional as F\n        >>> class Model(nn.Module):\n        ...     def __init__(self):\n        ...         super().__init__()\n        ...         self.OP_A = nn.Linear(128, 128)\n        ...         self.OP_B = nn.Linear(128, 128)\n        ...\n        ...     def forward(self, x):\n        ...         x1 = self.OP_A(x)\n        ...         x = F.depend(x, x1)\n        ...         return self.OP_B(x)\n        ...\n        >>> model = Model()\n        >>> class Graph(nn.Graph):\n        ...     def __init__(self) -> None:\n        ...         super().__init__()\n        ...         self.model = model\n        ...\n        ...     def build(self, x):\n        ...         return self.model(x)\n        ...\n        >>> graph = Graph()\n        >>> x = flow.randn([1, 128], dtype=flow.float32)\n        >>> y = graph(x)\n    \"\"\",\n)\n"
  },
  {
    "path": "python/oneflow/framework/docstr/distance.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow\nfrom oneflow.framework.docstr.utils import add_docstr\n\nadd_docstr(\n    oneflow._C.cosine_similarity,\n    r\"\"\"\n    cosine_similarity(x1: Tensor, x2: Tensor, dim: int=1, eps: float=1e-8) -> Tensor\n\n    Returns cosine similarity between ``x1`` and ``x2``, computed along dim. ``x1`` and ``x2`` must be broadcastable\n    to a common shape. ``dim`` refers to the dimension in this common shape. Dimension ``dim`` of the output is\n    squeezed (see :func:`oneflow.squeeze`), resulting in the\n    output tensor having 1 fewer dimension.\n\n    The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.functional.cosine_similarity.html\n    \n    .. math ::\n        \\text{similarity} = \\dfrac{x_1 \\cdot x_2}{\\max(\\Vert x_1 \\Vert _2 \\cdot \\Vert x_2 \\Vert _2, \\epsilon)}\n    \n    Args:\n        x1 (Tensor): First input.\n        x2 (Tensor): Second input.\n        dim (int, optional): Dimension along which cosine similarity is computed. Default: 1\n        eps (float, optional): Small value to avoid division by zero.\n            Default: 1e-8\n\n    For examples:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import oneflow.nn.functional as F\n        >>> input1 = flow.randn(100, 128)\n        >>> input2 = flow.randn(100, 128)\n        >>> output = F.cosine_similarity(input1, input2)\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow._C.pairwise_distance,\n    r\"\"\"\n    pairwise_distance(x1: Tensor, x2: Tensor, dim: float=2.0, eps: float=1e-6, keepdim: bool=False) -> Tensor\n    Computes the pairwise distance between vectors :math:`v_1`, :math:`v_2` using the p-norm:\n\n    .. math ::\n        \\left \\| x \\right \\| _p = (\\sum_{i=1}^n \\left | x_i \\right |^p )^{\\frac{1}{p}}\n\n    The interface is consistent with PyTorch.\n    The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.PairwiseDistance.html.\n\n    Args:\n        x1 (Tensor): First input.\n        x2 (Tensor): Second input.\n        p (real): the norm degree. Default: 2\n        eps (float, optional): Small value to avoid division by zero. Default: 1e-6\n        keepdim (bool, optional): Determines whether or not to keep the vector dimension. Default: False\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> x1 = flow.arange(12).reshape(3, 4)\n        >>> x2 = flow.arange(12).reshape(3, 4)\n        >>> output = flow.nn.functional.pairwise_distance(x1, x2, p=2)\n        >>> output\n        tensor([2.0000e-06, 2.0000e-06, 2.0000e-06], dtype=oneflow.float32)\n        >>> output.shape\n        oneflow.Size([3])\n\n    \"\"\",\n)\n"
  },
  {
    "path": "python/oneflow/framework/docstr/dropout.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow\nfrom oneflow.framework.docstr.utils import add_docstr\n\nadd_docstr(\n    oneflow._C.dropout,\n    \"\"\"\n    dropout(x: Tensor, p: float = 0.5, training: bool = True, generator :Generator = None, *, addend: Tensor) -> Tensor \n    \n    During training, randomly zeroes some of the elements of the input\n    tensor with probability :attr:`p` using samples from a Bernoulli\n    distribution.\n\n    The documentation is referenced from:\n    https://pytorch.org/docs/1.10/generated/torch.nn.functional.dropout.html.\n\n    Args:      \n        x(Tensor): A Tensor which will be applyed dropout. \n        p(float): probability of an element to be zeroed. Default: 0.5    \n        training(bool): If is True it will apply dropout. Default: True     \n        generator(Generator, optional):  A pseudorandom number generator for sampling\n        addend(Tensor, optional):  A Tensor add in result after dropout, it can be used in model's residual connection structure. Default: None  \n\n    Shape:\n        - Input: :math:`(*)`. Input can be of any shape\n        - Output: :math:`(*)`. Output is of the same shape as input\n\n    For example:\n\n    Example 1: \n\n    .. code-block:: python\n\n        >>> import numpy as np\n        >>> import oneflow as flow\n\n       \n        >>> arr = np.array(\n        ...    [\n        ...        [-0.7797, 0.2264, 0.2458, 0.4163],\n        ...        [0.4299, 0.3626, -0.4892, 0.4141],\n        ...        [-1.4115, 1.2183, -0.5503, 0.6520],\n        ...    ]\n        ... )\n        >>> x = flow.tensor(arr, dtype=flow.float32)\n        >>> y = flow.nn.functional.dropout(x, p=0) \n\n        >>> arr = np.array(\n        ...    [\n        ...        [-0.7797, 0.2264, 0.2458, 0.4163],\n        ...        [0.4299, 0.3626, -0.4892, 0.4141],\n        ...        [-1.4115, 1.2183, -0.5503, 0.6520],\n        ...    ]\n        ... )\n        >>> x = flow.tensor(arr, dtype=flow.float32)\n        >>> generator = flow.Generator()\n        >>> y = flow.nn.functional.dropout(x, p=0.5, generator=generator) \n      \n    Example 2: \n    \n    .. code-block:: python\n\n        >>> import numpy as np\n        >>> import oneflow as flow\n\n       \n        >>> arr = np.array(\n        ...    [\n        ...        [-0.7797, 0.2264, 0.2458, 0.4163],\n        ...        [0.4299, 0.3626, -0.4892, 0.4141],\n        ...        [-1.4115, 1.2183, -0.5503, 0.6520],\n        ...    ]\n        ... )\n        >>> x = flow.tensor(arr, dtype=flow.float32)\n        >>> addend = flow.ones((3, 4), dtype=flow.float32)\n        >>> y = flow.nn.functional.dropout(x, p=0, addend=addend) \n        >>> y #doctest: +ELLIPSIS\n        tensor([[ 0.2203,  1.2264,  1.2458,  1.4163],\n                [ 1.4299,  1.3626,  0.5108,  1.4141],\n                [-0.4115,  2.2183,  0.4497,  1.6520]], dtype=oneflow.float32)\n    \n    See :class:`~oneflow.nn.Dropout` for details.   \n \n    \"\"\",\n)\n\nadd_docstr(\n    oneflow._C.dropout1d,\n    r\"\"\"\n    dropout1d(x: Tensor, p: float = 0.5, training: bool = True) -> Tensor \n\n    The documentation is referenced from:\n    https://pytorch.org/docs/1.10/generated/torch.nn.functional.dropout1d.html.\n\n    Randomly zero out entire channels (a channel is a 1D feature map,\n    e.g., the :math:`j`-th channel of the :math:`i`-th sample in the\n    batched input is a 1D tensor :math:`\\text{input}[i, j]`) of the input tensor).\n    Each channel will be zeroed out independently on every forward call with\n    probability :attr:`p` using samples from a Bernoulli distribution.\n\n    See :class:`~oneflow.nn.Dropout1d` for details.\n\n    Args:\n        p: probability of a channel to be zeroed. Default: 0.5\n        training: apply dropout if is ``True``. Default: ``True``\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow._C.dropout2d,\n    r\"\"\"\n    dropout1d(x: Tensor, p: float = 0.5, training: bool = True) -> Tensor \n\n    The documentation is referenced from:\n    https://pytorch.org/docs/1.10/generated/torch.nn.functional.dropout2d.html.\n\n    Randomly zero out entire channels (a channel is a 2D feature map,\n    e.g., the :math:`j`-th channel of the :math:`i`-th sample in the\n    batched input is a 2D tensor :math:`\\text{input}[i, j]`) of the input tensor).\n    Each channel will be zeroed out independently on every forward call with\n    probability :attr:`p` using samples from a Bernoulli distribution.\n\n    See :class:`~oneflow.nn.Dropout2d` for details.\n\n    Args:\n        p: probability of a channel to be zeroed. Default: 0.5\n        training: apply dropout if is ``True``. Default: ``True``\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow._C.dropout3d,\n    r\"\"\"\n    dropout1d(x: Tensor, p: float = 0.5, training: bool = True) -> Tensor \n\n    The documentation is referenced from:\n    https://pytorch.org/docs/1.10/generated/torch.nn.functional.dropout3d.html.\n\n    Randomly zero out entire channels (a channel is a 3D feature map,\n    e.g., the :math:`j`-th channel of the :math:`i`-th sample in the\n    batched input is a 3D tensor :math:`\\text{input}[i, j]`) of the input tensor).\n    Each channel will be zeroed out independently on every forward call with\n    probability :attr:`p` using samples from a Bernoulli distribution.\n\n    See :class:`~oneflow.nn.Dropout3d` for details.\n\n    Args:\n        p: probability of a channel to be zeroed. Default: 0.5\n        training: apply dropout if is ``True``. Default: ``True``\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.nn.Dropout,\n    \"\"\"\n    During training, randomly zeroes some of the elements of the input\n    tensor with probability :attr:`p` using samples from a Bernoulli\n    distribution. Each channel will be zeroed out independently on every forward\n    call.\n\n    The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.Dropout.html.\n\n    This has proven to be an effective technique for regularization and\n    preventing the co-adaptation of neurons as described in the paper\n    \"Improving neural networks by preventing co-adaptation of feature\n    detectors\".\n\n    Furthermore, the outputs are scaled by a factor of :math:`\\\\frac{1}{1-p}` during\n    training. This means that during evaluation the module simply computes an\n    identity function.\n\n    Additionally, we can pass an extra Tensor `addend` which shape is consistent with input Tensor. \n    The `addend` Tensor will be add in result after dropout, it is very useful in model's residual connection structure.\n\n    Args:\n        p: probability of an element to be zeroed. Default: 0.5\n        inplace: If set to ``True``, will do this operation in-place. Default: ``False``\n        generator:  A pseudorandom number generator for sampling\n\n    Shape:\n        - Input: :math:`(*)`. Input can be of any shape\n        - Output: :math:`(*)`. Output is of the same shape as input\n\n    For example:\n\n    example 1: \n\n    .. code-block:: python\n\n        >>> import numpy as np\n        >>> import oneflow as flow\n        \n        >>> m = flow.nn.Dropout(p=0)\n        >>> arr = np.array(\n        ...    [\n        ...        [-0.7797, 0.2264, 0.2458, 0.4163],\n        ...        [0.4299, 0.3626, -0.4892, 0.4141],\n        ...        [-1.4115, 1.2183, -0.5503, 0.6520],\n        ...    ]\n        ... )\n        >>> x = flow.Tensor(arr)\n        >>> y = m(x)\n        >>> y #doctest: +ELLIPSIS\n        tensor([[-0.7797,  0.2264,  0.2458,  0.4163],\n                [ 0.4299,  0.3626, -0.4892,  0.4141],\n                [-1.4115,  1.2183, -0.5503,  0.6520]], dtype=oneflow.float32)\n    \n    example 2: \n    \n    .. code-block:: python\n\n        >>> import numpy as np\n        >>> import oneflow as flow\n        \n        >>> m = flow.nn.Dropout(p=0)\n        >>> arr = np.array(\n        ...    [\n        ...        [-0.7797, 0.2264, 0.2458, 0.4163],\n        ...        [0.4299, 0.3626, -0.4892, 0.4141],\n        ...        [-1.4115, 1.2183, -0.5503, 0.6520],\n        ...    ]\n        ... )\n        >>> x = flow.Tensor(arr)\n        >>> addend = flow.ones((3, 4), dtype=flow.float32)\n        >>> y = m(x, addend=addend)\n        >>> y #doctest: +ELLIPSIS\n        tensor([[ 0.2203,  1.2264,  1.2458,  1.4163],\n                [ 1.4299,  1.3626,  0.5108,  1.4141],\n                [-0.4115,  2.2183,  0.4497,  1.6520]], dtype=oneflow.float32)\n    \n    .. _Improving neural networks by preventing co-adaptation of feature\n        detectors: https://arxiv.org/abs/1207.0580\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.nn.Dropout1d,\n    \"\"\"\n    Randomly zero out entire channels (a channel is a 1D feature map,\n    e.g., the :math:`j`-th channel of the :math:`i`-th sample in the\n    batched input is a 1D tensor :math:`\\text{input}[i, j]`).\n    Each channel will be zeroed out independently on every forward call with\n    probability :attr:`p` using samples from a Bernoulli distribution.\n\n    The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.Dropout1d.html.\n\n    Usually the input comes from :class:`nn.Conv1d` modules.\n\n    As described in the paper\n    `Efficient Object Localization Using Convolutional Networks`_ ,\n    if adjacent pixels within feature maps are strongly correlated\n    (as is normally the case in early convolution layers) then i.i.d. dropout\n    will not regularize the activations and will otherwise just result\n    in an effective learning rate decrease.\n\n    In this case, :func:`oneflow.nn.Dropout1d` will help promote independence between\n    feature maps and should be used instead.\n\n    Args:\n        p (float, optional): probability of an element to be zero-ed.\n        inplace (bool, optional): If set to ``True``, will do this operation\n            in-place\n\n    Shape:\n        - Input: :math:`(N, C, L)` or :math:`(C, L)`.\n        - Output: :math:`(N, C, L)` or :math:`(C, L)` (same shape as input).\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import numpy as np\n        >>> import oneflow as flow\n        \n        >>> m = flow.nn.Dropout1d(p=0)\n        >>> arr = np.array(\n        ...    [\n        ...        [-0.7797, 0.2264, 0.2458, 0.4163],\n        ...        [0.4299, 0.3626, -0.4892, 0.4141],\n        ...        [-1.4115, 1.2183, -0.5503, 0.6520],\n        ...    ]\n        ... )\n        >>> x = flow.Tensor(arr)\n        >>> y = m(x)\n        >>> y #doctest: +ELLIPSIS\n        tensor([[-0.7797,  0.2264,  0.2458,  0.4163],\n                [ 0.4299,  0.3626, -0.4892,  0.4141],\n                [-1.4115,  1.2183, -0.5503,  0.6520]], dtype=oneflow.float32)\n\n    .. _Efficient Object Localization Using Convolutional Networks:\n       https://arxiv.org/abs/1411.4280\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.nn.Dropout2d,\n    \"\"\"\n    Randomly zero out entire channels (a channel is a 2D feature map,\n    e.g., the :math:`j`-th channel of the :math:`i`-th sample in the\n    batched input is a 2D tensor :math:`\\text{input}[i, j]`).\n    Each channel will be zeroed out independently on every forward call with\n    probability :attr:`p` using samples from a Bernoulli distribution.\n\n    The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.Dropout2d.html.\n\n    Usually the input comes from :class:`nn.Conv2d` modules.\n\n    As described in the paper\n    `Efficient Object Localization Using Convolutional Networks`_ ,\n    if adjacent pixels within feature maps are strongly correlated\n    (as is normally the case in early convolution layers) then i.i.d. dropout\n    will not regularize the activations and will otherwise just result\n    in an effective learning rate decrease.\n\n    In this case, :func:`oneflow.nn.Dropout2d` will help promote independence between\n    feature maps and should be used instead.\n\n    Args:\n        p (float, optional): probability of an element to be zero-ed.\n        inplace (bool, optional): If set to ``True``, will do this operation\n            in-place\n\n    Shape:\n        - Input: :math:`(N, C, H, W)` or :math:`(C, H, W)`.\n        - Output: :math:`(N, C, H, W)` or :math:`(C, H, W)` (same shape as input).\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import numpy as np\n        >>> import oneflow as flow\n        \n        >>> m = flow.nn.Dropout2d(p=0)\n        >>> arr = np.array(\n        ...    [\n        ...        [-0.7797, 0.2264, 0.2458, 0.4163],\n        ...        [0.4299, 0.3626, -0.4892, 0.4141],\n        ...        [-1.4115, 1.2183, -0.5503, 0.6520],\n        ...    ]\n        ... )\n        >>> x = flow.Tensor(arr)\n        >>> y = m(x)\n        >>> y #doctest: +ELLIPSIS\n        tensor([[-0.7797,  0.2264,  0.2458,  0.4163],\n                [ 0.4299,  0.3626, -0.4892,  0.4141],\n                [-1.4115,  1.2183, -0.5503,  0.6520]], dtype=oneflow.float32)\n\n    .. _Efficient Object Localization Using Convolutional Networks:\n       https://arxiv.org/abs/1411.4280\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.nn.Dropout3d,\n    \"\"\"\n    Randomly zero out entire channels (a channel is a 3D feature map,\n    e.g., the :math:`j`-th channel of the :math:`i`-th sample in the\n    batched input is a 3D tensor :math:`\\text{input}[i, j]`).\n    Each channel will be zeroed out independently on every forward call with\n    probability :attr:`p` using samples from a Bernoulli distribution.\n\n    The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.Dropout2d.html.\n\n    Usually the input comes from :class:`nn.Conv3d` modules.\n\n    As described in the paper\n    `Efficient Object Localization Using Convolutional Networks`_ ,\n    if adjacent pixels within feature maps are strongly correlated\n    (as is normally the case in early convolution layers) then i.i.d. dropout\n    will not regularize the activations and will otherwise just result\n    in an effective learning rate decrease.\n\n    In this case, :func:`oneflow.nn.Dropout3d` will help promote independence between\n    feature maps and should be used instead.\n\n    Args:\n        p (float, optional): probability of an element to be zeroed.\n        inplace (bool, optional): If set to ``True``, will do this operation\n            in-place\n\n    Shape:\n        - Input: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)`.\n        - Output: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)` (same shape as input).\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import numpy as np\n        >>> import oneflow as flow\n        \n        >>> m = flow.nn.Dropout3d(p=0)\n        >>> arr = np.array(\n        ...    [\n        ...        [-0.7797, 0.2264, 0.2458, 0.4163],\n        ...        [0.4299, 0.3626, -0.4892, 0.4141],\n        ...        [-1.4115, 1.2183, -0.5503, 0.6520],\n        ...    ]\n        ... )\n        >>> x = flow.Tensor(arr)\n        >>> y = m(x)\n        >>> y #doctest: +ELLIPSIS\n        tensor([[-0.7797,  0.2264,  0.2458,  0.4163],\n                [ 0.4299,  0.3626, -0.4892,  0.4141],\n                [-1.4115,  1.2183, -0.5503,  0.6520]], dtype=oneflow.float32)\n\n    .. _Efficient Object Localization Using Convolutional Networks:\n       https://arxiv.org/abs/1411.4280\n    \"\"\",\n)\n"
  },
  {
    "path": "python/oneflow/framework/docstr/einsum.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow\nfrom oneflow.framework.docstr.utils import add_docstr\n\nadd_docstr(\n    oneflow.einsum,\n    \"\"\"\n    einsum(equation, *operands) -> oneflow.Tensor\n\n    Sums the product of the elements of the input :attr:`operands` along dimensions specified using a notation\n    based on the Einstein summation convention.\n\n    Einsum allows computing many common multi-dimensional linear algebraic array operations by representing them\n    in a short-hand format based on the Einstein summation convention, given by :attr:`equation`. The details of\n    this format are described below, but the general idea is to label every dimension of the input :attr:`operands`\n    with some subscript and define which subscripts are part of the output. The output is then computed by summing\n    the product of the elements of the :attr:`operands` along the dimensions whose subscripts are not part of the\n    output. For example, matrix multiplication can be computed using einsum as `flow.einsum(\"ij,jk->ik\", A, B)`.\n    Here, j is the summation subscript and i and k the output subscripts (see section below for more details on why).\n\n    Equation:\n\n        The :attr:`equation` string specifies the subscripts (letters in `[a-zA-Z]`) for each dimension of\n        the input :attr:`operands` in the same order as the dimensions, separating subcripts for each operand by a\n        comma (','), e.g. `'ij,jk'` specify subscripts for two 2D operands. The dimensions labeled with the same subscript\n        must be broadcastable, that is, their size must either match or be `1`. The exception is if a subscript is\n        repeated for the same input operand, in which case the dimensions labeled with this subscript for this operand\n        must match in size and the operand will be replaced by its diagonal along these dimensions. The subscripts that\n        appear exactly once in the :attr:`equation` will be part of the output, sorted in increasing alphabetical order.\n        The output is computed by multiplying the input :attr:`operands` element-wise, with their dimensions aligned based\n        on the subscripts, and then summing out the dimensions whose subscripts are not part of the output.\n\n        Optionally, the output subscripts can be explicitly defined by adding an arrow ('->') at the end of the equation\n        followed by the subscripts for the output. For instance, the following equation computes the transpose of a\n        matrix multiplication: 'ij,jk->ki'. The output subscripts must appear at least once for some input operand and\n        at most once for the output.\n\n        Ellipsis ('...') can be used in place of subscripts to broadcast the dimensions covered by the ellipsis.\n        Each input operand may contain at most one ellipsis which will cover the dimensions not covered by subscripts,\n        e.g. for an input operand with 5 dimensions, the ellipsis in the equation `'ab...c'` cover the third and fourth\n        dimensions. The ellipsis does not need to cover the same number of dimensions across the :attr:`operands` but the\n        'shape' of the ellipsis (the size of the dimensions covered by them) must broadcast together. If the output is not\n        explicitly defined with the arrow ('->') notation, the ellipsis will come first in the output (left-most dimensions),\n        before the subscript labels that appear exactly once for the input operands. e.g. the following equation implements\n        batch matrix multiplication `'...ij,...jk'`.\n\n        A few final notes: the equation may contain whitespaces between the different elements (subscripts, ellipsis,\n        arrow and comma) but something like `'. . .'` is not valid. An empty string `''` is valid for scalar operands.\n\n    .. note::\n\n        ``flow.einsum`` handles ellipsis ('...') differently from NumPy in that it allows dimensions\n        covered by the ellipsis to be summed over, that is, ellipsis are not required to be part of the output.\n\n    .. note::\n\n        This function does not optimize the given expression, so a different formula for the same computation may\n        run faster or consume less memory. Projects like opt_einsum (https://optimized-einsum.readthedocs.io/en/stable/)\n        can optimize the formula for you.\n\n    Args:\n        equation (String): The subscripts for the Einstein summation.\n        *operands (oneflow.Tensor): The tensors to compute the Einstein summation of.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n\n        # trace\n        >>> flow.einsum('ii', flow.arange(4*4).reshape(4,4).to(flow.float32))\n        tensor(30., dtype=oneflow.float32)\n\n        # diagonal\n        >>> flow.einsum('ii->i', flow.arange(4*4).reshape(4,4).to(flow.float32))\n        tensor([ 0.,  5., 10., 15.], dtype=oneflow.float32)\n\n        # outer product\n        >>> x = flow.arange(5).to(flow.float32)\n        >>> y = flow.arange(4).to(flow.float32)\n        >>> flow.einsum('i,j->ij', x, y)\n        tensor([[ 0.,  0.,  0.,  0.],\n                [ 0.,  1.,  2.,  3.],\n                [ 0.,  2.,  4.,  6.],\n                [ 0.,  3.,  6.,  9.],\n                [ 0.,  4.,  8., 12.]], dtype=oneflow.float32)\n        \n        # batch matrix multiplication\n        >>> As = flow.arange(3*2*5).reshape(3,2,5).to(flow.float32)\n        >>> Bs = flow.arange(3*5*4).reshape(3,5,4).to(flow.float32)\n        >>> flow.einsum('bij,bjk->bik', As, Bs).shape\n        oneflow.Size([3, 2, 4])\n\n        # batch permute\n        >>> A = flow.randn(2, 3, 4, 5)\n        >>> flow.einsum('...ij->...ji', A).shape\n        oneflow.Size([2, 3, 5, 4])\n\n        # bilinear\n        >>> A = flow.randn(3,5,4)\n        >>> l = flow.randn(2,5)\n        >>> r = flow.randn(2,4)\n        >>> flow.einsum('bn,anm,bm->ba', l, A, r).shape\n        oneflow.Size([2, 3])\n\n    \"\"\",\n)\n"
  },
  {
    "path": "python/oneflow/framework/docstr/erfinv.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow\nfrom oneflow.framework.docstr.utils import add_docstr\n\nadd_docstr(\n    oneflow.erfinv,\n    \"\"\"Computes the inverse error function of :attr:`input`. The inverse error function is defined in the range :math:`(-1, 1)` as:\n\n    .. math::\n        \\mathrm{erfinv}(\\mathrm{erf}(x)) = x\n\n    Args:\n        input (oneflow.Tensor): the input tensor.\n\n    For example:\n\n    .. code-block:: python\n    \n        >>> import oneflow as flow\n        >>> import numpy as np\n               \n        >>> input=flow.tensor(np.random.randn(3,3).astype(np.float32))\n        >>> of_out=flow.erfinv(input)\n        >>> of_out.shape\n        oneflow.Size([3, 3])\n\n\n    \"\"\",\n)\n"
  },
  {
    "path": "python/oneflow/framework/docstr/expand.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow\nfrom oneflow.framework.docstr.utils import add_docstr\n\nadd_docstr(\n    oneflow.expand,\n    \"\"\"\n    oneflow.expand(input, *sizes) -> Tensor,\n\n    This operator expand the input tensor to a larger size.\n\n    Passing -1 as the size for a dimension means not changing the size of that dimension.\n\n    Tensor can be also expanded to a larger number of dimensions and the new ones will be appended at the front.\n\n    For the new dimensions, the size cannot be set to -1.\n\n    Args:\n        input (oneflow.Tensor): the input Tensor.\n        *sizes  (oneflow.Size or int): The desired expanded size.\n\n    Returns:\n        oneflow.Tensor: The result Tensor.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n        >>> x = np.array([[[[0, 1]],\n        ...               [[2, 3]],\n        ...               [[4, 5]]]]).astype(np.int32)\n        >>> input = flow.Tensor(x)\n        >>> input.shape\n        oneflow.Size([1, 3, 1, 2])\n        >>> out = input.expand(1, 3, 2, 2)\n        >>> out.shape\n        oneflow.Size([1, 3, 2, 2])\n\n    \"\"\",\n)\n"
  },
  {
    "path": "python/oneflow/framework/docstr/flatten.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow\nfrom oneflow.framework.docstr.utils import add_docstr\n\nadd_docstr(\n    oneflow.flatten,\n    \"\"\"Flattens a contiguous range of dims into a tensor.\n\n    Args:\n        start_dim: first dim to flatten (default = 0).\n        end_dim: last dim to flatten (default = -1).\n    \n    For example: \n\n    .. code-block:: python \n\n        >>> import numpy as np\n        >>> import oneflow as flow\n        >>> input = flow.randn(32, 1, 5, 5)\n        >>> output = flow.flatten(input, start_dim=1)\n        >>> output.shape\n        oneflow.Size([32, 25])\n\n    \"\"\",\n)\n"
  },
  {
    "path": "python/oneflow/framework/docstr/flip.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow\nfrom oneflow.framework.docstr.utils import add_docstr\n\nadd_docstr(\n    oneflow.flip,\n    \"\"\"\n    flip(input, dims) -> Tensor\n\n    Reverse the order of a n-D tensor along given axis in dims.\n\n    .. note::\n        `flow.flip` makes a copy of :attr:`input`'s data. This is different from NumPy's `np.flip`,\n        which returns a view in constant time. Since copying a tensor's data is more work than viewing that data,\n        `flow.flip` is expected to be slower than `np.flip`.\n\n    Args:\n        input (Tensor): the input tensor\n        dims (a list or tuple): axis to flip on\n        \n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n        \n        >>> np_arr = np.arange(0, 8).reshape((2, 2, 2)).astype(np.float32)\n        >>> input = flow.Tensor(np_arr)\n        >>> input.shape\n        oneflow.Size([2, 2, 2])\n        >>> out = flow.flip(input, [0, 1])\n        >>> out\n        tensor([[[6., 7.],\n                 [4., 5.]],\n        <BLANKLINE>\n                [[2., 3.],\n                 [0., 1.]]], dtype=oneflow.float32)\n\n    \"\"\",\n)\n"
  },
  {
    "path": "python/oneflow/framework/docstr/hann_window.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow\nfrom oneflow.framework.docstr.utils import add_docstr\n\nadd_docstr(\n    oneflow.hann_window,\n    r\"\"\"\n    hann_window(window_length, periodic=True, *, device=None,  placement=None, sbp=None, dtype=None, requires_grad=False) -> Tensor\n\n    This function is equivalent to PyTorch’s hann_window function. \n    The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.hann_window.html.\n\n    Hann window function.\n\n    .. math::\n        w[n] = \\frac{1}{2}\\ \\left[1 - \\cos \\left( \\frac{2 \\pi n}{N - 1} \\right)\\right] =\n                \\sin^2 \\left( \\frac{\\pi n}{N - 1} \\right),\n\n    where :math:`N` is the full window size.\n\n    The input :attr:`window_length` is a positive integer controlling the\n    returned window size. :attr:`periodic` flag determines whether the returned\n    window trims off the last duplicate value from the symmetric window. Therefore, if :attr:`periodic` is true, the :math:`N` in\n    above formula is in fact :math:`\\text{window_length} + 1`. Also, we always have\n    ``oneflow.hann_window(L, periodic=True)`` equal to\n    ``oneflow.hann_window(L + 1, periodic=False)[:-1])``.\n\n    .. note::\n        If :attr:`window_length` :math:`=1`, the returned window contains a single value 1.\n\n    Arguments:\n        window_length (int): the size of returned window\n        periodic (bool, optional): If True, returns a window to be used as periodic\n            function. If False, return a symmetric window.\n\n    Keyword args:\n        dtype (oneflow.dtype, optional): the data type to perform the computation in.\n            Default: if None, uses the global default dtype (see oneflow.get_default_dtype())\n            when both :attr:`start` and :attr:`end` are real,\n            and corresponding complex dtype when either is complex.\n        device (oneflow.device, optional): the desired device of returned tensor. Default: if None, uses the current device for the default tensor type\n        placement (oneflow.placement, optional): the desired placement of returned global tensor. Default: if None, the returned tensor is local one using the argument `device`.\n        sbp (oneflow.sbp.sbp or tuple of oneflow.sbp.sbp, optional): the desired sbp descriptor of returned global tensor. Default: if None, the returned tensor is local one using the argument `device`.\n        requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: False.\n\n    Returns:\n        Tensor: A 1-D tensor of size :math:`(\\text{{window_length}},)` containing the window\n\n    \"\"\",\n)\n"
  },
  {
    "path": "python/oneflow/framework/docstr/in_top_k.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow\nfrom oneflow.framework.docstr.utils import add_docstr\n\nadd_docstr(\n    oneflow.in_top_k,\n    \"\"\"\n    in_top_k(targets, predictions, k) -> Tensor\n\n    Says whether the targets are in the top K predictions.\n\n    Args:\n        targets (Tensor): the target tensor of type int32 or int64.\n        predictions (Tensor): the predictions tensor of type float32 .\n        k (int): Number of top elements to look at for computing precision.\n\n    Returns:\n        oneflow.Tensor: A Tensor of type bool. Computed Precision at k as a bool Tensor.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n        >>> targets1 = flow.tensor(np.array([3, 1]), dtype=flow.int32)\n        >>> predictions1 = flow.tensor(np.array([[0.0, 1.0, 2.0, 3.0], [3.0, 2.0, 1.0, 0.0],]), dtype=flow.float32)\n        >>> out1 = flow.in_top_k(targets1, predictions1, k=1)\n        >>> out1\n        tensor([ True, False], dtype=oneflow.bool)\n        >>> out2 = flow.in_top_k(targets1, predictions1, k=2)\n        >>> out2\n        tensor([True, True], dtype=oneflow.bool)\n        >>> targets2 = flow.tensor(np.array([3, 1]), dtype=flow.int32, device=flow.device('cuda'))\n        >>> predictions2 = flow.tensor(np.array([[0.0, 1.0, 2.0, 3.0], [3.0, 2.0, 1.0, 0.0],]), dtype=flow.float32, device=flow.device('cuda'))\n        >>> out3 = flow.in_top_k(targets2, predictions2, k=1)\n        >>> out3\n        tensor([ True, False], device='cuda:0', dtype=oneflow.bool)\n    \"\"\",\n)\n"
  },
  {
    "path": "python/oneflow/framework/docstr/index_add.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow\nfrom oneflow.framework.docstr.utils import add_docstr\n\nadd_docstr(\n    oneflow.Tensor.index_add_,\n    r\"\"\"\n    index_add_(dim, index, source, *, alpha=1) -> Tensor\n\n    The interface is consistent with PyTorch.    \n\n    Accumulate the elements of :attr:`alpha` times ``source`` into the :attr:`self`\n    tensor by adding to the indices in the order given in :attr:`index`. For example,\n    if ``dim == 0``, ``index[i] == j``, and ``alpha=-1``, then the ``i``\\ th row of\n    ``source`` is subtracted from the ``j``\\ th row of :attr:`self`.\n\n    The :attr:`dim`\\ th dimension of ``source`` must have the same size as the\n    length of :attr:`index` (which must be a vector), and all other dimensions must\n    match :attr:`self`, or an error will be raised.\n\n    For a 3-D tensor the output is given as::\n\n        self[index[i], :, :] += alpha * src[i, :, :]  # if dim == 0\n        self[:, index[i], :] += alpha * src[:, i, :]  # if dim == 1\n        self[:, :, index[i]] += alpha * src[:, :, i]  # if dim == 2\n\n    Args:\n        dim (int): dimension along which to index\n        index (Tensor): indices of ``source`` to select from,\n                should have dtype either `oneflow.int64` or `oneflow.int32`\n        source (Tensor): the tensor containing values to add\n\n    Keyword args:\n        alpha (Number): the scalar multiplier for ``source``\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> x = flow.ones(5, 3)\n        >>> t = flow.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=flow.float)\n        >>> index = flow.tensor([0, 4, 2])\n        >>> x.index_add_(0, index, t)\n        tensor([[ 2.,  3.,  4.],\n                [ 1.,  1.,  1.],\n                [ 8.,  9., 10.],\n                [ 1.,  1.,  1.],\n                [ 5.,  6.,  7.]], dtype=oneflow.float32)\n        >>> x.index_add_(0, index, t, alpha=-1)\n        tensor([[1., 1., 1.],\n                [1., 1., 1.],\n                [1., 1., 1.],\n                [1., 1., 1.],\n                [1., 1., 1.]], dtype=oneflow.float32)\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow._C.index_add,\n    r\"\"\"\n    index_add(input, dim, index, source, *, alpha=1, out=None) -> Tensor\n\n    See :meth:`oneflow.Tensor.index_add_` for function description.\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow._C.index_add_,\n    r\"\"\"\n    index_add_(dim, index, source, *, alpha=1) -> Tensor\n\n    Out-of-place version of :meth:`oneflow.Tensor.index_add_`.\n    \"\"\",\n)\n"
  },
  {
    "path": "python/oneflow/framework/docstr/index_select.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow\nfrom oneflow.framework.docstr.utils import add_docstr\n\nadd_docstr(\n    oneflow.index_select,\n    \"\"\"\n    input.index_select(dim, index) -> Tensor\n\n    Select values along an axis specified by `dim`.\n\n    :attr:`index` must be an Int32 Tensor with 1-D.\n    :attr:`dim` must be in the range of input Dimensions.\n    value of :attr:`index` must be in the range of the dim-th of input.\n    Note that ``input`` and ``index`` do not broadcast against each other.  \n    \n    The interface is consistent with PyTorch.    \n    The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.index_select.html.\n\n    Args:\n        input (Tensor): the source tensor\n        dim (int): the axis along which to index\n        index (Tensor): the 1-D tensor containing the indices to index\n    \n    For example:\n\n    .. code-block:: python\n    \n        >>> import oneflow as flow\n        >>> input = flow.tensor([[1,2,3],[4,5,6]], dtype=flow.int32)\n        >>> input \n        tensor([[1, 2, 3],\n                [4, 5, 6]], dtype=oneflow.int32)\n        >>> index = flow.tensor([0,1], dtype=flow.int64)\n        >>> output = flow.index_select(input, 1, index)\n        >>> output\n        tensor([[1, 2],\n                [4, 5]], dtype=oneflow.int32)\n        >>> output = input.index_select(1, index)\n        >>> output\n        tensor([[1, 2],\n                [4, 5]], dtype=oneflow.int32)\n    \n    ..\n        Feature Stage of Operator [index_select].\n        - Maintainer List [@QiangX-man, @hjchen2, @strint]\n        - Current Stage [ ]\n        - Alpha Stage Check List [ ]\n          - API(Compatible with PyTorch 1.11, anything incompatible must be noted in API Doc.)[Yes]\n          - Doc(API Doc must be provided and showed normally on the web page.)[Yes]\n          - Functionality and its' Test [ ]\n            - Functionality is highly compatiable with PyTorch 1.11. [Yes]\n            - eager local [Yes] [@QiangX-man, @hjchen2]\n              - forward [Yes]\n              - backward [Yes]\n              - gpu [Yes]\n              - cpu [Yes]\n            - graph local [ ] [@BBuf, @strint, @hjchen2]\n              - forward [Yes]\n              - backward [ ]\n              - gpu [Yes]\n              - cpu [Yes]\n          - Exception Handling\n            - Exception Message and Hint must be provided [ ]\n        - Beta Stage Check List [ ]\n          - API(High compatibility with PyTorch 1.11, shouldn't have anything incompatible for a naive reason.)[ ]\n          - Doc(Same standard as Alpha Stage)[ ]\n          - Functionality and its' Test [ ]\n            - eager global [ ]\n              - forward [ ]\n              - backward [ ]\n              - gpu [ ]\n              - cpu [ ]\n            - graph gloal [ ]\n              - forward [ ]\n              - backward [ ]\n              - gpu [ ]\n              - cpu [ ]\n          - Performance and Scalability(Must be evaluated.)[ ]\n            - CUDA kernel [ ]\n            - CPU kernel [ ]\n            - N nodes M devices [ ]\n          - Exception Handling [ ]\n            - Exception Message and Hint must be provided [ ]\n            - Try you best to do Exception Recovery [ ]\n        - Stable Stage Check List [ ]\n          - API(Same standard as Beta Stage)[ ]\n          - Doc(Same standard as Beta Stage)[ ]\n          - Functionality and its' Test [ ]\n            - fp16 and AMP [ ]\n            - NHWC [ ]\n          - Performance and Scalability(Must be evaluated.)[ ]\n          - Exception Handling [ ]\n    \"\"\",\n)\n"
  },
  {
    "path": "python/oneflow/framework/docstr/inv.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow\nfrom oneflow.framework.docstr.utils import add_docstr\n\n\nadd_docstr(\n    oneflow.linalg.inv,\n    \"\"\"linalg.inv(A) -> Tensor\n\n    Computes the inverse of a square matrix if it exists.\n    Throws a `RuntimeError` if the matrix is not invertible.\n\n    Letting :math:`\\mathbb{K}` be :math:`\\mathbb{R}` or :math:`\\mathbb{C}`,\n    for a matrix :math:`A \\in \\mathbb{K}^{n \\times n}`,\n    its **inverse matrix** :math:`A^{-1} \\in \\mathbb{K}^{n \\times n}` (if it exists) is defined as\n\n    .. math::\n\n        A^{-1}A = AA^{-1} = \\mathrm{I}_n\n\n    where :math:`\\mathrm{I}_n` is the `n`-dimensional identity matrix.\n\n    The inverse matrix exists if and only if :math:`A` is `invertible`_. In this case,\n    the inverse is unique.\n\n    Supports input of float, double, cfloat and cdouble dtypes.\n    Also supports batches of matrices, and if :attr:`A` is a batch of matrices\n    then the output has the same batch dimensions.\n\n    Args:\n        A (Tensor): tensor of shape `(*, n, n)` where `*` is zero or more batch dimensions\n                    consisting of invertible matrices.\n\n    Raises:\n        RuntimeError: if the matrix :attr:`A` or any matrix in the batch of matrices :attr:`A` is not invertible.\n\n    Examples:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> A = flow.tensor([[ 1.3408, -0.7788,  1.0551, -0.5866],\n        ...                  [ 0.8480,  0.8350,  0.9781, -0.1297],\n        ...                  [-0.0881, -0.6142, -0.3833,  0.3232],\n        ...                  [ 1.2841,  0.7517, -0.3849,  0.2515]])\n        >>> flow.linalg.inv(A)\n        tensor([[ 0.3105, -0.0811,  0.1288,  0.5169],\n        ...     [-0.3457,  0.1716, -0.7133,  0.1987],\n        ...     [-0.0593,  1.1706,  0.8694, -0.6516],\n        ...     [-0.6427,  1.6923,  2.8049, -0.2541]], dtype=oneflow.float32)\n\n        >>> A = flow.tensor([[[ 0.6144,  0.1027, -0.1353],\n        ...                   [-1.4415, -0.6731,  0.3723],\n        ...                   [ 0.4069, -0.8940,  1.4056]],\n        ...                  [[-1.1891, -0.3897, -1.5015],\n        ...                   [ 0.3028,  1.1040,  0.2600],\n        ...                   [-1.6970,  0.4238,  0.9146]]])\n        >>> flow.linalg.inv(A)\n        tensor([[[ 1.6830,  0.0644,  0.1449],\n        ...      [-5.9755, -2.5206,  0.0925],\n        ...      [-4.2879, -1.6219,  0.7283]],\n        ...\n        ...     [[-0.2370,  0.0737, -0.4100],\n        ...      [ 0.1892,  0.9579,  0.0384],\n        ...      [-0.5274, -0.3070,  0.3148]]], dtype=oneflow.float32)\n\n    .. _invertible:\n        https://en.wikipedia.org/wiki/Invertible_matrix#The_invertible_matrix_theorem\n    \n    ..\n        Feature Stage of Operator [linalg.inv].\n        - Maintainer List [@simonJJJ]\n        - Current Stage [pre Alpha]\n        - Alpha Stage Check List [ ]\n          - API(Compatible with PyTorch 1.11, anything incompatible must be noted in API Doc.)[Yes]\n          - Doc(API Doc must be provided and showed normally on the web page.)[Yes]\n          - Functionality and its' Test [ ]\n            - Functionality is highly compatiable with PyTorch 1.11. [Yes]\n            - eager local [Yes] [@simonJJJ]\n              - forward [Yes]\n              - backward [Yes]\n              - gpu [Yes]\n              - cpu [Yes]\n            - graph local [ ] [@simonJJJ]\n              - forward [Yes]\n              - backward [ ]\n              - gpu [Yes]\n              - cpu [Yes]\n          - Exception Handling\n            - Exception Message and Hint must be provided [Yes]\n        - Beta Stage Check List [ ]\n          - API(High compatibility with PyTorch 1.11, shouldn't have anything incompatible for a naive reason.)[ ]\n          - Doc(Same standard as Alpha Stage)[Yes]\n          - Functionality and its' Test [ ]\n            - eager global [Yes] [@simonJJJ]\n              - forward [Yes]\n              - backward [Yes]\n              - gpu [Yes]\n              - cpu [Yes]\n            - graph gloal [Yes]\n              - forward [Yes]\n              - backward [ ]\n              - gpu [Yes]\n              - cpu [Yes]\n          - Performance and Scalability(Must be evaluated.)[ ]\n            - CUDA kernel [ ]\n            - CPU kernel [ ]\n            - N nodes M devices [ ]\n          - Exception Handling [Yes]\n            - Exception Message and Hint must be provided [Yes]\n            - Try you best to do Exception Recovery [Yes]\n        - Stable Stage Check List [ ]\n          - API(Same standard as Beta Stage)[ ]\n          - Doc(Same standard as Beta Stage)[ ]\n          - Functionality and its' Test [ ]\n            - fp16 and AMP [ ]\n            - NHWC [ ]\n          - Performance and Scalability(Must be evaluated.)[ ]\n          - Exception Handling [ ]\n    \"\"\",\n)\n"
  },
  {
    "path": "python/oneflow/framework/docstr/is_floating_point.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow\nfrom oneflow.framework.docstr.utils import add_docstr\n\nadd_docstr(\n    oneflow.is_floating_point,\n    r\"\"\"Returns True if the data type of input is a floating point data type i.e., one of `oneflow.float64` , `oneflow.float32` , `oneflow.float16`, and `oneflow.bfloat16`.\n\n    Args:\n        input  (Tensor): the input tensor.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        \n        >>> input = flow.tensor([1, 2, 3, 4, 5], dtype=flow.int)\n        >>> output = flow.is_floating_point(input)\n        >>> output\n        False\n    \"\"\",\n)\n"
  },
  {
    "path": "python/oneflow/framework/docstr/lerp.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow\nfrom oneflow.framework.docstr.utils import add_docstr\n\nadd_docstr(\n    oneflow.lerp,\n    \"\"\"\n    lerp(start, end, weight) -> Tensor\n\n    The documentation is referenced from: https://pytorch.org/docs/stable/generated/torch.lerp.html.\n\n    Does a linear interpolation of two tensors `start` and `end` based on a scalar or tensor `weight` and returns the result.\n\n    The shapes of start` and `end` must be broadcastable. If `weight` is a tensor, then the shapes of `weight`, `start`, and `end` must be broadcastable.\n\n    .. math::\n        out_{i} = start_{i} + weight_{i} * (end_{i} - start_{i})\n\n    Args:\n        start (oneflow.Tensor): the tensor with the starting points.\n        end (oneflow.Tensor): the tensor with the ending points.\n        weight (float or oneflow.Tensor): the weight for the interpolation formula.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> start = flow.arange(1., 5.)\n        >>> end = flow.empty(4).fill_(10)\n        >>> flow.lerp(start, end, 0.5)\n        tensor([5.5000, 6.0000, 6.5000, 7.0000], dtype=oneflow.float32)\n        >>> flow.lerp(start, end, flow.full_like(start, 0.5))\n        tensor([5.5000, 6.0000, 6.5000, 7.0000], dtype=oneflow.float32)\n\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.lerp_,\n    \"\"\"\n    In-place version of :func:`oneflow.lerp`\n    \"\"\",\n)\n"
  },
  {
    "path": "python/oneflow/framework/docstr/linalg.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow\nfrom oneflow.framework.docstr.utils import add_docstr\n\nadd_docstr(\n    oneflow.linalg.cross,\n    \"\"\"linalg.cross(input, other, dim=-1) -> Tensor\n\n    Computes the cross product of two 3-dimensional vectors.\n\n    Supports input of float and double dtypes. \n    Also supports batches of vectors, for which it computes the product along the dimension dim. \n    In this case, the output has the same batch dimensions as the inputs broadcast to a common shape.\n\n    The documentation is referenced from: https://pytorch.org/docs/1.11/generated/torch.linalg.cross.html\n\n    Args:\n        input (Tensor): the first input tensor.\n        other (Tensor): the second input tensor.\n        dim (int, optional): the dimension along which to take the cross-product. Default: `-1`\n\n    Raises:\n        RuntimeError:  If after broadcasting ``input.size(dim) != 3`` or ``other.size(dim) != 3``.\n    \n    Examples:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> a = flow.tensor([[ -0.3956, 1.1455,  1.6895],\n        ...                  [ -0.5849, 1.3672,  0.3599],\n        ...                  [ -1.1626, 0.7180, -0.0521],\n        ...                  [ -0.1339, 0.9902, -2.0225]])\n        >>> b = flow.tensor([[ -0.0257, -1.4725, -1.2251],\n        ...                  [ -1.1479, -0.7005, -1.9757],\n        ...                  [ -1.3904,  0.3726, -1.1836],\n        ...                  [ -0.9688, -0.7153,  0.2159]])\n        >>> flow.linalg.cross(a, b)\n        tensor([[ 1.0844, -0.5281,  0.6120],\n                [-2.4491, -1.5687,  1.9791],\n                [-0.8304, -1.3036,  0.5651],\n                [-1.2329,  1.9883,  1.0551]], dtype=oneflow.float32)\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.cross,\n    \"\"\"cross(input, other, dim=None) -> Tensor\n\n    Returns the cross product of vectors in dimension `dim` of `input` and `other`.\n\n    Supports input of float and double dtypes. \n    Also supports batches of vectors, for which it computes the product along the dimension `dim`. \n    In this case, the output has the same batch dimensions as the inputs.\n\n    If `dim` is not given, it defaults to the first dimension found with the size 3. Note that this might be unexpected.\n\n    The documentation is referenced from: https://pytorch.org/docs/1.11/generated/torch.cross.html\n\n    .. warning::\n        This function may change in a future PyTorch release to match the default behaviour in :func:`oneflow.linalg.cross`. We recommend using :func:`oneflow.linalg.cross`.\n\n    Args:\n        input (Tensor): the first input tensor.\n        other (Tensor): the second input tensor.\n        dim (int, optional): the dimension to take the cross-product in. Default: `None`\n    \n    Examples:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> a = flow.tensor([[ -0.3956, 1.1455,  1.6895],\n        ...                  [ -0.5849, 1.3672,  0.3599],\n        ...                  [ -1.1626, 0.7180, -0.0521],\n        ...                  [ -0.1339, 0.9902, -2.0225]])\n        >>> b = flow.tensor([[ -0.0257, -1.4725, -1.2251],\n        ...                  [ -1.1479, -0.7005, -1.9757],\n        ...                  [ -1.3904,  0.3726, -1.1836],\n        ...                  [ -0.9688, -0.7153,  0.2159]])\n        >>> flow.cross(a, b)\n        tensor([[ 1.0844, -0.5281,  0.6120],\n                [-2.4491, -1.5687,  1.9791],\n                [-0.8304, -1.3036,  0.5651],\n                [-1.2329,  1.9883,  1.0551]], dtype=oneflow.float32)\n    \"\"\",\n)\n"
  },
  {
    "path": "python/oneflow/framework/docstr/logaddexp.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow\nfrom oneflow.framework.docstr.utils import add_docstr\n\nadd_docstr(\n    oneflow.logaddexp,\n    \"\"\"\n    logaddexp(input, other, *, out=None) -> Tensor\n\n    The documentation is referenced from: https://pytorch.org/docs/stable/generated/torch.logaddexp.html.\n\n    Logarithm of the sum of exponentiations of the inputs.\n\n    Calculates pointwise :math:`\\log\\left(e^x + e^y\\right)`. This function is useful\n    in statistics where the calculated probabilities of events may be so small as to\n    exceed the range of normal floating point numbers. In such cases the logarithm\n    of the calculated probability is stored. This function allows adding\n    probabilities stored in such a fashion.\n\n    Args:\n        input (oneflow.Tensor): the input Tensor.\n        other (oneflow.Tensor): the second input Tensor.\n        out (oneflow.Tensor, optional): the output Tensor.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> flow.logaddexp(flow.tensor([-1.0]), flow.tensor([-1.0, -2, -3]))\n        tensor([-0.3069, -0.6867, -0.8731], dtype=oneflow.float32)\n        >>> flow.logaddexp(flow.tensor([-100.0, -200, -300]), flow.tensor([-1.0, -2, -3]))\n        tensor([-1., -2., -3.], dtype=oneflow.float32)\n \n    \"\"\",\n)\n"
  },
  {
    "path": "python/oneflow/framework/docstr/logical_ops.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow\nfrom oneflow.framework.docstr.utils import add_docstr\n\n\nadd_docstr(\n    oneflow.logical_and,\n    \"\"\"\n    Computes the element-wise logical AND of the given input tensors.\n    Zeros are treated as False and nonzeros are treated as True.\n\n    Args:\n        input (oneflow.Tensor): The input Tensor\n        other (oneflow.Tensor): The Tensor to compute AND with\n\n    Returns:\n        oneflow.Tensor: The output Tensor\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import numpy as np\n        >>> import oneflow as flow\n\n        >>> input1 = flow.tensor(np.array([1, 0, 1]).astype(np.float32), dtype=flow.float32)\n        >>> input2 = flow.tensor(np.array([1, 1, 0]).astype(np.float32), dtype=flow.float32)\n\n        >>> out = flow.logical_and(input1, input2)\n        >>> out\n        tensor([ True, False, False], dtype=oneflow.bool)\n\n    \"\"\",\n)\n\n\nadd_docstr(\n    oneflow.logical_or,\n    \"\"\"\n    Computes the element-wise logical OR of the given input tensors. \n    Zeros are treated as False and nonzeros are treated as True.\n\n    Args:\n        input (oneflow.Tensor): The input Tensor\n        other (oneflow.Tensor): The Tensor to compute OR with\n\n    Returns:\n        oneflow.Tensor: The output Tensor\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import numpy as np\n        >>> import oneflow as flow\n        \n        >>> input1 = flow.tensor(np.array([1, 0, 1]).astype(np.float32), dtype=flow.float32)\n        >>> input2 = flow.tensor(np.array([1, 0, 0]).astype(np.float32), dtype=flow.float32)\n\n        >>> out = flow.logical_or(input1, input2)\n        >>> out\n        tensor([ True, False,  True], dtype=oneflow.bool)\n\n    \"\"\",\n)\n\n\nadd_docstr(\n    oneflow.logical_xor,\n    \"\"\"\n    Computes the element-wise logical XOR of the given input tensors. \n    Zeros are treated as False and nonzeros are treated as True.\n\n    Args:\n        input (oneflow.Tensor): The input Tensor\n        other (oneflow.Tensor): The Tensor to compute XOR with\n\n    Returns:\n        oneflow.Tensor: The output Tensor\n\n    For example:\n\n    .. code-block:: python\n    \n        >>> import numpy as np\n        >>> import oneflow as flow\n        \n        >>> input1 = flow.tensor(np.array([1, 0, 1]).astype(np.float32), dtype=flow.float32)\n        >>> input2 = flow.tensor(np.array([1, 0, 0]).astype(np.float32), dtype=flow.float32)\n        >>> out = flow.logical_xor(input1, input2)\n        >>> out\n        tensor([False, False,  True], dtype=oneflow.bool)\n\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.logical_not,\n    r\"\"\"\n    Computes the element-wise logical NOT of the given input tensors.\n    Zeros are treated as False and nonzeros are treated as True.\n    Args:\n        input (oneflow.Tensor): The input Tensor\n        other (oneflow.Tensor): The Tensor to compute NOT with\n\n    Returns:\n        oneflow.Tensor: The output Tensor\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n\n        >>> input = flow.tensor([1, 0, -1], dtype=flow.float32)\n        >>> out = flow.logical_not(input)\n        >>> out\n        tensor([False,  True, False], dtype=oneflow.bool)\n\n    \"\"\",\n)\n"
  },
  {
    "path": "python/oneflow/framework/docstr/loss.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow\nfrom oneflow.framework.docstr.utils import add_docstr\n\nadd_docstr(\n    oneflow._C.triplet_margin_loss,\n    r\"\"\"    \n    Creates a criterion that measures the triplet loss given an input\n    tensors :math:`x1`, :math:`x2`, :math:`x3` and a margin with a value greater than :math:`0`.\n    This is used for measuring a relative similarity between samples. A triplet\n    is composed by `a`, `p` and `n` (i.e., `anchor`, `positive examples` and `negative\n    examples` respectively). The shapes of all input tensors should be\n    :math:`(N, D)`.\n    \n    The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.functional.triplet_margin_loss.html.\n    \n    The distance swap is described in detail in the paper `Learning shallow\n    convolutional feature descriptors with triplet losses <http://www.bmva.org/bmvc/2016/papers/paper119/index.html>`__ by\n    V. Balntas, E. Riba et al.\n\n    The loss function for each sample in the mini-batch is:\n\n    .. math::\n        L(a, p, n) = \\max \\{d(a_i, p_i) - d(a_i, n_i) + {\\rm margin}, 0\\}\n\n\n    where\n\n    .. math::\n        d(x_i, y_i) = \\left\\lVert {\\bf x}_i - {\\bf y}_i \\right\\rVert_p\n\n    Args:\n        margin (float, optional): Default: :math:`1`.\n        p (float, optional): The norm degree for pairwise distance. Default: :math:`2.0`.\n        swap (bool, optional): The distance swap is described in detail in the paper\n            `Learning shallow convolutional feature descriptors with triplet losses` by\n            V. Balntas, E. Riba et al. Default: ``False``.\n        reduction (string, optional): Specifies the reduction to apply to the output:\n            ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,\n            ``'mean'``: the sum of the output will be divided by the number of\n            elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average`\n            and :attr:`reduce` are in the process of being deprecated, and in the meantime,\n            specifying either of those two args will override :attr:`reduction`. Default: ``'mean'``\n\n    Shape:\n        - Input: :math:`(N, D)` where :math:`D` is the vector dimension.\n        - Output: A Tensor of shape :math:`(N)` if :attr:`reduction` is ``'none'``, or a scalar\n          otherwise.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n        >>> triplet_loss = flow.nn.TripletMarginLoss(margin=1.0, p=2)\n        >>> anchor = np.array([[1, -1, 1],[-1, 1, -1], [1, 1, 1]])\n        >>> positive = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])\n        >>> negative = np.array([[2, 2, 2], [2, 2, 2], [2, 2, 2]])\n        >>> output = triplet_loss(flow.Tensor(anchor), flow.Tensor(positive), flow.Tensor(negative))\n        >>> output\n        tensor(6.2971, dtype=oneflow.float32)\n    \n    \"\"\",\n)\n\nadd_docstr(\n    oneflow._C.cross_entropy,\n    r\"\"\"\n    cross_entropy(input, target, weight=None, ignore_index=-100, reduction=\"mean\", label_smoothing=0.0)\n\n    See :class:`~oneflow.nn.CrossEntropyLoss` for details.\n\n    The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.functional.cross_entropy.html.\n\n\n    Args:\n        input (Tensor) : :math:`(N, C)` where `C = number of classes` or :math:`(N, C, H, W)`\n            in case of 2D Loss, or :math:`(N, C, d_1, d_2, ..., d_K)` where :math:`K \\geq 1`\n            in the case of K-dimensional loss. `input` is expected to contain unnormalized scores\n            (often referred to as logits).\n        target (Tensor) : If containing class indices, shape :math:`(N)` where each value is\n            :math:`0 \\leq \\text{targets}[i] \\leq C-1`, or :math:`(N, d_1, d_2, ..., d_K)` with\n            :math:`K \\geq 1` in the case of K-dimensional loss. If containing class probabilities,\n            same shape as the input.\n        weight (Tensor, optional): a manual rescaling weight given to each\n            class. If given, has to be a Tensor of size `C`\n        ignore_index (int, optional): Specifies a target value that is ignored\n            and does not contribute to the input gradient. When :attr:`size_average` is\n            ``True``, the loss is averaged over non-ignored targets. Note that\n            :attr:`ignore_index` is only applicable when the target contains class indices.\n            Default: -100\n        reduction (string, optional): Specifies the reduction to apply to the output:\n            ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,\n            ``'mean'``: the sum of the output will be divided by the number of\n            elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average`\n            and :attr:`reduce` are in the process of being deprecated, and in the meantime,\n            specifying either of those two args will override :attr:`reduction`. Default: ``'mean'``\n        label_smoothing (float, optinoal): A float in [0.0, 1.0]. Specifies the amount\n            of smoothing when computing the loss, where 0.0 means no smoothing.\n            The targets become a mixture of the original ground truth and a uniform\n            distribution as described in `Rethinking the Inception Architecture for Computer Vision <https://arxiv.org/abs/1512.00567>`_.\n            Default: :math:`0.0`.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import oneflow.nn.functional as F\n        >>> input = flow.randn(3, 5, requires_grad=True)\n        >>> target = flow.ones(3, dtype=flow.int64)\n        >>> loss = F.cross_entropy(input, target)\n        >>> loss.backward()\n\n\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow._C.l1_loss,\n    r\"\"\"\n    l1_loss(input, target, reduction=\"mean\") -> Tensor\n\n    This operator computes the L1 loss between each element in input and target.\n\n    see :class:`~oneflow.nn.L1Loss` for details.\n\n    Args:\n        input (Tensor): The input Tensor.\n        target (Tensor): The target Tensor.\n        reduction (string, optional): The reduce type, it can be one of \"none\", \"mean\", \"sum\". Defaults to \"mean\".\n    \n    Examples::\n\n        >>> import oneflow as flow\n        >>> import oneflow.nn.functional as F\n        >>> input = flow.randn(3, 4, requires_grad=True)\n        >>> target = flow.rand(3, 4, requires_grad=False)\n        >>> loss = F.l1_loss(input, target)\n        >>> loss.backward()\n\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow._C.mse_loss,\n    r\"\"\"\n    mse_loss(input, target, reduction=\"mean\") -> Tensor\n\n    This operator computes the mean squared error (squared L2 norm) \n    loss between each element in input and target.\n\n    see :class:`~oneflow.nn.MSELoss` for details.\n\n    Args:\n        input (Tensor): The input Tensor.\n        target (Tensor): The target Tensor.\n        reduction (string, optional): Specifies the reduction to apply to the output:\n            ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,\n            ``'mean'``: the sum of the output will be divided by the number of\n            elements in the output, ``'sum'``: the output will be summed. Default: ``'mean'``\n\n    Examples::\n\n        >>> import oneflow as flow\n        >>> import oneflow.nn.functional as F\n        >>> input = flow.randn(3, 4, requires_grad=True)\n        >>> target = flow.rand(3, 4, requires_grad=False)\n        >>> loss = F.mse_loss(input, target)\n        >>> loss.backward()\n\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow._C.smooth_l1_loss,\n    \"\"\"\n    smooth_l1_loss(input: Tensor, target: Tensor, size_average: bool=True, reduce: bool=True, reduction: str='mean', beta: float=1.0) -> Tensor\n\n    Function that uses a squared term if the absolute\n    element-wise error falls below beta and an L1 term otherwise.\n\n    See :class:`~oneflow.nn.SmoothL1Loss` for details.\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow._C.binary_cross_entropy_loss,\n    r\"\"\"\n    binary_cross_entropy(input, target, weight=None, reduction='mean')\n\n    The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.functional.binary_cross_entropy.html.\n    \n    Function that measures the Binary Cross Entropy between the target and input probabilities.\n\n    See :class:`~oneflow.nn.BCELoss` for details.\n\n    Args:\n        input: Tensor of arbitrary shape as probabilities.\n        target: Tensor of the same shape as input with values between 0 and 1.\n        weight (Tensor, optional): a manual rescaling weight\n                if provided it's repeated to match input tensor shape\n        reduction (string, optional): Specifies the reduction to apply to the output:\n            ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,\n            ``'mean'``: the sum of the output will be divided by the number of\n            elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average`\n            and :attr:`reduce` are in the process of being deprecated, and in the meantime,\n            specifying either of those two args will override :attr:`reduction`. Default: ``'mean'``\n\n    Examples::\n\n        >>> import oneflow as flow\n        >>> import oneflow.nn.functional as F\n        >>> input = flow.randn(3, 2, requires_grad=True)\n        >>> target = flow.rand(3, 2, requires_grad=False)\n        >>> loss = F.binary_cross_entropy(flow.sigmoid(input), target)\n        >>> loss.backward()\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow._C.binary_cross_entropy_with_logits_loss,\n    r\"\"\"\n    binary_cross_entropy_with_logits(input, target, weight=None, reduction='mean', pos_weight=None)\n\n    The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.functional.binary_cross_entropy_with_logits.html.\n\n    Function that measures Binary Cross Entropy between target and input logits.\n\n    See :class:`~oneflow.nn.BCEWithLogitsLoss` for details.\n\n    Args:\n        input: Tensor of arbitrary shape as unnormalized scores (often referred to as logits).\n        target: Tensor of the same shape as input with values between 0 and 1\n        weight (Tensor, optional): a manual rescaling weight\n            if provided it's repeated to match input tensor shape\n        reduction (string, optional): Specifies the reduction to apply to the output:\n            ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,\n            ``'mean'``: the sum of the output will be divided by the number of\n            elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average`\n            and :attr:`reduce` are in the process of being deprecated, and in the meantime,\n            specifying either of those two args will override :attr:`reduction`. Default: ``'mean'``\n        pos_weight (Tensor, optional): a weight of positive examples.\n                Must be a vector with length equal to the number of classes.\n\n    Examples::\n\n        >>> import oneflow as flow\n        >>> import oneflow.nn.functional as F\n        >>> input = flow.randn(3, requires_grad=True)\n        >>> target = flow.randn(3)\n        >>> target[target >= 0] = 1\n        >>> target[target < 0] = 0\n        >>> loss = F.binary_cross_entropy_with_logits(input, target)\n        >>> loss.backward()\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow._C.kl_div_loss,\n    r\"\"\"\n    kl_div_loss(input, target, reduction=\"mean\", log_target=False)\n\n    `The Kullback-Leibler divergence loss measure <https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence>`_\n\n    See :class:`~oneflow.nn.KLDivLoss` for details.\n\n    Args:\n        reduction (string, optional): Specifies the reduction to apply to the output:\n            ``'none'`` | ``'batchmean'`` | ``'sum'`` | ``'mean'``.\n            ``'none'``: no reduction will be applied.\n            ``'batchmean'``: the sum of the output will be divided by batchsize.\n            ``'sum'``: the output will be summed.\n            ``'mean'``: the output will be divided by the number of elements in the output.\n            Default: ``'mean'``\n        log_target (bool, optional): Specifies whether `target` is passed in the log space.\n            Default: ``False``\n\n    .. note::\n        :attr:`reduction` = ``'mean'`` doesn't return the true kl divergence value, please use\n        :attr:`reduction` = ``'batchmean'`` which aligns with KL math definition.\n        In the next major release, ``'mean'`` will be changed to be the same as ``'batchmean'``.\n\n    Shape:\n        - Input: :math:`(N, *)` where :math:`*` means, any number of additional\n          dimensions\n        - Target: :math:`(N, *)`, same shape as the input\n        - Output: scalar by default. If :attr:``reduction`` is ``'none'``, then :math:`(N, *)`,\n          the same shape as the input\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n        >>> input = flow.tensor([-0.9021705, 0.08798598, 1.04686249], dtype=flow.float32)\n        >>> target = flow.tensor([1.22386942, -0.89729659, 0.01615712], dtype=flow.float32)\n        >>> out = flow.nn.functional.kl_div(input, target, reduction=\"none\", log_target=False)\n        >>> out\n        tensor([ 1.3514,  0.0000, -0.0836], dtype=oneflow.float32)\n        >>> out = flow.nn.functional.kl_div(input, target, reduction=\"mean\", log_target=False)\n        >>> out\n        tensor(0.4226, dtype=oneflow.float32)\n        >>> out = flow.nn.functional.kl_div(input, target, reduction=\"sum\", log_target=True)\n        >>> out\n        tensor(5.7801, dtype=oneflow.float32)\n\n    \"\"\",\n)\n"
  },
  {
    "path": "python/oneflow/framework/docstr/masked_fill.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow\nfrom oneflow.framework.docstr.utils import add_docstr\n\nadd_docstr(\n    oneflow.masked_fill,\n    \"\"\"\n    Fills elements of :attr:`self` tensor with :attr:`value` where :attr:`mask` is True.\n    The shape of :attr:`mask` must be broadcastable with the shape of the underlying tensor.\n\n    Args:\n        mask (BoolTensor): the boolean mask\n        value (float): the value to fill in with\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n        >>> in_arr = np.array(\n        ...     [[[-0.13169311,  0.97277078,  1.23305363,  1.56752789],\n        ...     [-1.51954275,  1.87629473, -0.53301206,  0.53006478],\n        ...     [-1.38244183, -2.63448052,  1.30845795, -0.67144869]],\n        ...     [[ 0.41502161,  0.14452418,  0.38968   , -1.76905653],\n        ...     [ 0.34675095, -0.7050969 , -0.7647731 , -0.73233418],\n        ...     [-1.90089858,  0.01262963,  0.74693893,  0.57132389]]]\n        ... )\n        >>> fill_value = 8.7654321 # random value e.g. -1e9 3.1415\n        >>> input = flow.tensor(in_arr, dtype=flow.float32)\n        >>> mask = flow.tensor((in_arr > 0).astype(np.int8), dtype=flow.int)\n        >>> output = flow.masked_fill(input, mask, fill_value)\n\n        # tensor([[[-0.1317,  8.7654,  8.7654,  8.7654],\n        #  [-1.5195,  8.7654, -0.533 ,  8.7654],\n        #  [-1.3824, -2.6345,  8.7654, -0.6714]],\n\n        # [[ 8.7654,  8.7654,  8.7654, -1.7691],\n        #  [ 8.7654, -0.7051, -0.7648, -0.7323],\n        #  [-1.9009,  8.7654,  8.7654,  8.7654]]], dtype=oneflow.float32)\n\n    \"\"\",\n)\n"
  },
  {
    "path": "python/oneflow/framework/docstr/math_ops.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow\nfrom oneflow.framework.docstr.utils import add_docstr\n\nadd_docstr(\n    oneflow.abs,\n    r\"\"\"Return the absolute value of each element in input tensor:math:`y = |x|` element-wise.\n\n    Args:\n        input (Tensor): the input tensor.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n\n        >>> x = flow.tensor(np.array([-1, 2, -3, 4]).astype(np.float32))\n        >>> flow.abs(x)\n        tensor([1., 2., 3., 4.], dtype=oneflow.float32)\n\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.add,\n    r\"\"\"\n    oneflow.add(input, other, *, alpha=1) -> Tensor\n    \n    Adds `other`, scaled by `alpha`, to `input`. Scalar and broadcast promotation are supported.\n\n    .. math::\n        out = input + alpha \\times other\n        \n    Args:\n        input (Union[int, float, oneflow.Tensor]): the input tensor.\n        other (Union[int, float, oneflow.Tensor]): the tensor or number to add to input.\n    \n    Keyword args:\n        alpha (Number, optional): the multiplier for `other`.\n\n    Returns:\n        oneflow.Tensor: the output Tensor.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import numpy as np\n        >>> import oneflow as flow\n\n        # element-wise add\n        >>> x = flow.tensor(np.random.randn(2,3), dtype=flow.float32)\n        >>> y = flow.tensor(np.random.randn(2,3), dtype=flow.float32)\n        >>> out = flow.add(x, y).numpy()\n        >>> out.shape\n        (2, 3)\n\n        # scalar add\n        >>> x = 5\n        >>> y = flow.tensor(np.random.randn(2,3), dtype=flow.float32)\n        >>> out = flow.add(x, y).numpy()\n        >>> out.shape\n        (2, 3)\n\n        # broadcast add\n        >>> x = flow.tensor(np.random.randn(1,1), dtype=flow.float32)\n        >>> y = flow.tensor(np.random.randn(2,3), dtype=flow.float32)\n        >>> out = flow.add(x, y).numpy()\n        >>> out.shape\n        (2, 3)\n        \n        # use alpha\n        >>> x = flow.zeros(2, 3)\n        >>> y = flow.ones(2, 3)\n        >>> out = flow.add(x, y, alpha=10)\n        >>> out\n        tensor([[10., 10., 10.],\n                [10., 10., 10.]], dtype=oneflow.float32)\n\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.floor,\n    \"\"\"\n    Returns a new tensor with the arcsine of the elements of :attr:`input`.\n\n    .. math::\n        \\\\text{out}_{i} = \\\\lfloor \\\\text{input}_{i} \\\\rfloor\n\n    Args:\n        input (Tensor): the input tensor.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n        >>> input = flow.tensor(np.array([-0.5,  1.5, 0,  0.8]), dtype=flow.float32)\n        >>> output = flow.floor(input)\n        >>> output.shape\n        oneflow.Size([4])\n        >>> output.numpy()\n        array([-1.,  1.,  0.,  0.], dtype=float32)\n\n        >>> input1 = flow.tensor(np.array([[0.8, 1.0], [-0.6, 2.5]]), dtype=flow.float32)\n        >>> output1 = input1.floor()\n        >>> output1.shape\n        oneflow.Size([2, 2])\n        >>> output1.numpy()\n        array([[ 0.,  1.],\n               [-1.,  2.]], dtype=float32)\n\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.floor_,\n    r\"\"\"\n    In-place version of :func:`oneflow.floor`\n\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.div,\n    r\"\"\"\n    div(x, y, *, rounding_mode=None)\n\n    Computes the division of input by other for each element, scalar and broadcast promotation are supported.\n    The formula is:\n\n    .. math::\n        out = \\frac{input}{other}\n\n    Args:\n        input (Union[int, float, oneflow.Tensor]): input.\n        other (Union[int, float, oneflow.Tensor]): other.\n\n    Keyword Arguments:\n        rounding_mode (str, optional): It can be set as ``\"floor\"`` (roudning the results down)\n            or ``\"trunc\"`` (rounding the results towards zero). None for default (no rounding).\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import numpy as np\n        >>> import oneflow as flow\n\n        # element-wise divide\n        >>> input = flow.tensor(np.random.randn(2,3), dtype=flow.float32)\n        >>> other = flow.tensor(np.random.randn(2,3), dtype=flow.float32)\n        >>> out = flow.div(input,other).numpy()\n        >>> out.shape\n        (2, 3)\n\n        # scalar divide\n        >>> input = 5\n        >>> other = flow.tensor(np.random.randn(2,3), dtype=flow.float32)\n        >>> out = flow.div(input,other).numpy()\n        >>> out.shape\n        (2, 3)\n\n        # broadcast divide\n        >>> input = flow.tensor(np.random.randn(1,1), dtype=flow.float32)\n        >>> other = flow.tensor(np.random.randn(2,3), dtype=flow.float32)\n        >>> out = flow.div(input,other).numpy()\n        >>> out.shape\n        (2, 3)\n\n        # rounding_mode\n        >>> x = flow.tensor([ 0.3810,  1.2774, -0.2972, -0.3719,  0.4637])\n        >>> flow.div(x, 0.5)\n        tensor([ 0.7620,  2.5548, -0.5944, -0.7438,  0.9274], dtype=oneflow.float32)\n        >>> flow.div(x, 0.5, rounding_mode=\"floor\")\n        tensor([ 0.,  2., -1., -1.,  0.], dtype=oneflow.float32)\n        >>> flow.div(x, 0.5, rounding_mode=\"trunc\")\n        tensor([0., 2., -0., -0., 0.], dtype=oneflow.float32)\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.mul,\n    r\"\"\"Computes the multiplication of input by other for each element, scalar and broadcast promotation are supported.\n\n    The formula is:\n\n    .. math::\n        \\text{out}_i = \\text{input}_i \\times \\text{other}_i\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import numpy as np\n        >>> import oneflow as flow\n\n        # element-wise multiply\n        >>> input = flow.tensor(np.random.randn(2,3), dtype=flow.float32)\n        >>> other = flow.tensor(np.random.randn(2,3), dtype=flow.float32)\n        >>> out = flow.mul(input,other).numpy()\n        >>> out.shape\n        (2, 3)\n\n        # scalar mutiply\n        >>> input = 5\n        >>> other = flow.tensor(np.random.randn(2,3), dtype=flow.float32)\n        >>> out = flow.mul(input,other).numpy()\n        >>> out.shape\n        (2, 3)\n\n        # broadcast mutiply\n        >>> input = flow.tensor(np.random.randn(1,1), dtype=flow.float32)\n        >>> other = flow.tensor(np.random.randn(2,3), dtype=flow.float32)\n        >>> out = flow.mul(input,other).numpy()\n        >>> out.shape\n        (2, 3)\n\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.reciprocal,\n    r\"\"\"Computes the safe reciprocal of x. If x is zero, the reciprocal will\n    be also set to zero.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import numpy as np\n        >>> import oneflow as flow\n\n        >>> x = flow.tensor(np.array([[1, 2, 3], [4, 5, 6]]), dtype=flow.float32)\n        >>> out = flow.reciprocal(x)\n        >>> out.numpy()\n        array([[1.        , 0.5       , 0.33333334],\n               [0.25      , 0.2       , 0.16666667]], dtype=float32)\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.sub,\n    r\"\"\"Computes the subtraction of input by other for each element, scalar and broadcast promotation are supported.\n    The formula is:\n\n    .. math::\n        out = input - other\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import numpy as np\n        >>> import oneflow as flow\n\n        # element-wise subtract\n        >>> input = flow.tensor(np.random.randn(2,3), dtype=flow.float32)\n        >>> other = flow.tensor(np.random.randn(2,3), dtype=flow.float32)\n        >>> out = flow.sub(input,other).numpy()\n        >>> out.shape\n        (2, 3)\n\n        # scalar subtract\n        >>> input = 5\n        >>> other = flow.tensor(np.random.randn(2,3), dtype=flow.float32)\n        >>> out = flow.sub(input,other).numpy()\n        >>> out.shape\n        (2, 3)\n\n        # broadcast subtract\n        >>> input = flow.tensor(np.random.randn(1,1), dtype=flow.float32)\n        >>> other = flow.tensor(np.random.randn(2,3), dtype=flow.float32)\n        >>> out = flow.sub(input,other).numpy()\n        >>> out.shape\n        (2, 3)\n\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.asin,\n    r\"\"\"\n    Returns a new tensor with the arcsine of the elements of :attr:`input`.\n\n    .. math::\n        \\text{out}_{i} = \\sin^{-1}(\\text{input}_{i})\n\n    Args:\n        input (Tensor): the input tensor.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n        >>> input = flow.tensor(np.array([-0.5,  0.8, 1.0,  -0.8]), dtype=flow.float32)\n        >>> output = flow.asin(input)\n        >>> output.shape\n        oneflow.Size([4])\n        >>> output\n        tensor([-0.5236,  0.9273,  1.5708, -0.9273], dtype=oneflow.float32)\n        >>> input1 = flow.tensor(np.array([[0.8, 1.0], [-0.6, -1.0]]), dtype=flow.float32)\n        >>> output1 = input1.asin()\n        >>> output1.shape\n        oneflow.Size([2, 2])\n        >>> output1\n        tensor([[ 0.9273,  1.5708],\n                [-0.6435, -1.5708]], dtype=oneflow.float32)\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.asinh,\n    r\"\"\"\n    Returns a new tensor with the inverse hyperbolic sine of the elements of :attr:`input`.\n\n    .. math::\n        \\text{out}_{i} = \\sinh^{-1}(\\text{input}_{i})\n\n    Args:\n        input (Tensor): the input tensor.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n        >>> input = flow.tensor(np.array([2, 3, 4]), dtype=flow.float32)\n        >>> output = flow.asinh(input)\n        >>> output.shape\n        oneflow.Size([3])\n        >>> output\n        tensor([1.4436, 1.8184, 2.0947], dtype=oneflow.float32)\n\n        >>> input1 = flow.tensor(np.array([[-1, 0, -0.4], [5, 7, 0.8]]), dtype=flow.float32)\n        >>> output1 = input1.asinh()\n        >>> output1.shape\n        oneflow.Size([2, 3])\n        >>> output1\n        tensor([[-0.8814,  0.0000, -0.3900],\n                [ 2.3124,  2.6441,  0.7327]], dtype=oneflow.float32)\n\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.atan,\n    r\"\"\"\n    Returns a new tensor with the arctangent of the elements of :attr:`input`.\n\n    .. math::\n        \\text{out}_{i} = \\tan^{-1}(\\text{input}_{i})\n\n    Args:\n        input (Tensor): the input tensor.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n        >>> input = flow.tensor(np.array([0.5, 0.6, 0.7]), dtype=flow.float32)\n        >>> output = flow.atan(input)\n        >>> output.shape\n        oneflow.Size([3])\n\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.ceil,\n    r\"\"\"Returns a new tensor with the ceil of the elements of :attr:`input`,\n    the smallest integer greater than or equal to each element.\n\n    The equation is:\n\n    .. math::\n        \\text{out}_{i} = \\left\\lceil \\text{input}_{i} \\right\\rceil = \\left\\lfloor \\text{input}_{i} \\right\\rfloor + 1\n\n    Args:\n        input (oneflow.Tensor): A Tensor.\n\n    Returns:\n        oneflow.Tensor: The result Tensor\n\n    For example:\n\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n        >>> x = flow.tensor(np.array([0.1, -2, 3.4]).astype(np.float32))\n        >>> y = flow.ceil(x)\n        >>> y.shape\n        oneflow.Size([3])\n        >>> y\n        tensor([ 1., -2.,  4.], dtype=oneflow.float32)\n        >>> x = flow.tensor(np.array([[2.5, 4.6, 0.6],[7.8, 8.3, 9.2]]).astype(np.float32))\n        >>> y = x.ceil()\n        >>> y.shape\n        oneflow.Size([2, 3])\n        >>> y\n        tensor([[ 3.,  5.,  1.],\n                [ 8.,  9., 10.]], dtype=oneflow.float32)\n        >>> x = flow.tensor(np.array([[[2.2, 4.4, 6.5],[7.1, 8.2, 9.3]],[[10.6,11.2,12.2],[13.5,14.8,15.9]]]).astype(np.float32))\n        >>> y = flow.ceil(x)\n        >>> y.shape\n        oneflow.Size([2, 2, 3])\n        >>> y\n        tensor([[[ 3.,  5.,  7.],\n                 [ 8.,  9., 10.]],\n        <BLANKLINE>\n                [[11., 12., 13.],\n                 [14., 15., 16.]]], dtype=oneflow.float32)\n\n    \"\"\",\n)\n\nadd_docstr(oneflow.ceil_, r\"\"\"In-place version of :func:`oneflow.ceil`\"\"\")\n\n\nadd_docstr(\n    oneflow.negative,\n    r\"\"\"This operator computes the negative value of Tensor.\n\n    Args:\n        input (oneflow.Tensor): A Tensor\n\n    Returns:\n        oneflow.Tensor: The result Tensor\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import numpy as np\n        >>> import oneflow as flow\n\n        >>> input = flow.tensor(\n        ...    np.array([1.0, -1.0, 2.3]).astype(np.float32), dtype=flow.float32\n        ... )\n        >>> out = flow.negative(input)\n        >>> out\n        tensor([-1.0000,  1.0000, -2.3000], dtype=oneflow.float32)\n\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.log1p,\n    r\"\"\"Returns a new tensor with the natural logarithm of (1 + input).\n\n    .. math::\n        \\text{out}_{i}=\\log_e(1+\\text{input}_{i})\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n        >>> x = flow.tensor(np.array([1.3, 1.5, 2.7]), dtype=flow.float32)\n        >>> out = flow.log1p(x)\n        >>> out\n        tensor([0.8329, 0.9163, 1.3083], dtype=oneflow.float32)\n\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.exp,\n    r\"\"\"\n\n    This operator computes the exponential of Tensor.\n\n    The equation is:\n\n    .. math::\n\n        out = e^x\n\n    Args:\n        x (oneflow.Tensor): A Tensor\n\n    Returns:\n        oneflow.Tensor: The result Tensor\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import numpy as np\n        >>> import oneflow as flow\n\n        >>> x = flow.tensor(np.array([1, 2, 3]).astype(np.float32), dtype=flow.float32)\n        >>> y = flow.exp(x)\n        >>> y\n        tensor([ 2.7183,  7.3891, 20.0855], dtype=oneflow.float32)\n\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.exp2,\n    r\"\"\"\n\n    This operator computes the base two exponential of Tensor.\n\n    The equation is:\n\n    .. math::\n\n        out = 2^x\n\n    Args:\n        x (oneflow.Tensor): A Tensor\n\n    Returns:\n        oneflow.Tensor: The result Tensor\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import numpy as np\n        >>> import oneflow as flow\n\n        >>> x = flow.tensor(np.array([1, 2, 3]).astype(np.float32), dtype=flow.float32)\n        >>> y = flow.exp2(x)\n        >>> y\n        tensor([2., 4., 8.], dtype=oneflow.float32)\n\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.acos,\n    r\"\"\"\n    Returns a new tensor with the inverse cosine of the elements of :attr:`input`.\n\n    .. math::\n        \\text{out}_{i} = \\arccos(\\text{input}_{i})\n\n    Args:\n        input (Tensor): the input tensor.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n\n        >>> arr = np.array([0.5, 0.6, 0.7])\n        >>> input = flow.tensor(arr, dtype=flow.float32)\n        >>> output = flow.acos(input)\n        >>> output\n        tensor([1.0472, 0.9273, 0.7954], dtype=oneflow.float32)\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.acosh,\n    r\"\"\"\n    Returns a new tensor with the inverse hyperbolic cosine of the elements of :attr:`input`.\n\n    .. math::\n\n        \\text{out}_{i} = \\cosh^{-1}(\\text{input}_{i})\n\n    Args:\n        input (Tensor): the input tensor.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n        >>> x1 = flow.tensor(np.array([2, 3, 4]).astype(np.float32))\n        >>> out1 = flow.acosh(x1)\n        >>> out1\n        tensor([1.3170, 1.7627, 2.0634], dtype=oneflow.float32)\n        >>> x2 = flow.tensor(np.array([1.5, 2.6, 3.7]).astype(np.float32),device=flow.device('cuda'))\n        >>> out2 = flow.acosh(x2)\n        >>> out2\n        tensor([0.9624, 1.6094, 1.9827], device='cuda:0', dtype=oneflow.float32)\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.atanh,\n    r\"\"\"Returns a new tensor with the inverse hyperbolic tangent of the elements of :attr:`input`.\n\n    .. math::\n        \\text{out}_{i} = \\tanh^{-1}(\\text{input}_{i})\n\n    Args:\n        input (Tensor): the input tensor.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n        >>> np_arr = np.array([0.5, 0.6, 0.7]).astype(np.float32)\n        >>> input = flow.tensor(np_arr, dtype=flow.float32)\n        >>> output = flow.atanh(input)\n        >>> output\n        tensor([0.5493, 0.6931, 0.8673], dtype=oneflow.float32)\n\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.sign,\n    r\"\"\"Computes the sign of Tensor.\n\n    .. math::\n\n        \\text{out}_{i}  = \\text{sgn}(\\text{input}_{i})\n\n    Args:\n        input (Tensor): the input tensor.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n        >>> x1 = flow.tensor(np.array([-2, 0, 2]).astype(np.float32))\n        >>> out1 = flow.sign(x1)\n        >>> out1.numpy()\n        array([-1.,  0.,  1.], dtype=float32)\n        >>> x2 = flow.tensor(np.array([-3.2, -4.5, 5.8]).astype(np.float32),device=flow.device('cuda'))\n        >>> out2 = flow.sign(x2)\n        >>> out2.numpy()\n        array([-1., -1.,  1.], dtype=float32)\n\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.sin,\n    r\"\"\"Returns a new tensor with the sine of the elements of :attr:`input`.\n\n    sin(x: Tensor) -> Tensor\n\n    .. math::\n        \\text{y}_{i} = \\sin(\\text{x}_{i})\n\n    Args:\n        x (Tensor): the input tensor.\n\n    For example:\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n        >>> x1 = flow.tensor(np.array([-0.5461,  0.1347, -2.7266, -0.2746]).astype(np.float32))\n        >>> y1 = flow.sin(x1)\n        >>> y1\n        tensor([-0.5194,  0.1343, -0.4032, -0.2712], dtype=oneflow.float32)\n\n        >>> x2 = flow.tensor(np.array([-1.4, 2.6, 3.7]).astype(np.float32), device=flow.device('cuda'))\n        >>> y2 = flow.sin(x2)\n        >>> y2\n        tensor([-0.9854,  0.5155, -0.5298], device='cuda:0', dtype=oneflow.float32)\n\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.sin_,\n    r\"\"\"\n    In-place version of :func:`oneflow.sin`\n\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.sinh,\n    r\"\"\"Returns a new tensor with the hyperbolic sine of the elements of :attr:`input`.\n\n    .. math::\n        \\text{out}_{i} = \\sinh(\\text{input}_{i})\n\n    Args:\n        input (Tensor): the input tensor.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import numpy as np\n        >>> import oneflow as flow\n\n        >>> x1 = flow.tensor(np.array([1, 2, 3]), dtype=flow.float32)\n        >>> x2 = flow.tensor(np.array([1.53123589,0.54242598,0.15117185]), dtype=flow.float32)\n        >>> x3 = flow.tensor(np.array([1,0,-1]), dtype=flow.float32)\n\n        >>> flow.sinh(x1).numpy()\n        array([ 1.1752012,  3.6268604, 10.017875 ], dtype=float32)\n        >>> flow.sinh(x2).numpy()\n        array([2.20381  , 0.5694193, 0.1517483], dtype=float32)\n        >>> flow.sinh(x3).numpy()\n        array([ 1.1752012,  0.       , -1.1752012], dtype=float32)\n\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.tan,\n    r\"\"\"Returns  the tan value of the elements of :attr:`input`.\n\n    .. math::\n        \\text{out}_{i} = \\tan(\\text{input}_{i})\n\n    Args:\n        input (Tensor): the input tensor.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n        >>> np_arr = np.array([-1/4*np.pi, 0, 1/4*np.pi]).astype(np.float32)\n        >>> input = flow.tensor(np_arr, dtype=flow.float32)\n        >>> output = flow.tan(input)\n        >>> output\n        tensor([-1.,  0.,  1.], dtype=oneflow.float32)\n\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.cos,\n    r\"\"\"\n    Returns a new tensor with the cosine  of the elements of :attr:`input`.\n\n    .. math::\n        \\text{out}_{i} = \\cos(\\text{input}_{i})\n\n    Args:\n        input (Tensor): the input tensor.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n        >>> arr = np.array([1.4309,  1.2706, -0.8562,  0.9796])\n        >>> input = flow.tensor(arr, dtype=flow.float32)\n        >>> output = flow.cos(input).numpy()\n\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.cosh,\n    r\"\"\"\n    Returns a new tensor with the hyperbolic cosine of the elements of :attr:`input`.\n\n    .. math::\n        \\text{out}_{i} = \\cosh(\\text{input}_{i})\n\n    Args:\n        input (Tensor): the input tensor.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import numpy as np\n        >>> import oneflow as flow\n\n        >>> arr = np.array([ 0.1632,  1.1835, -0.6979, -0.7325])\n        >>> input = flow.tensor(arr, dtype=flow.float32)\n        >>> output = flow.cosh(input).numpy()\n        >>> output\n        array([1.0133467, 1.7859949, 1.2535787, 1.2804903], dtype=float32)\n\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.erf,\n    r\"\"\"Computes the error function of each element. The error function is defined as follows:\n\n    .. math::\n            \\operatorname{erf}(x)=\\frac{2}{\\sqrt{\\pi}} \\int_{0}^{x} e^{-t^{2}} d t\n\n    Args:\n        x (oneflow.Tensor): A Tensor\n\n    Returns:\n        oneflow.Tensor: The result Tensor\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n\n        >>> x = flow.tensor(np.array([0, -1., 10.]), dtype=flow.float32)\n        >>> out = flow.erf(x)\n        >>> out.shape\n        oneflow.Size([3])\n        >>> out.numpy()\n        array([ 0.       , -0.8427008,  1.       ], dtype=float32)\n\n        >>> x = flow.tensor(np.array([[0, -1., 10.], [5, 7, 0.8]]), dtype=flow.float32)\n        >>> out = flow.erf(x)\n        >>> out.shape\n        oneflow.Size([2, 3])\n        >>> out.numpy()\n        array([[ 0.        , -0.8427008 ,  1.        ],\n               [ 1.        ,  1.        ,  0.74210095]], dtype=float32)\n\n        >>> x = flow.tensor(np.array([[0, -1., 10.], [5, 7, 0.8], [2, 3, 4]]), dtype=flow.float32)\n        >>> out = x.erf()\n        >>> out.shape\n        oneflow.Size([3, 3])\n        >>> out.numpy()\n        array([[ 0.        , -0.8427008 ,  1.        ],\n               [ 1.        ,  1.        ,  0.74210095],\n               [ 0.9953223 ,  0.9999779 ,  1.        ]], dtype=float32)\n\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.erfc,\n    r\"\"\"Computes the complementary error function of each element of input. The complementary error\n    function is defined as follows:\n\n    .. math::\n            \\operatorname{erfc}(x)=1-\\frac{2}{\\sqrt{\\pi}} \\int_{0}^{x} e^{-t^{2}} d t\n\n    Args:\n        x (oneflow.Tensor): A Tensor\n\n    Returns:\n        oneflow.Tensor: The result Tensor\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n\n        >>> x = flow.tensor(np.array([0, -1., 10.]), dtype=flow.float32)\n        >>> out = flow.erfc(x)\n        >>> out\n        tensor([1.0000e+00, 1.8427e+00, 2.8026e-45], dtype=oneflow.float32)\n\n        >>> x = flow.tensor(np.array([[0, -1., 10.], [5, 7, 0.8]]), dtype=flow.float32)\n        >>> out = flow.erfc(x)\n        >>> out\n        tensor([[1.0000e+00, 1.8427e+00, 2.8026e-45],\n                [1.5375e-12, 4.1838e-23, 2.5790e-01]], dtype=oneflow.float32)\n\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.expm1,\n    r\"\"\"Returns a new tensor with the exponential of the elements minus 1\n    of :attr:`input`.\n\n\n    The equation is:\n\n    .. math::\n        y_{i} = e^{x_{i}} - 1\n\n    Args:\n        input (oneflow.Tensor): A Tensor.\n\n    Returns:\n        oneflow.Tensor: The result Tensor\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n        >>> x = flow.tensor(np.array([1, 2, 3]).astype(np.float32))\n        >>> y = flow.expm1(x)\n        >>> y.shape\n        oneflow.Size([3])\n        >>> y\n        tensor([ 1.7183,  6.3891, 19.0855], dtype=oneflow.float32)\n\n        >>> x = flow.tensor(np.array([[[2, 4, 6],[7, 8, 9]],[[10,11,12],[13,14,15]]]).astype(np.float32))\n        >>> y = flow.expm1(x)\n        >>> print(y.shape)\n        oneflow.Size([2, 2, 3])\n        >>> print(y.numpy())\n        [[[6.3890562e+00 5.3598152e+01 4.0242880e+02]\n          [1.0956332e+03 2.9799580e+03 8.1020840e+03]]\n        <BLANKLINE>\n         [[2.2025465e+04 5.9873141e+04 1.6275380e+05]\n          [4.4241238e+05 1.2026032e+06 3.2690165e+06]]]\n\n\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.fmod,\n    r\"\"\"\n    fmod(input, other, *, out=None) -> Tensor\n\n    Computes the element-wise remainder of division.\n\n    The dividend and divisor may contain both for integer and floating point\n    numbers. The remainder has the same sign as the dividend :attr:`input`.\n\n    Supports broadcasting to a common shape, integer and float inputs.\n\n\n    Args:\n        input (Tensor): the dividend\n        other (Tensor or Scalar): the divisor\n\n    Keyword args:\n        out (Tensor, optional): the output tensor.\n\n    Example::\n\n        >>> import oneflow as flow\n        >>> flow.fmod(flow.tensor([-3., -2, -1, 1, 2, 3], dtype=flow.float32), 2.)\n        tensor([-1., -0., -1.,  1.,  0.,  1.], dtype=oneflow.float32)\n        >>> flow.fmod(flow.tensor([1, 2, 3, 4, 5.], dtype=flow.float32), 1.5)\n        tensor([1.0000, 0.5000, 0.0000, 1.0000, 0.5000], dtype=oneflow.float32)\n        >>> flow.fmod(flow.tensor([1, 2, 3, 4., -5]), flow.tensor([4, 2, 1, 3., 1]))\n        tensor([1., 0., 0., 1., -0.], dtype=oneflow.float32)\n\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.log,\n    r\"\"\"\n    Returns a new tensor with the natural logarithm of the elements of :attr:`input`.\n\n    .. math::\n        y_{i} = \\log_{e} (x_{i})\n\n    Args:\n        input (Tensor): the input tensor.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n        >>> arr = np.random.randn(2, 3, 4, 5)\n        >>> input = flow.tensor(arr, dtype=flow.float32)\n        >>> output = flow.log(input)\n\n\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.log2,\n    \"\"\"\n    oneflow.log2(input) -> Tensor\n\n    Returns a new tensor with the natural logarithm to the base 2 of the elements of :attr:`input`.\n    \n    .. math::\n        y_{i} = \\\\log2_{e} (x_{i})\n\n    Args:\n        input (Tensor): the input tensor.\n    \n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n        >>> arr = np.random.randn(2, 3, 4, 5)\n        >>> input = flow.tensor(arr, dtype=flow.float32)\n        >>> output = flow.log2(input)\n\n\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.log10,\n    \"\"\"\n    oneflow.log10(input) -> Tensor\n\n    Returns a new tensor with the natural logarithm to the base 10 of the elements of :attr:`input`.\n\n    .. math::\n        y_{i} = \\\\log10_{e} (x_{i})\n\n    Args:\n        input (Tensor): the input tensor.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> x = flow.ones(3, 3) * 10\n        >>> output = flow.log10(x)\n        >>> output\n        tensor([[1., 1., 1.],\n                [1., 1., 1.],\n                [1., 1., 1.]], dtype=oneflow.float32)\n\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.minimum,\n    r\"\"\"Computes the element-wise minimum of x and y.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n\n        >>> x = flow.tensor((1, 2, -1), dtype=flow.float32)\n        >>> y = flow.tensor((3, 0, 4), dtype=flow.float32)\n        >>> flow.minimum(x, y)\n        tensor([ 1.,  0., -1.], dtype=oneflow.float32)\n\n        >>> x = flow.tensor((1,), dtype=flow.float32)\n        >>> y = flow.tensor((3, 0, 4), dtype=flow.float32)\n        >>> flow.minimum(x, y)\n        tensor([1., 0., 1.], dtype=oneflow.float32)\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.maximum,\n    r\"\"\"Computes the element-wise maximum of x and y.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n\n        >>> x = flow.tensor((1, 2, -1), dtype=flow.float32)\n        >>> y = flow.tensor((3, 0, 4), dtype=flow.float32)\n        >>> flow.maximum(x, y)\n        tensor([3., 2., 4.], dtype=oneflow.float32)\n\n        >>> x = flow.tensor((1,), dtype=flow.float32)\n        >>> y = flow.tensor((3, 0, 4), dtype=flow.float32)\n        >>> flow.maximum(x, y)\n        tensor([3., 1., 4.], dtype=oneflow.float32)\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.median,\n    r\"\"\"\n    median(input) -> Tensor\n\n    Returns the median of the values in input.\n    The documentation is referenced from:\n    https://pytorch.org/docs/1.10/generated/torch.median.html#torch.median\n\n    .. note::\n        The median is not unique for :attr:`input` tensors with an even number\n        of elements. In this case the lower of the two medians is returned.\n\n    Args:\n        input (Tensor): the input tensor.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> x = flow.tensor((1, 2, -1), dtype=flow.float32)\n        >>> flow.median(x)\n        tensor(1., dtype=oneflow.float32)\n\n    .. function:: median(input, dim=-1, keepdim=False, *, out=None) -> (Tensor, LongTensor)\n        :noindex:\n\n    Returns a tuple ``(values, indices)`` where ``values`` contains the median of each row of :attr:`input`\n    in the dimension :attr:`dim`, and ``indices`` contains the index of the median values found in the dimension :attr:`dim`.\n\n    By default, :attr:`dim` is the last dimension of the :attr:`input` tensor.\n\n    If :attr:`keepdim` is ``True``, the output tensors are of the same size\n    as :attr:`input` except in the dimension :attr:`dim` where they are of size 1.\n    Otherwise, :attr:`dim` is squeezed (see :func:`flow.squeeze`), resulting in\n    the outputs tensor having 1 fewer dimension than :attr:`input`.\n\n    .. note::\n        The median is not unique for :attr:`input` tensors with an even number\n        of elements in the dimension :attr:`dim`. In this case the lower of the\n        two medians is returned.\n\n    Args:\n        input (Tensor): the input tensor.\n        dim (int): the dimension to reduce.\n        keepdim (bool): whether the output tensor has :attr:`dim` retained or not.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> a = flow.tensor([[ 0.2505, -0.3982, -0.9948,  0.3518, -1.3131],\n        ...    [ 0.3180, -0.6993,  1.0436,  0.0438,  0.2270],\n        ...    [-0.2751,  0.7303,  0.2192,  0.3321,  0.2488],\n        ...    [ 1.0778, -1.9510,  0.7048,  0.4742, -0.7125]])\n        >>> result=flow.median(a, 1)\n        >>> result.values\n        tensor([-0.3982,  0.2270,  0.2488,  0.4742], dtype=oneflow.float32)\n        >>> result.indices\n        tensor([1, 4, 4, 3], dtype=oneflow.int64)\n        \n    ..\n        Feature Stage of Operator [index_select].\n        - Maintainer List [@simonJJJ]\n        - Current Stage [pre Alpha]\n        - Alpha Stage Check List [ ]\n          - API(Compatible with PyTorch 1.11, anything incompatible must be noted in API Doc.)[Yes]\n          - Doc(API Doc must be provided and showed normally on the web page.)[Yes]\n          - Functionality and its' Test [ ]\n            - Functionality is highly compatiable with PyTorch 1.11. [Yes]\n            - eager local [Yes] [@simonJJJ]\n              - forward [Yes]\n              - backward [Yes]\n              - gpu [Yes]\n              - cpu [Yes]\n            - graph local [ ] [@simonJJJ]\n              - forward [Yes]\n              - backward [ ]\n              - gpu [Yes]\n              - cpu [Yes]\n          - Exception Handling\n            - Exception Message and Hint must be provided [Yes]\n        - Beta Stage Check List [ ]\n          - API(High compatibility with PyTorch 1.11, shouldn't have anything incompatible for a naive reason.)[ ]\n          - Doc(Same standard as Alpha Stage)[Yes]\n          - Functionality and its' Test [ ]\n            - eager global [Yes] [@simonJJJ]\n              - forward [Yes]\n              - backward [Yes]\n              - gpu [Yes]\n              - cpu [Yes]\n            - graph gloal [ ]\n              - forward [ ]\n              - backward [ ]\n              - gpu [ ]\n              - cpu [ ]\n          - Performance and Scalability(Must be evaluated.)[ ]\n            - CUDA kernel [ ]\n            - CPU kernel [ ]\n            - N nodes M devices [ ]\n          - Exception Handling [ ]\n            - Exception Message and Hint must be provided [ ]\n            - Try you best to do Exception Recovery [ ]\n        - Stable Stage Check List [ ]\n          - API(Same standard as Beta Stage)[ ]\n          - Doc(Same standard as Beta Stage)[ ]\n          - Functionality and its' Test [ ]\n            - fp16 and AMP [ ]\n            - NHWC [ ]\n          - Performance and Scalability(Must be evaluated.)[ ]\n          - Exception Handling [ ]\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.mode,\n    r\"\"\"\n    oneflow.mode(input, dim=-1, keepdim=False)\n\n    Returns a namedtuple (values, indices) where values is the mode value of each row of \n    the input tensor in the given dimension dim, i.e. a value which appears most often in \n    that row, and indices is the index location of each mode value found.\n    \n    By default, :attr:`dim` is the last dimension of the :attr:`input` tensor.\n\n    If :attr:`keepdim` is ``True``, the output tensors are of the same size\n    as :attr:`input` except in the dimension :attr:`dim` where they are of size 1.\n    Otherwise, :attr:`dim` is squeezed (see :func:`flow.squeeze`), resulting in\n    the outputs tensor having 1 fewer dimension than :attr:`input`.\n    \n    Args:\n        input (Tensor): the input tensor.\n        dim (int): the dimension to reduce. Default: `-1`\n        keepdim (bool): whether the output tensor has dim retained or not. Default: `False`\n\n    Returns:\n        Tuple(oneflow.Tensor, oneflow.Tensor(dtype=int64)): the result tuple of two output\n        tensors (values, indices) \n        \n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n\n        >>> x = flow.tensor([6, 2, 5, 3, 3, 5, 4, 3])\n        >>> result = flow.mode(x)\n        >>> result.values\n        tensor(3, dtype=oneflow.int64)\n        >>> result.indices\n        tensor(7, dtype=oneflow.int64)\n        >>> x = flow.Tensor([[2, 1, 2, 3], [2, 4, 3, 3]])\n        >>> result = flow.mode(x, dim=0)\n        >>> result.values\n        tensor([2., 1., 2., 3.], dtype=oneflow.float32)\n        >>> result.indices\n        tensor([1, 0, 0, 1], dtype=oneflow.int64)\n        \n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.pow,\n    r\"\"\"Takes the power of each element in input with exponent and returns a tensor with the result. Exponent can be either a single float number, a single int number, or a tensor with the same shape as input.\n    When exponent is a scalar value, the operation applied is:\n\n    .. math::\n        \\text{out}_i = x_i ^ \\text{exponent}\n\n    When exponent is a tensor, the operation applied is:\n\n    .. math::\n        \\text{out}_i = x_i ^ {\\text{exponent}_i}\n\n    Args:\n        input (Tensor): the input tensor.\n        exponent (int, float, Tensor): the exponent.\n\n    Returns:\n        Tensor: The result of variance on the specified axis of input Tensor\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n\n        >>> x = flow.tensor(np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]), dtype=flow.float32)\n        >>> out = flow.pow(x, 2)\n        >>> out\n        tensor([ 1.,  4.,  9., 16., 25., 36.], dtype=oneflow.float32)\n\n        >>> x = flow.tensor(np.array([1.0, 2.0, 3.0, 4.0]), dtype=flow.float32)\n        >>> y = flow.tensor(np.array([1.0, 2.0, 3.0, 4.0]), dtype=flow.float32)\n        >>> out = flow.pow(x, y)\n        >>> out\n        tensor([  1.,   4.,  27., 256.], dtype=oneflow.float32)\n\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.rsqrt,\n    r\"\"\"Returns a new tensor with the reciprocal of the square-root of each of\n        the elements of :attr:`input`.\n\n        .. math::\n            \\text{out}_{i} = \\frac{1}{\\sqrt{\\text{input}_{i}}}\n\n        Args:\n            input (Tensor): the input tensor.\n\n         For example:\n\n        .. code-block:: python\n\n            >>> import oneflow as flow\n            >>> import numpy as np\n\n            >>> a = flow.tensor(np.array([1.0, 2.0, 3.0]), dtype=flow.float32)\n            >>> out = flow.rsqrt(a).numpy()\n            >>> out\n            array([1.        , 0.70710677, 0.57735026], dtype=float32)\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.sqrt,\n    r\"\"\"Returns a new tensor with the square-root of the elements of :attr:`input`.\n\n        .. math::\n            \\text{out}_{i} = \\sqrt{\\text{input}_{i}}\n\n        Args:\n            input (Tensor): the input tensor.\n\n         For example:\n\n        .. code-block:: python\n\n            >>> import oneflow as flow\n            >>> import numpy as np\n\n            >>> arr = np.array([1.0, 2.0, 3.0])\n            >>> input = flow.tensor(arr, dtype=flow.float32)\n            >>> output = flow.sqrt(input).numpy()\n            >>> output\n            array([1.       , 1.4142135, 1.7320508], dtype=float32)\n        \"\"\",\n)\n\n\nadd_docstr(\n    oneflow.square,\n    r\"\"\"Returns a new tensor with the square of the elements of :attr:`input`.\n\n        .. math::\n            \\text{out}_{i} = \\sqrt{\\text{input}_{i}}\n\n        Args:\n            input (Tensor): the input tensor.\n\n         For example:\n\n        .. code-block:: python\n\n            >>> import oneflow as flow\n            >>> import numpy as np\n\n            >>> arr = np.array([1.0, 2.0, 3.0])\n            >>> input = flow.tensor(arr, dtype=flow.float32)\n            >>> output = flow.square(input).numpy()\n            >>> output\n            array([1., 4., 9.], dtype=float32)\n        \"\"\",\n)\n\nadd_docstr(\n    oneflow.matmul,\n    r\"\"\"\n    matmul(input, other) -> Tensor\n\n    This operator applies matrix multiplication to two Tensor.\n\n    Args:\n        a (oneflow.Tensor): A Tensor\n        b (oneflow.Tensor): A Tensor\n\n    Returns:\n        oneflow.Tensor: The result Tensor\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n        >>> input1 = flow.tensor(np.random.randn(2, 6), dtype=flow.float32)\n        >>> input2 = flow.tensor(np.random.randn(6, 5), dtype=flow.float32)\n        >>> of_out = flow.matmul(input1, input2)\n        >>> of_out.shape\n        oneflow.Size([2, 5])\n\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.mv,\n    r\"\"\"\n    mv(input, vec) -> Tensor\n\n    Performs a matrix-vector product of the matrix :attr:`input` and the vector :attr:`vec`.\n\n    If :attr:`input` is a :math:`(n \\times m)` tensor, :attr:`vec` is a\n    1-D tensor of size `m`, :attr:`out` will be a 1-D tensor of size `n`.\n    \n    .. note:: This function does not broadcast.\n\n    The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.mv.html.\n\n    Args:\n        input (oneflow.Tensor): matrix to be matrix multiplied\n        vec (oneflow.Tensor): vector to be matrix multiplied\n    Returns:\n        oneflow.Tensor: the output Tensor\n    \n    For example:\n\n    .. code-block:: python\n    \n        >>> import oneflow as flow\n        >>> mat = flow.randn(2, 3)\n        >>> vec = flow.randn(3)\n        >>> out = flow.mv(mat, vec)\n        >>> out.shape\n        oneflow.Size([2])\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.mm,\n    r\"\"\"\n    mm(input, mat2) -> Tensor\n    \n    Performs a matrix multiplication of the matrices :attr:`input` and :attr:`mat2`.\n\n    If :attr:`input` is a :math:`(n \\times m)` tensor, :attr:`mat2` is a\n    :math:`(m \\times p)` tensor, :attr:`out` will be a :math:`(n \\times p)` tensor.\n\n    .. note:: This function does not broadcast.\n            For broadcasting matrix products, see :func:`oneflow.matmul`.\n\n    The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.mm.html.\n\n    Args:\n        input (oneflow.Tensor): the first matrix to be matrix multiplied\n        mat2 (oneflow.Tensor): the second matrix to be matrix multiplied\n\n    Returns:\n        oneflow.Tensor: The result Tensor\n    \n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> mat1 = flow.randn(2, 3)\n        >>> mat2 = flow.randn(3, 3)\n        >>> of_out = flow.mm(mat1, mat2)\n        >>> of_out.shape\n        oneflow.Size([2, 3])\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.round,\n    r\"\"\"This operator rounds the value of Blob to the nearest integer.\n    \n    .. note::\n        This function implements the \"round half to even\" to break ties when a number is equidistant from two integers (e.g. `round(2.5)` is 2).\n    \n    Args:\n        input (oneflow.Tensor): A Tensor\n    Returns:\n        oneflow.Tensor: The result Tensor\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n        >>> x1 = flow.tensor(np.array([1.49999, 1.500001, 2.7]).astype(np.float32))\n        >>> out1 = flow.round(x1)\n        >>> out1.numpy()\n        array([1., 2., 3.], dtype=float32)\n        >>> x2 = flow.tensor(np.array([2.499999, 7.5000001, 5.3, 6.8]).astype(np.float32))\n        >>> out2 = flow.round(x2)\n        >>> out2.numpy()\n        array([2., 8., 5., 7.], dtype=float32)\n\n    \"\"\",\n)\n\nadd_docstr(oneflow.round_, r\"\"\"In-place version of :func:`oneflow.round`.\"\"\")\n\nadd_docstr(\n    oneflow.std,\n    r\"\"\"\n    Returns the standard-deviation of each row of the :attr:`input` tensor in the\n    dimension :attr:`dim`. If :attr:`dim` is a list of dimensions,\n    reduce over all of them.\n\n    If keepdim is True, the output tensor is of the same size as input except in\n    the dimension(s) dim where it is of size 1. Otherwise, dim is squeezed,\n    resulting in the output tensor having 1 (or len(dim)) fewer dimension(s).\n\n    If :attr:`unbiased` is ``False``, then the standard-deviation will be calculated\n    via the biased estimator. Otherwise, Bessel's correction will be used.\n\n    Args:\n        input (Tensor): the input tensor.\n        dim (int or tuple of ints): the dimension or dimensions to reduce.\n        unbiased (bool): whether to use the unbiased estimation or not\n        keepdim (bool): whether the output tensor has `dim` retained or not.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n\n        >>> arr = np.array([1.0, 2.0, 3.0])\n        >>> input = flow.tensor(arr)\n        >>> output = flow.std(input, dim=0).numpy()\n        >>> output\n        array(1.)\n\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.var,\n    r\"\"\"Returns the variance of each row of the `input` tensor in the given dimension `dim`.\n\n    If `keepdim` is `True`, the output tensor is of the same size as `input` except in the dimension(s) `dim`\n    where it is of size 1. Otherwise, dim is squeezed (see `flow.squeeze()`), resulting in the output\n    tensor having 1 (or `len(dim)`) fewer dimension(s).\n\n    Args:\n        input (Tensor): the input tensor.\n        dim (int or tuple of ints): the dimension or dimensions to reduce. Defaults to None.\n        unbiased (bool, optional): whether to use Bessel’s correction (:math:`\\delta N = 1`). Defaults to True.\n        keepdim (bool, optional): whether the output tensor has dim retained or not. Defaults to False.\n\n    Returns:\n        Tensor: The result of variance on the specified axis of input Tensor\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import numpy as np\n        >>> import oneflow as flow\n\n        >>> input = flow.tensor(np.random.randn(2, 3, 4, 5))\n        >>> output = flow.var(input, 1, True)\n\n    \"\"\",\n)\n\n\nadd_docstr(\n    oneflow.dot,\n    r\"\"\"This operator computes the dot product of tensor input and other.\n\n    The equation is:\n\n\t$$\n        \\\\sum_{i=1}^{n}(x[i] * y[i])\n\t$$\n\n    Args:\n        input (Tensor):  first tensor in the dot product.\n        other (Tensor):  second tensor in the dot product.\n\n    Shape:\n        - input: Input must be 1D.\n        - other: Other must be 1D.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> flow.dot(flow.Tensor([2, 3]), flow.Tensor([2, 1]))\n        tensor(7., dtype=oneflow.float32)\n\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.select,\n    r\"\"\"\n    Slices the self tensor along the selected dimension at the given index. This function returns \n    a view of the original tensor with the given dimension removed.\n\n    Args:\n        input (Tensor): the input tensor.\n        dim  (int):  the dimension to slice.\n        select (int): the index to select with.\n\n    Returns:\n        oneflow.Tensor: the output Tensor.\n\n    For example:\n    \n    .. code-block:: python\n    \n        >>> import oneflow as flow\n        >>> input = flow.rand(3, 4, 5)\n        >>> out = flow.select(input, 0, 1)\n        >>> out.size()\n        oneflow.Size([4, 5])\n        >>> out = flow.select(input, 1, 1)\n        >>> out.size()\n        oneflow.Size([3, 5])\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.movedim,\n    r\"\"\"\n    Moves the dimension(s) of input at the position(s) in source to the position(s) in destination.\n    Other dimensions of input that are not explicitly moved remain in their original order and appear at the positions not specified in destination.\n    The documentation is referenced from:\n    https://pytorch.org/docs/1.10/generated/torch.movedim.html.\n\n    Args:\n        input (Tensor): the input tensor.\n        source  (int or a list): Original positions of the dims to move. These must be unique.\n        destination (int or a list): Destination positions for each of the original dims. These must also be unique.\n\n    Returns:\n        oneflow.Tensor: the output Tensor.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n\n        >>> input = flow.tensor(np.random.randn(2, 3, 4, 5), dtype=flow.float32)\n        >>> output = flow.movedim(input, 1, 0)\n        >>> output.shape\n        oneflow.Size([3, 2, 4, 5])\n        >>> output = flow.movedim(input, (1, 2), (0, 1))\n        >>> output.shape\n        oneflow.Size([3, 4, 2, 5])\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.as_strided,\n    r\"\"\"\n    as_strided(input, size, stride, storage_offset=None) -> Tensor\n    Create a view of an existing oneflow.Tensor input with specified size, stride and storage_offset.\n    The documentation is referenced from:\n    https://pytorch.org/docs/1.10/generated/torch.as_strided.html.\n\n    Args:\n        input (Tensor): the input tensor.\n        size (tuple or ints): the shape of the output tensor.\n        stride (tuple or ints): the stride of the output tensor.\n        storage_offset (int): the offset in the underlying storage of the output tensor\n\n    Returns:\n        oneflow.Tensor: the output Tensor.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n\n        >>> input = flow.rand(2,3,5)\n        >>> output = flow.as_strided(input, (2,3,3), (1,2,3), 1)\n        >>> output.size()\n        oneflow.Size([2, 3, 3])\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.addcmul,\n    r\"\"\"\n    oneflow.addcmul(input, tensor1, tensor2, *, value=1) -> Tensor\n\n    Performs the element-wise multiplication of tensor1 by tensor2, multiply the result\n    by the scalar value and add it to input.\n    The documentation is referenced from:\n    https://pytorch.org/docs/1.10/generated/torch.addcmul.html\n    \n    .. math::\n        \\text{out}_i = \\text{input}_i + value \\times\\  \\text{tensor1}_i \\times\\ \\text{tensor2}_i\n        \n    Args:\n        input (Tensor): the tensor to be added.\n        tensor1 (Tensor): the tensor to be multiplied.\n        tensor2 (Tensor): the tensor to be multiplied.\n    \n    Keyword args:\n        value (Number, optional): multiplier for :math:`tensor1 * tensor2`.\n\n    Returns:\n        oneflow.Tensor: the output Tensor.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        \n        >>> input = flow.rand(2, 3, 4)\n        >>> tensor1 = flow.rand(2, 3, 4)\n        >>> tensor2 = flow.rand(2, 3, 4)\n        >>> out = flow.addcmul(input, tensor1, tensor2, value=2)\n        >>> out.size()\n        oneflow.Size([2, 3, 4])\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.eye,\n    \"\"\"oneflow.eye(n, m, *, device=None, requires_grad=False, placement=None, sbp) -> Tensor\n\n    This operator creates a 2-D Tensor with ones on the diagonal and zeros elsewhere.\n\n    Args:\n        n (int): the number of rows.\n        m (int, optional): the number of colums with default being n. Defaults to None.\n\n    Keyword args:\n        device(Union[flow.device, str], optional): the desired device of returned tensor. Default: if None, uses the current device for the default tensor.\n        requires_grad(bool, optional): If autograd should record operations on the returned tensor. Default: `False`.\n        placement(oneflow._oneflow_internal.placement, optional): The placement attribute allows you to specify which physical device the tensor is stored on.\n        sbp(Union[oneflow._oneflow_internal.sbp.sbp, List[oneflow._oneflow_internal.sbp.sbp]], optional): When creating a global tensor, specify the SBP of the tensor.\n\n    Returns:\n        oneflow.Tensor: The result tensor with ones on the diagonal and zeros elsewhere.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> out = flow.eye(3, 3)\n        >>> out\n        tensor([[1., 0., 0.],\n                [0., 1., 0.],\n                [0., 0., 1.]], dtype=oneflow.float32)\n        >>> out = flow.eye(3, 3, device=\"cuda\")\n        >>> out\n        tensor([[1., 0., 0.],\n                [0., 1., 0.],\n                [0., 0., 1.]], device='cuda:0', dtype=oneflow.float32)\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.tensor_split,\n    r\"\"\"\n    Splits a tensor into multiple sub-tensors, all of which are views of input, along dimension\n    dim according to the indices or number of sections specified by indices_or_sections .\n    The documentation is referenced from:\n    https://pytorch.org/docs/1.10/generated/torch.tensor_split.html.\n\n    Args:\n        input (Tensor): the input tensor.\n        indices_or_sections (int or a list): If indices_or_sections is an integer n , input is split into n sections \n            along dimension dim.If input is divisible by n along dimension dim, each section will be of equal size, \n            input.size (dim) / n. If input is not divisible by n, the sizes of the first int(input.size(dim) % n).\n            sections will have size int(input.size(dim) / n) + 1, and the rest will have size int(input.size(dim) / n).\n            If indices_or_sections is a list or tuple of ints, then input is split along dimension dim at each of the indices in \n            the list, tuple or tensor. For instance, indices_or_sections=[2, 3] and dim=0 would result in the tensors \n            input[:2], input[2:3], and input[3:].If indices_or_sections is a tensor, it must be a zero-dimensional or\n            one-dimensional long tensor on the CPU.\n        dim (int): dimension along which to split the tensor.\n\n    Returns:\n        oneflow.TensorTuple: the output TensorTuple.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n\n        >>> input = flow.rand(3,4,5)\n        >>> output = flow.tensor_split(input,(2,3),2)\n        >>> output[0].size()\n        oneflow.Size([3, 4, 2])\n        >>> output[1].size()\n        oneflow.Size([3, 4, 1])\n        >>> output[2].size()\n        oneflow.Size([3, 4, 2])\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.hsplit,\n    r\"\"\"\n    hsplit(input, indices_or_sections) -> List of Tensors\n\n    The documentation is referenced from:\n    https://pytorch.org/docs/1.10/generated/torch.hsplit.html.\n\n    Splits `input`, a tensor with one or more dimensions, into multiple tensors horizontally according to `indices_or_sections`.\n    Each split is a view of `input`.\n\n    If `input` is one dimensional this is equivalent to calling oneflow.tensor_split(input, indices_or_sections, dim=0) \n    (the split dimension is zero), and if `input` has two or more dimensions it’s equivalent to calling \n    oneflow.tensor_split(input, indices_or_sections, dim=1) (the split dimension is 1), except that if `indices_or_sections`\n    is an integer it must evenly divide the split dimension or a runtime error will be thrown.\n\n    Args:\n        input (Tensor): the input tensor.\n        indices_or_sections (int or a list): See argument in :func:`oneflow.tensor_split()`.\n\n    Returns:\n        oneflow.TensorTuple: the output TensorTuple.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n\n        >>> input = flow.rand(3,4,5,6)\n        >>> output = flow.hsplit(input,(1,3))\n        >>> output[0].size()\n        oneflow.Size([3, 1, 5, 6])\n        >>> output[1].size()\n        oneflow.Size([3, 2, 5, 6])\n        >>> output[2].size()\n        oneflow.Size([3, 1, 5, 6])\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.vsplit,\n    r\"\"\"\n    Splits input, a tensor with two or more dimensions, into multiple tensors vertically according to indices_or_sections.\n    Each split is a view of input.\n    This is equivalent to calling oneflow.tensor_split(input, indices_or_sections, dim=0) (the split dimension is 0),\n    except that if indices_or_sections is an integer it must evenly divide the split dimension or a runtime error will be thrown.\n    The documentation is referenced from:\n    https://pytorch.org/docs/1.10/generated/torch.vsplit.html.\n\n    Args:\n        input (Tensor): the input tensor.\n        indices_or_sections (int or a list): If indices_or_sections is an integer n , input is split into n sections \n            along dimension dim.If input is divisible by n along dimension dim, each section will be of equal size, \n            input.size (dim) / n. If input is not divisible by n, the sizes of the first int(input.size(dim) % n).\n            sections will have size int(input.size(dim) / n) + 1, and the rest will have size int(input.size(dim) / n).\n            If indices_or_sections is a list or tuple of ints, then input is split along dimension dim at each of the indices in \n            the list, tuple or tensor. For instance, indices_or_sections=[2, 3] and dim=0 would result in the tensors \n            input[:2], input[2:3], and input[3:].If indices_or_sections is a tensor, it must be a zero-dimensional or\n            one-dimensional long tensor on the CPU.\n\n    Returns:\n        oneflow.TensorTuple: the output TensorTuple.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n\n        >>> input = flow.rand(4, 4, 5, 6)\n        >>> output = flow.vsplit(input, (1, 3))\n        >>> output[0].size()\n        oneflow.Size([1, 4, 5, 6])\n        >>> output[1].size()\n        oneflow.Size([2, 4, 5, 6])\n        >>> output[2].size()\n        oneflow.Size([1, 4, 5, 6])\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.cumsum,\n    r\"\"\"oneflow.cumsum(input, dim) -> Tensor\n    \n    This operator computes the cumulative sum of input elements in the given dimension.\n\n    The equation is:\n\n\t$$\n        y_{i}=x_{0}+x_{1}+...+x_{i}\n\t$$\n\n    Args:\n        input (Tensor):  the input ND tensor.\n        dim (int):  the dimension to do cumsum, valid range is [-N, N-1), N is tensor's dimensions\n\n    Returns:\n        oneflow.Tensor: The result tensor with cumsum result.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> input = flow.ones(3, 3)\n        >>> dim = 1\n        >>> flow.cumsum(input, dim)\n        tensor([[1., 2., 3.],\n                [1., 2., 3.],\n                [1., 2., 3.]], dtype=oneflow.float32)\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.cumprod,\n    \"\"\"oneflow.cumprod(input, dim) -> Tensor\n\n    This operator computes the cumulative product of input elements in the given dimension.\n\n    The equation is:\n\n\t$$\n        y_{i}=x_{0}*x_{1}*...*x_{i}\n\t$$\n\n    Args:\n        input (Tensor):  the input tensor.\n        dim (int):  the dimension to do cumsum whose valid range is [-N, N-1), and the N is tensor's dimensions\n\n    Returns:\n        oneflow.Tensor: The result tensor with cumprod result.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> input=flow.tensor([1, 2, 3])\n        >>> flow.cumprod(input, dim=0)\n        tensor([1, 2, 6], dtype=oneflow.int64)\n    \"\"\",\n)\n\n\nadd_docstr(\n    oneflow.trunc,\n    r\"\"\"trunc(input) -> Tensor\n\n    The interface is consistent with PyTorch.    \n    The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.trunc.html\n\n    Returns a new tensor with the truncated integer values of\n    the elements of :attr:`input`.\n\n    Args:\n        input(Tensor): the input tensor.\n\n    Example::\n\n        >>> import oneflow as flow\n        >>> a = flow.tensor([ 3.4742,  0.5466, -0.8008, -0.9079])\n        >>> flow.trunc(a)\n        tensor([3., 0., -0., -0.], dtype=oneflow.float32)\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.digamma,\n    r\"\"\"digamma(input) -> Tensor\n\n    .. math::\n    \\digamma(x) = \\frac{d}{dx} \\ln\\left(\\Gamma\\left(x\\right)\\right) = \\frac{\\Gamma'(x)}{\\Gamma(x)}\n\n    Args:\n        input (Tensor): the tensor to compute the digamma function on\n        \n    .. note::  This function is similar to SciPy's `scipy.special.digamma`.\n\n    Example::\n\n        >>> import oneflow as flow\n        >>> a = flow.tensor([1, 0.5])\n        >>> flow.digamma(a)\n        tensor([-0.5772, -1.9635], dtype=oneflow.float32)\n        \n    \"\"\",\n)\n"
  },
  {
    "path": "python/oneflow/framework/docstr/meshgrid.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow\nfrom oneflow.framework.docstr.utils import add_docstr\n\nadd_docstr(\n    oneflow.meshgrid,\n    \"\"\"\n    Take :math:`N` tensors, each of which can be either scalar or 1-dimensional\n    vector, and create :math:`N` N-dimensional grids, where the :math:`i` :sup:`th` grid is defined by\n    expanding the :math:`i` :sup:`th` input over dimensions defined by other inputs.\n    \n    The interface is consistent with PyTorch.\n    The documentation is referenced from:\n    https://pytorch.org/docs/1.10/generated/torch.meshgrid.html#torch.meshgrid\n\n    Args:\n        tensors (list of Tensor): list of scalars or 1 dimensional tensors. Scalars will be\n            treated as tensors of size :math:`(1,)` automatically.  \n        indexing ((string, optional): the indexing mode, either \"xy\" or \"ij\", defaults to \"ij\".\n            If \"ij\" is selected, the dimensions are in the same order as the cardinality of the inputs.\n            If \"xy\" is selected, the first dimension corresponds to the cardinality of \n            the second input and the second dimension corresponds to the cardinality of the first input.\n\n    Returns:\n        seq (sequence of Tensors): If the input has :math:`k` tensors of size\n        :math:`(N_1,), (N_2,), \\\\ldots , (N_k,)`, then the output would also have :math:`k` tensors,\n        where all tensors are of size :math:`(N_1, N_2, \\\\ldots , N_k)`.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import numpy as np\n        >>> import oneflow as flow\n        \n        >>> input1 = flow.tensor(np.array([2, 2, 3]), dtype=flow.float32)\n        >>> input2 = flow.tensor(np.array([4, 5, 6]), dtype=flow.float32)\n        >>> of_x, of_y = flow.meshgrid(input1, input2)\n        >>> of_x\n        tensor([[2., 2., 2.],\n                [2., 2., 2.],\n                [3., 3., 3.]], dtype=oneflow.float32)\n        >>> of_y\n        tensor([[4., 5., 6.],\n                [4., 5., 6.],\n                [4., 5., 6.]], dtype=oneflow.float32)\n    \"\"\",\n)\n"
  },
  {
    "path": "python/oneflow/framework/docstr/module.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow\nfrom oneflow.framework.docstr.utils import add_docstr\n\nadd_docstr(\n    oneflow.nn.Module.to_consistent,\n    \"\"\"\n    This interface is no longer available, please use :func:`oneflow.nn.Module.to_global` instead.\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.nn.Module.to_global,\n    \"\"\"\n    Convert the parameters and buffers to global.\n\n    It performs the same :func:`oneflow.Tensor.to_global` conversion to each parameter and buffer in this module.\n\n\n    Note:\n        This method modifies the module in-place.\n\n        Both placement and sbp are required if the parameters and buffers of this module are local,\n        otherwise at least one of placement and sbp is required.\n\n    Args:\n        placement (flow.placement, optional): the desired placement of the parameters and buffers in this module. Default: None\n        sbp (flow.sbp.sbp or tuple of flow.sbp.sbp, optional): the desired sbp of the parameters and buffers in this module. Default: None\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> m = flow.nn.Conv2d(in_channels=3, out_channels=4, kernel_size=3)\n        >>> m.to_global(placement=flow.placement(\"cpu\", ranks=[0]), sbp=[flow.sbp.split(0)])\n        >>> m.weight.is_global\n        True\n        >>> m.bias.is_global\n        True\n    \"\"\",\n)\n"
  },
  {
    "path": "python/oneflow/framework/docstr/nms.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow\nfrom oneflow.framework.docstr.utils import add_docstr\n\nadd_docstr(\n    oneflow.nms,\n    \"\"\"\n    Performs non-maximum suppression (NMS) on the boxes according\n    to their intersection-over-union (IoU).\n\n    NMS iteratively removes lower scoring boxes which have an\n    IoU greater than iou_threshold with another (higher scoring)\n    box.\n\n    Args:\n        boxes (Tensor[N, 4]): boxes to perform NMS on. They\n            are expected to be in ``(x1, y1, x2, y2)`` format with ``0 <= x1 < x2`` and\n            ``0 <= y1 < y2``.\n        scores (Tensor[N]): scores for each one of the boxes\n        iou_threshold (float): discards all overlapping boxes with IoU > iou_threshold\n\n    Returns:\n        Tensor: int64 tensor with the indices of the elements that have been kept by NMS, sorted in decreasing order of scores\n    \"\"\",\n)\n"
  },
  {
    "path": "python/oneflow/framework/docstr/nonzero.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow\nfrom oneflow.framework.docstr.utils import add_docstr\n\nadd_docstr(\n    oneflow.nonzero,\n    \"\"\"nonzero(input, *, out=None, as_tuple=False) -> Tensor or tuple of Tensors\n\n    .. note::\n        When :attr:`as_tuple` is ``False`` (default):  returns a\n        2-D tensor where each row is the index for a nonzero value.\n\n        When :attr:`as_tuple` is ``True``: returns a tuple of 1-D\n        index tensors, allowing for advanced indexing, so ``x[x.nonzero(as_tuple=True)]``\n        gives all nonzero values of tensor ``x``. Of the returned tuple, each index tensor\n        contains nonzero indices for a certain dimension.\n\n        See below for more details on the two behaviors.\n\n    **When** :attr:`as_tuple` **is** ``False`` **(default)**:\n\n    Returns a tensor containing the indices of all non-zero elements of\n    :attr:`input`.  Each row in the result contains the indices of a non-zero\n    element in :attr:`input`. The result is sorted lexicographically, with\n    the last index changing the fastest (C-style).\n\n    If :attr:`input` has :math:`n` dimensions, then the resulting indices tensor\n    :attr:`out` is of size :math:`(z \\\\times n)`, where :math:`z` is the total number of\n    non-zero elements in the :attr:`input` tensor.\n\n    **When** :attr:`as_tuple` **is** ``True``:\n\n    Returns a tuple of 1-D tensors, one for each dimension in :attr:`input`,\n    each containing the indices (in that dimension) of all non-zero elements of\n    :attr:`input` .\n\n    If :attr:`input` has :math:`n` dimensions, then the resulting tuple contains :math:`n`\n    tensors of size :math:`z`, where :math:`z` is the total number of\n    non-zero elements in the :attr:`input` tensor.\n\n    As a special case, when :attr:`input` has zero dimensions and a nonzero scalar\n    value, it is treated as a one-dimensional tensor with one element.\n\n    Args:\n        input(Tensor): the input tensor.\n\n    Keyword args:\n        out (Tensor, optional): the output tensor containing indices\n\n    Returns:\n        Tensor or tuple of Tensors: If :attr:`as_tuple` is ``False``, the output\n        tensor containing indices. If :attr:`as_tuple` is ``True``, one 1-D tensor for\n        each dimension, containing the indices of each nonzero element along that\n        dimension.\n\n    Example::\n\n        >>> import oneflow as flow\n        >>> flow.nonzero(flow.tensor([1, 1, 1, 0, 1]))\n        tensor([[0],\n                [1],\n                [2],\n                [4]], dtype=oneflow.int64)\n        >>> flow.nonzero(flow.tensor([[0.6, 0.0, 0.0, 0.0],\n        ...                             [0.0, 0.4, 0.0, 0.0],\n        ...                             [0.0, 0.0, 1.2, 0.0],\n        ...                             [0.0, 0.0, 0.0,-0.4]]))\n        tensor([[0, 0],\n                [1, 1],\n                [2, 2],\n                [3, 3]], dtype=oneflow.int64)\n        >>> flow.nonzero(flow.tensor([1, 1, 1, 0, 1]), as_tuple=True)\n        (tensor([0, 1, 2, 4], dtype=oneflow.int64),)\n        >>> flow.nonzero(flow.tensor([[0.6, 0.0, 0.0, 0.0],\n        ...                             [0.0, 0.4, 0.0, 0.0],\n        ...                             [0.0, 0.0, 1.2, 0.0],\n        ...                             [0.0, 0.0, 0.0,-0.4]]), as_tuple=True)\n        (tensor([0, 1, 2, 3], dtype=oneflow.int64), tensor([0, 1, 2, 3], dtype=oneflow.int64))\n        >>> flow.nonzero(flow.tensor(5), as_tuple=True)\n        (tensor([0], dtype=oneflow.int64),)\n\n    \"\"\",\n)\n"
  },
  {
    "path": "python/oneflow/framework/docstr/norm.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow\nfrom oneflow.framework.docstr.utils import add_docstr\n\n\nadd_docstr(\n    oneflow.linalg.vector_norm,\n    \"\"\"linalg.vector_norm(input, ord=2, dim=None, keepdim=False, *, dtype=None, out=None) -> Tensor\n\n    Computes a vector norm.\n\n    Supports input of float, double dtypes.\n\n    This function does not necessarily treat multidimensonal attr:`input` as a batch of\n    vectors, instead:\n\n    - If :attr:`dim`\\\\ `= None`, :attr:`input` will be flattened before the norm is computed.\n    - If :attr:`dim` is an `int` or a `tuple`, the norm will be computed over these dimensions and the other dimensions will be treated as batch dimensions.\n\n    This behavior is for consistency with :func:`flow.linalg.norm`.\n\n    :attr:`ord` defines the vector norm that is computed. The following norms are supported:\n\n    ======================   ========================================================\n    :attr:`ord`              vector norm\n    ======================   ========================================================\n    `2` (default)            `2`-norm (see below)\n    `inf`                    `max(abs(x))`\n    `-inf`                   `min(abs(x))`\n    `0`                      `sum(x != 0)`\n    other `int` or `float`   `sum(abs(x)^{ord})^{(1 / ord)}`\n    ======================   ========================================================\n\n    where `inf` refers to `float('inf')`, NumPy's `inf` object, or any equivalent object.\n\n    Args:\n        input (Tensor): tensor, flattened by default, but this behavior can be\n            controlled using :attr:`dim`.\n        ord (int, float, inf, -inf, 'fro', 'nuc', optional): order of norm. Default: `2`\n        dim (int, Tuple[int], optional): dimensions over which to compute\n            the norm. See above for the behavior when :attr:`dim`\\\\ `= None`.\n            Default: `None`\n        keepdim (bool, optional): If set to `True`, the reduced dimensions are retained\n            in the result as dimensions with size one. Default: `False`\n\n    Returns:\n        A real-valued tensor.\n\n    Examples:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> from oneflow import linalg as LA\n        >>> import numpy as np\n        >>> a = flow.tensor(np.arange(9, dtype=np.float32) - 4)\n        >>> a\n        tensor([-4., -3., -2., -1.,  0.,  1.,  2.,  3.,  4.], dtype=oneflow.float32)\n        >>> b = a.reshape(3, 3)\n        >>> b\n        tensor([[-4., -3., -2.],\n                [-1.,  0.,  1.],\n                [ 2.,  3.,  4.]], dtype=oneflow.float32)\n        >>> LA.vector_norm(a, ord=3.5)\n        tensor(5.4345, dtype=oneflow.float32)\n        >>> LA.vector_norm(b, ord=3.5)\n        tensor(5.4345, dtype=oneflow.float32)\n    \n    \"\"\",\n)\n\n\nadd_docstr(\n    oneflow.linalg.matrix_norm,\n    \"\"\"linalg.matrix_norm(input, ord='fro', dim=(-2, -1), keepdim=False, *, dtype=None, out=None) -> Tensor\n\n    Computes a matrix norm.\n\n    Support input of float, double, cfloat and cdouble dtypes.\n    Also supports batches of matrices: the norm will be computed over the\n    dimensions specified by the 2-tuple :attr:`dim` and the other dimensions will\n    be treated as batch dimensions. The output will have the same batch dimensions.\n\n    :attr:`ord` defines the matrix norm that is computed. The following norms are supported:\n\n    ======================   ========================================================\n    :attr:`ord`              matrix norm\n    ======================   ========================================================\n    `'fro'` (default)        Frobenius norm\n    `'nuc'`                  -- not supported yet --\n    `inf`                    `max(sum(abs(x), dim=1))`\n    `-inf`                   `min(sum(abs(x), dim=1))`\n    `1`                      `max(sum(abs(x), dim=0))`\n    `-1`                     `min(sum(abs(x), dim=0))`\n    `2`                      -- not supported yet --\n    `-2`                     -- not supported yet --\n    ======================   ========================================================\n\n    where `inf` refers to `float('inf')`, NumPy's `inf` object, or any equivalent object.\n\n    Args:\n        input (Tensor): tensor with two or more dimensions. By default its\n            shape is interpreted as `(*, m, n)` where `*` is zero or more\n            batch dimensions, but this behavior can be controlled using :attr:`dim`.\n        ord (int, inf, -inf, 'fro', 'nuc', optional): order of norm. Default: `'fro'`\n        dim (Tuple[int, int], optional): dimensions over which to compute the norm. Default: `(-2, -1)`\n        keepdim (bool, optional): If set to `True`, the reduced dimensions are retained\n            in the result as dimensions with size one. Default: `False`\n\n    Returns:\n        A real-valued tensor.\n\n    Examples:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> from oneflow import linalg as LA\n        >>> import numpy as np\n        >>> a = flow.tensor(np.arange(9, dtype=np.float32)).reshape(3,3)\n        >>> a\n        tensor([[0., 1., 2.],\n                [3., 4., 5.],\n                [6., 7., 8.]], dtype=oneflow.float32)\n        >>> LA.matrix_norm(a)\n        tensor(14.2829, dtype=oneflow.float32)\n        >>> LA.matrix_norm(a, ord=-1)\n        tensor(9., dtype=oneflow.float32)\n        >>> b = a.expand(2, -1, -1)\n        >>> b\n        tensor([[[0., 1., 2.],\n                 [3., 4., 5.],\n                 [6., 7., 8.]],\n        <BLANKLINE>\n                [[0., 1., 2.],\n                 [3., 4., 5.],\n                 [6., 7., 8.]]], dtype=oneflow.float32)\n        >>> LA.matrix_norm(b, dim=(0, 2))\n        tensor([ 3.1623, 10.0000, 17.2627], dtype=oneflow.float32)\n    \n    \"\"\",\n)\n\n\nadd_docstr(\n    oneflow.linalg.norm,\n    \"\"\"linalg.norm(input, ord=None, dim=None, keepdim=False, *, dtype=None, out=None) -> Tensor\n    Returns the matrix norm or vector norm of a given tensor.\n\n    This function can calculate one of eight different types of matrix norms, or one\n    of an infinite number of vector norms, depending on both the number of reduction\n    dimensions and the value of the `ord` parameter.\n\n    Args:\n        input (Tensor): The input tensor. If dim is None, input must be 1-D or 2-D, unless :attr:`ord`\n            is None. If both :attr:`dim` and :attr:`ord` are None, the 2-norm of the input flattened to 1-D\n            will be returned. Its data type must be either a floating point or complex type. For complex\n            inputs, the norm is calculated on of the absolute values of each element. If the input is\n            complex and neither :attr:`dtype` nor :attr:`out` is specified, the result's data type will\n            be the corresponding floating point type (e.g. float if :attr:`input` is complexfloat).\n\n        ord (int, inf, -inf, 'fro', 'nuc', optional): order of norm. Default: `'None'`\n            The following norms can be calculated:\n\n            ==============  ============================  =================================\n            :attr:`ord`       norm for matrices             norm for vectors\n            ==============  ============================  =================================\n            None             Frobenius norm                `2`-norm\n            `'fro'`          Frobenius norm                -- not supported --\n            `'nuc'`          -- not supported yet --       -- not supported --\n            `inf`            `max(sum(abs(x), dim=1))`     `max(abs(x))`\n            `-inf`           `min(sum(abs(x), dim=1))`     `min(abs(x))`\n            `0`              -- not supported --           `sum(x != 0)`\n            `1`              `max(sum(abs(x), dim=0))`     as below\n            `-1`             `min(sum(abs(x), dim=0))`     as below\n            `2`              -- not supported yet --       as below\n            `-2`             -- not supported yet --       as below\n            other            -- not supported --           `sum(abs(x)^{ord})^{(1 / ord)}`\n            ==============  ============================  =================================\n\n            where `inf` refers to `float('inf')`, NumPy's `inf` object, or any equivalent object.\n\n        dim (int, 2-tuple of ints, 2-list of ints, optional): If :attr:`dim` is an int,\n            vector norm will be calculated over the specified dimension. If :attr:`dim`\n            is a 2-tuple of ints, matrix norm will be calculated over the specified\n            dimensions. If :attr:`dim` is None, matrix norm will be calculated\n            when the input tensor has two dimensions, and vector norm will be\n            calculated when the input tensor has one dimension. Default: ``None``\n\n        keepdim (bool, optional): If set to True, the reduced dimensions are retained\n            in the result as dimensions with size one. Default: ``False``\n\n        out (Tensor, optional): The output tensor.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> from oneflow import linalg as LA\n        >>> import numpy as np\n        >>> a = flow.tensor(np.arange(9, dtype=np.float32) - 4)\n        >>> a\n        tensor([-4., -3., -2., -1.,  0.,  1.,  2.,  3.,  4.], dtype=oneflow.float32)\n        >>> b = a.reshape(3, 3)\n        >>> b\n        tensor([[-4., -3., -2.],\n                [-1.,  0.,  1.],\n                [ 2.,  3.,  4.]], dtype=oneflow.float32)\n        >>> LA.norm(a)\n        tensor(7.7460, dtype=oneflow.float32)\n        >>> LA.norm(b)\n        tensor(7.7460, dtype=oneflow.float32)\n        >>> LA.norm(b, 'fro')\n        tensor(7.7460, dtype=oneflow.float32)\n        >>> LA.norm(a, float('inf'))\n        tensor(4., dtype=oneflow.float32)\n        >>> LA.norm(b, float('inf'))\n        tensor(9., dtype=oneflow.float32)\n        >>> LA.norm(a, -float('inf'))\n        tensor(0., dtype=oneflow.float32)\n        >>> LA.norm(b, -float('inf'))\n        tensor(2., dtype=oneflow.float32)\n        >>> LA.norm(a, 1)\n        tensor(20., dtype=oneflow.float32)\n        >>> LA.norm(b, 1)\n        tensor(7., dtype=oneflow.float32)\n        >>> LA.norm(a, -1)\n        tensor(0., dtype=oneflow.float32)\n        >>> LA.norm(b, -1)\n        tensor(6., dtype=oneflow.float32)\n        >>> LA.norm(a, 2)\n        tensor(7.7460, dtype=oneflow.float32)\n        >>> LA.norm(a, -2)\n        tensor(0., dtype=oneflow.float32)\n        >>> LA.norm(a, 3)\n        tensor(5.8480, dtype=oneflow.float32)\n        >>> LA.norm(a, -3)\n        tensor(0., dtype=oneflow.float32)\n        >>> c = flow.tensor([[1., 2., 3.],\n        ...                   [-1, 1, 4]])\n        >>> LA.norm(c, dim=0)\n        tensor([1.4142, 2.2361, 5.0000], dtype=oneflow.float32)\n        >>> LA.norm(c, dim=1, keepdim = True)\n        tensor([[3.7417],\n                [4.2426]], dtype=oneflow.float32)\n        >>> LA.norm(c, ord=1, dim=1)\n        tensor([6., 6.], dtype=oneflow.float32)\n\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow._C.normalize,\n    \"\"\"nn.functional.normalize(input: Tensor, p: float=2.0, dim: int=0, epsilon: float=1e-12) -> Tensor\n\n    Performs :math:`L_p` normalization of inputs over specified dimension\n\n    For a tensor :attr:`input` of sizes :math:`(n_0, ..., n_{dim}, ..., n_k)`, each\n    :math:`n_{dim}` -element vector :math:`v` along dimension :attr:`dim` is transformed as:\n\n    .. math::\n        v = \\\\frac{v}{\\max(\\\\lVert v \\\\rVert_p, \\\\epsilon)}.\n\n    With the default arguments it uses the Euclidean norm over vectors along dimension :math:`1` for normalization.\n\n    But note that the gradient calculation of the input tensor has different results on different frameworks\n    when `input.shape[dim] = 1`.\n\n    Args:\n        input (oneflow.Tensor): input tensor of any shape\n        p (float): the exponent value in the norm formulation. Default: 2\n        dim (int): the dimension to reduce. Default: 1\n        eps (float): small value to avoid division by zero. Default: 1e-12 \n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> x = flow.tensor([[1, 2], [3, 4]], dtype=flow.float32)\n        >>> out = flow.nn.functional.normalize(x, 2, 0)\n        >>> out\n        tensor([[0.3162, 0.4472],\n                [0.9487, 0.8944]], dtype=oneflow.float32)\n        >>> out = flow.nn.functional.normalize(x, 2, 1)\n        >>> out\n        tensor([[0.4472, 0.8944],\n                [0.6000, 0.8000]], dtype=oneflow.float32)\n\n    \"\"\",\n)\n"
  },
  {
    "path": "python/oneflow/framework/docstr/normalization.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow\nfrom oneflow.framework.docstr.utils import add_docstr\n\nadd_docstr(\n    oneflow.nn.functional.layer_norm,\n    \"\"\"nn.functional.layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-05) -> Tensor\n\n    Applies Layer Normalization for last certain number of dimensions.\n\n    See :class:`~oneflow.nn.LayerNorm` for details.\n\n    \"\"\",\n)\n"
  },
  {
    "path": "python/oneflow/framework/docstr/oneflow.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow\nfrom oneflow.framework.docstr.utils import add_docstr\n\nadd_docstr(\n    oneflow.set_num_threads,\n    \"\"\"\n    Sets the number of threads used for intraop parallelism on CPU.\n    \n    .. WARNING::\n        To ensure that the correct number of threads is used, \n        set_num_threads must be called before running eager, eager globe or ddp.\n\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.get_default_dtype,\n    \"\"\"oneflow.get_default_dtype() -> oneflow._oneflow_internal.dtype\n\n    Returns the default floating point dtype.\n\n    Returns:\n        oneflow.dtype: The default floating point dtype.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> flow.set_default_dtype(flow.float32)\n        >>> flow.get_default_dtype()\n        oneflow.float32\n        >>> flow.set_default_dtype(flow.float64)\n        >>> flow.get_default_dtype()\n        oneflow.float64\n        >>> flow.set_default_tensor_type(flow.FloatTensor)\n        >>> flow.get_default_dtype()\n        oneflow.float32\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.set_default_dtype,\n    \"\"\"oneflow.set_default_dtype() -> None\n\n    Sets the default floating point type for those source operators which create Tensor.\n\n    The default floating point type is ``oneflow.float32``.\n\n    Args:\n        dtype (oneflow.dtype): The floating point dtype.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow\n        >>> oneflow.set_default_dtype(oneflow.float64)\n        >>> x = oneflow.randn(2, 3)\n        >>> x.dtype\n        oneflow.float64\n        >>> oneflow.set_default_dtype(oneflow.float32)\n        >>> x = oneflow.randn(2, 3)\n        >>> x.dtype\n        oneflow.float32\n    \"\"\",\n)\n"
  },
  {
    "path": "python/oneflow/framework/docstr/onehot.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow\nfrom oneflow.framework.docstr.utils import add_docstr\n\nadd_docstr(\n    oneflow._C.one_hot,\n    r\"\"\"\n    one_hot(input, num_classes=-1, on_value=1, off_value=0)\n    This operator generates a onehot Tensor from input Tensor.\n\n    If input Tensor's rank is `N`, the corresponding onehot Tensor's rank is `N+1`.\n\n    Args:\n        input (Tensor): The input Tensor.\n        num_classes (int): The length of onehot Tensor.\n        on_value (Union[int, float], optional): The fill value when `x[i] == i`. Defaults to 1.\n        off_value (Union[int, float], optional): The fill value when `x[i] != i`. Defaults to 0.\n    Note:\n\n        The data type of input tensor should be `int32` or `int64`.\n\n    Returns:\n        oneflow.Tensor.\n    \n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n\n        >>> input=flow.tensor(np.array([0, 3, 1, 2]).astype(np.int64), dtype=flow.int64)\n        >>> out = flow.nn.functional.one_hot(input, num_classes=5)\n        >>> out\n        tensor([[1, 0, 0, 0, 0],\n                [0, 0, 0, 1, 0],\n                [0, 1, 0, 0, 0],\n                [0, 0, 1, 0, 0]], dtype=oneflow.int64)\n    \n    \"\"\",\n)\n"
  },
  {
    "path": "python/oneflow/framework/docstr/pooling.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow\nfrom oneflow.framework.docstr.utils import add_docstr\n\nadd_docstr(\n    oneflow._C.adaptive_avg_pool1d,\n    \"\"\"\n    adaptive_avg_pool1d(input, output_size) -> Tensor\n\n    Applies a 1D adaptive average pooling over an input signal composed of\n    several input planes.\n\n    See :class:`~oneflow.nn.AdaptiveAvgPool1d` for details and output shape.\n\n    Args:\n        input: the input tensor\n        output_size: the target output size (single integer)\n\n    For examples:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n\n        >>> arr = np.array([[[ 0.0558, -0.6875, -1.6544, -0.6226,  0.1018,  0.0502, -1.2538, 0.1491]]])\n        >>> input = flow.tensor(arr, dtype=flow.float32)\n        >>> flow.nn.functional.adaptive_avg_pool1d(input, output_size=[4])\n        tensor([[[-0.3158, -1.1385,  0.0760, -0.5524]]], dtype=oneflow.float32)\n\n    \"\"\",\n)\nadd_docstr(\n    oneflow._C.adaptive_avg_pool2d,\n    \"\"\"\n    adaptive_avg_pool2d(input, output_size) -> Tensor\n\n    Applies a 2D adaptive average pooling over an input signal composed of several input planes.\n\n    See :class:`~oneflow.nn.AdaptiveAvgPool2d` for details and output shape.\n\n    Args:\n        input: the input tensor\n        output_size: the target output size (single integer or double-integer tuple)\n\n    For examples:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n\n        >>> arr = np.array([[[[ 0.1004,  0.0488, -1.0515,  0.9466],[ 0.4538,  0.2361,  1.3437,  0.398 ],[ 0.0558, -0.6875, -1.6544, -0.6226],[ 0.1018,  0.0502, -1.2538,  0.1491]]]])\n        >>> input = flow.tensor(arr, dtype=flow.float32)\n        >>> outputs = flow.nn.functional.adaptive_avg_pool2d(input, (2, 2))\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow._C.adaptive_avg_pool3d,\n    \"\"\"\n    adaptive_avg_pool3d(input, output_size) -> Tensor\n\n    Applies a 3D adaptive average pooling over an input signal composed of several input planes.\n\n    See :class:`~oneflow.nn.AdaptiveAvgPool3d` for details and output shape.\n\n    Args:\n        input: the input tensor\n        output_size: the target output size (single integer or triple-integer tuple)\n\n    For examples:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow         \n        >>> import numpy as np\n\n        >>> input = flow.tensor(np.random.randn(1, 1, 4, 4, 4), dtype=flow.float32)\n        >>> output = flow.nn.functional.adaptive_avg_pool3d(input, (2, 2, 2))\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow._C.avg_pool1d,\n    \"\"\"\n    avg_pool1d(input, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True) -> Tensor\n\n    Applies a 1D average pooling over an input signal composed of several input planes.\n\n    The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.functional.avg_pool1d.html\n\n    See :class:`~oneflow.nn.AvgPool1d` for details and output shape.\n\n    Args:\n        input: input tensor of shape :math:`(\\\\text{minibatch} , \\\\text{in_channels} , iW)`\n        kernel_size: the size of the window. Can be a single number or a tuple `(kW,)`\n        stride: the stride of the window. Can be a single number or a tuple `(sW,)`. Default: :attr:`kernel_size`\n        padding: implicit zero paddings on both sides of the input. Can be a single number or a tuple `(padW,)`. Default: 0\n        ceil_mode: when True, will use `ceil` instead of `floor` to compute the output shape. Default: ``False``\n        count_include_pad: when True, will include the zero-padding in the averaging calculation. Default: ``True``\n\n    Examples::\n\n        >>> # pool of square window of size=3, stride=2\n        >>> import oneflow\n        >>> input = oneflow.tensor([[[1, 2, 3, 4, 5, 6, 7]]], dtype=oneflow.float32)\n        >>> oneflow.nn.functional.avg_pool1d(input, kernel_size=3, stride=2)\n        tensor([[[2., 4., 6.]]], dtype=oneflow.float32)\n\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow._C.avg_pool2d,\n    \"\"\"\n    avg_pool2d(input, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True, divisor_override=0) -> Tensor\n\n    Applies 2D average-pooling operation in :math:`kH \\\\times kW` regions by step size :math:`sH \\\\times sW` steps. The number of output features is equal to the number of input planes.\n\n    The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.functional.avg_pool2d.html.\n\n    See :class:`~oneflow.nn.AvgPool2d` for details and output shape.\n\n    Args:\n        input: input tensor :math:`(\\\\text{minibatch} , \\\\text{in_channels} , iH , iW)`\n        kernel_size: size of the pooling region. Can be a single number or a tuple `(kH, kW)`\n        stride: stride of the pooling operation. Can be a single number or a tuple `(sH, sW)`. Default: :attr:`kernel_size`\n        padding: implicit zero paddings on both sides of the input. Can be a single number or a tuple `(padH, padW)`. Default: 0\n        ceil_mode: when True, will use `ceil` instead of `floor` in the formula to compute the output shape. Default: ``False``\n        count_include_pad: when True, will include the zero-padding in the averaging calculation. Default: ``True``\n        divisor_override: if specified, it will be used as divisor, otherwise size of the pooling region will be used. Default: 0\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow._C.avg_pool3d,\n    \"\"\"\n    avg_pool3d(input, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True, divisor_override=0) -> Tensor\n\n    Applies 3D average-pooling operation in :math:`kT \\\\times kH \\\\times kW` regions by step size :math:`sT \\\\times sH \\\\times sW` steps. The number of output features is equal to :math:`\\\\lfloor\\\\frac{\\\\text{input planes}}{sT}\\\\rfloor`.\n\n    The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.functional.avg_pool3d.html\n\n    See :class:`~oneflow.nn.AvgPool3d` for details and output shape.\n\n    Args:\n        input: input tensor :math:`(\\\\text{minibatch} , \\\\text{in_channels} , iT \\\\times iH , iW)`\n        kernel_size: size of the pooling region. Can be a single number or a tuple `(kT, kH, kW)`\n        stride: stride of the pooling operation. Can be a single number or a tuple `(sT, sH, sW)`. Default: :attr:`kernel_size`\n        padding: implicit zero paddings on both sides of the input. Can be a single number or a tuple `(padT, padH, padW)`, Default: 0\n        ceil_mode: when True, will use `ceil` instead of `floor` in the formula to compute the output shape\n        count_include_pad: when True, will include the zero-padding in the averaging calculation\n        divisor_override: if specified, it will be used as divisor, otherwise size of the pooling region will be used. Default: 0\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow._C.max_unpool1d,\n    \"\"\"\n    max_unpool1d(input, indices, kernel_size, stride=None, padding=0, output_size=None) -> Tensor\n\n    Computes a partial inverse of ``MaxPool1d``.\n\n    See :class:`MaxUnpool1d` for details.\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow._C.max_unpool2d,\n    \"\"\"\n    max_unpool2d(input, indices, kernel_size, stride=None, padding=0, output_size=None) -> Tensor\n\n    Computes a partial inverse of ``MaxPool2d``.\n\n    See :class:`MaxUnpool2d` for details.\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow._C.max_unpool3d,\n    \"\"\"\n    max_unpool3d(input, indices, kernel_size, stride=None, padding=0, output_size=None) -> Tensor\n\n    Computes a partial inverse of ``MaxPool3d``.\n\n    See :class:`MaxUnpool3d` for details.\n    \"\"\",\n)\n"
  },
  {
    "path": "python/oneflow/framework/docstr/quantile.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow\nfrom oneflow.framework.docstr.utils import add_docstr\n\nadd_docstr(\n    oneflow.quantile,\n    \"\"\"\n    quantile(input, q, dim=None, keepdim=False, *, interpolation='linear', out=None) -> Tensor\n\n    The documentation is referenced from: https://pytorch.org/docs/stable/generated/torch.quantile.html.\n\n    Computes the q-th quantiles of each row of the :attr:`input` tensor along the dimension :attr:`dim`.\n\n    To compute the quantile, we map q in [0, 1] to the range of indices [0, n] to find the location\n    of the quantile in the sorted input. If the quantile lies between two data points ``a < b`` with\n    indices ``i`` and ``j`` in the sorted order, result is computed according to the given\n    :attr:`interpolation` method as follows:\n\n    - ``linear``: ``a + (b - a) * fraction``, where ``fraction`` is the fractional part of the computed quantile index.\n    - ``lower``: ``a``.\n    - ``higher``: ``b``.\n    - ``nearest``: ``a`` or ``b``, whichever's index is closer to the computed quantile index (rounding down for .5 fractions).\n    - ``midpoint``: ``(a + b) / 2``.\n\n    If :attr:`q` is a 1D tensor, the first dimension of the output represents the quantiles and has size\n    equal to the size of :attr:`q`, the remaining dimensions are what remains from the reduction.\n\n    .. note::\n        By default :attr:`dim` is ``None`` resulting in the :attr:`input` tensor being flattened before computation.\n \n    Args:\n        input (oneflow.Tensor): the input Tensor.\n        q (float or oneflow.Tensor): a scalar or 1D tensor of values in the range [0, 1].\n        dim (int, optional): the dimension to reduce. Default is None.\n        keepdim (bool, optional): whether the output tensor has dim retained or not. Default is False\n        interpolation (str, optional): interpolation method to use when the desired quantile lies between two data points.\n                                Can be ``linear``, ``lower``, ``higher``, ``midpoint`` and ``nearest``.\n                                Default is ``linear``.\n        out (oneflow.Tensor, optional): the output Tensor.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> a = flow.arange(8.)\n        >>> q = flow.tensor([0.25, 0.5, 0.75])\n        >>> flow.quantile(a, q, dim=0, keepdim=True)\n        tensor([[1.7500],\n                [3.5000],\n                [5.2500]], dtype=oneflow.float32)\n        >>> a = flow.arange(4.)\n        >>> flow.quantile(a, 0.6, interpolation=\"linear\")\n        tensor(1.8000, dtype=oneflow.float32)\n        >>> flow.quantile(a, 0.6, interpolation=\"lower\")\n        tensor(1., dtype=oneflow.float32)\n        >>> flow.quantile(a, 0.6, interpolation=\"higher\")\n        tensor(2., dtype=oneflow.float32)\n        >>> flow.quantile(a, 0.6, interpolation=\"midpoint\")\n        tensor(1.5000, dtype=oneflow.float32)\n        >>> flow.quantile(a, 0.6, interpolation=\"nearest\")\n        tensor(2., dtype=oneflow.float32)\n        >>> flow.quantile(a, 0.4, interpolation=\"nearest\")\n        tensor(1., dtype=oneflow.float32)\n\n    \"\"\",\n)\n"
  },
  {
    "path": "python/oneflow/framework/docstr/random.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow\nfrom oneflow.framework.docstr.utils import add_docstr\n\nadd_docstr(\n    oneflow.bernoulli,\n    \"\"\"\n    bernoulli(input, p, *, generator=None, out=None)\n    \n    This operator returns a Tensor with binaray random numbers (0 / 1) from a Bernoulli distribution.\n\n    Args:\n        input (Tensor): the input tensor of probability values for the Bernoulli distribution\n        p (float, optional): the probability for the Bernoulli distribution. If specified, Bernoulli distribution will use p for sampling, not input\n        generator (Generator, optional): a pseudorandom number generator for sampling\n        out (Tensor, optional): the output tensor.\n\n    Shape:\n        - Input: :math:`(*)`. Input can be of any shape\n        - Output: :math:`(*)`. Output is of the same shape as input\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import numpy as np\n        >>> import oneflow as flow\n\n        >>> arr = np.array(\n        ...    [\n        ...        [1.0, 1.0, 1.0],\n        ...        [1.0, 1.0, 1.0],\n        ...        [1.0, 1.0, 1.0],\n        ...    ]\n        ... )\n        >>> x = flow.tensor(arr, dtype=flow.float32)\n        >>> y = flow.bernoulli(x)\n        >>> y\n        tensor([[1., 1., 1.],\n                [1., 1., 1.],\n                [1., 1., 1.]], dtype=oneflow.float32)\n        >>> y = flow.bernoulli(x, 1)\n        >>> y\n        tensor([[1., 1., 1.],\n                [1., 1., 1.],\n                [1., 1., 1.]], dtype=oneflow.float32)\n        >>> y = flow.bernoulli(x, p=0)\n        >>> y\n        tensor([[0., 0., 0.],\n                [0., 0., 0.],\n                [0., 0., 0.]], dtype=oneflow.float32)\n\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow._C.randn,\n    \"\"\"\n    randn(*size, *, dtype=None, generator=None, device=None, placement=None, sbp=None, requires_grad=False) -> Tensor\n\n    Returns a tensor filled with random numbers from a normal distribution with mean 0 and variance 1 (also called the standard normal distribution).\n\n    The shape of the tensor is defined by the variable argument ``size``.\n\n    Args:\n        size (int... or oneflow.Size): Defining the shape of the output tensor.\n          Can be a variable number of arguments or a collection like a list or tuple or oneflow.Size.\n        dtype (flow.dtype, optional): The desired data type of returned tensor. Default: ``flow.float32``.\n        generator (flow.Generator, optional): a pseudorandom number generator for sampling\n        device (flow.device, optional): The desired device of returned local tensor. If None, uses the\n          current device.\n        placement (flow.placement, optional): The desired device of returned global tensor. If None, will\n          construct local tensor.\n        sbp (flow.sbp, optional): The desired sbp of returned global tensor. It must be equal with the\n          numbers of placement.\n        requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: False.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> x = flow.randn(3,3) # construct local tensor\n        >>> x.shape\n        oneflow.Size([3, 3])\n        >>> x.is_global\n        False\n        >>> placement = flow.placement(\"cpu\", ranks=[0])\n        >>> sbp = flow.sbp.broadcast\n        >>> x = flow.randn(3,3,placement=placement,sbp=sbp) # construct global tensor\n        >>> x.is_global\n        True\n\n    \"\"\",\n)\n\n\nadd_docstr(\n    oneflow._C.randn_like,\n    \"\"\"\n    randn_like(input, *, dtype=None, generator=None, device=None, placement=None, sbp=None, requires_grad=False) -> Tensor\n\n    Returns a tensor with the same size as `input` that is filled with random numbers from a normal distribution with mean 0 and variance 1.\n    flow.randn_like(input) is equivalent to flow.randn(input.size(), dtype=input.dtype, device=input.device).\n\n    Args:\n        input (oneflow.Tensor): the size of ``input`` will determine size of the output tensor.\n        dtype (flow.dtype, optional): The desired data type of returned tensor. defaults to the dtype of `input`.\n        generator (flow.Generator, optional): a pseudorandom number generator for sampling\n        device (flow.device, optional): The desired device of returned local tensor. If None, defaults to the device of `input`.\n        placement (flow.placement, optional): The desired device of returned global tensor. If None, will\n          construct local tensor.\n        sbp (flow.sbp, optional): The desired sbp of returned global tensor. It must be equal with the\n          numbers of placement, If None, will construct local tensor.\n        requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: False.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> x = flow.randn(3,3) # construct local tensor\n        >>> y = flow.randn_like(x)\n        >>> y.shape\n        oneflow.Size([3, 3])\n        >>> y.is_global\n        False\n        >>> placement = flow.placement(\"cpu\", ranks=[0])\n        >>> sbp = flow.sbp.broadcast\n        >>> z = flow.randn_like(y, placement=placement, sbp=sbp) # construct global tensor\n        >>> z.is_global\n        True\n\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow._C.rand,\n    \"\"\"\n    rand(*size, *, dtype=None, generator=None, device=None, placement=None, sbp=None, requires_grad=False) -> Tensor\n\n    Returns a tensor filled with random numbers from a uniform distribution on the interval [0, 1)\n\n    The shape of the tensor is defined by the variable argument ``size``.\n\n    Args:\n        size (int... or oneflow.Size): Defining the shape of the output tensor.\n          Can be a variable number of arguments or a collection like a list or tuple or oneflow.Size.\n        dtype (flow.dtype, optional): The desired data type of returned tensor. Default: ``flow.float32``.\n        generator (flow.Generator, optional): a pseudorandom number generator for sampling\n        device (flow.device, optional): The desired device of returned local tensor. If None, uses the\n          current device.\n        placement (flow.placement, optional): The desired device of returned global tensor. If None, will\n          construct local tensor.\n        sbp (flow.sbp, optional): The desired sbp of returned global tensor. It must be equal with the\n          numbers of placement.\n        requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: False.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> x = flow.rand(3,3) # construct local tensor\n        >>> x.shape\n        oneflow.Size([3, 3])\n        >>> x.is_global\n        False\n        >>> placement = flow.placement(\"cpu\", ranks=[0])\n        >>> sbp = flow.sbp.broadcast\n        >>> x = flow.rand(3, 3, placement=placement, sbp=sbp) # construct global tensor\n        >>> x.is_global\n        True\n\n\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow._C.normal,\n    r\"\"\"\n    The documentation is referenced from: https://pytorch.org/docs/stable/generated/torch.normal.html\n    normal(mean, std, *, generator=None, out=None) -> Tensor\n\n    Returns a tensor of random numbers drawn from separate normal distributions\n    whose mean and standard deviation are given.\n\n    The :attr:`mean` is a tensor with the mean of\n    each output element's normal distribution\n\n    The :attr:`std` is a tensor with the standard deviation of\n    each output element's normal distribution\n\n    The shapes of :attr:`mean` and :attr:`std` don't need to match, but the\n    total number of elements in each tensor need to be the same.\n\n    .. note:: \n        Infers the output shape from input arrays :attr:`mean` and :attr:`std`.\n        The output shape will have a dimensionality equal to the max of :attr:`mean` and :attr:`std`.\n        Dimensions with size 1 in either :attr:`mean` or :attr:`std` are expanded to match the other.\n\n    Args:\n        mean (Tensor): the tensor of per-element means\n        std (Tensor): the tensor of per-element standard deviations\n\n    Keyword args:\n        generator (Generator, optional): Random number generator. Defaults to `oneflow::DefaultGenerator` if not provided.\n        out (Tensor, optional): Output tensor, will be resized and filled with the result. If not provided, a new tensor is created.  \n\n    Example:\n\n    .. code-block:: python\n        \n        >>> import oneflow as flow\n        >>> generator = flow.Generator()\n        >>> generator.manual_seed(0) #doctest: +ELLIPSIS\n        <oneflow._oneflow_internal.Generator object at ...>\n        >>> z = flow.normal(mean=flow.arange(1., 11.), std=flow.arange(1, 0, -0.1), generator=generator)\n        >>> z[:5]\n        tensor([3.2122, 3.0468, 3.6192, 4.3387, 5.6261], dtype=oneflow.float32)\n\n\n    normal(mean=0.0, std, `*`, generator=None, out=None) -> Tensor.\n    \n    Similar to the function above, but the means are shared among all drawn elements.\n\n    Args:  \n        mean (float, optional) : the mean for all distributions\n        std (Tensor) : the tensor of per-element standard deviations\n\n    Keyword args:  \n        generator (Generator, optional): Random number generator. Defaults to `oneflow::DefaultGenerator` if not provided.\n        out (Tensor, optional): Output tensor, will be resized and filled with the result. If not provided, a new tensor is created.  \n\n    Example:\n    \n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> flow.normal(mean=0.5, std=flow.arange(1., 6.)).shape\n        oneflow.Size([5])\n       \n    \n    normal(mean, std=1.0, `*`, generator=None, out=None) -> Tensor\n    Similar to the function above, but the standard deviations are shared among all drawn elements.\n\n    Args:  \n        mean (Tensor): the tensor of per-element means  \n        std (float, optional): the standard deviation  \n\n    Keyword args:  \n        generator (Generator, optional): Random number generator. Defaults to `oneflow::DefaultGenerator` if not provided.  \n        out (Tensor): The output tensor\n\n    Returns:\n        Tensor: The output tensor, with random normal values.\n\n    Example:\n    \n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> flow.normal(mean=flow.arange(1., 6.)).shape\n        oneflow.Size([5])\n\n    normal(mean, std, size, `*`, out=None, placement=None, sbp=None, generator=None, dtype=None, device=None, requires_grad=False) -> Tensor\n    Returns a tensor of random numbers drawn from separate normal distributions whose mean and standard deviation are given.\n\n    Args:\n        mean (float):  the mean for all distributions\n        std (float):  the standard deviation for all distributions\n        size (int...):  a sequence of integers defining the shape of the output tensor.\n\n    Keyword args:\n        out (Tensor, optional):  the output tensor.\n        placement (flow.placement, optional): The desired device of returned global tensor. If None, will\n          construct local tensor.\n        sbp (flow.sbp, optional): The desired sbp of returned global tensor. It must be equal with the\n          numbers of placement.\n        generator(:class:`oneflow.Generator`, optional):  a pseudorandom number generator for sampling\n        dtype (:class:`oneflow.dtype`, optional): the desired data type of returned tensor.\n            Default: `oneflow.float32`.\n        device: the desired device of returned tensor. Default: cpu.\n        requires_grad(bool, optional): If autograd should record operations on the returned tensor. Default: False.\n\n    Example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> generator = flow.Generator()\n        >>> generator.manual_seed(0) #doctest: +ELLIPSIS\n        <oneflow._oneflow_internal.Generator object at ...>\n        >>> y = flow.normal(0, 1, 5, generator=generator)\n        >>> y\n        tensor([2.2122, 1.1631, 0.7740, 0.4838, 1.0434], dtype=oneflow.float32)\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow._C.randint,\n    \"\"\"\n    randint(low=0, high, size, *, dtype=None, generator=None, device=None, placement=None, sbp=None, requires_grad=False) -> Tensor\n\n    Returns a tensor filled with random integers generated uniformly between low (inclusive) and high (exclusive).\n\n    The shape of the tensor is defined by the variable argument ``size``.\n\n    The interface is consistent with PyTorch.    \n    The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.randint.html.\n\n    Args:\n        low (int, optional):  Lowest integer to be drawn from the distribution. Default: 0.\n        high (int):  One above the highest integer to be drawn from the distribution.\n        size (tuple or oneflow.Size):  Defining the shape of the output tensor.\n          Can be a variable number of arguments or a collection like a list or tuple or oneflow.Size.\n\n    Keyword args:\n        dtype (oneflow.dtype, optional): The desired data type of returned tensor. Default: ``flow.int64``.\n        generator (oneflow.Generator, optional) – a pseudorandom number generator for sampling\n        device (oneflow.device, optional): The desired device of returned local tensor. If None, uses the\n          current device.\n        placement (oneflow.placement, optional): The desired device of returned global tensor. If None, will\n          construct local tensor.\n        sbp (oneflow.sbp, optional): The desired sbp of returned global tensor. It must be equal with the\n          numbers of placement.\n        requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: False.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> generator = flow.Generator()\n        >>> generator.manual_seed(0) #doctest: +ELLIPSIS\n        <oneflow._oneflow_internal.Generator object at ...>\n        >>> y = flow.randint(0, 5, (3,3), generator=generator) # construct local tensor\n        >>> y\n        tensor([[2, 2, 3],\n                [4, 3, 4],\n                [2, 4, 2]], dtype=oneflow.int64)\n        >>> y.is_global\n        False\n        >>> placement = flow.placement(\"cpu\", ranks=[0])\n        >>> y = flow.randint(0, 5, (3,3), generator=generator, placement=placement, sbp=flow.sbp.broadcast) # construct global tensor\n        >>> y.is_global\n        True\n\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow._C.randint_like,\n    \"\"\"\n    randint_like(input, low=0, high, size, *, dtype=None, generator=None, device=None, placement=None, sbp=None, requires_grad=False) -> Tensor\n\n    Returns a tensor filled with random integers generated uniformly between low (inclusive) and high (exclusive).\n\n    The interface is consistent with PyTorch.    \n    The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.randint_like.html.\n\n    Args:\n        input (oneflow.Tensor): the size of ``input`` will determine size of the output tensor.\n        low (int, optional):  Lowest integer to be drawn from the distribution. Default: 0.\n        high (int):  One above the highest integer to be drawn from the distribution.\n\n\n    Keyword args:\n        dtype (oneflow.dtype, optional): The desired data type of returned tensor. Default: ``flow.int64``.\n        generator (oneflow.Generator, optional) – a pseudorandom number generator for sampling\n        device (oneflow.device, optional): The desired device of returned local tensor. If None, uses the\n          current device.\n        placement (oneflow.placement, optional): The desired device of returned global tensor. If None, will\n          construct local tensor.\n        sbp (oneflow.sbp, optional): The desired sbp of returned global tensor. It must be equal with the\n          numbers of placement.\n        requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: False.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> generator = flow.Generator()\n        >>> generator.manual_seed(0) #doctest: +ELLIPSIS\n        <oneflow._oneflow_internal.Generator object at ...>\n        >>> x = flow.randn(2, 2, generator=generator)\n        >>> y = flow.randint_like(x, 0, 5, generator=generator) # construct local tensor\n        >>> y\n        tensor([[3, 4],\n                [2, 4]], dtype=oneflow.int64)\n        >>> y.is_global\n        False\n        >>> placement = flow.placement(\"cpu\", ranks=[0])\n        >>> y = flow.randint_like(x, 0, 5, generator=generator, placement=placement, sbp=flow.sbp.broadcast) # construct global tensor\n        >>> y.is_global\n        True\n\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow._C.randperm,\n    r\"\"\"\n    randperm(n, *, generator=None, dtype=torch.int64, device=None, placement=None, sbp=None, requires_grad=False) -> Tensor\n\n    Returns a random permutation of integers from ``0`` to ``n - 1``.\n\n    Args:\n        n (int): the upper bound (exclusive)\n\n    Keyword args:\n        generator(:class:`oneflow.Generator`, optional):  a pseudorandom number generator for sampling\n        dtype (:class:`oneflow.dtype`, optional): the desired data type of returned tensor.\n            Default: ``oneflow.int64``.\n        device: the desired device of returned tensor. Default: cpu.\n        placement:(:class:`flow.placement`, optional): The desired device of returned global tensor. If None,\n            will construct local tensor.\n        sbp: (:class:`flow.sbp`, optional): The desired sbp of returned global tensor. It must be equal with the\n            numbers of placement.\n        requires_grad(bool, optional): If autograd should record operations on the returned tensor. Default: False.\n\n    Example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> generator = flow.Generator()\n        >>> generator.manual_seed(0) #doctest: +ELLIPSIS\n        <oneflow._oneflow_internal.Generator object at ...>\n        >>> y = flow.randperm(5, generator=generator) # construct local tensor\n        >>> y\n        tensor([2, 4, 3, 0, 1], dtype=oneflow.int64)\n        >>> y.is_global\n        False\n        >>> placement = flow.placement(\"cpu\", ranks=[0])\n        >>> y = flow.randperm(5, generator=generator, placement=placement, sbp=flow.sbp.broadcast) # construct global tensor\n        >>> y.is_global\n        True\n\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.multinomial,\n    \"\"\"\n    multinomial(input, num_samples, replacement=False, generator=None) -> LongTensor\n    \n    Returns a tensor where each row contains :attr:`num_samples` indices sampled\n    from the multinomial probability distribution located in the corresponding row\n    of tensor :attr:`input`.\n\n    .. note::\n      The rows of :attr:`input` do not need to sum to one (in which case we use\n      the values as weights), but must be non-negative, finite and have\n      a non-zero sum.\n\n    Indices are ordered from left to right according to when each was sampled\n    (first samples are placed in first column).\n\n    If :attr:`input` is a vector, :attr:`out` is a vector of size :attr:`num_samples`.\n\n    If :attr:`input` is a matrix with `m` rows, :attr:`out` is an matrix of shape\n    :math:`(m x num\\_samples)`.\n\n    If replacement is ``True``, samples are drawn with replacement.\n\n    If not, they are drawn without replacement, which means that when a\n    sample index is drawn for a row, it cannot be drawn again for that row.\n\n    .. note::\n        When drawn without replacement, :attr:`num_samples` must be lower than\n        number of non-zero elements in :attr:`input` (or the min number of non-zero\n        elements in each row of :attr:`input` if it is a matrix).\n\n    Args:\n        input (Tensor): the input tensor containing probabilities\n        num_samples (int): number of samples to draw\n        replacement (bool, optional): whether to draw with replacement or not\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> gen = flow.manual_seed(0)\n        >>> weights = flow.tensor([0, 10, 3, 0], dtype=flow.float) # create a tensor of weights\n        >>> flow.multinomial(weights, 2)\n        tensor([1, 2], dtype=oneflow.int64)\n        >>> flow.multinomial(weights, 4, replacement=True)\n        tensor([1, 2, 1, 1], dtype=oneflow.int64)\n\n    \"\"\",\n)\n"
  },
  {
    "path": "python/oneflow/framework/docstr/reduce_ops.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow\nfrom oneflow.framework.docstr.utils import add_docstr\n\nadd_docstr(\n    oneflow.max,\n    \"\"\"\n    oneflow.max(input, dim=None, keepdim=False)\n\n    Computes the maximum value of all elements in the input tensor.\n    \n    Args:\n        input (oneflow.Tensor): the Input Tensor\n        dim (int, optional): the dimension to reduce. Default: `None`\n        keepdim (bool, optional): whether the output tensor has dim retained or not. Default: `False`\n\n    Returns:\n        Tensor or Tuple(oneflow.Tensor, oneflow.Tensor(dtype=int64)): If :attr:`dim` is `None`, returns \n        the maximum value of all elements in the `input` tensor. Otherwise, returns a tuple of Tensor (values, indices), \n        where the `values` are the maximum value of all elements in the `input` tensor,\n        the `indices` are the indices of the elements in the original input tensor.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        \n        >>> input = flow.Tensor([[4, 1, 5], [2, 6, 3]])\n        >>> flow.max(input)\n        tensor(6., dtype=oneflow.float32)\n        >>> result = flow.max(input, dim=1)\n        >>> result.values\n        tensor([5., 6.], dtype=oneflow.float32)\n        >>> result.indices\n        tensor([2, 1], dtype=oneflow.int64)\n\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.min,\n    \"\"\"\n    oneflow.min(input, dim=None, keepdim=False)\n    \n    Computes the minimum value of all elements in the input tensor.\n    \n    Args:\n        input (oneflow.Tensor): the Input Tensor\n        dim (int, optional): the dimension to reduce. Default: `None`\n        keepdim (bool, optional): whether the output tensor has dim retained or not. Default: `False`\n\n    Returns:\n        Tensor or Tuple(oneflow.Tensor, oneflow.Tensor(dtype=int64)): If :attr:`dim` is `None`, returns \n        the minimum value of all elements in the `input` tensor. Otherwise, returns a tuple of Tensor (values, indices), \n        where the `values` are the minimum value of all elements in the `input` tensor,\n        the `indices` are the indices of the elements in the original input tensor.\n    \n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n\n        >>> input = flow.Tensor([[4, 1, 5], [2, 6, 3]])\n        >>> flow.min(input)\n        tensor(1., dtype=oneflow.float32)\n        >>> result = flow.min(input, dim=1)\n        >>> result.values\n        tensor([1., 2.], dtype=oneflow.float32)\n        >>> result.indices\n        tensor([1, 0], dtype=oneflow.int64)\n\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.sum,\n    \"\"\"\n    oneflow.sum(input, dim=None, keepdim=False) -> Tensor\n\n    Computes the sum of row of elements in a tensor in the given dimension. If the dimension is None, sum of all elements will be caculated.\n    \n    If `keepdim` is `True`, the output tensor is of the same size as `input` except in the dimension(s) `dim` where it is of size 1. Otherwise, `dim` is squeezed :func:`oneflow.squeeze()`, resulting in the output tensor having 1 (or `len(dim)`) fewer dimension(s). \n\n    Args:\n        input (oneflow.Tensor): the Input Tensor\n        dim (int or tuple of ints, optional): the dimension to reduce. Default: `None`\n        keepdim (bool, optional): whether the output tensor has dim retained or not. Default: `False`\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n\n        >>> input = flow.Tensor([[1, 2, 3], [4, 5, 6]])\n        >>> flow.sum(input)\n        tensor(21., dtype=oneflow.float32)\n        >>> flow.sum(input, dim=0)\n        tensor([5., 7., 9.], dtype=oneflow.float32)\n        >>> flow.sum(input, dim=1)\n        tensor([ 6., 15.], dtype=oneflow.float32)\n\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.mean,\n    \"\"\"\n    oneflow.mean(input, dim=None, keepdim=False) -> Tensor\n    \n    Computes the mean of row of elements in a tensor in the given dimension. If the dimension is None, mean of all elements will be caculated.\n    \n    If `keepdim` is `True`, the output tensor is of the same size as `input` except in the dimension(s) `dim` where it is of size 1. Otherwise, `dim` is squeezed :func:`oneflow.squeeze()`, resulting in the output tensor having 1 (or `len(dim)`) fewer dimension(s). \n\n    Args:\n        input (oneflow.Tensor): the Input Tensor\n        dim (int or tuple of ints, optional): the dimension to reduce. Default: `None`\n        keepdim (bool, optional): whether the output tensor has dim retained or not. Default: `False`\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n\n        >>> input = flow.Tensor([[1, 2, 3], [4, 5, 6]])\n        >>> flow.mean(input)\n        tensor(3.5000, dtype=oneflow.float32)\n        >>> flow.mean(input, dim=0)\n        tensor([2.5000, 3.5000, 4.5000], dtype=oneflow.float32)\n        >>> flow.mean(input, dim=1)\n        tensor([2., 5.], dtype=oneflow.float32)\n\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.prod,\n    \"\"\"\n    oneflow.prod(input, dim=None, keepdim=False) -> Tensor\n\n    Computes the product of row of elements in a tensor in the given dimension. If the dimension is None, product of all elements will be caculated.\n    \n    If `keepdim` is `True`, the output tensor is of the same size as `input` except in the dimension(s) `dim` where it is of size 1. Otherwise, `dim` is squeezed :func:`oneflow.squeeze()`, resulting in the output tensor having 1 (or `len(dim)`) fewer dimension(s). \n\n    Args:\n        input (oneflow.Tensor): the Input Tensor\n        dim (int or tuple of ints, optional): the dimension to reduce. Default: `None`\n        keepdim (bool, optional): whether the output tensor has dim retained or not. Default: `False`\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n\n        >>> input = flow.Tensor([[1, 2, 3], [4, 5, 6]])\n        >>> flow.prod(input)\n        tensor(720., dtype=oneflow.float32)\n        >>> flow.prod(input, dim=0)\n        tensor([ 4., 10., 18.], dtype=oneflow.float32)\n        >>> flow.prod(input, dim=1)\n        tensor([  6., 120.], dtype=oneflow.float32)\n\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.all,\n    \"\"\"\n    oneflow.all(input, dim=None, keepdim=False) -> Tensor\n\n    For each row of `input` in the given dimension `dim`, returns True if all element in the row evaluate to True and False otherwise. If the dimension is None, compute if all elements in the input tensor to true.\n    \n    If `keepdim` is `True`, the output tensor is of the same size as `input` except in the dimension(s) `dim` where it is of size 1. Otherwise, `dim` is squeezed :func:`oneflow.squeeze()`, resulting in the output tensor having 1 (or `len(dim)`) fewer dimension(s). \n\n    Args:\n        input (oneflow.Tensor): the Input Tensor\n        dim (int, optional): the dimension to reduce. Default: `None`\n        keepdim (bool, optional): whether the output tensor has dim retained or not. Default: `False`\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n\n        >>> input = flow.Tensor([[1, 2, 3], [4, 5, 6]]) < 4\n        >>> input\n        tensor([[ True,  True,  True],\n                [False, False, False]], dtype=oneflow.bool)\n        >>> flow.all(input)\n        tensor(False, dtype=oneflow.bool)\n        >>> flow.all(input, 1)\n        tensor([ True, False], dtype=oneflow.bool)\n        >>> flow.all(input, 1, True)\n        tensor([[ True],\n                [False]], dtype=oneflow.bool)\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.any,\n    \"\"\"\n    oneflow.any(input, dim=None, keepdim=False) -> Tensor\n    \n    For each row of `input` in the given dimension `dim`, returns True if any element in the row evaluate to True and False otherwise. If the dimension is None, compute if any elements in the input tensor to true.\n    \n    If `keepdim` is `True`, the output tensor is of the same size as `input` except in the dimension(s) `dim` where it is of size 1. Otherwise, `dim` is squeezed :func:`oneflow.squeeze()`, resulting in the output tensor having 1 (or `len(dim)`) fewer dimension(s). \n\n    Args:\n        input (oneflow.Tensor): the Input Tensor\n        dim (int, optional): the dimension to reduce. Default: `None`\n        keepdim (bool, optional): whether the output tensor has dim retained or not. Default: `False`\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        \n        >>> input = flow.Tensor([[1, 2, 3], [4, 5, 6]]) < 4\n        >>> input\n        tensor([[ True,  True,  True],\n                [False, False, False]], dtype=oneflow.bool)\n        >>> flow.any(input)\n        tensor(True, dtype=oneflow.bool)\n        >>> flow.any(input, 0)\n        tensor([True, True, True], dtype=oneflow.bool)\n        >>> flow.any(input, 0, True)\n        tensor([[True, True, True]], dtype=oneflow.bool)\n\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.nansum,\n    r\"\"\"oneflow.nansum(input, dim, keepdim=False, *, dtype=None) -> Tensor\n\n    Returns the sum of each row of the ``input`` tensor in the given dimension ``dim``,\n    treating Not a Numbers (NaNs) as zero. If ``dim`` is a list of dimensions, \n    reduce over all of them.\n\n    If ``keepdim`` is ``True``, the output tensor is of the same size as ``input`` except \n    in the dimension(s) ``dim`` where it is of size 1. \n    Otherwise, ``dim`` is squeezed (see :class:`oneflow.squeeze()`), \n    resulting in the output tensor having 1 (or ``len(dim)``) fewer dimension(s).\n\n    The interface is consistent with PyTorch.\n    The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nansum.html.\n\n    Args:\n        input (oneflow.Tensor): the Input Tensor\n        dim (int, optional): the dimension to reduce. Default: ``None``\n        keepdim (bool, optional): whether the output tensor has ``dim`` retained or not. Default: `False`\n        dtype (oneflow.dtype, optional): the desired data type of returned tensor. \n            If specified, the input tensor is casted to dtype before the operation is performed.\n            This is useful for preventing data type overflows. Default: ``None``.\n\n    Example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> x = flow.tensor([1., 2., float(\"nan\")])\n        >>> flow.nansum(x)\n        tensor(3., dtype=oneflow.float32)\n        >>> x = flow.tensor([[1., float(\"nan\")], [float(\"nan\"), 2]])\n        >>> flow.nansum(x, dim=1)\n        tensor([1., 2.], dtype=oneflow.float32)\n        >>> x = flow.tensor([float(\"nan\") for i in range(3)])\n        >>> flow.nansum(x)\n        tensor(0., dtype=oneflow.float32)\n\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.logsumexp,\n    r\"\"\"\n    oneflow.logsumexp(input, dim, keepdim=False) -> Tensor\n    \n    Returns the log of summed exponentials of each row of the :attr:`input`\n    tensor in the given dimension :attr:`dim`. The computation is numerically\n    stabilized.\n\n    For summation index :math:`j` given by `dim` and other indices :math:`i`, the result is\n\n    .. math::\n        \\text{logsumexp}(x)_{{i}} = \\log \\sum_j \\exp(x_{{ij}})\n\n    The interface is consistent with PyTorch.\n    The documentation is referenced from: https://pytorch.org/docs/1.12/generated/torch.logsumexp.html.\n\n    Args:\n        input (oneflow.Tensor): the Input Tensor\n        dim (int or tuple of ints): the dimension or dimensions to reduce.\n        keepdim (bool, optional): whether the output tensor has dim retained or not. Default: `False`\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n\n        >>> input = flow.Tensor([[1, 2, 3], [4, 5, 6]])\n        >>> flow.logsumexp(input, 0)\n        tensor([4.0486, 5.0486, 6.0486], dtype=oneflow.float32)\n        >>> flow.logsumexp(input, 1)\n        tensor([3.4076, 6.4076], dtype=oneflow.float32)\n\n    \"\"\",\n)\n"
  },
  {
    "path": "python/oneflow/framework/docstr/repeat.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow\nfrom oneflow.framework.docstr.utils import add_docstr\n\nadd_docstr(\n    oneflow.repeat,\n    \"\"\"\n    repeat(input, sizes) -> Tensor\n\n    This operator repeat the input tensor to a larger size along the specified dimensions.\n\n    Args:\n        input (oneflow.Tensor): the input Tensor.\n        sizes (flow.Shape or List): The number of times to repeat this tensor along each dimension.\n\n    Returns:\n        oneflow.Tensor: The result Tensor.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n        >>> np_arr = np.random.randn(5, 3, 6, 9).astype(np.float32)\n        >>> input = flow.Tensor(np_arr)\n        >>> out = input.repeat(1, 1, 2, 2)\n        >>> out.shape\n        oneflow.Size([5, 3, 12, 18])\n        >>> out = input.repeat(2, 1, 1, 2, 2)\n        >>> out.shape\n        oneflow.Size([2, 5, 3, 12, 18])\n    \"\"\",\n)\n"
  },
  {
    "path": "python/oneflow/framework/docstr/repeat_interleave.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow\nfrom oneflow.framework.docstr.utils import add_docstr\n\nadd_docstr(\n    oneflow.repeat_interleave,\n    \"\"\"\n    repeat_interleave(input, repeats, dim=None, *, output_size=None) -> Tensor\n\n    Repeat elements of a tensor.\n\n    .. warning::\n\n        This is different from :meth:`oneflow.Tensor.repeat` but similar to ``numpy.repeat``.\n\n    The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.repeat_interleave.html\n\n    Args:\n        input (oneflow.Tensor): the input Tensor.\n        repeats (Tensor or int): The number of repetitions for each element.\n            repeats is broadcasted to fit the shape of the given axis.\n        dim (int, optional): The dimension along which to repeat values.\n            By default, use the flattened input array, and return a flat output\n            array.\n\n    Keyword args:\n        output_size (int, optional): Total output size for the given axis\n            ( e.g. sum of repeats). If given, it will avoid stream syncronization\n            needed to calculate output shape of the tensor.\n\n    Returns:\n        oneflow.Tensor: Repeated tensor which has the same shape as input, except along the given axis.\n    \n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> x = flow.tensor([1, 2, 3])\n        >>> y = flow.tensor([[1, 2], [3, 4]])\n        >>> flow.repeat_interleave(y, 2)\n        tensor([1, 1, 2, 2, 3, 3, 4, 4], dtype=oneflow.int64)\n        >>> flow.repeat_interleave(y, 3, dim=1)\n        tensor([[1, 1, 1, 2, 2, 2],\n                [3, 3, 3, 4, 4, 4]], dtype=oneflow.int64)\n        >>> flow.repeat_interleave(y, flow.tensor([1, 2]), dim=0)\n        tensor([[1, 2],\n                [3, 4],\n                [3, 4]], dtype=oneflow.int64)\n        >>> flow.repeat_interleave(y, flow.tensor([1, 2]), dim=0, output_size=3)\n        tensor([[1, 2],\n                [3, 4],\n                [3, 4]], dtype=oneflow.int64)\n    \n    ..\n        Feature Stage of Operator [repeat_interleave].\n        - Maintainer List [@BBuf]\n        - Current Stage [ ]\n        - Alpha Stage Check List [ ]\n          - API(Compatible with PyTorch 1.11, anything incompatible must be noted in API Doc.)[Yes]\n          - Doc(API Doc must be provided and showed normally on the web page.)[Yes]\n          - Functionality and its' Test [ ]\n            - Functionality is highly compatiable with PyTorch 1.11. [Yes]\n            - eager local [Yes] [@QiangX-man, @hjchen2]\n              - forward [Yes]\n              - backward [Yes]\n              - gpu [Yes]\n              - cpu [Yes]\n            - graph local [ ] [@BBuf, @strint, @hjchen2]\n              - forward [Yes]\n              - backward [ ]\n              - gpu [Yes]\n              - cpu [Yes]\n          - Exception Handling\n            - Exception Message and Hint must be provided [Yes]\n        - Beta Stage Check List [ ]\n          - API(High compatibility with PyTorch 1.11, shouldn't have anything incompatible for a naive reason.)[Yes]\n          - Doc(Same standard as Alpha Stage)[ ]\n          - Functionality and its' Test [ ]\n            - eager global [ ]\n              - forward [ ]\n              - backward [ ]\n              - gpu [ ]\n              - cpu [ ]\n            - graph gloal [ ]\n              - forward [ ]\n              - backward [ ]\n              - gpu [ ]\n              - cpu [ ]\n          - Performance and Scalability(Must be evaluated.)[ ]\n            - CUDA kernel [ ]\n            - CPU kernel [ ]\n            - N nodes M devices [ ]\n          - Exception Handling [ ]\n            - Exception Message and Hint must be provided [ ]\n            - Try you best to do Exception Recovery [ ]\n        - Stable Stage Check List [ ]\n          - API(Same standard as Beta Stage)[ ]\n          - Doc(Same standard as Beta Stage)[ ]\n          - Functionality and its' Test [ ]\n            - fp16 and AMP [ ]\n            - NHWC [ ]\n          - Performance and Scalability(Must be evaluated.)[ ]\n          - Exception Handling [ ]\n    \"\"\",\n)\n"
  },
  {
    "path": "python/oneflow/framework/docstr/roc_auc_score.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow\nfrom oneflow.framework.docstr.utils import add_docstr\n\nadd_docstr(\n    oneflow.roc_auc_score,\n    \"\"\"\n    oneflow.roc_auc_score(label, pred) -> Tensor\n\n    Compute Area Under the Receiver Operating Characteristic Curve (ROC AUC) from prediction scores.\n\n    Note: Currently this implementation can only be used on CPU.\n\n    Args:\n        label (Tensor[N, 1]): True lable of the samples\n        pred (Tensor[N, 1]): Predicted probability value to be true\n        \n    Returns:\n        Tensor[1, ]: float32 tensor of auc score\n       \n    For example:\n\n    .. code-block:: python\n\n\n        >>> import numpy as np\n        >>> import oneflow as flow\n        \n        >>> label = flow.Tensor([0, 0, 1, 1])\n        >>> pred = flow.Tensor([0.1, 0.4, 0.35, 0.8])     \n          \n        >>> score = flow.roc_auc_score(label, pred)\n        >>> score\n        tensor([0.7500], dtype=oneflow.float32)\n\n\n    \"\"\",\n)\n"
  },
  {
    "path": "python/oneflow/framework/docstr/searchsorted.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow\nfrom oneflow.framework.docstr.utils import add_docstr\n\nadd_docstr(\n    oneflow.searchsorted,\n    \"\"\"\n    searchsorted() -> oneflow.Tensor\n\n    Find the indices from the innermost dimension of sorted_sequence such that, if the corresponding values\n    in values were inserted before the indices, the order of the corresponding innermost dimension within\n    sorted_sequence would be preserved. Return a new tensor with the same size as values. If right is False\n    (default), then the left boundary of sorted_sequence is closed. More formally, the returned index\n    satisfies the following rules:\n\n    =================  =========  ==========================================================================\n    sorted_sequence     right      returned index satisfies\n    =================  =========  ==========================================================================\n    1-D                 False      sorted_sequence[i-1] < values[m][n]...[l][x] <= sorted_sequence[i]\n    1-D                 True       sorted_sequence[i-1] <= values[m][n]...[l][x] < sorted_sequence[i]\n    N-D                 False      sorted_sequence[m][n]...[l][i-1] < values[m][n]...[l][x] \n                                                    <= sorted_sequence[m][n]...[l][i]\n    N-D                 True       sorted_sequence[m][n]...[l][i-1] <= values[m][n]...[l][x] \n                                                    sorted_sequence[m][n]...[l][i]\n    =================  =========  ==========================================================================\n\n    The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.searchsorted.html\n\n    Args:\n        sorted_sequence (Tensor): N-D or 1-D tensor, containing monotonically increasing sequence on the\n                                innermost dimension.\n        values (Tensor or Scalar): N-D tensor or a Scalar containing the search value(s).\n        out_int32 (bool optional): indicate the output data type. torch.int32 if True, torch.int64 otherwise.\n                                Default value is False, i.e. default output data type is torch.int64.\n        right (bool optional): if False, return the first suitable location that is found. If True, return the\n                                last such index. If no suitable index found, return 0 for non-numerical value\n                                (eg. nan, inf) or the size of innermost dimension within sorted_sequence (one\n                                pass the last index of the innermost dimension). In other words, if False, gets\n                                the lower bound index for each value in values on the corresponding innermost\n                                dimension of the sorted_sequence. If True, gets the upper bound index instead.\n                                Default value is False.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> sorted_sequence = flow.tensor([[1, 3, 5, 7, 9], [2, 4, 6, 8, 10]])\n        >>> sorted_sequence\n        tensor([[ 1,  3,  5,  7,  9],\n                [ 2,  4,  6,  8, 10]], dtype=oneflow.int64)\n        >>> values = flow.tensor([[3, 6, 9], [3, 6, 9]])\n        >>> values\n        tensor([[3, 6, 9],\n                [3, 6, 9]], dtype=oneflow.int64)\n        >>> flow.searchsorted(sorted_sequence, values)\n        tensor([[1, 3, 4],\n                [1, 2, 4]], dtype=oneflow.int64)\n        >>> flow.searchsorted(sorted_sequence, values, right=True)\n        tensor([[2, 3, 5],\n                [1, 3, 4]], dtype=oneflow.int64)\n        >>> sorted_sequence_1d = flow.tensor([1, 3, 5, 7, 9])\n        >>> sorted_sequence_1d\n        tensor([1, 3, 5, 7, 9], dtype=oneflow.int64)\n        >>> flow.searchsorted(sorted_sequence_1d, values)\n        tensor([[1, 3, 4],\n                [1, 3, 4]], dtype=oneflow.int64)\n\n    \"\"\",\n)\n"
  },
  {
    "path": "python/oneflow/framework/docstr/sort.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow\nfrom oneflow.framework.docstr.utils import add_docstr\n\nadd_docstr(\n    oneflow.sort,\n    \"\"\"Sorts the elements of the input tensor along a given dimension in ascending order by value.\n\n    Args:\n        input (oneflow.Tensor): the input Tensor.\n        dim (int, optional): the dimension to be sorted. Defaults to the last dim (-1).\n        descending (bool, optional): controls the sorting order (ascending or descending).\n\n    Returns:\n        Tuple(oneflow.Tensor, oneflow.Tensor(dtype=int32)): A tuple of (values, indices), where\n        where the values are the sorted values and the indices are the indices of the elements\n        in the original input tensor.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n        >>> x = np.array([[1, 3, 8, 7, 2], [1, 9, 4, 3, 2]], dtype=np.float32)\n        >>> input = flow.Tensor(x)\n        >>> result = flow.sort(input)\n        >>> result.values\n        tensor([[1., 2., 3., 7., 8.],\n                [1., 2., 3., 4., 9.]], dtype=oneflow.float32)\n        >>> result.indices\n        tensor([[0, 4, 1, 3, 2],\n                [0, 4, 3, 2, 1]], dtype=oneflow.int32)\n        >>> result = flow.sort(input, descending=True)\n        >>> result.values\n        tensor([[8., 7., 3., 2., 1.],\n                [9., 4., 3., 2., 1.]], dtype=oneflow.float32)\n        >>> result.indices\n        tensor([[2, 3, 1, 4, 0],\n                [1, 2, 3, 4, 0]], dtype=oneflow.int32)\n        >>> result = flow.sort(input, dim=0)\n        >>> result.values\n        tensor([[1., 3., 4., 3., 2.],\n                [1., 9., 8., 7., 2.]], dtype=oneflow.float32)\n        >>> result.indices\n        tensor([[0, 0, 1, 1, 0],\n                [1, 1, 0, 0, 1]], dtype=oneflow.int32)\n \n    \"\"\",\n)\n"
  },
  {
    "path": "python/oneflow/framework/docstr/special_ops.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow\nfrom oneflow.framework.docstr.utils import add_docstr\n\nadd_docstr(\n    oneflow.special.digamma,\n    \"\"\"\n    Alias for :func:`oneflow.digamma`. \n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.special.erf,\n    \"\"\"\n    Alias for :func:`oneflow.erf`. \n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.special.erfc,\n    \"\"\"\n    Alias for :func:`oneflow.erfc`. \n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.special.erfinv,\n    \"\"\"\n    Alias for :func:`oneflow.erfinv`. \n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.special.exp2,\n    \"\"\"\n    Alias for :func:`oneflow.exp2`. \n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.special.expm1,\n    \"\"\"\n    Alias for :func:`oneflow.expm1`. \n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.special.log1p,\n    \"\"\"\n    Alias for :func:`oneflow.log1p`. \n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.special.log_softmax,\n    \"\"\"\n    Alias for :func:`oneflow.nn.functional.log_softmax`. \n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.special.logsumexp,\n    \"\"\"\n    Alias for :func:`oneflow.logsumexp`. \n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.special.round,\n    \"\"\"\n    Alias for :func:`oneflow.round`. \n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.special.softmax,\n    \"\"\"\n    Alias for :func:`oneflow.softmax`. \n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.special.psi,\n    \"\"\"\n    Alias for :func:`oneflow.special.digamma`. \n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.special.zeta,\n    r\"\"\"\n    zeta(input, other) -> Tensor\n    \n    Computes the Hurwitz zeta function, elementwise.\n    \n    .. math::\n        \\zeta(x, q) = \\sum_{k=0}^{\\infty} \\frac{1}{(k + q)^x}\n    \n    Args:\n        input (Tensor): the input tensor corresponding to `x`.\n        other (Tensor): the input tensor corresponding to `q`.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> x = flow.tensor([2., 4.])\n        >>> flow.special.zeta(x, 1)\n        tensor([1.6449, 1.0823], dtype=oneflow.float32)\n        >>> flow.special.zeta(x, flow.tensor([1., 2.]))\n        tensor([1.6449, 0.0823], dtype=oneflow.float32)\n        >>> flow.special.zeta(2,flow.tensor([1., 2.]))\n        tensor([1.6449, 0.6449], dtype=oneflow.float32)\n\n    \"\"\",\n)\n"
  },
  {
    "path": "python/oneflow/framework/docstr/split.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow\nfrom oneflow.framework.docstr.utils import add_docstr\n\nadd_docstr(\n    oneflow.split,\n    \"\"\"Splits the tensor into chunks.\n\n    If `split_size_or_sections` is an integer type, then x will be split into equally sized chunks (if possible).\n    Last chunk will be smaller if the tensor size along the given dimension `dim` is not divisible by split_size.\n\n    If `split_size_or_sections` is a list, then x will be split into `len(split_size_or_sections)` chunks\n    with sizes in `dim` according to `split_size_or_sections`.\n\n    Args:\n        x: tensor to split.\n        split_size_or_sections: size of a single chunk or list of sizes for each chunk.\n        dim: dimension along which to split the tensor.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> a = flow.arange(10).view(5, 2)\n        >>> flow.split(a, 2)\n        (tensor([[0, 1],\n                [2, 3]], dtype=oneflow.int64), tensor([[4, 5],\n                [6, 7]], dtype=oneflow.int64), tensor([[8, 9]], dtype=oneflow.int64))\n        >>> flow.split(a, [1, 4])\n        (tensor([[0, 1]], dtype=oneflow.int64), tensor([[2, 3],\n                [4, 5],\n                [6, 7],\n                [8, 9]], dtype=oneflow.int64))\n    \"\"\",\n)\n"
  },
  {
    "path": "python/oneflow/framework/docstr/swapaxes.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow\nfrom oneflow.framework.docstr.utils import add_docstr\n\nadd_docstr(\n    oneflow._C.swapaxes,\n    \"\"\"swapaxes(input, axis0, axis1) -> Tensor\n    \n    This function is equivalent to NumPy’s swapaxes function.\n\n    For example:\n\n    .. code-block:: python\n    \n        >>> import oneflow as flow\n               \n        >>> x = flow.tensor([[[0,1],[2,3]],[[4,5],[6,7]]])\n        >>> x.shape\n        oneflow.Size([2, 2, 2])\n        >>> flow.swapaxes(x, 0, 1).shape\n        oneflow.Size([2, 2, 2])\n        >>> flow.swapaxes(x, 0, 2).shape\n        oneflow.Size([2, 2, 2])\n\n    \"\"\",\n)\n"
  },
  {
    "path": "python/oneflow/framework/docstr/swapdims.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow\nfrom oneflow.framework.docstr.utils import add_docstr\n\nadd_docstr(\n    oneflow._C.swapdims,\n    \"\"\"\n    swapdims(input, dim0, dim1) -> Tensor\n\n    This function is equivalent to torch’s swapdims function.\n\n    For example:\n\n    .. code-block:: python\n    \n        >>> import oneflow as flow\n\n        >>> x = flow.tensor([[[0,1],[2,3]],[[4,5],[6,7]]])\n        >>> x\n        tensor([[[0, 1],\n                 [2, 3]],\n        <BLANKLINE>\n                [[4, 5],\n                 [6, 7]]], dtype=oneflow.int64)\n        >>> flow.swapdims(x, 0, 1)\n        tensor([[[0, 1],\n                 [4, 5]],\n        <BLANKLINE>\n                [[2, 3],\n                 [6, 7]]], dtype=oneflow.int64)\n        >>> flow.swapdims(x, 0, 2)\n        tensor([[[0, 4],\n                 [2, 6]],\n        <BLANKLINE>\n                [[1, 5],\n                 [3, 7]]], dtype=oneflow.int64)\n\n    \"\"\",\n)\n"
  },
  {
    "path": "python/oneflow/framework/docstr/tensor.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow\nfrom oneflow.framework.docstr.utils import add_docstr\n\nadd_docstr(\n    oneflow.tensor,\n    r\"\"\"\n    Constructs a tensor with data, return a global tensor if placement and sbp are in kwargs,\n       otherwise return a local tensor.\n\n    Arguments:\n        data: Initial data for the tensor. Can be a list, tuple, NumPy ndarray, scalar or tensor.\n    Keyword Arguments:\n        dtype (oneflow.dtype, optional) – the desired data type of returned tensor.\n            Default: if None, infers data type from data.\n        device (oneflow.device, optional): the desired device of returned tensor. If placement\n            and sbp is None, uses the current cpu for the default tensor type.\n        placement (oneflow.placement, optional): the desired placement of returned tensor.\n        sbp (oneflow.sbp or tuple of oneflow.sbp, optional): the desired sbp of returned tensor.\n        requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: False\n        pin_memory(bool, optional): If set, returned tensor would be allocated in the pinned memory. Works only for CPU tensors. Default: False.\n\n    Note:\n        The Keyword Argument device is mutually exclusive with placement and sbp.\n\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n\n        >>> x = flow.tensor([1,2,3])\n        >>> x\n        tensor([1, 2, 3], dtype=oneflow.int64)\n\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.from_numpy,\n    r\"\"\"\n    Creates a ``Tensor`` from a ``numpy.ndarray``.\n\n    The returned tensor and ndarray share the same memory. Modifications to the tensor\n    will be reflected in the ndarray and vice versa.\n\n    It currently accepts ndarray with dtypes of numpy.float64, numpy.float32, numpy.float16,\n    numpy.int64, numpy.int32, numpy.int8, numpy.uint8.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n        >>> np_arr = np.arange(6).reshape(2, 3)\n        >>> t = flow.from_numpy(np_arr)\n        >>> t\n        tensor([[0, 1, 2],\n                [3, 4, 5]], dtype=oneflow.int64)\n        >>> np_arr[0, 0] = -1\n        >>> t\n        tensor([[-1,  1,  2],\n                [ 3,  4,  5]], dtype=oneflow.int64)\n    \"\"\",\n)\n\n\nadd_docstr(\n    oneflow.Tensor.device,\n    r\"\"\"    \n    Is the :class:`oneflow.device` where this Tensor is, which is invalid for global tensor.\n\n    The documentation is referenced from:\n    https://pytorch.org/docs/1.10/generated/torch.Tensor.device.html.\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.placement,\n    r\"\"\"\n    Is the :class:`oneflow.placement` where this Tensor is, which is invalid for local tensor.\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.sbp,\n    r\"\"\"\n    Is the ``oneflow.sbp`` representing that how the data of the global tensor is distributed, which is invalid for local tensor.\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.is_global,\n    r\"\"\"\n    Return whether this Tensor is a global tensor.\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.is_lazy,\n    r\"\"\"\n    Return whether this Tensor is a lazy tensor.\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.atan2,\n    r\"\"\"\n    See :func:`oneflow.atan2`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.expand,\n    \"\"\"\n    Tensor.expand() -> Tensor\n\n    See :func:`oneflow.expand`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.expand_as,\n    \"\"\"\n    expand_as(other) -> Tensor\n\n    Expand this tensor to the same size as :attr:`other`.\n    ``self.expand_as(other)`` is equivalent to ``self.expand(other.size())``.\n\n    Please see :meth:`~Tensor.expand` for more information about ``expand``.\n\n    Args:\n        other (:class:`oneflow.Tensor`): The result tensor has the same size\n            as :attr:`other`.\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.flatten,\n    \"\"\"\n    See :func:`oneflow.flatten`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.floor,\n    \"\"\"\n    See :func:`oneflow.floor`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.floor_,\n    \"\"\"\n    See :func:`oneflow.floor_`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.flip,\n    \"\"\"\n    See :func:`oneflow.flip`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.in_top_k,\n    \"\"\"\n    Tensor.in_top_k(targets, predictions, k) -> Tensor\n\n    See :func:`oneflow.in_top_k`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.index_select,\n    \"\"\"\n    Tensor.index_select(dim, index) -> Tensor\n\n    See :func:`oneflow.index_select`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.numel,\n    \"\"\"\n    See :func:`oneflow.numel`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.offload,\n    \"\"\"\n    Transfer tensor data from GPU memory back to host (CPU) memory. If the tensor is already in host (CPU) memory, the operation does nothing and gives a warning.\n    Note that this operation only changes the storage of the tensor, and the tensor id will not change.\n\n    Note:\n    \n        Both global tensor and local tensor of oneflow are applicable to this operation.\n\n        Use with :func:`oneflow.Tensor.load` and :func:`oneflow.Tensor.is_offloaded`. \n        The behavior of load() is the opposite of offload(), is_offloaded() returns a boolean indicating whether the tensor has been moved to CPU memory.     \n\n        In addition, support for offloading elements of :func:`oneflow.nn.Module.parameters` is provided.        \n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n\n        >>> # local tensor\n        >>> x = flow.tensor(np.random.randn(1024, 1024, 100), dtype=flow.float32, device=flow.device(\"cuda\"), )\n        >>> before_id = id(x)\n        >>> x.offload() # Move the Tensor from the GPU to the CPU\n        >>> after_id = id(x)\n        >>> after_id == before_id\n        True\n        >>> x.is_offloaded()\n        True\n        >>> x.load() # Move the Tensor from the cpu to the gpu\n        >>> x.is_offloaded()\n        False\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n\n        >>> # global tensor\n        >>> # Run on 2 ranks respectively\n        >>> placement = flow.placement(\"cuda\", ranks=[0, 1])\n        >>> sbp = flow.sbp.broadcast\n        >>> x = flow.randn(1024, 1024, 100, dtype=flow.float32, placement=placement, sbp=sbp) # doctest: +SKIP\n        >>> before_id = id(x) # doctest: +SKIP\n        >>> x.offload() # doctest: +SKIP\n        >>> after_id = id(x) # doctest: +SKIP\n        >>> print(after_id == before_id) # doctest: +SKIP\n        >>> print(x.is_offloaded()) # doctest: +SKIP\n        >>> x.load() # doctest: +SKIP\n        >>> print(x.is_offloaded()) # doctest: +SKIP\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.load,\n    \"\"\"\n    Load tensor data stored on the host (CPU) back to GPU memory. If the tensor is already in GPU memory, the operation does nothing and gives a warning.\n\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.is_offloaded,\n    \"\"\"\n    Tensor.is_offloaded() -> bool\n\n    Determine whether the tensor has been moved to CPU memory and the CUDA device memory has been released.\n\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.new_empty,\n    \"\"\"\n    Tensor.new_empty(*size, dtype=None, device=None, placement=None, sbp=None, requires_grad=False) -> Tensor\n\n    Returns a Tensor of size :attr:`size` filled with uninitialized data. By default, the returned Tensor has the same :attr:`flow.dtype` and :attr:`flow.device` as this tensor.\n\n    Args:\n        size (int...): a list, tuple, or flow.Size of integers defining the shape of the output tensor.\n        dtype (flow.dtype, optional):  the desired type of returned tensor. Default: if None, same flow.dtype as this tensor.\n        device (flow.device, optional): the desired device of returned tensor. Default: if None, same flow.device as this tensor.\n        placement (flow.placement, optional): the desired placement of returned global tensor. Default: if None, the returned tensor is local one using the argument `device`.\n        sbp (flow.sbp.sbp or tuple of flow.sbp.sbp, optional): the desired sbp descriptor of returned global tensor. Default: if None, the returned tensor is local one using the argument `device`.\n        requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: False.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n\n        >>> x = flow.ones(())\n        >>> y = x.new_empty((2, 2))\n        >>> y.shape\n        oneflow.Size([2, 2])\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.new_ones,\n    \"\"\"\n    Tensor.new_ones() -> Tensor\n\n    See :func:`oneflow.new_ones`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.new_zeros,\n    \"\"\"\n    Tensor.new_zeros(size=None, dtype=None, device=None, placement=None, sbp=None, requires_grad=False) -> Tensor\n\n    Returns a Tensor of size size filled with 0. By default, the returned Tensor has the same oneflow.dtype, oneflow.device or oneflow.placement and oneflow.sbp as this tensor.\n\n    Args:\n        size (int...): a list, tuple, or flow.Size of integers defining the shape of the output tensor.\n        dtype (flow.dtype, optional):  the desired type of returned tensor. Default: if None, same flow.dtype as this tensor.\n        device (flow.device, optional): the desired device of returned tensor. Default: if None, same flow.device as this tensor.\n        placement (flow.placement, optional): the desired placement of returned global tensor. Default: if None, the returned tensor is local one using the argument `device`.\n        sbp (flow.sbp.sbp or tuple of flow.sbp.sbp, optional): the desired sbp descriptor of returned global tensor. Default: if None, the returned tensor is local one using the argument `device`.\n        requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: False.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import numpy as np\n        >>> import oneflow as flow\n\n        >>> x = flow.Tensor(np.ones((1, 2, 3)))\n        >>> y = x.new_zeros((2, 2))\n        >>> y\n        tensor([[0., 0.],\n                [0., 0.]], dtype=oneflow.float32)\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.new_full,\n    \"\"\"\n    Tensor.new_full(size, fill_value, dtype=None, device=None, placement=None, sbp=None, requires_grad=False) -> Tensor\n\n    Returns a Tensor of size size filled with fill_value. By default, the returned Tensor has the same oneflow.dtype, oneflow.device or oneflow.placement and oneflow.sbp as this tensor.\n\n    Args:\n        fill_value (scalar): the number to fill the output tensor with.\n        size (int...): a list, tuple, or flow.Size of integers defining the shape of the output tensor.\n        dtype (flow.dtype, optional):  the desired type of returned tensor. Default: if None, same flow.dtype as this tensor.\n        device (flow.device, optional): the desired device of returned tensor. Default: if None, same flow.device as this tensor.\n        placement (flow.placement, optional): the desired placement of returned global tensor. Default: if None, the returned tensor is local one using the argument `device`.\n        sbp (flow.sbp.sbp or tuple of flow.sbp.sbp, optional): the desired sbp descriptor of returned global tensor. Default: if None, the returned tensor is local one using the argument `device`.\n        requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: False.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import numpy as np\n        >>> import oneflow as flow\n\n        >>> tensor = flow.ones((2,), dtype=flow.float64)\n        >>> tensor.new_full((3, 4), 3.141592)\n        tensor([[3.1416, 3.1416, 3.1416, 3.1416],\n                [3.1416, 3.1416, 3.1416, 3.1416],\n                [3.1416, 3.1416, 3.1416, 3.1416]], dtype=oneflow.float64)\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.storage_offset,\n    \"\"\"\n    Tensor.storage_offset() -> Tensor\n\n    Returns self tensor’s offset in the underlying storage in terms of number of storage elements (not bytes).\n\n    Example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> x = flow.tensor([1, 2, 3, 4, 5])\n        >>> x.storage_offset()\n        0\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.local_to_global,\n    \"\"\"\n    Tensor.local_to_global(placement=None, sbp=None, *, check_meta=True, copy=False) -> Tensor\n\n    Creates a global tensor from a local tensor.\n\n    Note:\n        This tensor must be local tensor.\n\n        Both placement and sbp are required.\n\n        The returned global tensor takes this tensor as its local component in the current rank.\n\n        There is no data communication usually, but when sbp is ``oneflow.sbp.broadcast``, the data on rank 0 will be broadcast to other ranks.\n\n    .. warning::\n        When the sbp is ``oneflow.sbp.broadcast``, the data on the non-0 rank will be modified. If you want to keep the input local tensor unchanged,\n        please set the arg copy to True.\n\n    Args:\n        placement (flow.placement, optional): the desired placement of returned global tensor. Default: None\n        sbp (flow.sbp.sbp or tuple of flow.sbp.sbp, optional): the desired sbp of returned global tensor. Default: None\n    Keyword Args:\n        check_meta (bool, optional): indicates whether to check meta information when createing global tensor from local\n            tensor. Only can be set to False when the shape and dtype of the input local tensor on each rank are the same. If set to False, the\n            execution of local_to_global can be accelerated. Default: True\n        copy (bool, optional): When copy is set, the returned global tensor takes the replication of this tensor as its local component in the current rank. Default: False\n\n    .. code-block:: python\n\n        >>> # Run on 2 ranks respectively\n        >>> import oneflow as flow\n        >>> input = flow.tensor([0., 1.], dtype=flow.float32) # doctest: +SKIP\n        >>> output = input.local_to_global(placement=flow.placement(\"cpu\", ranks=[0, 1]), sbp=[flow.sbp.split(0)], check_meta=False) # doctest: +SKIP\n        >>> print(output.size()) # doctest: +SKIP\n        >>> print(output) # doctest: +SKIP\n\n    .. code-block:: python\n\n        >>> # results on rank 0\n        oneflow.Size([4])\n        tensor([0., 1., 0., 1.], placement=oneflow.placement(type=\"cpu\", ranks=[0, 1]), sbp=(oneflow.sbp.split(dim=0),), dtype=oneflow.float32) \n \n    .. code-block:: python\n\n        >>> # results on rank 1\n        oneflow.Size([4])\n        tensor([0., 1., 0., 1.], placement=oneflow.placement(type=\"cpu\", ranks=[0, 1]), sbp=(oneflow.sbp.split(dim=0),), dtype=oneflow.float32)\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.global_to_global,\n    \"\"\"\n    Tensor.global_to_global(placement=None, sbp=None, *, grad_sbp=None, check_meta=False, copy=False) -> Tensor\n\n    Performs Tensor placement and/or sbp conversion.\n\n    Note:\n        This tensor must be global tensor.\n\n        At least one of placement and sbp is required.\n\n        If placement and sbp are all the same as this tensor's own placement and sbp, then returns this tensor own.\n    \n    Args:\n        placement (flow.placement, optional): the desired placement of returned global tensor. Default: None\n        sbp (flow.sbp.sbp or tuple of flow.sbp.sbp, optional): the desired sbp of returned global tensor. Default: None\n    Keyword Args:\n        grad_sbp (flow.sbp.sbp or tuple of flow.sbp.sbp, optional): manually specify the sbp of this tensor's grad\n            tensor in the backward pass. If None, the grad tensor sbp will be infered automatically. Default: None\n        check_meta (bool, optional): indicates whether to check meta information. If set to True, check the consistency\n            of the input meta information (placement and sbp) on each rank. Default: False\n        copy (bool, optional): When copy is set, a new Tensor is created even when the Tensor already matches the desired conversion. Default: False\n\n    .. code-block:: python\n\n        >>> # Run on 2 ranks respectively\n        >>> import oneflow as flow\n        >>> input = flow.tensor([0., 1.], dtype=flow.float32, placement=flow.placement(\"cpu\", ranks=[0, 1]), sbp=[flow.sbp.broadcast]) # doctest: +SKIP\n        >>> output = input.global_to_global(placement=flow.placement(\"cpu\", ranks=[0, 1]), sbp=[flow.sbp.split(0)]) # doctest: +SKIP\n        >>> print(output.size()) # doctest: +SKIP\n        >>> print(output) # doctest: +SKIP\n\n    .. code-block:: python\n\n        >>> # results on rank 0\n        oneflow.Size([2])\n        tensor([0., 1.], placement=oneflow.placement(type=\"cpu\", ranks=[0, 1]), sbp=(oneflow.sbp.split(dim=0),), dtype=oneflow.float32)\n\n    .. code-block:: python\n\n        >>> # results on rank 1\n        oneflow.Size([2])\n        tensor([0., 1.], placement=oneflow.placement(type=\"cpu\", ranks=[0, 1]), sbp=(oneflow.sbp.split(dim=0),), dtype=oneflow.float32)\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.to_global,\n    \"\"\"\n    Tensor.to_global(placement=None, sbp=None, **kwargs) -> Tensor\n\n    Creates a global tensor if this tensor is a local tensor, otherwise performs Tensor placement and/or sbp conversion.\n\n    Note:\n        This tensor can be local tensor or global tensor.\n\n        - For local tensor\n\n          Both placement and sbp are required.\n\n          The returned global tensor takes this tensor as its local component in the current rank.\n\n          There is no data communication usually, but when sbp is ``oneflow.sbp.broadcast``, the data on rank 0 will be broadcast to other ranks.\n\n        - For global tensor\n\n          At least one of placement and sbp is required.\n\n          If placement and sbp are all the same as this tensor's own placement and sbp, then returns this tensor own.\n\n    .. warning::\n        When the input tensor is a local tensor and sbp is ``oneflow.sbp.broadcast``, the data on the non-0 rank will be modified.\n        If you want to keep the input local tensor unchanged, please set the arg copy to True.\n\n    Args:\n        placement (flow.placement, optional): the desired placement of returned global tensor. Default: None\n        sbp (flow.sbp.sbp or tuple of flow.sbp.sbp, optional): the desired sbp of returned global tensor. Default: None\n    Keyword Args:\n        grad_sbp (flow.sbp.sbp or tuple of flow.sbp.sbp, optional): manually specify the sbp of this tensor's grad\n            tensor in the backward pass. If None, the grad tensor sbp will be infered automatically. It is only used if this tensor is a\n            global tensor. Default: None\n        check_meta (bool, optional): indicates whether to check meta information. If set to True, check the input meta\n            information on each rank. Default: True if this tensor is a local tensor, False if this tensor is a global tensor\n        copy (bool, optional): When copy is set, copy occurres in this operation. For local tensor, the returned global tensor takes the\n            replication of this tensor as its local component in the current rank. For global tensor, a new Tensor is created even when\n            the Tensor already matches the desired conversion. Default: False\n\n    For local tensor:\n\n    .. code-block:: python\n\n        >>> # Run on 2 ranks respectively\n        >>> import oneflow as flow\n        >>> input = flow.tensor([0., 1.], dtype=flow.float32) # doctest: +SKIP\n        >>> output = input.to_global(placement=flow.placement(\"cpu\", ranks=[0, 1]), sbp=[flow.sbp.split(0)], check_meta=False) # doctest: +SKIP\n        >>> print(output.size()) # doctest: +SKIP\n        >>> print(output) # doctest: +SKIP\n\n    .. code-block:: python\n\n        >>> # results on rank 0\n        oneflow.Size([4])\n        tensor([0., 1., 0., 1.], placement=oneflow.placement(type=\"cpu\", ranks=[0, 1]), sbp=(oneflow.sbp.split(dim=0),), dtype=oneflow.float32) \n \n    .. code-block:: python\n\n        >>> # results on rank 1\n        oneflow.Size([4])\n        tensor([0., 1., 0., 1.], placement=oneflow.placement(type=\"cpu\", ranks=[0, 1]), sbp=(oneflow.sbp.split(dim=0),), dtype=oneflow.float32)\n\n    For global tensor:\n\n    .. code-block:: python\n\n        >>> # Run on 2 ranks respectively\n        >>> import oneflow as flow\n        >>> input = flow.tensor([0., 1.], dtype=flow.float32, placement=flow.placement(\"cpu\", ranks=[0, 1]), sbp=[flow.sbp.broadcast]) # doctest: +SKIP\n        >>> output = input.to_global(placement=flow.placement(\"cpu\", ranks=[0, 1]), sbp=[flow.sbp.split(0)]) # doctest: +SKIP\n        >>> print(output.size()) # doctest: +SKIP\n        >>> print(output) # doctest: +SKIP\n\n    .. code-block:: python\n\n        >>> # results on rank 0\n        oneflow.Size([2])\n        tensor([0., 1.], placement=oneflow.placement(type=\"cpu\", ranks=[0, 1]), sbp=(oneflow.sbp.split(dim=0),), dtype=oneflow.float32)\n\n    .. code-block:: python\n\n        >>> # results on rank 1\n        oneflow.Size([2])\n        tensor([0., 1.], placement=oneflow.placement(type=\"cpu\", ranks=[0, 1]), sbp=(oneflow.sbp.split(dim=0),), dtype=oneflow.float32)\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.to_consistent,\n    \"\"\"\n    This interface is no longer available, please use :func:`oneflow.Tensor.to_global` instead.\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.to_local,\n    \"\"\"\n    Tensor.to_local(**kwargs) -> Tensor\n\n    Returns the local component of this global tensor in the current rank.\n\n    Keyword Args:\n        copy (bool, optional): When copy is set, a new replicated tensor of the local component of this global tensor in the current rank is returned. Default: False\n\n    Note:\n        This tensor should be a global tensor, and it returns a empty tensor if there is no local component in the current rank.\n\n        No copy occurred in this operation if copy is not set.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> # Run on 2 ranks respectively\n        >>> import oneflow as flow\n        >>> x = flow.tensor([0., 1.], dtype=flow.float32, placement=flow.placement(\"cpu\", ranks=[0, 1]), sbp=[flow.sbp.split(0)]) # doctest: +SKIP\n        >>> y = x.to_local() # doctest: +SKIP\n        >>> print(y.size()) # doctest: +SKIP\n        >>> print(y) # doctest: +SKIP\n\n    .. code-block:: python\n\n        >>> # results on rank 0\n        oneflow.Size([1])\n        tensor([0.], dtype=oneflow.float32)\n\n    .. code-block:: python\n\n        >>> # results on rank 1\n        oneflow.Size([1])\n        tensor([1.], dtype=oneflow.float32)\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.transpose,\n    \"\"\"\n    See :func:`oneflow.transpose`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.logical_not,\n    \"\"\"\n    logical_not() -> Tensor\n    See :func:`oneflow.logical_not`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.lerp,\n    \"\"\"\n    See :func:`oneflow.lerp`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.lerp_,\n    \"\"\"\n    See :func:`oneflow.lerp_`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.quantile,\n    \"\"\"\n    See :func:`oneflow.quantile`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.sqrt,\n    \"\"\"\n    See :func:`oneflow.sqrt`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.square,\n    \"\"\"\n    See :func:`oneflow.square`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.std,\n    \"\"\"\n    See :func:`oneflow.std`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.var,\n    \"\"\"\n    See :func:`oneflow.var`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.squeeze,\n    \"\"\"\n    Tensor.squeeze(dim=None) -> Tensor\n    See :func:`oneflow.squeeze`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.squeeze_,\n    \"\"\"\n    Tensor.squeeze_(dim=None) -> Tensor\n    In-place version of :func:`oneflow.Tensor.squeeze`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.unfold,\n    \"\"\"\n    Returns a view of the original tensor which contains all slices of `size` size from `self`\n    tensor in the dimension `dimension`.\n\n    Step between two slices is given by `step`.\n\n    If sizedim is the size of dimension `dimension` for `self`, the size of dimension dimension in the\n    returned tensor will be (sizedim - size) / step + 1.\n\n    An additional dimension of size `size` is appended in the returned tensor.\n\n    The interface is consistent with PyTorch.\n    The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.Tensor.unfold.html.\n\n    Args:\n        dimension (int): dimension in which unfolding happens\n        size (int): the size of each slice that is unfolded\n        step (int): the step between each slice\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import numpy as np\n        >>> import oneflow as flow\n\n        >>> x = flow.arange(1, 8)\n        >>> x\n        tensor([1, 2, 3, 4, 5, 6, 7], dtype=oneflow.int64)\n        >>> x.unfold(0, 2, 1)\n        tensor([[1, 2],\n                [2, 3],\n                [3, 4],\n                [4, 5],\n                [5, 6],\n                [6, 7]], dtype=oneflow.int64)\n        >>> x.unfold(0, 2, 2)\n        tensor([[1, 2],\n                [3, 4],\n                [5, 6]], dtype=oneflow.int64)\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.matmul,\n    \"\"\"\n    See :func:`oneflow.matmul`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.mv,\n    \"\"\"\n    See :func:`oneflow.mv`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.mm,\n    \"\"\"\n    See :func:`oneflow.mm`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.narrow,\n    \"\"\"\n    See :func:`oneflow.narrow`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.unsqueeze,\n    \"\"\"\n    Tensor.unsqueeze(dim) -> Tensor\n\n    See :func:`oneflow.unsqueeze`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.unsqueeze_,\n    \"\"\"\n    Tensor.unsqueeze_(dim) -> Tensor\n\n    In-place version of :func:`oneflow.Tensor.unsqueeze`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.as_strided,\n    \"\"\"\n    Tensor.as_strided(size, stride, storage_offset=None) -> Tensor\n\n    See :func:`oneflow.as_strided`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.as_strided_,\n    \"\"\"\n    Tensor.as_strided_(size, stride, storage_offset=None) -> Tensor\n\n    In-place version of :func:`oneflow.Tensor.as_strided`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.permute,\n    \"\"\"\n    See :func:`oneflow.permute`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.abs,\n    \"\"\"\n    See :func:`oneflow.abs`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.acos,\n    \"\"\"\n    See :func:`oneflow.acos`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.arccos,\n    \"\"\"\n    See :func:`oneflow.arccos`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.acosh,\n    \"\"\"\n    See :func:`oneflow.acosh`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.arccosh,\n    \"\"\"\n    See :func:`oneflow.arccosh`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.arctanh,\n    \"\"\"\n    See :func:`oneflow.arctanh`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.argmax,\n    \"\"\"\n    See :func:`oneflow.argmax`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.argmin,\n    \"\"\"\n    See :func:`oneflow.argmin`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.argsort,\n    \"\"\"\n    See :func:`oneflow.argsort`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.argwhere,\n    \"\"\"\n    See :func:`oneflow.argwhere`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.atanh,\n    \"\"\"\n    See :func:`oneflow.atanh`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.backward,\n    \"\"\"\n    Computes the gradient of current tensor `w.r.t.` graph leaves.\n\n    The graph is differentiated using the chain rule. If the tensor is non-scalar (i.e. its data has more than one element) and requires gradient, the function additionally requires specifying gradient. It should be a tensor of matching type and location, that contains the gradient of the differentiated function w.r.t. self.\n\n    This function accumulates gradients in the leaves - you might need to zero .grad attributes or set them to None before calling it. See Default gradient layouts for details on the memory layout of accumulated gradients.\n\n    Note:\n        If you run any forward ops, create gradient, and/or call backward in a user-specified CUDA stream context, see Stream semantics of backward passes.\n    Note:\n        When inputs are provided and a given input is not a leaf, the current implementation will call its grad_fn (though it is not strictly needed to get this gradients). It is an implementation detail on which the user should not rely. See https://github.com/pytorch/pytorch/pull/60521#issuecomment-867061780 for more details.\n\n    The interface is consistent with PyTorch.\n    The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.Tensor.backward.html.\n\n    Args:\n        gradient (Tensor or None): Gradient w.r.t. the tensor. If it is a tensor, it will be automatically converted to a Tensor that does not require grad unless create_graph is True. None values can be specified for scalar Tensors or ones that don’t require grad. If a None value would be acceptable then this argument is optional.\n\n        retain_graph (bool, optional): If False, the graph used to compute the grads will be freed. Note that in nearly all cases setting this option to True is not needed and often can be worked around in a much more efficient way. Defaults to the value of create_graph.\n\n        create_graph (bool, optional): If True, graph of the derivative will be constructed, allowing to compute higher order derivative products. Defaults to False.\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.grad,\n    r\"\"\"\n    Return the gradient calculated by autograd functions. This property is None by default.\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.grad_fn,\n    r\"\"\"\n    Return the function that created this tensor if it's ``requires_grad`` is True.\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.inverse,\n    \"\"\"\n    See :func:`oneflow.linalg.inv`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.trunc,\n    \"\"\"\n    See :func:`oneflow.trunc`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.is_leaf,\n    r\"\"\"\n    All Tensors that have ``requires_grad`` which is ``False`` will be leaf Tensors by convention.\n\n    For Tensor that have ``requires_grad`` which is ``True``, they will be leaf Tensors if they\n    were created by source operations.\n\n    Only leaf Tensors will have their ``grad`` populated during a call to ``backward()``. To get\n    ``grad`` populated for non-leaf Tensors, you can use ``retain_grad()``.\n\n    Compatible with PyTorch.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> a = flow.rand(10, requires_grad=False)\n        >>> a.is_leaf\n        True\n        >>> a = flow.rand(10, requires_grad=True)\n        >>> a.is_leaf\n        True\n        >>> b = a.cuda()\n        >>> b.is_leaf\n        False\n        >>> c = a + 2\n        >>> c.is_leaf\n        False\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.requires_grad,\n    r\"\"\"\n    Is ``True`` if gradient need to be computed for this Tensor, ``False`` otherwise.\n\n    Compatible with PyTorch.\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.requires_grad_,\n    r\"\"\"oneflow.Tensor.requires_grad_(requires_grad=True) -> Tensor\n    Sets this tensor’s requires_grad attribute in-place. Returns this tensor.\n\n    Compatible with PyTorch.\n\n    Args:\n        requires_grad (bool): Change the requires_grad flag for this Tensor. Default is ``True``.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> a = flow.rand(10, requires_grad=False)\n        >>> a.requires_grad\n        False\n        >>> a = a.requires_grad_(requires_grad=True)\n        >>> a.requires_grad\n        True\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.register_hook,\n    r\"\"\"oneflow.Tensor.register_hook(hook)\n\n    Registers a backward hook.\n\n    The hook will be called every time a gradient with respect to the Tensor is computed.\n    The hook should have the following signature:\n\n    .. code-block:: \n\n        hook(grad) -> Tensor or None\n\n\n    The hook should not modify its argument, but it can optionally return a new gradient which\n    will be used in place of ``grad``.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> x = flow.ones(5, requires_grad=True)\n        >>> def hook(grad):\n        ...     return grad * 2\n        >>> x.register_hook(hook)\n        >>> y = x * 2\n        >>> y.sum().backward()\n        >>> x.grad\n        tensor([4., 4., 4., 4., 4.], dtype=oneflow.float32)\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.retain_grad,\n    r\"\"\"\n    Enables this Tensor to have their ``grad`` populated during ``backward()``. This is a no-op\n    for leaf tensors.\n\n    Compatible with PyTorch.\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.bmm,\n    \"\"\"\n    See :func:`oneflow.bmm`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.chunk,\n    \"\"\"\n    See :func:`oneflow.chunk`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.split,\n    \"\"\"\n    See :func:`oneflow.split`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.unbind,\n    \"\"\"\n    See :func:`oneflow.unbind`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.swapaxes,\n    \"\"\"\n    See :func:`oneflow.swapaxes`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.amax,\n    \"\"\"\n    See :func:`oneflow.amax`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.swapdims,\n    \"\"\"\n    See :func:`oneflow.swapdims`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.cast,\n    \"\"\"\n    See :func:`oneflow.cast`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.diag,\n    \"\"\"\n    See :func:`oneflow.diag`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.addcdiv,\n    \"\"\"\n    See :func:`oneflow.addcdiv`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.addcdiv_,\n    \"\"\"\n    In-place version of :func:`oneflow.Tensor.addcdiv`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.dim,\n    \"\"\"\n    Tensor.dim() → int\n\n    Returns the number of dimensions of self tensor.\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.element_size,\n    \"\"\"\n    Tensor.element_size() → int\n\n    Returns the size in bytes of an individual element.\n\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.exp,\n    \"\"\"\n    See :func:`oneflow.exp`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.exp2,\n    \"\"\"\n    See :func:`oneflow.exp2`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.erf,\n    \"\"\"\n    Tensor.erf() -> Tensor\n\n    See :func:`oneflow.erf`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.erfc,\n    \"\"\"\n    Tensor.erfc() -> Tensor\n\n    See :func:`oneflow.erfc`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.erfinv,\n    \"\"\"\n    See :func:`oneflow.erfinv`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.erfinv_,\n    \"\"\"\n    Inplace version of :func:`oneflow.erfinv`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.eq,\n    \"\"\"\n    See :func:`oneflow.eq`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.equal,\n    \"\"\"\n    See :func:`oneflow.equal`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.lt,\n    \"\"\"\n    See :func:`oneflow.lt`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.le,\n    \"\"\"\n    See :func:`oneflow.le`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.ne,\n    \"\"\"\n    See :func:`oneflow.ne`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.neg,\n    \"\"\"\n    See :func:`oneflow.neg`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.norm,\n    \"\"\"\n    See :func:`oneflow.norm`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.fill_,\n    \"\"\"\n    Tensor.fill_(value) → Tensor\n\n    Fills `self` tensor with the specified value.\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.ge,\n    \"\"\"\n    See :func:`oneflow.ge`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.get_device,\n    \"\"\"\n    Tensor.get_device() -> Device ordinal (Integer)\n\n    For CUDA tensors, this function returns the device ordinal of the GPU on which the tensor resides. For CPU tensors, an error is thrown.\n\n\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.gt,\n    \"\"\"\n    See :func:`oneflow.gt`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.gt_,\n    \"\"\"Tensor.gt_(value) -> Tensor\n    In-place version of :func:`oneflow.Tensor.gt`.\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.log1p,\n    \"\"\"\n    See :func:`oneflow.log1p`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.mish,\n    \"\"\"\n    See :func:`oneflow.mish`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.mul,\n    \"\"\"Tensor.mul(value) -> Tensor\n    See :func:`oneflow.mul`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.mul_,\n    \"\"\"Tensor.mul_(value) -> Tensor\n\n    In-place version of :func:`oneflow.Tensor.mul`.\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.div_,\n    \"\"\"Tensor.div_(value) -> Tensor\n    In-place version of :func:`oneflow.Tensor.div`.\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.sub_,\n    \"\"\"Tensor.sub_(value) -> Tensor\n    In-place version of :func:`oneflow.Tensor.sub`.\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.negative,\n    \"\"\"\n    See :func:`oneflow.negative`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.nelement,\n    \"\"\"\n    Tensor.nelement() → int\n\n    Alias for numel()\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.normal_,\n    \"\"\"\n    normal_(mean=0, std=1, *, generator=None) -> Tensor\n\n    Fills :attr:`self` tensor with elements samples from the normal distribution parameterized by :attr:`mean` and :attr:`std`.\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.numpy,\n    \"\"\"\n    Tensor.numpy() → numpy.ndarray\n\n    Returns self tensor as a NumPy ndarray. This tensor and the returned ndarray share the same underlying storage. Changes to\n     self tensor will be reflected in the ndarray and vice versa.\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.pow,\n    \"\"\"\n    See :func:`oneflow.pow`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.relu,\n    \"\"\"\n    See :func:`oneflow.relu`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.roll,\n    \"\"\"\n    See :func:`oneflow.roll`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.round,\n    \"\"\"\n    See :func:`oneflow.round`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.round_,\n    \"\"\"\n    See :func:`oneflow.round_`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.reciprocal,\n    \"\"\"\n    See :func:`oneflow.reciprocal`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.add,\n    \"\"\"\n    See :func:`oneflow.add`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.addmm,\n    \"\"\"\n    See :func:`oneflow.addmm`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.add_,\n    \"\"\"\n    In-place version of :func:`oneflow.Tensor.add`.\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.addcmul,\n    \"\"\"\n    See :func:`oneflow.addcmul`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.addcmul_,\n    \"\"\"\n    In-place version of :func:`oneflow.Tensor.addcmul`.\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.asin,\n    \"\"\"\n    See :func:`oneflow.asin`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.asinh,\n    \"\"\"\n    See :func:`oneflow.asinh`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.arcsin,\n    \"\"\"\n    See :func:`oneflow.arcsin`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.arcsinh,\n    \"\"\"\n    See :func:`oneflow.arcsinh`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.sin,\n    \"\"\"\n    sin() -> Tensor\n\n    See :func:`oneflow.sin`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.sin_,\n    \"\"\"\n    See :func:`oneflow.sin_`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.cos,\n    \"\"\"\n    See :func:`oneflow.cos`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.diagonal,\n    \"\"\"\n    See :func:`oneflow.diagonal`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.log,\n    \"\"\"\n    See :func:`oneflow.log`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.log2,\n    \"\"\"\n    See :func:`oneflow.log2`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.log10,\n    \"\"\"\n    See :func:`oneflow.log10`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.ndim,\n    \"\"\"\n    See :func:`oneflow.Tensor.dim`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.rsqrt,\n    \"\"\"\n    See :func:`oneflow.rsqrt`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.cosh,\n    \"\"\"\n    See :func:`oneflow.cosh`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.atan,\n    \"\"\"\n    See :func:`oneflow.atan`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.arctan,\n    \"\"\"\n    See :func:`oneflow.arctan`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.dot,\n    \"\"\"\n    See :func:`oneflow.dot`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.selu,\n    \"\"\"\n    See :func:`oneflow.selu`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.sigmoid,\n    \"\"\"\n    See :func:`oneflow.sigmoid`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.sign,\n    \"\"\"\n    See :func:`oneflow.sign`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.silu,\n    \"\"\"\n    See :func:`oneflow.silu`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.sinh,\n    \"\"\"\n    See :func:`oneflow.sinh`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.size,\n    \"\"\"\n    Returns the size of the self tensor. If dim is not specified, the returned value is a oneflow.Size, a subclass of tuple. If dim is specified, returns an int holding the size of that dimension.\n\n    The interface is consistent with PyTorch.\n\n    Args:\n        idx (int, optional): The dimension for which to retrieve the size.\n\n\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.softmax,\n    \"\"\"\n    See :func:`oneflow.softmax`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.softplus,\n    \"\"\"\n    See :func:`oneflow.softplus`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.softsign,\n    \"\"\"\n    See :func:`oneflow.softsign`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.tan,\n    \"\"\"\n    See :func:`oneflow.tan`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.tanh,\n    \"\"\"\n    See :func:`oneflow.tanh`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.tril,\n    \"\"\"\n    See :func:`oneflow.tril`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.triu,\n    \"\"\"\n    See :func:`oneflow.triu`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.uniform_,\n    \"\"\"\n    Tensor.uniform_(from=0, to=1) → Tensor\n\n    Fills self tensor with numbers sampled from the continuous uniform distribution:\n\n    .. math::\n        P(x)=1/(to-from)\n\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.copy_,\n    \"\"\"\n    Copies the elements from src into self tensor and returns self.\n\n    The src tensor must be broadcastable with the self tensor. It may be of a different data type or reside on a different device.\n\n    The interface is consistent with PyTorch.\n\n    Args:\n\n        src (Tensor): the source tensor to copy from\n\n        non_blocking (bool): if True and this copy is between CPU and GPU, the copy may occur asynchronously with respect to the host. For other cases, this argument has no effect.\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.to,\n    \"\"\"Performs Tensor dtype and/or device conversion.\n        A flow.dtype and flow.device are inferred from the arguments of `input.to(*args, **kwargs)`.\n\n    .. note::\n        If the ``input`` Tensor already\n        has the correct :class:`flow.dtype` and :class:`flow.device`, then ``input`` is returned.\n        Otherwise, the returned tensor is a copy of ``input`` with the desired.\n\n    Args:\n        input (oneflow.Tensor): An input tensor.\n        *args (oneflow.Tensor or oneflow.device or oneflow.dtype): Positional arguments\n        **kwargs (oneflow.device or oneflow.dtype) : Key-value arguments\n\n    Returns:\n        oneflow.Tensor: A Tensor.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import numpy as np\n        >>> import oneflow as flow\n\n        >>> arr = np.random.randint(1, 9, size=(1, 2, 3, 4))\n        >>> input = flow.Tensor(arr)\n        >>> output = input.to(dtype=flow.float32)\n        >>> np.array_equal(arr.astype(np.float32), output.numpy())\n        True\n\n    \"\"\",\n)\n\n\nadd_docstr(\n    oneflow.Tensor.half,\n    \"\"\"\n    self.half() is equivalent to self.to(dtype=oneflow.float16).\n\n    See :func:`oneflow.Tensor.to`\n\n    \"\"\",\n)\n\n\nadd_docstr(\n    oneflow.Tensor.gather,\n    \"\"\"\n    oneflow.Tensor.gather(dim, index) -> Tensor\n\n    See :func:`oneflow.gather`\n\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.clamp,\n    \"\"\"\n    See :func:`oneflow.clamp`.\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.clamp_,\n    \"\"\"\n    Inplace version of :func:`oneflow.Tensor.clamp`.\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.clip,\n    \"\"\"\n    Alias for :func:`oneflow.Tensor.clamp`.\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.clip_,\n    \"\"\"\n    Alias for :func:`oneflow.Tensor.clamp_`.\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.cpu,\n    r\"\"\"Returns a copy of this object in CPU memory.\n    If this object is already in CPU memory and on the correct device, then no copy is performed and the original object is returned.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n\n        >>> input = flow.tensor([1, 2, 3, 4, 5], device=flow.device(\"cuda\"))\n        >>> output = input.cpu()\n        >>> output.device\n        device(type='cpu', index=0)\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.cuda,\n    r\"\"\"Returns a copy of this object in CUDA memory.\n    If this object is already in CUDA memory and on the correct device, then no copy is performed and the original object is returned.\n\n    Args:\n        device  (flow.device): The destination GPU device. Defaults to the current CUDA device.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n\n        >>> input = flow.Tensor([1, 2, 3, 4, 5])\n        >>> output = input.cuda()\n        >>> output.device\n        device(type='cuda', index=0)\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.cumprod,\n    \"\"\"\n    See :func:`oneflow.cumprod`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.cumsum,\n    \"\"\"\n    See :func:`oneflow.cumsum`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.repeat,\n    \"\"\"\n    Tensor.repeat(*size) -> Tensor\n\n    See :func:`oneflow.repeat`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.repeat_interleave,\n    \"\"\"\n    Tensor.repeat_interleave(repeats, dim=None, *, output_size=None) -> Tensor\n\n    See :func:`oneflow.repeat_interleave`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.t,\n    \"\"\"\n    See :func:`oneflow.t`\n\n    Tensor.t() → Tensor\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.tile,\n    \"\"\"\n    Tensor.tile(*dims) -> Tensor\n\n    See :func:`oneflow.tile`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.T,\n    \"\"\"\n    Is this Tensor with its dimensions reversed.\n\n    If `n` is the number of dimensions in `x`, `x.T` is equivalent to `x.permute(n-1, n-2, ..., 0)`.\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.fmod,\n    \"\"\"\n    Tensor.fmod(other) -> Tensor\n\n    See :func:`oneflow.fmod`\n\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.logical_and,\n    \"\"\"\n    logical_and() -> Tensor\n\n    See :func:`oneflow.logical_and`\n\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.logical_or,\n    \"\"\"\n\n    logical_or() -> Tensor\n\n    See :func:`oneflow.logical_or`\n\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.logical_xor,\n    \"\"\"\n    logical_xor() -> Tensor\n\n    See :func:`oneflow.logical_xor`\n\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.logsumexp,\n    \"\"\"\n    See :func:`oneflow.logsumexp`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.masked_fill,\n    \"\"\"\n    See :func:`oneflow.masked_fill`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.masked_fill_,\n    \"\"\"\n    In-place version of :meth:`oneflow.Tensor.masked_fill`.\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.masked_select,\n    \"\"\"\n    See :func:`oneflow.masked_select`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.sub,\n    \"\"\"\n    See :func:`oneflow.sub`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.div,\n    \"\"\"\n    See :func:`oneflow.div`\n\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.ceil,\n    \"\"\"\n    See :func:`oneflow.ceil`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.ceil_,\n    \"\"\"\n    See :func:`oneflow.ceil_`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.expm1,\n    \"\"\"\n    See :func:`oneflow.expm1`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.topk,\n    \"\"\"\n    See :func:`oneflow.topk`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.nms,\n    \"\"\"\n    See :func:`oneflow.nms`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.nonzero,\n    \"\"\"\n    nonzero(input, as_tuple=False) -> Tensor\n\n    See :func:`oneflow.nonzero`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.max,\n    \"\"\"\n    input.max(dim, index) -> Tensor\n\n    See :func:`oneflow.max`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.min,\n    \"\"\"\n    input.min(dim, index) -> Tensor\n\n    See :func:`oneflow.min`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.maximum,\n    \"\"\"\n    See :func:`oneflow.maximum`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.median,\n    \"\"\"\n    See :func:`oneflow.median`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.minimum,\n    \"\"\"\n    See :func:`oneflow.minimum`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.mode,\n    \"\"\"\n    See :func:`oneflow.mode`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.sum,\n    \"\"\"\n    input.sum(dim=None, keepdim=False) -> Tensor\n\n    See :func:`oneflow.sum`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.all,\n    \"\"\"\n    input.all(dim=None, keepdim=False) -> Tensor\n\n    See :func:`oneflow.all`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.any,\n    \"\"\"\n    input.any(dim=None, keepdim=False) -> Tensor\n\n    See :func:`oneflow.any`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.mean,\n    \"\"\"\n    input.mean(dim=None, keepdim=False) -> Tensor\n\n    See :func:`oneflow.mean`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.prod,\n    \"\"\"\n    input.prod(dim=None, keepdim=False) -> Tensor\n\n    See :func:`oneflow.prod`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.reshape,\n    \"\"\"\n    See :func:`oneflow.reshape`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.reshape_as,\n    \"\"\"\n    Tensor.reshape_as(other) -> Tensor\n    Returns this tensor as the same shape as other.\n    self.reshape_as(other) is equivalent to self.reshape(other.sizes()).\n    This method returns a view if other.sizes() is compatible with the current shape.\n    See :func:`oneflow.Tensor.view` on when it is possible to return a view.\n\n    Please see reshape() for more information about reshape. See :func:`oneflow.reshape`\n\n    Parameters\n    other (oneflow.Tensor) – The result tensor has the same shape as other.\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.view,\n    \"\"\"\n    Returns a new tensor with the same data as the :attr:`self` tensor but of a\n    different :attr:`shape`.\n\n    The returned tensor shares the same data and must have the same number\n    of elements, but may have a different size. For a tensor to be viewed, the new\n    view size must be compatible with its original size and stride, i.e., each new\n    view dimension must either be a subspace of an original dimension, or only span\n    across original dimensions :math:`d, d+1, \\\\dots, d+k` that satisfy the following\n    contiguity-like condition that :math:`\\\\forall i = d, \\\\dots, d+k-1`,\n\n    .. math::\n\n      \\\\text{stride}[i] = \\\\text{stride}[i+1] \\\\times \\\\text{size}[i+1]\n\n    Otherwise, it will not be possible to view :attr:`self` tensor as :attr:`shape`\n    without copying it (e.g., via :meth:`contiguous`). When it is unclear whether a\n    :meth:`view` can be performed, it is advisable to use :meth:`reshape`, which\n    returns a view if the shapes are compatible, and copies (equivalent to calling\n    :meth:`contiguous`) otherwise.\n\n    The interface is consistent with PyTorch.\n    The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.Tensor.view.html.\n\n    Args:\n        input: A Tensor.\n        *shape: flow.Size or int...\n    Returns:\n        A Tensor has the same type as `input`.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import numpy as np\n        >>> import oneflow as flow\n\n        >>> x = np.array(\n        ...    [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]]\n        ... ).astype(np.float32)\n        >>> input = flow.Tensor(x)\n\n        >>> y = input.view(2, 2, 2, -1).numpy().shape\n        >>> y\n        (2, 2, 2, 2)\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.view_as,\n    \"\"\"\n    Tensor.view_as(other) -> Tensor\n\n    Expand this tensor to the same size as :attr:`other`.\n    ``self.view_as(other)`` is equivalent to ``self.view(other.size())``.\n    \n    Please see :meth:`~Tensor.view` for more information about ``view``.\n\n    Args:\n        other (:class:`oneflow.Tensor`): The result tensor has the same size\n            as :attr:`other`.\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.sort,\n    \"\"\"\n    See :func:`oneflow.sort`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.type_as,\n    r\"\"\"Returns this tensor cast to the type of the given tensor.\n        This is a no-op if the tensor is already of the correct type.\n\n    Args:\n        input  (Tensor): the input tensor.\n        target (Tensor): the tensor which has the desired type.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n\n        >>> input = flow.tensor(np.random.randn(1, 2, 3), dtype=flow.float32)\n        >>> target = flow.tensor(np.random.randn(4, 5, 6), dtype = flow.int32)\n        >>> input = input.type_as(target)\n        >>> input.dtype\n        oneflow.int32\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.bool,\n    r\"\"\"``Tensor.bool()`` is equivalent to ``Tensor.to(oneflow.bool)``. See :class:`oneflow.Tensor.to()`.\n\n    Args:\n        input  (Tensor): the input tensor.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n\n        >>> input = flow.tensor(np.random.randn(1, 2, 3), dtype=flow.float32)\n        >>> input = input.bool()\n        >>> input.dtype\n        oneflow.bool\n\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.int,\n    r\"\"\"``Tensor.int()`` is equivalent to ``Tensor.to(flow.int32)``. See :class:`oneflow.Tensor.to()`.\n\n    Args:\n        input  (Tensor): the input tensor.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n\n        >>> input = flow.tensor(np.random.randn(1, 2, 3), dtype=flow.float32)\n        >>> input = input.int()\n        >>> input.dtype\n        oneflow.int32\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.long,\n    r\"\"\"``Tensor.long()`` is equivalent to ``Tensor.to(flow.int64)``. See :class:`oneflow.Tensor.to()`.\n\n    Args:\n        input  (Tensor): the input tensor.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n\n        >>> input = flow.tensor(np.random.randn(1, 2, 3), dtype=flow.float32)\n        >>> input = input.long()\n        >>> input.dtype\n        oneflow.int64\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.float,\n    r\"\"\"``Tensor.float()`` is equivalent to ``Tensor.to(flow.float32)``. See :class:`oneflow.Tensor.to()`.\n\n    Args:\n        input  (Tensor): the input tensor.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n\n        >>> input = flow.tensor(np.random.randn(1, 2, 3), dtype=flow.int)\n        >>> input = input.float()\n        >>> input.dtype\n        oneflow.float32\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.double,\n    r\"\"\"``Tensor.double()`` is equivalent to ``Tensor.to(flow.float64)``. See :class:`oneflow.Tensor.to()`.\n\n    Args:\n        input  (Tensor): the input tensor.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n\n        >>> input = flow.tensor(np.random.randn(1, 2, 3), dtype=flow.int)\n        >>> input = input.double()\n        >>> input.dtype\n        oneflow.float64\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.is_contiguous,\n    r\"\"\"\n    Tensor.is_contiguous() -> bool\n\n    Returns True if `self` tensor is contiguous in memory.\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.is_cuda,\n    r\"\"\"\n    Tensor.is_cuda() -> bool\n    \n    Is `True` if the Tensor is stored on the GPU, `False` otherwise.\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.is_floating_point,\n    \"\"\"\n    See :func:`oneflow.is_floating_point`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.item,\n    r\"\"\"Returns the value of this tensor as a standard Python number. This only works for tensors with one element.\n    For other cases, see tolist().\n\n    This operation is not differentiable.\n\n    Args:\n        input  (Tensor): the input tensor.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> x = flow.tensor([1.0])\n        >>> x.item()\n        1.0\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.tolist,\n    r\"\"\"Returns the tensor as a (nested) list. For scalars, a standard Python number is returned,\n    just like with `item()`. Tensors are automatically moved to the CPU first if necessary.\n\n    This operation is not differentiable.\n\n    Args:\n        input  (Tensor): the input tensor.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> input = flow.tensor([[1,2,3], [4,5,6]])\n        >>> input.tolist()\n        [[1, 2, 3], [4, 5, 6]]\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.where,\n    \"\"\"\n    See :func:`oneflow.where`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.zero_,\n    r\"\"\"\n    Tensor.zero_() -> Tensor\n    \n    Fills `self` tensor with zeros.\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.isnan,\n    \"\"\"\n    See :func:`oneflow.isnan`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.isinf,\n    \"\"\"\n    See :func:`oneflow.isinf`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.byte,\n    \"\"\"\n    self.byte() is equivalent to self.to(oneflow.uint8).\n    See :func:`oneflow.Tensor.to`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.amin,\n    \"\"\"\n    See :func:`oneflow.amin`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.pin_memory,\n    r\"\"\"\n    Tensor.pin_memory() -> Tensor\n\n    Copies the tensor to pinned memory, if it’s not already pinned.\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.is_pinned,\n    r\"\"\"\n    Tensor.is_pinned() -> bool\n\n    Returns true if this tensor resides in pinned memory.\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.type,\n    r\"\"\"\n    type(dtype=None, non_blocking=False, **kwargs) -> str or Tensor\n\n    Returns the type if dtype is not provided, else casts this object to the specified type.\n\n    If this is already of the correct type, no copy is performed and the original object is returned.\n\n    Args:\n        dtype (oneflow.dtype or oneflow.tensortype or string, optional): The desired type.\n        non_blocking (bool): (**Not Implemented yet**) If True, and the source is in pinned memory\n            and destination is on the GPU or vice versa, the copy is performed asynchronously with respect to the host.\n            Otherwise, the argument has no effect.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> a = flow.tensor([1, 2], dtype=flow.float32)\n        >>> a.type()\n        'oneflow.FloatTensor'\n        >>> a.type(flow.int8)  # dtype input\n        tensor([1, 2], dtype=oneflow.int8)\n        >>> a.type(flow.cuda.DoubleTensor)  # tensortype input\n        tensor([1., 2.], device='cuda:0', dtype=oneflow.float64)\n        >>> a.type(\"oneflow.HalfTensor\")  # string input\n        tensor([1., 2.], dtype=oneflow.float16)\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.scatter,\n    \"\"\"\n    See :func:`oneflow.scatter`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.scatter_,\n    \"\"\"\n    Inplace version of :func:`oneflow.Tensor.scatter`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.scatter_add,\n    \"\"\"\n    See :func:`oneflow.scatter_add`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.scatter_add_,\n    \"\"\"\n    Inplace version of :func:`oneflow.Tensor.scatter_add`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.cross,\n    \"\"\"\n    See :func:`oneflow.cross`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.nansum,\n    \"\"\"\n    See :func:`oneflow.nansum`\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> x = flow.tensor([1., 2., float(\"nan\")])\n        >>> x.nansum()\n        tensor(3., dtype=oneflow.float32)\n        >>> x = flow.tensor([[1., float(\"nan\")], [float(\"nan\"), 2]])\n        >>> x.nansum(dim=1, keepdim=True)\n        tensor([[1.],\n                [2.]], dtype=oneflow.float32)\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.bincount,\n    \"\"\"\n    See :func:`oneflow.bincount`\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> x = flow.Tensor([0, 2, 3]).int()\n        >>> x.bincount()\n        tensor([1, 0, 1, 1], dtype=oneflow.int64)\n        >>> weight = flow.Tensor([0.1, 0.2, 0.3])\n        >>> x.bincount(weight)\n        tensor([0.1000, 0.0000, 0.2000, 0.3000], dtype=oneflow.float32)\n        >>> x.bincount(weight, minlength=5)\n        tensor([0.1000, 0.0000, 0.2000, 0.3000, 0.0000], dtype=oneflow.float32)\n\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.bernoulli,\n    \"\"\"\n    See :func:`oneflow.bernoulli`\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> x = flow.Tensor([1, 1, 1])\n        >>> x.bernoulli()\n        tensor([1., 1., 1.], dtype=oneflow.float32)\n        >>> x.bernoulli(p=0.0)\n        tensor([0., 0., 0.], dtype=oneflow.float32)\n\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.bernoulli_,\n    \"\"\"\n    The inplace version of :func:`oneflow.Tensor.bernoulli_`.\n\n    See :func:`oneflow.Tensor.bernoulli`\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> x = flow.Tensor([1, 1, 1])\n        >>> x.bernoulli_(p=0.0)\n        tensor([0., 0., 0.], dtype=oneflow.float32)\n        >>> x\n        tensor([0., 0., 0.], dtype=oneflow.float32)\n\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.broadcast_to,\n    \"\"\"\n    See :func:`oneflow.broadcast_to`\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.unique,\n    \"\"\"\n    See :func:`oneflow.unique`\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> x = flow.tensor([3, 1, 2, 0 ,2])\n        >>> x.unique()\n        tensor([0, 1, 2, 3], dtype=oneflow.int64)\n        >>> x, indices = x.unique(return_inverse=True)\n        >>> indices\n        tensor([3, 1, 2, 0, 2], dtype=oneflow.int32)\n        >>> x, counts = x.unique(return_counts=True)\n        >>> counts\n        tensor([1, 1, 1, 1], dtype=oneflow.int32)\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.clone,\n    \"\"\"\n    See :func:`oneflow.clone`\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> x = flow.tensor([1, 2, 3])\n        >>> x.clone()\n        tensor([1, 2, 3], dtype=oneflow.int64)\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.bitwise_and,\n    \"\"\"\n    See :func:`oneflow.bitwise_and`\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> x = flow.tensor([1, 2, 3])\n        >>> x.bitwise_and(4)\n        tensor([0, 0, 0], dtype=oneflow.int64)\n        >>> y = flow.tensor([2, 1, 0])\n        >>> x.bitwise_and(y)\n        tensor([0, 0, 0], dtype=oneflow.int64)\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.bitwise_or,\n    \"\"\"\n    See :func:`oneflow.bitwise_or`\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> x = flow.tensor([1, 2, 3])\n        >>> x.bitwise_or(4)\n        tensor([5, 6, 7], dtype=oneflow.int64)\n        >>> y = flow.tensor([2, 1, 0])\n        >>> x.bitwise_or(y)\n        tensor([3, 3, 3], dtype=oneflow.int64)\n\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.bitwise_xor,\n    \"\"\"\n    See :func:`oneflow.bitwise_xor`\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> x = flow.tensor([1, 2, 3])\n        >>> x.bitwise_xor(4)\n        tensor([5, 6, 7], dtype=oneflow.int64)\n        >>> y = flow.tensor([2, 1, 0])\n        >>> x.bitwise_xor(y)\n        tensor([3, 3, 3], dtype=oneflow.int64)\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.new,\n    \"\"\"\n    Constructs a new tensor of the same data type and device (or placemant and sbp) as self tensor.\n\n    Any valid argument combination to the tensor constructor is accepted by this method,\n    including sizes, NumPy ndarray, Python Sequence, etc. See :func:`oneflow.Tensor` for more details.\n\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> x = flow.randn(3, 2)\n        >>> x.new()\n        tensor([], dtype=oneflow.float32)\n        >>> x.new(1, 2).shape\n        oneflow.Size([1, 2])\n        >>> x.new([1, 2])\n        tensor([1., 2.], dtype=oneflow.float32)\n        >>> y = flow.randn(3, 3)\n        >>> x.new(y).shape\n        oneflow.Size([3, 3])\n\n    .. warning::\n        When y is global tensor, the invoking ``Tensor.new(y)`` will raise an error.\n        Consider use ``Tensor.new(y.size())`` to create a tensor that has\n        the same placement and sbp with Tensor and the same size with ``y``.\n\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.baddbmm,\n    \"\"\"\n    See :func:`oneflow.baddbmm`\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> x = flow.randn(2, 3, 4)\n        >>> batch1 = flow.randn(2, 3, 5)\n        >>> batch2 = flow.randn(2, 5, 4)\n        >>> x.baddbmm(batch1, batch2, alpha=2, beta=2) # doctest: +SKIP\n    \"\"\",\n)\n\n\nadd_docstr(\n    oneflow.Tensor.frac,\n    r\"\"\"\n    See :func:`oneflow.frac`.\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.frac_,\n    r\"\"\"\n    In-place version of :func:`oneflow.Tensor.frac`.\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.Tensor.digamma,\n    \"\"\"\n    See :func:`oneflow.digamma`\n    \"\"\",\n)\n"
  },
  {
    "path": "python/oneflow/framework/docstr/tensor_attributes.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow\nfrom oneflow.framework.docstr.utils import add_docstr, reset_docstr\n\noneflow.device.__doc__ = r\"\"\"\n    A :class:`oneflow.device` is an object representing the device on which a :class:`oneflow.Tensor` is or will be allocated.\n\n    The documentation is referenced from:\n    https://pytorch.org/docs/1.10/tensor_attributes.html#torch.torch.device.\n    \n    The :class:`oneflow.device` contains a device type ('cpu' or 'cuda') and optional device ordinal for the device type. If the \n    device ordinal is not present, this object will always represent the current device for the device type.\n\n    A :class:`oneflow.device`’s device can be accessed via the Tensor.device property.\n\n    A :class:`oneflow.device` can be constructed via a string or via a string and device ordinal\n\n    Via a string:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> flow.device('cuda:0')\n        device(type='cuda', index=0)\n\n        >>> flow.device('cpu')\n        device(type='cpu', index=0)\n\n        >>> flow.device('cuda')  # current cuda device\n        device(type='cuda', index=0)\n    \n    Via a string and device ordinal:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> flow.device('cuda', 0)\n        device(type='cuda', index=0)\n\n        >>> flow.device('cpu', 0)\n        device(type='cpu', index=0)\n    \n    Note:\n        The :class:`oneflow.device` argument in functions can generally be substituted with a string. This allows for fast prototyping of code.\n        \n        .. code-block:: python\n\n            >>> import oneflow as flow\n            >>> # Example of a function that takes in a oneflow.device\n            >>> cuda0 = flow.device('cuda:0')\n            >>> x = flow.randn(2,3, device=cuda0)\n        \n        .. code-block:: python\n\n            >>> # You can substitute the flow.device with a string\n            >>> x = flow.randn(2,3, device='cuda:0')\n\n\"\"\"\n\n\noneflow.placement.__doc__ = r\"\"\"\n    A ``oneflow.placement`` is an object representing the device group on which a :class:`oneflow.Tensor` is or will be allocated. The ``oneflow.placement`` contains a device type ('cpu' or 'cuda') and corresponding device sequence.\n    \n    A :class:`oneflow.Tensor`'s placement can be accessed via the Tensor.placement property.\n    \n    A oneflow.placement can be constructed in several ways:\n    \n    .. code-block:: python\n\n        >>> import oneflow as flow\n        \n        >>> p = flow.placement(type=\"cuda\", ranks=[0, 1, 2, 3])\n        >>> p\n        oneflow.placement(type=\"cuda\", ranks=[0, 1, 2, 3])\n        >>> p = flow.placement(type=\"cuda\", ranks=[[0, 1], [2, 3]])\n        >>> p\n        oneflow.placement(type=\"cuda\", ranks=[[0, 1], [2, 3]])\n        \n    \"\"\"\n\nreset_docstr(\n    oneflow.placement.all,\n    r\"\"\"\n    oneflow.placement.all(device_type) -> oneflow.placement\n\n    Returns a placement that contains all available devices.\n\n    Args:\n        device_type (str): cuda or cpu\n\n    For examples:\n\n    .. code-block:: python\n\n        # Runs on 4 ranks\n        import oneflow as flow\n\n        p = flow.placement.all(\"cuda\") # oneflow.placement(type=\"cuda\", ranks=[0, 1, 2, 3])\n        p = flow.placement.all(\"cpu\") # oneflow.placement(type=\"cpu\", ranks=[0, 1, 2, 3])\n\n    \"\"\",\n)\n\noneflow.sbp.sbp.__doc__ = r\"\"\"\n    A ``oneflow.sbp`` is an object representing that how the data of the global tensor is distributed across the ranks of the ``Tensor`` placement.\n\n    ``oneflow.sbp`` includes three types:\n\n        - oneflow.sbp.split(dim)\n\n          Indicates that the global tensor is evenly divided according to the dimension `dim` and distributed on each rank.\n\n        - oneflow.sbp.broadcast()\n\n          Indicates that the global tensor is replicated on each rank.\n\n        - oneflow.sbp.partial_sum()\n\n          Indicates that the value of the global tensor is element-wise sum of the local tensors distributed in each rank.\n\n\n    A :class:`oneflow.Tensor`'s sbp can be accessed via the Tensor.sbp property.\n\n    A ``oneflow.sbp`` can be constructed in several ways:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n\n        >>> s = flow.sbp.split(0)\n        >>> s\n        oneflow.sbp.split(dim=0)\n        >>> b = flow.sbp.broadcast()\n        >>> b\n        oneflow.sbp.broadcast\n        >>> p = flow.sbp.partial_sum()\n        >>> p\n        oneflow.sbp.partial_sum\n    \"\"\"\n"
  },
  {
    "path": "python/oneflow/framework/docstr/tensor_ops.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow\nfrom oneflow.framework.docstr.utils import add_docstr\n\nadd_docstr(\n    oneflow.narrow,\n    r\"\"\"\n    narrow(x, dim: int, start: int, length: int) -> Tensor\n    \n    Returns a new tensor that is a narrowed version of `input` tensor.\n    The dimension `dim` is input from `start` to `start + length`.\n\n    Args:\n        input: the tensor to narrow.\n        dim: the dimension along which to narrow.\n        start: the starting dimension.\n        length: the distance to the ending dimension.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> input = flow.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])\n        >>> flow.narrow(input, 0, 0, 2)\n        tensor([[1, 2, 3],\n                [4, 5, 6]], dtype=oneflow.int64)\n        >>> flow.narrow(input, 1, 1, 2)\n        tensor([[2, 3],\n                [5, 6],\n                [8, 9]], dtype=oneflow.int64)\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.unsqueeze,\n    r\"\"\"\n    unsqueeze(input, dim) -> Tensor\n    \n    Returns a new tensor with a dimension of size one inserted at the\n    specified position.\n\n    The returned tensor shares the same underlying data with this tensor.\n\n    A :attr:`dim` value within the range `[-input.ndimension() - 1, input.ndimension() + 1)`\n    can be used. Negative :attr:`dim` will correspond to :meth:`unsqueeze`\n    applied at :attr:`dim` = ``dim + input.ndimension() + 1``.\n\n    Args:\n        input (Tensor): the input tensor.\n        dim (int): the index at which to insert the singleton dimension\n\n    For example: \n\n    .. code-block:: python \n\n        >>> import numpy as np\n        >>> import oneflow as flow\n        \n        >>> x = flow.randn(2, 3, 4)\n        >>> y = x.unsqueeze(2)\n        >>> y.shape\n        oneflow.Size([2, 3, 1, 4])\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.permute,\n    r\"\"\"\n    permute(input, *dims) -> Tensor\n\n    Returns a view of the original tensor with its dimensions permuted.\n\n    Args:\n        dims (tuple of ints): The desired ordering of dimensions\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import numpy as np\n        >>> import oneflow as flow\n        \n        >>> input = flow.tensor(np.random.randn(2, 6, 5, 3), dtype=flow.float32)\n        >>> output = flow.permute(input, (1, 0, 2, 3)).shape\n        >>> output\n        oneflow.Size([6, 2, 5, 3])\n\n    \"\"\",\n)\n"
  },
  {
    "path": "python/oneflow/framework/docstr/tensor_t.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow\nfrom oneflow.framework.docstr.utils import add_docstr\n\nadd_docstr(\n    oneflow.t,\n    \"\"\"\n    oneflow.t(input) → Tensor.\n\n        Expects `input` to be <= 2-D tensor and transposes dimensions 0 and 1. \n\n        0-D and 1-D tensors are returned as is. When input is a 2-D tensor this is equivalent to `transpose(input, 0, 1)`.\n\n    Args:\n        input (oneflow.Tensor): An input tensor.   \n \n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n\n        >>> x = flow.tensor(np.random.randn(), dtype=flow.float32)\n        >>> flow.t(x).shape\n        oneflow.Size([])\n        >>> x = flow.tensor(np.random.randn(3), dtype=flow.float32)\n        >>> flow.t(x).shape\n        oneflow.Size([3])\n        >>> x = flow.tensor(np.random.randn(2,3), dtype=flow.float32)\n        >>> flow.t(x).shape\n        oneflow.Size([3, 2])\n    \n    \"\"\",\n)\n"
  },
  {
    "path": "python/oneflow/framework/docstr/tensordot.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow\nfrom oneflow.framework.docstr.utils import add_docstr\n\nadd_docstr(\n    oneflow.tensordot,\n    r\"\"\"\n    tensordot(a, b, dims=Union[int, Tensor, Tuple[List[int], List[int]], List[List[int]]], out=None) -> Tensor\n    \n    Compute tensor dot along given dimensions.\n    \n    Given two tensors a and b, and dims which represent two lists containing dim indices, `tensordot` traverses the two\n    lists and calculate the tensor dot along every dim pair.\n\n    Args:\n        a (oneflow.Tensor): The input tensor to compute tensordot\n        b (oneflow.Tensor): The input tensor to compute tensordot\n        dims (int or list or tuple or oneflow.Tensor):\n            The dims to calculate tensordot.\n            If it's an integer or oneflow.Tensor with only one element,\n            the last dims of tensor `a` and the first dims of tensor `b` will be calculated.\n            If it's a list or tuple or oneflow.Tensor with more than one element,\n            it must contain two array-like object, which represent the dims of tensor a and tensor b to be calculated.\n        out (oneflow.Tensor): The tensor to save result (NOT IMPLEMENTED YET)\n        \n    Returns:\n        oneflow.Tensor: The result tensor\n\n    For example:\n    \n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> a = flow.randn(3, 4, 5)\n        >>> b = flow.randn(4, 5, 6)\n        >>> flow.tensordot(a, b, dims=2).shape\n        oneflow.Size([3, 6])\n        >>> b = flow.randn(5, 6, 7)\n        >>> flow.tensordot(a, b, dims=1).shape\n        oneflow.Size([3, 4, 6, 7])\n        >>> b = flow.randn(3, 4, 7)\n        >>> flow.tensordot(a, b, dims=[[0, 1], [0, 1]]).shape\n        oneflow.Size([5, 7])\n    \n    Note:\n\n        Three common use cases are:\n\n        - dims = 0 : tensor product :math:`a \\otimes b`\n\n        - dims = 1 : tensor dot product :math:`a \\cdot b`\n\n        - dims = 2 : (default) tensor double contraction :math:`a : b`\n\n        The part of documentation is referenced from https://numpy.org/doc/stable/reference/generated/numpy.tensordot.html.\n\n\n    Note:\n        The operation is equivalent to the series of operations:\n\n        - Permute the dimensions of the tensor A that require tensordot to the end\n\n        - Permute the dimensions of the tensor B that require tensordot to the start\n\n        - Reshape the permuted tensor A into a 2-dimensional tensor, where the size of the 0th dimension is the product of the dimensions that do not require dot product, and the size of the 1st dimension is the product of the dimensions that require dot product\n\n        - Reshape the permuted tensor B into a 2-dimensional tensor, where the size of the 0th dimension is the product of the dimensions that require dot product, and the size of the 1st dimension is the product of the dimensions that do not require dot product\n\n        - Calculate the matrix multiplication of reshaped tensor A and reshaped tensor B\n\n        - Reshape the result of matrix multiplication, the target shape is the concatenation of the dimensions that do not require tensordot of tensor A and B\n\n    This series of operations can be equivalently represented by the following code:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> a = flow.randn(2, 4, 3)\n        >>> b = flow.randn(3, 4, 2)\n        >>> dims = [[0, 2], [2, 0]]\n        >>> permuted_a = a.permute(1, 0, 2) # 0, 2 are the dimensions requiring tensordot and are placed in the end in permuting\n        >>> permuted_b = b.permute(2, 0, 1) # 2, 0 are the dimensions requiring tensordot and are placed at the beginning in permuting\n        >>> reshaped_a = permuted_a.reshape(4, 2 * 3) # 4 is the dimensions of a that do not require tensordot\n        >>> reshaped_b = permuted_b.reshape(2 * 3, 4) # 4 is the dimensions of a that do not require tensordot\n        >>> matmul_result = flow.matmul(reshaped_a, reshaped_b)\n        >>> result = matmul_result.reshape(4, 4) # 4, 4 are the concatentation of dimensions that do not require tensordot of a and b\n        >>> flow.all(result == flow.tensordot(a, b, dims))\n        tensor(True, dtype=oneflow.bool)\n\n    ..\n        Feature Stage of Operator [tensordot].\n        - Maintainer List [@marigoold]\n        - Current Stage [ ]\n        - Alpha Stage Check List [ ]\n          - API(Compatible with PyTorch 1.11, anything incompatible must be noted in API Doc.)[Yes]\n          - Doc(API Doc must be provided and showed normally on the web page.)[Yes]\n          - Functionality and its' Test [ ]\n            - Functionality is highly compatiable with PyTorch 1.11. [ ] (out parameter is not implemented yet)\n            - eager local [Yes]\n              - forward [Yes]\n              - backward [Yes]\n              - gpu [Yes]\n              - cpu [Yes]\n            - graph local [ ] (when the type of param `dims` is oneflow.Tensor, the tensor.item() will make graph fail)\n              - forward [ ]\n              - backward [ ]\n              - gpu [ ]\n              - cpu [ ]\n          - Exception Handling\n            - Exception Message and Hint must be provided [Yes]\n        - Beta Stage Check List [ ]\n          - API(High compatibility with PyTorch 1.11, shouldn't have anything incompatible for a naive reason.)[ ]\n          - Doc(Same standard as Alpha Stage)[ ]\n          - Functionality and its' Test [ ]\n            - eager global [ ]\n              - forward [ ]\n              - backward [ ]\n              - gpu [ ]\n              - cpu [ ]\n            - graph gloal [ ]\n              - forward [ ]\n              - backward [ ]\n              - gpu [ ]\n              - cpu [ ]\n          - Performance and Scalability(Must be evaluated.)[ ]\n            - CUDA kernel [ ]\n            - CPU kernel [ ]\n            - N nodes M devices [ ]\n          - Exception Handling [ ]\n            - Exception Message and Hint must be provided [ ]\n            - Try you best to do Exception Recovery [ ]\n        - Stable Stage Check List [ ]\n          - API(Same standard as Beta Stage)[ ]\n          - Doc(Same standard as Beta Stage)[ ]\n          - Functionality and its' Test [ ]\n            - fp16 and AMP [ ]\n            - NHWC [ ]\n          - Performance and Scalability(Must be evaluated.)[ ]\n          - Exception Handling [ ]\n    \"\"\",\n)\n"
  },
  {
    "path": "python/oneflow/framework/docstr/tile.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow\nfrom oneflow.framework.docstr.utils import add_docstr\n\nadd_docstr(\n    oneflow.tile,\n    \"\"\"\n    tile(input, dims) -> Tensor\n\n    Constructs a tensor by repeating the elements of ``input``.  The ``dims`` argument specifies the number\n    of repetitions in each dimension.\n\n    If ``dims`` specifies fewer dimensions than ``input`` has, then ones are prepended to ``dims`` until\n    all dimensions are specified.  For example, if ``input`` has shape (8, 6, 4, 2) and ``dims`` is (2, 2),\n    then ``dims`` is treated as (1, 1, 2, 2).\n\n    Analogously, if ``input`` has fewer dimensions than ``dims`` specifies, then ``input`` is treated as\n    if it were unsqueezed at dimension zero until it has as many dimensions as ``dims`` specifies.\n    For example, if ``input`` has shape (4, 2) and ``dims`` is (3, 3, 2, 2), then ``input`` is treated as\n    if it had the shape (1, 1, 4, 2).\n\n    .. note::\n        This function is similar to NumPy’s tile function.\n    \n    The interface is consistent with PyTorch.\n    The documentation is referenced from:\n    https://pytorch.org/docs/1.10/generated/torch.tile.html.\n\n    Args:\n        input (oneflow.Tensor): the tensor whose elements to repeat.\n        dims (tuple): the number of repetitions per dimension.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n        \n        >>> np_arr = np.random.randn(5, 3, 6, 9).astype(np.float32)\n        >>> input = flow.Tensor(np_arr)\n        >>> out = input.tile(2,1,2,1)\n        >>> out.shape\n        oneflow.Size([10, 3, 12, 9])\n        >>> x = np.random.randn(5, 2, 1)\n        >>> input = flow.Tensor(x)\n        >>> out = input.tile(3,4)\n        >>> out.shape\n        oneflow.Size([5, 6, 4])\n    \"\"\",\n)\n"
  },
  {
    "path": "python/oneflow/framework/docstr/topk.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow\nfrom oneflow.framework.docstr.utils import add_docstr\n\nadd_docstr(\n    oneflow.topk,\n    \"\"\"Finds the values and indices of the k largest entries at specified axis.\n\n    Args:\n        input (oneflow.Tensor): Input Tensor\n        k (int): the k in “top-k”\n        dim (int, optional): the dimension to sort along. Defaults to the last dim (-1)\n        largest (bool, optional): controls whether to return largest or smallest elements\n        sorted (bool, optional): controls whether to return the elements in sorted order (Only Support True Now!)\n\n    Returns:\n        Tuple(oneflow.Tensor, oneflow.Tensor(dtype=int32)): A tuple of (values, indices), where\n        the indices are the indices of the elements in the original input tensor.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n        >>> x = np.array([[1, 3, 8, 7, 2], [1, 9, 4, 3, 2]], dtype=np.float32)\n        >>> result = flow.topk(flow.Tensor(x), k=3, dim=1)\n        >>> result.values\n        tensor([[8., 7., 3.],\n                [9., 4., 3.]], dtype=oneflow.float32)\n        >>> result.indices\n        tensor([[2, 3, 1],\n                [1, 2, 3]], dtype=oneflow.int64)\n        >>> result.values.shape\n        oneflow.Size([2, 3])\n        >>> result.indices.shape\n        oneflow.Size([2, 3])\n        >>> result = flow.topk(flow.Tensor(x), k=2, dim=1, largest=False)\n        >>> result.values\n        tensor([[1., 2.],\n                [1., 2.]], dtype=oneflow.float32)\n        >>> result.indices\n        tensor([[0, 4],\n                [0, 4]], dtype=oneflow.int64)\n        >>> result.values.shape\n        oneflow.Size([2, 2])\n        >>> result.indices.shape\n        oneflow.Size([2, 2])\n\n    \"\"\",\n)\n"
  },
  {
    "path": "python/oneflow/framework/docstr/trigonometric_ops.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow\nfrom oneflow.framework.docstr.utils import add_docstr\n\n\nadd_docstr(\n    oneflow.atan2,\n    \"\"\"Element-wise arctangent of input{i}/other{i}\n    with consideration of the quadrant. Returns a new tensor with the signed\n    angles in radians between vector (other{i},input{i}) and vector (1, 0).\n\n    The shapes of input and other must be broadcastable.\n\n    Args:\n        input (Tensor): the first input tensor.\n\n        other (Tensor): the second input tensor.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n\n        >>> x1 = flow.Tensor(np.array([1,2,3]))\n        >>> y1 = flow.Tensor(np.array([3,2,1]))\n        >>> x2 = flow.Tensor(np.array([1.53123589,0.54242598,0.15117185]))\n        >>> y2 = flow.Tensor(np.array([-0.21906378,0.09467151,-0.75562878]))\n        >>> x3 = flow.Tensor(np.array([1,0,-1]))\n        >>> y3 = flow.Tensor(np.array([0,1,0]))\n\n        >>> flow.atan2(x1,y1).numpy()\n        array([0.32175055, 0.7853982 , 1.2490457 ], dtype=float32)\n        >>> flow.atan2(x2,y2).numpy()\n        array([1.7128955, 1.3980033, 2.9441385], dtype=float32)\n        >>> flow.atan2(x3,y3).numpy()\n        array([ 1.5707964,  0.       , -1.5707964], dtype=float32)\n    \"\"\",\n)\n"
  },
  {
    "path": "python/oneflow/framework/docstr/unbind.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow\nfrom oneflow.framework.docstr.utils import add_docstr\n\nadd_docstr(\n    oneflow.unbind,\n    \"\"\"\n    Removes a tensor dimension.\n\n    Returns a tuple of all slices along a given dimension, already without it.\n    \n    This function is equivalent to PyTorch's unbind function.\n\n    Args:\n        x(Tensor): the tensor to unbind\n        dim(int): dimension to remove\n\n    For example:\n\n    .. code-block:: python\n    \n        >>> import oneflow as flow\n               \n        >>> x = flow.tensor(range(12)).reshape([3,4])\n        >>> flow.unbind(x)\n        (tensor([0, 1, 2, 3], dtype=oneflow.int64), tensor([4, 5, 6, 7], dtype=oneflow.int64), tensor([ 8,  9, 10, 11], dtype=oneflow.int64))\n        >>> flow.unbind(x, 1)\n        (tensor([0, 4, 8], dtype=oneflow.int64), tensor([1, 5, 9], dtype=oneflow.int64), tensor([ 2,  6, 10], dtype=oneflow.int64), tensor([ 3,  7, 11], dtype=oneflow.int64))\n\n    \"\"\",\n)\n"
  },
  {
    "path": "python/oneflow/framework/docstr/util_ops.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow\nfrom oneflow.framework.docstr.utils import add_docstr\n\nadd_docstr(\n    oneflow.isnan,\n    \"\"\"\n    isnan(input) -> Tensor \n    \n    This function is equivalent to PyTorch’s isnan function. \n    The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.isnan.html?highlight=isnan#torch.isnan\n\n    Returns a new tensor with boolean elements representing if each element of input is NaN or not.\n\n    Args:\n        input(Tensor): the input tensor.\n\n    Returns:\n        A boolean tensor that is True where input is NaN and False elsewhere.\n\n    Example::\n\n        >>> import oneflow as flow\n        >>> flow.isnan(flow.tensor([1, float('nan'), 2]))\n        tensor([False,  True, False], dtype=oneflow.bool)\n\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.isinf,\n    \"\"\"\n    isinf(input) -> Tensor \n\n    This function is equivalent to PyTorch’s isinf function. \n    The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.isinf.html?highlight=isinf#torch.isinf\n\n    Tests if each element of input is infinite (positive or negative infinity) or not.\n\n    Args:\n        input(Tensor): the input tensor.\n\n    Returns:\n        A boolean tensor that is True where input is infinite and False elsewhere.\n\n    Example::\n\n        >>> import oneflow as flow\n        >>> flow.isinf(flow.tensor([1, float('inf'), 2, float('-inf'), float('nan')]))\n        tensor([False,  True, False,  True, False], dtype=oneflow.bool)\n\n    \"\"\",\n)\n\nadd_docstr(\n    oneflow.isfinite,\n    \"\"\"\n    isfinite(input) -> Tensor \n\n    This function is equivalent to PyTorch’s isfinite function. \n    The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.isfinite.html?highlight=isfinite#torch.isfinite\n\n    Returns a new tensor with boolean elements representing if each element is finite or not.\n\n    Args:\n        input(Tensor): the input tensor.\n\n    Returns:\n        A boolean tensor that is True where input is finite and False elsewhere.\n\n    Example::\n\n        >>> import oneflow as flow\n        >>> flow.isfinite(flow.tensor([1, float('inf'), 2, float('-inf'), float('nan')]))\n        tensor([ True, False,  True, False, False], dtype=oneflow.bool)\n\n    \"\"\",\n)\n"
  },
  {
    "path": "python/oneflow/framework/docstr/utils.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport oneflow._oneflow_internal\nfrom doctest import DocTestParser, DebugRunner, DocTestRunner\n\n\ndef _test_docstr(docstr, verbose=True, optionflags=0, raise_on_error=True):\n    parser = DocTestParser()\n    if raise_on_error:\n        runner = DebugRunner(verbose=verbose, optionflags=optionflags)\n    else:\n        runner = DocTestRunner(verbose=verbose, optionflags=optionflags)\n    test = parser.get_doctest(docstr, {}, __name__, __file__, 0)\n    runner.run(test)\n\n\ndef add_docstr(fun, docstr: str):\n    return oneflow._oneflow_internal.add_doc(fun, docstr)\n\n\ndef reset_docstr(o, docstr):\n    if type(o) == type:\n        assert hasattr(o, \"__doc__\"), str(o) + \" does not have a docstring!\"\n        setattr(o, \"__doc__\", docstr)\n        return o\n    else:\n        return oneflow._oneflow_internal.reset_doc(o, docstr)\n"
  },
  {
    "path": "python/oneflow/framework/docstr/vision.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow\nfrom oneflow.framework.docstr.utils import add_docstr\n\nadd_docstr(\n    oneflow._C.pad,\n    r\"\"\"\n    Pads tensor.\n\n    Padding size:\n        The padding size by which to pad some dimensions of :attr:`input`\n        are described starting from the last dimension and moving forward.\n        :math:`\\left\\lfloor\\frac{\\text{len(pad)}}{2}\\right\\rfloor` dimensions\n        of ``input`` will be padded.\n        For example, to pad only the last dimension of the input tensor, then\n        :attr:`pad` has the form\n        :math:`(\\text{padding_left}, \\text{padding_right})`;\n        to pad the last 2 dimensions of the input tensor, then use\n        :math:`(\\text{padding_left}, \\text{padding_right},`\n        :math:`\\text{padding_top}, \\text{padding_bottom})`;\n        to pad the last 3 dimensions, use\n        :math:`(\\text{padding_left}, \\text{padding_right},`\n        :math:`\\text{padding_top}, \\text{padding_bottom}`\n        :math:`\\text{padding_front}, \\text{padding_back})`.\n\n    Padding mode:\n        See :class:`oneflow.nn.ConstantPad2d`, :class:`oneflow.nn.ReflectionPad2d`, and\n        :class:`oneflow.nn.ReplicationPad2d` for concrete examples on how each of the\n        padding modes works. Constant padding is implemented for arbitrary dimensions.\n        Replicate padding is implemented for padding the last 3 dimensions of 5D input\n        tensor, or the last 2 dimensions of 4D input tensor, or the last dimension of\n        3D input tensor. Reflect padding is only implemented for padding the last 2\n        dimensions of 4D input tensor, or the last dimension of 3D input tensor.\n\n    Args:\n        input (Tensor): N-dimensional tensor\n        pad (tuple): m-elements tuple, where\n            :math:`\\frac{m}{2} \\leq` input dimensions and :math:`m` is even.\n        mode: ``'constant'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.\n            Default: ``'constant'``\n        value: fill value for ``'constant'`` padding. Default: ``0``\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n\n        >>> pad = [2, 2, 1, 1]\n        >>> input = flow.tensor(np.arange(18).reshape((1, 2, 3, 3)).astype(np.float32))\n        >>> output = flow.nn.functional.pad(input, pad, mode = \"replicate\")\n        >>> output.shape\n        oneflow.Size([1, 2, 5, 7])\n        >>> output\n        tensor([[[[ 0.,  0.,  0.,  1.,  2.,  2.,  2.],\n                  [ 0.,  0.,  0.,  1.,  2.,  2.,  2.],\n                  [ 3.,  3.,  3.,  4.,  5.,  5.,  5.],\n                  [ 6.,  6.,  6.,  7.,  8.,  8.,  8.],\n                  [ 6.,  6.,  6.,  7.,  8.,  8.,  8.]],\n        <BLANKLINE>\n                 [[ 9.,  9.,  9., 10., 11., 11., 11.],\n                  [ 9.,  9.,  9., 10., 11., 11., 11.],\n                  [12., 12., 12., 13., 14., 14., 14.],\n                  [15., 15., 15., 16., 17., 17., 17.],\n                  [15., 15., 15., 16., 17., 17., 17.]]]], dtype=oneflow.float32)\n\n    See :class:`oneflow.nn.ConstantPad2d`, :class:`oneflow.nn.ReflectionPad2d`, and :class:`oneflow.nn.ReplicationPad2d` for concrete examples on how each of the padding modes works.\n        \n    \"\"\",\n)\n"
  },
  {
    "path": "python/oneflow/framework/docstr/where.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow\nfrom oneflow.framework.docstr.utils import add_docstr\n\nadd_docstr(\n    oneflow.where,\n    \"\"\"Return a tensor of elements selected from either :attr:`x` or :attr:`y`, depending on :attr:`condition`.\n    If the element in condition is larger than 0,\n\n    it will take the `x` element, else it will take the `y` element\n\n    .. note::\n        If :attr:`x` is None and :attr:`y` is None,  flow.where(condition) is \n        identical to flow.nonzero(condition, as_tuple=True).\n        \n        The tensors :attr:`condition`, :attr:`x`, :attr:`y` must be broadcastable.\n\n    Args:\n        condition (IntTensor): When 1 (nonzero), yield x, otherwise yield y\n        x (Tensor or Scalar): value (if :attr:x is a scalar) or values selected at indices\n                            where :attr:`condition` is True\n        y (Tensor or Scalar): value (if :attr:x is a scalar) or values selected at indices\n                            where :attr:`condition` is False\n    Returns:\n        Tensor: A tensor of shape equal to the broadcasted shape of :attr:`condition`, :attr:`x`, :attr:`y`\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import numpy as np\n        >>> import oneflow as flow\n        >>> x = flow.tensor(\n        ...    np.array([[-0.4620, 0.3139], [0.3898, -0.7197], [0.0478, -0.1657]]),\n        ...    dtype=flow.float32,\n        ... )\n        >>> y = flow.tensor(np.ones(shape=(3, 2)), dtype=flow.float32)\n        >>> condition = flow.tensor(np.array([[0, 1], [1, 0], [1, 0]]), dtype=flow.int32)\n        >>> out = condition.where(x, y)\n        >>> out #doctest: +ELLIPSIS\n        tensor([[1.0000, 0.3139],\n                ...\n                [0.0478, 1.0000]], dtype=oneflow.float32)\n\n    \"\"\",\n)\n"
  },
  {
    "path": "python/oneflow/framework/dtype.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport numpy as np\n\nimport oneflow\nimport oneflow._oneflow_internal\nimport oneflow.core.common.data_type_pb2 as data_type_pb2\nfrom oneflow._oneflow_internal import (\n    set_default_dtype,\n    get_default_dtype,\n)\n\n_dtypes = [\n    oneflow.bool,\n    oneflow.float,\n    oneflow.float32,\n    oneflow.double,\n    oneflow.float64,\n    oneflow.float16,\n    oneflow.int8,\n    oneflow.int16,\n    oneflow.int32,\n    oneflow.int64,\n    oneflow.uint8,\n    oneflow.record,\n    oneflow.tensor_buffer,\n    oneflow.bfloat16,\n    oneflow.complex64,\n    oneflow.cfloat,\n    oneflow.complex128,\n    oneflow.cdouble,\n]\n\n\ndef dtypes():\n    return _dtypes\n\n\ndef convert_proto_dtype_to_oneflow_dtype(proto_dtype):\n    return oneflow._oneflow_internal.deprecated.GetDTypeByDataType(proto_dtype)\n\n\n_ONEFLOW_DTYPE_TO_NUMPY_DTYPE = {\n    # >> np_bool = np.array([1,2], dtype=bool).dtype\n    # >> np_bool == bool\n    # True\n    oneflow.bool: bool,\n    oneflow.float: np.float32,\n    oneflow.float16: np.float16,\n    oneflow.float32: np.float32,\n    oneflow.float64: np.double,\n    oneflow.double: np.double,\n    oneflow.int8: np.int8,\n    oneflow.char: np.int8,\n    oneflow.int16: np.int16,\n    oneflow.int32: np.int32,\n    oneflow.int64: np.int64,\n    oneflow.uint8: np.uint8,\n    oneflow.complex64: np.complex64,\n    oneflow.cfloat: np.complex64,\n    oneflow.complex128: np.complex128,\n    oneflow.cdouble: np.complex128,\n}\n\n\ndef convert_oneflow_dtype_to_numpy_dtype(oneflow_dtype: oneflow.dtype):\n    if oneflow_dtype not in _ONEFLOW_DTYPE_TO_NUMPY_DTYPE:\n        raise NotImplementedError\n    return _ONEFLOW_DTYPE_TO_NUMPY_DTYPE[oneflow_dtype]\n\n\ndef convert_numpy_dtype_to_oneflow_dtype(numpy_dtype: np.dtype):\n    for (k, v) in _ONEFLOW_DTYPE_TO_NUMPY_DTYPE.items():\n        if v == numpy_dtype:\n            return k\n    raise NotImplementedError\n\n\ndel data_type_pb2\ndel np\n\n\ndef set_default_tensor_type(tensor_type):\n    \"\"\"Sets the default floating point type for those source operators which create Tensor.\n\n    The default floating point type is ``oneflow.FloatTensor``.\n\n    Args:\n        tensor_type (type or string): The floating point tensor type or its name.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow\n        >>> oneflow.set_default_tensor_type(oneflow.FloatTensor)\n        >>> x = oneflow.ones(2, 3)\n        >>> x.dtype\n        oneflow.float32\n        >>> oneflow.set_default_tensor_type(\"oneflow.DoubleTensor\")\n        >>> x = oneflow.ones(2, 3)\n        >>> x.dtype\n        oneflow.float64\n        >>> oneflow.set_default_tensor_type(oneflow.FloatTensor)\n        >>> x = oneflow.tensor([1.0, 2])\n        >>> x.dtype\n        oneflow.float32\n    \"\"\"\n\n    def _import_dotted_name(name):\n        \"\"\"\n        This function quotes from: https://github.com/pytorch/pytorch/blob/master/torch/_utils.py\n        \"\"\"\n        components = name.split(\".\")\n        obj = __import__(components[0])\n        for component in components[1:]:\n            obj = getattr(obj, component)\n        return obj\n\n    if isinstance(tensor_type, str):\n        tensor_type = _import_dotted_name(tensor_type)\n    oneflow._oneflow_internal.set_default_tensor_type(tensor_type)\n\n\ndef is_floating_point(input):\n    return input.is_floating_point()\n"
  },
  {
    "path": "python/oneflow/framework/env_util.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport os\nimport socket\nimport traceback\nfrom contextlib import closing\nimport warnings\n\nimport oneflow._oneflow_internal\nimport oneflow.core.control.ctrl_bootstrap_pb2 as ctrl_bootstrap_pb\nimport oneflow.core.job.env_pb2 as env_pb\nimport oneflow.core.job.resource_pb2 as resource_util\nimport oneflow.framework.c_api_util as c_api_util\n\n\ndef api_all_device_placement(device_type: str) -> oneflow._oneflow_internal.placement:\n    r\"\"\"\n    oneflow.env.all_device_placement(device_type) -> oneflow.placement\n\n    Returns a placement that contains all available devices.\n\n    Note:\n        It is recommended to use `oneflow.placement.all` instead of this function.\n\n    Args:\n        device_type (str): cuda or cpu\n\n    For examples:\n\n    .. code-block:: python\n\n        # Runs on 4 ranks\n        import oneflow as flow\n\n        p = flow.env.all_device_placement(\"cuda\") # oneflow.placement(type=\"cuda\", ranks=[0, 1, 2, 3])\n        p = flow.env.all_device_placement(\"cpu\") # oneflow.placement(type=\"cpu\", ranks=[0, 1, 2, 3])\n\n    \"\"\"\n    return oneflow.placement.all(device_type)\n\n\ndef check_non_localhost_proxy_and_print_warning():\n    for env_var_name in [\"http_proxy\", \"HTTP_PROXY\", \"https_proxy\", \"HTTPS_PROXY\"]:\n        env_var_value = os.getenv(env_var_name)\n        if (\n            env_var_value is not None\n            and (not \"://localhost\" in env_var_value)\n            and (not \"://127.0.0.1\" in env_var_value)\n            and (not env_var_value.startswith(\"localhost\"))\n            and (not env_var_value.startswith(\"127.0.0.1\"))\n        ):\n            print(\n                f\"Proxy through another machine ({env_var_value}) is incompatible with OneFlow. Please unset them by `unset http_proxy https_proxy HTTP_PROXY HTTPS_PROXY`\"\n            )\n            break\n\n\ndef create_env():\n    \"\"\"create environment\n\n    Returns:\n        Env: [description]\n    \"\"\"\n    global default_env_proto\n    assert len(default_env_proto.machine) > 0\n    CompleteEnvProto(default_env_proto)\n    if default_env_proto.ctrl_bootstrap_conf.world_size > 1:\n        check_non_localhost_proxy_and_print_warning()\n    return c_api_util.GetEnvContext(default_env_proto)\n\n\ndef CompleteEnvProto(env_proto):\n    _UpdateDefaultEnvProtoByMultiClientEnvVars(env_proto)\n    if env_proto.HasField(\"ctrl_port\") == False:\n        if len(env_proto.machine) == 1:\n            env_proto.ctrl_port = _FindFreePort()\n        else:\n            raise ValueError(\n                \"a ctrl_port is required if running multi-node, set it with 'oneflow.env.ctrl_port([YOUR PORT])'\"\n            )\n\n\ndef _MakeMachine(machines):\n    if isinstance(machines, str):\n        machines = [machines]\n    rp_machine = env_pb.EnvProto().machine\n    for m_data in machines:\n        m = rp_machine.add()\n        if isinstance(m_data, str):\n            m.addr = m_data\n        elif isinstance(m_data, dict):\n            if \"addr\" in m_data:\n                m.addr = m_data[\"addr\"]\n            if \"ctrl_port_agent\" in m_data:\n                m.ctrl_port_agent = m_data[\"ctrl_port_agent\"]\n            if \"data_port_agent\" in m_data:\n                m.data_port_agent = m_data[\"data_port_agent\"]\n        else:\n            raise NotImplementedError\n    id = 0\n    addrs_for_check = set()\n    for m in rp_machine:\n        m.id = id\n        id += 1\n        assert m.addr not in addrs_for_check\n        addrs_for_check.add(m.addr)\n    return rp_machine\n\n\ndef _MakeBootstrapConf(bootstrap_info: dict):\n    global config_master_addr\n    assert config_master_addr.HasField(\"host\"), \"must config master host first\"\n    assert config_master_addr.HasField(\"port\"), \"must config master port first\"\n    assert config_world_size != 0, \"must config world size first\"\n    bootstrap_conf = ctrl_bootstrap_pb.BootstrapConf()\n    bootstrap_conf.master_addr.CopyFrom(config_master_addr)\n    bootstrap_conf.world_size = config_world_size\n    assert \"rank\" in bootstrap_info\n    bootstrap_conf.rank = bootstrap_info[\"rank\"]\n    if \"host\" in bootstrap_info:\n        bootstrap_conf.host = bootstrap_info[\"host\"]\n    global config_bootstrap_ctrl_port\n    if config_bootstrap_ctrl_port != 0:\n        bootstrap_conf.ctrl_port = config_bootstrap_ctrl_port\n    global config_node_size\n    if config_node_size != 0:\n        bootstrap_conf.node_size = config_node_size\n    return bootstrap_conf\n\n\ndef _DefaultEnvProto():\n    env_proto = env_pb.EnvProto()\n    machine = env_proto.machine.add()\n    machine.id = 0\n    machine.addr = \"127.0.0.1\"\n    return env_proto\n\n\ndef _FindFreePort():\n    with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:\n        s.bind((\"localhost\", 0))\n        s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)\n        return s.getsockname()[1]\n\n\ndef CheckAndWarnAbnormalEnvVars():\n    env_var_names = [\"MASTER_ADDR\", \"MASTER_PORT\", \"WORLD_SIZE\", \"RANK\"]\n    env_var_without_value = [x for x in env_var_names if os.getenv(x) is None]\n    env_var_with_value = [x for x in env_var_names if os.getenv(x) is not None]\n    if len(env_var_with_value) != 0 and len(env_var_without_value) != 0:\n        warnings.warn(\n            f\"Among four environment variables required for distributed training, only {', '.join('`{0}`'.format(x) for x in env_var_with_value)} are set, but {', '.join('`{0}`'.format(x) for x in env_var_without_value)} are not set.\"\n        )\n\n\ndef _UpdateDefaultEnvProtoByMultiClientEnvVars(env_proto):\n    def str2int(env_config):\n        return int(env_config)\n\n    bootstrap_conf = ctrl_bootstrap_pb.BootstrapConf()\n    master_addr = ctrl_bootstrap_pb.Address()\n    master_addr.host = os.getenv(\"MASTER_ADDR\", \"127.0.0.1\")\n    master_addr.port = str2int(os.getenv(\"MASTER_PORT\", _FindFreePort()))\n    bootstrap_conf.master_addr.CopyFrom(master_addr)\n    bootstrap_conf.world_size = str2int(os.getenv(\"WORLD_SIZE\", 1))\n    bootstrap_conf.rank = str2int(os.getenv(\"RANK\", 0))\n    env_proto.ctrl_bootstrap_conf.CopyFrom(bootstrap_conf)\n    cpp_logging_conf = env_pb.CppLoggingConf()\n    if os.getenv(\"GLOG_log_dir\"):\n        cpp_logging_conf.log_dir = os.getenv(\"GLOG_log_dir\")\n    if os.getenv(\"GLOG_logtostderr\"):\n        cpp_logging_conf.logtostderr = str2int(os.getenv(\"GLOG_logtostderr\"))\n    if os.getenv(\"GLOG_logbuflevel\"):\n        cpp_logging_conf.logbuflevel = str2int(os.getenv(\"GLOG_logbuflevel\"))\n    if os.getenv(\"GLOG_minloglevel\"):\n        cpp_logging_conf.minloglevel = str2int(os.getenv(\"GLOG_minloglevel\"))\n    env_proto.cpp_logging_conf.CopyFrom(cpp_logging_conf)\n\n\nclass EnvHolder(object):\n    def __init__(self):\n        CheckAndWarnAbnormalEnvVars()\n        self._env_cxt = create_env()\n        self._shutting_down = [False]\n\n    def is_shutting_down(self):\n        \"\"\"\n        Whether the interpreter is currently shutting down.\n        For use in finalizers, __del__ methods, and similar; it is advised\n        to early bind this function rather than look it up when calling it,\n        since at shutdown module globals may be cleared.\n\n        Please refer to: https://github.com/Oneflow-Inc/OneTeam/issues/1219#issuecomment-1092370402\n        This solution is obtained from cupy code: https://github.com/cupy/cupy/pull/2809\n        \"\"\"\n        return self._shutting_down[0]\n\n    def switch_to_shutting_down(self, is_normal_exit=True):\n        self._shutting_down[0] = True\n        self._env_cxt.SwitchToShuttingDownPhase(is_normal_exit)\n\n\ndef GetEnv():\n    return EnvHolder()\n\n\ndevice_tag2default_parallel_conf = {}\ndefault_env_proto = _DefaultEnvProto()\nconfig_master_addr = ctrl_bootstrap_pb.Address()\nconfig_world_size = 0\nconfig_bootstrap_ctrl_port = 0\nconfig_node_size = 0\nglobal_ctrl_bootstrap_confs = []\n"
  },
  {
    "path": "python/oneflow/framework/function_desc.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow._oneflow_internal\nimport oneflow.core.job.job_conf_pb2 as job_conf_pb\nimport oneflow.framework.hob as hob\nimport oneflow.framework.session_context as session_ctx\nimport oneflow.support.enable_if as enable_if\n\n\nclass FunctionAttribute(object):\n    def __init__(self):\n        self.default_placement_scope = None\n        self.default_distribute_strategy = None\n        self.allow_cpu_return_op = True\n\n\nclass FunctionDesc(object):\n    def __init__(self, job_func=None, job_config_proto=None, function_attribute=None):\n        if job_config_proto is None:\n            job_config_proto = job_conf_pb.JobConfigProto()\n        if function_attribute is None:\n            function_attribute = FunctionAttribute()\n        self.job_func = job_func\n        self.job_config_proto = job_config_proto\n        self.job_config_proto.predict_conf.SetInParent()\n        self.function_attribute = function_attribute\n\n    def IsTrainable(self):\n        if self.job_config_proto.HasField(\"train_conf\"):\n            return True\n        if self.job_config_proto.HasField(\"predict_conf\"):\n            return False\n        raise NotImplementedError\n\n    def HasAttr(self, attr_name):\n        if attr_name == \"flag_name2flag_value\":\n            return False\n        name2default = session_ctx.GetDefaultSession().function_flag_name2default_val\n        if attr_name in self.job_config_proto.flag_name2flag_value:\n            return True\n        return self.job_config_proto.HasField(attr_name)\n\n    def __getattr__(self, attr_name):\n        assert attr_name != \"flag_name2flag_value\"\n        flag_name2flag_value = self.job_config_proto.flag_name2flag_value\n        name2default = session_ctx.GetDefaultSession().function_flag_name2default_val\n        if attr_name not in name2default:\n            assert self.job_config_proto.HasField(attr_name)\n            return getattr(self.job_config_proto, attr_name)\n        attr_value = name2default[attr_name]\n        if attr_name in flag_name2flag_value:\n            attr_value = flag_name2flag_value[attr_name]\n        if attr_value.HasField(\"at_bool\"):\n            return attr_value.at_bool\n        elif attr_value.HasField(\"at_int64\"):\n            return attr_value.at_int64\n        elif attr_value.HasField(\"at_double\"):\n            return attr_value.at_double\n        elif attr_value.HasField(\"at_string\"):\n            return attr_value.at_string\n        else:\n            raise NotImplementedError()\n"
  },
  {
    "path": "python/oneflow/framework/function_util.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport copy\nimport functools\nimport inspect\nimport re\nimport sys\nimport traceback\nfrom typing import Any, Callable, Optional, Union\n\nimport oneflow._oneflow_internal\nimport oneflow.core.common.data_type_pb2 as data_type_pb\nimport oneflow.framework.session_context as session_ctx\nimport oneflow.support.enable_if as enable_if\nimport oneflow.support.pb_util as pb_util\nfrom oneflow import oneflow_deprecate\nfrom oneflow.framework.function_desc import FunctionDesc\n\n\nclass FunctionConfig(object):\n    \"\"\"OneFlow function's configurations.\n    \"\"\"\n\n    def __init__(self) -> None:\n        self.function_desc = FunctionDesc()\n\n    def __getattr__(\n        self, attr_name: str\n    ) -> Callable[[Optional[Union[bool, int, float, str]]], None]:\n        name2default = session_ctx.GetDefaultSession().function_flag_name2default_val\n        assert attr_name in name2default\n        flag_name2flag_value = self.function_desc.job_config_proto.flag_name2flag_value\n        default_val = name2default[attr_name]\n\n        def FunctionConfigSetter(\n            attr_value: Optional[Union[bool, int, float, str]] = None\n        ) -> None:\n            if default_val.HasField(\"at_bool\"):\n                if attr_value is None:\n                    attr_value = True\n                assert type(attr_value) is bool\n                flag_name2flag_value[attr_name].at_bool = attr_value\n            elif default_val.HasField(\"at_int64\"):\n                assert type(attr_value) is int\n                flag_name2flag_value[attr_name].at_int64 = attr_value\n            elif default_val.HasField(\"at_double\"):\n                assert type(attr_value) is float\n                flag_name2flag_value[attr_name].at_double = attr_value\n            elif default_val.HasField(\"at_string\"):\n                assert type(attr_value) is str\n                flag_name2flag_value[attr_name].at_string = attr_value\n            else:\n                raise NotImplementedError(\n                    \"config_flag `%s' with type %s is not supported\"\n                    % (attr_name, type(attr_value))\n                )\n\n        return FunctionConfigSetter\n\n\ndef _CloneFunctionDesc(func_desc, job_func):\n    new_func_desc = FunctionDesc(job_func=job_func)\n    new_func_desc.job_config_proto.CopyFrom(func_desc.job_config_proto)\n    new_func_desc.function_attribute = copy.deepcopy(func_desc.function_attribute)\n    return new_func_desc\n\n\ndef oneflow_function_config(*field_paths):\n    def Decorator(func):\n        global _class_property2return_obj_class\n        for field_path in field_paths:\n            fields = field_path.split(\".\")\n            assert len(fields) > 0\n            cls = FunctionConfig\n            for (index, field) in enumerate(fields):\n                assert field != \"function_desc\"\n                assert re.match(\"^[_\\\\w]+[_\\\\w\\\\d]*$\", field)\n                if (cls, field) not in _class_property2return_obj_class:\n                    class_name = \".\".join([\"function_config\"] + fields[: index + 1])\n\n                    def Init(self, function_desc):\n                        self.function_desc = function_desc\n\n                    config_class = type(class_name, (object,), dict(__init__=Init))\n                    setattr(cls, field, _MakeInnerJobConfigClassProperty(config_class))\n                    _class_property2return_obj_class[cls, field] = config_class\n                cls = _class_property2return_obj_class[cls, field]\n            cls.__call__ = _MakeLeafJobConfigCall(func)\n        return func\n\n    return Decorator\n\n\n_class_property2return_obj_class = {}\n\n\ndef _MakeInnerJobConfigClassProperty(return_obj_class):\n    return property(lambda self: return_obj_class(self.function_desc))\n\n\ndef _MakeLeafJobConfigCall(method):\n    return lambda self, *argv, **kwarg: method(self.function_desc, *argv, **kwarg)\n\n\n@oneflow_function_config(\"default_data_type\")\ndef set_default_data_type(func_desc, value):\n    \"\"\"Set default data type for job\n\n    Args:\n        func_desc ([type]): job function\n        value ([type]): data type. e.g. flow.float\n    \"\"\"\n    func_desc.job_config_proto.default_data_type = oneflow._oneflow_internal.deprecated.GetProtoDtype4OfDtype(\n        value\n    )\n\n\n@oneflow_function_config(\"default_initializer_conf\")\ndef set_default_initializer_conf(func_desc, value):\n    \"\"\"Set default initial configuration for job\n\n    Args:\n        func_desc ([type]): [description]\n        value ([type]): [description]\n    \"\"\"\n    assert type(value) is dict\n    pb_util.PythonDict2PbMessage(\n        value, func_desc.job_config_proto.default_initializer_conf\n    )\n\n\n@oneflow_function_config(\"exp_run_conf\")\ndef set_exp_run_conf(value):\n    \"\"\"Set experimental configuration for job\n\n    Args:\n        value ([type]): [description]\n    \"\"\"\n    assert type(func_desc, value) is dict\n    pb_util.PythonDict2PbMessage(value, func_desc.job_config_proto.exp_run_conf)\n\n\n@oneflow_function_config(\"static_mem_alloc_policy_white_list.has\")\ndef static_mem_alloc_policy_white_list_has_policy(func_desc, policy):\n    \"\"\"Get items from white list related to static memory allocation policy\n\n    Args:\n        func_desc ([type]): [description]\n        policy ([type]): [description]\n\n    Returns:\n        [type]: [description]\n    \"\"\"\n    return getattr(func_desc.job_config_proto.memory_allocation_algorithm_conf, policy)\n\n\n@oneflow_function_config(\"static_mem_alloc_policy_white_list.add\")\ndef static_mem_alloc_policy_white_list_add_policy(func_desc, policy):\n    \"\"\"Add item to white list related to static memory allocation policy\n\n    Args:\n        func_desc ([type]): [description]\n        policy ([type]): [description]\n    \"\"\"\n    setattr(func_desc.job_config_proto.memory_allocation_algorithm_conf, policy, True)\n\n\n@oneflow_function_config(\"static_mem_alloc_policy_white_list.remove\")\ndef static_mem_alloc_policy_white_list_remove_policy(func_desc, policy):\n    \"\"\"Remove item of white list related to static memory allocation policy\n\n    Args:\n        func_desc ([type]): [description]\n        policy ([type]): [description]\n    \"\"\"\n    setattr(func_desc.job_config_proto.memory_allocation_algorithm_conf, policy, False)\n\n\n@oneflow_function_config(\"static_mem_alloc_policy_white_list.policy_mem_size_first\")\ndef policy_mem_size_first(func_desc):\n    \"\"\"A static memory allocation policy called: mem_size_first\n\n    Args:\n        func_desc ([type]): [description]\n\n    Returns:\n        [type]: [description]\n    \"\"\"\n    return \"use_mem_size_first_algo\"\n\n\n@oneflow_function_config(\n    \"static_mem_alloc_policy_white_list.policy_mutual_exclusion_first\"\n)\ndef policy_mutual_exclusion_first(func_desc):\n    \"\"\"A static memory allocation policy called: mutual_exclusion_first\n\n    Args:\n        func_desc ([type]): [description]\n\n    Returns:\n        [type]: [description]\n    \"\"\"\n    return \"use_mutual_exclusion_first_algo\"\n\n\n@oneflow_function_config(\"static_mem_alloc_policy_white_list.policy_time_line\")\ndef policy_time_line(func_desc):\n    \"\"\"A static memory allocation policy called: time_line\n\n    Args:\n        func_desc ([type]): [description]\n\n    Returns:\n        [type]: [description]\n    \"\"\"\n    return \"use_time_line_algo\"\n\n\n@oneflow_function_config(\"static_mem_alloc_algo_white_list.show\")\ndef show_static_mem_alloc_algo_white_list(func_desc):\n    \"\"\"Show configuration of  static memory allocation policy,\n          including: \"use_mem_size_first_algo\", \"use_mutual_exclusion_first_algo\", \"use_time_line_algo\"\n\n    Args:\n        func_desc ([type]): [description]\n\n    Returns:\n        [type]: [description]\n    \"\"\"\n    return [\n        \"use_mem_size_first_algo\",\n        \"use_mutual_exclusion_first_algo\",\n        \"use_time_line_algo\",\n    ]\n\n\n@oneflow_function_config(\"enable_cudnn\")\ndef set_enable_cudnn(func_desc, value=True):\n    \"\"\"Whether use cudnn to accelerate job or not.\n\n    Args:\n        func_desc ([type]): [description]\n        value (bool, optional): [description]. Defaults to True.\n    \"\"\"\n    func_desc.job_config_proto.enable_cudnn = value\n\n\n@oneflow_function_config(\"cudnn_buf_limit_mbyte\")\ndef set_cudnn_buf_limit_mbyte(func_desc, value):\n    \"\"\"Set cudnn buffer limit, e.g. 1024mb\n\n    Args:\n        func_desc ([type]): [description]\n        value ([type]): [description]\n    \"\"\"\n    func_desc.job_config_proto.cudnn_buf_limit_mbyte = value\n\n\n@oneflow_function_config(\"cudnn_conv_force_fwd_algo\")\ndef set_cudnn_conv_force_fwd_algo(func_desc, value):\n    \"\"\"Set value to cudnn conv_force_forward algorithm\n\n    Args:\n        func_desc ([type]): [description]\n        value ([type]): [description]\n    \"\"\"\n    func_desc.job_config_proto.cudnn_conv_force_fwd_algo = value\n\n\n@oneflow_function_config(\"cudnn_conv_force_bwd_data_algo\")\ndef set_cudnn_conv_force_bwd_data_algo(func_desc, value):\n    \"\"\"Set value to cudnn conv_force_backward_data algorithm\n\n    Args:\n        func_desc ([type]): [description]\n        value ([type]): [description]\n    \"\"\"\n    func_desc.job_config_proto.cudnn_conv_force_bwd_data_algo = value\n\n\n@oneflow_function_config(\"cudnn_conv_force_bwd_filter_algo\")\ndef set_cudnn_conv_force_bwd_filter_algo(func_desc, value):\n    \"\"\"Set value to cudnn conv_force_backward_filter algorithm\n\n    Args:\n        func_desc ([type]): [description]\n        value ([type]): [description]\n    \"\"\"\n    func_desc.job_config_proto.cudnn_conv_force_bwd_filter_algo = value\n\n\n@oneflow_function_config(\"cudnn_conv_heuristic_search_algo\")\ndef set_cudnn_conv_heuristic_search_algo(func_desc, value):\n    \"\"\"Set value to cudnn conv_heuristic_search algorithm\n\n    Args:\n        func_desc ([type]): [description]\n        value ([type]): [description]\n    \"\"\"\n    func_desc.job_config_proto.cudnn_conv_heuristic_search_algo = value\n\n\n@oneflow_function_config(\"enable_cudnn_fused_normalization_add_relu\")\ndef set_enable_cudnn_fused_normalization_add_relu(func_desc, value):\n    \"\"\"Whether enable cudnn_fused_normalization_add_relu.\n\n    Args:\n        func_desc ([type]): [description]\n        value ([type]): [description]\n    \"\"\"\n    func_desc.job_config_proto.enable_cudnn_fused_normalization_add_relu = value\n\n\n@oneflow_function_config(\"enable_fuse_add_to_output\")\ndef set_enable_fuse_add_to_output(func_desc, value):\n    \"\"\"Whether enable fuse_add_to_output.\n            If enabled, try to fuse a binary element-wise add to one of the predecessors to improve performance.\n\n    Args:\n        func_desc ([type]): [description]\n        value ([type]): [description]\n    \"\"\"\n    func_desc.job_config_proto.enable_fuse_add_to_output = value\n\n\n@oneflow_function_config(\"enable_fuse_cast_scale\")\ndef set_enable_fuse_cast_scale(func_desc, value=True):\n    \"\"\"Whether enable fuse_cast_scale.\n            If enabled, try to fuse cast and scalar_mul_by_tensor to improve performance.\n\n    Args:\n        func_desc ([type]): [description]\n        value ([type]): [description]\n    \"\"\"\n    func_desc.job_config_proto.enable_fuse_cast_scale = value\n\n\n@oneflow_function_config(\"cudnn_conv_use_deterministic_algo_only\")\ndef set_cudnn_conv_use_deterministic_algo_only(func_desc, value):\n    \"\"\"Set value to cudnn conv_use_deterministic_only algorithm\n\n    Args:\n        func_desc ([type]): [description]\n        value ([type]): [description]\n    \"\"\"\n    func_desc.job_config_proto.cudnn_conv_use_deterministic_algo_only = value\n\n\n@oneflow_function_config(\"enable_reuse_mem\")\ndef set_enable_reused_mem(func_desc, value=True):\n    \"\"\"Whether enable reuse memory or not\n\n    Args:\n        func_desc ([type]): [description]\n        value (bool, optional): [description]. Defaults to True.\n    \"\"\"\n    func_desc.job_config_proto.enable_reuse_mem = value\n\n\n@oneflow_function_config(\"enable_inplace\")\ndef set_enable_inplace(func_desc, value=True):\n    \"\"\"Whether enable inplace  or not\n\n    Args:\n        func_desc ([type]): [description]\n        value (bool, optional): [description]. Defaults to True.\n    \"\"\"\n    func_desc.job_config_proto.enable_inplace = value\n\n\n@oneflow_function_config(\"enable_inplace_in_reduce_struct\")\ndef set_enable_inplace_in_reduce_struct(func_desc, value=True):\n    print(\n        \"'enable_inplace_in_reduce_struct' has been deprecated, has no effect and will be removed in the future.\"\n    )\n\n\n@oneflow_function_config(\"enable_nccl\")\ndef set_enable_nccl(func_desc, value=True):\n    print(\n        \"'enable_nccl' has been deprecated, has no effect and will be removed in the future.\"\n    )\n\n\n@oneflow_function_config(\"use_nccl_inter_node_communication\")\ndef set_use_nccl_inter_node_communication(func_desc, value=True):\n    print(\n        \"'use_nccl_inter_node_communication' has been deprecated, has no effect and will be removed in the future.\"\n    )\n\n\n@oneflow_function_config(\"use_boxing_v2\")\ndef set_use_boxing_v2(func_desc, value=True):\n    print(\n        \"'use_boxing_v2' has been deprecated, has no effect and will be removed in the future.\"\n    )\n\n\n@oneflow_function_config(\"do_parallel_cast_before_widening_type_cast\")\ndef set_do_parallel_cast_before_widening_type_cast(func_desc, value=True):\n    func_desc.job_config_proto.do_parallel_cast_before_widening_type_cast = value\n\n\n@oneflow_function_config(\"enable_all_reduce_group\")\ndef set_enable_all_reduce_group(func_desc, value=True):\n    print(\n        \"'enable_all_reduce_group' has been deprecated, has no effect and will be removed in the future.\"\n    )\n\n\n@oneflow_function_config(\"all_reduce_group_num\")\ndef set_all_reduce_group_num(func_desc, value):\n    print(\n        \"'all_reduce_group_num' has been deprecated, has no effect and will be removed in the future.\"\n    )\n\n\n@oneflow_function_config(\"all_reduce_lazy_ratio\")\ndef set_all_reduce_lazy_ratio(func_desc, value):\n    print(\n        \"'all_reduce_lazy_ratio' has been deprecated, has no effect and will be removed in the future.\"\n    )\n\n\n@oneflow_function_config(\"all_reduce_group_min_mbyte\")\ndef set_all_reduce_group_min_mbyte(func_desc, value):\n    print(\n        \"'all_reduce_group_min_mbyte' has been deprecated, has no effect and will be removed in the future.\"\n    )\n\n\n@oneflow_function_config(\"all_reduce_group_size_warmup\")\ndef set_all_reduce_group_size_warmup(func_desc, value):\n    print(\n        \"'all_reduce_group_size_warmup' has been deprecated, has no effect and will be removed in the future.\"\n    )\n\n\n@oneflow_function_config(\"all_reduce_fp16\")\ndef set_all_reduce_fp16(func_desc, value=True):\n    print(\n        \"'all_reduce_fp16' has been deprecated, has no effect and will be removed in the future.\"\n    )\n\n\n@oneflow_function_config(\n    \"optimizer_placement_optimization_mode\",\n    \"train.optimizer_placement_optimization_mode\",\n)\ndef set_optimizer_placement_optimization_mode(func_desc, mode):\n    \"\"\"Enable optimizer_placement_optimization with mode 'mode'\n\n    Args:\n        func_desc ([type]): [description]\n        mode (str): [description].\n    \"\"\"\n    assert mode in [\"non_distributed\", \"distributed_split\"]\n    func_desc.job_config_proto.optimizer_placement_optimization_mode = mode\n\n\n@oneflow_function_config(\n    \"optimizer_placement_optimization_threshold\",\n    \"train.optimizer_placement_optimization_threshold\",\n)\ndef set_optimizer_placement_optimization_threshold(func_desc, value):\n    func_desc.job_config_proto.optimizer_placement_optimization_threshold = value\n\n\n@oneflow_function_config(\"enable_non_distributed_optimizer\")\ndef set_enable_non_distributed_optimizer(func_desc, value=True):\n    \"\"\"Whether enable non_distributed optimizer or not\n\n    Args:\n        func_desc ([type]): [description]\n        value (bool, optional): [description]. Defaults to True.\n    \"\"\"\n    if value:\n        set_optimizer_placement_optimization_mode(func_desc, \"non_distributed\")\n\n\n@oneflow_function_config(\"disable_all_reduce_sequence\")\ndef set_disable_all_reduce_sequence(func_desc, value=True):\n    print(\n        \"'disable_all_reduce_sequence' has been deprecated, has no effect and will be removed in the future.\"\n    )\n\n\n@oneflow_function_config(\"prune_parallel_cast_ops\")\ndef set_prune_parallel_cast_ops(func_desc, value=True):\n    \"\"\"Whether prune parallel cast  operations or not.\n\n    Args:\n        func_desc ([type]): [description]\n        value (bool, optional): [description]. Defaults to True.\n    \"\"\"\n    func_desc.job_config_proto.prune_parallel_cast_ops = value\n\n\n@oneflow_function_config(\"prune_cast_to_static_shape_ops\")\ndef set_prune_cast_to_static_shape_ops(func_desc, value=True):\n    \"\"\"Whether or not set prune_cast to static shape opretions\n\n    Args:\n        func_desc ([type]): [description]\n        value (bool, optional): [description]. Defaults to True.\n    \"\"\"\n    func_desc.job_config_proto.prune_cast_to_static_shape_ops = value\n\n\n@oneflow_function_config(\"prune_amp_white_identity_ops\")\ndef set_prune_amp_white_identity_ops(func_desc, value=True):\n    \"\"\"Whether prune amp_white_identity operations or not.\n\n    Args:\n        func_desc ([type]): [description]\n        value (bool, optional): [description]. Defaults to True.\n    \"\"\"\n    func_desc.job_config_proto.prune_amp_white_identity_ops = value\n\n\n@oneflow_function_config(\"prune_depend_ops\")\ndef set_prune_depend_ops(func_desc, value=True):\n    \"\"\"Whether prune depend operations or not.\n\n    Args:\n        func_desc ([type]): [description]\n        value (bool, optional): [description]. Defaults to True.\n    \"\"\"\n    func_desc.job_config_proto.prune_depend_ops = value\n\n\n@oneflow_function_config(\"non_distributed_optimizer_group_size_mbyte\")\ndef set_non_distributed_optimizer_group_size_mbyte(func_desc, value):\n    print(\n        \"'non_distributed_optimizer_group_size_mbyte' has been deprecated, has no effect and will be removed in the future.\"\n    )\n\n\n@oneflow_function_config(\n    \"enable_true_half_config_when_conv\", \"cudnn_conv_enable_true_half\"\n)\ndef set_cudnn_conv_enable_true_half(func_desc, value=True):\n    \"\"\"Whether  use true_half mode or not during  convolution calculation process while using cudnn.\n\n    Args:\n        func_desc ([type]): [description]\n        value (bool, optional): [description]. Defaults to True.\n    \"\"\"\n    func_desc.job_config_proto.cudnn_conv_enable_pseudo_half = not value\n\n\n@oneflow_function_config(\n    \"cudnn_conv_enable_pseudo_half\", \"enable_cudnn_conv_pseudo_half\"\n)\ndef set_cudnn_conv_enable_pseudo_half(func_desc, value):\n    \"\"\"Whether  enable pseudo_half mode or not during  convolution calculation process while using cudnn\n\n    Args:\n        func_desc ([type]): [description]\n        value ([type]): [description]\n    \"\"\"\n    func_desc.job_config_proto.cudnn_conv_enable_pseudo_half = value\n\n\n@oneflow_function_config(\"enable_float_compute_for_half_gemm\")\ndef set_enable_float_compute_for_half_gemm(func_desc, value=True):\n    \"\"\"Whether  enable float_compute or not ,\n          if True, means that the type of intermedia value is float when compute half gemm.\n\n    Args:\n        func_desc ([type]): [description]\n        value (bool, optional): [description]. Defaults to True.\n    \"\"\"\n    print(\n        \"WARNING: enable_float_compute_for_half_gemm has been deprecated, because we always use float compute for half gemm. Please remove it.\\n        \"\n    )\n    print(traceback.format_stack()[-3])\n\n\n@oneflow_function_config(\"enable_quantization_aware_training\")\n@oneflow_function_config(\"enable_qat\")\ndef set_enable_quantization_aware_training(func_desc, value=True):\n    \"\"\"If true, then job will use quantization aware training\n\n    Args:\n        func_desc ([type]): [description]\n        value (bool, optional): [description]. Defaults to True.\n    \"\"\"\n    func_desc.job_config_proto.enable_quantization_aware_training = value\n\n\n@oneflow_function_config(\"qat.per_channel_weight_quantization\")\ndef set_qat_per_channel(func_desc, value=True):\n    func_desc.job_config_proto.qat_config.per_channel_weight_quantization = value\n\n\n@oneflow_function_config(\"qat.symmetric\")\ndef set_qat_symmetric(func_desc, value=True):\n    func_desc.job_config_proto.qat_config.symmetric = value\n\n\n@oneflow_function_config(\"qat.moving_min_max_momentum\")\ndef set_qat_moving_min_max_momentum(func_desc, value: float):\n    func_desc.job_config_proto.qat_config.moving_min_max_momentum = value\n\n\n@oneflow_function_config(\"qat.moving_min_max_stop_update_after_iters\")\ndef set_qat_moving_min_max_momentum(func_desc, value: float):\n    func_desc.job_config_proto.qat_config.moving_min_max_stop_update_after_iters = value\n\n\n@oneflow_function_config(\"qat.target_backend\")\ndef set_qat_symmetric(func_desc, value: str):\n    func_desc.job_config_proto.qat_config.target_backend = value\n\n\n@oneflow_function_config(\"enable_auto_mixed_precision\")\ndef set_enable_auto_mixed_precision(func_desc, value=True):\n    \"\"\"If true, then job will use mixed precision mode, it means use both float16 and float32 during model training.\n\n    Args:\n        func_desc ([type]): [description]\n        value (bool, optional): [description]. Defaults to True.\n    \"\"\"\n    func_desc.job_config_proto.enable_auto_mixed_precision = value\n\n\n@oneflow_function_config(\"enable_keep_header_only\")\ndef set_enable_keep_header_only(func_desc, value=True):\n    \"\"\"deprecated api.\n\n    Args:\n        func_desc ([type]): [description]\n        value (bool, optional): [description]. Defaults to True.\n    \"\"\"\n    print(\"Sorry! enable_keep_header_only is deprecated and it doesn't work.\\n\")\n\n\n@oneflow_function_config(\"concurrency_width\")\ndef set_concurrency_width(func_desc, value):\n    \"\"\"Set up concurrency width\n\n    Args:\n        func_desc ([type]): [description]\n        value ([type]): [description]\n    \"\"\"\n    func_desc.job_config_proto.concurrency_width = value\n\n\n@oneflow_function_config(\"train.model_update_conf\")\ndef set_model_update_conf(func_desc, value):\n    \"\"\"Set up optimizer and update method of learning rate  for job\n\n    Args:\n        func_desc ([type]): [description]\n        value ([type]): [description]\n    \"\"\"\n    print(\n        \"WARNING: func_config.train.* has been deprecated. Please replace it by the new optimizer api.\\n        \"\n    )\n    print(traceback.format_stack()[-3])\n    assert type(value) is dict\n    func_desc.job_config_proto.train_conf.model_update_conf.SetInParent()\n    pb_util.PythonDict2PbMessage(\n        value, func_desc.job_config_proto.train_conf.model_update_conf\n    )\n\n\n@oneflow_function_config(\"indexed_slices_optimizer_conf\")\ndef set_indexed_slices_optimizer_conf(func_desc, value):\n    \"\"\"Set indexed slices configuration of optimizer\n\n    Args:\n        func_desc ([type]): [description]\n        value ([type]): [description]\n    \"\"\"\n    assert type(value) is dict\n    func_desc.job_config_proto.indexed_slices_optimizer_conf.SetInParent()\n    pb_util.PythonDict2PbMessage(\n        value, func_desc.job_config_proto.indexed_slices_optimizer_conf\n    )\n\n\n@oneflow_function_config(\"enable_fuse_model_update_ops\")\ndef set_enable_fuse_model_update_ops(func_desc, value=True):\n    \"\"\"Whether enable fuse_model_update_ops.\n            If enabled, try to fuse cast + scale + l1_l2_regularize_gradient + model_update to one op to improve performance.\n\n    Args:\n        func_desc ([type]): [description]\n        value ([type]): [description]\n    \"\"\"\n    func_desc.job_config_proto.enable_fuse_model_update_ops = value\n\n\n@oneflow_function_config(\"enable_gradients_stats_aggregation\")\ndef set_enable_gradients_stats_aggregation(func_desc, value=True):\n    \"\"\"Whether enable gradients_stats_aggregation.\n            If enabled, gradients stats ops (norm, finite, ...) will be aggregated.\n\n    Args:\n        func_desc ([type]): [description]\n        value ([type]): [description]\n    \"\"\"\n    func_desc.job_config_proto.enable_gradients_stats_aggregation = value\n\n\n@oneflow_function_config(\"train.loss_scale_factor\")\ndef set_loss_scale_factor(func_desc, value):\n    \"\"\"Set scale factor for loss\n\n    Args:\n        func_desc ([type]): [description]\n        value ([type]): [description]\n    \"\"\"\n    print(\n        \"WARNING: func_config.train.* has been deprecated. Please replace it by the new optimizer api.\\n        \"\n    )\n    print(traceback.format_stack()[-3])\n    func_desc.job_config_proto.train_conf.loss_scale_factor = value\n\n\n@oneflow_function_config(\"train.primary_lr\")\ndef set_primary_lr(func_desc, value):\n    \"\"\"Set the primary leaning rate for job\n\n    Args:\n        func_desc ([type]): [description]\n        value ([type]): [description]\n    \"\"\"\n    print(\n        \"WARNING: func_config.train.* has been deprecated. Please replace it by the new optimizer api.\\n        \"\n    )\n    print(traceback.format_stack()[-3])\n    func_desc.job_config_proto.train_conf.primary_lr = value\n\n\n@oneflow_function_config(\"train.secondary_lr\")\ndef set_secondary_lr(func_desc, value):\n    \"\"\"Set the secondary leaning rate for job\n\n    Args:\n        func_desc ([type]): [description]\n        value ([type]): [description]\n    \"\"\"\n    print(\n        \"WARNING: func_config.train.* has been deprecated. Please replace it by the new optimizer api.\\n        \"\n    )\n    print(traceback.format_stack()[-3])\n    func_desc.job_config_proto.train_conf.secondary_lr = value\n\n\n@oneflow_function_config(\"train.num_gradient_accumulation_steps\")\ndef set_num_gradient_accumulation_steps(func_desc, value):\n    func_desc.job_config_proto.num_gradient_accumulation_steps = value\n\n\n@oneflow_function_config(\"default_logical_view\")\ndef set_default_distribute_strategy(func_desc, value):\n    \"\"\"Set up default distribute strategy for job\n\n    Args:\n        func_desc ([type]): [description]\n        value ([type]): [description]\n    \"\"\"\n    assert isinstance(value, distribute_ctx.DistributeStrategy)\n    func_desc.function_attribute.default_distribute_strategy = value\n\n\n@oneflow_function_config(\"allow_cpu_return_op\")\ndef allow_cpu_return_op(func_desc, value):\n    \"\"\"Whether allow operaions returned from cpu or  not\n\n    Args:\n        func_desc ([type]): [description]\n        value ([type]): [description]\n    \"\"\"\n    func_desc.function_attribute.allow_cpu_return_op = value\n\n\n@oneflow_function_config(\"default_distribute_strategy\")\n@oneflow_deprecate()\ndef deprecated_set_default_distribute_strategy(*args, **kwargs):\n    print(\n        \"WARNING:\",\n        \"function_config.default_distribute_strategy\",\n        \"has been deprecated. Please use {} instead.\".format(\n            \"function_config.default_logical_view\"\n        ),\n    )\n    print(traceback.format_stack()[-3], file=sys.stderr)\n    set_default_distribute_strategy(*args, **kwargs)\n"
  },
  {
    "path": "python/oneflow/framework/generator.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow\nimport oneflow._oneflow_internal\n\n\ndef create_generator(device=None):\n    if device is None:\n        device = \"auto\"\n    return oneflow._oneflow_internal.create_generator(device)\n\n\ndef seed() -> int:\n    r\"\"\"\n    Sets the seed for generating random numbers to a non-deterministic\n    random number. Returns a 64 bit number used to seed the RNG.\n\n    The documentation is referenced from:\n    https://pytorch.org/docs/1.10/generated/torch.seed.html.\n    \"\"\"\n    seed = default_generator.seed()\n    oneflow._oneflow_internal.manual_seed(seed)\n    return seed\n\n\ndef manual_seed(seed):\n    r\"\"\"\n    Sets the seed for generating random numbers. Returns a\n    `oneflow.Generator` object.\n\n    The documentation is referenced from:\n    https://pytorch.org/docs/1.10/generated/torch.manual_seed.html.\n\n    Args:\n        seed (int): The desired seed. Value must be within the inclusive range\n            `[-0x8000_0000_0000_0000, 0xffff_ffff_ffff_ffff]`. Otherwise, a RuntimeError\n            is raised. Negative inputs are remapped to positive values with the formula\n            `0xffff_ffff_ffff_ffff + seed`.\n    \"\"\"\n    seed = int(seed)\n    return oneflow._oneflow_internal.manual_seed(seed)\n\n\ndef initial_seed() -> int:\n    r\"\"\"\n    Returns the initial seed for generating random numbers as a\n    Python `long`.\n\n    The documentation is referenced from:\n    https://pytorch.org/docs/1.10/_modules/torch/random.html.\n    \n    \"\"\"\n    return default_generator.initial_seed()\n\n\ndef _getstate(self):\n    return {\"device\": str(self.device), \"state\": self.get_state()}\n\n\ndef _setstate(self, state_dict):\n    self.__init__(state_dict[\"device\"])\n    self.set_state(state_dict[\"state\"])\n\n\ndef get_rng_state():\n    r\"\"\"\n    Sets the random number generator state.\n\n    The documentation is referenced from:\n    https://pytorch.org/docs/1.10/generated/torch.get_rng_state.html.\n    \n    .. note: This function only works for CPU. For CUDA, please use\n             oneflow.manual_seed(seed), which works for both CPU and CUDA.\n\n    Args:\n        new_state (oneflow.ByteTensor): The desired state\n    \"\"\"\n    return oneflow.default_generator.get_state()\n\n\ndef set_rng_state(state):\n    \"\"\"\n    Returns the random number generator state as a `oneflow.ByteTensor`.\n\n    The documentation is referenced from:\n    https://pytorch.org/docs/1.10/generated/torch.set_rng_state.html.\n    \n    \"\"\"\n\n    return oneflow.default_generator.set_state(state)\n\n\ndefault_generator = oneflow._oneflow_internal.default_generator(\"cpu\")\noneflow._oneflow_internal.Generator.__getstate__ = _getstate\noneflow._oneflow_internal.Generator.__setstate__ = _setstate\n"
  },
  {
    "path": "python/oneflow/framework/graph_build_util.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nfrom contextlib import contextmanager\nimport os\n\nfrom google.protobuf import text_format\nimport oneflow\n\nimport oneflow._oneflow_internal\nimport oneflow.core.job.scope_pb2 as scope_pb2_util\nimport oneflow.core.job.job_conf_pb2 as job_conf_pb\nimport oneflow.framework.attr_util as attr_util\nimport oneflow.framework.c_api_util as c_api_util\nimport oneflow.framework.scope_util as scope_util\nimport oneflow.framework.session_context as session_context\nfrom oneflow.framework.tensor import Tensor\nfrom oneflow.nn.graph.proxy import GraphBlockType\nimport oneflow._oneflow_internal._C as _C\n\nlazy_mode = oneflow._oneflow_internal.lazy_mode\n\n\n@contextmanager\ndef graph_build_context(config_proto, session):\n    prev_scope = oneflow._oneflow_internal.GetCurrentScope()\n    assert type(config_proto) is job_conf_pb.JobConfigProto, type(config_proto)\n    config_proto_str = text_format.MessageToString(config_proto)\n    new_scope = oneflow._oneflow_internal.MakeInitialScope(\n        config_proto_str, oneflow.placement(\"cpu\", [0]), False,  # is_local\n    )\n\n    graph_scope = _make_new_graph_scope(new_scope, config_proto.job_name)\n\n    oneflow._oneflow_internal.eager.Sync()\n    with lazy_mode.guard(True):\n        with JobBuildAndInferCtx(config_proto):\n            with BlockScopeContext(prev_scope, graph_scope):\n                yield\n\n\nclass JobBuildAndInferCtx(object):\n    def __init__(self, config_proto):\n        self._job_conf = config_proto\n\n    def __enter__(self):\n        c_api_util.JobBuildAndInferCtx_Open(self._job_conf.job_name)\n        c_api_util.CurJobBuildAndInferCtx_SetJobConf(self._job_conf)\n\n    def __exit__(self, exc_type, exc_val, exc_tb):\n        if exc_type is None:\n            oneflow._oneflow_internal.JobBuildAndInferCtx_Close()\n            return True\n        else:\n            oneflow._oneflow_internal.JobBuildAndInferCtx_Close()\n            return False\n\n\nclass BlockScopeContext(object):\n    def __init__(self, prev_scope, new_scope):\n        assert prev_scope is not None\n        assert new_scope is not None\n        self._prev_scope = prev_scope\n        self._new_scope = new_scope\n\n    def __enter__(self):\n        assert oneflow._oneflow_internal.GetCurrentScope() is self._prev_scope\n        oneflow._oneflow_internal.GlobalScopeStackPush(self._new_scope)\n\n    def __exit__(self, exc_type, exc_val, exc_tb):\n        assert oneflow._oneflow_internal.GetCurrentScope() is self._new_scope\n        oneflow._oneflow_internal.GlobalScopeStackPop()\n        assert oneflow._oneflow_internal.GetCurrentScope() is self._prev_scope\n        if exc_type is None:\n            return True\n        else:\n            return False\n\n\nclass DebugScopeContext(object):\n    def __init__(\n        self,\n        s_level,\n        v_level=0,\n        mode=False,\n        max_py_stack_depth=2,\n        only_user_py_stack=True,\n    ):\n        self._prev_v = oneflow._oneflow_internal.GetFLAGS_v()\n        self._prev_logtostderr = oneflow._oneflow_internal.GetFLAGS_alsologtostderr()\n        self._prev_mode = oneflow._oneflow_internal.GetGraphDebugMode()\n        self._prev_max_py_stack_depth = (\n            oneflow._oneflow_internal.GetGraphDebugMaxPyStackDepth()\n        )\n        self._prev_only_user_py_stack = (\n            oneflow._oneflow_internal.GetGraphDebugOnlyUserPyStack()\n        )\n        self._v = max(v_level, self._prev_v)\n        self._mode = mode\n        self._s = s_level\n        self._max_py_stack_depth = max(\n            max_py_stack_depth, self._prev_max_py_stack_depth\n        )\n        self._only_user_py_stack = only_user_py_stack\n\n    def __enter__(self):\n        oneflow._oneflow_internal.SetFLAGS_v(self._v)\n        oneflow._oneflow_internal.SetGraphDebugMode(self._mode)\n        if self._s == 0 and self._v >= 1:\n            oneflow._oneflow_internal.SetFLAGS_alsologtostderr(True)\n        oneflow._oneflow_internal.SetGraphDebugMaxPyStackDepth(self._max_py_stack_depth)\n        oneflow._oneflow_internal.SetGraphDebugOnlyUserPyStack(self._only_user_py_stack)\n\n    def __exit__(self, exc_type, exc_val, exc_tb):\n        if self._s == 0 and self._v >= 1:\n            oneflow._oneflow_internal.SetFLAGS_alsologtostderr(self._prev_logtostderr)\n        oneflow._oneflow_internal.SetFLAGS_v(self._prev_v)\n        oneflow._oneflow_internal.SetGraphDebugMode(self._prev_mode)\n        oneflow._oneflow_internal.SetGraphDebugMaxPyStackDepth(\n            self._prev_max_py_stack_depth\n        )\n        oneflow._oneflow_internal.SetGraphDebugOnlyUserPyStack(\n            self._prev_only_user_py_stack\n        )\n\n\ndef _make_new_scope(prev_scope, scope_proto_str_setter):\n    new_scope = None\n\n    def build_scope(builder):\n        nonlocal new_scope\n        new_scope = builder.BuildScopeByProtoStrSetter(\n            prev_scope, scope_proto_str_setter\n        )\n        assert new_scope is not None\n\n    oneflow._oneflow_internal.deprecated.PhysicalRun(build_scope)\n    oneflow._oneflow_internal.eager.Sync()\n    return new_scope\n\n\ndef _make_new_graph_scope(prev_scope, graph_name):\n    assert prev_scope is not None\n    attr_dict = dict()\n    name2default = session_context.GetDefaultSession().scope_attr_name2default_val\n\n    def scope_proto_str_setter(serialized_scope_proto: str):\n        scope_proto = text_format.Parse(\n            serialized_scope_proto, scope_pb2_util.ScopeProto()\n        )\n        scope_proto.module_name = graph_name\n        return str(text_format.MessageToString(scope_proto))\n\n    return _make_new_scope(prev_scope, scope_proto_str_setter)\n\n\ndef make_new_blockgraph_scope(prev_scope, graph_block):\n    assert prev_scope is not None\n    assert graph_block is not None\n    attr_dict = dict()\n    if graph_block.stage_id is not None:\n        attr_dict[\"pipeline_stage_id_hint\"] = graph_block.stage_id\n    if graph_block.type == GraphBlockType.MODULE:\n        if graph_block.activation_checkpointing is not None:\n            attr_dict[\"checkpointing\"] = graph_block.activation_checkpointing\n\n    name2default = session_context.GetDefaultSession().scope_attr_name2default_val\n\n    def scope_proto_str_setter(serialized_scope_proto: str):\n        scope_proto = text_format.Parse(\n            serialized_scope_proto, scope_pb2_util.ScopeProto()\n        )\n        # set attr\n        for attr_name, py_value in attr_dict.items():\n            assert attr_name in name2default\n            attr_util.SetProtoAttrValue(\n                scope_proto.attr_name2attr_value[attr_name],\n                py_value,\n                name2default[attr_name],\n            )\n        # append name prefix\n        scope_proto.ClearField(\"scope_op_name_prefixes\")\n        scope_proto.scope_op_name_prefixes.append(\n            graph_block.name_prefix + graph_block.name\n        )\n        # set module name\n        if graph_block.type == GraphBlockType.MODULE:\n            scope_proto.module_name = graph_block.name_prefix + graph_block.name\n        return str(text_format.MessageToString(scope_proto))\n\n    return _make_new_scope(prev_scope, scope_proto_str_setter)\n\n\ndef make_new_name_scope(prev_scope, name):\n    assert prev_scope is not None\n\n    def scope_proto_str_setter(serialized_scope_proto: str):\n        scope_proto = text_format.Parse(\n            serialized_scope_proto, scope_pb2_util.ScopeProto()\n        )\n        # append name prefix\n        scope_proto.ClearField(\"scope_op_name_prefixes\")\n        scope_proto.scope_op_name_prefixes.append(name)\n        scope_proto.module_name = name\n        return str(text_format.MessageToString(scope_proto))\n\n    return _make_new_scope(prev_scope, scope_proto_str_setter)\n\n\ndef scope_to_proto(scope):\n    return text_format.Parse(scope._proto_str, scope_pb2_util.ScopeProto())\n\n\ndef build_graph_input_arg(op_name, arg):\n    assert isinstance(arg, Tensor)\n    input_conf = oneflow.core.operator.op_conf_pb2.FeedInputOpConf()\n    input_conf.in_0 = \"in_0\"  # Set the default value, otherwise the parsing fails\n    input_conf.out_0 = \"out_0\"\n    input_conf_str = text_format.MessageToString(input_conf)\n\n    input_op = oneflow._oneflow_internal.one.FeedInputOpExpr(\n        op_name, input_conf_str, [\"in_0\"], [\"out_0\"]\n    )\n    lazy_arg = _C.dispatch_feed_input(input_op, arg)\n    return lazy_arg\n\n\ndef build_graph_state(op_name, state_tensor, state_config):\n    var_conf = oneflow.core.operator.op_conf_pb2.FeedVariableOpConf()\n    var_conf.in_0 = \"in_0\"  # Set the default value, otherwise the parsing fails\n    var_conf.out_0 = \"out_0\"\n    var_conf_str = text_format.MessageToString(var_conf)\n\n    var_op = oneflow._oneflow_internal.one.FeedVariableOpExpr(\n        op_name, var_conf_str, [\"in_0\"], [\"out_0\"]\n    )\n    l2 = 0.0\n    if state_config is not None:\n        l2 = state_config.l2\n    elif state_tensor.requires_grad:\n        l2 = 0.0\n\n    assert isinstance(state_tensor, Tensor)\n    lazy_tensor = _C.dispatch_feed_variable(var_op, state_tensor, l2=l2)\n    return lazy_tensor\n\n\ndef build_graph_output(op_name, out):\n    assert isinstance(out, Tensor)\n\n    output_conf = oneflow.core.operator.op_conf_pb2.FetchOutputOpConf()\n    output_conf.in_0 = \"in_0\"  # Set the default value, otherwise the parsing fails\n    output_conf.out_0 = \"out_0\"\n    output_conf_str = text_format.MessageToString(output_conf)\n\n    output_op = oneflow._oneflow_internal.one.FetchOutputOpExpr(\n        op_name, output_conf_str, [\"in_0\"], [\"out_0\"]\n    )\n    fake_eager_out = _C.dispatch_fetch_output(output_op, out)\n    return fake_eager_out\n"
  },
  {
    "path": "python/oneflow/framework/hob.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nfrom oneflow.support.high_order_bool import bool_functor\n\n\n\"\"\"Example:\n@bool_functor(\"Current mode is %s\" % rt_mode.GLOBAL_MODE)\ndef in_global_mode(ctx):\n    return rt_mode.CurrentMode() == rt_mode.GLOBAL_MODE\n\"\"\"\n"
  },
  {
    "path": "python/oneflow/framework/id_util.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow._oneflow_internal\n\n\ndef UniqueStr(prefix):\n    return oneflow._oneflow_internal.UniqueStr(prefix)\n"
  },
  {
    "path": "python/oneflow/framework/infer_compiler/__init__.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\ntry:\n    import torch\nexcept ImportError:\n    print(\"You should install torch also when use `oneflow.framework.infer_compiler`.\")\n\nfrom .transform.custom_transform import register\nfrom .utils.patch_for_compiler import *\nfrom .with_fx_graph import fx_node_tranform\nfrom .with_fx_interpreter import OneFlowInterpreter\nfrom .with_oneflow_compile import compile_from_torch\nfrom .with_oneflow_backend import oneflow_backend\n"
  },
  {
    "path": "python/oneflow/framework/infer_compiler/import_tools/__init__.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\"\"\" Tools for importing modules and packages\"\"\"\nfrom .importer import LazyMocker, import_module_from_path\n"
  },
  {
    "path": "python/oneflow/framework/infer_compiler/import_tools/format_utils.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport inspect\nfrom types import FunctionType\nfrom typing import Union\n\n\nclass MockEntityNameFormatter:\n    def __init__(self, prefix: str = \"mock_\", suffix: str = \"_oflow\"):\n        self.prefix = prefix\n        self.suffix = suffix\n\n    def _format_pkg_name(self, pkg_name: str) -> str:\n        if pkg_name.startswith(self.prefix) and pkg_name.endswith(self.suffix):\n            return pkg_name\n        return self.prefix + pkg_name + self.suffix\n\n    def _reverse_pkg_name(self, pkg_name: str) -> str:\n        assert pkg_name.startswith(self.prefix) and pkg_name.endswith(\n            self.suffix\n        ), f\"Package name must start with {self.prefix} and end with {self.suffix}, but got {pkg_name}\"\n        return pkg_name[len(self.prefix) : -len(self.suffix)]\n\n    def _format_full_class_name(self, obj: Union[str, type, FunctionType]):\n        if isinstance(obj, type):\n            obj = f\"{obj.__module__}.{obj.__qualname__}\"\n\n        elif isinstance(obj, FunctionType):\n            module = inspect.getmodule(obj)\n            obj = f\"{module.__name__}.{obj.__qualname__}\"\n\n        assert isinstance(obj, str), f\"obj must be str, but got {type(obj)}\"\n\n        if \".\" in obj:\n            pkg_name, cls_name = obj.split(\".\", 1)\n            return f\"{self._format_pkg_name(pkg_name)}.{cls_name}\"\n        else:\n            return self._format_pkg_name(obj)\n\n    def format(self, entity: Union[str, type, FunctionType]) -> str:\n        return self._format_full_class_name(entity)\n\n    def unformat(self, mock_entity_name: str) -> str:\n        if \".\" in mock_entity_name:\n            pkg_name, cls_name = mock_entity_name.split(\".\", 1)\n            return f\"{self._reverse_pkg_name(pkg_name)}.{cls_name}\"\n        else:  # mock_entity_name is a pkg_name\n            return self._reverse_pkg_name(mock_entity_name)\n"
  },
  {
    "path": "python/oneflow/framework/infer_compiler/import_tools/importer.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport importlib\nimport os\nimport sys\nfrom pathlib import Path\nfrom types import FunctionType, ModuleType\nfrom typing import Optional, Union\n\nfrom oneflow.mock_torch import DynamicMockModule\n\nfrom .format_utils import MockEntityNameFormatter\n\nif sys.version_info < (3, 8):\n    try:\n        from importlib_metadata import requires\n    except ImportError:\n        import subprocess\n\n        subprocess.check_call(\"pip install importlib_metadata\", shell=True)\n        subprocess.check_call(\"pip install packaging\", shell=True)\nelse:\n    from importlib.metadata import requires\n\n__all__ = [\"import_module_from_path\", \"LazyMocker\", \"is_need_mock\"]\n\n\ndef is_need_mock(cls) -> bool:\n    assert isinstance(cls, (type, str))\n    main_pkg = cls.__module__.split(\".\")[0]\n    try:\n        pkgs = requires(main_pkg)\n    except Exception as e:\n        return True\n    if pkgs:\n        for pkg in pkgs:\n            pkg = pkg.split(\" \")[0]\n            if pkg == \"torch\":\n                return True\n        return False\n    return True\n\n\ndef import_module_from_path(module_path: Union[str, Path]) -> ModuleType:\n    if isinstance(module_path, Path):\n        module_path = str(module_path)\n    module_name = os.path.basename(module_path)\n    if os.path.isfile(module_path):\n        sp = os.path.splitext(module_path)\n        module_name = sp[0]\n\n    if os.path.isfile(module_path):\n        module_spec = importlib.util.spec_from_file_location(module_name, module_path)\n        module_dir = os.path.split(module_path)[0]\n    else:\n        module_spec = importlib.util.spec_from_file_location(\n            module_name, os.path.join(module_path, \"__init__.py\")\n        )\n        module_dir = module_path\n\n    module = importlib.util.module_from_spec(module_spec)\n    sys.modules[module_name] = module\n    module_spec.loader.exec_module(module)\n    return module\n\n\nclass LazyMocker:\n    def __init__(self, prefix: str, suffix: str, tmp_dir: Optional[Union[str, Path]]):\n        self.prefix = prefix\n        self.suffix = suffix\n        self.tmp_dir = tmp_dir\n        self.mocked_packages = set()\n        self.cleanup_list = []\n\n    def mock_package(self, package: str):\n        pass\n\n    def cleanup(self):\n        pass\n\n    def get_mock_entity_name(self, entity: Union[str, type, FunctionType]):\n        formatter = MockEntityNameFormatter(prefix=self.prefix, suffix=self.suffix)\n        full_obj_name = formatter.format(entity)\n        return full_obj_name\n\n    def mock_entity(self, entity: Union[str, type, FunctionType]):\n        \"\"\"Mock the entity and return the mocked entity\n\n        Example:\n            >>> mocker = LazyMocker(prefix=\"mock_\", suffix=\"_of\", tmp_dir=\"tmp\")\n            >>> mocker.mock_entity(\"models.DemoModel\")\n            <class 'mock_models_of.DemoModel'>\n            >>> cls_obj = models.DemoModel\n            >>> mocker.mock_entity(cls_obj)\n            <class 'mock_models_of.DemoModel'>\n        \"\"\"\n        return self.load_entity_with_mock(entity)\n\n    def add_mocked_package(self, package: str):\n        if package in self.mocked_packages:\n            return\n\n        self.mocked_packages.add(package)\n        package = sys.modules.get(package, None)\n\n        # TODO remove code below\n        # fix the mock error in https://github.com/siliconflow/oneflow/blob/main/python/oneflow/mock_torch/mock_importer.py#L105-L118\n        if package and getattr(package, \"__file__\", None) is not None:\n            pkg_path = Path(package.__file__).parents[1]\n            if pkg_path not in sys.path:\n                sys.path.append(str(pkg_path))\n\n    def load_entity_with_mock(self, entity: Union[str, type, FunctionType]):\n        formatter = MockEntityNameFormatter(prefix=self.prefix, suffix=self.suffix)\n        full_obj_name = formatter.format(entity)\n        attrs = full_obj_name.split(\".\")\n\n        # add package path to sys.path to avoid mock error\n        self.add_mocked_package(attrs[0])\n\n        mock_pkg = DynamicMockModule.from_package(attrs[0], verbose=False)\n        for name in attrs[1:]:\n            mock_pkg = getattr(mock_pkg, name)\n        return mock_pkg\n"
  },
  {
    "path": "python/oneflow/framework/infer_compiler/transform/__init__.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\"\"\"Module to convert PyTorch code to OneFlow.\"\"\"\nfrom .builtin_transform import (\n    ProxySubmodule,\n    default_converter,\n    get_attr,\n    map_args,\n    proxy_class,\n    torch2oflow,\n)\nfrom .custom_transform import register\nfrom .manager import transform_mgr\n"
  },
  {
    "path": "python/oneflow/framework/infer_compiler/transform/builtin_transform.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\"\"\"Convert torch object to oneflow object.\"\"\"\nimport importlib\nimport os\nimport types\nfrom collections import OrderedDict\nfrom collections.abc import Iterable\nfrom functools import partial, singledispatch\nfrom typing import Any, Union\n\nimport oneflow as flow\nimport torch\nfrom oneflow.framework.infer_compiler.import_tools.importer import is_need_mock\nfrom oneflow.framework.infer_compiler.utils.log_utils import logger\nfrom oneflow.framework.infer_compiler.utils.patch_for_diffusers import diffusers_checker\n\nfrom .manager import transform_mgr\n\n__all__ = [\n    \"proxy_class\",\n    \"ProxySubmodule\",\n    \"map_args\",\n    \"get_attr\",\n    \"torch2oflow\",\n    \"default_converter\",\n]\n\n\ndef singledispatch_proxy(func):\n    dispatcher = singledispatch(func)\n    _warning_set = set()\n\n    def wrapper(first_param, *args, **kwargs):\n        nonlocal _warning_set\n\n        before = first_param.__class__.__name__\n        result = dispatcher(first_param, *args, **kwargs)\n        after = result.__class__.__name__\n\n        description = f\"{before} transformed to  {after}\"\n\n        if before not in after and description not in _warning_set:\n            _warning_set.add(description)\n            logger.info(f\"instance_name: {description}\")\n        return result\n\n    wrapper.register = dispatcher.register\n    wrapper.dispatch = dispatcher.dispatch\n    return wrapper\n\n\ndef proxy_class(cls: type):\n    if cls.__module__.startswith(\"torch.\"):\n        mod_name = cls.__module__.replace(\"torch.\", \"oneflow.\")\n        mod = importlib.import_module(mod_name)\n        return getattr(mod, cls.__name__)\n\n    full_qualified_name = cls.__module__ + \".\" + cls.__qualname__\n    result = transform_mgr.transform_cls(full_qualified_name)\n    return result\n\n\nclass ProxySubmodule:\n    def __init__(self, submod):\n        self._oflow_proxy_submod = submod\n        self._oflow_proxy_parameters = {}\n        self._oflow_proxy_children = {}\n\n    def __getitem__(self, index):  # __getitem__\n        if isinstance(self._oflow_proxy_submod, Iterable):\n            submod = self._oflow_proxy_submod[index]\n            return torch2oflow(submod)\n\n        raise RuntimeError(f\"can't getitem for: {type(self._oflow_proxy_submod)}\")\n\n    def __repr__(self) -> str:\n        return \" oflow_proxy: \" + self._oflow_proxy_submod.__repr__()\n\n    def __getattribute__(self, attribute):\n        if attribute.startswith(\"_oflow_proxy\"):\n            return object.__getattribute__(self, attribute)\n        elif attribute in [\"forward\", \"_conv_forward\"]:\n            replacement = proxy_class(type(self._oflow_proxy_submod))\n            return lambda *args, **kwargs: getattr(replacement, attribute)(\n                self, *args, **kwargs\n            )\n        elif (\n            diffusers_checker.is_attention_instance(self._oflow_proxy_submod)\n            and attribute == \"get_attention_scores\"\n        ):\n            replacement = proxy_class(type(self._oflow_proxy_submod))\n            return lambda *args, **kwargs: getattr(replacement, attribute)(\n                self, *args, **kwargs\n            )\n        elif (\n            isinstance(self._oflow_proxy_submod, torch.nn.Linear)\n            and attribute == \"use_fused_matmul_bias\"\n        ):\n            return (\n                self.bias is not None\n                and os.getenv(\"ONEFLOW_KERNEL_ENABLE_FUSED_LINEAR\") == \"1\"\n            )\n        elif (\n            isinstance(self._oflow_proxy_submod, torch.nn.Dropout)\n            and attribute == \"generator\"\n        ):\n            return flow.Generator()\n        elif (\n            isinstance(self._oflow_proxy_submod, (torch.nn.Conv2d, torch.nn.Conv3d))\n            and attribute == \"channel_pos\"\n        ):\n            return \"channels_first\"\n        else:\n            a = getattr(self._oflow_proxy_submod, attribute)\n\n            if isinstance(a, (torch.nn.parameter.Parameter, torch.Tensor)):\n                # TODO(oneflow): assert a.requires_grad == False\n                if attribute not in self._oflow_proxy_parameters:\n                    a = torch2oflow(a)\n\n                    self._oflow_proxy_parameters[attribute] = a\n                else:\n                    a = self._oflow_proxy_parameters[attribute]\n            elif isinstance(\n                a, (torch.nn.Module, torch.nn.ModuleList, torch.nn.Sequential)\n            ):\n                if attribute not in self._oflow_proxy_children:\n                    a = torch2oflow(a)\n\n                    self._oflow_proxy_children[attribute] = a\n                else:\n                    a = self._oflow_proxy_children[attribute]\n\n            return a\n\n    def __call__(self, *args: Any, **kwargs: Any) -> Any:\n        replacement = proxy_class(type(self._oflow_proxy_submod))\n\n        if replacement is not None:\n            return replacement.__call__(self, *args, **kwargs)\n        else:\n            raise RuntimeError(\n                \"can't find oneflow module for: \" + str(type(self._oflow_proxy_submod))\n            )\n\n\n@singledispatch_proxy\ndef torch2oflow(mod, *args, **kwargs):\n    return default_converter(mod, *args, **kwargs)\n\n\ndef default_converter(obj, verbose=False, *, proxy_cls=None):\n    if not is_need_mock(type(obj)):\n        return obj\n    try:\n        new_obj_cls = proxy_class(type(obj)) if proxy_cls is None else proxy_cls\n\n        def init(self):\n            for k, _ in obj.__dict__.items():\n                attr = getattr(obj, k)\n                self.__dict__[k] = torch2oflow(attr)\n\n        of_obj_cls = type(str(new_obj_cls), (new_obj_cls,), {\"__init__\": init})\n        of_obj = of_obj_cls()\n\n        if verbose:\n            logger.info(f\"convert {type(obj)} to {type(of_obj)}\")\n        return of_obj\n    except Exception as e:\n        logger.warning(f\"Unsupported type: {type(obj)} {e}\")\n        # raise NotImplementedError(f\"Unsupported type: {obj}\")\n        return obj\n\n\n@torch2oflow.register\ndef _(mod: torch.nn.Module, verbose=False):\n    proxy_md = ProxySubmodule(mod)\n\n    new_md_cls = proxy_class(type(mod))\n\n    def init(self):\n        nonlocal proxy_md\n\n        flow.nn.Module.__init__(self)\n\n        self._parameters = OrderedDict()\n        self._buffers = OrderedDict()\n        self._modules = OrderedDict()\n        for n, p in list(proxy_md.named_parameters(\"\", False)):\n            self._parameters[n] = torch2oflow(p)\n        for n, b in list(proxy_md.named_buffers(\"\", False)):\n            self._buffers[n] = flow.utils.tensor.from_torch(b.data)\n        for n, m in proxy_md._modules.items():\n            self._modules[n] = torch2oflow(m)\n\n        for k, _ in proxy_md.__dict__.items():\n            if k not in self.__dict__:\n                attr = getattr(proxy_md, k)\n                try:\n                    self.__dict__[k] = torch2oflow(attr)\n\n                except Exception as e:\n                    logger.error(f\"convert {type(attr)} failed: {e}\")\n                    raise NotImplementedError(f\"Unsupported type: {type(attr)}\")\n\n    def proxy_getattr(self, attr):\n        nonlocal proxy_md\n\n        try:\n            return super().__getattribute__(attr)\n        except:\n            if attr in self._modules:\n                return self._modules[attr]\n            if attr in self._parameters:\n                return self._parameters[attr]\n            elif attr in self._buffers:\n                return self._buffers[attr]\n            else:\n                return getattr(proxy_md, attr)\n\n    of_mod_cls = type(\n        str(new_md_cls), (new_md_cls,), {\"__init__\": init, \"__getattr__\": proxy_getattr}\n    )\n    of_mod = of_mod_cls()\n    if of_mod.training:\n        of_mod.training = False\n        if verbose:\n            logger.info(\n                f\"\"\"\n            Warning: {type(of_mod)} is in training mode \n            and is turned into eval mode which is good for infrence optimation.\n            \"\"\"\n            )\n\n    if verbose:\n        logger.info(f\"convert {type(mod)} to {type(of_mod)}\")\n\n    return of_mod\n\n\n@torch2oflow.register\ndef _(mod: torch.nn.BatchNorm1d, verbose=False):\n    of_mod = torch2oflow.dispatch(torch.nn.Module)(mod, verbose)\n    of_mod.channel_axis = 1\n\n    return of_mod\n\n\n@torch2oflow.register\ndef _(mod: torch.nn.BatchNorm2d, verbose=False):\n    of_mod = torch2oflow.dispatch(torch.nn.Module)(mod, verbose)\n    if os.getenv(\"ONEFLOW_ENABLE_NHWC\"):\n        of_mod.channel_axis = 3\n    else:\n        of_mod.channel_axis = 1\n\n    return of_mod\n\n\n@torch2oflow.register\ndef _(mod: torch.nn.BatchNorm3d, verbose=False):\n    of_mod = torch2oflow.dispatch(torch.nn.Module)(mod, verbose)\n    of_mod.channel_axis = 1\n\n    return of_mod\n\n\n@torch2oflow.register\ndef _(mod: torch.nn.MaxPool1d, verbose=False):\n    of_mod = torch2oflow.dispatch(torch.nn.Module)(mod, verbose)\n    of_mod.channel_pos = \"channels_first\"\n\n    return of_mod\n\n\n@torch2oflow.register\ndef _(mod: torch.nn.MaxPool2d, verbose=False):\n    of_mod = torch2oflow.dispatch(torch.nn.Module)(mod, verbose)\n    if os.getenv(\"ONEFLOW_ENABLE_NHWC\"):\n        of_mod.channel_pos = \"channels_last\"\n    else:\n        of_mod.channel_pos = \"channels_first\"\n\n    return of_mod\n\n\n@torch2oflow.register\ndef _(mod: torch.nn.MaxPool3d, verbose=False):\n    of_mod = torch2oflow.dispatch(torch.nn.Module)(mod, verbose)\n    of_mod.channel_pos = \"channels_first\"\n\n    return of_mod\n\n\n@torch2oflow.register\ndef _(mod: torch.nn.AvgPool1d, verbose=False):\n    of_mod = torch2oflow.dispatch(torch.nn.Module)(mod, verbose)\n    of_mod.channel_pos = \"channels_first\"\n\n    return of_mod\n\n\n@torch2oflow.register\ndef _(mod: torch.nn.AvgPool2d, verbose=False):\n    of_mod = torch2oflow.dispatch(torch.nn.Module)(mod, verbose)\n    if os.getenv(\"ONEFLOW_ENABLE_NHWC\"):\n        of_mod.channel_pos = \"channels_last\"\n    else:\n        of_mod.channel_pos = \"channels_first\"\n\n    return of_mod\n\n\n@torch2oflow.register\ndef _(mod: torch.nn.AvgPool3d, verbose=False):\n    of_mod = torch2oflow.dispatch(torch.nn.Module)(mod, verbose)\n    of_mod.channel_pos = \"channels_first\"\n\n    return of_mod\n\n\n@torch2oflow.register\ndef _(mod: torch.nn.AdaptiveAvgPool2d, verbose=False):\n    of_mod = torch2oflow.dispatch(torch.nn.Module)(mod, verbose)\n    if os.getenv(\"ONEFLOW_ENABLE_NHWC\"):\n        of_mod.channel_pos = \"channels_last\"\n    else:\n        of_mod.channel_pos = \"channels_first\"\n\n    return of_mod\n\n\ntry:\n    from torchvision.ops import Conv2dNormActivation\n\n    @torch2oflow.register\n    def _(mod: Conv2dNormActivation, verbose=False):\n        return flow.nn.Sequential(*[torch2oflow(layer) for layer in mod])\n\n\nexcept ImportError:\n    logger.warning(\"Failed to import torchvision\")\n\n\n@torch2oflow.register\ndef _(mod: torch.nn.ModuleList, verbose=False):\n    of_mod_list = flow.nn.ModuleList()\n    for original_submod in mod:\n        submod = torch2oflow(original_submod, verbose)\n        of_mod_list.append(submod)\n\n    return of_mod_list\n\n\n@torch2oflow.register\ndef _(mod: torch.nn.Sequential, verbose=False):\n    of_mod_list = []\n    for original_submod in mod:\n        submod = torch2oflow(original_submod, verbose)\n        of_mod_list.append(submod)\n\n    of_mod_seq = proxy_class(type(mod))(*of_mod_list)\n    return of_mod_seq\n\n\n@torch2oflow.register\ndef _(mod: torch.nn.parameter.Parameter, verbose=False) -> flow.nn.Parameter:\n    data = flow.utils.tensor.from_torch(mod.data)\n    if mod.data.dtype == torch.int8:\n        mod.requires_grad_(False)\n        return flow.nn.Parameter(data.to(flow.int8), requires_grad=False)\n    return flow.nn.Parameter(data, requires_grad=mod.requires_grad)\n\n\n@torch2oflow.register\ndef _(mod: torch.Tensor, verbose=False) -> flow.Tensor:\n    return flow.utils.tensor.from_torch(mod)\n\n\n_dtype_map = {\n    \"torch.float16\": flow.float16,\n    \"torch.float32\": flow.float32,\n    \"torch.double\": flow.double,\n    \"torch.int8\": flow.int8,\n    \"torch.int32\": flow.int32,\n    \"torch.int64\": flow.int64,\n    \"torch.uint8\": flow.uint8,\n}\n\n\n@torch2oflow.register\ndef _(mod: torch.dtype, verbose=False) -> flow.dtype:\n    return _dtype_map[str(mod)]\n\n\n@torch2oflow.register\ndef _(mod: list, verbose=False) -> list:\n    return [torch2oflow(m, verbose) for m in mod]\n\n\n@torch2oflow.register\ndef _(mod: tuple, verbose=False) -> tuple:\n    return tuple(torch2oflow(m, verbose) for m in mod)\n\n\n@torch2oflow.register\ndef _(mod: OrderedDict, verbose=False) -> OrderedDict:\n    if \"OrderedDict\" not in f\"{mod}\":\n        return default_converter(mod, verbose)\n    else:\n        return default_converter(mod, verbose, proxy_cls=OrderedDict)\n\n\n@torch2oflow.register\ndef _(mod: set, verbose=False) -> set:\n    return set(torch2oflow(m, verbose) for m in mod)\n\n\n@torch2oflow.register(int)\n@torch2oflow.register(float)\n@torch2oflow.register(str)\n@torch2oflow.register(bool)\ndef _(mod, verbose=False) -> Union[int, float, str, bool]:\n    return mod\n\n\n@torch2oflow.register\ndef _(mod: None, verbose=False):\n    return mod\n\n\n@torch2oflow.register\ndef _(mod: types.BuiltinFunctionType, verbose=False):\n    if hasattr(mod, \"__module__\"):\n        mod_name = None\n        if mod.__module__.startswith(\"torch._C._nn\"):\n            # The equivalence of mod inside torch._C._nn may be\n            # defined in flow.nn.functional\n            if getattr(flow.nn.functional, mod.__name__):\n                mod_name = \"oneflow.nn.functional\"\n            else:\n                mod_name = mod.__module__.replace(\n                    \"torch._C._nn\", \"oneflow._oneflow_internal._C\"\n                )\n        elif mod.__module__.startswith(\"torch\"):\n            try:\n                if getattr(torch.nn.functional, mod.__name__) == mod:\n                    mod_name = \"oneflow.nn.functional\"\n            except:\n                mod_name = mod.__module__.replace(\"torch\", \"oneflow\")\n        if mod_name is not None:\n            m = importlib.import_module(mod_name)\n            return getattr(m, mod.__name__)\n\n    return default_converter(mod, verbose)\n\n\n@torch2oflow.register\ndef _(mod: torch.device, verbose=False):\n    index = mod.index if mod.index is not None else 0\n    return flow.device(mod.type, index)\n\n\n@torch2oflow.register\ndef _(mod: dict, verbose=False) -> dict:\n    return {torch2oflow(k): torch2oflow(v, verbose) for k, v in mod.items()}\n\n\n@torch2oflow.register\ndef _(func: types.FunctionType, verbose=False):\n    return transform_mgr.transform_func(func)\n\n\n@torch2oflow.register\ndef _(mod: partial, verbose=False):\n    # https://docs.python.org/3/library/functools.html?highlight=partial#functools.partial\n    func = torch2oflow(mod.func)\n    args = torch2oflow(mod.args)\n    keywords = torch2oflow(mod.keywords)\n    return partial(func, *args, **keywords)\n\n\n############################################## Code For Onefx ##############################################\n\n\ndef map_args(args, kwargs):\n    args = [torch2oflow(a) for a in args]\n    kwargs = dict((k, torch2oflow(v)) for (k, v) in kwargs.items())\n    return (args, kwargs)\n\n\ndef get_attr(gm, node, torch2flow={}):\n    attr = getattr(gm, node.target)\n    if attr in torch2flow:\n        return torch2flow[attr]\n    of_attr = torch2oflow(attr)\n    torch2flow[attr] = of_attr\n    return of_attr\n"
  },
  {
    "path": "python/oneflow/framework/infer_compiler/transform/custom_transform.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\"\"\"A module for registering custom torch2oflow functions and classes.\"\"\"\nimport inspect\nfrom pathlib import Path\nfrom typing import Callable, Dict, List, Optional, Union\n\nfrom oneflow.framework.infer_compiler.utils.log_utils import logger\n\nfrom .builtin_transform import torch2oflow\nfrom .manager import transform_mgr\n\n__all__ = [\"register\"]\n\n\ndef register_torch2oflow_class(cls: type, replacement: type, verbose=True):\n    try:\n        key = transform_mgr.get_transformed_entity_name(cls)\n        transform_mgr.update_class_proxies({key: replacement}, verbose=verbose)\n\n    except Exception as e:\n        logger.warning(f\"Cannot register {cls} {replacement}. {e}\")\n\n\ndef register_torch2oflow_func(func, first_param_type=None, verbose=False):\n    if first_param_type is None:\n        params = inspect.signature(func).parameters\n        first_param_type = params[list(params.keys())[0]].annotation\n        if first_param_type == inspect._empty:\n            logger.warning(f\"Cannot register {func} {first_param_type}.\")\n    try:\n        torch2oflow.register(first_param_type)(func)\n        logger.debug(f\"Register {func} {first_param_type}\")\n        if verbose:\n            logger.info(f\"Register {func} {first_param_type}\")\n    except Exception as e:\n        logger.warning(f\"Cannot register {func} {first_param_type}. {e}\")\n\n\ndef ensure_list(obj):\n    if isinstance(obj, list):\n        return obj\n    return [obj]\n\n\ndef register(\n    *,\n    package_names: Optional[List[Union[Path, str]]] = None,\n    torch2oflow_class_map: Optional[Dict[type, type]] = None,\n    torch2oflow_funcs: Optional[List[Callable]] = None,\n):\n    if package_names:\n        package_names = ensure_list(package_names)\n        transform_mgr.load_class_proxies_from_packages(package_names)\n\n    if torch2oflow_class_map:\n        for torch_cls, of_cls in torch2oflow_class_map.items():\n            register_torch2oflow_class(torch_cls, of_cls)\n\n    if torch2oflow_funcs:\n        torch2oflow_funcs = ensure_list(torch2oflow_funcs)\n        for func in torch2oflow_funcs:\n            register_torch2oflow_func(func)\n"
  },
  {
    "path": "python/oneflow/framework/infer_compiler/transform/manager.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport logging\nimport os\nimport types\nfrom pathlib import Path\nfrom typing import Dict, List, Union\n\nfrom oneflow.framework.infer_compiler.import_tools.importer import LazyMocker\nfrom oneflow.framework.infer_compiler.utils.log_utils import logger\n\n__all__ = [\"transform_mgr\"]\n\n\nclass TransformManager:\n    \"\"\"TransformManager\n\n    __init__ args:\n        `debug_mode`: Whether to print debug info.\n        `tmp_dir`: The temp dir to store mock files.\n    \"\"\"\n\n    def __init__(self, debug_mode=False, tmp_dir=\"./output\"):\n        self.debug_mode = debug_mode\n        self._torch_to_oflow_cls_map = {}\n        self._setup_logger()\n        self.mocker = LazyMocker(prefix=\"\", suffix=\"\", tmp_dir=None)\n\n    def _setup_logger(self):\n        name = \"ONEDIFF\"\n        level = logging.DEBUG if self.debug_mode else logging.ERROR\n        logger.configure_logging(name=name, file_name=None, level=level, log_dir=None)\n        self.logger = logger\n\n    def get_mocked_packages(self):\n        return self.mocker.mocked_packages\n\n    def load_class_proxies_from_packages(self, package_names: List[Union[Path, str]]):\n        self.logger.debug(f\"Loading modules: {package_names}\")\n        for package_name in package_names:\n            self.mocker.mock_package(package_name)\n            self.logger.info(f\"Loaded Mock Torch Package: {package_name} successfully\")\n\n    def update_class_proxies(self, class_proxy_dict: Dict[str, type], verbose=True):\n        \"\"\"Update `_torch_to_oflow_cls_map` with `class_proxy_dict`.\n\n        example:\n            `class_proxy_dict = {\"mock_torch.nn.Conv2d\": flow.nn.Conv2d}`\n\n        \"\"\"\n        self._torch_to_oflow_cls_map.update(class_proxy_dict)\n\n        debug_message = f\"Updated class proxies: {len(class_proxy_dict)}\"\n        debug_message += f\"\\n{class_proxy_dict}\\n\"\n        self.logger.debug(debug_message)\n\n    def _transform_entity(self, entity):\n        result = self.mocker.mock_entity(entity)\n        if result is None:\n            RuntimeError(f\"Failed to transform entity: {entity}\")\n        return result\n\n    def get_transformed_entity_name(self, entity):\n        return self.mocker.get_mock_entity_name(entity)\n\n    def transform_cls(self, full_cls_name: str):\n        \"\"\"Transform a class name to a mock class .\"\"\"\n        mock_full_cls_name = self.get_transformed_entity_name(full_cls_name)\n\n        if mock_full_cls_name in self._torch_to_oflow_cls_map:\n            use_value = self._torch_to_oflow_cls_map[mock_full_cls_name]\n            return use_value\n\n        mock_cls = self._transform_entity(mock_full_cls_name)\n        self._torch_to_oflow_cls_map[mock_full_cls_name] = mock_cls\n        return mock_cls\n\n    def transform_func(self, func: types.FunctionType):\n        # TODO: support transform function cache\n        return self._transform_entity(func)\n\n    def transform_package(self, package_name):\n        return self._transform_entity(package_name)\n\n\ndebug_mode = os.getenv(\"ONEDIFF_DEBUG\", \"0\") == \"1\"\ntransform_mgr = TransformManager(debug_mode=debug_mode, tmp_dir=None)\n\ntry:\n    import pydantic\n\n    if pydantic.VERSION < \"2.5.2\":\n        logger.warning(\n            f\"Pydantic version {pydantic.VERSION} is too low, please upgrade to 2.5.2 or higher.\"\n        )\n        from oneflow.mock_torch.mock_utils import MockEnableDisableMixin\n\n        MockEnableDisableMixin.hazard_list.append(\n            \"huggingface_hub.inference._text_generation\"\n        )\n\nexcept ImportError:\n    pass\n"
  },
  {
    "path": "python/oneflow/framework/infer_compiler/utils/__init__.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nfrom .oneflow_exec_mode import oneflow_exec_mode, oneflow_exec_mode_enabled\n"
  },
  {
    "path": "python/oneflow/framework/infer_compiler/utils/args_tree_util.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow as flow\nimport torch\nfrom oneflow.framework.args_tree import ArgsTree\n\n\ndef input_output_processor(func):\n    def process_input(*args, **kwargs):\n        def input_fn(value):\n            if isinstance(value, torch.Tensor):\n                # TODO: https://github.com/siliconflow/sd-team/issues/109\n                return flow.utils.tensor.from_torch(value.contiguous())\n            else:\n                return value\n\n        args_tree = ArgsTree((args, kwargs), False, tensor_type=torch.Tensor)\n        out = args_tree.map_leaf(input_fn)\n        mapped_args = out[0]\n        mapped_kwargs = out[1]\n        return mapped_args, mapped_kwargs\n\n    def process_output(output):\n        def output_fn(value):\n            if isinstance(value, flow.Tensor):\n                return flow.utils.tensor.to_torch(value)\n            else:\n                return value\n\n        out_tree = ArgsTree((output, None), False)\n        out = out_tree.map_leaf(output_fn)\n        return out[0]\n\n    def wrapper(cls, *args, **kwargs):\n        mapped_args, mapped_kwargs = process_input(*args, **kwargs)\n        output = func(cls, *mapped_args, **mapped_kwargs)\n        return process_output(output)\n\n    return wrapper\n"
  },
  {
    "path": "python/oneflow/framework/infer_compiler/utils/cost_util.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport inspect\nimport time\nfrom functools import wraps\n\nimport oneflow as flow\n\nfrom .log_utils import logger\n\n\nclass cost_cnt:\n    def __init__(self, debug=False, message=\"\\t\"):\n        self.debug = debug\n        self.message = message\n\n    def __enter__(self):\n        if not self.debug:\n            return\n        flow._oneflow_internal.eager.Sync()\n        before_used = flow._oneflow_internal.GetCUDAMemoryUsed()\n        before_host_used = flow._oneflow_internal.GetCPUMemoryUsed()\n        logger.debug(f\"====> {self.message} try to run...\")\n        logger.debug(f\"{self.message} cuda mem before {before_used} MB\")\n        logger.debug(f\"{self.message} host mem before {before_host_used} MB\")\n        self.before_used = before_used\n        self.before_host_used = before_host_used\n        self.start_time = time.time()\n\n    def __exit__(self, exc_type, exc_val, exc_tb):\n        if not self.debug:\n            return\n        flow._oneflow_internal.eager.Sync()\n        end_time = time.time()\n        after_used = flow._oneflow_internal.GetCUDAMemoryUsed()\n        after_host_used = flow._oneflow_internal.GetCPUMemoryUsed()\n        logger.debug(f\"{self.message} run time {end_time - self.start_time} seconds\")\n        logger.debug(f\"{self.message} cuda mem after {after_used} MB\")\n        logger.debug(f\"{self.message} cuda mem diff {after_used - self.before_used} MB\")\n        logger.debug(f\"{self.message} host mem after {after_host_used} MB\")\n        logger.debug(\n            f\"{self.message} host mem diff {after_host_used - self.before_host_used} MB\"\n        )\n        logger.debug(f\"<==== {self.message} finish run.\")\n\n    def __call__(self, func):\n        @wraps(func)\n        def clocked(*args, **kwargs):\n            if not self.debug:\n                return func(*args, **kwargs)\n            module = inspect.getmodule(func)\n            logger.debug(\n                f\"==> function {module.__name__}.{func.__name__}  try to run...\"\n            )\n            flow._oneflow_internal.eager.Sync()\n\n            before_used = flow._oneflow_internal.GetCUDAMemoryUsed()\n            logger.debug(f\"{func.__name__} cuda mem before {before_used} MB\")\n\n            before_host_used = flow._oneflow_internal.GetCPUMemoryUsed()\n            logger.debug(f\"{func.__name__} host mem before {before_host_used} MB\")\n\n            start_time = time.time()\n            out = func(*args, **kwargs)\n            flow._oneflow_internal.eager.Sync()\n            end_time = time.time()\n\n            logger.debug(f\"{func.__name__} run time {end_time - start_time} seconds\")\n\n            after_used = flow._oneflow_internal.GetCUDAMemoryUsed()\n            logger.debug(f\"{func.__name__} cuda mem after {after_used} MB\")\n\n            logger.debug(f\"{func.__name__} cuda mem diff {after_used - before_used} MB\")\n            after_host_used = flow._oneflow_internal.GetCPUMemoryUsed()\n            logger.debug(f\"{func.__name__} host mem after {after_host_used} MB\")\n            logger.debug(\n                f\"{func.__name__} host mem diff {after_host_used - before_host_used} MB\"\n            )\n\n            logger.debug(f\"<== function {func.__name__} finish run.\")\n            logger.debug(\"\")\n            return out\n\n        return clocked\n"
  },
  {
    "path": "python/oneflow/framework/infer_compiler/utils/log_utils.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport logging\nimport os\nimport time\nfrom pathlib import Path\n\n\nclass ColorFormatter(logging.Formatter):\n    COLORS = {\n        \"DEBUG\": \"\\033[34m\",  # Blue\n        \"INFO\": \"\\033[92m\",  # green\n        \"WARNING\": \"\\033[93m\",  # Yellow\n        \"ERROR\": \"\\033[91m\",  # Red\n        \"CRITICAL\": \"\\033[91m\",  # Red\n    }\n\n    def format(self, record):\n        log_message = super().format(record)\n        color = self.COLORS.get(record.levelname, \"\\033[0m\")  # Default to Reset color\n        return f\"{color}{log_message}\\033[0m\"\n\n\nclass ConfigurableLogger:\n    def __init__(self) -> None:\n        self.logger = logging.getLogger(__name__)\n\n    def __getattr__(self, name):\n        return getattr(self.logger, name)\n\n    def configure_logging(self, name, level, log_dir=None, file_name=None):\n        logger = logging.getLogger(name)\n\n        if logger.hasHandlers():\n            logger.warning(\"Logging handlers already exist for %s\", name)\n            return\n\n        logger.setLevel(level)\n\n        # Create a console formatter and add it to a console handler\n        console_formatter = ColorFormatter(\n            fmt=\"%(levelname)s [%(asctime)s] - %(message)s\", datefmt=\"%Y-%m-%d %H:%M:%S\"\n        )\n        console_handler = logging.StreamHandler()\n        console_handler.setFormatter(console_formatter)\n        logger.addHandler(console_handler)\n\n        # Create a file formatter and add it to a file handler if log_dir is provided\n        if log_dir:\n            log_dir = Path(log_dir)\n            os.makedirs(log_dir, exist_ok=True)\n\n            file_prefix = \"{}_\".format(\n                time.strftime(\"%Y-%m-%d_%H-%M-%S\", time.localtime())\n            )\n\n            if file_name:\n                log_file_name = file_prefix + file_name\n            else:\n                log_file_name = file_prefix + name + \".log\"\n\n            log_file = log_dir / log_file_name\n            file_formatter = logging.Formatter(\n                fmt=\"%(levelname)s [%(asctime)s] - %(message)s\",\n                datefmt=\"%Y-%m-%d %H:%M:%S\",\n            )\n            file_handler = logging.FileHandler(log_file, encoding=\"utf-8\")\n            file_handler.setFormatter(file_formatter)\n            logger.addHandler(file_handler)\n\n        self.logger = logger\n\n\nlogger = ConfigurableLogger()\n"
  },
  {
    "path": "python/oneflow/framework/infer_compiler/utils/oneflow_exec_mode.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow as flow\n\n_ONEFLOW_EXEC_MODE = False\n\n\nclass oneflow_exec_mode(object):\n    def __init__(self, enabled=None):\n        if enabled is not None:\n            self.enabled = enabled\n        else:\n            self.enabled = True\n\n    def __enter__(self):\n        global _ONEFLOW_EXEC_MODE\n        self.prev_mode = _ONEFLOW_EXEC_MODE\n        _ONEFLOW_EXEC_MODE = self.enabled\n        self.prev_grad_mode = flow.is_grad_enabled()\n        _ = flow.set_grad_enabled(False)\n\n    def __exit__(self, exc_type, exc_val, exc_tb):\n        global _ONEFLOW_EXEC_MODE\n        _ONEFLOW_EXEC_MODE = self.prev_mode\n        _ = flow.set_grad_enabled(self.prev_grad_mode)\n\n\ndef oneflow_exec_mode_enabled():\n    global _ONEFLOW_EXEC_MODE\n    return _ONEFLOW_EXEC_MODE\n"
  },
  {
    "path": "python/oneflow/framework/infer_compiler/utils/param_utils.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nfrom typing import Any, Dict, List\n\nimport oneflow as flow\nimport torch\n\n\ndef parse_device(args: List[Any], kwargs: Dict[str, Any]):\n    if \"device\" in kwargs:\n        return kwargs[\"device\"]\n    for x in args:\n        if isinstance(x, (flow.device, torch.device)):\n            return x\n        if x in [\"cpu\", \"cuda\"]:\n            return x\n    return None\n\n\ndef check_device(current_device, target_device) -> bool:\n    def _convert(device):\n        assert isinstance(device, (str, torch.device, flow.device))\n        if isinstance(device, torch.device):\n            index = device.index if device.index is not None else 0\n            return flow.device(device.type, index)\n        if isinstance(device, str):\n            return flow.device(device)\n        return device\n\n    return _convert(current_device) == _convert(target_device)\n"
  },
  {
    "path": "python/oneflow/framework/infer_compiler/utils/patch_for_compiler.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport math\n\nimport oneflow as flow\nimport oneflow.nn.functional as F\n\n\nclass FakeCuda:\n    @staticmethod\n    def current_device():\n        return \"cuda:0\"\n\n    @staticmethod\n    def mem_get_info(dev):\n        return 1024 * 1024 * 1024, 1024 * 1024 * 1024\n\n    @staticmethod\n    def _scaled_dot_product_attention_math(\n        query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False\n    ):\n        d_k = query.size(-1)\n\n        if is_causal:\n            assert attn_mask is None, \"Cannot use both attn_mask and is_causal=True\"\n            L, S = query.size(-2), key.size(-2)\n            attn_mask = flow.ones((L, S), dtype=flow.bool).tril()\n\n        if attn_mask is not None:\n            if attn_mask.dtype == flow.bool:\n                new_attn_mask = flow.empty(\n                    attn_mask.shape, dtype=query.dtype, device=query.device\n                )\n                mask = flow.logical_not(attn_mask)\n                new_attn_mask.masked_fill_(mask, float(\"-inf\"))\n                attn_mask = new_attn_mask\n\n        scores = flow.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)\n\n        if attn_mask is not None:\n            scores.add_(attn_mask)\n\n        p_attn = F.softmax(scores, dim=-1)\n\n        if dropout_p > 0.0:\n            generator = flow.Generator()\n            p_attn = flow.nn.functional.dropout(\n                p_attn, p=dropout_p, generator=generator\n            )\n\n        return flow.matmul(p_attn, value)\n\n    @staticmethod\n    def scaled_dot_product_attention(\n        query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False\n    ):\n        \"\"\"Scaled Dot-Product Attention\n        Args:\n        query (Tensor): Query tensor; shape :math:`(N, ..., L, E)`.\n        key (Tensor): Key tensor; shape :math:`(N, ..., S, E)`.\n        value (Tensor): Value tensor; shape :math:`(N, ..., S, Ev)`.\n        attn_mask (optional Tensor): Attention mask; shape :math:`(N, ..., L, S)`. Two types of masks are supported.\n            A boolean mask where a value of True indicates that the element *should* take part in attention.\n            A float mask of the same type as query, key, value that is added to the attention score.\n        dropout_p (float): Dropout probability; if greater than 0.0, dropout is applied\n        is_causal (bool): If true, assumes causal attention masking and errors if both attn_mask and is_causal\n            are set.\n\n        Returns:\n            output (Tensor): Attention output; shape :math:`(N, ..., L, Ev)`.\n\n        Shape legend:\n            - :math:`N: \\text{Batch size} ... : \\text{Any number of other batch dimensions (optional)}`\n            - :math:`S: \\text{Source sequence length}`\n            - :math:`L: \\text{Target sequence length}`\n            - :math:`E: \\text{Embedding dimension of the query and key}`\n            - :math:`Ev: \\text{Embedding dimension of the value}`\n        \"\"\"\n        if attn_mask is not None or dropout_p > 0.0:\n            return FakeCuda._scaled_dot_product_attention_math(\n                query, key, value, attn_mask, dropout_p, is_causal\n            )\n\n        batch_size, num_heads, target_seq_len, head_dim = query.shape\n        out = flow._C.fused_multi_head_attention_inference_v2(\n            query=query,\n            query_layout=\"BHMK\",\n            query_head_size=head_dim,\n            key=key,\n            key_layout=\"BHMK\",\n            value=value,\n            value_layout=\"BHMK\",\n            output_layout=\"BM(HK)\",\n            causal=is_causal,\n        )\n        # (N, L, H x Ev) -> (N, H, L, Ev)\n        value_embed_dim = value.shape[-1]\n        out = out.view(batch_size, target_seq_len, num_heads, value_embed_dim).permute(\n            0, 2, 1, 3\n        )\n        return out\n\n\nflow.cuda.current_device = FakeCuda.current_device\nflow.cuda.mem_get_info = FakeCuda.mem_get_info\nflow.nn.functional.scaled_dot_product_attention = FakeCuda.scaled_dot_product_attention\nF.scaled_dot_product_attention = FakeCuda.scaled_dot_product_attention\n"
  },
  {
    "path": "python/oneflow/framework/infer_compiler/utils/patch_for_diffusers.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n# TODO: remove this file to diffusers/src/infer_compiler_registry/register_diffusers\nfrom abc import ABC, abstractmethod\n\nfrom .log_utils import logger\n\ntry:\n    import diffusers\n    from diffusers.models.attention_processor import Attention\nexcept ImportError:\n    diffusers = None\n    logger.warning(\"diffusers not found, some features will be disabled.\")\n\n_IS_DIFFUSERS_AVAILABLE = diffusers is not None\n\n\nclass InstanceChecker(ABC):\n    @abstractmethod\n    def is_attention_instance(self, instance):\n        pass\n\n\nclass DiffusersChecker(InstanceChecker):\n    def is_attention_instance(self, instance):\n        if not _IS_DIFFUSERS_AVAILABLE:\n            return False\n        return isinstance(instance, Attention)\n\n\ndiffusers_checker = DiffusersChecker()\n"
  },
  {
    "path": "python/oneflow/framework/infer_compiler/with_fx_graph.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport os\n\nimport oneflow as flow\nimport torch\nimport torch.fx as fx\nfrom torch.fx.node import map_aggregate\n\nfrom .transform import get_attr, torch2oflow\n\n\ndef fx_node_tranform(gm):\n    of_gm = to_of_transform(gm)\n\n    enable_graph = os.getenv(\"ONEDIFF_INFER_COMPILER_USE_GRAPH\", \"True\").lower() in (\n        \"true\",\n        \"1\",\n        \"t\",\n    )\n\n    if not enable_graph:\n        oneflow_fn = of_gm.forward\n    else:\n        # Align this with env setting in `with_oneflow_compile`.\n        # Otherwise, infererence using PyTorch with OneFlow backend on\n        # multiple input shapes may crash\n        os.environ.setdefault(\"ONEFLOW_RUN_GRAPH_BY_VM\", \"1\")\n        os.environ.setdefault(\"ONEFLOW_GRAPH_DELAY_VARIABLE_OP_EXECUTION\", \"1\")\n        os.environ.setdefault(\"ONEFLOW_MLIR_CSE\", \"1\")\n        os.environ.setdefault(\"ONEFLOW_MLIR_ENABLE_INFERENCE_OPTIMIZATION\", \"1\")\n        os.environ.setdefault(\"ONEFLOW_MLIR_ENABLE_ROUND_TRIP\", \"1\")\n        os.environ.setdefault(\"ONEFLOW_MLIR_FUSE_FORWARD_OPS\", \"1\")\n        os.environ.setdefault(\"ONEFLOW_MLIR_FUSE_OPS_WITH_BACKWARD_IMPL\", \"1\")\n        os.environ.setdefault(\"ONEFLOW_MLIR_GROUP_MATMUL\", \"1\")\n        os.environ.setdefault(\"ONEFLOW_MLIR_PREFER_NHWC\", \"0\")\n        os.environ.setdefault(\"ONEFLOW_KERNEL_ENABLE_FUSED_CONV_BIAS\", \"1\")\n        os.environ.setdefault(\"ONEFLOW_KERNEL_ENABLE_FUSED_LINEAR\", \"1\")\n        os.environ.setdefault(\n            \"ONEFLOW_KERNEL_CONV_CUTLASS_IMPL_ENABLE_TUNING_WARMUP\", \"1\"\n        )\n        os.environ.setdefault(\n            \"ONEFLOW_KERNEL_GEMM_CUTLASS_IMPL_ENABLE_TUNING_WARMUP\", \"1\"\n        )\n        os.environ.setdefault(\"ONEFLOW_KERNEL_CONV_ENABLE_CUTLASS_IMPL\", \"1\")\n        os.environ.setdefault(\"ONEFLOW_KERNEL_GEMM_ENABLE_CUTLASS_IMPL\", \"1\")\n        os.environ.setdefault(\"ONEFLOW_CONV_ALLOW_HALF_PRECISION_ACCUMULATION\", \"1\")\n        os.environ.setdefault(\"ONEFLOW_MATMUL_ALLOW_HALF_PRECISION_ACCUMULATION\", \"1\")\n        os.environ.setdefault(\"ONEFLOW_LINEAR_EMBEDDING_SKIP_INIT\", \"1\")\n        os.environ.setdefault(\"ONEFLOW_MLIR_GROUP_MATMUL_QUANT\", \"1\")\n\n        class OfGraph(flow.nn.Graph):\n            def __init__(self):\n                super().__init__()\n                self.fx_md = of_gm\n                self.config.enable_cudnn_conv_heuristic_search_algo(False)\n                self.config.allow_fuse_add_to_output(True)\n\n            def build(self, *args, **kwargs):\n                return self.fx_md(*args, **kwargs)\n\n        of_g = OfGraph()\n        oneflow_fn = lambda *args, **kwargs: of_g(*args, **kwargs)\n\n    return oneflow_fn\n\n\ndef to_of_transform(\n    gm: torch.fx.GraphModule, tracer_class: type = fx.Tracer\n) -> torch.fx.GraphModule:\n    name2node = {}\n    name2obj = {}\n    torch2flow = {}\n    of_g = flow.fx.Graph()\n    modules = dict(gm.named_modules())\n    for node in gm.graph.nodes:\n        if node.op == \"placeholder\":\n            of_node = of_g.create_node(\"placeholder\", node.target)\n            name2node[node.name] = of_node\n        elif node.op == \"output\":\n            of_node = of_g.output(node_replace_args(node.args, name2node)[0])\n            name2node[node.name] = of_node\n        elif node.op == \"call_function\":\n            of_node = of_g.create_node(\n                \"call_function\",\n                torch2oflow(node.target),\n                args=node_replace_args(node.args, name2node),\n                kwargs=node_replace_args(node.kwargs, name2node),\n            )\n            name2node[node.name] = of_node\n        elif node.op == \"call_method\":\n            of_node = of_g.create_node(\n                \"call_method\",\n                node.target,\n                args=node_replace_args(node.args, name2node),\n                kwargs=node_replace_args(node.kwargs, name2node),\n            )\n            name2node[node.name] = of_node\n        elif node.op == \"call_module\":\n            torch_md = modules[node.target]\n            name2obj[node.target] = torch2oflow(torch_md)\n\n            of_node = of_g.create_node(\n                \"call_module\",\n                node.target,\n                args=node_replace_args(node.args, name2node),\n                kwargs=node_replace_args(node.kwargs, name2node),\n            )\n            name2node[node.name] = of_node\n        elif node.op == \"get_attr\":\n            of_node = of_g.create_node(\"get_attr\", node.target)\n            name2node[node.name] = of_node\n            name2obj[node.target] = get_attr(gm, node, torch2flow)\n        else:\n            raise ValueError(f\"not valid node type{node.foramt_node()}\")\n\n    of_gm = flow.fx.GraphModule(name2obj, of_g)\n    of_gm.training = False\n    of_gm.graph.lint()\n    of_gm.recompile()\n    return of_gm\n\n\ndef replace_node(node, name2node):\n    if isinstance(node, torch.fx.Node):\n        return name2node[node.name]\n    else:\n        return torch2oflow(node)\n\n\ndef node_replace_args(args, name2node):\n    return map_aggregate(args, lambda node: replace_node(node, name2node))\n"
  },
  {
    "path": "python/oneflow/framework/infer_compiler/with_fx_interpreter.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nfrom typing import Any, Dict, Tuple\n\nimport torch\nfrom oneflow.framework.infer_compiler.transform.builtin_transform import torch2oflow\n\nfrom .transform import ProxySubmodule, map_args\n\n\nclass OneFlowInterpreter(torch.fx.Interpreter):\n    from torch.fx.node import Argument, Target\n\n    def call_function(self, target: Target, args: Tuple, kwargs: Dict) -> Any:\n        args, kwargs = map_args(args, kwargs)\n        target = torch2oflow(target)\n        return super().call_function(target, args, kwargs)\n\n    def call_method(self, target: Target, args: Tuple, kwargs: Dict) -> Any:\n        args, kwargs = map_args(args, kwargs)\n        return super().call_method(target, args, kwargs)\n\n    def call_module(\n        self, target: \"Target\", args: Tuple[Argument, ...], kwargs: Dict[str, Any]\n    ) -> Any:\n        submod = self.fetch_attr(target)\n        submod = ProxySubmodule(submod)\n        return submod(*args, **kwargs)\n"
  },
  {
    "path": "python/oneflow/framework/infer_compiler/with_oneflow_backend.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport os\nimport torch\n\nimport oneflow as flow\nfrom oneflow.framework.args_tree import ArgsTree\nfrom .with_fx_graph import fx_node_tranform\nfrom .with_fx_interpreter import OneFlowInterpreter\n\n\ndef oneflow_backend(gm, example_inputs, *args, **kwargs):\n    with_interp = os.getenv(\n        \"ONEDIFF_INFER_COMPILER_USE_INTERPRETER\", \"False\"\n    ).lower() in (\"true\", \"1\", \"t\",)\n    if not with_interp:\n        transformed_fn = fx_node_tranform(gm)\n\n    def wrapped_forward(*args, **kwargs):\n        def input_fn(value):\n            if isinstance(value, torch.Tensor):\n                return flow.utils.tensor.from_torch(value.contiguous())\n            else:\n                return value\n\n        args_tree = ArgsTree((args, kwargs), False, tensor_type=torch.Tensor)\n        out = args_tree.map_leaf(input_fn)\n        args = out[0]\n        if with_interp:\n            output = OneFlowInterpreter(gm, garbage_collect_values=False).run(\n                *args, **kwargs\n            )\n        else:\n            output = transformed_fn(*args, **kwargs)\n        if isinstance(output, tuple):\n            return tuple(flow.utils.tensor.to_torch(i) for i in output)\n        return flow.utils.tensor.to_torch(output)\n\n    return wrapped_forward\n"
  },
  {
    "path": "python/oneflow/framework/infer_compiler/with_oneflow_compile.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport os\nimport types\nfrom functools import wraps\nfrom itertools import chain\nfrom typing import Any\n\nimport oneflow as flow\nimport torch\nfrom oneflow.utils.tensor import to_torch\n\nfrom .transform.builtin_transform import torch2oflow\nfrom .transform.manager import transform_mgr\nfrom .utils.args_tree_util import input_output_processor\nfrom .utils.cost_util import cost_cnt\nfrom .utils.log_utils import logger\nfrom .utils.oneflow_exec_mode import oneflow_exec_mode, oneflow_exec_mode_enabled\nfrom .utils.param_utils import check_device, parse_device\n\n\nclass DualModule(torch.nn.Module):\n    def __init__(self, torch_module, oneflow_module):\n        torch.nn.Module.__init__(self)\n        self._torch_module = torch_module\n        self._oneflow_module = oneflow_module\n\n    @property\n    def oneflow_module(self):\n        if self._oneflow_module is not None:\n            return self._oneflow_module\n\n        logger.debug(f\"Convert {type(self._torch_module)} ...\")\n        self._oneflow_module = torch2oflow(self._torch_module)\n        logger.debug(f\"Convert {id(self._torch_module)} done!\")\n        return self._oneflow_module\n\n    @oneflow_module.deleter\n    def oneflow_module(self):\n        if self._oneflow_module:\n            del self._oneflow_module\n            setattr(self, \"_oneflow_module\", None)\n\n    def to(self, *args, **kwargs):\n        if oneflow_exec_mode_enabled():\n            self._oneflow_module.to(*args, **kwargs)\n        else:\n            if self._oneflow_module is not None:\n                args = [torch2oflow(v) for v in args]\n                kwargs = {k: torch2oflow(v) for k, v in kwargs.items()}\n                self._oneflow_module.to(*args, **kwargs)\n                self._torch_module_to_with_check(*args, **kwargs)\n            else:\n                self._torch_module.to(*args, **kwargs)\n\n    def _torch_module_to_with_check(self, *args, **kwargs):\n        def _align_tensor(torch_module, oneflow_module):\n            oneflow_tensor_list = set(\n                [x for x, _ in oneflow_module.named_parameters()]\n                + [x for x, _ in oneflow_module.named_buffers()]\n            )\n            for name, tensor in chain.from_iterable(\n                [torch_module.named_parameters(), torch_module.named_buffers(),]\n            ):\n                if name not in oneflow_tensor_list:\n                    tensor.data = tensor.to(*args, **kwargs)\n                else:\n                    oneflow_tensor = oneflow_module.get_parameter(name)\n                    if oneflow_tensor is None:\n                        tensor.data = tensor.to(*args, **kwargs)\n                    elif tensor.data_ptr() != oneflow_tensor.data_ptr():\n                        tensor.data = to_torch(oneflow_tensor.data)\n\n        oneflow_module_list = set([x for x, _ in self._oneflow_module.named_modules()])\n        for name, module in self._torch_module.named_modules():\n            if name not in oneflow_module_list:\n                module.to(*args, **kwargs)\n            else:\n                _align_tensor(module, self._oneflow_module.get_submodule(name))\n\n    def __getattr__(self, name):\n        if name == \"_torch_module\":\n            return self._modules[name]\n        if name == \"_oneflow_module\":\n            return super().__getattribute__(name)\n\n        torch_attr = getattr(self._torch_module, name)\n        oneflow_attr = (\n            None\n            if self._oneflow_module is None\n            else getattr(self._oneflow_module, name)\n        )\n        if isinstance(torch_attr, torch.nn.ModuleList):\n            oneflow_attr = (\n                [None] * len(torch_attr) if oneflow_attr is None else oneflow_attr\n            )\n            return DualModuleList(torch_attr, oneflow_attr)\n        elif isinstance(torch_attr, torch.nn.Module):\n            return get_mixed_dual_module(torch_attr.__class__)(torch_attr, oneflow_attr)\n        else:\n            return oneflow_attr if oneflow_exec_mode_enabled() else torch_attr\n\n    def __setattr__(self, name: str, value: Any) -> None:\n        if name in [\"_torch_module\", \"_oneflow_module\"]:\n            super().__setattr__(name, value)\n        else:  # TODO: aviod memory up when set attr\n            try:\n                setattr(self._torch_module, name, value)\n                value = torch2oflow(value)\n                if isinstance(value, flow.Tensor):\n                    obj = getattr(self._oneflow_module, name)\n                    obj.copy_(value)\n                else:\n                    setattr(self._oneflow_module, name, value)\n            except:\n                super().__setattr__(name, value)\n\n\nclass DualModuleList(torch.nn.ModuleList):\n    def __init__(self, torch_modules, oneflow_modules):\n        super().__init__()\n        assert len(torch_modules) == len(oneflow_modules)\n        self._torch_modules = torch_modules\n        self._oneflow_modules = oneflow_modules\n        dual_modules = []\n        for torch_module, oneflow_module in zip(\n            self._torch_modules, self._oneflow_modules\n        ):\n            dual_modules.append(\n                get_mixed_dual_module(torch_module.__class__)(\n                    torch_module, oneflow_module\n                )\n            )\n        # clear self._modules since `self._torch_modules = torch_modules` will append a module to self._modules\n        self._modules.clear()\n        self += dual_modules\n\n    def __setitem__(self, idx: int, module: DualModule):\n        idx = self._get_abs_string_index(idx)\n        setattr(self._torch_modules, str(idx), module._torch_module)\n        setattr(self._oneflow_modules, str(idx), module._oneflow_module)\n        return setattr(self, str(idx), module)\n\n    def __setattr__(self, name, value):\n        if name in (\"_torch_modules\", \"_oneflow_modules\"):\n            return object.__setattr__(self, name, value)\n        try:\n            if isinstance(value, DualModule):\n                setattr(self._torch_modules, name, value._torch_module)\n                setattr(self._oneflow_modules, name, value._oneflow_module)\n            else:\n                setattr(self._torch_modules, name, value)\n                value = torch2oflow(value)\n                setattr(self._oneflow_modules, name, value)\n        except:\n            super().__setattr__(name, value)\n\n\ndef get_mixed_dual_module(module_cls):\n    class MixedDualModule(DualModule, module_cls):\n        def __init__(self, torch_module, oneflow_module):\n            DualModule.__init__(self, torch_module, oneflow_module)\n\n    return MixedDualModule\n\n\ndef graph_file_management(func):\n    @wraps(func)\n    def wrapper(self: \"DeployableModule\", *args, **kwargs):\n        graph_file = self._deployable_module_options.get(\"graph_file\", None)\n\n        # Load graph file\n        if graph_file is not None:\n            try:\n                if not os.path.exists(graph_file):\n                    logger.warning(\n                        f\"Graph file {graph_file} not exists!, will generate graph.\"\n                    )\n\n                else:\n                    graph_device = self._deployable_module_options.get(\n                        \"graph_file_device\", None\n                    )\n\n                    self.load_graph(graph_file, torch2oflow(graph_device))\n                    logger.info(f\"Load graph file: {graph_file}\")\n\n                    graph_file = None\n                    self._deployable_module_options[\"graph_file\"] = None\n\n            except Exception as e:\n                logger.error(f\"Load graph file: {graph_file} failed! {e}\")\n\n        ret = func(self, *args, **kwargs)\n\n        # Save graph file\n        if graph_file is not None:\n            try:\n                if graph_file is not None:\n                    os.makedirs(os.path.dirname(graph_file), exist_ok=True)\n                    self.save_graph(graph_file)\n                    logger.info(f\"Save graph file: {graph_file} done!\")\n            except Exception as e:\n                logger.error(f\"Save graph file: {graph_file} failed! {e}\")\n            finally:\n                self._deployable_module_options[\"graph_file\"] = None\n\n        return ret\n\n    return wrapper\n\n\ndef handle_deployable_exception(func):\n    @wraps(func)\n    def wrapper(self, *args, **kwargs):\n        if transform_mgr.debug_mode:\n            return func(self, *args, **kwargs)\n        else:\n            try:\n                return func(self, *args, **kwargs)\n            except Exception as e:\n                logger.error(f\"Exception in {func.__name__}: {e}\")\n                logger.warning(\"Recompile oneflow module ...\")\n                del self._deployable_module_model.oneflow_module\n                self._deployable_module_dpl_graph = None\n                return func(self, *args, **kwargs)\n\n    return wrapper\n\n\nclass DeployableModule(torch.nn.Module):\n    def __init__(\n        self,\n        torch_module,\n        oneflow_module,\n        use_graph=True,\n        options={},\n        graph_path=None,\n        graph_device=None,\n    ):\n        torch.nn.Module.__init__(self)\n        self._deployable_module_model = get_mixed_dual_module(torch_module.__class__)(\n            torch_module, oneflow_module\n        )\n        self._deployable_module_use_graph = use_graph\n        self._deployable_module_options = options\n        self._deployable_module_dpl_graph = None\n        self._is_raw_deployable_module = True\n\n    @classmethod\n    def from_existing(cls, existing_module, use_graph=None, options=None):\n        torch_module = existing_module._deployable_module_model._torch_module\n        oneflow_module = existing_module._deployable_module_model._oneflow_module\n        instance = cls(torch_module, oneflow_module, use_graph, options)\n        instance._deployable_module_dpl_graph = (\n            existing_module._deployable_module_dpl_graph if use_graph else None\n        )\n        return instance\n\n    def get_graph(self):\n        if self._deployable_module_dpl_graph is not None:\n            return self._deployable_module_dpl_graph\n        if \"size\" in self._deployable_module_options:\n            size = self._deployable_module_options[\"size\"]\n        else:\n            size = 9\n        if \"dynamic\" in self._deployable_module_options:\n            dynamic = self._deployable_module_options[\"dynamic\"]\n        else:\n            dynamic = True\n        self._deployable_module_dpl_graph = get_oneflow_graph(\n            self._deployable_module_model.oneflow_module, size, dynamic\n        )\n        if \"debug\" in self._deployable_module_options:\n            self._deployable_module_dpl_graph.debug(\n                self._deployable_module_options[\"debug\"]\n            )\n        return self._deployable_module_dpl_graph\n\n    @input_output_processor\n    @handle_deployable_exception\n    @graph_file_management\n    def apply_model(self, *args, **kwargs):\n        if self._deployable_module_use_graph:\n            dpl_graph = self.get_graph()\n            with oneflow_exec_mode():\n                output = dpl_graph(*args, **kwargs)\n        else:\n            with oneflow_exec_mode():\n                output = self._deployable_module_model.oneflow_module.apply_model(\n                    *args, **kwargs\n                )\n        return output\n\n    @input_output_processor\n    @handle_deployable_exception\n    @graph_file_management\n    def __call__(self, *args, **kwargs):\n        if self._deployable_module_use_graph:\n            dpl_graph = self.get_graph()\n            with oneflow_exec_mode():\n                output = dpl_graph(*args, **kwargs)\n        else:\n            with oneflow_exec_mode():\n                output = self._deployable_module_model.oneflow_module(*args, **kwargs)\n        return output\n\n    def to(self, *args, **kwargs):\n        if self._deployable_module_dpl_graph is None:\n            self._deployable_module_model.to(*args, **kwargs)\n            return self\n\n        # assert the target device is same as graph device\n        target_device = parse_device(args, kwargs)\n        if (\n            target_device is not None\n            and len(self._deployable_module_dpl_graph._blocks) > 0\n        ):\n            current_device = next(self._deployable_module_dpl_graph._state()).device\n            if not check_device(current_device, target_device):\n                raise RuntimeError(\n                    f\"After graph built, the device of graph can't be modified, current device: {current_device}, target device: {target_device}\"\n                )\n        self._deployable_module_model.to(*args, **kwargs)\n        return self\n\n    # TODO(): Just for transformers VAE decoder\n    @input_output_processor\n    @handle_deployable_exception\n    @graph_file_management\n    def decode(self, *args, **kwargs):\n        if self._deployable_module_use_graph:\n\n            def _build(graph, *args, **kwargs):\n                return graph.model.decode(*args, **kwargs)\n\n            dpl_graph = self.get_graph()\n            dpl_graph.build = types.MethodType(_build, dpl_graph)\n            with oneflow_exec_mode():\n                output = dpl_graph(*args, **kwargs)\n        else:\n            with oneflow_exec_mode():\n                output = self._deployable_module_model.oneflow_module.decode(\n                    *args, **kwargs\n                )\n        return output\n\n    def __getattr__(self, name):\n        if name in self._modules:\n            return self._modules[name]\n        return getattr(self._deployable_module_model, name)\n\n    def load_graph(self, file_path, device=None, run_warmup=True):\n        self.get_graph().warmup_with_load(file_path, device, run_warmup)\n\n    def warmup_with_load(self, file_path, device=None, run_warmup=True):\n        self.get_graph().warmup_with_load(file_path, device, run_warmup)\n\n    def save_graph(self, file_path):\n        self.get_graph().save_graph(file_path)\n\n\nclass OneflowGraph(flow.nn.Graph):\n    @flow.nn.Graph.with_dynamic_input_shape()\n    def __init__(self, model):\n        super().__init__(enable_get_runtime_state_dict=True)\n        self.model = model\n        self.config.enable_cudnn_conv_heuristic_search_algo(False)\n        self.config.allow_fuse_add_to_output(True)\n\n        os.environ.setdefault(\"ONEFLOW_GRAPH_DELAY_VARIABLE_OP_EXECUTION\", \"1\")\n        os.environ.setdefault(\"ONEFLOW_MLIR_CSE\", \"1\")\n        os.environ.setdefault(\"ONEFLOW_MLIR_ENABLE_INFERENCE_OPTIMIZATION\", \"1\")\n        os.environ.setdefault(\"ONEFLOW_MLIR_ENABLE_ROUND_TRIP\", \"1\")\n        os.environ.setdefault(\"ONEFLOW_MLIR_FUSE_FORWARD_OPS\", \"1\")\n        os.environ.setdefault(\"ONEFLOW_MLIR_FUSE_OPS_WITH_BACKWARD_IMPL\", \"1\")\n        os.environ.setdefault(\"ONEFLOW_MLIR_GROUP_MATMUL\", \"1\")\n        # TODO(lml): enable ONEFLOW_MLIR_PREFER_NHWC when related bug fix.\n        os.environ.setdefault(\"ONEFLOW_MLIR_PREFER_NHWC\", \"0\")\n        os.environ.setdefault(\"ONEFLOW_KERNEL_ENABLE_FUSED_CONV_BIAS\", \"1\")\n        os.environ.setdefault(\"ONEFLOW_KERNEL_ENABLE_FUSED_LINEAR\", \"1\")\n        os.environ.setdefault(\n            \"ONEFLOW_KERNEL_CONV_CUTLASS_IMPL_ENABLE_TUNING_WARMUP\", \"1\"\n        )\n        os.environ.setdefault(\n            \"ONEFLOW_KERNEL_GEMM_CUTLASS_IMPL_ENABLE_TUNING_WARMUP\", \"1\"\n        )\n        os.environ.setdefault(\"ONEFLOW_KERNEL_CONV_ENABLE_CUTLASS_IMPL\", \"1\")\n        os.environ.setdefault(\"ONEFLOW_KERNEL_GEMM_ENABLE_CUTLASS_IMPL\", \"1\")\n        os.environ.setdefault(\"ONEFLOW_CONV_ALLOW_HALF_PRECISION_ACCUMULATION\", \"1\")\n        os.environ.setdefault(\"ONEFLOW_MATMUL_ALLOW_HALF_PRECISION_ACCUMULATION\", \"1\")\n        os.environ.setdefault(\"ONEFLOW_LINEAR_EMBEDDING_SKIP_INIT\", \"1\")\n        os.environ.setdefault(\"ONEFLOW_MLIR_GROUP_MATMUL_QUANT\", \"1\")\n\n    def build(self, *args, **kwargs):\n        return self.model(*args, **kwargs)\n\n    @cost_cnt(transform_mgr.debug_mode)\n    def warmup_with_load(self, file_path, device=None, run_warmup=True):\n        state_dict = flow.load(file_path)\n        if device is not None:\n            state_dict = flow.nn.Graph.runtime_state_dict_to(state_dict, device)\n        self.load_runtime_state_dict(state_dict, warmup_with_run=run_warmup)\n\n    @cost_cnt(transform_mgr.debug_mode)\n    def save_graph(self, file_path):\n        state_dict = self.runtime_state_dict()\n        flow.save(state_dict, file_path)\n\n\ndef get_oneflow_graph(model, size=9, dynamic=True):\n    g = OneflowGraph(model)\n    g._dynamic_input_graph_cache.set_cache_size(size)\n    g._dynamic_input_graph_cache.enable_shared(dynamic)\n    return g\n\n\ndef state_dict_hook(module, state_dict, prefix, local_metadata):\n    pytorch_key_prefix = \"_deployable_module_model._torch_module.\"\n    new_state_dict = type(state_dict)()\n    for k, v in state_dict.items():\n        # _deployable_module_model._torch_module.out.2.weight => out.2.weight\n        if k.startswith(pytorch_key_prefix):\n            new_k = k[len(pytorch_key_prefix) :]\n            new_state_dict[new_k] = v\n        else:\n            new_state_dict[k] = v\n    return new_state_dict\n\n\n# Return a DeployableModule that using module_cls as it's parent class.\ndef get_mixed_deployable_module(module_cls):\n    class MixedDeployableModule(DeployableModule, module_cls):\n        def __init__(\n            self,\n            torch_module,\n            oneflow_module,\n            use_graph=True,\n            options={},\n            graph_path=None,\n            graph_device=None,\n        ):\n            DeployableModule.__init__(\n                self,\n                torch_module,\n                oneflow_module,\n                use_graph,\n                options,\n                graph_path,\n                graph_device,\n            )\n            self._is_raw_deployable_module = False\n\n        @classmethod\n        def from_existing(cls, existing_module, use_graph=None, options=None):\n            torch_module = existing_module._deployable_module_model._torch_module\n            oneflow_module = existing_module._deployable_module_model._oneflow_module\n            instance = cls(torch_module, oneflow_module, use_graph, options)\n            instance._deployable_module_dpl_graph = (\n                existing_module._deployable_module_dpl_graph if use_graph else None\n            )\n            return instance\n\n    return MixedDeployableModule\n\n\ndef compile_from_torch(\n    torch_module: torch.nn.Module, *, use_graph=True, options={},\n):\n    \"\"\"\n    Converts torch module to oneflow module.\n\n    Note:\n        Map from torch to oneflow should be registered by `infer_compiler.register(torch2oflow_class_map={TorchModule: OneflowModule})` before `compile_from_torch` be called.\n\n    Args:\n        torch_module (torch.nn.Module): Torch module to be compiled.\n        use_graph (bool, optional): If `True`, graph of compiled module can be saved and loaded to speedup the compile process. Defaults to `True`.\n        options (dict, optional): \n            size (int, optional): graph cache size. Defaults to `9`.\n            dynamic (bool, optional): If `True`, graph of compiled module can be shared with other modules. Defaults to `True`.\n            debug (int, optional): debug level. Defaults to `-1`.\n\n    Returns:\n        DeployableModule: Compiled oneflow module.\n    \"\"\"\n\n    def wrap_module(module):\n        if isinstance(module, DeployableModule):\n            assert not module._is_raw_deployable_module\n            return module.__class__.from_existing(module, use_graph, options)\n        else:\n            return get_mixed_deployable_module(module.__class__)(\n                module, None, use_graph, options\n            )\n\n    model = wrap_module(torch_module)\n    assert isinstance(model, DeployableModule)\n    assert isinstance(model, torch_module.__class__)\n    model._register_state_dict_hook(state_dict_hook)\n\n    return model\n"
  },
  {
    "path": "python/oneflow/framework/job_set_util.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nfrom typing import Optional, TypeVar\n\nfrom oneflow.core.job.job_set_pb2 import JobSet\n\n_VT = TypeVar(\"_VT\")\n\n\n_default_job_set = JobSet()\n"
  },
  {
    "path": "python/oneflow/framework/model.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n__all__ = [\n    \"DataModule\",\n    \"NumpyDataModule\",\n    \"TrainingConfig\",\n    \"ValidationConfig\",\n    \"CheckpointConfig\",\n    \"Callback\",\n    \"Model\",\n]\nimport inspect\nfrom abc import ABC\nfrom typing import Any, List, Optional, Tuple, Union\n\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow._oneflow_internal\nimport oneflow.framework.dtype as dtype_util\nfrom oneflow.framework.function_util import FunctionConfig as ExecutionConfig\nfrom oneflow.framework.tensor import Tensor\nfrom oneflow.nn.modules.module import Module\nfrom oneflow.optim.optimizer import Optimizer as OOPOptimizer\n\n\nclass DataModule(Module):\n    def __init__(self, *args, **kwargs):\n        super().__init__()\n\n    def forward(self, step_idx: int = 0, optimizer_idx: int = 0):\n        pass\n\n    def infer_oneflow_data_placeholder(\n        self, batch: Tuple[Any] = None, optimizer_idx: int = 0\n    ):\n        return None\n\n\nclass NumpyDataModule(DataModule):\n    def __init__(self, *args, **kwargs):\n        super().__init__()\n\n    def forward(self, step_idx: int = 0, optimizer_idx: int = 0):\n        pass\n\n    def __call__(self, *args):\n        ret = self.forward(*args)\n        return ret\n\n    def infer_oneflow_data_placeholder(\n        self, batch: Tuple[np.ndarray, ...] = None, optimizer_idx: int = 0\n    ):\n        assert isinstance(batch, tuple), \"model.NumpyDataModule must return a tuple.\"\n        data_placeholder_list = []\n        for item in batch:\n            assert isinstance(\n                item, np.ndarray\n            ), \"model.NumpyDataModule must return a tuple of numpy.\"\n            of_dtype = dtype_util.convert_numpy_dtype_to_oneflow_dtype(item.dtype)\n            # numpy_placeholder = oneflow_typing.Numpy.Placeholder(\n            #    shape=item.shape, dtype=of_dtype\n            # )\n            data_placeholder_list.append(numpy_placeholder)\n        return data_placeholder_list\n\n\nclass TrainingConfig:\n    def __init__(self):\n        super().__init__()\n        self.exe_cfg = ExecutionConfig()\n        self.data = None\n        self.error_msg = \"\"\n\n    def config_execution(self, exe_cfg: ExecutionConfig = None):\n        self.exe_cfg = exe_cfg\n\n    def config_data(self, data: DataModule = None):\n        self.data = data\n\n    def check_valid(self):\n        is_valid = True\n        self.error_msg = \"\"\n        if not isinstance(self.exe_cfg, ExecutionConfig):\n            self.error_msg += \"model.TrainingConfig exe_cfg is not ExecutionConfig;\"\n            is_valid = False\n        if self.data is None:\n            self.error_msg += \"model.TrainingConfig data is None;\"\n            is_valid = False\n        if not isinstance(self.data, DataModule):\n            self.error_msg += \"model.TrainingConfig data is not DataModule;\"\n            is_valid = False\n        return is_valid\n\n\nclass ValidationConfig:\n    def __init__(self):\n        super().__init__()\n        self.exe_cfg = ExecutionConfig()\n        self.data = None\n        self.step_interval = 10\n        self.error_msg = \"\"\n\n    def config_execution(self, exe_cfg: ExecutionConfig = None):\n        self.exe_cfg = exe_cfg\n\n    def config_data(self, data: DataModule = None):\n        self.data = data\n\n    def config_step_interval(self, step_interval: int = 1):\n        self.step_interval = step_interval\n\n    def check_valid(self):\n        is_valid = True\n        self.error_msg = \"\"\n        if self.data is None:\n            self.error_msg += \"model.ValidationConfig data is None;\"\n            is_valid = False\n        if not isinstance(self.data, DataModule):\n            self.error_msg += \"model.ValidationConfig data is not DataModule;\"\n            is_valid = False\n        if self.step_interval <= 0 or not isinstance(self.step_interval, int):\n            self.error_msg += (\n                \"model.ValidationConfig step_interval is <= 0 or is not int;\"\n            )\n            is_valid = False\n        return is_valid\n\n\nclass CheckpointConfig(object):\n    def __init__(self):\n        self.need_load = False\n        self.load_dirpath = None\n        self.need_save = False\n        self.save_dirpath = None\n        self.save_step_interval = 1\n        self.error_msg = \"\"\n\n    def config_load(self, dirpath: str = None):\n        self.need_load = True\n        assert dirpath is not None, \"dirpath should not be None\"\n        self.load_dirpath = dirpath\n\n    def config_save(self, dirpath: str = None, step_interval: int = 1):\n        self.need_save = True\n        self.save_dirpath = dirpath\n        assert dirpath is not None, \"dirpath should not be None\"\n        self.save_step_interval = step_interval\n        assert step_interval > 0, \"step_interval should not <= 0\"\n        assert isinstance(step_interval, int), \"step_interval should be int\"\n\n    def check_valid(self):\n        is_valid = True\n        self.error_msg = \"\"\n        return is_valid\n\n\nclass Callback(ABC):\n    \"\"\" Abstract base class used to build new callbacks.\n    \"\"\"\n\n    def on_training_step_end(\n        self,\n        outputs: Optional[Union[Tensor, Tuple[Tensor, ...]]],\n        step_idx: int = 0,\n        optimizer_idx: int = 0,\n    ):\n        pass\n\n    def on_validation_step_end(\n        self, outputs: Optional[Union[Tensor, Tuple[Tensor, ...]]], step_idx: int = 0,\n    ):\n        pass\n\n\nclass Model(ABC, Module):\n    \"\"\"A high level API for model training and validation.\n    \"\"\"\n\n    def __init__(self, *args, **kwargs):\n        super().__init__()\n        self._is_deprecated_function_style = (\n            kwargs[\"is_deprecated_function_style\"]\n            if \"is_deprecated_function_style\" in kwargs\n            else False\n        )\n\n    def forward(self, *args, **kwargs):\n        \"\"\"Same as `nn.Module.forward()`, here is to define the operations you want to use for prediction.\n        \"\"\"\n        raise NotImplementedError\n\n    def training_step(self, *args, **kwargs):\n        \"\"\"Operates on a single batch of data from the training set and return loss.\n        \"\"\"\n        raise NotImplementedError()\n\n    def validation_step(self, *args, **kwargs):\n        \"\"\"Operates on a single batch of data from the validation set.\n        \"\"\"\n        raise NotImplementedError()\n\n    def configure_optimizers(self):\n        \"\"\"Choose what optimizers and learning-rate schedulers to use in your optimization.\n        Normally you'd need one. But in the case of GANs or similar you might have multiple.\n        \"\"\"\n        raise NotImplementedError()\n\n    def fit(\n        self,\n        training_config: Optional[TrainingConfig] = None,\n        validation_config: Optional[ValidationConfig] = None,\n        checkpoint_config: Optional[CheckpointConfig] = None,\n        callbacks: Optional[Union[Callback, List[Callback]]] = None,\n        max_steps: int = 100,\n    ):\n        \"\"\" Runs the full training and validation routine.\n        \"\"\"\n        self._max_steps = max_steps\n        self._sub_models = self._get_and_check_sub_models(\n            training_config, validation_config, checkpoint_config, callbacks\n        )\n        if len(self._sub_models) == 0:\n            return\n        if self._checkpoint_model.is_valid:\n            self._checkpoint_model.load()\n        for step_idx in range(0, self._max_steps):\n            for sub_model in self._sub_models:\n                try:\n                    sub_model.step(step_idx)\n                except Exception as e:\n                    print(\n                        \"Model step_idx {} {} failed.\".format(step_idx, sub_model.name)\n                    )\n                    raise e\n\n    def method_overrided(self, method_name: str = None) -> bool:\n        return getattr(self.__class__, method_name) != getattr(Model, method_name)\n\n    def _get_and_check_sub_models(\n        self,\n        training_config: Optional[TrainingConfig] = None,\n        validation_config: Optional[ValidationConfig] = None,\n        checkpoint_config: Optional[CheckpointConfig] = None,\n        callbacks: Optional[Union[Callback, List[Callback]]] = None,\n    ):\n        sub_models = []\n        self._train_model = (\n            TrainModel(training_config, self, callbacks)\n            if self._is_deprecated_function_style\n            else TrainModelOOPStyle(training_config, self, callbacks)\n        )\n        if self._train_model.is_valid:\n            sub_models.append(self._train_model)\n        elif training_config is not None:\n            print(\n                self._train_model.error_msg,\n                \"{}'s fit() will not do training.\".format(self.__class__.__name__),\n            )\n        self._val_model = (\n            ValidateModel(validation_config, self, callbacks)\n            if self._is_deprecated_function_style\n            else ValidateModelOOPStyle(validation_config, self, callbacks)\n        )\n        if self._val_model.is_valid:\n            sub_models.append(self._val_model)\n        elif validation_config is not None:\n            print(\n                self._val_model.error_msg,\n                \"{}'s fit() will not do validation.\".format(self.__class__.__name__),\n            )\n        if len(sub_models) == 0:\n            print(\n                \"{}'s fit() will do nothing because there has no valid configuration.\".format(\n                    self.__class__.__name__\n                )\n            )\n            return sub_models\n        self._checkpoint_model = (\n            CheckpointModel(checkpoint_config, self, callbacks)\n            if self._is_deprecated_function_style\n            else CheckpointModelOOPStyle(checkpoint_config, self, callbacks)\n        )\n        if self._checkpoint_model.is_valid:\n            sub_models.append(self._checkpoint_model)\n        elif checkpoint_config is not None:\n            print(\n                self._checkpoint_model.error_msg,\n                \"{}'s fit() will not do checkpoint.\".format(self.__class__.__name__),\n            )\n        return sub_models\n\n\nclass SubModel(ABC):\n    def __init__(self, name, cfg, model, callbacks):\n        self._cfg = cfg\n        assert isinstance(model, Model)\n        self._model = model\n        self._cbs = callbacks\n        self.name = name\n        self.is_valid = True\n        self.error_msg = (\n            self._model.__class__.__name__ + \" \" + self.name + \" error message: \"\n        )\n        if not self._get_and_check_cfg():\n            self.is_valid = False\n        if not self._get_and_check_cbs():\n            self.is_valid = False\n\n    def step(self, step_idx: int = 0):\n        raise NotImplementedError\n\n    def _get_and_check_cfg(self):\n        if self._cfg is None:\n            self.error_msg += \"config is None;\"\n            return False\n        if not self._cfg.check_valid():\n            self.error_msg += self._cfg.error_msg\n            return False\n        else:\n            return True\n\n    def _get_and_check_cbs(self):\n        if self._cbs is None:\n            self._cbs = []\n            return True\n        if isinstance(self._cbs, Callback):\n            self._cbs = [self._cbs]\n            return True\n        if isinstance(self._cbs, list):\n            for cb in self._cbs:\n                assert isinstance(\n                    cb, Callback\n                ), \"model callbacks' type must be model.Callback or List[model.Callback].\"\n            return True\n        assert (\n            False\n        ), \"model callbacks' type must be model.Callback or List[model.Callback].\"\n\n    def _method_callback(self, method_name: str = None, *args, **kwargs):\n        for cb in self._cbs:\n            method = getattr(cb, method_name)\n            method(*args, **kwargs)\n\n\nclass TrainModel(SubModel):\n    def __init__(\n        self,\n        cfg: TrainingConfig = None,\n        model: Model = None,\n        callbacks: Optional[Union[Callback, List[Callback]]] = None,\n    ):\n        super().__init__(\"training\", cfg, model, callbacks)\n        if not self._get_and_check_step():\n            self.is_valid = False\n        if not self._get_and_check_opts():\n            self.is_valid = False\n        if self.is_valid and (not self._get_and_check_jobs()):\n            self.is_valid = False\n\n    def step(self, step_idx: int = 0):\n        assert self.is_valid, self.error_msg\n        for optimizer_idx in range(0, len(self._opts)):\n            outputs = None\n            if self._is_numpy_input:\n                batch = None\n                if step_idx == 0:\n                    batch = self._first_numpy_batch[optimizer_idx]\n                else:\n                    batch = self._cfg.data(step_idx, optimizer_idx)\n                outputs = self._jobs[optimizer_idx](*batch).get()\n            else:\n                outputs = self._jobs[optimizer_idx]().get()\n            self._method_callback(\n                \"on_training_step_end\",\n                outputs=outputs,\n                step_idx=step_idx,\n                optimizer_idx=optimizer_idx,\n            )\n\n    def _get_and_check_step(self):\n        if not self._model.method_overrided(\"training_step\"):\n            self.error_msg += \"model.training_step() is empty;\"\n            return False\n        else:\n            return True\n\n    def _get_and_check_opts(self):\n        self._opts = []\n        if not self._model.method_overrided(\"configure_optimizers\"):\n            self.error_msg += \"model.configure_optimizers() is empty;\"\n            return False\n        opt_conf = self._model.configure_optimizers()\n        if isinstance(opt_conf, Optimizer):\n            self._opts = [opt_conf]\n        elif isinstance(opt_conf, (list, tuple)):\n            for opt in opt_conf:\n                assert isinstance(\n                    opt, Optimizer\n                ), \"model.configure_optimizers() must return Optimizer                     or List[Optimizer, ...] or Tuple[Optimizer, ...]\"\n            self._opts = opt_conf\n        else:\n            assert (\n                False\n            ), \"model.configure_optimizers() must return Optimizer                 or List[Optimizer, ...] or Tuple[Optimizer, ...]\"\n        return True\n\n    def _get_and_check_jobs(self):\n        self._is_numpy_input = (\n            True if isinstance(self._cfg.data, NumpyDataModule) else False\n        )\n        self._jobs = []\n        if self._is_numpy_input:\n            self._first_numpy_batch = []\n            for optimizer_idx in range(0, len(self._opts)):\n                batch = self._cfg.data(0, optimizer_idx)\n                self._first_numpy_batch.insert(optimizer_idx, batch)\n                self._jobs.insert(\n                    optimizer_idx, self._construct_numpy_job(batch, optimizer_idx)\n                )\n        else:\n            for optimizer_idx in range(0, len(self._opts)):\n                self._jobs.insert(optimizer_idx, self._construct_job(optimizer_idx))\n        return True\n\n    def _construct_job(self, optimizer_idx: int = 0):\n        def job():\n            batch = self._cfg.data(0, optimizer_idx)\n            outputs = self._model.training_step(\n                batch=batch, optimizer_idx=optimizer_idx\n            )\n            loss = None\n            if isinstance(outputs, tuple) and len(outputs) > 0:\n                loss = outputs[0]\n            else:\n                loss = outputs\n            self._opts[optimizer_idx].minimize(loss)\n            return outputs\n\n        job.__name__ = (\n            self._model.__class__.__name__ + \"_Model_train_job_\" + str(optimizer_idx)\n        )\n        deco  # = api_oneflow_function(type=\"train\", function_config=self._cfg.exe_cfg)\n        return deco(job)\n\n    def _construct_numpy_job(self, batch, optimizer_idx):\n        def job(*input_batch):\n            outputs = self._model.training_step(\n                batch=input_batch, optimizer_idx=optimizer_idx\n            )\n            loss = None\n            if isinstance(outputs, tuple) and len(outputs) > 0:\n                loss = outputs[0]\n            else:\n                loss = outputs\n            self._opts[optimizer_idx].minimize(loss)\n            return outputs\n\n        _infer_job_signature(self._cfg.data, batch, optimizer_idx, job)\n        job.__name__ = (\n            self._model.__class__.__name__\n            + \"_Model_train_numpy_job_\"\n            + str(optimizer_idx)\n        )\n        deco  # = api_oneflow_function(type=\"train\", function_config=self._cfg.exe_cfg)\n        return deco(job)\n\n\nclass ValidateModel(SubModel):\n    def __init__(\n        self,\n        cfg: ValidationConfig = None,\n        model: Model = None,\n        callbacks: Optional[Union[Callback, List[Callback]]] = None,\n    ):\n        super().__init__(\"validation\", cfg, model, callbacks)\n        if not self._get_and_check_step():\n            self.is_valid = False\n        if self.is_valid and (not self._get_and_check_job()):\n            self.is_valid = False\n\n    def step(self, step_idx: int = 0):\n        assert self.is_valid\n        if (step_idx + 1) % self._cfg.step_interval == 0:\n            outputs = None\n            if self._is_numpy_input:\n                batch = None\n                if step_idx == 0:\n                    batch = self._first_numpy_batch\n                else:\n                    batch = self._cfg.data(step_idx, 0)\n                outputs = self._job(*batch).get()\n            else:\n                outputs = self._job().get()\n            self._method_callback(\n                \"on_validation_step_end\", step_idx=step_idx, outputs=outputs\n            )\n\n    def _get_and_check_step(self):\n        if not self._model.method_overrided(\"validation_step\"):\n            self.error_msg += \"model.validation_step() is empty;\"\n            return False\n        else:\n            return True\n\n    def _get_and_check_job(self):\n        self._is_numpy_input = (\n            True if isinstance(self._cfg.data, NumpyDataModule) else False\n        )\n        self._job = None\n        if not self._is_numpy_input:\n            self._job = self._construct_job()\n        else:\n            batch = self._cfg.data(0, 0)\n            self._first_numpy_batch = batch\n            self._job = self._construct_numpy_job(batch)\n        return True\n\n    def _construct_job(self):\n        def job():\n            batch = self._cfg.data(0, 0)\n            return self._model.validation_step(batch)\n\n        job.__name__ = self._model.__class__.__name__ + \"_Model_eval_job\"\n        deco  # = api_oneflow_function(type=\"predict\", function_config=self._cfg.exe_cfg)\n        return deco(job)\n\n    def _construct_numpy_job(self, batch: Tuple[np.ndarray, ...] = None):\n        def job(*input_batch):\n            return self._model.validation_step(batch=input_batch)\n\n        _infer_job_signature(self._cfg.data, batch, 0, job)\n        job.__name__ = self._model.__class__.__name__ + \"_Model_eval_numpy_job\"\n        deco  # = api_oneflow_function(type=\"predict\", function_config=self._cfg.exe_cfg)\n        return deco(job)\n\n\nclass CheckpointModel(SubModel):\n    def __init__(\n        self,\n        cfg: CheckpointConfig = None,\n        model: Model = None,\n        callbacks: Optional[Union[Callback, List[Callback]]] = None,\n    ):\n        super().__init__(\"checkpointing\", cfg, model, callbacks)\n\n    def load(self):\n        assert self.is_valid\n        if self._cfg.need_load:\n            self._load_checkpoint(self._cfg.load_dirpath)\n\n    def step(self, step_idx: int = 0):\n        assert self.is_valid\n        if self._cfg.need_save:\n            if (step_idx + 1) % self._cfg.save_step_interval == 0:\n                self._save_checkpoint(\n                    dirpath=self._cfg.save_dirpath + \"-\" + str(step_idx)\n                )\n\n    def _load_checkpoint(self, dirpath: str):\n        \"\"\"Load model states from a checkpoint.\n        \"\"\"\n        stat_dict = flow.load(path=dirpath)\n        self._model.load_state_dict(stat_dict)\n\n    def _save_checkpoint(self, dirpath: str):\n        \"\"\"Save model states as a checkpoint.\n        \"\"\"\n        stat_dict = self._model.state_dict()\n        flow.save(stat_dict, dirpath)\n\n\nclass TrainModelOOPStyle(SubModel):\n    def __init__(\n        self,\n        cfg: TrainingConfig = None,\n        model: Model = None,\n        callbacks: Optional[Union[Callback, List[Callback]]] = None,\n    ):\n        super().__init__(\"training\", cfg, model, callbacks)\n        if not self._get_and_check_step():\n            self.is_valid = False\n        if not self._get_and_check_opts():\n            self.is_valid = False\n\n    def step(self, step_idx: int = 0):\n        assert self.is_valid, self.error_msg\n        for optimizer_idx in range(0, len(self._opts)):\n            batch = self._cfg.data(step_idx, optimizer_idx)\n            outputs = self._model.training_step(\n                batch=batch, optimizer_idx=optimizer_idx\n            )\n            loss = None\n            if isinstance(outputs, tuple) and len(outputs) > 0:\n                loss = outputs[0]\n            else:\n                loss = outputs\n            loss.backward()\n            opt = self._opts[optimizer_idx]\n            opt.step()\n            opt.zero_grad()\n            self._method_callback(\n                \"on_training_step_end\",\n                outputs=outputs,\n                step_idx=step_idx,\n                optimizer_idx=optimizer_idx,\n            )\n\n    def _get_and_check_step(self):\n        if not self._model.method_overrided(\"training_step\"):\n            self.error_msg += \"model.training_step() is empty;\"\n            return False\n        else:\n            return True\n\n    def _get_and_check_opts(self):\n        self._opts = []\n        if not self._model.method_overrided(\"configure_optimizers\"):\n            self.error_msg += \"model.configure_optimizers() is empty;\"\n            return False\n        opt_conf = self._model.configure_optimizers()\n        if isinstance(opt_conf, OOPOptimizer):\n            self._opts = [opt_conf]\n        elif isinstance(opt_conf, (list, tuple)):\n            for opt in opt_conf:\n                assert isinstance(\n                    opt, OOPOptimizer\n                ), \"model.configure_optimizers() must return Optimizer                     or List[Optimizer, ...] or Tuple[Optimizer, ...]\"\n            self._opts = opt_conf\n        else:\n            assert (\n                False\n            ), \"model.configure_optimizers() must return Optimizer                 or List[Optimizer, ...] or Tuple[Optimizer, ...]\"\n        return True\n\n\nclass ValidateModelOOPStyle(SubModel):\n    def __init__(\n        self,\n        cfg: ValidationConfig = None,\n        model: Model = None,\n        callbacks: Optional[Union[Callback, List[Callback]]] = None,\n    ):\n        super().__init__(\"validation\", cfg, model, callbacks)\n        if not self._get_and_check_step():\n            self.is_valid = False\n\n    def step(self, step_idx: int = 0):\n        assert self.is_valid\n        if (step_idx + 1) % self._cfg.step_interval == 0:\n            outputs = None\n            with oneflow._oneflow_internal.autograd.no_grad():\n                inputs = self._cfg.data(step_idx, 0)\n                model_previous_mode = self._model.training\n                self._model.train()\n                outputs = self._model.validation_step(inputs)\n                self._model.train(model_previous_mode)\n            self._method_callback(\n                \"on_validation_step_end\", step_idx=step_idx, outputs=outputs\n            )\n\n    def _get_and_check_step(self):\n        if not self._model.method_overrided(\"validation_step\"):\n            self.error_msg += \"model.validation_step() is empty;\"\n            return False\n        else:\n            return True\n\n\nclass CheckpointModelOOPStyle(SubModel):\n    def __init__(\n        self,\n        cfg: CheckpointConfig = None,\n        model: Model = None,\n        callbacks: Optional[Union[Callback, List[Callback]]] = None,\n    ):\n        super().__init__(\"checkpointing\", cfg, model, callbacks)\n\n    def load(self):\n        assert self.is_valid\n        if self._cfg.need_load:\n            self._load_checkpoint(self._cfg.load_dirpath)\n\n    def step(self, step_idx: int = 0):\n        assert self.is_valid\n        if self._cfg.need_save:\n            if (step_idx + 1) % self._cfg.save_step_interval == 0:\n                self._save_checkpoint(\n                    dirpath=self._cfg.save_dirpath + \"-\" + str(step_idx)\n                )\n\n    def _load_checkpoint(self, dirpath: str):\n        \"\"\"Load model states from a checkpoint.\n        \"\"\"\n        stat_dict = flow.load(path=dirpath)\n        self._model.load_state_dict(stat_dict)\n\n    def _save_checkpoint(self, dirpath: str):\n        \"\"\"Save model states as a checkpoint.\n        \"\"\"\n        stat_dict = self._model.state_dict()\n        flow.save(stat_dict, dirpath)\n\n\ndef _infer_job_signature(data_module, batch, optimizer_idx, job):\n    para_list = []\n    placeholder_list = data_module.infer_oneflow_data_placeholder(batch, optimizer_idx)\n    for (i, placeholder) in enumerate(placeholder_list):\n        para_name = (\n            data_module.__class__.__name__\n            + \"_opt_\"\n            + str(optimizer_idx)\n            + \"_para_\"\n            + str(i)\n        )\n        para_list.append(\n            inspect.Parameter(\n                name=para_name,\n                kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,\n                annotation=placeholder,\n            )\n        )\n    origin_sig = inspect.signature(job)\n    new_sig = origin_sig.replace(parameters=para_list)\n    job.__oneflow_function_signature__ = new_sig\n"
  },
  {
    "path": "python/oneflow/framework/multi_client_session.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport enum\nimport inspect\n\nfrom google.protobuf import text_format\n\nimport oneflow._oneflow_internal\nimport oneflow.core.job.job_set_pb2 as job_set_util\nimport oneflow.framework.c_api_util as c_api_util\nimport oneflow.framework.env_util as env_util\nimport oneflow.core.job.resource_pb2 as resource_pb\n\n\nclass MultiClientSession(object):\n    class Status(enum.Enum):\n        CREATED = 1\n        INITED = 2\n        CLOSED = 3\n\n    def __init__(self, env, sess_id):\n        self._id = sess_id\n        self._env = env\n        assert self._env is not None\n        # New a MultiClientSessionContext\n        self._session_ctx = oneflow._oneflow_internal.SessionContext(self._env._env_cxt)\n        self.config_proto_ = self._make_config_proto()\n        self.function_flag_name2default_val_ = {}\n        self._update_function_flag_name2defaultVal()\n        self.scope_attr_name2default_val_ = {}\n        self._update_scope_attr_name2defaultVal()\n        self.status_ = self.Status.CREATED\n\n    def __del__(self):\n        if self._env.is_shutting_down():\n            # After python shutting down, it's not safe to call oneflow\n            return\n        self._TryClose()\n\n    def TryInit(self):\n        self._check_status(self.Status.CREATED, self.Status.INITED)\n        if self.status_ == self.Status.CREATED:\n            config_proto_str = text_format.MessageToString(self.config_proto)\n            self._session_ctx.try_init(config_proto_str)\n            self.status_ = self.Status.INITED\n\n    def _TryClose(self):\n        if self.status_ != self.Status.CLOSED:\n            oneflow._oneflow_internal.ClearSessionId(self.id)\n        self.status_ = self.Status.CLOSED\n\n    @property\n    def status(self):\n        return self.status_\n\n    @property\n    def id(self):\n        return self._id\n\n    @property\n    def config_proto(self):\n        return self.config_proto_\n\n    @property\n    def resource(self):\n        self._check_status(self.Status.INITED)\n        return c_api_util.CurrentResource()\n\n    @property\n    def function_flag_name2default_val(self):\n        return self.function_flag_name2default_val_\n\n    @property\n    def scope_attr_name2default_val(self):\n        return self.scope_attr_name2default_val_\n\n    @property\n    def is_running(self):\n        return self.status_ == self.Status.INITED\n\n    def _check_status(self, *status):\n        check_success = False\n        for stat in status:\n            if self.status_ == stat:\n                check_success = True\n                break\n        if check_success is False:\n            caller_func_name = inspect.stack()[1].function\n            allowed_status = \" or \".join([str(stat) for stat in status])\n            raise ValueError(\n                \"The calling to {} is only allowed when status is {}, but current status is {}\".format(\n                    caller_func_name, allowed_status, self.status_\n                )\n            )\n\n    def _make_config_proto(self):\n        config_proto = job_set_util.ConfigProto()\n        config_proto.resource.SetInParent()\n        config_proto.session_id = self.id\n        return config_proto\n\n    def _update_function_flag_name2defaultVal(self):\n        items = c_api_util.GetFunctionConfigDef().attr_name2attr_def.items()\n        self.function_flag_name2default_val_ = {k: v.default_val for (k, v) in items}\n\n    def _update_scope_attr_name2defaultVal(self):\n        items = c_api_util.GetScopeConfigDef().attr_name2attr_def.items()\n        self.scope_attr_name2default_val_ = {k: v.default_val for (k, v) in items}\n\n    def update_resource_eagerly(self, resource_config):\n        self._check_status(self.Status.INITED)\n        config_proto_str = text_format.MessageToString(resource_config)\n        self._session_ctx.update_resource(config_proto_str)\n"
  },
  {
    "path": "python/oneflow/framework/register_class_method_util.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow._oneflow_internal\nimport oneflow.framework.check_point_v2 as check_point_v2\nimport oneflow.framework.tensor as tensor_util\n\n\ndef RegisterMethod4Class():\n    tensor_util.RegisterMethods()\n    check_point_v2.RegisterMethods()\n"
  },
  {
    "path": "python/oneflow/framework/scope_util.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport traceback\nfrom contextlib import contextmanager\n\nfrom google.protobuf import text_format\n\nimport oneflow._oneflow_internal\nimport oneflow.core.job.scope_pb2 as scope_pb2_util\nimport oneflow.framework.attr_util as attr_util\nimport oneflow.framework.session_context as session_ctx\nfrom oneflow import oneflow_deprecate\n\n\ndef api_scope_config(**kwargs):\n    name2default = session_ctx.GetDefaultSession().scope_attr_name2default_val\n\n    def SetScopeProtoStr(serialized_scope_proto: str):\n        scope_proto = text_format.Parse(\n            serialized_scope_proto, scope_pb2_util.ScopeProto()\n        )\n        for (attr_name, py_value) in kwargs.items():\n            assert attr_name in name2default\n            attr_util.SetProtoAttrValue(\n                scope_proto.attr_name2attr_value[attr_name],\n                py_value,\n                name2default[attr_name],\n            )\n        return str(text_format.MessageToString(scope_proto))\n\n    sess = session_ctx.GetDefaultSession()\n    scope = MakeScope(\n        lambda old_scope, builder: builder.BuildScopeByProtoStrSetter(\n            old_scope, SetScopeProtoStr\n        )\n    )\n    return ScopeContext(scope)\n\n\ndef current_scope():\n    \"\"\" Return current scope\n    \"\"\"\n    return oneflow._oneflow_internal.GetCurrentScope()\n\n\nfrom oneflow import oneflow_deprecate\n\n\ndef MakeScope(build_func):\n    scope = None\n    old_scope = oneflow._oneflow_internal.GetCurrentScope()\n    assert old_scope is not None\n\n    def BuildScope(builder):\n        nonlocal scope\n        scope = build_func(old_scope, builder)\n        assert scope is not None\n\n    oneflow._oneflow_internal.deprecated.PhysicalRun(BuildScope)\n    return scope\n\n\n@contextmanager\ndef ScopeContext(scope):\n    old_scope = oneflow._oneflow_internal.GetCurrentScope()\n    oneflow._oneflow_internal.GlobalScopeStackPush(scope)\n    try:\n        yield\n    finally:\n        assert oneflow._oneflow_internal.GetCurrentScope() is scope\n        oneflow._oneflow_internal.GlobalScopeStackPop()\n        assert oneflow._oneflow_internal.GetCurrentScope() is old_scope\n"
  },
  {
    "path": "python/oneflow/framework/session_context.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport functools\n\nimport oneflow\nimport oneflow._oneflow_internal\nfrom oneflow.framework.multi_client_session import MultiClientSession\n\n\nclass SessionStatus:\n    OPEN = \"OPEN\"\n    RUNNING = \"RUNNING\"\n    CLOSED = \"CLOSED\"\n\n\ndef GetDefaultSession():\n    global _sess_id2sess\n    default_sess_id = oneflow._oneflow_internal.GetDefaultSessionId()\n    assert default_sess_id in _sess_id2sess\n    return _sess_id2sess[default_sess_id]\n\n\ndef NewDefaultSession(env):\n    session_id = oneflow._oneflow_internal.NewSessionId()\n    assert oneflow._oneflow_internal.RegsterSessionId(session_id)\n    new_default_sess = MultiClientSession(env, session_id)\n    global _sess_id2sess\n    assert new_default_sess.id not in _sess_id2sess\n    _sess_id2sess[new_default_sess.id] = new_default_sess\n\n\ndef TryCloseDefaultSession():\n    global _sess_id2sess\n    default_sess_id = oneflow._oneflow_internal.GetDefaultSessionId()\n    assert default_sess_id in _sess_id2sess\n    if default_sess_id in _sess_id2sess:\n        del _sess_id2sess[default_sess_id]\n    # Try clear to avoid using this outdated session.\n    oneflow._oneflow_internal.ClearSessionId(default_sess_id)\n\n\ndef try_init_default_session(func):\n    @functools.wraps(func)\n    def Func(*args, **kwargs):\n        GetDefaultSession().TryInit()\n        return func(*args, **kwargs)\n\n    return Func\n\n\n_sess_id2sess = {}\n"
  },
  {
    "path": "python/oneflow/framework/sysconfig.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport imp\nimport importlib.util\nimport os\nfrom typing import List\n\nimport oneflow\nimport oneflow._oneflow_internal\n\n\ndef get_include() -> str:\n    return os.path.join(os.path.dirname(oneflow.__file__), \"include\")\n\n\ndef get_lib() -> str:\n    return os.path.dirname(oneflow.__file__)\n\n\ndef get_compile_flags() -> List[str]:\n    flags = []\n    flags.append(\"-I{}\".format(get_include()))\n    flags.append(\"-DHALF_ENABLE_CPP11_USER_LITERALS=0\")\n    if oneflow._oneflow_internal.flags.with_cuda():\n        flags.append(\"-DWITH_CUDA\")\n    if oneflow._oneflow_internal.flags.use_cxx11_abi():\n        flags.append(\"-D_GLIBCXX_USE_CXX11_ABI=1\")\n    else:\n        flags.append(\"-D_GLIBCXX_USE_CXX11_ABI=0\")\n    return flags\n\n\ndef get_liboneflow_link_flags() -> List[str]:\n    oneflow_python_module_path = get_lib()\n    # path in a pip release\n    oneflow_python_libs_path = f\"{oneflow_python_module_path}.libs\"\n    # path in a cmake build dir\n    if not os.path.exists(oneflow_python_libs_path):\n        from oneflow.version import __cmake_project_binary_dir__\n\n        oneflow_python_libs_path = __cmake_project_binary_dir__\n    return [\n        f\"-L{oneflow_python_libs_path}\",\n        f\"-l:oneflow\",\n        f\"-l:of_protoobj\",\n    ]\n\n\ndef get_link_flags() -> List[str]:\n    flags = []\n    flags.append(\"-L{}\".format(get_lib()))\n    (file, oneflow_internal_lib_path, _) = imp.find_module(\n        \"_oneflow_internal\", [get_lib()]\n    )\n    if file:\n        file.close()\n    flags.append(\"-l:{}\".format(os.path.basename(oneflow_internal_lib_path)))\n    return flags\n\n\ndef with_cuda() -> bool:\n    return oneflow._oneflow_internal.flags.with_cuda()\n\n\ndef get_cuda_version() -> int:\n    return oneflow._oneflow_internal.flags.cuda_version()\n\n\ndef has_rpc_backend_grpc() -> bool:\n    return oneflow._oneflow_internal.flags.has_rpc_backend_grpc()\n\n\ndef has_rpc_backend_local() -> bool:\n    return oneflow._oneflow_internal.flags.has_rpc_backend_local()\n\n\ndef cmake_build_type() -> str:\n    return oneflow._oneflow_internal.flags.cmake_build_type()\n\n\ndef with_rdma() -> bool:\n    return oneflow._oneflow_internal.flags.with_rdma()\n"
  },
  {
    "path": "python/oneflow/framework/tensor.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nfrom numbers import Number\nimport oneflow as flow\nimport oneflow.framework.tensor_str as tensor_str\nimport oneflow._oneflow_internal.lazy_mode as lazy_mode\n\nimport numpy as np\nfrom typing import Union\n\nTensor = flow._oneflow_internal.Tensor\nTensorTuple = flow._oneflow_internal.TensorTuple\n\n\ndef _ndim(self):\n    return len(self.shape)\n\n\ndef _backward(self, gradient=None, retain_graph=False, create_graph=False):\n    if lazy_mode.is_enabled():\n        assert (\n            self.is_lazy\n        ), \"nn.Graph only accept lazy tensor to call backward() in lazy mode.\"\n        assert (\n            not retain_graph\n        ), \"nn.Graph donot accept 'retain_graph' argument in backward() at the moment.\"\n        assert (\n            not create_graph\n        ), \"nn.Graph donot accept 'create_graph' argument in backward() at the moment.\"\n        flow._oneflow_internal.nn.graph.AddTensorAsGraphLoss(self)\n    flow.autograd.backward(self, gradient, retain_graph, create_graph)\n\n\ndef _str(self):\n    return self.__repr__()\n\n\ndef _repr(self):\n    return tensor_str._gen_tensor_str(self)\n\n\ndef _meta_repr(self):\n    return tensor_str._gen_tensor_meta_str(self)\n\n\ndef _eq(self, other):\n    if self is None and other is None:\n        return True\n    elif self is None or other is None:\n        return False\n    else:\n        return flow._C.broadcast_equal(self, other)\n\n\ndef _cuda(self, device: Union[int, str, flow.device] = None):\n    if device is None:\n        device = \"cuda\"\n    elif isinstance(device, int):\n        device = \"cuda:\" + str(device)\n    return self.to(device=device)\n\n\ndef _norm(self, p=None, dim=None, keepdim=False, dtype=None):\n    if type(p) == str or dim != None:\n        return flow._C.norm(self, p, dim, keepdim, dtype=dtype)\n    return flow._C.norm(self, p, dim, keepdim, dtype=dtype, for_norm=True)\n\n\ndef is_nonzero(input):\n    r\"\"\"\n    is_nonzero(input) -> (bool)\n\n    Returns True if the :attr:`input` is a single element tensor which is not equal to zero\n    after type conversions. i.e. not equal to ``flow.tensor([0.])`` or ``flow.tensor([0])``.\n\n    Throws a ``RuntimeError`` if ``input.shape.numel() != 1``\n\n    For Example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> flow.is_nonzero(flow.tensor([0.]))\n        False\n        >>> flow.is_nonzero(flow.tensor([1.5]))\n        True\n        >>> flow.is_nonzero(flow.tensor([3]))\n        True\n\n    \"\"\"\n    shape = input.shape\n    if shape.numel() == 0:\n        raise RuntimeError(\"bool value of Tensor with no values is ambiguous\")\n    if shape.numel() > 1:\n        raise RuntimeError(\"bool value of Tensor with more than one value is ambiguous\")\n    value = input.numpy().item()\n    return bool(value)\n\n\ndef _add(self, other, *, alpha=1):\n    return flow._C.add(self, other, alpha=alpha)\n\n\ndef _addmm(self, mat1, mat2, alpha=1, beta=1):\n    return flow.addmm(self, mat1, mat2, alpha, beta)\n\n\ndef _add_inplace(self, other, *, alpha=1):\n    return flow._C.add(self, other, alpha=alpha, inplace=True)\n\n\ndef _iadd(self, other):\n    return self.add_(other)\n\n\ndef _sub_inplace(self, other):\n    return flow._C.sub(self, other, inplace=True)\n\n\ndef _expand(self, *size):\n    return flow.expand(self, *size)\n\n\ndef _expand_as(input, other):\n    return flow.expand(input, *other.size())\n\n\ndef _argwhere(self):\n    return flow.argwhere(self)\n\n\ndef _index(self):\n    assert self.numel() == 1 and self.dtype in (\n        flow.uint8,\n        flow.int8,\n        flow.int32,\n        flow.int64,\n        flow.bool,\n    ), \"Only integer tensors of a single element can be converted to an index\"\n    return self.numpy().item()\n\n\ndef _scalar_float(self):\n    assert (\n        self.numel() == 1\n    ), \"only one element tensors can be converted to Python scalars\"\n    return self.numpy().astype(np.float64).item()\n\n\ndef _scalar_int(self):\n    assert (\n        self.numel() == 1\n    ), \"only one element tensors can be converted to Python scalars\"\n    return self.numpy().astype(np.int64).item()\n\n\ndef _new_empty(\n    self, *size, dtype=None, device=None, placement=None, sbp=None, requires_grad=False,\n):\n    return flow.new_empty(self, size, dtype, device, placement, sbp, requires_grad)\n\n\ndef _new_ones(\n    self, *size, dtype=None, device=None, placement=None, sbp=None, requires_grad=False,\n):\n    return flow.new_ones(self, size, dtype, device, placement, sbp, requires_grad)\n\n\ndef _new_zeros(\n    self, *size, dtype=None, device=None, placement=None, sbp=None, requires_grad=False,\n):\n    return flow.new_zeros(self, size, dtype, device, placement, sbp, requires_grad)\n\n\ndef _squeeze_inplace(self, dim=None):\n    return flow._C.squeeze_(self, dim=dim)\n\n\ndef _unsqueeze_inplace(self, dim=None):\n    return flow._C.unsqueeze_(self, dim=dim)\n\n\ndef _new_full(\n    self,\n    size,\n    fill_value,\n    dtype=None,\n    device=None,\n    placement=None,\n    sbp=None,\n    requires_grad=False,\n):\n    return flow.new_full(\n        self, size, fill_value, dtype, device, placement, sbp, requires_grad\n    )\n\n\ndef _argsort(self, dim=-1, descending=None):\n    return flow.argsort(self, dim=dim, descending=descending)\n\n\ndef _uniform(self, a=0, b=1):\n    return flow.nn.init.uniform_(self, a, b)\n\n\ndef _exponential(self, lambd=1.0, generator=None):\n    return flow._C.exponential_(self, lambd, generator)\n\n\ndef _trunc_normal_(\n    self, mean=0.0, std=1.0, a=-2.0, b=2.0,\n):\n    return flow.nn.init.trunc_normal_(self, mean=mean, std=std, a=a, b=b)\n\n\ndef _kaiming_uniform(\n    self, a=0, mode=\"fan_in\", nonlinearity=\"leaky_relu\", *, data_format=\"NCHW\"\n):\n    return flow.nn.init.kaiming_uniform_(\n        self, a=a, mode=mode, nonlinearity=nonlinearity, data_format=data_format\n    )\n\n\ndef _kaiming_normal(\n    self, a=0, mode=\"fan_in\", nonlinearity=\"leaky_relu\", *, data_format=\"NCHW\"\n):\n    return flow.nn.init.kaiming_normal_(\n        self, a=a, mode=mode, nonlinearity=nonlinearity, data_format=data_format\n    )\n\n\ndef _xavier_normal(self, gain=1.0):\n    return flow.nn.init.xavier_normal_(self, gain=gain, data_format=data_format)\n\n\ndef _xavier_uniform(self, gain=1.0):\n    return flow.nn.init.xavier_uniform_(self, gain=gain, data_format=data_format)\n\n\ndef _orthogonal(self, gain=1.0):\n    if self.ndimension() < 2:\n        raise ValueError(\"Only tensors with 2 or more dimensions are supported\")\n    rows = self.shape[0]\n    cols = np.prod(self.shape[1:])\n    flattened = np.random.normal(0.0, 1.0, size=(rows, cols))\n    if rows < cols:\n        flattened = flattened.T\n    # TODO\n    q, r = np.linalg.qr(flattened)\n    d = np.diag(r, 0)\n    d = np.sign(d)\n    q *= d\n    if rows < cols:\n        q = q.T\n    self = gain * flow.tensor(q.reshape(self.shape))\n    return self\n\n\ndef _normal(self, mean=0, std=1):\n    return flow.nn.init.normal_(self, mean=mean, std=std)\n\n\ndef _copy_from_numpy_to_eager_local_tensor(eager_local_tensor, np_arr):\n    assert np_arr.dtype == flow.convert_oneflow_dtype_to_numpy_dtype(\n        eager_local_tensor.dtype\n    )\n    assert np_arr.shape == tuple(eager_local_tensor.shape)\n    eager_local_tensor._copy_from_numpy(np_arr)\n\n\ndef _copy(self, other: Union[Tensor, np.ndarray]):\n    if isinstance(other, np.ndarray):\n        other = flow.from_numpy(other)\n    elif not isinstance(other, Tensor):\n        other = flow.tensor(other)\n    other = other.to(self.dtype)\n    if self.is_global:\n        assert other.is_global, \"Only global tensor can be assigned to global tensor.\"\n        if not (self.sbp == other.sbp and self.placement == other.placement):\n            other_cpu_placement = flow.placement(\"cpu\", other.placement.ranks)\n            other = other.to_global(placement=other_cpu_placement)\n            self_cpu_placement = flow.placement(\"cpu\", self.placement.ranks)\n            other = other.to_global(placement=self_cpu_placement, sbp=self.sbp)\n        flow._C.assign_local_tensor(self.to_local(), other.to_local())\n    else:\n        assert other.is_local, \"Only local tensor can be assigned to local tensor.\"\n        other = flow._C.broadcast_like(other, self)\n        if not self.is_contiguous():\n            # NOTE: slice_update support non-contiguous input tensor\n            with flow.no_grad():\n                self[...] = other\n        else:\n            flow._C.assign_local_tensor(self, other)\n\n\ndef _format(self, format_spec):\n    if self.dim() == 0:\n        return self.numpy().tolist().__format__(format_spec)\n    return object.__format__(self, format_spec)\n\n\ndef _to(self, *args, **kwargs):\n    new_args = list()\n    # If device is single int, replace it with flow.device(\"cuda:{device}\")\n    if len(args) > 0 and isinstance(args[0], int):\n        new_args.append(flow.device(f\"cuda:{args[0]}\"))\n        for i in range(1, len(args)):\n            new_args.append(args[i])\n    else:\n        new_args = args\n    if (\"device\" in kwargs) and isinstance(kwargs[\"device\"], int):\n        kwargs[\"device\"] = flow.device(f\"cuda:{kwargs['device']}\")\n    return flow._C.to(self, *new_args, **kwargs)\n\n\ndef _tolist(self):\n    if self.numel() == 1 and self.ndim == 0:\n        return self.item()\n    return self.numpy().tolist()\n\n\ndef _repeat(self, *sizes):\n    if len(sizes) == 1:\n        new_sizes = sizes[0]\n        if isinstance(new_sizes, int):\n            new_sizes = (new_sizes,)\n    else:\n        new_sizes = sizes\n    return flow._C.repeat(self, new_sizes)\n\n\ndef _tile(self, *dims):\n    if len(dims) == 1:\n        new_dims = dims[0]\n        if isinstance(new_dims, int):\n            new_dims = (new_dims,)\n    else:\n        new_dims = dims\n    return flow._C.tile(self, new_dims)\n\n\ndef _T(self):\n    return flow._C.T(self)\n\n\ndef _nms(boxes, scores, iou_threshold: float):\n    return flow.nms(boxes, scores, iou_threshold)\n\n\ndef _nonzero(self, as_tuple=False):\n    return flow.nonzero(self, as_tuple)\n\n\ndef _prod(self, dim=[], keepdim=False):\n    return flow.prod(self, dim, keepdim)\n\n\ndef _masked_select(self, mask):\n    return flow.masked_select(self, mask)\n\n\ndef _sort(self, dim: int = -1, descending: bool = False):\n    return flow.sort(self, dim, descending)\n\n\ndef _where(self, x=None, y=None):\n    return flow.where(self, x, y)\n\n\ndef _numpy(self, dtype=None):\n    assert (\n        not self.is_lazy\n    ), \"tensor.numpy() is not allowed to be called in nn.Graph.build(*args) or be called by lazy tensor.\"\n    if self.is_global:\n        if self.placement.type == \"meta\":\n            raise TypeError(\"can't convert meta device type global tensor to numpy.\")\n    else:\n        if self.device.type == \"meta\":\n            raise TypeError(\"can't convert meta device type local tensor to numpy.\")\n\n    if self.dtype == flow.tensor_buffer:\n        shapes, dtypes = self._tensor_buffer_shapes_and_dtypes\n        tensors = flow.tensor_buffer_to_list_of_tensors(self, shapes, dtypes)\n        return [t.numpy() for t in tensors]\n    # TODO: support bfloat16 to numpy in C++\n    if self.dtype == flow.bfloat16:\n        self = self.to(flow.float32)\n    if self.is_global:\n        self_cpu_placement = flow.placement(\"cpu\", self.placement.ranks)\n        self = (\n            self.to_global(placement=self_cpu_placement)\n            .to_global(placement=flow.placement.all(\"cpu\"), sbp=flow.sbp.broadcast)\n            .to_local()\n        )\n    assert self.is_local\n    if self.device != flow.device(\"cpu\"):\n        self = self.cpu()\n    result = self.to_numpy()\n    if dtype is None:\n        return result\n    return result.astype(dtype)\n\n\ndef zero_(self):\n    self.zero_()\n    return self\n\n\ndef _is_consistent(self):\n    raise RuntimeError(\".is_consistent has been removed, please use .is_global instead\")\n\n\ndef _to_consistent(self, *args, **kwargs):\n    raise RuntimeError(\".to_consistent has been removed, please use .to_global instead\")\n\n\ndef _new_tensor(\n    self, data, dtype=None, device=None, requires_grad=False, placement=None, sbp=None\n):\n    if dtype is None:\n        dtype = self.dtype\n    if self.is_local:\n        assert (\n            placement is None and sbp is None\n        ), \"self is local tensor, placement and sbp are expected to be None.\"\n        if device is None:\n            device = self.device\n        return flow.tensor(\n            data, dtype=dtype, device=device, requires_grad=requires_grad\n        )\n    else:\n        assert device is None, \"self is global tensor, device is expected to be None.\"\n        if placement is None:\n            placement = self.placement\n        if sbp is None:\n            sbp = self.sbp\n        return flow.tensor(\n            data, dtype=dtype, placement=placement, sbp=sbp, requires_grad=requires_grad\n        )\n\n\ndef _cumsum(self, dim, dtype=None):\n    return flow._C.cumsum(self, dim, dtype=dtype)\n\n\ndef _cumprod(self, dim, dtype=None):\n    return flow._C.cumprod(self, dim, dtype=dtype)\n\n\ndef _cross(self, other, dim=None):\n    return flow._C.cross(self, other, dim)\n\n\ndef _scatter(self, dim, index, src, *, reduce=None):\n    return flow._C.scatter(self, dim, index, src, reduce=reduce, inplace=False)\n\n\ndef _scatter_inplace(self, dim, index, src, *, reduce=None):\n    return flow._C.scatter(self, dim, index, src, reduce=reduce, inplace=True)\n\n\ndef _scatter_add_inplace(self, dim, index, src):\n    return flow._C.scatter_add(self, dim, index, src, inplace=True)\n\n\ndef _contains(self, element):\n    r\"\"\"Check if `element` is present in tensor\n\n    Args:\n        element (Tensor or scalar): element to be checked\n            for presence in current tensor\"\n    \"\"\"\n    if isinstance(element, (flow.Tensor, Number)):\n        # type hint doesn't understand the __contains__ result array\n        return (element == self).any().item()  # type: ignore[union-attr]\n\n    raise RuntimeError(\n        \"Tensor.__contains__ only supports Tensor or scalar, but you passed in a %s.\"\n        % type(element)\n    )\n\n\ndef _allclose(self, other, atol=1e-08, rtol=1e-05, equal_nan=False):\n    return flow._C.allclose(self, other, atol, rtol, equal_nan)\n\n\ndef _index_add(self, dim, index, source, alpha=1):\n    return flow._C.index_add(self, dim, index, source, alpha)\n\n\ndef _index_add_inplace(self, dim, index, source, alpha=1):\n    return flow._C.index_add_(self, dim, index, source, alpha)\n\n\ndef _as_strided(self, size, stride, storage_offset=0):\n    return flow._C.as_strided(self, size, stride, storage_offset)\n\n\ndef _as_strided_inplace(self, size, stride, storage_offset=0):\n    return flow._C.as_strided_(self, size, stride, storage_offset)\n\n\ndef _logaddexp(self, other):\n    return flow._C.logaddexp(self, other)\n\n\ndef _real(self):\n    return flow._C.real(self)\n\n\ndef _imag(self):\n    return flow._C.imag(self)\n\n\ndef _conj(self):\n    return flow._C.conj(self)\n\n\ndef _conj_physical(self):\n    return flow._C.conj_physical(self)\n\n\ndef _storage(self):\n    return self\n\n\n@property\ndef _layout(self):\n    return flow.strided\n\n\ndef RegisterMethods():\n    Tensor.ndim = property(_ndim)\n    Tensor.numpy = _numpy\n    Tensor.add = _add\n    Tensor.add_ = _add_inplace\n    Tensor.sub_ = _sub_inplace\n    Tensor.backward = _backward\n    Tensor.__str__ = _str\n    Tensor.__repr__ = _repr\n    Tensor.__contains__ = _contains\n    Tensor.__bool__ = is_nonzero\n    Tensor.__iadd__ = _iadd\n    Tensor.addmm = _addmm\n    Tensor.__format__ = _format\n    Tensor.__index__ = _index\n    Tensor.__float__ = _scalar_float\n    Tensor.__int__ = _scalar_int\n    Tensor.__array__ = _numpy\n    Tensor.uniform_ = _uniform\n    Tensor.exponential_ = _exponential\n    Tensor.trunc_normal_ = _trunc_normal_\n    Tensor.kaiming_uniform_ = _kaiming_uniform\n    Tensor.kaiming_normal_ = _kaiming_normal\n    Tensor.xavier_normal_ = _xavier_normal\n    Tensor.xavier_uniform_ = _xavier_uniform\n    Tensor.orthogonal_ = _orthogonal\n    Tensor.normal_ = _normal\n    Tensor.copy_ = _copy\n    Tensor._meta_repr = _meta_repr\n    Tensor.argsort = _argsort\n    Tensor.argwhere = _argwhere\n    Tensor.expand = _expand\n    Tensor.expand_as = _expand_as\n    Tensor.new_empty = _new_empty\n    Tensor.new_ones = _new_ones\n    Tensor.new_zeros = _new_zeros\n    Tensor.new_full = _new_full\n    Tensor.squeeze_ = _squeeze_inplace\n    Tensor.unsqueeze_ = _unsqueeze_inplace\n    Tensor.where = _where\n    Tensor.norm = _norm\n    Tensor.repeat = _repeat\n    Tensor.tile = _tile\n    Tensor.to = _to\n    Tensor.T = property(_T)\n    Tensor.masked_select = _masked_select\n    Tensor.eq = _eq\n    Tensor.sort = _sort\n    Tensor.tolist = _tolist\n    Tensor.nms = _nms\n    Tensor.nonzero = _nonzero\n    Tensor.prod = _prod\n    Tensor.is_consistent = _is_consistent\n    Tensor.to_consistent = _to_consistent\n    Tensor.new_tensor = _new_tensor\n    Tensor.cumsum = _cumsum\n    Tensor.cumprod = _cumprod\n    Tensor.cross = _cross\n    Tensor.scatter = _scatter\n    Tensor.scatter_ = _scatter_inplace\n    Tensor.scatter_add_ = _scatter_add_inplace\n    Tensor.allclose = _allclose\n    Tensor.index_add = _index_add\n    Tensor.index_add_ = _index_add_inplace\n    Tensor.as_strided = _as_strided\n    Tensor.as_strided_ = _as_strided_inplace\n    Tensor.logaddexp = _logaddexp\n    Tensor.real = _real\n    Tensor.imag = _imag\n    Tensor.conj = _conj\n    Tensor.conj_physical = _conj_physical\n    Tensor.layout = _layout\n    Tensor.storage = _storage\n\n\ndef register_tensor_op(op_name):\n    def set_tensor_op(method):\n        setattr(Tensor, op_name, method)\n        return method\n\n    return set_tensor_op\n"
  },
  {
    "path": "python/oneflow/framework/tensor_str.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\"\"\"\nThis file is mostly referenced from PyTorch v1.8.1 torch/_tensor_str.py\n\"\"\"\n\n\nimport math\nimport numpy as np\nfrom typing import Optional\nimport oneflow as flow\nfrom oneflow.framework.tensor_str_util import _autoset_linewidth\nfrom oneflow.framework.tensor_str_util import _try_convert_to_local_tensor\n\n\nclass __PrinterOptions(object):\n    precision: int = 4\n    threshold: float = 1000\n    edgeitems: int = 3\n    userset_linewidth: int = None\n    sci_mode: Optional[bool] = None\n\n    autoset_linewidth: bool = True\n\n    @property\n    def linewidth(self):\n        return (\n            _autoset_linewidth() if self.autoset_linewidth else self.userset_linewidth\n        )\n\n    @linewidth.setter\n    def linewidth(self, value):\n        self.userset_linewidth = value\n\n\nPRINT_OPTS = __PrinterOptions()\n\n\ndef set_printoptions(\n    precision=None,\n    threshold=None,\n    edgeitems=None,\n    linewidth=None,\n    profile=None,\n    sci_mode=None,\n):\n    r\"\"\"Set options for printing. Items shamelessly taken from NumPy\n\n    Args:\n        precision: Number of digits of precision for floating point output\n            (default = 4).\n        threshold: Total number of array elements which trigger summarization\n            rather than full `repr` (default = 1000).\n        edgeitems: Number of array items in summary at beginning and end of\n            each dimension (default = 3).\n        linewidth: The number of characters per line for the purpose of\n            inserting line breaks (default = terminal_columns).\n        profile: Sane defaults for pretty printing. Can override with any of\n            the above options. (any one of `default`, `short`, `full`)\n        sci_mode: Enable (True) or disable (False) scientific notation. If\n            None (default) is specified, the value is defined by\n            `oneflow._tensor_str._Formatter`. This value is automatically chosen\n            by the framework.\n    .. note::\n        linewidth equals to terminal columns, manual setting will invalidate the default automatic setting.\n    \"\"\"\n    if profile is not None:\n        if profile == \"default\":\n            PRINT_OPTS.precision = 4\n            PRINT_OPTS.threshold = 1000\n            PRINT_OPTS.edgeitems = 3\n            PRINT_OPTS.linewidth = 80\n        elif profile == \"short\":\n            PRINT_OPTS.precision = 2\n            PRINT_OPTS.threshold = 1000\n            PRINT_OPTS.edgeitems = 2\n            PRINT_OPTS.linewidth = 80\n        elif profile == \"full\":\n            PRINT_OPTS.precision = 4\n            PRINT_OPTS.threshold = math.inf\n            PRINT_OPTS.edgeitems = 3\n            PRINT_OPTS.linewidth = 80\n\n    if precision is not None:\n        PRINT_OPTS.precision = precision\n    if threshold is not None:\n        PRINT_OPTS.threshold = threshold\n    if edgeitems is not None:\n        PRINT_OPTS.edgeitems = edgeitems\n    if linewidth is not None:\n        PRINT_OPTS.linewidth = linewidth\n    PRINT_OPTS.sci_mode = sci_mode\n    if profile is not None or linewidth is not None:\n        PRINT_OPTS.autoset_linewidth = False\n\n\nclass _Formatter(object):\n    def __init__(self, tensor):\n        self.floating_dtype = tensor.dtype.is_floating_point\n        self.int_mode = True\n        self.sci_mode = False\n        self.max_width = 1\n        self.random_sample_num = 50\n        tensor = _try_convert_to_local_tensor(tensor)\n\n        with flow.no_grad():\n            tensor_view = tensor.reshape(-1)\n\n        if not self.floating_dtype:\n            for value in tensor_view:\n                value_str = \"{}\".format(value)\n                self.max_width = max(self.max_width, len(value_str))\n\n        else:\n            nonzero_finite_vals = flow.masked_select(tensor_view, tensor_view.ne(0))\n            if nonzero_finite_vals.numel() == 0:\n                # no valid number, do nothing\n                return\n\n            nonzero_finite_abs = nonzero_finite_vals.abs()\n            nonzero_finite_min = nonzero_finite_abs.min().numpy().astype(np.float64)\n            nonzero_finite_max = nonzero_finite_abs.max().numpy().astype(np.float64)\n\n            for value in nonzero_finite_abs.numpy():\n                if value != np.ceil(value):\n                    self.int_mode = False\n                    break\n\n            if self.int_mode:\n                # Check if scientific representation should be used.\n                if (\n                    nonzero_finite_max / nonzero_finite_min > 1000.0\n                    or nonzero_finite_max > 1.0e8\n                ):\n                    self.sci_mode = True\n                    for value in nonzero_finite_vals:\n                        value_str = (\n                            (\"{{:.{}e}}\").format(PRINT_OPTS.precision).format(value)\n                        )\n                        self.max_width = max(self.max_width, len(value_str))\n                else:\n                    for value in nonzero_finite_vals:\n                        value_str = (\"{:.0f}\").format(value)\n                        self.max_width = max(self.max_width, len(value_str) + 1)\n            else:\n                if (\n                    nonzero_finite_max / nonzero_finite_min > 1000.0\n                    or nonzero_finite_max > 1.0e8\n                    or nonzero_finite_min < 1.0e-4\n                ):\n                    self.sci_mode = True\n                    for value in nonzero_finite_vals:\n                        value_str = (\n                            (\"{{:.{}e}}\").format(PRINT_OPTS.precision).format(value)\n                        )\n                        self.max_width = max(self.max_width, len(value_str))\n                else:\n                    for value in nonzero_finite_vals:\n                        value_str = (\n                            (\"{{:.{}f}}\").format(PRINT_OPTS.precision).format(value)\n                        )\n                        self.max_width = max(self.max_width, len(value_str))\n\n        if PRINT_OPTS.sci_mode is not None:\n            self.sci_mode = PRINT_OPTS.sci_mode\n\n    def width(self):\n        return self.max_width\n\n    def format(self, value):\n        if self.floating_dtype:\n            if self.sci_mode:\n                ret = (\n                    (\"{{:{}.{}e}}\")\n                    .format(self.max_width, PRINT_OPTS.precision)\n                    .format(value)\n                )\n            elif self.int_mode:\n                ret = \"{:.0f}\".format(value)\n                if not (math.isinf(value) or math.isnan(value)):\n                    ret += \".\"\n            else:\n                ret = (\"{{:.{}f}}\").format(PRINT_OPTS.precision).format(value)\n        else:\n            ret = \"{}\".format(value)\n        return (self.max_width - len(ret)) * \" \" + ret\n\n\ndef _scalar_str(self, formatter1):\n    return formatter1.format(_try_convert_to_local_tensor(self).tolist())\n\n\ndef _vector_str(self, indent, summarize, formatter1):\n    # length includes spaces and comma between elements\n    element_length = formatter1.width() + 2\n    elements_per_line = max(\n        1, int(math.floor((PRINT_OPTS.linewidth - indent) / (element_length)))\n    )\n\n    def _val_formatter(val, formatter1=formatter1):\n        return formatter1.format(val)\n\n    if summarize and self.size(0) > 2 * PRINT_OPTS.edgeitems:\n        left_values = _try_convert_to_local_tensor(\n            self[: PRINT_OPTS.edgeitems]\n        ).tolist()\n        right_values = _try_convert_to_local_tensor(\n            self[-PRINT_OPTS.edgeitems :]\n        ).tolist()\n        data = (\n            [_val_formatter(val) for val in left_values]\n            + [\" ...\"]\n            + [_val_formatter(val) for val in right_values]\n        )\n    else:\n        values = _try_convert_to_local_tensor(self).tolist()\n        data = [_val_formatter(val) for val in values]\n\n    data_lines = [\n        data[i : i + elements_per_line] for i in range(0, len(data), elements_per_line)\n    ]\n    lines = [\", \".join(line) for line in data_lines]\n    return \"[\" + (\",\" + \"\\n\" + \" \" * (indent + 1)).join(lines) + \"]\"\n\n\ndef _tensor_str_with_formatter(self, indent, summarize, formatter1):\n    dim = self.dim()\n\n    if dim == 0:\n        return _scalar_str(self, formatter1)\n\n    if dim == 1:\n        return _vector_str(self, indent, summarize, formatter1)\n\n    if summarize and self.size(0) > 2 * PRINT_OPTS.edgeitems:\n        slices = (\n            [\n                _tensor_str_with_formatter(self[i], indent + 1, summarize, formatter1,)\n                for i in range(0, PRINT_OPTS.edgeitems)\n            ]\n            + [\"...\"]\n            + [\n                _tensor_str_with_formatter(self[i], indent + 1, summarize, formatter1,)\n                for i in range(self.shape[0] - PRINT_OPTS.edgeitems, self.shape[0])\n            ]\n        )\n    else:\n        slices = [\n            _tensor_str_with_formatter(self[i], indent + 1, summarize, formatter1)\n            for i in range(0, self.size(0))\n        ]\n\n    tensor_str = (\",\" + \"\\n\" * (dim - 1) + \" \" * (indent + 1)).join(slices)\n    return \"[\" + tensor_str + \"]\"\n\n\ndef _tensor_str(self, indent):\n    summarize = self.numel() > PRINT_OPTS.threshold\n    if self.dtype is flow.float16:\n        self = self.float()\n\n    with flow.no_grad():\n        formatter = _Formatter(get_summarized_data(self) if summarize else self)\n        return _tensor_str_with_formatter(self, indent, summarize, formatter)\n\n\ndef _add_suffixes(tensor_str, suffixes, indent):\n    tensor_strs = [tensor_str]\n    last_line_len = len(tensor_str) - tensor_str.rfind(\"\\n\") + 1\n    for suffix in suffixes:\n        suffix_len = len(suffix)\n        if last_line_len + suffix_len + 2 > PRINT_OPTS.linewidth:\n            tensor_strs.append(\",\\n\" + \" \" * indent + suffix)\n            last_line_len = indent + suffix_len\n        else:\n            tensor_strs.append(\", \" + suffix)\n            last_line_len += suffix_len + 2\n    tensor_strs.append(\")\")\n    return \"\".join(tensor_strs)\n\n\ndef get_summarized_data(self):\n    dim = self.dim()\n    if dim == 0:\n        return self\n    if dim == 1:\n        if self.size(0) > 2 * PRINT_OPTS.edgeitems:\n            return flow.cat(\n                (self[: PRINT_OPTS.edgeitems], self[-PRINT_OPTS.edgeitems :])\n            )\n        else:\n            return self\n    if self.size(0) > 2 * PRINT_OPTS.edgeitems:\n        start = [self[i] for i in range(0, PRINT_OPTS.edgeitems)]\n        end = [\n            self[i] for i in range(self.shape[0] - PRINT_OPTS.edgeitems, self.shape[0])\n        ]\n        return flow.stack([get_summarized_data(x) for x in (start + end)])\n    else:\n        return flow.stack([get_summarized_data(x) for x in self])\n\n\ndef _format_tensor_on_cpu(tensor):\n    if tensor.is_global:\n        device = tensor.placement.type\n    else:\n        device = tensor.device.type\n    return device != \"cpu\" and device != \"cuda\"\n\n\ndef _gen_tensor_str_template(tensor, is_meta):\n    is_meta = is_meta or tensor.is_lazy\n    prefix = \"tensor(\"\n    indent = len(prefix)\n    suffixes = []\n\n    meta_device_flag = False\n    # tensor is local or global\n    if tensor.is_global:\n        if tensor.placement.type == \"meta\":\n            meta_device_flag = True\n        suffixes.append(f\"placement={str(tensor.placement)}\")\n        suffixes.append(f\"sbp={str(tensor.sbp)}\")\n    elif tensor.device.type != \"cpu\":\n        if tensor.device.type == \"meta\":\n            meta_device_flag = True\n        suffixes.append(\"device='\" + str(tensor.device) + \"'\")\n    if tensor.is_lazy:\n        suffixes.append(\"is_lazy='True'\")\n\n    # tensor is empty, meta or normal\n    if tensor.numel() == 0:\n        # Explicitly print the shape if it is not (0,), to match NumPy behavior\n        if tensor.dim() != 1:\n            suffixes.append(\"size=\" + str(tuple(tensor.shape)))\n        tensor_str = \"[]\"\n    elif is_meta or meta_device_flag:\n        tensor_str = \"...\"\n        suffixes.append(\"size=\" + str(tuple(tensor.shape)))\n    else:\n        if _format_tensor_on_cpu(tensor):\n            tensor_str = _tensor_str(tensor.detach().to(\"cpu\"), indent)\n        else:\n            tensor_str = _tensor_str(tensor, indent)\n\n    suffixes.append(\"dtype=\" + str(tensor.dtype))\n    if tensor.grad_fn is not None:\n        name = tensor.grad_fn.name()\n        suffixes.append(\"grad_fn=<{}>\".format(name))\n    elif tensor.requires_grad:\n        suffixes.append(\"requires_grad=True\")\n\n    return _add_suffixes(prefix + tensor_str, suffixes, indent)\n\n\ndef _gen_tensor_str(tensor):\n    return _gen_tensor_str_template(tensor, False)\n\n\ndef _gen_tensor_meta_str(tensor):\n    # meta\n    return _gen_tensor_str_template(tensor, True)\n"
  },
  {
    "path": "python/oneflow/framework/tensor_str_util.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport os\nimport oneflow as flow\nfrom typing import Optional, Tuple\n\n\ndef _autoset_linewidth():\n    # os.terminal_size(columns, lines),\n    # columns represents width of the terminal window in characters\n    # and lines represents height of the terminal window in characters.\n    try:\n        linewidth = os.get_terminal_size()[0]\n    except OSError:\n        linewidth = 80\n    return linewidth\n\n\ndef _try_convert_to_local_tensor(tensor):\n    if tensor.is_global:\n        tensor = tensor.to_global(\n            placement=flow.placement.all(tensor.placement.type), sbp=flow.sbp.broadcast,\n        ).to_local()\n    return tensor\n"
  },
  {
    "path": "python/oneflow/framework/tensor_tuple_util.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport collections\nfrom typing import Optional, Sequence, Union\n\nfrom oneflow._oneflow_internal import Tensor, TensorTuple\n\n\ndef convert_to_tensor_tuple(args: Optional[Union[Tensor, Sequence[Tensor]]]):\n    if args is None:\n        return TensorTuple()\n    elif isinstance(args, collections.abc.Sequence):\n        return TensorTuple(args)\n    else:\n        tensor_tuple = TensorTuple()\n        tensor_tuple.append(args)\n        return tensor_tuple\n"
  },
  {
    "path": "python/oneflow/framework/type_tensor.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow\nimport oneflow as flow\n\nfrom oneflow._C import (\n    HalfTensor,\n    FloatTensor,\n    DoubleTensor,\n    BoolTensor,\n    ByteTensor,\n    CharTensor,\n    IntTensor,\n    LongTensor,\n    ComplexFloatTensor,\n    ComplexDoubleTensor,\n)\n\n__all__ = [\n    \"HalfTensor\",\n    \"FloatTensor\",\n    \"DoubleTensor\",\n    \"BoolTensor\",\n    \"ByteTensor\",\n    \"CharTensor\",\n    \"IntTensor\",\n    \"LongTensor\",\n    \"ComplexFloatTensor\",\n    \"ComplexDoubleTensor\",\n    # TODO: Add support for BFloat16Tensor, ComplexHalfTensor\n]\n"
  },
  {
    "path": "python/oneflow/framework/unittest.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport atexit\nimport imp\nimport os\nimport socket\nimport subprocess\nimport sys\nimport unittest\nimport uuid\nimport doctest\nfrom contextlib import closing\nfrom tempfile import NamedTemporaryFile\nfrom typing import Any, Callable, Dict\n\nimport google.protobuf.text_format as pbtxt\n\nimport oneflow\nimport oneflow.env\nimport oneflow.sysconfig\nfrom oneflow.core.job.env_pb2 import EnvProto\n\n\ndef register_test_cases(\n    scope: Dict[str, Any],\n    directory: str,\n    filter_by_num_nodes: Callable[[bool], int],\n    base_class: unittest.TestCase = unittest.TestCase,\n) -> None:\n    def FilterTestPyFile(f):\n        return (\n            os.path.isfile(os.path.join(directory, f))\n            and f.endswith(\".py\")\n            and f.startswith(\"test\")\n        )\n\n    def FilterMethodName(module, name):\n        method = getattr(module, name)\n        return (\n            name.startswith(\"test\")\n            and callable(method)\n            and filter_by_num_nodes(_GetNumOfNodes(method))\n        )\n\n    onlytest_files = [f for f in os.listdir(directory) if FilterTestPyFile(f)]\n    for f in onlytest_files:\n        class_name = f[0:-3]\n        module = imp.load_source(class_name, os.path.join(directory, f))\n        test_func_names = [\n            name for name in dir(module) if FilterMethodName(module, name)\n        ]\n        method_dict = {k: getattr(module, k) for k in test_func_names}\n        scope[class_name] = type(class_name, (test_case_mixin, base_class), method_dict)\n\n\ndef num_nodes_required(num_nodes: int) -> Callable[[Callable], Callable]:\n    def Decorator(f):\n        f.__oneflow_test_case_num_nodes_required__ = num_nodes\n        return f\n\n    return Decorator\n\n\ndef _GetNumOfNodes(func):\n    if hasattr(func, \"__oneflow_test_case_num_nodes_required__\") == False:\n        return 1\n    return getattr(func, \"__oneflow_test_case_num_nodes_required__\")\n\n\ndef eager_execution_enabled():\n    return os.getenv(\"ONEFLOW_TEST_ENABLE_EAGER\") == \"1\"\n\n\ndef typing_check_enabled():\n    return os.getenv(\"ONEFLOW_TEST_ENABLE_TYPING_CHECK\") == \"1\"\n\n\ndef node_list():\n    node_list_str = os.getenv(\"ONEFLOW_TEST_NODE_LIST\")\n    assert node_list_str\n    return node_list_str.split(\",\")\n\n\ndef has_node_list():\n    if os.getenv(\"ONEFLOW_TEST_NODE_LIST\"):\n        return True\n    else:\n        return False\n\n\ndef node_size():\n    node_num_from_env = os.getenv(\"ONEFLOW_TEST_NODE_NUM\", None)\n    if node_num_from_env:\n        return int(node_num_from_env)\n    elif has_node_list():\n        node_list_from_env = node_list()\n        return len(node_list_from_env)\n    else:\n        return 1\n\n\ndef has_world_size():\n    return True\n\n\ndef world_size():\n    return oneflow.env.get_world_size()\n\n\ndef device_num():\n    device_num_str = os.getenv(\"ONEFLOW_TEST_DEVICE_NUM\")\n    if device_num_str:\n        return int(device_num_str)\n    else:\n        return 1\n\n\ndef enable_init_by_host_list():\n    return os.getenv(\"ONEFLOW_TEST_ENABLE_INIT_BY_HOST_LIST\") == \"1\"\n\n\ndef enable_multi_process():\n    return os.getenv(\"ONEFLOW_TEST_MULTI_PROCESS\") == \"1\"\n\n\ndef find_free_port():\n    with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:\n        s.bind((\"localhost\", 0))\n        s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)\n        return s.getsockname()[1]\n\n\n_unittest_worker_initilized = False\n\n\ndef worker_agent_port():\n    port_txt = os.getenv(\"ONEFLOW_TEST_WORKER_AGENT_PORT\")\n    if port_txt:\n        return int(port_txt)\n    else:\n        return None\n\n\ndef worker_agent_authkey():\n    key = os.getenv(\"ONEFLOW_TEST_WORKER_AGENT_AUTHKEY\")\n    assert key\n    return key\n\n\ndef use_worker_agent():\n    return worker_agent_port() is not None\n\n\ndef cast(conn=None, cmd=None, msg=None):\n    cmd = \"cast/\" + cmd\n    print(\"[unittest]\", f\"[{cmd}]\", msg)\n    conn.send(cmd.encode())\n    conn.send(msg.encode())\n\n\ndef call(conn=None, cmd=None, msg=None):\n    cmd = \"call/\" + cmd\n    print(\"[unittest]\", f\"[{cmd}]\", msg)\n    conn.send(cmd.encode())\n    msg_ = \"\"\n    if msg is not None:\n        msg_ = msg\n    conn.send(msg_.encode())\n    return conn.recv().decode()\n\n\nTestCase = unittest.TestCase\n\n\ndef skip_unless(n, d):\n    if (n > 1 or d > 1) and oneflow.sysconfig.has_rpc_backend_grpc() == False:\n        return unittest.skip(\n            \"requires multi node rpc backend when node_size > 1 and device_num > 1\"\n        )\n    if node_size() == n and device_num() == d:\n        return lambda func: func\n    else:\n        return unittest.skip(\n            \"only runs when node_size is {} and device_num is {}\".format(n, d)\n        )\n\n\ndef skip_unless_1n1d():\n    return skip_unless(1, 1)\n\n\ndef skip_unless_1n2d():\n    return skip_unless(1, 2)\n\n\ndef skip_unless_1n4d():\n    return skip_unless(1, 4)\n\n\ndef skip_unless_2n1d():\n    return skip_unless(2, 1)\n\n\ndef skip_unless_2n2d():\n    return skip_unless(2, 2)\n\n\ndef skip_unless_2n4d():\n    return skip_unless(2, 4)\n\n\nclass CondSkipChecker(doctest.OutputChecker):\n    def __init__(self, check_flags):\n        self._check_flags = check_flags\n\n    def check_output(self, want, got, optionflags):\n        # default check_output without flag\n        if optionflags == 0:\n            return super(CondSkipChecker, self).check_output(want, got, optionflags)\n\n        target_rank_list = [bool(flag & optionflags) for flag in self._check_flags]\n        # wrong flag will be handled before here, so any(target_rank_list) is True\n        # not target rank\n        if target_rank_list.index(True) != oneflow.env.get_rank():\n            return True\n        elif target_rank_list.index(True) == oneflow.env.get_rank():\n            return super(CondSkipChecker, self).check_output(want, got, optionflags)\n\n\ndef check_multi_rank_docstr(module):\n    # supply customized flag ONLY_CHECK_RANK_{x} for docstr\n    check_flags = [\n        doctest.register_optionflag(f\"ONLY_CHECK_RANK_{i}\")\n        for i in range(oneflow.env.get_world_size())\n    ]\n    finder = doctest.DocTestFinder()\n    runner = doctest.DebugRunner(CondSkipChecker(check_flags))\n    for test in finder.find(module, module.__name__):\n        runner.run(test)\n"
  },
  {
    "path": "python/oneflow/fx/__init__.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\ntry:\n    from onefx import *\nexcept:\n\n    class Proxy:\n        def __init__(self):\n            raise NotImplementedError(\n                \"oneflow.fx.Proxy is only for compatibility with PyTorch and is not actually implemented.\"\n            )\n"
  },
  {
    "path": "python/oneflow/hub.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n# This file was copyed from https://github.com/pytorch/pytorch/blob/master/torch/hub.py and consistent with oneflow.\n\nimport errno\nimport hashlib\nimport json\nimport os\nimport re\nimport shutil\nimport sys\nimport tempfile\nimport oneflow as flow\nimport warnings\nimport zipfile\nfrom pathlib import Path\nfrom typing import Dict, Optional, Any\nfrom urllib.error import HTTPError\nfrom urllib.request import urlopen, Request\nfrom urllib.parse import urlparse  # noqa: F401\n\ntry:\n    from tqdm.auto import (\n        tqdm,\n    )  # automatically select proper tqdm submodule if available\nexcept ImportError:\n    try:\n        from tqdm import tqdm\n    except ImportError:\n        # fake tqdm if it's not installed\n        class tqdm(object):  # type: ignore[no-redef]\n            def __init__(\n                self,\n                total=None,\n                disable=False,\n                unit=None,\n                unit_scale=None,\n                unit_divisor=None,\n            ):\n                self.total = total\n                self.disable = disable\n                self.n = 0\n                # ignore unit, unit_scale, unit_divisor; they're just for real tqdm\n\n            def update(self, n):\n                if self.disable:\n                    return\n\n                self.n += n\n                if self.total is None:\n                    sys.stderr.write(\"\\r{0:.1f} bytes\".format(self.n))\n                else:\n                    sys.stderr.write(\n                        \"\\r{0:.1f}%\".format(100 * self.n / float(self.total))\n                    )\n                sys.stderr.flush()\n\n            def close(self):\n                self.disable = True\n\n            def __enter__(self):\n                return self\n\n            def __exit__(self, exc_type, exc_val, exc_tb):\n                if self.disable:\n                    return\n\n                sys.stderr.write(\"\\n\")\n\n\n__all__ = [\n    \"download_url_to_file\",\n    \"get_dir\",\n    \"help\",\n    \"list\",\n    \"load\",\n    \"load_state_dict_from_url\",\n    \"set_dir\",\n]\n\n# matches bfd8deac from resnet18-bfd8deac.pth\nHASH_REGEX = re.compile(r\"-([a-f0-9]*)\\.\")\n\n_TRUSTED_REPO_OWNERS = \"oneflow\"\nENV_GITHUB_TOKEN = \"GITHUB_TOKEN\"\nENV_ONEFLOW_HOME = \"ONEFLOW_HOME\"\nENV_XDG_CACHE_HOME = \"XDG_CACHE_HOME\"\nDEFAULT_CACHE_DIR = \"~/.cache\"\nVAR_DEPENDENCY = \"dependencies\"\nMODULE_HUBCONF = \"hubconf.py\"\nREAD_DATA_CHUNK = 8192\n_hub_dir = None\n\n\n# Copied from tools/shared/module_loader to be included in oneflow package\ndef _import_module(name, path):\n    import importlib.util\n    from importlib.abc import Loader\n\n    spec = importlib.util.spec_from_file_location(name, path)\n    assert spec is not None\n    module = importlib.util.module_from_spec(spec)\n    assert isinstance(spec.loader, Loader)\n    spec.loader.exec_module(module)\n    return module\n\n\ndef _remove_if_exists(path):\n    if os.path.exists(path):\n        if os.path.isfile(path):\n            os.remove(path)\n        else:\n            shutil.rmtree(path)\n\n\ndef _git_archive_link(repo_owner, repo_name, ref):\n    # See https://docs.github.com/en/rest/reference/repos#download-a-repository-archive-zip\n    return f\"https://github.com/{repo_owner}/{repo_name}/zipball/{ref}\"\n\n\ndef _load_attr_from_module(module, func_name):\n    # Check if callable is defined in the module\n    if func_name not in dir(module):\n        return None\n    return getattr(module, func_name)\n\n\ndef _get_oneflow_home():\n    oneflow_home = os.path.expanduser(\n        os.getenv(\n            ENV_ONEFLOW_HOME,\n            os.path.join(os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), \"oneflow\"),\n        )\n    )\n    return oneflow_home\n\n\n_get_torch_home = _get_oneflow_home\n\n\ndef _parse_repo_info(github):\n    if \":\" in github:\n        repo_info, ref = github.split(\":\")\n    else:\n        repo_info, ref = github, None\n    repo_owner, repo_name = repo_info.split(\"/\")\n\n    if ref is None:\n        # The ref wasn't specified by the user, so we need to figure out the\n        # default branch: main or master. Our assumption is that if main exists\n        # then it's the default branch, otherwise it's master.\n        try:\n            with urlopen(f\"https://github.com/{repo_owner}/{repo_name}/tree/main/\"):\n                ref = \"main\"\n        except HTTPError as e:\n            if e.code == 404:\n                ref = \"master\"\n            else:\n                raise\n    return repo_owner, repo_name, ref\n\n\ndef _read_url(url):\n    with urlopen(url) as r:\n        return r.read().decode(r.headers.get_content_charset(\"utf-8\"))\n\n\ndef _validate_not_a_forked_repo(repo_owner, repo_name, ref):\n    # Use urlopen to avoid depending on local git.\n    headers = {\"Accept\": \"application/vnd.github.v3+json\"}\n    token = os.environ.get(ENV_GITHUB_TOKEN)\n    if token is not None:\n        headers[\"Authorization\"] = f\"token {token}\"\n    for url_prefix in (\n        f\"https://api.github.com/repos/{repo_owner}/{repo_name}/branches\",\n        f\"https://api.github.com/repos/{repo_owner}/{repo_name}/tags\",\n    ):\n        page = 0\n        while True:\n            page += 1\n            url = f\"{url_prefix}?per_page=100&page={page}\"\n            response = json.loads(_read_url(Request(url, headers=headers)))\n            # Empty response means no more data to process\n            if not response:\n                break\n            for br in response:\n                if br[\"name\"] == ref or br[\"commit\"][\"sha\"].startswith(ref):\n                    return\n\n    raise ValueError(\n        f\"Cannot find {ref} in https://github.com/{repo_owner}/{repo_name}. \"\n        \"If it's a commit from a forked repo, please call hub.load() with forked repo directly.\"\n    )\n\n\ndef _get_cache_or_reload(\n    github, force_reload, trust_repo, calling_fn, verbose=True, skip_validation=False\n):\n    # Setup hub_dir to save downloaded files\n    hub_dir = get_dir()\n    if not os.path.exists(hub_dir):\n        os.makedirs(hub_dir)\n    # Parse github repo information\n    repo_owner, repo_name, ref = _parse_repo_info(github)\n    # Github allows branch name with slash '/',\n    # this causes confusion with path on both Linux and Windows.\n    # Backslash is not allowed in Github branch name so no need to\n    # to worry about it.\n    normalized_br = ref.replace(\"/\", \"_\")\n    # Github renames folder repo-v1.x.x to repo-1.x.x\n    # We don't know the repo name before downloading the zip file\n    # and inspect name from it.\n    # To check if cached repo exists, we need to normalize folder names.\n    owner_name_branch = \"_\".join([repo_owner, repo_name, normalized_br])\n    repo_dir = os.path.join(hub_dir, owner_name_branch)\n    # Check that the repo is in the trusted list\n    _check_repo_is_trusted(\n        repo_owner,\n        repo_name,\n        owner_name_branch,\n        trust_repo=trust_repo,\n        calling_fn=calling_fn,\n    )\n\n    use_cache = (not force_reload) and os.path.exists(repo_dir)\n\n    if use_cache:\n        if verbose:\n            sys.stderr.write(\"Using cache found in {}\\n\".format(repo_dir))\n    else:\n        # Validate the tag/branch is from the original repo instead of a forked repo\n        if not skip_validation:\n            _validate_not_a_forked_repo(repo_owner, repo_name, ref)\n\n        cached_file = os.path.join(hub_dir, normalized_br + \".zip\")\n        _remove_if_exists(cached_file)\n\n        try:\n            url = _git_archive_link(repo_owner, repo_name, ref)\n            sys.stderr.write('Downloading: \"{}\" to {}\\n'.format(url, cached_file))\n            download_url_to_file(url, cached_file, progress=False)\n        except HTTPError as err:\n            if err.code == 300:\n                # Getting a 300 Multiple Choices error likely means that the ref is both a tag and a branch\n                # in the repo. This can be disambiguated by explicitely using refs/heads/ or refs/tags\n                # See https://git-scm.com/book/en/v2/Git-Internals-Git-References\n                # Here, we do the same as git: we throw a warning, and assume the user wanted the branch\n                warnings.warn(\n                    f\"The ref {ref} is ambiguous. Perhaps it is both a tag and a branch in the repo? \"\n                    \"OneFlowhub will now assume that it's a branch. \"\n                    \"You can disambiguate tags and branches by explicitly passing refs/heads/branch_name or \"\n                    \"refs/tags/tag_name as the ref. That might require using skip_validation=True.\"\n                )\n                disambiguated_branch_ref = f\"refs/heads/{ref}\"\n                url = _git_archive_link(\n                    repo_owner, repo_name, ref=disambiguated_branch_ref\n                )\n                download_url_to_file(url, cached_file, progress=False)\n            else:\n                raise\n\n        with zipfile.ZipFile(cached_file) as cached_zipfile:\n            extraced_repo_name = cached_zipfile.infolist()[0].filename\n            extracted_repo = os.path.join(hub_dir, extraced_repo_name)\n            _remove_if_exists(extracted_repo)\n            # Unzip the code and rename the base folder\n            cached_zipfile.extractall(hub_dir)\n\n        _remove_if_exists(cached_file)\n        _remove_if_exists(repo_dir)\n        shutil.move(extracted_repo, repo_dir)  # rename the repo\n\n    return repo_dir\n\n\ndef _check_repo_is_trusted(\n    repo_owner, repo_name, owner_name_branch, trust_repo, calling_fn=\"load\"\n):\n    hub_dir = get_dir()\n    filepath = os.path.join(hub_dir, \"trusted_list\")\n\n    if not os.path.exists(filepath):\n        Path(filepath).touch()\n    with open(filepath, \"r\") as file:\n        trusted_repos = tuple(line.strip() for line in file)\n\n    # To minimize friction of introducing the new trust_repo mechanism, we consider that\n    # if a repo was already downloaded by oneflowhub, then it is already trusted (even if it's not in the allowlist)\n    trusted_repos_legacy = next(os.walk(hub_dir))[1]\n\n    owner_name = \"_\".join([repo_owner, repo_name])\n    is_trusted = (\n        owner_name in trusted_repos\n        or owner_name_branch in trusted_repos_legacy\n        or repo_owner in _TRUSTED_REPO_OWNERS\n    )\n\n    # TODO: Remove `None` option in 1.14 and change the default to \"check\"\n    if trust_repo is None:\n        if not is_trusted:\n            warnings.warn(\n                \"You are about to download and run code from an untrusted repository. In a future release, this won't \"\n                \"be allowed. To add the repository to your trusted list, change the command to {calling_fn}(..., \"\n                \"trust_repo=False) and a command prompt will appear asking for an explicit confirmation of trust, \"\n                f\"or {calling_fn}(..., trust_repo=True), which will assume that the prompt is to be answered with \"\n                f\"'yes'. You can also use {calling_fn}(..., trust_repo='check') which will only prompt for \"\n                f\"confirmation if the repo is not already trusted. This will eventually be the default behaviour\"\n            )\n        return\n\n    if (trust_repo is False) or (trust_repo == \"check\" and not is_trusted):\n        response = input(\n            f\"The repository {owner_name} does not belong to the list of trusted repositories and as such cannot be downloaded. \"\n            \"Do you trust this repository and wish to add it to the trusted list of repositories (y/N)?\"\n        )\n        if response.lower() in (\"y\", \"yes\"):\n            if is_trusted:\n                print(\"The repository is already trusted.\")\n        elif response.lower() in (\"n\", \"no\", \"\"):\n            raise Exception(\"Untrusted repository.\")\n        else:\n            raise ValueError(f\"Unrecognized response {response}.\")\n\n    # At this point we're sure that the user trusts the repo (or wants to trust it)\n    if not is_trusted:\n        with open(filepath, \"a\") as file:\n            file.write(owner_name + \"\\n\")\n\n\ndef _check_module_exists(name):\n    import importlib.util\n\n    return importlib.util.find_spec(name) is not None\n\n\ndef _check_dependencies(m):\n    dependencies = _load_attr_from_module(m, VAR_DEPENDENCY)\n\n    if dependencies is not None:\n        missing_deps = [pkg for pkg in dependencies if not _check_module_exists(pkg)]\n        if len(missing_deps):\n            raise RuntimeError(\n                \"Missing dependencies: {}\".format(\", \".join(missing_deps))\n            )\n\n\ndef _load_entry_from_hubconf(m, model):\n    if not isinstance(model, str):\n        raise ValueError(\"Invalid input: model should be a string of function name\")\n\n    # Note that if a missing dependency is imported at top level of hubconf, it will\n    # throw before this function. It's a chicken and egg situation where we have to\n    # load hubconf to know what're the dependencies, but to import hubconf it requires\n    # a missing package. This is fine, Python will throw proper error message for users.\n    _check_dependencies(m)\n\n    func = _load_attr_from_module(m, model)\n\n    if func is None or not callable(func):\n        raise RuntimeError(\"Cannot find callable {} in hubconf\".format(model))\n\n    return func\n\n\ndef get_dir():\n    \"\"\"\n    Get the OneFlow Hub cache directory used for storing downloaded models & weights.\n    If :func:`~oneflow.hub.set_dir` is not called, default path is ``$ONEFLOW_HOME/hub`` where\n    environment variable ``$ONEFLOW_HOME`` defaults to ``$XDG_CACHE_HOME/oneflow``.\n    ``$XDG_CACHE_HOME`` follows the X Design Group specification of the Linux\n    filesystem layout, with a default value ``~/.cache`` if the environment\n    variable is not set.\n    \"\"\"\n    # Issue warning to move data if old env is set\n    if os.getenv(\"ONEFLOW_HUB\"):\n        warnings.warn(\"ONEFLOW_HUB is deprecated, please use env ONEFLOW_HOME instead\")\n\n    if _hub_dir is not None:\n        return _hub_dir\n    return os.path.join(_get_oneflow_home(), \"hub\")\n\n\ndef set_dir(d):\n    \"\"\"\n    Optionally set the OneFlow Hub directory used to save downloaded models & weights.\n\n    Args:\n        d (str): path to a local folder to save downloaded models & weights.\n    \"\"\"\n    global _hub_dir\n    _hub_dir = os.path.expanduser(d)\n\n\ndef list(github, force_reload=False, skip_validation=False, trust_repo=None):\n    \"\"\"\n    List all callable entrypoints available in the repo specified by ``github``.\n\n    Args:\n        github (str): a string with format \"repo_owner/repo_name[:ref]\" with an optional\n            ref (tag or branch). If ``ref`` is not specified, the default branch is assumed to be ``main`` if\n            it exists, and otherwise ``master``. Example: ' Oneflow-Inc/vision:0.2.0'\n        force_reload (bool, optional): whether to discard the existing cache and force a fresh download.\n            Default is ``False``.\n        skip_validation (bool, optional): if ``False``, oneflowhub will check that the branch or commit\n            specified by the ``github`` argument properly belongs to the repo owner. This will make\n            requests to the GitHub API; you can specify a non-default GitHub token by setting the\n            ``GITHUB_TOKEN`` environment variable. Default is ``False``.\n        trust_repo (bool, str or None): ``\"check\"``, ``True``, ``False`` or ``None``.\n            This parameter was introduced in v1.12 and helps ensuring that users\n            only run code from repos that they trust.\n            - If ``False``, a prompt will ask the user whether the repo should be trusted.\n\n            - If ``True``, the repo will be added to the trusted list and loaded without \n              requiring explicit confirmation.\n\n            - If ``\"check\"``, the repo will be checked against the list of\n              trusted repos in the cache. If it is not present in that list, the\n              behaviour will fall back onto the ``trust_repo=False`` option.\n\n            - If ``None``, this will raise a warning, inviting the user to set\n              ``trust_repo`` to either ``False``, ``True`` or ``\"check\"``. This\n              is only present for backward compatibility and will be removed in\n              v1.14.\n        \n            Default is ``None`` and will eventually change to ``\"check\"`` in v1.14.\n    \n    Returns:\n        list: The available callables entrypoint\n    \n    For example:\n\n        >>> entrypoints = oneflow.hub.list('Oneflow-Inc/vision', force_reload=True)\n    \n    \"\"\"\n    repo_dir = _get_cache_or_reload(\n        github,\n        force_reload,\n        trust_repo,\n        \"list\",\n        verbose=True,\n        skip_validation=skip_validation,\n    )\n\n    sys.path.insert(0, repo_dir)\n\n    hubconf_path = os.path.join(repo_dir, MODULE_HUBCONF)\n    hub_module = _import_module(MODULE_HUBCONF, hubconf_path)\n\n    sys.path.remove(repo_dir)\n\n    # We take functions starts with '_' as internal helper functions\n    entrypoints = [\n        f\n        for f in dir(hub_module)\n        if callable(getattr(hub_module, f)) and not f.startswith(\"_\")\n    ]\n\n    return entrypoints\n\n\ndef help(github, model, force_reload=False, skip_validation=False, trust_repo=None):\n    \"\"\"\n    Show the docstring of entrypoint ``model``.\n\n    Args:\n        github (str): a string with format <repo_owner/repo_name[:ref]> with an optional\n            ref (a tag or a branch). If ``ref`` is not specified, the default branch is assumed\n            to be ``main`` if it exists, and otherwise ``master``.\n            Example: 'Oneflow-Inc/vision:0.2.0'\n        model (str): a string of entrypoint name defined in repo's ``hubconf.py``\n        force_reload (bool, optional): whether to discard the existing cache and force a fresh download.\n            Default is ``False``.\n        skip_validation (bool, optional): if ``False``, oneflowhub will check that the ref\n            specified by the ``github`` argument properly belongs to the repo owner. This will make\n            requests to the GitHub API; you can specify a non-default GitHub token by setting the\n            ``GITHUB_TOKEN`` environment variable. Default is ``False``.\n        trust_repo (bool, str or None): ``\"check\"``, ``True``, ``False`` or ``None``.\n            This parameter was introduced in v1.12 and helps ensuring that users\n            only run code from repos that they trust.\n            \n            - If ``False``, a prompt will ask the user whether the repo should\n              be trusted.\n            \n            - If ``True``, the repo will be added to the trusted list and loaded\n              without requiring explicit confirmation.\n            \n            - If ``\"check\"``, the repo will be checked against the list of\n              trusted repos in the cache. If it is not present in that list, the\n              behaviour will fall back onto the ``trust_repo=False`` option.\n            \n            - If ``None``: this will raise a warning, inviting the user to set\n              ``trust_repo`` to either ``False``, ``True`` or ``\"check\"``. This\n              is only present for backward compatibility and will be removed in\n              v1.14.\n            \n            Default is ``None`` and will eventually change to ``\"check\"`` in v1.14.\n    \n    For example:\n        >>> print(oneflow.hub.help('Oneflow-Inc/vision', 'resnet18', force_reload=True))\n    \"\"\"\n    repo_dir = _get_cache_or_reload(\n        github,\n        force_reload,\n        trust_repo,\n        \"help\",\n        verbose=True,\n        skip_validation=skip_validation,\n    )\n\n    sys.path.insert(0, repo_dir)\n\n    hubconf_path = os.path.join(repo_dir, MODULE_HUBCONF)\n    hub_module = _import_module(MODULE_HUBCONF, hubconf_path)\n\n    sys.path.remove(repo_dir)\n\n    entry = _load_entry_from_hubconf(hub_module, model)\n\n    return entry.__doc__\n\n\ndef load(\n    repo_or_dir,\n    model,\n    *args,\n    source=\"github\",\n    trust_repo=None,\n    force_reload=False,\n    verbose=True,\n    skip_validation=False,\n    **kwargs,\n):\n    \"\"\"\n    Load a model from a github repo or a local directory.\n    Note: Loading a model is the typical use case, but this can also be used to\n    for loading other objects such as tokenizers, loss functions, etc.\n    If ``source`` is 'github', ``repo_or_dir`` is expected to be\n    of the form ``repo_owner/repo_name[:ref]`` with an optional\n    ref (a tag or a branch).\n    If ``source`` is 'local', ``repo_or_dir`` is expected to be a\n    path to a local directory.\n    \n    Args:\n        repo_or_dir (str): If ``source`` is 'github',\n            this should correspond to a github repo with format ``repo_owner/repo_name[:ref]`` with\n            an optional ref (tag or branch), for example 'Oneflow-Inc/vision:0.2.0'. If ``ref`` is not specified,\n            the default branch is assumed to be ``main`` if it exists, and otherwise ``master``.\n            If ``source`` is 'local'  then it should be a path to a local directory.\n        model (str): the name of a callable (entrypoint) defined in the\n            repo/dir's ``hubconf.py``.\n        *args (optional): the corresponding args for callable ``model``.\n        source (str, optional): 'github' or 'local'. Specifies how\n            ``repo_or_dir`` is to be interpreted. Default is 'github'.\n        trust_repo (bool, str or None): ``\"check\"``, ``True``, ``False`` or ``None``.\n            This parameter was introduced in v1.12 and helps ensuring that users\n            only run code from repos that they trust.\n            \n            - If ``False``, a prompt will ask the user whether the repo should\n              be trusted.\n            \n            - If ``True``, the repo will be added to the trusted list and loaded\n              without requiring explicit confirmation.\n            \n            - If ``\"check\"``, the repo will be checked against the list of\n              trusted repos in the cache. If it is not present in that list, the\n              behaviour will fall back onto the ``trust_repo=False`` option.\n            \n            - If ``None``: this will raise a warning, inviting the user to set\n              ``trust_repo`` to either ``False``, ``True`` or ``\"check\"``. This\n              is only present for backward compatibility and will be removed in\n              v1.14.\n            \n            Default is ``None`` and will eventually change to ``\"check\"`` in v1.14.\n        force_reload (bool, optional): whether to force a fresh download of\n            the github repo unconditionally. Does not have any effect if\n            ``source = 'local'``. Default is ``False``.\n        verbose (bool, optional): If ``False``, mute messages about hitting\n            local caches. Note that the message about first download cannot be\n            muted. Does not have any effect if ``source = 'local'``.\n            Default is ``True``.\n        skip_validation (bool, optional): if ``False``, oneflowhub will check that the branch or commit\n            specified by the ``github`` argument properly belongs to the repo owner. This will make\n            requests to the GitHub API; you can specify a non-default GitHub token by setting the\n            ``GITHUB_TOKEN`` environment variable. Default is ``False``.\n        **kwargs (optional): the corresponding kwargs for callable ``model``.\n    \n    Returns:\n        The output of the ``model`` callable when called with the given\n        ``*args`` and ``**kwargs``.\n    \n    For example:\n        >>> # from a github repo\n        >>> repo = 'Oneflow-Inc/vision'\n        >>> model = oneflow.hub.load(repo, 'resnet50', weights='ResNet50_Weights.IMAGENET1K_V1')\n        >>> # from a local directory\n        >>> path = '/some/local/path/oneflow/vision'\n        >>> # xdoctest: +SKIP\n        >>> model = oneflow.hub.load(path, 'resnet50', weights='ResNet50_Weights.DEFAULT')\n    \"\"\"\n    source = source.lower()\n\n    if source not in (\"github\", \"local\"):\n        raise ValueError(\n            f'Unknown source: \"{source}\". Allowed values: \"github\" | \"local\".'\n        )\n\n    if source == \"github\":\n        repo_or_dir = _get_cache_or_reload(\n            repo_or_dir,\n            force_reload,\n            trust_repo,\n            \"load\",\n            verbose=verbose,\n            skip_validation=skip_validation,\n        )\n\n    model = _load_local(repo_or_dir, model, *args, **kwargs)\n    return model\n\n\ndef _load_local(hubconf_dir, model, *args, **kwargs):\n    \"\"\"\n    Load a model from a local directory with a ``hubconf.py``.\n    \n    Args:\n        hubconf_dir (str): path to a local directory that contains a\n            ``hubconf.py``.\n        model (str): name of an entrypoint defined in the directory's\n            ``hubconf.py``.\n        *args (optional): the corresponding args for callable ``model``.\n        **kwargs (optional): the corresponding kwargs for callable ``model``.\n    \n    Returns:\n        a single model with corresponding pretrained weights.\n    \n    For example:\n        >>> # xdoctest: +SKIP(\"stub local path\")\n        >>> path = '/some/local/path/oneflow/vision'\n        >>> model = _load_local(path, 'resnet50', weights='ResNet50_Weights.IMAGENET1K_V1')\n    \"\"\"\n    sys.path.insert(0, hubconf_dir)\n\n    hubconf_path = os.path.join(hubconf_dir, MODULE_HUBCONF)\n    hub_module = _import_module(MODULE_HUBCONF, hubconf_path)\n\n    entry = _load_entry_from_hubconf(hub_module, model)\n    model = entry(*args, **kwargs)\n\n    sys.path.remove(hubconf_dir)\n\n    return model\n\n\ndef download_url_to_file(url, dst, hash_prefix=None, progress=True):\n    \"\"\"Download object at the given URL to a local path.\n    \n    Args:\n        url (str): URL of the object to download\n        dst (str): Full path where object will be saved, e.g. ``/tmp/temporary_file``\n        hash_prefix (str, optional): If not None, the SHA256 downloaded file should start with ``hash_prefix``.\n            Default: None\n        progress (bool, optional): whether or not to display a progress bar to stderr\n            Default: True\n    \n    For example:\n        >>> # xdoctest: +REQUIRES(POSIX)\n        >>> oneflow.hub.download_url_to_file('https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/ResNet/resnet18.zip', '/tmp/temporary_file')\n    \"\"\"\n    file_size = None\n    req = Request(url, headers={\"User-Agent\": \"oneflow.hub\"})\n    u = urlopen(req)\n    meta = u.info()\n    if hasattr(meta, \"getheaders\"):\n        content_length = meta.getheaders(\"Content-Length\")\n    else:\n        content_length = meta.get_all(\"Content-Length\")\n    if content_length is not None and len(content_length) > 0:\n        file_size = int(content_length[0])\n\n    # We deliberately save it in a temp file and move it after\n    # download is complete. This prevents a local working checkpoint\n    # being overridden by a broken download.\n    dst = os.path.expanduser(dst)\n    dst_dir = os.path.dirname(dst)\n    f = tempfile.NamedTemporaryFile(delete=False, dir=dst_dir)\n\n    try:\n        if hash_prefix is not None:\n            sha256 = hashlib.sha256()\n        with tqdm(\n            total=file_size,\n            disable=not progress,\n            unit=\"B\",\n            unit_scale=True,\n            unit_divisor=1024,\n        ) as pbar:\n            while True:\n                buffer = u.read(8192)\n                if len(buffer) == 0:\n                    break\n                f.write(buffer)\n                if hash_prefix is not None:\n                    sha256.update(buffer)\n                pbar.update(len(buffer))\n\n        f.close()\n        if hash_prefix is not None:\n            digest = sha256.hexdigest()\n            if digest[: len(hash_prefix)] != hash_prefix:\n                raise RuntimeError(\n                    'invalid hash value (expected \"{}\", got \"{}\")'.format(\n                        hash_prefix, digest\n                    )\n                )\n        shutil.move(f.name, dst)\n    finally:\n        f.close()\n        if os.path.exists(f.name):\n            os.remove(f.name)\n\n\n# Hub used to support automatically extracts from zipfile manually compressed by users.\n# We should remove this support since zipfile is now default zipfile format for oneflow.save().\ndef _is_legacy_zip_format(filename):\n    if zipfile.is_zipfile(filename):\n        return True\n    else:\n        return False\n\n\ndef _legacy_zip_load(filename, model_dir, map_location):\n    # Note: extractall() defaults to overwrite file if exists. No need to clean up beforehand.\n    #       We deliberately don't handle tarfile here since our legacy serialization format was in tar.\n    #       E.g. resnet18-5c106cde.pth which is widely used.\n    with zipfile.ZipFile(filename) as f:\n        members = f.infolist()\n        f.extractall(model_dir)\n        extraced_name = members[0].filename\n        extracted_file = os.path.join(model_dir, extraced_name)\n    return flow.load(extracted_file, map_location=map_location)\n\n\ndef load_state_dict_from_url(\n    url: str,\n    model_dir: Optional[str] = None,\n    map_location=None,\n    progress: bool = True,\n    check_hash: bool = False,\n    file_name: Optional[str] = None,\n) -> Dict[str, Any]:\n    \"\"\"Loads the OneFlow serialized object at the given URL.\n    If downloaded file is a zip file, it will be automatically\n    decompressed.\n    If the object is already present in `model_dir`, it's deserialized and\n    returned.\n    The default value of ``model_dir`` is ``<hub_dir>/checkpoints`` where\n    ``hub_dir`` is the directory returned by :func:`~oneflow.hub.get_dir`.\n    \n    Args:\n        url (str): URL of the object to download\n        model_dir (str, optional): directory in which to save the object\n        map_location (optional): a function or a dict specifying how to remap storage locations (see oneflow.load)\n        progress (bool, optional): whether or not to display a progress bar to stderr.\n            Default: True\n        check_hash(bool, optional): If True, the filename part of the URL should follow the naming convention\n            ``filename-<sha256>.ext`` where ``<sha256>`` is the first eight or more\n            digits of the SHA256 hash of the contents of the file. The hash is used to\n            ensure unique names and to verify the contents of the file.\n            Default: False\n        file_name (str, optional): name for the downloaded file. Filename from ``url`` will be used if not set.\n    \n    For example:\n        >>> state_dict = oneflow.hub.load_state_dict_from_url('https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/ResNet/resnet18.zip')\n    \"\"\"\n    # Issue warning to move data if old env is set\n    if os.getenv(\"ONEFLOW_MODEL_ZOO\"):\n        warnings.warn(\n            \"ONEFLOW_MODEL_ZOO is deprecated, please use env ONEFLOW_HOME instead\"\n        )\n\n    if model_dir is None:\n        hub_dir = get_dir()\n        model_dir = os.path.join(hub_dir, \"checkpoints\")\n\n    try:\n        os.makedirs(model_dir)\n    except OSError as e:\n        if e.errno == errno.EEXIST:\n            # Directory already exists, ignore.\n            pass\n        else:\n            # Unexpected OSError, re-raise.\n            raise\n\n    parts = urlparse(url)\n    filename = os.path.basename(parts.path)\n    if file_name is not None:\n        filename = file_name\n    cached_file = os.path.join(model_dir, filename)\n    if not os.path.exists(cached_file):\n        sys.stderr.write('Downloading: \"{}\" to {}\\n'.format(url, cached_file))\n        hash_prefix = None\n        if check_hash:\n            r = HASH_REGEX.search(filename)  # r is Optional[Match[str]]\n            hash_prefix = r.group(1) if r else None\n        download_url_to_file(url, cached_file, hash_prefix, progress=progress)\n\n    if _is_legacy_zip_format(cached_file):\n        return _legacy_zip_load(cached_file, model_dir, map_location)\n\n    return flow.load(cached_file, map_location=map_location)\n"
  },
  {
    "path": "python/oneflow/ir/__main__.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport argparse\nimport oneflow\n\nparser = argparse.ArgumentParser()\nparser.add_argument(\"--gen_ods\", default=False, action=\"store_true\", required=False)\nargs = parser.parse_args()\n\nif __name__ == \"__main__\":\n    oneflow._oneflow_internal.ir.gen_ods()\n"
  },
  {
    "path": "python/oneflow/ir/ast_gen_transformer.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow\nimport ast\n\n\nclass ASTTransformer(ast.NodeTransformer):\n    def visit_arg(self, node: ast.arg):\n        node.ast = oneflow._oneflow_internal.ir.arg_(node.arg)\n        return node\n\n    def visit_arguments(self, node: ast.arguments):\n        for arg in node.args:\n            self.visit(arg)\n\n        list = [arg.ast for arg in node.args]\n        node.ast = oneflow._oneflow_internal.ir.arguments_(list)\n        return node\n\n    def visit_FunctionDef(self, node: ast.FunctionDef):\n        for arg in node.body:\n            self.visit(arg)\n\n        body = [arg.ast for arg in node.body]\n        self.visit(node.args)\n        node.ast = oneflow._oneflow_internal.ir.FunctionDef_(\n            \"get_lr\", node.args.ast, body\n        )\n        return node\n\n    def visit_Return(self, node: ast.Return):\n        self.visit(node.value)\n\n        node.ast = oneflow._oneflow_internal.ir.Return_(node.value.ast)\n        return node\n\n    def visit_Assign(self, node: ast.Assign):\n        self.visit(node.value)\n        for arg in node.targets:\n            self.visit(arg)\n\n        targets = [arg.ast for arg in node.targets]\n        node.ast = oneflow._oneflow_internal.ir.Assign_(targets, node.value.ast)\n        return node\n\n    def visit_If(self, node: ast.If):\n        self.visit(node.test)\n        for arg in node.body:\n            self.visit(arg)\n\n        if node.orelse:\n            for arg in node.orelse:\n                self.visit(arg)\n\n        test = node.test.ast\n        body = [arg.ast for arg in node.body]\n        orelse = [arg.ast for arg in node.orelse]\n        node.ast = oneflow._oneflow_internal.ir.If_(test, body, orelse)\n        return node\n\n    def visit_Raise(self, node: ast.Raise):\n        print(ast.dump(node))\n        raise \"not suport yet now\"\n\n    def visit_Assert(self, node: ast.Assert):\n        print(ast.dump(node))\n        raise \"not suport yet now\"\n\n    def visit_Expr(self, node: ast.Expr):\n        print(ast.dump(node))\n        raise \"not suport yet now\"\n\n    def visit_BoolOp(self, node: ast.BoolOp):\n        print(ast.dump(node))\n        raise \"not suport yet now\"\n\n    def visit_BinOp(self, node: ast.BinOp):\n        self.visit(node.left)\n        self.visit(node.right)\n\n        left = node.left.ast\n        right = node.right.ast\n\n        def get_op(op: ast.operator):\n            list = [ast.Add, ast.Sub, ast.Mult, ast.Div, ast.Pow]\n            res = 1\n            for elem in list:\n                if isinstance(op, elem):\n                    return res\n                res += 1\n\n        op = get_op(node.op)\n\n        node.ast = oneflow._oneflow_internal.ir.BinOp_(left, op, right)\n        return node\n\n    def visit_Lambda(self, node: ast.Lambda):\n        print(ast.dump(node))\n        raise \"not suport yet now\"\n\n    def visit_Compare(self, node: ast.Compare):\n        self.visit(node.left)\n\n        for arg in node.comparators:\n            self.visit(arg)\n\n        left = node.left.ast\n        comparators = [arg.ast for arg in node.comparators]\n\n        def get_op(op: ast.operator):\n            list = [ast.Eq, ast.NotEq, ast.Lt, ast.LtE, ast.Gt, ast.GtE]\n            res = 1\n            for elem in list:\n                if isinstance(op, elem):\n                    return res\n                res += 1\n\n        ops = [get_op(arg) for arg in node.ops]\n\n        node.ast = oneflow._oneflow_internal.ir.Compare_(left, ops, comparators)\n        return node\n\n    def visit_Call(self, node: ast.Call):\n        self.visit(node.func)\n\n        for arg in node.args:\n            self.visit(arg)\n\n        func = node.func.ast\n        args = [arg.ast for arg in node.args]\n\n        node.ast = oneflow._oneflow_internal.ir.Call_(func, args)\n        return node\n\n    def visit_Constant(self, node: ast.Constant):\n        node.ast = oneflow._oneflow_internal.ir.Constant_(node.value)\n        return node\n\n    def visit_Num(self, node: ast.Num):\n        node.ast = oneflow._oneflow_internal.ir.Num_(node.value)\n        return node\n\n    def visit_Attribute(self, node: ast.Attribute):\n        self.visit(node.value)\n        value = node.value.ast\n\n        node.ast = oneflow._oneflow_internal.ir.Attribute_(value, node.attr)\n        return node\n\n    def visit_Name(self, node: ast.Name):\n        node.ast = oneflow._oneflow_internal.ir.Name_(node.id)\n        return node\n"
  },
  {
    "path": "python/oneflow/ir/bisect_transformer.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport ast\nfrom bisect import bisect\n\n\nclass BisectTransformer(ast.NodeTransformer):\n    def visit_FunctionDef(self, node: ast.FunctionDef):\n        self.body_index = 0\n        self.body = node.body\n        for stmt in node.body:\n            self.visit(stmt)\n        self.body_index += 1\n        return node\n\n    def visit_Call(self, node: ast.Call):\n        if isinstance(node.func, ast.Attribute):\n            func: ast.Attribute = node.func\n            if func.value.id == \"bisect\":\n                bisect_x_list = [\"bisect_right\", \"bisect_left\"]\n                if func.attr in bisect_x_list:\n                    op = ast.LtE\n                    if func.attr == \"bisect_right\":\n                        op = ast.Lt\n                    if not isinstance(node.args[0], ast.List):\n                        raise \"only support bisect.bisect_right(list, x)\"\n                    ls = node.args[0].elts\n                    cmp = node.args[1]\n                    index = 0\n                    for i in ls[::-1]:\n                        test = ast.Compare(cmp, [op()], [i])\n                        assign = ast.Assign(\n                            [ast.Name(\"tmp\")], ast.Constant(len(ls) - index - 1, None)\n                        )\n                        if \"orelse\" in locals():\n                            orelse = ast.If(test, [assign], [orelse])\n                        else:\n                            orelse = ast.If(test, [assign], [])\n                        index += 1\n                    self.body.insert(self.body_index, orelse)\n                    return ast.Name(\"tmp\")\n        return node\n"
  },
  {
    "path": "python/oneflow/ir/lr_jit.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport ast\nimport textwrap\nimport inspect\nimport oneflow\n\nimport unittest\nimport oneflow.unittest\n\nfrom ast_gen_transformer import ASTTransformer\nfrom math_params_transformer import MathParamsTransformer\nfrom self_params_transformer import SelfParamsTransformer\nfrom bisect_transformer import BisectTransformer\n\n\ndef lr_jit_register(lr_obj, is_dump=False):\n    _id = lr_obj.__class__.__name__\n    # load source txt\n    _src = textwrap.dedent(inspect.getsource(lr_obj.get_lr))\n    _ast = ast.parse(_src).body[0]\n\n    # transform param self\n    transformer = SelfParamsTransformer(lr_obj)\n    transformer.visit(_ast)\n\n    # transform for bisect lib\n    transformer = BisectTransformer()\n    transformer.visit(_ast)\n\n    # transform for math lib\n    transformer = MathParamsTransformer()\n    transformer.visit(_ast)\n\n    # feed transformed as to C++\n    transformer = ASTTransformer()\n    transformer.visit(_ast)\n\n    oneflow._oneflow_internal.ir.compile_and_register_lr_jit(_id, _ast.ast, is_dump)\n    return _id\n\n\ndef _test_current_lr_jit(test_case):\n    from oneflow.nn.optimizer.constant_lr import ConstantLR\n    from oneflow.nn.optimizer.cosine_annealing_lr import CosineAnnealingLR\n    from oneflow.nn.optimizer.cosine_decay_lr import CosineDecayLR\n    from oneflow.nn.optimizer.exponential_lr import ExponentialLR\n    from oneflow.nn.optimizer.lambda_lr import LambdaLR\n    from oneflow.nn.optimizer.linear_lr import LinearLR\n    from oneflow.nn.optimizer.multistep_lr import MultiStepLR\n    from oneflow.nn.optimizer.polynomial_lr import PolynomialLR\n    from oneflow.nn.optimizer.sequential_lr import SequentialLR\n    from oneflow.nn.optimizer.step_lr import StepLR\n    from oneflow.nn.optimizer.warmup_lr import WarmupLR\n\n    from oneflow.optim import SGD\n    from oneflow.nn import Parameter\n    import numpy as np\n\n    param = Parameter(oneflow.ones(3, 4))\n    optimizer = SGD([param], lr=0.001)\n\n    lr_jit = oneflow._oneflow_internal.ir.create_global_lr_jit()\n\n    lr_obj_list = [\n        # WarmupLR(optimizer),\n        StepLR(optimizer, 5),\n        # SequentialLR(optimizer),\n        PolynomialLR(optimizer, 5),\n        MultiStepLR(optimizer, [10, 20, 30]),\n        LinearLR(optimizer),\n        # LambdaLR(optimizer, [lambda step: 0.95 * step]),\n        ExponentialLR(optimizer, 1.1),\n        CosineDecayLR(optimizer, 10),\n        CosineAnnealingLR(optimizer, 50),\n        ConstantLR(optimizer),\n    ]\n\n    for lr_obj in lr_obj_list:\n        id_ = lr_jit_register(lr_obj, False)\n\n        ls = [[0.005, 5], [0.01, 10], [0.02, 21]]\n        for elem in ls:\n            base_lr = elem[0]\n            step = elem[1]\n            lr = lr_obj.get_lr(base_lr, step)\n            lr_jit = oneflow._oneflow_internal.ir.get_lr(id_, base_lr, step)\n            test_case.assertTrue(np.isclose(lr, lr_jit))\n\n\n@oneflow.unittest.skip_unless_1n1d()\nclass TestCurrentLRJIT(oneflow.unittest.MLIRTestCase):\n    def test_current_lr_jit(test_case):\n        _test_current_lr_jit(test_case)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/ir/math_params_transformer.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport ast\n\n\nclass MathParamsTransformer(ast.NodeTransformer):\n    def visit_Attribute(self, node):\n        import math\n\n        list = [\"pi\"]\n        if node.value.id == \"math\":\n            if node.attr in list:\n                _name = node.attr\n                _attr = getattr(math, _name)\n                return ast.Constant(_attr, None)\n        return node\n"
  },
  {
    "path": "python/oneflow/ir/self_params_transformer.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport ast\n\n\nclass SelfParamsTransformer(ast.NodeTransformer):\n    def __init__(self, lr_obj):\n        super().__init__()\n        self.lr_obj = lr_obj\n\n    def visit_Attribute(self, node):\n        if node.value.id == \"self\":\n            _name = node.attr\n            _attr = getattr(self.lr_obj, _name)\n            if isinstance(_attr, list):\n                ls = [ast.Constant(elem, None) for elem in _attr]\n                return ast.List(ls)\n            return ast.Constant(_attr, None)\n        return node\n\n    def visit_arguments(self, node: ast.arguments):\n        for index, item in enumerate(node.args):\n            if item.arg == \"self\":\n                node.args.pop(index)\n        return node\n"
  },
  {
    "path": "python/oneflow/jit/__init__.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport warnings\nfrom typing import Any, Dict, List, Set, Tuple, Union, Callable\n\n\ndef script(\n    obj,\n    optimize=None,\n    _frames_up=0,\n    _rcb=None,\n    example_inputs: Union[List[Tuple], Dict[Callable, List[Tuple]], None] = None,\n):\n    warnings.warn(\n        \"The oneflow.jit.script interface is just to align the torch.jit.script interface and has no practical significance.\"\n    )\n    return obj\n\n\ndef ignore(drop=False, **kwargs):\n    warnings.warn(\n        \"The oneflow.jit.ignore interface is just to align the torch.jit.ignore interface and has no practical significance.\"\n    )\n\n    def decorator(fn):\n        return fn\n\n    return decorator\n\n\ndef unused(fn):\n    warnings.warn(\n        \"The oneflow.jit.unused interface is just to align the torch.jit.unused interface and has no practical significance.\"\n    )\n\n    return fn\n\n\ndef _script_if_tracing(fn):\n    warnings.warn(\n        \"The oneflow.jit._script_if_tracing interface is just to align the torch.jit._script_if_tracing interface and has no practical significance.\"\n    )\n\n    return fn\n\n\ndef _overload_method(fn):\n    warnings.warn(\n        \"The oneflow.jit._overload_method interface is just to align the torch.jit._overload_method interface and has no practical significance.\"\n    )\n\n    return fn\n\n\ndef is_scripting():\n    return False\n\n\ndef is_tracing():\n    return False\n\n\nclass _Final:\n    \"\"\"Mixin to prohibit subclassing\"\"\"\n\n    __slots__ = (\"__weakref__\",)\n\n    def __init_subclass__(self, *args, **kwds):\n        if \"_root\" not in kwds:\n            raise TypeError(\"Cannot subclass special typing classes\")\n\n\nclass _SpecialForm(_Final, _root=True):\n    __slots__ = (\"_name\", \"__doc__\", \"_getitem\")\n\n    def __init__(self, getitem):\n        self._getitem = getitem\n        self._name = getitem.__name__\n        self.__doc__ = getitem.__doc__\n\n    def __getattr__(self, item):\n        if item in {\"__name__\", \"__qualname__\"}:\n            return self._name\n\n        raise AttributeError(item)\n\n    def __mro_entries__(self, bases):\n        raise TypeError(f\"Cannot subclass {self!r}\")\n\n    def __repr__(self):\n        return \"typing.\" + self._name\n\n    def __reduce__(self):\n        return self._name\n\n    def __call__(self, *args, **kwds):\n        raise TypeError(f\"Cannot instantiate {self!r}\")\n\n    def __or__(self, other):\n        return Union[self, other]\n\n    def __ror__(self, other):\n        return Union[other, self]\n\n    def __instancecheck__(self, obj):\n        raise TypeError(f\"{self} cannot be used with isinstance()\")\n\n    def __subclasscheck__(self, cls):\n        raise TypeError(f\"{self} cannot be used with issubclass()\")\n\n    def __getitem__(self, parameters):\n        return self._getitem(self, parameters)\n\n\n@_SpecialForm\ndef Final(*args, **kwargs):\n    warnings.warn(\n        \"The oneflow.jit.Final interface is just to align the torch.jit.Final interface and has no practical significance.\"\n    )\n\n\ndef interface(fn):\n    warnings.warn(\n        \"The oneflow.jit.interface interface is just to align the torch.jit.interface interface and has no practical significance.\"\n    )\n    return fn\n"
  },
  {
    "path": "python/oneflow/jit/annotations.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nfrom typing import Tuple, List\n\nBroadcastingList2 = Tuple\n"
  },
  {
    "path": "python/oneflow/library.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport warnings\n\nwarnings.warn(\n    \"The oneflow.library interface is just to align the torch.library interface and has no practical significance.\"\n)\n"
  },
  {
    "path": "python/oneflow/linalg.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow as flow\n\n\ndef norm(self, ord=None, dim=None, keepdim=False, dtype=None):\n    return flow._C.norm(self, ord, dim, keepdim, dtype=dtype)\n\n\ndef vector_norm(self, ord=2, dim=None, keepdim=False, dtype=None):\n    return flow._C.vector_norm(self, ord, dim, keepdim, dtype=dtype)\n\n\ndef matrix_norm(self, ord=\"fro\", dim=(-2, -1), keepdim=False, dtype=None):\n    return flow._C.matrix_norm(self, ord, dim, keepdim, dtype=dtype)\n\n\ndef inv(self):\n    return flow._C.inv(self)\n\n\ndef diagonal(self, input, offset=0, dim1=-2, dim2=-1):\n    \"\"\"\n    Alias for :func:`oneflow.diagonal` with defaults :attr:`dim1`\\ `= -2`, :attr:`dim2`\\ `= -1`.\n    \"\"\"\n    return flow._C.diagonal(self, input, offset=offset, dim1=dim1, dim2=dim2)\n\n\ndef cross(input, other, dim=-1):\n    return flow._C.linalg_cross(input, other, dim=dim)\n\n\ndef det(A):\n    \"\"\"\n    Computes the determinant of a square matrix.\n\n    Supports input of float, double dtypes. Also supports batches of matrices,\n    and if A is a batch of matrices then the output has the same batch dimensions.\n\n    The interface is consistent with PyTorch.\n\n    The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.linalg.det.html\n\n    Args:\n        A (Tensor): tensor of shape (\\*, n, n) where \\* is zero or more batch dimensions.\n\n    Returns:\n        oneflow.Tensor: the output Tensor.\n\n    .. warning::\n        Currently, only CUDA11 and above versions are supported.\n\n    \"\"\"\n    return flow._C.det(A)\n\n\ndef solve():\n    raise NotImplementedError()\n"
  },
  {
    "path": "python/oneflow/mock_torch/__init__.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nfrom .mock_importer import ModuleWrapper, enable, disable\nfrom .mock_modules import DummyModule\nfrom .dyn_mock_mod import DynamicMockModule\n"
  },
  {
    "path": "python/oneflow/mock_torch/__main__.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport argparse\nfrom pathlib import Path\nimport os\nimport sys\n\nif sys.version_info < (3, 8):\n    try:\n        from importlib_metadata import requires\n    except ImportError:\n        import subprocess\n\n        subprocess.check_call(\"pip install importlib_metadata\", shell=True)\n        subprocess.check_call(\"pip install packaging\", shell=True)\n\nparser = argparse.ArgumentParser()\nparser.add_argument(\n    \"mock\",\n    choices=[\"enable\", \"disable\"],\n    help=\"enable/disable mocking 'import torch', default is enable\",\n    nargs=\"?\",\n    default=\"enable\",\n)\nparser.add_argument(\"--lazy\", action=\"store_true\")\nparser.add_argument(\"--verbose\", action=\"store_true\")\nargs = parser.parse_args()\n\ntorch_env = Path(__file__).parent\n\n\ndef main():\n    def is_torch_env(s):\n        if s.endswith(\"oneflow/mock_torch\"):\n            return True\n        return False\n\n    if args.mock == \"enable\":\n        print(\n            f\"export ONEFLOW_MOCK_TORCH_LAZY={args.lazy}; export ONEFLOW_MOCK_TORCH_VERBOSE={args.verbose}; export PYTHONPATH={str(torch_env)}:$PYTHONPATH\"\n        )\n    elif args.mock == \"disable\" and \"PYTHONPATH\" in os.environ:\n        paths = os.environ[\"PYTHONPATH\"].rstrip(\":\").split(\":\")\n        paths = [p for p in paths if not is_torch_env(p)]\n        if len(paths) == 0:\n            print(\n                \"unset PYTHONPATH; unset ONEFLOW_MOCK_TORCH_LAZY; unset ONEFLOW_MOCK_TORCH_VERBOSE\"\n            )\n            return\n        path = \":\".join(paths)\n        print(\n            f\"export PYTHONPATH={path}; unset ONEFLOW_MOCK_TORCH_LAZY; unset ONEFLOW_MOCK_TORCH_VERBOSE\"\n        )\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "python/oneflow/mock_torch/dyn_mock_mod.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nfrom inspect import ismodule\nimport importlib\nfrom contextlib import contextmanager\nfrom types import ModuleType\nfrom typing import Dict, List\nfrom .mock_importer import enable\n\n\nclass DynamicMockModule(ModuleType):\n    def __init__(\n        self, pkg_name: str, obj_entity: ModuleType, main_pkg_enable: callable,\n    ):\n        self._pkg_name = pkg_name\n        self._obj_entity = obj_entity  # ModuleType or _LazyModule\n        self._main_pkg_enable = main_pkg_enable\n        self._intercept_dict = {}\n\n    def __repr__(self) -> str:\n        return f\"<DynamicMockModule {self._pkg_name} {self._obj_entity}>\"\n\n    def hijack(self, module_name: str, obj: object):\n        self._intercept_dict[module_name] = obj\n\n    @classmethod\n    def from_package(\n        cls,\n        main_pkg: str,\n        *,\n        lazy: bool = True,\n        verbose: bool = False,\n        extra_dict: Dict[str, str] = None,\n        required_dependencies: List[str] = [],\n    ):\n        assert isinstance(main_pkg, str)\n\n        @contextmanager\n        def main_pkg_enable():\n            with enable(\n                lazy=lazy,\n                verbose=verbose,\n                extra_dict=extra_dict,\n                main_pkg=main_pkg,\n                mock_version=True,\n                required_dependencies=required_dependencies,\n            ):\n                yield\n\n        with main_pkg_enable():\n            obj_entity = importlib.import_module(main_pkg)\n        return cls(main_pkg, obj_entity, main_pkg_enable)\n\n    def _get_module(self, _name: str):\n        # Fix Lazy import\n        # https://github.com/huggingface/diffusers/blob/main/src/diffusers/__init__.py#L728-L734\n        module_name = f\"{self._obj_entity.__name__}.{_name}\"\n        try:\n            return importlib.import_module(module_name)\n        except Exception as e:\n            raise RuntimeError(\n                f\"Failed to import {module_name} because of the following error (look up to see its\"\n                f\" traceback):\\n{e}\"\n            ) from e\n\n    def __getattr__(self, name: str):\n        fullname = f\"{self._obj_entity.__name__}.{name}\"\n        if fullname in self._intercept_dict:\n            return self._intercept_dict[fullname]\n\n        with self._main_pkg_enable():\n            obj_entity = getattr(self._obj_entity, name, None)\n            if obj_entity is None:\n                obj_entity = self._get_module(name)\n\n        if ismodule(obj_entity):\n            return DynamicMockModule(self._pkg_name, obj_entity, self._main_pkg_enable)\n\n        return obj_entity\n\n    def __all__(self):\n        with self._main_pkg_enable():\n            return dir(self._obj_entity)\n"
  },
  {
    "path": "python/oneflow/mock_torch/mock_importer.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport builtins\nfrom functools import partial\nimport types\nfrom inspect import ismodule, currentframe\nfrom types import ModuleType\nfrom typing import Any, Dict, Optional\nfrom importlib.abc import MetaPathFinder, Loader\nfrom importlib.machinery import ModuleSpec\nfrom importlib.util import find_spec, module_from_spec\nimport sys\nfrom typing import List\nfrom zipimport import zipimporter\n\nimport oneflow.support.env_var_util as env_var_util\nfrom .mock_modules import MockModuleDict, DummyModule\nfrom .mock_utils import MockEnableDisableMixin\n\n\nerror_msg = \"\"\" is not implemented, please submit an issue at  \n'https://github.com/Oneflow-Inc/oneflow/issues' including the log information of the error, the \nminimum reproduction code, and the system information.\"\"\"\n\n\n# patch hasattr so that\n# 1. torch.not_exist returns DummyModule object, but\n# 2. hasattr(torch, \"not_exist\") still returns False\n_builtin_hasattr = builtins.hasattr\nif not isinstance(_builtin_hasattr, types.BuiltinFunctionType):\n    raise Exception(\"hasattr already patched by someone else!\")\n\n\ndef hasattr(obj, name):\n    return _builtin_hasattr(obj, name)\n\n\nbuiltins.hasattr = hasattr\n\n\ndef probably_called_from_hasattr():\n    frame = currentframe().f_back.f_back\n    return frame.f_code is hasattr.__code__\n\n\n# module wrapper with checks for existence of methods\nclass ModuleWrapper(ModuleType):\n    # TODO add selcted methods\n    def __init__(self, module):\n        self.module = module\n\n    def __setattr__(self, name, value):\n        super().__setattr__(name, value)\n        if name != \"module\":\n            setattr(self.module, name, value)\n\n    def __getattr__(self, name: str) -> Any:\n        if not hasattr(self.module, name):\n            if name == \"__path__\":\n                return None\n            if name == \"__all__\":\n                return [attr for attr in dir(self.module) if not attr.startswith(\"_\")]\n            new_name = self.module.__name__ + \".\" + name\n            if _importer.lazy and not probably_called_from_hasattr():\n                if _importer.verbose:\n                    print(\n                        f'\"{new_name}\" is not found in oneflow, use dummy object as fallback.'\n                    )\n                return DummyModule(new_name, verbose=_importer.verbose)\n            else:\n                if _importer.lazy and _importer.verbose:\n                    print(f\"hasattr({self.module.__name__}, {name}) returns False\")\n                raise AttributeError(new_name + error_msg)\n        attr = getattr(self.module, name)\n        if ismodule(attr):\n            return ModuleWrapper(attr)\n        else:\n            return attr\n\n\nclass OneflowImporter(MockEnableDisableMixin, MetaPathFinder, Loader):\n    def __init__(self):\n        # module_from_spec will try to call the loader's create_module, resulting in infinite recursion\n        self.in_create_module = False\n        self.enable = False\n        # both __init__.py of oneflow and torch can't be executed multiple times, so we use a cache\n        self.enable_mod_cache = {}\n        self.disable_mod_cache = {}\n        # Record modules loaded during mocking for deletion\n        self.delete_list = []\n\n    def find_spec(self, fullname, path, target=None):\n        if module_dict_global.in_forward_dict(\n            fullname\n        ):  # don't touch modules other than torch or extra libs module\n            # for first import of real torch, we use default meta path finders, not our own\n            if not self.enable and self.disable_mod_cache.get(fullname) is None:\n                return None\n            return ModuleSpec(fullname, self)\n        self.delete_list.append(fullname)\n        return None\n\n    def find_module(self, fullname, path=None):\n        spec = self.find_spec(fullname, path)\n        return spec\n\n    def create_module(self, spec):\n        if self.in_create_module:\n            return None\n        self.in_create_module = True\n        if self.enable:\n            if module_dict_global.in_forward_dict(spec.name):\n                oneflow_mod_fullname = module_dict_global.forward_name(spec.name)\n            if (\n                sys.modules.get(oneflow_mod_fullname) is None\n                and self.enable_mod_cache.get(spec.name) is None\n            ):\n                # get actual oneflow module\n                try:\n                    real_spec = find_spec(oneflow_mod_fullname)\n                except ModuleNotFoundError:\n                    real_spec = None\n                if real_spec is None:\n                    self.in_create_module = False\n                    if self.lazy:\n                        if self.verbose:\n                            print(\n                                f\"{oneflow_mod_fullname} is not found in oneflow, use dummy object as fallback.\"\n                            )\n                        return DummyModule(oneflow_mod_fullname, verbose=self.verbose)\n                    else:\n                        raise ModuleNotFoundError(oneflow_mod_fullname + error_msg)\n\n                real_mod = module_from_spec(real_spec)\n                loader = real_spec.loader\n                if isinstance(loader, zipimporter):\n                    # TODO: verify can mock torch as oneflow in zipimporter\n                    pass\n                else:\n                    loader.exec_module(real_mod)\n            else:\n                real_mod = sys.modules.get(oneflow_mod_fullname)\n                if real_mod is None:\n                    real_mod = self.enable_mod_cache[spec.name]\n            self.in_create_module = False\n            return real_mod\n        else:\n            torch_full_name = spec.name\n            real_mod = self.disable_mod_cache[torch_full_name]\n            self.in_create_module = False\n            return real_mod\n\n    def exec_module(self, module):\n        module_name = module.__name__\n        if module_dict_global.in_inverse_dict(module_name):\n            fullname = module_dict_global.inverse_name(module_name)\n        if self.enable:\n            if not isinstance(module, DummyModule):\n                module = ModuleWrapper(module)\n        sys.modules[fullname] = module\n        globals()[fullname] = module\n\n    def _enable(\n        self,\n        globals=None,\n        lazy=False,\n        verbose=False,\n        *,\n        main_pkg: str = None,\n        mock_version: bool = None,\n        required_dependencies: List[str] = [],\n        from_cli: bool = False,\n    ):\n\n        if verbose:\n            print(\"enable mock torch\", globals[\"__name__\"])\n\n        if self.enable:  # already enabled\n            of_importer_module_name = self.globals[\"__name__\"]\n            input_module_name = globals[\"__name__\"]\n            if of_importer_module_name != input_module_name:\n                print(\n                    f\"Warning: {of_importer_module_name} is already enabled, but {input_module_name} is trying to enable it again. skip.\"\n                )\n            return\n\n        # record config for re-enabling\n        self._mock_enable_config = {k: v for k, v in locals().items() if k != \"self\"}\n        # insert importer to the first place of meta_path\n        sys.meta_path.insert(0, self)\n\n        self.lazy = lazy\n        self.verbose = verbose\n        self.from_cli = from_cli\n        self.globals = globals\n\n        self.mock_enable(\n            globals=globals,\n            module_dict=module_dict_global,\n            main_pkg=main_pkg,\n            mock_version=mock_version,\n            required_dependencies=required_dependencies,\n            from_cli=from_cli,\n            verbose=verbose,\n        )\n        self.enable = True\n\n    def _disable(self, globals, *, verbose=False):\n        if verbose:\n            print(\n                \"disable mock torch in\",\n                globals[\"__name__\"],\n                \"\\tself.enable: \",\n                self.enable,\n            )\n\n        if not self.enable:  # already disabled\n            return\n\n        of_importer_module_name = self.globals[\"__name__\"]\n        input_module_name = globals[\"__name__\"]\n        if of_importer_module_name != input_module_name:\n            raise RuntimeError(\n                f\"Error: {of_importer_module_name} is enabled, but {input_module_name} is trying to disable it. must disable it in the same module.\"\n            )\n\n        self.mock_disable(\n            globals=globals,\n            module_dict=module_dict_global,\n            delete_list=self.delete_list,\n            from_cli=self.from_cli,\n        )\n\n        sys.meta_path.remove(self)\n        self.enable = False\n        self.delete_list = []\n        self.globals = None\n\n\n_importer = OneflowImporter()\n\n\nclass BaseMockConfig:\n    def __init__(\n        self,\n        lazy: Optional[bool] = None,\n        verbose: Optional[bool] = None,\n        extra_dict: Optional[Dict[str, str]] = None,\n        *,\n        main_pkg: Optional[str] = None,\n        mock_version: Optional[str] = None,\n        required_dependencies: List[str] = [],\n        _from_cli: bool = False,\n    ):\n        global module_dict_global\n        module_dict_global = MockModuleDict(extra_dict)\n        module_dict_global.add(\"torch\", \"oneflow\")\n\n        required_dependencies.extend(\n            [k for k in extra_dict or {} if k not in required_dependencies]\n        )\n        if \"torch\" not in required_dependencies:\n            required_dependencies.append(\"torch\")\n\n        parse_bool_env = partial(\n            env_var_util.parse_boolean_from_env, defalut_value=False\n        )\n\n        forcedly_disabled_by_env_var = parse_bool_env(\"ONEFLOW_DISABLE_MOCK_TORCH\")\n\n        lazy = lazy if lazy is not None else parse_bool_env(\"ONEFLOW_MOCK_TORCH_LAZY\")\n        verbose = (\n            verbose\n            if verbose is not None\n            else parse_bool_env(\"ONEFLOW_MOCK_TORCH_VERBOSE\")\n        )\n\n        self.lazy = lazy\n        self.verbose = verbose\n        self.forcedly_disabled_by_env_var = forcedly_disabled_by_env_var\n        self.required_dependencies = required_dependencies\n        self.parse_bool_env = parse_bool_env\n        self._from_cli = _from_cli\n        self.main_pkg = main_pkg\n        self.mock_version = mock_version\n\n\nclass enable(BaseMockConfig):\n    \"\"\"https://docs.oneflow.org/master/cookies/oneflow_torch.html\"\"\"\n\n    def __init__(\n        self,\n        lazy: Optional[bool] = None,\n        verbose: Optional[bool] = None,\n        extra_dict: Optional[Dict[str, str]] = None,\n        *,\n        main_pkg: Optional[str] = None,\n        mock_version: Optional[str] = None,\n        required_dependencies: List[str] = [],\n        _from_cli: bool = False,\n    ):\n        super().__init__(\n            lazy=lazy,\n            verbose=verbose,\n            extra_dict=extra_dict,\n            main_pkg=main_pkg,\n            mock_version=mock_version,\n            required_dependencies=required_dependencies,\n            _from_cli=_from_cli,\n        )\n\n        if self.forcedly_disabled_by_env_var:  # super().__init__ will set this\n            return\n\n        self.globals = currentframe().f_back.f_globals\n        self.skip_processing = False\n        if getattr(_importer, \"globals\", None) is not None:\n            import_name = _importer.globals[\"__name__\"]\n            if import_name == self.globals[\"__name__\"]:\n                self.skip_processing = True\n                return\n\n        self._importer_enable = _importer.enable\n        if self._importer_enable:\n            self._mock_enable_config = _importer._mock_enable_config\n            _importer._disable(_importer.globals, verbose=self.verbose)\n\n        _importer._enable(\n            self.globals,\n            lazy=self.lazy,\n            verbose=self.verbose,\n            main_pkg=main_pkg,\n            mock_version=mock_version,\n            required_dependencies=required_dependencies,\n            from_cli=_from_cli,\n        )\n\n    def __enter__(self):\n        pass\n\n    def __exit__(self, exception_type, exception_value, traceback):\n\n        if self.forcedly_disabled_by_env_var or self.skip_processing:\n            return\n\n        _importer._disable(_importer.globals, verbose=self.verbose)\n\n        if self._importer_enable:\n            _importer._enable(\n                # When re-enabling mock torch, from_cli shoule always be False\n                **self._mock_enable_config,\n            )\n\n\nclass disable:\n    def __init__(self):\n        self._importer_enable = _importer.enable\n        if not self._importer_enable:\n            return\n\n        self.globals = currentframe().f_back.f_globals\n        self.lazy = _importer.lazy\n        self.verbose = _importer.verbose\n        self._mock_enable_config = _importer._mock_enable_config\n        _importer._disable(_importer.globals, verbose=self.verbose)\n\n    def __enter__(self):\n        pass\n\n    def __exit__(self, exception_type, exception_value, traceback):\n        if self._importer_enable:\n            _importer._enable(\n                # When re-enabling mock torch, from_cli shoule always be False\n                **self._mock_enable_config,\n            )\n\n\ndef is_enabled():\n    return _importer.enable\n"
  },
  {
    "path": "python/oneflow/mock_torch/mock_modules.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nfrom types import ModuleType\n\n__all__ = [\"MockModuleDict\", \"DummyModule\"]\n\n\nclass MockModuleDict:\n    def __init__(self, mapping=None):\n        if mapping is not None and not isinstance(mapping, dict):\n            raise ValueError(\"Extra mock library must be a dict.\")\n        self.forward = {}\n        self.inverse = {}\n        if mapping is not None:\n            for key, value in mapping.items():\n                self.add(key, value)\n\n    def add(self, key, value):\n        \"\"\"mock key thorugh value.\"\"\"\n        if key in self.forward or value in self.inverse:\n            raise ValueError(\"Key or value already exists.\")\n        self.forward[key] = value\n        self.inverse[value] = key\n\n    def remove(self, key=None, value=None):\n        if key is not None:\n            value = self.forward.pop(key)\n            self.inverse.pop(value)\n        elif value is not None:\n            key = self.inverse.pop(value)\n            self.forward.pop(key)\n        else:\n            raise ValueError(\"Must provide a key or value to remove.\")\n\n    def in_forward_dict(self, s):\n        return s.split(\".\")[0] in self.forward.keys()\n\n    def in_inverse_dict(self, s):\n        return s.split(\".\")[0] in self.inverse.keys()\n\n    def inverse_name(self, s: str):  # s: spec.name\n        return self.inverse[s.split(\".\")[0]] + s[len(s.split(\".\")[0]) :]\n\n    def forward_name(self, s: str):\n        return self.forward[s.split(\".\")[0]] + s[len(s.split(\".\")[0]) :]\n\n\nclass DummyModule(ModuleType):\n    def __init__(self, name, verbose=False):\n        super().__init__(name)\n        self._verbose = verbose\n\n    def __getattr__(self, name):\n        if self._verbose:\n            print(\n                f'\"{self.__name__}\" is a dummy object, and its attr \"{name}\" is accessed.'\n            )\n        if name == \"__path__\":\n            return None\n        if name == \"__all__\":\n            return []\n        if name == \"__file__\":\n            return None\n        if name == \"__mro_entries__\":\n            return lambda x: ()\n\n        return DummyModule(self.__name__ + \".\" + name, self._verbose)\n\n    def __getitem__(self, name):\n        new_name = f\"{self.__name__}[{name}]\"\n        if isinstance(name, int):\n            if self._verbose:\n                print(\n                    f'\"{self.__name__}\" is a dummy object, and `{new_name}` is called. Raising an IndexError to simulate an empty list.'\n                )\n            raise IndexError\n        if self._verbose:\n            print(f'\"{self.__name__}\" is a dummy object, and `{new_name}` is called.')\n        return DummyModule(new_name, self._verbose)\n\n    def __call__(self, *args, **kwargs):\n        new_name = f'{self.__name__}({\", \".join(map(repr, args))}, {\", \".join([\"{}={}\".format(k, repr(v)) for k, v in kwargs.items()])})'\n        if self._verbose:\n            print(f'\"{self.__name__}\" is a dummy object, and `{new_name}` is called.')\n        return DummyModule(new_name, self._verbose)\n\n    def __bool__(self):\n        if self._verbose:\n            print(\n                f'\"{self.__name__}\" is a dummy object, and its bool value is accessed.'\n            )\n        return False\n\n    def __enter__(self):\n        raise RuntimeError(\n            f'\"{self.__name__}\" is a dummy object, and does not support \"with\" statement.'\n        )\n\n    def __exit__(self, exception_type, exception_value, traceback):\n        raise RuntimeError(\n            f'\"{self.__name__}\" is a dummy object, and does not support \"with\" statement.'\n        )\n\n    def __subclasscheck__(self, subclass):\n        return False\n\n    def __instancecheck__(self, instance):\n        return False\n"
  },
  {
    "path": "python/oneflow/mock_torch/mock_utils.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport sys\nimport sysconfig\nimport pkgutil\nfrom collections import deque\nfrom importlib import import_module\n\nif sys.version_info <= (3, 8):\n    try:\n        from importlib_metadata import requires\n    except ImportError:\n        import subprocess\n\n        subprocess.check_call(\"pip install importlib_metadata\", shell=True)\n        subprocess.check_call(\"pip install packaging\", shell=True)\nelse:\n    from importlib.metadata import requires\n\nfrom packaging.requirements import Requirement\nfrom pathlib import Path\nfrom functools import lru_cache\nfrom typing import List, Optional\nfrom types import ModuleType\n\n\n__all__ = [\"MockEnableDisableMixin\"]\n\n\nclass PackageDependencyMixin:\n    \"\"\"Get all dependencies of a package filtered by a list of dependencies.\n\n    Example:\n        >>> import diffusers #  version 0.24.0\n        >>> op = PackageDependencyMixin()\n        >>> result = op.has_dependencies(\"diffusers\", [\"torch\"])\n        >>> print(result)\n        ['huggingface_hub', 'diffusers']\n    \"\"\"\n\n    pkg_cache = {}  # {pkg: [deps]}\n\n    @staticmethod\n    def find_matching_dependencies(\n        main_pkg: str, dependencies: List[str], max_visits=1000\n    ) -> List[str]:\n        @lru_cache()\n        def python_stdlib_packages():\n            # current python stdlib path\n            stdlib_path = sysconfig.get_paths()[\"stdlib\"]\n\n            # use pkgutil to list all modules in the standard library\n            python_modules = [\n                name for _, name, _ in pkgutil.iter_modules([stdlib_path])\n            ]\n\n            # combine built-in module names and Python modules\n            all_modules = list(sys.builtin_module_names) + python_modules\n\n            return all_modules\n\n        def format_package_name(pkg: str):\n            return Requirement(pkg).name.replace(\"-\", \"_\")\n\n        @lru_cache()\n        def get_requirements(pkg: str):\n\n            python_modules = python_stdlib_packages()\n            if pkg in python_modules:\n                return []\n            try:\n                direct_dependencies = requires(pkg)\n                if len(direct_dependencies) == 0:\n                    return []\n\n                result = set()\n                for pkg in direct_dependencies:\n                    pkg = format_package_name(pkg)\n                    if pkg == main_pkg:\n                        continue\n\n                    if pkg not in python_modules:\n                        result.add(pkg)\n\n                return list(result)\n\n            except:\n                return []\n\n        def is_leaf_package(pkg) -> bool:\n            if pkg in dependencies:\n                return True\n\n            return len(get_requirements(pkg)) == 0\n\n        main_pkg = format_package_name(main_pkg)\n\n        # build graph\n        graph = {}  # {dep: [pkg1, pkg2, ...]}\n        queue = deque([main_pkg])\n        visited = set()\n        stops = set()\n        while queue:\n            pkg = queue.popleft()\n            if is_leaf_package(pkg):\n                stops.add(pkg)\n                continue\n            if pkg in visited:\n                continue\n            visited.add(pkg)\n            if len(visited) > max_visits:\n                print(\n                    f\"\\033[1;33mWARNING: max_visits {max_visits} reached, stop searching.\\033[0m\"\n                )\n                break\n\n            for req in get_requirements(pkg):\n                graph.setdefault(req, set()).add(pkg)\n                queue.append(req)\n\n        # init cache and queue\n        cache = {}\n        visited.clear()\n        queue = deque(stops)\n        for pkg in stops:\n            cache[pkg] = True if pkg in dependencies else False\n\n        # bfs_from_stops\n        while queue:\n            pkg = queue.popleft()\n            if pkg in visited:\n                continue\n            visited.add(pkg)\n\n            for dep in graph.get(pkg, set()):\n                is_ok = cache.get(dep, False)\n                if cache[pkg] or is_ok:\n                    is_ok = True\n                cache[dep] = is_ok\n                queue.append(dep)\n\n        return [pkg for pkg, is_ok in cache.items() if is_ok]\n\n    @staticmethod\n    def varify_input(main_pkg, dependencies, callback, verbose=False):\n        try:\n            requires(main_pkg)\n        except:\n            if verbose:\n                print(\n                    f\"WARNING: main_pkg {main_pkg} has no meta information, please check if it is a valid package.\"\n                )\n                print(\"will set it as its own dependency to avoid error.\")\n            PackageDependencyMixin.pkg_cache[main_pkg] = [main_pkg] + dependencies\n\n        if not isinstance(main_pkg, str):\n            raise ValueError(\"main_pkg must be a string.\")\n        if not isinstance(dependencies, list):\n            raise ValueError(\"dependencies must be a list.\")\n        if not all([isinstance(dep, str) for dep in dependencies]):\n            raise ValueError(\"dependencies must be a list of strings.\")\n        if callback is not None and not callable(callback):\n            raise ValueError(\"callback must be a callable.\")\n\n    @classmethod\n    def has_dependencies(\n        self,\n        main_pkg: str,\n        dependencies: List[str],\n        callback: callable = None,\n        *,\n        verbose=False,\n    ) -> List[str]:\n        \"\"\"Check if a package has any dependencies in a list of dependencies.\"\"\"\n        PackageDependencyMixin.varify_input(main_pkg, dependencies, callback, verbose)\n\n        deps = PackageDependencyMixin.pkg_cache.get(main_pkg, None)\n        if deps is None:\n            deps = PackageDependencyMixin.find_matching_dependencies(\n                main_pkg, dependencies\n            )\n            PackageDependencyMixin.pkg_cache.update({main_pkg: deps})\n\n        if verbose:\n            print(\"PackageDependencyMixin : main_pkg=\", main_pkg, \", deps=\", deps)\n\n        if callback:\n            return callback(deps)\n        else:\n            return deps\n\n\nclass VersionMixin:\n    version_cache = {}\n\n    def mock_version(self, module_a: ModuleType, module_b: ModuleType):\n        \"\"\"Mock the version of module_a with the version of module_b.\"\"\"\n        if isinstance(module_a, str):\n            module_a = import_module(module_a)\n        if isinstance(module_b, str):\n            module_b = import_module(module_b)\n\n        attr_name = \"__version__\"\n        orig_attr = getattr(module_a, attr_name, None)\n        setattr(module_a, attr_name, getattr(module_b, attr_name, None))\n        VersionMixin.version_cache.update({module_a: (attr_name, orig_attr)})\n\n    def restore_version(self):\n        for module, (attr_name, orig_attr) in self.version_cache.items():\n            setattr(module, attr_name, orig_attr)\n        VersionMixin.version_cache.clear()\n\n\nclass MockEnableDisableMixin(PackageDependencyMixin, VersionMixin):\n    \"\"\"Mock torch package using  OneFlow.\"\"\"\n\n    # list of hazardous modules that may cause issues, handle with care\n    hazard_list = [\n        \"_distutils_hack\",\n        \"importlib\",\n        \"regex\",\n        \"tokenizers\",\n        \"safetensors._safetensors_rust\",\n    ]\n\n    def is_safe_module(self, module_key):\n        k = module_key\n        hazard_list = MockEnableDisableMixin.hazard_list\n\n        name = k if \".\" not in k else k[: k.find(\".\")]\n        if name in hazard_list or k in hazard_list:\n            return False\n        return True\n\n    def mock_enable(\n        self,\n        globals,  # parent_globals\n        module_dict,  # MockModuleDict object\n        *,\n        main_pkg: Optional[str] = None,\n        mock_version: Optional[str] = None,\n        required_dependencies: List[str],\n        from_cli=False,\n        verbose=False,\n        **kwargs,\n    ):\n        \"\"\"Mock torch package using  OneFlow.\n\n        Args:\n            `globals`: The globals() of the parent module.\n\n            `module_dict`:  MockModuleDict object.\n\n            `main_pkg`: The main package to mock.\n\n            `required_dependencies`: The dependencies to mock for the `main_pkg`.\n        \"\"\"\n        if mock_version:\n            mock_map = module_dict.forward\n            for pkg, mock_pkg in mock_map.items():\n                self.mock_version(pkg, mock_pkg)\n\n        if not hasattr(self, \"enable_mod_cache\"):\n            self.enable_mod_cache = {}\n        if not hasattr(self, \"disable_mod_cache\"):\n            self.disable_mod_cache = {}\n        if not hasattr(self, \"mock_safety_packages\"):\n            self.mock_safety_packages = set()\n\n        if main_pkg:\n            # Analyze the dependencies of the main package\n            cur_sys_modules = sys.modules.copy()\n            existing_deps = self.has_dependencies(\n                main_pkg,\n                dependencies=required_dependencies,\n                callback=lambda x: [dep for dep in x if dep in cur_sys_modules],\n                verbose=verbose,\n            )\n            if verbose:\n                print(\n                    \"Existing dependencies of \",\n                    \"main_pkg: \",\n                    main_pkg,\n                    \"existing_deps: \",\n                    existing_deps,\n                )\n\n            self.mock_safety_packages.update(existing_deps)\n\n        # disable non-safe modules loaded before mocking\n        def can_disable_mod_cache(k):  # module_key\n            if not self.is_safe_module(k):\n                return False\n            if module_dict.in_forward_dict(k):\n                return True\n            for dep_pkg in self.mock_safety_packages:\n                if k.startswith(dep_pkg + \".\") or k == dep_pkg:\n                    return True\n            return False\n\n        for k, v in sys.modules.copy().items():\n            exclude_torch_from_cli = not (from_cli and k == \"torch\")\n            if not exclude_torch_from_cli:  # torch is imported from CLI\n                continue\n\n            if can_disable_mod_cache(k):\n                aliases = [alias for alias, value in globals.items() if value is v]\n                self.disable_mod_cache.update({k: (v, aliases)})\n                del sys.modules[k]\n                for alias in aliases:\n                    del globals[alias]\n\n        # restore modules loaded during mocking\n        for k, (v, aliases) in self.enable_mod_cache.items():\n            sys.modules.update({k: v})\n            for alias in aliases:\n                globals.update({alias: v})\n\n    def mock_disable(self, globals, module_dict, delete_list, from_cli=False):\n        \"\"\"Disable the mocked packages.\"\"\"\n        if not hasattr(self, \"enable_mod_cache\") or not hasattr(\n            self, \"disable_mod_cache\"\n        ):\n            RuntimeError(\"Please call mock_enable() first.\")\n\n        # disable modules loaded during mocking\n        def can_enable_mod_cache(k):  # module_key\n            if not self.is_safe_module(k):\n                return False\n            if module_dict.in_forward_dict(k):\n                return True\n            return k in delete_list\n\n        for k, v in sys.modules.copy().items():\n            if can_enable_mod_cache(k):\n                aliases = [alias for alias, value in globals.items() if value is v]\n                self.enable_mod_cache.update({k: (v, aliases)})\n                del sys.modules[k]\n                for alias in aliases:\n                    del globals[alias]\n\n        # restore modules loaded during before mocking\n        for k, (v, aliases) in self.disable_mod_cache.items():\n            sys.modules.update({k: v})\n            for alias in aliases:\n                globals.update({alias: v})\n\n        if from_cli:\n            torch_env = Path(__file__).parent\n            if str(torch_env) in sys.path:\n                sys.path.remove(str(torch_env))\n\n        self.restore_version()\n"
  },
  {
    "path": "python/oneflow/mock_torch/torch/__init__.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport sys\nimport oneflow\nfrom oneflow.mock_torch import ModuleWrapper, enable\n\n\ndef __getattr__(name: str):\n    return ModuleWrapper(oneflow).__getattr__(name)\n\n\nenable(_from_cli=True)\n"
  },
  {
    "path": "python/oneflow/model.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nfrom oneflow.framework.model import Callback, CheckpointConfig, DataModule\nfrom oneflow.framework.model import Model as Model\nfrom oneflow.framework.model import NumpyDataModule, TrainingConfig, ValidationConfig\n"
  },
  {
    "path": "python/oneflow/multiprocessing/__init__.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\"\"\"\noneflow.multiprocessing is a wrapper around the native :mod:`multiprocessing`\nmodule. It registers custom reducers, that use shared memory to provide shared\nviews on the same data in different processes. Once the tensor/storage is moved\nto shared_memory (see :func:`~oneflow.Tensor.share_memory_`), it will be possible\nto send it to other processes without making any copies.\n\nThe API is 100% compatible with the original module - it's enough to change\n``import multiprocessing`` to ``import oneflow.multiprocessing`` to have all the\ntensors sent through the queues or shared via other mechanisms, moved to shared\nmemory.\n\nBecause of the similarity of APIs we do not document most of this package\ncontents, and we recommend referring to very good docs of the original module.\n\"\"\"\nimport oneflow as flow\nimport sys\nfrom .reductions import init_reductions\nimport multiprocessing\n\n__all__ = [\n    \"set_sharing_strategy\",\n    \"get_sharing_strategy\",\n    \"get_all_sharing_strategies\",\n    \"unlink_all_shared_memory\",\n]\n\n\nfrom multiprocessing import *  # noqa: F403\n\n\n__all__ += multiprocessing.__all__  # type: ignore[attr-defined]\n\n\n# This call adds a Linux specific prctl(2) wrapper function to this module.\n# See https://github.com/pytorch/pytorch/pull/14391 for more information.\nflow._oneflow_internal._multiprocessing_init()\n\n\n\"\"\"Add helper function to spawn N processes and wait for completion of any of\nthem. This depends `mp.get_context` which was added in Python 3.4.\"\"\"\nfrom .spawn import (\n    spawn,\n    SpawnContext,\n    start_processes,\n    ProcessContext,\n    ProcessRaisedException,\n    ProcessExitedException,\n)\n\n\nif sys.platform == \"darwin\" or sys.platform == \"win32\":\n    _sharing_strategy = \"file_system\"\n    _all_sharing_strategies = {\"file_system\"}\nelse:\n    _sharing_strategy = \"file_descriptor\"\n    _all_sharing_strategies = {\"file_descriptor\", \"file_system\"}\n\n\ndef set_sharing_strategy(new_strategy):\n    \"\"\"Sets the strategy for sharing CPU tensors.\n\n    Args:\n        new_strategy (str): Name of the selected strategy. Should be one of\n            the values returned by :func:`get_all_sharing_strategies()`.\n    \"\"\"\n    global _sharing_strategy\n    assert new_strategy in _all_sharing_strategies\n    _sharing_strategy = new_strategy\n\n\ndef get_sharing_strategy():\n    \"\"\"Returns the current strategy for sharing CPU tensors.\"\"\"\n    return _sharing_strategy\n\n\ndef get_all_sharing_strategies():\n    \"\"\"Returns a set of sharing strategies supported on a current system.\"\"\"\n    return _all_sharing_strategies\n\n\ndef unlink_all_shared_memory():\n    flow._oneflow_internal.multiprocessing.unlink_all_shared_memory()\n\n\ninit_reductions()\n"
  },
  {
    "path": "python/oneflow/multiprocessing/_atfork.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport sys\n\n__all__ = [\"register_after_fork\"]\n\nif sys.platform == \"win32\" or sys.version_info < (3, 7):\n    import multiprocessing.util as _util\n\n    def _register(func):\n        def wrapper(arg):\n            func()\n\n        _util.register_after_fork(_register, wrapper)\n\n\nelse:\n    import os\n\n    def _register(func):\n        os.register_at_fork(after_in_child=func)\n\n\ndef register_after_fork(func):\n    \"\"\"Register a callable to be executed in the child process after a fork.\n\n    Note:\n        In python < 3.7 this will only work with processes created using the\n        ``multiprocessing`` module. In python >= 3.7 it also works with\n        ``os.fork()``.\n\n    Args:\n        func (function): Function taking no arguments to be called in the child after fork\n\n    \"\"\"\n    _register(func)\n"
  },
  {
    "path": "python/oneflow/multiprocessing/pool.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport multiprocessing.pool\nimport multiprocessing.util as util\n\nfrom .queue import SimpleQueue\n\n\ndef clean_worker(*args, **kwargs):\n    import gc\n\n    multiprocessing.pool.worker(*args, **kwargs)\n    # Regular multiprocessing workers don't fully clean up after themselves,\n    # so we have to explicitly trigger garbage collection to make sure that all\n    # destructors are called...\n    gc.collect()\n\n\nclass Pool(multiprocessing.pool.Pool):\n    \"\"\"Pool implementation which uses our version of SimpleQueue.\n    This lets us pass tensors in shared memory across processes instead of\n    serializing the underlying data.\"\"\"\n\n    def _setup_queues(self):\n        self._inqueue = SimpleQueue()\n        self._outqueue = SimpleQueue()\n        self._quick_put = self._inqueue._writer.send\n        self._quick_get = self._outqueue._reader.recv\n\n    def _repopulate_pool(self):\n        \"\"\"Bring the number of pool processes up to the specified number,\n        for use after reaping workers which have exited.\n        \"\"\"\n        for i in range(self._processes - len(self._pool)):\n            # changed worker -> clean_worker\n            args = (\n                self._inqueue,\n                self._outqueue,\n                self._initializer,\n                self._initargs,\n                self._maxtasksperchild,\n            )\n            if hasattr(self, \"_wrap_exception\"):\n                args += (self._wrap_exception,)\n            w = self.Process(target=clean_worker, args=args)\n            self._pool.append(w)\n            w.name = w.name.replace(\"Process\", \"PoolWorker\")\n            w.daemon = True\n            w.start()\n            util.debug(\"added worker\")\n"
  },
  {
    "path": "python/oneflow/multiprocessing/queue.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport io\nimport multiprocessing.queues\nfrom multiprocessing.reduction import ForkingPickler\nimport pickle\n\n\nclass ConnectionWrapper(object):\n    \"\"\"Proxy class for _multiprocessing.Connection which uses ForkingPickler to\n    serialize objects\"\"\"\n\n    def __init__(self, conn):\n        self.conn = conn\n\n    def send(self, obj):\n        buf = io.BytesIO()\n        ForkingPickler(buf, pickle.HIGHEST_PROTOCOL).dump(obj)\n        self.send_bytes(buf.getvalue())\n\n    def recv(self):\n        buf = self.recv_bytes()\n        return pickle.loads(buf)\n\n    def __getattr__(self, name):\n        if \"conn\" in self.__dict__:\n            return getattr(self.conn, name)\n        raise AttributeError(\n            \"'{}' object has no attribute '{}'\".format(type(self).__name__, \"conn\")\n        )\n\n\nclass Queue(multiprocessing.queues.Queue):\n    def __init__(self, *args, **kwargs):\n        super(Queue, self).__init__(*args, **kwargs)\n        self._reader: ConnectionWrapper = ConnectionWrapper(self._reader)\n        self._writer: ConnectionWrapper = ConnectionWrapper(self._writer)\n        self._send = self._writer.send\n        self._recv = self._reader.recv\n\n\nclass SimpleQueue(multiprocessing.queues.SimpleQueue):\n    def _make_methods(self):\n        if not isinstance(self._reader, ConnectionWrapper):\n            self._reader: ConnectionWrapper = ConnectionWrapper(self._reader)\n            self._writer: ConnectionWrapper = ConnectionWrapper(self._writer)\n        super(SimpleQueue, self)._make_methods()  # type: ignore[misc]\n"
  },
  {
    "path": "python/oneflow/multiprocessing/reductions.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nfrom multiprocessing.reduction import ForkingPickler\n\nimport numpy as np\n\nimport oneflow as flow\nfrom oneflow.nn.parameter import Parameter\nfrom oneflow.framework.tensor import Tensor\nfrom oneflow.multiprocessing import shared_memory\n\n\ntry:\n    # Early load resource_sharer to prevent a partially initialized instance\n    # from being inherited in a forked child process. The reduce_storage method\n    # requires this module indirectly through DupFd(). The built-in mp.Queue\n    # class pickles arguments in a background thread which may overlap with the\n    # fork.\n    import multiprocessing.resource_sharer\nexcept ImportError:\n    pass\n\n\ndef rebuild_empty_tensor(shape, dtype, requires_grad):\n    t = flow.tensor([], dtype=dtype)\n    t.requires_grad = requires_grad\n    return t.reshape(*shape)\n\n\ndef rebuild_shm_tensor(shm, shape, dtype, requires_grad):\n    def delete_shm():\n        try:\n            # For unknown reasons delete_shm called in dataloader may fail\n            # with \"StopIteration\".\n            # An example is when dataloader is wrapped in a generator like\n            # `log_every`.\n            shm.close()\n            shm.unlink()\n        except:\n            pass\n\n    arr = np.ndarray(shape, dtype=dtype, buffer=shm.buf)\n    t = flow.from_numpy(arr)\n    t._register_storage_delete_hook(delete_shm)\n    t.requires_grad = requires_grad\n\n    return t\n\n\ndef rebuild_empty_parameter(shape, dtype, requires_grad):\n    t = flow.tensor([], dtype=dtype)\n    t = t.reshape(*shape)\n    return Parameter(t, requires_grad=requires_grad)\n\n\ndef rebuild_shm_parameter(shm, shape, dtype, requires_grad):\n    def delete_shm():\n        shm.close()\n        shm.unlink()\n\n    arr = np.ndarray(shape, dtype=dtype, buffer=shm.buf)\n    t = flow.from_numpy(arr)\n    t._register_storage_delete_hook(delete_shm)\n    return Parameter(t, requires_grad=requires_grad)\n\n\ndef reduce_tensor(tensor):\n    tensor_data = tensor.numpy()\n    requires_grad = tensor.requires_grad\n\n    if tensor_data.nbytes == 0:\n        return (rebuild_empty_tensor, (tensor.shape, tensor.dtype, requires_grad))\n    else:\n        shm = shared_memory.SharedMemory(create=True, size=tensor_data.nbytes)\n        shm_numpy = np.ndarray(\n            tensor_data.shape, dtype=tensor_data.dtype, buffer=shm.buf\n        )\n        shm_numpy[:] = tensor_data[:]\n        return (\n            rebuild_shm_tensor,\n            (shm, tensor_data.shape, tensor_data.dtype, requires_grad),\n        )\n\n\ndef reduce_parameter(tensor):\n    tensor_data = tensor.numpy()\n    requires_grad = tensor.requires_grad\n\n    if tensor_data.nbytes == 0:\n        return (rebuild_empty_parameter, (tensor, shape, tensor.dtype, requires_grad))\n    else:\n        shm = shared_memory.SharedMemory(create=True, size=tensor_data.nbytes)\n        shm_numpy = np.ndarray(\n            tensor_data.shape, dtype=tensor_data.dtype, buffer=shm.buf\n        )\n        shm_numpy[:] = tensor_data[:]\n        return (\n            rebuild_shm_parameter,\n            (shm, tensor_data.shape, tensor_data.dtype, requires_grad),\n        )\n\n\ndef init_reductions():\n    ForkingPickler.register(Tensor, reduce_tensor)\n    ForkingPickler.register(flow._oneflow_internal.Tensor, reduce_tensor)\n    ForkingPickler.register(Parameter, reduce_parameter)\n    ForkingPickler.register(flow._oneflow_internal.nn.Parameter, reduce_parameter)\n"
  },
  {
    "path": "python/oneflow/multiprocessing/shared_memory/__init__.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow as flow\n\n__all__ = [\"SharedMemory\"]\n\n\nclass SharedMemory:\n    def __init__(self, name=None, create=False, size=0):\n        if not size >= 0:\n            raise ValueError(\"'size' must be a non-negative integer\")\n        if create:\n            if size == 0:\n                raise ValueError(\"'size' must be a positive number different from zero\")\n        self.shm_ = flow._oneflow_internal.multiprocessing.SharedMemory(\n            name=name if name is not None else \"\", create=create, size=size\n        )\n\n    def __del__(self):\n        try:\n            if hasattr(self, \"shm_\"):\n                self.close()\n        except OSError:\n            pass\n\n    def __reduce__(self):\n        return (\n            self.__class__,\n            (self.name, False, self.size,),\n        )\n\n    def __repr__(self):\n        return f\"{self.__class__.__name__}({self.name!r}, size={self.size})\"\n\n    @property\n    def buf(self):\n        \"A memoryview of contents of the shared memory block.\"\n        return self.shm_.buf\n\n    @property\n    def name(self):\n        \"Unique name that identifies the shared memory block.\"\n        return self.shm_.name\n\n    @property\n    def size(self):\n        \"Size in bytes.\"\n        return self.shm_.size\n\n    def close(self):\n        \"\"\"Closes access to the shared memory from this instance but does\n        not destroy the shared memory block.\"\"\"\n        return self.shm_.close()\n\n    def unlink(self):\n        \"\"\"Requests that the underlying shared memory block be destroyed.\n        In order to ensure proper cleanup of resources, unlink should be\n        called once (and only once) across all processes which have access\n        to the shared memory block.\"\"\"\n        return self.shm_.unlink()\n"
  },
  {
    "path": "python/oneflow/multiprocessing/spawn.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nfrom typing import Optional\nimport multiprocessing\nimport multiprocessing.connection\nimport signal\nimport sys\nimport warnings\n\nfrom oneflow.multiprocessing import _prctl_pr_set_pdeathsig  # type: ignore[attr-defined]\n\n\nclass ProcessException(Exception):\n    __slots__ = [\"error_index\", \"error_pid\"]\n\n    def __init__(self, msg: str, error_index: int, pid: int):\n        super().__init__(msg)\n        self.error_index = error_index\n        self.pid = pid\n\n\nclass ProcessRaisedException(ProcessException):\n    \"\"\"\n    Exception is thrown when the process failed due to exception\n    raised by the code.\n    \"\"\"\n\n    def __init__(\n        self, msg: str, error_index: int, error_pid: int,\n    ):\n        super().__init__(msg, error_index, error_pid)\n\n\nclass ProcessExitedException(ProcessException):\n    \"\"\"\n    Exception is thrown when the process failed due to signal\n    or exited with a specific code.\n    \"\"\"\n\n    __slots__ = [\"exit_code\"]\n\n    def __init__(\n        self,\n        msg: str,\n        error_index: int,\n        error_pid: int,\n        exit_code: int,\n        signal_name: Optional[str] = None,\n    ):\n        super().__init__(msg, error_index, error_pid)\n        self.exit_code = exit_code\n        self.signal_name = signal_name\n\n\ndef _wrap(fn, i, args, error_queue):\n    # prctl(2) is a Linux specific system call.\n    # On other systems the following function call has no effect.\n    # This is set to ensure that non-daemonic child processes can\n    # terminate if their parent terminates before they do.\n    _prctl_pr_set_pdeathsig(signal.SIGINT)\n\n    try:\n        fn(i, *args)\n    except KeyboardInterrupt:\n        pass  # SIGINT; Killed by parent, do nothing\n    except Exception:\n        # Propagate exception to parent process, keeping original traceback\n        import traceback\n\n        error_queue.put(traceback.format_exc())\n        sys.exit(1)\n\n\nclass ProcessContext:\n    def __init__(self, processes, error_queues):\n        self.error_queues = error_queues\n        self.processes = processes\n        self.sentinels = {\n            process.sentinel: index for index, process in enumerate(processes)\n        }\n\n    def pids(self):\n        return [int(process.pid) for process in self.processes]\n\n    def join(self, timeout=None):\n        r\"\"\"\n        Tries to join one or more processes in this spawn context.\n        If one of them exited with a non-zero exit status, this function\n        kills the remaining processes and raises an exception with the cause\n        of the first process exiting.\n\n        Returns ``True`` if all processes have been joined successfully,\n        ``False`` if there are more processes that need to be joined.\n\n        Args:\n            timeout (float): Wait this long before giving up on waiting.\n        \"\"\"\n        # Ensure this function can be called even when we're done.\n        if len(self.sentinels) == 0:\n            return True\n\n        # Wait for any process to fail or all of them to succeed.\n        ready = multiprocessing.connection.wait(self.sentinels.keys(), timeout=timeout,)\n\n        error_index = None\n        for sentinel in ready:\n            index = self.sentinels.pop(sentinel)\n            process = self.processes[index]\n            process.join()\n            if process.exitcode != 0:\n                error_index = index\n                break\n\n        # Return if there was no error.\n        if error_index is None:\n            # Return whether or not all processes have been joined.\n            return len(self.sentinels) == 0\n\n        # Assume failure. Terminate processes that are still alive.\n        for process in self.processes:\n            if process.is_alive():\n                process.terminate()\n            process.join()\n\n        # There won't be an error on the queue if the process crashed.\n        failed_process = self.processes[error_index]\n        if self.error_queues[error_index].empty():\n            exitcode = self.processes[error_index].exitcode\n            if exitcode < 0:\n                name = signal.Signals(-exitcode).name\n                raise ProcessExitedException(\n                    \"process %d terminated with signal %s\" % (error_index, name),\n                    error_index=error_index,\n                    error_pid=failed_process.pid,\n                    exit_code=exitcode,\n                    signal_name=name,\n                )\n            else:\n                raise ProcessExitedException(\n                    \"process %d terminated with exit code %d\" % (error_index, exitcode),\n                    error_index=error_index,\n                    error_pid=failed_process.pid,\n                    exit_code=exitcode,\n                )\n\n        original_trace = self.error_queues[error_index].get()\n        msg = \"\\n\\n-- Process %d terminated with the following error:\\n\" % error_index\n        msg += original_trace\n        raise ProcessRaisedException(msg, error_index, failed_process.pid)\n\n\nclass SpawnContext(ProcessContext):\n    def __init__(self, processes, error_queues):\n        warnings.warn(\"SpawnContext is renamed to ProcessContext since 1.4 release.\")\n        super(SpawnContext, self).__init__(processes, error_queues)\n\n    pass\n\n\n# Note: [start_processes]\n# mp.start_processes handles both start_method='spawn' and 'fork'. It's supposed to be a\n# more generalized API than mp.spawn. Currently we only document mp.spawn as it's the\n# CUDA compatible start_method. However, in environments like Ipython notebooks, 'fork'\n# works better than 'spawn'. Every helper function we created for mp.spawn is indeed\n# general enough, and backends like XLA can reuse them in Colab notebooks as well.\n# Currently we only add this API first, we can consider adding it to documentation as\n# needed in the future.\ndef start_processes(\n    fn, args=(), nprocs=1, join=True, daemon=False, start_method=\"spawn\"\n):\n    mp = multiprocessing.get_context(start_method)\n    error_queues = []\n    processes = []\n    for i in range(nprocs):\n        error_queue = mp.SimpleQueue()\n        process = mp.Process(\n            target=_wrap, args=(fn, i, args, error_queue), daemon=daemon,\n        )\n        process.start()\n        error_queues.append(error_queue)\n        processes.append(process)\n\n    context = ProcessContext(processes, error_queues)\n    if not join:\n        return context\n\n    # Loop on join until it returns True or raises an exception.\n    while not context.join():\n        pass\n\n\ndef spawn(fn, args=(), nprocs=1, join=True, daemon=False, start_method=\"spawn\"):\n    r\"\"\"Spawns ``nprocs`` processes that run ``fn`` with ``args``.\n\n    If one of the processes exits with a non-zero exit status, the\n    remaining processes are killed and an exception is raised with the\n    cause of termination. In the case an exception was caught in the\n    child process, it is forwarded and its traceback is included in\n    the exception raised in the parent process.\n\n    Args:\n        fn (function): Function is called as the entrypoint of the\n            spawned process. This function must be defined at the top\n            level of a module so it can be pickled and spawned. This\n            is a requirement imposed by multiprocessing.\n\n            The function is called as ``fn(i, *args)``, where ``i`` is\n            the process index and ``args`` is the passed through tuple\n            of arguments.\n\n        args (tuple): Arguments passed to ``fn``.\n        nprocs (int): Number of processes to spawn.\n        join (bool): Perform a blocking join on all processes.\n        daemon (bool): The spawned processes' daemon flag. If set to True,\n                       daemonic processes will be created.\n        start_method (string): (deprecated) this method will always use ``spawn``\n                               as the start method. To use a different start method\n                               use ``start_processes()``.\n\n    Returns:\n        None if ``join`` is ``True``,\n        :class:`~ProcessContext` if ``join`` is ``False``\n\n    \"\"\"\n    if start_method != \"spawn\":\n        msg = (\n            \"This method only supports start_method=spawn (got: %s).\\n\"\n            \"To use a different start_method use:\\n\\t\\t\"\n            \" oneflow.multiprocessing.start_processes(...)\" % start_method\n        )\n        warnings.warn(msg)\n    return start_processes(fn, args, nprocs, join, daemon, start_method=\"spawn\")\n"
  },
  {
    "path": "python/oneflow/nn/__init__.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nfrom .modules import *\nfrom oneflow.nn.graph import Graph\nfrom oneflow.nn.modules.activation import (\n    ELU,\n    CELU,\n    GELU,\n    QuickGELU,\n    SquareReLU,\n    GLU,\n    Hardsigmoid,\n    Hardshrink,\n    Hardswish,\n    Hardtanh,\n    LeakyReLU,\n    RReLU,\n    LogSigmoid,\n    LogSoftmax,\n    Mish,\n    PReLU,\n    ReLU,\n    ReLU6,\n    Sigmoid,\n    Softmax,\n    Softshrink,\n    Softplus,\n    Tanh,\n    SELU,\n    SiLU,\n    Softsign,\n    Threshold,\n)\n\nfrom oneflow.nn.modules.all_reduce import AllReduce\nfrom oneflow.nn.modules.batchnorm import (\n    BatchNorm1d,\n    BatchNorm2d,\n    BatchNorm3d,\n    SyncBatchNorm,\n)\nfrom oneflow.nn.modules.batchnorm_fused import (\n    FusedBatchNorm1d,\n    FusedBatchNorm2d,\n    FusedBatchNorm3d,\n)\nfrom oneflow.nn.modules.fused_mlp import FusedMLP\n\nfrom oneflow.nn.modules.container import (\n    ModuleDict,\n    ModuleList,\n    ParameterDict,\n    ParameterList,\n    Sequential,\n)\nfrom oneflow.nn.modules.conv import (\n    Conv1d,\n    Conv2d,\n    Conv3d,\n    ConvTranspose1d,\n    ConvTranspose2d,\n    ConvTranspose3d,\n)\nfrom oneflow.nn.modules.distance import CosineSimilarity, PairwiseDistance\nfrom oneflow.nn.modules.min_max_observer import MinMaxObserver\nfrom oneflow.nn.modules.moving_average_min_max_observer import (\n    MovingAverageMinMaxObserver,\n)\nfrom oneflow.nn.modules.fake_quantization import FakeQuantization\nfrom oneflow.nn.modules.quantization import Quantization\nfrom oneflow.nn.modules.distributed_partial_fc_sample import (\n    DistributedPariticalFCSample,\n)\n\nfrom oneflow.nn.modules.dataset import (\n    COCOReader,\n    CoinFlip,\n    CropMirrorNormalize,\n    OFRecordImageDecoder,\n    OFRecordImageDecoderRandomCrop,\n    OFRecordImageGpuDecoderRandomCropResize,\n    OFRecordRawDecoder,\n    OFRecordRawDecoder as OfrecordRawDecoder,\n    OFRecordReader,\n    OFRecordReader as OfrecordReader,\n    OFRecordBytesDecoder,\n    GPTIndexedBinDataReader,\n    RawReader,\n)\n\nfrom oneflow.nn.modules.dropout import Dropout, Dropout1d, Dropout2d, Dropout3d\nfrom oneflow.nn.modules.flatten import Flatten\nfrom oneflow.nn.modules.instancenorm import (\n    InstanceNorm1d,\n    InstanceNorm2d,\n    InstanceNorm3d,\n)\nfrom oneflow.nn.modules.linear import Identity, Linear\nfrom oneflow.nn.modules.loss import (\n    BCELoss,\n    BCEWithLogitsLoss,\n    CrossEntropyLoss,\n    CTCLoss,\n    KLDivLoss,\n    L1Loss,\n    MarginRankingLoss,\n    MSELoss,\n    NLLLoss,\n    SmoothL1Loss,\n    CombinedMarginLoss,\n    TripletMarginLoss,\n)\nfrom oneflow.nn.modules.normalization import GroupNorm, LayerNorm, RMSLayerNorm, RMSNorm\nfrom oneflow.nn.modules.padding import (\n    ConstantPad1d,\n    ConstantPad2d,\n    ConstantPad3d,\n    ReflectionPad1d,\n    ReflectionPad2d,\n    ReplicationPad1d,\n    ReplicationPad2d,\n    ZeroPad2d,\n)\nfrom oneflow.nn.modules.pixelshuffle import PixelShufflev2 as PixelShuffle\nfrom oneflow.nn.modules.pooling import (\n    AvgPool1d,\n    AvgPool2d,\n    AvgPool3d,\n    MaxPool1d,\n    MaxPool2d,\n    MaxPool3d,\n    MaxUnpool1d,\n    MaxUnpool2d,\n    MaxUnpool3d,\n    AdaptiveAvgPool1d,\n    AdaptiveAvgPool2d,\n    AdaptiveAvgPool3d,\n    AdaptiveMaxPool1d,\n    AdaptiveMaxPool2d,\n    AdaptiveMaxPool3d,\n)\nfrom oneflow.nn.modules.sparse import Embedding\nfrom oneflow.nn.modules.upsampling import (\n    Upsample,\n    UpsamplingBilinear2d,\n    UpsamplingNearest2d,\n)\nfrom oneflow.nn.modules.fold import Fold, Unfold\n\nfrom oneflow.nn.parameter import Parameter\nfrom oneflow.nn import utils\n\nfrom . import functional\n\nfrom . import parallel\n\nfrom oneflow.nn.modules.rnn import (\n    RNNCellBase,\n    RNNCell,\n    LSTMCell,\n    GRUCell,\n    RNNBase,\n    RNN,\n    LSTM,\n    GRU,\n)\n\nfrom oneflow.nn.qat.conv import QatConv1d, QatConv2d, QatConv3d\n\n\nclass DataParallel(Module):\n    def __init__(self):\n        raise NotImplementedError()\n"
  },
  {
    "path": "python/oneflow/nn/common_types.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nfrom typing import Tuple, TypeVar, Union\n\nT = TypeVar(\"T\")\n_scalar_or_tuple_any_t = Union[T, Tuple[T, ...]]\n_scalar_or_tuple_1_t = Union[T, Tuple[T]]\n_scalar_or_tuple_2_t = Union[T, Tuple[T, T]]\n_scalar_or_tuple_3_t = Union[T, Tuple[T, T, T]]\n_scalar_or_tuple_4_t = Union[T, Tuple[T, T, T, T]]\n_scalar_or_tuple_5_t = Union[T, Tuple[T, T, T, T, T]]\n_scalar_or_tuple_6_t = Union[T, Tuple[T, T, T, T, T, T]]\n_size_any_t = _scalar_or_tuple_any_t[int]\n_size_1_t = _scalar_or_tuple_1_t[int]\n_size_2_t = _scalar_or_tuple_2_t[int]\n_size_3_t = _scalar_or_tuple_3_t[int]\n_size_4_t = _scalar_or_tuple_4_t[int]\n_size_5_t = _scalar_or_tuple_5_t[int]\n_size_6_t = _scalar_or_tuple_6_t[int]\n_ratio_2_t = _scalar_or_tuple_2_t[float]\n_ratio_3_t = _scalar_or_tuple_3_t[float]\n_ratio_any_t = _scalar_or_tuple_any_t[float]\n"
  },
  {
    "path": "python/oneflow/nn/functional/__init__.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nfrom oneflow.nn.modules.interpolate import interpolate, interpolate_like\nfrom oneflow.nn.modules.affine_grid import affine_grid\nfrom oneflow.nn.modules.grid_sample import grid_sample\nfrom oneflow.nn.modules.sparse_softmax_cross_entropy import sparse_softmax_cross_entropy\nfrom oneflow._C import conv1d\nfrom oneflow._C import conv2d\nfrom oneflow._C import conv3d\nfrom oneflow._C import deconv1d as conv_transpose1d\nfrom oneflow._C import deconv2d as conv_transpose2d\nfrom oneflow._C import deconv3d as conv_transpose3d\nfrom oneflow._C import avg_pool1d\nfrom oneflow._C import avg_pool2d\nfrom oneflow._C import avg_pool3d\nfrom .maxpool import max_pool1d\nfrom .maxpool import max_pool2d\nfrom .maxpool import max_pool3d\nfrom .maxpool import adaptive_max_pool1d\nfrom .maxpool import adaptive_max_pool2d\nfrom .maxpool import adaptive_max_pool3d\nfrom oneflow._C import adaptive_avg_pool1d\nfrom oneflow._C import adaptive_avg_pool2d\nfrom oneflow._C import adaptive_avg_pool3d\nfrom oneflow._C import max_unpool1d\nfrom oneflow._C import max_unpool2d\nfrom oneflow._C import max_unpool3d\nfrom oneflow._C import cosine_similarity, pairwise_distance\nfrom oneflow._C import relu\nfrom oneflow._C import square_relu\nfrom oneflow._C import hardtanh\nfrom oneflow._C import hardsigmoid\nfrom oneflow._C import hardshrink\nfrom oneflow._C import hardswish\nfrom oneflow._C import leaky_relu\nfrom oneflow._C import rrelu, rrelu_\nfrom oneflow._C import elu\nfrom oneflow._C import celu\nfrom oneflow._C import selu\nfrom oneflow._C import sigmoid\nfrom oneflow._C import softshrink\nfrom oneflow._C import prelu\nfrom oneflow._C import gelu_with_approximate as gelu\nfrom oneflow._C import quick_gelu\nfrom oneflow._C import glu\nfrom oneflow._C import logsigmoid\nfrom oneflow._C import log_softmax\nfrom oneflow._C import softsign\nfrom .softmax import softmax\nfrom oneflow._C import softplus\nfrom oneflow._C import tanh\nfrom oneflow._C import threshold\nfrom oneflow._C import silu\nfrom oneflow._C import mish\nfrom oneflow.nn.modules.normalization import layer_norm, group_norm\nfrom oneflow._C import dropout, dropout1d, dropout2d, dropout3d\nfrom oneflow._C import smooth_l1_loss\nfrom .pad import pad\nfrom .batch_norm import batch_norm\nfrom oneflow._C import triplet_margin_loss\nfrom oneflow._C import ctc_greedy_decoder\nfrom .ctc_loss import ctc_loss\nfrom oneflow._C import one_hot\nfrom oneflow._C import normalize\nfrom oneflow._C import mse_loss\nfrom oneflow._C import l1_loss\nfrom oneflow._C import cross_entropy\nfrom oneflow._C import binary_cross_entropy_loss as binary_cross_entropy\nfrom oneflow._C import (\n    binary_cross_entropy_with_logits_loss as binary_cross_entropy_with_logits,\n)\nfrom oneflow.nn.modules.sparse import embedding\nfrom oneflow.nn.modules.linear import linear\nfrom oneflow.nn.modules.activation import relu6\nfrom oneflow.nn.modules.upsampling import Upsample as upsample\nfrom oneflow._C import unfold\nfrom oneflow._C import fold\nfrom .deform_conv import deform_conv2d\nfrom oneflow._C import kl_div_loss as kl_div\nfrom oneflow._C import gumbel_softmax\nfrom .depend import depend\n"
  },
  {
    "path": "python/oneflow/nn/functional/batch_norm.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport os\nfrom typing import List, Optional\n\nfrom oneflow.framework.tensor import Tensor\nimport oneflow as flow\n\n\ndef batch_norm(\n    input: Tensor,\n    running_mean: Optional[Tensor],\n    running_var: Optional[Tensor],\n    weight: Optional[Tensor] = None,\n    bias: Optional[Tensor] = None,\n    training: bool = False,\n    momentum: float = 0.1,\n    eps: float = 1e-5,\n) -> Tensor:\n    r\"\"\"Applies Batch Normalization for each channel across a batch of data.\n\n    See :class:`~oneflow.nn.BatchNorm1d`, :class:`~oneflow.nn.BatchNorm2d`,\n    :class:`~oneflow.nn.BatchNorm3d` for details.\n    \"\"\"\n    if input.ndim == 4 and os.getenv(\"ONEFLOW_ENABLE_NHWC\") == \"1\":\n        axis = 3\n    else:\n        axis = 1\n\n    return flow._C.normalization(\n        input, running_mean, running_var, weight, bias, axis, eps, momentum, training,\n    )\n"
  },
  {
    "path": "python/oneflow/nn/functional/ctc_loss.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nfrom oneflow.framework.tensor import Tensor\nimport oneflow as flow\n\n\ndef ctc_loss(\n    log_probs: Tensor,\n    targets: Tensor,\n    input_lengths: Tensor,\n    target_lengths: Tensor,\n    blank=0,\n    reduction=\"mean\",\n    zero_infinity=False,\n) -> Tensor:\n    r\"\"\"\n    The Connectionist Temporal Classification loss.\n    \n    The documentation is referenced from:\n    https://pytorch.org/docs/stable/generated/torch.nn.functional.ctc_loss.html\n\n    See :class:`~oneflow.nn.CTCLoss` for details.\n    \n    Args:\n        log_probs: The logarithmized probabilities of the outputs.\n        targets: Targets cannot be blank. In the second form, the targets are assumed to be concatenated.\n        input_lengths: Lengths of the inputs.\n        target_lengths: Lengths of the targets.\n        blank: Black label, default 0.\n        reduction: Specifies the reduction to apply to the output:  ``'none'`` | ``'mean'`` | ``'sum'`` . Default ``'Mean'``.\n        zero_infinity: Whether to zero infinite losses and the associated gradients. Default ``False``.\n        \n    Example:\n        >>> import oneflow as flow\n        >>> import oneflow.nn as nn\n        >>> import oneflow.nn.functional as F\n        >>> log_probs = flow.tensor(\n        ...     [\n        ...         [[-1.1031, -0.7998, -1.5200], [-0.9808, -1.1363, -1.1908]],\n        ...         [[-1.2258, -1.0665, -1.0153], [-1.1135, -1.2331, -0.9671]],\n        ...         [[-1.3348, -0.6611, -1.5118], [-0.9823, -1.2355, -1.0941]],\n        ...         [[-1.3850, -1.3273, -0.7247], [-0.8235, -1.4783, -1.0994]],\n        ...         [[-0.9049, -0.8867, -1.6962], [-1.4938, -1.3630, -0.6547]],\n        ...     ],\n        ...     dtype=flow.float32,\n        ...     requires_grad=True,\n        ...     )\n        >>> targets = flow.tensor([[1, 2, 2], [1, 2, 2]], dtype=flow.int32, device=\"cuda\")\n        >>> input_lengths = flow.tensor([5, 5], dtype=flow.int32)\n        >>> target_lengths = flow.tensor([3, 3], dtype=flow.int32)\n        >>> out = F.ctc_loss(log_probs, targets, input_lengths, target_lengths)\n        >>> out\n        tensor(1.1376, dtype=oneflow.float32, grad_fn=<scalar_mulBackward>)\n        \n    \"\"\"\n    max_target_length = 0\n    if targets.ndim == 1:\n        max_target_length = target_lengths.max().item()\n    elif targets.ndim == 2:\n        max_target_length = targets.shape[1]\n    return flow._C.ctc_loss(\n        log_probs,\n        targets,\n        input_lengths,\n        target_lengths,\n        max_target_length,\n        blank,\n        zero_infinity,\n        reduction,\n    )\n"
  },
  {
    "path": "python/oneflow/nn/functional/deform_conv.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nfrom typing import Optional, Tuple, Union\n\nimport oneflow as flow\nfrom oneflow.framework.tensor import Tensor\n\n\ndef deform_conv2d(\n    input: Tensor,\n    offset: Tensor,\n    weight: Tensor,\n    bias: Optional[Tensor] = None,\n    stride: Tuple[int, int] = (1, 1),\n    padding: Tuple[int, int] = (0, 0),\n    dilation: Tuple[int, int] = (1, 1),\n    mask: Optional[Tensor] = None,\n) -> Tensor:\n    r\"\"\"\n    Performs Deformable Convolution v2, described in\n    `Deformable ConvNets v2: More Deformable, Better Results\n    <https://arxiv.org/abs/1811.11168>`__ if :attr:`mask` is not ``None`` and\n    Performs Deformable Convolution, described in\n    `Deformable Convolutional Networks\n    <https://arxiv.org/abs/1703.06211>`__ if :attr:`mask` is ``None``.\n\n    Args:\n        input (Tensor[batch_size, in_channels, in_height, in_width]): input tensor\n        offset (Tensor[batch_size, 2 * offset_groups * kernel_height * kernel_width, out_height, out_width]):\n            offsets to be applied for each position in the convolution kernel.\n        weight (Tensor[out_channels, in_channels // groups, kernel_height, kernel_width]): convolution weights,\n            split into groups of size (in_channels // groups)\n        bias (Tensor[out_channels]): optional bias of shape (out_channels,). Default: None\n        stride (int or Tuple[int, int]): distance between convolution centers. Default: 1\n        padding (int or Tuple[int, int]): height/width of padding of zeroes around\n            each image. Default: 0\n        dilation (int or Tuple[int, int]): the spacing between kernel elements. Default: 1\n        mask (Tensor[batch_size, offset_groups * kernel_height * kernel_width, out_height, out_width]):\n            masks to be applied for each position in the convolution kernel. Default: None\n\n    Returns:\n        Tensor[batch_sz, out_channels, out_h, out_w]: result of convolution\n\n    Examples::\n        >>> import oneflow as flow\n        >>> import oneflow.nn.functional as F\n        >>> input = flow.rand(4, 3, 10, 10)\n        >>> kh, kw = 3, 3\n        >>> weight = flow.rand(5, 3, kh, kw)\n        >>> # offset and mask should have the same spatial size as the output\n        >>> # of the convolution. In this case, for an input of 10, stride of 1\n        >>> # and kernel size of 3, without padding, the output size is 8\n        >>> offset = flow.rand(4, 2 * kh * kw, 8, 8)\n        >>> mask = flow.rand(4, kh * kw, 8, 8)\n        >>> out = F.deform_conv2d(input, offset, weight, mask=mask)\n        >>> out.size()\n        oneflow.Size([4, 5, 8, 8])\n    \"\"\"\n    use_mask = mask is not None\n    if mask is None:\n        mask = flow.zeros((input.shape[0], 0), dtype=input.dtype).to(input.device)\n    stride_h = stride[0]\n    stride_w = stride[1]\n    pad_h = padding[0]\n    pad_w = padding[1]\n    dil_h = dilation[0]\n    dil_w = dilation[1]\n    weights_h, weights_w = weight.shape[-2:]\n\n    # TODO(yzm): Support rectangle convolution\n    if weights_h != weights_w:\n        raise NotImplementedError(\"Rectangle convolution is not supported currently.\")\n\n    if use_mask and len(mask.shape) != 4:\n        raise RuntimeError(\"The dimension of mask tensor weight must be 4\")\n    if len(input.shape) != 4:\n        raise RuntimeError(\"The dimension of input tensor weight must be 4\")\n    if len(weight.shape) != 4:\n        raise RuntimeError(\"The dimension of weight tensor weight must be 4\")\n    if len(offset.shape) != 4:\n        raise RuntimeError(\"The dimension of offset tensor weight must be 4\")\n\n    _, n_in_channels, _, _ = input.shape\n    n_offset_grps = offset.shape[1] // (2 * weights_h * weights_w)\n    n_weight_grps = n_in_channels // weight.shape[1]\n\n    if n_offset_grps == 0:\n        raise RuntimeError(\n            \"The shape of the offset tensor at dimension 1 is not valid. It should \"\n            \"be a multiple of 2 * weight.size[2] * weight.size[3].\\n\"\n            f\"Got offset.shape[1]={offset.shape[1]}, while 2 * weight.size[2] * weight.size[3]={2 * weights_h * weights_w}\"\n        )\n\n    return flow._C.deform_conv2d(\n        input,\n        weight,\n        offset,\n        mask,\n        bias,\n        stride_h,\n        stride_w,\n        pad_h,\n        pad_w,\n        dil_h,\n        dil_w,\n        n_weight_grps,\n        n_offset_grps,\n        use_mask,\n    )\n"
  },
  {
    "path": "python/oneflow/nn/functional/depend.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nfrom oneflow.framework.tensor import Tensor\nimport oneflow as flow\nfrom typing import Union, List\n\n\ndef depend(input: Tensor, depend: Union[Tensor, List[Tensor]]) -> Tensor:\n    r\"\"\"\n    Add control dependency to guarantee OP A is executed before OP B.\n    Used to prevent OPs from being rearranged or eliminated during graph compilation.\n\n    Args:\n        input (Tensor): a tensor intended to input OP B\n        depend (Tensor or List[Tensor]): one of the output tensors of OP A (support passing in multiple tensors form different OP)\n\n    Returns:\n        Tensor: the identity of \"input\" tensor\n\n    Examples:\n        >>> import oneflow as flow\n        >>> import oneflow.nn as nn\n        >>> import oneflow.nn.functional as F\n        >>> class Model(nn.Module):\n        ...     def __init__(self):\n        ...         super().__init__()\n        ...         self.OP_A = nn.Linear(128, 128)\n        ...         self.OP_B = nn.Linear(128, 128)\n        ...\n        ...     def forward(self, x):\n        ...         x1 = self.OP_A(x)\n        ...         x = F.depend(x, x1)\n        ...         return self.OP_B(x)\n        ...\n        >>> model = Model()\n        >>> class Graph(nn.Graph):\n        ...     def __init__(self) -> None:\n        ...         super().__init__()\n        ...         self.model = model\n        ...\n        ...     def build(self, x):\n        ...         return self.model(x)\n        ...\n        >>> graph = Graph()\n        >>> x = flow.randn([1, 128], dtype=flow.float32)\n        >>> y = graph(x)\n    \"\"\"\n    # avoid performance loss in eager mode\n    if not input.is_lazy:\n        return input\n\n    # avoid self-loop\n    if isinstance(depend, Tensor) and input is depend:\n        raise RuntimeError('\"input\" and \"depend\" can NOT be the same tensor.')\n\n    if isinstance(depend, List):\n        for idx, t_depend in enumerate(depend):\n            if input is t_depend:\n                raise RuntimeError(\n                    '\"input\" and \"depend[%d]\" are the same tensor, which is not allowed.'\n                    % idx\n                )\n\n    return flow._C.depend(input, depend)\n"
  },
  {
    "path": "python/oneflow/nn/functional/maxpool.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow\n\n# oneflow._C.max_poolXd returns a TensorTuple, to align torch,\n# here we return different result according to the param `return_indices`.\ndef max_pool1d(\n    x,\n    kernel_size,\n    stride=None,\n    padding=0,\n    dilation=1,\n    return_indices=False,\n    ceil_mode=False,\n    data_format=\"channels_first\",\n):\n    r\"\"\"\n    max_pool1d(input, kernel_size, stride=None, padding=0, dilation=1, return_indices=False,ceil_mode=False, data_format=\"channels_first\")\n\n    Applies a 1D max pooling over an input signal composed of several input\n    planes.\n\n    The documentation is referenced from: https://pytorch.org/docs/master/generated/torch.nn.functional.max_pool1d.html.\n\n    .. note::\n        The order of :attr:`ceil_mode` and :attr:`return_indices` is different from\n        what seen in :class:`~oneflow.nn.MaxPool1d`, and will change in a future release.\n\n    See :class:`~oneflow.nn.MaxPool1d` for details.\n\n    Args:\n        input: input tensor of shape :math:`(\\text{minibatch} , \\text{in_channels} , iW)`, minibatch dim optional.\n        kernel_size: the size of the window. Can be a single number or a tuple `(kW,)`\n        stride: the stride of the window. Can be a single number or a tuple `(sW,)`. Default: :attr:`kernel_size`\n        padding: Implicit negative infinity padding to be added on both sides, must be >= 0 and <= kernel_size / 2.\n        dilation: The stride between elements within a sliding window, must be > 0.\n        return_indices: If ``True``, will return the argmax along with the max values.Useful for :class:`oneflow.nn.functional.max_unpool1d` later.\n        ceil_mode: If ``True``, will use `ceil` instead of `floor` to compute the output shape. This ensures that every element in the input tensor is covered by a sliding window.\n    \"\"\"\n    _max_pool_out = oneflow._C.max_pool1d(\n        x,\n        kernel_size,\n        stride,\n        padding,\n        dilation,\n        return_indices,\n        ceil_mode,\n        data_format,\n    )\n    if return_indices:\n        return _max_pool_out\n    else:\n        return _max_pool_out[0]\n\n\ndef max_pool2d(\n    x,\n    kernel_size,\n    stride=None,\n    padding=0,\n    dilation=1,\n    return_indices=False,\n    ceil_mode=False,\n    data_format=\"channels_first\",\n):\n    r\"\"\"\n    max_pool2d(input, kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False,data_format=\"channels_first\")\n\n    Applies a 2D max pooling over an input signal composed of several input\n    planes.\n\n    The documentation is referenced from: https://pytorch.org/docs/master/generated/torch.nn.functional.max_pool2d.html.\n\n    .. note::\n        The order of :attr:`ceil_mode` and :attr:`return_indices` is different from\n        what seen in :class:`~oneflow.nn.MaxPool2d`, and will change in a future release.\n\n    See :class:`~oneflow.nn.MaxPool2d` for details.\n\n    Args:\n        input: input tensor :math:`(\\text{minibatch} , \\text{in_channels} , iH , iW)`, minibatch dim optional.\n        kernel_size: size of the pooling region. Can be a single number or a tuple `(kH, kW)`\n        stride: stride of the pooling operation. Can be a single number or a tuple `(sH, sW)`. Default: :attr:`kernel_size`\n        padding: Implicit negative infinity padding to be added on both sides, must be >= 0 and <= kernel_size / 2.\n        dilation: The stride between elements within a sliding window, must be > 0.\n        return_indices: If ``True``, will return the argmax along with the max values.Useful for :class:`oneflow.nn.functional.max_unpool2d` later.\n        ceil_mode: If ``True``, will use `ceil` instead of `floor` to compute the output shape. This ensures that every element in the input tensor is covered by a sliding window.\n    \"\"\"\n    _max_pool_out = oneflow._C.max_pool2d(\n        x,\n        kernel_size,\n        stride,\n        padding,\n        dilation,\n        return_indices,\n        ceil_mode,\n        data_format,\n    )\n    if return_indices:\n        return _max_pool_out\n    else:\n        return _max_pool_out[0]\n\n\ndef max_pool3d(\n    x,\n    kernel_size,\n    stride=None,\n    padding=0,\n    dilation=1,\n    return_indices=False,\n    ceil_mode=False,\n    data_format=\"channels_first\",\n):\n    r\"\"\"\n    max_pool3d(input, kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False, data_format=\"channels_first\")\n\n    Applies a 3D max pooling over an input signal composed of several input\n    planes.\n\n    The documentation is referenced from: https://pytorch.org/docs/master/generated/torch.nn.functional.max_pool3d.html.\n\n    .. note::\n        The order of :attr:`ceil_mode` and :attr:`return_indices` is different from\n        what seen in :class:`~oneflow.nn.MaxPool3d`, and will change in a future release.\n\n    See :class:`~oneflow.nn.MaxPool3d` for details.\n\n    Args:\n        input: input tensor :math:`(\\text{minibatch} , \\text{in_channels} , iD, iH , iW)`, minibatch dim optional.\n        kernel_size: size of the pooling region. Can be a single number or a tuple `(kT, kH, kW)`\n        stride: stride of the pooling operation. Can be a single number or a tuple `(sT, sH, sW)`. Default: :attr:`kernel_size`\n        padding: Implicit negative infinity padding to be added on both sides, must be >= 0 and <= kernel_size / 2.\n        dilation: The stride between elements within a sliding window, must be > 0.\n        return_indices: If ``True``, will return the argmax along with the max values.Useful for :class:`~oneflow.nn.functional.max_unpool3d` later.\n        ceil_mode: If ``True``, will use `ceil` instead of `floor` to compute the output shape. This ensures that every element in the input tensor is covered by a sliding window.\n    \"\"\"\n    _max_pool_out = oneflow._C.max_pool3d(\n        x,\n        kernel_size,\n        stride,\n        padding,\n        dilation,\n        return_indices,\n        ceil_mode,\n        data_format,\n    )\n    if return_indices:\n        return _max_pool_out\n    else:\n        return _max_pool_out[0]\n\n\ndef adaptive_max_pool1d(input, output_size, return_indices: bool = False):\n    r\"\"\"Applies a 1D adaptive max pooling over an input signal composed of\n    several input planes.\n\n    The documentation is referenced from:\n    https://pytorch.org/docs/1.10/generated/torch.nn.functional.adaptive_max_pool1d.html\n\n    See :class:`~oneflow.nn.AdaptiveMaxPool1d` for details and output shape.\n\n    Args:\n        output_size: the target output size (single integer)\n        return_indices: whether to return pooling indices. Default: ``False``\n\n    \"\"\"\n\n    _out = oneflow._C.adaptive_max_pool1d(input, output_size)\n    if return_indices:\n        return _out\n    else:\n        return _out[0]\n\n\ndef adaptive_max_pool2d(\n    input, output_size, return_indices: bool = False, data_format=\"channels_first\"\n):\n    r\"\"\"Applies a 2D adaptive max pooling over an input signal composed of\n    several input planes.\n\n    The documentation is referenced from:\n    https://pytorch.org/docs/1.10/generated/torch.nn.functional.adaptive_max_pool2d.html\n\n    See :class:`~oneflow.nn.AdaptiveMaxPool2d` for details and output shape.\n\n    Args:\n        output_size: the target output size (single integer or\n            double-integer tuple)\n        return_indices: whether to return pooling indices. Default: ``False``\n\n    \"\"\"\n    _out = oneflow._C.adaptive_max_pool2d(input, output_size, data_format=data_format)\n    if return_indices:\n        return _out\n    else:\n        return _out[0]\n\n\ndef adaptive_max_pool3d(input, output_size, return_indices: bool = False):\n    r\"\"\"Applies a 3D adaptive max pooling over an input signal composed of\n    several input planes.\n\n    The documentation is referenced from:\n    https://pytorch.org/docs/1.10/generated/torch.nn.functional.adaptive_max_pool3d.html\n\n    See :class:`~oneflow.nn.AdaptiveMaxPool3d` for details and output shape.\n\n    Args:\n        output_size: the target output size (single integer or\n            triple-integer tuple)\n        return_indices: whether to return pooling indices. Default: ``False``\n\n    \"\"\"\n\n    _out = oneflow._C.adaptive_max_pool3d(input, output_size)\n    if return_indices:\n        return _out\n    else:\n        return _out[0]\n"
  },
  {
    "path": "python/oneflow/nn/functional/pad.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nfrom typing import List\nfrom oneflow.framework.tensor import Tensor\nimport oneflow as flow\n\n\ndef pad(\n    input: Tensor, pad: List[int], mode: str = \"constant\", value: float = 0.0\n) -> Tensor:\n    r\"\"\"Pads tensor.\n\n    The documentation is referenced from:\n    https://pytorch.org/docs/1.10/generated/torch.nn.functional.pad.html.\n\n    Padding size:\n        The padding size by which to pad some dimensions of :attr:`input`\n        are described starting from the last dimension and moving forward.\n        :math:`\\left\\lfloor\\frac{\\text{len(pad)}}{2}\\right\\rfloor` dimensions\n        of ``input`` will be padded.\n        For example, to pad only the last dimension of the input tensor, then\n        :attr:`pad` has the form\n        :math:`(\\text{padding_left}, \\text{padding_right})`;\n        to pad the last 2 dimensions of the input tensor, then use\n        :math:`(\\text{padding_left}, \\text{padding_right},`\n        :math:`\\text{padding_top}, \\text{padding_bottom})`;\n        to pad the last 3 dimensions, use\n        :math:`(\\text{padding_left}, \\text{padding_right},`\n        :math:`\\text{padding_top}, \\text{padding_bottom}`\n        :math:`\\text{padding_front}, \\text{padding_back})`.\n\n    Padding mode:\n        See :class:`oneflow.nn.ConstantPad2d`, :class:`oneflow.nn.ReflectionPad2d`, and\n        :class:`oneflow.nn.ReplicationPad2d` for concrete examples on how each of the\n        padding modes works. Constant padding is implemented for arbitrary dimensions.\n        Replicate and reflection padding is implemented for padding the last 3\n        dimensions of 5D input tensor, or the last 2 dimensions of 4D input\n        tensor, or the last dimension of 3D input tensor.\n\n    Note:\n        When using the CUDA backend, this operation may induce nondeterministic\n        behaviour in its backward pass that is not easily switched off.\n\n    Args:\n        input (Tensor): N-dimensional tensor\n        pad (tuple): m-elements tuple, where\n            :math:`\\frac{m}{2} \\leq` input dimensions and :math:`m` is even.\n        mode: ``'constant'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.\n            Default: ``'constant'``\n        value: fill value for ``'constant'`` padding. Default: ``0``\n\n    Examples::\n\n        >>> import oneflow as flow\n        >>> import oneflow.nn.functional as F\n        >>> t4d = flow.empty(3, 3, 4, 2)\n        >>> p1d = (1, 1)\n        >>> out = F.pad(t4d, p1d)\n        >>> out.size()\n        oneflow.Size([3, 3, 4, 4])\n\n    \"\"\"\n    assert len(pad) % 2 == 0, \"Padding length must be divisible by 2\"\n    assert len(pad) // 2 <= input.dim(), \"Padding length too large\"\n    if mode == \"constant\":\n        return flow._C.pad(input, pad, mode=\"constant\", value=value)\n    else:\n        assert (\n            value == 0.0\n        ), 'Padding mode \"{}\"\" doesn\\'t take in value argument'.format(mode)\n        if len(pad) == 2 and (input.dim() == 2 or input.dim() == 3):\n            if mode == \"reflect\":\n                return flow._C.pad(input, pad, mode=\"reflect\")\n            elif mode == \"replicate\":\n                return flow._C.pad(input, pad, mode=\"replicate\")\n            elif mode == \"circular\":\n                raise NotImplementedError(\n                    \"1D circular padding are not supported for now\"\n                )\n            else:\n                raise NotImplementedError\n\n        elif len(pad) == 4 and (input.dim() == 3 or input.dim() == 4):\n            if mode == \"reflect\":\n                return flow._C.pad(input, pad, mode=\"reflect\")\n            elif mode == \"replicate\":\n                return flow._C.pad(input, pad, mode=\"replicate\")\n            elif mode == \"circular\":\n                raise NotImplementedError(\n                    \"2D circular padding are not supported for now\"\n                )\n            else:\n                raise NotImplementedError\n\n        elif len(pad) == 6 and (input.dim() == 4 or input.dim() == 5):\n            if mode == \"reflect\":\n                raise NotImplementedError(\n                    \"3D reflect padding are not supported for now\"\n                )\n            elif mode == \"replicate\":\n                raise NotImplementedError(\n                    \"3D replicate padding are not supported for now\"\n                )\n            elif mode == \"circular\":\n                raise NotImplementedError(\n                    \"3D circular padding are not supported for now\"\n                )\n            else:\n                raise NotImplementedError\n        else:\n            raise NotImplementedError(\n                \"Only 2D, 3D, 4D, 5D padding with non-constant padding are supported for now\"\n            )\n"
  },
  {
    "path": "python/oneflow/nn/functional/softmax.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport os\nfrom typing import List, Optional\n\nfrom oneflow.framework.tensor import Tensor\nimport oneflow as flow\n\n# ref https://github.com/pytorch/pytorch/blob/master/torch/nn/functional.py\ndef softmax(input: Tensor, dim: Optional[int] = None, dtype=None) -> Tensor:\n    r\"\"\"Applies a softmax function.\n    Softmax is defined as:\n    :math:`\\text{Softmax}(x_{i}) = \\frac{\\exp(x_i)}{\\sum_j \\exp(x_j)}`\n    It is applied to all slices along dim, and will re-scale them so that the elements\n    lie in the range `[0, 1]` and sum to 1.\n\n    See :class:`~oneflow.nn.Softmax` for more details.\n\n    Args:\n        input (Tensor): input\n        dim (int): A dimension along which softmax will be computed.\n        dtype (:class:`oneflow.dtype`, optional): the desired data type of returned tensor.\n            If specified, the input tensor is casted to :attr:`dtype` before the operation\n            is performed. This is useful for preventing data type overflows. Default: None.\n\n    .. note::\n        This function doesn't work directly with NLLLoss,\n        which expects the Log to be computed between the Softmax and itself.\n        Use log_softmax instead (it's faster and has better numerical properties).\n    \"\"\"\n    if dtype is None:\n        ret = flow._C.softmax(input, dim)\n    else:\n        ret = flow._C.softmax(input.to(dtype), dim)\n    return ret\n"
  },
  {
    "path": "python/oneflow/nn/graph/__init__.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nfrom .graph import Graph\nfrom .proxy import Proxy\nfrom .graph_block import GraphModule\nfrom .graph_block import GraphTensor\n"
  },
  {
    "path": "python/oneflow/nn/graph/cache.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport os\nimport weakref\nfrom collections import deque, OrderedDict\nfrom typing import Dict, Union\n\nfrom oneflow.framework.args_tree import ArgsTree\nfrom oneflow.framework.tensor import Tensor\nimport oneflow\n\n\nclass LRUCache(object):\n    _cnt: int = 0\n\n    def __init__(self, cache_size, keep_the_1st=True):\n        assert cache_size >= 2\n        self.cache_size = cache_size\n        self.hash_map = dict()\n        self.keep_the_1st = keep_the_1st\n        self.queue = deque()\n\n    def is_empty(self):\n        return len(self.hash_map) == 0\n\n    def is_full(self):\n        return len(self.hash_map) >= self.cache_size\n\n    def pop(self):\n        if len(self.queue) == 0:\n            return None\n        pop_key = self.queue.pop()\n        value = self.hash_map.pop(pop_key)\n        del value\n        return pop_key\n\n    def set(self, key, value):\n        new_key = None\n        old_key = None\n        if key in self.hash_map:\n            return new_key, old_key\n\n        if self.is_full():\n            old_key = self.pop()\n            assert old_key is not None, f\"Cache size is {self.cache_size}, at least 2.\"\n        assert not self.is_full()\n\n        if not (self.keep_the_1st and self.is_empty()):\n            self.queue.appendleft(key)\n\n        value._oneflow_graph_cache_order = LRUCache._cnt\n        LRUCache._cnt += 1\n        self.hash_map[key] = value\n        new_key = key\n        return new_key, old_key\n\n    def get(self, key):\n        if key in self.hash_map:\n            if key in self.queue:\n                self.queue.remove(key)\n                self.queue.appendleft(key)\n            return self.hash_map[key]\n\n        return None\n\n    def items(self):\n        for (key, value) in self.hash_map.items():\n            yield (key, value)\n\n\nclass AvoidRecursiveCacheCall(object):\n    def __init__(self, graph) -> None:\n        self._g = graph\n        self._prev_flag = self._g._run_with_cache\n\n    def __enter__(self):\n        self._g._run_with_cache = False\n\n    def __exit__(self, exc_type, exc_val, exc_tb):\n        self._g._run_with_cache = self._prev_flag\n\n\nclass GraphCache(object):\n    def __init__(self, base_graph, cache_size=10, enable_graph_shared=True):\n        assert base_graph is not None and isinstance(base_graph, weakref.ProxyTypes)\n        self._base_graph = base_graph\n\n        self._cache_size = cache_size\n        self._cache = None\n\n        self._enable_shared = enable_graph_shared\n\n    def set_cache_size(self, cache_size):\n        self._cache_size = cache_size\n\n    def enable_shared(self, enabled=True):\n        self._enable_shared = enabled\n\n    def __call__(self, *args, **kwargs):\n        graph = self.get_graph(*args, **kwargs)\n        with AvoidRecursiveCacheCall(graph):\n            return graph(*args, **kwargs)\n\n    def _compile(self, *args, **kwargs):\n        graph = self.get_graph(*args, **kwargs)\n        with AvoidRecursiveCacheCall(graph):\n            return graph._compile(*args, **kwargs)\n\n    def runtime_state_dict(\n        self, destination=None, with_eager=False,\n    ) -> Dict[str, Dict[str, Union[Dict[str, Tensor], str]]]:\n        if destination is None:\n            destination = OrderedDict()\n            destination._metadata = OrderedDict()\n\n        for (key, graph) in self._cache.items():\n            with AvoidRecursiveCacheCall(graph):\n                state_dict = graph.runtime_state_dict(with_eager=with_eager)\n            state_dict[\"cache_order\"] = graph._oneflow_graph_cache_order\n            state_dict[\"cache_key\"] = key\n            destination[state_dict[\"graph_name\"]] = state_dict\n        return destination\n\n    @staticmethod\n    def runtime_state_dict_to(\n        state_dict: Union[\n            Dict[str, Union[Dict[str, Tensor], str]],\n            Dict[str, Dict[str, Union[Dict[str, Tensor], str]]],\n        ],\n        device: str,\n    ) -> Union[\n        Dict[str, Union[Dict[str, Tensor], str]],\n        Dict[str, Dict[str, Union[Dict[str, Tensor], str]]],\n    ]:\n        destination = OrderedDict()\n        destination._metadata = OrderedDict()\n        for (key, sub_state_dict) in state_dict.items():\n            dest_sub_state_dict = oneflow.nn.Graph.runtime_state_dict_to(\n                sub_state_dict, device\n            )\n            dest_sub_state_dict[\"cache_order\"] = sub_state_dict[\"cache_order\"]\n            dest_sub_state_dict[\"cache_key\"] = sub_state_dict[\"cache_key\"]\n            destination[key] = dest_sub_state_dict\n        return destination\n\n    def _init_and_get_a_graph_in_cache(self, cache_key):\n        self._base_graph._print(\n            0,\n            0,\n            self._base_graph._shallow_repr()\n            + f\" is creating a graph cache with key {cache_key}.\",\n        )\n        cur_is_base = False\n        if self._cache.is_empty():\n            # Has no graph yet\n            cur_is_base = True\n            graph = self._base_graph\n        else:\n            # Create new graph from base\n            graph = self._base_graph.__class__(\n                *self._base_graph._cached_init_args,\n                **self._base_graph._cached_init_kwargs,\n            )\n            graph._run_with_cache = False\n            graph._dynamic_input_graph_cache = None\n            graph._cached_init_args = None\n            graph._cached_init_kwargs = None\n\n        if self._enable_shared is True:\n            if cur_is_base:\n                graph.enable_shared()\n            else:\n                graph.share_from(self._base_graph)\n        new_key, old_key = self._cache.set(cache_key, graph)\n        if old_key is not None:\n            self._base_graph._print(\n                0,\n                0,\n                self._base_graph._shallow_repr()\n                + f\" cache is full(cache size {self._cache_size}), has deleted an old graph cache with key {old_key}.\",\n            )\n        assert new_key is not None\n\n        return graph\n\n    def load_runtime_state_dict(\n        self,\n        state_dict: Dict[str, Dict[str, Union[Dict[str, Tensor], str]]],\n        *,\n        warmup_with_run: bool = False,\n    ) -> None:\n        graph_dict = dict()\n        for _, sub_state_dict in state_dict.items():\n            cache_order = sub_state_dict[\"cache_order\"]\n            graph_dict[cache_order] = sub_state_dict\n\n        if self._cache is None:\n            self._cache = LRUCache(self._cache_size)\n        for _, sub_state_dict in sorted(graph_dict.items()):\n            cache_key = sub_state_dict[\"cache_key\"]\n            graph = self._cache.get(cache_key)\n            assert graph is None\n            graph = self._init_and_get_a_graph_in_cache(cache_key)\n            with AvoidRecursiveCacheCall(graph):\n                graph.load_runtime_state_dict(\n                    sub_state_dict, warmup_with_run=warmup_with_run\n                )\n\n    def gen_key(self, *args, **kwargs):\n        flattened_shapes = []\n        args_tree = ArgsTree((args, kwargs), False)\n        for arg in args_tree.iter_nodes():\n            if isinstance(arg, Tensor):\n                flattened_shapes.append(arg.shape)\n        return tuple(flattened_shapes)\n\n    def get_graph(self, *args, **kwargs):\n        if self._cache is None:\n            self._cache = LRUCache(self._cache_size)\n\n        cache_key = hash(self.gen_key(*args, **kwargs))\n        graph = self._cache.get(cache_key)\n\n        # Create graph\n        if graph is None:\n            self._base_graph._print(\n                0,\n                0,\n                self._base_graph._shallow_repr()\n                + \" got a new input shape, is compiling a new graph.\",\n            )\n            graph = self._init_and_get_a_graph_in_cache(cache_key)\n\n        return graph\n"
  },
  {
    "path": "python/oneflow/nn/graph/graph.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport logging\nimport warnings\nimport os\nimport sys\nimport time\nimport inspect\nimport weakref\nfrom collections import OrderedDict\nfrom functools import partial, wraps\nfrom typing import Dict, Optional, Union, List, Callable\nfrom google.protobuf import text_format\nfrom copy import deepcopy\n\nimport oneflow\nimport oneflow._oneflow_internal\nimport oneflow.core.job.job_pb2 as job_pb\nimport oneflow.framework.c_api_util as c_api_util\nimport oneflow.framework.graph_build_util as graph_build_util\nimport oneflow.framework.session_context as session_ctx\nfrom oneflow.amp import GradScaler, StaticGradScaler\nfrom oneflow.env import get_rank\nfrom oneflow.framework.multi_client_session import MultiClientSession\nfrom oneflow.framework.tensor import Tensor, TensorTuple\nfrom oneflow.framework.tensor_tuple_util import convert_to_tensor_tuple\nfrom oneflow.nn.graph.proxy import (\n    Proxy,\n    GraphBlockType,\n    get_proxy_cls,\n    GraphModule,\n    GraphTensor,\n)\nfrom oneflow.nn.graph.graph_config import GraphConfig\nfrom oneflow.nn.graph.optimizer import OptDict, VariableConfig\nfrom oneflow.nn.graph.util import (\n    add_indent,\n    operators_repr,\n    GraphIR,\n    seq_to_func_return,\n    sys_exc_error_msg,\n    _rsd_sub_destination_to,\n    _job_to,\n    _plan_to,\n)\nfrom oneflow.framework.args_tree import ArgsTree\nfrom oneflow.nn.modules.module import Module\nfrom oneflow.nn.optimizer.lr_scheduler import LRScheduler\nfrom oneflow.optim.optimizer import Optimizer\n\n\nclass Graph(object):\n    r\"\"\"Base class for training or evaluating a neural network in static graph mode.\n\n    To use static graph mode for model training or evaluation in OneFlow, you should:\n\n    1. Define your customized graph as a subclass of ``nn.Graph``.\n    2. Add ``super().__init__()`` in your subclass's ``__init__()``.\n    3. Add modules to your graph as regular attributes.\n    4. Define computation logical in ``build()`` method.\n    5. Instantiate your graph then call it.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n\n        >>> class LinearGraph(flow.nn.Graph):\n        ...    def __init__(self):\n        ...        super().__init__()\n        ...        # Add a module to the graph.\n        ...        self.linear = flow.nn.Linear(3, 8, False)\n        ...    def build(self, x):\n        ...        # Use the module to build the computation logic of the graph.\n        ...        return self.linear(x)\n\n        # Instantiate the graph\n        >>> linear_graph = LinearGraph()\n        >>> x = flow.randn(4, 3)\n\n        # First call on graph will run graph's build() method to\n        # trace a computatioin graph. Then the computation graph will be\n        # optimized and executed for the first time.\n        >>> linear_graph(x).shape\n        oneflow.Size([4, 8])\n\n        # Later call on graph will execute the computation graph directly.\n        >>> linear_graph(x).shape\n        oneflow.Size([4, 8])\n\n    Note:\n        nn.Graph cannot be nested at the moment.\n    \"\"\"\n    _child_init_cnt = dict()\n\n    def __init__(\n        self,\n        *,\n        enable_get_runtime_state_dict: bool = False,\n        debug_v_level: int = -1,\n        debug_ranks: Optional[Union[int, List[int]]] = None,\n        debug_max_py_stack_depth: int = 2,\n        debug_only_user_py_stack=True,\n        debug_op_repr_with_py_stack=False,\n    ):\n        \"\"\"\n        Initializes internal Graph states. It MUST be called in ``__init__`` method of subclass.\n\n        For example:\n\n        .. code-block:: python\n\n            >>> import oneflow as flow\n            >>> class CustomGraph(flow.nn.Graph):\n            ...     def __init__(self):\n            ...         super().__init__() # MUST be called\n            ...         # Then define the graph attributes\n            ...     def build(self):\n            ...         pass\n\n        \"\"\"\n        self._generate_name()\n        self.config = GraphConfig()\n        self._blocks = OrderedDict()\n        self._opts = []\n        self._verbose = False\n        self._grad_scaler = None\n        self._variables_conf = OrderedDict()\n        self._additional_variable_tobe_loaded = OrderedDict()\n        self._is_compiled = False\n        self._is_user_mode = False\n        # Default is local view\n        self._is_global_view = False\n        # Optimize the overhead of graph input/output process\n        self._is_simple_tuple_input = False\n        self._is_simple_tuple_output = False\n\n        self._outputs_buffer_size = 2\n        self._cur_index_of_ouputs_buffer = 0\n\n        # For graph level op rewrite\n        self._unique_global_op_dict = dict()\n        self._unique_identity_op_dict = dict()\n\n        # Graph compilation related.\n        # forward graph job proto\n        self._forward_job_proto = None\n        # forward, backward and optimized graph job proto\n        self._full_job_proto = None\n        # completed graph job proto\n        self._compiled_job_proto = None\n        self._job_id = None\n        self._args_repr = []\n        self._outs_repr = []\n        self._oneflow_internal_graph_ir__ = None\n        enalbe_lazy_separate_compile = os.environ.get(\n            \"ONEFLOW_ENABLE_LAZY_SEPARATE_COMPILE\"\n        )\n        if enalbe_lazy_separate_compile != None and enalbe_lazy_separate_compile == \"1\":\n            os.environ[\"ONEFLOW_LAZY_COMPILE_MODE\"] = \"rank_per_process\"\n            # Separate compile mode only works with nccl use compute stream and logical chain.\n            os.environ[\"ENABLE_LOGICAL_CHAIN\"] = \"1\"\n            oneflow.boxing.nccl.enable_use_compute_stream(True)\n\n        self._session = session_ctx.GetDefaultSession()\n        assert type(self._session) is MultiClientSession\n        self._session.TryInit()\n        self._c_nn_graph = None\n        self.env_enable_mlir_inference_opt = None\n\n        # For build graph from another graph with different input shape.\n        self._enable_shared_from_this = False\n        self._build_with_shared_graph = False\n\n        # For load graph from runtime states.\n        self.enable_save_runtime_state_dict(enable_get_runtime_state_dict)\n        self._is_from_runtime_state_dict = False\n\n        # For run graph with dynamic shape cache\n        self._run_with_cache = False\n\n        # For debug\n        self._debug = False\n        self._debug_min_s_level = 2\n        self._debug_max_v_level = 0\n        self._debug_max_py_stack_depth = 2\n        self._debug_op_repr_with_py_stack = False\n        self._debug_only_user_py_stack = True\n        self.debug(\n            debug_v_level,\n            ranks=debug_ranks,\n            max_py_stack_depth=debug_max_py_stack_depth,\n            only_user_py_stack=debug_only_user_py_stack,\n            op_repr_with_py_stack=debug_op_repr_with_py_stack,\n        )\n\n    def build(self, *args, **kwargs):\n        r\"\"\"The ``build()`` method must be overridden to define neural network\n        computaion logic.\n\n        The ``build()`` method of nn.Graph is very similar to the ``forward()``\n        method of nn.Module. It is used to describe the computatioin logical of\n        a neural network.\n\n        When a graph object being called for the first time, the ``build()``\n        method will be called implicitly to build the computatioin graph.\n\n        Make sure to call modules's ``train()`` or ``eval()`` method before the\n        first call of your graph to make the module executing the right\n        training or evaluation logic if needed.\n\n        For example:\n\n        .. code-block:: python\n\n            >>> import oneflow as flow\n            >>> linear = flow.nn.Linear(3, 8, False)\n            >>> class MyGraph(flow.nn.Graph):\n            ...     def __init__(self):\n            ...         super().__init__()\n            ...         self.model = linear\n            ...     def build(self, x):\n            ...         return self.model(x)\n\n            >>> linear_graph = MyGraph()\n            >>> x = flow.randn(4, 3)\n            >>> linear.eval() # make linear module executing in evaluation mode\n            Linear(in_features=3, out_features=8, bias=False)\n            >>> y = linear_graph(x) # The build() method is called implicitly\n\n        Note:\n            ``build()`` method's inputs and outputs support list/tuple/dict,\n            but the item in them must be one of these types:\n\n            * ``Tensor``\n            * ``None``\n\n        \"\"\"\n        raise NotImplementedError(\n            \"nn.Graph.build() method must be overridden when subclassing the nn.Graph.\"\n        )\n\n    def __call__(self, *args, **kwargs):\n        r\"\"\"Call nn.Graph subclass instance to run your customized graph.\n\n        Call your customized graph after the instantiation:\n\n        For example:\n\n        .. code-block:: python\n\n            g = CustomGraph()\n            out_tensors = g(input_tensors)\n\n        The inputs of ``__call__`` method must match the inputs of ``build()``\n        method. And the ``__call__`` method will return outputs matching the\n        outputs of ``build()`` method.\n\n        Note:\n            The first call takes longer than later calls, because nn.Graph\n            will do the computaion graph generation and optimization at the first call.\n\n            Donot override this function.\n        \"\"\"\n        # For cache cache graphs with dynamic input shape.\n        if self._run_with_cache:\n            return self._dynamic_input_graph_cache(*args, **kwargs)\n\n        if not self._is_compiled:\n            self._compile(*args, **kwargs)\n\n        return self.__run(*args, **kwargs)\n\n    def add_optimizer(\n        self, optim: Optimizer, *, lr_sch: LRScheduler = None, is_sparse: bool = False,\n    ):\n        r\"\"\"Add an optimizer, an learning rate scheduler to the graph.\n\n        To do training with nn.Graph, you should do 2 more things:\n\n        1. Add at least one optimizer(learning rate schedulers are optional) with ``add_optimizer()`` method.\n        2. Call loss tensor's ``backward()`` method in ``build()`` method.\n\n        Note that the computaion graph will automatically execute these methods:\n\n        * optimizer's ``clip_grad()`` if a optimizer is set to do grad cliping.\n        * optimizer's ``step()``.\n        * optimizer's ``zero_grad()``.\n        * learn rate scheduler's ``step()``.\n\n        Also note that only scalar tensor are allowed to call ``backward()``\n        in ``nn.Graph.build()`` for the moment. So you may call methods such as ``Tensor.mean()``\n        to make the loss tensor a scalar tensor.\n\n        Note:\n            If you want to output the learning rate information for each step,\n            set the ``verbose`` parameter of the ``lr_scheduler`` to ``True``, and you will see the result at rank 0.\n\n            This feature is the same as eager mode.\n\n        For example:\n\n        .. code-block:: python\n\n            >>> import oneflow as flow\n            >>> loss_fn = flow.nn.MSELoss(reduction=\"sum\")\n            >>> model = flow.nn.Sequential(flow.nn.Linear(3, 1), flow.nn.Flatten(0, 1))\n            >>> optimizer = flow.optim.SGD(model.parameters(), lr=1e-6)\n            >>> class LinearTrainGraph(flow.nn.Graph):\n            ...     def __init__(self):\n            ...         super().__init__()\n            ...         self.model = model\n            ...         self.loss_fn = loss_fn\n            ...         # Add an optimizer\n            ...         self.add_optimizer(optimizer)\n            ...     def build(self, x, y):\n            ...         y_pred = self.model(x)\n            ...         loss = self.loss_fn(y_pred, y)\n            ...         # Call loss tensor's backward(), loss tensor must be a scalar tensor\n            ...         loss.backward()\n            ...         return loss\n\n            >>> linear_graph = LinearTrainGraph()\n            >>> x = flow.randn(10, 3)\n            >>> y = flow.randn(10)\n            >>> model.train() # make model executing in training mode\n            Sequential(\n              (0): Linear(in_features=3, out_features=1, bias=True)\n              (1): Flatten(start_dim=0, end_dim=1)\n            )\n            >>> for t in range(3):\n            ...     loss = linear_graph(x, y)\n\n        Args:\n            optim (oneflow.optim.Optimizer): The optimizer.\n            lr_sch : The learning rate scheduler, see oneflow.optim.lr_scheduler.\n            is_sparse: When set to be True, treat optim as a sparse optimizer. Default is False.\n        \"\"\"\n        opt_dict = dict()\n        assert optim is not None, \"optimizer cannot be None\"\n        assert isinstance(\n            optim, Optimizer\n        ), \"optimizer must be an instance of Optimizer\"\n\n        opt_dict[\"optim\"] = optim\n        opt_dict[\"is_sparse\"] = bool(is_sparse)\n        if lr_sch is not None:\n            assert isinstance(lr_sch, LRScheduler)\n            assert (\n                lr_sch.optimizer is optim\n            ), \"lr_scheduler's optimizer must be the same optimizer in add_optimizer.\"\n            opt_dict[\"lr_sch\"] = lr_sch\n            self._verbose = opt_dict[\"lr_sch\"].verbose\n            rank = get_rank()\n            if rank != 0:\n                self._verbose = False\n        oneflow._oneflow_internal.SetGraphLRVerbose(self._verbose)\n        self._opts.append(opt_dict)\n        # Set the training config if there is an optimizer add in graph.\n        if len(self._opts) == 1:\n            self.config._train(True)\n\n    def set_grad_scaler(self, grad_scaler: GradScaler = None):\n        r\"\"\"Set the GradScaler for gradient and loss scaling.\"\"\"\n        assert isinstance(grad_scaler, (GradScaler, StaticGradScaler))\n        self._grad_scaler = grad_scaler\n\n    def state_dict(\n        self, destination=None\n    ) -> Dict[str, Union[Dict[str, Tensor], Tensor]]:\n        r\"\"\"Returns a dictionary containing a whole state of the graph.\n\n        States of modules/optimizers/lr schedulers in a graph are included.\n\n        Keys of modules' state dict are corresponding to their name in the graph.\n        Values of modules' state dict are corresponding to their nn.Module's\n        state dict.\n\n        Other keys and tensors are states of optimizers/lr schedulers/etc.\n\n        Returns:\n            dict: a dictionary containing the whole state of the graph.\n\n        \"\"\"\n        # Sync to make sure states has been updated.\n        oneflow._oneflow_internal.eager.Sync()\n        if destination is None:\n            destination = OrderedDict()\n            destination._metadata = OrderedDict()\n        # Get states from sub module block\n        for name, block in self._blocks.items():\n            assert block.to(GraphModule).type == GraphBlockType.MODULE\n            sub_destination = OrderedDict()\n            sub_destination._metadata = OrderedDict()\n            module = block.to(Module)\n            if module is not None:\n                module.state_dict(\n                    sub_destination, \"\", keep_vars=False,\n                )\n            destination[name] = sub_destination\n        # Get additional states.\n        # Additional variables are states in Optimizer/LRScheduler and free eager tensors of nn.Graph.\n        if self._is_compiled:\n            # Get from _c_nn_graph.\n            additional_var_names = self._c_nn_graph.additional_var_names\n            additional_var_tensors = self._c_nn_graph.additional_var_tensors\n            assert len(additional_var_names) == len(additional_var_tensors)\n            for i in range(len(additional_var_names)):\n                additional_tensor = additional_var_tensors[i]\n                if not self._is_global_view and additional_tensor.is_global:\n                    additional_tensor = additional_tensor.to_local()\n                destination[additional_var_names[i]] = additional_tensor\n        else:\n            # Get from loaded dict.\n            for name, item in self._additional_variable_tobe_loaded.items():\n                destination[name] = item\n        return destination\n\n    def load_state_dict(\n        self,\n        state_dict: Dict[str, Union[Dict[str, Tensor], Tensor]],\n        strict: bool = True,\n    ):\n        r\"\"\"Copies module's states and other graph states from :attr:`state_dict`\n        into this graph. If :attr:`strict` is ``True``, then\n        the keys of :attr:`state_dict` must exactly match the keys returned\n        by this module's :meth:`nn.Graph.state_dict` function.\n\n        Args:\n            state_dict (dict): a dict containing module's states and other graph states.\n            strict (bool, optional): whether to strictly enforce that the keys\n                in :attr:`state_dict` match the keys returned by this graph's\n                :meth:`nn.Graph.state_dict` function. Default: ``True``.\n\n        Note:\n            nn.Graph's state dict can only be loaded before the first call of a graph.\n        \"\"\"\n        assert (\n            not self._is_compiled\n        ), \"nn.Graph's state dict can only be loaded before the first call of a graph.\"\n        # Additional variables are states in Optimizer or LRScheduler of nn.Graph.\n        for name, item in state_dict.items():\n            if name in self._blocks:\n                # 1 load parameter/buffer to Modules\n                self._blocks[name].to(Module).load_state_dict(item, strict)\n            else:\n                # 2 store other state to CNNGraph, CNNGraph load them after job pass\n                assert isinstance(item, Tensor)\n                self._additional_variable_tobe_loaded[name] = item\n\n    @property\n    def name(self):\n        r\"\"\"Name auto-generated for this graph.\"\"\"\n        return self._name\n\n    @property\n    def is_compiled(self):\n        r\"\"\"Whether this graph is compiled or not\n        \"\"\"\n        return self._is_compiled\n\n    @property\n    def training(self):\n        r\"\"\"In traninig mode if the graph has an optimizer.\"\"\"\n        return self.config.training\n\n    def debug(\n        self,\n        v_level: int = -1,\n        *,\n        ranks: Optional[Union[int, List[int]]] = None,\n        max_py_stack_depth: int = 2,\n        only_user_py_stack=True,\n        op_repr_with_py_stack=False,\n    ) -> None:\n        r\"\"\"Open or close debug mode of the graph.\n\n        If in debug mode, logs of computation graph building infos or warnings will be\n        printed. Otherwise, only errors will be printed.\n\n        Each nn.Module inside a nn.Graph also has a debug() method to enable debug mode.\n\n        Use ``v_level`` to choose verbose debug info level, default level is 0, max level is 3.\n        ``v_level`` -1 will disable the debug mode of the graph (i.e. no info will be printed).\n        ``v_level`` 0 will print warning and graph building stages. ``v_level`` 1 will additionally\n        print graph build info of each nn.Module. ``v_level`` 2 will additionally print graph build\n        info of each operation. ``v_level`` 3 will additionally print more detailed info of each\n        operation.\n\n        Use ``ranks`` to choose which rank to print the debug information.\n\n        Use ``max_py_stack_depth`` to specify the max Python stack depth for the debug information.\n\n        Use ``only_user_py_stack`` to only print the operators' locations which are from users' code or models.\n\n        Use ``op_repr_with_py_stack`` to print operators' locations when printing nn.Graph's repr.\n\n        For example:\n\n        .. code-block:: python\n\n            g = CustomGraph()\n            g.debug()  # Open debug mode\n            out_tensors = g(input_tensors)  # Will print log for debug at the first call\n\n        Args:\n            v_level (int): choose verbose debug info level, default v_level is 0, max v_level is 3. v_level can be set to -1 to close the debug mode.\n            ranks (int or list(int)): choose ranks to print the debug information. Default rank ``0``.\n                You can choose any valid rank. Ranks equals ``-1`` means debug on all ranks.\n            max_py_stack_depth(int): the maximum depth for the Python stack debug information. Default: ``2``.\n            only_user_py_stack(bool): only to print the operators' locations from users' code. Default: ``True``.\n            op_repr_with_py_stack(bool):  print operators' locations when printing nn.Graph's repr. Default: ``False``.\n        \"\"\"\n        assert isinstance(v_level, int)\n        assert v_level >= -1, \"The min verbose debug info level is -1.\"\n        assert v_level <= 3, \"The max verbose debug info level is 3.\"\n        assert max_py_stack_depth >= 0, \"The min max stack depth is 0.\"\n        assert isinstance(max_py_stack_depth, int)\n        assert isinstance(only_user_py_stack, bool)\n        assert isinstance(op_repr_with_py_stack, bool)\n\n        if ranks is None:\n            rank_list = [0]\n        elif isinstance(ranks, int):\n            rank_list = [ranks]\n        elif isinstance(ranks, list):\n            rank_list = ranks\n        else:\n            raise ValueError(\"ranks must be int or List[int].\")\n\n        my_rank = get_rank()\n        if -1 in rank_list or my_rank in rank_list:\n            self._debug = v_level >= 0\n            if self._debug:\n                self._debug_min_s_level = 0\n                self._debug_max_v_level = max(0, v_level)\n            for name, block in self._blocks.items():\n                assert block.to(GraphModule).type == GraphBlockType.MODULE\n                block.to(GraphModule).debug(\n                    v_level,\n                    ranks=ranks,\n                    max_py_stack_depth=max_py_stack_depth,\n                    only_user_py_stack=only_user_py_stack,\n                    op_repr_with_py_stack=op_repr_with_py_stack,\n                )\n\n        self._debug_max_py_stack_depth = max_py_stack_depth\n        self._debug_op_repr_with_py_stack = op_repr_with_py_stack\n        self._debug_only_user_py_stack = only_user_py_stack\n\n    def __repr__(self):\n        r\"\"\"For printing the graph structure.\n\n        The graph structure can be printed after graph instantiation.\n\n        After the first call of graph, inputs and outputs will be added to\n        the graph structure.\n\n        For example:\n\n        .. code-block:: python\n\n            g = CustomGraph()\n            print(g)\n\n            out_tensors = g(input_tensors)\n            print(g) # Inputs and Outputs infos are added\n\n        \"\"\"\n        child_lines = []\n        child_lines.append(add_indent(repr(self.config), 2))\n        if len(self._args_repr) > 0:\n            for in_str in self._args_repr:\n                input_str = add_indent(in_str, 2)\n                child_lines.append(input_str)\n\n        if len(self._blocks) > 0:\n            for n, m in self._blocks.items():\n                mod_str = repr(m)\n                mod_str = add_indent(mod_str, 2)\n                child_lines.append(mod_str)\n\n        for op_str in self._ops_repr():\n            child_lines.append(add_indent(op_str, 2))\n\n        if len(self._outs_repr) > 0:\n            for out_str in self._outs_repr:\n                output_str = add_indent(out_str, 2)\n                child_lines.append(output_str)\n\n        main_str = self._shallow_repr() + \": (\"\n        if len(child_lines) > 0:\n            main_str += \"\\n  \" + \"\\n  \".join(child_lines) + \"\\n\"\n        main_str += \")\"\n        return main_str\n\n    def _shallow_repr(self):\n        shallow_repr = \"(GRAPH:\" + self._name + \":\" + self.__class__.__name__ + \")\"\n        return shallow_repr\n\n    def _ops_repr(self):\n        r\"\"\"Generate operators' string representation of this graph\n        \"\"\"\n        if self._compiled_graph_proto is not None:\n            module_conf = self._compiled_graph_proto.module_name2module_conf[self.name]\n            if self._oneflow_internal_graph_ir__ is None:\n                self._oneflow_internal_graph_ir__ = GraphIR(self._compiled_graph_proto)\n            return operators_repr(\n                module_conf.ops,\n                self._oneflow_internal_graph_ir__,\n                self._debug_op_repr_with_py_stack,\n            )\n\n        return []\n\n    def __print(self, s_level=2, v_level=0, msg=None):\n        r\"\"\"Do print according to info level.\"\"\"\n        assert isinstance(s_level, int)\n        assert isinstance(v_level, int)\n        assert isinstance(msg, str) or isinstance(msg, Callable)\n        if s_level >= self._debug_min_s_level:\n            if (s_level > 0) or (s_level == 0 and v_level <= self._debug_max_v_level):\n                if isinstance(msg, str):\n                    print(msg, flush=True)\n                elif isinstance(msg, Callable):\n                    print(msg(), flush=True)\n\n    def _print(self, s_level=2, v_level=0, msg=None):\n        self.__print(s_level, v_level, msg)\n\n    @property\n    def _config_proto(self):\n        return self.config.proto\n\n    @property\n    def _optimization_conf_proto(self):\n        return self._session.resource\n\n    @property\n    def _graph_proto(self):\n        if not self._is_compiled:\n            self.__print(\n                2,\n                0,\n                f\"[ERROR]{self._shallow_repr()} has not been compiled, so it's graph proto is None.\"\n                \" You can call the graph to trigger it's compilation.\",\n            )\n        return self._forward_job_proto\n\n    @property\n    def _full_graph_proto(self):\n        if self._full_job_proto is None:\n            self.__print(\n                2,\n                0,\n                f\"[ERROR]{self._shallow_repr()} has not been compiled, so it's full graph proto is None.\"\n                \" You can call the graph to trigger it's compilation.\",\n            )\n        return self._full_job_proto\n\n    @_full_graph_proto.setter\n    def _full_graph_proto(self, full_job_proto):\n        assert (\n            not self._is_compiled\n        ), \"nn.Graph's full graph proto can only be set before the first compilation.\"\n        self._full_job_proto = full_job_proto\n        self._c_nn_graph.job = full_job_proto.SerializeToString()\n\n    @property\n    def _compiled_graph_proto(self):\n        if not self._is_compiled and self._compiled_job_proto is None:\n            self.__print(\n                2,\n                0,\n                f\"[ERROR]{self._shallow_repr()} has not been compiled, so it's compiled graph proto is None.\"\n                \" You can call the graph to trigger it's compilation.\",\n            )\n        return self._compiled_job_proto\n\n    def _generate_name(self):\n        child_name = self.__class__.__name__\n        if Graph._child_init_cnt.get(child_name) is None:\n            Graph._child_init_cnt[child_name] = 0\n        self._name = child_name + \"_\" + str(Graph._child_init_cnt[child_name])\n        Graph._child_init_cnt[child_name] += 1\n\n    def _state(self):\n        for _, b in self._blocks.items():\n            pa_gen = b.parameters(recurse=True)\n            for pa in pa_gen:\n                yield pa\n            bu_gen = b.buffers(recurse=True)\n            for bu in bu_gen:\n                yield bu\n\n    def __ensure_state_tensors_contiguous(self):\n        for state_block in self._state():\n            state_tensor = state_block.to(Tensor)\n            if not state_tensor.is_contiguous():\n                state_tensor.contiguous_()\n\n    def _filter_states(self):\n        state_tensor_set = set()\n        state_tensors = []\n        state_op_names = []\n\n        for state_block in self._state():\n            state_tensor = state_block.to(Tensor)\n            # If any state tensor is global tensor, graph is in global view.\n            if state_tensor.is_global:\n                self._is_global_view = True\n            if state_tensor in state_tensor_set:\n                continue\n            op_name = (\n                state_block.to(GraphTensor).name_prefix\n                + state_block.to(GraphTensor).name\n            )\n            state_tensor_set.add(state_tensor)\n            state_tensors.append(state_tensor)\n            state_op_names.append(op_name)\n\n            if state_block.to(GraphTensor).type == GraphBlockType.PARAMETER:\n                self._variables_conf[state_tensor] = VariableConfig(op_name)\n\n        self._state_tensor_tuple = convert_to_tensor_tuple(state_tensors)\n        self._eager_state_op_names = deepcopy(state_op_names)\n        return state_op_names\n\n    def _generate_config_proto(self):\n        self.config.proto.job_name = self._name\n        self._outputs_buffer_size = self.config._outputs_buffer_size\n\n        if self._grad_scaler is not None:\n            self._grad_scaler._generate_conf_for_graph(self.config.proto.train_conf)\n\n        for opt in self._opts:\n            opt_dict = OptDict(opt)\n            self.config._generate_optimizer_and_variable_configs(\n                opt_dict, self._variables_conf\n            )\n\n    def _create_states_builder(self):\n        state2lazy_builder = dict()\n        for state_block in self._state():\n            state_tensor = state_block.to(Tensor)\n            op_name = (\n                state_block.to(GraphTensor).name_prefix\n                + state_block.to(GraphTensor).name\n            )\n            if state_tensor in state2lazy_builder:\n                # Differe tensor block shares the same tensor, so they need to share the same\n                # builder.\n                state_block.set_lazy_origin_builder(state2lazy_builder[state_tensor])\n            else:\n                if state_block.to(GraphTensor).type == GraphBlockType.PARAMETER:\n                    assert state_tensor in self._variables_conf\n                    state_config = self._variables_conf[state_tensor]\n                    op_name = state_config.name\n                else:\n                    state_config = None\n                # Init a new lazy tensor builder\n                state_block.lazy_origin_builder().name = op_name\n                state_block.lazy_origin_builder().method = partial(\n                    graph_build_util.build_graph_state,\n                    op_name,\n                    state_tensor,\n                    state_config,\n                )\n                state2lazy_builder[state_tensor] = state_block.lazy_origin_builder()\n\n    def _mark_variable_gradients(self):\n        variable = []\n        gradients = []\n        for state_block in self._state():\n            if (\n                state_block.to(GraphTensor).type == GraphBlockType.PARAMETER\n                and state_block.to(Tensor).grad is not None\n                and state_block.to(Tensor).grad.is_lazy\n            ):\n                variable.append(state_block.to(Tensor))\n                gradients.append(state_block.to(Tensor).grad)\n        oneflow._oneflow_internal.nn.graph.MarkVariableGradients(variable, gradients)\n\n    @staticmethod\n    def trace(func):\n        \"\"\"Trace a function to do static graph and run with nn.Graph.\n\n        After decorating a function with ``trace``, the function is turned into a naive `nn.Graph`.\n\n        Note:\n            This is just a quick way to run a simple function with nn.Graph.\n            If you want to do training or model save/load, customize a nn.Graph class instead, donot use ``trace``.\n\n        For example:\n\n        .. code-block:: python\n\n            >>> import oneflow as flow\n            >>> @flow.nn.Graph.trace\n            ... def test_func(x):\n            ...     return x * 2\n            >>> input = flow.tensor((1, 2), dtype=flow.float32)\n            >>> out = test_func(input)\n            >>> out\n            tensor([2., 4.], dtype=oneflow.float32)\n\n        ..\n            Feature Stage of Feature [trace].\n            - Maintainer List [@strint]\n            - Current Stage [Pre-alpha, note that this is an experimental feature and maybe removed without notice.]\n\n        \"\"\"\n        assert inspect.isfunction(\n            func\n        ), f\"nn.Graph.trace only support function currently, so {func} must be a function.\"\n        graph_cls_name = func.__name__ + \"_graph\"\n\n        def init(self):\n            super(graph_cls_name, self).__init__()\n\n        def build(self, *args, **kwargs):\n            return func(*args, **kwargs)\n\n        graph_cls_name = type(\n            graph_cls_name, (Graph,), {\"__init__\": init, \"build\": build,},\n        )\n\n        a_graph = graph_cls_name()\n\n        return a_graph\n\n    def _compile(self, *args, **kwargs):\n        if self._run_with_cache:\n            return self._dynamic_input_graph_cache._compile(*args, **kwargs)\n\n        if not self._is_compiled:\n            if not self._build_with_shared_graph:\n                return self._compile_new(*args, **kwargs)\n            else:\n                return self._compile_from_shared(*args, **kwargs)\n        else:\n            warnings.warn(\n                f\"{self._shallow_repr()} has been compiled, no need to compile again.\"\n            )\n            return\n\n    def _compile_new(self, *args, **kwargs):\n        if (\n            len(args) != 0\n            and isinstance(args, (tuple, list))\n            and len(kwargs) == 0\n            and all(isinstance(arg, Tensor) for arg in args)\n        ):\n            self._is_simple_tuple_input = True\n\n        self.__ensure_input_tensors_contiguous(*args, **kwargs)\n        _, eager_outputs = self.build_graph(*args, **kwargs)\n        if isinstance(eager_outputs, (tuple, list)) and all(\n            isinstance(arg, Tensor) for arg in eager_outputs\n        ):\n            self._is_simple_tuple_output = True\n        self.finish_compile_and_init_runtime()\n        return eager_outputs\n\n    def enable_shared(self, mode: bool = True):\n        if mode:\n            assert (\n                not self._is_compiled\n            ), \" enable_shared must be set before graph compile.\"\n            # If enable shared, graph compile will generate more data for sharing.\n            self._enable_shared_from_this = True\n        else:\n            self._enable_shared_from_this = False\n\n    def share_from(self, shared_graph: \"Graph\") -> None:\n        assert isinstance(\n            shared_graph, Graph\n        ), \"shared_graph must be an instance of nn.Graph.\"\n        assert (\n            shared_graph._enable_shared_from_this\n        ), \"shared_graph must have been enabled to be shared.\"\n        assert shared_graph._is_compiled, \"shared_graph must have been compiled.\"\n        self._shared_graph = shared_graph\n        self._enable_shared_from_this = False\n        self._build_with_shared_graph = True\n\n    def _compile_from_shared(self, *args, **kwargs):\n        self.__print(\n            0, 0, self._shallow_repr() + \" start building a shared graph and plan.\"\n        )\n        build_graph_start = time.perf_counter()\n        self.__ensure_input_tensors_contiguous(*args, **kwargs)\n\n        self.__ensure_state_tensors_contiguous()\n        # Filter to get unique states in graph\n        state_op_names = self._filter_states()\n        # Generate new config.\n        if self._shared_graph._is_from_runtime_state_dict:\n            # To avoid same graph name with the loaded graphs.\n            self._name = (\n                self._name + \"_of_shared_from_loaded_\" + self._shared_graph.name\n            )\n        self._generate_config_proto()\n        # Deal with parameter and buffer\n        self._create_states_builder()\n\n        # Build current forward graph to generate some new attributes of this graph.\n        with graph_build_util.graph_build_context(self.config.proto, self._session):\n            self._job_id = (\n                oneflow._oneflow_internal.JobBuildAndInferCtx_GetCurrentJobId()\n            )\n            # Deal with inputs\n            (input_op_names, lazy_args, lazy_kwargs, args_repr, _,) = self.__build_io(\n                \"input\", graph_build_util.build_graph_input_arg, *args, **kwargs\n            )\n            # Deal with module in self.build(*args)\n            self._is_user_mode = True\n            outputs = self.build(*lazy_args, **lazy_kwargs)\n            self._is_user_mode = False\n\n            # Always pack output to remain type of outputs\n            outputs = (outputs,)\n            (\n                output_op_names,\n                build_eager_outputs,\n                _,  # empty kwargs return\n                outs_repr,\n                out2name,\n            ) = self.__build_io(\"output\", graph_build_util.build_graph_output, *outputs)\n\n            # Save forward graph job proto\n            self._forward_job_proto = c_api_util.GetCurrentJob()\n\n        # Create op name vectors from shared graph and this graph.\n        assert len(self._forward_job_proto.net.op) == len(\n            self._shared_graph._forward_job_proto.net.op\n        )\n        # This graph and the shared graph's original graph have same operators and operator order.\n        # We use this to find the corresponding operator in shared graph.\n        shared_op_names_from_ordered_original_graph = []\n        for op_idx in range(len(self._forward_job_proto.net.op)):\n            shared_op_names_from_ordered_original_graph.append(\n                self._shared_graph._forward_job_proto.net.op[op_idx].name\n            )\n\n        # Copy the completed graph from the shared graphwo and reuse it.\n        self._compiled_job_proto = deepcopy(self._shared_graph._compiled_graph_proto)\n        self._compiled_job_proto.job_conf.job_name = self._name\n        # Create a c nn graph to run with lazy runtime.\n        self._c_nn_graph = oneflow._oneflow_internal.nn.graph.CNNGraph(\n            self._name,\n            self._compiled_job_proto.SerializeToString(),\n            self._job_id,\n            self._session._session_ctx,\n        )\n\n        # Build graph with new inputs from a compiled job of a shared graph.\n        inputs_tensor_tuple = convert_to_tensor_tuple(\n            self.__flatten_io(\"input\", *args, **kwargs)\n        )\n        input_op_names = self._shared_graph._input_op_names\n        self._c_nn_graph.build_with_new_input_from_shared_graph(\n            input_op_names,\n            inputs_tensor_tuple,\n            shared_op_names_from_ordered_original_graph,\n            self._forward_job_proto.SerializeToString(),\n        )\n        # Get new compiled job proto\n        compiled_job_str = self._c_nn_graph.get_current_job_str()\n        self._compiled_job_proto = job_pb.Job()\n        self._compiled_job_proto.ParseFromString(compiled_job_str)\n\n        # Build output tensor buffer with new shape from the new compiled job proto.\n        self.__rebuild_outputs(\n            self._shared_graph._out2name,\n            self._compiled_job_proto,\n            self._shared_graph._build_eager_outputs,\n        )\n\n        # Register output/variable/buffer to _c_nn_graph\n        output_op_names = self._shared_graph._output_op_names\n        self._c_nn_graph.register_output_op_names_and_tensors(\n            output_op_names, self._outputs_tensor_tuple\n        )\n        self._state_tensor_tuple = self._shared_graph._state_tensor_tuple\n        self._c_nn_graph.register_variable_op_names_and_tensors(\n            self._shared_graph._state_op_names, self._state_tensor_tuple\n        )\n\n        self.__prepare_for_share_or_runtime_save(\n            input_op_names,\n            inputs_tensor_tuple,\n            output_op_names,\n            build_eager_outputs,\n            out2name,\n            *args,\n            **kwargs,\n        )\n\n        # Init runtime.\n        # TODO(strint): align states needs to care about free eager tensor.\n        self._c_nn_graph.align_states_after_logical_graph_compile()\n        self._c_nn_graph.compile_plan_for_runtime()\n        self._c_nn_graph.init_runtime()\n        self._is_compiled = True\n        build_graph_end = time.perf_counter()\n        self.__print(\n            0,\n            0,\n            self._shallow_repr()\n            + \" building a shared graph and plan Done! Cost time: \"\n            + str(round(build_graph_end - build_graph_start, 2))\n            + \"s.\"\n            + \"\\n\",\n        )\n\n        return (seq_to_func_return(self._eager_outputs_buffer[0], True),)\n\n    def enable_save_runtime_state_dict(self, mode: bool = True):\n        if mode:\n            assert (\n                not self._is_compiled\n            ), \" enable_save_runtime_state_dict must be set before graph compile.\"\n            # If enable save runtime states, graph compile will generate more data for save.\n            self._enable_save_runtime_state_dict = True\n        else:\n            self._enable_save_runtime_state_dict = False\n\n    def runtime_state_dict(\n        self, destination=None, with_eager=False\n    ) -> Union[\n        Dict[str, Union[Dict[str, Tensor], str]],\n        Dict[str, Dict[str, Union[Dict[str, Tensor], str]]],\n    ]:\n        if self._run_with_cache:\n            return self._dynamic_input_graph_cache.runtime_state_dict(\n                with_eager=with_eager\n            )\n\n        assert (\n            self._enable_save_runtime_state_dict\n        ), \"nn.Graph's runtime state dict can only be got when enable_save_runtime_state_dict is set with True.\"\n        assert (\n            self._is_compiled\n        ), \"nn.Graph's runtime state dict can only be got after the first call of a graph.\"\n\n        # Sync to make sure states has been updated.\n        oneflow._oneflow_internal.eager.Sync()\n        if destination is None:\n            destination = OrderedDict()\n            destination._metadata = OrderedDict()\n\n        destination[\"oneflow_version\"] = oneflow.__version__\n        destination[\"graph_name\"] = self.name\n        destination[\"job_id\"] = self._job_id\n\n        def _fill_sub_destination(dest_dict, name_list, tensor_tuple):\n            assert len(tensor_tuple) == len(name_list)\n            for name_idx in range(len(name_list)):\n                tensor_item = tensor_tuple[name_idx]\n                device_str = \":\".join(\n                    (tensor_item.device.type, str(tensor_item.device.index))\n                )\n                dest_dict[name_list[name_idx]] = (tensor_item, device_str)\n\n        # This is original outputs is needed to build output buffer.\n        tuple_idx = -1\n\n        def gen_index_in_tuple(item):\n            nonlocal tuple_idx\n            if isinstance(item, Tensor):\n                tuple_idx += 1\n                return \"_OFTPI\" + str(tuple_idx)  # oneflow tuple index\n            else:\n                return item\n\n        inputs_sub_destination = OrderedDict()\n        _fill_sub_destination(\n            inputs_sub_destination, self._input_op_names, self._inputs_tensor_tuple\n        )\n\n        _eager_inputs_args, _eager_inputs_kwargs = self.__map_io_lite(\n            gen_index_in_tuple, *self.inputs_original[0], **self.inputs_original[1],\n        )\n        destination[\"inputs\"] = inputs_sub_destination\n        destination[\"inputs_original\"] = (_eager_inputs_args, _eager_inputs_kwargs)\n\n        tuple_idx = -1\n        _eager_outputs, _ = self.__map_io_lite(gen_index_in_tuple, *self._eager_outputs)\n        destination[\"outputs_original\"] = _eager_outputs\n        assert len(self._outputs_tensor_tuple) == tuple_idx + 1\n        outputs_sub_destination = OrderedDict()\n        _fill_sub_destination(\n            outputs_sub_destination, self._output_op_names, self._outputs_tensor_tuple\n        )\n        destination[\"outputs\"] = outputs_sub_destination\n\n        destination[\"oneflow_with_eager_tensor\"] = with_eager\n        if not self._build_with_shared_graph:\n            _state_tensor_tuple4save = []\n            if with_eager:\n                _state_tensor_tuple4save = self._state_tensor_tuple\n            else:\n                assert len(self._state_tensor_tuple) == len(self._state_op_names)\n                for state_idx in range(len(self._state_tensor_tuple)):\n                    if self._state_op_names[state_idx] in self._eager_state_op_names:\n                        # This state tensor is from eager module. Just save a dummy tensor here.\n                        _state_tensor_tuple4save.append(\n                            oneflow.Tensor().to(\n                                self._state_tensor_tuple[state_idx].device\n                            )\n                        )\n                    else:\n                        _state_tensor_tuple4save.append(\n                            self._state_tensor_tuple[state_idx]\n                        )\n            states_sub_destination = OrderedDict()\n            _fill_sub_destination(\n                states_sub_destination, self._state_op_names, _state_tensor_tuple4save\n            )\n            destination[\"states\"] = states_sub_destination\n\n        destination[\"exe_plan\"] = self._c_nn_graph.plan\n        if self._enable_shared_from_this:\n            destination[\"forward_graph\"] = self._forward_job_proto\n            destination[\"compile_graph\"] = self._compiled_job_proto\n\n        destination[\"id_state\"] = oneflow._oneflow_internal.get_id_state()\n\n        return destination\n\n    def load_runtime_state_dict(\n        self,\n        state_dict: Union[\n            Dict[str, Union[Dict[str, Tensor], str]],\n            Dict[str, Dict[str, Union[Dict[str, Tensor], str]]],\n        ],\n        *,\n        warmup_with_run: bool = True,\n    ) -> None:\n        if self._run_with_cache:\n            return self._dynamic_input_graph_cache.load_runtime_state_dict(\n                state_dict, warmup_with_run=warmup_with_run\n            )\n\n        build_graph_start = time.perf_counter()\n\n        # init id state\n        oneflow._oneflow_internal.set_id_state(state_dict[\"id_state\"])\n\n        self._is_from_runtime_state_dict = True\n        self._name = state_dict[\"graph_name\"]\n        if \"oneflow_version\" not in state_dict:\n            state_dict[\"oneflow_version\"] = \"none\"\n        if state_dict[\"oneflow_version\"] != oneflow.__version__:\n            warnings.warn(\n                f\"nn.Graph {self._name} WARNING: current oneflow version ({oneflow.__version__}) is loading \"\n                f\"runtime_state_dict from a different version ({state_dict['oneflow_version']}), \"\n                \"there may has compatibility problems.\"\n            )\n        # Generate new config.\n        self._generate_config_proto()\n        self.__print(0, 0, self._shallow_repr() + \" start loading a graph and plan.\")\n        self._job_id = state_dict[\"job_id\"]\n\n        # Create a c nn graph to run with lazy runtime.\n        self._c_nn_graph = oneflow._oneflow_internal.nn.graph.CNNGraph(\n            self._name,\n            state_dict[\"exe_plan\"],\n            self._job_id,\n            self._session._session_ctx,\n            True,  # Init from plan\n        )\n\n        def _load_list_from_state_dict(state_dict):\n            name_list = []\n            tensor_list = []\n            for name, item in state_dict.items():\n                name_list.append(name)\n                tensor_of_item, device_of_item = item\n                tensor_list.append(tensor_of_item.to(device_of_item))\n            return (name_list, convert_to_tensor_tuple(tensor_list))\n\n        self._input_op_names, self._inputs_tensor_tuple = _load_list_from_state_dict(\n            state_dict[\"inputs\"]\n        )\n        self._output_op_names, self._outputs_tensor_tuple = _load_list_from_state_dict(\n            state_dict[\"outputs\"]\n        )\n        _eager_inputs_args_index, _eager_inputs_kwargs_index = state_dict[\n            \"inputs_original\"\n        ]\n        _eager_outputs_index = state_dict[\"outputs_original\"]\n\n        def get_tensor_in_tuple(tensor_tuple, map_item):\n            if isinstance(map_item, str) and map_item.startswith(\"_OFTPI\"):\n                of_idx = int(map_item[6:])\n                return tensor_tuple[of_idx]\n            else:\n                return map_item\n\n        _eager_inputs_args, _eager_inputs_kwargs = self.__map_io_lite(\n            lambda map_item: get_tensor_in_tuple(self._inputs_tensor_tuple, map_item),\n            *_eager_inputs_args_index,\n            **_eager_inputs_kwargs_index,\n        )\n        _eager_outputs, _ = self.__map_io_lite(\n            lambda map_item: get_tensor_in_tuple(self._outputs_tensor_tuple, map_item),\n            *_eager_outputs_index,\n        )\n        self._eager_outputs = _eager_outputs\n\n        # The base graph need extra info to create new shared graph\n        if self._enable_shared_from_this:\n            self._forward_job_proto = state_dict[\"forward_graph\"]\n            self._compiled_job_proto = state_dict[\"compile_graph\"]\n            self._build_eager_outputs = self._eager_outputs\n            self._out2name = dict()\n            for output_idx in range(len(self._output_op_names)):\n                self._out2name[\n                    self._outputs_tensor_tuple[output_idx]\n                ] = self._output_op_names[output_idx]\n\n        # Load state tensor of modules\n        if \"oneflow_with_eager_tensor\" in state_dict:\n            with_eager = state_dict[\"oneflow_with_eager_tensor\"]\n        else:\n            with_eager = True\n\n        if self._build_with_shared_graph:\n            self._state_op_names = self._shared_graph._state_op_names\n            self._state_tensor_tuple = self._shared_graph._state_tensor_tuple\n        else:\n            self._state_op_names, self._state_tensor_tuple = _load_list_from_state_dict(\n                state_dict[\"states\"]\n            )\n            if type(self) != Graph:\n                # Graph init with eager module, try to share mem with eager module\n                states_from_eager = dict()\n                for state_block in self._state():\n                    state_tensor = state_block.to(Tensor)\n                    state_op_name = (\n                        state_block.to(GraphTensor).name_prefix\n                        + state_block.to(GraphTensor).name\n                    )\n                    states_from_eager[state_op_name] = state_tensor\n                for s_idx, s_name in enumerate(self._state_op_names):\n                    if s_name in states_from_eager:\n                        state_tensor_from_eager = states_from_eager[s_name]\n                        assert (\n                            state_tensor_from_eager.device\n                            == self._state_tensor_tuple[s_idx].device\n                        )\n                        if with_eager:\n                            assert oneflow.allclose(\n                                state_tensor_from_eager, self._state_tensor_tuple[s_idx]\n                            )\n                        self._state_tensor_tuple[s_idx] = state_tensor_from_eager\n                if not with_eager:\n                    for s_idx, s_name in enumerate(self._state_op_names):\n                        if (oneflow.numel(self._state_tensor_tuple[s_idx]) == 0) and (\n                            s_name not in states_from_eager\n                        ):\n                            warnings.warn(\n                                f\"Current graph is missing parameter {s_name}, but load_runtime_state_dict needs it. This may cause error later.\"\n                            )\n\n        self.__build_outputs_buffer()\n\n        self._c_nn_graph.register_input_op_names_and_tensors(\n            self._input_op_names, self._inputs_tensor_tuple\n        )\n        self._c_nn_graph.register_output_op_names_and_tensors(\n            self._output_op_names, self._outputs_tensor_tuple\n        )\n        self._c_nn_graph.register_variable_op_names_and_tensors(\n            self._state_op_names, self._state_tensor_tuple\n        )\n        self._c_nn_graph.align_states_after_logical_graph_compile()\n        self._c_nn_graph.init_runtime()\n        self._is_compiled = True\n        if warmup_with_run:\n            self.__run(\n                *_eager_inputs_args, **_eager_inputs_kwargs\n            )  # pre-run to warm up\n            oneflow._oneflow_internal.eager.Sync()\n        build_graph_end = time.perf_counter()\n        self.__print(\n            0,\n            0,\n            self._shallow_repr()\n            + \" load a graph and plan Done! Cost time: \"\n            + str(round(build_graph_end - build_graph_start, 2))\n            + \"s.\"\n            + \"\\n\",\n        )\n\n    @staticmethod\n    def runtime_state_dict_to(\n        state_dict: Union[\n            Dict[str, Union[Dict[str, Tensor], str]],\n            Dict[str, Dict[str, Union[Dict[str, Tensor], str]]],\n        ],\n        device: str,\n    ) -> Union[\n        Dict[str, Union[Dict[str, Tensor], str]],\n        Dict[str, Dict[str, Union[Dict[str, Tensor], str]]],\n    ]:\n        if \"job_id\" not in state_dict:\n            from oneflow.nn.graph.cache import GraphCache\n\n            return GraphCache.runtime_state_dict_to(state_dict, device)\n\n        dest_device = oneflow.device(device)\n        assert dest_device.type == \"cuda\", \"device must be cuda.\"\n\n        destination = OrderedDict()\n        destination._metadata = OrderedDict()\n        destination[\"oneflow_version\"] = state_dict[\"oneflow_version\"]\n        destination[\"graph_name\"] = state_dict[\"graph_name\"]\n        destination[\"job_id\"] = state_dict[\"job_id\"]\n        destination[\"inputs\"] = _rsd_sub_destination_to(state_dict[\"inputs\"], device)\n        destination[\"inputs_original\"] = state_dict[\"inputs_original\"]\n        destination[\"outputs\"] = _rsd_sub_destination_to(state_dict[\"outputs\"], device)\n        destination[\"outputs_original\"] = state_dict[\"outputs_original\"]\n        destination[\"oneflow_with_eager_tensor\"] = state_dict[\n            \"oneflow_with_eager_tensor\"\n        ]\n        if \"states\" in state_dict:\n            destination[\"states\"] = _rsd_sub_destination_to(\n                state_dict[\"states\"], device\n            )\n        destination[\"exe_plan\"] = _plan_to(state_dict[\"exe_plan\"], dest_device)\n        if \"forward_graph\" in state_dict:\n            forward_graph = deepcopy(state_dict[\"forward_graph\"])\n            _job_to(forward_graph, dest_device)\n            destination[\"forward_graph\"] = forward_graph\n        if \"compile_graph\" in state_dict:\n            compile_graph = deepcopy(state_dict[\"compile_graph\"])\n            _job_to(compile_graph, dest_device)\n            destination[\"compile_graph\"] = compile_graph\n        destination[\"id_state\"] = state_dict[\"id_state\"]\n        return destination\n\n    def build_graph(self, *args, **kwargs):\n        # Build graph\n        try:\n            self.__print(0, 0, self._shallow_repr() + \" start building graph.\")\n            assert not self._is_compiled, (\n                \"nn.Graph \" + self._name + \" has already been compiled.\"\n            )\n            build_graph_start = time.perf_counter()\n            with graph_build_util.DebugScopeContext(\n                self._debug_min_s_level,\n                self._debug_max_v_level,\n                self._debug,\n                self._debug_max_py_stack_depth,\n                self._debug_only_user_py_stack,\n            ):\n                outputs = self.__build_graph(*args, **kwargs)\n            build_graph_end = time.perf_counter()\n            self.__print(\n                0,\n                0,\n                self._shallow_repr()\n                + \" building graph Done! Cost time: \"\n                + str(round(build_graph_end - build_graph_start, 2))\n                + \"s.\"\n                + \"\\n\",\n            )\n            return outputs\n        except:\n            self.__print(\n                2, 0, \"[ERROR]\" + self._shallow_repr() + \" building graph got error.\"\n            )\n            raise\n\n    def finish_compile_and_init_runtime(self):\n        additional_var_names = list()\n        additional_var_tensors = list()\n        for name, tensor in self._additional_variable_tobe_loaded.items():\n            additional_var_names.append(name)\n            additional_var_tensors.append(tensor)\n        if len(additional_var_names) > 0:\n            self._c_nn_graph.register_additional_variable_names_and_tensors(\n                additional_var_names, convert_to_tensor_tuple(additional_var_tensors)\n            )\n        # Sync to make sure states has been loaded.\n        oneflow._oneflow_internal.eager.Sync()\n\n        # Complie graph to execution plan and init Runtime\n        try:\n            self.__print(\n                0, 0, self._shallow_repr() + \" start building plan.\",\n            )\n            compile_and_init_start = time.perf_counter()\n            with graph_build_util.DebugScopeContext(\n                self._debug_min_s_level,\n                self._debug_max_v_level,\n                self._debug,\n                self._debug_max_py_stack_depth,\n                self._debug_only_user_py_stack,\n            ):\n                self._c_nn_graph.align_states_after_logical_graph_compile()\n                self._c_nn_graph.complete_graph_for_runtime()\n                # Get compiled job\n                compiled_job_str = self._c_nn_graph.get_current_job_str()\n                self._compiled_job_proto = job_pb.Job()\n                self._compiled_job_proto.ParseFromString(compiled_job_str)\n                self.__print(\n                    0, 1, lambda: f\"{self.name} with operators:\\n\" + self.__repr__()\n                )\n                self._c_nn_graph.compile_plan_for_runtime()\n                self._c_nn_graph.init_runtime()\n\n            compile_and_init_end = time.perf_counter()\n            self.__print(\n                0,\n                0,\n                self._shallow_repr()\n                + \" building plan Done! Cost time: \"\n                + str(round(compile_and_init_end - compile_and_init_start, 2))\n                + \"s.\"\n                + \"\\n\",\n            )\n        except Exception as e:\n            print(e, file=sys.stderr)\n            self.__print(\n                2, 0, \"[ERROR]\" + self._shallow_repr() + \" building plan got error.\"\n            )\n            raise\n\n        self._is_compiled = True\n        # After compile, _additional_variable_tobe_loaded is useless.\n        self._additional_variable_tobe_loaded.clear()\n\n    def __build_graph(self, *args, **kwargs):\n        self.__ensure_state_tensors_contiguous()\n\n        # Filter to get unique states in graph\n        state_op_names = self._filter_states()\n\n        self._generate_config_proto()\n\n        # Deal with parameter and buffer\n        self.__print(\n            0,\n            1,\n            self._shallow_repr()\n            + \" start building graph builders of parameters and buffers.\",\n        )\n        self._create_states_builder()\n        self.__print(\n            0,\n            1,\n            self._shallow_repr()\n            + \" end building graph builders of parameters and buffers.\",\n        )\n\n        with graph_build_util.graph_build_context(self.config.proto, self._session):\n            # Deal with inputs\n            self.__print(0, 1, self._shallow_repr() + \" start building graph inputs.\")\n            (\n                input_op_names,\n                lazy_args,\n                lazy_kwargs,\n                self._args_repr,\n                _,\n            ) = self.__build_io(\n                \"input\", graph_build_util.build_graph_input_arg, *args, **kwargs\n            )\n            self.__print(0, 1, self._shallow_repr() + \" end building graph inputs.\")\n\n            # Deal with module in self.build(*args)\n            self.__print(0, 1, self._shallow_repr() + \" start building graph modules.\")\n            self._is_user_mode = True\n            outputs = self.build(*lazy_args, **lazy_kwargs)\n            self._is_user_mode = False\n            self.__print(0, 1, self._shallow_repr() + \" end building graph modules.\")\n\n            # Deal with outputs\n            self.__print(0, 1, self._shallow_repr() + \" start building graph outputs.\")\n            # Always pack output to remain type of outputs\n            outputs = (outputs,)\n\n            (\n                output_op_names,\n                build_eager_outputs,\n                _,  # empty kwargs return\n                self._outs_repr,\n                out2name,\n            ) = self.__build_io(\"output\", graph_build_util.build_graph_output, *outputs)\n\n            self.__print(0, 1, self._shallow_repr() + \" end building graph outputs.\")\n\n            # Save forward graph job proto\n            self._forward_job_proto = c_api_util.GetCurrentJob()\n\n            if self.training:\n                self._mark_variable_gradients()\n\n            self.__print(\n                0,\n                1,\n                self._shallow_repr() + \" start building graph with compile passes.\",\n            )\n            self.env_enable_mlir_inference_opt = os.getenv(\n                \"ONEFLOW_MLIR_ENABLE_INFERENCE_OPTIMIZATION\"\n            )\n            enable_mlir_inference_opt = (\n                False\n                if self.env_enable_mlir_inference_opt is None\n                else bool(self.env_enable_mlir_inference_opt)\n            )\n            modules_has_training = False\n            for item in self._blocks.values():\n                if item.to(Module).training:\n                    modules_has_training = True\n                    break\n            if (\n                modules_has_training or self.training or self._is_global_view\n            ) and enable_mlir_inference_opt:\n                log_for_mlir_inference_opt = lambda extra_info: logging.warning(\n                    f\"environment variable ONEFLOW_MLIR_ENABLE_INFERENCE_OPTIMIZATION will be ignored {extra_info}.\"\n                )\n                if self.training:\n                    log_for_mlir_inference_opt(\"in training mode\")\n\n                if modules_has_training and not self.training:\n                    log_for_mlir_inference_opt(\n                        \"when not all modules in graph are in eval mode\"\n                    )\n\n                if self._is_global_view:\n                    log_for_mlir_inference_opt(\"in global mode\")\n                enable_mlir_inference_opt = False\n                del os.environ[\"ONEFLOW_MLIR_ENABLE_INFERENCE_OPTIMIZATION\"]\n            oneflow._oneflow_internal.FillVariableTensorMgr(\n                state_op_names, self._state_tensor_tuple\n            )\n            # Optimize the graph with compile passes.\n            oneflow._oneflow_internal.CurJobBuildAndInferCtx_Complete()\n            # Save full graph job proto after job Complete for find real output blob shape and build it.\n            self._full_job_proto = c_api_util.GetCurrentJob()\n            self._job_id = (\n                oneflow._oneflow_internal.JobBuildAndInferCtx_GetCurrentJobId()\n            )\n            self.__print(\n                0, 1, self._shallow_repr() + \" end building graph with compile passes.\"\n            )\n\n            # Re-build outputs accoring to full graph and outputs buffer config.\n            self.__print(\n                0,\n                1,\n                self._shallow_repr()\n                + \" start re-building graph outputs for optimizatioin.\",\n            )\n            self.__rebuild_outputs(out2name, self._full_job_proto, build_eager_outputs)\n            self.__print(\n                0,\n                1,\n                self._shallow_repr()\n                + \" end re-building graph outputs for optimizatioin.\",\n            )\n            # Create a c nn graph to run with lazy runtime.\n            self._c_nn_graph = oneflow._oneflow_internal.nn.graph.CNNGraph(\n                self._name,\n                self._full_job_proto.SerializeToString(),\n                self._job_id,\n                self._session._session_ctx,\n            )\n            # Register input/output/variable/buffer to _c_nn_graph\n            inputs_tensor_tuple = convert_to_tensor_tuple(\n                self.__flatten_io(\"input\", *args, **kwargs)\n            )\n            self._c_nn_graph.register_input_op_names_and_tensors(\n                input_op_names, inputs_tensor_tuple\n            )\n            self._c_nn_graph.register_output_op_names_and_tensors(\n                output_op_names, self._outputs_tensor_tuple\n            )\n            (\n                self._state_op_names,\n                state_tensors,\n            ) = oneflow._oneflow_internal.DumpVariableTensorMgr()\n            self._state_tensor_tuple = convert_to_tensor_tuple(state_tensors)\n\n            self._c_nn_graph.register_variable_op_names_and_tensors(\n                self._state_op_names, self._state_tensor_tuple\n            )\n\n            self.__prepare_for_share_or_runtime_save(\n                input_op_names,\n                inputs_tensor_tuple,\n                output_op_names,\n                build_eager_outputs,\n                out2name,\n                *args,\n                **kwargs,\n            )\n\n        # Clear useless dict used in graph build.\n        self._unique_global_op_dict.clear()\n        self._unique_identity_op_dict.clear()\n\n        # Always pack outputs to remain type of outputs\n        return (\n            self._full_job_proto,\n            seq_to_func_return(self._eager_outputs_buffer[0], True),\n        )\n\n    def __prepare_for_share_or_runtime_save(\n        self,\n        input_op_names,\n        inputs_tensor_tuple,\n        output_op_names,\n        build_eager_outputs,\n        out2name,\n        *args,\n        **kwargs,\n    ):\n        if self._enable_save_runtime_state_dict or self._enable_shared_from_this:\n            self._input_op_names = input_op_names\n            self._output_op_names = output_op_names\n\n        if self._enable_shared_from_this:\n            self._build_eager_outputs = build_eager_outputs\n            self._out2name = out2name\n\n        if self._enable_save_runtime_state_dict:\n            self._inputs_tensor_tuple = inputs_tensor_tuple\n            self.inputs_original = (args, kwargs)\n\n    def __rebuild_outputs(\n        self, out2name=None, compiled_graph_proto=None, build_eager_outputs=None\n    ):\n        # NOTE(chengcheng):\n        #   Lazy build output eager tensors.\n        #\n        #   After JobBuildAndInferCtxt.Complete, the output tensor shape\n        #   could be changed by JobPass, such as GradientAccumulationRewritePass.\n        def build_real_output(fake_eager_out):\n            lbn = out2name[fake_eager_out] + \"/out\"\n            assert lbn in compiled_graph_proto.helper.lbn2logical_blob_desc\n            blob_conf = compiled_graph_proto.helper.lbn2logical_blob_desc[lbn]\n\n            shape = tuple(blob_conf.shape.dim)\n            dtype = fake_eager_out.dtype\n\n            with oneflow._oneflow_internal.lazy_mode.guard(False):\n                if fake_eager_out.is_global:\n                    eager_out = oneflow.empty(\n                        shape,\n                        dtype=dtype,\n                        placement=fake_eager_out.placement,\n                        sbp=fake_eager_out.sbp,\n                    )\n                else:\n                    eager_out = oneflow.empty(\n                        shape, dtype=dtype, device=fake_eager_out.device\n                    )\n\n            return eager_out\n\n        self._eager_outputs, _ = self.__map_io(\n            \"output\", build_real_output, *build_eager_outputs\n        )\n\n        self.__build_outputs_buffer()\n\n    def __build_outputs_buffer(self):\n        def convert_to_synced_tensor_tuple(*args):\n            tensor_tuple = convert_to_tensor_tuple(*args)\n            # tensors acting as buffer should be synced once upon created.\n            oneflow._oneflow_internal.nn.graph.SoftSyncNNGraphBuffers(\n                tensor_tuple, self._c_nn_graph\n            )\n            return tensor_tuple\n\n        self._outputs_tensor_tuple = convert_to_synced_tensor_tuple(\n            self.__flatten_io(\"output\", *self._eager_outputs)\n        )\n        self._eager_outputs_buffer = [\n            self._eager_outputs,\n        ]\n        self._outputs_tensor_tuple_buffer = [\n            self._outputs_tensor_tuple,\n        ]\n\n        # Make outputs buffer\n        for i in range(self._outputs_buffer_size - 1):\n            outputs_buffer_item, _ = self.__empty_like_io(\n                \"output\", *self._eager_outputs\n            )\n            self._eager_outputs_buffer.append(outputs_buffer_item)\n            outputs_tensor_tuple_buffer_item = convert_to_synced_tensor_tuple(\n                self.__flatten_io(\"output\", *outputs_buffer_item)\n            )\n            self._outputs_tensor_tuple_buffer.append(outputs_tensor_tuple_buffer_item)\n\n        self.__check_outputs_buffer()\n\n    def __check_outputs_buffer(self):\n        has_len = len(self._outputs_tensor_tuple_buffer)\n        assert (\n            has_len == self._outputs_buffer_size\n        ), f\"nn.Graph's outputs buffer size {has_len} donot match the set value {self._outputs_buffer_size}.\"\n        # Check there is not duplicated outputs buffer tensor.\n        out_id_dic = dict()\n\n        def check_id_and_add(t, name):\n            if t is not None:\n                tid = id(t)\n                assert (\n                    tid not in out_id_dic\n                ), f\"nn.Graph's outputs buffer add buffer tensor tid {tid} has conflict, new item name {name}, old item name {out_id_dic[tid]}.\"\n                out_id_dic[tid] = name\n\n        for b_idx, buffer in enumerate(self._outputs_tensor_tuple_buffer):\n            for i_idx, item in enumerate(buffer):\n                check_id_and_add(\n                    item, \"graph_ouputs_buffer_\" + str(b_idx) + \"_\" + str(i_idx)\n                )\n\n    def __run(self, *args, **kwargs):\n        try:\n            flattened_eager_args = self.__ensure_input_tensors_contiguous_and_flatten(\n                *args, **kwargs\n            )\n            if oneflow.support.env_var_util.parse_boolean_from_env(\n                \"ONEFLOW_RUN_GRAPH_BY_VM\", False\n            ):\n                eager_outputs = oneflow._oneflow_internal.nn.graph.RunLazyNNGraphByVM(\n                    convert_to_tensor_tuple(flattened_eager_args), self._c_nn_graph,\n                )\n                if len(eager_outputs) == 1:\n                    return eager_outputs[0]\n                else:\n                    return eager_outputs\n            else:\n                outputs_tensor_tuple = self._outputs_tensor_tuple_buffer[\n                    self._cur_index_of_ouputs_buffer\n                ]\n                eager_outputs = self._eager_outputs_buffer[\n                    self._cur_index_of_ouputs_buffer\n                ]\n                # oneflow._oneflow_internal.eager.Sync() NOTE(chengcheng): Need Sync?\n                oneflow._oneflow_internal.nn.graph.RunLazyNNGraph(\n                    convert_to_tensor_tuple(flattened_eager_args),\n                    outputs_tensor_tuple,\n                    self._c_nn_graph,\n                )\n                # Update outputs buffer reading index\n                self._cur_index_of_ouputs_buffer += 1\n                if self._cur_index_of_ouputs_buffer >= self._outputs_buffer_size:\n                    self._cur_index_of_ouputs_buffer = 0\n\n                # Copy outputs from buffer\n                eager_outputs, _ = self.__copy_io(\"output\", *eager_outputs)\n\n                # Make sure that last used devices of tensors in `outputs_tensor_tuple` are\n                # \"critical_section\".\n                # NNGraph's execution flow will be broken if `last_used_device` of `outputs_tensor_tuple`\n                # are not \"critical_section\".\n                oneflow._oneflow_internal.nn.graph.SoftSyncNNGraphBuffers(\n                    outputs_tensor_tuple, self._c_nn_graph\n                )\n        except:\n            self.__print(\n                2,\n                0,\n                \"[ERROR]\"\n                + self._shallow_repr()\n                + \" run got error: \"\n                + sys_exc_error_msg(),\n            )\n            raise\n\n        # Always pack outputs to remain type of outputs\n        return seq_to_func_return(eager_outputs, True)\n\n    def __build_io(self, io_type, build_func, *args, **kwargs):\n        assert io_type in (\"input\", \"output\")\n        op_names = []\n        args_repr = []\n        tensor2op_name = {}\n\n        def build_tensor_or_any(tensor, name, repr_str):\n            if isinstance(tensor, Tensor):\n                build_arg = build_func(name, tensor)\n                op_names.append(name)\n                tensor2op_name[build_arg] = name\n            else:\n                build_arg = tensor\n\n            args_repr.append(repr_str)\n            self.__print(0, 1, repr_str)\n            return build_arg\n\n        args_tree = ArgsTree(\n            (args, kwargs), True, \"_\" + self.name + \"_\" + io_type, None\n        )\n\n        def leaf_arg_fn(arg):\n            name = arg.prefix() + \"_\" + arg.name()\n            if isinstance(arg.value(), Tensor):\n                arg_repr = self.__io_item_check_and_gen_repr(\n                    arg.value(), Tensor, io_type, name\n                )\n                build_arg = build_tensor_or_any(arg.value(), name, arg_repr)\n                return build_arg\n            else:  # Opaque\n                arg_repr = self.__io_item_check_and_gen_repr(\n                    arg.value(), None, io_type, name\n                )\n                build_arg = build_tensor_or_any(arg.value(), name, arg_repr)\n\n        out = args_tree.map_leaf(leaf_arg_fn)\n        build_args = out[0]\n        build_kwargs = out[1]\n\n        return op_names, build_args, build_kwargs, args_repr, tensor2op_name\n\n    def __io_item_check_and_gen_repr(self, item, expect_type, io_type, name):\n        assert io_type in (\"input\", \"output\")\n        if expect_type is None:\n            repr_str = (\n                \"[WARNING](\"\n                + io_type.upper()\n                + \":\"\n                + name\n                + \":\"\n                + str(type(item))\n                + \")\"\n            )\n            self.__print(1, 0, repr_str)\n            return repr_str\n        elif expect_type is not None and isinstance(item, expect_type):\n            if isinstance(item, Tensor):\n                repr_str = (\n                    \"(\" + io_type.upper() + \":\" + name + \":\" + item._meta_repr() + \")\"\n                )\n            else:\n                repr_str = (\n                    \"[WARNING](\"\n                    + io_type.upper()\n                    + \":\"\n                    + name\n                    + \":\"\n                    + str(type(item))\n                    + \")\"\n                )\n            return repr_str\n        else:\n            repr_str = (\n                \"[ERROR](\" + io_type.upper() + \":\" + name + \":\" + str(type(item)) + \")\"\n            )\n            self.__print(2, 0, repr_str)\n            raise NotImplementedError(\n                \"nn.Graph.build()'s input/output item only support types: Tensor/None.\"\n            )\n\n    def __map_io(self, io_type, func, *args, **kwargs):\n        assert io_type in (\"input\", \"output\")\n\n        def mapping_tensor_or_any(tensor):\n            if isinstance(tensor, Tensor):\n                mapped_arg = func(tensor)\n            else:\n                mapped_arg = tensor\n            return mapped_arg\n\n        def leaf_arg_fn(arg):\n            arg_value = arg.value()\n            return mapping_tensor_or_any(arg_value)\n\n        # NOTE(lixiang): Reduce the overhead of traversal and parsing of io args.\n        if self._is_simple_tuple_output or self._is_simple_tuple_input:\n            args_tree = ArgsTree(args, False)\n            out = args_tree.map_tuple_leaf(mapping_tensor_or_any)\n            return out, kwargs\n\n        args_tree = ArgsTree(\n            (args, kwargs), True, \"_\" + self.name + \"_\" + io_type, None\n        )\n\n        out = args_tree.map_leaf(leaf_arg_fn)\n        mapped_args = out[0]\n        mapped_kwargs = out[1]\n        return mapped_args, mapped_kwargs\n\n    def __map_io_lite(self, func, *args, **kwargs):\n        args_tree = ArgsTree((args, kwargs), False)\n        out = args_tree.map_leaf(func)\n        mapped_args = out[0]\n        mapped_kwargs = out[1]\n        return mapped_args, mapped_kwargs\n\n    def __flatten_io(self, io_type, *args, **kwargs):\n        flattened_args = []\n        args_tree = ArgsTree((args, kwargs), False)\n\n        for arg in args_tree.iter_nodes():\n            if isinstance(arg, Tensor):\n                flattened_args.append(arg)\n            else:\n                continue\n        return flattened_args\n\n    def __io_item_check(self, item, expect_type, io_type, name):\n        if expect_type is None and item is None:\n            return\n        elif expect_type is not None and isinstance(item, expect_type):\n            return\n        else:\n            assert io_type in (\"input\", \"output\")\n            repr_str = (\n                \"[ERROR](\" + io_type.upper() + \":\" + name + \":\" + str(type(item)) + \")\"\n            )\n            self.__print(2, 0, repr_str)\n            raise NotImplementedError(\n                \"nn.Graph.build()'s input/output item only support types: Tensor/None.\"\n            )\n\n    def __empty_like_io(self, io_type, *args, **kwargs):\n        def func(t):\n            shape = t.shape\n            dtype = t.dtype\n\n            with oneflow._oneflow_internal.lazy_mode.guard(False):\n                if t.is_global:\n                    eager_out = oneflow.empty(\n                        shape, dtype=dtype, placement=t.placement, sbp=t.sbp,\n                    )\n                else:\n                    eager_out = oneflow.empty(shape, dtype=dtype, device=t.device)\n\n            return eager_out\n\n        return self.__map_io(io_type, func, *args, **kwargs)\n\n    def __copy_io(self, io_type, *args, **kwargs):\n        def func(tensor):\n            with oneflow._oneflow_internal.lazy_mode.guard(False):\n                build_arg = tensor.to(copy=True)\n                return build_arg\n\n        return self.__map_io(io_type, func, *args, **kwargs)\n\n    def _add_module(self, name: str, module: Module = None) -> None:\n        r\"\"\"Adds module to the graph as a block so that the module will\n        be called in nn.Graph.build.\n\n        Args:\n            name (str): name of the child block. The child block can be accessed from this graph using the given name.\n            module (Module): child module to be added to the graph.\n\n        Just assign nn.Module in nn.Graph, _add_module will be called to add the\n        module as a ProxyModule:\n\n        For example:\n\n        .. code-block:: python\n\n            >>> import oneflow as flow\n            >>> import numpy as np\n            >>> class LinearGraph(flow.nn.Graph):\n            ...     def __init__(self):\n            ...         super().__init__()\n            ...         # add a nn.Module as a block to graph.\n            ...         self.linear = flow.nn.Linear(3, 8, False)\n            ...     def build(self, x):\n            ...         # call the nn.Module block.\n            ...         return self.linear(x)\n\n\n        The block can be accessed as an attribute using the given name.\n            g = LinearGraph()\n            g(flow.Tensor(np.random.randn(8, 3)))\n            print(g.linear)\n            (MODULE:linear:Linear(in_features=3, out_features=8, bias=False)): (\n              (INPUT:_linear_input.0.0_2:tensor(..., is_lazy='True', size=(8, 3), dtype=oneflow.float32))\n              (PARAMETER:linear.weight:tensor(..., size=(8, 3), dtype=oneflow.float32, grad_fn=<accumulate_grad>)): ()\n              (OUTPUT:_linear_output.0.0_2:tensor(..., is_lazy='True', size=(8, 8), dtype=oneflow.float32,\n                     grad_fn=<matmulBackward>))\n              (GraphModule:linear()): (\n                (OPERATOR: linear.weight() -> (out:sbp=(B), size=(8, 3), dtype=(oneflow.float32)), placement=(oneflow.placement(type=\"cpu\", ranks=[0])))\n                (OPERATOR: linear-matmul-0(_LinearGraph_0_input.0.0_2/out:(sbp=(B), size=(8, 3), dtype=(oneflow.float32)), linear.weight/out:(sbp=(B), size=(8, 3), dtype=(oneflow.float32))) -> (linear-matmul-0/out_0:(sbp=(B), size=(8, 8), dtype=(oneflow.float32))), placement=(oneflow.placement(type=\"cpu\", ranks=[0])))\n              )\n            )\n        \"\"\"\n        if \"_name\" not in self.__dict__:\n            raise AttributeError(\n                \"Base class nn.Graph has not been initialized, \"\n                \"please call super().__init__() in subclass of nn.Graph \"\n                \"before assigning any attribute.\"\n            )\n        if not isinstance(module, Module) and module is not None:\n            raise TypeError(\"{} is not a Module subclass\".format(type(module)))\n        elif not isinstance(name, str):\n            raise TypeError(\"module name should be a string. Got {}\".format(type(name)))\n        elif hasattr(self, name) and name not in self._blocks:\n            raise KeyError(\"attribute '{}' already exists\".format(name))\n        elif \".\" in name:\n            raise KeyError('module name can\\'t contain \".\", got: {}'.format(name))\n        elif name == \"\":\n            raise KeyError('module name can\\'t be empty string \"\"')\n\n        self._blocks[name] = get_proxy_cls(module)(\n            module, \"\", name, weakref.proxy(self)\n        )\n\n    def __setattr__(self, name: str, value=None):\n        if isinstance(value, Module):\n            self._add_module(name, value)\n        elif isinstance(value, Optimizer):\n            raise AttributeError(\n                \"'{}' nn.Graph is not allowed to set Optimizer attribute named '{}'. \"\n                \"Please use add_optimizer(...) instead.\".format(\n                    type(self).__name__, name\n                )\n            )\n        elif isinstance(value, Tensor):\n            raise AttributeError(\n                \"'{}' nn.Graph is not allowed to set Tensor attribute named '{}'. \"\n                \"Please use nn.Module to hold the tensor, then add the nn.Module to nn.Graph.\".format(\n                    type(self).__name__, name\n                )\n            )\n        else:\n            object.__setattr__(self, name, value)\n\n    def __getattr__(self, name: str):\n        if \"_blocks\" in self.__dict__:\n            if name in self._blocks:\n                return self._blocks[name]\n        if name in self.__dict__:\n            return self.__dict__[name]\n        raise AttributeError(\n            \"'{}' object has no attribute '{}'\".format(type(self).__name__, name)\n        )\n\n    def __del__(self):\n        # Ensure vm has finished running this graph.\n        if self._session._env.is_shutting_down():\n            # After python shutting down, it's not safe to call oneflow._oneflow_internal.eager.\n            # But shutting down will do sync in SwitchToShuttingDownPhase.\n            # So it's safe to skip sync here.\n            return\n        oneflow._oneflow_internal.eager.Sync()\n        current_env_enable_mlir_inference_opt = os.getenv(\n            \"ONEFLOW_MLIR_ENABLE_INFERENCE_OPTIMIZATION\"\n        )\n        if (self.env_enable_mlir_inference_opt is not None) and (\n            current_env_enable_mlir_inference_opt is None\n        ):\n            os.environ[\n                \"ONEFLOW_MLIR_ENABLE_INFERENCE_OPTIMIZATION\"\n            ] = self.env_enable_mlir_inference_opt\n        oneflow._oneflow_internal.ResetVariableTensorMgr()\n\n    def __ensure_input_tensors_contiguous(self, *args, **kwargs):\n        args_tree = ArgsTree((args, kwargs), False)\n\n        def func(value):\n            if isinstance(value, Tensor) and not value.is_contiguous():\n                value.contiguous_()\n            return value\n\n        # NOTE(lixiang): Reduce the overhead of traversal and parsing of input args.\n        if self._is_simple_tuple_input:\n            args_tree.map_tuple_leaf(func)\n            return\n\n        args_tree.map_leaf(func)\n\n    def __ensure_input_tensors_contiguous_and_flatten(self, *args, **kwargs):\n        flattened_args = []\n\n        def func(value):\n            if isinstance(value, Tensor) and not value.is_contiguous():\n                value.contiguous_()\n            return value\n\n        # NOTE(lixiang): Reduce the overhead of traversal and parsing of input args.\n        if self._is_simple_tuple_input:\n            args_tree = ArgsTree(args, False)\n            # contiguous\n            args_tree.map_tuple_leaf(func)\n            # flatten\n            for arg in args_tree.iter_nodes():\n                if isinstance(arg, Tensor):\n                    flattened_args.append(arg)\n                else:\n                    continue\n            return flattened_args\n\n        args_tree = ArgsTree((args, kwargs), False)\n        # contiguous\n        args_tree.map_leaf(func)\n        # flatten\n        for arg in args_tree.iter_nodes():\n            if isinstance(arg, Tensor):\n                flattened_args.append(arg)\n            else:\n                continue\n        return flattened_args\n\n    @staticmethod\n    def with_dynamic_input_shape(*, size: int = 10, enable_shared: bool = True):\n        def deco_with_config(graph_init_func):\n            @wraps(graph_init_func)\n            def deco_func(self, *args, **kwargs):\n                graph_init_func(self, *args, **kwargs)\n                self._run_with_cache = True\n                import oneflow.nn.graph.cache as cache\n\n                self._dynamic_input_graph_cache = cache.GraphCache(\n                    weakref.proxy(self),\n                    cache_size=size,\n                    enable_graph_shared=enable_shared,\n                )\n                self._cached_init_args = args\n                self._cached_init_kwargs = kwargs\n\n            return deco_func\n\n        return deco_with_config\n\n\nif __name__ == \"__main__\":\n    import doctest\n\n    doctest.testmod(raise_on_error=True)\n"
  },
  {
    "path": "python/oneflow/nn/graph/graph_block.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport weakref\nfrom collections import OrderedDict\nfrom typing import Iterator, Optional, Set, Union, List\n\nimport oneflow._oneflow_internal\nfrom oneflow.env import get_rank\nfrom oneflow.framework import graph_build_util\nfrom oneflow.nn.graph.util import (\n    add_indent,\n    operators_repr,\n    GraphIR,\n)\n\n\nclass GraphBlockType:\n    NONE = \"NONE\"\n    MODULE = \"MODULE\"\n    PARAMETER = \"PARAMETER\"\n    BUFFER = \"BUFFER\"\n\n\n# Module or Tensor are both treated as Block.\nclass GraphBlock(object):\n    def __init__(\n        self,\n        prefix: str = \"\",\n        name: str = \"\",\n        belonged_graph: weakref.ProxyTypes = None,\n        belonged_proxy: weakref.ProxyTypes = None,\n        block_graph_type: GraphBlockType = GraphBlockType.NONE,\n    ):\n        self._name = name\n        self._name_prefix = prefix\n        self._type = block_graph_type\n        self._scope = None\n        self._prev_scope = None\n        assert belonged_graph is None or isinstance(belonged_graph, weakref.ProxyTypes)\n        self._belonged_graph = belonged_graph\n        assert belonged_proxy is None or isinstance(belonged_proxy, weakref.ProxyTypes)\n        self._belonged_proxy = belonged_proxy\n\n    @property\n    def name(self):\n        return self._name\n\n    @property\n    def name_prefix(self):\n        return self._name_prefix\n\n    @property\n    def type(self):\n        return self._type\n\n    @property\n    def prev_scope(self):\n        if self._prev_scope is None:\n            self._prev_scope = oneflow._oneflow_internal.GetCurrentScope()\n        return self._prev_scope\n\n    @property\n    def scope(self):\n        if self._scope is None:\n            self._scope = graph_build_util.make_new_blockgraph_scope(\n                self.prev_scope, self\n            )\n        return self._scope\n\n    def scope_context(self):\n        return graph_build_util.BlockScopeContext(self.prev_scope, self.scope)\n\n\nclass GraphModule(GraphBlock):\n    r\"\"\"GraphModule is the graph representation of a nn.Module in a nn.Graph.\n\n    When an nn.Module is added into an nn.Graph, it is wrapped into a ProxyModule. The ProxyModule has a GraphModule inside it.\n    You can get and set the GraphModule to enable graph optimization on the nn.Module.\n    \"\"\"\n\n    def __init__(\n        self,\n        prefix: str = \"\",\n        name: str = \"\",\n        belonged_graph: weakref.ProxyTypes = None,\n        belonged_proxy: weakref.ProxyTypes = None,\n    ):\n        super().__init__(\n            prefix, name, belonged_graph, belonged_proxy, GraphBlockType.MODULE\n        )\n        self._is_null = True\n        self._stage_id = None\n        self._stage_placement = None\n        self._activation_checkpointing = None\n\n        self._debug = False\n        self._debug_min_s_level = 2\n        self._debug_max_v_level = 0\n        self._debug_max_py_stack_depth = 2\n        self._debug_only_user_py_stack = True\n        self._debug_op_repr_with_py_stack = False\n        self._is_executing_forward = False\n        self._args_repr = []\n        self._outs_repr = []\n\n    def set_stage(self, stage_id: int = None, placement=None):\n        r\"\"\"Set stage id and placement of nn.Module in pipeline parallelism.\n\n        Args:\n            stage_id (int): stage id of this module.\n            placement (flow.placement): the placement of all tensor in this module.\n\n        Note:\n            There will be automatically do tensor.to_global(placement) for all input tensor of\n            this module. So there is no need to write to_global() in the module forward when using\n            Pipeline Parallelism which is not recommended.\n\n        For example:\n\n        .. code-block:: python\n\n            # module0 and module1 are two nn.Module in a nn.Graph.\n            # When a nn.Module is added into a nn.Graph, it is wrapped into a ProxyModule.\n            # We can set Stage ID and Placement by using ProxyModule.to(GraphModule).set_stage()\n            # The Stage ID is numbered starting from 0 and increasing by 1.\n            # The Placement is all tensors placement of this module.\n            import oneflow as flow\n            from oneflow.nn.graph import GraphModule\n            P_0 = flow.placement(type = \"cuda\", ranks = [0, 1])\n            P_1 = flow.placement(type = \"cuda\", ranks = [2, 3])\n            self.module0.to(GraphModule).set_stage(stage_id = 0, placement = P0)\n            self.module1.to(GraphModule).set_stage(stage_id = 1, placement = P1)\n\n        \"\"\"\n\n        self._is_null = False\n        self._stage_id = stage_id\n        self._stage_placement = placement\n\n    # NOTE(lixiang): For the normal display of docstr, the API Doc of the get and set methods are written together in the stage_id function.\n    @property\n    def stage_id(self):\n        r\"\"\"Set/Get stage id of nn.Module/GraphModule in pipeline parallelism.\n        When calling stage_id(value: int = None), set different module's stage id to hint the graph\n        preparing right num of buffers in pipeline. (Not Recommended, for easy and efficient pipeline\n        parallelism experience, please use set_stage(stage_id, placement))\n        \"\"\"\n        return self._stage_id\n\n    @stage_id.setter\n    def stage_id(self, value: int = None):\n        r\"\"\"Set stage id of Module in pipeline parallelism.\n        Set different module's stage id to hint the graph preparing right num of buffers in pipeline.\n        \"\"\"\n        print(\n            \"Warning: `stage_id = i` is deprecated, please use \\n\",\n            \" set_stage(i, placement) for easy and efficient Pipeline parallel experience.\",\n        )\n\n        self._is_null = False\n        self._stage_id = value\n\n    @property\n    def stage_placement(self):\n        return self._stage_placement\n\n    # NOTE(lixiang): For the normal display of docstr, the API Doc of the get and set methods are written together in the activation_checkpointing function.\n    @property\n    def activation_checkpointing(self):\n        r\"\"\"Set/Get whether do activation checkpointing in this nn.Module.\n\n        For example:\n\n        .. code-block:: python\n\n            import oneflow as flow\n            from oneflow.nn.graph import GraphModule\n\n            class Graph(flow.nn.Graph):\n                def __init__(self):\n                    super().__init__()\n                    self.linear1 = flow.nn.Linear(3, 5, False)\n                    self.linear2 = flow.nn.Linear(5, 8, False)\n                    self.linear1.to(GraphModule).activation_checkpointing = True\n                    self.linear2.to(GraphModule).activation_checkpointing = True\n\n                def build(self, x):\n                    y_pred = self.linear1(x)\n                    y_pred = self.linear2(y_pred)\n                    return y_pred\n\n            graph = Graph()\n\n        \"\"\"\n        return self._activation_checkpointing\n\n    @activation_checkpointing.setter\n    def activation_checkpointing(self, mode: bool = False):\n        r\"\"\"Set whether do activation checkpointing in this Module.\n        \"\"\"\n        self._is_null = False\n        self._activation_checkpointing = mode\n\n    def _config_repr(self):\n        main_str = (\n            \"(\"\n            + self.__class__.__name__\n            + \"(\"\n            + (\n                (\"stage_id=\" + str(self.stage_id) + \", \")\n                if self.stage_id is not None\n                else \"\"\n            )\n            + (\n                (\n                    \"activation_checkpointing=\"\n                    + str(self.activation_checkpointing)\n                    + \", \"\n                )\n                if self.activation_checkpointing is not None\n                else \"\"\n            )\n            + \"))\"\n        )\n        return main_str\n\n    def debug(\n        self,\n        v_level: int = 0,\n        *,\n        ranks: Optional[Union[int, List[int]]] = None,\n        max_py_stack_depth: int = 2,\n        only_user_py_stack=True,\n        op_repr_with_py_stack=False,\n    ) -> None:\n        assert isinstance(v_level, int)\n        assert isinstance(max_py_stack_depth, int)\n        assert isinstance(only_user_py_stack, bool)\n        assert isinstance(op_repr_with_py_stack, bool)\n\n        if ranks is None:\n            rank_list = [0]\n        elif isinstance(ranks, int):\n            rank_list = [ranks]\n        elif isinstance(ranks, list):\n            rank_list = ranks\n        else:\n            raise ValueError(\"ranks must be int or List[int].\")\n\n        my_rank = get_rank()\n        if -1 in rank_list or my_rank in rank_list:\n            self._debug = v_level >= 0\n            if self._debug:\n                self._debug_min_s_level = 0\n                self._debug_max_v_level = max(0, v_level)\n\n            self._debug_max_py_stack_depth = max_py_stack_depth\n            self._debug_only_user_py_stack = only_user_py_stack\n            self._debug_op_repr_with_py_stack = op_repr_with_py_stack\n\n            if self._type == GraphBlockType.MODULE:\n\n                def _set_child(d):\n                    for (_, n) in d.items():\n                        n.to(GraphModule).debug(\n                            v_level,\n                            ranks=ranks,\n                            max_py_stack_depth=max_py_stack_depth,\n                            only_user_py_stack=only_user_py_stack,\n                            op_repr_with_py_stack=op_repr_with_py_stack,\n                        )\n\n                assert self._belonged_proxy is not None and isinstance(\n                    self._belonged_proxy, weakref.ProxyTypes\n                )\n                _set_child(self._belonged_proxy._modules)\n\n    def _ops_repr(self):\n        r\"\"\"Generate operators' string representation of this GraphModule\n        \"\"\"\n        assert self._belonged_graph, (\n            \"ProxyModule: \"\n            + self._name_prefix\n            + self.name\n            + \"'s belonged graph is not set.\"\n        )\n\n        if self._belonged_graph._compiled_graph_proto is not None:\n            module_conf = self._belonged_graph._compiled_graph_proto.module_name2module_conf[\n                self.name_prefix + self.name\n            ]\n            if self._belonged_graph._oneflow_internal_graph_ir__ is None:\n                self._belonged_graph._oneflow_internal_graph_ir__ = GraphIR(\n                    self._belonged_graph._compiled_graph_proto\n                )\n            return operators_repr(\n                module_conf.ops,\n                self._belonged_graph._oneflow_internal_graph_ir__,\n                self._debug_op_repr_with_py_stack,\n            )\n\n        return []\n\n    def _shallow_repr(self):\n        main_str = (\n            \"(\"\n            + self.__class__.__name__\n            + \":\"\n            + self._name_prefix\n            + self._name\n            + \"(\"\n            + (\n                (\"stage_id=\" + str(self.stage_id) + \", \")\n                if self.stage_id is not None\n                else \"\"\n            )\n            + (\n                (\n                    \"activation_checkpointing=\"\n                    + str(self.activation_checkpointing)\n                    + \", \"\n                )\n                if self.activation_checkpointing is not None\n                else \"\"\n            )\n            + \"))\"\n        )\n        return main_str\n\n    def _repr_lines(self):\n        child_lines = []\n        for op_str in self._ops_repr():\n            child_lines.append(add_indent(op_str, 2))\n        return child_lines\n\n    def __repr__(self):\n        lines = None\n        child_lines = self._repr_lines()\n        if len(child_lines) > 0:\n            lines = child_lines\n\n        main_str = self._shallow_repr() + \": (\"\n        if lines is not None:\n            main_str += \"\\n  \" + \"\\n  \".join(lines) + \"\\n\"\n        main_str += \")\"\n        return main_str\n\n\nclass GraphTensor(GraphBlock):\n    r\"\"\"GraphTensor is the graph representation of a Tensor in a nn.Graph.\n    \"\"\"\n\n    def __init__(\n        self,\n        prefix: str = \"\",\n        name: str = \"\",\n        belonged_graph: weakref.ProxyTypes = None,\n        belonged_proxy: weakref.ProxyTypes = None,\n        tensor_graph_type: GraphBlockType = GraphBlockType.NONE,\n    ):\n        super().__init__(\n            prefix, name, belonged_graph, belonged_proxy, tensor_graph_type\n        )\n        self._stage_id = None\n        self._stage_placement = None\n\n    def set_stage(self, stage_id: int = None, placement=None):\n        self._stage_id = stage_id\n        self._stage_placement = placement\n\n    @property\n    def stage_id(self):\n        return self._stage_id\n\n    @stage_id.setter\n    def stage_id(self, value: int = None):\n        print(\n            \"Warning: `stage_id = i` is deprecated, please use \\n\",\n            \" set_stage(i, placement) for easy and efficient Pipeline parallel experience.\",\n        )\n\n        self._stage_id = value\n\n    @property\n    def stage_placement(self):\n        return self._stage_placement\n"
  },
  {
    "path": "python/oneflow/nn/graph/graph_config.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport os\n\nfrom collections import OrderedDict\n\nimport oneflow.boxing.nccl as nccl_config\nfrom oneflow.nn.graph.optimizer import OptDict\nimport oneflow.core.job.job_conf_pb2 as job_conf_pb\nimport oneflow as flow\n\n\nclass GraphConfig(object):\n    r\"\"\"For configuration of nn.Graph.\n    \"\"\"\n\n    def __init__(self):\n        super().__init__()\n        self._outputs_buffer_size = 2\n        self.proto = job_conf_pb.JobConfigProto()\n        self._train(False)\n\n    def _train(self, mode: bool = True):\n        if mode:\n            self.proto.train_conf.SetInParent()\n        else:\n            self.proto.predict_conf.SetInParent()\n\n    @property\n    def training(self):\n        if self.proto.HasField(\"train_conf\"):\n            return True\n        if self.proto.HasField(\"predict_conf\"):\n            return False\n        raise NotImplementedError\n\n    def enable_amp(self, mode: bool = True, *, dtype: flow.dtype = flow.float16):\n        r\"\"\"If set to true, then graph will use mixed precision mode, it means use both float16 and float32 during model training.\n\n        For example:\n\n        .. code-block:: python\n\n            import oneflow as flow\n\n            class Graph(flow.nn.Graph):\n                def __init__(self):\n                    super().__init__()\n                    self.linear = flow.nn.Linear(3, 8, False)\n                    self.config.enable_amp(True) # Use mixed precision mode.\n                def build(self, x):\n                    return self.linear(x)\n\n            graph = Graph()\n\n        Args:\n            mode (bool, optional): The default value is True.\n\n\n        \"\"\"\n        assert type(mode) is bool\n        assert dtype in (flow.float16, flow.bfloat16)\n        self.proto.enable_auto_mixed_precision = mode\n        self.proto.mixed_precision_data_type = flow._oneflow_internal.deprecated.GetProtoDtype4OfDtype(\n            dtype\n        )\n\n    def set_zero_redundancy_optimizer_mode(self, mode: str = \"distributed_split\"):\n        raise RuntimeError(\n            \"`set_zero_redundancy_optimizer_mode` has been changed to `enable_zero`, please use `enable_zero(True)` to activate ZeRO optimization.\"\n        )\n\n    def enable_zero(\n        self,\n        mode: bool = True,\n        *,\n        stage: int = 2,\n        shard_min_size: int = 1024,\n        shard_restore_level: int = 1,\n    ):\n        r\"\"\"Enable ZeRO redundancy optimizer.\n\n        This optimization will reduce optimizer states memory consumption as described\n        by ZeRO https://arxiv.org/abs/1910.02054 .\n\n        The default zero stage is 2.\n\n        For example:\n\n        .. code-block:: python\n\n            import oneflow as flow\n\n            class Graph(flow.nn.Graph):\n                def __init__(self):\n                    super().__init__()\n                    self.linear = flow.nn.Linear(3, 8, False)\n                    self.config.enable_zero()\n                def build(self, x):\n                    return self.linear(x)\n\n            graph = Graph()\n\n        Args:\n            mode (bool): if set to true, optimizer states of Data Parallel will be sharded across devices.\n            stage (int): optimization stage, range from 1 to 3.\n            shard_min_size (int): min size (element count) of a shard of an optimizer state.\n            shard_restore_level (int): level to restore sharded parameter to whole parameter for consumer operators, level 0 is no restore, level 1 is soft restore, level 2 is hard restore. Note that this parameter is at pre-alpha stage.\n        \"\"\"\n        if not mode:\n            self.proto.optimizer_placement_optimization_mode = \"none\"\n            return\n        assert stage >= 1 and stage <= 3, \"ZeRO stage must range from 1 to 3.\"\n        assert (\n            shard_min_size > 0\n        ), \"ZeRO min size of a sharded optimizer state must > 0.\"\n        assert stage >= 1 and stage <= 3, \"ZeRO stage must range from 1 to 3.\"\n        if stage >= 1:\n            self.proto.optimizer_placement_optimization_mode = \"distributed_split\"\n            self.proto.optimizer_placement_optimization_threshold = shard_min_size\n            self.proto.optimizer_placement_optimization_shard_restore_level = (\n                shard_restore_level\n            )\n        if stage >= 2:\n            nccl_config.enable_use_compute_stream(True)\n        if stage >= 3:\n            nccl_config.disable_group_boxing_by_dst_parallel(True)\n\n    def allow_fuse_model_update_ops(self, mode: bool = True):\n        r\"\"\"If set to true, try to fuse cast + scale + l1_l2_regularize_gradient + model_update to one op to improve performance.\n\n        For example:\n\n        .. code-block:: python\n\n            import oneflow as flow\n\n            class Graph(flow.nn.Graph):\n                def __init__(self):\n                    super().__init__()\n                    self.linear = flow.nn.Linear(3, 8, False)\n                    self.config.allow_fuse_model_update_ops(True)\n                def build(self, x):\n                    return self.linear(x)\n\n            graph = Graph()\n\n        Args:\n            mode (bool, optional): The default value is True.\n        \"\"\"\n        self.proto.enable_fuse_model_update_ops = mode\n\n    def allow_fuse_add_to_output(self, mode: bool = True):\n        r\"\"\"If set to true, try to fuse a binary element-wise add operator to one of the predecessors to improve performance.\n\n        For example:\n\n        .. code-block:: python\n\n            import oneflow as flow\n\n            class Graph(flow.nn.Graph):\n                def __init__(self):\n                    super().__init__()\n                    self.bn1 = flow.nn.BatchNorm1d(100)\n                    self.config.allow_fuse_add_to_output(True)\n                def build(self, x):\n                    bn = self.bn1(x)\n                    out = bn + x\n                    return out\n\n            graph = Graph()\n\n        Args:\n            mode (bool, optional): The default value is True.\n        \"\"\"\n        self.proto.enable_fuse_add_to_output = mode\n\n    def allow_fuse_cast_scale(self, mode: bool = True):\n        r\"\"\"If set to true, try to fuse cast and scalar_mul_by_tensor to improve performance.\n\n        For example:\n\n        .. code-block:: python\n\n            import oneflow as flow\n\n            def model(x):\n                return flow.mul(1,flow.cast(x,flow.int8))\n\n            class Graph(flow.nn.Graph):\n                def __init__(self):\n                    super().__init__()\n                    self.m=model\n                    self.config.allow_fuse_cast_scale(True)\n                def build(self, x):\n                    return self.m(x)\n\n            graph = Graph()\n\n        Args:\n            mode (bool, optional): The default value is True.\n        \"\"\"\n        self.proto.enable_fuse_cast_scale = mode\n\n    def set_gradient_accumulation_steps(self, value):\n        r\"\"\"Set num of steps to accumulate gradient.\n\n        For example:\n\n        .. code-block:: python\n\n            import oneflow as flow\n\n            class Graph(flow.nn.Graph):\n                def __init__(self):\n                    super().__init__()\n                    self.linear = flow.nn.Linear(3, 8, False)\n                    # Let graph do gradient accumulation, such as pipelining parallelism depends on gradient accumulation.\n                    self.config.set_gradient_accumulation_steps(4)\n                def build(self, x):\n                    return self.linear(x)\n\n            graph = Graph()\n\n        Args:\n            value (int): num of steps.\n        \"\"\"\n        self.proto.num_gradient_accumulation_steps = value\n        if value > 1:\n            # NOTE(chengcheng): when use gradient accumulation, optimizer nccl allreduce can NOT\n            #  overlap with backward, so nccl use compute stream is optimization without negative\n            #  effects.\n            nccl_config.enable_use_compute_stream(True)\n\n    def set_outputs_buffer_size(self, value: int = 2):\n        r\"\"\"Set the outputs buffer size of ``nn.Graph``.\n\n        When graph's outputs buffer size is greater than 2, multiple call on the graph can work like a pipeline. This makes multiple call takes less time.\n\n        The default outputs buffer size is 2.\n\n        # TODO (lixiang): Explain the meaning of the size of buffer size and add sample code.\n        # The size of the buffer size indicates the maximum number of iterations that the output of the Graph and the Graph actually executed asynchronously can overlap.\n        # If the buffer size is 1, there is no pipeline. A size of 2 means that it can execute 1 iter ahead of time. A size of 3 means that two iters can be executed ahead of time.\n\n        Args:\n            value (int): graph outputs buffer size.\n        \"\"\"\n        assert isinstance(value, int)\n        assert value >= 1\n        self._outputs_buffer_size = value\n\n    def enable_cudnn_conv_heuristic_search_algo(self, mode: bool = True):\n        r\"\"\" Whether enable cudnn conv operation to use heuristic search algorithm.\n\n        Note:\n            It is recommended to use `flow.backends.cudnn.enable_conv_heuristic_search_algo(False)` instead of this function.\n\n        For example:\n\n        .. code-block:: python\n\n            import oneflow as flow\n\n            class Graph(flow.nn.Graph):\n                def __init__(self):\n                    super().__init__()\n                    self.m = flow.nn.Conv2d(16, 32, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1))\n                    # Do not enable the cudnn conv operation to use the heuristic search algorithm.\n                    self.config.enable_cudnn_conv_heuristic_search_algo(False)\n                def build(self, x):\n                    return self.m(x)\n\n            graph = Graph()\n\n        Args:\n            mode (bool, optional): The default value is True.\n        \"\"\"\n        self.proto.cudnn_conv_heuristic_search_algo = mode\n\n    def enable_straighten_algorithm(self, mode: str = \"MemoryFirst\"):\n        r\"\"\" Whether enable the straighten algorithm.\n\n        straighten_algorithm_tag 1: Disable\n        Disable the straighten algorithm in the task graph.\n        Would use the original topography order for executing task nodes.\n\n        straighten_algorithm_tag 2: SpeedFirst\n        Under the second configuration, the straighten algorithm would try to speed up the training as much as possible.\n        If using nccl compute stream, setting the tag to 2 might not speed up the training.\n        If not using nccl compute stream, setting the tag to 2 might speed up data parallelism by 0.6% and model parallelism by 6%.\n        Considering memory, enabling the straighten algorithm is forbidden with one machine/device only, and not recommended under pipeline parallelism.\n\n        straighten_algorithm_tag 3: MemoryFirst\n        Under the third configuration, the straighten algorithm would try to compress memory as much as possible.\n        It might save up to 13% of the memory for some models.\n        And might save nothing for some models.\n\n        straighten_algorithm_tag 4: OverlapCpuGpu\n        Under the forth configuration, the straighten algorithm would try to run the cpu nodes and gpu nodes alternately.\n        Such procedure would reduce the gaps of the execution on gpus.\n        It might speed up the training by 2%.\n        If no cpu nodes exist, the straighten_algorithm_tag would be switch to 3 automatically.\n\n        straighten_algorithm_tag 5: DelayShortGpu\n        Under the fifth configuration, the straighten algorithm would try to delay the cpu nodes.\n        Such procedure would reduce the gaps of the execution on gpus.\n        It might speed up the validation (or training).\n        If no cpu nodes exist, the straighten_algorithm_tag would be switch to 3 automatically.\n        \"\"\"\n        assert (\n            mode == \"Disable\"\n            or mode == \"SpeedFirst\"\n            or mode == \"MemoryFirst\"\n            or mode == \"OverlapCpuGpu\"\n            or mode == \"DelayShortGpu\"\n        ), \"please choose one type among {Disable, SpeedFirst, MemoryFirst, OverlapCpuGpu, DelayShortGpu}\"\n        if mode == \"Disable\":\n            self.proto.straighten_algorithm_tag_in_task_graph = 1\n        elif mode == \"SpeedFirst\":\n            self.proto.straighten_algorithm_tag_in_task_graph = 2\n        elif mode == \"MemoryFirst\":\n            self.proto.straighten_algorithm_tag_in_task_graph = 3\n        elif mode == \"OverlapCpuGpu\":\n            self.proto.straighten_algorithm_tag_in_task_graph = 4\n        else:\n            self.proto.straighten_algorithm_tag_in_task_graph = 5\n\n    def enable_compress_memory(self, mode: bool = True):\n        \"\"\"If true, then the graph will try its best to find the minimum memory allocation strategy.\n        This process might take several minutes for a small graph and half an hour for a large one.\n        The compressed memory would be closed to the lower bound of the peak memory.\n        It benefits a lot if you need to train a lot of batches.\n\n        Args:\n            mode (bool, optional): [description]. Default is True.\n        \"\"\"\n        self.proto.enable_compress_memory = mode\n\n    def enable_choose_best_memory_allocation(self, mode: bool = True):\n        \"\"\"If true, then the graph will go through all the memory allocation algorithms. Including\n        large memory first algorithm,\n        long lifetime first algorithm,\n        first in first allocates algorithm,\n        large memory volume first algorithm\n        with the compact insertion on and off.\n        The the graph will choose the one with the least memory.\n\n        If false, the graph will directly choose\n        the large memory first algorithm with compact insertion off.\n        Since the large memory first algorithm is the best one among those algorithms during most of our test cases.\n        And turning compact insertion off will save half of the time of this algorithm.\n        \"\"\"\n        if mode:\n            self.proto.memory_allocation_algorithm_conf.use_mem_size_first_algo = True\n            self.proto.memory_allocation_algorithm_conf.use_lifetime_first_algo = True\n            self.proto.memory_allocation_algorithm_conf.use_time_line_algo = True\n            self.proto.memory_allocation_algorithm_conf.use_mem_volume_first_algo = True\n            self.proto.memory_compact_insert_conf.use_compact_insert = True\n            self.proto.memory_compact_insert_conf.use_non_compact_insert = True\n\n    def enable_auto_parallel(self, mode: bool = True):\n        \"\"\"If true, then graph will use the auto parallel algorithm to select a parallelism strategy.\n\n        Args:\n            mode (bool, optional): [description]. Default is True.\n        \"\"\"\n        self.proto.enable_auto_parallel = mode\n\n    def enable_auto_parallel_ignore_user_sbp_config(self, mode: bool = True):\n        \"\"\"If true, it will ignore all user configurations of SBP.\n\n        Args:\n            mode (bool, optional): [description]. Default is True.\n        \"\"\"\n        self.proto.enable_auto_parallel_ignore_user_sbp_config = mode\n\n    def set_auto_parallel_computation_cost_ratio(self, ratio):\n        \"\"\"\n        Set coefficient of computation cost in auto-parallel algorithm.\n        \"\"\"\n        self.proto.auto_parallel_computation_cost_ratio = ratio\n\n    def set_auto_parallel_wait_time(self, cost):\n        \"\"\"\n        Set wait time for auto-parallel algorithm.\n\n        wait time: An auto-parallel parameter. Describe the mutable extra time it will take when\n        communication between devices occurs. It will be added to the copy cost and may get reduced\n        when cover by computation cost.\n        \"\"\"\n        self.proto.auto_parallel_wait_time = cost\n\n    def enable_auto_parallel_trunk_algo(self, mode: bool = True):\n        \"\"\"\n        Find the trunk of the SBP graph, then reduce the wait time for tributaries.\n        \"\"\"\n        self.proto.enable_auto_parallel_trunk_algo = mode\n\n    def enable_auto_parallel_sbp_collector(self, mode: bool = True):\n        \"\"\"\n        Use \\\"sbp collector\\\" to create \\\"sbp proxy\\\" for nodes with multiple downstream operators.\n        \"\"\"\n        self.proto.enable_auto_parallel_sbp_collector = mode\n\n    def enable_auto_memory(self, mode: str = \"AdaptiveMemory\"):\n        r\"\"\" Whether we use a parallelism strategy with less memory\n\n        Auto memory strategy 1: Disable\n        Disable auto memory in auto parallel.\n        Ignore the memory and try our best to speed up the training.\n\n        Auto memory strategy 2: SlightMemoryDown\n        Try to decrease the memory while maintaining the throughput.\n\n        Auto memory strategy 3: ModerateMemoryDown\n        Decrease the memory, throughput might or might not be affected.\n        Similar to data parallelism + ZeRO.\n\n        Auto memory strategy 4: HeavyMemoryDown\n        Try our best to decrease the memory, ignoring the throughput.\n\n        Auto memory strategy 5: AdaptiveMemory\n        Use normal auto parallelism without consideration of memory while we have enough memory.\n        Gradually decrease the memory to avoid out of memory while we have inadequate memory.\n        Always try to find the highest throughput under the current limitation of memory.\n        \"\"\"\n        assert (\n            mode == \"Disable\"\n            or mode == \"SlightMemoryDown\"\n            or mode == \"ModerateMemoryDown\"\n            or mode == \"HeavyMemoryDown\"\n            or mode == \"AdaptiveMemory\"\n        )\n        if mode == \"Disable\":\n            self.proto.enable_auto_memory = 1\n        elif mode == \"SlightMemoryDown\":\n            self.proto.enable_auto_memory = 2\n        elif mode == \"ModerateMemoryDown\":\n            self.proto.enable_auto_memory = 3\n        elif mode == \"HeavyMemoryDown\":\n            self.proto.enable_auto_memory = 4\n        else:\n            self.proto.enable_auto_memory = 5\n\n    def enable_multi_tensor_update(self, mode: bool = True):\n        \"\"\"\n        Enable Multi Tensor Update Pass, it will merge small optimizer kernels to reduce kernel launch overhead.\n        \"\"\"\n        self.proto.enable_multi_tensor_update = mode\n\n    def enable_fused_model_update_cast(self, mode: bool = True):\n        \"\"\"\n        This option only works in AMP Mode, it will fuse optimizer update and model weights cast to half precision operation.\n        \"\"\"\n        self.proto.enable_fused_model_update_cast = mode\n\n    def _generate_optimizer_and_variable_configs(\n        self, opt_dict: OptDict = None, variables_conf: OrderedDict = None,\n    ):\n        opt_dict.generate_optimizer_and_variable_configs(self.proto, variables_conf)\n\n    def __repr__(self):\n        main_str = (\n            \"(\"\n            + \"CONFIG\"\n            + \":config:\"\n            + self.__class__.__name__\n            + \"(\"\n            + (\"training=\" + str(self.training) + \", \")\n            + \"))\"\n        )\n        return main_str\n"
  },
  {
    "path": "python/oneflow/nn/graph/optimizer.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nfrom oneflow.optim.optimizer import Optimizer\nfrom oneflow.nn.optimizer.lr_scheduler import LRScheduler\n\n\nclass OptDict(object):\n    def __init__(self, opt_dict):\n        if not isinstance(opt_dict, dict):\n            raise ValueError(\"opt_dict is not a dict\")\n\n        if \"optim\" in opt_dict:\n            if isinstance(opt_dict[\"optim\"], Optimizer):\n                self._optimizer = opt_dict[\"optim\"]\n            else:\n                raise ValueError('opt_dict[\"optim\"] is not an instance of Optimizer.')\n        else:\n            raise ValueError(\"Key 'optim' doesn't exist in opt_dict.\")\n\n        if \"is_sparse\" in opt_dict and opt_dict[\"is_sparse\"] is True:\n            self._is_sparse = True\n        else:\n            self._is_sparse = False\n\n        self._lr_scheduler = None\n        if \"lr_sch\" in opt_dict:\n            if not isinstance(opt_dict[\"lr_sch\"], LRScheduler):\n                raise ValueError(\n                    'opt_dict[\"lr_sch\"] is not an instance of LRScheduler.'\n                )\n\n            if opt_dict[\"lr_sch\"].optimizer is not self._optimizer:\n                raise ValueError(\"lr_scheduler doesn't match optimizer.\")\n\n            self._lr_scheduler = opt_dict[\"lr_sch\"]\n\n    def generate_optimizer_and_variable_configs(self, job_conf, vars_conf):\n        train_conf = job_conf.train_conf\n\n        if self._optimizer is None:\n            return\n\n        # Check first\n        self._optimizer._check_variables_in_graph(vars_conf)\n        self._optimizer._check_variables_optimizer_bound(vars_conf)\n\n        opt_confs = self._optimizer._generate_conf_for_graph(train_conf, vars_conf)\n\n        if self._is_sparse:\n            self._optimizer._generate_indexed_slices_optimizer_conf(job_conf, vars_conf)\n\n        if self._lr_scheduler is None:\n            return\n\n        for opt_conf in opt_confs:\n            self._lr_scheduler._generate_conf_for_graph(opt_conf.learning_rate_decay)\n\n\nclass VariableConfig(object):\n    def __init__(self, name: str):\n        assert name != \"\"\n        self._name = name\n        self._l2 = 0.0\n        self._bound_opt = None\n\n    @property\n    def name(self):\n        return self._name\n\n    @property\n    def l2(self):\n        return self._l2\n\n    @l2.setter\n    def l2(self, l2: float = 0.0):\n        self._l2 = l2\n\n    @property\n    def bound_optimizer(self):\n        return self._bound_opt\n\n    @bound_optimizer.setter\n    def bound_optimizer(self, opt):\n        self._bound_opt = opt\n\n    def __repr__(self):\n        return \"(variable name: \" + self._name + \"):(l2: \" + str(self._l2) + \".)\"\n"
  },
  {
    "path": "python/oneflow/nn/graph/proxy.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nfrom typing import Iterator, Optional, Set, Union, List\nimport weakref\nimport types\n\nimport oneflow._C\nimport oneflow._oneflow_internal\nfrom oneflow.framework import graph_build_util\nfrom oneflow.framework.tensor import Tensor, TensorTuple\nfrom oneflow.nn.modules.module import Module\nfrom oneflow.nn.modules.container import *\nfrom oneflow.nn.utils.container import *\nfrom oneflow.nn.parameter import Parameter\nfrom oneflow.nn.graph.graph_block import (\n    GraphBlockType,\n    GraphBlock,\n    GraphModule,\n    GraphTensor,\n)\nfrom oneflow.nn.graph.util import (\n    add_indent,\n    seq_to_func_return,\n)\nfrom oneflow.framework.args_tree import ArgsTree\n\n\ndef get_proxy_cls(item):\n    if isinstance(item, Sequential):\n        return ProxySequential\n    elif isinstance(item, ModuleList):\n        return ProxyModuleList\n    elif isinstance(item, ModuleDict):\n        return ProxyModuleDict\n    elif isinstance(item, ParameterList):\n        return ProxyParameterList\n    elif isinstance(item, ParameterDict):\n        return ProxyParameterDict\n    elif isinstance(item, Module):\n        return ProxyModule\n    elif isinstance(item, Tensor):\n        return ProxyTensor\n    else:\n        raise NotImplementedError()\n\n\nclass Proxy(object):\n    def __init__(self):\n        \"\"\" An ecution proxy of nn.Module or Tensor.\n\n        A proxy contains the original data(nn.Module or Tensor) and a graph representation of the original data.\n        \"\"\"\n        # The original data\n        self._oneflow_internal_origin__ = None\n        # The graph representation of the original data\n        self._oneflow_internal_graphblock__ = None\n\n    def to(self, *args, **kwargs):\n        \"\"\"\n        \"\"\"\n        if len(args) == 1 and issubclass(args[0], GraphBlock):\n            return self._oneflow_internal_graphblock__\n        elif len(args) == 1 and (args[0] is Module or args[0] is Tensor):\n            return self._oneflow_internal_origin__\n        else:\n            self._oneflow_internal_origin__.to(*args, **kwargs)\n\n\nclass ProxyModule(Proxy):\n    def __init__(\n        self,\n        origin: Module = None,\n        prefix: str = \"\",\n        name: str = \"\",\n        belonged_graph: weakref.ProxyTypes = None,\n    ):\n        assert not isinstance(origin, Proxy)\n        super().__init__()\n        self._oneflow_internal_graphblock__ = GraphModule(\n            prefix, name, belonged_graph, weakref.proxy(self)\n        )\n        self._modules = OrderedDict()\n        self._parameters = OrderedDict()\n        self._buffers = OrderedDict()\n\n        self._oneflow_internal_graphblock__set_origin(origin)\n\n    def _oneflow_internal_graphblock__set_origin(self, origin):\n        self._oneflow_internal_origin__ = origin\n        if origin is None:\n            return\n        assert isinstance(origin, Module)\n        for (n, m) in origin._modules.items():\n            self.__setattr__(\n                n,\n                get_proxy_cls(m)(\n                    m,\n                    self.to(GraphModule)._name_prefix\n                    + self.to(GraphModule)._name\n                    + \".\",\n                    n,\n                    self.to(GraphModule)._belonged_graph,\n                ),\n            )\n        for (n, p) in list(origin.named_parameters(\"\", False)):\n            self.__setattr__(\n                n,\n                get_proxy_cls(p)(\n                    p,\n                    self.to(GraphTensor)._name_prefix\n                    + self.to(GraphTensor)._name\n                    + \".\",\n                    n,\n                ),\n            )\n        for (n, b) in list(origin.named_buffers(\"\", False)):\n            self.__setattr__(\n                n,\n                get_proxy_cls(b)(\n                    b,\n                    self.to(GraphTensor)._name_prefix\n                    + self.to(GraphTensor)._name\n                    + \".\",\n                    n,\n                ),\n            )\n\n    def __call__(self, *args, **kwargs):\n        assert self.to(GraphModule)._type == GraphBlockType.MODULE\n        self.__print(0, 1, self._shallow_repr())\n\n        args_tree = ArgsTree(\n            (args, kwargs),\n            True,\n            \"_\"\n            + self.to(GraphModule).name_prefix\n            + self.to(GraphModule).name\n            + \"_input\",\n            None,\n        )\n\n        for (name, arg) in args_tree.iter_named_nodes():\n            if arg.is_leaf():\n                arg_value = arg.value()\n                meta_repr_str = (\n                    arg_value._meta_repr()\n                    if isinstance(arg_value, Tensor)\n                    else str(type(arg_value))\n                )\n                in_str = \"(INPUT:\" + name + \":\" + meta_repr_str + \")\"\n                if not isinstance(arg_value, Tensor):\n                    in_str = \"[WARNING]\" + in_str\n                self.to(GraphModule)._args_repr.append(in_str)\n                self.__print(0, 1, in_str)\n\n        def _print_state(d):\n            for (_, n) in d.items():\n                self.__print(0, 1, n._shallow_repr())\n\n        _print_state(self._parameters)\n        _print_state(self._buffers)\n\n        # NOTE: The original nn.Module's __call__ method is ignored, which means\n        # that hooks of nn.Modules are ignored. It is not recommended\n        # to use hooks of nn.Module in nn.Graph for the moment.\n        with graph_build_util.DebugScopeContext(\n            self.to(GraphModule)._debug_min_s_level,\n            self.to(GraphModule)._debug_max_v_level,\n            self.to(GraphModule)._debug,\n            self.to(GraphModule)._debug_max_py_stack_depth,\n            self.to(GraphModule)._debug_only_user_py_stack,\n        ):\n            result = self.__block_forward(*args, **kwargs)\n\n        outputs = ()\n        if not (type(result) is tuple or type(result) is list):\n            outputs = (result,)\n        else:\n            outputs = result\n\n        args_tree = ArgsTree(\n            (outputs, {}),\n            True,\n            \"_\"\n            + self.to(GraphModule).name_prefix\n            + self.to(GraphModule).name\n            + \"_output\",\n            None,\n        )\n\n        for (name, arg) in args_tree.iter_named_nodes():\n            if arg.is_leaf():\n                arg_value = arg.value()\n                meta_repr_str = (\n                    arg_value._meta_repr()\n                    if isinstance(arg_value, Tensor)\n                    else str(type(arg_value))\n                )\n                out_str = \"(OUTPUT:\" + name + \":\" + meta_repr_str + \")\"\n                if not isinstance(arg_value, Tensor):\n                    out_str = \"[WARNING]\" + out_str\n                self.to(GraphModule)._outs_repr.append(out_str)\n                self.__print(0, 1, out_str)\n\n        return result\n\n    @property\n    def __class__(self):\n        if self.to(GraphModule)._belonged_graph._is_user_mode == True:\n            return self.to(Module).__class__\n        else:\n            return type(self)\n\n    def __block_forward(self, *args, **kwargs):\n        self.to(GraphModule)._is_executing_forward = True\n        args, kwargs = self.__pre_forward_map(*args, **kwargs)\n        with self.to(GraphModule).scope_context():\n            # \"Instance method __func__ is the function object\", \"when an instance method object is called,\n            # the underlying function __func__ is called, inserting the class instance __self__ in front of\n            # the argument list.\"\n            # Reference: https://docs.python.org/3/reference/datamodel.html\n            unbound_forward_of_module_instance = self.to(Module).forward.__func__\n            result = unbound_forward_of_module_instance(self, *args, **kwargs)\n        self.to(GraphModule)._is_executing_forward = False\n        return result\n\n    def __pre_forward_map(self, *args, **kwargs):\n        # Insert identity op when doing activation checkpointing or pipeline execution.\n        # Identity op outside activation checkpointing scope will be the endpoint of an activation checkpointing segment.\n        # Identity op as the first op of a pipeline stage will make backward op depends on the identity op within the stage,\n        # otherwise the backward op may depends the op in former stage which will make graph creates unnessary buffers.\n        if self.to(GraphModule)._stage_placement is not None:\n\n            def insert_to_global(t):\n                assert isinstance(t, Tensor)\n                return self.__get_or_create_global(\n                    t, self.to(GraphModule)._stage_placement\n                )\n\n            args, kwargs = self.__map_io(\n                \"input\", insert_to_global, \"insert_to_global\", *args, **kwargs\n            )\n\n        if self.to(GraphModule).activation_checkpointing or (\n            self.to(GraphModule).stage_id is not None\n            and self.to(GraphModule).stage_id >= 0\n        ):\n\n            def insert_identity(t):\n                assert isinstance(t, Tensor)\n                return self.__get_or_create_identity(t)\n\n            args, kwargs = self.__map_io(\n                \"input\", insert_identity, \"insert_identity\", *args, **kwargs\n            )\n\n        return args, kwargs\n\n    def __get_or_create_global(self, input_tensor: Tensor = None, placement=None):\n        assert input_tensor is not None\n        assert placement is not None\n        key = str(id(input_tensor)) + str(placement)\n\n        # input_tensor + placement -> unique_global_tensor\n        if key not in self.to(GraphModule)._belonged_graph._unique_global_op_dict:\n            # store input tensor to avoid tensor id recycle\n            self.to(GraphModule)._belonged_graph._unique_global_op_dict[key] = (\n                input_tensor.to_global(placement=placement),\n                input_tensor,\n            )\n\n        return self.to(GraphModule)._belonged_graph._unique_global_op_dict[key][0]\n\n    def __get_or_create_identity(self, input_tensor: Tensor = None):\n        assert input_tensor is not None\n        key = input_tensor\n\n        # input_tensor(with placement) -> unique_identity_tensor\n        # When placement is different, the input tensor(output tensor of __get_or_create_global) is different, so the\n        # key can use only input tensor.\n        if key not in self.to(GraphModule)._belonged_graph._unique_identity_op_dict:\n            # Reuse current module name for indentity op\n            ident_name_scope = graph_build_util.make_new_name_scope(\n                self.to(GraphModule).prev_scope,\n                self.to(GraphModule).name_prefix + self.to(GraphModule).name,\n            )\n            with graph_build_util.BlockScopeContext(\n                self.to(GraphModule).prev_scope, ident_name_scope\n            ):\n                # store input tensor to avoid tensor id recycle\n                self.to(GraphModule)._belonged_graph._unique_identity_op_dict[\n                    key\n                ] = oneflow._C.identity(input_tensor)\n\n        return self.to(GraphModule)._belonged_graph._unique_identity_op_dict[key]\n\n    def add_module(self, name: str, module: Optional[Module]) -> None:\n        if isinstance(module, Module):\n            self.__setattr__(\n                name,\n                get_block_cls(module)(\n                    module,\n                    self.to(GraphModule)._name_prefix\n                    + self.to(GraphModule)._name\n                    + \".\",\n                    name,\n                    self.to(GraphModule)._belonged_graph,\n                ),\n            )\n        elif isinstance(module, Proxy):\n            self.__setattr__(name, module)\n\n    def register_parameter(self, name: str, param: Optional[Parameter]) -> None:\n        self.__setattr__(\n            name,\n            get_proxy_cls(param)(\n                param,\n                self.to(GraphModule)._name_prefix + self.to(GraphModule)._name + \".\",\n                name,\n            ),\n        )\n\n    def modules(self, memo: Optional[Set[\"Proxy\"]] = None) -> Iterator[\"Proxy\"]:\n        assert self.to(GraphModule)._type == GraphBlockType.MODULE\n        if memo is None:\n            memo = set()\n        if self not in memo:\n            memo.add(self)\n            yield self\n            for (name, module) in self._modules.items():\n                if module is None:\n                    continue\n                for m in module.modules(memo):\n                    yield m\n\n    def __map_io(self, io_type, func, func_desc, *args, **kwargs):\n        assert isinstance(func_desc, str)\n        assert io_type in (\"input\", \"output\")\n        mapped_args = []\n\n        def map_tensor(item):\n            assert isinstance(item, Tensor)\n            return func(item)\n\n        args_tree = ArgsTree(\n            (args, kwargs),\n            True,\n            \"_\"\n            + self.to(GraphModule).name_prefix\n            + self.to(GraphModule).name\n            + \"_\"\n            + io_type,\n            None,\n        )\n\n        def leaf_node_fn(leaf_node):\n            arg = leaf_node.value()\n            name = leaf_node.prefix() + \"_\" + leaf_node.name()\n            is_tensor, repr_str = self.__io_tensor_check_and_gen(arg, io_type, name)\n            if is_tensor:\n                self.__print(\n                    0,\n                    1,\n                    f\"{repr_str} is a Tensor, {func_desc} transformation has been done.\",\n                )\n                return map_tensor(arg)\n            else:\n                self.__print(\n                    0,\n                    0,\n                    f\"{repr_str} is not a Tensor, {func_desc} transformation will be ignored.\",\n                )\n                return arg\n\n        out = args_tree.map_leaf(leaf_node_fn)\n        mapped_args = out[0]\n        mapped_kwargs = out[1]\n        return mapped_args, mapped_kwargs\n\n    def __io_tensor_check_and_gen(self, item, io_type, name):\n        assert io_type in (\"input\", \"output\")\n        if isinstance(item, Tensor):\n            repr_str = (\n                \"(\" + io_type.upper() + \":\" + name + \":\" + item._meta_repr() + \")\"\n            )\n            return True, repr_str\n        else:\n            repr_str = (\n                \"[WARNING](\"\n                + io_type.upper()\n                + \":\"\n                + name\n                + \":\"\n                + str(type(item))\n                + \")\"\n            )\n            return False, repr_str\n\n    def __members(self, get_members_fn, recurse=True) -> Iterator[\"Proxy\"]:\n        assert self.to(GraphModule)._type == GraphBlockType.MODULE\n        memo = set()\n        modules = self.modules() if recurse else [self]\n        for module in modules:\n            members = get_members_fn(module)\n            for (k, v) in members:\n                if v is None or v in memo:\n                    continue\n                memo.add(v)\n                yield v\n\n    def parameters(self, recurse: bool = True) -> Iterator[\"Proxy\"]:\n        assert self.to(GraphModule)._type == GraphBlockType.MODULE\n        gen = self.__members(lambda module: module._parameters.items(), recurse=recurse)\n        for elem in gen:\n            yield elem\n\n    def buffers(self, recurse: bool = True) -> Iterator[\"Proxy\"]:\n        assert self.to(GraphModule)._type == GraphBlockType.MODULE\n        gen = self.__members(lambda module: module._buffers.items(), recurse=recurse)\n        for elem in gen:\n            yield elem\n\n    def __setattr__(self, name: str, value=None) -> None:\n        if value is None or not isinstance(value, Proxy):\n            self.__dict__[name] = value\n        else:\n            dicts_or_sets = (\n                self.__dict__,\n                self._modules,\n                self._parameters,\n                self._buffers,\n            )\n            for d in dicts_or_sets:\n                if name in d:\n                    raise AttributeError(\n                        \"'{}' object has duplicated attribute named '{}'\".format(\n                            self.to(GraphModule)._name, name\n                        )\n                    )\n            if value.to(GraphModule).type == GraphBlockType.MODULE:\n                self._modules[name] = value\n            elif value.to(GraphTensor).type == GraphBlockType.PARAMETER:\n                self._parameters[name] = value\n            elif value.to(GraphTensor).type == GraphBlockType.BUFFER:\n                self._buffers[name] = value\n            else:\n                raise AttributeError(\n                    \"'{}' object are not allowed to set attribute named '{}'\".format(\n                        type(self).__name__, name\n                    )\n                )\n\n    def __getattr__(self, name: str):\n        if name in self.__dict__:\n            return self.__dict__[name]\n        # support get module\n        if \"_modules\" in self.__dict__:\n            modules = self.__dict__[\"_modules\"]\n            if name in modules:\n                return modules[name]\n        # support get parameter\n        p_state = self._get_from_states(name, \"_parameters\")\n        if p_state is not None:\n            return p_state\n        # support get buffer\n        b_state = self._get_from_states(name, \"_buffers\")\n        if b_state is not None:\n            return b_state\n        # support none parameter or buffer\n        if name in self.to(Module)._parameters:\n            p_none = self.to(Module)._parameters[name]\n            assert p_none is None\n            return None\n        if name in self.to(Module)._buffers:\n            b_none = self.to(Module)._buffers[name]\n            assert b_none is None\n            return None\n        if hasattr(self.to(Module), name):\n            # support getting normal attr from the nn.Module\n            attr = getattr(self.to(Module), name)\n            if isinstance(attr, types.MethodType):\n                # If the attr is MethodType, rebind the method to self\n                attr = types.MethodType(attr.__func__, self)\n            return attr\n        raise AttributeError(\n            \"'{}' '{}' object '{}' in nn.Graph has no attribute '{}'\".format(\n                self.to(GraphModule)._type,\n                type(self).__name__,\n                self.to(GraphModule)._name_prefix + self.to(GraphModule).name,\n                name,\n            )\n        )\n\n    def _get_from_states(self, name, states_name):\n        if states_name not in self.__dict__:\n            return None\n\n        _states = self.__dict__[states_name]\n        if name not in _states:\n            return None\n\n        _s_block = _states[name]\n        if graph_build_util.lazy_mode.is_enabled():\n            _s_block.try_build()\n            return _s_block.lazy_origin\n        elif (not graph_build_util.lazy_mode.is_enabled()) and self.to(\n            GraphModule\n        )._is_executing_forward:\n            # eager and inside nn.Graph.build()\n            return _s_block.to(Tensor)\n        else:\n            # outside nn.Graph.build()\n            # eager and inside nn.Graph.build()\n            return _s_block\n\n    def __repr__(self):\n        lines = None\n        child_lines = []\n        if len(self.to(GraphModule)._args_repr) > 0:\n            for in_str in self.to(GraphModule)._args_repr:\n                input_str = add_indent(in_str, 2)\n                child_lines.append(input_str)\n\n        def _append_child(d):\n            for (_, n) in d.items():\n                n_str = repr(n)\n                n_str = add_indent(n_str, 2)\n                child_lines.append(n_str)\n\n        _append_child(self._parameters)\n        _append_child(self._buffers)\n        _append_child(self._modules)\n\n        if len(self.to(GraphModule)._outs_repr) > 0:\n            for out_str in self.to(GraphModule)._outs_repr:\n                output_str = add_indent(out_str, 2)\n                child_lines.append(output_str)\n\n        child_lines.append(add_indent(repr(self.to(GraphModule)), 2))\n\n        if len(child_lines) > 0:\n            lines = child_lines\n\n        main_str = self._shallow_repr() + \": (\"\n        if lines is not None:\n            main_str += \"\\n  \" + \"\\n  \".join(lines) + \"\\n\"\n        main_str += \")\"\n        return main_str\n\n    def _shallow_repr(self):\n        shallow_repr = (\n            \"(\"\n            + self.to(GraphModule)._type\n            + \":\"\n            + self.to(GraphModule)._name_prefix\n            + self.to(GraphModule)._name\n            + \":\"\n            + self._oneflow_internal_origin__._shallow_repr()\n            + \")\"\n        )\n        return shallow_repr\n\n    def __print(self, s_level=2, v_level=0, msg: str = \"\"):\n        r\"\"\"Do print according to info level.\n        \"\"\"\n        assert isinstance(s_level, int)\n        assert isinstance(v_level, int)\n        assert isinstance(msg, str)\n        if s_level >= self.to(GraphModule)._debug_min_s_level:\n            if (s_level > 0) or (\n                s_level == 0 and v_level <= self.to(GraphModule)._debug_max_v_level\n            ):\n                print(msg, flush=True)\n\n\nclass LazyBuilder(object):\n    def __init__(self, name: str = None, method=None):\n        self.name = name\n        self.method = method\n        self.result = None\n        self.finished = False\n\n    def try_build(self, block=None):\n        if not self.finished:\n            assert self.name is not None\n            assert self.method is not None\n            assert self.result is None\n            with block.to(GraphTensor).scope_context():\n                self.result = self.method()\n            self.finished = True\n\n\nclass ProxyTensor(Proxy):\n    def __init__(\n        self,\n        origin: Union[Parameter, Tensor] = None,\n        prefix: str = \"\",\n        name: str = \"\",\n        belonged_graph: weakref.ProxyTypes = None,\n    ):\n        assert not isinstance(origin, Proxy)\n        if isinstance(origin, Parameter):\n            self._oneflow_internal_graphblock__ = GraphTensor(\n                prefix,\n                name,\n                belonged_graph,\n                weakref.proxy(self),\n                GraphBlockType.PARAMETER,\n            )\n        elif isinstance(origin, Tensor):\n            self._oneflow_internal_graphblock__ = GraphTensor(\n                prefix, name, belonged_graph, weakref.proxy(self), GraphBlockType.BUFFER\n            )\n        else:\n            raise NotImplementedError()\n        self._lazy_origin_builder = LazyBuilder()\n        self.build_finished = False\n        self._oneflow_internal_graphblock__set_origin(origin)\n\n    def _oneflow_internal_graphblock__set_origin(self, origin):\n        self._oneflow_internal_origin__ = origin\n\n    @property\n    def lazy_origin(self):\n        assert (\n            self.to(GraphTensor)._type == GraphBlockType.PARAMETER\n            or self.to(GraphTensor)._type == GraphBlockType.BUFFER\n        ), \"Only Parameter or Buffer Proxy has lazy_origin\"\n        return self._lazy_origin_builder.result\n\n    def lazy_origin_builder(self):\n        assert (\n            self.to(GraphTensor)._type == GraphBlockType.PARAMETER\n            or self.to(GraphTensor)._type == GraphBlockType.BUFFER\n        ), \"Only Parameter or Buffer Proxy has lazy_origin_builder\"\n        return self._lazy_origin_builder\n\n    def set_lazy_origin_builder(self, builder=None):\n        assert (\n            self.to(GraphTensor)._type == GraphBlockType.PARAMETER\n            or self.to(GraphTensor)._type == GraphBlockType.BUFFER\n        ), \"Only Parameter or Buffer Proxy has lazy_origin_builder\"\n        self._lazy_origin_builder = builder\n\n    def try_build(self):\n        if not self.build_finished:\n            self._lazy_origin_builder.try_build(self)\n            self.build_finished = True\n\n    def __repr__(self):\n        lines = None\n        main_str = self._shallow_repr() + \": (\"\n        if lines is not None:\n            main_str += \"\\n  \" + \"\\n  \".join(lines) + \"\\n\"\n        main_str += \")\"\n        return main_str\n\n    def _shallow_repr(self):\n        shallow_repr = (\n            \"(\"\n            + self.to(GraphTensor)._type\n            + \":\"\n            + self.to(GraphTensor)._name_prefix\n            + self.to(GraphTensor)._name\n            + \":\"\n            + self._oneflow_internal_origin__._meta_repr()\n            + \")\"\n        )\n        return shallow_repr\n\n\nclass ProxySequential(get_seq(ProxyModule)):\n    def __init__(\n        self,\n        origin: Sequential = None,\n        prefix: str = \"\",\n        name: str = \"\",\n        belonged_graph: weakref.ProxyTypes = None,\n    ):\n        super().__init__()\n        self.to(GraphModule)._name_prefix = prefix\n        self.to(GraphModule)._name = name\n        self.to(GraphModule)._belonged_graph = belonged_graph\n        self.to(GraphModule)._belonged_block = weakref.proxy(self)\n        self._oneflow_internal_graphblock__set_origin(origin)\n\n\nclass ProxyModuleList(get_list(ProxyModule)):\n    def __init__(\n        self,\n        origin: ModuleList = None,\n        prefix: str = \"\",\n        name: str = \"\",\n        belonged_graph: weakref.ProxyTypes = None,\n    ):\n        if isinstance(origin, ModuleList):\n            super().__init__()\n            self.to(GraphModule)._name_prefix = prefix\n            self.to(GraphModule)._name = name\n            self.to(GraphModule)._belonged_graph = belonged_graph\n            self._oneflow_internal_graphblock__set_origin(origin)\n            # ModuleList is a container without forward() method,\n\n        elif isinstance(origin, list):\n            super().__init__(origin)\n            first = origin[0]\n            new_name = \"_idx\"\n            new_list = []\n            for item in origin:\n                new_name += \"-\" + item.to(GraphModule).name\n                new_list.append(item.to(Module))\n            new_module_list = ModuleList(new_list)\n            self.to(GraphModule)._name_prefix = (\n                first.to(GraphModule).name_prefix + first.to(GraphModule).name\n            )\n            self.to(GraphModule)._name = new_name\n            self.to(GraphModule)._belonged_graph = first.to(GraphModule)._belonged_graph\n            self._oneflow_internal_origin__ = new_module_list\n\n\nclass ProxyModuleDict(get_dict(ProxyModule)):\n    def __init__(\n        self,\n        origin: ModuleDict = None,\n        prefix: str = \"\",\n        name: str = \"\",\n        belonged_graph: weakref.ProxyTypes = None,\n    ):\n        super().__init__()\n        self.to(GraphModule)._name_prefix = prefix\n        self.to(GraphModule)._name = name\n        self.to(GraphModule)._belonged_graph = belonged_graph\n        self.to(GraphModule)._belonged_block = weakref.proxy(self)\n        self._oneflow_internal_graphblock__set_origin(origin)\n\n\nclass ProxyParameterList(get_para_list(ProxyModule)):\n    def __init__(\n        self,\n        origin: ParameterList = None,\n        prefix: str = \"\",\n        name: str = \"\",\n        belonged_graph: weakref.ProxyTypes = None,\n    ):\n        super().__init__()\n        self.to(GraphModule)._name_prefix = prefix\n        self.to(GraphModule)._name = name\n        self.to(GraphModule)._belonged_graph = belonged_graph\n        self.to(GraphModule)._belonged_block = weakref.proxy(self)\n        self._oneflow_internal_graphblock__set_origin(origin)\n        self.to(GraphModule)._is_executing_forward = True\n\n    def __getitem__(self, idx):\n        assert isinstance(idx, int)\n        idx = self._get_abs_string_index(idx)\n        key = str(idx)\n        p_state = self._get_from_states(key, \"_parameters\")\n        if p_state is not None:\n            return p_state\n        else:\n            raise AttributeError(\"ParameterList dosen't contain \", key)\n\n\nclass ProxyParameterDict(get_para_dict(ProxyModule)):\n    def __init__(\n        self,\n        origin: ParameterDict = None,\n        prefix: str = \"\",\n        name: str = \"\",\n        belonged_graph: weakref.ProxyTypes = None,\n    ):\n        super().__init__()\n        self.to(GraphModule)._name_prefix = prefix\n        self.to(GraphModule)._name = name\n        self.to(GraphModule)._belonged_graph = belonged_graph\n        self.to(GraphModule)._belonged_block = weakref.proxy(self)\n        self._oneflow_internal_graphblock__set_origin(origin)\n        self.to(GraphModule)._is_executing_forward = True\n\n    def __getitem__(self, key: str):\n        p_state = self._get_from_states(key, \"_parameters\")\n        if p_state is not None:\n            return p_state\n        else:\n            raise AttributeError(\"ParameterDict dosen't contain key \", key)\n"
  },
  {
    "path": "python/oneflow/nn/graph/util.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport sys\nfrom string import Template\nfrom typing import Callable, Dict, Union, List, Tuple, Optional\nfrom collections import OrderedDict\n\nfrom google.protobuf import text_format\nfrom google.protobuf.message import Message\n\nimport oneflow\nimport oneflow.core.job.job_pb2 as job_pb\nimport oneflow.core.job.plan_pb2 as plan_pb\nimport oneflow.core.common.device_type_pb2 as device_type\nimport oneflow.core.operator.op_conf_pb2 as op_conf_util\nfrom oneflow.framework.tensor import Tensor\n\n\ndef _nd_sbp2repr(nd_sbp):\n    dim_len = len(nd_sbp.sbp_parallel)\n    nd_sbp_str = \"sbp=(\"\n    for i in range(dim_len):\n        if i > 0:\n            nd_sbp_str += \", \"\n        sbp = nd_sbp.sbp_parallel[i]\n        if sbp.HasField(\"broadcast_parallel\"):\n            nd_sbp_str += \"B\"\n        elif sbp.HasField(\"partial_sum_parallel\"):\n            nd_sbp_str += \"P\"\n        elif sbp.HasField(\"split_parallel\"):\n            nd_sbp_str += \"S(\" + str(sbp.split_parallel.axis) + \")\"\n    nd_sbp_str += \")\"\n    return nd_sbp_str\n\n\ndef _blob_desc_repr(blob_desc):\n    desc_str = \"size=(\"\n    for i in range(len(blob_desc.shape.dim)):\n        if i > 0:\n            desc_str += \", \"\n        desc_str += str(blob_desc.shape.dim[i])\n    desc_str += \"), \"\n    desc_str += \"dtype=(\"\n    desc_str += str(oneflow.dtype.get(int(blob_desc.data_type)))\n    desc_str += \")\"\n    return desc_str\n\n\ndef _get_args_repr(ordered_bn, bn2lbn, bn2nd_sbp, lbn2blob_desc):\n    arg_repr_list = []\n    for bn in ordered_bn:\n        lbns = list(bn2lbn[bn].s)\n\n        # sbp repr\n        sub_bns_sbp = []\n        for bn_idx in range(len(lbns)):\n            sub_bn = bn + \"_\" + str(bn_idx)\n            nd_sbp = bn2nd_sbp[sub_bn]\n            sub_bns_sbp.append(_nd_sbp2repr(nd_sbp))\n\n        # TODO: placement repr\n\n        # shape repr and dtype\n        sub_bns_desc = []\n        for bn_idx in range(len(lbns)):\n            sub_bns_desc.append(_blob_desc_repr(lbn2blob_desc[lbns[bn_idx]]))\n\n        # sub arg repr\n        sub_arg_repr_list = []\n        for bn_idx in range(len(lbns)):\n            sub_arg_repr_list.append(\n                lbns[bn_idx]\n                + \":(\"\n                + sub_bns_sbp[bn_idx]\n                + \", \"\n                + sub_bns_desc[bn_idx]\n                + \")\"\n            )\n\n        if len(lbns) > 1:  # arg of multiple tensors\n            arg_repr_list.append(\"[\" + (\", \").join(sub_arg_repr_list) + \"]\")\n        else:\n            assert len(lbns) == 1\n            arg_repr_list.append(sub_arg_repr_list[0])\n\n    return arg_repr_list\n\n\ndef _get_user_op_io_repr(op_conf, bn2nd_sbp, lbn2blob_desc):\n    user_op_conf = op_conf.user_conf\n    input_sig_str = \", \".join(\n        _get_args_repr(\n            user_op_conf.input_order, user_op_conf.input, bn2nd_sbp, lbn2blob_desc\n        )\n    )\n    output_sig_str = \", \".join(\n        _get_args_repr(\n            user_op_conf.output_order, user_op_conf.output, bn2nd_sbp, lbn2blob_desc\n        )\n    )\n    return input_sig_str, output_sig_str\n\n\ndef _get_var_op_io_repr(op_conf, bn2nd_sbp, lbn2blob_desc):\n    input_sig_str = \"\"\n    var_op_conf = op_conf.variable_conf\n    output_lbn = op_conf.name + \"/\" + var_op_conf.out\n    output_sig_str = var_op_conf.out\n    nd_sbp = bn2nd_sbp[var_op_conf.out]\n    output_sig_str += (\n        \":\" + _nd_sbp2repr(nd_sbp) + \", \" + _blob_desc_repr(lbn2blob_desc[output_lbn])\n    )\n    return input_sig_str, output_sig_str\n\n\ndef _get_iden_op_io_repr(op_conf, bn2nd_sbp, lbn2blob_desc):\n    iden_op_conf = op_conf.identity_conf\n    input_lbn = getattr(iden_op_conf, \"in\")\n    input_sig_str = (\n        input_lbn\n        + \":\"\n        + _nd_sbp2repr(bn2nd_sbp[\"in\"])\n        + \", \"\n        + _blob_desc_repr(lbn2blob_desc[input_lbn])\n    )\n\n    output_lbn = op_conf.name + \"/\" + iden_op_conf.out\n    output_sig_str = iden_op_conf.out\n    nd_sbp = bn2nd_sbp[iden_op_conf.out]\n    output_sig_str += (\n        \":\" + _nd_sbp2repr(nd_sbp) + \", \" + _blob_desc_repr(lbn2blob_desc[output_lbn])\n    )\n\n    return input_sig_str, output_sig_str\n\n\ndef _get_input_op_io_repr(op_conf, bn2nd_sbp, lbn2blob_desc):\n    op_input_conf = op_conf.input_conf\n    output_lbn = op_conf.name + \"/\" + op_input_conf.out\n    nd_sbp = bn2nd_sbp[op_input_conf.out]\n    output_sig_str = (\n        output_lbn\n        + \":\"\n        + _nd_sbp2repr(nd_sbp)\n        + \", \"\n        + _blob_desc_repr(lbn2blob_desc[output_lbn])\n    )\n    return \"\", output_sig_str\n\n\ndef _get_output_op_io_repr(op_conf, bn2nd_sbp, lbn2blob_desc):\n    op_output_conf = op_conf.output_conf\n    input_lbn = getattr(op_output_conf, \"in\")\n    output_lbn = op_conf.name + \"/\" + op_output_conf.out\n\n    input_sig_str = (\n        input_lbn\n        + \":\"\n        + _nd_sbp2repr(bn2nd_sbp[\"in\"])\n        + \", \"\n        + _blob_desc_repr(lbn2blob_desc[output_lbn])\n    )\n\n    nd_sbp = bn2nd_sbp[op_output_conf.out]\n    output_sig_str = (\n        output_lbn\n        + \":\"\n        + _nd_sbp2repr(nd_sbp)\n        + \", \"\n        + _blob_desc_repr(lbn2blob_desc[output_lbn])\n    )\n    return input_sig_str, output_sig_str\n\n\nclass GraphIR(object):\n    def __init__(self, g_proto: job_pb.Job):\n        assert g_proto is not None and isinstance(g_proto, job_pb.Job)\n        self._graph_proto = g_proto\n        self._op2conf = None\n        self._op2placement = None\n\n    def get_op_conf(self, op_name: str) -> Optional[op_conf_util.OperatorConf]:\n        if self._op2conf is None:\n            self._op2conf = dict()\n            for op_conf in self._graph_proto.net.op:\n                self._op2conf[op_conf.name] = op_conf\n        if op_name not in self._op2conf:\n            return None\n        return self._op2conf[op_name]\n\n    def get_op_placement(self, op_name: str) -> Optional[oneflow.placement]:\n        if self._op2placement is None:\n            self._op2placement = dict()\n            for group in self._graph_proto.placement.placement_group:\n                parallel_conf = group.parallel_conf\n                for this_op_name in group.op_set.op_name:\n                    self._op2placement[this_op_name] = oneflow.placement(\n                        proto_str=text_format.MessageToString(parallel_conf)\n                    )\n        if op_name not in self._op2placement:\n            return None\n        return self._op2placement[op_name]\n\n\ndef _op_signature(\n    op: op_conf_util.OperatorConf,\n    graph_proto: job_pb.Job,\n    graph_ir: GraphIR,\n    show_op_loc: bool,\n) -> Tuple[bool, str]:\n    bn2nd_sbp = graph_proto.job_parallel_view_conf.op_name2nd_sbp_signature_conf[\n        op.name\n    ].bn_in_op2nd_sbp\n    lbn2blob_desc = graph_proto.helper.lbn2logical_blob_desc\n    signature_template = Template(\n        op.name\n        + \"($input) -> ($output)\"\n        + \", placement=(\"\n        + str(graph_ir.get_op_placement(op.name))\n        + \")\"\n    )\n    input_sig_str = \"...\"\n    output_sig_str = \"...\"\n\n    # Only deal with UserOpConf and VariableOpConf for now.\n    if op.HasField(\"user_conf\"):\n        input_sig_str, output_sig_str = _get_user_op_io_repr(\n            op, bn2nd_sbp, lbn2blob_desc\n        )\n    elif op.HasField(\"variable_conf\"):\n        input_sig_str, output_sig_str = _get_var_op_io_repr(\n            op, bn2nd_sbp, lbn2blob_desc\n        )\n    elif op.HasField(\"identity_conf\"):\n        input_sig_str, output_sig_str = _get_iden_op_io_repr(\n            op, bn2nd_sbp, lbn2blob_desc\n        )\n    elif op.HasField(\"input_conf\"):\n        input_sig_str, output_sig_str = _get_input_op_io_repr(\n            op, bn2nd_sbp, lbn2blob_desc\n        )\n    elif op.HasField(\"output_conf\"):\n        input_sig_str, output_sig_str = _get_output_op_io_repr(\n            op, bn2nd_sbp, lbn2blob_desc\n        )\n    elif op.name.startswith(\"System-\"):\n        return False, \"\"\n\n    op_str = \"(OPERATOR: \"\n    op_str += signature_template.substitute(input=input_sig_str, output=output_sig_str)\n\n    if show_op_loc and op.loc:\n        op_str += \", location=(\" + op.loc + \")\"\n\n    op_str += \")\"\n\n    return True, op_str\n\n\ndef operators_repr(ops: Message, graph_ir: GraphIR, show_op_loc: bool,) -> List[str]:\n    r\"\"\"Generate operators' string representation of this module\n    \"\"\"\n    graph_proto = graph_ir._graph_proto\n    ops_strs = []\n    for op in ops:\n        op_conf = graph_ir.get_op_conf(op)\n        if op_conf is None:\n            continue\n        assert isinstance(op_conf, op_conf_util.OperatorConf)\n        got_repr, op_str = _op_signature(op_conf, graph_proto, graph_ir, show_op_loc)\n        if got_repr:\n            ops_strs.append(op_str)\n    return ops_strs\n\n\ndef add_indent(in_s, num_spaces):\n    s = in_s.split(\"\\n\")\n    if len(s) == 1:\n        return in_s\n    first = s.pop(0)\n    s = [num_spaces * \" \" + line for line in s]\n    s = \"\\n\".join(s)\n    s = first + \"\\n\" + s\n    return s\n\n\ndef sys_exc_error_msg():\n    msg = \"\"\n    exc_info = sys.exc_info()\n    if len(exc_info) > 0:\n        msg += str(exc_info[0])\n    if len(exc_info) > 1:\n        msg += \" \" + str(exc_info[1])\n    return msg\n\n\ndef seq_to_func_return(seq, need_unpack=False):\n    if need_unpack:\n        return seq[0]\n    return seq\n\n\ndef _rsd_sub_destination_to(origin_dict, dest_device_str):\n    dest_dict = OrderedDict()\n    for k, v in origin_dict.items():\n        tensor_item, device_str = v\n        dest_dict[k] = (\n            tensor_item.to(device=oneflow.device(dest_device_str), copy=True),\n            dest_device_str,\n        )\n    return dest_dict\n\n\ndef _parallel_conf_to(parallel_conf, dest_device):\n    if parallel_conf.device_tag == \"cuda\":\n        assert len(parallel_conf.device_name) == 1\n        parallel_conf.device_name[0] = \"@0:\" + str(dest_device.index)\n\n\ndef _mem_case_to(mem_case, dest_device):\n    if mem_case.device_type == device_type.DeviceType.kCUDA:\n        mem_case.device_id = dest_device.index\n    if (\n        mem_case.HasField(\"pinned_device_type\")\n        and mem_case.pinned_device_type == device_type.DeviceType.kCUDA\n    ):\n        mem_case.pinned_device_id = dest_device.index\n\n\ndef _job_to(job, dest_device):\n    for pg in job.placement.placement_group:\n        _parallel_conf_to(pg.parallel_conf, dest_device)\n    for bpg in job.placement.blob_placement_group:\n        _parallel_conf_to(bpg.parallel_conf, dest_device)\n\n\ndef _modify_bits(original_num, k, j, new_num):\n    if k > j:\n        return original_num\n    mask = ((1 << (j - k + 1)) - 1) << k\n    cleared_num = original_num & ~mask\n    modified_num = cleared_num | ((new_num & ((1 << (j - k + 1)) - 1)) << k)\n    return modified_num\n\n\ndef _get_bits(original_num, k, j):\n    mask = ((1 << (j - k + 1)) - 1) << k\n    cleared_num = (original_num & mask) >> k\n\n    return cleared_num\n\n\ndef _task_id_to(task_id, dest_device):\n    if _get_bits(task_id, 43, 48) == 2:\n        new_id = _modify_bits(task_id, 36, 43, dest_device.index)\n\n        return new_id\n    else:\n        return task_id\n\n\ndef _thrd_id_to(thrd_id, dest_device):\n    if _get_bits(thrd_id, 22, 27) == 2:\n        new_id = _modify_bits(thrd_id, 15, 22, dest_device.index)\n        return new_id\n    else:\n        return thrd_id\n\n\ndef _plan_to(plan_str, dest_device):\n    plan = plan_pb.Plan()\n    plan.ParseFromString(plan_str)\n    for task in plan.task:\n        task.task_id = _task_id_to(task.task_id, dest_device)\n        task.thrd_id = _thrd_id_to(task.thrd_id, dest_device)\n        for node in task.exec_sequence.exec_node:\n            _parallel_conf_to(\n                node.kernel_conf.op_attribute.parallel_conf_signature.op_parallel_conf,\n                dest_device,\n            )\n        for name, regst in task.produced_regst_desc.items():\n            regst.producer_task_id = _task_id_to(regst.producer_task_id, dest_device)\n            for c_task_id_idx in range(len(regst.consumer_task_id)):\n                regst.consumer_task_id[c_task_id_idx] = _task_id_to(\n                    regst.consumer_task_id[c_task_id_idx], dest_device\n                )\n            _mem_case_to(regst.mem_case, dest_device)\n    for mem_block in plan.block_chunk_list.mem_block:\n        _mem_case_to(mem_block.mem_case, dest_device)\n        mem_block.thrd_id_hint = _thrd_id_to(mem_block.thrd_id_hint, dest_device)\n    for chunk in plan.block_chunk_list.chunk:\n        _mem_case_to(chunk.mem_case, dest_device)\n\n    new_ctrl_regst_desc_id2producer_task_id = {}\n    for (\n        regst_desc_id,\n        producer_task_id,\n    ) in plan.ctrl_regst_desc_info.ctrl_regst_desc_id2producer_task_id.items():\n        new_ctrl_regst_desc_id2producer_task_id[regst_desc_id] = _task_id_to(\n            producer_task_id, dest_device\n        )\n    for (\n        regst_desc_id,\n        producer_task_id,\n    ) in new_ctrl_regst_desc_id2producer_task_id.items():\n        plan.ctrl_regst_desc_info.ctrl_regst_desc_id2producer_task_id[\n            regst_desc_id\n        ] = producer_task_id\n\n    for job_id, op_attr_tab in plan.job_id2op_attribute_ref_table.items():\n        for _, op_attr in op_attr_tab.op_name2op_attribute.items():\n            _parallel_conf_to(\n                op_attr.parallel_conf_signature.op_parallel_conf, dest_device\n            )\n\n    return plan.SerializeToString()\n"
  },
  {
    "path": "python/oneflow/nn/image.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nfrom oneflow.nn.modules.dataset import ImageBatchAlign as batch_align\nfrom oneflow.nn.modules.dataset import ImageDecode as decode\nfrom oneflow.nn.modules.dataset import ImageFlip as flip\nfrom oneflow.nn.modules.dataset import ImageNormalize as normalize\nfrom oneflow.nn.modules.dataset import ImageResize as Resize\n"
  },
  {
    "path": "python/oneflow/nn/init.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport os\nimport math\nimport warnings\n\nimport numpy as np\n\nimport oneflow as flow\nfrom oneflow.ops.util.initializer_util import (\n    calc_gain as calculate_gain,\n    calc_fan,\n    get_data_format,\n)\nfrom oneflow.framework.tensor import Tensor\nimport oneflow.framework.dtype as dtype_util\n\n\ndef uniform_(tensor, a=0.0, b=1.0):\n    r\"\"\"\n    \n    Fills the input Tensor with values drawn from the uniform\n    distribution :math:`\\mathcal{U}(a, b)`.\n\n    The interface is consistent with PyTorch.\n    The documentation is referenced from: https://pytorch.org/docs/1.10/nn.init.html.\n\n    Args:\n        tensor: an n-dimensional `oneflow.Tensor`\n        a: the lower bound of the uniform distribution\n        b: the upper bound of the uniform distribution\n\n    Examples:\n        >>> w = flow.empty(3, 5)\n        >>> nn.init.uniform_(w)\n    \"\"\"\n    assert a <= b, \"b must be greater than or equal to a,but got {%d} vs {%d}\" % (b, a)\n    with flow.no_grad():\n        return flow._C.uniform_(tensor, a, b)\n\n\ndef normal_(tensor, mean=0.0, std=1.0):\n    r\"\"\"\n    \n    Fills the input Tensor with values drawn from the normal\n    distribution :math:`\\mathcal{N}(\\text{mean}, \\text{std}^2)`.\n\n    The interface is consistent with PyTorch.\n    The documentation is referenced from: https://pytorch.org/docs/1.10/nn.init.html.\n\n    Args:\n        tensor: an n-dimensional `oneflow.Tensor`\n        mean: the mean of the normal distribution\n        std: the standard deviation of the normal distribution\n\n    Examples:\n        >>> w = flow.empty(3, 5)\n        >>> nn.init.normal_(w)\n    \"\"\"\n    with flow.no_grad():\n        if tensor.is_local:\n            return flow.normal(mean=mean, std=std, size=tensor.shape, out=tensor)\n        else:\n            return flow.normal(\n                mean=mean,\n                std=std,\n                size=tensor.shape,\n                out=tensor,\n                placement=tensor.placement,\n                sbp=tensor.sbp,\n            )\n\n\ndef xavier_uniform_(tensor, gain=1.0, *, data_format=\"NCHW\"):\n    r\"\"\"\n    Fills the input `Tensor` with values according to the method\n    described in `Understanding the difficulty of training deep feedforward\n    neural networks` - Glorot, X. & Bengio, Y. (2010), using a uniform\n    distribution. The resulting tensor will have values sampled from\n    :math:`\\mathcal{U}(-a, a)` where\n\n    .. math::\n        a = \\text{gain} \\times \\sqrt{\\frac{6}{\\text{fan_in} + \\text{fan_out}}}\n\n    The interface is consistent with PyTorch.\n    The documentation is referenced from: https://pytorch.org/docs/1.10/nn.init.html.\n\n    Also known as Glorot initialization.\n\n    Args:\n        tensor: an n-dimensional `oneflow.Tensor`\n        gain: an optional scaling factor\n\n    Examples:\n        >>> w = flow.empty(3, 5)\n        >>> nn.init.xavier_uniform_(w, gain=nn.init.calculate_gain('relu'))\n    \"\"\"\n    fan = calc_fan(tensor.shape, \"fan_sum\", get_data_format(data_format))\n    std = gain * math.sqrt(2.0 / fan)\n    bound = math.sqrt(3.0) * std\n    return uniform_(tensor, -bound, bound)\n\n\ndef xavier_normal_(tensor, gain=1.0, *, data_format=\"NCHW\"):\n    r\"\"\"\n    Fills the input `Tensor` with values according to the method\n    described in `Understanding the difficulty of training deep feedforward\n    neural networks` - Glorot, X. & Bengio, Y. (2010), using a normal\n    distribution. The resulting tensor will have values sampled from\n    :math:`\\mathcal{N}(0, \\text{std}^2)` where\n\n    .. math::\n        \\text{std} = \\text{gain} \\times \\sqrt{\\frac{2}{\\text{fan_in} + \\text{fan_out}}}\n\n    The interface is consistent with PyTorch.\n    The documentation is referenced from: https://pytorch.org/docs/1.10/nn.init.html.\n\n    Also known as Glorot initialization.\n\n    Args:\n        tensor: an n-dimensional `oneflow.Tensor`\n        gain: an optional scaling factor\n\n    Examples:\n        >>> w = flow.empty(3, 5)\n        >>> nn.init.xavier_normal_(w)\n    \"\"\"\n    if os.getenv(\"ONEFLOW_ENABLE_NHWC\") == \"1\":\n        data_format = \"NHWC\"\n    fan = calc_fan(tensor.shape, \"fan_sum\", get_data_format(data_format))\n    std = gain * math.sqrt(2.0 / fan)\n    return normal_(tensor, 0.0, std)\n\n\ndef orthogonal_(tensor, gain=1.0):\n    r\"\"\"\n    Fills the input `Tensor` with a (semi) orthogonal matrix, as\n    described in `Exact solutions to the nonlinear dynamics of learning in deep\n    linear neural networks` - Saxe, A. et al. (2013). The input tensor must have\n    at least 2 dimensions, and for tensors with more than 2 dimensions the\n    trailing dimensions are flattened.\n\n    The interface is consistent with PyTorch.\n    The documentation is referenced from: https://pytorch.org/docs/1.10/nn.init.html.\n\n    Args:\n        tensor: an n-dimensional `torch.Tensor`, where :math:`n \\geq 2`\n        gain: optional scaling factor\n\n    Examples:\n        >>> w = flow.empty(3, 5)\n        >>> nn.init.orthogonal_(w)\n    \"\"\"\n    with flow.no_grad():\n        return tensor.orthogonal_(gain)\n\n\ndef kaiming_uniform_(\n    tensor, a=0, mode=\"fan_in\", nonlinearity=\"leaky_relu\", *, data_format=\"NCHW\"\n):\n    r\"\"\"\n    Fills the input `Tensor` with values according to the method\n    described in `Delving deep into rectifiers: Surpassing human-level\n    performance on ImageNet classification` - He, K. et al. (2015), using a\n    uniform distribution. The resulting tensor will have values sampled from\n    :math:`\\mathcal{U}(-\\text{bound}, \\text{bound})` where\n\n    .. math::\n        \\text{bound} = \\text{gain} \\times \\sqrt{\\frac{3}{\\text{fan_mode}}}\n    \n    The interface is consistent with PyTorch.\n    The documentation is referenced from: https://pytorch.org/docs/1.10/nn.init.html.\n\n    Also known as He initialization.\n\n    Args:\n        tensor: an n-dimensional `oneflow.Tensor`\n        a: the negative slope of the rectifier used after this layer (only\n            used with ``'leaky_relu'``)\n        mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'``\n            preserves the magnitude of the variance of the weights in the\n            forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the\n            backwards pass.\n        nonlinearity: the non-linear function (`nn.functional` name),\n            recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default).\n\n    Examples:\n        >>> w = flow.empty(3, 5)\n        >>> nn.init.kaiming_uniform_(w, mode='fan_in', nonlinearity='relu')\n    \"\"\"\n    if os.getenv(\"ONEFLOW_ENABLE_NHWC\") == \"1\":\n        data_format = \"NHWC\"\n    fan = calc_fan(tensor.shape, mode, get_data_format(data_format))\n    gain = calculate_gain(nonlinearity, a)\n    std = gain / math.sqrt(fan)\n    bound = math.sqrt(3.0) * std\n    return uniform_(tensor, -bound, bound)\n\n\ndef kaiming_normal_(\n    tensor, a=0, mode=\"fan_in\", nonlinearity=\"leaky_relu\", *, data_format=\"NCHW\"\n):\n    r\"\"\"    \n    Fills the input `Tensor` with values according to the method\n    described in `Delving deep into rectifiers: Surpassing human-level\n    performance on ImageNet classification` - He, K. et al. (2015), using a\n    normal distribution. The resulting tensor will have values sampled from\n    :math:`\\mathcal{N}(0, \\text{std}^2)` where\n\n    .. math::\n        \\text{std} = \\frac{\\text{gain}}{\\sqrt{\\text{fan_mode}}}\n\n    The interface is consistent with PyTorch.\n    The documentation is referenced from: https://pytorch.org/docs/1.10/nn.init.html.\n\n    Also known as He initialization.\n\n    Args:\n        tensor: an n-dimensional `oneflow.Tensor`\n        a: the negative slope of the rectifier used after this layer (only\n            used with ``'leaky_relu'``)\n        mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'``\n            preserves the magnitude of the variance of the weights in the\n            forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the\n            backwards pass.\n        nonlinearity: the non-linear function (`nn.functional` name),\n            recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default).\n\n    Examples:\n        >>> w = flow.empty(3, 5)\n        >>> nn.init.kaiming_normal_(w, mode='fan_out', nonlinearity='relu')\n    \"\"\"\n    if os.getenv(\"ONEFLOW_ENABLE_NHWC\") == \"1\":\n        data_format = \"NHWC\"\n    assert mode in [\"fan_in\", \"fan_out\"]\n    fan = calc_fan(tensor.shape, mode, get_data_format(data_format))\n    gain = calculate_gain(nonlinearity, a)\n    std = gain / math.sqrt(fan)\n    return normal_(tensor, 0.0, std)\n\n\n# The trunc_normal_ implemention is referenced from https://github.com/pytorch/pytorch/blob/master/torch/nn/init.py#L22\ndef trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):\n    # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf\n    def norm_cdf(x):\n        # Computes standard normal cumulative distribution function\n        return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0\n\n    if (mean < a - 2 * std) or (mean > b + 2 * std):\n        warnings.warn(\n            \"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. \"\n            \"The distribution of values may be incorrect.\",\n            stacklevel=2,\n        )\n\n    with flow.no_grad():\n        # Values are generated by using a truncated uniform distribution and\n        # then using the inverse CDF for the normal distribution.\n        # Get upper and lower cdf values\n        l = norm_cdf((a - mean) / std)\n        u = norm_cdf((b - mean) / std)\n\n        # Uniformly fill tensor with values from [l, u], then translate to\n        # [2l-1, 2u-1].\n        tensor.uniform_(2 * l - 1, 2 * u - 1)\n\n        # Use inverse cdf transform for normal distribution to get truncated\n        # standard normal\n        tensor.erfinv_()\n\n        # Transform to proper mean, std\n        tensor.mul_(std * math.sqrt(2.0))\n        tensor.add_(mean)\n\n        # Clamp to ensure it's in the proper range\n        tensor.clamp_(min=a, max=b)\n        return tensor\n\n\ndef constant_(tensor, val):\n    r\"\"\"\n    \n    Fills the input Tensor with the value :math:`\\text{val}`.\n\n    The interface is consistent with PyTorch.\n    The documentation is referenced from: https://pytorch.org/docs/1.10/nn.init.html.\n\n    Args:\n        tensor: an n-dimensional `oneflow.Tensor`\n        val: the value to fill the tensor with\n\n    Examples:\n        >>> w = flow.empty(3, 5)\n        >>> nn.init.constant_(w, 0.3)\n    \"\"\"\n    with flow.no_grad():\n        tensor[...] = val\n        return tensor\n\n\ndef ones_(tensor):\n    r\"\"\"\n    \n    Fills the input Tensor with the scalar value `1`.\n\n    The interface is consistent with PyTorch.\n    The documentation is referenced from: https://pytorch.org/docs/1.10/nn.init.html.\n\n    Args:\n        tensor: an n-dimensional `oneflow.Tensor`\n\n    Examples:\n        >>> w = flow.empty(3, 5)\n        >>> nn.init.ones_(w)\n    \"\"\"\n    with flow.no_grad():\n        return constant_(tensor, 1)\n\n\ndef zeros_(tensor):\n    r\"\"\"\n    \n    Fills the input Tensor with the scalar value `0`.\n\n    The interface is consistent with PyTorch.\n    The documentation is referenced from: https://pytorch.org/docs/1.10/nn.init.html.\n\n    Args:\n        tensor: an n-dimensional `oneflow.Tensor`\n\n    Examples:\n        >>> w = flow.empty(3, 5)\n        >>> nn.init.zeros_(w)\n    \"\"\"\n    with flow.no_grad():\n        return constant_(tensor, 0)\n\n\ndef eye_(tensor):\n    r\"\"\"\n    \n    Fills the 2-dimensional input `Tensor` with the identity\n    matrix. Preserves the identity of the inputs in `Linear` layers, where as\n    many inputs are preserved as possible.\n\n    The interface is consistent with PyTorch.\n    The documentation is referenced from: https://pytorch.org/docs/1.10/nn.init.html.\n\n    Args:\n        tensor: a 2-dimensional `oneflow.Tensor`\n\n    Examples:\n        >>> w = flow.empty(3, 5)\n        >>> nn.init.eye_(w)\n    \"\"\"\n    if tensor.ndimension() != 2:\n        raise ValueError(\"Only tensors with 2 dimensions are supported\")\n    with flow.no_grad():\n        return flow._C.eye_(tensor)\n\n\ndef _calculate_fan_in_and_fan_out(tensor):\n    dimensions = tensor.ndimension()\n    if dimensions < 2:\n        raise ValueError(\n            \"Fan in and fan out can not be computed for tensor with fewer than 2 dimensions\"\n        )\n    num_input_fmaps = tensor.size(1)\n    num_output_fmaps = tensor.size(0)\n    receptive_field_size = 1\n    if tensor.ndimension() > 2:\n        for s in tensor.size()[2:]:\n            receptive_field_size *= s\n    fan_in = num_input_fmaps * receptive_field_size\n    fan_out = num_output_fmaps * receptive_field_size\n    return (fan_in, fan_out)\n"
  },
  {
    "path": "python/oneflow/nn/modules/__init__.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nfrom .module import Module\n"
  },
  {
    "path": "python/oneflow/nn/modules/_functions.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow as flow\n\n\nclass BackwardHookFunction(flow.autograd.Function):\n    @staticmethod\n    def forward(ctx, *args):\n        ctx.mark_non_differentiable(*[arg for arg in args if not arg.requires_grad])\n        return args\n\n    @staticmethod\n    def backward(ctx, *args):\n        return args\n"
  },
  {
    "path": "python/oneflow/nn/modules/activation.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport warnings\nfrom typing import Optional\n\nimport oneflow as flow\nfrom oneflow.nn.modules.module import Module\n\n\nclass PReLU(Module):\n    \"\"\"Applies the element-wise function:\n\n    .. math::\n        PReLU(x) = \\\\max(0,x) + a * \\\\min(0,x)\n\n    Here :math:`a` is a learnable parameter. When called without arguments, `nn.PReLU()` uses a single\n    parameter :math:`a` across all input channels. If called with `nn.PReLU(nChannels)`,\n    a separate :math:`a` is used for each input channel.\n\n\n    .. note::\n        weight decay should not be used when learning :math:`a` for good performance.\n\n    .. note::\n        Channel dim is the 2nd dim of input. When input has dims < 2, then there is\n        no channel dim and the number of channels = 1.\n\n    Args:\n        num_parameters (int): number of :math:`a` to learn.\n            Although it takes an int as input, there is only two values are legitimate:\n            1, or the number of channels at input. Default: 1\n        init (float): the initial value of :math:`a`. Default: 0.25\n\n    Shape:\n        - Input: :math:`(N, *)` where `*` means, any number of additional\n          dimensions\n        - Output: :math:`(N, *)`, same shape as the input\n\n    Attr:\n        - weight (Tensor): the learnable weights of shape (:attr:`num_parameters`).\n\n    .. code-block:: python\n\n        >>> import numpy as np\n        >>> import oneflow as flow\n        \n        >>> m = flow.nn.PReLU()\n        >>> input = flow.tensor(np.asarray([[[[1, -2], [3, 4]]]]), dtype=flow.float32)\n        >>> print(m(input).numpy())\n        [[[[ 1.  -0.5]\n           [ 3.   4. ]]]]\n\n    \"\"\"\n\n    def __init__(\n        self, num_parameters: int = 1, init: float = 0.25, device=None, dtype=None\n    ) -> None:\n        super().__init__()\n        self.num_parameters = num_parameters\n        self.weight = flow.nn.Parameter(\n            flow.empty(num_parameters, dtype=dtype, device=device).fill_(init)\n        )\n\n    def forward(self, x):\n        return flow._C.prelu(x, self.weight)\n\n    def extra_repr(self) -> str:\n        return \"num_parameters={}\".format(self.num_parameters)\n\n\nclass ReLU(Module):\n    \"\"\"Applies the rectified linear unit function element-wise:\n\n    :math:`\\\\text{ReLU}(x) = (x)^+ = \\\\max(0, x)`\n\n    Args:\n        inplace: can optionally do the operation in-place. Default: ``False``\n\n    Shape:\n        - Input: :math:`(N, *)` where `*` means, any number of additional\n          dimensions\n        - Output: :math:`(N, *)`, same shape as the input\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n        >>> relu = flow.nn.ReLU()\n        >>> ndarr = np.asarray([1, -2, 3])\n        >>> x = flow.Tensor(ndarr)\n        >>> relu(x)\n        tensor([1., 0., 3.], dtype=oneflow.float32)\n\n    \"\"\"\n\n    def __init__(self, inplace: bool = False):\n        super().__init__()\n        self.inplace = inplace\n\n    def forward(self, x):\n        return flow._C.relu(x, self.inplace)\n\n    def extra_repr(self):\n        inplace_str = \"inplace=True\" if self.inplace else \"\"\n        return inplace_str\n\n\nclass ReLU6(Module):\n    \"\"\"Applies the element-wise function:\n\n    .. math::\n\n        \\\\text{Relu6}(x) = \\\\begin{cases}\n            6 & \\\\text{ if } x > 6 \\\\\\\\\n            0 & \\\\text{ if } x < 0 \\\\\\\\\n            x & \\\\text{ otherwise } \\\\\\\\\n        \\\\end{cases}\n\n    Args:\n        inplace: can optionally do the operation in-place. Default: ``False``\n\n    Shape:\n        - Input: :math:`(N, *)` where `*` means, any number of additional\n          dimensions\n        - Output: :math:`(N, *)`, same shape as the input\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import numpy as np\n        >>> import oneflow as flow\n        \n        >>> x = np.array([-0.5, 0, 0.5]).astype(np.float32)\n        >>> input = flow.Tensor(x)\n        >>> relu6 = flow.nn.ReLU6()\n\n        >>> out = relu6(input)\n        >>> out\n        tensor([0.0000, 0.0000, 0.5000], dtype=oneflow.float32)\n\n    \"\"\"\n\n    def __init__(self, inplace: bool = False):\n        super().__init__()\n        self.inplace = inplace\n\n    def forward(self, x):\n        if self.inplace:\n            warnings.warn(\"ReLU6 module do not support inplace now\")\n        return flow._C.hardtanh(x, min_val=0.0, max_val=6.0)\n\n    def extra_repr(self):\n        inplace_str = \"inplace=True\" if self.inplace else \"\"\n        return inplace_str\n\n\ndef relu6(input, inplace=False):\n    r\"\"\"relu6(input, inplace=False) -> Tensor\n\n    Applies the element-wise function :math:`\\text{ReLU6}(x) = \\min(\\max(0,x), 6)`.\n\n    See :class:`~oneflow.nn.ReLU6` for more details.\n    \"\"\"\n    if inplace:\n        warnings.warn(\"nn.functional.relu6 do not support inplace now\")\n    return flow._C.hardtanh(input, min_val=0.0, max_val=6.0)\n\n\nclass Tanh(Module):\n    \"\"\"This operator computes the hyperbolic tangent value of Tensor.\n\n    The equation is:\n\n    .. math::\n\n        out = \\\\frac{e^x-e^{-x}}{e^x+e^{-x}}\n\n    Args:\n        input (oneflow.Tensor): A Tensor\n\n    Returns:\n        oneflow.Tensor: The result Tensor\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import numpy as np\n        >>> import oneflow as flow\n        \n        >>> x = np.array([-1, 0, 1]).astype(np.float32)\n        >>> input = flow.Tensor(x)\n        >>> tanh = flow.nn.Tanh()\n        >>> out = tanh(input)\n        >>> out\n        tensor([-0.7616,  0.0000,  0.7616], dtype=oneflow.float32)\n\n    \"\"\"\n\n    def __init__(self):\n        super().__init__()\n\n    def forward(self, input):\n        return flow._C.tanh(input)\n\n\nclass ELU(Module):\n    \"\"\"Applies the element-wise function \n        :math:`\\\\text{ELU}(x) = \\\\begin{cases}x & \\\\text{ if } x \\\\gt 0  \\\\\\\\\\\\alpha*(exp(x)-1) & \\\\text{ if } x \\\\le 0 \\\\\\\\\\\\end{cases}`\n\n    Args:\n        alpha: the :math:`\\\\alpha` value for the ELU formulation. Default: 1.0\n        inplace: can optionally do the operation in-place. Default: ``False``\n\n    Shape:\n        - Input: :math:`(N, *)` where `*` means, any number of additional\n          dimensions\n        - Output: :math:`(N, *)`, same shape as the input\n\n    For example:\n\n    .. code-block:: python\n\n\n        >>> import numpy as np\n        >>> import oneflow as flow\n        \n        >>> x = np.array([-0.5, 0, 0.5]).astype(np.float32)\n        >>> input = flow.Tensor(x)\n        >>> elu = flow.nn.ELU()\n\n        >>> out = elu(input)\n        >>> out\n        tensor([-0.3935,  0.0000,  0.5000], dtype=oneflow.float32)\n\n    \"\"\"\n\n    def __init__(self, alpha: float = 1.0, inplace: bool = False):\n        super().__init__()\n        self.alpha = alpha\n        self.inplace = inplace\n\n    def forward(self, x):\n        if self.inplace:\n            warnings.warn(\"ELU module do not support inplace now\")\n        return flow._C.elu(x, alpha=self.alpha)\n\n    def extra_repr(self):\n        param_str = f\"alpha={self.alpha}\"\n        param_str += \", inplace=True\" if self.inplace else \"\"\n        return param_str\n\n\nclass CELU(Module):\n    \"\"\"Applies the element-wise function:\n\n    .. math::\n\n        \\\\text{CELU}(x, \\\\alpha) = \\\\begin{cases}\n\t\t\t\tx & \\\\text{ if } x \\\\ge 0  \\\\\\\\\n                \\\\alpha*(exp(\\\\frac{x}{\\\\alpha})-1) & \\\\text{ otherwise } \\\\\\\\\n    \t\t    \\\\end{cases}\n\n    Args:\n        alpha: the :math:`\\\\alpha` value for the CELU formulation. Default: 1.0\n        inplace: can optionally do the operation in-place. Default: ``False``\n\n    Shape:\n        - Input: :math:`(N, *)` where `*` means, any number of additional\n          dimensions\n        - Output: :math:`(N, *)`, same shape as the input\n\n    For example:\n\n    .. code-block:: python\n\n\n        >>> import numpy as np\n        >>> import oneflow as flow\n        \n        >>> x = np.array([-0.5, 0, 0.5]).astype(np.float32)\n        >>> input = flow.Tensor(x)\n        >>> celu = flow.nn.CELU(alpha=0.5)\n\n        >>> out = celu(input)\n        >>> out\n        tensor([-0.3161,  0.0000,  0.5000], dtype=oneflow.float32)\n\n    \"\"\"\n\n    def __init__(self, alpha: float = 1.0, inplace: bool = False):\n        super().__init__()\n        self.alpha = alpha\n        self.inplace = inplace\n\n    def forward(self, x):\n        return flow._C.celu(x, alpha=self.alpha, inplace=self.inplace)\n\n    def extra_repr(self):\n        param_str = f\"alpha={self.alpha}\"\n        param_str += \", inplace=True\" if self.inplace else \"\"\n        return param_str\n\n\nclass GELU(Module):\n    \"\"\"\n    GELU(approximate='none') -> Tensor\n\n    The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.GELU.html.\n\n    Applies the Gaussian Error Linear Units function:\n\n    .. math:: \\\\text{GELU}(x) = x * \\Phi(x)\n\n    where :math:`\\Phi(x)` is the Cumulative Distribution Function for Gaussian Distribution.\n\n    When the approximate argument is 'tanh', Gelu is estimated with:\n\n    .. math:: \\\\text{GELU}(x) = 0.5 * x * (1 + \\\\text{Tanh}(\\sqrt(2 / \\pi) * (x + 0.044715 * x^3)))\n\n    Args:\n        input (oneflow.Tensor): Input Tensor\n        approximate (string, optional): the gelu approximation algorithm to use:\n            ``'none'`` | ``'tanh'``. Default: ``'none'``\n\n    Returns:\n        oneflow.Tensor: A Tensor has same shape as the input.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import numpy as np\n        >>> import oneflow as flow\n        \n        >>> x = np.array([-0.5, 0, 0.5]).astype(np.float32)\n        >>> input = flow.Tensor(x)\n        >>> gelu = flow.nn.GELU()\n\n        >>> out = gelu(input)\n        >>> out\n        tensor([-0.1543,  0.0000,  0.3457], dtype=oneflow.float32)\n\n    \"\"\"\n\n    def __init__(self, approximate: str = \"none\"):\n        super().__init__()\n        self.approximate = approximate\n\n    def forward(self, input):\n        if self.approximate == \"none\" or self.approximate == \"tanh\":\n            return flow._C.gelu_with_approximate(input, self.approximate)\n        else:\n            raise NotImplementedError\n\n\nclass QuickGELU(Module):\n    \"\"\"\n    QuickGELU() -> Tensor\n\n    Applies GELU approximation that is fast but somewhat inaccurate. See: https://github.com/hendrycks/GELUs\n\n    .. math::\n        \\\\text{QuickGELU}(x) = x * \\\\sigma(1.702x) = x * \\\\frac{1}{1 + \\\\exp(-1.702x)}\n\n    Args:\n        input (oneflow.Tensor): Input Tensor\n\n    Returns:\n        oneflow.Tensor: A Tensor has same shape as the input.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        \n        >>> input = flow.Tensor([-0.5, 0, 0.5])\n        >>> gelu = flow.nn.QuickGELU()\n\n        >>> out = gelu(input)\n        >>> out\n        tensor([-0.1496,  0.0000,  0.3504], dtype=oneflow.float32)\n\n    \"\"\"\n\n    def __init__(self):\n        super().__init__()\n\n    def forward(self, x):\n        return flow._C.quick_gelu(x)\n\n\nclass SquareReLU(Module):\n    \"\"\"\n    SquareReLU() -> Tensor\n\n    Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2\n\n    .. math::\n        :math:`\\\\text{SquareReLU}(x) = \\\\max(0, x) * \\\\max(0, x)`\n\n    Args:\n        input (oneflow.Tensor): Input Tensor\n\n    Returns:\n        oneflow.Tensor: A Tensor has same shape as the input.\n        \n    For example:\n\n    .. code-block:: python\n\n        >>> import numpy as np\n        >>> import oneflow as flow\n        \n        >>> x = np.array([-0.5, 0, 0.5]).astype(np.float32)\n        >>> input = flow.Tensor(x)\n        >>> square_relu = flow.nn.SquareReLU()\n\n        >>> out = square_relu(input)\n        >>> out\n        tensor([0.0000, 0.0000, 0.2500], dtype=oneflow.float32)\n\n    \"\"\"\n\n    def __init__(self):\n        super().__init__()\n\n    def forward(self, x):\n        return flow._C.square_relu(x)\n\n\nclass Sigmoid(Module):\n    \"\"\"Applies the element-wise function:\n\n    .. math::\n        \\\\text{Sigmoid}(x) = \\\\sigma(x) = \\\\frac{1}{1 + \\\\exp(-x)}\n\n    Shape:\n        - Input: :math:`(N, *)` where `*` means, any number of additional\n          dimensions\n        - Output: :math:`(N, *)`, same shape as the input\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import numpy as np\n        >>> import oneflow as flow\n        \n        >>> x = flow.Tensor(np.array([0.81733328, 0.43621480, 0.10351428]))\n        >>> m = flow.nn.Sigmoid()\n        >>> out = m(x)\n        >>> out\n        tensor([0.6937, 0.6074, 0.5259], dtype=oneflow.float32)\n    \"\"\"\n\n    def __init__(self):\n        super().__init__()\n\n    def forward(self, x):\n        return flow._C.sigmoid(x)\n\n\nclass Hardsigmoid(Module):\n    \"\"\"Applies the element-wise function:\n\n    .. math::\n        \\\\text{Hardsigmoid}(x) = \\\\begin{cases}\n            0 & \\\\text{ if } x \\\\le -3  \\\\\\\\\n            1 & \\\\text{ if } x \\\\ge +3 \\\\\\\\\n            \\\\frac{x}{6} + \\\\frac{1}{2} & \\\\text{ otherwise } \\\\\\\\\n        \\\\end{cases}\n\n    Args:\n        inplace: can optionally do the operation in-place. Default: ``False``\n\n    Shape:\n        - Input: :math:`(N, *)` where `*` means, any number of additional\n          dimensions\n        - Output: :math:`(N, *)`, same shape as the input\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import numpy as np\n        >>> import oneflow as flow\n        \n        >>> x = np.array([-0.5, 0, 0.5]).astype(np.float32)\n        >>> input = flow.Tensor(x)\n        >>> hardsigmoid = flow.nn.Hardsigmoid()\n\n        >>> out = hardsigmoid(input)\n        >>> out\n        tensor([0.4167, 0.5000, 0.5833], dtype=oneflow.float32)\n\n\n    \"\"\"\n\n    def __init__(self, inplace: bool = False):\n        super().__init__()\n        self.inplace = inplace\n\n    def forward(self, x):\n        if self.inplace:\n            return flow._C.hardsigmoid(x, True)\n        return flow._C.hardsigmoid(x, False)\n\n    def extra_repr(self):\n        inplace_str = \"inplace=True\" if self.inplace else \"\"\n        return inplace_str\n\n\nclass Hardshrink(Module):\n    r\"\"\"\n    The Hardshrink activation.\n\n    The formula is:\n\n    .. math::\n        \\text{Hardshrink}(x) =\n        \\begin{cases}\n        x, & \\text{ if } x > \\lambda \\\\\n        x, & \\text{ if } x < -\\lambda \\\\\n        0, & \\text{ otherwise }\n        \\end{cases}\n\n    Args:\n        lambd: the :math:`\\lambda` value for the Hardshrink formulation. Default: 0.5\n        inplace: can optionally do the operation in-place. Default: ``False``\n\n    Shape:\n        - Input: :math:`(N, *)` where `*` means, any number of additional\n          dimensions\n        - Output: :math:`(N, *)`, same shape as the input\n\n    For example:\n\n    .. code-block:: python\n    \n        >>> import numpy as np\n        >>> import oneflow as flow\n        >>> x = np.array([-1.1, 0, 0.2, 0.5]).astype(np.float32)\n        >>> input = flow.Tensor(x)\n        >>> hardshrink = flow.nn.Hardshrink(lambd=0.5)\n        >>> out = hardshrink(input)\n        >>> out\n        tensor([-1.1000,  0.0000,  0.0000,  0.0000], dtype=oneflow.float32)\n    \"\"\"\n\n    def __init__(self, lambd: float = 0.5, inplace: bool = False):\n        super().__init__()\n        self.inplace = inplace\n        self.lambd = lambd\n\n    def forward(self, x):\n        return flow._C.hardshrink(x, lambd=self.lambd, inplace=self.inplace)\n\n    def extra_repr(self) -> str:\n        param_str = f\"lambd={self.lambd}\"\n        param_str += \", inplace=True\" if self.inplace else \"\"\n        return param_str\n\n\nclass Softmax(Module):\n    \"\"\"Applies the Softmax function to an n-dimensional input Tensor\n    rescaling them so that the elements of the n-dimensional output Tensor\n    lie in the range [0,1] and sum to 1.\n\n    Softmax is defined as:\n\n    .. math::\n        \\\\text{Softmax}(x_{i}) = \\\\frac{\\\\exp(x_i)}{\\\\sum_j \\\\exp(x_j)}\n\n    When the input Tensor is a sparse tensor then the unspecifed\n    values are treated as ``-inf``.\n\n    Shape:\n        - Input: :math:`(*)` where `*` means, any number of additional\n          dimensions\n        - Output: :math:`(*)`, same shape as the input\n\n    Returns:\n        a Tensor of the same dimension and shape as the input with\n        values in the range [0, 1]\n\n    Args:\n        dim (int): A dimension along which Softmax will be computed (so every slice\n            along dim will sum to 1).\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import numpy as np\n        >>> import oneflow as flow\n        \n        >>> m = flow.nn.Softmax(dim = 2)\n        >>> x = flow.Tensor(\n        ...    np.array(\n        ...        [[[-0.46716809,  0.40112534,  0.61984003],\n        ...        [-1.31244969, -0.42528763,  1.47953856]]]\n        ...    )\n        ... )\n        >>> out = m(x)\n        >>> out\n        tensor([[[0.1575, 0.3754, 0.4671],\n                 [0.0507, 0.1230, 0.8263]]], dtype=oneflow.float32)\n    \"\"\"\n\n    def __init__(self, dim: Optional[int] = None):\n        super(Softmax, self).__init__()\n        self.dim = dim\n\n    def forward(self, x):\n        return flow._C.softmax(x, self.dim)\n\n    def extra_repr(self):\n        return f\"dim={self.dim}\"\n\n\nclass LogSoftmax(Module):\n    r\"\"\"Applies the LogSoftmax function to an n-dimensional\n    input Tensor.\n    The LogSoftmax formulation can be simplified as:\n\n    .. math::\n        \\text{LogSoftmax}(x_{i}) = \\log\\left(\\frac{\\exp(x_i) }{ \\sum_j \\exp(x_j)} \\right) = x_i - \\log({ \\sum_j \\exp(x_j)})\n\n    Args:\n        dim (int): A dimension along which LogSoftmax will be computed.\n\n    Shape:\n        - Input: :math:`(N, *)` where `*` means, any number of additional\n          dimensions\n        - Output: :math:`(N, *)`, same shape as the input\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import numpy as np\n        >>> import oneflow as flow\n        \n        >>> m = flow.nn.LogSoftmax(dim=1)\n        >>> x = flow.Tensor(\n        ...    np.array(\n        ...        [[ 0.4296, -1.1957,  2.5463],\n        ...        [ 1.2552, -1.5747,  0.6923]]\n        ...    )\n        ... )\n        >>> out = m(x)\n        >>> out\n        tensor([[-2.2513, -3.8766, -0.1346],\n                [-0.4877, -3.3176, -1.0506]], dtype=oneflow.float32)\n    \"\"\"\n\n    def __init__(self, dim: Optional[int] = None):\n        super(LogSoftmax, self).__init__()\n        self.dim = dim\n\n    def forward(self, x):\n        return flow._C.log_softmax(x, self.dim)\n\n    def extra_repr(self):\n        return f\"dim={self.dim}\"\n\n\nclass LogSigmoid(Module):\n    \"\"\"Applies the element-wise function:\n\n    .. math::\n        \\\\text{LogSigmoid}(x) = \\\\log\\\\left(\\\\frac{ 1 }{ 1 + \\\\exp(-x)}\\\\right)\n\n    Shape:\n        - Input: :math:`(N, *)` where `*` means, any number of additional\n          dimensions\n        - Output: :math:`(N, *)`, same shape as the input\n\n    For example:\n\n    .. code-block:: python\n\n\n        >>> import numpy as np\n        >>> import oneflow as flow\n        \n        >>> x = np.array([-0.5, 0, 0.5]).astype(np.float32)\n        >>> input = flow.Tensor(x)\n        >>> logsigmoid = flow.nn.LogSigmoid()\n\n        >>> out = logsigmoid(input)\n        >>> out\n        tensor([-0.9741, -0.6931, -0.4741], dtype=oneflow.float32)\n\n    \"\"\"\n\n    def __init__(self):\n        super().__init__()\n\n    def forward(self, x):\n        return flow._C.logsigmoid(x)\n\n\nclass Softplus(Module):\n    \"\"\"Applies the element-wise function:\n\n    .. math::\n        \\\\text{Softplus}(x) = \\\\frac{1}{\\\\beta} * \\\\log(1 + \\\\exp(\\\\beta * x))\n\n    SoftPlus is a smooth approximation to the ReLU function and can be used\n    to constrain the output of a machine to always be positive.\n\n    For numerical stability the implementation reverts to the linear function\n    when :math:`input \\\\times \\\\beta > threshold`.\n\n    Args:\n        beta: the :math:`\\\\beta` value for the Softplus formulation. Default: 1\n        threshold: values above this revert to a linear function. Default: 20\n\n    Shape:\n        - Input: :math:`(N, *)` where `*` means, any number of additional\n          dimensions\n        - Output: :math:`(N, *)`, same shape as the input\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import numpy as np\n        >>> import oneflow as flow\n        \n        >>> x = np.array([-0.5, 0, 0.5]).astype(np.float32)\n        >>> input = flow.Tensor(x)\n        >>> softplus = flow.nn.Softplus()\n\n        >>> out = softplus(input)\n        >>> out\n        tensor([0.4741, 0.6931, 0.9741], dtype=oneflow.float32)\n    \"\"\"\n\n    def __init__(self, beta: int = 1, threshold: int = 20):\n        super().__init__()\n        self.beta = beta\n        self.threshold = threshold\n\n    def forward(self, x):\n        return flow._C.softplus(x, beta=self.beta, threshold=self.threshold)\n\n    def extra_repr(self):\n        return f\"beta={self.beta}, threshold={self.threshold}\"\n\n\nclass Hardswish(Module):\n    \"\"\"Applies the hardswish function, element-wise, as described in the paper `Searching for MobileNetV3\n    <https://arxiv.org/abs/1905.02244>`__.\n\n    .. math::\n        \\\\text{Hardswish}(x) = \\\\begin{cases}\n            0 & \\\\text{ if } x \\\\le -3  \\\\\\\\\n            x & \\\\text{ if } x \\\\ge +3 \\\\\\\\\n            x*(x+3)/6 & \\\\text{ otherwise } \\\\\\\\\n        \\\\end{cases}\n\n    Args:\n        inplace: can optionally do the operation in-place. Default: ``False``\n\n    Shape:\n        - Input: :math:`(N, *)` where `*` means, any number of additional\n          dimensions\n        - Output: :math:`(N, *)`, same shape as the input\n\n    .. code-block:: python\n\n        >>> import numpy as np\n        >>> import oneflow as flow\n        \n        >>> x = np.array([-0.5, 0, 0.5]).astype(np.float32)\n        >>> input = flow.Tensor(x)\n        >>> hardswish = flow.nn.Hardswish()\n\n        >>> out = hardswish(input)\n        >>> out\n        tensor([-0.2083,  0.0000,  0.2917], dtype=oneflow.float32)\n        \n    \"\"\"\n\n    def __init__(self, inplace: bool = False):\n        super().__init__()\n        self.inplace = inplace\n\n    def forward(self, x):\n        if self.inplace:\n            warnings.warn(\"Hardswish module do not support inplace now\")\n        return flow._C.hardswish(x)\n\n    def extra_repr(self):\n        inplace_str = \"inplace=True\" if self.inplace else \"\"\n        return inplace_str\n\n\nclass Hardtanh(Module):\n    \"\"\"\n    Applies the HardTanh function element-wise\n\n    HardTanh is defined as:\n\n    .. math::\n        \\\\text{HardTanh}(x) = \\\\begin{cases}\n            1 & \\\\text{ if } x > 1 \\\\\\\\\n            -1 & \\\\text{ if } x < -1 \\\\\\\\\n            x & \\\\text{ otherwise } \\\\\\\\\n        \\\\end{cases}\n\n    The range of the linear region :math:`[-1, 1]` can be adjusted using\n    :attr:`min_val` and :attr:`max_val`.\n\n    Args:\n        min_val: minimum value of the linear region range. Default: -1\n        max_val: maximum value of the linear region range. Default: 1\n        inplace: can optionally do the operation in-place. Default: ``False``\n\n    Keyword arguments :attr:`min_value` and :attr:`max_value`\n    have been deprecated in favor of :attr:`min_val` and :attr:`max_val`.\n\n    Shape:\n        - Input: :math:`(N, *)` where `*` means, any number of additional\n          dimensions\n        - Output: :math:`(N, *)`, same shape as the input\n\n    For example:\n\n    .. code-block:: python\n\n\n        >>> import numpy as np\n        >>> import oneflow as flow\n        \n        >>> m = flow.nn.Hardtanh()\n        >>> arr = np.array([0.2, 0.3, 3.0, 4.0])\n        >>> x = flow.Tensor(arr)\n        >>> out = m(x)\n        >>> out\n        tensor([0.2000, 0.3000, 1.0000, 1.0000], dtype=oneflow.float32)\n\n    \"\"\"\n\n    def __init__(\n        self,\n        min_val: float = -1,\n        max_val: float = 1,\n        inplace: bool = False,\n        min_value: Optional[float] = None,\n        max_value: Optional[float] = None,\n    ):\n        super().__init__()\n        if min_value is not None:\n            warnings.warn(\n                \"keyword argument min_value is deprecated and rename to min_val\"\n            )\n            min_val = min_value\n        if max_value is not None:\n            warnings.warn(\n                \"keyword argument max_value is deprecated and rename to max_val\"\n            )\n            max_val = max_value\n        self.min_val = min_val\n        self.max_val = max_val\n        self.inplace = inplace\n\n    def forward(self, x):\n        if self.inplace:\n            warnings.warn(\"Hardtanh module do not support inplace now\")\n        return flow._C.hardtanh(x, min_val=self.min_val, max_val=self.max_val)\n\n    def extra_repr(self):\n        param_str = f\"min_val={self.min_val}, max_val={self.max_val}\"\n        param_str += \", inplace=True\" if self.inplace else \"\"\n        return param_str\n\n\nclass LeakyReLU(Module):\n    \"\"\"Applies the element-wise function:\n\n    .. math::\n        \\\\text{LeakyRELU}(x) = \\\\begin{cases}\n            x, & \\\\text{ if } x \\\\geq 0 \\\\\\\\\n            \\\\text{negative_slope} \\\\times x, & \\\\text{ otherwise }\n        \\\\end{cases}\n\n    Args:\n        negative_slope: Controls the angle of the negative slope. Default: 1e-2\n        inplace: can optionally do the operation in-place. Default: ``False``\n\n    Shape:\n        - Input: :math:`(N, *)` where `*` means, any number of additional\n          dimensions\n        - Output: :math:`(N, *)`, same shape as the input\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import numpy as np\n        >>> import oneflow as flow\n        \n        >>> m = flow.nn.LeakyReLU(0.1)\n        >>> arr = np.array([0.2, 0.3, 3.0, 4.0])\n        >>> x = flow.Tensor(arr)\n        >>> out = m(x)\n        >>> out\n        tensor([0.2000, 0.3000, 3.0000, 4.0000], dtype=oneflow.float32)\n    \"\"\"\n\n    def __init__(self, negative_slope: float = 0.01, inplace: bool = False):\n        super().__init__()\n        self.negative_slope = negative_slope\n        self.inplace = inplace\n\n    def forward(self, x):\n        return flow._C.leaky_relu(x, alpha=self.negative_slope, inplace=self.inplace)\n\n    def extra_repr(self):\n        param_str = f\"negative_slope={self.negative_slope}\"\n        param_str += \", inplace=True\" if self.inplace else \"\"\n        return param_str\n\n\nclass RReLU(Module):\n    \"\"\"Applies the randomized leaky rectified liner unit function, element-wise:\n\n    .. math::\n        \\\\text{RReLU}(x) = \\\\begin{cases}\n            x, & \\\\text{ if } x \\\\geq 0 \\\\\\\\\n            a \\\\times x, & \\\\text{ otherwise }\n        \\\\end{cases}\n        \n    where :math:`a` is randomly sampled from uniform distribution\n    :math:`\\mathcal{U}(\\text{lower}, \\text{upper})`.\n    \n    .. note::\n        See `Empirical Evaluation of Rectified Activations in Convolution Network: <https://arxiv.org/pdf/1505.00853.pdf>`_\n\n    Args:\n        lower: lower bound of the uniform distribution. Default: :math:`\\frac{1}{8}`\n        upper: upper bound of the uniform distribution. Default: :math:`\\frac{1}{3}`\n        inplace: can optionally do the operation in-place. Default: ``False``\n\n    Shape:\n        - Input: :math:`(*)`, where :math:`*` means any number of dimensions.\n        - Output: :math:`(N, *)`, same shape as the input\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n        \n        >>> m = flow.nn.RReLU(0.1, 0.3)\n        >>> arr = np.array([0.2, -0.3, -3.0, 4.0, 0.5, -2.2])\n        >>> x = flow.Tensor(arr)\n        >>> out = m(x) \n        >>> print(out) # doctest: +SKIP\n        tensor([ 0.2000, -0.0824, -0.5418,  4.0000,  0.5000, -0.4213], dtype=oneflow.float32) # doctest: +SKIP\n            \n    \"\"\"\n\n    def __init__(\n        self, lower: float = 1.0 / 8, upper: float = 1.0 / 3, inplace: bool = False\n    ):\n        super().__init__()\n        self.lower = lower\n        self.upper = upper\n        self.inplace = inplace\n\n    def forward(self, x):\n        return flow._C.rrelu(x, self.lower, self.upper, self.training, self.inplace)\n\n    def extra_repr(self):\n        param_str = f\"lower={self.lower}\"\n        param_str += f\"upper={self.upper}\"\n        param_str += \", inplace=True\" if self.inplace else \"\"\n        return param_str\n\n\nclass Mish(Module):\n    \"\"\"Applies the element-wise function:\n\n    .. math::\n        \\\\text{Mish}(x) = x * \\\\text{Tanh}(\\\\text{Softplus}(x))\n\n    .. note::\n        See `Mish: A Self Regularized Non-Monotonic Neural Activation Function <https://arxiv.org/abs/1908.08681>`_\n\n    Shape:\n        - Input: :math:`(N, *)` where `*` means, any number of additional\n          dimensions\n        - Output: :math:`(N, *)`, same shape as the input\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import numpy as np\n        >>> import oneflow as flow\n        \n        >>> x = np.array([1, 2, 3]).astype(np.float32)\n        >>> input = flow.Tensor(x)\n        >>> mish = flow.nn.Mish()\n\n        >>> out = mish(input)\n        >>> out\n        tensor([0.8651, 1.9440, 2.9865], dtype=oneflow.float32)\n    \"\"\"\n\n    def __init__(self, inplace: bool = False):\n        self.inplace = inplace\n        super().__init__()\n\n    def forward(self, x):\n        return flow._C.mish(x)\n\n\nclass SiLU(Module):\n    r\"\"\"SiLU(Swish) activation:\n\n    .. math::\n    \n        \\text{SiLU}(x) = x * sigmoid(x)\n    \n    .. note::\n        See `Gaussian Error Linear Units (GELUs) <https://arxiv.org/abs/1606.08415>`_\n        where the SiLU (Sigmoid Linear Unit) was originally coined, and see\n        `Sigmoid-Weighted Linear Units for Neural Network Function Approximation\n        in Reinforcement Learning <https://arxiv.org/abs/1702.03118>`_ and `Swish:\n        a Self-Gated Activation Function <https://arxiv.org/abs/1710.05941v1>`_\n        where the SiLU was experimented with later.\n    \n    Shape:\n        - Input: :math:`(N, *)` where `*` means, any number of additional\n          dimensions\n        - Output: :math:`(N, *)`, same shape as the input\n    \n    For example:\n    \n    .. code-block:: python\n    \n        >>> import numpy as np\n        >>> import oneflow as flow\n\n\n        >>> x = np.array([1, 2, 3]).astype(np.float32)\n        >>> input = flow.Tensor(x)\n        >>> silu = flow.nn.SiLU()\n        >>> out = silu(input)\n        >>> out\n        tensor([0.7311, 1.7616, 2.8577], dtype=oneflow.float32)\n    \"\"\"\n\n    def __init__(self, inplace: bool = False):\n        self.inplace = inplace\n        super().__init__()\n\n    def forward(self, x):\n        return flow._C.silu(x)\n\n\nclass SELU(Module):\n    r\"\"\"Applies the element-wise function:\n\n    The formula is: \n    \n    .. math::  \n    \n        \\text{SELU}(x) = \\text{scale} * (\\max(0,x) + \\min(0, \\alpha * (\\exp(x) - 1)))\n    \n    with :math:`\\alpha = 1.6732632423543772848170429916717` and\n    \n    :math:`\\text{scale} = 1.0507009873554804934193349852946`.\n    \n    .. warning::\n    \n        When using ``kaiming_normal`` or ``kaiming_normal_`` for initialisation,\n        ``nonlinearity='linear'`` should be used instead of ``nonlinearity='selu'``\n        in order to get `Self-Normalizing Neural Networks`_.\n        See :func:`torch.nn.init.calculate_gain` for more information.\n    \n    More details can be found in the paper `Self-Normalizing Neural Networks <https://arxiv.org/abs/1706.02515>`_.\n    \n    Shape:\n        - Input: :math:`(N, *)` where `*` means, any number of additional\n          dimensions\n        - Output: :math:`(N, *)`, same shape as the input\n    \n    For example:\n    \n    .. code-block:: python\n    \n        >>> import numpy as np\n        >>> import oneflow as flow\n        >>> x = np.array([1, 2, 3]).astype(np.float32)\n        >>> input = flow.Tensor(x)\n        >>> selu = flow.nn.SELU()\n        >>> out = selu(input)\n        >>> out\n        tensor([1.0507, 2.1014, 3.1521], dtype=oneflow.float32)\n    \"\"\"\n\n    def __init__(self, inplace: bool = False):\n        self.inplace = inplace\n        super().__init__()\n\n    def forward(self, x):\n        return flow._C.selu(x)\n\n\nclass Softshrink(Module):\n    r\"\"\"\n    The Softshrink activation.\n\n    The formula is:\n    \n    .. math::\n\n        \\text{Softshrink}(x) =\n        \\begin{cases}\n        x - \\lambd, & \\text{ if } x > \\lambda \\\\\n        x + \\lambd, & \\text{ if } x < -\\lambda \\\\\n        0, & \\text{ otherwise }\n        \\end{cases}\n\n    The interface is consistent with PyTorch.\n    The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.Softshrink.html.\n\n    Args:\n        lambd: the :math:`\\lambda` value for the Softshrink formulation. Default: 0.5\n        inplace: can optionally do the operation in-place. Default: ``False``\n    \n    Shape:\n        - Input: :math:`(N, *)` where `*` means, any number of additional\n          dimensions\n        - Output: :math:`(N, *)`, same shape as the input\n\n    For example:\n    \n    .. code-block:: python\n    \n        >>> import numpy as np\n        >>> import oneflow as flow\n        >>> x = np.array([-1, 0, 0.2, 0.5]).astype(np.float32)\n        >>> input = flow.Tensor(x)\n        >>> softshrink = flow.nn.Softshrink(lambd=0.5)\n        >>> out = softshrink(input)\n        >>> out\n        tensor([-0.5000,  0.0000,  0.0000,  0.0000], dtype=oneflow.float32)\n    \"\"\"\n\n    def __init__(self, lambd: float = 0.5, inplace: bool = False):\n        self.inplace = inplace\n        self.lambd = lambd\n        super().__init__()\n\n    def forward(self, x):\n        return flow._C.softshrink(x, alpha=self.lambd, inplace=self.inplace)\n\n    def extra_repr(self) -> str:\n        param_str = f\"lambd={self.lambd}\"\n        param_str += \", inplace=True\" if self.inplace else \"\"\n        return param_str\n\n\nclass Softsign(Module):\n    r\"\"\"The SoftSign activation.\n\n    The formula is: \n    \n    .. math::  \n    \n        SoftSign(x) = \\frac{x}{1 + |x|}\n    \n    Shape:\n        - Input: :math:`(N, *)` where `*` means, any number of additional\n          dimensions\n        - Output: :math:`(N, *)`, same shape as the input\n    \n    For example:\n    \n    .. code-block:: python\n    \n        >>> import numpy as np\n        >>> import oneflow as flow\n        >>> x = np.array([1, 2, 3]).astype(np.float32)\n        >>> input = flow.Tensor(x)\n        >>> softsign = flow.nn.Softsign()\n        >>> out = softsign(input)\n        >>> out\n        tensor([0.5000, 0.6667, 0.7500], dtype=oneflow.float32)\n    \"\"\"\n\n    def __init__(self, inplace: bool = False):\n        self.inplace = inplace\n        super().__init__()\n\n    def forward(self, x):\n        return flow._C.softsign(x)\n\n\nclass GLU(Module):\n    r\"\"\"The GLU activation.\n\n    Args:\n        input (Tensor, float): input tensor. \n        dim (int, optional): dimension on which to split the input. Default: -1\n\n    Shape:\n        - Input: :math:`(\\ast_1, N, \\ast_2)` where `*` means, any number of additional\n          dimensions\n        - Output: :math:`(\\ast_1, M, \\ast_2)` where :math:`M=N/2`\n\n    The formula is: \n    \n    .. math::  \n\n        GLU(input) = GLU(a, b) = a \\otimes sigmoid(b)\n\n    .. note::\n        where input is split in half along dim to form a and b, ⊗ is the element-wise product between matrices.\n\n    For example:\n    \n    .. code-block:: python\n    \n        >>> import oneflow as flow\n        >>> import oneflow.nn as nn\n        >>> m = nn.GLU()\n        >>> x = flow.tensor([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=flow.float32)\n        >>> y = m(x)\n        >>> y\n        tensor([[0.9526, 1.9640],\n                [4.9954, 5.9980]], dtype=oneflow.float32)\n    \n    \"\"\"\n\n    def __init__(self, dim: Optional[int] = -1):\n        super().__init__()\n        self.dim = dim\n\n    def forward(self, input):\n        return flow._C.glu(input, self.dim)\n\n\nclass Threshold(Module):\n    r\"\"\"The Threshold Activation. Return ``x`` if ``x`` is greater than ``threshold``, else return ``value``.\n\n    The interface is consistent with PyTorch.\n    The documentation is referenced from https://pytorch.org/docs/1.10/generated/torch.nn.Threshold.html.\n\n    The formula is:\n\n    .. math::\n\n        \\text{Threshold}(x) =\n        \\begin{cases}\n        x, & \\text{ if } x > \\text{ threshold } \\\\\n        \\text{value }, & \\text{ otherwise }\n        \\end{cases}\n\n    Args:\n        threshold (float): The ``threshold`` value for the Threshold formulation\n        value (float): The ``value`` value for the Threshold formulation\n\n    Shapes:\n        - Input: :math:`(N, *)` where `*` means, any number of additional dimensions\n        - Output: :math:`(N, *)`, same shape as the input\n\n    Returns:\n        Oneflow.Tensor: The result tensor\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n        >>> x = np.array([-1, 0, 0.5, 1]).astype(np.float32)\n        >>> input = flow.Tensor(x)\n        >>> th = flow.nn.Threshold(threshold=0.5, value=0.2)\n        >>> out = th(input)\n        >>> out\n        tensor([0.2000, 0.2000, 0.2000, 1.0000], dtype=oneflow.float32)\n\n    \"\"\"\n\n    def __init__(self, threshold: float, value: float):\n        super().__init__()\n        self.threshold = threshold\n        self.value = value\n\n    def forward(self, input):\n        return flow._C.threshold(input, threshold=self.threshold, value=self.value)\n\n\nif __name__ == \"__main__\":\n    import doctest\n\n    doctest.testmod(raise_on_error=True)\n"
  },
  {
    "path": "python/oneflow/nn/modules/affine_grid.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nfrom typing import List\n\nimport oneflow as flow\n\n\ndef affine_grid(theta, size: List[int], align_corners: bool = False):\n    \"\"\"The interface is consistent with PyTorch.    \n    The documentation is referenced from: \n    https://pytorch.org/docs/1.10/generated/torch.nn.functional.affine_grid.html.\n\n    Generates a 2D or 3D flow field (sampling grid), given a batch of\n    affine matrices :attr:`theta`.\n\n    .. note::\n        This function is often used in conjunction with :func:`grid_sample`\n        to build `Spatial Transformer Networks`_ .\n\n    Args:\n        theta (Tensor): input batch of affine matrices with shape\n            (:math:`N, 2, 3`) for 2D or\n            (:math:`N, 3, 4`) for 3D\n        size (oneflow.Size): the target output image size.\n            (:math:`N, C, H, W` for 2D or\n            :math:`N, C, D, H, W` for 3D)\n            Example: oneflow.Size((32, 3, 24, 24))\n        align_corners (bool): if ``True``, consider ``-1`` and ``1``\n            to refer to the centers of the corner pixels rather than the image corners.\n            Refer to :func:`grid_sample` for a more complete description.\n            A grid generated by :func:`affine_grid` should be passed to :func:`grid_sample`\n            with the same setting for this option.\n            Default: ``False``\n\n    Returns:\n        output (Tensor): output Tensor of size (:math:`N, H, W, 2`)\n\n    .. _`Spatial Transformer Networks`:\n        https://arxiv.org/abs/1506.02025\n\n    Examples::\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n        >>> input = flow.tensor(np.arange(1., 7).reshape((1, 2, 3)), dtype=flow.float32)\n        >>> output = flow.nn.functional.affine_grid(input, flow.Size([1, 1, 2, 2]), align_corners=True)\n        >>> output\n        tensor([[[[ 0., -3.],\n                  [ 2.,  5.]],\n        <BLANKLINE>\n                 [[ 4.,  7.],\n                  [ 6., 15.]]]], dtype=oneflow.float32)\n    \"\"\"\n    y = flow._C.affine_grid(theta, size=size, align_corners=align_corners)\n    return y\n\n\nif __name__ == \"__main__\":\n    import doctest\n\n    doctest.testmod(raise_on_error=True)\n"
  },
  {
    "path": "python/oneflow/nn/modules/all_reduce.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow as flow\nfrom oneflow.nn.modules.module import Module\n\nfrom typing import Sequence\n\n\nclass AllReduce(Module):\n    def __init__(self, parallel_conf_str: str):\n        super().__init__()\n        self._op = (\n            flow.stateful_op(\"eager_ccl_all_reduce\").Input(\"in\").Output(\"out\").Build()\n        )\n        self.parallel_conf = parallel_conf_str\n\n    def forward(self, x):\n        assert x.device.type == \"cuda\"\n        assert x.device.index == flow.env.get_local_rank()\n        return flow._C.dispatch_eager_ccl_all_reduce(self._op, parallel_conf)\n"
  },
  {
    "path": "python/oneflow/nn/modules/arange.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nfrom typing import List, Union\nimport oneflow as flow\n\n\ndef arange_op(\n    start: Union[int, flow.Tensor] = None,\n    end: Union[int, flow.Tensor] = None,\n    step: int = 1,\n    dtype: flow.dtype = None,\n    device: Union[str, flow.device] = None,\n    placement: flow.placement = None,\n    sbp: Union[flow.sbp.sbp, List[flow.sbp.sbp]] = None,\n    requires_grad: bool = False,\n):\n    if start is None:\n        start = 0\n    elif flow.is_tensor(start):\n        # support start as a Scalar Tensor\n        assert len(start.shape) == 0, \"start must be a Scalar\"\n        start = start.item()\n\n    if end is None:\n        end = start\n        start = 0\n    elif flow.is_tensor(end):\n        # support end as a Scalar Tensor\n        assert len(end.shape) == 0, \"end must be a Scalar\"\n        end = end.item()\n\n    if placement is None:\n        if isinstance(device, str):\n            device = flow.device(device)\n        res = flow._C.arange(start, end, step, dtype=dtype, device=device)\n    else:\n        assert isinstance(\n            placement, flow._oneflow_internal.placement\n        ), \"placement should be oneflow._oneflow_internal.placement type.\"\n        assert isinstance(sbp, (flow.sbp.sbp, tuple, list)), \"sbp: %s\" % sbp\n        if isinstance(sbp, flow.sbp.sbp):\n            sbp = (sbp,)\n        else:\n            for elem in sbp:\n                assert isinstance(elem, flow.sbp.sbp), \"sbp: %s\" % sbp\n        assert len(sbp) == len(placement.ranks.shape)\n        res = flow._C.global_arange(\n            start, end, step, dtype=dtype, placement=placement, sbp=sbp\n        )\n\n    res.requires_grad = requires_grad\n    return res\n\n\nif __name__ == \"__main__\":\n    import doctest\n\n    doctest.testmod(raise_on_error=True)\n"
  },
  {
    "path": "python/oneflow/nn/modules/argsort.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow as flow\nfrom oneflow.framework.tensor import register_tensor_op\nfrom oneflow.nn.modules.module import Module\nfrom oneflow.ops.transpose_util import (\n    get_inversed_perm,\n    get_perm_when_transpose_axis_to_last_dim,\n)\n\n\ndef argsort_op(input, dim: int = -1, descending: bool = False):\n    num_dims = len(input.shape)\n    dim = dim if dim >= 0 else dim + num_dims\n    direction = \"DESCENDING\" if descending else \"ASCENDING\"\n    assert 0 <= dim < num_dims, \"dim out of range\"\n    if dim == num_dims - 1:\n        return flow._C.arg_sort(input, direction)\n    else:\n        perm = get_perm_when_transpose_axis_to_last_dim(num_dims, dim)\n        x = flow._C.transpose(input, perm=perm)\n        x = flow._C.arg_sort(x, direction)\n        return flow._C.transpose(x, perm=get_inversed_perm(perm))\n\n\nif __name__ == \"__main__\":\n    import doctest\n\n    doctest.testmod(raise_on_error=True)\n"
  },
  {
    "path": "python/oneflow/nn/modules/argwhere.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nfrom typing import Optional\n\nimport numpy as np\n\nimport oneflow as flow\nfrom oneflow.framework.tensor import register_tensor_op\nfrom oneflow.nn.modules.module import Module\n\n\ndef argwhere_op(input, dtype: Optional[flow.dtype] = flow.int32):\n    \"\"\"This operator finds the indices of input Tensor `input` elements that are non-zero. \n\n    It returns a list in which each element is a coordinate that points to a non-zero element in the condition.\n\n    Args:\n        input (oneflow.Tensor): the input Tensor.\n        dtype (Optional[flow.dtype], optional): The data type of output. Defaults to None.\n\n    Returns:\n        oneflow.Tensor: The result Tensor.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import numpy as np\n        >>> import oneflow as flow\n        >>> x = np.array([[0, 1, 0],\n        ...            [2, 0, 2]]).astype(np.float32)\n\n        >>> input = flow.Tensor(x)\n        >>> output = flow.argwhere(input)\n        >>> output\n        tensor([[0, 1],\n                [1, 0],\n                [1, 2]], dtype=oneflow.int32)\n\n    \"\"\"\n\n    if input.is_lazy:\n        raise ValueError(\"A lazy tensor can not be applied to argwhere.\")\n\n    (res, size) = flow._C.argwhere(input, dtype=dtype)\n    slice_tup_list = [(0, size.numpy().item(), 1)]\n    return flow.slice(res, slice_tup_list=slice_tup_list)\n\n\nif __name__ == \"__main__\":\n    import doctest\n\n    doctest.testmod(raise_on_error=True)\n"
  },
  {
    "path": "python/oneflow/nn/modules/as_tensor.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport numpy as np\nimport oneflow as flow\n\n\ndef as_tensor(data, dtype=None, device=None):\n    if flow.is_tensor(data):\n        if dtype is None:\n            dtype = data.dtype\n        if device is None:\n            device = data.device\n        if data.dtype is dtype and data.device is device:\n            return data\n        else:\n            data = data.to(dtype=dtype, device=device)\n    elif isinstance(data, (np.ndarray)):\n        if dtype is None:\n            if (device is None) or (device.type == \"cpu\"):\n                data = flow.from_numpy(data)\n            else:\n                data = flow.tensor(data, device=device)\n        else:\n            data_infer_flow_type = flow.framework.dtype.convert_numpy_dtype_to_oneflow_dtype(\n                data.dtype\n            )\n            if data_infer_flow_type is dtype:\n                if (device is None) or (device.type == \"cpu\"):\n                    data = flow.from_numpy(data)\n                else:\n                    data = flow.tensor(data, dtype=dtype, device=device)\n            else:\n                if (device is None) or (device.type == \"cpu\"):\n                    data = flow.tensor(data, dtype=dtype)\n                else:\n                    data = flow.tensor(data, dtype=dtype, device=device)\n    else:\n        # not shared memory in this case\n        data = flow.tensor(data)\n        if device is not None:\n            data = data.to(device)\n        if dtype is not None:\n            data = data.to(dtype)\n    return data\n"
  },
  {
    "path": "python/oneflow/nn/modules/batchnorm.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nfrom typing import Union\nimport os\n\nimport oneflow as flow\nfrom oneflow.nn.modules.module import Module\nfrom oneflow.autograd import Function\n\n\nclass _NormBase(Module):\n    \"\"\"Common base of _InstanceNorm and _BatchNorm\"\"\"\n\n    def __init__(\n        self,\n        num_features: int,\n        eps: float = 1e-05,\n        momentum: float = 0.1,\n        affine: bool = True,\n        track_running_stats: bool = True,\n    ) -> None:\n        super().__init__()\n        self.num_features = num_features\n        self.eps = eps\n        self.momentum = momentum\n        self.affine = affine\n        self.track_running_stats = track_running_stats\n        if self.affine:\n            self.weight = flow.nn.Parameter(flow.Tensor(num_features))\n            self.bias = flow.nn.Parameter(flow.Tensor(num_features))\n        else:\n            self.register_parameter(\"weight\", None)\n            self.register_parameter(\"bias\", None)\n        if self.track_running_stats:\n            self.register_buffer(\"running_mean\", flow.zeros(num_features))\n            self.register_buffer(\"running_var\", flow.ones(num_features))\n            self.register_buffer(\"num_batches_tracked\", flow.tensor(0, dtype=flow.long))\n        else:\n            self.register_buffer(\"running_mean\", None)\n            self.register_buffer(\"running_var\", None)\n            self.register_buffer(\"num_batches_tracked\", None)\n\n        self.reset_parameters()\n\n    def reset_running_stats(self) -> None:\n        if self.track_running_stats:\n            self.running_mean.zero_()\n            self.running_var.fill_(1)\n            self.num_batches_tracked.zero_()\n\n    def reset_parameters(self) -> None:\n        self.reset_running_stats()\n        if self.affine:\n            flow.nn.init.ones_(self.weight)\n            flow.nn.init.zeros_(self.bias)\n\n    def _check_input_dim(self, input):\n        raise NotImplementedError\n\n    def _load_from_state_dict(\n        self,\n        state_dict,\n        prefix,\n        local_metadata,\n        strict,\n        missing_keys,\n        unexpected_keys,\n        error_msgs,\n    ):\n        if self.track_running_stats:\n            num_batches_tracked_key = prefix + \"num_batches_tracked\"\n            if not num_batches_tracked_key in state_dict:\n                if self.running_mean.is_global:\n                    sbp = self.running_mean.sbp\n                    placement = self.running_mean.placement\n                    state_dict[num_batches_tracked_key] = flow.tensor(\n                        0, dtype=flow.long\n                    ).to_global(sbp=sbp, placement=placement)\n                else:\n                    state_dict[num_batches_tracked_key] = flow.tensor(\n                        0, dtype=flow.long\n                    )\n        super(_NormBase, self)._load_from_state_dict(\n            state_dict,\n            prefix,\n            local_metadata,\n            strict,\n            missing_keys,\n            unexpected_keys,\n            error_msgs,\n        )\n\n    def extra_repr(self):\n        return \"{num_features}, eps={eps}, momentum={momentum}, affine={affine}, track_running_stats={track_running_stats}\".format(\n            **self.__dict__\n        )\n\n\nclass _BatchNorm(_NormBase):\n    def __init__(\n        self,\n        num_features,\n        eps=1e-05,\n        momentum=0.1,\n        affine=True,\n        track_running_stats=True,\n    ):\n        super().__init__(num_features, eps, momentum, affine, track_running_stats)\n        self.channel_axis = 1\n\n    def forward(self, x):\n        self._check_input_dim(x)\n        exponential_average_factor = self.momentum\n        if self.training and self.track_running_stats:\n            if self.num_batches_tracked is not None:\n                self.num_batches_tracked.add_(1)\n                if self.momentum is None:\n                    exponential_average_factor = 1.0 / float(self.num_batches_tracked)\n        if self.training:\n            is_training = True\n        else:\n            is_training = (self.running_mean is None) and (self.running_var is None)\n        # NOTE(lixiang): If it is training mode, pass running_mean and running_var directly to the functor layer.\n        return flow._C.normalization(\n            x,\n            self.running_mean,\n            self.running_var,\n            self.weight,\n            self.bias,\n            axis=self.channel_axis,\n            epsilon=self.eps,\n            momentum=exponential_average_factor,\n            is_training=is_training,\n        )\n\n\nclass BatchNorm1d(_BatchNorm):\n    \"\"\"Applies Batch Normalization over a 2D or 3D input (a mini-batch of 1D\n    inputs with optional additional channel dimension) as described in the paper\n    `Batch Normalization: Accelerating Deep Network Training by Reducing\n    Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`__ .\n\n    .. math::\n\n        y = \\\\frac{x - \\\\mathrm{E}[x]}{\\\\sqrt{\\\\mathrm{Var}[x] + \\\\epsilon}} * \\\\gamma + \\\\beta\n\n    The mean and standard-deviation are calculated per-dimension over\n    the mini-batches and :math:`\\\\gamma` and :math:`\\\\beta` are learnable parameter vectors\n    of size `C` (where `C` is the input size). By default, the elements of :math:`\\\\gamma` are set\n    to 1 and the elements of :math:`\\\\beta` are set to 0. The standard-deviation is calculated\n    via the biased estimator, equivalent to `oneflow.var(input, unbiased=False)`.\n\n    Also by default, during training this layer keeps running estimates of its\n    computed mean and variance, which are then used for normalization during\n    evaluation. The running estimates are kept with a default :attr:`momentum`\n    of 0.1.\n\n    If :attr:`track_running_stats` is set to ``False``, this layer then does not\n    keep running estimates, and batch statistics are instead used during\n    evaluation time as well.\n\n    .. note::\n        This :attr:`momentum` argument is different from one used in optimizer\n        classes and the conventional notion of momentum. Mathematically, the\n        update rule for running statistics here is\n        :math:`\\\\hat{x}_\\\\text{new} = (1 - \\\\text{momentum}) \\\\times \\\\hat{x} + \\\\text{momentum} \\\\times x_t`,\n        where :math:`\\\\hat{x}` is the estimated statistic and :math:`x_t` is the\n        new observed value.\n\n    Because the Batch Normalization is done over the `C` dimension, computing statistics\n    on `(N, L)` slices, it's common terminology to call this Temporal Batch Normalization.\n\n    Args:\n        num_features: :math:`C` from an expected input of size\n            :math:`(N, C, L)` or :math:`L` from input of size :math:`(N, L)`\n        eps: a value added to the denominator for numerical stability.\n            Default: 1e-5\n        momentum: the value used for the running_mean and running_var\n            computation. Can be set to ``None`` for cumulative moving average\n            (i.e. simple average). Default: 0.1\n        affine: a boolean value that when set to ``True``, this module has\n            learnable affine parameters. Default: ``True``\n        track_running_stats: a boolean value that when set to ``True``, this\n            module tracks the running mean and variance, and when set to ``False``,\n            this module does not track such statistics, and initializes statistics\n            buffers :attr:`running_mean` and :attr:`running_var` as ``None``.\n            When these buffers are ``None``, this module always uses batch statistics.\n            in both training and eval modes. Default: ``True``\n\n    Shape:\n        - Input: :math:`(N, C)` or :math:`(N, C, L)`\n        - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n        \n        >>> x = flow.Tensor(np.random.randn(20, 100))\n        >>> m = flow.nn.BatchNorm1d(100)\n        >>> y = m(x)\n\n    \"\"\"\n\n    def _check_input_dim(self, input):\n        if input.ndim != 2 and input.ndim != 3:\n            raise ValueError(\n                \"expected 2D or 3D input (got {}D input)\".format(input.ndim)\n            )\n\n\nclass BatchNorm2d(_BatchNorm):\n    \"\"\"Applies Batch Normalization over a 4D input (a mini-batch of 2D inputs\n    with additional channel dimension) as described in the paper\n    `Batch Normalization: Accelerating Deep Network Training by Reducing\n    Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`__ .\n\n    .. math::\n\n        y = \\\\frac{x - \\\\mathrm{E}[x]}{ \\\\sqrt{\\\\mathrm{Var}[x] + \\\\epsilon}} * \\\\gamma + \\\\beta\n\n    The mean and standard-deviation are calculated per-dimension over\n    the mini-batches and :math:`\\\\gamma` and :math:`\\\\beta` are learnable parameter vectors\n    of size `C` (where `C` is the input size). By default, the elements of :math:`\\\\gamma` are set\n    to 1 and the elements of :math:`\\\\beta` are set to 0. The standard-deviation is calculated\n    via the biased estimator, equivalent to `oneflow.var(input, unbiased=False)`.\n\n    Also by default, during training this layer keeps running estimates of its\n    computed mean and variance, which are then used for normalization during\n    evaluation. The running estimates are kept with a default :attr:`momentum`\n    of 0.1.\n\n    If :attr:`track_running_stats` is set to ``False``, this layer then does not\n    keep running estimates, and batch statistics are instead used during\n    evaluation time as well.\n\n    .. note::\n        This :attr:`momentum` argument is different from one used in optimizer\n        classes and the conventional notion of momentum. Mathematically, the\n        update rule for running statistics here is\n        :math:`\\\\hat{x}_\\\\text{new} = (1 - \\\\text{momentum}) \\\\times \\\\hat{x} + \\\\text{momentum} \\\\times x_t`,\n        where :math:`\\\\hat{x}` is the estimated statistic and :math:`x_t` is the\n        new observed value.\n\n    Because the Batch Normalization is done over the `C` dimension, computing statistics\n    on `(N, H, W)` slices, it's common terminology to call this Spatial Batch Normalization.\n\n    Args:\n        num_features: :math:`C` from an expected input of size\n            :math:`(N, C, H, W)`\n        eps: a value added to the denominator for numerical stability.\n            Default: 1e-5\n        momentum: the value used for the running_mean and running_var\n            computation. Can be set to ``None`` for cumulative moving average\n            (i.e. simple average). Default: 0.1\n        affine: a boolean value that when set to ``True``, this module has\n            learnable affine parameters. Default: ``True``\n        track_running_stats: a boolean value that when set to ``True``, this\n            module tracks the running mean and variance, and when set to ``False``,\n            this module does not track such statistics, and initializes statistics\n            buffers :attr:`running_mean` and :attr:`running_var` as ``None``.\n            When these buffers are ``None``, this module always uses batch statistics.\n            in both training and eval modes. Default: ``True``\n\n    Shape:\n        - Input: :math:`(N, C, H, W)`\n        - Output: :math:`(N, C, H, W)` (same shape as input)\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n        \n        >>> x = flow.Tensor(np.random.randn(4, 2, 8, 3))\n        >>> m = flow.nn.BatchNorm2d(num_features=2, eps=1e-5, momentum=0.1)\n        >>> y = m(x)\n\n    \"\"\"\n\n    def __init__(\n        self,\n        num_features,\n        eps=1e-05,\n        momentum=0.1,\n        affine=True,\n        track_running_stats=True,\n    ):\n        super().__init__(num_features, eps, momentum, affine, track_running_stats)\n        if os.getenv(\"ONEFLOW_ENABLE_NHWC\") == \"1\":\n            self.channel_axis = 3\n\n    def to_memory_format(self, memory_format) -> None:\n        if memory_format is flow.channels_last:\n            self.channel_axis = 3\n        elif memory_format is flow.contiguous_format:\n            self.channel_axis = 1\n\n    def _check_input_dim(self, input):\n        if input.ndim != 4:\n            raise ValueError(\"expected 4D input (got {}D input)\".format(input.ndim))\n\n\nclass BatchNorm3d(_BatchNorm):\n    r\"\"\"Applies Batch Normalization over a 5D input (a mini-batch of 3D inputs\n    with additional channel dimension) as described in the paper\n    `Batch Normalization: Accelerating Deep Network Training by Reducing\n    Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`__ .\n\n    .. math::\n\n        y = \\frac{x - \\mathrm{E}[x]}{ \\sqrt{\\mathrm{Var}[x] + \\epsilon}} * \\gamma + \\beta\n\n    The mean and standard-deviation are calculated per-dimension over\n    the mini-batches and :math:`\\gamma` and :math:`\\beta` are learnable parameter vectors\n    of size `C` (where `C` is the input size). By default, the elements of :math:`\\gamma` are set\n    to 1 and the elements of :math:`\\beta` are set to 0. The standard-deviation is calculated\n    via the biased estimator, equivalent to `oneflow.var(input, unbiased=False)`.\n\n    Also by default, during training this layer keeps running estimates of its\n    computed mean and variance, which are then used for normalization during\n    evaluation. The running estimates are kept with a default :attr:`momentum`\n    of 0.1.\n\n    If :attr:`track_running_stats` is set to ``False``, this layer then does not\n    keep running estimates, and batch statistics are instead used during\n    evaluation time as well.\n\n    .. note::\n        This :attr:`momentum` argument is different from one used in optimizer\n        classes and the conventional notion of momentum. Mathematically, the\n        update rule for running statistics here is\n        :math:`\\hat{x}_\\text{new} = (1 - \\text{momentum}) \\times \\hat{x} + \\text{momentum} \\times     x_t`,\n        where :math:`\\hat{x}` is the estimated statistic and :math:`x_t` is the\n        new observed value.\n\n    Because the Batch Normalization is done over the `C` dimension, computing statistics\n    on `(N, D, H, W)` slices, it's common terminology to call this Spatial Batch Normalization.\n\n    Args:\n        num_features: :math:`C` from an expected input of size\n            :math:`(N, C, D, H, W)`\n        eps: a value added to the denominator for numerical stability.\n            Default: 1e-5\n        momentum: the value used for the running_mean and running_var\n            computation. Can be set to ``None`` for cumulative moving average\n            (i.e. simple average). Default: 0.1\n        affine: a boolean value that when set to ``True``, this module has\n            learnable affine parameters. Default: ``True``\n        track_running_stats: a boolean value that when set to ``True``, this\n            module tracks the running mean and variance, and when set to ``False``,\n            this module does not track such statistics, and initializes statistics\n            buffers :attr:`running_mean` and :attr:`running_var` as ``None``.\n            When these buffers are ``None``, this module always uses batch statistics.\n            in both training and eval modes. Default: ``True``\n\n    Shape:\n        - Input: :math:`(N, C, D, H, W)`\n        - Output: :math:`(N, C, D, H, W)` (same shape as input)\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n\n        >>> x = flow.Tensor(np.random.randn(3, 2, 5, 8, 4))\n        >>> m = flow.nn.BatchNorm3d(num_features=2, eps=1e-5, momentum=0.1)\n        >>> y = m(x)\n        >>> y.size()\n        oneflow.Size([3, 2, 5, 8, 4])\n\n    \"\"\"\n\n    def _check_input_dim(self, input):\n        if input.ndim != 5:\n            raise ValueError(\"expected 5D input (got {}D input)\".format(input.ndim))\n\n\nglobal_eps = 0.1\nglobal_momentum = 0.1\nglobal_world_size = 1\nglobal_axis = 1\n\n\nclass SyncBatchNormFunction(flow.autograd.Function):\n    @staticmethod\n    def forward(self, input, weight, bias, running_mean, running_var):\n        assert input.is_local, \"SyncBatchNorm does not support global tensor as input.\"\n\n        if not input.is_contiguous():\n            input = input.contiguous()\n        if weight is not None:\n            weight = weight.contiguous()\n\n        size = int(input.numel() // input.size(1))\n        if size == 1 and global_world_size < 2:\n            raise ValueError(\n                \"Expected more than 1 value per channel when training, got input size {}\".format(\n                    size\n                )\n            )\n\n        num_channels = input.shape[global_axis]\n        if input.numel() > 0:\n            # calculate mean/invstd for input.\n            mean, invstd = flow._C.batch_norm_stats(input, global_axis, global_eps)\n\n            count = flow.full(\n                (1,),\n                input.numel() // input.size(global_axis),\n                dtype=mean.dtype,\n                device=mean.device,\n            )\n\n            # C, C, 1 -> (2C + 1)\n            combined = flow.cat([mean, invstd, count], dim=0)\n        else:\n            # for empty input, set stats and the count to zero. The stats with\n            # zero count will be filtered out later when computing global mean\n            # & invstd, but they still needs to participate the all_gather\n            # collective communication to unblock other peer processes.\n            combined = flow.zeros(\n                2 * num_channels + 1, dtype=input.dtype, device=input.device\n            )\n\n        # Use allgather instead of allreduce because count could be different across\n        # ranks, simple all reduce op can not give correct results.\n        # batch_norm_gather_stats_with_counts calculates global mean & invstd based on\n        # all gathered mean, invstd and count.\n        # world_size * (2C + 1)\n        combined_size = combined.numel()\n        combined_flat = flow.empty(\n            global_world_size,\n            combined_size,\n            dtype=combined.dtype,\n            device=combined.device,\n        )\n        flow.comm.all_gather_into_tensor(combined_flat, combined)\n        # world_size * (2C + 1) -> world_size * C, world_size * C, world_size * 1\n        mean_all, invstd_all, count_all = flow.split(combined_flat, num_channels, dim=1)\n\n        # remove stats from empty inputs\n        mask = count_all.squeeze(-1) >= 1\n        count_all = count_all[mask]\n        mean_all = mean_all[mask]\n        invstd_all = invstd_all[mask]\n\n        # calculate global mean & invstd\n        mean, invstd = flow._C.batch_norm_gather_stats_with_counts(\n            input,\n            mean_all,\n            invstd_all,\n            running_mean,\n            running_var,\n            global_momentum,\n            global_eps,\n            count_all.view(-1),\n        )\n\n        self.save_for_backward(input, weight, mean, invstd, count_all.to(flow.int32))\n\n        # apply element-wise normalization\n        if input.numel() > 0:\n            return flow._C.batch_norm_elemt(\n                input, weight, bias, mean, invstd, global_axis, global_eps\n            )\n        else:\n            return flow.zeros(*(input.shape), dtype=input.dtype, device=input.device)\n\n    @staticmethod\n    def backward(self, grad_output):\n        if not grad_output.is_contiguous():\n            grad_output = grad_output.contiguous()\n        saved_input, weight, mean, invstd, count_tensor = self.saved_tensors\n        grad_input = grad_weight = grad_bias = None\n\n        channel_axis = 1\n        if os.getenv(\"ONEFLOW_ENABLE_NHWC\") == \"1\":\n            if saved_input.dim() == 3:\n                channel_axis = 2\n            elif saved_input.dim() == 4:\n                channel_axis = 3\n            elif saved_input.dim() == 5:\n                channel_axis = 4\n\n        # calculate local stats as well as grad_weight / grad_bias\n        sum_dy, sum_dy_xmu, grad_weight, grad_bias = flow._C.batch_norm_backward_reduce(\n            grad_output, saved_input, mean, invstd, channel_axis\n        )\n\n        # synchronizing stats used to calculate input gradient.\n        num_channels = sum_dy.shape[0]\n        combined = flow.cat([sum_dy, sum_dy_xmu], dim=0)\n        flow.comm.all_reduce(combined)\n        sum_dy, sum_dy_xmu = flow.split(combined, num_channels)\n\n        # backward pass for gradient calculation\n        grad_input = flow._C.batch_norm_backward_elemt(\n            grad_output,\n            saved_input,\n            mean,\n            invstd,\n            weight,\n            sum_dy,\n            sum_dy_xmu,\n            count_tensor,\n            channel_axis,\n        )\n\n        # synchronizing of grad_weight / grad_bias is not needed as distributed\n        # training would handle all reduce.\n        return grad_input, grad_weight, grad_bias, None, None\n\n\nclass SyncBatchNorm(_BatchNorm):\n    r\"\"\"Applies Batch Normalization over a N-Dimensional input (a mini-batch of [N-2]D inputs\n    with additional channel dimension) as described in the paper\n    `Batch Normalization: Accelerating Deep Network Training by Reducing\n    Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`__ .\n\n    .. math::\n\n        y = \\frac{x - \\mathrm{E}[x]}{ \\sqrt{\\mathrm{Var}[x] + \\epsilon}} * \\gamma + \\beta\n\n    The mean and standard-deviation are calculated per-dimension over all\n    mini-batches of the same process groups. :math:`\\gamma` and :math:`\\beta`\n    are learnable parameter vectors of size `C` (where `C` is the input size).\n    By default, the elements of :math:`\\gamma` are sampled from\n    :math:`\\mathcal{U}(0, 1)` and the elements of :math:`\\beta` are set to 0.\n    The standard-deviation is calculated via the biased estimator, equivalent to\n    `oneflow.var(input, unbiased=False)`.\n\n    Also by default, during training this layer keeps running estimates of its\n    computed mean and variance, which are then used for normalization during\n    evaluation. The running estimates are kept with a default :attr:`momentum`\n    of 0.1.\n\n    If :attr:`track_running_stats` is set to ``False``, this layer then does not\n    keep running estimates, and batch statistics are instead used during\n    evaluation time as well.\n\n    .. note::\n        This :attr:`momentum` argument is different from one used in optimizer\n        classes and the conventional notion of momentum. Mathematically, the\n        update rule for running statistics here is\n        :math:`\\hat{x}_\\text{new} = (1 - \\text{momentum}) \\times \\hat{x} + \\text{momentum} \\times x_t`,\n        where :math:`\\hat{x}` is the estimated statistic and :math:`x_t` is the\n        new observed value.\n\n    Because the Batch Normalization is done for each channel in the ``C`` dimension, computing\n    statistics on ``(N, +)`` slices, it's common terminology to call this Volumetric Batch\n    Normalization or Spatio-temporal Batch Normalization.\n\n    Currently :class:`SyncBatchNorm` only supports\n    :class:`~oneflow.nn.DistributedDataParallel` (DDP) with single GPU per process. Use\n    :meth:`oneflow.nn.SyncBatchNorm.convert_sync_batchnorm()` to convert\n    :attr:`BatchNorm*D` layer to :class:`SyncBatchNorm` before wrapping\n    Network with DDP.\n\n    Args:\n        num_features: :math:`C` from an expected input of size\n            :math:`(N, C, +)`\n        eps: a value added to the denominator for numerical stability.\n            Default: ``1e-5``\n        momentum: the value used for the running_mean and running_var\n            computation. Can be set to ``None`` for cumulative moving average\n            (i.e. simple average). Default: 0.1\n        affine: a boolean value that when set to ``True``, this module has\n            learnable affine parameters. Default: ``True``\n        track_running_stats: a boolean value that when set to ``True``, this\n            module tracks the running mean and variance, and when set to ``False``,\n            this module does not track such statistics, and initializes statistics\n            buffers :attr:`running_mean` and :attr:`running_var` as ``None``.\n            When these buffers are ``None``, this module always uses batch statistics.\n            in both training and eval modes. Default: ``True``\n\n    Shape:\n        - Input: :math:`(N, C, +)`\n        - Output: :math:`(N, C, +)` (same shape as input)\n\n    .. note::\n        Synchronization of batchnorm statistics occurs only while training, i.e.\n        synchronization is disabled when ``model.eval()`` is set or if\n        ``self.training`` is otherwise ``False``.\n\n    Examples::\n\n        >>> import oneflow as flow\n        \n        >>> bn = flow.nn.BatchNorm2d(100)\n        >>> sync_bn = flow.nn.SyncBatchNorm.convert_sync_batchnorm(bn).cuda()\n        >>> input = flow.randn(20, 100, 35, 45, device=\"cuda\")\n        >>> output = sync_bn(input)\n    \"\"\"\n\n    def __init__(\n        self,\n        num_features: int,\n        eps: float = 1e-5,\n        momentum: float = 0.1,\n        affine: bool = True,\n        track_running_stats: bool = True,\n    ) -> None:\n        super().__init__(num_features, eps, momentum, affine, track_running_stats)\n\n    def _check_input_dim(self, input):\n        if input.dim() < 2:\n            raise ValueError(\n                \"expected at least 2D input (got {}D input)\".format(input.dim())\n            )\n        if os.getenv(\"ONEFLOW_ENABLE_NHWC\") == \"1\":\n            if input.dim() == 3:\n                self.channel_axis = 2\n            elif input.dim() == 4:\n                self.channel_axis = 3\n            elif input.dim() == 5:\n                self.channel_axis = 4\n\n    def _check_non_zero_input_channels(self, input):\n        if input.size(1) == 0:\n            raise ValueError(\n                \"SyncBatchNorm number of input channels should be non-zero\"\n            )\n\n    def forward(self, input):\n        # currently only GPU input is supported\n        if not input.is_cuda:\n            raise ValueError(\"SyncBatchNorm expected input tensor to be on GPU\")\n\n        self._check_input_dim(input)\n        self._check_non_zero_input_channels(input)\n\n        if self.momentum is None:\n            exponential_average_factor = 0.0\n        else:\n            exponential_average_factor = self.momentum\n\n        if self.training and self.track_running_stats:\n            assert self.num_batches_tracked is not None\n            self.num_batches_tracked.add_(1)\n            if self.momentum is None:  # use cumulative moving average\n                exponential_average_factor = 1.0 / self.num_batches_tracked.item()\n            else:  # use exponential moving average\n                exponential_average_factor = self.momentum\n\n        r\"\"\"\n        Decide whether the mini-batch stats should be used for normalization rather than the buffers.\n        Mini-batch stats are used in training mode, and in eval mode when buffers are None.\n        \"\"\"\n        if self.training:\n            bn_training = True\n        else:\n            bn_training = (self.running_mean is None) and (self.running_var is None)\n\n        # Don't sync batchnorm stats in inference mode (model.eval()).\n        need_sync = bn_training and self.training\n        if need_sync:\n            need_sync = flow.env.get_world_size() > 1\n\n        # # fallback to framework BN when synchronization is not necessary\n        if not need_sync:\n            return flow._C.normalization(\n                input,\n                self.running_mean,\n                self.running_var,\n                self.weight,\n                self.bias,\n                axis=self.channel_axis,\n                epsilon=self.eps,\n                momentum=exponential_average_factor,\n                is_training=bn_training,\n            )\n        else:\n            assert bn_training\n            global global_eps\n            global global_momentum\n            global global_world_size\n            global global_axis\n            global_eps = self.eps\n            global_momentum = exponential_average_factor\n            global_world_size = flow.env.get_world_size()\n            global_axis = self.channel_axis\n            assert (\n                self.track_running_stats\n            ), \"`track_running_stats` should be True when using SyncBatchNorm.\"\n            return SyncBatchNormFunction.apply(\n                input, self.weight, self.bias, self.running_mean, self.running_var,\n            )\n\n    @classmethod\n    def convert_sync_batchnorm(cls, module):\n        r\"\"\"Helper function to convert all :attr:`BatchNorm*D` layers in the model to\n        :class:`oneflow.nn.SyncBatchNorm` layers.\n\n        Args:\n            module (nn.Module): module containing one or more :attr:`BatchNorm*D` layers\n\n        Returns:\n            The original :attr:`module` with the converted :class:`oneflow.nn.SyncBatchNorm`\n            layers. If the original :attr:`module` is a :attr:`BatchNorm*D` layer,\n            a new :class:`oneflow.nn.SyncBatchNorm` layer object will be returned\n            instead.\n\n        Example::\n\n            >>> import oneflow as flow\n\n            >>> module = flow.nn.Sequential( flow.nn.Linear(20, 100), flow.nn.BatchNorm1d(100)).cuda()\n            >>> sync_bn_module = flow.nn.SyncBatchNorm.convert_sync_batchnorm(module)\n\n        \"\"\"\n        module_output = module\n        if isinstance(module, flow.nn.modules.batchnorm._BatchNorm):\n            module_output = flow.nn.SyncBatchNorm(\n                module.num_features,\n                module.eps,\n                module.momentum,\n                module.affine,\n                module.track_running_stats,\n            )\n            if module.affine:\n                with flow.no_grad():\n                    module_output.weight = module.weight\n                    module_output.bias = module.bias\n            module_output.running_mean = module.running_mean\n            module_output.running_var = module.running_var\n            module_output.num_batches_tracked = module.num_batches_tracked\n        for name, child in module.named_children():\n            module_output.add_module(name, cls.convert_sync_batchnorm(child))\n        del module\n        return module_output\n\n\nif __name__ == \"__main__\":\n    import doctest\n\n    doctest.testmod(raise_on_error=True)\n"
  },
  {
    "path": "python/oneflow/nn/modules/batchnorm_fused.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nfrom typing import Union\nimport os\n\nimport oneflow as flow\nfrom oneflow.nn.modules.module import Module\n\n\nclass _FusedNormBase(Module):\n    \"\"\"Common base of _FusedBatchNorm\"\"\"\n\n    def __init__(\n        self,\n        num_features: int,\n        eps: float = 1e-05,\n        momentum: float = 0.1,\n        affine: bool = True,\n        track_running_stats: bool = True,\n    ) -> None:\n        super().__init__()\n        self.num_features = num_features\n        self.eps = eps\n        self.momentum = momentum\n        self.affine = affine\n        self.track_running_stats = track_running_stats\n        if self.affine:\n            self.weight = flow.nn.Parameter(flow.Tensor(num_features))\n            self.bias = flow.nn.Parameter(flow.Tensor(num_features))\n        else:\n            self.register_parameter(\"weight\", None)\n            self.register_parameter(\"bias\", None)\n        if self.track_running_stats:\n            self.register_buffer(\"running_mean\", flow.Tensor(num_features))\n            self.register_buffer(\"running_var\", flow.Tensor(num_features))\n        else:\n            self.register_parameter(\"running_mean\", None)\n            self.register_parameter(\"running_var\", None)\n        self.reset_parameters()\n\n    def reset_running_stats(self) -> None:\n        if self.track_running_stats:\n            self.running_mean.fill_(0)\n            self.running_var.fill_(1)\n\n    def reset_parameters(self) -> None:\n        self.reset_running_stats()\n        if self.affine:\n            flow.nn.init.ones_(self.weight)\n            flow.nn.init.zeros_(self.bias)\n\n    def _check_input_dim(self, input):\n        raise NotImplementedError\n\n    def extra_repr(self):\n        return \"num_features={num_features}, eps={eps}, momentum={momentum}, affine={affine}, track_running_stats={track_running_stats}\".format(\n            **self.__dict__\n        )\n\n\nclass _FusedBatchNorm(_FusedNormBase):\n    def __init__(\n        self,\n        num_features,\n        eps=1e-05,\n        momentum=0.1,\n        affine=True,\n        track_running_stats=True,\n    ):\n        super().__init__(num_features, eps, momentum, affine, track_running_stats)\n        self.channel_axis = 1\n\n    def forward(self, x, addend=None):\n        self._check_input_dim(x)\n\n        if self.training:\n            is_training = True\n        else:\n            is_training = (self.running_mean is None) and (self.running_var is None)\n        return flow._C.normalization_add_relu(\n            x,\n            addend if addend is not None else None,\n            self.running_mean\n            if not self.training or self.track_running_stats\n            else None,\n            self.running_var if not self.training or self.track_running_stats else None,\n            self.weight,\n            self.bias,\n            axis=self.channel_axis,\n            epsilon=self.eps,\n            momentum=self.momentum,\n            is_training=is_training,\n        )\n\n\nclass FusedBatchNorm1d(_FusedBatchNorm):\n    \"\"\"Applies Fused Batch Normalization over a 2D or 3D input, the formula is: \n    \n    .. math:: \n\n        out = ReLU(BatchNorm(input) + addend)\n\n    The formula of Batch Normalization is: \n\n    .. math::\n\n        y = \\\\frac{x - \\\\mathrm{E}[x]}{\\\\sqrt{\\\\mathrm{Var}[x] + \\\\epsilon}} * \\\\gamma + \\\\beta\n\n    The mean and standard-deviation are calculated per-dimension over\n    the mini-batches and :math:`\\\\gamma` and :math:`\\\\beta` are learnable parameter vectors\n    of size `C` (where `C` is the input size). By default, the elements of :math:`\\\\gamma` are set\n    to 1 and the elements of :math:`\\\\beta` are set to 0. The standard-deviation is calculated\n    via the biased estimator, equivalent to `torch.var(input, unbiased=False)`.\n\n    Also by default, during training this layer keeps running estimates of its\n    computed mean and variance, which are then used for normalization during\n    evaluation. The running estimates are kept with a default :attr:`momentum`\n    of 0.1.\n\n    If :attr:`track_running_stats` is set to ``False``, this layer then does not\n    keep running estimates, and batch statistics are instead used during\n    evaluation time as well.\n\n    .. note::\n        This :attr:`momentum` argument is different from one used in optimizer\n        classes and the conventional notion of momentum. Mathematically, the\n        update rule for running statistics here is\n        :math:`\\\\hat{x}_\\\\text{new} = (1 - \\\\text{momentum}) \\\\times \\\\hat{x} + \\\\text{momentum} \\\\times x_t`,\n        where :math:`\\\\hat{x}` is the estimated statistic and :math:`x_t` is the\n        new observed value.\n\n    Because the Batch Normalization is done over the `C` dimension, computing statistics\n    on `(N, L)` slices, it's common terminology to call this Temporal Batch Normalization.\n\n    Args:\n        num_features: :math:`C` from an expected input of size\n            :math:`(N, C, L)` or :math:`L` from input of size :math:`(N, L)`\n        eps: a value added to the denominator for numerical stability.\n            Default: 1e-5\n        momentum: the value used for the running_mean and running_var\n            computation. Can be set to ``None`` for cumulative moving average\n            (i.e. simple average). Default: 0.1\n        affine: a boolean value that when set to ``True``, this module has\n            learnable affine parameters. Default: ``True``\n        track_running_stats: a boolean value that when set to ``True``, this\n            module tracks the running mean and variance, and when set to ``False``,\n            this module does not track such statistics, and initializes statistics\n            buffers :attr:`running_mean` and :attr:`running_var` as ``None``.\n            When these buffers are ``None``, this module always uses batch statistics.\n            in both training and eval modes. Default: ``True``\n\n    Shape:\n        - Input: :math:`(N, C)` or :math:`(N, C, L)`\n        - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n        \n        >>> x = flow.Tensor(np.random.randn(20, 100)).to(\"cuda\") # FusedBatchNorm support in GPU currently. \n        >>> m = flow.nn.FusedBatchNorm1d(num_features=100, eps=1e-5, momentum=0.1).to(\"cuda\")\n        >>> y = m(x, addend=None)\n\n    \"\"\"\n\n    def _check_input_dim(self, input):\n        if input.ndim != 2 and input.ndim != 3:\n            raise ValueError(\n                \"expected 2D or 3D input (got {}D input)\".format(input.ndim)\n            )\n\n\nclass FusedBatchNorm2d(_FusedBatchNorm):\n    \"\"\"Applies Fused Batch Normalization over a 4D input, the formula is: \n    \n    .. math:: \n\n        out = ReLU(BatchNorm(input) + addend)\n\n    The formula of Batch Normalization is: \n\n    .. math::\n\n        y = \\\\frac{x - \\\\mathrm{E}[x]}{\\\\sqrt{\\\\mathrm{Var}[x] + \\\\epsilon}} * \\\\gamma + \\\\beta\n\n    The mean and standard-deviation are calculated per-dimension over\n    the mini-batches and :math:`\\\\gamma` and :math:`\\\\beta` are learnable parameter vectors\n    of size `C` (where `C` is the input size). By default, the elements of :math:`\\\\gamma` are set\n    to 1 and the elements of :math:`\\\\beta` are set to 0. The standard-deviation is calculated\n    via the biased estimator, equivalent to `torch.var(input, unbiased=False)`.\n\n    Also by default, during training this layer keeps running estimates of its\n    computed mean and variance, which are then used for normalization during\n    evaluation. The running estimates are kept with a default :attr:`momentum`\n    of 0.1.\n\n    If :attr:`track_running_stats` is set to ``False``, this layer then does not\n    keep running estimates, and batch statistics are instead used during\n    evaluation time as well.\n\n    .. note::\n        This :attr:`momentum` argument is different from one used in optimizer\n        classes and the conventional notion of momentum. Mathematically, the\n        update rule for running statistics here is\n        :math:`\\\\hat{x}_\\\\text{new} = (1 - \\\\text{momentum}) \\\\times \\\\hat{x} + \\\\text{momentum} \\\\times x_t`,\n        where :math:`\\\\hat{x}` is the estimated statistic and :math:`x_t` is the\n        new observed value.\n\n    Because the Batch Normalization is done over the `C` dimension, computing statistics\n    on `(N, H, W)` slices, it's common terminology to call this Spatial Batch Normalization.\n\n    Args:\n        num_features: :math:`C` from an expected input of size\n            :math:`(N, C, H, W)`\n        eps: a value added to the denominator for numerical stability.\n            Default: 1e-5\n        momentum: the value used for the running_mean and running_var\n            computation. Can be set to ``None`` for cumulative moving average\n            (i.e. simple average). Default: 0.1\n        affine: a boolean value that when set to ``True``, this module has\n            learnable affine parameters. Default: ``True``\n        track_running_stats: a boolean value that when set to ``True``, this\n            module tracks the running mean and variance, and when set to ``False``,\n            this module does not track such statistics, and initializes statistics\n            buffers :attr:`running_mean` and :attr:`running_var` as ``None``.\n            When these buffers are ``None``, this module always uses batch statistics.\n            in both training and eval modes. Default: ``True``\n\n    Shape:\n        - Input: :math:`(N, C, H, W)`\n        - Output: :math:`(N, C, H, W)` (same shape as input)\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n        \n        >>> x = flow.Tensor(np.random.randn(4, 2, 8, 3)).to(\"cuda\") # FusedBatchNorm support in GPU currently. \n        >>> m = flow.nn.FusedBatchNorm2d(num_features=2, eps=1e-5, momentum=0.1).to(\"cuda\")\n        >>> y = m(x, addend=None)\n\n    \"\"\"\n\n    def __init__(\n        self,\n        num_features,\n        eps=1e-05,\n        momentum=0.1,\n        affine=True,\n        track_running_stats=True,\n    ):\n        super().__init__(num_features, eps, momentum, affine, track_running_stats)\n        if os.getenv(\"ONEFLOW_ENABLE_NHWC\") == \"1\":\n            self.channel_axis = 3\n\n    def _check_input_dim(self, input):\n        if input.ndim != 4:\n            raise ValueError(\"expected 4D input (got {}D input)\".format(input.ndim))\n\n\nclass FusedBatchNorm3d(_FusedBatchNorm):\n    r\"\"\"Applies Fused Batch Normalization over a 5D input, the formula is: \n    \n    .. math:: \n\n        out = ReLU(BatchNorm(input) + addend)\n\n    The formula of Batch Normalization is: \n\n    .. math::\n\n        y = \\\\frac{x - \\\\mathrm{E}[x]}{\\\\sqrt{\\\\mathrm{Var}[x] + \\\\epsilon}} * \\\\gamma + \\\\beta\n\n    The mean and standard-deviation are calculated per-dimension over\n    the mini-batches and :math:`\\gamma` and :math:`\\beta` are learnable parameter vectors\n    of size `C` (where `C` is the input size). By default, the elements of :math:`\\gamma` are set\n    to 1 and the elements of :math:`\\beta` are set to 0. The standard-deviation is calculated\n    via the biased estimator, equivalent to `torch.var(input, unbiased=False)`.\n\n    Also by default, during training this layer keeps running estimates of its\n    computed mean and variance, which are then used for normalization during\n    evaluation. The running estimates are kept with a default :attr:`momentum`\n    of 0.1.\n\n    If :attr:`track_running_stats` is set to ``False``, this layer then does not\n    keep running estimates, and batch statistics are instead used during\n    evaluation time as well.\n\n    .. note::\n        This :attr:`momentum` argument is different from one used in optimizer\n        classes and the conventional notion of momentum. Mathematically, the\n        update rule for running statistics here is\n        :math:`\\hat{x}_\\text{new} = (1 - \\text{momentum}) \\times \\hat{x} + \\text{momentum} \\times     x_t`,\n        where :math:`\\hat{x}` is the estimated statistic and :math:`x_t` is the\n        new observed value.\n\n    Because the Batch Normalization is done over the `C` dimension, computing statistics\n    on `(N, D, H, W)` slices, it's common terminology to call this Spatial Batch Normalization.\n\n    Args:\n        num_features: :math:`C` from an expected input of size\n            :math:`(N, C, D, H, W)`\n        eps: a value added to the denominator for numerical stability.\n            Default: 1e-5\n        momentum: the value used for the running_mean and running_var\n            computation. Can be set to ``None`` for cumulative moving average\n            (i.e. simple average). Default: 0.1\n        affine: a boolean value that when set to ``True``, this module has\n            learnable affine parameters. Default: ``True``\n        track_running_stats: a boolean value that when set to ``True``, this\n            module tracks the running mean and variance, and when set to ``False``,\n            this module does not track such statistics, and initializes statistics\n            buffers :attr:`running_mean` and :attr:`running_var` as ``None``.\n            When these buffers are ``None``, this module always uses batch statistics.\n            in both training and eval modes. Default: ``True``\n\n    Shape:\n        - Input: :math:`(N, C, D, H, W)`\n        - Output: :math:`(N, C, D, H, W)` (same shape as input)\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n\n        >>> x = flow.Tensor(np.random.randn(3, 2, 5, 8, 4)).to(\"cuda\") # FusedBatchNorm support in GPU currently. \n        >>> m = flow.nn.FusedBatchNorm3d(num_features=2, eps=1e-5, momentum=0.1).to(\"cuda\")\n        >>> y = m(x, addend=None)\n\n    \"\"\"\n\n    def _check_input_dim(self, input):\n        if input.ndim != 5:\n            raise ValueError(\"expected 5D input (got {}D input)\".format(input.ndim))\n\n\nif __name__ == \"__main__\":\n    import doctest\n\n    doctest.testmod(raise_on_error=True)\n"
  },
  {
    "path": "python/oneflow/nn/modules/broadcast_ops.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow as flow\nfrom oneflow.nn.modules.utils import _single, _handle_size_arg\n\n\ndef broadcast_shapes(*shapes):\n    r\"\"\"broadcast_shapes(*shapes) -> Size\n\n    The interface is consistent with PyTorch.\n    The documentation is referenced from:\n    https://pytorch.org/docs/1.10/generated/torch.broadcast_shapes.html.\n\n    Similar to :func:`oneflow.broadcast_tensors` but for shapes.\n\n    This is equivalent to ``flow.broadcast_tensors(*map(flow.empty, shapes))[0].shape``\n    but avoids the need create to intermediate tensors.\n    This is useful for broadcasting tensors of common batch shape but different rightmost shape,\n    e.g. to broadcast mean vectors with covariance matrices.\n\n    Args:\n        \\*shapes (flow.Size): Shapes of tensors.\n\n    Returns:\n        A shape compatible with all input shapes.\n\n    Raises:\n        RuntimeError: If shapes are incompatible.\n\n    Example::\n\n        >>> import oneflow as flow\n        >>> flow.broadcast_shapes((2,), (3, 1), (1, 1, 1))\n        oneflow.Size([1, 3, 2])\n    \"\"\"\n    shapes = _single(shapes)\n    return flow._C.broadcast_shapes(shapes)\n\n\ndef broadcast_tensors(*tensors):\n    r\"\"\"broadcast_tensors(*tensors) -> List of Tensors\n\n    The interface is consistent with PyTorch.\n    The documentation is referenced from:\n    https://pytorch.org/docs/1.10/generated/torch.broadcast_tensors.html.\n\n    Broadcasts the given tensors according to ``broadcasting-semantics``.\n\n    Args:\n        *tensors: any number of tensors of the same type\n\n    .. warning::\n\n        More than one element of a broadcasted tensor may refer to a single\n        memory location. As a result, in-place operations (especially ones that\n        are vectorized) may result in incorrect behavior. If you need to write\n        to the tensors, please clone them first.\n\n    Example::\n\n        >>> import oneflow as flow\n        >>> x = flow.arange(3).view(1, 3)\n        >>> y = flow.arange(2).view(2, 1)\n        >>> a, b = flow.broadcast_tensors(x, y)\n        >>> a.size()\n        oneflow.Size([2, 3])\n        >>> a\n        tensor([[0, 1, 2],\n                [0, 1, 2]], dtype=oneflow.int64)\n    \"\"\"\n    tensors = _single(tensors)\n    return flow._C.broadcast_tensors(tensors)\n\n\ndef broadcast_to(input, shape):\n    r\"\"\"broadcast_to(input, shape) -> Tensors\n\n    The interface is consistent with PyTorch.\n    The documentation is referenced from:\n    https://pytorch.org/docs/1.10/generated/torch.broadcast_to.html.\n\n    Broadcasts ``input`` to the shape ``shape``. Equivalent to calling ``input.expand(shape)``. See :func:`oneflow.expand` for details.\n\n    Args:\n        input (oneflow.Tensor): the input tensor.\n        shape (list, tuple, or oneflow.Size): the new shape.\n\n    Example::\n\n        >>> import oneflow as flow\n        >>> x = flow.tensor([1, 2, 3])\n        >>> flow.broadcast_to(x, (3, 3))\n        tensor([[1, 2, 3],\n                [1, 2, 3],\n                [1, 2, 3]], dtype=oneflow.int64)\n    \"\"\"\n    shape = _handle_size_arg(shape)\n    shape = _single(shape)\n    return flow._C.broadcast_to(input, shape)\n\n\nif __name__ == \"__main__\":\n    import doctest\n\n    doctest.testmod(raise_on_error=True)\n"
  },
  {
    "path": "python/oneflow/nn/modules/constant.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nfrom typing import List, Optional, Union\nimport numpy as np\n\nimport oneflow as flow\nfrom oneflow.framework.tensor import register_tensor_op\nfrom oneflow.nn.common_types import _size_any_t\nfrom oneflow.nn.modules.utils import _single, _handle_size_arg\n\n\nclass _ConstantBase:\n    def __init__(\n        self,\n        size: Union[_size_any_t, flow.Size],\n        value: Union[float, int, complex],\n        dtype: Optional[flow.dtype],\n        device: Union[flow.device, int, str] = None,\n        placement: flow.placement = None,\n        sbp: Union[flow.sbp.sbp, List[flow.sbp.sbp]] = None,\n        requires_grad: bool = False,\n    ) -> None:\n        assert size is not None, \"shape must not be None!\"\n        assert isinstance(\n            size, (int, tuple, list, flow.Size)\n        ), \"shape should be int or tuple int!\"\n        self.device = device\n        if isinstance(self.device, int):\n            self.device = flow.device(\"cuda\", self.device)\n        if isinstance(self.device, str):\n            self.device = flow.device(self.device)\n        self.requires_grad = requires_grad\n        size = _single(size)\n        if dtype is None:\n            dtype = flow.get_default_dtype()\n        if placement is None:\n            if device is None:\n                self.device = flow.device(\"cpu\")\n        else:\n            assert device is None\n        self.placement = placement\n        self.sbp = sbp\n        if placement is not None:\n            assert isinstance(sbp, (flow.sbp.sbp, tuple, list)), \"sbp: %s\" % sbp\n            if isinstance(self.sbp, flow.sbp.sbp):\n                self.sbp = (self.sbp,)\n            else:\n                for elem in sbp:\n                    assert isinstance(elem, flow.sbp.sbp), \"sbp: %s\" % sbp\n            assert len(self.sbp) == len(placement.ranks.shape)\n        else:\n            assert sbp is None, \"sbp: %s\" % sbp\n        self.shape = size\n        self.value = value\n        self.dtype = dtype\n\n    def forward(self):\n        if self.placement is not None:\n            if isinstance(self.value, flow.Tensor):\n                assert (\n                    self.value.ndim <= 1 and self.value.numel() == 1\n                ), \"Only tensor with single element or scalar tensor are supported as value!\"\n                res = flow._C.global_tensor_constant(\n                    self.shape,\n                    self.value,\n                    dtype=self.dtype,\n                    placement=self.placement,\n                    sbp=self.sbp,\n                )\n            else:\n                res = flow._C.global_constant(\n                    self.shape,\n                    self.value,\n                    dtype=self.dtype,\n                    placement=self.placement,\n                    sbp=self.sbp,\n                )\n        else:\n            if isinstance(self.value, flow.Tensor):\n                assert (\n                    self.value.ndim <= 1 and self.value.numel() == 1\n                ), \"Only tensor with single element or scalar tensor are supported as value!\"\n                res = flow._C.tensor_constant(\n                    self.shape, self.value, dtype=self.dtype, device=self.device\n                )\n            else:\n                res = flow._C.constant(\n                    self.shape, self.value, dtype=self.dtype, device=self.device\n                )\n        res.requires_grad = self.requires_grad\n        return res\n\n\ndef _handle_meta_args(\n    input,\n    size: Union[_size_any_t, List[int], flow.Size, None] = None,\n    dtype: Optional[flow.dtype] = None,\n    device: Union[flow.device, str, None] = None,\n    placement: flow.placement = None,\n    sbp: Union[\n        flow._oneflow_internal.sbp.sbp, List[flow._oneflow_internal.sbp.sbp], None\n    ] = None,\n    requires_grad: bool = False,\n):\n    if isinstance(device, str):\n        device = flow.device(device)\n    if size is None:\n        new_size = input.shape\n    else:\n        new_size = _handle_size_arg(size)\n    if dtype is None:\n        new_dtype = input.dtype\n    else:\n        new_dtype = dtype\n    new_device = device\n    new_placement = placement\n    new_sbp = sbp\n    new_requires_grad = requires_grad\n\n    if new_device is not None:\n        assert (\n            new_placement is None\n        ), \"argument 'placement' must be None when argument 'device' exist\"\n        assert (\n            new_sbp is None\n        ), \"argument 'sbp' must be None when argument 'device' exist\"\n    elif new_device is None and new_placement is None and new_sbp is None:\n        new_device = input.device if input.is_local else None\n        new_placement = input.placement if input.is_global else None\n        new_sbp = input.sbp if input.is_global else None\n    else:\n        if new_placement is None and new_sbp is not None:\n            assert (\n                input.is_global\n            ), \"argument 'placement' must not be None when argument 'sbp' exist and Tensor is local\"\n            new_placement = input.placement\n        elif new_placement is not None and new_sbp is None:\n            assert (\n                input.is_global\n            ), \"argument 'sbp' must not be None when argument 'placement' exist and Tensor is local\"\n            new_sbp = input.sbp\n    assert isinstance(\n        new_size, (int, tuple, list, flow.Size)\n    ), f\"argument 'size' must be tuple of ints, not %s\" % (type(new_size))\n    assert isinstance(\n        new_dtype, flow.dtype\n    ), f\"argument 'dtype' must be flow.dtype, not %s\" % (type(new_dtype))\n    if new_placement is not None:\n        assert isinstance(\n            new_placement, flow.placement\n        ), f\"argument 'placement' must be flow.placement, not %s\" % (\n            type(new_placement)\n        )\n        assert isinstance(\n            new_sbp, (flow.sbp.sbp, tuple)\n        ), f\"argument 'sbp' must be flow.sbp.sbp, not %s\" % (type(new_sbp))\n    else:\n        assert isinstance(\n            new_device, (str, flow.device)\n        ), f\"argument 'device' must be flow.device, not %s\" % (type(new_device))\n    assert isinstance(\n        new_requires_grad, bool\n    ), f\"argument 'requires_grad' must be bool, not %s\" % (type(new_requires_grad))\n\n    return new_size, new_dtype, new_device, new_placement, new_sbp, new_requires_grad\n\n\nclass Ones(_ConstantBase):\n    def __init__(\n        self,\n        size,\n        dtype=None,\n        device=None,\n        placement=None,\n        sbp=None,\n        requires_grad=False,\n    ):\n        super().__init__(size, 1, dtype, device, placement, sbp, requires_grad)\n\n\ndef ones_op(\n    *size: Union[_size_any_t, flow.Size, List[int]],\n    dtype: Optional[flow.dtype] = None,\n    device: Union[flow.device, str, None] = None,\n    placement: flow.placement = None,\n    sbp: Union[\n        flow._oneflow_internal.sbp.sbp, List[flow._oneflow_internal.sbp.sbp], None\n    ] = None,\n    requires_grad: bool = False,\n):\n    \"\"\"\n    Returns a tensor filled with the scalar value 1,\n    with the shape defined by the variable argument `size`.\n\n    Args:\n        size (an integer or tuple of integer values): defining the shape of the output tensor. Can be \\\\\n         a variable number of arguments or a collection like a list or tuple.\n        dtype (flow.dtype, optional): the desired data type of returned tensor.\n        device (flow.device, optional): the desired device of returned tensor. Default: if None, uses the current device for the default tensor type\n        placement (flow.placement, optional): the desired placement of returned global tensor. Default: if None, the returned tensor is local one using the argument `device`.\n        sbp (flow.sbp.sbp or tuple of flow.sbp.sbp, optional): the desired sbp descriptor of returned global tensor. Default: if None, the returned tensor is local one using the argument `device`.\n        requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: False.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> y = flow.ones(5)\n        >>> y\n        tensor([1., 1., 1., 1., 1.], dtype=oneflow.float32)\n        >>> y = flow.ones(2,3) # construct local tensor\n        >>> y\n        tensor([[1., 1., 1.],\n                [1., 1., 1.]], dtype=oneflow.float32)\n        >>> placement = flow.placement(\"cpu\", ranks=[0])\n        >>> y = flow.ones(4, 5, placement=placement, sbp=flow.sbp.broadcast) # construct global tensor\n        >>> y.is_global\n        True\n\n\n    \"\"\"\n    size = _handle_size_arg(size)\n    return Ones(size, dtype, device, placement, sbp, requires_grad).forward()\n\n\ndef ones_like_op(\n    input,\n    dtype: Optional[flow.dtype] = None,\n    device: Union[flow.device, str, None] = None,\n    placement: flow.placement = None,\n    sbp: Union[\n        flow._oneflow_internal.sbp.sbp, List[flow._oneflow_internal.sbp.sbp], None\n    ] = None,\n    requires_grad: bool = False,\n):\n    (\n        new_size,\n        new_dtype,\n        new_device,\n        new_placement,\n        new_sbp,\n        new_requires_grad,\n    ) = _handle_meta_args(input, None, dtype, device, placement, sbp, requires_grad)\n    return Ones(\n        new_size, new_dtype, new_device, new_placement, new_sbp, new_requires_grad\n    ).forward()\n\n\nclass Zeros(_ConstantBase):\n    def __init__(\n        self,\n        size,\n        dtype=None,\n        device=None,\n        placement=None,\n        sbp=None,\n        requires_grad=False,\n    ):\n        super().__init__(size, 0, dtype, device, placement, sbp, requires_grad)\n\n\ndef zeros_op(\n    *size: Union[_size_any_t, flow.Size, List[int]],\n    dtype: Optional[flow.dtype] = None,\n    device: Union[flow.device, str, None] = None,\n    placement: flow.placement = None,\n    sbp: Union[\n        flow._oneflow_internal.sbp.sbp, List[flow._oneflow_internal.sbp.sbp], None\n    ] = None,\n    requires_grad: bool = False,\n):\n    \"\"\"\n    Returns a tensor filled with the scalar value 0,\n    with the shape defined by the variable argument `size`.\n\n    Args:\n        size(an integer or tuple of integer values) - defining the shape of the output tensor. Can be \\\\\n         a variable number of arguments or a collection like a list or tuple.\n        dtype (flow.dtype, optional): the desired data type of returned tensor.\n        device (flow.device, optional): the desired device of returned tensor. Default: if None, uses the current device for the default tensor type\n        placement (flow.placement, optional): the desired placement of returned global tensor. Default: if None, the returned tensor is local one using the argument `device`.\n        sbp (flow.sbp.sbp or tuple of flow.sbp.sbp, optional): the desired sbp descriptor of returned global tensor. Default: if None, the returned tensor is local one using the argument `device`.\n        requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: False.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> y = flow.zeros(5)\n        >>> y\n        tensor([0., 0., 0., 0., 0.], dtype=oneflow.float32)\n        >>> y = flow.zeros(2,3)\n        >>> y\n        tensor([[0., 0., 0.],\n                [0., 0., 0.]], dtype=oneflow.float32)\n\n    \"\"\"\n    size = _handle_size_arg(size)\n    return Zeros(size, dtype, device, placement, sbp, requires_grad).forward()\n\n\ndef zeros_like_op(\n    input,\n    dtype: Optional[flow.dtype] = None,\n    device: Union[flow.device, str, None] = None,\n    placement: flow.placement = None,\n    sbp: Union[\n        flow._oneflow_internal.sbp.sbp, List[flow._oneflow_internal.sbp.sbp], None\n    ] = None,\n    requires_grad: bool = False,\n):\n    (\n        new_size,\n        new_dtype,\n        new_device,\n        new_placement,\n        new_sbp,\n        new_requires_grad,\n    ) = _handle_meta_args(input, None, dtype, device, placement, sbp, requires_grad)\n    return Zeros(\n        new_size, new_dtype, new_device, new_placement, new_sbp, new_requires_grad\n    ).forward()\n\n\nclass Full(_ConstantBase):\n    def __init__(\n        self,\n        size,\n        value,\n        dtype,\n        device=None,\n        placement=None,\n        sbp=None,\n        requires_grad=False,\n    ):\n        super().__init__(size, value, dtype, device, placement, sbp, requires_grad)\n\n\ndef full_op(\n    size: Union[_size_any_t, flow.Size],\n    fill_value: Union[float, int, complex],\n    dtype: Optional[flow.dtype] = None,\n    device: Union[flow.device, str, None] = None,\n    placement: flow.placement = None,\n    sbp: Union[\n        flow._oneflow_internal.sbp.sbp, List[flow._oneflow_internal.sbp.sbp], None\n    ] = None,\n    requires_grad: bool = False,\n):\n    \"\"\"\n    Creates a tensor of size `size` filled with fill_value. \n    The tensor’s dtype is inferred from `value`.\n\n    Args:\n        size(int...): a list, tuple, or oneflow.Size of integers defining the shape of the output tensor.\n        fill_value(Scalar): the value to fill the output tensor with.\n        dtype (oneflow.dtype, optional): the desired data type of returned tensor.\n        device (oneflow.device, optional): the desired device of returned tensor. Default: if None, uses the current device for the default tensor type\n        placement (oneflow.placement, optional): the desired placement of returned global tensor. Default: if None, the returned tensor is local one using the argument `device`.\n        sbp (oneflow.sbp.sbp or tuple of oneflow.sbp.sbp, optional): the desired sbp descriptor of returned global tensor. Default: if None, the returned tensor is local one using the argument `device`.\n        requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: False.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> y = flow.full((5,),5) \n        >>> y\n        tensor([5, 5, 5, 5, 5], dtype=oneflow.int64)\n        >>> y = flow.full((2,3),5.0) # construct local tensor\n        >>> y\n        tensor([[5., 5., 5.],\n                [5., 5., 5.]], dtype=oneflow.float32)\n        >>> placement = flow.placement(\"cpu\", ranks=[0])\n        >>> y = flow.full((2,3), 5.0, placement=placement, sbp=flow.sbp.broadcast)  # construct global tensor\n        >>> y.is_global\n        True\n\n    \"\"\"\n    size = _handle_size_arg(size)\n    if not isinstance(fill_value, (int, float, complex, flow.Tensor)):\n        # handle numpy scalar dtype\n        assert isinstance(\n            fill_value.dtype, (np.dtype)\n        ), \"fill_value must be python scalar or numpy scalar.\"\n        fill_value = fill_value.item()\n    if dtype is None:\n        dtype = flow.tensor(fill_value).dtype\n    return Full(\n        size, fill_value, dtype, device, placement, sbp, requires_grad\n    ).forward()\n\n\ndef full_like_op(\n    input,\n    fill_value,\n    dtype: Optional[flow.dtype] = None,\n    device: Union[flow.device, str, None] = None,\n    placement: flow.placement = None,\n    sbp: Union[\n        flow._oneflow_internal.sbp.sbp, List[flow._oneflow_internal.sbp.sbp], None\n    ] = None,\n    requires_grad: bool = False,\n):\n    \"\"\"\n    full_like(input, fill_value, \\*, dtype=None, device=None, placement=None, sbp=None, requires_grad=False) -> Tensor\n    \n    Returns a tensor with the same size as :attr:`input` filled with :attr:`fill_value`.\n    ``oneflow.full_like(input, fill_value)`` is equivalent to\n    ``oneflow.full(input.size(), fill_value, dtype=input.dtype, device=input.device)``.\n\n    The interface is consistent with PyTorch.    \n    The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.full_like.html.\n\n    Args:\n        input(oneflow.Tensor)\n        fill_value(Scalar): the value to fill the output tensor with.\n        dtype (oneflow.dtype, optional): the desired data type of returned tensor.\n        device (oneflow.device, optional): the desired device of returned tensor. Default: if None, uses the current device for the default tensor type\n        placement (oneflow.placement, optional): the desired placement of returned global tensor. Default: if None, the returned tensor is local one using the argument `device`.\n        sbp (oneflow.sbp.sbp or tuple of oneflow.sbp.sbp, optional): the desired sbp descriptor of returned global tensor. Default: if None, the returned tensor is local one using the argument `device`.\n        requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: False.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> x = flow.randn(2, 3)\n        >>> y = flow.full_like(x, 2.0)\n        >>> y\n        tensor([[2., 2., 2.],\n                [2., 2., 2.]], dtype=oneflow.float32)\n        >>> y = flow.full_like(x, 2, dtype=flow.int32)\n        >>> y\n        tensor([[2, 2, 2],\n                [2, 2, 2]], dtype=oneflow.int32)\n        >>> placement = flow.placement(\"cpu\", ranks=[0])\n        >>> y = flow.full_like(x, 5.0, placement=placement, sbp=flow.sbp.broadcast)  # construct global tensor\n        >>> y.is_global\n        True\n\n    \"\"\"\n    (\n        new_size,\n        new_dtype,\n        new_device,\n        new_placement,\n        new_sbp,\n        new_requires_grad,\n    ) = _handle_meta_args(input, None, dtype, device, placement, sbp, requires_grad)\n    return Full(\n        new_size,\n        fill_value,\n        new_dtype,\n        new_device,\n        new_placement,\n        new_sbp,\n        new_requires_grad,\n    ).forward()\n\n\ndef new_ones_op(\n    x, size=None, dtype=None, device=None, placement=None, sbp=None, requires_grad=False\n):\n    (\n        new_size,\n        new_dtype,\n        new_device,\n        new_placement,\n        new_sbp,\n        new_requires_grad,\n    ) = _handle_meta_args(x, size, dtype, device, placement, sbp, requires_grad)\n    if new_placement is not None:\n        res = flow._C.global_constant(\n            new_size, 1.0, dtype=new_dtype, placement=placement, sbp=sbp\n        )\n    else:\n        res = flow._C.constant(new_size, 1.0, dtype=new_dtype, device=new_device)\n    res.requires_grad = new_requires_grad\n    return res\n\n\ndef new_zeros_op(\n    x, size=None, dtype=None, device=None, placement=None, sbp=None, requires_grad=False\n):\n    (\n        new_size,\n        new_dtype,\n        new_device,\n        new_placement,\n        new_sbp,\n        new_requires_grad,\n    ) = _handle_meta_args(x, size, dtype, device, placement, sbp, requires_grad)\n    if new_placement is not None:\n        res = flow._C.global_constant(\n            new_size, 0.0, dtype=new_dtype, placement=new_placement, sbp=new_sbp\n        )\n    else:\n        res = flow._C.constant(new_size, 0.0, dtype=new_dtype, device=new_device)\n    res.requires_grad = new_requires_grad\n    return res\n\n\ndef new_full_op(\n    x,\n    size,\n    fill_value,\n    dtype=None,\n    device=None,\n    placement=None,\n    sbp=None,\n    requires_grad=False,\n):\n    size = _handle_size_arg(size)\n    (\n        new_size,\n        new_dtype,\n        new_device,\n        new_placement,\n        new_sbp,\n        new_requires_grad,\n    ) = _handle_meta_args(x, size, dtype, device, placement, sbp, requires_grad)\n    if flow.is_tensor(fill_value):\n        assert (\n            len(fill_value.size()) == 0\n        ), \"new_full(): argument 'fill_value' must be Number, not Tensor\"\n        fill_value = fill_value.item()\n\n    if new_placement is not None:\n        res = flow._C.global_constant(\n            new_size, fill_value, dtype=new_dtype, placement=new_placement, sbp=new_sbp\n        )\n    else:\n        res = flow._C.constant(new_size, fill_value, dtype=new_dtype, device=new_device)\n    res.requires_grad = new_requires_grad\n    return res\n\n\nif __name__ == \"__main__\":\n    import doctest\n\n    doctest.testmod(raise_on_error=True)\n"
  },
  {
    "path": "python/oneflow/nn/modules/container.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nfrom oneflow.nn.utils.container import *\nfrom oneflow.nn.modules.module import Module\n\n\nclass Sequential(get_seq(Module)):\n    \"\"\"A sequential container.\n\n    The interface is consistent with PyTorch.    \n    The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.Sequential.html?#torch.nn.Sequential.\n    \n    Modules will be added to it in the order they are passed in the constructor.\n    Alternatively, an ordered dict of modules can also be passed in.\n\n    To make it easier to understand, here is a small example:\n\n    .. code-block:: python\n\n        >>> import oneflow.nn as nn\n        >>> from collections import OrderedDict\n        >>> nn.Sequential(nn.Conv2d(1,20,5), nn.ReLU(), nn.Conv2d(20,64,5), nn.ReLU()) #doctest: +ELLIPSIS\n        Sequential(\n          (0): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))\n          (1): ReLU()\n          (2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))\n          (3): ReLU()\n        )\n        >>> nn.Sequential(OrderedDict([\n        ...    ('conv1', nn.Conv2d(1,20,5)),\n        ...    ('relu1', nn.ReLU()),\n        ...    ('conv2', nn.Conv2d(20,64,5)),\n        ...    ('relu2', nn.ReLU())\n        ... ])) #doctest: +ELLIPSIS\n        Sequential(\n          (conv1): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))\n          (relu1): ReLU()\n          (conv2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))\n          (relu2): ReLU()\n        )\n\n    \"\"\"\n\n    pass\n\n\nclass ModuleList(get_list(Module)):\n    \"\"\"Holds submodules in a list.\n\n    The interface is consistent with PyTorch.    \n    The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.ModuleList.html?#torch.nn.ModuleList.\n    \n    :class:`~oneflow.nn.ModuleList` can be indexed like a regular Python list, but\n    modules it contains are properly registered, and will be visible by all\n    :class:`~oneflow.nn.Module` methods.\n    \n    Args:\n        modules (iterable, optional): an iterable of modules to add\n    \n    .. code-block:: python\n\n        >>> import oneflow.nn as nn\n\n        >>> class MyModule(nn.Module):\n        ...    def __init__(self):\n        ...        super(MyModule, self).__init__()\n        ...        self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)])\n        ...    def forward(self, x):\n        ...        # ModuleList can act as an iterable, or be indexed using ints\n        ...        for i, l in enumerate(self.linears):\n        ...            x = self.linears[i // 2](x) + l(x)\n        ...        return x\n\n        >>> model = MyModule()\n        >>> model.linears\n        ModuleList(\n          (0): Linear(in_features=10, out_features=10, bias=True)\n          (1): Linear(in_features=10, out_features=10, bias=True)\n          (2): Linear(in_features=10, out_features=10, bias=True)\n          (3): Linear(in_features=10, out_features=10, bias=True)\n          (4): Linear(in_features=10, out_features=10, bias=True)\n          (5): Linear(in_features=10, out_features=10, bias=True)\n          (6): Linear(in_features=10, out_features=10, bias=True)\n          (7): Linear(in_features=10, out_features=10, bias=True)\n          (8): Linear(in_features=10, out_features=10, bias=True)\n          (9): Linear(in_features=10, out_features=10, bias=True)\n        )\n        \n\n    \"\"\"\n\n    pass\n\n\nclass ModuleDict(get_dict(Module)):\n    \"\"\"Holds submodules in a dictionary.\n\n    The interface is consistent with PyTorch.    \n    The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.ModuleDict.html?#torch.nn.ModuleDict.\n\n    :class:`~oneflow.nn.ModuleDict` can be indexed like a regular Python dictionary,\n    but modules it contains are properly registered, and will be visible by all\n    :class:`~oneflow.nn.Module` methods.\n\n    :class:`~oneflow.nn.ModuleDict` is an **ordered** dictionary that respects\n\n    * the order of insertion, and\n\n    * in :meth:`~oneflow.nn.ModuleDict.update`, the order of the merged\n      ``OrderedDict``, ``dict`` (started from Python 3.6) or another\n      :class:`~oneflow.nn.ModuleDict` (the argument to\n      :meth:`~oneflow.nn.ModuleDict.update`).\n\n    Note that :meth:`~oneflow.nn.ModuleDict.update` with other unordered mapping\n    types (e.g., Python's plain ``dict`` before Python version 3.6) does not\n    preserve the order of the merged mapping.\n\n    Args:\n        modules (iterable, optional): a mapping (dictionary) of (string: module)\n            or an iterable of key-value pairs of type (string, module)\n\n    .. code-block:: python\n\n        >>> import oneflow.nn as nn\n\n        >>> class MyModule(nn.Module):\n        ...    def __init__(self):\n        ...        super(MyModule, self).__init__()\n        ...        self.choices = nn.ModuleDict({\n        ...                'conv': nn.Conv2d(10, 10, 3),\n        ...                'pool': nn.MaxPool2d(3)\n        ...        })\n        ...        self.activations = nn.ModuleDict([\n        ...                ['lrelu', nn.LeakyReLU()],\n        ...                ['prelu', nn.PReLU()]\n        ...        ])\n\n        ...    def forward(self, x, choice, act):\n        ...        x = self.choices[choice](x)\n        ...        x = self.activations[act](x)\n        ...        return x\n    \n        >>> model = MyModule()\n        >>> model.choices\n        ModuleDict(\n          (conv): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1))\n          (pool): MaxPool2d(kernel_size=(3, 3), stride=(3, 3), padding=(0, 0), dilation=(1, 1))\n        )\n    \"\"\"\n\n    pass\n\n\nclass ParameterList(get_para_list(Module)):\n    \"\"\"Holds parameters in a list.\n\n    :class:`~oneflow.nn.ParameterList` can be indexed like a regular Python\n    list, but parameters it contains are properly registered, and will be\n    visible by all :class:`~oneflow.nn.Module` methods.\n\n    The interface is consistent with PyTorch.    \n    The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.ParameterList.html?#torch.nn.ParameterList.\n\n    Args:\n        parameters (iterable, optional): an iterable of :class:`~oneflow.nn.Parameter` to add\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import oneflow.nn as nn\n\n        >>> class MyModule(nn.Module):\n        ...    def __init__(self):\n        ...        super(MyModule, self).__init__()\n        ...        self.params = nn.ParameterList([nn.Parameter(flow.randn(10, 10)) for i in range(10)])\n        ...\n        ...    def forward(self, x):\n        ...        # ParameterList can act as an iterable, or be indexed using ints\n        ...        for i, p in enumerate(self.params):\n        ...            x = self.params[i // 2].mm(x) + p.mm(x)\n        ...        return x\n\n        >>> model = MyModule()\n        >>> model.params\n        ParameterList(\n            (0): Parameter containing: [<class 'oneflow.nn.Parameter'> of size 10x10]\n            (1): Parameter containing: [<class 'oneflow.nn.Parameter'> of size 10x10]\n            (2): Parameter containing: [<class 'oneflow.nn.Parameter'> of size 10x10]\n            (3): Parameter containing: [<class 'oneflow.nn.Parameter'> of size 10x10]\n            (4): Parameter containing: [<class 'oneflow.nn.Parameter'> of size 10x10]\n            (5): Parameter containing: [<class 'oneflow.nn.Parameter'> of size 10x10]\n            (6): Parameter containing: [<class 'oneflow.nn.Parameter'> of size 10x10]\n            (7): Parameter containing: [<class 'oneflow.nn.Parameter'> of size 10x10]\n            (8): Parameter containing: [<class 'oneflow.nn.Parameter'> of size 10x10]\n            (9): Parameter containing: [<class 'oneflow.nn.Parameter'> of size 10x10]\n        )\n    \"\"\"\n\n    pass\n\n\nclass ParameterDict(get_para_dict(Module)):\n    \"\"\"\n    Holds parameters in a dictionary.\n\n    ParameterDict can be indexed like a regular Python dictionary, but parameters it\n    contains are properly registered, and will be visible by all Module methods.\n\n    :class:`~oneflow.nn.ParameterDict` is an **ordered** dictionary that respects\n\n    * the order of insertion, and\n\n    * in :meth:`~oneflow.nn.ParameterDict.update`, the order of the merged ``OrderedDict``\n      or another :class:`~oneflow.nn.ParameterDict` (the argument to\n      :meth:`~oneflow.nn.ParameterDict.update`).\n\n    Note that :meth:`~oneflow.nn.ParameterDict.update` with other unordered mapping\n    types (e.g., Python's plain ``dict``) does not preserve the order of the\n    merged mapping.\n    \n    The interface is consistent with PyTorch.    \n    The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.ParameterDict.html?#torch.nn.ParameterDict.\n\n    Args:\n        parameters (iterable, optional): a mapping (dictionary) of\n            (string : :class:`~oneflow.nn.Parameter`) or an iterable of key-value pairs\n            of type (string, :class:`~oneflow.nn.Parameter`)\n\n    .. code-block:: python\n        \n        >>> import oneflow as flow\n        >>> import oneflow.nn as nn\n\n        >>> class MyModule(nn.Module):\n        ...    def __init__(self):\n        ...        super(MyModule, self).__init__()\n        ...        self.params = nn.ParameterDict({\n        ...                'left': nn.Parameter(flow.randn(5, 10)),\n        ...                'right': nn.Parameter(flow.randn(5, 10))\n        ...        })\n        ...\n        ...    def forward(self, x, choice):\n        ...        x = self.params[choice].mm(x)\n        ...        return x\n\n        >>> model = MyModule()\n        >>> model.params\n        ParameterDict(\n            (left): Parameter containing: [<class 'oneflow.nn.Parameter'> of size 5x10]\n            (right): Parameter containing: [<class 'oneflow.nn.Parameter'> of size 5x10]\n        )\n    \"\"\"\n\n    pass\n\n\nif __name__ == \"__main__\":\n    import doctest\n\n    doctest.testmod(raise_on_error=True)\n"
  },
  {
    "path": "python/oneflow/nn/modules/conv.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport math\nimport os\n\nimport oneflow as flow\nfrom oneflow.nn import init\nfrom oneflow.nn.common_types import _size_1_t, _size_2_t, _size_3_t\nfrom oneflow.nn.modules.module import Module\nfrom oneflow.nn.modules.utils import _pair, _single, _triple\n\nfrom typing import Union\n\n\ndef slice(x, begin, size):\n    ndim = len(x.shape)\n    if not isinstance(begin, (list, tuple)) or len(begin) != ndim:\n        raise ValueError(\n            \"begin must be a list/tuple with the same length as input tensor's number of dimensions\"\n        )\n    if not all((isinstance(b, int) or b is None for b in begin)):\n        raise ValueError(\"element of begin must be a int or None\")\n    if not isinstance(size, (list, tuple)) or len(size) != ndim:\n        raise ValueError(\n            \"size must be a list/tuple with the same length as input tensor's number of dimensions.\"\n        )\n    if not all((isinstance(s, int) or s is None for s in size)):\n        raise ValueError(\"element of size must be a int or None\")\n    slice_tup_list = []\n    for (b, s, dim_size) in zip(begin, size, x.shape):\n        (start, stop, step) = (None, None, 1)\n        if b is not None:\n            if b < -dim_size or b >= dim_size:\n                raise ValueError(\"element of begin is out of range\")\n            start = b\n        if s is not None:\n            if s == -1:\n                stop = dim_size\n            else:\n                if s <= 0 or s > dim_size:\n                    raise ValueError(\"element of size is invalid\")\n                if b + s < dim_size:\n                    stop = b + s\n        slice_tup_list.append((start, stop, step))\n    return flow.slice(x, slice_tup_list)\n\n\nclass ConvUtil(object):\n    @classmethod\n    def split(cls, x, axis, split_num):\n        split_len = x.shape[axis] // split_num\n        result_list = []\n        slice_begin = [0] * len(x.shape)\n        slice_size = [-1] * len(x.shape)\n        slice_size[axis] = split_len\n        for i in range(split_num):\n            slice_begin[axis] = i * split_len\n            result = slice(x, slice_begin, slice_size)\n            result_list.append(result)\n        return result_list\n\n\ndef get_padding(padding, kernel_size, dilation, stride):\n    valid_padding_strings = {\"same\", \"valid\"}\n    if isinstance(padding, str):\n        if padding not in valid_padding_strings:\n            raise ValueError(\n                \"Invalid padding string {!r}, should be one of {}\".format(\n                    padding, valid_padding_strings\n                )\n            )\n        if padding == \"same\" and any(s != 1 for s in list(stride)):\n            raise ValueError(\"padding='same' is not supported for strided convolutions\")\n\n    out_padding = [0] * len(kernel_size)\n    if padding == \"same\":\n        for d, k, i in zip(dilation, kernel_size, range(len(kernel_size) - 1, -1, -1)):\n            total_padding = d * (k - 1)\n            left_pad = total_padding // 2\n            out_padding[i] = left_pad\n    return out_padding\n\n\nclass Conv1d(Module):\n    \"\"\"Applies a 1D convolution over an input signal composed of several input\n    planes.\n    \n    The interface is consistent with PyTorch.    \n    The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.Conv1d.html.\n    \n    In the simplest case, the output value of the layer with input size\n    :math:`(N, C_{\\\\text{in}}, L)` and output :math:`(N, C_{\\\\text{out}}, L_{\\\\text{out}})` can be\n    precisely described as:\n\n    .. math::\n        \\\\text{out}(N_i, C_{\\\\text{out}_j}) = \\\\text{bias}(C_{\\\\text{out}_j}) +\n        \\\\sum_{k = 0}^{C_{in} - 1} \\\\text{weight}(C_{\\\\text{out}_j}, k)\n        \\\\star \\\\text{input}(N_i, k)\n\n    where :math:`\\\\star` is the valid `cross-correlation`_ operator,\n    :math:`N` is a batch size, :math:`C` denotes a number of channels,\n    :math:`L` is a length of signal sequence.\n\n    * :attr:`stride` controls the stride for the cross-correlation, a single\n      number or a one-element tuple.\n\n    * :attr:`padding` controls the amount of padding applied to the input. It\n      can be either a string {{'valid', 'same'}} or a tuple of ints giving the\n      amount of implicit padding applied on both sides.\n\n    * :attr:`dilation` controls the spacing between the kernel points; also\n      known as the à trous algorithm. It is harder to describe, but this `link`_\n      has a nice visualization of what :attr:`dilation` does.\n\n    Note:\n        ``padding='valid'`` is the same as no padding. ``padding='same'`` pads\n        the input so the output has the shape as the input. However, this mode\n        doesn't support any stride values other than 1.\n\n    Args:\n        in_channels (int): Number of channels in the input image\n        out_channels (int): Number of channels produced by the convolution\n        kernel_size (int or tuple): Size of the convolving kernel\n        stride (int or tuple, optional): Stride of the convolution. Default: 1\n        padding (int, tuple or str, optional): Padding added to both sides of\n            the input. Default: 0\n        padding_mode (string, optional): ``'zeros'``. Default: ``'zeros'``\n        dilation (int or tuple, optional): Spacing between kernel\n            elements. Default: 1\n        groups (int, optional): Number of blocked connections from input\n            channels to output channels. Default: 1\n        bias (bool, optional): If ``True``, adds a learnable bias to the\n            output. Default: ``True``\n\n    Shape:\n        - Input: :math:`(N, C_{in}, L_{in})`\n        - Output: :math:`(N, C_{out}, L_{out})` where\n\n          .. math::\n              L_{out} = \\\\left\\\\lfloor\\\\frac{L_{in} + 2 \\\\times \\\\text{padding} - \\\\text{dilation}\n                        \\\\times (\\\\text{kernel\\\\_size} - 1) - 1}{\\\\text{stride}} + 1\\\\right\\\\rfloor\n\n    Attributes:\n        weight (Tensor): the learnable weights of the module of shape\n            :math:`(\\\\text{out\\\\_channels},\n            \\\\frac{\\\\text{in\\\\_channels}}{\\\\text{groups}}, \\\\text{kernel\\\\_size})`.\n            The values of these weights are sampled from\n            :math:`\\\\mathcal{U}(-\\\\sqrt{k}, \\\\sqrt{k})` where\n            :math:`k = \\\\frac{groups}{C_\\\\text{in} * \\\\text{kernel\\\\_size}}`\n        bias (Tensor):   the learnable bias of the module of shape\n            (out_channels). If :attr:`bias` is ``True``, then the values of these weights are\n            sampled from :math:`\\\\mathcal{U}(-\\\\sqrt{k}, \\\\sqrt{k})` where\n            :math:`k = \\\\frac{groups}{C_\\\\text{in} * \\\\text{kernel\\\\_size}}`\n\n    For example: \n\n    .. code-block:: python\n\n        >>> import numpy as np\n        >>> import oneflow as flow\n        >>> import oneflow.nn as nn\n        \n        >>> arr = np.random.randn(20, 16, 50)\n        >>> input = flow.Tensor(arr)\n        >>> m = nn.Conv1d(16, 33, 3, stride=2)\n        >>> output = m(input)\n\n    .. _cross-correlation:\n        https://en.wikipedia.org/wiki/Cross-correlation\n\n    .. _link:\n        https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md\n    \"\"\"\n\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        kernel_size: _size_1_t,\n        stride: _size_1_t = 1,\n        padding: Union[str, _size_1_t] = 0,\n        dilation: _size_1_t = 1,\n        groups: int = 1,\n        bias: bool = True,\n        padding_mode: str = \"zeros\",\n        device=None,\n        dtype=None,\n    ):\n        super().__init__()\n        assert padding_mode == \"zeros\"\n        self.padding_mode = padding_mode\n        self.kernel_size = _single(kernel_size)\n        self.stride = _single(stride)\n        self.dilation = _single(dilation)\n        self.padding = (\n            get_padding(padding, self.kernel_size, self.dilation, self.stride)\n            if isinstance(padding, str)\n            else _single(padding)\n        )\n        self.groups = groups\n        self.channel_pos = \"channels_first\"\n        assert in_channels % groups == 0\n        assert out_channels % groups == 0\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.weight = flow.nn.Parameter(\n            flow.empty(\n                out_channels,\n                in_channels // groups,\n                *self.kernel_size,\n                dtype=dtype,\n                device=device\n            )\n        )\n        self.out_channel_groups = out_channels // groups\n        self.bias = None\n        if bias:\n            self.bias = flow.nn.Parameter(\n                flow.empty(out_channels, dtype=dtype, device=device)\n            )\n        self.reset_parameters()\n\n    def reset_parameters(self) -> None:\n        init.kaiming_uniform_(self.weight, a=math.sqrt(5))\n        if self.bias is not None:\n            (fan_in, _) = init._calculate_fan_in_and_fan_out(self.weight)\n            bound = 1 / math.sqrt(fan_in)\n            init.uniform_(self.bias, -bound, bound)\n\n    def _conv_forward(self, x, weight, bias):\n        return flow._C.conv1d(\n            x,\n            weight,\n            bias,\n            stride=self.stride,\n            padding=self.padding,\n            dilation=self.dilation,\n            groups=self.groups,\n            channel_pos=self.channel_pos,\n        )\n\n    def forward(self, x):\n        return self._conv_forward(x, self.weight, self.bias)\n\n    def extra_repr(self):\n        s = \"{in_channels}, {out_channels}, kernel_size={kernel_size}, stride={stride}\"\n        if self.padding != (0,) * len(self.padding):\n            s += \", padding={padding}\"\n        if self.dilation != (1,) * len(self.dilation):\n            s += \", dilation={dilation}\"\n        if self.groups != 1:\n            s += \", groups={groups}\"\n        if self.bias is None:\n            s += \", bias=False\"\n        if self.padding_mode != \"zeros\":\n            s += \", padding_mode={padding_mode}\"\n        return s.format(**self.__dict__)\n\n\nclass Conv2d(Module):\n    \"\"\"Applies a 2D convolution over an input signal composed of several input\n    planes.\n    The interface is consistent with PyTorch.    \n    The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.Conv2d.html.\n\n    In the simplest case, the output value of the layer with input size\n    :math:`(N, C_{\\\\text{in}}, H, W)` and output :math:`(N, C_{\\\\text{out}}, H_{\\\\text{out}}, W_{\\\\text{out}})`\n    can be precisely described as:\n\n    .. math::\n        \\\\text{out}(N_i, C_{\\\\text{out}_j}) = \\\\text{bias}(C_{\\\\text{out}_j}) +\n        \\\\sum_{k = 0}^{C_{\\\\text{in}} - 1} \\\\text{weight}(C_{\\\\text{out}_j}, k) \\\\star \\\\text{input}(N_i, k)\n\n\n    where :math:`\\\\star` is the valid 2D `cross-correlation`_ operator,\n    :math:`N` is a batch size, :math:`C` denotes a number of channels,\n    :math:`H` is a height of input planes in pixels, and :math:`W` is\n    width in pixels.\n\n\n    * :attr:`stride` controls the stride for the cross-correlation, a single\n      number or a tuple.\n    * :attr:`padding` controls the amount of implicit padding on both\n      sides for :attr:`padding` number of points for each dimension.\n    * :attr:`dilation` controls the spacing between the kernel points; also\n      known as the à trous algorithm. It is harder to describe, but this `link`_\n      has a nice visualization of what :attr:`dilation` does.\n    * :attr:`groups` controls the connections between inputs and outputs.\n      :attr:`in_channels` and :attr:`out_channels` must both be divisible by\n      :attr:`groups`. For example,\n\n        * At groups=1, all inputs are convolved to all outputs.\n        * At groups=2, the operation becomes equivalent to having two conv\n          layers side by side, each seeing half the input channels\n          and producing half the output channels, and both subsequently\n          concatenated.\n        * At groups= :attr:`in_channels`, each input channel is convolved with\n          its own set of filters (of size\n          :math:`\\\\frac{\\\\text{out_channels}}{\\\\text{in_channels}}`).,\n\n    The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`dilation` can either be:\n\n        - a single ``int`` -- in which case the same value is used for the height and width dimension\n        - a ``tuple`` of two ints -- in which case, the first `int` is used for the height dimension,\n          and the second `int` for the width dimension\n\n    Note:\n        When `groups == in_channels` and `out_channels == K * in_channels`,\n        where `K` is a positive integer, this operation is also known as a \"depthwise convolution\".\n\n        In other words, for an input of size :math:`(N, C_{in}, L_{in})`,\n        a depthwise convolution with a depthwise multiplier `K` can be performed with the arguments\n        :math:`(C_\\\\text{in}=C_\\\\text{in}, C_\\\\text{out}=C_\\\\text{in} \\\\times \\\\text{K}, ..., \\\\text{groups}=C_\\\\text{in})`.\n\n\n    Args:\n        in_channels (int): Number of channels in the input image\n        out_channels (int): Number of channels produced by the convolution\n        kernel_size (int or tuple): Size of the convolving kernel\n        stride (int or tuple, optional): Stride of the convolution. Default: 1\n        padding (int or tuple, optional): Zero-padding added to both sides of\n            the input. Default: 0\n        padding_mode (string, optional): ``'zeros'``. Default: ``'zeros'``\n        dilation (int or tuple, optional): Spacing between kernel elements. Default: 1\n        groups (int, optional): Number of blocked connections from input\n            channels to output channels. Default: 1\n        bias (bool, optional): If ``True``, adds a learnable bias to the\n            output. Default: ``True``\n\n    Shape:\n        - Input: :math:`(N, C_{in}, H_{in}, W_{in})`\n        - Output: :math:`(N, C_{out}, H_{out}, W_{out})` where\n\n          .. math::\n              H_{out} = \\\\left\\\\lfloor\\\\frac{H_{in}  + 2 \\\\times \\\\text{padding}[0] - \\\\text{dilation}[0]\n                        \\\\times (\\\\text{kernel_size}[0] - 1) - 1}{\\\\text{stride}[0]} + 1\\\\right\\\\rfloor\n\n          .. math::\n              W_{out} = \\\\left\\\\lfloor\\\\frac{W_{in}  + 2 \\\\times \\\\text{padding}[1] - \\\\text{dilation}[1]\n                        \\\\times (\\\\text{kernel_size}[1] - 1) - 1}{\\\\text{stride}[1]} + 1\\\\right\\\\rfloor\n\n    Attr:\n        - weight (Tensor): the learnable weights of the module of shape\n            :math:`(\\\\text{out_channels}, \\\\frac{\\\\text{in_channels}}{\\\\text{groups}},`\n            :math:`\\\\text{kernel_size[0]}, \\\\text{kernel_size[1]})`.\n            The values of these weights are sampled from\n            :math:`\\\\mathcal{U}(-\\\\sqrt{k}, \\\\sqrt{k})` where\n            :math:`k = \\\\frac{groups}{C_\\\\text{in} * \\\\prod_{i=0}^{1}\\\\text{kernel_size}[i]}`\n\n        - bias (Tensor):   the learnable bias of the module of shape\n            (out_channels). If :attr:`bias` is ``True``,\n            then the values of these weights are\n            sampled from :math:`\\\\mathcal{U}(-\\\\sqrt{k}, \\\\sqrt{k})` where\n            :math:`k = \\\\frac{groups}{C_\\\\text{in} * \\\\prod_{i=0}^{1}\\\\text{kernel_size}[i]}`\n\n    For example: \n\n    .. code-block:: python\n\n        >>> import numpy as np\n        >>> import oneflow as flow\n        >>> import oneflow.nn as nn\n        \n        >>> arr = np.random.randn(20, 16, 50, 100)\n        >>> input = flow.Tensor(arr)\n        >>> m = nn.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1))\n        >>> output = m(input)\n\n    .. _cross-correlation:\n        https://en.wikipedia.org/wiki/Cross-correlation\n\n    .. _link:\n        https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md\n    \"\"\"\n\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        kernel_size: _size_2_t,\n        stride: _size_2_t = 1,\n        padding: Union[str, _size_2_t] = 0,\n        dilation: _size_2_t = 1,\n        groups: int = 1,\n        bias: bool = True,\n        padding_mode: str = \"zeros\",\n        device=None,\n        dtype=None,\n    ):\n        super().__init__()\n        assert padding_mode == \"zeros\"\n        self.padding_mode = padding_mode\n        self.kernel_size = _pair(kernel_size)\n        self.stride = _pair(stride)\n        self.dilation = _pair(dilation)\n        self.padding = (\n            get_padding(padding, self.kernel_size, self.dilation, self.stride)\n            if isinstance(padding, str)\n            else _pair(padding)\n        )\n        self.groups = groups\n        self.transposed = False\n\n        if os.getenv(\"ONEFLOW_ENABLE_NHWC\") == \"1\":\n            self.channel_pos = \"channels_last\"\n            self.transposed = True\n        else:\n            self.channel_pos = \"channels_first\"\n\n        assert in_channels % groups == 0\n        assert out_channels % groups == 0\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        if self.channel_pos == \"channels_first\":\n            self.weight = flow.nn.Parameter(\n                flow.empty(\n                    out_channels,\n                    in_channels // groups,\n                    *self.kernel_size,\n                    device=device,\n                    dtype=dtype\n                )\n            )\n        else:\n            self.weight = flow.nn.Parameter(\n                flow.empty(\n                    out_channels,\n                    *self.kernel_size,\n                    in_channels // groups,\n                    device=device,\n                    dtype=dtype\n                )\n            )\n\n        self.out_channel_groups = out_channels // groups\n        self.bias = None\n        if bias:\n            self.bias = flow.nn.Parameter(\n                flow.empty(out_channels, device=device, dtype=dtype)\n            )\n        self.reset_parameters()\n\n    def reset_parameters(self) -> None:\n        init.kaiming_uniform_(self.weight, a=math.sqrt(5))\n        if self.bias is not None:\n            (fan_in, _) = init._calculate_fan_in_and_fan_out(self.weight)\n            bound = 1 / math.sqrt(fan_in)\n            init.uniform_(self.bias, -bound, bound)\n\n    def to_memory_format(self, memory_format) -> None:\n        if self.channel_pos == \"channels_first\" and memory_format is flow.channels_last:\n            self.channel_pos = \"channels_last\"\n            with flow.no_grad():\n                self.weight.data = self.weight.to(memory_format=flow.channels_last)\n        elif (\n            self.channel_pos == \"channels_last\"\n            and memory_format is flow.contiguous_format\n        ):\n            self.channel_pos = \"channels_first\"\n            with flow.no_grad():\n                self.weight.data = self.weight.to(memory_format=flow.contiguous_format)\n\n    def _conv_forward(self, x, weight, bias):\n        return flow._C.conv2d(\n            x,\n            weight,\n            bias,\n            stride=self.stride,\n            padding=self.padding,\n            dilation=self.dilation,\n            groups=self.groups,\n            channel_pos=self.channel_pos,\n        )\n\n    def forward(self, x):\n        return self._conv_forward(x, self.weight, self.bias)\n\n    def extra_repr(self):\n        s = \"{in_channels}, {out_channels}, kernel_size={kernel_size}, stride={stride}\"\n        if self.padding != (0,) * len(self.padding):\n            s += \", padding={padding}\"\n        if self.dilation != (1,) * len(self.dilation):\n            s += \", dilation={dilation}\"\n        if self.groups != 1:\n            s += \", groups={groups}\"\n        if self.bias is None:\n            s += \", bias=False\"\n        if self.padding_mode != \"zeros\":\n            s += \", padding_mode={padding_mode}\"\n        return s.format(**self.__dict__)\n\n\nclass Conv3d(Module):\n    r\"\"\"Applies a 3D convolution over an input signal composed of several input\n    planes.\n    \n    The interface is consistent with PyTorch.    \n    The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.Conv3d.html.\n\n    In the simplest case, the output value of the layer with input size :math:`(N, C_{in}, D, H, W)`\n    and output :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})` can be precisely described as:\n\n    .. math::\n        out(N_i, C_{out_j}) = bias(C_{out_j}) +\n                                \\sum_{k = 0}^{C_{in} - 1} weight(C_{out_j}, k) \\star input(N_i, k)\n\n    where :math:`\\star` is the valid 3D `cross-correlation`_ operator\n\n    * :attr:`stride` controls the stride for the cross-correlation.\n\n    * :attr:`padding` controls the amount of padding applied to the input. It\n      can be either a string {{'valid', 'same'}} or a tuple of ints giving the\n      amount of implicit padding applied on both sides.\n\n    * :attr:`dilation` controls the spacing between the kernel points; also known as the à trous algorithm.\n      It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does.\n\n    The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`dilation` can either be:\n\n        - a single ``int`` -- in which case the same value is used for the depth, height and width dimension\n        - a ``tuple`` of three ints -- in which case, the first `int` is used for the depth dimension,\n          the second `int` for the height dimension and the third `int` for the width dimension\n\n    Note:\n        ``padding='valid'`` is the same as no padding. ``padding='same'`` pads\n        the input so the output has the shape as the input. However, this mode\n        doesn't support any stride values other than 1.\n\n    Args:\n        in_channels (int): Number of channels in the input image\n        out_channels (int): Number of channels produced by the convolution\n        kernel_size (int or tuple): Size of the convolving kernel\n        stride (int or tuple, optional): Stride of the convolution. Default: 1\n        padding (int, tuple or str, optional): Padding added to all six sides of\n            the input. Default: 0\n        padding_mode (string, optional): ``'zeros'``. Default: ``'zeros'``\n        dilation (int or tuple, optional): Spacing between kernel elements. Default: 1\n        groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1\n        bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True``\n    \n    Shape:\n        - Input: :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})`\n        - Output: :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})` where\n\n          .. math::\n              D_{out} = \\left\\lfloor\\frac{D_{in} + 2 \\times \\text{padding}[0] - \\text{dilation}[0]\n                    \\times (\\text{kernel\\_size}[0] - 1) - 1}{\\text{stride}[0]} + 1\\right\\rfloor\n\n          .. math::\n              H_{out} = \\left\\lfloor\\frac{H_{in} + 2 \\times \\text{padding}[1] - \\text{dilation}[1]\n                    \\times (\\text{kernel\\_size}[1] - 1) - 1}{\\text{stride}[1]} + 1\\right\\rfloor\n\n          .. math::\n              W_{out} = \\left\\lfloor\\frac{W_{in} + 2 \\times \\text{padding}[2] - \\text{dilation}[2]\n                    \\times (\\text{kernel\\_size}[2] - 1) - 1}{\\text{stride}[2]} + 1\\right\\rfloor\n\n    Attributes:\n        weight (Tensor): the learnable weights of the module of shape\n                         :math:`(\\text{out\\_channels}, \\frac{\\text{in\\_channels}}{\\text{groups}},`\n                         :math:`\\text{kernel\\_size[0]}, \\text{kernel\\_size[1]}, \\text{kernel\\_size[2]})`.\n                         The values of these weights are sampled from\n                         :math:`\\mathcal{U}(-\\sqrt{k}, \\sqrt{k})` where\n                         :math:`k = \\frac{groups}{C_\\text{in} * \\prod_{i=0}^{2}\\text{kernel\\_size}[i]}`\n        bias (Tensor):   the learnable bias of the module of shape (out_channels). If :attr:`bias` is ``True``,\n                         then the values of these weights are\n                         sampled from :math:`\\mathcal{U}(-\\sqrt{k}, \\sqrt{k})` where\n                         :math:`k = \\frac{groups}{C_\\text{in} * \\prod_{i=0}^{2}\\text{kernel\\_size}[i]}`\n\n    For example: \n\n    .. code-block:: python\n\n        >>> import numpy as np\n        >>> import oneflow as flow\n        >>> import oneflow.nn as nn\n\n        >>> arr = np.random.randn(1, 2, 5, 5, 5)\n        >>> input = flow.Tensor(arr)\n        >>> m = nn.Conv3d(2, 4, kernel_size=3, stride=1)\n        >>> output = m(input)\n        \n    .. _cross-correlation:\n        https://en.wikipedia.org/wiki/Cross-correlation\n    .. _link:\n        https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md\n    \"\"\"\n\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        kernel_size: _size_3_t,\n        stride: _size_3_t = 1,\n        padding: Union[str, _size_3_t] = 0,\n        dilation: _size_3_t = 1,\n        groups: int = 1,\n        bias: bool = True,\n        padding_mode: str = \"zeros\",  # TODO: refine this type\n        device=None,\n        dtype=None,\n    ):\n        super().__init__()\n\n        assert padding_mode == \"zeros\"\n        self.padding_mode = padding_mode\n        self.kernel_size = _triple(kernel_size)\n        self.stride = _triple(stride)\n        self.dilation = _triple(dilation)\n        self.padding = (\n            get_padding(padding, self.kernel_size, self.dilation, self.stride)\n            if isinstance(padding, str)\n            else _triple(padding)\n        )\n        self.groups = groups\n        self.channel_pos = \"channels_first\"\n        assert in_channels % groups == 0, \"in_channels must be divisible by groups\"\n        assert out_channels % groups == 0, \"out_channels must be divisible by groups\"\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.weight = flow.nn.Parameter(\n            flow.empty(\n                out_channels,\n                in_channels // groups,\n                *self.kernel_size,\n                device=device,\n                dtype=dtype\n            )\n        )\n        self.out_channel_groups = out_channels // groups\n        self.bias = None\n        if bias:\n            self.bias = flow.nn.Parameter(\n                flow.empty(out_channels, device=device, dtype=dtype)\n            )\n        self.reset_parameters()\n\n    def reset_parameters(self) -> None:\n        init.kaiming_uniform_(self.weight, a=math.sqrt(5))\n        if self.bias is not None:\n            fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)\n            bound = 1 / math.sqrt(fan_in)\n            init.uniform_(self.bias, -bound, bound)\n\n    def to_memory_format(self, memory_format) -> None:\n        if self.channel_pos == \"channels_first\" and memory_format is flow.channels_last:\n            self.channel_pos = \"channels_last\"\n            with flow.no_grad():\n                self.weight.data = self.weight.to(memory_format=flow.channels_last)\n        elif (\n            self.channel_pos == \"channels_last\"\n            and memory_format is flow.contiguous_format\n        ):\n            self.channel_pos = \"channels_first\"\n            with flow.no_grad():\n                self.weight.data = self.weight.to(memory_format=flow.contiguous_format)\n\n    def _conv_forward(self, x, weight, bias):\n        return flow._C.conv3d(\n            x,\n            weight,\n            bias,\n            stride=self.stride,\n            padding=self.padding,\n            dilation=self.dilation,\n            groups=self.groups,\n            channel_pos=self.channel_pos,\n        )\n\n    def forward(self, x):\n        return self._conv_forward(x, self.weight, self.bias)\n\n    def extra_repr(self):\n        s = \"{in_channels}, {out_channels}, kernel_size={kernel_size}, stride={stride}\"\n        if self.padding != (0,) * len(self.padding):\n            s += \", padding={padding}\"\n        if self.dilation != (1,) * len(self.dilation):\n            s += \", dilation={dilation}\"\n        if self.groups != 1:\n            s += \", groups={groups}\"\n        if self.bias is None:\n            s += \", bias=False\"\n        if self.padding_mode != \"zeros\":\n            s += \", padding_mode={padding_mode}\"\n        return s.format(**self.__dict__)\n\n\nclass ConvTranspose1d(Module):\n    r\"\"\"Applies a 1D transposed convolution operator over an input image\n    composed of several input planes.\n\n    This module can be seen as the gradient of Conv1d with respect to its input.\n    It is also known as a fractionally-strided convolution or\n    a deconvolution (although it is not an actual deconvolution operation).\n\n    This module supports TensorFloat32.\n\n    * :attr:`stride` controls the stride for the cross-correlation.\n\n    * :attr:`padding` controls the amount of implicit zero padding on both\n      sides for ``dilation * (kernel_size - 1) - padding`` number of points. See note\n      below for details.\n\n    * :attr:`output_padding` controls the additional size added to one side\n      of the output shape. See note below for details.\n\n    * :attr:`dilation` controls the spacing between the kernel points; also known as the à trous algorithm.\n      It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does.\n\n    Note:\n        The :attr:`padding` argument effectively adds ``dilation * (kernel_size - 1) - padding``\n        amount of zero padding to both sizes of the input. This is set so that\n        when a :class:`~torch.nn.Conv1d` and a :class:`~torch.nn.ConvTranspose1d`\n        are initialized with same parameters, they are inverses of each other in\n        regard to the input and output shapes. However, when ``stride > 1``,\n        :class:`~torch.nn.Conv1d` maps multiple input shapes to the same output\n        shape. :attr:`output_padding` is provided to resolve this ambiguity by\n        effectively increasing the calculated output shape on one side. Note\n        that :attr:`output_padding` is only used to find output shape, but does\n        not actually add zero-padding to output.\n\n    Note:\n        In some circumstances when using the CUDA backend with CuDNN, this operator\n        may select a nondeterministic algorithm to increase performance. If this is\n        undesirable, you can try to make the operation deterministic (potentially at\n        a performance cost) by setting ``torch.backends.cudnn.deterministic =\n        True``.\n        Please see the notes on randomness for background.\n\n\n    Args:\n        in_channels (int): Number of channels in the input image\n        out_channels (int): Number of channels produced by the convolution\n        kernel_size (int or tuple): Size of the convolving kernel\n        stride (int or tuple, optional): Stride of the convolution. Default: 1\n        padding (int or tuple, optional): ``dilation * (kernel_size - 1) - padding`` zero-padding\n            will be added to both sides of the input. Default: 0\n        output_padding (int or tuple, optional): Additional size added to one side\n            of the output shape. Default: 0\n        groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1\n        bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True``\n        dilation (int or tuple, optional): Spacing between kernel elements. Default: 1\n\n    Shape:\n        - Input: :math:`(N, C_{in}, L_{in})`\n        - Output: :math:`(N, C_{out}, L_{out})` where\n\n          .. math::\n              L_{out} = (L_{in} - 1) \\times \\text{stride} - 2 \\times \\text{padding} + \\text{dilation}\n                        \\times (\\text{kernel_size} - 1) + \\text{output_padding} + 1\n\n    Attributes:\n        weight (Tensor): the learnable weights of the module of shape\n                         :math:`(\\\\text{in\\_channels}, \\frac{\\\\text{out\\\\_channels}}{\\text{groups}},`\n                         :math:`\\\\text{kernel\\\\_size})`.\n                         The values of these weights are sampled from\n                         :math:`\\mathcal{U}(-\\sqrt{k}, \\sqrt{k})` where\n                         :math:`k = \\frac{groups}{C_\\text{out} * \\\\text{kernel\\\\_size}}`\n        bias (Tensor):   the learnable bias of the module of shape (out_channels).\n                         If :attr:`bias` is ``True``, then the values of these weights are\n                         sampled from :math:`\\mathcal{U}(-\\sqrt{k}, \\sqrt{k})` where\n                         :math:`k = \\frac{groups}{C_\\text{out} * \\\\text{kernel\\\\_size}}`\n    \n    .. _cross-correlation:\n        https://en.wikipedia.org/wiki/Cross-correlation\n\n    .. _link:\n        https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md\n    \"\"\"\n\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        kernel_size: _size_1_t,\n        stride: _size_1_t = 1,\n        padding: _size_1_t = 0,\n        output_padding: _size_1_t = 0,\n        groups: int = 1,\n        bias: bool = True,\n        dilation: _size_1_t = 1,\n        padding_mode: str = \"zeros\",\n    ) -> None:\n        super().__init__()\n        assert (\n            padding_mode == \"zeros\"\n        ), \"Only `zeros` padding mode is supported for ConvTranspose1d\"\n        self.kernel_size = _single(kernel_size)\n        self.stride = _single(stride)\n        self.padding = _single(padding)\n        self.dilation = _single(dilation)\n        self.output_padding = _single(output_padding)\n        self.groups = groups\n        assert in_channels % groups == 0\n        assert out_channels % groups == 0\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.weight = flow.nn.Parameter(\n            flow.Tensor(in_channels, out_channels // groups, *self.kernel_size)\n        )\n        self.filters = out_channels\n        self.bias = None\n        self._bias_add_op = None\n        if bias:\n            self.bias = flow.nn.Parameter(flow.Tensor(out_channels))\n        self.reset_parameters()\n\n    def reset_parameters(self) -> None:\n        init.kaiming_uniform_(self.weight, a=math.sqrt(5))\n        if self.bias is not None:\n            (fan_in, _) = init._calculate_fan_in_and_fan_out(self.weight)\n            bound = 1 / math.sqrt(fan_in)\n            init.uniform_(self.bias, -bound, bound)\n\n    def forward(self, x):\n        return flow._C.deconv1d(\n            x,\n            self.weight,\n            self.bias,\n            self.stride,\n            self.padding,\n            self.output_padding,\n            self.groups,\n            self.dilation,\n            \"channels_first\",\n        )\n\n\nclass ConvTranspose2d(Module):\n    \"\"\"\n    \n    Applies a 2D transposed convolution operator over an input image composed of several input planes.\n\n    This module can be seen as the gradient of Conv2d with respect to its input.\n    It is also known as a fractionally-strided convolution or\n    a deconvolution (although it is not an actual deconvolution operation).\n\n    Args:  \n        in_channels (int): Number of channels in the input image\n        out_channels (int): Number of channels produced by the convolution\n        kernel_size (int or tuple): Size of the convolving kernel\n        stride (int or tuple, optional): Stride of the convolution. Default: 1\n        padding (int or tuple, optional): ``dilation * (kernel_size - 1) - padding`` zero-padding\n            will be added to both sides of each dimension in the input. Default: 0\n        output_padding (int or tuple, optional): Additional size added to one side\n            of each dimension in the output shape. Default: 0\n        groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1\n        bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True``\n        dilation (int or tuple, optional): Spacing between kernel elements. Default: 1\n\n    Shape:\n        - Input: :math:`(N, C_{in}, H_{in}, W_{in})`\n        - Output: :math:`(N, C_{out}, H_{out}, W_{out})` where\n\n        .. math::\n              H_{out} = (H_{in} - 1) \\\\times \\\\text{stride}[0] - 2 \\\\times \\\\text{padding}[0] + \\\\text{dilation}[0] \n\n                        \\\\times (\\\\text{kernel_size}[0] - 1) + \\\\text{output_padding}[0] + 1\n        .. math::\n              W_{out} = (W_{in} - 1) \\\\times \\\\text{stride}[1] - 2 \\\\times \\\\text{padding}[1] + \\\\text{dilation}[1]\n              \n                        \\\\times (\\\\text{kernel_size}[1] - 1) + \\\\text{output_padding}[1] + 1\n\n    Attributes:\n        ConvTranspose2d.weight (Tensor): the learnable weights of the module of shape\n                         :math:`(\\\\text{in_channels}, \\\\frac{\\\\text{out_channels}}{\\\\text{groups}},`\n                         :math:`\\\\text{kernel_size[0]}, \\\\text{kernel_size[1]})`.\n                         The values of these weights are sampled from\n                         :math:`\\\\mathcal{U}(-\\\\sqrt{k}, \\\\sqrt{k})` where\n                         :math:`k = \\\\frac{groups}{C_\\\\text{out} * \\\\prod_{i=0}^{1}\\\\text{kernel_size}[i]}`\n        ConvTranspose2d.bias (Tensor): the learnable bias of the module of shape (out_channels)\n                         If :attr:`bias` is ``True``, then the values of these weights are\n                         sampled from :math:`\\\\mathcal{U}(-\\\\sqrt{k}, \\\\sqrt{k})` where\n                         :math:`k = \\\\frac{groups}{C_\\\\text{out} * \\\\prod_{i=0}^{1}\\\\text{kernel_size}[i]}`\n\n    Examples::\n\n        >>> import numpy as np\n        >>> import oneflow as flow\n        >>> import oneflow.nn as nn\n        \n        >>> m = nn.ConvTranspose2d(16, 33, 3, stride=2)\n        >>> # non-square kernels and unequal stride and with padding\n        >>> m = nn.ConvTranspose2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2))\n        >>> m = m.to(\"cuda\")\n        >>> input = flow.Tensor(np.random.randn(20, 16, 50, 100), device=flow.device(\"cuda\"))\n        >>> output = m(input)\n        >>> output.size()\n        oneflow.Size([20, 33, 93, 100])\n\n    .. _cross-correlation:\n        https://en.wikipedia.org/wiki/Cross-correlation\n\n    .. _link:\n        https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md\n    \"\"\"\n\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        kernel_size: _size_2_t,\n        stride: _size_2_t = 1,\n        padding: _size_2_t = 0,\n        output_padding: _size_2_t = 0,\n        groups: int = 1,\n        bias: bool = True,\n        dilation: int = 1,\n        padding_mode: str = \"zeros\",\n    ) -> None:\n        super().__init__()\n        assert padding_mode == \"zeros\"\n        self.kernel_size = _pair(kernel_size)\n        self.stride = _pair(stride)\n        self.padding = _pair(padding)\n        self.output_padding = _pair(output_padding)\n        self.dilation = _pair(dilation)\n        self.groups = groups\n        assert in_channels % groups == 0\n        assert out_channels % groups == 0\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.weight = flow.nn.Parameter(\n            flow.Tensor(in_channels, out_channels // groups, *self.kernel_size)\n        )\n        self.in_channel_groups = in_channels // groups\n        self.filters = out_channels\n        self.bias = None\n        self._bias_add_op = None\n        if bias:\n            self.bias = flow.nn.Parameter(flow.Tensor(out_channels))\n\n        self.reset_parameters()\n\n    def reset_parameters(self) -> None:\n        init.kaiming_uniform_(self.weight, a=math.sqrt(5))\n        if self.bias is not None:\n            (fan_in, _) = init._calculate_fan_in_and_fan_out(self.weight)\n            bound = 1 / math.sqrt(fan_in)\n            init.uniform_(self.bias, -bound, bound)\n\n    def forward(self, x):\n        res = flow._C.deconv2d(\n            x,\n            self.weight,\n            self.bias,\n            self.stride,\n            self.padding,\n            self.output_padding,\n            self.groups,\n            self.dilation,\n            \"channels_first\",\n        )\n        return res\n\n\nclass ConvTranspose3d(Module):\n    r\"\"\"\n    Applies a 3D transposed convolution operator over an input image composed of several input\n    planes.\n    The transposed convolution operator multiplies each input value element-wise by a learnable kernel,\n    and sums over the outputs from all input feature planes.\n\n    This module can be seen as the gradient of Conv3d with respect to its input.\n    It is also known as a fractionally-strided convolution or\n    a deconvolution (although it is not an actual deconvolution operation).\n\n    This module supports TensorFloat32.\n\n    * :attr:`stride` controls the stride for the cross-correlation.\n\n    * :attr:`padding` controls the amount of implicit zero padding on both\n      sides for ``dilation * (kernel_size - 1) - padding`` number of points. See note\n      below for details.\n\n    * :attr:`output_padding` controls the additional size added to one side\n      of the output shape. See note below for details.\n\n    * :attr:`dilation` controls the spacing between the kernel points; also known as the à trous algorithm.\n      It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does.\n\n\n    The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`output_padding`\n    can either be:\n\n        - a single ``int`` -- in which case the same value is used for the depth, height and width dimensions\n        - a ``tuple`` of three ints -- in which case, the first `int` is used for the depth dimension,\n          the second `int` for the height dimension and the third `int` for the width dimension\n\n    Note:\n        The :attr:`padding` argument effectively adds ``dilation * (kernel_size - 1) - padding``\n        amount of zero padding to both sizes of the input. This is set so that\n        when a :class:`~torch.nn.Conv3d` and a :class:`~torch.nn.ConvTranspose3d`\n        are initialized with same parameters, they are inverses of each other in\n        regard to the input and output shapes. However, when ``stride > 1``,\n        :class:`~torch.nn.Conv3d` maps multiple input shapes to the same output\n        shape. :attr:`output_padding` is provided to resolve this ambiguity by\n        effectively increasing the calculated output shape on one side. Note\n        that :attr:`output_padding` is only used to find output shape, but does\n        not actually add zero-padding to output.\n\n\n    Args:\n        in_channels (int): Number of channels in the input image\n        out_channels (int): Number of channels produced by the convolution\n        kernel_size (int or tuple): Size of the convolving kernel\n        stride (int or tuple, optional): Stride of the convolution. Default: 1\n        padding (int or tuple, optional): ``dilation * (kernel_size - 1) - padding`` zero-padding\n            will be added to both sides of each dimension in the input. Default: 0\n        output_padding (int or tuple, optional): Additional size added to one side\n            of each dimension in the output shape. Default: 0\n        groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1\n        bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True``\n        dilation (int or tuple, optional): Spacing between kernel elements. Default: 1\n    \n\n    Shape:\n        - Input: :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})`\n        - Output: :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})` where\n\n        .. math::\n              D_{out} = (D_{in} - 1) \\times \\text{stride}[0] - 2 \\times \\text{padding}[0] + \\text{dilation}[0]\n                        \\times (\\text{kernel_size}[0] - 1) + \\text{output_padding}[0] + 1\n        .. math::\n              H_{out} = (H_{in} - 1) \\times \\text{stride}[1] - 2 \\times \\text{padding}[1] + \\text{dilation}[1]\n                        \\times (\\text{kernel_size}[1] - 1) + \\text{output_padding}[1] + 1\n        .. math::\n              W_{out} = (W_{in} - 1) \\times \\text{stride}[2] - 2 \\times \\text{padding}[2] + \\text{dilation}[2]\n                        \\times (\\text{kernel_size}[2] - 1) + \\text{output_padding}[2] + 1\n\n\n    Attributes:\n        weight (Tensor): the learnable weights of the module of shape\n                         :math:`(\\text{in_channels}, \\frac{\\text{out_channels}}{\\text{groups}},`\n                         :math:`\\text{kernel_size[0]}, \\text{kernel_size[1]}, \\text{kernel_size[2]})`.\n                         The values of these weights are sampled from\n                         :math:`\\mathcal{U}(-\\sqrt{k}, \\sqrt{k})` where\n                         :math:`k = \\frac{groups}{C_\\text{out} * \\prod_{i=0}^{2}\\text{kernel_size}[i]}`\n        bias (Tensor):   the learnable bias of the module of shape (out_channels)\n                         If :attr:`bias` is ``True``, then the values of these weights are\n                         sampled from :math:`\\mathcal{U}(-\\sqrt{k}, \\sqrt{k})` where\n                         :math:`k = \\frac{groups}{C_\\text{out} * \\prod_{i=0}^{2}\\text{kernel_size}[i]}`\n\n    Examples::\n\n        >>> import oneflow as flow\n        >>> import oneflow.nn as nn\n\n        >>> # With square kernels and equal stride\n        >>> m = nn.ConvTranspose3d(16, 33, 3, stride=2)\n        >>> # non-square kernels and unequal stride and with padding\n        >>> m = nn.ConvTranspose3d(16, 33, (3, 5, 2), stride=(2, 1, 1), padding=(0, 4, 2))\n        >>> input = flow.randn(20, 16, 10, 50, 100)\n        >>> output = m(input)\n\n    .. _cross-correlation:\n        https://en.wikipedia.org/wiki/Cross-correlation\n\n    .. _link:\n        https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md\n    \"\"\"\n\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        kernel_size: _size_3_t,\n        stride: _size_3_t = 1,\n        padding: _size_3_t = 0,\n        output_padding: _size_3_t = 0,\n        groups: int = 1,\n        bias: bool = True,\n        dilation: _size_3_t = 1,\n        padding_mode: str = \"zeros\",\n    ) -> None:\n        super().__init__()\n        assert padding_mode == \"zeros\", \"Only `zeros` padding mode is supported\"\n        self.kernel_size = _triple(kernel_size)\n        self.stride = _triple(stride)\n        self.padding = _triple(padding)\n        self.dilation = _triple(dilation)\n        self.output_padding = _triple(output_padding)\n        self.groups = groups\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        assert in_channels % groups == 0\n        assert out_channels % groups == 0\n        self.weight = flow.nn.Parameter(\n            flow.Tensor(in_channels, out_channels // groups, *self.kernel_size)\n        )\n        self.filters = out_channels\n        self.bias = None\n        self._bias_add_op = None\n        if bias:\n            self.bias = flow.nn.Parameter(flow.Tensor(out_channels))\n        self.reset_parameters()\n\n    def reset_parameters(self) -> None:\n        init.kaiming_uniform_(self.weight, a=math.sqrt(5))\n        if self.bias is not None:\n            (fan_in, _) = init._calculate_fan_in_and_fan_out(self.weight)\n            bound = 1 / math.sqrt(fan_in)\n            init.uniform_(self.bias, -bound, bound)\n\n    def forward(self, x):\n        return flow._C.deconv3d(\n            x,\n            self.weight,\n            self.bias,\n            self.stride,\n            self.padding,\n            self.output_padding,\n            self.groups,\n            self.dilation,\n            \"channels_first\",\n        )\n\n\nif __name__ == \"__main__\":\n    import doctest\n\n    doctest.testmod(raise_on_error=True)\n"
  },
  {
    "path": "python/oneflow/nn/modules/dataset.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport os\nimport random\nimport sys\nimport traceback\nfrom google.protobuf import text_format\nfrom typing import List, Optional, Sequence, Tuple, Union\n\nimport oneflow as flow\nimport oneflow._oneflow_internal._C as _C\nfrom oneflow.framework.tensor import Tensor\nfrom oneflow.framework.scope_util import current_scope\nfrom oneflow.nn.common_types import _size_1_t, _size_2_t, _size_3_t, _size_any_t\nfrom oneflow.nn.modules.module import Module\nfrom oneflow.nn.modules.utils import _pair, _reverse_repeat_tuple, _single, _triple\nimport oneflow.framework.id_util as id_util\n\n\ndef local_gen_random_seed(seed=None):\n    if seed is None:\n        seed = -1\n        has_seed = False\n    else:\n        has_seed = True\n    return (seed, has_seed)\n\n\nclass OFRecordReader(Module):\n    def __init__(\n        self,\n        ofrecord_dir: str,\n        batch_size: int = 1,\n        data_part_num: int = 1,\n        part_name_prefix: str = \"part-\",\n        part_name_suffix_length: int = -1,\n        random_shuffle: bool = False,\n        shuffle_buffer_size: int = 1024,\n        shuffle_after_epoch: bool = False,\n        random_seed: int = -1,\n        device: Union[flow.device, str] = None,\n        placement: flow.placement = None,\n        sbp: Union[flow.sbp.sbp, List[flow.sbp.sbp]] = None,\n        name: Optional[str] = None,\n    ):\n        super().__init__()\n\n        if name is not None:\n            print(\"WARNING: name has been deprecated and has NO effect.\\n\")\n        self.ofrecord_dir = ofrecord_dir\n        self.batch_size = batch_size\n        self.data_part_num = data_part_num\n        self.part_name_prefix = part_name_prefix\n        self.part_name_suffix_length = part_name_suffix_length\n        self.random_shuffle = random_shuffle\n        self.shuffle_buffer_size = shuffle_buffer_size\n        self.shuffle_after_epoch = shuffle_after_epoch\n\n        self.placement = placement\n        if placement is None:\n            self.device = device or flow.device(\"cpu\")\n        else:\n            assert device is None\n\n        if placement is not None:\n            assert isinstance(sbp, (flow.sbp.sbp, tuple, list)), \"sbp: %s\" % sbp\n            if isinstance(sbp, flow.sbp.sbp):\n                sbp = (sbp,)\n            else:\n                for elem in sbp:\n                    assert isinstance(elem, flow.sbp.sbp), \"sbp: %s\" % sbp\n            assert len(sbp) == len(placement.ranks.shape)\n        else:\n            assert sbp is None, \"sbp: %s\" % sbp\n\n        self.sbp = sbp\n\n        (self.seed, self.has_seed) = local_gen_random_seed(random_seed)\n        self._op = flow.stateful_op(\"OFRecordReader\").Output(\"out\").Build()\n\n    def forward(self):\n        if self.placement is not None:\n            res = _C.dispatch_ofrecord_reader(\n                self._op,\n                data_dir=self.ofrecord_dir,\n                data_part_num=self.data_part_num,\n                part_name_prefix=self.part_name_prefix,\n                part_name_suffix_length=self.part_name_suffix_length,\n                batch_size=self.batch_size,\n                shuffle_buffer_size=self.shuffle_buffer_size,\n                random_shuffle=self.random_shuffle,\n                shuffle_after_epoch=self.shuffle_after_epoch,\n                seed=self.seed,\n                sbp=self.sbp,\n                placement=self.placement,\n            )\n        else:\n            res = _C.dispatch_ofrecord_reader(\n                self._op,\n                data_dir=self.ofrecord_dir,\n                data_part_num=self.data_part_num,\n                part_name_prefix=self.part_name_prefix,\n                part_name_suffix_length=self.part_name_suffix_length,\n                batch_size=self.batch_size,\n                shuffle_buffer_size=self.shuffle_buffer_size,\n                random_shuffle=self.random_shuffle,\n                shuffle_after_epoch=self.shuffle_after_epoch,\n                seed=self.seed,\n                device=self.device,\n            )\n        return res\n\n\nclass OFRecordRawDecoder(Module):\n    def __init__(\n        self,\n        blob_name: str,\n        shape: Sequence[int],\n        dtype: flow.dtype,\n        dim1_varying_length: bool = False,\n        truncate: bool = False,\n        auto_zero_padding: bool = False,\n        name: Optional[str] = None,\n    ):\n        super().__init__()\n        if auto_zero_padding:\n            print(\n                \"WARNING: auto_zero_padding has been deprecated, Please use truncate instead.\\n\"\n            )\n        if name is not None:\n            print(\"WARNING: name has been deprecated and has NO effect.\\n\")\n        self.blob_name = blob_name\n        self.shape = shape\n        self.dtype = dtype\n        self.dim1_varying_length = dim1_varying_length\n        self.truncate = truncate\n        self.auto_zero_padding = auto_zero_padding\n        self._op = (\n            flow.stateful_op(\"ofrecord_raw_decoder\").Input(\"in\").Output(\"out\").Build()\n        )\n\n    def forward(self, input):\n        res = _C.dispatch_ofrecord_raw_decoder(\n            self._op,\n            input,\n            name=self.blob_name,\n            shape=self.shape,\n            data_type=self.dtype,\n            dim1_varying_length=self.dim1_varying_length,\n            truncate=self.truncate or self.auto_zero_padding,\n        )\n        return res\n\n\nclass CoinFlip(Module):\n    r\"\"\"\n    CoinFlip(batch_size=1, random_seed=None, probability=0.5, device=None, placement=None, sbp=None)\n\n    Generates random boolean values following a bernoulli distribution.\n\n    The probability of generating a value 1 (true) is determined by the ``probability`` argument.\n\n    The shape of the generated data can be either specified explicitly with a ``shape`` argument,\n    or chosen to match the shape of the input, if provided. If none are present, a single value per\n    sample is generated.\n\n    The documentation is referenced from:\n    https://docs.nvidia.com/deeplearning/dali/user-guide/docs/supported_ops_legacy.html#nvidia.dali.ops.CoinFlip.\n\n    Args:\n        batch_size (int, optional): Maximum batch size of the pipeline. Negative values for this parameter \n            are invalid - the default value may only be used with serialized pipeline (the value stored in \n            serialized pipeline is used instead). In most cases, the actual batch size of the pipeline will be \n            equal to the maximum one. Default: 1\n        random_seed (int, optional): Random seed. Default: None\n        probability (float, optional): Probability of value 1. Default: 0.5\n        device (oneflow.device, optional): Desired device of returned tensor. Default: if None, uses the \n            current device for the default tensor type.\n        placement (oneflow.placement, optional):  Desired placement of returned global tensor. \n            Default: if None, the returned tensor is local one using the argument `device`.\n        sbp (oneflow.sbp.sbp or tuple of oneflow.sbp.sbp, optional): Desired sbp descriptor of returned \n            global tensor. Default: if None, the returned tensor is local one using the argument `device`.\n\n    \"\"\"\n\n    def __init__(\n        self,\n        batch_size: int = 1,\n        random_seed: Optional[int] = None,\n        probability: float = 0.5,\n        device: Union[flow.device, str] = None,\n        placement: flow.placement = None,\n        sbp: Union[flow.sbp.sbp, List[flow.sbp.sbp]] = None,\n    ):\n        super().__init__()\n        self.batch_size = batch_size\n        self.probability = probability\n\n        self.placement = placement\n        if placement is None:\n            self.device = device or flow.device(\"cpu\")\n            assert self.device == \"cpu\" or self.device == flow.device(\n                \"cpu\"\n            ), \"coin flip only supports cpu currently.\"\n        else:\n            assert device is None\n\n        if placement is not None:\n            assert isinstance(sbp, (flow.sbp.sbp, tuple, list)), \"sbp: %s\" % sbp\n            if isinstance(sbp, flow.sbp.sbp):\n                sbp = (sbp,)\n            else:\n                for elem in sbp:\n                    assert isinstance(elem, flow.sbp.sbp), \"sbp: %s\" % sbp\n            assert len(sbp) == len(placement.ranks.shape)\n            assert (\n                self.placement.type == \"cpu\"\n            ), \"coin flip only supports cpu currently.\"\n        else:\n            assert sbp is None, \"sbp: %s\" % sbp\n\n        self.sbp = sbp\n\n        (self.seed, self.has_seed) = local_gen_random_seed(random_seed)\n\n        self._op = flow.stateful_op(\"coin_flip\").Output(\"out\").Build()\n\n    def forward(self):\n        if self.placement is not None:\n            res = _C.dispatch_coin_flip(\n                self._op,\n                batch_size=self.batch_size,\n                probability=self.probability,\n                has_seed=self.has_seed,\n                seed=self.seed,\n                placement=self.placement,\n                sbp=self.sbp,\n            )\n        else:\n            res = _C.dispatch_coin_flip(\n                self._op,\n                batch_size=self.batch_size,\n                probability=self.probability,\n                has_seed=self.has_seed,\n                seed=self.seed,\n                device=self.device,\n            )\n        return res\n\n\nclass CropMirrorNormalize(Module):\n    r\"\"\"\n    CropMirrorNormalize(color_space=\"BGR\", output_layout=\"NCHW\", crop_h=0, crop_w=0, crop_pos_y=0.5, crop_pos_x=0.5, mean= [0.0], std= [1.0], output_dtype=oneflow.float)\n\n    Performs fused cropping, normalization, format conversion\n    (NHWC to NCHW) if desired, and type casting.\n\n    Normalization takes the input images and produces the output by using the following formula:\n    \n    .. math::\n        output = (input - mean) / std\n\n    .. note::\n        If no cropping arguments are specified, only mirroring and normalization will occur.\n\n    This operator allows sequence inputs and supports volumetric data.\n\n    The documentation is referenced from:\n    https://docs.nvidia.com/deeplearning/dali/user-guide/docs/supported_ops_legacy.html#nvidia.dali.ops.CropMirrorNormalize.\n\n    Args:\n        color_space (str, optional): The color space of the input image. Default: \"BGR\"\n        output_layout (str, optional): Tensor data layout for the output. Default: \"NCHW\"\n        crop_h (int, optional): Cropping the window height (in pixels). Default: 0\n        crop_w (int, optional): Cropping window width (in pixels). Default: 0\n        crop_pos_y (float, optional): Normalized (0.0 - 1.0) vertical position of the start of the cropping \n            window (typically, the upper left corner). Default: 0.5\n        crop_pos_x (float, optional): Normalized (0.0 - 1.0) horizontal position of the cropping window \n            (upper left corner). Default: 0.5\n        mean (float or list of float, optional): Mean pixel values for image normalization. Default: [0.0],\n        std (float or list of float, optional): Standard deviation values for image normalization. \n            Default: [1.0]\n        output_dtype (oneflow.dtype, optional): Output data type. Default: ``oneflow.float``\n\n    \"\"\"\n\n    def __init__(\n        self,\n        color_space: str = \"BGR\",\n        output_layout: str = \"NCHW\",\n        crop_h: int = 0,\n        crop_w: int = 0,\n        crop_pos_y: float = 0.5,\n        crop_pos_x: float = 0.5,\n        mean: Sequence[float] = [0.0],\n        std: Sequence[float] = [1.0],\n        output_dtype: flow.dtype = flow.float,\n    ):\n        super().__init__()\n        if output_layout != \"NCHW\":\n            print(\n                \"WARNING: output_layout has been deprecated. Please use Environment Variable ONEFLOW_ENABLE_NHWC, and make it equals 1.\"\n            )\n        if os.getenv(\"ONEFLOW_ENABLE_NHWC\") == \"1\":\n            output_layout = \"NHWC\"\n        else:\n            output_layout = \"NCHW\"\n\n        self.color_space = color_space\n        self.output_layout = output_layout\n        self.mean = mean\n        self.std = std\n        self.crop_h = crop_h\n        self.crop_w = crop_w\n        self.crop_pos_y = crop_pos_y\n        self.crop_pos_x = crop_pos_x\n        self.output_dtype = output_dtype\n\n        self._op_uint8_with_mirror = (\n            flow.stateful_op(\"crop_mirror_normalize_from_uint8\")\n            .Input(\"in\")\n            .Input(\"mirror\")\n            .Output(\"out\")\n            .Build()\n        )\n        self._op_uint8_no_mirror = (\n            flow.stateful_op(\"crop_mirror_normalize_from_uint8\")\n            .Input(\"in\")\n            .Output(\"out\")\n            .Build()\n        )\n        self._op_buffer_with_mirror = (\n            flow.stateful_op(\"crop_mirror_normalize_from_tensorbuffer\")\n            .Input(\"in\")\n            .Input(\"mirror\")\n            .Output(\"out\")\n            .Build()\n        )\n\n        self._op_buffer_no_mirror = (\n            flow.stateful_op(\"crop_mirror_normalize_from_tensorbuffer\")\n            .Input(\"in\")\n            .Output(\"out\")\n            .Build()\n        )\n\n    def forward(self, input, mirror=None):\n        if input.dtype is flow.uint8:\n            if mirror is not None:\n                res = _C.dispatch_crop_mirror_normalize_from_uint8(\n                    self._op_uint8_with_mirror,\n                    (input, mirror),\n                    color_space=self.color_space,\n                    output_layout=self.output_layout,\n                    mean=self.mean,\n                    std=self.std,\n                    crop_h=self.crop_h,\n                    crop_w=self.crop_w,\n                    crop_pos_x=self.crop_pos_x,\n                    crop_pos_y=self.crop_pos_y,\n                    output_dtype=self.output_dtype,\n                )\n            else:\n                res = _C.dispatch_crop_mirror_normalize_from_uint8(\n                    self._op_uint8_no_mirror,\n                    (input,),\n                    color_space=self.color_space,\n                    output_layout=self.output_layout,\n                    mean=self.mean,\n                    std=self.std,\n                    crop_h=self.crop_h,\n                    crop_w=self.crop_w,\n                    crop_pos_x=self.crop_pos_x,\n                    crop_pos_y=self.crop_pos_y,\n                    output_dtype=self.output_dtype,\n                )\n        elif input.dtype is flow.tensor_buffer:\n            if mirror is not None:\n                res = _C.dispatch_crop_mirror_normalize_from_tensorbuffer(\n                    self._op_buffer_with_mirror,\n                    (input, mirror),\n                    color_space=self.color_space,\n                    output_layout=self.output_layout,\n                    mean=self.mean,\n                    std=self.std,\n                    crop_h=self.crop_h,\n                    crop_w=self.crop_w,\n                    crop_pos_x=self.crop_pos_x,\n                    crop_pos_y=self.crop_pos_y,\n                    output_dtype=self.output_dtype,\n                )\n            else:\n                res = _C.dispatch_crop_mirror_normalize_from_tensorbuffer(\n                    self._op_buffer_no_mirror,\n                    (input,),\n                    color_space=self.color_space,\n                    output_layout=self.output_layout,\n                    mean=self.mean,\n                    std=self.std,\n                    crop_h=self.crop_h,\n                    crop_w=self.crop_w,\n                    crop_pos_x=self.crop_pos_x,\n                    crop_pos_y=self.crop_pos_y,\n                    output_dtype=self.output_dtype,\n                )\n        else:\n            print(\n                \"ERROR! oneflow.nn.CropMirrorNormalize module NOT support input dtype = \",\n                input.dtype,\n            )\n            raise NotImplementedError\n        return res\n\n\nclass OFRecordImageDecoderRandomCrop(Module):\n    def __init__(\n        self,\n        blob_name: str,\n        color_space: str = \"BGR\",\n        num_attempts: int = 10,\n        random_seed: Optional[int] = None,\n        random_area: Sequence[float] = [0.08, 1.0],\n        random_aspect_ratio: Sequence[float] = [0.75, 1.333333],\n    ):\n        super().__init__()\n        self.blob_name = blob_name\n        self.color_space = color_space\n        self.num_attempts = num_attempts\n        self.random_area = random_area\n        self.random_aspect_ratio = random_aspect_ratio\n        (self.seed, self.has_seed) = local_gen_random_seed(random_seed)\n        self._op = (\n            flow.stateful_op(\"ofrecord_image_decoder_random_crop\")\n            .Input(\"in\")\n            .Output(\"out\")\n            .Build()\n        )\n\n    def forward(self, input):\n        res = _C.dispatch_ofrecord_image_decoder_random_crop(\n            self._op,\n            input,\n            name=self.blob_name,\n            color_space=self.color_space,\n            num_attempts=self.num_attempts,\n            random_area=self.random_area,\n            random_aspect_ratio=self.random_aspect_ratio,\n            has_seed=self.has_seed,\n            seed=self.seed,\n        )\n        return res\n\n\nclass OFRecordImageDecoder(Module):\n    def __init__(self, blob_name: str, color_space: str = \"BGR\"):\n        super().__init__()\n        self._op = (\n            flow.stateful_op(\"ofrecord_image_decoder\").Input(\"in\").Output(\"out\").Build()\n        )\n        self.blob_name = blob_name\n        self.color_space = color_space\n\n    def forward(self, input):\n        res = _C.dispatch_ofrecord_image_decoder(\n            self._op, input, name=self.blob_name, color_space=self.color_space\n        )\n        return res\n\n\nclass OFRecordImageGpuDecoderRandomCropResize(Module):\n    def __init__(\n        self,\n        target_width: int,\n        target_height: int,\n        num_attempts: Optional[int] = 10,\n        seed: Optional[int] = 0,\n        random_area: Optional[Sequence[float]] = [0.08, 1.0],\n        random_aspect_ratio: Optional[Sequence[float]] = [0.75, 1.333333],\n        num_workers: Optional[int] = 3,\n        warmup_size: Optional[int] = 6400,\n        max_num_pixels: Optional[int] = 67108864,\n    ):\n        super().__init__()\n        self.target_width = target_width\n        self.target_height = target_height\n        self.num_attempts = num_attempts\n        self.seed = seed\n        assert len(random_area) == 2\n        self.random_area = random_area\n        assert len(random_aspect_ratio) == 2\n        self.random_aspect_ratio = random_aspect_ratio\n        self.num_workers = num_workers\n        self.warmup_size = warmup_size\n        self.max_num_pixels = max_num_pixels\n        gpu_decoder_conf = (\n            flow.core.operator.op_conf_pb2.ImageDecoderRandomCropResizeOpConf()\n        )\n        # parse failed when excu clang format if use `gpu_decoder_conf.in = \"error_input_need_to_be_replaced\"`\n        setattr(gpu_decoder_conf, \"in\", \"error_input_need_to_be_replaced\")\n        gpu_decoder_conf.out = \"out\"\n        gpu_decoder_conf.target_width = (\n            -1\n        )  # Set the default value, otherwise the parsing fails\n        gpu_decoder_conf.target_height = -1\n        gpu_decoder_conf_str = text_format.MessageToString(gpu_decoder_conf)\n        self._op = flow._oneflow_internal.one.ImageDecoderRandomCropResizeOpExpr(\n            id_util.UniqueStr(\"ImageGpuDecoder\"), gpu_decoder_conf_str, [\"in\"], [\"out\"]\n        )\n\n    def forward(self, input):\n        if not input.is_lazy:\n            print(\n                \"ERROR! oneflow.nn.OFRecordImageGpuDecoderRandomCropResize module \",\n                \"NOT support run as eager module, please use it in nn.Graph.\",\n            )\n            raise NotImplementedError\n        res = _C.dispatch_image_decoder_random_crop_resize(\n            self._op,\n            input,\n            target_width=self.target_width,\n            target_height=self.target_height,\n            num_attempts=self.num_attempts,\n            seed=self.seed,\n            random_area_min=self.random_area[0],\n            random_area_max=self.random_area[1],\n            random_aspect_ratio_min=self.random_aspect_ratio[0],\n            random_aspect_ratio_max=self.random_aspect_ratio[1],\n            num_workers=self.num_workers,\n            warmup_size=self.warmup_size,\n            max_num_pixels=self.max_num_pixels,\n        )\n        if not res.is_cuda:\n            print(\n                \"WARNING! oneflow.nn.OFRecordImageGpuDecoderRandomCropResize ONLY support \",\n                \"CUDA runtime version >= 10.2, so now it degenerates into CPU decode version.\",\n            )\n        return res\n\n\nclass TensorBufferToListOfTensors(Module):\n    def __init__(\n        self, out_shapes, out_dtypes, out_num: int = 1, dynamic_out: bool = False\n    ):\n        super().__init__()\n        self._op = (\n            flow.stateful_op(\"tensor_buffer_to_list_of_tensors_v2\")\n            .Input(\"in\")\n            .Output(\"out\", out_num)\n            .Build()\n        )\n        self.out_shapes = out_shapes\n        self.out_dtypes = out_dtypes\n        self.dynamic_out = dynamic_out\n\n    def forward(self, input):\n        return _C.dispatch_tensor_buffer_to_list_of_tensors_v2(\n            self._op,\n            input,\n            out_shapes=self.out_shapes,\n            out_dtypes=self.out_dtypes,\n            dynamic_out=self.dynamic_out,\n        )\n\n\ndef tensor_buffer_to_list_of_tensors(tensor, out_shapes, out_dtypes):\n    return TensorBufferToListOfTensors(\n        [list(out_shape) for out_shape in out_shapes], out_dtypes, len(out_shapes)\n    )(tensor)\n\n\nclass ImageResize(Module):\n    def __init__(\n        self,\n        target_size: Union[int, Sequence[int]] = None,\n        min_size: Optional[int] = None,\n        max_size: Optional[int] = None,\n        keep_aspect_ratio: bool = False,\n        resize_side: str = \"shorter\",\n        channels: int = 3,\n        dtype: Optional[flow.dtype] = None,\n        interpolation_type: str = \"auto\",\n        name: Optional[str] = None,\n        color_space: Optional[str] = None,\n        interp_type: Optional[str] = None,\n        resize_shorter: int = 0,\n        resize_x: int = 0,\n        resize_y: int = 0,\n    ):\n        super().__init__()\n        if name is not None:\n            print(\"WARNING: name has been deprecated and has NO effect.\\n\")\n        deprecated_param_used = False\n        if color_space is not None:\n            print(\n                \"WARNING: color_space has been deprecated. Please use channels instead.\"\n            )\n            print(traceback.format_stack()[-2])\n            deprecated_param_used = True\n            assert isinstance(color_space, str)\n            if color_space.upper() == \"RGB\" or color_space.upper() == \"BGR\":\n                channels = 3\n            elif color_space.upper() == \"GRAY\":\n                channels = 1\n            else:\n                raise ValueError(\"invalid color_space\")\n        self.channels = channels\n        if interp_type is not None:\n            print(\n                \"WARNING: interp_type has been deprecated. Please use interpolation_type instead.\"\n            )\n            print(traceback.format_stack()[-2])\n            deprecated_param_used = True\n            assert isinstance(interp_type, str)\n            if interp_type == \"Linear\":\n                interpolation_type = \"bilinear\"\n            elif interp_type == \"NN\":\n                interpolation_type = \"nearest_neighbor\"\n            elif interp_type == \"Cubic\":\n                interpolation_type = \"bicubic\"\n            else:\n                raise ValueError(\"invalid interp_type\")\n        self.interpolation_type = interpolation_type\n\n        if resize_x > 0 and resize_y > 0:\n            print(\n                \"WARNING: resize_x and resize_y has been deprecated. Please use target_size instead.\"\n            )\n            print(traceback.format_stack()[-2])\n            deprecated_param_used = True\n            target_size = (resize_x, resize_y)\n            keep_aspect_ratio = False\n        if resize_shorter > 0:\n            print(\n                \"WARNING: resize_shorter has been deprecated. Please use target_size instead.\"\n            )\n            print(traceback.format_stack()[-2])\n            deprecated_param_used = True\n            target_size = resize_shorter\n            keep_aspect_ratio = True\n            resize_side = \"shorter\"\n        self.keep_aspect_ratio = keep_aspect_ratio\n        if self.keep_aspect_ratio:\n            if not isinstance(target_size, int):\n                raise ValueError(\n                    \"target_size must be an int when keep_aspect_ratio is True\"\n                )\n            if min_size is None:\n                min_size = 0\n            if max_size is None:\n                max_size = 0\n            if resize_side == \"shorter\":\n                resize_longer = False\n            elif resize_side == \"longer\":\n                resize_longer = True\n            else:\n                raise ValueError('resize_side must be \"shorter\" or \"longer\"')\n            self.target_size = target_size\n            self.min_size = min_size\n            self.max_size = max_size\n            self.resize_longer = resize_longer\n            self._op = (\n                flow.stateful_op(\"image_resize_keep_aspect_ratio\")\n                .Input(\"in\")\n                .Output(\"out\")\n                .Output(\"size\")\n                .Output(\"scale\")\n                .Build()\n            )\n        else:\n            if (\n                not isinstance(target_size, (list, tuple))\n                or len(target_size) != 2\n                or (not all((isinstance(size, int) for size in target_size)))\n            ):\n                raise ValueError(\n                    \"target_size must be a form like (width, height) when keep_aspect_ratio is False\"\n                )\n            if dtype is None:\n                dtype = flow.uint8\n            self.dtype = dtype\n            (self.target_w, self.target_h) = target_size\n            self._op = (\n                flow.stateful_op(\"image_resize_to_fixed\")\n                .Input(\"in\")\n                .Output(\"out\")\n                .Output(\"scale\")\n                .Build()\n            )\n\n    def forward(self, input):\n        if self.keep_aspect_ratio:\n            res = _C.dispatch_image_resize_keep_aspect_ratio(\n                self._op,\n                input,\n                target_size=self.target_size,\n                min_size=self.min_size,\n                max_size=self.max_size,\n                resize_longer=self.resize_longer,\n                interpolation_type=self.interpolation_type,\n            )\n            new_size = flow.tensor_buffer_to_tensor(\n                res[1], dtype=flow.int32, instance_shape=(2,)\n            )\n            scale = flow.tensor_buffer_to_tensor(\n                res[2], dtype=flow.float32, instance_shape=(2,)\n            )\n        else:\n            res = _C.dispatch_image_resize_to_fixed(\n                self._op,\n                input,\n                target_width=self.target_w,\n                target_height=self.target_h,\n                channels=self.channels,\n                data_type=self.dtype,\n                interpolation_type=self.interpolation_type,\n            )\n            new_size = None\n            scale = res[1]\n        res_image = res[0]\n        return (res_image, scale, new_size)\n\n\ndef raw_decoder(\n    input_record,\n    blob_name: str,\n    shape: Sequence[int],\n    dtype: flow.dtype,\n    dim1_varying_length: bool = False,\n    truncate: bool = False,\n    auto_zero_padding: bool = False,\n    name: Optional[str] = None,\n):\n    if auto_zero_padding:\n        print(\n            \"WARNING: auto_zero_padding has been deprecated, Please use truncate instead.\\n            \"\n        )\n    return OFRecordRawDecoder(\n        blob_name,\n        shape,\n        dtype,\n        dim1_varying_length,\n        truncate or auto_zero_padding,\n        name,\n    ).forward(input_record)\n\n\ndef get_ofrecord_handle(\n    ofrecord_dir: str,\n    batch_size: int = 1,\n    data_part_num: int = 1,\n    part_name_prefix: str = \"part-\",\n    part_name_suffix_length: int = -1,\n    random_shuffle: bool = False,\n    shuffle_buffer_size: int = 1024,\n    shuffle_after_epoch: bool = False,\n    name: Optional[str] = None,\n):\n    return OFRecordReader(\n        ofrecord_dir,\n        batch_size,\n        data_part_num,\n        part_name_prefix,\n        part_name_suffix_length,\n        random_shuffle,\n        shuffle_buffer_size,\n        shuffle_after_epoch,\n        name,\n    )()\n\n\nclass ImageFlip(Module):\n    \"\"\"This operator flips the images.\n\n    The flip code corresponds to the different flip mode:\n\n    0 (0x00): Non Flip\n\n    1 (0x01): Horizontal Flip\n\n    2 (0x02): Vertical Flip\n\n    3 (0x03): Both Horizontal and Vertical Flip\n\n    Args:\n        images: The input images.\n        flip_code: The flip code.\n\n    Returns:\n        The result image.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import numpy as np\n        >>> import oneflow as flow\n        >>> import oneflow.nn as nn\n\n        >>> arr = np.array([\n        ...    [[[1, 2, 3], [3, 2, 1]],\n        ...     [[2, 3, 4], [4, 3, 2]]],\n        ...    [[[3, 4, 5], [5, 4, 3]],\n        ...     [[4, 5, 6], [6, 5, 4]]]])\n        >>> image_tensors = flow.Tensor(arr, device=flow.device(\"cpu\"))\n        >>> image_tensor_buffer = flow.tensor_to_tensor_buffer(image_tensors, instance_dims=3)\n        >>> flip_code = flow.ones(arr.shape[0], dtype=flow.int8)\n        >>> output = nn.image.flip()(image_tensor_buffer, flip_code).numpy()\n        >>> output[0]\n        array([[[3., 2., 1.],\n                [1., 2., 3.]],\n        <BLANKLINE>\n               [[4., 3., 2.],\n                [2., 3., 4.]]], dtype=float32)\n        >>> output[1]\n        array([[[5., 4., 3.],\n                [3., 4., 5.]],\n        <BLANKLINE>\n               [[6., 5., 4.],\n                [4., 5., 6.]]], dtype=float32)\n    \"\"\"\n\n    def __init__(self):\n        super().__init__()\n\n    def forward(self, images, flip_code):\n        return flow._C.image_flip(images, flip_code=flip_code)\n\n\nclass ImageDecode(Module):\n    def __init__(self, dtype: flow.dtype = flow.uint8, color_space: str = \"BGR\"):\n        super().__init__()\n        self.color_space = color_space\n        self.dtype = dtype\n        self._op = flow.stateful_op(\"image_decode\").Input(\"in\").Output(\"out\").Build()\n\n    def forward(self, input):\n        return _C.dispatch_image_decode(\n            self._op, input, color_space=self.color_space, data_type=self.dtype\n        )\n\n\nclass ImageNormalize(Module):\n    def __init__(self, std: Sequence[float], mean: Sequence[float]):\n        super().__init__()\n        self.std = std\n        self.mean = mean\n        self._op = flow.stateful_op(\"image_normalize\").Input(\"in\").Output(\"out\").Build()\n\n    def forward(self, input):\n        return _C.dispatch_image_normalize(\n            self._op, input, mean=self.mean, std=self.std\n        )\n\n\nclass COCOReader(Module):\n    def __init__(\n        self,\n        annotation_file: str,\n        image_dir: str,\n        batch_size: int,\n        shuffle: bool = True,\n        random_seed: Optional[int] = None,\n        group_by_aspect_ratio: bool = True,\n        remove_images_without_annotations: bool = True,\n        stride_partition: bool = True,\n        device: Union[flow.device, str] = None,\n        placement: flow.placement = None,\n        sbp: Union[flow.sbp.sbp, List[flow.sbp.sbp]] = None,\n    ):\n        super().__init__()\n\n        _handle_shuffle_args(self, shuffle, random_seed)\n        _handle_distributed_args(self, device, placement, sbp)\n\n        self.annotation_file = annotation_file\n        self.image_dir = image_dir\n        self.batch_size = batch_size\n        self.group_by_aspect_ratio = group_by_aspect_ratio\n        self.remove_images_without_annotations = remove_images_without_annotations\n        self.stride_partition = stride_partition\n\n        self._op = (\n            flow.stateful_op(\"COCOReader\")\n            .Output(\"image\")\n            .Output(\"image_id\")\n            .Output(\"image_size\")\n            .Output(\"gt_bbox\")\n            .Output(\"gt_label\")\n            .Output(\"gt_segm\")\n            .Output(\"gt_segm_index\")\n            .Build()\n        )\n\n    def forward(self):\n        if self.placement is None:\n            # local apply\n            outputs = _C.dispatch_coco_reader(\n                self._op,\n                session_id=current_scope().session_id,\n                annotation_file=self.annotation_file,\n                image_dir=self.image_dir,\n                batch_size=self.batch_size,\n                shuffle_after_epoch=self.shuffle,\n                random_seed=self.random_seed,\n                group_by_ratio=self.group_by_aspect_ratio,\n                remove_images_without_annotations=self.remove_images_without_annotations,\n                stride_partition=self.stride_partition,\n                device=self.device,\n            )\n        else:\n            # consistent apply\n            outputs = _C.dispatch_coco_reader(\n                self._op,\n                session_id=current_scope().session_id,\n                annotation_file=self.annotation_file,\n                image_dir=self.image_dir,\n                batch_size=self.batch_size,\n                shuffle_after_epoch=self.shuffle,\n                random_seed=self.random_seed,\n                group_by_ratio=self.group_by_aspect_ratio,\n                remove_images_without_annotations=self.remove_images_without_annotations,\n                stride_partition=self.stride_partition,\n                placement=self.placement,\n                sbp=self.sbp,\n            )\n        return outputs\n\n\nclass ImageBatchAlign(Module):\n    def __init__(self, shape: Sequence[int], dtype: flow.dtype, alignment: int):\n        super().__init__()\n        self._op = (\n            flow.stateful_op(\"image_batch_align\").Input(\"in\").Output(\"out\").Build()\n        )\n        self.shape = shape\n        self.dtype = dtype\n        self.alignment = alignment\n\n    def forward(self, input):\n        return _C.dispatch_image_batch_align(\n            self._op,\n            input,\n            shape=self.shape,\n            data_type=self.dtype,\n            alignment=self.alignment,\n            dynamic_out=False,\n        )\n\n\nclass OFRecordBytesDecoder(Module):\n    r\"\"\"This operator reads an tensor as bytes. The output might need\n\n    further decoding process like cv2.imdecode() for images and decode(\"utf-8\")\n\n    for characters,depending on the downstream task.\n\n    Args:\n        blob_name: The name of the target feature in OFRecord.\n\n        name: The name for this component in the graph.\n\n        input: the Tensor which might be provided by an OFRecordReader.\n\n    Returns:\n\n        The result Tensor encoded with bytes.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import numpy as np\n        >>> import oneflow as flow\n\n        >>> def example():\n        ...      batch_size = 16\n        ...      record_reader = flow.nn.OFRecordReader(\n        ...         \"dataset/\",\n        ...         batch_size=batch_size,\n        ...         part_name_suffix_length=5,\n        ...      )\n        ...      val_record = record_reader()\n\n        ...      bytesdecoder_img = flow.nn.OFRecordBytesDecoder(\"encoded\")\n\n        ...      image_bytes_batch = bytesdecoder_img(val_record)\n\n        ...      image_bytes = image_bytes_batch.numpy()[0]\n        ...      return image_bytes\n        ... example()  # doctest: +SKIP\n        array([255 216 255 ...  79 255 217], dtype=uint8)\n\n\n\n    \"\"\"\n\n    def __init__(self, blob_name: str, name: Optional[str] = None):\n        super().__init__()\n        if name is not None:\n            print(\"WARNING: name has been deprecated and has NO effect.\\n\")\n        self._op = (\n            flow.stateful_op(\"ofrecord_bytes_decoder\").Input(\"in\").Output(\"out\").Build()\n        )\n        self.blob_name = blob_name\n\n    def forward(self, input):\n        return _C.dispatch_ofrecord_bytes_decoder(self._op, input, name=self.blob_name)\n\n\nclass GPTIndexedBinDataReader(Module):\n    def __init__(\n        self,\n        data_file_prefix: str,\n        seq_length: int,\n        num_samples: int,\n        batch_size: int,\n        dtype: flow.dtype = flow.int64,\n        shuffle: bool = True,\n        random_seed: Optional[int] = None,\n        split_sizes: Optional[Sequence[str]] = None,\n        split_index: Optional[int] = None,\n        device: Union[flow.device, str] = None,\n        placement: flow.placement = None,\n        sbp: Union[flow.sbp.sbp, List[flow.sbp.sbp]] = None,\n    ):\n        super().__init__()\n\n        _handle_shuffle_args(self, shuffle, random_seed)\n        _handle_distributed_args(self, device, placement, sbp)\n\n        self.data_file_prefix = data_file_prefix\n        self.batch_size = batch_size\n        self.num_samples = num_samples\n        self.seq_length = seq_length\n        self.dtype = dtype\n\n        if split_index is None:\n            split_index = 0\n        self.split_index = split_index\n\n        if split_sizes is None:\n            split_sizes = (1,)\n        self.split_sizes = split_sizes\n\n        if split_index >= len(split_sizes):\n            raise ValueError(\n                \"split index {} is out of range, split_sizes {}\".formart(\n                    split_index, split_sizes\n                )\n            )\n\n        self.op_ = (\n            flow.stateful_op(\"megatron_gpt_mmap_data_loader\").Output(\"out\").Build()\n        )\n\n    def forward(self):\n        if self.placement is None:\n            output = _C.dispatch_megatron_gpt_mmap_data_loader(\n                self.op_,\n                data_file_prefix=self.data_file_prefix,\n                seq_length=self.seq_length,\n                label_length=1,\n                num_samples=self.num_samples,\n                batch_size=self.batch_size,\n                dtype=self.dtype,\n                shuffle=self.shuffle,\n                random_seed=self.random_seed,\n                split_sizes=self.split_sizes,\n                split_index=self.split_index,\n                device=self.device,\n            )\n        else:\n            output = _C.dispatch_megatron_gpt_mmap_data_loader(\n                self.op_,\n                data_file_prefix=self.data_file_prefix,\n                seq_length=self.seq_length,\n                label_length=1,\n                num_samples=self.num_samples,\n                batch_size=self.batch_size,\n                dtype=self.dtype,\n                shuffle=self.shuffle,\n                random_seed=self.random_seed,\n                split_sizes=self.split_sizes,\n                split_index=self.split_index,\n                placement=self.placement,\n                sbp=self.sbp,\n            )\n        return output\n\n\nclass RawReader(Module):\n    def __init__(\n        self,\n        files: List[str],\n        shape: Sequence[int],\n        dtype: flow.dtype,\n        batch_size: int,\n        random_shuffle: bool = True,\n        shuffle_block_size: int = 0,\n        random_seed: Optional[int] = None,\n        placement: flow.placement = None,\n        sbp: Union[flow.sbp.sbp, List[flow.sbp.sbp]] = None,\n    ):\n\n        super().__init__()\n\n        _handle_shuffle_args(self, random_shuffle, random_seed)\n        _handle_distributed_args(self, None, placement, sbp)\n\n        self.files = files\n        self.shape = shape\n        self.dtype = dtype\n        self.batch_size = batch_size\n        self.shuffle_block_size = shuffle_block_size\n\n        self.op = flow.stateful_op(\"raw_reader\").Output(\"out\").Build()\n\n    def forward(self):\n        if self.placement is None:\n\n            output = _C.dispatch_raw_reader(\n                self.op,\n                files=self.files,\n                shape=self.shape,\n                data_type=self.dtype,\n                batch_size=self.batch_size,\n                random_shuffle=self.shuffle,\n                shuffle_block_size=self.shuffle_block_size,\n                random_seed=self.random_seed,\n                device=self.device,\n            )\n        else:\n            output = _C.dispatch_raw_reader(\n                self.op,\n                files=self.files,\n                shape=self.shape,\n                data_type=self.dtype,\n                batch_size=self.batch_size,\n                random_shuffle=self.shuffle,\n                shuffle_block_size=self.shuffle_block_size,\n                random_seed=self.random_seed,\n                placement=self.placement,\n                sbp=self.sbp,\n            )\n        return output\n\n\ndef _handle_distributed_args(module, device, placement, sbp):\n    module.placement = placement\n    if placement is None:\n        module.device = device or flow.device(\"cpu\")\n    else:\n        if device is not None:\n            raise ValueError(\n                \"The 'device' and 'placement' arguments can't be specified at the same time.\"\n            )\n\n        module.device = None\n\n        if isinstance(sbp, (tuple, list)):\n            for sbp_item in sbp:\n                if not isinstance(sbp_item, flow.sbp.sbp):\n                    raise ValueError(f\"invalid sbp item: {sbp_item}\")\n        elif isinstance(sbp, flow.sbp.sbp):\n            sbp = (sbp,)\n        else:\n            raise ValueError(f\"invalid 'sbp' argument: {sbp}\")\n\n        if len(sbp) != len(placement.ranks.shape):\n            raise ValueError(\n                \"Number of SBP's dimensions of sbp and number of placement ranks'dimensions must equal.\"\n                f\" {len(sbp)} vs. {len(placement.ranks)}\"\n            )\n\n    module.sbp = sbp\n\n\ndef _handle_shuffle_args(module, shuffle, random_seed):\n    module.shuffle = shuffle\n    if random_seed is None:\n        if shuffle:\n            module.random_seed = random.randrange(sys.maxsize)\n        else:\n            module.random_seed = -1\n    else:\n        assert isinstance(random_seed, int)\n        module.random_seed = random_seed\n\n\nif __name__ == \"__main__\":\n    import doctest\n\n    doctest.testmod(raise_on_error=True)\n"
  },
  {
    "path": "python/oneflow/nn/modules/distance.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow as flow\nfrom oneflow.framework.tensor import Tensor\nfrom oneflow.nn.modules.module import Module\n\nfrom typing import Optional\n\n\nclass CosineSimilarity(Module):\n    r\"\"\"    \n    Returns cosine similarity between :math:`x_1` and :math:`x_2`, computed along `dim`.\n\n    .. math ::\n        \\text{similarity} = \\dfrac{x_1 \\cdot x_2}{\\max(\\Vert x_1 \\Vert _2 \\cdot \\Vert x_2 \\Vert _2, \\epsilon)}.\n\n    The interface is consistent with PyTorch.\n    The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.CosineSimilarity.html#torch.nn.CosineSimilarity\n\n    Args:\n        dim (int, optional): Dimension where cosine similarity is computed. Default: 1\n        eps (float, optional): Small value to avoid division by zero.\n            Default: 1e-8\n    Shape:\n        - Input1: :math:`(\\ast_1, D, \\ast_2)` where D is at position `dim`.\n        - Input2: :math:`(\\ast_1, D, \\ast_2)`, same number of dimensions as x1, matching x1 size at dimension `dim`,\n              and broadcastable with x1 at other dimensions.\n        - Output: :math:`(\\ast_1, \\ast_2)`\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> from oneflow import nn\n        >>> input1 = flow.randn(100, 128)\n        >>> input2 = flow.randn(100, 128)\n        >>> cos = nn.CosineSimilarity(dim=1, eps=1e-6)\n        >>> output = cos(input1, input2)\n    \"\"\"\n\n    def __init__(self, dim: Optional[int] = 1, eps: Optional[float] = 1e-08,) -> None:\n        super().__init__()\n        self.dim = dim\n        self.eps = eps\n\n    def forward(self, x1: Tensor, x2: Tensor) -> Tensor:\n        return flow._C.cosine_similarity(x1, x2, self.dim, self.eps)\n\n\nclass PairwiseDistance(Module):\n    r\"\"\"Computes the pairwise distance between vectors :math:`v_1`, :math:`v_2` using the p-norm:\n\n    .. math ::\n        \\left \\| x \\right \\| _p = (\\sum_{i=1}^n \\left | x_i \\right |^p )^{\\frac{1}{p}}\n\n    The interface is consistent with PyTorch.\n    The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.PairwiseDistance.html.\n\n    Args:\n        p (real): the norm degree. Default: 2\n        eps (float, optional): Small value to avoid division by zero. Default: 1e-6\n        keepdim (bool, optional): Determines whether or not to keep the vector dimension. Default: False\n\n    Shape:\n        - Input1: :math:`(N, D)` or :math:`(D)`, where N = batch dimension and D = vector dimension\n        - Input2: :math:`(N, D)` or :math:`(D)`, same shape as the input1\n        - Output: :math:`(N)` or :math:`()` based on input dimension. If keepdim is True, then :math:`(N, 1)` or :math:`(1)` based on input dimension.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> pdist = flow.nn.PairwiseDistance(p=2)\n        >>> x1 = flow.arange(12).reshape(3, 4)\n        >>> x2 = flow.arange(12).reshape(3, 4)\n        >>> pdist(x1, x2)\n        tensor([2.0000e-06, 2.0000e-06, 2.0000e-06], dtype=oneflow.float32)\n        >>> pdist(x1, x2).shape\n        oneflow.Size([3])\n\n    \"\"\"\n\n    def __init__(\n        self,\n        p: Optional[float] = 2.0,\n        eps: Optional[float] = 1e-06,\n        keepdim: Optional[bool] = False,\n    ) -> None:\n        super().__init__()\n        self.p = p\n        self.eps = eps\n        self.keepdim = keepdim\n\n    def forward(self, x1: Tensor, x2: Tensor) -> Tensor:\n        return flow._C.pairwise_distance(\n            x1, x2, p=self.p, eps=self.eps, keepdim=self.keepdim\n        )\n\n\nif __name__ == \"__main__\":\n    import doctest\n\n    doctest.testmod(raise_on_error=True)\n"
  },
  {
    "path": "python/oneflow/nn/modules/distributed_partial_fc_sample.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport warnings\nimport oneflow as flow\nimport oneflow._oneflow_internal._C as _C\nfrom oneflow.nn.modules.module import Module\n\n\nclass DistributedPariticalFCSample(Module):\n    def __init__(self, num_sample):\n        super().__init__()\n        self.num_sample = num_sample\n        self._op = (\n            flow.stateful_op(\"distributed_partial_fc_sample\")\n            .Input(\"weight\")\n            .Input(\"label\")\n            .Output(\"mapped_label\")\n            .Output(\"sampled_label\")\n            .Output(\"sampled_weight\")\n            .Build()\n        )\n\n    def forward(self, weight, label):\n        res = _C.dispatch_distributed_partial_fc_sample(\n            self._op, weight=weight, label=label, num_sample=self.num_sample\n        )\n        return res\n\n\ndef distributed_partial_fc_sample_op(weight, label, num_sample):\n    warnings.warn(\n        \"oneflow.distributed_partial_fc_sample is deprecated. Please use nn.DistributedPariticalFCSample module instead.\",\n        DeprecationWarning,\n    )\n    return DistributedPariticalFCSample(num_sample)(weight, label)\n"
  },
  {
    "path": "python/oneflow/nn/modules/dropout.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport random\nimport sys\n\nimport oneflow as flow\nimport oneflow.framework.id_util as id_util\nfrom oneflow.nn.modules.module import Module\n\n\nclass _DropoutNd(Module):\n    __constants__ = [\"p\", \"inplace\"]\n    p: float\n    inplace: bool\n\n    def __init__(self, p: float = 0.5, inplace: bool = False) -> None:\n        super(_DropoutNd, self).__init__()\n        if p < 0 or p > 1:\n            raise ValueError(\n                \"dropout probability has to be between 0 and 1, but got {}\".format(p)\n            )\n        self.p = p\n        self.inplace = inplace\n\n    def extra_repr(self) -> str:\n        return \"p={}, inplace={}\".format(self.p, self.inplace)\n\n\nclass Dropout(_DropoutNd):\n    def __init__(self, p: float = 0.5, inplace: bool = False, generator=None):\n        _DropoutNd.__init__(self, p, inplace)\n        self.p = p\n        self.generator = generator\n\n    def forward(self, x, addend=None):\n        return flow._C.dropout(\n            x,\n            self.p,\n            self.training,\n            self.inplace,\n            self.generator,\n            addend=addend if addend is not None else None,\n        )\n\n\nclass Dropout1d(Dropout):\n    def forward(self, x, addend=None):\n        return flow._C.dropout1d(x, self.p, self.training)\n\n\nclass Dropout2d(Dropout):\n    def forward(self, x, addend=None):\n        return flow._C.dropout2d(x, self.p, self.training)\n\n\nclass Dropout3d(Dropout):\n    def forward(self, x, addend=None):\n        return flow._C.dropout3d(x, self.p, self.training)\n\n\nif __name__ == \"__main__\":\n    import doctest\n\n    doctest.testmod(raise_on_error=True)\n"
  },
  {
    "path": "python/oneflow/nn/modules/einsum.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow as flow\n\n\ndef einsum_op(equation, *operands):\n    return flow._C.einsum(equation, operands)\n\n\nif __name__ == \"__main__\":\n    import doctest\n\n    doctest.testmod(raise_on_error=True)\n"
  },
  {
    "path": "python/oneflow/nn/modules/empty.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nfrom typing import List, Optional, Union\n\nimport oneflow as flow\nfrom oneflow.nn.common_types import _size_any_t\nfrom oneflow.nn.modules.utils import _handle_size_arg, _single\n\n\ndef empty_op(\n    *size,\n    dtype: Optional[flow.dtype] = None,\n    device: Union[flow.device, str] = None,\n    placement: flow.placement = None,\n    sbp: Union[\n        flow._oneflow_internal.sbp.sbp, List[flow._oneflow_internal.sbp.sbp]\n    ] = None,\n    requires_grad: bool = False,\n    pin_memory: bool = False,\n):\n    assert size is not None, \"shape must not be None\"\n\n    shape = _single(_handle_size_arg(size))\n\n    if dtype is None:\n        dtype = flow.get_default_dtype()\n    if placement is not None:\n        assert (\n            device is None\n        ), \"argument 'device' must be None when argument 'placement' exist\"\n\n    if placement is not None:\n        assert (\n            sbp is not None\n        ), \"argument 'sbp' must not be None when argument 'placement' exist\"\n        assert isinstance(\n            sbp, (flow.sbp.sbp, tuple, list)\n        ), f\"argument 'sbp' must be flow.sbp.sbp, not %s\" % (type(sbp))\n        if isinstance(sbp, flow.sbp.sbp):\n            sbp = (sbp,)\n        else:\n            for elem in sbp:\n                assert isinstance(elem, flow.sbp.sbp), (\n                    \"Element in argument 'sbp' must be flow.sbp.sbp, not %s\"\n                    % (type(elem))\n                )\n        assert len(sbp) == len(placement.ranks.shape)\n    else:\n        assert sbp is None, \"argument 'sbp' must be None\"\n\n    if placement is not None:\n        tensor = flow._C.global_empty(shape, dtype=dtype, placement=placement, sbp=sbp)\n        tensor.requires_grad_(requires_grad)\n    else:\n        tensor = flow._C.empty(\n            shape,\n            dtype=dtype,\n            device=device,\n            requires_grad=requires_grad,\n            pin_memory=pin_memory,\n        )\n    return tensor\n\n\ndef empty_like_op(\n    input,\n    dtype: Optional[flow.dtype] = None,\n    device: Union[flow.device, str, None] = None,\n    placement: flow.placement = None,\n    sbp: flow._oneflow_internal.sbp.sbp = None,\n    requires_grad: bool = False,\n):\n    new_size = _single(_handle_size_arg(input.size()))\n    if placement is None and input.is_global and input.placement is not None:\n        placement = input.placement\n    if sbp is None and input.is_global and input.sbp is not None:\n        sbp = input.sbp\n    if dtype is None:\n        dtype = input.dtype\n    if placement is None and device is None:\n        device = input.device\n    return empty_op(\n        new_size,\n        dtype=dtype,\n        device=device,\n        placement=placement,\n        sbp=sbp,\n        requires_grad=requires_grad,\n    )\n\n\ndef new_empty_op(\n    x, size, dtype=None, device=None, placement=None, sbp=None, requires_grad=False\n):\n    new_size = _single(_handle_size_arg(size))\n    new_dtype = dtype\n    new_device = device\n    new_placement = placement\n    new_sbp = sbp\n\n    if dtype is None:\n        new_dtype = x.dtype\n    if device is None:\n        new_device = x.device if x.is_local else None\n    if placement is None:\n        new_placement = x.placement if x.is_global else None\n    if sbp is None:\n        new_sbp = x.sbp if x.is_global else None\n\n    return empty_op(\n        new_size,\n        dtype=new_dtype,\n        device=new_device,\n        placement=new_placement,\n        sbp=new_sbp,\n        requires_grad=requires_grad,\n    )\n\n\nif __name__ == \"__main__\":\n    import doctest\n\n    doctest.testmod(raise_on_error=True)\n"
  },
  {
    "path": "python/oneflow/nn/modules/expand.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow as flow\nfrom oneflow.nn.modules.utils import _single, _handle_size_arg\n\n\ndef expand_op(input, *sizes):\n    sizes = _handle_size_arg(sizes)\n    sizes = _single(sizes)\n    return flow._C.expand(input, sizes)\n\n\nif __name__ == \"__main__\":\n    import doctest\n\n    doctest.testmod(raise_on_error=True)\n"
  },
  {
    "path": "python/oneflow/nn/modules/fake_quantization.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow as flow\nfrom oneflow.nn.modules.module import Module\n\n\nclass FakeQuantization(Module):\n    \"\"\"\n    \n    Simulate the quantize and dequantize operations in training time.\n\n    The output will be computed as:\n\n        if quantization_scheme == \"symmetric\":\n\n        .. math::\n\n            & quant\\\\_max = 2^{quantization\\\\_to\\\\_bit - 1} - 1\n\n            & quant\\\\_min = -quant\\\\_max\n\n            & clamp(round(x / scale), quant\\\\_min, quant\\\\_max) * scale\n\n        elif quantization_scheme == \"affine\":\n\n        .. math::\n\n            & quant\\\\_max = 2^{quantization\\\\_to\\\\_bit} - 1\n\n            & quant\\\\_min = 0\n\n            & (clamp(round(x / scale + zero\\\\_point), quant\\\\_min, quant\\\\_max) - zero\\\\_point) * scale\n\n    Args:\n        input(oneflow.Tensor):  the input value(s), in ``oneflow.float32``.\n        scale(oneflow.Tensor): quantization scale.\n        zero_point(oneflow.Tensor): quantization zero_point.\n        quantization_bit (int): Quantize input to uintX / intX, X can be in range [2, 8]. Defaults to 8.\n        quantization_scheme (str): \"symmetric\" or \"affine\", quantize to signed / unsigned integer. Defaults to \"symmetric\".\n        quantization_formula (str): Support \"google\" or \"cambricon\".\n\n    Returns:\n        oneflow.Tensor: Input tensor after quantize and dequantize operations.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import numpy as np\n        >>> import oneflow as flow\n\n        >>> weight = (np.random.random((2, 3, 4, 5)) - 0.5).astype(np.float32)\n        \n        >>> input_tensor = flow.tensor(\n        ...    weight, dtype=flow.float32\n        ... )\n        \n        >>> quantization_bit = 8\n        >>> quantization_scheme = \"symmetric\"\n        >>> quantization_formula = \"google\"\n        >>> per_layer_quantization = True\n\n        >>> min_max_observer = flow.nn.MinMaxObserver(quantization_formula=quantization_formula, quantization_bit=quantization_bit,\n        ... quantization_scheme=quantization_scheme, per_layer_quantization=per_layer_quantization)\n        >>> fake_quantization = flow.nn.FakeQuantization(quantization_formula=quantization_formula, quantization_bit=quantization_bit, \n        ... quantization_scheme=quantization_scheme)\n\n        >>> scale, zero_point = min_max_observer(\n        ...    input_tensor,\n        ... )\n\n        >>> output_tensor = fake_quantization(\n        ...    input_tensor,\n        ...    scale,\n        ...    zero_point,\n        ... )\n\n    \"\"\"\n\n    def __init__(\n        self,\n        quantization_formula: str = \"google\",\n        quantization_bit: int = 8,\n        quantization_scheme: str = \"symmetric\",\n    ) -> None:\n        super().__init__()\n        self.quantization_formula = quantization_formula\n        self.quantization_bit = quantization_bit\n        self.quantization_scheme = quantization_scheme\n\n    def forward(self, input, scale, zero_point):\n        return flow._C.fake_quantization(\n            input,\n            scale,\n            zero_point,\n            self.quantization_formula,\n            self.quantization_bit,\n            self.quantization_scheme,\n        )\n\n\nif __name__ == \"__main__\":\n    import doctest\n\n    doctest.testmod(raise_on_error=True)\n"
  },
  {
    "path": "python/oneflow/nn/modules/flatten.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow as flow\nfrom oneflow.framework.tensor import register_tensor_op\nfrom oneflow.nn.modules.module import Module\n\n\nclass Flatten(Module):\n    \"\"\"Flattens a contiguous range of dims into a tensor. For use with: nn.Sequential.\n\n    Args:\n        start_dim: first dim to flatten (default = 1).\n        end_dim: last dim to flatten (default = -1).\n    \n\n    For example: \n\n    .. code-block:: python \n\n        >>> import oneflow as flow\n        >>> input = flow.Tensor(32, 1, 5, 5)\n        >>> m = flow.nn.Flatten()\n        >>> output = m(input)\n        >>> output.shape\n        oneflow.Size([32, 25])\n\n    \"\"\"\n\n    def __init__(self, start_dim: int = 1, end_dim: int = -1) -> None:\n        super().__init__()\n        self.start_dim = start_dim\n        self.end_dim = end_dim\n\n    def forward(self, input):\n        return flow._C.flatten(input, start_dim=self.start_dim, end_dim=self.end_dim)\n\n    def extra_repr(self) -> str:\n        return \"start_dim={}, end_dim={}\".format(self.start_dim, self.end_dim)\n\n\nif __name__ == \"__main__\":\n    import doctest\n\n    doctest.testmod(raise_on_error=True)\n"
  },
  {
    "path": "python/oneflow/nn/modules/fold.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow as flow\nfrom oneflow.nn.common_types import _size_2_t\nfrom oneflow.nn.modules.module import Module\n\n\nclass Fold(Module):\n    r\"\"\"\n    Fold(output_size, kernel_size, dilation=1, padding=0, stride=1)\n\n    The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.Fold.html.\n\n    Combines an array of sliding local blocks into a large containing\n    tensor, it also called `col2img`\n\n    Consider a batched :attr:`input` tensor containing sliding local blocks,\n    e.g., patches of images, of shape :math:`(N, C \\times  \\prod(\\text{kernel_size}), L)`,\n    where :math:`N` is batch dimension, :math:`C \\times \\prod(\\text{kernel_size})`\n    is the number of values within a block (a block has :math:`\\prod(\\text{kernel_size})`\n    spatial locations each containing a :math:`C`-channeled vector), and\n    :math:`L` is the total number of blocks. (This is exactly the\n    same specification as the output shape of :class:`~oneflow.nn.Unfold`.) This\n    operation combines these local blocks into the large :attr:`output` tensor\n    of shape :math:`(N, C, \\text{output_size}[0], \\text{output_size}[1], \\dots)`\n    by summing the overlapping values. Similar to :class:`~oneflow.nn.Unfold`, the\n    arguments must satisfy\n\n    .. math::\n        L = \\prod_d \\left\\lfloor\\frac{\\text{output_size}[d] + 2 \\times \\text{padding}[d] %\n            - \\text{dilation}[d] \\times (\\text{kernel_size}[d] - 1) - 1}{\\text{stride}[d]} + 1\\right\\rfloor,\n\n    where :math:`d` is over all spatial dimensions.\n\n    * :attr:`output_size` describes the spatial shape of the large containing\n      tensor of the sliding local blocks. It is useful to resolve the ambiguity\n      when multiple input shapes map to same number of sliding blocks, e.g.,\n      with ``stride > 0``.\n\n    The :attr:`padding`, :attr:`stride` and :attr:`dilation` arguments specify\n    how the sliding blocks are retrieved.\n\n    * :attr:`stride` controls the stride for the sliding blocks.\n\n    * :attr:`padding` controls the amount of implicit zero-paddings on both\n      sides for :attr:`padding` number of points for each dimension before\n      reshaping.\n\n    * :attr:`dilation` controls the spacing between the kernel points; also known as \n      the à trous algorithm.\n\n    Args:\n        output_size (int or tuple): the shape of the spatial dimensions of the\n                                    output (i.e., ``output.sizes()[2:]``)\n        kernel_size (int or tuple): the size of the sliding blocks\n        stride (int or tuple): the stride of the sliding blocks in the input\n                               spatial dimensions. Default: 1\n        padding (int or tuple, optional): implicit zero padding to be added on\n                                          both sides of input. Default: 0\n        dilation (int or tuple, optional): a parameter that controls the\n                                           stride of elements within the\n                                           neighborhood. Default: 1\n\n    * If :attr:`output_size`, :attr:`kernel_size`, :attr:`dilation`,\n      :attr:`padding` or :attr:`stride` is an int or a tuple of length 1 then\n      their values will be replicated across all spatial dimensions.\n\n    * For the case of two output spatial dimensions this operation is sometimes\n      called ``col2im``.\n\n    .. note::\n        :class:`~oneflow.nn.Fold` calculates each combined value in the resulting\n        large tensor by summing all values from all containing blocks.\n        :class:`~oneflow.nn.Unfold` extracts the values in the local blocks by\n        copying from the large tensor. So, if the blocks overlap, they are not\n        inverses of each other.\n\n        In general, folding and unfolding operations are related as\n        follows. Consider :class:`~oneflow.nn.Fold` and\n        :class:`~oneflow.nn.Unfold` instances created with the same\n        parameters:\n\n        >>> fold_params = dict(kernel_size=..., dilation=..., padding=..., stride=...)\n        >>> fold = nn.Fold(output_size=..., **fold_params)\n        >>> unfold = nn.Unfold(**fold_params)\n\n        Then for any (supported) ``input`` tensor the following\n        equality holds:\n\n        ::\n\n            fold(unfold(input)) == divisor * input\n\n        where ``divisor`` is a tensor that depends only on the shape\n        and dtype of the ``input``:\n\n        >>> input_ones = oneflow.ones(input.shape, dtype=input.dtype)\n        >>> divisor = fold(unfold(input_ones))\n\n        When the ``divisor`` tensor contains no zero elements, then\n        ``fold`` and ``unfold`` operations are inverses of each\n        other (up to constant divisor).\n\n    .. warning::\n        Currently, only unbatched (3D) or batched (4D) image-like output tensors are supported.\n\n    Shape:\n        - Input: :math:`(N, C \\times \\prod(\\text{kernel_size}), L)` or :math:`(C \\times \\prod(\\text{kernel_size}), L)`\n        - Output: :math:`(N, C, \\text{output_size}[0], \\text{output_size}[1], \\dots)`\n          or :math:`(C, \\text{output_size}[0], \\text{output_size}[1], \\dots)` as described above\n\n    For example: \n\n    .. code-block:: python \n\n        >>> import oneflow as flow \n        >>> import numpy as np\n\n        >>> x_tensor = flow.Tensor(np.random.randn(1, 9, 16))\n        >>> fold = flow.nn.Fold(output_size=(4, 4), kernel_size=3, padding=1)\n        >>> out = fold(x_tensor)\n        >>> out.shape\n        oneflow.Size([1, 1, 4, 4])\n\n    \"\"\"\n\n    def __init__(\n        self,\n        output_size: _size_2_t,\n        kernel_size: _size_2_t,\n        dilation: _size_2_t = 1,\n        padding: _size_2_t = 0,\n        stride: _size_2_t = 1,\n    ) -> None:\n        super(Fold, self).__init__()\n        self.output_size = output_size\n        self.kernel_size = kernel_size\n        self.dilation = dilation\n        self.padding = padding\n        self.stride = stride\n\n    def forward(self, input):\n        return flow._C.fold(\n            input,\n            self.output_size,\n            self.kernel_size,\n            self.dilation,\n            self.padding,\n            self.stride,\n            \"channels_first\",\n        )\n\n    def extra_repr(self) -> str:\n        return (\n            \"output_size={output_size}, kernel_size={kernel_size}, \"\n            \"dilation={dilation}, padding={padding}, stride={stride}\".format(\n                **self.__dict__\n            )\n        )\n\n\nclass Unfold(Module):\n    r\"\"\"\n    Unfold(kernel_size, dilation=1, padding=0, stride=1)\n\n    The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.Unfold.html.\n\n    This op extracts elements in a local window from input tensor, it also called `img2col`. \n\n    Consider a batched :attr:`input` tensor of shape :math:`(N, C, *)`,\n    where :math:`N` is the batch dimension, :math:`C` is the channel dimension,\n    and :math:`*` represent arbitrary spatial dimensions. This operation flattens\n    each sliding :attr:`kernel_size`-sized block within the spatial dimensions\n    of :attr:`input` into a column (i.e., last dimension) of a 3-D :attr:`output`\n    tensor of shape :math:`(N, C \\times \\prod(\\text{kernel_size}), L)`, where\n    :math:`C \\times \\prod(\\text{kernel_size})` is the total number of values\n    within each block (a block has :math:`\\prod(\\text{kernel_size})` spatial\n    locations each containing a :math:`C`-channeled vector), and :math:`L` is\n    the total number of such blocks:\n\n    .. math::\n        L = \\prod_d \\left\\lfloor\\frac{\\text{spatial_size}[d] + 2 \\times \\text{padding}[d] %\n            - \\text{dilation}[d] \\times (\\text{kernel_size}[d] - 1) - 1}{\\text{stride}[d]} + 1\\right\\rfloor,\n\n    where :math:`\\text{spatial_size}` is formed by the spatial dimensions\n    of :attr:`input` (:math:`*` above), and :math:`d` is over all spatial\n    dimensions.\n\n    Therefore, indexing :attr:`output` at the last dimension (column dimension)\n    gives all values within a certain block.\n\n    The :attr:`padding`, :attr:`stride` and :attr:`dilation` arguments specify\n    how the sliding blocks are retrieved.\n\n    * :attr:`stride` controls the stride for the sliding blocks.\n\n    * :attr:`padding` controls the amount of implicit zero-paddings on both\n      sides for :attr:`padding` number of points for each dimension before\n      reshaping.\n\n    * :attr:`dilation` controls the spacing between the kernel points; also known as\n      the à trous algorithm.\n\n    Args:\n        kernel_size (int or tuple): the size of the sliding blocks\n        stride (int or tuple, optional): the stride of the sliding blocks in the input\n                                         spatial dimensions. Default: 1\n        padding (int or tuple, optional): implicit zero padding to be added on\n                                          both sides of input. Default: 0\n        dilation (int or tuple, optional): a parameter that controls the\n                                           stride of elements within the\n                                           neighborhood. Default: 1\n\n    * If :attr:`kernel_size`, :attr:`dilation`, :attr:`padding` or\n      :attr:`stride` is an int or a tuple of length 1, their values will be\n      replicated across all spatial dimensions.\n\n    * For the case of two input spatial dimensions this operation is sometimes\n      called ``im2col``.\n\n    .. note::\n        :class:`~oneflow.nn.Fold` calculates each combined value in the resulting\n        large tensor by summing all values from all containing blocks.\n        :class:`~oneflow.nn.Unfold` extracts the values in the local blocks by\n        copying from the large tensor. So, if the blocks overlap, they are not\n        inverses of each other.\n\n        In general, folding and unfolding operations are related as\n        follows. Consider :class:`~oneflow.nn.Fold` and\n        :class:`~oneflow.nn.Unfold` instances created with the same\n        parameters:\n\n        >>> fold_params = dict(kernel_size=..., dilation=..., padding=..., stride=...)\n        >>> fold = nn.Fold(output_size=..., **fold_params)\n        >>> unfold = nn.Unfold(**fold_params)\n\n        Then for any (supported) ``input`` tensor the following\n        equality holds:\n\n        ::\n                    fold(unfold(input)) == divisor * input\n\n        where ``divisor`` is a tensor that depends only on the shape\n        and dtype of the ``input``:\n\n        >>> input_ones = oneflow.ones(input.shape, dtype=input.dtype)\n        >>> divisor = fold(unfold(input_ones))\n\n        When the ``divisor`` tensor contains no zero elements, then\n        ``fold`` and ``unfold`` operations are inverses of each\n        other (up to constant divisor).\n\n    .. warning::\n        Currently, only 4-D input tensors (batched image-like tensors) are\n        supported.\n\n    Shape:\n        - Input: :math:`(N, C, *)`\n        - Output: :math:`(N, C \\times \\prod(\\text{kernel_size}), L)` as described above\n\n    For example: \n\n    .. code-block:: python \n\n        >>> import oneflow as flow \n        >>> import numpy as np \n\n        >>> x_tensor = flow.Tensor(np.random.randn(1, 1, 4, 4))\n        >>> unfold = flow.nn.Unfold(kernel_size=3, padding=1)\n        >>> out = unfold(x_tensor)\n        >>> out.shape\n        oneflow.Size([1, 9, 16])\n\n    \"\"\"\n\n    def __init__(\n        self,\n        kernel_size: _size_2_t,\n        dilation: _size_2_t = 1,\n        padding: _size_2_t = 0,\n        stride: _size_2_t = 1,\n    ) -> None:\n        super(Unfold, self).__init__()\n        self.kernel_size = kernel_size\n        self.dilation = dilation\n        self.padding = padding\n        self.stride = stride\n\n    def forward(self, input):\n        return flow._C.unfold(\n            input,\n            self.kernel_size,\n            self.dilation,\n            self.padding,\n            self.stride,\n            \"channels_first\",\n        )\n\n    def extra_repr(self) -> str:\n        return (\n            \"kernel_size={kernel_size}, dilation={dilation}, padding={padding},\"\n            \" stride={stride}\".format(**self.__dict__)\n        )\n"
  },
  {
    "path": "python/oneflow/nn/modules/fused_mlp.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport math\n\nimport oneflow as flow\nfrom oneflow.framework.tensor import Tensor\nfrom oneflow.nn.init import _calculate_fan_in_and_fan_out\nfrom oneflow.nn.modules.module import Module\nfrom typing import Tuple\n\n\nclass FusedMLP(Module):\n    \"\"\"Applies a linear transformation with relu activation to the incoming data: :math:`y = ReLU(xA^T + b)`\n\n    Args:\n        in_features: size of each input sample\n\n        hidden_features: A tuple of each Linear layer hidden size\n\n        out_features: The final Linear layer hidden size\n\n        hidden_dropout_rate: A tuple of each hidden layer's dropout rate\n\n        out_dropout_rate: The final Linear layer's dropout rate\n\n    Shape:\n        - Input: :math:`(N, *, H_{in})` where :math:`*` means any number of\n          additional dimensions and :math:`H_{in} = {in\\\\_features}`\n\n        - Output: :math:`(N, *, H_{out})` where all but the last dimension\n          are the same shape as the input and :math:`H_{out} = {out\\\\_features}`.\n\n    Attr:\n        - :attr:`skip_final_activation`: Whether to skip final hidden layer's activation. Default: False. \n\n    For example:\n\n    .. code-block:: python\n\n        >>> import numpy as np\n        >>> import oneflow as flow\n        \n\n        >>> m = flow.nn.FusedMLP(128, [256, 512], 1024).to(\"cuda\")\n        >>> input = flow.Tensor(np.random.randn(1, 128)).to(\"cuda\")\n        >>> output = m(input)\n        >>> output.size()\n        oneflow.Size([1, 1024])\n\n    \"\"\"\n\n    def __init__(\n        self,\n        in_features: int,\n        hidden_features: Tuple[int],\n        out_features: int,\n        hidden_dropout_rate: Tuple[float] = None,\n        out_dropout_rate: float = 0.0,\n        skip_final_activation=False,\n    ) -> None:\n        super().__init__()\n        self.in_features = in_features\n        self.hidden_features = hidden_features\n        self.out_features = out_features\n        # TODO(zzk): Add more activation support.\n        self.skip_final_activation = skip_final_activation\n        self.hidden_layer_num = len(hidden_features)\n        self.dropout_rate_list = (\n            hidden_dropout_rate\n            if hidden_dropout_rate\n            else [0.0] * (self.hidden_layer_num)\n        )\n        self.dropout_rate_list += [out_dropout_rate]\n        self.add_parameters()\n        self.reset_parameters()\n        self.use_dropout = False\n        for i in range(self.hidden_layer_num + 1):\n            if self.dropout_rate_list[i] != 0.0:\n                self.use_dropout = True\n                break\n\n    def add_parameters(self) -> None:\n        \"\"\"Register parameter in FusedMLP module. \n\n        \"\"\"\n        if self.hidden_layer_num != 0:\n            # First layer.\n            self.register_parameter(\n                f\"weight_{0}\",\n                flow.nn.Parameter(\n                    flow.Tensor(self.hidden_features[0], self.in_features)\n                ),\n            )\n            self.register_parameter(\n                f\"bias_{0}\", flow.nn.Parameter(flow.Tensor(self.hidden_features[0]))\n            )\n\n            # Middle Layer.\n            for idx in range(1, self.hidden_layer_num):\n                self.register_parameter(\n                    f\"weight_{idx}\",\n                    flow.nn.Parameter(\n                        flow.Tensor(\n                            self.hidden_features[idx], self.hidden_features[idx - 1],\n                        )\n                    ),\n                )\n                self.register_parameter(\n                    f\"bias_{idx}\",\n                    flow.nn.Parameter(flow.Tensor(self.hidden_features[idx])),\n                )\n\n            # Final Layer.\n            self.register_parameter(\n                f\"weight_{self.hidden_layer_num}\",\n                flow.nn.Parameter(\n                    flow.Tensor(\n                        self.out_features,\n                        self.hidden_features[self.hidden_layer_num - 1],\n                    )\n                ),\n            )\n            self.register_parameter(\n                f\"bias_{self.hidden_layer_num}\",\n                flow.nn.Parameter(flow.Tensor(self.out_features)),\n            )\n        else:\n            # there is only 1 layer.\n            self.register_parameter(\n                f\"weight_{0}\",\n                flow.nn.Parameter(flow.Tensor(self.out_features, self.in_features)),\n            )\n            self.register_parameter(\n                f\"bias_{0}\", flow.nn.Parameter(flow.Tensor(self.out_features))\n            )\n\n    def weight(self, i):\n        \"\"\"Returns the ith weight. \n\n        \"\"\"\n        return getattr(self, f\"weight_{i}\")\n\n    def weights(self):\n        \"\"\"Returns the weight list in FusedMLP module. \n\n        \"\"\"\n        return [self.weight(i) for i in range(self.hidden_layer_num + 1)]\n\n    def bias(self, i):\n        \"\"\"Return the ith bias. \n\n        \"\"\"\n        return getattr(self, f\"bias_{i}\")\n\n    def biases(self):\n        \"\"\"Returns the bias list in FusedMLP module. \n\n        \"\"\"\n        return [self.bias(i) for i in range(self.hidden_layer_num + 1)]\n\n    def reset_parameters(self) -> None:\n        \"\"\"Reset the parameters in FusedMLP module. \n\n        \"\"\"\n        for layer_idx in range(self.hidden_layer_num + 1):\n            flow.nn.init.kaiming_uniform_(self.weight(layer_idx), a=math.sqrt(5))\n            (fan_in, _) = _calculate_fan_in_and_fan_out(self.weight(layer_idx))\n            bound = 1 / math.sqrt(fan_in)\n            flow.nn.init.uniform_(self.bias(layer_idx), -bound, bound)\n\n    def forward(self, x):\n        if not self.training or not self.use_dropout:\n            return flow._C.fused_mlp(\n                x, self.weights(), self.biases(), self.skip_final_activation\n            )\n        else:\n            return flow._C.fused_matmul_bias_add_relu_dropout(\n                x,\n                self.weights(),\n                self.biases(),\n                self.skip_final_activation,\n                self.dropout_rate_list,\n            )\n\n    def extra_repr(self) -> str:\n        return \"in_features={}, hidden_features={}, out_features={}, skip_final_activation={}\".format(\n            self.in_features,\n            self.hidden_features,\n            self.out_features,\n            self.skip_final_activation,\n        )\n\n\nif __name__ == \"__main__\":\n    import doctest\n\n    doctest.testmod(raise_on_error=True)\n"
  },
  {
    "path": "python/oneflow/nn/modules/global_cast.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow as flow\nfrom oneflow.framework.tensor import register_tensor_op, Tensor\nfrom oneflow.nn.modules.module import Module\n\n\ndef _check_sbp(sbp):\n    if sbp is None:\n        pass\n    elif isinstance(sbp, (tuple, list)):\n        if not all(isinstance(sbp_item, flow.sbp.sbp) for sbp_item in sbp):\n            raise TypeError(\n                \"sbp parameter must be type of oneflow.sbp.sbp or list/tuple of oneflow.sbp.sbp\"\n            )\n    elif isinstance(sbp, flow.sbp.sbp):\n        sbp = (sbp,)\n    else:\n        raise TypeError(f\"Invalid parameter sbp with type {type(sbp)}\")\n\n    return sbp\n\n\ndef local_to_global_op(input, placement=None, sbp=None, *, check_meta=True, copy=False):\n    # Convert None to a tensor with shape 0, in order to input it into flow._C.to_global.\n    if input is None:\n        input = flow.tensor(())\n\n    assert isinstance(input, Tensor)\n    assert input.is_local, \"input must be a local tensor\"\n    if placement is None or sbp is None:\n        raise ValueError(\n            \"Converting a local tensor to global tensor must have placement and sbp parameters.\"\n        )\n\n    assert isinstance(\n        placement, flow.placement\n    ), f\"Invalid parameter placement with type {type(placement)}\"\n\n    sbp = _check_sbp(sbp)\n    grad_sbp = tuple()\n    return flow._C.to_global(input, placement, sbp, grad_sbp, check_meta, copy)\n\n\ndef global_to_global_op(\n    input, placement=None, sbp=None, *, grad_sbp=None, check_meta=False, copy=False\n):\n    assert isinstance(input, Tensor)\n    assert input.is_global, \"input must be a global tensor\"\n\n    sbp = _check_sbp(sbp)\n    if placement is None:\n        placement = input.placement\n\n    if sbp is None:\n        sbp = input.sbp\n\n    assert isinstance(\n        placement, flow.placement\n    ), f\"Invalid parameter placement with type {type(placement)}\"\n\n    grad_sbp = _check_sbp(grad_sbp)\n    if grad_sbp is None:\n        grad_sbp = tuple()\n    return flow._C.to_global(input, placement, sbp, grad_sbp, check_meta, copy)\n\n\ndef to_global_op(input, placement=None, sbp=None, **kwargs):\n    assert isinstance(input, Tensor)\n\n    if input.is_global:\n        return global_to_global_op(input=input, placement=placement, sbp=sbp, **kwargs)\n    else:\n        if \"grad_sbp\" in kwargs:\n            del kwargs[\"grad_sbp\"]\n        return local_to_global_op(input=input, placement=placement, sbp=sbp, **kwargs)\n\n\ndef to_local_op(input, *, copy=False):\n    assert input.is_global, \"Expected global tensor for to_local but got local tensor!\"\n    return flow._C.to_local(input, copy)\n"
  },
  {
    "path": "python/oneflow/nn/modules/grid_sample.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport oneflow as flow\n\n\ndef grid_sample(\n    input,\n    grid,\n    mode: str = \"bilinear\",\n    padding_mode: str = \"zeros\",\n    align_corners: bool = False,\n):\n    \"\"\"The interface is consistent with PyTorch.    \n    The documentation is referenced from: \n    https://pytorch.org/docs/1.10/generated/torch.nn.functional.grid_sample.html.\n\n    Given an :attr:`input` and a flow-field :attr:`grid`, computes the\n    ``output`` using :attr:`input` values and pixel locations from :attr:`grid`.\n\n    Currently, only spatial (4-D) and volumetric (5-D) :attr:`input` are\n    supported.\n\n    In the spatial (4-D) case, for :attr:`input` with shape\n    :math:`(N, C, H_{in}, W_{in})` and :attr:`grid` with shape\n    :math:`(N, H_{out}, W_{out}, 2)`, the output will have shape\n    :math:`(N, C, H_{out}, W_{out})`.\n\n    For each output location ``output[n, :, h, w]``, the size-2 vector\n    ``grid[n, h, w]`` specifies :attr:`input` pixel locations ``x`` and ``y``,\n    which are used to interpolate the output value ``output[n, :, h, w]``.\n    In the case of 5D inputs, ``grid[n, d, h, w]`` specifies the\n    ``x``, ``y``, ``z`` pixel locations for interpolating\n    ``output[n, :, d, h, w]``. :attr:`mode` argument specifies ``nearest`` or\n    ``bilinear`` interpolation method to sample the input pixels.\n\n    :attr:`grid` specifies the sampling pixel locations normalized by the\n    :attr:`input` spatial dimensions. Therefore, it should have most values in\n    the range of ``[-1, 1]``. For example, values ``x = -1, y = -1`` is the\n    left-top pixel of :attr:`input`, and values  ``x = 1, y = 1`` is the\n    right-bottom pixel of :attr:`input`.\n\n    If :attr:`grid` has values outside the range of ``[-1, 1]``, the corresponding\n    outputs are handled as defined by :attr:`padding_mode`. Options are\n\n        * ``padding_mode=\"zeros\"``: use ``0`` for out-of-bound grid locations,\n        * ``padding_mode=\"border\"``: use border values for out-of-bound grid locations,\n        * ``padding_mode=\"reflection\"``: use values at locations reflected by\n          the border for out-of-bound grid locations. For location far away\n          from the border, it will keep being reflected until becoming in bound,\n          e.g., (normalized) pixel location ``x = -3.5`` reflects by border ``-1``\n          and becomes ``x' = 1.5``, then reflects by border ``1`` and becomes\n          ``x'' = -0.5``.\n\n    Note:\n        This function is often used in conjunction with :func:`affine_grid`\n        to build `Spatial Transformer Networks`_ .\n\n    Note:\n        NaN values in :attr:`grid` would be interpreted as ``-1``.\n\n    Args:\n        input (Tensor): input of shape :math:`(N, C, H_{in}, W_{in})` (4-D case)\n                        or :math:`(N, C, D_{in}, H_{in}, W_{in})` (5-D case)\n        grid (Tensor): flow-field of shape :math:`(N, H_{out}, W_{out}, 2)` (4-D case)\n                       or :math:`(N, D_{out}, H_{out}, W_{out}, 3)` (5-D case)\n        mode (str): interpolation mode to calculate output values\n            ``'bilinear'`` | ``'nearest'`` | ``'bicubic'``. Default: ``'bilinear'``\n            Note: ``mode='bicubic'`` supports only 4-D input.\n            When ``mode='bilinear'`` and the input is 5-D, the interpolation mode\n            used internally will actually be trilinear. However, when the input is 4-D,\n            the interpolation mode will legitimately be bilinear.\n        padding_mode (str): padding mode for outside grid values\n            ``'zeros'`` | ``'border'`` | ``'reflection'``. Default: ``'zeros'``\n        align_corners (bool): Geometrically, we consider the pixels of the\n            input  as squares rather than points.\n            If set to ``True``, the extrema (``-1`` and ``1``) are considered as referring\n            to the center points of the input's corner pixels. If set to ``False``, they\n            are instead considered as referring to the corner points of the input's corner\n            pixels, making the sampling more resolution agnostic.\n            This option parallels the ``align_corners`` option in\n            :func:`interpolate`, and so whichever option is used here\n            should also be used there to resize the input image before grid sampling.\n            Default: ``False``\n\n    Returns:\n        output (Tensor): output Tensor\n\n    .. _`Spatial Transformer Networks`:\n        https://arxiv.org/abs/1506.02025\n\n    .. note::\n        ``mode='bicubic'`` is implemented using the `cubic convolution algorithm`_ with :math:`\\\\alpha=-0.75`.\n        The constant :math:`\\\\alpha` might be different from packages to packages.\n        For example, `PIL`_ and `OpenCV`_ use -0.5 and -0.75 respectively.\n        This algorithm may \"overshoot\" the range of values it's interpolating.\n        For example, it may produce negative values or values greater than 255 when interpolating input in [0, 255].\n        Clamp the results with :func: `flow.clamp` to ensure they are within the valid range.\n    .. _`cubic convolution algorithm`: https://en.wikipedia.org/wiki/Bicubic_interpolation\n    .. _`PIL`: https://github.com/python-pillow/Pillow/blob/4634eafe3c695a014267eefdce830b4a825beed7/src/libImaging/Resample.c#L51\n    .. _`OpenCV`: https://github.com/opencv/opencv/blob/f345ed564a06178670750bad59526cfa4033be55/modules/imgproc/src/resize.cpp#L908\n    \n    Examples::\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n        >>> input = flow.tensor(np.arange(1., 11).reshape((1, 1, 2, 5)), dtype=flow.float32)\n        >>> np_grid = np.array(\n        ...     [[[-0.9, -4.1], [0, 0.2000], [1, -1], [-0.333, 1e-6], [0.5, 1.0]],\n        ...      [[-1.0, -0.5], [0, 0.3333], [1, -1], [-0.200, 1e-6], [1.5, 0.5]]]\n        ... ).reshape(1, 2, 5, 2)\n        >>> grid = flow.tensor(np_grid, dtype=flow.float32)\n        >>> output = flow.nn.functional.grid_sample(input, grid, mode='nearest', padding_mode='zeros',\n        ...                                        align_corners=True)\n        >>> output\n        tensor([[[[0., 8., 5., 7., 9.],\n                  [1., 8., 5., 8., 0.]]]], dtype=oneflow.float32)\n    \"\"\"\n    y = flow._C.grid_sample(\n        input,\n        grid,\n        interpolation_mode=mode,\n        padding_mode=padding_mode,\n        align_corners=align_corners,\n    )\n    return y\n\n\nif __name__ == \"__main__\":\n    import doctest\n\n    doctest.testmod(raise_on_error=True)\n"
  },
  {
    "path": "python/oneflow/nn/modules/instancenorm.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow as flow\nfrom oneflow.nn.modules.batchnorm import _NormBase\n\n\nclass _InstanceNorm(_NormBase):\n    def __init__(\n        self,\n        num_features: int,\n        eps: float = 1e-05,\n        momentum: float = 0.1,\n        affine: bool = False,\n        track_running_stats: bool = False,\n    ):\n        super().__init__(num_features, eps, momentum, affine, track_running_stats)\n\n    def _forward(self, x):\n        axis = 1\n        params_shape = [x.shape[axis]]\n        weight = self.weight\n        bias = self.bias\n        nd_params_shape = [1] * len(x.shape)\n        nd_params_shape[axis] = params_shape[0]\n        mean = x.mean(2, keepdim=True)\n        variance = x.var(2, unbiased=False, keepdim=True)\n        normalized = (x - mean) / flow.sqrt(variance + self.eps)\n        if self.weight is not None and params_shape[0] == self.weight.nelement():\n            weight = flow.reshape(self.weight, shape=nd_params_shape)\n        if self.bias is not None and params_shape[0] == self.bias.nelement():\n            bias = flow.reshape(self.bias, shape=nd_params_shape)\n        if self.weight is not None:\n            normalized = normalized * weight\n        if self.bias is not None:\n            normalized = normalized + bias\n        return normalized\n\n    def forward(self, x):\n        self._check_input_dim(x)\n        reshape_to_1d = flow.reshape(x, [x.shape[0], x.shape[1], -1])\n        normalized_1d_out = self._forward(reshape_to_1d)\n        reshape_back_to_nd = flow.reshape(normalized_1d_out, list(x.shape))\n        return reshape_back_to_nd\n\n\nclass InstanceNorm1d(_InstanceNorm):\n    \"\"\"\n    Applies Instance Normalization over a 3D input (a mini-batch of 1D\n    inputs with optional additional channel dimension) as described in the paper\n    `Instance Normalization: The Missing Ingredient for Fast Stylization\n    <https://arxiv.org/abs/1607.08022>`__.\n\n    .. math::\n\n        y = \\\\frac{x - \\\\mathrm{E}[x]}{ \\\\sqrt{\\\\mathrm{Var}[x] + \\\\epsilon}} * \\\\gamma + \\\\beta\n\n    The mean and standard-deviation are calculated per-dimension separately\n    for each object in a mini-batch. :math:`\\\\gamma` and :math:`\\\\beta` are learnable parameter vectors\n    of size `C` (where `C` is the input size) if :attr:`affine` is ``True``.\n    The standard-deviation is calculated via the biased estimator, equivalent to\n    `torch.var(input, unbiased=False)`.\n\n    By default, this layer uses instance statistics computed from input data in\n    both training and evaluation modes.\n\n    If :attr:`track_running_stats` is set to ``True``, during training this\n    layer keeps running estimates of its computed mean and variance, which are\n    then used for normalization during evaluation. The running estimates are\n    kept with a default :attr:`momentum` of 0.1.\n\n    .. note::\n        This :attr:`momentum` argument is different from one used in optimizer\n        classes and the conventional notion of momentum. Mathematically, the\n        update rule for running statistics here is\n        :math:`\\\\hat{x}_\\\\text{new} = (1 - \\\\text{momentum}) \\\\times \\\\hat{x} + \\\\text{momentum} \\\\times x_t`,\n        where :math:`\\\\hat{x}` is the estimated statistic and :math:`x_t` is the\n        new observed value.\n\n    .. note::\n        :class:`InstanceNorm1d` and :class:`LayerNorm` are very similar, but\n        have some subtle differences. :class:`InstanceNorm1d` is applied\n        on each channel of channeled data like multidimensional time series, but\n        :class:`LayerNorm` is usually applied on entire sample and often in NLP\n        tasks. Additionally, :class:`LayerNorm` applies elementwise affine\n        transform, while :class:`InstanceNorm1d` usually don't apply affine\n        transform.\n\n    The interface is consistent with PyTorch.\n    The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.InstanceNorm1d.html.\n\n    Args:\n        num_features: :math:`C` from an expected input of size\n            :math:`(N, C, L)` or :math:`L` from input of size :math:`(N, L)`\n        eps: a value added to the denominator for numerical stability. Default: 1e-5\n        momentum: the value used for the running_mean and running_var computation. Default: 0.1\n        affine: a boolean value that when set to ``True``, this module has\n            learnable affine parameters, initialized the same way as done for batch normalization.\n            Default: ``False``.\n        track_running_stats: a boolean value that when set to ``True``, this\n            module tracks the running mean and variance, and when set to ``False``,\n            this module does not track such statistics and always uses batch\n            statistics in both training and eval modes. Default: ``False``\n\n    Shape:\n        - Input: :math:`(N, C, L)`\n        - Output: :math:`(N, C, L)` (same shape as input)\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n\n        >>> # Without Learnable Parameters\n        >>> m = flow.nn.InstanceNorm1d(100)\n        >>> # With Learnable Parameters\n        >>> m = flow.nn.InstanceNorm1d(100, affine=True)\n        >>> x = flow.Tensor(np.random.randn(20, 100, 40))\n        >>> output = m(x)\n\n    \"\"\"\n\n    def _check_input_dim(self, input):\n        if input.dim() == 2:\n            raise ValueError(\n                \"InstanceNorm1d returns 0-filled tensor to 2D tensor.This is because InstanceNorm1d reshapes inputs to(1, N * C, ...) from (N, C,...) and this makesvariances 0.\"\n            )\n        if input.dim() != 3:\n            raise ValueError(\"expected 3D input (got {}D input)\".format(input.dim()))\n\n\nclass InstanceNorm2d(_InstanceNorm):\n    \"\"\"\n    Applies Instance Normalization over a 4D input (a mini-batch of 2D inputs\n    with additional channel dimension) as described in the paper\n    `Instance Normalization: The Missing Ingredient for Fast Stylization\n    <https://arxiv.org/abs/1607.08022>`__.\n\n    .. math::\n\n        y = \\\\frac{x - \\\\mathrm{E}[x]}{ \\\\sqrt{\\\\mathrm{Var}[x] + \\\\epsilon}} * \\\\gamma + \\\\beta\n\n    The mean and standard-deviation are calculated per-dimension separately\n    for each object in a mini-batch. :math:`\\\\gamma` and :math:`\\\\beta` are learnable parameter vectors\n    of size `C` (where `C` is the input size) if :attr:`affine` is ``True``.\n    The standard-deviation is calculated via the biased estimator, equivalent to\n    `torch.var(input, unbiased=False)`.\n\n    By default, this layer uses instance statistics computed from input data in\n    both training and evaluation modes.\n\n    If :attr:`track_running_stats` is set to ``True``, during training this\n    layer keeps running estimates of its computed mean and variance, which are\n    then used for normalization during evaluation. The running estimates are\n    kept with a default :attr:`momentum` of 0.1.\n\n    .. note::\n        This :attr:`momentum` argument is different from one used in optimizer\n        classes and the conventional notion of momentum. Mathematically, the\n        update rule for running statistics here is\n        :math:`\\\\hat{x}_\\\\text{new} = (1 - \\\\text{momentum}) \\\\times \\\\hat{x} + \\\\text{momentum} \\\\times x_t`,\n        where :math:`\\\\hat{x}` is the estimated statistic and :math:`x_t` is the\n        new observed value.\n\n    .. note::\n        :class:`InstanceNorm2d` and :class:`LayerNorm` are very similar, but\n        have some subtle differences. :class:`InstanceNorm2d` is applied\n        on each channel of channeled data like RGB images, but\n        :class:`LayerNorm` is usually applied on entire sample and often in NLP\n        tasks. Additionally, :class:`LayerNorm` applies elementwise affine\n        transform, while :class:`InstanceNorm2d` usually don't apply affine\n        transform.\n\n    The interface is consistent with PyTorch.\n    The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.InstanceNorm2d.html.\n\n    Args:\n        num_features: :math:`C` from an expected input of size\n            :math:`(N, C, H, W)`\n        eps: a value added to the denominator for numerical stability. Default: 1e-5\n        momentum: the value used for the running_mean and running_var computation. Default: 0.1\n        affine: a boolean value that when set to ``True``, this module has\n            learnable affine parameters, initialized the same way as done for batch normalization.\n            Default: ``False``.\n        track_running_stats: a boolean value that when set to ``True``, this\n            module tracks the running mean and variance, and when set to ``False``,\n            this module does not track such statistics and always uses batch\n            statistics in both training and eval modes. Default: ``False``\n\n    Shape:\n        - Input: :math:`(N, C, H, W)`\n        - Output: :math:`(N, C, H, W)` (same shape as input)\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n\n        >>> # Without Learnable Parameters\n        >>> m = flow.nn.InstanceNorm2d(100)\n        >>> # With Learnable Parameters\n        >>> m = flow.nn.InstanceNorm2d(100, affine=True)\n        >>> x = flow.Tensor(np.random.randn(20, 100, 35, 45))\n        >>> output = m(x)\n\n    \"\"\"\n\n    def _check_input_dim(self, input):\n        if input.dim() != 4:\n            raise ValueError(\"expected 4D input (got {}D input)\".format(input.dim()))\n\n\nclass InstanceNorm3d(_InstanceNorm):\n    \"\"\"\n    Applies Instance Normalization over a 5D input (a mini-batch of 3D inputs\n    with additional channel dimension) as described in the paper\n    `Instance Normalization: The Missing Ingredient for Fast Stylization\n    <https://arxiv.org/abs/1607.08022>`__.\n\n    .. math::\n\n        y = \\\\frac{x - \\\\mathrm{E}[x]}{ \\\\sqrt{\\\\mathrm{Var}[x] + \\\\epsilon}} * \\\\gamma + \\\\beta\n\n    The mean and standard-deviation are calculated per-dimension separately\n    for each object in a mini-batch. :math:`\\\\gamma` and :math:`\\\\beta` are learnable parameter vectors\n    of size C (where C is the input size) if :attr:`affine` is ``True``.\n    The standard-deviation is calculated via the biased estimator, equivalent to\n    `torch.var(input, unbiased=False)`.\n\n    By default, this layer uses instance statistics computed from input data in\n    both training and evaluation modes.\n\n    If :attr:`track_running_stats` is set to ``True``, during training this\n    layer keeps running estimates of its computed mean and variance, which are\n    then used for normalization during evaluation. The running estimates are\n    kept with a default :attr:`momentum` of 0.1.\n\n    .. note::\n        This :attr:`momentum` argument is different from one used in optimizer\n        classes and the conventional notion of momentum. Mathematically, the\n        update rule for running statistics here is\n        :math:`\\\\hat{x}_\\\\text{new} = (1 - \\\\text{momentum}) \\\\times \\\\hat{x} + \\\\text{momentum} \\\\times x_t`,\n        where :math:`\\\\hat{x}` is the estimated statistic and :math:`x_t` is the\n        new observed value.\n\n    .. note::\n        :class:`InstanceNorm3d` and :class:`LayerNorm` are very similar, but\n        have some subtle differences. :class:`InstanceNorm3d` is applied\n        on each channel of channeled data like 3D models with RGB color, but\n        :class:`LayerNorm` is usually applied on entire sample and often in NLP\n        tasks. Additionally, :class:`LayerNorm` applies elementwise affine\n        transform, while :class:`InstanceNorm3d` usually don't apply affine\n        transform.\n\n    The interface is consistent with PyTorch.\n    The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.InstanceNorm3d.html.\n\n    Args:\n        num_features: :math:`C` from an expected input of size\n            :math:`(N, C, D, H, W)`\n        eps: a value added to the denominator for numerical stability. Default: 1e-5\n        momentum: the value used for the running_mean and running_var computation. Default: 0.1\n        affine: a boolean value that when set to ``True``, this module has\n            learnable affine parameters, initialized the same way as done for batch normalization.\n            Default: ``False``.\n        track_running_stats: a boolean value that when set to ``True``, this\n            module tracks the running mean and variance, and when set to ``False``,\n            this module does not track such statistics and always uses batch\n            statistics in both training and eval modes. Default: ``False``\n\n    Shape:\n        - Input: :math:`(N, C, D, H, W)`\n        - Output: :math:`(N, C, D, H, W)` (same shape as input)\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n        >>> # Without Learnable Parameters\n        >>> m = flow.nn.InstanceNorm3d(100)\n        >>> # With Learnable Parameters\n        >>> m = flow.nn.InstanceNorm3d(100, affine=True)\n        >>> x = flow.Tensor(np.random.randn(20, 100, 35, 45, 10))\n        >>> output = m(x)\n\n    \"\"\"\n\n    def _check_input_dim(self, input):\n        if input.dim() != 5:\n            raise ValueError(\"expected 5D input (got {}D input)\".format(input.dim()))\n\n\nif __name__ == \"__main__\":\n    import doctest\n\n    doctest.testmod(raise_on_error=True)\n"
  },
  {
    "path": "python/oneflow/nn/modules/interpolate.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport math\nimport warnings\nfrom typing import Optional, Tuple, Union\n\nimport oneflow as flow\nfrom oneflow.framework.tensor import register_tensor_op\nfrom oneflow.nn.modules.module import Module\n\n\nclass Interpolate:\n    def __init__(\n        self,\n        size: Optional[Union[int, Tuple[int, ...]]] = None,\n        scale_factor: Optional[Union[float, Tuple[float, ...]]] = None,\n        mode: str = \"nearest\",\n        align_corners: Optional[bool] = None,\n        recompute_scale_factor: Optional[bool] = None,\n    ):\n        self.size = size\n        if isinstance(scale_factor, tuple):\n            self.scale_factor = tuple((float(factor) for factor in scale_factor))\n        else:\n            self.scale_factor = float(scale_factor) if scale_factor else None\n        if mode in (\"nearest\", \"area\") and align_corners is not None:\n            raise ValueError(\n                \"align_corners option can only be set with the interpolating modes: linear | bilinear | bicubic | trilinear\"\n            )\n        self.mode = mode\n        self.recompute_scale_factor = recompute_scale_factor\n        if align_corners == None:\n            align_corners = False\n        self.align_corners = align_corners\n        self.height_scale = None\n        self.width_scale = None\n        if isinstance(self.scale_factor, float):\n            self.height_scale = self.scale_factor\n            self.width_scale = self.scale_factor\n        elif isinstance(self.scale_factor, tuple):\n            self.height_scale = self.scale_factor[0]\n            self.width_scale = self.scale_factor[1]\n        else:\n            pass\n        if self.mode not in (\n            \"nearest\",\n            \"bilinear\",\n            \"linear\",\n            \"area\",\n            \"bicubic\",\n            \"trilinear\",\n        ):\n            raise ValueError(\n                'interpolation must be \"nearest\" or \"bilinear\" or \"linear\" or \"area\" or \"bicubic\" or \"trilinear\".'\n            )\n        if self.mode == \"nearest\" and self.align_corners:\n            raise ValueError('interpolation \"nearest\" does not support align_corners.')\n\n    def forward(self, x):\n        if len(x.shape) == 3 and self.mode == \"bilinear\":\n            raise NotImplementedError(\"Got 3D input, but bilinear mode needs 4D input\")\n        if len(x.shape) == 3 and self.mode == \"trilinear\":\n            raise NotImplementedError(\"Got 3D input, but trilinear mode needs 5D input\")\n        if len(x.shape) == 4 and self.mode == \"linear\":\n            raise NotImplementedError(\"Got 4D input, but linear mode needs 3D input\")\n        if len(x.shape) == 4 and self.mode == \"trilinear\":\n            raise NotImplementedError(\"Got 4D input, but trilinear mode needs 5D input\")\n        if len(x.shape) == 5 and self.mode == \"linear\":\n            raise NotImplementedError(\"Got 5D input, but linear mode needs 3D input\")\n        if len(x.shape) == 5 and self.mode == \"bilinear\":\n            raise NotImplementedError(\"Got 5D input, but bilinear mode needs 4D input\")\n\n        dim = len(x.shape) - 2\n        if self.size is not None and self.scale_factor is not None:\n            raise ValueError(\"only one of size or scale_factor should be defined\")\n        elif self.size is not None:\n            assert self.scale_factor is None\n            scale_factors = []\n            if isinstance(self.size, (list, tuple)):\n                if len(self.size) != dim:\n                    raise ValueError(\n                        \"size shape must match input shape. Input is {}D, size is {}\".format(\n                            dim, len(self.size)\n                        )\n                    )\n                output_size = self.size\n            else:\n                output_size = [self.size for _ in range(dim)]\n            for i in range(dim):\n                scale_factors.append(output_size[i] / x.shape[i + 2])\n        elif self.scale_factor is not None:\n            assert self.size is None\n            output_size = None\n            if isinstance(self.scale_factor, (list, tuple)):\n                if len(self.scale_factor) != dim:\n                    raise ValueError(\n                        \"scale_factor shape must match input shape. Input is {}D, scale_factor is {}\".format(\n                            dim, len(self.scale_factor)\n                        )\n                    )\n                scale_factors = self.scale_factor\n            else:\n                scale_factors = [self.scale_factor for _ in range(dim)]\n        else:\n            raise ValueError(\"either size or scale_factor should be defined\")\n        if self.recompute_scale_factor and self.size is not None:\n            raise ValueError(\n                \"recompute_scale_factor is not meaningful with an explicit size.\"\n            )\n        if self.mode == \"area\" and output_size is None:\n            self.recompute_scale_factor = True\n        if self.recompute_scale_factor is True:\n            assert scale_factors is not None\n            output_size = [\n                int(math.floor(float(x.size(i + 2)) * scale_factors[i]))\n                for i in range(dim)\n            ]\n            scale_factors = []\n            for i in range(dim):\n                scale_factors.append(output_size[i] / x.shape[2 + i])\n        if len(x.shape) == 3 and self.mode == \"nearest\":\n            return flow._C.upsample_nearest_1d(\n                x,\n                scale_factor=scale_factors[0],\n                output_size=output_size,\n                data_format=\"channels_first\",\n            )\n        if len(x.shape) == 4 and self.mode == \"nearest\":\n            return flow._C.upsample_nearest_2d(\n                x,\n                height_scale=scale_factors[0],\n                width_scale=scale_factors[1],\n                output_size=output_size,\n                data_format=\"channels_first\",\n            )\n        if len(x.shape) == 5 and self.mode == \"nearest\":\n            return flow._C.upsample_nearest_3d(\n                x,\n                depth_scale=scale_factors[0],\n                height_scale=scale_factors[1],\n                width_scale=scale_factors[2],\n                output_size=output_size,\n                data_format=\"channels_first\",\n            )\n        if len(x.shape) == 3 and self.mode == \"area\":\n            assert output_size is not None\n            return flow._C.adaptive_avg_pool1d(x, output_size)\n        if len(x.shape) == 4 and self.mode == \"area\":\n            assert output_size is not None\n            return flow._C.adaptive_avg_pool2d(x, output_size)\n        if len(x.shape) == 5 and self.mode == \"area\":\n            assert output_size is not None\n            return flow._C.adaptive_avg_pool3d(x, output_size)\n        if len(x.shape) == 3 and self.mode == \"linear\":\n            assert self.align_corners is not None\n            return flow._C.upsample_linear_1d(\n                x,\n                scale_factor=scale_factors[0],\n                align_corners=self.align_corners,\n                output_size=output_size,\n                data_format=\"channels_first\",\n            )\n        if len(x.shape) == 4 and self.mode == \"bilinear\":\n            assert self.align_corners is not None\n            return flow._C.upsample_bilinear_2d(\n                x,\n                height_scale=scale_factors[0],\n                width_scale=scale_factors[1],\n                align_corners=self.align_corners,\n                output_size=output_size,\n                data_format=\"channels_first\",\n            )\n        if len(x.shape) == 4 and self.mode == \"bicubic\":\n            assert self.align_corners is not None\n            return flow._C.upsample_bicubic_2d(\n                x,\n                height_scale=scale_factors[0],\n                width_scale=scale_factors[1],\n                align_corners=self.align_corners,\n                output_size=output_size,\n                data_format=\"channels_first\",\n            )\n        if len(x.shape) == 5 and self.mode == \"trilinear\":\n            assert self.align_corners is not None\n            return flow._C.upsample_trilinear_3d(\n                x,\n                depth_scale=scale_factors[0],\n                height_scale=scale_factors[1],\n                width_scale=scale_factors[2],\n                align_corners=self.align_corners,\n                output_size=output_size,\n                data_format=\"channels_first\",\n            )\n\n        raise NotImplementedError(\n            \"Input Error: Only 3D, 4D and 5D input Tensors supported\"\n            \" (got {}D) for the modes: nearest | linear | bilinear | bicubic | trilinear | area\"\n            \" (got {})\".format(len(x.shape), self.mode)\n        )\n\n\ndef interpolate(\n    input,\n    size=None,\n    scale_factor=None,\n    mode=\"nearest\",\n    align_corners=None,\n    recompute_scale_factor=None,\n):\n    \"\"\"The interface is consistent with PyTorch.    \n    \n    The documentation is referenced from: https://pytorch.org/docs/1.10/_modules/torch/nn/functional.html#interpolate.\n    \n\n    Down/up samples the input to either the given :attr:`size` or the given\n    :attr:`scale_factor`\n\n    The algorithm used for interpolation is determined by :attr:`mode`.\n\n    Currently temporal, spatial and volumetric sampling are supported, i.e.\n    expected inputs are 3-D, 4-D or 5-D in shape.\n\n    The input dimensions are interpreted in the form:\n    `mini-batch x channels x [optional depth] x [optional height] x width`.\n\n    The modes available for resizing are: `nearest`, `linear` (3D-only),\n    `bilinear`, `bicubic` (4D-only), `trilinear` (5D-only), `area`\n\n    Args:\n        input (Tensor): the input tensor\n        size (int or Tuple[int] or Tuple[int, int] or Tuple[int, int, int]):\n            output spatial size.\n        scale_factor (float or Tuple[float]): multiplier for spatial size. Has to match input size if it is a tuple.\n        mode (str): algorithm used for upsampling:\n            ``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` |\n            ``'trilinear'`` | ``'area'``. Default: ``'nearest'``\n        align_corners (bool, optional): Geometrically, we consider the pixels of the\n            input and output as squares rather than points.\n            If set to ``True``, the input and output tensors are aligned by the\n            center points of their corner pixels, preserving the values at the corner pixels.\n            If set to ``False``, the input and output tensors are aligned by the corner\n            points of their corner pixels, and the interpolation uses edge value padding\n            for out-of-boundary values, making this operation *independent* of input size\n            when :attr:`scale_factor` is kept the same. This only has an effect when :attr:`mode`\n            is ``'linear'``, ``'bilinear'``, ``'bicubic'`` or ``'trilinear'``.\n            Default: ``False``\n        recompute_scale_factor (bool, optional): recompute the scale_factor for use in the\n            interpolation calculation.  When `scale_factor` is passed as a parameter, it is used\n            to compute the `output_size`.  If `recompute_scale_factor` is ``False`` or not specified,\n            the passed-in `scale_factor` will be used in the interpolation computation.\n            Otherwise, a new `scale_factor` will be computed based on the output and input sizes for\n            use in the interpolation computation (i.e. the computation will be identical to if the computed\n            `output_size` were passed-in explicitly).  Note that when `scale_factor` is floating-point,\n            the recomputed scale_factor may differ from the one passed in due to rounding and precision\n            issues.\n\n    .. note::\n        With ``mode='bicubic'``, it's possible to cause overshoot, in other words it can produce\n        negative values or values greater than 255 for images.\n        Explicitly call ``result.clamp(min=0, max=255)`` if you want to reduce the overshoot\n        when displaying the image.\n\n    .. warning::\n        With ``align_corners = True``, the linearly interpolating modes\n        (`linear`, `bilinear`, and `trilinear`) don't proportionally align the\n        output and input pixels, and thus the output values can depend on the\n        input size. This was the default behavior for these modes up to version\n        0.3.1. Since then, the default behavior is ``align_corners = False``.\n        See :class:`~torch.nn.Upsample` for concrete examples on how this\n        affects the outputs.\n\n    .. warning::\n        When scale_factor is specified, if recompute_scale_factor=True,\n        scale_factor is used to compute the output_size which will then\n        be used to infer new scales for the interpolation.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n        \n        >>> input = flow.tensor(np.arange(1, 5).reshape((1, 1, 4)), dtype=flow.float32)\n        >>> output = flow.nn.functional.interpolate(input, scale_factor=2.0, mode=\"linear\")\n        >>> output\n        tensor([[[1.0000, 1.2500, 1.7500, 2.2500, 2.7500, 3.2500, 3.7500, 4.0000]]],\n               dtype=oneflow.float32)\n\n    \"\"\"\n    return Interpolate(\n        size=size,\n        scale_factor=scale_factor,\n        mode=mode,\n        align_corners=align_corners,\n        recompute_scale_factor=recompute_scale_factor,\n    ).forward(input)\n\n\ndef interpolate_like(\n    input, like, mode=\"nearest\", align_corners=None,\n):\n    \"\"\"The interface is consistent with PyTorch.\n\n    The documentation is referenced from: https://pytorch.org/docs/1.10/_modules/torch/nn/functional.html#interpolate.\n\n\n    Down/up samples the input to the same shape as the `like` tensor.\n\n    The algorithm used for interpolation is determined by :attr:`mode`.\n\n    Currently temporal, spatial and volumetric sampling are supported, i.e.\n    expected inputs are 3-D, 4-D or 5-D in shape.\n\n    The input dimensions are interpreted in the form:\n    `mini-batch x channels x [optional depth] x [optional height] x width`.\n\n    The modes available for resizing are: `nearest`, `linear` (3D-only),\n    `bilinear`, `bicubic` (4D-only), `trilinear` (5D-only), `area`\n\n    Args:\n        input (Tensor): the input tensor\n        like (Tensor): the like tensor\n        mode (str): algorithm used for upsampling:\n            ``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` |\n            ``'trilinear'`` | ``'area'``. Default: ``'nearest'``\n        align_corners (bool, optional): Geometrically, we consider the pixels of the\n            input and output as squares rather than points.\n            If set to ``True``, the input and output tensors are aligned by the\n            center points of their corner pixels, preserving the values at the corner pixels.\n            If set to ``False``, the input and output tensors are aligned by the corner\n            points of their corner pixels, and the interpolation uses edge value padding\n            for out-of-boundary values. This only has an effect when :attr:`mode`\n            is ``'linear'``, ``'bilinear'``, ``'bicubic'`` or ``'trilinear'``.\n            Default: ``False``\n\n    .. note::\n        With ``mode='bicubic'``, it's possible to cause overshoot, in other words it can produce\n        negative values or values greater than 255 for images.\n        Explicitly call ``result.clamp(min=0, max=255)`` if you want to reduce the overshoot\n        when displaying the image.\n\n    .. warning::\n        With ``align_corners = True``, the linearly interpolating modes\n        (`linear`, `bilinear`, and `trilinear`) don't proportionally align the\n        output and input pixels, and thus the output values can depend on the\n        input size. This was the default behavior for these modes up to version\n        0.3.1. Since then, the default behavior is ``align_corners = False``.\n        See :class:`~torch.nn.Upsample` for concrete examples on how this\n        affects the outputs.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n\n        >>> input = flow.tensor(np.arange(1, 5).reshape((1, 1, 2, 2)), dtype=flow.float32)\n        >>> like = flow.randn(1, 1, 4, 4)\n        >>> output = flow.nn.functional.interpolate_like(input, like, mode=\"nearest\")\n        >>> output\n        tensor([[[[1., 1., 2., 2.],\n                  [1., 1., 2., 2.],\n                  [3., 3., 4., 4.],\n                  [3., 3., 4., 4.]]]], dtype=oneflow.float32)\n\n    \"\"\"\n    return Interpolate(\n        size=like.shape[2:], mode=mode, align_corners=align_corners,\n    ).forward(input)\n\n\nif __name__ == \"__main__\":\n    import doctest\n\n    doctest.testmod(raise_on_error=True)\n"
  },
  {
    "path": "python/oneflow/nn/modules/is_tensor.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport oneflow as flow\n\n\ndef is_tensor_op(obj):\n    r\"\"\"\n    is_tensor(input) -> (bool)\n\n    Note that this function is simply doing ``isinstance(obj, Tensor)``.\n    Using that ``isinstance`` check is better for typechecking with mypy,\n    and more explicit - so it's recommended to use that instead of\n    ``is_tensor``.\n\n    Args:\n        obj (Object): Object to test\n    \n    For example:\n\n    .. code-block:: python\n    \n        >>> import oneflow as flow\n\n        >>> x=flow.tensor([1,2,3])\n        >>> flow.is_tensor(x)\n        True\n\n    \"\"\"\n    return isinstance(obj, flow.Tensor)\n"
  },
  {
    "path": "python/oneflow/nn/modules/linear.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport math\n\nimport oneflow as flow\nfrom oneflow.framework.tensor import Tensor\nfrom oneflow.nn.init import _calculate_fan_in_and_fan_out\nfrom oneflow.nn.modules.module import Module\nimport os\n\n\nclass Identity(Module):\n    \"\"\"A placeholder identity operator that is argument-insensitive.\n\n    Args:\n        args: any argument (unused)\n        kwargs: any keyword argument (unused)\n\n    For example:\n\n    .. code-block:: python\n\n        import numpy as np\n        import oneflow as flow\n\n        m = flow.nn.Identity()\n        input = flow.Tensor(np.random.rand(2, 3, 4, 5))\n\n        output = m(input)\n\n        # output = input\n\n    \"\"\"\n\n    def __init__(self, *args, **kwargs):\n        super().__init__()\n\n    def forward(self, input: Tensor) -> Tensor:\n        return input\n\n\nclass Linear(Module):\n    \"\"\"Applies a linear transformation to the incoming data: :math:`y = xA^T + b`\n\n    Args:\n\n        - in_features: size of each input sample\n\n        - out_features: size of each output sample\n\n        - bias: If set to ``False``, the layer will not learn an additive bias. Default: ``True``\n\n    Shape:\n        - Input: :math:`(N, *, H_{in})` where :math:`*` means any number of\n          additional dimensions and :math:`H_{in} = {in\\\\_features}`\n\n        - Output: :math:`(N, *, H_{out})` where all but the last dimension\n          are the same shape as the input and :math:`H_{out} = {out\\\\_features}`.\n\n    Attr:\n        - :attr:`weight`: the learnable weights of the module of shape :math:`({out\\\\_features}, {in\\\\_features})`. The values are initialized from :math:`\\\\mathcal{U}(-\\\\sqrt{k}, \\\\sqrt{k})`, where :math:`(k = 1 / {in\\\\_features})`\n\n        - :attr:`bias`: the learnable bias of the module of shape :math:`({out\\\\_features})`. If :attr:`bias` is ``True``, the values are initialized from :math:`\\\\mathcal{U}(-\\\\sqrt{k}, \\\\sqrt{k})` where :math:`(k = 1 / {in\\\\_features})`\n\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import numpy as np\n        >>> import oneflow as flow\n\n\n        >>> m = flow.nn.Linear(20, 30, False)\n        >>> input = flow.Tensor(np.random.randn(128, 20))\n        >>> output = m(input)\n        >>> output.size()\n        oneflow.Size([128, 30])\n\n    \"\"\"\n\n    def __init__(\n        self,\n        in_features: int,\n        out_features: int,\n        bias: bool = True,\n        device=None,\n        dtype=None,\n    ) -> None:\n        super().__init__()\n        self.in_features = in_features\n        self.out_features = out_features\n        self.weight = flow.nn.Parameter(\n            flow.Tensor(out_features, in_features).to(dtype=dtype, device=device)\n        )\n        self.bias = (\n            flow.nn.Parameter(flow.Tensor(out_features).to(dtype=dtype, device=device))\n            if bias\n            else None\n        )\n        self.use_fused_matmul_bias = (\n            self.bias is not None\n            and os.getenv(\"ONEFLOW_KERNEL_ENABLE_FUSED_LINEAR\") == \"1\"\n        )\n        self.reset_parameters()\n\n    def reset_parameters(self) -> None:\n        if os.getenv(\"ONEFLOW_LINEAR_EMBEDDING_SKIP_INIT\", \"0\") == \"1\":\n            return\n        flow.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))\n        if self.bias is not None:\n            (fan_in, _) = _calculate_fan_in_and_fan_out(self.weight)\n            bound = 1 / math.sqrt(fan_in)\n            flow.nn.init.uniform_(self.bias, -bound, bound)\n\n    def forward(self, x):\n        if self.use_fused_matmul_bias:\n            return flow._C.fused_matmul_bias(x, self.weight, self.bias)\n        else:\n            res = flow._C.matmul(x, self.weight, transpose_a=False, transpose_b=True)\n            if self.bias is not None:\n                res += self.bias\n            return res\n\n    def extra_repr(self) -> str:\n        return \"in_features={}, out_features={}, bias={}\".format(\n            self.in_features, self.out_features, self.bias is not None\n        )\n\n\ndef linear(input, weight, bias=None):\n    r\"\"\"\n    Applies a linear transformation to the incoming data: :math:`y = xA^T + b`.\n\n    Shape:\n\n        - Input: :math:`(N, *, in\\_features)` N is the batch size, `*` means any number of\n          additional dimensions\n        - Weight: :math:`(out\\_features, in\\_features)`\n        - Bias: :math:`(out\\_features)`\n        - Output: :math:`(N, *, out\\_features)`\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import numpy as np\n        >>> import oneflow as flow\n\n        >>> input = flow.tensor(np.random.randn(128, 20))\n        >>> weight = flow.tensor(np.random.randn(30, 20))\n        >>> output = flow.nn.functional.linear(input, weight)\n        >>> output.size()\n        oneflow.Size([128, 30])\n\n    \"\"\"\n    res = flow._C.matmul(input, weight, transpose_a=False, transpose_b=True)\n    if bias is not None:\n        res += bias\n    return res\n\n\nif __name__ == \"__main__\":\n    import doctest\n\n    doctest.testmod(raise_on_error=True)\n"
  },
  {
    "path": "python/oneflow/nn/modules/linspace.py",
    "content": "\"\"\"\r\nCopyright 2020 The OneFlow Authors. All rights reserved.\r\n\r\nLicensed under the Apache License, Version 2.0 (the \"License\");\r\nyou may not use this file except in compliance with the License.\r\nYou may obtain a copy of the License at\r\n\r\n    http://www.apache.org/licenses/LICENSE-2.0\r\n\r\nUnless required by applicable law or agreed to in writing, software\r\ndistributed under the License is distributed on an \"AS IS\" BASIS,\r\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\r\nSee the License for the specific language governing permissions and\r\nlimitations under the License.\r\n\"\"\"\r\nfrom typing import List, Optional, Union\r\nimport math\r\nimport oneflow as flow\r\n\r\n\r\ndef linspace_op(\r\n    start: Union[float, flow.Tensor],\r\n    end: Union[float, flow.Tensor],\r\n    steps: Union[int, flow.Tensor],\r\n    dtype: flow.dtype = flow.float32,\r\n    device: Union[str, flow.device] = None,\r\n    placement: flow.placement = None,\r\n    sbp: Union[flow.sbp.sbp, List[flow.sbp.sbp]] = None,\r\n    requires_grad: bool = False,\r\n):\r\n    r\"\"\"\r\n    Creates a one-dimensional tensor of size :attr:`steps` whose values are evenly\r\n    spaced from :attr:`start` to :attr:`end`, inclusive. That is, the value are:\r\n\r\n    .. math::\r\n        (\\text{start},\r\n        \\text{start} + \\frac{\\text{end} - \\text{start}}{\\text{steps} - 1},\r\n        \\ldots,\r\n        \\text{start} + (\\text{steps} - 2) * \\frac{\\text{end} - \\text{start}}{\\text{steps} - 1},\r\n        \\text{end})\r\n\r\n    Args:\r\n        start (float): the starting value for the set of points\r\n        end (float): the ending value for the set of points\r\n        steps (int): size of the constructed tensor\r\n\r\n    Keyword arguments:\r\n        dtype(flow.dtype, optional): If `dtype` is not given, the `dtype` is inferred to be `flow.float32`.\r\n        device(flow.device, optional): the desired device of returned tensor. Default: if None, uses the current device for the default tensor.\r\n        requires_grad(bool, optional): If autograd should record operations on the returned tensor. Default: `False`.\r\n\r\n    For example:\r\n\r\n    .. code-block:: python\r\n\r\n        >>> import oneflow as flow\r\n\r\n        >>> y = flow.linspace(3, 10, steps=5)\r\n        >>> y\r\n        tensor([ 3.0000,  4.7500,  6.5000,  8.2500, 10.0000], dtype=oneflow.float32)\r\n\r\n    \"\"\"\r\n\r\n    def is_scalar(tensor):\r\n        return tensor.ndim == 0 and tensor.nelement() == 1\r\n\r\n    if isinstance(start, flow.Tensor):\r\n        if not is_scalar(start):\r\n            raise TypeError(\r\n                \"linspace(): argument 'start' (position 1) must be Number, not Tensor\"\r\n            )\r\n        start = start.item()\r\n    if isinstance(end, flow.Tensor):\r\n        if not is_scalar(end):\r\n            raise TypeError(\r\n                \"linspace(): argument 'end' (position 2) must be Number, not Tensor\"\r\n            )\r\n        end = end.item()\r\n    if isinstance(steps, flow.Tensor):\r\n        if not is_scalar(steps):\r\n            raise TypeError(\r\n                \"linspace(): argument 'steps' (position 3) must be Number, not Tensor\"\r\n            )\r\n        if flow.is_floating_point(steps):\r\n            raise TypeError(\r\n                \"linspace(): argument 'steps' (position 3) must be int, not Tensor (with dtype: \"\r\n                + str(steps.dtype)\r\n                + \")\"\r\n            )\r\n        steps = steps.item()\r\n\r\n    if start == end:\r\n        return flow.full((steps,), start * 1.0)\r\n    step = 1.0\r\n    if steps == 0:\r\n        end = start\r\n    elif steps == 1:\r\n        end = start + 1.0\r\n    else:\r\n        step = (end - start) * 1.0 / (steps - 1)\r\n        if math.isclose(((end - start) / (steps - 1)) * (steps - 1), (end - start)):\r\n            end = end + step / 2.0\r\n    if placement is None:\r\n        if isinstance(device, str):\r\n            device = flow.device(device)\r\n        res = flow._C.arange(start, end, step, dtype=dtype, device=device)\r\n    else:\r\n        assert isinstance(\r\n            placement, flow._oneflow_internal.placement\r\n        ), \"placement should be oneflow._oneflow_internal.placement type.\"\r\n        assert isinstance(sbp, (flow.sbp.sbp, tuple, list)), \"sbp: %s\" % sbp\r\n        if isinstance(sbp, flow.sbp.sbp):\r\n            sbp = (sbp,)\r\n        else:\r\n            for elem in sbp:\r\n                assert isinstance(elem, flow.sbp.sbp), \"sbp: %s\" % sbp\r\n        assert len(sbp) == len(placement.ranks.shape)\r\n        res = flow._C.global_arange(\r\n            start, end, step, dtype=dtype, placement=placement, sbp=sbp\r\n        )\r\n\r\n    res.requires_grad = requires_grad\r\n    return res\r\n\r\n\r\nif __name__ == \"__main__\":\r\n    import doctest\r\n\r\n    doctest.testmod(raise_on_error=True)\r\n"
  },
  {
    "path": "python/oneflow/nn/modules/logspace.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nfrom cgitb import reset\nfrom typing import List, Optional, Union\nimport math\nimport oneflow as flow\n\n\ndef logspace_op(\n    start: float,\n    end: float,\n    steps: int,\n    base: Optional[float] = 10.0,\n    dtype: flow.dtype = None,\n    device: Union[str, flow.device] = None,\n    placement: flow.placement = None,\n    sbp: Union[flow.sbp.sbp, List[flow.sbp.sbp]] = None,\n    requires_grad: bool = False,\n):\n    r\"\"\"\n    logspace(start, end, steps, base=10.0, *, dtype=None, device=None, placement=None, sbp=None, requires_grad=False) -> Tensor\n\n    This function is equivalent to PyTorch’s logspace function. \n    The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.logspace.html.\n\n    Creates a one-dimensional tensor of size :attr:`steps` whose values are evenly\n    spaced from :math:`{{\\text{{base}}}}^{{\\text{{start}}}}` to\n    :math:`{{\\text{{base}}}}^{{\\text{{end}}}}`, inclusive, on a logarithmic scale\n    with base :attr:`base`. That is, the values are:\n\n    .. math::\n        (\\text{base}^{\\text{start}},\n        \\text{base}^{(\\text{start} + \\frac{\\text{end} - \\text{start}}{ \\text{steps} - 1})},\n        \\ldots,\n        \\text{base}^{(\\text{start} + (\\text{steps} - 2) * \\frac{\\text{end} - \\text{start}}{ \\text{steps} - 1})},\n        \\text{base}^{\\text{end}})\n\n    Args:\n        start (float): the starting value for the set of points\n        end (float): the ending value for the set of points\n        steps (int): size of the constructed tensor\n        base (float, optional): base of the logarithm function. Default: ``10.0``.\n\n    Keyword arguments:\n        dtype (oneflow.dtype, optional): the data type to perform the computation in.\n            Default: if None, uses the global default dtype (see oneflow.get_default_dtype())\n            when both :attr:`start` and :attr:`end` are real,\n            and corresponding complex dtype when either is complex.\n        device (oneflow.device, optional): the desired device of returned tensor. Default: if None, uses the current device for the default tensor type\n        placement (oneflow.placement, optional): the desired placement of returned global tensor. Default: if None, the returned tensor is local one using the argument `device`.\n        sbp (oneflow.sbp.sbp or tuple of oneflow.sbp.sbp, optional): the desired sbp descriptor of returned global tensor. Default: if None, the returned tensor is local one using the argument `device`.\n        requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: False.\n\n    Example::\n\n        >>> import oneflow as flow\n        >>> flow.logspace(start=-10, end=10, steps=2)\n        tensor([1.0000e-10, 1.0000e+10], dtype=oneflow.float32)\n        >>> flow.logspace(start=0.1, end=1.0, steps=5)\n        tensor([ 1.2589,  2.1135,  3.5481,  5.9566, 10.0000], dtype=oneflow.float32)\n        >>> flow.logspace(start=0.1, end=1.0, steps=1)\n        tensor([1.2589], dtype=oneflow.float32)\n        >>> flow.logspace(start=2, end=2, steps=1, base=2)\n        tensor([4.], dtype=oneflow.float32)\n\n    \"\"\"\n    # TODO: Migrate to C++\n    indice = flow.linspace(\n        start=start,\n        end=end,\n        steps=steps,\n        dtype=dtype,\n        device=device,\n        placement=placement,\n        sbp=sbp,\n    )\n    res = flow.pow(base, indice)\n    res.requires_grad = requires_grad\n    return res\n"
  },
  {
    "path": "python/oneflow/nn/modules/loss.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nfrom typing import Optional\n\nimport oneflow as flow\nfrom oneflow.framework.tensor import Tensor\nfrom oneflow.nn.modules.module import Module\nfrom oneflow.nn.modules.constant import _ConstantBase\n\n\nclass _Loss(Module):\n    def __init__(self, reduction: str = \"mean\") -> None:\n        super(_Loss, self).__init__()\n        assert reduction in [\"none\", \"mean\", \"sum\"]\n        self.reduction = reduction\n\n\nclass _WeightedLoss(_Loss):\n    def __init__(\n        self, weight: Optional[Tensor] = None, reduction: str = \"mean\"\n    ) -> None:\n        super(_WeightedLoss, self).__init__(reduction=reduction)\n        self.register_buffer(\"weight\", weight)\n\n\nclass L1Loss(_Loss):\n    \"\"\"This operator computes the L1 Loss between each element in `input` and `target`.\n\n    The equation is:\n\n    if reduction = \"none\":\n\n    .. math::\n\n        output = |Target - Input|\n\n    if reduction = \"mean\":\n\n    .. math::\n\n        output = \\\\frac{1}{n}\\\\sum_{i=1}^n|Target_i - Input_i|\n\n    if reduction = \"sum\":\n\n    .. math::\n\n        output = \\\\sum_{i=1}^n|Target_i - Input_i|\n\n    Args:\n        input (oneflow.Tensor): the input Tensor.\n        target (oneflow.Tensor): The target Tensor.\n        reduction (str): The reduce type, it can be one of \"none\", \"mean\", \"sum\". Defaults to \"mean\".\n\n    Returns:\n        oneflow.Tensor: The result Tensor.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n        >>> input = flow.tensor([[1, 1, 1], [2, 2, 2], [7, 7, 7]], dtype = flow.float32)\n        >>> target = flow.tensor([[4, 4, 4], [4, 4, 4], [4, 4, 4]], dtype = flow.float32)\n        >>> m = flow.nn.L1Loss(reduction=\"none\")\n        >>> out = m(input, target)\n        >>> out\n        tensor([[3., 3., 3.],\n                [2., 2., 2.],\n                [3., 3., 3.]], dtype=oneflow.float32)\n        >>> m_mean = flow.nn.L1Loss(reduction=\"mean\")\n        >>> out = m_mean(input, target)\n        >>> out\n        tensor(2.6667, dtype=oneflow.float32)\n        >>> m_mean = flow.nn.L1Loss(reduction=\"sum\")\n        >>> out = m_mean(input, target)\n        >>> out\n        tensor(24., dtype=oneflow.float32)\n    \"\"\"\n\n    def __init__(self, reduction: str = \"mean\") -> None:\n        super(L1Loss, self).__init__(reduction)\n\n    def forward(self, input: Tensor, target: Tensor) -> Tensor:\n        return flow._C.l1_loss(input, target, self.reduction)\n\n\nclass CrossEntropyLoss(_WeightedLoss):\n    r\"\"\"\n    The documentation is referenced from:\n    https://pytorch.org/docs/1.10/generated/torch.nn.CrossEntropyLoss.html.\n\n    This criterion combines :class:`~flow.nn.LogSoftmax` and :class:`~flow.nn.NLLLoss` in one single class.\n\n    It is useful when training a classification problem with `C` classes.\n    If provided, the optional argument `weight` should be a 1D Tensor assigning weight to each of the classes.\n    This is particularly useful when you have an unbalanced training set.\n\n    The `input` is expected to contain raw, unnormalized scores for each class.\n    `input` has to be a Tensor of size either :math:`(minibatch, C)` or\n    :math:`(minibatch, C, d_1, d_2, ..., d_K)`\n    with :math:`K \\geq 1` for the `K`-dimensional case (described later).\n\n    The target that this criterion expects should contain either:\n\n    - Class indices in the range :math:`[0, C)` where :math:`C` is the number of classes; if\n      `ignore_index` is specified, this loss also accepts this class index (this index\n      may not necessarily be in the class range). The unreduced (i.e. with :attr:`reduction`\n      set to ``'none'``) loss for this case can be described as:\n\n      .. math::\n          \\ell(x, y) = L = \\{l_1,\\dots,l_N\\}^\\top, \\quad\n          l_n = - w_{y_n} \\log \\frac{\\exp(x_{n,y_n})}{\\sum_{c=1}^C \\exp(x_{n,c})}\n          \\cdot \\mathbb{1}\\{y_n \\not= \\text{ignore_index}\\}\n\n      where :math:`x` is the input, :math:`y` is the target, :math:`w` is the weight,\n      :math:`C` is the number of classes, and :math:`N` spans the minibatch dimension as well as\n      :math:`d_1, ..., d_k` for the `K`-dimensional case. If\n      :attr:`reduction` is not ``'none'`` (default ``'mean'``), then\n\n      .. math::\n          \\ell(x, y) = \\begin{cases}\n              \\sum_{n=1}^N \\frac{1}{\\sum_{n=1}^N w_{y_n} \\cdot \\mathbb{1}\\{y_n \\not= \\text{ignore_index}\\}} l_n, &\n               \\text{if reduction} = \\text{'mean';}\\\\\n                \\sum_{n=1}^N l_n,  &\n                \\text{if reduction} = \\text{'sum'.}\n            \\end{cases}\n\n      Note that this case is equivalent to the combination of :class:`~torch.nn.LogSoftmax` and\n      :class:`~torch.nn.NLLLoss`.\n\n    - Probabilities for each class; useful when labels beyond a single class per minibatch item\n      are required, such as for blended labels, label smoothing, etc. The unreduced (i.e. with\n      :attr:`reduction` set to ``'none'``) loss for this case can be described as:\n\n      .. math::\n          \\ell(x, y) = L = \\{l_1,\\dots,l_N\\}^\\top, \\quad\n          l_n = - \\sum_{c=1}^C w_c \\log \\frac{\\exp(x_{n,c})}{\\sum_{i=1}^C \\exp(x_{n,i})} y_{n,c}\n\n      where :math:`x` is the input, :math:`y` is the target, :math:`w` is the weight,\n      :math:`C` is the number of classes, and :math:`N` spans the minibatch dimension as well as\n      :math:`d_1, ..., d_k` for the `K`-dimensional case. If\n      :attr:`reduction` is not ``'none'`` (default ``'mean'``), then\n\n      .. math::\n          \\ell(x, y) = \\begin{cases}\n              \\frac{\\sum_{n=1}^N l_n}{N}, &\n               \\text{if reduction} = \\text{'mean';}\\\\\n                \\sum_{n=1}^N l_n,  &\n                \\text{if reduction} = \\text{'sum'.}\n            \\end{cases}\n\n\n    Args:\n        weight (oneflow.Tensor, optional): a manual rescaling weight given to each class.\n            If given, has to be a Tensor of size `C`\n        ignore_index (int, optional): Specifies a target value that is ignored and does not\n            contribute to the input gradient. When ``reduction`` is ``mean``, the loss is averaged\n            over non-ignored targets. Note that ``ignore_index`` is only applicable when the target\n            contains class indices.\n        reduction (string, optional): Specifies the reduction to apply to the output:\n            ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will\n            be applied, ``'mean'``: the weighted mean of the output is taken,\n            ``'sum'``: the output will be summed. Default: ``'mean'``\n        label_smoothing (float, optinoal): A float in [0.0, 1.0]. Specifies the amount\n            of smoothing when computing the loss, where 0.0 means no smoothing.\n            The targets become a mixture of the original ground truth and a uniform\n            distribution as described in `Rethinking the Inception Architecture for Computer Vision <https://arxiv.org/abs/1512.00567>`_.\n            Default: :math:`0.0`.\n\n    Shape:\n        - Input: Shape ::math:`(N, C)` or :math:`(N, C, d_1, d_2, ..., d_K)` with :math:`K \\geq 1`\n          in the case of `K`-dimensional loss.\n        - Target: If containing class indices, shape :math:`(N)` or :math:`(N, d_1, d_2, ..., d_K)` with\n          :math:`K \\geq 1` in the case of K-dimensional loss where each value should be between :math:`[0, C)`.\n          If containing class probabilities, same shape as the input and each value should be between :math:`[0, 1]`.\n        - Output: If reduction is 'none', same shape as the target. Otherwise, scalar.\n\n        where:\n\n        .. math::\n            \\begin{aligned}\n                C ={} & \\text{number of classes} \\\\\n                N ={} & \\text{batch size} \\\\\n            \\end{aligned}\n\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n        \n        >>> input = flow.tensor(\n        ...    [[-0.1664078, -1.7256707, -0.14690138],\n        ...        [-0.21474946, 0.53737473, 0.99684894],\n        ...        [-1.135804, -0.50371903, 0.7645404]], dtype=flow.float32)\n        >>> target = flow.tensor(np.array([0, 1, 2]), dtype=flow.int32)\n        >>> out = flow.nn.CrossEntropyLoss(reduction=\"none\")(input, target)\n        >>> out\n        tensor([0.8020, 1.1167, 0.3583], dtype=oneflow.float32)\n        >>> out_sum = flow.nn.CrossEntropyLoss(reduction=\"sum\")(input, target)\n        >>> out_sum\n        tensor(2.2769, dtype=oneflow.float32)\n        >>> out_mean = flow.nn.CrossEntropyLoss(reduction=\"mean\")(input, target)\n        >>> out_mean\n        tensor(0.7590, dtype=oneflow.float32)\n        >>> out_ignore_0 = flow.nn.CrossEntropyLoss(reduction=\"none\", ignore_index=0)(input, target)\n        >>> out_ignore_0\n        tensor([0.0000, 1.1167, 0.3583], dtype=oneflow.float32)\n        >>> out_label_smoothing = flow.nn.CrossEntropyLoss(reduction=\"none\", label_smoothing=0.5)(input, target)\n        >>> out_label_smoothing\n        tensor([1.0586, 1.1654, 0.8864], dtype=oneflow.float32)\n        >>> probs = flow.tensor([[ 0.99495536,  0.28255007, -0.2775054 ],\n        ...    [ 0.42397153,  0.01075112,  0.56527734],\n        ...    [ 0.72356546, -0.1304398 ,  0.4068744 ]], dtype=flow.float32)\n        >>> out = flow.nn.CrossEntropyLoss()(input, probs)\n        >>> out\n        tensor(1.3305, dtype=oneflow.float32)\n\n    \"\"\"\n\n    def __init__(\n        self,\n        weight: Optional[Tensor] = None,\n        ignore_index: int = -100,\n        reduction: str = \"mean\",\n        label_smoothing: float = 0.0,\n    ) -> None:\n        super(CrossEntropyLoss, self).__init__(weight, reduction)\n        self.ignore_index = ignore_index\n        self.label_smoothing = label_smoothing\n        if self.label_smoothing < 0.0 or self.label_smoothing > 1.0:\n            raise ValueError(\n                \"label_smoothing must be between 0.0 and 1.0. Got: \", label_smoothing\n            )\n\n    def forward(self, input, target):\n        return flow._C.cross_entropy(\n            input,\n            target,\n            self.weight,\n            self.ignore_index,\n            self.reduction,\n            self.label_smoothing,\n        )\n\n\nclass BCELoss(_WeightedLoss):\n    \"\"\"This operator computes the binary cross entropy loss.\n\n    The equation is:\n\n    if reduction = \"none\":\n\n    .. math::\n\n        out = -(Target_i*log(Input_i) + (1-Target_i)*log(1-Input_i))\n\n    if reduction = \"mean\":\n\n    .. math::\n\n        out = -\\\\frac{1}{n}\\\\sum_{i=1}^n(Target_i*log(Input_i) + (1-Target_i)*log(1-Input_i))\n\n    if reduction = \"sum\":\n\n    .. math::\n\n        out = -\\\\sum_{i=1}^n(Target_i*log(Input_i) + (1-Target_i)*log(1-Input_i))\n\n    Args:\n        weight (oneflow.Tensor, optional): The manual rescaling weight to the loss. Default to None, whose corresponding weight value is 1.\n        reduction (str, optional): The reduce type, it can be one of \"none\", \"mean\", \"sum\". Defaults to \"mean\".\n\n    Attention:\n        The input value must be in the range of (0, 1). Or the loss function may return `nan` value.\n\n    Returns:\n        oneflow.Tensor: The result Tensor.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n        >>> input = flow.Tensor(np.array([[1.2, 0.2, -0.3], [0.7, 0.6, -2]]).astype(np.float32))\n        >>> target = flow.Tensor(np.array([[0, 1, 0], [1, 0, 1]]).astype(np.float32))\n        >>> weight = flow.Tensor(np.array([[2, 2, 2], [2, 2, 2]]).astype(np.float32))\n        >>> activation = flow.nn.Sigmoid()\n        >>> sigmoid_input = activation(input)\n        >>> m = flow.nn.BCELoss(weight, reduction=\"none\")\n        >>> out = m(sigmoid_input, target)\n        >>> out\n        tensor([[2.9266, 1.1963, 1.1087],\n                [0.8064, 2.0750, 4.2539]], dtype=oneflow.float32)\n        >>> m_sum = flow.nn.BCELoss(weight, reduction=\"sum\")\n        >>> out = m_sum(sigmoid_input, target)\n        >>> out\n        tensor(12.3668, dtype=oneflow.float32)\n        >>> m_mean = flow.nn.BCELoss(weight, reduction=\"mean\")\n        >>> out = m_mean(sigmoid_input, target)\n        >>> out\n        tensor(2.0611, dtype=oneflow.float32)\n        >>> m_none = flow.nn.BCELoss()\n        >>> out = m_none(sigmoid_input, target)\n        >>> out\n        tensor(1.0306, dtype=oneflow.float32)\n\n    \"\"\"\n\n    def __init__(\n        self, weight: Optional[Tensor] = None, reduction: str = \"mean\"\n    ) -> None:\n        super(BCELoss, self).__init__(weight, reduction)\n\n    def forward(self, input: Tensor, target: Tensor) -> Tensor:\n        return flow._C.binary_cross_entropy_loss(\n            input, target, self.weight, self.reduction\n        )\n\n\nclass NLLLoss(_WeightedLoss):\n    \"\"\" The negative log likelihood loss. It is useful to train a classification\n    problem with `C` classes.\n\n    The `input` given through a forward call is expected to contain\n    log-probabilities of each class. `input` has to be a Tensor of size either\n    :math:`(minibatch, C)` or :math:`(minibatch, C, d_1, d_2, ..., d_K)`\n    with :math:`K \\\\geq 1` for the `K`-dimensional case (described later).\n\n    Obtaining log-probabilities in a neural network is easily achieved by\n    adding a  `LogSoftmax`  layer in the last layer of your network.\n    You may use `CrossEntropyLoss` instead, if you prefer not to add an extra\n    layer.\n\n    The `target` that this loss expects should be a class index in the range :math:`[0, C-1]`\n    where `C = number of classes`;\n\n    The unreduced (i.e. with :attr:`reduction` set to ``'none'``) loss can be described as:\n\n    .. math::\n        \\\\ell(x, y) = L = \\\\{l_1,\\\\dots,l_N\\\\}^\\\\top, \\\\quad\n        l_n = - w_{y_n} x_{n,y_n}, \\\\quad\n        w_{c} = \\\\mathbb{1},\n\n    where :math:`x` is the input, :math:`y` is the target, :math:`w` is the weight, and\n    :math:`N` is the batch size. If :attr:`reduction` is not ``'none'``\n    (default ``'mean'``), then\n\n    .. math::\n        \\\\ell(x, y) = \\\\begin{cases}\n            \\\\sum_{n=1}^N \\\\frac{1}{N} l_n, &\n            \\\\text{if reduction} = \\\\text{`mean';}\\\\\\\\\n            \\\\sum_{n=1}^N l_n,  &\n            \\\\text{if reduction} = \\\\text{`sum'.}\n        \\\\end{cases}\n\n    Can also be used for higher dimension inputs, such as 2D images, by providing\n    an input of size :math:`(minibatch, C, d_1, d_2, ..., d_K)` with :math:`K \\\\geq 1`,\n    where :math:`K` is the number of dimensions, and a target of appropriate shape\n    (see below). In the case of images, it computes NLL loss per-pixel.\n\n    Args:\n        reduction (string, optional): Specifies the reduction to apply to the output:\n            ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will\n            be applied, ``'mean'``: the weighted mean of the output is taken,\n            ``'sum'``: the output will be summed. Default: ``'mean'``\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n\n        >>> input = flow.tensor(\n        ... [[-0.1664078, -1.7256707, -0.14690138],\n        ... [-0.21474946, 0.53737473, 0.99684894],\n        ... [-1.135804, -0.50371903, 0.7645404]], dtype=flow.float32)\n        >>> target = flow.tensor(np.array([0, 1, 2]), dtype=flow.int32)\n        >>> m = flow.nn.NLLLoss(reduction=\"none\")\n        >>> out = m(input, target)\n        >>> out\n        tensor([ 0.1664, -0.5374, -0.7645], dtype=oneflow.float32)\n\n        >>> m = flow.nn.NLLLoss(reduction=\"sum\")\n        >>> out = m(input, target)\n        >>> out\n        tensor(-1.1355, dtype=oneflow.float32)\n\n        >>> m = flow.nn.NLLLoss(reduction=\"mean\")\n        >>> out = m(input, target)\n        >>> out\n        tensor(-0.3785, dtype=oneflow.float32)\n\n    \"\"\"\n\n    def __init__(\n        self,\n        weight: Optional[Tensor] = None,\n        ignore_index: int = -100,\n        reduction: str = \"mean\",\n    ) -> None:\n        super(NLLLoss, self).__init__(weight, reduction)\n        self.ignore_index = ignore_index\n\n    def forward(self, input: Tensor, target: Tensor) -> Tensor:\n        return flow._C.nll_loss(\n            input, target, self.weight, self.ignore_index, self.reduction\n        )\n\n\nclass KLDivLoss(_Loss):\n    \"\"\"\n    The Kullback-Leibler divergence loss measure\n\n    `Kullback-Leibler divergence`_ is a useful distance measure for continuous\n    distributions and is often useful when performing direct regression over\n    the space of (discretely sampled) continuous output distributions.\n\n    As with :class:`~torch.nn.NLLLoss`, the `input` given is expected to contain\n    *log-probabilities* and is not restricted to a 2D Tensor.\n    The targets are interpreted as *probabilities* by default, but could be considered\n    as *log-probabilities* with :attr:`log_target` set to ``True``.\n\n    This criterion expects a `target` `Tensor` of the same size as the\n    `input` `Tensor`.\n\n    The unreduced (i.e. with :attr:`reduction` set to ``'none'``) loss can be described as:\n\n    .. math::\n        l(x,y) = L = \\\\{ l_1,\\\\dots,l_N \\\\}, \\\\quad\n        l_n = y_n \\\\cdot \\\\left( \\\\log y_n - x_n \\\\right)\n\n    where the index :math:`N` spans all dimensions of ``input`` and :math:`L` has the same\n    shape as ``input``. If :attr:`reduction` is not ``'none'`` (default ``'mean'``), then:\n\n    .. math::\n        \\\\ell(x, y) = \\\\begin{cases}\n            \\\\operatorname{mean}(L), & \\\\text{if reduction} = \\\\text{`mean';} \\\\\\\\\n            \\\\operatorname{sum}(L),  & \\\\text{if reduction} = \\\\text{`sum'.}\n        \\\\end{cases}\n\n    In default :attr:`reduction` mode ``'mean'``, the losses are averaged for each minibatch over observations\n    **as well as** over dimensions. ``'batchmean'`` mode gives the correct KL divergence where losses\n    are averaged over batch dimension only. ``'mean'`` mode's behavior will be changed to the same as\n    ``'batchmean'`` in the next major release.\n\n    .. _`kullback-leibler divergence`: https://en.wikipedia.org/wiki/Kullback-Leibler_divergence\n\n    The interface is consistent with PyTorch.\n    The documentation is referenced from:\n    https://pytorch.org/docs/1.10/generated/torch.nn.KLDivLoss.html.\n\n    Args:\n        reduction (string, optional): Specifies the reduction to apply to the output:\n            ``'none'`` | ``'batchmean'`` | ``'sum'`` | ``'mean'``.\n            ``'none'``: no reduction will be applied.\n            ``'batchmean'``: the sum of the output will be divided by batchsize.\n            ``'sum'``: the output will be summed.\n            ``'mean'``: the output will be divided by the number of elements in the output.\n            Default: ``'mean'``\n        log_target (bool, optional): Specifies whether `target` is passed in the log space.\n            Default: ``False``\n\n    .. note::\n        :attr:`reduction` = ``'mean'`` doesn't return the true kl divergence value, please use\n        :attr:`reduction` = ``'batchmean'`` which aligns with KL math definition.\n        In the next major release, ``'mean'`` will be changed to be the same as ``'batchmean'``.\n\n    Shape:\n        - Input: :math:`(N, *)` where :math:`*` means, any number of additional\n          dimensions\n        - Target: :math:`(N, *)`, same shape as the input\n        - Output: scalar by default. If :attr:``reduction`` is ``'none'``, then :math:`(N, *)`,\n          the same shape as the input\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n        >>> input = flow.tensor([-0.9021705, 0.08798598, 1.04686249], dtype=flow.float32)\n        >>> target = flow.tensor([1.22386942, -0.89729659, 0.01615712], dtype=flow.float32)\n        >>> m = flow.nn.KLDivLoss(reduction=\"none\", log_target=False)\n        >>> out = m(input, target)\n        >>> out\n        tensor([ 1.3514,  0.0000, -0.0836], dtype=oneflow.float32)\n        >>> m = flow.nn.KLDivLoss(reduction=\"mean\", log_target=False)\n        >>> out = m(input, target)\n        >>> out\n        tensor(0.4226, dtype=oneflow.float32)\n        >>> m = flow.nn.KLDivLoss(reduction=\"sum\", log_target=True)\n        >>> out = m(input, target)\n        >>> out\n        tensor(5.7801, dtype=oneflow.float32)\n\n    \"\"\"\n\n    def __init__(self, reduction: str = \"mean\", log_target: bool = False) -> None:\n        if reduction == \"batchmean\":\n            super(KLDivLoss, self).__init__(\"sum\")\n            self.reduction = \"batchmean\"\n        else:\n            super(KLDivLoss, self).__init__(reduction)\n\n        self.log_target = log_target\n\n    def forward(self, input: Tensor, target: Tensor) -> Tensor:\n        return flow._C.kl_div_loss(input, target, self.log_target, self.reduction)\n\n\nclass MSELoss(_Loss):\n    \"\"\"\n    Creates a criterion that measures the mean squared error (squared L2 norm) between\n    each element in the input :math:`x` and target :math:`y`.\n\n    The unreduced (i.e. with :attr:`reduction` set to ``'none'``) loss can be described as:\n\n    .. math::\n        \\\\ell(x, y) = L = \\\\{l_1,\\\\dots,l_N\\\\}^\\\\top, \\\\quad\n        l_n = \\\\left( x_n - y_n \\\\right)^2,\n\n    where :math:`N` is the batch size. If :attr:`reduction` is not ``'none'``\n    (default ``'mean'``), then:\n\n    .. math::\n        \\\\ell(x, y) =\n        \\\\begin{cases}\n            \\\\operatorname{mean}(L), &  \\\\text{if reduction} = \\\\text{`mean';}\\\\\\\\\n            \\\\operatorname{sum}(L),  &  \\\\text{if reduction} = \\\\text{`sum'.}\n        \\\\end{cases}\n\n    :math:`x` and :math:`y` are tensors of arbitrary shapes with a total\n    of :math:`n` elements each.\n\n    The mean operation still operates over all the elements, and divides by :math:`n`.\n\n    The division by :math:`n` can be avoided if one sets ``reduction = 'sum'``.\n\n    The interface is consistent with PyTorch.\n    The documentation is referenced from:\n    https://pytorch.org/docs/1.10/generated/torch.nn.MSELoss.html.\n\n    Args:\n        reduction (string, optional): Specifies the reduction to apply to the output:\n            ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,\n            ``'mean'``: the sum of the output will be divided by the number of\n            elements in the output, ``'sum'``: the output will be summed. Default: ``'mean'``\n\n    Shape:\n        - Input: :math:`(N, *)` where :math:`*` means, any number of additional\n          dimensions\n        - Target: :math:`(N, *)`, same shape as the input\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n        >>> input = flow.tensor(\n        ... [[-0.02557137, 0.03101675, 1.37493674],\n        ... [0.25599439, -1.08372561, -0.21006816]], dtype=flow.float32)\n        >>> target = flow.tensor(\n        ... [[-1.53105064, -0.68137555, 0.5931354],\n        ... [-0.49158347, 0.93673637, 0.1324141]], dtype=flow.float32)\n        >>> m = flow.nn.MSELoss(reduction=\"none\")\n        >>> out = m(input, target)\n        >>> out\n        tensor([[2.2665, 0.5075, 0.6112],\n                [0.5589, 4.0823, 0.1173]], dtype=oneflow.float32)\n        >>> m = flow.nn.MSELoss(reduction=\"mean\")\n        >>> out = m(input, target)\n        >>> out\n        tensor(1.3573, dtype=oneflow.float32)\n        >>> m = flow.nn.MSELoss(reduction=\"sum\")\n        >>> out = m(input, target)\n        >>> out\n        tensor(8.1436, dtype=oneflow.float32)\n\n    \"\"\"\n\n    def __init__(self, reduction: str = \"mean\") -> None:\n        super(MSELoss, self).__init__(reduction)\n\n    def forward(self, input: Tensor, target: Tensor) -> Tensor:\n        return flow._C.mse_loss(input, target, self.reduction)\n\n\nclass MarginRankingLoss(_Loss):\n    \"\"\"Creates a criterion that measures the loss given\n    inputs :math:`x1`, :math:`x2`, two 1D mini-batch `Tensors`,\n    and a label 1D mini-batch tensor :math:`y` (containing 1 or -1).\n\n    If :math:`y = 1` then it assumed the first input should be ranked higher\n    (have a larger value) than the second input, and vice-versa for :math:`y = -1`.\n\n    The loss function for each sample in the mini-batch is:\n\n    .. math::\n        \\\\text{loss}(x1, x2, y) = \\\\max(0, -y * (x1 - x2) + \\\\text{margin})\n\n    Args:\n        margin (float, optional): Has a default value of :math:`0`.\n        reduction (string, optional): Specifies the reduction to apply to the output:\n            ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,\n            ``'mean'``: the sum of the output will be divided by the number of\n            elements in the output, ``'sum'``: the output will be summed. Default: ``'mean'``\n\n    Shape:\n        - `x1` : :math:`(N, D)` where `N` is the batch size and `D` is the size of a sample.\n        - `x2` : :math:`(N, D)` where `N` is the batch size and `D` is the size of a sample.\n        - Target: :math:`(N)`\n        - Output: scalar. If :attr:`reduction` is ``'none'``, then :math:`(N)`.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n        >>> x1 = flow.tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), dtype=flow.float32)\n        >>> x2 = flow.tensor(np.array([[2, 2, 2], [2, 2, 2], [2, 2, 2]]), dtype=flow.float32)\n        >>> target = flow.tensor(np.array([[1, -1, 1],[-1, 1, -1], [1, 1, 1]]), dtype=flow.float32)\n        >>> m = flow.nn.MarginRankingLoss(margin =1.0, reduction=\"none\")\n        >>> out = m(x1, x2, target)\n        >>> out\n        tensor([[2., 1., 0.],\n                [3., 0., 5.],\n                [0., 0., 0.]], dtype=oneflow.float32)\n\n        >>> m = flow.nn.MarginRankingLoss(margin = 0.3, reduction=\"sum\")\n        >>> out = m(x1, x2, target)\n        >>> out\n        tensor(8.2000, dtype=oneflow.float32)\n\n        >>> m = flow.nn.MarginRankingLoss(margin = 10, reduction=\"mean\")\n        >>> out = m(x1, x2, target)\n        >>> out\n        tensor(8.3333, dtype=oneflow.float32)\n\n\n    \"\"\"\n\n    def __init__(self, margin: float = 0.0, reduction: str = \"mean\") -> None:\n        super(MarginRankingLoss, self).__init__(reduction)\n        self.margin = margin\n\n    def forward(self, input1: Tensor, input2: Tensor, target: Tensor) -> Tensor:\n        return flow._C.margin_ranking_loss(\n            input1, input2, target, self.margin, self.reduction\n        )\n\n\nclass CTCLoss(_Loss):\n    \"\"\"The Connectionist Temporal Classification loss.\n    The interface is consistent with PyTorch.\n    The documentation is referenced from:\n    https://pytorch.org/docs/1.10/generated/torch.nn.CTCLoss.html.\n\n    Calculates loss between a continuous (unsegmented) time series and a target sequence. CTCLoss sums over the\n    probability of possible alignments of input to target, producing a loss value which is differentiable\n    with respect to each input node. The alignment of input to target is assumed to be \"many-to-one\", which\n    limits the length of the target sequence such that it must be :math:`\\\\leq` the input length.\n\n    Args:\n        blank (int, optional): blank label. Default :math:`0`.\n        reduction (string, optional): Specifies the reduction to apply to the output:\n            ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,\n            ``'mean'``: the output losses will be divided by the target lengths and\n            then the mean over the batch is taken. Default: ``'mean'``\n        zero_infinity (bool, optional):\n            Whether to zero infinite losses and the associated gradients.\n            Default: ``False``\n            Infinite losses mainly occur when the inputs are too short\n            to be aligned to the targets.\n\n    Shape:\n        - Log_probs: Tensor of size :math:`(T, N, C)`,\n          where :math:`T = \\\\text{input length}`,\n          :math:`N = \\\\text{batch size}`, and\n          :math:`C = \\\\text{number of classes (including blank)}`.\n        - Targets: Tensor of size :math:`(N, S)` or\n          :math:`(\\\\operatorname{sum}(\\\\text{target_lengths}))`,\n          where :math:`N = \\\\text{batch size}` and\n          :math:`S = \\\\text{max target length, if shape is } (N, S)`.\n          It represent the target sequences. Each element in the target\n          sequence is a class index. And the target index cannot be blank (default=0).\n          In the :math:`(N, S)` form, targets are padded to the\n          length of the longest sequence, and stacked.\n          In the :math:`(\\\\operatorname{sum}(\\\\text{target_lengths}))` form,\n          the targets are assumed to be un-padded and\n          concatenated within 1 dimension.\n        - Input_lengths: Tuple or tensor of size :math:`(N)`,\n          where :math:`N = \\\\text{batch size}`. It represent the lengths of the\n          inputs (must each be :math:`\\\\leq T`). And the lengths are specified\n          for each sequence to achieve masking under the assumption that sequences\n          are padded to equal lengths.\n        - Target_lengths: Tuple or tensor of size :math:`(N)`,\n          where :math:`N = \\\\text{batch size}`. It represent lengths of the targets.\n          Lengths are specified for each sequence to achieve masking under the\n          assumption that sequences are padded to equal lengths. If target shape is\n          :math:`(N,S)`, target_lengths are effectively the stop index\n          :math:`s_n` for each target sequence, such that ``target_n = targets[n,0:s_n]`` for\n          each target in a batch. Lengths must each be :math:`\\\\leq S`\n          If the targets are given as a 1d tensor that is the concatenation of individual\n          targets, the target_lengths must add up to the total length of the tensor.\n\n    Reference:\n        A. Graves et al.: Connectionist Temporal Classification:\n        Labelling Unsegmented Sequence Data with Recurrent Neural Networks:\n        https://www.cs.toronto.edu/~graves/icml_2006.pdf\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        \n        >>> log_probs = flow.tensor(\n        ...    [\n        ...        [[-1.1031, -0.7998, -1.5200], [-0.9808, -1.1363, -1.1908]],\n        ...        [[-1.2258, -1.0665, -1.0153], [-1.1135, -1.2331, -0.9671]],\n        ...        [[-1.3348, -0.6611, -1.5118], [-0.9823, -1.2355, -1.0941]],\n        ...        [[-1.3850, -1.3273, -0.7247], [-0.8235, -1.4783, -1.0994]],\n        ...        [[-0.9049, -0.8867, -1.6962], [-1.4938, -1.3630, -0.6547]],\n        ...    ], dtype=flow.float32)\n        >>> targets = flow.tensor([[1, 2, 2], [1, 2, 2]], dtype=flow.int32)\n        >>> input_lengths = flow.tensor([5, 5], dtype=flow.int32)\n        >>> target_lengths = flow.tensor([3, 3], dtype=flow.int32)\n        >>> loss_mean = flow.nn.CTCLoss()\n        >>> out = loss_mean(log_probs, targets, input_lengths, target_lengths)\n        >>> out\n        tensor(1.1376, dtype=oneflow.float32)\n        >>> loss_sum = flow.nn.CTCLoss(blank=0, reduction=\"sum\")\n        >>> out = loss_sum(log_probs, targets, input_lengths, target_lengths)\n        >>> out\n        tensor(6.8257, dtype=oneflow.float32)\n\n    \"\"\"\n\n    def __init__(\n        self, blank: int = 0, reduction: str = \"mean\", zero_infinity: bool = False\n    ) -> None:\n        super(CTCLoss, self).__init__(reduction)\n        self.blank = blank\n        self.zero_infinity = zero_infinity\n\n    def forward(\n        self,\n        log_probs: Tensor,\n        targets: Tensor,\n        input_lengths: Tensor,\n        target_lengths: Tensor,\n    ) -> Tensor:\n        max_target_length = 0\n        if targets.ndim == 1:\n            max_target_length = target_lengths.max().item()\n        elif targets.ndim == 2:\n            max_target_length = targets.shape[1]\n        return flow._C.ctc_loss(\n            log_probs,\n            targets,\n            input_lengths,\n            target_lengths,\n            max_target_length,\n            self.blank,\n            self.zero_infinity,\n            self.reduction,\n        )\n\n\nclass BCEWithLogitsLoss(_WeightedLoss):\n    \"\"\"This operator combines the `Sigmoid` and `BCELoss` together. For numerical stability,\n    we apply some math tricks instead of using `Sigmoid` layer with `BCELoss`.\n\n    The equation is:\n\n    if :attr:`reduction` = ``\"none\"``:\n\n    .. math::\n\n        out = -weight*[Pos\\\\_weight*y*log\\\\sigma({x}) + (1-y)*log(1-\\\\sigma(x))]\n\n    if :attr:`reduction` = ``\"mean\"``:\n\n    .. math::\n\n        out = -\\\\frac{weight}{n}\\\\sum_{i=1}^n[Pos\\\\_weight*y*log\\\\sigma({x}) + (1-y)*log(1-\\\\sigma(x))]\n\n    if :attr:`reduction` = ``\"sum\"``:\n\n    .. math::\n\n        out = -weight*\\\\sum_{i=1}^n[Pos\\\\_weight*y*log\\\\sigma({x}) + (1-y)*log(1-\\\\sigma(x))]\n\n    Args:\n        weight (Tensor, optional): The manual rescaling weight to the loss. Default: ``None``\n        size_average (bool, optional): Deprecated (see :attr:`reduction`). Default: ``True``\n        reduce (bool, optional): Deprecated (see :attr:`reduction`). Default: ``True``\n        reduction (str, optional): The reduce type, it can be one of ``\"none\"``, ``\"mean\"``, ``\"sum\"``.\n            ``'none'``: no reduction will be applied, ``'mean'``: the sum of the output will be divided\n            by the number of elements in the output, ``'sum'``: the output will be summed. Default: ``\"mean\"``\n        pos_weight (Tensor, optional): The manual rescaling weight to the positive examples.\n            Default: ``None``\n\n    Shape:\n        - Input: :math:`(N,*)` where `*` means, any number of additional dimensions\n        - Target: :math:`(N,*)`, same shape as the input\n        - Output: scalar. If :attr:`reduction` is ``\"none\"``, then :math:`(N,*)`, same shape as input.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> input = flow.tensor([[1.2, 0.2, -0.3], [0.7, 0.6, -2], [0.7, 0.6, -2]], dtype=flow.float32)\n        >>> target = flow.tensor([[0, 1, 0], [1, 0, 1], [1, 0, 1]], dtype=flow.float32)\n        >>> weight = flow.tensor([[2, 2, 2], [2, 2, 2], [2, 2, 2]], dtype=flow.float32)\n        >>> pos_weight = flow.tensor([1.2, 1.3, 1.4], dtype=flow.float32)\n\n        >>> m = flow.nn.BCEWithLogitsLoss(weight=weight, pos_weight=pos_weight, reduction=\"none\")\n        >>> out = m(input, target)\n        >>> out\n        tensor([[2.9266, 1.5552, 1.1087],\n                [0.9676, 2.0750, 5.9554],\n                [0.9676, 2.0750, 5.9554]], dtype=oneflow.float32)\n\n        >>> m = flow.nn.BCEWithLogitsLoss(weight=weight, pos_weight=pos_weight, reduction=\"mean\")\n        >>> out = m(input, target)\n        >>> out\n        tensor(2.6207, dtype=oneflow.float32)\n\n        >>> m = flow.nn.BCEWithLogitsLoss(weight=weight, pos_weight=pos_weight, reduction=\"sum\")\n        >>> out = m(input, target)\n        >>> out\n        tensor(23.5865, dtype=oneflow.float32)\n\n\n    \"\"\"\n\n    def __init__(\n        self,\n        weight: Optional[Tensor] = None,\n        reduction: str = \"mean\",\n        pos_weight: Optional[Tensor] = None,\n    ) -> None:\n        super(BCEWithLogitsLoss, self).__init__(weight, reduction)\n        self.reduction = reduction\n        self.pos_weight = pos_weight\n\n    def forward(self, input: Tensor, target: Tensor) -> Tensor:\n        return flow._C.binary_cross_entropy_with_logits_loss(\n            input, target, self.weight, self.pos_weight, self.reduction\n        )\n\n\nclass SmoothL1Loss(_Loss):\n    \"\"\"Creates a criterion that uses a squared term if the absolute\n    element-wise error falls below beta and an L1 term otherwise.\n    The interface is consistent with PyTorch.\n    The documentation is referenced from:\n    https://pytorch.org/docs/1.10/generated/torch.nn.SmoothL1Loss.html.\n\n    It is less sensitive to outliers than :class:`torch.nn.MSELoss` and in some cases\n    prevents exploding gradients (e.g. see the paper `Fast R-CNN <https://openaccess.thecvf.com/content_iccv_2015/papers/Girshick_Fast_R-CNN_ICCV_2015_paper.pdf>`__ by Ross Girshick)..\n\n    For a batch of size :math:`N`, the unreduced loss can be described as:\n\n    .. math::\n        \\\\ell(x, y) = L = \\\\{l_1, ..., l_N\\\\}^T\n\n    with\n\n    .. math::\n        l_n = \\\\begin{cases}\n        0.5 (x_n - y_n)^2 / beta, & \\\\text{if } |x_n - y_n| < beta \\\\\\\\\n        |x_n - y_n| - 0.5 * beta, & \\\\text{otherwise }\n        \\\\end{cases}\n\n    If `reduction` is not `none`, then:\n\n    .. math::\n        \\\\ell(x, y) =\n        \\\\begin{cases}\n            \\\\operatorname{mean}(L), &  \\\\text{if reduction} = \\\\text{`mean';}\\\\\\\\\n            \\\\operatorname{sum}(L),  &  \\\\text{if reduction} = \\\\text{`sum'.}\n        \\\\end{cases}\n\n    .. note::\n        Smooth L1 loss can be seen as exactly :class:`L1Loss`, but with the :math:`|x - y| < beta`\n        portion replaced with a quadratic function such that its slope is 1 at :math:`|x - y| = beta`.\n        The quadratic segment smooths the L1 loss near :math:`|x - y| = 0`.\n\n    .. note::\n        Smooth L1 loss is closely related to :class:`HuberLoss`, being\n        equivalent to :math:`huber(x, y) / beta` (note that Smooth L1's beta hyper-parameter is\n        also known as delta for Huber). This leads to the following differences:\n\n        * As beta -> 0, Smooth L1 loss converges to :class:`L1Loss`, while :class:`HuberLoss`\n          converges to a constant 0 loss.\n        * As beta -> :math:`+\\\\infty`, Smooth L1 loss converges to a constant 0 loss, while\n          :class:`HuberLoss` converges to :class:`MSELoss`.\n        * For Smooth L1 loss, as beta varies, the L1 segment of the loss has a constant slope of 1.\n          For :class:`HuberLoss`, the slope of the L1 segment is beta.\n\n    Args:\n        size_average (bool, optional): Deprecated (see :attr:`reduction`). By default,\n            the losses are averaged over each loss element in the batch. Note that for\n            some losses, there are multiple elements per sample. If the field :attr:`size_average`\n            is set to ``False``, the losses are instead summed for each minibatch. Ignored\n            when :attr:`reduce` is ``False``. Default: ``True``\n        reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the\n            losses are averaged or summed over observations for each minibatch depending\n            on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per\n            batch element instead and ignores :attr:`size_average`. Default: ``True``\n        reduction (string, optional): Specifies the reduction to apply to the output:\n            ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,\n            ``'mean'``: the sum of the output will be divided by the number of\n            elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average`\n            and :attr:`reduce` are in the process of being deprecated, and in the meantime,\n            specifying either of those two args will override :attr:`reduction`. Default: ``'mean'``\n        beta (float, optional): Specifies the threshold at which to change between L1 and L2 loss.\n            The value must be non-negative. Default: 1.0\n\n    Shape:\n        - Input: :math:`(N, *)` where :math:`*` means any number of additional dimensions\n        - Target: :math:`(N, *)`; same shape as the input\n        - Output: scalar. If :attr:`reduction` is ``'none'``, then :math:`(N, *)`; same shape as the input\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n        \n        >>> x = flow.tensor(np.array([0.1, 0.4, 0.3, 0.5, 0.9]).astype(np.float32), dtype=flow.float32)\n        >>> y = flow.tensor(np.array([0.3, 0.9, 2.5, 0.4, 0.3]).astype(np.float32), dtype=flow.float32)\n        >>> m = flow.nn.SmoothL1Loss(reduction=\"none\")\n        >>> out = m(x, y)\n        >>> out\n        tensor([0.0200, 0.1250, 1.7000, 0.0050, 0.1800], dtype=oneflow.float32)\n\n        >>> m = flow.nn.SmoothL1Loss(reduction=\"mean\")\n        >>> out = m(x, y)\n        >>> out\n        tensor(0.4060, dtype=oneflow.float32)\n\n        >>> m = flow.nn.SmoothL1Loss(reduction=\"sum\")\n        >>> out = m(x, y)\n        >>> out\n        tensor(2.0300, dtype=oneflow.float32)\n    \"\"\"\n\n    def __init__(self, reduction: str = \"mean\", beta: float = 1.0) -> None:\n        super(SmoothL1Loss, self).__init__(reduction)\n        self.beta = beta\n\n    def forward(self, input: Tensor, target: Tensor) -> Tensor:\n        return flow._C.smooth_l1_loss(input, target, self.beta, self.reduction)\n\n\nclass CombinedMarginLoss(Module):\n    r\"\"\"The operation implements \"margin_softmax\" in InsightFace:\n    https://github.com/deepinsight/insightface/blob/master/recognition/arcface_mxnet/train.py\n    The implementation of margin_softmax in InsightFace is composed of multiple operators.\n    We fuse them for speed up.\n\n    Applies the function:\n\n    .. math::\n\n        {\\rm CombinedMarginLoss}(x_i, label) =\n        \\left\\{\\begin{matrix} \\cos(m_1\\cdot\\arccos x_i+m_2) - m_3 & {\\rm if} \\ i == label \\\\\n        x_i & {\\rm otherwise} \\end{matrix}\\right.\n\n\n    Args:\n        x (oneflow.Tensor): A Tensor\n        label (oneflow.Tensor): label with integer data type\n        m1 (float): loss m1 parameter\n        m2 (float): loss m2 parameter\n        m3 (float): loss m3 parameter\n\n    .. note::\n\n        Here are some special cases:\n\n        - when :math:`m_1=1, m_2\\neq 0, m_3=0`, CombineMarginLoss has the same parameter as `ArcFace <https://arxiv.org/abs/1801.07698>`__ .\n\n        - when :math:`m_1=1, m_2=0, m_3\\neq 0`, CombineMarginLoss has the same parameter as `CosFace (a.k.a AM-Softmax) <https://arxiv.org/abs/1801.09414>`__ .\n\n        - when :math:`m_1\\gt 1, m_2=m_3=0`, CombineMarginLoss has the same parameter as `A-Softmax <https://arxiv.org/abs/1704.08063>`__.\n\n    Returns:\n        oneflow.Tensor: A Tensor\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import numpy as np\n        >>> import oneflow as flow\n        >>> np_x = np.array([[-0.7027179, 0.0230609], [-0.02721931, -0.16056311], [-0.4565852, -0.64471215]])\n        >>> np_label = np.array([0, 1, 1])\n        >>> x = flow.tensor(np_x, dtype=flow.float32)\n        >>> label = flow.tensor(np_label, dtype=flow.int32)\n        >>> loss_func = flow.nn.CombinedMarginLoss(0.3, 0.5, 0.4)\n        >>> out = loss_func(x, label)\n        >>> out\n        tensor([[-0.0423,  0.0231],\n                [-0.0272,  0.1237],\n                [-0.4566, -0.0204]], dtype=oneflow.float32)\n\n    \"\"\"\n\n    def __init__(self, m1: float = 1.0, m2: float = 0.0, m3: float = 0.0) -> None:\n        super().__init__()\n        self.m1 = m1\n        self.m2 = m2\n        self.m3 = m3\n\n    def forward(self, x: Tensor, label: Tensor) -> Tensor:\n        return flow._C.combined_margin_loss(\n            x, label, m1=self.m1, m2=self.m2, m3=self.m3\n        )\n\n\nclass TripletMarginLoss(Module):\n    r\"\"\"Creates a criterion that measures the triplet loss given an input\n    tensors :math:`x1`, :math:`x2`, :math:`x3` and a margin with a value greater than :math:`0`.\n    This is used for measuring a relative similarity between samples. A triplet\n    is composed by `a`, `p` and `n` (i.e., `anchor`, `positive examples` and `negative\n    examples` respectively). The shapes of all input tensors should be\n    :math:`(N, D)`.\n\n    The distance swap is described in detail in the paper `Learning shallow\n    convolutional feature descriptors with triplet losses <http://www.bmva.org/bmvc/2016/papers/paper119/index.html>`__ by\n    V. Balntas, E. Riba et al.\n\n    The loss function for each sample in the mini-batch is:\n\n    .. math::\n        L(a, p, n) = \\max \\{d(a_i, p_i) - d(a_i, n_i) + {\\rm margin}, 0\\}\n\n\n    where\n\n    .. math::\n        d(x_i, y_i) = \\left\\lVert {\\bf x}_i - {\\bf y}_i \\right\\rVert_p\n\n    Args:\n        margin (float, optional): Default: :math:`1`.\n        p (float, optional): The norm degree for pairwise distance. Default: :math:`2.0`.\n        swap (bool, optional): The distance swap is described in detail in the paper\n            `Learning shallow convolutional feature descriptors with triplet losses` by\n            V. Balntas, E. Riba et al. Default: ``False``.\n        reduction (string, optional): Specifies the reduction to apply to the output:\n            ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,\n            ``'mean'``: the sum of the output will be divided by the number of\n            elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average`\n            and :attr:`reduce` are in the process of being deprecated, and in the meantime,\n            specifying either of those two args will override :attr:`reduction`. Default: ``'mean'``\n\n    Shape:\n        - Input: :math:`(N, D)` where :math:`D` is the vector dimension.\n        - Output: A Tensor of shape :math:`(N)` if :attr:`reduction` is ``'none'``, or a scalar\n          otherwise.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n        >>> triplet_loss = flow.nn.TripletMarginLoss(margin=1.0, p=2)\n        >>> anchor = np.array([[1, -1, 1],[-1, 1, -1], [1, 1, 1]])\n        >>> positive = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])\n        >>> negative = np.array([[2, 2, 2], [2, 2, 2], [2, 2, 2]])\n        >>> output = triplet_loss(flow.Tensor(anchor), flow.Tensor(positive), flow.Tensor(negative))\n        >>> output\n        tensor(6.2971, dtype=oneflow.float32)\n\n    \"\"\"\n\n    def __init__(\n        self,\n        margin: float = 1.0,\n        p: float = 2.0,\n        eps: float = 1e-6,\n        swap: bool = False,\n        size_average=None,\n        reduce=None,\n        reduction: str = \"mean\",\n    ) -> None:\n        super().__init__()\n        self.margin = margin\n        self.p = p\n        self.eps = eps\n        self.swap = swap\n        self.reduction = reduction\n\n    def forward(self, anchor, positive, negative):\n        triplet_loss = flow._C.triplet_margin_loss(\n            anchor,\n            positive,\n            negative,\n            margin=self.margin,\n            p=self.p,\n            eps=self.eps,\n            swap=self.swap,\n            reduction=self.reduction,\n        )\n        return triplet_loss\n\n\nif __name__ == \"__main__\":\n    import doctest\n\n    doctest.testmod(raise_on_error=True)\n"
  },
  {
    "path": "python/oneflow/nn/modules/masked_select.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow as flow\n\n\ndef masked_select_op(input, mask):\n    \"\"\"\n\n    Returns a new 1-D tensor which indexes the input tensor according to the boolean mask mask which is a BoolTensor(In oneFlow BoolTensor is replaced by Int8Tensor).\n\n    The shapes of the mask tensor and the input tensor don’t need to match, but they must be broadcastable.\n\n    Args:\n        input (Tensor): the input tensor.\n        mask (Tensor): the tensor containing the binary mask to index with\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n        \n        >>> input = flow.tensor(np.array([[-0.4620, 0.3139], [0.3898, -0.7197], [0.0478, -0.1657]]), dtype=flow.float32)\n        >>> mask = input.gt(0.05)\n        >>> out = flow.masked_select(input, mask)\n        >>> out\n        tensor([0.3139, 0.3898], dtype=oneflow.float32)\n    \"\"\"\n\n    assert input.is_global == mask.is_global, (\n        f\"input tensor is %s tensor, but mask is %s tensor\"\n        % (\n            \"global\" if input.is_global else \"local\",\n            \"global\" if mask.is_global else \"local\",\n        )\n    )\n    broadcast_shape = []\n    input_shape_len = len(input.shape)\n    mask_shape_len = len(mask.shape)\n    input_shape = [input.shape[i] for i in range(input_shape_len)]\n    input_shape.reverse()\n    mask_shape = [mask.shape[i] for i in range(mask_shape_len)]\n    mask_shape.reverse()\n    for i in range(max(input_shape_len, mask_shape_len)):\n        if i < input_shape_len and i < mask_shape_len:\n            broadcast_shape.append(max(input_shape[i], mask_shape[i]))\n        elif i < input_shape_len:\n            broadcast_shape.append(input_shape[i])\n        else:\n            broadcast_shape.append(mask_shape[i])\n    broadcast_shape.reverse()\n    broadcast_input = input.expand(broadcast_shape)\n    broadcast_mask = mask.expand(broadcast_shape)\n\n    indices = flow.argwhere(broadcast_mask)\n    gather_res = flow._C.gather_nd(broadcast_input, indices)\n\n    return gather_res.flatten()\n\n\nif __name__ == \"__main__\":\n    import doctest\n\n    doctest.testmod(raise_on_error=True)\n"
  },
  {
    "path": "python/oneflow/nn/modules/math_ops.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport collections\nfrom typing import Optional, Sequence, Union\n\nimport oneflow as flow\nfrom oneflow.framework.tensor import register_tensor_op\nfrom oneflow.nn.modules.module import Module\nfrom oneflow.nn.modules.utils import _check_axis\nfrom oneflow.ops.transpose_util import (\n    get_inversed_perm,\n    get_perm_when_transpose_axis_to_last_dim,\n)\n\n\ndef asin_op(input):\n    \"\"\"\n    Returns a new tensor with the arcsine of the elements of :attr:`input`.\n\n    .. math::\n        \\\\text{out}_{i} = \\\\sin^{-1}(\\\\text{input}_{i})\n\n    Args:\n        input (Tensor): the input tensor.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n        >>> input = flow.tensor(np.array([-0.5,  0.8, 1.0,  -0.8]), dtype=flow.float32)\n        >>> output = flow.asin(input)\n        >>> output.shape\n        oneflow.Size([4])\n        >>> output\n        tensor([-0.5236,  0.9273,  1.5708, -0.9273], dtype=oneflow.float32)\n        >>> input1 = flow.tensor(np.array([[0.8, 1.0], [-0.6, -1.0]]), dtype=flow.float32)\n        >>> output1 = input1.asin()\n        >>> output1.shape\n        oneflow.Size([2, 2])\n        >>> output1\n        tensor([[ 0.9273,  1.5708],\n                [-0.6435, -1.5708]], dtype=oneflow.float32)\n    \"\"\"\n    return flow._C.asin(input)\n\n\ndef arcsin_op(input):\n    \"\"\"\n  \n    Alias for :func:`oneflow.asin`\n    \"\"\"\n    return flow._C.asin(input)\n\n\ndef asinh_op(input):\n    \"\"\"\n    Returns a new tensor with the inverse hyperbolic sine of the elements of :attr:`input`.\n\n    .. math::\n        \\\\text{out}_{i} = \\\\sinh^{-1}(\\\\text{input}_{i})\n\n    Args:\n        input (Tensor): the input tensor.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n        >>> input = flow.tensor(np.array([2, 3, 4]), dtype=flow.float32)\n        >>> output = flow.asinh(input)\n        >>> output.shape\n        oneflow.Size([3])\n        >>> output\n        tensor([1.4436, 1.8184, 2.0947], dtype=oneflow.float32)\n\n        >>> input1 = flow.tensor(np.array([[-1, 0, -0.4], [5, 7, 0.8]]), dtype=flow.float32)\n        >>> output1 = input1.asinh()\n        >>> output1.shape\n        oneflow.Size([2, 3])\n        >>> output1\n        tensor([[-0.8814,  0.0000, -0.3900],\n                [ 2.3124,  2.6441,  0.7327]], dtype=oneflow.float32)\n\n    \"\"\"\n    return flow._C.asinh(input)\n\n\ndef arcsinh_op(input):\n    \"\"\"\n  \n    Alias for :func:`oneflow.asinh`\n    \"\"\"\n    return flow._C.asinh(input)\n\n\ndef asinh_op_tensor(input):\n    \"\"\"\n\n    See :func:`oneflow.asinh`\n    \"\"\"\n    return flow._C.asinh(input)\n\n\ndef inplace_sin_op_tensor(input):\n    \"\"\"\n    In-place version of :func:`oneflow.sin`\n    \n    \"\"\"\n    return flow._C.sin_(input)\n\n\ndef atan_op(input):\n    \"\"\"\n    Returns a new tensor with the arctangent of the elements of :attr:`input`.\n\n    .. math::\n        \\\\text{out}_{i} = \\\\tan^{-1}(\\\\text{input}_{i})\n\n    Args:\n        input (Tensor): the input tensor.\n\n    For example:\n\n    .. code-block:: python\n    \n        >>> import oneflow as flow\n        >>> import numpy as np\n        >>> input = flow.tensor(np.array([0.5, 0.6, 0.7]), dtype=flow.float32)\n        >>> output = flow.atan(input)\n        >>> output.shape\n        oneflow.Size([3])\n        \n    \"\"\"\n    return flow._C.atan(input)\n\n\ndef arctan_op(input):\n    \"\"\"\n    Alias for :func:`oneflow.atan`\n    \n    \"\"\"\n    return flow._C.atan(input)\n\n\ndef fmod_op(input, other):\n    \"\"\"\n    fmod(input, other, *, out=None) -> Tensor\n\n    Computes the element-wise remainder of division.\n\n    The dividend and divisor may contain both for integer and floating point\n    numbers. The remainder has the same sign as the dividend :attr:`input`.\n\n    Supports broadcasting to a common shape, integer and float inputs.\n\n\n    Args:\n        input (Tensor): the dividend\n        other (Tensor or Scalar): the divisor\n\n    Keyword args:\n        out (Tensor, optional): the output tensor.\n\n    Example::\n\n        >>> import oneflow as flow\n        >>> flow.fmod(flow.tensor([-3., -2, -1, 1, 2, 3]), 2.)\n        tensor([-1., -0., -1.,  1.,  0.,  1.], dtype=oneflow.float32)\n        >>> flow.fmod(flow.tensor([1, 2, 3, 4, 5.]), 1.5)\n        tensor([1.0000, 0.5000, 0.0000, 1.0000, 0.5000], dtype=oneflow.float32)\n        >>> flow.fmod(flow.tensor([1, 2, 3, 4., -5]), flow.tensor([4, 2, 1, 3., 1]))\n        tensor([1., 0., 0., 1., -0.], dtype=oneflow.float32)\n\n    \"\"\"\n    return flow._C.fmod(input, other)\n\n\ndef addmm(x, mat1, mat2, alpha=1, beta=1):\n    if len(x.shape) > 2 or len(mat1.shape) > 2 or len(mat2.shape) > 2:\n        raise ValueError(\"input matrixes shape can not be greater than 2\")\n    else:\n        return flow.mul(x, beta) + flow.mul(flow._C.matmul(mat1, mat2), alpha)\n\n\ndef addmm_op(input, mat1, mat2, alpha=1, beta=1):\n    \"\"\"addmm(beta=1, input, alpha=1, mat1, mat2, out=None) -> Tensor\n\n    Performs a matrix multiplication of the matrices :attr:`mat1` and :attr:`mat2`.\n    The matrix :attr:`input` is added to the final result.\n\n    If :attr:`mat1` is a :math:`(n \\\\times m)` tensor, :attr:`mat2` is a\n    :math:`(m \\\\times p)` tensor, then :attr:`input` must be\n    broadcastable with a :math:`(n \\\\times p)` tensor\n    and :attr:`out` will be a :math:`(n \\\\times p)` tensor.\n\n    :attr:`alpha` and :attr:`beta` are scaling factors on matrix-vector product between\n    :attr:`mat1` and :attr:`mat2` and the added matrix :attr:`input` respectively.\n\n    .. math::\n        \\\\text{out} = \\\\beta\\\\ \\\\text{input} + \\\\alpha\\\\ (\\\\text{mat1}_i \\\\mathbin{@} \\\\text{mat2}_i)\n\n    For inputs of type `float` or `double`, arguments :attr:`beta` and\n    :attr:`alpha` must be real numbers, otherwise they should be integers.\n\n    Args:\n        beta (Number, optional): multiplier for :attr:`input` (:math:`\\\\beta`)\n        input (Tensor): matrix to be added\n        alpha (Number, optional): multiplier for :math:`mat1 @ mat2` (:math:`\\\\alpha`)\n        mat1 (Tensor): the first matrix to be multiplied\n        mat2 (Tensor): the second matrix to be multiplied\n        out (Tensor, optional): the output tensor.\n\n    For example:\n\n        >>> import numpy as np\n        >>> import oneflow as flow\n        >>> input = flow.tensor(np.array([[1,2,4],[5,11,9.1]]))\n        >>> mat1 = flow.tensor(np.array([[7.3,1.9,7.3],[10.2,1,5.5]])) \n        >>> mat2 = flow.tensor(np.array([[7.3,1.9,7.3],[10.2,1,5.5],[3.7,2.2,8.1]])) \n        >>> output = flow.addmm(input, mat1, mat2)\n        >>> output\n        tensor([[100.6800,  33.8300, 126.8700],\n                [110.0100,  43.4800, 133.6100]], dtype=oneflow.float64)\n        >>> output.shape\n        oneflow.Size([2, 3])\n\n        >>> input2 = flow.tensor(np.array([1.7]))\n        >>> mat1 = flow.tensor(np.array([[1,2],[5,9.1],[7.7,1.4]]))\n        >>> mat2 = flow.tensor(np.array([[1,2,3.7],[5,9.1,6.8]]))\n        >>> output2 = flow.addmm(input2, mat1, mat2, alpha=1, beta=2)\n        >>> output2\n        tensor([[14.4000, 23.6000, 20.7000],\n                [53.9000, 96.2100, 83.7800],\n                [18.1000, 31.5400, 41.4100]], dtype=oneflow.float64)\n        >>> output2.shape\n        oneflow.Size([3, 3])\n    \"\"\"\n    return addmm(input, mat1, mat2, alpha, beta)\n\n\nif __name__ == \"__main__\":\n    import doctest\n\n    doctest.testmod(raise_on_error=True)\n"
  },
  {
    "path": "python/oneflow/nn/modules/meshgrid.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow as flow\n\n\ndef meshgrid_op(*tensors, indexing=\"ij\"):\n    if isinstance(tensors[0], (list, tuple)):\n        return flow._C.meshgrid(tensors[0], indexing)\n    return flow._C.meshgrid(tensors, indexing)\n\n\nif __name__ == \"__main__\":\n    import doctest\n\n    doctest.testmod(raise_on_error=True)\n"
  },
  {
    "path": "python/oneflow/nn/modules/min_max_observer.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow as flow\nfrom oneflow.framework.tensor import register_tensor_op\nfrom oneflow.nn.modules.module import Module\n\n\nclass MinMaxObserver(Module):\n    \"\"\"\n    \n    Compute the quantization parameters of the input tensor.\n\n    First compute the max and min values of input tensor:\n\n    .. math::\n\n        & max\\\\_value = max(input)\n\n        & min\\\\_value = min(input)\n\n    Then compute the scale and zero_point with the following equations:\n\n        if quantization_scheme == \"symmetric\":\n\n        .. math::\n\n            & denom = 2^{quantization\\\\_to\\\\_bit - 1} - 1\n\n            & scale = max(|max\\\\_value|,|min\\\\_value|) / denom\n\n            & zero\\\\_point = 0\n\n        elif quantization_scheme == \"affine\":\n\n        .. math::\n\n            & denom = 2^{quantization\\\\_to\\\\_bit} - 1\n\n            & scale = (max\\\\_value - min\\\\_value) / denom\n\n            & zero\\\\_point = -min\\\\_value / scale\n\n    If per_layer_quantization is False, then the shape of scale and zero_point will be (input.shape[0],).\n\n    Args:\n        input(oneflow.Tensor):  the input value(s), in ``oneflow.float32``.\n        quantization_formula (str): Support \"google\" or \"cambricon\".\n        quantization_bit (int): Quantize input to uintX / intX, X can be in range [2, 8]. Defaults to 8.\n        quantization_scheme (str): \"symmetric\" or \"affine\", quantize to signed / unsigned integer. Defaults to \"symmetric\".\n        per_layer_quantization (bool): True or False, means per-layer / per-channel quantization. Defaults to True.\n\n    Returns:\n        Tuple[oneflow.Tensor, oneflow.Tensor]: The scale and zero_point of input tensor.\n\n    For example:\n\n    .. code-block:: python\n        \n        >>> import numpy as np\n        >>> import oneflow as flow\n\n        >>> weight = (np.random.random((2, 3, 4, 5)) - 0.5).astype(np.float32)\n        \n        >>> input_tensor = flow.tensor(\n        ...    weight, dtype=flow.float32\n        ... )\n        \n        >>> quantization_bit = 8\n        >>> quantization_scheme = \"symmetric\"\n        >>> quantization_formula = \"google\"\n        >>> per_layer_quantization = True\n\n        >>> min_max_observer = flow.nn.MinMaxObserver(quantization_formula=quantization_formula, quantization_bit=quantization_bit,\n        ... quantization_scheme=quantization_scheme, per_layer_quantization=per_layer_quantization)\n\n        >>> scale, zero_point = min_max_observer(\n        ...    input_tensor, )\n\n    \"\"\"\n\n    def __init__(\n        self,\n        quantization_formula: str = \"google\",\n        quantization_bit: int = 8,\n        quantization_scheme: str = \"symmetric\",\n        per_layer_quantization: bool = True,\n    ) -> None:\n        super().__init__()\n        self.quantization_formula = quantization_formula\n        self.quantization_bit = quantization_bit\n        self.quantization_scheme = quantization_scheme\n        self.per_layer_quantization = per_layer_quantization\n\n    def forward(self, input):\n        return flow._C.min_max_observer(\n            input,\n            self.quantization_formula,\n            self.quantization_bit,\n            self.quantization_scheme,\n            self.per_layer_quantization,\n        )\n\n\nif __name__ == \"__main__\":\n    import doctest\n\n    doctest.testmod(raise_on_error=True)\n"
  },
  {
    "path": "python/oneflow/nn/modules/module.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport itertools\nfrom collections import OrderedDict, namedtuple\nfrom typing import (\n    Any,\n    Callable,\n    Dict,\n    Iterator,\n    List,\n    Optional,\n    Set,\n    Tuple,\n    TypeVar,\n    Union,\n    overload,\n)\nimport traceback\nimport functools\nimport weakref\nimport warnings\n\nimport numpy as np\nimport oneflow as flow\nfrom oneflow.framework.tensor import Tensor\nfrom oneflow.nn.parameter import Parameter\nfrom contextlib import contextmanager\n\n\nclass _WrappedHook(object):\n    def __init__(self, hook: Callable, module: Optional[\"Module\"] = None):\n        self.hook: Callable = hook\n        functools.update_wrapper(self, hook)\n\n        self.with_module: bool = False\n\n        if module is not None:\n            self.module: weakref.ReferenceType[\"Module\"] = weakref.ref(module)\n            self.with_module = True\n\n    def __call__(self, *args: Any, **kwargs: Any) -> Any:\n        if self.with_module:\n            module = self.module()\n            if module is None:\n                raise RuntimeError(\"You are trying to call the hook of a dead Module!\")\n            return self.hook(module, *args, **kwargs)\n        return self.hook(*args, **kwargs)\n\n    def __getstate__(self) -> Dict:\n        result = {\"hook\": self.hook, \"with_module\": self.with_module}\n        if self.with_module:\n            result[\"module\"] = self.module()\n\n        return result\n\n    def __setstate__(self, state: Dict):\n        self.hook = state[\"hook\"]\n        self.with_module = state[\"with_module\"]\n\n        if self.with_module:\n            if state[\"module\"] is None:\n                raise RuntimeError(\n                    \"You are trying to revive the hook of a dead Module!\"\n                )\n            self.module = weakref.ref(state[\"module\"])\n\n\nclass _IncompatibleKeys(\n    namedtuple(\"IncompatibleKeys\", [\"missing_keys\", \"unexpected_keys\"])\n):\n    def __repr__(self):\n        if not self.missing_keys and (not self.unexpected_keys):\n            return \"<All keys matched successfully>\"\n        return super(_IncompatibleKeys, self).__repr__()\n\n    __str__ = __repr__\n\n\ndef _addindent(s_, numSpaces):\n    s = s_.split(\"\\n\")\n    if len(s) == 1:\n        return s_\n    first = s.pop(0)\n    s = [numSpaces * \" \" + line for line in s]\n    s = \"\\n\".join(s)\n    s = first + \"\\n\" + s\n    return s\n\n\nT = TypeVar(\"T\", bound=\"Module\")\n\n\nclass Module(object):\n    r\"\"\"Base class for all neural network modules.\n    \n    This class is consistent with PyTorch.\n    The documentation is referenced from:\n    https://pytorch.org/docs/1.10/generated/torch.nn.Module.html.\n\n    Your models should also subclass this class.\n\n    Modules can also contain other Modules, allowing to nest them in\n    a tree structure. You can assign the submodules as regular attributes::\n\n        import oneflow.nn as nn\n        import oneflow.nn.functional as F\n\n        class Model(nn.Module):\n            def __init__(self):\n                super().__init__()\n                self.conv1 = nn.Conv2d(1, 20, 5)\n                self.conv2 = nn.Conv2d(20, 20, 5)\n\n            def forward(self, x):\n                x = F.relu(self.conv1(x))\n                return F.relu(self.conv2(x))\n\n    Submodules assigned in this way will be registered, and will have their\n    parameters converted too when you call :meth:`to`, etc.\n\n    .. note::\n        As per the example above, an ``__init__()`` call to the parent class\n        must be made before assignment on the child.\n\n    :ivar training: Boolean represents whether this module is in training or\n                    evaluation mode.\n    :vartype training: bool\n    \"\"\"\n\n    def __init__(self):\n        \"\"\"\n        Calls super().__setattr__('a', a) instead of the typical self.a = a\n        to avoid Module.__setattr__ overhead. Module's __setattr__ has special\n        handling for parameters, submodules, and buffers but simply calls into\n        super().__setattr__ for all other attributes.\n        \"\"\"\n        super().__setattr__(\"training\", True)\n        super().__setattr__(\"_parameters\", OrderedDict())\n        super().__setattr__(\"_buffers\", OrderedDict())\n        super().__setattr__(\"_non_persistent_buffers_set\", set())\n        super().__setattr__(\"_backward_hooks\", OrderedDict())\n        super().__setattr__(\"_is_full_backward_hook\", None)\n        super().__setattr__(\"_forward_hooks\", OrderedDict())\n        super().__setattr__(\"_forward_pre_hooks\", OrderedDict())\n        super().__setattr__(\"_state_dict_hooks\", OrderedDict())\n        super().__setattr__(\"_load_state_dict_pre_hooks\", OrderedDict())\n        super().__setattr__(\"_modules\", OrderedDict())\n        super().__setattr__(\"_is_ddp_module\", False)\n        super().__setattr__(\"_oneflow_internal_module_tensor_applied_dict__\", None)\n        super().__setattr__(\"cpg\", None)\n\n    def __getstate__(self):\n        if not self._is_ddp_module:\n            if (\n                len(self._backward_hooks) > 0\n                or len(self._forward_hooks) > 0\n                or len(self._forward_pre_hooks) > 0\n                or len(self._state_dict_hooks) > 0\n                or len(self._load_state_dict_pre_hooks) > 0\n            ):\n                warnings.warn(\"The module hooks will not be remained after serializing\")\n\n        state = self.__dict__.copy()\n        del state[\"_backward_hooks\"]\n        del state[\"_forward_hooks\"]\n        del state[\"_forward_pre_hooks\"]\n        del state[\"_state_dict_hooks\"]\n        del state[\"_load_state_dict_pre_hooks\"]\n        del state[\"_is_full_backward_hook\"]\n        del state[\"_non_persistent_buffers_set\"]\n        return state\n\n    def __setstate__(self, state):\n        self.__dict__.update(state)\n        self._backward_hooks = OrderedDict()\n        self._forward_hooks = OrderedDict()\n        self._forward_pre_hooks = OrderedDict()\n        self._state_dict_hooks = OrderedDict()\n        self._load_state_dict_pre_hooks = OrderedDict()\n        self._is_full_backward_hook = None\n        self._non_persistent_buffers_set = set()\n        if hasattr(self, \"_is_ddp_module\") and self._is_ddp_module:\n            # flow.nn.parallel.DistributedDataParallel updates the module inplace\n            flow.nn.parallel.DistributedDataParallel(self, broadcast_parameters=False)\n\n    def forward(self, *args, **kwargs):\n        raise NotImplementedError()\n\n    def __call__(self, *args, **kwargs):\n        if flow._oneflow_internal.lazy_mode.is_enabled():\n            warnings.warn(\n                self._shallow_repr()\n                + \" is called in a nn.Graph, but not registered into a nn.Graph.\"\n            )\n\n        full_backward_hooks, non_full_backward_hooks = [], []\n        if self._backward_hooks:\n            full_backward_hooks, non_full_backward_hooks = self._get_backward_hooks()\n\n        for hook in list(self._forward_pre_hooks.values()):\n            result = hook(self, args)\n            if result is not None:\n                if not isinstance(result, tuple):\n                    result = (result,)\n                args = result\n\n        bw_hook = None\n        if full_backward_hooks:\n            bw_hook = flow.utils.hooks.BackwardHook(self, full_backward_hooks, [])\n            args = bw_hook.setup_input_hook(args)\n\n        res = self.forward(*args, **kwargs)\n        for hook in list(self._forward_hooks.values()):\n            result = hook(self, args, res)\n            if result is not None:\n                res = result\n\n        if bw_hook is not None:\n            res = bw_hook.setup_output_hook(res)\n\n        if non_full_backward_hooks:\n            var = res\n            while not isinstance(var, Tensor):\n                if isinstance(var, dict):\n                    var = next((v for v in var.values() if isinstance(v, Tensor)))\n                else:\n                    var = var[0]\n            grad_fn = var.grad_fn\n\n            if grad_fn is not None:\n                self._maybe_warn_non_full_backward_hook(args, res, grad_fn)\n                for hook in non_full_backward_hooks:\n                    wrapper = functools.partial(hook, self)\n                    functools.update_wrapper(wrapper, hook)\n                    grad_fn.register_hook(wrapper)\n\n        return res\n\n    def add_module(self, name: str, module: Optional[\"Module\"]) -> None:\n        r\"\"\"\n        add_module(name, module)\n        \n        Adds a child module to the current module.\n\n        The module can be accessed as an attribute using the given name.\n\n        Args:\n            name (string): name of the child module. The child module can be\n                accessed from this module using the given name\n            module (Module): child module to be added to the module.\n        \"\"\"\n        if not isinstance(module, Module) and module is not None:\n            raise TypeError(\"{} is not a Module subclass\".format(type(module)))\n        elif not isinstance(name, str):\n            raise TypeError(\"module name should be a string. Got {}\".format(type(name)))\n        elif hasattr(self, name) and name not in self._modules:\n            raise KeyError(\"attribute '{}' already exists\".format(name))\n        elif \".\" in name:\n            raise KeyError('module name can\\'t contain \".\", got: {}'.format(name))\n        elif name == \"\":\n            raise KeyError('module name can\\'t be empty string \"\"')\n        self._modules[name] = module\n\n    def register_buffer(\n        self, name: str, tensor: Optional[Tensor], persistent: bool = True\n    ) -> None:\n        r\"\"\"\n        register_buffer(name, tensor, persistent=True)\n        \n        Adds a buffer to the module.\n\n        This is typically used to register a buffer that should not to be\n        considered a model parameter. For example, BatchNorm's ``running_mean``\n        is not a parameter, but is part of the module's state. Buffers, by\n        default, are persistent and will be saved alongside parameters. This\n        behavior can be changed by setting :attr:`persistent` to ``False``. The\n        only difference between a persistent buffer and a non-persistent buffer\n        is that the latter will not be a part of this module's\n        :attr:`state_dict`.\n\n        Buffers can be accessed as attributes using given names.\n\n        Args:\n            name (string): name of the buffer. The buffer can be accessed\n                from this module using the given name\n            tensor (Tensor or None): buffer to be registered. If ``None``, then operations\n                that run on buffers, such as :attr:`cuda`, are ignored. If ``None``,\n                the buffer is **not** included in the module's :attr:`state_dict`.\n            persistent (bool): whether the buffer is part of this module's\n                :attr:`state_dict`.\n                \n        Example::\n\n            >>> self.register_buffer('running_mean', oneflow.zeros(num_features)) # doctest: +SKIP\n        \"\"\"\n        if \"_buffers\" not in self.__dict__:\n            raise AttributeError(\"cannot assign buffer before Module.__init__() call\")\n        elif not isinstance(name, str):\n            raise TypeError(\"buffer name should be a string. Got {}\".format(type(name)))\n        elif \".\" in name:\n            raise KeyError('buffer name can\\'t contain \".\"')\n        elif name == \"\":\n            raise KeyError('buffer name can\\'t be empty string \"\"')\n        elif hasattr(self, name) and name not in self._buffers:\n            raise KeyError(\"attribute '{}' already exists\".format(name))\n        elif tensor is not None and (not isinstance(tensor, Tensor)):\n            raise TypeError(\n                \"cannot assign '{}' object to buffer '{}' (Tensor or None required)\".format(\n                    type(tensor), name\n                )\n            )\n        else:\n            self._buffers[name] = tensor\n            if persistent:\n                self._non_persistent_buffers_set.discard(name)\n            else:\n                self._non_persistent_buffers_set.add(name)\n\n    def register_parameter(self, name: str, param: Optional[Parameter]) -> None:\n        r\"\"\"\n        register_parameter(name, param)\n        \n        Adds a parameter to the module.\n\n        The parameter can be accessed as an attribute using given name.\n\n        Args:\n            name (string): name of the parameter. The parameter can be accessed\n                from this module using the given name\n            param (Parameter or None): parameter to be added to the module. If\n                ``None``, then operations that run on parameters, such as :attr:`cuda`,\n                are ignored. If ``None``, the parameter is **not** included in the\n                module's :attr:`state_dict`.\n        \"\"\"\n        if \"_parameters\" not in self.__dict__:\n            raise AttributeError(\n                \"cannot assign parameter before Module.__init__() call\"\n            )\n        elif not isinstance(name, str):\n            raise TypeError(\n                \"parameter name should be a string. Got {}\".format(type(name))\n            )\n        elif \".\" in name:\n            raise KeyError('parameter name can\\'t contain \".\"')\n        elif name == \"\":\n            raise KeyError('parameter name can\\'t be empty string \"\"')\n        elif hasattr(self, name) and name not in self._parameters:\n            raise KeyError(\"attribute '{}' already exists\".format(name))\n        if param is None:\n            self._parameters[name] = None\n        elif not isinstance(param, Parameter):\n            raise TypeError(\n                \"cannot assign '{}' object to parameter '{}' (nn.Parameter or None required)\".format(\n                    type(param), name\n                )\n            )\n        else:\n            self._parameters[name] = param\n\n    def _register_state_dict_hook(self, hook):\n        r\"\"\"These hooks will be called with arguments: `self`, `state_dict`,\n        `prefix`, `local_metadata`, after the `state_dict` of `self` is set.\n        Note that only parameters and buffers of `self` or its children are\n        guaranteed to exist in `state_dict`. The hooks may modify `state_dict`\n        inplace or return a new one.\n\n        .. note:\n            Do not use `module.state_dict()` in _register_state_dict_hook function\n        \"\"\"\n        handle = flow.utils.hooks.RemovableHandle(self._state_dict_hooks)\n        self._state_dict_hooks[handle.id] = hook\n        return handle\n\n    def _register_load_state_dict_pre_hook(\n        self, hook: Callable[..., None], with_module=False\n    ):\n        r\"\"\"These hooks will be called with arguments: `state_dict`, `prefix`,\n        `local_metadata`, `strict`, `missing_keys`, `unexpected_keys`,\n        `error_msgs`, before loading `state_dict` into `self`. These arguments\n        are exactly the same as those of `_load_from_state_dict`.\n\n        If ``with_module`` is ``True``, then the first argument to the hook is\n        an instance of the module.\n\n        Arguments:\n            hook (Callable): Callable hook that will be invoked before\n                loading the state dict.\n            with_module (bool, optional): Whether or not to pass the module\n                instance to the hook as the first parameter.\n        \"\"\"\n        handle = flow.utils.hooks.RemovableHandle(self._load_state_dict_pre_hooks)\n        self._load_state_dict_pre_hooks[handle.id] = _WrappedHook(\n            hook, self if with_module else None\n        )\n        return handle\n\n    def register_state_dict_pre_hook(self, hook):\n        r\"\"\"These hooks will be called with arguments: ``self``, ``prefix``,\n        and ``keep_vars`` before calling ``state_dict`` on ``self``. The registered\n        hooks can be used to perform pre-processing before the ``state_dict``\n        call is made.\n        \"\"\"\n        handle = flow.utils.hooks.RemovableHandle(self._state_dict_pre_hooks)\n        self._state_dict_pre_hooks[handle.id] = hook\n        return handle\n\n    def __getattr__(self, name: str) -> Union[Tensor, \"Module\"]:\n        if \"_parameters\" in self.__dict__:\n            _parameters = self.__dict__[\"_parameters\"]\n            if name in _parameters:\n                return _parameters[name]\n        if \"_buffers\" in self.__dict__:\n            _buffers = self.__dict__[\"_buffers\"]\n            if name in _buffers:\n                return _buffers[name]\n        if \"_modules\" in self.__dict__:\n            modules = self.__dict__[\"_modules\"]\n            if name in modules:\n                return modules[name]\n        raise AttributeError(\n            \"'{}' object has no attribute '{}'\".format(type(self).__name__, name)\n        )\n\n    def __setattr__(self, name: str, value: Union[Tensor, \"Module\"]) -> None:\n        def remove_from(*dicts_or_sets):\n            for d in dicts_or_sets:\n                if name in d:\n                    if isinstance(d, dict):\n                        del d[name]\n                    else:\n                        d.discard(name)\n\n        params = self.__dict__.get(\"_parameters\")\n        if isinstance(value, Parameter):\n            if params is None:\n                raise AttributeError(\n                    \"cannot assign parameters before Module.__init__() call\"\n                )\n            remove_from(\n                self.__dict__,\n                self._buffers,\n                self._modules,\n                self._non_persistent_buffers_set,\n            )\n            self.register_parameter(name, value)\n        elif params is not None and name in params:\n            if value is not None:\n                raise TypeError(\n                    \"cannot assign '{}' as parameter '{}' (nn.Parameter or None expected)\".format(\n                        type(value), name\n                    )\n                )\n            self.register_parameter(name, value)\n        else:\n            modules = self.__dict__.get(\"_modules\")\n            if isinstance(value, Module):\n                if modules is None:\n                    raise AttributeError(\n                        \"cannot assign module before Module.__init__() call\"\n                    )\n                remove_from(\n                    self.__dict__,\n                    self._parameters,\n                    self._buffers,\n                    self._non_persistent_buffers_set,\n                )\n                modules[name] = value\n            elif modules is not None and name in modules:\n                if value is not None:\n                    raise TypeError(\n                        \"cannot assign '{}' as child module '{}' (nn.Module or None expected)\".format(\n                            type(value), name\n                        )\n                    )\n                modules[name] = value\n            else:\n                buffers = self.__dict__.get(\"_buffers\")\n                if buffers is not None and name in buffers:\n                    if value is not None and (not isinstance(value, Tensor)):\n                        raise TypeError(\n                            \"cannot assign '{}' as buffer '{}' (Tensor or None expected)\".format(\n                                type(value), name\n                            )\n                        )\n                    buffers[name] = value\n                else:\n                    object.__setattr__(self, name, value)\n\n    def __delattr__(self, name):\n        if name in self._parameters:\n            del self._parameters[name]\n        elif name in self._buffers:\n            del self._buffers[name]\n            self._non_persistent_buffers_set.discard(name)\n        elif name in self._modules:\n            del self._modules[name]\n        else:\n            super().__delattr__(name)\n\n    def _named_members(self, get_members_fn, prefix=\"\", recurse=True):\n        memo = set()\n        modules = self.named_modules(prefix=prefix) if recurse else [(prefix, self)]\n        for (module_prefix, module) in modules:\n            members = get_members_fn(module)\n            for (k, v) in members:\n                if v is None or v in memo:\n                    continue\n                memo.add(v)\n                name = module_prefix + (\".\" if module_prefix else \"\") + k\n                yield (name, v)\n\n    def parameters(self, recurse: bool = True) -> Iterator[Parameter]:\n        r\"\"\"\n        parameters(recurse=True) -> Iterator[Parameter]\n        \n        Returns an iterator over module parameters.\n\n        This is typically passed to an optimizer.\n\n        Args:\n            recurse (bool): if True, then yields parameters of this module\n                and all submodules. Otherwise, yields only parameters that\n                are direct members of this module.\n\n        Yields:\n            Parameter: module parameter\n\n        Example::\n\n            >>> for param in model.parameters(): # doctest: +SKIP\n            ...     print(type(param), param.size()) # doctest: +SKIP\n            <class 'oneflow.Tensor'> oneflow.Size([10])\n\n        \"\"\"\n        for (name, param) in self.named_parameters(recurse=recurse):\n            yield param\n\n    def named_parameters(\n        self, prefix: str = \"\", recurse: bool = True\n    ) -> Iterator[Tuple[str, Tensor]]:\n        r\"\"\"\n        named_parameters(prefix=\"\", recurse=True) -> Iterator[Tuple[str, Tensor]]\n        \n        Returns an iterator over module parameters, yielding both the\n        name of the parameter as well as the parameter itself.\n\n        Args:\n            prefix (str): prefix to prepend to all parameter names.\n            recurse (bool): if True, then yields parameters of this module\n                and all submodules. Otherwise, yields only parameters that\n                are direct members of this module.\n\n        Yields:\n            (string, Parameter): Tuple containing the name and parameter\n\n        Example::\n\n            >>> for name, param in self.named_parameters(): # doctest: +SKIP\n            ...    if name in ['bias']: # doctest: +SKIP\n            ...        print(param.size()) # doctest: +SKIP\n\n        \"\"\"\n        gen = self._named_members(\n            lambda module: module._parameters.items(), prefix=prefix, recurse=recurse\n        )\n        for elem in gen:\n            yield elem\n\n    def buffers(self, recurse: bool = True) -> Iterator[Tensor]:\n        r\"\"\"\n        buffers(recurse=True) -> Iterator[Tensor]\n        \n        Returns an iterator over module buffers.\n\n        Args:\n            recurse (bool): if True, then yields buffers of this module\n                and all submodules. Otherwise, yields only buffers that\n                are direct members of this module.\n\n        Yields:\n            oneflow.Tensor: module buffer\n\n        Example::\n\n            >>> for buf in model.buffers(): # doctest: +SKIP\n            ...     print(type(buf), buf.size()) # doctest: +SKIP\n            <class 'oneflow.Tensor'> oneflow.Size([10])\n\n        \"\"\"\n        for (name, buf) in self.named_buffers(recurse=recurse):\n            yield buf\n\n    def named_buffers(\n        self, prefix: str = \"\", recurse: bool = True\n    ) -> Iterator[Tuple[str, Tensor]]:\n        r\"\"\"\n        named_buffers(prefix=\"\", recurse=True) -> Iterator[Tuple[str, Tensor]]\n        \n        Returns an iterator over module buffers, yielding both the\n        name of the buffer as well as the buffer itself.\n\n        Args:\n            prefix (str): prefix to prepend to all buffer names.\n            recurse (bool): if True, then yields buffers of this module\n                and all submodules. Otherwise, yields only buffers that\n                are direct members of this module.\n\n        Yields:\n            (string, oneflow.Tensor): Tuple containing the name and buffer\n\n        Example::\n\n            >>> for name, buf in self.named_buffers(): # doctest: +SKIP\n            ...    if name in ['running_var']: # doctest: +SKIP\n            ...        print(buf.size()) # doctest: +SKIP\n\n        \"\"\"\n        gen = self._named_members(\n            lambda module: module._buffers.items(), prefix=prefix, recurse=recurse\n        )\n        for elem in gen:\n            yield elem\n\n    def children(self) -> Iterator[\"Module\"]:\n        r\"\"\"\n        children() -> Iterator[\"Module\"]\n        \n        Returns an iterator over immediate children modules.\n\n        Yields:\n            Module: a child module\n            \n        Example::\n\n            >>> import oneflow.nn as nn\n            >>> l1 = nn.Linear(2, 2)\n            >>> l2 = nn.Linear(2, 2)\n            >>> net = nn.Sequential(l1, l2)\n            >>> for idx, m in enumerate(net.children()):\n            ...     print(idx, '->', m)\n            0 -> Linear(in_features=2, out_features=2, bias=True)\n            1 -> Linear(in_features=2, out_features=2, bias=True)\n\n        \"\"\"\n        for (name, module) in self.named_children():\n            yield module\n\n    def named_children(self) -> Iterator[Tuple[str, \"Module\"]]:\n        r\"\"\"\n        named_children() -> Iterator[Tuple[str, \"Module\"]]\n        \n        Returns an iterator over immediate children modules, yielding both\n        the name of the module as well as the module itself.\n\n        Yields:\n            (string, Module): Tuple containing a name and child module\n\n        Example::\n\n            >>> for name, module in model.named_children(): # doctest: +SKIP\n            ...     if name in ['conv4', 'conv5']: # doctest: +SKIP\n            ...         print(module) # doctest: +SKIP\n\n        \"\"\"\n        memo = set()\n        for (name, module) in self._modules.items():\n            if module is not None and module not in memo:\n                memo.add(module)\n                yield (name, module)\n\n    def modules(self) -> Iterator[\"Module\"]:\n        r\"\"\"\n        modules() -> Iterator[\"Module\"]\n        \n        Returns an iterator over all modules in the network.\n\n        Yields:\n            Module: a module in the network\n\n        Note:\n            Duplicate modules are returned only once. In the following\n            example, ``l`` will be returned only once.\n\n        Example::\n\n            >>> import oneflow.nn as nn\n            >>> l = nn.Linear(2, 2)\n            >>> net = nn.Sequential(l, l)\n            >>> for idx, m in enumerate(net.modules()):\n            ...     print(idx, '->', m)\n            0 -> Sequential(\n              (0): Linear(in_features=2, out_features=2, bias=True)\n              (1): Linear(in_features=2, out_features=2, bias=True)\n            )\n            1 -> Linear(in_features=2, out_features=2, bias=True)\n\n        \"\"\"\n        for (name, module) in self.named_modules():\n            yield module\n\n    def named_modules(self, memo: Optional[Set[\"Module\"]] = None, prefix: str = \"\"):\n        r\"\"\"\n        named_modules(memo=None, prefix=\"\")\n        \n        Returns an iterator over all modules in the network, yielding\n        both the name of the module as well as the module itself.\n\n        Args:\n            memo: a memo to store the set of modules already added to the result\n            prefix: a prefix that will be added to the name of the module\n\n        Yields:\n            (string, Module): Tuple of name and module\n\n        Note:\n            Duplicate modules are returned only once. In the following\n            example, ``l`` will be returned only once.\n\n        Example::\n\n            >>> import oneflow.nn as nn\n            >>> l = nn.Linear(2, 2)\n            >>> net = nn.Sequential(l, l)\n            >>> for idx, m in enumerate(net.named_modules()):\n            ...     print(idx, '->', m)\n            0 -> ('', Sequential(\n              (0): Linear(in_features=2, out_features=2, bias=True)\n              (1): Linear(in_features=2, out_features=2, bias=True)\n            ))\n            1 -> ('0', Linear(in_features=2, out_features=2, bias=True))\n\n        \"\"\"\n        if memo is None:\n            memo = set()\n        if self not in memo:\n            memo.add(self)\n            yield (prefix, self)\n            for (name, module) in self._modules.items():\n                if module is None:\n                    continue\n                submodule_prefix = prefix + (\".\" if prefix else \"\") + name\n                for m in module.named_modules(memo, submodule_prefix):\n                    yield m\n\n    def train(self: T, mode: bool = True) -> T:\n        r\"\"\"\n        train(mode=True)\n        \n        Sets the module in training mode.\n\n        This has any effect only on certain modules. See documentations of\n        particular modules for details of their behaviors in training/evaluation\n        mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm1d`,\n        etc.\n\n        Args:\n            mode (bool): whether to set training mode (``True``) or evaluation\n                         mode (``False``). Default: ``True``.\n\n        Returns:\n            Module: self\n        \"\"\"\n        self.training = mode\n        for module in self.children():\n            module.train(mode)\n        return self\n\n    def eval(self: T) -> T:\n        r\"\"\"\n        eval()\n        \n        Sets the module in evaluation mode.\n\n        This has any effect only on certain modules. See documentations of\n        particular modules for details of their behaviors in training/evaluation\n        mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm1d`,\n        etc.\n\n        This is equivalent with :meth:`self.train(False) <oneflow.nn.Module.train>`.\n\n        Returns:\n            Module: self\n        \"\"\"\n        return self.train(False)\n\n    def requires_grad_(self: T, requires_grad: bool = True) -> T:\n        r\"\"\"Change if autograd should record operations on parameters in this\n        module.\n        The interface is consistent with PyTorch.\n        The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.Module.html?highlight=requires_grad_#torch.nn.Module.requires_grad_.\n\n        This method sets the parameters' :attr:`requires_grad` attributes\n        in-place.\n\n        This method is helpful for freezing part of the module for finetuning\n        or training parts of a model individually (e.g., GAN training).\n\n        Args:\n            requires_grad (bool): whether autograd should record operations on\n                                  parameters in this module. Default: ``True``.\n\n        Returns:\n            Module: self\n        \"\"\"\n        for p in self.parameters():\n            p.requires_grad_(requires_grad)\n        return self\n\n    def zero_grad(self, set_to_none: bool = False) -> None:\n        r\"\"\"\n        zero_grad(set_to_none=False)\n        \n        Sets gradients of all model parameters to zero. See similar function\n        under :class:`oneflow.optim.Optimizer` for more context.\n\n        Args:\n            set_to_none (bool): instead of setting to zero, set the grads to None.\n                See :meth:`oneflow.optim.Optimizer.zero_grad` for details.\n        \"\"\"\n        if getattr(self, \"_is_replica\", False):\n            warnings.warn(\n                \"Calling .zero_grad() from a module created with nn.DataParallel() has no effect. \"\n                \"The parameters are copied (in a differentiable manner) from the original module. \"\n                \"This means they are not leaf nodes in autograd and so don't accumulate gradients. \"\n                \"If you need gradients in your forward method, consider using autograd.grad instead.\"\n            )\n\n        for p in self.parameters():\n            if p.grad is not None:\n                if set_to_none:\n                    p.grad = None\n                else:\n                    if p.grad.grad_fn is not None:\n                        p.grad.detach_()\n                    else:\n                        p.grad.requires_grad_(False)\n                    p.grad.zero_()\n\n    def _save_to_state_dict(self, destination, prefix, keep_vars):\n        for (name, param) in self._parameters.items():\n            if param is not None:\n                destination[prefix + name] = param\n        for (name, buf) in self._buffers.items():\n            if buf is not None and name not in self._non_persistent_buffers_set:\n                destination[prefix + name] = buf\n\n    def _load_from_state_dict(\n        self,\n        state_dict,\n        prefix,\n        local_metadata,\n        strict,\n        missing_keys,\n        unexpected_keys,\n        error_msgs,\n    ):\n        for hook in self._load_state_dict_pre_hooks.values():\n            hook(\n                state_dict,\n                prefix,\n                local_metadata,\n                strict,\n                missing_keys,\n                unexpected_keys,\n                error_msgs,\n            )\n        persistent_buffers = {\n            k: v\n            for (k, v) in self._buffers.items()\n            if k not in self._non_persistent_buffers_set\n        }\n        local_name_params = itertools.chain(\n            self._parameters.items(), persistent_buffers.items()\n        )\n        local_state = {k: v for (k, v) in local_name_params if v is not None}\n        for (name, param) in local_state.items():\n            key = prefix + name\n            if key in state_dict:\n                input_param = state_dict[key]\n                if tuple(input_param.shape) != tuple(param.shape):\n                    error_msgs.append(\n                        \"size mismatch for {}: copying a param with shape {} from checkpoint, the shape in current model is {}.\".format(\n                            key, input_param.shape, param.shape\n                        )\n                    )\n                    continue\n                if (\n                    isinstance(input_param, Tensor)\n                    and input_param.is_global != param.is_global\n                ):\n                    if param.is_global:\n                        help_msg = \"Maybe you need to convert the checkpoint param to global, or set global_src_rank=0 when using flow.load to load model's state_dict\"\n                    else:\n                        help_msg = \"Maybe you need to convert your model to global.\"\n                    error_msgs.append(\n                        'local / global mismatch for \"{}\":  param from checkpoint is {} tensor, but the param in current model is {} tensor. {}'.format(\n                            key,\n                            \"global\" if input_param.is_global else \"local\",\n                            \"global\" if param.is_global else \"local\",\n                            help_msg,\n                        )\n                    )\n                    continue\n\n                try:\n                    with flow.no_grad():\n                        param.copy_(input_param)\n                except Exception as ex:\n                    error_msgs.append(\n                        'While copying the parameter \"{}\", an exception occurred : \\n\\n{}.'.format(\n                            key,\n                            \"\".join(\n                                map(\n                                    lambda line: \"\\t\" + line,\n                                    traceback.format_exc().splitlines(True),\n                                )\n                            ),\n                        )\n                    )\n            elif strict:\n                missing_keys.append(key)\n        if strict:\n            for key in state_dict.keys():\n                if key.startswith(prefix):\n                    input_name = key[len(prefix) :]\n                    input_name = input_name.split(\".\", 1)[0]\n                    if (\n                        input_name not in self._modules\n                        and input_name not in local_state\n                    ):\n                        unexpected_keys.append(key)\n\n    def load_state_dict(\n        self,\n        state_dict: Union[Dict[str, Tensor], Dict[str, Tensor]],\n        strict: bool = True,\n    ):\n        r\"\"\"\n        load_state_dict(state_dict, strict=True)\n        \n        Copies parameters and buffers from :attr:`state_dict` into\n        this module and its descendants. If :attr:`strict` is ``True``, then\n        the keys of :attr:`state_dict` must exactly match the keys returned\n        by this module's :meth:`~oneflow.nn.Module.state_dict` function.\n\n        Args:\n            state_dict (dict): a dict containing parameters and\n                persistent buffers.\n            strict (bool, optional): whether to strictly enforce that the keys\n                in :attr:`state_dict` match the keys returned by this module's\n                :meth:`~oneflow.nn.Module.state_dict` function. Default: ``True``\n\n        Returns:\n            ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields:\n                * **missing_keys** is a list of str containing the missing keys\n                * **unexpected_keys** is a list of str containing the unexpected keys\n\n        Note:\n            If a parameter or buffer is registered as ``None`` and its corresponding key\n            exists in :attr:`state_dict`, :meth:`load_state_dict` will raise a\n            ``RuntimeError``.\n        \"\"\"\n        missing_keys = []\n        unexpected_keys = []\n        error_msgs = []\n        metadata = getattr(state_dict, \"_metadata\", None)\n        state_dict = state_dict.copy()\n        if metadata is not None:\n            state_dict._metadata = metadata\n\n        def load(module, prefix=\"\"):\n            local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})\n            module._load_from_state_dict(\n                state_dict,\n                prefix,\n                local_metadata,\n                True,\n                missing_keys,\n                unexpected_keys,\n                error_msgs,\n            )\n            for (name, child) in module._modules.items():\n                if child is not None:\n                    load(child, prefix + name + \".\")\n\n        load(self)\n        load = None\n        if strict:\n            if len(unexpected_keys) > 0:\n                error_msgs.insert(\n                    0,\n                    \"Unexpected key(s) in state_dict: {}. \".format(\n                        \", \".join(('\"{}\"'.format(k) for k in unexpected_keys))\n                    ),\n                )\n            if len(missing_keys) > 0:\n                error_msgs.insert(\n                    0,\n                    \"Missing key(s) in state_dict: {}. \".format(\n                        \", \".join(('\"{}\"'.format(k) for k in missing_keys))\n                    ),\n                )\n        if len(error_msgs) > 0:\n            raise RuntimeError(\n                \"Error(s) in loading state_dict for {}:\\n\\t{}\".format(\n                    self.__class__.__name__, \"\\n\\t\".join(error_msgs)\n                )\n            )\n        return _IncompatibleKeys(missing_keys, unexpected_keys)\n\n    def state_dict(\n        self, destination=None, prefix=\"\", keep_vars=False\n    ) -> Dict[str, Tensor]:\n        r\"\"\"\n        state_dict(destination=None, prefix=\"\", keep_vars=False) -> Dict[str, Tensor]\n        \n        Returns a dictionary containing a whole state of the module.\n\n        Both parameters and persistent buffers (e.g. running averages) are\n        included. Keys are corresponding parameter and buffer names.\n        Parameters and buffers set to ``None`` are not included.\n\n        Args:\n            destination (dict, optional): Deprecated. This dict is returned\n                with the module state saved in it. It should also have an\n                attribute ``_metadata: dict`` to save metadata of the module\n                state. If it's not provided, an ``OrderedDict`` is created and\n                returned. Default: ``None``\n            prefix (str, optional): a prefix added to parameter and buffer\n                names to compose the keys in dict. Default: ``''``\n            keep_vars (bool, optional): by default the :class:`~oneflow.Tensor` s\n                returned in the state dict are detached from autograd. If it's\n                set to ``True``, detaching is not performed. Default: ``False``\n\n        Returns:\n            dict:\n                a dictionary containing a whole state of the module\n\n        Example::\n\n            >>> import oneflow.nn as nn\n            >>> l1 = nn.Linear(2, 2)\n            >>> l2 = nn.Linear(2, 2)\n            >>> net = nn.Sequential(l1, l2)\n            >>> net.state_dict().keys()\n            odict_keys(['0.weight', '0.bias', '1.weight', '1.bias'])\n\n        \"\"\"\n        if destination is None:\n            destination = OrderedDict()\n            destination._metadata = OrderedDict()\n\n        # TODO(hujiakui): add _version for nn.Module\n        local_metadata = dict(version=1)\n        if hasattr(destination, \"_metadata\"):\n            destination._metadata[prefix[:-1]] = local_metadata\n        self._save_to_state_dict(destination, prefix, keep_vars)\n        for (name, module) in self._modules.items():\n            if module is not None:\n                module.state_dict(destination, prefix + name + \".\", keep_vars=keep_vars)\n        for hook in self._state_dict_hooks.values():\n            hook_result = hook(self, destination, prefix, local_metadata)\n            if hook_result is not None:\n                destination = hook_result\n        return destination\n\n    _grad_t = Union[Tuple[Tensor, ...], Tensor]\n\n    def register_backward_hook(\n        self, hook: Callable[[\"Module\", _grad_t, _grad_t], Union[None, Tensor]]\n    ):\n        r\"\"\"Registers a backward hook on the module.\n\n        This function is deprecated in favor of :meth:`~oneflow.nn.Module.register_full_backward_hook` and\n        the behavior of this function will change in future versions.\n\n        Returns:\n            :class:`oneflow.utils.hooks.RemovableHandle`:\n                a handle that can be used to remove the added hook by calling\n                ``handle.remove()``\n\n        \"\"\"\n        if self._is_full_backward_hook is True:\n            raise RuntimeError(\n                \"Cannot use both regular backward hooks and full backward hooks on a \"\n                \"single Module. Please use only one of them.\"\n            )\n\n        self._is_full_backward_hook = False\n\n        handle = flow.utils.hooks.RemovableHandle(self._backward_hooks)\n        self._backward_hooks[handle.id] = hook\n        return handle\n\n    def register_full_backward_hook(\n        self, hook: Callable[[\"Module\", _grad_t, _grad_t], Union[None, Tensor]],\n    ):\n        r\"\"\"Registers a backward hook on the module.\n\n        The hook will be called every time the gradients with respect to module\n        inputs are computed. The hook should have the following signature::\n\n            hook(module, grad_input, grad_output) -> TensorTuple or None\n\n        The :attr:`grad_input` and :attr:`grad_output` are :class:`oneflow.TensorTuple` that contain the gradients\n        with respect to the inputs and outputs respectively. The hook should\n        not modify its arguments, but it can optionally return a new gradient with\n        respect to the input that will be used in place of :attr:`grad_input` in\n        subsequent computations. :attr:`grad_input` will only correspond to the inputs given\n        as positional arguments and all kwarg arguments are ignored. Entries\n        in :attr:`grad_input` and :attr:`grad_output` will be ``None`` for all non-Tensor\n        arguments.\n\n        For technical reasons, when this hook is applied to a Module, its forward function will\n        receive a view of each Tensor passed to the Module. Similarly the caller will receive a view\n        of each Tensor returned by the Module's forward function.\n\n        .. warning ::\n            Modifying inputs or outputs inplace is not allowed when using backward hooks and\n            will raise an error.\n\n        Returns:\n            :class:`oneflow.utils.hooks.RemovableHandle`:\n                a handle that can be used to remove the added hook by calling\n                ``handle.remove()``\n\n        \"\"\"\n        if self._is_full_backward_hook is False:\n            raise RuntimeError(\n                \"Cannot use both regular backward hooks and full backward hooks on a \"\n                \"single Module. Please use only one of them.\"\n            )\n        self._is_full_backward_hook = True\n\n        handle = flow.utils.hooks.RemovableHandle(self._backward_hooks)\n        self._backward_hooks[handle.id] = hook\n        return handle\n\n    def _get_backward_hooks(self):\n        r\"\"\"Returns the backward hooks for use in the call function.\n        It returns two lists, one with the full backward hooks and one with the non-full\n        backward hooks.\n        \"\"\"\n        full_backward_hooks: List[Callable] = []\n        if self._is_full_backward_hook is True:\n            full_backward_hooks += self._backward_hooks.values()\n\n        non_full_backward_hooks: List[Callable] = []\n        if self._is_full_backward_hook is False:\n            non_full_backward_hooks += self._backward_hooks.values()\n\n        return full_backward_hooks, non_full_backward_hooks\n\n    def _maybe_warn_non_full_backward_hook(self, args, res, grad_fn):\n        if not isinstance(res, Tensor):\n            if not (\n                isinstance(res, tuple) and all([isinstance(r, Tensor) for r in result])\n            ):\n                warnings.warn(\n                    \"Using non-full backward hooks on a Module that does not return a \"\n                    \"single Tensor or a tuple of Tensors is deprecated and will be removed \"\n                    \"in future versions. This hook will be missing some of the grad_output. \"\n                    \"Please use register_full_backward_hook to get the documented behavior.\"\n                )\n                return\n        else:\n            res = (res,)\n\n        if not isinstance(args, Tensor):\n            if not (\n                isinstance(args, tuple) and all([isinstance(i, Tensor) for i in args])\n            ):\n                warnings.warn(\n                    \"Using non-full backward hooks on a Module that does not take as input a \"\n                    \"single Tensor or a tuple of Tensors is deprecated and will be removed \"\n                    \"in future versions. This hook will be missing some of the grad_input. \"\n                    \"Please use register_full_backward_hook to get the documented behavior.\"\n                )\n                return\n        else:\n            args = (args,)\n\n        # At this point we are sure that inputs and result are tuple of Tensors\n        out_grad_fn = {r.grad_fn for r in res if r.grad_fn is not None}\n        if len(out_grad_fn) == 0 or (\n            len(out_grad_fn) == 1 and grad_fn not in out_grad_fn\n        ):\n            warnings.warn(\n                \"Using a non-full backward hook when outputs are nested in python data structure \"\n                \"is deprecated and will be removed in future versions. This hook will be missing \"\n                \"some grad_output.\"\n            )\n        elif len(out_grad_fn) > 1:\n            warnings.warn(\n                \"Using a non-full backward hook when outputs are generated by different autograd Nodes \"\n                \"is deprecated and will be removed in future versions. This hook will be missing \"\n                \"some grad_output. Please use register_full_backward_hook to get the documented behavior.\"\n            )\n        else:\n            # At this point the grad_output part of the hook will most likely be correct\n            inputs_grad_fn = {i.grad_fn for i in args if i.grad_fn is not None}\n\n            next_functions = {grad_fn.next_functions[0][0]}\n\n            if inputs_grad_fn != next_functions:\n                warnings.warn(\n                    \"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n                    \"is deprecated and will be removed in future versions. This hook will be missing \"\n                    \"some grad_input. Please use register_full_backward_hook to get the documented \"\n                    \"behavior.\"\n                )\n\n    def register_forward_pre_hook(self, hook: Callable[..., None]):\n        r\"\"\"\n        register_forward_pre_hook(hook)\n        \n        Registers a forward pre-hook on the module.\n\n        The hook will be called every time before :func:`forward` is invoked.\n        It should have the following signature::\n\n            hook(module, input) -> None or modified input\n\n        The input contains only the positional arguments given to the module.\n        Keyword arguments won't be passed to the hooks and only to the ``forward``.\n        The hook can modify the input. User can either return a tuple or a\n        single modified value in the hook. We will wrap the value into a tuple\n        if a single value is returned(unless that value is already a tuple).\n\n        \"\"\"\n        handle = flow.utils.hooks.RemovableHandle(self._forward_pre_hooks)\n        self._forward_pre_hooks[handle.id] = hook\n        return handle\n\n    def register_forward_hook(self, hook: Callable[..., None]):\n        r\"\"\"\n        register_forward_hook(hook)\n\n        Registers a forward hook on the module.\n\n        The hook will be called every time after :func:`forward` has computed an output.\n        It should have the following signature::\n\n            hook(module, input, output) -> None or modified output\n\n        The input contains only the positional arguments given to the module.\n        Keyword arguments won't be passed to the hooks and only to the ``forward``.\n        The hook can modify the output. It can modify the input inplace but\n        it will not have effect on forward since this is called after\n        :func:`forward` is called.\n\n        \"\"\"\n        handle = flow.utils.hooks.RemovableHandle(self._forward_hooks)\n        self._forward_hooks[handle.id] = hook\n        return handle\n\n    def _apply(self, fn):\n        if not hasattr(self, \"cpg\"):\n            self.cpg = None\n        if self.cpg is not None:\n            self.cpg = None\n            warnings.warn(\n                \"deleted ContiguousParamsGroup since creating it before \"\n                \"apply operations like to(), to_global() will cause error.\"\n            )\n\n        # A dict to store tensors that has already been applied.\n        # There is no need to apply multiple times on a same tensor.\n        if self._oneflow_internal_module_tensor_applied_dict__ is None:\n            self._oneflow_internal_module_tensor_applied_dict__ = dict()\n\n        for module in self.children():\n            module._oneflow_internal_module_tensor_applied_dict__ = (\n                self._oneflow_internal_module_tensor_applied_dict__\n            )\n            module._apply(fn)\n            module._oneflow_internal_module_tensor_applied_dict__ = None\n\n        def can_use_assign_copy(tensor, tensor_applied):\n            return tensor.is_local == tensor_applied.is_local\n\n        for (key, param) in self._parameters.items():\n            if param is None:\n                continue\n\n            need_apply = False\n            if param not in self._oneflow_internal_module_tensor_applied_dict__:\n                need_apply = True\n                assert isinstance(param, Parameter)\n                assert param.is_leaf\n                with flow.no_grad():\n                    param_applied = fn(param)\n                param_applied.requires_grad = param.requires_grad\n\n                if param.grad is not None:\n                    assert param.grad.is_leaf\n                    with flow.no_grad():\n                        grad_applied = fn(param.grad)\n                    grad_applied.requires_grad = param.grad.requires_grad\n                    param_applied.grad = grad_applied\n            else:\n                param_applied = self._oneflow_internal_module_tensor_applied_dict__[\n                    param\n                ]\n\n            if can_use_assign_copy(param_applied, param):\n                if need_apply:\n                    self._parameters[key].data = param_applied\n                    self._oneflow_internal_module_tensor_applied_dict__[\n                        param\n                    ] = param_applied\n                else:\n                    # The parameter's data has already been set when it can use assign copy.\n                    pass\n            else:\n                if need_apply:\n                    new_param = Parameter(param_applied, param.requires_grad)\n                    self._parameters[key] = new_param\n                    self._oneflow_internal_module_tensor_applied_dict__[\n                        param\n                    ] = new_param\n                else:\n                    self._parameters[\n                        key\n                    ] = self._oneflow_internal_module_tensor_applied_dict__[param]\n\n        for (key, buf) in self._buffers.items():\n            if buf is not None:\n                if buf not in self._oneflow_internal_module_tensor_applied_dict__:\n                    buf_applied = fn(buf)\n                    self._buffers[key] = buf_applied\n                    self._oneflow_internal_module_tensor_applied_dict__[\n                        buf\n                    ] = buf_applied\n                else:\n                    self._buffers[\n                        key\n                    ] = self._oneflow_internal_module_tensor_applied_dict__[buf]\n\n        self._oneflow_internal_module_tensor_applied_dict__ = None\n        return self\n\n    def apply(self: T, fn: Callable[[\"Module\"], None]) -> T:\n        r\"\"\"\n        apply(fn)\n        \n        Applies ``fn`` recursively to every submodule (as returned by ``.children()``)\n        as well as self. Typical use includes initializing the parameters of a model.\n\n        Args:\n            fn (:class:`Module` -> None): function to be applied to each submodule\n\n        Returns:\n            Module: self\n\n        Example::\n    \n            >>> import oneflow as flow\n            >>> import oneflow.nn as nn\n            >>> @flow.no_grad()\n            ... def init_weights(m):\n            ...     print(m)\n            ...     if type(m) == nn.Linear:\n            ...         m.weight.fill_(1.0)\n            ...         print(m.weight)\n            >>> net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))\n            >>> net.apply(init_weights)\n            Linear(in_features=2, out_features=2, bias=True)\n            tensor([[1., 1.],\n                    [1., 1.]], dtype=oneflow.float32, requires_grad=True)\n            Linear(in_features=2, out_features=2, bias=True)\n            tensor([[1., 1.],\n                    [1., 1.]], dtype=oneflow.float32, requires_grad=True)\n            Sequential(\n              (0): Linear(in_features=2, out_features=2, bias=True)\n              (1): Linear(in_features=2, out_features=2, bias=True)\n            )\n            Sequential(\n              (0): Linear(in_features=2, out_features=2, bias=True)\n              (1): Linear(in_features=2, out_features=2, bias=True)\n            )\n        \"\"\"\n        if self.cpg is not None:\n            self.cpg = None\n            warnings.warn(\n                \"deleted ContiguousParamsGroup since creating it before \"\n                \"apply operations like to(), to_global() will cause error.\"\n            )\n\n        for module in self.children():\n            module.apply(fn)\n        fn(self)\n        return self\n\n    def to_empty(self: T, *, device: Union[str, flow.device]) -> T:\n        r\"\"\"Moves the parameters and buffers to the specified device without copying storage.\n\n        Args:\n            device (:class:`oneflow.device`): the desired device of the parameters\n                and buffers in this module\n        \n        Returns:\n            Module: self\n        \"\"\"\n        return self._apply(lambda t: flow.empty_like(t, device=device))\n\n    def _to_memory_format(self, memory_format):\n        r\"\"\"Casts the parameters and buffers in this module to another memory format.\n\n        The data_format attribute should also be modified. \n        \n        Note:\n            This interface is unstable and may be removed in the future once the data_format\n            attribute has been removed from the module.\n\n        Args:\n            memory_format (:class:`oneflow.memory_format`): the desired memory\n                format for 4D parameters and buffers in this module (keyword\n                only argument)\n\n        Returns:\n            Module: self\n        \"\"\"\n        for module in self.children():\n            module._to_memory_format(memory_format)\n        self.to_memory_format(memory_format)\n        return self\n\n    def to_memory_format(self, memory_format) -> None:\n        pass\n\n    @overload\n    def to(\n        self: T,\n        device: Optional[Union[int, str, flow.device]] = ...,\n        dtype: Optional[flow.dtype] = ...,\n    ) -> T:\n        ...\n\n    @overload\n    def to(self: T, dtype: flow.dtype) -> T:\n        ...\n\n    @overload\n    def to(self: T, tensor: Tensor) -> T:\n        ...\n\n    def to(self, *args, **kwargs):\n        r\"\"\"Moves and/or casts the parameters and buffers.\n\n        This can be called as\n\n        .. function:: to(device=None, dtype=None)\n           :noindex:\n\n        .. function:: to(dtype)\n           :noindex:\n\n        .. function:: to(memory_format=None)\n           :noindex:\n\n        .. function:: to(tensor)\n           :noindex:\n\n        Its signature is similar to :meth:`oneflow.Tensor.to`, but only accepts\n        floating point :attr:`dtype`\\ s. In addition, this method will\n        only cast the floating point parameters and buffers to :attr:`dtype`\n        (if given). The integral parameters and buffers will be moved\n        :attr:`device`, if that is given, but with dtypes unchanged.\n\n        See below for examples.\n\n        .. note::\n            This method modifies the module in-place.\n\n        Args:\n            device (:class:`oneflow.device`): the desired device of the parameters\n                and buffers in this module\n            dtype (:class:`oneflow.dtype`): the desired floating point dtype of\n                the parameters and buffers in this module\n            memory_format (:class:`oneflow.memory_format`): the desired memory\n                format for 4D parameters and buffers in this module (keyword\n                only argument)\n            tensor (oneflow.Tensor): Tensor whose dtype and device are the desired\n                dtype and device for all parameters and buffers in this module\n\n        Returns:\n            Module: self\n\n        Examples::\n\n            >>> import oneflow as flow\n            >>> import oneflow.nn as nn\n            >>> linear = nn.Linear(2, 2)\n            >>> linear.weight.device\n            device(type='cpu', index=0)\n            >>> linear.weight.dtype\n            oneflow.float32\n            >>> linear.to(flow.double)\n            Linear(in_features=2, out_features=2, bias=True)\n            >>> linear.weight.dtype\n            oneflow.float64\n            >>> gpu1 = flow.device(\"cuda:1\")\n            >>> linear.to(gpu1, dtype=flow.half)\n            Linear(in_features=2, out_features=2, bias=True)\n            >>> linear.weight.device\n            device(type='cuda', index=1)\n            >>> linear.weight.dtype\n            oneflow.float16\n            >>> cpu = flow.device(\"cpu\")\n            >>> linear.to(cpu)\n            Linear(in_features=2, out_features=2, bias=True)\n            >>> linear.weight.device\n            device(type='cpu', index=0)\n\n        \"\"\"\n\n        device = None\n        dtype = None\n        memory_format = None\n        if len(args) + len(kwargs) == 2:\n            device = kwargs.pop(\"device\", None) or args[0]\n            dtype = kwargs.pop(\"dtype\", None) or args[1]\n        elif len(args) + len(kwargs) == 1:\n            if len(args) == 1:\n                arg = args[0]\n                if isinstance(arg, Tensor):\n                    device = arg.device\n                    dtype = arg.dtype\n                elif isinstance(arg, flow.dtype):\n                    dtype = arg\n                    device = None\n                elif isinstance(arg, (flow.device, str, int)):\n                    dtype = None\n                    device = arg\n                elif isinstance(arg, flow.memory_format):\n                    memory_format = arg\n                else:\n                    raise ValueError(f\"Unsupported parameters in module.to: {arg}\")\n            else:\n                device = kwargs.pop(\"device\", None)\n                dtype = kwargs.pop(\"dtype\", None)\n                memory_format = kwargs.pop(\"memory_format\", None)\n                tensor = kwargs.pop(\"tensor\", None)\n                if tensor is not None:\n                    device = tensor.device\n                    dtype = tensor.dtype\n        else:\n            raise ValueError(\n                f\"Unsupported parameters in module.to: {args} and {kwargs}\"\n            )\n\n        if dtype is not None:\n            if not dtype.is_floating_point:\n                raise TypeError(\n                    \"nn.Module.to only accepts floating point \"\n                    \"dtypes, but got desired dtype={}\".format(dtype)\n                )\n\n        if memory_format is not None:\n            self._to_memory_format(memory_format)\n\n        def convert(t):\n            return t.to(device, dtype if t.is_floating_point() else None)\n\n        return self._apply(convert)\n\n    def to_consistent(self, *args, **kwargs):\n        raise RuntimeError(\n            \".to_consistent has been removed, please use .to_global instead\"\n        )\n\n    def to_global(self, placement=None, sbp=None):\n        def convert(t):\n            return t.to_global(placement=placement, sbp=sbp)\n\n        return self._apply(convert)\n\n    def to_local(self):\n        def convert(t):\n            return t.to_local()\n\n        return self._apply(convert)\n\n    def cpu(self: T) -> T:\n        r\"\"\"\n        cpu()\n        \n        Moves all model parameters and buffers to the CPU.\n\n        .. note::\n            This method modifies the module in-place.\n\n        Returns:\n            Module: self\n        \"\"\"\n        return self._apply(lambda t: t.cpu())\n\n    def cuda(self: T, device: Optional[Union[int, flow.device]] = None) -> T:\n        r\"\"\"\n        cuda(device=None)\n        \n        Moves all model parameters and buffers to the GPU.\n\n        This also makes associated parameters and buffers different objects. So\n        it should be called before constructing optimizer if the module will\n        live on GPU while being optimized.\n\n        .. note::\n            This method modifies the module in-place.\n\n        Args:\n            device (int, optional): if specified, all parameters will be\n                copied to that device\n\n        Returns:\n            Module: self\n        \"\"\"\n        return self._apply(lambda t: t.cuda(device))\n\n    def float(self: T) -> T:\n        r\"\"\"\n        float()\n        \n        Casts all floating point parameters and buffers to ``float`` datatype.\n\n        .. note::\n            This method modifies the module in-place.\n\n        Returns:\n            Module: self\n        \"\"\"\n        return self._apply(lambda t: t.float() if t.is_floating_point() else t)\n\n    def double(self: T) -> T:\n        r\"\"\"\n        double()\n        \n        Casts all floating point parameters and buffers to ``double`` datatype.\n\n        .. note::\n            This method modifies the module in-place.\n\n        Returns:\n            Module: self\n        \"\"\"\n        return self._apply(lambda t: t.double() if t.is_floating_point() else t)\n\n    def half(self: T) -> T:\n        r\"\"\"\n        half()\n        \n        Casts all floating point parameters and buffers to ``half`` datatype.\n\n        .. note::\n            This method modifies the module in-place.\n\n        Returns:\n            Module: self\n        \"\"\"\n        return self._apply(lambda t: t.half() if t.is_floating_point() else t)\n\n    def _get_name(self):\n        return self.__class__.__name__\n\n    def get_submodule(self, target: str):\n        r\"\"\"Get submodule accroding to the name of submodule.\n\n        Args:\n            target (str): The name of submodule to find.\n\n        .. code-block:: python\n\n            >>> from oneflow import nn\n            >>> class Net3(nn.Module):\n            >>>     def __init__(self):\n            >>>         super().__init__()\n            >>>         self.linear = nn.Linear(3, 2)\n            >>>\n            >>> class Net2(nn.Module):\n            >>>     def __init__(self):\n            >>>         super().__init__()\n            >>>         self.net3 = Net3()\n            >>>\n            >>> class Net1(nn.Module):\n            >>>     def __init__(self):\n            >>>         super().__init__()\n            >>>         self.net2 = Net2()\n            >>>\n            >>> net = Net1()\n            >>> print(net.get_submodule(\"net2.net3\"))\n            Net3(\n            (linear): Linear(in_features=3, out_features=2, bias=True)\n            )\n            >>> print(net.get_submodule(\"net2\"))\n            Net2(\n            (net3): Net3(\n                (linear): Linear(in_features=3, out_features=2, bias=True)\n                )\n            )\n\n        Returns:\n            oneflow.nn.Module: The submodule referenced by ``target``\n\n        Raises:\n            AttributeError: If the module can't reference the submodule accroding to ``target``\n            TypeError: If the result referenced by ``target`` is not an ``nn.Module``\n\n        \"\"\"\n        if target == \"\":\n            return self\n        curr_module_name = [self._get_name()]\n        submodule_names = target.split(\".\")\n        mod = self\n        for submodule_name in submodule_names:\n            if not hasattr(mod, submodule_name):\n                raise AttributeError(\n                    f\"`{'.'.join(curr_module_name)}` doesn't have submodule `{submodule_name}`\"\n                )\n            mod = getattr(mod, submodule_name)\n            curr_module_name.append(submodule_name)\n            if not isinstance(mod, flow.nn.Module):\n                raise TypeError(\n                    f\"`{'.'.join(curr_module_name)}` isn't an oneflow.Module, but a {type(mod)}\"\n                )\n        return mod\n\n    def get_parameter(self, target: str):\n        r\"\"\"Return the parameter refenreced by ``target``.\n\n        Args:\n            target (str): The name of parameter to find.\n\n        .. code-block:: python\n\n            >>> from oneflow import nn\n            >>> class Net3(nn.Module):\n            >>>     def __init__(self):\n            >>>         super().__init__()\n            >>>         self.linear = nn.Linear(3, 3)\n            >>>\n            >>> class Net2(nn.Module):\n            >>>     def __init__(self):\n            >>>         super().__init__()\n            >>>         self.net3 = Net3()\n            >>>         self.linear = nn.Linear(2, 2)\n            >>>\n            >>> class Net1(nn.Module):\n            >>>     def __init__(self):\n            >>>         super().__init__()\n            >>>         self.net2 = Net2()\n            >>>         self.linear = nn.Linear(1, 1)\n            >>>\n            >>> net = Net1()\n            >>> print(net.get_parameter(\"linear.weight\").shape)\n            oneflow.Size([1, 1])\n            >>> print(net.get_parameter(\"net2.linear.weight\").shape)\n            oneflow.Size([2, 2])\n\n        Returns:\n            oneflow.nn.Parameter: The parameter referenced by ``target``\n\n        Raises:\n            AttributeError: If the module can't reference the parameter according to ``target``\n            TypeError: If the result refererenced by ``target`` is not an ``nn.Parameter``\n\n        \"\"\"\n        sub_module_name, _, parameter_name = target.rpartition(\".\")\n        sub_module = self.get_submodule(sub_module_name)\n        if hasattr(sub_module, parameter_name):\n            parameter = getattr(sub_module, parameter_name)\n        else:\n            raise AttributeError(\n                f\"`{sub_module_name}` doesn't have attribute `{parameter_name}`\"\n            )\n        if not isinstance(parameter, flow.Tensor):\n            raise TypeError(\n                f\"`{target}` is not an oneflow.Tensor, but {type(parameter)}\"\n            )\n        return parameter\n\n    def extra_repr(self) -> str:\n        \"\"\"Set the extra representation of the module\n\n        To print customized extra information, you should re-implement\n        this method in your own modules. Both single-line and multi-line\n        strings are acceptable.\n        \"\"\"\n        return \"\"\n\n    def make_contiguous_params_group(self):\n        r\"\"\"Get contiguous parameters group after creating the whole module.\n\n        Rearrange the parameters of the model in the same dtype and device \n        (or placement and sbp for global tensor) to form a single tensor for\n        accelerating the element-wise operations of parameters' data or gradient.\n\n        .. note::\n            This method should be used strictly after all parameters have finished\n            doing apply operations, otherwise it will cause an error.\n\n        Example::\n\n        >>> net = Network().to(device)\n        >>> net.make_contiguous_params_group()\n        \n        \"\"\"\n        self.cpg = flow.nn.utils.parameters_grouping.ContiguousParamsGroup(\n            list(self.parameters()), group_on_current_buffer=False\n        )\n\n    def __repr__(self):\n        extra_lines = []\n        extra_repr = self.extra_repr()\n        if extra_repr:\n            extra_lines = extra_repr.split(\"\\n\")\n        child_lines = []\n        for (key, module) in self._modules.items():\n            mod_str = repr(module)\n            mod_str = _addindent(mod_str, 2)\n            child_lines.append(\"(\" + key + \"): \" + mod_str)\n        lines = extra_lines + child_lines\n        main_str = self._get_name() + \"(\"\n        if lines:\n            if len(extra_lines) == 1 and (not child_lines):\n                main_str += extra_lines[0]\n            else:\n                main_str += \"\\n  \" + \"\\n  \".join(lines) + \"\\n\"\n        main_str += \")\"\n        return main_str\n\n    def _shallow_repr(self):\n        extra_lines = []\n        extra_repr = self.extra_repr()\n        if extra_repr:\n            extra_lines = extra_repr.split(\"\\n\")\n        lines = extra_lines\n        main_str = self._get_name() + \"(\"\n        if lines:\n            if len(extra_lines) == 1:\n                main_str += extra_lines[0]\n            else:\n                main_str += \"\\n  \" + \"\\n  \".join(lines) + \"\\n\"\n        main_str += \")\"\n        return main_str\n"
  },
  {
    "path": "python/oneflow/nn/modules/moving_average_min_max_observer.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport numpy as np\nimport oneflow as flow\nfrom oneflow.framework.tensor import register_tensor_op\nfrom oneflow.nn.modules.module import Module\n\n\nclass MovingAverageMinMaxObserver(Module):\n    \"\"\"\n    \n    Compute the quantization parameters based on the moving average of the input tensor's min and max values.\n\n    First compute the moving\\\\_max and moving\\\\_min value of input tensor:\n\n        if quantization_scheme == \"symmetric\":\n\n        .. math::\n\n            & moving\\\\_max = moving\\\\_max * momentum + |max(input)| * (1 - momentum)\n\n            & moving\\\\_min = moving\\\\_max\n\n        elif quantization_scheme == \"affine\":\n\n        .. math::\n\n            & moving\\\\_max = moving\\\\_max * momentum + max(input) * (1 - momentum)\n\n            & moving\\\\_min = moving\\\\_min * momentum + min(input) * (1 - momentum)\n\n    The moving average of min and max values are initialized as the first batch of input `Blob`'s min and max.\n\n    Then compute the scale and zero_point with the following equations:\n\n        if quantization_scheme == \"symmetric\":\n\n        .. math::\n\n            & denom = 2^{quantization\\\\_to\\\\_bit - 1} - 1\n\n            & scale = moving\\\\_max / denom\n\n            & zero\\\\_point = 0\n\n        elif quantization_scheme == \"affine\":\n\n        .. math::\n\n            & denom = 2^{quantization\\\\_to\\\\_bit} - 1\n\n            & scale = (moving\\\\_max - moving\\\\_min) / denom\n\n            & zero\\\\_point = -moving\\\\_min / scale\n\n    Note:\n        ``current_train_step`` can be directly assigned to an optimizer(eg.SGD) step.\n\n    Args:\n        input(oneflow.Tensor):  the input value(s), in ``oneflow.float32``.\n        current_train_step_tensor(oneflow.Tensor): record train step for quantionzation aware training.\n        stop_update_after_iters(int): stop record train step for quantionzation aware training when train iter greater than stop_update_after_iters.\n        quantization_formula (str): Support \"google\" or \"cambricon\".\n        quantization_bit (int): Quantize input to uintX / intX, X can be in range [2, 8]. Defaults to 8.\n        quantization_scheme (str): \"symmetric\" or \"affine\", quantize to signed / unsigned integer. Defaults to \"symmetric\".\n        momentum (float): Smoothing parameter for exponential moving average operation. Defaults to 0.95.\n\n    Returns:\n        Tuple[oneflow.Tensor, oneflow.Tensor]: The scale and zero_point of input tensor.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import numpy as np\n        >>> import oneflow as flow\n\n        >>> weight = (np.random.random((2, 3, 4, 5)) - 0.5).astype(np.float32)\n        \n        >>> input_tensor = flow.tensor(\n        ...    weight, dtype=flow.float32\n        ... )\n\n        >>> current_train_step_tensor = flow.tensor(\n        ...   np.zeros((1,)).astype(np.float32),\n        ...    dtype=flow.int64,\n        ... )\n        \n        >>> momentum = 0.95\n        >>> quantization_bit = 8\n        >>> quantization_scheme = \"symmetric\"\n        >>> quantization_formula = \"google\"\n\n        >>> moving_average_min_max_observer = flow.nn.MovingAverageMinMaxObserver(stop_update_after_iters=1,  \n        ...                                                                       quantization_formula=quantization_formula, quantization_bit=quantization_bit,\n        ...                                                                       quantization_scheme=quantization_scheme, momentum=momentum,\n        ...                                                                       )\n\n        >>> (scale, zero_point) = moving_average_min_max_observer(\n        ...    input_tensor,\n        ...    current_train_step_tensor,\n        ... )\n\n    \"\"\"\n\n    def __init__(\n        self,\n        stop_update_after_iters: int = 1,\n        quantization_formula: str = \"google\",\n        quantization_bit: int = 8,\n        quantization_scheme: str = \"symmetric\",\n        momentum: float = 0.95,\n    ) -> None:\n        super().__init__()\n        self.quantization_formula = quantization_formula\n        self.stop_update_after_iters = stop_update_after_iters\n        self.quantization_bit = quantization_bit\n        self.quantization_scheme = quantization_scheme\n        self.momentum = momentum\n        self.register_buffer(\"moving_max\", flow.Tensor(1))\n        self.register_buffer(\"moving_min\", flow.Tensor(1))\n        self.reset_running_stats()\n\n    def reset_running_stats(self) -> None:\n        self.moving_max.fill_(0)\n        self.moving_min.fill_(0)\n\n    def forward(self, input, current_train_step):\n        return flow._C.moving_average_min_max_observer(\n            input,\n            current_train_step,\n            self.moving_max,\n            self.moving_min,\n            self.training,\n            self.stop_update_after_iters,\n            self.quantization_formula,\n            self.quantization_bit,\n            self.quantization_scheme,\n            self.momentum,\n        )\n\n\nif __name__ == \"__main__\":\n    import doctest\n\n    doctest.testmod(raise_on_error=True)\n"
  },
  {
    "path": "python/oneflow/nn/modules/nms.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow as flow\nfrom oneflow.framework.tensor import register_tensor_op\nfrom oneflow.nn.modules.module import Module\n\n\ndef nms_op(boxes, scores, iou_threshold: float):\n    score_inds = flow.argsort(scores, dim=0, descending=True)\n    boxes = flow._C.gather(boxes, score_inds, axis=0)\n    keep = flow._C.nms(boxes, iou_threshold)\n    index = flow.squeeze(flow.argwhere(keep), dim=[1])\n    return flow._C.gather(score_inds, index, axis=0)\n"
  },
  {
    "path": "python/oneflow/nn/modules/nonzero.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nfrom typing import Optional\n\nimport numpy as np\n\nimport oneflow as flow\nfrom oneflow.framework.tensor import register_tensor_op\nfrom oneflow.nn.modules.module import Module\n\n\ndef nonzero_op(input, as_tuple=False):\n    meta_device_flag = False\n    if input.is_global:\n        if input.placement.type == \"meta\":\n            meta_device_flag = True\n    else:\n        if input.device.type == \"meta\":\n            meta_device_flag = True\n    if meta_device_flag:\n        raise RuntimeError(\n            \"Could not run nonzero with arguments from the meta backend.\"\n        )\n    if as_tuple:\n        return flow._C.nonzero(input, as_tuple)\n    else:\n        return flow._C.nonzero(input, as_tuple)[0]\n\n\nif __name__ == \"__main__\":\n    import doctest\n\n    doctest.testmod(raise_on_error=True)\n"
  },
  {
    "path": "python/oneflow/nn/modules/norm.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow as flow\n\n\ndef norm(input, p=\"fro\", dim=None, keepdim=False, dtype=None):\n    \"\"\"\n    Returns the matrix norm or vector norm of a given tensor.\n\n    The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.norm.html.\n\n    .. warning::\n\n        Use :func:`oneflow.linalg.norm`, instead, or :func:`oneflow.linalg.vector_norm`\n        when computing vector norms and :func:`oneflow.linalg.matrix_norm` when\n        computing matrix norms. Note, however, the signature for these functions\n        is slightly different than the signature for oneflow.norm.\n\n    Args:\n        input (Tensor): The input tensor. Its data type must be either a floating\n            point or complex type. For complex inputs, the norm is calculated using the\n            absolute value of each element. If the input is complex and neither\n            :attr:`dtype` nor :attr:`out` is specified, the result's data type will\n            be the corresponding floating point type (e.g. float if :attr:`input` is\n            complexfloat).\n\n        p (int, float, inf, -inf, 'fro', 'nuc', optional): the order of norm. Default: ``'fro'``\n            The following norms can be calculated:\n\n            ======  ==============  ==========================\n            ord     matrix norm     vector norm\n            ======  ==============  ==========================\n            'fro'   Frobenius norm  --\n            'nuc'   nuclear norm    --\n            Number  --              sum(abs(x)**p)**(1./p)\n            ======  ==============  ==========================\n\n            The vector norm can be calculated across any number of dimensions.\n            The corresponding dimensions of :attr:`input` are flattened into\n            one dimension, and the norm is calculated on the flattened\n            dimension.\n\n            Frobenius norm produces the same result as ``p=2`` in all cases\n            except when :attr:`dim` is a list of three or more dims, in which\n            case Frobenius norm throws an error.\n\n            Nuclear norm can only be calculated across exactly two dimensions.\n\n        dim (int, tuple of ints, list of ints, optional):\n            Specifies which dimension or dimensions of :attr:`input` to\n            calculate the norm across. If :attr:`dim` is ``None``, the norm will\n            be calculated across all dimensions of :attr:`input`. If the norm\n            type indicated by :attr:`p` does not support the specified number of\n            dimensions, an error will occur.\n        keepdim (bool, optional): whether the output tensors have :attr:`dim`\n            retained or not. Ignored if :attr:`dim` = ``None`` and\n            :attr:`out` = ``None``. Default: ``False``\n        dtype (:class:`oneflow.dtype`, optional): the desired data type of\n            returned tensor. If specified, the input tensor is casted to\n            :attr:`dtype` while performing the operation. Default: None.\n\n    .. note::\n        Even though ``p='fro'`` supports any number of dimensions, the true\n        mathematical definition of Frobenius norm only applies to tensors with\n        exactly two dimensions. :func:`oneflow.linalg.norm` with ``ord='fro'`` aligns\n        with the mathematical definition, since it can only be applied across\n        exactly two dimensions.\n\n    Example::\n\n        >>> import oneflow as flow\n        >>> a = flow.arange(9, dtype= flow.float) - 4\n        >>> b = a.reshape((3, 3))\n        >>> flow.norm(a)\n        tensor(7.7460, dtype=oneflow.float32)\n        >>> flow.norm(b)\n        tensor(7.7460, dtype=oneflow.float32)\n        >>> flow.norm(a, float('inf'))\n        tensor(4., dtype=oneflow.float32)\n        >>> flow.norm(b, float('inf'))\n        tensor(9., dtype=oneflow.float32)\n        >>> c = flow.tensor([[ 1, 2, 3],[-1, 1, 4]] , dtype= flow.float)\n        >>> flow.norm(c, dim=0)\n        tensor([1.4142, 2.2361, 5.0000], dtype=oneflow.float32)\n        >>> flow.norm(c, dim=1)\n        tensor([3.7417, 4.2426], dtype=oneflow.float32)\n        >>> flow.norm(c, p=1, dim=1)\n        tensor([6., 6.], dtype=oneflow.float32)\n        >>> d = flow.arange(8, dtype= flow.float).reshape(2,2,2)\n        >>> flow.norm(d, dim=(1,2))\n        tensor([ 3.7417, 11.2250], dtype=oneflow.float32)\n        >>> flow.norm(d[0, :, :]), flow.norm(d[1, :, :])\n        (tensor(3.7417, dtype=oneflow.float32), tensor(11.2250, dtype=oneflow.float32))\n    \"\"\"\n    if type(p) == str or dim != None:\n        return flow._C.norm(input=input, ord=p, dim=dim, keepdim=keepdim, dtype=dtype)\n    return flow._C.norm(\n        input=input, ord=p, dim=dim, keepdim=keepdim, dtype=dtype, for_norm=True\n    )\n"
  },
  {
    "path": "python/oneflow/nn/modules/normalization.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport warnings\nfrom typing import Tuple, Union\n\nimport oneflow as flow\nfrom oneflow.framework.tensor import Tensor\nfrom oneflow.nn import init\nfrom oneflow.nn.modules.module import Module\n\n_shape_t = Union[int, Tuple[int], flow._oneflow_internal.Size]\n\n\ndef group_norm(\n    input: Tensor,\n    num_groups: int,\n    weight: Tensor = None,\n    bias: Tensor = None,\n    eps: float = 1e-05,\n    num_channels: int = None,\n):\n    r\"\"\"Apply Group Normalization for last certain number of dimensions.\n\n    See :class:`~oneflow.nn.GroupNorm` for details.\n    \"\"\"\n    assert len(input.shape) >= 3, \"The dimensions of input tensor must larger than 2\"\n    if num_channels is None:\n        num_channels = input.shape[1]\n    assert (\n        input.shape[1] == num_channels\n    ), \"The channels of input tensor must equal num_channels\"\n\n    affine = weight is not None and bias is not None\n    if not input.is_cpu:\n        return flow._C.group_norm(input, weight, bias, affine, num_groups, eps)\n    else:\n        origin_shape = input.shape\n        reshape_to_1d = flow.reshape(input, shape=[origin_shape[0], num_groups, -1])\n        mean = flow.mean(reshape_to_1d, dim=2, keepdim=True)\n        variance = flow.var(reshape_to_1d, dim=2, unbiased=False, keepdim=True)\n        normalized = (reshape_to_1d - mean) / flow.sqrt(variance + eps)\n        normalized = flow.reshape(normalized, shape=[origin_shape[0], num_channels, -1])\n        if weight is not None:\n            normalized = normalized * weight.reshape(1, num_channels, 1)\n        if bias is not None:\n            normalized = normalized + bias.reshape(1, num_channels, 1)\n        res = flow.reshape(normalized, shape=tuple(input.shape))\n        return res\n\n\nclass GroupNorm(Module):\n    \"\"\"\n    Applies Group Normalization over a mini-batch of inputs as described in\n    the paper `Group Normalization <https://arxiv.org/abs/1803.08494>`__\n\n    .. math::\n\n        y = \\\\frac{x - \\\\mathrm{E}[x]}{ \\\\sqrt{\\\\mathrm{Var}[x] + \\\\epsilon}} * \\\\gamma + \\\\beta\n\n    The input channels are separated into :attr:`num_groups` groups, each containing\n    ``num_channels / num_groups`` channels. The mean and standard-deviation are calculated\n    separately over the each group. :math:`\\\\gamma` and :math:`\\\\beta` are learnable\n    per-channel affine transform parameter vectors of size :attr:`num_channels` if\n    :attr:`affine` is ``True``.\n    The standard-deviation is calculated via the biased estimator, equivalent to\n    `torch.var(input, unbiased=False)`.\n\n    This layer uses statistics computed from input data in both training and\n    evaluation modes.\n\n    The interface is consistent with PyTorch.\n    The documentation is referenced from:\n    https://pytorch.org/docs/1.10/generated/torch.nn.GroupNorm.html.\n\n    Args:\n        num_groups (int): number of groups to separate the channels into\n        num_channels (int): number of channels expected in input\n        eps: a value added to the denominator for numerical stability. Default: 1e-5\n        affine: a boolean value that when set to ``True``, this module\n            has learnable per-channel affine parameters initialized to ones (for weights)\n            and zeros (for biases). Default: ``True``.\n\n    Shape:\n        - Input: :math:`(N, C, *)` where :math:`C=\\\\text{num_channels}`\n        - Output: :math:`(N, C, *)` (same shape as input)\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n        >>> input = flow.Tensor(np.random.randn(20, 6, 10, 10))\n        >>> # Separate 6 channels into 3 groups\n        >>> m = flow.nn.GroupNorm(3, 6)\n        >>> # Separate 6 channels into 6 groups (equivalent with InstanceNorm)\n        >>> m = flow.nn.GroupNorm(6, 6)\n        >>> # Put all 6 channels into a single group (equivalent with LayerNorm)\n        >>> m = flow.nn.GroupNorm(1, 6)\n        >>> # Activating the module\n        >>> output = m(input)\n    \n\"\"\"\n\n    def __init__(\n        self,\n        num_groups: int,\n        num_channels: int,\n        eps: float = 1e-05,\n        affine: bool = True,\n        device=None,\n        dtype=None,\n    ) -> None:\n        super().__init__()\n        assert num_groups > 0, \"The num_groups must larger than zero\"\n        assert num_channels > 0, \"The num_channels must larger than zero\"\n        self.num_groups = num_groups\n        self.num_channels = num_channels\n        self.eps = eps\n        self.affine = affine\n        factory_kwargs = {}\n        if device:\n            factory_kwargs[\"device\"] = device\n        if dtype:\n            factory_kwargs[\"dtype\"] = dtype\n        if self.affine:\n            self.weight = flow.nn.Parameter(\n                flow.Tensor(num_channels).to(**factory_kwargs)\n            )\n            self.bias = flow.nn.Parameter(\n                flow.Tensor(num_channels).to(**factory_kwargs)\n            )\n        else:\n            self.register_parameter(\"weight\", None)\n            self.register_parameter(\"bias\", None)\n        self.reset_parameters()\n\n    def reset_parameters(self) -> None:\n        if self.affine:\n            flow.nn.init.ones_(self.weight)\n            flow.nn.init.zeros_(self.bias)\n\n    def forward(self, input: Tensor) -> Tensor:\n        return group_norm(\n            input, self.num_groups, self.weight, self.bias, self.eps, self.num_channels\n        )\n\n    def extra_repr(self) -> str:\n        return \"{num_groups}, {num_channels}, eps={eps}, affine={affine}\".format(\n            **self.__dict__\n        )\n\n\ndef layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-05):\n    assert len(input.shape) > len(\n        normalized_shape\n    ), \"Input tensor dim must greater than normalized dim!\"\n    begin_norm_axis = len(input.shape) - len(normalized_shape)\n    begin_params_axis = len(input.shape) - len(normalized_shape)\n\n    elementwise_affine = True if (weight is not None and bias is not None) else False\n\n    for i in range(0, len(normalized_shape)):\n        if input.shape[i + begin_params_axis] != normalized_shape[i]:\n            raise RuntimeError(\n                f\"Given normalized_shape={normalized_shape}, expected input with shape [*, {str(normalized_shape)[1:-1]}], but got input of size {input.shape}\"\n            )\n\n    if input.is_cpu:\n        reduce_axis = []\n        for dim in range(len(input.shape)):\n            if dim >= begin_norm_axis:\n                reduce_axis.append(dim)\n        mean = input.mean(dim=reduce_axis, keepdim=True)\n        variance = input.var(dim=reduce_axis, unbiased=False, keepdim=True)\n        params_shape = input.shape[begin_params_axis:]\n        if len(mean.shape) == 1:\n            nd_params_shape = [1] * len(input.shape)\n            nd_params_shape[begin_norm_axis] = params_shape[0]\n            mean = flow.reshape(mean, shape=nd_params_shape)\n            variance = flow.reshape(variance, nd_params_shape)\n            if weight is not None and params_shape[0] == weight.nelement():\n                weight = flow.reshape(weight, shape=nd_params_shape)\n            if bias is not None and params_shape[0] == bias.nelement():\n                bias = flow.reshape(bias, shape=nd_params_shape)\n        elif len(mean.shape) == len(input.shape):\n            pass\n        else:\n            raise ValueError(\n                \"shape of mean and variance should be 1D or has number of axes and x's\"\n            )\n        variance += eps\n        normalized = (input - mean) * variance.rsqrt()\n        if elementwise_affine:\n            normalized = normalized * weight + bias\n        return normalized\n    else:\n        if elementwise_affine:\n            res = flow._C.layer_norm_affine(\n                input,\n                weight,\n                bias,\n                begin_norm_axis=begin_norm_axis,\n                begin_params_axis=begin_params_axis,\n                epsilon=eps,\n            )\n        else:\n            res = flow._C.layer_norm(\n                input,\n                begin_norm_axis=begin_norm_axis,\n                begin_params_axis=begin_params_axis,\n                epsilon=eps,\n            )\n        return res\n\n\nclass LayerNorm(Module):\n    \"\"\"Applies Layer Normalization over a mini-batch of inputs as described in\n    the paper `Layer Normalization <https://arxiv.org/abs/1607.06450>`__\n\n    .. math::\n        y = \\\\frac{x - \\\\mathrm{E}[x]}{ \\\\sqrt{\\\\mathrm{Var}[x] + \\\\epsilon}} * \\\\gamma + \\\\beta\n\n    The mean and standard-deviation are calculated separately over the last\n    certain number dimensions which have to be of the shape specified by\n    :attr:`normalized_shape`.\n    :math:`\\\\gamma` and :math:`\\\\beta` are learnable affine transform parameters of\n    :attr:`normalized_shape` if :attr:`elementwise_affine` is ``True``.\n    The standard-deviation is calculated via the biased estimator.\n\n    .. note::\n        Unlike Batch Normalization and Instance Normalization, which applies\n        scalar scale and bias for each entire channel/plane with the\n        :attr:`affine` option, Layer Normalization applies per-element scale and\n        bias with :attr:`elementwise_affine`.\n\n    This layer uses statistics computed from input data in both training and\n    evaluation modes.\n\n    Args:\n        normalized_shape (int or list or oneflow.Size): input shape from an expected input of size\n\n            .. math::\n                [* \\\\times \\\\text{normalized_shape}[0] \\\\times \\\\text{normalized_shape}[1] \\\\times \\\\ldots \\\\times \\\\text{normalized_shape}[-1]]\n\n            If a single integer is used, it is treated as a singleton list, and this module will\n\n            normalize over the last dimension which is expected to be of that specific size.\n\n        eps: a value added to the denominator for numerical stability. Default: 1e-5\n        elementwise_affine: a boolean value that when set to ``True``, this module\n            has learnable per-element affine parameters initialized to ones (for weights)\n            and zeros (for biases). Default: ``True``.\n\n    Shape:\n        - Input: :math:`(N, *)`\n        - Output: :math:`(N, *)` (same shape as input)\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import numpy as np\n        >>> import oneflow as flow\n        \n        >>> input_arr = np.array(\n        ...     [\n        ...         [\n        ...             [[-0.16046895, -1.03667831], [-0.34974465, 0.26505867]],\n        ...             [[-1.24111986, -0.53806001], [1.72426331, 0.43572459]],\n        ...         ],\n        ...         [\n        ...             [[-0.77390957, -0.42610624], [0.16398858, -1.35760343]],\n        ...             [[1.07541728, 0.11008703], [0.26361224, -0.48663723]],\n        ...         ],\n        ...     ],\n        ...     dtype=np.float32,\n        ... )\n\n        >>> x = flow.Tensor(input_arr)\n        >>> m = flow.nn.LayerNorm(2)\n        >>> y = m(x).numpy()\n        >>> y\n        array([[[[ 0.99997395, -0.99997395],\n                 [-0.999947  ,  0.999947  ]],\n        <BLANKLINE>\n                [[-0.99995965,  0.9999595 ],\n                 [ 0.99998784, -0.99998784]]],\n        <BLANKLINE>\n        <BLANKLINE>\n               [[[-0.9998348 ,  0.99983466],\n                 [ 0.9999914 , -0.9999914 ]],\n        <BLANKLINE>\n                [[ 0.9999785 , -0.9999785 ],\n                 [ 0.9999646 , -0.9999646 ]]]], dtype=float32)\n\n    \"\"\"\n\n    __constants__ = [\"normalized_shape\", \"eps\", \"elementwise_affine\"]\n    normalized_shape: Tuple[int, ...]\n    eps: float\n    elementwise_affine: bool\n\n    def __init__(\n        self,\n        normalized_shape: _shape_t,\n        eps: float = 1e-05,\n        elementwise_affine: bool = True,\n    ) -> None:\n        super(LayerNorm, self).__init__()\n        if isinstance(normalized_shape, int):\n            normalized_shape = (normalized_shape,)\n        self.normalized_shape = tuple(normalized_shape)\n        self.eps = eps\n        self.elementwise_affine = elementwise_affine\n        if self.elementwise_affine:\n            self.weight = flow.nn.Parameter(flow.Tensor(*self.normalized_shape))\n            self.bias = flow.nn.Parameter(flow.Tensor(*self.normalized_shape))\n        else:\n            self.register_parameter(\"weight\", None)\n            self.register_parameter(\"bias\", None)\n        self.reset_parameters()\n\n    def reset_parameters(self) -> None:\n        if self.elementwise_affine:\n            init.ones_(self.weight)\n            init.zeros_(self.bias)\n\n    def forward(self, x):\n        return layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)\n\n    def extra_repr(self) -> str:\n        return \"{normalized_shape}, eps={eps}, elementwise_affine={elementwise_affine}\".format(\n            **self.__dict__\n        )\n\n\nclass RMSLayerNorm(Module):\n    \"\"\"\n    Construct a layernorm module in the T5 style. No bias and no subtraction of mean.\n    \n    T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean\n    \n    Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated\n    \n    w/o mean and there is no bias. Additionally we want to make sure that the accumulation for\n    \n    half-precision inputs is done in fp32.\n\n    Args:\n        hidden_size (int): number of features in the hidden state\n        eps: a value added to the denominator for numerical stability. Default: 1e-6\n\n    Shape:\n        - Input: :math:`(N, *)`\n        - Output: :math:`(N, *)` (same shape as input)\n    \n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n\n        >>> x = flow.randn(2, 4, 3)\n        >>> m = flow.nn.RMSLayerNorm(3)\n        >>> y = m(x)\n        >>> y.size()\n        oneflow.Size([2, 4, 3])\n\n    \"\"\"\n\n    def __init__(self, hidden_size, eps=1e-6):\n        warnings.warn(\n            f\"nn.RMSLayerNorm has been deprecated. Please use nn.RMSNorm instead.\"\n        )\n\n        super().__init__()\n        self.weight = flow.nn.Parameter(flow.ones(hidden_size))\n        self.variance_epsilon = eps\n\n    def forward(self, hidden_states):\n        return flow._C.rms_layer_norm(hidden_states, self.weight, self.variance_epsilon)\n\n\nclass RMSNorm(Module):\n    \"\"\"Applies Root Mean Square Layer Normalization over a mini-batch of inputs as described in\n    the paper `Root Mean Square Layer Normalization <https://arxiv.org/abs/1910.07467>`__\n\n    .. math::\n        y = \\\\frac{x}{\\\\mathrm{RMS}[x]} \\\\mathrm{weight},\\\\text{ where }\\\\mathrm{RMS}[x] = \\\\sqrt{\\\\frac{1}{n} \\\\sum_{i=1}^{n} x^{2}}\n\n    There is no bias and no subtraction of mean with RMS Layer Normalization, \n    and it only scales and doesn't shift.\n\n    The root mean squre are calculated separately over the last\n    certain number dimensions which have to be of the shape specified by\n    :attr:`normalized_shape`.\n    :math:`\\\\weight` is learnable affine transform parameters of\n    :attr:`normalized_shape` if :attr:`elementwise_affine` is ``True``.\n\n    .. note::\n        Like Layer Normalization, Root Mean Square Layer Normalization applies per-element scale\n        with :attr:`elementwise_affine`.\n\n    This layer uses statistics computed from input data in both training and\n    evaluation modes.\n\n    Args:\n        normalized_shape (int or list or oneflow.Size): input shape from an expected input of size\n\n            .. math::\n                [* \\\\times \\\\text{normalized_shape}[0] \\\\times \\\\text{normalized_shape}[1] \\\\times \\\\ldots \\\\times \\\\text{normalized_shape}[-1]]\n\n            If a single integer is used, it is treated as a singleton list, and this module will\n\n            normalize over the last dimension which is expected to be of that specific size.\n\n        eps: a value added to the denominator for numerical stability. Default: 1e-5\n        elementwise_affine: a boolean value that when set to ``True``, this module\n            has learnable per-element affine parameters initialized to ones (for weights).\n            Default: ``True``.\n\n    Shape:\n        - Input: :math:`(N, *)`\n        - Output: :math:`(N, *)` (same shape as input)\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import numpy as np\n        >>> import oneflow as flow\n        \n        >>> input_arr = np.array(\n        ...     [\n        ...         [\n        ...             [[-0.16046895, -1.03667831], [-0.34974465, 0.26505867]],\n        ...             [[-1.24111986, -0.53806001], [1.72426331, 0.43572459]],\n        ...         ],\n        ...         [\n        ...             [[-0.77390957, -0.42610624], [0.16398858, -1.35760343]],\n        ...             [[1.07541728, 0.11008703], [0.26361224, -0.48663723]],\n        ...         ],\n        ...     ],\n        ...     dtype=np.float32,\n        ... )\n\n        >>> x = flow.Tensor(input_arr, device=\"cuda\")\n        >>> m = flow.nn.RMSNorm(2).to(device=\"cuda\")\n        >>> y = m(x).numpy()\n        >>> y\n        array([[[[-0.21632987, -1.3975569 ],\n                 [-1.127044  ,  0.8541454 ]],\n        <BLANKLINE>\n                [[-1.2975204 , -0.5625112 ],\n                 [ 1.3711083 ,  0.34648165]]],\n        <BLANKLINE>\n        <BLANKLINE>\n               [[[-1.2388322 , -0.6820876 ],\n                 [ 0.16959298, -1.4040003 ]],\n        <BLANKLINE>\n                [[ 1.4068495 ,  0.14401469],\n                 [ 0.6735778 , -1.2434478 ]]]], dtype=float32)\n\n    \"\"\"\n\n    _constants__ = [\"normalized_shape\", \"eps\", \"elementwise_affine\"]\n    normalized_shape: Tuple[int, ...]\n    eps: float\n    elementwise_affine: bool\n\n    def __init__(\n        self,\n        normalized_shape: _shape_t,\n        eps: float = 1e-05,\n        elementwise_affine: bool = True,\n        device=None,\n        dtype=None,\n    ):\n        factory_kwargs = {\"device\": device, \"dtype\": dtype}\n        super().__init__()\n        if isinstance(normalized_shape, int):\n            normalized_shape = (normalized_shape,)\n        self.normalized_shape = tuple(normalized_shape)\n        self.eps = eps\n        self.elementwise_affine = elementwise_affine\n        if self.elementwise_affine:\n            self.weight = flow.nn.Parameter(\n                flow.ones(*self.normalized_shape, **factory_kwargs)\n            )\n        else:\n            self.register_parameter(\"weight\", None)\n\n    def forward(self, x):\n        return flow._C.rms_norm(x, self.weight, self.normalized_shape, self.eps)\n\n    def extra_repr(self) -> str:\n        return \"{normalized_shape}, eps={eps}, elementwise_affine={elementwise_affine}\".format(\n            **self.__dict__\n        )\n\n\nif __name__ == \"__main__\":\n    import doctest\n\n    doctest.testmod(raise_on_error=True)\n"
  },
  {
    "path": "python/oneflow/nn/modules/numel.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow as flow\n\n\ndef numel_op(input):\n    \"\"\"\n    numel(input) -> int\n\n    Returns the total number of elements in the :attr:`input` tensor.\n\n    Args:\n        input (oneflow.Tensor): Input Tensor\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n\n        >>> a = flow.randn(1, 2, 3, 4, 5)\n        >>> flow.numel(a)\n        120\n        >>> a = flow.zeros(4,4)\n        >>> flow.numel(a)\n        16\n    \"\"\"\n    return input.numel()\n\n\nif __name__ == \"__main__\":\n    import doctest\n\n    doctest.testmod(raise_on_error=True)\n"
  },
  {
    "path": "python/oneflow/nn/modules/padding.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nfrom typing import Union\n\nimport oneflow as flow\nfrom oneflow.nn.common_types import _size_2_t, _size_4_t\nfrom oneflow.nn.modules.module import Module\nfrom oneflow.nn.modules.utils import _pair, _quadruple\n\n\nclass ReplicationPad1d(Module):\n    r\"\"\"\n    ReplicationPad1d(padding)\n\n    Pads the input tensor using replication of the input boundary.\n\n    The interface is consistent with PyTorch.\n    The documentation is referenced from:\n    https://pytorch.org/docs/1.10/generated/torch.nn.ReplicationPad1d.html.\n\n    For `N`-dimensional padding, use :func:`oneflow.nn.functional.pad()`.\n\n    Args:\n        padding (int, tuple): the size of the padding. If is `int`, uses the same\n            padding in all boundaries. If a 2-`tuple`, uses\n            (:math:`\\text{padding_left}`, :math:`\\text{padding_right}`)\n\n    Shape:\n        - Input: :math:`(C, W_{in})` or :math:`(N, C, W_{in})`.\n        - Output: :math:`(C, W_{out})` or :math:`(N, C, W_{out})`, where\n\n          :math:`W_{out} = W_{in} + \\text{padding_left} + \\text{padding_right}`\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import numpy as np\n        >>> import oneflow as flow\n        >>> m = flow.nn.ReplicationPad1d((2, 2))\n        >>> input = flow.tensor(np.arange(18).reshape((2, 3, 3)).astype(np.float32))\n        >>> out = m(input)\n        >>> out\n        tensor([[[ 0.,  0.,  0.,  1.,  2.,  2.,  2.],\n                 [ 3.,  3.,  3.,  4.,  5.,  5.,  5.],\n                 [ 6.,  6.,  6.,  7.,  8.,  8.,  8.]],\n        <BLANKLINE>\n                [[ 9.,  9.,  9., 10., 11., 11., 11.],\n                 [12., 12., 12., 13., 14., 14., 14.],\n                 [15., 15., 15., 16., 17., 17., 17.]]], dtype=oneflow.float32)\n\n    \"\"\"\n\n    def __init__(self, padding: _size_4_t):\n        super().__init__()\n        if isinstance(padding, tuple):\n            assert len(padding) == 2, ValueError(\"Padding length must be 2\")\n            boundary = [*padding]\n        elif isinstance(padding, int):\n            boundary = _pair(padding)\n        else:\n            raise ValueError(\"padding must be in or list or tuple!\")\n        self.padding = boundary\n\n    def forward(self, x):\n        return flow._C.pad(x, pad=self.padding, mode=\"replicate\")\n\n    def extra_repr(self) -> str:\n        return \"{}\".format(self.padding)\n\n\nclass ReplicationPad2d(Module):\n    \"\"\"\n    ReplicationPad2d(padding)\n\n    Pads the input tensor using the replication of the input boundary.\n\n    The interface is consistent with PyTorch.\n    The documentation is referenced from:\n    https://pytorch.org/docs/1.10/generated/torch.nn.ReplicationPad2d.html.\n\n    Args:\n        padding (Union[int, tuple, list]):  the size of the padding. If is `int`, uses the same padding in all boundaries. If a 4-`tuple`, uses (:math:`\\\\mathrm{padding_{left}}`, :math:`\\\\mathrm{padding_{right}}`, :math:`\\\\mathrm{padding_{top}}`, :math:`\\\\mathrm{padding_{bottom}}`)\n\n    Shape:\n        - Input: :math:`(N, C, H_{\\\\text{in}}, W_{\\\\text{in}})` or :math:`(C, H_{in}, W_{in})`\n        - Output: :math:`(N, C, H_{\\\\text{out}}, W_{\\\\text{out}})` or :math:`(C, H_{out}, W_{out})` where\n\n            :math:`H_{out} = H_{in} + \\\\mathrm{padding_{top}} + \\\\mathrm{padding_{bottom}}`\n\n            :math:`W_{out} = W_{in} + \\\\mathrm{padding_{left}} + \\\\mathrm{padding_{right}}`\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n        >>> m = flow.nn.ReplicationPad2d((2, 2, 1, 1))\n        >>> input = flow.tensor(np.arange(18).reshape((1, 2, 3, 3)).astype(np.float32))\n        >>> input_int = flow.tensor(np.arange(18).reshape((1, 2, 3, 3)).astype(np.int32))\n        >>> output = m(input)\n        >>> output.shape\n        oneflow.Size([1, 2, 5, 7])\n        >>> output\n        tensor([[[[ 0.,  0.,  0.,  1.,  2.,  2.,  2.],\n                  [ 0.,  0.,  0.,  1.,  2.,  2.,  2.],\n                  [ 3.,  3.,  3.,  4.,  5.,  5.,  5.],\n                  [ 6.,  6.,  6.,  7.,  8.,  8.,  8.],\n                  [ 6.,  6.,  6.,  7.,  8.,  8.,  8.]],\n        <BLANKLINE>\n                 [[ 9.,  9.,  9., 10., 11., 11., 11.],\n                  [ 9.,  9.,  9., 10., 11., 11., 11.],\n                  [12., 12., 12., 13., 14., 14., 14.],\n                  [15., 15., 15., 16., 17., 17., 17.],\n                  [15., 15., 15., 16., 17., 17., 17.]]]], dtype=oneflow.float32)\n\n    \"\"\"\n\n    def __init__(self, padding: _size_4_t):\n        super().__init__()\n        if isinstance(padding, (tuple, list)):\n            assert len(padding) == 4, ValueError(\"Length of padding must be 4\")\n            boundary = [*padding]\n        elif isinstance(padding, int):\n            boundary = _quadruple(padding)\n        else:\n            raise ValueError(\"padding must be int or list or tuple!\")\n        self.padding = boundary\n\n    def forward(self, x):\n        return flow._C.pad(x, pad=self.padding, mode=\"replicate\")\n\n    def extra_repr(self) -> str:\n        return \"{}\".format(self.padding)\n\n\nclass ReflectionPad1d(Module):\n    \"\"\"\n    ReflectionPad1d(padding)\n\n    This operator pads the input tensor using the reflection of the input boundary.\n\n    The interface is consistent with PyTorch.\n    The documentation is referenced from:\n    https://pytorch.org/docs/1.10/generated/torch.nn.ReflectionPad1d.html.\n\n    Args:\n        padding (Union[int,tuple]): The size or bundary of padding, if is `int` uses the same padding in all dimension; if 4-dims `tuple`, uses :math:`(\\\\text{padding}_{\\\\text{left}}, \\\\text{padding}_{\\\\text{right}}, \\\\text{padding}_{\\\\text{top}}, \\\\text{padding}_{\\\\text{bottom}} )`\n\n    Returns:\n        Tensor: Returns a new tensor which is result of the reflection padding of the input tensor.\n\n    Shape:\n        - Input: :math:`(C, W_{in})` or :math:`(N, C, W_{in})`.\n        - Output: :math:`(C, W_{out})` or :math:`(N, C, W_{out})`, where\n\n          :math:`W_{out} = W_{in} + \\\\text{padding_left} + \\\\text{padding_right}`\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n        >>> input = flow.tensor(np.arange(18).reshape((2, 3, 3)).astype(np.float32))\n        >>> m = flow.nn.ReflectionPad1d((2, 2))\n        >>> out = m(input)\n        >>> out\n        tensor([[[ 2.,  1.,  0.,  1.,  2.,  1.,  0.],\n                 [ 5.,  4.,  3.,  4.,  5.,  4.,  3.],\n                 [ 8.,  7.,  6.,  7.,  8.,  7.,  6.]],\n        <BLANKLINE>\n                [[11., 10.,  9., 10., 11., 10.,  9.],\n                 [14., 13., 12., 13., 14., 13., 12.],\n                 [17., 16., 15., 16., 17., 16., 15.]]], dtype=oneflow.float32)\n\n    \"\"\"\n\n    def __init__(self, padding: _size_2_t) -> None:\n        super().__init__()\n        if isinstance(padding, tuple):\n            assert len(padding) == 2, ValueError(\"Padding length must be 2\")\n            boundary = [*padding]\n        elif isinstance(padding, int):\n            boundary = _pair(padding)\n        else:\n            raise ValueError(\"padding must be in or list or tuple!\")\n        self.padding = boundary\n\n    def forward(self, x):\n        return flow._C.pad(x, pad=self.padding, mode=\"reflect\")\n\n    def extra_repr(self) -> str:\n        return \"{}\".format(self.padding)\n\n\nclass ReflectionPad2d(Module):\n    \"\"\"\n    ReflectionPad2d(padding)\n\n    This operator pads the input tensor using the reflection of the input boundary.\n\n    The interface is consistent with PyTorch.\n    The documentation is referenced from:\n    https://pytorch.org/docs/1.10/generated/torch.nn.ReflectionPad2d.html.\n\n    Args:\n        padding (Union[int,tuple]): The size or bundary of padding, if is `int` uses the same padding in all dimension; if 4-dims `tuple`, uses :math:`(\\\\text{padding}_{\\\\text{left}}, \\\\text{padding}_{\\\\text{right}}, \\\\text{padding}_{\\\\text{top}}, \\\\text{padding}_{\\\\text{bottom}} )`\n\n    Returns:\n        Tensor: Returns a new tensor which is result of the reflection padding of the input tensor.\n\n    Shape:\n        - Input: :math:`(N, C, H_{\\\\text{in}}, W_{\\\\text{in}})` or :math:`(C, H_{in}, W_{in})`\n        - Output: :math:`(N, C, H_{\\\\text{out}}, W_{\\\\text{out}})` or :math:`(C, H_{out}, W_{out})` where\n\n          :math:`H_{\\\\text{out}} = H_{\\\\text{in}} + \\\\text{padding}_{\\\\text{top}} + \\\\text{padding}_{\\\\text{bottom}}`\n\n          :math:`W_{\\\\text{out}} = W_{\\\\text{in}} + \\\\text{padding}_{\\\\text{left}} + \\\\text{padding}_{\\\\text{right}}`\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n        >>> input = flow.tensor(np.arange(18).reshape((1, 2, 3, 3)).astype(np.float32))\n        >>> m = flow.nn.ReflectionPad2d((2, 2, 1, 1))\n        >>> out = m(input)\n        >>> out\n        tensor([[[[ 5.,  4.,  3.,  4.,  5.,  4.,  3.],\n                  [ 2.,  1.,  0.,  1.,  2.,  1.,  0.],\n                  [ 5.,  4.,  3.,  4.,  5.,  4.,  3.],\n                  [ 8.,  7.,  6.,  7.,  8.,  7.,  6.],\n                  [ 5.,  4.,  3.,  4.,  5.,  4.,  3.]],\n        <BLANKLINE>\n                 [[14., 13., 12., 13., 14., 13., 12.],\n                  [11., 10.,  9., 10., 11., 10.,  9.],\n                  [14., 13., 12., 13., 14., 13., 12.],\n                  [17., 16., 15., 16., 17., 16., 15.],\n                  [14., 13., 12., 13., 14., 13., 12.]]]], dtype=oneflow.float32)\n\n    \"\"\"\n\n    def __init__(self, padding: _size_4_t) -> None:\n        super().__init__()\n        if isinstance(padding, tuple):\n            assert len(padding) == 4, ValueError(\"Padding length must be 4\")\n            boundary = [*padding]\n        elif isinstance(padding, int):\n            boundary = _quadruple(padding)\n        else:\n            raise ValueError(\"padding must be in or list or tuple!\")\n        self.padding = boundary\n\n    def forward(self, x):\n        return flow._C.pad(x, pad=self.padding, mode=\"reflect\")\n\n    def extra_repr(self) -> str:\n        return \"{}\".format(self.padding)\n\n\nclass ConstantPad1d(Module):\n    \"\"\"\n    ConstantPad1d(padding)\n    \n    Pads the input tensor boundaries with a constant value.\n\n    The interface is consistent with PyTorch, and referenced from:\n    https://pytorch.org/docs/1.10/generated/torch.nn.ConstantPad1d.html.\n\n    For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`.\n\n    Args:\n        padding (int, list, tuple): the size of the padding. If is `int`, uses the same\n            padding in both boundaries. If a 2-`tuple`, uses\n            (:math:`\\\\text{padding_left}`, :math:`\\\\text{padding_right}`)\n\n        value (int, float): The constant value used for padding. Defaults to 0.\n\n    Shape:\n        - Input: :math:`(N, C, W_{in})`\n        - Output: :math:`(N, C, W_{out})` where\n\n          :math:`W_{out} = W_{in} + \\\\text{padding\\\\_left} + \\\\text{padding\\\\_right}`\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n\n        >>> input = flow.tensor(np.arange(8).reshape(2,2,2).astype(np.float32))\n        >>> m = flow.nn.ConstantPad1d(padding=[1, 2], value=9.9999)\n        >>> output = m(input)\n        >>> output\n        tensor([[[9.9999, 0.0000, 1.0000, 9.9999, 9.9999],\n                 [9.9999, 2.0000, 3.0000, 9.9999, 9.9999]],\n        <BLANKLINE>\n                [[9.9999, 4.0000, 5.0000, 9.9999, 9.9999],\n                 [9.9999, 6.0000, 7.0000, 9.9999, 9.9999]]], dtype=oneflow.float32)\n\n    \"\"\"\n\n    def __init__(self, padding: Union[int, tuple, list], value: Union[int, float] = 0):\n        super().__init__()\n        if isinstance(padding, (tuple, list)):\n            boundary = padding\n        elif isinstance(padding, int):\n            boundary = [padding] * 2\n        else:\n            raise ValueError(\"padding must be int or list or tuple!\")\n        self.padding = boundary\n        self.value = value\n\n    def forward(self, x):\n        return flow._C.pad(x, pad=self.padding, mode=\"constant\", value=self.value)\n\n\nclass ConstantPad2d(Module):\n    \"\"\"\n    ConstantPad2d(padding)\n\n    This operator pads the input with constant value that user specifies.\n    User can set the amount of padding by setting the parameter `paddings`.\n\n    The interface is consistent with PyTorch.\n    The documentation is referenced from:\n    https://pytorch.org/docs/1.10/generated/torch.nn.ConstantPad2d.html.\n\n    Args:\n        padding (int, tuple, list):  the size of the padding.\n            If is `int`, uses the same padding in all boundaries.\n            If a 4-`tuple`, uses\n            (:math:`\\\\mathrm{padding_{left}}`, :math:`\\\\mathrm{padding_{right}}`, :math:`\\\\mathrm{padding_{top}}`, :math:`\\\\mathrm{padding_{bottom}}`)\n\n        value (int, float): The constant value used for padding. Defaults to 0.\n\n    Shape:\n        - Input: :math:`(N, C, H_{in}, W_{in})`\n        - Output: :math:`(N, C, H_{out}, W_{out})` where\n\n          :math:`H_{out} = H_{in} + \\\\mathrm{padding_{top}} + \\\\mathrm{padding_{bottom}}`\n          :math:`W_{out} = W_{in} + \\\\mathrm{padding_{left}} + \\\\mathrm{padding_{right}}`\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n\n        >>> m = flow.nn.ConstantPad2d((2, 2, 1, 1), 1)\n        >>> input = flow.tensor(np.arange(18).reshape((1, 2, 3, 3)).astype(np.float32))\n        >>> output = m(input)\n        >>> output.shape\n        oneflow.Size([1, 2, 5, 7])\n        >>> output\n        tensor([[[[ 1.,  1.,  1.,  1.,  1.,  1.,  1.],\n                  [ 1.,  1.,  0.,  1.,  2.,  1.,  1.],\n                  [ 1.,  1.,  3.,  4.,  5.,  1.,  1.],\n                  [ 1.,  1.,  6.,  7.,  8.,  1.,  1.],\n                  [ 1.,  1.,  1.,  1.,  1.,  1.,  1.]],\n        <BLANKLINE>\n                 [[ 1.,  1.,  1.,  1.,  1.,  1.,  1.],\n                  [ 1.,  1.,  9., 10., 11.,  1.,  1.],\n                  [ 1.,  1., 12., 13., 14.,  1.,  1.],\n                  [ 1.,  1., 15., 16., 17.,  1.,  1.],\n                  [ 1.,  1.,  1.,  1.,  1.,  1.,  1.]]]], dtype=oneflow.float32)\n\n    \"\"\"\n\n    def __init__(self, padding: Union[int, tuple, list], value: Union[int, float] = 0):\n        super().__init__()\n        if isinstance(padding, (tuple, list)):\n            boundary = padding\n        elif isinstance(padding, int):\n            boundary = [padding] * 4\n        else:\n            raise ValueError(\"padding must be int or list or tuple!\")\n        self.padding = boundary\n        self.value = value\n\n    def forward(self, x):\n        return flow._C.pad(x, pad=self.padding, mode=\"constant\", value=self.value)\n\n\nclass ConstantPad3d(Module):\n    \"\"\"\n    ConstantPad3d(padding)\n    \n    Pads the input tensor boundaries with a constant value.\n    The interface is consistent with PyTorch, and referenced from:\n    https://pytorch.org/docs/1.10/generated/torch.nn.ConstantPad3d.html.\n\n    For `N`-dimensional padding, use :func:`flow.nn.functional.pad()`.\n\n    Args:\n        padding (int, list, tuple): the size of the padding. If is `int`, uses the same\n            padding in all boundaries. If a 6-`tuple`, uses\n            (:math:`\\\\text{padding_left}`, :math:`\\\\text{padding_right}`,\n            :math:`\\\\text{padding_top}`, :math:`\\\\text{padding_bottom}`,\n            :math:`\\\\text{padding_front}`, :math:`\\\\text{padding_back}`)\n\n        value (int, float): The constant value used for padding. Defaults to 0.\n\n    Shape:\n        - Input: :math:`(N, C, D_{in}, H_{in}, W_{in})`\n        - Output: :math:`(N, C, D_{out}, H_{out}, W_{out})` where\n\n          :math:`D_{out} = D_{in} + \\\\text{padding_front} + \\\\text{padding_back}`\n\n          :math:`H_{out} = H_{in} + \\\\text{padding_top} + \\\\text{padding_bottom}`\n\n          :math:`W_{out} = W_{in} + \\\\text{padding_left} + \\\\text{padding_right}`\n\n    Examples::\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n\n        >>> input = flow.tensor(np.arange(8).reshape(1,1,2,2,2).astype(np.int32))\n        >>> m = flow.nn.ConstantPad3d(padding=1, value=9)\n        >>> output = m(input)\n        >>> output\n        tensor([[[[[9, 9, 9, 9],\n                   [9, 9, 9, 9],\n                   [9, 9, 9, 9],\n                   [9, 9, 9, 9]],\n        <BLANKLINE>\n                  [[9, 9, 9, 9],\n                   [9, 0, 1, 9],\n                   [9, 2, 3, 9],\n                   [9, 9, 9, 9]],\n        <BLANKLINE>\n                  [[9, 9, 9, 9],\n                   [9, 4, 5, 9],\n                   [9, 6, 7, 9],\n                   [9, 9, 9, 9]],\n        <BLANKLINE>\n                  [[9, 9, 9, 9],\n                   [9, 9, 9, 9],\n                   [9, 9, 9, 9],\n                   [9, 9, 9, 9]]]]], dtype=oneflow.int32)\n    \"\"\"\n\n    def __init__(self, padding: Union[int, tuple, list], value: Union[int, float] = 0):\n        super().__init__()\n        if isinstance(padding, (tuple, list)):\n            boundary = padding\n        elif isinstance(padding, int):\n            boundary = [padding] * 6\n        else:\n            raise ValueError(\"padding must be int or list or tuple!\")\n        self.padding = boundary\n        self.value = value\n\n    def forward(self, x):\n        return flow._C.pad(x, pad=self.padding, mode=\"constant\", value=self.value)\n\n\nclass ZeroPad2d(Module):\n    \"\"\"\n    ZeroPad2d(padding)\n\n    Pads the input tensor boundaries with zero. User can set the amount of padding by setting the parameter `paddings`.\n\n    The interface is consistent with PyTorch.\n    The documentation is referenced from:\n    https://pytorch.org/docs/1.10/generated/torch.nn.ZeroPad2d.html.\n\n    Args:\n        padding (Union[int, tuple]):  the size of the padding. If is `int`, uses the same padding in all boundaries. If a 4-`tuple`, uses (:math:`\\\\mathrm{padding_{left}}`, :math:`\\\\mathrm{padding_{right}}`, :math:`\\\\mathrm{padding_{top}}`, :math:`\\\\mathrm{padding_{bottom}}`)\n\n    Shape:\n        - Input: :math:`(N, C, H_{in}, W_{in})`\n        - Output: :math:`(N, C, H_{out}, W_{out})` where\n\n            :math:`H_{out} = H_{in} + \\\\mathrm{padding_{top}} + \\\\mathrm{padding_{bottom}}`\n\n            :math:`W_{out} = W_{in} + \\\\mathrm{padding_{left}} + \\\\mathrm{padding_{right}}`\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n        >>> m1 = flow.nn.ZeroPad2d(2)\n        >>> m2 = flow.nn.ZeroPad2d((1,2,2,0))\n        >>> input = flow.tensor(np.arange(18).reshape((1, 2, 3, 3)).astype(np.float32))\n        >>> output = m1(input)\n        >>> output.shape\n        oneflow.Size([1, 2, 7, 7])\n        >>> output\n        tensor([[[[ 0.,  0.,  0.,  0.,  0.,  0.,  0.],\n                  [ 0.,  0.,  0.,  0.,  0.,  0.,  0.],\n                  [ 0.,  0.,  0.,  1.,  2.,  0.,  0.],\n                  [ 0.,  0.,  3.,  4.,  5.,  0.,  0.],\n                  [ 0.,  0.,  6.,  7.,  8.,  0.,  0.],\n                  [ 0.,  0.,  0.,  0.,  0.,  0.,  0.],\n                  [ 0.,  0.,  0.,  0.,  0.,  0.,  0.]],\n        <BLANKLINE>\n                 [[ 0.,  0.,  0.,  0.,  0.,  0.,  0.],\n                  [ 0.,  0.,  0.,  0.,  0.,  0.,  0.],\n                  [ 0.,  0.,  9., 10., 11.,  0.,  0.],\n                  [ 0.,  0., 12., 13., 14.,  0.,  0.],\n                  [ 0.,  0., 15., 16., 17.,  0.,  0.],\n                  [ 0.,  0.,  0.,  0.,  0.,  0.,  0.],\n                  [ 0.,  0.,  0.,  0.,  0.,  0.,  0.]]]], dtype=oneflow.float32)\n        >>> output = m2(input)\n        >>> output\n        tensor([[[[ 0.,  0.,  0.,  0.,  0.,  0.],\n                  [ 0.,  0.,  0.,  0.,  0.,  0.],\n                  [ 0.,  0.,  1.,  2.,  0.,  0.],\n                  [ 0.,  3.,  4.,  5.,  0.,  0.],\n                  [ 0.,  6.,  7.,  8.,  0.,  0.]],\n        <BLANKLINE>\n                 [[ 0.,  0.,  0.,  0.,  0.,  0.],\n                  [ 0.,  0.,  0.,  0.,  0.,  0.],\n                  [ 0.,  9., 10., 11.,  0.,  0.],\n                  [ 0., 12., 13., 14.,  0.,  0.],\n                  [ 0., 15., 16., 17.,  0.,  0.]]]], dtype=oneflow.float32)\n    \"\"\"\n\n    def __init__(self, padding: Union[int, tuple, list]):\n        super().__init__()\n        if isinstance(padding, (tuple, list)):\n            boundary = padding\n        elif isinstance(padding, int):\n            boundary = [padding] * 4\n        else:\n            raise ValueError(\"padding must be int or list or tuple!\")\n        self.padding = boundary\n        self.value = 0.0\n\n    def forward(self, x):\n        return flow._C.pad(x, pad=self.padding, mode=\"constant\", value=self.value)\n\n\nif __name__ == \"__main__\":\n    import doctest\n\n    doctest.testmod(raise_on_error=True)\n"
  },
  {
    "path": "python/oneflow/nn/modules/pixelshuffle.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nfrom typing import Optional\n\nimport oneflow as flow\nfrom oneflow.framework.tensor import Tensor\nfrom oneflow.nn.modules.module import Module\n\n\nclass PixelShufflev2(Module):\n    \"\"\"\n    Part of the documentation is referenced from:\n    https://pytorch.org/docs/1.10/generated/torch.nn.PixelShuffle.html.\n\n    Rearranges elements in a tensor of shape :math:`(*, C \\\\times r_h \\\\times r_w, H, W)`\n    to a tensor of shape :math:`(*, C, H \\\\times r_h, W \\\\times r_w)`, where r_h and r_w are upscale factors.\n\n    This is useful for implementing efficient sub-pixel convolution\n    with a stride of :math:`1/r`.\n\n    See the paper:\n    `Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network`_\n    by Shi et. al (2016) for more details.\n\n    Args:\n        upscale_factor (int, optional): factor to increase spatial resolution by, only use when factors of height and width spatial are the same.\n\n        h_upscale_factor (int, optional): factor to increase height spatial resolution by, only one of h_upscale_factor and upscale_factor can be used.\n        w_upscale_factor (int, optional): factor to increase width spatial resolution by, only one of w_upscale_factor and upscale_factor can be used.\n\n    Shape:\n        - Input: :math:`(*, C_{in}, H_{in}, W_{in})`, where * is zero or more batch dimensions\n        - Output: :math:`(*, C_{out}, H_{out}, W_{out})`, where\n\n    if use upscale_factor:\n\n    .. math::\n        C_{out} = C_{in} \\\\div \\\\text{h_upscale_factor}^2\n\n        H_{out} = H_{in} \\\\times \\\\text{upscale_factor}\n\n        W_{out} = W_{in} \\\\times \\\\text{upscale_factor}\n\n    if use h_upscale_factor and w_upscale_factor:\n\n    .. math::\n        C_{out} = C_{in} \\\\div \\\\text{h_upscale_factor} \\\\div \\\\text{w_upscale_factor}\n\n        H_{out} = H_{in} \\\\times \\\\text{h_upscale_factor}\n\n        W_{out} = W_{in} \\\\times \\\\text{w_upscale_factor}\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n        >>> m = flow.nn.PixelShuffle(upscale_factor=2)\n        >>> x = flow.Tensor(np.random.randn(3, 4, 5, 5))\n        >>> y = m(x)\n        >>> y.shape\n        oneflow.Size([3, 1, 10, 10])\n\n        >>> m = flow.nn.PixelShuffle(h_upscale_factor=3, w_upscale_factor=4)\n        >>> x = flow.Tensor(np.random.randn(1, 24, 2, 2))\n        >>> y = m(x)\n        >>> y.shape\n        oneflow.Size([1, 2, 6, 8])\n\n    .. _Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network:\n        https://arxiv.org/abs/1609.05158\n    \"\"\"\n\n    def __init__(\n        self,\n        upscale_factor: Optional[int] = None,\n        h_upscale_factor: Optional[int] = None,\n        w_upscale_factor: Optional[int] = None,\n    ) -> None:\n        super().__init__()\n        if upscale_factor is None:\n            assert (\n                h_upscale_factor is not None and w_upscale_factor is not None\n            ), \"h_upscale_factor and w_upscale_factor should be None if use upscale_factor\"\n        else:\n            assert (\n                h_upscale_factor is None and w_upscale_factor is None\n            ), \"upscale_factor should be None if use h_upscale_factor and w_upscale_factor\"\n            h_upscale_factor = upscale_factor\n            w_upscale_factor = upscale_factor\n        assert (\n            h_upscale_factor > 0 and w_upscale_factor > 0\n        ), \"The scale factor of height and width must larger than zero\"\n        self.h_upscale_factor = h_upscale_factor\n        self.w_upscale_factor = w_upscale_factor\n\n    def forward(self, input: Tensor) -> Tensor:\n        return flow._C.pixel_shuffle(\n            input, self.h_upscale_factor, self.w_upscale_factor\n        )\n\n    def extra_repr(self) -> str:\n        return f\"w_upscale_factor={self.w_upscale_factor}, h_upscale_factor={self.h_upscale_factor}\"\n\n\nif __name__ == \"__main__\":\n    import doctest\n\n    doctest.testmod(raise_on_error=True)\n"
  },
  {
    "path": "python/oneflow/nn/modules/pooling.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nfrom typing import Optional, Union, List\nimport os\n\nimport oneflow as flow\nfrom oneflow.nn.common_types import _size_1_t, _size_2_t, _size_3_t\nfrom oneflow.nn.modules.module import Module\nfrom oneflow.nn.modules.utils import (\n    _generate_output_size,\n    _getint,\n    _pair,\n    _single,\n    _triple,\n)\n\n\nclass MaxPool1d(Module):\n    r\"\"\"Applies a 1D max pooling over an input signal composed of several input planes.\n    \n    The interface is consistent with PyTorch.\n    The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.MaxPool1d.html.\n\n    In the simplest case, the output value of the layer with input size :math:`(N, C, L)`\n    and output :math:`(N, C, L_{out})` can be precisely described as:\n\n    .. math::\n        out(N_i, C_j, k) = \\max_{m=0, \\ldots, \\text{kernel\\_size} - 1}\n                input(N_i, C_j, stride \\times k + m)\n\n    If :attr:`padding` is non-zero, then the input is implicitly padded with minimum value on both sides\n    for :attr:`padding` number of points. :attr:`dilation` is the stride between the elements within the\n    sliding window. This link has a nice visualization of the pooling parameters.\n\n    Note:\n        When ceil_mode=True, sliding windows are allowed to go off-bounds if they start within the left padding\n        or the input. Sliding windows that would start in the right padded region are ignored.\n\n    Args:\n        kernel_size: The size of the sliding window, must be > 0.\n        stride: The stride of the sliding window, must be > 0. Default value is :attr:`kernel_size`.\n        padding: Implicit negative infinity padding to be added on both sides, must be >= 0 and <= kernel_size / 2.\n        dilation: The stride between elements within a sliding window, must be > 0.\n        return_indices: If ``True``, will return the argmax along with the max values.\n        ceil_mode: If ``True``, will use `ceil` instead of `floor` to compute the output shape. This\n                   ensures that every element in the input tensor is covered by a sliding window.\n\n    Shape:\n        - Input: :math:`(N, C, L_{in})`\n        - Output: :math:`(N, C, L_{out})`, where\n\n          .. math::\n              L_{out} = \\left\\lfloor \\frac{L_{in} + 2 \\times \\text{padding} - \\text{dilation}\n                    \\times (\\text{kernel_size} - 1) - 1}{\\text{stride}} + 1\\right\\rfloor\n\n    For example: \n\n    .. code-block:: python \n\n        import oneflow as flow \n        import numpy as np\n\n        of_maxpool1d = flow.nn.MaxPool1d(kernel_size=3, padding=1, stride=1)\n        x = flow.Tensor(np.random.randn(1, 4, 4))\n        y = of_maxpool1d(x)\n        y.shape \n        oneflow.Size([1, 4, 4])\n\n    \"\"\"\n\n    def __init__(\n        self,\n        kernel_size: _size_1_t,\n        stride: Optional[_size_1_t] = None,\n        padding: _size_1_t = 0,\n        dilation: _size_1_t = 1,\n        return_indices: bool = False,\n        ceil_mode: bool = False,\n    ):\n        super().__init__()\n        self.kernel_size = _single(kernel_size)\n        self.stride = _single(stride) if stride is not None else self.kernel_size\n        data_format = \"NCL\"  # only support \"NCL\" for now !\n        self.channel_pos = \"channels_first\" if data_format == \"NCL\" else \"channels_last\"\n        self.dilation = _single(dilation)\n        self.padding = _single(padding)\n        self.return_indices = return_indices\n        self.ceil_mode = ceil_mode\n\n    def forward(self, x):\n        y, indice = flow._C.max_pool1d(\n            x,\n            kernel_size=self.kernel_size,\n            stride=self.stride,\n            padding=self.padding,\n            dilation=self.dilation,\n            return_indices=True,\n            ceil_mode=self.ceil_mode,\n            data_format=self.channel_pos,\n        )\n        if self.return_indices:\n            return y, indice\n        else:\n            return y\n\n    def extra_repr(self) -> str:\n        return \"kernel_size={}, stride={}, padding={}\".format(\n            self.kernel_size, self.stride, self.padding\n        )\n\n\ndef get_dhw_offset(channel_pos):\n    if channel_pos == \"channels_first\":\n        return 2\n    else:\n        return 1\n\n\ndef get_ndim_pads_list(padding, dhw_offset, ndims):\n    pads_list = []\n    for i in range(len(padding)):\n        pad = padding[i]\n        if isinstance(pad, int):\n            pad = [pad, pad]\n        elif isinstance(pad, (list, tuple)):\n            assert len(pad) == 2\n            pad = [pad[0], pad[1]]\n        else:\n            raise ValueError(\"padding must be list tuple or int\")\n        if i in range(dhw_offset, dhw_offset + ndims):\n            pads_list.append(pad)\n        else:\n            assert pad == [0, 0]\n    return pads_list\n\n\ndef calc_pool_padding(padding, dhw_offset, ndims):\n    if isinstance(padding, str):\n        padding = \"SAME_LOWER\" if padding.upper() == \"SAME\" else padding\n        assert padding.upper() in [\"VALID\", \"SAME_LOWER\", \"SAME_UPPER\"]\n        padding_type = padding.lower()\n        ndim_pads_list = [[0, 0]] * ndims\n    elif isinstance(padding, (list, tuple)):\n        padding_type = \"customized\"\n        ndim_pads_list = get_ndim_pads_list(padding, dhw_offset, ndims)\n    else:\n        raise ValueError(\"padding must be str or a list.\")\n    return (padding_type, ndim_pads_list)\n\n\nclass MaxPool2d(Module):\n    r\"\"\"Applies a 2D max pooling over an input signal composed of several input planes.\n    \n    The interface is consistent with PyTorch.\n    The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.MaxPool2d.html.\n\n    In the simplest case, the output value of the layer with input size :math:`(N, C, H, W)`,\n    output :math:`(N, C, H_{out}, W_{out})` and :attr:`kernel_size` :math:`(kH, kW)`\n    can be precisely described as:\n\n    .. math::\n        \\begin{aligned}\n            out(N_i, C_j, h, w) ={} & \\max_{m=0, \\ldots, kH-1} \\max_{n=0, \\ldots, kW-1} \\\\\n                                    & \\text{input}(N_i, C_j, \\text{stride[0]} \\times h + m,\n                                                   \\text{stride[1]} \\times w + n)\n        \\end{aligned}\n\n    If :attr:`padding` is non-zero, then the input is implicitly minimum value padded on both sides\n    for :attr:`padding` number of points. :attr:`dilation` controls the spacing between the kernel points.\n    It is harder to describe, but this link has a nice visualization of what :attr:`dilation` does.\n\n    Note:\n        When ceil_mode=True, sliding windows are allowed to go off-bounds if they start within the left padding\n        or the input. Sliding windows that would start in the right padded region are ignored.\n    The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`dilation` can either be:\n        - a single ``int`` -- in which case the same value is used for the height and width dimension\n        - a ``tuple`` of two ints -- in which case, the first `int` is used for the height dimension,\n          and the second `int` for the width dimension\n\n    Args:\n        kernel_size: the size of the window to take a max over\n        stride: the stride of the window. Default value is :attr:`kernel_size`\n        padding: implicit minimum value padding to be added on both sides\n        dilation: a parameter that controls the stride of elements in the window\n        return_indices: if ``True``, will return the max indices along with the outputs.\n                        Useful for :class:`torch.nn.MaxUnpool2d` later\n        ceil_mode: when True, will use `ceil` instead of `floor` to compute the output shape\n\n    Shape:\n        - Input: :math:`(N, C, H_{in}, W_{in})`\n        - Output: :math:`(N, C, H_{out}, W_{out})`, where\n\n          .. math::\n              H_{out} = \\left\\lfloor\\frac{H_{in} + 2 * \\text{padding[0]} - \\text{dilation[0]}\n                    \\times (\\text{kernel_size[0]} - 1) - 1}{\\text{stride[0]}} + 1\\right\\rfloor\n          .. math::\n              W_{out} = \\left\\lfloor\\frac{W_{in} + 2 * \\text{padding[1]} - \\text{dilation[1]}\n                    \\times (\\text{kernel_size[1]} - 1) - 1}{\\text{stride[1]}} + 1\\right\\rfloor\n\n    For example:\n\n    .. code-block:: python\n\n        import oneflow as flow \n        import numpy as np\n\n        m = flow.nn.MaxPool2d(kernel_size=3, padding=1, stride=1)\n        x = flow.Tensor(np.random.randn(1, 4, 4, 4))\n        y = m(x)\n        y.shape \n        oneflow.Size([1, 4, 4, 4])\n\n    \"\"\"\n\n    def __init__(\n        self,\n        kernel_size: _size_2_t,\n        stride: Optional[_size_2_t] = None,\n        padding: _size_2_t = 0,\n        dilation: _size_2_t = 1,\n        return_indices: bool = False,\n        ceil_mode: bool = False,\n    ):\n        super().__init__()\n        self.kernel_size = _pair(kernel_size)\n        self.stride = _pair(stride) if (stride is not None) else _pair(kernel_size)\n        self.padding = _pair(padding)\n        self.dilation = _pair(dilation)\n        self.return_indices = return_indices\n        self.ceil_mode = ceil_mode\n        if os.getenv(\"ONEFLOW_ENABLE_NHWC\") == \"1\":\n            self.channel_pos = \"channels_last\"\n        else:\n            self.channel_pos = \"channels_first\"\n\n    def to_memory_format(self, memory_format) -> None:\n        if memory_format is flow.channels_last:\n            self.channel_pos = \"channels_last\"\n        elif memory_format is flow.contiguous_format:\n            self.channel_pos = \"channels_first\"\n\n    def forward(self, x):\n        if not self.return_indices:\n            return flow._C.max_pool2d(\n                x,\n                kernel_size=self.kernel_size,\n                stride=self.stride,\n                padding=self.padding,\n                dilation=self.dilation,\n                return_indices=self.return_indices,\n                ceil_mode=self.ceil_mode,\n                data_format=self.channel_pos,\n            )[0]\n        else:\n            return flow._C.max_pool2d(\n                x,\n                kernel_size=self.kernel_size,\n                stride=self.stride,\n                padding=self.padding,\n                dilation=self.dilation,\n                return_indices=self.return_indices,\n                ceil_mode=self.ceil_mode,\n                data_format=self.channel_pos,\n            )\n\n    def extra_repr(self) -> str:\n        return \"kernel_size={}, stride={}, padding={}, dilation={}\".format(\n            self.kernel_size, self.stride, self.padding, self.dilation\n        )\n\n\nclass MaxPool3d(Module):\n    r\"\"\"Applies a 3D max pooling over an input signal composed of several input planes.\n    \n    The interface is consistent with PyTorch.\n    The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.MaxPool3d.html.\n\n    In the simplest case, the output value of the layer with input size :math:`(N, C, D, H, W)`,\n    output :math:`(N, C, D_{out}, H_{out}, W_{out})` and :attr:`kernel_size` :math:`(kD, kH, kW)`\n    can be precisely described as:\n\n    .. math::\n        \\begin{aligned}\n            \\text{out}(N_i, C_j, d, h, w) ={} & \\max_{k=0, \\ldots, kD-1} \\max_{m=0, \\ldots, kH-1} \\max_{n=0, \\ldots, kW-1} \\\\\n                                              & \\text{input}(N_i, C_j, \\text{stride[0]} \\times d + k,\n                                                             \\text{stride[1]} \\times h + m, \\text{stride[2]} \\times w + n)\n        \\end{aligned}\n\n    If :attr:`padding` is non-zero, then the input is implicitly minimum value on both sides\n    for :attr:`padding` number of points. :attr:`dilation` controls the spacing between the kernel points.\n    It is harder to describe, but this link has a nice visualization of what :attr:`dilation` does.\n\n    Note:\n        When ceil_mode=True, sliding windows are allowed to go off-bounds if they start within the left padding\n        or the input. Sliding windows that would start in the right padded region are ignored.\n\n    The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`dilation` can either be:\n\n        - a single ``int`` -- in which case the same value is used for the depth, height and width dimension\n        - a ``tuple`` of three ints -- in which case, the first `int` is used for the depth dimension,\n          the second `int` for the height dimension and the third `int` for the width dimension\n\n    Args:\n        kernel_size: the size of the window to take a max over\n        stride: the stride of the window. Default value is :attr:`kernel_size`\n        padding: implicit minimum value padding to be added on all three sides\n        dilation: a parameter that controls the stride of elements in the window\n        return_indices: if ``True``, will return the max indices along with the outputs.\n                        Useful for :class:`torch.nn.MaxUnpool3d` later\n        ceil_mode: when True, will use `ceil` instead of `floor` to compute the output shape\n\n    Shape:\n        - Input: :math:`(N, C, D_{in}, H_{in}, W_{in})`\n        - Output: :math:`(N, C, D_{out}, H_{out}, W_{out})`, where\n\n          .. math::\n              D_{out} = \\left\\lfloor\\frac{D_{in} + 2 \\times \\text{padding}[0] - \\text{dilation}[0] \\times\n                (\\text{kernel_size}[0] - 1) - 1}{\\text{stride}[0]} + 1\\right\\rfloor\n\n          .. math::\n              H_{out} = \\left\\lfloor\\frac{H_{in} + 2 \\times \\text{padding}[1] - \\text{dilation}[1] \\times\n                (\\text{kernel_size}[1] - 1) - 1}{\\text{stride}[1]} + 1\\right\\rfloor\n\n          .. math::\n              W_{out} = \\left\\lfloor\\frac{W_{in} + 2 \\times \\text{padding}[2] - \\text{dilation}[2] \\times\n                (\\text{kernel_size}[2] - 1) - 1}{\\text{stride}[2]} + 1\\right\\rfloor\n\n    For example:\n\n    .. code-block:: python\n\n        import oneflow as flow \n        import numpy as np \n\n        of_maxpool3d = flow.nn.MaxPool3d(kernel_size=3, padding=1, stride=1)\n        x = flow.Tensor(np.random.randn(1, 4, 4, 4, 4))\n        y = of_maxpool3d(x)\n        y.shape \n        oneflow.Size([1, 4, 4, 4, 4])\n\n    \"\"\"\n\n    def __init__(\n        self,\n        kernel_size: _size_3_t,\n        stride: Optional[_size_3_t] = None,\n        padding: _size_3_t = 0,\n        dilation: _size_3_t = 1,\n        return_indices: bool = False,\n        ceil_mode: bool = False,\n    ):\n        super().__init__()\n        self.kernel_size = _triple(kernel_size)\n        self.stride = _triple(stride) if (stride is not None) else _triple(kernel_size)\n        data_format = \"NCDHW\"\n        self.channel_pos = (\n            \"channels_last\" if data_format == \"NDHWC\" else \"channels_first\"\n        )\n        self.dilation = _triple(dilation)\n        self.padding = _triple(padding)\n        self.return_indices = return_indices\n        self.ceil_mode = ceil_mode\n\n    def forward(self, x):\n        y, indice = flow._C.max_pool3d(\n            x,\n            kernel_size=self.kernel_size,\n            stride=self.stride,\n            padding=self.padding,\n            dilation=self.dilation,\n            return_indices=True,\n            ceil_mode=self.ceil_mode,\n            data_format=self.channel_pos,\n        )\n\n        if self.return_indices:\n            return y, indice\n        else:\n            return y\n\n    def extra_repr(self) -> str:\n        return \"kernel_size={}, stride={}, padding={}, dilation={}\".format(\n            self.kernel_size, self.stride, self.padding, self.dilation\n        )\n\n\nclass AvgPool1d(Module):\n    r\"\"\"Applies a 1D average pooling over an input signal composed of several input planes.\n    In the simplest case, the output value of the layer with input size :math:`(N, C, H, W)`,\n    output :math:`(N, C, H_{out}, W_{out})` and `kernel_size` :math:`k`\n    can be precisely described as:\n    \n    .. math::\n        out(N_i, C_j, l)  = \\\\frac{1}{k} \\\\sum_{m=0}^{k-1}\n                               input(N_i, C_j, stride[0] \\\\times h + m, stride*l + m)\n    \n    If padding is non-zero, then the input is implicitly zero-padded on both sides for padding number of points.\n    The parameters kernel_size, stride, padding can each be an int or a one-element tuple.\n    \n    Note:\n        When ceil_mode=True, sliding windows are allowed to go off-bounds if they start within the left padding or the\n        input. Sliding windows that would start in the right padded region are ignored.\n    \n    Args:\n        kernel_size: the size of the window.\n        strides: the stride of the window. Default value is kernel_size.\n        padding: implicit zero padding to be added on both sides.\n        ceil_mode: when True, will use ceil instead of floor to compute the output shape.\n        count_include_pad: when True, will include the zero-padding in the averaging calculation.\n    \n    For example: \n\n    .. code-block:: python \n        \n        import oneflow as flow \n        import numpy as np\n\n        m = flow.nn.AvgPool1d(kernel_size=3, padding=1, stride=1)\n        x = flow.tensor(np.random.randn(1, 4, 4))\n        y = m(x)\n        y.shape \n        oneflow.Size([1, 4, 4])\n\n    \"\"\"\n\n    def __init__(\n        self,\n        kernel_size: _size_2_t,\n        stride: Optional[_size_2_t] = None,\n        padding: _size_2_t = 0,\n        ceil_mode: bool = False,\n        count_include_pad: bool = True,\n    ):\n        super().__init__()\n        self.kernel_size = _single(kernel_size)\n        data_format = \"NCHW\"  # only support \"NCHW\" for now !\n        self.channel_pos = (\n            \"channels_first\" if data_format == \"NCHW\" else \"channels_last\"\n        )\n        self.stride = _single(stride) if (stride is not None) else _single(kernel_size)\n        self.ceil_mode = ceil_mode\n        self.count_include_pad = count_include_pad\n        self.padding = _single(padding)\n\n    def forward(self, x):\n        return flow._C.avg_pool1d(\n            x,\n            kernel_size=self.kernel_size,\n            stride=self.stride,\n            padding=self.padding,\n            ceil_mode=self.ceil_mode,\n            count_include_pad=self.count_include_pad,\n            divisor_override=0,\n            data_format=self.channel_pos,\n        )\n\n    def extra_repr(self) -> str:\n        return (\n            \"kernel_size={kernel_size}, stride={stride}, padding={padding}\"\n            \", ceil_mode={ceil_mode}\".format(**self.__dict__)\n        )\n\n\nclass AvgPool2d(Module):\n    r\"\"\"Performs the 2d-average pooling on the input.\n\n    In the simplest case, the output value of the layer with input size :math:`(N, C, H, W)`,\n    output :math:`(N, C, H_{out}, W_{out})` and `kernel_size` :math:`(kH, kW)`\n    can be precisely described as:\n\n    .. math::\n\n        out(N_i, C_j, h, w)  = \\frac{1}{kH * kW} \\sum_{m=0}^{kH-1} \\sum_{n=0}^{kW-1}\n                               input(N_i, C_j, stride[0] \\times h + m, stride[1] \\times w + n)\n\n    Args:\n        kernel_size (Union[int, Tuple[int, int]]):  An int or list of ints that has length 1, 2. The size of the window for each dimension of the input Tensor.\n        strides (Union[int, Tuple[int, int]]): An int or list of ints that has length 1, 2. The stride of the sliding window for each dimension of the input Tensor.\n        padding (Tuple[int, int]): An int or list of ints that has length 1, 2. Implicit zero padding to be added on both sides.\n        ceil_mode (bool, default to False): When True, will use ceil instead of floor to compute the output shape.\n\n    For example:\n\n    .. code-block:: python\n\n        import oneflow as flow \n        import numpy as np \n\n        m = flow.nn.AvgPool2d(kernel_size=3, padding=1, stride=1)\n        x = flow.tensor(np.random.randn(1, 4, 4, 4))\n        y = m(x)   \n        y.shape\n        oneflow.Size([1, 4, 4, 4])\n\n    \"\"\"\n\n    def __init__(\n        self,\n        kernel_size: _size_2_t,\n        stride: Optional[_size_2_t] = None,\n        padding: _size_2_t = 0,\n        ceil_mode: bool = False,\n        count_include_pad: bool = True,\n        divisor_override: int = 0,\n    ):\n        super().__init__()\n        self.kernel_size = _pair(kernel_size)\n        self.stride = _pair(stride) if (stride is not None) else _pair(kernel_size)\n        self.ceil_mode = ceil_mode\n        self.channel_pos = \"channels_first\"\n        if os.getenv(\"ONEFLOW_ENABLE_NHWC\") == \"1\":\n            self.channel_pos = \"channels_last\"\n        self.padding = _pair(padding)\n        self.count_include_pad = count_include_pad\n        self.divisor_override = int(divisor_override)\n\n    def to_memory_format(self, memory_format) -> None:\n        if memory_format is flow.channels_last:\n            self.channel_pos = \"channels_last\"\n        elif memory_format is flow.contiguous_format:\n            self.channel_pos = \"channels_first\"\n\n    def forward(self, x):\n        return flow._C.avg_pool2d(\n            x,\n            kernel_size=self.kernel_size,\n            stride=self.stride,\n            padding=self.padding,\n            ceil_mode=self.ceil_mode,\n            count_include_pad=self.count_include_pad,\n            divisor_override=self.divisor_override,\n            data_format=self.channel_pos,\n        )\n\n    def extra_repr(self) -> str:\n        return (\n            \"kernel_size={kernel_size}, stride={stride}, padding={padding}\"\n            \", ceil_mode={ceil_mode}\".format(**self.__dict__)\n        )\n\n\nclass AvgPool3d(Module):\n    r\"\"\"Applies a 3D average pooling over an input signal composed of several input planes.\n    In the simplest case, the output value of the layer with input size :math:`(N, C, D, H, W)`,\n    output :math:`(N, C, D_{out}, H_{out}, W_{out})` and `kernel_size` :math:`(kD, kH, kW)`\n    can be precisely described as:\n    \n    .. math::\n        out(N_i, C_j, d, h, w)  = \\\\frac{1}{kD * kH * kW } \\\\sum_{k=0}^{kD-1} \\\\sum_{m=0}^{kH-1} \\\\sum_{n=0}^{kW-1}\n                               input(N_i, C_j, stride[0] \\\\times d + k, stride[1] \\\\times h + m, stride[2] \\\\times w + n)\n    \n    If padding is non-zero, then the input is implicitly zero-padded on all three sides for padding number of points.\n    \n    Note:\n        When ceil_mode=True, sliding windows are allowed to go off-bounds if they start within the left padding or the\n        input. Sliding windows that would start in the right padded region are ignored.\n    \n    Args:\n        kernel_size: the size of the window.\n        strides:  the stride of the window. Default value is kernel_size.\n        padding:  implicit zero padding to be added on all three sides.\n        ceil_mode:  when True, will use ceil instead of floor to compute the output shape.\n        count_include_pad: when True, will include the zero-padding in the averaging calculation.\n        divisor_override: if specified, it will be used as divisor, otherwise kernel_size will be used.\n    \n    Shape:\n        - Input: :math:`(N, C, D_{in}, H_{in}, W_{in})`\n\n        - Output: :math:`(N, C, D_{out}, H_{out}, W_{out})`, where\n        \n          .. math::\n              D_{out} = \\\\left\\\\lfloor\\\\frac{D_{in} + 2 \\\\times \\\\text{padding}[0] - \\\\text{kernel_size}[0]}{\\\\text{stride}[0]} + 1\\\\right\\\\rfloor\n        \n          .. math::\n              H_{out} = \\\\left\\\\lfloor\\\\frac{H_{in} + 2 \\\\times \\\\text{padding}[1] - \\\\text{kernel_size}[1]}{\\\\text{stride}[1]} + 1\\\\right\\\\rfloor\n        \n          .. math::\n              W_{out} = \\\\left\\\\lfloor\\\\frac{W_{in} + 2 \\\\times \\\\text{padding}[2] - \\\\text{kernel_size}[2]}{\\\\text{stride}[2]} + 1\\\\right\\\\rfloor\n    \n    For example:\n    \n    .. code-block:: python\n    \n        import oneflow as flow\n        import numpy as np\n        \n        m = flow.nn.AvgPool3d(kernel_size=(2,2,2),padding=(0,0,0),stride=(1,1,1))\n        x = flow.tensor(np.random.randn(9, 7, 11, 32, 20))\n        y = m(x)\n        y.shape\n        oneflow.Size([9, 7, 10, 31, 19])\n\n    \"\"\"\n\n    def __init__(\n        self,\n        kernel_size: _size_3_t,\n        stride: Optional[_size_3_t] = None,\n        padding: _size_3_t = 0,\n        ceil_mode: bool = False,\n        count_include_pad: bool = True,\n        divisor_override: int = 0,\n    ):\n        super().__init__()\n        self.kernel_size = _triple(kernel_size)\n        data_format = \"NCHW\"  # only support \"NCHW\" for now !\n        self.channel_pos = (\n            \"channels_first\" if data_format == \"NCHW\" else \"channels_last\"\n        )\n        self.stride = _triple(stride) if (stride is not None) else _triple(kernel_size)\n        self.ceil_mode = ceil_mode\n        self.count_include_pad = count_include_pad\n        self.divisor_override = int(divisor_override)\n        self.padding = _triple(padding)\n\n    def forward(self, x):\n        return flow._C.avg_pool3d(\n            x,\n            kernel_size=self.kernel_size,\n            stride=self.stride,\n            padding=self.padding,\n            ceil_mode=self.ceil_mode,\n            count_include_pad=self.count_include_pad,\n            divisor_override=self.divisor_override,\n            data_format=self.channel_pos,\n        )\n\n    def extra_repr(self) -> str:\n        return (\n            \"kernel_size={kernel_size}, stride={stride}, padding={padding}\"\n            \", ceil_mode={ceil_mode}\".format(**self.__dict__)\n        )\n\n\nclass AdaptiveAvgPool1d(Module):\n    \"\"\"Applies a 1D adaptive average pooling over an input signal composed of several input planes.\n\n    The output size is H, for any input size.\n    The number of output features is equal to the number of input planes.\n\n    Args:\n        output_size: the target output size H\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import numpy as np\n        >>> import oneflow as flow\n        >>> import oneflow.nn as nn\n\n        >>> m = nn.AdaptiveAvgPool1d(5)\n        >>> input = flow.Tensor(np.random.randn(1, 64, 8))\n        >>> output = m(input)\n        >>> output.size()\n        oneflow.Size([1, 64, 5])\n\n    \"\"\"\n\n    def __init__(self, output_size: _size_1_t) -> None:\n        super().__init__()\n        assert output_size is not None, \"'output_size' cannot be NoneType\"\n        self.output_size = _single(output_size)\n\n    def forward(self, x):\n        assert (\n            len(x.shape) == 3 and len(self.output_size) == 1\n        ), \"the length of 'output_size' does not match the input size, 1 expected\"\n        assert isinstance(\n            self.output_size[0], int\n        ), \"numbers in 'output_size' should be integer\"\n        return flow._C.adaptive_avg_pool1d(x, output_size=self.output_size)\n\n\ndef adaptive_avg_pool1d(input, output_size):\n    \"\"\"Applies a 1D adaptive average pooling over an input signal composed of several input planes.\n\n    See :mod:`oneflow.nn.AdaptiveAvgPool1d`\n\n    Args:\n        input: input tensor\n        output_size: the target output size (single integer)\n    \"\"\"\n    return AdaptiveAvgPool1d(output_size)(input)\n\n\nclass AdaptiveAvgPool2d(Module):\n    \"\"\"Applies a 2D adaptive average pooling over an input signal composed of several input planes.\n\n    The output is of size H x W, for any input size.\n    The number of output features is equal to the number of input planes.\n\n    Args:\n        output_size: the target output size of the image of the form H x W.\n                     Can be a tuple (H, W) or a single H for a square image H x H.\n                     H and W can be either a ``int``, or ``None`` which means the size will\n                     be the same as that of the input.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import numpy as np\n        >>> import oneflow as flow\n        >>> import oneflow.nn as nn\n\n        >>> m = nn.AdaptiveAvgPool2d((5,7))\n        >>> input = flow.Tensor(np.random.randn(1, 64, 8, 9))\n        >>> output = m(input)\n        >>> output.size()\n        oneflow.Size([1, 64, 5, 7])\n\n        >>> m = nn.AdaptiveAvgPool2d(7)\n        >>> input = flow.Tensor(np.random.randn(1, 64, 10, 9))\n        >>> output = m(input)\n        >>> output.size()\n        oneflow.Size([1, 64, 7, 7])\n\n        >>> m = nn.AdaptiveAvgPool2d((None, 7))\n        >>> input = flow.Tensor(np.random.randn(1, 64, 10, 9))\n        >>> output = m(input)\n        >>> output.size()\n        oneflow.Size([1, 64, 10, 7])\n\n    \"\"\"\n\n    def __init__(self, output_size, data_format=None) -> None:\n        super().__init__()\n        assert output_size is not None, \"'output_size' cannot be NoneType\"\n        self.output_size = _pair(output_size)\n        if data_format:\n            if not data_format in [\"channels_first\", \"channels_last\"]:\n                raise ValueError(\n                    f\"data_format must be one of ['channels_first', 'channels_last'], but got {data_format}\"\n                )\n            self.channel_pos = data_format\n        elif os.getenv(\"ONEFLOW_ENABLE_NHWC\") == \"1\":\n            self.channel_pos = \"channels_last\"\n        else:\n            self.channel_pos = \"channels_first\"\n\n    def to_memory_format(self, memory_format) -> None:\n        if memory_format is flow.channels_last:\n            self.channel_pos = \"channels_last\"\n        elif memory_format is flow.channels_first:\n            self.channel_pos = \"channels_first\"\n\n    def forward(self, x):\n        assert (\n            len(x.shape) == 4\n        ), f\"expected 4-dimensional tensor, but got {len(x.shape)}-dimensional tensor\"\n        new_output_size = _generate_output_size(x.shape, self.output_size)\n        return flow._C.adaptive_avg_pool2d(\n            x, output_size=new_output_size, data_format=self.channel_pos\n        )\n\n\ndef adaptive_avg_pool2d(input, output_size, data_format=None):\n    \"\"\"Applies a 2D adaptive average pooling over an input signal composed of several input planes.\n\n    See :mod:`oneflow.nn.AdaptiveAvgPool2d`\n\n    Args:\n        input: input tensor\n        output_size: the target output size (single integer or double-integer tuple)\n    \"\"\"\n    return AdaptiveAvgPool2d(output_size, data_format)(input)\n\n\nclass AdaptiveAvgPool3d(Module):\n    \"\"\"Applies a 3D adaptive average pooling over an input signal composed of several input planes.\n\n    The output is of size D x H x W, for any input size.\n    The number of output features is equal to the number of input planes.\n\n    Args:\n        output_size: the target output size of the form D x H x W.\n                     Can be a tuple (D, H, W) or a single number D for a cube D x D x D.\n                     D, H and W can be either a ``int``, or ``None`` which means the size will\n                     be the same as that of the input.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import numpy as np\n        >>> import oneflow as flow\n        >>> import oneflow.nn as nn\n\n        >>> m = nn.AdaptiveAvgPool3d((5,7,9))\n        >>> input = flow.Tensor(np.random.randn(1, 64, 8, 9, 10))\n        >>> output = m(input)\n        >>> output.size()\n        oneflow.Size([1, 64, 5, 7, 9])\n\n        >>> m = nn.AdaptiveAvgPool3d(7)\n        >>> input = flow.Tensor(np.random.randn(1, 64, 10, 9, 8))\n        >>> output = m(input)\n        >>> output.size()\n        oneflow.Size([1, 64, 7, 7, 7])\n\n        >>> m = nn.AdaptiveAvgPool3d((7, None, None))\n        >>> input = flow.Tensor(np.random.randn(1, 64, 10, 9, 8))\n        >>> output = m(input)\n        >>> output.size()\n        oneflow.Size([1, 64, 7, 9, 8])\n\n    \"\"\"\n\n    def __init__(self, output_size) -> None:\n        super().__init__()\n        assert output_size is not None, \"'output_size' cannot be NoneType\"\n        self.output_size = _triple(output_size)\n\n    def forward(self, x):\n        assert (\n            len(x.shape) == 5\n        ), f\"expected 5-dimensional tensor, but got {len(x.shape)}-dimensional tensor\"\n        new_output_size = _generate_output_size(x.shape, self.output_size)\n        return flow._C.adaptive_avg_pool3d(x, output_size=new_output_size)\n\n\ndef adaptive_avg_pool3d(input, output_size):\n    \"\"\"Applies a 3D adaptive average pooling over an input signal composed of several input planes.\n\n    See :mod:`oneflow.nn.AdaptiveAvgPool3d`\n\n    Args:\n        input: input tensor\n        output_size: the target output size (single integer or triple-integer tuple)\n    \"\"\"\n    return AdaptiveAvgPool3d(output_size)(input)\n\n\nclass _AdaptiveMaxPoolNd(Module):\n    def __init__(self, output_size, return_indices: bool = False) -> None:\n        super(_AdaptiveMaxPoolNd, self).__init__()\n        self.output_size = output_size\n        self.return_indices = return_indices\n\n    def extra_repr(self) -> str:\n        return \"output_size={}\".format(self.output_size)\n\n\nclass AdaptiveMaxPool1d(_AdaptiveMaxPoolNd):\n    r\"\"\"Applies a 1D adaptive max pooling over an input signal composed of several input planes.\n\n        The documentation is referenced from:\n        https://pytorch.org/docs/1.10/generated/torch.nn.AdaptiveMaxPool1d.html.\n        \n        The output size is :math:`L_{out}`, for any input size.\n        The number of output features is equal to the number of input planes.\n\n        Args:\n            output_size: the target output size :math:`L_{out}`.\n            return_indices: if ``True``, will return the indices along with the outputs.\n                            Default: ``False``\n\n        Shape:\n            - Input: :math:`(N, C, L_{in})`.\n            - Output: :math:`(N, C, L_{out})`, where :math:`L_{out}=\\text{output_size}`.\n\n        Examples:\n\n        .. code-block:: python\n\n            >>> import oneflow as flow\n            >>> # target output size of 5\n            >>> m = flow.nn.AdaptiveMaxPool1d(5)\n            >>> input = flow.randn(1, 64, 8)\n            >>> output = m(input)\n            >>> print(output.shape)\n            oneflow.Size([1, 64, 5])\n\n    \"\"\"\n\n    def forward(self, input):\n        self.output_size = _single(self.output_size)\n        assert (\n            len(input.shape) == 3 and len(self.output_size) == 1\n        ), \"the length of 'output_size' does not match the input size, 1 expected\"\n        new_output_size = _generate_output_size(input.shape, self.output_size)\n        return flow.nn.functional.adaptive_max_pool1d(\n            input, self.output_size, self.return_indices\n        )\n\n\nclass AdaptiveMaxPool2d(_AdaptiveMaxPoolNd):\n    r\"\"\"Applies a 2D adaptive max pooling over an input signal composed of several input planes.\n\n    The documentation is referenced from:\n    https://pytorch.org/docs/1.10/generated/torch.nn.AdaptiveMaxPool2d.html.\n\n    The output is of size :math:`H_{out} \\times W_{out}`, for any input size.\n    The number of output features is equal to the number of input planes.\n\n    Args:\n        output_size: the target output size of the image of the form :math:`H_{out} \\times W_{out}`.\n                     Can be a tuple :math:`(H_{out}, W_{out})` or a single :math:`H_{out}` for a\n                     square image :math:`H_{out} \\times H_{out}`. :math:`H_{out}` and :math:`W_{out}`\n                     should be a ``int``.\n        return_indices: if ``True``, will return the indices along with the outputs.\n                        Default: ``False``\n\n    Shape:\n        - Input: :math:`(N, C, H_{in}, W_{in})`.\n        - Output: :math:`(N, C, H_{out}, W_{out})`, where\n          :math:`(H_{out}, W_{out})=\\text{output_size}`.\n\n    Examples:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import oneflow.nn as nn\n        >>> # target output size of 5x7\n        >>> m = nn.AdaptiveMaxPool2d((5,7))\n        >>> input = flow.randn(1, 64, 8, 9)\n        >>> output = m(input)\n        >>> print(output.shape)\n        oneflow.Size([1, 64, 5, 7])\n        >>> # target output size of 7x7 (square)\n        >>> m = nn.AdaptiveMaxPool2d(7)\n        >>> input = flow.randn(1, 64, 10, 9)\n        >>> output = m(input)\n        >>> print(output.shape)\n        oneflow.Size([1, 64, 7, 7])\n    \"\"\"\n\n    def __init__(self, output_size, return_indices=False, data_format=None) -> None:\n        super().__init__(output_size, return_indices=return_indices)\n        if data_format:\n            if not data_format in [\"channels_first\", \"channels_last\"]:\n                raise ValueError(\n                    f\"data_format must be one of ['channels_first', 'channels_last'], but got {data_format}\"\n                )\n            self.channel_pos = data_format\n        elif os.getenv(\"ONEFLOW_ENABLE_NHWC\") == \"1\":\n            self.channel_pos = \"channels_last\"\n        else:\n            self.channel_pos = \"channels_first\"\n\n    def to_memory_format(self, memory_format) -> None:\n        if memory_format is flow.channels_last:\n            self.channel_pos = \"channels_last\"\n        elif memory_format is flow.channels_first:\n            self.channel_pos = \"channels_first\"\n\n    def forward(self, input):\n        self.output_size = _pair(self.output_size)\n        assert (\n            len(input.shape) == 4\n        ), f\"expected 4-dimensional tensor, but got {len(input.shape)}-dimensional tensor\"\n        new_output_size = _generate_output_size(input.shape, self.output_size)\n        return flow.nn.functional.adaptive_max_pool2d(\n            input, self.output_size, self.return_indices, self.channel_pos\n        )\n\n\nclass AdaptiveMaxPool3d(_AdaptiveMaxPoolNd):\n    r\"\"\"Applies a 3D adaptive max pooling over an input signal composed of several input planes.\n\n    The documentation is referenced from:\n    https://pytorch.org/docs/1.10/generated/torch.nn.AdaptiveMaxPool3d.html.\n\n    The output is of size :math:`D_{out} \\times H_{out} \\times W_{out}`, for any input size.\n    The number of output features is equal to the number of input planes.\n\n    Args:\n        output_size: the target output size of the image of the form :math:`D_{out} \\times H_{out} \\times W_{out}`.\n                     Can be a tuple :math:`(D_{out}, H_{out}, W_{out})` or a single\n                     :math:`D_{out}` for a cube :math:`D_{out} \\times D_{out} \\times D_{out}`.\n                     :math:`D_{out}`, :math:`H_{out}` and :math:`W_{out}` should be a\n                     ``int``.\n\n        return_indices: if ``True``, will return the indices along with the outputs.\n                        Default: ``False``\n\n    Shape:\n        - Input: :math:`(N, C, D_{in}, H_{in}, W_{in})`.\n        - Output: :math:`(N, C, D_{out}, H_{out}, W_{out})`,\n          where :math:`(D_{out}, H_{out}, W_{out})=\\text{output_size}`.\n\n    Examples:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import oneflow.nn as nn\n        >>> # target output size of 5x7x9\n        >>> m = nn.AdaptiveMaxPool3d((5,7,9))\n        >>> input = flow.randn(1, 64, 8, 9, 10)\n        >>> output = m(input)\n        >>> print(output.shape)\n        oneflow.Size([1, 64, 5, 7, 9])\n        >>> # target output size of 7x7x7 (cube)\n        >>> m = nn.AdaptiveMaxPool3d(7)\n        >>> input = flow.randn(1, 64, 10, 9, 8)\n        >>> output = m(input)\n        >>> print(output.shape)\n        oneflow.Size([1, 64, 7, 7, 7])\n    \"\"\"\n\n    def forward(self, input):\n        self.output_size = _triple(self.output_size)\n        assert (\n            len(input.shape) == 5\n        ), f\"expected 5-dimensional tensor, but got {len(input.shape)}-dimensional tensor\"\n        new_output_size = _generate_output_size(input.shape, self.output_size)\n        return flow.nn.functional.adaptive_max_pool3d(\n            input, self.output_size, self.return_indices\n        )\n\n\nclass MaxUnpool1d(Module):\n    r\"\"\"Computes a partial inverse of :class:`MaxPool1d`.\n\n    :class:`MaxPool1d` is not fully invertible, since the non-maximal values are lost.\n\n    :class:`MaxUnpool1d` takes in as input the output of :class:`MaxPool1d`\n    including the indices of the maximal values and computes a partial inverse\n    in which all non-maximal values are set to zero.\n\n    The interface is consistent with PyTorch.\n    The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.MaxUnpool1d.html.\n\n    .. note:: :class:`MaxPool1d` can map several input sizes to the same output\n              sizes. Hence, the inversion process can get ambiguous.\n              To accommodate this, you can provide the needed output size\n              as an additional argument :attr:`output_size` in the forward call.\n              See the Inputs and Example below.\n\n    Args:\n        kernel_size (int or tuple): Size of the max pooling window.\n        stride (int or tuple): Stride of the max pooling window.\n            It is set to :attr:`kernel_size` by default.\n        padding (int or tuple): Padding that was added to the input\n\n    Inputs:\n        - `input`: the input Tensor to invert\n        - `indices`: the indices given out by :class:`~oneflow.nn.MaxPool1d`\n        - `output_size` (optional): the targeted output size\n\n    Shape:\n        - Input: :math:`(N, C, H_{in})`.\n        - Output: :math:`(N, C, H_{out})`, where\n\n          .. math::\n              H_{out} = (H_{in} - 1) \\times \\text{stride}[0] - 2 \\times \\text{padding}[0] + \\text{kernel\\_size}[0]\n\n          or as given by :attr:`output_size` in the call operator\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> pool = flow.nn.MaxPool1d(2, stride=2, return_indices=True)\n        >>> unpool = flow.nn.MaxUnpool1d(2, stride=2)\n        >>> input = flow.tensor([[[1., 2, 3, 4, 5, 6, 7, 8]]])\n        >>> output, indices = pool(input)\n        >>> unpool(output, indices)\n        tensor([[[0., 2., 0., 4., 0., 6., 0., 8.]]], dtype=oneflow.float32)\n        >>> # Example showcasing the use of output_size\n        >>> input = flow.tensor([[[1., 2, 3, 4, 5, 6, 7, 8, 9]]])\n        >>> output, indices = pool(input)\n        >>> unpool(output, indices, output_size=input.size())\n        tensor([[[0., 2., 0., 4., 0., 6., 0., 8., 0.]]], dtype=oneflow.float32)\n        >>> unpool(output, indices)\n        tensor([[[0., 2., 0., 4., 0., 6., 0., 8.]]], dtype=oneflow.float32)\n\n    .. note:: When `indices` contains elements out of the `output_size` range,\n              an RuntimeError will be raised on the cpu and an indeterminate\n              result will be calculated on the cuda.\n\n    \"\"\"\n\n    def __init__(\n        self,\n        kernel_size: _size_1_t,\n        stride: Optional[_size_1_t] = None,\n        padding: Optional[_size_1_t] = 0,\n    ):\n        super().__init__()\n        self.kernel_size = kernel_size\n        self.stride = stride\n        self.padding = padding\n\n    def forward(self, x, indices, output_size=None):\n        return flow._C.max_unpool1d(\n            x, indices, self.kernel_size, self.stride, self.padding, output_size\n        )\n\n\nclass MaxUnpool2d(Module):\n    r\"\"\"Computes a partial inverse of :class:`MaxPool2d`.\n\n    :class:`MaxPool2d` is not fully invertible, since the non-maximal values are lost.\n\n    :class:`MaxUnpool2d` takes in as input the output of :class:`MaxPool2d`\n    including the indices of the maximal values and computes a partial inverse\n    in which all non-maximal values are set to zero.\n\n    The interface is consistent with PyTorch.\n    The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.MaxUnpool2d.html.\n\n    .. note:: :class:`MaxPool2d` can map several input sizes to the same output\n              sizes. Hence, the inversion process can get ambiguous.\n              To accommodate this, you can provide the needed output size\n              as an additional argument :attr:`output_size` in the forward call.\n              See the Inputs and Example below.\n\n    Args:\n        kernel_size (int or tuple): Size of the max pooling window.\n        stride (int or tuple): Stride of the max pooling window.\n            It is set to :attr:`kernel_size` by default.\n        padding (int or tuple): Padding that was added to the input\n\n    Inputs:\n        - `input`: the input Tensor to invert\n        - `indices`: the indices given out by :class:`~oneflow.nn.MaxPool2d`\n        - `output_size` (optional): the targeted output size\n\n    Shape:\n        - Input: :math:`(N, C, H_{in}, W_{in})` .\n        - Output: :math:`(N, C, H_{out}, W_{out})`, where\n\n          .. math::\n            H_{out} = (H_{in} - 1) \\times \\text{stride[0]} - 2 \\times \\text{padding[0]} + \\text{kernel\\_size[0]}\n\n          .. math::\n            W_{out} = (W_{in} - 1) \\times \\text{stride[1]} - 2 \\times \\text{padding[1]} + \\text{kernel\\_size[1]}\n\n          or as given by :attr:`output_size` in the call operator\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> pool = flow.nn.MaxPool2d(2, stride=2, return_indices=True)\n        >>> unpool = flow.nn.MaxUnpool2d(2, stride=2)\n        >>> input = flow.tensor([[[[ 1.,  2,  3,  4],\n        ...                         [ 5,  6,  7,  8],\n        ...                         [ 9, 10, 11, 12],\n        ...                         [13, 14, 15, 16]]]])\n        >>> output, indices = pool(input)\n        >>> unpool(output, indices) # doctest: +SKIP \n        tensor([[[[ 0.,  0.,  0.,  0.],\n                [ 0.,  6.,  0.,  8.],\n                [ 0.,  0.,  0.,  0.],\n                [ 0., 14.,  0., 16.]]]], dtype=oneflow.float32)\n        >>> # specify a different output size than input size\n        >>> unpool(output, indices, output_size=flow.Size([1, 1, 5, 5])) # doctest: +SKIP\n        tensor([[[[ 0.,  0.,  0.,  0.,  0.],\n                [ 6.,  0.,  8.,  0.,  0.],\n                [ 0.,  0.,  0., 14.,  0.],\n                [16.,  0.,  0.,  0.,  0.],\n                [ 0.,  0.,  0.,  0.,  0.]]]], dtype=oneflow.float32)\n\n    .. note:: When `indices` contains elements out of the `output_size` range,\n              an RuntimeError will be raised on the cpu and an indeterminate\n              result will be calculated on the cuda.\n    \"\"\"\n\n    def __init__(\n        self,\n        kernel_size: _size_2_t,\n        stride: Optional[_size_2_t] = None,\n        padding: Optional[_size_2_t] = 0,\n    ):\n        super().__init__()\n        self.kernel_size = kernel_size\n        self.stride = stride\n        self.padding = padding\n\n    def forward(self, x, indices, output_size=None):\n        return flow._C.max_unpool2d(\n            x, indices, self.kernel_size, self.stride, self.padding, output_size\n        )\n\n\nclass MaxUnpool3d(Module):\n    r\"\"\"Computes a partial inverse of :class:`MaxPool3d`.\n\n    :class:`MaxPool3d` is not fully invertible, since the non-maximal values are lost.\n    :class:`MaxUnpool3d` takes in as input the output of :class:`MaxPool3d`\n    including the indices of the maximal values and computes a partial inverse\n    in which all non-maximal values are set to zero.\n\n    The interface is consistent with PyTorch.\n    The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.MaxPool3d.html.\n\n    .. note:: :class:`MaxPool3d` can map several input sizes to the same output\n              sizes. Hence, the inversion process can get ambiguous.\n              To accommodate this, you can provide the needed output size\n              as an additional argument :attr:`output_size` in the forward call.\n              See the Inputs section below.\n\n    Args:\n        kernel_size (int or tuple): Size of the max pooling window.\n        stride (int or tuple): Stride of the max pooling window.\n            It is set to :attr:`kernel_size` by default.\n        padding (int or tuple): Padding that was added to the input\n\n    Inputs:\n        - `input`: the input Tensor to invert\n        - `indices`: the indices given out by :class:`~oneflow.nn.MaxPool3d`\n        - `output_size` (optional): the targeted output size\n\n    Shape:\n        - Input: :math:`(N, C, D_{in}, H_{in}, W_{in})`.\n        - Output: :math:`(N, C, D_{out}, H_{out}, W_{out})`, where\n\n          .. math::\n              D_{out} = (D_{in} - 1) \\times \\text{stride[0]} - 2 \\times \\text{padding[0]} + \\text{kernel\\_size[0]}\n\n          .. math::\n              H_{out} = (H_{in} - 1) \\times \\text{stride[1]} - 2 \\times \\text{padding[1]} + \\text{kernel\\_size[1]}\n\n          .. math::\n              W_{out} = (W_{in} - 1) \\times \\text{stride[2]} - 2 \\times \\text{padding[2]} + \\text{kernel\\_size[2]}\n\n          or as given by :attr:`output_size` in the call operator\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> # pool of square window of size=3, stride=2\n        >>> pool = flow.nn.MaxPool3d(3, stride=2, return_indices=True)\n        >>> unpool = flow.nn.MaxUnpool3d(3, stride=2)\n        >>> output, indices = pool(flow.randn(20, 16, 51, 33, 15))\n        >>> unpooled_output = unpool(output, indices)\n        >>> unpooled_output.size()\n        oneflow.Size([20, 16, 51, 33, 15])\n\n    .. note:: When `indices` contains elements out of the `output_size` range,\n              an RuntimeError will be raised on the cpu and an indeterminate\n              result will be calculated on the cuda.\n    \"\"\"\n\n    def __init__(\n        self,\n        kernel_size: _size_3_t,\n        stride: Optional[_size_3_t] = None,\n        padding: Optional[_size_3_t] = 0,\n    ):\n        super().__init__()\n        self.kernel_size = kernel_size\n        self.stride = stride\n        self.padding = padding\n\n    def forward(self, x, indices, output_size=None):\n        return flow._C.max_unpool3d(\n            x, indices, self.kernel_size, self.stride, self.padding, output_size\n        )\n\n\nif __name__ == \"__main__\":\n    import doctest\n\n    doctest.testmod(raise_on_error=True)\n"
  },
  {
    "path": "python/oneflow/nn/modules/quantization.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow as flow\nfrom oneflow.nn.modules.module import Module\n\n\nclass Quantization(Module):\n    \"\"\"\n    \n    Simulate the quantize operation in inference time.\n\n    The output will be computed as:\n\n        if quantization_scheme == \"symmetric\":\n\n        .. math::\n\n            & quant\\\\_max = 2^{quantization\\\\_to\\\\_bit - 1} - 1\n\n            & quant\\\\_min = -quant\\\\_max\n\n            & clamp(round(x / scale), quant\\\\_min, quant\\\\_max)\n\n        elif quantization_scheme == \"affine\":\n\n        .. math::\n\n            & quant\\\\_max = 2^{quantization\\\\_to\\\\_bit} - 1\n\n            & quant\\\\_min = 0\n\n            & (clamp(round(x / scale + zero\\\\_point), quant\\\\_min, quant\\\\_max) - zero\\\\_point)\n\n    Args:\n        quantization_bit (int): Quantize input to uintX / intX, X can be in range [2, 8]. Defaults to 8.\n        quantization_scheme (str): \"symmetric\" or \"affine\", quantize to signed / unsigned integer. Defaults to \"symmetric\".\n        quantization_formula (str): Support \"google\" or \"cambricon\".\n\n    Returns:\n        oneflow.Tensor: Input tensor after quantize operation.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import numpy as np\n        >>> import oneflow as flow\n\n        >>> weight = (np.random.random((2, 3, 4, 5)) - 0.5).astype(np.float32)\n        \n        >>> input_tensor = flow.tensor(\n        ...    weight, dtype=flow.float32\n        ... )\n        \n        >>> quantization_bit = 8\n        >>> quantization_scheme = \"symmetric\"\n        >>> quantization_formula = \"google\"\n        >>> per_layer_quantization = True\n\n        >>> min_max_observer = flow.nn.MinMaxObserver(quantization_formula=quantization_formula, quantization_bit=quantization_bit,\n        ... quantization_scheme=quantization_scheme, per_layer_quantization=per_layer_quantization)\n        >>> quantization = flow.nn.Quantization(quantization_formula=quantization_formula, quantization_bit=quantization_bit, \n        ... quantization_scheme=quantization_scheme)\n\n        >>> scale, zero_point = min_max_observer(\n        ...    input_tensor,\n        ... )\n\n        >>> output_tensor = quantization(\n        ...    input_tensor,\n        ...    scale,\n        ...    zero_point,\n        ... )\n\n    \"\"\"\n\n    def __init__(\n        self,\n        quantization_formula: str = \"google\",\n        quantization_bit: int = 8,\n        quantization_scheme: str = \"symmetric\",\n    ) -> None:\n        super().__init__()\n        self.quantization_formula = quantization_formula\n        self.quantization_bit = quantization_bit\n        self.quantization_scheme = quantization_scheme\n\n    def forward(self, input, scale, zero_point):\n        return flow._C.quantization(\n            input,\n            scale,\n            zero_point,\n            self.quantization_formula,\n            self.quantization_bit,\n            self.quantization_scheme,\n        )\n\n\nif __name__ == \"__main__\":\n    import doctest\n\n    doctest.testmod(raise_on_error=True)\n"
  },
  {
    "path": "python/oneflow/nn/modules/reshape.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nfrom typing import Sequence\n\nimport oneflow as flow\nfrom oneflow.framework.tensor import register_tensor_op\nfrom oneflow.nn.modules.module import Module\n\n\ndef _input_args_is_int(args):\n    return all((isinstance(x, int) for x in args))\n\n\ndef _input_args_is_flow_size(args):\n    return all((isinstance(x, flow.Size) for x in args)) and len(args) == 1\n\n\ndef reshape_op(input, shape: Sequence[int] = None):\n    \"\"\"This operator reshapes a Tensor.\n\n    We can set one dimension in `shape` as `-1`, the operator will infer the complete shape.\n\n    Args:\n        x: A Tensor.\n        shape: Shape of the output tensor.\n    Returns:\n        A Tensor has the same type as `x`.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import numpy as np\n        >>> import oneflow as flow\n        >>> x = np.array(\n        ...    [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]]\n        ... ).astype(np.float32)\n        >>> input = flow.Tensor(x)\n\n        >>> y = flow.reshape(input, shape=[2, 2, 2, -1]).shape\n        >>> y\n        oneflow.Size([2, 2, 2, 2])\n\n    \"\"\"\n    return flow._C.reshape(input, shape)\n\n\ndef view_op(input, *shape):\n    if len(shape) == 1:\n        new_shape = shape[0]\n        if isinstance(new_shape, int):\n            new_shape = (new_shape,)\n    else:\n        new_shape = shape\n    return flow._C.view(input, new_shape)\n\n\nif __name__ == \"__main__\":\n    import doctest\n\n    doctest.testmod(raise_on_error=True)\n"
  },
  {
    "path": "python/oneflow/nn/modules/rnn.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport math\nimport warnings\nimport numbers\nfrom typing import List, Tuple, Optional\n\nimport oneflow as flow\nfrom oneflow import nn\nfrom oneflow.framework.tensor import Tensor\nfrom oneflow.nn.utils.rnn import PackedSequence\n\n# NOTE(Liang Depeng): The implementation of rnn modules are modified from\n#                     https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/rnn.py\ndef apply_permutation(tensor: Tensor, permutation: Tensor, dim: int = 1) -> Tensor:\n    return tensor.index_select(dim, permutation)\n\n\nclass RNNBase(nn.Module):\n    def __init__(\n        self,\n        mode: str,\n        input_size: int,\n        hidden_size: int,\n        num_layers: int = 1,\n        bias: bool = True,\n        batch_first: bool = False,\n        dropout: float = 0.0,\n        bidirectional: bool = False,\n        proj_size: int = 0,\n        device=None,\n        dtype=None,\n    ) -> None:\n        factory_kwargs = {\"device\": device, \"dtype\": dtype}\n        super().__init__()\n        self.mode = mode\n        self.input_size = input_size\n        self.hidden_size = hidden_size\n        self.num_layers = num_layers\n        self.bias = bias\n        self.batch_first = batch_first\n        self.dropout = float(dropout)\n        self.bidirectional = bidirectional\n        self.proj_size = proj_size\n        num_directions = 2 if bidirectional else 1\n\n        if (\n            not isinstance(dropout, numbers.Number)\n            or not 0 <= dropout <= 1\n            or isinstance(dropout, bool)\n        ):\n            raise ValueError(\n                \"dropout should be a number in range [0, 1] \"\n                \"representing the probability of an element being \"\n                \"zeroed\"\n            )\n        if dropout > 0 and num_layers == 1:\n            warnings.warn(\n                \"dropout option adds dropout after all but last \"\n                \"recurrent layer, so non-zero dropout expects \"\n                \"num_layers greater than 1, but got dropout={} and \"\n                \"num_layers={}\".format(dropout, num_layers)\n            )\n        if proj_size < 0:\n            raise ValueError(\n                \"proj_size should be a positive integer or zero to disable projections\"\n            )\n        if proj_size >= hidden_size:\n            raise ValueError(\"proj_size has to be smaller than hidden_size\")\n\n        if mode == \"LSTM\":\n            gate_size = 4 * hidden_size\n        elif mode == \"GRU\":\n            gate_size = 3 * hidden_size\n        elif mode == \"RNN_TANH\":\n            gate_size = hidden_size\n        elif mode == \"RNN_RELU\":\n            gate_size = hidden_size\n        else:\n            raise ValueError(\"Unrecognized RNN mode: \" + mode)\n\n        self._flat_weights_names = []\n        self._all_weights = []\n        for layer in range(num_layers):\n            for direction in range(num_directions):\n                real_hidden_size = proj_size if proj_size > 0 else hidden_size\n                layer_input_size = (\n                    input_size if layer == 0 else real_hidden_size * num_directions\n                )\n\n                w_ih = nn.Parameter(\n                    flow.empty((gate_size, layer_input_size), **factory_kwargs)\n                )\n                w_hh = nn.Parameter(\n                    flow.empty((gate_size, real_hidden_size), **factory_kwargs)\n                )\n                b_ih = nn.Parameter(flow.empty(gate_size, **factory_kwargs))\n                b_hh = nn.Parameter(flow.empty(gate_size, **factory_kwargs))\n                layer_params: Tuple[Tensor, ...] = ()\n                if self.proj_size == 0:\n                    if bias:\n                        layer_params = (w_ih, w_hh, b_ih, b_hh)\n                    else:\n                        layer_params = (w_ih, w_hh)\n                else:\n                    w_hr = nn.Parameter(\n                        flow.empty((proj_size, hidden_size), **factory_kwargs)\n                    )\n                    if bias:\n                        layer_params = (w_ih, w_hh, b_ih, b_hh, w_hr)\n                    else:\n                        layer_params = (w_ih, w_hh, w_hr)\n\n                suffix = \"_reverse\" if direction == 1 else \"\"\n                param_names = [\"weight_ih_l{}{}\", \"weight_hh_l{}{}\"]\n                if bias:\n                    param_names += [\"bias_ih_l{}{}\", \"bias_hh_l{}{}\"]\n                if self.proj_size > 0:\n                    param_names += [\"weight_hr_l{}{}\"]\n                param_names = [x.format(layer, suffix) for x in param_names]\n\n                for name, param in zip(param_names, layer_params):\n                    setattr(self, name, param)\n                self._flat_weights_names.extend(param_names)\n                self._all_weights.append(param_names)\n\n        self._flat_weights = [\n            (lambda wn: getattr(self, wn) if hasattr(self, wn) else None)(wn)\n            for wn in self._flat_weights_names\n        ]\n        self.reset_parameters()\n\n    def __setattr__(self, attr, value):\n        if hasattr(self, \"_flat_weights_names\") and attr in self._flat_weights_names:\n            # keep self._flat_weights up to date if you do self.weight = ...\n            idx = self._flat_weights_names.index(attr)\n            self._flat_weights[idx] = value\n        super().__setattr__(attr, value)\n\n    def to_global(self, placement=None, sbp=None):\n        def convert(t):\n            return t.to_global(placement=placement, sbp=sbp)\n\n        self = self._apply(convert)\n        self._flat_weights = [\n            (lambda wn: getattr(self, wn) if hasattr(self, wn) else None)(wn)\n            for wn in self._flat_weights_names\n        ]\n        return self\n\n    def reset_parameters(self) -> None:\n        stdv = 1.0 / math.sqrt(self.hidden_size) if self.hidden_size > 0 else 0\n        for weight in self.parameters():\n            nn.init.uniform_(weight, -stdv, stdv)\n\n    def check_input(self, input: Tensor, batch_sizes: Optional[Tensor]) -> None:\n        expected_input_dim = 2 if batch_sizes is not None else 3\n        if input.dim() != expected_input_dim:\n            raise RuntimeError(\n                \"input must have {} dimensions, got {}\".format(\n                    expected_input_dim, input.dim()\n                )\n            )\n        if self.input_size != input.size(-1):\n            raise RuntimeError(\n                \"input.size(-1) must be equal to input_size. Expected {}, got {}\".format(\n                    self.input_size, input.size(-1)\n                )\n            )\n\n    def get_expected_hidden_size(\n        self, input: Tensor, batch_sizes: Optional[Tensor]\n    ) -> Tuple[int, int, int]:\n        if batch_sizes is not None:\n            mini_batch = int(batch_sizes[0])\n        else:\n            mini_batch = input.size(0) if self.batch_first else input.size(1)\n        num_directions = 2 if self.bidirectional else 1\n        if self.proj_size > 0:\n            expected_hidden_size = (\n                self.num_layers * num_directions,\n                mini_batch,\n                self.proj_size,\n            )\n        else:\n            expected_hidden_size = (\n                self.num_layers * num_directions,\n                mini_batch,\n                self.hidden_size,\n            )\n        return expected_hidden_size\n\n    def check_hidden_size(\n        self,\n        hx: Tensor,\n        expected_hidden_size: Tuple[int, int, int],\n        msg: str = \"Expected hidden size {}, got {}\",\n    ) -> None:\n        if hx.size() != expected_hidden_size:\n            raise RuntimeError(msg.format(expected_hidden_size, list(hx.size())))\n\n    def check_forward_args(\n        self, input: Tensor, hidden: Tensor, batch_sizes: Optional[Tensor]\n    ):\n        self.check_input(input, batch_sizes)\n        expected_hidden_size = self.get_expected_hidden_size(input, batch_sizes)\n\n        self.check_hidden_size(hidden, expected_hidden_size)\n\n    def permute_hidden(self, hx: Tensor, permutation: Optional[Tensor]):\n        if permutation is None:\n            return hx\n        return apply_permutation(hx, permutation)\n\n    def extra_repr(self) -> str:\n        s = \"{input_size}, {hidden_size}\"\n        if self.proj_size != 0:\n            s += \", proj_size={proj_size}\"\n        if self.num_layers != 1:\n            s += \", num_layers={num_layers}\"\n        if self.bias is not True:\n            s += \", bias={bias}\"\n        if self.batch_first is not False:\n            s += \", batch_first={batch_first}\"\n        if self.dropout != 0:\n            s += \", dropout={dropout}\"\n        if self.bidirectional is not False:\n            s += \", bidirectional={bidirectional}\"\n        return s.format(**self.__dict__)\n\n    @property\n    def all_weights(self) -> List[List[nn.Parameter]]:\n        return [\n            [getattr(self, weight) for weight in weights]\n            for weights in self._all_weights\n        ]\n\n\nclass RNN(RNNBase):\n    r\"\"\"\n    Applies a multi-layer Elman RNN with \\tanhtanh or \\text{ReLU}ReLU non-linearity to an input sequence.\n\n    For each element in the input sequence, each layer computes the following function:\n    \n    function:\n\n    .. math::\n        h_t = \\tanh(W_{ih} x_t + b_{ih} + W_{hh} h_{(t-1)} + b_{hh})\n\n    where :math:`h_t` is the hidden state at time `t`, :math:`x_t` is\n    the input at time `t`, and :math:`h_{(t-1)}` is the hidden state of the\n    previous layer at time `t-1` or the initial hidden state at time `0`.\n    If :attr:`nonlinearity` is ``'relu'``, then :math:`\\text{ReLU}` is used instead of :math:`\\tanh`.\n\n    The interface is consistent with PyTorch.\n    The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.RNN.html.\n\n    Args:\n        input_size: The number of expected features in the input `x`\n        hidden_size: The number of features in the hidden state `h`\n        num_layers: Number of recurrent layers. E.g., setting ``num_layers=2``\n            would mean stacking two RNNs together to form a `stacked RNN`,\n            with the second RNN taking in outputs of the first RNN and\n            computing the final results. Default: 1\n        nonlinearity: The non-linearity to use. Can be either ``'tanh'`` or ``'relu'``. Default: ``'tanh'``\n        bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`.\n            Default: ``True``\n        batch_first: If ``True``, then the input and output tensors are provided\n            as `(batch, seq, feature)` instead of `(seq, batch, feature)`.\n            Note that this does not apply to hidden or cell states. See the\n            Inputs/Outputs sections below for details.  Default: ``False``\n        dropout: If non-zero, introduces a `Dropout` layer on the outputs of each\n            RNN layer except the last layer, with dropout probability equal to\n            :attr:`dropout`. Default: 0\n        bidirectional: If ``True``, becomes a bidirectional RNN. Default: ``False``\n\n    Inputs: input, h_0\n        * **input**: tensor of shape :math:`(L, N, H_{in})` when ``batch_first=False`` or\n          :math:`(N, L, H_{in})` when ``batch_first=True`` containing the features of\n          the input sequence.\n        * **h_0**: tensor of shape :math:`(D * \\text{num\\_layers}, N, H_{out})` containing the initial hidden\n          state for each element in the batch. Defaults to zeros if not provided.\n\n        where:\n\n        .. math::\n            \\begin{aligned}\n                N ={} & \\text{batch size} \\\\\n                L ={} & \\text{sequence length} \\\\\n                D ={} & 2 \\text{ if bidirectional=True otherwise } 1 \\\\\n                H_{in} ={} & \\text{input_size} \\\\\n                H_{out} ={} & \\text{hidden_size}\n            \\end{aligned}\n\n    Outputs: output, h_n\n        * **output**: tensor of shape :math:`(L, N, D * H_{out})` when ``batch_first=False`` or\n          :math:`(N, L, D * H_{out})` when ``batch_first=True`` containing the output features\n          `(h_t)` from the last layer of the RNN, for each `t`.\n        * **h_n**: tensor of shape :math:`(D * \\text{num\\_layers}, N, H_{out})` containing the final hidden state\n          for each element in the batch.\n\n    Attributes:\n        weight_ih_l[k]: the learnable input-hidden weights of the k-th layer,\n            of shape `(hidden_size, input_size)` for `k = 0`. Otherwise, the shape is\n            `(hidden_size, num_directions * hidden_size)`\n        weight_hh_l[k]: the learnable hidden-hidden weights of the k-th layer,\n            of shape `(hidden_size, hidden_size)`\n        bias_ih_l[k]: the learnable input-hidden bias of the k-th layer,\n            of shape `(hidden_size)`\n        bias_hh_l[k]: the learnable hidden-hidden bias of the k-th layer,\n            of shape `(hidden_size)`\n\n    .. note::\n        All the weights and biases are initialized from :math:`\\mathcal{U}(-\\sqrt{k}, \\sqrt{k})`\n        where :math:`k = \\frac{1}{\\text{hidden\\_size}}`\n    \n    .. note::\n        For bidirectional RNNs, forward and backward are directions 0 and 1 respectively.\n        Example of splitting the output layers when ``batch_first=False``:\n        ``output.view((seq_len, batch, num_directions, hidden_size))``.\n    \n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n        >>> rnn = flow.nn.RNN(10, 20, 2)\n        >>> input = flow.tensor(np.random.randn(5, 3, 10), dtype=flow.float32)\n        >>> h0 = flow.tensor(np.random.randn(2, 3, 20), dtype=flow.float32)\n        >>> output, hn = rnn(input, h0)\n        >>> output.size()\n        oneflow.Size([5, 3, 20])\n\n    \"\"\"\n\n    def __init__(self, *args, **kwargs):\n        if \"proj_size\" in kwargs:\n            raise ValueError(\n                \"proj_size argument is only supported for LSTM, not RNN or GRU\"\n            )\n        self.nonlinearity = kwargs.pop(\"nonlinearity\", \"tanh\")\n        if self.nonlinearity == \"tanh\":\n            mode = \"RNN_TANH\"\n        elif self.nonlinearity == \"relu\":\n            mode = \"RNN_RELU\"\n        else:\n            raise ValueError(\"Unknown nonlinearity '{}'\".format(self.nonlinearity))\n        super().__init__(mode, *args, **kwargs)\n\n    def forward(self, input, hx=None):  # noqa: F811\n        orig_input = input\n        if isinstance(orig_input, PackedSequence):\n            input = orig_input.data\n            batch_sizes = orig_input.batch_sizes\n            sorted_indices = orig_input.sorted_indices\n            unsorted_indices = orig_input.unsorted_indices\n            max_batch_size = int(batch_sizes[0])\n        else:\n            batch_sizes = None\n            is_batched = input.dim() == 3\n            batch_dim = 0 if self.batch_first else 1\n            if not is_batched:\n                input = input.unsqueeze(batch_dim)\n                if hx is not None:\n                    if hx.dim() != 2:\n                        raise RuntimeError(\n                            f\"For unbatched 2-D input, hx should also be 2-D but got {hx.dim()}-D tensor\"\n                        )\n                    hx = hx.unsqueeze(1)\n            else:\n                if hx is not None and hx.dim() != 3:\n                    raise RuntimeError(\n                        f\"For batched 3-D input, hx should also be 3-D but got {hx.dim()}-D tensor\"\n                    )\n            max_batch_size = input.size(0) if self.batch_first else input.size(1)\n            sorted_indices = None\n            unsorted_indices = None\n\n        if hx is None:\n            num_directions = 2 if self.bidirectional else 1\n            if input.is_global:\n                hx = flow.zeros(\n                    self.num_layers * num_directions,\n                    max_batch_size,\n                    self.hidden_size,\n                    dtype=input.dtype,\n                    sbp=input.sbp,\n                    placement=input.placement,\n                )\n            else:\n                hx = flow.zeros(\n                    self.num_layers * num_directions,\n                    max_batch_size,\n                    self.hidden_size,\n                    dtype=input.dtype,\n                    device=input.device,\n                )\n        else:\n            # Each batch of the hidden state should match the input sequence that\n            # the user believes he/she is passing in.\n            hx = self.permute_hidden(hx, sorted_indices)\n        self._flat_weights = [\n            (lambda wn: getattr(self, wn) if hasattr(self, wn) else None)(wn)\n            for wn in self._flat_weights_names\n        ]\n        assert hx is not None\n        self.check_forward_args(input, hx, batch_sizes)\n        assert self.mode == \"RNN_TANH\" or self.mode == \"RNN_RELU\"\n        if batch_sizes is None:\n            if self.mode == \"RNN_TANH\":\n                result = flow._C.rnn_tanh(\n                    input,\n                    hx,\n                    self._flat_weights,\n                    self.bias,\n                    self.num_layers,\n                    self.dropout,\n                    self.training,\n                    self.bidirectional,\n                    self.batch_first,\n                )\n            else:\n                result = flow._C.rnn_relu(\n                    input,\n                    hx,\n                    self._flat_weights,\n                    self.bias,\n                    self.num_layers,\n                    self.dropout,\n                    self.training,\n                    self.bidirectional,\n                    self.batch_first,\n                )\n        else:\n            if self.mode == \"RNN_TANH\":\n                result = flow._C.rnn_tanh(\n                    input,\n                    batch_sizes,\n                    hx,\n                    self._flat_weights,\n                    self.bias,\n                    self.num_layers,\n                    self.dropout,\n                    self.training,\n                    self.bidirectional,\n                )\n            else:\n                result = flow._C.rnn_relu(\n                    input,\n                    batch_sizes,\n                    hx,\n                    self._flat_weights,\n                    self.bias,\n                    self.num_layers,\n                    self.dropout,\n                    self.training,\n                    self.bidirectional,\n                )\n\n        output = result[0]\n        hidden = result[1]\n\n        if isinstance(orig_input, PackedSequence):\n            output_packed = PackedSequence(\n                output, batch_sizes, sorted_indices, unsorted_indices\n            )\n            return output_packed, self.permute_hidden(hidden, unsorted_indices)\n\n        if not is_batched:\n            output = output.squeeze(batch_dim)\n            hidden = hidden.squeeze(1)\n\n        return output, self.permute_hidden(hidden, unsorted_indices)\n\n\nclass LSTM(RNNBase):\n    r\"\"\"\n    Applies a multi-layer long short-term memory (LSTM) RNN to an input sequence.\n\n    For each element in the input sequence, each layer computes the following\n    \n    function:\n\n    .. math::\n        \\begin{array}{ll} \\\\\n            i_t = \\sigma(W_{ii} x_t + b_{ii} + W_{hi} h_{t-1} + b_{hi}) \\\\\n            f_t = \\sigma(W_{if} x_t + b_{if} + W_{hf} h_{t-1} + b_{hf}) \\\\\n            g_t = \\tanh(W_{ig} x_t + b_{ig} + W_{hg} h_{t-1} + b_{hg}) \\\\\n            o_t = \\sigma(W_{io} x_t + b_{io} + W_{ho} h_{t-1} + b_{ho}) \\\\\n            c_t = f_t \\odot c_{t-1} + i_t \\odot g_t \\\\\n            h_t = o_t \\odot \\tanh(c_t) \\\\\n        \\end{array}\n\n    where :math:`h_t` is the hidden state at time `t`, :math:`c_t` is the cell\n    state at time `t`, :math:`x_t` is the input at time `t`, :math:`h_{t-1}`\n    is the hidden state of the layer at time `t-1` or the initial hidden\n    state at time `0`, and :math:`i_t`, :math:`f_t`, :math:`g_t`,\n    :math:`o_t` are the input, forget, cell, and output gates, respectively.\n    :math:`\\sigma` is the sigmoid function, and :math:`\\odot` is the Hadamard product.\n\n    In a multilayer LSTM, the input :math:`x^{(l)}_t` of the :math:`l` -th layer\n    (:math:`l >= 2`) is the hidden state :math:`h^{(l-1)}_t` of the previous layer multiplied by\n    dropout :math:`\\delta^{(l-1)}_t` where each :math:`\\delta^{(l-1)}_t` is a Bernoulli random\n    variable which is :math:`0` with probability :attr:`dropout`.\n\n    If ``proj_size > 0`` is specified, LSTM with projections will be used. This changes\n    the LSTM cell in the following way. First, the dimension of :math:`h_t` will be changed from\n    ``hidden_size`` to ``proj_size`` (dimensions of :math:`W_{hi}` will be changed accordingly).\n    Second, the output hidden state of each layer will be multiplied by a learnable projection\n    matrix: :math:`h_t = W_{hr}h_t`. Note that as a consequence of this, the output\n    of LSTM network will be of different shape as well. See Inputs/Outputs sections below for exact\n    dimensions of all variables. You can find more details in https://arxiv.org/abs/1402.1128.\n\n    The interface is consistent with PyTorch.\n    The documentation is referenced from: https://pytorch.org/docs/1.10/_modules/torch/nn/modules/rnn.html#LSTM.\n\n    Args:\n        input_size: The number of expected features in the input `x`\n        hidden_size: The number of features in the hidden state `h`\n        num_layers: Number of recurrent layers. E.g., setting ``num_layers=2``\n            would mean stacking two LSTMs together to form a `stacked LSTM`,\n            with the second LSTM taking in outputs of the first LSTM and\n            computing the final results. Default: 1\n        bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`.\n            Default: ``True``\n        batch_first: If ``True``, then the input and output tensors are provided\n            as `(batch, seq, feature)` instead of `(seq, batch, feature)`.\n            Note that this does not apply to hidden or cell states. See the\n            Inputs/Outputs sections below for details.  Default: ``False``\n        dropout: If non-zero, introduces a `Dropout` layer on the outputs of each\n            LSTM layer except the last layer, with dropout probability equal to\n            :attr:`dropout`. Default: 0\n        bidirectional: If ``True``, becomes a bidirectional LSTM. Default: ``False``\n        proj_size: If ``> 0``, will use LSTM with projections of corresponding size. Default: 0\n\n    Inputs: input, (h_0, c_0)\n        * **input**: tensor of shape :math:`(L, N, H_{in})` when ``batch_first=False`` or\n          :math:`(N, L, H_{in})` when ``batch_first=True`` containing the features of\n          the input sequence.\n        * **h_0**: tensor of shape :math:`(D * \\text{num\\_layers}, N, H_{out})` containing the\n          initial hidden state for each element in the batch.\n          Defaults to zeros if (h_0, c_0) is not provided.\n        * **c_0**: tensor of shape :math:`(D * \\text{num\\_layers}, N, H_{cell})` containing the\n          initial cell state for each element in the batch.\n          Defaults to zeros if (h_0, c_0) is not provided.\n\n        where:\n\n        .. math::\n            \\begin{aligned}\n                N ={} & \\text{batch size} \\\\\n                L ={} & \\text{sequence length} \\\\\n                D ={} & 2 \\text{ if bidirectional=True otherwise } 1 \\\\\n                H_{in} ={} & \\text{input\\_size} \\\\\n                H_{cell} ={} & \\text{hidden\\_size} \\\\\n                H_{out} ={} & \\text{proj\\_size if } \\text{proj\\_size}>0 \\text{ otherwise hidden\\_size} \\\\\n            \\end{aligned}\n\n    Outputs: output, (h_n, c_n)\n        * **output**: tensor of shape :math:`(L, N, D * H_{out})` when ``batch_first=False`` or\n          :math:`(N, L, D * H_{out})` when ``batch_first=True`` containing the output features\n          `(h_t)` from the last layer of the LSTM, for each `t`.\n        * **h_n**: tensor of shape :math:`(D * \\text{num\\_layers}, N, H_{out})` containing the\n          final hidden state for each element in the batch.\n        * **c_n**: tensor of shape :math:`(D * \\text{num\\_layers}, N, H_{cell})` containing the\n          final cell state for each element in the batch.\n\n    Attributes:\n        weight_ih_l[k] : the learnable input-hidden weights of the :math:`\\text{k}^{th}` layer\n            `(W_ii|W_if|W_ig|W_io)`, of shape `(4*hidden_size, input_size)` for `k = 0`.\n            Otherwise, the shape is `(4*hidden_size, num_directions * hidden_size)`\n        weight_hh_l[k] : the learnable hidden-hidden weights of the :math:`\\text{k}^{th}` layer\n            `(W_hi|W_hf|W_hg|W_ho)`, of shape `(4*hidden_size, hidden_size)`. If ``proj_size > 0``\n            was specified, the shape will be `(4*hidden_size, proj_size)`.\n        bias_ih_l[k] : the learnable input-hidden bias of the :math:`\\text{k}^{th}` layer\n            `(b_ii|b_if|b_ig|b_io)`, of shape `(4*hidden_size)`\n        bias_hh_l[k] : the learnable hidden-hidden bias of the :math:`\\text{k}^{th}` layer\n            `(b_hi|b_hf|b_hg|b_ho)`, of shape `(4*hidden_size)`\n        weight_hr_l[k] : the learnable projection weights of the :math:`\\text{k}^{th}` layer\n            of shape `(proj_size, hidden_size)`. Only present when ``proj_size > 0`` was\n            specified.\n\n    .. note::\n        All the weights and biases are initialized from :math:`\\mathcal{U}(-\\sqrt{k}, \\sqrt{k})`\n        where :math:`k = \\frac{1}{\\text{hidden\\_size}}`\n\n    .. note::\n        For bidirectional LSTMs, forward and backward are directions 0 and 1 respectively.\n        Example of splitting the output layers when ``batch_first=False``:\n        ``output.view(seq_len, batch, num_directions, hidden_size)``.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n        >>> rnn = flow.nn.LSTM(10, 20, 2)\n        >>> input = flow.tensor(np.random.randn(5, 3, 10), dtype=flow.float32)\n        >>> h0 = flow.tensor(np.random.randn(2, 3, 20), dtype=flow.float32)\n        >>> c0 = flow.tensor(np.random.randn(2, 3, 20), dtype=flow.float32)\n        >>> output, (hn, cn) = rnn(input, (h0, c0))\n        >>> output.size()\n        oneflow.Size([5, 3, 20])\n        \n    \"\"\"\n\n    def __init__(self, *args, **kwargs):\n        super().__init__(\"LSTM\", *args, **kwargs)\n\n    def get_expected_cell_size(\n        self, input: Tensor, batch_sizes: Optional[Tensor]\n    ) -> Tuple[int, int, int]:\n        if batch_sizes is not None:\n            mini_batch = int(batch_sizes[0])\n        else:\n            mini_batch = input.size(0) if self.batch_first else input.size(1)\n        num_directions = 2 if self.bidirectional else 1\n        expected_hidden_size = (\n            self.num_layers * num_directions,\n            mini_batch,\n            self.hidden_size,\n        )\n        return expected_hidden_size\n\n    def check_forward_args(\n        self,\n        input: Tensor,\n        hidden: Tuple[Tensor, Tensor],\n        batch_sizes: Optional[Tensor],\n    ):\n        self.check_input(input, batch_sizes)\n        self.check_hidden_size(\n            hidden[0],\n            self.get_expected_hidden_size(input, batch_sizes),\n            \"Expected hidden[0] size {}, got {}\",\n        )\n        self.check_hidden_size(\n            hidden[1],\n            self.get_expected_cell_size(input, batch_sizes),\n            \"Expected hidden[1] size {}, got {}\",\n        )\n\n    def permute_hidden(\n        self, hx: Tuple[Tensor, Tensor], permutation: Optional[Tensor]\n    ) -> Tuple[Tensor, Tensor]:\n        if permutation is None:\n            return hx\n        return (\n            apply_permutation(hx[0], permutation),\n            apply_permutation(hx[1], permutation),\n        )\n\n    def forward(self, input, hx=None):\n        orig_input = input\n        batch_sizes = None\n        if isinstance(orig_input, PackedSequence):\n            input = orig_input.data\n            batch_sizes = orig_input.batch_sizes\n            sorted_indices = orig_input.sorted_indices\n            unsorted_indices = orig_input.unsorted_indices\n            max_batch_size = int(batch_sizes[0])\n        else:\n            batch_sizes = None\n            is_batched = input.dim() == 3\n            batch_dim = 0 if self.batch_first else 1\n            if not is_batched:\n                input = input.unsqueeze(batch_dim)\n            max_batch_size = input.size(0) if self.batch_first else input.size(1)\n            sorted_indices = None\n            unsorted_indices = None\n\n        if hx is None:\n            num_directions = 2 if self.bidirectional else 1\n            real_hidden_size = (\n                self.proj_size if self.proj_size > 0 else self.hidden_size\n            )\n\n            if input.is_global:\n                h_zeros = flow.zeros(\n                    self.num_layers * num_directions,\n                    max_batch_size,\n                    real_hidden_size,\n                    dtype=input.dtype,\n                    sbp=input.sbp,\n                    placement=input.placement,\n                )\n                c_zeros = flow.zeros(\n                    self.num_layers * num_directions,\n                    max_batch_size,\n                    self.hidden_size,\n                    dtype=input.dtype,\n                    sbp=input.sbp,\n                    placement=input.placement,\n                )\n            else:\n                h_zeros = flow.zeros(\n                    self.num_layers * num_directions,\n                    max_batch_size,\n                    real_hidden_size,\n                    dtype=input.dtype,\n                    device=input.device,\n                )\n                c_zeros = flow.zeros(\n                    self.num_layers * num_directions,\n                    max_batch_size,\n                    self.hidden_size,\n                    dtype=input.dtype,\n                    device=input.device,\n                )\n            hx = (h_zeros, c_zeros)\n        else:\n            if batch_sizes is None:  # If not PackedSequence input.\n                if is_batched:\n                    if hx[0].dim() != 3 or hx[1].dim() != 3:\n                        msg = (\n                            \"For batched 3-D input, hx and cx should \"\n                            f\"also be 3-D but got ({hx[0].dim()}-D, {hx[1].dim()}-D) tensors\"\n                        )\n                        raise RuntimeError(msg)\n                else:\n                    if hx[0].dim() != 2 or hx[1].dim() != 2:\n                        msg = (\n                            \"For unbatched 2-D input, hx and cx should \"\n                            f\"also be 2-D but got ({hx[0].dim()}-D, {hx[1].dim()}-D) tensors\"\n                        )\n                        raise RuntimeError(msg)\n                    hx = (hx[0].unsqueeze(1), hx[1].unsqueeze(1))\n\n            # Each batch of the hidden state should match the input sequence that\n            # the user believes he/she is passing in.\n            hx = self.permute_hidden(hx, sorted_indices)\n\n        self.check_forward_args(input, hx, batch_sizes)\n        self._flat_weights = [\n            (lambda wn: getattr(self, wn) if hasattr(self, wn) else None)(wn)\n            for wn in self._flat_weights_names\n        ]\n        if batch_sizes is None:\n            result = flow._C.lstm(\n                input,\n                hx,\n                self._flat_weights,\n                self.bias,\n                self.num_layers,\n                self.dropout,\n                self.training,\n                self.bidirectional,\n                self.batch_first,\n            )\n        else:\n            result = flow._C.lstm(\n                input,\n                batch_sizes,\n                hx,\n                self._flat_weights,\n                self.bias,\n                self.num_layers,\n                self.dropout,\n                self.training,\n                self.bidirectional,\n            )\n        output = result[0]\n        hidden = result[1:]\n        if isinstance(orig_input, PackedSequence):\n            output_packed = PackedSequence(\n                output, batch_sizes, sorted_indices, unsorted_indices\n            )\n            return output_packed, self.permute_hidden(hidden, unsorted_indices)\n        else:\n            if not is_batched:\n                output = output.squeeze(batch_dim)\n                hidden = (hidden[0].squeeze(1), hidden[1].squeeze(1))\n            return output, self.permute_hidden(hidden, unsorted_indices)\n\n\nclass GRU(RNNBase):\n    r\"\"\"\n    Applies a multi-layer gated recurrent unit (GRU) RNN to an input sequence.\n\n    For each element in the input sequence, each layer computes the following\n    \n    function:\n\n    .. math::\n        \\begin{array}{ll}\n            r_t = \\sigma(W_{ir} x_t + b_{ir} + W_{hr} h_{(t-1)} + b_{hr}) \\\\\n            z_t = \\sigma(W_{iz} x_t + b_{iz} + W_{hz} h_{(t-1)} + b_{hz}) \\\\\n            n_t = \\\\tanh(W_{in} x_t + b_{in} + r_t * (W_{hn} h_{(t-1)}+ b_{hn})) \\\\\n            h_t = (1 - z_t) * n_t + z_t * h_{(t-1)}\n        \\end{array}\n    \n    where :math:`h_t` is the hidden state at time `t`, :math:`x_t` is the input\n    at time `t`, :math:`h_{(t-1)}` is the hidden state of the layer\n    at time `t-1` or the initial hidden state at time `0`, and :math:`r_t`,\n    :math:`z_t`, :math:`n_t` are the reset, update, and new gates, respectively.\n    :math:`\\sigma` is the sigmoid function, and :math:`*` is the Hadamard product.\n\n    In a multilayer GRU, the input :math:`x^{(l)}_t` of the :math:`l` -th layer\n    (:math:`l >= 2`) is the hidden state :math:`h^{(l-1)}_t` of the previous layer multiplied by\n    dropout :math:`\\delta^{(l-1)}_t` where each :math:`\\delta^{(l-1)}_t` is a Bernoulli random\n    variable which is :math:`0` with probability :attr:`dropout`.\n\n    The interface is consistent with PyTorch.\n    The documentation is referenced from: https://pytorch.org/docs/1.10/_modules/torch/nn/modules/rnn.html#GRU.\n\n    Args:\n        num_layers: Number of recurrent layers. E.g., setting ``num_layers=2``\n            would mean stacking two GRUs together to form a `stacked GRU`,\n            with the second GRU taking in outputs of the first GRU and\n            computing the final results. Default: 1\n        bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`.\n            Default: ``True``\n        batch_first: If ``True``, then the input and output tensors are provided\n            as `(batch, seq, feature)` instead of `(seq, batch, feature)`.\n            Note that this does not apply to hidden or cell states. See the\n            Inputs/Outputs sections below for details.  Default: ``False``\n        dropout: If non-zero, introduces a `Dropout` layer on the outputs of each\n            GRU layer except the last layer, with dropout probability equal to\n            :attr:`dropout`. Default: 0\n        bidirectional: If ``True``, becomes a bidirectional GRU. Default: ``False``\n\n    Inputs: input, h_0\n        * **input**: tensor of shape :math:`(L, N, H_{in})` when ``batch_first=False`` or\n          :math:`(N, L, H_{in})` when ``batch_first=True`` containing the features of\n          the input sequence. \n        * **h_0**: tensor of shape :math:`(D * \\text{num\\_layers}, N, H_{out})` containing the initial hidden\n          state for each element in the batch. Defaults to zeros if not provided.\n        \n        where:\n\n        .. math::\n            \\begin{aligned}\n                N ={} & \\text{batch size} \\\\\n                L ={} & \\text{sequence length} \\\\\n                D ={} & 2 \\text{ if bidirectional=True otherwise } 1 \\\\\n                H_{in} ={} & \\text{input\\_size} \\\\\n                H_{out} ={} & \\text{hidden\\_size}\n            \\end{aligned}\n\n    Outputs: output, h_n\n        * **output**: tensor of shape :math:`(L, N, D * H_{out})` when ``batch_first=False`` or\n          :math:`(N, L, D * H_{out})` when ``batch_first=True`` containing the output features\n          `(h_t)` from the last layer of the GRU, for each `t`. If a\n        * **h_n**: tensor of shape :math:`(D * \\text{num\\_layers}, N, H_{out})` containing the final hidden state\n          for each element in the batch.\n\n    Attributes:\n        weight_ih_l[k] : the learnable input-hidden weights of the :math:`\\text{k}^{th}` layer\n            (W_ir|W_iz|W_in), of shape `(3*hidden_size, input_size)` for `k = 0`.\n            Otherwise, the shape is `(3*hidden_size, num_directions * hidden_size)`\n        weight_hh_l[k] : the learnable hidden-hidden weights of the :math:`\\text{k}^{th}` layer\n            (W_hr|W_hz|W_hn), of shape `(3*hidden_size, hidden_size)`\n        bias_ih_l[k] : the learnable input-hidden bias of the :math:`\\text{k}^{th}` layer\n            (b_ir|b_iz|b_in), of shape `(3*hidden_size)`\n        bias_hh_l[k] : the learnable hidden-hidden bias of the :math:`\\text{k}^{th}` layer\n            (b_hr|b_hz|b_hn), of shape `(3*hidden_size)`\n\n    .. note::\n        All the weights and biases are initialized from :math:`\\mathcal{U}(-\\sqrt{k}, \\sqrt{k})`\n        where :math:`k = \\frac{1}{\\text{hidden\\_size}}`\n\n    .. note::\n        For bidirectional GRUs, forward and backward are directions 0 and 1 respectively.\n        Example of splitting the output layers when ``batch_first=False``:\n        ``output.view(seq_len, batch, num_directions, hidden_size)``.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n        >>> rnn = flow.nn.GRU(10, 20, 2)\n        >>> input = flow.tensor(np.random.randn(5, 3, 10), dtype=flow.float32)\n        >>> h0 = flow.tensor(np.random.randn(2, 3, 20), dtype=flow.float32)\n        >>> output, hn = rnn(input, h0)\n        >>> output.size()\n        oneflow.Size([5, 3, 20])\n    \n    \"\"\"\n\n    def __init__(self, *args, **kwargs):\n        if \"proj_size\" in kwargs:\n            raise ValueError(\n                \"proj_size argument is only supported for LSTM, not RNN or GRU\"\n            )\n        super().__init__(\"GRU\", *args, **kwargs)\n\n    def forward(self, input, hx=None):\n        orig_input = input\n        if isinstance(orig_input, PackedSequence):\n            input = orig_input.data\n            batch_sizes = orig_input.batch_sizes\n            sorted_indices = orig_input.sorted_indices\n            unsorted_indices = orig_input.unsorted_indices\n            max_batch_size = int(batch_sizes[0])\n        else:\n            batch_sizes = None\n            is_batched = input.dim() == 3\n            batch_dim = 0 if self.batch_first else 1\n            if not is_batched:\n                input = input.unsqueeze(batch_dim)\n                if hx is not None:\n                    if hx.dim() != 2:\n                        raise RuntimeError(\n                            f\"For unbatched 2-D input, hx should also be 2-D but got {hx.dim()}-D tensor\"\n                        )\n                    hx = hx.unsqueeze(1)\n            else:\n                if hx is not None and hx.dim() != 3:\n                    raise RuntimeError(\n                        f\"For batched 3-D input, hx should also be 3-D but got {hx.dim()}-D tensor\"\n                    )\n            max_batch_size = input.size(0) if self.batch_first else input.size(1)\n            sorted_indices = None\n            unsorted_indices = None\n\n        if hx is None:\n            num_directions = 2 if self.bidirectional else 1\n            if input.is_global:\n                hx = flow.zeros(\n                    self.num_layers * num_directions,\n                    max_batch_size,\n                    self.hidden_size,\n                    dtype=input.dtype,\n                    sbp=input.sbp,\n                    placement=input.placement,\n                )\n            else:\n                hx = flow.zeros(\n                    self.num_layers * num_directions,\n                    max_batch_size,\n                    self.hidden_size,\n                    dtype=input.dtype,\n                    device=input.device,\n                )\n        else:\n            # Each batch of the hidden state should match the input sequence that\n            # the user believes he/she is passing in.\n            hx = self.permute_hidden(hx, sorted_indices)\n\n        self.check_forward_args(input, hx, batch_sizes)\n        self._flat_weights = [\n            (lambda wn: getattr(self, wn) if hasattr(self, wn) else None)(wn)\n            for wn in self._flat_weights_names\n        ]\n        if batch_sizes is None:\n            result = flow._C.gru(\n                input,\n                hx,\n                self._flat_weights,\n                self.bias,\n                self.num_layers,\n                self.dropout,\n                self.training,\n                self.bidirectional,\n                self.batch_first,\n            )\n        else:\n            result = flow._C.gru(\n                input,\n                batch_sizes,\n                hx,\n                self._flat_weights,\n                self.bias,\n                self.num_layers,\n                self.dropout,\n                self.training,\n                self.bidirectional,\n            )\n        output = result[0]\n        hidden = result[1]\n\n        if isinstance(orig_input, PackedSequence):\n            output_packed = PackedSequence(\n                output, batch_sizes, sorted_indices, unsorted_indices\n            )\n            return output_packed, self.permute_hidden(hidden, unsorted_indices)\n        else:\n            if not is_batched:\n                output = output.squeeze(batch_dim)\n                hidden = hidden.squeeze(1)\n\n            return output, self.permute_hidden(hidden, unsorted_indices)\n\n\nclass RNNCellBase(nn.Module):\n    def __init__(\n        self,\n        input_size: int,\n        hidden_size: int,\n        bias: bool,\n        num_chunks: int,\n        device=None,\n        dtype=None,\n    ):\n        factory_kwargs = {\"device\": device, \"dtype\": dtype}\n        super().__init__()\n        self.input_size = input_size\n        self.hidden_size = hidden_size\n        self.bias = bias\n        self.weight_ih = nn.Parameter(\n            flow.empty(num_chunks * hidden_size, input_size, **factory_kwargs)\n        )\n        self.weight_hh = nn.Parameter(\n            flow.empty(num_chunks * hidden_size, hidden_size, **factory_kwargs)\n        )\n        if bias:\n            self.bias_ih = nn.Parameter(\n                flow.empty(num_chunks * hidden_size, **factory_kwargs)\n            )\n            self.bias_hh = nn.Parameter(\n                flow.empty(num_chunks * hidden_size, **factory_kwargs)\n            )\n        else:\n            self.register_parameter(\"bias_ih\", None)\n            self.register_parameter(\"bias_hh\", None)\n\n        self.reset_parameters()\n\n    def extra_repr(self) -> str:\n        s = \"{input_size}, {hidden_size}\"\n        if \"bias\" in self.__dict__ and self.bias is not True:\n            s += \", bias={bias}\"\n        if \"nonlinearity\" in self.__dict__ and self.nonlinearity != \"tanh\":\n            s += \", nonlinearity={nonlinearity}\"\n        return s.format(**self.__dict__)\n\n    def reset_parameters(self) -> None:\n        stdv = 1.0 / math.sqrt(self.hidden_size) if self.hidden_size > 0 else 0\n        for weight in self.parameters():\n            nn.init.uniform_(weight, -stdv, stdv)\n\n\nclass RNNCell(RNNCellBase):\n    r\"\"\"    \n    An Elman RNN cell with tanh or ReLU non-linearity.\n\n    .. math::\n\n        h' = \\tanh(W_{ih} x + b_{ih}  +  W_{hh} h + b_{hh})\n\n    If :attr:`nonlinearity` is `'relu'`, then ReLU is used in place of tanh.\n\n    The interface is consistent with PyTorch.\n    The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.RNNCell.html.\n\n    Args:\n        input_size: The number of expected features in the input `x`\n        hidden_size: The number of features in the hidden state `h`\n        bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`.\n            Default: ``True``\n        nonlinearity: The non-linearity to use. Can be either ``'tanh'`` or ``'relu'``. Default: ``'tanh'``\n\n    Inputs: input, hidden\n        - **input**: tensor containing input features\n        - **hidden**: tensor containing the initial hidden state\n          Defaults to zero if not provided.\n\n    Outputs: h'\n        - **h'** of shape `(batch, hidden_size)`: tensor containing the next hidden state\n          for each element in the batch\n\n    Shape:\n        - input: :math:`(N, H_{in})` or :math:`(H_{in})` tensor containing input features where\n          :math:`H_{in}` = `input_size`.\n        - hidden: :math:`(N, H_{out})` or :math:`(H_{out})` tensor containing the initial hidden\n          state where :math:`H_{out}` = `hidden_size`. Defaults to zero if not provided.\n        - output: :math:`(N, H_{out})` or :math:`(H_{out})` tensor containing the next hidden state.\n\n    Attributes:\n        weight_ih: the learnable input-hidden weights, of shape\n            `(hidden_size, input_size)`\n        weight_hh: the learnable hidden-hidden weights, of shape\n            `(hidden_size, hidden_size)`\n        bias_ih: the learnable input-hidden bias, of shape `(hidden_size)`\n        bias_hh: the learnable hidden-hidden bias, of shape `(hidden_size)`\n\n    .. note::\n        All the weights and biases are initialized from :math:`\\mathcal{U}(-\\sqrt{k}, \\sqrt{k})`\n        where :math:`k = \\frac{1}{\\text{hidden\\_size}}`\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import oneflow.nn as nn\n\n        >>> rnn = nn.RNNCell(10, 20)\n        >>> input = flow.randn(6, 3, 10)\n        >>> hx = flow.randn(3, 20)\n        >>> hx = rnn(input[0], hx)\n        >>> hx.size()\n        oneflow.Size([3, 20])\n    \"\"\"\n\n    def __init__(\n        self,\n        input_size: int,\n        hidden_size: int,\n        bias: bool = True,\n        nonlinearity: str = \"tanh\",\n        device=None,\n        dtype=None,\n    ):\n        factory_kwargs = {\"device\": device, \"dtype\": dtype}\n        super(RNNCell, self).__init__(\n            input_size, hidden_size, bias, num_chunks=1, **factory_kwargs\n        )\n        self.nonlinearity = nonlinearity\n\n    def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor:\n        assert input.dim() in (\n            1,\n            2,\n        ), f\"RNNCell: Expected input to be 1-D or 2-D but received {input.dim()}-D tensor\"\n        is_batched = input.dim() == 2\n        if not is_batched:\n            input = input.unsqueeze(0)\n\n        if hx is None:\n            if input.is_global():\n                hx = flow.zeros(\n                    input.size(0),\n                    self.hidden_size,\n                    dtype=input.dtype,\n                    sbp=input.sbp,\n                    placement=input.placement,\n                )\n            else:\n                hx = flow.zeros(\n                    input.size(0),\n                    self.hidden_size,\n                    dtype=input.dtype,\n                    device=input.device,\n                )\n        else:\n            hx = hx.unsqueeze(0) if not is_batched else hx\n\n        if self.nonlinearity == \"tanh\":\n            ret = flow._C.rnn_tanh_cell(\n                input, hx, self.weight_ih, self.weight_hh, self.bias_ih, self.bias_hh,\n            )\n        elif self.nonlinearity == \"relu\":\n            ret = flow._C.rnn_relu_cell(\n                input, hx, self.weight_ih, self.weight_hh, self.bias_ih, self.bias_hh,\n            )\n        else:\n            raise RuntimeError(\"Unknown nonlinearity: {}\".format(self.nonlinearity))\n\n        if not is_batched:\n            ret = ret.squeeze(0)\n\n        return ret\n\n\nclass LSTMCell(RNNCellBase):\n    r\"\"\"    \n    A long short-term memory (LSTM) cell.\n\n    .. math::\n\n        \\begin{array}{ll}\n        i = \\sigma(W_{ii} x + b_{ii} + W_{hi} h + b_{hi}) \\\\\n        f = \\sigma(W_{if} x + b_{if} + W_{hf} h + b_{hf}) \\\\\n        g = \\tanh(W_{ig} x + b_{ig} + W_{hg} h + b_{hg}) \\\\\n        o = \\sigma(W_{io} x + b_{io} + W_{ho} h + b_{ho}) \\\\\n        c' = f * c + i * g \\\\\n        h' = o * \\tanh(c') \\\\\n        \\end{array}\n\n    where :math:`\\sigma` is the sigmoid function, and :math:`*` is the Hadamard product.\n\n    The interface is consistent with PyTorch.\n    The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.LSTMCell.html.\n\n    Args:\n        input_size: The number of expected features in the input `x`\n        hidden_size: The number of features in the hidden state `h`\n        bias: If ``False``, then the layer does not use bias weights `b_ih` and\n            `b_hh`. Default: ``True``\n\n    Inputs: input, (h_0, c_0)\n        - **input** of shape `(batch, input_size)` or `(input_size)`: tensor containing input features\n        - **h_0** of shape `(batch, hidden_size)` or `(hidden_size)`: tensor containing the initial hidden state\n        - **c_0** of shape `(batch, hidden_size)` or `(hidden_size)`: tensor containing the initial cell state\n\n          If `(h_0, c_0)` is not provided, both **h_0** and **c_0** default to zero.\n\n    Outputs: (h_1, c_1)\n        - **h_1** of shape `(batch, hidden_size)` or `(hidden_size)`: tensor containing the next hidden state\n        - **c_1** of shape `(batch, hidden_size)` or `(hidden_size)`: tensor containing the next cell state\n\n    Attributes:\n        weight_ih: the learnable input-hidden weights, of shape\n            `(4*hidden_size, input_size)`\n        weight_hh: the learnable hidden-hidden weights, of shape\n            `(4*hidden_size, hidden_size)`\n        bias_ih: the learnable input-hidden bias, of shape `(4*hidden_size)`\n        bias_hh: the learnable hidden-hidden bias, of shape `(4*hidden_size)`\n\n    .. note::\n        All the weights and biases are initialized from :math:`\\mathcal{U}(-\\sqrt{k}, \\sqrt{k})`\n        where :math:`k = \\frac{1}{\\text{hidden\\_size}}`\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import oneflow.nn as nn\n\n        >>> rnn = nn.LSTMCell(10, 20) # (input_size, hidden_size)\n        >>> input = flow.randn(2, 3, 10) # (time_steps, batch, input_size)\n        >>> hx = flow.randn(3, 20) # (batch, hidden_size)\n        >>> cx = flow.randn(3, 20)\n        >>> hx, cx = rnn(input[0], (hx, cx))\n        >>> hx.size()\n        oneflow.Size([3, 20])\n    \"\"\"\n\n    def __init__(\n        self,\n        input_size: int,\n        hidden_size: int,\n        bias: bool = True,\n        device=None,\n        dtype=None,\n    ) -> None:\n        factory_kwargs = {\"device\": device, \"dtype\": dtype}\n        super(LSTMCell, self).__init__(\n            input_size, hidden_size, bias, num_chunks=4, **factory_kwargs\n        )\n\n    def forward(\n        self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None\n    ) -> Tuple[Tensor, Tensor]:\n        assert input.dim() in (\n            1,\n            2,\n        ), f\"LSTMCell: Expected input to be 1-D or 2-D but received {input.dim()}-D tensor\"\n        is_batched = input.dim() == 2\n        if not is_batched:\n            input = input.unsqueeze(0)\n\n        if hx is None:\n            if input.is_global():\n                zeros = flow.zeros(\n                    input.size(0),\n                    self.hidden_size,\n                    dtype=input.dtype,\n                    sbp=input.sbp,\n                    placement=input.placement,\n                )\n            else:\n                zeros = flow.zeros(\n                    input.size(0),\n                    self.hidden_size,\n                    dtype=input.dtype,\n                    device=input.device,\n                )\n            hx = (zeros, zeros)\n        else:\n            hx = (hx[0].unsqueeze(0), hx[1].unsqueeze(0)) if not is_batched else hx\n\n        ret = flow._C.lstm_cell(\n            input, hx, self.weight_ih, self.weight_hh, self.bias_ih, self.bias_hh,\n        )\n\n        if not is_batched:\n            ret = (ret[0].squeeze(0), ret[1].squeeze(0))\n        return ret\n\n\nclass GRUCell(RNNCellBase):\n    r\"\"\"    \n    A gated recurrent unit (GRU) cell\n\n    .. math::\n\n        \\begin{array}{ll}\n        r = \\sigma(W_{ir} x + b_{ir} + W_{hr} h + b_{hr}) \\\\\n        z = \\sigma(W_{iz} x + b_{iz} + W_{hz} h + b_{hz}) \\\\\n        n = \\tanh(W_{in} x + b_{in} + r * (W_{hn} h + b_{hn})) \\\\\n        h' = (1 - z) * n + z * h\n        \\end{array}\n\n    where :math:`\\sigma` is the sigmoid function, and :math:`*` is the Hadamard product.\n\n    The interface is consistent with PyTorch.\n    The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.GRUCell.html.\n\n    Args:\n        input_size: The number of expected features in the input `x`\n        hidden_size: The number of features in the hidden state `h`\n        bias: If ``False``, then the layer does not use bias weights `b_ih` and\n            `b_hh`. Default: ``True``\n\n    Inputs: input, hidden\n        - **input** : tensor containing input features\n        - **hidden** : tensor containing the initial hidden\n          state for each element in the batch.\n          Defaults to zero if not provided.\n\n    Outputs: h'\n        - **h'** : tensor containing the next hidden state\n          for each element in the batch\n\n    Shape:\n        - input: :math:`(N, H_{in})` or :math:`(H_{in})` tensor containing input features where\n          :math:`H_{in}` = `input_size`.\n        - hidden: :math:`(N, H_{out})` or :math:`(H_{out})` tensor containing the initial hidden\n          state where :math:`H_{out}` = `hidden_size`. Defaults to zero if not provided.\n        - output: :math:`(N, H_{out})` or :math:`(H_{out})` tensor containing the next hidden state.\n\n    Attributes:\n        weight_ih: the learnable input-hidden weights, of shape\n            `(3*hidden_size, input_size)`\n        weight_hh: the learnable hidden-hidden weights, of shape\n            `(3*hidden_size, hidden_size)`\n        bias_ih: the learnable input-hidden bias, of shape `(3*hidden_size)`\n        bias_hh: the learnable hidden-hidden bias, of shape `(3*hidden_size)`\n\n    .. note::\n        All the weights and biases are initialized from :math:`\\mathcal{U}(-\\sqrt{k}, \\sqrt{k})`\n        where :math:`k = \\frac{1}{\\text{hidden\\_size}}`\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import oneflow.nn as nn\n\n        >>> rnn = nn.GRUCell(10, 20)\n        >>> input = flow.randn(6, 3, 10)\n        >>> hx = flow.randn(3, 20)\n        >>> hx = rnn(input[0], hx)\n        >>> hx.size()\n        oneflow.Size([3, 20])\n\n    \"\"\"\n\n    def __init__(\n        self,\n        input_size: int,\n        hidden_size: int,\n        bias: bool = True,\n        device=None,\n        dtype=None,\n    ):\n        factory_kwargs = {\"device\": device, \"dtype\": dtype}\n        super().__init__(input_size, hidden_size, bias, num_chunks=3, **factory_kwargs)\n\n    def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor:\n        assert input.dim() in (\n            1,\n            2,\n        ), f\"GRUCell: Expected input to be 1-D or 2-D but received {input.dim()}-D tensor\"\n        is_batched = input.dim() == 2\n        if not is_batched:\n            input = input.unsqueeze(0)\n\n        if hx is None:\n            if input.is_global():\n                hx = flow.zeros(\n                    input.size(0),\n                    self.hidden_size,\n                    dtype=input.dtype,\n                    sbp=input.sbp,\n                    placement=input.placement,\n                )\n            else:\n                hx = flow.zeros(\n                    input.size(0),\n                    self.hidden_size,\n                    dtype=input.dtype,\n                    device=input.device,\n                )\n        else:\n            hx = hx.unsqueeze(0) if not is_batched else hx\n\n        ret = flow._C.gru_cell(\n            input, hx, self.weight_ih, self.weight_hh, self.bias_ih, self.bias_hh,\n        )\n\n        if not is_batched:\n            ret = ret.squeeze(0)\n\n        return ret\n\n\nif __name__ == \"__main__\":\n    import doctest\n\n    doctest.testmod(raise_on_error=True)\n"
  },
  {
    "path": "python/oneflow/nn/modules/roll.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow as flow\nfrom oneflow.framework.tensor import register_tensor_op\n\n\ndef roll_op(input, shifts, dims=None):\n    \"\"\"Roll the tensor along the given dimension(s). \n    \n    Elements that are shifted beyond the last position are re-introduced at the first position. \n    \n    If a dimension is not specified, the tensor will be flattened before rolling and then restored to the original shape.\n\n    Args:\n        input (oneflow.Tensor): the input Tensor.\n        shifts (int or tuple of ints): The number of places by which the elements of the tensor are shifted. \n                                              If shifts is a tuple, dims must be a tuple of the same size, \n                                              and each dimension will be rolled by the corresponding value.\n        dims (int or tuple of ints): Axis along which to roll.\n\n    Returns:\n        oneflow.Tensor: The result Tensor.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n        >>> x = np.array([[1, 2],\n        ...               [3, 4],\n        ...               [5, 6],\n        ...               [7, 8]])\n        >>> input = flow.Tensor(x)\n        >>> input.shape\n        oneflow.Size([4, 2])\n        >>> out = flow.roll(input, 1, 0)\n        >>> out\n        tensor([[7., 8.],\n                [1., 2.],\n                [3., 4.],\n                [5., 6.]], dtype=oneflow.float32)\n        >>> input.roll(-1, 1)\n        tensor([[2., 1.],\n                [4., 3.],\n                [6., 5.],\n                [8., 7.]], dtype=oneflow.float32)\n    \"\"\"\n    return flow._C.roll(input, shifts, dims)\n\n\nif __name__ == \"__main__\":\n    import doctest\n\n    doctest.testmod(raise_on_error=True)\n"
  },
  {
    "path": "python/oneflow/nn/modules/scatter.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport oneflow as flow\nfrom oneflow.framework.tensor import Tensor\nfrom oneflow.nn.modules.module import Module\n\n__all__ = [\"scatter\", \"scatter_add\", \"scatter_nd\", \"tensor_scatter_nd_update\"]\n\n\ndef scatter(input, dim, index, src, *, reduce=None):\n    r\"\"\"This operator writes the elements specified by `index` along with the axis \n    `dim` from the `src` into the `input`.\n\n    Take a 3-D blob as example, the output is specified by:\n    \n    .. code-block:: python\n\n        input[index[i][j][k]][j][k] = src[i][j][k]  # if dim == 0\n        input[i][index[i][j][k]][k] = src[i][j][k]  # if dim == 1\n        input[i][j][index[i][j][k]] = src[i][j][k]  # if dim == 2\n\n    input, index and src (if it is a Tensor) should all have the same number of dimensions. \n    It is also required that index.shape(d) <= src.shape(d) for all dimensions d, \n    and that index.shape(d) <= input.shape(d) for all dimensions d != dim.\n    Note that index and src do not broadcast.\n\n    .. warning::\n        When indices are not unique, the behavior is non-deterministic (one of the values from src will be picked arbitrarily) \n        and the gradient will be incorrect (it will be propagated to all locations in the source that correspond to the same index)!\n    \n    .. note::\n        The backward pass is implemented only for ``src.shape == index.shape``.\n    \n    Additionally accepts an optional ``reduce`` argument that allows specification of an optional reduction operation, \n    which is applied to all values in the tensor ``src`` into ``input`` at the indicies specified in the ``index``. \n    For each value in ``src``, the reduction operation is applied to an index in ``input`` which is specified by its index in ``src`` for ``dimension != dim`` \n    and by the corresponding value in ``index`` for ``dimension = dim``.\n\n    Given a 3-D tensor and reduction using the multiplication operation, input is updated as:\n\n    .. code-block:: python\n\n        input[index[i][j][k]][j][k] *= src[i][j][k]  # if dim == 0\n        input[i][index[i][j][k]][k] *= src[i][j][k]  # if dim == 1\n        input[i][j][index[i][j][k]] *= src[i][j][k]  # if dim == 2\n\n    Reducing with the addition operation is the same as using :func:`oneflow.scatter_add()`.\n\n    The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.Tensor.scatter\\_.html.\n\n    Args:\n        input (Tensor): The input blob.\n        dim (int): The axis along which to index\n        index (Tensor): The index blob of elements to scatter. \n        src (Tensor or float): The source blob whose elements will be scatterd and updated to output.\n        reduce (str, optional): Reduction operation to apply, can be either ``add`` or ``multiply``.\n\n    Returns:\n        Tensor: The scatterd Tensor. \n\n    For example: \n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n\n        >>> input = flow.ones((3,5))*2\n        >>> index = flow.tensor(np.array([[0,1,2],[0,1,4]], ), dtype=flow.int32)\n        >>> src = flow.Tensor(np.array([[0,10,20,30,40],[50,60,70,80,90]]))\n        >>> out = flow.scatter(input, 1, index, src)\n        >>> out\n        tensor([[ 0., 10., 20.,  2.,  2.],\n                [50., 60.,  2.,  2., 70.],\n                [ 2.,  2.,  2.,  2.,  2.]], dtype=oneflow.float32)\n\n    \"\"\"\n    return flow._C.scatter(input, dim, index, src, reduce=reduce)\n\n\ndef scatter_add(input, dim, index, src):\n    r\"\"\"This operator scatter the src with addition operation according to index along dim into the input.\n\n    Take a 3-D blob as example, the output is specified by:\n    \n    .. code-block:: python\n\n        input[index[i][j][k]][j][k] += src[i][j][k]  # if dim == 0\n        input[i][index[i][j][k]][k] += src[i][j][k]  # if dim == 1\n        input[i][j][index[i][j][k]] += src[i][j][k]  # if dim == 2\n\n    Args:\n        input (Tensor): The input blob.\n        dim (int): The axis along which to index\n        index (Tensor): The index blob of elements to scatter. \n        src (Tensor): The source blob whose elements will be scatterd and added to output.\n\n    Returns:\n        Tensor: The scatterd Tensor. \n\n    For example: \n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n        >>> input = flow.ones((3,5))*2\n        >>> index = flow.tensor(np.array([[0,1,2],[0,1,4]], ), dtype=flow.int32)\n        >>> src = flow.Tensor(np.array([[0,10,20,30,40],[50,60,70,80,90]]))\n        >>> out = flow.scatter_add(input, 1, index, src)\n        >>> out\n        tensor([[ 2., 12., 22.,  2.,  2.],\n                [52., 62.,  2.,  2., 72.],\n                [ 2.,  2.,  2.,  2.,  2.]], dtype=oneflow.float32)\n\n    \"\"\"\n\n    assert type(src) in [\n        flow.Tensor\n    ], f\"type of src must be oneflow.Tensor, but %s givien\" % type(src)\n\n    return flow._C.scatter_add(input, dim, index, src)\n\n\ndef scatter_nd(index, update, shape):\n    \"\"\"This operator inserts the elements in `update` according to the `index` and create a new Tensor.\n\n    Args:\n        index: The indices of `update`. Its type should be `flow.int`.\n        update: The update Tensor.\n        shape (Sequence[int]): The constant tensor shape, the constant tensor elements are all zero.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n        >>> index = flow.tensor(np.array([[1], [6], [4]]), dtype=flow.int)\n        >>> update = flow.tensor(np.array([10.2, 5.1, 12.7]), dtype=flow.float)\n        >>> out = flow.scatter_nd(index, update, [8])\n        >>> out\n        tensor([ 0.0000, 10.2000,  0.0000,  0.0000, 12.7000,  0.0000,  5.1000,  0.0000],\n               dtype=oneflow.float32)\n\n    \"\"\"\n    return flow._C.scatternd(index, update, shape)\n\n\ndef tensor_scatter_nd_update(tensor, indices, updates):\n    r\"\"\"\n    This operation creates a new tensor by applying sparse updates to the input tensor.\n    This is similar to an index assignment.\n\n    This operator is very similar to :meth:`scatter_nd`, except that the updates are scattered onto an existing\n    tensor (as opposed to a zero-tensor).\n\n    Args:\n        tensor: The tensor will be scattered.\n        indices: The indices of ``update``. Its type should be `flow.int`.\n        update: The update Tensor.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> tensor = flow.arange(8)\n        >>> indices = flow.tensor([[1], [3], [5]])\n        >>> updates = flow.tensor([-1, -2, -3])\n        >>> flow.tensor_scatter_nd_update(tensor, indices, updates)\n        tensor([ 0, -1,  2, -2,  4, -3,  6,  7], dtype=oneflow.int64)\n\n    \"\"\"\n    return flow._C.tensor_scatter_nd_update(tensor, indices, updates)\n\n\nif __name__ == \"__main__\":\n    import doctest\n\n    doctest.testmod(raise_on_error=True)\n"
  },
  {
    "path": "python/oneflow/nn/modules/slice.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nfrom typing import Sequence, Tuple\n\nimport oneflow as flow\nfrom oneflow.ops.array_ops import parse_slice_tuple_list\n\n\ndef slice_op(input, slice_tup_list: Sequence[Tuple[int, int, int]]):\n    \"\"\"Extracts a slice from a tensor.\n    The `slice_tup_list` assigns the slice indices in each dimension, the format is (start, stop, step).\n    The operator will slice the tensor according to the `slice_tup_list`.\n\n    Args:\n        input: A `Tensor`.\n        slice_tup_list: A list of slice tuple, indicate each dimension slice (start, stop, step).\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import numpy as np\n        >>> import oneflow as flow\n        >>> input = flow.Tensor(np.random.randn(3, 6, 9).astype(np.float32))\n        >>> tup_list = [[None, None, None], [0, 5, 2], [0, 6, 3]]\n        >>> y = flow.slice(input, slice_tup_list=tup_list)\n        >>> y.shape\n        oneflow.Size([3, 3, 2])\n    \"\"\"\n    (start, stop, step) = parse_slice_tuple_list(slice_tup_list, input.shape)\n    return flow._C.slice(input, start, stop, step)\n\n\ndef slice_update_op(input, update, slice_tup_list: Sequence[Tuple[int, int, int]]):\n    \"\"\"Update a slice of tensor `x`. Like `x[start:stop:step] = update`.\n\n    Args:\n        x: A `Tensor`, whose slice will be updated.\n        update: A `Tensor`, indicate the update content.\n        slice_tup_list: A list of slice tuple, indicate each dimension slice (start, stop, step).\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import numpy as np\n        >>> import oneflow as flow\n\n        >>> input = flow.Tensor(np.array([1, 1, 1, 1, 1]).astype(np.float32))\n        >>> update = flow.Tensor(np.array([2, 3, 4]).astype(np.float32))\n        >>> flow.slice_update(input, update, slice_tup_list=[[1, 4, 1]])\n        tensor([1., 2., 3., 4., 1.], dtype=oneflow.float32)\n\n    \"\"\"\n\n    (start, stop, step) = parse_slice_tuple_list(slice_tup_list, input.shape)\n    if update.dtype != input.dtype:\n        update = update.to(dtype=input.dtype)\n    return flow._C.slice_update(input, update, start, stop, step, inplace=True)\n\n\nif __name__ == \"__main__\":\n    import doctest\n\n    doctest.testmod(raise_on_error=True)\n"
  },
  {
    "path": "python/oneflow/nn/modules/sparse.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport os\nfrom typing import List, Optional, Tuple\n\nimport oneflow as flow\nfrom oneflow.framework.tensor import Tensor\nfrom oneflow.nn.modules.module import Module\n\n\nclass Embedding(Module):\n    \"\"\"A simple lookup table that stores embeddings of a fixed dictionary and size.\n\n    This module is often used to store word embeddings and retrieve them using indices.\n    The input to the module is a list of indices, and the output is the corresponding\n    word embeddings.\n\n    Args:\n        num_embeddings (int): size of the dictionary of embeddings\n        embedding_dim (int): the size of each embedding vector\n        padding_idx (int, optional): If specified, the entries at :attr:`padding_idx` do not contribute to the gradient;\n                                    therefore, the embedding vector at :attr:`padding_idx` is not updated during training,\n                                    i.e. it remains as a fixed \"pad\". For a newly constructed Embedding,\n                                    the embedding vector at :attr:`padding_idx` will default to all zeros,\n                                    but can be updated to another value to be used as the padding vector.\n        max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm` is renormalized to have \n                                    norm :attr:`max_norm`\n        norm_type (float, optional): The p of the p-norm to compute for the :attr:`max_norm` option. Default :attr:`2`.\n        scale_grad_by_freq (boolean, optional): If given, this will scale gradients by the inverse of \n                                                frequency of the words in the mini-batch. Default :attr:`False`\n\n    For example:\n\n    .. code-block:: python\n        \n        >>> import numpy as np\n        >>> import oneflow as flow\n        \n        >>> indices = flow.tensor([[1, 2, 4, 5], [4, 3, 2, 9]], dtype=flow.int)\n        >>> m = flow.nn.Embedding(10, 3)\n        >>> y = m(indices)\n        \n    ..\n        Feature Stage of Operator [Embedding].\n        - Maintainer List [@EsdeathYZH]\n        - Current Stage [ ]\n        - Alpha Stage Check List [ ]\n          - API(Compatible with PyTorch 1.11, anything incompatible must be noted in API Doc.)[Yes]\n          - Doc(API Doc must be provided and showed normally on the web page.)[Yes]\n          - Functionality and its' Test [ ]\n            - Functionality is highly compatiable with PyTorch 1.11. [Yes]\n            - eager local [Yes] [@EsdeathYZH]\n              - forward [Yes]\n              - backward [Yes]\n              - gpu [Yes]\n              - cpu [Yes]\n            - graph local [ ] [@BBuf, @strint, @hjchen2]\n              - forward [Yes]\n              - backward [ ]\n              - gpu [Yes]\n              - cpu [Yes]\n          - Exception Handling\n            - Exception Message and Hint must be provided [ ]\n        - Beta Stage Check List [ ]\n          - API(High compatibility with PyTorch 1.11, shouldn't have anything incompatible for a naive reason.)[ ]\n          - Doc(Same standard as Alpha Stage)[ ]\n          - Functionality and its' Test [ ]\n            - eager global [ ]\n              - forward [ ]\n              - backward [ ]\n              - gpu [ ]\n              - cpu [ ]\n            - graph gloal [ ]\n              - forward [ ]\n              - backward [ ]\n              - gpu [ ]\n              - cpu [ ]\n          - Performance and Scalability(Must be evaluated.)[ ]\n            - CUDA kernel [ ]\n            - CPU kernel [ ]\n            - N nodes M devices [ ]\n          - Exception Handling [ ]\n            - Exception Message and Hint must be provided [ ]\n            - Try you best to do Exception Recovery [ ]\n        - Stable Stage Check List [ ]\n          - API(Same standard as Beta Stage)[ ]\n          - Doc(Same standard as Beta Stage)[ ]\n          - Functionality and its' Test [ ]\n            - fp16 and AMP [ ]\n            - NHWC [ ]\n          - Performance and Scalability(Must be evaluated.)[ ]\n          - Exception Handling [ ]\n\n    \"\"\"\n\n    def __init__(\n        self,\n        num_embeddings: int,\n        embedding_dim: int,\n        padding_idx: Optional[int] = None,\n        max_norm: Optional[float] = None,\n        norm_type: float = 2.0,\n        scale_grad_by_freq: bool = False,\n        sparse: bool = False,\n        _weight: Optional[Tensor] = None,\n        device=None,\n        dtype=None,\n    ):\n        factory_kwargs = {\"device\": device, \"dtype\": dtype}\n        super().__init__()\n        self.num_embeddings = num_embeddings\n        self.embedding_dim = embedding_dim\n        if padding_idx is not None:\n            if padding_idx > 0:\n                assert (\n                    padding_idx < self.num_embeddings\n                ), \"Padding_idx must be within num_embeddings\"\n            elif padding_idx < 0:\n                assert (\n                    padding_idx >= -self.num_embeddings\n                ), \"Padding_idx must be within num_embeddings\"\n                padding_idx = self.num_embeddings + padding_idx\n        self.padding_idx = padding_idx\n        self.max_norm = max_norm\n        self.norm_type = norm_type\n        self.scale_grad_by_freq = scale_grad_by_freq\n        assert sparse is False, \"Not support sparse=True yet!\"\n        if _weight is None:\n            self.weight = flow.nn.Parameter(\n                flow.empty((num_embeddings, embedding_dim), **factory_kwargs)\n            )\n            self.reset_parameters()\n        else:\n            assert list(_weight.shape) == [\n                num_embeddings,\n                embedding_dim,\n            ], \"Shape of weight does not match num_embeddings and embedding_dim\"\n            self.weight = flow.nn.Parameter(_weight)\n        self.sparse = sparse\n\n    def reset_parameters(self) -> None:\n        if os.getenv(\"ONEFLOW_LINEAR_EMBEDDING_SKIP_INIT\", \"0\") == \"1\":\n            return\n        flow.nn.init.normal_(self.weight)\n        self._fill_padding_idx_with_zero()\n\n    def _fill_padding_idx_with_zero(self) -> None:\n        if self.padding_idx is not None:\n            with flow.no_grad():\n                self.weight[self.padding_idx] = 0\n\n    def extra_repr(self) -> str:\n        s = \"{num_embeddings}, {embedding_dim}\"\n        if self.padding_idx is not None:\n            s += \", padding_idx={padding_idx}\"\n        if self.max_norm is not None:\n            s += \", max_norm={max_norm}\"\n        if self.norm_type != 2:\n            s += \", norm_type={norm_type}\"\n        if self.scale_grad_by_freq is not False:\n            s += \", scale_grad_by_freq={scale_grad_by_freq}\"\n        if self.sparse is not False:\n            s += \", sparse=True\"\n        return s.format(**self.__dict__)\n\n    def forward(self, indices):\n        if self.max_norm is not None:\n            with flow.no_grad():\n                flow._C.embedding_renorm_(\n                    self.weight, indices, self.max_norm, self.norm_type\n                )\n        if self.padding_idx is None and not self.scale_grad_by_freq:\n            return flow._C.gather(self.weight, indices, axis=0)\n        else:\n            return flow._C.embedding(\n                self.weight, indices, self.padding_idx, self.scale_grad_by_freq\n            )\n\n\ndef embedding(\n    input,\n    weight,\n    padding_idx=None,\n    max_norm=None,\n    norm_type=2.0,\n    scale_grad_by_freq=False,\n    sparse=False,\n):\n    r\"\"\"A simple lookup table that looks up embeddings in a fixed dictionary and size.\n\n    This module is often used to retrieve word embeddings using indices.\n    The input to the module is a list of indices, and the embedding matrix,\n    and the output is the corresponding word embeddings.\n\n    See :class:`oneflow.nn.Embedding` for more details.\n\n    Args:\n        input (oneflow.LongTensor): Tensor containing indices into the embedding matrix\n        weight (Tensor): The embedding matrix with number of rows equal to the maximum possible index + 1,\n            and number of columns equal to the embedding size\n        padding_idx (int, optional): If specified, the entries at :attr:`padding_idx` do not contribute to the gradient;\n                                     therefore, the embedding vector at :attr:`padding_idx` is not updated during training,\n                                     i.e. it remains as a fixed \"pad\".\n        max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is renormalized to have \n                                    norm max_norm\n        norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2.\n        scale_grad_by_freq (boolean, optional): If given, this will scale gradients by the inverse of \n                                                frequency of the words in the mini-batch. Default False\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import oneflow.nn.functional as F\n\n        >>> # a batch of 2 samples of 4 indices each\n        >>> input = flow.tensor([[1,2,4,5],[4,3,2,9]])\n        >>> # an embedding matrix containing 10 tensors of size 3\n        >>> embedding_matrix = flow.rand(10, 3)\n        >>> output = F.embedding(input, embedding_matrix)\n        >>> output.shape\n        oneflow.Size([2, 4, 3])\n        >>> # example with padding_idx\n        >>> input = flow.tensor([[0,2,0,5]])\n        >>> output = F.embedding(input, embedding_matrix, padding_idx=0)\n        >>> output.shape\n        oneflow.Size([1, 4, 3])\n    \"\"\"\n\n    assert sparse is False, \"Not support sparse=True yet!\"\n    if padding_idx is not None:\n        if padding_idx > 0:\n            assert padding_idx < weight.size(\n                0\n            ), \"Padding_idx must be within num_embeddings\"\n        elif padding_idx < 0:\n            assert padding_idx >= -weight.size(\n                0\n            ), \"Padding_idx must be within num_embeddings\"\n            padding_idx = weight.size(0) + padding_idx\n\n    if max_norm is not None:\n        with flow.no_grad():\n            weight = flow._C.embedding_renorm_(weight, input, max_norm, norm_type)\n\n    if padding_idx is None and not scale_grad_by_freq:\n        return flow._C.gather(weight, input, axis=0)\n    else:\n        return flow._C.embedding(weight, input, padding_idx, scale_grad_by_freq)\n\n\nif __name__ == \"__main__\":\n    import doctest\n\n    doctest.testmod(raise_on_error=True)\n"
  },
  {
    "path": "python/oneflow/nn/modules/sparse_softmax_cross_entropy.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow as flow\n\n\ndef sparse_softmax_cross_entropy(labels, logits):\n    \"\"\"The interface is consistent with TensorFlow.    \n    The documentation is referenced from: \n    https://www.tensorflow.org/api_docs/python/tf/nn/sparse_softmax_cross_entropy_with_logits\n    \n    Computes sparse softmax cross entropy between `logits` and `labels`.\n\n    Measures the probability error in discrete classification tasks in which the\n    classes are mutually exclusive (each entry is in exactly one class).  For\n    example, each CIFAR-10 image is labeled with one and only one label: an image\n    can be a dog or a truck, but not both.\n\n    A common use case is to have logits of shape\n    `[batch_size, num_classes]` and have labels of shape\n    `[batch_size]`, but higher dimensions are supported, in which\n    case the `dim`-th dimension is assumed to be of size `num_classes`.\n    `logits` must have the dtype of `float16`, `float32`, or `float64`, and\n    `labels` must have the dtype of `int32` or `int64`.\n\n    Args:\n        labels (Tensor): shape with [d_0, d_1, ..., d_{r-1}] (where `r` is rank of\n            `labels` and output) and dtype `int32` or `int64`. Each entry in `labels`\n            must be an index in [0, num_classes).\n        logits (Tensor): Per-label activations (typically a linear output) of shape\n            [d_0, d_1, ..., d_{r-1}, num_classes] and dtype `float16`, `float32`, or\n            `float64`. These activation energies are interpreted as unnormalized log\n            probabilities.\n\n    Returns:\n        output (Tensor): A `Tensor` of the same shape as `labels` and of the same type as `logits`\n        with the softmax cross entropy loss.\n\n    Examples::\n        >>> import numpy as np\n        >>> import oneflow as flow\n        >>> np_logits = np.array(\n        ...      [\n        ...          [2.0, -5.0, 0.5, -0.1],\n        ...          [0.0, 0.0, 1.9, 1.4],\n        ...          [-100.0, 100.0, -100.0, -100.0],\n        ...      ]\n        ...  )\n        >>> np_labels = np.array([0, 3, 1])\n        >>> logits = flow.tensor(np_logits, dtype=flow.float32)\n        >>> labels = flow.tensor(np_labels, dtype=flow.int32)\n        >>> output = flow.nn.functional.sparse_softmax_cross_entropy(\n        ...     labels=labels, logits=logits\n        ... )\n        >>> output\n        tensor([ 2.9751e-01,  1.1448e+00, -1.4305e-06], dtype=oneflow.float32)\n    \"\"\"\n    return flow._C.sparse_softmax_cross_entropy(logits, labels)\n\n\nif __name__ == \"__main__\":\n    import doctest\n\n    doctest.testmod(raise_on_error=True)\n"
  },
  {
    "path": "python/oneflow/nn/modules/tensor_buffer.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nfrom typing import Optional, Sequence\n\nimport oneflow as flow\n\n\ndef tensor_buffer_to_tensor_op(x, dtype: flow.dtype, instance_shape: Sequence[int]):\n    \"\"\"This operator converts the Tensor's type from TensorBuffer to original type.\n    Some operator's output data type is `TensorBuffer`, you can use this operator to convert back\n    to `Tensor`.\n\n    Refer to `Concept Explanation <https://docs.oneflow.org/basics_topics/concept_explanation.html#3tensorbuffer-tensorlist>`_\n    for more about TensorBuffer.\n\n    Args:\n        x (oneflow.Tensor): The input Tensor.\n        dtype (flow.dtype): The data dtype.\n        instance_shape (Sequence[int]): The shape of each TensorBuffer instance.\n\n    Returns:\n        oneflow.Tensor: The result Tensor.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import numpy as np\n        >>> import oneflow as flow\n        >>> x = np.random.randn(4, 16, 64, 64).astype(np.float32)\n        >>> x = flow.Tensor(x)\n        >>> x = flow.tensor_to_tensor_buffer(x, instance_dims=2)\n        >>> output = flow.tensor_buffer_to_tensor(x, instance_shape=(64, 64), dtype=flow.float)\n        >>> output.shape\n        oneflow.Size([4, 16, 64, 64])\n\n    \"\"\"\n    return flow._C.tensor_buffer_to_tensor(\n        x, dtype=dtype, instance_shape=instance_shape\n    )\n\n\ndef tensor_to_tensor_buffer(x, instance_dims: int):\n    \"\"\"This operator converts the Tensor's type to TensorBuffer.\n\n    Refer to `Concept Explanation <https://docs.oneflow.org/basics_topics/concept_explanation.html#3tensorbuffer-tensorlist>`_\n    for more about TensorBuffer.\n\n    Args:\n        x (oneflow.Tensor): The input Tensor.\n        instance_dims (int): The dimensions of dynamic tensor instance.\n\n    Returns:\n        oneflow.Tensor: The result Tensor.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import numpy as np\n        >>> import oneflow as flow\n        >>> x = np.random.randn(4, 16, 64, 64).astype(np.float32)\n        >>> x = flow.Tensor(x)\n        >>> x = flow.tensor_to_tensor_buffer(x, instance_dims=2)\n        >>> output = flow.tensor_buffer_to_tensor(x, instance_shape=(64, 64), dtype=flow.float)\n        >>> output.shape\n        oneflow.Size([4, 16, 64, 64])\n    \n    \"\"\"\n    return flow._C.tensor_to_tensor_buffer(x, instance_dims)\n\n\ndef gen_tensor_buffer(\n    shape: Sequence[int],\n    shape_list: Sequence[Sequence[int]],\n    value_list: Sequence[float],\n    data_type: Optional[flow.dtype] = flow.float32,\n    dynamic_out: Optional[bool] = False,\n):\n    return flow._C.gen_tensor_buffer(\n        shape, shape_list, value_list, data_type, dynamic_out\n    )\n\n\nif __name__ == \"__main__\":\n    import doctest\n\n    doctest.testmod(raise_on_error=True)\n"
  },
  {
    "path": "python/oneflow/nn/modules/tensordot.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow\nfrom typing import Union, List, Tuple\nimport warnings\n\n\ndef tensordot(\n    a,\n    b,\n    dims: Union[oneflow.Tensor, int, List[List[int]], Tuple[List[int]]] = 2,\n    out=None,\n):\n    if out is not None:\n        raise NotImplementedError(\n            \"tensordot with `out` parameter which is not None is not yet implemented\"\n        )\n    if not isinstance(dims, (oneflow.Tensor, int, list, tuple)):\n        raise TypeError(\n            f\"oneflow.tensordot expects dims to be one of oneflow.Tensor, int, Tuple[List[int], List[int]] or List[List[int]] containing two lists, but got {type(dims)}\"\n        )\n\n    if isinstance(dims, int):\n        return oneflow._C.tensordot(a, b, dims)\n    elif isinstance(dims, (list, tuple)):\n        assert (\n            len(dims) == 2\n        ), f\"The list/tuple of dims must contain two lists, got {len(dims)}\"\n        dim_a = list(dims[0])\n        dim_b = list(dims[1])\n    elif isinstance(dims, oneflow.Tensor):\n        warnings.warn(\n            \"tensordot doesn't support nn.Graph when the type of `dims` is oneflow.Tensor, because it needs synchronization.\"\n        )\n        if dims.numel() == 1:\n            return oneflow._C.tensordot(a, b, dims.item())\n        assert (\n            dims.dim() == 2\n        ), f\"The dims tensor must have two dimensions, got {dims.dim()}\"\n        assert (\n            len(dims) == 2 and dims.dim() == 2\n        ), f\"The dims tensor must have two rows, got {len(dims)}\"\n        dim_a = dims[0].tolist()\n        dim_b = dims[1].tolist()\n\n    return oneflow._C.tensordot(a, b, dim_a, dim_b)\n"
  },
  {
    "path": "python/oneflow/nn/modules/trigonometric_ops.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow as flow\nfrom oneflow.nn.modules.module import Module\nfrom oneflow.framework.tensor import register_tensor_op\n\n\ndef sign_op(input):\n    \"\"\"Computes the sign of Tensor.\n\n    .. math::\n\n        \\\\text{out}_{i}  = \\\\text{sgn}(\\\\text{input}_{i})\n\n    Args:\n        input (Tensor): the input tensor.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n        >>> x1 = flow.Tensor(np.array([-2, 0, 2]).astype(np.float32))\n        >>> out1 = flow.sign(x1)\n        >>> out1.numpy()\n        array([-1.,  0.,  1.], dtype=float32)\n        >>> x2 = flow.Tensor(np.array([-3.2, -4.5, 5.8]).astype(np.float32),device=flow.device('cuda'))\n        >>> out2 = flow.sign(x2)\n        >>> out2.numpy()\n        array([-1., -1.,  1.], dtype=float32)\n\n    \"\"\"\n    return flow._C.sign(input)\n\n\ndef sinh_op(input):\n    \"\"\"Returns a new tensor with the hyperbolic sine of the elements of :attr:`input`.\n\n    .. math::\n        \\\\text{out}_{i} = \\\\sinh(\\\\text{input}_{i})\n\n    Args:\n        input (Tensor): the input tensor.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import numpy as np\n        >>> import oneflow as flow\n\n        >>> x1 = flow.Tensor(np.array([1, 2, 3]))\n        >>> x2 = flow.Tensor(np.array([1.53123589,0.54242598,0.15117185]))\n        >>> x3 = flow.Tensor(np.array([1,0,-1]))\n\n        >>> flow.sinh(x1).numpy()\n        array([ 1.1752012,  3.6268604, 10.017875 ], dtype=float32)\n        >>> flow.sinh(x2).numpy()\n        array([2.20381  , 0.5694193, 0.1517483], dtype=float32)\n        >>> flow.sinh(x3).numpy()\n        array([ 1.1752012,  0.       , -1.1752012], dtype=float32)\n\n    \"\"\"\n    return flow._C.sinh(input)\n\n\ndef tan_op(input):\n    \"\"\"Returns  the tan value of the elements of :attr:`input`.\n\n    .. math::\n        \\\\text{out}_{i} = \\\\tan(\\\\text{input}_{i})\n\n    Args:\n        input (Tensor): the input tensor.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n        >>> np_arr = np.array([-1/4*np.pi, 0, 1/4*np.pi]).astype(np.float32)\n        >>> input = flow.Tensor(np_arr)\n        >>> output = flow.tan(input)\n        >>> output\n        tensor([-1.,  0.,  1.], dtype=oneflow.float32)\n\n    \"\"\"\n    return flow._C.tan(input)\n\n\ndef acosh_op(input):\n    \"\"\"Returns a new tensor with the inverse hyperbolic cosine of the elements of :attr:`input`.\n\n    .. math::\n\n        \\\\text{out}_{i} = \\\\cosh^{-1}(\\\\text{input}_{i})\n\n    Args:\n        input (Tensor): the input tensor.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n        >>> x1 = flow.Tensor(np.array([2, 3, 4]).astype(np.float32))\n        >>> out1 = flow.acosh(x1)\n        >>> out1\n        tensor([1.3170, 1.7627, 2.0634], dtype=oneflow.float32)\n        >>> x2 = flow.Tensor(np.array([1.5, 2.6, 3.7]).astype(np.float32),device=flow.device('cuda'))\n        >>> out2 = flow.acosh(x2)\n        >>> out2\n        tensor([0.9624, 1.6094, 1.9827], device='cuda:0', dtype=oneflow.float32)\n\n    \"\"\"\n    return flow._C.acosh(input)\n\n\ndef arccosh_op(input):\n    \"\"\"\n\n    See :func:`oneflow.acosh`\n\n    \"\"\"\n    return flow._C.acosh(input)\n\n\nif __name__ == \"__main__\":\n    import doctest\n\n    doctest.testmod(raise_on_error=True)\n"
  },
  {
    "path": "python/oneflow/nn/modules/unique.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow as flow\n\n\ndef unique_op(\n    input, sorted=True, return_inverse=False, return_counts=False, dtype=flow.int\n):\n    r\"\"\"\n    Returns the unique elements of the input tensor.\n\n    The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.unique.html.\n\n    Args:\n        input (Tensor): The input tensor.\n        sorted (bool): Whether to sort the unique elements in ascending order before returning as output.\n        return_inverse (bool): Whether to also return the indices for where elements in the original input ended up in the returned unique list.\n        return_counts (bool): Whether to also return the counts for each unique element.\n        dtype (flow.dtype): Dtype of the returned indices and counts.\n\n    Returns:\n        oneflow.Tensor or List of oneflow.Tensor:\n\n        - **output** (Tensor): the output list of unique scalar elements.\n\n        - **inverse_indices** (Tensor): (optional) if return_inverse is True, \n          there will be an additional returned tensor (same shape as input) representing\n          the indices for where elements in the original input map to in the output;\n          otherwise, this function will only return a single tensor.\n\n        - **counts** (Tensor): (optional) if return_counts is True, there will be an additional\n          returned tensor (same shape as output or output.size(dim), if dim was specified)\n          representing the number of occurrences for each unique value or tensor.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> x = flow.tensor([3, 1, 2, 0 ,2])\n        >>> flow.unique(x)\n        tensor([0, 1, 2, 3], dtype=oneflow.int64)\n        >>> flow.unique(x, sorted=False)\n        tensor([3, 1, 2, 0], dtype=oneflow.int64)\n        >>> results, indices = flow.unique(x, return_inverse=True)\n        >>> indices\n        tensor([3, 1, 2, 0, 2], dtype=oneflow.int32)\n        >>> results, counts = flow.unique(x, return_counts=True)\n        >>> counts\n        tensor([1, 1, 2, 1], dtype=oneflow.int32)\n        >>> results, indices = flow.unique(x, return_inverse=True, dtype=flow.long)\n        >>> indices\n        tensor([3, 1, 2, 0, 2], dtype=oneflow.int64)\n\n    \"\"\"\n    if not return_inverse and not return_counts:\n        return flow._C.unique(input, sorted, dtype=dtype)\n    else:\n        return flow._C.unique(\n            input,\n            sorted,\n            return_inverse=return_inverse,\n            return_counts=return_counts,\n            dtype=dtype,\n        )\n\n\nif __name__ == \"__main__\":\n    import doctest\n\n    doctest.testmod(raise_on_error=True)\n"
  },
  {
    "path": "python/oneflow/nn/modules/upsampling.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nfrom typing import Optional, Tuple, Union\n\nimport oneflow as flow\nfrom oneflow.nn.modules.module import Module\n\n\nclass Upsample(Module):\n    \"\"\"    \n    Upsamples a given multi-channel 1D (temporal), 2D (spatial) or 3D (volumetric) data.\n\n    The input data is assumed to be of the form\n    `minibatch x channels x [optional depth] x [optional height] x width`.\n    Hence, for spatial inputs, we expect a 4D Tensor and for volumetric inputs, we expect a 5D Tensor.\n\n    The algorithms available for upsampling are nearest neighbor and linear,\n    bilinear, bicubic and trilinear for 3D, 4D and 5D input Tensor,\n    respectively.\n\n    One can either give a :attr:`scale_factor` or the target output :attr:`size` to\n    calculate the output size. (You cannot give both, as it is ambiguous)\n\n    The interface is consistent with PyTorch.\n    The documentation is referenced from: https://pytorch.org/docs/1.10/_modules/torch/nn/modules/upsampling.html.\n\n    Args:\n        size (int or Tuple[int] or Tuple[int, int] or Tuple[int, int, int], optional):\n            output spatial sizes\n        scale_factor (float or Tuple[float] or Tuple[float, float] or Tuple[float, float, float], optional):\n            multiplier for spatial size. Has to match input size if it is a tuple.\n        mode (str, optional): the upsampling algorithm: one of ``'nearest'``,\n            ``'linear'``, ``'bilinear'``, ``'bicubic'`` and ``'trilinear'``.\n            Default: ``'nearest'``\n        align_corners (bool, optional): if ``True``, the corner pixels of the input\n            and output tensors are aligned, and thus preserving the values at\n            those pixels. This only has effect when :attr:`mode` is\n            ``'linear'``, ``'bilinear'``, or ``'trilinear'``. Default: ``False``\n\n    Shape:\n        - Input: :math:`(N, C, W_{in})`, :math:`(N, C, H_{in}, W_{in})` or :math:`(N, C, D_{in}, H_{in}, W_{in})`\n        - Output: :math:`(N, C, W_{out})`, :math:`(N, C, H_{out}, W_{out})`\n          or :math:`(N, C, D_{out}, H_{out}, W_{out})`, where\n\n    .. math::\n        D_{out} = \\\\left\\\\lfloor D_{in} \\\\times \\\\text{scale_factor} \\\\right\\\\rfloor\n\n    .. math::\n        H_{out} = \\\\left\\\\lfloor H_{in} \\\\times \\\\text{scale_factor} \\\\right\\\\rfloor\n\n    .. math::\n        W_{out} = \\\\left\\\\lfloor W_{in} \\\\times \\\\text{scale_factor} \\\\right\\\\rfloor\n\n    .. warning::\n        With ``align_corners = True``, the linearly interpolating modes\n        (`linear`, `bilinear`, `bicubic`, and `trilinear`) don't proportionally\n        align the output and input pixels, and thus the output values can depend\n        on the input size. This was the default behavior for these modes up to\n        version 0.3.1. Since then, the default behavior is\n        ``align_corners = False``. See below for concrete examples on how this\n        affects the outputs.\n\n    .. note::\n        If you want downsampling/general resizing, you should use :func:`~nn.functional.interpolate`.\n\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import numpy as np\n        >>> import oneflow as flow\n\n        >>> input = flow.tensor(np.arange(1, 5).reshape((1, 1, 2, 2)), dtype=flow.float32)\n        >>> input = input.to(\"cuda\")\n        >>> m = flow.nn.Upsample(scale_factor=2.0, mode=\"nearest\")\n        >>> output = m(input)\n        >>> output #doctest: +ELLIPSIS\n        tensor([[[[1., 1., 2., 2.],\n                  ...\n                  [3., 3., 4., 4.]]]], device='cuda:0', dtype=oneflow.float32)\n\n    \"\"\"\n\n    def __init__(\n        self,\n        size: Optional[Union[int, Tuple[int, ...]]] = None,\n        scale_factor: Optional[Union[float, Tuple[float, ...]]] = None,\n        mode: str = \"nearest\",\n        align_corners: Optional[bool] = None,\n    ):\n        super().__init__()\n        self.size = size\n        self.scale_factor = scale_factor\n        self.mode = mode\n        self.align_corners = align_corners\n\n    def forward(self, x):\n        return flow.nn.functional.interpolate(\n            x,\n            size=self.size,\n            scale_factor=self.scale_factor,\n            mode=self.mode,\n            align_corners=self.align_corners,\n        )\n\n    def extra_repr(self) -> str:\n        if self.scale_factor is not None:\n            info = \"scale_factor=\" + str(self.scale_factor)\n        else:\n            info = \"size=\" + str(self.size)\n        info += \", mode=\" + self.mode\n        return info\n\n\nclass UpsamplingNearest2d(Upsample):\n    \"\"\"Applies a 2D nearest neighbor upsampling to an input signal composed of several input\n    channels.\n\n    To specify the scale, it takes either the :attr:`size` or the :attr:`scale_factor`\n    as it's constructor argument.\n\n    When :attr:`size` is given, it is the output size of the image `(h, w)`.\n\n    Args:\n        size (int or Tuple[int, int], optional): output spatial sizes\n        scale_factor (float or Tuple[float, float], optional): multiplier for\n            spatial size.\n\n    .. warning::\n        This class is deprecated in favor of :func:`~nn.functional.interpolate`.\n\n    Shape:\n        - Input: :math:`(N, C, H_{in}, W_{in})`\n        - Output: :math:`(N, C, H_{out}, W_{out})` where\n\n    .. math::\n          H_{out} = \\\\left\\\\lfloor H_{in} \\\\times \\\\text{scale_factor} \\\\right\\\\rfloor\n\n    .. math::\n          W_{out} = \\\\left\\\\lfloor W_{in} \\\\times \\\\text{scale_factor} \\\\right\\\\rfloor\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import numpy as np\n        >>> import oneflow as flow\n        \n        >>> input = flow.tensor(np.arange(1, 5).reshape((1, 1, 2, 2)), dtype=flow.float32)\n        >>> input = input.to(\"cuda\")\n        >>> m = flow.nn.UpsamplingNearest2d(scale_factor=2.0)\n        >>> output = m(input)\n        >>> output #doctest: +ELLIPSIS\n        tensor([[[[1., 1., 2., 2.],\n                  ...\n                  [3., 3., 4., 4.]]]], device='cuda:0', dtype=oneflow.float32)\n\n    \"\"\"\n\n    def __init__(\n        self,\n        size: Optional[Tuple[int, int]] = None,\n        scale_factor: Optional[Tuple[float, float]] = None,\n    ) -> None:\n        super(UpsamplingNearest2d, self).__init__(size, scale_factor, mode=\"nearest\")\n\n\nclass UpsamplingBilinear2d(Upsample):\n    \"\"\"Applies a 2D bilinear upsampling to an input signal composed of several input\n    channels.\n\n    To specify the scale, it takes either the :attr:`size` or the :attr:`scale_factor`\n    as it's constructor argument.\n\n    When :attr:`size` is given, it is the output size of the image `(h, w)`.\n\n    Args:\n        size (int or Tuple[int, int], optional): output spatial sizes\n        scale_factor (float or Tuple[float, float], optional): multiplier for\n            spatial size.\n\n    .. warning::\n        This class is deprecated in favor of :func:`~nn.functional.interpolate`. It is\n        equivalent to ``nn.functional.interpolate(..., mode='bilinear', align_corners=True)``.\n\n    Shape:\n        - Input: :math:`(N, C, H_{in}, W_{in})`\n        - Output: :math:`(N, C, H_{out}, W_{out})` where\n\n    .. math::\n        H_{out} = \\\\left\\\\lfloor H_{in} \\\\times \\\\text{scale_factor} \\\\right\\\\rfloor\n\n    .. math::\n        W_{out} = \\\\left\\\\lfloor W_{in} \\\\times \\\\text{scale_factor} \\\\right\\\\rfloor\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import numpy as np\n        >>> import oneflow as flow\n        \n        >>> input = flow.tensor(np.arange(1, 5).reshape((1, 1, 2, 2)), dtype=flow.float32)\n        >>> input = input.to(\"cuda\")\n        >>> m = flow.nn.UpsamplingBilinear2d(scale_factor=2.0)\n        >>> output = m(input)\n        >>> output #doctest: +ELLIPSIS\n        tensor([[[[1.0000, 1.3333, 1.6667, 2.0000],\n                  ...\n                  [3.0000, 3.3333, 3.6667, 4.0000]]]], device='cuda:0',\n               dtype=oneflow.float32)\n\n    \"\"\"\n\n    def __init__(\n        self,\n        size: Optional[Tuple[int, int]] = None,\n        scale_factor: Optional[Tuple[float, float]] = None,\n    ) -> None:\n        super(UpsamplingBilinear2d, self).__init__(\n            size, scale_factor, mode=\"bilinear\", align_corners=True\n        )\n\n\nif __name__ == \"__main__\":\n    import doctest\n\n    doctest.testmod(raise_on_error=True)\n"
  },
  {
    "path": "python/oneflow/nn/modules/utils.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport collections.abc as container_abcs\nfrom itertools import repeat\nfrom typing import List\n\nimport oneflow as flow\n\n\ndef _ntuple(n):\n    def parse(x):\n        if isinstance(x, container_abcs.Iterable):\n            return tuple(x)\n        return tuple(repeat(x, n))\n\n    return parse\n\n\ndef _getint():\n    def parse(x):\n        if isinstance(x, container_abcs.Iterable):\n            return int(x[0])\n        return int(x)\n\n    return parse\n\n\n_getint = _getint()\n_single = _ntuple(1)\n_pair = _ntuple(2)\n_triple = _ntuple(3)\n_quadruple = _ntuple(4)\n\n\ndef _handle_size_arg(size):\n    if len(size) == 0:\n        return size\n    assert len(size) > 0, \"size of tensor doesn't exists\"\n    if isinstance(size[0], (list, tuple, flow.Size)):\n        assert (\n            len(size) == 1\n        ), \"shape should be specified by tuple of int size, not tuple of list\"\n        size = size[0]\n    return size\n\n\ndef _reverse_repeat_tuple(t, n):\n    \"\"\"Reverse the order of `t` and repeat each element for `n` times.\n    This can be used to translate padding arg used by Conv and Pooling modules\n    to the ones used by `F.pad`.\n    \"\"\"\n    return tuple((x for x in reversed(t) for _ in range(n)))\n\n\ndef _list_with_default(out_size, defaults):\n    if isinstance(out_size, int):\n        return out_size\n    if len(defaults) <= len(out_size):\n        raise ValueError(\n            \"Input dimension should be at least {}\".format(len(out_size) + 1)\n        )\n    return [\n        v if v is not None else d\n        for (v, d) in zip(out_size, defaults[-len(out_size) :])\n    ]\n\n\ndef _check_axis(axis, shape):\n    ndim = len(shape)\n    if axis is None:\n        axis = list(range(len(shape)))\n    if isinstance(axis, int):\n        axis = [axis]\n    assert isinstance(axis, (list, tuple)), \"Invalid axis {}\".format(axis)\n    axis = list(axis)\n    for i in range(len(axis)):\n        assert (\n            -ndim <= axis[i] <= ndim - 1\n        ), \"Dimension out of range (expected to be in range of [{}, {}], but got {})\".format(\n            -ndim, ndim - 1, axis[i]\n        )\n        if axis[i] < 0:\n            axis[i] = axis[i] + ndim\n    return axis\n\n\ndef _generate_output_size(input_size, output_size):\n    new_output_size = []\n    assert len(input_size) - 2 == len(\n        output_size\n    ), f\"the length of 'output_size' does not match the input size, {len(input_size) - 2} expected\"\n    for i in range(len(output_size)):\n        if output_size[i] is None:\n            new_output_size.append(input_size[i + 2])\n        else:\n            assert isinstance(\n                output_size[i], int\n            ), \"numbers in 'output_size' should be integer\"\n            new_output_size.append(output_size[i])\n    return tuple(new_output_size)\n"
  },
  {
    "path": "python/oneflow/nn/modules/where.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow as flow\nfrom oneflow.framework.tensor import register_tensor_op\n\n\ndef where_op(condition, x=None, y=None):\n    if x is None and y is None:\n        return flow.nonzero(condition, as_tuple=True)\n\n    return flow._C.where(condition, x, y)\n\n\nif __name__ == \"__main__\":\n    import doctest\n\n    doctest.testmod(raise_on_error=True)\n"
  },
  {
    "path": "python/oneflow/nn/optimizer/__init__.py",
    "content": ""
  },
  {
    "path": "python/oneflow/nn/optimizer/adadelta.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport collections\nimport math\nfrom typing import Callable, Dict, Iterator, List, Tuple, Union\n\nimport oneflow as flow\nfrom oneflow.optim.optimizer import Optimizer, ParamGroup\nfrom oneflow.nn.parameter import Parameter\n\n\nclass Adadelta(Optimizer):\n    r\"\"\"Implements Adadelta Optimizer. \n\n        The formula is: \n\n        .. math::\n\n            & v_{t} = v_{t-1} * rho + g_{t}^2 * (1 - rho)\n\n            & delta = \\frac{\\sqrt{u_{t-1} + \\epsilon}}{\\sqrt{v_{t} + \\epsilon}} * g_{t}\n            \n            & u_{t} = u_{t-1} * rho + delta^2*(1 - rho)\n\n            & x_{t} = x_{t-1} - lr * delta\n\n        Args:\n            params (Union[Iterator[Parameter], List[Dict]]): iterable of parameters to optimize or dicts defining parameter groups\n            lr (float, optional): The learning rate. Defaults to 0.001.\n            rho (float, optional): The decay factor of learning rate. Defaults to 0.0.\n            eps (float, optional): A small constant terms added to the denominator to improve numerical stability. Defaults to 1e-10.\n            weight_decay (float, optional): The weight decay. Defaults to 0.\n            maximize (bool, optional): maximize the params based on the objective, instead of minimizing. Defaults False.\n            contiguous_params (bool, optional): whether to use contiguous ParamGroup \n                which puts all parameters of the same type, device and group into the\n                same tensor and update them together. (default: False)\n        \n        For example: \n\n        Example 1: \n\n        .. code-block:: python\n\n            # Assume net is a custom model. \n            adadelta = flow.optim.Adadelta(net.parameters(), lr=1e-3)\n\n            for epoch in range(epochs):\n                # Read data, Compute the loss and so on. \n                # ...\n                loss.backward()\n                adadelta.step()\n                adadelta.zero_grad()\n\n        Example 2: \n\n        .. code-block:: python \n\n            # Assume net is a custom model. \n            adadelta = flow.optim.Adadelta(\n                [\n                    {\n                        \"params\": net.parameters(),\n                        \"lr\": learning_rate,\n                        \"clip_grad_max_norm\": 0.5,\n                        \"clip_grad_norm_type\": 2.0,\n                    }\n                ],\n            )\n\n            for epoch in range(epochs):\n                # Read data, Compute the loss and so on. \n                # ...\n                loss.backward()\n                adadelta.clip_grad()\n                adadelta.step()\n                adadelta.zero_grad()\n\n        If you want to use clip_grad, you can refer this example. \n\n        For more details of `clip_grad_max_norm` and `clip_grad_norm_type`, you can refer to :func:`oneflow.nn.utils.clip_grad_norm_`. \n        \n    \"\"\"\n\n    def __init__(\n        self,\n        params: Union[Iterator[Parameter], List[Dict]],\n        lr: float = 1.0,\n        rho: float = 0.9,\n        eps: float = 1e-6,\n        weight_decay: float = 0,\n        maximize: bool = False,\n        contiguous_params: bool = False,\n    ):\n        assert lr >= 0.0, f\"Invalid learning rate: {lr}\"\n        assert weight_decay >= 0.0, f\"Invalid weight_decay value: {weight_decay}\"\n        assert eps >= 0.0, f\"Invalid epsilon value: {eps}\"\n        assert 1.0 >= rho >= 0.0, f\"Invalid rho value: {rho}\"\n        assert (\n            not maximize\n        ), f\"In Graph Mode, weight decay has been added to Variable, it cause different result with Eager Mode when maximize = True\"\n        options = dict()\n        options[\"lr\"] = lr\n        options[\"rho\"] = rho\n        options[\"eps\"] = eps\n        options[\"maximize\"] = maximize\n        options[\"weight_decay\"] = weight_decay\n        options[\"contiguous_params\"] = contiguous_params\n        super().__init__(params, options)\n\n        for param_group in self.param_groups:\n            if param_group[\"contiguous_params\"]:\n                param_list = param_group.contiguous_parameters\n            else:\n                param_list = param_group.parameters\n\n            for param in param_list:\n                assert param.is_leaf, \"parameters must be leaf tensor\"\n                self.state[param] = dict()\n                self.state[param][\"square_avgs\"] = flow.zeros_like(param)\n                self.state[param][\"acc_deltas\"] = flow.zeros_like(param)\n\n        self._op = (\n            flow.stateful_op(\"adadelta_update\")\n            .Input(\"model\")\n            .Input(\"model_diff\")\n            .Input(\"square_avgs\")\n            .Input(\"acc_deltas\")\n            .Build()\n        )\n\n    def step(self, closure: Callable = None):\n        \"\"\"Performs a single optimization step.\n\n        Args:\n            closure (callable, optional): A closure that reevaluates the model\n                and returns the loss.\n        \"\"\"\n        with flow.no_grad():\n            loss = None\n            if closure is not None:\n                with flow.enable_grad():\n                    loss = closure()\n\n            for param_group in self.param_groups:\n                kwargs = {\n                    \"learning_rate\": param_group[\"lr\"],\n                    \"l2\": param_group[\"weight_decay\"],\n                    \"rho\": param_group[\"rho\"],\n                    \"epsilon\": param_group[\"eps\"],\n                    \"maximize\": param_group[\"maximize\"],\n                }\n\n                if param_group[\"contiguous_params\"]:\n                    param_list = param_group.contiguous_parameters\n                else:\n                    param_list = param_group.parameters\n\n                for param in param_list:\n                    if param.grad is None:\n                        continue\n                    square_avgs_tensor = self.state[param][\"square_avgs\"]\n                    acc_deltas_tensor = self.state[param][\"acc_deltas\"]\n                    flow._C.dispatch_adadelta_update(\n                        self._op,\n                        (param, param.grad, square_avgs_tensor, acc_deltas_tensor),\n                        **kwargs,\n                    )\n\n            self.state[\"step\"] = self.state[\"step\"] + 1\n            return loss\n\n    def _generate_conf_for_graph(self, train_conf, vars_conf):\n        new_opt_confs = []\n        for param_group in self.param_groups:\n            assert (\n                param_group[\"contiguous_params\"] != True\n            ), \"contiguous_params cannot be used in graph\"\n\n            optimizer_conf = train_conf.optimizer_conf.add()\n\n            lr = (\n                param_group[\"initial_lr\"]\n                if \"initial_lr\" in param_group\n                else param_group[\"lr\"]\n            )\n            l2 = param_group[\"weight_decay\"]\n            rho = param_group[\"rho\"]\n            epsilon = param_group[\"eps\"]\n            maximize = param_group[\"maximize\"]\n\n            optimizer_conf.base_learning_rate = lr\n            self._generate_lr_scale_for_optim_conf(param_group, optimizer_conf)\n\n            optimizer_conf.adadelta_conf.rho = rho\n            optimizer_conf.adadelta_conf.epsilon = epsilon\n            optimizer_conf.adadelta_conf.maximize = maximize\n\n            self._generate_grad_clip_conf_for_optim_conf(param_group, optimizer_conf)\n\n            for param in param_group.parameters:\n                vars_conf[param].l2 = l2\n                if param.requires_grad:\n                    optimizer_conf.variable_op_names.append(vars_conf[param].name)\n\n            new_opt_confs.append(optimizer_conf)\n        return new_opt_confs\n"
  },
  {
    "path": "python/oneflow/nn/optimizer/adagrad.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport collections\nimport math\nfrom typing import Callable, Dict, Iterator, List, Tuple, Union\n\nimport oneflow as flow\nfrom oneflow.optim.optimizer import Optimizer, ParamGroup\nfrom oneflow.nn.parameter import Parameter\n\n\nclass Adagrad(Optimizer):\n    r\"\"\"Implements Adagrad Optimizer. \n\n        The formula is: \n\n        .. math:: \n\n            & S_{t} = S_{t-1} + grad \\odot grad \n            \n            & decay\\_lr = \\frac{learning\\_rate}{(1 + (train\\_step - 1) * lr\\_decay)}\n\n            & X_{t} = X_{t-1} - \\frac{decay\\_lr}{\\sqrt{S_{t} + \\epsilon}} \\odot grad\n\n        Args:\n            params (Union[Iterator[Parameter], List[Dict]]): iterable of parameters to optimize or dicts defining\n            parameter groups\n            lr (float, optional): The learning rate. Defaults to 0.001.\n            lr_decay (float, optional): The decay factor of learning rate. Defaults to 0.0.\n            weight_decay (float, optional): The weight decay. Defaults to 0.\n            initial_accumulator_value (float, optional): The initial value of S. Defaults to 0.0.\n            eps (float, optional): A small constant terms added to the denominator to improve numerical stability. Defaults to 1e-10.\n            contiguous_params (bool, optional): whether to use contiguous ParamGroup \n                which puts all parameters of the same type, device and group into the\n                same tensor and update them together. (default: False)\n        \n        For example: \n\n        Example 1: \n\n        .. code-block:: python\n\n            # Assume net is a custom model. \n            adagrad = flow.optim.Adagrad(net.parameters(), lr=1e-3)\n\n            for epoch in range(epochs):\n                # Read data, Compute the loss and so on. \n                # ...\n                loss.backward()\n                adagrad.step()\n                adagrad.zero_grad()\n\n        Example 2: \n\n        .. code-block:: python \n\n            # Assume net is a custom model. \n            adagrad = flow.optim.Adagrad(\n                [\n                    {\n                        \"params\": net.parameters(),\n                        \"lr\": learning_rate,\n                        \"clip_grad_max_norm\": 0.5,\n                        \"clip_grad_norm_type\": 2.0,\n                    }\n                ],\n            )\n\n            for epoch in range(epochs):\n                # Read data, Compute the loss and so on. \n                # ...\n                loss.backward()\n                adagrad.clip_grad()\n                adagrad.step()\n                adagrad.zero_grad()\n\n        If you want to use clip_grad, you can refer this example. \n\n        For more details of `clip_grad_max_norm` and `clip_grad_norm_type`, you can refer to :func:`oneflow.nn.utils.clip_grad_norm_`. \n        \n        \"\"\"\n\n    def __init__(\n        self,\n        params: Union[Iterator[Parameter], List[Dict]],\n        lr: float = 0.001,\n        lr_decay: float = 0.0,\n        weight_decay: float = 0,\n        initial_accumulator_value: float = 0.0,\n        eps: float = 1e-10,\n        contiguous_params: bool = False,\n    ):\n        assert lr >= 0.0, f\"Invalid learning rate: {lr}\"\n        assert weight_decay >= 0.0, f\"Invalid weight_decay value: {weight_decay}\"\n        assert (\n            initial_accumulator_value >= 0.0\n        ), f\"Invalid initial_accumulator_value value: {initial_accumulator_value}\"\n        assert eps >= 0.0, f\"Invalid epsilon value: {eps}\"\n\n        options = dict()\n        options[\"lr\"] = lr\n        options[\"initial_accumulator_value\"] = initial_accumulator_value\n        options[\"lr_decay\"] = lr_decay\n        options[\"weight_decay\"] = weight_decay\n        options[\"eps\"] = eps\n        options[\"contiguous_params\"] = contiguous_params\n        super().__init__(params, options)\n\n        for param_group in self.param_groups:\n            if param_group[\"contiguous_params\"]:\n                param_list = param_group.contiguous_parameters\n            else:\n                param_list = param_group.parameters\n\n            for param in param_list:\n                assert param.is_leaf, \"parameters must be leaf tensor\"\n                self.state[param] = dict()\n                self.state[param][\"sum\"] = flow.zeros_like(param).fill_(\n                    param_group[\"initial_accumulator_value\"]\n                )\n\n        self._op = (\n            flow.stateful_op(\"adagrad_update\")\n            .Input(\"model\")\n            .Input(\"model_diff\")\n            .Input(\"sum\")\n            .Build()\n        )\n\n    def step(self, closure: Callable = None):\n        \"\"\"Performs a single optimization step.\n\n        Args:\n            closure (callable, optional): A closure that reevaluates the model\n                and returns the loss.\n        \"\"\"\n        with flow.no_grad():\n            loss = None\n            if closure is not None:\n                with flow.enable_grad():\n                    loss = closure()\n\n            for param_group in self.param_groups:\n                kwargs = {\n                    \"learning_rate\": param_group[\"lr\"],\n                    \"l2\": param_group[\"weight_decay\"],\n                    \"epsilon\": param_group[\"eps\"],\n                    \"lr_decay\": param_group[\"lr_decay\"],\n                    \"train_step_val\": self.state[\"step\"] + 1,\n                }\n\n                if param_group[\"contiguous_params\"]:\n                    param_list = param_group.contiguous_parameters\n                else:\n                    param_list = param_group.parameters\n\n                for param in param_list:\n                    if param.grad is None:\n                        continue\n                    sum_tensor = self.state[param][\"sum\"]\n                    flow._C.dispatch_adagrad_update(\n                        self._op, (param, param.grad, sum_tensor), **kwargs\n                    )\n\n            self.state[\"step\"] = self.state[\"step\"] + 1\n            return loss\n\n    def _generate_conf_for_graph(self, train_conf, vars_conf):\n        new_opt_confs = []\n        for param_group in self.param_groups:\n            assert (\n                param_group[\"contiguous_params\"] != True\n            ), \"contiguous_params cannot be used in graph\"\n\n            optimizer_conf = train_conf.optimizer_conf.add()\n\n            lr = (\n                param_group[\"initial_lr\"]\n                if \"initial_lr\" in param_group\n                else param_group[\"lr\"]\n            )\n            l2 = param_group[\"weight_decay\"]\n            initial_accumulator_value = param_group[\"initial_accumulator_value\"]\n            lr_decay = param_group[\"lr_decay\"]\n            epsilon = param_group[\"eps\"]\n\n            optimizer_conf.base_learning_rate = lr\n            self._generate_lr_scale_for_optim_conf(param_group, optimizer_conf)\n\n            optimizer_conf.adagrad_conf.initial_accumulator_value = (\n                initial_accumulator_value\n            )\n            optimizer_conf.adagrad_conf.lr_decay = lr_decay\n            optimizer_conf.adagrad_conf.epsilon = epsilon\n\n            self._generate_grad_clip_conf_for_optim_conf(param_group, optimizer_conf)\n\n            for param in param_group.parameters:\n                vars_conf[param].l2 = l2\n                if param.requires_grad:\n                    optimizer_conf.variable_op_names.append(vars_conf[param].name)\n\n            new_opt_confs.append(optimizer_conf)\n        return new_opt_confs\n"
  },
  {
    "path": "python/oneflow/nn/optimizer/adam.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport warnings\nimport math\nfrom typing import Callable, Dict, Iterator, List, Tuple, Union\n\nimport oneflow as flow\nfrom oneflow.optim.optimizer import Optimizer, ParamGroup\nfrom oneflow.nn.parameter import Parameter\n\n\nclass Adam(Optimizer):\n    \"\"\"Implements Adam algorithm.\n\n    It has been proposed in `Adam: A Method for Stochastic Optimization`_.\n    The implementation of the L2 penalty follows changes proposed in\n    `Decoupled Weight Decay Regularization`_.\n\n    This algorithm can adjust the learning rate of each parameter dynamically according to the 1st-moment estimates and the 2nd-moment estimates of gradient.\n\n    the equation of parameters updating is:\n\n    .. math::\n\n        & V_t = \\\\beta_1*V_{t-1} + (1-\\\\beta_1)*grad\n\n        & S_t = \\\\beta_2*S_{t-1} + (1-\\\\beta_2)*{grad} \\\\odot {grad}\n\n        & \\\\hat{g} = learning\\\\_rate*\\\\frac{{V_t}}{\\\\sqrt{{S_t}}+\\\\epsilon}\n\n        & param_{new} = param_{old} - \\\\hat{g}\n\n    Args:\n        params (iterable): iterable of parameters to optimize or dicts defining\n            parameter groups\n        lr (float, optional): learning rate (default: 1e-3)\n        betas (Tuple[float, float], optional): coefficients used for computing\n            running averages of gradient and its square (default: (0.9, 0.999))\n        eps (float, optional): term added to the denominator to improve\n            numerical stability (default: 1e-8)\n        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)\n        amsgrad (bool, optional): whether to use the AMSGrad variant of this algorithm. (default: False) \n        do_bias_correction (bool, optional): whether to do bias correction (default: True)\n        contiguous_params (bool, optional): whether to use contiguous ParamGroup \n            which puts all parameters of the same type, device and group into the\n            same tensor and update them together. (default: False)\n        fused (bool, optional): whether to divide all the parameters into several groups, then\n            update each group of parameters with the fused kernel. (default: False)\n\n    .. _Adam\\\\: A Method for Stochastic Optimization:\n        https://arxiv.org/abs/1412.6980\n\n    .. _Decoupled Weight Decay Regularization:\n        https://arxiv.org/abs/1711.05101\n\n    For example: \n\n    Example 1: \n\n    .. code-block:: python \n\n        # Assume net is a custom model. \n        adam = flow.optim.Adam(net.parameters(), lr=1e-3)\n\n        for epoch in range(epochs):\n            # Read data, Compute the loss and so on. \n            # ...\n            loss.backward()\n            adam.step()\n            adam.zero_grad()\n\n    Example 2: \n\n    .. code-block:: python \n\n        # Assume net is a custom model. \n        adam = flow.optim.Adam(\n            [\n                {\n                    \"params\": net.parameters(),\n                    \"lr\": learning_rate,\n                    \"clip_grad_max_norm\": 0.5,\n                    \"clip_grad_norm_type\": 2.0,\n                }\n            ],\n        )\n\n        for epoch in range(epochs):\n            # Read data, Compute the loss and so on. \n            # ...\n            loss.backward()\n            adam.clip_grad()\n            adam.step()\n            adam.zero_grad()\n\n    If you want to use clip_grad, you can refer this example. \n\n    For more details of `clip_grad_max_norm` and `clip_grad_norm_type`, you can refer to :func:`oneflow.nn.utils.clip_grad_norm_`. \n\n    \n    \"\"\"\n\n    def __init__(\n        self,\n        params: Union[Iterator[Parameter], List[Dict]],\n        lr: float = 0.001,\n        betas: Tuple[float, float] = (0.9, 0.999),\n        eps: float = 1e-08,\n        weight_decay: float = 0,\n        amsgrad: bool = False,\n        do_bias_correction: bool = True,\n        contiguous_params: bool = False,\n        fused: bool = False,\n    ):\n        assert lr >= 0.0, f\"Invalid learning rate: {lr}\"\n        assert eps >= 0.0, f\"Invalid epsilon value: {eps}\"\n        assert (\n            betas[0] >= 0.0 and betas[0] < 1.0\n        ), f\"Invalid beta parameter at index 0: {betas[0]}\"\n        assert (\n            betas[1] >= 0.0 and betas[1] < 1.0\n        ), f\"Invalid beta parameter at index 1: {betas[1]}\"\n        assert weight_decay >= 0.0, f\"Invalid weight_decay value: {weight_decay}\"\n        options = dict()\n        options[\"lr\"] = lr\n        options[\"eps\"] = eps\n        options[\"betas\"] = betas\n        options[\"weight_decay\"] = weight_decay\n        options[\"amsgrad\"] = amsgrad\n        options[\"bias_correction1\"] = 1.0\n        options[\"bias_correction2\"] = 1.0\n        options[\"do_bias_correction\"] = do_bias_correction\n        options[\"contiguous_params\"] = contiguous_params\n        options[\"fused\"] = fused\n        super().__init__(params, options)\n\n        for param_group in self.param_groups:\n            if param_group[\"contiguous_params\"]:\n                param_list = param_group.contiguous_parameters\n            else:\n                param_list = param_group.parameters\n\n            for param in param_list:\n                assert param.is_leaf, \"parameters must be leaf tensor\"\n                self.state[param] = dict()\n\n                if param_group[\"fused\"] and param_group[\"amsgrad\"]:\n                    warnings.warn(\"Fused Adam is not supported when amsgrad=True.\")\n                    param_group[\"fused\"] = False\n\n                if param_group[\"fused\"] and not param.is_cuda:\n                    warnings.warn(\"Fused Adam only support cuda parameters.\")\n                    param_group[\"fused\"] = False\n\n        self._op_with_amsgrad = (\n            flow.stateful_op(\"adam_update\")\n            .Input(\"model\")\n            .Input(\"model_diff\")\n            .Input(\"m\")\n            .Input(\"v\")\n            .Input(\"max_v\")\n            .Build()\n        )\n\n        self._op_without_amsgrad = (\n            flow.stateful_op(\"adam_update\")\n            .Input(\"model\")\n            .Input(\"model_diff\")\n            .Input(\"m\")\n            .Input(\"v\")\n            .Build()\n        )\n\n    def _single_tensor_update(self, param_group):\n        kwargs = {\n            \"learning_rate\": param_group[\"lr\"],\n            \"bias_correction1\": param_group[\"bias_correction1\"],\n            \"bias_correction2\": param_group[\"bias_correction2\"],\n            \"l2\": param_group[\"weight_decay\"],\n            \"beta1\": param_group[\"betas\"][0],\n            \"beta2\": param_group[\"betas\"][1],\n            \"epsilon\": param_group[\"eps\"],\n            \"do_bias_correction\": param_group[\"do_bias_correction\"],\n            \"amsgrad\": param_group[\"amsgrad\"],\n        }\n\n        if param_group[\"contiguous_params\"]:\n            param_list = param_group.contiguous_parameters\n        else:\n            param_list = param_group.parameters\n\n        for param in param_list:\n            if param.grad is None:\n                continue\n            if \"exp_avg\" not in self.state[param]:\n                self.state[param][\"exp_avg\"] = flow.zeros_like(param)\n            if \"exp_avg_sq\" not in self.state[param]:\n                self.state[param][\"exp_avg_sq\"] = flow.zeros_like(param)\n            if param_group[\"amsgrad\"]:\n                if \"max_exp_avg_sq\" not in self.state[param]:\n                    self.state[param][\"max_exp_avg_sq\"] = flow.zeros_like(param)\n\n            m_tensor = self.state[param][\"exp_avg\"]\n            v_tensor = self.state[param][\"exp_avg_sq\"]\n\n            if param_group[\"amsgrad\"]:\n                max_v_tensor = self.state[param][\"max_exp_avg_sq\"]\n                flow._C.dispatch_adam_update(\n                    self._op_with_amsgrad,\n                    (param, param.grad, m_tensor, v_tensor, max_v_tensor),\n                    **kwargs,\n                )\n            else:\n                flow._C.dispatch_adam_update(\n                    self._op_without_amsgrad,\n                    (param, param.grad, m_tensor, v_tensor),\n                    **kwargs,\n                )\n\n    def _fused_update(self, param_group):\n        param_list = []\n        param_grad_list = []\n        m_tensor_list = []\n        v_tensor_list = []\n\n        for param in param_group.parameters:\n            if param.grad is None:\n                continue\n\n            if \"exp_avg\" not in self.state[param]:\n                self.state[param][\"exp_avg\"] = flow.zeros_like(param)\n            if \"exp_avg_sq\" not in self.state[param]:\n                self.state[param][\"exp_avg_sq\"] = flow.zeros_like(param)\n            if param_group[\"amsgrad\"]:\n                if \"max_exp_avg_sq\" not in self.state[param]:\n                    self.state[param][\"max_exp_avg_sq\"] = flow.zeros_like(param)\n\n            param_list.append(param)\n            param_grad_list.append(param.grad)\n            m_tensor_list.append(self.state[param][\"exp_avg\"])\n            v_tensor_list.append(self.state[param][\"exp_avg_sq\"])\n\n        flow._C.multi_tensor_adam_update(\n            model=param_list,\n            model_diff=param_grad_list,\n            m=m_tensor_list,\n            v=v_tensor_list,\n            learning_rate_val=param_group[\"lr\"],\n            l2=param_group[\"weight_decay\"],\n            beta1=param_group[\"betas\"][0],\n            beta2=param_group[\"betas\"][1],\n            bias_correction1_val=param_group[\"bias_correction1\"],\n            bias_correction2_val=param_group[\"bias_correction2\"],\n            do_bias_correction=param_group[\"do_bias_correction\"],\n            scale=1.0,\n            weight_decay=0.0,\n            epsilon=param_group[\"eps\"],\n        )\n\n    def step(self, closure: Callable = None):\n        \"\"\"Performs a single optimization step.\n\n        Args:\n            closure (callable, optional): A closure that reevaluates the model\n                and returns the loss.\n        \"\"\"\n        with flow.no_grad():\n            loss = None\n            if closure is not None:\n                with flow.enable_grad():\n                    loss = closure()\n\n            for param_group in self.param_groups:\n                if param_group[\"do_bias_correction\"]:\n                    param_group[\"bias_correction1\"] = 1.0 - math.pow(\n                        param_group[\"betas\"][0], self.state[\"step\"] + 1\n                    )\n                    param_group[\"bias_correction2\"] = 1.0 - math.pow(\n                        param_group[\"betas\"][1], self.state[\"step\"] + 1\n                    )\n\n                if param_group[\"fused\"]:\n                    self._fused_update(param_group)\n                else:\n                    self._single_tensor_update(param_group)\n\n            self.state[\"step\"] += 1\n\n            return loss\n\n    def _generate_conf_for_graph(self, train_conf, vars_conf):\n        new_opt_confs = []\n        for param_group in self.param_groups:\n            assert (\n                param_group[\"contiguous_params\"] != True\n            ), \"contiguous_params cannot be used in graph\"\n\n            optimizer_conf = train_conf.optimizer_conf.add()\n\n            lr = (\n                param_group[\"initial_lr\"]\n                if \"initial_lr\" in param_group\n                else param_group[\"lr\"]\n            )\n            l2 = param_group[\"weight_decay\"]\n            beta1 = param_group[\"betas\"][0]\n            beta2 = param_group[\"betas\"][1]\n\n            epsilon = param_group[\"eps\"]\n            do_bias_correction = param_group[\"do_bias_correction\"]\n            amsgrad = param_group[\"amsgrad\"]\n\n            optimizer_conf.base_learning_rate = lr\n            self._generate_lr_scale_for_optim_conf(param_group, optimizer_conf)\n\n            optimizer_conf.adam_conf.beta1 = beta1\n            optimizer_conf.adam_conf.beta2 = beta2\n            optimizer_conf.adam_conf.epsilon = epsilon\n            optimizer_conf.adam_conf.do_bias_correction = do_bias_correction\n            optimizer_conf.adam_conf.amsgrad = amsgrad\n\n            self._generate_grad_clip_conf_for_optim_conf(param_group, optimizer_conf)\n\n            for param in param_group.parameters:\n                vars_conf[param].l2 = l2\n                if param.requires_grad:\n                    optimizer_conf.variable_op_names.append(vars_conf[param].name)\n\n            new_opt_confs.append(optimizer_conf)\n        return new_opt_confs\n\n    @property\n    def support_sparse(self):\n        return True\n"
  },
  {
    "path": "python/oneflow/nn/optimizer/adamw.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport warnings\nimport math\nfrom typing import Callable, Dict, Iterator, List, Tuple, Union\n\nimport oneflow as flow\nfrom oneflow.optim.optimizer import Optimizer, ParamGroup\nfrom oneflow.nn.parameter import Parameter\n\n\nclass AdamW(Optimizer):\n    \"\"\"Implements AdamW algorithm.\n\n    The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_.\n    The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_.\n\n    The optimizer of the Adam-weight-decay algorithm.\n\n    (More details please refer to `Adam-weight-decay <https://www.fast.ai/2018/07/02/adam-weight-decay/>`_).\n\n    So we use Adam-weight-decay algorithm to solve this problem.\n\n    the equation of parameters updating is:\n\n    .. math::\n\n        & V_t = \\\\beta_1*V_{t-1} + (1-\\\\beta_1)*grad\n\n        & S_t = \\\\beta_2*S_{t-1} + (1-\\\\beta_2)*{grad} \\\\odot {grad}\n\n        & \\\\hat{g} = learning\\\\_rate*(\\\\frac{{V_t}}{\\\\sqrt{{S_t}}+\\\\epsilon}+\\\\lambda*param_{old})\n\n        & param_{new} = param_{old} - \\\\hat{g}\n\n    Args:\n        params (iterable): iterable of parameters to optimize or dicts defining\n            parameter groups\n        lr (float, optional): learning rate (default: 1e-3)\n        betas (Tuple[float, float], optional): coefficients used for computing\n            running averages of gradient and its square (default: (0.9, 0.999))\n        eps (float, optional): term added to the denominator to improve\n            numerical stability (default: 1e-8)\n        weight_decay (float, optional): weight decay (L2 penalty) (In the equation is λ, default: 0)\n        amsgrad (bool, optional): whether to use the AMSGrad variant of this algorithm. (default: False) \n        do_bias_correction (bool, optional): whether to do bias correction (default: True)\n        contiguous_params (bool, optional): whether to use contiguous ParamGroup \n            which puts all parameters of the same type, device and group into the\n            same tensor and update them together. (default: False)\n        fused (bool, optional): whether to divide all the parameters into several groups, then\n            update each group of parameters with the fused kernel. (default: False)\n\n    .. _Adam\\\\: A Method for Stochastic Optimization:\n        https://arxiv.org/abs/1412.6980\n\n    .. _Decoupled Weight Decay Regularization:\n        https://arxiv.org/abs/1711.05101\n\n    For example: \n\n    Example 1: \n\n    .. code-block:: python \n\n        # Assume net is a custom model. \n        adamw = flow.optim.AdamW(net.parameters(), lr=1e-3)\n\n        for epoch in range(epochs):\n            # Read data, Compute the loss and so on. \n            # ...\n            loss.backward()\n            adamw.step()\n            adamw.zero_grad()\n\n    Example 2: \n\n    .. code-block:: python \n\n        # Assume net is a custom model. \n        adamw = flow.optim.AdamW(\n            [\n                {\n                    \"params\": net.parameters(),\n                    \"lr\": learning_rate,\n                    \"clip_grad_max_norm\": 0.5,\n                    \"clip_grad_norm_type\": 2.0,\n                }\n            ],\n        )\n\n        for epoch in range(epochs):\n            # Read data, Compute the loss and so on. \n            # ...\n            loss.backward()\n            adamw.clip_grad()\n            adamw.step()\n            adamw.zero_grad()\n\n    If you want to use clip_grad, you can refer this example. \n\n    For more details of `clip_grad_max_norm` and `clip_grad_norm_type`, you can refer to :func:`oneflow.nn.utils.clip_grad_norm_`. \n\n    \"\"\"\n\n    def __init__(\n        self,\n        params: Union[Iterator[Parameter], List[Dict]],\n        lr: float = 0.001,\n        betas: Tuple[float, float] = (0.9, 0.999),\n        eps: float = 1e-08,\n        weight_decay: float = 0,\n        amsgrad: bool = False,\n        do_bias_correction: bool = True,\n        contiguous_params: bool = False,\n        fused: bool = False,\n    ):\n        assert lr >= 0.0, f\"Invalid learning rate: {lr}\"\n        assert eps >= 0.0, f\"Invalid epsilon value: {eps}\"\n        assert (\n            betas[0] >= 0.0 and betas[0] < 1.0\n        ), f\"Invalid beta parameter at index 0: {betas[0]}\"\n        assert (\n            betas[1] >= 0.0 and betas[1] < 1.0\n        ), f\"Invalid beta parameter at index 1: {betas[1]}\"\n        assert weight_decay >= 0.0, f\"Invalid weight_decay value: {weight_decay}\"\n        options = dict()\n        options[\"lr\"] = lr\n        options[\"eps\"] = eps\n        options[\"betas\"] = betas\n        options[\"weight_decay\"] = weight_decay\n        options[\"bias_correction1\"] = 1.0\n        options[\"bias_correction2\"] = 1.0\n        options[\"do_bias_correction\"] = do_bias_correction\n        options[\"amsgrad\"] = amsgrad\n        options[\"contiguous_params\"] = contiguous_params\n        options[\"fused\"] = fused\n        super().__init__(params, options)\n\n        for param_group in self.param_groups:\n            if param_group[\"contiguous_params\"]:\n                param_list = param_group.contiguous_parameters\n            else:\n                param_list = param_group.parameters\n\n            for param in param_list:\n                assert param.is_leaf, \"parameters must be leaf tensor\"\n                self.state[param] = dict()\n\n                if param_group[\"fused\"] and param_group[\"amsgrad\"]:\n                    warnings.warn(\"Fused Adamw is not supported when amsgrad=True.\")\n                    param_group[\"fused\"] = False\n\n                if param_group[\"fused\"] and not param.is_cuda:\n                    warnings.warn(\"Fused Adamw only support cuda parameters.\")\n                    param_group[\"fused\"] = False\n\n        self._op_with_amsgrad = (\n            flow.stateful_op(\"adam_update\")\n            .Input(\"model\")\n            .Input(\"model_diff\")\n            .Input(\"m\")\n            .Input(\"v\")\n            .Input(\"max_v\")\n            .Build()\n        )\n        self._op_without_amsgrad = (\n            flow.stateful_op(\"adam_update\")\n            .Input(\"model\")\n            .Input(\"model_diff\")\n            .Input(\"m\")\n            .Input(\"v\")\n            .Build()\n        )\n\n    def _single_tensor_update(self, param_group):\n        kwargs = {\n            \"learning_rate\": param_group[\"lr\"],\n            \"bias_correction1\": param_group[\"bias_correction1\"],\n            \"bias_correction2\": param_group[\"bias_correction2\"],\n            \"weight_decay\": param_group[\"weight_decay\"],\n            \"beta1\": param_group[\"betas\"][0],\n            \"beta2\": param_group[\"betas\"][1],\n            \"epsilon\": param_group[\"eps\"],\n            \"do_bias_correction\": param_group[\"do_bias_correction\"],\n            \"amsgrad\": param_group[\"amsgrad\"],\n        }\n\n        if param_group[\"contiguous_params\"]:\n            param_list = param_group.contiguous_parameters\n        else:\n            param_list = param_group.parameters\n\n        for param in param_list:\n            if param.grad is None:\n                continue\n\n            if \"exp_avg\" not in self.state[param]:\n                self.state[param][\"exp_avg\"] = flow.zeros_like(param)\n            if \"exp_avg_sq\" not in self.state[param]:\n                self.state[param][\"exp_avg_sq\"] = flow.zeros_like(param)\n            if param_group[\"amsgrad\"]:\n                if \"max_exp_avg_sq\" not in self.state[param]:\n                    self.state[param][\"max_exp_avg_sq\"] = flow.zeros_like(param)\n            m_tensor = self.state[param][\"exp_avg\"]\n            v_tensor = self.state[param][\"exp_avg_sq\"]\n\n            if param_group[\"amsgrad\"]:\n                max_v_tensor = self.state[param][\"max_exp_avg_sq\"]\n                flow._C.dispatch_adam_update(\n                    self._op_with_amsgrad,\n                    (param, param.grad, m_tensor, v_tensor, max_v_tensor),\n                    **kwargs,\n                )\n            else:\n                flow._C.dispatch_adam_update(\n                    self._op_without_amsgrad,\n                    (param, param.grad, m_tensor, v_tensor),\n                    **kwargs,\n                )\n\n    def _fused_update(self, param_group):\n        param_list = []\n        param_grad_list = []\n        m_tensor_list = []\n        v_tensor_list = []\n\n        for param in param_group.parameters:\n            if param.grad is None:\n                continue\n\n            if \"exp_avg\" not in self.state[param]:\n                self.state[param][\"exp_avg\"] = flow.zeros_like(param)\n            if \"exp_avg_sq\" not in self.state[param]:\n                self.state[param][\"exp_avg_sq\"] = flow.zeros_like(param)\n            if param_group[\"amsgrad\"]:\n                if \"max_exp_avg_sq\" not in self.state[param]:\n                    self.state[param][\"max_exp_avg_sq\"] = flow.zeros_like(param)\n\n            param_list.append(param)\n            param_grad_list.append(param.grad)\n            m_tensor_list.append(self.state[param][\"exp_avg\"])\n            v_tensor_list.append(self.state[param][\"exp_avg_sq\"])\n\n        flow._C.multi_tensor_adam_update(\n            model=param_list,\n            model_diff=param_grad_list,\n            m=m_tensor_list,\n            v=v_tensor_list,\n            learning_rate_val=param_group[\"lr\"],\n            l2=0.0,\n            beta1=param_group[\"betas\"][0],\n            beta2=param_group[\"betas\"][1],\n            bias_correction1_val=param_group[\"bias_correction1\"],\n            bias_correction2_val=param_group[\"bias_correction2\"],\n            do_bias_correction=param_group[\"do_bias_correction\"],\n            scale=1.0,\n            weight_decay=param_group[\"weight_decay\"],\n            epsilon=param_group[\"eps\"],\n        )\n\n    def step(self, closure: Callable = None):\n        \"\"\"Performs a single optimization step.\n\n        Args:\n            closure (callable, optional): A closure that reevaluates the model\n                and returns the loss.\n        \"\"\"\n        with flow.no_grad():\n            loss = None\n            if closure is not None:\n                with flow.enable_grad():\n                    loss = closure()\n\n            for param_group in self.param_groups:\n                if param_group[\"do_bias_correction\"]:\n                    param_group[\"bias_correction1\"] = 1.0 - math.pow(\n                        param_group[\"betas\"][0], self.state[\"step\"] + 1\n                    )\n                    param_group[\"bias_correction2\"] = 1.0 - math.pow(\n                        param_group[\"betas\"][1], self.state[\"step\"] + 1\n                    )\n\n                if param_group[\"fused\"]:\n                    self._fused_update(param_group)\n                else:\n                    self._single_tensor_update(param_group)\n\n            self.state[\"step\"] += 1\n            return loss\n\n    def _generate_conf_for_graph(self, train_conf, vars_conf):\n        new_opt_confs = []\n        for param_group in self.param_groups:\n            assert (\n                param_group[\"contiguous_params\"] != True\n            ), \"contiguous_params cannot be used in graph\"\n\n            optimizer_conf = train_conf.optimizer_conf.add()\n            lr = (\n                param_group[\"initial_lr\"]\n                if \"initial_lr\" in param_group\n                else param_group[\"lr\"]\n            )\n            weight_decay = param_group[\"weight_decay\"]\n            beta1 = param_group[\"betas\"][0]\n            beta2 = param_group[\"betas\"][1]\n            epsilon = param_group[\"eps\"]\n            do_bias_correction = param_group[\"do_bias_correction\"]\n            amsgrad = param_group[\"amsgrad\"]\n\n            optimizer_conf.base_learning_rate = lr\n            self._generate_lr_scale_for_optim_conf(param_group, optimizer_conf)\n\n            optimizer_conf.adam_conf.beta1 = beta1\n            optimizer_conf.adam_conf.beta2 = beta2\n            optimizer_conf.adam_conf.epsilon = epsilon\n            optimizer_conf.adam_conf.do_bias_correction = do_bias_correction\n            optimizer_conf.adam_conf.amsgrad = amsgrad\n\n            optimizer_conf.weight_decay_conf.weight_decay_rate = weight_decay\n\n            self._generate_grad_clip_conf_for_optim_conf(param_group, optimizer_conf)\n\n            for param in param_group.parameters:\n                if param.requires_grad:\n                    optimizer_conf.variable_op_names.append(vars_conf[param].name)\n\n            new_opt_confs.append(optimizer_conf)\n        return new_opt_confs\n\n    @property\n    def support_sparse(self):\n        \"\"\"Whether AdamW Optimizer support sparse update. \n\n        \"\"\"\n        return True\n"
  },
  {
    "path": "python/oneflow/nn/optimizer/chained_scheduler.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nfrom .lr_scheduler import LRScheduler\n\n\nclass ChainedScheduler(LRScheduler):\n    \"\"\"Chains list of learning rate schedulers. It takes a list of chainable learning\n    rate schedulers and performs consecutive step() functions belong to them by just\n    one call.\n\n    Args:\n        schedulers (list): List of chained schedulers.\n\n    Example:\n        >>> # Assuming optimizer uses lr = 1. for all groups\n        >>> # lr = 0.09     if step == 0\n        >>> # lr = 0.081    if step == 1\n        >>> # lr = 0.729    if step == 2\n        >>> # lr = 0.6561   if step == 3\n        >>> # lr = 0.59049  if step >= 4\n        >>> scheduler1 = ConstantLR(self.opt, factor=0.1, total_iters=2)\n        >>> scheduler2 = ExponentialLR(self.opt, gamma=0.9)\n        >>> scheduler = ChainedScheduler([scheduler1, scheduler2])\n        >>> for _ in range(100):\n        >>>     train(...)\n        >>>     validate(...)\n        >>>     scheduler.step()\n    \"\"\"\n\n    def __init__(self, schedulers):\n        if not isinstance(schedulers, (list, tuple)) or any(\n            not isinstance(s, LRScheduler) for s in schedulers\n        ):\n            raise ValueError(\"ChainedScheduler expects a list of schedulers\")\n\n        if len(schedulers) == 0:\n            raise ValueError(\"length of list of schedulers must be greater than 0\")\n\n        opt = schedulers[0].optimizer\n\n        for i in range(1, len(schedulers)):\n            if schedulers[i].optimizer != opt:\n                raise ValueError(\n                    \"ChainedScheduler expects all schedulers to belong to the same optimizer, but \"\n                    f\"got schedulers at index {0} and {i} to be different\"\n                )\n\n        self.schedulers = list(schedulers)\n        super().__init__(optimizer=opt)\n\n    def step(self):\n        self.last_step += 1\n        lrs = self.schedulers[0].base_lrs.copy()\n        for scheduler in self.schedulers:\n            for i, lr in enumerate(lrs):\n                lrs[i] = scheduler.get_lr(lr, self.last_step)\n\n            scheduler.last_step = self.last_step\n\n        self.update_lrs(lrs)\n\n    def state_dict(self):\n        \"\"\"Returns the state of the scheduler as a :class:`dict`.\n\n        It contains an entry for every variable in self.__dict__ which\n        is not the optimizer.\n        The wrapped scheduler states will also be saved.\n        \"\"\"\n        state_dict = {\n            key: value\n            for key, value in self.__dict__.items()\n            if key not in (\"optimizer\", \"schedulers\")\n        }\n        state_dict[\"schedulers\"] = [None] * len(self.schedulers)\n        for i, s in enumerate(self.schedulers):\n            state_dict[\"schedulers\"][i] = s.state_dict()\n\n        return state_dict\n\n    def load_state_dict(self, state_dict):\n        \"\"\"Loads the schedulers state.\n\n        Args:\n            state_dict (dict): scheduler state. Should be an object returned\n                from a call to :meth:`state_dict`.\n        \"\"\"\n        scheduler_states = state_dict.pop(\"schedulers\")\n        self.__dict__.update(state_dict)\n        # avoid side effect of calling load_state_dict twice\n        state_dict[\"schedulers\"] = scheduler_states\n\n        for i, s in enumerate(scheduler_states):\n            self.schedulers[i].load_state_dict(s)\n\n    def _generate_conf_for_graph(self, lr_conf):\n        raise NotImplementedError(\"ChainedScheduler is not supported in graph mode yet\")\n"
  },
  {
    "path": "python/oneflow/nn/optimizer/constant_lr.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nfrom ...optim.optimizer import Optimizer\nfrom .lr_scheduler import LRScheduler\n\n\nclass ConstantLR(LRScheduler):\n    \"\"\"Decays the learning rate of each parameter group by a small constant factor until the\n    number of step reaches a pre-defined milestone: total_iters.\n\n    Args:\n        optimizer (Optimizer): Wrapped optimizer.\n        factor (float): The number we multiply learning rate until the milestone. Default: 1./3.\n        total_iters (int): The number of steps that the scheduler decays the learning rate.\n            Default: 5.\n        last_step (int): The last step. Default: -1.\n        verbose (bool): If ``True``, prints a message to stdout for\n            each step. Default: ``False``.\n\n    Example:\n        >>> # Assuming optimizer uses lr = 0.05 for all groups\n        >>> # lr = 0.025   if step == 0\n        >>> # lr = 0.025   if step == 1\n        >>> # lr = 0.025   if step == 2\n        >>> # lr = 0.025   if step == 3\n        >>> # lr = 0.05    if step >= 4\n        >>> scheduler = ConstantLR(self.opt, factor=0.5, total_iters=4)\n        >>> for step in range(100):\n        >>>     train(...)\n        >>>     validate(...)\n        >>>     scheduler.step()\n    \"\"\"\n\n    def __init__(\n        self,\n        optimizer: Optimizer,\n        factor: float = 1.0 / 3,\n        total_iters: int = 5,\n        last_step: int = -1,\n        verbose: bool = False,\n    ):\n        assert isinstance(optimizer, Optimizer)\n\n        if factor > 1.0 or factor < 0:\n            raise ValueError(\n                \"Constant multiplicative factor expected to be between 0 and 1.\"\n            )\n\n        self.factor = factor\n        self.total_iters = total_iters\n        super().__init__(optimizer, last_step, verbose)\n\n    def get_lr(self, base_lr, step):\n        if step < self.total_iters:\n            return base_lr * self.factor\n\n        return base_lr\n\n    def _generate_conf_for_graph(self, lr_conf):\n        lr_conf.constant_lr_conf.SetInParent()\n        constant_lr_conf = lr_conf.constant_lr_conf\n        constant_lr_conf.factor = self.factor\n        constant_lr_conf.total_iters = self.total_iters\n"
  },
  {
    "path": "python/oneflow/nn/optimizer/cosine_annealing_lr.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport math\n\nfrom ...optim.optimizer import Optimizer\nfrom .lr_scheduler import LRScheduler\n\n\nclass CosineAnnealingLR(LRScheduler):\n    r\"\"\"\n    Set the learning rate of each parameter group using a cosine annealing\n    schedule, where :math:`\\eta_{max}` is set to the initial lr and\n    :math:`T_{cur}` is the number of epochs since the last restart in SGDR:\n\n    .. math::\n        \\begin{aligned}\n            \\eta_t & = \\eta_{min} + \\frac{1}{2}(\\eta_{max} - \\eta_{min})\\left(1\n            + \\cos\\left(\\frac{T_{cur}}{T_{max}}\\pi\\right)\\right),\n            & T_{cur} \\neq (2k+1)T_{max}; \\\\\n            \\eta_{t+1} & = \\eta_{t} + \\frac{1}{2}(\\eta_{max} - \\eta_{min})\n            \\left(1 - \\cos\\left(\\frac{1}{T_{max}}\\pi\\right)\\right),\n            & T_{cur} = (2k+1)T_{max}.\n        \\end{aligned}\n\n    When last_step=-1, sets initial lr as lr. Notice that because the schedule\n    is defined recursively, the learning rate can be simultaneously modified\n    outside this scheduler by other operators. If the learning rate is set\n    solely by this scheduler, the learning rate at each step becomes:\n\n    .. math::\n        \\eta_t = \\eta_{min} + \\frac{1}{2}(\\eta_{max} - \\eta_{min})\\left(1 +\n        \\cos\\left(\\frac{T_{cur}}{T_{max}}\\pi\\right)\\right)\n\n    It has been proposed in\n    `SGDR: Stochastic Gradient Descent with Warm Restarts`_. Note that this only\n    implements the cosine annealing part of SGDR, and not the restarts.\n\n    The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.optim.lr_scheduler.CosineAnnealingLR.html.\n\n    Args:\n        optimizer (Optimizer): Wrapped optimizer.\n        T_max (int): Maximum number of iterations.\n        eta_min (float): Minimum learning rate. Default: 0.\n        last_step (int): The index of last epoch. Default: -1.\n        verbose (bool): If ``True``, prints a message to stdout for\n            each update. Default: ``False``.\n\n    .. _SGDR\\: Stochastic Gradient Descent with Warm Restarts:\n        https://arxiv.org/abs/1608.03983\n    \"\"\"\n\n    def __init__(\n        self,\n        optimizer: Optimizer,\n        T_max: int,\n        eta_min: float = 0.0,\n        last_step: int = -1,\n        verbose: bool = False,\n    ):\n        self.T_max = T_max\n        self.eta_min = eta_min\n        super().__init__(optimizer, last_step, verbose)\n\n    def get_lr(self, base_lr, step):\n        cos_decay = 0.5 * (1 + math.cos(math.pi * step / self.T_max))\n        return self.eta_min + (base_lr - self.eta_min) * cos_decay\n\n    def _generate_conf_for_graph(self, lr_conf):\n        lr_conf.cosine_annealing_conf.SetInParent()\n        cosine_annealing_conf = lr_conf.cosine_annealing_conf\n        cosine_annealing_conf.t_max = self.T_max\n        cosine_annealing_conf.eta_min = self.eta_min\n"
  },
  {
    "path": "python/oneflow/nn/optimizer/cosine_annealing_warm_restarts.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport math\nfrom ...optim.optimizer import Optimizer\nfrom .lr_scheduler import LRScheduler\n\n\nclass CosineAnnealingWarmRestarts(LRScheduler):\n    r\"\"\"Set the learning rate of each parameter group using a cosine annealing\n    schedule, where :math:`\\eta_{max}` is set to the initial lr, :math:`T_{cur}`\n    is the number of steps since the last restart and :math:`T_{i}` is the number\n    of steps between two warm restarts in SGDR:\n\n    .. math::\n        \\eta_t = \\eta_{min} + \\frac{1}{2}(\\eta_{max} - \\eta_{min})\\left(1 +\n        \\cos\\left(\\frac{T_{cur}}{T_{i}}\\pi\\right)\\right)\n\n    When :math:`T_{cur}=T_{i}`, set :math:`\\eta_t = \\eta_{min}`.\n    When :math:`T_{cur}=0` after restart, set :math:`\\eta_t=\\eta_{max}`.\n\n    It has been proposed in\n    `SGDR: Stochastic Gradient Descent with Warm Restarts`_.\n\n    Args:\n        optimizer (Optimizer): Wrapped optimizer.\n        T_0 (int): Number of iterations for the first restart.\n        T_mult (int, optional): A factor increases :math:`T_{i}` after a restart. Default: 1.\n        eta_min (float, optional): Minimum learning rate. Default: 0.\n        decay_rate (float, optional): Decay rate every restarts.\n        restart_limit (int, optional): The limit of restarts. 0 indicate unlimited restarts. Default: 0.\n        last_step (int, optional): The index of last step. Default: -1.\n        verbose (bool): If ``True``, prints a message to stdout for\n            each update. Default: ``False``.\n\n    .. _SGDR\\: Stochastic Gradient Descent with Warm Restarts:\n        https://arxiv.org/abs/1608.03983\n    \"\"\"\n\n    def __init__(\n        self,\n        optimizer: Optimizer,\n        T_0: int,\n        T_mult: int = 1,\n        eta_min: float = 0.0,\n        decay_rate: float = 1.0,\n        restart_limit: int = 0,\n        last_step: int = -1,\n        verbose: bool = False,\n    ):\n        assert isinstance(optimizer, Optimizer)\n        if T_0 <= 0 or not isinstance(T_0, int):\n            raise ValueError(f\"Expected positive integer T_0, but got {T_0}\")\n\n        if T_mult < 1 or not isinstance(T_mult, int):\n            raise ValueError(f\"Expected integer T_mult >= 1, but got {T_mult}\")\n\n        self.T_0 = T_0\n        self.T_mult = T_mult\n        self.eta_min = eta_min\n        self.decay_rate = decay_rate\n        self.restart_limit = restart_limit\n\n        super().__init__(optimizer, last_step, verbose)\n\n    def get_lr(self, base_lr, step):\n        if self.T_mult > 1:\n            epoch = math.floor(\n                math.log(1 - step / self.T_0 * (1 - self.T_mult), self.T_mult)\n            )\n            epoch_steps = self.T_mult ** epoch * self.T_0\n            step_in_epoch = (\n                step - (1 - self.T_mult ** epoch) / (1 - self.T_mult) * self.T_0\n            )\n        else:\n            epoch = step // self.T_0\n            epoch_steps = self.T_0\n            step_in_epoch = step - (epoch_steps * epoch)\n\n        gamma = self.decay_rate ** epoch\n        if self.restart_limit == 0 or (\n            self.restart_limit > 0 and epoch < self.restart_limit\n        ):\n            cos_decay = 0.5 * (1 + math.cos(math.pi * step_in_epoch / epoch_steps))\n            return self.eta_min + (base_lr * gamma - self.eta_min) * cos_decay\n\n        return self.eta_min\n\n    def _generate_conf_for_graph(self, lr_conf):\n        lr_conf.cosine_annealing_warm_restarts_conf.SetInParent()\n        cosa_warm_restarts_conf = lr_conf.cosine_annealing_warm_restarts_conf\n        cosa_warm_restarts_conf.t_initial = self.T_0\n        cosa_warm_restarts_conf.t_mult = self.T_mult\n        cosa_warm_restarts_conf.eta_min = self.eta_min\n        cosa_warm_restarts_conf.decay_rate = self.decay_rate\n        cosa_warm_restarts_conf.restart_limit = self.restart_limit\n"
  },
  {
    "path": "python/oneflow/nn/optimizer/cosine_decay_lr.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport math\n\nfrom ...optim.optimizer import Optimizer\nfrom .lr_scheduler import LRScheduler\n\n\nclass CosineDecayLR(LRScheduler):\n    \"\"\"This operator creates a Cosine decayed learning rate scheduler.\n\n    Before the decay_steps are specified by user, the learning rate will be updated as:\n\n    .. math::\n\n        & cos\\\\_decay = 0.5*(1+cos(\\\\pi*\\\\frac{current\\\\_step}{decay\\\\_steps}))\n\n        & decay\\\\_factor = (1-\\\\alpha)*cos\\\\_decay+\\\\alpha\n\n        & learning\\\\_rate = base\\\\_learning\\\\_rate*decay\\\\_factor\n\n    After the decay_steps specified by user, the learning rate will be :\n\n    .. math::\n\n        learning\\\\_rate = {base\\\\_learning\\\\_rate}*{\\\\alpha}\n\n    It has been proposed in\n    `SGDR: Stochastic Gradient Descent with Warm Restarts`_. Note that this only\n    implements the cosine annealing part of SGDR, and not the restarts.\n\n    Args:\n        optimizer(Optimizer): Wrapped optimizer.\n        decay_steps (int): The decay steps in the scheduler.\n        alpha (float, optional): The learning rate scale factor (:math:`\\\\alpha`). (default: 0.0)\n        last_step (int, optional): The index of last step. (default: -1)\n        verbose (bool, optional): If ``True``, prints a message to stdout for each update. (default: ``False``)\n\n    For example:\n\n    .. code-block:: python\n\n        import oneflow as flow\n\n        ...\n        cosine_decay_lr = flow.optim.lr_scheduler.CosineDecayLR(optimizer, decay_steps=100, alpha=0.0)\n        for epoch in range(num_epoch):\n            train(...)\n            cosine_decay_lr.step()\n\n    .. _SGDR\\\\: Stochastic Gradient Descent with Warm Restarts:\n        https://arxiv.org/abs/1608.03983\n    \"\"\"\n\n    def __init__(\n        self,\n        optimizer: Optimizer,\n        decay_steps: int,\n        alpha: float = 0.0,\n        last_step: int = -1,\n        verbose: bool = False,\n    ):\n        assert (\n            decay_steps > 0\n        ), f\"decay_steps must greater than zero, but got {decay_steps}\"\n        self.decay_steps = decay_steps\n        self.alpha = alpha\n        super().__init__(optimizer, last_step, verbose)\n\n    def get_lr(self, base_lr, step):\n        if step < self.decay_steps:\n            cos_decay = 0.5 * (1 + math.cos(math.pi * step / self.decay_steps))\n            decay_factor = (1 - self.alpha) * cos_decay + self.alpha\n        else:\n            decay_factor = self.alpha\n\n        return base_lr * decay_factor\n\n    def _generate_conf_for_graph(self, lr_conf):\n        # CosineDecayLR is the same as CosineDecayConf in nn.Graph\n        lr_conf.cosine_conf.SetInParent()\n        cosine_decay_conf = lr_conf.cosine_conf\n        cosine_decay_conf.decay_batches = self.decay_steps\n        cosine_decay_conf.alpha = self.alpha\n"
  },
  {
    "path": "python/oneflow/nn/optimizer/exponential_lr.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nfrom ...optim.optimizer import Optimizer\nfrom .lr_scheduler import LRScheduler\n\n\nclass ExponentialLR(LRScheduler):\n    \"\"\"\n    Decays the learning rate of each parameter group by gamma every epoch.\n    When last_epoch=-1, sets initial lr as lr.\n\n    Args:\n        optimizer (Optimizer): Wrapped optimizer.\n        gamma (float): Multiplicative factor of learning rate decay.\n        last_step (int): The index of last step. Default: -1.\n        verbose (bool): If ``True``, prints a message to stdout for\n            each update. Default: ``False``.\n    \"\"\"\n\n    def __init__(\n        self,\n        optimizer: Optimizer,\n        gamma: float,\n        last_step: int = -1,\n        verbose: bool = False,\n    ):\n        assert isinstance(optimizer, Optimizer)\n        if gamma <= 0.0:\n            raise ValueError(f\"'gamma' must be greater than zero, but got {gamma}\")\n\n        self.gamma = gamma\n        super().__init__(optimizer, last_step, verbose)\n\n    def get_lr(self, base_lr, step):\n        return base_lr * (self.gamma ** step)\n\n    def _generate_conf_for_graph(self, lr_conf):\n        lr_conf.step_conf.SetInParent()\n        step_conf = lr_conf.step_conf\n        step_conf.step_size = 1\n        step_conf.gamma = self.gamma\n"
  },
  {
    "path": "python/oneflow/nn/optimizer/lamb.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nfrom typing import Callable, Dict, Iterator, List, Union, Tuple\n\nimport math\nimport oneflow as flow\nfrom oneflow.optim.optimizer import Optimizer\nfrom oneflow.nn.parameter import Parameter\n\n\nclass LAMB(Optimizer):\n    \"\"\"Implements LAMB algorithm.\n\n    LAMB was proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_.\n\n    The equation of parameters updating is:\n\n    .. math::\n\n        & V_t = \\\\beta_1*V_{t-1} + (1-\\\\beta_1)*grad\n\n        & S_t = \\\\beta_2*S_{t-1} + (1-\\\\beta_2)*{grad} \\\\odot {grad}\n\n        & \\\\hat{u} = \\\\frac{{V_t}}{\\\\sqrt{{S_t}}+\\\\epsilon}\n        \n        & \\\\hat{r} = learning\\\\_rate * \\\\frac{||param_{old}||_2}{||\\\\hat{u}||_2}\n\n        & param_{new} = param_{old} - \\\\hat{r} * \\\\hat{u}\n\n    Args:\n        parameters (iterable): iterable of parameters to optimize or dicts defining\n            parameter groups\n        lr (float, optional): learning rate (default: 1e-3)\n        betas (Tuple[float, float], optional): coefficients used for computing\n            running averages of gradient and its square (default: (0.9, 0.999))\n        eps (float, optional): term added to the denominator to improve\n            numerical stability (default: 1e-8)\n        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)\n        adam_w_mode (bool, optional): apply L2 regularization or weight decay True for\n            decoupled weight decay (also known as AdamW) (default: True)\n        do_bias_correction (bool, optional): whether to do bias correction (default: True)\n        amsgrad (bool, optional): whether to use the AMSGrad variant of this algorithm. \n            NOT SUPPORTED now! (default: False)\n        contiguous_params (bool, optional): whether to use contiguous ParamGroup \n            which puts all parameters of the same type, device and group into the\n            same tensor and update them together. (default: False)\n        \n    .. _Large Batch Optimization for Deep Learning\\\\: Training BERT in 76 minutes:\n        https://arxiv.org/abs/1904.00962\n\n    For example:\n\n    Example 1:\n\n    .. code-block:: python\n\n        # Assume net is a custom model.\n        lamb = flow.optim.LAMB(net.parameters(), lr=1e-3)\n\n        for epoch in range(epochs):\n            # Read data, Compute the loss and so on.\n            # ...\n            loss.backward()\n            lamb.step()\n            lamb.zero_grad()\n\n    Example 2:\n\n    .. code-block:: python\n\n        # Assume net is a custom model.\n        lamb = flow.optim.LAMB(\n            [\n                {\n                    \"params\": net.parameters(),\n                    \"lr\": learning_rate,\n                    \"clip_grad_max_norm\": 0.5,\n                    \"clip_grad_norm_type\": 2.0,\n                }\n            ],\n        )\n\n        for epoch in range(epochs):\n            # Read data, Compute the loss and so on.\n            # ...\n            loss.backward()\n            lamb.clip_grad()\n            lamb.step()\n            lamb.zero_grad()\n\n    If you want to use clip_grad, you can refer this example.\n\n    For more details of `clip_grad_max_norm` and `clip_grad_norm_type`, you can refer to :func:`oneflow.nn.utils.clip_grad_norm_`.\n    \"\"\"\n\n    def __init__(\n        self,\n        params: Union[Iterator[Parameter], List[Dict]],\n        lr: float = 0.001,\n        betas: Tuple[float, float] = (0.9, 0.999),\n        eps: float = 1e-08,\n        weight_decay: float = 0,\n        adam_w_mode: bool = True,\n        do_bias_correction: bool = True,\n        amsgrad: bool = False,\n        contiguous_params: bool = False,\n    ):\n        if amsgrad:\n            # TODO: supported amsgrad in Lamb\n            raise RuntimeError(\"LAMB does not support AMSGrad variant.\")\n        assert lr >= 0.0, f\"Invalid learning rate: {lr}\"\n        assert eps >= 0.0, f\"Invalid epsilon value: {eps}\"\n        assert (\n            betas[0] >= 0.0 and betas[0] < 1.0\n        ), f\"Invalid beta parameter at index 0: {betas[0]}\"\n        assert (\n            betas[1] >= 0.0 and betas[1] < 1.0\n        ), f\"Invalid beta parameter at index 1: {betas[1]}\"\n        assert weight_decay >= 0.0, f\"Invalid weight_decay value: {weight_decay}\"\n\n        options = dict()\n        options[\"lr\"] = lr\n        options[\"eps\"] = eps\n        options[\"betas\"] = betas\n        options[\"weight_decay\"] = weight_decay\n        options[\"amsgrad\"] = amsgrad\n        options[\"adam_w_mode\"] = adam_w_mode\n        options[\"bias_correction1\"] = 1.0\n        options[\"bias_correction2\"] = 1.0\n        options[\"do_bias_correction\"] = do_bias_correction\n        options[\"contiguous_params\"] = contiguous_params\n\n        super().__init__(params, options)\n\n        for param_group in self.param_groups:\n            if param_group[\"contiguous_params\"]:\n                param_list = param_group.contiguous_parameters\n            else:\n                param_list = param_group.parameters\n\n            for param in param_list:\n                assert param.is_leaf, \"parameters must be leaf tensor\"\n                self.state[param] = dict()\n\n        self._op = (\n            flow.stateful_op(\"lamb_update\")\n            .Input(\"model\")\n            .Input(\"model_diff\")\n            .Input(\"m\")\n            .Input(\"v\")\n            .Build()\n        )\n\n    def step(self, closure: Callable = None):\n        \"\"\"Performs a single optimization step.\n\n        Args:\n            closure (callable, optional): A closure that reevaluates the model\n                and returns the loss.\n        \"\"\"\n        with flow.no_grad():\n            loss = None\n            if closure is not None:\n                with flow.enable_grad():\n                    loss = closure()\n\n            for param_group in self.param_groups:\n                if param_group[\"do_bias_correction\"]:\n                    param_group[\"bias_correction1\"] = 1.0 - math.pow(\n                        param_group[\"betas\"][0], self.state[\"step\"] + 1\n                    )\n                    param_group[\"bias_correction2\"] = 1.0 - math.pow(\n                        param_group[\"betas\"][1], self.state[\"step\"] + 1\n                    )\n\n                kwargs = {\n                    \"learning_rate\": param_group[\"lr\"],\n                    \"bias_correction1\": param_group[\"bias_correction1\"],\n                    \"bias_correction2\": param_group[\"bias_correction2\"],\n                    \"beta1\": param_group[\"betas\"][0],\n                    \"beta2\": param_group[\"betas\"][1],\n                    \"epsilon\": param_group[\"eps\"],\n                    \"do_bias_correction\": param_group[\"do_bias_correction\"],\n                }\n                if param_group[\"adam_w_mode\"]:\n                    kwargs[\"weight_decay\"] = param_group[\"weight_decay\"]\n                    kwargs[\"l2\"] = 0.0\n                else:\n                    kwargs[\"l2\"] = param_group[\"weight_decay\"]\n                    kwargs[\"weight_decay\"] = 0.0\n\n                if param_group[\"contiguous_params\"]:\n                    param_list = param_group.contiguous_parameters\n                else:\n                    param_list = param_group.parameters\n\n                for param in param_list:\n                    if param.grad is None:\n                        continue\n                    if \"exp_avg\" not in self.state[param]:\n                        self.state[param][\"exp_avg\"] = flow.zeros_like(param)\n                    if \"exp_avg_sq\" not in self.state[param]:\n                        self.state[param][\"exp_avg_sq\"] = flow.zeros_like(param)\n                    m_tensor = self.state[param][\"exp_avg\"]\n                    v_tensor = self.state[param][\"exp_avg_sq\"]\n\n                    flow._C.dispatch_lamb_update(\n                        self._op, (param, param.grad, m_tensor, v_tensor), **kwargs\n                    )\n\n            self.state[\"step\"] += 1\n\n            return loss\n\n    def _generate_conf_for_graph(self, train_conf, vars_conf):\n        new_opt_confs = []\n        for param_group in self.param_groups:\n            assert (\n                param_group[\"contiguous_params\"] != True\n            ), \"contiguous_params cannot be used in graph\"\n\n            optimizer_conf = train_conf.optimizer_conf.add()\n\n            lr = (\n                param_group[\"initial_lr\"]\n                if \"initial_lr\" in param_group\n                else param_group[\"lr\"]\n            )\n            adam_w_mode = param_group[\"adam_w_mode\"]\n            weight_decay = param_group[\"weight_decay\"]\n            beta1 = param_group[\"betas\"][0]\n            beta2 = param_group[\"betas\"][1]\n            do_bias_correction = param_group[\"do_bias_correction\"]\n            epsilon = param_group[\"eps\"]\n\n            optimizer_conf.base_learning_rate = lr\n            self._generate_lr_scale_for_optim_conf(param_group, optimizer_conf)\n\n            optimizer_conf.lamb_conf.beta1 = beta1\n            optimizer_conf.lamb_conf.beta2 = beta2\n            optimizer_conf.lamb_conf.epsilon = epsilon\n            optimizer_conf.lamb_conf.do_bias_correction = do_bias_correction\n\n            self._generate_grad_clip_conf_for_optim_conf(param_group, optimizer_conf)\n\n            if adam_w_mode:\n                optimizer_conf.weight_decay_conf.weight_decay_rate = weight_decay\n            else:\n                optimizer_conf.weight_decay_conf.weight_decay_rate = 0.0\n\n            for param in param_group.parameters:\n                if not adam_w_mode:\n                    # Set l2 penalty as weight decay if **NOT** using adam_w_mode\n                    vars_conf[param].l2 = weight_decay\n                if param.requires_grad:\n                    optimizer_conf.variable_op_names.append(vars_conf[param].name)\n\n            new_opt_confs.append(optimizer_conf)\n        return new_opt_confs\n"
  },
  {
    "path": "python/oneflow/nn/optimizer/lambda_lr.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport types\n\nfrom .lr_scheduler import LRScheduler\n\n\nclass LambdaLR(LRScheduler):\n    \"\"\"\n    Sets the learning rate of each parameter group to the initial lr times a given function.\n    When last_step=-1, sets initial lr as lr.\n\n    .. math::\n\n        learning\\\\_rate = base\\\\_learning\\\\_rate*lambda(last\\\\_step)\n\n    Args:\n        optimizer(Optimizer): Wrapped optimizer.\n        lr_lambda(function or list): A function which computes a multiplicative factor given an integer\n            parameter epoch, or a list of such functions, one for each group in optimizer.param_groups.\n        last_step (int, optional): The index of last step. (default: -1)\n        verbose (bool, optional): If ``True``, prints a message to stdout for each update. (default: ``False``)\n\n    For example:\n\n    .. code-block:: python\n\n        import oneflow as flow\n\n        ...\n        lambda1 = lambda step: step // 30\n        lambda2 = lambda step: 0.95 * step\n        lambda_lr = flow.optim.lr_scheduler.LambdaLR(optimizer, [lambda1, lambda2])\n        for epoch in range(num_epoch):\n            train(...)\n            lambda_lr.step()\n\n    \"\"\"\n\n    def __init__(self, optimizer, lr_lambda, last_step=-1, verbose=False):\n        if not isinstance(lr_lambda, (list, tuple)):\n            self.lr_lambdas = [lr_lambda] * len(optimizer.param_groups)\n        else:\n            assert len(lr_lambda) == len(\n                optimizer.param_groups\n            ), f\"Expected {len(optimizer.param_groups)} lr_lambdas, but got {len(lr_lambda)}\"\n            self.lr_lambdas = list(lr_lambda)\n        super().__init__(optimizer, last_step, verbose)\n\n    def state_dict(self):\n        \"\"\"Returns the state of the scheduler as a :class:`dict`.\n\n        It contains an entry for every variable in self.__dict__ which\n        is not the optimizer.\n        The learning rate lambda functions will only be saved if they are callable objects\n        and not if they are functions or lambdas.\n        \"\"\"\n        state_dict = {\n            key: value\n            for (key, value) in self.__dict__.items()\n            if key not in (\"optimizer\", \"lr_lambdas\")\n        }\n        state_dict[\"lr_lambdas\"] = [None] * len(self.lr_lambdas)\n        for (idx, fn) in enumerate(self.lr_lambdas):\n            if not isinstance(fn, types.FunctionType):\n                state_dict[\"lr_lambdas\"][idx] = fn.__dict__.copy()\n        return state_dict\n\n    def load_state_dict(self, state_dict):\n        \"\"\"Loads the schedulers state.\n\n        Arguments:\n            state_dict (dict): scheduler state. Should be an object returned\n                from a call to :meth:`state_dict`.\n        \"\"\"\n        lr_lambdas = state_dict.pop(\"lr_lambdas\")\n        self.__dict__.update(state_dict)\n        state_dict[\"lr_lambdas\"] = lr_lambdas\n        for (idx, fn) in enumerate(lr_lambdas):\n            if fn is not None:\n                self.lr_lambdas[idx].__dict__.update(fn)\n\n    def step(self):\n        \"\"\"Performs a single learning rate schedule step.\n\n        \"\"\"\n        self.last_step += 1\n        lrs = []\n        for (lmbda, base_lr) in zip(self.lr_lambdas, self.base_lrs):\n            lrs.append(base_lr * lmbda(self.last_step))\n        self.update_lrs(lrs)\n"
  },
  {
    "path": "python/oneflow/nn/optimizer/lbfgs.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nfrom typing import Callable, Dict, Iterator, List, Tuple, Union\nfrom functools import reduce\nfrom oneflow.optim.optimizer import Optimizer\nfrom oneflow.nn.parameter import Parameter\nimport oneflow as flow\n\n# TODO implement quadrati_interpolate op\ndef _quadratic_interpolate(x1, f1, g1, x2, f2, g2, bounds=None):\n\n    if bounds is not None:\n        xmin_bound, xmax_bound = bounds\n    else:\n        xmin_bound, xmax_bound = (x1, x2) if x1 < x2 else (x2, x1)\n    if x1 == 0:\n        t_new = -(g1 * (x2 ** 2)) / (2 * (f2 - f1 - g1 * x2))\n    else:\n        a = -(f1 - f2 - g1 * (x1 - x2)) / ((x1 - x2) ** 2)\n        t_new = x1 - g1 / (2 * a)\n    return min(xmax_bound, max(xmin_bound, t_new))\n\n\ndef _strong_wolfe(\n    eval_closure, x, t, d, f, g, gtd, c1=1e-4, c2=0.9, tolerance_change=1e-9, max_ls=25\n):\n    d_norm = d.abs().max()\n    g = g.clone()\n    f_new, g_new = eval_closure(x, t, d)\n    ls_func_evals = 1\n    gtd_new = g_new.dot(d)\n\n    t_prev, f_prev, g_prev, gtd_prev = 0, f, g, gtd\n    done = False\n    ls_iter = 0\n    while ls_iter < max_ls:\n        if f_new > (f + c1 * t * gtd) or (ls_iter > 1 and f_new > f_prev):\n            search_area = [t_prev, t]\n            search_area_f = [f_prev, f_new]\n            search_area_g = [g_prev, g_new.clone()]\n            search_area_gtd = [gtd_prev, gtd_new]\n            break\n\n        if abs(gtd_new) <= -c2 * gtd:\n            search_area = [t]\n            search_area_f = [f_new]\n            search_area_g = [g_new]\n            done = True\n            break\n\n        if gtd_new >= 0:\n            search_area = [t_prev, t]\n            search_area_f = [f_prev, f_new]\n            search_area_g = [g_prev, g_new.clone()]\n            search_area_gtd = [gtd_prev, gtd_new]\n\n        min_step = t + 0.01 * (t - t_prev)\n        max_step = t * 10\n        tmp = t\n        t = _quadratic_interpolate(\n            t_prev, f_prev, gtd_prev, t, f_new, gtd_new, bounds=(min_step, max_step)\n        )\n        t_prev = tmp\n        f_prev = f_new\n        g_prev = g_new.clone()\n        gtd_prev = gtd_new\n        f_new, g_new = eval_closure(x, t, d)\n        ls_func_evals += 1\n        gtd_new = g_new.dot(d)\n        ls_iter += 1\n    if ls_iter == max_ls:\n        search_area = [0, t]\n        search_area_f = [f, f_new]\n        search_area_g = [g, g_new]\n\n    # zoom\n    low_pos, high_pos = (0, 1) if search_area_f[0] <= search_area_f[-1] else (1, 0)\n    while not done and ls_iter < max_ls:\n\n        if abs(search_area[1] - search_area[0]) * d_norm < tolerance_change:\n            break\n\n        t = _quadratic_interpolate(\n            search_area[0],\n            search_area_f[0],\n            search_area_gtd[0],\n            search_area[1],\n            search_area_f[1],\n            search_area_gtd[1],\n        )\n\n        f_new, g_new = eval_closure(x, t, d)\n        ls_func_evals += 1\n        gtd_new = g_new.dot(d)\n        ls_iter += 1\n\n        if f_new > (f + c1 * t * gtd) or f_new >= search_area_f[low_pos]:\n            search_area[high_pos] = t\n            search_area_f[high_pos] = f_new\n            search_area_g[high_pos] = g_new.clone()\n            search_area_gtd[high_pos] = gtd_new\n            low_pos, high_pos = (\n                (0, 1) if search_area_f[0] <= search_area_f[1] else (1, 0)\n            )\n        if abs(gtd_new) <= -c2 * gtd:\n            done = True\n        elif gtd_new * (search_area[high_pos] - search_area[low_pos]) >= 0:\n            search_area[high_pos] = search_area[low_pos]\n            search_area_f[high_pos] = search_area_f[low_pos]\n            search_area_g[high_pos] = search_area_g[low_pos]\n            search_area_gtd[high_pos] = search_area_gtd[low_pos]\n\n        search_area[low_pos] = t\n        search_area_f[low_pos] = f_new\n        search_area_g[low_pos] = g_new.clone()\n        search_area_gtd[low_pos] = gtd_new\n\n    t = search_area[low_pos]\n    f_new = search_area_f[low_pos]\n    g_new = search_area_g[low_pos]\n    return f_new, g_new, t, ls_func_evals\n\n\nclass LBFGS(Optimizer):\n    \"\"\"Implements LBFGS algorithm\n    \n    It has been propose in `On the limited memory BFGS method for large scale optimization`_.\n    The implementation of the two-loop recursion proposed in `Updating Quasi-Newton Matrices with Limited Storage`_.\n    \n    The implementation of the strong_wolfe line search  proposed in `Numerical_Optimization_v2`\n    \n    This algorithm uses an estimated inverse Hessian matrix to steer its search through variable space and determine the optimal direction.\n    \n    The line search algorithm terminates with a step length that satisfies the strong Wolfe conditions.\n    \n    This optimizer only support one parameter group.        \n    \n    Args:\n        params (iterable): iterable of parameters to optimize or dicts defining\n            parameter groups\n        lr (float, optional): learning rate (default: 1e-3)\n        max_iter (int,optional): max iteration per step (default: 20)\n        max_eval (int,optional): max func evals per step (default: max_iter * 1.25)\n        tolerance_grad (float, optional): termination tolerance on first order optimality (default 1e-7)\n        tolerance_change (float, optional): termination tolerance on paramter changes (default: 1e-9)\n        history_size (int,optional): paramter update history size (default: 100)\n        line_search_fn (str,optional): line search function `strong_wolfe` or None (default: None)\n        contiguous_params (bool, optional): whether to use contiguous ParamGroup \n            which puts all parameters of the same type, device and group into the\n            same tensor and update them together. (default: False)\n    .. _On the limited memory BFGS method for large scale optimization:\n        https://dl.acm.org/doi/10.5555/3112655.3112866\n            \n    .. _Updating Quasi-Newton Matrices with Limited Storage:\n        https://www.ams.org/journals/mcom/1980-35-151/S0025-5718-1980-0572855-7/S0025-5718-1980-0572855-7.pdf\n    \n    For example: \n    \n    .. code-block:: python \n    \n        # Assume net is a custom model. \n        lbfgs = flow.optim.LBFGS(net.parameters())\n        \n        for epoch in range (epochs):\n            def closure():\n                lbfgs.zero_grad()\n                # Read data, Compute the loss and so on. \n                loss.backward()\n                return loss\n            lbfgs.step(closure)\n                \n    \n    \"\"\"\n\n    def __init__(\n        self,\n        params: Union[Iterator[Parameter], List[Dict]],\n        lr: float = 0.001,\n        max_iter: int = 20,\n        max_eval: int = None,\n        tolerance_grad: float = 1e-7,\n        tolerance_change: float = 1e-9,\n        history_size: int = 100,\n        line_search_fn=None,\n        contiguous_params: bool = False,\n    ):\n        if max_eval is None:\n            max_eval = max_iter * 1.25\n        options = dict()\n        options[\"lr\"] = lr\n        options[\"max_iter\"] = max_iter\n        options[\"max_eval\"] = max_eval\n        options[\"tolerance_grad\"] = tolerance_grad\n        options[\"tolerance_change\"] = tolerance_change\n        options[\"history_size\"] = history_size\n        options[\"line_search_fn\"] = line_search_fn\n        options[\"contiguous_params\"] = contiguous_params\n        super().__init__(params, options)\n        assert (\n            len(self.param_groups) == 1\n        ), \"LBFGS not support parameter groups (there can be only one)\"\n        param_group = self.param_groups[0]\n        if param_group[\"contiguous_params\"]:\n            param_list = param_group.contiguous_parameters\n        else:\n            param_list = param_group.parameters\n        for param in param_list:\n            assert param.is_leaf, \"parameters must be leaf tensor\"\n        self._params = param_list\n        self._numel_cache = None\n\n    def _gather_flat_grad(self):\n        views = []\n        for p in self._params:\n            if p.grad is None:\n                view = p.new(p.numel()).zero_()\n            else:\n                view = p.grad.view(-1)\n            views.append(view)\n        return flow.cat(views, 0)\n\n    def _numel(self):\n        # get parameters total numel\n        if self._numel_cache is None:\n            self._numel_cache = reduce(\n                lambda totnumel, p: totnumel + p.numel(), self._params, 0,\n            )\n        return self._numel_cache\n\n    def _update(self, step_size, direction):\n        # update parameters\n        offset = 0\n        for p in self._params:\n            numel = p.numel()\n            p.add_(direction[offset : offset + numel].view_as(p), alpha=step_size)\n            offset += numel\n        assert offset == self._numel()\n\n    def _try_direction(self, closure, x, t, d):\n        self._update(t, d)\n        with flow.enable_grad():\n            loss = float(closure())\n        flag_grad = self._gather_flat_grad()\n        for p, data in zip(self._params, x):\n            p.copy_(data)\n        return loss, flag_grad\n\n    def step(self, closure: Callable = None):\n        \"\"\"Performs a single optimization step.\n\n        Args:\n            closure (callable): A closure that reevaluates the model\n                and returns the loss.\n        \"\"\"\n        with flow.no_grad():\n            assert closure != None, \"closure must not be None\"\n            param_group = self.param_groups[0]\n            lr = param_group[\"lr\"]\n            max_iter = param_group[\"max_iter\"]\n            max_eval = param_group[\"max_eval\"]\n            tolerance_grad = param_group[\"tolerance_grad\"]\n            tolerance_change = param_group[\"tolerance_change\"]\n            line_search_fn = param_group[\"line_search_fn\"]\n            history_size = param_group[\"history_size\"]\n\n            state = self.state[self._params[0]]\n            state.setdefault(\"func_evals\", 0)\n            state.setdefault(\"n_iter\", 0)\n            with flow.enable_grad():\n                origin_loss = closure()\n            loss = float(origin_loss)\n            current_evals = 1\n            state[\"func_evals\"] += 1\n\n            flat_grad = self._gather_flat_grad()\n            if flat_grad.abs().max() <= tolerance_grad:\n                return origin_loss\n\n            # prev state\n            d = state.get(\"d\")\n            t = state.get(\"t\")\n            old_diffs = state.get(\"old_diffs\")\n            old_step_size = state.get(\"old_step_size\")\n            ro = state.get(\"ro\")\n            H_diag = state.get(\"H_diag\")\n            prev_flat_grad = state.get(\"prev_flat_grad\")\n            prev_loss = state.get(\"prev_loss\")\n\n            n_iter = 0\n\n            while n_iter < max_iter:\n                n_iter += 1\n                state[\"n_iter\"] += 1\n\n                # compute direction\n                if state[\"n_iter\"] == 1:\n                    d = flat_grad.neg()\n                    old_diffs = []\n                    old_step_size = []\n                    ro = []\n                    H_diag = 1\n                else:\n                    y = flat_grad.sub(prev_flat_grad)\n                    s = d.mul(t)\n                    ys = y.dot(s)\n                    # ys must be positive\n                    if ys > 1e-10:\n                        if len(old_diffs) == history_size:\n                            old_diffs.pop(0)\n                            old_step_size.pop(0)\n                            ro.pop(0)\n                        old_diffs.append(y)\n                        old_step_size.append(s)\n                        ro.append(1.0 / ys)\n                        H_diag = ys / y.dot(y)\n\n                    num_old = len(old_diffs)\n\n                    if \"alpha\" not in state:\n                        state[\"alpha\"] = [None] * history_size\n                    alpha = state[\"alpha\"]\n\n                    q = flat_grad.neg()\n                    for i in range(num_old - 1, -1, -1):\n                        alpha[i] = old_step_size[i].dot(q) * ro[i]\n                        q.add_(old_diffs[i], alpha=-alpha[i])\n\n                    d = q.mul(H_diag)\n                    for i in range(num_old):\n                        beta_i = old_diffs[i].dot(d) * ro[i]\n                        d.add_(old_step_size[i], alpha=alpha[i] - beta_i)\n\n                # compute step size\n                if prev_flat_grad is None:\n                    prev_flat_grad = flat_grad.clone()\n                else:\n                    prev_flat_grad.copy_(flat_grad)\n\n                prev_loss = loss\n\n                if state[\"n_iter\"] == 1:\n                    t = min(1.0, 1.0 / flat_grad.abs().sum()) * lr\n                else:\n                    t = lr\n\n                gtd = flat_grad.dot(d)\n                if gtd > -tolerance_change:\n                    break\n\n                ls_func_evals = 0\n                if line_search_fn is None:\n                    self._update(t, d)\n                    if n_iter != max_iter:\n                        with flow.enable_grad():\n                            loss = float(closure())\n                        flat_grad = self._gather_flat_grad()\n                        ls_func_evals = 1\n                else:\n                    assert (\n                        line_search_fn == \"strong_wolfe\"\n                    ), \"only strong_wolfe is expected\"\n                    init_param = [p.clone() for p in self._params]\n\n                    def eval_func(x, t, d):\n                        return self._try_direction(closure, x, t, d)\n\n                    loss, flat_grad, t, ls_func_evals = _strong_wolfe(\n                        eval_func, init_param, t, d, loss, flat_grad, gtd\n                    )\n                    self._update(t, d)\n\n                current_evals += ls_func_evals\n                state[\"func_evals\"] += ls_func_evals\n\n                if n_iter == max_iter:\n                    break\n\n                if current_evals >= max_eval:\n                    break\n\n                if flat_grad.abs().max() <= tolerance_grad:\n                    break\n\n                if d.mul(t).abs().max() <= tolerance_change:\n                    break\n\n                if abs(loss - prev_loss) < tolerance_change:\n                    break\n\n            state[\"d\"] = d\n            state[\"t\"] = t\n            state[\"old_diffs\"] = old_diffs\n            state[\"old_step_size\"] = old_step_size\n            state[\"ro\"] = ro\n            state[\"prev_flat_grad\"] = prev_flat_grad\n            state[\"prev_loss\"] = prev_loss\n            state[\"H_diag\"] = H_diag\n            return origin_loss\n"
  },
  {
    "path": "python/oneflow/nn/optimizer/linear_lr.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nfrom ...optim.optimizer import Optimizer\nfrom .lr_scheduler import LRScheduler\n\n\nclass LinearLR(LRScheduler):\n    \"\"\"Decays the learning rate of each parameter group by linearly changing small\n    multiplicative factor until the number of step reaches a pre-defined milestone: total_iters.\n\n    Args:\n        optimizer (Optimizer): Wrapped optimizer.\n        start_factor (float): The number we multiply learning rate in the first step.\n            The multiplication factor changes towards end_factor in the following steps.\n            Default: 1./3.\n        end_factor (float): The number we multiply learning rate at the end of linear changing\n            process. Default: 1.0.\n        total_iters (int): The number of iterations that multiplicative factor reaches to 1.\n            Default: 5.\n        last_step (int): The index of the last step. Default: -1.\n        verbose (bool): If ``True``, prints a message to stdout for\n            each update. Default: ``False``.\n\n    Example:\n        >>> # Assuming optimizer uses lr = 0.05 for all groups\n        >>> # lr = 0.025    if step == 0\n        >>> # lr = 0.03125  if step == 1\n        >>> # lr = 0.0375   if step == 2\n        >>> # lr = 0.04375  if step == 3\n        >>> # lr = 0.05    if step >= 4\n        >>> scheduler = LinearLR(self.opt, start_factor=0.5, total_iters=4)\n        >>> for step in range(100):\n        >>>     train(...)\n        >>>     validate(...)\n        >>>     scheduler.step()\n    \"\"\"\n\n    def __init__(\n        self,\n        optimizer: Optimizer,\n        start_factor: float = 1.0 / 3,\n        end_factor: float = 1.0,\n        total_iters: int = 5,\n        last_step: int = -1,\n        verbose: bool = False,\n    ):\n        assert isinstance(optimizer, Optimizer)\n\n        if start_factor > 1.0 or start_factor < 0:\n            raise ValueError(\n                \"Starting multiplicative factor expected to be between 0 and 1.\"\n            )\n\n        if end_factor > 1.0 or end_factor < 0:\n            raise ValueError(\n                \"Ending multiplicative factor expected to be between 0 and 1.\"\n            )\n\n        self.start_factor = start_factor\n        self.end_factor = end_factor\n        self.total_iters = total_iters\n        super().__init__(optimizer, last_step, verbose)\n\n    def get_lr(self, base_lr, step):\n        if step < self.total_iters:\n            multiplier = self.start_factor + (self.end_factor - self.start_factor) * (\n                step / self.total_iters\n            )\n        else:\n            multiplier = self.end_factor\n\n        return base_lr * multiplier\n\n    def _generate_conf_for_graph(self, lr_conf):\n        lr_conf.linear_lr_conf.SetInParent()\n        linear_lr_conf = lr_conf.linear_lr_conf\n        linear_lr_conf.start_factor = self.start_factor\n        linear_lr_conf.end_factor = self.end_factor\n        linear_lr_conf.total_iters = self.total_iters\n"
  },
  {
    "path": "python/oneflow/nn/optimizer/lr_scheduler.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nfrom ...optim.optimizer import Optimizer\n\n\nclass LRScheduler(object):\n    def __init__(\n        self, optimizer: Optimizer, last_step: int = -1, verbose: bool = False\n    ):\n        if not isinstance(optimizer, Optimizer):\n            raise TypeError(f\"{type(optimizer).__name__} is not an Optimizer object\")\n\n        self.optimizer = optimizer\n        self.last_step = last_step\n        self.verbose = verbose\n        self._init_base_lrs()\n        self.step()\n\n    def state_dict(self):\n        \"\"\"Return the state of the scheduler as a :class:`dict`.\n\n        It contains an entry for every variable in self.__dict__ which\n        is not the optimizer.\n        \"\"\"\n        return {\n            key: value for (key, value) in self.__dict__.items() if key != \"optimizer\"\n        }\n\n    def load_state_dict(self, state_dict):\n        \"\"\"Load the schedulers state.\n\n        Arguments:\n            state_dict (dict): scheduler state. Should be an object returned\n                from a call to :meth:`state_dict`.\n        \"\"\"\n        self.__dict__.update(state_dict)\n\n    def get_lr(self, base_lr, step):\n        \"\"\"Compute learning rate using chainable form of the scheduler\"\"\"\n        raise NotImplementedError\n\n    def get_last_lr(self):\n        \"\"\"Return last computed learning rate by current scheduler.\"\"\"\n        return self._last_lr\n\n    def print_lr(self, group, lr):\n        \"\"\"Display the current learning rate.\"\"\"\n        print(\n            f\"Last step {self.last_step} of {type(self)} adjusting learning rate \"\n            f\"of param_groups[{group}] to {lr:.5f}\"\n        )\n\n    def step(self):\n        self.last_step += 1\n        lrs = [self.get_lr(base_lr, self.last_step) for base_lr in self.base_lrs]\n        self.update_lrs(lrs)\n\n    def update_lrs(self, lrs):\n        self._last_lr = []\n        for i, (group, lr) in enumerate(zip(self.optimizer.param_groups, lrs)):\n            group[\"lr\"] = lr\n            self._last_lr.append(lr)\n            if self.verbose:\n                self.print_lr(i, lr)\n\n    def _init_base_lrs(self):\n        if self.last_step == -1:\n            for group in self.optimizer.param_groups:\n                if \"initial_lr\" not in group:\n                    group.setdefault(\"initial_lr\", group[\"lr\"])\n        else:\n            for (i, group) in enumerate(self.optimizer.param_groups):\n                if \"initial_lr\" not in group:\n                    raise KeyError(\n                        \"param 'initial_lr' is not specified \"\n                        f\"in param_groups[{i}] when resuming an optimizer\"\n                    )\n\n        self.base_lrs = [group[\"initial_lr\"] for group in self.optimizer.param_groups]\n"
  },
  {
    "path": "python/oneflow/nn/optimizer/multiplicative_lr.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport math\n\nfrom ...optim.optimizer import Optimizer\nfrom .lr_scheduler import LRScheduler\n\n\nclass MultiplicativeLR(LRScheduler):\n    \"\"\"Multiply the learning rate of each parameter group by the factor given\n    in the specified function. When last_epoch=-1, sets initial lr as lr.\n\n    The documentation is referenced from:\n    https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.MultiplicativeLR\n\n    Args:\n        optimizer (Optimizer): Wrapped optimizer.\n        lr_lambda (function or list): A function which computes a multiplicative\n            factor given an integer parameter epoch, or a list of such\n            functions, one for each group in optimizer.param_groups.\n        last_step (int): The index of last step. Default: -1.\n        verbose (bool): If ``True``, prints a message to stdout for each update. Default: ``False``.\n\n    For example:\n\n    .. code-block:: python\n\n        import oneflow as flow\n\n        ...\n        lmbda = lambda epoch: 0.95\n        step_lr = flow.optim.lr_scheduler.MultiplicativeLR(optimizer, lr_lambda=lmbda)\n        for epoch in range(num_epoch):\n            train(...)\n            step_lr.step()\n    \"\"\"\n\n    def __init__(self, optimizer, lr_lambda, last_step=-1, verbose=False):\n        self.optimizer = optimizer\n\n        if not isinstance(lr_lambda, list) and not isinstance(lr_lambda, tuple):\n            self.lr_lambdas = [lr_lambda] * len(optimizer.param_groups)\n        else:\n            if len(lr_lambda) != len(optimizer.param_groups):\n                raise ValueError(\n                    \"Expected {} lr_lambdas, but got {}\".format(\n                        len(optimizer.param_groups), len(lr_lambda)\n                    )\n                )\n            self.lr_lambdas = list(lr_lambda)\n        super().__init__(optimizer, last_step, verbose)\n\n    def state_dict(self):\n        \"\"\"Returns the state of the scheduler as a :class:`dict`.\n\n        It contains an entry for every variable in self.__dict__ which\n        is not the optimizer.\n        The learning rate lambda functions will only be saved if they are callable objects\n        and not if they are functions or lambdas.\n        \"\"\"\n        state_dict = {\n            key: value\n            for key, value in self.__dict__.items()\n            if key not in (\"optimizer\", \"lr_lambdas\")\n        }\n        state_dict[\"lr_lambdas\"] = [None] * len(self.lr_lambdas)\n\n        for idx, fn in enumerate(self.lr_lambdas):\n            if not isinstance(fn, types.FunctionType):\n                state_dict[\"lr_lambdas\"][idx] = fn.__dict__.copy()\n\n        return state_dict\n\n    def load_state_dict(self, state_dict):\n        \"\"\"Loads the schedulers state.\n\n        Args:\n            state_dict (dict): scheduler state. Should be an object returned\n                from a call to :meth:`state_dict`.\n        \"\"\"\n        lr_lambdas = state_dict.pop(\"lr_lambdas\")\n        self.__dict__.update(state_dict)\n        state_dict[\"lr_lambdas\"] = lr_lambdas\n\n        for idx, fn in enumerate(lr_lambdas):\n            if fn is not None:\n                self.lr_lambdas[idx].__dict__.update(fn)\n\n    def step(self):\n        \"\"\"Performs a single learning rate schedule step.\n\n        \"\"\"\n        self.last_step += 1\n        if self.last_step > 0:\n            lrs = [\n                group[\"lr\"] * lmbda(self.last_step)\n                for lmbda, group in zip(self.lr_lambdas, self.optimizer.param_groups)\n            ]\n        else:\n            lrs = [group[\"lr\"] for group in self.optimizer.param_groups]\n        self.update_lrs(lrs)\n"
  },
  {
    "path": "python/oneflow/nn/optimizer/multistep_lr.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport bisect\n\nfrom ...optim.optimizer import Optimizer\nfrom .lr_scheduler import LRScheduler\n\n\nclass MultiStepLR(LRScheduler):\n    \"\"\"\n    Decays the learning rate of each parameter group by gamma once the number of step\n    reaches one of the milestones. Notice that such decay can happen simultaneously with\n    other changes to the learning rate from outside this scheduler.When last_step=-1, sets initial lr as lr.\n\n    Args:\n        optimizer(Optimizer): Wrapped optimizer.\n        milestones(list): List of step indices. Must be increasing\n        gamma (float, optional): Multiplicative factor of learning rate decay. (default: 0.1)\n        last_step (int, optional): The index of last step. (default: -1)\n        verbose (bool, optional): If ``True``, prints a message to stdout for each update. (default: ``False``)\n\n    For example:\n\n    .. code-block:: python\n\n        import oneflow as flow\n\n        ...\n        multistep_lr = flow.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[30,80], gamma=0.1)\n        for epoch in range(num_epoch):\n            train(...)\n            multistep_lr.step()\n\n    \"\"\"\n\n    def __init__(\n        self,\n        optimizer: Optimizer,\n        milestones: list,\n        gamma: float = 0.1,\n        last_step: int = -1,\n        verbose: bool = False,\n    ):\n        for i in range(1, len(milestones)):\n            assert (\n                milestones[i] > milestones[i - 1]\n            ), f\"values in `list` milestone must be increasing, but got {milestones}\"\n        assert gamma > 0.0, f\"gamma must greater than zero, but got {gamma}\"\n        self.milestones = milestones\n        self.gamma = gamma\n        super().__init__(optimizer, last_step, verbose)\n\n    def get_lr(self, base_lr, step):\n        sect = bisect.bisect_right(self.milestones, step)\n        factor = self.gamma ** sect\n        return base_lr * factor\n\n    def _generate_conf_for_graph(self, lr_conf):\n        lr_conf.multi_step_conf.SetInParent()\n        multi_step_conf = lr_conf.multi_step_conf\n        for milestone in self.milestones:\n            multi_step_conf.milestones.append(milestone)\n        multi_step_conf.gamma = self.gamma\n"
  },
  {
    "path": "python/oneflow/nn/optimizer/polynomial_lr.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport math\n\nfrom .lr_scheduler import LRScheduler\n\n\nclass PolynomialLR(LRScheduler):\n    r\"\"\"\n    This operator creates a polynomial decayed learning rate scheduler.\n    The learning rate will be updated as follows:\n\n    If cycle is `True`, the equation is:\n\n    .. math::\n        \\begin{aligned}\n           & decay\\_batch = decay\\_batch*ceil(\\frac{current\\_batch}{decay\\_batch}) \\\\\n           & learning\\_rate = (base\\_lr-end\\_lr)*(1-\\frac{current\\_batch}{decay\\_batch})^{power}+end\\_lr\n        \\end{aligned}\n\n    If cycle is `False`, the equation is:\n\n    .. math::\n        \\begin{aligned}\n           & current\\_batch = min(decay\\_batch, current\\_batch) \\\\\n           & learning\\_rate = (base\\_lr-end\\_lr)*(1-\\frac{current\\_batch}{decay\\_batch})^{power}+end\\_lr\n        \\end{aligned}\n\n    Args:\n        optimizer (Optimizer): Wrapper optimizer.\n        decay_batch (int): The decayed steps.\n        end_learning_rate (float, optional): The final learning rate. Defaults to 0.0001.\n        power (float, optional): The power of polynomial. Defaults to 1.0.\n        cycle (bool, optional): If cycle is True, the scheduler will decay the learning rate every decay steps. Defaults to False.\n\n    For example:\n\n    .. code-block:: python\n\n        import oneflow as flow\n       \n        ... \n        polynomial_scheduler = flow.optim.lr_scheduler.PolynomialLR(\n            optimizer, decay_batch=5, end_learning_rate=0.00001, power=2\n            )\n\n        for epoch in range(num_epoch):\n            train(...)\n            polynomial_scheduler.step()\n    \"\"\"\n\n    def __init__(\n        self,\n        optimizer,\n        decay_batch: int,\n        end_learning_rate: float = 0.0001,\n        power: float = 1.0,\n        cycle: bool = False,\n        last_step: int = -1,\n        verbose: bool = False,\n    ):\n        assert (\n            decay_batch > 0\n        ), f\"decay_batch must greater than zero, but got {decay_batch}\"\n        self.max_decay_steps = decay_batch\n        self.end_learning_rate = end_learning_rate\n        self.power = power\n        self.cycle = cycle\n        super().__init__(optimizer, last_step, verbose)\n\n    def get_lr(self, base_lr, step):\n        decay_batch = self.max_decay_steps\n        cur_batch = step\n        if self.cycle:\n            if cur_batch == 0:\n                cur_batch = 1\n            decay_batch = decay_batch * math.ceil(cur_batch / decay_batch)\n        else:\n            cur_batch = min(cur_batch, decay_batch)\n\n        factor = (1 - cur_batch / decay_batch) ** (self.power)\n        return (base_lr - self.end_learning_rate) * factor + self.end_learning_rate\n\n    def _generate_conf_for_graph(self, lr_conf):\n        lr_conf.polynomial_conf.SetInParent()\n        polynomial_conf = lr_conf.polynomial_conf\n        polynomial_conf.decay_batches = self.max_decay_steps\n        polynomial_conf.end_learning_rate = self.end_learning_rate\n        polynomial_conf.power = self.power\n        polynomial_conf.cycle = self.cycle\n"
  },
  {
    "path": "python/oneflow/nn/optimizer/reduce_lr_on_plateau.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nfrom math import inf\nfrom ...optim.optimizer import Optimizer\n\n\nclass ReduceLROnPlateau(object):\n    \"\"\"Reduce learning rate when a metric has stopped improving.\n    Models often benefit from reducing the learning rate by a factor\n    of 2-10 once learning stagnates. This scheduler reads a metrics\n    quantity and if no improvement is seen for a 'patience' number\n    of epochs, the learning rate is reduced.\n\n    Args:\n        optimizer (Optimizer): Wrapped optimizer.\n        mode (str): One of `min`, `max`. In `min` mode, lr will\n            be reduced when the quantity monitored has stopped\n            decreasing; in `max` mode it will be reduced when the\n            quantity monitored has stopped increasing. Default: 'min'.\n        factor (float): Factor by which the learning rate will be\n            reduced. new_lr = lr * factor. Default: 0.1.\n        patience (int): Number of epochs with no improvement after\n            which learning rate will be reduced. For example, if\n            `patience = 2`, then we will ignore the first 2 epochs\n            with no improvement, and will only decrease the LR after the\n            3rd epoch if the loss still hasn't improved then.\n            Default: 10.\n        threshold (float): Threshold for measuring the new optimum,\n            to only focus on significant changes. Default: 1e-4.\n        threshold_mode (str): One of `rel`, `abs`. In `rel` mode,\n            dynamic_threshold = best * ( 1 + threshold ) in 'max'\n            mode or best * ( 1 - threshold ) in `min` mode.\n            In `abs` mode, dynamic_threshold = best + threshold in\n            `max` mode or best - threshold in `min` mode. Default: 'rel'.\n        cooldown (int): Number of epochs to wait before resuming\n            normal operation after lr has been reduced. Default: 0.\n        min_lr (float or list): A scalar or a list of scalars. A\n            lower bound on the learning rate of all param groups\n            or each group respectively. Default: 0.\n        eps (float): Minimal decay applied to lr. If the difference\n            between new and old lr is smaller than eps, the update is\n            ignored. Default: 1e-8.\n        verbose (bool): If ``True``, prints a message to stdout for\n            each update. Default: ``False``.\n\n    For example:\n    \n    .. code-block:: python\n\n        optimizer = flow.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)\n        scheduler = flow.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min')\n        for epoch in range(10):\n            train(...)\n            val_loss = validate(...)\n            # Note that step should be called after validate()\n            scheduler.step(val_loss)\n    \"\"\"\n\n    def __init__(\n        self,\n        optimizer,\n        mode=\"min\",\n        factor=0.1,\n        patience=10,\n        threshold=1e-4,\n        threshold_mode=\"rel\",\n        cooldown=0,\n        min_lr=0,\n        eps=1e-8,\n        verbose=False,\n    ):\n\n        if factor >= 1.0:\n            raise ValueError(\"Factor should be < 1.0.\")\n        self.factor = factor\n\n        # Attach optimizer\n        if not isinstance(optimizer, Optimizer):\n            raise TypeError(\"{} is not an Optimizer\".format(type(optimizer).__name__))\n        self.optimizer = optimizer\n\n        if isinstance(min_lr, list) or isinstance(min_lr, tuple):\n            if len(min_lr) != len(optimizer.param_groups):\n                raise ValueError(\n                    \"expected {} min_lrs, got {}\".format(\n                        len(optimizer.param_groups), len(min_lr)\n                    )\n                )\n            self.min_lrs = list(min_lr)\n        else:\n            self.min_lrs = [min_lr] * len(optimizer.param_groups)\n\n        self.patience = patience\n        self.verbose = verbose\n        self.cooldown = cooldown\n        self.cooldown_counter = 0\n        self.mode = mode\n        self.threshold = threshold\n        self.threshold_mode = threshold_mode\n        self.best = None\n        self.num_bad_steps = None\n        self.mode_worse = None  # the worse value for the chosen mode\n        self.eps = eps\n        self.last_step = 0\n        self._init_is_better(\n            mode=mode, threshold=threshold, threshold_mode=threshold_mode\n        )\n        self._reset()\n\n    def step(self, metrics):\n        \"\"\"Performs a single learning rate schedule step.\n\n        Arguments:\n            metrics (float): a metrics quantity of Measuring the effect of model training.\n        \"\"\"\n        # convert `metrics` to float, in case it's a zero-dim Tensor\n        current = float(metrics)\n        self.last_step = self.last_step + 1\n\n        if self.is_better(current, self.best):\n            self.best = current\n            self.num_bad_steps = 0\n        else:\n            self.num_bad_steps += 1\n\n        if self.in_cooldown:\n            self.cooldown_counter -= 1\n            self.num_bad_steps = 0  # ignore any bad epochs in cooldown\n\n        if self.num_bad_steps > self.patience:\n            self._reduce_lr(self.last_step)\n            self.cooldown_counter = self.cooldown\n            self.num_bad_steps = 0\n\n        self._last_lr = [group[\"lr\"] for group in self.optimizer.param_groups]\n\n    @property\n    def in_cooldown(self):\n        \"\"\"Whether the learning rate scheduler in cooldown phase. \n\n        \"\"\"\n        return self.cooldown_counter > 0\n\n    def is_better(self, a, best):\n        \"\"\"Whether the metric has improvement. \n        \n        \"\"\"\n        if self.mode == \"min\" and self.threshold_mode == \"rel\":\n            rel_epsilon = 1.0 - self.threshold\n            return a < best * rel_epsilon\n\n        elif self.mode == \"min\" and self.threshold_mode == \"abs\":\n            return a < best - self.threshold\n\n        elif self.mode == \"max\" and self.threshold_mode == \"rel\":\n            rel_epsilon = self.threshold + 1.0\n            return a > best * rel_epsilon\n\n        else:  # mode == 'max' and epsilon_mode == 'abs':\n            return a > best + self.threshold\n\n    def state_dict(self):\n        \"\"\"Returns the state of the scheduler as a :class:`dict`.\n\n        It contains an entry for every variable in self.__dict__ which\n        is not the optimizer.\n        \"\"\"\n        return {\n            key: value for key, value in self.__dict__.items() if key != \"optimizer\"\n        }\n\n    def load_state_dict(self, state_dict):\n        \"\"\"Loads the schedulers state.\n\n        Arguments:\n            state_dict (dict): scheduler state. Should be an object returned\n                from a call to :meth:`state_dict`.\n        \"\"\"\n        self.__dict__.update(state_dict)\n        self._init_is_better(\n            mode=self.mode, threshold=self.threshold, threshold_mode=self.threshold_mode\n        )\n\n    def _reduce_lr(self, epoch):\n        for i, param_group in enumerate(self.optimizer.param_groups):\n            old_lr = float(param_group[\"lr\"])\n            new_lr = max(old_lr * self.factor, self.min_lrs[i])\n            if old_lr - new_lr > self.eps:\n                param_group[\"lr\"] = new_lr\n                if self.verbose:\n                    print(\n                        \"Epoch {:5d}: reducing learning rate\"\n                        \" of group {} to {:.4e}.\".format(epoch, i, new_lr)\n                    )\n\n    def _reset(self):\n        \"\"\"Resets num_bad_steps counter and cooldown counter.\"\"\"\n        self.best = self.mode_worse\n        self.cooldown_counter = 0\n        self.num_bad_steps = 0\n\n    def _init_is_better(self, mode, threshold, threshold_mode):\n        if mode not in {\"min\", \"max\"}:\n            raise ValueError(\"mode \" + mode + \" is unknown!\")\n        if threshold_mode not in {\"rel\", \"abs\"}:\n            raise ValueError(\"threshold mode \" + threshold_mode + \" is unknown!\")\n\n        if mode == \"min\":\n            self.mode_worse = inf\n        else:  # mode == 'max':\n            self.mode_worse = -inf\n\n        self.mode = mode\n        self.threshold = threshold\n        self.threshold_mode = threshold_mode\n"
  },
  {
    "path": "python/oneflow/nn/optimizer/rmsprop.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport collections\nfrom typing import Callable, Dict, Iterator, List, Union\n\nimport oneflow as flow\nfrom oneflow.optim.optimizer import Optimizer, ParamGroup\nfrom oneflow.nn.parameter import Parameter\n\n\nclass RMSprop(Optimizer):\n    \"\"\"Implements RMSprop algorithm.\n\n    oot Mean Squared Propagation (RMSProp) is an unpublished, adaptive learning\n    rate method. The original slides proposed RMSProp: Slide 29 of\n    http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf .\n\n    The original equation is as follows:\n\n    .. math::\n\n        r(w, t) = \\\\alpha r(w, t-1) + (1 - \\\\alpha)(\\\\nabla Q_{i}(w))^2\n\n        W = w - \\\\frac{\\\\eta} {\\\\\\\\sqrt{r(w,t) + \\\\epsilon}} \\\\nabla Q_{i}(w)\n\n    The first equation calculates moving average of the squared gradient for\n    each weight. Then dividing the gradient by :math:`sqrt{v(w,t)}`.\n    In some cases, adding a momentum term :math: `\\\\beta` is beneficial.\n    In our implementation, Nesterov momentum is used:\n\n    .. math::\n\n        r(w, t) = \\\\alpha r(w, t-1) + (1 - \\\\alpha)(\\\\nabla Q_{i}(w))^2\n\n        v(w, t) = \\\\beta v(w, t-1) + \\\\frac{\\\\eta} {\\\\\\\\sqrt{r(w,t) +\n            \\\\epsilon}} \\\\nabla Q_{i}(w)\n\n        w = w - v(w, t)\n\n    if centered is True:\n\n    .. math::\n\n        r(w, t) = \\\\alpha r(w, t-1) + (1 - \\\\alpha)(\\\\nabla Q_{i}(w))^2\n\n        g(w, t) = \\\\alpha g(w, t-1) + (1 - \\\\alpha)\\\\nabla Q_{i}(w)\n\n        v(w, t) = \\\\beta v(w, t-1) + \\\\frac{\\\\eta} {\\\\\\\\sqrt{r(w,t) - (g(w, t))^2 +\n            \\\\epsilon}} \\\\nabla Q_{i}(w)\n\n        w = w - v(w, t)\n\n    where, :math:`\\\\alpha` is a hyperparameter and typical values are 0.99, 0.95\n    and so on. :math:`\\\\beta` is the momentum term. :math:`\\\\epsilon` is a\n    smoothing term to avoid division by zero, usually set somewhere in range\n    from 1e-4 to 1e-8.\n\n    Args:\n        params (iterable): iterable of parameters to optimize or dicts defining\n            parameter groups\n        lr (float, optional): learning rate (default: 1e-2)\n        momentum (float, optional): momentum factor (default: 0, oneflow not support momenmtum > 0 now!)\n        alpha (float, optional): smoothing constant (default: 0.99)\n        eps (float, optional): term added to the denominator to improve\n            numerical stability (default: 1e-8)\n        centered (bool, optional) : if ``True``, compute the centered RMSProp,\n            the gradient is normalized by an estimation of its variance\n        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)\n        contiguous_params (bool, optional): whether to use contiguous ParamGroup \n            which puts all parameters of the same type, device and group into the\n            same tensor and update them together. (default: False)\n\n    For example: \n\n    Example 1: \n\n    .. code-block:: python \n\n        # Assume net is a custom model. \n        rmsprop = flow.optim.RMSprop(net.parameters(), lr=1e-3)\n\n        for epoch in range(epochs):\n            # Read data, Compute the loss and so on. \n            # ...\n            loss.backward()\n            rmsprop.step()\n            rmsprop.zero_grad()\n\n    Example 2: \n\n    .. code-block:: python \n\n        # Assume net is a custom model. \n        rmsprop = flow.optim.RMSprop(\n            [\n                {\n                    \"params\": net.parameters(),\n                    \"lr\": learning_rate,\n                    \"clip_grad_max_norm\": 0.5,\n                    \"clip_grad_norm_type\": 2.0,\n                }\n            ],\n        )\n\n        for epoch in range(epochs):\n            # Read data, Compute the loss and so on. \n            # ...\n            loss.backward()\n            rmsprop.clip_grad()\n            rmsprop.step()\n            rmsprop.zero_grad()\n\n    If you want to use clip_grad, you can refer this example. \n\n    For more details of `clip_grad_max_norm` and `clip_grad_norm_type`, you can refer to :func:`oneflow.nn.utils.clip_grad_norm_`. \n\n    \"\"\"\n\n    def __init__(\n        self,\n        params: Union[Iterator[Parameter], List[Dict]],\n        lr: float = 0.001,\n        alpha: float = 0.99,\n        eps: float = 1e-08,\n        weight_decay: float = 0,\n        momentum: float = 0.0,\n        centered: bool = False,\n        contiguous_params: bool = False,\n    ):\n        assert lr >= 0.0, f\"Invalid learning rate: {lr}\"\n        assert alpha >= 0.0, f\"Invalid alpha value: {alpha}\"\n        assert eps >= 0.0, f\"Invalid epsilon value: {eps}\"\n        assert weight_decay >= 0.0, f\"Invalid weight_decay value: {weight_decay}\"\n        assert momentum == 0.0, \"Not support momentum greater than zeros now!\"\n        options = dict()\n        options[\"lr\"] = lr\n        options[\"alpha\"] = alpha\n        options[\"eps\"] = eps\n        options[\"weight_decay\"] = weight_decay\n        options[\"centered\"] = centered\n        options[\"contiguous_params\"] = contiguous_params\n        super().__init__(params, options)\n\n        for param_group in self.param_groups:\n            if param_group[\"contiguous_params\"]:\n                param_list = param_group.contiguous_parameters\n            else:\n                param_list = param_group.parameters\n\n            for param in param_list:\n                assert param.is_leaf, \"parameters must be leaf tensor\"\n                self.state[param] = dict()\n\n        self._centered_rmsprop = (\n            flow.stateful_op(\"rmsprop_update\")\n            .Input(\"model\")\n            .Input(\"model_diff\")\n            .Input(\"mean_square\")\n            .Input(\"mean_gradient\")\n            .Build()\n        )\n        self._rmsprop = (\n            flow.stateful_op(\"rmsprop_update\")\n            .Input(\"model\")\n            .Input(\"model_diff\")\n            .Input(\"mean_square\")\n            .Build()\n        )\n\n    def step(self, closure: Callable = None):\n        \"\"\"Performs a single optimization step.\n\n        Args:\n            closure (callable, optional): A closure that reevaluates the model\n                and returns the loss.\n        \"\"\"\n        with flow.no_grad():\n            loss = None\n            if closure is not None:\n                with flow.enable_grad():\n                    loss = closure()\n\n            for param_group in self.param_groups:\n                kwargs = {\n                    \"learning_rate\": param_group[\"lr\"],\n                    \"epsilon\": param_group[\"eps\"],\n                    \"decay_rate\": param_group[\"alpha\"],\n                    \"l2\": param_group[\"weight_decay\"],\n                }\n\n                if param_group[\"contiguous_params\"]:\n                    param_list = param_group.contiguous_parameters\n                else:\n                    param_list = param_group.parameters\n\n                for param in param_list:\n                    if param.grad is None:\n                        continue\n\n                    if \"square_avg\" not in self.state[param]:\n                        self.state[param][\"square_avg\"] = flow.zeros_like(param)\n                    ms_tensor = self.state[param][\"square_avg\"]\n\n                    if param_group[\"centered\"]:\n                        if \"grad_avg\" not in self.state[param]:\n                            self.state[param][\"grad_avg\"] = flow.zeros_like(param)\n                        mg_tensor = self.state[param][\"grad_avg\"]\n                        flow._C.dispatch_rmsprop_update(\n                            self._centered_rmsprop,\n                            (param, param.grad, ms_tensor, mg_tensor),\n                            centered=True,\n                            **kwargs,\n                        )\n                    else:\n                        flow._C.dispatch_rmsprop_update(\n                            self._rmsprop, (param, param.grad, ms_tensor), **kwargs\n                        )\n            self.state[\"step\"] = self.state[\"step\"] + 1\n            return loss\n\n    def _generate_conf_for_graph(self, train_conf, vars_conf):\n        new_opt_confs = []\n        for param_group in self.param_groups:\n            assert (\n                param_group[\"contiguous_params\"] != True\n            ), \"contiguous_params cannot be used in graph\"\n\n            optimizer_conf = train_conf.optimizer_conf.add()\n\n            lr = (\n                param_group[\"initial_lr\"]\n                if \"initial_lr\" in param_group\n                else param_group[\"lr\"]\n            )\n            decay_rate = param_group[\"alpha\"]\n            centered = param_group[\"centered\"]\n            weight_decay = param_group[\"weight_decay\"]\n\n            epslion = param_group[\"eps\"]\n\n            optimizer_conf.base_learning_rate = lr\n            self._generate_lr_scale_for_optim_conf(param_group, optimizer_conf)\n\n            optimizer_conf.rmsprop_conf.decay_rate = decay_rate\n            optimizer_conf.rmsprop_conf.centered = centered\n            optimizer_conf.rmsprop_conf.epsilon = epslion\n\n            self._generate_grad_clip_conf_for_optim_conf(param_group, optimizer_conf)\n\n            # Set l2 penalty as weight decay\n            for param in param_group.parameters:\n                vars_conf[param].l2 = weight_decay\n                if param.requires_grad:\n                    optimizer_conf.variable_op_names.append(vars_conf[param].name)\n\n            new_opt_confs.append(optimizer_conf)\n        return new_opt_confs\n"
  },
  {
    "path": "python/oneflow/nn/optimizer/sequential_lr.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport bisect\nfrom typing import Sequence, Union\nfrom ...optim.optimizer import Optimizer\nfrom .lr_scheduler import LRScheduler\n\n\nclass SequentialLR(LRScheduler):\n    \"\"\"Receives the list of schedulers that is expected to be called sequentially during\n    optimization process and milestone points that provides exact intervals to reflect\n    which scheduler is supposed to be called at a given step.\n\n    Args:\n        optimizer (Optimizer): Wrapped optimizer.\n        schedulers (list): List of chained schedulers.\n        milestones (list): List of integers that reflects milestone points.\n        interval_rescaling (bool or list): Each scheduler has a corresponding 'interval_rescaling'.\n            If it is set to True, scheduler will start and end at the same values as it would\n            if it were the only scheduler, otherwise all schedulers share the same step.\n            Default is False for all schedulers.\n        last_step (int): The index of last step. Default: -1.\n        verbose (bool): Default: False. Print lr if is set to True.\n\n    Example:\n        >>> # Assuming optimizer uses lr = 1. for all groups\n        >>> # lr = 0.1     if step == 0\n        >>> # lr = 0.1     if step == 1\n        >>> # lr = 0.9     if step == 2\n        >>> # lr = 0.81    if step == 3\n        >>> # lr = 0.729   if step == 4\n        >>> scheduler1 = ConstantLR(self.opt, factor=0.1, total_iters=2)\n        >>> scheduler2 = ExponentialLR(self.opt, gamma=0.9)\n        >>> scheduler = SequentialLR(self.opt, schedulers=[scheduler1, scheduler2], milestones=[2])\n        >>> for step in range(100):\n        >>>     train(...)\n        >>>     validate(...)\n        >>>     scheduler.step()\n    \"\"\"\n\n    def __init__(\n        self,\n        optimizer: Optimizer,\n        schedulers: Sequence[LRScheduler],\n        milestones: Sequence[int],\n        interval_rescaling: Union[Sequence[bool], bool] = False,\n        last_step: int = -1,\n        verbose: bool = False,\n    ):\n        assert isinstance(optimizer, Optimizer)\n        assert isinstance(schedulers, (list, tuple))\n        assert isinstance(milestones, (list, tuple))\n\n        if len(schedulers) == 0:\n            raise ValueError(\"Sequential Schedulers expects at least one scheduler\")\n\n        for i in range(len(schedulers)):\n            if schedulers[i].optimizer != optimizer:\n                raise ValueError(\n                    \"Sequential Schedulers expects all schedulers to belong to the same optimizer, but \"\n                    f\"got schedulers at index {i} to be different than the optimizer passed in.\"\n                )\n\n        if len(milestones) != len(schedulers) - 1:\n            raise ValueError(\n                f\"Sequential Schedulers expects number of schedulers provided to be one more \"\n                f\"than the number of milestone points, but got number of schedulers {len(schedulers)} \"\n                f\"and the number of milestones to be equal to {len(milestones)}\"\n            )\n\n        if isinstance(interval_rescaling, (list, tuple)):\n            if len(interval_rescaling) != len(milestones):\n                raise ValueError(\n                    \"'interval_rescaling' expects a bool or a list of bool with length be equal to \"\n                    f\"the number of milestones, but got number of milestones {len(milestones)} \"\n                    f\"and the length of list of interval_rescaling {len(interval_rescaling)}\"\n                )\n\n            assert all([isinstance(r, bool) for r in interval_rescaling])\n        else:\n            assert isinstance(interval_rescaling, bool)\n            interval_rescaling = [interval_rescaling] * (len(milestones))\n\n        self.schedulers = list(schedulers)\n        self.milestones = list(milestones)\n        self.interval_rescaling = list(interval_rescaling)\n        super().__init__(optimizer, last_step, verbose)\n\n    def step(self):\n        self.last_step += 1\n        cur_step = self.last_step\n        s_i = bisect.bisect_right(self.milestones, cur_step)\n        if s_i > 0 and self.interval_rescaling[s_i - 1]:\n            cur_step = self.last_step - self.milestones[s_i - 1]\n\n        scheduler = self.schedulers[s_i]\n        scheduler.last_step = cur_step\n        lrs = [scheduler.get_lr(base_lr, cur_step) for base_lr in self.base_lrs]\n        self.update_lrs(lrs)\n\n    def state_dict(self):\n        # exclude optimizer and nested schedulers\n        state_dict = {\n            key: value\n            for key, value in self.__dict__.items()\n            if key not in (\"optimizer\", \"schedulers\")\n        }\n        state_dict[\"schedulers\"] = [None] * len(self.schedulers)\n        for i, s in enumerate(self.schedulers):\n            state_dict[\"schedulers\"][i] = s.state_dict()\n\n        return state_dict\n\n    def load_state_dict(self, state_dict):\n        scheduler_states = state_dict.pop(\"schedulers\")\n        self.__dict__.update(state_dict)\n        # avoid side effect of calling load_state_dict twice\n        state_dict[\"schedulers\"] = scheduler_states\n\n        for i, s in enumerate(scheduler_states):\n            self.schedulers[i].load_state_dict(s)\n\n    def _generate_conf_for_graph(self, lr_conf):\n        lr_conf.sequential_scheduler_conf.SetInParent()\n        seq_lr_conf = lr_conf.sequential_scheduler_conf\n\n        for scheduler in self.schedulers:\n            scheduler._generate_conf_for_graph(seq_lr_conf.schedulers.add())\n\n        for m in self.milestones:\n            seq_lr_conf.milestones.append(m)\n\n        for r in self.interval_rescaling:\n            seq_lr_conf.interval_rescaling.append(r)\n"
  },
  {
    "path": "python/oneflow/nn/optimizer/sgd.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport warnings\nfrom typing import Callable, Dict, Iterator, List, Union\n\nimport oneflow as flow\nfrom oneflow.nn.parameter import Parameter\n\nfrom ...optim.optimizer import Optimizer, ParamGroup\n\n\nclass SGD(Optimizer):\n    \"\"\"Implements SGD algorithm.\n\n    This algorithm takes a random sample's gradient as an approximate estimate of\n    the overall gradient in small batch gradient descent.\n\n    When the momentum = 0, the equation of parameters updating is:\n\n        .. math::\n\n            param_{new} = param_{old} - learning\\\\_rate * grad\n\n    With momentum, the equation of parameters updating is:\n\n        .. math::\n\n            & V_t = \\\\beta * V_{t-1} - learning\\\\_rate * (g_t + param_{old} * weight\\\\_decay)\n\n            & param_{new} = param_{old} + V_t\n\n    Args:\n        params (iterable): iterable of parameters to optimize or dicts defining\n            parameter groups\n        lr (float, optional): learning rate (default: 1e-3)\n        momentum (float, optional): Momentum factor (default: 0.0)\n        weight_decay (float, optional): weight decay (L2 penalty) (default: 0.0)\n        contiguous_params (bool, optional): whether to use contiguous ParamGroup \n            which puts all parameters of the same type, device and group into the\n            same tensor and update them together. (default: False)\n        fused (bool, optional): whether to divide all the parameters into several groups, then\n            update each group of parameters with the fused kernel. (default: False)\n\n    For example: \n\n    Example 1: \n\n    .. code-block:: python \n\n        # Assume net is a custom model. \n        sgd = flow.optim.SGD(net.parameters(), lr=1e-3)\n\n        for epoch in range(epochs):\n            # Read data, Compute the loss and so on. \n            # ...\n            loss.backward()\n            sgd.step()\n            sgd.zero_grad()\n\n    Example 2: \n\n    .. code-block:: python \n\n        # Assume net is a custom model. \n        sgd = flow.optim.SGD(\n            [\n                {\n                    \"params\": net.parameters(),\n                    \"lr\": learning_rate,\n                    \"clip_grad_max_norm\": 0.5,\n                    \"clip_grad_norm_type\": 2.0,\n                }\n            ],\n        )\n\n        for epoch in range(epochs):\n            # Read data, Compute the loss and so on. \n            # ...\n            loss.backward()\n            sgd.clip_grad()\n            sgd.step()\n            sgd.zero_grad()\n\n    If you want to use clip_grad, you can refer this example. \n\n    For more details of `clip_grad_max_norm` and `clip_grad_norm_type`, you can refer to :func:`oneflow.nn.utils.clip_grad_norm_`. \n\n    \"\"\"\n\n    def __init__(\n        self,\n        params: Union[Iterator[Parameter], List[Dict]],\n        lr: float = 0.001,\n        momentum: float = 0.0,\n        dampening: float = 0.0,\n        weight_decay: float = 0.0,\n        nesterov: bool = False,\n        maximize: bool = False,\n        contiguous_params: bool = False,\n        fused: bool = False,\n    ):\n        assert lr >= 0.0, f\"Invalid learning rate: {lr}\"\n        assert momentum >= 0.0, f\"Invalid momentum: {momentum}\"\n        assert weight_decay >= 0.0, f\"Invalid weight_decay: {weight_decay}\"\n        if maximize:\n            warnings.warn(\n                \"Only Momentum > 0.0, param `maximize` takes effect. \", FutureWarning,\n            )\n        options = dict()\n        options[\"lr\"] = lr\n        options[\"momentum\"] = momentum\n        options[\"dampening\"] = dampening\n        options[\"weight_decay\"] = weight_decay\n        options[\"nesterov\"] = nesterov\n        options[\"maximize\"] = maximize\n        options[\"contiguous_params\"] = contiguous_params\n        options[\"fused\"] = fused\n        super().__init__(params, options)\n\n        for param_group in self.param_groups:\n            if param_group[\"contiguous_params\"]:\n                param_list = param_group.contiguous_parameters\n            else:\n                param_list = param_group.parameters\n\n            for param in param_list:\n                assert param.is_leaf, \"parameters must be leaf tensor\"\n                self.state[param] = dict()\n\n                if param_group[\"fused\"] and not param.is_cuda:\n                    warnings.warn(\"Fused SGD only support cuda parameters.\")\n                    param_group[\"fused\"] = False\n\n        self._momentum_sgd = (\n            flow.stateful_op(\"momentum_update\")\n            .Input(\"model\")\n            .Input(\"model_diff\")\n            .Input(\"momentum\")\n            .Build()\n        )\n        self._sgd = (\n            flow.stateful_op(\"sgd_update\").Input(\"model\").Input(\"model_diff\").Build()\n        )\n\n    def _single_tensor_update(self, param_group):\n        lr = param_group[\"lr\"]\n        l2 = param_group[\"weight_decay\"]\n\n        if param_group[\"contiguous_params\"]:\n            param_list = param_group.contiguous_parameters\n        else:\n            param_list = param_group.parameters\n\n        for param in param_list:\n            if param.grad is None:\n                continue\n            if param_group[\"momentum\"] == 0.0:\n                # TODO: Support param `maximize` in Naive SGD Optimizer. (zhengzekang)\n                flow._C.dispatch_sgd_update(\n                    self._sgd, (param, param.grad), learning_rate=lr, l2=l2\n                )\n            else:\n                if \"momentum_buf\" not in self.state[param]:\n                    self.state[param][\"momentum_buf\"] = flow.zeros_like(param)\n                momentum_buf = self.state[param][\"momentum_buf\"]\n                beta = param_group[\"momentum\"]\n                dampening = param_group[\"dampening\"]\n                nesterov = param_group[\"nesterov\"]\n                maximize = param_group[\"maximize\"]\n                flow._C.dispatch_momentum_update(\n                    self._momentum_sgd,\n                    (param, param.grad, momentum_buf),\n                    learning_rate=lr,\n                    l2=l2,\n                    beta=beta,\n                    dampening=dampening,\n                    nesterov=nesterov,\n                    maximize=maximize,\n                )\n\n    def _fused_update(self, param_group):\n        use_momentum = param_group[\"momentum\"] != 0\n        param_list = []\n        param_grad_list = []\n        if use_momentum:\n            momentum_buf_list = []\n\n        for param in param_group.parameters:\n            if param.grad is None:\n                continue\n            param_list.append(param)\n            param_grad_list.append(param.grad)\n\n            if use_momentum:\n                if \"momentum_buf\" not in self.state[param]:\n                    self.state[param][\"momentum_buf\"] = flow.zeros_like(param)\n                momentum_buf_list.append(self.state[param][\"momentum_buf\"])\n\n        if not use_momentum:\n            flow._C.multi_tensor_sgd_update(\n                model=param_list,\n                model_diff=param_grad_list,\n                scale=1.0,\n                weight_decay=param_group[\"weight_decay\"],\n                learning_rate_val=param_group[\"lr\"],\n            )\n        else:\n            flow._C.multi_tensor_momentum_update(\n                model=param_list,\n                model_diff=param_grad_list,\n                momentum_buf=momentum_buf_list,\n                scale=1.0,\n                weight_decay=param_group[\"weight_decay\"],\n                learning_rate_val=param_group[\"lr\"],\n                momentum=param_group[\"momentum\"],\n                dampening=param_group[\"dampening\"],\n                nesterov=param_group[\"nesterov\"],\n                maximize=param_group[\"maximize\"],\n            )\n\n    def step(self, closure: Callable = None):\n        \"\"\"Performs a single optimization step.\n        Args:\n            closure (callable, optional): A closure that reevaluates the model\n                and returns the loss.\n        \"\"\"\n        with flow.no_grad():\n            loss = None\n            if closure is not None:\n                with flow.enable_grad():\n                    loss = closure()\n\n            for param_group in self.param_groups:\n                if param_group[\"fused\"]:\n                    self._fused_update(param_group)\n                else:\n                    self._single_tensor_update(param_group)\n\n        self.state[\"step\"] = self.state[\"step\"] + 1\n        return loss\n\n    def _generate_conf_for_graph(self, train_conf, vars_conf):\n        new_opt_confs = []\n        for param_group in self.param_groups:\n            assert (\n                param_group[\"contiguous_params\"] != True\n            ), \"contiguous_params cannot be used in graph\"\n\n            optimizer_conf = train_conf.optimizer_conf.add()\n            lr = (\n                param_group[\"initial_lr\"]\n                if \"initial_lr\" in param_group\n                else param_group[\"lr\"]\n            )\n            beta = param_group[\"momentum\"]\n            l2 = param_group[\"weight_decay\"]\n            dampening = param_group[\"dampening\"]\n            nesterov = param_group[\"nesterov\"]\n            maximize = param_group[\"maximize\"]\n\n            optimizer_conf.base_learning_rate = lr\n            self._generate_lr_scale_for_optim_conf(param_group, optimizer_conf)\n\n            if beta == 0:\n                optimizer_conf.naive_conf.SetInParent()\n            else:\n                optimizer_conf.momentum_conf.beta = beta\n                # Only Momentum Optimizer support these params.\n                optimizer_conf.momentum_conf.dampening = dampening\n                optimizer_conf.momentum_conf.nesterov = nesterov\n                optimizer_conf.momentum_conf.maximize = maximize\n\n            self._generate_grad_clip_conf_for_optim_conf(param_group, optimizer_conf)\n\n            for param in param_group.parameters:\n                vars_conf[param].l2 = l2\n                if param.requires_grad:\n                    optimizer_conf.variable_op_names.append(vars_conf[param].name)\n\n            new_opt_confs.append(optimizer_conf)\n        return new_opt_confs\n\n    @property\n    def support_sparse(self):\n        \"\"\"Whether SGD Optimizer support sparse update. \n\n        \"\"\"\n        return True\n"
  },
  {
    "path": "python/oneflow/nn/optimizer/step_lr.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport math\n\nfrom ...optim.optimizer import Optimizer\nfrom .lr_scheduler import LRScheduler\n\n\nclass StepLR(LRScheduler):\n    \"\"\"\n    Decays the learning rate of each parameter group by gamma every step_size steps.\n    Notice that such decay can happen simultaneously with other changes to the learning\n    rate fromoutside this scheduler. When last_step=-1, sets initial lr as lr.\n\n    Args:\n        optimizer(Optimizer): Wrapped optimizer.\n        step_size (int): Period of learning rate decay.\n        gamma (float, optional): Multiplicative factor of learning rate decay. (default: 0.1)\n        last_step (int, optional): The index of last step. (default: -1)\n        verbose (bool, optional): If ``True``, prints a message to stdout for each update. (default: ``False``)\n\n    For example:\n\n    .. code-block:: python\n\n        import oneflow as flow\n\n        ...\n        step_lr = flow.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)\n        for epoch in range(num_epoch):\n            train(...)\n            step_lr.step()\n\n    \"\"\"\n\n    def __init__(\n        self,\n        optimizer: Optimizer,\n        step_size: int,\n        gamma: float = 0.1,\n        last_step: int = -1,\n        verbose: bool = False,\n    ):\n        assert step_size > 0, f\"step_size must greater than zero, but got {step_size}\"\n        assert gamma > 0.0, f\"gamma must greater than zero, but got {gamma}\"\n        self.step_size = step_size\n        self.gamma = gamma\n        super().__init__(optimizer, last_step, verbose)\n\n    def get_lr(self, base_lr, step):\n        step_stage = math.floor(step / self.step_size)\n        factor = self.gamma ** step_stage\n        return base_lr * factor\n\n    def _generate_conf_for_graph(self, lr_conf):\n        lr_conf.step_conf.SetInParent()\n        step_conf = lr_conf.step_conf\n        step_conf.step_size = self.step_size\n        step_conf.gamma = self.gamma\n"
  },
  {
    "path": "python/oneflow/nn/optimizer/swa_utils.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nr\"\"\"\nSwa_utils Methods are consistent with PyTorch.\nThe documentation is referenced from:\nhttps://pytorch.org/docs/stable/optim.html#stochastic-weight-averaging.\n\"\"\"\nimport itertools\nimport math\nfrom copy import deepcopy\nimport warnings\n\nimport oneflow as flow\nfrom oneflow.nn import Module\nfrom oneflow.nn.optimizer.lr_scheduler import LRScheduler\n\n__all__ = [\"AveragedModel\", \"update_bn\", \"SWALR\"]\n\n\nclass AveragedModel(Module):\n    r\"\"\"Implements averaged model for Stochastic Weight Averaging (SWA).\n\n    The documentation is referenced from:\n    https://pytorch.org/docs/stable/optim.html#stochastic-weight-averaging\n\n    Stochastic Weight Averaging was proposed in `Averaging Weights Leads to\n    Wider Optima and Better Generalization`_ by Pavel Izmailov, Dmitrii\n    Podoprikhin, Timur Garipov, Dmitry Vetrov and Andrew Gordon Wilson\n    (UAI 2018).\n\n    AveragedModel class creates a copy of the provided module :attr:`model`\n    on the device :attr:`device` and allows to compute running averages of the\n    parameters of the :attr:`model`.\n\n    Args:\n        model (oneflow.nn.Module): model to use with SWA\n        device (oneflow.device, optional): if provided, the averaged model will be\n            stored on the :attr:`device`\n        avg_fn (function, optional): the averaging function used to update\n            parameters; the function must take in the current value of the\n            :class:`AveragedModel` parameter, the current value of :attr:`model`\n            parameter and the number of models already averaged; if None,\n            equally weighted average is used (default: None)\n        use_buffers (bool): if ``True``, it will compute running averages for\n            both the parameters and the buffers of the model. (default: ``False``)\n\n    For example:\n\n    .. code-block:: python\n\n        import oneflow as flow\n\n        ...\n        loader, optimizer, model, loss_fn = ...\n        swa_model = flow.optim.swa_utils.AveragedModel(model)\n        scheduler = flow.optim.lr_scheduler.CosineAnnealingLR(optimizer,\n                                             T_max=300)\n        swa_start = 160\n        swa_scheduler = SWALR(optimizer, swa_lr=0.05)\n        for i in range(300):\n            for input, target in loader:\n                optimizer.zero_grad()\n                loss_fn(model(input), target).backward()\n                optimizer.step()\n            if i > swa_start:\n                swa_model.update_parameters(model)\n                swa_scheduler.step()\n            else:\n                scheduler.step()\n\n        # Update bn statistics for the swa_model at the end\n        flow.optim.swa_utils.update_bn(loader, swa_model)\n\n    You can also use custom averaging functions with `avg_fn` parameter.\n    If no averaging function is provided, the default is to compute\n    equally-weighted average of the weights.\n\n    For example:\n\n    .. code-block:: python\n\n        import oneflow as flow\n\n\n        ema_avg = lambda averaged_model_parameter, model_parameter, num_averaged: (\n                         0.1 * averaged_model_parameter + 0.9 * model_parameter)\n        swa_model = flow.optim.swa_utils.AveragedModel(model, avg_fn=ema_avg, use_buffers=True)\n\n    .. note::\n        When using SWA with models containing Batch Normalization you may\n        need to update the activation statistics for Batch Normalization.\n        This can be done either by using the :meth:`oneflow.optim.swa_utils.update_bn`\n        or by setting :attr:`use_buffers` to `True`. The first approach updates the\n        statistics in a post-training step by passing data through the model. The\n        second does it during the parameter update phase by averaging all buffers.\n        Empirical evidence has shown that updating the statistics in normalization\n        layers increases accuracy, but you may wish to empirically test which\n        approach yields the best results in your problem.\n\n    .. note::\n        :attr:`avg_fn` is not saved in the :meth:`state_dict` of the model.\n\n    .. note::\n        When :meth:`update_parameters` is called for the first time (i.e.\n        :attr:`n_averaged` is `0`) the parameters of `model` are copied\n        to the parameters of :class:`AveragedModel`. For every subsequent\n        call of :meth:`update_parameters` the function `avg_fn` is used\n        to update the parameters.\n\n    .. _Averaging Weights Leads to Wider Optima and Better Generalization:\n        https://arxiv.org/abs/1803.05407\n    .. _There Are Many Consistent Explanations of Unlabeled Data: Why You Should\n        Average:\n        https://arxiv.org/abs/1806.05594\n    .. _SWALP: Stochastic Weight Averaging in Low-Precision Training:\n        https://arxiv.org/abs/1904.11943\n    .. _Stochastic Weight Averaging in Parallel: Large-Batch Training That\n        Generalizes Well:\n        https://arxiv.org/abs/2001.02312\n    \"\"\"\n\n    def __init__(self, model, device=None, avg_fn=None, use_buffers=False):\n        super(AveragedModel, self).__init__()\n        self.module = deepcopy(model)\n        if device is not None:\n            self.module = self.module.to(device)\n        self.register_buffer(\n            \"n_averaged\", flow.tensor(0, dtype=flow.long, device=device)\n        )\n        if avg_fn is None:\n\n            def avg_fn(averaged_model_parameter, model_parameter, num_averaged):\n                return averaged_model_parameter + (\n                    model_parameter - averaged_model_parameter\n                ) / (num_averaged + 1)\n\n        self.avg_fn = avg_fn\n        self.use_buffers = use_buffers\n\n    def forward(self, *args, **kwargs):\n        return self.module(*args, **kwargs)\n\n    def update_parameters(self, model):\n        self_param = (\n            itertools.chain(self.module.parameters(), self.module.buffers())\n            if self.use_buffers\n            else self.parameters()\n        )\n        model_param = (\n            itertools.chain(model.parameters(), model.buffers())\n            if self.use_buffers\n            else model.parameters()\n        )\n        for p_swa, p_model in zip(self_param, model_param):\n            device = p_swa.device\n            p_model_ = p_model.detach().to(device)\n            if self.n_averaged == 0:\n                p_swa.detach().copy_(p_model_)\n            else:\n                p_swa.detach().copy_(\n                    self.avg_fn(p_swa.detach(), p_model_, self.n_averaged.to(device))\n                )\n        if not self.use_buffers:\n            # If not apply running averages to the buffers,\n            # keep the buffers in sync with the source model.\n            for b_swa, b_model in zip(self.module.buffers(), model.buffers()):\n                b_swa.detach().copy_(b_model.detach().to(device))\n        self.n_averaged += 1\n\n\ndef update_bn(loader, model, device=None):\n    r\"\"\"Updates BatchNorm running_mean, running_var buffers in the model.\n\n    The documentation is referenced from:\n    https://pytorch.org/docs/stable/optim.html#taking-care-of-batch-normalization\n\n    It performs one pass over data in `loader` to estimate the activation\n    statistics for BatchNorm layers in the model.\n    Args:\n        loader (oneflow.utils.data.DataLoader): dataset loader to compute the\n            activation statistics on. Each data batch should be either a\n            tensor, or a list/tuple whose first element is a tensor\n            containing data.\n        model (oneflow.nn.Module): model for which we seek to update BatchNorm\n            statistics.\n        device (oneflow.device, optional): If set, data will be transferred to\n            :attr:`device` before being passed into :attr:`model`.\n\n    For example:\n\n    .. code-block:: python\n\n        import oneflow as flow\n\n        loader, model = ...\n        flow.optim.swa_utils.update_bn(loader, model)\n\n    .. note::\n        The `update_bn` utility assumes that each data batch in :attr:`loader`\n        is either a tensor or a list or tuple of tensors; in the latter case it\n        is assumed that :meth:`model.forward()` should be called on the first\n        element of the list or tuple corresponding to the data batch.\n    \"\"\"\n    with flow.no_grad():\n        momenta = {}\n        for module in model.modules():\n            if isinstance(module, flow.nn.modules.batchnorm._BatchNorm):\n                module.running_mean = flow.zeros_like(module.running_mean)\n                module.running_var = flow.ones_like(module.running_var)\n                momenta[module] = module.momentum\n\n        if not momenta:\n            return\n\n        was_training = model.training\n        model.train()\n        for module in momenta.keys():\n            module.momentum = None\n            module.num_batches_tracked *= 0\n\n        for input in loader:\n            if isinstance(input, (list, tuple)):\n                input = input[0]\n            if device is not None:\n                input = input.to(device)\n\n            model(input)\n\n        for bn_module in momenta.keys():\n            bn_module.momentum = momenta[bn_module]\n        model.train(was_training)\n\n\nclass SWALR(LRScheduler):\n    r\"\"\"Anneals the learning rate in each parameter group to a fixed value.\n\n    The documentation is referenced from:\n    https://pytorch.org/docs/stable/optim.html#swa-learning-rate-schedules\n\n    This learning rate scheduler is meant to be used with Stochastic Weight\n    Averaging (SWA) method (see `oneflow.optim.swa_utils.AveragedModel`).\n\n    Args:\n        optimizer (oneflow.optim.Optimizer): wrapped optimizer\n        swa_lrs (float or list): the learning rate value for all param groups\n            together or separately for each group.\n        annealing_epochs (int): number of epochs in the annealing phase\n            (default: 10)\n        annealing_strategy (str): \"cos\" or \"linear\"; specifies the annealing\n            strategy: \"cos\" for cosine annealing, \"linear\" for linear annealing\n            (default: \"cos\")\n        last_epoch (int): the index of the last epoch (default: -1)\n\n    The :class:`SWALR` scheduler can be used together with other\n    schedulers to switch to a constant learning rate late in the training\n    as in the example below.\n\n    For example:\n\n    .. code-block:: python\n\n        import oneflow as flow\n\n        loader, optimizer, model = ...\n        lr_lambda = lambda epoch: 0.9\n        scheduler = flow.optim.lr_scheduler.MultiplicativeLR(optimizer,\n                lr_lambda=lr_lambda)\n        swa_scheduler = flow.optim.swa_utils.SWALR(optimizer,\n                anneal_strategy=\"linear\", anneal_epochs=20, swa_lr=0.05)\n        swa_start = 160\n        for i in range(300):\n            for input, target in loader:\n                optimizer.zero_grad()\n                loss_fn(model(input), target).backward()\n                optimizer.step()\n            if i > swa_start:\n                swa_scheduler.step()\n            else:\n                scheduler.step()\n\n    .. _Averaging Weights Leads to Wider Optima and Better Generalization:\n        https://arxiv.org/abs/1803.05407\n    \"\"\"\n\n    def __init__(\n        self, optimizer, swa_lr, anneal_epochs=10, anneal_strategy=\"cos\", last_epoch=-1\n    ):\n        swa_lrs = self._format_param(optimizer, swa_lr)\n        for swa_lr, group in zip(swa_lrs, optimizer.param_groups):\n            group[\"swa_lr\"] = swa_lr\n        if anneal_strategy not in [\"cos\", \"linear\"]:\n            raise ValueError(\n                \"anneal_strategy must by one of 'cos' or 'linear', \"\n                f\"instead got {anneal_strategy}\"\n            )\n        elif anneal_strategy == \"cos\":\n            self.anneal_func = self._cosine_anneal\n        elif anneal_strategy == \"linear\":\n            self.anneal_func = self._linear_anneal\n        if not isinstance(anneal_epochs, int) or anneal_epochs < 0:\n            raise ValueError(\n                f\"anneal_epochs must be equal or greater than 0, got {anneal_epochs}\"\n            )\n        self.anneal_epochs = anneal_epochs\n        self.param_group_index = 0\n        super(SWALR, self).__init__(optimizer, last_epoch)\n\n    @staticmethod\n    def _format_param(optimizer, swa_lrs):\n        if isinstance(swa_lrs, (list, tuple)):\n            if len(swa_lrs) != len(optimizer.param_groups):\n                raise ValueError(\n                    \"swa_lr must have the same length as \"\n                    f\"optimizer.param_groups: swa_lr has {len(swa_lrs)}, \"\n                    f\"optimizer.param_groups has {len(optimizer.param_groups)}\"\n                )\n            return swa_lrs\n        else:\n            return [swa_lrs] * len(optimizer.param_groups)\n\n    @staticmethod\n    def _linear_anneal(t):\n        return t\n\n    @staticmethod\n    def _cosine_anneal(t):\n        return (1 - math.cos(math.pi * t)) / 2\n\n    @staticmethod\n    def _get_initial_lr(lr, swa_lr, alpha):\n        if alpha == 1:\n            return swa_lr\n        return (lr - alpha * swa_lr) / (1 - alpha)\n\n    def get_lr(self, base_lr, step):\n        if self.anneal_epochs == 0:\n            step = max(1, step)\n        prev_t = max(0, min(1, (step - 1) / max(1, self.anneal_epochs)))\n        prev_alpha = self.anneal_func(prev_t)\n        group = self.optimizer.param_groups[self.param_group_index]\n        prev_lr = self._get_initial_lr(group[\"lr\"], group[\"swa_lr\"], prev_alpha)\n        self.param_group_index += 1\n        if self.param_group_index == len(self.optimizer.param_groups):\n            self.param_group_index = 0\n        t = max(0, min(1, step / max(1, self.anneal_epochs)))\n        alpha = self.anneal_func(t)\n        return group[\"swa_lr\"] * alpha + prev_lr * (1 - alpha)\n"
  },
  {
    "path": "python/oneflow/nn/optimizer/warmup_lr.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport numpy as np\nfrom typing import Union\nfrom ...optim.optimizer import Optimizer\nfrom .lr_scheduler import LRScheduler\nfrom .sequential_lr import SequentialLR\nfrom .constant_lr import ConstantLR\nfrom .linear_lr import LinearLR\n\n\nclass WarmupLR(SequentialLR):\n    r\"\"\"Increasing the learning rate with a small warmup factor until the number of epoch\n    reaches the warmup_iters. You can assign an optimizer or a learning rate scheduler.\n    Notice that the warmup can happen simultaneously with learning rate scheduler.\n\n    Args:\n        scheduler_or_optimizer ([type]): Wrapped learning rate scheduler or optimizer\n        warmup_factor (float, optional): The warmup factor. Defaults to 1.0/3.\n        warmup_iters (int, optional): The number of warmup steps. Defaults to 5.\n        warmup_method (str, optional): The method of warmup, you can choose \"linear\" or \"constant\".\n            In linear mode, the multiplication factor starts with warmup_factor in the first epoch and then inreases linearly to reach 1. Defaults to \"linear\".\n        last_step (int, optional): The index of the last step. Defaults to -1.\n        verbose (bool, optional): If True, it prints a message to stdout for each update step. Defaults to False.\n\n    Raises:\n        ValueError: The warmup method should be one of the \"constant\" and \"linear\"\n\n    For example:\n\n    Example 1:\n\n    .. code:: python\n\n        # lr = 0.0005    if epoch == 0\n        # lr = 0.0005    if epoch == 1\n        # lr = 0.0005    if epoch == 2\n        # lr = 0.0005    if epoch == 3\n        # lr = 0.0005    if epoch == 4\n        # lr = 0.001     if epoch >= 5\n        of_sgd = flow.optim.SGD(parameters, lr=0.001)\n        constant_warmup_lr = flow.optim.lr_scheduler.WarmUpLR(\n            of_sgd, warmup_factor=0.5, warmup_iters=5, warmup_method=\"constant\"\n        )\n        ...\n\n    Example 2:\n\n    .. code:: python\n\n        # lr = 0.0005    if epoch == 0\n        # lr = 0.0006    if epoch == 1\n        # lr = 0.0007    if epoch == 2\n        # lr = 0.0008    if epoch == 3\n        # lr = 0.0009    if epoch == 4\n        # lr = 0.001    if epoch >= 5\n        of_sgd = flow.optim.SGD(parameters, lr=0.001)\n        constant_warmup_lr = flow.optim.lr_scheduler.WarmUpLR(\n            of_sgd, warmup_factor=0.5, warmup_iters=5, warmup_method=\"linear\"\n        )\n        ...\n\n    Example 2:\n\n    .. code:: python\n\n        # lr = 0.0005    if epoch == 0\n        # lr = 0.00075   if epoch == 1\n        # Above is WarmUpLR, then we start CosineDecayLR\n        # lr = 0.000689  if epoch == 2\n        # lr = 0.000410  if epoch == 3\n        # ....\n        of_sgd = flow.optim.SGD(parameters, lr=0.001)\n        alpha = 0.1\n        decay_steps = 5\n        cosine_decay_lr = flow.optim.lr_scheduler.CosineDecayLR(\n            of_sgd, decay_steps=decay_steps, alpha=alpha\n        )\n        linear_warmup_cosine_lr = flow.optim.lr_scheduler.WarmUpLR(\n            cosine_decay_lr, warmup_factor=0.5, warmup_iters=2, warmup_method=\"linear\"\n        )\n        ...\n    \"\"\"\n\n    def __init__(\n        self,\n        scheduler_or_optimizer: Union[LRScheduler, Optimizer],\n        warmup_factor: float = 1.0 / 3,\n        warmup_iters: int = 5,\n        warmup_method: str = \"linear\",\n        warmup_prefix: bool = False,\n        last_step=-1,\n        verbose=False,\n    ):\n        if not isinstance(scheduler_or_optimizer, (LRScheduler, Optimizer)):\n            raise ValueError(\n                \"'scheduler_or_optimizer' must be a LRScheduler or an Optimizer, but got \"\n                f\"{type(scheduler_or_optimizer)}\"\n            )\n\n        if warmup_method not in (\"linear\", \"constant\"):\n            raise ValueError(\n                f\"'warmup_method' must be 'linear' or 'constant', but got {warmup_method}\"\n            )\n\n        if isinstance(scheduler_or_optimizer, LRScheduler):\n            opt = scheduler_or_optimizer.optimizer\n            scheduler = scheduler_or_optimizer\n        else:\n            opt = scheduler_or_optimizer\n            scheduler = None\n\n        if scheduler is None and warmup_iters == 0:\n            raise ValueError(\n                \"When 'scheduler_or_optimizer' is an optimizer warmup_iters can't be equal to 0\"\n            )\n\n        self.warmup_factor = warmup_factor\n        self.warmup_iters = warmup_iters\n        self.warmup_method = warmup_method\n        self.warmup_prefix = warmup_prefix\n        # manually init optimizer, last_step, base_lrs first\n        self.optimizer = opt\n        self.last_step = last_step\n        self.verbose = verbose\n        self._init_base_lrs()\n        warmup = self._init_warmup_scheduler(scheduler)\n        self._init_seq_scheduler(scheduler, warmup)\n\n    def _init_warmup_scheduler(self, scheduler):\n        warmup = None\n\n        if self.warmup_iters <= 0:\n            return\n\n        if self.warmup_method == \"linear\":\n            if scheduler and self.warmup_prefix is False:\n                base_lr = self.base_lrs[0]\n                if not np.isclose(self.base_lrs, base_lr).all():\n                    raise ValueError(\n                        \"The param_groups in optimizer have different warmup configs, please use different optimizers.\"\n                    )\n\n                end_lr = scheduler.get_lr(base_lr, self.warmup_iters)\n                end_factor = end_lr / base_lr\n            else:\n                end_factor = 1.0\n\n            warmup = LinearLR(\n                self.optimizer,\n                start_factor=self.warmup_factor,\n                end_factor=end_factor,\n                total_iters=self.warmup_iters,\n                last_step=self.last_step,\n                verbose=self.verbose,\n            )\n        else:  # \"constant\"\n            warmup = ConstantLR(\n                self.optimizer,\n                factor=self.warmup_factor,\n                total_iters=self.warmup_iters,\n                last_step=self.last_step,\n                verbose=self.verbose,\n            )\n\n        return warmup\n\n    def _init_seq_scheduler(self, scheduler, warmup):\n        if warmup and scheduler:\n            schedulers = [warmup, scheduler]\n            milestones = [self.warmup_iters]\n            interval_rescaling = [self.warmup_prefix]\n        elif warmup:\n            schedulers = [warmup]\n            milestones = []\n            interval_rescaling = []\n        elif scheduler:\n            schedulers = [scheduler]\n            milestones = []\n            interval_rescaling = []\n        else:\n            raise ValueError(\"No scheduler can work\")\n\n        super().__init__(\n            self.optimizer,\n            schedulers=schedulers,\n            milestones=milestones,\n            interval_rescaling=interval_rescaling,\n            last_step=self.last_step,\n            verbose=self.verbose,\n        )\n"
  },
  {
    "path": "python/oneflow/nn/parallel/__init__.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nfrom .distributed import DistributedDataParallel\n\n__all__ = [\"DistributedDataParallel\"]\n"
  },
  {
    "path": "python/oneflow/nn/parallel/distributed.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport warnings\nfrom collections import OrderedDict\n\nimport oneflow as flow\nfrom oneflow.support.env_var_util import parse_boolean_from_env\nfrom oneflow.framework.tensor_tuple_util import convert_to_tensor_tuple\nfrom oneflow.nn.utils.parameters_grouping import ContiguousParamsGroup\nfrom oneflow.framework.args_tree import ArgsTree\n\n\ndef allreduce_fn(module, param, use_bucket):\n    ddp_state_for_reversed_params = module._ddp_state_for_reversed_params\n\n    def allreduce_with_bucket(grad):\n        buckets = module._buckets\n        bucket_tensors = module._bucket_tensors\n        ddp_state_for_reversed_params[param][0] = True\n        for index, bucket in enumerate(buckets):\n            deleted = all(ddp_state_for_reversed_params[x][1] for x in bucket)\n            if deleted:\n                continue\n\n            assert not any(ddp_state_for_reversed_params[x][1] for x in bucket)\n\n            all_params_in_bucket_ready = all(\n                ddp_state_for_reversed_params[x][0] for x in bucket\n            )\n            if all_params_in_bucket_ready:\n                for x in bucket:\n                    ddp_state_for_reversed_params[x][1] = True\n                # NOTE(jianhao)(higher-order-grad):\n                # local allreduce doesn't have gradient function, higher-order grad may be unsupported\n                flow._C.local_all_reduce(bucket_tensors[index], inplace=True)\n            else:\n                break\n\n    def allreduce_without_bucket(grad):\n        ddp_state_for_reversed_params[param][0] = True\n        for cur_param, (ready, deleted) in ddp_state_for_reversed_params.items():\n            if deleted:\n                continue\n            if ready:\n                ddp_state_for_reversed_params[cur_param][1] = True\n                # NOTE(jianhao)(higher-order-grad): local allreduce doesn't have gradient function, higher-order grad may be unsupported\n                if cur_param is param:\n                    flow._C.local_all_reduce(grad, True)\n                else:\n                    flow._C.local_all_reduce(cur_param.grad, True)\n            else:\n                break\n\n    return allreduce_with_bucket if use_bucket else allreduce_without_bucket\n\n\ndef DistributedDataParallel(\n    module: \"flow.nn.Module\",\n    *,\n    broadcast_buffers: bool = True,\n    broadcast_parameters: bool = True,\n    bucket_size: int = 10,\n    use_bucket: bool = True,\n):\n    assert all(x.dtype == flow.float32 for x in module.parameters())\n    if use_bucket and parse_boolean_from_env(\"ONEFLOW_DISABLE_VIEW\", False):\n        warnings.warn(\n            \"because the environment variable 'ONEFLOW_DISABLE_VIEW' is set to true, so the view mechanism is disabled, and we will set use_bucket=False\"\n        )\n        use_bucket = False\n    world_size = flow.env.get_world_size()\n    if broadcast_parameters:\n        with flow.no_grad():\n            for x in module.parameters():\n                requires_grad = x.requires_grad\n                flow._C.comm_broadcast(x, inplace=True)\n                # TODO: fix the bug that x's requires_grad is discarded\n                # after flow._C.comm_broadcast\n                x.requires_grad_(requires_grad)\n\n    if use_bucket:\n        all_grad_size = sum([x.numel() for x in module.parameters()])\n        if all_grad_size > 0:\n            device = list(module.parameters())[0].device\n            assert all(x.device == device for x in module.parameters())\n        reversed_param_list = list(\n            reversed(\n                list([param for param in module.parameters() if param.requires_grad])\n            )\n        )\n\n        module._bucket_index = {\n            x: i // bucket_size for i, x in enumerate(reversed_param_list)\n        }\n        module._buckets = [\n            reversed_param_list[i : i + bucket_size]\n            for i in range(0, len(reversed_param_list), bucket_size)\n        ]\n\n        module._params_group = ContiguousParamsGroup(module._buckets)\n        module._bucket_tensors = module._params_group.grouped_parameters_grad\n\n    ddp_state_for_reversed_params = OrderedDict(\n        reversed([(x, [False, False]) for x in module.parameters() if x.requires_grad])\n    )\n    module._ddp_state_for_reversed_params = ddp_state_for_reversed_params\n    # The gradient shoule be averaged by all the nodes, so besides allreduce,\n    # a division by world_size is required.\n    # Use x * (1 / world_size) instead of x / world_size for two reasons:\n    # 1. multiplication is faster than division\n    # 2. An inplace operation is needed here (for allreduce grouping)\n    #    But we do not have inplace division in oneflow.\n    mul_factor = 1 / world_size\n\n    def inplace_mul_and_return_none(x):\n        x.mul_(mul_factor)\n        return None\n\n    for param in module.parameters():\n        if param.requires_grad:\n            param._register_post_grad_accumulation_hook(inplace_mul_and_return_none)\n            param._register_post_grad_accumulation_hook(\n                allreduce_fn(module, param, use_bucket)\n            )\n\n    def post_forward_hook(module, input, output):\n        ddp_state_for_reversed_params = module._ddp_state_for_reversed_params\n        for state in ddp_state_for_reversed_params.values():\n            state[0], state[1] = False, False\n        output = ArgsTree(output).map_leaf(\n            lambda x: flow._C.select_top_n(\n                convert_to_tensor_tuple([x, *ddp_state_for_reversed_params.keys()]),\n                n=1,\n            )[0]\n        )\n        buffers = list(module.buffers())\n        if len(buffers) > 0:\n            flow._C.stream_touch(buffers)\n        return output\n\n    module.register_forward_hook(post_forward_hook)\n\n    if broadcast_buffers:\n\n        def pre_forward_hook(module, input):\n            with flow.no_grad():\n                buffers = list(module.buffers())\n                flow._C.comm_broadcast(buffers, inplace=True)\n\n        module.register_forward_pre_hook(pre_forward_hook)\n\n    module._is_ddp_module = True\n\n    return module\n"
  },
  {
    "path": "python/oneflow/nn/parameter.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow as flow\n\nParameter = flow._oneflow_internal.nn.Parameter\n"
  },
  {
    "path": "python/oneflow/nn/qat/__init__.py",
    "content": ""
  },
  {
    "path": "python/oneflow/nn/qat/conv.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow as flow\nfrom oneflow import nn as nn\nfrom oneflow.nn.common_types import _size_1_t, _size_2_t, _size_3_t\nfrom typing import Union\n\n\ndef get_conv_fake_quantized(\n    input, input_observer, current_train_step, weight, weight_observer, fake_quantizer\n):\n    in_scale, in_zero_point = input_observer(input, current_train_step)\n    input_fake_quanted = fake_quantizer(input, in_scale, in_zero_point)\n    w_scale, w_zero_point = weight_observer(weight)\n    weight_fake_quanted = fake_quantizer(weight, w_scale, w_zero_point)\n    return input_fake_quanted, weight_fake_quanted\n\n\ndef init_conv_fake_quants(\n    self,\n    quantization_formula: str = \"google\",\n    quantization_bit: int = 8,\n    quantization_scheme: str = \"symmetric\",\n    weight_quant_per_layer: bool = True,\n    input_quant_momentum: float = 0.95,\n):\n    self.input_min_max_observer = nn.MovingAverageMinMaxObserver(\n        stop_update_after_iters=1,\n        quantization_formula=quantization_formula,\n        quantization_bit=quantization_bit,\n        quantization_scheme=quantization_scheme,\n        momentum=input_quant_momentum,\n    )\n    self.register_buffer(\"current_train_step\", flow.zeros(1, dtype=flow.int64,))\n    self.weight_min_max_observer = nn.MinMaxObserver(\n        quantization_formula=quantization_formula,\n        quantization_bit=quantization_bit,\n        quantization_scheme=quantization_scheme,\n        per_layer_quantization=weight_quant_per_layer,\n    )\n    self.fake_quantizer = nn.FakeQuantization(\n        quantization_formula=quantization_formula,\n        quantization_bit=quantization_bit,\n        quantization_scheme=quantization_scheme,\n    )\n\n\nclass QatConv1d(nn.Conv1d):\n    r\"\"\"A Conv1d module attached with `nn.MinMaxObserver`, `nn.MovingAverageMinMaxObserver` and `nn.FakeQuantization` modules for weight and input,\n    used for quantization aware training.\n\n    The parameters of QatConv1d are the same as :class:`~oneflow.nn.Conv1d` with some extra parameters for fake quantization,\n    see :class:`~oneflow.nn.MinMaxObserver`, :class:`~oneflow.nn.MovingAverageMinMaxObserver` and :class:`~oneflow.nn.FakeQuantization` for more details.\n\n    Args:\n        in_channels (int): Number of channels in the input image\n        out_channels (int): Number of channels produced by the convolution\n        kernel_size (int or tuple): Size of the convolving kernel\n        stride (int or tuple, optional): Stride of the convolution. Default: 1\n        padding (int, tuple or str, optional): Padding added to both sides of\n            the input. Default: 0\n        dilation (int or tuple, optional): Spacing between kernel\n            elements. Default: 1\n        groups (int, optional): Number of blocked connections from input\n            channels to output channels. Default: 1\n        bias (bool, optional): If ``True``, adds a learnable bias to the\n            output. Default: ``True``\n        padding_mode (string, optional): ``'zeros'``. Default: ``'zeros'``\n        quantization_formula (str): Support \"google\" or \"cambricon\".\n        quantization_bit (int): Quantize input to uintX / intX, X can be in range [2, 8]. Defaults to 8.\n        quantization_scheme (str): \"symmetric\" or \"affine\", quantize to signed / unsigned integer. Defaults to \"symmetric\".\n        weight_quant_per_layer (bool): True or False, means per-layer / per-channel for weight quantization. Defaults to True.\n        input_quant_momentum (float): Smoothing parameter for exponential moving average operation for input quantization. Defaults to 0.95.\n\n    Shape:\n        - Input: :math:`(N, C_{in}, L_{in})`\n        - Output: :math:`(N, C_{out}, L_{out})` where\n\n          .. math::\n              L_{out} = \\\\left\\\\lfloor\\\\frac{L_{in} + 2 \\\\times \\\\text{padding} - \\\\text{dilation}\n                        \\\\times (\\\\text{kernel\\\\_size} - 1) - 1}{\\\\text{stride}} + 1\\\\right\\\\rfloor\n\n    Attributes:\n        weight (Tensor): the learnable weights of the module of shape\n            :math:`(\\\\text{out\\\\_channels},\n            \\\\frac{\\\\text{in\\\\_channels}}{\\\\text{groups}}, \\\\text{kernel\\\\_size})`.\n            The values of these weights are sampled from\n            :math:`\\\\mathcal{U}(-\\\\sqrt{k}, \\\\sqrt{k})` where\n            :math:`k = \\\\frac{groups}{C_\\\\text{in} * \\\\text{kernel\\\\_size}}`\n        bias (Tensor):   the learnable bias of the module of shape\n            (out_channels). If :attr:`bias` is ``True``, then the values of these weights are\n            sampled from :math:`\\\\mathcal{U}(-\\\\sqrt{k}, \\\\sqrt{k})` where\n            :math:`k = \\\\frac{groups}{C_\\\\text{in} * \\\\text{kernel\\\\_size}}`\n\n    For example: \n\n    .. code-block:: python\n\n        >>> import numpy as np\n        >>> import oneflow as flow\n        >>> import oneflow.nn as nn\n        \n        >>> arr = np.random.randn(20, 16, 50)\n        >>> input = flow.Tensor(arr)\n        >>> m = nn.QatConv1d(16, 33, 3, stride=2, quantization_formula=\"google\", quantization_bit=8, quantization_scheme=\"symmetric\")\n        >>> output = m(input)\n\n    \"\"\"\n\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        kernel_size: _size_1_t,\n        stride: _size_1_t = 1,\n        padding: Union[str, _size_1_t] = 0,\n        dilation: _size_1_t = 1,\n        groups: int = 1,\n        bias: bool = True,\n        padding_mode: str = \"zeros\",\n        quantization_formula: str = \"google\",\n        quantization_bit: int = 8,\n        quantization_scheme: str = \"symmetric\",\n        weight_quant_per_layer: bool = True,\n        input_quant_momentum: float = 0.95,\n    ):\n        super().__init__(\n            in_channels,\n            out_channels,\n            kernel_size,\n            stride,\n            padding,\n            dilation,\n            groups,\n            bias,\n            padding_mode,\n        )\n        self.channel_pos = \"channels_first\"\n        init_conv_fake_quants(\n            self,\n            quantization_formula=quantization_formula,\n            quantization_bit=quantization_bit,\n            quantization_scheme=quantization_scheme,\n            weight_quant_per_layer=weight_quant_per_layer,\n            input_quant_momentum=input_quant_momentum,\n        )\n\n    def forward(self, x):\n        fake_quan_input, fake_quan_weight = get_conv_fake_quantized(\n            x,\n            self.input_min_max_observer,\n            self.current_train_step,\n            self.weight,\n            self.weight_min_max_observer,\n            self.fake_quantizer,\n        )\n        return self._conv_forward(fake_quan_input, fake_quan_weight, self.bias)\n\n\nclass QatConv2d(nn.Conv2d):\n    r\"\"\"A Conv2d module attached with `nn.MinMaxObserver`, `nn.MovingAverageMinMaxObserver` and `nn.FakeQuantization` modules for weight and input,\n    used for quantization aware training.\n\n    The parameters of QatConv2d are the same as :class:`~oneflow.nn.Conv2d` with some extra parameters for fake quantization,\n    see :class:`~oneflow.nn.MinMaxObserver`, :class:`~oneflow.nn.MovingAverageMinMaxObserver` and :class:`~oneflow.nn.FakeQuantization` for more details.\n\n    Args:\n        in_channels (int): Number of channels in the input image\n        out_channels (int): Number of channels produced by the convolution\n        kernel_size (int or tuple): Size of the convolving kernel\n        stride (int or tuple, optional): Stride of the convolution. Default: 1\n        padding (int or tuple, optional): Zero-padding added to both sides of\n            the input. Default: 0\n        dilation (int or tuple, optional): Spacing between kernel elements. Default: 1\n        groups (int, optional): Number of blocked connections from input\n            channels to output channels. Default: 1\n        bias (bool, optional): If ``True``, adds a learnable bias to the\n            output. Default: ``True``\n        padding_mode (string, optional): ``'zeros'``. Default: ``'zeros'``\n        quantization_formula (str): Support \"google\" or \"cambricon\".\n        quantization_bit (int): Quantize input to uintX / intX, X can be in range [2, 8]. Defaults to 8.\n        quantization_scheme (str): \"symmetric\" or \"affine\", quantize to signed / unsigned integer. Defaults to \"symmetric\".\n        weight_quant_per_layer (bool): True or False, means per-layer / per-channel for weight quantization. Defaults to True.\n        input_quant_momentum (float): Smoothing parameter for exponential moving average operation for input quantization. Defaults to 0.95.\n\n\n    Shape:\n        - Input: :math:`(N, C_{in}, H_{in}, W_{in})`\n        - Output: :math:`(N, C_{out}, H_{out}, W_{out})` where\n\n          .. math::\n              H_{out} = \\\\left\\\\lfloor\\\\frac{H_{in}  + 2 \\\\times \\\\text{padding}[0] - \\\\text{dilation}[0]\n                        \\\\times (\\\\text{kernel_size}[0] - 1) - 1}{\\\\text{stride}[0]} + 1\\\\right\\\\rfloor\n\n          .. math::\n              W_{out} = \\\\left\\\\lfloor\\\\frac{W_{in}  + 2 \\\\times \\\\text{padding}[1] - \\\\text{dilation}[1]\n                        \\\\times (\\\\text{kernel_size}[1] - 1) - 1}{\\\\text{stride}[1]} + 1\\\\right\\\\rfloor\n\n    Attr:\n        - weight (Tensor): the learnable weights of the module of shape\n            :math:`(\\\\text{out_channels}, \\\\frac{\\\\text{in_channels}}{\\\\text{groups}},`\n            :math:`\\\\text{kernel_size[0]}, \\\\text{kernel_size[1]})`.\n            The values of these weights are sampled from\n            :math:`\\\\mathcal{U}(-\\\\sqrt{k}, \\\\sqrt{k})` where\n            :math:`k = \\\\frac{groups}{C_\\\\text{in} * \\\\prod_{i=0}^{1}\\\\text{kernel_size}[i]}`\n\n        - bias (Tensor):   the learnable bias of the module of shape\n            (out_channels). If :attr:`bias` is ``True``,\n            then the values of these weights are\n            sampled from :math:`\\\\mathcal{U}(-\\\\sqrt{k}, \\\\sqrt{k})` where\n            :math:`k = \\\\frac{groups}{C_\\\\text{in} * \\\\prod_{i=0}^{1}\\\\text{kernel_size}[i]}`\n\n    For example: \n\n    .. code-block:: python\n\n        >>> import numpy as np\n        >>> import oneflow as flow\n        >>> import oneflow.nn as nn\n        \n        >>> arr = np.random.randn(20, 16, 50, 100)\n        >>> input = flow.Tensor(arr)\n        >>> m = nn.QatConv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1), quantization_formula=\"google\", quantization_bit=8, quantization_scheme=\"symmetric\")\n        >>> output = m(input)\n\n    \"\"\"\n\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        kernel_size: _size_2_t,\n        stride: _size_2_t = 1,\n        padding: Union[str, _size_2_t] = 0,\n        dilation: _size_2_t = 1,\n        groups: int = 1,\n        bias: bool = True,\n        padding_mode: str = \"zeros\",\n        quantization_formula: str = \"google\",\n        quantization_bit: int = 8,\n        quantization_scheme: str = \"symmetric\",\n        weight_quant_per_layer: bool = True,\n        input_quant_momentum: float = 0.95,\n    ):\n        super().__init__(\n            in_channels,\n            out_channels,\n            kernel_size,\n            stride,\n            padding,\n            dilation,\n            groups,\n            bias,\n            padding_mode,\n        )\n        self.channel_pos = \"channels_first\"\n        init_conv_fake_quants(\n            self,\n            quantization_formula=quantization_formula,\n            quantization_bit=quantization_bit,\n            quantization_scheme=quantization_scheme,\n            weight_quant_per_layer=weight_quant_per_layer,\n            input_quant_momentum=input_quant_momentum,\n        )\n\n    def forward(self, x):\n        fake_quan_input, fake_quan_weight = get_conv_fake_quantized(\n            x,\n            self.input_min_max_observer,\n            self.current_train_step,\n            self.weight,\n            self.weight_min_max_observer,\n            self.fake_quantizer,\n        )\n        return self._conv_forward(fake_quan_input, fake_quan_weight, self.bias)\n\n\nclass QatConv3d(nn.Conv3d):\n    r\"\"\"A Conv3d module attached with `nn.MinMaxObserver`, `nn.MovingAverageMinMaxObserver` and `nn.FakeQuantization` modules for weight and input,\n    used for quantization aware training.\n\n    The parameters of QatConv3d are the same as :class:`~oneflow.nn.Conv3d` with some extra parameters for fake quantization,\n    see :class:`~oneflow.nn.MinMaxObserver`, :class:`~oneflow.nn.MovingAverageMinMaxObserver` and :class:`~oneflow.nn.FakeQuantization` for more details.\n\n    Args:\n        in_channels (int): Number of channels in the input image\n        out_channels (int): Number of channels produced by the convolution\n        kernel_size (int or tuple): Size of the convolving kernel\n        stride (int or tuple, optional): Stride of the convolution. Default: 1\n        padding (int, tuple or str, optional): Padding added to all six sides of\n            the input. Default: 0\n        dilation (int or tuple, optional): Spacing between kernel elements. Default: 1\n        groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1\n        bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True``\n        padding_mode (string, optional): ``'zeros'``. Default: ``'zeros'``\n        quantization_formula (str): Support \"google\" or \"cambricon\".\n        quantization_bit (int): Quantize input to uintX / intX, X can be in range [2, 8]. Defaults to 8.\n        quantization_scheme (str): \"symmetric\" or \"affine\", quantize to signed / unsigned integer. Defaults to \"symmetric\".\n        weight_quant_per_layer (bool): True or False, means per-layer / per-channel for weight quantization. Defaults to True.\n        input_quant_momentum (float): Smoothing parameter for exponential moving average operation for input quantization. Defaults to 0.95.\n\n\n    Shape:\n        - Input: :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})`\n        - Output: :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})` where\n\n          .. math::\n              D_{out} = \\left\\lfloor\\frac{D_{in} + 2 \\times \\text{padding}[0] - \\text{dilation}[0]\n                    \\times (\\text{kernel\\_size}[0] - 1) - 1}{\\text{stride}[0]} + 1\\right\\rfloor\n\n          .. math::\n              H_{out} = \\left\\lfloor\\frac{H_{in} + 2 \\times \\text{padding}[1] - \\text{dilation}[1]\n                    \\times (\\text{kernel\\_size}[1] - 1) - 1}{\\text{stride}[1]} + 1\\right\\rfloor\n\n          .. math::\n              W_{out} = \\left\\lfloor\\frac{W_{in} + 2 \\times \\text{padding}[2] - \\text{dilation}[2]\n                    \\times (\\text{kernel\\_size}[2] - 1) - 1}{\\text{stride}[2]} + 1\\right\\rfloor\n\n    Attributes:\n        weight (Tensor): the learnable weights of the module of shape\n                         :math:`(\\text{out\\_channels}, \\frac{\\text{in\\_channels}}{\\text{groups}},`\n                         :math:`\\text{kernel\\_size[0]}, \\text{kernel\\_size[1]}, \\text{kernel\\_size[2]})`.\n                         The values of these weights are sampled from\n                         :math:`\\mathcal{U}(-\\sqrt{k}, \\sqrt{k})` where\n                         :math:`k = \\frac{groups}{C_\\text{in} * \\prod_{i=0}^{2}\\text{kernel\\_size}[i]}`\n        bias (Tensor):   the learnable bias of the module of shape (out_channels). If :attr:`bias` is ``True``,\n                         then the values of these weights are\n                         sampled from :math:`\\mathcal{U}(-\\sqrt{k}, \\sqrt{k})` where\n                         :math:`k = \\frac{groups}{C_\\text{in} * \\prod_{i=0}^{2}\\text{kernel\\_size}[i]}`\n\n    For example: \n\n    .. code-block:: python\n\n        >>> import numpy as np\n        >>> import oneflow as flow\n        >>> import oneflow.nn as nn\n\n        >>> arr = np.random.randn(1, 2, 5, 5, 5)\n        >>> input = flow.Tensor(arr)\n        >>> m = nn.QatConv3d(2, 4, kernel_size=3, stride=1, quantization_formula=\"google\", quantization_bit=8, quantization_scheme=\"symmetric\")\n        >>> output = m(input)\n\n    \"\"\"\n\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        kernel_size: _size_3_t,\n        stride: _size_3_t = 1,\n        padding: Union[str, _size_3_t] = 0,\n        dilation: _size_3_t = 1,\n        groups: int = 1,\n        bias: bool = True,\n        padding_mode: str = \"zeros\",\n        quantization_formula: str = \"google\",\n        quantization_bit: int = 8,\n        quantization_scheme: str = \"symmetric\",\n        weight_quant_per_layer: bool = True,\n        input_quant_momentum: float = 0.95,\n    ):\n        super().__init__(\n            in_channels,\n            out_channels,\n            kernel_size,\n            stride,\n            padding,\n            dilation,\n            groups,\n            bias,\n            padding_mode,\n        )\n        self.channel_pos = \"channels_first\"\n        init_conv_fake_quants(\n            self,\n            quantization_formula=quantization_formula,\n            quantization_bit=quantization_bit,\n            quantization_scheme=quantization_scheme,\n            weight_quant_per_layer=weight_quant_per_layer,\n            input_quant_momentum=input_quant_momentum,\n        )\n\n    def forward(self, x):\n        fake_quan_input, fake_quan_weight = get_conv_fake_quantized(\n            x,\n            self.input_min_max_observer,\n            self.current_train_step,\n            self.weight,\n            self.weight_min_max_observer,\n            self.fake_quantizer,\n        )\n        return self._conv_forward(fake_quan_input, fake_quan_weight, self.bias)\n\n\nif __name__ == \"__main__\":\n    import doctest\n\n    doctest.testmod(raise_on_error=True)\n"
  },
  {
    "path": "python/oneflow/nn/utils/__init__.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nfrom oneflow.nn.utils.clip_grad import clip_grad_norm_, clip_grad_value_\nfrom oneflow.nn.utils.weight_norm import weight_norm\nfrom oneflow.nn.utils.weight_norm import remove_weight_norm\nfrom oneflow.nn.utils.parameters_grouping import ContiguousParamsGroup\nfrom oneflow.nn.utils.convert_parameters import (\n    parameters_to_vector,\n    vector_to_parameters,\n)\nfrom oneflow.nn.utils.skip_init import skip_init\n"
  },
  {
    "path": "python/oneflow/nn/utils/clip_grad.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport warnings\nfrom typing import Union, Iterable\n\nimport numpy as np\nimport oneflow as flow\n\nfrom oneflow.framework.tensor import Tensor\nfrom oneflow.framework.tensor import register_tensor_op\nfrom oneflow.nn.modules.module import Module\n\n\n_tensor_or_tensors = Union[Tensor, Iterable[Tensor]]\n\n\ndef clip_grad_norm_(\n    parameters: _tensor_or_tensors,\n    max_norm: float,\n    norm_type: float = 2.0,\n    fused: bool = False,\n    error_if_nonfinite: bool = False,\n) -> Tensor:\n    r\"\"\"Clips gradient norm of an iterable of parameters.\n    The norm is computed over all gradients together, as if they were\n    concatenated into a single vector.\n\n    Args:\n        parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a\n            single Tensor that will have gradients normalized\n        max_norm (float or int): max norm of the gradients\n        norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for\n            infinity norm.\n        error_if_nonfinite (bool): if True, an error is thrown if the total\n            norm of the gradients from :attr:``parameters`` is ``nan``,\n            ``inf``, or ``-inf``. Default: False (will switch to True in the future)\n\n    Returns:\n        Parameters after cliping gradient norm\n        Total norm of the parameters (viewed as a single vector).\n    \n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n        >>> x1 = flow.tensor(np.array([[2, 3, 4], [1.5, 2.6, 3.7]]).astype(np.float32), requires_grad=True)\n        >>> m1 = flow.nn.ReLU()\n        >>> out1 = m1(x1)\n        >>> out1 = out1.sum()\n        >>> out1.backward()\n        >>> norm1 = flow.nn.utils.clip_grad_norm_(x1, 0.6, 1.0)\n        >>> norm1\n        tensor(6., dtype=oneflow.float32)\n        >>> x1.grad\n        tensor([[0.1000, 0.1000, 0.1000],\n                [0.1000, 0.1000, 0.1000]], dtype=oneflow.float32)\n        >>> x2 = flow.tensor(np.array([[-2, -3, -4], [2.5, 0, 3.2]]).astype(np.float32), requires_grad=True)\n        >>> out2 = flow.atan(x2)\n        >>> out2 = out2.sum()\n        >>> out2.backward()\n        >>> norm2 = flow.nn.utils.clip_grad_norm_(x2, 0.5)\n        >>> norm2\n        tensor(1.0394, dtype=oneflow.float32)\n        >>> x2.grad\n        tensor([[0.0962, 0.0481, 0.0283],\n                [0.0663, 0.4810, 0.0428]], dtype=oneflow.float32)\n\n    \"\"\"\n\n    if isinstance(parameters, (Tensor, flow._oneflow_internal.Tensor)):\n        parameters = [parameters]\n    parameters = [p for p in parameters if p.grad is not None]\n    max_norm = float(max_norm)\n    norm_type = float(norm_type)\n    if len(parameters) == 0:\n        return flow.tensor(0.0)\n\n    if parameters[0].is_global:\n        assert all(\n            [p.is_global for p in parameters]\n        ), \"All parameters must be global tensor.\"\n        sbp_broadcast = [flow.sbp.broadcast for _ in parameters[0].sbp]\n        param0_placement = parameters[0].placement\n        if norm_type == float(\"inf\"):\n            norms = [\n                p.grad.detach()\n                .to_global(sbp=sbp_broadcast)\n                .abs()\n                .max()\n                .to_global(placement=param0_placement)\n                for p in parameters\n            ]\n            total_norm = norms[0] if len(norms) == 1 else flow.max(flow.stack(norms))\n        elif norm_type == float(\"-inf\"):\n            norms = [\n                p.grad.detach()\n                .to_global(sbp=sbp_broadcast)\n                .abs()\n                .min()\n                .to_global(placement=param0_placement)\n                for p in parameters\n            ]\n            total_norm = norms[0] if len(norms) == 1 else flow.min(flow.stack(norms))\n        else:\n            total_norm = flow.linalg.vector_norm(\n                flow.stack(\n                    [\n                        flow.linalg.vector_norm(\n                            p.grad.detach().to_global(sbp=sbp_broadcast), norm_type\n                        ).to_global(placement=param0_placement)\n                        for p in parameters\n                    ]\n                ),\n                norm_type,\n            )\n        if error_if_nonfinite and flow.logical_or(\n            total_norm.isnan(), total_norm.isinf()\n        ):\n            raise RuntimeError(\n                f\"The total norm of order {norm_type} for gradients from \"\n                \"`parameters` is non-finite, so it cannot be clipped. To disable \"\n                \"this error and scale the gradients by the non-finite norm anyway, \"\n                \"set `error_if_nonfinite=False`\"\n            )\n        clip_coef = max_norm / (total_norm + 1e-6)\n        clip_coef_clamped = clip_coef.clamp(max=1.0)\n        for p in parameters:\n            p.grad.detach().mul_(clip_coef_clamped.to_global(placement=p.placement))\n    elif fused and not error_if_nonfinite and all([p.grad.is_cuda for p in parameters]):\n        param_grad_list = []\n        for param in parameters:\n            param_grad_list.append(param.grad)\n        total_norm = flow._C.fused_clip_grad(param_grad_list, max_norm, norm_type,)\n    else:\n        device = parameters[0].grad.device\n        if norm_type == float(\"inf\"):\n            norms = [p.grad.detach().abs().max().to(device) for p in parameters]\n            total_norm = norms[0] if len(norms) == 1 else flow.max(flow.stack(norms))\n        elif norm_type == float(\"-inf\"):\n            norms = [p.grad.detach().abs().min().to(device) for p in parameters]\n            total_norm = norms[0] if len(norms) == 1 else flow.min(flow.stack(norms))\n        else:\n            total_norm = flow.linalg.vector_norm(\n                flow.stack(\n                    [\n                        flow.linalg.vector_norm(p.grad.detach(), norm_type).to(device)\n                        for p in parameters\n                    ]\n                ),\n                norm_type,\n            )\n        if error_if_nonfinite and flow.logical_or(\n            total_norm.isnan(), total_norm.isinf()\n        ):\n            raise RuntimeError(\n                f\"The total norm of order {norm_type} for gradients from \"\n                \"`parameters` is non-finite, so it cannot be clipped. To disable \"\n                \"this error and scale the gradients by the non-finite norm anyway, \"\n                \"set `error_if_nonfinite=False`\"\n            )\n        clip_coef = max_norm / (total_norm + 1e-6)\n        clip_coef_clamped = clip_coef.clamp(max=1.0)\n        for p in parameters:\n            p.grad.detach().mul_(clip_coef_clamped.to(p.grad.device))\n    return total_norm\n\n\ndef clip_grad_value_(parameters: _tensor_or_tensors, clip_value: float) -> None:\n    r\"\"\"Clips gradient of an iterable of parameters at specified value.\n\n    Gradients are modified in-place.\n\n    Args:\n        parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a\n            single Tensor that will have gradients normalized\n        clip_value (float or int): maximum allowed value of the gradients.\n            The gradients are clipped in the range\n            :math:`\\left[\\text{-clip\\_value}, \\text{clip\\_value}\\right]`\n    \"\"\"\n    if isinstance(parameters, flow.Tensor):\n        parameters = [parameters]\n    clip_value = float(clip_value)\n    for p in filter(lambda p: p.grad is not None, parameters):\n        # TODO: Switch to inplace clamp function\n        p.grad[:] = p.grad.clamp(min=-clip_value, max=clip_value)\n\n\nif __name__ == \"__main__\":\n    import doctest\n\n    doctest.testmod(raise_on_error=True)\n"
  },
  {
    "path": "python/oneflow/nn/utils/container.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport collections.abc\nimport warnings\nimport operator\nfrom collections import OrderedDict, abc as container_abcs\nfrom itertools import islice\nfrom typing import (\n    Any,\n    Iterable,\n    Iterator,\n    Mapping,\n    Optional,\n    Tuple,\n    TypeVar,\n    Union,\n    overload,\n    Generic,\n)\n\nimport oneflow as flow\nfrom oneflow.nn.modules.module import Module\n\nT = TypeVar(\"T\")\n\n\ndef get_seq(T):\n    class SequentialContainer(T):\n        @overload\n        def __init__(self, *args: T) -> None:\n            ...\n\n        @overload\n        def __init__(self, arg: \"OrderedDict[str, T]\") -> None:\n            ...\n\n        def __init__(self, *args: Any):\n            super(SequentialContainer, self).__init__()\n            if len(args) == 1 and isinstance(args[0], OrderedDict):\n                for (key, module) in args[0].items():\n                    self.add_module(key, module)\n            else:\n                for (idx, module) in enumerate(args):\n                    self.add_module(str(idx), module)\n\n        def _get_item_by_idx(self, iterator, idx):\n            \"\"\"Get the idx-th item of the iterator\"\"\"\n            size = len(self)\n            idx = operator.index(idx)\n            if not -size <= idx < size:\n                raise IndexError(\"index {} is out of range\".format(idx))\n            idx %= size\n            return next(islice(iterator, idx, None))\n\n        def __getitem__(self: T, idx) -> T:\n            if isinstance(idx, slice):\n                return self.__class__(OrderedDict(list(self._modules.items())[idx]))\n            else:\n                return self._get_item_by_idx(self._modules.values(), idx)\n\n        def __setitem__(self, idx: int, module: T) -> None:\n            key = self._get_item_by_idx(self._modules.keys(), idx)\n            return setattr(self, key, module)\n\n        def __delitem__(self, idx: Union[slice, int]) -> None:\n            if isinstance(idx, slice):\n                for key in list(self._modules.keys())[idx]:\n                    delattr(self, key)\n            else:\n                key = self._get_item_by_idx(self._modules.keys(), idx)\n                delattr(self, key)\n\n        def __len__(self) -> int:\n            return len(self._modules)\n\n        def __dir__(self):\n            keys = super(SequentialContainer, self).__dir__()\n            keys = [key for key in keys if not key.isdigit()]\n            return keys\n\n        def __iter__(self) -> Iterator[T]:\n            return iter(self._modules.values())\n\n        def forward(self, input):\n            for module in self:\n                input = module(input)\n            return input\n\n    return SequentialContainer\n\n\ndef get_list(T):\n    class ListContainer(T):\n        def __init__(self, modules: Optional[Iterable[T]] = None) -> None:\n            super(ListContainer, self).__init__()\n            if modules is not None:\n                self += modules\n\n        def _get_abs_string_index(self, idx):\n            \"\"\"Get the absolute index for the list of modules\"\"\"\n            idx = operator.index(idx)\n            if not -len(self) <= idx < len(self):\n                raise IndexError(\"index {} is out of range\".format(idx))\n            if idx < 0:\n                idx += len(self)\n            return str(idx)\n\n        def __getitem__(self, idx: int) -> T:\n            if isinstance(idx, slice):\n                return self.__class__(list(self._modules.values())[idx])\n            else:\n                return self._modules[self._get_abs_string_index(idx)]\n\n        def __setitem__(self, idx: int, module: T) -> None:\n            idx = self._get_abs_string_index(idx)\n            return setattr(self, str(idx), module)\n\n        def __delitem__(self, idx: Union[int, slice]) -> None:\n            if isinstance(idx, slice):\n                for k in range(len(self._modules))[idx]:\n                    delattr(self, str(k))\n            else:\n                delattr(self, self._get_abs_string_index(idx))\n            str_indices = [str(i) for i in range(len(self._modules))]\n            self._modules = OrderedDict(list(zip(str_indices, self._modules.values())))\n\n        def __len__(self) -> int:\n            return len(self._modules)\n\n        def __iter__(self) -> Iterator[T]:\n            return iter(self._modules.values())\n\n        def __iadd__(self: T, modules: Iterable[T]) -> T:\n            return self.extend(modules)\n\n        def __dir__(self):\n            keys = super(ListContainer, self).__dir__()\n            keys = [key for key in keys if not key.isdigit()]\n            return keys\n\n        def insert(self, index: int, module: T) -> None:\n            \"\"\"Insert a given module before a given index in the list.\n    \n            Arguments:\n                index (int): index to insert.\n                module (nn.Module): module to insert\n            \"\"\"\n            for i in range(len(self._modules), index, -1):\n                self._modules[str(i)] = self._modules[str(i - 1)]\n            self._modules[str(index)] = module\n\n        def append(self: T, module: T) -> T:\n            \"\"\"Appends a given module to the end of the list.\n    \n            Arguments:\n                module (nn.Module): module to append\n            \"\"\"\n            self.add_module(str(len(self)), module)\n            return self\n\n        def extend(self: T, modules: Iterable[T]) -> T:\n            \"\"\"Appends modules from a Python iterable to the end of the list.\n    \n            Arguments:\n                modules (iterable): iterable of modules to append\n            \"\"\"\n            if not isinstance(modules, collections.abc.Iterable):\n                raise TypeError(\n                    \"ModuleList.extend should be called with an iterable, but got \"\n                    + type(modules).__name__\n                )\n            offset = len(self)\n            for (i, module) in enumerate(modules):\n                self.add_module(str(offset + i), module)\n            return self\n\n        def forward(self):\n            raise NotImplementedError()\n\n    return ListContainer\n\n\ndef get_dict(T):\n    class DictContainer(T):\n        def __init__(self, modules: Optional[Mapping[str, T]] = None) -> None:\n            super(DictContainer, self).__init__()\n            if modules is not None:\n                self.update(modules)\n\n        def __getitem__(self, key: str) -> T:\n            return self._modules[key]\n\n        def __setitem__(self, key: str, module: T) -> None:\n            self.add_module(key, module)\n\n        def __delitem__(self, key: str) -> None:\n            del self._modules[key]\n\n        def __len__(self) -> int:\n            return len(self._modules)\n\n        def __iter__(self) -> Iterator[str]:\n            return iter(self._modules)\n\n        def __contains__(self, key: str) -> bool:\n            return key in self._modules\n\n        def clear(self) -> None:\n            \"\"\"Remove all items from the ModuleDict.\n            \"\"\"\n            self._modules.clear()\n\n        def pop(self, key: str) -> T:\n            \"\"\"Remove key from the ModuleDict and return its module.\n    \n            Arguments:\n                key (string): key to pop from the ModuleDict\n            \"\"\"\n            v = self[key]\n            del self[key]\n            return v\n\n        def keys(self) -> Iterable[str]:\n            \"\"\"Return an iterable of the ModuleDict keys.\n            \"\"\"\n            return self._modules.keys()\n\n        def items(self) -> Iterable[Tuple[str, T]]:\n            \"\"\"Return an iterable of the ModuleDict key/value pairs.\n            \"\"\"\n            return self._modules.items()\n\n        def values(self) -> Iterable[T]:\n            \"\"\"Return an iterable of the ModuleDict values.\n            \"\"\"\n            return self._modules.values()\n\n        def update(self, modules: Mapping[str, T]) -> None:\n            if not isinstance(modules, collections.abc.Iterable):\n                raise TypeError(\n                    \"ModuleDict.update should be called with an iterable of key/value pairs, but got \"\n                    + type(modules).__name__\n                )\n            if isinstance(modules, (OrderedDict, T, collections.abc.Mapping)):\n                for (key, module) in modules.items():\n                    self[key] = module\n            else:\n                for (j, m) in enumerate(modules):\n                    if not isinstance(m, collections.abc.Iterable):\n                        raise TypeError(\n                            \"ModuleDict update sequence element #\"\n                            + str(j)\n                            + \" should be Iterable; is\"\n                            + type(m).__name__\n                        )\n                    if not len(m) == 2:\n                        raise ValueError(\n                            \"ModuleDict update sequence element #\"\n                            + str(j)\n                            + \" has length \"\n                            + str(len(m))\n                            + \"; 2 is required\"\n                        )\n                    self[m[0]] = m[1]\n\n    return DictContainer\n\n\ndef get_para_list(T):\n    class ParameterListContainer(T):\n        def __init__(self, parameters=None) -> None:\n            super(ParameterListContainer, self).__init__()\n            self._initialized = True\n            if parameters is not None:\n                self += parameters\n\n        def __setstate__(self, state):\n            state[\"_initialized\"] = False\n            super(ParameterListContainer, self).__setstate__(state)\n            self._initialized = True\n\n        def _get_abs_string_index(self, idx):\n            \"\"\"Get the absolute index for the list of modules\"\"\"\n            idx = operator.index(idx)\n            if not -len(self) <= idx < len(self):\n                raise IndexError(\"index {} is out of range\".format(idx))\n            if idx < 0:\n                idx += len(self)\n            return str(idx)\n\n        @overload\n        def __getitem__(self, idx: int):\n            ...\n\n        @overload\n        def __getitem__(self: T, idx: slice) -> T:\n            ...\n\n        def __getitem__(self, idx):\n            if isinstance(idx, slice):\n                return self.__class__(list(self._parameters.values())[idx])\n            else:\n                idx = self._get_abs_string_index(idx)\n                return self._parameters[str(idx)]\n\n        def __setitem__(self, idx: int, param) -> None:\n            idx = self._get_abs_string_index(idx)\n            return self.register_parameter(str(idx), param)\n\n        def __len__(self) -> int:\n            return len(self._parameters)\n\n        def __iter__(self):\n            return iter(self._parameters.values())\n\n        def __iadd__(self, parameters):\n            return self.extend(parameters)\n\n        def __dir__(self):\n            keys = super(ParameterListContainer, self).__dir__()\n            keys = [key for key in keys if not key.isdigit()]\n            return keys\n\n        def append(self: T, parameter) -> T:\n            \"\"\"Appends a given parameter at the end of the list.\n    \n            Arguments:\n    \n                parameter (nn.Parameter): parameter to append\n            \"\"\"\n            self.register_parameter(str(len(self)), parameter)\n            return self\n\n        def extend(self: T, parameters) -> T:\n            \"\"\"Appends parameters from a Python iterable to the end of the list.\n    \n            Arguments:\n    \n                parameters (iterable): iterable of parameters to append\n            \"\"\"\n            if not isinstance(parameters, collections.abc.Iterable):\n                raise TypeError(\n                    \"ParameterList.extend should be called with an iterable, but got \"\n                    + type(parameters).__name__\n                )\n            offset = len(self)\n            for (i, param) in enumerate(parameters):\n                self.register_parameter(str(offset + i), param)\n            return self\n\n        def extra_repr(self) -> str:\n            child_lines = []\n            for (k, p) in self._parameters.items():\n                size_str = \"x\".join((str(size) for size in p.size()))\n                device_str = \"\" if not p.is_cuda else \" (GPU {})\".format(p.get_device())\n                parastr = \"Parameter containing: [{} of size {}{}]\".format(\n                    type(p), size_str, device_str\n                )\n                child_lines.append(\"  (\" + str(k) + \"): \" + parastr)\n            tmpstr = \"\\n\".join(child_lines)\n            return tmpstr\n\n        def __call__(self, input):\n            raise RuntimeError(\"ParameterList should not be called.\")\n\n        def _replicate_for_data_parallel(self):\n            warnings.warn(\n                \"nn.ParameterList is being used with DataParallel but this is not supported. This list will appear empty for the models replicated on each GPU except the original one.\"\n            )\n            return super(ParameterListContainer, self)._replicate_for_data_parallel()\n\n    return ParameterListContainer\n\n\ndef get_para_dict(T):\n    class ParameterDictContainer(T):\n        def __init__(self, parameters=None) -> None:\n            super(ParameterDictContainer, self).__init__()\n            self._initialized = True\n            if parameters is not None:\n                self.update(parameters)\n\n        def __setstate__(self, state):\n            state[\"_initialized\"] = False\n            super(ParameterDictContainer, self).__setstate__(state)\n            self._initialized = True\n\n        def __getitem__(self, key: str):\n            return self._parameters[key]\n\n        def __setitem__(self, key: str, parameter) -> None:\n            self.register_parameter(key, parameter)\n\n        def __delitem__(self, key: str) -> None:\n            del self._parameters[key]\n\n        def __len__(self) -> int:\n            return len(self._parameters)\n\n        def __iter__(self) -> Iterator[str]:\n            return iter(self._parameters.keys())\n\n        def __contains__(self, key: str) -> bool:\n            return key in self._parameters\n\n        def clear(self) -> None:\n            \"\"\"Remove all items from the ParameterDict.\n            \"\"\"\n            self._parameters.clear()\n\n        def pop(self, key: str):\n            r\"\"\"Remove key from the ParameterDict and return its parameter.\n    \n            Args:\n    \n                key (string): key to pop from the ParameterDict\n            \"\"\"\n            v = self[key]\n            del self[key]\n            return v\n\n        def keys(self) -> Iterable[str]:\n            r\"\"\"Return an iterable of the ParameterDict keys.\n            \"\"\"\n            return self._parameters.keys()\n\n        def items(self):\n            r\"\"\"Return an iterable of the ParameterDict key/value pairs.\n            \"\"\"\n            return self._parameters.items()\n\n        def values(self):\n            r\"\"\"Return an iterable of the ParameterDict values.\n            \"\"\"\n            return self._parameters.values()\n\n        def update(self, parameters) -> None:\n            r\"\"\"Update the :class:`~flow.nn.ParameterDict` with the key-value pairs from a\n            mapping or an iterable, overwriting existing keys.\n    \n            .. note::\n                If :attr:`parameters` is an ``OrderedDict``, a :class:`~flow.nn.ParameterDict`, or\n                an iterable of key-value pairs, the order of new elements in it is preserved.\n         \n            Args:\n                parameters (iterable): a mapping (dictionary) from string to\n                    :class:`~flow.nn.Parameter`, or an iterable of\n                    key-value pairs of type (string, :class:`~flow.nn.Parameter`)\n    \n            \"\"\"\n            if not isinstance(parameters, container_abcs.Iterable):\n                raise TypeError(\n                    \"ParametersDict.update should be called with an \"\n                    \"iterable of key/value pairs, but got \" + type(parameters).__name__\n                )\n\n            if isinstance(parameters, (OrderedDict, ParameterDictContainer)):\n                for key, parameter in parameters.items():\n                    self[key] = parameter\n            elif isinstance(parameters, container_abcs.Mapping):\n                for key, parameter in sorted(parameters.items()):\n                    self[key] = parameter\n            else:\n                for j, p in enumerate(parameters):\n                    if not isinstance(p, container_abcs.Iterable):\n                        raise TypeError(\n                            \"ParameterDict update sequence element \"\n                            \"#\" + str(j) + \" should be Iterable; is\" + type(p).__name__\n                        )\n                    if not len(p) == 2:\n                        raise ValueError(\n                            \"ParameterDict update sequence element \"\n                            \"#\"\n                            + str(j)\n                            + \" has length \"\n                            + str(len(p))\n                            + \"; 2 is required\"\n                        )\n                    # parameters as length-2 list too cumbersome to type, see ModuleDict.update comment\n                    self[p[0]] = p[1]  # type: ignore[assignment]\n\n        def extra_repr(self) -> str:\n            child_lines = []\n            for k, p in self._parameters.items():\n                size_str = \"x\".join(str(size) for size in p.size())\n                device_str = \"\" if not p.is_cuda else \" (GPU {})\".format(p.get_device())\n                parastr = \"Parameter containing: [{} of size {}{}]\".format(\n                    type(p), size_str, device_str\n                )\n                child_lines.append(\"  (\" + k + \"): \" + parastr)\n            tmpstr = \"\\n\".join(child_lines)\n            return tmpstr\n\n        def __call__(self, input):\n            raise RuntimeError(\"ParameterDict should not be called.\")\n\n        def _replicate_for_data_parallel(self):\n            warnings.warn(\n                \"nn.ParameterDict is being used with DataParallel but this is not \"\n                \"supported. This dict will appear empty for the models replicated \"\n                \"on each GPU except the original one.\"\n            )\n\n            return super(ParameterDictContainer, self)._replicate_for_data_parallel()\n\n    return ParameterDictContainer\n"
  },
  {
    "path": "python/oneflow/nn/utils/convert_parameters.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow as flow\nfrom typing import Iterable, Optional\nfrom oneflow.framework.tensor import Tensor\n\n\ndef parameters_to_vector(parameters: Iterable[Tensor]) -> Tensor:\n    r\"\"\"Convert parameters to one vector\n    \n    The method is consistent with PyTorch.\n    The documentation is referenced from:\n    https://pytorch.org/docs/stable/generated/torch.nn.utils.parameters_to_vector.\n\n    Args:\n        parameters (Iterable[Tensor]): an iterator of Tensors that are the\n            parameters of a model.\n    Returns:\n        The parameters represented by a single vector\n    \"\"\"\n    # Flag for the device where the parameter is located\n    param_device = None\n\n    vec = []\n    for param in parameters:\n        # Ensure the parameters are located in the same device\n        param_device = _check_param_device(param, param_device)\n\n        vec.append(param.view(-1))\n    return flow.cat(vec)\n\n\ndef vector_to_parameters(vec: Tensor, parameters: Iterable[Tensor]) -> None:\n    r\"\"\"Convert one vector to the parameters\n\n    The method is consistent with PyTorch.\n    The documentation is referenced from:\n    https://pytorch.org/docs/stable/generated/torch.nn.utils.vector_to_parameters.\n\n    Args:\n        vec (Tensor): a single vector represents the parameters of a model.\n        parameters (Iterable[Tensor]): an iterator of Tensors that are the\n            parameters of a model.\n    \"\"\"\n    # Ensure vec of type Tensor\n    if not isinstance(vec, Tensor):\n        raise TypeError(\"expected flow.Tensor, but got: {}\".format(flow.typename(vec)))\n    # Flag for the device where the parameter is located\n    param_device = None\n\n    # Pointer for slicing the vector for each parameter\n    pointer = 0\n    for param in parameters:\n        # Ensure the parameters are located in the same device\n        param_device = _check_param_device(param, param_device)\n\n        # The length of the parameter\n        num_param = param.numel()\n        # Slice the vector, reshape it, and replace the old data of the parameter\n        param.data = vec[pointer : pointer + num_param].view_as(param).data\n\n        # Increment the pointer\n        pointer += num_param\n\n\ndef _check_param_device(param: Tensor, old_param_device: Optional[int]) -> int:\n    r\"\"\"This helper function is to check if the parameters are located\n    in the same device. Currently, the conversion between model parameters\n    and single vector form is not supported for multiple allocations,\n    e.g. parameters in different GPUs, or mixture of CPU/GPU.\n\n    The method is consistent with PyTorch.\n    The documentation is referenced from:\n    https://pytorch.org/docs/1.10/nn.html#utilities.\n\n    Args:\n        param ([Tensor]): a Tensor of a parameter of a model\n        old_param_device (int): the device where the first parameter of a\n                                model is allocated.\n    Returns:\n        old_param_device (int): report device for the first time\n    \"\"\"\n\n    # Meet the first parameter\n    if old_param_device is None:\n        old_param_device = param.get_device() if param.is_cuda else -1\n    else:\n        warn = False\n        if param.is_cuda:  # Check if in same GPU\n            warn = param.get_device() != old_param_device\n        else:  # Check if in CPU\n            warn = old_param_device != -1\n        if warn:\n            raise TypeError(\n                \"Found two parameters on different devices, \"\n                \"this is currently not supported.\"\n            )\n    return old_param_device\n"
  },
  {
    "path": "python/oneflow/nn/utils/parameters_grouping.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport collections\nimport warnings\nfrom typing import Union, List\n\nimport oneflow as flow\nfrom oneflow.framework.tensor import Tensor\n\n\n_tensor_or_tensors = Union[Tensor, List[Tensor], List[List[Tensor]]]\n\n\ndef numel_in_bucket(tensor: Tensor):\n    assert flow.is_floating_point(tensor), \"params grouping only support float tensor.\"\n\n    def align(x: int, unit_size: int):\n        return (x + (unit_size - 1)) // unit_size * unit_size\n\n    # tensor memory should be align to 512 bytes for cuda operations,\n    # align size depends on floating type\n    return align(\n        tensor.numel(),\n        flow._oneflow_internal.max_alignment_size()\n        // (flow.finfo(tensor.dtype).bits // 8),\n    )\n\n\nclass ContiguousParamsGroup(object):\n    \"\"\"Arange tensors into contiguous buffer according to their group.\n\n    Args:\n        params_group_list (Iterable[Tensor] or Tensor): an iterable of Tensors or\n            a single Tensor that will be made into buffers.\n        group_on_current_buffer (bool, optional): whether to group tensors on allocated\n            buffers. (default: True)\n\n    Note:\n        The ContiguousParamsGroup is created by 2D List of Tensors, which indicates the\n        Tensors in the same 1D List should be grouped into the same Tensor buffer, otherwise\n        try to make them into a 2D List.\n\n        If group_on_current_buffer is set True but there is not any buffer created before,\n        ContiguousParamsGroup will allocate default buffers for all parameters.\n    \"\"\"\n\n    def __init__(\n        self,\n        params_group_list: _tensor_or_tensors,\n        group_on_current_buffer: bool = True,\n    ):\n        self.params_group_list = params_group_list.copy()\n\n        self._make_valid_params_group_list()\n        self._remove_no_grad_tensors()\n        self._check_tensor_position_consistency()\n\n        self.group_on_current_buffer = group_on_current_buffer\n        self.grouped_tensors = []\n        self.grouped_grads = []\n\n        if not self.group_on_current_buffer:\n            self._parameters_grouping_on_new_buffer()\n        else:\n            self._check_current_buffer()\n            self._parameters_grouping_on_current_buffer()\n\n    def _make_valid_params_group_list(self):\n        \"\"\"making params_group_list 2D List of Tensors\n        \"\"\"\n        if isinstance(self.params_group_list, Tensor):\n            warnings.warn(\"Single tensor is best not do grouping.\")\n            self.params_group_list = [[self.params_group_list]]\n        elif all([isinstance(p, Tensor) for p in self.params_group_list]):\n            self.params_group_list = [self.params_group_list]\n        elif all(\n            [\n                all([isinstance(p, Tensor) for p in params])\n                for params in self.params_group_list\n            ]\n        ):\n            pass\n        else:\n            raise ValueError(\"The shape of params_group_list is illegal!\")\n\n    def _remove_no_grad_tensors(self):\n        self.params_group_list = [\n            [p for p in params if p.requires_grad] for params in self.params_group_list\n        ]\n\n    def _check_tensor_position_consistency(self):\n        if all(\n            [all([p.is_global for p in params]) for params in self.params_group_list]\n        ):\n            self.is_global = True\n        elif all(\n            [all([p.is_local for p in params]) for params in self.params_group_list]\n        ):\n            self.is_global = False\n        else:\n            raise ValueError(\n                \"Parameters must be all local tensors or all global tensors for params grouping.\"\n            )\n\n    def _check_current_buffer(self):\n        \"\"\"If all tensors are not held by any buffer, try to create buffer.\n        \"\"\"\n        for params in self.params_group_list:\n            for p in params:\n                if p._ref_tensor is not None:\n                    return\n\n        warnings.warn(\"create defualt buffer for all parameters as one group.\")\n\n        self._physical_preparation = ContiguousParamsGroup(\n            self.params_group_list, group_on_current_buffer=False,\n        )\n\n    def _make_buffer_params_mapping(self):\n        buffer_params_mapping = collections.defaultdict(list)\n\n        for params in self.params_group_list:\n            for p in params:\n                if p._ref_tensor is not None:\n                    assert (\n                        p._ref_index < p._ref_tensor.numel()\n                    ), \"invalid ref tensor index.\"\n                    buffer_params_mapping[p._ref_tensor].append((p._ref_index, p))\n\n        for buffer, params_list in buffer_params_mapping.items():\n            buffer_params_mapping[buffer] = sorted(params_list, key=lambda x: x[0])\n\n        return buffer_params_mapping\n\n    def _parameters_grouping_on_new_buffer(self):\n        # Use the group in params_group_list to create default buffer.\n        # A buffer that is too large will affect the parallelism of different parameters.\n\n        params_buffer_size = {}\n        physical_params_buffer = {}\n        params_buffer_index = {}\n\n        for idx, params in enumerate(self.params_group_list):\n            for p in params:\n                if self.is_global:\n                    tensor_key = (p.dtype, p.placement, p.sbp, idx)\n                else:\n                    tensor_key = (p.dtype, p.device, idx)\n\n                params_buffer_size[tensor_key] = params_buffer_size.get(\n                    tensor_key, 0\n                ) + numel_in_bucket(p)\n\n        for tensor_key, buffer_size in params_buffer_size.items():\n            dtype = tensor_key[0]\n\n            if self.is_global:\n                placement = tensor_key[1]\n                sbp = tensor_key[2]\n                physical_param_buf = flow.zeros(\n                    buffer_size, dtype=dtype, placement=placement, sbp=sbp\n                )\n                physical_param_buf.grad = flow.zeros(\n                    buffer_size, dtype=dtype, placement=placement, sbp=sbp\n                )\n            else:\n                device = tensor_key[1]\n                physical_param_buf = flow.zeros(buffer_size, dtype=dtype, device=device)\n                physical_param_buf.grad = flow.zeros(\n                    buffer_size, dtype=dtype, device=device\n                )\n\n            self.grouped_tensors.append(physical_param_buf)\n            self.grouped_grads.append(physical_param_buf.grad)\n            physical_params_buffer[tensor_key] = physical_param_buf\n            params_buffer_index[tensor_key] = 0\n\n        for idx, params in enumerate(self.params_group_list):\n            for p in params:\n                if self.is_global:\n                    tensor_key = (p.dtype, p.placement, p.sbp, idx)\n                else:\n                    tensor_key = (p.dtype, p.device, idx)\n\n                param_buf = physical_params_buffer[tensor_key]\n                index = params_buffer_index[tensor_key]\n                size = p.numel()\n                shape = p.data.shape\n\n                assert index + numel_in_bucket(p) <= param_buf.numel()\n\n                param_buf[index : index + size] = p.data.detach().clone().view(-1)\n                p.data = param_buf[index : index + size].view(shape)\n                p.grad = param_buf.grad[index : index + size].view(shape)\n\n                p._ref_tensor = param_buf\n                p._ref_index = index\n\n                index += numel_in_bucket(p)\n                params_buffer_index[tensor_key] = index\n\n    def _parameters_grouping_on_current_buffer(self):\n        buffer_params_mapping = self._make_buffer_params_mapping()\n\n        if buffer_params_mapping is None or len(buffer_params_mapping) == 0:\n            warnings.warn(\n                \"Since nn.Module didn't use make_contiguous_params_group() to create \"\n                \"a contiguous module, the remapping won't make any difference for parameters. \"\n            )\n\n        params_group = []\n        for params in self.params_group_list:\n            group = set()\n            for p in params:\n                group.add(p)\n            params_group.append(group)\n\n        # handling the parameters already on allocated buffers\n        # try best to make the adjacent tensors on device into same logical buffer\n        for param_buf, params in buffer_params_mapping.items():\n            logical_buffer_start, logical_buffer_size = 0, 0\n            pre_group_index = -1\n            params_cnt = len(params)\n\n            for p_index, (_, p) in enumerate(params):\n                current_group_index = -1\n\n                for group_index, group in enumerate(params_group):\n                    if p in group:\n                        current_group_index = group_index\n                        break\n\n                if current_group_index == -1:\n                    continue\n\n                params_group[current_group_index].remove(p)\n\n                def _make_logical_buf():\n                    nonlocal logical_buffer_start, logical_buffer_size\n                    nonlocal pre_group_index, current_group_index\n\n                    pre_group_index = current_group_index\n\n                    if logical_buffer_size == 0:\n                        return\n\n                    logical_param_buf = param_buf[\n                        logical_buffer_start : logical_buffer_start\n                        + logical_buffer_size\n                    ].view(logical_buffer_size)\n                    logical_param_grad_buf = param_buf.grad[\n                        logical_buffer_start : logical_buffer_start\n                        + logical_buffer_size\n                    ].view(logical_buffer_size)\n                    logical_param_buf.grad = logical_param_grad_buf\n\n                    self.grouped_tensors.append(logical_param_buf)\n                    self.grouped_grads.append(logical_param_grad_buf)\n\n                    logical_buffer_start += logical_buffer_size\n                    logical_buffer_size = 0\n\n                if current_group_index != pre_group_index:\n                    _make_logical_buf()\n\n                logical_buffer_size += numel_in_bucket(p)\n\n                if p_index == params_cnt - 1:\n                    _make_logical_buf()\n\n        # handling params not on any buffer\n        # however, we don't make new tensors into contiguous buffer this time\n        for group in params_group:\n            for p in group:\n                self.grouped_tensors.append(p)\n                self.grouped_grads.append(p.grad)\n\n    @property\n    def grouped_parameters(self):\n        \"\"\"the grouped contiguous parameters\n        \"\"\"\n        return self.grouped_tensors\n\n    @property\n    def grouped_parameters_grad(self):\n        \"\"\"the grouped contiguous parameters' gradient\n        \"\"\"\n        return self.grouped_grads\n"
  },
  {
    "path": "python/oneflow/nn/utils/prune.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nr\"\"\"\nPrune Methods are consistent with PyTorch.\nThe documentation is referenced from:\nhttps://pytorch.org/docs/stable/nn.html#module-torch.nn.utils.\n\"\"\"\nimport numbers\nfrom abc import ABC, abstractmethod\nfrom collections.abc import Iterable\nfrom typing import Tuple\nimport numpy as np\nimport oneflow as flow\n\n\nclass BasePruningMethod(ABC):\n    r\"\"\"Abstract base class for creation of new pruning techniques.\n    Provides a skeleton for customization requiring the overriding of methods\n    such as :meth:`compute_mask` and :meth:`apply`.\n    \"\"\"\n    _tensor_name: str\n\n    def __init__(self):\n        pass\n\n    def __call__(self, module, inputs):\n        r\"\"\"Multiplies the mask (stored in ``module[name + '_mask']``)\n        into the original tensor (stored in ``module[name + '_orig']``)\n        and stores the result into ``module[name]`` by using\n        :meth:`apply_mask`.\n        Args:\n            module (nn.Module): module containing the tensor to prune\n            inputs: not used.\n        \"\"\"\n        setattr(module, self._tensor_name, self.apply_mask(module))\n\n    @abstractmethod\n    def compute_mask(self, t, default_mask):\n        r\"\"\"Computes and returns a mask for the input tensor ``t``.\n        Starting from a base ``default_mask`` (which should be a mask of ones\n        if the tensor has not been pruned yet), generate a random mask to\n        apply on top of the ``default_mask`` according to the specific pruning\n        method recipe.\n        Args:\n            t (flow.Tensor): tensor representing the importance scores of the\n            parameter to prune.\n            default_mask (flow.Tensor): Base mask from previous pruning\n            iterations, that need to be respected after the new mask is\n            applied. Same dims as ``t``.\n        Returns:\n            mask (flow.Tensor): mask to apply to ``t``, of same dims as ``t``\n        \"\"\"\n        pass\n\n    def apply_mask(self, module):\n        r\"\"\"Simply handles the multiplication between the parameter being\n        pruned and the generated mask.\n        Fetches the mask and the original tensor from the module\n        and returns the pruned version of the tensor.\n        Args:\n            module (nn.Module): module containing the tensor to prune\n        Returns:\n            pruned_tensor (flow.Tensor): pruned version of the input tensor\n        \"\"\"\n        # to carry out the multiplication, the mask needs to have been computed,\n        # so the pruning method must know what tensor it's operating on\n        assert self._tensor_name is not None, \"Module {} has to be pruned\".format(\n            module\n        )  # this gets set in apply()\n        mask = getattr(module, self._tensor_name + \"_mask\")\n        orig = getattr(module, self._tensor_name + \"_orig\")\n        pruned_tensor = mask.to(dtype=orig.dtype) * orig\n        return pruned_tensor\n\n    @classmethod\n    def apply(cls, module, name, *args, importance_scores=None, **kwargs):\n        r\"\"\"Adds the forward pre-hook that enables pruning on the fly and\n        the reparametrization of a tensor in terms of the original tensor\n        and the pruning mask.\n        Args:\n            module (nn.Module): module containing the tensor to prune\n            name (str): parameter name within ``module`` on which pruning\n                will act.\n            args: arguments passed on to a subclass of\n                :class:`BasePruningMethod`\n            importance_scores (flow.Tensor): tensor of importance scores (of\n                same shape as module parameter) used to compute mask for pruning.\n                The values in this tensor indicate the importance of the\n                corresponding elements in the parameter being pruned.\n                If unspecified or None, the parameter will be used in its place.\n            kwargs: keyword arguments passed on to a subclass of a\n                :class:`BasePruningMethod`\n        \"\"\"\n\n        def _get_composite_method(cls, module, name, *args, **kwargs):\n            # Check if a pruning method has already been applied to\n            # `module[name]`. If so, store that in `old_method`.\n            old_method = None\n            found = 0\n            # there should technically be only 1 hook with hook.name == name\n            # assert this using `found`\n            hooks_to_remove = []\n            for k, hook in module._forward_pre_hooks.items():\n                # if it exists, take existing thing, remove hook, then\n                # go through normal thing\n                if isinstance(hook, BasePruningMethod) and hook._tensor_name == name:\n                    old_method = hook\n                    hooks_to_remove.append(k)\n                    found += 1\n            assert (\n                found <= 1\n            ), \"Avoid adding multiple pruning hooks to the\\\n                same tensor {} of module {}. Use a PruningContainer.\".format(\n                name, module\n            )\n\n            for k in hooks_to_remove:\n                del module._forward_pre_hooks[k]\n\n            # Apply the new pruning method, either from scratch or on top of\n            # the previous one.\n            method = cls(*args, **kwargs)  # new pruning\n            # Have the pruning method remember what tensor it's been applied to\n            method._tensor_name = name\n\n            # combine `methods` with `old_method`, if `old_method` exists\n            if old_method is not None:  # meaning that there was a hook\n                # if the hook is already a pruning container, just add the\n                # new pruning method to the container\n                if isinstance(old_method, PruningContainer):\n                    old_method.add_pruning_method(method)\n                    method = old_method  # rename old_method --> method\n\n                # if the hook is simply a single pruning method, create a\n                # container, add the old pruning method and the new one\n                elif isinstance(old_method, BasePruningMethod):\n                    container = PruningContainer(old_method)\n                    # Have the pruning method remember the name of its tensor\n                    # setattr(container, '_tensor_name', name)\n                    container.add_pruning_method(method)\n                    method = container  # rename container --> method\n            return method\n\n        method = _get_composite_method(cls, module, name, *args, **kwargs)\n        # at this point we have no forward_pre_hooks but we could have an\n        # active reparametrization of the tensor if another pruning method\n        # had been applied (in which case `method` would be a PruningContainer\n        # and not a simple pruning method).\n\n        # Pruning is to be applied to the module's tensor named `name`,\n        # starting from the state it is found in prior to this iteration of\n        # pruning. The pruning mask is calculated based on importances scores.\n\n        orig = getattr(module, name)\n        if importance_scores is not None:\n            assert (\n                importance_scores.shape == orig.shape\n            ), \"importance_scores should have the same shape as parameter \\\n                {} of {}\".format(\n                name, module\n            )\n        else:\n            importance_scores = orig\n\n        # If this is the first time pruning is applied, take care of moving\n        # the original tensor to a new parameter called name + '_orig' and\n        # and deleting the original parameter\n        if not isinstance(method, PruningContainer):\n            # copy `module[name]` to `module[name + '_orig']`\n            module.register_parameter(name + \"_orig\", orig)\n            # temporarily delete `module[name]`\n            del module._parameters[name]\n            default_mask = flow.ones_like(orig)  # temp\n        # If this is not the first time pruning is applied, all of the above\n        # has been done before in a previous pruning iteration, so we're good\n        # to go\n        else:\n            default_mask = getattr(module, name + \"_mask\").detach().clone()\n\n        # Use try/except because if anything goes wrong with the mask\n        # computation etc., you'd want to roll back.\n        try:\n            # get the final mask, computed according to the specific method\n            mask = method.compute_mask(importance_scores, default_mask=default_mask)\n            # reparametrize by saving mask to `module[name + '_mask']`...\n            module.register_buffer(name + \"_mask\", mask)\n            # ... and the new pruned tensor to `module[name]`\n            setattr(module, name, method.apply_mask(module))\n            # associate the pruning method to the module via a hook to\n            # compute the function before every forward() (compile by run)\n            module.register_forward_pre_hook(method)\n\n        except Exception as e:\n            if not isinstance(method, PruningContainer):\n                orig = getattr(module, name + \"_orig\")\n                module.register_parameter(name, orig)\n                del module._parameters[name + \"_orig\"]\n            raise e\n\n        return method\n\n    def prune(self, t, default_mask=None, importance_scores=None):\n        r\"\"\"Computes and returns a pruned version of input tensor ``t``\n        according to the pruning rule specified in :meth:`compute_mask`.\n        Args:\n            t (flow.Tensor): tensor to prune (of same dimensions as\n                ``default_mask``).\n            importance_scores (flow.Tensor): tensor of importance scores (of\n                same shape as ``t``) used to compute mask for pruning ``t``.\n                The values in this tensor indicate the importance of the\n                corresponding elements in the ``t`` that is being pruned.\n                If unspecified or None, the tensor ``t`` will be used in its place.\n            default_mask (flow.Tensor, optional): mask from previous pruning\n                iteration, if any. To be considered when determining what\n                portion of the tensor that pruning should act on. If None,\n                default to a mask of ones.\n        Returns:\n            pruned version of tensor ``t``.\n        \"\"\"\n        if importance_scores is not None:\n            assert (\n                importance_scores.shape == t.shape\n            ), \"importance_scores should have the same shape as tensor t\"\n        else:\n            importance_scores = t\n        default_mask = default_mask if default_mask is not None else flow.ones_like(t)\n        return t * self.compute_mask(importance_scores, default_mask=default_mask)\n\n    def remove(self, module):\n        r\"\"\"Removes the pruning reparameterization from a module. The pruned\n        parameter named ``name`` remains permanently pruned, and the parameter\n        named ``name+'_orig'`` is removed from the parameter list. Similarly,\n        the buffer named ``name+'_mask'`` is removed from the buffers.\n        Note:\n            Pruning itself is NOT undone or reversed!\n        \"\"\"\n        # before removing pruning from a tensor, it has to have been applied\n        assert (\n            self._tensor_name is not None\n        ), \"Module {} has to be pruned\\\n            before pruning can be removed\".format(\n            module\n        )  # this gets set in apply()\n\n        # to update module[name] to latest trained weights\n        weight = self.apply_mask(module)  # masked weights\n\n        # delete and reset\n        if hasattr(module, self._tensor_name):\n            delattr(module, self._tensor_name)\n        orig = module._parameters[self._tensor_name + \"_orig\"]\n        orig.data = weight.data\n        del module._parameters[self._tensor_name + \"_orig\"]\n        del module._buffers[self._tensor_name + \"_mask\"]\n        setattr(module, self._tensor_name, orig)\n\n\nclass PruningContainer(BasePruningMethod):\n    \"\"\"Container holding a sequence of pruning methods for iterative pruning.\n    Keeps track of the order in which pruning methods are applied and handles\n    combining successive pruning calls.\n    Accepts as argument an instance of a BasePruningMethod or an iterable of\n    them.\n    \"\"\"\n\n    def __init__(self, *args):\n        self._pruning_methods: Tuple[\"BasePruningMethod\", ...] = tuple()\n        if not isinstance(args, Iterable):  # only 1 item\n            self._tensor_name = args._tensor_name\n            self.add_pruning_method(args)\n        elif len(args) == 1:  # only 1 item in a tuple\n            self._tensor_name = args[0]._tensor_name\n            self.add_pruning_method(args[0])\n        else:  # manual construction from list or other iterable (or no args)\n            for method in args:\n                self.add_pruning_method(method)\n\n    def add_pruning_method(self, method):\n        r\"\"\"Adds a child pruning ``method`` to the container.\n        Args:\n            method (subclass of BasePruningMethod): child pruning method\n                to be added to the container.\n        \"\"\"\n        # check that we're adding a pruning method to the container\n        if not isinstance(method, BasePruningMethod) and method is not None:\n            raise TypeError(\n                \"{} is not a BasePruningMethod subclass\".format(type(method))\n            )\n        elif method is not None and self._tensor_name != method._tensor_name:\n            raise ValueError(\n                \"Can only add pruning methods acting on \"\n                \"the parameter named '{}' to PruningContainer {}.\".format(\n                    self._tensor_name, self\n                )\n                + \" Found '{}'\".format(method._tensor_name)\n            )\n        # if all checks passed, add to _pruning_methods tuple\n        self._pruning_methods += (method,)  # type: ignore[operator]\n\n    def __len__(self):\n        return len(self._pruning_methods)\n\n    def __iter__(self):\n        return iter(self._pruning_methods)\n\n    def __getitem__(self, idx):\n        return self._pruning_methods[idx]\n\n    def compute_mask(self, t, default_mask):\n        r\"\"\"Applies the latest ``method`` by computing the new partial masks\n        and returning its combination with the ``default_mask``.\n        The new partial mask should be computed on the entries or channels\n        that were not zeroed out by the ``default_mask``.\n        Which portions of the tensor ``t`` the new mask will be calculated from\n        depends on the ``PRUNING_TYPE`` (handled by the type handler):\n        * for 'unstructured', the mask will be computed from the raveled\n          list of nonmasked entries;\n        * for 'structured', the mask will be computed from the nonmasked\n          channels in the tensor;\n        * for 'global', the mask will be computed across all entries.\n        Args:\n            t (flow.Tensor): tensor representing the parameter to prune\n                (of same dimensions as ``default_mask``).\n            default_mask (flow.Tensor): mask from previous pruning iteration.\n        Returns:\n            mask (flow.Tensor): new mask that combines the effects\n            of the ``default_mask`` and the new mask from the current\n            pruning ``method`` (of same dimensions as ``default_mask`` and\n            ``t``).\n        \"\"\"\n\n        def _combine_masks(method, t, mask):\n            r\"\"\"\n            Args:\n                method (a BasePruningMethod subclass): pruning method\n                    currently being applied.\n                t (flow.Tensor): tensor representing the parameter to prune\n                    (of same dimensions as mask).\n                mask (flow.Tensor): mask from previous pruning iteration\n            Returns:\n                new_mask (flow.Tensor): new mask that combines the effects\n                    of the old mask and the new mask from the current\n                    pruning method (of same dimensions as mask and t).\n            \"\"\"\n            new_mask = mask  # start off from existing mask\n            new_mask = new_mask.to(dtype=t.dtype)\n\n            # compute a slice of t onto which the new pruning method will operate\n            if method.PRUNING_TYPE == \"unstructured\":\n                # prune entries of t where the mask is 1\n                slc = mask == 1\n\n            # for struct pruning, exclude channels that have already been\n            # entirely pruned\n            elif method.PRUNING_TYPE == \"structured\":\n                if not hasattr(method, \"dim\"):\n                    raise AttributeError(\n                        \"Pruning methods of PRUNING_TYPE \"\n                        '\"structured\" need to have the attribute `dim` defined.'\n                    )\n\n                # find the channels to keep by removing the ones that have been\n                # zeroed out already (i.e. where sum(entries) == 0)\n                n_dims = t.dim()  # \"is this a 2D tensor? 3D? ...\"\n                dim = method.dim\n                # convert negative indexing\n                if dim < 0:\n                    dim = n_dims + dim\n                # if dim is still negative after subtracting it from n_dims\n                if dim < 0:\n                    raise IndexError(\n                        \"Index is out of bounds for tensor with dimensions {}\".format(\n                            n_dims\n                        )\n                    )\n                # find channels along dim = dim that aren't already tots 0ed out\n                keep_channel = mask.sum(dim=[d for d in range(n_dims) if d != dim]) != 0\n                # create slice to identify what to prune\n                slc = [slice(None)] * n_dims\n                slc[dim] = keep_channel\n\n            elif method.PRUNING_TYPE == \"global\":\n                n_dims = len(t.shape)  # \"is this a 2D tensor? 3D? ...\"\n                slc = [slice(None)] * n_dims\n\n            else:\n                raise ValueError(\n                    \"Unrecognized PRUNING_TYPE {}\".format(method.PRUNING_TYPE)\n                )\n\n            # compute the new mask on the unpruned slice of the tensor t\n            partial_mask = method.compute_mask(t[slc], default_mask=mask[slc])\n            new_mask[slc] = partial_mask.to(dtype=new_mask.dtype)\n\n            return new_mask\n\n        method = self._pruning_methods[-1]\n        mask = _combine_masks(method, t, default_mask)\n        return mask\n\n\nclass Identity(BasePruningMethod):\n    r\"\"\"Utility pruning method that does not prune any units but generates the\n    pruning parametrization with a mask of ones.\n    \"\"\"\n\n    PRUNING_TYPE = \"unstructured\"\n\n    def compute_mask(self, t, default_mask):\n        mask = default_mask\n        return mask\n\n    @classmethod\n    def apply(cls, module, name):\n        r\"\"\"Adds the forward pre-hook that enables pruning on the fly and\n        the reparametrization of a tensor in terms of the original tensor\n        and the pruning mask.\n        Args:\n            module (nn.Module): module containing the tensor to prune\n            name (str): parameter name within ``module`` on which pruning\n                will act.\n        \"\"\"\n        return super(Identity, cls).apply(module, name)\n\n\nclass RandomUnstructured(BasePruningMethod):\n    r\"\"\"Prune (currently unpruned) units in a tensor at random.\n    Args:\n        name (str): parameter name within ``module`` on which pruning\n            will act.\n        amount (int or float): quantity of parameters to prune.\n            If ``float``, should be between 0.0 and 1.0 and represent the\n            fraction of parameters to prune. If ``int``, it represents the\n            absolute number of parameters to prune.\n    \"\"\"\n\n    PRUNING_TYPE = \"unstructured\"\n\n    def __init__(self, amount):\n        # Check range of validity of pruning amount\n        _validate_pruning_amount_init(amount)\n        self.amount = amount\n\n    def compute_mask(self, t, default_mask):\n        # Check that the amount of units to prune is not > than the number of\n        # parameters in t\n        tensor_size = t.nelement()\n        # Compute number of units to prune: amount if int,\n        # else amount * tensor_size\n        nparams_toprune = _compute_nparams_toprune(self.amount, tensor_size)\n        # This should raise an error if the number of units to prune is larger\n        # than the number of units in the tensor\n        _validate_pruning_amount(nparams_toprune, tensor_size)\n\n        mask = default_mask.clone()\n\n        if nparams_toprune != 0:  # k=0 not supported by flow.kthvalue\n            # prob = flow.rand_like(t)\n            prob = flow.rand(t.size(), dtype=t.dtype, device=t.device)\n            topk = flow.topk(prob.view(-1), k=nparams_toprune)\n            mask.view(-1)[topk.indices] = 0\n\n        return mask\n\n    @classmethod\n    def apply(cls, module, name, amount):\n        r\"\"\"Adds the forward pre-hook that enables pruning on the fly and\n        the reparametrization of a tensor in terms of the original tensor\n        and the pruning mask.\n        Args:\n            module (nn.Module): module containing the tensor to prune\n            name (str): parameter name within ``module`` on which pruning\n                will act.\n            amount (int or float): quantity of parameters to prune.\n                If ``float``, should be between 0.0 and 1.0 and represent the\n                fraction of parameters to prune. If ``int``, it represents the\n                absolute number of parameters to prune.\n        \"\"\"\n        return super(RandomUnstructured, cls).apply(module, name, amount=amount)\n\n\nclass L1Unstructured(BasePruningMethod):\n    r\"\"\"Prune (currently unpruned) units in a tensor by zeroing out the ones\n    with the lowest L1-norm.\n    Args:\n        amount (int or float): quantity of parameters to prune.\n            If ``float``, should be between 0.0 and 1.0 and represent the\n            fraction of parameters to prune. If ``int``, it represents the\n            absolute number of parameters to prune.\n    \"\"\"\n\n    PRUNING_TYPE = \"unstructured\"\n\n    def __init__(self, amount):\n        # Check range of validity of pruning amount\n        _validate_pruning_amount_init(amount)\n        self.amount = amount\n\n    def compute_mask(self, t, default_mask):\n        # Check that the amount of units to prune is not > than the number of\n        # parameters in t\n        tensor_size = t.nelement()\n        # Compute number of units to prune: amount if int,\n        # else amount * tensor_size\n        nparams_toprune = _compute_nparams_toprune(self.amount, tensor_size)\n        # This should raise an error if the number of units to prune is larger\n        # than the number of units in the tensor\n        _validate_pruning_amount(nparams_toprune, tensor_size)\n\n        mask = default_mask.clone()\n\n        if nparams_toprune != 0:  # k=0 not supported by flow.kthvalue\n            # largest=True --> top k; largest=False --> bottom k\n            # Prune the smallest k\n\n            topk = flow.topk(flow.abs(t).view(-1), k=nparams_toprune, largest=False)\n            # topk will have .indices and .values\n            mask.view(-1)[topk.indices] = 0\n\n        return mask\n\n    @classmethod\n    def apply(cls, module, name, amount, importance_scores=None):\n        r\"\"\"Adds the forward pre-hook that enables pruning on the fly and\n        the reparametrization of a tensor in terms of the original tensor\n        and the pruning mask.\n        Args:\n            module (nn.Module): module containing the tensor to prune\n            name (str): parameter name within ``module`` on which pruning\n                will act.\n            amount (int or float): quantity of parameters to prune.\n                If ``float``, should be between 0.0 and 1.0 and represent the\n                fraction of parameters to prune. If ``int``, it represents the\n                absolute number of parameters to prune.\n            importance_scores (flow.Tensor): tensor of importance scores (of same\n                shape as module parameter) used to compute mask for pruning.\n                The values in this tensor indicate the importance of the corresponding\n                elements in the parameter being pruned.\n                If unspecified or None, the module parameter will be used in its place.\n        \"\"\"\n        return super(L1Unstructured, cls).apply(\n            module, name, amount=amount, importance_scores=importance_scores\n        )\n\n\nclass RandomStructured(BasePruningMethod):\n    r\"\"\"Prune entire (currently unpruned) channels in a tensor at random.\n    Args:\n        amount (int or float): quantity of parameters to prune.\n            If ``float``, should be between 0.0 and 1.0 and represent the\n            fraction of parameters to prune. If ``int``, it represents the\n            absolute number of parameters to prune.\n        dim (int, optional): index of the dim along which we define\n            channels to prune. Default: -1.\n    \"\"\"\n\n    PRUNING_TYPE = \"structured\"\n\n    def __init__(self, amount, dim=-1):\n        # Check range of validity of amount\n        _validate_pruning_amount_init(amount)\n        self.amount = amount\n        self.dim = dim\n\n    def compute_mask(self, t, default_mask):\n        r\"\"\"Computes and returns a mask for the input tensor ``t``.\n        Starting from a base ``default_mask`` (which should be a mask of ones\n        if the tensor has not been pruned yet), generate a random mask to\n        apply on top of the ``default_mask`` by randomly zeroing out channels\n        along the specified dim of the tensor.\n        Args:\n            t (flow.Tensor): tensor representing the parameter to prune\n            default_mask (flow.Tensor): Base mask from previous pruning\n                iterations, that need to be respected after the new mask is\n                applied. Same dims as ``t``.\n        Returns:\n            mask (flow.Tensor): mask to apply to ``t``, of same dims as ``t``\n        Raises:\n            IndexError: if ``self.dim >= len(t.shape)``\n        \"\"\"\n        # Check that tensor has structure (i.e. more than 1 dimension) such\n        # that the concept of \"channels\" makes sense\n        _validate_structured_pruning(t)\n\n        # Check that self.dim is a valid dim to index t, else raise IndexError\n        _validate_pruning_dim(t, self.dim)\n\n        # Check that the amount of channels to prune is not > than the number of\n        # channels in t along the dim to prune\n        tensor_size = t.shape[self.dim]\n        # Compute number of units to prune: amount if int,\n        # else amount * tensor_size\n        nparams_toprune = _compute_nparams_toprune(self.amount, tensor_size)\n        # This should raise an error if the number of units to prune is larger\n        # than the number of units in the tensor\n        _validate_pruning_amount(nparams_toprune, tensor_size)\n\n        # Compute binary mask by initializing it to all 0s and then filling in\n        # 1s wherever topk.indices indicates, along self.dim.\n        # mask has the same shape as tensor t\n        def make_mask(t, dim, nchannels, nchannels_toprune):\n            # generate a random number in [0, 1] to associate to each channel\n            prob = flow.rand(nchannels)\n\n            # generate mask for each channel by 0ing out the channels that\n            # got assigned the k = nchannels_toprune lowest values in prob\n            # threshold = flow.kthvalue(prob, k=nchannels_toprune).values\n\n            # ---------------------------------------------------------------\n            # Oneflow does not support kthvalue, but because the operation of kthvalue is\n            # relatively simple, it is implemented directly in python\n\n            y, i = flow.sort(prob)\n            threshold = y[nchannels_toprune - 1]\n            # ---------------------------------------------------------------\n            channel_mask = prob > threshold\n\n            mask = flow.zeros_like(t)\n            slc = [slice(None)] * len(t.shape)\n            slc[dim] = channel_mask\n            mask[slc] = 1\n            return mask\n\n        if nparams_toprune == 0:  # k=0 not supported by flow.kthvalue\n            mask = default_mask\n        else:\n            # apply the new structured mask on top of prior (potentially\n            # unstructured) mask\n            mask = make_mask(t, self.dim, tensor_size, nparams_toprune)\n            mask *= default_mask.to(dtype=mask.dtype)\n        return mask\n\n    @classmethod\n    def apply(cls, module, name, amount, dim=-1):\n        r\"\"\"Adds the forward pre-hook that enables pruning on the fly and\n        the reparametrization of a tensor in terms of the original tensor\n        and the pruning mask.\n        Args:\n            module (nn.Module): module containing the tensor to prune\n            name (str): parameter name within ``module`` on which pruning\n                will act.\n            amount (int or float): quantity of parameters to prune.\n                If ``float``, should be between 0.0 and 1.0 and represent the\n                fraction of parameters to prune. If ``int``, it represents the\n                absolute number of parameters to prune.\n            dim (int, optional): index of the dim along which we define\n                channels to prune. Default: -1.\n        \"\"\"\n        return super(RandomStructured, cls).apply(module, name, amount=amount, dim=dim)\n\n\nclass LnStructured(BasePruningMethod):\n    r\"\"\"Prune entire (currently unpruned) channels in a tensor based on their\n    L\\ ``n``-norm.\n    Args:\n        amount (int or float): quantity of channels to prune.\n            If ``float``, should be between 0.0 and 1.0 and represent the\n            fraction of parameters to prune. If ``int``, it represents the\n            absolute number of parameters to prune.\n        n (int, float, inf, -inf, 'fro', 'nuc'): See documentation of valid\n            entries for argument ``p`` in :func:`flow.norm`.\n        dim (int, optional): index of the dim along which we define\n            channels to prune. Default: -1.\n    \"\"\"\n\n    PRUNING_TYPE = \"structured\"\n\n    def __init__(self, amount, n, dim=-1):\n        # Check range of validity of amount\n        _validate_pruning_amount_init(amount)\n        self.amount = amount\n        self.n = n\n        self.dim = dim\n\n    def compute_mask(self, t, default_mask):\n        r\"\"\"Computes and returns a mask for the input tensor ``t``.\n        Starting from a base ``default_mask`` (which should be a mask of ones\n        if the tensor has not been pruned yet), generate a mask to apply on\n        top of the ``default_mask`` by zeroing out the channels along the\n        specified dim with the lowest L\\ ``n``-norm.\n        Args:\n            t (flow.Tensor): tensor representing the parameter to prune\n            default_mask (flow.Tensor): Base mask from previous pruning\n                iterations, that need to be respected after the new mask is\n                applied.  Same dims as ``t``.\n        Returns:\n            mask (flow.Tensor): mask to apply to ``t``, of same dims as ``t``\n        Raises:\n            IndexError: if ``self.dim >= len(t.shape)``\n        \"\"\"\n        # Check that tensor has structure (i.e. more than 1 dimension) such\n        # that the concept of \"channels\" makes sense\n        _validate_structured_pruning(t)\n        # Check that self.dim is a valid dim to index t, else raise IndexError\n        _validate_pruning_dim(t, self.dim)\n\n        # Check that the amount of channels to prune is not > than the number of\n        # channels in t along the dim to prune\n        tensor_size = t.shape[self.dim]\n        # Compute number of units to prune: amount if int,\n        # else amount * tensor_size\n        nparams_toprune = _compute_nparams_toprune(self.amount, tensor_size)\n        nparams_tokeep = tensor_size - nparams_toprune\n        # This should raise an error if the number of units to prune is larger\n        # than the number of units in the tensor\n        _validate_pruning_amount(nparams_toprune, tensor_size)\n\n        # Structured pruning prunes entire channels so we need to know the\n        # L_n norm along each channel to then find the topk based on this\n        # metric\n        norm = _compute_norm(t, self.n, self.dim)\n        # largest=True --> top k; largest=False --> bottom k\n        # Keep the largest k channels along dim=self.dim\n\n        topk = flow.topk(norm, k=nparams_tokeep, largest=True)\n        # topk will have .indices and .values\n\n        # Compute binary mask by initializing it to all 0s and then filling in\n        # 1s wherever topk.indices indicates, along self.dim.\n        # mask has the same shape as tensor t\n        def make_mask(t, dim, indices):\n            # init mask to 0\n            mask = flow.zeros_like(t)\n            # e.g.: slc = [None, None, None], if len(t.shape) = 3\n            slc = [slice(None)] * len(t.shape)\n            # replace a None at position=dim with indices\n            # e.g.: slc = [None, None, [0, 2, 3]] if dim=2 & indices=[0,2,3]\n            slc[dim] = indices\n            # use slc to slice mask and replace all its entries with 1s\n            # e.g.: mask[:, :, [0, 2, 3]] = 1\n            mask[slc] = 1\n            return mask\n\n        if nparams_toprune == 0:  # k=0 not supported by flow.kthvalue\n            mask = default_mask\n        else:\n\n            mask = make_mask(t, self.dim, topk.indices)\n            mask *= default_mask.to(dtype=mask.dtype)\n\n        return mask\n\n    @classmethod\n    def apply(cls, module, name, amount, n, dim, importance_scores=None):\n        r\"\"\"Adds the forward pre-hook that enables pruning on the fly and\n        the reparametrization of a tensor in terms of the original tensor\n        and the pruning mask.\n        Args:\n            module (nn.Module): module containing the tensor to prune\n            name (str): parameter name within ``module`` on which pruning\n                will act.\n            amount (int or float): quantity of parameters to prune.\n                If ``float``, should be between 0.0 and 1.0 and represent the\n                fraction of parameters to prune. If ``int``, it represents the\n                absolute number of parameters to prune.\n            n (int, float, inf, -inf, 'fro', 'nuc'): See documentation of valid\n                entries for argument ``p`` in :func:`flow.norm`.\n            dim (int): index of the dim along which we define channels to\n                prune.\n            importance_scores (flow.Tensor): tensor of importance scores (of same\n                shape as module parameter) used to compute mask for pruning.\n                The values in this tensor indicate the importance of the corresponding\n                elements in the parameter being pruned.\n                If unspecified or None, the module parameter will be used in its place.\n        \"\"\"\n        return super(LnStructured, cls).apply(\n            module,\n            name,\n            amount=amount,\n            n=n,\n            dim=dim,\n            importance_scores=importance_scores,\n        )\n\n\nclass CustomFromMask(BasePruningMethod):\n\n    PRUNING_TYPE = \"global\"\n\n    def __init__(self, mask):\n        self.mask = mask\n\n    def compute_mask(self, t, default_mask):\n        assert default_mask.shape == self.mask.shape\n        mask = default_mask * self.mask.to(dtype=default_mask.dtype)\n        return mask\n\n    @classmethod\n    def apply(cls, module, name, mask):\n        r\"\"\"Adds the forward pre-hook that enables pruning on the fly and\n        the reparametrization of a tensor in terms of the original tensor\n        and the pruning mask.\n        Args:\n            module (nn.Module): module containing the tensor to prune\n            name (str): parameter name within ``module`` on which pruning\n                will act.\n        \"\"\"\n        return super(CustomFromMask, cls).apply(module, name, mask=mask)\n\n\ndef identity(module, name):\n    r\"\"\"Applies pruning reparametrization to the tensor corresponding to the\n    parameter called ``name`` in ``module`` without actually pruning any\n    units. Modifies module in place (and also return the modified module)\n    by:\n    1) adding a named buffer called ``name+'_mask'`` corresponding to the\n       binary mask applied to the parameter ``name`` by the pruning method.\n    2) replacing the parameter ``name`` by its pruned version, while the\n       original (unpruned) parameter is stored in a new parameter named\n       ``name+'_orig'``.\n    Note:\n        The mask is a tensor of ones.\n    Args:\n        module (nn.Module): module containing the tensor to prune.\n        name (str): parameter name within ``module`` on which pruning\n                will act.\n    Returns:\n        module (nn.Module): modified (i.e. pruned) version of the input module\n    Examples:\n        >>> # xdoctest: +SKIP\n        >>> m = prune.identity(nn.Linear(2, 3), 'bias')\n        >>> print(m.bias_mask)\n        tensor([1., 1., 1.])\n    \"\"\"\n    Identity.apply(module, name)\n    return module\n\n\ndef random_unstructured(module, name, amount):\n    r\"\"\"Prunes tensor corresponding to parameter called ``name`` in ``module``\n    by removing the specified ``amount`` of (currently unpruned) units\n    selected at random.\n    Modifies module in place (and also return the modified module) by:\n    1) adding a named buffer called ``name+'_mask'`` corresponding to the\n       binary mask applied to the parameter ``name`` by the pruning method.\n    2) replacing the parameter ``name`` by its pruned version, while the\n       original (unpruned) parameter is stored in a new parameter named\n       ``name+'_orig'``.\n    Args:\n        module (nn.Module): module containing the tensor to prune\n        name (str): parameter name within ``module`` on which pruning\n                will act.\n        amount (int or float): quantity of parameters to prune.\n            If ``float``, should be between 0.0 and 1.0 and represent the\n            fraction of parameters to prune. If ``int``, it represents the\n            absolute number of parameters to prune.\n    Returns:\n        module (nn.Module): modified (i.e. pruned) version of the input module\n    Examples:\n        >>> # xdoctest: +SKIP\n        >>> m = prune.random_unstructured(nn.Linear(2, 3), 'weight', amount=1)\n        >>> flow.sum(m.weight_mask == 0)\n        tensor(1)\n    \"\"\"\n    RandomUnstructured.apply(module, name, amount)\n    return module\n\n\ndef l1_unstructured(module, name, amount, importance_scores=None):\n    r\"\"\"Prunes tensor corresponding to parameter called ``name`` in ``module``\n    by removing the specified `amount` of (currently unpruned) units with the\n    lowest L1-norm.\n    Modifies module in place (and also return the modified module)\n    by:\n    1) adding a named buffer called ``name+'_mask'`` corresponding to the\n       binary mask applied to the parameter ``name`` by the pruning method.\n    2) replacing the parameter ``name`` by its pruned version, while the\n       original (unpruned) parameter is stored in a new parameter named\n       ``name+'_orig'``.\n    Args:\n        module (nn.Module): module containing the tensor to prune\n        name (str): parameter name within ``module`` on which pruning\n                will act.\n        amount (int or float): quantity of parameters to prune.\n            If ``float``, should be between 0.0 and 1.0 and represent the\n            fraction of parameters to prune. If ``int``, it represents the\n            absolute number of parameters to prune.\n        importance_scores (flow.Tensor): tensor of importance scores (of same\n            shape as module parameter) used to compute mask for pruning.\n            The values in this tensor indicate the importance of the corresponding\n            elements in the parameter being pruned.\n            If unspecified or None, the module parameter will be used in its place.\n    Returns:\n        module (nn.Module): modified (i.e. pruned) version of the input module\n    Examples:\n        >>> # xdoctest: +SKIP\n        >>> m = prune.l1_unstructured(nn.Linear(2, 3), 'weight', amount=0.2)\n        >>> m.state_dict().keys()\n        odict_keys(['bias', 'weight_orig', 'weight_mask'])\n    \"\"\"\n    L1Unstructured.apply(\n        module, name, amount=amount, importance_scores=importance_scores\n    )\n    return module\n\n\ndef random_structured(module, name, amount, dim):\n    r\"\"\"Prunes tensor corresponding to parameter called ``name`` in ``module``\n    by removing the specified ``amount`` of (currently unpruned) channels\n    along the specified ``dim`` selected at random.\n    Modifies module in place (and also return the modified module)\n    by:\n    1) adding a named buffer called ``name+'_mask'`` corresponding to the\n       binary mask applied to the parameter ``name`` by the pruning method.\n    2) replacing the parameter ``name`` by its pruned version, while the\n       original (unpruned) parameter is stored in a new parameter named\n       ``name+'_orig'``.\n    Args:\n        module (nn.Module): module containing the tensor to prune\n        name (str): parameter name within ``module`` on which pruning\n                will act.\n        amount (int or float): quantity of parameters to prune.\n            If ``float``, should be between 0.0 and 1.0 and represent the\n            fraction of parameters to prune. If ``int``, it represents the\n            absolute number of parameters to prune.\n        dim (int): index of the dim along which we define channels to prune.\n    Returns:\n        module (nn.Module): modified (i.e. pruned) version of the input module\n    Examples:\n        >>> # xdoctest: +SKIP\n        >>> m = prune.random_structured(\n        ...     nn.Linear(5, 3), 'weight', amount=3, dim=1\n        ... )\n        >>> columns_pruned = int(sum(flow.sum(m.weight, dim=0) == 0))\n        >>> print(columns_pruned)\n        3\n    \"\"\"\n    RandomStructured.apply(module, name, amount, dim)\n    return module\n\n\ndef ln_structured(module, name, amount, n, dim, importance_scores=None):\n    r\"\"\"Prunes tensor corresponding to parameter called ``name`` in ``module``\n    by removing the specified ``amount`` of (currently unpruned) channels\n    along the specified ``dim`` with the lowest L\\ ``n``-norm.\n    Modifies module in place (and also return the modified module)\n    by:\n    1) adding a named buffer called ``name+'_mask'`` corresponding to the\n       binary mask applied to the parameter ``name`` by the pruning method.\n    2) replacing the parameter ``name`` by its pruned version, while the\n       original (unpruned) parameter is stored in a new parameter named\n       ``name+'_orig'``.\n    Args:\n        module (nn.Module): module containing the tensor to prune\n        name (str): parameter name within ``module`` on which pruning\n                will act.\n        amount (int or float): quantity of parameters to prune.\n            If ``float``, should be between 0.0 and 1.0 and represent the\n            fraction of parameters to prune. If ``int``, it represents the\n            absolute number of parameters to prune.\n        n (int, float, inf, -inf, 'fro', 'nuc'): See documentation of valid\n            entries for argument ``p`` in :func:`flow.norm`.\n        dim (int): index of the dim along which we define channels to prune.\n        importance_scores (flow.Tensor): tensor of importance scores (of same\n            shape as module parameter) used to compute mask for pruning.\n            The values in this tensor indicate the importance of the corresponding\n            elements in the parameter being pruned.\n            If unspecified or None, the module parameter will be used in its place.\n    Returns:\n        module (nn.Module): modified (i.e. pruned) version of the input module\n    Examples:\n        >>> # xdoctest: +SKIP\n        >>> m = prune.ln_structured(\n        ...    nn.Conv2d(5, 3, 2), 'weight', amount=0.3, dim=1, n=float('-inf')\n        ... )\n    \"\"\"\n    LnStructured.apply(\n        module, name, amount, n, dim, importance_scores=importance_scores\n    )\n    return module\n\n\ndef global_unstructured(parameters, pruning_method, importance_scores=None, **kwargs):\n    r\"\"\"\n    Globally prunes tensors corresponding to all parameters in ``parameters``\n    by applying the specified ``pruning_method``.\n    Modifies modules in place by:\n    1) adding a named buffer called ``name+'_mask'`` corresponding to the\n       binary mask applied to the parameter ``name`` by the pruning method.\n    2) replacing the parameter ``name`` by its pruned version, while the\n       original (unpruned) parameter is stored in a new parameter named\n       ``name+'_orig'``.\n    Args:\n        parameters (Iterable of (module, name) tuples): parameters of\n            the model to prune in a global fashion, i.e. by aggregating all\n            weights prior to deciding which ones to prune. module must be of\n            type :class:`nn.Module`, and name must be a string.\n        pruning_method (function): a valid pruning function from this module,\n            or a custom one implemented by the user that satisfies the\n            implementation guidelines and has ``PRUNING_TYPE='unstructured'``.\n        importance_scores (dict): a dictionary mapping (module, name) tuples to\n            the corresponding parameter's importance scores tensor. The tensor\n            should be the same shape as the parameter, and is used for computing\n            mask for pruning.\n            If unspecified or None, the parameter will be used in place of its\n            importance scores.\n        kwargs: other keyword arguments such as:\n            amount (int or float): quantity of parameters to prune across the\n            specified parameters.\n            If ``float``, should be between 0.0 and 1.0 and represent the\n            fraction of parameters to prune. If ``int``, it represents the\n            absolute number of parameters to prune.\n    Raises:\n        TypeError: if ``PRUNING_TYPE != 'unstructured'``\n    Note:\n        Since global structured pruning doesn't make much sense unless the\n        norm is normalized by the size of the parameter, we now limit the\n        scope of global pruning to unstructured methods.\n    Examples:\n        >>> # xdoctest: +SKIP\n        >>> net = nn.Sequential(OrderedDict([\n        ...     ('first', nn.Linear(10, 4)),\n        ...     ('second', nn.Linear(4, 1)),\n        ... ]))\n        >>> parameters_to_prune = (\n        ...     (net.first, 'weight'),\n        ...     (net.second, 'weight'),\n        ... )\n        >>> prune.global_unstructured(\n        ...     parameters_to_prune,\n        ...     pruning_method=prune.L1Unstructured,\n        ...     amount=10,\n        ... )\n        >>> print(sum(flow.nn.utils.parameters_to_vector(net.buffers()) == 0))\n        tensor(10, dtype=flow.uint8)\n    \"\"\"\n    # ensure parameters is a list or generator of tuples\n    if not isinstance(parameters, Iterable):\n        raise TypeError(\"global_unstructured(): parameters is not an Iterable\")\n\n    importance_scores = importance_scores if importance_scores is not None else {}\n    if not isinstance(importance_scores, dict):\n        raise TypeError(\"global_unstructured(): importance_scores must be of type dict\")\n\n    # flatten importance scores to consider them all at once in global pruning\n    relevant_importance_scores = flow.nn.utils.parameters_to_vector(\n        [\n            importance_scores.get((module, name), getattr(module, name))\n            for (module, name) in parameters\n        ]\n    )\n    # similarly, flatten the masks (if they exist), or use a flattened vector\n    # of 1s of the same dimensions as t\n    default_mask = flow.nn.utils.parameters_to_vector(\n        [\n            getattr(module, name + \"_mask\", flow.ones_like(getattr(module, name)))\n            for (module, name) in parameters\n        ]\n    )\n\n    # use the canonical pruning methods to compute the new mask, even if the\n    # parameter is now a flattened out version of `parameters`\n    container = PruningContainer()\n    container._tensor_name = \"temp\"  # to make it match that of `method`\n    method = pruning_method(**kwargs)\n    method._tensor_name = \"temp\"  # to make it match that of `container`\n    if method.PRUNING_TYPE != \"unstructured\":\n        raise TypeError(\n            'Only \"unstructured\" PRUNING_TYPE supported for '\n            \"the `pruning_method`. Found method {} of type {}\".format(\n                pruning_method, method.PRUNING_TYPE\n            )\n        )\n\n    container.add_pruning_method(method)\n\n    # use the `compute_mask` method from `PruningContainer` to combine the\n    # mask computed by the new method with the pre-existing mask\n    final_mask = container.compute_mask(relevant_importance_scores, default_mask)\n\n    # Pointer for slicing the mask to match the shape of each parameter\n    pointer = 0\n    for module, name in parameters:\n\n        param = getattr(module, name)\n        # The length of the parameter\n        num_param = param.numel()\n        # Slice the mask, reshape it\n        param_mask = final_mask[pointer : pointer + num_param].view_as(param)\n        # Assign the correct pre-computed mask to each parameter and add it\n        # to the forward_pre_hooks like any other pruning method\n        custom_from_mask(module, name, mask=param_mask)\n\n        # Increment the pointer to continue slicing the final_mask\n        pointer += num_param\n\n\ndef custom_from_mask(module, name, mask):\n    r\"\"\"Prunes tensor corresponding to parameter called ``name`` in ``module``\n    by applying the pre-computed mask in ``mask``.\n    Modifies module in place (and also return the modified module)\n    by:\n    1) adding a named buffer called ``name+'_mask'`` corresponding to the\n       binary mask applied to the parameter ``name`` by the pruning method.\n    2) replacing the parameter ``name`` by its pruned version, while the\n       original (unpruned) parameter is stored in a new parameter named\n       ``name+'_orig'``.\n    Args:\n        module (nn.Module): module containing the tensor to prune\n        name (str): parameter name within ``module`` on which pruning\n            will act.\n        mask (Tensor): binary mask to be applied to the parameter.\n    Returns:\n        module (nn.Module): modified (i.e. pruned) version of the input module\n    Examples:\n        >>> # xdoctest: +SKIP\n        >>> m = prune.custom_from_mask(\n        ...     nn.Linear(5, 3), name='bias', mask=flow.tensor([0, 1, 0])\n        ... )\n        >>> print(m.bias_mask)\n        tensor([0., 1., 0.])\n    \"\"\"\n    CustomFromMask.apply(module, name, mask)\n    return module\n\n\ndef remove(module, name):\n    r\"\"\"Removes the pruning reparameterization from a module and the\n    pruning method from the forward hook. The pruned\n    parameter named ``name`` remains permanently pruned, and the parameter\n    named ``name+'_orig'`` is removed from the parameter list. Similarly,\n    the buffer named ``name+'_mask'`` is removed from the buffers.\n    Note:\n        Pruning itself is NOT undone or reversed!\n    Args:\n        module (nn.Module): module containing the tensor to prune\n        name (str): parameter name within ``module`` on which pruning\n            will act.\n    Examples:\n        >>> m = random_unstructured(nn.Linear(5, 7), name='weight', amount=0.2)\n        >>> m = remove(m, name='weight')\n    \"\"\"\n    for k, hook in module._forward_pre_hooks.items():\n        if isinstance(hook, BasePruningMethod) and hook._tensor_name == name:\n            hook.remove(module)\n            del module._forward_pre_hooks[k]\n            return module\n\n    raise ValueError(\n        \"Parameter '{}' of module {} has to be pruned \"\n        \"before pruning can be removed\".format(name, module)\n    )\n\n\ndef is_pruned(module):\n    r\"\"\"Check whether ``module`` is pruned by looking for\n    ``forward_pre_hooks`` in its modules that inherit from the\n    :class:`BasePruningMethod`.\n    Args:\n        module (nn.Module): object that is either pruned or unpruned\n    Returns:\n        binary answer to whether ``module`` is pruned.\n    Examples:\n        >>> m = nn.Linear(5, 7)\n        >>> # xdoctest: +SKIP\n        >>> print(prune.is_pruned(m))\n        False\n        >>> prune.random_unstructured(m, name='weight', amount=0.2)\n        >>> print(prune.is_pruned(m))\n        True\n    \"\"\"\n    for _, submodule in module.named_modules():\n        for _, hook in submodule._forward_pre_hooks.items():\n            if isinstance(hook, BasePruningMethod):\n                return True\n    return False\n\n\ndef _validate_pruning_amount_init(amount):\n    r\"\"\"Validation helper to check the range of amount at init.\n    Args:\n        amount (int or float): quantity of parameters to prune.\n            If float, should be between 0.0 and 1.0 and represent the\n            fraction of parameters to prune. If int, it represents the\n            absolute number of parameters to prune.\n    Raises:\n        ValueError: if amount is a float not in [0, 1], or if it's a negative\n            integer.\n        TypeError: if amount is neither a float nor an integer.\n    Note:\n        This does not take into account the number of parameters in the\n        tensor to be pruned, which is known only at prune.\n    \"\"\"\n    if not isinstance(amount, numbers.Real):\n        raise TypeError(\n            \"Invalid type for amount: {}. Must be int or float.\" \"\".format(amount)\n        )\n\n    if (isinstance(amount, numbers.Integral) and amount < 0) or (\n        not isinstance(amount, numbers.Integral)  # so it's a float\n        and (float(amount) > 1.0 or float(amount) < 0.0)\n    ):\n        raise ValueError(\n            \"amount={} should either be a float in the \"\n            \"range [0, 1] or a non-negative integer\"\n            \"\".format(amount)\n        )\n\n\ndef _validate_pruning_amount(amount, tensor_size):\n    r\"\"\"Validation helper to check that the amount of parameters to prune\n    is meaningful wrt to the size of the data (`tensor_size`).\n    Args:\n        amount (int or float): quantity of parameters to prune.\n            If float, should be between 0.0 and 1.0 and represent the\n            fraction of parameters to prune. If int, it represents the\n            absolute number of parameters to prune.\n        tensor_size (int): absolute number of parameters in the tensor\n            to prune.\n    \"\"\"\n    # TODO: consider removing this check and allowing users to specify\n    # a number of units to prune that is greater than the number of units\n    # left to prune. In this case, the tensor will just be fully pruned.\n\n    if isinstance(amount, numbers.Integral) and amount > tensor_size:\n        raise ValueError(\n            \"amount={} should be smaller than the number of \"\n            \"parameters to prune={}\".format(amount, tensor_size)\n        )\n\n\ndef _validate_structured_pruning(t):\n    r\"\"\"Validation helper to check that the tensor to be pruned is multi-\n    dimensional, such that the concept of \"channels\" is well-defined.\n    Args:\n        t (flow.Tensor): tensor representing the parameter to prune\n    Raises:\n        ValueError: if the tensor `t` is not at least 2D.\n    \"\"\"\n    shape = t.shape\n    if len(shape) <= 1:\n        raise ValueError(\n            \"Structured pruning can only be applied to \"\n            \"multidimensional tensors. Found tensor of shape \"\n            \"{} with {} dims\".format(shape, len(shape))\n        )\n\n\ndef _compute_nparams_toprune(amount, tensor_size):\n    r\"\"\"Since amount can be expressed either in absolute value or as a\n    percentage of the number of units/channels in a tensor, this utility\n    function converts the percentage to absolute value to standardize\n    the handling of pruning.\n    Args:\n        amount (int or float): quantity of parameters to prune.\n            If float, should be between 0.0 and 1.0 and represent the\n            fraction of parameters to prune. If int, it represents the\n            absolute number of parameters to prune.\n        tensor_size (int): absolute number of parameters in the tensor\n            to prune.\n    Returns:\n        int: the number of units to prune in the tensor\n    \"\"\"\n    # incorrect type already checked in _validate_pruning_amount_init\n    if isinstance(amount, numbers.Integral):\n        return amount\n    else:\n        return round(amount * tensor_size)\n\n\ndef _validate_pruning_dim(t, dim):\n    r\"\"\"\n    Args:\n        t (flow.Tensor): tensor representing the parameter to prune\n        dim (int): index of the dim along which we define channels to prune\n    \"\"\"\n    if dim >= t.dim():\n        raise IndexError(\"Invalid index {} for tensor of size {}\".format(dim, t.shape))\n\n\ndef _compute_norm(t, n, dim):\n    r\"\"\"Compute the L_n-norm across all entries in tensor `t` along all dimension\n    except for the one identified by dim.\n    Example: if `t` is of shape, say, 3x2x4 and dim=2 (the last dim),\n    then norm will have Size [4], and each entry will represent the\n    `L_n`-norm computed using the 3x2=6 entries for each of the 4 channels.\n    Args:\n        t (flow.Tensor): tensor representing the parameter to prune\n        n (int, float, inf, -inf, 'fro', 'nuc'): See documentation of valid\n            entries for argument p in flow.norm\n        dim (int): dim identifying the channels to prune\n    Returns:\n        norm (flow.Tensor): L_n norm computed across all dimensions except\n            for `dim`. By construction, `norm.shape = t.shape[-1]`.\n    \"\"\"\n    # dims = all axes, except for the one identified by `dim`\n    dims = list(range(t.dim()))\n    # convert negative indexing\n    if dim < 0:\n        dim = dims[dim]\n    dims.remove(dim)\n    # norm = flow.norm(t, p=n, dim=dims)\n\n    # torch.norm in pytorch can support the norm of multi-dimensional arrays,\n    # but the norm of the oneflow version only supports the norm of two-dimensional\n    # arrays. So we need to reshape tensor into two-dimensional tensor. The dim of 1\n    # represent the dims to compute norm.\n\n    a = t.clone()\n    fullDims = list(range(a.dim()))\n    retainedDims = list(set(fullDims).difference(set(dims)))\n    permute_order = retainedDims + dims\n    reshape_size = 1\n    for item in retainedDims:\n        reshape_size *= a.shape[item]\n    a = a.permute(permute_order)\n    a = a.reshape(reshape_size, -1)\n    norm = flow.norm(a, p=n, dim=1)\n\n    return norm\n"
  },
  {
    "path": "python/oneflow/nn/utils/rnn.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nfrom audioop import reverse\nfrom collections import namedtuple\nfrom typing import List, Tuple, Union, Iterable, Optional\nimport warnings\n\nimport oneflow as flow\nfrom oneflow.framework.tensor import Tensor\n\n# The implementation of rnn util is modified from: https://github.com/pytorch/pytorch/blob/master/torch/nn/utils/rnn.py\n\n\ndef bind(optional, fn):\n    if optional is None:\n        return None\n    return fn(optional)\n\n\ndef invert_permutation(permutation: Optional[Tensor]) -> Optional[Tensor]:\n    if permutation is None:\n        return None\n    return flow.scatter(\n        flow.zeros_like(permutation),\n        0,\n        permutation,\n        flow.arange(\n            0, permutation.numel(), device=permutation.device, dtype=flow.int32\n        ),\n    )\n\n\nclass PackedSequence(object):\n    \"\"\"The interface is consistent with PyTorch.\n    The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.utils.rnn.PackedSequence.html.\n    \n    Holds the data and list of :attr:`batch_sizes` of a packed sequence.\n\n    All RNN modules accept packed sequences as inputs.\n\n    Note:\n        Instances of this class should never be created manually. They are meant\n        to be instantiated by functions like :func:`pack_padded_sequence`.\n\n        Batch sizes represent the number elements at each sequence step in\n        the batch, not the varying sequence lengths passed to\n        :func:`pack_padded_sequence`.  For instance, given data ``abc`` and ``x``\n        the :class:`PackedSequence` would contain data ``axbc`` with\n        ``batch_sizes=[2,1,1]``.\n\n    Attributes:\n        data (Tensor): Tensor containing packed sequence\n        batch_sizes (Tensor): Tensor of integers holding\n            information about the batch size at each sequence step\n        sorted_indices (Tensor, optional): Tensor of integers holding how this\n            :class:`PackedSequence` is constructed from sequences.\n        unsorted_indices (Tensor, optional): Tensor of integers holding how this\n            to recover the original sequences with correct order.\n\n    .. note::\n        :attr:`data` can be on arbitrary device and of arbitrary dtype.\n        :attr:`sorted_indices` and :attr:`unsorted_indices` must be ``oneflow.int64``\n        tensors on the same device as :attr:`data`.\n\n        However, :attr:`batch_sizes` should always be a CPU ``oneflow.int64`` tensor.\n\n        This invariant is maintained throughout :class:`PackedSequence` class,\n        and all functions that construct a `:class:PackedSequence` in PyTorch\n        (i.e., they only pass in tensors conforming to this constraint).\n\n    \"\"\"\n\n    def __init__(\n        self,\n        data: Tensor,\n        batch_sizes: Optional[Tensor] = None,\n        sorted_indices: Optional[Tensor] = None,\n        unsorted_indices: Optional[Tensor] = None,\n    ):\n        self.sorted_indices = sorted_indices\n        if unsorted_indices is None:\n            self.unsorted_indices = invert_permutation(sorted_indices)\n        self.sorted_indices = sorted_indices\n\n        if batch_sizes is not None:\n            if batch_sizes.device.type != \"cpu\":\n                raise ValueError(\n                    \"batch_sizes should always be on CPU. \"\n                    \"Instances of PackedSequence should never be created manually. \"\n                    \"They should be instantiated by functions like pack_sequence \"\n                    \"and pack_padded_sequences in nn.rnn_utils \"\n                )\n            self.data = data\n            self.batch_sizes = batch_sizes\n        else:\n            assert isinstance(data, (list, tuple)) and len(data) == 2\n            self.data = data[0]\n            self.batch_sizes = data[1]\n\n    def pin_memory(self):\n        return PackedSequence(\n            self.data.pin_memory(),\n            self.batch_sizes,\n            bind(self.sorted_indices, lambda t: t.pin_memory()),\n            bind(self.unsorted_indices, lambda t: t.pin_memory()),\n        )\n\n    def cuda(self, *args, **kwargs):\n        ex = flow.tensor((), dtype=self.data.dtype, device=self.data.device).to(\n            *args, **kwargs\n        )\n        if ex.is_cuda:\n            return self.to(*args, **kwargs)\n        return self.to(*args, device=\"cuda\", **kwargs)\n\n    def cpu(self, *args, **kwargs):\n\n        ex = flow.tensor((), dtype=self.data.dtype, device=self.data.device).to(\n            *args, **kwargs\n        )\n        if ex.device.type == \"cpu\":\n            return self.to(*args, **kwargs)\n        return self.to(*args, device=\"cpu\", **kwargs)\n\n    def double(self):\n        return self.to(dtype=flow.double)\n\n    def float(self):\n        return self.to(dtype=flow.float)\n\n    def half(self):\n        return self.to(dtype=flow.half)\n\n    def long(self):\n        return self.to(dtype=flow.long)\n\n    def int(self):\n        return self.to(dtype=flow.int)\n\n    def short(self):\n        return self.to(dtype=flow.short)\n\n    def char(self):\n        return self.to(dtype=flow.int8)\n\n    def byte(self):\n        return self.to(dtype=flow.uint8)\n\n    def to(self, *args, **kwargs):\n        \"\"\"Performs dtype and/or device conversion on `self.data`.\n\n        It has similar signature as :meth:`oneflow.Tensor.to`, except optional\n        arguments like `non_blocking` and `copy` should be passed as kwargs,\n        not args, or they will not apply to the index tensors.\n\n        .. note::\n\n            If the ``self.data`` Tensor already has the correct :class:`oneflow.dtype`\n            and :class:`oneflow.device`, then ``self`` is returned.\n            Otherwise, returns a copy with the desired configuration.\n        \"\"\"\n        data = self.data.to(*args, **kwargs)\n        if data is self.data:\n            return self\n        else:\n            kwargs = {\n                k: v\n                for k, v in filter(\n                    lambda t: t[0] != \"device\" and t[0] != \"dtype\", kwargs.items()\n                )\n            }\n            sorted_indices = bind(\n                self.sorted_indices, lambda t: t.to(data.device, **kwargs)\n            )\n            unsorted_indices = bind(\n                self.unsorted_indices, lambda t: t.to(data.device, **kwargs)\n            )\n            return PackedSequence(\n                data, self.batch_sizes, sorted_indices, unsorted_indices\n            )\n\n    @property\n    def is_cuda(self):\n        r\"\"\"Returns true if `self.data` stored on a gpu\"\"\"\n        return self.data.is_cuda\n\n    def is_pinned(self):\n        r\"\"\"Returns true if `self.data` stored on in pinned memory\"\"\"\n        return self.data.is_pinned()\n\n\ndef pack_padded_sequence(\n    input: Tensor,\n    lengths: Tensor,\n    batch_first: bool = False,\n    enforce_sorted: bool = True,\n) -> PackedSequence:\n    \"\"\"The interface is consistent with PyTorch.\n    The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.utils.rnn.pack_padded_sequence.html.\n    \n    Packs a Tensor containing padded sequences of variable length.\n\n    :attr:`input` can be of size ``T x B x *`` where `T` is the length of the\n    longest sequence (equal to ``lengths[0]``), ``B`` is the batch size, and\n    ``*`` is any number of dimensions (including 0). If ``batch_first`` is\n    ``True``, ``B x T x *`` :attr:`input` is expected.\n\n    For unsorted sequences, use `enforce_sorted = False`. If :attr:`enforce_sorted` is\n    ``True``, the sequences should be sorted by length in a decreasing order, i.e.\n    ``input[:,0]`` should be the longest sequence, and ``input[:,B-1]`` the shortest\n    one. `enforce_sorted = True` is only necessary for ONNX export.\n\n    Note:\n        This function accepts any input that has at least two dimensions. You\n        can apply it to pack the labels, and use the output of the RNN with\n        them to compute the loss directly. A Tensor can be retrieved from\n        a :class:`PackedSequence` object by accessing its ``.data`` attribute.\n\n    Args:\n        input (Tensor): padded batch of variable length sequences.\n        lengths (Tensor or list(int)): list of sequence lengths of each batch\n            element (must be on the CPU if provided as a tensor).\n        batch_first (bool, optional): if ``True``, the input is expected in ``B x T x *``\n            format.\n        enforce_sorted (bool, optional): if ``True``, the input is expected to\n            contain sequences sorted by length in a decreasing order. If\n            ``False``, the input will get sorted unconditionally. Default: ``True``.\n\n    Returns:\n        a :class:`PackedSequence` object\n    \"\"\"\n    lengths = flow.as_tensor(lengths, dtype=flow.int64)\n    assert (\n        enforce_sorted == True\n    ), \"Only support enforce_sorted == True for now. Plesase Sort the input by length in a decreasing order.\"\n    if enforce_sorted:\n        sorted_indices = None\n    else:\n        lengths, sorted_indices = flow.sort(lengths, descending=True)\n        sorted_indices = sorted_indices.to(input.device)\n        batch_dim = 0 if batch_first else 1\n        input = input.index_select(batch_dim, sorted_indices)\n    data, batch_sizes = flow._C.pack_padded_sequence(input, lengths, batch_first)\n    return PackedSequence(data, batch_sizes, sorted_indices, None)\n\n\ndef pad_packed_sequence(\n    sequence: PackedSequence,\n    batch_first: bool = False,\n    padding_value: float = 0.0,\n    total_length: Optional[int] = None,\n) -> Tuple[Tensor, Tensor]:\n    \"\"\"The interface is consistent with PyTorch.\n    The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.utils.rnn.pad_packed_sequence.html.\n    \n    Pads a packed batch of variable length sequences.\n\n    It is an inverse operation to :func:`pack_padded_sequence`.\n\n    The returned Tensor's data will be of size ``T x B x *``, where `T` is the length\n    of the longest sequence and `B` is the batch size. If ``batch_first`` is True,\n    the data will be transposed into ``B x T x *`` format.\n\n    .. note::\n        :attr:`total_length` is useful to implement the\n        ``pack sequence -> recurrent network -> unpack sequence`` pattern in a\n        :class:`~oneflow.nn.Module` wrapped in :class:`~oneflow.nn.DataParallel`.\n\n    Args:\n        sequence (PackedSequence): batch to pad\n        batch_first (bool, optional): if ``True``, the output will be in ``B x T x *``\n            format.\n        padding_value (float, optional): values for padded elements.\n        total_length (int, optional): if not ``None``, the output will be padded to\n            have length :attr:`total_length`. This method will throw :class:`ValueError`\n            if :attr:`total_length` is less than the max sequence length in\n            :attr:`sequence`.\n\n    Returns:\n        Tuple of Tensor containing the padded sequence, and a Tensor\n        containing the list of lengths of each sequence in the batch.\n        Batch elements will be re-ordered as they were ordered originally when\n        the batch was passed to ``pack_padded_sequence`` or ``pack_sequence``.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> from oneflow.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence\n        >>> import oneflow as flow\n\n        >>> seq = flow.tensor([[4,5,6], [1,2,0], [3,0,0]])\n        >>> lens = [3, 2, 1]\n        >>> packed = pack_padded_sequence(seq, lens, batch_first=True, enforce_sorted=True)\n        >>> packed.data\n        tensor([4, 1, 3, 5, 2, 6], dtype=oneflow.int64)\n        >>> packed.batch_sizes\n        tensor([3, 2, 1], dtype=oneflow.int64)\n        >>> seq_unpacked, lens_unpacked = pad_packed_sequence(packed, batch_first=True)\n        >>> seq_unpacked\n        tensor([[4, 5, 6],\n                [1, 2, 0],\n                [3, 0, 0]], dtype=oneflow.int64)\n        >>> lens_unpacked\n        tensor([3., 2., 1.], dtype=oneflow.float32)\n\n\n    \"\"\"\n    max_seq_length = sequence.batch_sizes.shape[0]\n    if total_length is not None:\n        if total_length < max_seq_length:\n            raise ValueError(\n                \"Expected total_length to be at least the length \"\n                \"of the longest sequence in input, but got \"\n                \"total_length={} and max sequence length being {}\".format(\n                    total_length, max_seq_length\n                )\n            )\n    else:\n        total_length = max_seq_length\n\n    batch_sizes_t = sequence.batch_sizes.contiguous()\n    assert (\n        len(batch_sizes_t.shape) == 1\n        and batch_sizes_t.device.type == \"cpu\"\n        and batch_sizes_t.dtype == flow.int64\n    ), f\"'sequence.batch_sizes' should be a 1D CPU int64 tensor, but got {len(batch_sizes_t.shape)} D {batch_sizes_t.device.type} {batch_sizes_t.dtype} tensor\"\n\n    batch_sizes = batch_sizes_t.numpy()\n    max_batch_size = int(batch_sizes[0])\n    max_real_seq_length = batch_sizes_t.shape[0]\n    max_seq_length = max_real_seq_length\n    if total_length > 0:\n        assert (\n            total_length >= max_seq_length\n        ), f\"Expected total_length to be at least the length of the longest sequence in input, but got total_length={total_length} and max sequence length being {max_seq_length}\"\n        max_seq_length = total_length\n\n    output_size = []  # == [max_seq_length, max_batch_size, *sequence.data.size()[1:]]\n    output_size.append(max_seq_length)\n    output_size.append(max_batch_size)\n    output_size = output_size + list(sequence.data.shape[1:])\n    padded_output = flow.full(\n        output_size,\n        padding_value,\n        dtype=sequence.data.dtype,\n        device=sequence.data.device,\n        requires_grad=sequence.data.requires_grad,\n    )\n    # `padded_output` is leaf tensor which needs to be transformed into non-leaf tensor\n    # when it requires grad by calling the `clone` method before the following\n    # in-place operation to avoid runtime check error .\n    if padded_output.requires_grad == True:\n        padded_output = padded_output.clone()\n\n    # This will be modified at every iteration, but we reserve memory for it now.\n    tmp_view_size = output_size  # == [-1, -1, *sequence.data.size()[1:]]\n    lengths = flow.empty(max_batch_size)\n    data_offset = 0\n    prev_batch_size = max_batch_size\n    prev_i = 0\n    lengths_idx = max_batch_size - 1\n    for i in range(max_real_seq_length + 1):\n        batch_size = batch_sizes[i] if i != max_real_seq_length else 0\n        if batch_size != prev_batch_size:\n            l = prev_batch_size * (i - prev_i)\n            tmp_view_size[0] = i - prev_i\n            tmp_view_size[1] = prev_batch_size\n            padded_output[prev_i:i, 0:prev_batch_size] = sequence.data[\n                data_offset : data_offset + l\n            ].view(tmp_view_size)\n            data_offset += l\n            prev_i = i\n\n        dec = prev_batch_size - batch_size\n        if dec > 0:\n            for j in range(dec):\n                lengths[lengths_idx] = i\n                lengths_idx = lengths_idx - 1\n        prev_batch_size = batch_size\n\n    if batch_first:\n        permute_dims = [1, 0]\n        for i in range(2, padded_output.ndim):\n            permute_dims.append(i)\n        padded_output = padded_output.permute(permute_dims)\n\n    unsorted_indices = sequence.unsorted_indices\n    if unsorted_indices is not None:\n        batch_dim = 0 if batch_first else 1\n        return (\n            padded_output.index_select(batch_dim, unsorted_indices),\n            lengths[unsorted_indices],\n        )\n    return padded_output, lengths\n\n\ndef pad_sequence(\n    sequences: Union[Tensor, List[Tensor]],\n    batch_first: bool = False,\n    padding_value: float = 0.0,\n) -> Tensor:\n    \"\"\"The interface is consistent with PyTorch.\n    The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.nn.utils.rnn.pad_sequence.html.\n    \n    Pad a list of variable length Tensors with ``padding_value``\n\n    ``pad_sequence`` stacks a list of Tensors along a new dimension,\n    and pads them to equal length. For example, if the input is list of\n    sequences with size ``L x *`` and if batch_first is False, and ``T x B x *``\n    otherwise.\n\n    `B` is batch size. It is equal to the number of elements in ``sequences``.\n    `T` is length of the longest sequence.\n    `L` is length of the sequence.\n    `*` is any number of trailing dimensions, including none.\n\n    Note:\n        This function returns a Tensor of size ``T x B x *`` or ``B x T x *``\n        where `T` is the length of the longest sequence. This function assumes\n        trailing dimensions and type of all the Tensors in sequences are same.\n\n    Args:\n        sequences (list[Tensor]): list of variable length sequences.\n        batch_first (bool, optional): output will be in ``B x T x *`` if True, or in\n            ``T x B x *`` otherwise. Default: False.\n        padding_value (float, optional): value for padded elements. Default: 0.\n\n    Returns:\n        Tensor of size ``T x B x *`` if :attr:`batch_first` is ``False``.\n        Tensor of size ``B x T x *`` otherwise\n\n    For example:\n\n    .. code-block:: python\n    \n        >>> from oneflow.nn.utils.rnn import pad_sequence\n        >>> import oneflow as flow\n\n        >>> a = flow.ones(25, 300)\n        >>> b = flow.ones(22, 300)\n        >>> c = flow.ones(15, 300)\n        >>> out = pad_sequence([a, b, c])\n        >>> out.size()\n        oneflow.Size([25, 3, 300])\n\n    \"\"\"\n    if isinstance(sequences, Tensor):\n        sequences = sequences.unbind(0)\n\n    # assuming trailing dimensions and type of all the Tensors\n    # in sequences are same and fetching those from sequences[0]\n    sequences_size = len(sequences)\n    max_size = sequences[0].shape\n    trailing_dims = max_size[1:]\n    lens = [seq.shape[0] for seq in sequences]\n    lens.sort(reverse=True)\n    max_len = lens[0]\n    out_dims = [sequences_size, max_len] if batch_first else [max_len, sequences_size]\n    out_dims = out_dims + list(trailing_dims)\n\n    out = flow.full(\n        out_dims,\n        padding_value,\n        dtype=sequences[0].dtype,\n        device=sequences[0].device,\n        requires_grad=sequences[0].requires_grad,\n    )\n    for i in range(sequences_size):\n        currseq = sequences[i]\n        length_i = currseq.shape[0]\n        # use index notation to prevent duplicate references to the tensor\n        if batch_first:\n            out[i, 0:length_i] = currseq\n        else:\n            out[0:length_i, i] = currseq\n    return out\n\n\ndef unpad_sequence(\n    padded_sequences: Tensor, lengths: Tensor, batch_first: bool = False,\n) -> List[Tensor]:\n    \"\"\"\n    Unpad padded Tensor into a list of variable length Tensors\n\n    ``unpad_sequence`` unstacks padded Tensor into a list of variable length Tensors.\n\n    Args:\n        padded_sequences (Tensor): padded sequences.\n        lengths (Tensor): length of original (unpadded) sequences.\n        batch_first (bool, optional): whether batch dimension first or not. Default: False.\n\n    Returns:\n        a list of :class:`Tensor` objects\n\n    For example:\n\n    .. code-block:: python\n\n        >>> from oneflow.nn.utils.rnn import pad_sequence, unpad_sequence\n        >>> import oneflow as flow\n        >>> import numpy as np\n\n        >>> a = flow.ones(25, 300)\n        >>> b = flow.ones(22, 300)\n        >>> c = flow.ones(15, 300)\n        >>> sequences = [a, b, c]\n        >>> padded_sequences = pad_sequence(sequences)\n        >>> lengths = flow.as_tensor([v.size(0) for v in sequences])\n        >>> unpadded_sequences = unpad_sequence(padded_sequences, lengths)\n        >>> np.allclose(sequences[0].numpy(), unpadded_sequences[0].numpy())\n        True\n        >>> np.allclose(sequences[1].numpy(), unpadded_sequences[1].numpy())\n        True\n        >>> np.allclose(sequences[2].numpy(), unpadded_sequences[2].numpy())\n        True\n    \"\"\"\n    unpadded_sequences = []\n\n    if not batch_first:\n        padded_sequences = padded_sequences.permute((1, 0, 2))\n\n    max_length = padded_sequences.shape[1]\n    idx = flow.arange(max_length)\n\n    for seq, length in zip(padded_sequences, lengths):\n        mask = idx < length\n        unpacked_seq = seq[mask]\n        unpadded_sequences.append(unpacked_seq)\n\n    return unpadded_sequences\n\n\ndef pack_sequence(\n    sequences: List[Tensor], enforce_sorted: bool = True\n) -> PackedSequence:\n    \"\"\"Packs a list of variable length Tensors\n\n    Consecutive call of the next functions: ``pad_sequence``, ``pack_padded_sequence``.\n\n    ``sequences`` should be a list of Tensors of size ``L x *``, where `L` is\n    the length of a sequence and `*` is any number of trailing dimensions,\n    including zero.\n\n    For unsorted sequences, use `enforce_sorted = False`. If ``enforce_sorted``\n    is ``True``, the sequences should be sorted in the order of decreasing length.\n    ``enforce_sorted = True`` is only necessary for ONNX export.\n\n    Args:\n        sequences (list[Tensor]): A list of sequences of decreasing length.\n        enforce_sorted (bool, optional): if ``True``, checks that the input\n            contains sequences sorted by length in a decreasing order. If\n            ``False``, this condition is not checked. Default: ``True``.\n\n    Returns:\n        a :class:`PackedSequence` object\n\n    For example:\n\n    .. code-block:: python\n    \n        >>> from oneflow.nn.utils.rnn import pack_sequence\n        >>> import oneflow as flow\n\n        >>> a = flow.tensor([1,2,3])\n        >>> b = flow.tensor([4,5])\n        >>> c = flow.tensor([6])\n        >>> packed = pack_sequence([a, b, c])\n        >>> packed.data\n        tensor([1, 4, 6, 2, 5, 3], dtype=oneflow.int64)\n        >>> packed.batch_sizes\n        tensor([3, 2, 1], dtype=oneflow.int64)\n\n    \"\"\"\n    lengths = flow.as_tensor([v.size(0) for v in sequences])\n    return pack_padded_sequence(\n        pad_sequence(sequences), lengths, enforce_sorted=enforce_sorted\n    )\n\n\ndef unpack_sequence(packed_sequences: PackedSequence) -> List[Tensor]:\n    \"\"\"Unpacks PackedSequence into a list of variable length Tensors\n\n    ``packed_sequences`` should be a PackedSequence object.\n\n    Args:\n        packed_sequences (PackedSequence): A PackedSequence object.\n\n    Returns:\n        a list of :class:`Tensor` objects\n\n    For example:\n\n    .. code-block:: python\n\n        >>> from oneflow.nn.utils.rnn import pack_sequence, unpack_sequence\n        >>> import oneflow as flow\n\n        >>> a = flow.tensor([1,2,3])\n        >>> b = flow.tensor([4,5])\n        >>> c = flow.tensor([6])\n        >>> sequences = [a, b, c]\n        >>> packed_sequences = pack_sequence(sequences)\n        >>> packed_sequences.data\n        tensor([1, 4, 6, 2, 5, 3], dtype=oneflow.int64)\n        >>> packed_sequences.batch_sizes\n        tensor([3, 2, 1], dtype=oneflow.int64)\n        >>> unpacked_sequences = unpack_sequence(packed_sequences)\n        >>> unpacked_sequences\n        [tensor([1, 2, 3], dtype=oneflow.int64), tensor([4, 5], dtype=oneflow.int64), tensor([6], dtype=oneflow.int64)]\n\n    \"\"\"\n\n    padded_sequences, lengths = pad_packed_sequence(packed_sequences, batch_first=True)\n    unpacked_sequences = unpad_sequence(padded_sequences, lengths, batch_first=True)\n    return unpacked_sequences\n\n\nif __name__ == \"__main__\":\n    import doctest\n\n    doctest.testmod(raise_on_error=True)\n"
  },
  {
    "path": "python/oneflow/nn/utils/skip_init.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport inspect\nfrom oneflow.nn.modules.module import Module\n\n\ndef skip_init(module_cls, *args, **kwargs):\n    if not issubclass(module_cls, Module):\n        raise RuntimeError(\"Expected a Module; got {}\".format(module_cls))\n    if \"device\" not in inspect.signature(module_cls).parameters:\n        raise RuntimeError(\"Module must support a 'device' arg to skip initialization\")\n\n    final_device = kwargs.pop(\"device\", \"cpu\")\n    kwargs[\"device\"] = \"meta\"\n    return module_cls(*args, **kwargs).to_empty(device=final_device)\n"
  },
  {
    "path": "python/oneflow/nn/utils/weight_norm.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\n\nimport oneflow as flow\nfrom oneflow.framework.tensor import Tensor\nfrom typing import Any, TypeVar\nfrom oneflow.nn.modules.module import Module\n\n\ndef _norm_except_dim_0(v: Tensor):\n    output_size = [1] * v.dim()\n    output_size[0] = v.size(0)\n    return flow.linalg.norm(v.view(v.size(0), -1), ord=2, dim=1).view(*output_size)\n\n\ndef _norm_except_dim(v: Tensor, dim: int):\n    assert -v.dim() <= dim <= v.dim() - 1, \"dim out of range\"\n\n    if dim == -1:\n        return flow.linalg.norm(v, ord=\"fro\")\n    elif dim == 0:\n        return _norm_except_dim_0(v)\n    elif dim == v.dim() - 1:\n        output_size = [1] * v.dim()\n        output_size[v.dim() - 1] = v.size(v.dim() - 1)\n        return flow.linalg.norm(v.view(-1, v.size(v.dim() - 1)), ord=2, dim=0).view(\n            *output_size\n        )\n    else:\n        return flow.transpose(_norm_except_dim_0(flow.transpose(v, 0, dim)), 0, dim)\n\n\nclass WeightNorm(object):\n    name: str\n    dim: int\n\n    def __init__(self, name: str, dim: int) -> None:\n        if dim is None:\n            dim = -1\n        self.name = name\n        self.dim = dim\n\n    def compute_weight(self, module: Module) -> Any:\n        g = getattr(module, self.name + \"_g\")\n        v = getattr(module, self.name + \"_v\")\n        return v * (g / _norm_except_dim(v, self.dim))\n\n    @staticmethod\n    def apply(module, name: str, dim: int) -> \"WeightNorm\":\n        for k, hook in module._forward_pre_hooks.items():\n            if isinstance(hook, WeightNorm) and hook.name == name:\n                raise RuntimeError(\n                    \"Cannot register two weight_norm hooks on \"\n                    \"the same parameter {}\".format(name)\n                )\n\n        if dim is None:\n            dim = -1\n\n        fn = WeightNorm(name, dim)\n\n        weight = getattr(module, name)\n        del module._parameters[name]\n\n        # add g and v as new parameters and express w as g/||v|| * v\n        module.register_parameter(\n            name + \"_g\", flow.nn.Parameter(_norm_except_dim(weight, dim))\n        )\n        module.register_parameter(name + \"_v\", flow.nn.Parameter(weight))\n        setattr(module, name, fn.compute_weight(module))\n\n        # recompute weight before every forward()\n        module.register_forward_pre_hook(fn)\n\n        return fn\n\n    def remove(self, module: Module) -> None:\n        weight = self.compute_weight(module)\n        delattr(module, self.name)\n        del module._parameters[self.name + \"_g\"]\n        del module._parameters[self.name + \"_v\"]\n        setattr(module, self.name, flow.nn.Parameter(weight))\n\n    def __call__(self, module: Module, inputs: Any) -> None:\n        setattr(module, self.name, self.compute_weight(module))\n\n\nT_module = TypeVar(\"T_module\", bound=Module)\n\n\ndef weight_norm(module: T_module, name: str = \"weight\", dim: int = 0) -> T_module:\n    r\"\"\"Applies weight normalization to a parameter in the given module.\n\n    .. math::\n        \\mathbf{w}=g \\frac{\\mathbf{v}}{\\|\\mathbf{v}\\|}\n\n    Weight normalization is a reparameterization that decouples the magnitude\n    of a weight tensor from its direction. This replaces the parameter specified\n    by :attr:`name` (e.g. ``'weight'``) with two parameters: one specifying the magnitude\n    (e.g. ``'weight_g'``) and one specifying the direction (e.g. ``'weight_v'``).\n    Weight normalization is implemented via a hook that recomputes the weight\n    tensor from the magnitude and direction before every :meth:`~Module.forward`\n    call.\n\n    By default, with ``dim=0``, the norm is computed independently per output\n    channel/plane. To compute a norm over the entire weight tensor, use\n    ``dim=None``.\n    \n    See https://arxiv.org/abs/1602.07868\n\n    This document description is refereced to the Pytorch document: \n    https://pytorch.org/docs/1.10/generated/torch.nn.utils.weight_norm.html.\n\n    Args:\n        module (Module): containing module\n        name (str, optional): name of weight parameter\n        dim (int, optional): dimension over which to compute the norm\n\n    Returns:\n        The original module with the weight norm hook\n    \n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> m = flow.nn.utils.weight_norm(flow.nn.Linear(20, 40), name='weight')\n        >>> m\n        Linear(in_features=20, out_features=40, bias=True)\n        >>> m.weight_g.size()\n        oneflow.Size([40, 1])\n        >>> m.weight_v.size()\n        oneflow.Size([40, 20])\n\n    \"\"\"\n    WeightNorm.apply(module, name, dim)\n    return module\n\n\ndef remove_weight_norm(module: T_module, name: str = \"weight\") -> T_module:\n    r\"\"\"Removes the weight normalization reparameterization from a module.\n\n    Args:\n        module (Module): containing module\n        name (str, optional): name of weight parameter\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> m = flow.nn.utils.weight_norm(flow.nn.Linear(20, 40))\n        >>> flow.nn.utils.remove_weight_norm(m)\n        Linear(in_features=20, out_features=40, bias=True)\n\n    \"\"\"\n    for k, hook in module._forward_pre_hooks.items():\n        if isinstance(hook, WeightNorm) and hook.name == name:\n            hook.remove(module)\n            del module._forward_pre_hooks[k]\n            return module\n\n    raise ValueError(\"weight_norm of '{}' not found in {}\".format(name, module))\n\n\nif __name__ == \"__main__\":\n    import doctest\n\n    doctest.testmod(raise_on_error=True)\n"
  },
  {
    "path": "python/oneflow/one_embedding.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nfrom typing import Callable, Dict, Iterator, List, Union\nimport oneflow as flow\nfrom oneflow.nn.modules.module import Module\nfrom oneflow.optim.optimizer import Optimizer\nfrom oneflow.nn.parameter import Parameter\nimport json\nimport datetime\nfrom oneflow._oneflow_internal import OneEmbeddingHandler\nfrom oneflow._oneflow_internal import PersistentTableReader\nfrom oneflow._oneflow_internal import PersistentTableWriter\nimport numpy as np\nimport traceback\nfrom oneflow import nn\nimport oneflow.framework.graph_build_util as graph_build_util\n\n\ndef _check_initializer(initializer):\n    assert isinstance(initializer, dict)\n    assert initializer.__contains__(\"type\")\n    initializer_type = initializer[\"type\"]\n    assert initializer_type in [\"uniform\", \"normal\", \"constant\", \"trunc_normal\"]\n    if initializer_type == \"uniform\":\n        assert initializer.__contains__(\"low\")\n        assert initializer.__contains__(\"high\")\n    elif initializer_type == \"normal\":\n        assert initializer.__contains__(\"mean\")\n        assert initializer.__contains__(\"std\")\n    elif initializer_type == \"constant\":\n        assert initializer.__contains__(\"value\")\n    elif initializer_type == \"trunc_normal\":\n        assert initializer.__contains__(\"mean\")\n        assert initializer.__contains__(\"std\")\n        assert initializer.__contains__(\"a\")\n        assert initializer.__contains__(\"b\")\n    else:\n        raise NotImplementedError(\"unsupported initializer_type\")\n\n\ndef _check_cache(cache):\n    assert isinstance(cache, dict)\n    assert cache.__contains__(\"policy\")\n    assert cache[\"policy\"] in [\"lru\", \"full\"]\n    cache_memory_budget_mb = 0\n    if cache.__contains__(\"cache_memory_budget_mb\"):\n        cache_memory_budget_mb = cache[\"cache_memory_budget_mb\"]\n    capacity = 0\n    if cache.__contains__(\"capacity\"):\n        capacity = cache[\"capacity\"]\n    assert cache_memory_budget_mb > 0 or capacity > 0\n    assert cache.__contains__(\"value_memory_kind\")\n    assert cache[\"value_memory_kind\"] in [\"device\", \"host\"]\n\n\ndef _init(\n    name, embedding_dims, dtype, key_type, tables, store_options, default_initializer\n):\n    default_initializer = default_initializer or {\n        \"type\": \"normal\",\n        \"mean\": 0,\n        \"std\": 0.05,\n    }\n    key_value_store_options = {}\n    embedding_tables = {}\n    key_value_store_options[\"name\"] = name\n\n    if isinstance(embedding_dims, (list, tuple)):\n        column_dims = embedding_dims\n        embedding_dim = sum(embedding_dims)\n    else:\n        assert embedding_dims > 0\n        column_dims = [embedding_dims]\n        embedding_dim = embedding_dims\n    parallel_num = flow.env.get_world_size()\n    key_type_size = np.dtype(\n        flow.convert_oneflow_dtype_to_numpy_dtype(key_type)\n    ).itemsize\n    assert key_type_size > 0\n    key_value_store_options[\"key_type_size\"] = key_type_size\n    value_type_size = np.dtype(\n        flow.convert_oneflow_dtype_to_numpy_dtype(dtype)\n    ).itemsize\n    assert value_type_size > 0\n    key_value_store_options[\"value_type_size\"] = value_type_size\n    key_value_store_options[\"value_type\"] = str(dtype)\n    scale_factor = store_options[\"size_factor\"]\n    storage_dim = store_options[\"storage_dim\"]\n    if storage_dim != -1:\n        key_value_store_options[\"storage_dim\"] = storage_dim\n    else:\n        key_value_store_options[\"storage_dim\"] = scale_factor * embedding_dim\n    # kv store\n    assert store_options.__contains__(\"kv_store\")\n    kv_store = store_options[\"kv_store\"]\n    assert isinstance(kv_store, dict)\n    if kv_store.__contains__(\"caches\"):\n        caches = kv_store[\"caches\"]\n        assert isinstance(caches, (dict, list, tuple))\n        if isinstance(caches, dict):\n            _check_cache(caches)\n            caches = [caches]\n        else:\n            assert len(caches) <= 2\n            for i in range(len(caches)):\n                assert isinstance(caches[i], dict)\n                _check_cache(caches[i])\n        for i in range(len(caches)):\n            if caches[i].__contains__(\"capacity\"):\n                caches[i][\"capacity\"] = caches[i][\"capacity\"] // parallel_num\n    assert kv_store.__contains__(\"persistent_table\")\n    persistent_table = kv_store[\"persistent_table\"]\n    assert isinstance(persistent_table, dict)\n    assert persistent_table.__contains__(\"path\")\n    persistent_table_path = persistent_table[\"path\"]\n    assert isinstance(persistent_table_path, (str, list, tuple))\n    if isinstance(persistent_table_path, (list, tuple)):\n        assert len(persistent_table_path) == parallel_num\n    if persistent_table.__contains__(\"physical_block_size\"):\n        assert persistent_table[\"physical_block_size\"] in [512, 4096]\n    else:\n        persistent_table[\"physical_block_size\"] = 4096\n    if persistent_table.__contains__(\"capacity_hint\"):\n        assert persistent_table[\"capacity_hint\"] >= 0\n        persistent_table[\"capacity_hint\"] = (\n            persistent_table[\"capacity_hint\"] // parallel_num\n        )\n    key_value_store_options[\"kv_store\"] = kv_store\n    # initializer\n    if tables is not None:\n        assert isinstance(tables, (list, tuple))\n        for i in range(len(tables)):\n            table = tables[i]\n            if table.__contains__(\"columns\"):\n                assert not table.__contains__(\"initializer\")\n                columns = table[\"columns\"]\n                assert len(columns) == len(column_dims)\n                for column in columns:\n                    assert isinstance(column, dict)\n                    assert column.__contains__(\"initializer\")\n                    _check_initializer(column[\"initializer\"])\n            else:\n                assert isinstance(table, dict)\n                assert table.__contains__(\"initializer\")\n                _check_initializer(table[\"initializer\"])\n                columns = []\n                for j in range(len(column_dims)):\n                    columns.append(make_column_options(table[\"initializer\"]))\n                table[\"columns\"] = columns\n                del table[\"initializer\"]\n        embedding_tables[\"tables\"] = tables\n    else:\n        assert default_initializer is not None\n        _check_initializer(default_initializer)\n        columns = []\n        for j in range(len(column_dims)):\n            columns.append(make_column_options(default_initializer))\n        embedding_tables[\"tables\"] = [{\"columns\": columns}]\n    embedding_tables[\"column_dims\"] = column_dims\n    key_value_store_options[\"parallel_num\"] = parallel_num\n    return embedding_dim, embedding_tables, key_value_store_options\n\n\nclass Embedding(Module):\n    def __init__(\n        self,\n        name,\n        embedding_dim,\n        dtype,\n        key_type,\n        tables,\n        store_options,\n        default_initializer=None,\n        padding_idx=None,\n        seed=0,\n    ):\n        super().__init__()\n        self.dtype = dtype\n        self.key_type = key_type\n        parallel_num = flow.env.get_world_size()\n        self.embedding_dim, embedding_tables, key_value_store_options = _init(\n            name,\n            embedding_dim,\n            dtype,\n            key_type,\n            tables,\n            store_options,\n            default_initializer,\n        )\n        self.storage_dim = key_value_store_options[\"storage_dim\"]\n        self.embedding_name = key_value_store_options[\"name\"]\n        self.seed = seed\n        self.is_full_cache = (\n            len(key_value_store_options[\"kv_store\"][\"caches\"]) > 0\n            and key_value_store_options[\"kv_store\"][\"caches\"][0][\"policy\"] == \"full\"\n        )\n        self.key_value_store_options = json.dumps(key_value_store_options)\n        self.embedding_tables = json.dumps(embedding_tables)\n        self.num_tables = len(embedding_tables[\"tables\"])\n        self.local_rank = flow.env.get_local_rank()\n        self.rank_id = flow.env.get_rank()\n        self.world_size = flow.env.get_world_size()\n        self.handler = OneEmbeddingHandler(\n            self.key_value_store_options, self.local_rank, self.rank_id, self.world_size\n        )\n\n        self.shadow = flow.nn.Parameter(flow.Tensor(1))\n        self.padding_idx = padding_idx\n        self.embedding = None\n\n    def _save_to_state_dict(self, destination, prefix, keep_vars):\n        super()._save_to_state_dict(destination, prefix, keep_vars)\n        snapshot_timestamp_tensor = flow.tensor(\n            datetime.datetime.now().timestamp(), dtype=flow.float64, device=\"cuda\"\n        )\n        # Broadcast timestamp tensor from master rank.\n        flow.comm.broadcast(snapshot_timestamp_tensor, src=0)\n        snapshot_timestamp = float(snapshot_timestamp_tensor.numpy())\n        snapshot_timestamp_datetime = datetime.datetime.fromtimestamp(\n            snapshot_timestamp\n        )\n        snapshot_timestamp_str = snapshot_timestamp_datetime.strftime(\n            \"%Y-%m-%d-%H-%M-%S-%f\"\n        )\n        self.handler.SaveSnapshot(snapshot_timestamp_str)\n        destination[prefix + \"OneEmbeddingSnapshot\"] = snapshot_timestamp_str\n        destination[\n            prefix + \"OneEmbeddingKeyValueOptions\"\n        ] = self.key_value_store_options\n\n    def _load_from_state_dict(\n        self,\n        state_dict,\n        prefix,\n        local_metadata,\n        strict,\n        missing_keys,\n        unexpected_keys,\n        error_msgs,\n    ):\n        key = prefix + \"OneEmbeddingSnapshot\"\n        if key in state_dict:\n            saved_snapshot_name = state_dict[key]\n            try:\n                self.handler.LoadSnapshot(saved_snapshot_name)\n            except Exception as ex:\n                error_msgs.append(\n                    'While Loading OneEmbedding Snapshot named \"{}\" failed, please check whether the Snapshot exist'.format(\n                        saved_snapshot_name\n                    )\n                )\n\n    def save_snapshot(self, snapshot_name):\n        \"\"\"save snapshot\n\n        Args:\n            snapshot_name (str): the snapshot_name, snapshot will be saved in the snapshots dir under your_configed_persistent_path\n    \n        For example:\n\n        .. code-block:: python\n\n            >>> import oneflow as flow\n            >>> # use embedding create by flow.one_embedding.MultiTableEmbedding\n            >>> embedding.save_snapshot(\"my_snapshot1\")\n            >>> # a snapshot named \"my_snapshot1\" have been saved in the \"snapshots\" dir under your_configed_persistent_path\n            >>> # which can be reload by flow.one_embedding.load_snapshot\n        \"\"\"\n        self.handler.SaveSnapshot(snapshot_name)\n\n    def load_snapshot(self, snapshot_name):\n        \"\"\"load snapshot\n\n        Args:\n            snapshot_name (str): the snapshot_name, snapshot will be load from your_configed_persistent_path\n    \n        For example:\n\n        .. code-block:: python\n\n            >>> import oneflow as flow\n            >>> # use embedding create by flow.one_embedding.MultiTableEmbedding\n            >>> embedding.load_snapshot(\"my_snapshot1\")\n            >>> # load a snapshot named \"my_snapshot1\" from your_configed_persistent_path\n        \"\"\"\n        self.handler.LoadSnapshot(snapshot_name)\n\n    def forward(self, ids, table_ids=None):\n        \"\"\"Embedding lookup operation\n\n        Args:\n            ids (flow.tensor): the feature ids\n            table_ids (flow.tensor, optional): the table_id of each id, must be same shape as ids. There is no need to pass table_ids, if has config only one table or the ids has shape (batch_size, num_tables), and each column's id belongs to the column_id th table, otherwise, you should pass the tensor_ids.\n\n        Returns:\n            flow.tensor: the result of embedding lookup\n        \"\"\"\n        assert self.key_type == ids.dtype, \"ids data_type must equals key_type\"\n        embedding = flow._C.one_embedding_fused_lookup(\n            self.shadow,\n            ids,\n            table_ids,\n            self.dtype,\n            self.embedding_name,\n            self.storage_dim,\n            self.embedding_dim,\n            self.is_full_cache,\n            self.num_tables,\n            self.embedding_tables,\n            self.padding_idx,\n            self.seed,\n        )\n        if embedding.requires_grad and not graph_build_util.lazy_mode.is_enabled():\n            if self.embedding is not None:\n                raise ValueError(\n                    \"You are training without set embedding optimizer, Please add flow.one_embedding.Optimizer after optimizer.\"\n                )\n\n            self.embedding = embedding\n            self.embedding.retain_grad()\n            self.ids = ids\n            self.table_ids = table_ids\n        return embedding\n\n    def shuffle_and_lookup(self, state_initializer):\n        embedding_grad = self.embedding.grad\n        if self.world_size > 1:\n            (\n                num_unique_matrix,\n                inverse_unique_partition_indices,\n                cur_rank_num_unique,\n                cur_rank_unique_ids,\n                cur_rank_unique_table_ids,\n                cur_rank_inverse_indices,\n            ) = flow._C.one_embedding_id_shuffle(\n                self.ids, self.table_ids, self.num_tables, self.embedding_name\n            )\n            unique_values = flow._C.one_embedding_lookup(\n                cur_rank_num_unique,\n                cur_rank_unique_ids,\n                cur_rank_unique_table_ids,\n                self.dtype,\n                self.dtype,\n                self.storage_dim,\n                self.embedding_dim,\n                self.embedding_name,\n                self.embedding_tables,\n                state_initializer,\n                seed=self.seed,\n            )\n            cur_rank_unique_embedding_grad = flow._C.one_embedding_embedding_gradient_shuffle(\n                embedding_grad,\n                num_unique_matrix,\n                cur_rank_inverse_indices,\n                inverse_unique_partition_indices,\n                self.embedding_name,\n            )\n        else:\n            (\n                cur_rank_num_unique,\n                cur_rank_unique_ids,\n                cur_rank_unique_table_ids,\n                inverse_indices,\n            ) = flow._C.one_embedding_unique_key_value_pair(\n                self.ids, self.table_ids, self.num_tables, self.embedding_name\n            )\n            unique_values = flow._C.one_embedding_lookup(\n                cur_rank_num_unique,\n                cur_rank_unique_ids,\n                cur_rank_unique_table_ids,\n                self.dtype,\n                self.dtype,\n                self.storage_dim,\n                self.embedding_dim,\n                self.embedding_name,\n                self.embedding_tables,\n                state_initializer,\n                seed=self.seed,\n            )\n            cur_rank_unique_embedding_grad = flow._C.unsorted_segment_sum(\n                embedding_grad,\n                inverse_indices,\n                axis=0,\n                num_segments=unique_values.shape[0],\n            )\n        self.embedding = None\n        return (\n            cur_rank_num_unique,\n            cur_rank_unique_ids,\n            unique_values,\n            cur_rank_unique_embedding_grad,\n        )\n\n    def sgd_update(self, param_group, step):\n        lr = param_group[\"lr\"]\n        l2 = param_group[\"weight_decay\"]\n        momentum = param_group[\"momentum\"]\n        (\n            cur_rank_num_unique,\n            cur_rank_unique_ids,\n            unique_values,\n            cur_rank_unique_embedding_grad,\n        ) = self.shuffle_and_lookup(\"\")\n        updated_values = flow._C.one_embedding_sgd_update(\n            cur_rank_num_unique,\n            unique_values,\n            cur_rank_unique_embedding_grad,\n            learning_rate_val=lr,\n            scale=1.0,\n            weight_decay=l2,\n            momentum=momentum,\n            line_size=self.storage_dim,\n            embedding_size=self.embedding_dim,\n            embedding_name=self.embedding_name,\n        )\n        flow._C.one_embedding_embedding_put(\n            cur_rank_num_unique,\n            cur_rank_unique_ids,\n            updated_values,\n            self.embedding_name,\n            self.storage_dim,\n        )\n\n    def adam_update(self, param_group, step):\n        line_size = self.storage_dim\n        embedding_size = self.embedding_dim\n        lr = param_group[\"lr\"]\n        # not adjust, because it has been set in optimizer's step\n        bias_correction1 = param_group[\"bias_correction1\"]\n        bias_correction2 = param_group[\"bias_correction2\"]\n        l2 = param_group[\"weight_decay\"]\n        beta1 = param_group[\"betas\"][0]\n        beta2 = param_group[\"betas\"][1]\n        epsilon = param_group[\"eps\"]\n        do_bias_correction = param_group[\"do_bias_correction\"]\n        amsgrad = param_group[\"amsgrad\"]\n        assert amsgrad == False, \"one_embedding's adam not support amsgrad\"\n        state_initializer = [make_constant_initializer(0), make_constant_initializer(0)]\n        (\n            cur_rank_num_unique,\n            cur_rank_unique_ids,\n            unique_values,\n            cur_rank_unique_embedding_grad,\n        ) = self.shuffle_and_lookup(json.dumps(state_initializer))\n        updated_values = flow._C.one_embedding_adam_update(\n            cur_rank_num_unique,\n            unique_values,\n            cur_rank_unique_embedding_grad,\n            learning_rate_val=lr,\n            scale=1.0,\n            weight_decay=l2,\n            beta1=beta1,\n            beta2=beta2,\n            bias_correction1_val=bias_correction1,\n            bias_correction2_val=bias_correction2,\n            epsilon=epsilon,\n            do_bias_correction=do_bias_correction,\n            line_size=line_size,\n            embedding_size=embedding_size,\n            embedding_name=self.embedding_name,\n        )\n        flow._C.one_embedding_embedding_put(\n            cur_rank_num_unique,\n            cur_rank_unique_ids,\n            updated_values,\n            self.embedding_name,\n            line_size,\n        )\n\n    def adagrad_update(self, param_group, step):\n        lr = param_group[\"lr\"]\n        l2 = param_group[\"weight_decay\"]\n        epsilon = param_group[\"eps\"]\n        lr_decay = param_group[\"lr_decay\"]\n        initial_accumulator_value = param_group[\"initial_accumulator_value\"]\n        state_initializer = [make_constant_initializer(initial_accumulator_value)]\n        (\n            cur_rank_num_unique,\n            cur_rank_unique_ids,\n            unique_values,\n            cur_rank_unique_embedding_grad,\n        ) = self.shuffle_and_lookup(json.dumps(state_initializer))\n        updated_values = flow._C.one_embedding_adagrad_update(\n            cur_rank_num_unique,\n            unique_values,\n            cur_rank_unique_embedding_grad,\n            train_step_val=step + 1,\n            learning_rate_val=lr,\n            scale=1.0,\n            weight_decay=l2,\n            lr_decay=lr_decay,\n            epsilon=epsilon,\n            line_size=self.storage_dim,\n            embedding_size=self.embedding_dim,\n            embedding_name=self.embedding_name,\n        )\n        flow._C.one_embedding_embedding_put(\n            cur_rank_num_unique,\n            cur_rank_unique_ids,\n            updated_values,\n            self.embedding_name,\n            self.storage_dim,\n        )\n\n    def ftrl_update(self, param_group, step):\n        lr = param_group[\"lr\"]\n        l2 = param_group[\"weight_decay\"]\n        lr_power = param_group[\"lr_power\"]\n        lambda1 = param_group[\"lambda1\"]\n        lambda2 = param_group[\"lambda2\"]\n        beta = param_group[\"beta\"]\n        initial_accumulator_value = param_group[\"initial_accumulator_value\"]\n        state_initializer = [\n            make_constant_initializer(initial_accumulator_value),\n            make_constant_initializer(initial_accumulator_value),\n        ]\n        (\n            cur_rank_num_unique,\n            cur_rank_unique_ids,\n            unique_values,\n            cur_rank_unique_embedding_grad,\n        ) = self.shuffle_and_lookup(json.dumps(state_initializer))\n        updated_values = flow._C.one_embedding_ftrl_update(\n            cur_rank_num_unique,\n            unique_values,\n            cur_rank_unique_embedding_grad,\n            learning_rate_val=lr,\n            scale=1.0,\n            weight_decay=l2,\n            lr_power=lr_power,\n            lambda1=lambda1,\n            lambda2=lambda2,\n            beta=beta,\n            line_size=self.storage_dim,\n            embedding_size=self.embedding_dim,\n            embedding_name=self.embedding_name,\n        )\n        flow._C.one_embedding_embedding_put(\n            cur_rank_num_unique,\n            cur_rank_unique_ids,\n            updated_values,\n            self.embedding_name,\n            self.storage_dim,\n        )\n\n\ndef make_device_mem_store_options(\n    persistent_path, capacity, size_factor=1, storage_dim=-1, physical_block_size=4096\n):\n    \"\"\"make GPU only store_options param of MultiTableEmbedding\n\n    Args:\n        persistent_path (str, list): persistent storage path of Embedding. If passed a str, current rank Embedding will be saved in path/rank_id-num_ranks path. If passed a list, the list length must equals num_ranks, each elem of list represent the path of rank_id Embedding.\n        capacity (int): total capacity of Embedding\n        size_factor (int, optional): store size factor of embedding_dim, if SGD update, and momentum = 0, should be 1, if momentum > 0, it should be 2. if Adam, should be 3. Defaults to 1.\n        storage_dim (int, optional): number of elements in embedding storage, if set storage_dim, the size_factor param will be invalid. if SGD update, and momentum = 0, storage_dim should be embedding_size*1, if momentum > 0, storage_dim should be embedding_size*2. if Adam, storage_dim should be embedding_size*3. Defaults to -1.\n        physical_block_size (int, optional): physical_block_size should be sector size. Defaults to 4096.\n\n    Returns:\n        dict: GPU only store_options param of MultiTableEmbedding\n\n    See also :func:`oneflow.one_embedding.make_cached_ssd_store_options`\n    \"\"\"\n\n    assert isinstance(persistent_path, (str, list, tuple))\n    assert capacity > 0\n    options = {\n        \"kv_store\": {\n            \"caches\": [\n                {\n                    \"policy\": \"full\",\n                    \"capacity\": int(capacity),\n                    \"value_memory_kind\": \"device\",\n                }\n            ],\n            \"persistent_table\": {\n                \"path\": persistent_path,\n                \"physical_block_size\": physical_block_size,\n                \"capacity_hint\": int(capacity),\n            },\n        },\n        \"size_factor\": size_factor,\n        \"storage_dim\": storage_dim,\n    }\n    return options\n\n\ndef make_cached_ssd_store_options(\n    cache_budget_mb,\n    persistent_path,\n    capacity=None,\n    size_factor=1,\n    storage_dim=-1,\n    physical_block_size=4096,\n    host_cache_budget_mb=0,\n):\n    \"\"\"make SSD use GPU and host as cache store_options param of MultiTableEmbedding. If cache_budget_mb > 0 and host_cache_budget_mb > 0, use GPU and host memory as multi-level cache.\n\n    Args:\n        cache_budget_mb (int): the MB budget of per GPU as cache.\n        persistent_path (str, list): persistent storage path of Embedding, must use fast SSD because of frequently random disk access during training. If passed a str, current rank Embedding will be saved in path/rank_id-num_ranks path. If passed a list, the list length must equals num_ranks, each elem of list represent the path of rank_id Embedding.\n        capacity (int): total capacity of Embedding\n        size_factor (int, optional): store size factor of embedding_dim, if SGD update, and momentum = 0, should be 1, if momentum > 0, it should be 2. if Adam, should be 3. Defaults to 1.\n        storage_dim (int, optional): number of elements in embedding storage, if set storage_dim, the size_factor param will be invalid. if SGD update, and momentum = 0, storage_dim should be embedding_size*1, if momentum > 0, storage_dim should be embedding_size*2. if Adam, storage_dim should be embedding_size*3. Defaults to -1.\n        physical_block_size (int, optional): physical_block_size should be sector size. Defaults to 4096.\n        host_cache_budget_mb (int): the MB budget of host memory as cache per rank. Defaults to 0.\n\n    Returns:\n        dict: SSD use GPU and host as cache store_options param of MultiTableEmbedding\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow    \n        >>> store_options = flow.one_embedding.make_cached_ssd_store_options(\n        >>>     cache_budget_mb=8192, persistent_path=\"/your_path_to_ssd\", capacity=vocab_size,\n        >>> )\n        >>> # pass the store_options to the \"store_options\" param of flow.one_embedding.MultiTableEmbedding\n        >>> # ...\n    \"\"\"\n    assert isinstance(persistent_path, (str, list, tuple))\n    assert cache_budget_mb > 0 or host_cache_budget_mb > 0\n    if capacity is not None:\n        assert capacity > 0\n    else:\n        capacity = 0\n\n    cache_list = []\n    if cache_budget_mb > 0:\n        cache_list.append(\n            {\n                \"policy\": \"lru\",\n                \"cache_memory_budget_mb\": cache_budget_mb,\n                \"value_memory_kind\": \"device\",\n            }\n        )\n    if host_cache_budget_mb > 0:\n        cache_list.append(\n            {\n                \"policy\": \"lru\",\n                \"cache_memory_budget_mb\": host_cache_budget_mb,\n                \"value_memory_kind\": \"host\",\n            }\n        )\n\n    options = {\n        \"kv_store\": {\n            \"caches\": cache_list,\n            \"persistent_table\": {\n                \"path\": persistent_path,\n                \"physical_block_size\": physical_block_size,\n                \"capacity_hint\": int(capacity),\n            },\n        },\n        \"size_factor\": size_factor,\n        \"storage_dim\": storage_dim,\n    }\n    return options\n\n\ndef make_cached_host_mem_store_options(\n    cache_budget_mb,\n    persistent_path,\n    capacity,\n    size_factor=1,\n    storage_dim=-1,\n    physical_block_size=4096,\n):\n    \"\"\"make host use GPU as cache store_options param of MultiTableEmbedding\n\n    Args:\n        cache_budget_mb (int): the MB budget of per GPU as cache.\n        persistent_path (str, list): persistent storage path of Embedding. If passed a str, current rank Embedding will be saved in path/rank_id-num_ranks path. If passed a list, the list length must equals num_ranks, each elem of list represent the path of rank_id Embedding.\n        capacity (int): total capacity of Embedding\n        size_factor (int, optional): store size factor of embedding_dim, if SGD update, and momentum = 0, should be 1, if momentum > 0, it should be 2. if Adam, should be 3. Defaults to 1.\n        storage_dim (int, optional): number of elements in embedding storage, if set storage_dim, the size_factor param will be invalid. if SGD update, and momentum = 0, storage_dim should be embedding_size*1, if momentum > 0, storage_dim should be embedding_size*2. if Adam, storage_dim should be embedding_size*3. Defaults to -1.\n        physical_block_size (int, optional): physical_block_size should be sector size. Defaults to 4096.\n\n    Returns:\n        dict: host use GPU as cache store_options param of MultiTableEmbedding\n\n    See also :func:`oneflow.one_embedding.make_cached_ssd_store_options`\n    \"\"\"\n    assert isinstance(persistent_path, (str, list, tuple))\n    assert cache_budget_mb > 0\n    assert capacity > 0\n    options = {\n        \"kv_store\": {\n            \"caches\": [\n                {\n                    \"policy\": \"lru\",\n                    \"cache_memory_budget_mb\": cache_budget_mb,\n                    \"value_memory_kind\": \"device\",\n                },\n                {\n                    \"policy\": \"full\",\n                    \"capacity\": int(capacity),\n                    \"value_memory_kind\": \"host\",\n                },\n            ],\n            \"persistent_table\": {\n                \"path\": persistent_path,\n                \"physical_block_size\": physical_block_size,\n                \"capacity_hint\": int(capacity),\n            },\n        },\n        \"size_factor\": size_factor,\n        \"storage_dim\": storage_dim,\n    }\n    return options\n\n\ndef make_uniform_initializer(low=0.0, high=1.0):\n    \"\"\"make uniform initializer param of make_table_options\n\n    Args:\n        low (float): A python scalar. Lower bound of the range of random values to generate.\n        high (float): A python scalar. Upper bound of the range of random values to generate.\n\n    Returns:\n        dict: initializer param of make_table_options\n    \n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> initializer = flow.one_embedding.make_uniform_initializer(low=-scale, high=scale)\n        >>> # pass the initializer to flow.one_embedding.make_table_options\n        >>> # ...\n    \"\"\"\n    return {\"type\": \"uniform\", \"low\": low, \"high\": high}\n\n\ndef make_normal_initializer(mean=0.0, std=1.0):\n    \"\"\"make normal initializer param of make_table_options\n\n    Args:\n        mean (float): A python scalar. Mean of the random values to generate.\n        std (float): A python scalar. Standard deviation of the random values to generate.\n\n    Returns:\n        dict: initializer param of make_table_options\n    \n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> initializer = flow.one_embedding.make_normal_initializer(mean=0, std=0.01)\n        >>> # pass the initializer to flow.one_embedding.make_table_options\n        >>> # ...\n    \"\"\"\n    return {\"type\": \"normal\", \"mean\": mean, \"std\": std}\n\n\ndef make_constant_initializer(value):\n    \"\"\"make constant initializer param of make_table_options\n\n    Args:\n        constant (float): A python scalar. value to generate.\n\n    Returns:\n        dict: initializer param of make_table_options\n\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> initializer = flow.one_embedding.make_constant_initializer(value=0)\n        >>> # pass the initializer to flow.one_embedding.make_table_options\n        >>> # ...\n    \"\"\"\n    return {\"type\": \"constant\", \"value\": value}\n\n\ndef make_trunc_normal_initializer(mean=0.0, std=1.0, a=-2.0, b=2.0):\n    \"\"\"make truncated normal initializer param of make_table_options\n\n    Args:\n        mean (float): A python scalar. Mean of the random values to generate.\n        std (float): A python scalar. Standard deviation of the random values to generate.\n        a (float): A python scalar. The minimum cutoff value.\n        b (float): A python scalar. The maximum cutoff value.\n\n    Returns:\n        dict: initializer param of make_table_options\n    \n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> initializer = flow.one_embedding.make_trunc_normal_initializer(mean=0, std=0.01, a=-0.02, b=0.02)\n        >>> # pass the initializer to flow.one_embedding.make_table_options\n        >>> # ...\n    \"\"\"\n    return {\"type\": \"trunc_normal\", \"mean\": mean, \"std\": std, \"a\": a, \"b\": b}\n\n\ndef make_table_options(param):\n    \"\"\"make table param of Embedding tables\n\n    Args:\n        param (dict or list): param can be initializer or list of column_option. initializer can be made by make_uniform_initializer or make_normal_initializer or make_constant_initializer, column options can be made by make_column_options\n\n    Returns:\n        dict: table param of Embedding tables\n    \n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> initializer = flow.one_embedding.make_uniform_initializer(low=-scale, high=scale)\n        >>> table1 = flow.one_embedding.make_table_options(initializer)\n        >>> table2 = flow.one_embedding.make_table_options(initializer)\n        >>> tables = [table1, table2]\n        >>> # pass the tables to the \"tables\" param of flow.one_embedding.MultiTableEmbedding or flow.one_embedding.MultiTableMultiColumnEmbedding\n        >>> # ...\n        \n    \"\"\"\n    if isinstance(param, dict):\n        table = {\"initializer\": param}\n    elif isinstance(param, (list, tuple)):\n        table = {\"columns\": param}\n    else:\n        raise ValueError(\"param must be initializer or columns\")\n    return table\n\n\ndef make_column_options(initializer):\n    return {\"initializer\": initializer}\n\n\ndef make_table(param):\n    \"\"\"alias of `oneflow.one_embedding.make_table_options`\n\n    See also :func:`oneflow.one_embedding.make_table_options`\n    \"\"\"\n    return make_table_options(param)\n\n\nclass MultiTableEmbedding(Embedding):\n    r\"\"\"MultiTableEmbedding represent multi Embedding tables with same embedding_dim, dtype, and key_type.\n\n    Args:\n        name (str): The name of Embedding\n        embedding_dim (int): the size of each embedding vector\n        dtype (flow.dtype): the data type of embeddings\n        key_type (flow.dtype): the data type of feature ids\n        tables (list): list of table param which can be made by flow.one_embedding.make_table_options\n        store_options (dict): store option of Embedding\n        default_initializer (dict, optional): if tables param is None, use default_initializer to initialize table. Defaults to None.\n        padding_idx (int, optional): If specified, the entries at :attr:`padding_idx` do not contribute to the gradient;\n                                     therefore, the embedding vector at :attr:`padding_idx` is not updated during training,\n                                     the embedding vector at :attr:`padding_idx` will default to all zeros.\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n        >>> import oneflow.nn as nn\n        >>> # a simple example with 3 table\n        >>> table_size_array = [39884407, 39043, 17289]\n        >>> vocab_size = sum(table_size_array)\n        >>> num_tables = len(table_size_array)\n        >>> embedding_size = 128\n        >>> scales = np.sqrt(1 / np.array(table_size_array))\n        >>> tables = [\n        >>>     flow.one_embedding.make_table_options(\n        >>>         flow.one_embedding.make_uniform_initializer(low=-scale, high=scale)\n        >>>     )\n        >>>     for scale in scales\n        >>> ]\n        >>> store_options = flow.one_embedding.make_cached_ssd_store_options(\n        >>>     cache_budget_mb=8192, persistent_path=\"/your_path_to_ssd\", capacity=vocab_size,\n        >>> )\n        >>> embedding = flow.one_embedding.MultiTableEmbedding(\n        >>>     name=\"my_embedding\",\n        >>>     embedding_dim=embedding_size,\n        >>>     dtype=flow.float,\n        >>>     key_type=flow.int64,\n        >>>     tables=tables,\n        >>>     store_options=store_options,\n        >>> )\n        >>> embedding.to(\"cuda\")\n        >>> mlp = flow.nn.FusedMLP(\n        >>>     in_features=embedding_size * num_tables,\n        >>>     hidden_features=[512, 256, 128],\n        >>>     out_features=1,\n        >>>     skip_final_activation=True,\n        >>> )\n        >>> mlp.to(\"cuda\")\n        >>>\n        >>> class TrainGraph(flow.nn.Graph):\n        >>>     def __init__(self,):\n        >>>         super().__init__()\n        >>>         self.embedding_lookup = embedding\n        >>>         self.mlp = mlp\n        >>>         self.add_optimizer(\n        >>>             flow.optim.SGD(self.embedding_lookup.parameters(), lr=0.1, momentum=0.0)\n        >>>         )\n        >>>         self.add_optimizer(\n        >>>             flow.optim.SGD(self.mlp.parameters(), lr=0.1, momentum=0.0)\n        >>>         ) \n        >>>     def build(self, ids):\n        >>>         embedding = self.embedding_lookup(ids)\n        >>>         loss = self.mlp(flow.reshape(embedding, (-1, num_tables * embedding_size)))\n        >>>         loss = loss.sum()\n        >>>         loss.backward()\n        >>>         return loss \n        >>> ids = np.random.randint(0, 1000, (100, num_tables), dtype=np.int64)\n        >>> ids_tensor = flow.tensor(ids, requires_grad=False).to(\"cuda\")\n        >>> graph = TrainGraph()\n        >>> loss = graph(ids_tensor)\n        >>> print(loss)\n\n    \"\"\"\n\n    def __init__(\n        self,\n        name,\n        embedding_dim,\n        dtype,\n        key_type,\n        tables,\n        store_options,\n        default_initializer=None,\n        padding_idx=None,\n        seed=0,\n    ):\n        assert isinstance(embedding_dim, int)\n        super().__init__(\n            name,\n            embedding_dim,\n            dtype,\n            key_type,\n            tables,\n            store_options,\n            default_initializer,\n            padding_idx,\n            seed,\n        )\n\n\nclass MultiTableMultiColumnEmbedding(Embedding):\n    r\"\"\"MultiTableMultiColumnEmbedding represent multi Embedding tables with multi embedding_dim, same dtype, and key_type.\n\n    Args:\n        name (str): The name of Embedding\n        embedding_dim (list): list of the size of each embedding vector\n        dtype (flow.dtype): the data type of embeddings\n        key_type (flow.dtype): the data type of feature ids\n        tables (list): list of table param which can be made by flow.one_embedding.make_table_options\n        store_options (dict): store option of Embedding\n        default_initializer (dict, optional): if tables param is None, use default_initializer to initialize table. Defaults to None.\n        padding_idx (int, optional): If specified, the entries at :attr:`padding_idx` do not contribute to the gradient;\n                                     therefore, the embedding vector at :attr:`padding_idx` is not updated during training,\n                                     the embedding vector at :attr:`padding_idx` will default to all zeros.\n    For example:\n\n    .. code-block:: python\n\n        >>> import oneflow as flow\n        >>> import numpy as np\n        >>> import oneflow.nn as nn\n        >>> # a simple example with 3 table, every table has two column, the first column embedding_size is 10 and the second is 1.\n        >>> # every table's first column initialize with uniform(-1/sqrt(table_size), 1/sqrt(table_size)), second column initialize with normal(0, 1/sqrt(table_size))\n        >>> table_size_array = [39884407, 39043, 17289]\n        >>> vocab_size = sum(table_size_array)\n        >>> num_tables = len(table_size_array)\n        >>> embedding_size_list = [10, 1]\n        >>> scales = np.sqrt(1 / np.array(table_size_array))\n        >>> tables = [\n        >>>     flow.one_embedding.make_table_options(\n        >>>       [flow.one_embedding.make_column_options(    \n        >>>         flow.one_embedding.make_uniform_initializer(low=-scale, high=scale)), \n        >>>        flow.one_embedding.make_column_options(    \n        >>>         flow.one_embedding.make_normal_initializer(mean=0, std=scale))]\n        >>>     )\n        >>>     for scale in scales\n        >>> ]\n        >>> store_options = flow.one_embedding.make_cached_ssd_store_options(\n        >>>     cache_budget_mb=8192, persistent_path=\"/your_path_to_ssd\", capacity=vocab_size,\n        >>> )\n        >>> embedding = flow.one_embedding.MultiTableMultiColumnEmbedding(\n        >>>     name=\"my_embedding\",\n        >>>     embedding_dim=embedding_size_list,\n        >>>     dtype=flow.float,\n        >>>     key_type=flow.int64,\n        >>>     tables=tables,\n        >>>     store_options=store_options,\n        >>> )\n        >>> embedding.to(\"cuda\")\n        >>> mlp = flow.nn.FusedMLP(\n        >>>     in_features=sum(embedding_size_list) * num_tables,\n        >>>     hidden_features=[512, 256, 128],\n        >>>     out_features=1,\n        >>>     skip_final_activation=True,\n        >>> )\n        >>> mlp.to(\"cuda\")\n        >>>\n        >>> class TrainGraph(flow.nn.Graph):\n        >>>     def __init__(self,):\n        >>>         super().__init__()\n        >>>         self.embedding_lookup = embedding\n        >>>         self.mlp = mlp\n        >>>         self.add_optimizer(\n        >>>             flow.optim.SGD(self.embedding_lookup.parameters(), lr=0.1, momentum=0.0)\n        >>>         )\n        >>>         self.add_optimizer(\n        >>>             flow.optim.SGD(self.mlp.parameters(), lr=0.1, momentum=0.0)\n        >>>         ) \n        >>>     def build(self, ids):\n        >>>         embedding = self.embedding_lookup(ids)\n        >>>         loss = self.mlp(flow.reshape(embedding, (-1, num_tables * sum(embedding_size_list))))\n        >>>         loss = loss.sum()\n        >>>         loss.backward()\n        >>>         return loss \n        >>> ids = np.random.randint(0, 1000, (100, num_tables), dtype=np.int64)\n        >>> ids_tensor = flow.tensor(ids, requires_grad=False).to(\"cuda\")\n        >>> graph = TrainGraph()\n        >>> loss = graph(ids_tensor)\n        >>> print(loss)\n\n    \"\"\"\n\n    def __init__(\n        self,\n        name,\n        embedding_dim,\n        dtype,\n        key_type,\n        tables,\n        store_options,\n        default_initializer=None,\n        padding_idx=None,\n        seed=0,\n    ):\n        if isinstance(embedding_dim, (list, tuple)):\n            for dim in embedding_dim:\n                assert isinstance(dim, int)\n        else:\n            assert isinstance(embedding_dim, int)\n\n        super().__init__(\n            name,\n            embedding_dim,\n            dtype,\n            key_type,\n            tables,\n            store_options,\n            default_initializer,\n            padding_idx,\n            seed,\n        )\n\n\nclass Ftrl(Optimizer):\n    r\"\"\"FTRL Optimizer. \n\n    The formula is: \n\n        .. math:: \n                \\begin{align}\n                accumlator_{i+1} = accumlator_{i} + grad * grad \\\\\n                sigma = (accumulator_{i+1}^{lr\\_power} - accumulator_{i}^{lr\\_power}) / learning\\_rate \\\\\n                z_{i+1} = z_{i} + grad - sigma * param_{i} \\\\\n                \\text{}\n                    param_{i+1} = \\begin{cases}\n                    0 & \\text{ if } |z_{i+1}| < \\lambda_1 \\\\\n                    -(\\frac{\\beta+accumlator_{i+1}^{lr\\_power}}{learning\\_rate} + \\lambda_2)*(z_{i+1} - sign(z_{i+1})*\\lambda_1) & \\text{ otherwise } \\\\\n                \\end{cases}\n                \\end{align}\n    \n    Example 1: \n\n    .. code-block:: python \n\n        # Assume net is a custom model. \n        ftrl = flow.one_embedding.FTRL(net.parameters(), lr=1e-3)\n\n        for epoch in range(epochs):\n            # Read data, Compute the loss and so on. \n            # ...\n            loss.backward()\n            ftrl.step()\n            ftrl.zero_grad()\n\n    Args:\n        params (iterable): iterable of parameters to optimize or dicts defining\n            parameter groups\n        lr (float, optional): learning rate. Defaults to 1e-3.\n        weight_decay (float, optional): weight decay (L2 penalty). Defaults to 0.0.\n        lr_power (float, optional): learning rate decrease factor. Defaults to -0.5.\n        initial_accumulator_value (float, optional): The initial value of accumlator. Defaults to 0.1.\n        lambda1 (float, optional): L1 regularization strength. Defaults to 0.0.\n        lambda2 (float, optional): L2 regularization strength. Defaults to 0.0.\n        beta (float, optional): The value of beta. Defaults to 0.0.\n    \"\"\"\n\n    def __init__(\n        self,\n        params: Union[Iterator[Parameter], List[Dict]],\n        lr: float = 0.001,\n        weight_decay: float = 0.0,\n        lr_power: float = -0.5,\n        initial_accumulator_value: float = 0.1,\n        lambda1: float = 0.0,\n        lambda2: float = 0.0,\n        beta: float = 0.0,\n    ):\n        assert lr >= 0.0, f\"Invalid learning rate: {lr}\"\n        assert weight_decay >= 0.0, f\"Invalid weight_decay value: {weight_decay}\"\n        options = dict()\n        options[\"lr\"] = lr\n        options[\"weight_decay\"] = weight_decay\n        options[\"lr_power\"] = lr_power\n        options[\"initial_accumulator_value\"] = initial_accumulator_value\n        options[\"lambda1\"] = lambda1\n        options[\"lambda2\"] = lambda2\n        options[\"beta\"] = beta\n        super().__init__(params, options)\n        for param_group in self.param_groups:\n            for param in param_group.parameters:\n                assert param.is_leaf, \"parameters must be leaf tensor\"\n                self.state[param] = dict()\n                self.state[param][\"accumulator_value\"] = flow.zeros_like(param).fill_(\n                    param_group[\"initial_accumulator_value\"]\n                )\n\n        self._op = (\n            flow.stateful_op(\"ftrl_update\")\n            .Input(\"model\")\n            .Input(\"model_diff\")\n            .Input(\"accumulate\")\n            .Input(\"z\")\n            .Build()\n        )\n\n    def step(self, closure: Callable = None):\n        \"\"\"Performs a single optimization step.\n        \n        Args:\n            closure (callable, optional): A closure that reevaluates the model\n                and returns the loss.\n        \"\"\"\n        with flow.no_grad():\n            loss = None\n            if closure is not None:\n                loss = closure()\n\n            for param_group in self.param_groups:\n                kwargs = {\n                    \"learning_rate\": param_group[\"lr\"],\n                    \"l2\": param_group[\"weight_decay\"],\n                    \"lr_power\": param_group[\"lr_power\"],\n                    \"lambda1\": param_group[\"lambda1\"],\n                    \"lambda2\": param_group[\"lambda2\"],\n                    \"beta\": param_group[\"beta\"],\n                }\n                for param in param_group.parameters:\n                    if param.grad is None:\n                        continue\n                    if \"z\" not in self.state[param]:\n                        self.state[param][\"z\"] = flow.zeros_like(param)\n\n                    accumulate_tensor = self.state[param][\"accumulator_value\"]\n                    z_tensor = self.state[param][\"z\"]\n\n                    flow._C.dispatch_ftrl_update(\n                        self._op,\n                        (param, param.grad, accumulate_tensor, z_tensor),\n                        **kwargs,\n                    )\n\n            return loss\n\n    def _generate_conf_for_graph(self, train_conf, vars_conf):\n        new_opt_confs = []\n        for param_group in self.param_groups:\n            optimizer_conf = train_conf.optimizer_conf.add()\n\n            lr = (\n                param_group[\"initial_lr\"]\n                if \"initial_lr\" in param_group\n                else param_group[\"lr\"]\n            )\n\n            l2 = param_group[\"weight_decay\"]\n            initial_accumulator_value = param_group[\"initial_accumulator_value\"]\n            lr_power = param_group[\"lr_power\"]\n            lambda1 = param_group[\"lambda1\"]\n            lambda2 = param_group[\"lambda2\"]\n            beta = param_group[\"beta\"]\n\n            optimizer_conf.base_learning_rate = lr\n            self._generate_lr_scale_for_optim_conf(param_group, optimizer_conf)\n\n            optimizer_conf.ftrl_conf.initial_accumulator_value = (\n                initial_accumulator_value\n            )\n            optimizer_conf.ftrl_conf.lr_power = lr_power\n            optimizer_conf.ftrl_conf.lambda1 = lambda1\n            optimizer_conf.ftrl_conf.lambda2 = lambda2\n            optimizer_conf.ftrl_conf.beta = beta\n\n            self._generate_grad_clip_conf_for_optim_conf(param_group, optimizer_conf)\n\n            for param in param_group.parameters:\n                vars_conf[param].l2 = l2\n                if param.requires_grad:\n                    optimizer_conf.variable_op_names.append(vars_conf[param].name)\n\n            new_opt_confs.append(optimizer_conf)\n        return new_opt_confs\n\n    @property\n    def support_sparse(self):\n        return False\n\n\ndef make_persistent_table_reader(\n    paths, snapshot_name, key_type, value_type, storage_dim, physical_block_size=4096,\n):\n    r\"\"\"Creates a reader for reading persistent table.\n\n    Args:\n        paths (list): paths of tables to read\n        snapshot_name (str): name of the snapshot to read\n        key_type (flow.dtype): the data type of key\n        value_type (flow.dtype): the data type of value\n        storage_dim (int): number of elements in each value\n        physical_block_size (int, optional): physical_block_size should be sector size. Defaults to 4096\n    \"\"\"\n    return PersistentTableReader(\n        paths,\n        snapshot_name,\n        key_type,\n        value_type,\n        storage_dim,\n        4 * 1024,\n        physical_block_size,\n    )\n\n\ndef make_persistent_table_writer(\n    paths, snapshot_name, key_type, value_type, storage_dim, physical_block_size=4096,\n):\n    r\"\"\"Creates a writer for writing persistent table.\n\n    Args:\n        paths (list): paths of tables to write\n        snapshot_name (str): name of the snapshot to write\n        key_type (flow.dtype): the data type of key\n        value_type (flow.dtype): the data type of value\n        storage_dim (int): number of elements in each value\n        physical_block_size (int, optional): physical_block_size should be sector size. Defaults to 4096\n    \"\"\"\n    return PersistentTableWriter(\n        paths,\n        snapshot_name,\n        key_type,\n        value_type,\n        storage_dim,\n        4 * 1024,\n        physical_block_size,\n    )\n\n\nclass SmartDecayAdam(flow.nn.optimizer.adam.Adam):\n    \"\"\"Implements SmartDecayAdam algorithm.\n       The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_.\n       For Sparse Embedding Table in OneEmbedding, implement the SmartDecayAdam algorithm.\n       For other models, it is same as Adam.\n    \"\"\"\n\n    def _generate_conf_for_graph(self, train_conf, vars_conf):\n        new_opt_confs = super()._generate_conf_for_graph(train_conf, vars_conf)\n        for opt_conf in new_opt_confs:\n            opt_conf.adam_conf.smart_decay = True\n\n\nclass Optimizer(Optimizer):\n    def __init__(\n        self, optimizer: Optimizer, embeddings: List[Embedding],\n    ):\n        self.optimizer = optimizer\n        self.embeddings = embeddings\n        self.param_groups = optimizer.param_groups\n        # self._default_options = optimizer._default_options\n        # self._state = optimizer._state\n        self.defaults = optimizer.defaults\n        self.state = optimizer.state\n        self.embedding_param_group_dict = {}\n        for embedding in self.embeddings:\n            for group in self.param_groups:\n                param_set = set()\n                for param in group.parameters:\n                    param_set.add(param)\n                if embedding.shadow in param_set:\n                    self.embedding_param_group_dict[embedding.embedding_name] = group\n            if not embedding.embedding_name in self.embedding_param_group_dict:\n                raise ValueError(\"embedding must in optimizers param_group\")\n\n    def step(self, closure: Callable = None):\n        step = self.optimizer.state[\"step\"]\n        for embedding in self.embeddings:\n            param_group = self.embedding_param_group_dict[embedding.embedding_name]\n            if type(self.optimizer) is flow.nn.optimizer.sgd.SGD:\n                embedding.sgd_update(param_group, step)\n            elif type(self.optimizer) is flow.nn.optimizer.adam.Adam:\n                embedding.adam_update(param_group, step)\n            elif type(self.optimizer) is flow.nn.optimizer.adagrad.Adagrad:\n                embedding.adagrad_update(param_group, step)\n            elif type(self.optimizer) is flow.one_embedding.Ftrl:\n                embedding.ftrl_update(param_group, step)\n            else:\n                raise NotImplementedError(\"only support sgd, adam, adagrad and ftrl\")\n        self.optimizer.step()\n\n    def _generate_conf_for_graph(self, train_conf, vars_conf):\n        return self.optimizer._generate_conf_for_graph(train_conf, vars_conf)\n"
  },
  {
    "path": "python/oneflow/onnx/__init__.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport warnings\n\n\ndef symbolic_opset11():\n    warnings.warn(\n        \"The oneflow.onnx.symbolic_opset11 interface is just to align the torch.onnx.symbolic_opset11 interface and has no practical significance.\"\n    )\n\n\ndef register_custom_op_symbolic(*args, **kwargs):\n    warnings.warn(\n        \"The oneflow.onnx.register_custom_op_symbolic interface is just to align the torch.onnx.register_custom_op_symbolic interface and has no practical significance.\"\n    )\n"
  },
  {
    "path": "python/oneflow/onnx/symbolic_helper.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport warnings\n\n\ndef parse_args(*args, **kwargs):\n    warnings.warn(\n        \"The oneflow.onnx.parse_args interface is just to align the torch.onnx.parse_args interface and has no practical significance.\"\n    )\n\n    def func(fn):\n        return fn\n\n    return func\n"
  },
  {
    "path": "python/oneflow/ops/__init__.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\n\ndef load_library(path):\n    raise ImportError(\"load_library is not implemented\")\n"
  },
  {
    "path": "python/oneflow/ops/array_ops.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\n\ndef parse_slice_tuple_list(slice_tup_list, shape):\n    ndim = len(shape)\n    if not isinstance(slice_tup_list, (list, tuple)) or len(slice_tup_list) > ndim:\n        raise ValueError(\n            \"slice_tup_list must be a list or tuple with length less than or equal \"\n            \"to number of dimensions of input tensor\"\n        )\n\n    if len(slice_tup_list) < ndim:\n        supple_ndim = ndim - len(slice_tup_list)\n        slice_tup_list += type(slice_tup_list)([(None, None, None)] * supple_ndim)\n\n    start_list, stop_list, step_list = [], [], []\n    for (slice_tup, dim) in zip(slice_tup_list, shape):\n        if not isinstance(slice_tup, (tuple, list)) or len(slice_tup) != 3:\n            raise ValueError(\n                \"element of slice_tup_list must be a list or tuple with form (start, stop, step)\"\n            )\n\n        if not all((isinstance(elem, int) or elem is None for elem in slice_tup)):\n            raise ValueError(\"element of slice tuple must int or None\")\n\n        (start, stop, step) = slice_tup\n\n        if step is None:\n            step = 1\n\n        if step == 0:\n            raise ValueError(\"slice step can't be 0\")\n\n        if start is None:\n            start = 0 if step > 0 else dim\n\n        if stop is None:\n            stop = dim if step > 0 else -dim - 1\n\n        # start range is [-dim, dim-1]\n        start = max(min(start, dim - 1), -dim)\n        # stop range is [-dim-1, dim]\n        stop = max(min(stop, dim), -dim - 1)\n\n        reg_start = start if start >= 0 else start + dim\n        reg_stop = stop if stop >= 0 else stop + dim\n\n        if step > 0 and reg_stop < reg_start:\n            stop = start\n\n        if step < 0 and reg_start < reg_stop:\n            stop = start\n\n        start_list.append(start)\n        stop_list.append(stop)\n        step_list.append(step)\n\n    return start_list, stop_list, step_list\n"
  },
  {
    "path": "python/oneflow/ops/stateful_ops.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow\nimport oneflow._oneflow_internal\nimport oneflow.framework.id_util as id_util\n\n\nclass StatefulOp(object):\n    def __init__(self, op_type_name, op_name=None):\n        if op_name is None:\n            op_name = id_util.UniqueStr(op_type_name)\n        self._builder = oneflow._oneflow_internal.one.OpBuilder(op_type_name, op_name)\n        self._op = None\n        self._op_type_name = op_type_name\n\n    @property\n    def op(self):\n        \"\"\"access the builtin op\n\n        Returns:\n            the builtin op\n        \"\"\"\n        if self._op is None:\n            self._op = self._builder.build()\n        return self._op\n\n    def Input(self, input_name, num=1):\n        \"\"\"Set input blob of op\n\n        Args:\n            input_name (str): input name of blob\n            num (int, optional) : Defaults to 1.\n\n        Returns:\n            self\n        \"\"\"\n        assert isinstance(num, int) and num >= 1\n        self._builder.input(input_name, num)\n        return self\n\n    def Output(self, output_name, num=1):\n        \"\"\"Set output blob of op\n\n        Args:\n            output_name (str): name of output blob\n            num (int, optional):  Defaults to 1.\n\n        Returns:\n            self\n        \"\"\"\n        assert isinstance(num, int) and num >= 1\n        self._builder.output(output_name, num)\n        return self\n\n    def Build(self):\n        \"\"\"Explicitly complete the construction of the builtin op\n\n        Returns:\n            the completed builtin op\n        \"\"\"\n        if self._op is None:\n            self._op = self._builder.build()\n        return self._op\n"
  },
  {
    "path": "python/oneflow/ops/transpose_util.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nfrom typing import Sequence\n\n\ndef is_perm(perm: Sequence[int]) -> bool:\n    return list(range(len(perm))) == sorted(list(perm))\n\n\ndef get_perm_when_transpose_axis_to_last_dim(num_axes: int, axis: int) -> tuple:\n    axis = axis if axis >= 0 else axis + num_axes\n    assert 0 <= axis < num_axes, \"axis out of range\"\n    perm = [dim if dim < axis else dim + 1 for dim in range(num_axes - 1)]\n    perm.append(axis)\n    return tuple(perm)\n\n\ndef get_inversed_perm(perm: Sequence[int]) -> tuple:\n    assert is_perm(perm)\n    inversed_perm = [-1] * len(perm)\n    for i in range(len(perm)):\n        inversed_perm[perm[i]] = i\n    return tuple(inversed_perm)\n"
  },
  {
    "path": "python/oneflow/ops/util/__init__.py",
    "content": ""
  },
  {
    "path": "python/oneflow/ops/util/initializer_util.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport functools\nimport math\nfrom typing import Optional, Sequence, Union\n\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.core.job.initializer_conf_pb2 as initializer_conf_util\nimport oneflow.core.operator.op_conf_pb2 as op_conf_util\nimport oneflow.framework.dtype as dtype_util\n\n\ndef get_random_distribution(distribution):\n    if distribution.lower() == \"truncated_normal\":\n        return initializer_conf_util.kTruncatedNormal\n    elif distribution.lower() == \"random_normal\":\n        return initializer_conf_util.kRandomNormal\n    elif distribution.lower() == \"random_uniform\":\n        return initializer_conf_util.kRandomUniform\n    else:\n        raise ValueError(\"Invalid random_distribution\")\n\n\ndef get_data_format(data_format):\n    assert isinstance(data_format, str), \"data_format must be a string\"\n    if data_format.startswith(\"NC\"):\n        return \"channels_first\"\n    elif data_format.startswith(\"N\") and data_format.endswith(\"C\"):\n        return \"channels_last\"\n    else:\n        assert data_format == \"\", ValueError(\n            'data_format must be \"N...C\" or \"NC...\" or \"\"'\n        )\n        return \"\"\n\n\ndef calc_fan(shape, mode, data_format):\n    assert (\n        len(shape) >= 2\n    ), \"Fan in and fan out can out be computed for tensor with fewer 2 dimensions\"\n    if len(shape) == 2:\n        fan_in = shape[1]\n        fan_out = shape[0]\n    else:\n        fan_in = 1.0\n        for dim in shape[1:]:\n            fan_in *= dim\n        fan_out = shape[0]\n        if data_format == \"channels_first\":\n            for dim in shape[2:]:\n                fan_out *= dim\n        elif data_format == \"channels_last\":\n            for dim in shape[1:-1]:\n                fan_out *= dim\n        else:\n            raise NotImplementedError(\n                \"Only support 'channels_first' and 'channels_last' data format\"\n            )\n    if mode == \"fan_sum\":\n        return float(fan_in) + float(fan_out)\n    elif mode == \"fan_in\":\n        return float(fan_in)\n    elif mode == \"fan_out\":\n        return float(fan_out)\n    else:\n        raise NotImplementedError(\"Only support 'fan_in', 'fan_out' and 'fan_sum' mode\")\n\n\ndef calc_gain(nonlinearity, param=None):\n    linear_fns = [\n        \"linear\",\n        \"conv1d\",\n        \"conv2d\",\n        \"conv3d\",\n        \"conv_transpose1d\",\n        \"conv_transpose2d\",\n        \"conv_transpose3d\",\n    ]\n    if nonlinearity in linear_fns or nonlinearity == \"sigmoid\":\n        return 1\n    elif nonlinearity == \"tanh\":\n        return 5.0 / 3\n    elif nonlinearity == \"relu\":\n        return math.sqrt(2.0)\n    elif nonlinearity == \"leaky_relu\":\n        if param is None:\n            negative_slope = 0.01\n        elif (\n            not isinstance(param, bool)\n            and isinstance(param, int)\n            or isinstance(param, float)\n        ):\n            negative_slope = param\n        else:\n            raise ValueError(\"negative_slope {} not a valid number\".format(param))\n        return math.sqrt(2.0 / (1 + negative_slope ** 2))\n    elif nonlinearity == \"selu\":\n        return 3.0 / 4\n    else:\n        raise ValueError(\"Unsupported nonlinearity {}\".format(nonlinearity))\n"
  },
  {
    "path": "python/oneflow/optim/__init__.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nfrom oneflow.nn.optimizer.adam import Adam\nfrom oneflow.nn.optimizer.adamw import AdamW\nfrom oneflow.optim.optimizer import Optimizer\nfrom oneflow.nn.optimizer.rmsprop import RMSprop\nfrom oneflow.nn.optimizer.sgd import SGD\nfrom oneflow.nn.optimizer.adagrad import Adagrad\nfrom oneflow.nn.optimizer.lamb import LAMB\nfrom oneflow.nn.optimizer.adadelta import Adadelta\nfrom oneflow.nn.optimizer.lbfgs import LBFGS\n\nfrom . import lr_scheduler\nfrom . import swa_utils\n"
  },
  {
    "path": "python/oneflow/optim/lr_scheduler.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nfrom oneflow.nn.optimizer.lr_scheduler import LRScheduler as _LRScheduler\nfrom oneflow.nn.optimizer.cosine_decay_lr import CosineDecayLR\nfrom oneflow.nn.optimizer.cosine_annealing_lr import CosineAnnealingLR\nfrom oneflow.nn.optimizer.lambda_lr import LambdaLR\nfrom oneflow.nn.optimizer.step_lr import StepLR\nfrom oneflow.nn.optimizer.multistep_lr import MultiStepLR\nfrom oneflow.nn.optimizer.exponential_lr import ExponentialLR\nfrom oneflow.nn.optimizer.reduce_lr_on_plateau import ReduceLROnPlateau\nfrom oneflow.nn.optimizer.polynomial_lr import PolynomialLR\nfrom oneflow.nn.optimizer.constant_lr import ConstantLR\nfrom oneflow.nn.optimizer.linear_lr import LinearLR\nfrom oneflow.nn.optimizer.warmup_lr import WarmupLR\nfrom oneflow.nn.optimizer.warmup_lr import WarmupLR as WarmUpLR\nfrom oneflow.nn.optimizer.cosine_annealing_warm_restarts import (\n    CosineAnnealingWarmRestarts,\n)\nfrom oneflow.nn.optimizer.chained_scheduler import ChainedScheduler\nfrom oneflow.nn.optimizer.sequential_lr import SequentialLR\nfrom oneflow.nn.optimizer.multiplicative_lr import MultiplicativeLR\n"
  },
  {
    "path": "python/oneflow/optim/optimizer.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport collections\nimport warnings\nfrom copy import deepcopy\nfrom itertools import chain\nfrom typing import Any, Callable, Dict, Union\n\nfrom oneflow.framework.tensor import Tensor\nfrom oneflow.nn.graph.proxy import ProxyTensor\nfrom oneflow.nn.parameter import Parameter\nfrom oneflow.nn.utils.clip_grad import clip_grad_norm_\nfrom oneflow.nn.utils.parameters_grouping import ContiguousParamsGroup\nimport oneflow as flow\nfrom collections import defaultdict, abc as container_abcs\n\n\nclass ParamGroup(dict):\n    def __init__(\n        self, parameters: Dict[str, Any], default_options: Dict,\n    ):\n        # ParamGroup must be constructed by Dict[\"params\": parameters: List[Parameter, Tensor or ProxyTensor], \"...\": ...]\n        assert isinstance(parameters, dict) and \"params\" in parameters\n        assert not isinstance(parameters[\"params\"], (Parameter, Tensor))\n        self._parameters = list()\n\n        for p in parameters[\"params\"]:\n            if isinstance(p, (Parameter, Tensor)):\n                self._parameters.append(p)\n            elif isinstance(p, ProxyTensor):\n                # Add parameter from nn.Graph\n                self._parameters.append(p.to(Tensor))\n            else:\n                raise ValueError(\n                    \"parameters in ParamGroup must be Tensor or ProxyTensor.\"\n                )\n\n        self._options = deepcopy(default_options)\n        # rewrite options in default_options\n        for key in self._options:\n            if key in parameters:\n                self._options[key] = parameters[key]\n        # add excess keys in dict\n        for key in parameters:\n            if key not in self._options and key != \"params\":\n                self._options[key] = parameters[key]\n\n        self._enable_clip_grad = False\n        if \"clip_grad_max_norm\" in parameters and \"clip_grad_norm_type\" in parameters:\n            self._enable_clip_grad = True\n            self._options[\"clip_grad_max_norm\"] = parameters[\"clip_grad_max_norm\"]\n            self._options[\"clip_grad_norm_type\"] = parameters[\"clip_grad_norm_type\"]\n\n        self._make_options_valid()\n        self.contiguous_params = self._options.get(\"contiguous_params\", False)\n        if self.contiguous_params:\n            self.params_group = ContiguousParamsGroup([parameters[\"params\"]])\n\n        super().__init__(**self._options, params=self._parameters)\n        super().setdefault(\"contiguous_params\", False)\n        super().setdefault(\"_enable_clip_grad\", self._enable_clip_grad)\n\n    def _make_options_valid(self):\n        \"\"\"handle the conflict between optimizer options\n        \"\"\"\n        if self._options.get(\"contiguous_params\", False) and self._options.get(\n            \"fused\", False\n        ):\n            self._options[\"fused\"] = False\n\n            warnings.warn(\n                \"do not set contiguous_params and fused at the same time, \"\n                \"now only contiguous_params is set.\"\n            )\n\n    @property\n    def parameters(self):\n        return self._parameters\n\n    @property\n    def contiguous_parameters(self):\n        \"\"\"return contiguous_parameters for fast updating\n        \"\"\"\n        return self.params_group.grouped_parameters\n\n\nclass _SourceOpOnlyResourceDependenceMode:\n    def __init__(self):\n        self.guard_ = None\n\n    def __enter__(self):\n        self.guard = (\n            flow._oneflow_internal.eager.SourceOpOnlyResourceDependenceModeGuard()\n        )\n\n    def __exit__(self, *args, **kwargs):\n        del self.guard\n\n\ndef _decorate_step(step):\n    def decorated_step(*args, **kwargs):\n        with _SourceOpOnlyResourceDependenceMode():\n            return step(*args, **kwargs)\n\n    return decorated_step\n\n\nclass _RequiredParameter(object):\n    \"\"\"Singleton class representing a required parameter for an Optimizer.\"\"\"\n\n    def __repr__(self):\n        return \"<required parameter>\"\n\n\nrequired = _RequiredParameter()\n\n\nclass Optimizer(object):\n    def __init__(self, parameters, options):\n        self.param_groups = list()\n        self.state = defaultdict(dict)\n        self.defaults = options\n        self.state[\"step\"] = 0\n\n        self._parse_input_parameters(parameters)\n\n        all_remat = all(\n            p.is_local and p.device.rematable\n            for pg in self.param_groups\n            for p in pg.parameters\n        )\n        all_not_remat = all(\n            not p.is_local or not p.device.rematable\n            for pg in self.param_groups\n            for p in pg.parameters\n        )\n        if not all_remat and not all_not_remat:\n            raise ValueError(\n                \"Parameters should be all on rematable device or all on non-rematable device.\"\n            )\n\n        if all_not_remat:\n            # _decorate_step makes mutable update interleaved with backward\n            # computation, producing wrong results in DTR if the original\n            # weight is used to recompute other tensors.\n            # Besides, it makes parameters remain in memory by unknown reasons\n            # even after parameters and optimizer are not hold by python\n            # interpreter.\n            self.step = _decorate_step(self.step)\n        self._state_not_saved = [\n            \"params_group\",\n            \"_parameters\",\n        ]\n\n    def add_param_group(self, param_group) -> None:\n        r\"\"\"\n        \n        Add a param group to the :class:`Optimizer` s `param_groups`.\n        This can be useful when fine tuning a pre-trained network as frozen layers can be made\n        trainable and added to the :class:`Optimizer` as training progresses.\n        \n        Args:\n            param_group (dict): Specifies what Tensors should be optimized along with group\n                specific optimization options.\n        \n        Example:\n\n        >>> import oneflow\n        >>> import oneflow.optim as optim\n        >>> w1 = oneflow.ones(3, 3)\n        >>> w1.requires_grad = True\n        >>> w2 = oneflow.ones(3, 3)\n        >>> w2.requires_grad = True\n        >>> o = optim.SGD([w1])\n        >>> o.param_groups[0]\n        {'lr': 0.001, 'momentum': 0.0, 'dampening': 0.0, 'weight_decay': 0.0, 'nesterov': False, 'maximize': False, 'params': [tensor([[1., 1., 1.],\n                [1., 1., 1.],\n                [1., 1., 1.]], dtype=oneflow.float32, requires_grad=True)]}\n        >>> o.add_param_group({'params': w2})\n        >>> o.param_groups[1]\n        {'lr': 0.001, 'momentum': 0.0, 'dampening': 0.0, 'weight_decay': 0.0, 'nesterov': False, 'maximize': False, 'params': [tensor([[1., 1., 1.],\n                [1., 1., 1.],\n                [1., 1., 1.]], dtype=oneflow.float32, requires_grad=True)]}\n\n        \"\"\"\n        assert isinstance(param_group, dict), \"param group must be a dict\"\n\n        params = param_group[\"params\"]\n        if isinstance(params, flow.Tensor):\n            param_group[\"params\"] = [params]\n        elif isinstance(params, set):\n            raise TypeError(\n                \"optimizer parameters need to be organized in ordered collections, but \"\n                \"the ordering of tensors in sets will change between runs. Please use a list instead.\"\n            )\n        else:\n            param_group[\"params\"] = list(params)\n\n        for param in param_group[\"params\"]:\n            if not isinstance(param, flow.Tensor):\n                raise TypeError(\n                    \"optimizer can only optimize Tensors, \"\n                    \"but one of the params is \" + type(param)\n                )\n            if not param.is_leaf:\n                raise ValueError(\"can't optimize a non-leaf Tensor\")\n\n        for name, default in self.defaults.items():\n            if default is required and name not in param_group:\n                raise ValueError(\n                    \"parameter group didn't specify a value of required optimization parameter \"\n                    + name\n                )\n            else:\n                param_group.setdefault(name, default)\n        params = param_group[\"params\"]\n        if len(params) != len(set(params)):\n            warnings.warn(\n                \"optimizer contains a parameter group with duplicate parameters; \"\n                \"in future, this will cause an error; \",\n                stacklevel=3,\n            )\n\n        param_set = set()\n        for group in self.param_groups:\n            param_set.update(set(group.parameters))\n\n        if not param_set.isdisjoint(set(param_group[\"params\"])):\n            raise ValueError(\"some parameters appear in more than one parameter group\")\n\n        self.param_groups.append(ParamGroup(param_group, self.defaults))\n\n        for param in param_group[\"params\"]:\n            assert param.is_leaf, \"parameters must be leaf tensor\"\n            self.state[param] = dict()\n\n    def load_state_dict(self, state_dict) -> None:\n        r\"\"\"\n        Load the state of the optimizer which is created by `state_dict` function.\n\n        It almost copied from: https://pytorch.org/docs/1.10/_modules/torch/optim/optimizer.html#Optimizer.load_state_dict.\n        \"\"\"\n\n        # Validate the state_dict\n        groups = self.param_groups\n        saved_groups = state_dict[\"param_groups\"]\n\n        if len(groups) != len(saved_groups):\n            raise ValueError(\n                \"loaded state dict has a different number of parameter groups\"\n            )\n\n        for param, saved_param in zip(groups, saved_groups):\n            # the contiguous_params property is remained in state_dict,\n            # so contiguous_params of state_dict and current optimizer should match.\n            if \"contiguous_params\" in param and param[\n                \"contiguous_params\"\n            ] != saved_param.get(\"contiguous_params\", False):\n                raise ValueError(\n                    \"loaded contiguous_params state doesn't match the optimizer\"\n                )\n\n            if param[\"contiguous_params\"]:\n                param_list = param.contiguous_parameters\n            else:\n                param_list = param.parameters\n\n            if len(param_list) != len(saved_param[\"params\"]):\n                raise ValueError(\n                    \"loaded state dict contains a parameter group \"\n                    \"that doesn't match the size of optimizer's group\"\n                )\n\n        # Update the state\n        id_map = {\n            old_id: p\n            for old_id, p in zip(\n                chain.from_iterable((g[\"params\"] for g in saved_groups)),\n                chain.from_iterable(\n                    (\n                        g.parameters\n                        if not g[\"contiguous_params\"]\n                        else g.contiguous_parameters\n                        for g in groups\n                    )\n                ),\n            )\n        }\n\n        def cast(param, value):\n            r\"\"\"Make a deep copy of value, casting all tensors to device or placement of param.\"\"\"\n            if isinstance(value, Tensor):\n                if value.is_local:\n                    value = value.to(param.device)\n                else:\n                    cpu_value_placement = flow.placement(\"cpu\", value.placement.ranks)\n                    cpu_param_placement = flow.placement(\"cpu\", param.placement.ranks)\n                    value = (\n                        value.to_global(placement=cpu_value_placement)\n                        .to_global(placement=cpu_param_placement, sbp=param.sbp)\n                        .to_global(placement=param.placement)\n                    )\n                return value\n            elif isinstance(value, dict):\n                return {k: cast(param, v) for k, v in value.items()}\n            elif isinstance(value, collections.Iterable):\n                return type(value)(cast(param, v) for v in value)\n            else:\n                return value\n\n        # Copy state assigned to params (and cast tensors to appropriate types).\n        # State that is not assigned to params is copied as is (needed for\n        # backward compatibility).\n        state = dict()\n        for k, v in state_dict[\"state\"].items():\n            if k in id_map:\n                param = id_map[k]\n                state[param] = cast(param, v)\n            else:\n                state[k] = v\n        self.state = state\n\n        # Update parameter groups, setting their 'params' value\n        def update_group(group, new_group):\n            new_group.pop(\"params\")\n            g = deepcopy(new_group)\n            group.update(g)\n            group._enable_clip_grad = g[\"_enable_clip_grad\"]\n            group._options = g\n            return group\n\n        param_groups = [update_group(g, ng) for g, ng in zip(groups, saved_groups)]\n        self.param_groups = param_groups\n\n    def state_dict(self):\n        r\"\"\"\n        Returns the state of the optimizer as a :py:class:`dict`.\n\n        It contains two entries:\n\n        * state - a dict holding current optimization state. Its content\n          differs between optimizer classes.\n        * param_group - a dict containing all parameter groups.\n\n        It almost copied from: https://pytorch.org/docs/1.10/_modules/torch/optim/optimizer.html#Optimizer.state_dict.\n        \"\"\"\n\n        # Save order indices instead of Tensors\n        param_mappings = {}\n        start_index = 0\n\n        def pack_group(group):\n            if group[\"contiguous_params\"]:\n                param_list = group.contiguous_parameters\n            else:\n                param_list = group.parameters\n\n            nonlocal start_index\n            packed = {k: v for k, v in group.items() if k not in self._state_not_saved}\n            param_mappings.update(\n                {\n                    id(p): i\n                    for i, p in enumerate(param_list, start_index)\n                    if id(p) not in param_mappings\n                }\n            )\n            packed[\"params\"] = [param_mappings[id(p)] for p in param_list]\n            start_index += len(packed[\"params\"])\n            return packed\n\n        param_groups = [pack_group(g) for g in self.param_groups]\n        # Remap state to use order indices as keys\n        packed_state = {\n            (param_mappings[id(k)] if isinstance(k, Tensor) else k): v\n            for k, v in self.state.items()\n        }\n        return {\n            \"state\": packed_state,\n            \"param_groups\": param_groups,\n        }\n\n    def step(self, closure: Union[Callable, None] = None) -> Union[Tensor, None]:\n        \"\"\"Performs a single optimization step (parameter update).\n\n        Args:\n            closure (Union[Callable, None], optional): A closure that reevaluates the model and returns the loss. Optional for most optimizers.\n\n        Returns:\n            Union[Tensor, None]: The loss. \n        \"\"\"\n        raise NotImplementedError()\n\n    def clip_grad(self, error_if_nonfinite: bool = False):\n        r\"\"\"Clips gradient norm of an iterable of parameters. \n        The norm is computed over all gradients together, as if they were concatenated into a single vector.\n\n        You can set the max_norm and norm_type. \n\n        For more details, you can refer to the documentation of each optimizer(like Adam, SGD and so on). \n\n        You can also refer the code in :func:`oneflow.nn.utils.clip_grad_norm_`\n\n        Args:\n            error_if_nonfinite (bool): if True, an error is thrown if the total\n                norm of the gradients from :attr:``parameters`` is ``nan``,\n                ``inf``, or ``-inf``. Default: False (will switch to True in the future)\n\n        \"\"\"\n        for param_group in self.param_groups:\n            if param_group._enable_clip_grad:\n                clip_grad_norm_(\n                    param_group.parameters,\n                    param_group[\"clip_grad_max_norm\"],\n                    param_group[\"clip_grad_norm_type\"],\n                    error_if_nonfinite,\n                    param_group.get(\"fused\", False),\n                )\n            else:\n                warnings.warn(\n                    \"To enable clip_grad, passing the `clip_grad_max_norm` and `clip_grad_norm_type` parameters when instantializing the Optimizer.\"\n                )\n\n    def zero_grad(self, set_to_none: bool = False):\n        \"\"\"Sets the gradients of all optimized :class:`oneflow.Tensor` s to zero.\n\n        Args:\n            set_to_none (bool): instead of setting to zero, set the grads to None.\n                This will in general have lower memory footprint, and can modestly\n                improve performance. However, it changes certain behaviors.\n        For example:\n            1. When the user tries to access a gradient and perform manual ops on\n            it, a None attribute or a Tensor full of 0s will behave differently.\n\n            2. If the user requests zero_grad(set_to_none=True) followed by a\n            backward pass, grads are guaranteed to be None for params that did not\n            receive a gradient.\n\n            3. Optimizers have a different behavior if the gradient is 0 or None\n            (in one case it does the step with a gradient of 0 and in the other\n            it skips the step altogether).\n        \"\"\"\n        for param_group in self.param_groups:\n            if param_group[\"contiguous_params\"]:\n                param_list = param_group.contiguous_parameters\n            else:\n                param_list = param_group.parameters\n\n            for param in param_list:\n                param._zero_grad_(set_to_none)\n\n    def _parse_input_parameters(self, parameters):\n        \"\"\"\n        Supports such parameters:\n            1. Iterator: flow.optim.SGD(module.parameters(), lr=0.1)\n            2. List[Dict]: flow.optim.SGD([{\"params\": module1.parameters()}, {\"params\": module2.parameters()}])\n            3. List[Parameter or Tensor]: flow.optim.SGD([module.weight, module.bias])\n        \"\"\"\n        if isinstance(parameters, collections.abc.Iterator):\n            # Iterator\n            self.param_groups.append(\n                ParamGroup({\"params\": list(parameters)}, self.defaults)\n            )\n        elif isinstance(parameters, collections.abc.Iterable):\n            # List[Dict]\n            if isinstance(parameters[0], dict):\n                for param in parameters:\n                    assert isinstance(param, dict)\n                    self.param_groups.append(ParamGroup(param, self.defaults))\n            # List[Parameter or Tensor]\n            else:\n                self.param_groups.append(\n                    ParamGroup({\"params\": parameters}, self.defaults)\n                )\n        else:\n            raise TypeError(\n                f\"params argument given to the optimizer should be an iterable of Tensors or dicts, but got {type(parameters)}\"\n            )\n\n    def _generate_grad_clip_conf_for_optim_conf(self, param_group, optimizer_conf):\n        if not param_group._enable_clip_grad:\n            return\n\n        assert \"clip_grad_max_norm\" in param_group\n        assert \"clip_grad_norm_type\" in param_group\n        max_norm = float(param_group[\"clip_grad_max_norm\"])\n        norm_type = float(param_group[\"clip_grad_norm_type\"])\n        clip_grad_norm = optimizer_conf.clip_conf.clip_by_global_norm\n        clip_grad_norm.max_norm = max_norm\n        clip_grad_norm.norm_type = norm_type\n\n    def _generate_lr_scale_for_optim_conf(self, param_group, optimizer_conf):\n        if \"lr_scale\" not in param_group:\n            return\n\n        lr_scale = float(param_group[\"lr_scale\"])\n        optimizer_conf.lr_scale = lr_scale\n\n    @property\n    def support_sparse(self):\n        \"\"\"Whether the Optimizer support sparse update. \n\n        \"\"\"\n        return False\n\n    def _check_variables_in_graph(self, vars_conf):\n        for param_group in self.param_groups:\n            for param in param_group.parameters:\n                if not param.requires_grad:\n                    continue\n\n                if param not in vars_conf:\n                    raise ValueError(\n                        f\"Parameter <{param}> is not in the corresponding nn.Graph/nn.Module.\"\n                        \" Please make sure you call the module's to(..)/to_global(...) method first,\"\n                        \" then add the module's parameters into an optimizer.\"\n                    )\n\n    def _check_variables_optimizer_bound(self, vars_conf):\n        for param_group in self.param_groups:\n            for param in param_group.parameters:\n                if not param.requires_grad:\n                    continue\n\n                if vars_conf[param].bound_optimizer is None:\n                    vars_conf[param].bound_optimizer = self\n                elif vars_conf[param].bound_optimizer is not self:\n                    raise ValueError(\n                        f\"<{vars_conf[param].name}> is already bound to another optimizer.\"\n                    )\n\n    def _generate_indexed_slices_optimizer_conf(self, job_conf, vars_conf):\n        if not self.support_sparse:\n            raise ValueError(f\"{self.__class__} does not support sparse updating.\")\n\n        for param_group in self.param_groups:\n            for param in param_group.parameters:\n                if not param.requires_grad:\n                    continue\n\n                sparse_opt_conf = job_conf.indexed_slices_optimizer_conf\n                sparse_variable_op_names = sparse_opt_conf.include_op_names\n                sparse_variable_op_names.op_name.append(vars_conf[param].name)\n"
  },
  {
    "path": "python/oneflow/optim/swa_utils.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nfrom oneflow.nn.optimizer.swa_utils import SWALR, update_bn, AveragedModel\n"
  },
  {
    "path": "python/oneflow/profiler/__init__.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport oneflow._oneflow_internal\nfrom oneflow.profiler.profiler import (\n    profile,\n    record_function,\n    ProfilerActivity,\n    ProfilerAction,\n    tensorboard_trace_handler,\n)\n\n__all__ = [\n    \"range_push\",\n    \"range_pop\",\n    \"profiler_start\",\n    \"profiler_stop\",\n    \"profile\",\n    \"record_function\",\n    \"ProfilerActivity\",\n    \"kineto_available\",\n    \"tensorboard_trace_handler\",\n    \"ProfilerAction\",\n]\n\n\ndef range_push(range_name):\n    oneflow._oneflow_internal.profiler.RangePush(range_name)\n\n\ndef range_pop():\n    oneflow._oneflow_internal.profiler.RangePop()\n\n\ndef profiler_start():\n    oneflow._oneflow_internal.profiler.ProfilerStart()\n\n\ndef profiler_stop():\n    oneflow._oneflow_internal.profiler.ProfilerStop()\n\n\ndef kineto_available():\n    return True\n"
  },
  {
    "path": "python/oneflow/profiler/events.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport json\nimport copy\nfrom enum import Enum\nfrom typing import Tuple, List, Dict\nfrom collections import OrderedDict\nfrom rich import box\nfrom rich.console import Console\nfrom rich.table import Table\nfrom oneflow.profiler.util import format_time\n\n\nclass EventType(Enum):\n    Custom = 0\n    Kernel = 1\n\n\nclass CustomEventType(Enum):\n    Default = 0\n    CudaKernel = 1\n    CudaRuntime = 2\n\n\nclass EventBase:\n    MAX_NAME_LENGTH = 55\n\n    def __init__(self, name: str, time_total: float, event_type: EventType) -> None:\n        self._name: str = name\n        self._time_total: float = time_total\n        self.count: int = 1\n        self.event_type: EventType = event_type\n\n    def update(self, event) -> None:\n        assert self.event_type == event.event_type\n        self.cpu_time_total += event.cpu_time_total\n        self.count += event.count\n\n    @property\n    def name(self):\n        if len(self._name) > self.MAX_NAME_LENGTH:\n            return self._name[: self.MAX_NAME_LENGTH - 3] + \"...\"\n        return self._name\n\n    @property\n    def cpu_time_total(self):\n        return self._time_total\n\n    @cpu_time_total.setter\n    def cpu_time_total(self, new_time):\n        self._time_total = new_time\n\n    @property\n    def cpu_time(self):\n        return self._time_total / self.count\n\n    @property\n    def cuda_time_total(self):\n        return None\n\n    @cuda_time_total.setter\n    def cuda_time_total(self, new_time):\n        pass\n\n    @property\n    def cuda_time(self):\n        if self.cuda_time_total is None:\n            return None\n        return self.cuda_time_total / self.count\n\n    def has_cuda_time(self) -> bool:\n        return self.cuda_time_total is not None\n\n    def __eq__(self, __o: object) -> bool:\n        return (\n            self.name == __o.name\n            and self.count == __o.count\n            and self.cpu_time_total == __o.cpu_time_total\n            and self.cuda_time_total == __o.cuda_time_total\n        )\n\n\nclass CustomEvent(EventBase):\n    def __init__(\n        self, name: str, time_total: float, custom_event_type: CustomEventType\n    ) -> None:\n        super().__init__(name, time_total, EventType.Custom)\n        self.custom_event_type = custom_event_type\n\n    @classmethod\n    def from_dict(cls, d: dict):\n        return cls(d.get(\"name\"), d.get(\"time\"), CustomEventType(d.get(\"custom_type\")))\n\n    @property\n    def key(self):\n        return self.name, self.custom_event_type\n\n    @property\n    def cuda_time_total(self):\n        if self.custom_event_type == CustomEventType.CudaKernel:\n            return self._time_total\n        return None\n\n    def to_dict(self):\n        device_prefix = \"cuda\" if self.has_cuda_time() else \"cpu\"\n        time_attrs = [f\"{device_prefix}_{suffix}\" for suffix in [\"time\", \"time_total\"]]\n        result = {\n            \"name\": self.name,\n            \"count\": self.count,\n        }\n        for time_attr in time_attrs:\n            result[time_attr] = format_time(getattr(self, time_attr))\n        return result\n\n    def __eq__(self, __o: object) -> bool:\n        return (\n            super().__eq__(__o)\n            and isinstance(__o, type(self))\n            and self.custom_event_type == __o.custom_event_type\n        )\n\n\nclass KernelEvent(EventBase):\n    def __init__(\n        self,\n        name: str,\n        time_total: float,\n        memory_size: int,\n        description: Dict[str, str],\n    ) -> None:\n        super().__init__(name, time_total, EventType.Kernel)\n        self.children: List[CustomEvent] = []\n        self.memory_size = memory_size\n        self.description = description\n        self._cuda_time_total = 0.0\n        self._enable_show_input_shapes = True\n        self._enable_show_attributes = True\n\n    def add_child(self, event: CustomEvent):\n        self.children.append(event)\n        if event.has_cuda_time():\n            self._cuda_time_total += event.cuda_time\n\n    @classmethod\n    def from_dict(cls, d: dict):\n        kernel_event = cls(\n            d.get(\"name\"), d.get(\"time\"), d.get(\"memory_size\"), d.get(\"description\", {})\n        )\n        if \"children\" in d.keys():\n            children_list = d.get(\"children\")\n            if len(children_list) > 0:\n                for child_dict in children_list:\n                    kernel_event.add_child(CustomEvent.from_dict(child_dict))\n        return kernel_event\n\n    @property\n    def key(self):\n        def get_extra_keys():\n            extra_keys = []\n            if self.input_shapes != \"\" and self._enable_show_input_shapes:\n                extra_keys.append(self.description.get(\"input_shapes\")[1])\n            if self.attributes != \"\" and self._enable_show_attributes:\n                extra_keys.append(self.description.get(\"attrs\")[1])\n            return tuple(extra_keys)\n\n        if len(self.children) == 0:\n            return (self.name,) + get_extra_keys()\n        return (\n            self.name,\n            *get_extra_keys(),\n            \",\".join([x.name for x in self.children]),\n        )\n\n    @property\n    def cuda_time_total(self):\n        if self._cuda_time_total > 0.0:\n            return self._cuda_time_total\n        return None\n\n    @cuda_time_total.setter\n    def cuda_time_total(self, new_time):\n        self._cuda_time_total = new_time\n\n    @property\n    def input_shapes(self):\n        if \"input_shapes\" in self.description:\n            return self.description[\"input_shapes\"][0]\n        return \"\"\n\n    @property\n    def attributes(self):\n        if \"attrs\" in self.description:\n            return self.description[\"attrs\"][0]\n        return \"\"\n\n    @property\n    def bandwidth(self):\n        if len(self.children) > 0 and self.has_cuda_time():\n            if self.memory_size != -1:\n                return f\"{self.memory_size / (1024.0 * 1024.0 * 1024.0) / (self.cuda_time / (1000 * 1000)):.3f}GB/s\"\n        return \"\"\n\n    def to_dict(self):\n        result = {\n            \"name\": self.name,\n            \"cpu_time_total\": format_time(self.cpu_time_total),\n            \"cpu_time\": format_time(self.cpu_time),\n            \"count\": self.count,\n            \"input_shapes\": self.input_shapes,\n            \"attributes\": self.attributes,\n        }\n        if self.has_cuda_time():\n            result.update(\n                {\n                    \"cuda_time_total\": format_time(self.cuda_time_total),\n                    \"cuda_time\": format_time(self.cuda_time),\n                }\n            )\n\n        return result\n\n    def update(self, event):\n        assert id(self) != id(event)\n        assert isinstance(event, type(self))\n        assert len(self.children) == len(event.children)\n        assert self.has_cuda_time() == event.has_cuda_time()\n        assert self.key == event.key\n\n        super().update(event)\n        if self.has_cuda_time():\n            self.cuda_time_total += event.cuda_time_total\n\n        for i in range(len(self.children)):\n            self.children[i].update(event.children[i])\n\n    def make_children_average(self):\n        stats: Dict[Tuple[str, ...], CustomEvent] = OrderedDict()\n        for event in self.children:\n            if event.key in stats:\n                stats[event.key].update(event)\n            else:\n                stats[event.key] = copy.deepcopy(event)\n        self.children = list(stats.values())\n        self.children.sort(key=lambda x: x.name)\n\n    def __eq__(self, __o: object) -> bool:\n        return (\n            super().__eq__(__o)\n            and isinstance(__o, type(self))\n            and self.children == __o.children\n            and self.memory_size == __o.memory_size\n            and self.input_shapes == __o.input_shapes\n            and self.attributes == __o.attributes\n        )\n\n\nclass Events(list):\n    def __init__(self, events: str = \"\") -> None:\n        list.__init__([])\n        if events != \"\":\n            self.__init_events(events)\n\n    def __init_events(self, events: str):\n        events_json = json.loads(events)\n        classes = [CustomEvent, KernelEvent]\n        for event_json in events_json:\n            self.append(classes[event_json.get(\"type\")].from_dict(event_json))\n\n    def __str__(self):\n        return self.table()\n\n    def key_averages(self, group_by_input_shape=False, group_by_attributes=False):\n        stats: Dict[Tuple[str, ...], EventBase] = OrderedDict()\n\n        def deal_event(e):\n            if isinstance(e, KernelEvent):\n                e._enable_show_input_shapes = group_by_input_shape\n                e._enable_show_attributes = group_by_attributes\n\n            key = e.key\n            if key in stats:\n                stats[key].update(e)\n            else:\n                stats[key] = copy.deepcopy(e)\n\n        for event in self:\n            if isinstance(event, KernelEvent) and len(event.children) != 0:\n                event.make_children_average()\n                for event_child in event.children:\n                    deal_event(event_child)\n                event.children = []\n            deal_event(event)\n\n        results = Events()\n        results.extend(stats.values())\n        return results\n\n    def table(self):\n        has_input_shapes = any(\n            [\n                x.input_shapes != \"\" and x._enable_show_input_shapes\n                for x in self\n                if isinstance(x, KernelEvent)\n            ]\n        )\n        has_attributes = any(\n            [\n                x.attributes != \"\" and x._enable_show_attributes\n                for x in self\n                if isinstance(x, KernelEvent)\n            ]\n        )\n        has_bandwidth = any(\n            [x.bandwidth != \"\" for x in self if isinstance(x, KernelEvent)]\n        )\n        t = Table(\n            \"Name\",\n            \"CPU time total\",\n            \"CPU time\",\n            \"GPU time total\",\n            \"GPU time\",\n            \"Number of calls\",\n            box=box.SIMPLE,\n        )\n        field_keys = [\n            \"name\",\n            \"cpu_time_total\",\n            \"cpu_time\",\n            \"cuda_time_total\",\n            \"cuda_time\",\n            \"count\",\n        ]\n        if has_input_shapes:\n            t.add_column(\"Input shapes\")\n            field_keys.append(\"input_shapes\")\n        if has_attributes:\n            t.add_column(\"Attributes\")\n            field_keys.append(\"attributes\")\n        if has_bandwidth:\n            t.add_column(\"Bandwidth\")\n            field_keys.append(\"bandwidth\")\n\n        def build_row(data: dict):\n            return tuple(str(data.get(key, \"\")) for key in field_keys)\n\n        for item in self:\n            if isinstance(item, CustomEvent):\n                t.add_row(*build_row(item.to_dict()))\n            if isinstance(item, KernelEvent):\n                t.add_row(*build_row(item.to_dict()))\n                if len(item.children) > 0:\n                    for child in item.children:\n                        t.add_row(*build_row(child.to_dict()))\n        console = Console()\n        with console.capture() as capture:\n            console.print(t)\n        return capture.get()\n"
  },
  {
    "path": "python/oneflow/profiler/profiler.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow._oneflow_internal\nfrom enum import Enum\nfrom typing import Optional, Iterable, Set\nfrom oneflow.profiler.events import Events\n\n\nclass ProfilerActivity(Enum):\n    CPU = 1\n    CUDA = 2\n\n\nclass ProfilerAction(Enum):\n    \"\"\"\n    Profiler actions that can be taken at the specified intervals\n    \"\"\"\n\n    NONE = 0\n    WARMUP = 1\n    RECORD = 2\n    RECORD_AND_SAVE = 3\n\n\ndef tensorboard_trace_handler():\n    raise NotImplementedError()\n\n\ndef supported_activities() -> Set[ProfilerActivity]:\n    activities = set([ProfilerActivity.CPU])\n    if oneflow.cuda.is_available():\n        activities.add(ProfilerActivity.CUDA)\n    return activities\n\n\nclass profile:\n    def __init__(\n        self,\n        activities: Optional[Iterable[ProfilerActivity]] = None,\n        record_shapes: bool = False,\n        record_attrs: bool = False,\n        record_bandwidth_for_cuda: bool = False,\n    ) -> None:\n        self.activities = set(activities) if activities else supported_activities()\n        assert (\n            len(self.activities) > 0\n        ), \"At least one ProfilerActivity must be specified.\"\n        for item in self.activities:\n            assert (\n                item in supported_activities()\n            ), f\"Unsupported ProfilerActivity {item}\"\n        self.record_shapes = record_shapes\n        self.record_attrs = record_attrs\n        if not (ProfilerActivity.CUDA in self.activities):\n            assert (\n                record_bandwidth_for_cuda == False\n            ), \"record_bandwidth_for_cuda = True can only work with cuda.\"\n        self.record_bandwidth_for_cuda = record_bandwidth_for_cuda\n        self.profile_events: Optional[Events] = None\n\n    def __enter__(self):\n        oneflow._oneflow_internal.profiler.EnableProfiler(\n            ProfilerActivity.CPU in self.activities,\n            ProfilerActivity.CUDA in self.activities,\n            self.record_shapes,\n            self.record_attrs,\n            self.record_bandwidth_for_cuda,\n        )\n        return self\n\n    def __exit__(self, exc_type, exc_val, exc_tb):\n        self.profile_events = Events(\n            oneflow._oneflow_internal.profiler.DisableProfilerAndReturnResult()\n        )\n\n    def __check_finish(self):\n        if self.profile_events is None:\n            raise RuntimeError(\"Profiler didn't finish running\")\n\n    def key_averages(self, group_by_input_shape=False, group_by_attributes=False):\n        self.__check_finish()\n        return self.profile_events.key_averages(\n            group_by_input_shape=group_by_input_shape,\n            group_by_attributes=group_by_attributes,\n        )\n\n    def events(self):\n        self.__check_finish()\n        return self.profile_events\n\n\nclass record_function:\n    def __init__(self, name: str) -> None:\n        self.name = name\n        self.__event_recorder_key = \"\"\n\n    def __enter__(self):\n        self.__event_recorder_key = oneflow._oneflow_internal.profiler.StartRecord(\n            self.name\n        )\n        return self\n\n    def __exit__(self, exc_type, exc_val, exc_tb):\n        oneflow._oneflow_internal.profiler.EndRecord(self.__event_recorder_key)\n"
  },
  {
    "path": "python/oneflow/profiler/util.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nUS_IN_MS = 1000.0\nUS_IN_SECOND = US_IN_MS * 1000.0\n\n\ndef format_time(time_us):\n    if time_us >= US_IN_SECOND:\n        return \"{:.3f}s\".format(time_us / US_IN_SECOND)\n    if time_us >= US_IN_MS:\n        return \"{:.3f}ms\".format(time_us / US_IN_MS)\n    return \"{:.3f}us\".format(time_us)\n"
  },
  {
    "path": "python/oneflow/remat/__init__.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport re\n\nimport oneflow as flow\n\n\ndef parse_size(size):\n    units = {\n        \"B\": 1,\n        \"KB\": 2 ** 10,\n        \"MB\": 2 ** 20,\n        \"GB\": 2 ** 30,\n        \"TB\": 2 ** 40,\n        \"\": 1,\n        \"KIB\": 10 ** 3,\n        \"MIB\": 10 ** 6,\n        \"GIB\": 10 ** 9,\n        \"TIB\": 10 ** 12,\n    }\n    m = re.match(r\"^([\\d\\.]+)\\s*([a-zA-Z]{0,3})$\", str(size).strip())\n    assert m is not None\n    number, unit = float(m.group(1)), m.group(2).upper()\n    return int(number * units[unit])\n\n\ndef set_budget(budget: str):\n    budget_in_bytes = parse_size(budget)\n    flow._oneflow_internal.remat.set_budget_in_bytes(budget_in_bytes)\n\n\ndef get_budget():\n    budget_in_bytes = flow._oneflow_internal.remat.budget_in_bytes()\n    return budget_in_bytes\n\n\nset_small_pieces_optimization = (\n    flow._oneflow_internal.remat.set_small_pieces_optimization\n)\nis_small_pieces_optimization_enabled = (\n    flow._oneflow_internal.remat.is_small_pieces_optimization_enabled\n)\n"
  },
  {
    "path": "python/oneflow/sbp.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow\nfrom oneflow.framework.distribute import split_sbp as split\nimport oneflow._oneflow_internal\n\nsbp = oneflow._oneflow_internal.sbp.sbp\nbroadcast = oneflow._oneflow_internal.sbp.broadcast()\npartial_sum = oneflow._oneflow_internal.sbp.partial_sum()\n"
  },
  {
    "path": "python/oneflow/special/__init__.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\n\nfrom .special_ops import erf\nfrom .special_ops import erfc\nfrom .special_ops import erfinv\nfrom .special_ops import exp2\nfrom .special_ops import expm1\nfrom .special_ops import log1p\nfrom .special_ops import log_softmax\nfrom .special_ops import logsumexp\nfrom .special_ops import round\nfrom .special_ops import softmax\nfrom .special_ops import digamma\nfrom .special_ops import psi\nfrom .special_ops import zeta\n"
  },
  {
    "path": "python/oneflow/special/special_ops.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\n\nimport oneflow\nfrom oneflow.framework.tensor import Tensor\n\n# avoid redefine error when add_doc\n\n\ndef erf(x: Tensor):\n    return oneflow._C.erf(x)\n\n\ndef erfc(x: Tensor):\n    return oneflow._C.erfc(x)\n\n\ndef erfinv(x: Tensor):\n    return oneflow._C.erfinv(x)\n\n\ndef exp2(x: Tensor):\n    return oneflow._C.exp2(x)\n\n\ndef expm1(x: Tensor):\n    return oneflow._C.expm1(x)\n\n\ndef log1p(x: Tensor):\n    return oneflow._C.log1p(x)\n\n\ndef log_softmax(x: Tensor, dim: int):\n    return oneflow._C.log_softmax(x, dim)\n\n\ndef logsumexp(x: Tensor, dim: int, keepdim=False):\n    return oneflow._C.logsumexp(x, dim, keepdim)\n\n\ndef round(x: Tensor):\n    return oneflow._C.round(x)\n\n\ndef softmax(x: Tensor, dim: int):\n    return oneflow._C.softmax(x, dim)\n\n\ndef digamma(x: Tensor):\n    return oneflow._C.digamma(x)\n\n\ndef psi(x: Tensor):\n    return oneflow._C.digamma(x)\n\n\ndef zeta(input, other):\n    return oneflow._C.zeta(input, other)\n"
  },
  {
    "path": "python/oneflow/support/__init__.py",
    "content": ""
  },
  {
    "path": "python/oneflow/support/async_util.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport threading\n\n\ndef Await(counter, func):\n    assert counter > 0\n    cond_var = threading.Condition()\n    counter_box = [counter]\n    result_list = []\n\n    def Yield(result=None):\n        result_list.append(result)\n        cond_var.acquire()\n        assert counter_box[0] > 0\n        counter_box[0] -= 1\n        cond_var.notify()\n        cond_var.release()\n\n    func(Yield)\n    cond_var.acquire()\n    while counter_box[0] > 0:\n        cond_var.wait()\n    cond_var.release()\n    return result_list\n"
  },
  {
    "path": "python/oneflow/support/box.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\n\nclass Box(object):\n    def __init__(self, *arg):\n        assert len(arg) <= 1\n        self.has_value_ = len(arg) > 0\n        self.value_ = None\n        if self.has_value_:\n            self.value_ = arg[0]\n\n    @property\n    def value(self):\n        assert self.has_value_\n        return self.value_\n\n    @property\n    def value_setter(self):\n        return lambda val: self.set_value(val)\n\n    def set_value(self, val):\n        self.value_ = val\n        self.has_value_ = True\n\n    def has_value(self):\n        return self.has_value_\n"
  },
  {
    "path": "python/oneflow/support/enable_if.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport inspect\n\nimport oneflow.support.traceinfo as traceinfo\n\n\ndef condition(hob_expr):\n    def Decorator(func):\n        func.__oneflow_condition_hob__ = hob_expr\n        return func\n\n    return Decorator\n\n\ndef get_condition_hob(func):\n    assert hasattr(func, \"__oneflow_condition_hob__\")\n    return func.__oneflow_condition_hob__\n\n\ndef set_condition_hob(func, hob):\n    func.__oneflow_condition_hob__ = hob\n\n\ndef unique(arg_funcs, context=None, default=None):\n    assert isinstance(arg_funcs, (list, tuple))\n    conditional_functions = []\n    for arg_func in arg_funcs:\n        if isinstance(arg_func, tuple):\n            (func, hob_expr) = arg_func\n        elif inspect.isfunction(arg_func):\n            func = arg_func\n            assert hasattr(func, \"__oneflow_condition_hob__\")\n            hob_expr = func.__oneflow_condition_hob__\n        else:\n            raise NotImplementedError\n        debug_str = func.__name__\n        if hasattr(func, \"__debug_str__\"):\n            debug_str = func.__debug_str__\n        conditional_functions.append((hob_expr, func, debug_str))\n    if default is None:\n\n        def default(get_failed_info, *args, **kwargs):\n            raise NotImplementedError(get_failed_info())\n\n    matched_func = GetMatchedFunction(default, conditional_functions, context=context)\n    if matched_func is not None:\n        return matched_func\n    return MakeDefaultFunction(default, conditional_functions, context=context)\n\n\ndef GetMatchedFunction(default, conditional_functions, context=None):\n    select_triple = (None, None, None)\n    for triple in conditional_functions:\n        if not triple[0](context):\n            continue\n        if select_triple[1] is not None:\n            return _MultiMatchedErrorFunction(\n                default, [select_triple, triple], context=context\n            )\n        select_triple = triple\n    return select_triple[1]\n\n\ndef MakeDefaultFunction(default, conditional_functions, context=None):\n    def get_failed_info(customized_prompt=None):\n        failed_info = \"no avaliable function found.\\n\"\n        for (bf, func, location) in conditional_functions:\n            prompt = location if customized_prompt is None else customized_prompt\n            failed_info += \"\\n%s: \\x1b[1;31mFAILED\\x1b[0m\\n\\t%s\\n\" % (\n                prompt,\n                bf.debug_str(context),\n            )\n        return failed_info\n\n    return lambda *args, **kwargs: default(get_failed_info, *args, **kwargs)\n\n\ndef _MultiMatchedErrorFunction(default, matched_functions, context=None):\n    def get_failed_info(customized_prompt=None):\n        failed_info = \"at least two conditional functions matched.\\n\"\n        for (bf, func, location) in matched_functions:\n            prompt = location if customized_prompt is None else customized_prompt\n            failed_info += \"\\n%s: \\x1b[1;31mPASSED\\x1b[0m\\n\\t%s\\n\" % (\n                prompt,\n                bf.debug_str(context),\n            )\n        return failed_info\n\n    return lambda *args, **kwargs: default(get_failed_info, *args, **kwargs)\n"
  },
  {
    "path": "python/oneflow/support/env_var_util.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport os\n\n\ndef string_to_bool(env_str):\n    if env_str.lower() in (\"1\", \"true\", \"yes\", \"on\", \"y\"):\n        return True\n    return False\n\n\ndef parse_boolean_from_env(env_var, defalut_value):\n    # This function aligns with ParseBooleanFromEnv() in oneflow/core/common/util.cpp\n    assert isinstance(env_var, str), \"env variable must be string, but got: \" + type(\n        env_var\n    )\n    assert isinstance(\n        defalut_value, bool\n    ), \"env variable defalut value must be boolean, but got: \" + type(defalut_value)\n    if os.getenv(env_var) is None:\n        return defalut_value\n    else:\n        return string_to_bool(os.getenv(env_var))\n"
  },
  {
    "path": "python/oneflow/support/func_inspect_util.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport inspect\nimport sys\n\nif sys.version_info > (2, 7) and sys.version_info < (3, 0):\n\n    def GetArgNameAndDefaultTuple(func):\n        \"\"\"\n      returns a dictionary of arg_name:default_values for the input function\n      \"\"\"\n        (args, varargs, keywords, defaults) = inspect.getargspec(func)\n        defaults = list(defaults) if defaults is not None else []\n        while len(defaults) < len(args):\n            defaults.insert(0, None)\n        return tuple(zip(args, defaults))\n\n\nelif sys.version_info >= (3, 0):\n\n    def GetArgNameAndDefaultTuple(func):\n        signature = inspect.signature(func)\n        return tuple(\n            [\n                (k, v.default if v.default is not inspect.Parameter.empty else None)\n                for (k, v) in signature.parameters.items()\n            ]\n        )\n\n\nelse:\n    raise NotImplementedError\n\n\ndef GetArgDefaults(func):\n    return tuple(map(lambda x: x[1], GetArgNameAndDefaultTuple(func)))\n"
  },
  {
    "path": "python/oneflow/support/high_order_bool.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow\nimport oneflow._oneflow_internal\n\n\ndef bool_functor(verbose_debug_str):\n    def Decorator(match_function):\n        return HighOrderBool(verbose_debug_str, match_function)\n\n    return Decorator\n\n\ndef hob_context_attr(attr_name):\n    def Decorator(attr_getter):\n        return HobContextAttr(attr_name, attr_getter)\n\n    return Decorator\n\n\nclass BoolFunctor(object):\n    def debug_str(self, ctx, display_result=True):\n        if hasattr(self, \"__debug_str__\"):\n            if display_result:\n                return '\"%s\"[%s]' % (self.__debug_str__, self(ctx))\n            else:\n                return '\"%s\"' % self.__debug_str__\n        return self.verbose_debug_str(ctx, display_result=display_result)\n\n    def verbose_debug_str(self, ctx, display_result=True):\n        raise NotImplementedError\n\n    def __call__(self, ctx):\n        raise NotImplementedError\n\n    def __and__(self, rhs):\n        return _AndBoolFunctor(self, rhs)\n\n    def __or__(self, rhs):\n        return _OrBoolFunctor(self, rhs)\n\n    def __invert__(self):\n        return _NotBoolFunctor(self)\n\n\nclass HighOrderBool(BoolFunctor):\n    def __init__(self, verbose_debug_str, function):\n        self.verbose_debug_str_ = verbose_debug_str\n        self.function_ = function\n\n    def verbose_debug_str(self, ctx, display_result=True):\n        if display_result:\n            return '\"%s\"[%s]' % (self.verbose_debug_str_, self.function_(ctx))\n        else:\n            return '\"%s\"' % self.verbose_debug_str_\n\n    def __call__(self, ctx):\n        return self.function_(ctx)\n\n\nalways_true = HighOrderBool(\"Always true\", lambda: True)\nalways_false = HighOrderBool(\"Always false\", lambda: False)\n\n\nclass _AndBoolFunctor(BoolFunctor):\n    def __init__(self, lhs, rhs):\n        assert isinstance(lhs, BoolFunctor)\n        assert isinstance(rhs, BoolFunctor)\n        self.lhs_ = lhs\n        self.rhs_ = rhs\n\n    def verbose_debug_str(self, ctx, display_result=True):\n        left_display = self.lhs_.debug_str(ctx, display_result)\n        display_result = display_result and self.lhs_(ctx)\n        right_display = self.rhs_.debug_str(ctx, display_result)\n        return \"(%s and %s)\" % (left_display, right_display)\n\n    def __call__(self, ctx):\n        return self.lhs_(ctx) and self.rhs_(ctx)\n\n\nclass _OrBoolFunctor(BoolFunctor):\n    def __init__(self, lhs, rhs):\n        assert isinstance(lhs, BoolFunctor)\n        assert isinstance(rhs, BoolFunctor)\n        self.lhs_ = lhs\n        self.rhs_ = rhs\n\n    def verbose_debug_str(self, ctx, display_result=True):\n        left_display = self.lhs_.debug_str(ctx, display_result)\n        display_result = display_result and (not self.lhs_(ctx))\n        right_display = self.rhs_.debug_str(ctx, display_result)\n        return \"(%s or %s)\" % (left_display, right_display)\n\n    def __call__(self, ctx):\n        return self.lhs_(ctx) or self.rhs_(ctx)\n\n\nclass _NotBoolFunctor(BoolFunctor):\n    def __init__(self, x):\n        assert isinstance(x, BoolFunctor)\n        self.x_ = x\n\n    def verbose_debug_str(self, ctx, display_result=True):\n        return \"(not %s)\" % self.x_.debug_str(ctx, display_result)\n\n    def __call__(self, ctx):\n        return not self.x_(ctx)\n\n\nclass HobContextGetter(object):\n    def __init__(self, attr_name, attr_getter):\n        self.attr_name_ = attr_name\n        self.attr_getter_ = attr_getter\n\n    @property\n    def attr_name(self):\n        return self.attr_name_\n\n    @property\n    def attr_getter(self):\n        return self.attr_getter_\n\n    def __eq__(self, other):\n        if not isinstance(other, HobContextGetter):\n            other = HobContextConstant(other)\n        return self._MakeHob(other, \"==\", lambda a, b: a == b)\n\n    def __ne__(self, other):\n        if not isinstance(other, HobContextGetter):\n            other = HobContextConstant(other)\n        return self._MakeHob(other, \"!=\", lambda a, b: a != b)\n\n    def __gt__(self, other):\n        if not isinstance(other, HobContextGetter):\n            other = HobContextConstant(other)\n        return self._MakeHob(other, \">\", lambda a, b: a > b)\n\n    def __ge__(self, other):\n        if not isinstance(other, HobContextGetter):\n            other = HobContextConstant(other)\n        return self._MakeHob(other, \">=\", lambda a, b: a >= b)\n\n    def __lt__(self, other):\n        if not isinstance(other, HobContextGetter):\n            other = HobContextConstant(other)\n        return self._MakeHob(other, \"<\", lambda a, b: a < b)\n\n    def __le__(self, other):\n        if not isinstance(other, HobContextGetter):\n            other = HobContextConstant(other)\n        return self._MakeHob(other, \"<=\", lambda a, b: a <= b)\n\n    def _MakeHob(self, other, cmp_str, cmp_func):\n        @bool_functor(\"%s %s %s\" % (self.attr_name, cmp_str, other.attr_name))\n        def HobHob(context):\n            return cmp_func(self.attr_getter(context), other.attr_getter(context))\n\n        return HobHob\n\n\nclass HobContextConstant(HobContextGetter):\n    def __init__(self, value):\n        HobContextGetter.__init__(self, str(value), lambda ctx: value)\n\n\nclass HobContextAttr(HobContextGetter):\n    def __init__(self, attr_name, attr_getter):\n        HobContextGetter.__init__(self, attr_name, attr_getter)\n\n    def __getattr__(self, attr_name):\n        @hob_context_attr(\"%s.%s\" % (self.attr_name, attr_name))\n        def HobCtxAttr(ctx):\n            obj = self.attr_getter(ctx)\n            return getattr(obj, attr_name)\n\n        return HobCtxAttr\n\n    def HasField(self, attr_name):\n        @bool_functor('%s.HasField(\"%s\")' % (self.attr_name, attr_name))\n        def BoolFunctor(ctx):\n            obj = self.attr_getter(ctx)\n            if hasattr(obj, \"HasField\"):\n                return obj.HasField(attr_name)\n            else:\n                return hasattr(obj, attr_name)\n\n        return BoolFunctor\n"
  },
  {
    "path": "python/oneflow/support/lazy.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\n\nclass Lazy(object):\n    def __init__(self, get_value):\n        self.value_ = None\n        self.has_value_ = False\n        self.get_value_ = get_value\n\n    @property\n    def value(self):\n        if not self.has_value_:\n            self.value_ = self.get_value_()\n            self.has_value_ = True\n        return self.value_\n"
  },
  {
    "path": "python/oneflow/support/pb_util.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\n\ndef PythonDict2PbMessage(value, msg):\n    def extend_dict(values, msg):\n        for (k, v) in values.items():\n            if type(v) is dict:\n                extend_dict(v, getattr(msg, k))\n            elif type(v) is list or type(v) is tuple:\n                extend_list_or_tuple(v, getattr(msg, k))\n            else:\n                setattr(msg, k, v)\n        else:\n            msg.SetInParent()\n\n    def extend_list_or_tuple(values, msg):\n        if len(values) == 0:\n            return\n        if type(values[0]) is dict:\n            for v in values:\n                cmd = msg.add()\n                extend_dict(v, cmd)\n        else:\n            msg.extend(values)\n\n    extend_dict(value, msg)\n    return msg\n\n\ndef MergePbMessage(dst, src):\n    assert type(dst) is type(src)\n    for field in dst.DESCRIPTOR.fields:\n        field_name = field.name\n        if field.containing_oneof is not None:\n            if dst.WhichOneof(field.containing_oneof.name) is not None:\n                continue\n            src_field_name = src.WhichOneof(field.containing_oneof.name)\n            if src_field_name is None:\n                continue\n            if field_name != src_field_name:\n                continue\n        else:\n            if dst.HasField(field_name):\n                continue\n            if not src.HasField(field_name):\n                continue\n        _MergePbMessageField(dst, src, field)\n\n\ndef _MergePbMessageField(dst, src, field):\n    if field.message_type is None:\n        setattr(dst, field.name, getattr(src, field.name))\n    else:\n        MergePbMessage(getattr(dst, field.name), getattr(src, field.name))\n"
  },
  {
    "path": "python/oneflow/support/scope_stack.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nfrom contextlib import contextmanager\n\n\nclass ScopeStack(object):\n    def __init__(self, init=[]):\n        if not isinstance(init, list):\n            init = [init]\n        assert isinstance(init, list)\n        self.stack_ = init\n\n    def Current(self):\n        assert len(self.stack_) > 0\n        return self.stack_[0]\n\n    @contextmanager\n    def NewScope(self, scope):\n        self.stack_.insert(0, scope)\n        yield\n        self.stack_.pop(0)\n"
  },
  {
    "path": "python/oneflow/support/traceinfo.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport os\nimport traceback\n\n\ndef GetFrameLocationStr(depth=-1):\n    assert depth < 0\n    frame = traceback.extract_stack()[depth - 1]\n    return \"%s:%d\" % (frame[0], frame[1])\n\n\ndef GetStackInfoExcludeOneflowPythonFile():\n    import oneflow\n\n    dirname = os.path.dirname(oneflow.__file__)\n    stack_info = traceback.extract_stack()\n    filtered_stack_info = filter(\n        lambda x: x[0].startswith(dirname) == False, stack_info\n    )\n    return list(filtered_stack_info)\n"
  },
  {
    "path": "python/oneflow/sysconfig.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow\nfrom oneflow.framework.sysconfig import (\n    cmake_build_type,\n    get_compile_flags,\n    get_include,\n    get_lib,\n    get_link_flags,\n    get_liboneflow_link_flags,\n    has_rpc_backend_grpc,\n    has_rpc_backend_local,\n    with_cuda,\n    get_cuda_version,\n    with_rdma,\n)\n\n\nfrom oneflow._oneflow_internal.flags import (\n    with_mlir,\n    with_mlir_cuda_codegen,\n)\n"
  },
  {
    "path": "python/oneflow/test/README.md",
    "content": "## Ops Version : Alpha\n\n\n| Op Name | Doc Test | Compatiable/Completeness Test | Exception | Performance Test |\n| ------------------------- | ------------- | ----------------------------- | --------- | ---------------- |\n| oneflow.autograd.backward | [oneflow.Tensor.backward](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L727)   | [unsqueeze_backward](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_unsqueeze.py#L54)   | [non_requires_grad_tensor_backward](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_autograd.py#L24)   |  |\n| oneflow.autograd.grad | [oneflow.Tensor.grad](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L753)   | [adagrad_clip_grad](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_optim_adagrad.py#L213)   | [non_requires_grad_tensor_backward](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_autograd.py#L24)   |  |\n| oneflow.autograd.no_grad |  | [no_grad](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_autograd_mode.py#L62)   |  |  |\n| oneflow.autograd.enable_grad |  | [enable_grad](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_autograd_mode.py#L50)   |  |  |\n| oneflow.autograd.set_grad_enabled |  | [set_grad_enabled](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_autograd_mode.py#L74)   |  |  |\n| oneflow.autograd.inference_mode |  | [inference_mode](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_autograd_mode.py#L27)   |  |  |\n| oneflow.Tensor.grad | [oneflow.Tensor.grad](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L753)   | [adagrad_clip_grad](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_optim_adagrad.py#L213)   | [non_requires_grad_tensor_backward](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_autograd.py#L24)   |  |\n| oneflow.Tensor.requires_grad | [oneflow.Tensor.requires_grad](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L800)   | [requires_grad_tensor_inplace_and_backward](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_autograd.py#L170)   | [non_requires_grad_tensor_backward](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_autograd.py#L24)   |  |\n| oneflow.Tensor.is_leaf | [oneflow.Tensor.is_leaf](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L767)   |  |  |  |\n| oneflow.Tensor.backward | [oneflow.Tensor.backward](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L727)   | [unsqueeze_backward](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_unsqueeze.py#L54)   | [non_requires_grad_tensor_backward](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_autograd.py#L24)   |  |\n| oneflow.Tensor.detach |  | [tensor_detach](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/tensor/test_tensor_part_2.py#L91)   |  |  |\n| oneflow.Tensor.register_hook | [oneflow.Tensor.register_hook](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L833)   | [tensor_register_hook](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/tensor/test_tensor_part_1.py#L446)   |  |  |\n| oneflow.Tensor.retain_grad | [oneflow.Tensor.retain_grad](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L866)   | [retain_grad_for_leaf_tensor](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_autograd.py#L178)   |  |  |\n| oneflow.autograd.Function.forward |  | [eye_forward](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_eye.py#L27)   |  |  |\n| oneflow.autograd.Function.backward | [oneflow.Tensor.backward](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L727)   | [unsqueeze_backward](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_unsqueeze.py#L54)   | [non_requires_grad_tensor_backward](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_autograd.py#L24)   |  |\n| oneflow.autograd.Function.apply |  | [module_apply](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_module.py#L161)   |  |  |\n| oneflow.autograd.autograd_function.FunctionAutoGradCaptureState.mark_non_differentiable |  |  |  |  |\n| oneflow.autograd.autograd_function.FunctionAutoGradCaptureState.save_for_backward |  |  |  |  |\n| oneflow.autograd.autograd_function.FunctionAutoGradCaptureState.saved_tensors |  |  |  |  |\n| oneflow.cuda.is_available |  |  |  |  |\n| oneflow.cuda.device_count |  |  |  |  |\n| oneflow.cuda.current_device |  |  |  |  |\n| oneflow.cuda.set_device |  |  |  |  |\n| oneflow.cuda.synchronize |  |  |  |  |\n| oneflow.cuda.manual_seed_all |  |  |  |  |\n| oneflow.cuda.manual_seed |  | [generator_manual_seed](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_generator.py#L72)   |  |  |\n| oneflow.cuda.HalfTensor |  |  |  |  |\n| oneflow.cuda.FloatTensor |  |  |  |  |\n| oneflow.cuda.DoubleTensor |  |  |  |  |\n| oneflow.cuda.BoolTensor |  |  |  |  |\n| oneflow.cuda.ByteTensor |  |  |  |  |\n| oneflow.cuda.CharTensor |  |  |  |  |\n| oneflow.cuda.IntTensor |  |  |  |  |\n| oneflow.cuda.LongTensor |  |  |  |  |\n| oneflow.cuda.empty_cache |  |  |  |  |\n| oneflow.nn.functional.conv1d | [oneflow._C.conv1d](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/conv.py#L20)   | [conv1d_grad_grad](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_higher_derivative_conv.py#L128)   |  |  |\n| oneflow.nn.functional.conv2d | [oneflow._C.conv2d](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/conv.py#L57)   | [conv2d_grad_grad](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_higher_derivative_conv.py#L134)   |  | done   |\n| oneflow.nn.functional.conv3d | [oneflow._C.conv3d](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/conv.py#L95)   | [conv3d_grad_grad](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_higher_derivative_conv.py#L140)   |  |  |\n| oneflow.nn.functional.conv_transpose1d |  |  |  |  |\n| oneflow.nn.functional.conv_transpose2d |  |  |  | done   |\n| oneflow.nn.functional.conv_transpose3d |  |  |  |  |\n| oneflow.nn.functional.fold | [oneflow.nn.functional.fold](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/convolution.py#L20)   | [fold_impl](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_fold.py#L25)   |  | done   |\n| oneflow.nn.functional.unfold | [oneflow.Tensor.unfold](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L563)   | [unfold_tensor_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_unfold_tensor.py#L30)   |  |  |\n| oneflow.nn.functional.avg_pool1d | [oneflow._C.avg_pool1d](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/pooling.py#L99)   |  |  |  |\n| oneflow.nn.functional.avg_pool2d | [oneflow._C.avg_pool2d](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/pooling.py#L129)   |  |  | done   |\n| oneflow.nn.functional.avg_pool3d | [oneflow._C.avg_pool3d](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/pooling.py#L151)   |  |  |  |\n| oneflow.nn.functional.max_pool1d |  |  |  |  |\n| oneflow.nn.functional.max_pool2d |  |  |  | done   |\n| oneflow.nn.functional.max_pool3d |  |  |  |  |\n| oneflow.nn.functional.adaptive_avg_pool1d | [oneflow._C.adaptive_avg_pool1d](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/pooling.py#L20)   |  |  | done   |\n| oneflow.nn.functional.adaptive_avg_pool2d | [oneflow._C.adaptive_avg_pool2d](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/pooling.py#L48)   |  |  | done   |\n| oneflow.nn.functional.adaptive_avg_pool3d | [oneflow._C.adaptive_avg_pool3d](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/pooling.py#L74)   |  |  | done   |\n| oneflow.nn.functional.threshold | [oneflow._C.threshold](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L496)   | [softplus_threshold](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L533)   |  | done   |\n| oneflow.nn.functional.relu | [oneflow.relu](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L50)   | [relu_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L33)   | [relu_inplace_runtime_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_activation.py#L29)   | done   |\n| oneflow.nn.functional.hardtanh | [oneflow._C.hardtanh](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L363)   | [hardtanh_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L618)   |  | done   |\n| oneflow.nn.functional.hardswish | [oneflow._C.hardswish](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L316)   | [hardswish_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L590)   |  | done   |\n| oneflow.nn.functional.relu6 |  | [relu6_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L73)   |  | done   |\n| oneflow.nn.functional.elu | [oneflow._C.elu](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L385)   | [elu_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L165)   |  | done   |\n| oneflow.nn.functional.selu | [oneflow.selu](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L409)   | [selu_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L754)   |  | done   |\n| oneflow.nn.functional.celu | [oneflow._C.celu](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L468)   | [celu_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L203)   | [celu_inplace_runtime_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_activation.py#L47)   | done   |\n| oneflow.nn.functional.leaky_relu | [oneflow._C.leaky_relu](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L373)   |  |  | done   |\n| oneflow.nn.functional.prelu | [oneflow._C.prelu](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L20)   | [prelu_4dim_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_prelu.py#L32)   | [prelu_runtime_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_activation.py#L38)   |  |\n| oneflow.nn.functional.glu | [oneflow._C.glu](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L436)   | [glu_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_glu.py#L37)   | [glu_scalar_tensor_runtime_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_activation.py#L57)   | done   |\n| oneflow.nn.functional.gelu | [oneflow.gelu](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L74)   | [gelu_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L253)   |  | done   |\n| oneflow.nn.functional.logsigmoid | [oneflow._C.logsigmoid](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L177)   | [logsigmoid_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L484)   |  | done   |\n| oneflow.nn.functional.hardshrink | [oneflow._C.hardshrink](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L507)   | [hardshrink_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L857)   |  | done   |\n| oneflow.nn.functional.softsign | [oneflow._C.softsign](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L207)   | [softsign_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L782)   |  | done   |\n| oneflow.nn.functional.softplus | [oneflow.softplus](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L146)   | [softplus](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_softplus.py#L43)   |  | done   |\n| oneflow.nn.functional.softmax | [oneflow._C.softmax](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L118)   | [softmax_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L436)   | [softmax_index_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_activation.py#L109)   | done   |\n| oneflow.nn.functional.softshrink | [oneflow._C.softshrink](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L518)   | [softshrink_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L895)   |  | done   |\n| oneflow.nn.functional.log_softmax | [oneflow._C.log_softmax](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L132)   |  |  | done   |\n| oneflow.nn.functional.tanh | [oneflow.tanh](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L163)   | [tanh_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L106)   |  | done   |\n| oneflow.nn.functional.sigmoid | [oneflow.sigmoid](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L338)   | [sigmoid_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L281)   | [hard_sigmoid_inplace_runtime_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_activation.py#L87)   | done   |\n| oneflow.nn.functional.hardsigmoid | [oneflow._C.hardsigmoid](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L298)   | [hardsigmoid_inplace](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L336)   |  | done   |\n| oneflow.nn.functional.silu | [oneflow.silu](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L237)   | [silu_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L726)   |  | done   |\n| oneflow.nn.functional.mish | [oneflow.mish](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L267)   | [mish_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L698)   |  | done   |\n| oneflow.nn.functional.layer_norm | [oneflow.nn.functional.layer_norm](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/normalization.py#L20)   | [t5_layer_norm](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_t5_layernorm.py#L55)   |  |  |\n| oneflow.nn.functional.normalize | [oneflow._C.normalize](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/norm.py#L268)   | [functional_normalize](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_normalize.py#L54)   |  |  |\n| oneflow.nn.functional.linear |  | [interpolate_linear_1d](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_interpolate.py#L27)   |  |  |\n| oneflow.nn.functional.dropout | [oneflow._C.dropout](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/dropout.py#L20)   | [dropout_p01](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_dropout.py#L44)   |  | done   |\n| oneflow.nn.functional.dropout1d | [oneflow._C.dropout1d](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/dropout.py#L102)   | [dropout1d_p0](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_dropout.py#L309)   |  |  |\n| oneflow.nn.functional.dropout2d | [oneflow._C.dropout2d](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/dropout.py#L124)   | [dropout2d_p0](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_dropout.py#L316)   |  |  |\n| oneflow.nn.functional.dropout3d | [oneflow._C.dropout3d](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/dropout.py#L146)   | [dropout3d_p0](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_dropout.py#L323)   |  |  |\n| oneflow.nn.functional.embedding |  | [one_embedding_adagrad](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_one_embedding_adagrad.py#L174)   |  |  |\n| oneflow.nn.functional.one_hot | [oneflow._C.one_hot](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/onehot.py#L20)   | [one_hot](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_one_hot.py#L27)   |  |  |\n| oneflow.nn.functional.cosine_similarity | [oneflow._C.cosine_similarity](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/distance.py#L20)   |  | [cosine_similarity_not_floating_type](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_cosine_similarity.py#L24)   | done   |\n| oneflow.nn.functional.pairwise_distance | [oneflow._C.pairwise_distance](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/distance.py#L54)   | [pairwise_distance_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_pairwise_distance.py#L27)   |  |  |\n| oneflow.nn.functional.sparse_softmax_cross_entropy |  | [eager_global_sparse_softmax_cross_entropy](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_sparse_softmax_cross_entropy.py#L131)   | [sparse_softmax_cross_entropy_prediction_numaxes_err](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_sparse_softmax_cross_entropy_op.py#L23)   |  |\n| oneflow.nn.functional.cross_entropy | [oneflow._C.cross_entropy](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/loss.py#L82)   | [eager_global_sparse_softmax_cross_entropy](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_sparse_softmax_cross_entropy.py#L131)   | [sparse_cross_entropy_prediction_numaxes_err](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_sparse_cross_entropy_op.py#L23)   |  |\n| oneflow.nn.functional.l1_loss | [oneflow._C.l1_loss](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/loss.py#L130)   | [l1_loss_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_loss.py#L277)   | [smooth_l1_loss_shape_err](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_smooth_l1_loss_op.py#L23)   |  |\n| oneflow.nn.functional.mse_loss | [oneflow._C.mse_loss](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/loss.py#L156)   | [mse_loss_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_loss.py#L328)   |  |  |\n| oneflow.nn.functional.smooth_l1_loss | [oneflow._C.smooth_l1_loss](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/loss.py#L186)   | [smooth_l1_loss_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_loss.py#L308)   | [smooth_l1_loss_shape_err](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_smooth_l1_loss_op.py#L23)   |  |\n| oneflow.nn.functional.triplet_margin_loss | [oneflow._C.triplet_margin_loss](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/loss.py#L20)   |  | [triplet_margin_loss_reduce_type_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_nn_functor.py#L255)   |  |\n| oneflow.nn.functional.binary_cross_entropy |  | [nn_functional_binary_cross_entropy](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_loss.py#L245)   |  |  |\n| oneflow.nn.functional.binary_cross_entropy_with_logits |  | [nn_functional_binary_cross_entropy_with_logits](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_loss.py#L269)   |  |  |\n| oneflow.nn.functional.pad | [oneflow._C.pad](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/vision.py#L20)   | [pad_1d_impl](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_pad.py#L25)   | [pad_size_attribute_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_nn_functor.py#L89)   |  |\n| oneflow.nn.functional.interpolate |  | [interpolate_linear_1d](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_interpolate.py#L27)   |  |  |\n| oneflow.nn.functional.upsample |  | [upsample_bilinear_align_corners](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_upsample.py#L338)   |  |  |\n| oneflow.nn.functional.grid_sample |  | [grid_sample_4d](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_grid_sample.py#L31)   |  | done   |\n| oneflow.nn.functional.affine_grid |  | [affine_grid_2d](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_affine_grid.py#L31)   |  | done   |\n| oneflow.nn.functional.ctc_greedy_decoder | [oneflow._C.ctc_greedy_decoder](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/ctc_decode.py#L20)   | [ctc_greedy_decoder](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_ctc_greedy_decoder.py#L111)   |  |  |\n| oneflow.Tensor.new_empty | [oneflow.Tensor.new_empty](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L201)   | [new_empty](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_empty.py#L52)   |  |  |\n| oneflow.Tensor.new_ones | [oneflow.Tensor.new_ones](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L229)   | [flow_new_ones_list_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_constant.py#L107)   |  |  |\n| oneflow.Tensor.new_zeros | [oneflow.Tensor.new_zeros](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L238)   | [new_zeros](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_constant.py#L129)   |  |  |\n| oneflow.Tensor.new_tensor |  | [new_tensor_local_mode_with_default_args](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/tensor/test_new_tensor.py#L25)   |  |  |\n| oneflow.Tensor.is_cuda | [oneflow.Tensor.is_cuda](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L2071)   |  |  |  |\n| oneflow.Tensor.is_global | [oneflow.Tensor.is_global](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L109)   |  |  |  |\n| oneflow.Tensor.device | [oneflow.Tensor.device](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L85)   | [non_default_device](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_randperm.py#L133)   | [device_type](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_device.py#L25)   |  |\n| oneflow.Tensor.grad | [oneflow.Tensor.grad](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L753)   | [adagrad_clip_grad](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_optim_adagrad.py#L213)   | [non_requires_grad_tensor_backward](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_autograd.py#L24)   |  |\n| oneflow.Tensor.ndim | [oneflow.Tensor.ndim](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1315)   | [abs_with_ndim_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_abs.py#L34)   |  |  |\n| oneflow.Tensor.abs | [oneflow.abs](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L20)   | [abs_with_0_size_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_abs.py#L27)   |  | done   |\n| oneflow.Tensor.acos | [oneflow.acos](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L509)   | [acos](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_math_ops.py#L122)   |  |  |\n| oneflow.Tensor.acosh | [oneflow.acosh](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L535)   | [acosh](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_math_ops.py#L138)   |  |  |\n| oneflow.Tensor.add | [oneflow.add](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L41)   | [scatter_add_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_scatter_ops.py#L57)   | [add_inplace_runtime_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_binary_functor_exception.py#L27)   | done   |\n| oneflow.Tensor.add_ | [oneflow.Tensor.add_](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1222)   | [scatter_add_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_scatter_ops.py#L57)   | [add_inplace_runtime_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_binary_functor_exception.py#L27)   |  |\n| oneflow.Tensor.addcdiv | [oneflow.Tensor.addcdiv](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L939)   | [addcdiv](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_addcdiv.py#L25)   |  | done   |\n| oneflow.Tensor.addcdiv_ | [oneflow.Tensor.addcdiv_](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L946)   | [tensor_addcdiv_inplace](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_addcdiv.py#L49)   |  |  |\n| oneflow.Tensor.addcmul | [oneflow.addcmul](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L1558)   | [addcmul](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_addcmul.py#L37)   |  | done   |\n| oneflow.Tensor.addcmul_ | [oneflow.Tensor.addcmul_](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1236)   | [tensor_addcmul_inplace](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_addcmul.py#L50)   |  |  |\n| oneflow.Tensor.addmm | [oneflow.Tensor.addmm](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1215)   | [addmm](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_addmm.py#L60)   |  | done   |\n| oneflow.Tensor.all | [oneflow.Tensor.all](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1822)   | [flow_var_all_dim_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_var.py#L27)   |  |  |\n| oneflow.Tensor.amin | [oneflow.Tensor.amin](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L2167)   | [amin_with_negative_dim](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_amin.py#L34)   |  | done   |\n| oneflow.Tensor.amax | [oneflow.Tensor.amax](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L911)   | [amax_with_negative_dim](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_amax.py#L35)   |  | done   |\n| oneflow.Tensor.any | [oneflow.Tensor.any](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1831)   | [any_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_logical_reduce.py#L52)   |  |  |\n| oneflow.Tensor.arccos | [oneflow.Tensor.arccos](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L664)   | [arccos](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_math_ops.py#L114)   |  |  |\n| oneflow.Tensor.arccosh | [oneflow.Tensor.arccosh](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L678)   | [arccosh](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_math_ops.py#L130)   |  |  |\n| oneflow.Tensor.arcsin | [oneflow.Tensor.arcsin](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1257)   | [flow_arcsin_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_math_ops.py#L230)   |  |  |\n| oneflow.Tensor.arcsinh | [oneflow.Tensor.arcsinh](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1264)   | [flow_arcsinh_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_math_ops.py#L247)   |  |  |\n| oneflow.Tensor.arctan | [oneflow.Tensor.arctan](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1343)   | [flow_arctan_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_math_ops.py#L274)   |  |  |\n| oneflow.Tensor.arctanh | [oneflow.Tensor.arctanh](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L685)   | [flow_arctanh_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_math_ops.py#L296)   |  |  |\n| oneflow.Tensor.argmax | [oneflow.Tensor.argmax](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L692)   | [argmax_axis_negative](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_argmax.py#L29)   | [argmax_index_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_array_functor.py#L22)   | done   |\n| oneflow.Tensor.argmin | [oneflow.Tensor.argmin](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L699)   | [argmin_axis_negative](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_argmin.py#L29)   |  |  |\n| oneflow.Tensor.argsort | [oneflow.Tensor.argsort](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L706)   | [argsort](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_argsort.py#L37)   |  | done   |\n| oneflow.Tensor.argwhere | [oneflow.Tensor.argwhere](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L713)   | [argwhere_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_argwhere.py#L50)   |  |  |\n| oneflow.Tensor.asin | [oneflow.asin](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L285)   | [flow_asin_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_math_ops.py#L223)   |  |  |\n| oneflow.Tensor.asinh | [oneflow.asinh](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L318)   | [flow_asinh_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_math_ops.py#L240)   |  |  |\n| oneflow.Tensor.atan | [oneflow.atan](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L353)   | [flow_atan_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_math_ops.py#L267)   |  |  |\n| oneflow.Tensor.atan2 | [oneflow.Tensor.atan2](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L123)   | [atan2](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_math_ops.py#L155)   |  |  |\n| oneflow.Tensor.atanh | [oneflow.atanh](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L564)   | [flow_atanh_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_math_ops.py#L289)   |  |  |\n| oneflow.Tensor.backward | [oneflow.Tensor.backward](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L727)   | [unsqueeze_backward](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_unsqueeze.py#L54)   | [non_requires_grad_tensor_backward](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_autograd.py#L24)   |  |\n| oneflow.Tensor.bmm | [oneflow.Tensor.bmm](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L876)   | [bmm](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_bmm.py#L93)   | [bmm_exception_dim_not_right](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_bmm.py#L25)   |  |\n| oneflow.Tensor.byte | [oneflow.Tensor.byte](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L2159)   | [byte](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/tensor/test_tensor_part_1.py#L1234)   |  |  |\n| oneflow.Tensor.cast | [oneflow.cast](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/cast.py#L20)   | [cast_float2int](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_cast.py#L28)   | [add_broad_cast_runtime_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_binary_functor_exception.py#L37)   |  |\n| oneflow.Tensor.ceil | [oneflow.ceil](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L378)   | [ceil_flow_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_ceil.py#L29)   |  |  |\n| oneflow.Tensor.chunk | [oneflow.Tensor.chunk](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L883)   | [flow_chunk_list_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_chunk.py#L46)   | [chunk_0_dim_input_exception](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_chunk.py#L25)   |  |\n| oneflow.Tensor.clamp | [oneflow.clamp](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/clamp.py#L20)   | [clamp](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_clamp.py#L96)   |  |  |\n| oneflow.Tensor.clamp_ | [oneflow.Tensor.clamp_](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1548)   | [clamp_scalar_min](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_clamp.py#L47)   |  |  |\n| oneflow.Tensor.clip | [oneflow.clip](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/clamp.py#L152)   | [adagrad_clip_grad](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_optim_adagrad.py#L213)   |  |  |\n| oneflow.Tensor.clip_ | [oneflow.Tensor.clip_](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1562)   | [adagrad_clip_grad](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_optim_adagrad.py#L213)   |  |  |\n| oneflow.Tensor.clone |  | [clone_impl](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_clone.py#L24)   |  |  |\n| oneflow.Tensor.contiguous |  | [tensor_scatter_nd_update_with_non_contiguous_input](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_tensor_scatter_nd_update.py#L40)   |  |  |\n| oneflow.Tensor.copy_ | [oneflow.Tensor.copy_](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1468)   | [copy_broadcast_tensor](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_copy.py#L30)   |  |  |\n| oneflow.Tensor.cos | [oneflow.cos](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L712)   | [global_cos_grad_grad](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_math_op_higher_derivative.py#L65)   |  |  |\n| oneflow.Tensor.cosh | [oneflow.cosh](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L736)   |  |  |  |\n| oneflow.Tensor.cpu | [oneflow.Tensor.cpu](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1569)   | [from_torch_cpu](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_from_torch.py#L26)   |  |  |\n| oneflow.Tensor.cuda | [oneflow.Tensor.cuda](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1587)   | [cuda](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_tensor_ops.py#L110)   |  |  |\n| oneflow.Tensor.cumprod | [oneflow.cumprod](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L1788)   | [cumprod_impl](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_cumprod.py#L25)   |  | done   |\n| oneflow.Tensor.cumsum | [oneflow.cumsum](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L1755)   | [cumsum](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_cumsum.py#L37)   |  | done   |\n| oneflow.Tensor.data |  | [swapdims_flow_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_swapdims.py#L32)   | [normal_data_type_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_nn_functor.py#L278)   |  |\n| oneflow.Tensor.dot | [oneflow.dot](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L1438)   | [dot](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/tensor/test_tensor_part_1.py#L903)   | [dot_shape_error_msg](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_dot.py#L24)   | done   |\n| oneflow.Tensor.detach |  | [tensor_detach](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/tensor/test_tensor_part_2.py#L91)   |  |  |\n| oneflow.Tensor.placement | [oneflow.Tensor.placement](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L95)   | [eager_boxing_with_same_placement_p_to_s1](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_eager_boxing.py#L3093)   | [multi_input_with_diff_placement](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_multi_input_with_diff_device_or_placement.py#L42)   |  |\n| oneflow.Tensor.sbp | [oneflow.Tensor.sbp](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L102)   | [eager_global_cast_with_same_placement_and_sbp](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_eager_boxing.py#L3205)   | [get_sbp_with_invalid_axis](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_local_global_convert_error.py#L24)   |  |\n| oneflow.Tensor.diag | [oneflow.Tensor.diag](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L932)   | [diag_impl](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_diag.py#L26)   |  | done   |\n| oneflow.Tensor.diagonal | [oneflow.Tensor.diagonal](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1294)   | [diagonal_impl](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_diagonal.py#L24)   | [diagonal_index_error1](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_array_functor.py#L204)   | done   |\n| oneflow.Tensor.dim | [oneflow.Tensor.dim](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L953)   | [cosine_similartiy_module_with_nonequal_dim_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_cosine_similarity.py#L53)   | [glu_dim_index_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_activation.py#L67)   |  |\n| oneflow.Tensor.div | [oneflow.div](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L143)   | [div_grad_grad_impl](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_higher_derivative_div.py#L26)   | [div_inplace_runtime_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_binary_functor_exception.py#L81)   | done   |\n| oneflow.Tensor.div_ | [oneflow.Tensor.div_](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1116)   | [div_grad_grad_impl](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_higher_derivative_div.py#L26)   | [div_inplace_runtime_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_binary_functor_exception.py#L81)   |  |\n| oneflow.Tensor.double | [oneflow.Tensor.double](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L2041)   | [double](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_tensor_ops.py#L211)   |  |  |\n| oneflow.Tensor.dtype |  | [out_grad_with_different_dtype](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_autograd.py#L113)   | [sparse_cross_entropy_label_dtype_err](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_sparse_cross_entropy_op.py#L53)   |  |\n| oneflow.Tensor.element_size | [oneflow.Tensor.element_size](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L962)   |  |  |  |\n| oneflow.Tensor.eq | [oneflow.Tensor.eq](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1011)   | [eq_impl](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_eq.py#L25)   |  | done   |\n| oneflow.Tensor.erf | [oneflow.erf](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L763)   | [flow_erf_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_erf.py#L33)   |  | done   |\n| oneflow.Tensor.erfc | [oneflow.erfc](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L810)   | [erfc_impl](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_erfc.py#L25)   |  | done   |\n| oneflow.Tensor.erfinv | [oneflow.Tensor.erfinv](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L997)   | [flow_erfinv_with_inf_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_erfinv.py#L30)   |  | done   |\n| oneflow.Tensor.erfinv_ | [oneflow.Tensor.erfinv_](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1004)   | [flow_erfinv_with_inf_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_erfinv.py#L30)   |  |  |\n| oneflow.Tensor.exp | [oneflow.exp](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L476)   | [exp](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_math_ops.py#L72)   |  |  |\n| oneflow.Tensor.expand | [oneflow.Tensor.expand](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L130)   | [expand_new_dims_broadcast](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_expand_op.py#L28)   | [expand_dim_runtime_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_array_functor.py#L78)   |  |\n| oneflow.Tensor.expand_as | [oneflow.Tensor.expand_as](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L139)   |  |  |  |\n| oneflow.Tensor.expm1 | [oneflow.expm1](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L845)   | [expm1_impl](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_expm1.py#L29)   |  | done   |\n| oneflow.Tensor.fill_ | [oneflow.Tensor.fill_](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1053)   | [masked_fill_with_0dim_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_masked_fill.py#L35)   |  | done   |\n| oneflow.Tensor.flatten | [oneflow.Tensor.flatten](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L155)   | [to_global_flatten_hierarchy](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_cast.py#L30)   |  | done   |\n| oneflow.Tensor.flip | [oneflow.Tensor.flip](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L169)   | [image_flip](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_image_flip.py#L70)   |  | done   |\n| oneflow.Tensor.float | [oneflow.Tensor.float](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L2020)   | [logical_xor_float](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_logical_xor.py#L37)   |  |  |\n| oneflow.Tensor.floor | [oneflow.floor](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L100)   | [floor](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_floor.py#L35)   |  | done   |\n| oneflow.Tensor.floor_ | [oneflow.floor_](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L135)   | [flow_floor_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_floor.py#L57)   |  |  |\n| oneflow.Tensor.floor_divide |  |  |  |  |\n| oneflow.Tensor.fmod | [oneflow.fmod](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L890)   | [flow_fmod_element_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/tensor/test_tensor_part_1.py#L1021)   |  | done   |\n| oneflow.Tensor.gather | [oneflow.Tensor.gather](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1531)   | [gather_nd](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_gather_nd.py#L85)   | [gather_index_type_runtime_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_array_functor.py#L120)   | done   |\n| oneflow.Tensor.ge | [oneflow.Tensor.ge](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1062)   |  |  |  |\n| oneflow.Tensor.get_device | [oneflow.Tensor.get_device](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1069)   |  |  |  |\n| oneflow.Tensor.grad_fn | [oneflow.Tensor.grad_fn](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L760)   | [parameter_grad_fn_none](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/tensor/test_parameter.py#L29)   |  |  |\n| oneflow.Tensor.gt | [oneflow.Tensor.gt](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1080)   |  |  | done   |\n| oneflow.Tensor.half | [oneflow.Tensor.half](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1520)   | [module_to_half](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_module_to_half.py#L25)   |  |  |\n| oneflow.Tensor.in_top_k | [oneflow.Tensor.in_top_k](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L176)   | [in_top_k_impl](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_in_top_k.py#L82)   | [in_top_k_num_equal_runtime_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_array_functor.py#L389)   |  |\n| oneflow.Tensor.index_select | [oneflow.Tensor.index_select](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L185)   | [index_select_by_random](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_index_select.py#L30)   | [index_select_runtime_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_array_functor.py#L330)   |  |\n| oneflow.Tensor.int | [oneflow.Tensor.int](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1978)   | [logical_xor_int](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_logical_xor.py#L27)   | [tensordot_too_large_int_dims_runtime_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_tensordot.py#L35)   |  |\n| oneflow.Tensor.is_contiguous | [oneflow.Tensor.is_contiguous](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L2062)   |  |  |  |\n| oneflow.Tensor.is_floating_point | [oneflow.is_floating_point](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/is_floating_point.py#L20)   | [is_floating_point](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_tensor_ops.py#L176)   |  |  |\n| oneflow.Tensor.is_lazy | [oneflow.Tensor.is_lazy](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L116)   |  |  |  |\n| oneflow.Tensor.is_leaf | [oneflow.Tensor.is_leaf](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L767)   |  |  |  |\n| oneflow.Tensor.isinf | [oneflow.Tensor.isinf](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L2152)   | [isinf](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_util_ops.py#L33)   |  |  |\n| oneflow.Tensor.isnan | [oneflow.Tensor.isnan](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L2145)   | [isnan](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_util_ops.py#L24)   |  |  |\n| oneflow.Tensor.item | [oneflow.Tensor.item](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L2087)   | [tensordot_single_item_tensor_dim](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_tensordot.py#L105)   |  |  |\n| oneflow.Tensor.le | [oneflow.Tensor.le](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1025)   |  |  |  |\n| oneflow.Tensor.log | [oneflow.log](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L923)   | [log](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_math_ops.py#L56)   |  |  |\n| oneflow.Tensor.log1p | [oneflow.log1p](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L455)   | [log1p_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_log1p.py#L31)   |  |  |\n| oneflow.Tensor.log2 | [oneflow.log2](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L948)   | [log2_tensor_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/tensor/test_tensor_part_1.py#L971)   |  |  |\n| oneflow.Tensor.logical_and | [oneflow.Tensor.logical_and](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1677)   | [logical_and](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_logical_and.py#L58)   |  |  |\n| oneflow.Tensor.logical_or | [oneflow.Tensor.logical_or](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1687)   | [logical_or](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_logical_or.py#L58)   |  |  |\n| oneflow.Tensor.logical_not | [oneflow.Tensor.logical_not](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L520)   | [logical_not](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_logical_not.py#L43)   |  |  |\n| oneflow.Tensor.logical_xor | [oneflow.Tensor.logical_xor](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1698)   | [logical_xor_int](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_logical_xor.py#L27)   |  |  |\n| oneflow.Tensor.long | [oneflow.Tensor.long](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1999)   | [long](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_tensor_ops.py#L145)   |  |  |\n| oneflow.Tensor.lt | [oneflow.Tensor.lt](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1018)   |  |  |  |\n| oneflow.Tensor.masked_fill | [oneflow.Tensor.masked_fill](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1708)   | [masked_fill](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_masked_fill.py#L58)   |  |  |\n| oneflow.Tensor.masked_select | [oneflow.Tensor.masked_select](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1715)   | [masked_select](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_masked_select.py#L87)   |  |  |\n| oneflow.Tensor.matmul | [oneflow.matmul](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L1249)   | [fused_matmul_op](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_cublas_fused_mlp.py#L173)   | [matmul_dimension_error1](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_nn_functor.py#L220)   |  |\n| oneflow.Tensor.mm | [oneflow.mm](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L1311)   | [flow_mm_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_matmul.py#L69)   | [mm_not_2dim](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_mm.py#L24)   |  |\n| oneflow.Tensor.mv | [oneflow.mv](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L1278)   | [flow_mv_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_matmul.py#L78)   | [mv_not_matrix](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_mv.py#L23)   | done   |\n| oneflow.Tensor.max | [oneflow.Tensor.max](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1774)   | [moving_average_min_max_observer](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_moving_average_max_min_observer.py#L83)   |  |  |\n| oneflow.Tensor.maximum | [oneflow.maximum](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L997)   | [broadcast_maximum](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_maximum_minimum.py#L32)   |  |  |\n| oneflow.Tensor.median | [oneflow.median](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L1019)   | [median](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_median.py#L48)   | [median_exception_dim_out_of_range](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_median.py#L25)   |  |\n| oneflow.Tensor.mean | [oneflow.Tensor.mean](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1840)   | [mean](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_mean.py#L70)   | [normalization_moving_mean_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_nn_functor.py#L317)   |  |\n| oneflow.Tensor.min | [oneflow.Tensor.min](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1783)   | [moving_average_min_max_observer](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_moving_average_max_min_observer.py#L83)   |  |  |\n| oneflow.Tensor.minimum | [oneflow.minimum](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L975)   | [broadcast_minimum](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_maximum_minimum.py#L50)   |  |  |\n| oneflow.Tensor.mish | [oneflow.mish](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L267)   | [mish_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L698)   |  | done   |\n| oneflow.Tensor.mul | [oneflow.mul](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L186)   | [einsum_eltwise_mul_then_reduce_sum](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_einsum_eltwise_mul_then_reduce_sum.py#L40)   |  |  |\n| oneflow.Tensor.mul_ | [oneflow.Tensor.mul_](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1108)   | [fused_matmul_op](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_cublas_fused_mlp.py#L173)   | [matmul_dimension_error1](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_nn_functor.py#L220)   |  |\n| oneflow.Tensor.narrow | [oneflow.Tensor.narrow](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L629)   | [narrow](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_narrow.py#L35)   | [narrow_dim_index_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_array_functor.py#L178)   |  |\n| oneflow.Tensor.ndimension |  |  |  |  |\n| oneflow.Tensor.ne | [oneflow.Tensor.ne](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1032)   | [ne](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_ne.py#L31)   |  |  |\n| oneflow.Tensor.neg | [oneflow.Tensor.neg](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1039)   | [flow_split_sizes_neg_dim_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_split.py#L63)   | [tensordot_neg_dims_runtime_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_tensordot.py#L25)   |  |\n| oneflow.Tensor.negative | [oneflow.negative](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L428)   | [argmax_axis_negative](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_argmax.py#L29)   | [repeat_interleave_negative_tensor_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_repeat_interleave.py#L58)   |  |\n| oneflow.Tensor.nelement | [oneflow.Tensor.nelement](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1137)   | [tensor_nelement](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/tensor/test_tensor_part_1.py#L552)   |  |  |\n| oneflow.Tensor.nonzero | [oneflow.nonzero](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/nonzero.py#L20)   | [nonzero](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_nonzero.py#L51)   |  |  |\n| oneflow.Tensor.norm | [oneflow.linalg.norm](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/norm.py#L160)   | [clip_grad_norm_impl](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_clip_grad.py#L50)   |  |  |\n| oneflow.Tensor.normal_ | [oneflow.Tensor.normal_](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1154)   | [eager_boxing_normal_1d_exhaustive_testing](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_eager_boxing_exhaustive.py#L113)   | [normal_data_type_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_nn_functor.py#L278)   |  |\n| oneflow.Tensor.numel | [oneflow.Tensor.numel](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L194)   | [tensor_numel](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/tensor/test_tensor_part_1.py#L558)   |  |  |\n| oneflow.Tensor.numpy | [oneflow.Tensor.numpy](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1163)   | [dropout_numpy_p0](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_dropout.py#L29)   | [numpy_type](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_pad.py#L32)   |  |\n| oneflow.Tensor.permute | [oneflow.Tensor.permute](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L643)   | [einsum_batch_permute](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_einsum_batch_permute.py#L42)   |  |  |\n| oneflow.Tensor.pow | [oneflow.pow](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L1132)   | [pow_with_scalar](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_math_ops.py#L96)   |  |  |\n| oneflow.Tensor.prod | [oneflow.Tensor.prod](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1849)   | [prod_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_logical_reduce.py#L59)   |  |  |\n| oneflow.Tensor.reciprocal | [oneflow.reciprocal](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L226)   | [flow_reciprocal_list_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_reciprocal.py#L32)   |  |  |\n| oneflow.Tensor.register_hook | [oneflow.Tensor.register_hook](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L833)   | [tensor_register_hook](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/tensor/test_tensor_part_1.py#L446)   |  |  |\n| oneflow.Tensor.relu | [oneflow.relu](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L50)   | [relu_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L33)   | [relu_inplace_runtime_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_activation.py#L29)   | done   |\n| oneflow.Tensor.repeat | [oneflow.Tensor.repeat](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1622)   | [flow_tensor_repeat_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_repeat.py#L27)   | [repeat_interleave_index_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_repeat_interleave.py#L25)   |  |\n| oneflow.Tensor.repeat_interleave | [oneflow.Tensor.repeat_interleave](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1631)   | [flow_int_repeat_interleave_dim_none](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_repeat_interleave.py#L29)   | [repeat_interleave_index_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_repeat_interleave.py#L25)   |  |\n| oneflow.Tensor.requires_grad | [oneflow.Tensor.requires_grad](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L800)   | [requires_grad_tensor_inplace_and_backward](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_autograd.py#L170)   | [non_requires_grad_tensor_backward](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_autograd.py#L24)   |  |\n| oneflow.Tensor.requires_grad_ | [oneflow.Tensor.requires_grad_](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L809)   | [requires_grad_tensor_inplace_and_backward](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_autograd.py#L170)   | [non_requires_grad_tensor_backward](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_autograd.py#L24)   |  |\n| oneflow.Tensor.reshape | [oneflow.Tensor.reshape](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1858)   | [reshape_impl](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_reshape.py#L27)   | [reshape_like_size_match_err](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_reshape_like_op.py#L24)   | done   |\n| oneflow.Tensor.reshape_as | [oneflow.Tensor.reshape_as](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1865)   | [reshape_as_tensor_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/tensor/test_tensor_part_1.py#L1181)   |  |  |\n| oneflow.Tensor.retain_grad | [oneflow.Tensor.retain_grad](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L866)   | [retain_grad_for_leaf_tensor](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_autograd.py#L178)   |  |  |\n| oneflow.Tensor.roll | [oneflow.Tensor.roll](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1187)   | [roll](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_roll.py#L27)   | [roll_runtime_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_array_functor.py#L112)   |  |\n| oneflow.Tensor.round | [oneflow.round](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L1346)   | [flow_round_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_round.py#L30)   |  |  |\n| oneflow.Tensor.rsqrt | [oneflow.rsqrt](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L1173)   | [rsqrt](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_math_ops.py#L80)   |  |  |\n| oneflow.Tensor.selu | [oneflow.selu](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L409)   | [selu_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L754)   |  | done   |\n| oneflow.Tensor.shape |  | [randn_tuple_shape](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_randn.py#L62)   | [layernorm_exception_input_shape_not_match](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_layernorm.py#L25)   |  |\n| oneflow.Tensor.sigmoid | [oneflow.sigmoid](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L338)   | [sigmoid_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L281)   | [hard_sigmoid_inplace_runtime_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_activation.py#L87)   | done   |\n| oneflow.Tensor.sign | [oneflow.sign](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L589)   | [sign_impl](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_sign.py#L25)   |  |  |\n| oneflow.Tensor.silu | [oneflow.silu](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L237)   | [silu_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L726)   |  | done   |\n| oneflow.Tensor.sin | [oneflow.sin](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L618)   | [global_sin_grad_grad](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_math_op_higher_derivative.py#L59)   |  |  |\n| oneflow.Tensor.sin_ | [oneflow.sin_](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L648)   | [global_sin_grad_grad](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_math_op_higher_derivative.py#L59)   |  |  |\n| oneflow.Tensor.sinh | [oneflow.sinh](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L656)   | [sinh](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_math_ops.py#L23)   |  |  |\n| oneflow.Tensor.size | [oneflow.Tensor.size](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1392)   | [unsqueeze_with_0_size_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_unsqueeze.py#L62)   | [local_to_global_with_invalid_size](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_local_global_convert_error.py#L75)   |  |\n| oneflow.Tensor.softmax | [oneflow._C.softmax](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L118)   | [softmax_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L436)   | [softmax_index_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_activation.py#L109)   | done   |\n| oneflow.Tensor.softplus | [oneflow.softplus](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L146)   | [softplus](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_softplus.py#L43)   |  | done   |\n| oneflow.Tensor.softsign | [oneflow._C.softsign](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L207)   | [softsign_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L782)   |  | done   |\n| oneflow.Tensor.sort | [oneflow.Tensor.sort](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1947)   | [sort](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_sort.py#L69)   |  |  |\n| oneflow.Tensor.split | [oneflow.Tensor.split](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L890)   | [eager_boxing_2d_special_split_axis](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_eager_boxing_exhaustive.py#L146)   | [local_to_global_with_invalid_split_axis](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_local_global_convert_error.py#L39)   |  |\n| oneflow.Tensor.sqrt | [oneflow.sqrt](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L1198)   | [sqrt](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_math_ops.py#L64)   |  |  |\n| oneflow.Tensor.square | [oneflow.square](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L1224)   | [inv_random_square_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_inv.py#L39)   | [inv_exception_not_square_matrix](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_inv.py#L34)   |  |\n| oneflow.Tensor.squeeze | [oneflow.Tensor.squeeze](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L556)   | [squeeze](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_squeeze.py#L94)   | [squeeze_index_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_array_functor.py#L106)   |  |\n| oneflow.Tensor.std | [oneflow.std](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L1371)   | [std_flow_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_std.py#L26)   |  |  |\n| oneflow.Tensor.storage_offset | [oneflow.Tensor.storage_offset](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L268)   |  |  |  |\n| oneflow.Tensor.stride |  | [flow_as_strided_with_stride](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_as_stride.py#L49)   |  |  |\n| oneflow.Tensor.sum | [oneflow.Tensor.sum](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1813)   | [sum_impl](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_sum.py#L29)   | [reduce_sum_like_empty_axis_case_err](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_reduce_like_ops.py#L24)   |  |\n| oneflow.Tensor.swapaxes | [oneflow.Tensor.swapaxes](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L904)   | [swapaxes_flow_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_swapaxes.py#L31)   |  |  |\n| oneflow.Tensor.swapdims | [oneflow.Tensor.swapdims](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L918)   | [swapdims_flow_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_swapdims.py#L32)   |  |  |\n| oneflow.Tensor.sub | [oneflow.sub](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L246)   | [global_sub](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_sub.py#L50)   |  |  |\n| oneflow.Tensor.sub_ | [oneflow.Tensor.sub_](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1123)   | [global_sub_with_0_size_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_sub.py#L56)   |  |  |\n| oneflow.Tensor.tan | [oneflow.tan](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L687)   | [flow_tan_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_math_ops.py#L257)   |  |  |\n| oneflow.Tensor.tanh | [oneflow.tanh](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L163)   | [tanh_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L106)   |  | done   |\n| oneflow.Tensor.tile | [oneflow.tile](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tile.py#L20)   | [flow_tile_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_tile.py#L27)   | [tile_runtime_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_array_functor.py#L431)   |  |\n| oneflow.Tensor.to | [oneflow.Tensor.to](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1485)   | [dummy_module_to](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_module_to.py#L58)   | [local_to_global_with_invalid_split_axis](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_local_global_convert_error.py#L39)   |  |\n| oneflow.Tensor.local_to_global | [oneflow.Tensor.local_to_global](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L286)   | [local_to_global_2d_sbp](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_cast.py#L85)   | [local_to_global_with_invalid_split_axis](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_local_global_convert_error.py#L39)   |  |\n| oneflow.Tensor.global_to_global | [oneflow.Tensor.global_to_global](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L334)   | [cuda_global_to_global_cpu_s2b](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_cast.py#L210)   | [global_to_global_with_invalid_split_axis](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_local_global_convert_error.py#L51)   |  |\n| oneflow.Tensor.to_global | [oneflow.Tensor.to_global](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L381)   | [to_global_flatten_hierarchy](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_cast.py#L30)   | [local_to_global_with_invalid_split_axis](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_local_global_convert_error.py#L39)   |  |\n| oneflow.Tensor.to_local | [oneflow.Tensor.to_local](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L473)   |  | [call_to_local_for_local_tensor](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_local_global_convert_error.py#L65)   |  |\n| oneflow.Tensor.to_consistent | [oneflow.Tensor.to_consistent](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L466)   |  |  |  |\n| oneflow.Tensor.tolist | [oneflow.Tensor.tolist](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L2108)   | [tolist](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_tensor_ops.py#L257)   |  |  |\n| oneflow.Tensor.topk | [oneflow.Tensor.topk](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1751)   | [flow_topk_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_math_ops.py#L306)   |  |  |\n| oneflow.Tensor.transpose | [oneflow.Tensor.transpose](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L513)   | [einsum_matrix_transpose](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_einsum_matrix_transpose.py#L35)   |  |  |\n| oneflow.Tensor.tril | [oneflow.Tensor.tril](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1441)   | [fused_scale_tril](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_fused_scale_tril.py#L78)   |  |  |\n| oneflow.Tensor.triu | [oneflow.Tensor.triu](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1448)   | [triu](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_triu.py#L47)   |  |  |\n| oneflow.Tensor.type_as | [oneflow.Tensor.type_as](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1954)   | [type_as](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_tensor_ops.py#L165)   |  |  |\n| oneflow.Tensor.type | [oneflow.Tensor.type](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L2192)   | [type_tensor](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_type_tensor.py#L74)   | [cosine_similarity_not_floating_type](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_cosine_similarity.py#L24)   |  |\n| oneflow.Tensor.t | [oneflow.Tensor.t](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1640)   | [global_tensor_scatter_nd_update_t](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_tensor_scatter_nd_update.py#L140)   | [t_runtime_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_array_functor.py#L439)   |  |\n| oneflow.Tensor.T | [oneflow.Tensor.t](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1640)   | [global_tensor_scatter_nd_update_t](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_tensor_scatter_nd_update.py#L140)   | [t_runtime_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_array_functor.py#L439)   |  |\n| oneflow.Tensor.unbind | [oneflow.Tensor.unbind](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L897)   | [unbind_flow_with_random_data1](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_unbind.py#L32)   | [unbind_index_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_array_functor.py#L248)   |  |\n| oneflow.Tensor.unfold | [oneflow.Tensor.unfold](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L563)   | [unfold_tensor_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_unfold_tensor.py#L30)   |  |  |\n| oneflow.Tensor.uniform_ | [oneflow.Tensor.uniform_](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1455)   |  |  |  |\n| oneflow.Tensor.unsqueeze | [oneflow.Tensor.unsqueeze](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L636)   | [unsqueeze](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_unsqueeze.py#L68)   |  |  |\n| oneflow.Tensor.var | [oneflow.var](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L1407)   | [module_to_with_var_reuse](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_module_to.py#L93)   |  |  |\n| oneflow.Tensor.view | [oneflow.Tensor.view](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1881)   | [view](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_view.py#L79)   | [view_exception](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_view.py#L25)   |  |\n| oneflow.Tensor.view_as | [oneflow.Tensor.view_as](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1931)   |  |  |  |\n| oneflow.Tensor.where | [oneflow.Tensor.where](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L2129)   | [where](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_where.py#L196)   |  |  |\n| oneflow.Tensor.zero_ | [oneflow.Tensor.zero_](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L2136)   | [nonzero_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_nonzero.py#L64)   |  |  |\n| oneflow.Tensor.nms | [oneflow.Tensor.nms](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1758)   | [nms](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_nms.py#L50)   |  |  |\n| oneflow.Tensor.pin_memory | [oneflow.Tensor.pin_memory](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L2174)   | [tensor_pin_memory](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/tensor/test_tensor_pin_memory.py#L33)   |  |  |\n| oneflow.Tensor.is_pinned | [oneflow.Tensor.is_pinned](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L2183)   | [tensor_is_pinned](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/tensor/test_tensor_pin_memory.py#L76)   |  |  |\n| oneflow.nn.Parameter |  | [ddp_with_partial_requires_grad_parameter](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_ddp.py#L225)   | [direction_parameter_err](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_arg_sort_op.py#L23)   |  |\n| oneflow.nn.Module | [oneflow.nn.Module.to_consistent](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/module.py#L20)   | [dummy_module](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_module_to.py#L45)   |  |  |\n| oneflow.nn.Sequential |  |  |  |  |\n| oneflow.nn.ModuleList |  |  |  |  |\n| oneflow.nn.ModuleDict |  | [moduledict](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_module.py#L353)   |  |  |\n| oneflow.nn.ParameterList |  |  |  |  |\n| oneflow.nn.ParameterDict |  |  |  |  |\n| oneflow.nn.Module.add_module |  |  |  |  |\n| oneflow.nn.Module.apply |  | [module_apply](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_module.py#L161)   |  |  |\n| oneflow.nn.Module.buffers |  |  |  |  |\n| oneflow.nn.Module.children |  |  |  |  |\n| oneflow.nn.Module.cpu | [oneflow.Tensor.cpu](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1569)   | [from_torch_cpu](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_from_torch.py#L26)   |  |  |\n| oneflow.nn.Module.cuda | [oneflow.Tensor.cuda](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1587)   | [cuda](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_tensor_ops.py#L110)   |  |  |\n| oneflow.nn.Module.double | [oneflow.Tensor.double](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L2041)   | [double](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_tensor_ops.py#L211)   |  |  |\n| oneflow.nn.Module.train |  | [train_eval](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_module.py#L121)   |  |  |\n| oneflow.nn.Module.eval |  | [dropout_eval_p01](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_dropout.py#L33)   | [normalization_eval_need_moving_statistic_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_nn_functor.py#L347)   |  |\n| oneflow.nn.Module.extra_repr |  |  |  |  |\n| oneflow.nn.Module.float | [oneflow.Tensor.float](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L2020)   | [logical_xor_float](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_logical_xor.py#L37)   |  |  |\n| oneflow.nn.Module.forward |  | [eye_forward](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_eye.py#L27)   |  |  |\n| oneflow.nn.Module.load_state_dict |  | [load_state_dict](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_module.py#L63)   |  |  |\n| oneflow.nn.Module.modules |  |  |  |  |\n| oneflow.nn.Module.named_buffers |  |  |  |  |\n| oneflow.nn.Module.named_children |  |  |  |  |\n| oneflow.nn.Module.named_modules |  |  |  |  |\n| oneflow.nn.Module.named_parameters |  |  |  |  |\n| oneflow.nn.Module.parameters |  |  |  |  |\n| oneflow.nn.Module.register_buffer |  |  |  |  |\n| oneflow.nn.Module.register_forward_hook |  |  |  |  |\n| oneflow.nn.Module.register_forward_pre_hook |  |  |  |  |\n| oneflow.nn.Module.register_parameter |  |  |  |  |\n| oneflow.nn.Module.requires_grad_ | [oneflow.Tensor.requires_grad_](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L809)   | [requires_grad_tensor_inplace_and_backward](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_autograd.py#L170)   | [non_requires_grad_tensor_backward](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_autograd.py#L24)   |  |\n| oneflow.nn.Module.state_dict |  | [load_state_dict](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_module.py#L63)   |  |  |\n| oneflow.nn.Module.to | [oneflow.Tensor.to](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1485)   | [dummy_module_to](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_module_to.py#L58)   | [local_to_global_with_invalid_split_axis](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_local_global_convert_error.py#L39)   |  |\n| oneflow.nn.Module.zero_grad |  |  |  |  |\n| oneflow.nn.Conv1d | [oneflow._C.conv1d](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/conv.py#L20)   | [conv1d_grad_grad](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_higher_derivative_conv.py#L128)   |  |  |\n| oneflow.nn.Conv2d | [oneflow._C.conv2d](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/conv.py#L57)   | [conv2d_grad_grad](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_higher_derivative_conv.py#L134)   |  |  |\n| oneflow.nn.Conv3d | [oneflow._C.conv3d](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/conv.py#L95)   | [conv3d_grad_grad](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_higher_derivative_conv.py#L140)   |  |  |\n| oneflow.nn.ConvTranspose1d |  |  |  |  |\n| oneflow.nn.ConvTranspose2d |  |  |  |  |\n| oneflow.nn.ConvTranspose3d |  |  |  |  |\n| oneflow.nn.Unfold | [oneflow.Tensor.unfold](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L563)   | [unfold_tensor_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_unfold_tensor.py#L30)   |  |  |\n| oneflow.nn.Fold | [oneflow.nn.functional.fold](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/convolution.py#L20)   | [fold_impl](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_fold.py#L25)   |  |  |\n| oneflow.nn.MaxPool1d |  | [maxpool1d_functional](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_maxpool.py#L28)   |  |  |\n| oneflow.nn.MaxPool2d |  | [maxpool2d_functional](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_maxpool.py#L51)   |  |  |\n| oneflow.nn.MaxPool3d |  | [maxpool3d_functional](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_maxpool.py#L75)   |  |  |\n| oneflow.nn.AdaptiveAvgPool1d |  |  |  |  |\n| oneflow.nn.AdaptiveAvgPool2d |  |  |  |  |\n| oneflow.nn.AdaptiveAvgPool3d |  |  |  |  |\n| oneflow.nn.AvgPool1d |  | [adaptive_avgpool1d](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_adaptive_pool.py#L39)   |  |  |\n| oneflow.nn.AvgPool2d |  | [adaptive_avgpool2d](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_adaptive_pool.py#L53)   |  |  |\n| oneflow.nn.AvgPool3d |  | [adaptive_avgpool3d](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_adaptive_pool.py#L72)   |  |  |\n| oneflow.nn.ConstantPad1d |  | [constantpad1d_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_constant_pad.py#L32)   |  |  |\n| oneflow.nn.ConstantPad2d |  | [ConstantPad2d](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_zeropad2d.py#L96)   |  |  |\n| oneflow.nn.ConstantPad3d |  | [constantpad3d_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_constant_pad.py#L64)   |  |  |\n| oneflow.nn.ReflectionPad1d |  |  |  |  |\n| oneflow.nn.ReflectionPad2d |  |  |  |  |\n| oneflow.nn.ReplicationPad1d |  |  |  |  |\n| oneflow.nn.ReplicationPad2d |  | [ReplicationPad2d](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_replication_pad.py#L104)   |  |  |\n| oneflow.nn.ZeroPad2d |  | [global_ZeroPad2d](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_zeropad2d.py#L37)   |  |  |\n| oneflow.nn.ELU | [oneflow._C.elu](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L385)   | [elu_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L165)   |  |  |\n| oneflow.nn.Hardshrink | [oneflow._C.hardshrink](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L507)   | [hardshrink_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L857)   |  |  |\n| oneflow.nn.Hardsigmoid | [oneflow._C.hardsigmoid](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L298)   | [hardsigmoid_inplace](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L336)   |  |  |\n| oneflow.nn.Hardswish | [oneflow._C.hardswish](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L316)   | [hardswish_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L590)   |  |  |\n| oneflow.nn.Hardtanh | [oneflow._C.hardtanh](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L363)   | [hardtanh_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L618)   |  |  |\n| oneflow.nn.LeakyReLU |  | [leakyrelu_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L656)   |  |  |\n| oneflow.nn.LogSigmoid | [oneflow._C.logsigmoid](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L177)   | [logsigmoid_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L484)   |  |  |\n| oneflow.nn.PReLU | [oneflow._C.prelu](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L20)   | [prelu_4dim_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_prelu.py#L32)   | [prelu_runtime_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_activation.py#L38)   |  |\n| oneflow.nn.ReLU | [oneflow.relu](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L50)   | [relu_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L33)   | [relu_inplace_runtime_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_activation.py#L29)   |  |\n| oneflow.nn.ReLU6 |  | [relu6_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L73)   |  |  |\n| oneflow.nn.SELU | [oneflow.selu](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L409)   | [selu_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L754)   |  |  |\n| oneflow.nn.CELU | [oneflow._C.celu](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L468)   | [celu_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L203)   | [celu_inplace_runtime_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_activation.py#L47)   |  |\n| oneflow.nn.GELU | [oneflow.gelu](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L74)   | [gelu_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L253)   |  |  |\n| oneflow.nn.SiLU | [oneflow.silu](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L237)   | [silu_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L726)   |  |  |\n| oneflow.nn.Sigmoid | [oneflow.sigmoid](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L338)   | [sigmoid_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L281)   | [hard_sigmoid_inplace_runtime_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_activation.py#L87)   |  |\n| oneflow.nn.Mish | [oneflow.mish](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L267)   | [mish_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L698)   |  |  |\n| oneflow.nn.Softplus | [oneflow.softplus](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L146)   | [softplus](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_softplus.py#L43)   |  |  |\n| oneflow.nn.Softshrink | [oneflow._C.softshrink](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L518)   | [softshrink_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L895)   |  |  |\n| oneflow.nn.Softsign | [oneflow._C.softsign](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L207)   | [softsign_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L782)   |  |  |\n| oneflow.nn.Tanh | [oneflow.tanh](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L163)   | [tanh_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L106)   |  |  |\n| oneflow.nn.Threshold | [oneflow._C.threshold](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L496)   | [softplus_threshold](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L533)   |  |  |\n| oneflow.nn.GLU | [oneflow._C.glu](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L436)   | [glu_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_glu.py#L37)   | [glu_scalar_tensor_runtime_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_activation.py#L57)   |  |\n| oneflow.nn.Softmax | [oneflow._C.softmax](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L118)   | [softmax_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L436)   | [softmax_index_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_activation.py#L109)   |  |\n| oneflow.nn.LogSoftmax |  | [logsoftmax_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L460)   |  |  |\n| oneflow.nn.BatchNorm1d |  | [batchnorm1d_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_batchnorm.py#L34)   |  |  |\n| oneflow.nn.BatchNorm2d |  | [batchnorm2d_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_batchnorm.py#L52)   |  |  |\n| oneflow.nn.BatchNorm3d |  | [batchnorm3d_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_batchnorm.py#L70)   |  |  |\n| oneflow.nn.FusedBatchNorm1d |  |  |  |  |\n| oneflow.nn.FusedBatchNorm2d |  |  |  |  |\n| oneflow.nn.FusedBatchNorm3d |  |  |  |  |\n| oneflow.nn.GroupNorm |  | [groupnorm](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_groupnorm.py#L332)   |  |  |\n| oneflow.nn.InstanceNorm1d |  | [instancenorm1d](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_instancenorm.py#L29)   |  |  |\n| oneflow.nn.InstanceNorm2d |  | [instancenorm2d](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_instancenorm.py#L71)   |  |  |\n| oneflow.nn.InstanceNorm3d |  | [instancenorm3d](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_instancenorm.py#L141)   |  |  |\n| oneflow.nn.LayerNorm |  | [t5_layernorm](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_t5_layernorm.py#L83)   | [layernorm_exception_input_shape_not_match](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_layernorm.py#L25)   |  |\n| oneflow.nn.RMSLayerNorm |  |  |  |  |\n| oneflow.nn.RNN |  | [rnn_relu_cell](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_rnn_cell.py#L206)   |  |  |\n| oneflow.nn.LSTM |  | [lstm_cell](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_rnn_cell.py#L200)   |  |  |\n| oneflow.nn.GRU |  | [gru_cell](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_rnn_cell.py#L218)   |  |  |\n| oneflow.nn.RNNCell |  |  |  |  |\n| oneflow.nn.LSTMCell |  |  |  |  |\n| oneflow.nn.GRUCell |  |  |  |  |\n| oneflow.nn.Identity |  | [identity](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_linear.py#L113)   |  |  |\n| oneflow.nn.Linear |  | [interpolate_linear_1d](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_interpolate.py#L27)   |  |  |\n| oneflow.nn.Dropout | [oneflow._C.dropout](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/dropout.py#L20)   | [dropout_p01](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_dropout.py#L44)   |  |  |\n| oneflow.nn.Dropout1d | [oneflow._C.dropout1d](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/dropout.py#L102)   | [dropout1d_p0](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_dropout.py#L309)   |  |  |\n| oneflow.nn.Dropout2d | [oneflow._C.dropout2d](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/dropout.py#L124)   | [dropout2d_p0](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_dropout.py#L316)   |  |  |\n| oneflow.nn.Dropout3d | [oneflow._C.dropout3d](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/dropout.py#L146)   | [dropout3d_p0](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_dropout.py#L323)   |  |  |\n| oneflow.nn.Embedding |  | [one_embedding_adagrad](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_one_embedding_adagrad.py#L174)   |  |  |\n| oneflow.nn.CosineSimilarity |  |  |  |  |\n| oneflow.nn.PairwiseDistance |  |  |  |  |\n| oneflow.nn.BCELoss |  |  |  |  |\n| oneflow.nn.BCEWithLogitsLoss |  |  |  |  |\n| oneflow.nn.CTCLoss |  |  | [ctcloss_reduction_type_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_nn_functor.py#L62)   |  |\n| oneflow.nn.CombinedMarginLoss |  |  |  |  |\n| oneflow.nn.CrossEntropyLoss |  |  |  |  |\n| oneflow.nn.KLDivLoss |  |  |  |  |\n| oneflow.nn.L1Loss |  |  |  |  |\n| oneflow.nn.MSELoss |  |  |  |  |\n| oneflow.nn.MarginRankingLoss |  |  |  |  |\n| oneflow.nn.NLLLoss |  |  |  |  |\n| oneflow.nn.SmoothL1Loss |  |  |  |  |\n| oneflow.nn.TripletMarginLoss |  |  |  |  |\n| oneflow.nn.PixelShuffle |  |  |  |  |\n| oneflow.nn.Upsample |  | [upsample_bilinear_align_corners](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_upsample.py#L338)   |  |  |\n| oneflow.nn.UpsamplingBilinear2d |  | [UpsamplingBilinear2d](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_upsample.py#L97)   |  |  |\n| oneflow.nn.UpsamplingNearest2d |  | [UpsamplingNearest2d](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_upsample.py#L74)   |  |  |\n| oneflow.nn.parallel.DistributedDataParallel |  |  |  |  |\n| oneflow.nn.COCOReader |  |  |  |  |\n| oneflow.nn.CoinFlip |  |  |  |  |\n| oneflow.nn.CropMirrorNormalize |  |  |  |  |\n| oneflow.nn.OFRecordBytesDecoder |  | [OFRecordBytesDecoder](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_dataset.py#L351)   |  |  |\n| oneflow.nn.OFRecordImageDecoder |  |  |  |  |\n| oneflow.nn.OFRecordImageDecoderRandomCrop |  |  |  |  |\n| oneflow.nn.OFRecordRawDecoder |  |  |  |  |\n| oneflow.nn.OFRecordReader |  |  |  |  |\n| oneflow.nn.MinMaxObserver |  |  |  |  |\n| oneflow.nn.MovingAverageMinMaxObserver |  |  |  |  |\n| oneflow.nn.FakeQuantization |  |  |  |  |\n| oneflow.nn.QatConv1d |  |  |  |  |\n| oneflow.nn.QatConv2d |  |  |  |  |\n| oneflow.nn.QatConv3d |  |  |  |  |\n| oneflow.nn.utils.clip_grad_norm_ |  | [clip_grad_norm_impl](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_clip_grad.py#L50)   |  |  |\n| oneflow.nn.utils.clip_grad_value_ |  | [clip_grad_value_impl](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_clip_grad.py#L79)   |  |  |\n| oneflow.nn.utils.weight_norm |  | [weight_norm_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_weight_norm.py#L150)   |  |  |\n| oneflow.nn.utils.remove_weight_norm |  |  |  |  |\n| oneflow.nn.utils.rnn.PackedSequence |  |  |  |  |\n| oneflow.nn.utils.rnn.pack_padded_sequence |  |  |  |  |\n| oneflow.nn.utils.rnn.pad_packed_sequence |  |  |  |  |\n| oneflow.nn.utils.rnn.pad_sequence |  |  |  |  |\n| oneflow.nn.utils.rnn.pack_sequence |  |  |  |  |\n| oneflow.nn.Flatten | [oneflow.Tensor.flatten](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L155)   | [to_global_flatten_hierarchy](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_cast.py#L30)   |  |  |\n| oneflow.nn.FakeQuantization |  |  |  |  |\n| oneflow.nn.MinMaxObserver |  |  |  |  |\n| oneflow.nn.MovingAverageMinMaxObserver |  |  |  |  |\n| oneflow.nn.Quantization |  |  |  |  |\n| oneflow.BoolTensor |  |  |  |  |\n| oneflow.ByteTensor |  |  |  |  |\n| oneflow.CharTensor |  |  |  |  |\n| oneflow.DoubleTensor |  |  |  |  |\n| oneflow.FloatTensor |  |  |  |  |\n| oneflow.HalfTensor |  |  |  |  |\n| oneflow.IntTensor |  |  |  |  |\n| oneflow.LongTensor |  |  |  |  |\n| oneflow.is_tensor |  | [ellipsis_tensor](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/tensor/test_tensor_indexing2.py#L900)   | [rol_align_rois_tensor_dimension_err](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_roi_align_op.py#L34)   |  |\n| oneflow.is_floating_point | [oneflow.is_floating_point](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/is_floating_point.py#L20)   | [is_floating_point](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_tensor_ops.py#L176)   |  |  |\n| oneflow.is_nonzero |  |  |  |  |\n| oneflow.numel | [oneflow.Tensor.numel](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L194)   | [tensor_numel](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/tensor/test_tensor_part_1.py#L558)   |  |  |\n| oneflow.set_printoptions |  |  |  |  |\n| oneflow.tensor | [oneflow.tensor](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L20)   | [type_tensor](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_type_tensor.py#L74)   | [call_to_local_for_local_tensor](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_local_global_convert_error.py#L65)   |  |\n| oneflow.as_tensor | [oneflow.as_tensor](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/as_tensor.py#L20)   | [reshape_as_tensor_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/tensor/test_tensor_part_1.py#L1181)   |  |  |\n| oneflow.as_strided | [oneflow.as_strided](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L1529)   | [flow_as_strided_with_stride](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_as_stride.py#L49)   |  |  |\n| oneflow.from_numpy | [oneflow.from_numpy](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L55)   | [copy_to_and_from_numpy](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/tensor/test_tensor_part_1.py#L73)   |  |  |\n| oneflow.zeros |  | [zeros_like_float](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_zeros_like.py#L27)   |  |  |\n| oneflow.zeros_like | [oneflow.zeros_like](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/constant.py#L53)   | [zeros_like_float](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_zeros_like.py#L27)   |  |  |\n| oneflow.ones |  | [ones_like_float](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_ones_like.py#L27)   |  |  |\n| oneflow.ones_like | [oneflow.ones_like](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/constant.py#L20)   | [ones_like_float](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_ones_like.py#L27)   |  |  |\n| oneflow.randint_like | [oneflow._C.randint_like](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/random.py#L242)   | [consistent_randint_like](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_randint_like.py#L27)   |  |  |\n| oneflow.masked_fill | [oneflow.Tensor.masked_fill](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1708)   | [masked_fill](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_masked_fill.py#L58)   |  |  |\n| oneflow.new_ones | [oneflow.Tensor.new_ones](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L229)   | [flow_new_ones_list_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_constant.py#L107)   |  |  |\n| oneflow.arange | [oneflow.arange](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/arange.py#L20)   | [arange](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_arange.py#L63)   |  | done   |\n| oneflow.linspace |  | [global_linspace](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_linspace.py#L26)   |  |  |\n| oneflow.eye | [oneflow.eye](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L1597)   | [eye_forward](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_eye.py#L27)   |  | done   |\n| oneflow.empty | [oneflow.empty](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/constant.py#L119)   | [slice_empty](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_slice.py#L51)   | [reduce_sum_like_empty_axis_case_err](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_reduce_like_ops.py#L24)   |  |\n| oneflow.empty_like | [oneflow.empty_like](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/constant.py#L160)   |  |  |  |\n| oneflow.full |  | [global_full](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_full.py#L27)   |  |  |\n| oneflow.full_like |  | [full_like_with_random_data_float](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_constant.py#L154)   |  |  |\n| oneflow.tensor_scatter_nd_update |  | [global_tensor_scatter_nd_update](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_tensor_scatter_nd_update.py#L128)   | [tensor_scatter_nd_update_runtime_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_array_functor.py#L156)   |  |\n| oneflow.logspace |  | [logspace_int_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_logspace.py#L26)   |  |  |\n| oneflow.argwhere | [oneflow.Tensor.argwhere](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L713)   | [argwhere_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_argwhere.py#L50)   |  |  |\n| oneflow.atleast_1d | [oneflow.atleast_1d](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/array_ops.py#L272)   | [atleast_1d_with_list_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_atleast.py#L28)   |  |  |\n| oneflow.atleast_2d | [oneflow.atleast_2d](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/array_ops.py#L306)   | [atleast_2d_with_list_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_atleast.py#L43)   |  |  |\n| oneflow.atleast_3d | [oneflow.atleast_3d](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/array_ops.py#L344)   | [atleast_3d_with_list_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_atleast.py#L59)   |  |  |\n| oneflow.cat | [oneflow.cat](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/array_ops.py#L613)   | [cat_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_concat.py#L138)   |  |  |\n| oneflow.column_stack | [oneflow.column_stack](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/array_ops.py#L513)   | [column_stack_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_stack.py#L48)   |  |  |\n| oneflow.concat |  | [concat_with_input_0_size_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_concat.py#L164)   | [concat_index_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_array_functor.py#L37)   |  |\n| oneflow.chunk | [oneflow.Tensor.chunk](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L883)   | [flow_chunk_list_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_chunk.py#L46)   | [chunk_0_dim_input_exception](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_chunk.py#L25)   |  |\n| oneflow.dstack | [oneflow.dstack](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/array_ops.py#L481)   | [dstack_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_stack.py#L115)   |  |  |\n| oneflow.expand | [oneflow.Tensor.expand](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L130)   | [expand_new_dims_broadcast](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_expand_op.py#L28)   | [expand_dim_runtime_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_array_functor.py#L78)   |  |\n| oneflow.gather | [oneflow.Tensor.gather](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1531)   | [gather_nd](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_gather_nd.py#L85)   | [gather_index_type_runtime_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_array_functor.py#L120)   | done   |\n| oneflow.gather_nd | [oneflow.gather_nd](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/array_ops.py#L685)   | [gather_nd](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_gather_nd.py#L85)   |  |  |\n| oneflow.batch_gather | [oneflow.batch_gather](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/array_ops.py#L199)   | [batch_gather](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_batch_gather.py#L74)   |  |  |\n| oneflow.hsplit | [oneflow.hsplit](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L1674)   | [flow_hsplit_vec](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_hsplit.py#L27)   |  |  |\n| oneflow.hstack | [oneflow.hstack](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/array_ops.py#L413)   | [hstack_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_stack.py#L80)   |  |  |\n| oneflow.vsplit | [oneflow.vsplit](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L1714)   | [flow_vsplit_vec](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_vsplit.py#L27)   |  |  |\n| oneflow.vstack | [oneflow.vstack](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/array_ops.py#L447)   | [vstack_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_stack.py#L98)   |  |  |\n| oneflow.index_select | [oneflow.Tensor.index_select](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L185)   | [index_select_by_random](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_index_select.py#L30)   | [index_select_runtime_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_array_functor.py#L330)   |  |\n| oneflow.masked_select | [oneflow.Tensor.masked_select](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1715)   | [masked_select](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_masked_select.py#L87)   |  |  |\n| oneflow.movedim | [oneflow.movedim](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L1496)   | [flow_movedim_with_vector](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_movedim.py#L27)   |  |  |\n| oneflow.narrow | [oneflow.Tensor.narrow](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L629)   | [narrow](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_narrow.py#L35)   | [narrow_dim_index_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_array_functor.py#L178)   |  |\n| oneflow.nonzero | [oneflow.nonzero](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/nonzero.py#L20)   | [nonzero](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_nonzero.py#L51)   |  |  |\n| oneflow.permute | [oneflow.Tensor.permute](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L643)   | [einsum_batch_permute](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_einsum_batch_permute.py#L42)   |  |  |\n| oneflow.repeat | [oneflow.Tensor.repeat](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1622)   | [flow_tensor_repeat_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_repeat.py#L27)   | [repeat_interleave_index_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_repeat_interleave.py#L25)   |  |\n| oneflow.reshape | [oneflow.Tensor.reshape](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1858)   | [reshape_impl](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_reshape.py#L27)   | [reshape_like_size_match_err](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_reshape_like_op.py#L24)   | done   |\n| oneflow.row_stack | [oneflow.row_stack](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/array_ops.py#L547)   | [row_stack_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_stack.py#L64)   |  |  |\n| oneflow.select | [oneflow.select](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L1467)   | [flow_select](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_select.py#L28)   | [index_select_runtime_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_array_functor.py#L330)   |  |\n| oneflow.scatter |  | [global_tensor_scatter_nd_update](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_tensor_scatter_nd_update.py#L128)   | [tensor_scatter_nd_update_runtime_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_array_functor.py#L156)   |  |\n| oneflow.scatter_add |  | [scatter_add_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_scatter_ops.py#L57)   |  |  |\n| oneflow.scatter_nd |  | [global_tensor_scatter_nd_update](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_tensor_scatter_nd_update.py#L128)   | [tensor_scatter_nd_update_runtime_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_array_functor.py#L156)   |  |\n| oneflow.slice |  | [slice_grad_grad_impl](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_higher_derivative_slice.py#L38)   | [slice_update_start_list_err](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_slice_op.py#L23)   |  |\n| oneflow.slice_update |  | [slice_update](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_slice_update.py#L120)   | [slice_update_start_list_err](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_slice_op.py#L23)   |  |\n| oneflow.split | [oneflow.Tensor.split](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L890)   | [eager_boxing_2d_special_split_axis](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_eager_boxing_exhaustive.py#L146)   | [local_to_global_with_invalid_split_axis](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_local_global_convert_error.py#L39)   |  |\n| oneflow.squeeze | [oneflow.Tensor.squeeze](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L556)   | [squeeze](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_squeeze.py#L94)   | [squeeze_index_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_array_functor.py#L106)   |  |\n| oneflow.stack | [oneflow.stack](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/array_ops.py#L382)   | [stack_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_stack.py#L28)   | [stack_index_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_array_functor.py#L62)   |  |\n| oneflow.swapaxes | [oneflow.Tensor.swapaxes](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L904)   | [swapaxes_flow_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_swapaxes.py#L31)   |  |  |\n| oneflow.swapdims | [oneflow.Tensor.swapdims](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L918)   | [swapdims_flow_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_swapdims.py#L32)   |  |  |\n| oneflow.t | [oneflow.Tensor.t](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1640)   | [global_tensor_scatter_nd_update_t](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_tensor_scatter_nd_update.py#L140)   | [t_runtime_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_array_functor.py#L439)   |  |\n| oneflow.tile | [oneflow.tile](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tile.py#L20)   | [flow_tile_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_tile.py#L27)   | [tile_runtime_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_array_functor.py#L431)   |  |\n| oneflow.transpose | [oneflow.Tensor.transpose](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L513)   | [einsum_matrix_transpose](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_einsum_matrix_transpose.py#L35)   |  |  |\n| oneflow.unbind | [oneflow.Tensor.unbind](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L897)   | [unbind_flow_with_random_data1](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_unbind.py#L32)   | [unbind_index_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_array_functor.py#L248)   |  |\n| oneflow.unsqueeze | [oneflow.Tensor.unsqueeze](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L636)   | [unsqueeze](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_unsqueeze.py#L68)   |  |  |\n| oneflow.where | [oneflow.Tensor.where](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L2129)   | [where](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_where.py#L196)   |  |  |\n| oneflow.tensor_split | [oneflow.tensor_split](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L1634)   | [flow_tensor_split_vec](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_tensor_split.py#L27)   |  |  |\n| oneflow.seed |  | [generator_manual_seed](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_generator.py#L72)   |  |  |\n| oneflow.manual_seed |  | [generator_manual_seed](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_generator.py#L72)   |  |  |\n| oneflow.initial_seed |  |  |  |  |\n| oneflow.get_rng_state |  | [get_rng_state](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_generator.py#L137)   |  |  |\n| oneflow.set_rng_state |  | [set_rng_state](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_generator.py#L148)   |  |  |\n| oneflow.bernoulli | [oneflow.bernoulli](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/random.py#L20)   | [bernoulli](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_bernoulli.py#L56)   |  |  |\n| oneflow.normal | [oneflow._C.normal](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/random.py#L154)   | [eager_boxing_normal_1d_exhaustive_testing](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_eager_boxing_exhaustive.py#L113)   | [normal_data_type_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_nn_functor.py#L278)   |  |\n| oneflow.rand | [oneflow._C.rand](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/random.py#L112)   | [0d_rand](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_rand.py#L45)   |  |  |\n| oneflow.randint | [oneflow._C.randint](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/random.py#L191)   | [global_randint](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_randint.py#L27)   |  |  |\n| oneflow.randn | [oneflow._C.randn](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/random.py#L71)   | [randn](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_randn.py#L103)   |  |  |\n| oneflow.randperm | [oneflow._C.randperm](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/random.py#L291)   | [global_randperm](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_randperm.py#L26)   | [randperm_n_value_err_mes](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_randperm_op.py#L24)   |  |\n| oneflow.save |  | [save_state_dict](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_module.py#L222)   |  |  |\n| oneflow.load |  | [resnet18_load_weight_compatibile](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_resnet_load_torch_weight_compatibile.py#L30)   |  |  |\n| oneflow.set_num_threads | [oneflow.set_num_threads](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/oneflow.py#L20)   |  |  |  |\n| oneflow.no_grad |  | [no_grad](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_autograd_mode.py#L62)   |  |  |\n| oneflow.set_grad_enabled |  | [set_grad_enabled](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_autograd_mode.py#L74)   |  |  |\n| oneflow.enable_grad |  | [enable_grad](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_autograd_mode.py#L50)   |  |  |\n| oneflow.is_grad_enabled |  |  |  |  |\n| oneflow.inference_mode |  | [inference_mode](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_autograd_mode.py#L27)   |  |  |\n| oneflow.abs | [oneflow.abs](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L20)   | [abs_with_0_size_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_abs.py#L27)   |  | done   |\n| oneflow.acos | [oneflow.acos](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L509)   | [acos](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_math_ops.py#L122)   |  |  |\n| oneflow.acosh | [oneflow.acosh](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L535)   | [acosh](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_math_ops.py#L138)   |  |  |\n| oneflow.arccos | [oneflow.Tensor.arccos](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L664)   | [arccos](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_math_ops.py#L114)   |  |  |\n| oneflow.arccosh | [oneflow.Tensor.arccosh](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L678)   | [arccosh](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_math_ops.py#L130)   |  |  |\n| oneflow.add | [oneflow.add](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L41)   | [scatter_add_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_scatter_ops.py#L57)   | [add_inplace_runtime_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_binary_functor_exception.py#L27)   | done   |\n| oneflow.addcdiv | [oneflow.Tensor.addcdiv](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L939)   | [addcdiv](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_addcdiv.py#L25)   |  | done   |\n| oneflow.addcmul | [oneflow.addcmul](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L1558)   | [addcmul](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_addcmul.py#L37)   |  | done   |\n| oneflow.asin | [oneflow.asin](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L285)   | [flow_asin_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_math_ops.py#L223)   |  |  |\n| oneflow.asinh | [oneflow.asinh](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L318)   | [flow_asinh_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_math_ops.py#L240)   |  |  |\n| oneflow.arcsin | [oneflow.Tensor.arcsin](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1257)   | [flow_arcsin_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_math_ops.py#L230)   |  |  |\n| oneflow.arcsinh | [oneflow.Tensor.arcsinh](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1264)   | [flow_arcsinh_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_math_ops.py#L247)   |  |  |\n| oneflow.atan | [oneflow.atan](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L353)   | [flow_atan_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_math_ops.py#L267)   |  |  |\n| oneflow.atanh | [oneflow.atanh](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L564)   | [flow_atanh_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_math_ops.py#L289)   |  |  |\n| oneflow.arctan | [oneflow.Tensor.arctan](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1343)   | [flow_arctan_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_math_ops.py#L274)   |  |  |\n| oneflow.arctanh | [oneflow.Tensor.arctanh](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L685)   | [flow_arctanh_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_math_ops.py#L296)   |  |  |\n| oneflow.atan2 | [oneflow.Tensor.atan2](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L123)   | [atan2](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_math_ops.py#L155)   |  |  |\n| oneflow.ceil | [oneflow.ceil](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L378)   | [ceil_flow_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_ceil.py#L29)   |  |  |\n| oneflow.clamp | [oneflow.clamp](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/clamp.py#L20)   | [clamp](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_clamp.py#L96)   |  |  |\n| oneflow.clamp_min | [oneflow.clamp_min](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/clamp.py#L70)   | [clamp_min_none_flow_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_clamp.py#L119)   |  |  |\n| oneflow.clamp_max | [oneflow.clamp_max](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/clamp.py#L111)   | [clamp_max_none_flow_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_clamp.py#L126)   |  |  |\n| oneflow.clip | [oneflow.clip](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/clamp.py#L152)   | [adagrad_clip_grad](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_optim_adagrad.py#L213)   |  |  |\n| oneflow.cos | [oneflow.cos](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L712)   | [global_cos_grad_grad](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_math_op_higher_derivative.py#L65)   |  |  |\n| oneflow.cosh | [oneflow.cosh](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L736)   |  |  |  |\n| oneflow.div | [oneflow.div](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L143)   | [div_grad_grad_impl](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_higher_derivative_div.py#L26)   | [div_inplace_runtime_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_binary_functor_exception.py#L81)   | done   |\n| oneflow.erf | [oneflow.erf](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L763)   | [flow_erf_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_erf.py#L33)   |  | done   |\n| oneflow.erfc | [oneflow.erfc](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L810)   | [erfc_impl](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_erfc.py#L25)   |  | done   |\n| oneflow.erfinv | [oneflow.Tensor.erfinv](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L997)   | [flow_erfinv_with_inf_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_erfinv.py#L30)   |  | done   |\n| oneflow.exp | [oneflow.exp](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L476)   | [exp](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_math_ops.py#L72)   |  |  |\n| oneflow.expm1 | [oneflow.expm1](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L845)   | [expm1_impl](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_expm1.py#L29)   |  | done   |\n| oneflow.floor | [oneflow.floor](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L100)   | [floor](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_floor.py#L35)   |  | done   |\n| oneflow.floor_ | [oneflow.floor_](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L135)   | [flow_floor_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_floor.py#L57)   |  |  |\n| oneflow.fmod | [oneflow.fmod](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L890)   | [flow_fmod_element_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/tensor/test_tensor_part_1.py#L1021)   |  | done   |\n| oneflow.gelu | [oneflow.gelu](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L74)   | [gelu_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L253)   |  | done   |\n| oneflow.log | [oneflow.log](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L923)   | [log](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_math_ops.py#L56)   |  |  |\n| oneflow.log1p | [oneflow.log1p](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L455)   | [log1p_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_log1p.py#L31)   |  |  |\n| oneflow.log2 | [oneflow.log2](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L948)   | [log2_tensor_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/tensor/test_tensor_part_1.py#L971)   |  |  |\n| oneflow.logical_and | [oneflow.Tensor.logical_and](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1677)   | [logical_and](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_logical_and.py#L58)   |  |  |\n| oneflow.logical_not | [oneflow.Tensor.logical_not](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L520)   | [logical_not](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_logical_not.py#L43)   |  |  |\n| oneflow.logical_or | [oneflow.Tensor.logical_or](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1687)   | [logical_or](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_logical_or.py#L58)   |  |  |\n| oneflow.logical_xor | [oneflow.Tensor.logical_xor](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1698)   | [logical_xor_int](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_logical_xor.py#L27)   |  |  |\n| oneflow.mish | [oneflow.mish](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L267)   | [mish_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L698)   |  | done   |\n| oneflow.mul | [oneflow.mul](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L186)   | [einsum_eltwise_mul_then_reduce_sum](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_einsum_eltwise_mul_then_reduce_sum.py#L40)   |  |  |\n| oneflow.neg | [oneflow.Tensor.neg](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1039)   | [flow_split_sizes_neg_dim_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_split.py#L63)   | [tensordot_neg_dims_runtime_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_tensordot.py#L25)   |  |\n| oneflow.negative | [oneflow.negative](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L428)   | [argmax_axis_negative](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_argmax.py#L29)   | [repeat_interleave_negative_tensor_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_repeat_interleave.py#L58)   |  |\n| oneflow.pow | [oneflow.pow](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L1132)   | [pow_with_scalar](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_math_ops.py#L96)   |  |  |\n| oneflow.reciprocal | [oneflow.reciprocal](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L226)   | [flow_reciprocal_list_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_reciprocal.py#L32)   |  |  |\n| oneflow.round | [oneflow.round](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L1346)   | [flow_round_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_round.py#L30)   |  |  |\n| oneflow.rsqrt | [oneflow.rsqrt](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L1173)   | [rsqrt](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_math_ops.py#L80)   |  |  |\n| oneflow.selu | [oneflow.selu](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L409)   | [selu_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L754)   |  | done   |\n| oneflow.softmax | [oneflow._C.softmax](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L118)   | [softmax_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L436)   | [softmax_index_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_activation.py#L109)   | done   |\n| oneflow.softplus | [oneflow.softplus](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L146)   | [softplus](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_softplus.py#L43)   |  | done   |\n| oneflow.softsign | [oneflow._C.softsign](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L207)   | [softsign_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L782)   |  | done   |\n| oneflow.silu | [oneflow.silu](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L237)   | [silu_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L726)   |  | done   |\n| oneflow.sigmoid | [oneflow.sigmoid](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L338)   | [sigmoid_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L281)   | [hard_sigmoid_inplace_runtime_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_activation.py#L87)   | done   |\n| oneflow.sign | [oneflow.sign](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L589)   | [sign_impl](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_sign.py#L25)   |  |  |\n| oneflow.sin | [oneflow.sin](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L618)   | [global_sin_grad_grad](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_math_op_higher_derivative.py#L59)   |  |  |\n| oneflow.sinh | [oneflow.sinh](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L656)   | [sinh](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_math_ops.py#L23)   |  |  |\n| oneflow.sin_ | [oneflow.sin_](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L648)   | [global_sin_grad_grad](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_math_op_higher_derivative.py#L59)   |  |  |\n| oneflow.sqrt | [oneflow.sqrt](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L1198)   | [sqrt](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_math_ops.py#L64)   |  |  |\n| oneflow.square | [oneflow.square](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L1224)   | [inv_random_square_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_inv.py#L39)   | [inv_exception_not_square_matrix](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_inv.py#L34)   |  |\n| oneflow.sub | [oneflow.sub](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L246)   | [global_sub](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_sub.py#L50)   |  |  |\n| oneflow.tan | [oneflow.tan](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L687)   | [flow_tan_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_math_ops.py#L257)   |  |  |\n| oneflow.tanh | [oneflow.tanh](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/activation.py#L163)   | [tanh_module_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L106)   |  | done   |\n| oneflow.floor_divide |  |  |  |  |\n| oneflow.argmax | [oneflow.Tensor.argmax](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L692)   | [argmax_axis_negative](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_argmax.py#L29)   | [argmax_index_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_array_functor.py#L22)   | done   |\n| oneflow.argmin | [oneflow.Tensor.argmin](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L699)   | [argmin_axis_negative](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_argmin.py#L29)   |  |  |\n| oneflow.amax | [oneflow.Tensor.amax](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L911)   | [amax_with_negative_dim](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_amax.py#L35)   |  | done   |\n| oneflow.amin | [oneflow.Tensor.amin](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L2167)   | [amin_with_negative_dim](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_amin.py#L34)   |  | done   |\n| oneflow.any | [oneflow.Tensor.any](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1831)   | [any_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_logical_reduce.py#L52)   |  |  |\n| oneflow.max | [oneflow.Tensor.max](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1774)   | [moving_average_min_max_observer](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_moving_average_max_min_observer.py#L83)   |  |  |\n| oneflow.min | [oneflow.Tensor.min](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1783)   | [moving_average_min_max_observer](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_moving_average_max_min_observer.py#L83)   |  |  |\n| oneflow.mean | [oneflow.Tensor.mean](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1840)   | [mean](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_mean.py#L70)   | [normalization_moving_mean_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_nn_functor.py#L317)   |  |\n| oneflow.median | [oneflow.median](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L1019)   | [median](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_median.py#L48)   | [median_exception_dim_out_of_range](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_median.py#L25)   |  |\n| oneflow.prod | [oneflow.Tensor.prod](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1849)   | [prod_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_logical_reduce.py#L59)   |  |  |\n| oneflow.std | [oneflow.std](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L1371)   | [std_flow_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_std.py#L26)   |  |  |\n| oneflow.sum | [oneflow.Tensor.sum](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1813)   | [sum_impl](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_sum.py#L29)   | [reduce_sum_like_empty_axis_case_err](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_reduce_like_ops.py#L24)   |  |\n| oneflow.var | [oneflow.var](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L1407)   | [module_to_with_var_reuse](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_module_to.py#L93)   |  |  |\n| oneflow.norm | [oneflow.linalg.norm](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/norm.py#L160)   | [clip_grad_norm_impl](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_clip_grad.py#L50)   |  |  |\n| oneflow.all | [oneflow.Tensor.all](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1822)   | [flow_var_all_dim_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_var.py#L27)   |  |  |\n| oneflow.argsort | [oneflow.Tensor.argsort](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L706)   | [argsort](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_argsort.py#L37)   |  | done   |\n| oneflow.eq | [oneflow.Tensor.eq](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1011)   | [eq_impl](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_eq.py#L25)   |  | done   |\n| oneflow.equal |  | [softmax_module_with_batch_size_equal_1024](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_activation.py#L464)   | [concat_dim_equal_runtime_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_array_functor.py#L44)   |  |\n| oneflow.gt | [oneflow.Tensor.gt](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1080)   |  |  | done   |\n| oneflow.isinf | [oneflow.Tensor.isinf](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L2152)   | [isinf](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_util_ops.py#L33)   |  |  |\n| oneflow.isnan | [oneflow.Tensor.isnan](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L2145)   | [isnan](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_util_ops.py#L24)   |  |  |\n| oneflow.le | [oneflow.Tensor.le](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1025)   |  |  |  |\n| oneflow.lt | [oneflow.Tensor.lt](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1018)   |  |  |  |\n| oneflow.ne | [oneflow.Tensor.ne](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1032)   | [ne](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_ne.py#L31)   |  |  |\n| oneflow.sort | [oneflow.Tensor.sort](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1947)   | [sort](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_sort.py#L69)   |  |  |\n| oneflow.topk | [oneflow.Tensor.topk](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1751)   | [flow_topk_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_math_ops.py#L306)   |  |  |\n| oneflow.ge | [oneflow.Tensor.ge](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1062)   |  |  |  |\n| oneflow.greater | [oneflow.greater](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/comparison.py#L21)   | [greater_normal](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_greater.py#L29)   |  |  |\n| oneflow.greater_equal | [oneflow.greater_equal](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/comparison.py#L49)   | [greater_equal_impl](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_greater_equal.py#L25)   |  |  |\n| oneflow.maximum | [oneflow.maximum](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L997)   | [broadcast_maximum](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_maximum_minimum.py#L32)   |  |  |\n| oneflow.minimum | [oneflow.minimum](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L975)   | [broadcast_minimum](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_maximum_minimum.py#L50)   |  |  |\n| oneflow.not_equal |  |  |  |  |\n| oneflow.hann_window | [oneflow.hann_window](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/hann_window.py#L20)   | [global_hann_window](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_hann_window.py#L26)   | [hann_window_dtype_not_support](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_hann_window.py#L25)   | done   |\n| oneflow.adaptive_avg_pool1d | [oneflow._C.adaptive_avg_pool1d](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/pooling.py#L20)   |  |  | done   |\n| oneflow.adaptive_avg_pool2d | [oneflow._C.adaptive_avg_pool2d](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/pooling.py#L48)   |  |  | done   |\n| oneflow.adaptive_avg_pool3d | [oneflow._C.adaptive_avg_pool3d](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/pooling.py#L74)   |  |  | done   |\n| oneflow.broadcast_like | [oneflow.broadcast_like](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/broadcast_like.py#L20)   | [broadcast_like](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_broadcast_like.py#L161)   | [broadcast_like_runtime_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_array_functor.py#L28)   |  |\n| oneflow.cast | [oneflow.cast](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/cast.py#L20)   | [cast_float2int](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_cast.py#L28)   | [add_broad_cast_runtime_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_binary_functor_exception.py#L37)   |  |\n| oneflow.cumprod | [oneflow.cumprod](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L1788)   | [cumprod_impl](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_cumprod.py#L25)   |  | done   |\n| oneflow.cumsum | [oneflow.cumsum](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L1755)   | [cumsum](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_cumsum.py#L37)   |  | done   |\n| oneflow.diag | [oneflow.Tensor.diag](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L932)   | [diag_impl](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_diag.py#L26)   |  | done   |\n| oneflow.diagonal | [oneflow.Tensor.diagonal](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1294)   | [diagonal_impl](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_diagonal.py#L24)   | [diagonal_index_error1](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_array_functor.py#L204)   | done   |\n| oneflow.einsum | [oneflow.einsum](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/einsum.py#L20)   | [einsum_alphaflod_usecase11](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_einsum_alphaflod_usecase11.py#L38)   |  |  |\n| oneflow.flatten | [oneflow.Tensor.flatten](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L155)   | [to_global_flatten_hierarchy](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_cast.py#L30)   |  | done   |\n| oneflow.flip | [oneflow.Tensor.flip](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L169)   | [image_flip](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_image_flip.py#L70)   |  | done   |\n| oneflow.in_top_k | [oneflow.Tensor.in_top_k](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L176)   | [in_top_k_impl](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_in_top_k.py#L82)   | [in_top_k_num_equal_runtime_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_array_functor.py#L389)   |  |\n| oneflow.meshgrid | [oneflow.meshgrid](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/meshgrid.py#L20)   | [meshgrid_forawd](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_meshgrid.py#L29)   | [meshgrid_tensors_scalar_runtime_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_array_functor.py#L276)   |  |\n| oneflow.nms | [oneflow.Tensor.nms](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1758)   | [nms](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_nms.py#L50)   |  |  |\n| oneflow.roc_auc_score | [oneflow.roc_auc_score](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/roc_auc_score.py#L20)   | [roc_auc_score](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_roc_auc_score.py#L52)   |  |  |\n| oneflow.roll | [oneflow.Tensor.roll](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1187)   | [roll](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_roll.py#L27)   | [roll_runtime_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_array_functor.py#L112)   |  |\n| oneflow.searchsorted | [oneflow.searchsorted](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/searchsorted.py#L20)   |  |  |  |\n| oneflow.tensordot | [oneflow.tensordot](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensordot.py#L20)   | [tensordot_intdim](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_tensordot.py#L28)   | [tensordot_neg_dims_runtime_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_tensordot.py#L25)   |  |\n| oneflow.tril | [oneflow.Tensor.tril](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1441)   | [fused_scale_tril](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_fused_scale_tril.py#L78)   |  |  |\n| oneflow.repeat_interleave | [oneflow.Tensor.repeat_interleave](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1631)   | [flow_int_repeat_interleave_dim_none](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_repeat_interleave.py#L29)   | [repeat_interleave_index_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_repeat_interleave.py#L25)   |  |\n| oneflow.triu | [oneflow.Tensor.triu](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1448)   | [triu](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_triu.py#L47)   |  |  |\n| oneflow.addmm | [oneflow.Tensor.addmm](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1215)   | [addmm](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_addmm.py#L60)   |  | done   |\n| oneflow.bmm | [oneflow.Tensor.bmm](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L876)   | [bmm](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_bmm.py#L93)   | [bmm_exception_dim_not_right](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_bmm.py#L25)   |  |\n| oneflow.dot | [oneflow.dot](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L1438)   | [dot](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/tensor/test_tensor_part_1.py#L903)   | [dot_shape_error_msg](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_dot.py#L24)   | done   |\n| oneflow.matmul | [oneflow.matmul](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L1249)   | [fused_matmul_op](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_cublas_fused_mlp.py#L173)   | [matmul_dimension_error1](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_nn_functor.py#L220)   |  |\n| oneflow.mm | [oneflow.mm](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L1311)   | [flow_mm_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_matmul.py#L69)   | [mm_not_2dim](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_mm.py#L24)   |  |\n| oneflow.mv | [oneflow.mv](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/math_ops.py#L1278)   | [flow_mv_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_matmul.py#L78)   | [mv_not_matrix](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_mv.py#L23)   | done   |\n| oneflow.env.all_device_placement |  |  |  |  |\n| oneflow.env.get_world_size |  |  |  |  |\n| oneflow.env.get_rank |  |  |  |  |\n| oneflow.env.get_local_rank |  |  |  |  |\n| oneflow.env.get_node_size |  |  |  |  |\n| oneflow.env.init_rdma |  |  |  |  |\n| oneflow.env.rdma_is_initialized |  |  |  |  |\n| oneflow.comm.all_reduce |  | [all_reduce_1n2d](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_comm_ops.py#L31)   |  |  |\n| oneflow.comm.all_gather |  | [all_gather_1n2d](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_comm_ops.py#L48)   |  |  |\n| oneflow.comm.all_to_all |  | [all_to_all_1n4d](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_comm_ops.py#L148)   |  |  |\n| oneflow.comm.broadcast |  | [cosine_similartiy_broadcast_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_cosine_similarity.py#L45)   | [cosine_similarity_broadcast](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_cosine_similarity.py#L34)   |  |\n| oneflow.comm.barrier |  |  |  |  |\n| oneflow.comm.gather | [oneflow.Tensor.gather](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1531)   | [gather_nd](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_gather_nd.py#L85)   | [gather_index_type_runtime_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_array_functor.py#L120)   | done   |\n| oneflow.comm.reduce |  | [min_reduce_random_dim](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_min.py#L28)   | [reduce_sum_like_empty_axis_case_err](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_reduce_like_ops.py#L24)   |  |\n| oneflow.comm.reduce_scatter |  | [reduce_scatter_1n4d](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_comm_ops.py#L167)   |  |  |\n| oneflow.comm.recv | [oneflow.comm.recv](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/comm.py#L32)   | [send_recv](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_comm.py#L28)   |  |  |\n| oneflow.comm.scatter |  | [global_tensor_scatter_nd_update](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_tensor_scatter_nd_update.py#L128)   | [tensor_scatter_nd_update_runtime_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_array_functor.py#L156)   |  |\n| oneflow.comm.send | [oneflow.comm.send](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/comm.py#L20)   | [send_recv](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_comm.py#L28)   |  |  |\n| oneflow.linalg.norm | [oneflow.linalg.norm](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/norm.py#L160)   | [clip_grad_norm_impl](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_clip_grad.py#L50)   |  |  |\n| oneflow.linalg.vector_norm | [oneflow.linalg.vector_norm](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/norm.py#L21)   | [vector_norm_only_zero_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_norm.py#L318)   |  |  |\n| oneflow.linalg.matrix_norm | [oneflow.linalg.matrix_norm](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/norm.py#L88)   |  |  |  |\n| oneflow.linalg.diagonal | [oneflow.Tensor.diagonal](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1294)   | [diagonal_impl](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_diagonal.py#L24)   | [diagonal_index_error1](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_array_functor.py#L204)   | done   |\n| oneflow.linalg.inv | [oneflow.linalg.inv](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/inv.py#L21)   | [inv_3by3_with_random_data](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_inv.py#L27)   | [inv_exception_dim_short](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_inv.py#L25)   | done   |\n| oneflow.optim.Optimizer.add_param_group |  | [sgd_add_param_group](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_optim_add_param_group.py#L44)   | [sgd_add_param_group_not_unique](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_optim_add_param_group.py#L23)   |  |\n| oneflow.optim.Optimizer.load_state_dict |  | [load_state_dict](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_module.py#L63)   |  |  |\n| oneflow.optim.Optimizer.state_dict |  | [load_state_dict](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_module.py#L63)   |  |  |\n| oneflow.optim.Optimizer.step |  | [arange_step_prarm](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_arange.py#L35)   | [slice_update_step_list_err](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_slice_op.py#L49)   |  |\n| oneflow.optim.Optimizer.zero_grad |  |  |  |  |\n| oneflow.optim.Adagrad |  | [one_embedding_adagrad](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_one_embedding_adagrad.py#L174)   |  |  |\n| oneflow.optim.Adam |  | [multi_tensor_adam_update](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_multi_tensor_adam_update.py#L157)   |  |  |\n| oneflow.optim.AdamW |  | [adamw](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_optim_adamw.py#L244)   |  |  |\n| oneflow.optim.LAMB |  | [lamb](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_optim_lamb.py#L157)   |  |  |\n| oneflow.optim.RMSprop |  | [rmsprop](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_optim_rmsprop.py#L228)   |  |  |\n| oneflow.optim.SGD |  | [one_embedding_sgd](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_one_embedding_sgd.py#L190)   | [sgd_add_param_group_not_unique](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_optim_add_param_group.py#L23)   |  |\n| oneflow.optim.lr_scheduler.CosineAnnealingLR |  |  |  |  |\n| oneflow.optim.lr_scheduler.CosineDecayLR |  |  |  |  |\n| oneflow.optim.lr_scheduler.ExponentialLR |  |  |  |  |\n| oneflow.optim.lr_scheduler.LambdaLR |  |  |  |  |\n| oneflow.optim.lr_scheduler.MultiStepLR |  |  |  |  |\n| oneflow.optim.lr_scheduler.PolynomialLR |  |  |  |  |\n| oneflow.optim.lr_scheduler.ReduceLROnPlateau |  |  |  |  |\n| oneflow.optim.lr_scheduler.StepLR |  |  |  |  |\n| oneflow.optim.lr_scheduler.ConstantLR |  |  |  |  |\n| oneflow.optim.lr_scheduler.LinearLR |  |  |  |  |\n| oneflow.optim.lr_scheduler.ChainedScheduler |  |  |  |  |\n| oneflow.optim.lr_scheduler.SequentialLR |  |  |  |  |\n| oneflow.optim.lr_scheduler.CosineAnnealingWarmRestarts |  |  |  |  |\n| oneflow.one_embedding.make_table_options |  |  |  |  |\n| oneflow.one_embedding.make_table |  |  |  |  |\n| oneflow.one_embedding.make_uniform_initializer |  |  |  |  |\n| oneflow.one_embedding.make_normal_initializer |  |  |  |  |\n| oneflow.one_embedding.make_device_mem_store_options |  |  |  |  |\n| oneflow.one_embedding.make_cached_ssd_store_options |  |  |  |  |\n| oneflow.one_embedding.make_cached_host_mem_store_options |  |  |  |  |\n| oneflow.one_embedding.MultiTableEmbedding |  |  |  |  |\n| oneflow.one_embedding.MultiTableEmbedding.forward |  | [eye_forward](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_eye.py#L27)   |  |  |\n| oneflow.one_embedding.MultiTableEmbedding.save_snapshot |  |  |  |  |\n| oneflow.one_embedding.MultiTableEmbedding.load_snapshot |  |  |  |  |\n| oneflow.one_embedding.MultiTableMultiColumnEmbedding |  |  |  |  |\n| oneflow.one_embedding.MultiTableMultiColumnEmbedding.forward |  | [eye_forward](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_eye.py#L27)   |  |  |\n| oneflow.one_embedding.MultiTableMultiColumnEmbedding.save_snapshot |  |  |  |  |\n| oneflow.one_embedding.MultiTableMultiColumnEmbedding.load_snapshot |  |  |  |  |\n| oneflow.one_embedding.make_persistent_table_reader |  |  |  |  |\n| oneflow.one_embedding.make_persistent_table_writer |  |  |  |  |\n| oneflow.one_embedding.Ftrl |  | [ftrl](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_one_embedding_ftrl.py#L191)   |  |  |\n| oneflow.nn.init.calculate_gain |  |  |  |  |\n| oneflow.nn.init.uniform_ | [oneflow.Tensor.uniform_](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1455)   |  |  |  |\n| oneflow.nn.init.normal_ | [oneflow.Tensor.normal_](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L1154)   | [eager_boxing_normal_1d_exhaustive_testing](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_eager_boxing_exhaustive.py#L113)   | [normal_data_type_error](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/exceptions/test_nn_functor.py#L278)   |  |\n| oneflow.nn.init.constant_ |  | [constant_global](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_constant.py#L99)   |  |  |\n| oneflow.nn.init.ones_ |  | [ones_like_float](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_ones_like.py#L27)   |  |  |\n| oneflow.nn.init.zeros_ |  | [zeros_like_float](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_global_zeros_like.py#L27)   |  |  |\n| oneflow.nn.init.xavier_uniform_ |  |  |  |  |\n| oneflow.nn.init.xavier_normal_ |  |  |  |  |\n| oneflow.nn.init.kaiming_uniform_ |  |  |  |  |\n| oneflow.nn.init.kaiming_normal_ |  |  |  |  |\n| oneflow.nn.init.trunc_normal_ |  |  |  |  |\n| oneflow.nn.init.orthogonal_ |  |  |  |  |\n| oneflow.nn.image.Resize |  | [image_resize_to_fixed_size](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_image_resize.py#L192)   |  |  |\n| oneflow.nn.image.batch_align |  | [image_batch_align](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_image_batch_align.py#L52)   |  |  |\n| oneflow.nn.image.decode |  | [read_decode](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_one_rec_ops.py#L78)   |  |  |\n| oneflow.nn.image.flip | [oneflow.Tensor.flip](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/tensor.py#L169)   | [image_flip](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_image_flip.py#L70)   |  | done   |\n| oneflow.nn.image.normalize | [oneflow._C.normalize](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/framework/docstr/norm.py#L268)   | [functional_normalize](https://github.com/Oneflow-Inc/oneflow/blob/5d4308ecd3c72dafe29634f964a103694e4dea5b/python/oneflow/test/../../../python/oneflow/test/modules/test_normalize.py#L54)   |  |  |\n| oneflow.utils.data.random_split |  |  |  |  |\n## Test Data Summary\n- OneFlow Total API Number: 771\n- Doc Test Ratio: 63.81% (492 / 771)\n- Compatiable/Completeness Test Ratio: 73.80% (569 / 771)\n- Exception Test Ratio: 19.71% (152 / 771)\n- Performance Test Ratio: 15.56% (120 / 771)\n"
  },
  {
    "path": "python/oneflow/test/dataloader/data_utils.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport os\nimport oneflow as flow\nimport flowvision as vision\nimport flowvision.transforms as transforms\n\n\ndef load_data_cifar10(\n    batch_size,\n    data_dir=\"./data-test/cifar10\",\n    download=True,\n    transform=None,\n    source_url=None,\n    num_workers=0,\n):\n    cifar10_train = vision.datasets.CIFAR10(\n        root=data_dir,\n        train=True,\n        download=download,\n        transform=transform,\n        source_url=source_url,\n    )\n    cifar10_test = vision.datasets.CIFAR10(\n        root=data_dir,\n        train=False,\n        download=download,\n        transform=transform,\n        source_url=source_url,\n    )\n\n    train_iter = flow.utils.data.DataLoader(\n        cifar10_train, batch_size=batch_size, shuffle=True, num_workers=num_workers\n    )\n    test_iter = flow.utils.data.DataLoader(\n        cifar10_test, batch_size=batch_size, shuffle=False, num_workers=num_workers\n    )\n    return train_iter, test_iter\n\n\ndef load_data_mnist(\n    batch_size, resize=None, root=\"./data/mnist\", download=True, source_url=None\n):\n    \"\"\"Download the MNIST dataset and then load into memory.\"\"\"\n    root = os.path.expanduser(root)\n    transformer = []\n    if resize:\n        transformer += [transforms.Resize(resize)]\n    transformer += [transforms.ToTensor()]\n    transformer = transforms.Compose(transformer)\n\n    mnist_train = vision.datasets.MNIST(\n        root=root,\n        train=True,\n        transform=transformer,\n        download=download,\n        source_url=source_url,\n    )\n    mnist_test = vision.datasets.MNIST(\n        root=root,\n        train=False,\n        transform=transformer,\n        download=download,\n        source_url=source_url,\n    )\n    train_iter = flow.utils.data.DataLoader(\n        mnist_train, batch_size, shuffle=True, num_workers=2\n    )\n    test_iter = flow.utils.data.DataLoader(\n        mnist_test, batch_size, shuffle=False, num_workers=2\n    )\n    return train_iter, test_iter\n\n\ndef get_fashion_mnist_dataset(\n    resize=None, root=\"./data-test/fashion-mnist\", download=True, source_url=None,\n):\n    root = os.path.expanduser(root)\n    trans = []\n    if resize:\n        trans.append(transforms.Resize(resize))\n    trans.append(transforms.ToTensor())\n    transform = transforms.Compose(trans)\n\n    mnist_train = vision.datasets.FashionMNIST(\n        root=root,\n        train=True,\n        transform=transform,\n        download=download,\n        source_url=source_url,\n    )\n    mnist_test = vision.datasets.FashionMNIST(\n        root=root,\n        train=False,\n        transform=transform,\n        download=download,\n        source_url=source_url,\n    )\n    return mnist_train, mnist_test\n\n\n# reference: http://tangshusen.me/Dive-into-DL-PyTorch/#/chapter03_DL-basics/3.10_mlp-pytorch\ndef load_data_fashion_mnist(\n    batch_size,\n    resize=None,\n    root=\"./data-test/fashion-mnist\",\n    download=True,\n    source_url=None,\n    num_workers=0,\n):\n    \"\"\"Download the Fashion-MNIST dataset and then load into memory.\"\"\"\n    root = os.path.expanduser(root)\n    trans = []\n    if resize:\n        trans.append(transforms.Resize(resize))\n    trans.append(transforms.ToTensor())\n    transform = transforms.Compose(trans)\n\n    mnist_train = vision.datasets.FashionMNIST(\n        root=root,\n        train=True,\n        transform=transform,\n        download=download,\n        source_url=source_url,\n    )\n    mnist_test = vision.datasets.FashionMNIST(\n        root=root,\n        train=False,\n        transform=transform,\n        download=download,\n        source_url=source_url,\n    )\n\n    train_iter = flow.utils.data.DataLoader(\n        mnist_train, batch_size, shuffle=True, num_workers=num_workers\n    )\n    test_iter = flow.utils.data.DataLoader(\n        mnist_test, batch_size, shuffle=False, num_workers=num_workers\n    )\n    return train_iter, test_iter\n"
  },
  {
    "path": "python/oneflow/test/dataloader/test_cifar_dataset_multiprocess.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport os\nimport unittest\n\nimport oneflow.unittest\nimport oneflow as flow\nimport oneflow.nn as nn\nimport oneflow.optim as optim\nfrom data_utils import load_data_cifar10\nimport flowvision as vision\nimport flowvision.transforms as transforms\n\n\nclasses = (\n    \"plane\",\n    \"car\",\n    \"bird\",\n    \"cat\",\n    \"deer\",\n    \"dog\",\n    \"frog\",\n    \"horse\",\n    \"ship\",\n    \"truck\",\n)\n\n\nclass Net(nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.conv1 = nn.Conv2d(3, 6, 5)\n        self.pool = nn.MaxPool2d(2, 2)\n        self.conv2 = nn.Conv2d(6, 16, 5)\n        self.fc1 = nn.Linear(16 * 5 * 5, 120)\n        self.fc2 = nn.Linear(120, 84)\n        self.fc3 = nn.Linear(84, 10)\n\n    def forward(self, x):\n        x = self.pool(flow._C.relu(self.conv1(x)))\n        x = self.pool(flow._C.relu(self.conv2(x)))\n        x = flow.flatten(x, 1)  # flatten all dimensions except batch\n        x = flow._C.relu(self.fc1(x))\n        x = flow._C.relu(self.fc2(x))\n        x = self.fc3(x)\n        return x\n\n\ndef _test(test_case):\n    if os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"):\n        device = flow.device(\"cpu\")\n    else:\n        device = flow.device(\"cuda\")\n    net = Net()\n    net.to(device)\n\n    optimizer = optim.SGD(net.parameters(), lr=0.002, momentum=0.9)\n    criterion = nn.CrossEntropyLoss()\n    criterion.to(device)\n\n    transform = transforms.Compose(\n        [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),]\n    )\n\n    train_epoch = 1\n    batch_size = 4\n    num_workers = 4\n    data_dir = os.path.join(\n        os.getenv(\"ONEFLOW_TEST_CACHE_DIR\", \"./data-test\"), \"cifar10\"\n    )\n\n    train_iter, test_iter = load_data_cifar10(\n        batch_size=batch_size,\n        data_dir=data_dir,\n        download=True,\n        transform=transform,\n        source_url=\"https://oneflow-public.oss-cn-beijing.aliyuncs.com/datasets/cifar/cifar-10-python.tar.gz\",\n        num_workers=num_workers,\n    )\n\n    final_loss = 0\n    for epoch in range(1, train_epoch + 1):  # loop over the dataset multiple times\n        running_loss = 0.0\n        for i, data in enumerate(train_iter, 1):\n            # get the inputs; data is a list of [inputs, labels]\n            inputs, labels = data\n            inputs = inputs.to(dtype=flow.float32, device=device)\n            labels = labels.to(dtype=flow.int64, device=device)\n\n            # zero the parameter gradients\n            optimizer.zero_grad()\n\n            # forward + backward + optimize\n            outputs = net(inputs)\n            loss = criterion(outputs, labels)\n            loss.backward()\n            optimizer.step()\n\n            # print statistics\n            running_loss += loss.item()\n            if i % 200 == 0:  # print every 200 mini-batches\n                final_loss = running_loss / 200\n                print(\"epoch: %d  step: %5d  loss: %.3f \" % (epoch, i, final_loss))\n                running_loss = 0.0\n                break\n\n    print(\"final loss : \", final_loss)\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestCifarDataset(flow.unittest.TestCase):\n    def test_cifar_dataset(test_case):\n        _test(test_case)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/dataloader/test_cifar_dataset_singleprocess.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport os\nimport unittest\n\nimport flowvision as vision\nimport flowvision.transforms as transforms\n\nimport oneflow.unittest\nimport oneflow as flow\nimport oneflow.nn as nn\nimport oneflow.optim as optim\nfrom data_utils import load_data_cifar10\n\n\nclasses = (\n    \"plane\",\n    \"car\",\n    \"bird\",\n    \"cat\",\n    \"deer\",\n    \"dog\",\n    \"frog\",\n    \"horse\",\n    \"ship\",\n    \"truck\",\n)\n\n\nclass Net(nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.conv1 = nn.Conv2d(3, 6, 5)\n        self.pool = nn.MaxPool2d(2, 2)\n        self.conv2 = nn.Conv2d(6, 16, 5)\n        self.fc1 = nn.Linear(16 * 5 * 5, 120)\n        self.fc2 = nn.Linear(120, 84)\n        self.fc3 = nn.Linear(84, 10)\n\n    def forward(self, x):\n        x = self.pool(flow._C.relu(self.conv1(x)))\n        x = self.pool(flow._C.relu(self.conv2(x)))\n        x = flow.flatten(x, 1)  # flatten all dimensions except batch\n        x = flow._C.relu(self.fc1(x))\n        x = flow._C.relu(self.fc2(x))\n        x = self.fc3(x)\n        return x\n\n\ndef _test(test_case):\n    if os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"):\n        device = flow.device(\"cpu\")\n    else:\n        device = flow.device(\"cuda\")\n    net = Net()\n    net.to(device)\n\n    optimizer = optim.SGD(net.parameters(), lr=0.002, momentum=0.9)\n    criterion = nn.CrossEntropyLoss()\n    criterion.to(device)\n\n    transform = transforms.Compose(\n        [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),]\n    )\n\n    train_epoch = 1\n    batch_size = 4\n    num_workers = 0\n    data_dir = os.path.join(\n        os.getenv(\"ONEFLOW_TEST_CACHE_DIR\", \"./data-test\"), \"cifar10\"\n    )\n\n    train_iter, test_iter = load_data_cifar10(\n        batch_size=batch_size,\n        data_dir=data_dir,\n        download=True,\n        transform=transform,\n        source_url=\"https://oneflow-public.oss-cn-beijing.aliyuncs.com/datasets/cifar/cifar-10-python.tar.gz\",\n        num_workers=num_workers,\n    )\n\n    final_loss = 0\n    for epoch in range(1, train_epoch + 1):  # loop over the dataset multiple times\n        running_loss = 0.0\n        for i, data in enumerate(train_iter, 1):\n            # get the inputs; data is a list of [inputs, labels]\n            inputs, labels = data\n            inputs = inputs.to(dtype=flow.float32, device=device)\n            labels = labels.to(dtype=flow.int64, device=device)\n\n            # zero the parameter gradients\n            optimizer.zero_grad()\n\n            # forward + backward + optimize\n            outputs = net(inputs)\n            loss = criterion(outputs, labels)\n            loss.backward()\n            optimizer.step()\n\n            # print statistics\n            running_loss += loss.item()\n            if i % 200 == 0:  # print every 200 mini-batches\n                final_loss = running_loss / 200\n                print(\"epoch: %d  step: %5d  loss: %.3f \" % (epoch, i, final_loss))\n                running_loss = 0.0\n                break\n\n    print(\"final loss : \", final_loss)\n    # test_case.assertLess(final_loss, 1.50)\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestCifarDataset(flow.unittest.TestCase):\n    def test_cifar_dataset(test_case):\n        _test(test_case)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/dataloader/test_fashion_mnist_dataset.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport os\nimport unittest\nimport time\n\nimport oneflow.unittest\nimport oneflow as flow\nimport oneflow.nn as nn\nfrom data_utils import load_data_fashion_mnist\n\n\ndef get_fashion_mnist_labels(labels):\n    \"\"\"Get text labels for Fashion-MNIST.\"\"\"\n    text_labels = [\n        \"t-shirt\",\n        \"trouser\",\n        \"pullover\",\n        \"dress\",\n        \"coat\",\n        \"sandal\",\n        \"shirt\",\n        \"sneaker\",\n        \"bag\",\n        \"ankle boot\",\n    ]\n    return [text_labels[int(i)] for i in labels]\n\n\nclass FlattenLayer(nn.Module):\n    def __init__(self):\n        super(FlattenLayer, self).__init__()\n\n    def forward(self, x):  # x shape: (batch, *, *, ...)\n        res = x.reshape(x.shape[0], -1)\n        return res\n\n\ndef evaluate_accuracy(data_iter, net, device=None):\n    if device is None and isinstance(net, nn.Module):\n        # using net device if not specified\n        device = list(net.parameters())[0].device\n    acc_sum, n = 0.0, 0\n    net.eval()\n    with flow.no_grad():\n        for X, y in data_iter:\n            X = X.to(device=device)\n            y = y.to(device=device)\n            acc_sum += (\n                net(X.to(device)).argmax(dim=1).numpy() == y.to(device).numpy()\n            ).sum()\n            n += y.shape[0]\n    net.train()\n    return acc_sum / n\n\n\ndef _test(test_case):\n    num_inputs, num_outputs, num_hiddens = 784, 10, 256\n    net = nn.Sequential(\n        FlattenLayer(),\n        nn.Linear(num_inputs, num_hiddens),\n        nn.ReLU(),\n        nn.Linear(num_hiddens, num_outputs),\n    )\n\n    if os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"):\n        device = flow.device(\"cpu\")\n    else:\n        device = flow.device(\"cuda\")\n    net.to(device)\n\n    batch_size = 256\n    num_epochs = 1\n    data_dir = os.path.join(\n        os.getenv(\"ONEFLOW_TEST_CACHE_DIR\", \"./data-test\"), \"fashion-mnist\"\n    )\n    source_url = \"https://oneflow-public.oss-cn-beijing.aliyuncs.com/datasets/mnist/Fashion-MNIST/\"\n    train_iter, test_iter = load_data_fashion_mnist(\n        batch_size, resize=None, root=data_dir, download=True, source_url=source_url\n    )\n    loss = nn.CrossEntropyLoss()\n    loss.to(device)\n\n    optimizer = flow.optim.SGD(net.parameters(), lr=0.1)\n    final_accuracy = 0\n    for epoch in range(num_epochs):\n        train_l_sum, train_acc_sum, n = 0.0, 0.0, 0\n        start = time.time()\n        for X, y in train_iter:\n            X = X.to(device=device)\n            y = y.to(device=device)\n            y_hat = net(X)\n\n            l = loss(y_hat, y).sum()\n            optimizer.zero_grad()\n            l.backward()\n            optimizer.step()\n\n            train_l_sum += l.numpy()\n            train_acc_sum += (y_hat.argmax(dim=1).numpy() == y.numpy()).sum()\n            n += y.shape[0]\n            if n > 200:\n                break\n\n        test_acc = evaluate_accuracy(test_iter, net)\n        final_accuracy = train_acc_sum / n\n        print(\n            \"epoch %d, loss %.4f, train acc %.3f, test acc %.3f, cost >>>>>>> %s(s)\"\n            % (\n                epoch + 1,\n                train_l_sum / n,\n                final_accuracy,\n                test_acc,\n                str(time.time() - start),\n            )\n        )\n        final_accuracy = train_acc_sum / n\n    # test_case.assertLess(0.60, final_accuracy)\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestFashionMnistDataset(flow.unittest.TestCase):\n    def test_fashion_mnist_dataset(test_case):\n        _test(test_case)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/dataloader/test_lenet.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport os\nimport time\nimport unittest\n\nimport oneflow as flow\nimport oneflow.nn as nn\nimport oneflow.unittest\n\nfrom data_utils import load_data_fashion_mnist\n\n\n# reference: http://tangshusen.me/Dive-into-DL-PyTorch/#/chapter05_CNN/5.5_lenet\nclass LeNet(nn.Module):\n    def __init__(self):\n        super(LeNet, self).__init__()\n        self.conv = nn.Sequential(\n            nn.Conv2d(1, 6, kernel_size=5),  # in_channels, out_channels, kernel_size\n            nn.ReLU(),\n            nn.MaxPool2d(kernel_size=2, stride=2),  # kernel_size, stride\n            nn.Conv2d(6, 16, 5),\n            nn.ReLU(),\n            nn.MaxPool2d(kernel_size=2, stride=2),\n        )\n        self.fc = nn.Sequential(\n            nn.Linear(16 * 4 * 4, 120),\n            nn.ReLU(),\n            nn.Linear(120, 84),\n            nn.ReLU(),\n            nn.Linear(84, 10),\n        )\n\n    def forward(self, img):\n        feature = self.conv(img)\n        feature = feature.flatten(start_dim=1)\n        output = self.fc(feature)\n        return output\n\n\ndef evaluate_accuracy(data_iter, net, device=None):\n    if device is None and isinstance(net, nn.Module):\n        device = list(net.parameters())[0].device\n    acc_sum, n = 0.0, 0\n    net.eval()\n    with flow.no_grad():\n        for X, y in data_iter:\n            X = X.to(device=device)\n            y = y.to(device=device)\n            acc_sum += (net(X).argmax(dim=1).numpy() == y.numpy()).sum()\n            n += y.shape[0]\n    net.train()\n    return acc_sum / n\n\n\ndef _test_train_and_eval(test_case):\n    if os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"):\n        device = flow.device(\"cpu\")\n    else:\n        device = flow.device(\"cuda\")\n    net = LeNet()\n    lr, num_epochs = 0.02, 1\n    optimizer = flow.optim.SGD(net.parameters(), lr=lr, momentum=0.9)\n    net.to(device)\n\n    batch_size = 256\n    data_dir = os.path.join(\n        os.getenv(\"ONEFLOW_TEST_CACHE_DIR\", \"./data-test\"), \"fashion-mnist-lenet\"\n    )\n    source_url = \"https://oneflow-public.oss-cn-beijing.aliyuncs.com/datasets/mnist/Fashion-MNIST/\"\n\n    train_iter, test_iter = load_data_fashion_mnist(\n        batch_size=batch_size,\n        resize=None,\n        root=data_dir,\n        download=True,\n        source_url=source_url,\n        num_workers=0,\n    )\n    loss = nn.CrossEntropyLoss()\n    loss.to(device)\n\n    final_accuracy = 0\n\n    for epoch in range(num_epochs):\n        train_l_sum, train_acc_sum, n, batch_count, start = 0.0, 0.0, 0, 0, time.time()\n        for X, y in train_iter:\n            X = X.to(device=device)\n            y = y.to(device=device)\n            # forward\n            y_hat = net(X)\n            l = loss(y_hat, y).sum()\n            # backward\n            l.backward()\n            optimizer.step()\n            optimizer.zero_grad()\n\n            train_l_sum += l.numpy()\n            train_acc_sum += (y_hat.argmax(dim=1).numpy() == y.numpy()).sum()\n            n += y.shape[0]\n            batch_count += 1\n            if batch_count == 20:\n                break\n\n        test_acc = evaluate_accuracy(test_iter, net)\n        final_accuracy = train_acc_sum / n\n        print(\n            \"epoch %d, loss %.4f, train acc %.3f, test acc %.3f, time %.1f sec\"\n            % (\n                epoch + 1,\n                train_l_sum / batch_count,\n                final_accuracy,\n                test_acc,\n                time.time() - start,\n            )\n        )\n    # test_case.assertLess(0.4, final_accuracy)\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestLenet(flow.unittest.TestCase):\n    def test_lenet(test_case):\n        _test_train_and_eval(test_case)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/dataloader/test_mnist_dataset.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport os\nimport unittest\n\nimport flowvision as vision\nimport flowvision.transforms as transforms\n\nimport oneflow.unittest\nimport oneflow as flow\nimport oneflow.nn as nn\nfrom data_utils import load_data_mnist\n\n\ndata_dir = os.path.join(\n    os.getenv(\"ONEFLOW_TEST_CACHE_DIR\", \"./data-test\"), \"mnist-dataset\"\n)\ntrain_iter, test_iter = load_data_mnist(\n    batch_size=128,\n    download=True,\n    root=data_dir,\n    source_url=\"https://oneflow-public.oss-cn-beijing.aliyuncs.com/datasets/mnist/MNIST/\",\n)\n\n\ndef evaluate_accuracy(data_iter, net, device=None):\n    n_correct, n_samples = 0.0, 0\n    net.to(device)\n    net.eval()\n    with flow.no_grad():\n        for images, labels in data_iter:\n            images = images.reshape(-1, 28 * 28)\n            images = images.to(device=device)\n            labels = labels.to(device=device)\n            n_correct += (net(images).argmax(dim=1).numpy() == labels.numpy()).sum()\n            n_samples += images.shape[0]\n    net.train()\n    return n_correct / n_samples\n\n\nclass Net(nn.Module):\n    def __init__(\n        self, input_size=784, hidden_size1=128, hidden_size2=64, num_classes=10\n    ):\n        super(Net, self).__init__()\n        self.l1 = nn.Linear(input_size, hidden_size1)\n        self.relu1 = nn.ReLU()\n        self.l2 = nn.Linear(hidden_size1, hidden_size2)\n        self.relu2 = nn.ReLU()\n        self.l3 = nn.Linear(hidden_size2, num_classes)\n\n    def forward(self, x):\n        out = self.l1(x)\n        out = self.relu1(out)\n        out = self.l2(out)\n        out = self.relu2(out)\n        out = self.l3(out)\n        return out\n\n\ndef _test_train_and_eval(test_case):\n    if os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"):\n        device = flow.device(\"cpu\")\n    else:\n        device = flow.device(\"cuda\")\n\n    model = Net()\n    model.to(device)\n\n    loss = nn.CrossEntropyLoss().to(device)\n    optimizer = flow.optim.SGD(model.parameters(), lr=0.10)\n\n    num_epochs = 1\n    for epoch in range(num_epochs):\n        train_loss, n_correct, n_samples = 0.0, 0.0, 0\n        for images, labels in train_iter:\n            images = images.reshape(-1, 28 * 28)\n            images = images.to(device=device)\n            labels = labels.to(device=device)\n            features = model(images)\n            l = loss(features, labels).sum()\n            optimizer.zero_grad()\n            l.backward()\n            optimizer.step()\n\n            train_loss += l.numpy()\n            n_correct += (features.argmax(dim=1).numpy() == labels.numpy()).sum()\n            n_samples += images.shape[0]\n            if n_samples > 2000:\n                break\n\n        test_acc = evaluate_accuracy(test_iter, model, device)\n        train_acc = n_correct / n_samples\n        print(\n            \"epoch %d, train loss %.4f, train acc %.3f, test acc %.3f\"\n            % (epoch + 1, train_loss / n_samples, train_acc, test_acc)\n        )\n        # test_case.assertLess(0.8, test_acc)\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestMnistDataset(flow.unittest.TestCase):\n    def test_mnist_dataset(test_case):\n        _test_train_and_eval(test_case)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/dataloader/test_numpy_dataset.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\n\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\nclass ScpDataset(flow.utils.data.Dataset):\n    def __init__(self, chunksize=200, dim=81, length=2000):\n        self.chunksize = chunksize\n        self.dim = dim\n        self.length = length\n\n    def __getitem__(self, index):\n        np.random.seed(index)\n        return np.random.randn(self.chunksize, self.dim)\n\n    def __len__(self):\n        return self.length\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestNumpyDataset(flow.unittest.TestCase):\n    def test_numpy_dataset(test_case):\n        dataset = ScpDataset()\n        dataloader = flow.utils.data.DataLoader(dataset, batch_size=16, shuffle=True)\n        for X in dataloader:\n            test_case.assertEqual(X.shape, flow.Size([16, 200, 81]))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/dataloader/test_tensor_dataset.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\n\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.nn as nn\nimport oneflow.unittest\nimport oneflow.optim as optim\n\n\nclass LinearNet(nn.Module):\n    def __init__(self, n_feature):\n        super(LinearNet, self).__init__()\n        self.linear = nn.Linear(n_feature, 1)\n\n    def forward(self, x):\n        y = self.linear(x)\n        return y\n\n\n@unittest.skip(\"optimizer has a bug with 0-dim tensor\")\nclass TestTensorDataset(flow.unittest.TestCase):\n    def test_tensor_dataset(test_case):\n        num_inputs = 2\n        num_examples = 1000\n        true_w = [2, -3.4]\n        true_b = 4.2\n        net = LinearNet(num_inputs)\n        flow.nn.init.normal_(net.linear.weight, mean=0, std=0.01)\n        flow.nn.init.constant_(net.linear.bias, val=0)\n        loss = nn.MSELoss()\n        optimizer = optim.SGD(net.parameters(), lr=0.03)\n\n        features = flow.tensor(\n            np.random.normal(0, 1, (num_examples, num_inputs)), dtype=flow.float\n        )\n        labels = true_w[0] * features[:, 0] + true_w[1] * features[:, 1] + true_b\n        labels += flow.tensor(\n            np.random.normal(0, 0.01, size=labels.size()), dtype=flow.float\n        )\n\n        batch_size = 10\n        dataset = flow.utils.data.TensorDataset(features, labels)\n        data_iter = flow.utils.data.DataLoader(\n            dataset, batch_size, shuffle=True, num_workers=0\n        )\n        num_epochs = 10\n        for epoch in range(1, num_epochs + 1):\n            for (X, y) in data_iter:\n                output = net(X)\n                l = loss(output, y).sum()\n                optimizer.zero_grad()\n                l.backward()\n                optimizer.step()\n            if epoch == num_epochs:\n                test_case.assertLess(l.numpy(), 0.00025)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/dataloader/test_transforms.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport os\nimport unittest\n\nimport flowvision as vision\nimport flowvision.transforms as transforms\n\nimport oneflow as flow\nimport oneflow.nn as nn\nimport oneflow.optim as optim\nimport oneflow.unittest\nfrom data_utils import load_data_cifar10\n\n\nclass Net(nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.conv1 = nn.Conv2d(3, 6, 5)\n        self.pool = nn.MaxPool2d(2, 2)\n        self.conv2 = nn.Conv2d(6, 16, 5)\n        self.fc1 = nn.Linear(16 * 5 * 5, 120)\n        self.fc2 = nn.Linear(120, 84)\n        self.fc3 = nn.Linear(84, 10)\n\n    def forward(self, x):\n        x = self.pool(flow._C.relu(self.conv1(x)))\n        x = self.pool(flow._C.relu(self.conv2(x)))\n        x = flow.flatten(x, 1)  # flatten all dimensions except batch\n        x = flow._C.relu(self.fc1(x))\n        x = flow._C.relu(self.fc2(x))\n        x = self.fc3(x)\n        return x\n\n\ndef _test(test_case):\n    if os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"):\n        device = flow.device(\"cpu\")\n    else:\n        device = flow.device(\"cuda\")\n    net = Net()\n    net.to(device)\n\n    optimizer = optim.SGD(net.parameters(), lr=0.002, momentum=0.9)\n    criterion = nn.CrossEntropyLoss()\n    criterion.to(device)\n\n    transform = transforms.Compose(\n        [\n            transforms.Pad(10),\n            transforms.RandomHorizontalFlip(p=0.5),\n            transforms.RandomVerticalFlip(p=0.5),\n            transforms.CenterCrop(32),\n            transforms.Resize([32, 32]),\n            transforms.ToTensor(),\n            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),\n        ]\n    )\n\n    train_epoch = 1\n    batch_size = 4\n    data_dir = os.path.join(\n        os.getenv(\"ONEFLOW_TEST_CACHE_DIR\", \"./data-test\"), \"cifar10\"\n    )\n\n    train_iter, test_iter = load_data_cifar10(\n        batch_size=batch_size,\n        data_dir=data_dir,\n        download=True,\n        transform=transform,\n        source_url=\"https://oneflow-public.oss-cn-beijing.aliyuncs.com/datasets/cifar/cifar-10-python.tar.gz\",\n        num_workers=0,\n    )\n\n    final_loss = 0\n    for epoch in range(1, train_epoch + 1):  # loop over the dataset multiple times\n        running_loss = 0.0\n        for i, data in enumerate(train_iter, 1):\n            # get the inputs; data is a list of [inputs, labels]\n            inputs, labels = data\n            inputs = inputs.to(dtype=flow.float32, device=device)\n            labels = labels.to(dtype=flow.int64, device=device)\n\n            # zero the parameter gradients\n            optimizer.zero_grad()\n\n            # forward + backward + optimize\n            outputs = net(inputs)\n            loss = criterion(outputs, labels)\n            loss.backward()\n            optimizer.step()\n\n            # print statistics\n            running_loss += loss.numpy()\n            # print every 2000 mini-batches\n            if i % 2000 == 0:\n                final_loss = running_loss / 2000\n                print(\"epoch: %d  step: %5d  loss: %.3f \" % (epoch, i, final_loss))\n                running_loss = 0.0\n\n    print(\"final loss : \", final_loss)\n    # test_case.assertLess(final_loss, 1.79)\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestCifarDataset(flow.unittest.TestCase):\n    def test_cifar_dataset(test_case):\n        _test(test_case)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/exceptions/test_activation.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nfrom collections import OrderedDict\n\nimport os\nimport numpy as np\nimport time\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\nclass TestActivationError(flow.unittest.TestCase):\n    def test_relu_inplace_runtime_error(test_case):\n        with test_case.assertRaises(Exception) as context:\n            x = flow.ones((4, 4), dtype=flow.float32, requires_grad=True)\n            x.relu_()\n        test_case.assertTrue(\n            \"a leaf Tensor that requires grad is being used in an in-place operation\"\n            in str(context.exception)\n        )\n\n    def test_prelu_runtime_error(test_case):\n        with test_case.assertRaises(Exception) as context:\n            x = flow.ones((4, 4), dtype=flow.float32, requires_grad=True)\n            m = flow.nn.PReLU(5)\n            y = m(x)\n        test_case.assertTrue(\n            \"num_parameters in prelu must be 1 or 4\" in str(context.exception)\n        )\n\n    def test_celu_inplace_runtime_error(test_case):\n        with test_case.assertRaises(Exception) as context:\n            x = flow.ones((4, 4), dtype=flow.float32, requires_grad=True)\n            m = flow.nn.CELU(alpha=1.0, inplace=True)\n            y = m(x)\n        test_case.assertTrue(\n            \"a leaf Tensor that requires grad is being used in an in-place operation\"\n            in str(context.exception)\n        )\n\n    def test_glu_scalar_tensor_runtime_error(test_case):\n        with test_case.assertRaises(Exception) as context:\n            x = flow.tensor(1.0)\n            m = flow.nn.GLU()\n            y = m(x)\n        test_case.assertTrue(\n            \"glu does not support scalars because halving size must be even\"\n            in str(context.exception)\n        )\n\n    def test_glu_dim_index_error(test_case):\n        with test_case.assertRaises(Exception) as context:\n            x = flow.randn(2, 4)\n            m = flow.nn.GLU(dim=3)\n            y = m(x)\n        test_case.assertTrue(\n            \"Dimension out of range (expected to be in range of [-2, 1], but got 3)\"\n            in str(context.exception)\n        )\n\n    def test_glu_dim_even_runtime_error(test_case):\n        with test_case.assertRaises(Exception) as context:\n            x = flow.randn(2, 3)\n            m = flow.nn.GLU()\n            y = m(x)\n        test_case.assertTrue(\n            \"Halving dimension must be even, but dimension 1 is size 3\"\n            in str(context.exception)\n        )\n\n    def test_hard_sigmoid_inplace_runtime_error(test_case):\n        with test_case.assertRaises(Exception) as context:\n            x = flow.randn(2)\n            x.requires_grad = True\n            m = flow.nn.Hardsigmoid(inplace=True)\n            y = m(x)\n        test_case.assertTrue(\n            \"a leaf Tensor that requires grad is being used in an in-place operation\"\n            in str(context.exception)\n        )\n\n    def test_hard_shrink_inplace_runtime_error(test_case):\n        with test_case.assertRaises(Exception) as context:\n            x = flow.randn(2)\n            x.requires_grad = True\n            m = flow.nn.Hardshrink(inplace=True)\n            y = m(x)\n        test_case.assertTrue(\n            \"a leaf Tensor that requires grad is being used in an in-place operation\"\n            in str(context.exception)\n        )\n\n    def test_softmax_index_error(test_case):\n        with test_case.assertRaises(Exception) as context:\n            x = flow.randn(2, 4)\n            m = flow.nn.Softmax(dim=2)\n            y = m(x)\n        test_case.assertTrue(\n            \"Dimension out of range (expected to be in range of [-2, 1], but got 2)\"\n            in str(context.exception)\n        )\n\n    def test_soft_shrink_inplace_runtime_error(test_case):\n        with test_case.assertRaises(Exception) as context:\n            x = flow.randn(2)\n            x.requires_grad = True\n            m = flow.nn.Softshrink(inplace=True)\n            y = m(x)\n        test_case.assertTrue(\n            \"a leaf Tensor that requires grad is being used in an in-place operation\"\n            in str(context.exception)\n        )\n\n    def test_soft_shrink_alpha_runtime_error(test_case):\n        with test_case.assertRaises(Exception) as context:\n            x = flow.randn(2)\n            x.requires_grad = True\n            m = flow.nn.Softshrink(-0.1)\n            y = m(x)\n        test_case.assertTrue(\n            \"alpha must be greater or equal to 0, but found to be -0.1.\"\n            in str(context.exception)\n        )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/exceptions/test_add_n_op.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nimport oneflow as flow\nimport oneflow.unittest\n\n\nclass TestAddN(flow.unittest.TestCase):\n    def test_add_n_shape_error_msg(test_case):\n        a = flow.tensor([1, 2])\n        b = flow.tensor([3, 4])\n        c = flow.tensor([[2, 2], [2, 2]])\n        with test_case.assertRaises(RuntimeError) as context:\n            flow.add(a, b, c)\n        test_case.assertTrue(\n            \"inconsistent tensor size, expected all tensor to have the same number of elements, but got\"\n            in str(context.exception)\n        )\n\n    def test_add_n_dtype_error_msg(test_case):\n        a = flow.tensor([1, 2], dtype=flow.int64)\n        b = flow.tensor([3, 4], dtype=flow.int64)\n        c = flow.tensor([2, 2], dtype=flow.float64)\n        with test_case.assertRaises(RuntimeError) as context:\n            flow.add(a, b, c)\n        test_case.assertTrue(\n            \"expected all tenser to have same type, but found\" in str(context.exception)\n        )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/exceptions/test_arg_sort_op.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nimport oneflow as flow\nimport oneflow.unittest\n\n\nclass TestArgSort(flow.unittest.TestCase):\n    def test_direction_parameter_err(test_case):\n        with test_case.assertRaises(RuntimeError) as context:\n            x = flow.tensor([5, 10, 7, 8, 9, 1])\n            flow._C.arg_sort(x, direction=\"NONE\")\n        test_case.assertTrue(\n            \"expected the input direction parameter value is\" in str(context.exception)\n        )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/exceptions/test_array_functor.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nimport oneflow.unittest\nimport oneflow as flow\n\n\nclass TestArrayError(flow.unittest.TestCase):\n    def test_argmax_index_error(test_case):\n        with test_case.assertRaises(Exception) as context:\n            x = flow.ones((4, 4), dtype=flow.float32, requires_grad=True)\n            y = flow.argmax(x, dim=4)\n        test_case.assertTrue(\"Dimension out of range\" in str(context.exception))\n\n    def test_broadcast_like_runtime_error(test_case):\n        with test_case.assertRaises(Exception) as context:\n            x = flow.ones((1, 0), dtype=flow.float32, requires_grad=True)\n            like = flow.ones((2, 2, 2), dtype=flow.float32, requires_grad=True)\n            y = flow.broadcast_like(x, like)\n        test_case.assertTrue(\n            \"The expanded size of the tensor\" in str(context.exception)\n        )\n\n    def test_broadcast_like_numaxes_runtime_error(test_case):\n        with test_case.assertRaises(Exception) as context:\n            x = flow.ones((2, 2, 2), dtype=flow.float32, requires_grad=True)\n            like = flow.ones((2, 2), dtype=flow.float32, requires_grad=True)\n            y = flow._C.broadcast_like(x, like)\n        print(str(context.exception))\n        test_case.assertTrue(\"The number of sizes provided\" in str(context.exception))\n\n    def test_concat_index_error(test_case):\n        with test_case.assertRaises(Exception) as context:\n            x1 = flow.ones((2, 2), dtype=flow.float32, requires_grad=True)\n            x2 = flow.ones((2, 2), dtype=flow.float32, requires_grad=True)\n            y = flow.concat([x1, x2], dim=3)\n        test_case.assertTrue(\"Dimension out of range\" in str(context.exception))\n\n    def test_concat_dim_equal_runtime_error(test_case):\n        with test_case.assertRaises(Exception) as context:\n            x1 = flow.ones((2, 2), dtype=flow.float32, requires_grad=True)\n            x2 = flow.ones((2, 2, 2), dtype=flow.float32, requires_grad=True)\n            y = flow.concat([x1, x2])\n        test_case.assertTrue(\n            \"Tensors must have same number of dimensions\" in str(context.exception)\n        )\n\n    def test_concat_match_size_runtime_error(test_case):\n        with test_case.assertRaises(Exception) as context:\n            x1 = flow.ones((2, 2), dtype=flow.float32, requires_grad=True)\n            x2 = flow.ones((2, 3), dtype=flow.float32, requires_grad=True)\n            y = flow.concat([x1, x2])\n        test_case.assertTrue(\n            \"Sizes of tensors must match except in dimension\" in str(context.exception)\n        )\n\n    def test_stack_index_error(test_case):\n        with test_case.assertRaises(Exception) as context:\n            x1 = flow.ones((2, 1), dtype=flow.float32, requires_grad=True)\n            x2 = flow.ones((2, 1), dtype=flow.float32, requires_grad=True)\n            y = flow.concat([x1, x2], dim=4)\n        test_case.assertTrue(\"Dimension out of range\" in str(context.exception))\n\n    def test_stack_runtime_error(test_case):\n        with test_case.assertRaises(Exception) as context:\n            x1 = flow.ones((2, 1), dtype=flow.float32, requires_grad=True)\n            x2 = flow.ones((2, 2), dtype=flow.float32, requires_grad=True)\n            y = flow.stack([x1, x2])\n        test_case.assertTrue(\n            \"stack expects each tensor to be equal size\" in str(context.exception)\n        )\n\n    def test_expand_dim_runtime_error(test_case):\n        with test_case.assertRaises(Exception) as context:\n            x1 = flow.ones((2, 1), dtype=flow.float32, requires_grad=True)\n            x2 = flow.ones((2), dtype=flow.float32, requires_grad=True)\n            y = flow.expand(x1, x2.shape)\n        test_case.assertTrue(\n            \"be greater or equal to the number of dimensions in the tensor\"\n            in str(context.exception)\n        )\n\n    def test_expand_g_shape_runtime_error(test_case):\n        with test_case.assertRaises(Exception) as context:\n            x1 = flow.ones((2, 2), dtype=flow.float32, requires_grad=True)\n            x2 = flow.ones((2, 4), dtype=flow.float32, requires_grad=True)\n            y = flow.expand(x1, x2.shape)\n        test_case.assertTrue(\n            \"The expanded size of the tensor\" in str(context.exception)\n        )\n\n    def test_expand_l_shape_runtime_error(test_case):\n        with test_case.assertRaises(Exception) as context:\n            x1 = flow.ones((2, 2), dtype=flow.float32, requires_grad=True)\n            x2 = flow.ones((2, 0), dtype=flow.float32, requires_grad=True)\n            y = flow.expand(x1, x2.shape)\n        test_case.assertTrue(\n            \"The expanded size of the tensor\" in str(context.exception)\n        )\n\n    def test_squeeze_index_error(test_case):\n        with test_case.assertRaises(Exception) as context:\n            x = flow.ones((2, 1), dtype=flow.float32, requires_grad=True)\n            y = flow.squeeze(x, dim=4)\n        test_case.assertTrue(\"Dimension out of range\" in str(context.exception))\n\n    def test_roll_runtime_error(test_case):\n        with test_case.assertRaises(Exception) as context:\n            x = flow.ones((2, 2), dtype=flow.float32, requires_grad=True)\n            y = flow.roll(x, [0, 1], [0])\n        test_case.assertTrue(\n            \"shifts and dimensions must align\" in str(context.exception)\n        )\n\n    def test_gather_index_type_runtime_error(test_case):\n        with test_case.assertRaises(Exception) as context:\n            x1 = flow.ones((2, 2), dtype=flow.float32, requires_grad=True)\n            x2 = flow.ones((2, 2), dtype=flow.float32)\n            y = flow.gather(x1, 1, x2)\n        test_case.assertTrue(\n            \"gather(): Expected dtype int32 or int64 for index\"\n            in str(context.exception)\n        )\n\n    def test_gather_dim_value_runtime_error(test_case):\n        with test_case.assertRaises(Exception) as context:\n            x1 = flow.ones((2, 2), dtype=flow.float32, requires_grad=True)\n            x2 = flow.ones((2, 2), dtype=flow.int64)\n            y = flow.gather(x1, 2, x2)\n        test_case.assertTrue(\"Dimension out of range\" in str(context.exception))\n\n    def test_gather_dim_equal_runtime_error(test_case):\n        with test_case.assertRaises(Exception) as context:\n            x1 = flow.ones((2, 2), dtype=flow.float32, requires_grad=True)\n            x2 = flow.ones((2, 2, 2), dtype=flow.int64)\n            y = flow.gather(x1, 1, x2)\n        test_case.assertTrue(\n            \"Index tensor must have the same number of dimensions as input tensor\"\n            in str(context.exception)\n        )\n\n    def test_gather_size_runtime_error(test_case):\n        with test_case.assertRaises(Exception) as context:\n            x1 = flow.ones((2, 2), dtype=flow.float32, requires_grad=True)\n            x2 = flow.ones((4, 2), dtype=flow.int64)\n            y = flow.gather(x1, 1, x2)\n        test_case.assertTrue(\n            \"Size does not match at dimension\" in str(context.exception)\n        )\n\n    def test_tensor_scatter_nd_update_runtime_error(test_case):\n        with test_case.assertRaises(Exception) as context:\n            x = flow.arange(8, dtype=flow.float32, requires_grad=True)\n            indices = flow.tensor([[1], [3], [5]])\n            updates = flow.tensor([-1, -2, -3], dtype=flow.float64, requires_grad=True)\n            y = flow.tensor_scatter_nd_update(x, indices, updates)\n        test_case.assertTrue(\n            \"The dtype of tensor and updates must be same.\" in str(context.exception)\n        )\n\n    def test_view_runtime_error(test_case):\n        with test_case.assertRaises(Exception) as context:\n            x1 = flow.ones((2, 3, 4), dtype=flow.float32, requires_grad=True).permute(\n                1, 0, 2\n            )\n            x2 = flow.ones((4, 6), dtype=flow.float32, requires_grad=True)\n            y = flow.view(x1, x2.shape)\n        test_case.assertTrue(\n            \"view size is not compatible with input tensor's size\"\n            in str(context.exception)\n        )\n\n    def test_narrow_dim_index_error(test_case):\n        with test_case.assertRaises(Exception) as context:\n            x = flow.ones((3, 3), dtype=flow.float32, requires_grad=True)\n            y = flow.narrow(x, 3, 0, 2)\n        test_case.assertTrue(\"Dimension out of range\" in str(context.exception))\n\n    def test_narrow_0_dim_runtime_error(test_case):\n        with test_case.assertRaises(Exception) as context:\n            x = flow.tensor(1, dtype=flow.float32, requires_grad=True)\n            y = flow.narrow(x, 0, 0, 0)\n        test_case.assertTrue(\n            \"narrow() cannot be applied to a 0-dim tensor.\" in str(context.exception)\n        )\n\n    def test_narrow_start_index_error(test_case):\n        with test_case.assertRaises(Exception) as context:\n            x = flow.ones((3, 3), dtype=flow.float32, requires_grad=True)\n            y = flow.narrow(x, 0, 4, 0)\n        test_case.assertTrue(\"Dimension out of range\" in str(context.exception))\n\n    def test_narrow_length_exceed_runtime_error(test_case):\n        with test_case.assertRaises(Exception) as context:\n            x = flow.ones((3, 3), dtype=flow.float32, requires_grad=True)\n            y = flow.narrow(x, 0, 2, 2)\n        test_case.assertTrue(\"exceeds dimension size\" in str(context.exception))\n\n    def test_diagonal_index_error1(test_case):\n        with test_case.assertRaises(Exception) as context:\n            x = flow.ones((1, 2, 3), dtype=flow.float32, requires_grad=True)\n            y = flow.diagonal(x, 1, 3, 2)\n        test_case.assertTrue(\"Dimension out of range\" in str(context.exception))\n\n    def test_diagonal_index_error2(test_case):\n        with test_case.assertRaises(Exception) as context:\n            x = flow.ones((1, 2, 3), dtype=flow.float32, requires_grad=True)\n            y = flow.diagonal(x, 1, 2, 3)\n        test_case.assertTrue(\"Dimension out of range\" in str(context.exception))\n\n    def test_diagonal_runtime_error(test_case):\n        with test_case.assertRaises(Exception) as context:\n            x = flow.ones((1, 2, 3), dtype=flow.float32, requires_grad=True)\n            y = flow.diagonal(x, 1, 2, 2)\n        test_case.assertTrue(\n            \"diagonal dimensions cannot be identical\" in str(context.exception)\n        )\n\n    def test_split_index_error(test_case):\n        with test_case.assertRaises(Exception) as context:\n            x = flow.ones((1, 2, 3), dtype=flow.float32, requires_grad=True)\n            y = flow.split(x, split_size_or_sections=0, dim=4)\n        test_case.assertTrue(\"Dimension out of range\" in str(context.exception))\n\n    def test_split_runtime_error(test_case):\n        with test_case.assertRaises(Exception) as context:\n            x = flow.ones((1, 2, 3), dtype=flow.float32, requires_grad=True)\n            y = flow.split(x, split_size_or_sections=-1)\n        test_case.assertTrue(\n            \"split expects split_size be non-negative, but got split_size\"\n            in str(context.exception)\n        )\n\n    def test_splitwithsize_runtime_error(test_case):\n        with test_case.assertRaises(Exception) as context:\n            x = flow.ones((5, 2), dtype=flow.float32, requires_grad=True)\n            y = flow.split(x, [1, 3])\n        test_case.assertTrue(\n            \"split_with_sizes expects split_sizes to sum exactly to \"\n            in str(context.exception)\n        )\n\n    def test_unbind_index_error(test_case):\n        with test_case.assertRaises(Exception) as context:\n            x = flow.ones((1, 2, 3), dtype=flow.float32, requires_grad=True)\n            y = flow.unbind(x, dim=4)\n        test_case.assertTrue(\"Dimension out of range\" in str(context.exception))\n\n    def test_chunk_index_error(test_case):\n        with test_case.assertRaises(Exception) as context:\n            x = flow.ones((1, 2, 3), dtype=flow.float32, requires_grad=True)\n            y = flow.chunk(x, chunks=2, dim=4)\n        test_case.assertTrue(\"Dimension out of range\" in str(context.exception))\n\n    def test_chunk_tensor_dim_runtime_error(test_case):\n        with test_case.assertRaises(Exception) as context:\n            x = flow.tensor(1, dtype=flow.float32, requires_grad=True)\n            y = flow.chunk(x, chunks=2, dim=4)\n        test_case.assertTrue(\n            \"chunk expects at least a 1-dimensional tensor\" in str(context.exception)\n        )\n\n    def test_chunk_value_runtime_error(test_case):\n        with test_case.assertRaises(Exception) as context:\n            x = flow.ones((1, 2, 3), dtype=flow.float32, requires_grad=True)\n            y = flow.chunk(x, chunks=-1, dim=4)\n        test_case.assertTrue(\n            \"chunk expects `chunks` to be greater than 0, got\" in str(context.exception)\n        )\n\n    def test_meshgrid_tensors_scalar_runtime_error(test_case):\n        with test_case.assertRaises(Exception) as context:\n            x1 = flow.tensor([], dtype=flow.float32, requires_grad=True)\n            x2 = flow.ones((1, 2, 3), dtype=flow.float32, requires_grad=True)\n            y = flow.meshgrid(x1, x2)\n        test_case.assertTrue(\n            \"Expected scalar or 1D tensor in the tensor list\" in str(context.exception)\n        )\n\n    def test_meshgrid_tensors_size_runtime_error(test_case):\n        with test_case.assertRaises(Exception) as context:\n            y = flow.meshgrid([])\n        test_case.assertTrue(\n            \"meshgrid expects a non-empty TensorList\" in str(context.exception)\n        )\n\n    def test_meshgrid_tensors_dtype_runtime_error(test_case):\n        with test_case.assertRaises(Exception) as context:\n            x1 = flow.ones((2), dtype=flow.float32, requires_grad=True)\n            x2 = flow.ones((2), dtype=flow.float16, requires_grad=True)\n            y = flow.meshgrid(x1, x2)\n        test_case.assertTrue(\n            \"meshgrid expects all tensors to have the same dtype\"\n            in str(context.exception)\n        )\n\n    def test_meshgrid_tensors_placement_runtime_error(test_case):\n        with test_case.assertRaises(Exception) as context:\n            x1 = flow.tensor(\n                [0.0, 1.0],\n                dtype=flow.float32,\n                placement=flow.placement(\"cpu\", ranks=[0]),\n                sbp=[flow.sbp.broadcast],\n            )\n            x2 = flow.tensor(\n                [0.0, 1.0],\n                dtype=flow.float32,\n                placement=flow.placement(\"cpu\", ranks=[0]),\n                sbp=[flow.sbp.broadcast],\n            ).to_local()\n            y = flow.meshgrid(x1, x2)\n        test_case.assertTrue(\n            \"meshgrid expects all tensors are global tensor\" in str(context.exception)\n        )\n\n    def test_meshgrid_indexing_runtime_error(test_case):\n        with test_case.assertRaises(Exception) as context:\n            x1 = flow.ones((2), dtype=flow.float32, requires_grad=True)\n            x2 = flow.ones((2), dtype=flow.float32, requires_grad=True)\n            y = flow.meshgrid(x1, x2, indexing=\"ab\")\n        test_case.assertTrue(\n            \"meshgrid: indexing must be one of\" in str(context.exception)\n        )\n\n    def test_index_select_runtime_error(test_case):\n        with test_case.assertRaises(Exception) as context:\n            x = flow.tensor(\n                [[1, 2, 3], [4, 5, 6]], dtype=flow.float32, requires_grad=True\n            )\n            index = flow.tensor([0, 1], dtype=flow.float32)\n            y = flow.index_select(x, 1, index)\n        test_case.assertTrue(\n            \"Expected dtype int32 or int64 for index\" in str(context.exception)\n        )\n\n    def test_index_select_index_num_error(test_case):\n        with test_case.assertRaises(Exception) as context:\n            x = flow.tensor(\n                [[1, 2, 3], [4, 5, 6]], dtype=flow.float32, requires_grad=True\n            )\n            index = flow.tensor([[0]], dtype=flow.int32)\n            y = flow.index_select(x, 1, index)\n        test_case.assertTrue(\n            \"Index is supposed to be a vector\" in str(context.exception)\n        )\n\n    def test_index_select_index_error(test_case):\n        with test_case.assertRaises(Exception) as context:\n            x = flow.tensor(\n                [[1, 2, 3], [4, 5, 6]], dtype=flow.float32, requires_grad=True\n            )\n            index = flow.tensor([0], dtype=flow.int32)\n            y = flow.index_select(x, 4, index)\n        test_case.assertTrue(\"Dimension out of range\" in str(context.exception))\n\n    def test_to_device_runtime_error(test_case):\n        with test_case.assertRaises(Exception) as context:\n            x = flow.tensor(\n                [0.0, 1.0],\n                dtype=flow.float32,\n                placement=flow.placement(\"cpu\", ranks=[0]),\n                sbp=[flow.sbp.split(0)],\n            )\n            x.to(\"cpp\")\n        test_case.assertTrue(\n            \"Only string device without device id\" in str(context.exception)\n        )\n\n    def test_to_other_runtime_error(test_case):\n        with test_case.assertRaises(Exception) as context:\n            x = flow.tensor([0.0, 1.0], dtype=flow.float32)\n            other = flow.tensor(\n                [0.0, 1.0],\n                dtype=flow.float32,\n                placement=flow.placement(\"cpu\", ranks=[0]),\n                sbp=[flow.sbp.split(0)],\n            )\n            x.to(other)\n        test_case.assertTrue(\n            \"tensor.to(other) can only be called when tensor and other are local tensors\"\n            in str(context.exception)\n        )\n\n    def test_in_top_k_num_equal_runtime_error(test_case):\n        with test_case.assertRaises(Exception) as context:\n            target = flow.tensor([[3, 1]], dtype=flow.int32)\n            prediction = flow.tensor(\n                [[0.0, 1.0, 2.0, 3.0], [3.0, 2.0, 1.0, 0.0]], dtype=flow.float32\n            )\n            out = flow.in_top_k(target, prediction, k=1)\n        test_case.assertTrue(\n            \"The num of targets must equal the num of predictions\"\n            in str(context.exception)\n        )\n\n    def test_in_top_k_targets_dim_runtime_error(test_case):\n        with test_case.assertRaises(Exception) as context:\n            target = flow.tensor([[3, 1], [1, 3]], dtype=flow.int32)\n            prediction = flow.tensor(\n                [[0.0, 1.0, 2.0, 3.0], [3.0, 2.0, 1.0, 0.0]], dtype=flow.float32\n            )\n            out = flow.in_top_k(target, prediction, k=1)\n        test_case.assertTrue(\n            \"The dimension of targets must be 1\" in str(context.exception)\n        )\n\n    def test_in_top_k_pre_dim_runtime_error(test_case):\n        with test_case.assertRaises(Exception) as context:\n            target = flow.tensor([3, 1], dtype=flow.int32)\n            prediction = flow.tensor(\n                [[[0.0, 1.0, 2.0, 3.0]], [[3.0, 2.0, 1.0, 0.0]]], dtype=flow.float32\n            )\n            out = flow.in_top_k(target, prediction, k=1)\n        test_case.assertTrue(\n            \"The dimension of predictions must be 2\" in str(context.exception)\n        )\n\n    def test_repeat_runtime_error(test_case):\n        with test_case.assertRaises(Exception) as context:\n            x = flow.tensor([[1], [1]], dtype=flow.int32)\n            y = x.repeat(1)\n        test_case.assertTrue(\n            \"Number of dimensions of repeat dims can not be\" in str(context.exception)\n        )\n\n    def test_tile_runtime_error(test_case):\n        with test_case.assertRaises(Exception) as context:\n            x = flow.tensor([[1], [1]], dtype=flow.int32)\n            y = x.tile(-1)\n        test_case.assertTrue(\n            \"Trying to create tensor with negative dimension\" in str(context.exception)\n        )\n\n    def test_t_runtime_error(test_case):\n        with test_case.assertRaises(Exception) as context:\n            x = flow.tensor([[[1]]], dtype=flow.int32)\n            y = x.t()\n        test_case.assertTrue(\n            \"t() expects a tensor with <= 2 dimensions\" in str(context.exception)\n        )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/exceptions/test_autograd.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport re\nimport unittest\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\nclass TestAutograd(flow.unittest.TestCase):\n    def test_non_requires_grad_tensor_backward(test_case):\n        x = flow.ones(4, 4)\n        with test_case.assertRaises(Exception) as context:\n            x.backward()\n        test_case.assertIsNotNone(\n            re.search(\n                r\"\\nRuntimeError: element \\d of tensors does not require grad and does not have a grad_fn\",\n                str(context.exception),\n            )\n        )\n\n    def test_allow_unused(test_case):\n        with test_case.assertRaises(Exception) as context:\n            x = flow.ones(4, 4).requires_grad_()\n            y = flow.ones(4, 4).requires_grad_()\n            z = x * x\n            dx, dy = flow.autograd.grad(z, [x, y], flow.ones_like(z))\n        test_case.assertTrue(\n            \"allow_unused=True if this is the desired behavior\"\n            in str(context.exception)\n        )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/exceptions/test_batch_gather_op.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nimport numpy as np\nfrom numpy import array, dtype\nimport oneflow as flow\nimport oneflow.unittest\n\n\nclass TestBatchGather(flow.unittest.TestCase):\n    def test_input_tensor_dimesion_error_msg(test_case):\n        with test_case.assertRaises(RuntimeError) as context:\n            x = flow.tensor(1)\n            indice = flow.tensor([1])\n            flow.batch_gather(x, indice)\n        test_case.assertTrue(\n            \"The dimension of the input tensor should be greater than zero, but got\"\n            in str(context.exception)\n        )\n\n    def test_indices_dimesion_error_msg(test_case):\n        with test_case.assertRaises(RuntimeError) as context:\n            x = flow.tensor([1])\n            indice = flow.tensor(1)\n            flow.batch_gather(x, indice)\n        test_case.assertTrue(\n            \"The dimension of the indices tensor should be greater than zero, but got\"\n            in str(context.exception)\n        )\n\n    def test_legal_dimension_error_msg(test_case):\n        with test_case.assertRaises(RuntimeError) as context:\n            x = np.random.randn(1)\n            x_tensor = flow.tensor(x)\n            indice = flow.tensor([[1, 1], [1, 1], [1, 1]])\n            flow.batch_gather(x_tensor, indice)\n        test_case.assertTrue(\n            \"The dimension of the input tensor should be greater than or equal to the dimension of the indices tensor\"\n            in str(context.exception)\n        )\n\n    def test_indice_type_error_msg(test_case):\n        with test_case.assertRaises(TypeError) as context:\n            x = np.random.randn(2)\n            x_tensor = flow.tensor(x)\n            indice = flow.tensor([1, 1], dtype=flow.float64)\n            flow.batch_gather(x_tensor, indice)\n        test_case.assertTrue(\n            \"The dtype of the indices tensor must be int32 or int64\"\n            in str(context.exception)\n        )\n\n    def test_tensor_shape_size_error_msg(test_case):\n        with test_case.assertRaises(RuntimeError) as context:\n            x = np.random.randn(4, 5)\n            x_tensor = flow.tensor(x)\n            indice = flow.tensor([[1, 2], [1, 2], [1, 2]])\n            out = flow.batch_gather(x_tensor, indice)\n        test_case.assertTrue(\n            \"The size of indices tensor must match the size of input tensor\"\n            in str(context.exception)\n        )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/exceptions/test_bias_add_op.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nimport oneflow as flow\nimport oneflow.unittest\n\n\nclass TestBiasAdd(flow.unittest.TestCase):\n    def test_b_tensor_numaxes_err(test_case):\n        with test_case.assertRaises(RuntimeError) as context:\n            x = flow.tensor([[1, 1], [2, 2]])\n            y = flow.tensor([[2, 2], [1, 1]])\n            out = flow._C.bias_add(y, x, axis=0)\n        test_case.assertTrue(\n            \"Bias tensor has to be a one-dimensional vector\" in str(context.exception)\n        )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/exceptions/test_binary_functor_exception.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nfrom collections import OrderedDict\n\nimport os\nimport numpy as np\nimport time\nimport oneflow as flow\nimport oneflow.unittest\n\n\nclass TestBinaryFunctorError(flow.unittest.TestCase):\n    def test_add_inplace_runtime_error(test_case):\n        with test_case.assertRaises(RuntimeError) as context:\n            x = flow.ones((4, 4), dtype=flow.float32, requires_grad=True)\n            y = flow.ones((4, 4), dtype=flow.float32, requires_grad=True)\n            x.add_(y)\n        test_case.assertTrue(\n            \"a leaf Tensor that requires grad is being used in an in-place operation\"\n            in str(context.exception)\n        )\n\n    def test_add_broad_cast_runtime_error(test_case):\n        with test_case.assertRaises(RuntimeError) as context:\n            x = flow.ones((2, 3))\n            y = flow.ones((2, 4))\n            x.add_(y)\n        test_case.assertTrue(\n            \"Tensor with shape (2,3) doesn't match the broadcast shape in an inplace operation\"\n            in str(context.exception)\n        )\n\n        with test_case.assertRaises(RuntimeError) as context:\n            x = flow.ones((3, 3))\n            y = flow.ones((2, 3, 3))\n            x.add_(y)\n        test_case.assertTrue(\n            \"Can not expand origin shape (2,3,3) to (3,3)\" in str(context.exception)\n        )\n\n        with test_case.assertRaises(RuntimeError) as context:\n            x = flow.ones((4, 4), dtype=flow.float32, requires_grad=True)\n            y = flow.ones((4, 4), dtype=flow.float32, requires_grad=True)\n            x.mul_(y)\n        test_case.assertTrue(\n            \"a leaf Tensor that requires grad is being used in an in-place operation\"\n            in str(context.exception)\n        )\n\n        with test_case.assertRaises(RuntimeError) as context:\n            x = flow.ones((2, 3))\n            y = flow.ones((2, 4))\n            x.mul_(y)\n        test_case.assertTrue(\n            \"Tensor with shape (2,3) doesn't match the broadcast shape in an inplace operation\"\n            in str(context.exception)\n        )\n\n        with test_case.assertRaises(RuntimeError) as context:\n            x = flow.ones((3, 3))\n            y = flow.ones((2, 3, 3))\n            x.mul_(y)\n        test_case.assertTrue(\n            \"Can not expand origin shape (2,3,3) to (3,3)\" in str(context.exception)\n        )\n\n    def test_div_inplace_runtime_error(test_case):\n        with test_case.assertRaises(RuntimeError) as context:\n            x = flow.ones((4, 4), dtype=flow.float32, requires_grad=True)\n            y = flow.ones((4, 4), dtype=flow.float32, requires_grad=True)\n            x.div_(y)\n        test_case.assertTrue(\n            \"a leaf Tensor that requires grad is being used in an in-place operation\"\n            in str(context.exception)\n        )\n\n        with test_case.assertRaises(RuntimeError) as context:\n            x = flow.ones((2, 3))\n            y = flow.ones((2, 4))\n            x.div_(y)\n        test_case.assertTrue(\n            \"Tensor with shape (2,3) doesn't match the broadcast shape in an inplace operation\"\n            in str(context.exception)\n        )\n\n        with test_case.assertRaises(RuntimeError) as context:\n            x = flow.ones((3, 3))\n            y = flow.ones((2, 3, 3))\n            x.div_(y)\n        test_case.assertTrue(\n            \"Can not expand origin shape (2,3,3) to (3,3)\" in str(context.exception)\n        )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/exceptions/test_bmm.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestBmm(flow.unittest.TestCase):\n    def test_bmm_exception_dim_not_right(test_case):\n        x = flow.tensor((2, 2))\n        with test_case.assertRaises(RuntimeError) as ctx:\n            y = flow.bmm(x, x)\n        test_case.assertTrue(\n            \"Expected 3-dimensional tensor, but got 1-dimensional tensor for argument #1\"\n            in str(ctx.exception)\n        )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/exceptions/test_broadcast_ops.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\n\nimport oneflow as flow\nimport oneflow.unittest\n\nbinary_ops = [\n    flow.add,\n    flow.sub,\n    flow.mul,\n    flow.div,\n    flow.min,\n    flow.minimum,\n    flow.max,\n    flow.maximum,\n    flow.fmod,\n    flow.pow,\n    flow.eq,\n    flow.ne,\n    flow.gt,\n    flow.ge,\n    flow.lt,\n    flow.le,\n    flow.logical_and,\n    flow.logical_or,\n    flow.logical_xor,\n]\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestBroadcastOps(flow.unittest.TestCase):\n    def test_broadcast_binary_ops(test_case):\n        x = flow.Tensor(8, 10)\n        y = flow.Tensor(8)\n        for op in binary_ops:\n            with test_case.assertRaises(RuntimeError) as ctx:\n                op(x, y)\n            test_case.assertTrue(\n                \"The size of tensor a (10) must match the size of tensor b (8) at non-singleton dimension 1\"\n                in str(ctx.exception)\n            )\n\n    def test_broadcast_shapes(test_case):\n        with test_case.assertRaises(RuntimeError) as ctx:\n            y = flow.broadcast_shapes((2,), (3, 3), (1, 1, 1))\n        test_case.assertTrue(\n            \"input and other can't be broadcasted to a single shape.\"\n            in str(ctx.exception)\n        )\n        with test_case.assertRaises(RuntimeError) as ctx:\n            y = flow.broadcast_shapes()\n        test_case.assertTrue(\"shapes should not be empty.\" in str(ctx.exception))\n\n    def test_broadcast_tensors(test_case):\n        with test_case.assertRaises(RuntimeError) as ctx:\n            y, z = flow.broadcast_tensors(flow.ones(2, 3), flow.ones(4, 3))\n        test_case.assertTrue(\n            \"input and other can't be broadcasted to a single shape.\"\n            in str(ctx.exception)\n        )\n        with test_case.assertRaises(RuntimeError) as ctx:\n            y = flow.broadcast_tensors()\n        test_case.assertTrue(\"tensors should not be empty.\" in str(ctx.exception))\n\n    def test_broadcast_to(test_case):\n        # see flow.expand, because broadcast_to is an alias of flow.expand\n        pass\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/exceptions/test_chunk.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestModule(flow.unittest.TestCase):\n    def test_chunk_0_dim_input_exception(test_case):\n        # torch exception and messge:\n        #\n        #   RuntimeError: chunk expects at least a 1-dimensional tensor.\n        #\n        x = flow.tensor(3.14)\n        with test_case.assertRaises(RuntimeError) as ctx:\n            y = flow.chunk(x, chunks=1, dim=0)\n        test_case.assertTrue(\n            \"chunk expects at least a 1-dimensional tensor\" in str(ctx.exception)\n        )\n\n    def test_chunk_0_chunks_param_exception(test_case):\n        # torch exception and messge:\n        #\n        #   RuntimeError: chunk expects `chunks` to be greater than 0, got: 0\n        #\n        x = flow.tensor([[1, 2, 3], [4, 5, 6]])\n        with test_case.assertRaises(RuntimeError) as ctx:\n            y = flow.chunk(x, chunks=0, dim=0)\n        test_case.assertTrue(\n            \"chunk expects `chunks` to be greater than 0, got: \" in str(ctx.exception)\n        )\n\n    def test_chunk_dim_param_exception(test_case):\n        # torch exception and messge:\n        #\n        #   IndexError: Dimension out of range (expected to be in range of [-2, 1], but got -3)\n        #\n        x = flow.tensor([[1, 2, 3], [4, 5, 6]])\n        with test_case.assertRaises(IndexError) as ctx:\n            y = flow.chunk(x, chunks=2, dim=-3)\n        test_case.assertTrue(\n            \"Dimension out of range (expected to be in range of [-2, 1], but got -3)\"\n            in str(ctx.exception)\n        )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/exceptions/test_cosine_similarity.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestCosineSimilarity(flow.unittest.TestCase):\n    def test_cosine_similarity_not_floating_type(test_case):\n        x = flow.randn(2, 5).to(flow.int32)\n        y = flow.randn(2, 5).to(flow.int32)\n        with test_case.assertRaises(RuntimeError) as ctx:\n            out = flow.nn.functional.cosine_similarity(x, y, dim=1)\n        test_case.assertTrue(\n            \"expected common dtype to be floating point, yet common dtype is oneflow.int32\"\n            in str(ctx.exception)\n        )\n\n    def test_cosine_similarity_broadcast(test_case):\n        x = flow.randn(2, 5)\n        y = flow.randn(2, 4)\n        with test_case.assertRaises(RuntimeError) as ctx:\n            out = flow.nn.functional.cosine_similarity(x, y, dim=1)\n        test_case.assertTrue(\n            \"The size of tensor a (5) must match the size of tensor b (4) at non-singleton dimension 1\"\n            in str(ctx.exception)\n        )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/exceptions/test_deform_conv2d_op.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\nclass TestDeformConv(flow.unittest.TestCase):\n    def test_deform_conv2d_invalid_input_sizes(test_case):\n        input = flow.randn(2, 5, 1)\n        weight = flow.randn(2, 5, 1, 1)\n        offset = flow.randn(2, 5, 1, 1)\n        with test_case.assertRaises(RuntimeError) as ctx:\n            out = flow.nn.functional.deform_conv2d(input, offset, weight)\n        test_case.assertTrue(\n            \"The dimension of input tensor weight must be \" in str(ctx.exception)\n        )\n\n    def test_deform_conv2d_invalid_offset_sizes(test_case):\n        input = flow.randn(2, 5, 1, 1)\n        weight = flow.randn(2, 5, 1, 1)\n        offset = flow.randn(2, 5, 1)\n        with test_case.assertRaises(RuntimeError) as ctx:\n            out = flow.nn.functional.deform_conv2d(input, offset, weight)\n        test_case.assertTrue(\n            \"The dimension of offset tensor weight must be \" in str(ctx.exception)\n        )\n\n    def test_deform_conv2d_invalid_weight_sizes(test_case):\n        input = flow.randn(2, 5, 1, 1)\n        weight = flow.randn(2, 5, 5)\n        offset = flow.randn(2, 3, 1, 1)\n        with test_case.assertRaises(RuntimeError) as ctx:\n            out = flow.nn.functional.deform_conv2d(input, offset, weight)\n        test_case.assertTrue(\n            \"The dimension of weight tensor weight must be \" in str(ctx.exception)\n        )\n\n    def test_deform_conv2d_invalid_mask_sizes(test_case):\n        input = flow.randn(2, 5, 1, 1)\n        weight = flow.randn(2, 4, 1, 1)\n        offset = flow.randn(2, 3, 1, 1)\n        mask = flow.randn(2, 3, 1)\n        with test_case.assertRaises(RuntimeError) as ctx:\n            out = flow.nn.functional.deform_conv2d(input, offset, weight, mask=mask)\n        test_case.assertTrue(\n            \"The dimension of mask tensor weight must be\" in str(ctx.exception)\n        )\n\n    def test_deform_conv2d_invalid_dilation_parm(test_case):\n        input = flow.randn(4, 3, 10, 10)\n        weight = flow.randn(5, 3, 3, 3)\n        offset = flow.randn(4, 18, 8, 8)\n        with test_case.assertRaises(RuntimeError) as ctx:\n            out = flow.nn.functional.deform_conv2d(\n                input, offset, weight, dilation=(-1, 0)\n            )\n        test_case.assertTrue(\"The dilation must be greater than\" in str(ctx.exception))\n\n    def test_deform_conv2d_invalid_pad_parm(test_case):\n        input = flow.randn(4, 3, 10, 10)\n        weight = flow.randn(5, 3, 3, 3)\n        offset = flow.randn(4, 18, 8, 8)\n        with test_case.assertRaises(RuntimeError) as ctx:\n            out = flow.nn.functional.deform_conv2d(\n                input, offset, weight, padding=(-1, 0)\n            )\n        test_case.assertTrue(\"The pad must be greater than\" in str(ctx.exception))\n\n    def test_deform_conv2d_invalid_stride_parm(test_case):\n        input = flow.randn(4, 3, 10, 10)\n        weight = flow.randn(5, 3, 3, 3)\n        offset = flow.randn(4, 18, 8, 8)\n        with test_case.assertRaises(RuntimeError) as ctx:\n            out = flow.nn.functional.deform_conv2d(\n                input, offset, weight, stride=(-1, 0)\n            )\n        test_case.assertTrue(\"The stride must be greater than\" in str(ctx.exception))\n\n    def test_deform_conv2d_invalid_offset_shape(test_case):\n        input = flow.randn(4, 3, 10, 10)\n        weight = flow.randn(5, 3, 3, 3)\n        offset = flow.randn(4, 9, 8, 8)\n        with test_case.assertRaises(RuntimeError) as ctx:\n            out = flow.nn.functional.deform_conv2d(input, offset, weight)\n        test_case.assertTrue(\n            \"The shape of the offset tensor at dimension 1 is not valid\"\n            in str(ctx.exception)\n        )\n\n    def test_deform_conv2d_invalid_batch_size(test_case):\n        input = flow.randn(4, 3, 10, 10)\n        weight = flow.randn(5, 3, 3, 3)\n        offset = flow.randn(3, 18, 8, 8)\n        with test_case.assertRaises(RuntimeError) as ctx:\n            out = flow.nn.functional.deform_conv2d(input, offset, weight)\n        test_case.assertTrue(\"invalid batch size of offset\" in str(ctx.exception))\n\n    def test_deform_conv2d_invalid_mask_shape(test_case):\n        input = flow.randn(4, 3, 10, 10)\n        weight = flow.randn(5, 3, 3, 3)\n        offset = flow.randn(4, 18, 8, 8)\n        mask = flow.randn(4, 1, 8, 8)\n        with test_case.assertRaises(RuntimeError) as ctx:\n            out = flow.nn.functional.deform_conv2d(input, offset, weight, mask=mask)\n        test_case.assertTrue(\"mask.shape[1] is not valid\" in str(ctx.exception))\n\n    def test_deform_conv2d_invalid_output_size(test_case):\n        input = flow.randn(4, 3, 10, 10)\n        weight = flow.randn(5, 3, 3, 3)\n        offset = flow.randn(4, 18, 8, 8)\n        with test_case.assertRaises(RuntimeError) as ctx:\n            out = flow.nn.functional.deform_conv2d(\n                input, offset, weight, dilation=(10, 10)\n            )\n        test_case.assertTrue(\"Calculated output size too small\" in str(ctx.exception))\n\n    def test_deform_conv2d_invalid_offset_output_dims(test_case):\n        input = flow.randn(4, 3, 10, 10)\n        weight = flow.randn(5, 3, 3, 3)\n        offset = flow.randn(4, 18, 8, 8)\n        with test_case.assertRaises(RuntimeError) as ctx:\n            out = flow.nn.functional.deform_conv2d(\n                input, offset, weight, dilation=(2, 2)\n            )\n        test_case.assertTrue(\"invalid offset output dims\" in str(ctx.exception))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/exceptions/test_device.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport re\nimport unittest\nimport oneflow as flow\nimport oneflow.unittest\nimport oneflow.nn.functional as F\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestDevice(flow.unittest.TestCase):\n    def test_device_type(test_case):\n        with test_case.assertRaises(RuntimeError) as exp:\n            flow.device(\"xpu\")\n        test_case.assertTrue(\n            re.match(\n                \"Expected one of (.*) device type at start of device string: xpu\",\n                str(exp.exception),\n            )\n            is not None\n        )\n\n    def test_device_index(test_case):\n        # TODO(hjchen2): throw runtime error if cuda reports error\n        #     with test_case.assertRaises(RuntimeError) as exp:\n        #         device = flow.device(\"cuda:1000\")\n        #         flow.Tensor(2, 3).to(device=device)\n        #     test_case.assertTrue(\"CUDA error: invalid device ordinal\" in str(exp.exception))\n        pass\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/exceptions/test_dot.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nimport numpy as np\nimport oneflow as flow\nimport oneflow.unittest\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestDot(flow.unittest.TestCase):\n    def test_dot_shape_error_msg(test_case):\n        with test_case.assertRaises(RuntimeError) as exp:\n            a = flow.tensor([2, 3])\n            b = flow.tensor([2, 3, 4])\n            flow.dot(a, b)\n        test_case.assertTrue(\"inconsistent tensor size\" in str(exp.exception))\n\n    def test_dot_dims_error_msg(test_case):\n        with test_case.assertRaises(RuntimeError) as exp:\n            a = flow.tensor([[2, 3], [3, 4]])\n            flow.dot(a, a)\n        test_case.assertTrue(\"1D tensors expected\" in str(exp.exception))\n\n    def test_dot_dtype_error_msg(test_case):\n        with test_case.assertRaises(RuntimeError) as exp:\n            a = flow.tensor([2, 3], dtype=flow.int64)\n            b = flow.tensor([2, 3], dtype=flow.float32)\n            flow.dot(a, b)\n        test_case.assertTrue(\n            \"expected both vectors to have same dtype\" in str(exp.exception)\n        )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/exceptions/test_error_reported_in_thread.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport subprocess\nimport sys\nimport tempfile\nimport os\nimport unittest\nimport oneflow as flow\nimport oneflow.unittest\n\n\n@flow.unittest.skip_unless_1n1d()\ndef test_error_reported_in_thread():\n    for env_name in [\"ONEFLOW_DEBUG\", \"ONEFLOW_PYTHON_STACK_GETTER\"]:\n        env = os.environ.copy()\n        env[env_name] = \"1\"\n        # Run a new process to capture the error output\n        p = subprocess.run(\n            [sys.executable, \"throw_error.py\"],\n            capture_output=True,\n            cwd=os.path.dirname(os.path.realpath(__file__)),\n            env=env,\n        )\n        assert p.returncode != 0\n        error_msg = p.stderr.decode(\"utf-8\")\n        print(error_msg)\n        assert (\n            \"\"\"File \"throw_error.py\", line 19, in g\n    flow._C.throw_error(x)\n  File \"throw_error.py\", line 23, in f\n    g(x)\n  File \"throw_error.py\", line 26, in <module>\n    f(x)\"\"\"\n            in error_msg\n        )\n\n\n@flow.unittest.skip_unless_1n1d()\ndef test_python_stack_getter_disabled():\n    # Run a new process to capture the error output\n    p = subprocess.run(\n        [sys.executable, \"throw_error.py\"],\n        capture_output=True,\n        cwd=os.path.dirname(os.path.realpath(__file__)),\n    )\n    assert p.returncode != 0\n    error_msg = p.stderr.decode(\"utf-8\")\n    assert \"No Python stack available.\" in error_msg\n    assert \"ONEFLOW_DEBUG\" in error_msg\n    assert \"ONEFLOW_PYTHON_STACK_GETTER\" in error_msg\n"
  },
  {
    "path": "python/oneflow/test/exceptions/test_gird_sample_op.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nimport oneflow.unittest\nimport oneflow.nn\nimport oneflow as flow\nfrom oneflow.test_utils.test_util import GenArgList\nimport numpy as np\nfrom collections import OrderedDict\n\narg_dict = OrderedDict()\narg_dict[\"N\"] = [3, 4, 5]\narg_dict[\"C\"] = [4, 5, 6]\narg_dict[\"D_in\"] = [8, 11, 13]\narg_dict[\"H_in\"] = [5, 6, 7]\narg_dict[\"W_in\"] = [7, 8, 9]\narg_dict[\"D_out\"] = [13, 15, 17]\narg_dict[\"H_out\"] = [9, 10, 11]\narg_dict[\"W_out\"] = [11, 12, 13]\n\n\ndef _test_dimention_error_msg_impl(test_case, N, C, H_in, H_out):\n    inputval = oneflow.ones(N, C, H_in,)\n    grid = oneflow.ones(N, H_out, 1)\n    with test_case.assertRaises(RuntimeError) as ctx:\n        flow.nn.functional.grid_sample(\n            inputval, grid, mode=\"bilinear\", padding_mode=\"zeros\"\n        )\n    test_case.assertTrue(\"MUST be 4D or 5D input\" in str(ctx.exception))\n\n\ndef _test_4d_gird_shape_error_msg_impl(test_case, N, C, H_in, W_in, H_out, W_out):\n    inputval = oneflow.ones(N, C, H_in, W_in)\n    grid = oneflow.ones(N, H_out, W_out, 1)\n    with test_case.assertRaises(RuntimeError) as ctx:\n        flow.nn.functional.grid_sample(\n            inputval, grid, mode=\"bilinear\", padding_mode=\"zeros\"\n        )\n    test_case.assertTrue(\"Grid shape MUST (N, H_out, W_out, 2)\" in str(ctx.exception))\n\n\ndef _test_4d_grid_input_not_same_shape_error_msg_impl(\n    test_case, N, C, H_in, W_in, H_out, W_out\n):\n    inputval = oneflow.ones(N, C, H_in, W_in)\n    grid = oneflow.ones(N, H_out, W_out)\n    with test_case.assertRaises(RuntimeError) as ctx:\n        flow.nn.functional.grid_sample(\n            inputval, grid, mode=\"bilinear\", padding_mode=\"zeros\"\n        )\n    test_case.assertTrue(\n        \"Grid and input MUST have same dimention\" in str(ctx.exception)\n    )\n\n\ndef _test_5d_gird_shape_error_msg_impl(\n    test_case, N, C, D_in, H_in, W_in, D_out, H_out, W_out\n):\n    inputval = oneflow.ones(N, C, D_in, H_in, W_in)\n    grid = oneflow.ones(N, D_out, H_out, W_out, 2)\n    with test_case.assertRaises(RuntimeError) as ctx:\n        flow.nn.functional.grid_sample(\n            inputval, grid, mode=\"bilinear\", padding_mode=\"zeros\"\n        )\n    test_case.assertTrue(\"Grid shape MUST (N, H_out, W_out, 3)\" in str(ctx.exception))\n\n\ndef _test_5d_grid_input_not_same_shape_error_msg_impl(\n    test_case, N, C, D_in, H_in, W_in, D_out, H_out, W_out\n):\n    inputval = oneflow.ones(N, C, D_in, H_in, W_in)\n    grid = oneflow.ones(N, D_out, H_out, W_out)\n    with test_case.assertRaises(RuntimeError) as ctx:\n        flow.nn.functional.grid_sample(\n            inputval, grid, mode=\"bilinear\", padding_mode=\"zeros\"\n        )\n    test_case.assertTrue(\n        \"Grid and input MUST have same dimention\" in str(ctx.exception)\n    )\n\n\nclass TestGridSample(flow.unittest.TestCase):\n    def test_dimention_error_msg(test_case):\n        for arg in GenArgList(arg_dict):\n            _test_dimention_error_msg_impl(test_case, arg[0], arg[1], arg[3], arg[6])\n\n    def test_4d_gird_shape_error_msg(test_case):\n        for arg in GenArgList(arg_dict):\n            _test_4d_gird_shape_error_msg_impl(\n                test_case, arg[0], arg[1], arg[3], arg[4], arg[6], arg[7]\n            )\n\n    def test_4d_grid_input_not_same_shape_error_msg(test_case):\n        for arg in GenArgList(arg_dict):\n            _test_4d_grid_input_not_same_shape_error_msg_impl(\n                test_case, arg[0], arg[1], arg[3], arg[4], arg[6], arg[7]\n            )\n\n    def test_5d_gird_shape_error_msg(test_case):\n        for arg in GenArgList(arg_dict):\n            _test_5d_gird_shape_error_msg_impl(test_case, *arg[0:])\n\n    def test_5d_grid_input_not_same_shape_error_msg(test_case):\n        for arg in GenArgList(arg_dict):\n            _test_5d_grid_input_not_same_shape_error_msg_impl(test_case, *arg[0:])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/exceptions/test_global_branch_error_local_to_global_with_broadcast_sbp_1n2d.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport os\nimport numpy as np\nimport time\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@flow.unittest.skip_unless_1n2d()\nclass TestLocalToGlobalBranchError(flow.unittest.TestCase):\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    def test_global_branch_error_with_local_to_global(test_case):\n        try:\n            os.environ[\"ONEFLOW_TIMEOUT_SECONDS\"] = \"2\"\n            data = flow.rand(2, dtype=flow.float32)\n            placement = flow.placement(type=\"cpu\", ranks=[0, 1])\n            sbp = flow.sbp.broadcast\n            if flow.env.get_rank() == 0:\n                global_data = data.to_global(placement=placement, sbp=sbp)\n            else:\n                time.sleep(2)\n\n        except Exception as e:\n            err_msg = \"Maybe executing different code in different ranks, please check if the code is branched and operates on the global tensor\"\n            assert err_msg in str(e)\n        finally:\n            os.environ[\"ONEFLOW_TIMEOUT_SECONDS\"] = \"300\"\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/exceptions/test_global_branch_error_local_to_global_with_broadcast_sbp_1n4d.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport os\nimport numpy as np\nimport time\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@flow.unittest.skip_unless_1n4d()\nclass TestLocalToGlobalBranchError(flow.unittest.TestCase):\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    def test_global_branch_error_with_local_to_global(test_case):\n        try:\n            os.environ[\"ONEFLOW_TIMEOUT_SECONDS\"] = \"2\"\n            data = flow.rand(2, dtype=flow.float32)\n            placement = flow.placement(type=\"cpu\", ranks=[0, 1])\n            sbp = flow.sbp.broadcast\n            if flow.env.get_rank() == 0:\n                global_data = data.to_global(placement=placement, sbp=sbp)\n            else:\n                time.sleep(2)\n\n        except Exception as e:\n            err_msg = \"Maybe executing different code in different ranks, please check if the code is branched and operates on the global tensor\"\n            assert err_msg in str(e)\n        finally:\n            os.environ[\"ONEFLOW_TIMEOUT_SECONDS\"] = \"300\"\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/exceptions/test_global_branch_error_local_to_global_with_split_sbp.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport os\nimport numpy as np\nimport time\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@flow.unittest.skip_unless_1n2d()\nclass TestLocalToGlobalBranchError(flow.unittest.TestCase):\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    def test_global_branch_error_with_local_to_global(test_case):\n        try:\n            os.environ[\"ONEFLOW_TIMEOUT_SECONDS\"] = \"2\"\n            data = flow.rand(2, dtype=flow.float32)\n            placement = flow.placement.all(\"cuda\")\n            sbp = flow.sbp.split(0)\n            if flow.env.get_rank() == 0:\n                global_data = data.to_global(placement=placement, sbp=sbp)\n            else:\n                time.sleep(2)\n\n        except Exception as e:\n            err_msg = \"Maybe executing different code in different ranks, please check if the code is branched and operates on the global tensor\"\n            assert err_msg in str(e)\n        finally:\n            os.environ[\"ONEFLOW_TIMEOUT_SECONDS\"] = \"300\"\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/exceptions/test_global_branch_error_with_global_mean.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport os\nimport numpy as np\nimport time\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@flow.unittest.skip_unless_1n2d()\nclass TestGlobalMeanBranchError(flow.unittest.TestCase):\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    def test_global_branch_error_global_data_mean(test_case):\n        try:\n            os.environ[\"ONEFLOW_TIMEOUT_SECONDS\"] = \"2\"\n            data = flow.rand(2, dtype=flow.float32)\n            placement = flow.placement.all(\"cuda\")\n            sbp = flow.sbp.split(0)\n            global_data = data.to_global(placement=placement, sbp=sbp)\n            if flow.env.get_rank() == 0:\n                print(data.mean())\n                print(global_data.mean())\n            else:\n                time.sleep(2)\n\n        except Exception as e:\n            err_msg = \"Maybe executing different code in different ranks, please check if the code is branched and operates on the global tensor\"\n            assert err_msg in str(e)\n        finally:\n            os.environ[\"ONEFLOW_TIMEOUT_SECONDS\"] = \"300\"\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/exceptions/test_hann_window.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestHannWindow(flow.unittest.TestCase):\n    def test_hann_window_dtype_not_support(test_case):\n        window_length = 8\n        dtype = flow.int64\n        with test_case.assertRaises(RuntimeError) as ctx:\n            x = flow.hann_window(window_length, dtype=dtype)\n        test_case.assertTrue(\n            \"hann_window expects floating point dtypes, got: \" in str(ctx.exception)\n        )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/exceptions/test_in_top_k.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nimport oneflow.unittest\nimport oneflow as flow\nimport numpy as np\n\n\nclass TestInTopK(flow.unittest.TestCase):\n    def test_in_top_k_error_msg(test_case):\n        arr = np.array([1, 1])\n        targets = flow.Tensor(arr)\n        targets = flow.cast(targets, flow.float)\n        arr = np.array([[0.8, 0.6, 0.3], [0.1, 0.6, 0.4]])\n        predictions = flow.Tensor(arr)\n        with test_case.assertRaises(RuntimeError) as ctx:\n            flow.in_top_k(targets, predictions, 1)\n        test_case.assertTrue(\n            \"targets data type must be index type\" in str(ctx.exception)\n        )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/exceptions/test_inv.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestInv(flow.unittest.TestCase):\n    def test_inv_exception_dim_short(test_case):\n        x = flow.tensor((2, 2))\n        with test_case.assertRaises(RuntimeError) as ctx:\n            y = flow.linalg.inv(x)\n        test_case.assertTrue(\n            \"linalg.inv: The input tensor must be at least 2 dimensions.\"\n            in str(ctx.exception)\n        )\n\n    def test_inv_exception_not_square_matrix(test_case):\n        x = flow.randn(2, 3, 2)\n        with test_case.assertRaises(RuntimeError) as ctx:\n            y = flow.linalg.inv(x)\n        test_case.assertTrue(\n            \"RuntimeError: linalg.inv: A must be batches of square matrices, but they are 3 by 2 matrices\"\n            in str(ctx.exception)\n        )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/exceptions/test_layernorm.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestLayerNormModule(flow.unittest.TestCase):\n    def test_layernorm_exception_input_shape_not_match(test_case):\n        x = flow.randn(2, 3)\n        m = flow.nn.LayerNorm(2)\n        with test_case.assertRaises(RuntimeError) as ctx:\n            y = m(x)\n        test_case.assertTrue(\n            \"Given normalized_shape=(2,), expected input with shape [*, 2,], but got input of size oneflow.Size([2, 3])\"\n            in str(ctx.exception)\n        )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/exceptions/test_linalg.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nimport numpy as np\nimport oneflow as flow\nimport oneflow.unittest\n\n\nclass TestLinalgCross(flow.unittest.TestCase):\n    def test_cross_has_no_3_error(test_case):\n        a = flow.randn(4, 2)\n        b = flow.randn(4, 2)\n        with test_case.assertRaises(RuntimeError) as ctx:\n            flow.cross(a, b)\n            test_case.assertTrue(\n                \"RuntimeError: no dimension of size 3 in input.\" in str(ctx.exception)\n            )\n\n    def test_linalg_cross_has_no_3_error(test_case):\n        a = flow.randn(4, 2)\n        b = flow.randn(4, 2)\n        with test_case.assertRaises(RuntimeError) as ctx:\n            flow.linalg.cross(a, b)\n            test_case.assertTrue(\n                \"RuntimeError: the size of the specified dimension(which is -1) is not 3.\"\n                in str(ctx.exception)\n            )\n\n    def test_linalg_cross_broadcast_error(test_case):\n        a = flow.randn(4)\n        b = flow.randn(4, 2)\n        with test_case.assertRaises(RuntimeError) as ctx:\n            flow.linalg.cross(a, b)\n            test_case.assertTrue(\n                \"RuntimeError: input and other can't be broadcasted to a single shape. [input's shape: (1,4), other's shape: (4,2)].\"\n                in str(ctx.exception)\n            )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/exceptions/test_local_global_convert_error.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\nclass TestModule(flow.unittest.TestCase):\n    @flow.unittest.skip_unless_1n1d()\n    def test_get_sbp_with_invalid_axis(test_case):\n        with test_case.assertRaises(RuntimeError) as ctx:\n            sbp = flow.sbp.split(-1)\n        test_case.assertTrue(\n            \"Split axis must not be negative, but got -1!\" in str(ctx.exception)\n        )\n\n        with test_case.assertRaises(RuntimeError) as ctx:\n            sbp = flow.sbp.split(7)\n        test_case.assertTrue(\n            \"Expected split axis to be less than the supported maximum axis (6), but got 7!\"\n            in str(ctx.exception)\n        )\n\n    @flow.unittest.skip_unless_1n1d()\n    def test_local_to_global_with_invalid_split_axis(test_case):\n        x = flow.tensor([1, 2, 3, 4])\n        with test_case.assertRaises(RuntimeError) as ctx:\n            y = x.to_global(placement=flow.placement.all(\"cpu\"), sbp=flow.sbp.split(1))\n        test_case.assertTrue(\n            \"Split axis out of range (expected to be in range of [0, 1), but got 1!\"\n            in str(ctx.exception)\n        )\n\n    @flow.unittest.skip_unless_1n1d()\n    def test_global_to_global_with_invalid_split_axis(test_case):\n        x = flow.tensor(\n            [1, 2, 3, 4], placement=flow.placement.all(\"cpu\"), sbp=flow.sbp.broadcast,\n        )\n        with test_case.assertRaises(RuntimeError) as ctx:\n            y = x.to_global(sbp=flow.sbp.split(1))\n        test_case.assertTrue(\n            \"Split axis out of range (expected to be in range of [0, 1), but got 1!\"\n            in str(ctx.exception)\n        )\n\n    @flow.unittest.skip_unless_1n1d()\n    def test_call_to_local_for_local_tensor(test_case):\n        x = flow.tensor([1, 2, 3, 4])\n        with test_case.assertRaises(RuntimeError) as ctx:\n            y = x.to_local()\n        test_case.assertTrue(\n            \"Expected global tensor for to_local but got local tensor!\"\n            in str(ctx.exception)\n        )\n\n    @flow.unittest.skip_unless_1n2d()\n    def test_local_to_global_with_invalid_size(test_case):\n        if flow.env.get_rank() == 0:\n            x = flow.Tensor(2, 4)  # size(2, 4)\n        else:\n            x = flow.Tensor(4, 4)  # size(4, 4)\n        with test_case.assertRaises(RuntimeError) as ctx:\n            y = x.to_global(placement=flow.placement.all(\"cpu\"), sbp=flow.sbp.split(0))\n        test_case.assertTrue(\n            \"Sizes of tensors in dimension 0 must be same or match balanced split distribution. \"\n            \"See https://github.com/Oneflow-Inc/oneflow/blob/master/oneflow/core/common/balanced_splitter.h \"\n            \"for details of balanced split\" in str(ctx.exception)\n        )\n\n        with test_case.assertRaises(RuntimeError) as ctx:\n            y = x.to_global(placement=flow.placement.all(\"cpu\"), sbp=flow.sbp.split(1))\n        test_case.assertTrue(\n            \"Sizes of tensors must match except in dimension 1. Expected size 2 but got size 4 for tensor on rank 1!\"\n            in str(ctx.exception)\n        )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/exceptions/test_median.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestMedian(flow.unittest.TestCase):\n    def test_median_exception_dim_out_of_range(test_case):\n        x = flow.tensor((2, 2))\n        with test_case.assertRaises(IndexError) as ctx:\n            y = flow.median(x, 1)\n        test_case.assertTrue(\n            \"Dimension out of range (expected to be in range of [-1, 0], but got 1)\"\n            in str(ctx.exception)\n        )\n\n    def test_median_exception_reduce_0dim(test_case):\n        x = flow.randn(2, 0, 2)\n        with test_case.assertRaises(IndexError) as ctx:\n            y = flow.median(x, 1)\n        test_case.assertTrue(\n            \"IndexError: Expected reduction dim 1 to have non-zero size.\"\n            in str(ctx.exception)\n        )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/exceptions/test_mm.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nimport oneflow as flow\nimport oneflow.unittest\nimport oneflow.nn.functional as F\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestMm(flow.unittest.TestCase):\n    def test_mm_not_2dim(test_case):\n        with test_case.assertRaises(Exception) as exp:\n            mat1 = flow.randn(2, 3, 3)\n            mat2 = flow.randn(3, 3)\n            out = flow.mm(mat1, mat2)\n        test_case.assertTrue(\"self must be a matrix\" in str(exp.exception))\n        with test_case.assertRaises(Exception) as exp:\n            mat1 = flow.randn(2, 3)\n            mat2 = flow.randn(3, 3, 2)\n            out = flow.mm(mat1, mat2)\n        test_case.assertTrue(\"mat2 must be a matrix\" in str(exp.exception))\n\n    def test_mm_dim_not_match(test_case):\n        with test_case.assertRaises(Exception) as exp:\n            mat1 = flow.randn(2, 3)\n            mat2 = flow.randn(4, 3)\n            out = flow.mm(mat1, mat2)\n        test_case.assertTrue(\n            \"mat1 and mat2 shapes cannot be multiplied (2x3 and 4x3)\"\n            in str(exp.exception)\n        )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/exceptions/test_mode.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestMode(flow.unittest.TestCase):\n    def test_mode_exception_dim_out_of_range(test_case):\n        x = flow.tensor((2, 2))\n        with test_case.assertRaises(IndexError) as ctx:\n            y = flow.mode(x, 1)\n        test_case.assertTrue(\n            \"Dimension out of range (expected to be in range of [-1, 0], but got 1)\"\n            in str(ctx.exception)\n        )\n\n    def test_mode_exception_reduce_0dim(test_case):\n        x = flow.randn(2, 0, 2)\n        with test_case.assertRaises(IndexError) as ctx:\n            y = flow.mode(x, 1)\n        test_case.assertTrue(\n            \"IndexError: Expected reduction dim 1 to have non-zero size.\"\n            in str(ctx.exception)\n        )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/exceptions/test_multi_input_with_diff_device_or_placement.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport os\nimport unittest\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\nclass TestModule(flow.unittest.TestCase):\n    @flow.unittest.skip_unless_1n1d()\n    def test_multi_input_with_diff_device(test_case):\n        # torch exception and messge:\n        #\n        #   RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!\n        #\n        x = flow.tensor([1, 2, 3, 4])\n        y = flow.tensor([2, 4, 6, 8], device=\"cuda\")\n        with test_case.assertRaises(RuntimeError) as ctx:\n            z = flow.add(x, y)\n        test_case.assertTrue(\n            \"Expected all tensors to be on the same device, but found at least two devices\"\n            in str(ctx.exception)\n        )\n\n    @flow.unittest.skip_unless_1n2d()\n    def test_multi_input_with_diff_placement(test_case):\n        x = flow.tensor(\n            [1, 2, 3, 4], placement=flow.placement(\"cuda\", [0]), sbp=flow.sbp.broadcast\n        )\n        y = flow.tensor(\n            [2, 4, 6, 8], placement=flow.placement(\"cuda\", [1]), sbp=flow.sbp.broadcast\n        )\n        with test_case.assertRaises(RuntimeError) as ctx:\n            z = flow.add(x, y)\n        test_case.assertTrue(\n            \"Expected all tensors to be on the same placement, but found at least two placements\"\n            in str(ctx.exception)\n        )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/exceptions/test_mv.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nimport oneflow as flow\nimport oneflow.unittest\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestMv(flow.unittest.TestCase):\n    def test_mv_not_matrix(test_case):\n        with test_case.assertRaises(Exception) as exp:\n            mat = flow.randn(2, 3, 3)\n            vec = flow.randn(3)\n            out = flow.mv(mat, vec)\n        test_case.assertTrue(\n            \"vector + matrix @ vector expected, got 1, 3, 1\" in str(exp.exception)\n        )\n\n    def test_mv_not_vector(test_case):\n        with test_case.assertRaises(Exception) as exp:\n            mat = flow.randn(2, 3)\n            vec = flow.randn(3, 1)\n            out = flow.mv(mat, vec)\n        test_case.assertTrue(\n            \"vector + matrix @ vector expected, got 1, 2, 2\" in str(exp.exception)\n        )\n\n    def test_mv_size_mismatch(test_case):\n        with test_case.assertRaises(Exception) as exp:\n            mat = flow.randn(2, 3)\n            vec = flow.randn(4)\n            out = flow.mv(mat, vec)\n        test_case.assertTrue(\"size mismatch\" in str(exp.exception))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/exceptions/test_nn_functor.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport re\nimport unittest\n\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\nclass TestBiasAddError(flow.unittest.TestCase):\n    def test_bias_add_dimension_match_error(test_case):\n        with test_case.assertRaises(Exception) as ctx:\n            x = flow.ones((4, 4), dtype=flow.float32)\n            bias = flow.ones((5,), dtype=flow.float32)\n            out = flow._C.bias_add(x, bias, axis=1)\n\n        test_case.assertTrue(\n            \"The size of tensor x (4,4) must match the size of tensor b (5,) at dimension 1\"\n            in str(ctx.exception)\n        )\n\n    def test_bias_add_index_error(test_case):\n        with test_case.assertRaises(Exception) as ctx:\n            x = flow.ones((4, 4), dtype=flow.float32)\n            bias = flow.ones((5,), dtype=flow.float32)\n            out = flow._C.bias_add(x, bias, axis=3)\n\n        test_case.assertTrue(\n            \"Dimension out of range (expected to be in range of [-2,1], but got 3)\"\n            in str(ctx.exception)\n        )\n\n\nclass TestCrossEntropyError(flow.unittest.TestCase):\n    def test_cross_entropy_reduction_type_error(test_case):\n        with test_case.assertRaises(Exception) as ctx:\n            x = flow.ones((4, 4), dtype=flow.float32)\n            target = flow.ones((4, 4), dtype=flow.float32)\n            out = flow._C.cross_entropy(x, target, None, 0, \"just_test\")\n\n        test_case.assertTrue(\n            \"Reduction should be none, sum or mean.\" in str(ctx.exception)\n        )\n\n\nclass TestCTCLossError(flow.unittest.TestCase):\n    def test_ctcloss_reduction_type_error(test_case):\n        with test_case.assertRaises(Exception) as ctx:\n            x = flow.ones((5, 2, 3), dtype=flow.float32)\n            targets = flow.tensor([[1, 2, 2], [1, 2, 2]], dtype=flow.int32)\n            input_lengths = flow.tensor([5, 5], dtype=flow.int32)\n            target_lengths = flow.tensor([3, 3], dtype=flow.int32)\n            max_target_length = 0\n            if targets.ndim == 1:\n                max_target_length = target_lengths.max().item()\n            elif targets.ndim == 2:\n                max_target_length = targets.shape[1]\n            loss = flow._C.ctc_loss(\n                x,\n                targets,\n                input_lengths,\n                target_lengths,\n                max_target_length,\n                blank=0,\n                zero_infinity=False,\n                reduction=\"just_test\",\n            )\n        test_case.assertTrue(\n            \"Reduction should be none, sum or mean.\" in str(ctx.exception)\n        )\n\n\nclass TestPadError(flow.unittest.TestCase):\n    def test_pad_size_attribute_error(test_case):\n        with test_case.assertRaises(Exception) as ctx:\n            x = flow.ones((1, 1), dtype=flow.float32)\n            out = flow._C.pad(x, (1, 1, 1, 1, 1))\n        test_case.assertTrue(\n            \"Pad size should less than or equal to input axes * 2.\"\n            in str(ctx.exception)\n        )\n\n    def test_pad_size_mod2_error(test_case):\n        with test_case.assertRaises(Exception) as ctx:\n            x = flow.ones((1, 1), dtype=flow.float32)\n            out = flow._C.pad(x, (1, 1, 1,))\n\n        test_case.assertTrue(\n            \"Length of pad must be even but instead it equals 3\" in str(ctx.exception)\n        )\n\n    def test_reflect_pad_size_error(test_case):\n        with test_case.assertRaises(Exception) as ctx:\n            x = flow.ones((1, 1, 2, 2), dtype=flow.float32)\n            out = flow._C.pad(x, (4, 4, 4, 4), mode=\"reflect\")\n\n        test_case.assertTrue(\n            \"Padding size should be less than the corresponding input dimension, but got:\"\n            in str(ctx.exception)\n        )\n\n    def test_pad_mode_error(test_case):\n        with test_case.assertRaises(NotImplementedError) as ctx:\n            x = flow.ones((1, 1, 2, 2), dtype=flow.float32)\n            out = flow._C.pad(x, (4, 4, 4, 4), mode=\"test\")\n\n        test_case.assertTrue(\n            \"Pad mode is test, but only constant, reflect and replicate are valid.\"\n            in str(ctx.exception)\n        )\n\n\nclass TestFusedMLPError(flow.unittest.TestCase):\n    def test_fuse_mlp_weight_size_error(test_case):\n        with test_case.assertRaises(Exception) as ctx:\n            x = flow.ones((4, 4), dtype=flow.float32)\n            bias = flow.ones((4,), dtype=flow.float32)\n            out = flow._C.fused_mlp(x, [], [bias], False)\n\n        test_case.assertTrue(\n            \"The number of weights should be greater equal than 1\" in str(ctx.exception)\n        )\n\n    def test_fuse_mlp_weight_bias_size_error(test_case):\n        with test_case.assertRaises(Exception) as ctx:\n            x = flow.ones((4, 4), dtype=flow.float32)\n            w1 = flow.ones((4, 4), dtype=flow.float32)\n            w2 = flow.ones((4, 4), dtype=flow.float32)\n            bias1 = flow.ones((4,), dtype=flow.float32)\n            out = flow._C.fused_mlp(x, [w1, w2], [bias1], False)\n\n        test_case.assertTrue(\n            \"The number of weights should be equal to biases\" in str(ctx.exception)\n        )\n\n    def test_fuse_mlp_weight_numaxes_error(test_case):\n        with test_case.assertRaises(Exception) as ctx:\n            x = flow.ones((4, 4), dtype=flow.float32)\n            w1 = flow.ones((4,), dtype=flow.float32)\n            bias1 = flow.ones((4,), dtype=flow.float32)\n            out = flow._C.fused_mlp(x, [w1,], [bias1,], False)\n        test_case.assertTrue(\"Weight's dim size should == 2\" in str(ctx.exception))\n\n    def test_fuse_mlp_bias_numaxes_error(test_case):\n        with test_case.assertRaises(Exception) as ctx:\n            x = flow.ones((4, 4), dtype=flow.float32)\n            w1 = flow.ones((4, 4), dtype=flow.float32)\n            bias1 = flow.ones((4, 4), dtype=flow.float32)\n            out = flow._C.fused_mlp(x, [w1,], [bias1,], False)\n        test_case.assertTrue(\"Bias's dim size should == 1\" in str(ctx.exception))\n\n    def test_fuse_mlp_bias_first_dim_error(test_case):\n        with test_case.assertRaises(Exception) as ctx:\n            x = flow.ones((4, 4), dtype=flow.float32)\n            w1 = flow.ones((6, 4), dtype=flow.float32)\n            bias1 = flow.ones((5), dtype=flow.float32)\n            out = flow._C.fused_mlp(x, [w1,], [bias1,], False)\n\n        test_case.assertTrue(\n            \"Bias's dim is not equal to weight's first dim.\" in str(ctx.exception)\n        )\n\n    def test_fuse_mlp_weight_second_dim_error(test_case):\n        with test_case.assertRaises(Exception) as ctx:\n            x = flow.ones((2, 4), dtype=flow.float32)\n            w1 = flow.ones((3, 6), dtype=flow.float32)\n            bias1 = flow.ones((3), dtype=flow.float32)\n            out = flow._C.fused_mlp(x, [w1,], [bias1,], False)\n\n        test_case.assertTrue(\n            \"weight's second dim should be equal to input's second dim.\"\n            in str(ctx.exception)\n        )\n\n\nclass TestL2NormalizeError(flow.unittest.TestCase):\n    def test_l2normalize_axis_error1(test_case):\n        with test_case.assertRaises(Exception) as ctx:\n            x = flow.ones((3, 3), dtype=flow.float32)\n            out = flow._C.normalize(x, dim=3, use_l2_norm_kernel=True)\n        test_case.assertTrue(\"Axis should < 2 but axis is 3 now.\" in str(ctx.exception))\n\n    def test_l2normalize_axis_error2(test_case):\n        with test_case.assertRaises(Exception) as ctx:\n            x = flow.ones((3, 3), dtype=flow.float32)\n            out = flow._C.normalize(x, dim=-3, use_l2_norm_kernel=True)\n        test_case.assertTrue(\n            \"Axis should >=0 but axis is -1 now.\" in str(ctx.exception)\n        )\n\n\nclass TestLossBaseFunctorError(flow.unittest.TestCase):\n    def test_loss_base_reduction_type_error(test_case):\n        with test_case.assertRaises(Exception) as ctx:\n            x = flow.ones((4, 4), dtype=flow.float32)\n            target = flow.ones((4, 4), dtype=flow.float32)\n            out = flow._C.mse_loss(x, target, \"just_test\")\n\n        test_case.assertTrue(\n            \"Reduction should be none, sum or mean.\" in str(ctx.exception)\n        )\n\n\nclass TestMatmulError(flow.unittest.TestCase):\n    def test_matmul_dimension_error1(test_case):\n        with test_case.assertRaises(Exception) as ctx:\n            x = flow.ones((), dtype=flow.float32)\n            w = flow.ones((4, 4), dtype=flow.float32)\n            out = flow._C.matmul(x, w, False, False, 1.0)\n        test_case.assertTrue(\"Tensor a's dim should >= 1\" in str(ctx.exception))\n\n    def test_matmul_dimension_error2(test_case):\n        with test_case.assertRaises(Exception) as ctx:\n            x = flow.ones((4, 4), dtype=flow.float32)\n            w = flow.ones((), dtype=flow.float32)\n            out = flow._C.matmul(x, w, False, False, 1.0)\n        test_case.assertTrue(\"Tensor b's dim should >= 1\" in str(ctx.exception))\n\n\nclass TestPixelShuffleError(flow.unittest.TestCase):\n    def test_pixel_shuffle_4D_input_error(test_case):\n        with test_case.assertRaises(Exception) as ctx:\n            x = flow.ones((1, 8, 4, 4, 1), dtype=flow.float32)\n            out = flow._C.pixel_shuffle(x, 2, 2)\n\n        test_case.assertTrue(\"Only Accept 4D Tensor\" in str(ctx.exception))\n\n    def test_pixel_shuffle_channel_divisble_error(test_case):\n        with test_case.assertRaises(Exception) as ctx:\n            x = flow.ones((1, 8, 4, 4), dtype=flow.float32)\n            out = flow._C.pixel_shuffle(x, 2, 3)\n\n        test_case.assertTrue(\n            \"The channels of input tensor must be divisible by (upscale_factor * upscale_factor) or (h_upscale_factor * w_upscale_factor)\"\n            in str(ctx.exception)\n        )\n\n\nclass TestTripletMarginLossError(flow.unittest.TestCase):\n    def test_triplet_margin_loss_reduce_type_error(test_case):\n        with test_case.assertRaises(Exception) as ctx:\n            anchor = flow.ones((3, 3), dtype=flow.float32)\n            positive = flow.ones((3, 3), dtype=flow.float32)\n            negative = flow.ones((3, 3), dtype=flow.float32)\n\n            triplet_loss = flow._C.triplet_margin_loss(\n                anchor,\n                positive,\n                negative,\n                margin=0.001,\n                p=2,\n                eps=1e-5,\n                swap=False,\n                reduction=\"just_test\",\n            )\n\n        test_case.assertTrue(\n            \"Reduction should be none, sum or mean.\" in str(ctx.exception)\n        )\n\n\nclass TestNormalError(flow.unittest.TestCase):\n    def test_normal_data_type_error(test_case):\n        with test_case.assertRaises(Exception) as ctx:\n            x = flow._C.normal(mean=0.0, std=1.0, size=(3, 3), dtype=flow.int32)\n\n        test_case.assertTrue(\n            \"Only support float and double in normal().\" in str(ctx.exception)\n        )\n\n    def test_normal_out_tensor_data_type_error(test_case):\n        with test_case.assertRaises(RuntimeError) as ctx:\n            out = flow.zeros((3, 3), dtype=flow.float64)\n            x = flow._C.normal(\n                mean=0.0, std=1.0, size=(3, 3), dtype=flow.float32, out=out\n            )\n\n        test_case.assertTrue(\n            \"data type oneflow.float32 does not match data type of out parameter oneflow.float64\"\n            in str(ctx.exception)\n        )\n\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    def test_normal_out_tensor_device_type_error(test_case):\n        with test_case.assertRaises(RuntimeError) as ctx:\n            out = flow.zeros((3, 3), dtype=flow.float32, device=\"cuda\")\n            x = flow._C.normal(\n                mean=0.0,\n                std=1.0,\n                size=(3, 3),\n                dtype=flow.float32,\n                out=out,\n                device=\"cpu\",\n            )\n\n        test_case.assertTrue(\n            \"does not match device type of out parameter\" in str(ctx.exception)\n        )\n\n\nclass TestNormalizationError(flow.unittest.TestCase):\n    def test_normalization_moving_mean_error(test_case):\n        with test_case.assertRaises(Exception) as ctx:\n            x = flow.ones((1, 4, 2, 2), dtype=flow.float32)\n            moving_mean = flow.ones((4,), dtype=flow.float32)\n            weight = flow.ones((4,), dtype=flow.float32)\n            bias = flow.ones((4,), dtype=flow.float32)\n\n            out = flow._C.normalization(\n                x, moving_mean, None, weight, bias, 1, 1e-5, 0.9, False\n            )\n\n        test_case.assertTrue(\n            \"Both moving_mean and moving_variance should be None or Tensor.\"\n            in str(ctx.exception)\n        )\n\n    def test_normalization_x_input_axes_error(test_case):\n        with test_case.assertRaises(Exception) as ctx:\n            x = flow.ones((1,), dtype=flow.float32)\n            weight = flow.ones((4,), dtype=flow.float32)\n            bias = flow.ones((4,), dtype=flow.float32)\n\n            out = flow._C.normalization(\n                x, None, None, weight, bias, 1, 1e-5, 0.9, False\n            )\n\n        test_case.assertTrue(\n            \"NumAxes of x should be greater or equal than 2.\" in str(ctx.exception)\n        )\n\n    def test_normalization_eval_need_moving_statistic_error(test_case):\n        with test_case.assertRaises(Exception) as ctx:\n            x = flow.ones((1, 2,), dtype=flow.float32)\n            weight = flow.ones((2,), dtype=flow.float32)\n            bias = flow.ones((2,), dtype=flow.float32)\n\n            out = flow._C.normalization(\n                x, None, None, weight, bias, 1, 1e-5, 0.9, False\n            )\n\n        test_case.assertTrue(\n            \"Must have moving_mean and moving_variance in eval mode.\"\n            in str(ctx.exception)\n        )\n\n\nclass TestOnehotError(flow.unittest.TestCase):\n    def test_onehot_error(test_case):\n        with test_case.assertRaises(Exception) as ctx:\n            x = flow.ones((3, 3), dtype=flow.float32)\n            out = flow._C.one_hot(x, 3, 0.9, 0)\n\n        test_case.assertTrue(\n            \"one_hot is only applicable to index tensor.\" in str(ctx.exception)\n        )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/exceptions/test_optim_add_param_group.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nimport oneflow as flow\nimport oneflow.unittest\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestSgdAddParamGroup(flow.unittest.TestCase):\n    def test_sgd_add_param_group_not_unique(test_case):\n        with test_case.assertRaises(Exception) as exp:\n            w1 = flow.ones(3, 3)\n            w1.requires_grad = True\n            w2 = flow.ones(3, 3)\n            w2.requires_grad = True\n            o = flow.optim.SGD([w1])\n            o.add_param_group({\"params\": w2})\n            o.add_param_group({\"params\": w2})\n        print(str(exp.exception))\n        test_case.assertTrue(\n            \"some parameters appear in more than one parameter group\"\n            in str(exp.exception)\n        )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/exceptions/test_pad.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nimport oneflow as flow\nimport oneflow.unittest\nimport oneflow.nn.functional as F\nimport torch\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestPad(flow.unittest.TestCase):\n    def test_torch_type(test_case):\n        with test_case.assertRaises(TypeError) as exp:\n            F.pad(torch.randn(2, 2))\n        test_case.assertTrue(\n            \"pad() missing 1 required positional argument: 'pad'\" in str(exp.exception)\n        )\n\n    def test_numpy_type(test_case):\n        import numpy as np\n\n        with test_case.assertRaises(TypeError) as exp:\n            F.pad(np.random.randn(2, 2))\n        test_case.assertTrue(\n            \"pad() missing 1 required positional argument: 'pad'\" in str(exp.exception)\n        )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/exceptions/test_placement.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nimport oneflow as flow\nimport oneflow.unittest\n\n\nclass TestPlacement(flow.unittest.TestCase):\n    @flow.unittest.skip_unless_1n2d()\n    def test_inconsistent_placement(test_case):\n        x = flow.randn(2, 3)\n        if flow.env.get_rank() == 0:\n            placement = flow.placement(\"cpu\", [0, 1])\n        else:\n            placement = flow.placement(\"cpu\", [0])\n        sbp = flow.sbp.split(1)\n        with test_case.assertRaises(RuntimeError) as ctx:\n            x_global = x.to_global(placement=placement, sbp=sbp)\n        test_case.assertTrue(\"Inconsistent parallel description\" in str(ctx.exception))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/exceptions/test_randperm_op.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\nclass TestRandpermOp(flow.unittest.TestCase):\n    def test_randperm_n_value_err_mes(test_case):\n        with test_case.assertRaises(RuntimeError) as ctx:\n            a = flow.randperm(-1)\n        test_case.assertTrue(\n            \"Trying to create tensor with negative dimension\" in str(ctx.exception)\n        )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/exceptions/test_reduce_like_ops.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\nclass TestReduceSumLikeOps(flow.unittest.TestCase):\n    def test_reduce_sum_like_empty_axis_case_err(test_case):\n        a = flow.tensor([1, 1])\n        b = flow.tensor([1, 1, 1])\n        with test_case.assertRaises(RuntimeError) as ctx:\n            flow._C.reduce_sum_like(a, b, [])\n        test_case.assertTrue(\n            \"The shape of the x tensor must be consistent to the shape of the like tensor\"\n            in str(ctx.exception)\n        )\n\n    def test_reduce_sum_like_type_err(test_case):\n        a = flow.tensor([1, 1], dtype=flow.int64)\n        b = flow.tensor([1, 1], dtype=flow.float64)\n        with test_case.assertRaises(TypeError) as ctx:\n            flow._C.reduce_sum_like(a, b, [1])\n        test_case.assertTrue(\n            \"Tensors x and like must have the same type\" in str(ctx.exception)\n        )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/exceptions/test_reduce_ops.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestReduceOps(flow.unittest.TestCase):\n    def test_exception_dim_out_of_int_range(test_case):\n        x = flow.randn(2, 3, 4)\n        with test_case.assertRaises(IndexError) as exp:\n            flow.sum(x, 3)\n        test_case.assertTrue(\"Dimension out of range\" in str(exp.exception))\n\n    def test_exception_dim_out_of_list_range(test_case):\n        x = flow.randn(2, 3, 4)\n        with test_case.assertRaises(IndexError) as exp:\n            flow.sum(x, [-4])\n        test_case.assertTrue(\"Dimension out of range\" in str(exp.exception))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/exceptions/test_repeat_interleave.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nimport oneflow as flow\nimport oneflow.unittest\nimport oneflow.nn.functional as F\nimport torch\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestRepeatInterleave(flow.unittest.TestCase):\n    def test_repeat_interleave_index_error(test_case):\n        with test_case.assertRaises(Exception) as context:\n            x = flow.tensor([[1, 2], [3, 4]])\n            y = flow.repeat_interleave(x, 3, dim=4)\n        test_case.assertTrue(\n            \"Dimension out of range (expected to be in range of [-2, 1], but got 4)\"\n            in str(context.exception)\n        )\n\n    def test_repeat_interleave_tensor_shape_error(test_case):\n        with test_case.assertRaises(Exception) as context:\n            x = flow.tensor([[1, 2], [3, 4]])\n            r = flow.tensor([[1, 2], [3, 4]])\n            y = flow.repeat_interleave(x, r, dim=1)\n        test_case.assertTrue(\n            \"repeat_interleave only accept 1D vector as repeat\"\n            in str(context.exception)\n        )\n\n    def test_repeat_interleave_dtype_error(test_case):\n        with test_case.assertRaises(Exception) as context:\n            x = flow.tensor([[1, 2], [3, 4]])\n            r = flow.tensor([1.0, 2.0])\n            y = flow.repeat_interleave(x, r, dim=1)\n        test_case.assertTrue(\"repeats has to be Long tensor\" in str(context.exception))\n\n    def test_repeat_interleave_negative_tensor_error(test_case):\n        with test_case.assertRaises(Exception) as context:\n            x = flow.tensor([[1, 2], [3, 4]])\n            r = flow.tensor([1, -2])\n            y = flow.repeat_interleave(x, r, dim=1)\n        test_case.assertTrue(\"repeats can not be negative\" in str(context.exception))\n\n    def test_repeat_interleave_negative_tensor_error(test_case):\n        with test_case.assertRaises(Exception) as context:\n            x = flow.tensor([[1, 2], [3, 4]])\n            r = flow.tensor([1, 2])\n            y = flow.repeat_interleave(x, r, dim=2)\n        test_case.assertTrue(\n            \"Dimension out of range (expected to be in range of [-2, 1], but got 2)\"\n            in str(context.exception)\n        )\n\n    def test_repeat_interleave_dim_not_match_error(test_case):\n        with test_case.assertRaises(Exception) as context:\n            x = flow.tensor([[1, 2], [3, 4]])\n            r = flow.tensor([1])\n            y = flow.repeat_interleave(x, r, dim=1)\n        test_case.assertTrue(\n            \"repeats must have the same size as input along dim\"\n            in str(context.exception)\n        )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/exceptions/test_reshape.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestModule(flow.unittest.TestCase):\n    def test_reshape_exception_invalid_dim(test_case):\n        # torch exception and messge:\n        #\n        #   RuntimeError: Invalid shape dimension -2\n        #\n        x = flow.tensor((2, 2))\n        with test_case.assertRaises(RuntimeError) as ctx:\n            y = x.reshape((-2, 4))\n        test_case.assertTrue(\"Invalid shape dimension -2\" in str(ctx.exception))\n\n    def test_reshape_exception_invalid_size(test_case):\n        # torch exception and messge:\n        #\n        #   RuntimeError: shape '[2, 3, 5]' is invalid for input of size 24\n        #\n        x = flow.arange(24).reshape(2, 3, 4)\n        with test_case.assertRaises(RuntimeError) as ctx:\n            y = x.reshape((2, 3, 5))\n        test_case.assertTrue(\"is invalid for input of size 24\" in str(ctx.exception))\n\n    def test_reshape_exception_only_one_dim_infered(test_case):\n        # torch exception and messge:\n        #\n        #   RuntimeError: only one dimension can be inferred\n        #\n        x = flow.tensor((2, 2))\n        with test_case.assertRaises(RuntimeError) as ctx:\n            y = x.reshape((-1, -1))\n        test_case.assertTrue(\"only one dimension can be inferred\" in str(ctx.exception))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/exceptions/test_reshape_like_op.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\nclass TestReshapeLikeOp(flow.unittest.TestCase):\n    def test_reshape_like_size_match_err(test_case):\n        a = flow.tensor([1, 1])\n        b = flow.tensor([[1, 1, 1], [1, 1, 1]])\n        with test_case.assertRaises(RuntimeError) as ctx:\n            flow._C.reshape_like(a, b)\n        test_case.assertTrue(\n            \"The element number of the in tensor must be equal to the element number of the like tensor\"\n            in str(ctx.exception)\n        )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/exceptions/test_roi_align_op.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nimport numpy as np\nimport oneflow as flow\nimport oneflow.unittest\n\n\nclass TestRoiAlignOp(flow.unittest.TestCase):\n    def test_rol_align_x_tensor_dimension_err(test_case):\n        x = flow.randn(2, 3, 64)\n        rois = flow.randn(2, 3, 64, 64)\n        with test_case.assertRaises(RuntimeError) as ctx:\n            flow.roi_align(x, rois, 2.0, 14, 14, 2, True)\n        test_case.assertTrue(\n            \"The dimension of x tensor must be equal to 4, but got\"\n            in str(ctx.exception)\n        )\n\n    def test_rol_align_rois_tensor_dimension_err(test_case):\n        x = flow.randn(2, 3, 64, 5)\n        rois = flow.randn(2, 3, 64, 64)\n        with test_case.assertRaises(RuntimeError) as ctx:\n            flow.roi_align(x, rois, 2.0, 14, 14, 2, True)\n        test_case.assertTrue(\n            \"The dimension of rois tensor must be equal to 2, but got\"\n            in str(ctx.exception)\n        )\n\n    def test_rol_align_rois_tensor_size_err(test_case):\n        x = flow.randn(2, 3, 64, 5)\n        rois = flow.randn(2, 3)\n        with test_case.assertRaises(RuntimeError) as ctx:\n            flow.roi_align(x, rois, 2.0, 14, 14, 2, True)\n        test_case.assertTrue(\n            \"The size of rois tensor must be equal to 5 at dimension 1, but got\"\n            in str(ctx.exception)\n        )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/exceptions/test_save_load.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nimport tempfile\n\nimport oneflow as flow\nimport oneflow.unittest\nimport torch\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestSaveLoad(flow.unittest.TestCase):\n    def test_support_pytorch_with_global_src_rank(test_case):\n        conv_torch = torch.nn.Conv2d(3, 3, 3)\n        conv_flow = flow.nn.Conv2d(3, 3, 3)\n        with tempfile.NamedTemporaryFile() as f:\n            torch.save(conv_torch.state_dict(), f.name)\n            with test_case.assertRaises(ValueError) as ctx:\n                conv_flow.load_state_dict(\n                    flow.load(f.name, support_pytorch_format=False)\n                )\n        test_case.assertTrue(\"Cannot load file\" in str(ctx.exception))\n\n    def test_load_invalid_file(test_case):\n        f = tempfile.NamedTemporaryFile()\n        f.write(b\"invalid file\")\n        f.flush()\n        with test_case.assertRaises(ValueError) as ctx:\n            flow.load(f.name)\n        test_case.assertTrue(\"Cannot load file\" in str(ctx.exception))\n\n        f.close()\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/exceptions/test_saved_tensor_hooks.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport os\nimport unittest\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n@flow.unittest.skip_unless_1n1d()\nclass TestSavedTensorHooks(flow.unittest.TestCase):\n    def test_unpack_returns_non_tensor(test_case):\n        x = flow.ones(1, 2, 3).to(\"cuda\").requires_grad_()\n        y = flow.zeros(1, 2, 3).to(\"cuda\").requires_grad_()\n\n        def pack(x):\n            return x\n\n        def unpack(x):\n            return 0\n\n        with flow.autograd.graph.saved_tensors_hooks(pack, unpack):\n            z = x * y\n        with test_case.assertRaises(Exception) as exp:\n            z.sum().backward()\n        test_case.assertTrue(\n            \"unpack_hook should return a Tensor, but got `<class 'int'>` instead\"\n            in str(exp.exception)\n        )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/exceptions/test_slice_op.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nimport oneflow as flow\nimport oneflow.unittest\nimport numpy as np\n\n\nclass TestSlice(flow.unittest.TestCase):\n    def test_slice_update_start_list_err(test_case):\n        with test_case.assertRaises(RuntimeError) as context:\n            ref = flow.tensor([[1], [2]])\n            value = flow.tensor([[1], [2]])\n            start = [-1]\n            stop = [1]\n            step = [1]\n            flow._C.slice_update(ref, value, start, stop, step)\n        test_case.assertTrue(\n            \"The start list elements must be greater than or equal to 0, but got\"\n            in str(context.exception)\n        )\n\n    def test_slice_update_stop_list_err(test_case):\n        with test_case.assertRaises(RuntimeError) as context:\n            ref = flow.tensor([[1], [2]])\n            value = flow.tensor([[1], [2]])\n            start = [1]\n            stop = [-1]\n            step = [1]\n            flow._C.slice_update(ref, value, start, stop, step)\n        test_case.assertTrue(\n            \"The stop list elements must be greater than or equal to 0\"\n            in str(context.exception)\n        )\n\n    def test_slice_update_step_list_err(test_case):\n        with test_case.assertRaises(RuntimeError) as context:\n            ref = flow.tensor([[1], [2]])\n            value = flow.tensor([[1], [2]])\n            start = [1]\n            stop = [1]\n            step = [0]\n            flow._C.slice_update(ref, value, start, stop, step)\n        test_case.assertTrue(\n            \"The step list elements must be greater than 0, but got\"\n            in str(context.exception)\n        )\n\n    def test_slice_update_start_and_stop_compare_value_err(test_case):\n        with test_case.assertRaises(RuntimeError) as context:\n            ref = flow.tensor([[1], [2]])\n            value = flow.tensor([[1], [2]])\n            start = [2]\n            stop = [1]\n            step = [1]\n            flow._C.slice_update(ref, value, start, stop, step)\n        test_case.assertTrue(\n            \"The element in start list must be less than or equal to the element in stop list at index\"\n            in str(context.exception)\n        )\n\n    def test_slice_update_turple_size_match_err(test_case):\n        with test_case.assertRaises(RuntimeError) as context:\n            ref = flow.tensor([1, 2])\n            value = flow.tensor([1, 2])\n            start = [1, 2, 3]\n            stop = [1, 2, 3]\n            step = [1, 2, 3]\n            flow._C.slice_update(ref, value, start, stop, step)\n        test_case.assertTrue(\n            \"The size of slice tuple must be equal to the size of value tensor at dimension\"\n            in str(context.exception)\n        )\n\n    def test_slice_update_type_err(test_case):\n        with test_case.assertRaises(TypeError) as context:\n            ref = flow.tensor([1], dtype=flow.int64)\n            value = flow.tensor([0.545], dtype=flow.float32)\n            start = [1]\n            stop = [2]\n            step = [1]\n            flow._C.slice_update(ref, value, start, stop, step)\n        test_case.assertTrue(\n            \"Tensors ref and value must have same type\" in str(context.exception)\n        )\n\n    def test_slice_start_list_err(test_case):\n        with test_case.assertRaises(RuntimeError) as context:\n            ref = flow.tensor([1])\n            start = [-1]\n            stop = [1]\n            step = [1]\n            flow._C.slice(ref, start, stop, step)\n        test_case.assertTrue(\n            \"The start list elements must be greater than or equal to 0, but got \"\n            in str(context.exception)\n        )\n\n    def test_slice_stop_list_err(test_case):\n        with test_case.assertRaises(RuntimeError) as context:\n            ref = flow.tensor([1])\n            start = [1]\n            stop = [-1]\n            step = [1]\n            flow._C.slice(ref, start, stop, step)\n        test_case.assertTrue(\n            \"The stop list elements must be greater than or equal to 0, but got \"\n            in str(context.exception)\n        )\n\n    def test_slice_step_list_err(test_case):\n        with test_case.assertRaises(RuntimeError) as context:\n            ref = flow.tensor([1])\n            start = [1]\n            stop = [1]\n            step = [-1]\n            flow._C.slice(ref, start, stop, step)\n        test_case.assertTrue(\n            \"The step list elements must be greater than 0, but got \"\n            in str(context.exception)\n        )\n\n    def test_slice_start_and_stop_compare_value_err(test_case):\n        with test_case.assertRaises(RuntimeError) as context:\n            ref = flow.tensor([1])\n            start = [2]\n            stop = [1]\n            step = [1]\n            flow._C.slice(ref, start, stop, step)\n        test_case.assertTrue(\n            \"The element in start list must be less than or equal to the element in stop list at index \"\n            in str(context.exception)\n        )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/exceptions/test_smooth_l1_loss_op.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nimport oneflow as flow\nimport oneflow.unittest\n\n\nclass TestSmoothL1LossError(flow.unittest.TestCase):\n    def test_smooth_l1_loss_shape_err(test_case):\n        with test_case.assertRaises(RuntimeError) as context:\n            input = flow.randn(10)\n            target = flow.randn(11)\n            reduction = \"mean\"\n            beta = 1.0\n            flow._C.smooth_l1_loss(input, target, beta, reduction)\n        test_case.assertTrue(\"must match the size of target\" in str(context.exception))\n\n    def test_smooth_l1_loss_beta_err(test_case):\n        with test_case.assertRaises(RuntimeError) as context:\n            input = flow.randn(10)\n            target = flow.randn(10)\n            reduction = \"mean\"\n            beta = -1.0\n            flow._C.smooth_l1_loss(input, target, beta, reduction)\n        test_case.assertTrue(\n            \"beta must be greater than or equal to 0\" in str(context.exception)\n        )\n\n    def test_smooth_l1_loss_dtype_err(test_case):\n        with test_case.assertRaises(TypeError) as context:\n            input = flow.randn(10, dtype=flow.float32)\n            target = flow.randn(10, dtype=flow.float64)\n            reduction = \"mean\"\n            beta = 1.0\n            flow._C.smooth_l1_loss(input, target, beta, reduction)\n        test_case.assertTrue(\n            \"input and target are expected to have the same dtype\"\n            in str(context.exception)\n        )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/exceptions/test_softmax_cross_entropy_op.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nimport oneflow as flow\nimport oneflow.unittest\n\n\nclass TestSoftmaxCrossEntropyError(flow.unittest.TestCase):\n    def test_softmax_cross_entropy_prediction_numaxes_err(test_case):\n        with test_case.assertRaises(RuntimeError) as context:\n            prediction = flow.randn(10)\n            label = flow.randn(1, 10)\n            flow._C.softmax_cross_entropy(prediction, label)\n        test_case.assertTrue(\n            \"The dimension of prediction must be greater than or equal to 2, but found\"\n            in str(context.exception)\n        )\n\n    def test_softmax_cross_entropy_prediction_shape_err(test_case):\n        with test_case.assertRaises(RuntimeError) as context:\n            prediction = flow.randn(1, 10)\n            label = flow.randn(1, 11)\n            flow._C.softmax_cross_entropy(prediction, label)\n        test_case.assertTrue(\n            \"must match the size of prediction\" in str(context.exception)\n        )\n\n    def test_softmax_cross_entropy_dtype_err(test_case):\n        with test_case.assertRaises(TypeError) as context:\n            prediction = flow.randn(1, 10, dtype=flow.float32)\n            label = flow.randn(1, 10, dtype=flow.float64)\n            flow._C.softmax_cross_entropy(prediction, label)\n        test_case.assertTrue(\n            \"label and prediction are expected to have the same dtype, but found\"\n            in str(context.exception)\n        )\n\n    def test_softmax_cross_entropy_grad_prob_numaxes_err(test_case):\n        with test_case.assertRaises(RuntimeError) as context:\n            dy = flow.randn(10, 5)\n            label = flow.randn(10, 10, 5)\n            prob = flow.randn(10)\n            flow._C.softmax_cross_entropy_grad(dy, label, prob)\n        test_case.assertTrue(\n            \"The dimension of prob must be greater than or equal to 2, but found \"\n            in str(context.exception)\n        )\n\n    def test_softmax_cross_entropy_grad_dy_numaxes_err(test_case):\n        with test_case.assertRaises(RuntimeError) as context:\n            dy = flow.randn(10, 10, 5)\n            label = flow.randn(10, 10, 5)\n            prob = flow.randn(10, 10, 5)\n            flow._C.softmax_cross_entropy_grad(dy, label, prob)\n        test_case.assertTrue(\n            \"The dimension of dy is expected to be less than that of prob by 1, but found\"\n            in str(context.exception)\n        )\n\n    def test_softmax_cross_entropy_grad_dy_i_shape_err(test_case):\n        with test_case.assertRaises(RuntimeError) as context:\n            dy = flow.randn(10, 8)\n            label = flow.randn(10, 10, 5)\n            prob = flow.randn(10, 10, 5)\n            flow._C.softmax_cross_entropy_grad(dy, label, prob)\n        test_case.assertTrue(\"must match the size of label\" in str(context.exception))\n\n    def test_softmax_cross_entropy_grad_prob_shape_err(test_case):\n        with test_case.assertRaises(RuntimeError) as context:\n            dy = flow.randn(10, 10)\n            label = flow.randn(10, 10, 5)\n            prob = flow.randn(10, 10, 6)\n            flow._C.softmax_cross_entropy_grad(dy, label, prob)\n        test_case.assertTrue(\"must match the size of prob\" in str(context.exception))\n\n    def test_softmax_cross_entropy_grad_label_dtype_err(test_case):\n        with test_case.assertRaises(TypeError) as context:\n            dy = flow.randn(10, 10, dtype=flow.float64)\n            label = flow.randn(10, 10, 5, dtype=flow.float32)\n            prob = flow.randn(10, 10, 5, dtype=flow.float64)\n            flow._C.softmax_cross_entropy_grad(dy, label, prob)\n        test_case.assertTrue(\n            \"label and prob are expected to have the same dtype, but found\"\n            in str(context.exception)\n        )\n\n    def test_softmax_cross_entropy_grad_dy_dtype_err(test_case):\n        with test_case.assertRaises(TypeError) as context:\n            dy = flow.randn(10, 10, dtype=flow.float32)\n            label = flow.randn(10, 10, 5, dtype=flow.float64)\n            prob = flow.randn(10, 10, 5, dtype=flow.float64)\n            flow._C.softmax_cross_entropy_grad(dy, label, prob)\n            print(str(context.exception))\n        test_case.assertTrue(\n            \"dy and prob are expected to have the same dtype, but found\"\n            in str(context.exception)\n        )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/exceptions/test_sparse_cross_entropy_op.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nimport oneflow as flow\nimport oneflow.unittest\n\n\nclass TestSparseCrossEntropyError(flow.unittest.TestCase):\n    def test_sparse_cross_entropy_prediction_numaxes_err(test_case):\n        with test_case.assertRaises(RuntimeError) as context:\n            prediction = flow.randn(10)\n            label = flow.randint(0, 10, (10, 10), dtype=flow.int64)\n            depth = 10\n            flow._C.sparse_cross_entropy(prediction, label, depth)\n        test_case.assertTrue(\n            \"The dimension of prediction must be greater than or equal to 2, but found\"\n            in str(context.exception)\n        )\n\n    def test_sparse_cross_entropy_label_numaxes_err(test_case):\n        with test_case.assertRaises(RuntimeError) as context:\n            prediction = flow.randn(10, 10, 5)\n            label = flow.randint(0, 10, (10, 10, 5), dtype=flow.int64)\n            depth = 10\n            flow._C.sparse_cross_entropy(prediction, label, depth)\n        test_case.assertTrue(\n            \"The dimension of label is expected to be less than that of prediction by 1\"\n            in str(context.exception)\n        )\n\n    def test_sparse_cross_entropy_prediction_i_shape_err(test_case):\n        with test_case.assertRaises(RuntimeError) as context:\n            prediction = flow.randn(10, 10, 5)\n            label = flow.randint(0, 10, (10, 5), dtype=flow.int64)\n            depth = 10\n            flow._C.sparse_cross_entropy(prediction, label, depth)\n        test_case.assertTrue(\" must match the size of label\" in str(context.exception))\n\n    def test_sparse_cross_entropy_label_dtype_err(test_case):\n        with test_case.assertRaises(TypeError) as context:\n            prediction = flow.randn(10, 10, 5)\n            label = flow.randn((10, 10), dtype=flow.float32)\n            depth = 10\n            flow._C.sparse_cross_entropy(prediction, label, depth)\n        test_case.assertTrue(\n            \"The dtype of label must be integer, but found\" in str(context.exception)\n        )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/exceptions/test_sparse_softmax_cross_entropy_op.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nimport oneflow as flow\nimport oneflow.unittest\n\n\nclass TestSparseSoftmaxCrossEntropyError(flow.unittest.TestCase):\n    def test_sparse_softmax_cross_entropy_prediction_numaxes_err(test_case):\n        with test_case.assertRaises(RuntimeError) as context:\n            prediction = flow.randn(10)\n            label = flow.randint(0, 10, (10, 10), dtype=flow.int64)\n            flow._C.sparse_softmax_cross_entropy(prediction, label)\n        test_case.assertTrue(\n            \"The dimension of prediction must be greater than or equal to 2, but found\"\n            in str(context.exception)\n        )\n\n    def test_sparse_softmax_cross_entropy_label_numaxes_err(test_case):\n        with test_case.assertRaises(RuntimeError) as context:\n            prediction = flow.randn(10, 10, 5)\n            label = flow.randint(0, 10, (10, 10, 5), dtype=flow.int64)\n            flow._C.sparse_softmax_cross_entropy(prediction, label)\n        test_case.assertTrue(\n            \"The dimension of label is expected to be less than that of prediction by 1\"\n            in str(context.exception)\n        )\n\n    def test_sparse_softmax_cross_entropy_prediction_i_shape_err(test_case):\n        with test_case.assertRaises(RuntimeError) as context:\n            prediction = flow.randn(10, 10, 5)\n            label = flow.randint(0, 10, (10, 9), dtype=flow.int64)\n            flow._C.sparse_softmax_cross_entropy(prediction, label)\n        test_case.assertTrue(\"must match the size of label\" in str(context.exception))\n\n    def test_sparse_softmax_cross_entropy_label_dtype_err(test_case):\n        with test_case.assertRaises(TypeError) as context:\n            prediction = flow.randn(10, 10, 5)\n            label = flow.randn(10, 10, dtype=flow.float32)\n            flow._C.sparse_softmax_cross_entropy(prediction, label)\n        test_case.assertTrue(\n            \"The dtype of label must be integer, but found \" in str(context.exception)\n        )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/exceptions/test_split_like_op.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nimport oneflow as flow\nimport oneflow.unittest\n\n\nclass TestSplitLikeError(flow.unittest.TestCase):\n    def test_split_like_like_axes_err(test_case):\n        with test_case.assertRaises(RuntimeError) as context:\n            x = flow.randn(4, 4)\n            like = (flow.randn(2, 4, 4), flow.randn(2, 4, 4))\n            axis = 0\n            flow._C.split_like(x, like, axis)\n        test_case.assertTrue(\n            \") should be less than or equal to input (\" in str(context.exception)\n        )\n\n    def test_split_like_split_axes_err(test_case):\n        with test_case.assertRaises(RuntimeError) as context:\n            x = flow.randn(4, 4)\n            like = (flow.randn(2, 4), flow.randn(2, 4))\n            axis = 3\n            flow._C.split_like(x, like, axis)\n        test_case.assertTrue(\n            \"should be less than the dimension of like\" in str(context.exception)\n        )\n\n    def test_split_like_like_i_axes_err(test_case):\n        with test_case.assertRaises(RuntimeError) as context:\n            x = flow.randn(4, 4)\n            like = (flow.randn(2, 4), flow.randn(2))\n            axis = 0\n            flow._C.split_like(x, like, axis)\n        test_case.assertTrue(\n            \"must match the dimension of the first like\" in str(context.exception)\n        )\n\n    def test_split_like_x_i_shape_err(test_case):\n        with test_case.assertRaises(RuntimeError) as context:\n            x = flow.randn(4, 4)\n            like = (flow.randn(2, 4), flow.randn(2, 3))\n            axis = 0\n            flow._C.split_like(x, like, axis)\n        test_case.assertTrue(\"must match the size of like_i\" in str(context.exception))\n\n    def test_split_like_non_dynamic_static_dim_err(test_case):\n        with test_case.assertRaises(RuntimeError) as context:\n            x = flow.randn(4, 4)\n            like = (flow.randn(2, 4), flow.randn(3, 4))\n            axis = 0\n            flow._C.split_like(x, like, axis)\n        test_case.assertTrue(\n            \"shape situation, the total size of like\" in str(context.exception)\n        )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/exceptions/test_stft_op.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nimport oneflow as flow\nimport oneflow.unittest\nimport numpy as np\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestModule(flow.unittest.TestCase):\n    def test_stft_illegal_input_dim(test_case):\n        np_tensor = np.arange(1, 13, dtype=float).reshape(2, 2, 3)\n\n        with test_case.assertRaises(RuntimeError) as ctx:\n            x_flow = flow.tensor(np_tensor)\n            flow.stft(\n                x_flow,\n                n_fft=4,\n                center=True,\n                onesided=True,\n                return_complex=False,\n                normalized=False,\n            )\n        test_case.assertTrue(\"Expected a 1D or 2D tensor,but got\" in str(ctx.exception))\n\n    def test_stft_illegal_nfft(test_case):\n        np_tensor = np.arange(1, 13, dtype=float).reshape(4, 3)\n        win_tensor = np.arange(1, 5, dtype=float)\n\n        with test_case.assertRaises(RuntimeError) as ctx:\n            x_flow = flow.tensor(np_tensor)\n            flow_win = flow.tensor(win_tensor)\n\n            flow.stft(\n                x_flow,\n                n_fft=-1,\n                window=flow_win,\n                center=True,\n                onesided=True,\n                return_complex=False,\n                normalized=False,\n            )\n        test_case.assertTrue(\"Expected 0 < n_fft\" in str(ctx.exception))\n\n    def test_stft_illegal_hop_length(test_case):\n        np_tensor = np.arange(1, 13, dtype=float).reshape(4, 3)\n\n        with test_case.assertRaises(RuntimeError) as ctx:\n            x_flow = flow.tensor(np_tensor)\n\n            flow.stft(\n                x_flow,\n                n_fft=4,\n                hop_length=-1,\n                center=True,\n                onesided=True,\n                return_complex=False,\n                normalized=False,\n            )\n        test_case.assertTrue(\"Expected hop_length > 0, but got\" in str(ctx.exception))\n\n    def test_stft_illegal_win_length(test_case):\n        np_tensor = np.arange(1, 13, dtype=float).reshape(4, 3)\n\n        with test_case.assertRaises(RuntimeError) as ctx:\n            x_flow = flow.tensor(np_tensor)\n\n            flow.stft(\n                x_flow,\n                n_fft=4,\n                win_length=-1,\n                center=True,\n                onesided=True,\n                return_complex=False,\n                normalized=False,\n            )\n        test_case.assertTrue(\n            \"Expected 0 < win_length <=n_fft ,but got\" in str(ctx.exception)\n        )\n\n    def test_stft_illegal_window(test_case):\n        np_tensor = np.arange(1, 13, dtype=float).reshape(2, 6)\n        win_tensor = np.arange(1, 10, dtype=float)\n\n        with test_case.assertRaises(RuntimeError) as ctx:\n            x_flow = flow.tensor(np_tensor)\n            flow_win = flow.tensor(win_tensor)\n\n            flow.stft(\n                x_flow,\n                n_fft=4,\n                window=flow_win,\n                center=True,\n                onesided=True,\n                return_complex=False,\n                normalized=False,\n            )\n        test_case.assertTrue(\n            \"Expected a 1D window tensor of size equal to win_length=\"\n            in str(ctx.exception)\n        )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/exceptions/test_tensor_index.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nimport oneflow.unittest\nimport oneflow as flow\n\n\nclass TestTensorIndexError(flow.unittest.TestCase):\n    def test_PrepareSliceIndices_indices_amount_index_error(test_case):\n        with test_case.assertRaises(IndexError) as context:\n            x = flow.arange(16).reshape(4, 4)\n            x[0, 0, 0] = 0\n        test_case.assertTrue(\n            \"Too many indices for tensor of dimension\" in str(context.exception)\n        )\n\n    def test_PrepareSliceIndices_slice_step_runtime_error(test_case):\n        with test_case.assertRaises(RuntimeError) as context:\n            x = flow.tensor([0, 1, 2, 3], dtype=flow.int32)\n            s = slice(0, 2, -1)\n            y = x[s]\n        test_case.assertTrue(\"Step must be greater than zero\" in str(context.exception))\n\n    def test_ApplySelectIndexing_input_dim_runtime_error(test_case):\n        with test_case.assertRaises(RuntimeError) as context:\n            x = flow.tensor(5, dtype=flow.int32)\n            y = x[0]\n        test_case.assertTrue(\n            \"select() cannot be applied to a 0-dim tensor.\" in str(context.exception)\n        )\n\n    def test_ApplySelectIndexing_index_error(test_case):\n        with test_case.assertRaises(IndexError) as context:\n            x = flow.ones(2, 3, dtype=flow.int32)\n            y = x[3]\n        test_case.assertTrue(\n            \"Index out of range (expected to be in range of\" in str(context.exception)\n        )\n\n    def test_ApplyAdvancedIndexing_index_error(test_case):\n        with test_case.assertRaises(IndexError) as context:\n            x = flow.ones(2, 2, dtype=flow.int32)\n            index = (\n                flow.tensor(1, dtype=flow.int32),\n                flow.tensor(1, dtype=flow.int32),\n                flow.tensor(1, dtype=flow.int32),\n            )\n            y = x[index]\n        test_case.assertTrue(\n            \"Too many indices for tensor of dimension\" in str(context.exception)\n        )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/exceptions/test_tensordot.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nimport oneflow as flow\nimport oneflow.unittest\nimport oneflow.nn.functional as F\nimport torch\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestTensordotError(flow.unittest.TestCase):\n    def test_tensordot_neg_dims_runtime_error(test_case):\n        with test_case.assertRaises(Exception) as context:\n            a = flow.randn(1, 2, 3)\n            b = flow.randn(1, 2, 3)\n            flow.tensordot(a, b, dims=-1)\n        test_case.assertTrue(\n            \"tensordot expects dims >= 0, but got dims=-1\" in str(context.exception)\n        )\n\n    @unittest.skip(\"PyTorch doesn't have corresponding error message\")\n    def test_tensordot_too_large_int_dims_runtime_error(test_case):\n        with test_case.assertRaises(Exception) as context:\n            a = flow.randn(1, 2, 3)\n            b = flow.randn(1, 2, 3)\n            flow.tensordot(a, b, dims=100)\n        test_case.assertTrue(\n            \"tensordot expects dims <= a.ndim which is 3, but got 100\"\n            in str(context.exception)\n        )\n\n    def test_tensordot_out_of_range_dims_runtime_error(test_case):\n        with test_case.assertRaises(Exception) as context:\n            a = flow.randn(1, 2, 3)\n            b = flow.randn(1, 2, 3)\n            flow.tensordot(a, b, dims=[[3], [2]])\n        test_case.assertTrue(\n            \"Dimension out of range (expected to be in range of [-3, 2], but got 3)\"\n            in str(context.exception)\n        )\n\n    def test_tensordot_unmatch_dims_runtime_error(test_case):\n        with test_case.assertRaises(Exception) as context:\n            a = flow.randn(1, 2, 3)\n            b = flow.randn(1, 2, 3)\n            flow.tensordot(a, b, dims=[[1], [2]])\n        test_case.assertTrue(\n            \"contracted dimensions need to match, but first has size 2 in dim 1 and second has size 3 in dim 2\"\n            in str(context.exception)\n        )\n\n    def test_tensordot_recurring_dim_runtime_error(test_case):\n        with test_case.assertRaises(Exception) as context:\n            a = flow.randn(1, 2, 3)\n            b = flow.randn(1, 2, 3)\n            flow.tensordot(a, b, dims=[[1, 1], [1, 1]])\n        test_case.assertTrue(\n            \"dim 1 appears multiple times in the list of dims\" in str(context.exception)\n        )\n\n    def test_tensordot_dims_different_length_runtime_error(test_case):\n        with test_case.assertRaises(Exception) as context:\n            a = flow.randn(1, 2, 3)\n            b = flow.randn(1, 2, 3)\n            flow.tensordot(a, b, dims=[[1], [1, 2]])\n        test_case.assertTrue(\n            \"both dimension lists should have same length\" in str(context.exception)\n        )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/exceptions/test_to_global_error.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport os\nimport numpy as np\nimport time\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@flow.unittest.skip_unless_1n2d()\nclass TestToGlobalError(flow.unittest.TestCase):\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    def test_tensor_to_consistent(self):\n        with self.assertRaises(Exception) as context:\n            data = flow.rand(2, dtype=flow.float32)\n            placement = flow.placement.all(\"cuda\")\n            sbp = flow.sbp.split(0)\n            global_data = data.to_consistent(placement=placement, sbp=sbp)\n\n        self.assertTrue(\n            \".to_consistent has been removed, please use .to_global instead\"\n            in str(context.exception)\n        )\n\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    def test_tensor_is_global(self):\n        with self.assertRaises(Exception) as context:\n            data = flow.rand(2, dtype=flow.float32)\n            print(data.is_consistent())\n\n        self.assertTrue(\n            \".is_consistent has been removed, please use .is_global instead\"\n            in str(context.exception)\n        )\n\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    def test_module_to_consistent(self):\n        with self.assertRaises(Exception) as context:\n            m = flow.nn.Conv2d(1, 1, 1)\n            placement = flow.placement.all(\"cuda\")\n            sbp = flow.sbp.split(0)\n            m.to_consistent(placement=placement, sbp=sbp)\n\n        self.assertTrue(\n            \".to_consistent has been removed, please use .to_global instead\"\n            in str(context.exception)\n        )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/exceptions/test_view.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestModule(flow.unittest.TestCase):\n    def test_view_exception(test_case):\n        # torch exception and messge:\n        #\n        #   RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.\n        #\n        a = flow.arange(9).reshape(3, 3)\n        b = a.permute(1, 0)\n        with test_case.assertRaises(RuntimeError) as ctx:\n            print(b.view(9))\n        test_case.assertTrue(\n            \"view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.\"\n            in str(ctx.exception)\n        )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/exceptions/throw_error.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\n# This file is intended to be run in\n# python/oneflow/test/exceptions/test_error_reported_in_thread.py\n\nimport oneflow as flow\n\n\ndef g(x):\n    flow._C.throw_error(x)\n\n\ndef f(x):\n    x = x.relu()\n    g(x)\n\n\nx = flow.ones(3, 3, 4)\nf(x)\n"
  },
  {
    "path": "python/oneflow/test/expensive/README.md",
    "content": "# Expensive tests\n\n- Tests requires a lot of time, memory to run.\n- Every test should have exclusive access to GPU when running\n"
  },
  {
    "path": "python/oneflow/test/expensive/_internally_replaced_utils.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport os\nimport importlib.machinery\n\n\ndef _download_file_from_remote_location(fpath: str, url: str) -> None:\n    pass\n\n\ndef _is_remote_location_available() -> bool:\n    return False\n\n\ntry:\n    from torch.hub import load_state_dict_from_url\nexcept ImportError:\n    from torch.utils.model_zoo import load_url as load_state_dict_from_url\n\n\ndef _get_extension_path(lib_name):\n\n    lib_dir = os.path.dirname(__file__)\n    if os.name == \"nt\":\n        # Register the main torchvision library location on the default DLL path\n        import ctypes\n        import sys\n\n        kernel32 = ctypes.WinDLL(\"kernel32.dll\", use_last_error=True)\n        with_load_library_flags = hasattr(kernel32, \"AddDllDirectory\")\n        prev_error_mode = kernel32.SetErrorMode(0x0001)\n\n        if with_load_library_flags:\n            kernel32.AddDllDirectory.restype = ctypes.c_void_p\n\n        if sys.version_info >= (3, 8):\n            os.add_dll_directory(lib_dir)\n        elif with_load_library_flags:\n            res = kernel32.AddDllDirectory(lib_dir)\n            if res is None:\n                err = ctypes.WinError(ctypes.get_last_error())\n                err.strerror += f' Error adding \"{lib_dir}\" to the DLL directories.'\n                raise err\n\n        kernel32.SetErrorMode(prev_error_mode)\n\n    loader_details = (\n        importlib.machinery.ExtensionFileLoader,\n        importlib.machinery.EXTENSION_SUFFIXES,\n    )\n\n    extfinder = importlib.machinery.FileFinder(lib_dir, loader_details)\n    ext_specs = extfinder.find_spec(lib_name)\n    if ext_specs is None:\n        raise ImportError\n\n    return ext_specs.origin\n"
  },
  {
    "path": "python/oneflow/test/expensive/_test_remat.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\"\"\"\nThis file (_test_remat.py) is intended to be run inside test_remat.py\nwith correct environment variables like ONEFLOW_VM_MULTI_THREAD=0\n\"\"\"\nfrom contextlib import contextmanager\nimport os\nimport unittest\nimport functools\n\nimport numpy as np\n\nimport oneflow as flow\nfrom oneflow import nn\nimport flowvision\nimport oneflow.unittest\n\n\ndef evict(tensor):\n    flow._oneflow_internal.remat.evict(tensor)\n\n\ndef is_in_memory(tensor):\n    return flow._oneflow_internal.remat.is_in_memory(tensor)\n\n\nplaceholder_size = 0\n\n\ndef allocated_memory(device, include_test_placeholder=False):\n    if device == \"cuda\" and not flow.sysconfig.with_cuda():\n        return 0\n    return flow._oneflow_internal.remat.allocated_memory(device) - (\n        0 if include_test_placeholder else placeholder_size\n    )\n\n\ndef display(device):\n    return flow._oneflow_internal.remat.display(device)\n\n\ndef only_fbip():\n    if os.getenv(\"ONEFLOW_REMAT_COPY_ON_WRITE\") is None:\n        return lambda f: f\n    else:\n        return unittest.skip(\"\")\n\n\ndef only_copy_on_write():\n    if os.getenv(\"ONEFLOW_REMAT_COPY_ON_WRITE\") is not None:\n        return lambda f: f\n    else:\n        return unittest.skip(\"\")\n\n\ndef loss_test():\n    if os.getenv(\"ONEFLOW_REMAT_RUN_LOSS_TEST\") is not None:\n        return lambda f: f\n    else:\n        return unittest.skip(\n            \"Environment variable 'ONEFLOW_REMAT_RUN_LOSS_TEST' need to be set to run this test.\"\n        )\n\n\n@contextmanager\ndef generate_placeholder(size_mb, device):\n    global placeholder_size\n    placeholder_size = size_mb * 1024 * 1024\n    x = flow.zeros(int(placeholder_size), dtype=flow.int8, device=device)\n    flow._oneflow_internal.remat.disable_eviction(x)\n    try:\n        yield\n    finally:\n        del x\n        placeholder_size = 0\n\n\ndef memory_budget(budget_mb, device):\n    if device == \"cuda\" and not oneflow.sysconfig.with_cuda():\n        return unittest.skip(\"Skip CUDA tests on CPU build\")\n\n    def deco(f):\n        @functools.wraps(f)\n        def new_f(*args, **kwargs):\n            total_budget = flow.remat.get_budget() / 1024 / 1024\n            assert total_budget >= budget_mb, \"Not enough memory budget\"\n            remat_device = device + \"+remat\"\n            with generate_placeholder(total_budget - budget_mb, remat_device):\n                return f(*args, remat_device, **kwargs)\n\n        return new_f\n\n    return deco\n\n\nclass TestRemat(flow.unittest.TestCase):\n    @classmethod\n    def setUpClass(cls):\n        flow.remat.set_budget(\"500MB\")\n        flow.remat.set_small_pieces_optimization(False)\n\n    def setUp(self):\n        super().setUp()\n\n        assert (\n            os.getenv(\"ONEFLOW_VM_MULTI_THREAD\") is not None\n        ), \"Please set ONEFLOW_VM_MULTI_THREAD to False, 0 or OFF\"\n        # check the memory is empty at the beginning of every test case\n        if allocated_memory(\"cpu\") > 0:\n            print(\"allocated_memory(cpu):\", allocated_memory(\"cpu\"))\n            display(\"cpu\")\n        if allocated_memory(\"cuda\") > 0:\n            print(\"allocated_memory(cuda):\", allocated_memory(\"cuda\"))\n            display(\"cuda\")\n\n        self.assertEqual(allocated_memory(\"cpu\"), 0)\n        self.assertEqual(allocated_memory(\"cuda\"), 0)\n        flow._oneflow_internal.remat.clear_stats()\n\n    def tearDown(self):\n        super().tearDown()\n        # check the memory is empty at the end of every test case\n        self.assertEqual(allocated_memory(\"cpu\"), 0)\n        self.assertEqual(allocated_memory(\"cuda\"), 0)\n\n    @flow.unittest.skip_unless_1n1d()\n    @only_fbip()\n    @memory_budget(12, \"cpu\")\n    def test_remat_work_on_fbip_1(self, device):\n        x1 = flow.ones(1024 * 1024, device=device)  # 4MB\n        x2 = x1 * -2  # 8MB\n        x3 = x2 - 2  # 12MB\n        x2.relu_()  # 12MB\n        self.assertTrue(is_in_memory(x1))\n        self.assertTrue(is_in_memory(x2))\n        self.assertTrue(is_in_memory(x3))\n        evict(x3)\n        self.assertTrue(np.array_equal(x3.numpy(), np.ones(x3.shape) * -4))\n        evict(x2)\n        self.assertTrue(np.array_equal(x2.numpy(), np.zeros(x2.shape)))\n\n    @flow.unittest.skip_unless_1n1d()\n    @only_fbip()\n    @memory_budget(12, \"cpu\")\n    def test_remat_work_on_fbip_2(self, device):\n        x1 = flow.ones(1024 * 1024, device=device)  # 4MB\n        x2 = x1[0]\n        x3 = x2 + 2\n        evict(x3)\n        self.assertTrue(np.array_equal(x3.numpy(), np.ones(x3.shape) * 3))\n        evict(x2)\n        evict(x3)\n        self.assertTrue(np.array_equal(x3.numpy(), np.ones(x3.shape) * 3))\n        evict(x2)\n        self.assertTrue(np.array_equal(x2.numpy(), np.ones(x2.shape)))\n\n    @flow.unittest.skip_unless_1n1d()\n    @unittest.skip(\"mutation other than inplace is not supported yet\")\n    @only_fbip()\n    @memory_budget(12, \"cpu\")\n    def test_remat_work_on_fbip_3(self, device):\n        x1 = flow.ones(1024 * 1024, device=device)  # 4MB\n        x2 = x1 * -2  # 8MB\n        x1.zero_()\n        evict(x2)\n        print(x2.numpy())\n        self.assertTrue(np.array_equal(x2.numpy(), np.ones(x2.shape) * -2))\n\n    @flow.unittest.skip_unless_1n1d()\n    @only_fbip()\n    @memory_budget(12, \"cuda\")\n    def test_remat_work_on_fbip_4(self, device):\n        x1 = flow.ones(1024 * 1024, device=device)  # 4MB\n        x2 = x1 + 1\n        x2 += x1\n        x3 = x2.relu()\n        x4 = x3 + 1\n        evict(x3)\n        evict(x2)\n        evict(x1)\n        evict(x3)\n        self.assertTrue(np.array_equal(x4.numpy(), np.ones(x4.shape) * 4))\n\n    @flow.unittest.skip_unless_1n1d()\n    @memory_budget(12, \"cpu\")\n    def test_remat_work_on_simple_case_1(self, device):\n        x1 = flow.ones(1024 * 1024, device=device)  # 4MB\n        self.assertTrue(is_in_memory(x1))\n        self.assertEqual(allocated_memory(device), 4 * 1024 * 1024)\n        x2 = x1 + 2\n        self.assertEqual(allocated_memory(device), 8 * 1024 * 1024)\n        # eager eviction\n        del x1\n        self.assertEqual(allocated_memory(device), 4 * 1024 * 1024)\n        self.assertTrue(is_in_memory(x2))\n        x3 = x2 + 2\n        self.assertTrue(is_in_memory(x2))\n        x4 = x3 + 2\n        self.assertTrue(is_in_memory(x2))\n        x5 = x4 + 2\n        self.assertFalse(is_in_memory(x2))\n        self.assertTrue(is_in_memory(x3))\n        self.assertTrue(is_in_memory(x4))\n        x6 = x5 + 2\n        self.assertFalse(is_in_memory(x2))\n        # the eviction of x2 increases the cost of x3, so x4 is evicted\n        self.assertTrue(is_in_memory(x3))\n        self.assertFalse(is_in_memory(x4))\n\n        self.assertTrue(np.array_equal(x6.numpy(), np.ones(x6.shape) * 11))\n        self.assertTrue(np.array_equal(x3.numpy(), np.ones(x3.shape) * 5))\n\n    @flow.unittest.skip_unless_1n1d()\n    @memory_budget(12, \"cpu\")\n    def test_remat_work_on_simple_case_2(self, device):\n        x1 = flow.ones(1024 * 1024, device=device)  # 4MB\n        self.assertTrue(is_in_memory(x1))\n        self.assertEqual(allocated_memory(device), 4 * 1024 * 1024)\n        x2 = x1 + 2\n        # eager eviction\n        del x1\n        self.assertTrue(is_in_memory(x2))\n        x3 = x2 + 2\n        self.assertTrue(is_in_memory(x2))\n        x4 = x3 + 2\n        self.assertTrue(is_in_memory(x2))\n        x5 = x4 + 2\n        self.assertFalse(is_in_memory(x2))\n        self.assertTrue(is_in_memory(x3))\n        self.assertTrue(is_in_memory(x4))\n        x6 = x5 + 2\n        self.assertFalse(is_in_memory(x2))\n        # the eviction of x2 increases the cost of x3, so x4 is evicted\n        self.assertTrue(is_in_memory(x3))\n        self.assertFalse(is_in_memory(x4))\n\n        self.assertTrue(np.array_equal(x6.numpy(), np.ones(x6.shape) * 11))\n        self.assertTrue(np.array_equal(x3.numpy(), np.ones(x3.shape) * 5))\n\n    @flow.unittest.skip_unless_1n1d()\n    @memory_budget(12, \"cpu\")\n    def test_remat_full_and_init_constant(self, device):\n        x1 = flow.eye(1024, 1024, device=device)\n        self.assertTrue(is_in_memory(x1))\n        self.assertEqual(allocated_memory(device), 4 * 1024 * 1024)\n\n        x2 = flow.full(x1.shape, 3.0, device=device)\n        flow.nn.init.constant_(x1, x2)  # type: ignore[arg-type]\n        del x2\n        self.assertEqual(allocated_memory(device), 4 * 1024 * 1024)\n\n        evict(x1)\n\n        self.assertTrue(np.array_equal(x1.numpy(), np.ones(x1.shape) * 3))\n\n    @flow.unittest.skip_unless_1n1d()\n    @memory_budget(12, \"cpu\")\n    def test_remat_lifecycle_of_view_tensor(self, device):\n        x1 = flow.eye(2, 3, device=device)\n        self.assertTrue(is_in_memory(x1))\n\n        x2 = flow.ones(3, device=device)\n        x3 = flow.expand(x2, (2, 3))\n        x1[:] = x3\n        del x3\n        del x2\n\n        evict(x1)\n\n        self.assertTrue(np.array_equal(x1.numpy(), np.ones(x1.shape)))\n\n    @flow.unittest.skip_unless_1n1d()\n    @memory_budget(16, \"cpu\")\n    def test_remat_init_constant_and_scalar(self, device):\n        x0 = flow.ones(1024, 1024).to(device)\n        x1 = x0 + 0\n        x2 = x1 + 1\n        flow.nn.init.constant_(x1, 5.0)  # type: ignore[arg-type]\n\n        evict(x1)\n        self.assertTrue(np.array_equal(x1.numpy(), np.ones(x1.shape) * 5))\n\n        evict(x1)\n        evict(x2)\n        self.assertTrue(np.array_equal(x2.numpy(), np.ones(x2.shape) * 2))\n\n    @flow.unittest.skip_unless_1n1d()\n    @memory_budget(80, \"cpu\")\n    def test_copy(self, device):\n        x1 = flow.ones(1)\n        x2 = x1.to(device)\n        self.assertTrue(x2.device.rematable)\n        x3 = x2.to(flow.int64)\n        self.assertTrue(x3.device.rematable)\n        x4 = x2 + 1\n        self.assertTrue(x4.device.rematable)\n\n    @flow.unittest.skip_unless_1n1d()\n    @memory_budget(80, \"cuda\")\n    def test_simple_network(self, device):\n        model = nn.Sequential(\n            nn.Conv2d(3, 32, 3, 2, 1),\n            nn.BatchNorm2d(32),\n            nn.ReLU(inplace=False),\n            nn.Conv2d(32, 32, 3, 1, 1),\n            nn.BatchNorm2d(32),\n            nn.ReLU(inplace=False),\n            nn.Conv2d(32, 32, 3, 1, 1),\n            nn.BatchNorm2d(32),\n            nn.ReLU(inplace=False),\n            nn.Conv2d(32, 32, 3, 1, 1),\n            nn.BatchNorm2d(32),\n            nn.ReLU(inplace=False),\n        ).to(device)\n        for p in model.parameters():\n            p.grad = flow.zeros_like(p).to(device)\n        optimizer = flow.optim.SGD(model.parameters(), lr=0.1, momentum=0)\n        x = flow.ones(4, 3, 224, 224).to(device)\n        mem = allocated_memory(device)\n        for _ in range(10):\n            mem2 = allocated_memory(device)\n            self.assertEqual(mem, mem2)\n            loss = model(x).sum()\n            loss.backward()\n            del loss\n            optimizer.step()\n            optimizer.zero_grad()\n\n    def _test_resnet18(self, optimizer_fn, ddp, expected_loss):\n        flow.manual_seed(flow.env.get_rank())\n        device = \"cpu+remat\"\n\n        model = flowvision.models.resnet18().to(device)\n        if ddp:\n            model = flow.nn.parallel.DistributedDataParallel(model, use_bucket=False)\n        criterion = nn.CrossEntropyLoss().to(device)\n\n        for x in model.parameters():\n            x.grad = flow.zeros_like(x).to(device)\n        # optimizer = flow.optim.SGD(model.parameters(), lr=0.1, momentum=0)\n        optimizer = optimizer_fn(model.parameters())\n        x = flow.rand(10, 3, 224, 224).to(device)\n        target = (\n            flow.randint(low=0, high=1000, size=(x.shape[0],)).to(device).to(flow.int32)\n        )\n        # NOTE: there is a bug in current implementation about random ops:\n        # x1 = flow.rand(5)\n        # x2 = x1 + 1\n        # del x1   <--- we cannot block the eviction of x1 here because it is controlled by the user\n        # evict(x2)\n        # recompute(x2) <-- recomputing x2 triggers the recomputation of x1 and causes inconsistentness\n        flow._oneflow_internal.remat.disable_eviction(x)\n        flow._oneflow_internal.remat.disable_eviction(target)\n        ITER_NUM = 5\n        for i in range(ITER_NUM):\n            print(\"start allocated_memory(cpu):\", allocated_memory(\"cpu\"))\n            print(\n                \"recomputation num: \", flow._oneflow_internal.remat.recomputation_num()\n            )\n            output = model(x)\n            loss = criterion(output, target)\n            del output\n            print(loss.numpy().item())\n            if i == 4 and expected_loss is not None:\n                self.assertTrue(loss.numpy().item() in expected_loss)\n            loss.backward()\n            del loss\n            optimizer.step()\n            optimizer.zero_grad()\n            print(\"end allocated_memory(cpu):\", allocated_memory(\"cpu\"))\n            print(\n                \"recomputation num: \", flow._oneflow_internal.remat.recomputation_num()\n            )\n\n        # check there is more than 10 recomputations each iteration\n        # so the correctness check makes sense.\n        self.assertGreater(\n            flow._oneflow_internal.remat.recomputation_num(), ITER_NUM * 10\n        )\n\n    @flow.unittest.skip_unless_1n1d()\n    @only_fbip()\n    @memory_budget(220, \"cpu\")\n    @loss_test()\n    def test_resnet18_naive_sgd(self, _):\n        # NOTE: this loss is only correct in my environment on 21\n        self._test_resnet18(\n            lambda params: flow.optim.SGD(params, lr=0.1, momentum=0),\n            False,\n            [0.6304041147232056],\n        )\n\n    @flow.unittest.skip_unless_1n2d()\n    @only_fbip()\n    @memory_budget(220, \"cpu\")\n    @loss_test()\n    def test_resnet18_naive_sgd_ddp_1n2d(self, _):\n        # 2 devices, 2 losses\n        # NOTE: these losses are only correct in my environment on 21\n        self._test_resnet18(\n            lambda params: flow.optim.SGD(params, lr=0.1, momentum=0),\n            True,\n            [1.8890058994293213, 1.8992782831192017],\n        )\n\n    @flow.unittest.skip_unless_1n1d()\n    @only_fbip()\n    @memory_budget(270, \"cpu\")\n    @loss_test()\n    def test_resnet18_momentum_sgd(self, _):\n        # NOTE: this loss is only correct in my environment on 21\n        self._test_resnet18(\n            lambda params: flow.optim.SGD(params, lr=0.1, momentum=0.9), False, None\n        )\n\n    @flow.unittest.skip_unless_1n1d()\n    @only_fbip()\n    @memory_budget(310, \"cpu\")\n    @loss_test()\n    def test_resnet18_adam(self, _):\n        # NOTE: this loss is only correct in my environment on 21\n        self._test_resnet18(lambda params: flow.optim.Adam(params, lr=0.1), False, None)\n\n    @flow.unittest.skip_unless_1n1d()\n    @only_copy_on_write()\n    @memory_budget(12, \"cpu\")\n    def test_copy_on_write(self, _):\n        x1 = flow.ones(1024 * 1024)  # 4MB\n        x2 = flow.ones(1024 * 1024)\n        x3 = x2 + 1\n        x2 += x1\n        display(\"cpu\")\n        print(f\"x1 in memory?: {is_in_memory(x1)}\")\n        print(f\"x2 in memory?: {is_in_memory(x2)}\")\n        print(f\"x3 in memory?: {is_in_memory(x3)}\")\n\n        print(f\"recompute num: {flow._oneflow_internal.remat.recomputation_num()}\")\n        print(\n            f\"forced eviction num: {flow._oneflow_internal.remat.forced_eviction_num()}\"\n        )\n        print(\n            f\"eager eviction num: {flow._oneflow_internal.remat.eager_eviction_num()}\"\n        )\n\n        print(\"-------------\")\n\n        print(x3.numpy())\n        print(f\"x1 in memory?: {is_in_memory(x1)}\")\n        print(f\"x2 in memory?: {is_in_memory(x2)}\")\n        print(f\"x3 in memory?: {is_in_memory(x3)}\")\n\n        print(f\"recompute num: {flow._oneflow_internal.remat.recomputation_num()}\")\n        print(\n            f\"forced eviction num: {flow._oneflow_internal.remat.forced_eviction_num()}\"\n        )\n        print(\n            f\"eager eviction num: {flow._oneflow_internal.remat.eager_eviction_num()}\"\n        )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/expensive/pytorch_alexnet.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport torch\nimport torch.nn as nn\nfrom typing import Any\n\n\n__all__ = [\"AlexNet\", \"alexnet\"]\n\n\nclass AlexNet(nn.Module):\n    def __init__(self, num_classes: int = 1000) -> None:\n        super(AlexNet, self).__init__()\n        self.features = nn.Sequential(\n            nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),\n            nn.ReLU(inplace=True),\n            nn.MaxPool2d(kernel_size=3, stride=2),\n            nn.Conv2d(64, 192, kernel_size=5, padding=2),\n            nn.ReLU(inplace=True),\n            nn.MaxPool2d(kernel_size=3, stride=2),\n            nn.Conv2d(192, 384, kernel_size=3, padding=1),\n            nn.ReLU(inplace=True),\n            nn.Conv2d(384, 256, kernel_size=3, padding=1),\n            nn.ReLU(inplace=True),\n            nn.Conv2d(256, 256, kernel_size=3, padding=1),\n            nn.ReLU(inplace=True),\n            nn.MaxPool2d(kernel_size=3, stride=2),\n        )\n        self.avgpool = nn.AdaptiveAvgPool2d((6, 6))\n        self.classifier = nn.Sequential(\n            nn.Dropout(),\n            nn.Linear(256 * 6 * 6, 4096),\n            nn.ReLU(inplace=True),\n            nn.Dropout(),\n            nn.Linear(4096, 4096),\n            nn.ReLU(inplace=True),\n            nn.Linear(4096, num_classes),\n        )\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        x = self.features(x)\n        x = self.avgpool(x)\n        x = torch.flatten(x, 1)\n        x = self.classifier(x)\n        return x\n\n\ndef alexnet(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> AlexNet:\n    r\"\"\"AlexNet model architecture from the\n    `\"One weird trick...\" <https://arxiv.org/abs/1404.5997>`_ paper.\n    The required minimum input size of the model is 63x63.\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n        progress (bool): If True, displays a progress bar of the download to stderr\n    \"\"\"\n    model = AlexNet(**kwargs)\n    return model\n"
  },
  {
    "path": "python/oneflow/test/expensive/pytorch_convmixer.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport torch.nn as nn\n\n__all__ = [\"ConvMixer\", \"convmixer_768_32_relu\"]\n\n\nclass Residual(nn.Module):\n    def __init__(self, fn):\n        super().__init__()\n        self.fn = fn\n\n    def forward(self, x):\n        return self.fn(x) + x\n\n\ndef ConvMixer(dim, depth, kernel_size=9, patch_size=7, n_classes=1000):\n    return nn.Sequential(\n        nn.Conv2d(3, dim, kernel_size=patch_size, stride=patch_size),\n        nn.GELU(),\n        nn.BatchNorm2d(dim),\n        *[\n            nn.Sequential(\n                Residual(\n                    nn.Sequential(\n                        nn.Conv2d(dim, dim, kernel_size, groups=dim, padding=\"same\"),\n                        nn.GELU(),\n                        nn.BatchNorm2d(dim),\n                    )\n                ),\n                nn.Conv2d(dim, dim, kernel_size=1),\n                nn.GELU(),\n                nn.BatchNorm2d(dim),\n            )\n            for i in range(depth)\n        ],\n        nn.AdaptiveAvgPool2d((1, 1)),\n        nn.Flatten(),\n        nn.Linear(dim, n_classes)\n    )\n\n\ndef convmixer_768_32_relu(pretrained: bool = False, progress: bool = True, **kwargs):\n    \"\"\"\n    Constructs the ConvMixer model with 32 depth and 768 hidden size and ReLU activation layer.\n    .. note::\n        ConvMixer model with 32 depth and 768 hidden size and ReLU activation layer from the `Patched Are All You Need? <https://openreview.net/pdf?id=TVHS5Y4dNvM>`_ paper.\n    Args:\n        pretrained (bool): Whether to download the pre-trained model on ImageNet. Default: ``False``\n        progress (bool): If True, displays a progress bar of the download to stderr. Default: ``True``\n    \"\"\"\n    model = ConvMixer(768, 32, kernel_size=7, patch_size=7, n_classes=1000)\n    return model\n"
  },
  {
    "path": "python/oneflow/test/expensive/pytorch_convnext.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom timm.models.layers import trunc_normal_, DropPath\n\n__all__ = [\"ConvNeXt\", \"convnext_tiny\"]\n\n\nclass Block(nn.Module):\n    r\"\"\" ConvNeXt Block. There are two equivalent implementations:\n    (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)\n    (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back\n    We use (2) as we find it slightly faster in PyTorch\n    \n    Args:\n        dim (int): Number of input channels.\n        drop_path (float): Stochastic depth rate. Default: 0.0\n        layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.\n    \"\"\"\n\n    def __init__(self, dim, drop_path=0.0, layer_scale_init_value=1e-6):\n        super().__init__()\n        self.dwconv = nn.Conv2d(\n            dim, dim, kernel_size=7, padding=3, groups=dim\n        )  # depthwise conv\n        self.norm = LayerNorm(dim, eps=1e-6)\n        self.pwconv1 = nn.Linear(\n            dim, 4 * dim\n        )  # pointwise/1x1 convs, implemented with linear layers\n        self.act = nn.GELU()\n        self.pwconv2 = nn.Linear(4 * dim, dim)\n        self.gamma = (\n            nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)\n            if layer_scale_init_value > 0\n            else None\n        )\n        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()\n\n    def forward(self, x):\n        input = x\n        x = self.dwconv(x)\n        x = x.permute(0, 2, 3, 1)  # (N, C, H, W) -> (N, H, W, C)\n        x = self.norm(x)\n        x = self.pwconv1(x)\n        x = self.act(x)\n        x = self.pwconv2(x)\n        if self.gamma is not None:\n            x = self.gamma * x\n        x = x.permute(0, 3, 1, 2)  # (N, H, W, C) -> (N, C, H, W)\n\n        x = input + self.drop_path(x)\n        return x\n\n\nclass ConvNeXt(nn.Module):\n    r\"\"\" ConvNeXt\n        A PyTorch impl of : `A ConvNet for the 2020s`  -\n          https://arxiv.org/pdf/2201.03545.pdf\n    Args:\n        in_chans (int): Number of input image channels. Default: 3\n        num_classes (int): Number of classes for classification head. Default: 1000\n        depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]\n        dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]\n        drop_path_rate (float): Stochastic depth rate. Default: 0.\n        layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.\n        head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.\n    \"\"\"\n\n    def __init__(\n        self,\n        in_chans=3,\n        num_classes=1000,\n        depths=[3, 3, 9, 3],\n        dims=[96, 192, 384, 768],\n        drop_path_rate=0.0,\n        layer_scale_init_value=1e-6,\n        head_init_scale=1.0,\n    ):\n        super().__init__()\n\n        self.downsample_layers = (\n            nn.ModuleList()\n        )  # stem and 3 intermediate downsampling conv layers\n        stem = nn.Sequential(\n            nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),\n            LayerNorm(dims[0], eps=1e-6, data_format=\"channels_first\"),\n        )\n        self.downsample_layers.append(stem)\n        for i in range(3):\n            downsample_layer = nn.Sequential(\n                LayerNorm(dims[i], eps=1e-6, data_format=\"channels_first\"),\n                nn.Conv2d(dims[i], dims[i + 1], kernel_size=2, stride=2),\n            )\n            self.downsample_layers.append(downsample_layer)\n\n        self.stages = (\n            nn.ModuleList()\n        )  # 4 feature resolution stages, each consisting of multiple residual blocks\n        dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]\n        cur = 0\n        for i in range(4):\n            stage = nn.Sequential(\n                *[\n                    Block(\n                        dim=dims[i],\n                        drop_path=dp_rates[cur + j],\n                        layer_scale_init_value=layer_scale_init_value,\n                    )\n                    for j in range(depths[i])\n                ]\n            )\n            self.stages.append(stage)\n            cur += depths[i]\n\n        self.norm = nn.LayerNorm(dims[-1], eps=1e-6)  # final norm layer\n        self.head = nn.Linear(dims[-1], num_classes)\n\n        self.apply(self._init_weights)\n        self.head.weight.data.mul_(head_init_scale)\n        self.head.bias.data.mul_(head_init_scale)\n\n    def _init_weights(self, m):\n        if isinstance(m, (nn.Conv2d, nn.Linear)):\n            trunc_normal_(m.weight, std=0.02)\n            nn.init.constant_(m.bias, 0)\n\n    def forward_features(self, x):\n        for i in range(4):\n            x = self.downsample_layers[i](x)\n            x = self.stages[i](x)\n        return self.norm(\n            x.mean([-2, -1])\n        )  # global average pooling, (N, C, H, W) -> (N, C)\n\n    def forward(self, x):\n        x = self.forward_features(x)\n        x = self.head(x)\n        return x\n\n\nclass LayerNorm(nn.Module):\n    r\"\"\" LayerNorm that supports two data formats: channels_last (default) or channels_first. \n    The ordering of the dimensions in the inputs. channels_last corresponds to inputs with \n    shape (batch_size, height, width, channels) while channels_first corresponds to inputs \n    with shape (batch_size, channels, height, width).\n    \"\"\"\n\n    def __init__(self, normalized_shape, eps=1e-6, data_format=\"channels_last\"):\n        super().__init__()\n        self.weight = nn.Parameter(torch.ones(normalized_shape))\n        self.bias = nn.Parameter(torch.zeros(normalized_shape))\n        self.eps = eps\n        self.data_format = data_format\n        if self.data_format not in [\"channels_last\", \"channels_first\"]:\n            raise NotImplementedError\n        self.normalized_shape = (normalized_shape,)\n\n    def forward(self, x):\n        if self.data_format == \"channels_last\":\n            return F.layer_norm(\n                x, self.normalized_shape, self.weight, self.bias, self.eps\n            )\n        elif self.data_format == \"channels_first\":\n            u = x.mean(1, keepdim=True)\n            s = (x - u).pow(2).mean(1, keepdim=True)\n            x = (x - u) / torch.sqrt(s + self.eps)\n            x = self.weight[:, None, None] * x + self.bias[:, None, None]\n            return x\n\n\ndef convnext_tiny(pretrained=False, in_22k=False, **kwargs):\n    model = ConvNeXt(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs)\n    return model\n"
  },
  {
    "path": "python/oneflow/test/expensive/pytorch_crossformer.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport torch\nimport torch.nn as nn\nimport torch.utils.checkpoint as checkpoint\nfrom timm.models.layers import DropPath, to_2tuple, trunc_normal_\n\n__all__ = [\"CrossFormer\", \"crossformer_tiny_patch4_group7_224\"]\n\n\nclass Mlp(nn.Module):\n    def __init__(\n        self,\n        in_features,\n        hidden_features=None,\n        out_features=None,\n        act_layer=nn.GELU,\n        drop=0.0,\n    ):\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        self.fc1 = nn.Linear(in_features, hidden_features)\n        self.act = act_layer()\n        self.fc2 = nn.Linear(hidden_features, out_features)\n        self.drop = nn.Dropout(drop)\n\n    def forward(self, x):\n        x = self.fc1(x)\n        x = self.act(x)\n        x = self.drop(x)\n        x = self.fc2(x)\n        x = self.drop(x)\n        return x\n\n\nclass DynamicPosBias(nn.Module):\n    def __init__(self, dim, num_heads, residual):\n        super().__init__()\n        self.residual = residual\n        self.num_heads = num_heads\n        self.pos_dim = dim // 4\n        self.pos_proj = nn.Linear(2, self.pos_dim)\n        self.pos1 = nn.Sequential(\n            nn.LayerNorm(self.pos_dim),\n            nn.ReLU(inplace=True),\n            nn.Linear(self.pos_dim, self.pos_dim),\n        )\n        self.pos2 = nn.Sequential(\n            nn.LayerNorm(self.pos_dim),\n            nn.ReLU(inplace=True),\n            nn.Linear(self.pos_dim, self.pos_dim),\n        )\n        self.pos3 = nn.Sequential(\n            nn.LayerNorm(self.pos_dim),\n            nn.ReLU(inplace=True),\n            nn.Linear(self.pos_dim, self.num_heads),\n        )\n\n    def forward(self, biases):\n        if self.residual:\n            pos = self.pos_proj(biases)  # 2Wh-1 * 2Ww-1, heads\n            pos = pos + self.pos1(pos)\n            pos = pos + self.pos2(pos)\n            pos = self.pos3(pos)\n        else:\n            pos = self.pos3(self.pos2(self.pos1(self.pos_proj(biases))))\n        return pos\n\n    def flops(self, N):\n        flops = N * 2 * self.pos_dim\n        flops += N * self.pos_dim * self.pos_dim\n        flops += N * self.pos_dim * self.pos_dim\n        flops += N * self.pos_dim * self.num_heads\n        return flops\n\n\nclass Attention(nn.Module):\n    r\"\"\" Multi-head self attention module with dynamic position bias.\n    Args:\n        dim (int): Number of input channels.\n        group_size (tuple[int]): The height and width of the group.\n        num_heads (int): Number of attention heads.\n        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True\n        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set\n        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0\n        proj_drop (float, optional): Dropout ratio of output. Default: 0.0\n    \"\"\"\n\n    def __init__(\n        self,\n        dim,\n        group_size,\n        num_heads,\n        qkv_bias=True,\n        qk_scale=None,\n        attn_drop=0.0,\n        proj_drop=0.0,\n        position_bias=True,\n    ):\n\n        super().__init__()\n        self.dim = dim\n        self.group_size = group_size  # Wh, Ww\n        self.num_heads = num_heads\n        head_dim = dim // num_heads\n        self.scale = qk_scale or head_dim ** -0.5\n        self.position_bias = position_bias\n\n        if position_bias:\n            self.pos = DynamicPosBias(self.dim // 4, self.num_heads, residual=False)\n\n            # generate mother-set\n            position_bias_h = torch.arange(1 - self.group_size[0], self.group_size[0])\n            position_bias_w = torch.arange(1 - self.group_size[1], self.group_size[1])\n            biases = torch.stack(\n                torch.meshgrid([position_bias_h, position_bias_w])\n            )  # 2, 2Wh-1, 2W2-1\n            biases = biases.flatten(1).transpose(0, 1).float()\n            self.register_buffer(\"biases\", biases)\n\n            # get pair-wise relative position index for each token inside the group\n            coords_h = torch.arange(self.group_size[0])\n            coords_w = torch.arange(self.group_size[1])\n            coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww\n            coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww\n            relative_coords = (\n                coords_flatten[:, :, None] - coords_flatten[:, None, :]\n            )  # 2, Wh*Ww, Wh*Ww\n            relative_coords = relative_coords.permute(\n                1, 2, 0\n            ).contiguous()  # Wh*Ww, Wh*Ww, 2\n            relative_coords[:, :, 0] += self.group_size[0] - 1  # shift to start from 0\n            relative_coords[:, :, 1] += self.group_size[1] - 1\n            relative_coords[:, :, 0] *= 2 * self.group_size[1] - 1\n            relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww\n            self.register_buffer(\"relative_position_index\", relative_position_index)\n\n        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)\n        self.attn_drop = nn.Dropout(attn_drop)\n        self.proj = nn.Linear(dim, dim)\n        self.proj_drop = nn.Dropout(proj_drop)\n\n        self.softmax = nn.Softmax(dim=-1)\n\n    def forward(self, x, mask=None):\n        \"\"\"\n        Args:\n            x: input features with shape of (num_groups*B, N, C)\n            mask: (0/-inf) mask with shape of (num_groups, Wh*Ww, Wh*Ww) or None\n        \"\"\"\n        B_, N, C = x.shape\n        qkv = (\n            self.qkv(x)\n            .reshape(B_, N, 3, self.num_heads, C // self.num_heads)\n            .permute(2, 0, 3, 1, 4)\n        )\n        q, k, v = (\n            qkv[0],\n            qkv[1],\n            qkv[2],\n        )  # make torchscript happy (cannot use tensor as tuple)\n\n        q = q * self.scale\n        attn = q @ k.transpose(-2, -1)\n\n        if self.position_bias:\n            pos = self.pos(self.biases)  # 2Wh-1 * 2Ww-1, heads\n            # select position bias\n            relative_position_bias = pos[self.relative_position_index.view(-1)].view(\n                self.group_size[0] * self.group_size[1],\n                self.group_size[0] * self.group_size[1],\n                -1,\n            )  # Wh*Ww,Wh*Ww,nH\n            relative_position_bias = relative_position_bias.permute(\n                2, 0, 1\n            ).contiguous()  # nH, Wh*Ww, Wh*Ww\n            attn = attn + relative_position_bias.unsqueeze(0)\n\n        if mask is not None:\n            nW = mask.shape[0]\n            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(\n                1\n            ).unsqueeze(0)\n            attn = attn.view(-1, self.num_heads, N, N)\n            attn = self.softmax(attn)\n        else:\n            attn = self.softmax(attn)\n\n        attn = self.attn_drop(attn)\n\n        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)\n        x = self.proj(x)\n        x = self.proj_drop(x)\n        return x\n\n    def extra_repr(self) -> str:\n        return (\n            f\"dim={self.dim}, group_size={self.group_size}, num_heads={self.num_heads}\"\n        )\n\n    def flops(self, N):\n        # calculate flops for 1 group with token length of N\n        flops = 0\n        # qkv = self.qkv(x)\n        flops += N * self.dim * 3 * self.dim\n        # attn = (q @ k.transpose(-2, -1))\n        flops += self.num_heads * N * (self.dim // self.num_heads) * N\n        #  x = (attn @ v)\n        flops += self.num_heads * N * N * (self.dim // self.num_heads)\n        # x = self.proj(x)\n        flops += N * self.dim * self.dim\n        if self.position_bias:\n            flops += self.pos.flops(N)\n        return flops\n\n\nclass CrossFormerBlock(nn.Module):\n    r\"\"\" CrossFormer Block.\n    Args:\n        dim (int): Number of input channels.\n        input_resolution (tuple[int]): Input resulotion.\n        num_heads (int): Number of attention heads.\n        group_size (int): Group size.\n        lsda_flag (int): use SDA or LDA, 0 for SDA and 1 for LDA.\n        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.\n        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True\n        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.\n        drop (float, optional): Dropout rate. Default: 0.0\n        attn_drop (float, optional): Attention dropout rate. Default: 0.0\n        drop_path (float, optional): Stochastic depth rate. Default: 0.0\n        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU\n        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm\n    \"\"\"\n\n    def __init__(\n        self,\n        dim,\n        input_resolution,\n        num_heads,\n        group_size=7,\n        lsda_flag=0,\n        mlp_ratio=4.0,\n        qkv_bias=True,\n        qk_scale=None,\n        drop=0.0,\n        attn_drop=0.0,\n        drop_path=0.0,\n        act_layer=nn.GELU,\n        norm_layer=nn.LayerNorm,\n        num_patch_size=1,\n    ):\n        super().__init__()\n        self.dim = dim\n        self.input_resolution = input_resolution\n        self.num_heads = num_heads\n        self.group_size = group_size\n        self.lsda_flag = lsda_flag\n        self.mlp_ratio = mlp_ratio\n        self.num_patch_size = num_patch_size\n        if min(self.input_resolution) <= self.group_size:\n            # if group size is larger than input resolution, we don't partition groups\n            self.lsda_flag = 0\n            self.group_size = min(self.input_resolution)\n\n        self.norm1 = norm_layer(dim)\n\n        self.attn = Attention(\n            dim,\n            group_size=to_2tuple(self.group_size),\n            num_heads=num_heads,\n            qkv_bias=qkv_bias,\n            qk_scale=qk_scale,\n            attn_drop=attn_drop,\n            proj_drop=drop,\n            position_bias=True,\n        )\n\n        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()\n        self.norm2 = norm_layer(dim)\n        mlp_hidden_dim = int(dim * mlp_ratio)\n        self.mlp = Mlp(\n            in_features=dim,\n            hidden_features=mlp_hidden_dim,\n            act_layer=act_layer,\n            drop=drop,\n        )\n\n        attn_mask = None\n        self.register_buffer(\"attn_mask\", attn_mask)\n\n    def forward(self, x):\n        H, W = self.input_resolution\n        B, L, C = x.shape\n        assert L == H * W, \"input feature has wrong size %d, %d, %d\" % (L, H, W)\n\n        shortcut = x\n        x = self.norm1(x)\n        x = x.view(B, H, W, C)\n\n        # group embeddings\n        G = self.group_size\n        if self.lsda_flag == 0:  # 0 for SDA\n            x = x.reshape(B, H // G, G, W // G, G, C).permute(0, 1, 3, 2, 4, 5)\n        else:  # 1 for LDA\n            x = x.reshape(B, G, H // G, G, W // G, C).permute(0, 2, 4, 1, 3, 5)\n        x = x.reshape(B * H * W // G ** 2, G ** 2, C)\n\n        # multi-head self-attention\n        x = self.attn(x, mask=self.attn_mask)  # nW*B, G*G, C\n\n        # ungroup embeddings\n        x = x.reshape(B, H // G, W // G, G, G, C)\n        if self.lsda_flag == 0:\n            x = x.permute(0, 1, 3, 2, 4, 5).reshape(B, H, W, C)\n        else:\n            x = x.permute(0, 3, 1, 4, 2, 5).reshape(B, H, W, C)\n        x = x.view(B, H * W, C)\n\n        # FFN\n        x = shortcut + self.drop_path(x)\n        x = x + self.drop_path(self.mlp(self.norm2(x)))\n\n        return x\n\n    def extra_repr(self) -> str:\n        return (\n            f\"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, \"\n            f\"group_size={self.group_size}, lsda_flag={self.lsda_flag}, mlp_ratio={self.mlp_ratio}\"\n        )\n\n    def flops(self):\n        flops = 0\n        H, W = self.input_resolution\n        # norm1\n        flops += self.dim * H * W\n        # LSDA\n        nW = H * W / self.group_size / self.group_size\n        flops += nW * self.attn.flops(self.group_size * self.group_size)\n        # mlp\n        flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio\n        # norm2\n        flops += self.dim * H * W\n        return flops\n\n\nclass PatchMerging(nn.Module):\n    r\"\"\" Patch Merging Layer.\n    Args:\n        input_resolution (tuple[int]): Resolution of input feature.\n        dim (int): Number of input channels.\n        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm\n    \"\"\"\n\n    def __init__(\n        self,\n        input_resolution,\n        dim,\n        norm_layer=nn.LayerNorm,\n        patch_size=[2],\n        num_input_patch_size=1,\n    ):\n        super().__init__()\n        self.input_resolution = input_resolution\n        self.dim = dim\n        self.reductions = nn.ModuleList()\n        self.patch_size = patch_size\n        self.norm = norm_layer(dim)\n\n        for i, ps in enumerate(patch_size):\n            if i == len(patch_size) - 1:\n                out_dim = 2 * dim // 2 ** i\n            else:\n                out_dim = 2 * dim // 2 ** (i + 1)\n            stride = 2\n            padding = (ps - stride) // 2\n            self.reductions.append(\n                nn.Conv2d(dim, out_dim, kernel_size=ps, stride=stride, padding=padding)\n            )\n\n    def forward(self, x):\n        \"\"\"\n        x: B, H*W, C\n        \"\"\"\n        H, W = self.input_resolution\n        B, L, C = x.shape\n        assert L == H * W, \"input feature has wrong size\"\n        assert H % 2 == 0 and W % 2 == 0, f\"x size ({H}*{W}) are not even.\"\n\n        x = self.norm(x)\n        x = x.view(B, H, W, C).permute(0, 3, 1, 2)\n\n        xs = []\n        for i in range(len(self.reductions)):\n            tmp_x = self.reductions[i](x).flatten(2).transpose(1, 2)\n            xs.append(tmp_x)\n        x = torch.cat(xs, dim=2)\n        return x\n\n    def extra_repr(self) -> str:\n        return f\"input_resolution={self.input_resolution}, dim={self.dim}\"\n\n    def flops(self):\n        H, W = self.input_resolution\n        flops = H * W * self.dim\n        for i, ps in enumerate(self.patch_size):\n            if i == len(self.patch_size) - 1:\n                out_dim = 2 * self.dim // 2 ** i\n            else:\n                out_dim = 2 * self.dim // 2 ** (i + 1)\n            flops += (H // 2) * (W // 2) * ps * ps * out_dim * self.dim\n        return flops\n\n\nclass Stage(nn.Module):\n    \"\"\" CrossFormer blocks for one stage.\n    Args:\n        dim (int): Number of input channels.\n        input_resolution (tuple[int]): Input resolution.\n        depth (int): Number of blocks.\n        num_heads (int): Number of attention heads.\n        group_size (int): variable G in the paper, one group has GxG embeddings\n        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.\n        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True\n        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.\n        drop (float, optional): Dropout rate. Default: 0.0\n        attn_drop (float, optional): Attention dropout rate. Default: 0.0\n        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0\n        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm\n        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None\n        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.\n    \"\"\"\n\n    def __init__(\n        self,\n        dim,\n        input_resolution,\n        depth,\n        num_heads,\n        group_size,\n        mlp_ratio=4.0,\n        qkv_bias=True,\n        qk_scale=None,\n        drop=0.0,\n        attn_drop=0.0,\n        drop_path=0.0,\n        norm_layer=nn.LayerNorm,\n        downsample=None,\n        use_checkpoint=False,\n        patch_size_end=[4],\n        num_patch_size=None,\n    ):\n\n        super().__init__()\n        self.dim = dim\n        self.input_resolution = input_resolution\n        self.depth = depth\n        self.use_checkpoint = use_checkpoint\n\n        # build blocks\n        self.blocks = nn.ModuleList()\n        for i in range(depth):\n            lsda_flag = 0 if (i % 2 == 0) else 1\n            self.blocks.append(\n                CrossFormerBlock(\n                    dim=dim,\n                    input_resolution=input_resolution,\n                    num_heads=num_heads,\n                    group_size=group_size,\n                    lsda_flag=lsda_flag,\n                    mlp_ratio=mlp_ratio,\n                    qkv_bias=qkv_bias,\n                    qk_scale=qk_scale,\n                    drop=drop,\n                    attn_drop=attn_drop,\n                    drop_path=drop_path[i]\n                    if isinstance(drop_path, list)\n                    else drop_path,\n                    norm_layer=norm_layer,\n                    num_patch_size=num_patch_size,\n                )\n            )\n\n        # patch merging layer\n        if downsample is not None:\n            self.downsample = downsample(\n                input_resolution,\n                dim=dim,\n                norm_layer=norm_layer,\n                patch_size=patch_size_end,\n                num_input_patch_size=num_patch_size,\n            )\n        else:\n            self.downsample = None\n\n    def forward(self, x):\n        for blk in self.blocks:\n            # if self.use_checkpoint:\n            #     x = checkpoint.checkpoint(blk, x)\n            # else:\n            x = blk(x)\n        if self.downsample is not None:\n            x = self.downsample(x)\n        return x\n\n    def extra_repr(self) -> str:\n        return f\"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}\"\n\n    def flops(self):\n        flops = 0\n        for blk in self.blocks:\n            flops += blk.flops()\n        if self.downsample is not None:\n            flops += self.downsample.flops()\n        return flops\n\n\nclass PatchEmbed(nn.Module):\n    r\"\"\" Image to Patch Embedding\n    Args:\n        img_size (int): Image size.  Default: 224.\n        patch_size (int): Patch token size. Default: [4].\n        in_chans (int): Number of input image channels. Default: 3.\n        embed_dim (int): Number of linear projection output channels. Default: 96.\n        norm_layer (nn.Module, optional): Normalization layer. Default: None\n    \"\"\"\n\n    def __init__(\n        self, img_size=224, patch_size=[4], in_chans=3, embed_dim=96, norm_layer=None\n    ):\n        super().__init__()\n        img_size = to_2tuple(img_size)\n        # patch_size = to_2tuple(patch_size)\n        patches_resolution = [\n            img_size[0] // patch_size[0],\n            img_size[0] // patch_size[0],\n        ]\n        self.img_size = img_size\n        self.patch_size = patch_size\n        self.patches_resolution = patches_resolution\n        self.num_patches = patches_resolution[0] * patches_resolution[1]\n\n        self.in_chans = in_chans\n        self.embed_dim = embed_dim\n\n        self.projs = nn.ModuleList()\n        for i, ps in enumerate(patch_size):\n            if i == len(patch_size) - 1:\n                dim = embed_dim // 2 ** i\n            else:\n                dim = embed_dim // 2 ** (i + 1)\n            stride = patch_size[0]\n            padding = (ps - patch_size[0]) // 2\n            self.projs.append(\n                nn.Conv2d(in_chans, dim, kernel_size=ps, stride=stride, padding=padding)\n            )\n        if norm_layer is not None:\n            self.norm = norm_layer(embed_dim)\n        else:\n            self.norm = None\n\n    def forward(self, x):\n        B, C, H, W = x.shape\n        # FIXME look at relaxing size constraints\n        assert (\n            H == self.img_size[0] and W == self.img_size[1]\n        ), f\"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).\"\n        xs = []\n        for i in range(len(self.projs)):\n            tx = self.projs[i](x).flatten(2).transpose(1, 2)\n            xs.append(tx)  # B Ph*Pw C\n        x = torch.cat(xs, dim=2)\n        if self.norm is not None:\n            x = self.norm(x)\n        return x\n\n    def flops(self):\n        Ho, Wo = self.patches_resolution\n        flops = 0\n        for i, ps in enumerate(self.patch_size):\n            if i == len(self.patch_size) - 1:\n                dim = self.embed_dim // 2 ** i\n            else:\n                dim = self.embed_dim // 2 ** (i + 1)\n            flops += (\n                Ho\n                * Wo\n                * dim\n                * self.in_chans\n                * (self.patch_size[i] * self.patch_size[i])\n            )\n        if self.norm is not None:\n            flops += Ho * Wo * self.embed_dim\n        return flops\n\n\nclass CrossFormer(nn.Module):\n    r\"\"\" CrossFormer\n        A PyTorch impl of : `CrossFormer: A Versatile Vision Transformer Based on Cross-scale Attention`  -\n    Args:\n        img_size (int | tuple(int)): Input image size. Default 224\n        patch_size (int | tuple(int)): Patch size. Default: 4\n        in_chans (int): Number of input image channels. Default: 3\n        num_classes (int): Number of classes for classification head. Default: 1000\n        embed_dim (int): Patch embedding dimension. Default: 96\n        depths (tuple(int)): Depth of each stage.\n        num_heads (tuple(int)): Number of attention heads in different layers.\n        group_size (int): Group size. Default: 7\n        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4\n        qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True\n        qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None\n        drop_rate (float): Dropout rate. Default: 0\n        attn_drop_rate (float): Attention dropout rate. Default: 0\n        drop_path_rate (float): Stochastic depth rate. Default: 0.1\n        norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.\n        ape (bool): If True, add absolute position embedding to the patch embedding. Default: False\n        patch_norm (bool): If True, add normalization after patch embedding. Default: True\n        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False\n    \"\"\"\n\n    def __init__(\n        self,\n        img_size=224,\n        patch_size=[4],\n        in_chans=3,\n        num_classes=1000,\n        embed_dim=96,\n        depths=[2, 2, 6, 2],\n        num_heads=[3, 6, 12, 24],\n        group_size=7,\n        mlp_ratio=4.0,\n        qkv_bias=True,\n        qk_scale=None,\n        drop_rate=0.0,\n        attn_drop_rate=0.0,\n        drop_path_rate=0.1,\n        norm_layer=nn.LayerNorm,\n        ape=False,\n        patch_norm=True,\n        use_checkpoint=False,\n        merge_size=[[2], [2], [2]],\n        **kwargs,\n    ):\n        super().__init__()\n\n        self.num_classes = num_classes\n        self.num_layers = len(depths)\n        self.embed_dim = embed_dim\n        self.ape = ape\n        self.patch_norm = patch_norm\n        self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))\n        self.mlp_ratio = mlp_ratio\n\n        # split image into non-overlapping patches\n        self.patch_embed = PatchEmbed(\n            img_size=img_size,\n            patch_size=patch_size,\n            in_chans=in_chans,\n            embed_dim=embed_dim,\n            norm_layer=norm_layer if self.patch_norm else None,\n        )\n        num_patches = self.patch_embed.num_patches\n        patches_resolution = self.patch_embed.patches_resolution\n        self.patches_resolution = patches_resolution\n\n        # absolute position embedding\n        if self.ape:\n            self.absolute_pos_embed = nn.Parameter(\n                torch.zeros(1, num_patches, embed_dim)\n            )\n            trunc_normal_(self.absolute_pos_embed, std=0.02)\n\n        self.pos_drop = nn.Dropout(p=drop_rate)\n\n        # stochastic depth\n        dpr = [\n            x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))\n        ]  # stochastic depth decay rule\n\n        # build layers\n        self.layers = nn.ModuleList()\n\n        num_patch_sizes = [len(patch_size)] + [len(m) for m in merge_size]\n        for i_layer in range(self.num_layers):\n            patch_size_end = (\n                merge_size[i_layer] if i_layer < self.num_layers - 1 else None\n            )\n            num_patch_size = num_patch_sizes[i_layer]\n            layer = Stage(\n                dim=int(embed_dim * 2 ** i_layer),\n                input_resolution=(\n                    patches_resolution[0] // (2 ** i_layer),\n                    patches_resolution[1] // (2 ** i_layer),\n                ),\n                depth=depths[i_layer],\n                num_heads=num_heads[i_layer],\n                group_size=group_size[i_layer],\n                mlp_ratio=self.mlp_ratio,\n                qkv_bias=qkv_bias,\n                qk_scale=qk_scale,\n                drop=drop_rate,\n                attn_drop=attn_drop_rate,\n                drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])],\n                norm_layer=norm_layer,\n                downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,\n                use_checkpoint=use_checkpoint,\n                patch_size_end=patch_size_end,\n                num_patch_size=num_patch_size,\n            )\n            self.layers.append(layer)\n\n        self.norm = norm_layer(self.num_features)\n        self.avgpool = nn.AdaptiveAvgPool1d(1)\n        self.head = (\n            nn.Linear(self.num_features, num_classes)\n            if num_classes > 0\n            else nn.Identity()\n        )\n\n        self.apply(self._init_weights)\n\n    def _init_weights(self, m):\n        if isinstance(m, nn.Linear):\n            trunc_normal_(m.weight, std=0.02)\n            if isinstance(m, nn.Linear) and m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n        elif isinstance(m, nn.LayerNorm):\n            nn.init.constant_(m.bias, 0)\n            nn.init.constant_(m.weight, 1.0)\n\n    def no_weight_decay(self):\n        return {\"absolute_pos_embed\"}\n\n    def no_weight_decay_keywords(self):\n        return {\"relative_position_bias_table\"}\n\n    def forward_features(self, x):\n        x = self.patch_embed(x)\n        if self.ape:\n            x = x + self.absolute_pos_embed\n        x = self.pos_drop(x)\n\n        for layer in self.layers:\n            x = layer(x)\n\n        x = self.norm(x)  # B L C\n        x = self.avgpool(x.transpose(1, 2))  # B C 1\n        x = torch.flatten(x, 1)\n        return x\n\n    def forward(self, x):\n        x = self.forward_features(x)\n        x = self.head(x)\n        return x\n\n    def flops(self):\n        flops = 0\n        flops += self.patch_embed.flops()\n        for i, layer in enumerate(self.layers):\n            flops += layer.flops()\n        flops += (\n            self.num_features\n            * self.patches_resolution[0]\n            * self.patches_resolution[1]\n            // (2 ** self.num_layers)\n        )\n        flops += self.num_features * self.num_classes\n        return flops\n\n\ndef _create_cross_former(arch, pretrained=False, progress=True, **model_kwargs):\n    model = CrossFormer(**model_kwargs)\n    return model\n\n\ndef crossformer_tiny_patch4_group7_224(pretrained=False, progress=True, **kwargs):\n    \"\"\"\n    Constructs CrossFormer-T 224x224 model.\n    .. note::\n        CrossFormer-T 224x224 model from `\"CrossFormer: A Versatile Vision Transformer Based on Cross-scale Attention\" <https://arxiv.org/pdf/2108.00154.pdf>`_.\n    Args:\n        pretrained (bool): Whether to download the pre-trained model on ImageNet. Default: ``False``\n        progress (bool): If True, displays a progress bar of the download to stderr. Default: ``True``\n    For example:\n    .. code-block:: python\n        >>> import flowvision\n        >>> crossformer_tiny_patch4_group7_224 = flowvision.models.crossformer_tiny_patch4_group7_224(pretrained=False, progress=True)\n    \"\"\"\n    model_kwargs = dict(\n        img_size=224,\n        patch_size=(4, 8, 16, 32),\n        embed_dim=64,\n        depths=(1, 1, 8, 6),\n        num_heads=(2, 4, 8, 16),\n        group_size=(7, 7, 7, 7),\n        merge_size=((2, 4), (2, 4), (2, 4)),\n        drop_path_rate=0.1,\n        **kwargs,\n    )\n    return _create_cross_former(\n        \"crossformer_tiny_patch4_group7_224\",\n        pretrained=pretrained,\n        progress=progress,\n        **model_kwargs,\n    )\n"
  },
  {
    "path": "python/oneflow/test/expensive/pytorch_densenet.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch import Tensor\n\nfrom collections import OrderedDict\nfrom typing import Any, List, Tuple\n\n__all__ = [\n    \"DenseNet\",\n    \"densenet121\",\n]\n\n\nclass _DenseLayer(nn.Module):\n    def __init__(\n        self,\n        num_input_features: int,\n        growth_rate: int,\n        bn_size: int,\n        drop_rate: float,\n        memory_efficient: bool = False,\n    ) -> None:\n        super().__init__()\n        self.norm1: nn.BatchNorm2d\n        self.add_module(\"norm1\", nn.BatchNorm2d(num_input_features))\n        self.relu1: nn.ReLU\n        self.add_module(\"relu1\", nn.ReLU(inplace=True))\n        self.conv1: nn.Conv2d\n        self.add_module(\n            \"conv1\",\n            nn.Conv2d(\n                num_input_features,\n                bn_size * growth_rate,\n                kernel_size=1,\n                stride=1,\n                bias=False,\n            ),\n        )\n        self.norm2: nn.BatchNorm2d\n        self.add_module(\"norm2\", nn.BatchNorm2d(bn_size * growth_rate))\n        self.relu2: nn.ReLU\n        self.add_module(\"relu2\", nn.ReLU(inplace=True))\n        self.conv2: nn.Conv2d\n        self.add_module(\n            \"conv2\",\n            nn.Conv2d(\n                bn_size * growth_rate,\n                growth_rate,\n                kernel_size=3,\n                stride=1,\n                padding=1,\n                bias=False,\n            ),\n        )\n        self.drop_rate = float(drop_rate)\n        self.memory_efficient = memory_efficient\n\n    def bn_function(self, inputs: List[Tensor]) -> Tensor:\n        concated_features = torch.cat(inputs, 1)\n        bottleneck_output = self.conv1(\n            self.relu1(self.norm1(concated_features))\n        )  # noqa: T484\n        return bottleneck_output\n\n    # todo: rewrite when torchscript supports any\n    def any_requires_grad(self, input: List[Tensor]) -> bool:\n        for tensor in input:\n            if tensor.requires_grad:\n                return True\n        return False\n\n    # torchscript does not yet support *args, so we overload method\n    # allowing it to take either a List[Tensor] or single Tensor\n    def forward(self, input: Tensor) -> Tensor:  # noqa: F811\n        if isinstance(input, Tensor):\n            prev_features = [input]\n        else:\n            prev_features = input\n\n        if self.memory_efficient and self.any_requires_grad(prev_features):\n            if torch.jit.is_scripting():\n                raise Exception(\"Memory Efficient not supported in JIT\")\n\n            bottleneck_output = self.call_checkpoint_bottleneck(prev_features)\n        else:\n            bottleneck_output = self.bn_function(prev_features)\n\n        new_features = self.conv2(self.relu2(self.norm2(bottleneck_output)))\n        if self.drop_rate > 0:\n            new_features = F.dropout(\n                new_features, p=self.drop_rate, training=self.training\n            )\n        return new_features\n\n\nclass _DenseBlock(nn.ModuleDict):\n    _version = 2\n\n    def __init__(\n        self,\n        num_layers: int,\n        num_input_features: int,\n        bn_size: int,\n        growth_rate: int,\n        drop_rate: float,\n        memory_efficient: bool = False,\n    ) -> None:\n        super().__init__()\n        for i in range(num_layers):\n            layer = _DenseLayer(\n                num_input_features + i * growth_rate,\n                growth_rate=growth_rate,\n                bn_size=bn_size,\n                drop_rate=drop_rate,\n                memory_efficient=memory_efficient,\n            )\n            self.add_module(\"denselayer%d\" % (i + 1), layer)\n\n    def forward(self, init_features: Tensor) -> Tensor:\n        features = [init_features]\n        for name, layer in self.items():\n            new_features = layer(features)\n            features.append(new_features)\n        return torch.cat(features, 1)\n\n\nclass _Transition(nn.Sequential):\n    def __init__(self, num_input_features: int, num_output_features: int) -> None:\n        super().__init__()\n        self.add_module(\"norm\", nn.BatchNorm2d(num_input_features))\n        self.add_module(\"relu\", nn.ReLU(inplace=True))\n        self.add_module(\n            \"conv\",\n            nn.Conv2d(\n                num_input_features,\n                num_output_features,\n                kernel_size=1,\n                stride=1,\n                bias=False,\n            ),\n        )\n        self.add_module(\"pool\", nn.AvgPool2d(kernel_size=2, stride=2))\n\n\nclass DenseNet(nn.Module):\n    r\"\"\"Densenet-BC model class, based on\n    `\"Densely Connected Convolutional Networks\" <https://arxiv.org/pdf/1608.06993.pdf>`_.\n    Args:\n        growth_rate (int) - how many filters to add each layer (`k` in paper)\n        block_config (list of 4 ints) - how many layers in each pooling block\n        num_init_features (int) - the number of filters to learn in the first convolution layer\n        bn_size (int) - multiplicative factor for number of bottle neck layers\n          (i.e. bn_size * k features in the bottleneck layer)\n        drop_rate (float) - dropout rate after each dense layer\n        num_classes (int) - number of classification classes\n        memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient,\n          but slower. Default: *False*. See `\"paper\" <https://arxiv.org/pdf/1707.06990.pdf>`_.\n    \"\"\"\n\n    def __init__(\n        self,\n        growth_rate: int = 32,\n        block_config: Tuple[int, int, int, int] = (6, 12, 24, 16),\n        num_init_features: int = 64,\n        bn_size: int = 4,\n        drop_rate: float = 0,\n        num_classes: int = 1000,\n        memory_efficient: bool = False,\n    ) -> None:\n\n        super().__init__()\n\n        # First convolution\n        self.features = nn.Sequential(\n            OrderedDict(\n                [\n                    (\n                        \"conv0\",\n                        nn.Conv2d(\n                            3,\n                            num_init_features,\n                            kernel_size=7,\n                            stride=2,\n                            padding=3,\n                            bias=False,\n                        ),\n                    ),\n                    (\"norm0\", nn.BatchNorm2d(num_init_features)),\n                    (\"relu0\", nn.ReLU(inplace=True)),\n                    (\"pool0\", nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),\n                ]\n            )\n        )\n\n        # Each denseblock\n        num_features = num_init_features\n        for i, num_layers in enumerate(block_config):\n            block = _DenseBlock(\n                num_layers=num_layers,\n                num_input_features=num_features,\n                bn_size=bn_size,\n                growth_rate=growth_rate,\n                drop_rate=drop_rate,\n                memory_efficient=memory_efficient,\n            )\n            self.features.add_module(\"denseblock%d\" % (i + 1), block)\n            num_features = num_features + num_layers * growth_rate\n            if i != len(block_config) - 1:\n                trans = _Transition(\n                    num_input_features=num_features,\n                    num_output_features=num_features // 2,\n                )\n                self.features.add_module(\"transition%d\" % (i + 1), trans)\n                num_features = num_features // 2\n\n        # Final batch norm\n        self.features.add_module(\"norm5\", nn.BatchNorm2d(num_features))\n\n        # Linear layer\n        self.classifier = nn.Linear(num_features, num_classes)\n\n        # Official init from torch repo.\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                nn.init.kaiming_normal_(m.weight)\n            elif isinstance(m, nn.BatchNorm2d):\n                nn.init.constant_(m.weight, 1)\n                nn.init.constant_(m.bias, 0)\n            elif isinstance(m, nn.Linear):\n                nn.init.constant_(m.bias, 0)\n\n    def forward(self, x: Tensor) -> Tensor:\n        features = self.features(x)\n        out = F.relu(features, inplace=True)\n        out = F.adaptive_avg_pool2d(out, (1, 1))\n        out = torch.flatten(out, 1)\n        out = self.classifier(out)\n        return out\n\n\ndef _densenet(\n    growth_rate: int,\n    block_config: Tuple[int, int, int, int],\n    num_init_features: int,\n    progress: bool,\n    **kwargs: Any,\n) -> DenseNet:\n    model = DenseNet(growth_rate, block_config, num_init_features, **kwargs)\n    return model\n\n\ndef densenet121(progress: bool = True, **kwargs: Any) -> DenseNet:\n    r\"\"\"Densenet-121 model from\n    `\"Densely Connected Convolutional Networks\" <https://arxiv.org/pdf/1608.06993.pdf>`_.\n    The required minimum input size of the model is 29x29.\n    Args:\n        weights (DenseNet121_Weights, optional): The pretrained weights for the model\n        progress (bool): If True, displays a progress bar of the download to stderr\n        memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient,\n          but slower. Default: *False*. See `\"paper\" <https://arxiv.org/pdf/1707.06990.pdf>`_.\n    \"\"\"\n\n    return _densenet(32, (6, 12, 24, 16), 64, progress, **kwargs)\n"
  },
  {
    "path": "python/oneflow/test/expensive/pytorch_efficientnet.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport torch\nfrom torch import nn, Tensor\nfrom torchvision.ops import StochasticDepth\n\nimport copy\nimport math\nimport warnings\nfrom dataclasses import dataclass\nfrom functools import partial\nfrom typing import Any, Callable, Optional, List, Sequence, Tuple, Union\n\n__all__ = [\n    \"EfficientNet\",\n    \"efficientnet_b0\",\n]\n\n\nclass SqueezeExcitation(torch.nn.Module):\n    \"\"\"\n    This block implements the Squeeze-and-Excitation block from https://arxiv.org/abs/1709.01507 (see Fig. 1).\n    Parameters ``activation``, and ``scale_activation`` correspond to ``delta`` and ``sigma`` in in eq. 3.\n    Args:\n        input_channels (int): Number of channels in the input image\n        squeeze_channels (int): Number of squeeze channels\n        activation (Callable[..., torch.nn.Module], optional): ``delta`` activation. Default: ``torch.nn.ReLU``\n        scale_activation (Callable[..., torch.nn.Module]): ``sigma`` activation. Default: ``torch.nn.Sigmoid``\n    \"\"\"\n\n    def __init__(\n        self,\n        input_channels: int,\n        squeeze_channels: int,\n        activation: Callable[..., torch.nn.Module] = torch.nn.ReLU,\n        scale_activation: Callable[..., torch.nn.Module] = torch.nn.Sigmoid,\n    ) -> None:\n        super().__init__()\n        self.avgpool = torch.nn.AdaptiveAvgPool2d(1)\n        self.fc1 = torch.nn.Conv2d(input_channels, squeeze_channels, 1)\n        self.fc2 = torch.nn.Conv2d(squeeze_channels, input_channels, 1)\n        self.activation = activation()\n        self.scale_activation = scale_activation()\n\n    def _scale(self, input: Tensor) -> Tensor:\n        scale = self.avgpool(input)\n        scale = self.fc1(scale)\n        scale = self.activation(scale)\n        scale = self.fc2(scale)\n        return self.scale_activation(scale)\n\n    def forward(self, input: Tensor) -> Tensor:\n        scale = self._scale(input)\n        return scale * input\n\n\nclass ConvNormActivation(torch.nn.Sequential):\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        kernel_size: int = 3,\n        stride: int = 1,\n        padding: Optional[int] = None,\n        groups: int = 1,\n        norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm2d,\n        activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,\n        dilation: int = 1,\n        inplace: Optional[bool] = True,\n        bias: Optional[bool] = None,\n        conv_layer: Callable[..., torch.nn.Module] = torch.nn.Conv2d,\n    ) -> None:\n\n        if padding is None:\n            padding = (kernel_size - 1) // 2 * dilation\n        if bias is None:\n            bias = norm_layer is None\n\n        layers = [\n            conv_layer(\n                in_channels,\n                out_channels,\n                kernel_size,\n                stride,\n                padding,\n                dilation=dilation,\n                groups=groups,\n                bias=bias,\n            )\n        ]\n\n        if norm_layer is not None:\n            layers.append(norm_layer(out_channels))\n\n        if activation_layer is not None:\n            params = {} if inplace is None else {\"inplace\": inplace}\n            layers.append(activation_layer(**params))\n        super().__init__(*layers)\n        self.out_channels = out_channels\n\n        if self.__class__ == ConvNormActivation:\n            warnings.warn(\n                \"Don't use ConvNormActivation directly, please use Conv2dNormActivation and Conv3dNormActivation instead.\"\n            )\n\n\nclass Conv2dNormActivation(ConvNormActivation):\n    \"\"\"\n    Configurable block used for Convolution2d-Normalization-Activation blocks.\n    Args:\n        in_channels (int): Number of channels in the input image\n        out_channels (int): Number of channels produced by the Convolution-Normalization-Activation block\n        kernel_size: (int, optional): Size of the convolving kernel. Default: 3\n        stride (int, optional): Stride of the convolution. Default: 1\n        padding (int, tuple or str, optional): Padding added to all four sides of the input. Default: None, in which case it will calculated as ``padding = (kernel_size - 1) // 2 * dilation``\n        groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1\n        norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the convolution layer. If ``None`` this layer wont be used. Default: ``torch.nn.BatchNorm2d``\n        activation_layer (Callable[..., torch.nn.Module], optinal): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the conv layer. If ``None`` this layer wont be used. Default: ``torch.nn.ReLU``\n        dilation (int): Spacing between kernel elements. Default: 1\n        inplace (bool): Parameter for the activation layer, which can optionally do the operation in-place. Default ``True``\n        bias (bool, optional): Whether to use bias in the convolution layer. By default, biases are included if ``norm_layer is None``.\n    \"\"\"\n\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        kernel_size: int = 3,\n        stride: int = 1,\n        padding: Optional[int] = None,\n        groups: int = 1,\n        norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm2d,\n        activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,\n        dilation: int = 1,\n        inplace: Optional[bool] = True,\n        bias: Optional[bool] = None,\n    ) -> None:\n\n        super().__init__(\n            in_channels,\n            out_channels,\n            kernel_size,\n            stride,\n            padding,\n            groups,\n            norm_layer,\n            activation_layer,\n            dilation,\n            inplace,\n            bias,\n            torch.nn.Conv2d,\n        )\n\n\ndef _make_divisible(v, divisor=8, min_value=None, round_limit=0.9):\n    min_value = min_value or divisor\n    new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)\n    # Make sure that round down does not go down by more than 10%.\n    if new_v < round_limit * v:\n        new_v += divisor\n    return new_v\n\n\n@dataclass\nclass _MBConvConfig:\n    expand_ratio: float\n    kernel: int\n    stride: int\n    input_channels: int\n    out_channels: int\n    num_layers: int\n    block: Callable[..., nn.Module]\n\n    @staticmethod\n    def adjust_channels(\n        channels: int, width_mult: float, min_value: Optional[int] = None\n    ) -> int:\n        return _make_divisible(channels * width_mult, 8, min_value)\n\n\nclass MBConvConfig(_MBConvConfig):\n    # Stores information listed at Table 1 of the EfficientNet paper & Table 4 of the EfficientNetV2 paper\n    def __init__(\n        self,\n        expand_ratio: float,\n        kernel: int,\n        stride: int,\n        input_channels: int,\n        out_channels: int,\n        num_layers: int,\n        width_mult: float = 1.0,\n        depth_mult: float = 1.0,\n        block: Optional[Callable[..., nn.Module]] = None,\n    ) -> None:\n        input_channels = self.adjust_channels(input_channels, width_mult)\n        out_channels = self.adjust_channels(out_channels, width_mult)\n        num_layers = self.adjust_depth(num_layers, depth_mult)\n        if block is None:\n            block = MBConv\n        super().__init__(\n            expand_ratio,\n            kernel,\n            stride,\n            input_channels,\n            out_channels,\n            num_layers,\n            block,\n        )\n\n    @staticmethod\n    def adjust_depth(num_layers: int, depth_mult: float):\n        return int(math.ceil(num_layers * depth_mult))\n\n\nclass FusedMBConvConfig(_MBConvConfig):\n    # Stores information listed at Table 4 of the EfficientNetV2 paper\n    def __init__(\n        self,\n        expand_ratio: float,\n        kernel: int,\n        stride: int,\n        input_channels: int,\n        out_channels: int,\n        num_layers: int,\n        block: Optional[Callable[..., nn.Module]] = None,\n    ) -> None:\n        if block is None:\n            block = FusedMBConv\n        super().__init__(\n            expand_ratio,\n            kernel,\n            stride,\n            input_channels,\n            out_channels,\n            num_layers,\n            block,\n        )\n\n\nclass MBConv(nn.Module):\n    def __init__(\n        self,\n        cnf: MBConvConfig,\n        stochastic_depth_prob: float,\n        norm_layer: Callable[..., nn.Module],\n        se_layer: Callable[..., nn.Module] = SqueezeExcitation,\n    ) -> None:\n        super().__init__()\n\n        if not (1 <= cnf.stride <= 2):\n            raise ValueError(\"illegal stride value\")\n\n        self.use_res_connect = (\n            cnf.stride == 1 and cnf.input_channels == cnf.out_channels\n        )\n\n        layers: List[nn.Module] = []\n        activation_layer = nn.SiLU\n\n        # expand\n        expanded_channels = cnf.adjust_channels(cnf.input_channels, cnf.expand_ratio)\n        if expanded_channels != cnf.input_channels:\n            layers.append(\n                Conv2dNormActivation(\n                    cnf.input_channels,\n                    expanded_channels,\n                    kernel_size=1,\n                    norm_layer=norm_layer,\n                    activation_layer=activation_layer,\n                )\n            )\n\n        # depthwise\n        layers.append(\n            Conv2dNormActivation(\n                expanded_channels,\n                expanded_channels,\n                kernel_size=cnf.kernel,\n                stride=cnf.stride,\n                groups=expanded_channels,\n                norm_layer=norm_layer,\n                activation_layer=activation_layer,\n            )\n        )\n\n        # squeeze and excitation\n        squeeze_channels = max(1, cnf.input_channels // 4)\n        layers.append(\n            se_layer(\n                expanded_channels,\n                squeeze_channels,\n                activation=partial(nn.SiLU, inplace=True),\n            )\n        )\n\n        # project\n        layers.append(\n            Conv2dNormActivation(\n                expanded_channels,\n                cnf.out_channels,\n                kernel_size=1,\n                norm_layer=norm_layer,\n                activation_layer=None,\n            )\n        )\n\n        self.block = nn.Sequential(*layers)\n        self.stochastic_depth = StochasticDepth(stochastic_depth_prob, \"row\")\n        self.out_channels = cnf.out_channels\n\n    def forward(self, input: Tensor) -> Tensor:\n        result = self.block(input)\n        if self.use_res_connect:\n            result = self.stochastic_depth(result)\n            result += input\n        return result\n\n\nclass FusedMBConv(nn.Module):\n    def __init__(\n        self,\n        cnf: FusedMBConvConfig,\n        stochastic_depth_prob: float,\n        norm_layer: Callable[..., nn.Module],\n    ) -> None:\n        super().__init__()\n\n        if not (1 <= cnf.stride <= 2):\n            raise ValueError(\"illegal stride value\")\n\n        self.use_res_connect = (\n            cnf.stride == 1 and cnf.input_channels == cnf.out_channels\n        )\n\n        layers: List[nn.Module] = []\n        activation_layer = nn.SiLU\n\n        expanded_channels = cnf.adjust_channels(cnf.input_channels, cnf.expand_ratio)\n        if expanded_channels != cnf.input_channels:\n            # fused expand\n            layers.append(\n                Conv2dNormActivation(\n                    cnf.input_channels,\n                    expanded_channels,\n                    kernel_size=cnf.kernel,\n                    stride=cnf.stride,\n                    norm_layer=norm_layer,\n                    activation_layer=activation_layer,\n                )\n            )\n\n            # project\n            layers.append(\n                Conv2dNormActivation(\n                    expanded_channels,\n                    cnf.out_channels,\n                    kernel_size=1,\n                    norm_layer=norm_layer,\n                    activation_layer=None,\n                )\n            )\n        else:\n            layers.append(\n                Conv2dNormActivation(\n                    cnf.input_channels,\n                    cnf.out_channels,\n                    kernel_size=cnf.kernel,\n                    stride=cnf.stride,\n                    norm_layer=norm_layer,\n                    activation_layer=activation_layer,\n                )\n            )\n\n        self.block = nn.Sequential(*layers)\n        self.stochastic_depth = StochasticDepth(stochastic_depth_prob, \"row\")\n        self.out_channels = cnf.out_channels\n\n    def forward(self, input: Tensor) -> Tensor:\n        result = self.block(input)\n        if self.use_res_connect:\n            result = self.stochastic_depth(result)\n            result += input\n        return result\n\n\nclass EfficientNet(nn.Module):\n    def __init__(\n        self,\n        inverted_residual_setting: Sequence[Union[MBConvConfig, FusedMBConvConfig]],\n        dropout: float,\n        stochastic_depth_prob: float = 0.2,\n        num_classes: int = 1000,\n        norm_layer: Optional[Callable[..., nn.Module]] = None,\n        last_channel: Optional[int] = None,\n        **kwargs: Any,\n    ) -> None:\n        \"\"\"\n        EfficientNet V1 and V2 main class\n        Args:\n            inverted_residual_setting (Sequence[Union[MBConvConfig, FusedMBConvConfig]]): Network structure\n            dropout (float): The droupout probability\n            stochastic_depth_prob (float): The stochastic depth probability\n            num_classes (int): Number of classes\n            norm_layer (Optional[Callable[..., nn.Module]]): Module specifying the normalization layer to use\n            last_channel (int): The number of channels on the penultimate layer\n        \"\"\"\n        super().__init__()\n        if not inverted_residual_setting:\n            raise ValueError(\"The inverted_residual_setting should not be empty\")\n        elif not (\n            isinstance(inverted_residual_setting, Sequence)\n            and all([isinstance(s, _MBConvConfig) for s in inverted_residual_setting])\n        ):\n            raise TypeError(\n                \"The inverted_residual_setting should be List[MBConvConfig]\"\n            )\n\n        if \"block\" in kwargs:\n            warnings.warn(\n                \"The parameter 'block' is deprecated since 0.13 and will be removed 0.15. \"\n                \"Please pass this information on 'MBConvConfig.block' instead.\"\n            )\n            if kwargs[\"block\"] is not None:\n                for s in inverted_residual_setting:\n                    if isinstance(s, MBConvConfig):\n                        s.block = kwargs[\"block\"]\n\n        if norm_layer is None:\n            norm_layer = nn.BatchNorm2d\n\n        layers: List[nn.Module] = []\n\n        # building first layer\n        firstconv_output_channels = inverted_residual_setting[0].input_channels\n        layers.append(\n            Conv2dNormActivation(\n                3,\n                firstconv_output_channels,\n                kernel_size=3,\n                stride=2,\n                norm_layer=norm_layer,\n                activation_layer=nn.SiLU,\n            )\n        )\n\n        # building inverted residual blocks\n        total_stage_blocks = sum(cnf.num_layers for cnf in inverted_residual_setting)\n        stage_block_id = 0\n        for cnf in inverted_residual_setting:\n            stage: List[nn.Module] = []\n            for _ in range(cnf.num_layers):\n                # copy to avoid modifications. shallow copy is enough\n                block_cnf = copy.copy(cnf)\n\n                # overwrite info if not the first conv in the stage\n                if stage:\n                    block_cnf.input_channels = block_cnf.out_channels\n                    block_cnf.stride = 1\n\n                # adjust stochastic depth probability based on the depth of the stage block\n                sd_prob = (\n                    stochastic_depth_prob * float(stage_block_id) / total_stage_blocks\n                )\n\n                stage.append(block_cnf.block(block_cnf, sd_prob, norm_layer))\n                stage_block_id += 1\n\n            layers.append(nn.Sequential(*stage))\n\n        # building last several layers\n        lastconv_input_channels = inverted_residual_setting[-1].out_channels\n        lastconv_output_channels = (\n            last_channel if last_channel is not None else 4 * lastconv_input_channels\n        )\n        layers.append(\n            Conv2dNormActivation(\n                lastconv_input_channels,\n                lastconv_output_channels,\n                kernel_size=1,\n                norm_layer=norm_layer,\n                activation_layer=nn.SiLU,\n            )\n        )\n\n        self.features = nn.Sequential(*layers)\n        self.avgpool = nn.AdaptiveAvgPool2d(1)\n        self.classifier = nn.Sequential(\n            nn.Dropout(p=dropout, inplace=True),\n            nn.Linear(lastconv_output_channels, num_classes),\n        )\n\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                nn.init.kaiming_normal_(m.weight, mode=\"fan_out\")\n                if m.bias is not None:\n                    nn.init.zeros_(m.bias)\n            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):\n                nn.init.ones_(m.weight)\n                nn.init.zeros_(m.bias)\n            elif isinstance(m, nn.Linear):\n                init_range = 1.0 / math.sqrt(m.out_features)\n                nn.init.uniform_(m.weight, -init_range, init_range)\n                nn.init.zeros_(m.bias)\n\n    def _forward_impl(self, x: Tensor) -> Tensor:\n        x = self.features(x)\n\n        x = self.avgpool(x)\n        x = torch.flatten(x, 1)\n\n        x = self.classifier(x)\n\n        return x\n\n    def forward(self, x: Tensor) -> Tensor:\n        return self._forward_impl(x)\n\n\ndef _efficientnet(\n    inverted_residual_setting: Sequence[Union[MBConvConfig, FusedMBConvConfig]],\n    dropout: float,\n    last_channel: Optional[int],\n    progress: bool,\n    **kwargs: Any,\n) -> EfficientNet:\n    model = EfficientNet(\n        inverted_residual_setting, dropout, last_channel=last_channel, **kwargs\n    )\n    return model\n\n\ndef _efficientnet_conf(\n    arch: str, **kwargs: Any,\n) -> Tuple[Sequence[Union[MBConvConfig, FusedMBConvConfig]], Optional[int]]:\n    inverted_residual_setting: Sequence[Union[MBConvConfig, FusedMBConvConfig]]\n    if arch.startswith(\"efficientnet_b\"):\n        bneck_conf = partial(\n            MBConvConfig,\n            width_mult=kwargs.pop(\"width_mult\"),\n            depth_mult=kwargs.pop(\"depth_mult\"),\n        )\n        inverted_residual_setting = [\n            bneck_conf(1, 3, 1, 32, 16, 1),\n            bneck_conf(6, 3, 2, 16, 24, 2),\n            bneck_conf(6, 5, 2, 24, 40, 2),\n            bneck_conf(6, 3, 2, 40, 80, 3),\n            bneck_conf(6, 5, 1, 80, 112, 3),\n            bneck_conf(6, 5, 2, 112, 192, 4),\n            bneck_conf(6, 3, 1, 192, 320, 1),\n        ]\n        last_channel = None\n    elif arch.startswith(\"efficientnet_v2_s\"):\n        inverted_residual_setting = [\n            FusedMBConvConfig(1, 3, 1, 24, 24, 2),\n            FusedMBConvConfig(4, 3, 2, 24, 48, 4),\n            FusedMBConvConfig(4, 3, 2, 48, 64, 4),\n            MBConvConfig(4, 3, 2, 64, 128, 6),\n            MBConvConfig(6, 3, 1, 128, 160, 9),\n            MBConvConfig(6, 3, 2, 160, 256, 15),\n        ]\n        last_channel = 1280\n    elif arch.startswith(\"efficientnet_v2_m\"):\n        inverted_residual_setting = [\n            FusedMBConvConfig(1, 3, 1, 24, 24, 3),\n            FusedMBConvConfig(4, 3, 2, 24, 48, 5),\n            FusedMBConvConfig(4, 3, 2, 48, 80, 5),\n            MBConvConfig(4, 3, 2, 80, 160, 7),\n            MBConvConfig(6, 3, 1, 160, 176, 14),\n            MBConvConfig(6, 3, 2, 176, 304, 18),\n            MBConvConfig(6, 3, 1, 304, 512, 5),\n        ]\n        last_channel = 1280\n    elif arch.startswith(\"efficientnet_v2_l\"):\n        inverted_residual_setting = [\n            FusedMBConvConfig(1, 3, 1, 32, 32, 4),\n            FusedMBConvConfig(4, 3, 2, 32, 64, 7),\n            FusedMBConvConfig(4, 3, 2, 64, 96, 7),\n            MBConvConfig(4, 3, 2, 96, 192, 10),\n            MBConvConfig(6, 3, 1, 192, 224, 19),\n            MBConvConfig(6, 3, 2, 224, 384, 25),\n            MBConvConfig(6, 3, 1, 384, 640, 7),\n        ]\n        last_channel = 1280\n    else:\n        raise ValueError(f\"Unsupported model type {arch}\")\n\n    return inverted_residual_setting, last_channel\n\n\ndef efficientnet_b0(progress: bool = True, **kwargs: Any) -> EfficientNet:\n    \"\"\"\n    Constructs a EfficientNet B0 architecture from\n    `\"EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks\" <https://arxiv.org/abs/1905.11946>`_.\n    Args:\n        weights (EfficientNet_B0_Weights, optional): The pretrained weights for the model\n        progress (bool): If True, displays a progress bar of the download to stderr\n    \"\"\"\n\n    inverted_residual_setting, last_channel = _efficientnet_conf(\n        \"efficientnet_b0\", width_mult=1.0, depth_mult=1.0\n    )\n    return _efficientnet(\n        inverted_residual_setting, 0.2, last_channel, progress, **kwargs\n    )\n"
  },
  {
    "path": "python/oneflow/test/expensive/pytorch_ghostnet.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport torch\nimport torch.nn as nn\nimport math\n\n\n__all__ = [\"ghost_net\"]\n\n\ndef _make_divisible(v, divisor, min_value=None):\n    \"\"\"\n    This function is taken from the original tf repo.\n    It ensures that all layers have a channel number that is divisible by 8\n    It can be seen here:\n    https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py\n    \"\"\"\n    if min_value is None:\n        min_value = divisor\n    new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)\n    # Make sure that round down does not go down by more than 10%.\n    if new_v < 0.9 * v:\n        new_v += divisor\n    return new_v\n\n\nclass SELayer(nn.Module):\n    def __init__(self, channel, reduction=4):\n        super(SELayer, self).__init__()\n        self.avg_pool = nn.AdaptiveAvgPool2d(1)\n        self.fc = nn.Sequential(\n            nn.Linear(channel, channel // reduction),\n            nn.ReLU(inplace=True),\n            nn.Linear(channel // reduction, channel),\n        )\n\n    def forward(self, x):\n        b, c, _, _ = x.size()\n        y = self.avg_pool(x).view(b, c)\n        y = self.fc(y).view(b, c, 1, 1)\n        y = torch.clamp(y, 0, 1)\n        return x * y\n\n\ndef depthwise_conv(inp, oup, kernel_size=3, stride=1, relu=False):\n    return nn.Sequential(\n        nn.Conv2d(\n            inp, oup, kernel_size, stride, kernel_size // 2, groups=inp, bias=False\n        ),\n        nn.BatchNorm2d(oup),\n        nn.ReLU(inplace=True) if relu else nn.Sequential(),\n    )\n\n\nclass GhostModule(nn.Module):\n    def __init__(\n        self, inp, oup, kernel_size=1, ratio=2, dw_size=3, stride=1, relu=True\n    ):\n        super(GhostModule, self).__init__()\n        self.oup = oup\n        init_channels = math.ceil(oup / ratio)\n        new_channels = init_channels * (ratio - 1)\n\n        self.primary_conv = nn.Sequential(\n            nn.Conv2d(\n                inp, init_channels, kernel_size, stride, kernel_size // 2, bias=False\n            ),\n            nn.BatchNorm2d(init_channels),\n            nn.ReLU(inplace=True) if relu else nn.Sequential(),\n        )\n\n        self.cheap_operation = nn.Sequential(\n            nn.Conv2d(\n                init_channels,\n                new_channels,\n                dw_size,\n                1,\n                dw_size // 2,\n                groups=init_channels,\n                bias=False,\n            ),\n            nn.BatchNorm2d(new_channels),\n            nn.ReLU(inplace=True) if relu else nn.Sequential(),\n        )\n\n    def forward(self, x):\n        x1 = self.primary_conv(x)\n        x2 = self.cheap_operation(x1)\n        out = torch.cat([x1, x2], dim=1)\n        return out[:, : self.oup, :, :]\n\n\nclass GhostBottleneck(nn.Module):\n    def __init__(self, inp, hidden_dim, oup, kernel_size, stride, use_se):\n        super(GhostBottleneck, self).__init__()\n        assert stride in [1, 2]\n\n        self.conv = nn.Sequential(\n            # pw\n            GhostModule(inp, hidden_dim, kernel_size=1, relu=True),\n            # dw\n            depthwise_conv(hidden_dim, hidden_dim, kernel_size, stride, relu=False)\n            if stride == 2\n            else nn.Sequential(),\n            # Squeeze-and-Excite\n            SELayer(hidden_dim) if use_se else nn.Sequential(),\n            # pw-linear\n            GhostModule(hidden_dim, oup, kernel_size=1, relu=False),\n        )\n\n        if stride == 1 and inp == oup:\n            self.shortcut = nn.Sequential()\n        else:\n            self.shortcut = nn.Sequential(\n                depthwise_conv(inp, inp, kernel_size, stride, relu=False),\n                nn.Conv2d(inp, oup, 1, 1, 0, bias=False),\n                nn.BatchNorm2d(oup),\n            )\n\n    def forward(self, x):\n        return self.conv(x) + self.shortcut(x)\n\n\nclass GhostNet(nn.Module):\n    def __init__(self, cfgs, num_classes=1000, width_mult=1.0):\n        super(GhostNet, self).__init__()\n        # setting of inverted residual blocks\n        self.cfgs = cfgs\n\n        # building first layer\n        output_channel = _make_divisible(16 * width_mult, 4)\n        layers = [\n            nn.Sequential(\n                nn.Conv2d(3, output_channel, 3, 2, 1, bias=False),\n                nn.BatchNorm2d(output_channel),\n                nn.ReLU(inplace=True),\n            )\n        ]\n        input_channel = output_channel\n\n        # building inverted residual blocks\n        block = GhostBottleneck\n        for k, exp_size, c, use_se, s in self.cfgs:\n            output_channel = _make_divisible(c * width_mult, 4)\n            hidden_channel = _make_divisible(exp_size * width_mult, 4)\n            layers.append(\n                block(input_channel, hidden_channel, output_channel, k, s, use_se)\n            )\n            input_channel = output_channel\n        self.features = nn.Sequential(*layers)\n\n        # building last several layers\n        output_channel = _make_divisible(exp_size * width_mult, 4)\n        self.squeeze = nn.Sequential(\n            nn.Conv2d(input_channel, output_channel, 1, 1, 0, bias=False),\n            nn.BatchNorm2d(output_channel),\n            nn.ReLU(inplace=True),\n            nn.AdaptiveAvgPool2d((1, 1)),\n        )\n        input_channel = output_channel\n\n        output_channel = 1280\n        self.classifier = nn.Sequential(\n            nn.Linear(input_channel, output_channel, bias=False),\n            nn.BatchNorm1d(output_channel),\n            nn.ReLU(inplace=True),\n            nn.Dropout(0.2),\n            nn.Linear(output_channel, num_classes),\n        )\n\n        self._initialize_weights()\n\n    def forward(self, x):\n        x = self.features(x)\n        x = self.squeeze(x)\n        x = x.view(x.size(0), -1)\n        x = self.classifier(x)\n        return x\n\n    def _initialize_weights(self):\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                nn.init.kaiming_normal_(m.weight, mode=\"fan_out\", nonlinearity=\"relu\")\n            elif isinstance(m, nn.BatchNorm2d):\n                m.weight.data.fill_(1)\n                m.bias.data.zero_()\n\n\ndef ghost_net(**kwargs):\n    \"\"\"\n    Constructs a GhostNet model\n    \"\"\"\n    cfgs = [\n        # k, t, c, SE, s\n        [3, 16, 16, 0, 1],\n        [3, 48, 24, 0, 2],\n        [3, 72, 24, 0, 1],\n        [5, 72, 40, 1, 2],\n        [5, 120, 40, 1, 1],\n        [3, 240, 80, 0, 2],\n        [3, 200, 80, 0, 1],\n        [3, 184, 80, 0, 1],\n        [3, 184, 80, 0, 1],\n        [3, 480, 112, 1, 1],\n        [3, 672, 112, 1, 1],\n        [5, 672, 160, 1, 2],\n        [5, 960, 160, 0, 1],\n        [5, 960, 160, 1, 1],\n        [5, 960, 160, 0, 1],\n        [5, 960, 160, 1, 1],\n    ]\n    return GhostNet(cfgs, **kwargs)\n"
  },
  {
    "path": "python/oneflow/test/expensive/pytorch_googlenet.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch import Tensor\n\nimport warnings\nfrom typing import Optional, Tuple, List, Callable, Any\n\n__all__ = [\"GoogLeNet\", \"googlenet\"]\n\n\nclass GoogLeNet(nn.Module):\n    __constants__ = [\"aux_logits\", \"transform_input\"]\n\n    def __init__(\n        self,\n        num_classes: int = 1000,\n        aux_logits: bool = True,\n        transform_input: bool = False,\n        init_weights: Optional[bool] = None,\n        blocks: Optional[List[Callable[..., nn.Module]]] = None,\n        dropout: float = 0.2,\n        dropout_aux: float = 0.7,\n    ) -> None:\n        super().__init__()\n        if blocks is None:\n            blocks = [BasicConv2d, Inception, InceptionAux]\n        if init_weights is None:\n            warnings.warn(\n                \"The default weight initialization of GoogleNet will be changed in future releases of \"\n                \"torchvision. If you wish to keep the old behavior (which leads to long initialization times\"\n                \" due to scipy/scipy#11299), please set init_weights=True.\",\n                FutureWarning,\n            )\n            init_weights = True\n        if len(blocks) != 3:\n            raise ValueError(f\"blocks length should be 3 instead of {len(blocks)}\")\n        conv_block = blocks[0]\n        inception_block = blocks[1]\n        inception_aux_block = blocks[2]\n\n        self.aux_logits = aux_logits\n        self.transform_input = transform_input\n\n        self.conv1 = conv_block(3, 64, kernel_size=7, stride=2, padding=3)\n        self.maxpool1 = nn.MaxPool2d(3, stride=2, ceil_mode=True)\n        self.conv2 = conv_block(64, 64, kernel_size=1)\n        self.conv3 = conv_block(64, 192, kernel_size=3, padding=1)\n        self.maxpool2 = nn.MaxPool2d(3, stride=2, ceil_mode=True)\n\n        self.inception3a = inception_block(192, 64, 96, 128, 16, 32, 32)\n        self.inception3b = inception_block(256, 128, 128, 192, 32, 96, 64)\n        self.maxpool3 = nn.MaxPool2d(3, stride=2, ceil_mode=True)\n\n        self.inception4a = inception_block(480, 192, 96, 208, 16, 48, 64)\n        self.inception4b = inception_block(512, 160, 112, 224, 24, 64, 64)\n        self.inception4c = inception_block(512, 128, 128, 256, 24, 64, 64)\n        self.inception4d = inception_block(512, 112, 144, 288, 32, 64, 64)\n        self.inception4e = inception_block(528, 256, 160, 320, 32, 128, 128)\n        self.maxpool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)\n\n        self.inception5a = inception_block(832, 256, 160, 320, 32, 128, 128)\n        self.inception5b = inception_block(832, 384, 192, 384, 48, 128, 128)\n\n        if aux_logits:\n            self.aux1 = inception_aux_block(512, num_classes, dropout=dropout_aux)\n            self.aux2 = inception_aux_block(528, num_classes, dropout=dropout_aux)\n        else:\n            self.aux1 = None  # type: ignore[assignment]\n            self.aux2 = None  # type: ignore[assignment]\n\n        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))\n        self.dropout = nn.Dropout(p=dropout)\n        self.fc = nn.Linear(1024, num_classes)\n\n        if init_weights:\n            for m in self.modules():\n                if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):\n                    torch.nn.init.trunc_normal_(m.weight, mean=0.0, std=0.01, a=-2, b=2)\n                elif isinstance(m, nn.BatchNorm2d):\n                    nn.init.constant_(m.weight, 1)\n                    nn.init.constant_(m.bias, 0)\n\n    def _transform_input(self, x: Tensor) -> Tensor:\n        if self.transform_input:\n            x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5\n            x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5\n            x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5\n            x = torch.cat((x_ch0, x_ch1, x_ch2), 1)\n        return x\n\n    def _forward(self, x: Tensor) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:\n        # N x 3 x 224 x 224\n        x = self.conv1(x)\n        # N x 64 x 112 x 112\n        x = self.maxpool1(x)\n        # N x 64 x 56 x 56\n        x = self.conv2(x)\n        # N x 64 x 56 x 56\n        x = self.conv3(x)\n        # N x 192 x 56 x 56\n        x = self.maxpool2(x)\n\n        # N x 192 x 28 x 28\n        x = self.inception3a(x)\n        # N x 256 x 28 x 28\n        x = self.inception3b(x)\n        # N x 480 x 28 x 28\n        x = self.maxpool3(x)\n        # N x 480 x 14 x 14\n        x = self.inception4a(x)\n        # N x 512 x 14 x 14\n        aux1: Optional[Tensor] = None\n        if self.aux1 is not None:\n            if self.training:\n                aux1 = self.aux1(x)\n\n        x = self.inception4b(x)\n        # N x 512 x 14 x 14\n        x = self.inception4c(x)\n        # N x 512 x 14 x 14\n        x = self.inception4d(x)\n        # N x 528 x 14 x 14\n        aux2: Optional[Tensor] = None\n        if self.aux2 is not None:\n            if self.training:\n                aux2 = self.aux2(x)\n\n        x = self.inception4e(x)\n        # N x 832 x 14 x 14\n        x = self.maxpool4(x)\n        # N x 832 x 7 x 7\n        x = self.inception5a(x)\n        # N x 832 x 7 x 7\n        x = self.inception5b(x)\n        # N x 1024 x 7 x 7\n\n        x = self.avgpool(x)\n        # N x 1024 x 1 x 1\n        x = torch.flatten(x, 1)\n        # N x 1024\n        x = self.dropout(x)\n        x = self.fc(x)\n        # N x 1000 (num_classes)\n        return x, aux2, aux1\n\n    def forward(self, x: Tensor):\n        x = self._transform_input(x)\n        x, aux1, aux2 = self._forward(x)\n        return x\n\n\nclass Inception(nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        ch1x1: int,\n        ch3x3red: int,\n        ch3x3: int,\n        ch5x5red: int,\n        ch5x5: int,\n        pool_proj: int,\n        conv_block: Optional[Callable[..., nn.Module]] = None,\n    ) -> None:\n        super().__init__()\n        if conv_block is None:\n            conv_block = BasicConv2d\n        self.branch1 = conv_block(in_channels, ch1x1, kernel_size=1)\n\n        self.branch2 = nn.Sequential(\n            conv_block(in_channels, ch3x3red, kernel_size=1),\n            conv_block(ch3x3red, ch3x3, kernel_size=3, padding=1),\n        )\n\n        self.branch3 = nn.Sequential(\n            conv_block(in_channels, ch5x5red, kernel_size=1),\n            # Here, kernel_size=3 instead of kernel_size=5 is a known bug.\n            # Please see https://github.com/pytorch/vision/issues/906 for details.\n            conv_block(ch5x5red, ch5x5, kernel_size=3, padding=1),\n        )\n\n        self.branch4 = nn.Sequential(\n            nn.MaxPool2d(kernel_size=3, stride=1, padding=1, ceil_mode=True),\n            conv_block(in_channels, pool_proj, kernel_size=1),\n        )\n\n    def _forward(self, x: Tensor) -> List[Tensor]:\n        branch1 = self.branch1(x)\n        branch2 = self.branch2(x)\n        branch3 = self.branch3(x)\n        branch4 = self.branch4(x)\n\n        outputs = [branch1, branch2, branch3, branch4]\n        return outputs\n\n    def forward(self, x: Tensor) -> Tensor:\n        outputs = self._forward(x)\n        return torch.cat(outputs, 1)\n\n\nclass InceptionAux(nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        num_classes: int,\n        conv_block: Optional[Callable[..., nn.Module]] = None,\n        dropout: float = 0.7,\n    ) -> None:\n        super().__init__()\n        if conv_block is None:\n            conv_block = BasicConv2d\n        self.conv = conv_block(in_channels, 128, kernel_size=1)\n\n        self.fc1 = nn.Linear(2048, 1024)\n        self.fc2 = nn.Linear(1024, num_classes)\n        self.dropout = nn.Dropout(p=dropout)\n\n    def forward(self, x: Tensor) -> Tensor:\n        # aux1: N x 512 x 14 x 14, aux2: N x 528 x 14 x 14\n        x = F.adaptive_avg_pool2d(x, (4, 4))\n        # aux1: N x 512 x 4 x 4, aux2: N x 528 x 4 x 4\n        x = self.conv(x)\n        # N x 128 x 4 x 4\n        x = torch.flatten(x, 1)\n        # N x 2048\n        x = F.relu(self.fc1(x), inplace=True)\n        # N x 1024\n        x = self.dropout(x)\n        # N x 1024\n        x = self.fc2(x)\n        # N x 1000 (num_classes)\n\n        return x\n\n\nclass BasicConv2d(nn.Module):\n    def __init__(self, in_channels: int, out_channels: int, **kwargs: Any) -> None:\n        super().__init__()\n        self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)\n        self.bn = nn.BatchNorm2d(out_channels, eps=0.001)\n\n    def forward(self, x: Tensor) -> Tensor:\n        x = self.conv(x)\n        x = self.bn(x)\n        return F.relu(x, inplace=True)\n\n\ndef googlenet(progress: bool = True, **kwargs: Any) -> GoogLeNet:\n    r\"\"\"GoogLeNet (Inception v1) model architecture from\n    `\"Going Deeper with Convolutions\" <http://arxiv.org/abs/1409.4842>`_.\n    The required minimum input size of the model is 15x15.\n    Args:\n        weights (GoogLeNet_Weights, optional): The pretrained weights for the model\n        progress (bool): If True, displays a progress bar of the download to stderr\n        aux_logits (bool): If True, adds two auxiliary branches that can improve training.\n            Default: *False* when pretrained is True otherwise *True*\n        transform_input (bool): If True, preprocesses the input according to the method with which it\n            was trained on ImageNet. Default: True if ``weights=GoogLeNet_Weights.IMAGENET1K_V1``, else False.\n    \"\"\"\n    model = GoogLeNet(**kwargs)\n    return model\n"
  },
  {
    "path": "python/oneflow/test/expensive/pytorch_inception_v3.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport torch\nimport torch.nn.functional as F\nfrom torch import nn, Tensor\n\nimport warnings\nfrom typing import Callable, Any, Optional, Tuple, List\n\n__all__ = [\"Inception3\", \"inception_v3\"]\n\n\nclass Inception3(nn.Module):\n    def __init__(\n        self,\n        num_classes: int = 1000,\n        aux_logits: bool = True,\n        transform_input: bool = False,\n        inception_blocks: Optional[List[Callable[..., nn.Module]]] = None,\n        init_weights: Optional[bool] = None,\n        dropout: float = 0.5,\n    ) -> None:\n        super().__init__()\n        if inception_blocks is None:\n            inception_blocks = [\n                BasicConv2d,\n                InceptionA,\n                InceptionB,\n                InceptionC,\n                InceptionD,\n                InceptionE,\n                InceptionAux,\n            ]\n        if init_weights is None:\n            warnings.warn(\n                \"The default weight initialization of inception_v3 will be changed in future releases of \"\n                \"torchvision. If you wish to keep the old behavior (which leads to long initialization times\"\n                \" due to scipy/scipy#11299), please set init_weights=True.\",\n                FutureWarning,\n            )\n            init_weights = True\n        if len(inception_blocks) != 7:\n            raise ValueError(\n                f\"lenght of inception_blocks should be 7 instead of {len(inception_blocks)}\"\n            )\n        conv_block = inception_blocks[0]\n        inception_a = inception_blocks[1]\n        inception_b = inception_blocks[2]\n        inception_c = inception_blocks[3]\n        inception_d = inception_blocks[4]\n        inception_e = inception_blocks[5]\n        inception_aux = inception_blocks[6]\n\n        self.aux_logits = aux_logits\n        self.transform_input = transform_input\n        self.Conv2d_1a_3x3 = conv_block(3, 32, kernel_size=3, stride=2)\n        self.Conv2d_2a_3x3 = conv_block(32, 32, kernel_size=3)\n        self.Conv2d_2b_3x3 = conv_block(32, 64, kernel_size=3, padding=1)\n        self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2)\n        self.Conv2d_3b_1x1 = conv_block(64, 80, kernel_size=1)\n        self.Conv2d_4a_3x3 = conv_block(80, 192, kernel_size=3)\n        self.maxpool2 = nn.MaxPool2d(kernel_size=3, stride=2)\n        self.Mixed_5b = inception_a(192, pool_features=32)\n        self.Mixed_5c = inception_a(256, pool_features=64)\n        self.Mixed_5d = inception_a(288, pool_features=64)\n        self.Mixed_6a = inception_b(288)\n        self.Mixed_6b = inception_c(768, channels_7x7=128)\n        self.Mixed_6c = inception_c(768, channels_7x7=160)\n        self.Mixed_6d = inception_c(768, channels_7x7=160)\n        self.Mixed_6e = inception_c(768, channels_7x7=192)\n        self.AuxLogits: Optional[nn.Module] = None\n        if aux_logits:\n            self.AuxLogits = inception_aux(768, num_classes)\n        self.Mixed_7a = inception_d(768)\n        self.Mixed_7b = inception_e(1280)\n        self.Mixed_7c = inception_e(2048)\n        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))\n        self.dropout = nn.Dropout(p=dropout)\n        self.fc = nn.Linear(2048, num_classes)\n        if init_weights:\n            for m in self.modules():\n                if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):\n                    stddev = float(m.stddev) if hasattr(m, \"stddev\") else 0.1  # type: ignore\n                    torch.nn.init.trunc_normal_(\n                        m.weight, mean=0.0, std=stddev, a=-2, b=2\n                    )\n                elif isinstance(m, nn.BatchNorm2d):\n                    nn.init.constant_(m.weight, 1)\n                    nn.init.constant_(m.bias, 0)\n\n    def _transform_input(self, x: Tensor) -> Tensor:\n        if self.transform_input:\n            x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5\n            x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5\n            x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5\n            x = torch.cat((x_ch0, x_ch1, x_ch2), 1)\n        return x\n\n    def _forward(self, x: Tensor) -> Tuple[Tensor, Optional[Tensor]]:\n        # N x 3 x 299 x 299\n        x = self.Conv2d_1a_3x3(x)\n        # N x 32 x 149 x 149\n        x = self.Conv2d_2a_3x3(x)\n        # N x 32 x 147 x 147\n        x = self.Conv2d_2b_3x3(x)\n        # N x 64 x 147 x 147\n        x = self.maxpool1(x)\n        # N x 64 x 73 x 73\n        x = self.Conv2d_3b_1x1(x)\n        # N x 80 x 73 x 73\n        x = self.Conv2d_4a_3x3(x)\n        # N x 192 x 71 x 71\n        x = self.maxpool2(x)\n        # N x 192 x 35 x 35\n        x = self.Mixed_5b(x)\n        # N x 256 x 35 x 35\n        x = self.Mixed_5c(x)\n        # N x 288 x 35 x 35\n        x = self.Mixed_5d(x)\n        # N x 288 x 35 x 35\n        x = self.Mixed_6a(x)\n        # N x 768 x 17 x 17\n        x = self.Mixed_6b(x)\n        # N x 768 x 17 x 17\n        x = self.Mixed_6c(x)\n        # N x 768 x 17 x 17\n        x = self.Mixed_6d(x)\n        # N x 768 x 17 x 17\n        x = self.Mixed_6e(x)\n        # N x 768 x 17 x 17\n        aux: Optional[Tensor] = None\n        if self.AuxLogits is not None:\n            if self.training:\n                aux = self.AuxLogits(x)\n        # N x 768 x 17 x 17\n        x = self.Mixed_7a(x)\n        # N x 1280 x 8 x 8\n        x = self.Mixed_7b(x)\n        # N x 2048 x 8 x 8\n        x = self.Mixed_7c(x)\n        # N x 2048 x 8 x 8\n        # Adaptive average pooling\n        x = self.avgpool(x)\n        # N x 2048 x 1 x 1\n        x = self.dropout(x)\n        # N x 2048 x 1 x 1\n        x = torch.flatten(x, 1)\n        # N x 2048\n        x = self.fc(x)\n        # N x 1000 (num_classes)\n        return x, aux\n\n    def forward(self, x: Tensor):\n        x = self._transform_input(x)\n        x, aux = self._forward(x)\n        return x\n\n\nclass InceptionA(nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        pool_features: int,\n        conv_block: Optional[Callable[..., nn.Module]] = None,\n    ) -> None:\n        super().__init__()\n        if conv_block is None:\n            conv_block = BasicConv2d\n        self.branch1x1 = conv_block(in_channels, 64, kernel_size=1)\n\n        self.branch5x5_1 = conv_block(in_channels, 48, kernel_size=1)\n        self.branch5x5_2 = conv_block(48, 64, kernel_size=5, padding=2)\n\n        self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1)\n        self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1)\n        self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, padding=1)\n\n        self.branch_pool = conv_block(in_channels, pool_features, kernel_size=1)\n\n    def _forward(self, x: Tensor) -> List[Tensor]:\n        branch1x1 = self.branch1x1(x)\n\n        branch5x5 = self.branch5x5_1(x)\n        branch5x5 = self.branch5x5_2(branch5x5)\n\n        branch3x3dbl = self.branch3x3dbl_1(x)\n        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)\n        branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)\n\n        branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)\n        branch_pool = self.branch_pool(branch_pool)\n\n        outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]\n        return outputs\n\n    def forward(self, x: Tensor) -> Tensor:\n        outputs = self._forward(x)\n        return torch.cat(outputs, 1)\n\n\nclass InceptionB(nn.Module):\n    def __init__(\n        self, in_channels: int, conv_block: Optional[Callable[..., nn.Module]] = None\n    ) -> None:\n        super().__init__()\n        if conv_block is None:\n            conv_block = BasicConv2d\n        self.branch3x3 = conv_block(in_channels, 384, kernel_size=3, stride=2)\n\n        self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1)\n        self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1)\n        self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, stride=2)\n\n    def _forward(self, x: Tensor) -> List[Tensor]:\n        branch3x3 = self.branch3x3(x)\n\n        branch3x3dbl = self.branch3x3dbl_1(x)\n        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)\n        branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)\n\n        branch_pool = F.max_pool2d(x, kernel_size=3, stride=2)\n\n        outputs = [branch3x3, branch3x3dbl, branch_pool]\n        return outputs\n\n    def forward(self, x: Tensor) -> Tensor:\n        outputs = self._forward(x)\n        return torch.cat(outputs, 1)\n\n\nclass InceptionC(nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        channels_7x7: int,\n        conv_block: Optional[Callable[..., nn.Module]] = None,\n    ) -> None:\n        super().__init__()\n        if conv_block is None:\n            conv_block = BasicConv2d\n        self.branch1x1 = conv_block(in_channels, 192, kernel_size=1)\n\n        c7 = channels_7x7\n        self.branch7x7_1 = conv_block(in_channels, c7, kernel_size=1)\n        self.branch7x7_2 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3))\n        self.branch7x7_3 = conv_block(c7, 192, kernel_size=(7, 1), padding=(3, 0))\n\n        self.branch7x7dbl_1 = conv_block(in_channels, c7, kernel_size=1)\n        self.branch7x7dbl_2 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0))\n        self.branch7x7dbl_3 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3))\n        self.branch7x7dbl_4 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0))\n        self.branch7x7dbl_5 = conv_block(c7, 192, kernel_size=(1, 7), padding=(0, 3))\n\n        self.branch_pool = conv_block(in_channels, 192, kernel_size=1)\n\n    def _forward(self, x: Tensor) -> List[Tensor]:\n        branch1x1 = self.branch1x1(x)\n\n        branch7x7 = self.branch7x7_1(x)\n        branch7x7 = self.branch7x7_2(branch7x7)\n        branch7x7 = self.branch7x7_3(branch7x7)\n\n        branch7x7dbl = self.branch7x7dbl_1(x)\n        branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)\n        branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)\n        branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)\n        branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)\n\n        branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)\n        branch_pool = self.branch_pool(branch_pool)\n\n        outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]\n        return outputs\n\n    def forward(self, x: Tensor) -> Tensor:\n        outputs = self._forward(x)\n        return torch.cat(outputs, 1)\n\n\nclass InceptionD(nn.Module):\n    def __init__(\n        self, in_channels: int, conv_block: Optional[Callable[..., nn.Module]] = None\n    ) -> None:\n        super().__init__()\n        if conv_block is None:\n            conv_block = BasicConv2d\n        self.branch3x3_1 = conv_block(in_channels, 192, kernel_size=1)\n        self.branch3x3_2 = conv_block(192, 320, kernel_size=3, stride=2)\n\n        self.branch7x7x3_1 = conv_block(in_channels, 192, kernel_size=1)\n        self.branch7x7x3_2 = conv_block(192, 192, kernel_size=(1, 7), padding=(0, 3))\n        self.branch7x7x3_3 = conv_block(192, 192, kernel_size=(7, 1), padding=(3, 0))\n        self.branch7x7x3_4 = conv_block(192, 192, kernel_size=3, stride=2)\n\n    def _forward(self, x: Tensor) -> List[Tensor]:\n        branch3x3 = self.branch3x3_1(x)\n        branch3x3 = self.branch3x3_2(branch3x3)\n\n        branch7x7x3 = self.branch7x7x3_1(x)\n        branch7x7x3 = self.branch7x7x3_2(branch7x7x3)\n        branch7x7x3 = self.branch7x7x3_3(branch7x7x3)\n        branch7x7x3 = self.branch7x7x3_4(branch7x7x3)\n\n        branch_pool = F.max_pool2d(x, kernel_size=3, stride=2)\n        outputs = [branch3x3, branch7x7x3, branch_pool]\n        return outputs\n\n    def forward(self, x: Tensor) -> Tensor:\n        outputs = self._forward(x)\n        return torch.cat(outputs, 1)\n\n\nclass InceptionE(nn.Module):\n    def __init__(\n        self, in_channels: int, conv_block: Optional[Callable[..., nn.Module]] = None\n    ) -> None:\n        super().__init__()\n        if conv_block is None:\n            conv_block = BasicConv2d\n        self.branch1x1 = conv_block(in_channels, 320, kernel_size=1)\n\n        self.branch3x3_1 = conv_block(in_channels, 384, kernel_size=1)\n        self.branch3x3_2a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1))\n        self.branch3x3_2b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0))\n\n        self.branch3x3dbl_1 = conv_block(in_channels, 448, kernel_size=1)\n        self.branch3x3dbl_2 = conv_block(448, 384, kernel_size=3, padding=1)\n        self.branch3x3dbl_3a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1))\n        self.branch3x3dbl_3b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0))\n\n        self.branch_pool = conv_block(in_channels, 192, kernel_size=1)\n\n    def _forward(self, x: Tensor) -> List[Tensor]:\n        branch1x1 = self.branch1x1(x)\n\n        branch3x3 = self.branch3x3_1(x)\n        branch3x3 = [\n            self.branch3x3_2a(branch3x3),\n            self.branch3x3_2b(branch3x3),\n        ]\n        branch3x3 = torch.cat(branch3x3, 1)\n\n        branch3x3dbl = self.branch3x3dbl_1(x)\n        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)\n        branch3x3dbl = [\n            self.branch3x3dbl_3a(branch3x3dbl),\n            self.branch3x3dbl_3b(branch3x3dbl),\n        ]\n        branch3x3dbl = torch.cat(branch3x3dbl, 1)\n\n        branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)\n        branch_pool = self.branch_pool(branch_pool)\n\n        outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]\n        return outputs\n\n    def forward(self, x: Tensor) -> Tensor:\n        outputs = self._forward(x)\n        return torch.cat(outputs, 1)\n\n\nclass InceptionAux(nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        num_classes: int,\n        conv_block: Optional[Callable[..., nn.Module]] = None,\n    ) -> None:\n        super().__init__()\n        if conv_block is None:\n            conv_block = BasicConv2d\n        self.conv0 = conv_block(in_channels, 128, kernel_size=1)\n        self.conv1 = conv_block(128, 768, kernel_size=5)\n        self.conv1.stddev = 0.01  # type: ignore[assignment]\n        self.fc = nn.Linear(768, num_classes)\n        self.fc.stddev = 0.001  # type: ignore[assignment]\n\n    def forward(self, x: Tensor) -> Tensor:\n        # N x 768 x 17 x 17\n        x = F.avg_pool2d(x, kernel_size=5, stride=3)\n        # N x 768 x 5 x 5\n        x = self.conv0(x)\n        # N x 128 x 5 x 5\n        x = self.conv1(x)\n        # N x 768 x 1 x 1\n        # Adaptive average pooling\n        x = F.adaptive_avg_pool2d(x, (1, 1))\n        # N x 768 x 1 x 1\n        x = torch.flatten(x, 1)\n        # N x 768\n        x = self.fc(x)\n        # N x 1000\n        return x\n\n\nclass BasicConv2d(nn.Module):\n    def __init__(self, in_channels: int, out_channels: int, **kwargs: Any) -> None:\n        super().__init__()\n        self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)\n        self.bn = nn.BatchNorm2d(out_channels, eps=0.001)\n\n    def forward(self, x: Tensor) -> Tensor:\n        x = self.conv(x)\n        x = self.bn(x)\n        return F.relu(x, inplace=True)\n\n\ndef inception_v3(progress: bool = True, **kwargs: Any) -> Inception3:\n    r\"\"\"Inception v3 model architecture from\n    `\"Rethinking the Inception Architecture for Computer Vision\" <http://arxiv.org/abs/1512.00567>`_.\n    The required minimum input size of the model is 75x75.\n    .. note::\n        **Important**: In contrast to the other models the inception_v3 expects tensors with a size of\n        N x 3 x 299 x 299, so ensure your images are sized accordingly.\n    Args:\n        weights (Inception_V3_Weights, optional): The pretrained weights for the model\n        progress (bool): If True, displays a progress bar of the download to stderr\n        aux_logits (bool): If True, add an auxiliary branch that can improve training.\n            Default: *True*\n        transform_input (bool): If True, preprocesses the input according to the method with which it\n            was trained on ImageNet. Default: True if ``weights=Inception_V3_Weights.IMAGENET1K_V1``, else False.\n    \"\"\"\n    model = Inception3(**kwargs)\n    return model\n"
  },
  {
    "path": "python/oneflow/test/expensive/pytorch_levit.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport torch\nimport itertools\nfrom timm.models.vision_transformer import trunc_normal_\n\nspecification = {\n    \"LeViT_128S\": {\n        \"C\": \"128_256_384\",\n        \"D\": 16,\n        \"N\": \"4_6_8\",\n        \"X\": \"2_3_4\",\n        \"drop_path\": 0,\n        \"weights\": \"https://dl.fbaipublicfiles.com/LeViT/LeViT-128S-96703c44.pth\",\n    }\n}\n\n__all__ = [\"LeViT_128S\"]\n\n\ndef LeViT_128S(num_classes=1000, distillation=False, pretrained=False, fuse=False):\n    return model_factory(\n        **specification[\"LeViT_128S\"],\n        num_classes=num_classes,\n        distillation=distillation,\n        pretrained=pretrained,\n        fuse=fuse\n    )\n\n\nFLOPS_COUNTER = 0\n\n\nclass Conv2d_BN(torch.nn.Sequential):\n    def __init__(\n        self,\n        a,\n        b,\n        ks=1,\n        stride=1,\n        pad=0,\n        dilation=1,\n        groups=1,\n        bn_weight_init=1,\n        resolution=-10000,\n    ):\n        super().__init__()\n        self.add_module(\n            \"c\", torch.nn.Conv2d(a, b, ks, stride, pad, dilation, groups, bias=False)\n        )\n        bn = torch.nn.BatchNorm2d(b)\n        torch.nn.init.constant_(bn.weight, bn_weight_init)\n        torch.nn.init.constant_(bn.bias, 0)\n        self.add_module(\"bn\", bn)\n\n        global FLOPS_COUNTER\n        output_points = (\n            (resolution + 2 * pad - dilation * (ks - 1) - 1) // stride + 1\n        ) ** 2\n        FLOPS_COUNTER += a * b * output_points * (ks ** 2) // groups\n\n    @torch.no_grad()\n    def fuse(self):\n        c, bn = self._modules.values()\n        w = bn.weight / (bn.running_var + bn.eps) ** 0.5\n        w = c.weight * w[:, None, None, None]\n        b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5\n        m = torch.nn.Conv2d(\n            w.size(1) * self.c.groups,\n            w.size(0),\n            w.shape[2:],\n            stride=self.c.stride,\n            padding=self.c.padding,\n            dilation=self.c.dilation,\n            groups=self.c.groups,\n        )\n        m.weight.data.copy_(w)\n        m.bias.data.copy_(b)\n        return m\n\n\nclass Linear_BN(torch.nn.Sequential):\n    def __init__(self, a, b, bn_weight_init=1, resolution=-100000):\n        super().__init__()\n        self.add_module(\"c\", torch.nn.Linear(a, b, bias=False))\n        bn = torch.nn.BatchNorm1d(b)\n        torch.nn.init.constant_(bn.weight, bn_weight_init)\n        torch.nn.init.constant_(bn.bias, 0)\n        self.add_module(\"bn\", bn)\n\n        global FLOPS_COUNTER\n        output_points = resolution ** 2\n        FLOPS_COUNTER += a * b * output_points\n\n    @torch.no_grad()\n    def fuse(self):\n        l, bn = self._modules.values()\n        w = bn.weight / (bn.running_var + bn.eps) ** 0.5\n        w = l.weight * w[:, None]\n        b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5\n        m = torch.nn.Linear(w.size(1), w.size(0))\n        m.weight.data.copy_(w)\n        m.bias.data.copy_(b)\n        return m\n\n    def forward(self, x):\n        l, bn = self._modules.values()\n        x = l(x)\n        return bn(x.flatten(0, 1)).reshape_as(x)\n\n\nclass BN_Linear(torch.nn.Sequential):\n    def __init__(self, a, b, bias=True, std=0.02):\n        super().__init__()\n        self.add_module(\"bn\", torch.nn.BatchNorm1d(a))\n        l = torch.nn.Linear(a, b, bias=bias)\n        trunc_normal_(l.weight, std=std)\n        if bias:\n            torch.nn.init.constant_(l.bias, 0)\n        self.add_module(\"l\", l)\n        global FLOPS_COUNTER\n        FLOPS_COUNTER += a * b\n\n    @torch.no_grad()\n    def fuse(self):\n        bn, l = self._modules.values()\n        w = bn.weight / (bn.running_var + bn.eps) ** 0.5\n        b = (\n            bn.bias\n            - self.bn.running_mean * self.bn.weight / (bn.running_var + bn.eps) ** 0.5\n        )\n        w = l.weight * w[None, :]\n        if l.bias is None:\n            b = b @ self.l.weight.T\n        else:\n            b = (l.weight @ b[:, None]).view(-1) + self.l.bias\n        m = torch.nn.Linear(w.size(1), w.size(0))\n        m.weight.data.copy_(w)\n        m.bias.data.copy_(b)\n        return m\n\n\ndef b16(n, activation, resolution=224):\n    return torch.nn.Sequential(\n        Conv2d_BN(3, n // 8, 3, 2, 1, resolution=resolution),\n        activation(),\n        Conv2d_BN(n // 8, n // 4, 3, 2, 1, resolution=resolution // 2),\n        activation(),\n        Conv2d_BN(n // 4, n // 2, 3, 2, 1, resolution=resolution // 4),\n        activation(),\n        Conv2d_BN(n // 2, n, 3, 2, 1, resolution=resolution // 8),\n    )\n\n\nclass Residual(torch.nn.Module):\n    def __init__(self, m, drop):\n        super().__init__()\n        self.m = m\n        self.drop = drop\n\n    def forward(self, x):\n        if self.training and self.drop > 0:\n            return (\n                x\n                + self.m(x)\n                * torch.rand(x.size(0), 1, 1, device=x.device)\n                .ge_(self.drop)\n                .div(1 - self.drop)\n                .detach()\n            )\n        else:\n            return x + self.m(x)\n\n\nclass Attention(torch.nn.Module):\n    def __init__(\n        self, dim, key_dim, num_heads=8, attn_ratio=4, activation=None, resolution=14\n    ):\n        super().__init__()\n        self.num_heads = num_heads\n        self.scale = key_dim ** -0.5\n        self.key_dim = key_dim\n        self.nh_kd = nh_kd = key_dim * num_heads\n        self.d = int(attn_ratio * key_dim)\n        self.dh = int(attn_ratio * key_dim) * num_heads\n        self.attn_ratio = attn_ratio\n        h = self.dh + nh_kd * 2\n        self.qkv = Linear_BN(dim, h, resolution=resolution)\n        self.proj = torch.nn.Sequential(\n            activation(),\n            Linear_BN(self.dh, dim, bn_weight_init=0, resolution=resolution),\n        )\n\n        points = list(itertools.product(range(resolution), range(resolution)))\n        N = len(points)\n        attention_offsets = {}\n        idxs = []\n        for p1 in points:\n            for p2 in points:\n                offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1]))\n                if offset not in attention_offsets:\n                    attention_offsets[offset] = len(attention_offsets)\n                idxs.append(attention_offsets[offset])\n        self.attention_biases = torch.nn.Parameter(\n            torch.zeros(num_heads, len(attention_offsets))\n        )\n        self.register_buffer(\"attention_bias_idxs\", torch.LongTensor(idxs).view(N, N))\n\n        global FLOPS_COUNTER\n        # queries * keys\n        FLOPS_COUNTER += num_heads * (resolution ** 4) * key_dim\n        # softmax\n        FLOPS_COUNTER += num_heads * (resolution ** 4)\n        # attention * v\n        FLOPS_COUNTER += num_heads * self.d * (resolution ** 4)\n\n    @torch.no_grad()\n    def train(self, mode=True):\n        super().train(mode)\n        if mode and hasattr(self, \"ab\"):\n            del self.ab\n        else:\n            self.ab = self.attention_biases[:, self.attention_bias_idxs]\n\n    def forward(self, x):  # x (B,N,C)\n        B, N, C = x.shape\n        qkv = self.qkv(x)\n        q, k, v = qkv.view(B, N, self.num_heads, -1).split(\n            [self.key_dim, self.key_dim, self.d], dim=3\n        )\n        q = q.permute(0, 2, 1, 3)\n        k = k.permute(0, 2, 1, 3)\n        v = v.permute(0, 2, 1, 3)\n\n        attn = (q @ k.transpose(-2, -1)) * self.scale + (\n            self.attention_biases[:, self.attention_bias_idxs]\n            if self.training\n            else self.ab\n        )\n        attn = attn.softmax(dim=-1)\n        x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh)\n        x = self.proj(x)\n        return x\n\n\nclass Subsample(torch.nn.Module):\n    def __init__(self, stride, resolution):\n        super().__init__()\n        self.stride = stride\n        self.resolution = resolution\n\n    def forward(self, x):\n        B, N, C = x.shape\n        x = x.view(B, self.resolution, self.resolution, C)[\n            :, :: self.stride, :: self.stride\n        ].reshape(B, -1, C)\n        return x\n\n\nclass AttentionSubsample(torch.nn.Module):\n    def __init__(\n        self,\n        in_dim,\n        out_dim,\n        key_dim,\n        num_heads=8,\n        attn_ratio=2,\n        activation=None,\n        stride=2,\n        resolution=14,\n        resolution_=7,\n    ):\n        super().__init__()\n        self.num_heads = num_heads\n        self.scale = key_dim ** -0.5\n        self.key_dim = key_dim\n        self.nh_kd = nh_kd = key_dim * num_heads\n        self.d = int(attn_ratio * key_dim)\n        self.dh = int(attn_ratio * key_dim) * self.num_heads\n        self.attn_ratio = attn_ratio\n        self.resolution_ = resolution_\n        self.resolution_2 = resolution_ ** 2\n        h = self.dh + nh_kd\n        self.kv = Linear_BN(in_dim, h, resolution=resolution)\n\n        self.q = torch.nn.Sequential(\n            Subsample(stride, resolution),\n            Linear_BN(in_dim, nh_kd, resolution=resolution_),\n        )\n        self.proj = torch.nn.Sequential(\n            activation(), Linear_BN(self.dh, out_dim, resolution=resolution_)\n        )\n\n        self.stride = stride\n        self.resolution = resolution\n        points = list(itertools.product(range(resolution), range(resolution)))\n        points_ = list(itertools.product(range(resolution_), range(resolution_)))\n        N = len(points)\n        N_ = len(points_)\n        attention_offsets = {}\n        idxs = []\n        for p1 in points_:\n            for p2 in points:\n                size = 1\n                offset = (\n                    abs(p1[0] * stride - p2[0] + (size - 1) / 2),\n                    abs(p1[1] * stride - p2[1] + (size - 1) / 2),\n                )\n                if offset not in attention_offsets:\n                    attention_offsets[offset] = len(attention_offsets)\n                idxs.append(attention_offsets[offset])\n        self.attention_biases = torch.nn.Parameter(\n            torch.zeros(num_heads, len(attention_offsets))\n        )\n        self.register_buffer(\"attention_bias_idxs\", torch.LongTensor(idxs).view(N_, N))\n\n        global FLOPS_COUNTER\n        # queries * keys\n        FLOPS_COUNTER += num_heads * (resolution ** 2) * (resolution_ ** 2) * key_dim\n        # softmax\n        FLOPS_COUNTER += num_heads * (resolution ** 2) * (resolution_ ** 2)\n        # attention * v\n        FLOPS_COUNTER += num_heads * (resolution ** 2) * (resolution_ ** 2) * self.d\n\n    @torch.no_grad()\n    def train(self, mode=True):\n        super().train(mode)\n        if mode and hasattr(self, \"ab\"):\n            del self.ab\n        else:\n            self.ab = self.attention_biases[:, self.attention_bias_idxs]\n\n    def forward(self, x):\n        B, N, C = x.shape\n        k, v = (\n            self.kv(x)\n            .view(B, N, self.num_heads, -1)\n            .split([self.key_dim, self.d], dim=3)\n        )\n        k = k.permute(0, 2, 1, 3)  # BHNC\n        v = v.permute(0, 2, 1, 3)  # BHNC\n        q = (\n            self.q(x)\n            .view(B, self.resolution_2, self.num_heads, self.key_dim)\n            .permute(0, 2, 1, 3)\n        )\n\n        attn = (q @ k.transpose(-2, -1)) * self.scale + (\n            self.attention_biases[:, self.attention_bias_idxs]\n            if self.training\n            else self.ab\n        )\n        attn = attn.softmax(dim=-1)\n\n        x = (attn @ v).transpose(1, 2).reshape(B, -1, self.dh)\n        x = self.proj(x)\n        return x\n\n\nclass LeViT(torch.nn.Module):\n    \"\"\" Vision Transformer with support for patch or hybrid CNN input stage\n    \"\"\"\n\n    def __init__(\n        self,\n        img_size=224,\n        patch_size=16,\n        in_chans=3,\n        num_classes=1000,\n        embed_dim=[192],\n        key_dim=[64],\n        depth=[12],\n        num_heads=[3],\n        attn_ratio=[2],\n        mlp_ratio=[2],\n        hybrid_backbone=None,\n        down_ops=[],\n        attention_activation=torch.nn.Hardswish,\n        mlp_activation=torch.nn.Hardswish,\n        distillation=True,\n        drop_path=0,\n    ):\n        super().__init__()\n        global FLOPS_COUNTER\n\n        self.num_classes = num_classes\n        self.num_features = embed_dim[-1]\n        self.embed_dim = embed_dim\n        self.distillation = distillation\n\n        self.patch_embed = hybrid_backbone\n\n        self.blocks = []\n        down_ops.append([\"\"])\n        resolution = img_size // patch_size\n        for i, (ed, kd, dpth, nh, ar, mr, do) in enumerate(\n            zip(embed_dim, key_dim, depth, num_heads, attn_ratio, mlp_ratio, down_ops)\n        ):\n            for _ in range(dpth):\n                self.blocks.append(\n                    Residual(\n                        Attention(\n                            ed,\n                            kd,\n                            nh,\n                            attn_ratio=ar,\n                            activation=attention_activation,\n                            resolution=resolution,\n                        ),\n                        drop_path,\n                    )\n                )\n                if mr > 0:\n                    h = int(ed * mr)\n                    self.blocks.append(\n                        Residual(\n                            torch.nn.Sequential(\n                                Linear_BN(ed, h, resolution=resolution),\n                                mlp_activation(),\n                                Linear_BN(\n                                    h, ed, bn_weight_init=0, resolution=resolution\n                                ),\n                            ),\n                            drop_path,\n                        )\n                    )\n            if do[0] == \"Subsample\":\n                # ('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride)\n                resolution_ = (resolution - 1) // do[5] + 1\n                self.blocks.append(\n                    AttentionSubsample(\n                        *embed_dim[i : i + 2],\n                        key_dim=do[1],\n                        num_heads=do[2],\n                        attn_ratio=do[3],\n                        activation=attention_activation,\n                        stride=do[5],\n                        resolution=resolution,\n                        resolution_=resolution_\n                    )\n                )\n                resolution = resolution_\n                if do[4] > 0:  # mlp_ratio\n                    h = int(embed_dim[i + 1] * do[4])\n                    self.blocks.append(\n                        Residual(\n                            torch.nn.Sequential(\n                                Linear_BN(embed_dim[i + 1], h, resolution=resolution),\n                                mlp_activation(),\n                                Linear_BN(\n                                    h,\n                                    embed_dim[i + 1],\n                                    bn_weight_init=0,\n                                    resolution=resolution,\n                                ),\n                            ),\n                            drop_path,\n                        )\n                    )\n        self.blocks = torch.nn.Sequential(*self.blocks)\n\n        # Classifier head\n        self.head = (\n            BN_Linear(embed_dim[-1], num_classes)\n            if num_classes > 0\n            else torch.nn.Identity()\n        )\n        if distillation:\n            self.head_dist = (\n                BN_Linear(embed_dim[-1], num_classes)\n                if num_classes > 0\n                else torch.nn.Identity()\n            )\n\n        self.FLOPS = FLOPS_COUNTER\n        FLOPS_COUNTER = 0\n\n    def no_weight_decay(self):\n        return {x for x in self.state_dict().keys() if \"attention_biases\" in x}\n\n    def forward(self, x):\n        x = self.patch_embed(x)\n        x = x.flatten(2).transpose(1, 2)\n        x = self.blocks(x)\n        x = x.mean(1)\n        if self.distillation:\n            x = self.head(x), self.head_dist(x)\n            if not self.training:\n                x = (x[0] + x[1]) / 2\n        else:\n            x = self.head(x)\n        return x\n\n\ndef model_factory(\n    C, D, X, N, drop_path, weights, num_classes, distillation, pretrained, fuse\n):\n    embed_dim = [int(x) for x in C.split(\"_\")]\n    num_heads = [int(x) for x in N.split(\"_\")]\n    depth = [int(x) for x in X.split(\"_\")]\n    act = torch.nn.Hardswish\n    model = LeViT(\n        patch_size=16,\n        embed_dim=embed_dim,\n        num_heads=num_heads,\n        key_dim=[D] * 3,\n        depth=depth,\n        attn_ratio=[2, 2, 2],\n        mlp_ratio=[2, 2, 2],\n        down_ops=[\n            # ('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride)\n            [\"Subsample\", D, embed_dim[0] // D, 4, 2, 2],\n            [\"Subsample\", D, embed_dim[1] // D, 4, 2, 2],\n        ],\n        attention_activation=act,\n        mlp_activation=act,\n        hybrid_backbone=b16(embed_dim[0], activation=act),\n        num_classes=num_classes,\n        drop_path=drop_path,\n        distillation=distillation,\n    )\n    return model\n"
  },
  {
    "path": "python/oneflow/test/expensive/pytorch_mnasnet.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport torch\nimport torch.nn as nn\nfrom torch import Tensor\n\nimport warnings\nfrom typing import Any, Dict, List\n\n__all__ = [\n    \"MNASNet\",\n    \"mnasnet1_0\",\n]\n\n\n# Paper suggests 0.9997 momentum, for TensorFlow. Equivalent PyTorch momentum is\n# 1.0 - tensorflow.\n_BN_MOMENTUM = 1 - 0.9997\n\n\nclass _InvertedResidual(nn.Module):\n    def __init__(\n        self,\n        in_ch: int,\n        out_ch: int,\n        kernel_size: int,\n        stride: int,\n        expansion_factor: int,\n        bn_momentum: float = 0.1,\n    ) -> None:\n        super().__init__()\n        if stride not in [1, 2]:\n            raise ValueError(f\"stride should be 1 or 2 instead of {stride}\")\n        if kernel_size not in [3, 5]:\n            raise ValueError(f\"kernel_size should be 3 or 5 instead of {kernel_size}\")\n        mid_ch = in_ch * expansion_factor\n        self.apply_residual = in_ch == out_ch and stride == 1\n        self.layers = nn.Sequential(\n            # Pointwise\n            nn.Conv2d(in_ch, mid_ch, 1, bias=False),\n            nn.BatchNorm2d(mid_ch, momentum=bn_momentum),\n            nn.ReLU(inplace=True),\n            # Depthwise\n            nn.Conv2d(\n                mid_ch,\n                mid_ch,\n                kernel_size,\n                padding=kernel_size // 2,\n                stride=stride,\n                groups=mid_ch,\n                bias=False,\n            ),\n            nn.BatchNorm2d(mid_ch, momentum=bn_momentum),\n            nn.ReLU(inplace=True),\n            # Linear pointwise. Note that there's no activation.\n            nn.Conv2d(mid_ch, out_ch, 1, bias=False),\n            nn.BatchNorm2d(out_ch, momentum=bn_momentum),\n        )\n\n    def forward(self, input: Tensor) -> Tensor:\n        if self.apply_residual:\n            return self.layers(input) + input\n        else:\n            return self.layers(input)\n\n\ndef _stack(\n    in_ch: int,\n    out_ch: int,\n    kernel_size: int,\n    stride: int,\n    exp_factor: int,\n    repeats: int,\n    bn_momentum: float,\n) -> nn.Sequential:\n    \"\"\"Creates a stack of inverted residuals.\"\"\"\n    if repeats < 1:\n        raise ValueError(f\"repeats should be >= 1, instead got {repeats}\")\n    # First one has no skip, because feature map size changes.\n    first = _InvertedResidual(\n        in_ch, out_ch, kernel_size, stride, exp_factor, bn_momentum=bn_momentum\n    )\n    remaining = []\n    for _ in range(1, repeats):\n        remaining.append(\n            _InvertedResidual(\n                out_ch, out_ch, kernel_size, 1, exp_factor, bn_momentum=bn_momentum\n            )\n        )\n    return nn.Sequential(first, *remaining)\n\n\ndef _round_to_multiple_of(val: float, divisor: int, round_up_bias: float = 0.9) -> int:\n    \"\"\"Asymmetric rounding to make `val` divisible by `divisor`. With default\n    bias, will round up, unless the number is no more than 10% greater than the\n    smaller divisible value, i.e. (83, 8) -> 80, but (84, 8) -> 88.\"\"\"\n    if not 0.0 < round_up_bias < 1.0:\n        raise ValueError(\n            f\"round_up_bias should be greater than 0.0 and smaller than 1.0 instead of {round_up_bias}\"\n        )\n    new_val = max(divisor, int(val + divisor / 2) // divisor * divisor)\n    return new_val if new_val >= round_up_bias * val else new_val + divisor\n\n\ndef _get_depths(alpha: float) -> List[int]:\n    \"\"\"Scales tensor depths as in reference MobileNet code, prefers rouding up\n    rather than down.\"\"\"\n    depths = [32, 16, 24, 40, 80, 96, 192, 320]\n    return [_round_to_multiple_of(depth * alpha, 8) for depth in depths]\n\n\nclass MNASNet(torch.nn.Module):\n    \"\"\"MNASNet, as described in https://arxiv.org/pdf/1807.11626.pdf. This\n    implements the B1 variant of the model.\n    >>> model = MNASNet(1.0, num_classes=1000)\n    >>> x = torch.rand(1, 3, 224, 224)\n    >>> y = model(x)\n    >>> y.dim()\n    2\n    >>> y.nelement()\n    1000\n    \"\"\"\n\n    # Version 2 adds depth scaling in the initial stages of the network.\n    _version = 2\n\n    def __init__(\n        self, alpha: float, num_classes: int = 1000, dropout: float = 0.2\n    ) -> None:\n        super().__init__()\n        if alpha <= 0.0:\n            raise ValueError(f\"alpha should be greater than 0.0 instead of {alpha}\")\n        self.alpha = alpha\n        self.num_classes = num_classes\n        depths = _get_depths(alpha)\n        layers = [\n            # First layer: regular conv.\n            nn.Conv2d(3, depths[0], 3, padding=1, stride=2, bias=False),\n            nn.BatchNorm2d(depths[0], momentum=_BN_MOMENTUM),\n            nn.ReLU(inplace=True),\n            # Depthwise separable, no skip.\n            nn.Conv2d(\n                depths[0],\n                depths[0],\n                3,\n                padding=1,\n                stride=1,\n                groups=depths[0],\n                bias=False,\n            ),\n            nn.BatchNorm2d(depths[0], momentum=_BN_MOMENTUM),\n            nn.ReLU(inplace=True),\n            nn.Conv2d(depths[0], depths[1], 1, padding=0, stride=1, bias=False),\n            nn.BatchNorm2d(depths[1], momentum=_BN_MOMENTUM),\n            # MNASNet blocks: stacks of inverted residuals.\n            _stack(depths[1], depths[2], 3, 2, 3, 3, _BN_MOMENTUM),\n            _stack(depths[2], depths[3], 5, 2, 3, 3, _BN_MOMENTUM),\n            _stack(depths[3], depths[4], 5, 2, 6, 3, _BN_MOMENTUM),\n            _stack(depths[4], depths[5], 3, 1, 6, 2, _BN_MOMENTUM),\n            _stack(depths[5], depths[6], 5, 2, 6, 4, _BN_MOMENTUM),\n            _stack(depths[6], depths[7], 3, 1, 6, 1, _BN_MOMENTUM),\n            # Final mapping to classifier input.\n            nn.Conv2d(depths[7], 1280, 1, padding=0, stride=1, bias=False),\n            nn.BatchNorm2d(1280, momentum=_BN_MOMENTUM),\n            nn.ReLU(inplace=True),\n        ]\n        self.layers = nn.Sequential(*layers)\n        self.classifier = nn.Sequential(\n            nn.Dropout(p=dropout, inplace=True), nn.Linear(1280, num_classes)\n        )\n\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                nn.init.kaiming_normal_(m.weight, mode=\"fan_out\", nonlinearity=\"relu\")\n                if m.bias is not None:\n                    nn.init.zeros_(m.bias)\n            elif isinstance(m, nn.BatchNorm2d):\n                nn.init.ones_(m.weight)\n                nn.init.zeros_(m.bias)\n            elif isinstance(m, nn.Linear):\n                nn.init.kaiming_uniform_(\n                    m.weight, mode=\"fan_out\", nonlinearity=\"sigmoid\"\n                )\n                nn.init.zeros_(m.bias)\n\n    def forward(self, x: Tensor) -> Tensor:\n        x = self.layers(x)\n        # Equivalent to global avgpool and removing H and W dimensions.\n        x = x.mean([2, 3])\n        return self.classifier(x)\n\n\ndef _mnasnet(alpha: float, progress: bool, **kwargs: Any) -> MNASNet:\n    model = MNASNet(alpha, **kwargs)\n    return model\n\n\ndef mnasnet1_0(progress: bool = True, **kwargs: Any) -> MNASNet:\n    r\"\"\"MNASNet with depth multiplier of 1.0 from\n    `\"MnasNet: Platform-Aware Neural Architecture Search for Mobile\"\n    <https://arxiv.org/pdf/1807.11626.pdf>`_.\n    Args:\n        weights (MNASNet1_0_Weights, optional): The pretrained weights for the model\n        progress (bool): If True, displays a progress bar of the download to stderr\n    \"\"\"\n    return _mnasnet(1.0, progress, **kwargs)\n"
  },
  {
    "path": "python/oneflow/test/expensive/pytorch_poolformer.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport torch\nimport torch.nn as nn\nfrom timm.models.layers import DropPath, trunc_normal_\nfrom timm.models.layers.helpers import to_2tuple\n\nimport os\nimport copy\n\n\nclass PatchEmbed(nn.Module):\n    \"\"\"\n    Patch Embedding that is implemented by a layer of conv. \n    Input: tensor in shape [B, C, H, W]\n    Output: tensor in shape [B, C, H/stride, W/stride]\n    \"\"\"\n\n    def __init__(\n        self,\n        patch_size=16,\n        stride=16,\n        padding=0,\n        in_chans=3,\n        embed_dim=768,\n        norm_layer=None,\n    ):\n        super().__init__()\n        patch_size = to_2tuple(patch_size)\n        stride = to_2tuple(stride)\n        padding = to_2tuple(padding)\n        self.proj = nn.Conv2d(\n            in_chans, embed_dim, kernel_size=patch_size, stride=stride, padding=padding\n        )\n        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()\n\n    def forward(self, x):\n        x = self.proj(x)\n        x = self.norm(x)\n        return x\n\n\nclass LayerNormChannel(nn.Module):\n    \"\"\"\n    LayerNorm only for Channel Dimension.\n    Input: tensor in shape [B, C, H, W]\n    \"\"\"\n\n    def __init__(self, num_channels, eps=1e-05):\n        super().__init__()\n        self.weight = nn.Parameter(torch.ones(num_channels))\n        self.bias = nn.Parameter(torch.zeros(num_channels))\n        self.eps = eps\n\n    def forward(self, x):\n        u = x.mean(1, keepdim=True)\n        s = (x - u).pow(2).mean(1, keepdim=True)\n        x = (x - u) / torch.sqrt(s + self.eps)\n        x = self.weight.unsqueeze(-1).unsqueeze(-1) * x + self.bias.unsqueeze(\n            -1\n        ).unsqueeze(-1)\n        return x\n\n\nclass GroupNorm(nn.GroupNorm):\n    \"\"\"\n    Group Normalization with 1 group.\n    Input: tensor in shape [B, C, H, W]\n    \"\"\"\n\n    def __init__(self, num_channels, **kwargs):\n        super().__init__(1, num_channels, **kwargs)\n\n\nclass Pooling(nn.Module):\n    \"\"\"\n    Implementation of pooling for PoolFormer\n    --pool_size: pooling size\n    \"\"\"\n\n    def __init__(self, pool_size=3):\n        super().__init__()\n        self.pool = nn.AvgPool2d(\n            pool_size, stride=1, padding=pool_size // 2, count_include_pad=False\n        )\n\n    def forward(self, x):\n        return self.pool(x) - x\n\n\nclass Mlp(nn.Module):\n    \"\"\"\n    Implementation of MLP with 1*1 convolutions.\n    Input: tensor with shape [B, C, H, W]\n    \"\"\"\n\n    def __init__(\n        self,\n        in_features,\n        hidden_features=None,\n        out_features=None,\n        act_layer=nn.GELU,\n        drop=0.0,\n    ):\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        self.fc1 = nn.Conv2d(in_features, hidden_features, 1)\n        self.act = act_layer()\n        self.fc2 = nn.Conv2d(hidden_features, out_features, 1)\n        self.drop = nn.Dropout(drop)\n        self.apply(self._init_weights)\n\n    def _init_weights(self, m):\n        if isinstance(m, nn.Conv2d):\n            trunc_normal_(m.weight, std=0.02)\n            if m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n\n    def forward(self, x):\n        x = self.fc1(x)\n        x = self.act(x)\n        x = self.drop(x)\n        x = self.fc2(x)\n        x = self.drop(x)\n        return x\n\n\nclass PoolFormerBlock(nn.Module):\n    \"\"\"\n    Implementation of one PoolFormer block.\n    --dim: embedding dim\n    --pool_size: pooling size\n    --mlp_ratio: mlp expansion ratio\n    --act_layer: activation\n    --norm_layer: normalization\n    --drop: dropout rate\n    --drop path: Stochastic Depth, \n        refer to https://arxiv.org/abs/1603.09382\n    --use_layer_scale, --layer_scale_init_value: LayerScale, \n        refer to https://arxiv.org/abs/2103.17239\n    \"\"\"\n\n    def __init__(\n        self,\n        dim,\n        pool_size=3,\n        mlp_ratio=4.0,\n        act_layer=nn.GELU,\n        norm_layer=GroupNorm,\n        drop=0.0,\n        drop_path=0.0,\n        use_layer_scale=True,\n        layer_scale_init_value=1e-5,\n    ):\n\n        super().__init__()\n\n        self.norm1 = norm_layer(dim)\n        self.token_mixer = Pooling(pool_size=pool_size)\n        self.norm2 = norm_layer(dim)\n        mlp_hidden_dim = int(dim * mlp_ratio)\n        self.mlp = Mlp(\n            in_features=dim,\n            hidden_features=mlp_hidden_dim,\n            act_layer=act_layer,\n            drop=drop,\n        )\n\n        # The following two techniques are useful to train deep PoolFormers.\n        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()\n        self.use_layer_scale = use_layer_scale\n        if use_layer_scale:\n            self.layer_scale_1 = nn.Parameter(\n                layer_scale_init_value * torch.ones((dim)), requires_grad=True\n            )\n            self.layer_scale_2 = nn.Parameter(\n                layer_scale_init_value * torch.ones((dim)), requires_grad=True\n            )\n\n    def forward(self, x):\n        if self.use_layer_scale:\n            x = x + self.drop_path(\n                self.layer_scale_1.unsqueeze(-1).unsqueeze(-1)\n                * self.token_mixer(self.norm1(x))\n            )\n            x = x + self.drop_path(\n                self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) * self.mlp(self.norm2(x))\n            )\n        else:\n            x = x + self.drop_path(self.token_mixer(self.norm1(x)))\n            x = x + self.drop_path(self.mlp(self.norm2(x)))\n        return x\n\n\ndef basic_blocks(\n    dim,\n    index,\n    layers,\n    pool_size=3,\n    mlp_ratio=4.0,\n    act_layer=nn.GELU,\n    norm_layer=GroupNorm,\n    drop_rate=0.0,\n    drop_path_rate=0.0,\n    use_layer_scale=True,\n    layer_scale_init_value=1e-5,\n):\n    \"\"\"\n    generate PoolFormer blocks for a stage\n    return: PoolFormer blocks \n    \"\"\"\n    blocks = []\n    for block_idx in range(layers[index]):\n        block_dpr = (\n            drop_path_rate * (block_idx + sum(layers[:index])) / (sum(layers) - 1)\n        )\n        blocks.append(\n            PoolFormerBlock(\n                dim,\n                pool_size=pool_size,\n                mlp_ratio=mlp_ratio,\n                act_layer=act_layer,\n                norm_layer=norm_layer,\n                drop=drop_rate,\n                drop_path=block_dpr,\n                use_layer_scale=use_layer_scale,\n                layer_scale_init_value=layer_scale_init_value,\n            )\n        )\n    blocks = nn.Sequential(*blocks)\n\n    return blocks\n\n\nclass PoolFormer(nn.Module):\n    \"\"\"\n    PoolFormer, the main class of our model\n    --layers: [x,x,x,x], number of blocks for the 4 stages\n    --embed_dims, --mlp_ratios, --pool_size: the embedding dims, mlp ratios and \n        pooling size for the 4 stages\n    --downsamples: flags to apply downsampling or not\n    --norm_layer, --act_layer: define the types of normalization and activation\n    --num_classes: number of classes for the image classification\n    --in_patch_size, --in_stride, --in_pad: specify the patch embedding\n        for the input image\n    --down_patch_size --down_stride --down_pad: \n        specify the downsample (patch embed.)\n    --fork_feat: whether output features of the 4 stages, for dense prediction\n    --init_cfg, --pretrained: \n        for mmdetection and mmsegmentation to load pretrained weights\n    \"\"\"\n\n    def __init__(\n        self,\n        layers,\n        embed_dims=None,\n        mlp_ratios=None,\n        downsamples=None,\n        pool_size=3,\n        norm_layer=GroupNorm,\n        act_layer=nn.GELU,\n        num_classes=1000,\n        in_patch_size=7,\n        in_stride=4,\n        in_pad=2,\n        down_patch_size=3,\n        down_stride=2,\n        down_pad=1,\n        drop_rate=0.0,\n        drop_path_rate=0.0,\n        use_layer_scale=True,\n        layer_scale_init_value=1e-5,\n        fork_feat=False,\n        init_cfg=None,\n        pretrained=None,\n        **kwargs,\n    ):\n\n        super().__init__()\n\n        if not fork_feat:\n            self.num_classes = num_classes\n        self.fork_feat = fork_feat\n\n        self.patch_embed = PatchEmbed(\n            patch_size=in_patch_size,\n            stride=in_stride,\n            padding=in_pad,\n            in_chans=3,\n            embed_dim=embed_dims[0],\n        )\n\n        # set the main block in network\n        network = []\n        for i in range(len(layers)):\n            stage = basic_blocks(\n                embed_dims[i],\n                i,\n                layers,\n                pool_size=pool_size,\n                mlp_ratio=mlp_ratios[i],\n                act_layer=act_layer,\n                norm_layer=norm_layer,\n                drop_rate=drop_rate,\n                drop_path_rate=drop_path_rate,\n                use_layer_scale=use_layer_scale,\n                layer_scale_init_value=layer_scale_init_value,\n            )\n            network.append(stage)\n            if i >= len(layers) - 1:\n                break\n            if downsamples[i] or embed_dims[i] != embed_dims[i + 1]:\n                # downsampling between two stages\n                network.append(\n                    PatchEmbed(\n                        patch_size=down_patch_size,\n                        stride=down_stride,\n                        padding=down_pad,\n                        in_chans=embed_dims[i],\n                        embed_dim=embed_dims[i + 1],\n                    )\n                )\n\n        self.network = nn.ModuleList(network)\n\n        if self.fork_feat:\n            # add a norm layer for each output\n            self.out_indices = [0, 2, 4, 6]\n            for i_emb, i_layer in enumerate(self.out_indices):\n                if i_emb == 0 and os.environ.get(\"FORK_LAST3\", None):\n                    # TODO: more elegant way\n                    \"\"\"For RetinaNet, `start_level=1`. The first norm layer will not used.\n                    cmd: `FORK_LAST3=1 python -m torch.distributed.launch ...`\n                    \"\"\"\n                    layer = nn.Identity()\n                else:\n                    layer = norm_layer(embed_dims[i_emb])\n                layer_name = f\"norm{i_layer}\"\n                self.add_module(layer_name, layer)\n        else:\n            # Classifier head\n            self.norm = norm_layer(embed_dims[-1])\n            self.head = (\n                nn.Linear(embed_dims[-1], num_classes)\n                if num_classes > 0\n                else nn.Identity()\n            )\n\n        self.apply(self.cls_init_weights)\n\n        self.init_cfg = copy.deepcopy(init_cfg)\n        # load pre-trained model\n        if self.fork_feat and (self.init_cfg is not None or pretrained is not None):\n            self.init_weights()\n\n    # init for classification\n    def cls_init_weights(self, m):\n        if isinstance(m, nn.Linear):\n            trunc_normal_(m.weight, std=0.02)\n            if isinstance(m, nn.Linear) and m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n\n    def get_classifier(self):\n        return self.head\n\n    def reset_classifier(self, num_classes):\n        self.num_classes = num_classes\n        self.head = (\n            nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()\n        )\n\n    def forward_embeddings(self, x):\n        x = self.patch_embed(x)\n        return x\n\n    def forward_tokens(self, x):\n        outs = []\n        for idx, block in enumerate(self.network):\n            x = block(x)\n            if self.fork_feat and idx in self.out_indices:\n                norm_layer = getattr(self, f\"norm{idx}\")\n                x_out = norm_layer(x)\n                outs.append(x_out)\n        if self.fork_feat:\n            # output the features of four stages for dense prediction\n            return outs\n        # output only the features of last layer for image classification\n        return x\n\n    def forward(self, x):\n        # input embedding\n        x = self.forward_embeddings(x)\n        # through backbone\n        x = self.forward_tokens(x)\n        if self.fork_feat:\n            # otuput features of four stages for dense prediction\n            return x\n        x = self.norm(x)\n        cls_out = self.head(x.mean([-2, -1]))\n        # for image classification\n        return cls_out\n\n\ndef poolformer_s12(pretrained=False, **kwargs):\n    \"\"\"\n    PoolFormer-S12 model, Params: 12M\n    --layers: [x,x,x,x], numbers of layers for the four stages\n    --embed_dims, --mlp_ratios: \n        embedding dims and mlp ratios for the four stages\n    --downsamples: flags to apply downsampling or not in four blocks\n    \"\"\"\n    layers = [2, 2, 6, 2]\n    embed_dims = [64, 128, 320, 512]\n    mlp_ratios = [4, 4, 4, 4]\n    downsamples = [True, True, True, True]\n    model = PoolFormer(\n        layers,\n        embed_dims=embed_dims,\n        mlp_ratios=mlp_ratios,\n        downsamples=downsamples,\n        **kwargs,\n    )\n    return model\n"
  },
  {
    "path": "python/oneflow/test/expensive/pytorch_pvt.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom timm.models.layers import DropPath, to_2tuple, trunc_normal_\n\nfrom functools import partial\n\n__all__ = [\"pvt_tiny\"]\n\n\nclass Mlp(nn.Module):\n    def __init__(\n        self,\n        in_features,\n        hidden_features=None,\n        out_features=None,\n        act_layer=nn.GELU,\n        drop=0.0,\n    ):\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        self.fc1 = nn.Linear(in_features, hidden_features)\n        self.act = act_layer()\n        self.fc2 = nn.Linear(hidden_features, out_features)\n        self.drop = nn.Dropout(drop)\n\n    def forward(self, x):\n        x = self.fc1(x)\n        x = self.act(x)\n        x = self.drop(x)\n        x = self.fc2(x)\n        x = self.drop(x)\n        return x\n\n\nclass Attention(nn.Module):\n    def __init__(\n        self,\n        dim,\n        num_heads=8,\n        qkv_bias=False,\n        qk_scale=None,\n        attn_drop=0.0,\n        proj_drop=0.0,\n        sr_ratio=1,\n    ):\n        super().__init__()\n        assert (\n            dim % num_heads == 0\n        ), f\"dim {dim} should be divided by num_heads {num_heads}.\"\n\n        self.dim = dim\n        self.num_heads = num_heads\n        head_dim = dim // num_heads\n        self.scale = qk_scale or head_dim ** -0.5\n\n        self.q = nn.Linear(dim, dim, bias=qkv_bias)\n        self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)\n        self.attn_drop = nn.Dropout(attn_drop)\n        self.proj = nn.Linear(dim, dim)\n        self.proj_drop = nn.Dropout(proj_drop)\n\n        self.sr_ratio = sr_ratio\n        if sr_ratio > 1:\n            self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)\n            self.norm = nn.LayerNorm(dim)\n\n    def forward(self, x, H, W):\n        B, N, C = x.shape\n        q = (\n            self.q(x)\n            .reshape(B, N, self.num_heads, C // self.num_heads)\n            .permute(0, 2, 1, 3)\n        )\n\n        if self.sr_ratio > 1:\n            x_ = x.permute(0, 2, 1).reshape(B, C, H, W)\n            x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)\n            x_ = self.norm(x_)\n            kv = (\n                self.kv(x_)\n                .reshape(B, -1, 2, self.num_heads, C // self.num_heads)\n                .permute(2, 0, 3, 1, 4)\n            )\n        else:\n            kv = (\n                self.kv(x)\n                .reshape(B, -1, 2, self.num_heads, C // self.num_heads)\n                .permute(2, 0, 3, 1, 4)\n            )\n        k, v = kv[0], kv[1]\n\n        attn = (q @ k.transpose(-2, -1)) * self.scale\n        attn = attn.softmax(dim=-1)\n        attn = self.attn_drop(attn)\n\n        x = (attn @ v).transpose(1, 2).reshape(B, N, C)\n        x = self.proj(x)\n        x = self.proj_drop(x)\n\n        return x\n\n\nclass Block(nn.Module):\n    def __init__(\n        self,\n        dim,\n        num_heads,\n        mlp_ratio=4.0,\n        qkv_bias=False,\n        qk_scale=None,\n        drop=0.0,\n        attn_drop=0.0,\n        drop_path=0.0,\n        act_layer=nn.GELU,\n        norm_layer=nn.LayerNorm,\n        sr_ratio=1,\n    ):\n        super().__init__()\n        self.norm1 = norm_layer(dim)\n        self.attn = Attention(\n            dim,\n            num_heads=num_heads,\n            qkv_bias=qkv_bias,\n            qk_scale=qk_scale,\n            attn_drop=attn_drop,\n            proj_drop=drop,\n            sr_ratio=sr_ratio,\n        )\n        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here\n        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()\n        self.norm2 = norm_layer(dim)\n        mlp_hidden_dim = int(dim * mlp_ratio)\n        self.mlp = Mlp(\n            in_features=dim,\n            hidden_features=mlp_hidden_dim,\n            act_layer=act_layer,\n            drop=drop,\n        )\n\n    def forward(self, x, H, W):\n        x = x + self.drop_path(self.attn(self.norm1(x), H, W))\n        x = x + self.drop_path(self.mlp(self.norm2(x)))\n\n        return x\n\n\nclass PatchEmbed(nn.Module):\n    \"\"\" Image to Patch Embedding\n    \"\"\"\n\n    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):\n        super().__init__()\n        img_size = to_2tuple(img_size)\n        patch_size = to_2tuple(patch_size)\n\n        self.img_size = img_size\n        self.patch_size = patch_size\n        # assert img_size[0] % patch_size[0] == 0 and img_size[1] % patch_size[1] == 0, \\\n        #     f\"img_size {img_size} should be divided by patch_size {patch_size}.\"\n        self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1]\n        self.num_patches = self.H * self.W\n        self.proj = nn.Conv2d(\n            in_chans, embed_dim, kernel_size=patch_size, stride=patch_size\n        )\n        self.norm = nn.LayerNorm(embed_dim)\n\n    def forward(self, x):\n        B, C, H, W = x.shape\n\n        x = self.proj(x).flatten(2).transpose(1, 2)\n        x = self.norm(x)\n        H, W = H // self.patch_size[0], W // self.patch_size[1]\n\n        return x, (H, W)\n\n\nclass PyramidVisionTransformer(nn.Module):\n    def __init__(\n        self,\n        img_size=224,\n        patch_size=16,\n        in_chans=3,\n        num_classes=1000,\n        embed_dims=[64, 128, 256, 512],\n        num_heads=[1, 2, 4, 8],\n        mlp_ratios=[4, 4, 4, 4],\n        qkv_bias=False,\n        qk_scale=None,\n        drop_rate=0.0,\n        attn_drop_rate=0.0,\n        drop_path_rate=0.0,\n        norm_layer=nn.LayerNorm,\n        depths=[3, 4, 6, 3],\n        sr_ratios=[8, 4, 2, 1],\n        num_stages=4,\n    ):\n        super().__init__()\n        self.num_classes = num_classes\n        self.depths = depths\n        self.num_stages = num_stages\n\n        dpr = [\n            x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))\n        ]  # stochastic depth decay rule\n        cur = 0\n\n        for i in range(num_stages):\n            patch_embed = PatchEmbed(\n                img_size=img_size if i == 0 else img_size // (2 ** (i + 1)),\n                patch_size=patch_size if i == 0 else 2,\n                in_chans=in_chans if i == 0 else embed_dims[i - 1],\n                embed_dim=embed_dims[i],\n            )\n            num_patches = (\n                patch_embed.num_patches\n                if i != num_stages - 1\n                else patch_embed.num_patches + 1\n            )\n            pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dims[i]))\n            pos_drop = nn.Dropout(p=drop_rate)\n\n            block = nn.ModuleList(\n                [\n                    Block(\n                        dim=embed_dims[i],\n                        num_heads=num_heads[i],\n                        mlp_ratio=mlp_ratios[i],\n                        qkv_bias=qkv_bias,\n                        qk_scale=qk_scale,\n                        drop=drop_rate,\n                        attn_drop=attn_drop_rate,\n                        drop_path=dpr[cur + j],\n                        norm_layer=norm_layer,\n                        sr_ratio=sr_ratios[i],\n                    )\n                    for j in range(depths[i])\n                ]\n            )\n            cur += depths[i]\n\n            setattr(self, f\"patch_embed{i + 1}\", patch_embed)\n            setattr(self, f\"pos_embed{i + 1}\", pos_embed)\n            setattr(self, f\"pos_drop{i + 1}\", pos_drop)\n            setattr(self, f\"block{i + 1}\", block)\n\n        self.norm = norm_layer(embed_dims[3])\n\n        # cls_token\n        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims[3]))\n\n        # classification head\n        self.head = (\n            nn.Linear(embed_dims[3], num_classes) if num_classes > 0 else nn.Identity()\n        )\n\n        # init weights\n        for i in range(num_stages):\n            pos_embed = getattr(self, f\"pos_embed{i + 1}\")\n            trunc_normal_(pos_embed, std=0.02)\n        trunc_normal_(self.cls_token, std=0.02)\n        self.apply(self._init_weights)\n\n    def _init_weights(self, m):\n        if isinstance(m, nn.Linear):\n            trunc_normal_(m.weight, std=0.02)\n            if isinstance(m, nn.Linear) and m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n        elif isinstance(m, nn.LayerNorm):\n            nn.init.constant_(m.bias, 0)\n            nn.init.constant_(m.weight, 1.0)\n\n    def no_weight_decay(self):\n        # return {'pos_embed', 'cls_token'} # has pos_embed may be better\n        return {\"cls_token\"}\n\n    def get_classifier(self):\n        return self.head\n\n    def reset_classifier(self, num_classes, global_pool=\"\"):\n        self.num_classes = num_classes\n        self.head = (\n            nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()\n        )\n\n    def _get_pos_embed(self, pos_embed, patch_embed, H, W):\n        if H * W == self.patch_embed1.num_patches:\n            return pos_embed\n        else:\n            return (\n                F.interpolate(\n                    pos_embed.reshape(1, patch_embed.H, patch_embed.W, -1).permute(\n                        0, 3, 1, 2\n                    ),\n                    size=(H, W),\n                    mode=\"bilinear\",\n                )\n                .reshape(1, -1, H * W)\n                .permute(0, 2, 1)\n            )\n\n    def forward_features(self, x):\n        B = x.shape[0]\n\n        for i in range(self.num_stages):\n            patch_embed = getattr(self, f\"patch_embed{i + 1}\")\n            pos_embed = getattr(self, f\"pos_embed{i + 1}\")\n            pos_drop = getattr(self, f\"pos_drop{i + 1}\")\n            block = getattr(self, f\"block{i + 1}\")\n            x, (H, W) = patch_embed(x)\n\n            if i == self.num_stages - 1:\n                cls_tokens = self.cls_token.expand(B, -1, -1)\n                x = torch.cat((cls_tokens, x), dim=1)\n                pos_embed_ = self._get_pos_embed(pos_embed[:, 1:], patch_embed, H, W)\n                pos_embed = torch.cat((pos_embed[:, 0:1], pos_embed_), dim=1)\n            else:\n                pos_embed = self._get_pos_embed(pos_embed, patch_embed, H, W)\n\n            x = pos_drop(x + pos_embed)\n            for blk in block:\n                x = blk(x, H, W)\n            if i != self.num_stages - 1:\n                x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()\n\n        x = self.norm(x)\n\n        return x[:, 0]\n\n    def forward(self, x):\n        x = self.forward_features(x)\n        x = self.head(x)\n\n        return x\n\n\ndef pvt_tiny(pretrained=False, **kwargs):\n    model = PyramidVisionTransformer(\n        patch_size=4,\n        embed_dims=[64, 128, 320, 512],\n        num_heads=[1, 2, 5, 8],\n        mlp_ratios=[8, 8, 4, 4],\n        qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6),\n        depths=[2, 2, 2, 2],\n        sr_ratios=[8, 4, 2, 1],\n        **kwargs,\n    )\n    return model\n"
  },
  {
    "path": "python/oneflow/test/expensive/pytorch_res2net.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport torch.nn as nn\nimport torch\nimport torch.nn.functional as F\nimport math\n\n__all__ = [\"Res2Net\", \"res2net50\"]\n\n\nclass Bottle2neck(nn.Module):\n    expansion = 4\n\n    def __init__(\n        self,\n        inplanes,\n        planes,\n        stride=1,\n        downsample=None,\n        baseWidth=26,\n        scale=4,\n        stype=\"normal\",\n    ):\n        \"\"\" Constructor\n        Args:\n            inplanes: input channel dimensionality\n            planes: output channel dimensionality\n            stride: conv stride. Replaces pooling layer.\n            downsample: None when stride = 1\n            baseWidth: basic width of conv3x3\n            scale: number of scale.\n            type: 'normal': normal set. 'stage': first block of a new stage.\n        \"\"\"\n        super(Bottle2neck, self).__init__()\n\n        width = int(math.floor(planes * (baseWidth / 64.0)))\n        self.conv1 = nn.Conv2d(inplanes, width * scale, kernel_size=1, bias=False)\n        self.bn1 = nn.BatchNorm2d(width * scale)\n\n        if scale == 1:\n            self.nums = 1\n        else:\n            self.nums = scale - 1\n        if stype == \"stage\":\n            self.pool = nn.AvgPool2d(kernel_size=3, stride=stride, padding=1)\n        convs = []\n        bns = []\n        for i in range(self.nums):\n            convs.append(\n                nn.Conv2d(\n                    width, width, kernel_size=3, stride=stride, padding=1, bias=False\n                )\n            )\n            bns.append(nn.BatchNorm2d(width))\n        self.convs = nn.ModuleList(convs)\n        self.bns = nn.ModuleList(bns)\n\n        self.conv3 = nn.Conv2d(\n            width * scale, planes * self.expansion, kernel_size=1, bias=False\n        )\n        self.bn3 = nn.BatchNorm2d(planes * self.expansion)\n\n        self.relu = nn.ReLU(inplace=True)\n        self.downsample = downsample\n        self.stype = stype\n        self.scale = scale\n        self.width = width\n\n    def forward(self, x):\n        residual = x\n\n        out = self.conv1(x)\n        out = self.bn1(out)\n        out = self.relu(out)\n\n        spx = torch.split(out, self.width, 1)\n        for i in range(self.nums):\n            if i == 0 or self.stype == \"stage\":\n                sp = spx[i]\n            else:\n                sp = sp + spx[i]\n            sp = self.convs[i](sp)\n            sp = self.relu(self.bns[i](sp))\n            if i == 0:\n                out = sp\n            else:\n                out = torch.cat((out, sp), 1)\n        if self.scale != 1 and self.stype == \"normal\":\n            out = torch.cat((out, spx[self.nums]), 1)\n        elif self.scale != 1 and self.stype == \"stage\":\n            out = torch.cat((out, self.pool(spx[self.nums])), 1)\n\n        out = self.conv3(out)\n        out = self.bn3(out)\n\n        if self.downsample is not None:\n            residual = self.downsample(x)\n\n        out += residual\n        out = self.relu(out)\n\n        return out\n\n\nclass Res2Net(nn.Module):\n    def __init__(self, block, layers, baseWidth=26, scale=4, num_classes=1000):\n        self.inplanes = 64\n        super(Res2Net, self).__init__()\n        self.baseWidth = baseWidth\n        self.scale = scale\n        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)\n        self.bn1 = nn.BatchNorm2d(64)\n        self.relu = nn.ReLU(inplace=True)\n        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)\n        self.layer1 = self._make_layer(block, 64, layers[0])\n        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)\n        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)\n        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)\n        self.avgpool = nn.AdaptiveAvgPool2d(1)\n        self.fc = nn.Linear(512 * block.expansion, num_classes)\n\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                nn.init.kaiming_normal_(m.weight, mode=\"fan_out\", nonlinearity=\"relu\")\n            elif isinstance(m, nn.BatchNorm2d):\n                nn.init.constant_(m.weight, 1)\n                nn.init.constant_(m.bias, 0)\n\n    def _make_layer(self, block, planes, blocks, stride=1):\n        downsample = None\n        if stride != 1 or self.inplanes != planes * block.expansion:\n            downsample = nn.Sequential(\n                nn.Conv2d(\n                    self.inplanes,\n                    planes * block.expansion,\n                    kernel_size=1,\n                    stride=stride,\n                    bias=False,\n                ),\n                nn.BatchNorm2d(planes * block.expansion),\n            )\n\n        layers = []\n        layers.append(\n            block(\n                self.inplanes,\n                planes,\n                stride,\n                downsample=downsample,\n                stype=\"stage\",\n                baseWidth=self.baseWidth,\n                scale=self.scale,\n            )\n        )\n        self.inplanes = planes * block.expansion\n        for i in range(1, blocks):\n            layers.append(\n                block(self.inplanes, planes, baseWidth=self.baseWidth, scale=self.scale)\n            )\n\n        return nn.Sequential(*layers)\n\n    def forward(self, x):\n        x = self.conv1(x)\n        x = self.bn1(x)\n        x = self.relu(x)\n        x = self.maxpool(x)\n\n        x = self.layer1(x)\n        x = self.layer2(x)\n        x = self.layer3(x)\n        x = self.layer4(x)\n\n        x = self.avgpool(x)\n        x = x.view(x.size(0), -1)\n        x = self.fc(x)\n\n        return x\n\n\ndef res2net50(pretrained=False, **kwargs):\n    \"\"\"Constructs a Res2Net-50 model.\n    Res2Net-50 refers to the Res2Net-50_26w_4s.\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n    \"\"\"\n    model = Res2Net(Bottle2neck, [3, 4, 6, 3], baseWidth=26, scale=4, **kwargs)\n    return model\n"
  },
  {
    "path": "python/oneflow/test/expensive/pytorch_resmlp.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport torch\nimport torch.nn as nn\nfrom timm.models.layers import trunc_normal_, DropPath, to_2tuple\n\n\n__all__ = [\"resmlp_12\"]\n\n\nclass Mlp(nn.Module):\n    \"\"\" MLP as used in Vision Transformer, MLP-Mixer and related networks\n    \"\"\"\n\n    def __init__(\n        self,\n        in_features,\n        hidden_features=None,\n        out_features=None,\n        act_layer=nn.GELU,\n        drop=0.0,\n    ):\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        drop_probs = to_2tuple(drop)\n\n        self.fc1 = nn.Linear(in_features, hidden_features)\n        self.act = act_layer()\n        self.drop1 = nn.Dropout(drop_probs[0])\n        self.fc2 = nn.Linear(hidden_features, out_features)\n        self.drop2 = nn.Dropout(drop_probs[1])\n\n    def forward(self, x):\n        x = self.fc1(x)\n        x = self.act(x)\n        x = self.drop1(x)\n        x = self.fc2(x)\n        x = self.drop2(x)\n        return x\n\n\nclass PatchEmbed(nn.Module):\n    \"\"\" 2D Image to Patch Embedding\n    \"\"\"\n\n    def __init__(\n        self,\n        img_size=224,\n        patch_size=16,\n        in_chans=3,\n        embed_dim=768,\n        norm_layer=None,\n        flatten=True,\n    ):\n        super().__init__()\n        img_size = to_2tuple(img_size)\n        patch_size = to_2tuple(patch_size)\n        self.img_size = img_size\n        self.patch_size = patch_size\n        self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])\n        self.num_patches = self.grid_size[0] * self.grid_size[1]\n        self.flatten = flatten\n\n        self.proj = nn.Conv2d(\n            in_chans, embed_dim, kernel_size=patch_size, stride=patch_size\n        )\n        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()\n\n    def forward(self, x):\n        B, C, H, W = x.shape\n        assert (\n            H == self.img_size[0],\n            f\"Input image height ({H}) doesn't match model ({self.img_size[0]}).\",\n        )\n        assert (\n            W == self.img_size[1],\n            f\"Input image width ({W}) doesn't match model ({self.img_size[1]}).\",\n        )\n        x = self.proj(x)\n        if self.flatten:\n            x = x.flatten(2).transpose(1, 2)  # BCHW -> BNC\n        x = self.norm(x)\n        return x\n\n\nclass Affine(nn.Module):\n    def __init__(self, dim):\n        super().__init__()\n        self.alpha = nn.Parameter(torch.ones(dim))\n        self.beta = nn.Parameter(torch.zeros(dim))\n\n    def forward(self, x):\n        return self.alpha * x + self.beta\n\n\nclass layers_scale_mlp_blocks(nn.Module):\n    def __init__(\n        self,\n        dim,\n        drop=0.0,\n        drop_path=0.0,\n        act_layer=nn.GELU,\n        init_values=1e-4,\n        num_patches=196,\n    ):\n        super().__init__()\n        self.norm1 = Affine(dim)\n        self.attn = nn.Linear(num_patches, num_patches)\n        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()\n        self.norm2 = Affine(dim)\n        self.mlp = Mlp(\n            in_features=dim,\n            hidden_features=int(4.0 * dim),\n            act_layer=act_layer,\n            drop=drop,\n        )\n        self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)\n        self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)\n\n    def forward(self, x):\n        x = x + self.drop_path(\n            self.gamma_1 * self.attn(self.norm1(x).transpose(1, 2)).transpose(1, 2)\n        )\n        x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))\n        return x\n\n\nclass resmlp_models(nn.Module):\n    def __init__(\n        self,\n        img_size=224,\n        patch_size=16,\n        in_chans=3,\n        num_classes=1000,\n        embed_dim=768,\n        depth=12,\n        drop_rate=0.0,\n        Patch_layer=PatchEmbed,\n        act_layer=nn.GELU,\n        drop_path_rate=0.0,\n        init_scale=1e-4,\n    ):\n        super().__init__()\n\n        self.num_classes = num_classes\n        self.num_features = self.embed_dim = embed_dim\n\n        self.patch_embed = Patch_layer(\n            img_size=img_size,\n            patch_size=patch_size,\n            in_chans=int(in_chans),\n            embed_dim=embed_dim,\n        )\n        num_patches = self.patch_embed.num_patches\n        dpr = [drop_path_rate for i in range(depth)]\n\n        self.blocks = nn.ModuleList(\n            [\n                layers_scale_mlp_blocks(\n                    dim=embed_dim,\n                    drop=drop_rate,\n                    drop_path=dpr[i],\n                    act_layer=act_layer,\n                    init_values=init_scale,\n                    num_patches=num_patches,\n                )\n                for i in range(depth)\n            ]\n        )\n\n        self.norm = Affine(embed_dim)\n\n        self.feature_info = [dict(num_chs=embed_dim, reduction=0, module=\"head\")]\n        self.head = (\n            nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()\n        )\n        self.apply(self._init_weights)\n\n    def _init_weights(self, m):\n        if isinstance(m, nn.Linear):\n            trunc_normal_(m.weight, std=0.02)\n            if m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n        elif isinstance(m, nn.LayerNorm):\n            nn.init.constant_(m.bias, 0)\n            nn.init.constant_(m.weight, 1.0)\n\n    def get_classifier(self):\n        return self.head\n\n    def reset_classifier(self, num_classes, global_pool=\"\"):\n        self.num_classes = num_classes\n        self.head = (\n            nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()\n        )\n\n    def forward_features(self, x):\n        B = x.shape[0]\n\n        x = self.patch_embed(x)\n\n        for i, blk in enumerate(self.blocks):\n            x = blk(x)\n\n        x = self.norm(x)\n        x = x.mean(dim=1).reshape(B, 1, -1)\n\n        return x[:, 0]\n\n    def forward(self, x):\n        x = self.forward_features(x)\n        x = self.head(x)\n        return x\n\n\ndef resmlp_12(pretrained=False, dist=False, **kwargs):\n    model = resmlp_models(\n        patch_size=16,\n        embed_dim=384,\n        depth=12,\n        Patch_layer=PatchEmbed,\n        init_scale=0.1,\n        **kwargs,\n    )\n    return model\n"
  },
  {
    "path": "python/oneflow/test/expensive/pytorch_resnet.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport torch\nfrom torch import Tensor\nimport torch.nn as nn\nfrom typing import Type, Any, Callable, Union, List, Optional\n\n\n__all__ = [\n    \"ResNet\",\n    \"resnet18\",\n    \"resnet34\",\n    \"resnet50\",\n    \"resnet101\",\n    \"resnet152\",\n    \"resnext50_32x4d\",\n    \"resnext101_32x8d\",\n    \"wide_resnet50_2\",\n    \"wide_resnet101_2\",\n]\n\n\ndef conv3x3(\n    in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1\n) -> nn.Conv2d:\n    \"\"\"3x3 convolution with padding\"\"\"\n    return nn.Conv2d(\n        in_planes,\n        out_planes,\n        kernel_size=3,\n        stride=stride,\n        padding=dilation,\n        groups=groups,\n        bias=False,\n        dilation=dilation,\n    )\n\n\ndef conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:\n    \"\"\"1x1 convolution\"\"\"\n    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)\n\n\nclass BasicBlock(nn.Module):\n    expansion: int = 1\n\n    def __init__(\n        self,\n        inplanes: int,\n        planes: int,\n        stride: int = 1,\n        downsample: Optional[nn.Module] = None,\n        groups: int = 1,\n        base_width: int = 64,\n        dilation: int = 1,\n        norm_layer: Optional[Callable[..., nn.Module]] = None,\n    ) -> None:\n        super(BasicBlock, self).__init__()\n        if norm_layer is None:\n            norm_layer = nn.BatchNorm2d\n        if groups != 1 or base_width != 64:\n            raise ValueError(\"BasicBlock only supports groups=1 and base_width=64\")\n        if dilation > 1:\n            raise NotImplementedError(\"Dilation > 1 not supported in BasicBlock\")\n        # Both self.conv1 and self.downsample layers downsample the input when stride != 1\n        self.conv1 = conv3x3(inplanes, planes, stride)\n        self.bn1 = norm_layer(planes)\n        self.relu = nn.ReLU(inplace=True)\n        self.conv2 = conv3x3(planes, planes)\n        self.bn2 = norm_layer(planes)\n        self.downsample = downsample\n        self.stride = stride\n\n    def forward(self, x: Tensor) -> Tensor:\n        identity = x\n\n        out = self.conv1(x)\n        out = self.bn1(out)\n        out = self.relu(out)\n\n        out = self.conv2(out)\n        out = self.bn2(out)\n\n        if self.downsample is not None:\n            identity = self.downsample(x)\n\n        out += identity\n        out = self.relu(out)\n\n        return out\n\n\nclass Bottleneck(nn.Module):\n    # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)\n    # while original implementation places the stride at the first 1x1 convolution(self.conv1)\n    # according to \"Deep residual learning for image recognition\"https://arxiv.org/abs/1512.03385.\n    # This variant is also known as ResNet V1.5 and improves accuracy according to\n    # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.\n\n    expansion: int = 4\n\n    def __init__(\n        self,\n        inplanes: int,\n        planes: int,\n        stride: int = 1,\n        downsample: Optional[nn.Module] = None,\n        groups: int = 1,\n        base_width: int = 64,\n        dilation: int = 1,\n        norm_layer: Optional[Callable[..., nn.Module]] = None,\n    ) -> None:\n        super(Bottleneck, self).__init__()\n        if norm_layer is None:\n            norm_layer = nn.BatchNorm2d\n        width = int(planes * (base_width / 64.0)) * groups\n        # Both self.conv2 and self.downsample layers downsample the input when stride != 1\n        self.conv1 = conv1x1(inplanes, width)\n        self.bn1 = norm_layer(width)\n        self.conv2 = conv3x3(width, width, stride, groups, dilation)\n        self.bn2 = norm_layer(width)\n        self.conv3 = conv1x1(width, planes * self.expansion)\n        self.bn3 = norm_layer(planes * self.expansion)\n        self.relu = nn.ReLU(inplace=True)\n        self.downsample = downsample\n        self.stride = stride\n\n    def forward(self, x: Tensor) -> Tensor:\n        identity = x\n\n        out = self.conv1(x)\n        out = self.bn1(out)\n        out = self.relu(out)\n\n        out = self.conv2(out)\n        out = self.bn2(out)\n        out = self.relu(out)\n\n        out = self.conv3(out)\n        out = self.bn3(out)\n\n        if self.downsample is not None:\n            identity = self.downsample(x)\n\n        out += identity\n        out = self.relu(out)\n\n        return out\n\n\nclass ResNet(nn.Module):\n    def __init__(\n        self,\n        block: Type[Union[BasicBlock, Bottleneck]],\n        layers: List[int],\n        num_classes: int = 1000,\n        zero_init_residual: bool = False,\n        groups: int = 1,\n        width_per_group: int = 64,\n        replace_stride_with_dilation: Optional[List[bool]] = None,\n        norm_layer: Optional[Callable[..., nn.Module]] = None,\n    ) -> None:\n        super(ResNet, self).__init__()\n        if norm_layer is None:\n            norm_layer = nn.BatchNorm2d\n        self._norm_layer = norm_layer\n\n        self.inplanes = 64\n        self.dilation = 1\n        if replace_stride_with_dilation is None:\n            # each element in the tuple indicates if we should replace\n            # the 2x2 stride with a dilated convolution instead\n            replace_stride_with_dilation = [False, False, False]\n        if len(replace_stride_with_dilation) != 3:\n            raise ValueError(\n                \"replace_stride_with_dilation should be None \"\n                \"or a 3-element tuple, got {}\".format(replace_stride_with_dilation)\n            )\n        self.groups = groups\n        self.base_width = width_per_group\n        self.conv1 = nn.Conv2d(\n            3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False\n        )\n        self.bn1 = norm_layer(self.inplanes)\n        self.relu = nn.ReLU(inplace=True)\n        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)\n        self.layer1 = self._make_layer(block, 64, layers[0])\n        self.layer2 = self._make_layer(\n            block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0]\n        )\n        self.layer3 = self._make_layer(\n            block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1]\n        )\n        self.layer4 = self._make_layer(\n            block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2]\n        )\n        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))\n        self.fc = nn.Linear(512 * block.expansion, num_classes)\n\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                nn.init.kaiming_normal_(m.weight, mode=\"fan_out\", nonlinearity=\"relu\")\n            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):\n                nn.init.constant_(m.weight, 1)\n                nn.init.constant_(m.bias, 0)\n\n        # Zero-initialize the last BN in each residual branch,\n        # so that the residual branch starts with zeros, and each residual block behaves like an identity.\n        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677\n        if zero_init_residual:\n            for m in self.modules():\n                if isinstance(m, Bottleneck):\n                    nn.init.constant_(m.bn3.weight, 0)  # type: ignore[arg-type]\n                elif isinstance(m, BasicBlock):\n                    nn.init.constant_(m.bn2.weight, 0)  # type: ignore[arg-type]\n\n    def _make_layer(\n        self,\n        block: Type[Union[BasicBlock, Bottleneck]],\n        planes: int,\n        blocks: int,\n        stride: int = 1,\n        dilate: bool = False,\n    ) -> nn.Sequential:\n        norm_layer = self._norm_layer\n        downsample = None\n        previous_dilation = self.dilation\n        if dilate:\n            self.dilation *= stride\n            stride = 1\n        if stride != 1 or self.inplanes != planes * block.expansion:\n            downsample = nn.Sequential(\n                conv1x1(self.inplanes, planes * block.expansion, stride),\n                norm_layer(planes * block.expansion),\n            )\n\n        layers = []\n        layers.append(\n            block(\n                self.inplanes,\n                planes,\n                stride,\n                downsample,\n                self.groups,\n                self.base_width,\n                previous_dilation,\n                norm_layer,\n            )\n        )\n        self.inplanes = planes * block.expansion\n        for _ in range(1, blocks):\n            layers.append(\n                block(\n                    self.inplanes,\n                    planes,\n                    groups=self.groups,\n                    base_width=self.base_width,\n                    dilation=self.dilation,\n                    norm_layer=norm_layer,\n                )\n            )\n\n        return nn.Sequential(*layers)\n\n    def _forward_impl(self, x: Tensor) -> Tensor:\n        # See note [TorchScript super()]\n        x = self.conv1(x)\n        x = self.bn1(x)\n        x = self.relu(x)\n        x = self.maxpool(x)\n\n        x = self.layer1(x)\n        x = self.layer2(x)\n        x = self.layer3(x)\n        x = self.layer4(x)\n\n        x = self.avgpool(x)\n        x = torch.flatten(x, 1)\n        x = self.fc(x)\n\n        return x\n\n    def forward(self, x: Tensor) -> Tensor:\n        return self._forward_impl(x)\n\n\ndef _resnet(\n    arch: str,\n    block: Type[Union[BasicBlock, Bottleneck]],\n    layers: List[int],\n    pretrained: bool,\n    progress: bool,\n    **kwargs: Any\n) -> ResNet:\n    model = ResNet(block, layers, **kwargs)\n    return model\n\n\ndef resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:\n    r\"\"\"ResNet-18 model from\n    `\"Deep Residual Learning for Image Recognition\" <https://arxiv.org/pdf/1512.03385.pdf>`_.\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n        progress (bool): If True, displays a progress bar of the download to stderr\n    \"\"\"\n    return _resnet(\"resnet18\", BasicBlock, [2, 2, 2, 2], pretrained, progress, **kwargs)\n\n\ndef resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:\n    r\"\"\"ResNet-34 model from\n    `\"Deep Residual Learning for Image Recognition\" <https://arxiv.org/pdf/1512.03385.pdf>`_.\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n        progress (bool): If True, displays a progress bar of the download to stderr\n    \"\"\"\n    return _resnet(\"resnet34\", BasicBlock, [3, 4, 6, 3], pretrained, progress, **kwargs)\n\n\ndef resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:\n    r\"\"\"ResNet-50 model from\n    `\"Deep Residual Learning for Image Recognition\" <https://arxiv.org/pdf/1512.03385.pdf>`_.\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n        progress (bool): If True, displays a progress bar of the download to stderr\n    \"\"\"\n    return _resnet(\"resnet50\", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs)\n\n\ndef resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:\n    r\"\"\"ResNet-101 model from\n    `\"Deep Residual Learning for Image Recognition\" <https://arxiv.org/pdf/1512.03385.pdf>`_.\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n        progress (bool): If True, displays a progress bar of the download to stderr\n    \"\"\"\n    return _resnet(\n        \"resnet101\", Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs\n    )\n\n\ndef resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:\n    r\"\"\"ResNet-152 model from\n    `\"Deep Residual Learning for Image Recognition\" <https://arxiv.org/pdf/1512.03385.pdf>`_.\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n        progress (bool): If True, displays a progress bar of the download to stderr\n    \"\"\"\n    return _resnet(\n        \"resnet152\", Bottleneck, [3, 8, 36, 3], pretrained, progress, **kwargs\n    )\n\n\ndef resnext50_32x4d(\n    pretrained: bool = False, progress: bool = True, **kwargs: Any\n) -> ResNet:\n    r\"\"\"ResNeXt-50 32x4d model from\n    `\"Aggregated Residual Transformation for Deep Neural Networks\" <https://arxiv.org/pdf/1611.05431.pdf>`_.\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n        progress (bool): If True, displays a progress bar of the download to stderr\n    \"\"\"\n    kwargs[\"groups\"] = 32\n    kwargs[\"width_per_group\"] = 4\n    return _resnet(\n        \"resnext50_32x4d\", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs\n    )\n\n\ndef resnext101_32x8d(\n    pretrained: bool = False, progress: bool = True, **kwargs: Any\n) -> ResNet:\n    r\"\"\"ResNeXt-101 32x8d model from\n    `\"Aggregated Residual Transformation for Deep Neural Networks\" <https://arxiv.org/pdf/1611.05431.pdf>`_.\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n        progress (bool): If True, displays a progress bar of the download to stderr\n    \"\"\"\n    kwargs[\"groups\"] = 32\n    kwargs[\"width_per_group\"] = 8\n    return _resnet(\n        \"resnext101_32x8d\", Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs\n    )\n\n\ndef wide_resnet50_2(\n    pretrained: bool = False, progress: bool = True, **kwargs: Any\n) -> ResNet:\n    r\"\"\"Wide ResNet-50-2 model from\n    `\"Wide Residual Networks\" <https://arxiv.org/pdf/1605.07146.pdf>`_.\n    The model is the same as ResNet except for the bottleneck number of channels\n    which is twice larger in every block. The number of channels in outer 1x1\n    convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048\n    channels, and in Wide ResNet-50-2 has 2048-1024-2048.\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n        progress (bool): If True, displays a progress bar of the download to stderr\n    \"\"\"\n    kwargs[\"width_per_group\"] = 64 * 2\n    return _resnet(\n        \"wide_resnet50_2\", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs\n    )\n\n\ndef wide_resnet101_2(\n    pretrained: bool = False, progress: bool = True, **kwargs: Any\n) -> ResNet:\n    r\"\"\"Wide ResNet-101-2 model from\n    `\"Wide Residual Networks\" <https://arxiv.org/pdf/1605.07146.pdf>`_.\n    The model is the same as ResNet except for the bottleneck number of channels\n    which is twice larger in every block. The number of channels in outer 1x1\n    convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048\n    channels, and in Wide ResNet-50-2 has 2048-1024-2048.\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n        progress (bool): If True, displays a progress bar of the download to stderr\n    \"\"\"\n    kwargs[\"width_per_group\"] = 64 * 2\n    return _resnet(\n        \"wide_resnet101_2\", Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs\n    )\n"
  },
  {
    "path": "python/oneflow/test/expensive/pytorch_rexnet.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport torch\nimport torch.nn as nn\nimport math\n\n\n__all__ = [\n    \"ReXNetV1\",\n    \"rexnetv1_1_0\",\n]\n\n\ndef silu(x, inplace=False):\n    return x.mul_(x.sigmoid()) if inplace else x.mul(x.sigmoid())\n\n\nclass SiLU(nn.Module):\n    def __init__(self, inplace=True):\n        super(SiLU, self).__init__()\n        self.inplace = inplace\n\n    def forward(self, x):\n        return silu(x, self.inplace)\n\n\ndef ConvBNAct(\n    out,\n    in_channels,\n    channels,\n    kernel=1,\n    stride=1,\n    pad=0,\n    num_group=1,\n    active=True,\n    relu6=False,\n):\n    out.append(\n        nn.Conv2d(\n            in_channels, channels, kernel, stride, pad, groups=num_group, bias=False\n        )\n    )\n    out.append(nn.BatchNorm2d(channels))\n    if active:\n        out.append(nn.ReLU6(inplace=True) if relu6 else nn.ReLU(inplace=True))\n\n\ndef ConvBNSiLU(out, in_channels, channels, kernel=1, stride=1, pad=0, num_group=1):\n    out.append(\n        nn.Conv2d(\n            in_channels, channels, kernel, stride, pad, groups=num_group, bias=False\n        )\n    )\n    out.append(nn.BatchNorm2d(channels))\n    out.append(SiLU(inplace=True))\n\n\nclass SE(nn.Module):\n    def __init__(self, in_channels, channels, se_ratio=12):\n        super(SE, self).__init__()\n        self.avg_pool = nn.AdaptiveAvgPool2d(1)\n        self.fc = nn.Sequential(\n            nn.Conv2d(in_channels, channels // se_ratio, kernel_size=1, padding=0),\n            nn.BatchNorm2d(channels // se_ratio),\n            nn.ReLU(inplace=True),\n            nn.Conv2d(channels // se_ratio, channels, kernel_size=1, padding=0),\n            nn.Sigmoid(),\n        )\n\n    def forward(self, x):\n        y = self.avg_pool(x)\n        y = self.fc(y)\n        return x * y\n\n\nclass LinearBottleneck(nn.Module):\n    def __init__(\n        self, in_channels, channels, t, stride, use_se=True, se_ratio=12, **kwargs\n    ):\n        super(LinearBottleneck, self).__init__(**kwargs)\n        self.use_shortcut = stride == 1 and in_channels <= channels\n        self.in_channels = in_channels\n        self.out_channels = channels\n\n        out = []\n        if t != 1:\n            dw_channels = in_channels * t\n            ConvBNSiLU(out, in_channels=in_channels, channels=dw_channels)\n        else:\n            dw_channels = in_channels\n\n        ConvBNAct(\n            out,\n            in_channels=dw_channels,\n            channels=dw_channels,\n            kernel=3,\n            stride=stride,\n            pad=1,\n            num_group=dw_channels,\n            active=False,\n        )\n\n        if use_se:\n            out.append(SE(dw_channels, dw_channels, se_ratio))\n\n        out.append(nn.ReLU6())\n        ConvBNAct(\n            out, in_channels=dw_channels, channels=channels, active=False, relu6=True\n        )\n        self.out = nn.Sequential(*out)\n\n    def forward(self, x):\n        out = self.out(x)\n        if self.use_shortcut:\n            out[:, 0 : self.in_channels] += x\n\n        return out\n\n\nclass ReXNetV1(nn.Module):\n    def __init__(\n        self,\n        input_ch=16,\n        final_ch=180,\n        width_mult=1.0,\n        depth_mult=1.0,\n        classes=1000,\n        use_se=True,\n        se_ratio=12,\n        dropout_ratio=0.2,\n        bn_momentum=0.9,\n    ):\n        super(ReXNetV1, self).__init__()\n\n        layers = [1, 2, 2, 3, 3, 5]\n        strides = [1, 2, 2, 2, 1, 2]\n        use_ses = [False, False, True, True, True, True]\n\n        layers = [math.ceil(element * depth_mult) for element in layers]\n        strides = sum(\n            [\n                [element] + [1] * (layers[idx] - 1)\n                for idx, element in enumerate(strides)\n            ],\n            [],\n        )\n        if use_se:\n            use_ses = sum(\n                [[element] * layers[idx] for idx, element in enumerate(use_ses)], []\n            )\n        else:\n            use_ses = [False] * sum(layers[:])\n        ts = [1] * layers[0] + [6] * sum(layers[1:])\n\n        self.depth = sum(layers[:]) * 3\n        stem_channel = 32 / width_mult if width_mult < 1.0 else 32\n        inplanes = input_ch / width_mult if width_mult < 1.0 else input_ch\n\n        features = []\n        in_channels_group = []\n        channels_group = []\n\n        # The following channel configuration is a simple instance to make each layer become an expand layer.\n        for i in range(self.depth // 3):\n            if i == 0:\n                in_channels_group.append(int(round(stem_channel * width_mult)))\n                channels_group.append(int(round(inplanes * width_mult)))\n            else:\n                in_channels_group.append(int(round(inplanes * width_mult)))\n                inplanes += final_ch / (self.depth // 3 * 1.0)\n                channels_group.append(int(round(inplanes * width_mult)))\n\n        ConvBNSiLU(\n            features,\n            3,\n            int(round(stem_channel * width_mult)),\n            kernel=3,\n            stride=2,\n            pad=1,\n        )\n\n        for block_idx, (in_c, c, t, s, se) in enumerate(\n            zip(in_channels_group, channels_group, ts, strides, use_ses)\n        ):\n            features.append(\n                LinearBottleneck(\n                    in_channels=in_c,\n                    channels=c,\n                    t=t,\n                    stride=s,\n                    use_se=se,\n                    se_ratio=se_ratio,\n                )\n            )\n\n        pen_channels = int(1280 * width_mult)\n        ConvBNSiLU(features, c, pen_channels)\n\n        features.append(nn.AdaptiveAvgPool2d(1))\n        self.features = nn.Sequential(*features)\n        self.output = nn.Sequential(\n            nn.Dropout(dropout_ratio), nn.Conv2d(pen_channels, classes, 1, bias=True)\n        )\n\n    def extract_features(self, x):\n        return self.features[:-1](x)\n\n    def forward(self, x):\n        x = self.features(x)\n        x = self.output(x).flatten(1)\n        return x\n\n\ndef _create_rexnetv1(arch, pretrained=False, progress=True, **model_kwargs):\n    model = ReXNetV1(**model_kwargs)\n    return model\n\n\ndef rexnetv1_1_0(pretrained=False, progress=True, **kwargs):\n    \"\"\"\n    Constructs the ReXNet model with width multiplier of 1.0.\n    .. note::\n        ReXNet model with width multiplier of 1.0 from the `Rethinking Channel Dimensions for Efficient Model Design <https://arxiv.org/pdf/2007.00992.pdf>`_ paper.\n    Args:\n        pretrained (bool): Whether to download the pre-trained model on ImageNet. Default: ``False``\n        progress (bool): If True, displays a progress bar of the download to stderr. Default: ``True``\n    \"\"\"\n    model_kwargs = dict(width_mult=1.0, **kwargs)\n    return _create_rexnetv1(\n        \"rexnetv1_1_0\", pretrained=pretrained, progress=progress, **model_kwargs\n    )\n"
  },
  {
    "path": "python/oneflow/test/expensive/pytorch_rexnetv1_lite.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport torch\nimport torch.nn as nn\nfrom math import ceil\n\n__all__ = [\n    \"ReXNetV1_lite\",\n    \"rexnet_lite_1_0\",\n]\n\n\ndef _make_divisible(channel_size, divisor=None, min_value=None):\n    \"\"\"\n    This function is taken from the original tf repo.\n    It ensures that all layers have a channel number that is divisible by 8\n    It can be seen here:\n    https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py\n    \"\"\"\n    if not divisor:\n        return channel_size\n\n    if min_value is None:\n        min_value = divisor\n    new_channel_size = max(\n        min_value, int(channel_size + divisor / 2) // divisor * divisor\n    )\n    # Make sure that round down does not go down by more than 10%.\n    if new_channel_size < 0.9 * channel_size:\n        new_channel_size += divisor\n    return new_channel_size\n\n\ndef _add_conv(\n    out,\n    in_channels,\n    channels,\n    kernel=1,\n    stride=1,\n    pad=0,\n    num_group=1,\n    active=True,\n    relu6=True,\n    bn_momentum=0.1,\n    bn_eps=1e-5,\n):\n    out.append(\n        nn.Conv2d(\n            in_channels, channels, kernel, stride, pad, groups=num_group, bias=False\n        )\n    )\n    out.append(nn.BatchNorm2d(channels, momentum=bn_momentum, eps=bn_eps))\n    if active:\n        out.append(nn.ReLU6(inplace=True) if relu6 else nn.ReLU(inplace=True))\n\n\nclass LinearBottleneck(nn.Module):\n    def __init__(\n        self,\n        in_channels,\n        channels,\n        t,\n        kernel_size=3,\n        stride=1,\n        bn_momentum=0.1,\n        bn_eps=1e-5,\n        **kwargs\n    ):\n        super(LinearBottleneck, self).__init__(**kwargs)\n        self.conv_shortcut = None\n        self.use_shortcut = stride == 1 and in_channels <= channels\n        self.in_channels = in_channels\n        self.out_channels = channels\n        out = []\n        if t != 1:\n            dw_channels = in_channels * t\n            _add_conv(\n                out,\n                in_channels=in_channels,\n                channels=dw_channels,\n                bn_momentum=bn_momentum,\n                bn_eps=bn_eps,\n            )\n        else:\n            dw_channels = in_channels\n\n        _add_conv(\n            out,\n            in_channels=dw_channels,\n            channels=dw_channels * 1,\n            kernel=kernel_size,\n            stride=stride,\n            pad=(kernel_size // 2),\n            num_group=dw_channels,\n            bn_momentum=bn_momentum,\n            bn_eps=bn_eps,\n        )\n\n        _add_conv(\n            out,\n            in_channels=dw_channels,\n            channels=channels,\n            active=False,\n            bn_momentum=bn_momentum,\n            bn_eps=bn_eps,\n        )\n\n        self.out = nn.Sequential(*out)\n\n    def forward(self, x):\n        out = self.out(x)\n\n        if self.use_shortcut:\n            out[:, 0 : self.in_channels] += x\n        return out\n\n\nclass ReXNetV1_lite(nn.Module):\n    def __init__(\n        self,\n        fix_head_stem=False,\n        divisible_value=8,\n        input_ch=16,\n        final_ch=164,\n        multiplier=1.0,\n        classes=1000,\n        dropout_ratio=0.2,\n        bn_momentum=0.1,\n        bn_eps=1e-5,\n        kernel_conf=\"333333\",\n    ):\n        super(ReXNetV1_lite, self).__init__()\n\n        layers = [1, 2, 2, 3, 3, 5]\n        strides = [1, 2, 2, 2, 1, 2]\n        kernel_sizes = [int(element) for element in kernel_conf]\n\n        strides = sum(\n            [\n                [element] + [1] * (layers[idx] - 1)\n                for idx, element in enumerate(strides)\n            ],\n            [],\n        )\n        ts = [1] * layers[0] + [6] * sum(layers[1:])\n        kernel_sizes = sum(\n            [[element] * layers[idx] for idx, element in enumerate(kernel_sizes)], []\n        )\n        self.num_convblocks = sum(layers[:])\n\n        features = []\n        inplanes = input_ch / multiplier if multiplier < 1.0 else input_ch\n        first_channel = 32 / multiplier if multiplier < 1.0 or fix_head_stem else 32\n        first_channel = _make_divisible(\n            int(round(first_channel * multiplier)), divisible_value\n        )\n\n        in_channels_group = []\n        channels_group = []\n\n        _add_conv(\n            features,\n            3,\n            first_channel,\n            kernel=3,\n            stride=2,\n            pad=1,\n            bn_momentum=bn_momentum,\n            bn_eps=bn_eps,\n        )\n\n        for i in range(self.num_convblocks):\n            inplanes_divisible = _make_divisible(\n                int(round(inplanes * multiplier)), divisible_value\n            )\n            if i == 0:\n                in_channels_group.append(first_channel)\n                channels_group.append(inplanes_divisible)\n            else:\n                in_channels_group.append(inplanes_divisible)\n                inplanes += final_ch / (self.num_convblocks - 1 * 1.0)\n                inplanes_divisible = _make_divisible(\n                    int(round(inplanes * multiplier)), divisible_value\n                )\n                channels_group.append(inplanes_divisible)\n\n        for block_idx, (in_c, c, t, k, s) in enumerate(\n            zip(in_channels_group, channels_group, ts, kernel_sizes, strides)\n        ):\n            features.append(\n                LinearBottleneck(\n                    in_channels=in_c,\n                    channels=c,\n                    t=t,\n                    kernel_size=k,\n                    stride=s,\n                    bn_momentum=bn_momentum,\n                    bn_eps=bn_eps,\n                )\n            )\n\n        pen_channels = (\n            int(1280 * multiplier) if multiplier > 1 and not fix_head_stem else 1280\n        )\n        _add_conv(features, c, pen_channels, bn_momentum=bn_momentum, bn_eps=bn_eps)\n\n        self.features = nn.Sequential(*features)\n        self.avgpool = nn.AdaptiveAvgPool2d(1)\n\n        self.output = nn.Sequential(\n            nn.Conv2d(pen_channels, 1024, 1, bias=True),\n            nn.BatchNorm2d(1024, momentum=bn_momentum, eps=bn_eps),\n            nn.ReLU6(inplace=True),\n            nn.Dropout(dropout_ratio),\n            nn.Conv2d(1024, classes, 1, bias=True),\n        )\n\n    def forward(self, x):\n        x = self.features(x)\n        x = self.avgpool(x)\n        x = self.output(x).flatten(1)\n        return x\n\n\ndef _create_rexnet_lite(arch, pretrained=False, progress=True, **model_kwargs):\n    model = ReXNetV1_lite(**model_kwargs)\n    return model\n\n\ndef rexnet_lite_1_0(pretrained=False, progress=True, **kwargs):\n    \"\"\"\n    Constructs the ReXNet-lite model with width multiplier of 1.0.\n    .. note::\n        ReXNet-lite model with width multiplier of 1.0 from the `Rethinking Channel Dimensions for Efficient Model Design <https://arxiv.org/pdf/2007.00992.pdf>`_ paper.\n    Args:\n        pretrained (bool): Whether to download the pre-trained model on ImageNet. Default: ``False``\n        progress (bool): If True, displays a progress bar of the download to stderr. Default: ``True``\n    For example:\n    .. code-block:: python\n        >>> import flowvision\n        >>> rexnet_lite_1_0 = flowvision.models.rexnet_lite_1_0(pretrained=False, progress=True)\n    \"\"\"\n    model_kwargs = dict(multiplier=1.0, **kwargs)\n    return _create_rexnet_lite(\n        \"rexnet_lite_1_0\", pretrained=pretrained, progress=progress, **model_kwargs\n    )\n"
  },
  {
    "path": "python/oneflow/test/expensive/pytorch_senet.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nfrom __future__ import print_function, division, absolute_import\nfrom collections import OrderedDict\nimport math\nimport torch.nn as nn\n\n__all__ = [\"SENet\", \"senet154\"]\n\n\nclass SEModule(nn.Module):\n    def __init__(self, channels, reduction):\n        super(SEModule, self).__init__()\n        self.avg_pool = nn.AdaptiveAvgPool2d(1)\n        self.fc1 = nn.Conv2d(channels, channels // reduction, kernel_size=1, padding=0)\n        self.relu = nn.ReLU(inplace=True)\n        self.fc2 = nn.Conv2d(channels // reduction, channels, kernel_size=1, padding=0)\n        self.sigmoid = nn.Sigmoid()\n\n    def forward(self, x):\n        module_input = x\n        x = self.avg_pool(x)\n        x = self.fc1(x)\n        x = self.relu(x)\n        x = self.fc2(x)\n        x = self.sigmoid(x)\n        return module_input * x\n\n\nclass Bottleneck(nn.Module):\n    \"\"\"\n    Base class for bottlenecks that implements `forward()` method.\n    \"\"\"\n\n    def forward(self, x):\n        residual = x\n\n        out = self.conv1(x)\n        out = self.bn1(out)\n        out = self.relu(out)\n\n        out = self.conv2(out)\n        out = self.bn2(out)\n        out = self.relu(out)\n\n        out = self.conv3(out)\n        out = self.bn3(out)\n\n        if self.downsample is not None:\n            residual = self.downsample(x)\n\n        out = self.se_module(out) + residual\n        out = self.relu(out)\n\n        return out\n\n\nclass SEBottleneck(Bottleneck):\n    \"\"\"\n    Bottleneck for SENet154.\n    \"\"\"\n\n    expansion = 4\n\n    def __init__(self, inplanes, planes, groups, reduction, stride=1, downsample=None):\n        super(SEBottleneck, self).__init__()\n        self.conv1 = nn.Conv2d(inplanes, planes * 2, kernel_size=1, bias=False)\n        self.bn1 = nn.BatchNorm2d(planes * 2)\n        self.conv2 = nn.Conv2d(\n            planes * 2,\n            planes * 4,\n            kernel_size=3,\n            stride=stride,\n            padding=1,\n            groups=groups,\n            bias=False,\n        )\n        self.bn2 = nn.BatchNorm2d(planes * 4)\n        self.conv3 = nn.Conv2d(planes * 4, planes * 4, kernel_size=1, bias=False)\n        self.bn3 = nn.BatchNorm2d(planes * 4)\n        self.relu = nn.ReLU(inplace=True)\n        self.se_module = SEModule(planes * 4, reduction=reduction)\n        self.downsample = downsample\n        self.stride = stride\n\n\nclass SEResNetBottleneck(Bottleneck):\n    \"\"\"\n    ResNet bottleneck with a Squeeze-and-Excitation module. It follows Caffe\n    implementation and uses `stride=stride` in `conv1` and not in `conv2`\n    (the latter is used in the torchvision implementation of ResNet).\n    \"\"\"\n\n    expansion = 4\n\n    def __init__(self, inplanes, planes, groups, reduction, stride=1, downsample=None):\n        super(SEResNetBottleneck, self).__init__()\n        self.conv1 = nn.Conv2d(\n            inplanes, planes, kernel_size=1, bias=False, stride=stride\n        )\n        self.bn1 = nn.BatchNorm2d(planes)\n        self.conv2 = nn.Conv2d(\n            planes, planes, kernel_size=3, padding=1, groups=groups, bias=False\n        )\n        self.bn2 = nn.BatchNorm2d(planes)\n        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)\n        self.bn3 = nn.BatchNorm2d(planes * 4)\n        self.relu = nn.ReLU(inplace=True)\n        self.se_module = SEModule(planes * 4, reduction=reduction)\n        self.downsample = downsample\n        self.stride = stride\n\n\nclass SEResNeXtBottleneck(Bottleneck):\n    \"\"\"\n    ResNeXt bottleneck type C with a Squeeze-and-Excitation module.\n    \"\"\"\n\n    expansion = 4\n\n    def __init__(\n        self,\n        inplanes,\n        planes,\n        groups,\n        reduction,\n        stride=1,\n        downsample=None,\n        base_width=4,\n    ):\n        super(SEResNeXtBottleneck, self).__init__()\n        width = math.floor(planes * (base_width / 64)) * groups\n        self.conv1 = nn.Conv2d(inplanes, width, kernel_size=1, bias=False, stride=1)\n        self.bn1 = nn.BatchNorm2d(width)\n        self.conv2 = nn.Conv2d(\n            width,\n            width,\n            kernel_size=3,\n            stride=stride,\n            padding=1,\n            groups=groups,\n            bias=False,\n        )\n        self.bn2 = nn.BatchNorm2d(width)\n        self.conv3 = nn.Conv2d(width, planes * 4, kernel_size=1, bias=False)\n        self.bn3 = nn.BatchNorm2d(planes * 4)\n        self.relu = nn.ReLU(inplace=True)\n        self.se_module = SEModule(planes * 4, reduction=reduction)\n        self.downsample = downsample\n        self.stride = stride\n\n\nclass SENet(nn.Module):\n    def __init__(\n        self,\n        block,\n        layers,\n        groups,\n        reduction,\n        dropout_p=0.2,\n        inplanes=128,\n        input_3x3=True,\n        downsample_kernel_size=3,\n        downsample_padding=1,\n        num_classes=1000,\n    ):\n        \"\"\"\n        Parameters\n        ----------\n        block (nn.Module): Bottleneck class.\n            - For SENet154: SEBottleneck\n            - For SE-ResNet models: SEResNetBottleneck\n            - For SE-ResNeXt models:  SEResNeXtBottleneck\n        layers (list of ints): Number of residual blocks for 4 layers of the\n            network (layer1...layer4).\n        groups (int): Number of groups for the 3x3 convolution in each\n            bottleneck block.\n            - For SENet154: 64\n            - For SE-ResNet models: 1\n            - For SE-ResNeXt models:  32\n        reduction (int): Reduction ratio for Squeeze-and-Excitation modules.\n            - For all models: 16\n        dropout_p (float or None): Drop probability for the Dropout layer.\n            If `None` the Dropout layer is not used.\n            - For SENet154: 0.2\n            - For SE-ResNet models: None\n            - For SE-ResNeXt models: None\n        inplanes (int):  Number of input channels for layer1.\n            - For SENet154: 128\n            - For SE-ResNet models: 64\n            - For SE-ResNeXt models: 64\n        input_3x3 (bool): If `True`, use three 3x3 convolutions instead of\n            a single 7x7 convolution in layer0.\n            - For SENet154: True\n            - For SE-ResNet models: False\n            - For SE-ResNeXt models: False\n        downsample_kernel_size (int): Kernel size for downsampling convolutions\n            in layer2, layer3 and layer4.\n            - For SENet154: 3\n            - For SE-ResNet models: 1\n            - For SE-ResNeXt models: 1\n        downsample_padding (int): Padding for downsampling convolutions in\n            layer2, layer3 and layer4.\n            - For SENet154: 1\n            - For SE-ResNet models: 0\n            - For SE-ResNeXt models: 0\n        num_classes (int): Number of outputs in `last_linear` layer.\n            - For all models: 1000\n        \"\"\"\n        super(SENet, self).__init__()\n        self.inplanes = inplanes\n        if input_3x3:\n            layer0_modules = [\n                (\"conv1\", nn.Conv2d(3, 64, 3, stride=2, padding=1, bias=False)),\n                (\"bn1\", nn.BatchNorm2d(64)),\n                (\"relu1\", nn.ReLU(inplace=True)),\n                (\"conv2\", nn.Conv2d(64, 64, 3, stride=1, padding=1, bias=False)),\n                (\"bn2\", nn.BatchNorm2d(64)),\n                (\"relu2\", nn.ReLU(inplace=True)),\n                (\"conv3\", nn.Conv2d(64, inplanes, 3, stride=1, padding=1, bias=False)),\n                (\"bn3\", nn.BatchNorm2d(inplanes)),\n                (\"relu3\", nn.ReLU(inplace=True)),\n            ]\n        else:\n            layer0_modules = [\n                (\n                    \"conv1\",\n                    nn.Conv2d(\n                        3, inplanes, kernel_size=7, stride=2, padding=3, bias=False\n                    ),\n                ),\n                (\"bn1\", nn.BatchNorm2d(inplanes)),\n                (\"relu1\", nn.ReLU(inplace=True)),\n            ]\n        # To preserve compatibility with Caffe weights `ceil_mode=True`\n        # is used instead of `padding=1`.\n        layer0_modules.append((\"pool\", nn.MaxPool2d(3, stride=2, ceil_mode=True)))\n        self.layer0 = nn.Sequential(OrderedDict(layer0_modules))\n        self.layer1 = self._make_layer(\n            block,\n            planes=64,\n            blocks=layers[0],\n            groups=groups,\n            reduction=reduction,\n            downsample_kernel_size=1,\n            downsample_padding=0,\n        )\n        self.layer2 = self._make_layer(\n            block,\n            planes=128,\n            blocks=layers[1],\n            stride=2,\n            groups=groups,\n            reduction=reduction,\n            downsample_kernel_size=downsample_kernel_size,\n            downsample_padding=downsample_padding,\n        )\n        self.layer3 = self._make_layer(\n            block,\n            planes=256,\n            blocks=layers[2],\n            stride=2,\n            groups=groups,\n            reduction=reduction,\n            downsample_kernel_size=downsample_kernel_size,\n            downsample_padding=downsample_padding,\n        )\n        self.layer4 = self._make_layer(\n            block,\n            planes=512,\n            blocks=layers[3],\n            stride=2,\n            groups=groups,\n            reduction=reduction,\n            downsample_kernel_size=downsample_kernel_size,\n            downsample_padding=downsample_padding,\n        )\n        self.avg_pool = nn.AvgPool2d(7, stride=1)\n        self.dropout = nn.Dropout(dropout_p) if dropout_p is not None else None\n        self.last_linear = nn.Linear(512 * block.expansion, num_classes)\n\n    def _make_layer(\n        self,\n        block,\n        planes,\n        blocks,\n        groups,\n        reduction,\n        stride=1,\n        downsample_kernel_size=1,\n        downsample_padding=0,\n    ):\n        downsample = None\n        if stride != 1 or self.inplanes != planes * block.expansion:\n            downsample = nn.Sequential(\n                nn.Conv2d(\n                    self.inplanes,\n                    planes * block.expansion,\n                    kernel_size=downsample_kernel_size,\n                    stride=stride,\n                    padding=downsample_padding,\n                    bias=False,\n                ),\n                nn.BatchNorm2d(planes * block.expansion),\n            )\n\n        layers = []\n        layers.append(\n            block(self.inplanes, planes, groups, reduction, stride, downsample)\n        )\n        self.inplanes = planes * block.expansion\n        for i in range(1, blocks):\n            layers.append(block(self.inplanes, planes, groups, reduction))\n\n        return nn.Sequential(*layers)\n\n    def features(self, x):\n        x = self.layer0(x)\n        x = self.layer1(x)\n        x = self.layer2(x)\n        x = self.layer3(x)\n        x = self.layer4(x)\n        return x\n\n    def logits(self, x):\n        x = self.avg_pool(x)\n        if self.dropout is not None:\n            x = self.dropout(x)\n        x = x.view(x.size(0), -1)\n        x = self.last_linear(x)\n        return x\n\n    def forward(self, x):\n        x = self.features(x)\n        x = self.logits(x)\n        return x\n\n\ndef senet154(num_classes=1000, pretrained=\"imagenet\"):\n    model = SENet(\n        SEBottleneck,\n        [3, 8, 12, 3],\n        groups=64,\n        reduction=16,\n        dropout_p=0.2,\n        num_classes=num_classes,\n    )\n    return model\n"
  },
  {
    "path": "python/oneflow/test/expensive/pytorch_shufflenetv2.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport torch\nimport torch.nn as nn\nfrom torch import Tensor\n\nfrom typing import Callable, Any, List\n\n__all__ = [\n    \"ShuffleNetV2\",\n    \"shufflenet_v2_x2_0\",\n]\n\n\ndef channel_shuffle(x: Tensor, groups: int) -> Tensor:\n    batchsize, num_channels, height, width = x.size()\n    channels_per_group = num_channels // groups\n\n    # reshape\n    x = x.view(batchsize, groups, channels_per_group, height, width)\n\n    x = torch.transpose(x, 1, 2).contiguous()\n\n    # flatten\n    x = x.view(batchsize, -1, height, width)\n\n    return x\n\n\nclass InvertedResidual(nn.Module):\n    def __init__(self, inp: int, oup: int, stride: int) -> None:\n        super().__init__()\n\n        if not (1 <= stride <= 3):\n            raise ValueError(\"illegal stride value\")\n        self.stride = stride\n\n        branch_features = oup // 2\n        if (self.stride == 1) and (inp != branch_features << 1):\n            raise ValueError(\n                f\"Invalid combination of stride {stride}, inp {inp} and oup {oup} values. If stride == 1 then inp should be equal to oup // 2 << 1.\"\n            )\n\n        if self.stride > 1:\n            self.branch1 = nn.Sequential(\n                self.depthwise_conv(\n                    inp, inp, kernel_size=3, stride=self.stride, padding=1\n                ),\n                nn.BatchNorm2d(inp),\n                nn.Conv2d(\n                    inp, branch_features, kernel_size=1, stride=1, padding=0, bias=False\n                ),\n                nn.BatchNorm2d(branch_features),\n                nn.ReLU(inplace=True),\n            )\n        else:\n            self.branch1 = nn.Sequential()\n\n        self.branch2 = nn.Sequential(\n            nn.Conv2d(\n                inp if (self.stride > 1) else branch_features,\n                branch_features,\n                kernel_size=1,\n                stride=1,\n                padding=0,\n                bias=False,\n            ),\n            nn.BatchNorm2d(branch_features),\n            nn.ReLU(inplace=True),\n            self.depthwise_conv(\n                branch_features,\n                branch_features,\n                kernel_size=3,\n                stride=self.stride,\n                padding=1,\n            ),\n            nn.BatchNorm2d(branch_features),\n            nn.Conv2d(\n                branch_features,\n                branch_features,\n                kernel_size=1,\n                stride=1,\n                padding=0,\n                bias=False,\n            ),\n            nn.BatchNorm2d(branch_features),\n            nn.ReLU(inplace=True),\n        )\n\n    @staticmethod\n    def depthwise_conv(\n        i: int,\n        o: int,\n        kernel_size: int,\n        stride: int = 1,\n        padding: int = 0,\n        bias: bool = False,\n    ) -> nn.Conv2d:\n        return nn.Conv2d(i, o, kernel_size, stride, padding, bias=bias, groups=i)\n\n    def forward(self, x: Tensor) -> Tensor:\n        if self.stride == 1:\n            x1, x2 = x.chunk(2, dim=1)\n            out = torch.cat((x1, self.branch2(x2)), dim=1)\n        else:\n            out = torch.cat((self.branch1(x), self.branch2(x)), dim=1)\n\n        out = channel_shuffle(out, 2)\n\n        return out\n\n\nclass ShuffleNetV2(nn.Module):\n    def __init__(\n        self,\n        stages_repeats: List[int],\n        stages_out_channels: List[int],\n        num_classes: int = 1000,\n        inverted_residual: Callable[..., nn.Module] = InvertedResidual,\n    ) -> None:\n        super().__init__()\n        if len(stages_repeats) != 3:\n            raise ValueError(\"expected stages_repeats as list of 3 positive ints\")\n        if len(stages_out_channels) != 5:\n            raise ValueError(\"expected stages_out_channels as list of 5 positive ints\")\n        self._stage_out_channels = stages_out_channels\n\n        input_channels = 3\n        output_channels = self._stage_out_channels[0]\n        self.conv1 = nn.Sequential(\n            nn.Conv2d(input_channels, output_channels, 3, 2, 1, bias=False),\n            nn.BatchNorm2d(output_channels),\n            nn.ReLU(inplace=True),\n        )\n        input_channels = output_channels\n\n        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)\n\n        # Static annotations for mypy\n        self.stage2: nn.Sequential\n        self.stage3: nn.Sequential\n        self.stage4: nn.Sequential\n        stage_names = [f\"stage{i}\" for i in [2, 3, 4]]\n        for name, repeats, output_channels in zip(\n            stage_names, stages_repeats, self._stage_out_channels[1:]\n        ):\n            seq = [inverted_residual(input_channels, output_channels, 2)]\n            for i in range(repeats - 1):\n                seq.append(inverted_residual(output_channels, output_channels, 1))\n            setattr(self, name, nn.Sequential(*seq))\n            input_channels = output_channels\n\n        output_channels = self._stage_out_channels[-1]\n        self.conv5 = nn.Sequential(\n            nn.Conv2d(input_channels, output_channels, 1, 1, 0, bias=False),\n            nn.BatchNorm2d(output_channels),\n            nn.ReLU(inplace=True),\n        )\n\n        self.fc = nn.Linear(output_channels, num_classes)\n\n    def _forward_impl(self, x: Tensor) -> Tensor:\n        # See note [TorchScript super()]\n        x = self.conv1(x)\n        x = self.maxpool(x)\n        x = self.stage2(x)\n        x = self.stage3(x)\n        x = self.stage4(x)\n        x = self.conv5(x)\n        x = x.mean([2, 3])  # globalpool\n        x = self.fc(x)\n        return x\n\n    def forward(self, x: Tensor) -> Tensor:\n        return self._forward_impl(x)\n\n\ndef _shufflenetv2(progress: bool, *args: Any, **kwargs: Any,) -> ShuffleNetV2:\n    model = ShuffleNetV2(*args, **kwargs)\n    return model\n\n\ndef shufflenet_v2_x2_0(progress: bool = True, **kwargs: Any) -> ShuffleNetV2:\n    \"\"\"\n    Constructs a ShuffleNetV2 with 2.0x output channels, as described in\n    `\"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design\"\n    <https://arxiv.org/abs/1807.11164>`_.\n    Args:\n        weights (ShuffleNet_V2_X2_0_Weights, optional): The pretrained weights for the model\n        progress (bool): If True, displays a progress bar of the download to stderr\n    \"\"\"\n    return _shufflenetv2(progress, [4, 8, 4], [24, 244, 488, 976, 2048], **kwargs)\n"
  },
  {
    "path": "python/oneflow/test/expensive/pytorch_squeezenet.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport torch\nimport torch.nn as nn\nimport torch.nn.init as init\n\nfrom typing import Any\n\n__all__ = [\"SqueezeNet\", \"squeezenet1_1\"]\n\n\nclass Fire(nn.Module):\n    def __init__(\n        self,\n        inplanes: int,\n        squeeze_planes: int,\n        expand1x1_planes: int,\n        expand3x3_planes: int,\n    ) -> None:\n        super().__init__()\n        self.inplanes = inplanes\n        self.squeeze = nn.Conv2d(inplanes, squeeze_planes, kernel_size=1)\n        self.squeeze_activation = nn.ReLU(inplace=True)\n        self.expand1x1 = nn.Conv2d(squeeze_planes, expand1x1_planes, kernel_size=1)\n        self.expand1x1_activation = nn.ReLU(inplace=True)\n        self.expand3x3 = nn.Conv2d(\n            squeeze_planes, expand3x3_planes, kernel_size=3, padding=1\n        )\n        self.expand3x3_activation = nn.ReLU(inplace=True)\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        x = self.squeeze_activation(self.squeeze(x))\n        return torch.cat(\n            [\n                self.expand1x1_activation(self.expand1x1(x)),\n                self.expand3x3_activation(self.expand3x3(x)),\n            ],\n            1,\n        )\n\n\nclass SqueezeNet(nn.Module):\n    def __init__(\n        self, version: str = \"1_0\", num_classes: int = 1000, dropout: float = 0.5\n    ) -> None:\n        super().__init__()\n        self.num_classes = num_classes\n        if version == \"1_0\":\n            self.features = nn.Sequential(\n                nn.Conv2d(3, 96, kernel_size=7, stride=2),\n                nn.ReLU(inplace=True),\n                nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),\n                Fire(96, 16, 64, 64),\n                Fire(128, 16, 64, 64),\n                Fire(128, 32, 128, 128),\n                nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),\n                Fire(256, 32, 128, 128),\n                Fire(256, 48, 192, 192),\n                Fire(384, 48, 192, 192),\n                Fire(384, 64, 256, 256),\n                nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),\n                Fire(512, 64, 256, 256),\n            )\n        elif version == \"1_1\":\n            self.features = nn.Sequential(\n                nn.Conv2d(3, 64, kernel_size=3, stride=2),\n                nn.ReLU(inplace=True),\n                nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),\n                Fire(64, 16, 64, 64),\n                Fire(128, 16, 64, 64),\n                nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),\n                Fire(128, 32, 128, 128),\n                Fire(256, 32, 128, 128),\n                nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),\n                Fire(256, 48, 192, 192),\n                Fire(384, 48, 192, 192),\n                Fire(384, 64, 256, 256),\n                Fire(512, 64, 256, 256),\n            )\n        else:\n            # FIXME: Is this needed? SqueezeNet should only be called from the\n            # FIXME: squeezenet1_x() functions\n            # FIXME: This checking is not done for the other models\n            raise ValueError(\n                f\"Unsupported SqueezeNet version {version}: 1_0 or 1_1 expected\"\n            )\n\n        # Final convolution is initialized differently from the rest\n        final_conv = nn.Conv2d(512, self.num_classes, kernel_size=1)\n        self.classifier = nn.Sequential(\n            nn.Dropout(p=dropout),\n            final_conv,\n            nn.ReLU(inplace=True),\n            nn.AdaptiveAvgPool2d((1, 1)),\n        )\n\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                if m is final_conv:\n                    init.normal_(m.weight, mean=0.0, std=0.01)\n                else:\n                    init.kaiming_uniform_(m.weight)\n                if m.bias is not None:\n                    init.constant_(m.bias, 0)\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        x = self.features(x)\n        x = self.classifier(x)\n        return torch.flatten(x, 1)\n\n\ndef _squeezenet(version: str, progress: bool, **kwargs: Any,) -> SqueezeNet:\n    model = SqueezeNet(version, **kwargs)\n    return model\n\n\ndef squeezenet1_1(progress: bool = True, **kwargs: Any) -> SqueezeNet:\n    r\"\"\"SqueezeNet 1.1 model from the `official SqueezeNet repo\n    <https://github.com/DeepScale/SqueezeNet/tree/master/SqueezeNet_v1.1>`_.\n    SqueezeNet 1.1 has 2.4x less computation and slightly fewer parameters\n    than SqueezeNet 1.0, without sacrificing accuracy.\n    The required minimum input size of the model is 17x17.\n    Args:\n        weights (SqueezeNet1_1_Weights, optional): The pretrained weights for the model\n        progress (bool): If True, displays a progress bar of the download to stderr\n    \"\"\"\n    return _squeezenet(\"1_1\", progress, **kwargs)\n"
  },
  {
    "path": "python/oneflow/test/expensive/pytorch_swin_transformer.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport torch\nimport torch.nn as nn\nfrom timm.models.layers import DropPath, to_2tuple, trunc_normal_\n\n\nclass Mlp(nn.Module):\n    def __init__(\n        self,\n        in_features,\n        hidden_features=None,\n        out_features=None,\n        act_layer=nn.GELU,\n        drop=0.0,\n    ):\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        self.fc1 = nn.Linear(in_features, hidden_features)\n        self.act = act_layer()\n        self.fc2 = nn.Linear(hidden_features, out_features)\n        self.drop = nn.Dropout(drop)\n\n    def forward(self, x):\n        x = self.fc1(x)\n        x = self.act(x)\n        x = self.drop(x)\n        x = self.fc2(x)\n        x = self.drop(x)\n        return x\n\n\ndef window_partition(x, window_size):\n    \"\"\"\n    Args:\n        x: (B, H, W, C)\n        window_size (int): window size\n    Returns:\n        windows: (num_windows*B, window_size, window_size, C)\n    \"\"\"\n    B, H, W, C = x.shape\n    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)\n    windows = (\n        x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)\n    )\n    return windows\n\n\ndef window_reverse(windows, window_size, H, W):\n    \"\"\"\n    Args:\n        windows: (num_windows*B, window_size, window_size, C)\n        window_size (int): Window size\n        H (int): Height of image\n        W (int): Width of image\n    Returns:\n        x: (B, H, W, C)\n    \"\"\"\n    B = int(windows.shape[0] / (H * W / window_size / window_size))\n    x = windows.view(\n        B, H // window_size, W // window_size, window_size, window_size, -1\n    )\n    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)\n    return x\n\n\nclass WindowAttention(nn.Module):\n    r\"\"\" Window based multi-head self attention (W-MSA) module with relative position bias.\n    It supports both of shifted and non-shifted window.\n    Args:\n        dim (int): Number of input channels.\n        window_size (tuple[int]): The height and width of the window.\n        num_heads (int): Number of attention heads.\n        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True\n        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set\n        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0\n        proj_drop (float, optional): Dropout ratio of output. Default: 0.0\n    \"\"\"\n\n    def __init__(\n        self,\n        dim,\n        window_size,\n        num_heads,\n        qkv_bias=True,\n        qk_scale=None,\n        attn_drop=0.0,\n        proj_drop=0.0,\n    ):\n\n        super().__init__()\n        self.dim = dim\n        self.window_size = window_size  # Wh, Ww\n        self.num_heads = num_heads\n        head_dim = dim // num_heads\n        self.scale = qk_scale or head_dim ** -0.5\n\n        # define a parameter table of relative position bias\n        self.relative_position_bias_table = nn.Parameter(\n            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)\n        )  # 2*Wh-1 * 2*Ww-1, nH\n\n        # get pair-wise relative position index for each token inside the window\n        coords_h = torch.arange(self.window_size[0])\n        coords_w = torch.arange(self.window_size[1])\n        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww\n        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww\n        relative_coords = (\n            coords_flatten[:, :, None] - coords_flatten[:, None, :]\n        )  # 2, Wh*Ww, Wh*Ww\n        relative_coords = relative_coords.permute(\n            1, 2, 0\n        ).contiguous()  # Wh*Ww, Wh*Ww, 2\n        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0\n        relative_coords[:, :, 1] += self.window_size[1] - 1\n        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1\n        relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww\n        self.register_buffer(\"relative_position_index\", relative_position_index)\n\n        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)\n        self.attn_drop = nn.Dropout(attn_drop)\n        self.proj = nn.Linear(dim, dim)\n        self.proj_drop = nn.Dropout(proj_drop)\n\n        trunc_normal_(self.relative_position_bias_table, std=0.02)\n        self.softmax = nn.Softmax(dim=-1)\n\n    def forward(self, x, mask=None):\n        \"\"\"\n        Args:\n            x: input features with shape of (num_windows*B, N, C)\n            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None\n        \"\"\"\n        B_, N, C = x.shape\n        qkv = (\n            self.qkv(x)\n            .reshape(B_, N, 3, self.num_heads, C // self.num_heads)\n            .permute(2, 0, 3, 1, 4)\n        )\n        q, k, v = (\n            qkv[0],\n            qkv[1],\n            qkv[2],\n        )  # make torchscript happy (cannot use tensor as tuple)\n\n        q = q * self.scale\n        attn = q @ k.transpose(-2, -1)\n\n        relative_position_bias = self.relative_position_bias_table[\n            self.relative_position_index.view(-1)\n        ].view(\n            self.window_size[0] * self.window_size[1],\n            self.window_size[0] * self.window_size[1],\n            -1,\n        )  # Wh*Ww,Wh*Ww,nH\n        relative_position_bias = relative_position_bias.permute(\n            2, 0, 1\n        ).contiguous()  # nH, Wh*Ww, Wh*Ww\n        attn = attn + relative_position_bias.unsqueeze(0)\n\n        if mask is not None:\n            nW = mask.shape[0]\n            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(\n                1\n            ).unsqueeze(0)\n            attn = attn.view(-1, self.num_heads, N, N)\n            attn = self.softmax(attn)\n        else:\n            attn = self.softmax(attn)\n\n        attn = self.attn_drop(attn)\n\n        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)\n        x = self.proj(x)\n        x = self.proj_drop(x)\n        return x\n\n    def extra_repr(self) -> str:\n        return f\"dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}\"\n\n    def flops(self, N):\n        # calculate flops for 1 window with token length of N\n        flops = 0\n        # qkv = self.qkv(x)\n        flops += N * self.dim * 3 * self.dim\n        # attn = (q @ k.transpose(-2, -1))\n        flops += self.num_heads * N * (self.dim // self.num_heads) * N\n        #  x = (attn @ v)\n        flops += self.num_heads * N * N * (self.dim // self.num_heads)\n        # x = self.proj(x)\n        flops += N * self.dim * self.dim\n        return flops\n\n\nclass SwinTransformerBlock(nn.Module):\n    r\"\"\" Swin Transformer Block.\n    Args:\n        dim (int): Number of input channels.\n        input_resolution (tuple[int]): Input resulotion.\n        num_heads (int): Number of attention heads.\n        window_size (int): Window size.\n        shift_size (int): Shift size for SW-MSA.\n        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.\n        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True\n        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.\n        drop (float, optional): Dropout rate. Default: 0.0\n        attn_drop (float, optional): Attention dropout rate. Default: 0.0\n        drop_path (float, optional): Stochastic depth rate. Default: 0.0\n        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU\n        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm\n    \"\"\"\n\n    def __init__(\n        self,\n        dim,\n        input_resolution,\n        num_heads,\n        window_size=7,\n        shift_size=0,\n        mlp_ratio=4.0,\n        qkv_bias=True,\n        qk_scale=None,\n        drop=0.0,\n        attn_drop=0.0,\n        drop_path=0.0,\n        act_layer=nn.GELU,\n        norm_layer=nn.LayerNorm,\n    ):\n        super().__init__()\n        self.dim = dim\n        self.input_resolution = input_resolution\n        self.num_heads = num_heads\n        self.window_size = window_size\n        self.shift_size = shift_size\n        self.mlp_ratio = mlp_ratio\n        if min(self.input_resolution) <= self.window_size:\n            # if window size is larger than input resolution, we don't partition windows\n            self.shift_size = 0\n            self.window_size = min(self.input_resolution)\n        assert (\n            0 <= self.shift_size < self.window_size\n        ), \"shift_size must in 0-window_size\"\n\n        self.norm1 = norm_layer(dim)\n        self.attn = WindowAttention(\n            dim,\n            window_size=to_2tuple(self.window_size),\n            num_heads=num_heads,\n            qkv_bias=qkv_bias,\n            qk_scale=qk_scale,\n            attn_drop=attn_drop,\n            proj_drop=drop,\n        )\n\n        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()\n        self.norm2 = norm_layer(dim)\n        mlp_hidden_dim = int(dim * mlp_ratio)\n        self.mlp = Mlp(\n            in_features=dim,\n            hidden_features=mlp_hidden_dim,\n            act_layer=act_layer,\n            drop=drop,\n        )\n\n        if self.shift_size > 0:\n            # calculate attention mask for SW-MSA\n            H, W = self.input_resolution\n            img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1\n            h_slices = (\n                slice(0, -self.window_size),\n                slice(-self.window_size, -self.shift_size),\n                slice(-self.shift_size, None),\n            )\n            w_slices = (\n                slice(0, -self.window_size),\n                slice(-self.window_size, -self.shift_size),\n                slice(-self.shift_size, None),\n            )\n            cnt = 0\n            for h in h_slices:\n                for w in w_slices:\n                    img_mask[:, h, w, :] = cnt\n                    cnt += 1\n\n            mask_windows = window_partition(\n                img_mask, self.window_size\n            )  # nW, window_size, window_size, 1\n            mask_windows = mask_windows.view(-1, self.window_size * self.window_size)\n            attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)\n            attn_mask = attn_mask.masked_fill(\n                attn_mask != 0, float(-100.0)\n            ).masked_fill(attn_mask == 0, float(0.0))\n        else:\n            attn_mask = None\n\n        self.register_buffer(\"attn_mask\", attn_mask)\n\n    def forward(self, x):\n        H, W = self.input_resolution\n        B, L, C = x.shape\n        assert L == H * W, \"input feature has wrong size\"\n\n        shortcut = x\n        x = self.norm1(x)\n        x = x.view(B, H, W, C)\n\n        # cyclic shift\n        if self.shift_size > 0:\n            shifted_x = torch.roll(\n                x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)\n            )\n        else:\n            shifted_x = x\n\n        # partition windows\n        x_windows = window_partition(\n            shifted_x, self.window_size\n        )  # nW*B, window_size, window_size, C\n        x_windows = x_windows.view(\n            -1, self.window_size * self.window_size, C\n        )  # nW*B, window_size*window_size, C\n\n        # W-MSA/SW-MSA\n        attn_windows = self.attn(\n            x_windows, mask=self.attn_mask\n        )  # nW*B, window_size*window_size, C\n\n        # merge windows\n        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)\n        shifted_x = window_reverse(attn_windows, self.window_size, H, W)  # B H' W' C\n\n        # reverse cyclic shift\n        if self.shift_size > 0:\n            x = torch.roll(\n                shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)\n            )\n        else:\n            x = shifted_x\n        x = x.view(B, H * W, C)\n\n        # FFN\n        x = shortcut + self.drop_path(x)\n        x = x + self.drop_path(self.mlp(self.norm2(x)))\n\n        return x\n\n    def extra_repr(self) -> str:\n        return (\n            f\"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, \"\n            f\"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}\"\n        )\n\n    def flops(self):\n        flops = 0\n        H, W = self.input_resolution\n        # norm1\n        flops += self.dim * H * W\n        # W-MSA/SW-MSA\n        nW = H * W / self.window_size / self.window_size\n        flops += nW * self.attn.flops(self.window_size * self.window_size)\n        # mlp\n        flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio\n        # norm2\n        flops += self.dim * H * W\n        return flops\n\n\nclass PatchMerging(nn.Module):\n    r\"\"\" Patch Merging Layer.\n    Args:\n        input_resolution (tuple[int]): Resolution of input feature.\n        dim (int): Number of input channels.\n        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm\n    \"\"\"\n\n    def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):\n        super().__init__()\n        self.input_resolution = input_resolution\n        self.dim = dim\n        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)\n        self.norm = norm_layer(4 * dim)\n\n    def forward(self, x):\n        \"\"\"\n        x: B, H*W, C\n        \"\"\"\n        H, W = self.input_resolution\n        B, L, C = x.shape\n        assert L == H * W, \"input feature has wrong size\"\n        assert H % 2 == 0 and W % 2 == 0, f\"x size ({H}*{W}) are not even.\"\n\n        x = x.view(B, H, W, C)\n\n        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C\n        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C\n        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C\n        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C\n        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C\n        x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C\n\n        x = self.norm(x)\n        x = self.reduction(x)\n\n        return x\n\n    def extra_repr(self) -> str:\n        return f\"input_resolution={self.input_resolution}, dim={self.dim}\"\n\n    def flops(self):\n        H, W = self.input_resolution\n        flops = H * W * self.dim\n        flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim\n        return flops\n\n\nclass BasicLayer(nn.Module):\n    \"\"\" A basic Swin Transformer layer for one stage.\n    Args:\n        dim (int): Number of input channels.\n        input_resolution (tuple[int]): Input resolution.\n        depth (int): Number of blocks.\n        num_heads (int): Number of attention heads.\n        window_size (int): Local window size.\n        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.\n        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True\n        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.\n        drop (float, optional): Dropout rate. Default: 0.0\n        attn_drop (float, optional): Attention dropout rate. Default: 0.0\n        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0\n        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm\n        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None\n        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.\n    \"\"\"\n\n    def __init__(\n        self,\n        dim,\n        input_resolution,\n        depth,\n        num_heads,\n        window_size,\n        mlp_ratio=4.0,\n        qkv_bias=True,\n        qk_scale=None,\n        drop=0.0,\n        attn_drop=0.0,\n        drop_path=0.0,\n        norm_layer=nn.LayerNorm,\n        downsample=None,\n        use_checkpoint=False,\n    ):\n\n        super().__init__()\n        self.dim = dim\n        self.input_resolution = input_resolution\n        self.depth = depth\n        self.use_checkpoint = use_checkpoint\n\n        # build blocks\n        self.blocks = nn.ModuleList(\n            [\n                SwinTransformerBlock(\n                    dim=dim,\n                    input_resolution=input_resolution,\n                    num_heads=num_heads,\n                    window_size=window_size,\n                    shift_size=0 if (i % 2 == 0) else window_size // 2,\n                    mlp_ratio=mlp_ratio,\n                    qkv_bias=qkv_bias,\n                    qk_scale=qk_scale,\n                    drop=drop,\n                    attn_drop=attn_drop,\n                    drop_path=drop_path[i]\n                    if isinstance(drop_path, list)\n                    else drop_path,\n                    norm_layer=norm_layer,\n                )\n                for i in range(depth)\n            ]\n        )\n\n        # patch merging layer\n        if downsample is not None:\n            self.downsample = downsample(\n                input_resolution, dim=dim, norm_layer=norm_layer\n            )\n        else:\n            self.downsample = None\n\n    def forward(self, x):\n        for blk in self.blocks:\n            x = blk(x)\n        if self.downsample is not None:\n            x = self.downsample(x)\n        return x\n\n    def extra_repr(self) -> str:\n        return f\"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}\"\n\n    def flops(self):\n        flops = 0\n        for blk in self.blocks:\n            flops += blk.flops()\n        if self.downsample is not None:\n            flops += self.downsample.flops()\n        return flops\n\n\nclass PatchEmbed(nn.Module):\n    r\"\"\" Image to Patch Embedding\n    Args:\n        img_size (int): Image size.  Default: 224.\n        patch_size (int): Patch token size. Default: 4.\n        in_chans (int): Number of input image channels. Default: 3.\n        embed_dim (int): Number of linear projection output channels. Default: 96.\n        norm_layer (nn.Module, optional): Normalization layer. Default: None\n    \"\"\"\n\n    def __init__(\n        self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None\n    ):\n        super().__init__()\n        img_size = to_2tuple(img_size)\n        patch_size = to_2tuple(patch_size)\n        patches_resolution = [\n            img_size[0] // patch_size[0],\n            img_size[1] // patch_size[1],\n        ]\n        self.img_size = img_size\n        self.patch_size = patch_size\n        self.patches_resolution = patches_resolution\n        self.num_patches = patches_resolution[0] * patches_resolution[1]\n\n        self.in_chans = in_chans\n        self.embed_dim = embed_dim\n\n        self.proj = nn.Conv2d(\n            in_chans, embed_dim, kernel_size=patch_size, stride=patch_size\n        )\n        if norm_layer is not None:\n            self.norm = norm_layer(embed_dim)\n        else:\n            self.norm = None\n\n    def forward(self, x):\n        B, C, H, W = x.shape\n        # FIXME look at relaxing size constraints\n        assert (\n            H == self.img_size[0] and W == self.img_size[1]\n        ), f\"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).\"\n        x = self.proj(x).flatten(2).transpose(1, 2)  # B Ph*Pw C\n        if self.norm is not None:\n            x = self.norm(x)\n        return x\n\n    def flops(self):\n        Ho, Wo = self.patches_resolution\n        flops = (\n            Ho\n            * Wo\n            * self.embed_dim\n            * self.in_chans\n            * (self.patch_size[0] * self.patch_size[1])\n        )\n        if self.norm is not None:\n            flops += Ho * Wo * self.embed_dim\n        return flops\n\n\nclass SwinTransformer(nn.Module):\n    r\"\"\" Swin Transformer\n        A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows`  -\n          https://arxiv.org/pdf/2103.14030\n    Args:\n        img_size (int | tuple(int)): Input image size. Default 224\n        patch_size (int | tuple(int)): Patch size. Default: 4\n        in_chans (int): Number of input image channels. Default: 3\n        num_classes (int): Number of classes for classification head. Default: 1000\n        embed_dim (int): Patch embedding dimension. Default: 96\n        depths (tuple(int)): Depth of each Swin Transformer layer.\n        num_heads (tuple(int)): Number of attention heads in different layers.\n        window_size (int): Window size. Default: 7\n        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4\n        qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True\n        qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None\n        drop_rate (float): Dropout rate. Default: 0\n        attn_drop_rate (float): Attention dropout rate. Default: 0\n        drop_path_rate (float): Stochastic depth rate. Default: 0.1\n        norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.\n        ape (bool): If True, add absolute position embedding to the patch embedding. Default: False\n        patch_norm (bool): If True, add normalization after patch embedding. Default: True\n        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False\n    \"\"\"\n\n    def __init__(\n        self,\n        img_size=224,\n        patch_size=4,\n        in_chans=3,\n        num_classes=1000,\n        embed_dim=96,\n        depths=[2, 2, 6, 2],\n        num_heads=[3, 6, 12, 24],\n        window_size=7,\n        mlp_ratio=4.0,\n        qkv_bias=True,\n        qk_scale=None,\n        drop_rate=0.0,\n        attn_drop_rate=0.0,\n        drop_path_rate=0.1,\n        norm_layer=nn.LayerNorm,\n        ape=False,\n        patch_norm=True,\n        use_checkpoint=False,\n        **kwargs,\n    ):\n        super().__init__()\n\n        self.num_classes = num_classes\n        self.num_layers = len(depths)\n        self.embed_dim = embed_dim\n        self.ape = ape\n        self.patch_norm = patch_norm\n        self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))\n        self.mlp_ratio = mlp_ratio\n\n        # split image into non-overlapping patches\n        self.patch_embed = PatchEmbed(\n            img_size=img_size,\n            patch_size=patch_size,\n            in_chans=in_chans,\n            embed_dim=embed_dim,\n            norm_layer=norm_layer if self.patch_norm else None,\n        )\n        num_patches = self.patch_embed.num_patches\n        patches_resolution = self.patch_embed.patches_resolution\n        self.patches_resolution = patches_resolution\n\n        # absolute position embedding\n        if self.ape:\n            self.absolute_pos_embed = nn.Parameter(\n                torch.zeros(1, num_patches, embed_dim)\n            )\n            trunc_normal_(self.absolute_pos_embed, std=0.02)\n\n        self.pos_drop = nn.Dropout(p=drop_rate)\n\n        # stochastic depth\n        dpr = [\n            x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))\n        ]  # stochastic depth decay rule\n\n        # build layers\n        self.layers = nn.ModuleList()\n        for i_layer in range(self.num_layers):\n            layer = BasicLayer(\n                dim=int(embed_dim * 2 ** i_layer),\n                input_resolution=(\n                    patches_resolution[0] // (2 ** i_layer),\n                    patches_resolution[1] // (2 ** i_layer),\n                ),\n                depth=depths[i_layer],\n                num_heads=num_heads[i_layer],\n                window_size=window_size,\n                mlp_ratio=self.mlp_ratio,\n                qkv_bias=qkv_bias,\n                qk_scale=qk_scale,\n                drop=drop_rate,\n                attn_drop=attn_drop_rate,\n                drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])],\n                norm_layer=norm_layer,\n                downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,\n                use_checkpoint=use_checkpoint,\n            )\n            self.layers.append(layer)\n\n        self.norm = norm_layer(self.num_features)\n        self.avgpool = nn.AdaptiveAvgPool1d(1)\n        self.head = (\n            nn.Linear(self.num_features, num_classes)\n            if num_classes > 0\n            else nn.Identity()\n        )\n\n        self.apply(self._init_weights)\n\n    def _init_weights(self, m):\n        if isinstance(m, nn.Linear):\n            trunc_normal_(m.weight, std=0.02)\n            if isinstance(m, nn.Linear) and m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n        elif isinstance(m, nn.LayerNorm):\n            nn.init.constant_(m.bias, 0)\n            nn.init.constant_(m.weight, 1.0)\n\n    def no_weight_decay(self):\n        return {\"absolute_pos_embed\"}\n\n    def no_weight_decay_keywords(self):\n        return {\"relative_position_bias_table\"}\n\n    def forward_features(self, x):\n        x = self.patch_embed(x)\n        if self.ape:\n            x = x + self.absolute_pos_embed\n        x = self.pos_drop(x)\n\n        for layer in self.layers:\n            x = layer(x)\n\n        x = self.norm(x)  # B L C\n        x = self.avgpool(x.transpose(1, 2))  # B C 1\n        x = torch.flatten(x, 1)\n        return x\n\n    def forward(self, x):\n        x = self.forward_features(x)\n        x = self.head(x)\n        return x\n\n    def flops(self):\n        flops = 0\n        flops += self.patch_embed.flops()\n        for i, layer in enumerate(self.layers):\n            flops += layer.flops()\n        flops += (\n            self.num_features\n            * self.patches_resolution[0]\n            * self.patches_resolution[1]\n            // (2 ** self.num_layers)\n        )\n        flops += self.num_features * self.num_classes\n        return flops\n\n\ndef _create_swin_transformer(arch, pretrained=False, progress=True, **model_kwargs):\n    model = SwinTransformer(**model_kwargs)\n    return model\n\n\ndef swin_tiny_patch4_window7_224(pretrained=False, progress=True, **kwargs):\n    \"\"\"\n    Constructs Swin-T 224x224 model trained on ImageNet-1k.\n    .. note::\n        Swin-T 224x224 model from `\"Swin Transformer: Hierarchical Vision Transformer using Shifted Windows\" <https://arxiv.org/pdf/2103.14030>`_.\n    Args:\n        pretrained (bool): Whether to download the pre-trained model on ImageNet. Default: ``False``\n        progress (bool): If True, displays a progress bar of the download to stderr. Default: ``True``\n    For example:\n    .. code-block:: python\n        >>> import flowvision\n        >>> swin_tiny_patch4_window7_224 = flowvision.models.swin_tiny_patch4_window7_224(pretrained=False, progress=True)\n    \"\"\"\n    model_kwargs = dict(\n        img_size=224,\n        patch_size=4,\n        window_size=7,\n        embed_dim=96,\n        depths=(2, 2, 6, 2),\n        num_heads=(3, 6, 12, 24),\n        drop_path_rate=0.2,\n        **kwargs,\n    )\n    return _create_swin_transformer(\n        \"swin_tiny_patch4_window7_224\",\n        pretrained=pretrained,\n        progress=progress,\n        **model_kwargs,\n    )\n"
  },
  {
    "path": "python/oneflow/test/expensive/pytorch_uniformer.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nfrom collections import OrderedDict\nimport torch\nimport torch.nn as nn\nfrom functools import partial\nfrom timm.models.layers import trunc_normal_, DropPath, to_2tuple\n\nlayer_scale = False\ninit_value = 1e-6\n\n\nclass Mlp(nn.Module):\n    def __init__(\n        self,\n        in_features,\n        hidden_features=None,\n        out_features=None,\n        act_layer=nn.GELU,\n        drop=0.0,\n    ):\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        self.fc1 = nn.Linear(in_features, hidden_features)\n        self.act = act_layer()\n        self.fc2 = nn.Linear(hidden_features, out_features)\n        self.drop = nn.Dropout(drop)\n\n    def forward(self, x):\n        x = self.fc1(x)\n        x = self.act(x)\n        x = self.drop(x)\n        x = self.fc2(x)\n        x = self.drop(x)\n        return x\n\n\nclass CMlp(nn.Module):\n    def __init__(\n        self,\n        in_features,\n        hidden_features=None,\n        out_features=None,\n        act_layer=nn.GELU,\n        drop=0.0,\n    ):\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        self.fc1 = nn.Conv2d(in_features, hidden_features, 1)\n        self.act = act_layer()\n        self.fc2 = nn.Conv2d(hidden_features, out_features, 1)\n        self.drop = nn.Dropout(drop)\n\n    def forward(self, x):\n        x = self.fc1(x)\n        x = self.act(x)\n        x = self.drop(x)\n        x = self.fc2(x)\n        x = self.drop(x)\n        return x\n\n\nclass Attention(nn.Module):\n    def __init__(\n        self,\n        dim,\n        num_heads=8,\n        qkv_bias=False,\n        qk_scale=None,\n        attn_drop=0.0,\n        proj_drop=0.0,\n    ):\n        super().__init__()\n        self.num_heads = num_heads\n        head_dim = dim // num_heads\n        # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights\n        self.scale = qk_scale or head_dim ** -0.5\n\n        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)\n        self.attn_drop = nn.Dropout(attn_drop)\n        self.proj = nn.Linear(dim, dim)\n        self.proj_drop = nn.Dropout(proj_drop)\n\n    def forward(self, x):\n        B, N, C = x.shape\n        qkv = (\n            self.qkv(x)\n            .reshape(B, N, 3, self.num_heads, C // self.num_heads)\n            .permute(2, 0, 3, 1, 4)\n        )\n        q, k, v = (\n            qkv[0],\n            qkv[1],\n            qkv[2],\n        )  # make torchscript happy (cannot use tensor as tuple)\n\n        attn = (q @ k.transpose(-2, -1)) * self.scale\n        attn = attn.softmax(dim=-1)\n        attn = self.attn_drop(attn)\n\n        x = (attn @ v).transpose(1, 2).reshape(B, N, C)\n        x = self.proj(x)\n        x = self.proj_drop(x)\n        return x\n\n\nclass CBlock(nn.Module):\n    def __init__(\n        self,\n        dim,\n        num_heads,\n        mlp_ratio=4.0,\n        qkv_bias=False,\n        qk_scale=None,\n        drop=0.0,\n        attn_drop=0.0,\n        drop_path=0.0,\n        act_layer=nn.GELU,\n        norm_layer=nn.LayerNorm,\n    ):\n        super().__init__()\n        self.pos_embed = nn.Conv2d(dim, dim, 3, padding=1, groups=dim)\n        self.norm1 = nn.BatchNorm2d(dim)\n        self.conv1 = nn.Conv2d(dim, dim, 1)\n        self.conv2 = nn.Conv2d(dim, dim, 1)\n        self.attn = nn.Conv2d(dim, dim, 5, padding=2, groups=dim)\n        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here\n        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()\n        self.norm2 = nn.BatchNorm2d(dim)\n        mlp_hidden_dim = int(dim * mlp_ratio)\n        self.mlp = CMlp(\n            in_features=dim,\n            hidden_features=mlp_hidden_dim,\n            act_layer=act_layer,\n            drop=drop,\n        )\n\n    def forward(self, x):\n        x = x + self.pos_embed(x)\n        x = x + self.drop_path(self.conv2(self.attn(self.conv1(self.norm1(x)))))\n        x = x + self.drop_path(self.mlp(self.norm2(x)))\n        return x\n\n\nclass SABlock(nn.Module):\n    def __init__(\n        self,\n        dim,\n        num_heads,\n        mlp_ratio=4.0,\n        qkv_bias=False,\n        qk_scale=None,\n        drop=0.0,\n        attn_drop=0.0,\n        drop_path=0.0,\n        act_layer=nn.GELU,\n        norm_layer=nn.LayerNorm,\n    ):\n        super().__init__()\n        self.pos_embed = nn.Conv2d(dim, dim, 3, padding=1, groups=dim)\n        self.norm1 = norm_layer(dim)\n        self.attn = Attention(\n            dim,\n            num_heads=num_heads,\n            qkv_bias=qkv_bias,\n            qk_scale=qk_scale,\n            attn_drop=attn_drop,\n            proj_drop=drop,\n        )\n        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here\n        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()\n        self.norm2 = norm_layer(dim)\n        mlp_hidden_dim = int(dim * mlp_ratio)\n        self.mlp = Mlp(\n            in_features=dim,\n            hidden_features=mlp_hidden_dim,\n            act_layer=act_layer,\n            drop=drop,\n        )\n        global layer_scale\n        self.ls = layer_scale\n        if self.ls:\n            global init_value\n            print(f\"Use layer_scale: {layer_scale}, init_values: {init_value}\")\n            self.gamma_1 = nn.Parameter(\n                init_value * torch.ones((dim)), requires_grad=True\n            )\n            self.gamma_2 = nn.Parameter(\n                init_value * torch.ones((dim)), requires_grad=True\n            )\n\n    def forward(self, x):\n        x = x + self.pos_embed(x)\n        B, N, H, W = x.shape\n        x = x.flatten(2).transpose(1, 2)\n        if self.ls:\n            x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x)))\n            x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))\n        else:\n            x = x + self.drop_path(self.attn(self.norm1(x)))\n            x = x + self.drop_path(self.mlp(self.norm2(x)))\n        x = x.transpose(1, 2).reshape(B, N, H, W)\n        return x\n\n\nclass head_embedding(nn.Module):\n    def __init__(self, in_channels, out_channels):\n        super(head_embedding, self).__init__()\n\n        self.proj = nn.Sequential(\n            nn.Conv2d(\n                in_channels,\n                out_channels // 2,\n                kernel_size=(3, 3),\n                stride=(2, 2),\n                padding=(1, 1),\n            ),\n            nn.BatchNorm2d(out_channels // 2),\n            nn.GELU(),\n            nn.Conv2d(\n                out_channels // 2,\n                out_channels,\n                kernel_size=(3, 3),\n                stride=(2, 2),\n                padding=(1, 1),\n            ),\n            nn.BatchNorm2d(out_channels),\n        )\n\n    def forward(self, x):\n        x = self.proj(x)\n        return x\n\n\nclass middle_embedding(nn.Module):\n    def __init__(self, in_channels, out_channels):\n        super(middle_embedding, self).__init__()\n\n        self.proj = nn.Sequential(\n            nn.Conv2d(\n                in_channels,\n                out_channels,\n                kernel_size=(3, 3),\n                stride=(2, 2),\n                padding=(1, 1),\n            ),\n            nn.BatchNorm2d(out_channels),\n        )\n\n    def forward(self, x):\n        x = self.proj(x)\n        return x\n\n\nclass PatchEmbed(nn.Module):\n    \"\"\" Image to Patch Embedding\n    \"\"\"\n\n    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):\n        super().__init__()\n        img_size = to_2tuple(img_size)\n        patch_size = to_2tuple(patch_size)\n        num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])\n        self.img_size = img_size\n        self.patch_size = patch_size\n        self.num_patches = num_patches\n        self.norm = nn.LayerNorm(embed_dim)\n        self.proj = nn.Conv2d(\n            in_chans, embed_dim, kernel_size=patch_size, stride=patch_size\n        )\n\n    def forward(self, x):\n        B, C, H, W = x.shape\n        # FIXME look at relaxing size constraints\n        assert (\n            H == self.img_size[0] and W == self.img_size[1]\n        ), f\"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).\"\n        x = self.proj(x)\n        B, C, H, W = x.shape\n        x = x.flatten(2).transpose(1, 2)\n        x = self.norm(x)\n        x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()\n        return x\n\n\nclass UniFormer(nn.Module):\n    \"\"\" Vision Transformer\n    A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`  -\n        https://arxiv.org/abs/2010.11929\n    \"\"\"\n\n    def __init__(\n        self,\n        depth=[3, 4, 8, 3],\n        img_size=224,\n        in_chans=3,\n        num_classes=1000,\n        embed_dim=[64, 128, 320, 512],\n        head_dim=64,\n        mlp_ratio=4.0,\n        qkv_bias=True,\n        qk_scale=None,\n        representation_size=None,\n        drop_rate=0.0,\n        attn_drop_rate=0.0,\n        drop_path_rate=0.0,\n        norm_layer=None,\n        conv_stem=False,\n    ):\n        \"\"\"\n        Args:\n            depth (list): depth of each stage\n            img_size (int, tuple): input image size\n            in_chans (int): number of input channels\n            num_classes (int): number of classes for classification head\n            embed_dim (list): embedding dimension of each stage\n            head_dim (int): head dimension\n            mlp_ratio (int): ratio of mlp hidden dim to embedding dim\n            qkv_bias (bool): enable bias for qkv if True\n            qk_scale (float): override default qk scale of head_dim ** -0.5 if set\n            representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set\n            drop_rate (float): dropout rate\n            attn_drop_rate (float): attention dropout rate\n            drop_path_rate (float): stochastic depth rate\n            norm_layer (nn.Module): normalization layer\n            conv_stem (bool): whether use overlapped patch stem\n        \"\"\"\n        super().__init__()\n        self.num_classes = num_classes\n        self.num_features = (\n            self.embed_dim\n        ) = embed_dim  # num_features for consistency with other models\n        norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)\n        if conv_stem:\n            self.patch_embed1 = head_embedding(\n                in_channels=in_chans, out_channels=embed_dim[0]\n            )\n            self.patch_embed2 = middle_embedding(\n                in_channels=embed_dim[0], out_channels=embed_dim[1]\n            )\n            self.patch_embed3 = middle_embedding(\n                in_channels=embed_dim[1], out_channels=embed_dim[2]\n            )\n            self.patch_embed4 = middle_embedding(\n                in_channels=embed_dim[2], out_channels=embed_dim[3]\n            )\n        else:\n            self.patch_embed1 = PatchEmbed(\n                img_size=img_size,\n                patch_size=4,\n                in_chans=in_chans,\n                embed_dim=embed_dim[0],\n            )\n            self.patch_embed2 = PatchEmbed(\n                img_size=img_size // 4,\n                patch_size=2,\n                in_chans=embed_dim[0],\n                embed_dim=embed_dim[1],\n            )\n            self.patch_embed3 = PatchEmbed(\n                img_size=img_size // 8,\n                patch_size=2,\n                in_chans=embed_dim[1],\n                embed_dim=embed_dim[2],\n            )\n            self.patch_embed4 = PatchEmbed(\n                img_size=img_size // 16,\n                patch_size=2,\n                in_chans=embed_dim[2],\n                embed_dim=embed_dim[3],\n            )\n\n        self.pos_drop = nn.Dropout(p=drop_rate)\n        dpr = [\n            x.item() for x in torch.linspace(0, drop_path_rate, sum(depth))\n        ]  # stochastic depth decay rule\n        num_heads = [dim // head_dim for dim in embed_dim]\n        self.blocks1 = nn.ModuleList(\n            [\n                CBlock(\n                    dim=embed_dim[0],\n                    num_heads=num_heads[0],\n                    mlp_ratio=mlp_ratio,\n                    qkv_bias=qkv_bias,\n                    qk_scale=qk_scale,\n                    drop=drop_rate,\n                    attn_drop=attn_drop_rate,\n                    drop_path=dpr[i],\n                    norm_layer=norm_layer,\n                )\n                for i in range(depth[0])\n            ]\n        )\n        self.blocks2 = nn.ModuleList(\n            [\n                CBlock(\n                    dim=embed_dim[1],\n                    num_heads=num_heads[1],\n                    mlp_ratio=mlp_ratio,\n                    qkv_bias=qkv_bias,\n                    qk_scale=qk_scale,\n                    drop=drop_rate,\n                    attn_drop=attn_drop_rate,\n                    drop_path=dpr[i + depth[0]],\n                    norm_layer=norm_layer,\n                )\n                for i in range(depth[1])\n            ]\n        )\n        self.blocks3 = nn.ModuleList(\n            [\n                SABlock(\n                    dim=embed_dim[2],\n                    num_heads=num_heads[2],\n                    mlp_ratio=mlp_ratio,\n                    qkv_bias=qkv_bias,\n                    qk_scale=qk_scale,\n                    drop=drop_rate,\n                    attn_drop=attn_drop_rate,\n                    drop_path=dpr[i + depth[0] + depth[1]],\n                    norm_layer=norm_layer,\n                )\n                for i in range(depth[2])\n            ]\n        )\n        self.blocks4 = nn.ModuleList(\n            [\n                SABlock(\n                    dim=embed_dim[3],\n                    num_heads=num_heads[3],\n                    mlp_ratio=mlp_ratio,\n                    qkv_bias=qkv_bias,\n                    qk_scale=qk_scale,\n                    drop=drop_rate,\n                    attn_drop=attn_drop_rate,\n                    drop_path=dpr[i + depth[0] + depth[1] + depth[2]],\n                    norm_layer=norm_layer,\n                )\n                for i in range(depth[3])\n            ]\n        )\n        self.norm = nn.BatchNorm2d(embed_dim[-1])\n\n        # Representation layer\n        if representation_size:\n            self.num_features = representation_size\n            self.pre_logits = nn.Sequential(\n                OrderedDict(\n                    [\n                        (\"fc\", nn.Linear(embed_dim, representation_size)),\n                        (\"act\", nn.Tanh()),\n                    ]\n                )\n            )\n        else:\n            self.pre_logits = nn.Identity()\n\n        # Classifier head\n        self.head = (\n            nn.Linear(embed_dim[-1], num_classes) if num_classes > 0 else nn.Identity()\n        )\n\n        self.apply(self._init_weights)\n\n    def _init_weights(self, m):\n        if isinstance(m, nn.Linear):\n            trunc_normal_(m.weight, std=0.02)\n            if isinstance(m, nn.Linear) and m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n        elif isinstance(m, nn.LayerNorm):\n            nn.init.constant_(m.bias, 0)\n            nn.init.constant_(m.weight, 1.0)\n\n    def no_weight_decay(self):\n        return {\"pos_embed\", \"cls_token\"}\n\n    def get_classifier(self):\n        return self.head\n\n    def reset_classifier(self, num_classes, global_pool=\"\"):\n        self.num_classes = num_classes\n        self.head = (\n            nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()\n        )\n\n    def forward_features(self, x):\n        x = self.patch_embed1(x)\n        x = self.pos_drop(x)\n        for blk in self.blocks1:\n            x = blk(x)\n        x = self.patch_embed2(x)\n        for blk in self.blocks2:\n            x = blk(x)\n        x = self.patch_embed3(x)\n        for blk in self.blocks3:\n            x = blk(x)\n        x = self.patch_embed4(x)\n        for blk in self.blocks4:\n            x = blk(x)\n        x = self.norm(x)\n        x = self.pre_logits(x)\n        return x\n\n    def forward(self, x):\n        x = self.forward_features(x)\n        x = x.flatten(2).mean(-1)\n        x = self.head(x)\n        return x\n\n\ndef uniformer_small(pretrained=True, **kwargs):\n    model = UniFormer(\n        depth=[3, 4, 8, 3],\n        embed_dim=[64, 128, 320, 512],\n        head_dim=64,\n        mlp_ratio=4,\n        qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6),\n        **kwargs,\n    )\n    return model\n"
  },
  {
    "path": "python/oneflow/test/expensive/pytroch_mlp_mixer.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport math\nimport torch\nimport torch.nn as nn\nfrom timm.models.layers import DropPath, lecun_normal_, to_2tuple\n\nfrom functools import partial\nfrom typing import Callable\n\n\nclass Mlp(nn.Module):\n    \"\"\" MLP as used in Vision Transformer, MLP-Mixer and related networks\n    \"\"\"\n\n    def __init__(\n        self,\n        in_features,\n        hidden_features=None,\n        out_features=None,\n        act_layer=nn.GELU,\n        drop=0.0,\n    ):\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        drop_probs = to_2tuple(drop)\n\n        self.fc1 = nn.Linear(in_features, hidden_features)\n        self.act = act_layer()\n        self.drop1 = nn.Dropout(drop_probs[0])\n        self.fc2 = nn.Linear(hidden_features, out_features)\n        self.drop2 = nn.Dropout(drop_probs[1])\n\n    def forward(self, x):\n        x = self.fc1(x)\n        x = self.act(x)\n        x = self.drop1(x)\n        x = self.fc2(x)\n        x = self.drop2(x)\n        return x\n\n\nclass PatchEmbed(nn.Module):\n    \"\"\" 2D Image to Patch Embedding\n    \"\"\"\n\n    def __init__(\n        self,\n        img_size=224,\n        patch_size=16,\n        in_chans=3,\n        embed_dim=768,\n        norm_layer=None,\n        flatten=True,\n    ):\n        super().__init__()\n        img_size = to_2tuple(img_size)\n        patch_size = to_2tuple(patch_size)\n        self.img_size = img_size\n        self.patch_size = patch_size\n        self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])\n        self.num_patches = self.grid_size[0] * self.grid_size[1]\n        self.flatten = flatten\n\n        self.proj = nn.Conv2d(\n            in_chans, embed_dim, kernel_size=patch_size, stride=patch_size\n        )\n        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()\n\n    def forward(self, x):\n        B, C, H, W = x.shape\n        assert (\n            H == self.img_size[0],\n            f\"Input image height ({H}) doesn't match model ({self.img_size[0]}).\",\n        )\n        assert (\n            W == self.img_size[1],\n            f\"Input image width ({W}) doesn't match model ({self.img_size[1]}).\",\n        )\n        x = self.proj(x)\n        if self.flatten:\n            x = x.flatten(2).transpose(1, 2)  # BCHW -> BNC\n        x = self.norm(x)\n        return x\n\n\ndef named_apply(\n    fn: Callable, module: nn.Module, name=\"\", depth_first=True, include_root=False\n) -> nn.Module:\n    if not depth_first and include_root:\n        fn(module=module, name=name)\n    for child_name, child_module in module.named_children():\n        child_name = \".\".join((name, child_name)) if name else child_name\n        named_apply(\n            fn=fn,\n            module=child_module,\n            name=child_name,\n            depth_first=depth_first,\n            include_root=True,\n        )\n    if depth_first and include_root:\n        fn(module=module, name=name)\n    return module\n\n\nclass GatedMlp(nn.Module):\n    \"\"\" MLP as used in gMLP\n    \"\"\"\n\n    def __init__(\n        self,\n        in_features,\n        hidden_features=None,\n        out_features=None,\n        act_layer=nn.GELU,\n        gate_layer=None,\n        drop=0.0,\n    ):\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        drop_probs = to_2tuple(drop)\n\n        self.fc1 = nn.Linear(in_features, hidden_features)\n        self.act = act_layer()\n        self.drop1 = nn.Dropout(drop_probs[0])\n        if gate_layer is not None:\n            assert hidden_features % 2 == 0\n            self.gate = gate_layer(hidden_features)\n            hidden_features = (\n                hidden_features // 2\n            )  # FIXME base reduction on gate property?\n        else:\n            self.gate = nn.Identity()\n        self.fc2 = nn.Linear(hidden_features, out_features)\n        self.drop2 = nn.Dropout(drop_probs[1])\n\n    def forward(self, x):\n        x = self.fc1(x)\n        x = self.act(x)\n        x = self.drop1(x)\n        x = self.gate(x)\n        x = self.fc2(x)\n        x = self.drop2(x)\n        return x\n\n\nclass MixerBlock(nn.Module):\n    \"\"\" Residual Block w/ token mixing and channel MLPs\n    Based on: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601\n    \"\"\"\n\n    def __init__(\n        self,\n        dim,\n        seq_len,\n        mlp_ratio=(0.5, 4.0),\n        mlp_layer=Mlp,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6),\n        act_layer=nn.GELU,\n        drop=0.0,\n        drop_path=0.0,\n    ):\n        super().__init__()\n        tokens_dim, channels_dim = [int(x * dim) for x in to_2tuple(mlp_ratio)]\n        self.norm1 = norm_layer(dim)\n        self.mlp_tokens = mlp_layer(seq_len, tokens_dim, act_layer=act_layer, drop=drop)\n        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()\n        self.norm2 = norm_layer(dim)\n        self.mlp_channels = mlp_layer(dim, channels_dim, act_layer=act_layer, drop=drop)\n\n    def forward(self, x):\n        x = x + self.drop_path(\n            self.mlp_tokens(self.norm1(x).transpose(1, 2)).transpose(1, 2)\n        )\n        x = x + self.drop_path(self.mlp_channels(self.norm2(x)))\n        return x\n\n\nclass Affine(nn.Module):\n    def __init__(self, dim):\n        super().__init__()\n        self.alpha = nn.Parameter(torch.ones((1, 1, dim)))\n        self.beta = nn.Parameter(torch.zeros((1, 1, dim)))\n\n    def forward(self, x):\n        return torch.addcmul(self.beta, self.alpha, x)\n\n\nclass ResBlock(nn.Module):\n    \"\"\" Residual MLP block w/ LayerScale and Affine 'norm'\n    Based on: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404\n    \"\"\"\n\n    def __init__(\n        self,\n        dim,\n        seq_len,\n        mlp_ratio=4,\n        mlp_layer=Mlp,\n        norm_layer=Affine,\n        act_layer=nn.GELU,\n        init_values=1e-4,\n        drop=0.0,\n        drop_path=0.0,\n    ):\n        super().__init__()\n        channel_dim = int(dim * mlp_ratio)\n        self.norm1 = norm_layer(dim)\n        self.linear_tokens = nn.Linear(seq_len, seq_len)\n        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()\n        self.norm2 = norm_layer(dim)\n        self.mlp_channels = mlp_layer(dim, channel_dim, act_layer=act_layer, drop=drop)\n        self.ls1 = nn.Parameter(init_values * torch.ones(dim))\n        self.ls2 = nn.Parameter(init_values * torch.ones(dim))\n\n    def forward(self, x):\n        x = x + self.drop_path(\n            self.ls1 * self.linear_tokens(self.norm1(x).transpose(1, 2)).transpose(1, 2)\n        )\n        x = x + self.drop_path(self.ls2 * self.mlp_channels(self.norm2(x)))\n        return x\n\n\nclass SpatialGatingUnit(nn.Module):\n    \"\"\" Spatial Gating Unit\n    Based on: `Pay Attention to MLPs` - https://arxiv.org/abs/2105.08050\n    \"\"\"\n\n    def __init__(self, dim, seq_len, norm_layer=nn.LayerNorm):\n        super().__init__()\n        gate_dim = dim // 2\n        self.norm = norm_layer(gate_dim)\n        self.proj = nn.Linear(seq_len, seq_len)\n\n    def init_weights(self):\n        # special init for the projection gate, called as override by base model init\n        nn.init.normal_(self.proj.weight, std=1e-6)\n        nn.init.ones_(self.proj.bias)\n\n    def forward(self, x):\n        u, v = x.chunk(2, dim=-1)\n        v = self.norm(v)\n        v = self.proj(v.transpose(-1, -2))\n        return u * v.transpose(-1, -2)\n\n\nclass SpatialGatingBlock(nn.Module):\n    \"\"\" Residual Block w/ Spatial Gating\n    Based on: `Pay Attention to MLPs` - https://arxiv.org/abs/2105.08050\n    \"\"\"\n\n    def __init__(\n        self,\n        dim,\n        seq_len,\n        mlp_ratio=4,\n        mlp_layer=GatedMlp,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6),\n        act_layer=nn.GELU,\n        drop=0.0,\n        drop_path=0.0,\n    ):\n        super().__init__()\n        channel_dim = int(dim * mlp_ratio)\n        self.norm = norm_layer(dim)\n        sgu = partial(SpatialGatingUnit, seq_len=seq_len)\n        self.mlp_channels = mlp_layer(\n            dim, channel_dim, act_layer=act_layer, gate_layer=sgu, drop=drop\n        )\n        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()\n\n    def forward(self, x):\n        x = x + self.drop_path(self.mlp_channels(self.norm(x)))\n        return x\n\n\nclass MlpMixer(nn.Module):\n    def __init__(\n        self,\n        num_classes=1000,\n        img_size=224,\n        in_chans=3,\n        patch_size=16,\n        num_blocks=8,\n        embed_dim=512,\n        mlp_ratio=(0.5, 4.0),\n        block_layer=MixerBlock,\n        mlp_layer=Mlp,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6),\n        act_layer=nn.GELU,\n        drop_rate=0.0,\n        drop_path_rate=0.0,\n        nlhb=False,\n        stem_norm=False,\n        global_pool=\"avg\",\n    ):\n        super().__init__()\n        self.num_classes = num_classes\n        self.global_pool = global_pool\n        self.num_features = (\n            self.embed_dim\n        ) = embed_dim  # num_features for consistency with other models\n        self.grad_checkpointing = False\n\n        self.stem = PatchEmbed(\n            img_size=img_size,\n            patch_size=patch_size,\n            in_chans=in_chans,\n            embed_dim=embed_dim,\n            norm_layer=norm_layer if stem_norm else None,\n        )\n        # FIXME drop_path (stochastic depth scaling rule or all the same?)\n        self.blocks = nn.Sequential(\n            *[\n                block_layer(\n                    embed_dim,\n                    self.stem.num_patches,\n                    mlp_ratio,\n                    mlp_layer=mlp_layer,\n                    norm_layer=norm_layer,\n                    act_layer=act_layer,\n                    drop=drop_rate,\n                    drop_path=drop_path_rate,\n                )\n                for _ in range(num_blocks)\n            ]\n        )\n        self.norm = norm_layer(embed_dim)\n        self.head = (\n            nn.Linear(embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()\n        )\n\n        self.init_weights(nlhb=nlhb)\n\n    def init_weights(self, nlhb=False):\n        head_bias = -math.log(self.num_classes) if nlhb else 0.0\n        named_apply(\n            partial(_init_weights, head_bias=head_bias), module=self\n        )  # depth-first\n\n    def group_matcher(self, coarse=False):\n        return dict(\n            stem=r\"^stem\",  # stem and embed\n            blocks=[(r\"^blocks\\.(\\d+)\", None), (r\"^norm\", (99999,))],\n        )\n\n    def set_grad_checkpointing(self, enable=True):\n        self.grad_checkpointing = enable\n\n    def get_classifier(self):\n        return self.head\n\n    def reset_classifier(self, num_classes, global_pool=None):\n        self.num_classes = num_classes\n        if global_pool is not None:\n            assert global_pool in (\"\", \"avg\")\n            self.global_pool = global_pool\n        self.head = (\n            nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()\n        )\n\n    def forward_features(self, x):\n        x = self.stem(x)\n        x = self.blocks(x)\n        x = self.norm(x)\n        return x\n\n    def forward(self, x):\n        x = self.forward_features(x)\n        if self.global_pool == \"avg\":\n            x = x.mean(dim=1)\n        x = self.head(x)\n        return x\n\n\ndef _init_weights(module: nn.Module, name: str, head_bias: float = 0.0, flax=False):\n    \"\"\" Mixer weight initialization (trying to match Flax defaults)\n    \"\"\"\n    if isinstance(module, nn.Linear):\n        if name.startswith(\"head\"):\n            nn.init.zeros_(module.weight)\n            nn.init.constant_(module.bias, head_bias)\n        else:\n            if flax:\n                # Flax defaults\n                lecun_normal_(module.weight)\n                if module.bias is not None:\n                    nn.init.zeros_(module.bias)\n            else:\n                # like MLP init in vit (my original init)\n                nn.init.xavier_uniform_(module.weight)\n                if module.bias is not None:\n                    if \"mlp\" in name:\n                        nn.init.normal_(module.bias, std=1e-6)\n                    else:\n                        nn.init.zeros_(module.bias)\n    elif isinstance(module, nn.Conv2d):\n        lecun_normal_(module.weight)\n        if module.bias is not None:\n            nn.init.zeros_(module.bias)\n    elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d, nn.GroupNorm)):\n        nn.init.ones_(module.weight)\n        nn.init.zeros_(module.bias)\n    elif hasattr(module, \"init_weights\"):\n        # NOTE if a parent module contains init_weights method, it can override the init of the\n        # child modules as this will be called in depth-first order.\n        module.init_weights()\n\n\ndef mixer_s32_224(pretrained=False, **kwargs):\n    \"\"\" Mixer-S/32 224x224\n    Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601\n    \"\"\"\n    model_args = dict(patch_size=32, num_blocks=8, embed_dim=512, **kwargs)\n    model = MlpMixer(**model_args)\n    return model\n"
  },
  {
    "path": "python/oneflow/test/expensive/resnet50_model.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nfrom typing import Any, Callable, List, Optional, Type, Union\n\nimport oneflow as flow\nimport oneflow.nn as nn\nfrom oneflow import Tensor\n\n\nclass FakeBN(nn.Module):\n    \"\"\"Common base of _InstanceNorm and _BatchNorm\"\"\"\n\n    def __init__(\n        self,\n        num_features: int,\n        eps: float = 1e-05,\n        momentum: float = 0.1,\n        affine: bool = True,\n        track_running_stats: bool = True,\n    ) -> None:\n        super().__init__()\n        self.num_features = num_features\n        self.eps = eps\n        self.momentum = momentum\n        self.affine = affine\n        self.track_running_stats = track_running_stats\n        if self.affine:\n            self.weight = flow.nn.Parameter(flow.Tensor(num_features))\n            self.bias = flow.nn.Parameter(flow.Tensor(num_features))\n        else:\n            self.register_parameter(\"weight\", None)\n            self.register_parameter(\"bias\", None)\n        if self.track_running_stats:\n            self.register_buffer(\"running_mean\", flow.Tensor(num_features))\n            self.register_buffer(\"running_var\", flow.Tensor(num_features))\n        else:\n            self.register_parameter(\"running_mean\", None)\n            self.register_parameter(\"running_var\", None)\n\n    def forward(self, input):\n        return flow._C.identity(input)\n\n\ndef conv3x3(\n    in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1\n) -> nn.Conv2d:\n    \"\"\"3x3 convolution with padding\"\"\"\n    return nn.Conv2d(\n        in_planes,\n        out_planes,\n        kernel_size=3,\n        stride=stride,\n        padding=dilation,\n        groups=groups,\n        bias=False,\n        dilation=dilation,\n    )\n\n\ndef conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:\n    \"\"\"1x1 convolution\"\"\"\n    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)\n\n\nclass BasicBlock(nn.Module):\n    expansion: int = 1\n\n    def __init__(\n        self,\n        inplanes: int,\n        planes: int,\n        stride: int = 1,\n        downsample: Optional[nn.Module] = None,\n        groups: int = 1,\n        base_width: int = 64,\n        dilation: int = 1,\n        norm_layer: Optional[Callable[..., nn.Module]] = None,\n    ) -> None:\n        super(BasicBlock, self).__init__()\n        if norm_layer is None:\n            norm_layer = nn.BatchNorm2d\n        if groups != 1 or base_width != 64:\n            raise ValueError(\"BasicBlock only supports groups=1 and base_width=64\")\n        if dilation > 1:\n            raise NotImplementedError(\"Dilation > 1 not supported in BasicBlock\")\n        self.conv1 = conv3x3(inplanes, planes, stride)\n        self.bn1 = norm_layer(planes)\n        self.relu = nn.ReLU()\n        self.conv2 = conv3x3(planes, planes)\n        self.bn2 = norm_layer(planes)\n        self.downsample = downsample\n        self.stride = stride\n\n    def forward(self, x: Tensor) -> Tensor:\n        identity = x\n        out = self.conv1(x)\n        out = self.bn1(out)\n        out = self.relu(out)\n        out = self.conv2(out)\n        out = self.bn2(out)\n        if self.downsample is not None:\n            identity = self.downsample(x)\n        out += identity\n        out = self.relu(out)\n        return out\n\n\nclass Bottleneck(nn.Module):\n    expansion: int = 4\n\n    def __init__(\n        self,\n        inplanes: int,\n        planes: int,\n        stride: int = 1,\n        downsample: Optional[nn.Module] = None,\n        groups: int = 1,\n        base_width: int = 64,\n        dilation: int = 1,\n        norm_layer: Optional[Callable[..., nn.Module]] = None,\n    ) -> None:\n        super(Bottleneck, self).__init__()\n        if norm_layer is None:\n            norm_layer = nn.BatchNorm2d\n        width = int(planes * (base_width / 64.0)) * groups\n        self.conv1 = conv1x1(inplanes, width)\n        self.bn1 = norm_layer(width)\n        self.conv2 = conv3x3(width, width, stride, groups, dilation)\n        self.bn2 = norm_layer(width)\n        self.conv3 = conv1x1(width, planes * self.expansion)\n        self.bn3 = norm_layer(planes * self.expansion)\n        self.relu = nn.ReLU()\n        self.downsample = downsample\n        self.stride = stride\n\n    def forward(self, x: Tensor) -> Tensor:\n        identity = x\n        out = self.conv1(x)\n        out = self.bn1(out)\n        out = self.relu(out)\n        out = self.conv2(out)\n        out = self.bn2(out)\n        out = self.relu(out)\n        out = self.conv3(out)\n        out = self.bn3(out)\n        if self.downsample is not None:\n            identity = self.downsample(x)\n        out += identity\n        out = self.relu(out)\n        return out\n\n\nclass ResNet(nn.Module):\n    def __init__(\n        self,\n        block: Type[Union[BasicBlock, Bottleneck]],\n        layers: List[int],\n        num_classes: int = 1000,\n        zero_init_residual: bool = False,\n        groups: int = 1,\n        width_per_group: int = 64,\n        replace_stride_with_dilation: Optional[List[bool]] = None,\n        norm_layer: Optional[Callable[..., nn.Module]] = None,\n    ) -> None:\n        super(ResNet, self).__init__()\n        if norm_layer is None:\n            norm_layer = nn.BatchNorm2d\n        self._norm_layer = norm_layer\n        self.inplanes = 64\n        self.dilation = 1\n        if replace_stride_with_dilation is None:\n            replace_stride_with_dilation = [False, False, False]\n        if len(replace_stride_with_dilation) != 3:\n            raise ValueError(\n                \"replace_stride_with_dilation should be None or a 3-element tuple, got {}\".format(\n                    replace_stride_with_dilation\n                )\n            )\n        self.groups = groups\n        self.base_width = width_per_group\n        self.conv1 = nn.Conv2d(\n            3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False\n        )\n        self.bn1 = norm_layer(self.inplanes)\n        self.relu = nn.ReLU()\n        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)\n        self.layer1 = self._make_layer(block, 64, layers[0])\n        self.layer2 = self._make_layer(\n            block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0]\n        )\n        self.layer3 = self._make_layer(\n            block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1]\n        )\n        self.layer4 = self._make_layer(\n            block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2]\n        )\n        self.avgpool = nn.AvgPool2d((7, 7))\n        self.fc = nn.Linear(512 * block.expansion, num_classes)\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                nn.init.kaiming_normal_(m.weight, mode=\"fan_out\", nonlinearity=\"relu\")\n            elif isinstance(m, nn.BatchNorm2d):\n                nn.init.constant_(m.weight, 1)\n                nn.init.constant_(m.bias, 0)\n        if zero_init_residual:\n            for m in self.modules():\n                if isinstance(m, Bottleneck):\n                    nn.init.constant_(m.bn3.weight, 0)\n                elif isinstance(m, BasicBlock):\n                    nn.init.constant_(m.bn2.weight, 0)\n\n    def _make_layer(\n        self,\n        block: Type[Union[BasicBlock, Bottleneck]],\n        planes: int,\n        blocks: int,\n        stride: int = 1,\n        dilate: bool = False,\n    ) -> nn.Sequential:\n        norm_layer = self._norm_layer\n        downsample = None\n        previous_dilation = self.dilation\n        if dilate:\n            self.dilation *= stride\n            stride = 1\n        if stride != 1 or self.inplanes != planes * block.expansion:\n            downsample = nn.Sequential(\n                conv1x1(self.inplanes, planes * block.expansion, stride),\n                norm_layer(planes * block.expansion),\n            )\n        layers = []\n        layers.append(\n            block(\n                self.inplanes,\n                planes,\n                stride,\n                downsample,\n                self.groups,\n                self.base_width,\n                previous_dilation,\n                norm_layer,\n            )\n        )\n        self.inplanes = planes * block.expansion\n        for _ in range(1, blocks):\n            layers.append(\n                block(\n                    self.inplanes,\n                    planes,\n                    groups=self.groups,\n                    base_width=self.base_width,\n                    dilation=self.dilation,\n                    norm_layer=norm_layer,\n                )\n            )\n        return nn.Sequential(*layers)\n\n    def _forward_impl(self, x: Tensor) -> Tensor:\n        x = self.conv1(x)\n        x = self.bn1(x)\n        x = self.relu(x)\n        x = self.maxpool(x)\n        x = self.layer1(x)\n        x = self.layer2(x)\n        x = self.layer3(x)\n        x = self.layer4(x)\n        x = self.avgpool(x)\n        x = flow.flatten(x, 1)\n        x = self.fc(x)\n        return x\n\n    def forward(self, x: Tensor) -> Tensor:\n        return self._forward_impl(x)\n\n\ndef _resnet(\n    arch: str,\n    block: Type[Union[BasicBlock, Bottleneck]],\n    layers: List[int],\n    **kwargs: Any\n) -> ResNet:\n    model = ResNet(block, layers, **kwargs)\n    return model\n\n\ndef resnet50(**kwargs: Any) -> ResNet:\n    \"\"\"ResNet-5\n    `\"Deep Residual Learning for Image Recognition\" <https://arxiv.org/pdf/1512.03385.pdf>`_.\n    \"\"\"\n    return _resnet(\"resnet50\", Bottleneck, [3, 4, 6, 3], **kwargs)\n"
  },
  {
    "path": "python/oneflow/test/expensive/test_compatibility.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nfrom oneflow.test_utils.oneflow_pytorch_compatibility import *\nimport os\n\n\n@flow.unittest.skip_unless_1n1d()\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test gpu cases\")\nclass TestApiCompatibility(flow.unittest.TestCase):\n    def test_alexnet_compatibility(test_case):\n        do_test_train_loss_oneflow_pytorch(\n            test_case, \"pytorch_alexnet.py\", \"alexnet\", \"cuda\", 16, 224\n        )\n\n    def test_resnet50_compatibility(test_case):\n        do_test_train_loss_oneflow_pytorch(\n            test_case, \"pytorch_resnet.py\", \"resnet50\", \"cuda\", 16, 224\n        )\n\n    @unittest.skipIf(\n        os.environ[\"ONEFLOW_CI\"] == \"1\",\n        \"always get error: 'Check failed: cudnnConvolutionBackwardFilter'\",\n    )\n    def test_convmixer_compatibility(test_case):\n        do_test_train_loss_oneflow_pytorch(\n            test_case, \"pytorch_convmixer.py\", \"convmixer_768_32_relu\", \"cuda\", 4, 224\n        )\n\n    def test_densenet_compatibility(test_case):\n        do_test_train_loss_oneflow_pytorch(\n            test_case, \"pytorch_densenet.py\", \"densenet121\", \"cuda\", 8, 224\n        )\n\n    def test_ghostnet_compatibility(test_case):\n        do_test_train_loss_oneflow_pytorch(\n            test_case, \"pytorch_ghostnet.py\", \"ghost_net\", \"cuda\", 16, 224\n        )\n\n    def test_googlenet_compatibility(test_case):\n        do_test_train_loss_oneflow_pytorch(\n            test_case, \"pytorch_googlenet.py\", \"googlenet\", \"cuda\", 8, 224\n        )\n\n    def test_inception_v3_compatibility(test_case):\n        do_test_train_loss_oneflow_pytorch(\n            test_case, \"pytorch_inception_v3.py\", \"inception_v3\", \"cuda\", 4, 299\n        )\n\n    def test_mnasnet_compatibility(test_case):\n        do_test_train_loss_oneflow_pytorch(\n            test_case, \"pytorch_mnasnet.py\", \"mnasnet1_0\", \"cuda\", 16, 224\n        )\n\n    # def test_rexnet_compatibility(test_case):\n    #     do_test_train_loss_oneflow_pytorch(\n    #         test_case, \"pytorch_rexnet.py\", \"rexnetv1_1_0\", \"cuda\", 16, 224\n    #     )\n\n    # TODO(): support non-contiguous inplace add\n    # def test_rexnetv1_lite_compatibility(test_case):\n    #     do_test_train_loss_oneflow_pytorch(\n    #         test_case, \"pytorch_rexnetv1_lite.py\", \"rexnet_lite_1_0\", \"cuda\", 16, 224\n    #     )\n\n    # def test_res2net_compatibility(test_case):\n    #     do_test_train_loss_oneflow_pytorch(\n    #         test_case, \"pytorch_res2net.py\", \"res2net50\", \"cuda\", 16, 224\n    #     )\n\n    def test_shufflenetv2_compatibility(test_case):\n        do_test_train_loss_oneflow_pytorch(\n            test_case, \"pytorch_shufflenetv2.py\", \"shufflenet_v2_x2_0\", \"cuda\", 16, 224\n        )\n\n    def test_squeezenet_compatibility(test_case):\n        do_test_train_loss_oneflow_pytorch(\n            test_case, \"pytorch_squeezenet.py\", \"squeezenet1_1\", \"cuda\", 16, 224\n        )\n\n    @unittest.skipIf(\n        os.environ[\"ONEFLOW_CI\"] == \"1\",\n        \"always get error: 'Check failed: cudnnConvolutionBackwardFilter'\",\n    )\n    def test_convnext_compatibility(test_case):\n        do_test_train_loss_oneflow_pytorch(\n            test_case, \"pytorch_convnext.py\", \"convnext_tiny\", \"cuda\", 8, 224\n        )\n\n    # def test_crossformer_compatibility(test_case):\n    #     do_test_train_loss_oneflow_pytorch(\n    #         test_case,\n    #         \"pytorch_crossformer.py\",\n    #         \"crossformer_tiny_patch4_group7_224\",\n    #         \"cuda\",\n    #         8,\n    #         224,\n    #     )\n\n    # def test_efficientnet_compatibility(test_case):\n    #     do_test_train_loss_oneflow_pytorch(\n    #         test_case, \"pytorch_efficientnet.py\", \"efficientnet_b0\", \"cuda\", 8, 224,\n    #     )\n\n    def test_levit_compatibility(test_case):\n        do_test_train_loss_oneflow_pytorch(\n            test_case, \"pytorch_levit.py\", \"LeViT_128S\", \"cuda\", 8, 224,\n        )\n\n    # def test_mlp_mixer_compatibility(test_case):\n    #     do_test_train_loss_oneflow_pytorch(\n    #         test_case, \"pytroch_mlp_mixer.py\", \"mixer_s32_224\", \"cuda\", 8, 224,\n    #     )\n\n    def test_poolformer_compatibility(test_case):\n        do_test_train_loss_oneflow_pytorch(\n            test_case, \"pytorch_poolformer.py\", \"poolformer_s12\", \"cuda\", 8, 224,\n        )\n\n    def test_pvt_compatibility(test_case):\n        do_test_train_loss_oneflow_pytorch(\n            test_case, \"pytorch_pvt.py\", \"pvt_tiny\", \"cuda\", 8, 224,\n        )\n\n    def test_resmlp_compatibility(test_case):\n        do_test_train_loss_oneflow_pytorch(\n            test_case, \"pytorch_resmlp.py\", \"resmlp_12\", \"cuda\", 8, 224,\n        )\n\n    def test_uniformer_compatibility(test_case):\n        do_test_train_loss_oneflow_pytorch(\n            test_case, \"pytorch_uniformer.py\", \"uniformer_small\", \"cuda\", 8, 224,\n        )\n\n    # TODO(): support non-contiguous inplace add\n    # def test_swin_transformer_compatibility(test_case):\n    #     do_test_train_loss_oneflow_pytorch(\n    #         test_case,\n    #         \"pytorch_swin_transformer.py\",\n    #         \"swin_tiny_patch4_window7_224\",\n    #         \"cuda\",\n    #         8,\n    #         224,\n    #     )\n\n    def test_senet_compatibility(test_case):\n        do_test_train_loss_oneflow_pytorch(\n            test_case, \"pytorch_senet.py\", \"senet154\", \"cuda\", 2, 224,\n        )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/expensive/test_conv3d.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestConv3DModule(flow.unittest.TestCase):\n    @autotest(n=3)\n    def test_nn_functional_conv3d(test_case):\n        flow.backends.cuda.matmul.allow_tf32 = True\n        device = random_device()\n        img = torch.ones((1, 3, 16, 16, 16), requires_grad=True).to(device)\n        kernel = torch.ones((6, 3, 3, 3, 3), requires_grad=True).to(device)\n        y = torch.nn.functional.conv3d(img, kernel)\n        return y\n\n    @autotest(n=10, rtol=1e-3, atol=1e-4)\n    def test_conv3d_with_random_data(test_case):\n        flow.backends.cuda.matmul.allow_tf32 = True\n        channels = random(1, 6)\n        m = torch.nn.Conv3d(\n            in_channels=channels,\n            out_channels=random(1, 6),\n            kernel_size=random(1, 3),\n            stride=random() | nothing(),\n            padding=random(1, 3).to(int) | nothing(),\n            dilation=random(1, 5) | nothing(),\n            groups=random(1, 5) | nothing(),\n            padding_mode=constant(\"zeros\") | nothing(),\n        )\n        m.train(random())\n        device = random_device()\n        m.to(device)\n        x = random_tensor(ndim=5, dim0=2, dim1=channels).to(device)\n        y = m(x)\n        return y\n\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    @autotest(n=5, check_allclose=False, rtol=1e-3)\n    def test_conv3d_group_with_random_data(test_case):\n        flow.backends.cuda.matmul.allow_tf32 = True\n        channels = 720  # lcm(1, 2, 3, 4, 5, 6)\n        m = torch.nn.Conv3d(\n            in_channels=channels,\n            out_channels=channels,\n            kernel_size=random(1, 4),\n            stride=random() | nothing(),\n            padding=random(1, 3).to(int) | nothing(),\n            dilation=random(1, 5) | nothing(),\n            groups=random(1, 7),\n            padding_mode=constant(\"zeros\") | nothing(),\n        )\n        m.train(random())\n\n        device = random_device()\n        m.to(device)\n        m.pytorch.to(\"cuda\")\n        x = random_tensor(ndim=5, dim1=channels).to(device)\n        x.pytorch = x.pytorch.to(\"cuda\")\n        y = m(x)\n        return y\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/expensive/test_convtranspose.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nfrom oneflow.test_utils.test_util import GenArgList\nfrom oneflow.test_utils.automated_test_util import *\n\nimport oneflow as flow\nimport oneflow.nn as nn\nimport oneflow.unittest\n\n\ndef _test_convtranspose1d_bias_false(test_case, device):\n    np_arr = np.array([[[0.35356437, -0.95761778, 0.19567713]]])\n    weight = np.ones((1, 2, 3))\n    test_out_data = np.array(\n        [\n            [\n                [0.35356438, -0.6040534, -0.40837622, -0.7619406, 0.19567713],\n                [0.35356438, -0.6040534, -0.40837622, -0.7619406, 0.19567713],\n            ]\n        ]\n    )\n    test_out_grad = np.array([[[6.0, 6.0, 6.0]]])\n    input_flow = flow.tensor(\n        np_arr, dtype=flow.float32, device=flow.device(device), requires_grad=True\n    )\n    m_f = nn.ConvTranspose1d(1, 2, 3, stride=1, bias=False)\n    m_f.weight.data = flow.tensor(weight, dtype=flow.float32)\n    m_f = m_f.to(device)\n    out_flow = m_f(input_flow)\n    test_case.assertTrue(np.allclose(out_flow.numpy(), test_out_data, 1e-03, 1e-05))\n\n    out_flow = out_flow.sum()\n    out_flow.backward()\n    test_case.assertTrue(\n        np.allclose(input_flow.grad.numpy(), test_out_grad, 1e-06, 1e-06)\n    )\n\n\ndef _test_convtranspose1d_bias_true(test_case, device):\n    np_arr = np.array([[[0.54925832, -0.64144184, 0.15213189]]])\n    weight = np.ones((1, 2, 3))\n    bias = np.array([0.16849578, 0.1509564])\n    test_out_data = np.array(\n        [\n            [\n                [0.71775407, 0.07631224, 0.22844413, -0.32081416, 0.32062766],\n                [0.7002147, 0.05877288, 0.21090476, -0.3383535, 0.3030883],\n            ]\n        ]\n    )\n    test_out_grad = np.array([[[6.0, 6.0, 6.0]]])\n\n    input_flow = flow.tensor(\n        np_arr, dtype=flow.float32, device=flow.device(device), requires_grad=True\n    )\n    m_f = nn.ConvTranspose1d(1, 2, 3, stride=1, bias=True)\n    m_f.weight.data = flow.tensor(weight, dtype=flow.float32)\n    m_f.bias = nn.Parameter(flow.Tensor(bias))\n    m_f = m_f.to(device)\n    out_flow = m_f(input_flow)\n    test_case.assertTrue(np.allclose(out_flow.numpy(), test_out_data, 1e-02, 1e-05))\n    out_flow = out_flow.sum()\n    out_flow.backward()\n    test_case.assertTrue(\n        np.allclose(input_flow.grad.numpy(), test_out_grad, 1e-06, 1e-06)\n    )\n\n\ndef _test_convtranspose1d_group_bias_false(test_case, device):\n    np_arr = np.array(\n        [[[0.38072484, -0.01421228, -0.6512485], [-0.05744093, 2.47079971, 0.17573214]]]\n    )\n    weight = np.ones((2, 1, 3))\n    test_out_data = np.array(\n        [\n            [\n                [0.38072485, 0.36651257, -0.28473592, -0.66546077, -0.6512485],\n                [-0.05744093, 2.4133587, 2.5890908, 2.6465318, 0.17573214],\n            ]\n        ]\n    )\n    test_out_grad = np.array([[[3.0, 3.0, 3.0], [3.0, 3.0, 3.0]]])\n    input_flow = flow.tensor(\n        np_arr, dtype=flow.float32, device=flow.device(device), requires_grad=True\n    )\n    m_f = nn.ConvTranspose1d(2, 2, 3, stride=1, groups=2, bias=False)\n    m_f.weight.data = flow.tensor(weight, dtype=flow.float32)\n    m_f = m_f.to(device)\n    out_flow = m_f(input_flow)\n    test_case.assertTrue(np.allclose(out_flow.numpy(), test_out_data, 1e-06, 1e-06))\n    out_flow = out_flow.sum()\n    out_flow.backward()\n    test_case.assertTrue(\n        np.allclose(input_flow.grad.numpy(), test_out_grad, 1e-06, 1e-06)\n    )\n\n\ndef _test_convtranspose1d_group_bias_true(test_case, device):\n    np_arr = np.array(\n        [\n            [\n                [-0.77808793, 0.99824008, 0.57340066],\n                [1.46278707, -0.65234252, -1.13087643],\n            ],\n            [\n                [0.76053973, 0.62332447, -1.17157106],\n                [0.60291466, -0.0472167, 0.89986403],\n            ],\n        ]\n    )\n    weight = np.ones((2, 1, 3))\n    bias = np.array([0.32546719, 0.14995032])\n    test_out_data = np.array(\n        [\n            [\n                [-0.45262071, 0.54561937, 1.11902, 1.897108, 0.89886785],\n                [1.6127374, 0.96039486, -0.1704815, -1.6332686, -0.9809261],\n            ],\n            [\n                [1.0860069, 1.7093314, 0.5377604, -0.22277936, -0.8461038],\n                [0.75286496, 0.70564824, 1.6055121, 1.0025976, 1.0498143],\n            ],\n        ]\n    )\n    test_out_grad = np.array(\n        [[[3.0, 3.0, 3.0], [3.0, 3.0, 3.0]], [[3.0, 3.0, 3.0], [3.0, 3.0, 3.0]]]\n    )\n    input_flow = flow.tensor(\n        np_arr, dtype=flow.float32, device=flow.device(device), requires_grad=True\n    )\n    m_f = nn.ConvTranspose1d(2, 2, 3, stride=1, groups=2, bias=True)\n    m_f.weight.data = flow.tensor(weight, dtype=flow.float32)\n    m_f.bias = nn.Parameter(flow.Tensor(bias))\n    m_f = m_f.to(device)\n    out_flow = m_f(input_flow)\n    test_case.assertTrue(np.allclose(out_flow.numpy(), test_out_data, 1e-06, 1e-06))\n    out_flow = out_flow.sum()\n    out_flow.backward()\n    test_case.assertTrue(\n        np.allclose(input_flow.grad.numpy(), test_out_grad, 1e-06, 1e-06)\n    )\n\n\ndef _test_convtranspose1d_group_large_out_channel(test_case, device):\n    np_arr = np.array(\n        [\n            [\n                [2.00934643, 1.5782626, -1.59060988],\n                [-1.70463546, 1.30170714, -1.04025804],\n            ],\n            [\n                [0.60327536, 1.26085986, -0.58499662],\n                [-0.48145872, -1.64391469, -0.09332249],\n            ],\n        ]\n    )\n    weight = np.ones((2, 3, 3))\n    test_out_data = np.array(\n        [\n            [\n                [2.0093465, 3.587609, 1.9969991, -0.01234734, -1.5906099],\n                [2.0093465, 3.587609, 1.9969991, -0.01234734, -1.5906099],\n                [2.0093465, 3.587609, 1.9969991, -0.01234734, -1.5906099],\n                [-1.7046355, -0.40292835, -1.4431864, 0.2614491, -1.040258],\n                [-1.7046355, -0.40292835, -1.4431864, 0.2614491, -1.040258],\n                [-1.7046355, -0.40292835, -1.4431864, 0.2614491, -1.040258],\n            ],\n            [\n                [0.60327536, 1.8641353, 1.2791386, 0.6758632, -0.58499664],\n                [0.60327536, 1.8641353, 1.2791386, 0.6758632, -0.58499664],\n                [0.60327536, 1.8641353, 1.2791386, 0.6758632, -0.58499664],\n                [-0.48145872, -2.1253734, -2.2186959, -1.7372372, -0.09332249],\n                [-0.48145872, -2.1253734, -2.2186959, -1.7372372, -0.09332249],\n                [-0.48145872, -2.1253734, -2.2186959, -1.7372372, -0.09332249],\n            ],\n        ]\n    )\n    test_out_grad = np.array(\n        [[[9.0, 9.0, 9.0], [9.0, 9.0, 9.0]], [[9.0, 9.0, 9.0], [9.0, 9.0, 9.0]]]\n    )\n    input_flow = flow.tensor(\n        np_arr, dtype=flow.float32, device=flow.device(device), requires_grad=True\n    )\n    m_f = nn.ConvTranspose1d(2, 6, 3, stride=1, groups=2, bias=False)\n    m_f.weight.data = flow.tensor(weight, dtype=flow.float32)\n    m_f = m_f.to(device)\n    out_flow = m_f(input_flow)\n    test_case.assertTrue(np.allclose(out_flow.numpy(), test_out_data, 1e-06, 1e-06))\n    out_flow = out_flow.sum()\n    out_flow.backward()\n    test_case.assertTrue(\n        np.allclose(input_flow.grad.numpy(), test_out_grad, 1e-06, 1e-06)\n    )\n\n\ndef _test_convtranspose1d_group_large_in_channel(test_case, device):\n    np_arr = np.array(\n        [\n            [\n                [-0.3939792, -0.34989742, 0.15775536],\n                [0.927185, 0.25040535, -1.22738067],\n                [-0.2187831, -0.24346108, -0.07109655],\n                [-1.55353756, -0.37241986, 0.59579139],\n            ],\n            [\n                [-0.01818884, -1.34408642, 1.31260516],\n                [0.52124192, 0.52142919, 1.40499944],\n                [0.7410308, 1.93069512, 0.25694943],\n                [-0.30531658, 0.24990326, -0.9493729],\n            ],\n        ]\n    )\n    weight = np.ones((4, 1, 3))\n    test_out_data = np.array(\n        [\n            [\n                [0.5332058, 0.43371373, -0.6359115, -1.1691173, -1.0696253],\n                [-1.7723207, -2.3882017, -1.8635068, -0.09118611, 0.52469486],\n            ],\n            [\n                [0.50305307, -0.31960416, 2.3980005, 1.8949474, 2.7176046],\n                [0.43571424, 2.6163127, 1.9238893, 1.488175, -0.69242346],\n            ],\n        ]\n    )\n    test_out_grad = np.array(\n        [\n            [[3.0, 3.0, 3.0], [3.0, 3.0, 3.0], [3.0, 3.0, 3.0], [3.0, 3.0, 3.0]],\n            [[3.0, 3.0, 3.0], [3.0, 3.0, 3.0], [3.0, 3.0, 3.0], [3.0, 3.0, 3.0]],\n        ]\n    )\n    input_flow = flow.tensor(\n        np_arr, dtype=flow.float32, device=flow.device(device), requires_grad=True\n    )\n    m_f = nn.ConvTranspose1d(4, 2, 3, stride=1, groups=2, bias=False)\n    m_f.weight.data = flow.tensor(weight, dtype=flow.float32)\n    m_f = m_f.to(device)\n    out_flow = m_f(input_flow)\n    test_case.assertTrue(np.allclose(out_flow.numpy(), test_out_data, 1e-06, 1e-06))\n    out_flow = out_flow.sum()\n    out_flow.backward()\n    test_case.assertTrue(\n        np.allclose(input_flow.grad.numpy(), test_out_grad, 1e-06, 1e-06)\n    )\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestConvTranspose(flow.unittest.TestCase):\n    def test_ConvTranspose1d(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [\n            _test_convtranspose1d_bias_false,\n            _test_convtranspose1d_bias_true,\n            _test_convtranspose1d_group_bias_false,\n            _test_convtranspose1d_group_bias_true,\n            _test_convtranspose1d_group_large_out_channel,\n            _test_convtranspose1d_group_large_in_channel,\n        ]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n    @autotest(n=5, rtol=1e-2)\n    def test_ConvTranspose1d_(test_case):\n        channels = random(1, 6)\n        m = torch.nn.ConvTranspose1d(\n            in_channels=channels,\n            out_channels=random(1, 20),\n            kernel_size=random(1, 4),\n            stride=random() | nothing(),\n            padding=random(1, 3).to(int) | nothing(),\n            dilation=random(1, 5) | nothing(),\n            groups=random(1, 5) | nothing(),\n            padding_mode=constant(\"zeros\") | nothing(),\n        )\n        m.train(random())\n        device = random_device()\n        m.to(device)\n        x = random_tensor(ndim=3, dim1=channels).to(device)\n        y = m(x)\n        return y\n\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    @autotest(n=5)\n    def test_deconv1d_group_with_random_data(test_case):\n        channels = 720  # lcm(1, 2, 3, 4, 5, 6)\n        m = torch.nn.ConvTranspose1d(\n            in_channels=channels,\n            out_channels=channels,\n            kernel_size=random(1, 4),\n            stride=random() | nothing(),\n            padding=random(1, 3).to(int) | nothing(),\n            dilation=random(1, 5) | nothing(),\n            groups=random(1, 7),\n            padding_mode=constant(\"zeros\") | nothing(),\n        )\n        m.train(random())\n\n        device = random_device()\n        m.to(device)\n        m.pytorch.to(\"cuda\")\n        x = random_tensor(ndim=3, dim1=channels).to(device)\n        x.pytorch = x.pytorch.to(\"cuda\")\n        y = m(x)\n        return y\n\n    @autotest(n=5, rtol=1e-2)\n    def test_ConvTranspose3d_(test_case):\n        channels = random(1, 2)\n        m = torch.nn.ConvTranspose3d(\n            in_channels=channels,\n            out_channels=random(1, 2),\n            kernel_size=random(1, 2),\n            stride=random() | nothing(),\n            padding=random(1, 3).to(int) | nothing(),\n            dilation=random(1, 5) | nothing(),\n            groups=1,\n            padding_mode=constant(\"zeros\") | nothing(),\n        )\n        m.train(random())\n        device = random_device()\n        m.to(device)\n        x = random_tensor(ndim=5, dim1=channels).to(device)\n        y = m(x)\n        return y\n\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    @autotest(n=5)\n    def test_deconv3d_group_with_random_data(test_case):\n        channels = 120  # lcm(1, 2, 3, 4, 5)\n        m = torch.nn.ConvTranspose3d(\n            in_channels=channels,\n            out_channels=channels,\n            kernel_size=random(1, 4),\n            stride=random() | nothing(),\n            padding=random(1, 3).to(int) | nothing(),\n            dilation=random(1, 5) | nothing(),\n            groups=random(1, 6),\n            padding_mode=constant(\"zeros\") | nothing(),\n        )\n        m.train(random())\n\n        device = random_device()\n        m.to(device)\n        m.pytorch.to(\"cuda\")\n        x = random_tensor(ndim=5, dim1=channels).to(device)\n        x.pytorch = x.pytorch.to(\"cuda\")\n        y = m(x)\n        return y\n\n    @autotest(n=3, auto_backward=False)\n    @unittest.skip(\"TODO: functional_conv_transpose might output incorrect result\")\n    def test_functional_conv_transpose1d(test_case):\n        device = random_device()\n        channels = random(1, 6)\n        img = random_tensor(ndim=3, dim1=channels).to(device)\n        kernel = random_tensor(ndim=3, dim0=channels).to(device)\n        y = torch.nn.functional.conv_transpose1d(img, kernel)\n        return y\n\n    @autotest(n=3, auto_backward=False)\n    @unittest.skip(\"TODO: functional_conv_transpose might output incorrect result\")\n    def test_functional_conv_transpose2d(test_case):\n        device = random_device()\n        channels = random(1, 6)\n        img = random_tensor(ndim=4, dim1=channels).to(device)\n        kernel = random_tensor(ndim=4, dim0=channels).to(device)\n        y = torch.nn.functional.conv_transpose2d(img, kernel)\n        return y\n\n    @autotest(n=3, auto_backward=False)\n    @unittest.skip(\"TODO: functional_conv_transpose might output incorrect result\")\n    def test_functional_conv_transpose3d(test_case):\n        device = random_device()\n        channels = random(1, 6)\n        img = random_tensor(ndim=5, dim1=channels).to(device)\n        kernel = random_tensor(ndim=5, dim0=channels).to(device)\n        y = torch.nn.functional.conv_transpose3d(img, kernel)\n        return y\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/expensive/test_dynamic_allocation_gradient_shuffle.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport os\n\n# dynamic memory allocation can't be tested in unittest\nos.environ[\"ONEFLOW_ONE_EMBEDDING_USE_DYNAMIC_MEMORY_ALLOCATION\"] = \"1\"\nimport unittest\nfrom collections import OrderedDict\nfrom oneflow.test_utils.test_util import GenArgDict\nimport numpy as np\nimport oneflow as flow\n\n\ndef round_half_away_from_zero(x):\n    sign = np.sign(x)\n    abs_val = np.abs(x)\n    abs_val += 0.5\n    floor_val = np.floor(abs_val)\n    out = floor_val * sign\n    return out\n\n\ndef _test_embedding_gradient_shuffle(test_case, enable_quantize, fp16, embedding_size):\n    batch_size = 512\n    num_tables = 26\n    ids = np.random.randint(0, 1000, (batch_size, num_tables), dtype=np.int64)\n    enable_quantized_comm = enable_quantize and embedding_size < 1025\n    if enable_quantized_comm:\n        np_tolerance = 0.5\n        os.environ[\"ONEFLOW_ONE_EMBEDDING_ENABLE_QUANTIZED_COMM\"] = \"1\"\n        ids = np.arange(batch_size * num_tables, dtype=np.int64)\n        np.random.shuffle(ids)\n    else:\n        if fp16:\n            np_tolerance = 1e-2\n        else:\n            np_tolerance = 1e-3\n        os.environ[\"ONEFLOW_ONE_EMBEDDING_ENABLE_QUANTIZED_COMM\"] = \"0\"\n\n    table_ids = (\n        ids % num_tables\n    )  # same id must have same table id, so in this case get table_ids from ids\n    embedding_grad = np.random.uniform(\n        low=-1, high=1, size=(batch_size, num_tables, embedding_size)\n    ).astype(np.float32)\n    ids_tensor = flow.tensor(ids, requires_grad=False).to(\"cuda\")\n    table_ids_tensor = flow.tensor(table_ids.astype(np.int32), requires_grad=False).to(\n        \"cuda\"\n    )\n    embedding_grad_tensor = flow.tensor(embedding_grad, requires_grad=False).to(\"cuda\")\n\n    class TestGraph(flow.nn.Graph):\n        def __init__(self):\n            super().__init__()\n\n        def build(self, ids, table_ids, embedding_grad):\n            (\n                num_unique_matrix,\n                inverse_unique_partition_indices,\n                _,\n                cur_rank_unique_ids,\n                _,\n                cur_rank_inverse_indices,\n            ) = flow._C.one_embedding_id_shuffle(ids, table_ids, num_tables, \"test\")\n            if fp16:\n                embedding_grad = flow.cast(embedding_grad, flow.float16)\n            cur_rank_unique_embedding_grad = flow._C.one_embedding_embedding_gradient_shuffle(\n                embedding_grad,\n                num_unique_matrix,\n                cur_rank_inverse_indices,\n                inverse_unique_partition_indices,\n                \"test\",\n            )\n            if fp16:\n                cur_rank_unique_embedding_grad = flow.cast(\n                    cur_rank_unique_embedding_grad, flow.float32\n                )\n            return (\n                cur_rank_unique_embedding_grad,\n                flow.cast(cur_rank_unique_ids, flow.int32),\n                flow.cast(cur_rank_inverse_indices, flow.int32),\n                flow.cast(inverse_unique_partition_indices, flow.int32),\n            )\n\n    graph = TestGraph()\n    (\n        cur_rank_unique_embedding_grad,\n        cur_rank_unique_ids,\n        cur_rank_inverse_indices,\n        inverse_unique_partition_indices,\n    ) = graph(ids_tensor, table_ids_tensor, embedding_grad_tensor)\n    np_unique_ids, np_inverse = np.unique(ids, return_inverse=True)\n    np_num_unique = np_unique_ids.size\n    np_cur_rank_unique_embedding_grad = np.zeros(\n        cur_rank_unique_embedding_grad.shape, dtype=np.float32\n    ).reshape(-1, embedding_size)\n\n    embedding_grad = embedding_grad.reshape(-1, embedding_size)\n    if fp16:\n        embedding_grad = embedding_grad.astype(np.float16)\n    for k in range(np_num_unique):\n        np_data = sum(embedding_grad[np.where(ids.flatten() == np_unique_ids[k])[0]])\n        # Quantize Embedding Gradient.\n        if enable_quantized_comm:\n            abs_max_factor = np.max(np.abs(np_data))\n            int8_factor = np.full(abs_max_factor.shape, 127.0, dtype=np.float32)\n            quantize_factor = int8_factor / abs_max_factor\n            np_data = np_data * quantize_factor\n            np_data = round_half_away_from_zero(np_data)\n            np_data = np_data.astype(np.int8)\n            np_data = np_data.astype(np.float32)\n            dequantize_factor = abs_max_factor / int8_factor\n            np_data = np_data * dequantize_factor\n\n        np_cur_rank_unique_embedding_grad[k, :] = np_data\n\n    reversed_ids = cur_rank_unique_ids[cur_rank_inverse_indices][\n        inverse_unique_partition_indices\n    ]\n    test_case.assertTrue(np.array_equal(reversed_ids.numpy(), ids))\n    of_cur_rank_embedding_grad = cur_rank_unique_embedding_grad[\n        cur_rank_inverse_indices\n    ][inverse_unique_partition_indices]\n    of_cur_rank_embedding_grad = flow.reshape(\n        of_cur_rank_embedding_grad, (-1, embedding_size)\n    )\n    np_cur_rank_embedding_grad = np_cur_rank_unique_embedding_grad[np_inverse]\n    if fp16:\n        np_cur_rank_embedding_grad = np_cur_rank_embedding_grad.astype(np.float32)\n\n    test_case.assertTrue(\n        np.allclose(\n            of_cur_rank_embedding_grad.numpy().flatten(),\n            np_cur_rank_embedding_grad.flatten(),\n            atol=np_tolerance,\n            rtol=np_tolerance,\n        )\n    )\n\n\ndef _test_unique_key_value(test_case, has_table_id, num_tables):\n    batch_size = 128\n    ids = np.random.randint(0, 1000, (batch_size, num_tables), dtype=np.int64)\n    if has_table_id:\n        table_ids = (\n            ids % num_tables\n        )  # same id must have same table id, so in this case get table_ids from ids\n        table_ids_tensor = flow.tensor(\n            table_ids.astype(np.int32), requires_grad=False\n        ).to(\"cuda\")\n    else:\n        table_ids_tensor = None\n    ids_tensor = flow.tensor(ids, requires_grad=False).to(\"cuda\")\n\n    class TestGraph(flow.nn.Graph):\n        def __init__(self):\n            super().__init__()\n\n        def build(self, ids, table_ids):\n            (\n                num_unique,\n                unique_ids,\n                unique_table_ids,\n                inverse_indices,\n            ) = flow._C.one_embedding_unique_key_value_pair(ids, table_ids, num_tables)\n            return (\n                flow.cast(num_unique, flow.int32),\n                flow.cast(unique_ids, flow.int32),\n                flow.cast(unique_table_ids, flow.int32),\n                flow.cast(inverse_indices, flow.int32),\n            )\n\n    graph = TestGraph()\n    (num_unique, unique_ids, unique_table_ids, inverse_indices,) = graph(\n        ids_tensor, table_ids_tensor\n    )\n    np_unique_ids, np_inverse = np.unique(ids, return_inverse=True)\n    np_num_unique = np_unique_ids.size\n    test_case.assertTrue(np.array_equal(np_num_unique, num_unique[0]))\n    reversed_ids = unique_ids[inverse_indices]\n    test_case.assertTrue(np.array_equal(reversed_ids.numpy(), ids))\n    if has_table_id:\n        reversed_table_ids = unique_table_ids[inverse_indices]\n        test_case.assertTrue(np.array_equal(reversed_table_ids.numpy(), table_ids))\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n@flow.unittest.skip_unless_1n1d()\nclass DataShuffleTestCase(flow.unittest.TestCase):\n    def test_embedding_gradient_shuffle(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"enable_quantize\"] = [True, False]\n        arg_dict[\"fp16\"] = [True, False]\n        arg_dict[\"embedding_size\"] = [128, 17]\n        for kwargs in GenArgDict(arg_dict):\n            _test_embedding_gradient_shuffle(test_case, **kwargs)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/expensive/test_einsum.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nimport oneflow as flow\n\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestEinsum(flow.unittest.TestCase):\n    @autotest(n=5)\n    def test_einsum_matrix_transpose(test_case):\n        device = random_device()\n        x = random_tensor(ndim=2, dim0=random(1, 6), dim1=random(1, 6),).to(device)\n        z = torch.einsum(\"ij->ji\", x)\n        return z\n\n    @autotest(n=5)\n    def test_einsum_eltwise_multiply(test_case):\n        device = random_device()\n        dim0 = random(1, 6)\n        dim1 = random(1, 6)\n        x = random_tensor(ndim=2, dim0=dim0, dim1=dim1,).to(device)\n        y = random_tensor(ndim=2, dim0=dim0, dim1=dim1,).to(device)\n        z = torch.einsum(\"ij,ij->ij\", x, y)\n        return z\n\n    @autotest(n=5)\n    def test_einsum_get_diagonal(test_case):\n        device = random_device()\n        dim = random(1, 6)\n        x = random_tensor(ndim=2, dim0=dim, dim1=dim,).to(device)\n        z = torch.einsum(\"ii->i\", x)\n        return z\n\n    @autotest(n=5)\n    def test_einsum_batch_permute(test_case):\n        device = random_device()\n        x = random_tensor(\n            ndim=5,\n            dim0=random(1, 6),\n            dim1=random(1, 6),\n            dim2=random(1, 6),\n            dim3=random(1, 6),\n            dim4=random(1, 6),\n        ).to(device)\n        z = torch.einsum(\"...ij->...ji\", x)\n        return z\n\n    @autotest(n=5)\n    def test_einsum_reduce_sum(test_case):\n        device = random_device()\n        x = random_tensor(ndim=2, dim0=random(1, 6), dim1=random(1, 6),).to(device)\n        z = torch.einsum(\"ij->\", x)\n        return z\n\n    @autotest(n=5)\n    def test_einsum_matrix_column_sum(test_case):\n        device = random_device()\n        x = random_tensor(ndim=2, dim0=random(1, 6), dim1=random(1, 6),).to(device)\n        z = torch.einsum(\"ij->j\", x)\n        return z\n\n    @autotest(n=5, rtol=1e-2, atol=1e-4)\n    def test_einsum_matrix_vector_multiply(test_case):\n        device = random_device()\n        dim0 = random(1, 6)\n        dim1 = random(1, 6)\n        x = random_tensor(ndim=2, dim0=dim0, dim1=dim1,).to(device)\n        y = random_tensor(ndim=1, dim0=dim1,).to(device)\n        # NOTE(Liang Depeng): the same as 'ik,k->i'\n        z = torch.einsum(\"ik,k\", x, y)\n        return z\n\n    @autotest(n=5, rtol=1e-2, atol=1e-3)\n    def test_einsum_matmul(test_case):\n        device = random_device()\n        dim0 = random(1, 6)\n        dim1 = random(1, 6)\n        dim2 = random(1, 6)\n        x = random_tensor(ndim=2, dim0=dim0, dim1=dim1,).to(device)\n        y = random_tensor(ndim=2, dim0=dim1, dim1=dim2,).to(device)\n        # NOTE(Liang Depeng): the same as 'ik,kj->ij'\n        z = torch.einsum(\"ik,kj\", x, y)\n        return z\n\n    @autotest(n=5)\n    def test_einsum_vector_inner_product(test_case):\n        device = random_device()\n        dim0 = random(1, 6)\n        x = random_tensor(ndim=1, dim0=dim0,).to(device)\n        y = random_tensor(ndim=1, dim0=dim0,).to(device)\n        # NOTE(Liang Depeng): the same as 'i,i->'\n        z = torch.einsum(\"i,i\", x, y)\n        return z\n\n    @autotest(n=5)\n    def test_einsum_eltwise_mul_then_reduce_sum(test_case):\n        device = random_device()\n        dim0 = random(1, 6)\n        dim1 = random(1, 6)\n        x = random_tensor(ndim=2, dim0=dim0, dim1=dim1,).to(device)\n        y = random_tensor(ndim=2, dim0=dim0, dim1=dim1,).to(device)\n        # NOTE(Liang Depeng): the same as 'ij,ij->'\n        z = torch.einsum(\"ij,ij\", x, y)\n        return z\n\n    @autotest(n=5)\n    def test_einsum_vector_outer_product(test_case):\n        device = random_device()\n        x = random_tensor(ndim=1, dim0=random(1, 6),).to(device)\n        y = random_tensor(ndim=1, dim0=random(1, 6),).to(device)\n        # NOTE(Liang Depeng): the same as 'i,j->ij'\n        z = torch.einsum(\"i,j\", x, y)\n        return z\n\n    @autotest(n=5, rtol=1e-2)\n    def test_einsum_batch_matmul(test_case):\n        device = random_device()\n        dim0 = random(1, 6)\n        dim1 = random(1, 6)\n        x = random_tensor(ndim=3, dim0=dim0, dim1=random(1, 6), dim2=dim1,).to(device)\n        y = random_tensor(ndim=3, dim0=dim0, dim1=dim1, dim2=random(1, 6),).to(device)\n        z = torch.einsum(\"ijk,ikl->ijl\", x, y)\n        return z\n\n    @autotest(n=5, rtol=1e-2, atol=1e-3)\n    def test_einsum_tensor_contraction(test_case):\n        device = random_device()\n        dim0 = random(1, 6)\n        dim1 = random(1, 6)\n        x = random_tensor(\n            ndim=4, dim0=random(1, 6), dim1=dim0, dim2=dim1, dim3=random(1, 6),\n        ).to(device)\n        y = random_tensor(\n            ndim=5,\n            dim0=random(1, 6),\n            dim1=random(1, 6),\n            dim2=dim0,\n            dim3=random(1, 6),\n            dim4=dim1,\n        ).to(device)\n        z = torch.einsum(\"pqrs,tuqvr->pstuv\", x, y)\n        return z\n\n    @autotest(n=5, rtol=1e-2, atol=1e-3)\n    def test_einsum_bilinear_transformation(test_case):\n        device = random_device()\n        dim0 = random(1, 6)\n        dim1 = random(1, 6)\n        dim2 = random(1, 6)\n        x = random_tensor(ndim=2, dim0=dim0, dim1=dim1,).to(device)\n        y = random_tensor(ndim=3, dim0=random(1, 6), dim1=dim1, dim2=dim2,).to(device)\n        w = random_tensor(ndim=2, dim0=dim0, dim1=dim2,).to(device)\n        z = torch.einsum(\"ik,jkl,il->ij\", x, y, w)\n        return z\n\n    @autotest(n=20, auto_backward=False, check_graph=True)\n    def test_einsum_0_size_tensor(test_case):\n        device = random_device()\n        x = random_tensor(ndim=3, dim0=random(1, 6), dim1=0, dim2=random(1, 6),).to(\n            device\n        )\n        z = torch.einsum(\"ijk\", x)\n        return z\n\n    @unittest.skip(\"skip for now, becase it failed 20 times in past week\")\n    @autotest(n=5, rtol=1e-2, atol=1e-3)\n    def test_einsum_tensor_contraction2(test_case):\n        device = random_device()\n        dim0 = random(1, 6)\n        x = random_tensor(\n            ndim=4, dim0=random(1, 6), dim1=dim0, dim2=random(1, 6), dim3=random(1, 6),\n        ).to(device)\n        y = random_tensor(ndim=2, dim0=dim0, dim1=random(1, 6),).to(device)\n        z = torch.einsum(\"b n h w, n d -> b d h w\", x, y)\n        return z\n\n    @autotest(n=5)\n    def test_einsum_eltwise_mul_sum_row(test_case):\n        device = random_device()\n        dim0 = random(1, 6)\n        dim1 = random(1, 6)\n        x = random_tensor(ndim=2, dim0=dim0, dim1=dim1,).to(device)\n        y = random_tensor(ndim=2, dim0=dim0, dim1=dim1,).to(device)\n        z = torch.einsum(\"n d, n d -> n\", x, y)\n        return z\n\n    @unittest.skip(\"skip for now, becase it failed 20 times in past week\")\n    @autotest(n=5, rtol=1e-2, atol=1e-4)\n    def test_einsum_matmul2(test_case):\n        device = random_device()\n        dim0 = random(1, 6)\n        x = random_tensor(ndim=2, dim0=random(1, 6), dim1=dim0,).to(device)\n        y = random_tensor(ndim=2, dim0=random(1, 6), dim1=dim0,).to(device)\n        z = torch.einsum(\"i d, j d -> i j\", x, y)\n        return z\n\n    @autotest(n=5, rtol=1e-3)\n    def test_einsum_attention(test_case):\n        device = random_device()\n        dim0 = random(1, 6)\n        dim1 = random(1, 6)\n        dim2 = random(1, 6)\n        x = random_tensor(\n            ndim=4, dim0=dim0, dim1=dim1, dim2=random(1, 6), dim3=dim2,\n        ).to(device)\n        y = random_tensor(\n            ndim=4, dim0=dim0, dim1=dim1, dim2=random(1, 6), dim3=dim2,\n        ).to(device)\n        z = torch.einsum(\"b h i d, b h j d -> b h i j\", x, y)\n        return z\n\n    @autotest(n=5, rtol=1e-3)\n    def test_einsum_batch_matmul2(test_case):\n        device = random_device()\n        dim0 = random(1, 6)\n        dim1 = random(1, 6)\n        dim2 = random(1, 6)\n        x = random_tensor(\n            ndim=4, dim0=dim0, dim1=dim1, dim2=random(1, 6), dim3=dim2\n        ).to(device)\n        y = random_tensor(\n            ndim=4, dim0=dim0, dim1=dim1, dim2=dim2, dim3=random(1, 6)\n        ).to(device)\n        z = torch.einsum(\"b h i j, b h j d -> b h i d\", x, y)\n        return z\n\n    @unittest.skip(\"skip for now, becase it failed 28 times in past week\")\n    @autotest(n=5, rtol=1e-2)\n    def test_einsum_batch_matrix_vector_multiply(test_case):\n        device = random_device()\n        dim0 = random(1, 6)\n        dim1 = random(1, 6)\n        dim2 = random(1, 6)\n        x = random_tensor(ndim=3, dim0=dim0, dim1=dim1, dim2=dim2,).to(device)\n        y = random_tensor(\n            ndim=4, dim0=dim0, dim1=dim1, dim2=random(1, 6), dim3=dim2,\n        ).to(device)\n        z = torch.einsum(\"b i d, b i j d -> b i j\", x, y)\n        return z\n\n    @autotest(n=5, rtol=1e-2, atol=1e-4)\n    def test_einsum_batch_matmul3(test_case):\n        device = random_device()\n        dim0 = random(1, 6)\n        dim1 = random(1, 6)\n        x = random_tensor(\n            ndim=4, dim0=dim0, dim1=random(1, 6), dim2=random(1, 6), dim3=dim1,\n        ).to(device)\n        y = random_tensor(ndim=3, dim0=dim0, dim1=random(1, 6), dim2=dim1,).to(device)\n        z = torch.einsum(\"b x i d, b j d -> b x i j\", x, y)\n        return z\n\n    @autotest(n=5, rtol=1e-2)\n    def test_einsum_batch_matmul4(test_case):\n        device = random_device()\n        dim0 = random(1, 6)\n        dim1 = random(1, 6)\n        x = random_tensor(\n            ndim=4, dim0=dim0, dim1=random(1, 6), dim2=random(1, 6), dim3=dim1,\n        ).to(device)\n        y = random_tensor(ndim=3, dim0=dim0, dim1=dim1, dim2=random(1, 6),).to(device)\n        z = torch.einsum(\"b x i j, b j d -> b x i d\", x, y)\n        return z\n\n    @autotest(n=5, rtol=1e-2, atol=1e-4)\n    def test_einsum_alphaflod_usecase1(test_case):\n        device = random_device()\n        dim0 = random(1, 6)\n        dim1 = random(1, 6)\n        x = random_tensor(ndim=3, dim0=random(1, 6), dim1=dim0, dim2=dim1,).to(device)\n        y = random_tensor(ndim=3, dim0=dim0, dim1=dim1, dim2=random(1, 6),).to(device)\n        z = torch.einsum(\"hij, ijc->ihc\", x, y)\n        return z\n\n    @autotest(n=5, rtol=1e-2, atol=1e-3)\n    def test_einsum_alphaflod_usecase2(test_case):\n        device = random_device()\n        dim0 = random(1, 6)\n        dim1 = random(1, 6)\n        x = random_tensor(ndim=3, dim0=dim0, dim1=dim1, dim2=random(1, 6),).to(device)\n        y = random_tensor(ndim=3, dim0=dim0, dim1=dim1, dim2=random(1, 6),).to(device)\n        z = torch.einsum(\"rac,rab->rbc\", x, y)\n        return z\n\n    @autotest(n=5, rtol=1e-2)\n    def test_einsum_alphaflod_usecase3(test_case):\n        device = random_device()\n        dim0 = random(1, 6)\n        dim1 = random(1, 6)\n        x = random_tensor(ndim=2, dim0=dim0, dim1=dim1,).to(device)\n        y = random_tensor(ndim=3, dim0=dim0, dim1=dim1, dim2=random(1, 6),).to(device)\n        z = torch.einsum(\"ra,rab->rb\", x, y)\n        return z\n\n    @autotest(n=5, rtol=1e-2, atol=1e-4)\n    def test_einsum_alphaflod_usecase4(test_case):\n        device = random_device()\n        dim0 = random(1, 6)\n        dim1 = random(1, 6)\n        x = random_tensor(ndim=3, dim0=random(1, 6), dim1=dim0, dim2=dim1,).to(device)\n        y = random_tensor(ndim=3, dim0=random(1, 6), dim1=dim0, dim2=dim1,).to(device)\n        z = torch.einsum(\"qhc,khc->qkh\", x, y)\n        return z\n\n    @autotest(n=5, rtol=1e-2, atol=1e-3)\n    def test_einsum_alphaflod_usecase5(test_case):\n        device = random_device()\n        dim0 = random(1, 6)\n        x = random_tensor(ndim=2, dim0=random(1, 6), dim1=dim0,).to(device)\n        y = random_tensor(ndim=3, dim0=dim0, dim1=random(1, 6), dim2=random(1, 6),).to(\n            device\n        )\n        z = torch.einsum(\"nm, mrc->nrc\", x, y)\n        return z\n\n    @autotest(n=5, rtol=1e-2, atol=1e-3)\n    def test_einsum_alphaflod_usecase6(test_case):\n        device = random_device()\n        dim0 = random(1, 6)\n        dim1 = random(1, 6)\n        x = random_tensor(ndim=3, dim0=dim0, dim1=random(1, 6), dim2=dim1,).to(device)\n        y = random_tensor(ndim=3, dim0=dim0, dim1=random(1, 6), dim2=dim1,).to(device)\n        z = torch.einsum(\"abc,adc->bdc\", x, y)\n        return z\n\n    @autotest(n=5, rtol=1e-2, atol=1e-3)\n    def test_einsum_alphaflod_usecase7(test_case):\n        device = random_device()\n        dim0 = random(1, 6)\n        dim1 = random(1, 6)\n        x = random_tensor(\n            ndim=4, dim0=random(1, 6), dim1=dim0, dim2=dim1, dim3=random(1, 6),\n        ).to(device)\n        y = random_tensor(ndim=3, dim0=dim0, dim1=dim1, dim2=random(1, 6),).to(device)\n        z = torch.einsum(\"dceb,cef->dbf\", x, y)\n        return z\n\n    @autotest(n=5, rtol=1e-2, atol=1e-3)\n    def test_einsum_alphaflod_usecase8(test_case):\n        device = random_device()\n        dim0 = random(1, 6)\n        x = random_tensor(ndim=3, dim0=dim0, dim1=random(1, 6), dim2=random(1, 6),).to(\n            device\n        )\n        y = random_tensor(ndim=3, dim0=dim0, dim1=random(1, 6), dim2=random(1, 6),).to(\n            device\n        )\n        z = torch.einsum(\"acb,ade->dceb\", x, y)\n        return z\n\n    @autotest(n=5, rtol=1e-2, atol=1e-3)\n    def test_einsum_alphaflod_usecase9(test_case):\n        device = random_device()\n        dim0 = random(1, 6)\n        x = random_tensor(ndim=3, dim0=random(1, 6), dim1=random(1, 6), dim2=dim0,).to(\n            device\n        )\n        y = random_tensor(ndim=2, dim0=dim0, dim1=random(1, 6),).to(device)\n        z = torch.einsum(\"qkc,ch->hqk\", x, y)\n        return z\n\n    @autotest(n=5, rtol=1e-2, atol=1e-4)\n    def test_einsum_alphaflod_usecase10(test_case):\n        device = random_device()\n        dim0 = random(1, 6)\n        dim1 = random(1, 6)\n        dim2 = random(1, 6)\n        x = random_tensor(\n            ndim=4, dim0=dim0, dim1=dim1, dim2=random(1, 6), dim3=dim2,\n        ).to(device)\n        y = random_tensor(\n            ndim=4, dim0=dim0, dim1=dim2, dim2=dim1, dim3=random(1, 6)\n        ).to(device)\n        z = torch.einsum(\"bhqk,bkhc->bqhc\", x, y)\n        return z\n\n    @autotest(n=5, rtol=1e-2, atol=1e-3)\n    def test_einsum_alphaflod_usecase11(test_case):\n        device = random_device()\n        dim0 = random(1, 6)\n        x = random_tensor(ndim=3, dim0=random(1, 6), dim1=random(1, 6), dim2=dim0,).to(\n            device\n        )\n        y = random_tensor(ndim=3, dim0=dim0, dim1=random(1, 6), dim2=random(1, 6),).to(\n            device\n        )\n        z = torch.einsum(\"bqa,ahc->bqhc\", x, y)\n        return z\n\n    @autotest(n=5, rtol=1e-2, atol=1e-4)\n    def test_einsum_ellipsis_usecase1(test_case):\n        device = random_device()\n        dim0 = random(1, 6)\n        x = random_tensor(ndim=3, dim0=random(1, 6), dim1=random(1, 6), dim2=dim0,).to(\n            device\n        )\n        y = random_tensor(ndim=3, dim0=random(1, 6), dim1=random(1, 6), dim2=dim0,).to(\n            device\n        )\n        z = torch.einsum(\"...lc, ...c -> ...l\", x, y)\n        return z\n\n    @autotest(n=5, rtol=1e-2)\n    def test_einsum_ellipsis_usecase2(test_case):\n        device = random_device()\n        dim0 = random(1, 6)\n        dim1 = random(1, 6)\n        x = random_tensor(ndim=3, dim0=random(1, 6), dim1=dim0, dim2=dim1,).to(device)\n        y = random_tensor(ndim=3, dim0=random(1, 6), dim1=dim0, dim2=dim1).to(device)\n        z = torch.einsum(\"...lc, ...lc -> ...l\", x, y)\n        return z\n\n    @autotest(n=5, rtol=1e-2, atol=1e-3)\n    def test_einsum_ellipsis_usecase3(test_case):\n        device = random_device()\n        dim0 = random(1, 6)\n        x = random_tensor(ndim=3, dim0=random(1, 6), dim1=random(1, 6), dim2=dim0,).to(\n            device\n        )\n        y = random_tensor(ndim=3, dim0=random(1, 6), dim1=random(1, 6), dim2=dim0).to(\n            device\n        )\n        z = torch.einsum(\"...id,...jd->...ij\", x, y)\n        return z\n\n    @autotest(n=5, rtol=1e-2, atol=1e-4)\n    def test_einsum_ellipsis_usecase4(test_case):\n        device = random_device()\n        dim0 = random(1, 6)\n        dim1 = random(1, 6)\n        x = random_tensor(\n            ndim=4, dim0=random(1, 6), dim1=dim0, dim2=random(1, 6), dim3=dim1\n        ).to(device)\n        y = random_tensor(ndim=3, dim0=dim0, dim1=dim1, dim2=random(1, 6)).to(device)\n        z = torch.einsum(\"...klm,kmn->...kln\", x, y)\n        return z\n\n    @autotest(n=5, rtol=1e-2, atol=1e-3)\n    def test_einsum_ellipsis_usecase5(test_case):\n        device = random_device()\n        dim0 = random(1, 6)\n        x = random_tensor(\n            ndim=4, dim0=random(1, 6), dim1=random(1, 6), dim2=dim0, dim3=random(1, 6)\n        ).to(device)\n        y = random_tensor(ndim=3, dim0=random(1, 6), dim1=random(1, 6), dim2=dim0).to(\n            device\n        )\n        z = torch.einsum(\"...ikl, ...jk -> ...ijl\", x, y)\n        return z\n\n    @autotest(n=5, rtol=1e-2, atol=1e-4)\n    def test_einsum_ellipsis_usecase6(test_case):\n        device = random_device()\n        dim0 = random(1, 6)\n        x = random_tensor(ndim=3, dim0=random(1, 6), dim1=random(1, 6), dim2=dim0).to(\n            device\n        )\n        y = random_tensor(ndim=3, dim0=random(1, 6), dim1=random(1, 6), dim2=dim0).to(\n            device\n        )\n        z = torch.einsum(\"...l,...l->...\", x, y)\n        return z\n\n    @autotest(n=5)\n    def test_einsum_ellipsis_usecase7(test_case):\n        device = random_device()\n        dim0 = random(1, 6)\n        dim1 = random(1, 6)\n        dim2 = random(1, 6)\n        x = random_tensor(ndim=3, dim0=dim0, dim1=dim1, dim2=dim2).to(device)\n        y = random_tensor(\n            ndim=4, dim0=dim0, dim1=dim1, dim2=dim2, dim3=random(1, 6)\n        ).to(device)\n        z = torch.einsum(\"ijk,ijk...->ij...\", x, y)\n        return z\n\n    @autotest(n=5, rtol=1e-2, atol=1e-3)\n    def test_einsum_other_usecase1(test_case):\n        device = random_device()\n        dim0 = random(1, 6)\n        dim1 = random(1, 6)\n        dim2 = random(1, 6)\n        x = random_tensor(ndim=3, dim0=dim0, dim1=random(1, 6), dim2=dim1).to(device)\n        y = random_tensor(ndim=3, dim0=random(1, 6), dim1=dim1, dim2=dim2).to(device)\n        w = random_tensor(ndim=3, dim0=dim0, dim1=random(1, 6), dim2=dim2).to(device)\n        z = torch.einsum(\"bxi,oij,byj->boxy\", x, y, w)\n        return z\n\n    @autotest(n=5)\n    def test_einsum_other_usecase2(test_case):\n        device = random_device()\n        dim0 = random(1, 6)\n        dim1 = random(1, 6)\n        x = random_tensor(\n            ndim=4, dim0=dim0, dim1=dim1, dim2=random(1, 6), dim3=random(1, 6)\n        ).to(device)\n        y = random_tensor(\n            ndim=4, dim0=dim0, dim1=dim1, dim2=random(1, 6), dim3=random(1, 6)\n        ).to(device)\n        z = torch.einsum(\"ijac,ijkp->ijakcp\", x, y)\n        return z\n\n    @autotest(n=5, rtol=1e-2, atol=1e-3)\n    def test_einsum_other_usecase3(test_case):\n        device = random_device()\n        dim0 = random(1, 6)\n        dim1 = random(1, 6)\n        x = random_tensor(\n            ndim=4, dim0=dim0, dim1=random(1, 6), dim2=dim1, dim3=random(1, 6)\n        ).to(device)\n        y = random_tensor(ndim=3, dim0=dim0, dim1=random(1, 6), dim2=dim1).to(device)\n        z = torch.einsum(\"cdij,cbi->cdbj\", x, y)\n        return z\n\n    @autotest(n=5, rtol=1e-2, atol=1e-4)\n    def test_einsum_fastfold_usecase1(test_case):\n        device = random_device()\n        dim0 = random(1, 6)\n        dim1 = random(1, 6)\n        dim2 = random(1, 6)\n        x = random_tensor(\n            ndim=4, dim0=dim0, dim1=dim1, dim2=random(1, 6), dim3=dim2\n        ).to(device)\n        y = random_tensor(\n            ndim=4, dim0=dim0, dim1=dim1, dim2=random(1, 6), dim3=dim2\n        ).to(device)\n        z = torch.einsum(\"bsid,bsjd->bijd\", x, y)\n        return z\n\n    @autotest(n=5, rtol=1e-2, atol=1e-3)\n    def test_einsum_fastfold_usecase2(test_case):\n        device = random_device()\n        dim0 = random(1, 6)\n        dim1 = random(1, 6)\n        x = random_tensor(\n            ndim=4, dim0=dim0, dim1=dim1, dim2=random(1, 6), dim3=random(1, 6)\n        ).to(device)\n        y = random_tensor(\n            ndim=4, dim0=dim0, dim1=dim1, dim2=random(1, 6), dim3=random(1, 6)\n        ).to(device)\n        z = torch.einsum(\"bsid,bsje->bijde\", x, y)\n        return z\n\n    @autotest(n=5, rtol=1e-2, atol=1e-3)\n    def test_einsum_openfold_usecase1(test_case):\n        device = random_device()\n        dim0 = random(1, 6)\n        x = random_tensor(\n            ndim=4, dim0=random(1, 6), dim1=random(1, 6), dim2=dim0, dim3=random(1, 6)\n        ).to(device)\n        y = random_tensor(\n            ndim=4, dim0=random(1, 6), dim1=random(1, 6), dim2=dim0, dim3=random(1, 6)\n        ).to(device)\n        z = torch.einsum(\"...bac,...dae->...bdce\", x, y)\n        return z\n\n    @autotest(n=5, rtol=1e-2, atol=1e-4)\n    def test_einsum_openfold_usecase2(test_case):\n        device = random_device()\n        dim0 = random(1, 6)\n        dim1 = random(1, 6)\n        x = random_tensor(\n            ndim=4, dim0=random(1, 6), dim1=dim0, dim2=random(1, 6), dim3=dim1\n        ).to(device)\n        y = random_tensor(\n            ndim=4, dim0=random(1, 6), dim1=dim0, dim2=random(1, 6), dim3=dim1\n        ).to(device)\n        z = torch.einsum(\"...abc,...adc->...bdc\", x, y)\n        return z\n\n    @autotest(n=5, rtol=1e-2, atol=1e-4)\n    def test_einsum_openfold_usecase3(test_case):\n        device = random_device()\n        dim0 = random(1, 6)\n        dim1 = random(1, 6)\n        x = random_tensor(\n            ndim=4, dim0=random(1, 6), dim1=random(1, 6), dim2=dim0, dim3=dim1\n        ).to(device)\n        y = random_tensor(\n            ndim=4, dim0=random(1, 6), dim1=random(1, 6), dim2=dim0, dim3=dim1\n        ).to(device)\n        z = torch.einsum(\"...qhd,...khd->...hqk\", x, y)\n        return z\n\n    @autotest(n=5, rtol=1e-2, atol=1e-3)\n    def test_einsum_openfold_usecase4(test_case):\n        device = random_device()\n        dim0 = random(1, 6)\n        dim1 = random(1, 6)\n        x = random_tensor(\n            ndim=4, dim0=random(1, 6), dim1=dim0, dim2=dim1, dim3=random(1, 6)\n        ).to(device)\n        y = random_tensor(\n            ndim=4, dim0=random(1, 6), dim1=random(1, 6), dim2=dim1, dim3=dim0\n        ).to(device)\n        z = torch.einsum(\"...vhf,...qhv->...qhf\", x, y)\n        return z\n\n    @autotest(n=5, rtol=1e-2, atol=1e-3)\n    def test_einsum_openfold_usecase5(test_case):\n        device = random_device()\n        dim0 = random(1, 6)\n        x = random_tensor(\n            ndim=4, dim0=random(1, 6), dim1=random(1, 6), dim2=random(1, 6), dim3=dim0\n        ).to(device)\n        y = random_tensor(ndim=2, dim0=dim0, dim1=random(1, 6)).to(device)\n        z = torch.einsum(\"...ij,jk->ik\", x, y)\n        return z\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/expensive/test_global_tensor_offload.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\n\nfrom oneflow.test_utils.automated_test_util import *\n\nimport oneflow as flow\nimport oneflow.nn as nn\nimport oneflow.unittest\n\n# NOTE(Li Xiang): This variable controls the mem comparison method of the tensor offload test.\n#  1: Strictly test, compare mem changes according to tensor size.\n#  2: Loose test, compare mem changes before and after offload;\n#  3: Execute only offload, skip mem check.\noffload_tensor_test_mem_mode = 3\n\n\ndef _test_global_tensor_offload_d2h(test_case, input, tensor_mem):\n    test_case.assertTrue(not input.is_offloaded())\n    flow.cuda.empty_cache()\n    if input.placement == oneflow.placement(type=\"cuda\", ranks=[0, 1]):\n        flow._oneflow_internal.CudaSynchronize(0)\n        flow._oneflow_internal.CudaSynchronize(1)\n    elif input.placement == oneflow.placement(type=\"cuda\", ranks=[0, 1, 2, 3]):\n        flow._oneflow_internal.CudaSynchronize(0)\n        flow._oneflow_internal.CudaSynchronize(1)\n        flow._oneflow_internal.CudaSynchronize(2)\n        flow._oneflow_internal.CudaSynchronize(3)\n\n    flow._oneflow_internal.eager.ClusterSync()\n    before_used = flow._oneflow_internal.GetCUDAMemoryUsed()\n    before_id = id(input)\n    print(\"cuda\", before_used)\n\n    input.offload()\n    test_case.assertTrue(input.is_offloaded())\n    test_case.assertEqual(input.placement.type, \"cuda\")\n    after_used = flow._oneflow_internal.GetCUDAMemoryUsed()\n    after_id = id(input)\n    print(\"cuda to cpu\", after_used)\n    # Check global_tensor_mem cuda memory released\n    if offload_tensor_test_mem_mode == 1:\n        # NOTE(Li Xiang): In the case of 4 gpus, the memory usage of the tensor sometimes has a 2MB error.\n        if input.placement == oneflow.placement(type=\"cuda\", ranks=[0, 1, 2, 3]):\n            test_case.assertTrue(\n                ((before_used - after_used) == tensor_mem)\n                or ((before_used - after_used) == (tensor_mem - 2))\n            )\n            return\n        test_case.assertTrue((before_used - after_used) == tensor_mem)\n    elif offload_tensor_test_mem_mode == 2:\n        test_case.assertTrue(before_used > after_used)\n    elif offload_tensor_test_mem_mode == 3:\n        print(\n            \"Device:\",\n            flow.env.get_rank(),\n            \". cuda mem change value:\",\n            before_used - after_used,\n        )\n    test_case.assertEqual(before_id, after_id)\n\n\ndef _test_global_tensor_load_h2d(test_case, input, tensor_mem):\n    test_case.assertTrue(input.is_offloaded())\n\n    if input.placement == oneflow.placement(type=\"cuda\", ranks=[0, 1]):\n        flow._oneflow_internal.CudaSynchronize(0)\n        flow._oneflow_internal.CudaSynchronize(1)\n    elif input.placement == oneflow.placement(type=\"cuda\", ranks=[0, 1, 2, 3]):\n        flow._oneflow_internal.CudaSynchronize(0)\n        flow._oneflow_internal.CudaSynchronize(1)\n        flow._oneflow_internal.CudaSynchronize(2)\n        flow._oneflow_internal.CudaSynchronize(3)\n\n    flow._oneflow_internal.eager.ClusterSync()\n    before_used = flow._oneflow_internal.GetCUDAMemoryUsed()\n    before_id = id(input)\n\n    input.load()\n    test_case.assertTrue(not input.is_offloaded())\n    test_case.assertEqual(input.placement.type, \"cuda\")\n    after_used = flow._oneflow_internal.GetCUDAMemoryUsed()\n    after_id = id(input)\n    print(\"cpu to cuda\", after_used)\n    # Check global_tensor_mem cuda memory allocated\n    if offload_tensor_test_mem_mode == 1:\n        # NOTE(Li Xiang): In the case of 4 gpus, the memory usage of the tensor sometimes has a 2MB error.\n        if input.placement == oneflow.placement(type=\"cuda\", ranks=[0, 1, 2, 3]):\n            test_case.assertTrue(\n                ((after_used - before_used) == tensor_mem)\n                or ((after_used - before_used) == (tensor_mem - 2))\n            )\n            return\n        test_case.assertTrue((after_used - before_used) == tensor_mem)\n    elif offload_tensor_test_mem_mode == 2:\n        test_case.assertTrue(after_used > before_used)\n    elif offload_tensor_test_mem_mode == 3:\n        print(\n            \"Device:\",\n            flow.env.get_rank(),\n            \". cuda mem change value:\",\n            after_used - before_used,\n        )\n    test_case.assertEqual(before_id, after_id)\n\n\ndef _get_specific_global_tensor_mem(placement, sbp, tensor):\n    size_tensor = tensor.clone().detach().to_local()\n    cnt_size = size_tensor.element_size() * flow.numel(size_tensor)\n    return cnt_size / 1024 / 1024\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\nclass TestGlobalTensorOffload(flow.unittest.TestCase):\n    @globaltest\n    @flow.unittest.skip_unless_1n2d()\n    def test_global_tensor_offload_and_load_2d(test_case):\n        for i in range(5):\n            placement = flow.placement(\"cuda\", ranks=[0, 1])\n            for sbp in all_sbp(placement, max_dim=2, except_partial_sum=True):\n                input = flow.randn(\n                    1024, 1024, 100, dtype=flow.float32, placement=placement, sbp=sbp\n                )\n                data = input.numpy()\n                tensor_mem = _get_specific_global_tensor_mem(placement, sbp, input)\n                _test_global_tensor_offload_d2h(test_case, input, tensor_mem)\n                _test_global_tensor_load_h2d(test_case, input, tensor_mem)\n                test_case.assertTrue(\n                    np.allclose(input.numpy(), data, rtol=0.0001, atol=0.0001)\n                )\n\n    @globaltest\n    @flow.unittest.skip_unless_1n4d()\n    def test_global_tensor_offload_and_load_4d(test_case):\n        for i in range(5):\n            placement = flow.placement(\"cuda\", ranks=[0, 1, 2, 3])\n            for sbp in all_sbp(placement, max_dim=2, except_partial_sum=True):\n                input = flow.randn(\n                    1024, 1024, 10, dtype=flow.float32, placement=placement, sbp=sbp\n                )\n                data = input.numpy()\n                tensor_mem = _get_specific_global_tensor_mem(placement, sbp, input)\n                _test_global_tensor_offload_d2h(test_case, input, tensor_mem)\n                _test_global_tensor_load_h2d(test_case, input, tensor_mem)\n                test_case.assertTrue(\n                    np.allclose(input.numpy(), data, rtol=0.0001, atol=0.0001)\n                )\n\n    @globaltest\n    @flow.unittest.skip_unless_1n2d()\n    def test_global_tensor_offload_and_load_2d_cpu_mem(test_case):\n        flow.cuda.empty_cache()\n        for i in range(5):\n            placement = flow.placement(\"cuda\", ranks=[0, 1])\n            for sbp in all_sbp(placement, max_dim=2, except_partial_sum=True):\n                input = flow.randn(\n                    1024, 1024, 100, dtype=flow.float32, placement=placement, sbp=sbp\n                )\n\n                before_used = flow._oneflow_internal.GetCPUMemoryUsed()\n                before_id = id(input)\n                input.offload()\n                after_used = flow._oneflow_internal.GetCPUMemoryUsed()\n                after_id = id(input)\n                if offload_tensor_test_mem_mode == 2:\n                    test_case.assertTrue(after_used > before_used)\n                elif offload_tensor_test_mem_mode == 3:\n                    print(\"cpu mem change value:\", after_used - before_used)\n                test_case.assertEqual(before_id, after_id)\n\n                cur_used = flow._oneflow_internal.GetCPUMemoryUsed()\n                before_id = id(input)\n                input.load()\n                after_used = flow._oneflow_internal.GetCPUMemoryUsed()\n                after_id = id(input)\n                if offload_tensor_test_mem_mode == 2:\n                    test_case.assertTrue(after_used < cur_used)\n                elif offload_tensor_test_mem_mode == 3:\n                    print(\"cpu mem change value:\", cur_used - after_used)\n                test_case.assertEqual(before_id, after_id)\n\n    @globaltest\n    @flow.unittest.skip_unless_1n2d()\n    def test_global_param_offload_and_load(test_case):\n        def load_eager_model(model):\n            for param in model.parameters():\n                if param.is_offloaded():\n                    param.load()\n                    test_case.assertTrue(not param.is_offloaded())\n\n        def offload_eager_model(model):\n            for param in model.parameters():\n                if not param.is_offloaded():\n                    param.offload()\n                    test_case.assertTrue(param.is_offloaded())\n\n        class Model(nn.Module):\n            def __init__(self):\n                super().__init__()\n                self.n_layer = 1\n\n                layer_list = list()\n\n                for _ in range(self.n_layer):\n                    layer_list.append(nn.Linear(768, 4096))\n\n                self.layers = nn.Sequential(*layer_list)\n\n            def forward(self, x):\n                return self.layers(x)\n\n        placement = flow.placement(\"cuda\", ranks=[0, 1])\n        model0 = Model().cuda()\n        model0.to_global(placement=placement, sbp=flow.sbp.broadcast)\n        BZ = 128\n        dataset = [flow.rand((BZ, 768), dtype=flow.float32) for _ in range(128)]\n\n        with flow.no_grad():\n            for idx, x in enumerate(dataset):\n                print(f\"iter {idx} begin\")\n                x = x.cuda()\n                x = x.to_global(placement=placement, sbp=flow.sbp.broadcast)\n                load_eager_model(model0)\n                y0 = model0(x)\n                offload_eager_model(model0)\n                print(f\"iter {idx} end\")\n                if idx == 1:\n                    break\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/expensive/test_graph_multi_graph_v2.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport os\nimport unittest\nimport numpy as np\nimport time\nimport tempfile\nimport multiprocessing\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\ndef _reset_session():\n    # Close session to avoid the buffer name duplicate error.\n    oneflow.framework.session_context.TryCloseDefaultSession()\n    time.sleep(5)\n    flow.framework.session_context.NewDefaultSession(flow._oneflow_global_unique_env)\n\n\ndef _with_new_session(fn):\n    def new_fn(*args, **kwargs):\n        # Avoid Singleton value duplication such as buffer names.\n        # saved and loaded graph runtime share the same buffer names(job names).\n        print(\n            \"function \",\n            fn.__name__,\n            \" session reset to avoid Singleton value duplication ...\",\n        )\n        _reset_session()\n        out = fn(*args, **kwargs)\n        _reset_session()\n        return out\n\n    return new_fn\n\n\ndef _test_linear_multi_graph_share(test_case, device, with_reshape):\n    linear = flow.nn.Linear(3, 8, False)\n    linear = linear.to(device)\n    np_weight = np.ones((3, 8)).astype(np.float32)\n    np_weight.fill(2.3)\n    flow.nn.init.constant_(linear.weight, 2.3)\n\n    class LinearReshapeModule(flow.nn.Module):\n        def __init__(self, lin, with_r):\n            super().__init__()\n            self.linear = lin\n            self.with_reshape = with_r\n\n        def forward(self, x):\n            y = self.linear(x)\n            if with_reshape:\n                assert len(y.shape) == 2\n                return flow.reshape(y, (y.shape[1], y.shape[0]))\n            else:\n                return y\n\n    linear_reshape = LinearReshapeModule(linear, with_reshape)\n\n    class LinearGraph(flow.nn.Graph):\n        @flow.nn.Graph.with_dynamic_input_shape(size=4)\n        def __init__(self, lin, with_r):\n            super().__init__()\n            self.my_linear = LinearReshapeModule(lin, with_r)\n\n        def build(self, x):\n            return self.my_linear(x)\n\n    linear_g = LinearGraph(linear, with_reshape)\n    input_arr = np.array(\n        [\n            [-0.94630778, -0.83378579, -0.87060891],\n            [2.0289922, -0.28708987, -2.18369248],\n            [0.35217619, -0.67095644, -1.58943879],\n            [0.08086036, -1.81075924, 1.20752494],\n            [0.8901075, -0.49976737, -1.07153746],\n            [-0.44872912, -1.07275683, 0.06256855],\n            [-0.22556897, 0.74798368, 0.90416439],\n            [0.48339456, -2.32742195, -0.59321527],\n        ],\n        dtype=np.float32,\n    )\n    x = flow.tensor(input_arr, device=device)\n    of_lazy_out = linear_g(x)\n    of_eager_out = linear_reshape(x)\n    test_case.assertTrue(np.array_equal(of_lazy_out.numpy(), of_eager_out.numpy()))\n\n    input_arr1 = np.array(\n        [\n            [-0.94630778, -0.83378579, -0.87060891],\n            [2.0289922, -0.28708987, -2.18369248],\n            [0.35217619, -0.67095644, -1.58943879],\n            [0.08086036, -1.81075924, 1.20752494],\n        ],\n        dtype=np.float32,\n    )\n    x1 = flow.tensor(input_arr1, device=device)\n    of_lazy_out1 = linear_g(x1)\n    of_eager_out1 = linear_reshape(x1)\n    test_case.assertTrue(np.array_equal(of_lazy_out1.numpy(), of_eager_out1.numpy()))\n\n    input_arr2 = np.array(\n        [\n            [-0.94630778, -0.83378579, -0.87060891],\n            [2.0289922, -0.28708987, -2.18369248],\n        ],\n        dtype=np.float32,\n    )\n    x2 = flow.tensor(input_arr2, device=device)\n    of_lazy_out2 = linear_g(x2)\n    of_eager_out2 = linear_reshape(x2)\n    test_case.assertTrue(np.array_equal(of_lazy_out2.numpy(), of_eager_out2.numpy()))\n\n    of_lazy_out2 = linear_g(x2)\n    of_eager_out2 = linear_reshape(x2)\n    test_case.assertTrue(np.array_equal(of_lazy_out2.numpy(), of_eager_out2.numpy()))\n\n\ndef _get_state_dict_tensor_size(sd):\n    from oneflow.framework.args_tree import ArgsTree\n\n    def _get_tensor_mem(input):\n        # if input.dim() == 0:\n        #     return 2\n        cnt_size = input.element_size() * flow.numel(input)\n        return cnt_size\n\n    args_tree = ArgsTree(sd, False)\n\n    size = 0\n    for arg in args_tree.iter_nodes():\n        if isinstance(arg, flow.Tensor):\n            size += _get_tensor_mem(arg)\n        else:\n            continue\n    return size\n\n\n@_with_new_session\ndef _test_linear_multi_graph_save(return_dict, device, with_reshape, with_eager):\n    linear = flow.nn.Linear(3, 8, False)\n    linear = linear.to(device)\n    np_weight = np.ones((3, 8)).astype(np.float32)\n    np_weight.fill(2.3)\n    flow.nn.init.constant_(linear.weight, 2.3)\n\n    class LinearReshapeModule(flow.nn.Module):\n        def __init__(self):\n            super().__init__()\n            self.linear = linear\n\n        def forward(self, x):\n            y = self.linear(x)\n            if with_reshape:\n                assert len(y.shape) == 2\n                return flow.reshape(y, (y.shape[1], y.shape[0]))\n            else:\n                return y\n\n    linear_reshape = LinearReshapeModule()\n\n    class LinearGraph(flow.nn.Graph):\n        @flow.nn.Graph.with_dynamic_input_shape(size=3)\n        def __init__(self):\n            super().__init__(enable_get_runtime_state_dict=True)\n            self.my_linear = linear_reshape\n\n        def build(self, x):\n            return self.my_linear(x)\n\n    linear_g = LinearGraph()\n\n    input_arr = np.array(\n        [\n            [-0.94630778, -0.83378579, -0.87060891],\n            [2.0289922, -0.28708987, -2.18369248],\n            [0.35217619, -0.67095644, -1.58943879],\n            [0.08086036, -1.81075924, 1.20752494],\n            [0.8901075, -0.49976737, -1.07153746],\n            [-0.44872912, -1.07275683, 0.06256855],\n            [-0.22556897, 0.74798368, 0.90416439],\n            [0.48339456, -2.32742195, -0.59321527],\n        ],\n        dtype=np.float32,\n    )\n    x = flow.tensor(input_arr, device=device)\n    of_lazy_out = linear_g(x)\n    of_eager_out = linear_reshape(x)\n    test_case0 = np.array_equal(of_lazy_out.numpy(), of_eager_out.numpy())\n    return_dict[\"save0\"] = test_case0\n\n    input_arr1 = np.array(\n        [\n            [-0.94630778, -0.83378579, -0.87060891],\n            [2.0289922, -0.28708987, -2.18369248],\n            [0.35217619, -0.67095644, -1.58943879],\n            [0.08086036, -1.81075924, 1.20752494],\n        ],\n        dtype=np.float32,\n    )\n    x1 = flow.tensor(input_arr1, device=device)\n    of_lazy_out1 = linear_g(x1)\n    of_eager_out1 = linear_reshape(x1)\n    test_case1 = np.array_equal(of_lazy_out1.numpy(), of_eager_out1.numpy())\n    return_dict[\"save1\"] = test_case1\n\n    input_arr2 = np.array(\n        [\n            [-0.94630778, -0.83378579, -0.87060891],\n            [2.0289922, -0.28708987, -2.18369248],\n        ],\n        dtype=np.float32,\n    )\n    x2 = flow.tensor(input_arr2, device=device)\n    of_lazy_out2 = linear_g(x2)\n    of_eager_out2 = linear_reshape(x2)\n    test_case2 = np.array_equal(of_lazy_out2.numpy(), of_eager_out2.numpy())\n    return_dict[\"save2\"] = test_case2\n\n    input_arr3 = np.array([[-0.94630778, -0.83378579, -0.87060891],], dtype=np.float32,)\n    x3 = flow.tensor(input_arr3, device=device)\n    of_lazy_out3 = linear_g(x3)\n    of_eager_out3 = linear_reshape(x3)\n    test_case3 = np.array_equal(of_lazy_out3.numpy(), of_eager_out3.numpy())\n    return_dict[\"save3\"] = test_case3\n\n    of_lazy_out1 = linear_g(x1)\n    test_case1 = np.array_equal(of_lazy_out1.numpy(), of_eager_out1.numpy())\n    return_dict[\"save4\"] = test_case1\n\n    state_dict = linear_g.runtime_state_dict(with_eager=with_eager)\n    print(\"====> saved graphs\", state_dict.keys())\n    return state_dict\n\n\n@_with_new_session\ndef _test_linear_multi_graph_load(\n    return_dict, device, with_reshape, state_dict, with_new_input\n):\n    linear = flow.nn.Linear(3, 8, False)\n    linear = linear.to(device)\n    np_weight = np.ones((3, 8)).astype(np.float32)\n    np_weight.fill(2.3)\n    flow.nn.init.constant_(linear.weight, 2.3)\n\n    class LinearReshapeModule(flow.nn.Module):\n        def __init__(self):\n            super().__init__()\n            self.linear = linear\n\n        def forward(self, x):\n            y = self.linear(x)\n            if with_reshape:\n                assert len(y.shape) == 2\n                return flow.reshape(y, (y.shape[1], y.shape[0]))\n            else:\n                return y\n\n    linear_reshape = LinearReshapeModule()\n\n    class LinearGraph(flow.nn.Graph):\n        @flow.nn.Graph.with_dynamic_input_shape(size=20)\n        def __init__(self):\n            super().__init__(debug_v_level=0)\n            self.my_linear = linear_reshape\n\n        def build(self, x):\n            return self.my_linear(x)\n\n    linear_g = LinearGraph()\n    print(\"====> load\")\n    linear_g.load_runtime_state_dict(state_dict)\n    print(\"====> load finish\")\n\n    input_arr = np.array(\n        [\n            [-0.94630778, -0.83378579, -0.87060891],\n            [2.0289922, -0.28708987, -2.18369248],\n            [0.35217619, -0.67095644, -1.58943879],\n            [0.08086036, -1.81075924, 1.20752494],\n            [0.8901075, -0.49976737, -1.07153746],\n            [-0.44872912, -1.07275683, 0.06256855],\n            [-0.22556897, 0.74798368, 0.90416439],\n            [0.48339456, -2.32742195, -0.59321527],\n        ],\n        dtype=np.float32,\n    )\n    x = flow.tensor(input_arr, device=device)\n    of_lazy_out = linear_g(x)\n    of_eager_out = linear_reshape(x)\n    test_case0 = np.array_equal(of_lazy_out.numpy(), of_eager_out.numpy())\n    return_dict[\"load0\"] = test_case0\n\n    input_arr1 = np.array(\n        [\n            [-0.94630778, -0.83378579, -0.87060891],\n            [2.0289922, -0.28708987, -2.18369248],\n            [0.35217619, -0.67095644, -1.58943879],\n            [0.08086036, -1.81075924, 1.20752494],\n        ],\n        dtype=np.float32,\n    )\n    x1 = flow.tensor(input_arr1, device=device)\n    of_lazy_out1 = linear_g(x1)\n    of_eager_out1 = linear_reshape(x1)\n    test_case1 = np.array_equal(of_lazy_out1.numpy(), of_eager_out1.numpy())\n    return_dict[\"load1\"] = test_case1\n\n    if with_new_input:\n        # The following section is for testing the new input shape after completing the load.\n        input_arr2 = np.array(\n            [\n                [-0.94630778, -0.83378579, -0.87060891],\n                [2.0289922, -0.28708987, -2.18369248],\n                [0.08086036, -1.81075924, 1.20752494],\n            ],\n            dtype=np.float32,\n        )\n        x2 = flow.tensor(input_arr2, device=device)\n        of_lazy_out2 = linear_g(x2)\n        of_eager_out2 = linear_reshape(x2)\n        test_case2 = np.array_equal(of_lazy_out2.numpy(), of_eager_out2.numpy())\n        return_dict[\"load2\"] = test_case2\n\n\ndef _graph_save(return_dict, filename, with_eager):\n    state_dict = _test_linear_multi_graph_save(\n        return_dict, flow.device(\"cuda:0\"), True, with_eager,\n    )\n    print(\n        f\"state_dict(with_eager={with_eager}) tensors size \",\n        _get_state_dict_tensor_size(state_dict),\n    )\n    flow.save(state_dict, filename)\n    print(\"====> save process done\")\n\n\ndef _graph_load(return_dict, filename):\n    state_dict_loaded = flow.load(filename)\n    # load with nn.Graph\n    _test_linear_multi_graph_load(\n        return_dict, flow.device(\"cuda\"), True, state_dict_loaded, True\n    )\n    print(\"====> load process done\")\n\n\ndef _graph_load_to_another_device(return_dict, filename):\n    state_dict_loaded = flow.load(filename)\n    new_state_dict = flow.nn.Graph.runtime_state_dict_to(\n        state_dict_loaded, flow.device(\"cuda:1\")\n    )\n    # load with nn.Graph\n    _test_linear_multi_graph_load(\n        return_dict, flow.device(\"cuda:1\"), True, new_state_dict, False\n    )\n    print(\"====> load process done\")\n\n\ndef _test_linear_multi_graph_save_load_gpu(test_case, with_eager):\n    # A graph runtime state dict\n    with tempfile.NamedTemporaryFile() as f:\n        # Save a graph\n        manager = multiprocessing.Manager()\n        return_dict = manager.dict()\n        save_p = multiprocessing.get_context(\"spawn\").Process(\n            target=_graph_save, args=(return_dict, f.name, with_eager),\n        )\n        save_p.start()\n        save_p.join()\n\n        # Resume a graph from a graph runtime state dict\n        load_p = multiprocessing.get_context(\"spawn\").Process(\n            target=_graph_load, args=(return_dict, f.name)\n        )\n        load_p.start()\n        load_p.join()\n\n        # test_case can't be passed into sub process, so we check with return_dict.\n        # Reference: https://stackoverflow.com/questions/52225003/writing-to-multiple-files-using-multiprocessing-error-typeerror-cannot-seria\n        for (key, check_value) in return_dict.items():\n            test_case.assertTrue(check_value, key + \" failed.\")\n\n\ndef _test_load_to_another_device(test_case, with_eager):\n    # A graph runtime state dict\n    with tempfile.NamedTemporaryFile() as f:\n        # Save a graph\n        manager = multiprocessing.Manager()\n        return_dict = manager.dict()\n        save_p = multiprocessing.get_context(\"spawn\").Process(\n            target=_graph_save, args=(return_dict, f.name, with_eager),\n        )\n        save_p.start()\n        save_p.join()\n        print(save_p)\n\n        # Resume a graph from a graph runtime state dict\n        load_p = multiprocessing.get_context(\"spawn\").Process(\n            target=_graph_load_to_another_device, args=(return_dict, f.name)\n        )\n        load_p.start()\n        load_p.join()\n        print(load_p)\n\n        # test_case can't be passed into sub process, so we check with return_dict.\n        # Reference: https://stackoverflow.com/questions/52225003/writing-to-multiple-files-using-multiprocessing-error-typeerror-cannot-seria\n        for (key, check_value) in return_dict.items():\n            test_case.assertTrue(check_value, key + \" failed.\")\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n@flow.unittest.skip_unless_1n1d()\nclass TestLinearMultiGraph(oneflow.unittest.TestCase):\n    def test_linear_multi_graph_share_gpu(test_case):\n        _test_linear_multi_graph_share(test_case, flow.device(\"cuda\"), False)\n\n    def test_linear_reshape_multi_graph_share_gpu(test_case):\n        _test_linear_multi_graph_share(test_case, flow.device(\"cuda\"), True)\n\n    def test_linear_multi_graph_save_load_gpu_with_share(test_case):\n        _test_linear_multi_graph_save_load_gpu(test_case, True)\n\n    def test_linear_multi_graph_save_load_gpu_with_share_without_eager(test_case):\n        _test_linear_multi_graph_save_load_gpu(test_case, False)\n\n    def test_load_to_another_device(test_case):\n        _test_load_to_another_device(test_case, False)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/expensive/test_id_shuffle.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport os\n\n# dynamic memory allocation can't be tested in unittest\nos.environ[\"ONEFLOW_ONE_EMBEDDING_USE_DYNAMIC_MEMORY_ALLOCATION\"] = \"0\"\nimport unittest\nfrom collections import OrderedDict\nfrom oneflow.test_utils.test_util import GenArgDict\nimport numpy as np\nimport oneflow as flow\n\n\ndef _test_id_shuffle(test_case, has_table_id, num_tables):\n    batch_size = 512\n    ids = np.random.randint(0, 1000, (batch_size, num_tables), dtype=np.int64)\n    if has_table_id:\n        table_ids = (\n            ids % num_tables\n        )  # same id must have same table id, so in this case get table_ids from ids\n        table_ids_tensor = flow.tensor(\n            table_ids.astype(np.int32), requires_grad=False\n        ).to(\"cuda\")\n    else:\n        table_ids_tensor = None\n    ids_tensor = flow.tensor(ids, requires_grad=False).to(\"cuda\")\n\n    class TestGraph(flow.nn.Graph):\n        def __init__(self):\n            super().__init__()\n\n        def build(self, ids, table_ids):\n            (\n                num_unique_matrix,\n                inverse_unique_partition_indices,\n                cur_rank_num_unique,\n                cur_rank_unique_ids,\n                cur_rank_unique_table_ids,\n                cur_rank_inverse_indices,\n            ) = flow._C.one_embedding_id_shuffle(ids, table_ids, num_tables, \"test\")\n            return (\n                flow.cast(num_unique_matrix, flow.int32),\n                flow.cast(inverse_unique_partition_indices, flow.int32),\n                flow.cast(cur_rank_num_unique, flow.int32),\n                flow.cast(cur_rank_unique_ids, flow.int32),\n                flow.cast(cur_rank_unique_table_ids, flow.int32),\n                flow.cast(cur_rank_inverse_indices, flow.int32),\n            )\n\n    graph = TestGraph()\n    (\n        num_unique_matrix,\n        inverse_unique_partition_indices,\n        cur_rank_num_unique,\n        cur_rank_unique_ids,\n        cur_rank_unique_table_ids,\n        cur_rank_inverse_indices,\n    ) = graph(ids_tensor, table_ids_tensor)\n    np_unique_ids, np_inverse = np.unique(ids, return_inverse=True)\n    np_num_unique = np_unique_ids.size\n    test_case.assertTrue(np.array_equal(np_num_unique, num_unique_matrix[0]))\n    test_case.assertTrue(np.array_equal(np_num_unique, cur_rank_num_unique[0]))\n    reversed_ids = cur_rank_unique_ids[cur_rank_inverse_indices][\n        inverse_unique_partition_indices\n    ]\n    test_case.assertTrue(np.array_equal(reversed_ids.numpy(), ids))\n    if has_table_id:\n        reversed_table_ids = cur_rank_unique_table_ids[cur_rank_inverse_indices][\n            inverse_unique_partition_indices\n        ]\n        test_case.assertTrue(np.array_equal(reversed_table_ids.numpy(), table_ids))\n    # when has_table_id=False, we can not test table ids because in this case same ids not lead to same table id\n\n\ndef round_half_away_from_zero(x):\n    sign = np.sign(x)\n    abs_val = np.abs(x)\n    abs_val += 0.5\n    floor_val = np.floor(abs_val)\n    out = floor_val * sign\n    return out\n\n\ndef embedding_shuffle_quantize(np_data, np_dtype):\n    # When use float16, ComputeType is set to as Float.\n    np_reduce_data = np_data.astype(np.float32)\n    abs_max_factor = np.max(np.abs(np_reduce_data), axis=2)\n    abs_max_factor = np.expand_dims(abs_max_factor, axis=2)\n    transport_quantize_factor = abs_max_factor.astype(np_dtype)\n    int8_factor = np.ones(abs_max_factor.shape, dtype=np.float32) * 127.0\n    int8_factor = int8_factor.astype(np.float32)\n    quantize_factor = int8_factor / abs_max_factor\n\n    # Covert to Compute Type.\n    np_data.astype(np.float32)\n    np_data = np_data * quantize_factor\n    np_data = round_half_away_from_zero(np_data)\n    np_data = np_data.astype(np.int8)\n\n    # Covert to Compute Type.\n    np_data = np_data.astype(np.float32)\n    dequantize_factor = transport_quantize_factor.astype(np.float32) / int8_factor\n    np_data = np_data * dequantize_factor\n    np_data = np_data.astype(np_dtype)\n    return np_data\n\n\ndef _test_embedding_shuffle(test_case, dtype, enable_quantize):\n    batch_size = 512\n    num_tables = 26\n    embedding_size = 128\n    ids = np.random.randint(0, 1000, (batch_size, num_tables), dtype=np.int64)\n\n    enable_quantized_comm = enable_quantize and embedding_size < 1025\n    if enable_quantized_comm:\n        os.environ[\"ONEFLOW_ONE_EMBEDDING_ENABLE_QUANTIZED_COMM\"] = \"1\"\n    else:\n        os.environ[\"ONEFLOW_ONE_EMBEDDING_ENABLE_QUANTIZED_COMM\"] = \"0\"\n\n    table_ids = (\n        ids % num_tables\n    )  # same id must have same table id, so in this case get table_ids from ids\n    if dtype == flow.float16:\n        np_dtype = np.float16\n    else:\n        np_dtype = np.float32\n    data = np.random.rand(1000, embedding_size).astype(np_dtype)\n\n    ids_tensor = flow.tensor(ids, requires_grad=False).to(\"cuda\")\n    table_ids_tensor = flow.tensor(table_ids.astype(np.int32), requires_grad=False).to(\n        \"cuda\"\n    )\n    data_tensor = flow.tensor(data, requires_grad=False).to(\"cuda\")\n\n    class TestGraph(flow.nn.Graph):\n        def __init__(self):\n            super().__init__()\n\n        def build(self, ids, table_ids, data):\n            (\n                num_unique_matrix,\n                inverse_unique_partition_indices,\n                _,\n                cur_rank_unique_ids,\n                _,\n                cur_rank_inverse_indices,\n            ) = flow._C.one_embedding_id_shuffle(ids, table_ids, num_tables, \"test\")\n            unique_embeddings = flow._C.gather(data, cur_rank_unique_ids, axis=0)\n            embeddings = flow._C.one_embedding_embedding_shuffle(\n                unique_embeddings,\n                num_unique_matrix,\n                cur_rank_inverse_indices,\n                inverse_unique_partition_indices,\n                \"test\",\n            )\n            return embeddings\n\n    graph = TestGraph()\n    embeddings = graph(ids_tensor, table_ids_tensor, data_tensor)\n    np_embeddings = data[ids]\n\n    # Quantized numpy embedding.\n    if enable_quantized_comm:\n        np_embeddings = embedding_shuffle_quantize(np_embeddings, np_dtype)\n    test_case.assertTrue(\n        np.allclose(embeddings.numpy(), np_embeddings, atol=1e-4, rtol=1e-4)\n    )\n\n\ndef _test_embedding_gradient_shuffle(test_case, enable_quantize, fp16, embedding_size):\n    batch_size = 512\n    num_tables = 26\n    ids = np.random.randint(0, 1000, (batch_size, num_tables), dtype=np.int64)\n    enable_quantized_comm = enable_quantize and embedding_size < 1025\n    if enable_quantized_comm:\n        np_tolerance = 0.5\n        os.environ[\"ONEFLOW_ONE_EMBEDDING_ENABLE_QUANTIZED_COMM\"] = \"1\"\n        ids = np.arange(batch_size * num_tables, dtype=np.int64)\n        np.random.shuffle(ids)\n    else:\n        if fp16:\n            np_tolerance = 1e-2\n        else:\n            np_tolerance = 1e-4\n        os.environ[\"ONEFLOW_ONE_EMBEDDING_ENABLE_QUANTIZED_COMM\"] = \"0\"\n\n    table_ids = (\n        ids % num_tables\n    )  # same id must have same table id, so in this case get table_ids from ids\n    embedding_grad = np.random.uniform(\n        low=-1, high=1, size=(batch_size, num_tables, embedding_size)\n    ).astype(np.float32)\n    ids_tensor = flow.tensor(ids, requires_grad=False).to(\"cuda\")\n    table_ids_tensor = flow.tensor(table_ids.astype(np.int32), requires_grad=False).to(\n        \"cuda\"\n    )\n    embedding_grad_tensor = flow.tensor(embedding_grad, requires_grad=False).to(\"cuda\")\n\n    class TestGraph(flow.nn.Graph):\n        def __init__(self):\n            super().__init__()\n\n        def build(self, ids, table_ids, embedding_grad):\n            (\n                num_unique_matrix,\n                inverse_unique_partition_indices,\n                _,\n                cur_rank_unique_ids,\n                _,\n                cur_rank_inverse_indices,\n            ) = flow._C.one_embedding_id_shuffle(ids, table_ids, num_tables, \"test\")\n            if fp16:\n                embedding_grad = flow.cast(embedding_grad, flow.float16)\n            cur_rank_unique_embedding_grad = flow._C.one_embedding_embedding_gradient_shuffle(\n                embedding_grad,\n                num_unique_matrix,\n                cur_rank_inverse_indices,\n                inverse_unique_partition_indices,\n                \"test\",\n            )\n            if fp16:\n                cur_rank_unique_embedding_grad = flow.cast(\n                    cur_rank_unique_embedding_grad, flow.float32\n                )\n            return (\n                cur_rank_unique_embedding_grad,\n                flow.cast(cur_rank_unique_ids, flow.int32),\n                flow.cast(cur_rank_inverse_indices, flow.int32),\n                flow.cast(inverse_unique_partition_indices, flow.int32),\n            )\n\n    graph = TestGraph()\n    (\n        cur_rank_unique_embedding_grad,\n        cur_rank_unique_ids,\n        cur_rank_inverse_indices,\n        inverse_unique_partition_indices,\n    ) = graph(ids_tensor, table_ids_tensor, embedding_grad_tensor)\n    np_unique_ids, np_inverse = np.unique(ids, return_inverse=True)\n    np_num_unique = np_unique_ids.size\n    np_cur_rank_unique_embedding_grad = np.zeros(\n        cur_rank_unique_embedding_grad.shape, dtype=np.float32\n    ).reshape(-1, embedding_size)\n\n    embedding_grad = embedding_grad.reshape(-1, embedding_size)\n    if fp16:\n        embedding_grad = embedding_grad.astype(np.float16)\n    for k in range(np_num_unique):\n        np_data = sum(embedding_grad[np.where(ids.flatten() == np_unique_ids[k])[0]])\n        # Quantize Embedding Gradient.\n        if enable_quantized_comm:\n            abs_max_factor = np.max(np.abs(np_data))\n            int8_factor = np.full(abs_max_factor.shape, 127.0, dtype=np.float32)\n            quantize_factor = int8_factor / abs_max_factor\n            np_data = np_data * quantize_factor\n            np_data = round_half_away_from_zero(np_data)\n            np_data = np_data.astype(np.int8)\n            np_data = np_data.astype(np.float32)\n            dequantize_factor = abs_max_factor / int8_factor\n            np_data = np_data * dequantize_factor\n\n        np_cur_rank_unique_embedding_grad[k, :] = np_data\n\n    reversed_ids = cur_rank_unique_ids[cur_rank_inverse_indices][\n        inverse_unique_partition_indices\n    ]\n    test_case.assertTrue(np.array_equal(reversed_ids.numpy(), ids))\n    of_cur_rank_embedding_grad = cur_rank_unique_embedding_grad[\n        cur_rank_inverse_indices\n    ][inverse_unique_partition_indices]\n    of_cur_rank_embedding_grad = flow.reshape(\n        of_cur_rank_embedding_grad, (-1, embedding_size)\n    )\n    np_cur_rank_embedding_grad = np_cur_rank_unique_embedding_grad[np_inverse]\n    if fp16:\n        np_cur_rank_embedding_grad = np_cur_rank_embedding_grad.astype(np.float32)\n\n    test_case.assertTrue(\n        np.allclose(\n            of_cur_rank_embedding_grad.numpy().flatten(),\n            np_cur_rank_embedding_grad.flatten(),\n            atol=np_tolerance,\n            rtol=np_tolerance,\n        )\n    )\n\n\ndef _test_unique_key_value(test_case, has_table_id, num_tables):\n    batch_size = 128\n    ids = np.random.randint(0, 1000, (batch_size, num_tables), dtype=np.int64)\n    if has_table_id:\n        table_ids = (\n            ids % num_tables\n        )  # same id must have same table id, so in this case get table_ids from ids\n        table_ids_tensor = flow.tensor(\n            table_ids.astype(np.int32), requires_grad=False\n        ).to(\"cuda\")\n    else:\n        table_ids_tensor = None\n    ids_tensor = flow.tensor(ids, requires_grad=False).to(\"cuda\")\n\n    class TestGraph(flow.nn.Graph):\n        def __init__(self):\n            super().__init__()\n\n        def build(self, ids, table_ids):\n            (\n                num_unique,\n                unique_ids,\n                unique_table_ids,\n                inverse_indices,\n            ) = flow._C.one_embedding_unique_key_value_pair(\n                ids, table_ids, num_tables, \"test\"\n            )\n            return (\n                flow.cast(num_unique, flow.int32),\n                flow.cast(unique_ids, flow.int32),\n                flow.cast(unique_table_ids, flow.int32),\n                flow.cast(inverse_indices, flow.int32),\n            )\n\n    graph = TestGraph()\n    (num_unique, unique_ids, unique_table_ids, inverse_indices,) = graph(\n        ids_tensor, table_ids_tensor\n    )\n    np_unique_ids, np_inverse = np.unique(ids, return_inverse=True)\n    np_num_unique = np_unique_ids.size\n    test_case.assertTrue(np.array_equal(np_num_unique, num_unique[0]))\n    reversed_ids = unique_ids[inverse_indices]\n    test_case.assertTrue(np.array_equal(reversed_ids.numpy(), ids))\n    if has_table_id:\n        reversed_table_ids = unique_table_ids[inverse_indices]\n        test_case.assertTrue(np.array_equal(reversed_table_ids.numpy(), table_ids))\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n@flow.unittest.skip_unless_1n1d()\nclass DataShuffleTestCase(flow.unittest.TestCase):\n    def test_id_shuffle(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"has_table_id\"] = [True, False]\n        arg_dict[\"num_tables\"] = [1, 26]\n        for kwargs in GenArgDict(arg_dict):\n            _test_id_shuffle(test_case, **kwargs)\n\n    def test_embedding_shuffle(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"dtype\"] = [flow.float32, flow.float16]\n        arg_dict[\"enable_quantize\"] = [True, False]\n\n        for kwargs in GenArgDict(arg_dict):\n            _test_embedding_shuffle(test_case, **kwargs)\n\n    def test_embedding_gradient_shuffle(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"enable_quantize\"] = [True, False]\n        arg_dict[\"fp16\"] = [True, False]\n        arg_dict[\"embedding_size\"] = [128, 17]\n        for kwargs in GenArgDict(arg_dict):\n            _test_embedding_gradient_shuffle(test_case, **kwargs)\n\n    def test_unique_key_value(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"has_table_id\"] = [True, False]\n        arg_dict[\"num_tables\"] = [13, 26, 1]\n        for kwargs in GenArgDict(arg_dict):\n            _test_unique_key_value(test_case, **kwargs)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/expensive/test_id_shuffle_global.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport os\n\n# dynamic memory allocation can't be tested in unittest\nos.environ[\"ONEFLOW_ONE_EMBEDDING_USE_DYNAMIC_MEMORY_ALLOCATION\"] = \"0\"\nimport unittest\nfrom collections import OrderedDict\nfrom oneflow.test_utils.test_util import GenArgDict\nimport numpy as np\nimport oneflow as flow\n\nparallel_num = 2\nmax_id = 1000\n\n\ndef get_tensors(batch_size, num_tables):\n    placement = flow.placement(type=\"cuda\", ranks=list(range(parallel_num)))\n    ids = np.random.randint(0, max_id, (batch_size, num_tables), dtype=np.int64)\n    ids_tensor = flow.tensor(ids, requires_grad=False).to_global(\n        placement=placement, sbp=flow.sbp.split(0)\n    )\n    table_ids = (\n        ids % num_tables\n    )  # same id must have same table id, so in this case get table_ids from ids\n    table_ids_tensor = flow.tensor(\n        table_ids.astype(np.int32), requires_grad=False\n    ).to_global(placement=placement, sbp=flow.sbp.split(0))\n    return ids_tensor, table_ids_tensor\n\n\ndef _test_id_shuffle(test_case, has_table_id, num_tables):\n    batch_size = int(1024 / parallel_num)\n    placement = flow.placement(type=\"cuda\", ranks=list(range(parallel_num)))\n\n    class TestGraph(flow.nn.Graph):\n        def __init__(self):\n            super().__init__()\n\n        def build(self, ids, table_ids):\n            (\n                num_unique_matrix,\n                inverse_unique_partition_indices,\n                cur_rank_num_unique,\n                cur_rank_unique_ids,\n                cur_rank_unique_table_ids,\n                cur_rank_inverse_indices,\n            ) = flow._C.one_embedding_id_shuffle(ids, table_ids, num_tables, \"test\")\n            return (\n                flow.cast(num_unique_matrix, flow.int32),\n                flow.cast(inverse_unique_partition_indices, flow.int32),\n                flow.cast(cur_rank_num_unique, flow.int32),\n                flow.cast(cur_rank_unique_ids, flow.int32),\n                flow.cast(cur_rank_unique_table_ids, flow.int32),\n                flow.cast(cur_rank_inverse_indices, flow.int32),\n            )\n\n    graph = TestGraph()\n    for i in range(10):\n        ids_tensor, table_ids_tensor = get_tensors(batch_size, num_tables)\n        if not has_table_id:\n            table_ids_tensor = None\n        graph(ids_tensor, table_ids_tensor)\n    (\n        num_unique_matrix,\n        inverse_unique_partition_indices,\n        local_cur_rank_num_unique,\n        cur_rank_unique_ids,\n        cur_rank_unique_table_ids,\n        cur_rank_inverse_indices,\n    ) = graph(ids_tensor, table_ids_tensor)\n    cur_rank_num_unique = local_cur_rank_num_unique.to_local().to_global(\n        placement=placement, sbp=flow.sbp.split(0)\n    )\n    cur_rank_num_unique_list = []\n    cur_rank_unique_ids_list = []\n    cur_rank_unique_table_ids_list = []\n    cur_rank_num_ids = batch_size * num_tables * parallel_num\n    for i in range(parallel_num):\n        num_unique_i = cur_rank_num_unique.numpy()[i]\n        unique_ids_i = cur_rank_unique_ids.numpy()[\n            cur_rank_num_ids * i : cur_rank_num_ids * (i + 1)\n        ]\n        unique_table_ids_i = cur_rank_unique_table_ids.numpy()[\n            cur_rank_num_ids * i : cur_rank_num_ids * (i + 1)\n        ]\n        cur_rank_num_unique_list.append(num_unique_i)\n        cur_rank_unique_ids_list.append(np.array(unique_ids_i[0:num_unique_i]))\n        cur_rank_unique_table_ids_list.append(\n            np.array(unique_table_ids_i[0:num_unique_i])\n        )\n\n    global_ids = ids_tensor.numpy()\n    np_unique_ids, np_unique_index, np_inverse = np.unique(\n        global_ids, return_index=True, return_inverse=True\n    )\n    np_num_unique = np_unique_ids.size\n    # test num unique\n    test_case.assertTrue(\n        np.array_equal(np_num_unique, np.array(cur_rank_num_unique_list).sum())\n    )\n    # test unique ids\n    unique_ids = np.concatenate(cur_rank_unique_ids_list)\n    unique_ids.sort()\n    np_unique_ids.sort()\n    test_case.assertTrue(np.array_equal(unique_ids, np_unique_ids))\n    if has_table_id:\n        # test unique table ids\n        unique_table_ids = np.concatenate(cur_rank_unique_table_ids_list)\n        unique_table_ids.sort()\n        global_table_ids = table_ids_tensor.numpy()\n        np_unique_table_ids = global_table_ids.flatten()[np_unique_index]\n        np_unique_table_ids.sort()\n        test_case.assertTrue(np.array_equal(unique_table_ids, np_unique_table_ids))\n\n\ndef round_half_away_from_zero(x):\n    sign = np.sign(x)\n    abs_val = np.abs(x)\n    abs_val += 0.5\n    floor_val = np.floor(abs_val)\n    out = floor_val * sign\n    return out\n\n\ndef embedding_shuffle_quantize(np_data, np_dtype):\n\n    # When use float16, ComputeType is set to as Float.\n    np_reduce_data = np_data.astype(np.float32)\n    abs_max_factor = np.max(np.abs(np_reduce_data), axis=2)\n    abs_max_factor = np.expand_dims(abs_max_factor, axis=2)\n    transport_quantize_factor = abs_max_factor.astype(np_dtype)\n    int8_factor = np.ones(abs_max_factor.shape, dtype=np.float32) * 127.0\n    int8_factor = int8_factor.astype(np.float32)\n    quantize_factor = int8_factor / abs_max_factor\n\n    # Covert to Compute Type.\n    np_data.astype(np.float32)\n    np_data = np_data * quantize_factor\n    np_data = round_half_away_from_zero(np_data)\n    np_data = np_data.astype(np.int8)\n\n    # Covert to Compute Type.\n    np_data = np_data.astype(np.float32)\n    dequantize_factor = transport_quantize_factor.astype(np.float32) / int8_factor\n    np_data = np_data * dequantize_factor\n    np_data = np_data.astype(np_dtype)\n    return np_data\n\n\ndef _test_embedding_shuffle(test_case, dtype, enable_quantize):\n    batch_size = int(1024 / parallel_num)\n    placement = flow.placement(type=\"cuda\", ranks=list(range(parallel_num)))\n    num_tables = 26\n    embedding_size = 128\n    enable_quantized_comm = enable_quantize and embedding_size < 1025\n    if enable_quantized_comm:\n        os.environ[\"ONEFLOW_ONE_EMBEDDING_ENABLE_QUANTIZED_COMM\"] = \"1\"\n    else:\n        os.environ[\"ONEFLOW_ONE_EMBEDDING_ENABLE_QUANTIZED_COMM\"] = \"0\"\n\n    if dtype == flow.float16:\n        np_dtype = np.float16\n    else:\n        np_dtype = np.float32\n    data = np.random.rand(max_id, embedding_size).astype(np_dtype)\n    data_tensor = flow.tensor(data, requires_grad=False).to_global(\n        placement=placement, sbp=flow.sbp.broadcast()\n    )\n\n    class TestGraph(flow.nn.Graph):\n        def __init__(self):\n            super().__init__()\n\n        def build(self, ids, table_ids, data):\n            (\n                num_unique_matrix,\n                inverse_unique_partition_indices,\n                _,\n                cur_rank_unique_ids,\n                _,\n                cur_rank_inverse_indices,\n            ) = flow._C.one_embedding_id_shuffle(ids, table_ids, num_tables, \"test\")\n            unique_embeddings = flow._C.gather(data, cur_rank_unique_ids, axis=0)\n            embeddings = flow._C.one_embedding_embedding_shuffle(\n                unique_embeddings,\n                flow._C.identity(num_unique_matrix),\n                flow._C.identity(cur_rank_inverse_indices),\n                flow._C.identity(inverse_unique_partition_indices),\n                \"test\",\n            )\n            return embeddings\n\n    graph = TestGraph()\n    for i in range(10):\n        ids_tensor, table_ids_tensor = get_tensors(batch_size, num_tables)\n        graph(ids_tensor, table_ids_tensor, data_tensor)\n    embeddings = graph(ids_tensor, table_ids_tensor, data_tensor)\n    global_ids = ids_tensor.numpy()\n    global_data = data_tensor.numpy()\n    np_embeddings = global_data[global_ids]\n\n    # Quantized numpy embedding.\n    if enable_quantized_comm:\n        np_embeddings = embedding_shuffle_quantize(np_embeddings, np_dtype)\n\n    test_case.assertTrue(np.array_equal(embeddings.numpy(), np_embeddings))\n\n\ndef _test_embedding_gradient_shuffle(test_case, enable_quantize, fp16, embedding_size):\n    np_tolerance = 0\n    batch_size = int(1024 / parallel_num)\n    placement = flow.placement(type=\"cuda\", ranks=list(range(parallel_num)))\n    num_tables = 26\n    enable_quantized_comm = enable_quantize and embedding_size < 1025\n    if enable_quantized_comm:\n        np_tolerance = 0.5\n        os.environ[\"ONEFLOW_ONE_EMBEDDING_ENABLE_QUANTIZED_COMM\"] = \"1\"\n    else:\n        if fp16:\n            np_tolerance = 1e-2\n        else:\n            np_tolerance = 1e-4\n        os.environ[\"ONEFLOW_ONE_EMBEDDING_ENABLE_QUANTIZED_COMM\"] = \"0\"\n    embedding_grad = np.random.rand(batch_size, num_tables, embedding_size).astype(\n        np.float32\n    )\n    embedding_grad_tensor = flow.tensor(embedding_grad, requires_grad=False).to_global(\n        placement=placement, sbp=flow.sbp.split(0)\n    )\n\n    class TestGraph(flow.nn.Graph):\n        def __init__(self):\n            super().__init__()\n\n        def build(self, ids, table_ids, embedding_grad):\n            (\n                num_unique_matrix,\n                inverse_unique_partition_indices,\n                cur_rank_num_unique,\n                cur_rank_unique_ids,\n                _,\n                cur_rank_inverse_indices,\n            ) = flow._C.one_embedding_id_shuffle(ids, table_ids, num_tables, \"test\")\n            if fp16:\n                embedding_grad = flow.cast(embedding_grad, flow.float16)\n            cur_rank_unique_embedding_grad = flow._C.one_embedding_embedding_gradient_shuffle(\n                embedding_grad,\n                num_unique_matrix,\n                cur_rank_inverse_indices,\n                inverse_unique_partition_indices,\n                \"test\",\n            )\n            if fp16:\n                cur_rank_unique_embedding_grad = flow.cast(\n                    cur_rank_unique_embedding_grad, flow.float32\n                )\n            return (\n                cur_rank_unique_embedding_grad,\n                flow.cast(cur_rank_num_unique, flow.int32),\n                cur_rank_unique_ids,\n            )\n\n    graph = TestGraph()\n    for i in range(10):\n        ids_tensor, table_ids_tensor = get_tensors(batch_size, num_tables)\n        graph(ids_tensor, table_ids_tensor, embedding_grad_tensor)\n    ids_tensor, table_ids_tensor = get_tensors(batch_size, num_tables)\n    (\n        cur_rank_unique_embedding_grad,\n        local_cur_rank_num_unique,\n        cur_rank_unique_ids,\n    ) = graph(ids_tensor, table_ids_tensor, embedding_grad_tensor)\n    cur_rank_num_unique = local_cur_rank_num_unique.to_local().to_global(\n        placement=placement, sbp=flow.sbp.split(0)\n    )\n    global_ids = ids_tensor.numpy()\n    global_embedding_grad = embedding_grad_tensor.numpy()\n    np_unique_ids = np.unique(global_ids)\n    np_num_unique = np_unique_ids.size\n    np_cur_rank_unique_embedding_grad = np.zeros((max_id, embedding_size))\n    if fp16:\n        global_embedding_grad = global_embedding_grad.astype(np.float16)\n    for k in range(np_num_unique):\n        unique_id = np_unique_ids[k]\n        np_data = sum(\n            global_embedding_grad.reshape(-1, embedding_size)[\n                np.where(global_ids.flatten() == unique_id)[0]\n            ]\n        )\n        # Quantize Embedding Gradient.\n        if enable_quantized_comm:\n            abs_max_factor = np.max(np.abs(np_data))\n            int8_factor = np.full(abs_max_factor.shape, 127.0, dtype=np.float32)\n            quantize_factor = int8_factor / abs_max_factor\n            np_data = np_data * quantize_factor\n            np_data = round_half_away_from_zero(np_data)\n            np_data = np_data.astype(np.int8)\n            np_data = np_data.astype(np.float32)\n            dequantize_factor = abs_max_factor / int8_factor\n            np_data = np_data * dequantize_factor\n\n        np_cur_rank_unique_embedding_grad[unique_id, :] = np_data\n        if fp16:\n            np_cur_rank_unique_embedding_grad = np_cur_rank_unique_embedding_grad.astype(\n                np.float32\n            )\n\n    cur_rank_num_ids = batch_size * num_tables * parallel_num\n    of_unique_embedding_grad = np.zeros((max_id, embedding_size))\n    for i in range(parallel_num):\n        num_unique_i = cur_rank_num_unique.numpy()[i]\n        unique_ids_i = cur_rank_unique_ids.numpy()[\n            cur_rank_num_ids * i : cur_rank_num_ids * (i + 1)\n        ]\n        unique_embedding_grad_i = cur_rank_unique_embedding_grad.numpy()[\n            cur_rank_num_ids * i : cur_rank_num_ids * (i + 1)\n        ]\n        for j in range(num_unique_i):\n            unique_id = unique_ids_i[j]\n            of_unique_embedding_grad[unique_id, :] = unique_embedding_grad_i[j, :]\n\n    test_case.assertTrue(\n        np.allclose(\n            of_unique_embedding_grad,\n            np_cur_rank_unique_embedding_grad,\n            atol=np_tolerance,\n            rtol=np_tolerance,\n        ),\n    )\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n@flow.unittest.skip_unless_1n2d()\nclass DataShuffleTestCase(flow.unittest.TestCase):\n    def test_id_shuffle(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"has_table_id\"] = [True, False]\n        arg_dict[\"num_tables\"] = [1, 26]\n        for kwargs in GenArgDict(arg_dict):\n            _test_id_shuffle(test_case, **kwargs)\n\n    def test_embedding_shuffle(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"dtype\"] = [flow.float32, flow.float16]\n        arg_dict[\"enable_quantize\"] = [True, False]\n\n        for kwargs in GenArgDict(arg_dict):\n            _test_embedding_shuffle(test_case, **kwargs)\n\n    def test_embedding_gradient_shuffle(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"enable_quantize\"] = [True, False]\n        arg_dict[\"fp16\"] = [True, False]\n        arg_dict[\"embedding_size\"] = [128, 17]\n        for kwargs in GenArgDict(arg_dict):\n            _test_embedding_gradient_shuffle(test_case, **kwargs)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/expensive/test_layernorm.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nimport oneflow as flow\nimport oneflow.unittest\n\n\nfrom oneflow.test_utils.automated_test_util import *\nfrom oneflow.test_utils.test_util import GenArgList\n\ninput_arr = np.array(\n    [\n        [\n            [[-0.16046895, -1.03667831], [-0.34974465, 0.26505867]],\n            [[-1.24111986, -0.53806001], [1.72426331, 0.43572459]],\n        ],\n        [\n            [[-0.77390957, -0.42610624], [0.16398858, -1.35760343]],\n            [[1.07541728, 0.11008703], [0.26361224, -0.48663723]],\n        ],\n    ],\n    dtype=np.float32,\n)\n\n\ndef _test_layernorm(test_case, device):\n    output = np.array(\n        [\n            [\n                [[-0.0544118, -1.0509688], [-0.2696846, 0.4295622]],\n                [[-1.2834904, -0.4838651], [2.0891891, 0.6236691]],\n            ],\n            [\n                [[-0.8555527, -0.3554582], [0.493019, -1.694826]],\n                [[1.8035311, 0.4155158], [0.6362644, -0.4424936]],\n            ],\n        ],\n        dtype=np.float32,\n    )\n    x = flow.tensor(input_arr, dtype=flow.float32, device=flow.device(device))\n    m = flow.nn.LayerNorm(x.size()[1:]).to(device=flow.device(device))\n    y = m(x)\n    test_case.assertTrue(np.allclose(y.numpy(), output, 1e-05, 1e-05))\n\n\ndef _test_layernorm_v2(test_case, device):\n    output = np.array(\n        [\n            [\n                [[0.3406544, -1.5249983], [-0.0623574, 1.2467014]],\n                [[-1.2004623, -0.5688803], [1.4634399, 0.3059027]],\n            ],\n            [\n                [[-0.3180245, 0.3122248], [1.3815271, -1.3757277]],\n                [[1.497291, -0.2341234], [0.0412391, -1.3044068]],\n            ],\n        ],\n        dtype=np.float32,\n    )\n    x = flow.tensor(input_arr, dtype=flow.float32, device=flow.device(device))\n    m = flow.nn.LayerNorm([2, 2], eps=1e-05).to(device=flow.device(device))\n    y = m(x)\n    test_case.assertTrue(np.allclose(y.numpy(), output, 1e-05, 1e-05))\n\n\ndef _test_layernorm_v3(test_case, device):\n    output = np.array(\n        [\n            [\n                [[0.999974, -0.999974], [-0.999947, 0.999947]],\n                [[-0.9999595, 0.9999595], [0.999988, -0.999988]],\n            ],\n            [\n                [[-0.9998344, 0.9998341], [0.9999914, -0.9999914]],\n                [[0.9999787, -0.9999787], [0.9999645, -0.9999645]],\n            ],\n        ],\n        dtype=np.float32,\n    )\n    x = flow.tensor(input_arr, dtype=flow.float32, device=flow.device(device))\n    m = flow.nn.LayerNorm(2, elementwise_affine=True).to(device=flow.device(device))\n    y = m(x)\n    test_case.assertTrue(np.allclose(y.numpy(), output, 1e-05, 1e-05))\n\n\ndef _test_layernorm_backward(test_case, device):\n    output = np.array(\n        [\n            [\n                [[-0.0544118, -1.0509688], [-0.2696846, 0.4295622]],\n                [[-1.2834904, -0.4838651], [2.0891891, 0.6236691]],\n            ],\n            [\n                [[-0.8555527, -0.3554582], [0.493019, -1.694826]],\n                [[1.8035311, 0.4155158], [0.6362644, -0.4424936]],\n            ],\n        ],\n        dtype=np.float32,\n    )\n    x = flow.tensor(\n        input_arr, dtype=flow.float32, device=flow.device(device), requires_grad=True\n    )\n    m = flow.nn.LayerNorm(x.size()[1:]).to(device=flow.device(device))\n    y = m(x)\n    z = y.sum()\n    z.backward()\n    test_case.assertTrue(\n        np.allclose(x.grad.numpy(), np.zeros(shape=input_arr.shape), 1e-05, 1e-05)\n    )\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n@flow.unittest.skip_unless_1n1d()\nclass TestLayerNorm(flow.unittest.TestCase):\n    def test_layernorm(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [\n            _test_layernorm,\n            _test_layernorm_v2,\n            _test_layernorm_v3,\n            _test_layernorm_backward,\n        ]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n    @autotest(n=20, auto_backward=True, rtol=1e-3, atol=1e-3)\n    def test_layernorm_with_random_data_warp(test_case):\n        device = \"cuda\"\n        channel = random(1, 32).to(int)\n        height = random(1, 2).to(int)\n        width = random(1, 1024).to(int)\n\n        def get_random_norm_shape():\n            begin_axis = random(1, 3).to(int).value()\n            return tuple((channel.value(), height.value(), width.value())[begin_axis:])\n\n        m = torch.nn.LayerNorm(\n            normalized_shape=get_random_norm_shape(),\n            elementwise_affine=random().to(bool),\n        ).to(device)\n        x = random_tensor(ndim=4, dim1=channel, dim2=height, dim3=width).to(device)\n        y = m(x)\n        return y\n\n    @autotest(n=10, auto_backward=True, rtol=1e-3, atol=1e-3)\n    def test_layernorm_with_random_data_shared_mem(test_case):\n        device = \"cuda\"\n        channel = random(1, 32).to(int)\n        height = random(1, 2).to(int)\n        width = random(1024, 8192).to(int)\n\n        def get_random_norm_shape():\n            begin_axis = random(1, 3).to(int).value()\n            return tuple((channel.value(), height.value(), width.value())[begin_axis:])\n\n        m = torch.nn.LayerNorm(\n            normalized_shape=get_random_norm_shape(),\n            elementwise_affine=random().to(bool),\n        ).to(device)\n        x = random_tensor(ndim=4, dim1=channel, dim2=height, dim3=width).to(device)\n        y = m(x)\n        return y\n\n    @autotest(n=5, auto_backward=True, rtol=1e-3, atol=1e-3)\n    def test_layernorm_with_random_data_uncached(test_case):\n        device = \"cuda\"\n        channel = random(1, 32).to(int)\n        height = random(1, 2).to(int)\n        width = random(8192, 32768).to(int)\n\n        def get_random_norm_shape():\n            begin_axis = random(1, 3).to(int).value()\n            return tuple((channel.value(), height.value(), width.value())[begin_axis:])\n\n        m = torch.nn.LayerNorm(\n            normalized_shape=get_random_norm_shape(),\n            elementwise_affine=random().to(bool),\n        ).to(device)\n        x = random_tensor(ndim=4, dim1=channel, dim2=height, dim3=width).to(device)\n        y = m(x)\n        return y\n\n    @autotest(n=10, auto_backward=True, rtol=1e-3, atol=1e-3)\n    def test_layernorm_without_affine(test_case):\n        device = random_device()\n        channel = random(1, 32).to(int)\n        height = random(1, 2).to(int)\n        width = random(8192, 32768).to(int)\n\n        def get_random_norm_shape():\n            begin_axis = random(1, 3).to(int).value()\n            return tuple((channel.value(), height.value(), width.value())[begin_axis:])\n\n        m = torch.nn.LayerNorm(normalized_shape=get_random_norm_shape()).to(device)\n        x = random_tensor(ndim=4, dim1=channel, dim2=height, dim3=width).to(device)\n        y = m(x)\n        return y\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/expensive/test_oneembedding.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport os\n\nimport unittest\nfrom collections import OrderedDict\nfrom oneflow.test_utils.test_util import GenArgDict\nimport numpy as np\nimport oneflow as flow\nimport oneflow.nn as nn\nimport tempfile\nimport hashlib\n\n\nclass OneEmbedding(nn.Module):\n    def __init__(\n        self,\n        test_id,\n        embedding_vec_size,\n        persistent_path,\n        table_size_array,\n        size_factor,\n    ):\n        assert table_size_array is not None\n        vocab_size = sum(table_size_array)\n\n        scales = np.sqrt(1 / np.array(table_size_array))\n        tables = [\n            flow.one_embedding.make_table(\n                flow.one_embedding.make_uniform_initializer(low=-scale, high=scale)\n            )\n            for scale in scales\n        ]\n        store_options = flow.one_embedding.make_device_mem_store_options(\n            persistent_path=persistent_path,\n            capacity=vocab_size,\n            size_factor=size_factor,\n        )\n\n        super(OneEmbedding, self).__init__()\n        self.one_embedding = flow.one_embedding.MultiTableEmbedding(\n            f\"oneembedding_{test_id}\",\n            embedding_dim=embedding_vec_size,\n            dtype=flow.float,\n            key_type=flow.int64,\n            tables=tables,\n            store_options=store_options,\n        )\n\n    def forward(self, ids):\n        return self.one_embedding.forward(ids)\n\n\nclass TestModule(nn.Module):\n    def __init__(\n        self,\n        test_id,\n        embedding_vec_size,\n        persistent_path,\n        table_size_array,\n        size_factor,\n    ):\n        super(TestModule, self).__init__()\n        self.embedding = OneEmbedding(\n            test_id, embedding_vec_size, persistent_path, table_size_array, size_factor\n        )\n        self.mlp = nn.Linear(embedding_vec_size, 1)\n\n    def forward(self, inputs) -> flow.Tensor:\n        embedding = self.embedding(inputs)\n        logits = self.mlp(embedding).mean(dim=1)\n        return logits\n\n\nclass TrainGraph(flow.nn.Graph):\n    def __init__(\n        self, module, loss, optimizer, amp=False,\n    ):\n        super(TrainGraph, self).__init__()\n        self.module = module\n        self.loss = loss\n        self.add_optimizer(optimizer)\n        if amp:\n            self.config.enable_amp(True)\n\n    def build(self, labels, features):\n        logits = self.module(features.to(\"cuda\"))\n        loss = self.loss(logits, labels.to(\"cuda\"))\n        reduce_loss = flow.mean(loss)\n        reduce_loss.backward()\n        return reduce_loss.to(\"cpu\")\n\n\ndef _test_one_embedding(\n    test_case, batch_size, table_size_array, embedding_size, test_opt\n):\n    test_str = str([batch_size, table_size_array, embedding_size, test_opt])\n    test_hash = hashlib.sha256(test_str.encode(\"utf-8\")).hexdigest()\n\n    def np_to_global(np):\n        t = flow.from_numpy(np)\n        return t.to_global(placement=flow.placement.all(\"cpu\"), sbp=flow.sbp.split(0))\n\n    with tempfile.TemporaryDirectory() as persistent_path:\n        size_factor = 3 if test_opt == \"Adam\" else 1\n        module = TestModule(\n            test_hash, embedding_size, persistent_path, table_size_array, size_factor\n        )\n        module.to_global(flow.placement.all(\"cuda\"), flow.sbp.broadcast)\n\n        if test_opt == \"Adam\":\n            opt = flow.optim.Adam(module.parameters(), lr=0.1)\n        elif test_opt == \"SGD\":\n            opt = flow.optim.SGD(module.parameters(), lr=0.1)\n        else:\n            assert False\n\n        loss = flow.nn.BCEWithLogitsLoss(reduction=\"none\").to(\"cuda\")\n\n        train_graph = TrainGraph(module, loss, opt)\n\n        module.train()\n        for step in range(1, 101):\n            labels = np.random.randint(2, size=(batch_size, 1)).astype(np.float32)\n            features = np.random.randint(\n                sum(table_size_array), size=(batch_size, len(table_size_array))\n            )\n            labels = np_to_global(labels)\n            features = np_to_global(features)\n            loss = train_graph(labels, features)\n            test_case.assertFalse(np.isnan(loss.numpy()))\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\nclass OneEmbeddingTestCase(flow.unittest.TestCase):\n    def test_one_embedding(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"batch_size\"] = [32, 4096]\n        arg_dict[\"table_size_array\"] = [\n            [32, 65536, 100, 7],\n            [32768, 10000, 17, 3, 686],\n        ]\n        arg_dict[\"embedding_size\"] = [128, 17]\n        arg_dict[\"test_opt\"] = [\"SGD\", \"Adam\"]\n        for kwargs in GenArgDict(arg_dict):\n            _test_one_embedding(test_case, **kwargs)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/expensive/test_oneembedding_padding_idx.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport os\n\nimport unittest\nfrom collections import OrderedDict\nfrom oneflow.test_utils.test_util import GenArgDict\nimport numpy as np\nimport oneflow as flow\nimport oneflow.nn as nn\nimport tempfile\nimport hashlib\nimport random\n\n\nclass OneEmbedding(nn.Module):\n    def __init__(\n        self,\n        test_id,\n        embedding_vec_size,\n        persistent_path,\n        table_size_array,\n        size_factor,\n        padding_idx,\n    ):\n        assert table_size_array is not None\n        vocab_size = sum(table_size_array)\n\n        scales = np.sqrt(1 / np.array(table_size_array))\n        tables = [\n            flow.one_embedding.make_table(\n                flow.one_embedding.make_uniform_initializer(low=-scale, high=scale)\n            )\n            for scale in scales\n        ]\n        store_options = flow.one_embedding.make_device_mem_store_options(\n            persistent_path=persistent_path,\n            capacity=vocab_size,\n            size_factor=size_factor,\n        )\n\n        super(OneEmbedding, self).__init__()\n        self.one_embedding = flow.one_embedding.MultiTableEmbedding(\n            f\"oneembedding_{test_id}\",\n            embedding_dim=embedding_vec_size,\n            dtype=flow.float,\n            key_type=flow.int64,\n            tables=tables,\n            store_options=store_options,\n            padding_idx=padding_idx,\n        )\n\n    def forward(self, ids):\n        return self.one_embedding.forward(ids)\n\n\nclass TestModule(nn.Module):\n    def __init__(\n        self,\n        test_id,\n        embedding_vec_size,\n        persistent_path,\n        table_size_array,\n        size_factor,\n        padding_idx,\n    ):\n        super(TestModule, self).__init__()\n        self.embedding = OneEmbedding(\n            test_id,\n            embedding_vec_size,\n            persistent_path,\n            table_size_array,\n            size_factor,\n            padding_idx=padding_idx,\n        )\n\n    def forward(self, inputs) -> flow.Tensor:\n        embedding = self.embedding(inputs)\n        return embedding\n\n\nclass TrainGraph(flow.nn.Graph):\n    def __init__(\n        self, module, loss, optimizer, amp=False,\n    ):\n        super(TrainGraph, self).__init__()\n        self.module = module\n        self.loss = loss\n        self.add_optimizer(optimizer)\n        if amp:\n            self.config.enable_amp(True)\n\n    def build(self, labels, features):\n        embedding = self.module(features.to(\"cuda\"))\n        reduce_loss = flow.mean(embedding)\n        reduce_loss.backward()\n        return embedding.to(\"cpu\")\n\n\ndef _test_one_embedding_padding_idx(\n    test_case, batch_size, table_size_array, embedding_size, test_opt, padding_idx\n):\n    test_str = str([batch_size, table_size_array, embedding_size, test_opt])\n    test_hash = hashlib.sha256(test_str.encode(\"utf-8\")).hexdigest()\n\n    def np_to_global(np):\n        t = flow.from_numpy(np)\n        return t.to_global(placement=flow.placement.all(\"cpu\"), sbp=flow.sbp.split(0))\n\n    with tempfile.TemporaryDirectory() as persistent_path:\n        size_factor = 3 if test_opt == \"Adam\" else 1\n        module = TestModule(\n            test_hash,\n            embedding_size,\n            persistent_path,\n            table_size_array,\n            size_factor,\n            padding_idx,\n        )\n        module.to_global(flow.placement.all(\"cuda\"), flow.sbp.broadcast)\n\n        if test_opt == \"Adam\":\n            opt = flow.optim.Adam(module.parameters(), lr=0.1)\n        elif test_opt == \"SGD\":\n            opt = flow.optim.SGD(module.parameters(), lr=0.1)\n        else:\n            assert False\n\n        loss = flow.nn.BCEWithLogitsLoss(reduction=\"none\").to(\"cuda\")\n\n        train_graph = TrainGraph(module, loss, opt)\n\n        module.train()\n\n        padding_num = random.randint(0, batch_size - 1)\n        labels = np.random.randint(2, size=(batch_size, 1)).astype(np.float32)\n        padding_feature = np.full(\n            (len(table_size_array)), fill_value=padding_idx\n        ).astype(np.int64)\n\n        features = np.random.randint(\n            sum(table_size_array), size=(batch_size, len(table_size_array))\n        )\n        padding_feature_idx = np.random.randint(batch_size, size=(padding_num,))\n        for i in range(padding_num):\n            idx = int(padding_feature_idx[i])\n            features[idx] = padding_feature\n\n        labels = np_to_global(labels)\n        features = np_to_global(features)\n        embedding_val = train_graph(labels, features)\n        for i in range(padding_feature_idx.size):\n            idx = int(padding_feature_idx[i])\n            test_case.assertTrue(\n                np.array_equal(\n                    embedding_val[idx].numpy(),\n                    np.zeros((len(table_size_array), embedding_size), dtype=np.float32),\n                )\n            )\n\n        # Infer again to check the embedding in padding_idx is not updated.\n        embedding_val = train_graph(labels, features)\n        for i in range(padding_feature_idx.size):\n            idx = int(padding_feature_idx[i])\n            test_case.assertTrue(\n                np.array_equal(\n                    embedding_val[idx].numpy(),\n                    np.zeros((len(table_size_array), embedding_size), dtype=np.float32),\n                )\n            )\n\n\n@flow.unittest.skip_unless_1n1d()\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\nclass OneEmbeddingWithPaddingIdxTestCase(flow.unittest.TestCase):\n    def test_one_embedding_padding_idx(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"batch_size\"] = [32]\n        arg_dict[\"table_size_array\"] = [\n            [32, 64, 32, 32],\n        ]\n        arg_dict[\"embedding_size\"] = [12]\n        arg_dict[\"test_opt\"] = [\"SGD\"]\n        arg_dict[\"padding_idx\"] = [2]\n        os.environ[\"ONEFLOW_TIMEOUT_SECONDS\"] = \"300\"\n        for kwargs in GenArgDict(arg_dict):\n            _test_one_embedding_padding_idx(test_case, **kwargs)\n        os.environ[\"ONEFLOW_TIMEOUT_SECONDS\"] = \"90\"\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/expensive/test_permute.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nfrom random import shuffle\nimport numpy as np\nfrom oneflow.test_utils.test_util import GenArgList\nfrom oneflow.test_utils.automated_test_util import *\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\ndef _test_permute_impl(test_case, device):\n    input = flow.tensor(\n        np.random.randn(2, 6, 5, 3),\n        dtype=flow.float32,\n        device=flow.device(device),\n        requires_grad=True,\n    )\n    of_out1 = flow.permute(input, (1, 0, 2, 3))\n    np_out = input.numpy().transpose((1, 0, 2, 3))\n    test_case.assertTrue(np.array_equal(of_out1.numpy().flatten(), np_out.flatten()))\n    of_out = of_out1.sum()\n    of_out.backward()\n    np_grad = np.ones((2, 6, 5, 3))\n    test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 0.0001, 0.0001))\n\n\ndef _test_tensor_permute_impl(test_case, device):\n    input = flow.tensor(\n        np.random.randn(2, 6, 5, 3),\n        dtype=flow.float32,\n        device=flow.device(device),\n        requires_grad=True,\n    )\n    of_out1 = input.permute(1, 0, 2, 3)\n    of_out2 = input.permute(*(1, 0, 2, 3))\n    of_out3 = input.permute((1, 0, 2, 3))\n    of_out4 = input.permute([1, 0, 2, 3])\n    np_out = input.numpy().transpose((1, 0, 2, 3))\n    test_case.assertTrue(np.array_equal(of_out1.numpy().flatten(), np_out.flatten()))\n    test_case.assertTrue(np.array_equal(of_out2.numpy().flatten(), np_out.flatten()))\n    test_case.assertTrue(np.array_equal(of_out3.numpy().flatten(), np_out.flatten()))\n    test_case.assertTrue(np.array_equal(of_out4.numpy().flatten(), np_out.flatten()))\n    of_out = of_out1.sum()\n    of_out.backward()\n    np_grad = np.ones((2, 6, 5, 3))\n    test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 0.0001, 0.0001))\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestPermute(flow.unittest.TestCase):\n    def test_permute(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            _test_permute_impl(test_case, *arg)\n            _test_tensor_permute_impl(test_case, *arg)\n\n    @autotest(n=10, check_graph=False)\n    def test_torch_permute4d_with_random_data(test_case):\n        device = random_device()\n        ndim = 4\n        permute_list = [0, 1, 2, 3]\n        shuffle(permute_list)\n        x = random_tensor(ndim=ndim, dim0=random().to(int)).to(device)\n        y = torch.permute(x, dims=permute_list)\n        return y\n\n    @unittest.skip(\"pytorch 1.9.0 exist not torch.permute api\")\n    @autotest(n=10)\n    def test_torch_permute4d_with_random_0dim_data(test_case):\n        device = random_device()\n        permute_list = [0, 1, 2, 3]\n        shuffle(permute_list)\n        x = random_tensor(ndim=0).to(device)\n        y = torch.permute(x, dims=permute_list)\n        return y\n\n    @autotest(n=10, check_graph=True)\n    def test_permute5d_tensor_with_random_data(test_case):\n        device = random_device()\n        ndim = 5\n        permute_list = [0, 1, 2, 3, 4]\n        shuffle(permute_list)\n        x = random_tensor(\n            ndim=ndim,\n            dim0=random(1, 16).to(int),\n            dim1=random(1, 33).to(int),\n            dim2=random(1, 64).to(int),\n            dim3=random(45, 67).to(int),\n            dim4=random(1, 64).to(int),\n        ).to(device)\n        y = x.permute(permute_list)\n        return y\n\n    @autotest(n=10, check_graph=True)\n    def test_permute4d_tensor_with_random_data(test_case):\n        device = random_device()\n        ndim = 4\n        permute_list = [0, 1, 2, 3]\n        shuffle(permute_list)\n        x = random_tensor(\n            ndim=ndim,\n            dim0=random(1, 7).to(int),\n            dim1=random(1, 15).to(int),\n            dim2=random(1, 9).to(int),\n            dim3=random(1, 19).to(int),\n        ).to(device)\n        y = x.permute(permute_list)\n        return y\n\n    @autotest(n=10, check_graph=True)\n    def test_permute4d_tensor_with_stride(test_case):\n        device = random_device()\n        ndim = 4\n        permute_list1 = [0, 1, 2, 3]\n        shuffle(permute_list1)\n        x = random_tensor(\n            ndim=ndim,\n            dim0=random(1, 7).to(int),\n            dim1=random(1, 15).to(int),\n            dim2=random(1, 9).to(int),\n            dim3=random(1, 19).to(int),\n        ).to(device)\n        y = x.permute(permute_list1)\n        permute_list2 = [0, 1, 2, 3]\n        shuffle(permute_list2)\n        z = y.permute(permute_list2)\n        return z\n\n    @autotest(n=5, check_graph=True)\n    def test_permute3d_tensor_with_random_data(test_case):\n        device = random_device()\n        ndim = 3\n        permute_list = [0, 1, 2]\n        shuffle(permute_list)\n        x = random_tensor(\n            ndim=ndim,\n            dim0=random(1, 18).to(int),\n            dim1=random(1, 78).to(int),\n            dim2=random(1, 99).to(int),\n        ).to(device)\n        y = x.permute(permute_list)\n        return y\n\n    @autotest(n=10, auto_backward=False, check_graph=True)\n    def test_permute4d_tensor_bool_with_random_data(test_case):\n        device = random_device()\n        ndim = 4\n        permute_list = [0, 1, 2, 3]\n        shuffle(permute_list)\n        x = random_tensor(\n            ndim=ndim,\n            dim0=random(1, 7).to(int),\n            dim1=random(1, 15).to(int),\n            dim2=random(1, 9).to(int),\n            dim3=random(1, 19).to(int),\n        ).to(device=device, dtype=torch.bool)\n        y = x.permute(permute_list)\n        return y\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/expensive/test_remat.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport subprocess\nimport sys\nimport os\nimport unittest\nimport oneflow as flow\nimport oneflow.unittest\n\n\nclass TestRemat(flow.unittest.TestCase):\n    def test_remat_in_single_threaded_vm(test_case):\n        env = os.environ.copy()\n        env[\"ONEFLOW_VM_MULTI_THREAD\"] = \"0\"\n        p = subprocess.run(\n            [sys.executable, \"_test_remat.py\"],\n            cwd=os.path.dirname(os.path.realpath(__file__)),\n            env=env,\n        )\n        test_case.assertEqual(p.returncode, 0)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/expensive/test_resnet50_with_bn.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport os\nimport unittest\n\nfrom resnet50_model import resnet50\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\n@flow.unittest.skip_unless_1n1d()\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\nclass TestResNet50(flow.unittest.TestCase):\n    def test_resnet50_with_batchnorm(test_case):\n        batch_size = 32\n        color_space = \"RGB\"\n        height = 224\n        width = 224\n        output_layout = \"NCHW\"\n        rgb_mean = [123.68, 116.779, 103.939]\n        rgb_std = [58.393, 57.12, 57.375]\n        record_reader = flow.nn.OFRecordReader(\n            flow.unittest.dataset_dir(\"imagenette/ofrecord\"),\n            batch_size=batch_size,\n            data_part_num=1,\n            part_name_suffix_length=5,\n            shuffle_after_epoch=False,\n        )\n        record_image_decoder = flow.nn.OFRecordImageDecoder(\n            \"encoded\", color_space=color_space\n        )\n        record_label_decoder = flow.nn.OFRecordRawDecoder(\n            \"class/label\", shape=(), dtype=flow.int32\n        )\n        resize = flow.nn.image.Resize(\n            resize_side=\"shorter\", keep_aspect_ratio=True, target_size=256\n        )\n        crop_mirror_normal = flow.nn.CropMirrorNormalize(\n            color_space=color_space,\n            output_layout=output_layout,\n            crop_h=height,\n            crop_w=width,\n            crop_pos_y=0.5,\n            crop_pos_x=0.5,\n            mean=rgb_mean,\n            std=rgb_std,\n            output_dtype=flow.float,\n        )\n        res50_module = resnet50(\n            replace_stride_with_dilation=[False, False, False],\n            norm_layer=flow.nn.BatchNorm2d,\n        )\n        res50_module.train()\n        res50_module.load_state_dict(\n            flow.load(flow.unittest.dataset_dir(\"imagenette/resnet50_models\"))\n        )\n        of_corss_entropy = flow.nn.CrossEntropyLoss()\n        res50_module.to(\"cuda\")\n        of_corss_entropy.to(\"cuda\")\n        learning_rate = 0.001\n        mom = 0.9\n        of_sgd = flow.optim.SGD(\n            res50_module.parameters(), lr=learning_rate, momentum=mom\n        )\n        errors = 0.0\n        for b in range(100):\n            val_record = record_reader()\n            label = record_label_decoder(val_record)\n            image_raw_buffer = record_image_decoder(val_record)\n            image = resize(image_raw_buffer)[0]\n            image = crop_mirror_normal(image)\n            image = image.to(\"cuda\")\n            label = label.to(\"cuda\")\n            logits = res50_module(image)\n            loss = of_corss_entropy(logits, label)\n            loss.backward()\n            of_sgd.step()\n            of_sgd.zero_grad()\n            l = loss.numpy()\n        test_case.assertTrue(l < 3.5)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/expensive/test_resnet50_without_bn.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport os\nimport unittest\n\nimport numpy as np\nfrom resnet50_model import FakeBN, resnet50\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\n@flow.unittest.skip_unless_1n1d()\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\nclass TestResNet50(flow.unittest.TestCase):\n    def test_resnet50_without_batchnorm(test_case):\n        batch_size = 32\n        color_space = \"RGB\"\n        height = 224\n        width = 224\n        output_layout = \"NCHW\"\n        rgb_mean = [123.68, 116.779, 103.939]\n        rgb_std = [58.393, 57.12, 57.375]\n        record_reader = flow.nn.OFRecordReader(\n            flow.unittest.dataset_dir(\"imagenette/ofrecord\"),\n            batch_size=batch_size,\n            data_part_num=1,\n            part_name_suffix_length=5,\n            shuffle_after_epoch=False,\n        )\n        record_image_decoder = flow.nn.OFRecordImageDecoder(\n            \"encoded\", color_space=color_space\n        )\n        record_label_decoder = flow.nn.OFRecordRawDecoder(\n            \"class/label\", shape=(), dtype=flow.int32\n        )\n        resize = flow.nn.image.Resize(\n            resize_side=\"shorter\", keep_aspect_ratio=True, target_size=256\n        )\n        crop_mirror_normal = flow.nn.CropMirrorNormalize(\n            color_space=color_space,\n            output_layout=output_layout,\n            crop_h=height,\n            crop_w=width,\n            crop_pos_y=0.5,\n            crop_pos_x=0.5,\n            mean=rgb_mean,\n            std=rgb_std,\n            output_dtype=flow.float,\n        )\n        res50_module = resnet50(\n            replace_stride_with_dilation=[False, False, False], norm_layer=FakeBN\n        )\n        res50_module.train()\n        res50_module.load_state_dict(\n            flow.load(flow.unittest.dataset_dir(\"resnet50_wo_bn_weights_for_ci\"))\n        )\n        of_corss_entropy = flow.nn.CrossEntropyLoss()\n        res50_module.to(\"cuda\")\n        of_corss_entropy.to(\"cuda\")\n        learning_rate = 0.001\n        mom = 0.9\n        of_sgd = flow.optim.SGD(\n            res50_module.parameters(), lr=learning_rate, momentum=mom\n        )\n        gt_of_losses = [\n            49.83235168457031,\n            36.34172821044922,\n            23.585250854492188,\n            15.628865242004395,\n            9.552209854125977,\n            8.11514663696289,\n            6.364114284515381,\n            6.442500114440918,\n            4.439807891845703,\n            4.024901866912842,\n            4.7038373947143555,\n            4.253284454345703,\n            4.5806169509887695,\n            4.158677577972412,\n            3.0066077709198,\n            4.611920356750488,\n            4.46696138381958,\n            2.9725658893585205,\n            3.2383458614349365,\n            3.605447292327881,\n            3.8676259517669678,\n            3.2477705478668213,\n            2.9191272258758545,\n            3.162745475769043,\n            3.0127673149108887,\n            2.615905284881592,\n            2.7866411209106445,\n            3.471228837966919,\n            2.9467897415161133,\n            3.3623316287994385,\n        ]\n        for b in range(len(gt_of_losses)):\n            val_record = record_reader()\n            label = record_label_decoder(val_record)\n            image_raw_buffer = record_image_decoder(val_record)\n            image = resize(image_raw_buffer)[0]\n            image = crop_mirror_normal(image)\n            image = image.to(\"cuda\")\n            label = label.to(\"cuda\")\n            logits = res50_module(image)\n            loss = of_corss_entropy(logits, label)\n            loss.backward()\n            of_sgd.step()\n            of_sgd.zero_grad()\n            l = loss.numpy()\n            test_case.assertTrue(\n                np.allclose(l.item(), gt_of_losses[b], rtol=1e-2, atol=1e-3)\n            )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/expensive/test_rnn.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nimport oneflow as flow\n\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestRNNModules(flow.unittest.TestCase):\n    @autotest(n=5, check_graph=True, rtol=1e-2, atol=1e-3)\n    def test_rnn(test_case):\n        device = random_device()\n        batch_size = random(1, 6)\n        time_steps = random(1, 6)\n        num_layers = random(1, 6).to(int)\n        input_size = random(2, 6).to(int)\n        hidden_size = random(2, 6).to(int)\n        m = torch.nn.RNN(\n            input_size,\n            hidden_size,\n            num_layers=num_layers,\n            nonlinearity=\"tanh\",\n            bias=random().to(bool),\n            batch_first=random().to(bool),\n            dropout=0,\n            bidirectional=random().to(bool),\n        ).to(device)\n        input = random_tensor(\n            ndim=3, dim0=time_steps, dim1=batch_size, dim2=input_size\n        ).to(device)\n        out = m(input)\n        return out[0]\n\n    @autotest(n=5, check_graph=True, rtol=1e-2)\n    def test_lstm(test_case):\n        device = random_device()\n        batch_size = random(1, 6)\n        time_steps = random(1, 6)\n        num_layers = random(1, 6).to(int)\n        input_size = random(2, 6).to(int)\n        hidden_size = random(2, 6).to(int)\n        proj_size = random(2, 6).to(int)\n        m = torch.nn.LSTM(\n            input_size=input_size,\n            hidden_size=hidden_size,\n            num_layers=num_layers,\n            bias=random().to(bool),\n            batch_first=random().to(bool),\n            dropout=0,\n            bidirectional=random().to(bool),\n            proj_size=proj_size,\n        ).to(device)\n        input = random_tensor(\n            ndim=3, dim0=time_steps, dim1=batch_size, dim2=input_size\n        ).to(device)\n        out = m(input)\n        return out[0]\n\n    @autotest(n=5, check_graph=True, rtol=1e-2)\n    def test_gru(test_case):\n        device = random_device()\n        batch_size = random(1, 6)\n        time_steps = random(1, 6)\n        num_layers = random(1, 6).to(int)\n        input_size = random(2, 6).to(int)\n        hidden_size = random(2, 6).to(int)\n        m = torch.nn.GRU(\n            input_size=input_size,\n            hidden_size=hidden_size,\n            num_layers=num_layers,\n            bias=random().to(bool),\n            batch_first=random().to(bool),\n            dropout=0,\n            bidirectional=random().to(bool),\n        ).to(device)\n        input = random_tensor(\n            ndim=3, dim0=time_steps, dim1=batch_size, dim2=input_size\n        ).to(device)\n        out = m(input)\n        return out[0]\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/expensive/test_rnn_cell.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nimport oneflow as flow\n\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestRNN(flow.unittest.TestCase):\n    @autotest(n=5, check_graph=True, rtol=1e-2, atol=1e-3)\n    def test_rnn_tanh_cell(test_case):\n        device = random_device()\n        batch_size = random(1, 6)\n        time_steps = random(1, 6)\n        input_size = random(1, 6) * 2\n        hidden_size = random(1, 6) * 2\n        m = torch.nn.RNNCell(\n            input_size=input_size,\n            hidden_size=hidden_size,\n            bias=random().to(bool),\n            nonlinearity=\"tanh\",\n        ).to(device)\n        input = random_tensor(\n            ndim=3, dim0=time_steps, dim1=batch_size, dim2=input_size\n        ).to(device)\n        hx = random_tensor(ndim=2, dim0=batch_size, dim1=hidden_size).to(device)\n        for i in range(time_steps.to(int).value()):\n            hx = m(input[i], hx)\n        return hx\n\n    @autotest(n=5, check_graph=True)\n    def test_rnn_relu_cell(test_case):\n        device = random_device()\n        batch_size = random(1, 6)\n        time_steps = random(1, 6)\n        input_size = random(1, 6) * 2\n        hidden_size = random(1, 6) * 2\n        m = torch.nn.RNNCell(\n            input_size=input_size,\n            hidden_size=hidden_size,\n            bias=random().to(bool),\n            nonlinearity=\"relu\",\n        ).to(device)\n        input = random_tensor(\n            ndim=3, dim0=time_steps, dim1=batch_size, dim2=input_size\n        ).to(device)\n        hx = random_tensor(ndim=2, dim0=batch_size, dim1=hidden_size).to(device)\n        for i in range(time_steps.to(int).value()):\n            hx = m(input[i], hx)\n        return hx\n\n    @unittest.skip(\"skip for now, becase it failed 4 times in past week\")\n    @autotest(n=5, check_graph=True, rtol=1e-2)\n    def test_lstm_cell(test_case):\n        device = random_device()\n        batch_size = random(1, 6)\n        time_steps = random(1, 6)\n        input_size = random(1, 6) * 2\n        hidden_size = random(1, 6) * 2\n        has_bias = random().to(bool)\n        cx_requires_grad = random().to(bool)\n        m = torch.nn.LSTMCell(\n            input_size=input_size, hidden_size=hidden_size, bias=has_bias,\n        ).to(device)\n        input = random_tensor(\n            ndim=3, dim0=time_steps, dim1=batch_size, dim2=input_size\n        ).to(device)\n        hx = random_tensor(\n            ndim=2, dim0=batch_size, dim1=hidden_size, requires_grad=False\n        ).to(device)\n        cx = random_tensor(\n            ndim=2, dim0=batch_size, dim1=hidden_size, requires_grad=cx_requires_grad\n        ).to(device)\n\n        for i in range(time_steps.to(int).value()):\n            res = m(input[i], (hx, cx))\n            hx = res[0]\n            cx = res[1]\n        return res[0]\n\n    @autotest(n=5, check_graph=True, rtol=1e-2)\n    def test_gru_cell(test_case):\n        device = random_device()\n        batch_size = random(1, 6)\n        time_steps = random(1, 6)\n        input_size = random(1, 6) * 2\n        hidden_size = random(1, 6) * 2\n        has_bias = random().to(bool)\n        m = torch.nn.GRUCell(\n            input_size=input_size, hidden_size=hidden_size, bias=has_bias\n        ).to(device)\n        input = random_tensor(\n            ndim=3, dim0=time_steps, dim1=batch_size, dim2=input_size\n        ).to(device)\n        hx = random_tensor(ndim=2, dim0=batch_size, dim1=hidden_size).to(device)\n        for i in range(time_steps.to(int).value()):\n            hx = m(input[i], hx)\n        return hx\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/expensive/test_rnn_pack_sequence.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nimport random\nimport numpy as np\nfrom collections import OrderedDict\nimport torch\nimport torch.nn.utils.rnn as torch_rnn_utils\n\nimport oneflow as flow\nimport oneflow.nn.utils.rnn as flow_rnn_utils\nimport oneflow.unittest\nfrom oneflow.test_utils.test_util import GenArgList\n\n\ndef _test_rnn_pack_sequence(test_case, device):\n    l = [\"tanh\", \"relu\"]\n    input_size = random.randint(10, 1000)\n    hidden_size = random.randint(10, 1000)\n    num_layers = random.randint(1, 6)\n    nonlinearity = l[0 if num_layers <= 3 else 1]\n    grad_tol = 1e-4\n    if nonlinearity == \"relu\":\n        grad_tol = 100\n    bias = random.randint(-10, 10) <= 0\n    batch_first = False\n    dropout = 0\n    bidirectional = random.randint(-10, 10) <= 0\n\n    rnn_torch = torch.nn.RNN(\n        input_size=input_size,\n        hidden_size=hidden_size,\n        num_layers=num_layers,\n        nonlinearity=nonlinearity,\n        bias=bias,\n        batch_first=batch_first,\n        dropout=dropout,\n        bidirectional=bidirectional,\n    )\n\n    rnn_flow = flow.nn.RNN(\n        input_size=input_size,\n        hidden_size=hidden_size,\n        num_layers=num_layers,\n        nonlinearity=nonlinearity,\n        bias=bias,\n        batch_first=batch_first,\n        dropout=dropout,\n        bidirectional=bidirectional,\n    )\n\n    torch_state_dict = rnn_torch.state_dict()\n    new_dict = {}\n    for k, v in torch_state_dict.items():\n        new_dict[k] = v.detach().numpy()\n    rnn_flow.load_state_dict(new_dict)\n\n    rnn_flow = rnn_flow.to(device)\n    rnn_torch = rnn_torch.to(device)\n\n    max_seq_len = random.randint(10, 50)\n    batch_size = random.randint(10, 50)\n    lengths = []\n    lengths.append(max_seq_len)\n    for i in range(batch_size - 1):\n        lengths.append(random.randint(1, max_seq_len))\n    lengths.sort(reverse=True)\n\n    sequences = []\n    for i in range(batch_size):\n        sequences.append(flow.rand(lengths[i], input_size).to(device))\n\n    x_flow = flow_rnn_utils.pack_sequence(sequences)\n    torch_inputs = [torch.tensor(ft.numpy(), device=device) for ft in sequences]\n    x_torch = torch_rnn_utils.pack_sequence(torch_inputs)\n\n    out_torch, hid_torch = rnn_torch(x_torch)\n    out_flow, hid_flow = rnn_flow(x_flow)\n\n    z_torch = out_torch.data.sum()\n    z_torch.backward()\n    z_flow = out_flow.data.sum()\n    z_flow.backward()\n\n    test_case.assertTrue(\n        np.allclose(\n            out_torch.data.cpu().detach().numpy(),\n            out_flow.data.cpu().detach().numpy(),\n            atol=1e-5,\n        )\n    )\n\n    test_case.assertTrue(\n        np.allclose(\n            hid_torch.cpu().detach().numpy(),\n            hid_flow.cpu().detach().numpy(),\n            atol=1e-5,\n        )\n    )\n\n    all_weights = rnn_torch.all_weights\n    torch_params = []\n    for ls in all_weights:\n        for l in ls:\n            torch_params.append(l)\n    all_weights = rnn_flow.all_weights\n    flow_params = []\n    for ls in all_weights:\n        for l in ls:\n            flow_params.append(l)\n\n    for i in range(len(flow_params)):\n        torch_np = torch_params[i].grad.cpu().numpy()\n        flow_np = flow_params[i].grad.cpu().numpy()\n        test_case.assertTrue(np.allclose(torch_np, flow_np, atol=grad_tol))\n\n\ndef _test_lstm_pack_sequence(test_case, device):\n    input_size = random.randint(10, 1000)\n    hidden_size = random.randint(12, 1000)\n    num_layers = random.randint(1, 6)\n    bias = random.randint(-10, 10) <= 0\n    batch_first = False\n    dropout = 0\n    bidirectional = random.randint(-10, 10) <= 0\n    proj_size = random.randint(0, hidden_size - 1)\n\n    lstm_torch = torch.nn.LSTM(\n        input_size=input_size,\n        hidden_size=hidden_size,\n        num_layers=num_layers,\n        bias=bias,\n        batch_first=batch_first,\n        dropout=dropout,\n        bidirectional=bidirectional,\n        proj_size=proj_size,\n    )\n\n    lstm_flow = flow.nn.LSTM(\n        input_size=input_size,\n        hidden_size=hidden_size,\n        num_layers=num_layers,\n        bias=bias,\n        batch_first=batch_first,\n        dropout=dropout,\n        bidirectional=bidirectional,\n        proj_size=proj_size,\n    )\n\n    torch_state_dict = lstm_torch.state_dict()\n    new_dict = {}\n    for k, v in torch_state_dict.items():\n        new_dict[k] = v.detach().numpy()\n    lstm_flow.load_state_dict(new_dict)\n\n    lstm_flow = lstm_flow.to(device)\n    lstm_torch = lstm_torch.to(device)\n\n    max_seq_len = random.randint(10, 50)\n    batch_size = random.randint(10, 50)\n    lengths = []\n    lengths.append(max_seq_len)\n    for i in range(batch_size - 1):\n        lengths.append(random.randint(1, max_seq_len))\n    lengths.sort(reverse=True)\n\n    sequences = []\n    for i in range(batch_size):\n        sequences.append(flow.rand(lengths[i], input_size).to(device))\n\n    x_flow = flow_rnn_utils.pack_sequence(sequences)\n    torch_inputs = [torch.tensor(ft.numpy(), device=device) for ft in sequences]\n    x_torch = torch_rnn_utils.pack_sequence(torch_inputs)\n\n    out_torch, hid_torch = lstm_torch(x_torch)\n    out_flow, hid_flow = lstm_flow(x_flow)\n\n    z_torch = out_torch.data.sum()\n    z_torch.backward()\n    z_flow = out_flow.data.sum()\n    z_flow.backward()\n\n    test_case.assertTrue(\n        np.allclose(\n            out_torch.data.cpu().detach().numpy(),\n            out_flow.data.cpu().detach().numpy(),\n            atol=1e-5,\n        )\n    )\n\n    test_case.assertTrue(\n        np.allclose(\n            hid_torch[0].cpu().detach().numpy(),\n            hid_flow[0].cpu().detach().numpy(),\n            atol=1e-5,\n        )\n    )\n\n    test_case.assertTrue(\n        np.allclose(\n            hid_torch[1].cpu().detach().numpy(),\n            hid_flow[1].cpu().detach().numpy(),\n            atol=1e-5,\n        )\n    )\n\n    all_weights = lstm_torch.all_weights\n    torch_params = []\n    for ls in all_weights:\n        for l in ls:\n            torch_params.append(l)\n    all_weights = lstm_flow.all_weights\n    flow_params = []\n    for ls in all_weights:\n        for l in ls:\n            flow_params.append(l)\n\n    for i in range(len(flow_params)):\n        torch_np = torch_params[i].grad.cpu().numpy()\n        flow_np = flow_params[i].grad.cpu().numpy()\n        test_case.assertTrue(np.allclose(torch_np, flow_np, atol=1e-4))\n\n\ndef _test_gru_pack_sequence(test_case, device):\n    input_size = random.randint(10, 1000)\n    hidden_size = random.randint(10, 1000)\n    num_layers = random.randint(1, 6)\n    grad_tol = 1e-4\n    bias = random.randint(-10, 10) <= 0\n    batch_first = False\n    dropout = 0\n    bidirectional = random.randint(-10, 10) <= 0\n\n    gru_torch = torch.nn.GRU(\n        input_size=input_size,\n        hidden_size=hidden_size,\n        num_layers=num_layers,\n        bias=bias,\n        batch_first=batch_first,\n        dropout=dropout,\n        bidirectional=bidirectional,\n    )\n\n    gru_flow = flow.nn.GRU(\n        input_size=input_size,\n        hidden_size=hidden_size,\n        num_layers=num_layers,\n        bias=bias,\n        batch_first=batch_first,\n        dropout=dropout,\n        bidirectional=bidirectional,\n    )\n\n    torch_state_dict = gru_torch.state_dict()\n    new_dict = {}\n    for k, v in torch_state_dict.items():\n        new_dict[k] = v.detach().numpy()\n    gru_flow.load_state_dict(new_dict)\n\n    gru_flow = gru_flow.to(device)\n    gru_torch = gru_torch.to(device)\n\n    max_seq_len = random.randint(10, 50)\n    batch_size = random.randint(10, 50)\n    lengths = []\n    lengths.append(max_seq_len)\n    for i in range(batch_size - 1):\n        lengths.append(random.randint(1, max_seq_len))\n    lengths.sort(reverse=True)\n\n    sequences = []\n    for i in range(batch_size):\n        sequences.append(flow.rand(lengths[i], input_size).to(device))\n\n    x_flow = flow_rnn_utils.pack_sequence(sequences)\n    torch_inputs = [torch.tensor(ft.numpy(), device=device) for ft in sequences]\n    x_torch = torch_rnn_utils.pack_sequence(torch_inputs)\n\n    out_torch, hid_torch = gru_torch(x_torch)\n    out_flow, hid_flow = gru_flow(x_flow)\n\n    z_torch = out_torch.data.sum()\n    z_torch.backward()\n    z_flow = out_flow.data.sum()\n    z_flow.backward()\n\n    test_case.assertTrue(\n        np.allclose(\n            out_torch.data.cpu().detach().numpy(),\n            out_flow.data.cpu().detach().numpy(),\n            atol=1e-5,\n        )\n    )\n\n    test_case.assertTrue(\n        np.allclose(\n            hid_torch.cpu().detach().numpy(),\n            hid_flow.cpu().detach().numpy(),\n            atol=1e-5,\n        )\n    )\n\n    all_weights = gru_torch.all_weights\n    torch_params = []\n    for ls in all_weights:\n        for l in ls:\n            torch_params.append(l)\n    all_weights = gru_flow.all_weights\n    flow_params = []\n    for ls in all_weights:\n        for l in ls:\n            flow_params.append(l)\n\n    for i in range(len(flow_params)):\n        torch_np = torch_params[i].grad.cpu().numpy()\n        flow_np = flow_params[i].grad.cpu().numpy()\n        test_case.assertTrue(np.allclose(torch_np, flow_np, atol=grad_tol))\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestRNNModules(flow.unittest.TestCase):\n    def test_rnn(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [\n            _test_rnn_pack_sequence,\n            _test_lstm_pack_sequence,\n            _test_gru_pack_sequence,\n        ]\n        arg_dict[\"device\"] = [\"cuda\", \"cpu\"]\n        for i in range(5):\n            for arg in GenArgList(arg_dict):\n                arg[0](test_case, *arg[1:])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/expensive/test_rnn_utils.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nimport random\nimport numpy as np\nfrom collections import OrderedDict\nimport torch\nimport torch.nn.utils.rnn as torch_rnn_utils\n\nimport oneflow as flow\nimport oneflow.nn.utils.rnn as flow_rnn_utils\n\nimport oneflow.unittest\nfrom oneflow.test_utils.test_util import GenArgList\n\n\ndef _test_rnn_utils_pack_padded_sequence(test_case, device):\n    input_size = random.randint(10, 150)\n    max_seq_len = random.randint(10, 300)\n    batch_size = random.randint(10, 300)\n    requires_grad = np.random.rand() > 0.5\n    padded_inputs = np.zeros((max_seq_len, batch_size, input_size))\n    lengths = []\n    lengths.append(max_seq_len)\n    for i in range(batch_size - 1):\n        lengths.append(random.randint(1, max_seq_len))\n    lengths.sort(reverse=True)\n\n    for i in range(batch_size):\n        padded_inputs[0 : lengths[i], i : i + 1, :] = i + 1\n\n    inputs = flow.from_numpy(padded_inputs).to(device)\n    inputs.requires_grad = requires_grad\n    flow_res = flow_rnn_utils.pack_padded_sequence(inputs, lengths)\n\n    torch_inputs = torch.from_numpy(padded_inputs).to(device)\n    torch_inputs.requires_grad = requires_grad\n    torch_res = torch_rnn_utils.pack_padded_sequence(torch_inputs, lengths)\n\n    test_case.assertTrue(\n        np.allclose(\n            torch_res.batch_sizes.cpu().detach().numpy(),\n            flow_res.batch_sizes.cpu().detach().numpy(),\n            atol=1e-8,\n        )\n    )\n\n    test_case.assertTrue(\n        np.allclose(\n            torch_res.data.cpu().detach().numpy(),\n            flow_res.data.cpu().detach().numpy(),\n            atol=1e-8,\n        )\n    )\n\n    torch_seq_unpacked, torch_lens_unpacked = torch_rnn_utils.pad_packed_sequence(\n        torch_res, batch_first=False\n    )\n    flow_seq_unpacked, flow_lens_unpacked = flow_rnn_utils.pad_packed_sequence(\n        flow_res, batch_first=False\n    )\n\n    if requires_grad:\n        torch_seq_unpacked.sum().backward()\n        flow_seq_unpacked.sum().backward()\n\n    test_case.assertTrue(\n        np.allclose(\n            torch_seq_unpacked.cpu().detach().numpy(),\n            flow_seq_unpacked.cpu().detach().numpy(),\n            atol=1e-8,\n        )\n    )\n\n    test_case.assertTrue(\n        np.allclose(\n            torch_lens_unpacked.cpu().detach().numpy(),\n            flow_lens_unpacked.cpu().detach().numpy(),\n            atol=1e-8,\n        )\n    )\n\n    if requires_grad:\n        test_case.assertTrue(\n            np.allclose(inputs.grad.cpu().numpy(), torch_inputs.grad.cpu().numpy())\n        )\n\n\ndef _test_rnn_utils_pad_sequence(test_case, device):\n    input_size = random.randint(10, 150)\n    max_seq_len = random.randint(20, 300)\n    batch_size = random.randint(20, 300)\n    lengths = []\n    lengths.append(max_seq_len)\n    for i in range(batch_size - 1):\n        lengths.append(random.randint(1, max_seq_len))\n    lengths.sort(reverse=True)\n\n    sequences = []\n    for i in range(batch_size):\n        sequences.append(flow.rand(lengths[i], input_size).to(device))\n\n    flow_res = flow_rnn_utils.pad_sequence(sequences)\n\n    torch_inputs = [torch.tensor(ft.numpy(), device=device) for ft in sequences]\n    torch_res = torch_rnn_utils.pad_sequence(torch_inputs)\n\n    test_case.assertTrue(\n        np.allclose(\n            torch_res.cpu().detach().numpy(),\n            flow_res.cpu().detach().numpy(),\n            atol=1e-8,\n        )\n    )\n\n\ndef _test_rnn_utils_pack_sequence(test_case, device):\n    input_size = random.randint(10, 150)\n    max_seq_len = random.randint(20, 300)\n    batch_size = random.randint(20, 300)\n    lengths = []\n    lengths.append(max_seq_len)\n    for i in range(batch_size - 1):\n        lengths.append(random.randint(1, max_seq_len))\n    lengths.sort(reverse=True)\n\n    sequences = []\n    for i in range(batch_size):\n        sequences.append(flow.rand(lengths[i], input_size).to(device))\n\n    flow_res = flow_rnn_utils.pack_sequence(sequences)\n\n    torch_inputs = [torch.tensor(ft.numpy(), device=device) for ft in sequences]\n    torch_res = torch_rnn_utils.pack_sequence(torch_inputs)\n\n    test_case.assertTrue(\n        np.allclose(\n            torch_res.batch_sizes.cpu().detach().numpy(),\n            flow_res.batch_sizes.cpu().detach().numpy(),\n            atol=1e-8,\n        )\n    )\n\n    test_case.assertTrue(\n        np.allclose(\n            torch_res.data.cpu().detach().numpy(),\n            flow_res.data.cpu().detach().numpy(),\n            atol=1e-8,\n        )\n    )\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestRNNUtils(flow.unittest.TestCase):\n    def test_rnn_utils_pack_padded_sequence(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"device\"] = [\"cuda\", \"cpu\"]\n        for i in range(10):\n            for arg in GenArgList(arg_dict):\n                _test_rnn_utils_pack_padded_sequence(test_case, *arg[0:])\n\n    def test_rnn_utils_pad_sequence(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"device\"] = [\"cuda\", \"cpu\"]\n        for i in range(10):\n            for arg in GenArgList(arg_dict):\n                _test_rnn_utils_pad_sequence(test_case, *arg[0:])\n\n    def test_rnn_utils_pack_sequence(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"device\"] = [\"cuda\", \"cpu\"]\n        for i in range(10):\n            for arg in GenArgList(arg_dict):\n                _test_rnn_utils_pack_sequence(test_case, *arg[0:])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/expensive/test_sqrt_square_sum.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\n\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestLinalgVectorNorm2D(flow.unittest.TestCase):\n    @autotest(n=2, auto_backward=False, check_graph=True, rtol=0.5, atol=0.5)\n    def test_sqrt_sum_with_cpu_random_data(test_case):\n        device = cpu_device()\n        x = random_tensor(ndim=4, dim1=3, dim2=4, dim3=5, requires_grad=False).to(\n            device\n        )\n        y = torch.linalg.norm(x)\n        return y\n\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    @autotest(n=2, auto_backward=False, check_graph=True)\n    def test_sqrt_sum_with_cuda_random_data(test_case):\n        device = gpu_device()\n        x = random_tensor(ndim=4, dim1=10, dim2=10, dim3=10, requires_grad=False).to(\n            device\n        )\n        y = torch.linalg.norm(x)\n        return y\n\n    @autotest(n=2, auto_backward=False, check_graph=True, rtol=0.5, atol=0.5)\n    def test_scalar_print_random_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=4, dim1=3, dim2=4, dim3=5, requires_grad=False).to(\n            device\n        )\n        y = torch.linalg.norm(x)\n        return y\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/expensive/test_tensor_offload.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport os\nimport unittest\n\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.nn as nn\nimport oneflow.unittest\n\n# NOTE(Li Xiang): This variable controls the mem comparison method of the tensor offload test.\n#  1: Strictly test, compare mem changes according to tensor size.\n#  2: Loose test, compare mem changes before and after offload;\n#  3: Execute only offload, skip mem check.\noffload_tensor_test_mem_mode = 3\n\n\ndef _test_tensor_offload_d2h(test_case, input, tensor_mem):\n    print(\"\\n- test offload cuda mem use\")\n    test_case.assertTrue(not input.is_offloaded())\n\n    before_used = flow._oneflow_internal.GetCUDAMemoryUsed()\n    print(\"  - before \", before_used)\n    before_id = id(input)\n\n    input.offload()\n    test_case.assertTrue(input.is_offloaded())\n    test_case.assertEqual(input.device, flow.device(\"cuda\"))\n    after_used = flow._oneflow_internal.GetCUDAMemoryUsed()\n    after_id = id(input)\n    print(\"  - after \", after_used)\n    change_as_expected = (before_used - after_used) == tensor_mem\n    # Check tensor_mem cuda memory released\n    if offload_tensor_test_mem_mode == 1:\n        test_case.assertTrue(change_as_expected)\n    elif offload_tensor_test_mem_mode == 2:\n        if tensor_mem != 0:\n            test_case.assertTrue(before_used > after_used)\n    print(\"  - tensor size \", tensor_mem)\n    print(\"  - change \", after_used - before_used)\n    print(\"  - change as expected \", change_as_expected)\n    test_case.assertEqual(before_id, after_id)\n\n\ndef _test_tensor_load_h2d(test_case, input, tensor_mem):\n    print(\"\\n- test load cuda mem use\")\n    test_case.assertTrue(input.is_offloaded())\n\n    before_used = flow._oneflow_internal.GetCUDAMemoryUsed()\n    print(\"  - before \", before_used)\n    before_id = id(input)\n\n    input.load()\n    test_case.assertTrue(not input.is_offloaded())\n    test_case.assertEqual(input.device, flow.device(\"cuda\"))\n    after_used = flow._oneflow_internal.GetCUDAMemoryUsed()\n    after_id = id(input)\n    print(\"  - after \", after_used)\n    # Check tensor_mem cuda memory allocated\n    change_as_expected = (after_used - before_used) == tensor_mem\n    if offload_tensor_test_mem_mode == 1:\n        test_case.assertTrue(change_as_expected)\n    elif offload_tensor_test_mem_mode == 2:\n        if tensor_mem != 0:\n            test_case.assertTrue(after_used > before_used)\n    print(\"  - tensor size \", tensor_mem)\n    print(\"  - change \", after_used - before_used)\n    print(\"  - change as expected \", change_as_expected)\n    test_case.assertEqual(before_id, after_id)\n\n\ndef _get_tensor_mem(input):\n    if input.dim() == 0:\n        return 2\n    cnt_size = input.element_size() * flow.numel(input)\n    return cnt_size / 1024 / 1024\n\n\n@flow.unittest.skip_unless_1n1d()\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\nclass TestTensorOffload(flow.unittest.TestCase):\n    def test_tensor_offload_and_load_float32(test_case):\n        flow.cuda.empty_cache()\n        input = flow.tensor(\n            np.random.randn(1024, 1024, 100),\n            dtype=flow.float32,\n            device=flow.device(\"cuda\"),\n        )\n        data = input.numpy()\n\n        for i in range(3):\n            input_tensor_mem = _get_tensor_mem(input)\n            # test tensor offload\n            _test_tensor_offload_d2h(test_case, input, input_tensor_mem)\n\n            # data = input.numpy() will raise error here\n\n            # test tensor load\n            _test_tensor_load_h2d(test_case, input, input_tensor_mem)\n\n        # test data after tensor load\n        test_case.assertTrue(np.allclose(input.numpy(), data, rtol=0.0001, atol=0.0001))\n\n    def test_tensor_offload_and_load_float16(test_case):\n        flow.cuda.empty_cache()\n        input = flow.tensor(\n            np.random.randn(20, 1024, 1024),\n            dtype=flow.float16,\n            device=flow.device(\"cuda\"),\n        )\n        data = input.numpy()\n\n        for i in range(3):\n            input_tensor_mem = _get_tensor_mem(input)\n            # test tensor offload\n            _test_tensor_offload_d2h(test_case, input, input_tensor_mem)\n\n            # data = input.numpy() will raise error here\n\n            # test tensor load\n            _test_tensor_load_h2d(test_case, input, input_tensor_mem)\n\n        # test data after tensor load\n        test_case.assertTrue(np.allclose(input.numpy(), data, rtol=0.0001, atol=0.0001))\n\n    def test_tensor_offload_and_load_int64(test_case):\n        flow.cuda.empty_cache()\n        input = flow.tensor(\n            np.random.randn(20, 1024, 1024),\n            dtype=flow.int64,\n            device=flow.device(\"cuda\"),\n        )\n        data = input.numpy()\n\n        for i in range(3):\n            input_tensor_mem = _get_tensor_mem(input)\n            # test tensor offload\n            _test_tensor_offload_d2h(test_case, input, input_tensor_mem)\n\n            # data = input.numpy() will raise error here\n\n            # test tensor load\n            _test_tensor_load_h2d(test_case, input, input_tensor_mem)\n\n        # test data after tensor load\n        test_case.assertTrue(np.allclose(input.numpy(), data, rtol=0.0001, atol=0.0001))\n\n    @unittest.skip(\"0 dim tensor is unstable in CI container mem tests.\")\n    def test_tensor_offload_and_load_0dim(test_case):\n        flow.cuda.empty_cache()\n        input = flow.tensor(\n            np.random.randint(1, 10), dtype=flow.float16, device=flow.device(\"cuda\"),\n        )\n        data = input.numpy()\n\n        for i in range(3):\n            input_tensor_mem = _get_tensor_mem(input)\n            # test tensor offload\n            _test_tensor_offload_d2h(test_case, input, input_tensor_mem)\n\n            # data = input.numpy() will raise error here\n\n            # test tensor load\n            _test_tensor_load_h2d(test_case, input, input_tensor_mem)\n\n        # test data after tensor load\n        test_case.assertTrue(np.allclose(input.numpy(), data, rtol=0.0001, atol=0.0001))\n\n    def test_tensor_offload_and_load_0size(test_case):\n        flow.cuda.empty_cache()\n        input = flow.tensor(\n            np.random.randn(0, 1024, 1024),\n            dtype=flow.float16,\n            device=flow.device(\"cuda\"),\n        )\n        data = input.numpy()\n\n        for i in range(3):\n            input_tensor_mem = 0\n            # test tensor offload\n            _test_tensor_offload_d2h(test_case, input, input_tensor_mem)\n\n            # data = input.numpy() will raise error here\n\n            # test tensor load\n            _test_tensor_load_h2d(test_case, input, input_tensor_mem)\n\n        # test data after tensor load\n        test_case.assertTrue(np.allclose(input.numpy(), data, rtol=0.0001, atol=0.0001))\n\n    def test_tensor_offload_and_load_cpu_mem(test_case):\n        input = flow.tensor(\n            np.random.randn(1024, 1024, 100),\n            dtype=flow.float32,\n            device=flow.device(\"cuda\"),\n        )\n\n        before_used = flow._oneflow_internal.GetCPUMemoryUsed()\n        before_id = id(input)\n        input.offload()\n        after_used = flow._oneflow_internal.GetCPUMemoryUsed()\n        after_id = id(input)\n        if offload_tensor_test_mem_mode == 2:\n            test_case.assertTrue(after_used > before_used)\n        elif offload_tensor_test_mem_mode == 3:\n            print(\"cpu mem change value:\", after_used - before_used)\n        test_case.assertEqual(before_id, after_id)\n\n        cur_used = flow._oneflow_internal.GetCPUMemoryUsed()\n        before_id = id(input)\n        input.load()\n        after_used = flow._oneflow_internal.GetCPUMemoryUsed()\n        after_id = id(input)\n        if offload_tensor_test_mem_mode == 2:\n            test_case.assertTrue(after_used < cur_used)\n        elif offload_tensor_test_mem_mode == 3:\n            print(\"cpu mem change value:\", cur_used - after_used)\n        test_case.assertEqual(before_id, after_id)\n\n    def test_param_offload(test_case):\n        def load_eager_model(model):\n            for param in model.parameters():\n                print(\"\\n- test param load cuda mem use\")\n                test_case.assertTrue(param.is_offloaded())\n                before_used = flow._oneflow_internal.GetCUDAMemoryUsed()\n                print(\"  - before \", before_used)\n                param.load()\n                after_used = flow._oneflow_internal.GetCUDAMemoryUsed()\n                print(\"  - after \", after_used)\n                tensor_mem = _get_tensor_mem(param)\n                change_as_expected = (after_used - before_used) == tensor_mem\n                print(\"  - tensor size \", tensor_mem)\n                print(\"  - change \", after_used - before_used)\n                print(\"  - change as expected \", change_as_expected)\n                test_case.assertTrue(not param.is_offloaded())\n\n        def offload_eager_model(model):\n            for param in model.parameters():\n                print(\"\\n- test param offload cuda mem use\")\n                test_case.assertTrue(not param.is_offloaded())\n                before_used = flow._oneflow_internal.GetCUDAMemoryUsed()\n                print(\"  - before \", before_used)\n                param.offload()\n                after_used = flow._oneflow_internal.GetCUDAMemoryUsed()\n                print(\"  - after \", after_used)\n                tensor_mem = _get_tensor_mem(param)\n                change_as_expected = (before_used - after_used) == tensor_mem\n                print(\"  - tensor size \", tensor_mem)\n                print(\"  - change \", after_used - before_used)\n                print(\"  - change as expected \", change_as_expected)\n                test_case.assertTrue(param.is_offloaded())\n\n        class Model(nn.Module):\n            def __init__(self):\n                super().__init__()\n                self.n_layer = 1\n\n                layer_list = list()\n\n                for _ in range(self.n_layer):\n                    # Too small to seem mem change\n                    layer_list.append(nn.Linear(768, 4096))\n                    # Big enough to seem mem change\n                    layer_list.append(nn.Linear(4096, 4096))\n\n                self.layers = nn.Sequential(*layer_list)\n\n            def forward(self, x):\n                return self.layers(x)\n\n        model0 = Model().cuda()\n        BZ = 128\n        dataset = [flow.rand((BZ, 768), dtype=flow.float32) for _ in range(128)]\n\n        with flow.no_grad():\n            for idx, x in enumerate(dataset):\n                print(f\"iter {idx} begin\")\n                x = x.cuda()\n\n                if idx != 0:\n                    # no need to load at first iter\n                    load_eager_model(model0)\n                y0 = model0(x)\n                offload_eager_model(model0)\n\n                print(f\"iter {idx} end\")\n                if idx == 1:\n                    break\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/expensive/test_tensor_str.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport os\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow import tensor\nimport oneflow\n\n\ndef _test_local_tensor_str(test_case, device):\n    # int dtype\n    x = flow.tensor([[1, 2, 3], [4, 5, -6]], device=flow.device(device))\n    tensor_str = str(x)\n    test_case.assertTrue(\"3\" in tensor_str)\n    test_case.assertTrue(\"5\" in tensor_str)\n    test_case.assertTrue(\"-6\" in tensor_str)\n    test_case.assertTrue(\"2\" in str(x[0][1]))\n    test_case.assertTrue(np.allclose(eval(tensor_str).numpy(), x.numpy()))\n\n    # empty\n    x = flow.tensor([], device=flow.device(device))\n    tensor_str = str(x)\n    test_case.assertTrue(\"[]\" in tensor_str)\n    test_case.assertTrue(np.allclose(eval(tensor_str).numpy(), x.numpy()))\n\n    # scientific representation int_mode(val == np.ceil(val))\n    x = flow.tensor(\n        [[1, 2, 3], [4, 5, 600000]], device=flow.device(device), dtype=flow.float64\n    )\n    tensor_str = str(x)\n    test_case.assertTrue(\"6.0000e+05\" in tensor_str)\n    test_case.assertTrue(np.allclose(eval(tensor_str).numpy(), x.numpy()))\n\n    # int_mode\n    x = flow.tensor(\n        [[1.0, 2.0, 3.0], [4.0, 5, 60]], device=flow.device(device), dtype=flow.float64\n    )\n    tensor_str = str(x)\n    test_case.assertTrue(\"4.\" in tensor_str)\n    test_case.assertTrue(\"60.\" in tensor_str)\n    test_case.assertTrue(np.allclose(eval(tensor_str).numpy(), x.numpy()))\n\n    # float dtype\n    x = flow.tensor(\n        [[1.3, 2.4, 3.5], [-4.6, 5, 60]], device=flow.device(device), dtype=flow.float64\n    )\n    tensor_str = str(x)\n    test_case.assertTrue(\"3.5000\" in tensor_str)\n    test_case.assertTrue(\"-4.6000\" in tensor_str)\n    test_case.assertTrue(\"60.0000\" in tensor_str)\n    test_case.assertTrue(np.allclose(eval(tensor_str).numpy(), x.numpy()))\n\n    # scientific representation float dtype\n    x = flow.tensor(\n        [[1.3, 2.4, 3.5], [-4.6, 5, 60000000]],\n        device=flow.device(device),\n        dtype=flow.float64,\n    )\n    tensor_str = str(x)\n    test_case.assertTrue(\"2.4000e+00\" in tensor_str)\n    test_case.assertTrue(\"3.5000e+00\" in tensor_str)\n    test_case.assertTrue(\"-4.6000e+00\" in tensor_str)\n    test_case.assertTrue(\"6.0000e+07\" in tensor_str)\n    test_case.assertTrue(np.allclose(eval(tensor_str).numpy(), x.numpy()))\n\n    # summarized data float dtype\n    x = flow.tensor(\n        np.ones((100, 100, 100)), device=flow.device(device), dtype=flow.float64\n    )\n    tensor_str = str(x)\n    test_case.assertTrue(\"1\" in tensor_str)\n    test_case.assertTrue(\"...\" in tensor_str)\n\n\ndef _test_global_tensor_str(test_case, device):\n    placement = flow.placement(device, range(1))\n    # split global tensor\n    x = flow.ones((10, 10), placement=placement, sbp=[flow.sbp.split(0)])\n    tensor_str = str(x)\n    test_case.assertTrue(\"1.\" in tensor_str)\n\n    # broadcast global tensor\n    x = flow.ones((10, 10), placement=placement, sbp=[flow.sbp.broadcast])\n    tensor_str = str(x)\n    test_case.assertTrue(\"1.\" in tensor_str)\n\n    # partial_sum global tensor\n    x = flow.ones((10, 10), placement=placement, sbp=[flow.sbp.partial_sum])\n    tensor_str = str(x)\n    test_case.assertTrue(\"1.\" in tensor_str)\n\n    # summarized global tensor\n    x = flow.ones((100, 100), placement=placement, sbp=[flow.sbp.split(0)])\n    tensor_str = str(x)\n    test_case.assertTrue(\"1.\" in tensor_str)\n    test_case.assertTrue(\"...\" in tensor_str)\n\n    # empty global tensor\n    x = flow.ones((0, 10), placement=placement, sbp=[flow.sbp.split(0)])\n    tensor_str = str(x)\n    test_case.assertTrue(\"[]\" in tensor_str)\n\n\ndef _test_global_tensor_str_2d(test_case, device):\n    placement = flow.placement(device, range(2))\n    x = flow.ones((10, 10), placement=placement, sbp=[flow.sbp.split(0)])\n    tensor_str = str(x)\n    test_case.assertTrue(\"1.\" in tensor_str)\n\n    x = flow.ones((10, 10), placement=placement, sbp=[flow.sbp.broadcast])\n    tensor_str = str(x)\n    test_case.assertTrue(\"1.\" in tensor_str)\n    # TODO: x[0][0].to(\"cuda\") has bug\n    # test_case.assertTrue(\"1.\" in str(x[0][0]))\n\n    x = flow.ones((10, 10), placement=placement, sbp=[flow.sbp.partial_sum])\n    tensor_str = str(x)\n    test_case.assertTrue(\"1.\" in tensor_str)\n\n    x = flow.ones((100, 100), placement=placement, sbp=[flow.sbp.split(0)])\n    tensor_str = str(x)\n    test_case.assertTrue(\"1.\" in tensor_str)\n    # TODO: this test has bug\n    # test_case.assertTrue(\"...\" in tensor_str)\n\n    x = flow.ones((100, 100), placement=placement, sbp=[flow.sbp.split(1)])\n    tensor_str = str(x)\n    test_case.assertTrue(\"1.\" in tensor_str)\n    # TODO: this test has bug\n    # test_case.assertTrue(\"...\" in tensor_str)\n\n    x = flow.ones(\n        (10, 10), placement=flow.placement(device, ranks=[0]), sbp=[flow.sbp.broadcast]\n    )\n    tensor_str = str(x)\n    test_case.assertTrue(\"1.\" in tensor_str)\n\n    x = flow.ones((2, 5), placement=placement, sbp=[flow.sbp.split(0)])\n    tensor_str = str(x)\n    test_case.assertTrue(\"1.\" in tensor_str)\n\n\ndef _test_nd_sbp_tensor_str(test_case, device, sbp0, sbp1):\n    placement = flow.placement(type=device, ranks=[[0, 1], [2, 3]])\n    sbp = [sbp0, sbp1]\n    x = flow.ones((20, 20), placement=placement, sbp=sbp)\n    tensor_str = str(x)\n    test_case.assertTrue(str(sbp0) in tensor_str)\n    test_case.assertTrue(str(sbp1) in tensor_str)\n\n\nclass TestTensorStrModule(flow.unittest.TestCase):\n    @flow.unittest.skip_unless_1n1d()\n    @unittest.skip(\"TODO: fengwei, this often fails\")\n    def test_local_tensor_str_1n1d(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [\n            _test_local_tensor_str,\n        ]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    @flow.unittest.skip_unless_1n1d()\n    @unittest.skip(\"TODO: fengwei, this often fails\")\n    def test_global_tensor_str_1n1d(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [\n            _test_global_tensor_str,\n        ]\n        arg_dict[\"device\"] = [\"cuda\"]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    @flow.unittest.skip_unless_1n2d()\n    def test_tensor_str_1n2d(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [\n            _test_global_tensor_str_2d,\n        ]\n        arg_dict[\"device\"] = [\"cuda\", \"cpu\"]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    @flow.unittest.skip_unless_1n4d()\n    def test_nd_sbp_tensor_str(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [\n            _test_nd_sbp_tensor_str,\n        ]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n\n        sbp_arg_dict = OrderedDict()\n        sbp_list = [\n            flow.sbp.broadcast,\n            flow.sbp.split(0),\n            flow.sbp.partial_sum,\n        ]\n        sbp_arg_dict[\"sbp0\"] = sbp_list\n        sbp_arg_dict[\"sbp1\"] = sbp_list\n        for arg in GenArgList(arg_dict):\n            for sbp in GenArgList(sbp_arg_dict):\n                arg[0](test_case, *(arg[1:] + sbp[:]))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/expensive/test_util.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport itertools\nimport os\nfrom collections import OrderedDict\nfrom collections.abc import Iterable\n\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\ndef GenCartesianProduct(sets):\n    assert isinstance(sets, Iterable)\n    for set in sets:\n        assert isinstance(set, Iterable)\n        if os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"):\n            if \"cuda\" in set:\n                set.remove(\"cuda\")\n    return itertools.product(*sets)\n\n\ndef GenArgList(arg_dict):\n    assert isinstance(arg_dict, OrderedDict)\n    assert all([isinstance(x, list) for x in arg_dict.values()])\n    sets = [arg_set for (_, arg_set) in arg_dict.items()]\n    return GenCartesianProduct(sets)\n\n\ndef GenArgDict(arg_dict):\n    return [dict(zip(arg_dict.keys(), x)) for x in GenArgList(arg_dict)]\n\n\nclass Args:\n    def __init__(self, flow_args, tf_args=None):\n        super().__init__()\n        if tf_args is None:\n            tf_args = flow_args\n        self.flow_args = flow_args\n        self.tf_args = tf_args\n\n    def __str__(self):\n        return \"flow_args={} tf_args={}\".format(self.flow_args, self.tf_args)\n\n    def __repr__(self):\n        return self.__str__()\n\n\ntype_name_to_flow_type = {\n    \"float16\": flow.float16,\n    \"float32\": flow.float32,\n    \"double\": flow.double,\n    \"int8\": flow.int8,\n    \"int32\": flow.int32,\n    \"int64\": flow.int64,\n    \"uint8\": flow.uint8,\n}\ntype_name_to_np_type = {\n    \"float16\": np.float16,\n    \"float32\": np.float32,\n    \"double\": np.float64,\n    \"int8\": np.int8,\n    \"int32\": np.int32,\n    \"int64\": np.int64,\n    \"uint8\": np.uint8,\n}\n\n\ndef FlattenArray(input_array):\n    output_array = list()\n    for x in np.nditer(input_array):\n        output_array.append(x.tolist())\n    return output_array\n\n\ndef Array2Numpy(input_array, target_shape):\n    return np.array(input_array).reshape(target_shape, order=\"C\")\n\n\ndef Index2Coordinate(idx, tensor_shape):\n    coordinate = []\n    tmp = idx\n    for i in range(len(tensor_shape) - 1, -1, -1):\n        axis_size = tensor_shape[i]\n        coor = tmp % axis_size\n        coordinate.insert(0, int(coor))\n        tmp = (tmp - coor) / axis_size\n    return coordinate\n\n\ndef Coordinate2Index(coordinate, tensor_shape):\n    if len(coordinate) != len(tensor_shape):\n        raise \"wrong coordinate or shape\"\n    idx = 0\n    for (i, coor) in enumerate(coordinate):\n        size_at_axis = coor\n        for j in range(i + 1, len(tensor_shape)):\n            size_at_axis *= tensor_shape[j]\n        idx += size_at_axis\n    return idx\n\n\ndef generate_graph(func):\n    class Graph(flow.nn.Graph):\n        def __init__(self):\n            super().__init__()\n\n        def build(self, *args):\n            return func(*args)\n\n    return Graph()\n"
  },
  {
    "path": "python/oneflow/test/gen_ops_process.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport os\nimport subprocess\nimport glob\nimport re\n\n\ndef get_api(rst_dir):\n    \"\"\"\n    Extract operator names from rst files.\n\n    `currentmodule` is not regarded as operators.\n    `autoclass` and `automodule` are regarded as operators in the absence of `members`.\n    \"\"\"\n    op_files = glob.glob(rst_dir + \"/*.rst\")\n    op_files.remove(rst_dir + \"/graph.rst\")\n    op_files.remove(rst_dir + \"/index.rst\")\n    api_list = []\n    api_str = \"\"\n    for op_file in op_files:\n        with open(op_file, \"r\") as f:\n            line = f.readline()\n            pre = \"\"\n            while line:\n                skip = False\n                if \".. currentmodule::\" in line:\n                    pre = line.strip().replace(\".. currentmodule::\", \"\") + \".\"\n                elif \".. autofunction::\" in line:\n                    if \"oneflow\" not in line:\n                        api_str += pre\n                    api_str += line.replace(\".. autofunction::\", \"\")\n                elif (\n                    \".. autosummary::\" in line\n                    or \".. autoclass::\" in line\n                    or \":toctree:\" in line\n                    or \":nosignatures:\" in line\n                    or \":template:\" in line\n                ):\n                    if \":nosignatures:\" in line:\n                        line = f.readline()\n                        if \":template:\" in line:\n                            line = f.readline()\n                        line = f.readline()\n                        while line and len(line.replace(\" \", \"\")) > 1:\n                            if \"oneflow\" not in line:\n                                api_str += pre\n                            api_str += line\n                            line = f.readline()\n                elif \".. automodule::\" in line:\n                    pre_a = line.replace(\".. automodule:: \", \"\")\n                    line = f.readline()\n                    skip = True\n                    if \":members:\" in line and len(line) > 14:\n                        pre_a = pre_a.strip() + \".\"\n                        api_str += pre_a + line.replace(\":members:\", \"\")\n                        line = f.readline()\n                        while (\n                            line and \":\" not in line and len(line.replace(\" \", \"\")) > 1\n                        ):\n                            api_str += pre_a + line\n                            line = f.readline()\n                if not skip:\n                    line = f.readline()\n\n    api_list = api_str.strip().replace(\" \", \"\").replace(\",\", \"\").split(\"\\n\")\n    return api_list\n\n\ndef get_profile_func(path):\n    \"\"\"\n    Iterate through files under `path` to find out all operator names,\n    and update code links to file_func_map_list by file_func_map.\n    \"\"\"\n    files = os.listdir(path)\n    commit_bytes = subprocess.check_output([\"git\", \"rev-parse\", \"HEAD\"])\n    commit_str = commit_bytes.decode(\"utf-8\").replace(\"\\n\", \"\")\n    result_profile_func_list = []\n    for file in files:\n        if file != \"log\" and not os.path.isdir(file) and file.find(\"__pycache__\") == -1:\n            f = open(os.path.join(path, file))\n            last_line = \"\"\n            iter_f = iter(f)\n            line_num = 1\n            for line in iter_f:\n                line = line.strip()\n                match = re.fullmatch(r\"^@profile\\((.+)\\)$\", line)\n                if match:\n                    tem_profile = match.group(1)\n                    tem_profile_name = tem_profile.split(\".\")[-1]\n                    result_profile_func_list.append(tem_profile_name)\n\n    return result_profile_func_list\n\n\ndef get_test_func(path):\n    \"\"\"\n    Iterate through files under `path` to find out all operator names,\n    and update code links to file_func_map_list by file_func_map.\n    \"\"\"\n    files = os.listdir(path)\n    commit_bytes = subprocess.check_output([\"git\", \"rev-parse\", \"HEAD\"])\n    commit_str = commit_bytes.decode(\"utf-8\").replace(\"\\n\", \"\")\n    result_func_list = []\n    for file in files:\n        if file != \"log\" and not os.path.isdir(file) and file.find(\"__pycache__\") == -1:\n            f = open(os.path.join(path, file))\n            last_line = \"\"\n            iter_f = iter(f)\n            line_num = 1\n            for line in iter_f:\n                line = line.strip()\n                rem = re.match(\"def .*?(test_.*)\\(test_case.*\", line)\n                if rem and \"#\" not in line:\n                    func_name = rem.group(1).replace(\"_test_\", \"\").replace(\"test_\", \"\")\n                    result_func_list.append(func_name)\n                    file_func_map[func_name] = (\n                        f\" [{func_name}](\"\n                        + \"https://github.com/Oneflow-Inc/oneflow/blob/\"\n                        + commit_str\n                        + \"/python/oneflow/test/\"\n                        + path\n                        + \"/\"\n                        + file\n                        + f\"#L{line_num}) \"\n                    )\n                elif last_line.startswith(\"add_docstr\"):\n                    result_func_list.append(line[0:-1])\n                    file_func_map[line[0:-1]] = (\n                        f\" [{line[0:-1]}](\"\n                        + \"https://github.com/Oneflow-Inc/oneflow/blob/\"\n                        + commit_str\n                        + \"/python/oneflow/test/\"\n                        + path\n                        + \"/\"\n                        + file\n                        + f\"#L{line_num}) \"\n                    )\n                last_line = line\n                line_num += 1\n    return result_func_list\n\n\ndef pure_match(x, y):\n    \"\"\"\n    Check whether x contains y.\n\n    The purpose of identifying \".\" is to accurately match operator documents.\n    For example, if we make pos = x.find(y) while y = clip_, either oneflow.Tensor.clip or oneflow.Tensor.clip_ is right.\n\n    Besides, identifying \"_\" is important.\n    For example, if we make pos = x.find(y) while y = squeeze, either test of squeeze or unsqueeze is right.\n    \"\"\"\n    x = x.lower()\n    y = y.lower()\n    pos = -1\n    if \".\" in x:\n        x = x.split(\".\")\n        for i in x:\n            if i == y:\n                pos = 1\n                break\n    elif \"_\" in y:\n        pos = x.find(y)\n    else:\n        x = x.split(\"_\")\n        for i in x:\n            if i == y:\n                pos = 1\n                break\n    return pos != -1\n\n\ndef match_test_func(func, func_list):\n    \"\"\"\n    func: operator name\n    func_list: names of all operators\n\n    Check whether func_list contains func. If yes, return matching content, or else return \"\".\n    \"\"\"\n    match_res = \"\"\n    for i in range(len(func_list)):\n        if pure_match(func_list[i], func):\n            match_res = func_list[i]\n            break\n    return match_res\n\n\nif __name__ == \"__main__\":\n    api_list = get_api(\"../../../docs/source\")\n    dir_list = [\n        [\"../../../python/oneflow/framework/docstr\"],\n        [\"../../../python/oneflow/test/modules\", \"../../../python/oneflow/test/tensor\"],\n        [\"../../../python/oneflow/test/exceptions\"],\n    ]\n    num_cols = 4\n    test_func_list = list()\n    test_profile_list = list()\n    file_func_map = dict()\n    file_func_map_list = []\n\n    for i in range(0, len(dir_list)):\n        tmp_func_list = list()\n        tmp_profile_list = list()\n        file_func_map = dict()\n        for path in dir_list[i]:\n            tmp_func_list.extend(get_test_func(path))\n            tmp_profile_list.extend(get_profile_func(path))\n        test_func_list.append(tmp_func_list)\n        test_profile_list.extend(tmp_profile_list)\n        file_func_map_list.append(file_func_map)\n\n    result_list = []\n    result_list.append(f\"## Ops Version : Alpha\")\n    result_list.append(f\"\")\n    result_list.append(f\"\")\n    table_head = f\"| Op Name | Doc Test | Compatiable/Completeness Test | Exception | Performance Test |\"\n    result_list.append(table_head)\n    result_list.append(\n        f\"| ------------------------- | ------------- | ----------------------------- | --------- | ---------------- |\"\n    )\n\n    cnt0 = 0  # the number of doc_test\n    cnt1 = 0  # the number of compatiable_completeness_test\n    cnt2 = 0  # the number of exception_test\n    cnt3 = 0  # the number of profile_test\n\n    for name in api_list:\n        table_line = f\"| {name} |\"\n        name = name.split(\".\")[-1]\n        for i in range(3):\n            match_name = match_test_func(name, test_func_list[i])\n            if match_name != \"\":\n                if i == 0:\n                    cnt0 += 1\n                elif i == 1:\n                    cnt1 += 1\n                else:\n                    cnt2 += 1\n                table_line += file_func_map_list[i][match_name]\n            table_line += \"  |\"\n        if name in test_profile_list:\n            table_line += \" done \"\n            cnt3 += 1\n        table_line += \"  |\"\n\n        result_list.append(table_line)\n\n    doc_test_ratio = cnt0 / len(api_list)\n    compatiable_completeness_test_ratio = cnt1 / len(api_list)\n    exception_test_ratio = cnt2 / len(api_list)\n    performance_test_ratio = cnt3 / len(api_list)\n\n    result_list.append(f\"## Test Data Summary\")\n    result_list.append(f\"- OneFlow Total API Number: {len(api_list)}\")\n    result_list.append(\n        f\"- Doc Test Ratio: {100*doc_test_ratio:.2f}% ({cnt0} / {len(api_list)})\"\n    )\n    result_list.append(\n        f\"- Compatiable/Completeness Test Ratio: {100*compatiable_completeness_test_ratio:.2f}% ({cnt1} / {len(api_list)})\"\n    )\n    result_list.append(\n        f\"- Exception Test Ratio: {100*exception_test_ratio:.2f}% ({cnt2} / {len(api_list)})\"\n    )\n    result_list.append(\n        f\"- Performance Test Ratio: {100*performance_test_ratio:.2f}% ({cnt3} / {len(api_list)})\"\n    )\n    f = open(\"./README.md\", \"w\")\n    for line in result_list:\n        f.write(line + \"\\n\")\n    f.close()\n"
  },
  {
    "path": "python/oneflow/test/graph/alexnet_model.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow as flow\nimport oneflow.nn as nn\n\nfrom typing import Any\n\n\n__all__ = [\"AlexNet\", \"alexnet\"]\n\n\nclass AlexNet(nn.Module):\n    def __init__(self, num_classes: int = 1000) -> None:\n        super(AlexNet, self).__init__()\n        self.features = nn.Sequential(\n            nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),\n            nn.ReLU(inplace=True),\n            nn.MaxPool2d(kernel_size=3, stride=2),\n            nn.Conv2d(64, 192, kernel_size=5, padding=2),\n            nn.ReLU(inplace=True),\n            nn.MaxPool2d(kernel_size=3, stride=2),\n            nn.Conv2d(192, 384, kernel_size=3, padding=1),\n            nn.ReLU(inplace=True),\n            nn.Conv2d(384, 256, kernel_size=3, padding=1),\n            nn.ReLU(inplace=True),\n            nn.Conv2d(256, 256, kernel_size=3, padding=1),\n            nn.ReLU(inplace=True),\n            nn.MaxPool2d(kernel_size=3, stride=2),\n        )\n        self.avgpool = nn.AdaptiveAvgPool2d((6, 6))\n        self.classifier = nn.Sequential(\n            nn.Dropout(),\n            nn.Linear(256 * 6 * 6, 4096),\n            nn.ReLU(inplace=True),\n            nn.Dropout(),\n            nn.Linear(4096, 4096),\n            nn.ReLU(inplace=True),\n            nn.Linear(4096, num_classes),\n        )\n\n    def forward(self, x: flow.Tensor) -> flow.Tensor:\n        x = self.features(x)\n        x = self.avgpool(x)\n        x = flow.flatten(x, 1)\n        x = self.classifier(x)\n        return x\n\n\ndef alexnet(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> AlexNet:\n    r\"\"\"AlexNet model architecture from the\n    `\"One weird trick...\" <https://arxiv.org/abs/1404.5997>`_ paper.\n    The required minimum input size of the model is 63x63.\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n        progress (bool): If True, displays a progress bar of the download to stderr\n    \"\"\"\n    model = AlexNet(**kwargs)\n    return model\n"
  },
  {
    "path": "python/oneflow/test/graph/ofrecord_data_utils.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow as flow\n\nimport os\n\n\nclass OFRecordDataLoader(flow.nn.Module):\n    def __init__(\n        self,\n        ofrecord_root: str = \"./ofrecord\",\n        mode: str = \"train\",  # \"val\"\n        dataset_size: int = 9469,\n        batch_size: int = 1,\n    ):\n        super().__init__()\n        channel_last = False\n        output_layout = \"NHWC\" if channel_last else \"NCHW\"\n        self.train_record_reader = flow.nn.OFRecordReader(\n            ofrecord_root,\n            batch_size=batch_size,\n            data_part_num=1,\n            part_name_suffix_length=5,\n            random_shuffle=True if mode == \"train\" else False,\n            shuffle_after_epoch=True if mode == \"train\" else False,\n        )\n        self.record_label_decoder = flow.nn.OFRecordRawDecoder(\n            \"class/label\", shape=(), dtype=flow.int32\n        )\n\n        color_space = \"RGB\"\n        height = 224\n        width = 224\n\n        self.record_image_decoder = (\n            flow.nn.OFRecordImageDecoderRandomCrop(\"encoded\", color_space=color_space)\n            if mode == \"train\"\n            else flow.nn.OFRecordImageDecoder(\"encoded\", color_space=color_space)\n        )\n\n        self.resize = (\n            flow.nn.image.Resize(target_size=[height, width])\n            if mode == \"train\"\n            else flow.nn.image.Resize(\n                resize_side=\"shorter\", keep_aspect_ratio=True, target_size=256\n            )\n        )\n\n        self.flip = flow.nn.CoinFlip(batch_size=batch_size) if mode == \"train\" else None\n\n        rgb_mean = [123.68, 116.779, 103.939]\n        rgb_std = [58.393, 57.12, 57.375]\n        self.crop_mirror_norm = (\n            flow.nn.CropMirrorNormalize(\n                color_space=color_space,\n                output_layout=output_layout,\n                mean=rgb_mean,\n                std=rgb_std,\n                output_dtype=flow.float,\n            )\n            if mode == \"train\"\n            else flow.nn.CropMirrorNormalize(\n                color_space=color_space,\n                output_layout=output_layout,\n                crop_h=height,\n                crop_w=width,\n                crop_pos_y=0.5,\n                crop_pos_x=0.5,\n                mean=rgb_mean,\n                std=rgb_std,\n                output_dtype=flow.float,\n            )\n        )\n\n        self.batch_size = batch_size\n        self.dataset_size = dataset_size\n\n    def __len__(self):\n        return self.dataset_size // self.batch_size\n\n    def forward(self):\n        train_record = self.train_record_reader()\n        label = self.record_label_decoder(train_record)\n        image_raw_buffer = self.record_image_decoder(train_record)\n        image = self.resize(image_raw_buffer)[0]\n        rng = self.flip() if self.flip != None else None\n        image = self.crop_mirror_norm(image, rng)\n\n        return image, label\n"
  },
  {
    "path": "python/oneflow/test/graph/optimizer_test_util.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport numpy as np\n\n\ndef clip_grad_norm_np(np_grad, max_norm, norm_type):\n    max_norm = float(max_norm)\n    norm_type = float(norm_type)\n    if norm_type == float(\"inf\"):\n        total_norm = np.max(np.abs(np_grad))\n    if norm_type == float(\"-inf\"):\n        total_norm = np.min(np.abs(np_grad))\n    elif norm_type == 0:\n        total_norm = np.sum(np.stack([np.sum(np_grad != 0)]) != 0)\n    else:\n        total_norm = np_grad\n        for i in range(np_grad.ndim, 0, -1):\n            total_norm = np.linalg.norm(total_norm, norm_type, axis=i - 1)\n    clip_coef = max_norm / (total_norm + 1e-6)\n    if clip_coef < 1:\n        np_grad = np_grad * clip_coef\n    return total_norm, np_grad\n"
  },
  {
    "path": "python/oneflow/test/graph/test_alexnet_auto_parallel.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport os\nimport time\nimport unittest\nimport argparse\n\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.unittest\nfrom alexnet_model import alexnet\nimport flowvision as vision\nimport flowvision.transforms as transforms\n\n\ndef load_data_fashion_mnist(\n    batch_size,\n    resize=None,\n    root=\"./data-test/fashion-mnist\",\n    download=True,\n    source_url=None,\n    num_workers=0,\n):\n    \"\"\"Download the Fashion-MNIST dataset and then load into memory.\"\"\"\n    root = os.path.expanduser(root)\n    trans = []\n    if resize:\n        trans.append(transforms.Resize(resize))\n    trans.append(transforms.ToTensor())\n    transform = transforms.Compose(trans)\n\n    mnist_train = vision.datasets.FashionMNIST(\n        root=root,\n        train=True,\n        transform=transform,\n        download=download,\n        source_url=source_url,\n    )\n    mnist_test = vision.datasets.FashionMNIST(\n        root=root,\n        train=False,\n        transform=transform,\n        download=download,\n        source_url=source_url,\n    )\n\n    train_iter = flow.utils.data.DataLoader(\n        mnist_train, batch_size, shuffle=True, num_workers=num_workers\n    )\n    test_iter = flow.utils.data.DataLoader(\n        mnist_test, batch_size, shuffle=False, num_workers=num_workers\n    )\n    return train_iter, test_iter\n\n\ndef _parse_args():\n    parser = argparse.ArgumentParser(\"flags for train alexnet\")\n    parser.add_argument(\n        \"--load_checkpoint\", type=str, default=\"\", help=\"load checkpoint\"\n    )\n    parser.add_argument(\n        \"--ofrecord_path\",\n        type=str,\n        default=flow.unittest.dataset_dir(\"imagenette/ofrecord\"),\n        help=\"dataset path\",\n    )\n    # training hyper-parameters\n    parser.add_argument(\n        \"--learning_rate\", type=float, default=0.02, help=\"learning rate\"\n    )\n    parser.add_argument(\"--mom\", type=float, default=0.9, help=\"momentum\")\n    parser.add_argument(\"--epochs\", type=int, default=1, help=\"training epochs\")\n    parser.add_argument(\"--batch_size\", type=int, default=128, help=\"val batch size\")\n\n    return parser.parse_known_args()\n\n\ndef _test_alexnet_graph(test_case, args, placement, sbp):\n    data_dir = os.path.join(\n        os.getenv(\"ONEFLOW_TEST_CACHE_DIR\", \"./data-test\"), \"fashion-mnist-lenet\"\n    )\n    source_url = \"https://oneflow-public.oss-cn-beijing.aliyuncs.com/datasets/mnist/Fashion-MNIST/\"\n    train_iter, test_iter = load_data_fashion_mnist(\n        batch_size=args.batch_size,\n        root=data_dir,\n        download=True,\n        source_url=source_url,\n        num_workers=0,\n        resize=(112, 112),\n    )\n\n    # oneflow init\n    start_t = time.time()\n    alexnet_module = alexnet(num_classes=10)\n    end_t = time.time()\n    print(\"init time : {}\".format(end_t - start_t))\n\n    alexnet_module.to_global(placement, sbp)\n\n    of_cross_entropy = flow.nn.CrossEntropyLoss().to_global(placement, sbp)\n\n    of_sgd = flow.optim.SGD(\n        alexnet_module.parameters(), lr=args.learning_rate, momentum=args.mom\n    )\n\n    class AlexNetGraph(flow.nn.Graph):\n        def __init__(self):\n            super().__init__()\n            self.alexnet = alexnet_module\n            self.cross_entropy = of_cross_entropy\n            self.add_optimizer(of_sgd)\n            self.config.enable_auto_parallel(True)\n            self.config.enable_auto_parallel_ignore_user_sbp_config(True)\n            self.config.enable_auto_parallel_trunk_algo(True)\n            self.config.enable_auto_parallel_sbp_collector(True)\n\n        def build(self, image, label):\n            logits = self.alexnet(image)\n            loss = self.cross_entropy(logits, label)\n            loss.backward()\n            return loss\n\n    alexnet_graph = AlexNetGraph()\n\n    class AlexNetEvalGraph(flow.nn.Graph):\n        def __init__(self):\n            super().__init__()\n            self.alexnet = alexnet_module\n            self.config.enable_auto_parallel(True)\n            self.config.enable_auto_parallel_ignore_user_sbp_config(True)\n            self.config.enable_auto_parallel_trunk_algo(True)\n            self.config.enable_auto_parallel_sbp_collector(True)\n\n        def build(self, image):\n            with flow.no_grad():\n                logits = self.alexnet(image)\n                predictions = logits.softmax()\n            return predictions\n\n    alexnet_eval_graph = AlexNetEvalGraph()\n\n    of_losses = []\n    print_interval = 20\n\n    acc = 0.0\n    for epoch in range(args.epochs):\n        alexnet_module.train()\n\n        for i, (image, label) in enumerate(train_iter):\n            # oneflow graph train\n            if image.shape[0] != args.batch_size:\n                # drop last batch\n                break\n            start_t = time.time()\n            image = image.to_global(placement, sbp).expand(args.batch_size, 3, 112, 112)\n            label = label.to_global(placement, sbp)\n            loss = alexnet_graph(image, label)\n            end_t = time.time()\n            if i % print_interval == 0:\n                l = loss.numpy()\n                of_losses.append(l)\n                if flow.env.get_rank() == 0:\n                    print(\n                        \"epoch {} train iter {}/{} oneflow loss {}, train time : {}\".format(\n                            epoch, i, len(train_iter), l, end_t - start_t\n                        )\n                    )\n                # Stop after 20 iters to save time\n                break\n        if flow.env.get_rank() == 0:\n            print(\"epoch %d train done, start validation\" % epoch)\n\n        alexnet_module.eval()\n        correct_of = 0.0\n        total_of = 0.0\n        for image, label in test_iter:\n            # oneflow graph eval\n            if image.shape[0] != args.batch_size:\n                # drop last batch\n                break\n            start_t = time.time()\n            image = image.to_global(placement, sbp).expand(args.batch_size, 3, 112, 112)\n            predictions = alexnet_eval_graph(image)\n            of_predictions = predictions.numpy()\n            clsidxs = np.argmax(of_predictions, axis=1)\n\n            label_nd = label.numpy()\n\n            for i in range(args.batch_size):\n                total_of += 1\n                if clsidxs[i] == label_nd[i]:\n                    correct_of += 1\n            end_t = time.time()\n        acc = correct_of / total_of\n\n        if flow.env.get_rank() == 0:\n            print(\"epoch %d, oneflow top1 val acc: %f\" % (epoch, acc))\n    #  test_case.assertTrue(acc > 0.50)\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\nclass TestAlexnetAutoParallel(oneflow.unittest.TestCase):\n    def test_alexnet_auto_parallel_1d_sbp(test_case):\n        args, unknown_args = _parse_args()\n        placement = flow.placement.all(\"cuda\")\n        sbp = [flow.sbp.broadcast,] * len(placement.ranks.shape)\n        _test_alexnet_graph(test_case, args, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/graph/test_alexnet_graph.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport argparse\nimport numpy as np\nimport os\nimport time\nimport unittest\n\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom alexnet_model import alexnet\nfrom ofrecord_data_utils import OFRecordDataLoader\n\n\ndef _parse_args():\n    parser = argparse.ArgumentParser(\"flags for train alexnet\")\n    parser.add_argument(\n        \"--save_checkpoint_path\",\n        type=str,\n        default=\"./checkpoints\",\n        help=\"save checkpoint root dir\",\n    )\n    parser.add_argument(\n        \"--load_checkpoint\", type=str, default=\"\", help=\"load checkpoint\"\n    )\n    parser.add_argument(\n        \"--ofrecord_path\",\n        type=str,\n        default=flow.unittest.dataset_dir(\"imagenette/ofrecord\"),\n        help=\"dataset path\",\n    )\n    parser.add_argument(\n        \"--train_dataset_size\", type=int, default=400, help=\"train_dataset size\"\n    )\n    parser.add_argument(\n        \"--val_dataset_size\", type=int, default=40, help=\"val_dataset size\"\n    )\n    # training hyper-parameters\n    parser.add_argument(\n        \"--learning_rate\", type=float, default=0.001, help=\"learning rate\"\n    )\n    parser.add_argument(\"--mom\", type=float, default=0.9, help=\"momentum\")\n    parser.add_argument(\"--epochs\", type=int, default=1, help=\"training epochs\")\n    parser.add_argument(\n        \"--train_batch_size\", type=int, default=4, help=\"train batch size\"\n    )\n    parser.add_argument(\"--val_batch_size\", type=int, default=4, help=\"val batch size\")\n    parser.add_argument(\"--device\", type=str, default=\"cuda\", help=\"device\")\n\n    return parser.parse_known_args()\n\n\ndef _test_alexnet_graph_repr(test_case, args):\n    train_data_loader = OFRecordDataLoader(\n        ofrecord_root=args.ofrecord_path,\n        mode=\"train\",\n        dataset_size=args.train_dataset_size,\n        batch_size=args.train_batch_size,\n    )\n\n    alexnet_module = alexnet()\n    alexnet_module.to(args.device)\n\n    of_cross_entropy = flow.nn.CrossEntropyLoss()\n    of_cross_entropy.to(args.device)\n\n    of_sgd = flow.optim.SGD(\n        alexnet_module.parameters(), lr=args.learning_rate, momentum=args.mom\n    )\n\n    class AlexNetGraph(flow.nn.Graph):\n        def __init__(self):\n            super().__init__()\n            self.alexnet = alexnet_module\n            self.cross_entropy = of_cross_entropy\n            self.add_optimizer(of_sgd)\n\n        def build(self, image, label):\n            logits = self.alexnet(image)\n            loss = self.cross_entropy(logits, label)\n            loss.backward()\n            return loss\n\n    alexnet_graph = AlexNetGraph()\n\n    print(\"repr(alexnet_graph) before run: \\n\", repr(alexnet_graph))\n\n    # debug graph build\n    alexnet_graph.debug(1, op_repr_with_py_stack=True, max_py_stack_depth=4)\n\n    alexnet_module.train()\n    image, label = train_data_loader()\n    image = image.to(args.device)\n    label = label.to(args.device)\n    loss = alexnet_graph(image, label)\n\n    print(\"repr(alexnet_graph) after run: \\n\", repr(alexnet_graph))\n\n\ndef _test_alexnet_graph(test_case, args):\n    train_data_loader = OFRecordDataLoader(\n        ofrecord_root=args.ofrecord_path,\n        mode=\"train\",\n        dataset_size=args.train_dataset_size,\n        batch_size=args.train_batch_size,\n    )\n    val_data_loader = OFRecordDataLoader(\n        ofrecord_root=args.ofrecord_path,\n        mode=\"val\",\n        dataset_size=args.val_dataset_size,\n        batch_size=args.val_batch_size,\n    )\n\n    # oneflow init\n    start_t = time.time()\n    alexnet_module = alexnet()\n    end_t = time.time()\n    print(\"init time : {}\".format(end_t - start_t))\n\n    alexnet_module.to(args.device)\n\n    of_cross_entropy = flow.nn.CrossEntropyLoss()\n    of_cross_entropy.to(args.device)\n\n    of_sgd = flow.optim.SGD(\n        alexnet_module.parameters(), lr=args.learning_rate, momentum=args.mom\n    )\n\n    class AlexNetGraph(flow.nn.Graph):\n        def __init__(self):\n            super().__init__()\n            self.train_data_loader = train_data_loader\n            self.alexnet = alexnet_module\n            self.cross_entropy = of_cross_entropy\n            self.add_optimizer(of_sgd)\n\n        def build(self):\n            image, label = self.train_data_loader()\n            image = image.to(args.device)\n            label = label.to(args.device)\n            logits = self.alexnet(image)\n            loss = self.cross_entropy(logits, label)\n            loss.backward()\n            return loss\n\n    alexnet_graph = AlexNetGraph()\n\n    class AlexNetEvalGraph(flow.nn.Graph):\n        def __init__(self):\n            super().__init__()\n            self.val_data_loader = val_data_loader\n            self.alexnet = alexnet_module\n\n        def build(self):\n            with flow.no_grad():\n                image, label = self.val_data_loader()\n                image = image.to(args.device)\n                logits = self.alexnet(image)\n                predictions = logits.softmax()\n            return predictions, label\n\n    alexnet_eval_graph = AlexNetEvalGraph()\n\n    of_losses = []\n    all_samples = len(val_data_loader) * args.val_batch_size\n    print_interval = 10\n\n    for epoch in range(args.epochs):\n        alexnet_module.train()\n\n        for b in range(len(train_data_loader)):\n            # oneflow graph train\n            start_t = time.time()\n            loss = alexnet_graph()\n            end_t = time.time()\n            if b % print_interval == 0:\n                l = loss.numpy()\n                of_losses.append(l)\n                print(\n                    \"epoch {} train iter {} oneflow loss {}, train time : {}\".format(\n                        epoch, b, l, end_t - start_t\n                    )\n                )\n        print(\"epoch %d train done, start validation\" % epoch)\n\n        alexnet_module.eval()\n        correct_of = 0.0\n        for b in range(len(val_data_loader)):\n\n            start_t = time.time()\n            predictions, label = alexnet_eval_graph()\n            of_predictions = predictions.numpy()\n            clsidxs = np.argmax(of_predictions, axis=1)\n\n            label_nd = label.numpy()\n            for i in range(args.val_batch_size):\n                if clsidxs[i] == label_nd[i]:\n                    correct_of += 1\n            end_t = time.time()\n\n        print(\"epoch %d, oneflow top1 val acc: %f\" % (epoch, correct_of / all_samples))\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n@flow.unittest.skip_unless_1n1d()\nclass TestAlexnetGraph(oneflow.unittest.TestCase):\n    def test_alexnet_graph_repr(test_case):\n        args, unknown_args = _parse_args()\n        args.device = \"cuda\"\n        _test_alexnet_graph_repr(test_case, args)\n\n    @unittest.skip(\"skip for now, becase it failed 2 times in past week\")\n    def test_alexnet_graph_gpu(test_case):\n        args, unknown_args = _parse_args()\n        args.device = \"cuda\"\n        _test_alexnet_graph(test_case, args)\n\n    def test_alexnet_graph_cpu(test_case):\n        args, unknown_args = _parse_args()\n        args.device = \"cpu\"\n        args.train_batch_size = 40\n        _test_alexnet_graph(test_case, args)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/graph/test_comb1to2d.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\n\nimport oneflow as flow\nfrom oneflow import nn\nimport os\nimport numpy as np\n\nimport oneflow.unittest\n\n\nclass _TestModuleDiffHierarchy(nn.Module):\n    def forward(self, x):\n        sbp_1ds = [\n            flow.sbp.broadcast,\n            flow.sbp.partial_sum,\n            flow.sbp.split(0),\n            flow.sbp.split(1),\n            flow.sbp.split(2),\n        ]\n\n        for sbp1 in sbp_1ds:\n\n            for sbp2 in sbp_1ds:\n                for sbp3 in sbp_1ds:\n                    # (2, 2) -> 4\n                    x = x.to_global(\n                        placement=flow.placement(type=\"cuda\", ranks=np.array(range(4))),\n                        sbp=[sbp1],\n                    )\n                    # 4 -> (2, 2)\n                    x = x.to_global(\n                        placement=flow.placement(\n                            type=\"cuda\", ranks=np.array(range(4)).reshape(2, 2)\n                        ),\n                        sbp=[sbp2, sbp3],\n                    )\n\n        return x\n\n\nclass _TestModuleDiffPlacement(nn.Module):\n    def forward(self, x):\n        sbp_1ds = [\n            flow.sbp.broadcast,\n            flow.sbp.partial_sum,\n            flow.sbp.split(0),\n            flow.sbp.split(1),\n            flow.sbp.split(2),\n        ]\n        for sbp1 in sbp_1ds:\n            for sbp2 in sbp_1ds:\n                for sbp3 in sbp_1ds:\n                    # (2, 2) -> 3\n                    # 4 is not divisible by 3\n                    x = x.to_global(\n                        placement=flow.placement(type=\"cuda\", ranks=np.array(range(3))),\n                        sbp=[sbp1],\n                    )\n                    # 3 -> (2, 2)\n                    x = x.to_global(\n                        placement=flow.placement(\n                            type=\"cuda\", ranks=np.array(range(4)).reshape(2, 2)\n                        ),\n                        sbp=[sbp2, sbp3],\n                    )\n\n        return x\n\n\nclass _TestGraph(nn.Graph):\n    def __init__(self, model):\n        super().__init__()\n        self.model = model\n\n    def build(self, x):\n        x = self.model(x)\n        return x\n\n\n@flow.unittest.skip_unless_1n4d()\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\nclass TestLazyAllSbpCombinationTesting(flow.unittest.TestCase):\n    def test_lazy_boxing_2d_all_combination(test_case):\n        os.environ[\"ONEFLOW_BOXING_DISABLE_MIDDLE_NODE_AND_CHECK\"] = \"0\"\n        os.environ[\"ONEFLOW_BOXING_ENABLE_GENERAL_BASIC_COMMUNICATION\"] = \"0\"\n\n        x = flow.ones(\n            4,\n            12,\n            4,\n            sbp=[flow.sbp.broadcast, flow.sbp.broadcast],\n            placement=flow.placement(\n                type=\"cuda\", ranks=np.array(range(4)).reshape(2, 2)\n            ),\n        )\n\n        flow.boxing.nccl.enable_use_compute_stream(False)\n\n        model_diff_hierarchy = _TestModuleDiffHierarchy()\n        graph_diff_hierarchy = _TestGraph(model_diff_hierarchy)\n        y = graph_diff_hierarchy(x)\n\n        model_diff_placement = _TestModuleDiffPlacement()\n        graph_diff_placement = _TestGraph(model_diff_placement)\n        z = graph_diff_placement(x)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/graph/test_comb2d.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\n\nimport oneflow as flow\nfrom oneflow import nn\nimport os\nimport numpy as np\n\nimport oneflow.unittest\n\n\nclass _TestModule(nn.Module):\n    def forward(self, x):\n        sbp_1ds = [\n            flow.sbp.broadcast,\n            flow.sbp.partial_sum,\n            flow.sbp.split(0),\n            flow.sbp.split(1),\n            flow.sbp.split(2),\n        ]\n        y = x\n\n        for sbp1 in sbp_1ds:\n            for sbp2 in sbp_1ds:\n\n                for sbp3 in sbp_1ds:\n                    # in this case, use intra group boxing\n                    if sbp1 == sbp3:\n                        continue\n                    for sbp4 in sbp_1ds:\n                        # (2, 2) -> (2, 2)\n                        x = x.to_global(sbp=[sbp1, sbp2])\n                        x = x.to_global(sbp=[sbp3, sbp4])\n\n        return x\n\n\nclass _TestGraph(nn.Graph):\n    def __init__(self, model):\n        super().__init__()\n        self.model = model\n\n    def build(self, x):\n        x = self.model(x)\n        return x\n\n\n@flow.unittest.skip_unless_1n4d()\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\nclass TestLazyAllSbpCombinationTesting(flow.unittest.TestCase):\n    def test_lazy_boxing_2d_all_combination(test_case):\n        os.environ[\"ONEFLOW_BOXING_DISABLE_MIDDLE_NODE_AND_CHECK\"] = \"0\"\n        os.environ[\"ONEFLOW_BOXING_ENABLE_GENERAL_BASIC_COMMUNICATION\"] = \"0\"\n\n        model = _TestModule()\n        graph = _TestGraph(model)\n\n        flow.boxing.nccl.enable_use_compute_stream(False)\n\n        x = flow.ones(\n            4,\n            4,\n            4,\n            sbp=[flow.sbp.broadcast, flow.sbp.broadcast],\n            placement=flow.placement(\n                type=\"cuda\", ranks=np.array(range(4)).reshape(2, 2)\n            ),\n        )\n        y = graph(x)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/graph/test_forward_graph.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport os\nimport unittest\n\nimport oneflow\nimport oneflow as flow\nimport oneflow.unittest\n\n\n@flow.unittest.skip_unless_1n1d()\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\nclass TestForwardGraph(flow.unittest.TestCase):\n    @unittest.skip(\"skip for now, becase it failed 2 times in past week\")\n    def test_forward_graph(test_case):\n        class SubModule(flow.nn.Module):\n            def __init__(self):\n                super().__init__()\n                self.weight = flow.nn.Parameter(flow.Tensor(6, 6))\n                self.relu = flow.nn.ReLU()\n\n            def forward(self, x, y):\n                x = oneflow._C.matmul(x, self.weight)\n                x = self.relu(x)\n                y = self.relu(y)\n                return (x, y)\n\n        class CustomModule(flow.nn.Module):\n            def __init__(self):\n                super().__init__()\n                self.layer = SubModule()\n                self.register_buffer(\"dummy_buff\", flow.Tensor(6, 8))\n\n            def forward(self, x, y):\n                (x, y) = self.layer(x, y)\n                x = oneflow._C.flatten(x, 1)\n                x = oneflow._C.matmul(x, self.dummy_buff)\n                return (x, y)\n\n        class CustomGraph(flow.nn.Graph):\n            def __init__(self, module):\n                super().__init__()\n                self.m = module\n\n            def build(self, x, y):\n                out = self.m(x, y)\n                return out\n\n        m = CustomModule()\n        m.to(\"cuda\")\n        g = CustomGraph(m)\n        x = flow.Tensor(6, 6)\n        flow.nn.init.uniform_(x, a=-1.0, b=1.0)\n        x = x.to(\"cuda\")\n        y = flow.Tensor(10, 10)\n        flow.nn.init.uniform_(y, a=-1.0, b=1.0)\n        y = y.to(\"cuda\")\n        print(repr(g))\n        (z, a) = g._compile(x, y)\n        test_case.assertEqual(z.shape, (6, 8))\n        test_case.assertEqual(z.is_lazy, False)\n        test_case.assertEqual(a.shape, (10, 10))\n        test_case.assertEqual(a.is_lazy, False)\n        print(\"graph proto: \", g._graph_proto)\n\n    def test_add_backward(test_case):\n        linear = flow.nn.Linear(3, 8)\n        flow.nn.init.constant_(linear.weight, 2.068758)\n        flow.nn.init.constant_(linear.bias, 0.23)\n        of_sgd = flow.optim.SGD(linear.parameters(), lr=0.001, momentum=0.9)\n\n        class GraphAddBackward(flow.nn.Graph):\n            def __init__(self):\n                super().__init__()\n                self.linear = linear\n                self.add_optimizer(of_sgd)\n\n            def build(self, x):\n                out = self.linear(x)\n                out = out.mean()\n                out.backward()\n                return out\n\n        g_with_b = GraphAddBackward()\n        x = flow.ones(8, 3)\n        out = g_with_b(x)\n        print(\"graph proto: \", g_with_b._graph_proto)\n        print(repr(g_with_b))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/graph/test_free_tensor_not_in_job.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.unittest\nimport oneflow.nn as nn\n\n\ndef get_bn_graph():\n    model = nn.BatchNorm1d(6)\n    model.eval()\n    model.to_global(flow.placement.all(\"cpu\"), flow.sbp.broadcast)\n\n    class Testgraph(flow.nn.Graph):\n        def __init__(self, model):\n            super(Testgraph, self).__init__()\n            self.module = model\n\n        def build(self, x):\n            return self.module(x)\n\n    test_graph = Testgraph(model)\n    return test_graph\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestFreeTensorNotInJob(flow.unittest.TestCase):\n    def test_free_tensor_not_in_job(test_case):\n        x = flow.randn(1, 6, 2).to_global(\n            placement=flow.placement.all(\"cpu\"), sbp=flow.sbp.split(0)\n        )\n        y = get_bn_graph()(x)\n        test_case.assertEqual(y.size(), (1, 6, 2))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/graph/test_fx_fuse.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow as flow\nimport oneflow.nn as nn\nimport numpy as np\nimport unittest\nfrom oneflow.test_utils.automated_test_util import *\nimport numpy as np\nimport copy\nfrom typing import Dict, Any, Tuple\n\n\ndef _fuse_conv_bn_eval(conv, bn):\n    \"\"\"\n    Given a conv Module `A` and an batch_norm module `B`, returns a conv\n    module `C` such that C(x) == B(A(x)) in inference mode.\n    \"\"\"\n    assert not (conv.training or bn.training), \"Fusion only for eval!\"\n    fused_conv = copy.deepcopy(conv)\n\n    fused_conv.weight, fused_conv.bias = _fuse_conv_bn_weights(\n        fused_conv.weight,\n        fused_conv.bias,\n        bn.running_mean,\n        bn.running_var,\n        bn.eps,\n        bn.weight,\n        bn.bias,\n    )\n\n    return fused_conv\n\n\ndef _fuse_conv_bn_weights(conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b):\n    if conv_b is None:\n        conv_b = flow.zeros_like(bn_rm)\n    if bn_w is None:\n        bn_w = flow.ones_like(bn_rm)\n    if bn_b is None:\n        bn_b = flow.zeros_like(bn_rm)\n    bn_var_rsqrt = flow.rsqrt(bn_rv + bn_eps)\n\n    conv_w = conv_w * (bn_w * bn_var_rsqrt).reshape(\n        [-1] + [1] * (len(conv_w.shape) - 1)\n    )\n    conv_b = (conv_b - bn_rm) * bn_var_rsqrt * bn_w + bn_b\n\n    return flow.nn.Parameter(conv_w), flow.nn.Parameter(conv_b)\n\n\ndef _parent_name(target: str) -> Tuple[str, str]:\n    \"\"\"\n    Splits a qualname into parent path and last atom.\n    For example, `foo.bar.baz` -> (`foo.bar`, `baz`)\n    \"\"\"\n    *parent, name = target.rsplit(\".\", 1)\n    return parent[0] if parent else \"\", name\n\n\ndef _replace_node_module(\n    node: flow.fx.Node, modules: Dict[str, Any], new_module: flow.nn.Module\n):\n    assert isinstance(node.target, str)\n    parent_name, name = _parent_name(node.target)\n    setattr(modules[parent_name], name, new_module)\n\n\ndef _fx_fuse(model: flow.nn.Module) -> flow.nn.Module:\n    model = copy.deepcopy(model)\n    # The first step of most FX passes is to symbolically trace our model to\n    # obtain a `GraphModule`. This is a representation of our original model\n    # that is functionally identical to our original model, except that we now\n    # also have a graph representation of our forward pass.\n    fx_model: flow.fx.GraphModule = flow.fx.symbolic_trace(model)\n    modules = dict(fx_model.named_modules())\n\n    # The primary representation for working with FX are the `Graph` and the\n    # `Node`. Each `GraphModule` has a `Graph` associated with it - this\n    # `Graph` is also what generates `GraphModule.code`.\n    # The `Graph` itself is represented as a list of `Node` objects. Thus, to\n    # iterate through all of the operations in our graph, we iterate over each\n    # `Node` in our `Graph`.\n    for node in fx_model.graph.nodes:\n        # The FX IR contains several types of nodes, which generally represent\n        # call sites to modules, functions, or methods. The type of node is\n        # determined by `Node.op`.\n        if (\n            node.op != \"call_module\"\n        ):  # If our current node isn't calling a Module then we can ignore it.\n            continue\n        # For call sites, `Node.target` represents the module/function/method\n        # that's being called. Here, we check `Node.target` to see if it's a\n        # batch norm module, and then check `Node.args[0].target` to see if the\n        # input `Node` is a convolution.\n        if (\n            type(modules[node.target]) is nn.BatchNorm2d\n            and type(modules[node.args[0].target]) is nn.Conv2d\n        ):\n            if len(node.args[0].users) > 1:  # Output of conv is used by other nodes\n                continue\n            conv = modules[node.args[0].target]\n            bn = modules[node.target]\n            fused_conv = _fuse_conv_bn_eval(conv, bn)\n            _replace_node_module(node.args[0], modules, fused_conv)\n            # As we've folded the batch nor into the conv, we need to replace all uses\n            # of the batch norm with the conv.\n            node.replace_all_uses_with(node.args[0])\n            # Now that all uses of the batch norm have been replaced, we can\n            # safely remove the batch norm.\n            fx_model.graph.erase_node(node)\n    fx_model.graph.lint()\n    # After we've modified our graph, we need to recompile our graph in order\n    # to keep the generated code in sync.\n    fx_model.recompile()\n    return fx_model\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestConvBnFuse(flow.unittest.TestCase):\n    def test_fuse(test_case):\n        class WrappedBatchNorm(nn.Module):\n            def __init__(self):\n                super().__init__()\n                self.mod = nn.BatchNorm2d(1)\n\n            def forward(self, x):\n                return self.mod(x)\n\n        class M(nn.Module):\n            def __init__(self):\n                super().__init__()\n                self.conv1 = nn.Conv2d(1, 1, 1)\n                self.bn1 = nn.BatchNorm2d(1)\n                self.conv2 = nn.Conv2d(1, 1, 1)\n                self.nested = nn.Sequential(nn.BatchNorm2d(1), nn.Conv2d(1, 1, 1),)\n                self.wrapped = WrappedBatchNorm()\n\n            def forward(self, x):\n                x = self.conv1(x)\n                x = self.bn1(x)\n                x = self.conv2(x)\n                x = self.nested(x)\n                x = self.wrapped(x)\n                return x\n\n        model = M()\n\n        model.eval()\n\n        fused_model = _fx_fuse(model)\n        for i in range(10):\n            inp = flow.randn(5, 1, 32, 32)\n            test_case.assertTrue(\n                np.allclose(fused_model(inp).numpy(), model(inp).numpy(), atol=1e-6)\n            )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/graph/test_fx_replace_ops.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow as flow\nfrom oneflow.fx import symbolic_trace, replace_pattern\nfrom oneflow.test_utils.automated_test_util import *\nimport unittest\n\n\nclass M(flow.nn.Module):\n    def __init__(self):\n        super().__init__()\n\n    def forward(self, x, w1, w2):\n        val1 = flow.neg(w1)\n        m1 = flow.cat([val1, w2]).sum()\n        val2 = flow.neg(w1)\n        m2 = flow.cat([val2, w2]).sum()\n        return x + flow.max(m1) + flow.max(m2)\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestReplaceOps(flow.unittest.TestCase):\n    def test_pattern(test_case):\n        traced = symbolic_trace(M())\n\n        def pattern(a1, a2):\n            val1 = flow.neg(a1)\n            return flow.cat([val1, a2]).sum()\n\n        def replacement(w1, w2):\n            return flow.stack([w1, w2])\n\n        replace_pattern(traced, pattern, replacement)\n\n        test_case.assertTrue(\"cat\" not in traced.code)\n        test_case.assertTrue(\"neg\" not in traced.code)\n        test_case.assertTrue(\"stack\" in traced.code)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/graph/test_fx_symbolic_trace_module.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow as flow\nimport oneflow.nn as nn\nimport numpy as np\nimport unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\nclass AlexNet(nn.Module):\n    def __init__(self, num_classes: int = 1000) -> None:\n        super(AlexNet, self).__init__()\n        self.features = nn.Sequential(\n            nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),\n            nn.ReLU(inplace=True),\n            nn.MaxPool2d(kernel_size=3, stride=2),\n            nn.Conv2d(64, 192, kernel_size=5, padding=2),\n            nn.ReLU(inplace=True),\n            nn.MaxPool2d(kernel_size=3, stride=2),\n            nn.Conv2d(192, 384, kernel_size=3, padding=1),\n            nn.ReLU(inplace=True),\n            nn.Conv2d(384, 256, kernel_size=3, padding=1),\n            nn.ReLU(inplace=True),\n            nn.Conv2d(256, 256, kernel_size=3, padding=1),\n            nn.ReLU(inplace=True),\n            nn.MaxPool2d(kernel_size=3, stride=2),\n        )\n        self.avgpool = nn.AdaptiveAvgPool2d((6, 6))\n        self.classifier = nn.Sequential(\n            nn.Dropout(),\n            nn.Linear(256 * 6 * 6, 4096),\n            nn.ReLU(inplace=True),\n            nn.Dropout(),\n            nn.Linear(4096, 4096),\n            nn.ReLU(inplace=True),\n            nn.Linear(4096, num_classes),\n        )\n\n    def forward(self, x: flow.Tensor) -> flow.Tensor:\n        x = self.features(x)\n        x = self.avgpool(x)\n        x = flow.flatten(x, 1)\n        x = self.classifier(x)\n        return x\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestAlexNet(flow.unittest.TestCase):\n    def test_alexnet(test_case):\n        m = AlexNet()\n        m = m.eval()\n        gm: flow.fx.GraphModule = flow.fx.symbolic_trace(m)\n        for i in range(5):\n            input = flow.randn(1, 3, 224, 224)\n            test_case.assertTrue(\n                np.allclose(gm(input).numpy(), m(input).numpy(), equal_nan=True)\n            )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/graph/test_gbc1to2d.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\nimport oneflow\nimport numpy as np\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport time\nimport os\n\n\ndef _test_general_basic_communication_1d_to_2d(test_case, src_nd_sbp, dst_nd_sbp):\n    # can not process p in dst\n    if flow.sbp.partial_sum() in dst_nd_sbp:\n        return\n\n    # input\n    placement_x = flow.placement(\"cuda\", ranks=[0, 1, 2])\n    placement_y = flow.placement(\"cuda\", ranks=[[3, 0], [1, 2]])\n    local_np = np.arange(4 * 14).reshape(4, 14)\n    x = flow.tensor(local_np, sbp=src_nd_sbp, placement=placement_x)\n\n    # check eager boxing\n    eager_out = x.to_global(sbp=dst_nd_sbp, placement=placement_y)\n    test_case.assertTrue(np.array_equal(eager_out.numpy(), x.numpy()))\n\n    # check graph boxing\n    flow.boxing.nccl.enable_use_compute_stream(False)\n\n    class TestGeneralBasicCommunicationGraph(flow.nn.Graph):\n        def __init__(self):\n            super().__init__()\n\n        def build(self, x):\n            y = x.to_global(sbp=dst_nd_sbp, placement=placement_y)\n            return y\n\n    graph = TestGeneralBasicCommunicationGraph()\n    y = graph(x)\n    out_np = y.numpy()\n    in_np = x.numpy()\n    test_case.assertTrue(np.array_equal(out_np, in_np))\n\n\ndef gen_nd_sbp_1d():\n    sbp_list = [\n        flow.sbp.partial_sum(),\n        flow.sbp.broadcast(),\n        flow.sbp.split(0),\n        flow.sbp.split(1),\n    ]\n    return sbp_list\n\n\ndef gen_nd_sbp_2d():\n    nd_sbp_list = []\n    for sbp0 in gen_nd_sbp_1d():\n        for sbp1 in gen_nd_sbp_1d():\n            nd_sbp_list.append([sbp0, sbp1])\n    return nd_sbp_list\n\n\n@flow.unittest.skip_unless_1n4d()\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\nclass TestGeneralBasicCommunication(flow.unittest.TestCase):\n    def test_general_basic_communication(test_case):\n        os.environ[\"ONEFLOW_BOXING_DISABLE_MIDDLE_NODE_AND_CHECK\"] = \"0\"\n        os.environ[\"ONEFLOW_BOXING_ENABLE_GENERAL_BASIC_COMMUNICATION\"] = \"1\"\n\n        arg_dict = OrderedDict()\n        arg_dict[\"src_nd_sbp\"] = gen_nd_sbp_1d()\n        arg_dict[\"dst_nd_sbp\"] = gen_nd_sbp_2d()\n        for arg in GenArgList(arg_dict):\n            _test_general_basic_communication_1d_to_2d(test_case, *arg)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/graph/test_gbc2d.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\nimport oneflow\nimport numpy as np\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport time\nimport os\n\n\ndef _test_general_basic_communication_same_placement(test_case, src_nd_sbp, dst_nd_sbp):\n    # can not process p in dst\n    if flow.sbp.partial_sum() in dst_nd_sbp:\n        return\n\n    # skip src == dst\n    if src_nd_sbp == dst_nd_sbp:\n        return\n\n    # in this case, use intra group boxing\n    if src_nd_sbp[0] == dst_nd_sbp[0]:\n        return\n\n    # in this case, use inter group boxing\n    if (\n        src_nd_sbp[1] == dst_nd_sbp[1]\n        and src_nd_sbp[0] != src_nd_sbp[1]\n        and dst_nd_sbp[0] != dst_nd_sbp[1]\n    ):\n        return\n\n    # input\n    placement = flow.placement(\"cuda\", ranks=[[0, 1], [2, 3]])\n    local_np = np.arange(4 * 5).reshape(4, 5)\n    x = flow.tensor(local_np, sbp=src_nd_sbp, placement=placement)\n\n    # check eager boxing\n    eager_out = x.to_global(sbp=dst_nd_sbp, placement=placement)\n    test_case.assertTrue(np.array_equal(eager_out.numpy(), x.numpy()))\n\n    # check graph boxing\n    flow.boxing.nccl.enable_use_compute_stream(False)\n\n    class TestGeneralBasicCommunicationGraph(flow.nn.Graph):\n        def __init__(self):\n            super().__init__()\n\n        def build(self, x):\n            y = x.to_global(sbp=dst_nd_sbp, placement=placement)\n            return y\n\n    graph = TestGeneralBasicCommunicationGraph()\n    y = graph(x)\n    out_np = y.numpy()\n    in_np = x.numpy()\n    test_case.assertTrue(np.array_equal(out_np, in_np))\n\n\ndef gen_nd_sbp():\n    sbp_list = [\n        flow.sbp.partial_sum(),\n        flow.sbp.broadcast(),\n        flow.sbp.split(0),\n        flow.sbp.split(1),\n    ]\n    nd_sbp_list = []\n    for sbp0 in sbp_list:\n        for sbp1 in sbp_list:\n            nd_sbp_list.append([sbp0, sbp1])\n    return nd_sbp_list\n\n\n@flow.unittest.skip_unless_1n4d()\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\nclass TestGeneralBasicCommunication(flow.unittest.TestCase):\n    def test_general_basic_communication(test_case):\n        os.environ[\"ONEFLOW_BOXING_DISABLE_MIDDLE_NODE_AND_CHECK\"] = \"0\"\n        os.environ[\"ONEFLOW_BOXING_ENABLE_GENERAL_BASIC_COMMUNICATION\"] = \"1\"\n\n        arg_dict = OrderedDict()\n        arg_dict[\"src_nd_sbp\"] = gen_nd_sbp()\n        arg_dict[\"dst_nd_sbp\"] = gen_nd_sbp()\n        for arg in GenArgList(arg_dict):\n            _test_general_basic_communication_same_placement(test_case, *arg)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/graph/test_gbc2to1d.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\nimport oneflow\nimport numpy as np\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport time\nimport os\n\n\ndef _test_general_basic_communication_2d_to_1d(test_case, src_nd_sbp, dst_nd_sbp):\n    # can not process p in dst\n    if flow.sbp.partial_sum() == dst_nd_sbp:\n        return\n\n    # input\n    placement_x = flow.placement(\"cuda\", ranks=[[0, 1], [2, 3]])\n    placement_y = flow.placement(\"cuda\", ranks=[0, 3, 4])\n    local_np = np.arange(13 * 5).reshape(13, 5)\n    x = flow.tensor(local_np, sbp=src_nd_sbp, placement=placement_x)\n\n    # check eager boxing\n    eager_out = x.to_global(sbp=dst_nd_sbp, placement=placement_y)\n    test_case.assertTrue(np.array_equal(eager_out.numpy(), x.numpy()))\n\n    # check graph boxing\n    flow.boxing.nccl.enable_use_compute_stream(False)\n\n    class TestGeneralBasicCommunicationGraph(flow.nn.Graph):\n        def __init__(self):\n            super().__init__()\n\n        def build(self, x):\n            y = x.to_global(sbp=dst_nd_sbp, placement=placement_y)\n            return y\n\n    graph = TestGeneralBasicCommunicationGraph()\n    y = graph(x)\n    out_np = y.numpy()\n    in_np = x.numpy()\n    test_case.assertTrue(np.array_equal(out_np, in_np))\n\n\ndef gen_nd_sbp_1d():\n    sbp_list = [\n        flow.sbp.partial_sum(),\n        flow.sbp.broadcast(),\n        flow.sbp.split(0),\n        flow.sbp.split(1),\n    ]\n    return sbp_list\n\n\ndef gen_nd_sbp_2d():\n    nd_sbp_list = []\n    for sbp0 in gen_nd_sbp_1d():\n        for sbp1 in gen_nd_sbp_1d():\n            nd_sbp_list.append([sbp0, sbp1])\n    return nd_sbp_list\n\n\n@flow.unittest.skip_unless_2n4d()\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\nclass TestGeneralBasicCommunication(flow.unittest.TestCase):\n    def test_general_basic_communication(test_case):\n        os.environ[\"ONEFLOW_BOXING_DISABLE_MIDDLE_NODE_AND_CHECK\"] = \"0\"\n        os.environ[\"ONEFLOW_BOXING_ENABLE_GENERAL_BASIC_COMMUNICATION\"] = \"1\"\n\n        arg_dict = OrderedDict()\n        arg_dict[\"src_nd_sbp\"] = gen_nd_sbp_2d()\n        arg_dict[\"dst_nd_sbp\"] = gen_nd_sbp_1d()\n        for arg in GenArgList(arg_dict):\n            _test_general_basic_communication_2d_to_1d(test_case, *arg)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/graph/test_gbc2to2d.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\nimport oneflow\nimport numpy as np\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport time\nimport os\n\n\ndef _test_general_basic_communication_2d_to_2d(test_case, src_nd_sbp, dst_nd_sbp):\n    # can not process p in dst\n    if flow.sbp.partial_sum() in dst_nd_sbp:\n        return\n\n    if dst_nd_sbp[0] == dst_nd_sbp[1] and src_nd_sbp[0] == src_nd_sbp[1]:\n        return\n\n    # input\n    placement_x = flow.placement(\"cuda\", ranks=[[0, 1], [2, 3]])\n    placement_y = flow.placement(\"cuda\", ranks=[[0, 3, 4], [2, 5, 6]])\n    local_np = np.arange(12 * 12).reshape(12, 12)\n    x = flow.tensor(local_np, sbp=src_nd_sbp, placement=placement_x)\n\n    # check eager boxing\n    eager_out = x.to_global(sbp=dst_nd_sbp, placement=placement_y)\n    test_case.assertTrue(np.array_equal(eager_out.numpy(), x.numpy()))\n\n    # check graph boxing\n    flow.boxing.nccl.enable_use_compute_stream(False)\n\n    class TestGeneralBasicCommunicationGraph(flow.nn.Graph):\n        def __init__(self):\n            super().__init__()\n\n        def build(self, x):\n            y = x.to_global(sbp=dst_nd_sbp, placement=placement_y)\n            return y\n\n    graph = TestGeneralBasicCommunicationGraph()\n    y = graph(x)\n    out_np = y.numpy()\n    in_np = x.numpy()\n    test_case.assertTrue(np.array_equal(out_np, in_np))\n\n\ndef gen_nd_sbp():\n    sbp_list = [\n        flow.sbp.partial_sum(),\n        flow.sbp.broadcast(),\n        flow.sbp.split(0),\n        flow.sbp.split(1),\n    ]\n    nd_sbp_list = []\n    for sbp0 in sbp_list:\n        for sbp1 in sbp_list:\n            nd_sbp_list.append([sbp0, sbp1])\n    return nd_sbp_list\n\n\n@flow.unittest.skip_unless_2n4d()\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\nclass TestGeneralBasicCommunication(flow.unittest.TestCase):\n    def test_general_basic_communication(test_case):\n        os.environ[\"ONEFLOW_BOXING_DISABLE_MIDDLE_NODE_AND_CHECK\"] = \"0\"\n        os.environ[\"ONEFLOW_BOXING_ENABLE_GENERAL_BASIC_COMMUNICATION\"] = \"1\"\n\n        arg_dict = OrderedDict()\n        arg_dict[\"src_nd_sbp\"] = gen_nd_sbp()\n        arg_dict[\"dst_nd_sbp\"] = gen_nd_sbp()\n        for arg in GenArgList(arg_dict):\n            _test_general_basic_communication_2d_to_2d(test_case, *arg)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/graph/test_graph.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport os\nimport unittest\nfrom threading import Thread\n\nimport numpy as np\n\nimport oneflow\nimport oneflow as flow\nfrom oneflow.nn.graph import GraphModule, GraphTensor\nimport oneflow.framework.graph_build_util as graph_build_util\nimport oneflow.framework.scope_util as scope_util\nimport oneflow.unittest\n\n\nclass SubModule(flow.nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.conv1 = flow.nn.Conv2d(1, 1, 5)\n        self.relu = flow.nn.ReLU()\n\n    def forward(self, x):\n        x = self.conv1(x)\n        x = self.relu(x)\n        return x\n\n\nclass CustomModule(flow.nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.layer = SubModule()\n        self.fc1 = flow.nn.Linear(36, 4)\n        self.register_buffer(\"dummy_buff\", flow.Tensor(1, 4))\n\n    def forward(self, x):\n        x = self.layer(x)\n        x = oneflow._C.flatten(x, 1)\n        x = self.fc1(x) + self.dummy_buff\n        return x\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n@flow.unittest.skip_unless_1n1d()\nclass TestGraph(flow.unittest.TestCase):\n    def test_add_nested_module(test_case):\n        x = flow.Tensor(1, 1, 10, 10)\n        flow.nn.init.uniform_(x, a=-1.0, b=1.0)\n        m = CustomModule()\n        y = m(x)\n\n        class CustomGraphNestedModule(flow.nn.Graph):\n            def __init__(self):\n                super().__init__()\n                self.m = m\n\n            def build(self, x):\n                return self.m(x)\n\n        g = CustomGraphNestedModule()\n        test_case.assertTrue(isinstance(g.m, flow.nn.graph.Proxy))\n        test_case.assertEqual(g.m.to(GraphModule).type, \"MODULE\")\n        test_case.assertEqual(g.m.to(GraphModule).name, \"m\")\n        test_case.assertTrue(isinstance(g.m.dummy_buff, flow.nn.graph.Proxy))\n        test_case.assertEqual(g.m.dummy_buff.to(GraphTensor).type, \"BUFFER\")\n        test_case.assertTrue(isinstance(g.m.layer.conv1, flow.nn.graph.Proxy))\n        test_case.assertEqual(g.m.layer.conv1.to(GraphModule).name, \"conv1\")\n        test_case.assertEqual(g.m.layer.conv1.to(GraphModule).name_prefix, \"m.layer.\")\n        test_case.assertTrue(isinstance(g.m.layer.conv1.weight, flow.nn.graph.Proxy))\n        test_case.assertEqual(g.m.layer.conv1.weight.to(GraphTensor).type, \"PARAMETER\")\n        g.m.layer.conv1.to(GraphModule)._is_executing_forward = True\n        test_case.assertTrue(isinstance(g.m.layer.conv1.weight, flow.Tensor))\n        g.m.layer.conv1.to(GraphModule)._is_executing_forward = False\n        test_case.assertEqual(g.m.layer.conv1.kernel_size, (5, 5))\n        z = g.build(x)\n        test_case.assertTrue(np.array_equal(y.numpy(), z.numpy()))\n\n    def test_graph_name(test_case):\n        class ACustomGraph(flow.nn.Graph):\n            def __init__(self):\n                super().__init__()\n\n            def build(self, x):\n                return x\n\n        class BCustomGraph(flow.nn.Graph):\n            def __init__(self):\n                super().__init__()\n\n            def build(self, x):\n                return x\n\n        class CBCustomGraph(BCustomGraph):\n            def __init__(self):\n                super().__init__()\n\n        def create_graph(cnt):\n            a = ACustomGraph()\n            test_case.assertEqual(a.name, \"ACustomGraph_\" + str(cnt))\n            b = BCustomGraph()\n            test_case.assertEqual(b.name, \"BCustomGraph_\" + str(cnt))\n            cb = CBCustomGraph()\n            test_case.assertEqual(cb.name, \"CBCustomGraph_\" + str(cnt))\n\n        flow.nn.Graph._child_init_cnt.clear()\n        for i in range(0, 3):\n            create_graph(i)\n        flow.nn.Graph._child_init_cnt.clear()\n        for i in range(0, 3):\n            create_graph(i)\n\n    def test_graph_build_ctx(test_case):\n        test_case.assertEqual(graph_build_util.lazy_mode.is_enabled(), False)\n        with graph_build_util.lazy_mode.guard(True):\n            test_case.assertEqual(graph_build_util.lazy_mode.is_enabled(), True)\n            with graph_build_util.lazy_mode.guard(False):\n                test_case.assertEqual(graph_build_util.lazy_mode.is_enabled(), False)\n            test_case.assertEqual(graph_build_util.lazy_mode.is_enabled(), True)\n        test_case.assertEqual(graph_build_util.lazy_mode.is_enabled(), False)\n\n        class CustomGraphGraphBuildCtx(flow.nn.Graph):\n            def __init__(self):\n                super().__init__()\n\n            def build(self, x):\n                test_case.assertEqual(graph_build_util.lazy_mode.is_enabled(), True)\n                import oneflow.framework.session_context as session_ctx\n                from oneflow.framework.multi_client_session import MultiClientSession\n\n                session = session_ctx.GetDefaultSession()\n                test_case.assertEqual(type(session), MultiClientSession)\n                import oneflow.framework.scope_util as scope_util\n\n                scope = scope_util.current_scope()\n                scope_proto = graph_build_util.scope_to_proto(scope)\n                test_case.assertEqual(session.id, scope_proto.session_id)\n                test_case.assertEqual(\n                    oneflow._oneflow_internal.JobBuildAndInferCtx_GetCurrentJobName(),\n                    self.name,\n                )\n                return x\n\n        g = CustomGraphGraphBuildCtx()\n        test_case.assertEqual(graph_build_util.lazy_mode.is_enabled(), False)\n        data = np.array([2.0, 1.0, 0.0, -1.0, -2.0])\n        x = flow.tensor(data, dtype=flow.float32)\n        g._compile(x)\n        print(\"graph proto\", g._graph_proto)\n        test_case.assertEqual(graph_build_util.lazy_mode.is_enabled(), False)\n\n    def test_block_scope(test_case):\n        class SubModule0(flow.nn.Module):\n            def __init__(self):\n                super().__init__()\n                self.conv1 = flow.nn.Conv2d(1, 1, 5)\n\n            def forward(self, x):\n                scope = scope_util.current_scope()\n                scope_proto = graph_build_util.scope_to_proto(scope)\n                ck_bool = scope_proto.attr_name2attr_value[\"checkpointing\"].at_bool\n                test_case.assertEqual(ck_bool, True)\n                stage_int = scope_proto.attr_name2attr_value[\n                    \"pipeline_stage_id_hint\"\n                ].at_int64\n                test_case.assertEqual(stage_int, 0)\n                out = self.conv1(x)\n                weight = self.conv1.weight\n                test_case.assertTrue(weight.is_lazy)\n                return out\n\n        class SubModule1(flow.nn.Module):\n            def __init__(self):\n                super().__init__()\n                self.fc1 = flow.nn.Linear(36, 4, False)\n                self.register_buffer(\"dummy_buff\", flow.Tensor(1, 4))\n\n            def forward(self, x):\n                scope = scope_util.current_scope()\n                scope_proto = graph_build_util.scope_to_proto(scope)\n                test_case.assertEqual(\n                    scope_proto.parent_scope_symbol_id,\n                    self.to(flow.nn.graph.GraphModule).prev_scope.symbol_id,\n                )\n                ck_bool = scope_proto.attr_name2attr_value[\"checkpointing\"]\n                test_case.assertEqual(ck_bool.WhichOneof(\"value\"), None)\n                stage_int = scope_proto.attr_name2attr_value[\n                    \"pipeline_stage_id_hint\"\n                ].at_int64\n                test_case.assertEqual(stage_int, 1)\n                name = (\n                    self.to(flow.nn.graph.GraphModule).name_prefix\n                    + self.to(flow.nn.graph.GraphModule).name\n                )\n                prefixes = []\n                for prefix in scope_proto.scope_op_name_prefixes:\n                    prefixes.append(prefix)\n                name_in_scope = \".\".join(prefixes)\n                test_case.assertEqual(name, name_in_scope)\n                b = self.dummy_buff\n                dummy_buff_scope_proto = graph_build_util.scope_to_proto(\n                    self._buffers[\"dummy_buff\"].to(flow.nn.graph.GraphTensor).scope\n                )\n                test_case.assertEqual(\n                    dummy_buff_scope_proto.parent_scope_symbol_id, scope.symbol_id\n                )\n                x = self.fc1(x)\n                return x + b\n\n        class CustomModule1(flow.nn.Module):\n            def __init__(self):\n                super().__init__()\n                self.layer0 = SubModule0()\n                self.layer1 = SubModule1()\n\n            def forward(self, x, y):\n                print(\"x0: \", x.shape)\n                x = self.layer0(x)\n                print(\"x1: \", x.shape)\n                print(\"y0: \", y.shape)\n                y = self.layer1(y)\n                print(\"y1: \", y.shape)\n                return (x, y)\n\n        m = CustomModule1()\n\n        class CustomGraphBlockScope(flow.nn.Graph):\n            def __init__(self):\n                super().__init__()\n                self.m = m\n                self.m.layer0.to(GraphModule).set_stage(stage_id=0)\n                self.m.layer0.to(GraphModule).activation_checkpointing = True\n                self.m.layer1.to(GraphModule).set_stage(stage_id=1)\n\n            def build(self, x, y):\n                return self.m(x, y)\n\n        g = CustomGraphBlockScope()\n        print(g)\n        x = np.ones((1, 1, 10, 10))\n        x = flow.tensor(x, dtype=flow.float32)\n        y = np.ones((16, 36))\n        y = flow.tensor(y, dtype=flow.float32)\n        g._compile(x, y)\n\n    def test_create_optimizer_in_graph(test_case):\n        device = \"cuda\"\n        linear = flow.nn.Linear(3, 8)\n        linear = linear.to(device)\n        flow.nn.init.constant_(linear.weight, 2.068758)\n        flow.nn.init.constant_(linear.bias, 0.23)\n\n        class OptCreatedInGraph(flow.nn.Graph):\n            def __init__(self):\n                super().__init__()\n                self.linear = linear\n                # creat optimizer in nn.Graph and add parameter from ProxyModule\n                self.add_optimizer(\n                    flow.optim.SGD(self.linear.parameters(), lr=0.001, momentum=0.9)\n                )\n\n            def build(self, x):\n                out = self.linear(x)\n                out = out.sum()\n                out.backward()\n                return out\n\n        g = OptCreatedInGraph()\n        print(g)\n\n    def test_graph_in_subthread(test_case):\n        class TinyGraph(flow.nn.Graph):\n            def __init__(self):\n                super().__init__()\n\n            def build(self, input):\n                return input + 1\n\n        def f():\n            tiny_graph = TinyGraph()\n            input = flow.randn(1, 4)\n            return tiny_graph(input)\n\n        f()\n\n        new_thread = Thread(target=f)\n\n        new_thread.start()\n        new_thread.join()\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/graph/test_graph_activation_checkpoint.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport os\nimport re\nimport os\nimport unittest\n\nimport numpy as np\n\nimport oneflow\nimport oneflow as flow\nimport oneflow.framework.graph_build_util as graph_build_util\nimport oneflow.framework.scope_util as scope_util\nimport oneflow.unittest\nfrom oneflow.nn.graph import GraphModule\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n@flow.unittest.skip_unless_1n1d()\nclass TestGraphActivationCheckpoint(flow.unittest.TestCase):\n    def test_activation_checkpoint(test_case):\n        loss_fn = flow.nn.MSELoss(reduction=\"sum\")\n        model = flow.nn.Sequential(flow.nn.Linear(3, 4), flow.nn.Linear(4, 4))\n        model1 = flow.nn.Sequential(flow.nn.Linear(4, 1), flow.nn.Flatten(0, 1))\n\n        class SubModule0(flow.nn.Module):\n            def __init__(self):\n                super().__init__()\n                self.model = model\n\n            def forward(self, x):\n                scope = scope_util.current_scope()\n                scope_proto = graph_build_util.scope_to_proto(scope)\n                ck_bool = scope_proto.attr_name2attr_value[\"checkpointing\"].at_bool\n                test_case.assertEqual(ck_bool, True)\n                out = self.model(x)\n                return out\n\n        class SubModule1(flow.nn.Module):\n            def __init__(self):\n                super().__init__()\n                self.model = model1\n\n            def forward(self, x):\n                scope = scope_util.current_scope()\n                scope_proto = graph_build_util.scope_to_proto(scope)\n                ck_bool = scope_proto.attr_name2attr_value[\"checkpointing\"].at_bool\n                test_case.assertEqual(ck_bool, True)\n                out = self.model(x)\n                return out\n\n        optimizer = flow.optim.SGD(model.parameters(), lr=1e-6)\n\n        class LinearTrainGraph(flow.nn.Graph):\n            def __init__(self):\n                super().__init__()\n                self.model = SubModule0()\n                self.model1 = SubModule1()\n                self.loss_fn = loss_fn\n                # Add an optimizer\n                self.add_optimizer(optimizer)\n                self.model.to(GraphModule).activation_checkpointing = True\n                self.model1.to(GraphModule).activation_checkpointing = True\n\n            def build(self, x, y):\n                y_pred = self.model(x)\n                y_pred = self.model1(y_pred)\n                loss = self.loss_fn(y_pred, y)\n                loss.backward()\n                return loss\n\n        linear_graph = LinearTrainGraph()\n        x = flow.randn(10, 3)\n        y = flow.randn(10)\n        linear_graph._compile(x, y)\n\n        graph_proto = linear_graph._full_graph_proto\n        for op in graph_proto.net.op:\n            # Check flatten gradient operator take checkpoiting as input\n            if re.search(\"flatten.*grad\", op.name, re.I) is not None:\n                find_check_point = False\n                for value in op.user_conf.input.values():\n                    if (\n                        re.search(\"Sys-Checkpointing-Fake-Fw-Op\", str(value), re.I)\n                        is not None\n                    ):\n                        find_check_point = True\n                        print(value)\n                test_case.assertTrue(find_check_point)\n            # Check having insert identity op and first fake op of a segment has indentity grad as it's ctrl in op\n            if (\n                re.search(\n                    \"Sys-Checkpointing-Fake-Fw-Op_model.model.0-matmul*\", op.name, re.I,\n                )\n                is not None\n            ):\n                find_ctrl = False\n                for name in op.ctrl_in_op_name:\n                    if re.search(\"identity\", str(name), re.I) is not None:\n                        find_ctrl = True\n                        print(name)\n                test_case.assertTrue(find_ctrl)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/graph/test_graph_arange.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport os\nimport unittest\nimport numpy as np\nimport oneflow as flow\nimport oneflow.unittest\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n@flow.unittest.skip_unless_1n1d()\nclass TestArangeGraph(oneflow.unittest.TestCase):\n    def test_arange_graph(test_case):\n        of_eager_out = flow.arange(start=0, end=100, step=3, device=flow.device(\"cuda\"))\n\n        class ArangeGraph(flow.nn.Graph):\n            def __init__(self):\n                super().__init__()\n\n            def build(self):\n                return flow.arange(start=0, end=100, step=3, device=flow.device(\"cuda\"))\n\n        arange_g = ArangeGraph()\n        of_lazy_out = arange_g()\n        test_case.assertTrue(\n            np.allclose(of_eager_out.numpy(), of_lazy_out.numpy(), 1e-05, 1e-05)\n        )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/graph/test_graph_asymmetric_io.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport os\nimport unittest\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n@flow.unittest.skip_unless_1n2d()\nclass TestGlobalAsymmetricGraph(oneflow.unittest.TestCase):\n    def test_global_asymmetric_graph_gpu(test_case):\n        Broadcast = [flow.sbp.broadcast]\n        Placement_rank_0 = flow.placement(\"cuda\", ranks=[0])\n        Placement_rank_1 = flow.placement(\"cuda\", ranks=[1])\n\n        class MyGlobalAsymmetricModule(flow.nn.Module):\n            def __init__(self):\n                super().__init__()\n                self.linear1 = flow.nn.Linear(3, 8, False)\n                self.linear2 = flow.nn.Linear(8, 7, False)\n                self.linear1.to_global(placement=Placement_rank_0, sbp=Broadcast)\n                self.linear2.to_global(placement=Placement_rank_1, sbp=Broadcast)\n                flow.nn.init.ones_(self.linear1.weight)\n                flow.nn.init.constant_(self.linear2.weight, 2.3)\n\n            def forward(self, x, y):\n                out0 = x + y\n                out1 = self.linear1(out0)\n                out1 = out1.to_global(placement=Placement_rank_1, sbp=Broadcast)\n                out2 = self.linear2(out1)\n                return out2\n\n        class MyLocalModule(flow.nn.Module):\n            def __init__(self):\n                super().__init__()\n                self.linear1 = flow.nn.Linear(3, 8, False)\n                self.linear2 = flow.nn.Linear(8, 7, False)\n                flow.nn.init.ones_(self.linear1.weight)\n                flow.nn.init.constant_(self.linear2.weight, 2.3)\n\n            def forward(self, x, y):\n                # print(\"local_x in rank : \", flow.env.get_rank(), \" is : \", x)\n                # print(\"local_y in rank : \", flow.env.get_rank(), \" is : \", y)\n                out0 = x + y\n                out1 = self.linear1(out0)\n                out2 = self.linear2(out1)\n                return out2\n\n        my_local_module = MyLocalModule()\n        np_x = np.random.randn(5, 3)\n        np_y = np.ones(3)\n        local_x = flow.tensor(np_x, dtype=flow.float32)\n        global_x = local_x.to_global(\n            placement=flow.placement(\"cuda\", ranks=[0, 1]), sbp=Broadcast\n        )\n        local_x = global_x.to_local().to(\"cpu\")\n        local_y = flow.tensor(np_y, dtype=flow.float32)\n        local_out = my_local_module(local_x, local_y)\n        # print(\"eager_local_out: \", local_out)\n\n        my_module = MyGlobalAsymmetricModule()\n        x = local_x.to_global(placement=Placement_rank_0, sbp=Broadcast)\n        y = local_y.to_global(placement=Placement_rank_0, sbp=Broadcast)\n\n        class MyAsymmetricGraph(flow.nn.Graph):\n            def __init__(self):\n                super().__init__()\n                self.my_net = my_module\n\n            def build(self, x, y):\n                return self.my_net(x, y)\n\n        my_g = MyAsymmetricGraph()\n        graph_out = my_g(x, y)\n        test_case.assertTrue(graph_out.placement == Placement_rank_1)\n        graph_local_out = graph_out.to_local()\n        # NOTE(chengcheng): MUST call for each rank sync correct input copy\n        graph_local_out_np = graph_local_out.numpy()\n        # print(\"graph_local_out in rank \", flow.env.get_rank(),  \" is : \", graph_local_out)\n        if flow.env.get_rank() == 0:\n            test_case.assertTrue(graph_local_out.shape.numel() == 0)\n            test_case.assertTrue(graph_local_out_np.size == np.array([]).size)\n        elif flow.env.get_rank() == 1:\n            test_case.assertTrue(\n                np.allclose(\n                    graph_local_out.numpy(), local_out.numpy(), atol=1e-4, rtol=1e-4\n                )\n            )\n        else:\n            test_case.assertTrue(False)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/graph/test_graph_block.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport os\nimport unittest\nimport types\nimport warnings\n\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.nn as nn\nimport oneflow.unittest\nimport oneflow.framework.graph_build_util as graph_build_util\nimport oneflow.framework.scope_util as scope_util\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n@flow.unittest.skip_unless_1n1d()\nclass TestGraphBlock(flow.unittest.TestCase):\n    def test_module_has_custom_func(test_case):\n        class CustomModuleHasFunc(flow.nn.Module):\n            def __init__(self):\n                super().__init__()\n                self.data_mem = 10\n\n            def forward(self, x):\n                return self._custom_func(x)\n\n            def _custom_func(self, x):\n                test_case.assertEqual(self.data_mem, 10)\n                return x\n\n        class CustomGraphHasFunc(flow.nn.Graph):\n            def __init__(self):\n                super().__init__()\n                self.m = CustomModuleHasFunc()\n\n            def build(self, x):\n                return self.m(x)\n\n        g = CustomGraphHasFunc()\n        x = np.ones((10, 10))\n        x = flow.tensor(x, dtype=flow.float32)\n        out = g(x)\n        test_case.assertTrue(np.array_equal(x.numpy(), out.numpy()))\n\n    def test_module_has_special_attr(test_case):\n        class CustomModuleHasSpecialAttr(flow.nn.Module):\n            def __init__(self):\n                super().__init__()\n                self.config = 1\n                self.name = \"test_name\"\n\n            def forward(self, x):\n                test_case.assertEqual(self.config, 1)\n                test_case.assertEqual(self.name, \"test_name\")\n                test_case.assertEqual(self.to(nn.graph.GraphModule).name, \"m\")\n                return x\n\n        class CustomGraphHasSpecialAttr(flow.nn.Graph):\n            def __init__(self):\n                super().__init__()\n                self.m = CustomModuleHasSpecialAttr()\n\n            def build(self, x):\n                return self.m(x)\n\n        g = CustomGraphHasSpecialAttr()\n        x = np.ones((10, 10))\n        x = flow.tensor(x, dtype=flow.float32)\n        out = g(x)\n        test_case.assertTrue(np.array_equal(x.numpy(), out.numpy()))\n\n    def test_block_with_parameter(test_case):\n        device = \"cuda\"\n        linear = flow.nn.Linear(3, 8)\n        linear = linear.to(device)\n        flow.nn.init.constant_(linear.weight, 2.068758)\n        flow.nn.init.constant_(linear.bias, 0.23)\n        of_sgd = flow.optim.SGD(linear.parameters(), lr=0.001, momentum=0.9)\n\n        x = flow.tensor(\n            [\n                [-0.94630778, -0.83378579, -0.87060891],\n                [2.0289922, -0.28708987, -2.18369248],\n                [0.35217619, -0.67095644, -1.58943879],\n                [0.08086036, -1.81075924, 1.20752494],\n                [0.8901075, -0.49976737, -1.07153746],\n                [-0.44872912, -1.07275683, 0.06256855],\n                [-0.22556897, 0.74798368, 0.90416439],\n                [0.48339456, -2.32742195, -0.59321527],\n            ],\n            dtype=flow.float32,\n            device=device,\n            requires_grad=False,\n        )\n\n        class CustomModule(flow.nn.Module):\n            def __init__(self):\n                super().__init__()\n                self.linear = linear\n\n            def forward(self, x):\n                return self._forward_impl(x)\n\n            def _forward_impl(self, x):\n                test_case.assertTrue(isinstance(self.linear, flow.nn.graph.Proxy))\n                return self.linear(x)\n\n        class LinearTrainGraph(flow.nn.Graph):\n            def __init__(self):\n                super().__init__()\n                self.m = CustomModule()\n                self.add_optimizer(of_sgd)\n\n            def build(self, x):\n                out = self.m(x)\n                out = out.sum()\n                out.backward()\n                test_case.assertTrue(self.m.linear.weight.is_lazy)\n                return out\n\n        linear_t_g = LinearTrainGraph()\n\n        linear_t_g(x)\n\n    def test_block_get_class_in_forward(test_case):\n        device = \"cuda\"\n        linear = flow.nn.Linear(3, 8)\n        linear = linear.to(device)\n        flow.nn.init.constant_(linear.weight, 2.068758)\n        flow.nn.init.constant_(linear.bias, 0.23)\n        of_sgd = flow.optim.SGD(linear.parameters(), lr=0.001, momentum=0.9)\n\n        x = flow.tensor(\n            [\n                [-0.94630778, -0.83378579, -0.87060891],\n                [2.0289922, -0.28708987, -2.18369248],\n                [0.35217619, -0.67095644, -1.58943879],\n                [0.08086036, -1.81075924, 1.20752494],\n                [0.8901075, -0.49976737, -1.07153746],\n                [-0.44872912, -1.07275683, 0.06256855],\n                [-0.22556897, 0.74798368, 0.90416439],\n                [0.48339456, -2.32742195, -0.59321527],\n            ],\n            dtype=flow.float32,\n            device=device,\n            requires_grad=False,\n        )\n\n        class CustomModule(flow.nn.Module):\n            def __init__(self):\n                super().__init__()\n                self.linear = linear\n\n            def forward(self, x):\n                return self._forward_impl(x)\n\n            def _forward_impl(self, x):\n                test_case.assertTrue(isinstance(self.linear, flow.nn.Module))\n                test_case.assertTrue(isinstance(self.linear, flow.nn.Linear))\n                return self.linear(x)\n\n        class LinearTrainGraph(flow.nn.Graph):\n            def __init__(self):\n                super().__init__()\n                self.m = CustomModule()\n                test_case.assertTrue(isinstance(self.m.linear, flow.nn.graph.Proxy))\n                self.add_optimizer(of_sgd)\n\n            def build(self, x):\n                test_case.assertTrue(isinstance(self.m.linear, flow.nn.Module))\n                test_case.assertTrue(isinstance(self.m.linear, flow.nn.Linear))\n                out = self.m(x)\n                out = out.sum()\n                out.backward()\n                test_case.assertTrue(self.m.linear.weight.is_lazy)\n                return out\n\n        linear_t_g = LinearTrainGraph()\n        test_case.assertTrue(isinstance(linear_t_g.m.linear, flow.nn.graph.Proxy))\n\n        linear_t_g(x)\n\n    def test_block_with_not_registered_module(test_case):\n        device = \"cuda\"\n        linear = flow.nn.Linear(3, 8)\n        linear = linear.to(device)\n        flow.nn.init.constant_(linear.weight, 2.068758)\n        flow.nn.init.constant_(linear.bias, 0.23)\n\n        x = flow.tensor(\n            [\n                [-0.94630778, -0.83378579, -0.87060891],\n                [2.0289922, -0.28708987, -2.18369248],\n                [0.35217619, -0.67095644, -1.58943879],\n                [0.08086036, -1.81075924, 1.20752494],\n                [0.8901075, -0.49976737, -1.07153746],\n                [-0.44872912, -1.07275683, 0.06256855],\n                [-0.22556897, 0.74798368, 0.90416439],\n                [0.48339456, -2.32742195, -0.59321527],\n            ],\n            dtype=flow.float32,\n            device=device,\n            requires_grad=False,\n        )\n\n        class CustomModule(flow.nn.Module):\n            def __init__(self):\n                super().__init__()\n                self.dict = {\"lin\": linear}\n\n            def forward(self, x):\n                return self._forward_impl(x)\n\n            def _forward_impl(self, x):\n                test_case.assertTrue(isinstance(self.dict[\"lin\"], flow.nn.Module))\n                return self.dict[\"lin\"](x)\n\n        class LinearTrainGraph(flow.nn.Graph):\n            def __init__(self):\n                super().__init__()\n                self.m = CustomModule()\n\n            def build(self, x):\n                out = self.m(x)\n                out = out.sum()\n                return out\n\n        linear_t_g = LinearTrainGraph()\n\n        with warnings.catch_warnings(record=True) as w:\n            # Here will print:\n            #     UserWarning: Linear(in_features=3, out_features=8, bias=True) is called in a nn.Graph, but not registered into a nn.Graph.\n            linear_t_g(x)\n\n            test_case.assertTrue(len(w) == 1)\n            test_case.assertTrue(issubclass(w[-1].category, UserWarning))\n            test_case.assertTrue(\n                \"is called in a nn.Graph, but not registered into a nn.Graph\"\n                in str(w[-1].message)\n            )\n\n    def test_block_with_seq_container(test_case):\n        class SubModule0(flow.nn.Module):\n            def __init__(self):\n                super().__init__()\n                self.linear = flow.nn.Linear(10, 10, False)\n\n            def forward(self, x):\n                if graph_build_util.lazy_mode.is_enabled():\n                    scope = scope_util.current_scope()\n                    scope_proto = graph_build_util.scope_to_proto(scope)\n                    ck_bool = scope_proto.attr_name2attr_value[\"checkpointing\"].at_bool\n                    test_case.assertEqual(ck_bool, True)\n                out = self.linear(x)\n                return out\n\n        list_of_m = [SubModule0() for i in range(3)]\n\n        class SeqModule(flow.nn.Module):\n            def __init__(self):\n                super().__init__()\n                self.linears = flow.nn.Sequential(*list_of_m)\n\n            def forward(self, x):\n                x = self.linears(x)\n                return x\n\n        class SeqGraph(flow.nn.Graph):\n            def __init__(self):\n                super().__init__()\n                self.linears = flow.nn.Sequential(*list_of_m)\n                self.linears.to(nn.graph.GraphModule).activation_checkpointing = True\n\n            def build(self, x):\n                x = self.linears(x)\n                return x\n\n        seq_m = SeqModule()\n        seq_g = SeqGraph()\n\n        input = flow.tensor(np.random.randn(4, 10), dtype=flow.float32)\n        output_m = seq_m(input)\n        output_g = seq_g(input)\n\n        # print(seq_g)\n        test_case.assertTrue(np.array_equal(output_m.numpy(), output_g.numpy()))\n\n    def test_block_with_list_container(test_case):\n        class SubModule0(flow.nn.Module):\n            def __init__(self):\n                super().__init__()\n                self.linear = flow.nn.Linear(10, 10, False)\n\n            def forward(self, x):\n                if graph_build_util.lazy_mode.is_enabled():\n                    scope = scope_util.current_scope()\n                    scope_proto = graph_build_util.scope_to_proto(scope)\n                    ck_bool = scope_proto.attr_name2attr_value[\"checkpointing\"].at_bool\n                    test_case.assertEqual(ck_bool, True)\n                out = self.linear(x)\n                return out\n\n        list_of_m = [SubModule0() for i in range(3)]\n\n        class ModuleListModule(flow.nn.Module):\n            def __init__(self):\n                super().__init__()\n                self.linears = flow.nn.ModuleList(list_of_m)\n\n            def forward(self, x):\n                for i, _ in enumerate(self.linears):\n                    x = self.linears[i](x)\n                return x\n\n        class ModuleListGraph(flow.nn.Graph):\n            def __init__(self):\n                super().__init__()\n                self.linears = flow.nn.ModuleList(list_of_m)\n                # NOTE: ModuleList doesn't have config.\n                # self.linears.to(GraphModule).activation_checkpointing = True\n                for i, _ in enumerate(self.linears):\n                    self.linears[i].to(\n                        nn.graph.GraphModule\n                    ).activation_checkpointing = True\n\n            def build(self, x):\n                # ModuleList can act as an iterable, or be indexed using ints\n                for i, _ in enumerate(self.linears):\n                    x = self.linears[i](x)\n\n                return x\n\n        module_list_m = ModuleListModule()\n        module_list_g = ModuleListGraph()\n\n        input = flow.tensor(np.random.randn(4, 10), dtype=flow.float32)\n        output_m = module_list_m(input)\n        output_g = module_list_g(input)\n\n        # print(module_list_g)\n        test_case.assertTrue(np.array_equal(output_m.numpy(), output_g.numpy()))\n\n    def test_module_list_slice(test_case):\n        class ModuleListSlice(nn.Module):\n            def __init__(self,):\n                super().__init__()\n                linear1 = nn.Linear(5, 5, bias=False)\n                linear2 = nn.Linear(5, 5, bias=False)\n                linear3 = nn.Linear(5, 5, bias=False)\n                self.modulelist = nn.ModuleList([linear1, linear2, linear3])\n\n            def forward(self, x):\n                sliced_m = self.modulelist[1:]\n                test_case.assertEqual(len(sliced_m), 2)\n                y = sliced_m[1](x)\n                return y\n\n        class GraphModuleListSlice(nn.Graph):\n            def __init__(self, m):\n                super().__init__()\n                self.m = m\n\n            def build(self, x):\n                return self.m(x)\n\n        in_tensor = flow.randn(5, 5)\n\n        m = ModuleListSlice()\n        eager_out = m(in_tensor)\n\n        g = GraphModuleListSlice(m)\n        graph_out = g(in_tensor)\n\n        test_case.assertTrue(np.array_equal(eager_out.numpy(), graph_out.numpy()))\n\n    def test_block_with_dict_container(test_case):\n        class SubModule0(flow.nn.Module):\n            def __init__(self, out):\n                super().__init__()\n                self.linear = flow.nn.Linear(10, out, False)\n\n            def forward(self, x):\n                if graph_build_util.lazy_mode.is_enabled():\n                    scope = scope_util.current_scope()\n                    scope_proto = graph_build_util.scope_to_proto(scope)\n                    ck_bool = scope_proto.attr_name2attr_value[\"checkpointing\"].at_bool\n                    test_case.assertEqual(ck_bool, True)\n                out = self.linear(x)\n                return out\n\n        dict_of_m = {\"0\": SubModule0(10), \"1\": SubModule0(6)}\n\n        class ModuleDictModule(flow.nn.Module):\n            def __init__(self):\n                super().__init__()\n                self.linears = flow.nn.ModuleDict(dict_of_m)\n\n            def forward(self, x):\n                x = self.linears[\"0\"](x)\n                x = self.linears[\"1\"](x)\n                return x\n\n        class ModuleDictGraph(flow.nn.Graph):\n            def __init__(self):\n                super().__init__()\n                self.linears = flow.nn.ModuleDict(dict_of_m)\n\n                # NOTE: ModuleDict doesn't have config.\n                # self.linears.to(GraphModule).activation_checkpointing = True\n                for k, _ in self.linears.items():\n                    self.linears[k].to(\n                        nn.graph.GraphModule\n                    ).activation_checkpointing = True\n\n            def build(self, x):\n                # ModuleDict can act as an iterable, or get using key\n                x = self.linears[\"0\"](x)\n                x = self.linears[\"1\"](x)\n                return x\n\n        module_dict_m = ModuleDictModule()\n        module_dict_g = ModuleDictGraph()\n\n        input = flow.tensor(np.random.randn(4, 10), dtype=flow.float32)\n        output_m = module_dict_m(input)\n        output_g = module_dict_g(input)\n\n        # print(module_dict_g)\n        test_case.assertTrue(np.array_equal(output_m.numpy(), output_g.numpy()))\n\n    def test_block_with_dict_container_nto1(test_case):\n        class SubModule0(flow.nn.Module):\n            def __init__(self, out):\n                super().__init__()\n                self.linear = flow.nn.Linear(10, out, False)\n\n            def forward(self, x):\n                if graph_build_util.lazy_mode.is_enabled():\n                    scope = scope_util.current_scope()\n                    scope_proto = graph_build_util.scope_to_proto(scope)\n                    ck_bool = scope_proto.attr_name2attr_value[\"checkpointing\"].at_bool\n                    test_case.assertEqual(ck_bool, True)\n                out = self.linear(x)\n                return out\n\n        sub_m = SubModule0(10)\n        dict_of_m = {\"0\": sub_m, \"1\": sub_m}\n\n        class ModuleDictModule(flow.nn.Module):\n            def __init__(self):\n                super().__init__()\n                self.linears = flow.nn.ModuleDict(dict_of_m)\n\n            def forward(self, x):\n                x = self.linears[\"0\"](x)\n                x = self.linears[\"1\"](x)\n                return x\n\n        class ModuleDictGraph(flow.nn.Graph):\n            def __init__(self):\n                super().__init__()\n                self.linears = flow.nn.ModuleDict(dict_of_m)\n\n                # NOTE: ModuleDict doesn't have config.\n                # self.linears.to(GraphModule).activation_checkpointing = True\n                for k, _ in self.linears.items():\n                    self.linears[k].to(\n                        nn.graph.GraphModule\n                    ).activation_checkpointing = True\n\n            def build(self, x):\n                # ModuleDict can act as an iterable, or get using key\n                x = self.linears[\"0\"](x)\n                x = self.linears[\"1\"](x)\n                return x\n\n        module_dict_m = ModuleDictModule()\n        module_dict_g = ModuleDictGraph()\n\n        input = flow.tensor(np.random.randn(4, 10), dtype=flow.float32)\n        output_m = module_dict_m(input)\n        output_g = module_dict_g(input)\n        print(module_dict_g)\n\n        # print(module_dict_g)\n        test_case.assertTrue(np.array_equal(output_m.numpy(), output_g.numpy()))\n\n    def test_block_with_para_list_container(test_case):\n        list_of_p = [flow.nn.Parameter(flow.randn(10, 10)) for i in range(2)]\n\n        class ParaListModule(flow.nn.Module):\n            def __init__(self):\n                super().__init__()\n                self.params = flow.nn.ParameterList(list_of_p)\n\n            def forward(self, x):\n                for i, _ in enumerate(self.params):\n                    x = flow._C.matmul(x, self.params[i])\n                return x\n\n        class ParaListGraph(flow.nn.Graph):\n            def __init__(self):\n                super().__init__()\n                self.params = flow.nn.ParameterList(list_of_p)\n\n            def build(self, x):\n                for i, _ in enumerate(self.params):\n                    x = flow._C.matmul(x, self.params[i])\n                return x\n\n        para_list_m = ParaListModule()\n        para_list_g = ParaListGraph()\n        # print(para_list_g)\n\n        input = flow.tensor(np.random.randn(4, 10), dtype=flow.float32)\n        output_m = para_list_m(input)\n        # print(output_m)\n        output_g = para_list_g(input)\n\n        # print(para_list_g)\n        test_case.assertTrue(np.array_equal(output_m.numpy(), output_g.numpy()))\n\n    def test_block_with_para_dict_container(test_case):\n        dict_of_p = {\n            \"0\": flow.nn.Parameter(flow.randn(10, 3)),\n            \"1\": flow.nn.Parameter(flow.randn(10, 10)),\n        }\n\n        class ParaDictModule(flow.nn.Module):\n            def __init__(self):\n                super().__init__()\n                self.params = flow.nn.ParameterDict(dict_of_p)\n\n            def forward(self, x):\n                x = flow._C.matmul(x, self.params[\"0\"])\n                return x\n\n        class ParaDictGraph(flow.nn.Graph):\n            def __init__(self):\n                super().__init__()\n                self.params = flow.nn.ParameterDict(dict_of_p)\n\n            def build(self, x):\n                x = flow._C.matmul(x, self.params[\"0\"])\n                return x\n\n        para_dict_m = ParaDictModule()\n        para_dict_g = ParaDictGraph()\n        # print(para_dict_g)\n\n        input = flow.tensor(np.random.randn(4, 10), dtype=flow.float32)\n        output_m = para_dict_m(input)\n        # print(output_m)\n        output_g = para_dict_g(input)\n\n        # print(para_dict_g)\n        test_case.assertTrue(np.array_equal(output_m.numpy(), output_g.numpy()))\n\n    def test_mixin_module(test_case):\n        class ModuleMixin(flow.nn.Module):\n            def __init__(self):\n                super().__init__()\n                self._dtype = flow.float32\n\n            @property\n            def dtype(self):\n                return self._dtype\n\n        class ConfigMixin:\n            def hello_from_cfg(self):\n                return \"hello_from_cfg\"\n\n            @property\n            def property_from_cfg(self):\n                return 128\n\n        class MixedModule(ModuleMixin, ConfigMixin):\n            def __init__(self):\n                super().__init__()\n\n            def forward(self, x):\n                test_case.assertEqual(self.dtype, flow.float32)\n                test_case.assertEqual(self.hello_from_cfg(), \"hello_from_cfg\")\n                test_case.assertEqual(self.property_from_cfg, 128)\n                return x\n\n        mixedm = MixedModule()\n\n        class GraphConfigMixin(object):\n            @property\n            def hello_from_graph(self):\n                return \"hello_from_gcfg\"\n\n            def mixin_get_name(self):\n                return self.name\n\n        class MixinGraph(flow.nn.Graph, GraphConfigMixin):\n            def __init__(self):\n                super().__init__()\n                self.m = mixedm\n\n            def build(self, x):\n                test_case.assertEqual(self.hello_from_graph, \"hello_from_gcfg\")\n                test_case.assertEqual(self.mixin_get_name(), self.name)\n                return self.m(x)\n\n        g = MixinGraph()\n        x = np.ones((10, 10))\n        x = flow.tensor(x, dtype=flow.float32)\n        out = g(x)\n        test_case.assertTrue(np.array_equal(x.numpy(), out.numpy()))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/graph/test_graph_buffer_limit.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport os\nimport time\nimport unittest\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\ndef _test_graph_buffer_limit(test_case):\n    class StageLayerModule(flow.nn.Module):\n        def __init__(self):\n            super().__init__()\n            self.linear1 = flow.nn.Linear(10, 8, False)\n            self.linear2 = flow.nn.Linear(8, 10, False)\n            flow.nn.init.constant_(self.linear1.weight, 0.023)\n            flow.nn.init.constant_(self.linear2.weight, 1.23)\n\n        def forward(self, x):\n            out0 = self.linear1(x)\n            out0 = out0 + 1.0\n            out0 = out0 * 2.0\n            out1 = self.linear2(out0)\n            return out1\n\n    P0 = flow.placement(\"cuda\", ranks=[0])\n    P1 = flow.placement(\"cuda\", ranks=[1])\n    PT = flow.placement(\"cuda\", ranks=[0, 1])\n    B = flow.sbp.broadcast\n\n    class PipelineModule(flow.nn.Module):\n        def __init__(self):\n            super().__init__()\n            self.layer_0 = StageLayerModule()\n            self.layer_1 = StageLayerModule()\n            self.layer_0.to_global(P0, B)\n            self.layer_1.to_global(P1, B)\n\n        def forward(self, x):\n            # stage 0\n            in0 = x.to_global(P0, B)\n            out0 = self.layer_0(in0)\n            # stage 1\n            in1 = out0.to_global(P1, B)\n            out1 = self.layer_1(in1)\n            return out1\n\n    pp_m = PipelineModule()\n    pp_m.eval()\n\n    class PipelineGraph(flow.nn.Graph):\n        def __init__(self):\n            super().__init__()\n            self.pp_m = pp_m\n\n        def build(self, x):\n            return self.pp_m(x)\n\n    pp_g = PipelineGraph()\n\n    for i in range(500):\n        x = flow.randn(16, 10)\n        x = x.to_global(P0, B)\n        out = pp_g(x)\n        # print(out.to_local().mean())\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n@flow.unittest.skip_unless_1n2d()\nclass TestGraphPipelineBufferLimit(oneflow.unittest.TestCase):\n    def test_graph_buffer_limit(test_case):\n        _test_graph_buffer_limit(test_case)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/graph/test_graph_clip_grad_norm.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport os\nimport unittest\nimport numpy as np\n\nimport oneflow as flow\nfrom oneflow.nn.graph import GraphModule\nimport oneflow.unittest\n\n\nclass MyModule1(flow.nn.Module):\n    def __init__(self, param):\n        super().__init__()\n        self.param = flow.nn.Parameter(param)\n\n    def forward(self, input):\n        x = flow._C.matmul(input, self.param, transpose_b=True)\n        return flow._C.gelu(x)\n\n\nclass MyModule2(flow.nn.Module):\n    def __init__(self, param):\n        super().__init__()\n        self.param = flow.nn.Parameter(param)\n\n    def forward(self, input, target):\n        x = flow._C.matmul(input, self.param)\n        loss = flow._C.sparse_softmax_cross_entropy(x, target)\n        return loss.mean()\n        # return loss\n\n\ndef _make_optimizer(params, norm_type, max_norm):\n    return flow.optim.SGD(\n        [\n            {\n                \"params\": params,\n                \"lr\": 1.0,\n                \"momentum\": 0.0,\n                \"clip_grad_max_norm\": max_norm,\n                \"clip_grad_norm_type\": norm_type,\n            },\n        ]\n    )\n\n\nclass MyGraph(flow.nn.Graph):\n    def __init__(self, module1, module2, optimizer=None, acc=1):\n        super().__init__()\n\n        self.m1 = module1\n        self.m2 = module2\n\n        if (\n            module1.param.is_global\n            and module2.param.is_global\n            and module1.param.placement != module2.param.placement\n        ):\n            self.m1.to(GraphModule).set_stage(0)\n            self.m2.to(GraphModule).set_stage(1)\n\n        if optimizer is not None:\n            self.add_optimizer(optimizer)\n\n        if acc > 1:\n            self.config.set_gradient_accumulation_steps(acc)\n\n    def build(self, input, target):\n        x = self.m1(input)\n        if x.is_global and target.is_global and x.placement != target.placement:\n            x = x.to_global(placement=target.placement)\n        loss = self.m2(x, target)\n        loss.backward()\n        return loss\n\n\nclass TensorGenerator(object):\n    def __init__(\n        self, batch_size=8, feat1=10, feat2=8, device=\"cuda\", parallel_mode=None\n    ):\n        input = flow.randn(batch_size, feat1).to(device)\n        param1 = flow.randn(feat2, feat1).to(device)\n        param2 = flow.randn(feat2, feat1).to(device)\n        target = flow.randint(0, 10, (batch_size,)).to(device)\n\n        ranks = np.array(range(flow.env.get_world_size()))\n        placement = flow.placement(device, ranks)\n        self.input = input.to_global(placement, sbp=flow.sbp.broadcast)\n        self.param1 = param1.to_global(placement, sbp=flow.sbp.broadcast)\n        self.param2 = param2.to_global(placement, sbp=flow.sbp.broadcast)\n        self.target = target.to_global(placement, sbp=flow.sbp.broadcast)\n\n        self.input_sbp = None\n        self.target_sbp = None\n        self.param1_sbp = None\n        self.param2_sbp = None\n        self.placement1 = None\n        self.placement2 = None\n\n        if parallel_mode is not None:\n            assert isinstance(parallel_mode, str) or isinstance(\n                parallel_mode, (list, tuple)\n            )\n\n            if isinstance(parallel_mode, str):\n                parallel_mode = [parallel_mode]\n\n            assert all(p.upper() in (\"DP\", \"MP\", \"PP\") for p in parallel_mode)\n            assert len(parallel_mode) > 0 and len(parallel_mode) <= 2\n\n            self.input_sbp = []\n            self.target_sbp = []\n            self.param1_sbp = []\n            self.param2_sbp = []\n\n            has_pp = False\n\n            for p in parallel_mode:\n                if p == \"DP\":\n                    self.input_sbp.append(flow.sbp.split(0))\n                    self.target_sbp.append(flow.sbp.split(0))\n                    self.param1_sbp.append(flow.sbp.broadcast())\n                    self.param2_sbp.append(flow.sbp.broadcast())\n                elif p == \"MP\":\n                    self.input_sbp.append(flow.sbp.broadcast())\n                    self.target_sbp.append(flow.sbp.broadcast())\n                    self.param1_sbp.append(flow.sbp.split(0))\n                    self.param2_sbp.append(flow.sbp.split(0))\n                elif p == \"PP\":\n                    ranks = ranks.reshape(2, -1)\n                    self.placement1 = flow.placement(device, ranks[0])\n                    self.placement2 = flow.placement(device, ranks[1])\n                    has_pp = True\n                else:\n                    raise ValueError\n\n            if len(parallel_mode) > 1 and not has_pp:\n                ranks = ranks.reshape(2, -1)\n                self.placement1 = flow.placement(device, ranks)\n                self.placement2 = flow.placement(device, ranks)\n\n            if len(self.input_sbp) == 0:\n                self.input_sbp = None\n\n            if len(self.target_sbp) == 0:\n                self.target_sbp = None\n\n            if len(self.param1_sbp) == 0:\n                self.param1_sbp = None\n\n            if len(self.param2_sbp) == 0:\n                self.param2_sbp = None\n\n    def local_input(self):\n        return self.input.to_local()\n\n    def local_target(self):\n        return self.target.to_local()\n\n    def local_param1(self):\n        return self.param1.clone().to_local()\n\n    def local_param2(self):\n        return self.param2.clone().to_local()\n\n    def global_input(self):\n        if self.input_sbp is None and self.placement1 is None:\n            return self.input\n\n        return self.input.to_global(placement=self.placement1, sbp=self.input_sbp)\n\n    def global_target(self):\n        if self.target_sbp is None and self.placement2 is None:\n            return self.target\n\n        return self.target.to_global(placement=self.placement2, sbp=self.target_sbp)\n\n    def global_param1(self):\n        if self.param1_sbp is None and self.placement1 is None:\n            return self.param1.clone()\n\n        return self.param1.to_global(placement=self.placement1, sbp=self.param1_sbp)\n\n    def global_param2(self):\n        if self.param2_sbp is None and self.placement2 is None:\n            return self.param2.clone()\n\n        return self.param2.to_global(placement=self.placement2, sbp=self.param2_sbp)\n\n\ndef _compare_with_eager(\n    test_case,\n    *,\n    batch_size=8,\n    acc=1,\n    norm_type=2.0,\n    max_norm=1.0,\n    device=\"cuda\",\n    parallel_mode=None,\n    rtol=1e-03,\n    atol=1e-05,\n):\n    gen = TensorGenerator(\n        batch_size=batch_size, device=device, parallel_mode=parallel_mode\n    )\n\n    # eager\n    m1 = MyModule1(gen.local_param1())\n    m2 = MyModule2(gen.local_param2())\n    opt = _make_optimizer([m1.param, m2.param], norm_type, max_norm)\n    x = m1(gen.local_input())\n    loss = m2(x, gen.local_target())\n    opt.zero_grad()\n    loss.backward()\n    opt.clip_grad()\n    opt.step()\n\n    loss_a = loss.numpy()\n    grad1_a = m1.param.numpy()\n    grad2_a = m2.param.numpy()\n\n    # graph\n    graph_m1 = MyModule1(gen.global_param1())\n    graph_m2 = MyModule2(gen.global_param2())\n    opt = _make_optimizer([graph_m1.param, graph_m2.param], norm_type, max_norm)\n    graph = MyGraph(graph_m1, graph_m2, opt, acc)\n    graph_loss = graph(gen.global_input(), gen.global_target())\n\n    # debug\n    # rank = flow.env.get_rank()\n    # print(\"\")\n    # print(f\"[rank{rank}] eager local loss: {loss}\")\n\n    # print(\n    #     f\"[rank{rank}] graph_loss placement: {graph_loss.placement}, sbp: {graph_loss.sbp}\"\n    # )\n    # print(f\"[rank{rank}] graph_loss: {graph_loss}\")\n\n    # local_loss = graph_loss.to_local()\n    # print(f\"[rank{rank}] local_loss.numel(): {local_loss.numel()}\")\n    # print(f\"[rank{rank}] local_loss: {local_loss}\")\n\n    if acc > 1 and graph_loss.numel() == acc:\n        graph_loss = graph_loss.mean()\n\n    if parallel_mode is None:\n        loss_b = graph_loss.numpy()\n        grad1_b = graph.m1.to(flow.nn.Module).param.numpy()\n        grad2_b = graph.m2.to(flow.nn.Module).param.numpy()\n    else:\n        ranks = np.array(range(flow.env.get_world_size()))\n        placement = flow.placement(device, ranks)\n        loss_b = graph_loss.to_global(placement, flow.sbp.broadcast).to_local().numpy()\n        grad1_b = graph.m1.to(flow.nn.Module).param.to_global(\n            placement, flow.sbp.broadcast\n        )\n        grad1_b = grad1_b.to_local().numpy()\n        grad2_b = graph.m2.to(flow.nn.Module).param.to_global(\n            placement, flow.sbp.broadcast\n        )\n        grad2_b = grad2_b.to_local().numpy()\n\n    # compare\n    test_case.assertTrue(\n        np.allclose(loss_a, loss_b, rtol=rtol, atol=atol), f\"{loss_a} vs. {loss_b}\"\n    )\n    test_case.assertTrue(\n        np.allclose(grad1_a, grad1_b, rtol=rtol, atol=atol),\n        f\"\\n{grad1_a}\\nvs.\\n{grad1_b}\",\n    )\n    test_case.assertTrue(\n        np.allclose(grad2_a, grad2_b, rtol=rtol, atol=atol),\n        f\"\\n{grad2_a}\\nvs.\\n{grad2_b}\",\n    )\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\nclass TestGraphClipGradNorm(flow.unittest.TestCase):\n    @flow.unittest.skip_unless_1n1d()\n    def test_local(test_case):\n        _compare_with_eager(test_case)\n\n    @flow.unittest.skip_unless_1n1d()\n    def test_acc(test_case):\n        _compare_with_eager(test_case, batch_size=8, acc=8)\n\n    @flow.unittest.skip_unless_1n2d()\n    def test_dp(test_case):\n        _compare_with_eager(test_case, parallel_mode=\"DP\")\n\n    @flow.unittest.skip_unless_1n2d()\n    def test_mp(test_case):\n        _compare_with_eager(test_case, parallel_mode=\"MP\")\n\n    @flow.unittest.skip_unless_1n2d()\n    def test_pp(test_case):\n        _compare_with_eager(test_case, parallel_mode=\"PP\")\n\n    @flow.unittest.skip_unless_1n2d()\n    def test_pp_acc(test_case):\n        _compare_with_eager(test_case, batch_size=8, acc=8, parallel_mode=\"PP\")\n\n    @flow.unittest.skip_unless_1n4d()\n    def test_dp_mp(test_case):\n        _compare_with_eager(test_case, parallel_mode=[\"DP\", \"MP\"])\n\n    @flow.unittest.skip_unless_1n4d()\n    def test_mp_pp(test_case):\n        _compare_with_eager(test_case, parallel_mode=[\"MP\", \"PP\"])\n\n    @flow.unittest.skip_unless_1n4d()\n    def test_dp_pp(test_case):\n        _compare_with_eager(test_case, parallel_mode=[\"DP\", \"PP\"])\n\n    @flow.unittest.skip_unless_1n4d()\n    def test_mp_pp_acc(test_case):\n        _compare_with_eager(test_case, batch_size=8, acc=8, parallel_mode=[\"MP\", \"PP\"])\n\n    @flow.unittest.skip_unless_1n4d()\n    def test_dp_pp_acc(test_case):\n        _compare_with_eager(test_case, batch_size=8, acc=4, parallel_mode=[\"DP\", \"PP\"])\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\nclass TestGraphClipGradNormInf(flow.unittest.TestCase):\n    @flow.unittest.skip_unless_1n1d()\n    def test_local(test_case):\n        _compare_with_eager(test_case, norm_type=float(\"inf\"))\n\n    @flow.unittest.skip_unless_1n1d()\n    def test_acc(test_case):\n        _compare_with_eager(\n            test_case, batch_size=8, acc=8, norm_type=-float(\"inf\"), atol=1e-6\n        )\n\n    @flow.unittest.skip_unless_1n2d()\n    def test_dp(test_case):\n        _compare_with_eager(\n            test_case,\n            norm_type=float(\"inf\"),\n            max_norm=2.0,\n            parallel_mode=\"DP\",\n            atol=1e-6,\n        )\n\n    @flow.unittest.skip_unless_1n2d()\n    def test_mp(test_case):\n        _compare_with_eager(\n            test_case,\n            norm_type=-float(\"inf\"),\n            max_norm=3.0,\n            parallel_mode=\"MP\",\n            atol=1e-6,\n        )\n\n    @flow.unittest.skip_unless_1n2d()\n    def test_pp(test_case):\n        _compare_with_eager(\n            test_case,\n            norm_type=float(\"inf\"),\n            max_norm=4.0,\n            parallel_mode=\"PP\",\n            atol=1e-6,\n        )\n\n    @flow.unittest.skip_unless_1n2d()\n    def test_pp_acc(test_case):\n        _compare_with_eager(\n            test_case,\n            batch_size=8,\n            acc=8,\n            norm_type=-float(\"inf\"),\n            max_norm=5.0,\n            parallel_mode=\"PP\",\n            atol=1e-6,\n        )\n\n    @flow.unittest.skip_unless_1n4d()\n    def test_dp_mp(test_case):\n        _compare_with_eager(\n            test_case,\n            norm_type=float(\"inf\"),\n            max_norm=1.1,\n            parallel_mode=[\"DP\", \"MP\"],\n            atol=1e-6,\n        )\n\n    @flow.unittest.skip_unless_1n4d()\n    def test_mp_pp(test_case):\n        _compare_with_eager(\n            test_case,\n            norm_type=-float(\"inf\"),\n            max_norm=1.2,\n            parallel_mode=[\"MP\", \"PP\"],\n            atol=1e-6,\n        )\n\n    @flow.unittest.skip_unless_1n4d()\n    def test_dp_pp(test_case):\n        _compare_with_eager(\n            test_case,\n            norm_type=float(\"inf\"),\n            max_norm=1.3,\n            parallel_mode=[\"DP\", \"PP\"],\n            atol=1e-6,\n        )\n\n    @flow.unittest.skip_unless_1n4d()\n    def test_mp_pp_acc(test_case):\n        _compare_with_eager(\n            test_case,\n            batch_size=8,\n            acc=8,\n            norm_type=float(\"inf\"),\n            max_norm=2.1,\n            parallel_mode=[\"MP\", \"PP\"],\n            atol=1e-6,\n        )\n\n    @flow.unittest.skip_unless_1n4d()\n    def test_dp_pp_acc(test_case):\n        _compare_with_eager(\n            test_case,\n            batch_size=8,\n            acc=4,\n            norm_type=-float(\"inf\"),\n            max_norm=2.2,\n            parallel_mode=[\"DP\", \"PP\"],\n            atol=1e-6,\n        )\n\n\nif __name__ == \"__main__\":\n    # flow.manual_seed(0)\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/graph/test_graph_copy.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport os\nimport unittest\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n@flow.unittest.skip_unless_1n1d()\nclass TestCopyGraph(oneflow.unittest.TestCase):\n    @unittest.skip(\"skip for now, becase it failed 2 times in past week\")\n    def test_copy_graph(test_case):\n        linear = flow.nn.Linear(3, 8, False)\n        input_arr = np.array(\n            [\n                [-0.94630778, -0.83378579, -0.87060891],\n                [2.0289922, -0.28708987, -2.18369248],\n                [0.35217619, -0.67095644, -1.58943879],\n                [0.08086036, -1.81075924, 1.20752494],\n                [0.8901075, -0.49976737, -1.07153746],\n                [-0.44872912, -1.07275683, 0.06256855],\n                [-0.22556897, 0.74798368, 0.90416439],\n                [0.48339456, -2.32742195, -0.59321527],\n            ],\n            dtype=np.float32,\n        )\n        np_weight = np.ones((3, 8)).astype(np.float32)\n        np_weight.fill(2.3)\n        x = flow.tensor(input_arr)\n        flow.nn.init.constant_(linear.weight, 2.3)\n        of_eager_out = linear(x)\n        np_out = np.matmul(input_arr, np_weight)\n        test_case.assertTrue(np.allclose(of_eager_out.numpy(), np_out, 1e-05, 1e-05))\n\n        class LinearGraph(flow.nn.Graph):\n            def __init__(self):\n                super().__init__()\n                self.my_linear = linear.to(flow.device(\"cuda\"))\n\n            def build(self, x):\n                x = x.to(flow.device(\"cuda\"))\n                return self.my_linear(x)\n\n        linear_g = LinearGraph()\n        of_lazy_out = linear_g(x)\n        test_case.assertTrue(\n            np.allclose(of_lazy_out.numpy(), of_eager_out.numpy(), 1e-05, 1e-05)\n        )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/graph/test_graph_debug.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport os\nimport sys\nimport unittest\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.nn.graph import GraphModule\n\n\nrank = flow.env.get_rank()\n\n\ndef _graph_debug(test_case, v_level=0, ranks=None, max_py_stack_depth=2):\n    class DebugGraph(flow.nn.Graph):\n        def __init__(self):\n            super().__init__()\n            self.m = flow.nn.Linear(3, 3)\n\n        def build(self, x):\n            return x\n\n    d_g = DebugGraph()\n    d_g.debug(v_level, ranks=ranks, max_py_stack_depth=max_py_stack_depth)\n\n    if ranks is None:\n        rank_list = [0]\n    elif isinstance(ranks, int):\n        rank_list = [ranks]\n    elif isinstance(ranks, list):\n        rank_list = ranks\n\n    if (\n        -1 in rank_list or rank in rank_list\n    ) and v_level >= 0:  # v_level == -1 means debug mode is closed\n        test_case.assertTrue(d_g._debug)\n        test_case.assertTrue(d_g.m.to(GraphModule)._debug)\n        print(f\"ranks {ranks} rank {rank} debug is opened.\")\n    else:\n        test_case.assertTrue(not d_g._debug)\n        test_case.assertTrue(not d_g.m.to(GraphModule)._debug)\n        print(f\"ranks {ranks} rank {rank} debug is closed.\")\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n@flow.unittest.skip_unless_1n4d()\nclass TestGraphDebug(oneflow.unittest.TestCase):\n    def test_graph_debug_rank_null(test_case):\n        _graph_debug(test_case)\n\n    def test_graph_debug_rank_0(test_case):\n        _graph_debug(test_case, ranks=0)\n\n    def test_graph_debug_rank_1(test_case):\n        _graph_debug(test_case, ranks=1)\n\n    def test_graph_debug_rank_1_and_2(test_case):\n        _graph_debug(test_case, ranks=[1, 2])\n\n    def test_graph_debug_rank_all(test_case):\n        _graph_debug(test_case, ranks=-1)\n\n    def test_graph_debug_mode_closed(test_case):\n        _graph_debug(test_case, v_level=-1)\n\n    def test_graph_debug_mode_opened(test_case):\n        _graph_debug(test_case, v_level=0)\n\n    def test_graph_debug_max_py_stack_depth_2(test_case):\n        _graph_debug(test_case, max_py_stack_depth=2)\n\n    def test_graph_debug_max_py_stack_depth_8(test_case):\n        _graph_debug(test_case, max_py_stack_depth=8)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/graph/test_graph_depend.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport numpy as np\nimport unittest\n\n# used to observe operator optimization and execution order manually\n# import os\n# os.environ[\"ONEFLOW_DEBUG_MODE\"] = \"1\"\n# os.environ[\"GLOG_v\"] = \"3\"\n# os.environ[\"ENABLE_LOGICAL_CHAIN\"] = \"true\"\n\nimport oneflow as flow\nimport oneflow.nn as nn\nimport oneflow.unittest\n\n# NOTE: nn.functional.depend() behaves differently in the two modes\n# in EAGER mode, the OP has no effect. That is, the first paramerter\n# and output are the same tensor (like \"y=x\" in python), while the\n# second paramerter will be ignore.\n\n\ndef _build_graph_and_test(TestModel, in_data, test_case):\n\n    model = TestModel()\n    y_eager = model(in_data)\n\n    class TestGraph(flow.nn.Graph):\n        def __init__(self):\n            super().__init__()\n            self.model = model\n\n        def build(self, x):\n            return self.model(x)\n\n    graph = TestGraph()\n    # used to observe operator optimization and execution order manually\n    # graph.debug(3)\n    y_lazy = graph(in_data)\n    test_case.assertTrue(np.array_equal(y_eager.numpy(), y_lazy.numpy()))\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestDependGraph(oneflow.unittest.TestCase):\n    def test_depend_graph_case0(test_case):\n        class TestModel_0(nn.Module):\n            def __init__(self):\n                super().__init__()\n                self.linear = nn.Linear(128, 128)\n\n            def forward(self, x):\n                # to ensure \"x * 2\" be executed before \"self.linear(x)\" in graph mode\n                # base use case\n                x1 = x * 2\n                x = nn.functional.depend(x, x1)\n                x2 = self.linear(x)\n                return x2\n\n        x = flow.randn([1, 128], dtype=flow.float32)\n        _build_graph_and_test(TestModel_0, x, test_case)\n\n    def test_depend_graph_case1(test_case):\n        class TestModel_1(nn.Module):\n            def __init__(self):\n                super().__init__()\n                self.linear = nn.Linear(128, 128)\n\n            def forward(self, x):\n                # to ensure \"x * 2\" and \"x + 2\" be executed before \"self.linear(x)\" in graph mode\n                # test multiple continuous nn.functional.depend() in a logical chain\n                x1 = x * 2\n                x2 = x + 2\n                x = nn.functional.depend(x, x1)\n                x = nn.functional.depend(x, x2)\n                x3 = self.linear(x)\n                return x3\n\n        x = flow.randn([1, 128], dtype=flow.float32)\n        _build_graph_and_test(TestModel_1, x, test_case)\n\n    def test_depend_graph_case2(test_case):\n        class TestModel_2(nn.Module):\n            def __init__(self):\n                super().__init__()\n                self.linear = nn.Linear(128, 128)\n\n            def forward(self, x):\n                # to ensure \"x * 2\" and \"x + 2\" be executed before \"self.linear(x)\" in graph mode\n                # some users may code like this\n                x1 = x * 2\n                x2 = x + 2\n                x2 = nn.functional.depend(x2, x1)\n                x = nn.functional.depend(x, x2)\n                x3 = self.linear(x)\n                return x3\n\n        x = flow.randn([1, 128], dtype=flow.float32)\n        _build_graph_and_test(TestModel_2, x, test_case)\n\n    def test_depend_graph_case3(test_case):\n        class TestModel_3(nn.Module):\n            def __init__(self):\n                super().__init__()\n                self.linear = nn.Linear(128, 128)\n\n            def forward(self, x):\n                # to ensure \"x * 2\", \"x + 2\" and \"x -2\" be executed before \"self.linear(x)\" in graph mode\n                # a combination of above cases\n                x1 = x * 2\n                x2 = x + 2\n                x3 = x - 2\n                x = nn.functional.depend(x, x1)\n                x2 = nn.functional.depend(x2, x3)\n                x = nn.functional.depend(x, x2)\n                x3 = self.linear(x)\n                return x3\n\n        x = flow.randn([1, 128], dtype=flow.float32)\n        _build_graph_and_test(TestModel_3, x, test_case)\n\n    def test_depend_graph_case4(test_case):\n        class TestModel_4(nn.Module):\n            def __init__(self):\n                super().__init__()\n                self.linear = nn.Linear(128, 128)\n\n            def forward(self, x):\n                # the depend OP do nothing and it should be pruned from graph correctly\n                x1 = x * 2\n                x2 = nn.functional.depend(x, x1)\n                x3 = self.linear(x)\n                return x3\n\n        x = flow.randn([1, 128], dtype=flow.float32)\n        _build_graph_and_test(TestModel_4, x, test_case)\n\n    def test_depend_graph_case5(test_case):\n        class TestModel_5(nn.Module):\n            def __init__(self):\n                super().__init__()\n                self.linear0 = nn.Linear(128, 128)\n                self.linear1 = nn.Linear(128, 128)\n\n            def forward(self, x):\n                # to ensure \"x * 2\" be executed before \"self.linear0(x)\" and\n                # \"self.linear1(x)\" in graph mode\n                # to test the case that depend OP connect to more than one OPs\n                x1 = x * 2\n                x = nn.functional.depend(x, x1)\n                x2 = self.linear0(x)\n                x3 = self.linear1(x)\n                return x2 + x3\n\n        x = flow.randn([1, 128], dtype=flow.float32)\n        _build_graph_and_test(TestModel_5, x, test_case)\n\n    def test_depend_graph_case6(test_case):\n        class TestModel_6(nn.Module):\n            def __init__(self):\n                super().__init__()\n                self.linear = nn.Linear(128, 128)\n\n            def forward(self, x):\n                # to ensure \"x - 2\" be executed before \"self.linear(x)\" in graph mode\n                # to test the case that the OP connects to Depend OP also connects to other OPs\n                x1 = x * 2\n                x2 = x1 - 2\n                x3 = nn.functional.depend(x2, x1)\n                x4 = self.linear(x3)\n                x5 = x2 + x4\n                return x5\n\n        x = flow.randn([1, 128], dtype=flow.float32)\n        _build_graph_and_test(TestModel_6, x, test_case)\n\n    def test_depend_graph_case7(test_case):\n        class TestModel_7(nn.Module):\n            def __init__(self):\n                super().__init__()\n\n            def forward(self, x):\n                # to ensure \"mp_values * 2\" be executed before \"max_pool1d\" in graph mode\n                # to test the case that OPs have mutiple outputs connect to depend OP\n                x1 = x + 2\n                mp_values, mp_indices = nn.functional.max_pool1d(\n                    x, kernel_size=2, return_indices=True\n                )\n                mp_values = nn.functional.depend(mp_values, x1)\n                mp_values = mp_values * 2\n                return mp_values + mp_indices.to(flow.float32)\n\n        x = flow.randn([1, 2, 3], dtype=flow.float32)\n        _build_graph_and_test(TestModel_7, x, test_case)\n\n    def test_depend_graph_case8(test_case):\n        class TestModel_1(nn.Module):\n            def __init__(self):\n                super().__init__()\n                self.linear = nn.Linear(128, 128)\n\n            def forward(self, x):\n                # to ensure \"x * 2\" and \"x + 2\" be executed before \"self.linear(x)\" in graph mode\n                # to test the case that inputting mutiple depend tensors at a time\n                x1 = x * 2\n                x2 = x + 2\n                x = nn.functional.depend(x, [x1, x2])\n                x3 = self.linear(x)\n                return x3\n\n        x = flow.randn([1, 128], dtype=flow.float32)\n        _build_graph_and_test(TestModel_1, x, test_case)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/graph/test_graph_eye.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nimport numpy as np\nimport random\nimport oneflow as flow\nimport oneflow.unittest\nfrom test_util import generate_graph\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestEyeGraph(oneflow.unittest.TestCase):\n    def test_eye_graph(test_case):\n        n = random.randint(1, 10)\n        m = random.randint(1, 10)\n\n        eye_fn = lambda: flow.eye(n, m)\n        y_eager = eye_fn()\n        eye_graph = generate_graph(eye_fn)\n        y_lazy = eye_graph()\n        test_case.assertTrue(np.array_equal(y_eager.numpy(), y_lazy.numpy()))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/graph/test_graph_free_eager_tensor.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport os\nimport unittest\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n@flow.unittest.skip_unless_1n1d()\nclass TestGraphWithEagerTensorCaught(oneflow.unittest.TestCase):\n    def test_eager_tensor_forward_graph(test_case):\n        class MyModuleWithEagerTensorForward(flow.nn.Module):\n            def __init__(self):\n                super().__init__()\n                self.linear = flow.nn.Linear(3, 8, False)\n\n            def forward(self, x):\n                y0 = self.linear(x)\n                eager_t = flow.tensor([1.0], dtype=y0.dtype, device=y0.device)\n                out = y0 + eager_t\n                return out\n\n        my_net_module = MyModuleWithEagerTensorForward()\n        flow.nn.init.constant_(my_net_module.linear.weight, 2.3)\n        x = np.random.randn(5, 3)\n        x = flow.tensor(x, dtype=flow.float32)\n\n        class GraphEagerTensorCaught(flow.nn.Graph):\n            def __init__(self):\n                super().__init__()\n                self.my_net = my_net_module\n\n            def build(self, x):\n                return self.my_net(x)\n\n        my_g = GraphEagerTensorCaught()\n        graph_out = my_g(x)\n        eager_out = my_net_module(x)\n        test_case.assertTrue(\n            np.allclose(graph_out.numpy(), eager_out.numpy(), atol=1e-4, rtol=1e-4)\n        )\n\n    @unittest.skip(\"skip for now, becase it failed 2 times in past week\")\n    def test_eager_tensor_to(test_case):\n        class EagerTensorToModule(flow.nn.Module):\n            def __init__(self):\n                super().__init__()\n\n            def forward(self):\n                # test free eager tensor to\n                t = flow.tensor([1.0], dtype=flow.float32).to(\"cuda\")\n                return t\n\n        e_m = EagerTensorToModule()\n\n        class EagerTensorToGraph(flow.nn.Graph):\n            def __init__(self):\n                super().__init__()\n                self.e_m = e_m\n\n            def build(self):\n                return self.e_m()\n\n        e_g = EagerTensorToGraph()\n        graph_out = e_g()\n        eager_out = e_m()\n        test_case.assertTrue(\n            np.allclose(graph_out.numpy(), eager_out.numpy(), atol=1e-4, rtol=1e-4)\n        )\n\n    def test_two_graph_caught_same_free_eager_tensor(test_case):\n        np_x = np.random.randn(5, 3)\n        np_y = np.random.randn(5, 3)\n        x = flow.tensor(np_x, dtype=flow.float32)\n        y = flow.tensor(np_y, dtype=flow.float32)\n\n        class GraphAdd(flow.nn.Graph):\n            def __init__(self):\n                super().__init__()\n\n            def build(self):\n                return x + y\n\n        class GraphMul(flow.nn.Graph):\n            def __init__(self):\n                super().__init__()\n\n            def build(self):\n                return x * y\n\n        g_add = GraphAdd()\n        g_mul = GraphMul()\n\n        add_out = g_add()\n        mul_out = g_mul()\n        test_case.assertTrue(\n            np.allclose(add_out.numpy(), np_x + np_y, atol=1e-4, rtol=1e-4)\n        )\n        test_case.assertTrue(\n            np.allclose(mul_out.numpy(), np_x * np_y, atol=1e-4, rtol=1e-4)\n        )\n\n    def test_graph_return_free_eager_tensor(test_case):\n        np_x = np.random.randn(5, 3)\n        x = flow.tensor(np_x, dtype=flow.float32)\n\n        class GraphReturnEager(flow.nn.Graph):\n            def __init__(self):\n                super().__init__()\n\n            def build(self):\n                # Return free eager tensor\n                return x\n\n        g_return_eager = GraphReturnEager()\n\n        # Run first time\n        ret_eager_out = g_return_eager()\n        test_case.assertTrue(\n            np.allclose(ret_eager_out.numpy(), np_x, atol=1e-4, rtol=1e-4)\n        )\n\n        # Run second time\n        ret_eager_out1 = g_return_eager()\n        test_case.assertTrue(\n            np.allclose(ret_eager_out1.numpy(), np_x, atol=1e-4, rtol=1e-4)\n        )\n\n    def test_graph_return_inplace_free_eager_tensor(test_case):\n        np_x = np.random.randn(5, 3)\n        x = flow.tensor(np_x, dtype=flow.float32)\n\n        class GraphInplaceReturnEager(flow.nn.Graph):\n            def __init__(self):\n                super().__init__()\n\n            def build(self):\n                # x is free eager tensor\n                # mul_ is inplace scalar mul\n                # Input and output of mul_ are both tensor x\n                # After lazy interpretr, tensor x's name will be the ouput lbn of mul_\n                x.mul_(2)\n                # Here will return the output of mul_\n                return x\n\n        g_return_eager = GraphInplaceReturnEager()\n\n        # Run first time\n        ret_eager_out = g_return_eager()\n        # x in ouput changed\n        # So nn.Graph simulate inplace in nn.Graph.build().\n        test_case.assertTrue(\n            np.allclose(ret_eager_out.numpy(), np_x * 2, atol=1e-4, rtol=1e-4)\n        )\n        # x has not changed\n        # So nn.Graph inplace will not change free eager tensor.\n        test_case.assertTrue(np.allclose(x.numpy(), np_x, atol=1e-4, rtol=1e-4))\n\n        # Run second time\n        ret_eager_out = g_return_eager()\n        test_case.assertTrue(\n            np.allclose(ret_eager_out.numpy(), np_x * 2, atol=1e-4, rtol=1e-4)\n        )\n        test_case.assertTrue(np.allclose(x.numpy(), np_x, atol=1e-4, rtol=1e-4))\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n@flow.unittest.skip_unless_1n2d()\nclass GlobalFreeEagerTensorGraphTestCase(oneflow.unittest.TestCase):\n    def test_global_eager_tensor_to(test_case):\n        rank = flow.env.get_rank()\n        placement = flow.placement(\"cpu\", ranks=[0, 1])\n        t_l = flow.tensor([1.0, 2.0], dtype=flow.float32)\n        t = t_l.to_global(placement=placement, sbp=flow.sbp.broadcast)\n\n        class GlobalEagerTensorToModule(flow.nn.Module):\n            def __init__(self):\n                super().__init__()\n\n            def forward(self):\n                # test free eager tensor to\n                nonlocal t\n                t = t.to(\"cuda\")\n                return t\n\n        e_m = GlobalEagerTensorToModule()\n\n        class GlobalEagerTensorToGraph(flow.nn.Graph):\n            def __init__(self):\n                super().__init__()\n                self.e_m = e_m\n\n            def build(self):\n                return self.e_m()\n\n        e_g = GlobalEagerTensorToGraph()\n        graph_out = e_g().to_local()\n        print(\"g \", graph_out.numpy())\n        test_case.assertTrue(\n            np.allclose(graph_out.numpy(), t_l.numpy(), atol=1e-4, rtol=1e-4)\n        )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/graph/test_graph_grad_acc.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nimport os\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\ndef _test_grad_acc_graph(test_case, device):\n    def get_linear_sgd():\n        linear = flow.nn.Linear(3, 8)\n        linear = linear.to(device)\n        flow.nn.init.constant_(linear.weight, 2.068758)\n        flow.nn.init.constant_(linear.bias, 1.23)\n        of_sgd = flow.optim.SGD(linear.parameters(), lr=0.01, momentum=0.9)\n        return linear, of_sgd\n\n    x = flow.tensor(\n        [\n            [-0.94630778, -0.83378579, -0.87060891],\n            [2.0289922, -0.28708987, -2.18369248],\n            [0.35217619, -0.67095644, -1.58943879],\n            [0.08086036, -1.81075924, 1.20752494],\n            [0.8901075, -0.49976737, -1.07153746],\n            [-0.44872912, -1.07275683, 0.06256855],\n            [-0.22556897, 0.74798368, 0.90416439],\n            [0.48339456, -2.32742195, -0.59321527],\n        ],\n        device=device,\n        requires_grad=False,\n    )\n\n    free_one = flow.tensor([1.0], device=device, requires_grad=False)\n    eager_linear, eager_sgd = get_linear_sgd()\n    eager_out_list = []\n    eager_weight_list = []\n    for i in range(12):\n        index = (i % 4) * 2\n        input = x[index : (index + 2)]  # NOTE(chengcheng): unpack x by slice\n        # print(\"i = \", i, \" input = \", input)\n        of_out = eager_linear(input)\n        of_out += free_one  # Test free eager tensor\n        one = flow.ones(of_out.shape, dtype=of_out.dtype, device=of_out.device)\n        of_out += one\n        of_out = flow.reshape(of_out, shape=[-1])\n        of_out = of_out.sum()\n        loss = of_out * 0.25  # NOTE(chengcheng): scale loss by grad acc\n        loss.backward()\n        if (i + 1) % 4 == 0:\n            eager_sgd.step()\n            eager_sgd.zero_grad()\n            eager_weight_list.append(eager_linear.weight.numpy())\n            # print(\"of_eager_weight in step: \", i,\n            #      \" weight = \", eager_linear.weight.numpy())\n\n        # print(\"of_eager_out : \", of_out.numpy())\n        eager_out_list.append(of_out.numpy())\n\n    graph_linear, graph_sgd = get_linear_sgd()\n    graph_out_list = []\n    graph_weight_list = []\n\n    class LinearTrainGraph(flow.nn.Graph):\n        def __init__(self):\n            super().__init__()\n            self.linear = graph_linear\n            self.add_optimizer(graph_sgd)\n            self.config.set_gradient_accumulation_steps(4)\n\n        def build(self, x):\n            out = self.linear(x)\n            out += free_one  # Test free eager tensor\n            one = flow.ones(out.shape, dtype=out.dtype, device=out.device)\n            out += one\n            out = flow.reshape(out, shape=[-1])\n            # print(\"out.shape: \", out.shape)\n            loss = out.sum()\n            loss.backward()\n            return out, loss\n\n    linear_t_g = LinearTrainGraph()\n    for i in range(3):\n        # NOTE(chengcheng): Graph call 1 step for 1 mini-batch(4 micro-batch)\n        non_scalar_out, of_out = linear_t_g(x)\n        # print(\"of_lazy_out : \", of_out.numpy())\n\n        graph_out_list.append(of_out.numpy())\n        graph_weight_list.append(graph_linear.weight.numpy())\n        # print(\"of_lazy_weight in step: \", i,\n        #       \" weight = \", graph_linear.weight.numpy())\n\n    for i in range(3):\n        test_case.assertTrue(np.allclose(eager_weight_list[i], graph_weight_list[i]))\n        for j in range(4):\n            test_case.assertTrue(\n                eager_out_list[i * 4 + j].item() == graph_out_list[i][j]\n            )\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n@flow.unittest.skip_unless_1n1d()\nclass TestGradAccGraph(oneflow.unittest.TestCase):\n    def test_grad_acc_graph_gpu(test_case):\n        _test_grad_acc_graph(test_case, flow.device(\"cuda\"))\n\n    def test_grad_acc_graph_cpu(test_case):\n        _test_grad_acc_graph(test_case, flow.device(\"cpu\"))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/graph/test_graph_image_gpu_decoder.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport os\nimport unittest\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\nclass OFRecordDataLoader(flow.nn.Module):\n    def __init__(self):\n        super().__init__()\n        batch_size = 4\n        image_size = 224\n        self.train_record_reader = flow.nn.OFRecordReader(\n            flow.unittest.dataset_dir(\"imagenette/ofrecord\"),\n            batch_size=batch_size,\n            data_part_num=1,\n            part_name_suffix_length=5,\n            random_shuffle=True,\n            shuffle_after_epoch=True,\n            # placement=flow.placement(\"cpu\", ranks=[0]),\n            # sbp=[flow.sbp.broadcast]\n        )\n\n        self.record_label_decoder = flow.nn.OFRecordRawDecoder(\n            \"class/label\", shape=(), dtype=flow.int32\n        )\n        self.bytes_decoder = flow.nn.OFRecordBytesDecoder(\"encoded\")\n        self.image_gpu_decoder = flow.nn.OFRecordImageGpuDecoderRandomCropResize(\n            target_width=image_size, target_height=image_size, num_workers=3\n        )\n\n        color_space = \"RGB\"\n        output_layout = \"NHWC\"\n\n        self.flip = flow.nn.CoinFlip(\n            batch_size=batch_size,\n            # placement=flow.placement(\"cpu\", ranks=[0]),\n            # sbp=[flow.sbp.broadcast]\n        )\n\n        rgb_mean = [123.68, 116.779, 103.939]\n        rgb_std = [58.393, 57.12, 57.375]\n        self.crop_mirror_norm = flow.nn.CropMirrorNormalize(\n            color_space=color_space,\n            output_layout=output_layout,\n            mean=rgb_mean,\n            std=rgb_std,\n            output_dtype=flow.float,\n        )\n\n    def forward(self) -> (flow.Tensor, flow.Tensor):\n        train_record = self.train_record_reader()\n        label = self.record_label_decoder(train_record)\n        encoded = self.bytes_decoder(train_record)\n        image = self.image_gpu_decoder(encoded)\n        rng = self.flip()\n        if image.is_cuda:\n            rng = rng.to(\"cuda\")\n        image = self.crop_mirror_norm(image, rng)\n        return image, label\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n@flow.unittest.skip_unless_1n1d()\nclass TestImageGpuDecoderGraph(oneflow.unittest.TestCase):\n    @unittest.skip(\"skip for now, becase it failed 2 times in past week\")\n    def test_image_gpu_decoder_graph(test_case):\n        cc_reader = OFRecordDataLoader()\n\n        class GraphReader(flow.nn.Graph):\n            def __init__(self):\n                super().__init__()\n                self.my_reader = cc_reader\n\n            def build(self):\n                return self.my_reader()\n\n        reader_g = GraphReader()\n        image, label = reader_g()\n        print(image.shape)\n        print(label)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/graph/test_graph_inplace_add.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport os\nimport unittest\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\ndef _test_graph_lazy_inplace(test_case, x, y):\n    class LazyInplaceAdd(flow.nn.Graph):\n        def __init__(self):\n            super().__init__()\n\n        def build(self, x, y):\n            x += y\n            return x\n\n    z = LazyInplaceAdd()(x, y)\n    test_case.assertTrue(np.allclose(z.numpy(), (x + y).numpy(), 1e-05, 1e-05))\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n@flow.unittest.skip_unless_1n1d()\nclass TestLocalInplace(oneflow.unittest.TestCase):\n    def test_graph_inplace_gpu(test_case):\n        x = flow.randn(10, 10, device=flow.device(\"cuda\"))\n        y = flow.ones(10, device=flow.device(\"cuda\"))\n        _test_graph_lazy_inplace(test_case, x, y)\n\n    def test_graph_inplace_cpu(test_case):\n        x = flow.randn(10, 10, device=flow.device(\"cpu\"))\n        y = flow.ones(10, device=flow.device(\"cpu\"))\n        _test_graph_lazy_inplace(test_case, x, y)\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n@flow.unittest.skip_unless_1n2d()\nclass TestGlobalInplace(oneflow.unittest.TestCase):\n    def test_graph_inplace_gpu(test_case):\n        x = flow.randn(\n            10,\n            10,\n            placement=flow.placement(\"cuda\", ranks=[0, 1]),\n            sbp=flow.sbp.split(1),\n        )\n        y = flow.ones(\n            10, placement=flow.placement(\"cuda\", ranks=[0, 1]), sbp=flow.sbp.broadcast\n        )\n        _test_graph_lazy_inplace(test_case, x, y)\n\n    def test_graph_inplace_cpu(test_case):\n        x = flow.randn(\n            10, 10, placement=flow.placement(\"cpu\", ranks=[0, 1]), sbp=flow.sbp.split(1)\n        )\n        y = flow.ones(\n            10, placement=flow.placement(\"cpu\", ranks=[0, 1]), sbp=flow.sbp.broadcast\n        )\n        _test_graph_lazy_inplace(test_case, x, y)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/graph/test_graph_io_check.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport warnings\nfrom collections import OrderedDict\nfrom dataclasses import dataclass, fields\nfrom typing import Any, Tuple\nfrom collections import OrderedDict\nimport os\nimport unittest\nimport sys\n\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.framework.tensor import Tensor, TensorTuple\nfrom oneflow.framework.args_tree import ArgsTree\nfrom oneflow.nn.graph import GraphModule\n\n\nclass BaseOutput(OrderedDict):\n    def __post_init__(self):\n        class_fields = fields(self)\n\n        # Safety and consistency checks\n        if not len(class_fields):\n            raise ValueError(f\"{self.__class__.__name__} has no fields.\")\n\n        first_field = getattr(self, class_fields[0].name)\n        other_fields_are_none = all(\n            getattr(self, field.name) is None for field in class_fields[1:]\n        )\n\n        if other_fields_are_none and isinstance(first_field, dict):\n            for key, value in first_field.items():\n                self[key] = value\n        else:\n            for field in class_fields:\n                v = getattr(self, field.name)\n                if v is not None:\n                    self[field.name] = v\n\n    def __delitem__(self, *args, **kwargs):\n        raise Exception(\n            f\"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.\"\n        )\n\n    def setdefault(self, *args, **kwargs):\n        raise Exception(\n            f\"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.\"\n        )\n\n    def pop(self, *args, **kwargs):\n        raise Exception(\n            f\"You cannot use ``pop`` on a {self.__class__.__name__} instance.\"\n        )\n\n    def update(self, *args, **kwargs):\n        raise Exception(\n            f\"You cannot use ``update`` on a {self.__class__.__name__} instance.\"\n        )\n\n    def __getitem__(self, k):\n        if isinstance(k, str):\n            inner_dict = {k: v for (k, v) in self.items()}\n            if (\n                self.__class__.__name__\n                in [\"StableDiffusionPipelineOutput\", \"ImagePipelineOutput\"]\n                and k == \"sample\"\n            ):\n                warnings.warn(\n                    \"The keyword 'samples' is deprecated and will be removed in version 0.4.0. Please use `.images` or\"\n                    \" `'images'` instead.\",\n                    DeprecationWarning,\n                )\n                return inner_dict[\"images\"]\n            return inner_dict[k]\n        else:\n            return self.to_tuple()[k]\n\n    def __setattr__(self, name, value):\n        if name in self.keys() and value is not None:\n            # Don't call self.__setitem__ to avoid recursion errors\n            super().__setitem__(name, value)\n        super().__setattr__(name, value)\n\n    def __setitem__(self, key, value):\n        # Will raise a KeyException if needed\n        super().__setitem__(key, value)\n        # Don't call self.__setattr__ to avoid recursion errors\n        super().__setattr__(key, value)\n\n    def to_tuple(self) -> Tuple[Any]:\n        \"\"\"\n        Convert self to a tuple containing all the attributes/keys that are not `None`.\n        \"\"\"\n        return tuple(self[k] for k in self.keys())\n\n\n@dataclass\nclass CustomDataClass(BaseOutput):\n    sample: flow.Tensor\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n@flow.unittest.skip_unless_1n1d()\nclass TestGraphIOCheck(flow.unittest.TestCase):\n    def test_io_node(test_case):\n        x = np.ones((2, 2))\n        x = flow.tensor(x, dtype=flow.float32)\n\n        t2 = np.ones((2, 2))\n        t2 = flow.tensor(t2, dtype=flow.float32)\n        t3 = np.ones((2, 2))\n        t3 = flow.tensor(t3, dtype=flow.float32)\n        lt0 = list()\n        lt0.append(t2)\n        lt0.append(t3)\n\n        t4 = np.ones((2, 2))\n        t4 = flow.tensor(t4, dtype=flow.float32)\n\n        t4 = np.ones((2, 2))\n        t4 = flow.tensor(t4, dtype=flow.float32)\n\n        def fn(*args, **kwargs):\n            inp = (args, kwargs)\n            print(\"origin: \", inp)\n\n            args_tree = ArgsTree(inp, True, \"Graph_0\", None)\n\n            for (name, arg) in args_tree.iter_named_nodes():\n                print(name, repr(arg))\n\n            def leaf_fn(arg):\n                if isinstance(arg.value(), str):\n                    return \"mapped_str\"\n                return arg.value()\n\n            m_v = args_tree.map_leaf(leaf_fn)\n            print(\"mapped:\", m_v)\n            return m_v[0], m_v[1]\n\n        ret = fn(None, 1, \"test_str\", x, lt0, {\"t\": t4, \"l\": lt0}, kw=t4)\n        print(ret)\n        test_case.assertEqual(ret[0][2], \"mapped_str\")\n        test_case.assertEqual(id(ret[1][\"kw\"]), id(t4))\n\n    def test_io_node_with_simple_tuple_or_list_input(self):\n        x = np.ones((2, 2))\n        x = flow.tensor(x, dtype=flow.float32)\n\n        t2 = np.ones((2, 2))\n        t2 = flow.tensor(t2, dtype=flow.float32)\n        t3 = np.ones((2, 2))\n        t3 = flow.tensor(t3, dtype=flow.float32)\n        t4 = np.ones((2, 2))\n        t4 = flow.tensor(t4, dtype=flow.float32)\n        t5 = np.ones((2, 2))\n        t5 = flow.tensor(t4, dtype=flow.float32)\n        t6 = np.ones((2, 2))\n        t6 = flow.tensor(t4, dtype=flow.float32)\n\n        input_tuple = (x, t2, t3, t4)\n        input_list = [t5, t6]\n\n        def fn(args):\n            print(\"origin: \", args)\n\n            args_tree = ArgsTree(args, False)\n\n            for arg in args_tree.iter_nodes():\n                print(repr(arg))\n\n            def leaf_fn(value):\n                if isinstance(value, Tensor) and not value.is_contiguous():\n                    value.contiguous_()\n                return value\n\n            m_v = args_tree.map_tuple_leaf(leaf_fn)\n            print(\"mapped:\", m_v)\n            return m_v\n\n        # input tuple\n        ret = fn(input_tuple)\n        print(ret)\n        self.assertTrue(isinstance(ret, tuple))\n        self.assertEqual(id(ret[0]), id(x))\n        self.assertEqual(id(ret[1]), id(t2))\n        self.assertEqual(id(ret[2]), id(t3))\n        self.assertEqual(id(ret[3]), id(t4))\n\n        # input list\n        ret = fn(input_list)\n        print(ret)\n        self.assertTrue(isinstance(ret, list))\n        self.assertEqual(id(ret[0]), id(t5))\n        self.assertEqual(id(ret[1]), id(t6))\n\n    def test_custom_class(test_case):\n        x = np.ones((2, 2))\n        x = flow.tensor(x, dtype=flow.float32)\n        ordered_d = CustomDataClass(sample=x)\n\n        def fn(*args, **kwargs):\n            inp = (args, kwargs)\n            print(\"origin: \", inp)\n\n            args_tree = ArgsTree(inp, True, \"Graph_0\", None)\n\n            for (name, arg) in args_tree.iter_named_nodes():\n                print(name, repr(arg))\n\n            def leaf_fn(arg):\n                if isinstance(arg.value(), dict):\n                    return \"replaced\"\n                return arg.value()\n\n            m_v = args_tree.map_leaf(leaf_fn)\n            print(\"mapped:\", m_v)\n            return m_v[0], m_v[1]\n\n        ret = fn(ordered_d)\n        print(ret)\n\n    def test_non_tensor_types_of_module(test_case):\n        class CustomModuleIOCheck(flow.nn.Module):\n            def __init__(self):\n                super().__init__()\n\n            def forward(self, t, lt, n, i, s, **kwargs):\n                return t, lt, n, i, s, kwargs\n\n        class CustomGraphIOCheck(flow.nn.Graph):\n            def __init__(self):\n                super().__init__()\n                self.m = CustomModuleIOCheck()\n                self.m.to(GraphModule).activation_checkpointing = True\n\n            def build(self, t, lt, n, **kwargs):\n                rt, rlt, n, ri, rs, dic = self.m(t, lt, n, 1, \"2\", **kwargs)\n                return t, lt, n, dic\n\n        g = CustomGraphIOCheck()\n\n        x = flow.tensor(np.random.randn(1,), dtype=flow.float32)\n\n        t2 = flow.tensor(np.random.randn(1,), dtype=flow.float32)\n        t3 = flow.tensor(np.random.randn(1,), dtype=flow.float32)\n        lt0 = list()\n        lt0.append(t2)\n        lt0.append(t3)\n        t7 = flow.tensor(np.random.randn(1,), dtype=flow.float32)\n        dic2 = {\"kw2\": t7}\n        lt0.append(dic2)\n\n        t4 = flow.tensor(np.random.randn(1,), dtype=flow.float32)\n        t5 = flow.tensor(np.random.randn(1,), dtype=flow.float32)\n        t6 = flow.tensor(np.random.randn(1,), dtype=flow.float32)\n        lt1 = list()\n        lt1.append(t5)\n        lt1.append(t6)\n\n        ot, olt, on, odic = g(x, lt0, None, kw0=t4, kw1=lt1)\n        # print(g)\n        test_case.assertTrue(np.array_equal(x.numpy(), ot.numpy()))\n\n        test_case.assertTrue(isinstance(olt, list))\n        test_case.assertTrue(isinstance(olt[0], Tensor))\n        test_case.assertTrue(np.array_equal(olt[0].numpy(), lt0[0].numpy()))\n        test_case.assertTrue(isinstance(olt[1], Tensor))\n        test_case.assertTrue(np.array_equal(olt[1].numpy(), lt0[1].numpy()))\n        test_case.assertTrue(isinstance(olt[2], dict))\n        test_case.assertTrue(\n            np.array_equal(olt[2][\"kw2\"].numpy(), lt0[2][\"kw2\"].numpy())\n        )\n\n        test_case.assertTrue(on is None)\n        test_case.assertTrue(isinstance(odic, dict))\n        test_case.assertTrue(np.array_equal(odic[\"kw0\"].numpy(), t4.numpy()))\n        test_case.assertTrue(np.array_equal(odic[\"kw1\"][0].numpy(), t5.numpy()))\n        test_case.assertTrue(np.array_equal(odic[\"kw1\"][1].numpy(), t6.numpy()))\n\n    def test_graph_return_size_0_tuple(test_case):\n        def test_output(input, output_type):\n            print(input)\n            input = (input,)\n            print(input)\n\n            class CustomModule(flow.nn.Module):\n                def __init__(self):\n                    super().__init__()\n\n                def forward(self, t):\n                    return t[0]\n\n            class CustomGraphCheck1Ret(flow.nn.Graph):\n                def __init__(self):\n                    super().__init__()\n                    self.m = CustomModule()\n\n                def build(self, t):\n                    rt = self.m(t)\n                    return rt\n\n            model = CustomModule()\n            graph = CustomGraphCheck1Ret()\n\n            model_out = model(input)\n            graph_out = graph(input)\n\n            if output_type is None:\n                test_case.assertTrue(model_out is output_type)\n                test_case.assertTrue(graph_out is output_type)\n            else:\n                test_case.assertTrue(isinstance(model_out, output_type))\n                test_case.assertTrue(isinstance(graph_out, output_type))\n\n        x = np.ones((1, 10))\n        x = flow.tensor(x, dtype=flow.float32)\n\n        # test size 1 tuple\n        x_tuple = (x,)\n        test_output(x_tuple, tuple)\n\n        # test size 1 list\n        x_list = [\n            x,\n        ]\n        test_output(x_list, list)\n\n        # test tensor\n        test_output(x, Tensor)\n\n    def test_graph_return_dict_tuple(test_case):\n        def test_output(input):\n            print(input)\n\n            class CustomModule(flow.nn.Module):\n                def __init__(self):\n                    super().__init__()\n\n                def forward(self, t):\n                    return {\"output\": t}\n\n            class CustomGraphCheck1Ret(flow.nn.Graph):\n                def __init__(self):\n                    super().__init__()\n                    self.m = CustomModule()\n\n                def build(self, t):\n                    rt = self.m(t)\n                    return rt\n\n            model = CustomModule()\n            graph = CustomGraphCheck1Ret()\n\n            model_out = model(input)\n            graph_out = graph(input)\n\n            test_case.assertTrue(isinstance(model_out, dict))\n            test_case.assertTrue(isinstance(graph_out, dict))\n            test_case.assertEqual(len(model_out), 1)\n            test_case.assertEqual(len(graph_out), 1)\n            test_case.assertTrue(\"output\" in model_out)\n            test_case.assertTrue(\"output\" in graph_out)\n            test_case.assertTrue(\n                np.array_equal(model_out[\"output\"].numpy(), graph_out[\"output\"].numpy())\n            )\n\n        x = np.ones((1, 10))\n        x = flow.tensor(x, dtype=flow.float32)\n\n        # test tensor\n        test_output(x)\n\n    def test_graph_outputs_buffer(test_case):\n        class CustomModuleIOCheck(flow.nn.Module):\n            def __init__(self):\n                super().__init__()\n\n            def forward(self, t, tp, lt, n, i, s):\n                return t, tp, lt, n, i, s\n\n        class CustomGraphIOCheck1(flow.nn.Graph):\n            def __init__(self):\n                super().__init__()\n                self.config.set_outputs_buffer_size(5)\n                self.m = CustomModuleIOCheck()\n\n            def build(self, t, tp, lt, n):\n                rt, rtp, rlt, n, ri, rs = self.m(t, tp, lt, n, 1, \"2\")\n                return t, tp, lt, n\n\n        g = CustomGraphIOCheck1()\n\n        x = np.ones((10, 10))\n        x = flow.tensor(x, dtype=flow.float32)\n\n        y = np.ones((10, 10))\n        y = flow.tensor(y, dtype=flow.float32)\n\n        # IO with TensorTuple cannot pass this test,\n        # its tensor item's id is weird.\n        # t0 = np.ones((10, 10))\n        # t0 = flow.tensor(t0, dtype=flow.float32)\n        # t1 = np.ones((10, 10))\n        # t1 = flow.tensor(t1, dtype=flow.float32)\n        # tp0 = TensorTuple()\n        # tp0.append(t0)\n        # tp0.append(t1)\n\n        t2 = np.ones((10, 10))\n        t2 = flow.tensor(t2, dtype=flow.float32)\n        t3 = np.ones((10, 10))\n        t3 = flow.tensor(t3, dtype=flow.float32)\n        lt0 = list()\n        lt0.append(t2)\n        lt0.append(t3)\n\n        # Check there is not duplicated tensor in outputs buffer and outputs.\n        out_id_dic = dict()\n        out_tensor_holder = dict()\n\n        def check_id_and_add(t, name):\n            if t is not None:\n                tid = id(t)\n                assert (\n                    tid not in out_id_dic\n                ), f\"tid {tid}, now name {name}, inserted name {out_id_dic[tid]}\"\n                test_case.assertTrue(tid not in out_id_dic)\n                out_id_dic[tid] = name\n                # It seems that python id maybe re-used, hold it to avoid gc re-using it.\n                # ref: https://stackoverflow.com/questions/52096582/how-unique-is-pythons-id\n                out_tensor_holder[name] = t\n\n        def call_and_check(idx):\n            # ot, otp, olt, on = g(x, tp0, lt0, None)\n            ot, otp, olt, on = g(x, y, lt0, None)\n            if idx == 0:\n                test_case.assertEqual(len(g._outputs_tensor_tuple_buffer), 5)\n                for b_idx, buffer in enumerate(g._outputs_tensor_tuple_buffer):\n                    for i_idx, item in enumerate(buffer):\n                        check_id_and_add(\n                            item, \"buffer_\" + str(b_idx) + \"_\" + str(i_idx)\n                        )\n\n            test_case.assertTrue(np.array_equal(x.numpy(), ot.numpy()))\n            check_id_and_add(ot, \"ot_\" + str(idx))\n\n            # test_case.assertTrue(isinstance(otp, TensorTuple))\n            # check_id_and_add(otp, \"otp_\" + str(idx))\n            # test_case.assertTrue(isinstance(otp[0], Tensor))\n            # check_id_and_add(otp[0], \"otp0_\" + str(idx))\n            # test_case.assertTrue(np.array_equal(otp[0].numpy(), tp0[0].numpy()))\n            # test_case.assertTrue(isinstance(otp[1], Tensor))\n            # check_id_and_add(otp[1], \"otp1_\" + str(idx))\n            # test_case.assertTrue(np.array_equal(otp[1].numpy(), tp0[1].numpy()))\n\n            test_case.assertTrue(isinstance(otp, Tensor))\n            check_id_and_add(otp, \"otp_\" + str(idx))\n            test_case.assertTrue(np.array_equal(y.numpy(), otp.numpy()))\n\n            test_case.assertTrue(isinstance(olt, list))\n            check_id_and_add(olt, \"olt_\" + str(idx))\n            test_case.assertTrue(isinstance(olt[0], Tensor))\n            check_id_and_add(olt[0], \"olt0_\" + str(idx))\n            test_case.assertTrue(np.array_equal(olt[0].numpy(), lt0[0].numpy()))\n            check_id_and_add(olt[1], \"olt1_\" + str(idx))\n            test_case.assertTrue(np.array_equal(olt[1].numpy(), lt0[1].numpy()))\n\n            test_case.assertTrue(on is None)\n\n        for i in range(15):\n            call_and_check(i)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/graph/test_graph_linear.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport os\nimport unittest\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\ndef _test_linear_graph(test_case, device):\n    linear = flow.nn.Linear(3, 8, False)\n    linear = linear.to(device)\n    input_arr = np.array(\n        [\n            [-0.94630778, -0.83378579, -0.87060891],\n            [2.0289922, -0.28708987, -2.18369248],\n            [0.35217619, -0.67095644, -1.58943879],\n            [0.08086036, -1.81075924, 1.20752494],\n            [0.8901075, -0.49976737, -1.07153746],\n            [-0.44872912, -1.07275683, 0.06256855],\n            [-0.22556897, 0.74798368, 0.90416439],\n            [0.48339456, -2.32742195, -0.59321527],\n        ],\n        dtype=np.float32,\n    )\n    np_weight = np.ones((3, 8)).astype(np.float32)\n    np_weight.fill(2.3)\n    x = flow.tensor(input_arr, device=device)\n    flow.nn.init.constant_(linear.weight, 2.3)\n    of_eager_out = linear(x)\n    np_out = np.matmul(input_arr, np_weight)\n    test_case.assertTrue(np.allclose(of_eager_out.numpy(), np_out, 1e-05, 1e-05))\n\n    class LinearGraph(flow.nn.Graph):\n        def __init__(self):\n            super().__init__()\n            self.my_linear = linear\n\n        def build(self, x):\n            return self.my_linear(x)\n\n    linear_g = LinearGraph()\n    linear_g.debug(0)\n    of_lazy_out = linear_g(x)\n    test_case.assertTrue(np.array_equal(of_lazy_out.numpy(), of_eager_out.numpy()))\n\n\ndef _test_linear_graph_func(test_case, device):\n    linear = flow.nn.Linear(3, 8, False)\n    linear = linear.to(device)\n    input_arr = np.array(\n        [\n            [-0.94630778, -0.83378579, -0.87060891],\n            [2.0289922, -0.28708987, -2.18369248],\n            [0.35217619, -0.67095644, -1.58943879],\n            [0.08086036, -1.81075924, 1.20752494],\n            [0.8901075, -0.49976737, -1.07153746],\n            [-0.44872912, -1.07275683, 0.06256855],\n            [-0.22556897, 0.74798368, 0.90416439],\n            [0.48339456, -2.32742195, -0.59321527],\n        ],\n        dtype=np.float32,\n    )\n    np_weight = np.ones((3, 8)).astype(np.float32)\n    np_weight.fill(2.3)\n    x = flow.tensor(input_arr, device=device)\n    flow.nn.init.constant_(linear.weight, 2.3)\n    of_eager_out = linear(x)\n    np_out = np.matmul(input_arr, np_weight)\n    test_case.assertTrue(np.allclose(of_eager_out.numpy(), np_out, 1e-05, 1e-05))\n\n    @flow.nn.Graph.trace\n    def linear_func(x):\n        return linear(x)\n\n    of_lazy_out = linear_func(x)\n    test_case.assertTrue(np.array_equal(of_lazy_out.numpy(), of_eager_out.numpy()))\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n@flow.unittest.skip_unless_1n1d()\nclass TestLinearGraph(oneflow.unittest.TestCase):\n    def test_linear_graph_gpu(test_case):\n        _test_linear_graph(test_case, flow.device(\"cuda\"))\n\n    def test_linear_graph_cpu(test_case):\n        _test_linear_graph(test_case, flow.device(\"cpu\"))\n\n    def test_linear_graph_func_gpu(test_case):\n        _test_linear_graph_func(test_case, flow.device(\"cuda\"))\n\n    def test_linear_graph_func_cpu(test_case):\n        _test_linear_graph_func(test_case, flow.device(\"cpu\"))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/graph/test_graph_linear_train.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nimport os\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\ndef _test_linear_train_graph(test_case, device):\n    def train_with_module(iter_num=3):\n        linear = flow.nn.Linear(3, 8)\n        linear = linear.to(device)\n        flow.nn.init.constant_(linear.weight, 2.068758)\n        flow.nn.init.constant_(linear.bias, 0.23)\n        of_sgd = flow.optim.SGD(linear.parameters(), lr=0.001, momentum=0.9)\n\n        x = flow.tensor(\n            [\n                [-0.94630778, -0.83378579, -0.87060891],\n                [2.0289922, -0.28708987, -2.18369248],\n                [0.35217619, -0.67095644, -1.58943879],\n                [0.08086036, -1.81075924, 1.20752494],\n                [0.8901075, -0.49976737, -1.07153746],\n                [-0.44872912, -1.07275683, 0.06256855],\n                [-0.22556897, 0.74798368, 0.90416439],\n                [0.48339456, -2.32742195, -0.59321527],\n            ],\n            dtype=flow.float32,\n            device=device,\n            requires_grad=False,\n        )\n\n        def one_iter():\n            of_out = linear(x)\n            of_out = of_out.sum()\n\n            of_out.backward()\n            of_sgd.step()\n            of_sgd.zero_grad()\n\n            return of_out.numpy(), linear.weight.numpy()\n\n        check_list = []\n        for i in range(iter_num):\n            check_list.append(one_iter())\n        return check_list\n\n    def train_with_graph(iter_num=3):\n        linear = flow.nn.Linear(3, 8)\n        linear = linear.to(device)\n        flow.nn.init.constant_(linear.weight, 2.068758)\n        flow.nn.init.constant_(linear.bias, 0.23)\n        of_sgd = flow.optim.SGD(linear.parameters(), lr=0.001, momentum=0.9)\n\n        x = flow.tensor(\n            [\n                [-0.94630778, -0.83378579, -0.87060891],\n                [2.0289922, -0.28708987, -2.18369248],\n                [0.35217619, -0.67095644, -1.58943879],\n                [0.08086036, -1.81075924, 1.20752494],\n                [0.8901075, -0.49976737, -1.07153746],\n                [-0.44872912, -1.07275683, 0.06256855],\n                [-0.22556897, 0.74798368, 0.90416439],\n                [0.48339456, -2.32742195, -0.59321527],\n            ],\n            dtype=flow.float32,\n            device=device,\n            requires_grad=False,\n        )\n\n        class LinearTrainGraph(flow.nn.Graph):\n            def __init__(self):\n                super().__init__()\n                self.linear = linear\n                self.add_optimizer(of_sgd)\n\n            def build(self, x):\n                out = self.linear(x)\n                out = out.sum()\n                out.backward()\n                return out\n\n        linear_t_g = LinearTrainGraph()\n\n        def one_iter():\n            of_graph_out = linear_t_g(x)\n            print(linear_t_g.linear)\n            return (\n                of_graph_out.numpy(),\n                linear_t_g.linear.weight.to(flow.Tensor).numpy(),\n            )\n\n        check_list = []\n        for i in range(iter_num):\n            check_list.append(one_iter())\n        return check_list\n\n    iter_num = 3\n    module_check_list = train_with_module(iter_num)\n    graph_check_list = train_with_graph(iter_num)\n    for i in range(iter_num):\n        # check equal on loss\n        test_case.assertTrue(\n            np.array_equal(module_check_list[i][0], graph_check_list[i][0])\n        )\n        # check equal on weight\n        test_case.assertTrue(\n            np.array_equal(module_check_list[i][1], graph_check_list[i][1])\n        )\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n@flow.unittest.skip_unless_1n1d()\nclass TestLinearTrainGraph(oneflow.unittest.TestCase):\n    def test_linear_train_graph_gpu(test_case):\n        _test_linear_train_graph(test_case, flow.device(\"cuda\"))\n\n    def test_linear_train_graph_cpu(test_case):\n        _test_linear_train_graph(test_case, flow.device(\"cpu\"))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/graph/test_graph_loss.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.unittest\nfrom collections import OrderedDict\nfrom test_util import GenArgList\n\n\nshapes = {2: (128, 8), 3: (16, 8, 64), 4: (16, 8, 32, 32), 5: (16, 8, 16, 16, 16)}\n\n\ndef compare_loss(device_type, dim, reduction, cls, data_generator):\n    x, y = data_generator(dim, device_type)\n    f = cls(reduction=reduction).to(device_type)\n    z_eager = f(x, y)\n\n    class CurrentGraph(flow.nn.Graph):\n        def __init__(self) -> None:\n            super().__init__()\n            self.f = f\n\n        def build(self, x, y):\n            return self.f(x, y)\n\n    f_g = CurrentGraph()\n    z_lazy = f_g(x, y)\n    assert np.allclose(z_eager.numpy(), z_lazy.numpy(), rtol=1.0e-5, atol=1.0e-5)\n\n\ndef generate_necessity_default(dim: int, device: str):\n    shape = shapes[dim]\n    x_np = np.random.uniform(0, 1, shape)\n    y_np = np.random.uniform(0, 1, shape)\n    x = flow.tensor(x_np, dtype=flow.float32, device=device)\n    y = flow.tensor(y_np, dtype=flow.float32, device=device)\n    return x, y\n\n\ndef generate_necessity_for_cross_entropy_or_nll_loss(dim: int, device: str):\n    shape = shapes[dim]\n    y_shape = (shape[0],) if dim == 2 else (shape[0], *shape[2:])\n    x_np = np.random.uniform(0, 1, shape)\n    y_np = np.random.randint(0, shape[1], y_shape)\n    x = flow.tensor(x_np, dtype=flow.float32, device=device)\n    y = flow.tensor(y_np, dtype=flow.int32, device=device)\n    return x, y\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestKLDivLossGraph(oneflow.unittest.TestCase):\n    def test_kl_div_loss_graph(testcase):\n        arg_dict = OrderedDict()\n        arg_dict[\"device_type\"] = [\"cuda\", \"cpu\"]\n        arg_dict[\"dim\"] = [2, 3, 4, 5]\n        arg_dict[\"reduction\"] = [\"sum\", \"mean\"]\n        arg_dict[\"cls\"] = [flow.nn.KLDivLoss]\n        arg_dict[\"data_generator\"] = [generate_necessity_default]\n        for arg in GenArgList(arg_dict):\n            compare_loss(*arg)\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestSmoothL1LossGraph(oneflow.unittest.TestCase):\n    def test_smooth_l1_loss_graph(testcase):\n        arg_dict = OrderedDict()\n        arg_dict[\"device_type\"] = [\"cuda\", \"cpu\"]\n        arg_dict[\"dim\"] = [2, 3, 4, 5]\n        arg_dict[\"reduction\"] = [\"sum\", \"mean\"]\n        arg_dict[\"cls\"] = [flow.nn.SmoothL1Loss]\n        arg_dict[\"data_generator\"] = [generate_necessity_default]\n        for arg in GenArgList(arg_dict):\n            compare_loss(*arg)\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestBCELossOrWithLogitsGraph(flow.unittest.TestCase):\n    def test_bce_loss_graph(testcase):\n        arg_dict = OrderedDict()\n        arg_dict[\"device_type\"] = [\"cuda\", \"cpu\"]\n        arg_dict[\"dim\"] = [2, 3, 4, 5]\n        arg_dict[\"reduction\"] = [\"sum\", \"mean\"]\n        arg_dict[\"cls\"] = [flow.nn.BCELoss, flow.nn.BCEWithLogitsLoss]\n        arg_dict[\"data_generator\"] = [generate_necessity_default]\n        for arg in GenArgList(arg_dict):\n            compare_loss(*arg)\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestCrossEntropyOrNllLossGraph(flow.unittest.TestCase):\n    def test_cross_entropy_loss_or_nll_loss_graph(testcase):\n        arg_dict = OrderedDict()\n        arg_dict[\"device_type\"] = [\"cuda\", \"cpu\"]\n        arg_dict[\"dim\"] = [2, 3, 4, 5]\n        arg_dict[\"reduction\"] = [\"sum\", \"mean\"]\n        arg_dict[\"cls\"] = [flow.nn.CrossEntropyLoss, flow.nn.NLLLoss]\n        arg_dict[\"data_generator\"] = [generate_necessity_for_cross_entropy_or_nll_loss]\n        for arg in GenArgList(arg_dict):\n            compare_loss(*arg)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/graph/test_graph_lr_scale.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom collections import OrderedDict\nfrom oneflow.test_utils.test_util import GenArgDict\n\n\nclass _Block(flow.nn.Module):\n    def __init__(self, feats, device=None, placement=None):\n        super().__init__()\n        ones = flow.ones(feats)\n        if placement is not None:\n            ones = ones.to_global(placement=placement, sbp=flow.sbp.broadcast())\n        elif device is not None:\n            ones = ones.to(device)\n        self.param = flow.nn.Parameter(ones)\n\n    def forward(self, x):\n        return x + self.param\n\n\nclass _MyModule(flow.nn.Module):\n    def __init__(self, feats, depth, device=None, placement=None):\n        super().__init__()\n        self.layers = flow.nn.ModuleList(\n            [\n                _Block(feats=feats, device=device, placement=placement)\n                for i in range(depth)\n            ]\n        )\n\n    def forward(self, x):\n        for layer in self.layers:\n            x = layer(x)\n        return x\n\n\nclass _MyGraph(flow.nn.Graph):\n    def __init__(self, model, optimizer, lr_scheduler):\n        super().__init__()\n        self.m = model\n        self.add_optimizer(optimizer, lr_sch=lr_scheduler)\n\n    def build(self, input):\n        out = self.m(input)\n        out.sum().backward()\n        return out\n\n\ndef _lrs_param_groups(model, base_scale):\n    param_groups = []\n    for i, layer in enumerate(model.layers):\n        this_scale = base_scale ** (i + 1)\n        param_group = {\"params\": layer.parameters(), \"lr_scale\": this_scale}\n        param_groups.append(param_group)\n\n    return param_groups\n\n\ndef _rand_input(shape, device=None, placement=None, requires_grad=False):\n    input = flow.tensor(np.random.rand(*shape).astype(np.float32))\n    if placement is not None:\n        input = input.to_global(placement=placement, sbp=flow.sbp.split(0))\n    elif device is not None:\n        input = input.to(device)\n    if requires_grad:\n        input.requires_grad_()\n    return input\n\n\ndef _test_lrs(test_case, **kwargs):\n    verbose = kwargs.pop(\"verbose\", False)\n    if verbose:\n        print(f\"#### kwargs={kwargs}\")\n\n    batch_size = kwargs.pop(\"batch_size\", 4)\n    feats = kwargs.pop(\"feats\", 768)\n    depth = kwargs.pop(\"depth\", 3)\n    lr = kwargs.pop(\"lr\", 1.0)\n    base_scale = kwargs.pop(\"base_scale\", 0.1)\n    device_type = kwargs.pop(\"device_type\", \"cuda\")\n    placement = kwargs.pop(\"placement\", None)\n    graph_mode = kwargs.pop(\"graph_mode\", True)\n\n    model = _MyModule(feats=feats, depth=depth, device=device_type, placement=placement)\n    param_groups = _lrs_param_groups(model, base_scale=base_scale)\n    optimizer = flow.optim.SGD(param_groups, lr=lr)\n    lr_scheduler = flow.optim.lr_scheduler.ConstantLR(\n        optimizer, factor=1.0, total_iters=100\n    )\n    model_graph = _MyGraph(model, optimizer, lr_scheduler)\n\n    input = _rand_input(\n        (batch_size, feats), device=device_type, placement=placement, requires_grad=True\n    )\n    t_params = []\n\n    if graph_mode:\n        for i in range(depth):\n            origin_p = model.layers[i].param.numpy()\n            init_grad = float(batch_size * flow.env.get_world_size())\n            t_params.append(origin_p - float(init_grad) * lr * (base_scale ** (i + 1)))\n        ret = model_graph(input)\n    else:\n        for i in range(depth):\n            origin_p = model.layers[i].param.numpy()\n            init_grad = float(batch_size * flow.env.get_world_size())\n            t_params.append(origin_p - float(init_grad) * lr)\n\n        optimizer.zero_grad()\n        ret = model(input)\n        ret.sum().backward()\n        optimizer.step()\n        lr_scheduler.step()\n\n    if verbose:\n        print(\"#### input\")\n        print(input)\n        # sync\n        np_ret = ret.numpy()\n        print(\"#### ret\")\n        print(np_ret)\n\n        for i in range(depth):\n            np_param = model.layers[i].param.numpy()\n            print(f\"#### layer{i} param\")\n            print(np_param)\n\n        print(\"#### grad\")\n        print(input.grad)\n\n    for i in range(depth):\n        np_param = model.layers[i].param.numpy()\n        t_param = t_params[i]\n        test_case.assertTrue(\n            np.allclose(np_param, t_param), f\"\\n{np_param}\\n vs. \\n{t_param}\"\n        )\n\n\n@flow.unittest.skip_unless_1n1d()\nclass LRScaleTest(flow.unittest.TestCase):\n    def test_lr_scale(self):\n        arg_dict = OrderedDict()\n        arg_dict[\"batch_size\"] = [2, 4]\n        arg_dict[\"feats\"] = [10, 13]\n        arg_dict[\"depth\"] = [3, 4]\n        arg_dict[\"lr\"] = [1.0, 0.1]\n        arg_dict[\"base_scale\"] = [0.1, 0.2]\n        arg_dict[\"device_type\"] = [\"cuda\", \"cpu\"]\n        arg_dict[\"is_global\"] = [True, False]\n        arg_dict[\"graph_mode\"] = [True, False]\n\n        for arg in GenArgDict(arg_dict):\n            is_global = arg.pop(\"is_global\", True)\n            if is_global:\n                device_type = arg.pop(\"device_type\", \"cuda\")\n                arg[\"placement\"] = flow.placement.all(device_type)\n\n            # arg[\"verbose\"] = True\n            _test_lrs(self, **arg)\n\n\n@flow.unittest.skip_unless_1n2d()\nclass LRScaleParallelTest(flow.unittest.TestCase):\n    def test_lr_scale_parallel(self):\n        arg_dict = OrderedDict()\n        arg_dict[\"batch_size\"] = [2, 4]\n        arg_dict[\"feats\"] = [5, 10]\n        arg_dict[\"depth\"] = [3, 4]\n        arg_dict[\"lr\"] = [1.0, 0.1]\n        arg_dict[\"base_scale\"] = [0.1, 0.2]\n        arg_dict[\"device_type\"] = [\"cuda\", \"cpu\"]\n        arg_dict[\"graph_mode\"] = [True, False]\n\n        for arg in GenArgDict(arg_dict):\n            device_type = arg.pop(\"device_type\", \"cuda\")\n            arg[\"placement\"] = flow.placement.all(device_type)\n            # arg[\"verbose\"] = True\n            _test_lrs(self, **arg)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/graph/test_graph_lr_scheduler.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nimport os\nimport numpy as np\nimport glob\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\nclass MyModule(flow.nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.param = flow.nn.Parameter(flow.ones(3, 4))\n\n    def forward(self, input):\n        return self.param + input\n\n\nclass MyGraph(flow.nn.Graph):\n    def __init__(self, module, optimizer, lr_scheduler):\n        super().__init__()\n        self.m = module\n        self.add_optimizer(optimizer, lr_sch=lr_scheduler)\n\n    def build(self, input):\n        out = self.m(input)\n        out.mean().backward()\n        return out\n\n\ndef _rand_input():\n    return flow.Tensor(np.random.rand(3, 4).astype(np.float32))\n\n\ndef _get_graph_lrs_from_log(log_path):\n    lines = []\n    with open(log_path, \"rt\") as f:\n        for line in f:\n            lines.append(line.strip())\n\n    lines = lines[1:]\n    lrs = []\n    for i, line in enumerate(lines):\n        step, lr = line.split(\",\")\n        assert int(step) == i\n        lrs.append(float(lr))\n\n    return lrs\n\n\nclass _DebugMode(object):\n    def __enter__(self):\n        os.environ[\"ONEFLOW_DEBUG_MODE\"] = \"True\"\n\n    def __exit__(self, type, value, traceback):\n        del os.environ[\"ONEFLOW_DEBUG_MODE\"]\n\n\ndef _compare_graph_lr_scheduler_with_eager(test_case, **kwargs):\n    lr_scheduler_class = kwargs.pop(\"lr_scheduler\", None)\n    base_lr = kwargs.pop(\"base_lr\", None)\n    iters = kwargs.pop(\"iters\", None)\n    rtol = kwargs.pop(\"rtol\", 1e-05)\n    atol = kwargs.pop(\"atol\", 1e-08)\n\n    if \"warmup_method\" in kwargs:\n        warmup_method = kwargs.pop(\"warmup_method\", \"linear\")\n        warmup_iters = kwargs.pop(\"warmup_iters\", 5)\n        warmup_factor = kwargs.pop(\"warmup_factor\", 0.1)\n        warmup_prefix = kwargs.pop(\"warmup_prefix\", False)\n        need_warmup = True\n    else:\n        need_warmup = False\n\n    assert base_lr is not None and iters is not None\n\n    module = MyModule()\n    optimizer = flow.optim.SGD([module.param], lr=base_lr)\n    lr_scheduler = (\n        lr_scheduler_class(optimizer, **kwargs) if lr_scheduler_class else None\n    )\n\n    if need_warmup:\n        lr_scheduler = flow.optim.lr_scheduler.WarmupLR(\n            lr_scheduler or optimizer,\n            warmup_factor=warmup_factor,\n            warmup_iters=warmup_iters,\n            warmup_method=warmup_method,\n            warmup_prefix=warmup_prefix,\n        )\n\n    graph = MyGraph(module, optimizer, lr_scheduler)\n\n    with _DebugMode():\n        for _ in range(iters + 1):\n            ret = graph(_rand_input())\n            ret.numpy()  # sync for graph finishing\n\n    pid = os.getpid()\n    lr_log_file = glob.glob(f\"log/*/{pid}-train_step2lr.csv\")[0]\n    lrs = _get_graph_lrs_from_log(lr_log_file)\n    lrs = lrs[:iters]\n\n    optimizer.zero_grad(set_to_none=True)\n    eager_lrs = [lr_scheduler.get_last_lr()[0]]\n    for _ in range(iters):\n        ret = module(_rand_input())\n        ret.numpy()\n        optimizer.step()\n        lr_scheduler.step()\n        eager_lrs.append(lr_scheduler.get_last_lr()[0])\n\n    eager_lrs = eager_lrs[:iters]\n\n    test_case.assertTrue(\n        np.allclose(lrs, eager_lrs, rtol=rtol, atol=atol),\n        f\"\\ngraph_lrs: {lrs}\\nvs.\\neager_lrs: {eager_lrs}\",\n    )\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n@flow.unittest.skip_unless_1n1d()\nclass TestGraphLRSchedulerWithEager(flow.unittest.TestCase):\n    def test_constant_lr(self):\n        _compare_graph_lr_scheduler_with_eager(\n            self,\n            base_lr=0.1,\n            iters=10,\n            lr_scheduler=flow.optim.lr_scheduler.ConstantLR,\n            factor=0.1,\n            total_iters=10,\n        )\n\n    def test_linear_lr(self):\n        _compare_graph_lr_scheduler_with_eager(\n            self,\n            base_lr=0.1,\n            iters=20,\n            lr_scheduler=flow.optim.lr_scheduler.LinearLR,\n            start_factor=0.1,\n            end_factor=1.0,\n            total_iters=10,\n        )\n\n    def test_linear_lr_end_factor(self):\n        _compare_graph_lr_scheduler_with_eager(\n            self,\n            base_lr=0.1,\n            iters=20,\n            lr_scheduler=flow.optim.lr_scheduler.LinearLR,\n            start_factor=0.1,\n            end_factor=0.9,\n            total_iters=10,\n        )\n\n    def test_step_lr(self):\n        _compare_graph_lr_scheduler_with_eager(\n            self,\n            base_lr=0.1,\n            iters=30,\n            lr_scheduler=flow.optim.lr_scheduler.StepLR,\n            step_size=10,\n            gamma=0.1,\n        )\n\n    def test_multi_step_lr(self):\n        _compare_graph_lr_scheduler_with_eager(\n            self,\n            base_lr=0.1,\n            iters=20,\n            lr_scheduler=flow.optim.lr_scheduler.MultiStepLR,\n            milestones=[5, 15],\n            gamma=0.2,\n        )\n\n    def test_polynomial_lr(self):\n        _compare_graph_lr_scheduler_with_eager(\n            self,\n            base_lr=0.1,\n            iters=20,\n            lr_scheduler=flow.optim.lr_scheduler.PolynomialLR,\n            decay_batch=20,\n            end_learning_rate=1e-5,\n            power=2.0,\n            atol=1e-5,\n        )\n        _compare_graph_lr_scheduler_with_eager(\n            self,\n            base_lr=0.01,\n            iters=20,\n            lr_scheduler=flow.optim.lr_scheduler.PolynomialLR,\n            decay_batch=20,\n            end_learning_rate=1e-4,\n            power=1.0,\n            cycle=True,\n        )\n\n    def test_exponential_lr(self):\n        _compare_graph_lr_scheduler_with_eager(\n            self,\n            base_lr=0.1,\n            iters=10,\n            lr_scheduler=flow.optim.lr_scheduler.ExponentialLR,\n            gamma=0.5,\n            atol=1e-5,\n        )\n\n    def test_cosine_decay_lr(self):\n        _compare_graph_lr_scheduler_with_eager(\n            self,\n            base_lr=0.1,\n            iters=20,\n            lr_scheduler=flow.optim.lr_scheduler.CosineDecayLR,\n            decay_steps=10,\n            alpha=1e-3,\n            atol=1e-5,\n        )\n\n    def test_cosine_annealing_lr(self):\n        _compare_graph_lr_scheduler_with_eager(\n            self,\n            base_lr=0.1,\n            iters=20,\n            lr_scheduler=flow.optim.lr_scheduler.CosineAnnealingLR,\n            T_max=10,\n            eta_min=1e-4,\n            atol=1e-5,\n        )\n\n    def test_linear_warmup_cosine_annealing_lr(self):\n        _compare_graph_lr_scheduler_with_eager(\n            self,\n            base_lr=0.1,\n            iters=20,\n            lr_scheduler=flow.optim.lr_scheduler.CosineAnnealingLR,\n            T_max=20,\n            eta_min=1e-5,\n            warmup_method=\"linear\",\n            warmup_factor=0.1,\n            warmup_iters=5,\n            warmup_prefix=False,\n            atol=1e-5,\n        )\n\n    def test_linear_warmup_prefix_cosine_annealing_lr(self):\n        _compare_graph_lr_scheduler_with_eager(\n            self,\n            base_lr=0.1,\n            iters=20,\n            lr_scheduler=flow.optim.lr_scheduler.CosineAnnealingLR,\n            T_max=20,\n            eta_min=1e-5,\n            warmup_method=\"linear\",\n            warmup_factor=0.1,\n            warmup_iters=5,\n            warmup_prefix=True,\n            atol=1e-5,\n        )\n\n    def test_linear_warmup_multistep_lr(self):\n        _compare_graph_lr_scheduler_with_eager(\n            self,\n            base_lr=0.1,\n            iters=20,\n            lr_scheduler=flow.optim.lr_scheduler.MultiStepLR,\n            milestones=[10, 15],\n            gamma=0.1,\n            warmup_method=\"linear\",\n            warmup_factor=0.1,\n            warmup_iters=5,\n        )\n\n    def test_constant_warmup_cosine_decay_lr(self):\n        _compare_graph_lr_scheduler_with_eager(\n            self,\n            base_lr=0.1,\n            iters=20,\n            lr_scheduler=flow.optim.lr_scheduler.CosineDecayLR,\n            decay_steps=20,\n            alpha=1e-3,\n            warmup_method=\"constant\",\n            warmup_factor=0.1,\n            warmup_iters=5,\n            atol=1e-5,\n        )\n\n    def test_constant_warmup_prefix_cosine_decay_lr(self):\n        _compare_graph_lr_scheduler_with_eager(\n            self,\n            base_lr=0.1,\n            iters=20,\n            lr_scheduler=flow.optim.lr_scheduler.CosineDecayLR,\n            decay_steps=20,\n            alpha=1e-3,\n            warmup_method=\"constant\",\n            warmup_factor=0.1,\n            warmup_iters=5,\n            warmup_prefix=True,\n            atol=1e-5,\n        )\n\n    def test_only_warmup(self):\n        _compare_graph_lr_scheduler_with_eager(\n            self,\n            base_lr=0.1,\n            iters=10,\n            lr_scheduler=None,\n            warmup_method=\"linear\",\n            warmup_factor=0.1,\n            warmup_iters=5,\n        )\n\n    def test_warmup_iters_equal_to_zero(self):\n        _compare_graph_lr_scheduler_with_eager(\n            self,\n            base_lr=0.1,\n            iters=10,\n            lr_scheduler=flow.optim.lr_scheduler.StepLR,\n            step_size=3,\n            gamma=0.5,\n            warmup_method=\"linear\",\n            warmup_iters=0,\n        )\n\n    def test_cosine_annealing_warm_restarts(self):\n        _compare_graph_lr_scheduler_with_eager(\n            self,\n            base_lr=0.1,\n            iters=50,\n            lr_scheduler=flow.optim.lr_scheduler.CosineAnnealingWarmRestarts,\n            T_0=10,\n            T_mult=1,\n            eta_min=0.01,\n            atol=1e-5,\n        )\n\n    def test_cosine_annealing_warm_restarts_mult_2(self):\n        _compare_graph_lr_scheduler_with_eager(\n            self,\n            base_lr=0.1,\n            iters=70,\n            lr_scheduler=flow.optim.lr_scheduler.CosineAnnealingWarmRestarts,\n            T_0=10,\n            T_mult=2,\n            eta_min=0.01,\n            atol=1e-5,\n        )\n\n    def test_cosine_annealing_warm_restarts_limit(self):\n        _compare_graph_lr_scheduler_with_eager(\n            self,\n            base_lr=0.1,\n            iters=50,\n            lr_scheduler=flow.optim.lr_scheduler.CosineAnnealingWarmRestarts,\n            T_0=10,\n            T_mult=2,\n            eta_min=0.01,\n            decay_rate=0.5,\n            restart_limit=2,\n            atol=1e-5,\n        )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/graph/test_graph_lr_with_warmup.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport math\nimport unittest\nimport os\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.nn.parameter import Parameter\n\n\ndef _test_linear_graph_train_with_lr_sch(\n    test_case, iter_num, device, get_opt_and_lr_sch\n):\n    def train_with_module(iter_num=3):\n        linear = flow.nn.Linear(3, 8)\n        linear = linear.to(device)\n        flow.nn.init.constant_(linear.weight, -0.68758)\n        flow.nn.init.constant_(linear.bias, 0.23)\n\n        opt, lr_sch = get_opt_and_lr_sch(linear.parameters())\n\n        x = flow.tensor(\n            [\n                [-0.94630778, -0.83378579, -0.87060891],\n                [2.0289922, -0.28708987, -2.18369248],\n                [0.35217619, -0.67095644, -1.58943879],\n                [0.08086036, -1.81075924, 1.20752494],\n                [0.8901075, -0.49976737, -1.07153746],\n                [-0.44872912, -1.07275683, 0.06256855],\n                [-0.22556897, 0.74798368, 0.90416439],\n                [0.48339456, -2.32742195, -0.59321527],\n            ],\n            dtype=flow.float32,\n            device=device,\n            requires_grad=False,\n        )\n\n        def one_iter():\n            of_out = linear(x)\n            of_out = of_out.sum()\n\n            of_out.backward()\n            opt.step()\n            if lr_sch is not None:\n                lr_sch.step()\n            opt.zero_grad()\n\n            return of_out.numpy(), linear.weight.numpy()\n\n        check_list = []\n        for i in range(iter_num):\n            check_list.append(one_iter())\n        return check_list\n\n    def train_with_graph(iter_num=3):\n        linear = flow.nn.Linear(3, 8)\n        linear = linear.to(device)\n        flow.nn.init.constant_(linear.weight, -0.68758)\n        flow.nn.init.constant_(linear.bias, 0.23)\n\n        opt, lr_sch = get_opt_and_lr_sch(linear.parameters())\n\n        x = flow.tensor(\n            [\n                [-0.94630778, -0.83378579, -0.87060891],\n                [2.0289922, -0.28708987, -2.18369248],\n                [0.35217619, -0.67095644, -1.58943879],\n                [0.08086036, -1.81075924, 1.20752494],\n                [0.8901075, -0.49976737, -1.07153746],\n                [-0.44872912, -1.07275683, 0.06256855],\n                [-0.22556897, 0.74798368, 0.90416439],\n                [0.48339456, -2.32742195, -0.59321527],\n            ],\n            dtype=flow.float32,\n            device=device,\n            requires_grad=False,\n        )\n\n        class LinearTrainGraph(flow.nn.Graph):\n            def __init__(self):\n                super().__init__()\n                self.linear = linear\n                if lr_sch is None:\n                    self.add_optimizer(opt)\n                else:\n                    self.add_optimizer(opt, lr_sch=lr_sch)\n\n            def build(self, x):\n                out = self.linear(x)\n                out = out.sum()\n                out.backward()\n                return out\n\n        linear_t_g = LinearTrainGraph()\n\n        def one_iter():\n            of_graph_out = linear_t_g(x)\n            return (\n                of_graph_out.numpy(),\n                linear_t_g.linear.weight.to(flow.Tensor).numpy(),\n            )\n\n        check_list = []\n        for i in range(iter_num):\n            check_list.append(one_iter())\n        return check_list\n\n    module_check_list = train_with_module(iter_num)\n    graph_check_list = train_with_graph(iter_num)\n    for i in range(iter_num):\n        # check equal on loss\n        test_case.assertTrue(\n            np.allclose(\n                module_check_list[i][0],\n                graph_check_list[i][0],\n                rtol=0.00001,\n                atol=0.00001,\n            )\n        )\n        # check equal on weight\n        test_case.assertTrue(\n            np.allclose(\n                module_check_list[i][1],\n                graph_check_list[i][1],\n                rtol=0.00001,\n                atol=0.00001,\n            )\n        )\n\n\ndef _sgd_cosine_fn(parameters):\n    of_sgd = flow.optim.SGD(parameters, lr=0.001)\n    alpha = 0.5\n    decay_steps = 10\n    cosine_decay_lr = flow.optim.lr_scheduler.CosineDecayLR(\n        of_sgd, decay_steps=decay_steps, alpha=alpha\n    )\n    return of_sgd, cosine_decay_lr\n\n\ndef _sgd_cosine_constant_fn(parameters):\n    of_sgd = flow.optim.SGD(parameters, lr=0.001)\n    alpha = 0.5\n    decay_steps = 10\n    cosine_decay_lr = flow.optim.lr_scheduler.CosineDecayLR(\n        of_sgd, decay_steps=decay_steps, alpha=alpha\n    )\n    constant_warmup_cosine_lr = flow.optim.lr_scheduler.WarmUpLR(\n        cosine_decay_lr, warmup_factor=0.5, warmup_iters=5, warmup_method=\"constant\"\n    )\n    return of_sgd, constant_warmup_cosine_lr\n\n\ndef _sgd_constant_fn(parameters):\n    of_sgd = flow.optim.SGD(parameters, lr=0.001)\n    alpha = 0.5\n    steps = 10\n    constant_warmup_lr = flow.optim.lr_scheduler.WarmUpLR(\n        of_sgd, warmup_factor=0.5, warmup_iters=5, warmup_method=\"constant\"\n    )\n    return of_sgd, constant_warmup_lr\n\n\ndef _sgd_cosine_linear_fn(parameters):\n    of_sgd = flow.optim.SGD(parameters, lr=0.001)\n    alpha = 0.5\n    decay_steps = 10\n    cosine_decay_lr = flow.optim.lr_scheduler.CosineDecayLR(\n        of_sgd, decay_steps=decay_steps, alpha=alpha\n    )\n    linear_warmup_cosine_lr = flow.optim.lr_scheduler.WarmUpLR(\n        cosine_decay_lr, warmup_factor=0.5, warmup_iters=5, warmup_method=\"linear\"\n    )\n    return of_sgd, linear_warmup_cosine_lr\n\n\ndef _sgd_linear_fn(parameters):\n    of_sgd = flow.optim.SGD(parameters, lr=0.001)\n    alpha = 0.5\n    steps = 10\n    linear_warmup_lr = flow.optim.lr_scheduler.WarmUpLR(\n        of_sgd, warmup_factor=0.5, warmup_iters=5, warmup_method=\"linear\"\n    )\n    return of_sgd, linear_warmup_lr\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n@flow.unittest.skip_unless_1n1d()\nclass TestLinearGraphTrainWithCosineLrScheduler(flow.unittest.TestCase):\n    def test_graph_cosine(test_case):\n        _test_linear_graph_train_with_lr_sch(\n            test_case, 21, flow.device(\"cuda\"), _sgd_cosine_fn\n        )\n        _test_linear_graph_train_with_lr_sch(\n            test_case, 21, flow.device(\"cpu\"), _sgd_cosine_fn\n        )\n\n    def test_graph_cosine_constant(test_case):\n        _test_linear_graph_train_with_lr_sch(\n            test_case, 21, flow.device(\"cuda\"), _sgd_cosine_constant_fn\n        )\n        _test_linear_graph_train_with_lr_sch(\n            test_case, 21, flow.device(\"cpu\"), _sgd_cosine_constant_fn\n        )\n\n    def test_graph_constant(test_case):\n        _test_linear_graph_train_with_lr_sch(\n            test_case, 21, flow.device(\"cuda\"), _sgd_constant_fn\n        )\n        _test_linear_graph_train_with_lr_sch(\n            test_case, 21, flow.device(\"cpu\"), _sgd_constant_fn\n        )\n\n    def test_graph_cosine_linear(test_case):\n        _test_linear_graph_train_with_lr_sch(\n            test_case, 21, flow.device(\"cuda\"), _sgd_cosine_linear_fn\n        )\n        _test_linear_graph_train_with_lr_sch(\n            test_case, 21, flow.device(\"cpu\"), _sgd_cosine_linear_fn\n        )\n\n    def test_graph_linear(test_case):\n        _test_linear_graph_train_with_lr_sch(\n            test_case, 21, flow.device(\"cuda\"), _sgd_linear_fn\n        )\n        _test_linear_graph_train_with_lr_sch(\n            test_case, 21, flow.device(\"cpu\"), _sgd_linear_fn\n        )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/graph/test_graph_lrs.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport math\nimport unittest\nimport os\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.nn.parameter import Parameter\n\n\ndef _test_linear_graph_train_with_lr_sch(\n    test_case, iter_num, device, get_opt_and_lr_sch\n):\n    def train_with_module(iter_num=3):\n        linear = flow.nn.Linear(3, 8)\n        linear = linear.to(device)\n        flow.nn.init.constant_(linear.weight, -0.68758)\n        flow.nn.init.constant_(linear.bias, 0.23)\n\n        opt, lr_sch = get_opt_and_lr_sch(linear.parameters())\n\n        x = flow.tensor(\n            [\n                [-0.94630778, -0.83378579, -0.87060891],\n                [2.0289922, -0.28708987, -2.18369248],\n                [0.35217619, -0.67095644, -1.58943879],\n                [0.08086036, -1.81075924, 1.20752494],\n                [0.8901075, -0.49976737, -1.07153746],\n                [-0.44872912, -1.07275683, 0.06256855],\n                [-0.22556897, 0.74798368, 0.90416439],\n                [0.48339456, -2.32742195, -0.59321527],\n            ],\n            dtype=flow.float32,\n            device=device,\n            requires_grad=False,\n        )\n\n        def one_iter():\n            of_out = linear(x)\n            of_out = of_out.sum()\n\n            of_out.backward()\n            opt.step()\n            if lr_sch is not None:\n                lr_sch.step()\n            opt.zero_grad()\n\n            return of_out.numpy(), linear.weight.numpy()\n\n        check_list = []\n        for i in range(iter_num):\n            check_list.append(one_iter())\n        return check_list\n\n    def train_with_graph(iter_num=3):\n        linear = flow.nn.Linear(3, 8)\n        linear = linear.to(device)\n        flow.nn.init.constant_(linear.weight, -0.68758)\n        flow.nn.init.constant_(linear.bias, 0.23)\n\n        opt, lr_sch = get_opt_and_lr_sch(linear.parameters())\n\n        x = flow.tensor(\n            [\n                [-0.94630778, -0.83378579, -0.87060891],\n                [2.0289922, -0.28708987, -2.18369248],\n                [0.35217619, -0.67095644, -1.58943879],\n                [0.08086036, -1.81075924, 1.20752494],\n                [0.8901075, -0.49976737, -1.07153746],\n                [-0.44872912, -1.07275683, 0.06256855],\n                [-0.22556897, 0.74798368, 0.90416439],\n                [0.48339456, -2.32742195, -0.59321527],\n            ],\n            dtype=flow.float32,\n            device=device,\n            requires_grad=False,\n        )\n\n        class LinearTrainGraph(flow.nn.Graph):\n            def __init__(self):\n                super().__init__()\n                self.linear = linear\n                if lr_sch is None:\n                    self.add_optimizer(opt)\n                else:\n                    self.add_optimizer(opt, lr_sch=lr_sch)\n\n            def build(self, x):\n                out = self.linear(x)\n                out = out.sum()\n                out.backward()\n                return out\n\n        linear_t_g = LinearTrainGraph()\n\n        def one_iter():\n            of_graph_out = linear_t_g(x)\n            return (\n                of_graph_out.numpy(),\n                linear_t_g.linear.weight.to(flow.Tensor).numpy(),\n            )\n\n        check_list = []\n        for i in range(iter_num):\n            check_list.append(one_iter())\n        return check_list\n\n    module_check_list = train_with_module(iter_num)\n    graph_check_list = train_with_graph(iter_num)\n    for i in range(iter_num):\n        # check equal on loss\n        test_case.assertTrue(\n            np.allclose(\n                module_check_list[i][0],\n                graph_check_list[i][0],\n                rtol=0.00001,\n                atol=0.00001,\n            )\n        )\n        # check equal on weight\n        test_case.assertTrue(\n            np.allclose(\n                module_check_list[i][1],\n                graph_check_list[i][1],\n                rtol=0.00001,\n                atol=0.00001,\n            )\n        )\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n@flow.unittest.skip_unless_1n1d()\nclass TestGraphLRs(flow.unittest.TestCase):\n    def test_step_lr(test_case):\n        def _lr_fn(parameters):\n            of_sgd = flow.optim.SGD(parameters, lr=0.001)\n\n            step_lr = flow.optim.lr_scheduler.StepLR(of_sgd, step_size=7, gamma=0.1)\n            return of_sgd, step_lr\n\n        _test_linear_graph_train_with_lr_sch(test_case, 21, flow.device(\"cuda\"), _lr_fn)\n        _test_linear_graph_train_with_lr_sch(test_case, 21, flow.device(\"cpu\"), _lr_fn)\n\n    def test_multistep_lr(test_case):\n        def _lr_fn(parameters):\n            of_sgd = flow.optim.SGD(parameters, lr=0.001)\n\n            multistep_lr = flow.optim.lr_scheduler.MultiStepLR(\n                of_sgd, milestones=[10, 15], gamma=0.1\n            )\n            return of_sgd, multistep_lr\n\n        _test_linear_graph_train_with_lr_sch(test_case, 21, flow.device(\"cuda\"), _lr_fn)\n        _test_linear_graph_train_with_lr_sch(test_case, 21, flow.device(\"cpu\"), _lr_fn)\n\n    @unittest.skip(\"skip for now, becase it failed 6 times in past week\")\n    def test_cosine_annealing_lr(test_case):\n        def _lr_fn(parameters):\n            of_sgd = flow.optim.SGD(parameters, lr=0.001)\n\n            lr = flow.optim.lr_scheduler.CosineAnnealingLR(\n                of_sgd, T_max=5, eta_min=0.0001\n            )\n            return of_sgd, lr\n\n        _test_linear_graph_train_with_lr_sch(test_case, 21, flow.device(\"cuda\"), _lr_fn)\n        _test_linear_graph_train_with_lr_sch(test_case, 21, flow.device(\"cpu\"), _lr_fn)\n\n    def test_polynomial_lr(test_case):\n        def _lr_fn(parameters):\n            of_sgd = flow.optim.SGD(parameters, lr=0.001)\n\n            lr = flow.optim.lr_scheduler.PolynomialLR(\n                of_sgd, decay_batch=10, end_learning_rate=0.00001, power=2, cycle=True\n            )\n            return of_sgd, lr\n\n        _test_linear_graph_train_with_lr_sch(test_case, 21, flow.device(\"cuda\"), _lr_fn)\n\n        _test_linear_graph_train_with_lr_sch(test_case, 21, flow.device(\"cpu\"), _lr_fn)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/graph/test_graph_masked_fill.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nimport numpy as np\nimport random\n\nimport oneflow as flow\nfrom oneflow import nn\nimport oneflow.unittest\nfrom test_util import generate_graph\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestMaskedFillGraph(flow.unittest.TestCase):\n    def test_masked_fill_graph(test_case):\n        k = random.randint(1, 10)\n        model = nn.Sequential(nn.Linear(k, k))\n        optimizer = flow.optim.SGD(model.parameters(), lr=1e-3)\n        loss_fn = nn.MSELoss()\n\n        class MaskedFillGraph(flow.nn.Graph):\n            def __init__(self,):\n                super().__init__()\n                self.model = model\n                self.loss_fn = loss_fn\n                self.add_optimizer(optimizer)\n\n            def build(self, input, mask):\n                output = self.model(input)\n                output = flow.masked_fill(output, mask > 0.5, 0.5)\n                loss = self.loss_fn(output, input)\n                loss.backward()\n                return loss\n\n        input = flow.randn(k, k).requires_grad_()\n        mask = flow.randn(k, k)\n        model = MaskedFillGraph()\n        return model(input, mask)\n\n    def test_masked_fill_by_generate_graph(test_case):\n        k = random.randint(1, 10)\n        input = flow.randn(k, k)\n        mask = flow.randn(k, k)\n\n        masked_fill_fn = lambda: flow.masked_fill(input, mask > 0.5, 0.5)\n        y_eager = masked_fill_fn()\n        masked_fill_graph = generate_graph(masked_fill_fn)\n        y_lazy = masked_fill_graph()\n        test_case.assertTrue(np.array_equal(y_eager.numpy(), y_lazy.numpy()))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/graph/test_graph_nccl_logical_fusion.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\n\nimport oneflow as flow\nfrom oneflow import nn\nimport os\nimport numpy as np\n\nimport oneflow.unittest\n\n\n@flow.unittest.skip_unless_1n4d()\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\nclass TestGraphNcclLogicalFusion(flow.unittest.TestCase):\n    def test_graph_nccl_fusion_1d(test_case):\n        x_list = []\n        local_np = np.arange(4 * 8, dtype=float).reshape(4, 8)\n        P1d = flow.placement(\"cuda\", ranks=[0, 1, 2, 3])\n        B = flow.sbp.broadcast()\n        S0 = flow.sbp.split(0)\n        S1 = flow.sbp.split(1)\n        P = flow.sbp.partial_sum()\n\n        in_0 = (\n            flow.tensor(local_np / 4.0)\n            .to(flow.device(\"cuda\"))\n            .to_global(sbp=P, placement=P1d)\n        )\n\n        flow.boxing.nccl.enable_use_compute_stream(True)\n\n        class TestNcclFusion1DGraph(flow.nn.Graph):\n            def __init__(self):\n                super().__init__()\n\n            def build(self, x):\n                # fuse group 0:\n                x0 = x * 0.5\n                y0 = x0.to_global(sbp=B, placement=P1d)  # P->B\n\n                x1 = x * 1.0\n                y1 = x1.to_global(sbp=S0, placement=P1d)  # P->S0\n\n                x2 = x * 2.0\n                y2 = x2.to_global(sbp=S1, placement=P1d)  # P->S1\n\n                x3 = x * 3.0\n                y3 = x3.to_global(sbp=S1, placement=P1d)  # P->S1\n\n                x4 = x * 4.0\n                y4 = x4.to_global(sbp=S0, placement=P1d)  # P->S0\n\n                # fuse group 1:\n                x5 = y1 * 5.0\n                y5 = x5.to_global(sbp=B, placement=P1d)  # S0->B\n\n                x6 = y2 * (6.0 / 2.0)\n                y6 = x6.to_global(sbp=B, placement=P1d)  # S1->B\n\n                x7 = y3 * (9.0 / 3.0)\n                y7 = x7.to_global(sbp=S0, placement=P1d)  # S1->S0\n\n                x8 = y4 * (8.0 / 4.0)\n                y8 = x8.to_global(sbp=S1, placement=P1d)  # S0->S1\n\n                y = y0 + y1 + y2 + y3 + y4 + y5 + y6 + y7 + y8\n                return y, y0, y1, y2, y3, y4, y5, y6, y7, y8\n\n        graph = TestNcclFusion1DGraph()\n        out, out_0, out_1, out_2, out_3, out_4, out_5, out_6, out_7, out_8 = graph(in_0)\n        test_case.assertTrue(np.array_equal(out_0.numpy(), local_np * 0.5))\n        test_case.assertTrue(np.array_equal(out_1.numpy(), local_np * 1.0))\n        test_case.assertTrue(np.array_equal(out_2.numpy(), local_np * 2.0))\n        test_case.assertTrue(np.array_equal(out_3.numpy(), local_np * 3.0))\n        test_case.assertTrue(np.array_equal(out_4.numpy(), local_np * 4.0))\n        test_case.assertTrue(np.array_equal(out_5.numpy(), local_np * 5.0))\n        test_case.assertTrue(np.array_equal(out_6.numpy(), local_np * 6.0))\n        test_case.assertTrue(np.array_equal(out_7.numpy(), local_np * 9.0))\n        test_case.assertTrue(np.array_equal(out_8.numpy(), local_np * 8.0))\n        flow.boxing.nccl.enable_use_compute_stream(False)\n\n    def test_graph_nccl_fusion_2d(test_case):\n        x_list = []\n        local_np = np.arange(4 * 8, dtype=float).reshape(4, 8)\n        P2d = flow.placement(\"cuda\", ranks=[[0, 1], [2, 3]])\n        B = flow.sbp.broadcast()\n        S0 = flow.sbp.split(0)\n        S1 = flow.sbp.split(1)\n        P = flow.sbp.partial_sum()\n\n        in_BP = (\n            flow.tensor(local_np / 2.0)\n            .to(flow.device(\"cuda\"))\n            .to_global(sbp=(B, P), placement=P2d)\n        )\n        in_PB = (\n            flow.tensor(local_np / 2.0)\n            .to(flow.device(\"cuda\"))\n            .to_global(sbp=(P, B), placement=P2d)\n        )\n        in_S0P = in_BP.to_global(sbp=(S0, P), placement=P2d)\n        in_PS0 = in_PB.to_global(sbp=(P, S0), placement=P2d)\n\n        flow.boxing.nccl.enable_use_compute_stream(True)\n\n        class TestNcclFusion2DGraph(flow.nn.Graph):\n            def __init__(self):\n                super().__init__()\n\n            def build(self, x, xsd1):\n                # fuse group 0:\n                x0 = x * 0.5\n                y0 = x0.to_global(sbp=(S0, B), placement=P2d)  # same dim0 P->B\n\n                x1 = x * 1.0\n                y1 = x1.to_global(sbp=(S0, B), placement=P2d)  # same dim0 P->B\n\n                xss0 = x.to_global(sbp=(S0, S0), placement=P2d)\n                xss1 = x.to_global(sbp=(S0, S1), placement=P2d)\n                x2 = xss0 * 2.0\n                y2 = x2.to_global(sbp=(S0, B), placement=P2d)  # same dim0 S0->B\n\n                x3 = xss1 * 3.0\n                y3 = x3.to_global(sbp=(S0, B), placement=P2d)  # same dim0 S1->B\n\n                x4 = xss0 * 4.0\n                y4 = x4.to_global(sbp=(S0, S1), placement=P2d)  # same dim0 S0->S1\n\n                x5 = xss1 * 5.0\n                y5 = x5.to_global(sbp=(S0, S0), placement=P2d)  # same dim0 S1->S0\n\n                x6 = xsd1 * 6.0\n                y6 = x6.to_global(sbp=(B, S0), placement=P2d)  # same dim1 P-> B\n\n                x7 = xsd1 * 7.0\n                y7 = x7.to_global(sbp=(B, S0), placement=P2d)  # same dim1 P-> B\n\n                y = y0 + y1 + y2 + y3 + y4 + y5 + y6 + y7\n                return y, y0, y1, y2, y3, y4, y5, y6, y7\n\n        graph = TestNcclFusion2DGraph()\n        out, out_0, out_1, out_2, out_3, out_4, out_5, out_6, out_7 = graph(\n            in_S0P, in_PS0\n        )\n        test_case.assertTrue(np.array_equal(out_0.numpy(), local_np * 0.5))\n        test_case.assertTrue(np.array_equal(out_1.numpy(), local_np * 1.0))\n        test_case.assertTrue(np.array_equal(out_2.numpy(), local_np * 2.0))\n        test_case.assertTrue(np.array_equal(out_3.numpy(), local_np * 3.0))\n        test_case.assertTrue(np.array_equal(out_4.numpy(), local_np * 4.0))\n        test_case.assertTrue(np.array_equal(out_5.numpy(), local_np * 5.0))\n        test_case.assertTrue(np.array_equal(out_6.numpy(), local_np * 6.0))\n        test_case.assertTrue(np.array_equal(out_7.numpy(), local_np * 7.0))\n        flow.boxing.nccl.enable_use_compute_stream(False)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/graph/test_graph_non_contiguous_tensors.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport os\nimport unittest\nimport oneflow.unittest\nimport oneflow as flow\nimport numpy as np\n\n\nclass ModuleTest(flow.nn.Module):\n    def __init__(self, contiguous: bool, device):\n        super().__init__()\n        if contiguous:\n            self.weight = flow.nn.Parameter(flow.ones(4, 3, device=device))\n        else:\n            self.weight = flow.nn.Parameter(\n                flow.ones(3, 4, device=device).transpose(0, 1)\n            )\n\n    def forward(self, input):\n        res = flow.matmul(input, self.weight)\n        return res\n\n\ndef _test_graph_non_contiguous_tensors(test_case, device):\n    bias = flow.tensor(\n        [[1, 2, 3], [3, 4, 5], [7, 7, 7],], dtype=flow.float32, device=device\n    )\n\n    free_eager_bias_contiguous = bias\n    free_eager_bias_non_contiguous = bias.transpose(0, 1).contiguous().transpose(0, 1)\n    test_case.assertTrue(free_eager_bias_contiguous.is_contiguous())\n    test_case.assertFalse(free_eager_bias_non_contiguous.is_contiguous())\n\n    class GraphTestContiguousTensors(flow.nn.Graph):\n        def __init__(self):\n            super().__init__()\n            self.model = ModuleTest(True, device)\n\n        def build(self, input):\n            res = self.model(input) + free_eager_bias_contiguous\n            return res\n\n    class GraphTestNonContiguousTensors(flow.nn.Graph):\n        def __init__(self):\n            super().__init__()\n            self.model = ModuleTest(False, device)\n\n        def build(self, input):\n            res = self.model(input) + free_eager_bias_non_contiguous\n            return res\n\n    graph_contiguous_tensors = GraphTestContiguousTensors()\n    graph_non_contiguous_tensors = GraphTestNonContiguousTensors()\n\n    test_case.assertTrue(\n        graph_contiguous_tensors.model.weight.to(flow.Tensor).is_contiguous()\n    )\n    test_case.assertFalse(\n        graph_non_contiguous_tensors.model.weight.to(flow.Tensor).is_contiguous()\n    )\n\n    inp = flow.tensor(\n        [[1, 2, 3], [4, 5, 6], [3, 3, 3], [7, 8, 8]], dtype=flow.float32, device=device\n    )\n\n    non_contiguous_input = inp.transpose(0, 1)\n    test_case.assertFalse(non_contiguous_input.is_contiguous())\n\n    contiguous_input = non_contiguous_input.contiguous()\n    test_case.assertTrue(contiguous_input.is_contiguous())\n\n    contiguous_graph_output = graph_contiguous_tensors(contiguous_input)\n    non_contiguous_graph_output = graph_non_contiguous_tensors(non_contiguous_input)\n    test_case.assertTrue(\n        np.array_equal(\n            contiguous_graph_output.numpy(), non_contiguous_graph_output.numpy()\n        )\n    )\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestGraphNonContiguousTensor(oneflow.unittest.TestCase):\n    def test_graph_non_contiguous_tensors_cpu(test_case):\n        _test_graph_non_contiguous_tensors(test_case, flow.device(\"cpu\"))\n\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    def test_graph_non_contiguous_tensors_gpu(test_case):\n        _test_graph_non_contiguous_tensors(test_case, flow.device(\"cuda\"))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/graph/test_graph_normal_inplace.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport oneflow as flow\nimport numpy as np\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\nfrom oneflow.test_utils.test_util import GenArgDict\n\n\n_fn_param_local = {\n    \"normal\": lambda data: flow.normal(\n        size=data.shape, mean=0.0, std=1.0, out=data\n    ),  # NOTE(lixiang): source op that can be inplaced.\n}\n\n\n_fn_param_global = {\n    \"normal\": lambda data, placement, sbp: flow.normal(\n        size=data.shape, mean=0.0, std=1.0, out=data, placement=placement, sbp=sbp,\n    ),\n}\n\n\ndef _test_data_local(test_case, device, fn):\n\n    data_1 = flow.zeros([16, 64, 128, 128]).to(device)\n    data_2 = flow.zeros([16, 64, 128, 128]).to(device)\n\n    class NormalGraph(flow.nn.Graph):\n        def __init__(self):\n            super().__init__()\n\n        def build(self):\n            fn(data_1).to(device)\n            return data_1\n\n    model = NormalGraph()\n    lazy_x = model()\n    fn(data_2)\n\n    test_case.assertTrue(lazy_x.numpy().sum() != 0)\n    test_case.assertTrue(data_2.numpy().sum() != 0)\n\n\ndef _test_data_global(test_case, data_1, data_2, placement, sbp, fn):\n    class GlobalNormalGraph(flow.nn.Graph):\n        def __init__(self):\n            super().__init__()\n\n        def build(self):\n            flow.manual_seed(233)\n            fn(data_1, placement, sbp)\n            return data_1\n\n    model = GlobalNormalGraph()\n    lazy_x = model()\n\n    flow.manual_seed(233)\n    fn(data_2, placement, sbp)\n\n    test_case.assertTrue(\n        np.array_equal(lazy_x.to_local().numpy(), data_2.to_local().numpy())\n    )\n\n\nclass TestNormalOpInplaceData(flow.unittest.TestCase):\n    @oneflow.unittest.skip_unless_1n1d()\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    def test_normal_op_data_local_with_eager_and_lazy(test_case):\n\n        for device in [\"cuda\", \"cpu\"]:\n            for _, fn in _fn_param_local.items():\n                _test_data_local(test_case, device, fn=fn)\n\n    @unittest.skipIf(True, \"refactor eager random to align pytorch\")\n    @globaltest\n    def test_normal_op_data_consistent_with_eager_and_lazy(test_case):\n\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=2, except_partial_sum=True):\n\n                data_1 = flow.empty([8, 64, 128, 128]).to_global(\n                    placement=placement, sbp=sbp\n                )\n                data_2 = flow.empty([8, 64, 128, 128]).to_global(\n                    placement=placement, sbp=sbp\n                )\n\n                for _, fn in _fn_param_global.items():\n                    _test_data_global(test_case, data_1, data_2, placement, sbp, fn=fn)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/graph/test_graph_ofrecord_reader.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport os\nimport unittest\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\nclass OFRecordDataLoader(flow.nn.Module):\n    def __init__(self):\n        super().__init__()\n        batch_size = 4\n        self.train_record_reader = flow.nn.OFRecordReader(\n            flow.unittest.dataset_dir(\"imagenette/ofrecord\"),\n            batch_size=batch_size,\n            data_part_num=1,\n            part_name_suffix_length=5,\n            random_shuffle=True,\n            shuffle_after_epoch=True,\n            # placement=flow.placement(\"cpu\", ranks=[0]),\n            # sbp=[flow.sbp.broadcast]\n        )\n\n        self.record_label_decoder = flow.nn.OFRecordRawDecoder(\n            \"class/label\", shape=(), dtype=flow.int32\n        )\n\n        color_space = \"RGB\"\n        output_layout = \"NHWC\"\n        self.record_image_decoder = flow.nn.OFRecordImageDecoderRandomCrop(\n            \"encoded\", color_space=color_space\n        )\n\n        self.resize = flow.nn.image.Resize(target_size=[224, 224])\n\n        self.flip = flow.nn.CoinFlip(\n            batch_size=batch_size,\n            # placement=flow.placement(\"cpu\", ranks=[0]),\n            # sbp=[flow.sbp.broadcast]\n        )\n\n        rgb_mean = [123.68, 116.779, 103.939]\n        rgb_std = [58.393, 57.12, 57.375]\n        self.crop_mirror_norm = flow.nn.CropMirrorNormalize(\n            color_space=color_space,\n            output_layout=output_layout,\n            mean=rgb_mean,\n            std=rgb_std,\n            output_dtype=flow.float,\n        )\n\n    def forward(self) -> (flow.Tensor, flow.Tensor):\n        train_record = self.train_record_reader()\n        label = self.record_label_decoder(train_record)\n        image_raw_buffer = self.record_image_decoder(train_record)\n        image = self.resize(image_raw_buffer)[0]\n        rng = self.flip()\n        image = self.crop_mirror_norm(image, rng)\n        return image, label\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestOFRecordReaderGraph(oneflow.unittest.TestCase):\n    def test_ofrecord_reader_graph(test_case):\n        cc_reader = OFRecordDataLoader()\n\n        class GraphReader(flow.nn.Graph):\n            def __init__(self):\n                super().__init__()\n                self.my_reader = cc_reader\n\n            def build(self):\n                return self.my_reader()\n\n        reader_g = GraphReader()\n        image, label = reader_g()\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/graph/test_graph_optim_adadelta.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nfrom collections import OrderedDict\nimport numpy as np\nimport copy\n\nfrom test_util import GenArgList\nfrom optimizer_test_util import clip_grad_norm_np\n\nimport oneflow as flow\n\n\ndef compare_with_numpy_adadelta(\n    test_case,\n    device,\n    x_shape,\n    learning_rate,\n    train_iters,\n    rho,\n    eps,\n    maximize,\n    weight_decay,\n):\n    random_grad_seq = []\n    for _ in range(train_iters):\n        random_grad_seq.append(np.random.uniform(size=x_shape).astype(np.float32))\n    init_value = np.random.uniform(size=x_shape).astype(np.float32)\n\n    class CustomModule(flow.nn.Module):\n        def __init__(self):\n            super().__init__()\n            self.para0 = flow.nn.Parameter(\n                flow.Tensor(init_value, device=flow.device(device))\n            )\n\n        def forward(self, mask):\n            return self.para0 * mask\n\n    simp_module = CustomModule()\n    simp_module.to(device)\n    simp_module.train()\n\n    adadelta0 = flow.optim.Adadelta(\n        [\n            {\n                \"params\": simp_module.parameters(),\n                \"lr\": learning_rate,\n                \"weight_decay\": weight_decay,\n            }\n        ],\n        rho=rho,\n        eps=eps,\n        maximize=maximize,\n    )\n\n    class CustomAdadeltaGraph(flow.nn.Graph):\n        def __init__(self):\n            super().__init__()\n            self.m = simp_module\n            self.add_optimizer(adadelta0)\n\n        def build(self, mask_tensor):\n            loss = flow.sum(self.m(mask_tensor))\n            loss.backward()\n            return loss\n\n    of_res_list = []\n    adadelta_graph = CustomAdadeltaGraph()\n\n    for i in range(train_iters):\n        mask_tensor = flow.tensor(\n            random_grad_seq[i],\n            dtype=flow.float32,\n            requires_grad=False,\n            device=flow.device(device),\n        )\n        adadelta_x = adadelta_graph(mask_tensor)\n\n        of_res_list.append(copy.copy(simp_module.para0.numpy()))\n\n    np_res_list = []\n\n    def train_by_numpy():\n        x = init_value\n        square_avgs = np.zeros_like(x)\n        acc_deltas = np.zeros_like(x)\n\n        def np_train_one_iter(grad):\n            grad = grad if not maximize else -grad\n            grad = grad + weight_decay * x\n            new_square_avgs = square_avgs * rho + (1.0 - rho) * grad * grad\n            std = np.sqrt(new_square_avgs + eps)\n            delta = np.sqrt(acc_deltas + eps) / std * grad\n            new_acc_deltas = acc_deltas * rho + delta * delta * (1 - rho)\n            param = x - learning_rate * delta\n            return (param, new_square_avgs, new_acc_deltas)\n\n        for i in range(1, train_iters + 1):\n            (x, square_avgs, acc_deltas) = np_train_one_iter(random_grad_seq[i - 1])\n            np_res_list.append(x)\n        return x\n\n    train_by_numpy()\n\n    test_case.assertTrue(np.allclose(of_res_list, np_res_list, rtol=1e-4, atol=1e-4))\n\n\ndef compare_with_numpy_adadelta_clip_grad(\n    test_case,\n    device,\n    x_shape,\n    learning_rate,\n    train_iters,\n    rho,\n    eps,\n    maximize,\n    weight_decay,\n    clip_grad_max_norm,\n    clip_grad_norm_type,\n):\n    random_grad_seq = []\n    for _ in range(train_iters):\n        random_grad_seq.append(np.random.uniform(size=x_shape).astype(np.float32))\n    init_value = np.random.uniform(size=x_shape).astype(np.float32)\n\n    class CustomModule(flow.nn.Module):\n        def __init__(self):\n            super().__init__()\n            self.para0 = flow.nn.Parameter(\n                flow.tensor(init_value, device=flow.device(device))\n            )\n\n        def forward(self, mask):\n            return self.para0 * mask\n\n    simp_module = CustomModule()\n    simp_module.to(device)\n    simp_module.train()\n\n    adadelta0 = flow.optim.Adadelta(\n        [\n            {\n                \"params\": simp_module.parameters(),\n                \"lr\": learning_rate,\n                \"weight_decay\": weight_decay,\n                \"clip_grad_max_norm\": clip_grad_max_norm,\n                \"clip_grad_norm_type\": clip_grad_norm_type,\n            }\n        ],\n        rho=rho,\n        eps=eps,\n        maximize=maximize,\n    )\n\n    class CustomAdadeltaGraph(flow.nn.Graph):\n        def __init__(self):\n            super().__init__()\n            self.m = simp_module\n            self.add_optimizer(adadelta0)\n\n        def build(self, mask_tensor):\n            loss = flow.sum(self.m(mask_tensor))\n            loss.backward()\n            return loss\n\n    of_res_list = []\n    adadelta_graph = CustomAdadeltaGraph()\n\n    for i in range(train_iters):\n        mask_tensor = flow.tensor(\n            random_grad_seq[i], requires_grad=False, device=flow.device(device)\n        )\n        adadelta_x = adadelta_graph(mask_tensor)\n\n        of_res_list.append(copy.copy(simp_module.para0.numpy()))\n\n    np_res_list = []\n\n    def train_by_numpy():\n        x = init_value\n        square_avgs = np.zeros_like(x)\n        acc_deltas = np.zeros_like(x)\n\n        def np_train_one_iter(grad):\n            total_norm, grad = clip_grad_norm_np(\n                grad, clip_grad_max_norm, clip_grad_norm_type\n            )\n            grad = grad if not maximize else -grad\n            grad = grad + weight_decay * x\n            new_square_avgs = square_avgs * rho + (1.0 - rho) * grad * grad\n            std = np.sqrt(new_square_avgs + eps)\n            delta = np.sqrt(acc_deltas + eps) / std * grad\n            new_acc_deltas = acc_deltas * rho + delta * delta * (1 - rho)\n            param = x - learning_rate * delta\n            return (param, new_square_avgs, new_acc_deltas)\n\n        for i in range(1, train_iters + 1):\n            (x, square_avgs, acc_deltas) = np_train_one_iter(random_grad_seq[i - 1])\n            np_res_list.append(x)\n        return x\n\n    train_by_numpy()\n    test_case.assertTrue(np.allclose(of_res_list, np_res_list, rtol=1e-4, atol=1e-4))\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestAdadelta(flow.unittest.TestCase):\n    @unittest.skip(\"skip for now, becase it failed 8 times in past week\")\n    def test_adadelta(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"x_shape\"] = [(10,)]\n        arg_dict[\"learning_rate\"] = [1, 1e-3]\n        arg_dict[\"train_iters\"] = [10]\n        arg_dict[\"rho\"] = [0.9]\n        arg_dict[\"eps\"] = [1e-6]\n        arg_dict[\"maximize\"] = [False]\n        arg_dict[\"weight_decay\"] = [0.1]\n\n        for arg in GenArgList(arg_dict):\n            compare_with_numpy_adadelta(test_case, *arg)\n\n    def test_adadelta_clip_grad(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"x_shape\"] = [(10,)]\n        arg_dict[\"learning_rate\"] = [1, 1e-3]\n        arg_dict[\"train_iters\"] = [10]\n        arg_dict[\"rho\"] = [0.9]\n        arg_dict[\"eps\"] = [1e-6]\n        arg_dict[\"maximize\"] = [False]\n        arg_dict[\"weight_decay\"] = [0.1]\n        arg_dict[\"clip_grad_max_norm\"] = [1.0]\n        arg_dict[\"clip_grad_norm_type\"] = [2.0]\n        for arg in GenArgList(arg_dict):\n            compare_with_numpy_adadelta_clip_grad(test_case, *arg)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/graph/test_graph_optim_adagrad.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport os\nimport unittest\nfrom collections import OrderedDict\nimport numpy as np\nimport copy\n\nfrom test_util import GenArgList\nfrom optimizer_test_util import clip_grad_norm_np\n\nimport oneflow as flow\n\n\ndef compare_with_numpy_adagrad(\n    test_case,\n    device,\n    x_shape,\n    learning_rate,\n    train_iters,\n    lr_decay,\n    weight_decay,\n    initial_accumulator_value,\n    eps,\n):\n    random_grad_seq = []\n    for _ in range(train_iters):\n        random_grad_seq.append(np.random.uniform(size=x_shape).astype(np.float32))\n    init_value = np.random.uniform(size=x_shape).astype(np.float32)\n\n    class CustomModule(flow.nn.Module):\n        def __init__(self):\n            super().__init__()\n            self.para0 = flow.nn.Parameter(\n                flow.Tensor(init_value, device=flow.device(device))\n            )\n\n        def forward(self, mask):\n            return self.para0 * mask\n\n    simp_module = CustomModule()\n    simp_module.to(device)\n    simp_module.train()\n\n    adam0 = flow.optim.Adagrad(\n        [\n            {\n                \"params\": simp_module.parameters(),\n                \"lr\": learning_rate,\n                \"eps\": eps,\n                \"weight_decay\": weight_decay,\n            }\n        ],\n        lr_decay=lr_decay,\n        initial_accumulator_value=initial_accumulator_value,\n    )\n\n    class CustomAdagradGraph(flow.nn.Graph):\n        def __init__(self):\n            super().__init__()\n            self.m = simp_module\n            self.add_optimizer(adam0)\n\n        def build(self, mask_tensor):\n            loss = flow.sum(self.m(mask_tensor))\n            loss.backward()\n            return loss\n\n    of_res_list = []\n    adagrad_graph = CustomAdagradGraph()\n\n    for i in range(train_iters):\n        mask_tensor = flow.tensor(\n            random_grad_seq[i], requires_grad=False, device=flow.device(device)\n        )\n        adagrad_x = adagrad_graph(mask_tensor)\n\n        of_res_list.append(copy.copy(simp_module.para0.numpy()))\n\n    np_res_list = []\n\n    def train_by_numpy():\n        x = init_value\n        st = np.ones_like(x) * initial_accumulator_value\n\n        def train_one_iter(iter, grad):\n            grad = grad + weight_decay * x\n            lr = learning_rate / (1 + (iter - 1) * lr_decay)\n            s = st + grad * grad\n            param = x - lr / (np.sqrt(s) + eps) * grad\n            return (param, s)\n\n        for i in range(1, train_iters + 1):\n            (x, st) = train_one_iter(i, random_grad_seq[i - 1])\n            np_res_list.append(x)\n        return x\n\n    train_by_numpy()\n    test_case.assertTrue(np.allclose(of_res_list, np_res_list, rtol=0.001, atol=0.001))\n\n\ndef compare_with_numpy_adagrad_clip_grad(\n    test_case,\n    device,\n    x_shape,\n    learning_rate,\n    train_iters,\n    lr_decay,\n    weight_decay,\n    initial_accumulator_value,\n    eps,\n    clip_grad_max_norm,\n    clip_grad_norm_type,\n):\n    random_grad_seq = []\n    for _ in range(train_iters):\n        random_grad_seq.append(np.random.uniform(size=x_shape).astype(np.float32))\n    init_value = np.random.uniform(size=x_shape).astype(np.float32)\n\n    class CustomModule(flow.nn.Module):\n        def __init__(self):\n            super().__init__()\n            self.para0 = flow.nn.Parameter(\n                flow.Tensor(init_value, device=flow.device(device))\n            )\n\n        def forward(self, mask):\n            return self.para0 * mask\n\n    simp_module = CustomModule()\n    simp_module.to(device)\n    simp_module.train()\n\n    adam0 = flow.optim.Adagrad(\n        [\n            {\n                \"params\": simp_module.parameters(),\n                \"lr\": learning_rate,\n                \"eps\": eps,\n                \"weight_decay\": weight_decay,\n                \"clip_grad_max_norm\": clip_grad_max_norm,\n                \"clip_grad_norm_type\": clip_grad_norm_type,\n            }\n        ],\n        lr_decay=lr_decay,\n        initial_accumulator_value=initial_accumulator_value,\n    )\n\n    class CustomAdagradGraph(flow.nn.Graph):\n        def __init__(self):\n            super().__init__()\n            self.m = simp_module\n            self.add_optimizer(adam0)\n\n        def build(self, mask_tensor):\n            loss = flow.sum(self.m(mask_tensor))\n            loss.backward()\n            return loss\n\n    of_res_list = []\n    adagrad_graph = CustomAdagradGraph()\n\n    for i in range(train_iters):\n        mask_tensor = flow.tensor(\n            random_grad_seq[i], requires_grad=False, device=flow.device(device)\n        )\n        adagrad_x = adagrad_graph(mask_tensor)\n\n        of_res_list.append(copy.copy(simp_module.para0.numpy()))\n\n    np_res_list = []\n\n    def train_by_numpy():\n        x = init_value\n        st = np.ones_like(x) * initial_accumulator_value\n\n        def np_train_one_iter(iter, grad):\n            norm, grad = clip_grad_norm_np(\n                grad, clip_grad_max_norm, clip_grad_norm_type\n            )\n            grad = grad + weight_decay * x\n            lr = learning_rate / (1 + (iter - 1) * lr_decay)\n            s = st + grad * grad\n            param = x - lr / (np.sqrt(s) + eps) * grad\n\n            return (param, s)\n\n        for i in range(1, train_iters + 1):\n            (x, st) = np_train_one_iter(i, random_grad_seq[i - 1])\n            np_res_list.append(x)\n\n        return x\n\n    train_by_numpy()\n\n    test_case.assertTrue(np.allclose(of_res_list, np_res_list, rtol=0.001, atol=0.001))\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestAdagrad(flow.unittest.TestCase):\n    def test_adagrad(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"device\"] = [\"cuda\"]\n        if os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"):\n            arg_dict[\"device\"] = [\"cpu\"]\n        arg_dict[\"x_shape\"] = [(10,)]\n        arg_dict[\"learning_rate\"] = [1, 1e-3]\n        arg_dict[\"train_iters\"] = [10]\n        arg_dict[\"lr_decay\"] = [0.9, 0.75]\n        arg_dict[\"weight_decay\"] = [0.0, 0.1]\n        arg_dict[\"initial_accumulator_value\"] = [1.0, 2.1]\n        arg_dict[\"eps\"] = [1e-08, 1e-07]\n\n        for arg in GenArgList(arg_dict):\n            compare_with_numpy_adagrad(test_case, *arg)\n\n    def test_adagrad_clip_grad(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"device\"] = [\"cuda\"]\n        if os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"):\n            arg_dict[\"device\"] = [\"cpu\"]\n        arg_dict[\"x_shape\"] = [(10,)]\n        arg_dict[\"learning_rate\"] = [1, 1e-3]\n        arg_dict[\"train_iters\"] = [10]\n        arg_dict[\"lr_decay\"] = [0.9, 0.75]\n        arg_dict[\"weight_decay\"] = [0.0, 0.9]\n        arg_dict[\"initial_accumulator_value\"] = [1.0, 2.1]\n        arg_dict[\"eps\"] = [1e-8]\n        arg_dict[\"clip_grad_max_norm\"] = [1.0]\n        arg_dict[\"clip_grad_norm_type\"] = [2.0]\n        for arg in GenArgList(arg_dict):\n            compare_with_numpy_adagrad_clip_grad(test_case, *arg)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/graph/test_graph_optim_adam.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nfrom collections import OrderedDict\nimport numpy as np\nimport copy\n\nfrom test_util import GenArgList\nfrom optimizer_test_util import clip_grad_norm_np\n\nimport oneflow as flow\n\n\ndef compare_with_numpy_adam(\n    test_case,\n    device,\n    x_shape,\n    learning_rate,\n    train_iters,\n    betas,\n    weight_decay,\n    eps,\n    do_bias_correction,\n    amsgrad,\n):\n    random_grad_seq = []\n    for _ in range(train_iters):\n        random_grad_seq.append(np.random.uniform(size=x_shape).astype(np.float32))\n    init_value = np.random.uniform(size=x_shape).astype(np.float32)\n\n    class CustomModule(flow.nn.Module):\n        def __init__(self):\n            super().__init__()\n            self.para0 = flow.nn.Parameter(\n                flow.Tensor(init_value, device=flow.device(device))\n            )\n\n        def forward(self, mask):\n            return self.para0 * mask\n\n    simp_module = CustomModule()\n    simp_module.to(device)\n    simp_module.train()\n\n    adam0 = flow.optim.Adam(\n        [\n            {\n                \"params\": simp_module.parameters(),\n                \"lr\": learning_rate,\n                \"betas\": betas,\n                \"eps\": eps,\n                \"weight_decay\": weight_decay,\n                \"do_bias_correction\": do_bias_correction,\n                \"amsgrad\": amsgrad,\n            }\n        ]\n    )\n\n    class CustomAdamGraph(flow.nn.Graph):\n        def __init__(self):\n            super().__init__()\n            self.m = simp_module\n            self.add_optimizer(adam0)\n\n        def build(self, mask_tensor):\n            loss = flow.sum(self.m(mask_tensor))\n            loss.backward()\n            return loss\n\n    of_res_list = []\n    adam_graph = CustomAdamGraph()\n\n    for i in range(train_iters):\n        mask_tensor = flow.tensor(\n            random_grad_seq[i],\n            dtype=flow.float32,\n            requires_grad=False,\n            device=flow.device(device),\n        )\n        adam_x = adam_graph(mask_tensor)\n\n        of_res_list.append(copy.copy(simp_module.para0.numpy()))\n\n    np_res_list = []\n\n    def train_by_numpy():\n        x = init_value\n        vt = np.zeros_like(x)\n        st = np.zeros_like(x)\n        max_st = np.zeros_like(x)\n        beta1 = betas[0]\n        beta2 = betas[1]\n\n        def np_train_one_iter(step, grad):\n            grad = grad + weight_decay * x\n\n            bias_correction1 = 1.0\n            bias_correction2 = 1.0\n\n            if do_bias_correction:\n                bias_correction1 = 1.0 - np.power(beta1, step)\n                bias_correction2 = 1.0 - np.power(beta2, step)\n\n            v = beta1 * vt + (1 - beta1) * grad\n            s = beta2 * st + (1 - beta2) * grad * grad\n            max_s = np.zeros_like(x)\n\n            if amsgrad:\n                max_s = np.maximum(s, max_st)\n                denom = np.sqrt(max_s) / np.sqrt(bias_correction2) + eps\n            else:\n                denom = np.sqrt(s) / np.sqrt(bias_correction2) + eps\n\n            param = x - ((learning_rate / bias_correction1) * v / denom)\n            return (param, v, s, max_s)\n\n        for i in range(1, train_iters + 1):\n            (x, vt, st, max_st) = np_train_one_iter(i, random_grad_seq[i - 1])\n            np_res_list.append(x)\n        return x\n\n    train_by_numpy()\n\n    test_case.assertTrue(np.allclose(of_res_list, np_res_list, rtol=0.001, atol=0.001))\n\n\ndef compare_with_numpy_adam_clip_grad(\n    test_case,\n    device,\n    x_shape,\n    learning_rate,\n    train_iters,\n    betas,\n    weight_decay,\n    eps,\n    do_bias_correction,\n    amsgrad,\n    clip_grad_max_norm,\n    clip_grad_norm_type,\n):\n    random_grad_seq = []\n    for _ in range(train_iters):\n        random_grad_seq.append(np.random.uniform(size=x_shape).astype(np.float32))\n    init_value = np.random.uniform(size=x_shape).astype(np.float32)\n\n    class CustomModule(flow.nn.Module):\n        def __init__(self):\n            super().__init__()\n            self.para0 = flow.nn.Parameter(\n                flow.tensor(init_value, device=flow.device(device))\n            )\n\n        def forward(self, mask):\n            return self.para0 * mask\n\n    simp_module = CustomModule()\n    simp_module.to(device)\n    simp_module.train()\n\n    adam0 = flow.optim.Adam(\n        [\n            {\n                \"params\": simp_module.parameters(),\n                \"lr\": learning_rate,\n                \"betas\": betas,\n                \"eps\": eps,\n                \"weight_decay\": weight_decay,\n                \"do_bias_correction\": do_bias_correction,\n                \"amsgrad\": amsgrad,\n                \"clip_grad_max_norm\": clip_grad_max_norm,\n                \"clip_grad_norm_type\": clip_grad_norm_type,\n            }\n        ]\n    )\n\n    class CustomAdamGraph(flow.nn.Graph):\n        def __init__(self):\n            super().__init__()\n            self.m = simp_module\n            self.add_optimizer(adam0)\n\n        def build(self, mask_tensor):\n            loss = flow.sum(self.m(mask_tensor))\n            loss.backward()\n            return loss\n\n    of_res_list = []\n    adam_graph = CustomAdamGraph()\n\n    for i in range(train_iters):\n        mask_tensor = flow.tensor(\n            random_grad_seq[i], requires_grad=False, device=flow.device(device)\n        )\n        adam_x = adam_graph(mask_tensor)\n\n        of_res_list.append(copy.copy(simp_module.para0.numpy()))\n\n    np_res_list = []\n\n    def train_by_numpy():\n        x = init_value\n        vt = np.zeros_like(x)\n        st = np.zeros_like(x)\n        max_st = np.zeros_like(x)\n        beta1 = betas[0]\n        beta2 = betas[1]\n\n        def np_train_one_iter(step, grad):\n            total_norm, grad = clip_grad_norm_np(\n                grad, clip_grad_max_norm, clip_grad_norm_type\n            )\n            grad = grad + weight_decay * x\n\n            bias_correction1 = 1.0\n            bias_correction2 = 1.0\n\n            if do_bias_correction:\n                bias_correction1 = 1.0 - np.power(beta1, step)\n                bias_correction2 = 1.0 - np.power(beta2, step)\n\n            v = beta1 * vt + (1 - beta1) * grad\n            s = beta2 * st + (1 - beta2) * grad * grad\n            max_s = np.zeros_like(x)\n\n            if amsgrad:\n                max_s = np.maximum(s, max_st)\n                denom = np.sqrt(max_s) / np.sqrt(bias_correction2) + eps\n            else:\n                denom = np.sqrt(s) / np.sqrt(bias_correction2) + eps\n\n            param = x - ((learning_rate / bias_correction1) * v / denom)\n            return (param, v, s, max_s)\n\n        for i in range(1, train_iters + 1):\n            (x, vt, st, max_st) = np_train_one_iter(i, random_grad_seq[i - 1])\n            np_res_list.append(x)\n        return x\n\n    train_by_numpy()\n    test_case.assertTrue(np.allclose(of_res_list, np_res_list, rtol=1e-3, atol=1e-3))\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestAdam(flow.unittest.TestCase):\n    def test_adam(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"x_shape\"] = [(10,)]\n        arg_dict[\"learning_rate\"] = [1, 1e-3]\n        arg_dict[\"train_iters\"] = [10]\n        arg_dict[\"betas\"] = [(0.99, 0.9)]\n        arg_dict[\"weight_decay\"] = [0.001, 0.0]\n        arg_dict[\"eps\"] = [1e-8]\n        arg_dict[\"do_bias_correction\"] = [True, False]\n        arg_dict[\"amsgrad\"] = [True, False]\n\n        for arg in GenArgList(arg_dict):\n            compare_with_numpy_adam(test_case, *arg)\n\n    def test_adam_clip_grad(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"x_shape\"] = [(10,)]\n        arg_dict[\"learning_rate\"] = [1, 1e-3]\n        arg_dict[\"train_iters\"] = [10]\n        arg_dict[\"betas\"] = [(0.99, 0.9)]\n        arg_dict[\"weight_decay\"] = [0.0, 0.9]\n        arg_dict[\"eps\"] = [1e-8]\n        arg_dict[\"do_bias_correction\"] = [True, False]\n        arg_dict[\"amsgrad\"] = [True, False]\n        arg_dict[\"clip_grad_max_norm\"] = [1.0]\n        arg_dict[\"clip_grad_norm_type\"] = [2.0]\n        for arg in GenArgList(arg_dict):\n            compare_with_numpy_adam_clip_grad(test_case, *arg)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/graph/test_graph_optim_adamw.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nfrom collections import OrderedDict\nimport numpy as np\nimport copy\n\nfrom test_util import GenArgList\nfrom optimizer_test_util import clip_grad_norm_np\n\nimport oneflow as flow\n\n\ndef compare_with_numpy_adamw(\n    test_case,\n    device,\n    x_shape,\n    learning_rate,\n    train_iters,\n    betas,\n    weight_decay,\n    eps,\n    do_bias_correction,\n    amsgrad,\n):\n    random_grad_seq = []\n    for _ in range(train_iters):\n        random_grad_seq.append(np.random.uniform(size=x_shape).astype(np.float32))\n    init_value = np.random.uniform(size=x_shape).astype(np.float32)\n\n    class CustomModule(flow.nn.Module):\n        def __init__(self):\n            super().__init__()\n            self.para0 = flow.nn.Parameter(\n                flow.tensor(init_value, device=flow.device(device))\n            )\n\n        def forward(self, mask):\n            return self.para0 * mask\n\n    simp_module = CustomModule()\n    simp_module.to(device)\n    simp_module.train()\n\n    adamw0 = flow.optim.AdamW(\n        [\n            {\n                \"params\": simp_module.parameters(),\n                \"lr\": learning_rate,\n                \"betas\": betas,\n                \"weight_decay\": weight_decay,\n                \"do_bias_correction\": do_bias_correction,\n                \"amsgrad\": amsgrad,\n            }\n        ],\n        do_bias_correction=do_bias_correction,\n        amsgrad=amsgrad,\n    )\n\n    class CustomAdamWGraph(flow.nn.Graph):\n        def __init__(self):\n            super().__init__()\n            self.m = simp_module\n            self.add_optimizer(adamw0)\n\n        def build(self, mask_tensor):\n            loss = flow.sum(self.m(mask_tensor))\n            loss.backward()\n            return loss\n\n    of_res_list = []\n    adamw_graph = CustomAdamWGraph()\n    for i in range(train_iters):\n        mask_tensor = flow.tensor(\n            random_grad_seq[i], requires_grad=False, device=flow.device(device)\n        )\n        adamw_x = adamw_graph(mask_tensor)\n        of_res_list.append(copy.copy(simp_module.para0.numpy()))\n\n    np_res_list = []\n\n    def train_by_numpy():\n        x = init_value\n        vt = np.zeros_like(x)\n        st = np.zeros_like(x)\n        max_st = np.zeros_like(x)\n        beta1 = betas[0]\n        beta2 = betas[1]\n\n        def np_train_one_iter(step, grad):\n            v = beta1 * vt + (1 - beta1) * grad\n            s = beta2 * st + (1 - beta2) * grad * grad\n\n            bias_correction1 = 1.0\n            bias_correction2 = 1.0\n\n            if do_bias_correction:\n                bias_correction1 = 1.0 - np.power(beta1, step)\n                bias_correction2 = 1.0 - np.power(beta2, step)\n\n            max_s = np.zeros_like(x)\n            if amsgrad:\n                max_s = np.maximum(s, max_st)\n                denom = np.sqrt(max_s) / np.sqrt(bias_correction2) + eps\n            else:\n                denom = np.sqrt(s) / np.sqrt(bias_correction2) + eps\n\n            lr = learning_rate / bias_correction1 / denom\n            g = lr * v + learning_rate * weight_decay * x\n            param = x - g\n            return (param, v, s, max_s)\n\n        for i in range(1, train_iters + 1):\n            (x, vt, st, max_st) = np_train_one_iter(i, random_grad_seq[i - 1])\n            np_res_list.append(x)\n\n    train_by_numpy()\n\n    test_case.assertTrue(np.allclose(np_res_list, of_res_list, rtol=1e-4, atol=1e-4))\n\n\ndef compare_with_numpy_adamw_clip_grad(\n    test_case,\n    device,\n    x_shape,\n    learning_rate,\n    train_iters,\n    betas,\n    weight_decay,\n    eps,\n    do_bias_correction,\n    amsgrad,\n    clip_grad_max_norm,\n    clip_grad_norm_type,\n):\n    random_grad_seq = []\n    for _ in range(train_iters):\n        random_grad_seq.append(np.random.uniform(size=x_shape).astype(np.float32))\n    init_value = np.random.uniform(size=x_shape).astype(np.float32)\n\n    class CustomModule(flow.nn.Module):\n        def __init__(self):\n            super().__init__()\n            self.para0 = flow.nn.Parameter(\n                flow.tensor(init_value, device=flow.device(device))\n            )\n\n        def forward(self, mask):\n            return self.para0 * mask\n\n    simp_module = CustomModule()\n    simp_module.to(device)\n    simp_module.train()\n\n    adamw0 = flow.optim.AdamW(\n        [\n            {\n                \"params\": simp_module.parameters(),\n                \"lr\": learning_rate,\n                \"betas\": betas,\n                \"eps\": eps,\n                \"weight_decay\": weight_decay,\n                \"clip_grad_max_norm\": clip_grad_max_norm,\n                \"clip_grad_norm_type\": clip_grad_norm_type,\n            }\n        ],\n        do_bias_correction=do_bias_correction,\n        amsgrad=amsgrad,\n    )\n\n    class CustomAdamWGraph(flow.nn.Graph):\n        def __init__(self):\n            super().__init__()\n            self.m = simp_module\n            self.add_optimizer(adamw0)\n\n        def build(self, mask_tensor):\n            loss = flow.sum(self.m(mask_tensor))\n            loss.backward()\n            return loss\n\n    of_res_list = []\n    adamw_graph = CustomAdamWGraph()\n    for i in range(train_iters):\n        mask_tensor = flow.tensor(\n            random_grad_seq[i], requires_grad=False, device=flow.device(device)\n        )\n        adamw_x = adamw_graph(mask_tensor)\n        of_res_list.append(copy.copy(simp_module.para0.numpy()))\n\n    np_res_list = []\n\n    def train_by_numpy():\n        x = init_value\n        vt = np.zeros_like(x)\n        st = np.zeros_like(x)\n        max_st = np.zeros_like(x)\n\n        beta1 = betas[0]\n        beta2 = betas[1]\n\n        def np_train_one_iter(step, grad):\n            total_norm, grad = clip_grad_norm_np(\n                grad, clip_grad_max_norm, clip_grad_norm_type\n            )\n            v = beta1 * vt + (1 - beta1) * grad\n            s = beta2 * st + (1 - beta2) * grad * grad\n\n            bias_correction1 = 1.0\n            bias_correction2 = 1.0\n\n            if do_bias_correction:\n                bias_correction1 = 1.0 - np.power(beta1, step)\n                bias_correction2 = 1.0 - np.power(beta2, step)\n\n            max_s = np.zeros_like(x)\n            if amsgrad:\n                max_s = np.maximum(s, max_st)\n                denom = np.sqrt(max_s) / np.sqrt(bias_correction2) + eps\n            else:\n                denom = np.sqrt(s) / np.sqrt(bias_correction2) + eps\n\n            lr = learning_rate / bias_correction1 / denom\n            g = lr * v + learning_rate * weight_decay * x\n            param = x - g\n            return (param, v, s, max_s)\n\n        for i in range(1, train_iters + 1):\n            (x, vt, st, max_st) = np_train_one_iter(i, random_grad_seq[i - 1])\n            np_res_list.append(x)\n\n    train_by_numpy()\n\n    test_case.assertTrue(np.allclose(np_res_list, of_res_list, rtol=1e-4, atol=1e-4))\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestAdamW(flow.unittest.TestCase):\n    def test_adamw(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"x_shape\"] = [(10,)]\n        arg_dict[\"learning_rate\"] = [1, 1e-3]\n        arg_dict[\"train_iters\"] = [10]\n        arg_dict[\"betas\"] = [(0.99, 0.9)]\n        arg_dict[\"weight_decay\"] = [1e-3, 0.0]\n        arg_dict[\"eps\"] = [1e-8]\n        arg_dict[\"do_bias_correction\"] = [True, False]\n        arg_dict[\"amsgrad\"] = [True, False]\n\n        for arg in GenArgList(arg_dict):\n            compare_with_numpy_adamw(test_case, *arg)\n\n    def test_adamw_clip_grad(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"x_shape\"] = [(10,)]\n        arg_dict[\"learning_rate\"] = [1, 1e-3]\n        arg_dict[\"train_iters\"] = [10]\n        arg_dict[\"betas\"] = [(0.99, 0.9)]\n        arg_dict[\"weight_decay\"] = [0.0, 0.9]\n        arg_dict[\"eps\"] = [1e-8]\n        arg_dict[\"do_bias_correction\"] = [True, False]\n        arg_dict[\"amsgrad\"] = [True, False]\n        arg_dict[\"clip_grad_max_norm\"] = [1.0]\n        arg_dict[\"clip_grad_norm_type\"] = [2.0]\n        for arg in GenArgList(arg_dict):\n            compare_with_numpy_adamw_clip_grad(test_case, *arg)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/graph/test_graph_optim_ftrl.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport os\nimport unittest\nfrom collections import OrderedDict\nimport numpy as np\nimport copy\n\nfrom test_util import GenArgList\nfrom optimizer_test_util import clip_grad_norm_np\n\nimport oneflow as flow\nfrom oneflow.one_embedding import Ftrl\n\n\ndef compare_with_numpy_ftrl(\n    test_case,\n    device,\n    x_shape,\n    learning_rate,\n    train_iters,\n    weight_decay,\n    lr_power,\n    initial_accumulator_value,\n    lambda1,\n    lambda2,\n    beta,\n):\n    random_grad_seq = []\n    for _ in range(train_iters):\n        random_grad_seq.append(np.random.uniform(size=x_shape).astype(np.float32))\n    init_value = np.random.uniform(size=x_shape).astype(np.float32)\n\n    class CustomModule(flow.nn.Module):\n        def __init__(self):\n            super().__init__()\n            self.para0 = flow.nn.Parameter(\n                flow.Tensor(init_value, device=flow.device(device))\n            )\n\n        def forward(self, mask):\n            return self.para0 * mask\n\n    simp_module = CustomModule()\n    simp_module.to(device)\n    simp_module.train()\n\n    ftrl = Ftrl(\n        [\n            {\n                \"params\": simp_module.parameters(),\n                \"lr\": learning_rate,\n                \"weight_decay\": weight_decay,\n                \"lr_power\": lr_power,\n                \"initial_accumulator_value\": initial_accumulator_value,\n                \"lambda1\": lambda1,\n                \"lambda2\": lambda2,\n                \"beta\": beta,\n            }\n        ]\n    )\n\n    class CustomftrlGraph(flow.nn.Graph):\n        def __init__(self):\n            super().__init__()\n            self.m = simp_module\n            self.add_optimizer(ftrl)\n\n        def build(self, mask_tensor):\n            loss = flow.sum(self.m(mask_tensor))\n            loss.backward()\n            return loss\n\n    of_res_list = []\n    ftrl_graph = CustomftrlGraph()\n\n    for i in range(train_iters):\n        mask_tensor = flow.tensor(\n            random_grad_seq[i], requires_grad=False, device=flow.device(device)\n        )\n        ftrl_x = ftrl_graph(mask_tensor)\n\n        of_res_list.append(copy.copy(simp_module.para0.numpy()))\n\n    np_res_list = []\n\n    def train_by_numpy():\n        x = init_value\n        accum = np.zeros_like(x)\n        accum.fill(initial_accumulator_value)\n        z_arr = np.zeros_like(x)\n\n        def np_train_one_iter(grad):\n            grad = grad + weight_decay * x\n\n            new_accum = accum + grad * grad\n            sigma = (\n                np.power(new_accum, lr_power) - np.power(accum, lr_power)\n            ) / learning_rate\n            new_z_val = z_arr + grad - sigma * x\n\n            update_val = (np.sign(new_z_val) * lambda1 - new_z_val) / (\n                (beta + np.power(new_accum, lr_power)) / learning_rate + lambda2\n            )\n            param = np.where(np.abs(new_z_val) < lambda1, 0.0, update_val)\n            return (param, new_accum, new_z_val)\n\n        for i in range(1, train_iters + 1):\n            (x, accum, z_arr) = np_train_one_iter(random_grad_seq[i - 1])\n            np_res_list.append(x)\n        return x\n\n    train_by_numpy()\n    test_case.assertTrue(np.allclose(of_res_list, np_res_list, rtol=1e-4, atol=1e-4))\n\n\ndef compare_with_numpy_ftrl_clip_grad(\n    test_case,\n    device,\n    x_shape,\n    learning_rate,\n    train_iters,\n    weight_decay,\n    lr_power,\n    initial_accumulator_value,\n    lambda1,\n    lambda2,\n    beta,\n    clip_grad_max_norm,\n    clip_grad_norm_type,\n):\n    random_grad_seq = []\n    for _ in range(train_iters):\n        random_grad_seq.append(np.random.uniform(size=x_shape).astype(np.float32))\n    init_value = np.random.uniform(size=x_shape).astype(np.float32)\n\n    class CustomModule(flow.nn.Module):\n        def __init__(self):\n            super().__init__()\n            self.para0 = flow.nn.Parameter(\n                flow.Tensor(init_value, device=flow.device(device))\n            )\n\n        def forward(self, mask):\n            return self.para0 * mask\n\n    simp_module = CustomModule()\n    simp_module.to(device)\n    simp_module.train()\n\n    ftrl = Ftrl(\n        [\n            {\n                \"params\": simp_module.parameters(),\n                \"lr\": learning_rate,\n                \"weight_decay\": weight_decay,\n                \"lr_power\": lr_power,\n                \"initial_accumulator_value\": initial_accumulator_value,\n                \"lambda1\": lambda1,\n                \"lambda2\": lambda2,\n                \"beta\": beta,\n                \"clip_grad_max_norm\": clip_grad_max_norm,\n                \"clip_grad_norm_type\": clip_grad_norm_type,\n            }\n        ]\n    )\n\n    class CustomftrlGraph(flow.nn.Graph):\n        def __init__(self):\n            super().__init__()\n            self.m = simp_module\n            self.add_optimizer(ftrl)\n\n        def build(self, mask_tensor):\n            loss = flow.sum(self.m(mask_tensor))\n            loss.backward()\n            return loss\n\n    of_res_list = []\n    ftrl_graph = CustomftrlGraph()\n\n    for i in range(train_iters):\n        mask_tensor = flow.tensor(\n            random_grad_seq[i], requires_grad=False, device=flow.device(device)\n        )\n        ftrl_x = ftrl_graph(mask_tensor)\n\n        of_res_list.append(copy.copy(simp_module.para0.numpy()))\n\n    np_res_list = []\n\n    def train_by_numpy():\n        x = init_value\n        accum = np.zeros_like(x)\n        accum.fill(initial_accumulator_value)\n        z_arr = np.zeros_like(x)\n\n        def np_train_one_iter(grad):\n            norm, grad = clip_grad_norm_np(\n                grad, clip_grad_max_norm, clip_grad_norm_type\n            )\n            grad = grad + weight_decay * x\n\n            new_accum = accum + grad * grad\n            sigma = (\n                np.power(new_accum, lr_power) - np.power(accum, lr_power)\n            ) / learning_rate\n            new_z_val = z_arr + grad - sigma * x\n\n            update_val = (np.sign(new_z_val) * lambda1 - new_z_val) / (\n                (beta + np.power(new_accum, lr_power)) / learning_rate + lambda2\n            )\n            param = np.where(np.abs(new_z_val) < lambda1, 0.0, update_val)\n            return (param, new_accum, new_z_val)\n\n        for i in range(1, train_iters + 1):\n            (x, accum, z_arr) = np_train_one_iter(random_grad_seq[i - 1])\n            np_res_list.append(x)\n        return x\n\n    train_by_numpy()\n\n    test_case.assertTrue(np.allclose(of_res_list, np_res_list, rtol=1e-4, atol=1e-4))\n\n\n@flow.unittest.skip_unless_1n1d()\nclass Testftrl(flow.unittest.TestCase):\n    def test_ftrl(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"device\"] = [\"cuda\"]\n        if os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"):\n            arg_dict[\"device\"] = [\"cpu\"]\n        arg_dict[\"x_shape\"] = [(10,)]\n        arg_dict[\"learning_rate\"] = [1, 1e-3]\n        arg_dict[\"train_iters\"] = [10]\n        arg_dict[\"weight_decay\"] = [0.9, 0.000]\n        arg_dict[\"lr_power\"] = [-0.5, 0.5]\n        arg_dict[\"initial_accumulator_value\"] = [0.1, 0.05]\n        arg_dict[\"lambda1\"] = [0.01]\n        arg_dict[\"lambda2\"] = [0.0, 0.01]\n        arg_dict[\"beta\"] = [1.0]\n\n        for arg in GenArgList(arg_dict):\n            compare_with_numpy_ftrl(test_case, *arg)\n\n    def test_ftrl_clip_grad(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"device\"] = [\"cuda\"]\n        if os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"):\n            arg_dict[\"device\"] = [\"cpu\"]\n        arg_dict[\"x_shape\"] = [(10,)]\n        arg_dict[\"learning_rate\"] = [1, 1e-3]\n        arg_dict[\"train_iters\"] = [10]\n        arg_dict[\"weight_decay\"] = [0.9, 0.000]\n        arg_dict[\"lr_power\"] = [-0.5, 0.5]\n        arg_dict[\"initial_accumulator_value\"] = [0.1, 0.05]\n        arg_dict[\"lambda1\"] = [0.01]\n        arg_dict[\"lambda2\"] = [0.0, 0.01]\n        arg_dict[\"beta\"] = [1.0]\n        arg_dict[\"clip_grad_max_norm\"] = [1.0]\n        arg_dict[\"clip_grad_norm_type\"] = [2.0]\n        for arg in GenArgList(arg_dict):\n            compare_with_numpy_ftrl_clip_grad(test_case, *arg)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/graph/test_graph_optim_lamb.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport os\nimport unittest\nfrom collections import OrderedDict\nimport numpy as np\n\nfrom test_util import GenArgList\nfrom optimizer_test_util import clip_grad_norm_np\n\nimport oneflow as flow\n\n\ndef compare_with_numpy_lamb(\n    test_case,\n    device,\n    x_shape,\n    learning_rate,\n    train_iters,\n    betas,\n    weight_decay,\n    eps,\n    do_bias_correction,\n    adam_w_mode,\n    clip_grad_max_norm,\n    clip_grad_norm_type,\n):\n\n    np.random.seed(1000)\n\n    random_grad_seq = []\n    for _ in range(train_iters):\n        random_grad_seq.append(np.random.uniform(size=x_shape).astype(np.float32))\n    init_value = np.random.uniform(size=x_shape).astype(np.float32)\n\n    class CustomModule(flow.nn.Module):\n        def __init__(self):\n            super().__init__()\n            self.param = flow.nn.Parameter(\n                flow.Tensor(init_value, device=flow.device(device))\n            )\n\n        def forward(self, mask):\n            return self.param * mask\n\n    simp_module = CustomModule()\n    simp_module.to(device)\n    simp_module.train()\n\n    optim_kwargs = {\n        \"params\": simp_module.parameters(),\n        \"lr\": learning_rate,\n        \"betas\": betas,\n        \"eps\": eps,\n        \"weight_decay\": weight_decay,\n        \"adam_w_mode\": adam_w_mode,\n        \"do_bias_correction\": do_bias_correction,\n    }\n\n    if clip_grad_max_norm != -1:\n        optim_kwargs[\"clip_grad_max_norm\"] = clip_grad_max_norm\n        optim_kwargs[\"clip_grad_norm_type\"] = clip_grad_norm_type\n\n    lamb_optim = flow.optim.LAMB([optim_kwargs])\n\n    class CustomLambGraph(flow.nn.Graph):\n        def __init__(self):\n            super().__init__()\n            self.m = simp_module\n            self.add_optimizer(lamb_optim)\n\n        def build(self, mask_tensor):\n            loss = flow.sum(self.m(mask_tensor))\n            loss.backward()\n            return loss\n\n    lamb_graph = CustomLambGraph()\n\n    for i in range(train_iters):\n        mask_tensor = flow.tensor(\n            random_grad_seq[i],\n            dtype=flow.float32,\n            requires_grad=False,\n            device=flow.device(device),\n        )\n        lamb_graph(mask_tensor)\n\n    of_res = simp_module.param.numpy()\n\n    def train_by_numpy():\n        x = init_value\n        mt = np.zeros_like(x)\n        vt = np.zeros_like(x)\n        beta1 = betas[0]\n        beta2 = betas[1]\n        if adam_w_mode:\n            l2 = 0\n            wd = weight_decay\n        else:\n            l2 = weight_decay\n            wd = 0\n\n        def np_train_one_iter(step, grad):\n            if clip_grad_max_norm != -1:\n                _, grad = clip_grad_norm_np(\n                    grad, clip_grad_max_norm, clip_grad_norm_type\n                )\n\n            grad = grad + l2 * x\n\n            bias_correction1 = 1.0\n            bias_correction2 = 1.0\n\n            if do_bias_correction:\n                bias_correction1 = 1.0 - np.power(beta1, step + 1)\n                bias_correction2 = 1.0 - np.power(beta2, step + 1)\n\n            m = beta1 * mt + (1 - beta1) * grad\n            v = beta2 * vt + (1 - beta2) * grad * grad\n\n            denom = np.sqrt(v) / np.sqrt(bias_correction2) + eps\n\n            adam_diff = m / bias_correction1 / denom\n\n            w_norm = np.linalg.norm(x, ord=2)\n            g_norm = np.linalg.norm(adam_diff, ord=2)\n            if w_norm > 0 and g_norm > 0:\n                trust_ratio = w_norm / g_norm\n            else:\n                trust_ratio = 1.0\n\n            param = x - learning_rate * trust_ratio * (adam_diff + wd * x)\n            return (param, m, v)\n\n        for i in range(train_iters):\n            (x, mt, vt) = np_train_one_iter(i, random_grad_seq[i])\n        return x\n\n    np_res = train_by_numpy()\n\n    test_case.assertTrue(\n        np.allclose(of_res.flatten(), np_res.flatten(), rtol=1e-3, atol=1e-3)\n    )\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestLamb(flow.unittest.TestCase):\n    def test_lamb(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"device\"] = [\"cuda\"]\n        if os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"):\n            arg_dict[\"device\"] = [\"cpu\"]\n        arg_dict[\"x_shape\"] = [(10,)]\n        arg_dict[\"learning_rate\"] = [0.1, 1e-3]\n        arg_dict[\"train_iters\"] = [10]\n        arg_dict[\"betas\"] = [(0.99, 0.9)]\n        arg_dict[\"weight_decay\"] = [0.001, 0.1]\n        arg_dict[\"eps\"] = [1e-8, 1e-6]\n        arg_dict[\"do_bias_correction\"] = [True, False]\n        arg_dict[\"adam_w_mode\"] = [True, False]\n        # NOTE(l1aoxingyu): max_norm = -1 means no clip grad\n        # nn.Graph only support `clip_grad_max_norm == 1.0` and `clip_grad_norm_type == 2.0`\n        arg_dict[\"clip_grad_max_norm\"] = [-1, 1.0]\n        arg_dict[\"clip_grad_norm_type\"] = [2.0]\n\n        for arg in GenArgList(arg_dict):\n            compare_with_numpy_lamb(test_case, *arg)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/graph/test_graph_optim_rmsprop.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport os\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nimport copy\n\nfrom test_util import GenArgList\nfrom optimizer_test_util import clip_grad_norm_np\n\nimport oneflow as flow\n\n\n@flow.unittest.skip_unless_1n1d()\ndef compare_with_numpy_rmsprop(\n    test_case,\n    device,\n    x_shape,\n    learning_rate,\n    momentum,\n    train_iters,\n    alpha,\n    eps,\n    weight_decay,\n    centered,\n):\n    random_grad_seq = []\n    for _ in range(train_iters):\n        random_grad_seq.append(np.random.uniform(size=x_shape).astype(np.float32))\n    init_value = np.random.uniform(size=x_shape).astype(np.float32)\n\n    class CustomModel(flow.nn.Module):\n        def __init__(self):\n            super().__init__()\n            self.param0 = flow.nn.Parameter(\n                flow.tensor(init_value, device=flow.device(device))\n            )\n\n        def forward(self, mask):\n            return self.param0 * mask\n\n    simp_module = CustomModel()\n    simp_module.to(flow.device(device))\n    simp_module.train()\n\n    rmsprop0 = flow.optim.RMSprop(\n        [\n            {\n                \"params\": simp_module.parameters(),\n                \"lr\": learning_rate,\n                \"alpha\": alpha,\n                \"eps\": eps,\n                \"weight_decay\": weight_decay,\n                \"momentum\": momentum,\n                \"centered\": centered,\n            }\n        ]\n    )\n\n    class CustomRMSpropGraph(flow.nn.Graph):\n        def __init__(self):\n            super().__init__()\n            self.m = simp_module\n            self.add_optimizer(rmsprop0)\n\n        def build(self, mask_tensor):\n            loss = flow.sum(self.m(mask_tensor))\n            loss.backward()\n            return loss\n\n    of_res_list = []\n    rmsprop_graph = CustomRMSpropGraph()\n\n    for i in range(train_iters):\n        mask_tensor = flow.tensor(\n            random_grad_seq[i], requires_grad=False, device=flow.device(device)\n        )\n        rmsprop_x = rmsprop_graph(mask_tensor)\n\n        of_res_list.append(copy.copy(simp_module.param0.numpy()))\n\n    np_res_list = []\n\n    def train_by_numpy():\n        x = init_value\n        r = np.zeros_like(x)\n        v = np.zeros_like(x)\n        g = np.zeros_like(x)\n\n        def np_train_one_iter(grad):\n            # ref to: ../modules/test_optim_rmsprop.py -> train_by_numpy()\n\n            # weight decay is equivalent to l2 penalty\n            grad = grad + weight_decay * x\n            r_ = alpha * r + (1 - alpha) * grad * grad\n            if centered:\n                g_ = alpha * g + (1 - alpha) * grad\n                v_ = momentum * v + learning_rate / np.sqrt(r_ - g_ * g_ + eps) * grad\n            else:\n                g_ = g\n                v_ = momentum * v + learning_rate / np.sqrt(r_ + eps) * grad\n            param = x - v_\n            return (param, r_, g_, v_)\n\n        for i in range(train_iters):\n            (x, r, g, v) = np_train_one_iter(random_grad_seq[i])\n            np_res_list.append(x)\n        return x\n\n    train_by_numpy()\n\n    test_case.assertTrue(np.allclose(of_res_list, np_res_list, rtol=1e-3, atol=1e-3))\n\n\n@flow.unittest.skip_unless_1n1d()\ndef compare_with_numpy_rmsprop_clip_grad(\n    test_case,\n    device,\n    x_shape,\n    learning_rate,\n    momentum,\n    train_iters,\n    alpha,\n    eps,\n    weight_decay,\n    centered,\n    clip_grad_max_norm,\n    clip_grad_norm_type,\n):\n    random_grad_seq = []\n    for _ in range(train_iters):\n        random_grad_seq.append(np.random.uniform(size=x_shape).astype(np.float32))\n    init_value = np.random.uniform(size=x_shape).astype(np.float32)\n\n    class CustomModel(flow.nn.Module):\n        def __init__(self):\n            super().__init__()\n            self.param0 = flow.nn.Parameter(\n                flow.tensor(init_value, device=flow.device(device))\n            )\n\n        def forward(self, mask):\n            return self.param0 * mask\n\n    simp_module = CustomModel()\n    simp_module.to(flow.device(device))\n    simp_module.train()\n\n    rmsprop0 = flow.optim.RMSprop(\n        [\n            {\n                \"params\": simp_module.parameters(),\n                \"lr\": learning_rate,\n                \"alpha\": alpha,\n                \"eps\": eps,\n                \"weight_decay\": weight_decay,\n                \"momentum\": momentum,\n                \"centered\": centered,\n                \"clip_grad_max_norm\": clip_grad_max_norm,\n                \"clip_grad_norm_type\": clip_grad_norm_type,\n            }\n        ]\n    )\n\n    class CustomRMSpropGraph(flow.nn.Graph):\n        def __init__(self):\n            super().__init__()\n            self.m = simp_module\n            self.add_optimizer(rmsprop0)\n\n        def build(self, mask_tensor):\n            loss = flow.sum(self.m(mask_tensor))\n            loss.backward()\n            return loss\n\n    of_res_list = []\n    rmsprop_graph = CustomRMSpropGraph()\n\n    for i in range(train_iters):\n        mask_tensor = flow.tensor(\n            random_grad_seq[i], requires_grad=False, device=flow.device(device)\n        )\n        rmsprop_x = rmsprop_graph(mask_tensor)\n\n        of_res_list.append(copy.copy(simp_module.param0.numpy()))\n\n    np_res_list = []\n\n    def train_by_numpy():\n        x = init_value\n        r = np.zeros_like(x)\n        v = np.zeros_like(x)\n        g = np.zeros_like(x)\n\n        def np_train_one_iter(grad):\n            norm, grad = clip_grad_norm_np(\n                grad, clip_grad_max_norm, clip_grad_norm_type\n            )\n            # weight decay is equivalent to l2 penalty\n            grad = grad + weight_decay * x\n            r_ = alpha * r + (1 - alpha) * grad * grad\n            if centered:\n                g_ = alpha * g + (1 - alpha) * grad\n                v_ = momentum * v + learning_rate / np.sqrt(r_ - g_ * g_ + eps) * grad\n            else:\n                g_ = g\n                v_ = momentum * v + learning_rate / np.sqrt(r_ + eps) * grad\n            param = x - v_\n            return (param, r_, g_, v_)\n\n        for i in range(train_iters):\n            (x, r, g, v) = np_train_one_iter(random_grad_seq[i])\n            np_res_list.append(x)\n        return x\n\n    train_by_numpy()\n\n    test_case.assertTrue(np.allclose(of_res_list, np_res_list, rtol=1e-3, atol=1e-3))\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestRMSprop(flow.unittest.TestCase):\n    def test_rmsprop(test_case):\n        args_dict = OrderedDict()\n        args_dict[\"device\"] = [\"cuda\"]\n        if os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"):\n            args_dict[\"device\"] = [\"cpu\"]\n        args_dict[\"x_shape\"] = [(1,), (10,)]\n        args_dict[\"learning_rate\"] = [1]\n        args_dict[\"momentum\"] = [0.0]  # not supported momentum > 0\n        args_dict[\"train_iters\"] = [10]\n        args_dict[\"alpha\"] = [0.9]\n        args_dict[\"eps\"] = [1e-8, 1e-5]\n        args_dict[\"weight_decay\"] = [0.1, 0.9]\n        args_dict[\"centered\"] = [False, True]\n\n        for args in GenArgList(args_dict):\n            compare_with_numpy_rmsprop(test_case, *args)\n\n    def test_rmsprop_clip_grad(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"x_shape\"] = [(10,)]\n        arg_dict[\"learning_rate\"] = [1]\n        arg_dict[\"momentum\"] = [0.0]\n        arg_dict[\"train_iters\"] = [10]\n        arg_dict[\"alpha\"] = [0.9, 0.99]\n        arg_dict[\"eps\"] = [1e-08, 1e-05]\n        arg_dict[\"weight_decay\"] = [0.0, 0.9]\n        arg_dict[\"centered\"] = [False, True]\n        arg_dict[\"clip_grad_max_norm\"] = [1.0]\n        arg_dict[\"clip_grad_norm_type\"] = [2.0]\n        for arg in GenArgList(arg_dict):\n            compare_with_numpy_rmsprop_clip_grad(test_case, *arg)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/graph/test_graph_optim_sgd.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nfrom collections import OrderedDict\nimport numpy as np\nimport copy\n\nfrom test_util import GenArgList\nfrom optimizer_test_util import clip_grad_norm_np\n\nimport oneflow as flow\n\n\ndef compare_with_numpy_sgd(\n    test_case,\n    device,\n    x_shape,\n    learning_rate,\n    train_iters,\n    momentum,\n    dampening,\n    nesterov,\n    maximize,\n    weight_decay,\n):\n    random_grad_seq = []\n    for _ in range(train_iters):\n        random_grad_seq.append(np.random.uniform(size=x_shape).astype(np.float32))\n    init_value = np.random.uniform(size=x_shape).astype(np.float32)\n\n    class CustomModule(flow.nn.Module):\n        def __init__(self):\n            super().__init__()\n            self.para0 = flow.nn.Parameter(\n                flow.tensor(init_value, device=flow.device(device))\n            )\n\n        def forward(self, mask):\n            return self.para0 * mask\n\n    simp_module = CustomModule()\n    simp_module.to(device)\n    simp_module.train()\n\n    sgd0 = flow.optim.SGD(\n        [\n            {\n                \"params\": simp_module.parameters(),\n                \"lr\": learning_rate,\n                \"weight_decay\": weight_decay,\n            }\n        ],\n        momentum=momentum,\n        dampening=dampening,\n        nesterov=nesterov,\n        maximize=maximize,\n    )\n\n    class CustomSGDGraph(flow.nn.Graph):\n        def __init__(self):\n            super().__init__()\n            self.m = simp_module\n            self.add_optimizer(sgd0)\n\n        def build(self, mask_tensor):\n            loss = flow.sum(self.m(mask_tensor))\n            loss.backward()\n            return loss\n\n    of_res_list = []\n    sgd_graph = CustomSGDGraph()\n    for i in range(train_iters):\n        mask_tensor = flow.tensor(\n            random_grad_seq[i], requires_grad=False, device=flow.device(device)\n        )\n        sgd_x = sgd_graph(mask_tensor)\n        of_res_list.append(copy.copy(simp_module.para0.numpy()))\n\n    np_res_list = []\n\n    def train_by_numpy():\n        x = init_value\n        vt = np.zeros_like(x)\n\n        def np_train_one_iter(grad):\n            grad = grad + weight_decay * x\n            if momentum > 0.0:\n                next_momentum = momentum * vt + (1 - dampening) * grad\n                v = next_momentum\n\n                if nesterov:\n                    grad += momentum * next_momentum\n                else:\n                    grad = next_momentum\n\n                alpha = -learning_rate\n                if maximize:\n                    alpha = learning_rate\n                next_model = x + alpha * grad\n                param = next_model\n            else:\n                v = learning_rate * grad\n                param = x - v\n            return (param, v)\n\n        for i in range(train_iters):\n            (x, vt) = np_train_one_iter(random_grad_seq[i])\n            np_res_list.append(x)\n\n    train_by_numpy()\n    test_case.assertTrue(np.allclose(np_res_list, of_res_list, rtol=1e-3, atol=1e-3))\n\n\ndef compare_with_numpy_sgd_clip_grad(\n    test_case,\n    device,\n    x_shape,\n    learning_rate,\n    momentum,\n    dampening,\n    nesterov,\n    maximize,\n    weight_decay,\n    clip_grad_max_norm,\n    clip_grad_norm_type,\n    train_iters,\n):\n    random_grad_seq = []\n    for _ in range(train_iters):\n        random_grad_seq.append(np.random.uniform(size=x_shape).astype(np.float32))\n    init_value = np.random.uniform(size=x_shape).astype(np.float32)\n\n    class CustomModule(flow.nn.Module):\n        def __init__(self):\n            super().__init__()\n            self.para0 = flow.nn.Parameter(\n                flow.tensor(init_value, device=flow.device(device))\n            )\n\n        def forward(self, mask):\n            return self.para0 * mask\n\n    simp_module = CustomModule()\n    simp_module.to(device)\n    simp_module.train()\n\n    sgd0 = flow.optim.SGD(\n        [\n            {\n                \"params\": simp_module.parameters(),\n                \"lr\": learning_rate,\n                \"weight_decay\": weight_decay,\n                \"clip_grad_max_norm\": clip_grad_max_norm,\n                \"clip_grad_norm_type\": clip_grad_norm_type,\n            }\n        ],\n        momentum=momentum,\n        dampening=dampening,\n        nesterov=nesterov,\n        maximize=maximize,\n    )\n\n    class CustomSGDGraph(flow.nn.Graph):\n        def __init__(self):\n            super().__init__()\n            self.m = simp_module\n            self.add_optimizer(sgd0)\n\n        def build(self, mask_tensor):\n            loss = flow.sum(self.m(mask_tensor))\n            loss.backward()\n            return loss\n\n    of_res_list = []\n    sgd_graph = CustomSGDGraph()\n    for i in range(train_iters):\n        mask_tensor = flow.tensor(\n            random_grad_seq[i], requires_grad=False, device=flow.device(device)\n        )\n        sgd_x = sgd_graph(mask_tensor)\n        of_res_list.append(copy.copy(simp_module.para0.numpy()))\n\n    np_res_list = []\n\n    def train_by_numpy():\n        x = init_value\n        vt = np.zeros_like(x)\n\n        def np_train_one_iter(grad):\n            norm, grad = clip_grad_norm_np(\n                grad, clip_grad_max_norm, clip_grad_norm_type\n            )\n            grad = grad + weight_decay * x\n            if momentum > 0.0:\n                next_momentum = momentum * vt + (1 - dampening) * grad\n                v = next_momentum\n\n                if nesterov:\n                    grad += momentum * next_momentum\n                else:\n                    grad = next_momentum\n\n                alpha = -learning_rate\n                if maximize:\n                    alpha = learning_rate\n                next_model = x + alpha * grad\n                param = next_model\n            else:\n                v = learning_rate * grad\n                param = x - v\n            return (param, v)\n\n        for i in range(train_iters):\n            (x, vt) = np_train_one_iter(random_grad_seq[i])\n            np_res_list.append(x)\n\n    train_by_numpy()\n    for np_res, of_res in zip(np_res_list, of_res_list):\n        test_case.assertTrue(np.allclose(np_res, of_res, rtol=0.001, atol=0.001))\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestGraphSGD(flow.unittest.TestCase):\n    def test_sgd(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"x_shape\"] = [(10,)]\n        arg_dict[\"learning_rate\"] = [1, 1e-3]\n        arg_dict[\"train_iters\"] = [10]\n        arg_dict[\"momentum\"] = [0.9, 0.8]\n        arg_dict[\"dampening\"] = [0.0, 0.9]\n        arg_dict[\"nesterov\"] = [True, False]\n        arg_dict[\"maximize\"] = [True, False]\n        arg_dict[\"weight_decay\"] = [0.001, 0.0]\n        for arg in GenArgList(arg_dict):\n            compare_with_numpy_sgd(test_case, *arg)\n\n    def test_sgd_with_clip_grad(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"x_shape\"] = [(10,)]\n        arg_dict[\"learning_rate\"] = [1, 0.1]\n        arg_dict[\"momentum\"] = [0.0, 0.9]\n        arg_dict[\"dampening\"] = [0.0, 0.9]\n        arg_dict[\"nesterov\"] = [True, False]\n        arg_dict[\"maximize\"] = [True, False]\n        arg_dict[\"weight_decay\"] = [0.0, 0.9]\n        arg_dict[\"clip_grad_max_norm\"] = [1.0]\n        arg_dict[\"clip_grad_norm_type\"] = [2.0]\n        arg_dict[\"train_iters\"] = [10]\n        for arg in GenArgList(arg_dict):\n            compare_with_numpy_sgd_clip_grad(test_case, *arg)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/graph/test_graph_optimizer.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport os\nimport unittest\n\nimport numpy as np\n\nimport oneflow\nimport oneflow as flow\nimport oneflow.unittest\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestGraphOptimizer(flow.unittest.TestCase):\n    def test_optimizer(test_case):\n        class CustomModule(flow.nn.Module):\n            def __init__(self):\n                super().__init__()\n                self.para0 = flow.nn.Parameter(flow.Tensor(10, 4))\n\n            def forward(self, x):\n                x = flow._C.matmul(x, self.para0)\n                return x\n\n        m = CustomModule()\n        learning_rate = 0.1\n        momentum = 0.2\n        weight_decay = 0.7\n        sgd0 = flow.optim.SGD(\n            [\n                {\n                    \"params\": [m.para0],\n                    \"lr\": learning_rate,\n                    \"momentum\": momentum,\n                    \"weight_decay\": weight_decay,\n                }\n            ]\n        )\n        cosine_lr = flow.optim.lr_scheduler.CosineDecayLR(\n            sgd0, decay_steps=100, alpha=0.1\n        )\n\n        class CustomGraph0(flow.nn.Graph):\n            def __init__(self):\n                super().__init__()\n                self.m = m\n                self.add_optimizer(sgd0)\n\n            def build(self, x):\n                out = self.m(x)\n                out = out.mean()\n                out.backward()\n                return out\n\n        g = CustomGraph0()\n\n        x = flow.Tensor(4, 10)\n        flow.nn.init.uniform_(x, a=-1.0, b=1.0)\n        z = g._compile(x)\n        print(\"repr(g): \\n\", repr(g))\n        print(\"g.config.proto: \\n\", g.config.proto)\n        print(\"graph proto: \\n\", g._graph_proto)\n\n    def test_multi_optimizer_conf(test_case):\n        class CustomModule(flow.nn.Module):\n            def __init__(self):\n                super().__init__()\n                self.para0 = flow.nn.Parameter(flow.Tensor(1, 4))\n                self.para1 = flow.nn.Parameter(flow.Tensor(1, 4))\n                self.para2 = flow.nn.Parameter(flow.Tensor(1, 4))\n                self.para2.requires_grad_(False)\n                self.para3 = flow.nn.Parameter(flow.Tensor(1, 4))\n                self.para4 = flow.nn.Parameter(flow.Tensor(1, 4))\n\n            def forward(self, x):\n                x = flow._C.matmul(self.para0, x)\n                y = flow._C.matmul(self.para3, x)\n                return x, y\n\n        m = CustomModule()\n        learning_rate = 0.1\n        momentum = 0.2\n        sgd0 = flow.optim.SGD(\n            [\n                {\n                    \"params\": [m.para0, m.para1, m.para2],\n                    \"lr\": learning_rate,\n                    \"momentum\": momentum,\n                    \"weight_decay\": 0.3,\n                }\n            ]\n        )\n        sgd1 = flow.optim.SGD(\n            [\n                {\n                    \"params\": [m.para3],\n                    \"lr\": learning_rate,\n                    \"momentum\": momentum,\n                    \"weight_decay\": 0.4,\n                },\n                {\n                    \"params\": [m.para4],\n                    \"lr\": learning_rate,\n                    \"momentum\": 0.9,\n                    \"weight_decay\": 0.5,\n                },\n            ]\n        )\n        cosine_lr0 = flow.optim.lr_scheduler.CosineDecayLR(\n            sgd0, decay_steps=10, alpha=0.01\n        )\n        constant_warmup_cosine_lr0 = flow.optim.lr_scheduler.WarmUpLR(\n            cosine_lr0, warmup_factor=0.5, warmup_iters=5, warmup_method=\"constant\"\n        )\n        cosine_lr1 = flow.optim.lr_scheduler.CosineDecayLR(\n            sgd1, decay_steps=100, alpha=0.1\n        )\n        linear_warmup_cosine_lr1 = flow.optim.lr_scheduler.WarmUpLR(\n            cosine_lr1, warmup_factor=0.5, warmup_iters=5, warmup_method=\"linear\"\n        )\n\n        class CustomGraph0(flow.nn.Graph):\n            def __init__(self):\n                super().__init__()\n                self.m = m\n                self.add_optimizer(sgd0, lr_sch=constant_warmup_cosine_lr0)\n                self.add_optimizer(sgd1, lr_sch=linear_warmup_cosine_lr1)\n\n            def build(self, x):\n                out0, out1 = self.m(x)\n                out0.backward()\n                out1.backward()\n                return out0, out1\n\n        g = CustomGraph0()\n        x = flow.Tensor(4, 10)\n        flow.nn.init.uniform_(x, a=-1.0, b=1.0)\n        g._filter_states()\n        g._generate_config_proto()\n        print(\"repr(g): \\n\", repr(g))\n        print(\"g.config.proto: \\n\", g.config.proto)\n\n    @unittest.skip(\"skip for now, becase it failed 2 times in past week\")\n    def test_optimizer_with_clip_grad(test_case):\n        class CustomModule(flow.nn.Module):\n            def __init__(self):\n                super().__init__()\n                self.para0 = flow.nn.Parameter(flow.Tensor(10, 4))\n\n            def forward(self, x):\n                x = flow._C.matmul(x, self.para0)\n                return x\n\n        m = CustomModule()\n        learning_rate = 0.1\n        momentum = 0.2\n        scale = 0.3\n        weight_decay = 0.7\n        clip_grad_max_norm = 1.0\n        clip_grad_norm_type = 2.0\n\n        sgd0 = flow.optim.SGD(\n            [\n                {\n                    \"params\": [m.para0],\n                    \"lr\": learning_rate,\n                    \"momentum\": momentum,\n                    \"scale\": scale,\n                    \"weight_decay\": weight_decay,\n                    \"clip_grad_max_norm\": clip_grad_max_norm,\n                    \"clip_grad_norm_type\": clip_grad_norm_type,\n                }\n            ]\n        )\n\n        class CustomGraph0(flow.nn.Graph):\n            def __init__(self):\n                super().__init__()\n                self.m = m\n                self.add_optimizer(sgd0)\n\n            def build(self, x):\n                out = self.m(x)\n                out = out.sum()\n                out.backward()\n                return out\n\n        g = CustomGraph0()\n        x = flow.Tensor(4, 10)\n        flow.nn.init.uniform_(x, a=-1.0, b=1.0)\n        z = g._compile(x)\n        print(\"repr(g): \\n\", repr(g))\n        print(\"g.config.proto: \\n\", g.config.proto)\n        print(\"graph proto: \\n\", g._graph_proto)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/graph/test_graph_pipeline.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport os\nimport sys\nimport unittest\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.nn.graph import GraphModule\n\n\nrank = flow.env.get_rank()\n\n\nclass OFRecordDataLoader(flow.nn.Module):\n    def __init__(\n        self,\n        ofrecord_root: str = \"./ofrecord\",\n        mode: str = \"train\",  # \"val\"\n        dataset_size: int = 9469,\n        batch_size: int = 1,\n        placement=None,\n        sbp=None,\n    ):\n        super().__init__()\n        channel_last = False\n        output_layout = \"NHWC\" if channel_last else \"NCHW\"\n        self.train_record_reader = flow.nn.OFRecordReader(\n            ofrecord_root + \"/\" + mode,\n            batch_size=batch_size,\n            data_part_num=40,\n            part_name_suffix_length=5,\n            random_shuffle=False,\n            shuffle_after_epoch=False,\n            placement=placement,\n            sbp=sbp,\n            random_seed=0,\n        )\n        self.record_label_decoder = flow.nn.OFRecordRawDecoder(\n            \"class/label\", shape=(), dtype=flow.int32\n        )\n\n        color_space = \"RGB\"\n        height = 22\n        width = 22\n\n        self.record_image_decoder = flow.nn.OFRecordImageDecoder(\n            \"encoded\", color_space=color_space\n        )\n\n        self.resize = flow.nn.image.Resize(target_size=[height, width])\n\n        self.batch_size = batch_size\n        self.dataset_size = dataset_size\n\n    def __len__(self):\n        return self.dataset_size // self.batch_size\n\n    def forward(self):\n        train_record = self.train_record_reader()\n        label = self.record_label_decoder(train_record)\n        image_raw_buffer = self.record_image_decoder(train_record)\n        image = self.resize(image_raw_buffer)[0]\n        image = flow.flatten(image.to(flow.float32), start_dim=1)\n\n        return image, label\n\n\ndef _train_with_graph(iter_num=3):\n    B = [flow.sbp.broadcast]\n    P0 = flow.placement(\"cuda\", ranks=[0])\n    P1 = flow.placement(\"cuda\", ranks=[1])\n    P2 = flow.placement(\"cuda\", ranks=[2])\n    P3 = flow.placement(\"cuda\", ranks=[3])\n\n    train_data_loader = OFRecordDataLoader(\n        ofrecord_root=flow.unittest.dataset_dir(\"ImageNet/ofrecord\"),\n        mode=\"validation\",\n        dataset_size=400,\n        batch_size=4,\n        placement=P0,\n        sbp=B,\n    )\n\n    def _get_ppm_and_opt():\n        class StageModule(flow.nn.Module):\n            def __init__(self, *linear_args):\n                super().__init__()\n                self.linear = flow.nn.Linear(*linear_args)\n                flow.nn.init.constant_(self.linear.weight, 0.00023)\n\n            def forward(self, input):\n                out = self.linear(input)\n                return out\n\n        class PipelineModule(flow.nn.Module):\n            def __init__(self):\n                super().__init__()\n                # Initlize module and move each module to the right placement of its pipeline stage.\n                self.stage_0_m = StageModule(1452, 8, False).to_global(\n                    placement=P0, sbp=B\n                )\n                self.stage_1_m = StageModule(8, 8, False).to_global(placement=P1, sbp=B)\n                self.stage_2_m = StageModule(8, 8, False).to_global(placement=P2, sbp=B)\n                self.stage_3_m = StageModule(8, 1, False).to_global(placement=P3, sbp=B)\n\n            def forward(self, image):\n                out = self.stage_0_m(image)\n                # Move tensor between different pipeline stages.\n                out = out.to_global(placement=P1, sbp=B)\n                out = self.stage_1_m(out)\n                out = out.to_global(placement=P2, sbp=B)\n                out = self.stage_2_m(out)\n                out = out.to_global(placement=P3, sbp=B)\n                out = self.stage_3_m(out)\n                return out\n\n        pp_m = PipelineModule()\n        sgd = flow.optim.SGD(pp_m.parameters(), lr=0.0001)\n        return pp_m, sgd\n\n    pp_m, sgd = _get_ppm_and_opt()\n\n    class PipelineGraph(flow.nn.Graph):\n        def __init__(self):\n            super().__init__()\n            self.train_data_loader = train_data_loader\n            self.pp_m = pp_m\n            # Set different module's stage id to hint the graph preparing right num of buffers in pipeline.\n            self.pp_m.stage_0_m.to(GraphModule).set_stage(0)\n            self.pp_m.stage_1_m.to(GraphModule).set_stage(1)\n            self.pp_m.stage_2_m.to(GraphModule).set_stage(2)\n            self.pp_m.stage_3_m.to(GraphModule).set_stage(3)\n\n            self.pp_m.stage_0_m.to(GraphModule).activation_checkpointing = True\n            self.pp_m.stage_1_m.to(GraphModule).activation_checkpointing = True\n            self.pp_m.stage_2_m.to(GraphModule).activation_checkpointing = True\n            self.pp_m.stage_3_m.to(GraphModule).activation_checkpointing = True\n\n            self.mseloss = flow.nn.MSELoss(\"sum\")\n            self.add_optimizer(sgd)\n            # Let graph to do gradient accumulatioin, pipline execution depends on gradient accumulatioin.\n            self.config.set_gradient_accumulation_steps(4)\n\n        def build(self):\n            image, label = self.train_data_loader()\n\n            # Dataloader's outputs are on host memory, so move it to device 0.\n            image = image.to_global(placement=P0, sbp=B)\n            pp_m.train()\n            out = self.pp_m(image)\n\n            # Dataloader's outputs are on host memory, so move it to device 3.\n            label = label.to_global(placement=P3, sbp=B)\n            loss = self.mseloss(out, label.to(dtype=flow.float32))\n            loss.backward()\n\n            # Returning image and label is just for re-using data in eager test\n            image = image.to_global(placement=P3, sbp=B)\n            return loss, image, label\n\n    pp_g = PipelineGraph()\n\n    def one_iter(iter_idx):\n        loss, image, label = pp_g()\n        if rank == 3:\n            # loss on other rank are 0-Size tensor\n            loss = loss.to_local()\n            loss_np = loss.numpy()\n            print(\"loss numpy \\n\", loss)\n            image = image.to_local().numpy()\n            label = label.to_local().numpy()\n            return loss, image, label\n\n    check_list = []\n    data_list = []\n    for i in range(iter_num):\n        out = one_iter(i)\n        if rank == 3:\n            check_list.append(out[0])\n            data_list.append((out[1], out[2]))\n    return check_list, data_list\n\n\ndef _train_with_module(iter_num=3, data=None):\n    class DataModule(flow.nn.Module):\n        def __init__(self, data):\n            super().__init__()\n            self.data_list = []\n            self.idx = 0\n            for pair in data:\n                for i in range(4):\n                    s = i * 4\n                    e = s + 4\n                    micro_batch_image = pair[0][s:e]\n                    micro_batch_label = pair[1][s:e]\n                    self.data_list.append(\n                        (\n                            flow.Tensor(micro_batch_image).to(\"cuda:3\"),\n                            flow.Tensor(micro_batch_label).to(\"cuda:3\"),\n                        )\n                    )\n\n        def forward(self):\n            image = self.data_list[self.idx][0]\n            label = self.data_list[self.idx][1]\n            self.idx += 1\n            return image, label\n\n    class TrainModule(flow.nn.Module):\n        def __init__(self):\n            super().__init__()\n            self.linear = flow.nn.Linear(1452, 8, False)\n            flow.nn.init.constant_(self.linear.weight, 0.00023)\n            self.linear.to(\"cuda:3\")\n            self.linear1 = flow.nn.Linear(8, 8, False)\n            flow.nn.init.constant_(self.linear1.weight, 0.00023)\n            self.linear1.to(\"cuda:3\")\n            self.linear2 = flow.nn.Linear(8, 8, False)\n            flow.nn.init.constant_(self.linear2.weight, 0.00023)\n            self.linear2.to(\"cuda:3\")\n            self.linear3 = flow.nn.Linear(8, 1, False)\n            flow.nn.init.constant_(self.linear3.weight, 0.00023)\n            self.linear3.to(\"cuda:3\")\n            self.mseloss = flow.nn.MSELoss(\"sum\")\n\n        def forward(self, image, label):\n            out = self.linear(image)\n            out = self.linear1(out)\n            out = self.linear2(out)\n            out = self.linear3(out)\n            loss = self.mseloss(out, label)\n            return loss\n\n    if rank == 3:\n        data_m = DataModule(data)\n        train_m = TrainModule()\n        sgd = flow.optim.SGD(train_m.parameters(), lr=0.0001)\n\n        def one_iter(iter_idx):\n            if rank == 3:\n                image, label = data_m()\n                loss = train_m(image, label)\n\n                loss_np = loss.numpy()\n                print(\"eager loss numpy \\n\", loss_np)\n\n                loss = loss * 0.25\n                loss.backward()\n                if iter_idx % 4 == 3:\n                    print(f\"iter index: {iter_idx}\")\n                    # eager gradient accumulatioin\n                    sgd.step()\n                    sgd.zero_grad()\n                return loss_np\n\n        check_list = []\n        for i in range(iter_num):\n            check_list.append(one_iter(i))\n        return check_list\n\n\ndef _test_graph_pipeline(test_case):\n    iter_num = 3\n    graph_check_list, data = _train_with_graph(iter_num)\n    module_check_list = _train_with_module(iter_num * 4, data)\n\n    if rank == 3:\n        for i in range(iter_num * 4):\n            # check equal on loss\n            test_case.assertTrue(\n                np.array_equal(module_check_list[i], graph_check_list[i // 4][i % 4])\n            )\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n@flow.unittest.skip_unless_1n4d()\nclass TestGraphPipeline(oneflow.unittest.TestCase):\n    def test_graph_pipeline(test_case):\n        _test_graph_pipeline(test_case)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/graph/test_graph_pipeline_delay.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport os\nimport time\nimport unittest\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.nn.graph import GraphModule\n\n\ndef _test_graph_pipeline_delay_output(test_case):\n    class StageLayerModule(flow.nn.Module):\n        def __init__(self):\n            super().__init__()\n            self.linear1 = flow.nn.Linear(10, 8, False)\n            self.linear2 = flow.nn.Linear(8, 10)\n            flow.nn.init.constant_(self.linear1.weight, 0.023)\n            flow.nn.init.constant_(self.linear2.weight, 1.23)\n\n        def forward(self, x):\n            out0 = self.linear1(x)\n            out0 = out0 + 1.0\n            out0 = out0 * 2.0\n            out1 = self.linear2(out0)\n            return out1\n\n    P0 = flow.placement(\"cuda\", ranks=[0])\n    P1 = flow.placement(\"cuda\", ranks=[1])\n    B = flow.sbp.broadcast\n\n    class PipelineModule(flow.nn.Module):\n        def __init__(self):\n            super().__init__()\n            self.layer_0 = StageLayerModule()\n            self.layer_1 = StageLayerModule()\n            self.layer_0.to_global(P0, B)\n            self.layer_1.to_global(P1, B)\n\n        def forward(self, x):\n            # stage 0\n            in0 = x.to_global(P0, B)\n            out0 = self.layer_0(in0)\n            # stage 1\n            in1 = out0.to_global(P1, B)\n            out1 = self.layer_1(in1)\n            return out1\n\n    pp_m = PipelineModule()\n    pp_m.train()\n    of_sgd = flow.optim.SGD(pp_m.parameters(), lr=0.001)\n\n    class PipelineGraph(flow.nn.Graph):\n        def __init__(self):\n            super().__init__()\n            self.pp_m = pp_m\n            self.pp_m.layer_0.to(GraphModule).stage_id = 0\n            self.pp_m.layer_1.to(GraphModule).stage_id = 1\n            self.config.set_gradient_accumulation_steps(4)\n            self.add_optimizer(of_sgd)\n\n        def build(self, x, y):\n            pp_out = self.pp_m(x)\n            loss = pp_out.mean()\n            loss.backward()\n            y = x + y\n            free_out = y.to_global(P1, B)\n            return loss, free_out\n\n    pp_g = PipelineGraph()\n    rank = flow.env.get_rank()\n    for i in range(3):\n        x = flow.randn(16, 10)\n        y = flow.randn(16, 10)\n        x = x.to_global(P0, B)\n        y = y.to_global(P0, B)\n        if rank == 1:\n            time.sleep(2)\n        loss_pack_4, free_out = pp_g(x, y)\n        if rank == 1:\n            # NOTE(chengcheng): Before Oneflow-Inc/oneflow#6221 fix src/dst tick order with input/output,\n            #   this case use sleep in rank 1 will expose this BUG:\n            #   free_out is output only on rank 1, but NOT control in rank 1 src/dst tick, so if manual sleep\n            #   on rank 1, free out pull callback must exec before rank 1 src tick exec, so will meet BUG of\n            #   output_kernel buffer status empty.\n            #   After this PR fix, this test case ensure that src/dst tick and input/output cb exec order on\n            #   each rank is as expected.\n            time.sleep(2)\n            print(\n                \"rank: \",\n                rank,\n                \"packed loss with 4 micro-batch = \",\n                loss_pack_4.to_local(),\n            )\n            print(\n                \"rank: \",\n                rank,\n                \"packed image with 4 micro-batch = \",\n                free_out.to_local(),\n            )\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n@flow.unittest.skip_unless_1n2d()\nclass TestGraphPipelineDelayOutput(oneflow.unittest.TestCase):\n    def test_graph_pipeline_delay_output(test_case):\n        _test_graph_pipeline_delay_output(test_case)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/graph/test_graph_random_seed.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport os\nimport numpy as np\nimport unittest\nimport inspect\nimport types\nimport oneflow as flow\nimport oneflow.nn as nn\nimport oneflow.unittest\n\n\ndef _inspect_rand_op_and_args(rand_op, **kwargs):\n    if inspect.isclass(rand_op) and issubclass(rand_op, nn.Module):\n        init_method_signature = inspect.signature(rand_op.__init__)\n\n        module_init_args = dict()\n        for arg_name in list(init_method_signature.parameters.keys())[1:]:\n            if arg_name in kwargs:\n                module_init_args[arg_name] = kwargs.pop(arg_name)\n\n        module_instance = rand_op(**module_init_args)\n        return module_instance, kwargs\n\n    if isinstance(rand_op, types.BuiltinFunctionType):\n        return rand_op, kwargs\n\n    if inspect.isfunction(rand_op):\n        return rand_op, kwargs\n\n    raise ValueError(f\"invalid rand_op {rand_op}, type: {type(rand_op)}\")\n\n\n# y1 = rand_op1(x)\n# y2 = rand_op2(x)\n# rand_op1 and rand_op2 should have different seed in graph, lead to different result\ndef _test_rand_op_in_graph(test_case, rand_op, input=None, **kwargs):\n    rand_op1, kwargs1 = _inspect_rand_op_and_args(rand_op, **kwargs)\n    rand_op2, kwargs2 = _inspect_rand_op_and_args(rand_op, **kwargs)\n\n    class TestGraphWithoutInput(nn.Graph):\n        def __init__(self):\n            super().__init__()\n            self.rand_op1 = rand_op1\n            self.rand_op2 = rand_op2\n\n        def build(self):\n            y1 = self.rand_op1(**kwargs1)\n            y2 = self.rand_op2(**kwargs2)\n            return y1, y2\n\n    class TestGraph(nn.Graph):\n        def __init__(self):\n            super().__init__()\n            self.rand_op1 = rand_op1\n            self.rand_op2 = rand_op2\n\n        def build(self, x):\n            x1 = x\n            x2 = x.clone()\n            y1 = self.rand_op1(x1, **kwargs1)\n            y2 = self.rand_op2(x2, **kwargs2)\n            return y1, y2\n\n    if input is None:\n        graph = TestGraphWithoutInput()\n        rand_result1, rand_result2 = graph()\n    else:\n        graph = TestGraph()\n        rand_result1, rand_result2 = graph(input)\n\n    if isinstance(rand_result1, (list, tuple)):\n        rand_result1 = rand_result1[0]\n    if isinstance(rand_result2, (list, tuple)):\n        rand_result2 = rand_result2[0]\n\n    test_case.assertFalse(\n        np.allclose(rand_result1.numpy(), rand_result2.numpy()),\n        f\"\\ninput:\\n{input}\\nrand_result1:\\n{rand_result1}\\nrand_result2:\\n{rand_result2}\",\n    )\n\n\ndef _get_shape_and_device_from_args(pop_device=False, **kwargs):\n    if \"size\" in kwargs:\n        shape = kwargs[\"size\"]\n    elif \"shape\" in kwargs:\n        shape = kwargs[\"shape\"]\n    elif \"n\" in kwargs:\n        shape = (kwargs[\"n\"],)\n    else:\n        raise ValueError(f\"can't parse shape from kwargs {kwargs}\")\n\n    device = \"cpu\"\n    if \"device\" in kwargs:\n        device = kwargs[\"device\"]\n\n    return shape, device\n\n\n# Test FRB (Forward Recomputation Backpropagation)\n# y = rand_op(x) * w\n# dw = fake_rand_op(x) * dy\n# (y * w).backward() will result in dy == w\n# so dw == y demand rand_op(x) == fake_rand_op(x)\n# in checkpoint activation graph\n# fake_rand_op in backward should produce the same result with rand_op in forward\ndef _test_rand_op_in_FRB(test_case, rand_op, input=None, **kwargs):\n    rand_op, kwargs = _inspect_rand_op_and_args(rand_op, **kwargs)\n\n    class CheckpointActivationModule(nn.Module):\n        def __init__(self, weight, is_src_rand=False):\n            super().__init__()\n            self.rand_op = rand_op\n            self.is_src_rand = is_src_rand\n            self.weight = weight\n            self.param = nn.Parameter(flow.zeros(*weight.shape))\n\n        def forward(self, x):\n            weight = self.param - self.weight\n            if self.is_src_rand:\n                y = self.rand_op(**kwargs) + x\n            else:\n                y = self.rand_op(x, **kwargs)\n            if isinstance(y, (tuple, list)):\n                y = y[0]\n            return y * weight\n\n    class TestGraph(nn.Graph):\n        def __init__(self, model):\n            super().__init__()\n            self.model = model\n            self.model.to(nn.graph.GraphModule).activation_checkpointing = True\n            self.add_optimizer(flow.optim.SGD(self.model.parameters(), lr=1.0))\n\n        def build(self, x):\n            y = self.model(x)\n            (y * self.model.weight).sum().backward()\n            return y\n\n    if input is None:\n        shape, device = _get_shape_and_device_from_args(**kwargs)\n        x = flow.randn(*shape).to(device)\n        weight = flow.randn(*shape).to(device)\n        model = CheckpointActivationModule(weight, True).to(device)\n        graph = TestGraph(model)\n    else:\n        x = input\n        weight = flow.randn(*input.shape).to(input.device)\n        model = CheckpointActivationModule(weight, False).to(input.device)\n        graph = TestGraph(model)\n\n    y = graph(x)\n\n    test_case.assertTrue(\n        np.allclose(y.numpy(), model.param.numpy()),\n        f\"\\nx=\\n{x.numpy()}\\nweight=\\n{weight.numpy()}\\ny=\\n{y.numpy()}\\ndweight=\\n{model.param.numpy()}\",\n    )\n\n\ndef _test_split_rand_op_in_graph(test_case, rand_op, input=None, **kwargs):\n    rand_op, kwargs = _inspect_rand_op_and_args(rand_op, **kwargs)\n\n    class TestGraph(nn.Graph):\n        def __init__(self):\n            super().__init__()\n            self.rand_op = rand_op\n\n        def build(self, x):\n            x = x.to_global(sbp=flow.sbp.split(0))\n            y = self.rand_op(x, **kwargs)\n            return y\n\n    class TestGraphWithoutInput(nn.Graph):\n        def __init__(self, placement):\n            super().__init__()\n            self.rand_op = rand_op\n            self.placement = placement\n\n        def build(self):\n            y = self.rand_op(placement=self.placement, sbp=flow.sbp.split(0), **kwargs)\n            return y\n\n    ranks = np.array(range(flow.env.get_world_size()))\n    if input is None:\n        device = kwargs.pop(\"device\", None)\n        placement = flow.placement(device, ranks)\n        graph = TestGraphWithoutInput(placement)\n        y_global = graph()\n    else:\n        x = flow.concat([input, input], dim=0)\n        placement = flow.placement(input.device.type, ranks)\n        # local to broadcast global\n        x_global = x.to_global(placement=placement, sbp=flow.sbp.broadcast(), copy=True)\n        graph = TestGraph()\n        y_global = graph(x_global)\n\n    if isinstance(y_global, (list, tuple)):\n        y_global = y_global[0]\n\n    y_global = y_global.to_global(placement=placement, sbp=flow.sbp.broadcast())\n    half = y_global.shape[0] // 2\n    first_half = y_global[0:half]\n    second_half = y_global[half:]\n    test_case.assertFalse(np.allclose(first_half.numpy(), second_half.numpy()))\n\n\ndef _test_broadcast_rand_op_in_graph(test_case, rand_op, input=None, **kwargs):\n    rand_op, kwargs = _inspect_rand_op_and_args(rand_op, **kwargs)\n\n    class TestGraph(nn.Graph):\n        def __init__(self):\n            super().__init__()\n            self.rand_op = rand_op\n\n        def build(self, x):\n            y = self.rand_op(x, **kwargs)\n            return y\n\n    class TestGraphWithoutInput(nn.Graph):\n        def __init__(self, placement):\n            super().__init__()\n            self.rand_op = rand_op\n            self.placement = placement\n\n        def build(self):\n            y = self.rand_op(placement=placement, sbp=flow.sbp.broadcast(), **kwargs)\n            return y\n\n    ranks = np.array(range(flow.env.get_world_size()))\n    if input is None:\n        device = kwargs.pop(\"device\", None)\n        placement = flow.placement(device, ranks)\n        graph = TestGraphWithoutInput(placement)\n        y_global = graph()\n    else:\n        placement = flow.placement(input.device.type, ranks)\n        # local to broadcast global\n        x = input\n        x_global = x.to_global(placement=placement, sbp=flow.sbp.broadcast(), copy=True)\n        graph = TestGraph()\n        y_global = graph(x_global)\n\n    if isinstance(y_global, (list, tuple)):\n        y_local = y_global[0].to_local()\n    else:\n        y_local = y_global.to_local()\n\n    y_all_ranks = y_local.to_global(placement=placement, sbp=flow.sbp.split(0))\n    y_allgather = y_all_ranks.to_global(sbp=flow.sbp.broadcast())\n    half = y_allgather.shape[0] // 2\n    first_half = y_allgather[0:half]\n    second_half = y_allgather[half:]\n    test_case.assertTrue(np.allclose(first_half.numpy(), second_half.numpy()))\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestRandOpInGraph(oneflow.unittest.TestCase):\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    def test_usual_rand_op(self):\n        for device in (\"cpu\", \"cuda\"):\n            x = flow.randn(4, 16, device=device)\n            _test_rand_op_in_graph(self, nn.Dropout, x, p=0.5)\n            _test_rand_op_in_graph(self, flow._C.rrelu, x, training=True)\n            _test_rand_op_in_graph(self, nn.init.uniform_, x)\n            _test_rand_op_in_graph(self, flow._C.exponential_, x)\n\n            x1 = flow.rand(4, 16, device=device)\n            _test_rand_op_in_graph(\n                self, flow.multinomial, x1, num_samples=16, replacement=True\n            )\n\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    def test_source_rand_op(self):\n        shape = (4, 16)\n        for device in (\"cpu\", \"cuda\"):\n            _test_rand_op_in_graph(self, flow.rand, size=shape, device=device)\n            _test_rand_op_in_graph(\n                self, flow.normal, mean=0.0, std=1.0, size=shape, device=device\n            )\n            _test_rand_op_in_graph(\n                self, flow.randint, low=0, high=10, size=shape, device=device\n            )\n            _test_rand_op_in_graph(self, flow.randperm, n=32, device=device)\n\n    def test_bernoulli(self):\n        x1 = flow.randn(4, 16)\n        _test_rand_op_in_graph(self, flow.bernoulli, x1, p=0.5)\n        x2 = flow.rand(4, 16)\n        _test_rand_op_in_graph(self, flow.bernoulli, x2)\n\n    @unittest.skip(\"skip for now, becase it failed 4 times in past week\")\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    def test_random_mask_like(self):\n        x = flow.randn(4, 16, 128, 128).to(\"cuda\")\n        _test_rand_op_in_graph(\n            self,\n            flow._C.fused_scale_tril_softmax_mask_scale,\n            x,\n            p=0.1,\n            diagonal=2,\n            tril_scale_value=-1000,\n        )\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestRandOpInFRB(oneflow.unittest.TestCase):\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    def test_usual_rand_op(self):\n        for device in (\"cpu\", \"cuda\"):\n            x = flow.randn(4, 16, device=device)\n            _test_rand_op_in_FRB(self, nn.Dropout, x, p=0.5)\n            _test_rand_op_in_FRB(self, flow._C.rrelu, x, training=True)\n            _test_rand_op_in_FRB(self, nn.init.uniform_, x)\n            _test_rand_op_in_FRB(self, flow._C.exponential_, x)\n\n            x1 = flow.rand(4, 16, device=device)\n            _test_rand_op_in_FRB(\n                self, flow.multinomial, x1, num_samples=16, replacement=True\n            )\n\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    def test_source_rand_op(self):\n        shape = (4, 16)\n        for device in (\"cpu\", \"cuda\"):\n            _test_rand_op_in_FRB(self, flow.rand, size=shape, device=device)\n            _test_rand_op_in_FRB(\n                self, flow.normal, mean=0.0, std=1.0, size=shape, device=device\n            )\n            _test_rand_op_in_FRB(\n                self, flow.randint, low=0, high=10, size=shape, device=device\n            )\n            _test_rand_op_in_FRB(self, flow.randperm, n=32, device=device)\n\n    def test_bernoulli(self):\n        x1 = flow.randn(4, 16)\n        _test_rand_op_in_FRB(self, flow.bernoulli, x1, p=0.5)\n        x2 = flow.rand(4, 16)\n        _test_rand_op_in_FRB(self, flow.bernoulli, x2)\n\n    @unittest.skip(\"skip for now, becase it failed 4 times in past week\")\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    def test_random_mask_like(self):\n        x = flow.randn(4, 16, 128, 128).to(\"cuda\")\n        _test_rand_op_in_FRB(\n            self,\n            flow._C.fused_scale_tril_softmax_mask_scale,\n            x,\n            p=0.1,\n            diagonal=0,\n            tril_scale_value=-1000,\n        )\n\n\n@flow.unittest.skip_unless_1n2d()\nclass TestGlobalRandInGraph(oneflow.unittest.TestCase):\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    def test_usual_rand_op_with_split(self):\n        x = flow.randn(2, 4, device=\"cuda\")\n        _test_split_rand_op_in_graph(self, nn.Dropout, x, p=0.5)\n        _test_split_rand_op_in_graph(self, flow._C.rrelu, x, training=True)\n        _test_split_rand_op_in_graph(self, nn.init.uniform_, x)\n        _test_split_rand_op_in_graph(self, flow._C.exponential_, x)\n\n        x1 = flow.rand(2, 8, device=\"cuda\")\n        _test_split_rand_op_in_graph(\n            self, flow.multinomial, x1, num_samples=8, replacement=True\n        )\n\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    def test_usual_rand_op_with_broadcast(self):\n        x = flow.randn(2, 4, device=\"cuda\")\n        _test_broadcast_rand_op_in_graph(self, nn.Dropout, x, p=0.5)\n        _test_broadcast_rand_op_in_graph(self, flow._C.rrelu, x, training=True)\n        _test_broadcast_rand_op_in_graph(self, nn.init.uniform_, x)\n        _test_broadcast_rand_op_in_graph(self, flow._C.exponential_, x)\n\n        x1 = flow.rand(2, 8, device=\"cuda\")\n        _test_broadcast_rand_op_in_graph(\n            self, flow.multinomial, x1, num_samples=8, replacement=True\n        )\n\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    def test_source_rand_op_with_split(self):\n        shape = (4, 4)\n        _test_split_rand_op_in_graph(self, flow.rand, size=shape, device=\"cuda\")\n        _test_split_rand_op_in_graph(\n            self, flow.normal, mean=0.0, std=1.0, size=shape, device=\"cuda\"\n        )\n        _test_split_rand_op_in_graph(\n            self, flow.randint, low=0, high=10, size=shape, device=\"cuda\"\n        )\n        _test_split_rand_op_in_graph(self, flow.randperm, n=32, device=\"cuda\")\n\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    def test_source_rand_op_with_broadcast(self):\n        shape = (4, 4)\n        _test_broadcast_rand_op_in_graph(self, flow.rand, size=shape, device=\"cuda\")\n        _test_broadcast_rand_op_in_graph(\n            self, flow.normal, mean=0.0, std=1.0, size=shape, device=\"cuda\"\n        )\n        _test_broadcast_rand_op_in_graph(\n            self, flow.randint, low=0, high=10, size=shape, device=\"cuda\"\n        )\n        _test_broadcast_rand_op_in_graph(self, flow.randperm, n=32, device=\"cuda\")\n\n    def test_bernoulli_with_split(self):\n        x1 = flow.randn(2, 8)\n        _test_split_rand_op_in_graph(self, flow.bernoulli, x1, p=0.5)\n        x2 = flow.rand(2, 8)\n        _test_split_rand_op_in_graph(self, flow.bernoulli, x2)\n\n    def test_bernoulli_with_broadcast(self):\n        x1 = flow.randn(2, 8)\n        _test_broadcast_rand_op_in_graph(self, flow.bernoulli, x1, p=0.5)\n        x2 = flow.rand(2, 8)\n        _test_broadcast_rand_op_in_graph(self, flow.bernoulli, x2)\n\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    def test_random_mask_like_with_split(self):\n        x = flow.randn(2, 16, 64).to(\"cuda\")\n        _test_split_rand_op_in_graph(\n            self,\n            flow._C.fused_scale_tril_softmax_mask_scale,\n            x,\n            p=0.1,\n            diagonal=0,\n            tril_scale_value=-1000,\n        )\n\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    def test_random_mask_like_with_broadcast(self):\n        x = flow.randn(2, 16, 64).to(\"cuda\")\n        _test_broadcast_rand_op_in_graph(\n            self,\n            flow._C.fused_scale_tril_softmax_mask_scale,\n            x,\n            p=0.2,\n            diagonal=1,\n            tril_scale_value=-100,\n        )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/graph/test_graph_relu.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestReluGraph(oneflow.unittest.TestCase):\n    def test_relu_graph(test_case):\n        data = np.array([2.0, 1.0, 0.0, -1.0, -2.0])\n        x = flow.tensor(data, dtype=flow.float32)\n\n        MyRelu = flow.nn.ReLU()\n        y_eager = MyRelu(x)\n        # print(\"eager out :\", y_eager)\n\n        class ReluGraph(flow.nn.Graph):\n            def __init__(self):\n                super().__init__()\n                self.cc_relu = MyRelu\n\n            def build(self, x):\n                return self.cc_relu(x)\n\n        relu_g = ReluGraph()\n        y_lazy = relu_g(x)\n        # print(f\"type of lazy y: {type(y_lazy)}\")\n        # print(f\"lazy y shape: {y_lazy.shape}, data: {y_lazy}\")\n        test_case.assertTrue(np.array_equal(y_eager.numpy(), y_lazy.numpy()))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/graph/test_graph_reshape_acc.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nimport os\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.nn.graph import GraphModule\n\n\ndef _test_graph_reshape_acc(test_case):\n    class StageLayerModule(flow.nn.Module):\n        def __init__(self):\n            super().__init__()\n            self.linear1 = flow.nn.Linear(10, 8, False)\n            self.linear2 = flow.nn.Linear(8, 10, False)\n            flow.nn.init.constant_(self.linear1.weight, 0.023)\n            flow.nn.init.constant_(self.linear2.weight, 1.23)\n\n        def forward(self, x):\n            out0 = self.linear1(x)\n            out0 = flow.reshape(out0, (-1, 2, 4))\n            out0 = out0 + 1.0\n            out0 = out0 * 2.0\n            out0 = flow.reshape(out0, (-1, 8))\n            out1 = self.linear2(out0)\n            return out1\n\n    P0 = flow.placement(\"cuda\", ranks=[0])\n    P1 = flow.placement(\"cuda\", ranks=[1])\n    B = flow.sbp.broadcast\n\n    class PipelineModule(flow.nn.Module):\n        def __init__(self):\n            super().__init__()\n            self.layer_0 = StageLayerModule()\n            self.layer_1 = StageLayerModule()\n            self.layer_0.to_global(P0, B)\n            self.layer_1.to_global(P1, B)\n\n        def forward(self, x):\n            # stage 0\n            x = flow.flatten(x, start_dim=1)\n            in0 = x.to_global(P0, B)\n            out0 = self.layer_0(in0)\n            # stage 1\n            in1 = out0.to_global(P1, B)\n            out1 = self.layer_1(in1)\n            return out1\n\n    pp_m = PipelineModule()\n    pp_m.train()\n    sgd = flow.optim.SGD(pp_m.parameters(), lr=0.001)\n\n    class PipelineGraph(flow.nn.Graph):\n        def __init__(self):\n            super().__init__()\n            self.pp_m = pp_m\n            self.pp_m.layer_0.to(GraphModule).set_stage(0)\n            self.pp_m.layer_1.to(GraphModule).set_stage(1)\n            self.loss_fn = flow.nn.CrossEntropyLoss()\n            self.config.set_gradient_accumulation_steps(2)\n            self.add_optimizer(sgd)\n\n        def build(self, x, y):\n            out = self.pp_m(x)\n            y = y.to_global(P1, B)\n            loss = self.loss_fn(out, y)\n            loss.backward()\n            return loss\n\n    pp_g = PipelineGraph()\n\n    for i in range(20):\n        x = flow.randn(6, 2, 5)\n        y = flow.randint(0, 10, (6,))\n        x = x.to_global(P0, B)\n        y = y.to_global(P1, B)\n        out = pp_g(x, y)\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n@flow.unittest.skip_unless_1n2d()\nclass TestGraphReshapeAcc(oneflow.unittest.TestCase):\n    def test_graph_reshape_acc(test_case):\n        _test_graph_reshape_acc(test_case)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/graph/test_graph_reuse_var.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport os\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nfrom test_util import GenArgList\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\n@flow.unittest.skip_unless_1n2d()\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\nclass TestGraphResueVar(flow.unittest.TestCase):\n    def test_graph_reuse_var(test_case):\n        rank = flow.env.get_rank()\n        P = flow.placement(\"cuda\", ranks=[0, 1])\n        B = flow.sbp.broadcast\n\n        class ReuseVarModule(flow.nn.Module):\n            def __init__(self):\n                super().__init__()\n                self.linear1 = flow.nn.Linear(2, 2)\n                self.linear2 = flow.nn.Linear(2, 2)\n                # Reuse parameter\n                self.linear2.weight = self.linear1.weight\n\n            def forward(self, x):\n                # Allow user to call parameter outside it's module.\n                self.linear1.weight\n                x = self.linear1(x)\n                x = self.linear2(x)\n                return x\n\n        reuse_var_m = ReuseVarModule()\n        reuse_var_m.to_global(placement=P, sbp=B)\n        of_sgd = flow.optim.SGD(reuse_var_m.parameters(), lr=0.001, momentum=0.9)\n\n        class ReuseVarGraph(flow.nn.Graph):\n            def __init__(self):\n                super().__init__()\n                self.reuse_var_m = reuse_var_m\n                self.add_optimizer(of_sgd)\n\n            def build(self, x):\n                x = self.reuse_var_m(x)\n                loss = x.sum()\n                loss.backward()\n                return loss\n\n        x = flow.randint(0, 1, (2, 2), placement=P, sbp=B, dtype=flow.float32)\n        reuse_var_g = ReuseVarGraph()\n        loss = reuse_var_g(x)\n\n        # check lazy tensor builder\n        block = reuse_var_g.reuse_var_m\n        test_case.assertEqual(\n            block.linear1.weight.lazy_origin_builder().name,\n            \"reuse_var_m.linear1.weight\",\n        )\n        test_case.assertEqual(\n            block.linear1.weight.lazy_origin_builder().name,\n            block.linear2.weight.lazy_origin_builder().name,\n        )\n\n        # check optimizer's variable list\n        var_list = [\n            \"reuse_var_m.linear1.weight\",\n            \"reuse_var_m.linear1.bias\",\n            \"reuse_var_m.linear2.bias\",\n        ]\n        var_list_in_conf = reuse_var_g._graph_proto.job_conf.train_conf.optimizer_conf[\n            0\n        ].variable_op_names\n        test_case.assertEqual(len(var_list_in_conf), 3)\n        for idx in range(3):\n            test_case.assertEqual(var_list[idx], var_list_in_conf[idx])\n            if rank == 0:\n                print(var_list_in_conf[idx])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/graph/test_graph_save_load.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nimport os\nimport numpy as np\nimport tempfile\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\ndef _test_linear_graph_save_load(test_case, device):\n    def train_with_graph(call_cnt=0, state_dict_file=None, last_state_dict=None):\n        linear = flow.nn.Linear(3, 8)\n        linear = linear.to(device)\n        flow.nn.init.constant_(linear.weight, 2.068758)\n        flow.nn.init.constant_(linear.bias, 0.23)\n        of_sgd = flow.optim.SGD(linear.parameters(), lr=0.001, momentum=0.9)\n\n        x = flow.tensor(\n            [\n                [-0.94630778, -0.83378579, -0.87060891],\n                [2.0289922, -0.28708987, -2.18369248],\n                [0.35217619, -0.67095644, -1.58943879],\n                [0.08086036, -1.81075924, 1.20752494],\n                [0.8901075, -0.49976737, -1.07153746],\n                [-0.44872912, -1.07275683, 0.06256855],\n                [-0.22556897, 0.74798368, 0.90416439],\n                [0.48339456, -2.32742195, -0.59321527],\n            ],\n            dtype=flow.float32,\n            device=device,\n            requires_grad=False,\n        )\n\n        class LinearTrainGraph(flow.nn.Graph):\n            def __init__(self):\n                super().__init__()\n                self.linear = linear\n                self.add_optimizer(of_sgd)\n\n            def build(self, x):\n                out = self.linear(x)\n                out = out.sum()\n                out.backward()\n                return out\n\n        linear_t_g = LinearTrainGraph()\n        if call_cnt == 1:\n            state_dict = flow.load(state_dict_file)\n            linear_t_g.load_state_dict(state_dict)\n            # Check state in module has been loaded.\n            test_case.assertTrue(\n                np.array_equal(state_dict[\"linear\"][\"weight\"].numpy(), linear.weight)\n            )\n            test_case.assertTrue(\n                np.array_equal(state_dict[\"linear\"][\"bias\"].numpy(), linear.bias)\n            )\n        # Get state dict before compile is allowed.\n        init_state_dict = linear_t_g.state_dict()\n\n        of_graph_out = linear_t_g(x)\n        iter0_state_dict = linear_t_g.state_dict()\n        if call_cnt == 1:\n            # Check additional variable state initialized in job has been loaded.\n            cur_train_step = iter0_state_dict[\"System-Train-TrainStep\"].numpy()[0]\n            test_case.assertEqual(3, cur_train_step)\n            test_case.assertTrue(\n                cur_train_step == last_state_dict[\"System-Train-TrainStep\"].numpy()[0]\n            )\n            test_case.assertTrue(\n                np.array_equal(\n                    iter0_state_dict[\"linear\"][\"weight\"].numpy(),\n                    last_state_dict[\"linear\"][\"weight\"].numpy(),\n                )\n            )\n            test_case.assertTrue(\n                np.array_equal(\n                    iter0_state_dict[\"linear\"][\"bias\"].numpy(),\n                    last_state_dict[\"linear\"][\"bias\"].numpy(),\n                )\n            )\n            test_case.assertTrue(\n                np.array_equal(\n                    iter0_state_dict[\"linear.weight-momentum\"].numpy(),\n                    last_state_dict[\"linear.weight-momentum\"].numpy(),\n                )\n            )\n            test_case.assertTrue(\n                np.array_equal(\n                    iter0_state_dict[\"linear.bias-momentum\"].numpy(),\n                    last_state_dict[\"linear.bias-momentum\"].numpy(),\n                )\n            )\n\n        of_graph_out = linear_t_g(x)\n        of_graph_out.numpy()\n        iter1_state_dict = linear_t_g.state_dict()\n        if call_cnt == 0:\n            flow.save(iter1_state_dict, state_dict_file)\n\n        if call_cnt == 0:\n            of_graph_out = linear_t_g(x)\n            iter2_state_dict = linear_t_g.state_dict()\n            of_graph_out.numpy()\n            return iter2_state_dict\n\n    with tempfile.NamedTemporaryFile(prefix=\"graph_save_load_local\") as f:\n        iter2_state_dict = train_with_graph(0, f.name)\n        train_with_graph(1, f.name, iter2_state_dict)\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n@flow.unittest.skip_unless_1n1d()\nclass TestLinearGraphSaveLoad(oneflow.unittest.TestCase):\n    def test_linear_graph_save_load_gpu(test_case):\n        _test_linear_graph_save_load(test_case, flow.device(\"cuda\"))\n\n    def _test_linear_graph_save_load_cpu(test_case):\n        _test_linear_graph_save_load(test_case, flow.device(\"cpu\"))\n\n\ndef _test_linear_graph_save_load_global(test_case, device):\n    P = flow.placement(\"cuda\", ranks=[0, 1])\n    B = flow.sbp.broadcast\n    S = flow.sbp.split(0)\n\n    def train_with_graph(call_cnt=0, state_dict_file=None, last_state_dict=None):\n        linear = flow.nn.Linear(3, 8)\n        linear = linear.to(device)\n        flow.nn.init.constant_(linear.weight, 2.068758)\n        flow.nn.init.constant_(linear.bias, 0.23)\n        linear.to_global(placement=P, sbp=B)\n        of_sgd = flow.optim.SGD(linear.parameters(), lr=0.001, momentum=0.9)\n\n        x = flow.tensor(\n            [\n                [-0.94630778, -0.83378579, -0.87060891],\n                [2.0289922, -0.28708987, -2.18369248],\n                [0.35217619, -0.67095644, -1.58943879],\n                [0.08086036, -1.81075924, 1.20752494],\n                [0.8901075, -0.49976737, -1.07153746],\n                [-0.44872912, -1.07275683, 0.06256855],\n                [-0.22556897, 0.74798368, 0.90416439],\n                [0.48339456, -2.32742195, -0.59321527],\n            ],\n            dtype=flow.float32,\n            device=device,\n            requires_grad=False,\n        )\n        x = x.to_global(placement=P, sbp=S)\n\n        class LinearTrainGraphGlobal(flow.nn.Graph):\n            def __init__(self):\n                super().__init__()\n                self.linear = linear\n                self.add_optimizer(of_sgd)\n\n            def build(self, x):\n                out = self.linear(x)\n                out = out.sum()\n                out.backward()\n                return out\n\n        linear_t_g = LinearTrainGraphGlobal()\n        if call_cnt == 1:\n            state_dict = flow.load(state_dict_file, global_src_rank=0)\n            linear_t_g.load_state_dict(state_dict)\n            # Check state in module has been loaded.\n            # Tensors in state dict are save to rank 0, so they need to be broadcast to rank 0 and 1 before check.\n            test_case.assertTrue(\n                np.array_equal(\n                    state_dict[\"linear\"][\"weight\"]\n                    .to_global(placement=P, sbp=B)\n                    .to_local()\n                    .numpy(),\n                    linear.weight.to_local().numpy(),\n                )\n            )\n            test_case.assertTrue(\n                np.array_equal(\n                    state_dict[\"linear\"][\"bias\"]\n                    .to_global(placement=P, sbp=B)\n                    .to_local()\n                    .numpy(),\n                    linear.bias.to_local().numpy(),\n                )\n            )\n        # Get state dict before compile is allowed.\n        init_state_dict = linear_t_g.state_dict()\n\n        of_graph_out = linear_t_g(x)\n        iter0_state_dict = linear_t_g.state_dict()\n        if call_cnt == 1:\n            # Check additional variable state initialized in job has been loaded.\n            # TrainStep's placement is only on rank 0, so it needs to be broadcast to rank 0 and 1 before check.\n            cur_train_step = (\n                iter0_state_dict[\"System-Train-TrainStep\"]\n                .to_global(placement=P, sbp=B)\n                .to_local()\n                .numpy()[0]\n            )\n            test_case.assertEqual(3, cur_train_step)\n            test_case.assertTrue(\n                cur_train_step\n                == last_state_dict[\"System-Train-TrainStep\"]\n                .to_global(placement=P, sbp=B)\n                .to_local()\n                .numpy()[0]\n            )\n            test_case.assertTrue(\n                np.array_equal(\n                    iter0_state_dict[\"linear\"][\"weight\"].to_local().numpy(),\n                    last_state_dict[\"linear\"][\"weight\"].to_local().numpy(),\n                )\n            )\n            test_case.assertTrue(\n                np.array_equal(\n                    iter0_state_dict[\"linear\"][\"bias\"].to_local().numpy(),\n                    last_state_dict[\"linear\"][\"bias\"].to_local().numpy(),\n                )\n            )\n            test_case.assertTrue(\n                np.array_equal(\n                    iter0_state_dict[\"linear.weight-momentum\"].to_local().numpy(),\n                    last_state_dict[\"linear.weight-momentum\"].to_local().numpy(),\n                )\n            )\n            test_case.assertTrue(\n                np.array_equal(\n                    iter0_state_dict[\"linear.bias-momentum\"].to_local().numpy(),\n                    last_state_dict[\"linear.bias-momentum\"].to_local().numpy(),\n                )\n            )\n\n        of_graph_out = linear_t_g(x)\n        of_graph_out.numpy()\n        iter1_state_dict = linear_t_g.state_dict()\n        if call_cnt == 0:\n            flow.save(iter1_state_dict, state_dict_file, global_dst_rank=0)\n\n        if call_cnt == 0:\n            of_graph_out = linear_t_g(x)\n            of_graph_out.numpy()\n            iter2_state_dict = linear_t_g.state_dict()\n            return iter2_state_dict\n\n    with tempfile.NamedTemporaryFile(prefix=\"graph_save_load_global\") as f:\n        iter2_state_dict = train_with_graph(0, f.name)\n        train_with_graph(1, f.name, iter2_state_dict)\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n@flow.unittest.skip_unless_1n2d()\nclass TestLinearGraphSaveLoadGlobal(oneflow.unittest.TestCase):\n    def test_linear_graph_save_load_gpu(test_case):\n        _test_linear_graph_save_load_global(test_case, flow.device(\"cuda\"))\n\n    def _test_linear_graph_save_load_cpu(test_case):\n        _test_linear_graph_save_load_global(test_case, flow.device(\"cpu\"))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/graph/test_graph_save_load_global_b_s.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport os\nimport unittest\nimport numpy as np\nimport tempfile\n\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.nn.graph import GraphModule\n\n\ndef _test_linear_graph_save_load_global_broadcast(\n    test_case, model_tensor_placement, model_file_placement\n):\n    \"\"\"Data parallelism on 2 ranks.\n    \"\"\"\n    B = flow.sbp.broadcast\n    S0 = flow.sbp.split(0)\n\n    def train_with_graph(call_cnt=0, state_dict_file=None, last_state_dict=None):\n        linear = flow.nn.Linear(3, 8)\n        linear = linear.to(flow.device(model_tensor_placement.type))\n        flow.nn.init.constant_(linear.weight, 2.068758)\n        flow.nn.init.constant_(linear.bias, 0.23)\n        linear.to_global(placement=model_tensor_placement, sbp=B)\n        of_sgd = flow.optim.SGD(linear.parameters(), lr=0.001, momentum=0.9)\n\n        x = flow.tensor(\n            [\n                [-0.94630778, -0.83378579, -0.87060891],\n                [2.0289922, -0.28708987, -2.18369248],\n                [0.35217619, -0.67095644, -1.58943879],\n                [0.08086036, -1.81075924, 1.20752494],\n                [0.8901075, -0.49976737, -1.07153746],\n                [-0.44872912, -1.07275683, 0.06256855],\n                [-0.22556897, 0.74798368, 0.90416439],\n                [0.48339456, -2.32742195, -0.59321527],\n            ],\n            dtype=flow.float32,\n            device=flow.device(model_tensor_placement.type),\n            requires_grad=False,\n        )\n        x = x.to_global(placement=model_tensor_placement, sbp=S0)\n\n        class LinearTrainGraphGlobal(flow.nn.Graph):\n            def __init__(self):\n                super().__init__()\n                self.linear = linear\n                self.add_optimizer(of_sgd)\n\n            def build(self, x):\n                out = self.linear(x)\n                out = out.sum()\n                out.backward()\n                return out\n\n        linear_t_g = LinearTrainGraphGlobal()\n        cur_rank = flow.env.get_rank()\n        if call_cnt == 1:\n            if cur_rank in model_file_placement.ranks:\n                local_state_dict = flow.load(state_dict_file)\n            else:\n                local_state_dict = None\n\n            global_state_dict = flow.utils.global_view.to_global(\n                local_state_dict, placement=model_file_placement, sbp=B\n            )\n            linear_t_g.load_state_dict(global_state_dict)\n\n            if cur_rank == 0:  # Ignore None on rank 1\n                # Check state in module has been loaded.\n                test_case.assertTrue(\n                    np.array_equal(\n                        global_state_dict[\"linear\"][\"weight\"].to_local().numpy(),\n                        linear.weight.to_local().numpy(),\n                    )\n                )\n                test_case.assertTrue(\n                    np.array_equal(\n                        global_state_dict[\"linear\"][\"bias\"].to_local().numpy(),\n                        linear.bias.to_local().numpy(),\n                    )\n                )\n        # Get state dict before compile is allowed.\n        init_state_dict = linear_t_g.state_dict()\n\n        of_graph_out = linear_t_g(x)\n        iter0_state_dict = linear_t_g.state_dict()\n\n        # Load the model and check\n        if call_cnt == 1:\n            # Check additional variable state initialized in job has been loaded.\n            # TrainStep's placement is only on rank 0, so it needs to be broadcast to all ranks before check.\n            cur_train_step = (\n                iter0_state_dict[\"System-Train-TrainStep\"]\n                .to_global(placement=model_tensor_placement, sbp=B)\n                .to_local()\n                .numpy()[0]\n            )\n            test_case.assertEqual(3, cur_train_step)\n            test_case.assertTrue(\n                cur_train_step\n                == last_state_dict[\"System-Train-TrainStep\"]\n                .to_global(placement=model_tensor_placement, sbp=B)\n                .to_local()\n                .numpy()[0]\n            )\n            test_case.assertTrue(\n                np.array_equal(\n                    iter0_state_dict[\"linear\"][\"weight\"].to_local().numpy(),\n                    last_state_dict[\"linear\"][\"weight\"].to_local().numpy(),\n                )\n            )\n            test_case.assertTrue(\n                np.array_equal(\n                    iter0_state_dict[\"linear\"][\"bias\"].to_local().numpy(),\n                    last_state_dict[\"linear\"][\"bias\"].to_local().numpy(),\n                )\n            )\n            test_case.assertTrue(\n                np.array_equal(\n                    iter0_state_dict[\"linear.weight-momentum\"].to_local().numpy(),\n                    last_state_dict[\"linear.weight-momentum\"].to_local().numpy(),\n                )\n            )\n            test_case.assertTrue(\n                np.array_equal(\n                    iter0_state_dict[\"linear.bias-momentum\"].to_local().numpy(),\n                    last_state_dict[\"linear.bias-momentum\"].to_local().numpy(),\n                )\n            )\n\n        of_graph_out = linear_t_g(x)\n        of_graph_out.numpy()\n        iter1_state_dict = linear_t_g.state_dict()\n\n        # Save the model\n        if call_cnt == 0:\n            # Transfer the state dict to model_file_placement\n            model_file_state_dict = flow.utils.global_view.to_global(\n                iter1_state_dict, placement=model_file_placement, sbp=B\n            )\n\n            # Get the local component and save it on model_file_placement's rank(s)\n            if cur_rank in model_file_placement.ranks:\n                iter1_local_dict = flow.utils.global_view.to_local(\n                    model_file_state_dict\n                )\n                flow.save(iter1_local_dict, state_dict_file)\n\n            of_graph_out = linear_t_g(x)\n            of_graph_out.numpy()\n            iter2_state_dict = linear_t_g.state_dict()\n            return iter2_state_dict\n\n    rank_id = flow.env.get_rank()\n    with tempfile.NamedTemporaryFile(\n        prefix=\"graph_save_load_global_\" + str(rank_id)\n    ) as f:\n        iter2_state_dict = train_with_graph(0, f.name)\n        train_with_graph(1, f.name, iter2_state_dict)\n\n\ndef _test_graph_save_load_global_split_2(\n    test_case, model_tensor_placement, model_file_placement\n):\n    \"\"\"Pipeline parallelism on 2 ranks.\n    \"\"\"\n    P0 = flow.placement(model_tensor_placement.type, ranks=[0])\n    P1 = flow.placement(model_tensor_placement.type, ranks=[1])\n    BROADCAST = flow.sbp.broadcast\n\n    def get_sbp(state_dict, tensor):\n        if tensor is state_dict[\"System-Train-TrainStep\"]:\n            return BROADCAST\n        if tensor is state_dict[\"module_pipeline\"][\"m_stage1.linear.weight\"]:\n            return flow.sbp.split(1)\n        if tensor is state_dict[\"module_pipeline\"][\"m_stage1.linear.bias\"]:\n            return BROADCAST\n        return flow.sbp.split(0)\n\n    class Stage0Module(flow.nn.Module):\n        def __init__(self):\n            super().__init__()\n            self.linear = flow.nn.Linear(16, 8)\n            self.relu = flow.nn.ReLU()\n\n        def forward(self, x):\n            return self.relu(self.linear(x))\n\n    class Stage1Module(flow.nn.Module):\n        def __init__(self):\n            super().__init__()\n            self.linear = flow.nn.Linear(8, 1)\n\n        def forward(self, x):\n            return self.linear(x)\n\n    class PipelineModule(flow.nn.Module):\n        def __init__(self):\n            super().__init__()\n            self.m_stage0 = Stage0Module()\n            self.m_stage1 = Stage1Module()\n\n            self.m_stage0.to_global(placement=P0, sbp=BROADCAST)\n            self.m_stage1.to_global(placement=P1, sbp=BROADCAST)\n\n        def forward(self, x):\n            out_stage0 = self.m_stage0(x)\n            in_stage1 = out_stage0.to_global(placement=P1, sbp=BROADCAST)\n            out_stage1 = self.m_stage1(in_stage1)\n            return out_stage1\n\n    class PipelineGraph(flow.nn.Graph):\n        def __init__(self, module_pipleine):\n            super().__init__()\n            self.module_pipeline = module_pipleine\n            self.module_pipeline.m_stage0.to(GraphModule).set_stage(0, P0)\n            self.module_pipeline.m_stage1.to(GraphModule).set_stage(1, P1)\n            self.config.set_gradient_accumulation_steps(2)\n            self.add_optimizer(\n                flow.optim.SGD(self.module_pipeline.parameters(), lr=0.001)\n            )\n\n        def build(self, x):\n            out = self.module_pipeline(x)\n            out = out.sum()\n            out.backward()\n            return out\n\n    def train_with_graph(call_cnt=0, state_dict_file=None, last_state_dict=None):\n        # A fixed input with shape [2, 16]\n        x = flow.tensor(\n            [\n                [\n                    0.4286,\n                    0.7402,\n                    0.4161,\n                    0.6103,\n                    0.7394,\n                    1.1330,\n                    -0.2311,\n                    -0.1013,\n                    0.8537,\n                    0.9757,\n                    -0.9842,\n                    0.3839,\n                    -0.5551,\n                    -0.8832,\n                    0.7820,\n                    0.7421,\n                ],\n                [\n                    -0.1581,\n                    -1.0319,\n                    1.8430,\n                    0.3576,\n                    0.7288,\n                    -0.6912,\n                    0.9966,\n                    1.0840,\n                    -1.1760,\n                    1.5683,\n                    -0.2098,\n                    -1.6439,\n                    -2.7049,\n                    0.1949,\n                    1.6377,\n                    0.0745,\n                ],\n            ],\n            dtype=oneflow.float32,\n            placement=P0,\n            sbp=BROADCAST,\n        )\n\n        module_pipleine = PipelineModule()\n        graph_model = PipelineGraph(module_pipleine)\n        cur_rank = flow.env.get_rank()\n\n        if call_cnt == 1:\n            if cur_rank in model_file_placement.ranks:\n                local_state_dict = flow.load(state_dict_file)\n            else:\n                local_state_dict = None\n\n            # test sbp_for_special_keys\n            global_state_dict = flow.utils.global_view.to_global(\n                local_state_dict, placement=model_file_placement, sbp=get_sbp,\n            )\n            graph_model.load_state_dict(global_state_dict)\n\n            if cur_rank == 0:\n                test_case.assertTrue(\n                    np.array_equal(\n                        global_state_dict[\"module_pipeline\"][\"m_stage0.linear.weight\"]\n                        .to_local()\n                        .numpy(),\n                        module_pipleine.m_stage0.linear.weight.to_local().numpy()[\n                            :4\n                        ],  # The first half of shape (8, 16)\n                    )\n                )\n                test_case.assertTrue(\n                    np.array_equal(\n                        global_state_dict[\"module_pipeline\"][\"m_stage0.linear.bias\"]\n                        .to_local()\n                        .numpy(),\n                        module_pipleine.m_stage0.linear.bias.to_local().numpy()[\n                            :4\n                        ],  # The first half of shape (8,)\n                    )\n                )\n            if cur_rank == 1:\n                test_case.assertTrue(\n                    np.array_equal(\n                        global_state_dict[\"module_pipeline\"][\"m_stage1.linear.weight\"]\n                        .to_local()\n                        .numpy(),\n                        module_pipleine.m_stage1.linear.weight.to_local().numpy()[\n                            :, 4:\n                        ],  # The second half of shape (1, 8)\n                    )\n                )\n                test_case.assertTrue(\n                    np.array_equal(\n                        global_state_dict[\"module_pipeline\"][\"m_stage1.linear.bias\"]\n                        .to_local()\n                        .numpy(),\n                        module_pipleine.m_stage1.linear.bias.to_local().numpy(),\n                    )\n                )\n\n        graph_model(x)\n        iter0_state_dict = graph_model.state_dict()\n\n        if call_cnt == 1:\n            # TrainStep\n            cur_train_step = (\n                iter0_state_dict[\"System-Train-TrainStep\"]\n                .to_global(placement=model_tensor_placement, sbp=BROADCAST)\n                .to_local()\n                .numpy()[0]\n            )\n            test_case.assertEqual(3, cur_train_step)\n            test_case.assertTrue(\n                cur_train_step\n                == last_state_dict[\"System-Train-TrainStep\"]\n                .to_global(placement=model_tensor_placement, sbp=BROADCAST)\n                .to_local()\n                .numpy()[0]\n            )\n\n            # Weight & bias\n            test_case.assertTrue(\n                np.array_equal(\n                    iter0_state_dict[\"module_pipeline\"][\"m_stage0.linear.weight\"]\n                    .to_local()\n                    .numpy(),\n                    last_state_dict[\"module_pipeline\"][\"m_stage0.linear.weight\"]\n                    .to_local()\n                    .numpy(),\n                )\n            )\n            test_case.assertTrue(\n                np.array_equal(\n                    iter0_state_dict[\"module_pipeline\"][\"m_stage0.linear.bias\"]\n                    .to_local()\n                    .numpy(),\n                    last_state_dict[\"module_pipeline\"][\"m_stage0.linear.bias\"]\n                    .to_local()\n                    .numpy(),\n                )\n            )\n            test_case.assertTrue(\n                np.array_equal(\n                    iter0_state_dict[\"module_pipeline\"][\"m_stage1.linear.weight\"]\n                    .to_local()\n                    .numpy(),\n                    last_state_dict[\"module_pipeline\"][\"m_stage1.linear.weight\"]\n                    .to_local()\n                    .numpy(),\n                )\n            )\n            test_case.assertTrue(\n                np.array_equal(\n                    iter0_state_dict[\"module_pipeline\"][\"m_stage1.linear.bias\"]\n                    .to_local()\n                    .numpy(),\n                    last_state_dict[\"module_pipeline\"][\"m_stage1.linear.bias\"]\n                    .to_local()\n                    .numpy(),\n                )\n            )\n\n        graph_model(x)\n        iter1_state_dict = graph_model.state_dict()\n\n        if call_cnt == 0:\n            model_file_state_dict = flow.utils.global_view.to_global(\n                iter1_state_dict, placement=model_file_placement, sbp=get_sbp,\n            )\n            if flow.env.get_rank() in model_file_placement.ranks:\n                flow.save(\n                    flow.utils.global_view.to_local(model_file_state_dict),\n                    state_dict_file,\n                )\n\n            graph_model(x)\n            iter2_state_dict = graph_model.state_dict()\n            return iter2_state_dict\n\n    rank_id = flow.env.get_rank()\n    with tempfile.NamedTemporaryFile(\n        prefix=\"graph_save_load_global_\" + str(rank_id)\n    ) as f:\n        iter2_state_dict = train_with_graph(0, f.name)\n        train_with_graph(1, f.name, iter2_state_dict)\n\n\ndef _test_graph_save_load_global_split_4(\n    test_case, model_tensor_placement, model_file_placement\n):\n    \"\"\"Pipeline parallelism on 4 ranks.\n    \"\"\"\n    P0 = flow.placement(model_tensor_placement.type, ranks=[0])\n    P1 = flow.placement(model_tensor_placement.type, ranks=[1])\n    P2 = flow.placement(model_tensor_placement.type, ranks=[2])\n    P3 = flow.placement(model_tensor_placement.type, ranks=[3])\n    BROADCAST = flow.sbp.broadcast\n\n    def get_sbp(state_dict, tensor):\n        if tensor is state_dict[\"System-Train-TrainStep\"]:\n            return BROADCAST\n        if tensor is state_dict[\"module_pipeline\"][\"m_stage3.linear.weight\"]:\n            return flow.sbp.split(1)\n        if tensor is state_dict[\"module_pipeline\"][\"m_stage3.linear.bias\"]:\n            return BROADCAST\n        return flow.sbp.split(0)\n\n    class Stage0Module(flow.nn.Module):\n        def __init__(self):\n            super().__init__()\n            self.linear = flow.nn.Linear(16, 8)\n            self.relu = flow.nn.ReLU()\n\n        def forward(self, x):\n            return self.relu(self.linear(x))\n\n    class Stage1Module(flow.nn.Module):\n        def __init__(self):\n            super().__init__()\n            self.linear = flow.nn.Linear(8, 4)\n            self.relu = flow.nn.ReLU()\n\n        def forward(self, x):\n            return self.relu(self.linear(x))\n\n    class Stage2Module(flow.nn.Module):\n        def __init__(self):\n            super().__init__()\n            self.linear = flow.nn.Linear(4, 2)\n            self.relu = flow.nn.ReLU()\n\n        def forward(self, x):\n            return self.relu(self.linear(x))\n\n    class Stage3Module(flow.nn.Module):\n        def __init__(self):\n            super().__init__()\n            self.linear = flow.nn.Linear(2, 1)\n\n        def forward(self, x):\n            return self.linear(x)\n\n    class PipelineModule(flow.nn.Module):\n        def __init__(self):\n            super().__init__()\n            self.m_stage0 = Stage0Module()\n            self.m_stage1 = Stage1Module()\n            self.m_stage2 = Stage2Module()\n            self.m_stage3 = Stage3Module()\n\n            self.m_stage0.to_global(placement=P0, sbp=BROADCAST)\n            self.m_stage1.to_global(placement=P1, sbp=BROADCAST)\n            self.m_stage2.to_global(placement=P2, sbp=BROADCAST)\n            self.m_stage3.to_global(placement=P3, sbp=BROADCAST)\n\n        def forward(self, x):\n            out_stage0 = self.m_stage0(x)\n\n            in_stage1 = out_stage0.to_global(placement=P1, sbp=BROADCAST)\n            out_stage1 = self.m_stage1(in_stage1)\n\n            in_stage2 = out_stage1.to_global(placement=P2, sbp=BROADCAST)\n            out_stage2 = self.m_stage2(in_stage2)\n\n            in_stage3 = out_stage2.to_global(placement=P3, sbp=BROADCAST)\n            out_stage3 = self.m_stage3(in_stage3)\n\n            return out_stage3\n\n    class PipelineGraph(flow.nn.Graph):\n        def __init__(self, module_pipleine):\n            super().__init__()\n            self.module_pipeline = module_pipleine\n            self.module_pipeline.m_stage0.to(GraphModule).set_stage(0, P0)\n            self.module_pipeline.m_stage1.to(GraphModule).set_stage(1, P1)\n            self.module_pipeline.m_stage2.to(GraphModule).set_stage(2, P2)\n            self.module_pipeline.m_stage3.to(GraphModule).set_stage(3, P3)\n            self.config.set_gradient_accumulation_steps(2)\n            self.add_optimizer(\n                flow.optim.SGD(self.module_pipeline.parameters(), lr=0.001)\n            )\n\n        def build(self, x):\n            out = self.module_pipeline(x)\n            out = out.sum()\n            out.backward()\n            return out\n\n    def train_with_graph(call_cnt=0, state_dict_file=None, last_state_dict=None):\n        # A fixed input with shape [2, 16]\n        x = flow.tensor(\n            [\n                [\n                    0.4286,\n                    0.7402,\n                    0.4161,\n                    0.6103,\n                    0.7394,\n                    1.1330,\n                    -0.2311,\n                    -0.1013,\n                    0.8537,\n                    0.9757,\n                    -0.9842,\n                    0.3839,\n                    -0.5551,\n                    -0.8832,\n                    0.7820,\n                    0.7421,\n                ],\n                [\n                    -0.1581,\n                    -1.0319,\n                    1.8430,\n                    0.3576,\n                    0.7288,\n                    -0.6912,\n                    0.9966,\n                    1.0840,\n                    -1.1760,\n                    1.5683,\n                    -0.2098,\n                    -1.6439,\n                    -2.7049,\n                    0.1949,\n                    1.6377,\n                    0.0745,\n                ],\n            ],\n            dtype=flow.float32,\n            placement=P0,\n            sbp=BROADCAST,\n        )\n\n        module_pipleine = PipelineModule()\n        graph_model = PipelineGraph(module_pipleine)\n        cur_rank = flow.env.get_rank()\n\n        if call_cnt == 1:\n            if cur_rank in model_file_placement.ranks:\n                local_state_dict = flow.load(state_dict_file)\n            else:\n                local_state_dict = None\n\n            # test sbp_for_special_keys\n            global_state_dict = flow.utils.global_view.to_global(\n                local_state_dict, placement=model_file_placement, sbp=get_sbp,\n            )\n            graph_model.load_state_dict(global_state_dict)\n\n            if cur_rank == 0:\n                test_case.assertTrue(\n                    np.array_equal(\n                        global_state_dict[\"module_pipeline\"][\"m_stage0.linear.weight\"]\n                        .to_local()\n                        .numpy(),\n                        module_pipleine.m_stage0.linear.weight.to_local().numpy()[\n                            :4\n                        ],  # The first half of shape (8, 16)\n                    )\n                )\n                test_case.assertTrue(\n                    np.array_equal(\n                        global_state_dict[\"module_pipeline\"][\"m_stage0.linear.bias\"]\n                        .to_local()\n                        .numpy(),\n                        module_pipleine.m_stage0.linear.bias.to_local().numpy()[\n                            :4\n                        ],  # The first half of shape (8,)\n                    )\n                )\n            if cur_rank == 1:\n                test_case.assertTrue(\n                    np.array_equal(\n                        global_state_dict[\"module_pipeline\"][\"m_stage1.linear.weight\"]\n                        .to_local()\n                        .numpy(),\n                        module_pipleine.m_stage1.linear.weight.to_local().numpy()[\n                            2:, :\n                        ],  # The second half of shape (4, 8)\n                    )\n                )\n                test_case.assertTrue(\n                    np.array_equal(\n                        global_state_dict[\"module_pipeline\"][\"m_stage1.linear.bias\"]\n                        .to_local()\n                        .numpy(),\n                        module_pipleine.m_stage1.linear.bias.to_local().numpy()[\n                            2:\n                        ],  # The second half if shape (4,)\n                    )\n                )\n\n        graph_model(x)\n        iter0_state_dict = graph_model.state_dict()\n\n        if call_cnt == 1:\n            # TrainStep\n            cur_train_step = (\n                iter0_state_dict[\"System-Train-TrainStep\"]\n                .to_global(placement=model_tensor_placement, sbp=BROADCAST)\n                .to_local()\n                .numpy()[0]\n            )\n            test_case.assertEqual(3, cur_train_step)\n            test_case.assertTrue(\n                cur_train_step\n                == last_state_dict[\"System-Train-TrainStep\"]\n                .to_global(placement=model_tensor_placement, sbp=BROADCAST)\n                .to_local()\n                .numpy()[0]\n            )\n\n            # Weight & bias\n            test_case.assertTrue(\n                np.array_equal(\n                    iter0_state_dict[\"module_pipeline\"][\"m_stage0.linear.weight\"]\n                    .to_local()\n                    .numpy(),\n                    last_state_dict[\"module_pipeline\"][\"m_stage0.linear.weight\"]\n                    .to_local()\n                    .numpy(),\n                )\n            )\n            test_case.assertTrue(\n                np.array_equal(\n                    iter0_state_dict[\"module_pipeline\"][\"m_stage0.linear.bias\"]\n                    .to_local()\n                    .numpy(),\n                    last_state_dict[\"module_pipeline\"][\"m_stage0.linear.bias\"]\n                    .to_local()\n                    .numpy(),\n                )\n            )\n            test_case.assertTrue(\n                np.array_equal(\n                    iter0_state_dict[\"module_pipeline\"][\"m_stage1.linear.weight\"]\n                    .to_local()\n                    .numpy(),\n                    last_state_dict[\"module_pipeline\"][\"m_stage1.linear.weight\"]\n                    .to_local()\n                    .numpy(),\n                )\n            )\n            test_case.assertTrue(\n                np.array_equal(\n                    iter0_state_dict[\"module_pipeline\"][\"m_stage1.linear.bias\"]\n                    .to_local()\n                    .numpy(),\n                    last_state_dict[\"module_pipeline\"][\"m_stage1.linear.bias\"]\n                    .to_local()\n                    .numpy(),\n                )\n            )\n\n            test_case.assertTrue(\n                np.array_equal(\n                    iter0_state_dict[\"module_pipeline\"][\"m_stage2.linear.weight\"]\n                    .to_local()\n                    .numpy(),\n                    last_state_dict[\"module_pipeline\"][\"m_stage2.linear.weight\"]\n                    .to_local()\n                    .numpy(),\n                )\n            )\n            test_case.assertTrue(\n                np.array_equal(\n                    iter0_state_dict[\"module_pipeline\"][\"m_stage2.linear.bias\"]\n                    .to_local()\n                    .numpy(),\n                    last_state_dict[\"module_pipeline\"][\"m_stage2.linear.bias\"]\n                    .to_local()\n                    .numpy(),\n                )\n            )\n\n            test_case.assertTrue(\n                np.array_equal(\n                    iter0_state_dict[\"module_pipeline\"][\"m_stage3.linear.weight\"]\n                    .to_local()\n                    .numpy(),\n                    last_state_dict[\"module_pipeline\"][\"m_stage3.linear.weight\"]\n                    .to_local()\n                    .numpy(),\n                )\n            )\n            test_case.assertTrue(\n                np.array_equal(\n                    iter0_state_dict[\"module_pipeline\"][\"m_stage3.linear.bias\"]\n                    .to_local()\n                    .numpy(),\n                    last_state_dict[\"module_pipeline\"][\"m_stage3.linear.bias\"]\n                    .to_local()\n                    .numpy(),\n                )\n            )\n\n        graph_model(x)\n        iter1_state_dict = graph_model.state_dict()\n\n        if call_cnt == 0:\n            model_file_state_dict = flow.utils.global_view.to_global(\n                iter1_state_dict, placement=model_file_placement, sbp=get_sbp,\n            )\n            if flow.env.get_rank() in model_file_placement.ranks:\n                flow.save(\n                    flow.utils.global_view.to_local(model_file_state_dict),\n                    state_dict_file,\n                )\n\n            graph_model(x)\n            iter2_state_dict = graph_model.state_dict()\n            return iter2_state_dict\n\n    rank_id = flow.env.get_rank()\n    with tempfile.NamedTemporaryFile(\n        prefix=\"graph_save_load_global_\" + str(rank_id)\n    ) as f:\n        iter2_state_dict = train_with_graph(0, f.name)\n        train_with_graph(1, f.name, iter2_state_dict)\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n@flow.unittest.skip_unless_1n2d()\nclass TestGraphSaveLoadGlobal2d(oneflow.unittest.TestCase):\n    def test_linear_graph_save_load_gpu_1_broadcast(test_case):\n        _test_linear_graph_save_load_global_broadcast(\n            test_case,\n            model_tensor_placement=flow.placement(\"cuda\", ranks=[0, 1]),\n            model_file_placement=flow.placement(\"cpu\", ranks=[0]),\n        )\n\n    def test_linear_graph_save_load_cpu_1_broadcast(test_case):\n        _test_linear_graph_save_load_global_broadcast(\n            test_case,\n            model_tensor_placement=flow.placement(\"cpu\", ranks=[0, 1]),\n            model_file_placement=flow.placement(\"cpu\", ranks=[0]),\n        )\n\n    def test_graph_save_load_gpu_2_split(test_case):\n        _test_graph_save_load_global_split_2(\n            test_case,\n            model_tensor_placement=flow.placement(\"cuda\", ranks=[0, 1]),\n            model_file_placement=flow.placement(\"cpu\", ranks=[0, 1]),\n        )\n\n    @unittest.skip(\"skip for now, becase it failed 2 times in past week\")\n    def test_graph_save_load_cpu_2_split(test_case):\n        _test_graph_save_load_global_split_2(\n            test_case,\n            model_tensor_placement=flow.placement(\"cpu\", ranks=[0, 1]),\n            model_file_placement=flow.placement(\"cpu\", ranks=[0, 1]),\n        )\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n@flow.unittest.skip_unless_1n4d()\nclass TestGraphSaveLoadGlobal4d(oneflow.unittest.TestCase):\n    def test_graph_save_load_gpu_2_split_2_none(test_case):\n        _test_graph_save_load_global_split_4(\n            test_case,\n            model_tensor_placement=flow.placement(\"cuda\", ranks=[0, 1, 2, 3]),\n            model_file_placement=flow.placement(\"cpu\", ranks=[0, 1]),\n        )\n\n    @unittest.skip(\"skip for now, becase it failed 24 times in past week\")\n    def test_graph_save_load_cpu_2_split_2_none(test_case):\n        _test_graph_save_load_global_split_4(\n            test_case,\n            model_tensor_placement=flow.placement(\"cpu\", ranks=[0, 1, 2, 3]),\n            model_file_placement=flow.placement(\"cpu\", ranks=[0, 1]),\n        )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/graph/test_graph_scalar.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport os\nimport unittest\n\nimport numpy as np\nimport oneflow as flow\nimport oneflow.unittest\n\n\ndef _test_scalar_graph(test_case, device):\n    x = flow.tensor(3.0, device=device)\n\n    class MyModule(flow.nn.Module):\n        def __init__(self):\n            super().__init__()\n            self.weight = flow.nn.Parameter(flow.tensor(5.0, device=device))\n\n        def forward(self, x):\n            return x * self.weight + 1.0\n\n    my_module = MyModule()\n    of_eager_out = my_module(x)\n\n    class ScalarGraph(flow.nn.Graph):\n        def __init__(self):\n            super().__init__()\n            self.m = my_module\n\n        def build(self, x):\n            return self.m(x)\n\n    scalar_g = ScalarGraph()\n    of_lazy_out = scalar_g(x)\n    test_case.assertTrue(np.array_equal(of_lazy_out.numpy(), of_eager_out.numpy()))\n\n\ndef _test_scalar_train_graph(test_case, device):\n    class MyModule(flow.nn.Module):\n        def __init__(self):\n            super().__init__()\n            self.weight = flow.nn.Parameter(flow.tensor(5.0, device=device))\n\n        def forward(self, x):\n            return x * self.weight + 1.0\n\n    my_module = MyModule()\n    of_sgd = flow.optim.SGD(my_module.parameters(), lr=0.001, momentum=0.9)\n    eager_out_list = []\n    for i in range(3):\n        x = flow.tensor(i * 1.0, device=device, requires_grad=False)\n        of_eager_out = my_module(x)\n        of_eager_out.backward()\n        of_sgd.step()\n        of_sgd.zero_grad()\n        eager_out_list.append(of_eager_out)\n\n    lazy_module = MyModule()\n\n    class ScalarTrainGraph(flow.nn.Graph):\n        def __init__(self):\n            super().__init__()\n            self.m = lazy_module\n            of_sgd = flow.optim.SGD(lazy_module.parameters(), lr=0.001, momentum=0.9)\n            # self.m = MyModule()\n            # of_sgd = flow.optim.SGD(self.m.parameters(), lr=0.001, momentum=0.9)\n            self.add_optimizer(of_sgd)\n\n        def build(self, x):\n            loss = self.m(x)\n            loss.backward()\n            return loss\n\n    lazy_out_list = []\n    scalar_g = ScalarTrainGraph()\n    for i in range(3):\n        x = flow.tensor(i * 1.0, device=device)\n        of_lazy_out = scalar_g(x)\n        lazy_out_list.append(of_lazy_out)\n\n    for i in range(3):\n        test_case.assertTrue(\n            np.array_equal(lazy_out_list[i].numpy(), eager_out_list[i].numpy())\n        )\n\n\ndef _test_scalar_global_train_graph(test_case, placement):\n    sbp_b = flow.sbp.broadcast\n\n    class MyModule(flow.nn.Module):\n        def __init__(self):\n            super().__init__()\n            self.weight = flow.nn.Parameter(flow.tensor(5.0))\n\n        def forward(self, x):\n            return x * self.weight + 1.0\n\n    my_module = MyModule()\n\n    of_sgd = flow.optim.SGD(my_module.parameters(), lr=0.001, momentum=0.9)\n    eager_out_list = []\n    for i in range(3):\n        x = flow.tensor(i * 1.0, requires_grad=False)\n        of_eager_out = my_module(x)\n        of_eager_out.backward()\n        of_sgd.step()\n        of_sgd.zero_grad()\n        eager_out_list.append(of_eager_out)\n\n    lazy_module = MyModule()\n    lazy_module.to_global(placement=placement, sbp=sbp_b)\n\n    class ScalarTrainGraph(flow.nn.Graph):\n        def __init__(self):\n            super().__init__()\n            self.m = lazy_module\n            of_sgd = flow.optim.SGD(lazy_module.parameters(), lr=0.001, momentum=0.9)\n            self.add_optimizer(of_sgd)\n\n        def build(self, x):\n            loss = self.m(x)\n            loss.backward()\n            return loss\n\n    lazy_out_list = []\n    scalar_g = ScalarTrainGraph()\n    for i in range(3):\n        x = flow.tensor(i * 1.0, requires_grad=False)\n        x = x.to_global(placement=placement, sbp=sbp_b)\n        of_lazy_out = scalar_g(x)\n        lazy_out_list.append(of_lazy_out)\n    for i in range(3):\n        test_case.assertTrue(\n            np.array_equal(\n                lazy_out_list[i].to_local().numpy(), eager_out_list[i].numpy()\n            )\n        )\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n@flow.unittest.skip_unless_1n1d()\nclass TestScalarGraph(oneflow.unittest.TestCase):\n    def test_scalar_graph_gpu(test_case):\n        _test_scalar_graph(test_case, flow.device(\"cuda\"))\n\n    def test_scalar_graph_cpu(test_case):\n        _test_scalar_graph(test_case, flow.device(\"cpu\"))\n\n    def test_scalar_train_graph_gpu(test_case):\n        _test_scalar_train_graph(test_case, flow.device(\"cuda\"))\n\n    def test_scalar_train_graph_cpu(test_case):\n        _test_scalar_train_graph(test_case, flow.device(\"cpu\"))\n\n    def test_scalar_global_train_graph_gpu(test_case):\n        _test_scalar_global_train_graph(test_case, flow.placement(\"cuda\", ranks=[0]))\n\n    def test_scalar_global_train_graph_cpu(test_case):\n        _test_scalar_global_train_graph(test_case, flow.placement(\"cpu\", ranks=[0]))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/graph/test_graph_separate_compile.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nimport contextlib\nimport os\nimport numpy as np\n\nimport oneflow as flow\nfrom oneflow import nn\nimport oneflow.unittest\n\n\n@contextlib.contextmanager\ndef modified_environ(*remove, **update):\n    \"\"\"\n    From: https://stackoverflow.com/questions/2059482/temporarily-modify-the-current-processs-environment\n    Temporarily updates the ``os.environ`` dictionary in-place.\n\n    The ``os.environ`` dictionary is updated in-place so that the modification\n    is sure to work in all situations.\n\n    :param remove: Environment variables to remove.\n    :param update: Dictionary of environment variables and values to add/update.\n    \"\"\"\n    env = os.environ\n    update = update or {}\n    remove = remove or []\n\n    # List of environment variables being updated or removed.\n    stomped = (set(update.keys()) | set(remove)) & set(env.keys())\n    # Environment variables and values to restore on exit.\n    update_after = {k: env[k] for k in stomped}\n    # Environment variables and values to remove on exit.\n    remove_after = frozenset(k for k in update if k not in env)\n\n    try:\n        env.update(update)\n        [env.pop(k, None) for k in remove]\n        yield\n    finally:\n        env.update(update_after)\n        [env.pop(k) for k in remove_after]\n\n\ndef run_testcase_with_sep_compile(test_case_cls):\n    new_cls = type(\"SeparationCompile_\" + test_case_cls.__name__, (test_case_cls,), {})\n    with modified_environ(\n        ONEFLOW_LAZY_COMPILE_MODE=\"rank_per_process\", ENABLE_LOGICAL_CHAIN=\"1\"\n    ):\n        assert os.environ.get(\"ONEFLOW_LAZY_COMPILE_MODE\") == \"rank_per_process\"\n        assert os.environ.get(\"ENABLE_LOGICAL_CHAIN\") == \"1\"\n        flow.boxing.nccl.enable_use_compute_stream(True)\n        unittest.TextTestRunner().run(\n            unittest.TestLoader().loadTestsFromTestCase(new_cls)\n        )\n\n\ndef _get_comb1to2d_test():\n    class _TestModuleDiffHierarchy(nn.Module):\n        def forward(self, x):\n            sbp_1ds = [\n                flow.sbp.broadcast,\n                flow.sbp.partial_sum,\n                flow.sbp.split(0),\n                flow.sbp.split(1),\n                flow.sbp.split(2),\n            ]\n\n            for sbp1 in sbp_1ds:\n\n                for sbp2 in sbp_1ds:\n                    for sbp3 in sbp_1ds:\n                        # (2, 2) -> 4\n                        x = x.to_global(\n                            placement=flow.placement(\n                                type=\"cuda\", ranks=np.array(range(4))\n                            ),\n                            sbp=[sbp1],\n                        )\n                        # 4 -> (2, 2)\n                        x = x.to_global(\n                            placement=flow.placement(\n                                type=\"cuda\", ranks=np.array(range(4)).reshape(2, 2)\n                            ),\n                            sbp=[sbp2, sbp3],\n                        )\n\n            return x\n\n    class _TestModuleDiffPlacement(nn.Module):\n        def forward(self, x):\n            sbp_1ds = [\n                flow.sbp.broadcast,\n                flow.sbp.partial_sum,\n                flow.sbp.split(0),\n                flow.sbp.split(1),\n                flow.sbp.split(2),\n            ]\n            for sbp1 in sbp_1ds:\n                for sbp2 in sbp_1ds:\n                    for sbp3 in sbp_1ds:\n                        # (2, 2) -> 3\n                        # 4 is not divisible by 3\n                        x = x.to_global(\n                            placement=flow.placement(\n                                type=\"cuda\", ranks=np.array(range(3))\n                            ),\n                            sbp=[sbp1],\n                        )\n                        # 3 -> (2, 2)\n                        x = x.to_global(\n                            placement=flow.placement(\n                                type=\"cuda\", ranks=np.array(range(4)).reshape(2, 2)\n                            ),\n                            sbp=[sbp2, sbp3],\n                        )\n\n            return x\n\n    class _TestGraph(nn.Graph):\n        def __init__(self, model):\n            super().__init__()\n            self.model = model\n\n        def build(self, x):\n            x = self.model(x)\n            return x\n\n    @flow.unittest.skip_unless_1n4d()\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    class TestSepCompileLazyAllSbpCombinationTesting(flow.unittest.TestCase):\n        def test_lazy_boxing_2d_all_combination_diff_hierarchy(test_case):\n            os.environ[\"ONEFLOW_BOXING_DISABLE_MIDDLE_NODE_AND_CHECK\"] = \"0\"\n            os.environ[\"ONEFLOW_BOXING_ENABLE_GENERAL_BASIC_COMMUNICATION\"] = \"0\"\n\n            x = flow.ones(\n                4,\n                12,\n                4,\n                sbp=[flow.sbp.broadcast, flow.sbp.broadcast],\n                placement=flow.placement(\n                    type=\"cuda\", ranks=np.array(range(4)).reshape(2, 2)\n                ),\n            )\n            model_diff_hierarchy = _TestModuleDiffHierarchy()\n            graph_diff_hierarchy = _TestGraph(model_diff_hierarchy)\n            y = graph_diff_hierarchy(x)\n\n        def test_lazy_boxing_2d_all_combination_diff_placement(test_case):\n            os.environ[\"ONEFLOW_BOXING_DISABLE_MIDDLE_NODE_AND_CHECK\"] = \"0\"\n            os.environ[\"ONEFLOW_BOXING_ENABLE_GENERAL_BASIC_COMMUNICATION\"] = \"0\"\n\n            x = flow.ones(\n                4,\n                12,\n                4,\n                sbp=[flow.sbp.broadcast, flow.sbp.broadcast],\n                placement=flow.placement(\n                    type=\"cuda\", ranks=np.array(range(4)).reshape(2, 2)\n                ),\n            )\n            model_diff_placement = _TestModuleDiffPlacement()\n            graph_diff_placement = _TestGraph(model_diff_placement)\n            z = graph_diff_placement(x)\n            test_case.assertTrue(np.allclose(x.numpy(), z.numpy(), 1e-05, 1e-05))\n\n    return TestSepCompileLazyAllSbpCombinationTesting\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n@flow.unittest.skip_unless_1n4d()\nclass TestSeparationCompile(oneflow.unittest.TestCase):\n    def test_test_alexnet_auto_parallel(test_case):\n        from test_alexnet_auto_parallel import TestAlexnetAutoParallel\n\n        run_testcase_with_sep_compile(TestAlexnetAutoParallel)\n\n    def _test_comb1to2d(test_case):\n        run_testcase_with_sep_compile(_get_comb1to2d_test())\n\n    def test_graph_zero(test_case):\n        from test_graph_zero import TestLinearTrainGraph2DWithZeRO\n\n        run_testcase_with_sep_compile(TestLinearTrainGraph2DWithZeRO)\n\n    def test_graph_clip_grad_norm(test_case):\n        from test_graph_clip_grad_norm import TestGraphClipGradNorm\n\n        run_testcase_with_sep_compile(TestGraphClipGradNorm)\n\n    def test_graph_pipeline_grad_acc_and_activatioin_checkpointing(test_case):\n        from test_graph_pipeline import TestGraphPipeline\n\n        run_testcase_with_sep_compile(TestGraphPipeline)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/graph/test_graph_session_env_destruct.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport os\nimport unittest\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.unittest\n\nlinear = flow.nn.Linear(3, 8, False)\ninput_arr = np.random.randn(8, 3).astype(np.float32)\nnp_weight = np.ones((3, 8)).astype(np.float32)\nnp_weight.fill(2.3)\nx = flow.tensor(input_arr)\nflow.nn.init.constant_(linear.weight, 2.3)\nof_eager_out = linear(x)\nnp_out = np.matmul(input_arr, np_weight)\nassert np.allclose(of_eager_out.numpy(), np_out, 1e-05, 1e-05)\n\n\nclass LinearGraphDestruct(flow.nn.Graph):\n    def __init__(self):\n        super().__init__()\n        self.my_linear = linear\n\n    def build(self, x):\n        return self.my_linear(x)\n\n\nlinear_g_d = LinearGraphDestruct()\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestLinearGraphDestruct(oneflow.unittest.TestCase):\n    def test_linear_graph_destruct(test_case):\n        of_lazy_out = linear_g_d(x)\n        assert np.array_equal(of_lazy_out.numpy(), of_eager_out.numpy())\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/graph/test_graph_session_env_destruct1.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport os\nimport unittest\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.unittest\n\nlinear = flow.nn.Linear(3, 8, False)\ninput_arr = np.random.randn(8, 3).astype(np.float32)\nnp_weight = np.ones((3, 8)).astype(np.float32)\nnp_weight.fill(2.3)\nx = flow.tensor(input_arr)\nflow.nn.init.constant_(linear.weight, 2.3)\nof_eager_out = linear(x)\nnp_out = np.matmul(input_arr, np_weight)\nassert np.allclose(of_eager_out.numpy(), np_out, 1e-05, 1e-05)\n\n\nclass LinearGraphDestruct1(flow.nn.Graph):\n    def __init__(self):\n        super().__init__()\n        self.my_linear = linear\n\n    def build(self, x):\n        return self.my_linear(x)\n\n\n# test graph destruction when graph is not compiled\nlinear_g_d_not_compiled = LinearGraphDestruct1()\nprint(\"test graph destruction when graph is not compiled\")\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/graph/test_graph_sparse_optimizer.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport os\nimport unittest\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\nclass MyModule(flow.nn.Module):\n    def __init__(self, placement=None, sbp=None):\n        super().__init__()\n        w = flow.randn(10, 10, placement=placement, sbp=sbp)\n        self.weight = flow.nn.Parameter(w)\n\n    def forward(self, input):\n        return flow._C.gather(self.weight, input, 0)\n\n\nclass MyGraph(flow.nn.Graph):\n    def __init__(self, module):\n        super().__init__()\n        self.m = module\n        sgd = flow.optim.SGD(module.parameters(), lr=1e-3)\n        self.add_optimizer(sgd, is_sparse=True)\n\n    def build(self, input):\n        result = self.m(input)\n        result.mean().backward()\n\n\ndef _rand_input(placement=None, sbp=None):\n    generator = flow.Generator()\n    generator.manual_seed(0)\n    return flow.randint(0, 10, (8,), generator=generator, placement=placement, sbp=sbp)\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n@flow.unittest.skip_unless_1n1d()\nclass GraphSparseOptimizerTest(oneflow.unittest.TestCase):\n    @unittest.skip(\"skip for now, becase it failed 6 times in past week\")\n    def test(test_case):\n        PLC = flow.placement(\"cuda\", ranks=[0])\n        SBP = flow.sbp.broadcast\n        m = MyModule(PLC, SBP)\n        graph = MyGraph(m)\n        graph._compile(_rand_input(PLC, SBP))\n\n        sparse_optimizer_found = False\n        for op in graph._full_graph_proto.net.op:\n            # print(\"==>\", op.name)\n            if op.HasField(\"user_conf\"):\n                # print(\"  -->\", op.user_conf.op_type_name)\n                if op.user_conf.op_type_name == \"indexed_slices_sgd_update\":\n                    sparse_optimizer_found = True\n                    break\n\n        test_case.assertTrue(sparse_optimizer_found)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/graph/test_graph_sparse_softmax_cross_entropy.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport os\nimport unittest\nimport numpy as np\nimport oneflow as flow\nimport oneflow.unittest\n\n\nclass CrossEntropyModule(flow.nn.Module):\n    def __init__(self, pred):\n        super().__init__()\n        if pred.is_global:\n            self.param = flow.nn.Parameter(\n                flow.zeros(\n                    *pred.shape,\n                    dtype=pred.dtype,\n                    placement=pred.placement,\n                    sbp=pred.sbp,\n                )\n            )\n        else:\n            self.param = flow.nn.Parameter(\n                flow.zeros(*pred.shape, dtype=pred.dtype, device=pred.device)\n            )\n\n    def forward(self, pred, label):\n        pred = pred + self.param\n        loss = flow._C.sparse_softmax_cross_entropy(pred, label)\n        return loss.mean()\n\n\nclass CrossEntropyGraph(flow.nn.Graph):\n    def __init__(self, module):\n        super().__init__()\n        self.m = module\n        self.add_optimizer(flow.optim.SGD([module.param], lr=1.0, momentum=0.0))\n\n    def build(self, pred, label):\n        loss = self.m(pred, label)\n        loss.backward()\n        return loss\n\n\ndef _compare_with_nn_cross_entropy_loss(\n    test_case, pred, label, pred_sbp=None, label_sbp=None\n):\n    if pred.is_global:\n        assert label.is_global\n        pred_ = pred.to_local().detach().clone()\n        label_ = label.to_local()\n    else:\n        pred_ = pred.detach().clone()\n        label_ = label\n\n    pred_.requires_grad = True\n    cross_entropy_loss = flow.nn.CrossEntropyLoss()\n    loss = cross_entropy_loss(pred_, label_)\n    loss.backward()\n\n    if pred_sbp is not None:\n        pred = pred.to_global(sbp=pred_sbp)\n\n    if label_sbp is not None:\n        label = label.to_global(sbp=label_sbp)\n\n    cross_entropy_module = CrossEntropyModule(pred)\n    cross_entropy_graph = CrossEntropyGraph(cross_entropy_module)\n    graph_loss = cross_entropy_graph(pred, label)\n\n    loss_a = loss.numpy()\n    grad_a = pred_.grad.numpy()\n    if graph_loss.is_local:\n        loss_b = graph_loss.numpy()\n        grad_b = -cross_entropy_module.param.numpy()\n    else:\n        graph_loss = graph_loss.to_global(\n            sbp=[flow.sbp.broadcast()] * len(graph_loss.sbp)\n        )\n        loss_b = graph_loss.to_local().numpy()\n        pred_grad = cross_entropy_module.param.to_global(\n            sbp=[flow.sbp.broadcast()] * len(cross_entropy_module.param.sbp)\n        )\n        grad_b = -pred_grad.to_local().numpy()\n\n    test_case.assertTrue(np.allclose(loss_a, loss_b), f\"{loss_a} vs. {loss_b}\")\n    test_case.assertTrue(np.allclose(grad_a, grad_b), f\"\\n{grad_a}\\nvs.\\n{grad_b}\")\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\nclass TestSparseSoftmaxCrossEntropyGraph(oneflow.unittest.TestCase):\n    @flow.unittest.skip_unless_1n1d()\n    def test_local(test_case):\n        pred = flow.randn(8, 10).to(\"cuda\")\n        label = flow.randint(0, 10, (8,)).to(\"cuda\")\n        _compare_with_nn_cross_entropy_loss(test_case, pred, label)\n\n    @flow.unittest.skip_unless_1n2d()\n    def test_data_split(test_case):\n        pred = flow.randn(8, 10)\n        label = flow.randint(0, 10, (8,))\n        placement = flow.placement(\"cuda\", list(range(flow.env.get_world_size())))\n        pred = pred.to_global(placement=placement, sbp=flow.sbp.broadcast())\n        label = label.to_global(placement=placement, sbp=flow.sbp.broadcast())\n        _compare_with_nn_cross_entropy_loss(\n            test_case, pred, label, flow.sbp.split(0), flow.sbp.split(0)\n        )\n\n    @flow.unittest.skip_unless_1n2d()\n    def test_model_split(test_case):\n        pred = flow.randn(8, 10)\n        label = flow.randint(0, 10, (8,))\n        placement = flow.placement(\"cuda\", list(range(flow.env.get_world_size())))\n        pred = pred.to_global(placement=placement, sbp=flow.sbp.broadcast())\n        label = label.to_global(placement=placement, sbp=flow.sbp.broadcast())\n        _compare_with_nn_cross_entropy_loss(\n            test_case, pred, label, flow.sbp.split(1), flow.sbp.broadcast()\n        )\n\n    @flow.unittest.skip_unless_1n4d()\n    def test_2d_split(test_case):\n        pred = flow.randn(8, 10)\n        label = flow.randint(0, 10, (8,))\n        placement = flow.placement(\n            \"cuda\", np.array(range(flow.env.get_world_size())).reshape(2, 2)\n        )\n        pred = pred.to_global(\n            placement=placement, sbp=[flow.sbp.broadcast(), flow.sbp.broadcast()]\n        )\n        label = label.to_global(\n            placement=placement, sbp=[flow.sbp.broadcast(), flow.sbp.broadcast()]\n        )\n        _compare_with_nn_cross_entropy_loss(\n            test_case,\n            pred,\n            label,\n            [flow.sbp.split(0), flow.sbp.split(1)],\n            [flow.sbp.split(0), flow.sbp.broadcast()],\n        )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/graph/test_graph_tensor_clone.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport os\nimport unittest\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n@flow.unittest.skip_unless_1n1d()\nclass TestTensorCloneGraph(oneflow.unittest.TestCase):\n    def test_tensor_clone_graph(test_case):\n        class TensorCloneGraph(flow.nn.Graph):\n            def __init__(self):\n                super().__init__()\n\n            def build(self, x):\n                y = x.clone()\n                x += x\n                return x, y\n\n        x = flow.randn(3, 4)\n        res = TensorCloneGraph()(x)\n        test_case.assertTrue(len(res) == 2)\n        test_case.assertTrue(np.allclose(res[0], res[1] * 2, 1e-05, 1e-05))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/graph/test_graph_tensor_detach.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport os\nimport unittest\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n@flow.unittest.skip_unless_1n1d()\nclass TestTensorDetachGraph(oneflow.unittest.TestCase):\n    def test_tensor_detach_graph(test_case):\n        class TensorDetachGraph(flow.nn.Graph):\n            def __init__(self):\n                super().__init__()\n\n            def build(self, x):\n                x += x\n                y = x.detach()\n                return x, y\n\n        x = flow.randn(3, 4)\n        res = TensorDetachGraph()(x)\n        test_case.assertTrue(len(res) == 2)\n        test_case.assertTrue(np.allclose(res[0], res[1], 1e-05, 1e-05))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/graph/test_graph_with_global.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nimport os\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.nn.graph import GraphModule\nimport oneflow.utils.global_view as global_view\nfrom oneflow.utils.global_view import global_mode\n\n\ndef _test_linear_train_graph_with_ddp(test_case):\n    def train_with_graph(iter_num=1):\n        PC = flow.placement(\"cpu\", ranks=[0, 1])\n        P = flow.placement(\"cuda\", ranks=[0, 1])\n        B = flow.sbp.broadcast\n        S0 = flow.sbp.split(0)\n\n        linear_dp = flow.nn.Linear(800, 400, bias=False)\n        linear_dp = linear_dp.to_global(placement=P, sbp=B)\n        flow.nn.init.constant_(linear_dp.weight, 2.068758)\n\n        of_sgd = flow.optim.SGD(\n            [{\"params\": linear_dp.parameters()}], lr=0.001, momentum=0.9,\n        )\n\n        x = flow.ones((6, 800), placement=PC, sbp=S0)\n\n        class LinearTrainGraphWithDDP(flow.nn.Graph):\n            def __init__(self):\n                super().__init__()\n                self.linear_dp = linear_dp\n                self.add_optimizer(of_sgd)\n\n            def build(self, x):\n                x = x.to_global(placement=P)\n                out = self.linear_dp(x)\n                loss = out.sum()\n                loss.backward()\n                return out\n\n        class LinearEvalGraphWithDDP(flow.nn.Graph):\n            def __init__(self):\n                super().__init__()\n                self.linear_dp = linear_dp\n\n            def build(self, x):\n                x = x.to_global(placement=P)\n                out = self.linear_dp(x)\n                return out\n\n        linear_t_g = LinearTrainGraphWithDDP()\n        # linear_t_g.debug(1)\n        linear_e_g = LinearEvalGraphWithDDP()\n        # linear_e_g.debug(1)\n\n        result_check_list = []\n\n        def one_train_iter(iter_cnt=0):\n            out = linear_t_g(x)\n            result_check_list.append(out)\n\n            # if iter_cnt == 0:\n            #     if flow.env.get_rank() == 0:\n            #         import traceback\n\n            #         try:\n            #             print(linear_t_g)\n            #         except:\n            #             print(traceback.format_exc())\n\n        def one_eval_iter(iter_cnt=0):\n            out = linear_e_g(x)\n            result_check_list.append(out)\n\n        for i in range(iter_num):\n            one_train_iter(i)\n\n        # In evaluation graph, paramters's sbp are flow.sbp.split(0).\n        # But their consumer will consum them as flow.sbp.broadcast.\n        one_eval_iter()\n\n        return result_check_list\n\n    def train_with_graph_ddp(iter_num=1):\n        PC = flow.placement(\"cpu\", ranks=[0, 1])\n        P = flow.placement(\"cuda\", ranks=[0, 1])\n        B = flow.sbp.broadcast\n        S0 = flow.sbp.split(0)\n\n        linear_dp = flow.nn.Linear(800, 400, bias=False)\n        linear_dp = linear_dp.to_global(placement=P, sbp=B)\n        flow.nn.init.constant_(linear_dp.weight, 2.068758)\n\n        of_sgd = flow.optim.SGD(\n            [{\"params\": linear_dp.parameters()}], lr=0.001, momentum=0.9,\n        )\n\n        with global_mode(True, placement=PC, sbp=S0):\n            x = flow.ones((6, 800), placement=PC, sbp=S0)\n\n        class LinearTrainGraphWithDDP(flow.nn.Graph):\n            def __init__(self):\n                super().__init__()\n                self.linear_dp = linear_dp\n                self.add_optimizer(of_sgd)\n\n            def build(self, x):\n                # This is ok\n                # x = x.to(\"cuda\")\n\n                # This is ok\n                # x = x.to_global(placement=P)\n\n                # This is not ok\n                # x = x.to(device)\n\n                with global_mode(True, placement=P, sbp=B):\n                    # Test global tensor to device\n                    device = self.linear_dp.weight.device\n\n                    x = x.to(device)\n\n                    out = self.linear_dp(x)\n\n                    # Test randn source op\n                    sample = flow.randn(out.shape, device=\"cpu\").to(device)\n                    out = out + sample * 100\n\n                # Test disable global_mode while passing placement and sbp\n                with global_mode(False, placement=P, sbp=B):\n                    out = out - sample * 100\n                    cur_global_mode = global_view.current_global_mode()\n                    test_case.assertFalse(cur_global_mode.is_enabled)\n\n                loss = out.sum()\n                loss.backward()\n                return out\n\n        class LinearEvalGraphWithDDP(flow.nn.Graph):\n            def __init__(self):\n                super().__init__()\n                self.linear_dp = linear_dp\n\n            def build(self, x):\n                with global_mode(True, placement=P, sbp=B):\n                    device = self.linear_dp.weight.device\n\n                    x = x.to(device)\n\n                    out = self.linear_dp(x)\n\n                    # Test randn source op\n                    sample = flow.randn(out.shape, device=\"cpu\").to(device)\n                    out = out + sample * 100\n                    out = out - sample * 100\n\n                return out\n\n        linear_t_g = LinearTrainGraphWithDDP()\n        # linear_t_g.debug(1)\n        linear_e_g = LinearEvalGraphWithDDP()\n        # linear_e_g.debug(1)\n\n        result_check_list = []\n\n        def one_train_iter(iter_cnt=0):\n            out = linear_t_g(x)\n            result_check_list.append(out)\n\n            # if iter_cnt == 0:\n            #     if flow.env.get_rank() == 0:\n            #         import traceback\n\n            #         try:\n            #             print(linear_t_g)\n            #         except:\n            #             print(traceback.format_exc())\n\n        def one_eval_iter(iter_cnt=0):\n            out = linear_e_g(x)\n            result_check_list.append(out)\n\n        for i in range(iter_num):\n            one_train_iter(i)\n\n        # In evaluation graph, paramters's sbp are flow.sbp.split(0).\n        # But their consumer will consum them as flow.sbp.broadcast.\n        one_eval_iter()\n\n        return result_check_list\n\n    iter_num = 2\n    graph_check_list = train_with_graph(iter_num)\n    graph_ddp_check_list = train_with_graph_ddp(iter_num)\n    test_case.assertEqual(len(graph_check_list), iter_num + 1)\n    test_case.assertEqual(len(graph_ddp_check_list), iter_num + 1)\n    for i in range(iter_num + 1):\n        test_case.assertTrue(\n            np.allclose(\n                graph_check_list[i].numpy(),\n                graph_ddp_check_list[i].numpy(),\n                rtol=1e-5,\n                atol=1e-5,\n            ),\n            f\"current index {i} \\n base {graph_check_list[i].numpy()} \\n ddp {graph_ddp_check_list[i].numpy()} \\n diff {graph_ddp_check_list[i].numpy() - graph_check_list[i].numpy()}\",\n        )\n\n\ndef _test_global_mode(test_case):\n    P = flow.placement(\"cuda\", ranks=[0, 1])\n    B = flow.sbp.broadcast\n\n    class GlobalModeGraph(flow.nn.Graph):\n        def __init__(self):\n            super().__init__()\n\n        def build(self):\n            with global_mode(True, placement=P, sbp=B):\n                # Test global mode meta data\n                cur_global_mode = global_view.current_global_mode()\n                test_case.assertTrue(cur_global_mode.is_enabled)\n                test_case.assertEqual(cur_global_mode.placement, P)\n                test_case.assertEqual(cur_global_mode.sbp[0], B)\n\n                # Test global mode source op\n                randn_out = flow.randn((2, 2))\n                rand_out = flow.rand((2, 2))\n                randint_out = flow.randint(-100, 100, (2, 2))\n                randperm_out = flow.randperm(5)\n                arange_out = flow.arange(10)\n                empty_out = flow.empty((1, 2))\n                tensor_out = flow.tensor([[1, 2, 4, 5], [4, 3, 2, 9]], dtype=flow.int)\n                hann_window_out = flow.hann_window(8, dtype=flow.float)\n\n            test_case.assertTrue(not global_view.current_global_mode().is_enabled)\n\n            return {\n                \"randn_out\": randn_out,\n                \"rand_out\": rand_out,\n                \"randint_out\": randint_out,\n                \"randperm_out\": randperm_out,\n                \"arange_out\": arange_out,\n                \"empty_out\": empty_out,\n                \"tensor_out\": tensor_out,\n                \"hann_window_out\": hann_window_out,\n            }\n\n    global_graph = GlobalModeGraph()\n    out = global_graph()\n    for k, v in out.items():\n        test_case.assertEqual(v.is_global, True, k)\n        test_case.assertEqual(v.placement, P, k)\n        test_case.assertEqual(v.sbp[0], B, k)\n\n\ndef _test_global_mode_with_default_placement_and_sbp(test_case):\n    # create a tensor with broadcast split and placement on rank 0\n    a = flow.randn(\n        (1, 8), sbp=flow.sbp.broadcast, placement=flow.placement(\"cuda\", ranks=[0])\n    )\n    # enter global mode with broadcast split and placement on 2 GPUs\n    with global_mode(\n        True,\n        placement=flow.placement(type=\"cuda\", ranks=[0, 1]),\n        sbp=flow.sbp.broadcast,\n    ):\n        # check tensor placement and split\n        test_case.assertTrue(a.placement == flow.placement(\"cuda\", ranks=[0]))\n        test_case.assertTrue(a.sbp == (flow.sbp.broadcast,))\n        # check tensor print\n        print(a)\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n@flow.unittest.skip_unless_1n2d()\nclass TestLinearTrainGraphWithDDP(oneflow.unittest.TestCase):\n    def test_linear_train_graph_with_ddp(test_case):\n        _test_linear_train_graph_with_ddp(test_case)\n\n    @unittest.skip(\"skip for now, becase it failed 4 times in past week\")\n    def test_global_mode(test_case):\n        _test_global_mode(test_case)\n        _test_global_mode_with_default_placement_and_sbp(test_case)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/graph/test_graph_zero.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nimport os\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.nn.graph import GraphModule\n\n\ndef _test_linear_train_graph_with_zero(test_case, zero_stage=1):\n    def train_with_graph(iter_num=1):\n        P = flow.placement(\"cuda\", ranks=[0, 1])\n        B = flow.sbp.broadcast\n        S0 = flow.sbp.split(0)\n\n        linear_dp = flow.nn.Linear(800, 400, bias=False)\n        linear_dp = linear_dp.to_global(placement=P, sbp=B)\n        flow.nn.init.constant_(linear_dp.weight, 2.068758)\n\n        linear_mp = flow.nn.Linear(400, 500, bias=False)\n        linear_mp = linear_mp.to_global(placement=P, sbp=S0)\n        flow.nn.init.constant_(linear_mp.weight, 2.068758)\n\n        of_sgd = flow.optim.SGD(\n            [{\"params\": linear_dp.parameters()}, {\"params\": linear_mp.parameters()}],\n            lr=0.001,\n            momentum=0.9,\n        )\n        grad_scaler = flow.amp.StaticGradScaler(200)\n\n        x = flow.randint(1, 100, (6, 800), dtype=flow.float32, placement=P, sbp=S0)\n\n        class LinearTrainGraphWithZeRO(flow.nn.Graph):\n            def __init__(self):\n                super().__init__()\n                self.linear_dp = linear_dp\n                self.linear_mp = linear_mp\n                self.add_optimizer(of_sgd)\n\n                self.config.enable_amp(True)\n                self.set_grad_scaler(grad_scaler)\n                self.config.enable_zero(\n                    True, stage=zero_stage, shard_min_size=1, shard_restore_level=0,\n                )\n                self.debug(2)\n\n            def build(self, x):\n                out = self.linear_dp(x)\n                out = out.to_global(placement=P, sbp=B)\n                out = self.linear_mp(out)\n                loss = out.sum()\n                loss.backward()\n                return out\n\n        class LinearEvalGraphWithZeRO(flow.nn.Graph):\n            def __init__(self):\n                super().__init__()\n                self.linear_dp = linear_dp\n                self.linear_mp = linear_mp\n\n                self.config.enable_amp(True)\n\n            def build(self, x):\n                out = self.linear_dp(x)\n                out = out.to_global(placement=P, sbp=B)\n                out = self.linear_mp(out)\n                return out\n\n        linear_t_g = LinearTrainGraphWithZeRO()\n        linear_t_g.debug(1)\n        linear_e_g = LinearEvalGraphWithZeRO()\n        linear_e_g.debug(1)\n\n        def one_train_iter():\n            out = linear_t_g(x)\n            if flow.env.get_rank() == 0:\n                import traceback\n\n                try:\n                    print(linear_t_g)\n                except:\n                    print(traceback.format_exc())\n\n        def one_eval_iter():\n            out = linear_e_g(x)\n\n        for i in range(iter_num):\n            one_train_iter()\n\n        # After pass rewrite in training graph, parameters' sbp has been\n        # changed from flow.sbp.broadcast to flow.sbp.split(0)\n        test_case.assertEqual(linear_dp.weight.sbp[0], S0)\n        test_case.assertEqual(linear_mp.weight.sbp[0], S0)\n\n        # In evaluation graph, parameter's sbp are flow.sbp.split(0).\n        # But their consumer will consume them as flow.sbp.broadcast.\n        one_eval_iter()\n\n    iter_num = 1\n    graph_check_list = train_with_graph(iter_num)\n\n\ndef _test_linear_train_graph_2d_with_zero(test_case, zero_stage=1):\n    def train_with_graph(iter_num=1):\n        P = flow.placement(\"cuda\", ranks=[[0, 1], [2, 3]])\n        B = flow.sbp.broadcast\n        S0 = flow.sbp.split(0)\n        S1 = flow.sbp.split(1)\n\n        def get_mixed_linear():\n            linear_dp_mp = flow.nn.Linear(800, 400, bias=False)\n            linear_dp_mp = linear_dp_mp.to_global(placement=P, sbp=[B, S0])\n            flow.nn.init.constant_(linear_dp_mp.weight, 1.068758)\n\n            linear_mp_dp = flow.nn.Linear(800, 400, bias=False)\n            linear_mp_dp = linear_mp_dp.to_global(placement=P, sbp=[S0, B])\n            flow.nn.init.constant_(linear_mp_dp.weight, 1.068758)\n\n            class MixedLinear(flow.nn.Module):\n                def __init__(self):\n                    super().__init__()\n                    self.dp_mp = linear_dp_mp\n                    self.mp_dp = linear_mp_dp\n\n                def forward(self, x):\n                    x = self.dp_mp(x)\n                    x = flow.relu(x)\n                    x = self.mp_dp(x)\n                    x = flow.relu(x)\n                    return x\n\n            return MixedLinear()\n\n        mixed_linear0 = get_mixed_linear()\n        mixed_linear1 = get_mixed_linear()\n\n        of_sgd = flow.optim.SGD(\n            [\n                {\"params\": mixed_linear0.parameters()},\n                {\"params\": mixed_linear1.parameters()},\n            ],\n            lr=0.001,\n            momentum=0.9,\n        )\n        grad_scaler = flow.amp.StaticGradScaler(200)\n\n        x = flow.rand((2, 800), dtype=flow.float32, placement=P, sbp=[S0, B])\n\n        class LinearTrainGraph2DWithZeRO(flow.nn.Graph):\n            def __init__(self):\n                super().__init__()\n                self.mixed_linear0 = mixed_linear0\n                self.mixed_linear0.to(GraphModule).activation_checkpointing = True\n                self.mixed_linear1 = mixed_linear1\n                self.mixed_linear1.to(GraphModule).activation_checkpointing = True\n                self.add_optimizer(of_sgd)\n\n                self.config.enable_amp(True)\n                self.set_grad_scaler(grad_scaler)\n                self.config.enable_zero(\n                    True, stage=zero_stage, shard_min_size=1, shard_restore_level=1,\n                )\n\n            def build(self, x):\n                out = self.mixed_linear0(x)\n                out = self.mixed_linear1(out)\n                loss = out.mean()\n                loss.backward()\n                return loss\n\n        class LinearEvalGraph2DWithZeRO(flow.nn.Graph):\n            def __init__(self):\n                super().__init__()\n                self.mixed_linear0 = mixed_linear0\n                self.mixed_linear1 = mixed_linear1\n\n                self.config.enable_amp(True)\n\n            def build(self, x):\n                out = self.mixed_linear0(x)\n                out = self.mixed_linear1(out)\n                return out\n\n        linear_t_g = LinearTrainGraph2DWithZeRO()\n        linear_e_g = LinearEvalGraph2DWithZeRO()\n\n        def one_train_iter():\n            out = linear_t_g(x)\n            # if flow.env.get_rank() == 0:\n            #    print(linear_t_g)\n\n        def one_eval_iter():\n            out = linear_e_g(x)\n\n        for i in range(iter_num):\n            one_train_iter()\n\n        for state in linear_t_g._state():\n            test_case.assertEqual(\n                state.to(flow.Tensor).sbp,\n                (oneflow.sbp.split(dim=0), oneflow.sbp.split(dim=0)),\n            )\n\n        # In evaluation graph, paramters's sbp are flow.sbp.split(0).\n        # But their consumer will consum them as flow.sbp.broadcast.\n        one_eval_iter()\n\n    iter_num = 1\n    graph_check_list = train_with_graph(iter_num)\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n@flow.unittest.skip_unless_1n2d()\nclass TestLinearTrainGraphWithZeRO(oneflow.unittest.TestCase):\n    def test_linear_train_graph_with_zero_1(test_case):\n        _test_linear_train_graph_with_zero(test_case, 1)\n\n    def test_linear_train_graph_with_zero_2(test_case):\n        _test_linear_train_graph_with_zero(test_case, 2)\n\n    def test_linear_train_graph_with_zero_3(test_case):\n        _test_linear_train_graph_with_zero(test_case, 3)\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n@flow.unittest.skip_unless_1n4d()\nclass TestLinearTrainGraph2DWithZeRO(oneflow.unittest.TestCase):\n    def test_linear_train_graph_2d_with_zero_3(test_case):\n        _test_linear_train_graph_2d_with_zero(test_case, 3)\n\n    def test_linear_train_graph_2d_with_zero_2(test_case):\n        _test_linear_train_graph_2d_with_zero(test_case, 2)\n\n    def test_linear_train_graph_2d_with_zero_1(test_case):\n        _test_linear_train_graph_2d_with_zero(test_case, 1)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/graph/test_input_op_expr.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport os\nimport unittest\n\nimport numpy as np\nfrom google.protobuf import text_format\n\nimport oneflow\nimport oneflow as flow\nimport oneflow._oneflow_internal\nimport oneflow._oneflow_internal._C as _C\nimport oneflow.framework.c_api_util as c_api_util\nimport oneflow.framework.session_context as session_ctx\nimport oneflow.unittest\nfrom oneflow.framework.multi_client_session import MultiClientSession\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestFeedInputTensor(unittest.TestCase):\n    def test_feed_input_tensor(test_case):\n        x = flow.Tensor(1, 1, 10, 10)\n        flow.nn.init.uniform_(x, a=-1.0, b=1.0)\n        session = session_ctx.GetDefaultSession()\n        test_case.assertTrue(isinstance(session, MultiClientSession))\n        session.TryInit()\n        with oneflow._oneflow_internal.lazy_mode.guard(True):\n            oneflow._oneflow_internal.JobBuildAndInferCtx_Open(\n                \"cc_test_input_op_expr_job\"\n            )\n            job_conf = oneflow.core.job.job_conf_pb2.JobConfigProto()\n            job_conf.job_name = \"cc_test_input_op_expr_job\"\n            job_conf.predict_conf.SetInParent()\n            c_api_util.CurJobBuildAndInferCtx_SetJobConf(job_conf)\n            op_name = \"cc_Input_0\"\n            input_conf = oneflow.core.operator.op_conf_pb2.FeedInputOpConf()\n            input_conf.in_0 = \"EagerTensorInput\"\n            input_conf.out_0 = \"out_0\"\n            input_conf_str = text_format.MessageToString(input_conf)\n            input_op = oneflow._oneflow_internal.one.FeedInputOpExpr(\n                op_name, input_conf_str, [\"in_0\"], [\"out_0\"]\n            )\n            out_tensor = _C.dispatch_feed_input(input_op, x)\n            test_case.assertEqual(out_tensor.shape, (1, 1, 10, 10))\n            test_case.assertTrue(out_tensor.is_lazy)\n            test_case.assertTrue(out_tensor.is_local)\n            oneflow._oneflow_internal.JobBuildAndInferCtx_Close()\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/graph/test_long_add_n_pass.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport argparse\nimport numpy as np\nimport os\nimport time\nimport unittest\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\ndef _test_long_add_n_graph(test_case, device):\n    input_arr = np.array(\n        [\n            [-0.94630778, -0.83378579, -0.87060891],\n            [2.0289922, -0.28708987, -2.18369248],\n            [0.35217619, -0.67095644, -1.58943879],\n            [0.08086036, -1.81075924, 1.20752494],\n            [0.8901075, -0.49976737, -1.07153746],\n            [-0.44872912, -1.07275683, 0.06256855],\n            [-0.22556897, 0.74798368, 0.90416439],\n            [0.48339456, -2.32742195, -0.59321527],\n        ],\n        dtype=np.float32,\n    )\n    x0 = flow.tensor(input_arr, device=device)\n    x1 = flow.tensor(input_arr, device=device)\n    x2 = flow.tensor(input_arr, device=device)\n    x3 = flow.tensor(input_arr, device=device)\n    x4 = flow.tensor(input_arr, device=device)\n    x5 = flow.tensor(input_arr, device=device)\n    x6 = flow.tensor(input_arr, device=device)\n    x7 = flow.tensor(input_arr, device=device)\n    x8 = flow.tensor(input_arr, device=device)\n    x9 = flow.tensor(input_arr, device=device)\n\n    class AddNGraph(flow.nn.Graph):\n        def __init__(self):\n            super().__init__()\n\n        def build(self):\n            # Deprecated `temp = x0 + x0` to avoid unstable test\n            # enable this after fix https://github.com/Oneflow-Inc/oneflow/issues/9431\n            # temp = x0 + x0\n            temp = x0\n            temp = temp + x1  # test add_n1(add_n0(...), ...)\n            temp = temp + temp  # test add_n1(add_n0(...), add_n0(...))\n            temp = temp + x2\n            temp = temp + x3\n            temp = temp + x4\n            temp = temp + x5\n            temp = temp + x6\n            temp = temp + x7\n            other_add_n = x8 + x9\n            temp = temp + other_add_n  # test add_n2(add_n0(), add_n1())\n            return temp\n\n    add_n_g = AddNGraph()\n    of_lazy_out = add_n_g()\n    test_case.assertTrue(np.allclose(input_arr * 12, of_lazy_out.numpy(), 1e-05, 1e-05))\n\n\ndef _test_add_n_consume_multi_add_n_graph(test_case, device):\n    input_arr = np.array(\n        [\n            [-0.94630778, -0.83378579, -0.87060891],\n            [2.0289922, -0.28708987, -2.18369248],\n            [0.35217619, -0.67095644, -1.58943879],\n            [0.08086036, -1.81075924, 1.20752494],\n            [0.8901075, -0.49976737, -1.07153746],\n            [-0.44872912, -1.07275683, 0.06256855],\n            [-0.22556897, 0.74798368, 0.90416439],\n            [0.48339456, -2.32742195, -0.59321527],\n        ],\n        dtype=np.float32,\n    )\n    x0 = flow.tensor(input_arr, device=device)\n\n    class AddNGraph(flow.nn.Graph):\n        def __init__(self):\n            super().__init__()\n\n        def build(self):\n            temp = x0 + x0\n            temp = temp + temp\n            return temp\n\n    add_n_g = AddNGraph()\n    of_lazy_out = add_n_g()\n    test_case.assertTrue(np.allclose(input_arr * 4, of_lazy_out.numpy(), 1e-05, 1e-05))\n\n\n@unittest.skip(\"fail on ci\")\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n@flow.unittest.skip_unless_1n1d()\nclass TestLongAddNGraph(oneflow.unittest.TestCase):\n    def test_add_n(test_case):\n        device = \"cuda\"\n        _test_long_add_n_graph(test_case, device)\n\n    def test_consume_multi_add_n(test_case):\n        device = \"cuda\"\n        _test_add_n_consume_multi_add_n_graph(test_case, device)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/graph/test_modify_module_forward.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport argparse\nimport numpy as np\nimport os\nimport time\nimport unittest\nfrom types import MethodType\n\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow import nn\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n@flow.unittest.skip_unless_1n1d()\nclass TestModifyForwardOfModule(oneflow.unittest.TestCase):\n    def test_modify_forward(test_case):\n        def forward2(self, x):\n            return x + 1\n\n        class Model1(nn.Module):\n            def __init__(self):\n                super().__init__()\n\n            def forward(self, x):\n                return x\n\n        class ForwardModifiedGraph(nn.Graph):\n            def __init__(self, model):\n                super().__init__()\n                self.model = model\n                self.model.eval()\n\n            def build(self, x):\n                return self.model(x)\n\n        test_model = Model1()\n        test_model.forward = MethodType(forward2, test_model)\n        eval_graph_model1 = ForwardModifiedGraph(model=test_model)\n\n        input_tensor = flow.tensor([0.0], requires_grad=True)\n\n        eager_out = test_model(input_tensor)\n        graph_out = eval_graph_model1(input_tensor)\n        test_case.assertTrue(np.array_equal(graph_out.numpy(), eager_out.numpy()))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/graph/test_multi_client_session.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport os\nimport unittest\n\nimport oneflow\nimport oneflow as flow\nimport oneflow.framework.session_context as session_ctx\nimport oneflow.unittest\nfrom oneflow.framework.multi_client_session import MultiClientSession\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestMultiClientSession(unittest.TestCase):\n    def test_case1(self):\n        sess = session_ctx.GetDefaultSession()\n        self.assertTrue(isinstance(sess, MultiClientSession))\n        sess.TryInit()\n        self.assertEqual(sess.status, sess.Status.INITED)\n\n    def test_case2(self):\n        print(\"test_case2\")\n        sess = session_ctx.GetDefaultSession()\n        self.assertTrue(isinstance(sess, MultiClientSession))\n        sess.TryInit()\n        self.assertEqual(sess.status, sess.Status.INITED)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/graph/test_multi_graph.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport os\nimport unittest\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n@flow.unittest.skip_unless_1n1d()\nclass TestMultiGraph(oneflow.unittest.TestCase):\n    def test_multi_graph(test_case):\n        relu_data = np.array([2.0, 1.0, 0.0, -1.0, -2.0])\n        relu_in = flow.tensor(relu_data, dtype=flow.float32)\n\n        MyRelu = flow.nn.ReLU()\n        relu_out_eager = MyRelu(relu_in)\n\n        class ReluGraph(flow.nn.Graph):\n            def __init__(self):\n                super().__init__()\n                self.cc_relu = MyRelu\n\n            def build(self, x):\n                return self.cc_relu(x)\n\n        relu_g = ReluGraph()\n        relu_out_lazy = relu_g(relu_in)\n        test_case.assertTrue(\n            np.array_equal(relu_out_lazy.numpy(), relu_out_eager.numpy())\n        )\n\n        linear = flow.nn.Linear(3, 8, False)\n        linear = linear.to(flow.device(\"cuda\"))\n        input_arr = np.array(\n            [\n                [-0.94630778, -0.83378579, -0.87060891],\n                [2.0289922, -0.28708987, -2.18369248],\n                [0.35217619, -0.67095644, -1.58943879],\n                [0.08086036, -1.81075924, 1.20752494],\n                [0.8901075, -0.49976737, -1.07153746],\n                [-0.44872912, -1.07275683, 0.06256855],\n                [-0.22556897, 0.74798368, 0.90416439],\n                [0.48339456, -2.32742195, -0.59321527],\n            ],\n            dtype=np.float32,\n        )\n        np_weight = np.ones((3, 8)).astype(np.float32)\n        np_weight.fill(2.3)\n        linear_in = flow.tensor(input_arr, device=flow.device(\"cuda\"))\n        flow.nn.init.constant_(linear.weight, 2.3)\n        linear_out_eager = linear(linear_in)\n        np_out = np.matmul(input_arr, np_weight)\n        test_case.assertTrue(\n            np.allclose(linear_out_eager.numpy(), np_out, 1e-05, 1e-05)\n        )\n\n        class LinearGraph(flow.nn.Graph):\n            def __init__(self):\n                super().__init__()\n                self.my_linear = linear\n\n            def build(self, x):\n                return self.my_linear(x)\n\n        linear_g = LinearGraph()\n        linear_out_lazy = linear_g(linear_in)\n        test_case.assertTrue(\n            np.array_equal(linear_out_lazy.numpy(), linear_out_eager.numpy())\n        )\n\n        relu_out_lazy = relu_g(relu_in)\n        linear_out_lazy = linear_g(linear_in)\n        test_case.assertTrue(\n            np.array_equal(relu_out_eager.numpy(), relu_out_lazy.numpy())\n        )\n        test_case.assertTrue(\n            np.array_equal(linear_out_eager.numpy(), linear_out_lazy.numpy())\n        )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/graph/test_multi_tensor_adam_update_with_cast.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nfrom collections import OrderedDict\nimport numpy as np\nimport copy\nimport os\n\nfrom test_util import GenArgList\n\nimport oneflow as flow\n\n\ndef compare_with_numpy_adam(\n    test_case,\n    device,\n    x_shape,\n    tensor_num,\n    learning_rate,\n    train_iters,\n    betas,\n    weight_decay,\n    eps,\n    do_bias_correction,\n    amsgrad,\n):\n    random_weight_seq = []\n    init_value_seq = []\n\n    for _ in range(train_iters):\n        random_grad_seq_per_iter = []\n        for i in range(tensor_num):\n            random_grad_seq_per_iter.append(\n                np.random.uniform(size=x_shape).astype(np.float32)\n            )\n        random_weight_seq.append(random_grad_seq_per_iter)\n\n    for i in range(tensor_num):\n        init_value_seq.append(np.random.uniform(size=x_shape).astype(np.float32))\n\n    class CustomModule(flow.nn.Module):\n        def __init__(self):\n            super().__init__()\n            self.add_parameters()\n\n        def add_parameters(self) -> None:\n            for idx in range(tensor_num):\n                self.register_parameter(\n                    f\"param_{idx}\",\n                    flow.nn.Parameter(\n                        flow.tensor(init_value_seq[idx], device=flow.device(device))\n                    ),\n                )\n\n        def param(self, i):\n            return getattr(self, f\"param_{i}\")\n\n        def forward(self, mask_list):\n            out = 0\n            for idx in range(tensor_num):\n                out += flow._C.matmul(self.param(idx), mask_list[idx])\n\n            return out\n\n    simp_module = CustomModule()\n    simp_module.to(device)\n    simp_module.train()\n\n    adam0 = flow.optim.Adam(\n        [\n            {\n                \"params\": simp_module.parameters(),\n                \"lr\": learning_rate,\n                \"betas\": betas,\n                \"eps\": eps,\n                \"weight_decay\": weight_decay,\n            },\n        ],\n        do_bias_correction=do_bias_correction,\n        amsgrad=amsgrad,\n    )\n\n    class CustomAdamGraph(flow.nn.Graph):\n        def __init__(self):\n            super().__init__()\n            self.m = simp_module\n            self.add_optimizer(adam0)\n            self.config.enable_amp(True)\n            self.config.allow_fuse_model_update_ops(True)\n            self.config.enable_multi_tensor_update(True)\n            self.config.enable_fused_model_update_cast(True)\n\n        def build(self, mask_tensor_list):\n            loss = flow.sum(self.m(mask_tensor_list))\n            loss.backward()\n            return loss\n\n    of_res_list = []\n    adam_graph = CustomAdamGraph()\n    for i in range(train_iters):\n        mask_tensor_list = []\n        for idx in range(tensor_num):\n            mask_tensor_list.append(\n                flow.tensor(\n                    random_weight_seq[i][idx],\n                    dtype=flow.float32,\n                    requires_grad=False,\n                    device=flow.device(device),\n                )\n            )\n        adam_x = adam_graph(mask_tensor_list)\n        of_res_list.append([])\n        for idx in range(tensor_num):\n            of_res_list[i].append(copy.copy(simp_module.param(idx).numpy()))\n\n    np_res_list = []\n\n    def train_by_numpy():\n        x = init_value_seq\n        m = []\n        v = []\n        for idx in range(tensor_num):\n            m.append(np.zeros_like(x[idx]))\n            v.append(np.zeros_like(x[idx]))\n        beta1 = betas[0]\n        beta2 = betas[1]\n\n        ones = np.ones(x_shape).astype(np.float32)\n\n        def train_one_iter(step, weight):\n            for i in range(tensor_num):\n                transposed_weight = np.transpose(weight[i], (1, 0))\n                grad = np.matmul(ones, transposed_weight)\n                grad = grad + weight_decay * x[i]\n\n                bias_correction1 = 1.0\n                bias_correction2 = 1.0\n\n                if do_bias_correction:\n                    bias_correction1 = 1.0 - np.power(beta1, step)\n                    bias_correction2 = 1.0 - np.power(beta2, step)\n\n                m[i] = beta1 * m[i] + (1 - beta1) * grad\n                v[i] = beta2 * v[i] + (1 - beta2) * grad * grad\n                denom = np.sqrt(v[i]) / np.sqrt(bias_correction2) + eps\n\n                x[i] = x[i] - ((learning_rate / bias_correction1) * m[i] / denom)\n            return (x, m, v)\n\n        for i in range(1, train_iters + 1):\n            x, m, v = train_one_iter(i, random_weight_seq[i - 1])\n            np_res_list.append(copy.copy(x))\n\n    train_by_numpy()\n    for i in range(tensor_num):\n        test_case.assertTrue(\n            np.allclose(np_res_list[i], of_res_list[i], rtol=1e-3, atol=1e-3)\n        )\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n@flow.unittest.skip_unless_1n1d()\nclass TestMultiTensorAdam(flow.unittest.TestCase):\n    def test_multi_tensor_adam(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"device\"] = [\"cuda\"]\n        arg_dict[\"x_shape\"] = [(4, 4)]\n        arg_dict[\"tensor_num\"] = [4, 6]\n        arg_dict[\"learning_rate\"] = [1, 1e-3]\n        arg_dict[\"train_iters\"] = [10]\n        arg_dict[\"betas\"] = [(0.99, 0.9)]\n        arg_dict[\"weight_decay\"] = [0.0, 1e-3]\n        arg_dict[\"eps\"] = [1e-5]\n        arg_dict[\"do_bias_correction\"] = [True, False]\n        arg_dict[\"amsgrad\"] = [False]  # Multi tensor update do not support amsgrad\n        for arg in GenArgList(arg_dict):\n            compare_with_numpy_adam(test_case, *arg)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/graph/test_multi_tensor_sgd_update_with_cast.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nfrom collections import OrderedDict\nimport numpy as np\nimport copy\nimport os\n\nfrom test_util import GenArgList\n\nimport oneflow as flow\n\n\ndef compare_with_numpy_sgd(\n    test_case, device, x_shape, tensor_num, learning_rate, train_iters, weight_decay\n):\n    random_weight_seq = []\n    init_value_seq = []\n\n    for _ in range(train_iters):\n        random_grad_seq_per_iter = []\n        for i in range(tensor_num):\n            random_grad_seq_per_iter.append(\n                np.random.uniform(size=x_shape).astype(np.float32)\n            )\n        random_weight_seq.append(random_grad_seq_per_iter)\n\n    for i in range(tensor_num):\n        init_value_seq.append(np.random.uniform(size=x_shape).astype(np.float32))\n\n    class CustomModule(flow.nn.Module):\n        def __init__(self):\n            super().__init__()\n            self.add_parameters()\n\n        def add_parameters(self) -> None:\n            for idx in range(tensor_num):\n                self.register_parameter(\n                    f\"param_{idx}\",\n                    flow.nn.Parameter(\n                        flow.tensor(init_value_seq[idx], device=flow.device(device))\n                    ),\n                )\n\n        def param(self, i):\n            return getattr(self, f\"param_{i}\")\n\n        def forward(self, mask_list):\n            out = 0\n            for idx in range(tensor_num):\n                out += flow._C.matmul(self.param(idx), mask_list[idx])\n\n            return out\n\n    simp_module = CustomModule()\n    simp_module.to(device)\n    simp_module.train()\n\n    sgd0 = flow.optim.SGD(\n        [\n            {\n                \"params\": simp_module.parameters(),\n                \"lr\": learning_rate,\n                \"weight_decay\": weight_decay,\n            }\n        ],\n    )\n\n    class CustomSGDGraph(flow.nn.Graph):\n        def __init__(self):\n            super().__init__()\n            self.m = simp_module\n            self.add_optimizer(sgd0)\n            self.config.enable_amp(True)\n            self.config.allow_fuse_model_update_ops(True)\n            self.config.enable_multi_tensor_update(True)\n            self.config.enable_fused_model_update_cast(True)\n\n        def build(self, mask_tensor_list):\n            loss = flow.sum(self.m(mask_tensor_list))\n            loss.backward()\n            return loss\n\n    of_res_list = []\n    sgd_graph = CustomSGDGraph()\n    for i in range(train_iters):\n        mask_tensor_list = []\n        for idx in range(tensor_num):\n            mask_tensor_list.append(\n                flow.tensor(\n                    random_weight_seq[i][idx],\n                    dtype=flow.float32,\n                    requires_grad=False,\n                    device=flow.device(device),\n                )\n            )\n        sgd_x = sgd_graph(mask_tensor_list)\n        of_res_list.append([])\n        for idx in range(tensor_num):\n            of_res_list[i].append(copy.copy(simp_module.param(idx).numpy()))\n\n    np_res_list = []\n\n    def train_by_numpy():\n        x = init_value_seq\n        ones = np.ones(x_shape).astype(np.float32)\n\n        def train_one_iter(weight):\n            for i in range(tensor_num):\n                transposed_weight = np.transpose(weight[i], (1, 0))\n                grad = np.matmul(ones, transposed_weight)\n                grad = grad + weight_decay * x[i]\n                x[i] = x[i] - learning_rate * grad\n            return x\n\n        for i in range(train_iters):\n            x = train_one_iter(random_weight_seq[i])\n            np_res_list.append(copy.copy(x))\n\n    train_by_numpy()\n    for i in range(tensor_num):\n        test_case.assertTrue(\n            np.allclose(np_res_list[i], of_res_list[i], rtol=1e-3, atol=1e-3)\n        )\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n@flow.unittest.skip_unless_1n1d()\nclass TestMultiTensorSGD(flow.unittest.TestCase):\n    def test_multi_tensor_sgd(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"device\"] = [\"cuda\"]\n        arg_dict[\"x_shape\"] = [(4, 4)]\n        arg_dict[\"tensor_num\"] = [4, 6]\n        arg_dict[\"learning_rate\"] = [1, 1e-3]\n        arg_dict[\"train_iters\"] = [10]\n        arg_dict[\"weight_decay\"] = [0.0, 1e-3]\n        for arg in GenArgList(arg_dict):\n            compare_with_numpy_sgd(test_case, *arg)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/graph/test_nccl_logical_send_recv.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\nimport oneflow\nimport numpy as np\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport time\nimport os\n\n\ndef _test_nccl_logical_send_recv_2d(test_case, src_nd_sbp, dst_nd_sbp):\n    # can not process p in dst\n    if flow.sbp.partial_sum() in dst_nd_sbp:\n        return\n\n    # skip src == dst\n    if src_nd_sbp == dst_nd_sbp:\n        return\n\n    # in this case, use intra group boxing\n    if src_nd_sbp[0] == dst_nd_sbp[0]:\n        return\n\n    # in this case, use inter group boxing\n    if (\n        src_nd_sbp[1] == dst_nd_sbp[1]\n        and src_nd_sbp[0] != src_nd_sbp[1]\n        and dst_nd_sbp[0] != dst_nd_sbp[1]\n    ):\n        return\n\n    # input\n    placement = flow.placement(\"cuda\", ranks=[[0, 1], [2, 3]])\n    local_np = np.arange(4 * 4 * 4).reshape(4, 4, 4)\n    x = flow.tensor(local_np, sbp=src_nd_sbp, placement=placement)\n\n    # check eager boxing\n    eager_out = x.to_global(sbp=dst_nd_sbp, placement=placement)\n    test_case.assertTrue(np.array_equal(eager_out.numpy(), x.numpy()))\n\n    # check graph boxing\n    flow.boxing.nccl.enable_use_compute_stream(True)\n\n    class TestNcclLogicalSendRecv2DGraph(flow.nn.Graph):\n        def __init__(self):\n            super().__init__()\n\n        def build(self, x):\n            y = x.to_global(sbp=dst_nd_sbp, placement=placement)\n            return y\n\n    graph = TestNcclLogicalSendRecv2DGraph()\n    # graph.debug()\n    y = graph(x)\n    out_np = y.numpy()\n    in_np = x.numpy()\n    # if flow.env.get_rank() == 0:\n    #    print(\"src sbp \", src_nd_sbp, \", dst sbp \", dst_nd_sbp)\n    #    equal = np.array_equal(out_np, in_np)\n    #    if not equal:\n    #        print(\"in \", in_np)\n    #        print(\"out \", out_np)\n    test_case.assertTrue(np.array_equal(out_np, in_np))\n\n    flow.boxing.nccl.enable_use_compute_stream(False)\n\n\ndef gen_2d_sbp():\n    sbp_list = [\n        flow.sbp.partial_sum(),\n        flow.sbp.broadcast(),\n        flow.sbp.split(0),\n        flow.sbp.split(1),\n        flow.sbp.split(2),\n    ]\n    nd_sbp_list = []\n    for sbp0 in sbp_list:\n        for sbp1 in sbp_list:\n            nd_sbp_list.append([sbp0, sbp1])\n    return nd_sbp_list\n\n\n@flow.unittest.skip_unless_1n4d()\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\nclass TestNcclLogicalSendRecv2D(flow.unittest.TestCase):\n    def test_nccl_logical_send_recv_2d(test_case):\n        os.environ[\"ONEFLOW_BOXING_DISABLE_MIDDLE_NODE_AND_CHECK\"] = \"1\"\n        arg_dict = OrderedDict()\n        arg_dict[\"src_nd_sbp\"] = gen_2d_sbp()\n        arg_dict[\"dst_nd_sbp\"] = gen_2d_sbp()\n        for arg in GenArgList(arg_dict):\n            _test_nccl_logical_send_recv_2d(test_case, *arg)\n\n\ndef _test_nccl_logical_send_recv_1d(test_case, src_nd_sbp, dst_nd_sbp):\n    # can not process p in dst\n    if flow.sbp.partial_sum() in dst_nd_sbp:\n        return\n\n    # skip src == dst\n    if src_nd_sbp == dst_nd_sbp:\n        return\n\n    # input\n    placement = flow.placement(\"cuda\", ranks=[0, 1])\n    local_np = np.arange(2 * 2 * 2).reshape(2, 2, 2)\n    x = flow.tensor(local_np, sbp=src_nd_sbp, placement=placement)\n\n    # check eager boxing\n    eager_out = x.to_global(sbp=dst_nd_sbp, placement=placement)\n    test_case.assertTrue(np.array_equal(eager_out.numpy(), x.numpy()))\n\n    # check graph boxing\n    flow.boxing.nccl.enable_use_compute_stream(True)\n\n    class TestNcclLogicalSendRecv1DGraph(flow.nn.Graph):\n        def __init__(self):\n            super().__init__()\n\n        def build(self, x):\n            y = x.to_global(sbp=dst_nd_sbp, placement=placement)\n            return y\n\n    graph = TestNcclLogicalSendRecv1DGraph()\n    # graph.debug(0)\n    y = graph(x)\n    out_np = y.numpy()\n    in_np = x.numpy()\n    # if flow.env.get_rank() == 0:\n    #    print(\"src sbp \", src_nd_sbp, \", dst sbp \", dst_nd_sbp)\n    #    print(graph)\n    #    equal = np.array_equal(out_np, in_np)\n    #    if not equal:\n    #        print(\"in \", in_np)\n    #        print(\"out \", out_np)\n    #    print(\"====================\")\n    test_case.assertTrue(np.array_equal(out_np, in_np))\n\n\ndef gen_1d_sbp():\n    sbp_list = [\n        flow.sbp.partial_sum(),\n        flow.sbp.broadcast(),\n        flow.sbp.split(0),\n        flow.sbp.split(1),\n        flow.sbp.split(2),\n    ]\n    nd_sbp_list = []\n    for sbp0 in sbp_list:\n        nd_sbp_list.append(\n            [sbp0,]\n        )\n    return nd_sbp_list\n\n\n@flow.unittest.skip_unless_1n2d()\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\nclass TestNcclLogicalSendRecv1D(flow.unittest.TestCase):\n    def test_nccl_logical_send_recv_1d(test_case):\n        os.environ[\"ONEFLOW_BOXING_DISABLE_MIDDLE_NODE_AND_CHECK\"] = \"1\"\n        arg_dict = OrderedDict()\n        arg_dict[\"src_nd_sbp\"] = gen_1d_sbp()\n        arg_dict[\"dst_nd_sbp\"] = gen_1d_sbp()\n        for arg in GenArgList(arg_dict):\n            _test_nccl_logical_send_recv_1d(test_case, *arg)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/graph/test_neq_device_process_num.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport os\nimport unittest\n\nimport numpy as np\n\nimport oneflow\nimport oneflow as flow\nimport oneflow.unittest\nimport oneflow.sysconfig\nfrom oneflow.nn.graph import GraphModule\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\nclass TestGraphNeqDeviceProcessNum(flow.unittest.TestCase):\n    def test_graph_process_num_greater_than_device(test_case):\n        # NOTE(chengcheng): this test case is ONLY for 1n8d in 4d env.\n        if not (flow.env.get_node_size() == 1 and flow.env.get_world_size() == 8):\n            return\n        if not oneflow.sysconfig.has_rpc_backend_grpc():\n            return\n\n        BATCH_SIZE = 64\n        BROADCAST = [flow.sbp.broadcast]\n        P0 = flow.placement(\"cpu\", ranks=[0, 1, 2, 3])\n        P1 = flow.placement(\"cpu\", ranks=[4, 5, 6, 7])\n\n        class Stage0Module(flow.nn.Module):\n            def __init__(self):\n                super().__init__()\n                self.flatten = flow.nn.Flatten()\n                self.linear0 = flow.nn.Linear(28 * 28, 512)\n                self.relu0 = flow.nn.ReLU()\n\n            def forward(self, x):\n                out = self.flatten(x)\n                out = self.linear0(out)\n                out = self.relu0(out)\n                return out\n\n        class Stage1Module(flow.nn.Module):\n            def __init__(self):\n                super().__init__()\n                self.linear1 = flow.nn.Linear(512, 512)\n                self.relu1 = flow.nn.ReLU()\n                self.linear2 = flow.nn.Linear(512, 10)\n                self.relu2 = flow.nn.ReLU()\n\n            def forward(self, x):\n                out = self.linear1(x)\n                out = self.relu1(out)\n                out = self.linear2(out)\n                out = self.relu2(out)\n                return out\n\n        class PipelineModule(flow.nn.Module):\n            def __init__(self):\n                super().__init__()\n                self.m_stage0 = Stage0Module()\n                self.m_stage1 = Stage1Module()\n\n                self.m_stage0.to_global(placement=P0, sbp=BROADCAST)\n                self.m_stage1.to_global(placement=P1, sbp=BROADCAST)\n\n            def forward(self, x):\n                out_stage0 = self.m_stage0(x)\n                in_stage1 = out_stage0.to_global(placement=P1, sbp=flow.sbp.split(0))\n                out_stage1 = self.m_stage1(in_stage1)\n                return out_stage1\n\n        module_pipeline = PipelineModule()\n        sgd = flow.optim.SGD(module_pipeline.parameters(), lr=0.001)\n\n        class PipelineGraph(flow.nn.Graph):\n            def __init__(self):\n                super().__init__()\n                self.module_pipeline = module_pipeline\n                self.module_pipeline.m_stage0.to(GraphModule).set_stage(0)\n                self.module_pipeline.m_stage1.to(GraphModule).set_stage(1)\n                self.loss_fn = flow.nn.CrossEntropyLoss(reduction=\"none\")\n                self.config.set_gradient_accumulation_steps(2)\n                self.add_optimizer(sgd)\n\n            def build(self, x, y):\n                out = self.module_pipeline(x)\n                loss = self.loss_fn(out, y).sum()\n                loss = loss.to_global(placement=P1, sbp=BROADCAST)\n                loss.backward()\n                return loss\n\n        graph_pipeline = PipelineGraph()\n        graph_pipeline.debug(1)\n\n        x = flow.randn(BATCH_SIZE, 1, 28, 28)\n        x = x.to_global(P0, sbp=flow.sbp.split(0))\n        y = flow.randint(0, 10, (BATCH_SIZE, 1))\n        y = y.to_global(P1, sbp=flow.sbp.split(0))\n\n        for i in range(2):\n            loss = graph_pipeline(x, y)\n            print(\">>>>>>>\", flow.env.get_rank(), loss.to_local().numpy(), flush=True)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/graph/test_oneflow_compiler.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport os\nimport unittest\n\nimport numpy as np\nimport oneflow as flow\nimport oneflow.unittest\nimport torch\nfrom oneflow.framework.infer_compiler import compile_from_torch, register\nfrom oneflow.framework.infer_compiler.with_oneflow_compile import (\n    DualModule,\n    DualModuleList,\n)\n\n\nclass TorchModule(torch.nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.linears = torch.nn.ModuleList([torch.nn.Linear(10, 10) for _ in range(10)])\n\n    def forward(self, x):\n        for i, l in enumerate(self.linears):\n            x = self.linears[i // 2](x) + l(x)\n        return x\n\n\nclass FlowModule(flow.nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.linears = flow.nn.ModuleList([flow.nn.Linear(10, 10) for _ in range(10)])\n\n    def forward(self, x):\n        for i, l in enumerate(self.linears):\n            x = self.linears[i // 2](x) + l(x)\n        return x\n\n\n@flow.unittest.skip_unless_1n1d()\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\nclass TestOneflowInferCompiler(flow.unittest.TestCase):\n    def setUp(self):\n        os.environ[\"ONEFLOW_MLIR_ENABLE_ROUND_TRIP\"] = \"1\"\n\n    def test_compile_from_torch(test_case):\n        register(torch2oflow_class_map={TorchModule: FlowModule})\n\n        m = TorchModule().to(\"cuda\")\n        x = torch.randn(2, 10).to(\"cuda\")\n\n        y_torch = m(x)\n        m = compile_from_torch(m)\n        y_flow = m(x)\n        test_case.assertTrue(\n            np.allclose(y_torch.detach().cpu(), y_flow.detach().cpu(), 1e-03, 1e-03)\n        )\n        test_case.assertIsInstance(m.linears, DualModuleList)\n\n        x = getattr(m.linears, \"1\")\n        test_case.assertIsInstance(x, DualModule)\n\n        x.bias = None\n        setattr(m.linears, \"2\", x)\n        test_case.assertIsNone(m.linears[2].bias)\n        test_case.assertIsNone(m.linears._torch_modules[2].bias)\n        test_case.assertIsNone(m.linears._oneflow_modules[2].bias)\n\n        m.linears[3] = x\n        test_case.assertIsNone(m.linears[3].bias)\n        test_case.assertIsNone(m.linears._torch_modules[3].bias)\n        test_case.assertIsNone(m.linears._oneflow_modules[3].bias)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/graph/test_optimization_conf.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport os\nimport unittest\nimport oneflow.framework.session_context as session_ctx\nimport oneflow as flow\nimport oneflow.unittest\nimport oneflow.framework.config_util as config_util\nimport oneflow.framework.attr_util as attr_util\nimport random\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n@flow.unittest.skip_unless_1n1d()\nclass TestGraphWithSysConf(flow.unittest.TestCase):\n    def test_graph_config(test_case):\n        flow.boxing.enable_fusion(True)\n\n        flow.boxing.nccl.set_fusion_threshold_mbytes(800)\n        flow.boxing.nccl.set_fusion_max_ops_num(10)\n        flow.boxing.nccl.allow_fuse_all_reduce(True)\n        flow.boxing.nccl.allow_fuse_reduce_scatter(True)\n        flow.boxing.nccl.allow_fuse_all_gather(True)\n        flow.boxing.nccl.allow_fuse_reduce(True)\n        flow.boxing.nccl.allow_fuse_broadcast(True)\n        flow.boxing.nccl.allow_fuse_mixed_ops(True)\n        flow.boxing.nccl.enable_use_buffer_to_fuse_all_reduce(True)\n        flow.boxing.nccl.set_stream_num(3)\n        flow.boxing.nccl.enable_all_to_all(True)\n        flow.boxing.nccl.enable_use_compute_stream(True)\n        flow.boxing.nccl.disable_group_boxing_by_dst_parallel(True)\n\n        flow.backends.cudnn.set_reserved_mem_mbytes(1000)\n        flow.backends.cudnn.enable_fused_normalization_add_relu(True)\n        flow.backends.cudnn.enable_conv_heuristic_search_algo(False)\n\n        flow.utils.load_library(\"\")\n\n        class CustomGraphSysConf(flow.nn.Graph):\n            def __init__(self):\n                super().__init__()\n                # amp\n                self.config.enable_amp(True)\n                grad_scaler = flow.amp.GradScaler(\n                    init_scale=3000,\n                    growth_factor=2.0,\n                    backoff_factor=0.5,\n                    growth_interval=1000,\n                )\n                self.set_grad_scaler(grad_scaler)\n\n                self.config.allow_fuse_model_update_ops(True)\n                self.config.allow_fuse_add_to_output(True)\n                self.config.set_gradient_accumulation_steps(100)\n                self.config.allow_fuse_cast_scale(True)\n                self.config.enable_zero(True)\n                self.config.enable_cudnn_conv_heuristic_search_algo(False)\n\n            def build(self, x):\n                return x\n\n        g = CustomGraphSysConf()\n\n        print(\"optimization conf: \\n\", g._optimization_conf_proto)\n        test_case.assertTrue(g._optimization_conf_proto.nccl_use_compute_stream)\n        g._generate_config_proto()\n        print(\"graph conf: \\n\", g._config_proto)\n\n        # Test the resource config update eagerly\n        # Note: this tests all the apis in oneflow.framework.config_util automatically\n        def test_resource_config_update_apis_eagerly_automatically():\n            attrs_and_values_to_check = []\n            num_api_tested = 0\n\n            for api in config_util.api_attrs_and_type.keys():\n                attrs, type_ = config_util.api_attrs_and_type[api]\n                if type_ is int:\n                    attr_value = random.randint(0, 9999)\n                    attrs_and_values_to_check.append((attrs, attr_value))\n                elif type_ is bool:\n                    attr_value = random.choice([True, False])\n                    attrs_and_values_to_check.append((attrs, attr_value))\n                else:\n                    raise TypeError(\"Unsupported type!\")\n\n                api(attr_value)\n                num_api_tested += 1\n\n            # check all the attributes are set correctly\n            for (attrs, expected_attr_value) in attrs_and_values_to_check:\n                current_attr_value = attr_util.get_nested_attribute(\n                    g._optimization_conf_proto, attrs\n                )\n                test_case.assertTrue(\n                    current_attr_value == expected_attr_value,\n                    str(attrs)\n                    + \" : \"\n                    + str(current_attr_value)\n                    + \" vs \"\n                    + str(current_attr_value),\n                )\n\n            print(\"number of APIs tested: \" + str(num_api_tested))\n\n        # save the resource config before running random resource api tests\n        session = session_ctx.GetDefaultSession()\n        prev_resource_config = session.resource\n\n        for i in range(5):\n            test_resource_config_update_apis_eagerly_automatically()\n\n        print(\"optimization conf after session init: \\n\", g._optimization_conf_proto)\n\n        # restore the resource config\n        session.update_resource_eagerly(prev_resource_config)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/graph/test_output_op_expr.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport os\nimport unittest\n\nimport numpy as np\nfrom google.protobuf import text_format\n\nimport oneflow\nimport oneflow as flow\nimport oneflow._oneflow_internal\nimport oneflow._oneflow_internal._C as _C\nimport oneflow.framework.c_api_util as c_api_util\nimport oneflow.framework.session_context as session_ctx\nimport oneflow.unittest\nfrom oneflow.framework.multi_client_session import MultiClientSession\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestFetchOutputTensor(unittest.TestCase):\n    def test_fetch_output_tensor(test_case):\n        x = flow.Tensor(1, 1, 10, 10)\n        flow.nn.init.uniform_(x, a=-1.0, b=1.0)\n        session = session_ctx.GetDefaultSession()\n        test_case.assertTrue(isinstance(session, MultiClientSession))\n        session.TryInit()\n        with oneflow._oneflow_internal.lazy_mode.guard(True):\n            oneflow._oneflow_internal.JobBuildAndInferCtx_Open(\n                \"cc_test_output_op_expr_job\"\n            )\n            job_conf = oneflow.core.job.job_conf_pb2.JobConfigProto()\n            job_conf.job_name = \"cc_test_output_op_expr_job\"\n            job_conf.predict_conf.SetInParent()\n            c_api_util.CurJobBuildAndInferCtx_SetJobConf(job_conf)\n            input_conf = oneflow.core.operator.op_conf_pb2.FeedInputOpConf()\n            input_conf.in_0 = \"EagerTensorInput\"\n            input_conf.out_0 = \"out_0\"\n            input_conf_str = text_format.MessageToString(input_conf)\n            input_op = oneflow._oneflow_internal.one.FeedInputOpExpr(\n                \"cc_Input_0\", input_conf_str, [\"in_0\"], [\"out_0\"]\n            )\n            output_conf = oneflow.core.operator.op_conf_pb2.FetchOutputOpConf()\n            output_conf.in_0 = \"LazyTensorInput\"\n            output_conf.out_0 = \"out_0\"\n            output_conf_str = text_format.MessageToString(output_conf)\n            output_op = oneflow._oneflow_internal.one.FetchOutputOpExpr(\n                \"cc_Output_0\", output_conf_str, [\"in_0\"], [\"out_0\"]\n            )\n            lazy_tensor = _C.dispatch_feed_input(input_op, x)\n            test_case.assertEqual(lazy_tensor.shape, (1, 1, 10, 10))\n            test_case.assertTrue(lazy_tensor.is_lazy)\n            test_case.assertTrue(lazy_tensor.is_local)\n            eager_tensor = _C.dispatch_fetch_output(output_op, lazy_tensor)\n            test_case.assertEqual(eager_tensor.shape, (1, 1, 10, 10))\n            test_case.assertTrue(not eager_tensor.is_lazy)\n            test_case.assertTrue(eager_tensor.is_local)\n            oneflow._oneflow_internal.JobBuildAndInferCtx_Close()\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/graph/test_run_global_graph_by_vm.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nimport os\nimport oneflow as flow\nimport oneflow.unittest\nimport numpy as np\nfrom test_run_graph_by_vm import RunGraphByVmEnv, Graph\nfrom test_graph_ofrecord_reader import OFRecordDataLoader\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\nclass TestGlobalInterpreter(flow.unittest.TestCase):\n    @flow.unittest.skip_unless_1n2d()\n    def test_data_parallel_run_by_vm(test_case):\n        with RunGraphByVmEnv():\n\n            class DataParallelMul(flow.nn.Module):\n                def __init__(self, placement) -> None:\n                    super().__init__()\n                    self.w = flow.randn(\n                        5, 8, placement=placement, sbp=flow.sbp.broadcast\n                    )\n\n                def forward(self, x):\n                    return flow.matmul(x, self.w)\n\n            placement = flow.placement(\"cuda\", [0, 1])\n\n            m = DataParallelMul(placement).eval()\n            g = Graph(m)\n\n            input = flow.randn(4, 5, placement=placement, sbp=flow.sbp.split(0))\n            graph_output = g(input)\n            eager_output = m(input)\n\n            test_case.assertTrue(graph_output.sbp == eager_output.sbp)\n            test_case.assertTrue(graph_output.shape == eager_output.shape)\n            test_case.assertTrue(graph_output.placement == eager_output.placement)\n            test_case.assertTrue(np.allclose(graph_output, eager_output))\n\n    @flow.unittest.skip_unless_1n2d()\n    def test_module_parallel_run_by_vm(test_case):\n        with RunGraphByVmEnv():\n\n            class ModuleParallelMul(flow.nn.Module):\n                def __init__(self, placement) -> None:\n                    super().__init__()\n                    self.w = flow.randn(\n                        5, 8, placement=placement, sbp=flow.sbp.split(1)\n                    )\n\n                def forward(self, x):\n                    return flow.matmul(x, self.w)\n\n            placement = flow.placement(\"cuda\", [0, 1])\n            m = ModuleParallelMul(placement).eval()\n            g = Graph(m)\n\n            input = flow.randn(4, 5, placement=placement, sbp=flow.sbp.broadcast)\n            graph_output = g(input)\n            eager_output = m(input)\n\n            test_case.assertTrue(graph_output.sbp == eager_output.sbp)\n            test_case.assertTrue(graph_output.shape == eager_output.shape)\n            test_case.assertTrue(graph_output.placement == eager_output.placement)\n            test_case.assertTrue(np.allclose(graph_output, eager_output))\n\n    @flow.unittest.skip_unless_1n2d()\n    def test_boxing_data_parallel_run_by_vm(test_case):\n        with RunGraphByVmEnv():\n            flow.boxing.nccl.enable_use_compute_stream(False)\n\n            class BoxingModuleParallelMul(flow.nn.Module):\n                def __init__(self, placement) -> None:\n                    super().__init__()\n                    self.w1 = flow.randn(\n                        5, 8, placement=placement, sbp=flow.sbp.split(1)\n                    )\n                    self.w2 = flow.randn(\n                        8, 6, placement=placement, sbp=flow.sbp.split(1)\n                    )\n\n                def forward(self, x):\n                    x = flow.matmul(x, self.w1)\n                    x = flow.matmul(x, self.w2)\n                    return x\n\n            placement = flow.placement(\"cuda\", [0, 1])\n            m = BoxingModuleParallelMul(placement).eval()\n            g = Graph(m)\n\n            input = flow.randn(4, 5, placement=placement, sbp=flow.sbp.broadcast)\n            graph_output = g(input)\n            eager_output = m(input)\n\n            test_case.assertTrue(graph_output.sbp == eager_output.sbp)\n            test_case.assertTrue(graph_output.shape == eager_output.shape)\n            test_case.assertTrue(graph_output.placement == eager_output.placement)\n            test_case.assertTrue(np.allclose(graph_output, eager_output))\n\n    @flow.unittest.skip_unless_1n1d()\n    def test_empty_inputs(test_case):\n        with RunGraphByVmEnv():\n\n            class GraphReader(flow.nn.Graph):\n                def __init__(self):\n                    super().__init__()\n                    self.my_reader = OFRecordDataLoader()\n\n                def build(self):\n                    return self.my_reader()\n\n            reader_g = GraphReader()\n            image, label = reader_g()\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/graph/test_run_graph_by_vm.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport os\nimport oneflow as flow\nimport numpy as np\n\n\nclass EnvVar(object):\n    def __init__(self, env_list: dict):\n        self.env_list = env_list\n\n    def __enter__(self):\n        os.environ.update(self.env_list)\n\n    def __exit__(self, *args):\n        for key in self.env_list.keys():\n            if key in os.environ.keys():\n                os.environ.pop(key)\n\n\nclass RunGraphByVmEnv(EnvVar):\n    def __init__(self):\n        super().__init__(\n            {\n                \"ONEFLOW_RUN_GRAPH_BY_VM\": \"1\",\n                \"ONEFLOW_MLIR_ENABLE_ROUND_TRIP\": \"1\",\n                \"ONEFLOW_MLIR_ENABLE_INFERENCE_OPTIMIZATION\": \"1\",\n            }\n        )\n\n\nclass Graph(flow.nn.Graph):\n    def __init__(self, m):\n        super().__init__()\n        self.m = m\n\n    def build(self, x):\n        return self.m(x)\n\n\nclass M(flow.nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.w = flow.nn.Parameter(flow.randn(4))\n\n    def forward(self, x):\n        # these broadcast_sub and cast ops will be\n        # eliminated by nn.Graph\n        w1 = self.w - self.w - self.w\n        x = x * w1.to(flow.float32)\n        return x\n\n\ndef test_run_graph_by_vm(capsys):\n    with RunGraphByVmEnv():\n        m = M().eval()\n        g = Graph(m)\n\n        input = flow.randn(4)\n        graph_output = g(input)\n        eager_output = m(input)\n        assert graph_output.shape == (4,)\n        assert np.allclose(graph_output, eager_output)\n\n        input = flow.randn(3, 4)\n        graph_output = g(input)\n        eager_output = m(input)\n        assert graph_output.shape == (3, 4)\n        assert np.allclose(graph_output, eager_output)\n\n        # Test the optimization in graph works.\n        # broadcast_sub and cast ops are pruned.\n        print(g)\n        assert \"broadcast_sub\" not in capsys.readouterr().out\n        assert \"cast\" not in capsys.readouterr().out\n        assert \"broadcast_mul\" not in capsys.readouterr().out\n"
  },
  {
    "path": "python/oneflow/test/graph/test_to_global.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nimport os\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\nx = np.array(\n    [\n        [\n            0.21490018,\n            0.22043167,\n            0.1605895,\n            0.25424683,\n            0.12975895,\n            0.49967155,\n            0.04753795,\n            0.7518577,\n            0.38964537,\n            0.01955934,\n        ],\n        [\n            0.16392729,\n            0.41410774,\n            0.05424517,\n            0.7668146,\n            0.08050849,\n            0.5763975,\n            0.42364502,\n            0.4950619,\n            0.9608427,\n            0.11889187,\n        ],\n    ]\n)\n\ny = np.array(\n    [\n        [\n            0.9903706,\n            0.11213686,\n            0.29525927,\n            0.79380244,\n            0.70357895,\n            0.6950597,\n            0.52552456,\n            0.32304054,\n            0.6997739,\n            0.15671141,\n        ],\n        [\n            0.76867193,\n            0.59983397,\n            0.07774717,\n            0.07815815,\n            0.30385414,\n            0.7366552,\n            0.4607681,\n            0.40554753,\n            0.8290172,\n            0.8405671,\n        ],\n        [\n            0.8900324,\n            0.5274955,\n            0.80989295,\n            0.71331054,\n            0.8076364,\n            0.94833183,\n            0.04778554,\n            0.23992656,\n            0.57683426,\n            0.81757474,\n        ],\n    ]\n)\n\n\nclass MyModule1(flow.nn.Module):\n    def __init__(self, weight):\n        assert isinstance(weight, flow._oneflow_internal.Tensor)\n        super().__init__()\n        self.weight = flow.nn.Parameter(weight)\n        self.activation = flow.nn.ReLU()\n\n    def forward(self, x):\n        # print(f\"x shape: {x.shape}, placement: {x.placement}, sbp: {x.sbp}\")\n        # print(\n        #     f\"weight shape: {self.weight.shape}, placement: {self.weight.placement}, sbp: {self.weight.sbp}\"\n        # )\n        y = flow._C.matmul(x, self.weight, transpose_b=True)\n        # print(f\"y shape: {y.shape}, placement: {y.placement}, sbp: {y.sbp}\")\n        if y.is_global:\n            y = y.to_global(sbp=flow.sbp.broadcast)\n            # print(f\"post y shape: {y.shape}, placement: {y.placement}, sbp: {y.sbp}\")\n        return self.activation(y)\n\n\nclass MyModule2(flow.nn.Module):\n    def __init__(self, weight):\n        assert isinstance(weight, flow._oneflow_internal.Tensor)\n        super().__init__()\n        self.weight = flow.nn.Parameter(weight)\n        self.activation = flow.nn.ReLU()\n\n    def forward(self, x):\n        # print(f\"weight shape: {self.weight.shape}, placement: {self.weight.placement}, sbp: {self.weight.sbp}\")\n        if self.weight.is_global:\n            y = self.weight.to_global(grad_sbp=flow.sbp.broadcast)\n        z = flow._C.matmul(y, x, transpose_b=True)\n        out = self.activation(z).sum()\n        if self.weight.is_global:\n            out = out.to_global(sbp=flow.sbp.broadcast)\n        return out\n\n\nclass MyModule3(flow.nn.Module):\n    def __init__(self, transpose_a=False, transpose_b=False):\n        super().__init__()\n        self.activation = flow.nn.ReLU()\n        self.transpose_a = transpose_a\n        self.transpose_b = transpose_b\n\n    def forward(self, x, y):\n        z = flow._C.matmul(x, y, self.transpose_a, self.transpose_b)\n        if z.is_global:\n            z = z.to_global(sbp=flow.sbp.broadcast)\n        return self.activation(z)\n\n\nclass GlobalToModule(flow.nn.Module):\n    def __init__(self, device=\"cuda\"):\n        super().__init__()\n        self.device = device\n\n    def forward(self, x):\n        return x.to(self.device)\n\n\nclass FreeTensorModule(flow.nn.Module):\n    def __init__(self, shape, placement, sbp):\n        super().__init__()\n        self.shape = shape\n        self.placement = placement\n        self.sbp = sbp\n\n    def forward(self, x):\n        y = flow.ones(\n            self.shape, dtype=flow.float32, placement=self.placement, sbp=self.sbp\n        )\n        return flow._C.matmul(x, y, transpose_b=True)\n\n\nclass ToPlacementModule(flow.nn.Module):\n    def __init__(self, placement):\n        super().__init__()\n        self.placement = placement\n\n    def forward(self, x):\n        return x.to_global(placement=self.placement)\n\n\nclass MyGraph(flow.nn.Graph):\n    def __init__(self, module, optimizer=None):\n        super().__init__()\n        self.module = module\n        if optimizer is not None:\n            self.add_optimizer(optimizer)\n\n    def build(self, *arg):\n        y = self.module(*arg)\n        if self.config.training:\n            y.backward()\n        return y\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n@flow.unittest.skip_unless_1n2d()\nclass ToGlobalGraphTestCase(oneflow.unittest.TestCase):\n    # @unittest.skipIf(True, \"\")\n    def test_fwd_P2B(test_case):\n        \"\"\" compare eager fwd and lazy bwd\n        \"\"\"\n        rank = flow.env.get_rank()\n        # pid = os.getpid()\n        # print(f\"[{pid}][{rank}] ToGlobalGraphTestCase.test_fwd_P2B\")\n\n        local_x = flow.tensor(x, dtype=flow.float32, device=flow.device(f\"cuda:{rank}\"))\n        local_y = flow.tensor(y, dtype=flow.float32, device=flow.device(f\"cuda:{rank}\"))\n\n        z = flow._C.matmul(\n            flow.cat([local_x, local_x], dim=1),\n            flow.cat([local_y, local_y], dim=1),\n            transpose_b=True,\n        )\n        z = flow._C.relu(z)\n        # print(f\"z shape: {z.shape}, device: {z.device}\")\n        # print(z.numpy())\n\n        placement = flow.placement(\"cuda\", ranks=[0, 1])\n        sbp = flow.sbp.split(1)\n        c_x = local_x.to_global(placement=placement, sbp=sbp)\n        c_y = local_y.to_global(placement=placement, sbp=sbp)\n\n        # print(f\"c_x shape: {c_x.shape}, placement: {c_x.placement}, sbp: {c_x.sbp}\")\n        # print(f\"c_y shape: {c_y.shape}, placement: {c_y.placement}, sbp: {c_y.sbp}\")\n\n        m = MyModule1(c_y)\n        g = MyGraph(m)\n\n        g_z = g(c_x)\n        # print(f\"g_z shape: {g_z.shape}, placement: {g_z.placement}, sbp: {g_z.sbp}\")\n        # print(g_z.to_local().numpy())\n        test_case.assertTrue(np.allclose(z.numpy(), g_z.to_local().numpy()))\n\n    # @unittest.skipIf(True, \"\")\n    def test_bwd_P2B(test_case):\n        \"\"\" compare eager bwd and lazy bwd\n        \"\"\"\n        rank = flow.env.get_rank()\n        # pid = os.getpid()\n        # print(f\"[{pid}][{rank}] ToGlobalGraphTestCase.test_bwd_P2B\")\n\n        local_x = flow.tensor(x, dtype=flow.float32, device=flow.device(f\"cuda:{rank}\"))\n        local_y = flow.tensor(y, dtype=flow.float32, device=flow.device(f\"cuda:{rank}\"))\n\n        z = flow._C.matmul(\n            local_y, flow.cat([local_x, local_x], dim=0), transpose_b=True,\n        )\n        z = flow._C.relu(z)\n        z = z.sum()\n\n        placement = flow.placement(\"cuda\", ranks=[0, 1])\n        c_x = local_x.to_global(placement=placement, sbp=flow.sbp.split(0))\n        c_y = local_y.to_global(placement=placement, sbp=flow.sbp.broadcast)\n\n        m = MyModule2(c_y)\n        optimizer = flow.optim.SGD(m.parameters(), lr=1.0)\n        g = MyGraph(m, optimizer)\n\n        g_z = g(c_x)\n        # print(f\"g_z shape: {g_z.shape}, placement: {g_z.placement}, sbp: {g_z.sbp}\")\n        test_case.assertTrue(g_z.is_global)\n        test_case.assertTrue(g_z.sbp[0] == flow.sbp.broadcast)\n        # S(1) -> B not supported yet\n        # c_z = g_z.to_global(sbp=flow.sbp.broadcast)\n        # print(f\"c_z shape: {c_z.shape}, placement: {c_z.placement}, sbp: {c_z.sbp}\")\n        test_case.assertTrue(np.allclose(z.numpy(), g_z.to_local().numpy()))\n\n        e_y = c_y.detach()\n        # print(f\"e_y shape: {e_y.shape}, placement: {e_y.placement}, sbp: {e_y.sbp}\")\n        e_m = MyModule2(e_y)\n        e_z = e_m(c_x)\n        # print(f\"e_z shape: {e_z.shape}, placement: {e_z.placement}, sbp: {e_z.sbp}\")\n        e_z.backward()\n\n        test_case.assertTrue(\n            np.allclose(c_y.to_local().numpy(), e_y.to_local().numpy())\n        )\n\n    # @unittest.skipIf(True, \"\")\n    def test_multi_graph(test_case):\n        \"\"\" compare two lazy fwd\n        \"\"\"\n        rank = flow.env.get_rank()\n        # pid = os.getpid()\n        # print(f\"[{pid}][{rank}] ToGlobalGraphTestCase.test_multi_graph\")\n\n        local_x = flow.tensor(x, dtype=flow.float32, device=flow.device(f\"cuda:{rank}\"))\n        local_y = flow.tensor(y, dtype=flow.float32, device=flow.device(f\"cuda:{rank}\"))\n\n        placement = flow.placement(\"cuda\", ranks=[0, 1])\n        x1 = local_x.to_global(placement=placement, sbp=flow.sbp.broadcast)\n        y1 = local_y.to_global(placement=placement, sbp=flow.sbp.broadcast)\n        # B * B -> B -> B\n        m1 = MyModule3(transpose_b=True)\n        g1 = MyGraph(m1)\n\n        slice_obj = slice(\n            int(rank * local_x.shape[0] / 2), int((rank + 1) * local_x.shape[0] / 2)\n        )\n        x2 = local_x[slice_obj, :]\n        x2 = x2.to_global(placement=placement, sbp=flow.sbp.split(0))\n        y2 = local_y.to_global(placement=placement, sbp=flow.sbp.broadcast)\n        # S(0) * B -> S(0) -> B\n        m2 = MyModule3(transpose_b=True)\n        g2 = MyGraph(m2)\n\n        x3 = local_x[\n            :, int(rank * local_x.shape[1] / 2) : int((rank + 1) * local_x.shape[1] / 2)\n        ]\n        x3 = x3.to_global(placement=placement, sbp=flow.sbp.split(1))\n        y3 = local_y[\n            :, int(rank * local_y.shape[1] / 2) : int((rank + 1) * local_y.shape[1] / 2)\n        ]\n        y3 = y3.to_global(placement=placement, sbp=flow.sbp.split(1))\n        # S(1) * S(0) -> P -> B\n        m3 = MyModule3(transpose_b=True)\n        g3 = MyGraph(m3)\n\n        z1 = g1(x1, y1)\n        # print(f\"z1 shape: {z1.shape}, placement: {z1.placement}, sbp: {z1.sbp}\")\n        # print(z1.to_local().numpy())\n        z2 = g2(x2, y2)\n        # print(f\"z2 shape: {z2.shape}, placement: {z2.placement}, sbp: {z2.sbp}\")\n        # print(z2.to_local().numpy())\n        z3 = g3(x3, y3)\n        # print(f\"z3 shape: {z3.shape}, placement: {z3.placement}, sbp: {z3.sbp}\")\n        # print(z3.to_local().numpy())\n\n        test_case.assertTrue(np.allclose(z1.to_local().numpy(), z2.to_local().numpy()))\n        test_case.assertTrue(np.allclose(z1.to_local().numpy(), z3.to_local().numpy()))\n\n    # @unittest.skipIf(True, \"\")\n    def test_global_to(test_case):\n        c_x = flow.ones(\n            (4, 3), placement=flow.placement(\"cpu\", ranks=[0, 1]), sbp=flow.sbp.split(0)\n        )\n\n        global_to = GlobalToModule(\"cuda\")\n        g_global_to = MyGraph(global_to)\n\n        e = global_to(c_x)\n        test_case.assertTrue(e.is_cuda)\n        test_case.assertTrue(e.is_global)\n        test_case.assertTrue(e.sbp[0] == flow.sbp.split(0))\n\n        g = g_global_to(c_x)\n        test_case.assertTrue(g.is_cuda)\n        test_case.assertTrue(g.is_global)\n        test_case.assertTrue(g.sbp[0] == flow.sbp.split(0))\n\n        test_case.assertTrue(np.allclose(e.to_local().numpy(), g.to_local().numpy()))\n\n    # @unittest.skipIf(True, \"\")\n    def test_free_tensor_to_global(test_case):\n        local_x = flow.tensor(x, dtype=flow.float32, device=\"cpu\")\n        placement = flow.placement(\"cuda\", ranks=[0, 1])\n        c_x = local_x.to_global(placement, flow.sbp.split(0))\n\n        m = FreeTensorModule((3, 10), placement, flow.sbp.broadcast)\n        g = MyGraph(m)\n\n        eager_out = m(c_x)\n        test_case.assertTrue(eager_out.is_cuda)\n        test_case.assertTrue(eager_out.is_global)\n        test_case.assertTrue(eager_out.sbp[0] == flow.sbp.split(0))\n\n        graph_out = g(c_x)\n        test_case.assertTrue(graph_out.is_cuda)\n        test_case.assertTrue(graph_out.is_global)\n        test_case.assertTrue(graph_out.sbp[0] == flow.sbp.split(0))\n\n        test_case.assertTrue(\n            np.allclose(eager_out.to_local().numpy(), graph_out.to_local().numpy())\n        )\n\n    # @unittest.skipIf(True, \"\")\n    def test_to_placement(test_case):\n        rank = flow.env.get_rank()\n        # pid = os.getpid()\n        # print(f\"[{pid}][{rank}] ToGlobalGraphTestCase.test_to_placement\")\n\n        if rank == 0:\n            x = flow.ones((2, 3), dtype=flow.float32)\n        elif rank == 1:\n            x = flow.empty(tuple())\n        else:\n            raise ValueError\n\n        c_x = x.to_global(\n            placement=flow.placement(\"cpu\", ranks=[0]), sbp=flow.sbp.broadcast\n        )\n        # print(f\"c_x shape: {c_x.shape}, placement: {c_x.placement}, sbp: {c_x.sbp}\")\n\n        p1 = flow.placement(\"cpu\", ranks=[0, 1])\n        m1 = ToPlacementModule(p1)\n        g1 = MyGraph(m1)\n        y1 = g1(c_x)\n\n        # print(f\"y1 shape: {y1.shape}, placement: {y1.placement}, sbp: {y1.sbp}\")\n        test_case.assertTrue(y1.placement == p1)\n        test_case.assertTrue(y1.sbp[0] == flow.sbp.broadcast)\n        test_case.assertTrue(y1.to_local().numpy().mean() == 1.0)\n\n        p2 = flow.placement(\"cuda\", ranks=[0, 1])\n        m2 = ToPlacementModule(p2)\n        g2 = MyGraph(m2)\n        y2 = g2(y1)\n\n        # print(f\"y2 shape: {y2.shape}, placement: {y2.placement}, sbp: {y2.sbp}\")\n        test_case.assertTrue(y2.placement == p2)\n        test_case.assertTrue(y2.sbp[0] == flow.sbp.broadcast)\n        test_case.assertTrue(y2.to_local().numpy().mean() == 1.0)\n\n    # @unittest.skipIf(True, \"\")\n    def test_to_dtype(test_case):\n        x = flow.ones((2, 3), dtype=flow.int32, device=\"cpu\")\n\n        placement = flow.placement(\"cpu\", ranks=[0, 1])\n        c_x = flow.ones(\n            (2, 3), dtype=flow.int32, placement=placement, sbp=flow.sbp.broadcast\n        )\n\n        class CastModule(flow.nn.Module):\n            def __init__(self, dtype):\n                super().__init__()\n                self.dtype = dtype\n\n            def forward(self, x):\n                return x.to(dtype=self.dtype)\n\n        m = CastModule(flow.float32)\n        g = MyGraph(m)\n\n        e_x = m(x)\n        e_c_x = m(c_x)\n        # NOTE(chengcheng):\n        #   There are two BUG in this test script:\n        #   1. first call and second call input tensor meta is NOT same\n        #   2. nn.Graph NOT support local input with multi-rank yet.\n        # g_x = g(x)\n        g_c_x = g(c_x)\n\n        test_case.assertTrue(e_x.dtype == flow.float32)\n        # test_case.assertTrue(g_x.dtype == flow.float32)\n        test_case.assertTrue(e_c_x.dtype == flow.float32)\n        test_case.assertTrue(g_c_x.dtype == flow.float32)\n\n\nclass MyModule5(flow.nn.Module):\n    def __init__(self, transpose_a=False, transpose_b=False, sbp=[]):\n        super().__init__()\n        self.transpose_a = transpose_a\n        self.transpose_b = transpose_b\n        self.sbp = sbp\n\n    def forward(self, x, y):\n        z = flow._C.matmul(x, y, self.transpose_a, self.transpose_b)\n        assert z.is_global\n        assert len(z.sbp) == len(self.sbp)\n        return z.to_global(sbp=self.sbp)\n\n\n@unittest.skipIf(True, \"\")\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n@flow.unittest.skip_unless_1n4d()\nclass ToGlobal2DGraphTestCase(oneflow.unittest.TestCase):\n    def test_matmul(test_case):\n        placement = flow.placement(\"cuda\", ranks=[[0, 1], [2, 3]])\n        x = flow.ones(\n            (4, 6), placement=placement, sbp=[flow.sbp.split(0), flow.sbp.split(1)]\n        )\n        y = flow.ones(\n            (4, 6), placement=placement, sbp=[flow.sbp.broadcast, flow.sbp.split(1)]\n        )\n        z = flow._C.matmul(x, y, transpose_b=True)\n        print(f\"z shape: {z.shape}, placement: {z.placement}, sbp: {z.sbp}\")\n\n        # m = MyModule5(transpose_b=True, sbp=[flow.sbp.split(0), flow.sbp.broadcast])\n        # z = m(x, y)\n        # print(f\"z shape: {z.shape}, placement: {z.placement}, sbp: {z.sbp}\")\n\n\n@flow.unittest.skip_unless_1n4d()\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\nclass TestLazy1dTo2dGlobal(flow.unittest.TestCase):\n    def test_lazy_1d_to_2d_sbp(test_case):\n        P_1d = flow.placement(\n            device_type=\"cuda\", device_ids={0: range(4)}, hierarchy=(4,)\n        )\n        P_2d = flow.placement(\n            device_type=\"cuda\", device_ids={0: range(4)}, hierarchy=(2, 2)\n        )\n        B = flow.sbp.broadcast\n\n        class Test1dTo2dModule(flow.nn.Module):\n            def forward(self, x):\n                return x.to_global(placement=P_2d, sbp=[B, B])\n\n        class Test1dTo2dGraph(flow.nn.Graph):\n            def __init__(self, model):\n                super().__init__()\n                self.model = model\n\n            def build(self, x):\n                return self.model(x)\n\n        class Test2dTo1dModule(flow.nn.Module):\n            def forward(self, x):\n                return x.to_global(placement=P_1d, sbp=[B])\n\n        class Test2dTo1dGraph(flow.nn.Graph):\n            def __init__(self, model):\n                super().__init__()\n                self.model = model\n\n            def build(self, x):\n                return self.model(x)\n\n        model_1d_to_2d = Test1dTo2dModule()\n        graph_1d_to_2d = Test1dTo2dGraph(model_1d_to_2d)\n\n        x = flow.zeros(4, 4, 4, 4, sbp=[B, B], placement=P_2d)\n        x = x.to_global(placement=P_1d, sbp=[B])\n        test_case.assertTrue(x.sbp == (B,))\n        test_case.assertTrue(x.placement == P_1d)\n        y = graph_1d_to_2d(x)\n        test_case.assertTrue(y.sbp == (B, B))\n        test_case.assertTrue(y.placement == P_2d)\n\n        model_2d_to_1d = Test2dTo1dModule()\n        graph_2d_to_1d = Test2dTo1dGraph(model_2d_to_1d)\n        z = graph_2d_to_1d(y)\n        test_case.assertTrue(z.sbp == x.sbp)\n        test_case.assertTrue(z.placement == x.placement)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/graph/test_tvm_frontend_dependency_on_graph.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport os\nimport re\nimport unittest\n\nimport numpy as np\nimport oneflow as flow\nimport oneflow.unittest\nfrom alexnet_model import alexnet\n\n\nclass TvmFrontedGraph(flow.nn.Graph):\n    def __init__(self, module):\n        super().__init__()\n        self.m = module\n\n    def build(self, x):\n        out = self.m(x)\n        return out\n\n\ndef parse_attr(attr):\n    # Parse node_attr\n    attrs = {}\n    for a in attr:\n        attr_str = str(attr[a])\n\n        if attr_str[0:7] == \"at_list\":\n            attr_str_ = attr_str.split(\" \")[0]\n\n            if attr_str_ == \"at_list_float\":\n                attrs[a] = tuple(attr[a].at_list_float.val)\n            elif attr_str_ == \"at_list_int32\":\n                attrs[a] = tuple(attr[a].at_list_int32.val)\n            elif attr_str_ == \"at_list_int64\":\n                attrs[a] = tuple(attr[a].at_list_int64.val)\n\n        elif attr_str.split(\":\")[0] == \"at_string\":\n            attrs[a] = attr[a].at_string\n\n        elif attr_str.split(\" \")[0] == \"at_shape\":\n            attrs[a] = tuple(list(attr[a].at_shape.dim))\n\n        else:\n            attr_str_ = attr_str.split(\":\")[0]\n            if attr_str_ == \"at_bool\":\n                attrs[a] = attr[a].at_bool\n            elif attr_str_ == \"at_double\":\n                attrs[a] = attr[a].at_double\n            elif attr_str_ == \"at_float\":\n                attrs[a] = attr[a].at_float\n            elif attr_str_ == \"at_int32\":\n                attrs[a] = attr[a].at_int32\n            elif attr_str_ == \"at_int64\":\n                attrs[a] = attr[a].at_int64\n\n    return attrs\n\n\ndef is_user_op(node):\n    # Determine if the the node is the intermediate variables of graph\n    return node.WhichOneof(\"op_type\") == \"user_conf\"\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n@flow.unittest.skip_unless_1n1d()\nclass TestConvertDependency(flow.unittest.TestCase):\n    def test_get_params(test_case):\n        class ConvModel(flow.nn.Module):\n            def __init__(self):\n                super(ConvModel, self).__init__()\n                self.conv = flow.nn.Conv2d(3, 64, kernel_size=11, bias=False)\n\n            def forward(self, x):\n                x = self.conv(x)\n                return x\n\n        model = ConvModel().state_dict()\n        for layer_name in model:\n            layer_path = os.path.join(layer_name, \"out\")\n            test_case.assertEqual(layer_path, \"conv.weight/out\")\n\n    def test_infos_of_nodes(test_case):\n        alexnet_module = alexnet()\n        alexnet_graph = TvmFrontedGraph(alexnet_module)\n        if not alexnet_graph._is_compiled:\n            alexnet_graph._compile(flow.rand(1, 3, 224, 224))\n        graph_str = repr(alexnet_graph)\n        if not alexnet_graph._is_compiled:\n            alexnet_graph._compile(flow.rand(shape_input))\n\n        size_where = 2\n        if \"cuda\" in graph_str:\n            size_where = 3\n\n        p_size = re.compile(r\"size=\\(.*?\\)\", re.S)\n        p_type = re.compile(r\"(dtype=.*?)[,|\\)]\", re.S)\n        types = [\"INPUT\", \"PARAMETER\", \"BUFFER\", \"OUTPUT\"]\n        num_nodes = {}\n\n        for t in types:\n            data = re.finditer(t + \":.*\", graph_str)\n            cnt = 0\n            for i in data:\n                cnt += 1\n                attrs = i.group().split(\":\")\n                size_strs = re.findall(p_size, attrs[size_where])\n                type_strs = re.findall(p_type, attrs[size_where])\n                test_case.assertEqual(size_strs != [], True)\n                test_case.assertEqual(type_strs != [], True)\n\n                size_attr = size_strs[0].replace(\"size=\", \"\")\n                type_attr = type_strs[0].replace(\"dtype=\", \"\").replace(\")\", \"\")\n                if size_attr[-2] == \",\":\n                    size_attr = size_attr.replace(\",\", \"\")\n                if type_attr[-1] == \",\":\n                    type_attr = type_attr.replace(\",\", \"\")\n                    test_case.assertEqual(type_attr, \"oneflow.float32\")\n\n                data_size = tuple(map(int, size_attr[1:-1].split(\", \")))\n                if cnt == 1 and t == \"PARAMETER\":\n                    test_case.assertEqual(data_size, (64, 3, 11, 11))\n                elif cnt == 15 and t == \"PARAMETER\":\n                    test_case.assertEqual(data_size, (1000, 4096))\n            num_nodes[t] = cnt\n\n        test_case.assertEqual(num_nodes[\"INPUT\"] != 0, True)\n        test_case.assertEqual(num_nodes[\"BUFFER\"], 0)\n        test_case.assertEqual(num_nodes[\"PARAMETER\"], 16)\n        test_case.assertEqual(num_nodes[\"OUTPUT\"] != 0, True)\n\n        # get graph proto, if you don't _compile the graph, the _graph_proto will be None\n        graph_input = re.search(r\"INPUT:.*\", graph_str).group().split(\":\")\n        shape_input = tuple(\n            map(\n                int,\n                re.findall(p_size, graph_input[size_where])[0]\n                .replace(\"size=\", \"\")[1:-1]\n                .split(\", \"),\n            )\n        )\n        graph_proto = alexnet_graph._graph_proto\n\n        nodes = {}\n        for op in graph_proto.net.op:\n            nodes[op.name] = op\n\n        op_names = []\n        op_attrs = []\n        for node_name in nodes:\n            node = nodes[node_name]\n            if is_user_op(node):\n                op_name = node.user_conf.op_type_name\n                op_attr = parse_attr(node.user_conf.attr)\n                op_names.append(op_name)\n                op_attrs.append(op_attr)\n\n        test_case.assertEqual(op_names[0], \"conv2d\")\n        test_case.assertEqual(op_names[1], \"bias_add\")\n        test_case.assertEqual(op_names[2], \"relu\")\n\n        kernel_size = op_attrs[0].get(\"kernel_size\", None)\n        strides = op_attrs[0].get(\"strides\", None)\n        padding_before = op_attrs[0].get(\"padding_before\", None)\n        test_case.assertEqual(kernel_size, (11, 11))\n        test_case.assertEqual(strides, (4, 4))\n        test_case.assertEqual(padding_before, (2, 2))\n\n        node_input_list = []\n        node_output_list = []\n        for node_name in nodes:\n            node = nodes[node_name]\n            if is_user_op(node) and node.user_conf.op_type_name == \"conv2d\":\n                for input_name in node.user_conf.input:\n                    node_input_paths = getattr(node.user_conf.input[input_name], \"s\")\n                    for i in node_input_paths:\n                        node_input = i.split(\"/\")[0]\n                        print(node_input)\n                        node_input_list.append(node_input)\n                for output_name in node.user_conf.output:\n                    node_output_paths = getattr(node.user_conf.output[output_name], \"s\")\n                    for node_output_path in node_output_paths:\n                        node_output_name = node_output_path.split(\"/\")[0]\n                        print(node_output_name)\n                        node_output_list.append(node_output_name)\n\n        test_case.assertEqual(\"_TvmFrontedGraph_1_input.0.0_2\" in node_input_list, True)\n        test_case.assertEqual(\"m.features.0.weight\" in node_input_list, True)\n        test_case.assertEqual(\"m.features.5-max_pool_2d-7\" in node_input_list, True)\n        test_case.assertEqual(\"m.features.0-conv2d-0\" in node_output_list, True)\n        test_case.assertEqual(\"m.features.6-conv2d-8\" in node_output_list, True)\n\n    def test_buffer_convert_dependence(test_case):\n        class SubModule(flow.nn.Module):\n            def __init__(self):\n                super().__init__()\n                self.fc1 = flow.nn.Linear(36, 4, False)\n                self.register_buffer(\"dummy_buff\", flow.Tensor(1, 4))\n\n            def forward(self, x):\n                x = self.fc1(x)\n                x += self.dummy_buff\n                return x\n\n        sub_module = SubModule()\n        sub_graph = TvmFrontedGraph(sub_module)\n        graph_str = repr(sub_graph)\n\n        size_where = 2\n        if \"cuda\" in graph_str:\n            size_where = 3\n\n        p_size = re.compile(r\"size=\\(.*?\\)\", re.S)\n        p_type = re.compile(r\"dtype=.*?,\", re.S)\n        num_nodes = {}\n\n        data = re.finditer(\"BUFFER:.*\", graph_str)\n        for i in data:\n            attrs = i.group().split(\":\")\n            size_strs = re.findall(p_size, attrs[size_where])\n            size_attr = size_strs[0].replace(\"size=\", \"\")\n            data_size = tuple(map(int, size_attr[1:-1].split(\", \")))\n            test_case.assertEqual(data_size, (1, 4))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/graph/test_user_op_expr.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\n\nimport numpy as np\nfrom google.protobuf import text_format\nimport os\n\nimport oneflow\nimport oneflow as flow\nimport oneflow._oneflow_internal\nimport oneflow._oneflow_internal._C as _C\nimport oneflow.framework.c_api_util as c_api_util\nimport oneflow.framework.session_context as session_ctx\nimport oneflow.unittest\nfrom oneflow.framework.multi_client_session import MultiClientSession\n\n\ndef _get_c_tensor(t):\n    if isinstance(t, oneflow._oneflow_internal.Tensor):\n        return t\n    else:\n        raise NotImplementError\n\n\ndef _test_user_op_graph(test_case, is_cuda):\n    x0 = flow.tensor(np.random.rand(20, 30), dtype=flow.float32)\n    weight0 = flow.tensor(np.random.rand(30, 50), dtype=flow.float32)\n    x1 = flow.tensor(np.random.rand(50, 70), dtype=flow.float32)\n\n    if is_cuda:\n        x0 = x0.to(device=flow.device(\"cuda\"))\n        weight0 = weight0.to(device=flow.device(\"cuda\"))\n        x1 = x1.to(device=flow.device(\"cuda\"))\n\n    # NOTE(chengcheng): this tiny net is:\n    #    x0 * weight0 -> out0\n    #    relu(out0) -> y0\n    #    y0 * x1 -> out1\n    #    relu(out1) -> y1\n\n    session = session_ctx.GetDefaultSession()\n    test_case.assertTrue(isinstance(session, MultiClientSession))\n    session.TryInit()\n\n    with oneflow._oneflow_internal.lazy_mode.guard(True):\n\n        oneflow._oneflow_internal.JobBuildAndInferCtx_Open(\n            \"cc_test_user_op_expr_job_with_cuda\" + str(is_cuda)\n        )\n        job_conf = oneflow.core.job.job_conf_pb2.JobConfigProto()\n        job_conf.job_name = \"cc_test_user_op_expr_job_with_cuda\" + str(is_cuda)\n        job_conf.predict_conf.SetInParent()\n        c_api_util.CurJobBuildAndInferCtx_SetJobConf(job_conf)\n\n        x0_conf = oneflow.core.operator.op_conf_pb2.FeedInputOpConf()\n        x0_conf.in_0 = \"in_0\"\n        x0_conf.out_0 = \"out_0\"\n        x0_conf_str = text_format.MessageToString(x0_conf)\n        x0_op = oneflow._oneflow_internal.one.FeedInputOpExpr(\n            \"cc_Input_0\", x0_conf_str, [\"in_0\"], [\"out_0\"]\n        )\n\n        x1_conf = oneflow.core.operator.op_conf_pb2.FeedInputOpConf()\n        x1_conf.in_0 = \"in_0\"\n        x1_conf.out_0 = \"out_0\"\n        x1_conf_str = text_format.MessageToString(x1_conf)\n        x1_op = oneflow._oneflow_internal.one.FeedInputOpExpr(\n            \"cc_Input_1\", x1_conf_str, [\"in_0\"], [\"out_0\"]\n        )\n\n        weight0_conf = oneflow.core.operator.op_conf_pb2.FeedVariableOpConf()\n        weight0_conf.in_0 = \"in_0\"\n        weight0_conf.out_0 = \"out_0\"\n        weight0_conf_str = text_format.MessageToString(weight0_conf)\n        weight0_op = oneflow._oneflow_internal.one.FeedVariableOpExpr(\n            \"cc_Variable_0\", weight0_conf_str, [\"in_0\"], [\"out_0\"]\n        )\n        output_conf = oneflow.core.operator.op_conf_pb2.FetchOutputOpConf()\n        output_conf.in_0 = \"in_0\"\n        output_conf.out_0 = \"out_0\"\n        output_conf_str = text_format.MessageToString(output_conf)\n        output_op = oneflow._oneflow_internal.one.FetchOutputOpExpr(\n            \"cc_Output_0\", output_conf_str, [\"in_0\"], [\"out_0\"]\n        )\n\n        x0_lazy_tensor = _C.dispatch_feed_input(x0_op, x0)\n        x1_lazy_tensor = _C.dispatch_feed_input(x1_op, x1)\n        weight0_lazy_tensor = _C.dispatch_feed_input(weight0_op, weight0)\n\n        test_case.assertEqual(x0_lazy_tensor.shape, (20, 30))\n        test_case.assertTrue(x0_lazy_tensor.is_lazy)\n\n        test_case.assertEqual(weight0_lazy_tensor.shape, (30, 50))\n        test_case.assertTrue(weight0_lazy_tensor.is_lazy)\n        test_case.assertEqual(x1_lazy_tensor.shape, (50, 70))\n        test_case.assertTrue(x1_lazy_tensor.is_lazy)\n\n        out0 = flow._C.matmul(x0_lazy_tensor, weight0_lazy_tensor)\n        test_case.assertEqual(out0.shape, (20, 50))\n        test_case.assertTrue(out0.is_lazy)\n\n        y0 = flow._C.relu(out0)\n        test_case.assertEqual(y0.shape, (20, 50))\n        test_case.assertTrue(y0.is_lazy)\n\n        out1 = flow._C.matmul(y0, x1_lazy_tensor)\n        test_case.assertEqual(out1.shape, (20, 70))\n        test_case.assertTrue(out1.is_lazy)\n\n        y1 = flow._C.relu(out1)\n        test_case.assertEqual(y1.shape, (20, 70))\n        test_case.assertTrue(y1.is_lazy)\n\n        eager_output = _C.dispatch_fetch_output(output_op, y1)\n        test_case.assertEqual(eager_output.shape, (20, 70))\n        test_case.assertTrue(not eager_output.is_lazy)\n\n        if is_cuda:\n            test_case.assertTrue(x0_lazy_tensor.is_cuda)\n            test_case.assertTrue(x1_lazy_tensor.is_cuda)\n            test_case.assertTrue(weight0_lazy_tensor.is_cuda)\n            test_case.assertTrue(out0.is_cuda)\n            test_case.assertTrue(y0.is_cuda)\n            test_case.assertTrue(out1.is_cuda)\n            test_case.assertTrue(y1.is_cuda)\n\n        oneflow._oneflow_internal.JobBuildAndInferCtx_Close()\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestUserOpGraph(unittest.TestCase):\n    def test_user_op_graph_cpu(test_case):\n        _test_user_op_graph(test_case, False)\n\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    def test_user_op_graph_gpu(test_case):\n        _test_user_op_graph(test_case, True)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/graph/test_util.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport itertools\nimport os\nfrom collections import OrderedDict\nfrom collections.abc import Iterable\n\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\ndef GenCartesianProduct(sets):\n    assert isinstance(sets, Iterable)\n    for set in sets:\n        assert isinstance(set, Iterable)\n        if os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"):\n            if \"cuda\" in set:\n                set.remove(\"cuda\")\n    return itertools.product(*sets)\n\n\ndef GenArgList(arg_dict):\n    assert isinstance(arg_dict, OrderedDict)\n    assert all([isinstance(x, list) for x in arg_dict.values()])\n    sets = [arg_set for (_, arg_set) in arg_dict.items()]\n    return GenCartesianProduct(sets)\n\n\ndef GenArgDict(arg_dict):\n    return [dict(zip(arg_dict.keys(), x)) for x in GenArgList(arg_dict)]\n\n\nclass Args:\n    def __init__(self, flow_args, tf_args=None):\n        super().__init__()\n        if tf_args is None:\n            tf_args = flow_args\n        self.flow_args = flow_args\n        self.tf_args = tf_args\n\n    def __str__(self):\n        return \"flow_args={} tf_args={}\".format(self.flow_args, self.tf_args)\n\n    def __repr__(self):\n        return self.__str__()\n\n\ntype_name_to_flow_type = {\n    \"float16\": flow.float16,\n    \"float32\": flow.float32,\n    \"double\": flow.double,\n    \"int8\": flow.int8,\n    \"int32\": flow.int32,\n    \"int64\": flow.int64,\n    \"uint8\": flow.uint8,\n}\ntype_name_to_np_type = {\n    \"float16\": np.float16,\n    \"float32\": np.float32,\n    \"double\": np.float64,\n    \"int8\": np.int8,\n    \"int32\": np.int32,\n    \"int64\": np.int64,\n    \"uint8\": np.uint8,\n}\n\n\ndef FlattenArray(input_array):\n    output_array = list()\n    for x in np.nditer(input_array):\n        output_array.append(x.tolist())\n    return output_array\n\n\ndef Array2Numpy(input_array, target_shape):\n    return np.array(input_array).reshape(target_shape, order=\"C\")\n\n\ndef Index2Coordinate(idx, tensor_shape):\n    coordinate = []\n    tmp = idx\n    for i in range(len(tensor_shape) - 1, -1, -1):\n        axis_size = tensor_shape[i]\n        coor = tmp % axis_size\n        coordinate.insert(0, int(coor))\n        tmp = (tmp - coor) / axis_size\n    return coordinate\n\n\ndef Coordinate2Index(coordinate, tensor_shape):\n    if len(coordinate) != len(tensor_shape):\n        raise \"wrong coordinate or shape\"\n    idx = 0\n    for (i, coor) in enumerate(coordinate):\n        size_at_axis = coor\n        for j in range(i + 1, len(tensor_shape)):\n            size_at_axis *= tensor_shape[j]\n        idx += size_at_axis\n    return idx\n\n\ndef generate_graph(func):\n    class Graph(flow.nn.Graph):\n        def __init__(self):\n            super().__init__()\n\n        def build(self, *args):\n            return func(*args)\n\n    return Graph()\n"
  },
  {
    "path": "python/oneflow/test/graph/test_variable_op_expr.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport os\nimport unittest\n\nimport numpy as np\nfrom google.protobuf import text_format\n\nimport oneflow\nimport oneflow as flow\nimport oneflow._oneflow_internal\nimport oneflow._oneflow_internal._C as _C\nimport oneflow.framework.c_api_util as c_api_util\nimport oneflow.framework.session_context as session_ctx\nimport oneflow.unittest\nfrom oneflow.framework.multi_client_session import MultiClientSession\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestFeedVariableTensor(unittest.TestCase):\n    def test_feed_var_tensor(test_case):\n        x = flow.Tensor(1, 1, 10, 10)\n        flow.nn.init.uniform_(x, a=-1.0, b=1.0)\n        session = session_ctx.GetDefaultSession()\n        test_case.assertTrue(isinstance(session, MultiClientSession))\n        session.TryInit()\n        with oneflow._oneflow_internal.lazy_mode.guard(True):\n            oneflow._oneflow_internal.JobBuildAndInferCtx_Open(\n                \"cc_test_variable_op_expr_job\"\n            )\n            job_conf = oneflow.core.job.job_conf_pb2.JobConfigProto()\n            job_conf.job_name = \"cc_test_variable_op_expr_job\"\n            job_conf.predict_conf.SetInParent()\n            c_api_util.CurJobBuildAndInferCtx_SetJobConf(job_conf)\n            op_name = \"cc_Variable_0\"\n            var_conf = oneflow.core.operator.op_conf_pb2.FeedVariableOpConf()\n            var_conf.in_0 = \"EagerTensorInput\"\n            var_conf.out_0 = \"out_0\"\n            var_conf_str = text_format.MessageToString(var_conf)\n            var_op = oneflow._oneflow_internal.one.FeedVariableOpExpr(\n                op_name, var_conf_str, [\"in_0\"], [\"out_0\"]\n            )\n            out_tensor = _C.dispatch_feed_variable(var_op, x, l2=0)\n            test_case.assertEqual(out_tensor.shape, (1, 1, 10, 10))\n            test_case.assertTrue(out_tensor.is_lazy)\n            test_case.assertTrue(out_tensor.is_local)\n            oneflow._oneflow_internal.JobBuildAndInferCtx_Close()\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/misc/mock_example.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport torch\n\nprint(torch.__file__.find(\"mock_torch\") != -1)\n\n\ndef f():\n    return torch.__package__\n"
  },
  {
    "path": "python/oneflow/test/misc/test_autograd_functional.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nfrom packaging import version\nimport unittest\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import torch\nfrom oneflow.test_utils.automated_test_util import random_tensor\nfrom oneflow.test_utils.automated_test_util import autotest\n\n\ndef _func_tensor(x):\n    return x.exp().sum(dim=1)\n\n\ndef _func_scalar(x):\n    return x.exp().sum()\n\n\ndef _func_multi_tensor(x, y):\n    return (x.exp() + y.pow(2)).sum(dim=1)\n\n\ndef _func_multi_scalar(x, y):\n    return (x.exp() + y.pow(2)).sum()\n\n\ndef _func_scalar2tensor(x):\n    return (x, x ** 2, x ** 3)\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestAutogradFunctional(flow.unittest.TestCase):\n    @autotest(n=1, check_graph=False)\n    def test_vjp(test_case):\n        inputs = random_tensor(ndim=2, dim0=5, dim1=5)\n        v = random_tensor(ndim=1, dim0=5)\n        result_tensor = torch.autograd.functional.vjp(_func_tensor, inputs, v)\n        result_scalar = torch.autograd.functional.vjp(_func_scalar, inputs)\n\n        inputs = (\n            random_tensor(ndim=2, dim0=5, dim1=5),\n            random_tensor(ndim=2, dim0=5, dim1=5),\n        )\n        result_tensors = torch.autograd.functional.vjp(_func_multi_tensor, inputs, v)\n        result_scalars = torch.autograd.functional.vjp(_func_multi_scalar, inputs)\n\n    @autotest(n=1, check_graph=False)\n    def test_jvp(test_case):\n        inputs = random_tensor(ndim=2, dim0=5, dim1=5)\n        v = random_tensor(ndim=2, dim0=5, dim1=5)\n        result_tensor = torch.autograd.functional.jvp(_func_tensor, inputs, v)\n\n        inputs = (\n            random_tensor(ndim=2, dim0=5, dim1=5),\n            random_tensor(ndim=2, dim0=5, dim1=5),\n        )\n        v = (\n            random_tensor(ndim=2, dim0=5, dim1=5),\n            random_tensor(ndim=2, dim0=5, dim1=5),\n        )\n        result_tensors = torch.autograd.functional.jvp(_func_multi_tensor, inputs, v)\n\n        inputs = random_tensor(1)\n        result_scalar2tensor = torch.autograd.functional.jvp(\n            _func_scalar2tensor, inputs\n        )\n\n    @autotest(n=1, check_graph=False)\n    def test_vhp(test_case):\n        inputs = random_tensor(ndim=2, dim0=5, dim1=5)\n        v = random_tensor(ndim=2, dim0=5, dim1=5)\n        result_tensor = torch.autograd.functional.vhp(_func_scalar, inputs, v)\n\n        inputs = (\n            random_tensor(ndim=2, dim0=5, dim1=5),\n            random_tensor(ndim=2, dim0=5, dim1=5),\n        )\n        v = (\n            random_tensor(ndim=2, dim0=5, dim1=5),\n            random_tensor(ndim=2, dim0=5, dim1=5),\n        )\n        result_tensors = torch.autograd.functional.vhp(_func_multi_scalar, inputs, v)\n\n    @autotest(n=1, check_graph=False)\n    def test_hvp(test_case):\n        inputs = random_tensor(ndim=2, dim0=5, dim1=5)\n        v = random_tensor(ndim=2, dim0=5, dim1=5)\n        result_tensor = torch.autograd.functional.hvp(_func_scalar, inputs, v)\n\n        inputs = (\n            random_tensor(ndim=2, dim0=5, dim1=5),\n            random_tensor(ndim=2, dim0=5, dim1=5),\n        )\n        v = (\n            random_tensor(ndim=2, dim0=5, dim1=5),\n            random_tensor(ndim=2, dim0=5, dim1=5),\n        )\n        result_tensors = torch.autograd.functional.hvp(_func_multi_scalar, inputs, v)\n\n    # TODO: \"'jacobian' and 'hessian' has no strategy parameter in PyTorch before '1.11.0'\"\n    @autotest(n=1, check_graph=False)\n    def test_jacobian(test_case):\n        inputs = random_tensor(ndim=2, dim0=5, dim1=5)\n        if version.parse(torch.pytorch.__version__) < version.parse(\"1.11.0\"):\n            result_tensor = torch.autograd.functional.jacobian(\n                _func_tensor, inputs, vectorize=False\n            )\n        else:\n            result_tensor = torch.autograd.functional.jacobian(\n                _func_tensor, inputs, vectorize=False, strategy=\"reverse-mode\"\n            )\n\n        inputs = (\n            random_tensor(ndim=2, dim0=5, dim1=5),\n            random_tensor(ndim=2, dim0=5, dim1=5),\n        )\n        if version.parse(torch.pytorch.__version__) < version.parse(\"1.11.0\"):\n            result_tensors = torch.autograd.functional.jacobian(\n                _func_multi_scalar, inputs, vectorize=False\n            )\n        else:\n            result_tensors = torch.autograd.functional.jacobian(\n                _func_multi_scalar, inputs, vectorize=False, strategy=\"reverse-mode\"\n            )\n\n    @autotest(n=1, check_graph=False)\n    def test_hessian(test_case):\n        inputs = random_tensor(ndim=2, dim0=5, dim1=5)\n        if version.parse(torch.pytorch.__version__) < version.parse(\"1.11.0\"):\n            result_tensor = torch.autograd.functional.hessian(\n                _func_scalar, inputs, vectorize=False,\n            )\n        else:\n            result_tensor = torch.autograd.functional.hessian(\n                _func_scalar,\n                inputs,\n                vectorize=False,\n                outer_jacobian_strategy=\"reverse-mode\",\n            )\n\n        inputs = (\n            random_tensor(ndim=2, dim0=5, dim1=5),\n            random_tensor(ndim=2, dim0=5, dim1=5),\n        )\n        if version.parse(torch.pytorch.__version__) < version.parse(\"1.11.0\"):\n            result_tensors = torch.autograd.functional.hessian(\n                _func_multi_scalar, inputs, vectorize=False,\n            )\n        else:\n            result_tensors = torch.autograd.functional.hessian(\n                _func_multi_scalar,\n                inputs,\n                vectorize=False,\n                outer_jacobian_strategy=\"reverse-mode\",\n            )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/misc/test_distributed_env_vars.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport os\nimport unittest\nimport oneflow as flow\nimport oneflow.unittest\n\n\nclass TestDistributedEnvVars(flow.unittest.TestCase):\n    @flow.unittest.skip_unless_1n1d()\n    def test_default(test_case):\n        test_case.assertFalse(\"MASTER_ADDR\" in os.environ)\n        test_case.assertFalse(\"MASTER_PORT\" in os.environ)\n        test_case.assertFalse(\"WORLD_SIZE\" in os.environ)\n        test_case.assertFalse(\"RANK\" in os.environ)\n        test_case.assertFalse(\"LOCAL_RANK\" in os.environ)\n        test_case.assertEqual(flow.distributed.get_world_size(), 1)\n        test_case.assertEqual(flow.distributed.get_rank(), 0)\n        test_case.assertEqual(flow.distributed.get_local_rank(), 0)\n\n    @flow.unittest.skip_unless_1n2d()\n    def test_1n2d(test_case):\n        test_case.assertEqual(os.environ[\"MASTER_ADDR\"], \"127.0.0.1\")\n        test_case.assertEqual(os.environ[\"WORLD_SIZE\"], \"2\")\n        test_case.assertTrue(os.environ[\"RANK\"] in [\"0\", \"1\"])\n        test_case.assertTrue(os.environ[\"LOCAL_RANK\"] in [\"0\", \"1\"])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/misc/test_empty_cache.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport os\nimport unittest\nimport oneflow as flow\nimport oneflow.unittest\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n@flow.unittest.skip_unless_1n1d()\nclass TestEmptyCache(flow.unittest.TestCase):\n    def test_cuda_to_cpu_empty_cache(test_case):\n        if flow._oneflow_internal.flags.with_cuda():\n\n            x = flow.randn(512, 3, 512, 512).to(\"cuda\")\n            used_mem1 = flow._oneflow_internal.GetCUDAMemoryUsed()\n\n            x = x.cpu()\n            used_mem2 = flow._oneflow_internal.GetCUDAMemoryUsed()\n\n            flow.cuda.empty_cache()\n            used_mem3 = flow._oneflow_internal.GetCUDAMemoryUsed()\n            test_case.assertTrue((used_mem3 < used_mem1) and (used_mem3 < used_mem2))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/misc/test_env_cuda.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport os\nimport unittest\nimport oneflow as flow\nfrom oneflow.test_utils.automated_test_util.generators import nothing, oneof, random\nfrom oneflow.test_utils.automated_test_util import torch\nimport oneflow.unittest\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n@flow.unittest.skip_unless_1n2d()\nclass TestEnv(flow.unittest.TestCase):\n    def test_get_device_count(test_case):\n        test_case.assertEqual(flow.cuda.device_count(), 2)\n\n    def test_current_device_idx(test_case):\n        test_case.assertEqual(flow.cuda.current_device(), flow.env.get_rank())\n\n    def test_cuda_is_available(test_case):\n        test_case.assertEqual(flow.cuda.is_available(), True)\n\n    def test_cuda_synchronize(test_case):\n        flow.cuda.synchronize()\n        flow.cuda.synchronize(\"cuda\")\n        flow.cuda.synchronize(\"cuda:0\")\n        flow.cuda.synchronize(\"cuda:1\")\n        flow.cuda.synchronize(0)\n        flow.cuda.synchronize(1)\n        flow.cuda.synchronize(flow.device(\"cuda:0\"))\n        flow.cuda.synchronize(flow.device(\"cuda:1\"))\n\n        with test_case.assertRaisesRegex(ValueError, \"Expected a cuda device, but\"):\n            flow.cuda.synchronize(flow.device(\"cpu\"))\n\n        with test_case.assertRaisesRegex(ValueError, \"Expected a cuda device, but\"):\n            flow.cuda.synchronize(\"cpu\")\n\n    def test_cuda_get_device_name(test_case):\n        return torch.cuda.get_device_name(oneof(0, nothing()))\n\n    def test_cuda_get_device_capability(test_case):\n        return torch.cuda.get_device_capability(oneof(0, nothing()))\n\n    def test_cuda_mem_get_info(test_case):\n        device_idx = random(0, flow.cuda.device_count()).to(int).value()\n        return torch.cuda.mem_get_info(device_idx)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/misc/test_manual_seed_api.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport os\nimport unittest\n\nimport numpy as np\nimport oneflow as flow\n\nimport oneflow.unittest\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n@flow.unittest.skip_unless_1n2d()\nclass TestManualSeedApi(flow.unittest.TestCase):\n    def test_cuda_manual_seed_all(test_case):\n        flow.cuda.manual_seed_all(20)\n        x = flow.randn(2, 4, device=\"cuda:0\")\n        y = flow.randn(2, 4, device=\"cuda:1\")\n        test_case.assertTrue(np.allclose(x.numpy(), y.numpy()))\n\n    def test_cuda_manual_seed(test_case):\n        flow.cuda.manual_seed(30)\n        device = flow.device(\"cuda\", flow.cuda.current_device())\n        x = flow.randn(2, 4, device=device)\n        tensor_list = [flow.zeros((2, 4), dtype=flow.int32) for _ in range(2)]\n        flow.comm.all_gather(tensor_list, x)\n        test_case.assertTrue(\n            np.allclose(tensor_list[0].numpy(), tensor_list[1].numpy())\n        )\n\n    def test_manual_seed(test_case):\n        flow.manual_seed(40)\n        x = flow.randn(2, 4, device=\"cuda:0\")\n        y = flow.randn(2, 4, device=\"cuda:1\")\n        test_case.assertTrue(np.allclose(x.numpy(), y.numpy()))\n\n    def test_set_get_rng_state(test_case):\n        x = flow.ByteTensor(5000)\n        flow.set_rng_state(x)\n        y = flow.get_rng_state()\n        test_case.assertTrue(np.allclose(x.numpy(), y.numpy()))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/misc/test_mock_diffusers.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nimport oneflow as flow\nimport oneflow.unittest\n\n\"\"\"\nIf some modules import torch internally, \nflow.mock_torch.disable() should be able to restore the original torch within these modules.\n\"\"\"\n\n\nclass TestMock(flow.unittest.TestCase):\n    def test_mock_diffusers(test_case):\n\n        flow.mock_torch.enable(lazy=True)\n        from diffusers import UNet2DConditionModel\n\n        torch_module = UNet2DConditionModel.__dict__[\"forward\"].__globals__[\"torch\"]\n\n        flow.mock_torch.disable()\n        from diffusers import UNet2DConditionModel\n\n        torch_module = UNet2DConditionModel.__dict__[\"forward\"].__globals__[\"torch\"]\n\n        # check whether the torch module is the original torch\n        test_case.assertFalse(isinstance(torch_module, flow.mock_torch.ModuleWrapper))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/misc/test_mock_scope.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport os\nimport unittest\nimport oneflow as flow\nimport oneflow.unittest\nimport oneflow.mock_torch as mock\n\n\"\"\"\nenable & disable mode hold a dict[str, ModuleType] like sys.modules, the keys start with 'torch'.\nThe two modes don't interfere with each other, sys.modules and global scope are replaced on switch.\n\"\"\"\n\n\nwith mock.enable():\n    import torch\n    import torch.nn\n    import torch.version\nwith mock.disable():\n    import torch\n    import torch.nn\n    import torch.version\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestMock(flow.unittest.TestCase):\n    def test_with(test_case):\n        with mock.enable():\n            test_case.assertEqual(torch.__package__, \"oneflow\")\n            test_case.assertEqual(torch.nn.__package__, \"oneflow.nn\")\n            test_case.assertEqual(torch.version.__version__, flow.__version__)\n        with mock.disable():\n            test_case.assertEqual(torch.__package__, \"torch\")\n            test_case.assertEqual(torch.nn.__package__, \"torch.nn\")\n            test_case.assertEqual(torch.version.__version__, torch.__version__)\n\n    def test_simple(test_case):\n        mock.enable()\n        test_case.assertEqual(torch.__package__, \"oneflow\")\n        test_case.assertEqual(torch.nn.__package__, \"oneflow.nn\")\n        test_case.assertEqual(torch.version.__version__, flow.__version__)\n\n        mock.disable()\n\n        test_case.assertEqual(torch.__package__, \"torch\")\n        test_case.assertEqual(torch.nn.__package__, \"torch.nn\")\n        test_case.assertEqual(torch.version.__version__, torch.__version__)\n\n    def test_import_from(test_case):\n        mock.enable()\n        from torch import nn\n        from torch.version import __version__\n\n        test_case.assertEqual(nn.__package__, \"oneflow.nn\")\n        test_case.assertEqual(__version__, flow.__version__)\n\n        mock.disable()\n        from torch import nn\n        from torch.version import __version__\n\n        test_case.assertEqual(nn.__package__, \"torch.nn\")\n        test_case.assertEqual(__version__, torch.__version__)\n\n    def test_error(test_case):\n        mock.enable()\n        with test_case.assertRaises(ImportError) as context:\n            from torch import noexist\n        test_case.assertTrue(\n            \"cannot import name 'noexist' from 'oneflow'\" in str(context.exception)\n        )\n        with test_case.assertRaises(ModuleNotFoundError) as context:\n            import torch.noexist\n        test_case.assertTrue(\n            \"oneflow.noexist is not implemented\" in str(context.exception)\n        )\n        mock.disable()\n        with test_case.assertRaises(ImportError) as context:\n            from torch import noexist\n        test_case.assertTrue(\n            \"cannot import name 'noexist' from 'torch'\" in str(context.exception)\n        )\n        with test_case.assertRaises(ModuleNotFoundError) as context:\n            import torch.noexist\n        test_case.assertTrue(\n            \"No module named 'torch.noexist'\" in str(context.exception)\n        )\n\n    def test_nested_with(test_case):\n        with mock.enable():\n            test_case.assertEqual(torch.__package__, \"oneflow\")\n            with mock.disable():\n                test_case.assertEqual(torch.__package__, \"torch\")\n            test_case.assertEqual(torch.__package__, \"oneflow\")\n        with mock.disable():\n            test_case.assertEqual(torch.__package__, \"torch\")\n            with mock.enable():\n                test_case.assertEqual(torch.__package__, \"oneflow\")\n            test_case.assertEqual(torch.__package__, \"torch\")\n\n    def test_noop_disable(test_case):\n        with mock.disable():\n            import torch\n\n            test_case.assertEqual(torch.__package__, \"torch\")\n\n    @unittest.skip(\"skip for now, becase it failed 2 times in past week\")\n    def test_3rd_party(test_case):\n        with mock.enable():\n            from mock_example import f\n\n            test_case.assertEqual(f(), \"oneflow\")\n\n    def test_env_var(test_case):\n        os.environ[\"ONEFLOW_DISABLE_MOCK_TORCH\"] = \"1\"\n\n        with mock.enable():\n            import torch\n\n            test_case.assertEqual(torch.__package__, \"torch\")\n\n        os.environ[\"ONEFLOW_DISABLE_MOCK_TORCH\"] = \"0\"\n\n    def test_dummy_obj_fallback(test_case):\n        with mock.enable(lazy=True):\n            from torch import not_exist\n\n            test_case.assertEqual(not_exist.__name__, \"oneflow.not_exist\")\n            x = not_exist.x\n            test_case.assertEqual(x.__name__, \"oneflow.not_exist.x\")\n\n    def test_mock_torchvision(test_case):\n        with mock.enable(lazy=True):\n            import torchvision\n\n            model = torchvision.models.resnet18(pretrained=False)\n            test_case.assertEqual(len(list(model.parameters())), 62)\n\n    def test_mock_lazy_for_loop(test_case):\n        with mock.enable(lazy=True):\n            import torch\n\n            # Test no infinite loop\n            for _ in torch.not_exist:\n                pass\n\n    def test_mock_lazy_in_if(test_case):\n        with mock.enable(lazy=True):\n            import torch\n\n            if torch.not_exist:\n                test_case.assertTrue(False)\n\n    def test_hazard_list(test_case):\n        with mock.enable():\n            import sys\n            import safetensors\n        test_case.assertTrue(\"safetensors._safetensors_rust\" in sys.modules)\n        import safetensors\n\n    def test_isinstance(test_case):\n        with mock.enable(lazy=True):\n            import torch\n\n            test_case.assertFalse(isinstance(int, torch._six.string_class))\n\n    def test_with_statement(test_case):\n        with mock.enable(lazy=True):\n            with test_case.assertRaises(RuntimeError) as context:\n                import torch.noexist\n\n                with torch.noexist:\n                    pass\n            test_case.assertTrue(\n                '\"oneflow.noexist\" is a dummy object, and does not support \"with\" statement.'\n                in str(context.exception)\n            )\n\n    def test_setattr(test_case):\n        with mock.enable():\n            import torch\n\n            torch.nn.Linear_forward_before_lora = torch.nn.Linear.forward\n            test_case.assertEqual(\n                torch.nn.Linear_forward_before_lora, torch.nn.Linear.forward\n            )\n\n    def test_hasattr_and_getattr_in_lazy_mode(test_case):\n        with mock.enable(lazy=True):\n            test_case.assertFalse(hasattr(torch, \"not_exist\"))\n            test_case.assertFalse(hasattr(torch.nn.functional, \"not_exist\"))\n            test_case.assertTrue(isinstance(torch.not_exist, mock.DummyModule))\n            test_case.assertTrue(\n                isinstance(torch.nn.functional.not_exist, mock.DummyModule)\n            )\n\n            import torch.nn.functional as F\n\n            test_case.assertFalse(hasattr(F, \"scaled_dot_product_attention\"))\n            test_case.assertFalse(\n                hasattr(torch.nn.functional, \"scaled_dot_product_attention\")\n            )\n\n    def test_mock_extra_dict(test_case):\n        with mock.enable(lazy=True, extra_dict={\"torchvision\": \"flowvision\"}):\n            import torchvision\n\n            test_case.assertEqual(torchvision.models.__package__, \"flowvision.models\")\n\n\n# MUST use pytest to run this test\ndef test_verbose(capsys):\n    with mock.enable(lazy=True, verbose=True):\n        import torch.not_exist\n\n        captured = capsys.readouterr()\n        assert \"oneflow.not_exist is not found in oneflow\" in captured.out\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/misc/test_np_dtype_converter.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport os\nimport unittest\nimport numpy as np\nimport oneflow as flow\nimport oneflow.unittest\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n@flow.unittest.skip_unless_1n1d()\nclass TestNpDtypeConverter(flow.unittest.TestCase):\n    def test_np_dtype_converter(test_case):\n        for flow_dtype in flow.dtypes():\n            if flow_dtype in [flow.record, flow.tensor_buffer, flow.bfloat16]:\n                continue\n            np_dtype = flow.convert_oneflow_dtype_to_numpy_dtype(flow_dtype)\n            test_case.assertEqual(\n                flow.framework.dtype.convert_numpy_dtype_to_oneflow_dtype(np_dtype),\n                flow_dtype,\n            )\n\n            # Test whether dtype conversion works with arr.dtype\n            np_arr = np.array([1, 2], dtype=np_dtype)\n            test_case.assertEqual(np_arr.dtype, np_dtype)\n            flow_tensor = flow.tensor([1, 2], dtype=flow_dtype)\n            test_case.assertEqual(flow_tensor.dtype, flow_dtype)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/misc/test_placement.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport os\nimport unittest\nimport oneflow as flow\nimport oneflow.unittest\n\n\nclass TestPlacement(flow.unittest.TestCase):\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    def test_placement_all_cuda(test_case):\n        placement = flow.placement.all(\"cuda\")\n        test_case.assertEqual(placement.type, \"cuda\")\n        # assertEqual fails to compare lists\n        test_case.assertTrue(\n            list(placement.ranks) == list(range(flow.env.get_world_size()))\n        )\n\n    @unittest.skip(\"skip for now, becase it failed 10 times in past week\")\n    def test_placement_all_cpu(test_case):\n        placement = flow.placement.all(\"cpu\")\n        test_case.assertEqual(placement.type, \"cpu\")\n        # assertEqual fails to compare lists\n        test_case.assertTrue(\n            list(placement.ranks) == list(range(flow.env.get_world_size()))\n        )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/misc/test_pybind11_caster.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nimport oneflow as flow\nimport oneflow.unittest\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestPybind11Caster(flow.unittest.TestCase):\n    def test_optional(test_case):\n        test_case.assertEqual(\n            flow._oneflow_internal.test_api.increase_if_not_none(1), 2\n        )\n        test_case.assertEqual(\n            flow._oneflow_internal.test_api.increase_if_not_none(None), None\n        )\n\n    def test_maybe(test_case):\n        test_case.assertEqual(flow._oneflow_internal.test_api.divide(6, 2), 3)\n\n    def test_maybe_void(test_case):\n        flow._oneflow_internal.test_api.throw_if_zero(1)\n\n    def test_return_maybe_shared_ptr(test_case):\n        a1 = flow._oneflow_internal.test_api.get_singleton_a()\n        x1 = a1.get_x()\n        a1.inc_x()\n\n        a2 = flow._oneflow_internal.test_api.get_singleton_a()\n        x2 = a2.get_x()\n\n        test_case.assertEqual(id(a1), id(a2))\n        test_case.assertEqual(x1 + 1, x2)\n\n    def test_pass_optional_shared_ptr(test_case):\n        a1 = flow._oneflow_internal.test_api.get_singleton_a()\n        x1 = a1.get_x()\n        a1.inc_x()\n\n        a2 = flow._oneflow_internal.test_api.increase_x_of_a_if_not_none(a1)\n        x2 = a2.get_x()\n\n        test_case.assertEqual(id(a1), id(a2))\n        test_case.assertEqual(x1 + 2, x2)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/image_test_util.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport os\nimport random\n\nimport cv2\nimport numpy as np\nimport PIL\n\nimport oneflow as flow\n\nglobal_coco_dict = dict()\ndefault_coco_anno_file = flow.unittest.dataset_dir(\n    \"mscoco_2017/annotations/instances_val2017.json\"\n)\ndefault_coco_image_dir = flow.unittest.dataset_dir(\"mscoco_2017/val2017\")\n\n\ndef get_coco(anno_file):\n    global global_coco_dict\n    if anno_file not in global_coco_dict:\n        from pycocotools.coco import COCO\n\n        global_coco_dict[anno_file] = COCO(anno_file)\n    return global_coco_dict[anno_file]\n\n\ndef random_sample_images_from_coco(\n    anno_file=default_coco_anno_file, image_dir=default_coco_image_dir, batch_size=2\n):\n    image_files = []\n    image_ids = []\n    batch_group_id = -1\n    coco = get_coco(anno_file)\n    img_ids = coco.getImgIds()\n    while len(image_files) < batch_size:\n        rand_img_id = random.choice(img_ids)\n        img_h = coco.imgs[rand_img_id][\"height\"]\n        img_w = coco.imgs[rand_img_id][\"width\"]\n        group_id = int(img_h / img_w)\n        if batch_group_id == -1:\n            batch_group_id = group_id\n        if group_id != batch_group_id:\n            continue\n        image_files.append(os.path.join(image_dir, coco.imgs[rand_img_id][\"file_name\"]))\n        image_ids.append(rand_img_id)\n    assert len(image_files) == len(image_ids)\n    return (image_files, image_ids)\n\n\ndef read_images_by_cv(image_files, dtype, channels=3):\n    np_dtype = flow.convert_oneflow_dtype_to_numpy_dtype(dtype)\n    images = [cv2.imread(image_file).astype(np_dtype) for image_file in image_files]\n    assert all((isinstance(image, np.ndarray) for image in images))\n    assert all((image.ndim == 3 for image in images))\n    assert all((image.shape[2] == channels for image in images))\n    return images\n\n\ndef read_images_by_pil(image_files, dtype, channels=3):\n    image_objs = [PIL.Image.open(image_file) for image_file in image_files]\n    images = []\n    np_dtype = flow.convert_oneflow_dtype_to_numpy_dtype(dtype)\n    for im in image_objs:\n        bands = im.getbands()\n        band = \"\".join(bands)\n        if band == \"RGB\":\n            images.append(np.asarray(im).astype(np_dtype)[:, :, ::-1])\n        elif band == \"L\":\n            gs_image = np.asarray(im).astype(np_dtype)\n            gs_image_shape = gs_image.shape\n            assert len(gs_image_shape) == 2\n            gs_image = gs_image.reshape(gs_image_shape + (1,))\n            gs_image = np.broadcast_to(gs_image, shape=gs_image_shape + (3,))\n            images.append(gs_image)\n        elif band == \"BGR\":\n            images.append(np.asarray(im).astype(np_dtype))\n        else:\n            raise NotImplementedError\n    assert all((isinstance(image, np.ndarray) for image in images))\n    assert all((image.ndim == 3 for image in images))\n    assert all((image.shape[2] == channels for image in images))\n    return images\n\n\ndef infer_images_static_shape(images, channels=3):\n    image_shapes = [image.shape for image in images]\n    assert all((image.ndim == 3 for image in images))\n    assert all((image.shape[2] == channels for image in images))\n    image_shapes = np.asarray(image_shapes)\n    max_h = np.max(image_shapes[:, 0]).item()\n    max_w = np.max(image_shapes[:, 1]).item()\n    image_static_shape = (len(images), max_h, max_w, channels)\n    group_ids = []\n    aspect_ratio_list = []\n    for image_shape in image_shapes:\n        (h, w) = image_shape[0:2]\n        if h < w:\n            group_id = 0\n            aspect_ratio = h / w\n        else:\n            group_id = 1\n            aspect_ratio = w / h\n        group_ids.append(group_id)\n        aspect_ratio_list.append(aspect_ratio)\n    assert all((group_id == group_ids[0] for group_id in group_ids))\n    return (image_static_shape, aspect_ratio_list)\n\n\ndef compute_keep_aspect_ratio_resized_size(\n    target_size, min_size, max_size, aspect_ratio, resize_side\n):\n    if resize_side == \"shorter\":\n        min_res_size = target_size\n        max_res_size = int(round(min_res_size / aspect_ratio))\n        if max_size is not None and max_res_size > max_size:\n            max_res_size = max_size\n            min_res_size = int(round(max_res_size * aspect_ratio))\n    elif resize_side == \"longer\":\n        max_res_size = target_size\n        min_res_size = int(round(max_res_size * aspect_ratio))\n        if min_size is not None and min_res_size < min_size:\n            min_res_size = min_size\n            max_res_size = int(round(min_res_size / aspect_ratio))\n    else:\n        raise NotImplementedError\n    return (min_res_size, max_res_size)\n\n\ndef infer_keep_aspect_ratio_resized_images_static_shape(\n    target_size,\n    min_size,\n    max_size,\n    aspect_ratio_list,\n    resize_side=\"shorter\",\n    channels=3,\n):\n    resized_size_list = []\n    for aspect_ratio in aspect_ratio_list:\n        resized_size_list.append(\n            compute_keep_aspect_ratio_resized_size(\n                target_size, min_size, max_size, aspect_ratio, resize_side\n            )\n        )\n    (res_min_size, res_max_size) = max(\n        resized_size_list, key=lambda size: size[0] * size[1]\n    )\n    return (res_min_size, res_max_size, channels)\n"
  },
  {
    "path": "python/oneflow/test/modules/optimizer_test_util.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport numpy as np\n\n\ndef clip_grad_norm_np(np_grad, max_norm, norm_type):\n    np_grad_is_list = True\n    if isinstance(np_grad, np.ndarray):\n        np_grad_is_list = False\n        np_grad = [np_grad]\n\n    max_norm = float(max_norm)\n    norm_type = float(norm_type)\n    if norm_type == float(\"inf\"):\n        total_norm = np.max(np.abs(np_grad))\n    elif norm_type == float(\"-inf\"):\n        total_norm = np.min(np.abs(np_grad))\n    else:\n        norms = np_grad\n        total_norm = []\n        for i, norm in enumerate(norms):\n            for j in range(np_grad[i].ndim, 0, -1):\n                norm = np.linalg.norm(norm, norm_type, axis=j - 1)\n            total_norm.append(norm)\n        total_norm = np.linalg.norm(np.array(total_norm, dtype=np.float32), norm_type)\n\n    clip_coef = max_norm / (total_norm + 1e-6)\n    if clip_coef < 1:\n        for grad in np_grad:\n            grad *= clip_coef\n\n    if not np_grad_is_list:\n        np_grad = np_grad[0]\n    return total_norm, np_grad\n"
  },
  {
    "path": "python/oneflow/test/modules/save_load_test_data/3x3_i3o3_conv2d/tensor_3/meta",
    "content": "shape {\n  dim: 3\n  dim: 3\n  dim: 3\n  dim: 3\n}\ndata_type: kFloat\n"
  },
  {
    "path": "python/oneflow/test/modules/save_load_test_data/3x3_i3o3_conv2d/tensor_4/meta",
    "content": "shape {\n  dim: 3\n}\ndata_type: kFloat\n"
  },
  {
    "path": "python/oneflow/test/modules/save_load_test_data/3x3_i3o3_conv2d/tensor_4/out",
    "content": "w)I\u0011"
  },
  {
    "path": "python/oneflow/test/modules/save_load_test_data/3x3_i3o3_conv2d_params/tensor_5/meta",
    "content": "shape {\n  dim: 3\n  dim: 3\n  dim: 3\n  dim: 3\n}\ndata_type: kFloat\n"
  },
  {
    "path": "python/oneflow/test/modules/save_load_test_data/3x3_i3o3_conv2d_params/tensor_6/meta",
    "content": "shape {\n  dim: 3\n}\ndata_type: kFloat\n"
  },
  {
    "path": "python/oneflow/test/modules/save_load_test_data/3x3_i3o3_conv2d_params/tensor_6/out",
    "content": "w)I\u0011"
  },
  {
    "path": "python/oneflow/test/modules/sync_batchnorm_test_util.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow as flow\n\nONEREC_URL = (\n    \"https://oneflow-public.oss-cn-beijing.aliyuncs.com/sync_bn_test_datas.tar.gz\"\n)\nMD5 = \"537ff00fb47be8be90df75f47a883b76\"\n\n\ndef md5(fname):\n    import hashlib\n\n    hash_md5 = hashlib.md5()\n    with open(fname, \"rb\") as f:\n        for chunk in iter(lambda: f.read(4096), b\"\"):\n            hash_md5.update(chunk)\n    result = hash_md5.hexdigest()\n    print(\"md5\", fname, result)\n    return result\n\n\ndef download_file(out_path: str, url):\n    import requests\n    from tqdm import tqdm\n\n    resp = requests.get(url=url, stream=True)\n    MB = 1024 ** 2\n    size = int(resp.headers[\"Content-Length\"]) / MB\n    print(\"File size: %.4f MB, downloading...\" % size)\n    with open(out_path, \"wb\") as f:\n        for data in tqdm(\n            iterable=resp.iter_content(MB), total=size, unit=\"m\", desc=out_path\n        ):\n            f.write(data)\n        print(\"Done!\")\n\n\ndef ensure_datas():\n    import os\n    import pathlib\n\n    data_dir = os.path.join(\n        os.getenv(\"ONEFLOW_TEST_CACHE_DIR\", \"./data-test\"), \"sync_bn\"\n    )\n    file_path = pathlib.Path(data_dir) / ONEREC_URL.split(\"/\")[-1]\n    absolute_file_path = str(file_path.absolute())\n\n    if flow.env.get_rank() == 0:\n        file_path.parent.mkdir(parents=True, exist_ok=True)\n        if file_path.exists():\n            if MD5 != md5(absolute_file_path):\n                file_path.unlink()\n                download_file(absolute_file_path, ONEREC_URL)\n        else:\n            download_file(str(absolute_file_path), ONEREC_URL)\n        assert MD5 == md5(absolute_file_path)\n\n        import tarfile\n\n        my_tar = tarfile.open(str(absolute_file_path))\n        my_tar.extractall(data_dir)  # specify which folder to extract to\n        my_tar.close()\n\n    flow.comm.barrier()\n\n    return os.path.join(\n        os.getenv(\"ONEFLOW_TEST_CACHE_DIR\", \"./data-test\"),\n        \"sync_bn\",\n        \"sync_bn_test_datas\",\n    )\n"
  },
  {
    "path": "python/oneflow/test/modules/test_0_dim_tensor.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\n\nimport oneflow as flow\n\n\nfrom oneflow.test_utils.test_util import GenArgList\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\ndef _test_0_dim_tensor(test_case, device):\n    scalar = 9.999\n    input_np = np.array(scalar)\n    input = flow.tensor(input_np, device=device)\n\n    test_case.assertEqual(input.numel(), 1)\n    test_case.assertEqual(input.ndimension(), 0)\n\n    x1 = flow.tensor(np.array(2), dtype=flow.float32, device=device)\n    x2 = flow.tensor(np.array(3), dtype=flow.float32, device=device)\n    y1 = x1 * x2\n    y2 = x1 + x2\n    test_case.assertEqual(y1.numpy(), 6.0)\n    test_case.assertEqual(y2.numpy(), 5.0)\n\n\ndef _test_scalar_mul(test_case, device):\n    for dim in range(5):\n        test_case.assertEqual(\n            np.ones([2] * dim).sum(), flow.ones([2] * dim, device=device).sum().numpy()\n        )\n\n\ndef _test_slice(test_case, device):\n    x = flow.tensor(np.arange(10), device=device)\n    for i in range(x.numel()):\n        scalar_i = x[i]\n        test_case.assertEqual(i, scalar_i.numpy())\n        test_case.assertEqual(scalar_i.numel(), 1)\n        test_case.assertEqual(scalar_i.ndimension(), 0)\n\n\ndef _test_slice_backward(test_case, device):\n    np_grad = np.zeros(10)\n    x = flow.tensor(np.arange(10).astype(np.float32), device=device, requires_grad=True)\n    for i in range(x.numel()):\n        y = x[i]\n        z = y.sum()\n        z.backward()\n        np_grad[i] = 1\n        test_case.assertTrue(np.allclose(x.grad.numpy(), np_grad, 1e-04, 1e-04))\n\n    x2 = flow.tensor(\n        np.arange(100).astype(np.float32), device=device, requires_grad=True\n    )\n    y2 = x2[1:100]\n    z2 = y2.sum()\n    z2.backward()\n    np_grad2 = np.ones(100)\n    np_grad2[0] = 0\n    test_case.assertTrue(np.allclose(x2.grad.numpy(), np_grad2, 1e-04, 1e-04))\n\n\ndef _test_slice_scalar_graph(test_case, device):\n    x = flow.tensor(3.0, device=device)\n\n    class MyModule(flow.nn.Module):\n        def __init__(self):\n            super().__init__()\n            self.weight = flow.nn.Parameter(\n                flow.tensor([1.0, 2.0, 3.0, 4.0], device=device)\n            )\n\n        def forward(self, x):\n            return x * self.weight[3]\n\n    my_module = MyModule()\n    of_eager_out = my_module(x)\n\n    class ScalarGraph(flow.nn.Graph):\n        def __init__(self):\n            super().__init__()\n            self.m = my_module\n\n        def build(self, x):\n            return self.m(x)\n\n    scalar_g = ScalarGraph()\n    of_lazy_out = scalar_g(x)\n    test_case.assertTrue(np.array_equal(of_lazy_out.numpy(), of_eager_out.numpy()))\n\n\ndef _test_slice_scalar_train_graph(test_case, device):\n    class MyModule(flow.nn.Module):\n        def __init__(self):\n            super().__init__()\n            self.weight = flow.nn.Parameter(\n                flow.tensor([1.0, 2.0, 3.0, 4.0], device=device)\n            )\n\n        def forward(self, x):\n            return x * self.weight[3] + 1.0\n\n    my_module = MyModule()\n    of_sgd = flow.optim.SGD(my_module.parameters(), lr=0.001, momentum=0.9)\n    eager_out_list = []\n    for i in range(3):\n        x = flow.tensor(i * 1.0, device=device, requires_grad=False)\n        of_eager_out = my_module(x)\n        of_eager_out.backward()\n        of_sgd.step()\n        of_sgd.zero_grad()\n        eager_out_list.append(of_eager_out)\n\n    lazy_module = MyModule()\n\n    class ScalarTrainGraph(flow.nn.Graph):\n        def __init__(self):\n            super().__init__()\n            self.m = lazy_module\n            of_sgd = flow.optim.SGD(lazy_module.parameters(), lr=0.001, momentum=0.9)\n            self.add_optimizer(of_sgd)\n\n        def build(self, x):\n            loss = self.m(x)\n            loss.backward()\n            return loss\n\n    lazy_out_list = []\n    scalar_g = ScalarTrainGraph()\n    for i in range(3):\n        x = flow.tensor(i * 1.0, device=device)\n        of_lazy_out = scalar_g(x)\n        lazy_out_list.append(of_lazy_out)\n\n    for i in range(3):\n        test_case.assertTrue(\n            np.array_equal(lazy_out_list[i].numpy(), eager_out_list[i].numpy())\n        )\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestZeroDimensionTensor(flow.unittest.TestCase):\n    def test_0_dim_tensor(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [\n            _test_0_dim_tensor,\n            _test_scalar_mul,\n            _test_slice,\n            _test_slice_backward,\n            _test_slice_scalar_graph,\n            _test_slice_scalar_train_graph,\n        ]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_TripletMarginLoss.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nfrom oneflow.test_utils.test_util import GenArgList\n\nfrom oneflow.test_utils.automated_test_util import *\nimport oneflow as flow\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestTripletMarginLoss(flow.unittest.TestCase):\n    @unittest.skip(\"skip for now, becase it failed 2 times in past week\")\n    @autotest(n=10)\n    def test_triplet_marginloss_with_random_data(test_case):\n        margin = random().to(float)\n        p = random().to(float)\n        swap = random_bool()\n        reduction = oneof(\"none\", \"sum\", \"mean\", nothing())\n        m = torch.nn.TripletMarginLoss(\n            margin=margin, p=p, swap=swap, reduction=reduction\n        )\n        m.train(random())\n        device = random_device()\n        m.to(device)\n        shape = random_tensor(ndim=2, dim0=random(1, 8)).pytorch.shape\n        anchor = random_tensor(len(shape), *shape).to(device)\n        pos = random_tensor(len(shape), *shape).to(device)\n        neg = random_tensor(len(shape), *shape).to(device)\n        y = m(anchor, pos, neg)\n        return y\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_abs.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\n\nimport oneflow as flow\nfrom oneflow.test_utils.automated_test_util import *\nimport oneflow.unittest\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestAbsModule(flow.unittest.TestCase):\n    @autotest(n=5, check_graph=True)\n    def test_abs_with_0_size_data(test_case):\n        device = random_device()\n        x = random_tensor().to(device)\n        y = torch.abs(x)\n        return y\n\n    @autotest(n=5, check_graph=True)\n    def test_abs_with_0dim_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=0).to(device)\n        y = torch.abs(x)\n        return y\n\n    @profile(torch.abs)\n    def profile_abs(test_case):\n        torch.abs(torch.ones(1, 128, 28, 28))\n        torch.abs(torch.ones(16, 128, 28, 28))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_activation.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nimport torch as pytorch\n\nfrom oneflow.test_utils.automated_test_util import *\nfrom scipy import special\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestReLUModule(flow.unittest.TestCase):\n    @autotest(n=5)\n    def test_relu_module_with_random_data(test_case):\n        m = torch.nn.ReLU()\n        m.train(random())\n        device = random_device()\n        m.to(device)\n        x = random_tensor().to(device)\n        y = m(x)\n        return y\n\n    @autotest(n=5)\n    def test_relu_module_with_0dim_data(test_case):\n        m = torch.nn.ReLU()\n        m.train(random())\n        device = random_device()\n        m.to(device)\n        x = random_tensor(ndim=0).to(device)\n        y = m(x)\n        return y\n\n    @autotest(n=5, auto_backward=False)\n    def test_relu_module_with_0_size_data(test_case):\n        m = torch.nn.ReLU()\n        m.train(random())\n        device = random_device()\n        m.to(device)\n        x = random_tensor(4, 2, 3, 0, 3).to(device)\n        y = m(x)\n        return y\n\n    @profile(torch.nn.functional.relu)\n    def profile_relu(test_case):\n        torch.nn.functional.relu(torch.ones(1, 128, 28, 28))\n        torch.nn.functional.relu(torch.ones(1, 128, 28, 28), inplace=True)\n        torch.nn.functional.relu(torch.ones(16, 128, 28, 28))\n        torch.nn.functional.relu(torch.ones(16, 128, 28, 28), inplace=True)\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestReLU6Module(flow.unittest.TestCase):\n    @autotest(n=5)\n    def test_relu6_module_with_random_data(test_case):\n        m = torch.nn.ReLU6()\n        m.train(random())\n        device = random_device()\n        m.to(device)\n        x = random_tensor().to(device)\n        y = m(x)\n        return y\n\n    @autotest(n=5)\n    def test_relu6_module_with_0dim_data(test_case):\n        m = torch.nn.ReLU6()\n        m.train(random())\n        device = random_device()\n        m.to(device)\n        x = random_tensor(ndim=0).to(device)\n        y = m(x)\n        return y\n\n    @autotest(n=5, auto_backward=False)\n    def test_relu6_module_with_0_size_data(test_case):\n        m = torch.nn.ReLU6()\n        m.train(random())\n        device = random_device()\n        m.to(device)\n        x = random_tensor(4, 2, 3, 0, 3).to(device)\n        y = m(x)\n        return y\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestTanh(flow.unittest.TestCase):\n    @autotest(n=5)\n    def test_tanh_module_with_random_data(test_case):\n        m = torch.nn.Tanh()\n        m.train(random())\n        device = random_device()\n        m.to(device)\n        x = random_tensor().to(device)\n        y = m(x)\n        return y\n\n    @autotest(n=5)\n    def test_tanh_module_with_0dim_data(test_case):\n        m = torch.nn.Tanh()\n        m.train(random())\n        device = random_device()\n        m.to(device)\n        x = random_tensor(ndim=0).to(device)\n        y = m(x)\n        return y\n\n    @autotest(n=5, auto_backward=False)\n    def test_tanh_module_with_0_size_data(test_case):\n        m = torch.nn.Tanh()\n        m.train(random())\n        device = random_device()\n        m.to(device)\n        x = random_tensor(4, 2, 3, 0, 3).to(device)\n        y = m(x)\n        return y\n\n    @autotest(n=5)\n    def test_flow_tanh_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor().to(device)\n        y = torch.tanh(x)\n        return y\n\n    @autotest(n=5)\n    def test_flow_tanh_with_0dim_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=0).to(device)\n        y = torch.tanh(x)\n        return y\n\n    @autotest(n=5, auto_backward=False)\n    def test_flow_tanh_with_0_size_data(test_case):\n        device = random_device()\n        x = random_tensor(4, 2, 3, 0, 3).to(device)\n        y = torch.tanh(x)\n        return y\n\n    @profile(torch.nn.functional.tanh)\n    def profile_tanh(test_case):\n        torch.nn.functional.tanh(torch.ones(1, 128, 28, 28))\n        torch.nn.functional.tanh(torch.ones(16, 128, 28, 28))\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestELUModule(flow.unittest.TestCase):\n    @autotest(n=5)\n    def test_elu_module_with_random_data(test_case):\n        m = torch.nn.ELU(alpha=random() | nothing())\n        m.train(random())\n        device = random_device()\n        m.to(device)\n        x = random_tensor().to(device)\n        y = m(x)\n        return y\n\n    @autotest(n=5)\n    def test_elu_module_with_0dim_data(test_case):\n        m = torch.nn.ELU(alpha=random() | nothing())\n        m.train(random())\n        device = random_device()\n        m.to(device)\n        x = random_tensor(ndim=0).to(device)\n        y = m(x)\n        return y\n\n    @autotest(n=5, auto_backward=False)\n    def test_elu_module_with_0_size_data(test_case):\n        m = torch.nn.ELU(alpha=random() | nothing())\n        m.train(random())\n        device = random_device()\n        m.to(device)\n        x = random_tensor(4, 2, 3, 0, 3).to(device)\n        y = m(x)\n        return y\n\n    @profile(torch.nn.functional.elu)\n    def profile_elu(test_case):\n        torch.nn.functional.elu(torch.ones(1, 128, 28, 28), 1.0)\n        torch.nn.functional.elu(torch.ones(16, 128, 28, 28), 1.0)\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestCELUModule(flow.unittest.TestCase):\n    @autotest(n=5)\n    def test_celu_module_with_random_data(test_case):\n        m = torch.nn.CELU(alpha=random() | nothing())\n        m.train(random())\n        device = random_device()\n        m.to(device)\n        x = random_tensor().to(device)\n        y = m(x)\n        return y\n\n    @autotest(n=5)\n    def test_celu_module_with_0dim_data(test_case):\n        m = torch.nn.CELU(alpha=random() | nothing())\n        m.train(random())\n        device = random_device()\n        m.to(device)\n        x = random_tensor(ndim=0).to(device)\n        y = m(x)\n        return y\n\n    @autotest(n=5, auto_backward=False)\n    def test_celu_module_with_0_size_data(test_case):\n        m = torch.nn.CELU(alpha=random() | nothing())\n        m.train(random())\n        device = random_device()\n        m.to(device)\n        x = random_tensor(4, 2, 3, 0, 3).to(device)\n        y = m(x)\n        return y\n\n    @autotest(n=10)\n    def test_inplace_celu_module(test_case):\n        m = torch.nn.CELU(alpha=random() | nothing(), inplace=True)\n        device = random_device()\n        m.to(device)\n        x = random_tensor().to(device)\n        y = x + 0.001\n        m(y)\n        return y\n\n    @profile(torch.nn.functional.celu)\n    def profile_celu(test_case):\n        torch.nn.functional.celu(torch.ones(1, 128, 28, 28))\n        torch.nn.functional.celu(torch.ones(1, 128, 28, 28), inplace=True)\n        torch.nn.functional.celu(torch.ones(16, 128, 28, 28))\n        torch.nn.functional.celu(torch.ones(16, 128, 28, 28), inplace=True)\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestGelu(flow.unittest.TestCase):\n    @autotest(n=5)\n    def test_gelu_module_with_random_data(test_case):\n        m = torch.nn.GELU()\n        m.train(random())\n        device = random_device()\n        m.to(device)\n        x = random_tensor().to(device)\n        y = m(x)\n        return y\n\n    @autotest(n=5)\n    def test_gelu_module_with_0dim_data(test_case):\n        m = torch.nn.GELU()\n        m.train(random())\n        device = random_device()\n        m.to(device)\n        x = random_tensor(ndim=0).to(device)\n        y = m(x)\n        return y\n\n    @profile(torch.nn.functional.gelu)\n    def profile_gelu(test_case):\n        torch.nn.functional.gelu(torch.ones(1, 128, 28, 28))\n        torch.nn.functional.gelu(torch.ones(16, 128, 28, 28))\n\n\n@unittest.skipIf(\n    float(pytorch.__version__[:4]) < 1.12,\n    f\"need pytorch version >= 1.12, got {pytorch.__version__}\",\n)\n@flow.unittest.skip_unless_1n1d()\nclass TestFastGelu(flow.unittest.TestCase):\n    @autotest(n=5)\n    def test_fast_gelu(test_case):\n        device = random_device()\n        x = random_tensor().to(device)\n        y = torch.nn.functional.gelu(x, approximate=\"tanh\")\n        return y\n\n    @autotest(n=5, atol=1e-2, rtol=1e-2)\n    def test_fast_gelu_fp16(test_case):\n        x = random_tensor().to(device=gpu_device(), dtype=torch.float16)\n        y = torch.nn.functional.gelu(x, approximate=\"tanh\")\n        return y\n\n    @autotest(n=5)\n    def test_fast_gelu_scalar(test_case):\n        x = random_tensor(ndim=0).to(device=random_device())\n        y = torch.nn.functional.gelu(x, approximate=\"tanh\")\n        return y\n\n    @profile(torch.nn.functional.gelu)\n    def profile_fast_gelu(test_case):\n        torch.nn.functional.gelu(torch.ones(1, 128, 28, 28), approximate=\"tanh\")\n        torch.nn.functional.gelu(torch.ones(16, 128, 28, 28), approximate=\"tanh\")\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestSigmoidModule(flow.unittest.TestCase):\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    @autotest(n=5, atol=1e-3, check_dtype=True)\n    def test_sigmoid_flow_with_half_data(test_case):\n        device = gpu_device()\n        x = random_tensor().to(device=device, dtype=torch.float16)\n        y = torch.sigmoid(x)\n        return y\n\n    @autotest(n=5)\n    def test_sigmoid_module_with_random_data(test_case):\n        m = torch.nn.Sigmoid()\n        m.train(random())\n        device = random_device()\n        m.to(device)\n        x = random_tensor().to(device)\n        y = m(x)\n        return y\n\n    @autotest(n=5)\n    def test_sigmoid_module_with_0dim_data(test_case):\n        m = torch.nn.Sigmoid()\n        m.train(random())\n        device = random_device()\n        m.to(device)\n        x = random_tensor(ndim=0).to(device)\n        y = m(x)\n        return y\n\n    @autotest(n=5)\n    def test_sigmoid_flow_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor().to(device)\n        y = torch.sigmoid(x)\n        return y\n\n    @autotest(n=5)\n    def test_sigmoid_flow_with_0dim_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=0).to(device)\n        y = torch.sigmoid(x)\n        return y\n\n    @autotest(n=5)\n    def test_sigmoid_tensor_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor().to(device)\n        y = x.sigmoid()\n        return y\n\n    @autotest(n=5)\n    def test_sigmoid_tensor_with_0dim_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=0).to(device)\n        y = x.sigmoid()\n        return y\n\n    @profile(torch.sigmoid)\n    def profile_sigmoid(test_case):\n        torch.sigmoid(torch.ones(1, 128, 28, 28))\n        torch.sigmoid(torch.ones(16, 128, 28, 28))\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestHardsigmoidModule(flow.unittest.TestCase):\n    def test_hardsigmoid_inplace(test_case):\n        def np_hardsigmoid(input):\n            input_shape = input.shape\n            input = input.flatten()\n            elem_cnt = input.size\n            _zero = np.zeros_like(input)\n            for i in range(elem_cnt):\n                if input[i] >= 3:\n                    _zero[i] = 1\n                elif input[i] <= -3:\n                    _zero[i] = 0\n                else:\n                    _zero[i] = input[i] / 6 + 0.5\n            np_hsigmoid_out = np.reshape(_zero, newshape=input_shape)\n            return np.array(np_hsigmoid_out)\n\n        def test_hardsigmoid_inplace_impl(test_case, shape, device):\n            x = flow.tensor(\n                np.random.randn(*shape),\n                dtype=flow.float32,\n                device=flow.device(device),\n                requires_grad=True,\n            )\n            x_inplace = x + 1\n            np_out = np_hardsigmoid(x_inplace.numpy())\n\n            id_old = id(x_inplace)\n            y_inplace = flow.nn.functional.hardsigmoid(x_inplace, inplace=True)\n\n            test_case.assertEqual(id_old, id(y_inplace))\n            test_case.assertTrue(np.allclose(y_inplace.numpy(), np_out, 1e-5, 1e-5))\n\n        arg_dict = OrderedDict()\n        arg_dict[\"shape\"] = [(2, 3), (2, 3, 4), (2, 3, 4, 5)]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            test_hardsigmoid_inplace_impl(test_case, *arg)\n\n    @autotest(n=5)\n    def test_hardsigmoid_module_with_random_data(test_case):\n        m = torch.nn.Hardsigmoid()\n        m.train(random())\n        device = random_device()\n        m.to(device)\n        x = random_tensor().to(device)\n        y = m(x)\n        return y\n\n    @autotest(n=5)\n    def test_hardsigmoid_module_with_0dim_data(test_case):\n        m = torch.nn.Hardsigmoid()\n        m.train(random())\n        device = random_device()\n        m.to(device)\n        x = random_tensor(ndim=0).to(device)\n        y = m(x)\n        return y\n\n    @autotest(n=5)\n    def test_functional_hardsigmoid_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor().to(device)\n        y = torch.nn.functional.hardsigmoid(x, random_bool())\n        return y\n\n    @autotest(n=5)\n    def test_functional_hardsigmoid_with_0dim_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=0).to(device)\n        y = torch.nn.functional.hardsigmoid(x, random_bool())\n        return y\n\n    @profile(torch.nn.functional.hardsigmoid)\n    def profile_hardsigmoid(test_case):\n        torch.nn.functional.hardsigmoid(torch.ones(1, 128, 28, 28))\n        torch.nn.functional.hardsigmoid(torch.ones(1, 128, 28, 28), inplace=True)\n        torch.nn.functional.hardsigmoid(torch.ones(16, 128, 28, 28))\n        torch.nn.functional.hardsigmoid(torch.ones(16, 128, 28, 28), inplace=True)\n\n\ndef do_test_softmax(batch_size: int, log_softmax: bool = False):\n    num_dims = random(low=1, high=5).to(int)\n    m = torch.nn.Softmax(dim=random(low=0, high=num_dims).to(int) | nothing())\n    if log_softmax:\n        m = torch.nn.LogSoftmax(dim=random(low=0, high=num_dims).to(int) | nothing())\n    m.train(random())\n    device = random_device()\n    m.to(device)\n    x = (\n        random_tensor(ndim=num_dims).to(device)\n        if batch_size < 0\n        else random_tensor(ndim=num_dims, dim0=batch_size).to(device)\n    )\n    y = m(x)\n    return y\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestSoftmax(flow.unittest.TestCase):\n    @autotest(n=5)\n    def test_softmax_module_with_random_data(test_case):\n        return do_test_softmax(batch_size=-1, log_softmax=False)\n\n    @autotest(n=5)\n    def test_softmax_module_with_batch_size_equal_1024(test_case):\n        return do_test_softmax(batch_size=1024, log_softmax=False)\n\n    @autotest(n=5, check_graph=True)\n    def test_softmax_module_with_batch_size_equal_5120(test_case):\n        return do_test_softmax(batch_size=5120, log_softmax=False)\n\n    @autotest(n=2, check_graph=True)\n    def test_softmax_module_with_batch_size_equal_10240(test_case):\n        return do_test_softmax(batch_size=10240, log_softmax=False)\n\n    @profile(torch.nn.functional.softmax)\n    def profile_softmax(test_case):\n        torch.nn.functional.softmax(torch.ones(1, 128, 28, 28))\n        torch.nn.functional.softmax(torch.ones(16, 128, 28, 28))\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestLogSoftmaxModule(flow.unittest.TestCase):\n    @autotest(n=5)\n    def test_logsoftmax_module_with_random_data(test_case):\n        return do_test_softmax(batch_size=-1, log_softmax=True)\n\n    @autotest(n=5)\n    def test_softmax_module_with_batch_size_equal_1024(test_case):\n        return do_test_softmax(batch_size=1024, log_softmax=True)\n\n    @autotest(n=5, check_graph=True)\n    def test_softmax_module_with_batch_size_equal_5120(test_case):\n        return do_test_softmax(batch_size=5120, log_softmax=True)\n\n    @autotest(n=2, check_graph=True)\n    def test_softmax_module_with_batch_size_equal_10240(test_case):\n        return do_test_softmax(batch_size=10240, log_softmax=True)\n\n    @profile(torch.nn.functional.log_softmax)\n    def profile_logsoftmax(test_case):\n        torch.nn.functional.log_softmax(torch.ones(1, 128, 28, 28))\n        torch.nn.functional.log_softmax(torch.ones(16, 128, 28, 28))\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestLogSigmoidModule(flow.unittest.TestCase):\n    @autotest(n=5)\n    def test_logsigmoid_module_with_random_data(test_case):\n        m = torch.nn.LogSigmoid()\n        m.train(random())\n        device = random_device()\n        m.to(device)\n        x = random_tensor().to(device)\n        y = m(x)\n        return y\n\n    @autotest(n=5)\n    def test_logsigmoid_module_with_0dim_data(test_case):\n        m = torch.nn.LogSigmoid()\n        m.train(random())\n        device = random_device()\n        m.to(device)\n        x = random_tensor(ndim=0).to(device)\n        y = m(x)\n        return y\n\n    @profile(torch.nn.functional.logsigmoid)\n    def profile_logsigmoid(test_case):\n        torch.nn.functional.logsigmoid(torch.ones(1, 128, 28, 28))\n        torch.nn.functional.logsigmoid(torch.ones(16, 128, 28, 28))\n\n\ndef numpy_softplus(x, beta, threshold):\n    return np.where(\n        x * beta > threshold, x, 1.0 / beta * np.log(1.0 + np.exp(beta * x))\n    )\n\n\ndef _test_softplus(test_case, device):\n    m = flow.nn.Softplus()\n    arr = np.random.randn(2, 3, 4, 5)\n    np_out = numpy_softplus(arr, 1.0, 20)\n    x = flow.tensor(arr, device=flow.device(device))\n    of_out = m(x)\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05))\n\n\ndef _test_softplus_beta(test_case, device):\n    m = flow.nn.Softplus(beta=1.11)\n    arr = np.random.randn(2, 3, 4, 5)\n    np_out = numpy_softplus(arr, 1.11, 20)\n    x = flow.tensor(arr, device=flow.device(device))\n    of_out = m(x)\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05))\n\n\ndef _test_softplus_threshold(test_case, device):\n    m = flow.nn.Softplus(beta=1.11, threshold=1.55)\n    arr = np.random.randn(2, 3, 4, 5)\n    np_out = np.where(\n        arr * 1.11 > 1.55, arr, 1.0 / 1.11 * np.log(1.0 + np.exp(1.11 * arr))\n    )\n    np_out = numpy_softplus(arr, 1.11, 1.55)\n    x = flow.tensor(arr, device=flow.device(device))\n    of_out = m(x)\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05))\n\n\ndef _test_softplus_backward(test_case, device):\n    m = flow.nn.Softplus()\n    arr = np.array([1.0, 2.0, 21.0, 20.0, 4.0])\n    x = flow.tensor(arr, device=flow.device(device), requires_grad=True)\n    of_out = m(x)\n    of_out = of_out.sum()\n    of_out.backward()\n    np_grad = [0.7310585786300049, 0.8807970779778824, 1.0, 1.0, 0.9820137900379085]\n    test_case.assertTrue(np.allclose(x.grad.numpy(), np_grad, 1e-05, 1e-05))\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestSoftplusModule(flow.unittest.TestCase):\n    def test_softplus(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [\n            _test_softplus,\n            _test_softplus_beta,\n            _test_softplus_threshold,\n            _test_softplus_backward,\n        ]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n    @unittest.skip(\"pytorch softplus backward has bug\")\n    @autotest(n=5)\n    def test_softplus_module_with_random_data(test_case):\n        m = torch.nn.Softplus(beta=random() | nothing(), threshold=random() | nothing())\n        m.train(random())\n        device = random_device()\n        m.to(device)\n        x = random_tensor().to(device)\n        y = m(x)\n        return y\n\n    @profile(torch.nn.functional.softplus)\n    def profile_softplus(test_case):\n        torch.nn.functional.softplus(torch.ones(1, 128, 28, 28))\n        torch.nn.functional.softplus(torch.ones(16, 128, 28, 28))\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestHardswishModule(flow.unittest.TestCase):\n    @autotest(n=5)\n    def test_hardswish_module_with_random_data(test_case):\n        m = torch.nn.Hardswish()\n        m.train(random())\n        device = random_device()\n        m.to(device)\n        x = random_tensor().to(device)\n        y = m(x)\n        return y\n\n    @autotest(n=5)\n    def test_hardswish_module_with_0dim_data(test_case):\n        m = torch.nn.Hardswish()\n        m.train(random())\n        device = random_device()\n        m.to(device)\n        x = random_tensor(ndim=0).to(device)\n        y = m(x)\n        return y\n\n    @profile(torch.nn.functional.hardswish)\n    def profile_hardswish(test_case):\n        torch.nn.functional.hardswish(torch.ones(1, 128, 28, 28))\n        torch.nn.functional.hardswish(torch.ones(16, 128, 28, 28))\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestHardtanhModule(flow.unittest.TestCase):\n    @autotest(n=5)\n    def test_hardtanh_module_with_random_data(test_case):\n        m = torch.nn.Hardtanh(\n            min_val=random().to(float) | nothing(),\n            max_val=random().to(float) | nothing(),\n        )\n        m.train(random())\n        device = random_device()\n        m.to(device)\n        x = random_tensor(ndim=4).to(device)\n        y = m(x)\n        return y\n\n    @autotest(n=5)\n    def test_hardtanh_module_with_0dim_data(test_case):\n        m = torch.nn.Hardtanh(\n            min_val=random().to(float) | nothing(),\n            max_val=random().to(float) | nothing(),\n        )\n        m.train(random())\n        device = random_device()\n        m.to(device)\n        x = random_tensor(ndim=0).to(device)\n        y = m(x)\n        return y\n\n    @profile(torch.nn.functional.hardtanh)\n    def profile_hardtanh(test_case):\n        torch.nn.functional.hardtanh(\n            torch.ones(1, 128, 28, 28), min_val=-1.0, max_val=1.0\n        )\n        torch.nn.functional.hardtanh(\n            torch.ones(16, 128, 28, 28), min_val=-1.0, max_val=1.0\n        )\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestLeakyReLUModule(flow.unittest.TestCase):\n    @autotest(n=5)\n    def test_leakyrelu_module_with_random_data(test_case):\n        m = torch.nn.LeakyReLU(negative_slope=random() | nothing())\n        m.train(random())\n        device = random_device()\n        m.to(device)\n        x = random_tensor().to(device)\n        y = m(x)\n        return y\n\n    @autotest(n=5)\n    def test_leakyrelu_module_with_inplace_arg(test_case):\n        m = torch.nn.LeakyReLU(\n            negative_slope=random() | nothing(), inplace=random().to(bool) | nothing()\n        )\n        m.train(random())\n        device = random_device()\n        m.to(device)\n        x = random_tensor().to(device)\n        y = m(x)\n        return y\n\n    @autotest(n=5)\n    def test_leakyrelu_module_with_0dim_data(test_case):\n        m = torch.nn.LeakyReLU(negative_slope=random() | nothing())\n        m.train(random())\n        device = random_device()\n        m.to(device)\n        x = random_tensor(ndim=0).to(device)\n        y = m(x)\n        return y\n\n    @profile(torch.nn.functional.leaky_relu)\n    def profile_leaky_relu(test_case):\n        torch.nn.functional.leaky_relu(torch.ones(1, 128, 28, 28), 0.1)\n        torch.nn.functional.leaky_relu(torch.ones(1, 128, 28, 28), 0.1, inplace=True)\n        torch.nn.functional.leaky_relu(torch.ones(16, 128, 28, 28), 0.1)\n        torch.nn.functional.leaky_relu(torch.ones(16, 128, 28, 28), 0.1, inplace=True)\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestMishModule(flow.unittest.TestCase):\n    @autotest(n=5)\n    def test_mish_module_with_random_data(test_case):\n        m = torch.nn.Mish()\n        m.train(random())\n        device = random_device()\n        m.to(device)\n        x = random_tensor().to(device)\n        y = m(x)\n        return y\n\n    @autotest(n=5)\n    def test_mish_module_with_0dim_data(test_case):\n        m = torch.nn.Mish()\n        m.train(random())\n        device = random_device()\n        m.to(device)\n        x = random_tensor(ndim=0).to(device)\n        y = m(x)\n        return y\n\n    @profile(torch.nn.functional.mish)\n    def profile_mish(test_case):\n        torch.nn.functional.mish(torch.ones(1, 128, 28, 28))\n        torch.nn.functional.mish(torch.ones(16, 128, 28, 28))\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestSiluModule(flow.unittest.TestCase):\n    @autotest(n=5)\n    def test_silu_module_with_random_data(test_case):\n        m = torch.nn.SiLU()\n        m.train(random())\n        device = random_device()\n        m.to(device)\n        x = random_tensor().to(device)\n        y = m(x)\n        return y\n\n    @autotest(n=5)\n    def test_silu_module_with_0dim_data(test_case):\n        m = torch.nn.SiLU()\n        m.train(random())\n        device = random_device()\n        m.to(device)\n        x = random_tensor(ndim=0).to(device)\n        y = m(x)\n        return y\n\n    @profile(torch.nn.functional.silu)\n    def profile_silu(test_case):\n        torch.nn.functional.silu(torch.ones(1, 128, 28, 28))\n        torch.nn.functional.silu(torch.ones(16, 128, 28, 28))\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestSeluModule(flow.unittest.TestCase):\n    @autotest(n=5)\n    def test_selu_module_with_random_data(test_case):\n        m = torch.nn.SELU()\n        m.train(random())\n        device = random_device()\n        m.to(device)\n        x = random_tensor().to(device)\n        y = m(x)\n        return y\n\n    @autotest(n=5)\n    def test_selu_module_with_0dim_data(test_case):\n        m = torch.nn.SELU()\n        m.train(random())\n        device = random_device()\n        m.to(device)\n        x = random_tensor(ndim=0).to(device)\n        y = m(x)\n        return y\n\n    @profile(torch.nn.functional.selu)\n    def profile_selu(test_case):\n        torch.nn.functional.selu(torch.ones(1, 128, 28, 28))\n        torch.nn.functional.selu(torch.ones(16, 128, 28, 28))\n\n\n@unittest.skip(\"still have error in ci test\")\nclass TestSoftsignModule(flow.unittest.TestCase):\n    @autotest(n=5)\n    def test_softsign_module_with_random_data(test_case):\n        m = torch.nn.Softsign()\n        m.train(random())\n        device = random_device()\n        m.to(device)\n        x = random_tensor().to(device)\n        y = m(x)\n        return y\n\n    #'Ran 1 test in 0.000s',return a blank table\n    @profile(torch.nn.functional.softsign)\n    def profile_softsign(test_case):\n        torch.nn.functional.softsign(torch.ones(1, 128, 28, 28))\n        torch.nn.functional.softsign(torch.ones(16, 128, 28, 28))\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestReluFunction(flow.unittest.TestCase):\n    @autotest(n=5)\n    def test_flow_relu_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=2, dim1=3).to(device)\n        y = torch.relu(x)\n        return y\n\n    @autotest(n=5)\n    def test_flow_relu_with_0dim_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=0).to(device)\n        y = torch.relu(x)\n        return y\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestRelu6Function(flow.unittest.TestCase):\n    @autotest(n=5)\n    def test_flow_nn_functional_relu6_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=2, dim1=3).to(device)\n        y = torch.nn.functional.relu6(x)\n        return y\n\n    @autotest(n=5)\n    def test_flow_nn_functional_relu6_with_0dim_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=0).to(device)\n        y = torch.nn.functional.relu6(x)\n        return y\n\n    @profile(torch.nn.functional.relu6)\n    def profile_relu6(test_case):\n        torch.nn.functional.relu6(torch.ones(1, 128, 28, 28))\n        torch.nn.functional.relu6(torch.ones(16, 128, 28, 28))\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestLogSigmoidFunction(flow.unittest.TestCase):\n    @autotest(n=5)\n    def test_flow_nn_functional_logsigmoid_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=2, dim1=3).to(device)\n        y = torch.nn.functional.logsigmoid(x)\n        return y\n\n    @autotest(n=5)\n    def test_flow_nn_functional_logsigmoid_with_0dim_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=0).to(device)\n        y = torch.nn.functional.logsigmoid(x)\n        return y\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestHardshrinkModule(flow.unittest.TestCase):\n    @autotest(n=5)\n    def test_hardshrink_module_with_random_data(test_case):\n        m = torch.nn.Hardshrink(lambd=random() | nothing())\n        m.train(random())\n        device = random_device()\n        m.to(device)\n        x = random_tensor().to(device)\n        y = m(x)\n        return y\n\n    @autotest(n=5)\n    def test_hardshrink_module_with_0dim_data(test_case):\n        m = torch.nn.Hardshrink(lambd=random() | nothing())\n        m.train(random())\n        device = random_device()\n        m.to(device)\n        x = random_tensor(ndim=0).to(device)\n        y = m(x)\n        return y\n\n    @autotest(n=5, auto_backward=False)\n    def test_hardshrink_module_with_0_size_data(test_case):\n        m = torch.nn.Hardshrink(lambd=random() | nothing())\n        m.train(random())\n        device = random_device()\n        m.to(device)\n        x = random_tensor(4, 2, 3, 0, 3).to(device)\n        y = m(x)\n        return y\n\n    @profile(torch.nn.functional.hardshrink)\n    def profile_hardshrink(test_case):\n        torch.nn.functional.hardshrink(torch.ones(1, 128, 28, 28))\n        torch.nn.functional.hardshrink(torch.ones(16, 128, 28, 28))\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestSoftshrinkModule(flow.unittest.TestCase):\n    @autotest(n=5)\n    def test_softshrink_module_with_random_data(test_case):\n        m = torch.nn.Softshrink(alpha=random() | nothing())\n        m.train(random())\n        device = random_device()\n        m.to(device)\n        x = random_tensor().to(device)\n        y = m(x)\n        return y\n\n    @autotest(n=5)\n    def test_softshrink_module_with_0dim_data(test_case):\n        m = torch.nn.Softshrink(alpha=random() | nothing())\n        m.train(random())\n        device = random_device()\n        m.to(device)\n        x = random_tensor(ndim=0).to(device)\n        y = m(x)\n        return y\n\n    @autotest(n=5, auto_backward=False)\n    def test_softshrink_module_with_0_size_data(test_case):\n        m = torch.nn.Softshrink(alpha=random() | nothing())\n        m.train(random())\n        device = random_device()\n        m.to(device)\n        x = random_tensor(4, 2, 3, 0, 3).to(device)\n        y = m(x)\n        return y\n\n    @profile(torch.nn.functional.softshrink)\n    def profile_softshrink(test_case):\n        torch.nn.functional.softshrink(torch.ones(1, 128, 28, 28))\n        torch.nn.functional.softshrink(torch.ones(16, 128, 28, 28))\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestThresholdModule(flow.unittest.TestCase):\n    @autotest(n=5)\n    def test_threshold_module_with_random_data(test_case):\n        m = torch.nn.Threshold(\n            threshold=random() | nothing(), value=random() | nothing()\n        )\n        m.train(random())\n        device = random_device()\n        m.to(device)\n        x = random_tensor().to(device)\n        y = m(x)\n        return y\n\n    @autotest(n=5)\n    def test_threshold_module_with_0dim_data(test_case):\n        m = torch.nn.Threshold(\n            threshold=random() | nothing(), value=random() | nothing()\n        )\n        m.train(random())\n        device = random_device()\n        m.to(device)\n        x = random_tensor(ndim=0).to(device)\n        y = m(x)\n        return y\n\n    @autotest(n=5, auto_backward=False)\n    def test_threshold_module_with_0_size_data(test_case):\n        m = torch.nn.Threshold(\n            threshold=random() | nothing(), value=random() | nothing()\n        )\n        m.train(random())\n        device = random_device()\n        m.to(device)\n        x = random_tensor(4, 2, 3, 0, 3).to(device)\n        y = m(x)\n        return y\n\n    @profile(torch.nn.functional.threshold)\n    def profile_threshold(test_case):\n        torch.nn.functional.threshold(\n            torch.ones(1, 128, 28, 28), threshold=0.1, value=20\n        )\n        torch.nn.functional.threshold(\n            torch.ones(16, 128, 28, 28), threshold=0.1, value=20\n        )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_adaptive_max_pool.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\n\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.nn.common_types import _size_1_t\nfrom packaging import version\nimport torch as torch_original\nfrom typing import Union, Tuple\nimport numpy as np\n\n\nfrom oneflow.test_utils.automated_test_util import *\n\nNoneType = type(None)\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestAdaptiveMaxPool(flow.unittest.TestCase):\n    @unittest.skip(\"skip for now, becase it failed 2 times in past week\")\n    @autotest(n=5)\n    def test_adaptive_maxpool1d(test_case):\n        m = torch.nn.AdaptiveMaxPool1d(output_size=random().to(_size_1_t))\n        m.train(random())\n        device = random_device()\n        m.to(device)\n        x = random_tensor(ndim=3).to(device)\n        y = m(x)\n        return y\n\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    def test_adaptive_maxpool2d_manually(test_case):\n        def _test_adaptive_max_pool_nd(input_shape, output_shape, m1, m2):\n            input_np = np.random.rand(2, 3, *input_shape)\n            input_pt = torch_original.tensor(\n                input_np, device=\"cuda\", requires_grad=True\n            )\n            input_of = flow.tensor(input_np, device=\"cuda\", requires_grad=True)\n\n            m_pt = m1(output_shape, True)\n            m_of = m2(output_shape, True)\n\n            output_pt = m_pt(input_pt)\n            output_of = m_of(input_of)\n\n            sum_pt = torch_original.sum(output_pt[0])\n            sum_of = flow.sum(output_of[0])\n\n            sum_pt.backward()\n            sum_of.backward()\n\n            test_case.assertTrue(\n                np.array_equal(\n                    output_pt[0].detach().cpu().numpy(),\n                    output_of[0].detach().cpu().numpy(),\n                )\n            )\n            test_case.assertTrue(\n                np.array_equal(\n                    output_pt[1].detach().cpu().numpy(),\n                    output_of[1].detach().cpu().numpy(),\n                )\n            )\n            test_case.assertTrue(\n                np.array_equal(input_pt.grad.cpu().numpy(), input_of.grad.cpu().numpy())\n            )\n\n        _test_adaptive_max_pool_nd(\n            (10, 11),\n            (3, 4),\n            torch_original.nn.AdaptiveMaxPool2d,\n            flow.nn.AdaptiveMaxPool2d,\n        )\n        _test_adaptive_max_pool_nd(\n            (10, 11, 12),\n            (3, 4, 5),\n            torch_original.nn.AdaptiveMaxPool3d,\n            flow.nn.AdaptiveMaxPool3d,\n        )\n\n    @profile(torch.nn.functional.adaptive_max_pool1d)\n    def profile_adaptive_max_pool1d(test_case):\n        torch.nn.functional.adaptive_max_pool1d(torch.ones(1, 64, 8), 5)\n\n    @profile(torch.nn.functional.adaptive_max_pool2d)\n    def profile_adaptive_max_pool2d(test_case):\n        torch.nn.functional.adaptive_max_pool2d(torch.ones(1, 64, 10, 9), 7)\n        torch.nn.functional.adaptive_max_pool2d(torch.ones(1, 64, 8, 9), (5, 7))\n\n    @profile(torch.nn.functional.adaptive_max_pool3d)\n    def profile_adaptive_max_pool3d(test_case):\n        torch.nn.functional.adaptive_max_pool3d(torch.ones(1, 64, 8, 9, 10), (5, 7, 9))\n        torch.nn.functional.adaptive_max_pool3d(torch.ones(1, 64, 10, 9, 8), 7)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_adaptive_pool.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\n\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.nn.common_types import _size_1_t\nfrom packaging import version\nimport torch as torch_original\nfrom typing import Union, Tuple\n\n\nfrom oneflow.test_utils.automated_test_util import *\n\nNoneType = type(None)\n# Not the same as those in PyTorch because 'output_size' cannot be NoneType (even in 'torch.nn.AdaptiveAvgPoolXd')\n_size_2_opt_t_not_none = Union[int, Tuple[Union[int, NoneType], Union[int, NoneType]]]\n_size_3_opt_t_not_none = Union[\n    int, Tuple[Union[int, NoneType], Union[int, NoneType], Union[int, NoneType]]\n]\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestAdaptiveAvgPool(flow.unittest.TestCase):\n    @autotest(n=5)\n    def test_adaptive_avgpool1d(test_case):\n        m = torch.nn.AdaptiveAvgPool1d(output_size=random().to(_size_1_t))\n        m.train(random())\n        device = random_device()\n        m.to(device)\n        x = random_tensor(ndim=3).to(device)\n        y = m(x)\n        return y\n\n    @profile(torch.nn.functional.adaptive_avg_pool1d)\n    def profile_adaptive_avg_pool1d(test_case):\n        torch.nn.functional.adaptive_avg_pool1d(torch.ones(1, 64, 8), 5)\n\n    @autotest(n=5)\n    def test_adaptive_avgpool2d(test_case):\n        m = torch.nn.AdaptiveAvgPool2d(output_size=random().to(_size_2_opt_t_not_none))\n        m.train(random())\n        device = random_device()\n        m.to(device)\n        x = random_tensor(ndim=4).to(device)\n        y = m(x)\n        return y\n\n    @profile(torch.nn.functional.adaptive_avg_pool2d)\n    def profile_adaptive_avg_pool2d(test_case):\n        torch.nn.functional.adaptive_avg_pool2d(torch.ones(1, 64, 10, 9), 7)\n        torch.nn.functional.adaptive_avg_pool2d(torch.ones(1, 64, 8, 9), (5, 7))\n\n    @unittest.skipIf(\n        version.parse(torch_original.__version__) < version.parse(\"1.10.0\"),\n        \"GPU version 'nn.AdaptiveAvgPool3d' has a bug in PyTorch before '1.10.0'\",\n    )\n    @autotest(n=5)\n    def test_adaptive_avgpool3d(test_case):\n        m = torch.nn.AdaptiveAvgPool3d(output_size=random().to(_size_3_opt_t_not_none))\n        m.train(random())\n        device = random_device()\n        m.to(device)\n        x = random_tensor(ndim=5).to(device)\n        y = m(x)\n        return y\n\n    @profile(torch.nn.functional.adaptive_avg_pool3d)\n    def profile_adaptive_avg_pool3d(test_case):\n        torch.nn.functional.adaptive_avg_pool3d(torch.ones(1, 64, 8, 9, 10), (5, 7, 9))\n        torch.nn.functional.adaptive_avg_pool3d(torch.ones(1, 64, 10, 9, 8), 7)\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestAdaptiveAvgPoolFunctional(flow.unittest.TestCase):\n    @autotest(n=5)\n    def test_adaptive_avgpool1d_functional(test_case):\n        device = random_device()\n        x = random_tensor(ndim=3).to(device)\n        return torch.nn.functional.adaptive_avg_pool1d(x, output_size=random().to(int))\n\n    @autotest(n=5)\n    def test_adaptive_avgpool2d_functional(test_case):\n        device = random_device()\n        x = random_tensor(ndim=4).to(device)\n        return torch.nn.functional.adaptive_avg_pool2d(x, output_size=random().to(int))\n\n    @unittest.skipIf(\n        version.parse(torch_original.__version__) <= version.parse(\"1.10.0\"),\n        \"GPU version 'nn.AdaptiveAvgPool3d' has a bug in PyTorch before '1.10.0'\",\n    )\n    @autotest(n=5)\n    def test_adaptive_avgpool3d_functional(test_case):\n        device = random_device()\n        x = random_tensor(ndim=5).to(device)\n        return torch.nn.functional.adaptive_avg_pool3d(x, output_size=random().to(int))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_adaptive_pool_fp16.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.nn.common_types import _size_1_t\nfrom packaging import version\nimport torch as torch_original\nfrom typing import Union, Tuple\nfrom oneflow.test_utils.automated_test_util import *\n\nNoneType = type(None)\n_size_2_opt_t_not_none = Union[int, Tuple[Union[int, NoneType], Union[int, NoneType]]]\n_size_3_opt_t_not_none = Union[\n    int, Tuple[Union[int, NoneType], Union[int, NoneType], Union[int, NoneType]]\n]\n\n\n@flow.unittest.skip_unless_1n1d()\nclass Test_CpuFp16_AdaptiveAvgPool(flow.unittest.TestCase):\n    @autotest(n=5, rtol=0.01, atol=0.01)\n    def test_adaptive_avgpool1d(test_case):\n        m = torch.nn.AdaptiveAvgPool1d(output_size=random().to(_size_1_t))\n        m.train(random())\n        device = \"cpu\"\n        m.to(device)\n        x = random_tensor(ndim=3).to(device)\n        x = x.clone().half()\n        y = m(x)\n        return y\n\n    @profile(torch.nn.functional.adaptive_avg_pool1d)\n    def profile_adaptive_avg_pool1d(test_case):\n        return torch.nn.functional.adaptive_avg_pool1d(torch.ones(1, 64, 8).half(), 5)\n\n    @autotest(n=5, rtol=0.01, atol=0.01)\n    def test_adaptive_avgpool2d(test_case):\n        m = torch.nn.AdaptiveAvgPool2d(output_size=random().to(_size_2_opt_t_not_none))\n        m.train(random())\n        device = random_device()\n        m.to(device)\n        x = random_tensor(ndim=4).to(device)\n        x = x.half()\n        y = m(x)\n        return y\n\n    @profile(torch.nn.functional.adaptive_avg_pool2d)\n    def profile_adaptive_avg_pool2d(test_case):\n        torch.nn.functional.adaptive_avg_pool2d(torch.ones(1, 64, 10, 9).half(), 7)\n        torch.nn.functional.adaptive_avg_pool2d(torch.ones(1, 64, 8, 9).half(), (5, 7))\n\n    @unittest.skipIf(\n        version.parse(torch_original.__version__) < version.parse(\"1.10.0\"),\n        \"GPU version 'nn.AdaptiveAvgPool3d' has a bug in PyTorch before '1.10.0'\",\n    )\n    @autotest(n=5, rtol=0.01, atol=0.01)\n    def test_adaptive_avgpool3d(test_case):\n        m = torch.nn.AdaptiveAvgPool3d(output_size=random().to(_size_3_opt_t_not_none))\n        m.train(random())\n        device = random_device()\n        m.to(device)\n        x = random_tensor(ndim=5).to(device)\n        x = x.half()\n        y = m(x)\n        return y\n\n    @profile(torch.nn.functional.adaptive_avg_pool3d)\n    def profile_adaptive_avg_pool3d(test_case):\n        torch.nn.functional.adaptive_avg_pool3d(\n            torch.ones(1, 64, 8, 9, 10).half(), (5, 7, 9)\n        )\n        torch.nn.functional.adaptive_avg_pool3d(torch.ones(1, 64, 10, 9, 8).half(), 7)\n\n\n@flow.unittest.skip_unless_1n1d()\nclass Test_CpuFp16_AdaptiveAvgPoolFunctional(flow.unittest.TestCase):\n    @autotest(n=5, rtol=0.01, atol=0.01)\n    def test_adaptive_avgpool1d_functional(test_case):\n        device = random_device()\n        x = random_tensor(ndim=3).to(device)\n        x = x.half()\n        return torch.nn.functional.adaptive_avg_pool1d(x, output_size=random().to(int))\n\n    @autotest(n=5, rtol=0.01, atol=0.01)\n    def test_adaptive_avgpool2d_functional(test_case):\n        device = random_device()\n        x = random_tensor(ndim=4).to(device)\n        x = x.half()\n        return torch.nn.functional.adaptive_avg_pool2d(x, output_size=random().to(int))\n\n    @autotest(n=5, rtol=0.01, atol=0.01)\n    def test_adaptive_avgpool3d_functional(test_case):\n        device = random_device()\n        x = random_tensor(ndim=5).to(device)\n        x = x.half()\n        return torch.nn.functional.adaptive_avg_pool3d(x, output_size=random().to(int))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_add.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\nimport torch as torch_original\nimport numpy as np\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\ndef _test_add_forward(test_case, shape, device):\n    x = flow.tensor(np.random.randn(*shape), device=flow.device(device))\n    y = flow.tensor(np.random.randn(*shape), device=flow.device(device))\n    of_out = flow.add(x, y)\n    np_out = np.add(x.numpy(), y.numpy())\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001))\n    x = 5\n    y = flow.tensor(np.random.randn(*shape), device=flow.device(device))\n    of_out = flow.add(x, y)\n    np_out = np.add(x, y.numpy())\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001))\n    x = flow.tensor(np.random.randn(*shape), device=flow.device(device))\n    y = 5\n    of_out = flow.add(x, y)\n    np_out = np.add(x.numpy(), y)\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001))\n    x = flow.tensor(np.random.randn(*shape), device=flow.device(device))\n    y = flow.tensor(np.array([5.0]), device=flow.device(device))\n    of_out = flow.add(x, y)\n    np_out = np.add(x.numpy(), y.numpy())\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001))\n    x = flow.tensor(np.random.randn(1, 1), device=flow.device(device))\n    y = flow.tensor(np.random.randn(*shape), device=flow.device(device))\n    of_out = flow.add(x, y)\n    np_out = np.add(x.numpy(), y.numpy())\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001))\n\n\ndef _test_add_backward(test_case, shape, device):\n    x = 5\n    y = flow.tensor(\n        np.random.randn(*shape), requires_grad=True, device=flow.device(device)\n    )\n    of_out = flow.add(x, y).sum()\n    of_out.backward()\n    test_case.assertTrue(\n        np.allclose(y.grad.numpy(), np.ones(shape=shape), 0.0001, 0.0001)\n    )\n\n\ndef _test_inplace_add(test_case, shape, device):\n    np_x = np.random.randn(*shape)\n    of_x = flow.tensor(\n        np_x, dtype=flow.float32, device=flow.device(device), requires_grad=True\n    )\n    of_x_inplace = of_x + 1\n    id_old = id(of_x_inplace)\n    of_x_inplace.add_(5)\n    test_case.assertEqual(id_old, id(of_x_inplace))\n    np_out = np_x + 1 + 5\n    test_case.assertTrue(np.allclose(of_x_inplace.numpy(), np_out, 1e-05, 1e-05))\n    of_x_inplace = of_x_inplace.sum()\n    of_x_inplace.backward()\n    test_case.assertTrue(np.allclose(of_x.grad.numpy(), np.ones(shape), 1e-05, 1e-05))\n\n    of_x = flow.tensor(\n        np_x, dtype=flow.float32, device=flow.device(device), requires_grad=True\n    )\n    of_y = flow.tensor(\n        np.random.randn(*shape),\n        dtype=flow.float32,\n        device=flow.device(device),\n        requires_grad=False,\n    )\n    of_x_inplace = of_x + 1\n    id_old = id(of_x_inplace)\n    of_x_inplace.add_(of_y)\n    test_case.assertEqual(id_old, id(of_x_inplace))\n    np_out = np_x + 1 + of_y.numpy()\n    test_case.assertTrue(np.allclose(of_x_inplace.numpy(), np_out, 1e-05, 1e-05))\n    of_x_inplace = of_x_inplace.sum()\n    of_x_inplace.backward()\n    test_case.assertTrue(np.allclose(of_x.grad.numpy(), np.ones(shape), 1e-05, 1e-05))\n\n    of_x = flow.tensor(\n        np_x, dtype=flow.float32, device=flow.device(device), requires_grad=True\n    )\n    of_y = flow.tensor(\n        np.random.randn(*shape),\n        dtype=flow.float32,\n        device=flow.device(device),\n        requires_grad=False,\n    )\n    of_x_inplace = of_x + 1\n    id_old = id(of_x_inplace)\n    of_x_inplace += of_y\n    test_case.assertEqual(id_old, id(of_x_inplace))\n    np_out = np_x + 1 + of_y.numpy()\n    test_case.assertTrue(np.allclose(of_x_inplace.numpy(), np_out, 1e-05, 1e-05))\n    of_x_inplace = of_x_inplace.sum()\n    of_x_inplace.backward()\n    test_case.assertTrue(np.allclose(of_x.grad.numpy(), np.ones(shape), 1e-05, 1e-05))\n\n    of_x = flow.tensor(\n        np_x, dtype=flow.float32, device=flow.device(device), requires_grad=True\n    )\n    of_y = flow.tensor(\n        np.array([5.0]),\n        dtype=flow.float32,\n        device=flow.device(device),\n        requires_grad=False,\n    )\n    of_x_inplace = of_x + 1\n    id_old = id(of_x_inplace)\n    of_x_inplace.add_(of_y)\n    test_case.assertEqual(id_old, id(of_x_inplace))\n    np_out = np_x + 6\n    test_case.assertTrue(np.allclose(of_x_inplace.numpy(), np_out, 1e-05, 1e-05))\n    of_x_inplace = of_x_inplace.sum()\n    of_x_inplace.backward()\n    test_case.assertTrue(np.allclose(of_x.grad.numpy(), np.ones(shape), 1e-05, 1e-05))\n\n    of_x = flow.tensor(\n        np_x, dtype=flow.float32, device=flow.device(device), requires_grad=True\n    )\n    np_y = np.random.randn(*shape[:-1], 1)\n    of_y = flow.tensor(\n        np_y, dtype=flow.float32, device=flow.device(device), requires_grad=False\n    )\n    of_x_inplace = of_x + 1\n    id_old = id(of_x_inplace)\n    of_x_inplace.add_(of_y)\n    test_case.assertEqual(id_old, id(of_x_inplace))\n    np_out = np_x + 1 + np_y\n    test_case.assertTrue(np.allclose(of_x_inplace.numpy(), np_out, 1e-05, 1e-05))\n    of_x_inplace = of_x_inplace.sum()\n    of_x_inplace.backward()\n    test_case.assertTrue(np.allclose(of_x.grad.numpy(), np.ones(shape), 1e-05, 1e-05))\n\n\ndef _test_inplace_add_with_type_promotion(test_case, shape, device):\n    x = flow.tensor(\n        np.random.randn(*shape), device=flow.device(device), dtype=flow.float16\n    )\n    y = flow.tensor(\n        np.random.randn(*shape), device=flow.device(device), dtype=flow.float32\n    )\n    x += y\n    test_case.assertTrue(x.dtype == flow.float16)\n\n\ndef _test_inplace_add_0_size_tensor(test_case, shape, device):\n    x = flow.randn(0, 256, device=device)\n    y = flow.randn(1, 256, device=device)\n    x += y\n    test_case.assertEqual(x.size(), (0, 256))\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestAddModule(flow.unittest.TestCase):\n    def test_add(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [\n            _test_add_forward,\n            _test_add_backward,\n            _test_inplace_add,\n            _test_inplace_add_with_type_promotion,\n            _test_inplace_add_0_size_tensor,\n        ]\n        arg_dict[\"shape\"] = [(2, 3), (2, 3, 4), (2, 3, 4, 5)]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n    @autotest(n=10, include_complex=True)\n    def test_0_size_add(test_case):\n        device = random_device()\n        x = random_tensor(2, 0, 3).to(device)\n        y = random_tensor(2, 1, 3).to(device)\n        out = x + y\n        return out\n\n    @autotest(n=6, auto_backward=False, include_complex=True)\n    def test_0dim_inplace_add(test_case):\n        device = random_device()\n        x = random_tensor(2, 2, 3, requires_grad=False).to(device)\n        y = random_tensor(1, 10).to(device)\n        x += y.mean()\n        return x\n\n    @autotest(n=10, include_complex=True)\n    def test_0dim_two_inplace_add(test_case):\n        device = random_device()\n        x = random_tensor(2, 2, 3).to(device).mean()\n        y = random_tensor(2, 2, 3).to(device)\n        x += y.mean()\n        return x\n\n    @autotest(n=6, include_complex=True)\n    def test_add_with_alpha(test_case):\n        device = random_device()\n        x1 = random_tensor(2, 2, 3).to(device).mean()\n        x2 = random_tensor(2, 2, 3).to(device).mean()\n        x3 = random_tensor(2, 2, 3).to(device).mean()\n        y = random_tensor(2, 2, 3).to(device)\n        s = random().to(float)\n        alpha = random().to(float)\n        z1 = torch.add(x1, y, alpha=alpha)\n        z2 = torch.add(x2, s, alpha=alpha)\n        z3 = torch.add(s, x3, alpha=alpha)\n        return z1, z2, z3\n\n    @autotest(auto_backward=False)\n    def test_bool_add(test_case):\n        device = random_device()\n        x = random_tensor(2, 1, 3).to(device, torch.bool)\n        y = random_tensor(2, 1, 3).to(device, torch.bool)\n        out = x + y\n        return out\n\n    @autotest(auto_backward=False)\n    def test_0shape_bool_add(test_case):\n        device = random_device()\n        x = random_tensor(2, 0, 3).to(device, torch.bool)\n        y = random_tensor(2, 1, 3).to(device, torch.bool)\n        out = x + y\n        return out\n\n    @autotest(n=3, auto_backward=False)\n    def test_0dim_bool_inplace_add(test_case):\n        device = random_device()\n        x = random_tensor(2, 2, 3, requires_grad=False).to(device, torch.bool)\n        y = random_tensor(1, 10).to(device)\n        x += y.mean().to(torch.bool)\n        return x\n\n    @autotest(auto_backward=False)\n    def test_0dim_two_inplace_add(test_case):\n        device = random_device()\n        x = random_tensor(2, 2, 3).to(device).mean().to(torch.bool)\n        y = random_tensor(2, 2, 3).to(device)\n        return x\n        x += y.mean().to(torch.bool)\n\n    @autotest(n=6, include_complex=True)\n    def test_add_with_alpha_0dim(test_case):\n        device = random_device()\n        x1 = random_tensor(ndim=0).to(device).mean()\n        x2 = random_tensor(ndim=0).to(device).mean()\n        x3 = random_tensor(ndim=0).to(device).mean()\n        y = random_tensor(ndim=0).to(device)\n        s = random().to(float)\n        alpha = random().to(float)\n        z1 = torch.add(x1, y, alpha=alpha)\n        z2 = torch.add(x2, s, alpha=alpha)\n        z3 = torch.add(s, x3, alpha=alpha)\n        return z1, z2, z3\n\n    @profile(torch.add)\n    def profile_add(test_case):\n        torch.add(torch.ones(100), 20)\n        torch.add(torch.ones(100), torch.ones(100, 1), alpha=10)\n\n    @autotest(n=6, include_complex=True)\n    def test_non_contiguous_inplace_add(test_case):\n        device = random_device()\n        x = random_tensor(2, 2, 4).to(device)\n        y = x + 1\n        y = y[:, 1:3]\n        y += random_tensor(2, 2, 2).to(device)\n        return y\n\n    @autotest(n=10, include_complex=True)\n    def test_scalar_add_with_random_devices(test_case):\n        x1_device = random_device()\n        x2_device = random_device()\n        x1 = random_tensor(2, 2, 3).to(x1_device).mean()\n        x2 = random_tensor(2, 2, 3).to(x2_device)\n        y = x1 + x2\n        return y\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_addcdiv.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nfrom oneflow.test_utils.automated_test_util import *\nimport oneflow as flow\nimport oneflow.unittest\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestAddcdiv(flow.unittest.TestCase):\n    @autotest(n=5)\n    def test_addcdiv(test_case):\n        device = random_device()\n        ndim = random(2, 4).to(int).value()\n        shape = [random(2, 4) for i in range(ndim)]\n        input = random_tensor(ndim, *shape).to(device)\n        tensor1 = random_tensor(ndim, *shape).to(device)\n        tensor2 = random_tensor(ndim, *shape).to(device)\n        value = random(2, 4).to(int)\n        output = torch.addcdiv(input, tensor1, tensor2, value=value)\n        return output\n\n    @autotest(n=5)\n    def test_tensor_addcdiv(test_case):\n        device = random_device()\n        ndim = random(2, 4).to(int).value()\n        shape = [random(2, 4) for i in range(ndim)]\n        input = random_tensor(ndim, *shape).to(device)\n        tensor1 = random_tensor(ndim, *shape).to(device)\n        tensor2 = random_tensor(ndim, *shape).to(device)\n        value = random(2, 4).to(int)\n        output = input.addcdiv(tensor1, tensor2, value=value)\n        return output\n\n    @autotest(n=5)\n    def test_tensor_addcdiv_inplace(test_case):\n        device = random_device()\n        ndim = random(2, 4).to(int).value()\n        shape = [random(2, 4) for i in range(ndim)]\n        input = random_tensor(ndim, *shape).to(device)\n        input = input + 1.0\n        tensor1 = random_tensor(ndim, *shape).to(device)\n        tensor2 = random_tensor(ndim, *shape).to(device)\n        value = random(2, 4).to(int)\n        input.addcdiv_(tensor1, tensor2, value=value)\n        return input\n\n    @profile(torch.addcdiv)\n    def profile_addcdiv(test_case):\n        t = torch.ones(1, 3)\n        t1 = torch.ones(3, 1)\n        t2 = torch.ones(1, 3)\n        torch.addcdiv(t, t1, t2, value=0.1)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_addcmul.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nfrom oneflow.test_utils.automated_test_util import *\nimport oneflow as flow\nimport oneflow.unittest\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestAddcmul(flow.unittest.TestCase):\n    @autotest(check_graph=True)\n    def test_addcmul(test_case):\n        device = random_device()\n        ndim = random(low=2).to(int).value()\n        shape = [random(low=2, high=4) for i in range(ndim)]\n\n        input = random_tensor(len(shape), *shape).to(device)\n        tensor1 = random_tensor(len(shape), *shape).to(device)\n        tensor2 = random_tensor(len(shape), *shape).to(device)\n        value = random(3, 6).to(int)\n        output = torch.addcmul(input, tensor1, tensor2, value=value)\n        return output\n\n    @autotest(check_graph=True)\n    def test_tensor_addcmul(test_case):\n        device = random_device()\n        ndim = random(low=2).to(int).value()\n        shape = [random(low=2, high=4) for i in range(ndim)]\n\n        input = random_tensor(len(shape), *shape).to(device)\n        tensor1 = random_tensor(len(shape), *shape).to(device)\n        tensor2 = random_tensor(len(shape), *shape).to(device)\n        value = random(3, 6).to(int)\n        output = input.addcmul(tensor1, tensor2, value=value)\n        return output\n\n    @autotest(check_graph=True)\n    def test_tensor_addcmul_inplace(test_case):\n        device = random_device()\n        ndim = random(low=2).to(int).value()\n        shape = [random(low=2, high=4) for i in range(ndim)]\n\n        input = random_tensor(len(shape), *shape).to(device)\n        input = input + 1.0\n        tensor1 = random_tensor(len(shape), *shape).to(device)\n        tensor2 = random_tensor(len(shape), *shape).to(device)\n        value = random(3, 6).to(int)\n        input.addcmul_(tensor1, tensor2, value=value)\n        return input\n\n    @profile(torch.addcmul)\n    def profile_addcmul(test_case):\n        input = torch.ones(100, 12, 13)\n        tensor1 = torch.ones(100, 12, 13)\n        tensor2 = torch.ones(100, 12, 13)\n        torch.addcmul(input, tensor1, tensor2, value=2)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_addmm.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\ndef _test_addmm(test_case, shape, alpha, beta, device):\n    mat1 = np.random.randn(*shape)\n    mat2 = np.random.randn(*shape)\n    input = np.random.randn(*shape)\n    mat1_tensor = flow.tensor(mat1, dtype=flow.float32, device=flow.device(device))\n    mat2_tensor = flow.tensor(mat2, dtype=flow.float32, device=flow.device(device))\n    input_tensor = flow.tensor(input, dtype=flow.float32, device=flow.device(device))\n    of_out = flow.addmm(input_tensor, mat1_tensor, mat2_tensor, alpha, beta)\n    np_out = np.add(beta * input, alpha * np.matmul(mat1, mat2))\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05))\n\n\ndef _test_addmm_backward(test_case, shape, alpha, beta, device):\n    mat1 = np.random.randn(*shape)\n    mat2 = np.random.randn(*shape)\n    input = np.random.randn(*shape)\n    mat1_tensor = flow.tensor(mat1, dtype=flow.float32, device=flow.device(device))\n    mat2_tensor = flow.tensor(mat2, dtype=flow.float32, device=flow.device(device))\n    input_tensor = flow.tensor(\n        input, dtype=flow.float32, requires_grad=True, device=flow.device(device)\n    )\n    of_out = flow.addmm(input_tensor, mat1_tensor, mat2_tensor, alpha, beta).sum()\n    of_out.backward()\n    np_grad_out = np.ones_like(input) * beta\n    test_case.assertTrue(\n        np.allclose(input_tensor.grad.numpy(), np_grad_out, 1e-05, 1e-05)\n    )\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestAddmm(flow.unittest.TestCase):\n    def test_addmm(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"function_test\"] = [_test_addmm, _test_addmm_backward]\n        arg_dict[\"shape\"] = [(3, 3)]\n        arg_dict[\"alpha\"] = [4, 1.2, -3.7]\n        arg_dict[\"beta\"] = [1.5, 4, -2]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n    @autotest(n=5, rtol=1e-2, atol=1e-3)\n    def test_addmm_flow_with_random_data(test_case):\n        device = random_device()\n        input = random_tensor(ndim=2, dim0=2, dim1=3).to(device)\n        mat1 = random_tensor(ndim=2, dim0=2, dim1=4).to(device)\n        mat2 = random_tensor(ndim=2, dim0=4, dim1=3).to(device)\n        y = torch.addmm(\n            input,\n            mat1,\n            mat2,\n            beta=random().to(float) | nothing(),\n            alpha=random().to(float) | nothing(),\n        )\n        return y\n\n    @autotest(n=5, rtol=1e-2, atol=1e-3)\n    def test_addmm_broadcast_flow_with_random_data(test_case):\n        device = random_device()\n        input = random_tensor(ndim=2, dim0=1, dim1=1).to(device)\n        mat1 = random_tensor(ndim=2, dim0=2, dim1=4).to(device)\n        mat2 = random_tensor(ndim=2, dim0=4, dim1=3).to(device)\n        y = torch.addmm(\n            input,\n            mat1,\n            mat2,\n            beta=random().to(float) | nothing(),\n            alpha=random().to(float) | nothing(),\n        )\n        return y\n\n    @profile(torch.addmm)\n    def profile_addmm(test_case):\n        input = torch.ones(2, 3)\n        mat1 = torch.ones(2, 3)\n        mat2 = torch.ones(3, 3)\n        torch.addmm(input, mat1, mat2)\n        torch.addmm(input, mat1, mat2, alpha=1, beta=2)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_affine_grid.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom random import randint\nfrom random import choice\n\nimport numpy as np\n\nfrom oneflow.test_utils.automated_test_util import *\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestAffineGrid(flow.unittest.TestCase):\n    def test_affine_grid_2d(test_case):\n        input = flow.tensor(np.arange(1.0, 7).reshape((1, 2, 3)), dtype=flow.float32)\n        output = flow.nn.functional.affine_grid(\n            input, flow.Size([1, 1, 2, 2]), align_corners=True\n        )\n        groundtruth = np.array([[[[0.0, -3.0], [2.0, 5.0]], [[4.0, 7.0], [6.0, 15.0]]]])\n        test_case.assertTrue(\n            np.allclose(output.numpy(), groundtruth, rtol=1e-3, atol=1e-4)\n        )\n\n        output = flow.nn.functional.affine_grid(\n            input, flow.Size([1, 1, 2, 2]), align_corners=False\n        )\n        groundtruth = np.array([[[[1.5, 1.5], [2.5, 5.5]], [[3.5, 6.5], [4.5, 10.5]]]])\n        test_case.assertTrue(\n            np.allclose(output.numpy(), groundtruth, rtol=1e-3, atol=1e-4)\n        )\n\n    def test_affine_grid_3d(test_case):\n        input = flow.tensor(np.arange(1.0, 13).reshape((1, 3, 4)), dtype=flow.float32)\n        output = flow.nn.functional.affine_grid(\n            input, flow.Size([1, 1, 2, 2, 2]), align_corners=True\n        )\n        groundtruth = np.array(\n            [\n                [\n                    [\n                        [[-2.0, -10.0, -18.0], [0.0, 0.0, 0.0]],\n                        [[2.0, 2.0, 2.0], [4.0, 12.0, 20.0]],\n                    ],\n                    [\n                        [[4.0, 4.0, 4.0], [6.0, 14.0, 22.0]],\n                        [[8.0, 16.0, 24.0], [10.0, 26.0, 42.0]],\n                    ],\n                ]\n            ]\n        )\n        test_case.assertTrue(\n            np.allclose(output.numpy(), groundtruth, rtol=1e-3, atol=1e-4)\n        )\n\n        output = flow.nn.functional.affine_grid(\n            input, flow.Size([1, 1, 2, 2, 2]), align_corners=False\n        )\n        groundtruth = np.array(\n            [\n                [\n                    [\n                        [[1.0, -1.0, -3.0], [2.0, 4.0, 6.0]],\n                        [[3.0, 5.0, 7.0], [4.0, 10.0, 16.0]],\n                    ],\n                    [\n                        [[4.0, 6.0, 8.0], [5.0, 11.0, 17.0]],\n                        [[6.0, 12.0, 18.0], [7.0, 17.0, 27.0]],\n                    ],\n                ]\n            ]\n        )\n        test_case.assertTrue(\n            np.allclose(output.numpy(), groundtruth, rtol=1e-3, atol=1e-4)\n        )\n\n    @autotest(n=5, rtol=1e-03, atol=1e-04, check_allclose=False, check_graph=True)\n    def test_flow_affine_grid_2d_with_random_data(test_case):\n        N = randint(1, 8)\n        C = randint(1, 8)\n        H = randint(1, 8)\n        W = randint(1, 8)\n        device = random_device()\n        align_corners = choice([True, False])\n        theta = random_tensor(ndim=3, dim0=N, dim1=2, dim2=3).to(device)\n        output = torch.nn.functional.affine_grid(\n            theta, (N, C, H, W), align_corners=align_corners\n        ).to(device)\n        return output\n\n    @autotest(rtol=1e-03, atol=1e-03, check_allclose=False, check_graph=True)\n    def test_flow_affine_grid_3d_with_random_data(test_case):\n        N = randint(1, 8)\n        C = randint(1, 8)\n        D = randint(1, 8)\n        H = randint(1, 8)\n        W = randint(1, 8)\n        device = random_device()\n        align_corners = choice([True, False])\n        theta = random_tensor(ndim=3, dim0=N, dim1=3, dim2=4).to(device)\n        output = torch.nn.functional.affine_grid(\n            theta, (N, C, D, H, W), align_corners=align_corners\n        ).to(device)\n        return output\n\n    @profile(torch.nn.functional.affine_grid)\n    def profile_affine_grid(test_case):\n        input = torch.tensor(np.arange(1.0, 7).reshape((1, 2, 3)), dtype=torch.float32)\n        torch.nn.functional.affine_grid(\n            input, torch.Size([1, 1, 2, 2]), align_corners=True\n        )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_allclose.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\n\nfrom oneflow.test_utils.automated_test_util import *\n\nimport oneflow as flow\nimport oneflow.unittest\n\nrtol = 1e-3\n\n\ndef _perturbate(x):\n    shape = x.oneflow.shape\n    device = x.device\n    diff = (\n        random_tensor(len(shape), *shape, low=-1, high=1, requires_grad=False).to(\n            device\n        )\n        * rtol\n        * 2\n    )\n    return x + diff\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestAllClose(flow.unittest.TestCase):\n    @autotest(n=10, auto_backward=False, check_graph=False)\n    def test_allclose_with_random_data(test_case):\n        device = random_device()\n        shape = random_tensor().oneflow.shape\n        x1 = random_tensor(requires_grad=False).to(device)\n        x2 = _perturbate(x1)\n        y = torch.allclose(x1, x2, rtol=rtol)\n        return y\n\n    @autotest(n=10, auto_backward=False, check_graph=False)\n    def test_allclose_with_0dim_data(test_case):\n        device = random_device()\n        shape = random_tensor().oneflow.shape\n        x1 = random_tensor(requires_grad=False).to(device)\n        x2 = _perturbate(x1)\n        y = torch.allclose(x1, x2, rtol=rtol)\n        return y\n\n    @autotest(n=10, auto_backward=False, check_graph=False)\n    def test_tensor_allclose_with_random_data(test_case):\n        device = random_device()\n        shape = random_tensor().oneflow.shape\n        x1 = random_tensor(requires_grad=False).to(device)\n        x2 = _perturbate(x1)\n        y = x1.allclose(x2, rtol=rtol)\n        return y\n\n    @autotest(n=10, auto_backward=False, check_graph=False)\n    def test_allclose_broadcast(test_case):\n        device = random_device()\n        x1 = random_tensor(2, 2, 8, requires_grad=False).to(device)\n        x2 = _perturbate(x1[:, :1])\n        y = torch.allclose(x1, x2, rtol=rtol)\n        return y\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_allreduce.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport os\nimport unittest\n\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\nclass TestAllReduce(flow.unittest.TestCase):\n    @flow.unittest.skip_unless_1n2d()\n    def test_all_reduce(test_case):\n        arr_rank1 = np.array([1, 2])\n        arr_rank2 = np.array([3, 4])\n        if flow.env.get_rank() == 0:\n            x = flow.Tensor(arr_rank1)\n        elif flow.env.get_rank() == 1:\n            x = flow.Tensor(arr_rank2)\n        else:\n            raise ValueError\n        x = x.to(\"cuda\")\n        y = flow._C.local_all_reduce(x)\n        test_case.assertTrue(np.allclose(y.numpy(), arr_rank1 + arr_rank2))\n\n    @flow.unittest.skip_unless_2n2d()\n    def test_all_reduce_2nodes(test_case):\n        np_arr = np.array([1, 2])\n        x = flow.Tensor(np_arr * (flow.env.get_rank() + 1))\n        x = x.to(\"cuda\")\n        y = flow._C.local_all_reduce(x)\n        test_case.assertTrue(np.allclose(y.numpy(), np_arr * 10))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_amax.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\nfrom oneflow.test_utils.automated_test_util import *\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport numpy as np\n\n\ndef __check(test_case, input, dim, keepdim, device):\n    of_out = flow.amax(input, dim=dim, keepdim=keepdim)\n    if type(dim) is tuple:\n        if len(dim) == 0:\n            dim = None\n    np_out = np.amax(input.numpy(), axis=dim, keepdims=keepdim)\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, rtol=0.0001, atol=1e-05,))\n\n\ndef _test_amax_with_negative_dim(test_case, device):\n    input = flow.tensor(\n        np.random.randn(3, 5, 6, 8), dtype=flow.float32, device=flow.device(device)\n    )\n    dim = random(-4, 0).to(int).value()\n    keepdim = random_bool().value()\n    __check(test_case, input, dim, keepdim, device)\n\n\ndef _test_amax_with_positive_dim(test_case, device):\n    input = flow.tensor(\n        np.random.randn(3, 5, 6, 8), dtype=flow.float32, device=flow.device(device)\n    )\n    dim = random(0, 4).to(int).value()\n    keepdim = random_bool().value()\n    __check(test_case, input, dim, keepdim, device)\n\n\ndef _test_amax_with_multiple_axes(test_case, device):\n    input = flow.tensor(\n        np.random.randn(3, 5, 6, 8), dtype=flow.float32, device=flow.device(device)\n    )\n    axes = set()\n    num_axes = random(1, 4).to(int).value()\n    for _ in range(num_axes):\n        axes.add(random(0, 4).to(int).value())\n    keepdim = random_bool().value()\n    __check(test_case, input, tuple(axes), keepdim, device)\n\n\ndef _test_amax_with_empty_dim(test_case, device):\n    input = flow.tensor(\n        np.random.randn(3, 5, 6, 8), dtype=flow.float32, device=flow.device(device)\n    )\n    keepdim = random_bool().value()\n    __check(test_case, input, None, keepdim, device)\n\n\ndef _test_amax_keepdim(test_case, device):\n    input = flow.tensor(\n        np.random.randn(3, 5, 6, 8), dtype=flow.float32, device=flow.device(device)\n    )\n    dim = random(-4, 4).to(int).value()\n    keepdim = True\n    __check(test_case, input, dim, keepdim, device)\n\n\ndef _test_amax_not_keepdim(test_case, device):\n    input = flow.tensor(\n        np.random.randn(3, 5, 6, 8), dtype=flow.float32, device=flow.device(device)\n    )\n    dim = random(-4, 4).to(int).value()\n    keepdim = False\n    __check(test_case, input, dim, keepdim, device)\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestAmax(flow.unittest.TestCase):\n    def test_amax(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [\n            _test_amax_with_negative_dim,\n            _test_amax_with_positive_dim,\n            _test_amax_with_multiple_axes,\n            _test_amax_with_empty_dim,\n            _test_amax_keepdim,\n            _test_amax_not_keepdim,\n        ]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n    @autotest(n=5)\n    def test_amax_with_random_data_single_dim(test_case):\n        device = random_device()\n        ndim = random(1, 6).to(int)\n        x = random_tensor(ndim=ndim).to(device)\n        y = torch.amax(x, dim=random(0, ndim), keepdim=random().to(bool))\n        return y\n\n    @autotest(n=5)\n    def test_amax_with_random_data_empty_dim(test_case):\n        device = random_device()\n        ndim = random(1, 6).to(int)\n        x = random_tensor(ndim=ndim).to(device)\n        y = torch.amax(x, dim=None, keepdim=random().to(bool))\n        return y\n\n    @autotest(n=5)\n    def test_amax_with_random_data_multi_dims(test_case):\n        device = random_device()\n        ndim = random(2, 6).to(int)\n        x = random_tensor(ndim=ndim).to(device)\n        dim = set()\n        for _ in range(random(1, ndim).to(int).value()):\n            dim.add(random(0, ndim).to(int).value())\n        y = torch.amax(x, dim=tuple(dim), keepdim=random().to(bool))\n        return y\n\n    @profile(torch.amax)\n    def profile_amax(test_case):\n        input1 = torch.ones(4, 4)\n        input2 = torch.ones(100, 100)\n        torch.amax(input1, 1)\n        torch.amax(input1, 1, True)\n        torch.amax(input2, 1)\n        torch.amax(input2, 1, True)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_amin.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nfrom collections import OrderedDict\nfrom oneflow.test_utils.automated_test_util import *\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport numpy as np\n\n\ndef __check(test_case, input, dim, keepdim, device):\n    of_out = flow.amin(input, dim=dim, keepdim=keepdim)\n    if type(dim) is tuple:\n        if len(dim) == 0:\n            dim = None\n    np_out = np.amin(input.numpy(), axis=dim, keepdims=keepdim)\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, rtol=0.0001, atol=1e-05,))\n\n\ndef _test_amin_with_negative_dim(test_case, device):\n    input = flow.tensor(\n        np.random.randn(3, 5, 6, 8), dtype=flow.float32, device=flow.device(device)\n    )\n    dim = random(-4, 0).to(int).value()\n    keepdim = random_bool().value()\n    __check(test_case, input, dim, keepdim, device)\n\n\ndef _test_amin_with_positive_dim(test_case, device):\n    input = flow.tensor(\n        np.random.randn(3, 5, 6, 8), dtype=flow.float32, device=flow.device(device)\n    )\n    dim = random(0, 4).to(int).value()\n    keepdim = random_bool().value()\n    __check(test_case, input, dim, keepdim, device)\n\n\ndef _test_amin_with_multiple_axes(test_case, device):\n    input = flow.tensor(\n        np.random.randn(3, 5, 6, 8), dtype=flow.float32, device=flow.device(device)\n    )\n    axes = set()\n    num_axes = random(1, 4).to(int).value()\n    for _ in range(num_axes):\n        axes.add(random(0, 4).to(int).value())\n    keepdim = random_bool().value()\n    __check(test_case, input, tuple(axes), keepdim, device)\n\n\ndef _test_amin_with_empty_dim(test_case, device):\n    input = flow.tensor(\n        np.random.randn(3, 5, 6, 8), dtype=flow.float32, device=flow.device(device)\n    )\n    keepdim = random_bool().value()\n    __check(test_case, input, None, keepdim, device)\n\n\ndef _test_amin_keepdim(test_case, device):\n    input = flow.tensor(\n        np.random.randn(3, 5, 6, 8), dtype=flow.float32, device=flow.device(device)\n    )\n    dim = random(-4, 4).to(int).value()\n    keepdim = True\n    __check(test_case, input, dim, keepdim, device)\n\n\ndef _test_amin_not_keepdim(test_case, device):\n    input = flow.tensor(\n        np.random.randn(3, 5, 6, 8), dtype=flow.float32, device=flow.device(device)\n    )\n    dim = random(-4, 4).to(int).value()\n    keepdim = False\n    __check(test_case, input, dim, keepdim, device)\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestAmin(flow.unittest.TestCase):\n    def test_amin(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [\n            _test_amin_with_negative_dim,\n            _test_amin_with_positive_dim,\n            _test_amin_with_multiple_axes,\n            _test_amin_with_empty_dim,\n            _test_amin_keepdim,\n            _test_amin_not_keepdim,\n        ]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n    @autotest(n=5)\n    def test_amin_with_random_data_single_dim(test_case):\n        device = random_device()\n        ndim = random(1, 6).to(int)\n        x = random_tensor(ndim=ndim).to(device)\n        y = torch.amin(x, dim=random(0, ndim), keepdim=random().to(bool))\n        return y\n\n    @autotest(n=5)\n    def test_amin_with_random_data_empty_dim(test_case):\n        device = random_device()\n        ndim = random(1, 6).to(int)\n        x = random_tensor(ndim=ndim).to(device)\n        y = torch.amin(x, dim=None, keepdim=random().to(bool))\n        return y\n\n    @autotest(n=5)\n    def test_amin_with_random_data_multi_dims(test_case):\n        device = random_device()\n        ndim = random(2, 6).to(int)\n        x = random_tensor(ndim=ndim).to(device)\n        dim = set()\n        for _ in range(random(1, ndim).to(int).value()):\n            dim.add(random(0, ndim).to(int).value())\n        y = torch.amin(x, dim=tuple(dim), keepdim=random().to(bool))\n        return y\n\n    @profile(torch.amin)\n    def profile_amin(test_case):\n        input1 = torch.ones(4, 4)\n        input2 = torch.ones(100, 100)\n        torch.amin(input1, 1)\n        torch.amin(input1, 1, True)\n        torch.amin(input2, 1)\n        torch.amin(input2, 1, True)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_arange.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\ndef _test_arange(test_case, device):\n    np_out = np.arange(13, dtype=np.float32)\n    of_out = flow.arange(13, device=device, dtype=flow.float32)\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05))\n\n    np_out = np.arange(13, dtype=np.float16)\n    of_out = flow.arange(13, device=device, dtype=flow.float16)\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05))\n\n\ndef _test_arange_step_prarm(test_case, device):\n    np_out = np.arange(0, 20, 2)\n    of_out = flow.arange(0, 20, step=2, device=device)\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05))\n\n\ndef _test_arange_more_params(test_case, device):\n    np_out = np.arange(0, 100, 3)\n    of_out = flow.arange(start=0, end=100, step=3, device=device)\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05))\n\n\ndef _test_arange_backward(test_case, device):\n    x = flow.arange(13, dtype=flow.float32, device=device)\n    x.requires_grad = True\n    y = x.sum()\n    y.backward()\n    test_case.assertTrue(np.allclose(x.grad.numpy(), np.ones(13), 1e-05, 1e-05))\n\n    x = flow.arange(13, dtype=flow.float16, device=device)\n    x.requires_grad = True\n    y = x.sum()\n    y.backward()\n    test_case.assertTrue(np.allclose(x.grad.numpy(), np.ones(13), 1e-05, 1e-05))\n\n\ndef _test_arange_input_tensor_type(test_case, device):\n    x = flow.tensor([[1, 2], [3, 4]], dtype=flow.int64).to(device)\n    y = flow.arange(start=flow.min(x), end=flow.max(x), device=device)\n    test_case.assertTrue(np.allclose(y.numpy(), np.arange(1, 4)))\n\n    x = flow.tensor([[1, 2], [3, 4]], dtype=flow.int64).to(device)\n    y = flow.arange(\n        start=flow.min(x), end=flow.max(x), device=device, dtype=flow.float16\n    )\n    test_case.assertTrue(np.allclose(y.numpy(), np.arange(1, 4)))\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestArange(flow.unittest.TestCase):\n    def test_arange(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"function_test\"] = [\n            _test_arange,\n            _test_arange_step_prarm,\n            _test_arange_more_params,\n            _test_arange_backward,\n            _test_arange_input_tensor_type,\n        ]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n    @autotest(n=10, auto_backward=False, rtol=1e-5, atol=1e-5, check_graph=True)\n    def test_arange_with_random_data(test_case):\n        start = random().to(int)\n        end = start + random().to(int)\n        step = random(1, end - start + 1).to(int)\n        x = torch.arange(start=start, end=end, step=step)\n        device = random_device()\n        x.to(device)\n        return x\n\n    @autotest(n=5, auto_backward=False, rtol=1e-5, atol=1e-5, check_graph=True)\n    def test_arange_with_float_delta(test_case):\n        start = random().to(int)\n        end = start + random().to(int)\n        step = random(1, end - start + 1).to(float)\n        x = torch.arange(start=start, end=end, step=step)\n        device = random_device()\n        x.to(device)\n        return x\n\n    @autotest(n=5, auto_backward=False, rtol=1e-5, atol=1e-5, check_graph=True)\n    def test_arange_input_float_scalar_tensor(test_case):\n        start = random().to(float)\n        end = start + random().to(float)\n        x = torch.arange(start=torch.tensor(start), end=torch.tensor(end))\n        device = random_device()\n        x.to(device)\n        return x\n\n    @autotest(n=5, auto_backward=False, rtol=1e-5, atol=1e-5, check_graph=True)\n    def test_arange_input_float16_scalar_tensor(test_case):\n        start = random().to(float)\n        end = start + random().to(float)\n        start, end = torch.tensor(start).half(), torch.tensor(end).half()\n        x = torch.arange(start=start, end=end)\n        device = random_device()\n        x.to(device)\n        return x\n\n    def test_global_naive(test_case):\n        placement = flow.placement(\"cpu\", ranks=[0])\n        sbp = (flow.sbp.broadcast,)\n        x = flow.arange(start=0, end=10, step=1, placement=placement, sbp=sbp)\n        test_case.assertEqual(x.sbp, sbp)\n        test_case.assertEqual(x.placement, placement)\n\n    @profile(torch.arange)\n    def profile_arange(test_case):\n        torch.arange(5)\n        torch.arange(100000)\n        torch.arange(1, 4)\n        torch.arange(1, 2.5, 0.5)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_argmax.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\ndef _test_argmax_axis_negative(test_case, device):\n    input = flow.tensor(\n        np.random.randn(2, 6, 5, 3), dtype=flow.float32, device=flow.device(device)\n    )\n    axis = -1\n    of_out = flow.argmax(input, dim=axis)\n    np_out = np.argmax(input.numpy(), axis=axis)\n    test_case.assertTrue(np.array_equal(of_out.numpy().flatten(), np_out.flatten()))\n\n\ndef _test_tensor_argmax(test_case, device):\n    input = flow.tensor(\n        np.random.randn(2, 6, 5, 3), dtype=flow.float32, device=flow.device(device)\n    )\n    axis = 0\n    of_out = input.argmax(dim=axis)\n    np_out = np.argmax(input.numpy(), axis=axis)\n    test_case.assertTrue(np.array_equal(of_out.numpy().shape, np_out.shape))\n    test_case.assertTrue(np.array_equal(of_out.numpy().flatten(), np_out.flatten()))\n\n\ndef _test_argmax_axis_postive(test_case, device):\n    input = flow.tensor(\n        np.random.randn(2, 6, 5, 3), dtype=flow.float32, device=flow.device(device)\n    )\n    axis = 1\n    of_out = flow.argmax(input, dim=axis)\n    np_out = np.argmax(input.numpy(), axis=axis)\n    test_case.assertTrue(np.array_equal(of_out.numpy().flatten(), np_out.flatten()))\n\n\ndef _test_argmax_keepdims(test_case, device):\n    input = flow.tensor(\n        np.random.randn(2, 6, 5, 3), dtype=flow.float32, device=flow.device(device)\n    )\n    axis = 0\n    of_out = input.argmax(axis, True)\n    np_out = np.argmax(input.numpy(), axis=axis)\n    np_out = np.expand_dims(np_out, axis=axis)\n    test_case.assertTrue(np.array_equal(of_out.numpy().shape, np_out.shape))\n    test_case.assertTrue(np.array_equal(of_out.numpy().flatten(), np_out.flatten()))\n\n\ndef _test_argmax_dim_equal_none(test_case, device):\n    input = flow.tensor(\n        np.random.randn(2, 6, 5, 3), dtype=flow.float32, device=flow.device(device)\n    )\n    of_out = input.argmax()\n    np_out = np.argmax(input.numpy().flatten(), axis=0)\n    test_case.assertTrue(np.array_equal(of_out.numpy().flatten(), np_out.flatten()))\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestArgmax(flow.unittest.TestCase):\n    def test_argmax(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [\n            _test_argmax_axis_negative,\n            _test_tensor_argmax,\n            _test_argmax_axis_postive,\n            _test_argmax_keepdims,\n            _test_argmax_dim_equal_none,\n        ]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n    @autotest(n=5, auto_backward=False, rtol=1e-5, atol=1e-5, check_graph=True)\n    def test_argmax_with_random_data(test_case):\n        device = random_device()\n        ndim = random(1, 6).to(int)\n        x = random_tensor(ndim=ndim).to(device)\n        y = torch.argmax(x, dim=random(0, ndim).to(int), keepdim=random().to(bool))\n        return y\n\n    @profile(torch.argmax)\n    def profile_argmax(test_case):\n        torch.argmax(torch.ones(100000))\n        torch.argmax(torch.ones(1000000))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_argmin.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\ndef _test_argmin_axis_negative(test_case, device):\n    input = flow.tensor(\n        np.random.randn(2, 6, 5, 3), dtype=flow.float32, device=flow.device(device)\n    )\n    axis = -1\n    of_out = flow.argmin(input, dim=axis)\n    np_out = np.argmin(input.numpy(), axis=axis)\n    test_case.assertTrue(np.array_equal(of_out.numpy().flatten(), np_out.flatten()))\n\n\ndef _test_tensor_argmin(test_case, device):\n    input = flow.tensor(\n        np.random.randn(2, 6, 5, 3), dtype=flow.float32, device=flow.device(device)\n    )\n    axis = 0\n    of_out = input.argmin(dim=axis)\n    np_out = np.argmin(input.numpy(), axis=axis)\n    test_case.assertTrue(np.array_equal(of_out.numpy().shape, np_out.shape))\n    test_case.assertTrue(np.array_equal(of_out.numpy().flatten(), np_out.flatten()))\n\n\ndef _test_argmin_axis_postive(test_case, device):\n    input = flow.tensor(\n        np.random.randn(2, 6, 5, 3), dtype=flow.float32, device=flow.device(device)\n    )\n    axis = 1\n    of_out = flow.argmin(input, dim=axis)\n    np_out = np.argmin(input.numpy(), axis=axis)\n    test_case.assertTrue(np.array_equal(of_out.numpy().flatten(), np_out.flatten()))\n\n\ndef _test_argmin_keepdims(test_case, device):\n    input = flow.tensor(\n        np.random.randn(2, 6, 5, 3), dtype=flow.float32, device=flow.device(device)\n    )\n    axis = 0\n    of_out = input.argmin(axis, True)\n    np_out = np.argmin(input.numpy(), axis=axis)\n    np_out = np.expand_dims(np_out, axis=axis)\n    test_case.assertTrue(np.array_equal(of_out.numpy().shape, np_out.shape))\n    test_case.assertTrue(np.array_equal(of_out.numpy().flatten(), np_out.flatten()))\n\n\ndef _test_argmin_dim_equal_none(test_case, device):\n    input = flow.tensor(\n        np.random.randn(2, 6, 5, 3), dtype=flow.float32, device=flow.device(device)\n    )\n    of_out = input.argmin()\n    np_out = np.argmin(input.numpy().flatten(), axis=0)\n    test_case.assertTrue(np.array_equal(of_out.numpy().flatten(), np_out.flatten()))\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestArgmin(flow.unittest.TestCase):\n    def test_argmin(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [\n            _test_argmin_axis_negative,\n            _test_tensor_argmin,\n            _test_argmin_axis_postive,\n            _test_argmin_keepdims,\n            _test_argmin_dim_equal_none,\n        ]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n    @autotest(n=5, auto_backward=False, rtol=1e-5, atol=1e-5, check_graph=True)\n    def test_argmin_with_random_data(test_case):\n        device = random_device()\n        ndim = random(1, 6).to(int)\n        x = random_tensor(ndim=ndim).to(device)\n        y = torch.argmin(x, dim=random(0, ndim).to(int), keepdim=random().to(bool))\n        return y\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_argsort.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nfrom oneflow.test_utils.test_util import GenArgList, type_name_to_flow_type\n\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\ndef _test_argsort(test_case, data_shape, axis, descending, data_type, device):\n    input = flow.tensor(\n        np.random.randn(*data_shape),\n        dtype=type_name_to_flow_type[data_type],\n        device=flow.device(device),\n    )\n    np_input = -input.numpy() if descending else input.numpy()\n    if axis is not None:\n        of_out = flow.argsort(input, dim=axis, descending=descending)\n        np_out = np.argsort(np_input, axis=axis)\n    else:\n        of_out = flow.argsort(input, descending=descending)\n        np_out = np.argsort(np_input)\n    test_case.assertTrue(np.array_equal(of_out.numpy().flatten(), np_out.flatten()))\n\n\ndef _test_tensor_argsort(test_case, data_shape, axis, descending, data_type, device):\n    input = flow.tensor(\n        np.random.randn(*data_shape),\n        dtype=type_name_to_flow_type[data_type],\n        device=flow.device(device),\n    )\n    np_input = -input.numpy() if descending else input.numpy()\n    if axis is not None:\n        of_out = input.argsort(dim=axis, descending=descending)\n        np_out = np.argsort(np_input, axis=axis)\n    else:\n        of_out = input.argsort(descending=descending)\n        np_out = np.argsort(np_input)\n    test_case.assertTrue(np.array_equal(of_out.numpy().shape, np_out.shape))\n    test_case.assertTrue(np.array_equal(of_out.numpy().flatten(), np_out.flatten()))\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestArgsort(flow.unittest.TestCase):\n    def test_argsort(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [_test_argsort, _test_tensor_argsort]\n        arg_dict[\"data_shape\"] = [(2, 6, 5, 4), (3, 4, 8)]\n        arg_dict[\"axis\"] = [-1, 0, 2, None]\n        arg_dict[\"descending\"] = [True, False]\n        arg_dict[\"data_type\"] = [\"double\", \"float32\", \"int32\"]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n    @autotest(auto_backward=False, check_graph=True)\n    def test_argsort_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=4).to(device)\n        y = torch.argsort(\n            x, dim=random(low=-4, high=4).to(int), descending=random_bool()\n        )\n        return y\n\n    @autotest(auto_backward=False, check_graph=True)\n    def test_argsort_bool_with_random_data(test_case):\n        x = random_tensor(ndim=4).to(\"cpu\", torch.bool)\n        y = torch.argsort(\n            x, dim=random(low=-4, high=4).to(int), descending=random_bool()\n        )\n        return y\n\n    @profile(torch.argsort)\n    def profile_argsort(test_case):\n        torch.argsort(torch.ones(10, 10), dim=1)\n        torch.argsort(torch.ones(1000, 1000), dim=1)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_argwhere.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport oneflow.unittest\nfrom packaging import version\nfrom oneflow.test_utils.automated_test_util import *\nimport torch as torch_original\n\n\ndef _test_argwhere(test_case, shape, device):\n    np_input = np.random.randn(*shape)\n    input = flow.tensor(np_input, dtype=flow.float32, device=flow.device(device))\n    of_out = flow.argwhere(input)\n    np_out = np.argwhere(np_input)\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001))\n    test_case.assertTrue(np.array_equal(of_out.numpy().shape, np_out.shape))\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestArgwhere(flow.unittest.TestCase):\n    def test_argwhere(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [_test_argwhere]\n        arg_dict[\"shape\"] = [(2, 3), (2, 3, 4), (2, 4, 5, 6), (2, 3, 0, 4)]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n    @unittest.skip(\"pytorch do not have argwhere fn/module yet!\")\n    @autotest(n=5, rtol=1e-5, atol=1e-5)\n    def test_argwhere_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=random(2, 5).to(int)).to(device)\n        y = torch.argwhere(x)\n        return y\n\n    has_pytorch_1_11 = version.parse(torch_original.__version__) >= version.parse(\n        \"1.11.0\"\n    )\n\n    @unittest.skipIf(\n        not has_pytorch_1_11, \"torch.argwhere only exists in PyTorch >= 1.11.0\"\n    )\n    @profile(torch.argwhere if has_pytorch_1_11 else None)\n    def profile_argwhere(test_case):\n        torch.argwhere(torch.ones(3, 3, 100, 100))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_as_strided.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nimport numpy as np\nfrom random import shuffle\n\nfrom oneflow.test_utils.automated_test_util import *\nimport oneflow as flow\nimport oneflow.unittest\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestAsStrided(flow.unittest.TestCase):\n    @autotest(n=10)\n    def test_flow_AsStrided(test_case):\n        device = random_device()\n        ndim = np.random.randint(3, 6)\n        dim0 = np.random.randint(2, 4)\n        dim1 = np.random.randint(2, 4)\n        dim2 = np.random.randint(2, 4)\n        dim3 = np.random.randint(2, 4)\n        dim4 = np.random.randint(2, 4)\n        if ndim == 3:\n            x = random_tensor(3, dim0, dim1, dim2)\n        elif ndim == 4:\n            x = random_tensor(4, dim0, dim1, dim2, dim3)\n        elif ndim == 5:\n            x = random_tensor(5, dim0, dim1, dim2, dim3, dim4)\n        x = x.to(device)\n        storage_offset = random(0, 3).to(int)\n        z = torch.as_strided(x, (2, 2, 3), (1, 1, 2), storage_offset)\n        return z\n\n    @autotest(n=5)\n    def test_tensor_as_strided(test_case):\n        device = random_device()\n        ndim = np.random.randint(3, 6)\n        dim0 = np.random.randint(2, 4)\n        dim1 = np.random.randint(2, 4)\n        dim2 = np.random.randint(2, 4)\n        dim3 = np.random.randint(2, 4)\n        dim4 = np.random.randint(2, 4)\n        if ndim == 3:\n            x = random_tensor(3, dim0, dim1, dim2)\n        elif ndim == 4:\n            x = random_tensor(4, dim0, dim1, dim2, dim3)\n        elif ndim == 5:\n            x = random_tensor(5, dim0, dim1, dim2, dim3, dim4)\n        x = x.to(device)\n        storage_offset = random(0, 3).to(int)\n        y = x.as_strided((2, 2, 3), (1, 1, 2), storage_offset)\n        return y\n\n    @autotest(n=10)\n    def test_flow_as_strided_tensor_method(test_case):\n        device = random_device()\n        ndim = np.random.randint(3, 6)\n        x = random_tensor(ndim, *[np.random.randint(2, 4) for _ in range(ndim)])\n        x = x.to(device)\n        storage_offset = random(0, 3).to(int)\n        z = x.as_strided((2, 2, 3), (1, 1, 2), storage_offset)\n        return z\n\n    @autotest(n=10)\n    def test_flow_as_strided_with_stride(test_case):\n        device = random_device()\n        dim0 = np.random.randint(2, 4)\n        dim1 = np.random.randint(2, 4)\n        dim2 = np.random.randint(2, 4)\n        dim3 = np.random.randint(2, 4)\n        x = random_tensor(4, dim0, dim1, dim2, dim3)\n        x = x.to(device)\n        storage_offset = random(0, 3).to(int)\n        perm = [0, 1, 2, 3]\n        shuffle(perm)\n        y = x.permute(perm)\n        z = torch.as_strided(y, (2, 2, 3), (1, 1, 2), storage_offset)\n        return z\n\n    @autotest(n=5, auto_backward=False)\n    def test_flow_as_strided_bool(test_case):\n        device = random_device()\n        ndim = np.random.randint(3, 6)\n        dim0 = np.random.randint(2, 4)\n        dim1 = np.random.randint(2, 4)\n        dim2 = np.random.randint(2, 4)\n        dim3 = np.random.randint(2, 4)\n        dim4 = np.random.randint(2, 4)\n        if ndim == 3:\n            x = random_tensor(3, dim0, dim1, dim2)\n        elif ndim == 4:\n            x = random_tensor(4, dim0, dim1, dim2, dim3)\n        elif ndim == 5:\n            x = random_tensor(5, dim0, dim1, dim2, dim3, dim4)\n        x = x.to(device)\n        x = x.to(torch.bool)\n        storage_offset = random(0, 3).to(int)\n        z = torch.as_strided(x, (2, 2, 3), (1, 1, 2), storage_offset)\n        return z\n\n    @autotest(n=5, auto_backward=False)\n    def test_flow_as_strided_int8(test_case):\n        device = random_device()\n        ndim = np.random.randint(3, 6)\n        dim0 = np.random.randint(2, 4)\n        dim1 = np.random.randint(2, 4)\n        dim2 = np.random.randint(2, 4)\n        dim3 = np.random.randint(2, 4)\n        dim4 = np.random.randint(2, 4)\n        if ndim == 3:\n            x = random_tensor(3, dim0, dim1, dim2)\n        elif ndim == 4:\n            x = random_tensor(4, dim0, dim1, dim2, dim3)\n        elif ndim == 5:\n            x = random_tensor(5, dim0, dim1, dim2, dim3, dim4)\n        x = x.to(device)\n        x = x.to(torch.int8)\n        storage_offset = random(0, 3).to(int)\n        z = torch.as_strided(x, (2, 2, 3), (1, 1, 2), storage_offset)\n        return z\n\n    @autotest(n=5, auto_backward=False)\n    def test_flow_as_strided_uint8(test_case):\n        device = random_device()\n        ndim = np.random.randint(3, 6)\n        dim0 = np.random.randint(2, 4)\n        dim1 = np.random.randint(2, 4)\n        dim2 = np.random.randint(2, 4)\n        dim3 = np.random.randint(2, 4)\n        dim4 = np.random.randint(2, 4)\n        if ndim == 3:\n            x = random_tensor(3, dim0, dim1, dim2)\n        elif ndim == 4:\n            x = random_tensor(4, dim0, dim1, dim2, dim3)\n        elif ndim == 5:\n            x = random_tensor(5, dim0, dim1, dim2, dim3, dim4)\n        x = x.to(device)\n        x = x.to(torch.uint8)\n        storage_offset = random(0, 3).to(int)\n        z = torch.as_strided(x, (2, 2, 3), (1, 1, 2), storage_offset)\n        return z\n\n    @profile(torch.as_strided)\n    def profile_as_strided(test_case):\n        input = torch.ones(10, 10, 128, 128)\n        torch.as_strided(input, (10, 3, 128, 128), (1, 1, 1, 1))\n        torch.as_strided(input, (10, 3, 128, 128), (1, 1, 1, 1), 1)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_as_tensor.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport os\nimport random\nimport unittest\n\nimport numpy as np\nimport oneflow as flow\nimport oneflow.unittest\n\n\nnumpy_dtype_to_oneflow_dtype_dict = {\n    np.int32: flow.int32,\n    np.int64: flow.int64,\n    np.int8: flow.int8,\n    np.uint8: flow.uint8,\n    np.float64: flow.float64,\n    np.float32: flow.float32,\n    np.float16: flow.float16,\n}\n\n\n@flow.unittest.skip_unless_1n1d()\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test gpu cases\")\nclass TestAsTensor(flow.unittest.TestCase):\n    def test_tensor_type(test_case):\n        x = flow.randn(2, 3)\n        y = flow.as_tensor(x)\n        y[0] = 2.0\n        test_case.assertTrue(np.array_equal(x.numpy(), y.numpy()))\n        test_case.assertTrue(np.array_equal(id(x), id(y)))\n\n        x = flow.randn(2, 3)\n        x = x.to(\"cuda\")\n        y = flow.as_tensor(x)\n        y[0] = 2.0\n        test_case.assertTrue(np.array_equal(x.numpy(), y.numpy()))\n        test_case.assertTrue(np.array_equal(id(x), id(y)))\n\n        x = flow.randn(2, 3)\n        y = flow.as_tensor(x, device=flow.device(\"cuda:0\"))\n        test_case.assertTrue(id(x) != id(y))\n\n        for dtype in [\n            flow.float64,\n            flow.float16,\n            flow.int64,\n            flow.int32,\n            flow.int8,\n            flow.uint8,\n        ]:\n            x = flow.randn(2, 3)\n            y = flow.as_tensor(x, dtype=dtype)\n            test_case.assertTrue(id(x) != id(y))\n\n    def test_numpy_type(test_case):\n        for device in [flow.device(\"cpu\"), flow.device(\"cuda:0\"), None]:\n            for np_dtype in [\n                np.float64,\n                np.float32,\n                np.float16,\n                np.int64,\n                np.int32,\n                np.int8,\n                np.uint8,\n            ]:\n                for flow_dtype in [\n                    flow.float64,\n                    flow.float16,\n                    flow.int64,\n                    flow.int32,\n                    flow.int8,\n                    flow.uint8,\n                ]:\n                    np_arr = np.ones((2, 3), dtype=np_dtype)\n                    try:\n                        tensor = flow.as_tensor(np_arr, dtype=flow_dtype)\n                        if numpy_dtype_to_oneflow_dtype_dict[\n                            np_arr.dtype\n                        ] == flow_dtype and device is not flow.device(\"cuda:0\"):\n                            tensor[0][0] += 1.0\n                            test_case.assertTrue(np.array_equal(np_arr, tensor.numpy()))\n                        else:\n                            test_case.assertTrue(np.array_equal(np_arr, tensor.numpy()))\n                    except Exception as e:\n                        # Ignore cast or kernel mismatch error in test example\n                        pass\n\n    def test_other_type(test_case):\n        for device in [flow.device(\"cpu\"), flow.device(\"cuda:0\"), None]:\n            for np_dtype in [\n                np.float64,\n                np.float32,\n                np.float16,\n                np.int64,\n                np.int32,\n                np.int8,\n                np.uint8,\n            ]:\n                for flow_dtype in [\n                    flow.float64,\n                    flow.float16,\n                    flow.int64,\n                    flow.int32,\n                    flow.int8,\n                    flow.uint8,\n                ]:\n                    # tuple\n                    np_arr = (1.0, 2.0, 3.0)\n                    try:\n                        tensor = flow.as_tensor(np_arr, dtype=flow_dtype)\n                        test_case.assertTrue(np.array_equal(np_arr, tensor.numpy()))\n                    except Exception as e:\n                        # Ignore cast or kernel mismatch error in test example\n                        pass\n                    # tuple\n                    np_arr = [1.0, 2.0, 3.0]\n                    try:\n                        tensor = flow.as_tensor(np_arr, dtype=flow_dtype)\n                        test_case.assertTrue(np.array_equal(np_arr, tensor.numpy()))\n                    except Exception as e:\n                        # Ignore cast or kernel mismatch error in test example\n                        pass\n                    # scalar\n                    np_arr = 4.0\n                    try:\n                        tensor = flow.as_tensor(np_arr, dtype=flow_dtype)\n                        test_case.assertTrue(np.array_equal(np_arr, tensor.numpy()))\n                    except Exception as e:\n                        # Ignore cast or kernel mismatch error in test example\n                        pass\n\n    def test_numpy_dtype_bug(test_case):\n        test_case.assertEqual(flow.as_tensor([1.0]).dtype, flow.float32)\n        x = np.random.randn(10)\n        y1 = flow.as_tensor(x, dtype=flow.int64)\n        y2 = flow.as_tensor(x, dtype=flow.float64)\n        test_case.assertEqual(y1.dtype, flow.int64)\n        test_case.assertEqual(y2.dtype, flow.float64)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_asyncs_thread.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport os\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestLocalThread(flow.unittest.TestCase):\n    def test_stream(test_case):\n        with flow.asyncs.thread(flow.asyncs.Thread()):\n            test_case.assertEqual(flow.ones(1)[0], 1)\n\n\n@flow.unittest.skip_unless_1n2d()\nclass TestGlobalThread(flow.unittest.TestCase):\n    def test_cpu_stream(test_case):\n        threads = [flow.asyncs.Thread() for i in range(7)]\n        iter_and_threads = [(i, threads[i % 7]) for i in range(30)]\n        for i, thread in iter_and_threads:\n            with flow.asyncs.thread(thread):\n                placement = flow.placement(\"cpu\", [0, 1])\n                tensor = flow.ones(2, placement=placement, sbp=flow.sbp.split(0))\n                test_case.assertEqual(tensor[0], 1)\n                test_case.assertEqual(tensor[1], 1)\n\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    def test_cuda_stream(test_case):\n        threads = [flow.asyncs.Thread() for i in range(7)]\n        iter_and_threads = [(i, threads[i % 7]) for i in range(200)]\n        tensors = []\n        dim = 0\n        for i, thread in iter_and_threads:\n            dim += 1\n            with flow.asyncs.thread(thread):\n                placement = flow.placement(\"cuda\", [0, 1])\n                ones = flow.ones(2 * dim, placement=placement, sbp=flow.sbp.split(0))\n                tensors.append(ones.to_global(sbp=flow.sbp.broadcast) + i)\n        for i, tensor in enumerate(tensors):\n            test_case.assertEqual(tensor[0], 1 + i)\n            test_case.assertEqual(tensor[int(tensor.shape[0] / 2)], 1 + i)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_atleast.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\n\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestAtLeast(flow.unittest.TestCase):\n    @autotest(n=5)\n    def test_atleast_1d_with_list_random_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=0).to(device)\n        y = random_tensor(ndim=2).to(device)\n        out = torch.atleast_1d([x, y])\n        return out\n\n    @autotest(n=5)\n    def test_atleast_1d_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=random(low=0, high=3).to(int)).to(device)\n        out = torch.atleast_1d(x)\n        return out\n\n    @autotest(n=5)\n    def test_atleast_2d_with_list_random_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=0).to(device)\n        y = random_tensor(ndim=1).to(device)\n        z = random_tensor(ndim=3).to(device)\n        out = torch.atleast_2d([x, y, z])\n        return out\n\n    @autotest(n=5)\n    def test_atleast_2d_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=random(low=0, high=4).to(int)).to(device)\n        out = torch.atleast_2d(x)\n        return out\n\n    @autotest(n=5)\n    def test_atleast_3d_with_list_random_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=0).to(device)\n        y = random_tensor(ndim=1).to(device)\n        z = random_tensor(ndim=2).to(device)\n        p = random_tensor(ndim=4).to(device)\n        out = torch.atleast_3d([x, y, z, p])\n        return out\n\n    @autotest(n=5)\n    def test_atleast_3d_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=random(low=0, high=5).to(int)).to(device)\n        out = torch.atleast_3d(x)\n        return out\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_auto_to_global.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nimport oneflow as flow\nimport os\n\nfrom oneflow.test_utils.automated_test_util.torch_flow_dual_object import globaltest\nfrom oneflow.test_utils.test_util import GenArgList\n\n\ndef _test_auto_to_global(test_case, device):\n    os.environ[\"ONEFLOW_ENABLE_GLOBAL_INPUTS_WITH_INCONSISTENT_PLACEMENT\"] = \"true\"\n    x = flow.ones(\n        (2, 2),\n        sbp=[flow.sbp.broadcast, flow.sbp.broadcast],\n        placement=flow.placement(device, ranks=[[0], [1]]),\n    )\n    y = flow.zeros(\n        (2, 2),\n        sbp=[flow.sbp.broadcast, flow.sbp.broadcast],\n        placement=flow.placement(device, ranks=[[2], [3]]),\n    )\n    z = x + y\n    test_case.assertTrue(np.array_equal(x.numpy(), z.numpy()))\n    test_case.assertEqual(y.placement, z.placement)\n    os.environ[\"ONEFLOW_ENABLE_GLOBAL_INPUTS_WITH_INCONSISTENT_PLACEMENT\"] = \"false\"\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\nclass TestAutoToGlobal(flow.unittest.TestCase):\n    @globaltest\n    @flow.unittest.skip_unless_1n4d()\n    def test_auto_to_global(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"device\"] = [\"cuda\"]\n        for arg in GenArgList(arg_dict):\n            _test_auto_to_global(test_case, *arg)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_autograd.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport torch as original_torch\nimport numpy as np\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\nfrom oneflow.test_utils.test_util import GenArgList\n\n\ndef _test_autograd_backward(test_case, shape, device):\n    np_input = np.random.rand(*shape)\n    of_input = flow.tensor(\n        np_input, dtype=flow.float32, device=flow.device(device), requires_grad=True\n    )\n    of_out = of_input ** 2\n    of_out_sum = of_out.sum()\n    of_out_sum.backward()\n    test_case.assertTrue(\n        np.allclose(of_input.grad.numpy(), np_input * 2, 0.0001, 0.0001)\n    )\n    of_input = flow.tensor(\n        np_input, dtype=flow.float32, device=flow.device(device), requires_grad=True\n    )\n    of_out = of_input ** 2\n    of_out_sum = of_out.sum()\n    of_out_sum.backward(flow.ones_like(of_out_sum) * 3)\n    test_case.assertTrue(\n        np.allclose(of_input.grad.numpy(), np_input * 6, 0.0001, 0.0001)\n    )\n    of_input = flow.tensor(\n        np_input, dtype=flow.float32, device=flow.device(device), requires_grad=True\n    )\n    of_out = of_input ** 2\n    of_out_sum = of_out.sum()\n    of_out_sum.backward(retain_graph=True)\n    of_out_sum.backward(retain_graph=True)\n    test_case.assertTrue(\n        np.allclose(of_input.grad.numpy(), np_input * 4, 0.0001, 0.0001)\n    )\n\n\ndef _test_autograd_grad(test_case, shape, device):\n    np_input = np.random.rand(*shape)\n    of_input = flow.tensor(\n        np_input, dtype=flow.float32, device=flow.device(device), requires_grad=True\n    )\n    of_out = of_input ** 2\n    of_out_sum = of_out.sum()\n    grad = flow.autograd.grad(of_out_sum, of_input)[0]\n    test_case.assertTrue(of_input.grad is None)\n    test_case.assertTrue(np.allclose(grad.numpy(), np_input * 2, 0.0001, 0.0001))\n    of_input = flow.tensor(\n        np_input, dtype=flow.float32, device=flow.device(device), requires_grad=True\n    )\n    of_out = of_input ** 2\n    of_out_sum = of_out.sum()\n    grad = flow.autograd.grad(of_out_sum, of_input, flow.ones_like(of_out_sum) * 3)[0]\n    test_case.assertTrue(np.allclose(grad.numpy(), np_input * 6, 0.0001, 0.0001))\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestAutograd(flow.unittest.TestCase):\n    def test_autograd_interface(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"case\"] = [_test_autograd_backward, _test_autograd_grad]\n        arg_dict[\"shape\"] = [(2, 3), (2, 3, 4, 5)]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n    @autotest(n=10, auto_backward=True, rtol=1e-3, atol=1e-3, check_graph=True)\n    def test_accumulate_grad(test_case):\n        device = random_device()\n        ndim = random(1, 4).to(int)\n        x = random_tensor(ndim=ndim, requires_grad=True).to(device)\n        y = random_tensor(ndim=ndim, requires_grad=True).to(device)\n        return x / (x + y)\n\n    @autotest(n=10, auto_backward=True, rtol=1e-3, atol=1e-3, check_graph=True)\n    def test_0dim_accumulate_grad(test_case):\n        device = random_device()\n        ndim = 0\n        x = random_tensor(ndim=ndim, requires_grad=True).to(device)\n        y = random_tensor(ndim=ndim, requires_grad=True).to(device)\n        return x / (x + y)\n\n    @autotest(n=10, auto_backward=True, rtol=1e-3, atol=1e-3, check_graph=True)\n    def test_scalar_leaf_tensor_backward(test_case):\n        device = random_device()\n        ndim = 0\n        x = random_tensor(ndim=ndim, requires_grad=True).to(device)\n        return x\n\n    @autotest(n=1, auto_backward=False, check_graph=False)\n    def test_out_grad_with_different_dtype(test_case):\n        x = random_tensor(ndim=2, requires_grad=True)\n        y = x.sum()\n        y.backward(torch.tensor(False))\n        return x.grad\n\n    @autotest(n=10, auto_backward=False, check_graph=False)\n    def test_grad_grad(test_case):\n        device = random_device()\n        ndim = random(1, 4).to(int)\n        x = random_tensor(ndim=ndim, requires_grad=True).to(device)\n        y = x * x * x\n        x_grad = torch.autograd.grad(\n            outputs=y,\n            inputs=x,\n            grad_outputs=torch.ones_like(y),\n            create_graph=True,\n            retain_graph=True,\n        )[0]\n        x_grad_grad = torch.autograd.grad(\n            outputs=x_grad, inputs=x, grad_outputs=torch.ones_like(x_grad)\n        )[0]\n        return x_grad_grad\n\n    @autotest(n=10, auto_backward=False, rtol=1e-3, atol=1e-3, check_graph=False)\n    def test_autograd_multiple_times(test_case):\n        device = random_device()\n        ndim = random(1, 4).to(int).value()\n        dims = [random(0, 10).to(int) for _ in range(ndim)]\n        x = random_tensor(ndim, *dims, requires_grad=True)\n        x1 = x.to(device)\n        y = random_tensor(ndim, *dims, requires_grad=True)\n        y1 = y.to(device)\n        z = x1 + y1\n\n        for _ in range(10):\n            z.sum().backward()\n        return (x.grad, y.grad)\n\n    def test_autograd_set_acc_grad_and_backward(test_case):\n        for _ in range(5):\n            ndim = 2\n            dims = [random(1, 5).to(int).value() for _ in range(ndim)]\n            x = torch.randn(*dims).requires_grad_()\n            np_arr = np.random.rand(*dims)\n            init_grad = torch.tensor(np_arr).to(x.dtype)\n            x.pytorch.grad = init_grad.pytorch\n            x.oneflow.grad = init_grad.oneflow\n\n            x.sum().backward()\n            test_case.assertTrue(\n                np.allclose(\n                    x.grad.oneflow.numpy(), x.grad.pytorch.cpu().detach().numpy()\n                )\n            )\n\n    @autotest(n=1, check_graph=False)\n    def test_requires_grad_tensor_inplace_and_backward(test_case):\n        random_shape = [random(1, 10).to(int) for _ in range(4)]\n        x = random_tensor(4, *random_shape, requires_grad=False)\n        y = random_tensor(4, *random_shape, requires_grad=True)\n        x += y\n        return x\n\n    @autotest(n=1, check_graph=False)\n    def test_retain_grad_for_leaf_tensor(test_case):\n        random_shape = [random(1, 10).to(int) for _ in range(4)]\n        x = random_tensor(4, *random_shape, requires_grad=True)\n        y = x * 2\n        x.retain_grad()\n        return y\n\n    @autotest(n=1, auto_backward=False, check_graph=False)\n    def test_run_backward_and_grad_for_same_tensor(test_case):\n        random_shape = [random(1, 10).to(int) for _ in range(4)]\n        x = random_tensor(4, *random_shape, requires_grad=True)\n        y = x ** 2\n        y.sum().backward()\n        test_case.assertTrue(\n            np.allclose(x.grad.oneflow.numpy(), x.grad.pytorch.numpy())\n        )\n\n        y = x ** 2\n        x_grad = torch.autograd.grad(y.sum(), x)[0]\n        test_case.assertTrue(\n            np.allclose(x_grad.oneflow.numpy(), x_grad.pytorch.numpy())\n        )\n        test_case.assertTrue(\n            np.allclose(x.grad.oneflow.numpy(), x_grad.oneflow.numpy())\n        )\n\n    @autotest(n=1, auto_backward=False, check_graph=False)\n    def test_no_grad_domain_call_backward(test_case):\n        random_shape = [random(1, 10).to(int).value() for _ in range(4)]\n        with flow.no_grad():\n            x = flow.rand(*random_shape).requires_grad_()\n            with flow.enable_grad():\n                y = x * 2\n            flow.autograd.backward(y, flow.ones_like(y))\n        test_case.assertTrue(np.array_equal(x.grad.numpy(), np.full(random_shape, 2.0)))\n\n    @autotest(n=1, auto_backward=False, check_graph=False)\n    def test_acc_grad_inplace_update(test_case):\n        random_shape = [random(1, 5).to(int).value() for _ in range(4)]\n        x = flow.rand(*random_shape).requires_grad_()\n        y = flow.rand(*random_shape).requires_grad_()\n\n        z = x / (x + y)\n        z.sum().backward()\n        id_x_grad = id(x.grad)\n        id_y_grad = id(y.grad)\n\n        z = x / (x + y)\n        z.sum().backward()\n        test_case.assertEqual(id_x_grad, id(x.grad))\n        test_case.assertEqual(id_y_grad, id(y.grad))\n\n    def test_autograd_grad_allow_unused(test_case):\n        shape = [random(1, 10).to(int) for _ in range(4)]\n        shape = [2, 4]\n        device = random_device()\n        x = random_tensor(len(shape), *shape, requires_grad=True).to(device)\n        z = random_tensor(len(shape), *shape, requires_grad=True).to(device)\n        y = x * x\n\n        np_arr = np.random.rand(*y.oneflow.shape)\n        init_grad = torch.tensor(np_arr).requires_grad_().to(device)\n        dx_and_dz = torch.autograd.grad(\n            y,\n            [x, z],\n            init_grad,\n            retain_graph=True,\n            create_graph=True,\n            allow_unused=True,\n        )\n        test_case.assertTrue(\n            np.allclose(\n                dx_and_dz[0].oneflow.detach().numpy(),\n                dx_and_dz[0].pytorch.detach().cpu().numpy(),\n            )\n        )\n        test_case.assertTrue(\n            dx_and_dz[1].oneflow is None and dx_and_dz[1].pytorch is None\n        )\n\n        np_arr = np.random.rand(*y.oneflow.shape)\n        init_grad_grad = torch.tensor(np_arr).requires_grad_().to(device)\n        ddx = torch.autograd.grad(\n            dx_and_dz[0],\n            x,\n            init_grad_grad,\n            retain_graph=True,\n            create_graph=True,\n            allow_unused=True,\n        )[0]\n        test_case.assertTrue(\n            np.allclose(\n                ddx.oneflow.detach().numpy(), ddx.pytorch.detach().cpu().numpy(),\n            )\n        )\n\n        np_arr = np.random.rand(*y.oneflow.shape)\n        init_grad_grad_grad = torch.tensor(np_arr).requires_grad_().to(device)\n        dddx = torch.autograd.grad(\n            ddx,\n            x,\n            init_grad_grad_grad,\n            retain_graph=True,\n            create_graph=True,\n            allow_unused=True,\n        )[0]\n        test_case.assertTrue(dddx.oneflow is None and dddx.pytorch is None)\n\n    def test_autograd_is_grads_batched(test_case):\n        x = flow.randn(2, 2, requires_grad=True)\n\n        out = x.clone()  # Size([2, 2])\n        batched_grad = flow.arange(3).expand(2, 2, 3).transpose(0, 2)  # Size([3, 2, 2])\n        (grad,) = flow.autograd.grad(out, (x,), (batched_grad,), is_grads_batched=True)\n        test_case.assertTrue(\n            np.array_equal(\n                grad.cpu().detach().numpy(),\n                flow.arange(3)\n                .expand(2, 2, 3)\n                .transpose(0, 2)\n                .to(dtype=grad.dtype)\n                .numpy(),\n            )\n        )\n\n        # Detect shape mismatch\n        grad_out = flow.ones(2, 2)\n        with test_case.assertRaisesRegex(\n            RuntimeError, \"If `is_grads_batched=True`, we interpret the first\"\n        ):\n            flow.autograd.grad(\n                outputs=out,\n                grad_outputs=(grad_out,),\n                inputs=(x,),\n                is_grads_batched=True,\n            )\n\n        # TODO: ReduceSum backward not support broadcast grad with shape (3, ) to (3, 2, 2)\n        #  # Scalar outputs\n        #  out = x.sum()  # Size([])\n        #  batched_grad = flow.arange(3)  # Size([3])\n        #  (grad,) = flow.autograd.grad(out, (x,), (batched_grad,), is_grads_batched=True)\n        #  test_case.assertTrue(\n        #      np.array_equal(\n        #          grad.cpu().detach().numpy(),\n        #          flow.arange(3).expand(2, 2, 3).transpose(0, 2).to(dtype=grad.dtype).numpy(),\n        #      )\n        #  )\n\n        # We consider scalar and sized-1 to be a mismatch. This is consistent with current non-batched behavior.\n        grad_out = flow.ones(2).unsqueeze(1)\n        with test_case.assertRaisesRegex(\n            RuntimeError, \"If `is_grads_batched=True`, we interpret the first\"\n        ):\n            flow.autograd.grad(\n                outputs=out,\n                grad_outputs=(grad_out,),\n                inputs=(x,),\n                is_grads_batched=True,\n            )\n\n    def test_autograd_grad_none_list(test_case):\n        x = flow.randn(10, 10, requires_grad=True)\n        y = flow.randn(10, 10, requires_grad=True)\n        merge = flow.cat([x, y], dim=0)\n        s_x, s_y = flow.split(merge, 10, dim=0)\n        s_x_sum = s_x.sum()\n        s_y_sum = s_y.sum()\n\n        (grad_x, grad_y) = flow.autograd.grad((s_x_sum, s_y_sum), (x, y), (None, None))\n        test_case.assertTrue(\n            np.array_equal(grad_x.numpy(), np.ones(x.shape).astype(np.float32),)\n        )\n        test_case.assertTrue(\n            np.array_equal(grad_y.numpy(), np.ones(y.shape).astype(np.float32),)\n        )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_autograd_function.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport re\n\nimport unittest\n\nimport numpy as np\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow import autograd\n\n\nclass TestAutogradFunction(flow.unittest.TestCase):\n    @flow.unittest.skip_unless_1n1d()\n    def test_simple_input(test_case):\n        class MyReLU(autograd.Function):\n            @staticmethod\n            def forward(ctx, x):\n                y = x.clamp(min=0.0, max=None)\n                ctx.save_for_backward(x)\n                return y\n\n            @staticmethod\n            def backward(ctx, y_grad):\n                x_grad = y_grad.clone()\n                (x,) = ctx.saved_tensors\n                x_grad[x < 0] = 0\n                return x_grad\n\n        np_arr = np.random.randn(4, 5)\n        a = flow.tensor(np_arr).requires_grad_()\n        # forward\n        b = MyReLU.apply(a)\n        test_case.assertTrue(np.allclose(b.numpy(), np_arr.clip(min=0.0)))\n        # backward\n        b.sum().backward()\n        np_grad = np.ones((4, 5))\n        np_grad[np_arr < 0] = 0.0\n        test_case.assertTrue(np.allclose(a.grad.numpy(), np_grad))\n\n    @flow.unittest.skip_unless_1n1d()\n    def test_multi_input(test_case):\n        class MyMatMul(autograd.Function):\n            @staticmethod\n            def forward(ctx, x, y):\n                z = x * y\n                ctx.save_for_backward(x, y)\n                return z\n\n            @staticmethod\n            def backward(ctx, z_grad):\n                x, y = ctx.saved_tensors\n                x_grad = y * z_grad\n                y_grad = x * z_grad\n                return x_grad, y_grad\n\n        np_arr0 = np.random.randn(4, 5)\n        np_arr1 = np.random.randn(4, 5)\n        a = flow.tensor(np_arr0).requires_grad_()\n        b = flow.tensor(np_arr1).requires_grad_()\n        # forward\n        c = MyMatMul().apply(a, b)\n        test_case.assertTrue(np.allclose(c.numpy(), np_arr0 * np_arr1))\n        # backward\n        c.sum().backward()\n        test_case.assertTrue(np.allclose(a.grad.numpy(), np_arr1))\n        test_case.assertTrue(np.allclose(b.grad.numpy(), np_arr0))\n\n    @flow.unittest.skip_unless_1n1d()\n    def test_non_differentiable_interface(test_case):\n        class MyModule(autograd.Function):\n            @staticmethod\n            def forward(ctx, x, y):\n                mul_res = x * y\n                add_res = x + y\n                ctx.save_for_backward(x, y)\n                ctx.mark_non_differentiable(add_res)\n                return mul_res, add_res\n\n            @staticmethod\n            def backward(ctx, mul_grad, add_grad=None):\n                x, y = ctx.saved_tensors\n                x_grad = y * mul_grad\n                y_grad = x * mul_grad\n                return x_grad, y_grad\n\n        np_arr0 = np.random.randn(4, 5)\n        np_arr1 = np.random.randn(4, 5)\n        a = flow.tensor(np_arr0).requires_grad_()\n        b = flow.tensor(np_arr1).requires_grad_()\n        # forward\n        c, d = MyModule().apply(a, b)\n        test_case.assertTrue(np.allclose(c.numpy(), np_arr0 * np_arr1))\n        test_case.assertFalse(d.requires_grad)\n        test_case.assertTrue(d.grad_fn is None)\n        # backward\n        c.sum().backward()\n        test_case.assertTrue(np.allclose(a.grad.numpy(), np_arr1))\n        test_case.assertTrue(np.allclose(b.grad.numpy(), np_arr0))\n\n    @flow.unittest.skip_unless_1n1d()\n    def test_partial_inputs_requires_grad(test_case):\n        class MyModule(autograd.Function):\n            @staticmethod\n            def forward(ctx, x, y, z):\n                return x + y + z\n\n            @staticmethod\n            def backward(ctx, out_grad):\n                return None, out_grad, None\n\n        x = flow.randn(4, 5)\n        y = flow.randn(4, 5).requires_grad_()\n        z = flow.randn(4, 5)\n        # forward\n        res = MyModule.apply(x, y, z)\n        test_case.assertTrue(\n            np.allclose(res.numpy(), x.numpy() + y.numpy() + z.numpy())\n        )\n        # backward\n        res.sum().backward()\n        test_case.assertIsNone(x.grad)\n        test_case.assertTrue(np.allclose(y.grad.numpy(), np.ones((4, 5))))\n        test_case.assertIsNone(z.grad)\n\n    @flow.unittest.skip_unless_1n1d()\n    def test_dynamic_attr_for_ctx(test_case):\n        class MyModule(autograd.Function):\n            @staticmethod\n            def forward(ctx, x):\n                ctx.scale = 2.0\n                return x * ctx.scale\n\n            @staticmethod\n            def backward(ctx, out_grad):\n                return out_grad * ctx.scale\n\n        x = flow.randn(4, 5).requires_grad_()\n        # forward\n        res = MyModule.apply(x)\n        test_case.assertTrue(np.allclose(res.numpy(), x.numpy() * 2.0))\n        # backward\n        res.sum().backward()\n        test_case.assertTrue(np.allclose(x.grad.numpy(), np.ones((4, 5)) * 2.0))\n\n    @flow.unittest.skip_unless_1n1d()\n    def test_backward_error_message(test_case):\n        class MyModule(autograd.Function):\n            @staticmethod\n            def forward(ctx, x, y, z):\n                return x + y + z\n\n            @staticmethod\n            def backward(ctx, out_grad):\n                return None, out_grad\n\n        x = flow.randn(4, 5)\n        y = flow.randn(4, 5).requires_grad_()\n        z = flow.randn(4, 5)\n        res = MyModule.apply(x, y, z)\n        with test_case.assertRaises(Exception) as exp:\n            res.sum().backward()\n        test_case.assertIsNotNone(\n            re.search(\n                r\"RuntimeError: function MyModule returned an incorrect number of gradients \\(expected \\d, got \\d\\)\",\n                str(exp.exception),\n            )\n        )\n\n    @flow.unittest.skip_unless_1n1d()\n    def test_graph_test_multi_input(test_case):\n        class MyMul(autograd.Function):\n            @staticmethod\n            def forward(ctx, x, y):\n                z = x * y\n                ctx.save_for_backward(x, y)\n                return z\n\n            @staticmethod\n            def backward(ctx, z_grad):\n                x, y = ctx.saved_tensors\n                x_grad = 2 * y * z_grad\n                y_grad = 3 * x * z_grad\n                return x_grad, y_grad\n\n        class MyAdd(autograd.Function):\n            @staticmethod\n            def forward(ctx, x, y):\n                return 2 * x + y\n\n            @staticmethod\n            def backward(ctx, z_grad):\n                x_grad = z_grad\n                y_grad = 2 * z_grad\n                return x_grad, y_grad\n\n        model = flow.nn.Linear(5, 4, bias=False)\n        model.train()\n\n        class MyGraph(flow.nn.Graph):\n            def __init__(self):\n                super().__init__()\n                self.model = model\n                optimizer = flow.optim.SGD(self.model.parameters())\n                self.add_optimizer(optimizer)\n\n            def build(self, x, y):\n                x.retain_grad()\n                y.retain_grad()\n                self.model.weight.retain_grad()\n                z = MyMul().apply(x, y)\n                z = MyAdd().apply(z, self.model.weight)\n                z.sum().backward()\n                return z, x.grad, y.grad, self.model.weight.grad\n\n        np_arr0 = np.random.randn(4, 5).astype(np.float32)\n        np_arr1 = np.random.randn(4, 5).astype(np.float32)\n        np_arr2 = np.random.randn(4, 5).astype(np.float32)\n        a = flow.tensor(np_arr0).requires_grad_()\n        b = flow.tensor(np_arr1).requires_grad_()\n        model.weight.copy_(np_arr2)\n\n        c, a_grad, b_grad, w_grad = MyGraph()(a, b)\n        test_case.assertTrue(np.allclose(c.numpy(), 2 * np_arr0 * np_arr1 + np_arr2))\n        test_case.assertTrue(np.allclose(a_grad.numpy(), 2 * np_arr1))\n        test_case.assertTrue(np.allclose(b_grad.numpy(), 3 * np_arr0))\n        test_case.assertTrue(np.allclose(w_grad.numpy(), 2 * np.ones_like(np_arr2)))\n\n    @flow.unittest.skip_unless_1n1d()\n    def test_autograd_function_memory(test_case):\n        global_ctx = None\n\n        class MyModule(autograd.Function):\n            @staticmethod\n            def forward(ctx, x):\n                z = x.clone()\n                ctx.save_for_backward(z)\n                nonlocal global_ctx\n                global_ctx = ctx\n                return z\n\n            @staticmethod\n            def backward(ctx, out_grad):\n                (x,) = ctx.saved_tensors\n                return x\n\n        x = flow.randn(5, 5).requires_grad_()\n        res = MyModule.apply(x)\n        test_case.assertTrue(global_ctx._is_data_valid())\n        res.sum().backward()\n\n        # ensure that global_ctx is released\n        test_case.assertFalse(global_ctx._is_data_valid())\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_autograd_mode.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\nclass TestAutogradMode(oneflow.unittest.TestCase):\n    def test_grad_mode(test_case):\n        test_case.assertTrue(flow.is_grad_enabled())\n\n    def test_inference_mode(test_case):\n        with flow.inference_mode(True):\n            test_case.assertFalse(flow.is_grad_enabled())\n        test_case.assertTrue(flow.is_grad_enabled())\n\n        @flow.inference_mode(True)\n        def func():\n            test_case.assertFalse(flow.is_grad_enabled())\n\n        func()\n        test_case.assertTrue(flow.is_grad_enabled())\n\n        with flow.inference_mode(False):\n            test_case.assertTrue(flow.is_grad_enabled())\n        test_case.assertTrue(flow.is_grad_enabled())\n\n        @flow.inference_mode(False)\n        def func():\n            test_case.assertTrue(flow.is_grad_enabled())\n\n        func()\n        test_case.assertTrue(flow.is_grad_enabled())\n\n    def test_enable_grad(test_case):\n        with flow.enable_grad():\n            test_case.assertTrue(flow.is_grad_enabled())\n        test_case.assertTrue(flow.is_grad_enabled())\n\n        @flow.enable_grad()\n        def func():\n            test_case.assertTrue(flow.is_grad_enabled())\n\n        func()\n        test_case.assertTrue(flow.is_grad_enabled())\n\n    def test_no_grad(test_case):\n        with flow.no_grad():\n            test_case.assertFalse(flow.is_grad_enabled())\n        test_case.assertTrue(flow.is_grad_enabled())\n\n        @flow.no_grad()\n        def func():\n            test_case.assertFalse(flow.is_grad_enabled())\n\n        func()\n        test_case.assertTrue(flow.is_grad_enabled())\n\n    def test_set_grad_enabled(test_case):\n        def assert_grad_mode(mode):\n            if mode:\n                test_case.assertTrue(flow.is_grad_enabled())\n            else:\n                test_case.assertFalse(flow.is_grad_enabled())\n\n        def get_decorater_func_with_mode(mode):\n            @flow.set_grad_enabled(mode)\n            def func():\n                assert_grad_mode(mode)\n\n            return func\n\n        def get_decorater_context_func_with_mode(dec_mode, ctx_mode):\n            @flow.set_grad_enabled(dec_mode)\n            def func():\n                assert_grad_mode(dec_mode)\n                with flow.set_grad_enabled(ctx_mode):\n                    assert_grad_mode(ctx_mode)\n                assert_grad_mode(dec_mode)\n\n            return func\n\n        flow.set_grad_enabled(False)\n        assert_grad_mode(False)\n\n        with flow.set_grad_enabled(True):\n            assert_grad_mode(True)\n            flow.set_grad_enabled(False)\n            assert_grad_mode(False)\n            func = get_decorater_func_with_mode(True)\n            func()\n        assert_grad_mode(False)\n\n        flow.set_grad_enabled(True)\n        assert_grad_mode(True)\n\n        with flow.set_grad_enabled(False):\n            assert_grad_mode(False)\n            flow.set_grad_enabled(True)\n            assert_grad_mode(True)\n            func = get_decorater_func_with_mode(False)\n            func()\n        assert_grad_mode(True)\n\n        get_decorater_context_func_with_mode(True, True)()\n        get_decorater_context_func_with_mode(True, False)()\n        get_decorater_context_func_with_mode(False, True)()\n        get_decorater_context_func_with_mode(False, False)()\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_avgpool.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\n\nimport oneflow as flow\nfrom oneflow.test_utils.automated_test_util.generators import constant, random_bool\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestAvgPoolingModule(flow.unittest.TestCase):\n    @autotest(n=5)\n    def test_avgpool1d_with_random_data(test_case):\n        m = torch.nn.AvgPool1d(\n            kernel_size=random(4, 6),\n            stride=random(1, 3) | nothing(),\n            padding=random(1, 3) | nothing(),\n            ceil_mode=random(),\n            count_include_pad=random(),\n        )\n        m.train(random())\n        device = random_device()\n        m.to(device)\n        x = random_tensor(ndim=3, dim2=random(20, 22)).to(device)\n        y = m(x)\n        return y\n\n    @autotest(n=5)\n    def test_avgpool2d_with_random_data(test_case):\n        m = torch.nn.AvgPool2d(\n            kernel_size=random(4, 6),\n            stride=random(1, 3) | nothing(),\n            padding=random(1, 3) | nothing(),\n            ceil_mode=random(),\n            count_include_pad=random(),\n            divisor_override=random().to(int),\n        )\n        m.train(random())\n        device = random_device()\n        m.to(device)\n        x = random_tensor(ndim=4, dim2=random(20, 22), dim3=random(20, 22)).to(device)\n        y = m(x)\n        return y\n\n    # TODO:(zhaoluyang) this test case has probability to fail in backward\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    @autotest(n=5, rtol=0.001, atol=0.001, auto_backward=False)\n    def test_avgpool2d_with_half_data(test_case):\n        m = torch.nn.AvgPool2d(\n            kernel_size=random(4, 6),\n            stride=random(1, 3) | nothing(),\n            padding=random(1, 3) | nothing(),\n            ceil_mode=random(),\n            count_include_pad=random(),\n            divisor_override=random().to(int),\n        )\n        m.train(random())\n        device = gpu_device()\n        m.to(device)\n        x = (\n            random_tensor(\n                ndim=4, dim2=random(20, 22), dim3=random(20, 22), requires_grad=False\n            )\n            .to(device)\n            .to(torch.float16)\n        )\n        y = m(x)\n        return y\n\n    @autotest(n=5)\n    def test_avgpool3d_with_random_data(test_case):\n        m = torch.nn.AvgPool3d(\n            kernel_size=random(4, 6),\n            stride=random(1, 3) | nothing(),\n            padding=random(1, 3) | nothing(),\n            ceil_mode=random(),\n            count_include_pad=random(),\n            divisor_override=random().to(int),\n        )\n        m.train(random())\n        device = random_device()\n        m.to(device)\n        x = random_tensor(\n            ndim=5, dim2=random(20, 22), dim3=random(20, 22), dim4=random(20, 22)\n        ).to(device)\n        y = m(x)\n        return y\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestAvgPoolingFunctional(flow.unittest.TestCase):\n    @autotest(n=5)\n    def test_avgpool1d_functional(test_case):\n        device = random_device()\n        x = random_tensor(ndim=3, dim2=random(20, 22)).to(device)\n        y = torch.nn.functional.avg_pool1d(\n            x,\n            kernel_size=random(1, 6).to(int),\n            stride=random(1, 3).to(int) | nothing(),\n            padding=random(1, 3).to(int),\n            ceil_mode=random_bool(),\n            count_include_pad=random_bool(),\n        )\n        return y\n\n    @autotest(n=5)\n    def test_avgpool2d_functional(test_case):\n        device = random_device()\n        x = random_tensor(ndim=4, dim2=random(20, 22), dim3=random(20, 22)).to(device)\n        y = torch.nn.functional.avg_pool2d(\n            x,\n            kernel_size=random(1, 6).to(int),\n            stride=random(1, 3).to(int) | nothing(),\n            padding=random(1, 3).to(int),\n            ceil_mode=random_bool(),\n            count_include_pad=random_bool(),\n        )\n        return y\n\n    @autotest(n=5)\n    def test_avgpool3d_functional(test_case):\n        device = random_device()\n        x = random_tensor(\n            ndim=5, dim2=random(20, 22), dim3=random(20, 22), dim4=random(20, 22)\n        ).to(device)\n        y = torch.nn.functional.avg_pool3d(\n            x,\n            kernel_size=random(1, 6).to(int),\n            stride=random(1, 3).to(int) | nothing(),\n            padding=random(1, 3).to(int),\n            ceil_mode=random_bool(),\n            count_include_pad=random_bool(),\n        )\n        return y\n\n    @profile(torch.nn.functional.avg_pool2d)\n    def profile_avgpool2d(test_case):\n        torch.nn.functional.avg_pool2d(\n            torch.ones(1, 128, 28, 28), kernel_size=3, padding=1\n        )\n        torch.nn.functional.avg_pool2d(\n            torch.ones(1, 128, 28, 28), kernel_size=3, stride=2, padding=1\n        )\n        torch.nn.functional.avg_pool2d(\n            torch.ones(16, 128, 28, 28), kernel_size=3, padding=1\n        )\n        torch.nn.functional.avg_pool2d(\n            torch.ones(16, 128, 28, 28), kernel_size=3, stride=2, padding=1\n        )\n        torch.nn.functional.avg_pool2d(\n            torch.ones(16, 128, 28, 28),\n            kernel_size=3,\n            stride=2,\n            padding=1,\n            ceil_mode=True,\n        )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_baddbmm.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestBaddBmmModule(flow.unittest.TestCase):\n    @autotest(n=5, rtol=1e-4, atol=1e-3)\n    def test_baddbmm_with_torch(test_case):\n        device = random_device()\n        input = random_tensor(ndim=3, dim0=2, dim1=4, dim2=4).to(device)\n        batch1 = random_tensor(ndim=3, dim0=2, dim1=4, dim2=3).to(device)\n        batch2 = random_tensor(ndim=3, dim0=2, dim1=3, dim2=4).to(device)\n        y = torch.baddbmm(input, batch1, batch2, beta=2.0, alpha=1.2)\n        return y\n\n    @autotest(n=5, rtol=1e-4, atol=1e-3)\n    def test_baddbmm_in_sd2_with_torch(test_case):\n        device = random_device()\n        input = random_tensor(ndim=3, dim0=2, dim1=2, dim2=2, requires_grad=False).to(\n            device\n        )\n        batch1 = random_tensor(ndim=3, dim0=2, dim1=2, dim2=2).to(device)\n        batch2 = random_tensor(ndim=3, dim0=2, dim1=2, dim2=2).to(device)\n        y = torch.baddbmm(input, batch1, batch2, beta=0.0, alpha=1.2)\n        return y\n\n    @autotest(n=5, rtol=1e-4, atol=1e-3)\n    def test_baddbmm_no_attr_with_torch(test_case):\n        device = random_device()\n        input = random_tensor(ndim=3, dim0=2, dim1=4, dim2=4).to(device)\n        batch1 = random_tensor(ndim=3, dim0=2, dim1=4, dim2=3).to(device)\n        batch2 = random_tensor(ndim=3, dim0=2, dim1=3, dim2=4).to(device)\n        y = torch.baddbmm(input, batch1, batch2)\n        return y\n\n    @autotest(n=5, rtol=1e-4, atol=1e-3)\n    def test_baddbmm_broadcast_with_torch(test_case):\n        device = random_device()\n        input = random_tensor(ndim=1, dim0=4).to(device)\n        batch1 = random_tensor(ndim=3, dim0=2, dim1=4, dim2=3).to(device)\n        batch2 = random_tensor(ndim=3, dim0=2, dim1=3, dim2=4).to(device)\n        y = torch.baddbmm(input, batch1, batch2, beta=-1.98, alpha=1.34)\n        return y\n\n    @profile(torch.baddbmm)\n    def profile_baddbmm(test_case):\n        input = torch.ones(10, 100, 100)\n        batch1 = torch.ones(10, 100, 100)\n        batch2 = torch.ones(10, 100, 100)\n        torch.bmm(input, batch1, batch2, beta=-1.98, alpha=1.34)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_batch_gather.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\nimport os\n\nimport numpy as np\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\ndef _test_batch_gather(test_case, shape, device):\n    # for example: shape = (3, 2, 2)\n    x = np.random.randn(*shape)\n    x_tensor = flow.Tensor(x).to(device)\n    x_tensor.requires_grad = True\n    batchsize = x.shape[0]\n    init_index = np.array(\n        [np.random.randint(batchsize) for i in range(batchsize)]\n    ).astype(np.int64)\n\n    batch_gather_index = flow.tensor(init_index).to(device)\n    batch_gather_out = flow.batch_gather(x_tensor, batch_gather_index)\n\n    x_tensor_gather = flow.Tensor(x).to(device)\n    x_tensor_gather.requires_grad = True\n    reshaped_shape = [batchsize]  # reshaped_shape = [3]\n    for i in range(len(x.shape) - 1):\n        reshaped_shape.append(1)  # reshaped_shape = [3] -> [3, 1, 1]\n\n    gather_index = np.reshape(init_index, reshaped_shape)\n    gather_index = np.broadcast_to(gather_index, shape).astype(\n        np.int64\n    )  # [3, 1, 1] -> [3, 2, 2]\n    gather_index = flow.tensor(gather_index).to(device)\n    gather_out = flow.gather(x_tensor_gather, 0, gather_index)\n    total_out = batch_gather_out.sum() + gather_out.sum()\n    total_out.backward()\n\n    test_case.assertTrue(\n        np.allclose(batch_gather_out.numpy(), gather_out.numpy(), atol=1e-4, rtol=1e-4)\n    )\n\n    test_case.assertTrue(\n        np.allclose(\n            x_tensor.grad.numpy(), x_tensor_gather.grad.numpy(), atol=1e-4, rtol=1e-4,\n        )\n    )\n    test_case.assertTrue(\n        np.allclose(\n            x_tensor.grad.numpy(), x_tensor_gather.grad.numpy(), atol=1e-4, rtol=1e-4,\n        )\n    )\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestBatchGather(flow.unittest.TestCase):\n    def test_batch_gather(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [_test_batch_gather]\n        arg_dict[\"shape\"] = [(3, 2, 2), (3, 2, 4, 2), (3, 3, 4, 2, 2), (4, 2)]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_batchnorm.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\n\nfrom oneflow.test_utils.automated_test_util import *\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestBatchNormModule(flow.unittest.TestCase):\n    @unittest.skip(\"skip for now, becase it failed 11 times in past week\")\n    @autotest(\n        auto_backward=True, rtol=1e-3, atol=1e-3, check_grad_use_random_data=False\n    )\n    def test_batchnorm1d_module_with_random_data(test_case):\n        device = random_device()\n        channel = random(1, 4).to(int)\n        m = torch.nn.BatchNorm1d(\n            num_features=channel,\n            track_running_stats=random().to(bool),\n            affine=random().to(bool),\n        ).to(device)\n        m.train(random())\n        x = random_tensor(\n            ndim=3, dim0=random(1, 4), dim1=channel, requires_grad=True\n        ).to(device)\n        y = m(x)\n        return y\n\n    @autotest(\n        auto_backward=True, rtol=1e-3, atol=1e-3, check_grad_use_random_data=False\n    )\n    def test_batchnorm2d_module_with_random_data(test_case):\n        device = random_device()\n        channel = random(1, 4).to(int)\n        m = torch.nn.BatchNorm2d(\n            num_features=channel,\n            track_running_stats=random().to(bool),\n            affine=random().to(bool),\n        ).to(device)\n        m.train(random())\n        x = random_tensor(\n            ndim=4, dim0=random(1, 4), dim1=channel, requires_grad=True\n        ).to(device)\n        y = m(x)\n        return y\n\n    @autotest(\n        auto_backward=True, rtol=1e-3, atol=1e-3, check_grad_use_random_data=False\n    )\n    def test_batchnorm3d_module_with_random_data(test_case):\n        device = random_device()\n        channel = random(1, 4).to(int)\n        m = torch.nn.BatchNorm3d(\n            num_features=channel,\n            track_running_stats=random().to(bool),\n            affine=random().to(bool),\n        ).to(device)\n        m.train(random())\n        x = random_tensor(ndim=5, dim1=channel, requires_grad=True).to(device)\n        y = m(x)\n        return y\n\n    @autotest(rtol=1e-3, atol=1e-3, check_grad_use_random_data=False)\n    def test_functional_batchnorm_with_random_data(test_case):\n        device = random_device()\n        channel = random(1, 4).to(int)\n        x = random_tensor(ndim=5, dim1=channel, requires_grad=True).to(device)\n        running_mean = random_tensor(ndim=1, dim0=channel, requires_grad=False)\n        running_var = random_tensor(ndim=1, dim0=channel, low=0.0, requires_grad=False)\n        weight = random_tensor(ndim=1, dim0=channel)\n        bias = random_tensor(ndim=1, dim0=channel)\n        result = torch.nn.functional.batch_norm(\n            input=x,\n            running_mean=running_mean,\n            running_var=running_var,\n            weight=weight,\n            bias=bias,\n            training=random_bool(),\n        )\n        return result\n\n    @autotest(rtol=1e-3, atol=1e-3, auto_backward=False, check_graph=False)\n    def test_batchnorm2d_module_with_half_random_data(test_case):\n        device = random_device()\n        channel = random(1, 4).to(int)\n        m = torch.nn.BatchNorm2d(\n            num_features=channel,\n            track_running_stats=random().to(bool),\n            affine=random().to(bool),\n        ).to(device)\n        m.train(random())\n        m.half()\n        x = random_tensor(\n            ndim=4, dim0=random(1, 4), dim1=channel, requires_grad=True\n        ).to(device)\n        x.half()\n        y = m(x)\n        return y\n\n    @profile(torch.nn.functional.batch_norm)\n    def profile_batchnorm(test_case):\n        input = torch.ones(16, 128, 28, 28)\n        running_mean = torch.randn(128)\n        running_var = torch.randn(128)\n        weight = torch.randn(128)\n        bias = torch.randn(128)\n        torch.nn.functional.batch_norm(\n            input, running_mean, running_var, weight, bias, True\n        )\n        torch.nn.functional.batch_norm(\n            input, running_mean, running_var, weight, bias, False\n        )\n        torch.nn.functional.batch_norm(\n            input, running_mean, running_var, None, None, True\n        )\n        torch.nn.functional.batch_norm(\n            input, running_mean, running_var, None, None, False\n        )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_batchnorm_add_relu.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\nimport os\n\nimport numpy as np\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\ndef _test_bn_add_relu(test_case, device, batch, channel, height, width):\n    weight_numpy = np.random.randn(channel)\n    bias_numpy = np.random.randn(channel)\n\n    fused_x = np.random.randn(batch, channel, height, width)\n    fused_x_tensor = flow.Tensor(fused_x).to(device)\n    fused_x_tensor.requires_grad = True\n\n    fused_addend = np.random.randn(batch, channel, height, width)\n    fused_addend_tensor = flow.Tensor(fused_addend).to(device)\n    fused_addend_tensor.requires_grad = True\n\n    fused_weight_tensor = flow.nn.Parameter(flow.Tensor(weight_numpy).to(device))\n    fused_bias_tensor = flow.nn.Parameter(flow.Tensor(bias_numpy).to(device))\n\n    fused_bn = flow.nn.FusedBatchNorm2d(channel).to(device)\n    fused_bn.weight = fused_weight_tensor\n    fused_bn.bias = fused_bias_tensor\n    fused_out = fused_bn(fused_x_tensor, fused_addend_tensor)\n\n    origin_x_tensor = flow.Tensor(fused_x).to(device)\n    origin_x_tensor.requires_grad = True\n\n    origin_addend_tensor = flow.Tensor(fused_addend).to(device)\n    origin_addend_tensor.requires_grad = True\n\n    origin_weight_tensor = flow.nn.Parameter(flow.Tensor(weight_numpy).to(device))\n    origin_bias_tensor = flow.nn.Parameter(flow.Tensor(bias_numpy).to(device))\n\n    origin_batch_norm = flow.nn.BatchNorm2d(channel).to(device)\n    origin_batch_norm.weight = origin_weight_tensor\n    origin_batch_norm.bias = origin_bias_tensor\n\n    origin_out = origin_batch_norm(origin_x_tensor) + origin_addend_tensor\n    origin_out = flow.nn.functional.relu(origin_out)\n\n    total_out = fused_out + origin_out\n    total_out_sum = total_out.sum()\n\n    total_out_sum.backward()\n\n    # test output.\n    test_case.assertTrue(\n        np.allclose(fused_out.numpy(), origin_out.numpy(), atol=1e-4, rtol=1e-4)\n    )\n    # test input grad.\n    test_case.assertTrue(\n        np.allclose(\n            fused_x_tensor.grad.numpy(),\n            origin_x_tensor.grad.numpy(),\n            atol=1e-4,\n            rtol=1e-4,\n        )\n    )\n    test_case.assertTrue(\n        np.allclose(\n            fused_addend_tensor.grad.numpy(),\n            origin_addend_tensor.grad.numpy(),\n            atol=1e-4,\n            rtol=1e-4,\n        )\n    )\n    # test weight and bias grad.\n    test_case.assertTrue(\n        np.allclose(\n            fused_weight_tensor.grad.numpy(),\n            origin_weight_tensor.grad.numpy(),\n            atol=1e-4,\n            rtol=1e-4,\n        )\n    )\n    test_case.assertTrue(\n        np.allclose(\n            fused_bias_tensor.grad.numpy(),\n            origin_bias_tensor.grad.numpy(),\n            atol=1e-4,\n            rtol=1e-4,\n        )\n    )\n    # test running mean and running variance.\n    test_case.assertTrue(\n        np.allclose(\n            fused_bn.running_mean.numpy(),\n            origin_batch_norm.running_mean.numpy(),\n            atol=1e-4,\n            rtol=1e-4,\n        )\n    )\n    test_case.assertTrue(\n        np.allclose(\n            fused_bn.running_var.numpy(),\n            origin_batch_norm.running_var.numpy(),\n            atol=1e-4,\n            rtol=1e-4,\n        )\n    )\n\n\ndef _test_bn_relu(test_case, device, batch, channel, height, width):\n    weight_numpy = np.random.randn(channel)\n    bias_numpy = np.random.randn(channel)\n\n    fused_x = np.random.randn(batch, channel, height, width)\n    fused_x_tensor = flow.Tensor(fused_x).to(device)\n    fused_x_tensor.requires_grad = True\n\n    fused_weight_tensor = flow.nn.Parameter(flow.Tensor(weight_numpy).to(device))\n    fused_bias_tensor = flow.nn.Parameter(flow.Tensor(bias_numpy).to(device))\n\n    fused_bn = flow.nn.FusedBatchNorm2d(channel).to(device)\n    fused_bn.weight = fused_weight_tensor\n    fused_bn.bias = fused_bias_tensor\n    fused_out = fused_bn(fused_x_tensor, None)\n\n    origin_x_tensor = flow.Tensor(fused_x).to(device)\n    origin_x_tensor.requires_grad = True\n\n    origin_weight_tensor = flow.nn.Parameter(flow.Tensor(weight_numpy).to(device))\n    origin_bias_tensor = flow.nn.Parameter(flow.Tensor(bias_numpy).to(device))\n\n    origin_batch_norm = flow.nn.BatchNorm2d(channel).to(device)\n    origin_batch_norm.weight = origin_weight_tensor\n    origin_batch_norm.bias = origin_bias_tensor\n\n    origin_out = origin_batch_norm(origin_x_tensor)\n    origin_out = flow.nn.functional.relu(origin_out)\n\n    total_out = fused_out + origin_out\n    total_out_sum = total_out.sum()\n\n    total_out_sum.backward()\n\n    # test output.\n    test_case.assertTrue(\n        np.allclose(fused_out.numpy(), origin_out.numpy(), atol=1e-4, rtol=1e-4)\n    )\n    # test input grad.\n    test_case.assertTrue(\n        np.allclose(\n            fused_x_tensor.grad.numpy(),\n            origin_x_tensor.grad.numpy(),\n            atol=1e-4,\n            rtol=1e-4,\n        )\n    )\n\n    # test weight and bias grad.\n    test_case.assertTrue(\n        np.allclose(\n            fused_weight_tensor.grad.numpy(),\n            origin_weight_tensor.grad.numpy(),\n            atol=1e-4,\n            rtol=1e-4,\n        )\n    )\n    test_case.assertTrue(\n        np.allclose(\n            fused_bias_tensor.grad.numpy(),\n            origin_bias_tensor.grad.numpy(),\n            atol=1e-4,\n            rtol=1e-4,\n        )\n    )\n    # test running mean and running variance.\n    test_case.assertTrue(\n        np.allclose(\n            fused_bn.running_mean.numpy(),\n            origin_batch_norm.running_mean.numpy(),\n            atol=1e-4,\n            rtol=1e-4,\n        )\n    )\n    test_case.assertTrue(\n        np.allclose(\n            fused_bn.running_var.numpy(),\n            origin_batch_norm.running_var.numpy(),\n            atol=1e-4,\n            rtol=1e-4,\n        )\n    )\n\n\ndef _test_bn_relu_track_running_states_false(\n    test_case, device, batch, channel, height, width\n):\n    weight_numpy = np.random.randn(channel)\n    bias_numpy = np.random.randn(channel)\n\n    fused_x = np.random.randn(batch, channel, height, width)\n    fused_x_tensor = flow.Tensor(fused_x).to(device)\n    fused_x_tensor.requires_grad = True\n\n    fused_weight_tensor = flow.nn.Parameter(flow.Tensor(weight_numpy).to(device))\n    fused_bias_tensor = flow.nn.Parameter(flow.Tensor(bias_numpy).to(device))\n\n    fused_bn = flow.nn.FusedBatchNorm2d(channel, track_running_stats=False).to(device)\n    fused_bn.weight = fused_weight_tensor\n    fused_bn.bias = fused_bias_tensor\n    fused_out = fused_bn(fused_x_tensor, None)\n\n    origin_x_tensor = flow.Tensor(fused_x).to(device)\n    origin_x_tensor.requires_grad = True\n\n    origin_weight_tensor = flow.nn.Parameter(flow.Tensor(weight_numpy).to(device))\n    origin_bias_tensor = flow.nn.Parameter(flow.Tensor(bias_numpy).to(device))\n\n    origin_batch_norm = flow.nn.BatchNorm2d(channel, track_running_stats=False).to(\n        device\n    )\n    origin_batch_norm.weight = origin_weight_tensor\n    origin_batch_norm.bias = origin_bias_tensor\n\n    origin_out = origin_batch_norm(origin_x_tensor)\n    origin_out = flow.nn.functional.relu(origin_out)\n\n    total_out = fused_out + origin_out\n    total_out_sum = total_out.sum()\n\n    total_out_sum.backward()\n\n    # test output.\n    test_case.assertTrue(\n        np.allclose(fused_out.numpy(), origin_out.numpy(), atol=1e-4, rtol=1e-4)\n    )\n    # test input grad.\n    test_case.assertTrue(\n        np.allclose(\n            fused_x_tensor.grad.numpy(),\n            origin_x_tensor.grad.numpy(),\n            atol=1e-4,\n            rtol=1e-4,\n        )\n    )\n    # test weight and bias grad.\n    test_case.assertTrue(\n        np.allclose(\n            fused_weight_tensor.grad.numpy(),\n            origin_weight_tensor.grad.numpy(),\n            atol=1e-4,\n            rtol=1e-4,\n        )\n    )\n    test_case.assertTrue(\n        np.allclose(\n            fused_bias_tensor.grad.numpy(),\n            origin_bias_tensor.grad.numpy(),\n            atol=1e-4,\n            rtol=1e-4,\n        )\n    )\n    # When track running states is False, the running mean and running variance will be set as None.\n    test_case.assertIsNone(fused_bn.running_mean)\n    test_case.assertIsNone(origin_batch_norm.running_mean)\n    test_case.assertIsNone(fused_bn.running_var)\n    test_case.assertIsNone(origin_batch_norm.running_var)\n\n\ndef _test_bn_add_relu_track_running_states_false(\n    test_case, device, batch, channel, height, width\n):\n    weight_numpy = np.random.randn(channel)\n    bias_numpy = np.random.randn(channel)\n\n    fused_x = np.random.randn(batch, channel, height, width)\n    fused_x_tensor = flow.Tensor(fused_x).to(device)\n    fused_x_tensor.requires_grad = True\n\n    fused_addend = np.random.randn(batch, channel, height, width)\n    fused_addend_tensor = flow.Tensor(fused_addend).to(device)\n    fused_addend_tensor.requires_grad = True\n\n    fused_weight_tensor = flow.nn.Parameter(flow.Tensor(weight_numpy).to(device))\n    fused_bias_tensor = flow.nn.Parameter(flow.Tensor(bias_numpy).to(device))\n\n    fused_bn = flow.nn.FusedBatchNorm2d(channel, track_running_stats=False).to(device)\n    fused_bn.weight = fused_weight_tensor\n    fused_bn.bias = fused_bias_tensor\n    fused_out = fused_bn(fused_x_tensor, fused_addend_tensor)\n\n    origin_x_tensor = flow.Tensor(fused_x).to(device)\n    origin_x_tensor.requires_grad = True\n\n    origin_addend_tensor = flow.Tensor(fused_addend).to(device)\n    origin_addend_tensor.requires_grad = True\n\n    origin_weight_tensor = flow.nn.Parameter(flow.Tensor(weight_numpy).to(device))\n    origin_bias_tensor = flow.nn.Parameter(flow.Tensor(bias_numpy).to(device))\n\n    origin_batch_norm = flow.nn.BatchNorm2d(channel, track_running_stats=False).to(\n        device\n    )\n    origin_batch_norm.weight = origin_weight_tensor\n    origin_batch_norm.bias = origin_bias_tensor\n\n    origin_out = origin_batch_norm(origin_x_tensor) + origin_addend_tensor\n    origin_out = flow.nn.functional.relu(origin_out)\n\n    total_out = fused_out + origin_out\n    total_out_sum = total_out.sum()\n\n    total_out_sum.backward()\n\n    # test output.\n    test_case.assertTrue(\n        np.allclose(fused_out.numpy(), origin_out.numpy(), atol=1e-4, rtol=1e-4)\n    )\n    # test input grad.\n    test_case.assertTrue(\n        np.allclose(\n            fused_x_tensor.grad.numpy(),\n            origin_x_tensor.grad.numpy(),\n            atol=1e-4,\n            rtol=1e-4,\n        )\n    )\n    test_case.assertTrue(\n        np.allclose(\n            fused_addend_tensor.grad.numpy(),\n            origin_addend_tensor.grad.numpy(),\n            atol=1e-4,\n            rtol=1e-4,\n        )\n    )\n    # test weight and bias grad.\n    test_case.assertTrue(\n        np.allclose(\n            fused_weight_tensor.grad.numpy(),\n            origin_weight_tensor.grad.numpy(),\n            atol=1e-4,\n            rtol=1e-4,\n        )\n    )\n    test_case.assertTrue(\n        np.allclose(\n            fused_bias_tensor.grad.numpy(),\n            origin_bias_tensor.grad.numpy(),\n            atol=1e-4,\n            rtol=1e-4,\n        )\n    )\n    # When track running states is False, the running mean and running variance will be set as None.\n    test_case.assertIsNone(fused_bn.running_mean)\n    test_case.assertIsNone(origin_batch_norm.running_mean)\n    test_case.assertIsNone(fused_bn.running_var)\n    test_case.assertIsNone(origin_batch_norm.running_var)\n\n\ndef _test_bn_add_relu_eval(test_case, device, batch, channel, height, width):\n    weight_numpy = np.random.randn(channel)\n    bias_numpy = np.random.randn(channel)\n\n    fused_x = np.random.randn(batch, channel, height, width)\n    fused_x_tensor = flow.Tensor(fused_x).to(device)\n\n    fused_addend = np.random.randn(batch, channel, height, width)\n    fused_addend_tensor = flow.Tensor(fused_addend).to(device)\n\n    fused_weight_tensor = flow.nn.Parameter(flow.Tensor(weight_numpy).to(device))\n    fused_bias_tensor = flow.nn.Parameter(flow.Tensor(bias_numpy).to(device))\n\n    fused_bn = flow.nn.FusedBatchNorm2d(channel).to(device)\n    fused_bn.eval()\n    fused_bn.weight = fused_weight_tensor\n    fused_bn.bias = fused_bias_tensor\n    fused_out = fused_bn(fused_x_tensor, fused_addend_tensor)\n\n    origin_x_tensor = flow.Tensor(fused_x).to(device)\n\n    origin_addend_tensor = flow.Tensor(fused_addend).to(device)\n\n    origin_weight_tensor = flow.nn.Parameter(flow.Tensor(weight_numpy).to(device))\n    origin_bias_tensor = flow.nn.Parameter(flow.Tensor(bias_numpy).to(device))\n\n    origin_batch_norm = flow.nn.BatchNorm2d(channel).to(device)\n    origin_batch_norm.eval()\n    origin_batch_norm.weight = origin_weight_tensor\n    origin_batch_norm.bias = origin_bias_tensor\n\n    origin_out = origin_batch_norm(origin_x_tensor) + origin_addend_tensor\n    origin_out = flow.nn.functional.relu(origin_out)\n\n    # test output.\n    test_case.assertTrue(\n        np.allclose(fused_out.numpy(), origin_out.numpy(), atol=1e-4, rtol=1e-4)\n    )\n\n\ndef _test_bn_relu_eval(test_case, device, batch, channel, height, width):\n    weight_numpy = np.random.randn(channel)\n    bias_numpy = np.random.randn(channel)\n\n    fused_x = np.random.randn(batch, channel, height, width)\n    fused_x_tensor = flow.Tensor(fused_x).to(device)\n\n    fused_weight_tensor = flow.nn.Parameter(flow.Tensor(weight_numpy).to(device))\n    fused_bias_tensor = flow.nn.Parameter(flow.Tensor(bias_numpy).to(device))\n\n    fused_bn = flow.nn.FusedBatchNorm2d(channel).to(device)\n    fused_bn.eval()\n    fused_bn.weight = fused_weight_tensor\n    fused_bn.bias = fused_bias_tensor\n    fused_out = fused_bn(fused_x_tensor)\n\n    origin_x_tensor = flow.Tensor(fused_x).to(device)\n\n    origin_weight_tensor = flow.nn.Parameter(flow.Tensor(weight_numpy).to(device))\n    origin_bias_tensor = flow.nn.Parameter(flow.Tensor(bias_numpy).to(device))\n\n    origin_batch_norm = flow.nn.BatchNorm2d(channel).to(device)\n    origin_batch_norm.eval()\n    origin_batch_norm.weight = origin_weight_tensor\n    origin_batch_norm.bias = origin_bias_tensor\n\n    origin_out = origin_batch_norm(origin_x_tensor)\n    origin_out = flow.nn.functional.relu(origin_out)\n\n    # test output.\n    test_case.assertTrue(\n        np.allclose(fused_out.numpy(), origin_out.numpy(), atol=1e-4, rtol=1e-4)\n    )\n\n\n@flow.unittest.skip_unless_1n1d()\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test gpu cases\")\nclass TestBnAddRelu(flow.unittest.TestCase):\n    def test_bn_add_relu2d(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [\n            _test_bn_add_relu,\n            _test_bn_relu,\n            _test_bn_relu_track_running_states_false,\n            _test_bn_add_relu_track_running_states_false,\n            _test_bn_add_relu_eval,\n            _test_bn_relu_eval,\n        ]\n        arg_dict[\"device\"] = [\"cuda\"]\n        if os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"):\n            arg_dict[\"device\"] = [\"cpu\"]\n        arg_dict[\"batch\"] = [1, 2, 8]\n        arg_dict[\"channels\"] = [4, 6]\n        arg_dict[\"height\"] = [6, 8]\n        arg_dict[\"width\"] = [12, 8]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_bernoulli.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nfrom oneflow.test_utils.automated_test_util import *\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\ndef _test_bernoulli(test_case, shape, p, dtype):\n    input_arr = np.ones(shape)\n    x = flow.tensor(input_arr, dtype=flow.float32, device=flow.device(\"cpu\"))\n    if p is None:\n        y = flow.bernoulli(x, dtype=dtype)\n    else:\n        y = flow.bernoulli(x, p=p, dtype=dtype)\n    test_case.assertTrue(y.dtype == dtype)\n    if p == 1 or p is None:\n        test_case.assertTrue(np.allclose(y.numpy(), x.numpy()))\n    elif p == 0:\n        test_case.assertTrue(np.allclose(y.numpy(), np.zeros(shape)))\n\n\ndef _test_bernoulli_with_generator(test_case, shape):\n    generator = flow.Generator()\n    generator.manual_seed(0)\n    x = flow.tensor(\n        np.random.rand(*shape), dtype=flow.float32, device=flow.device(\"cpu\")\n    )\n    y_1 = flow.bernoulli(x, generator=generator)\n    generator.manual_seed(0)\n    y_2 = flow.bernoulli(x, generator=generator)\n    test_case.assertTrue(np.allclose(y_1.numpy(), y_2.numpy()))\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestBernoulli(flow.unittest.TestCase):\n    def test_bernoulli(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_functions\"] = [_test_bernoulli]\n        arg_dict[\"shape\"] = [(2, 3), (2, 3, 4), (2, 3, 4, 5)]\n        arg_dict[\"p\"] = [None, 0, 1]\n        arg_dict[\"dtype\"] = [flow.float32, flow.int64]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n    @unittest.skip(\"bernoulli has bug\")\n    @autotest(auto_backward=False)\n    def test_flow_bernoulli_with_random_data(test_case):\n        input = random_tensor(ndim=1).to(\"cpu\")\n        return torch.bernoulli(input)\n\n    \"\"\"\n    @profile(torch.bernoulli) \n    def profile_bernoulli(test_case):\n        torch.bernoulli(torch.ones(3, 3))\n        torch.bernoulli(torch.zeros(3, 3))\n    \"\"\"\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_binary_math_ops_dtype.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom itertools import product\n\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\ndef get_dtype_str(dtype):\n    return str(dtype).split(\".\")[-1]\n\n\ndtype_list = [\n    torch.int8,\n    torch.int32,\n    torch.int64,\n    torch.float16,\n    torch.float32,\n    torch.float64,\n]\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestBinaryMathOpsDtype(flow.unittest.TestCase):\n    @autotest(n=2, auto_backward=False, check_graph=False)\n    def test_binary_math_ops_dtype(test_case):\n        device = random_device()\n\n        for x1_dtype, x2_dtype in product(dtype_list, dtype_list):\n            x1 = random_tensor(2, 2, 3, requires_grad=False).to(device).to(x1_dtype)\n            x2 = random_tensor(2, 2, 3, requires_grad=False).to(device).to(x2_dtype)\n\n            for op in [\"+\", \"-\", \"*\", \"/\"]:\n                y = eval(f\"x1 {op} x2\")\n                test_case.assertEqual(\n                    get_dtype_str(y.oneflow.dtype), get_dtype_str(y.pytorch.dtype)\n                )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_bincount.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\n\nimport oneflow as flow\nfrom oneflow.test_utils.automated_test_util import *\nimport oneflow.unittest\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestBinCount(flow.unittest.TestCase):\n    @autotest(n=5, auto_backward=False, check_graph=False)\n    def test_bincount(test_case):\n        device = random_device()\n        x = random_tensor(1, 100, low=0, high=65536, dtype=int).to(device)\n        result = torch.bincount(x)\n        return result\n\n    @autotest(n=5, auto_backward=False, check_graph=False)\n    def test_bincount_weight(test_case):\n        device = random_device()\n        x = random_tensor(1, 100, low=0, high=65536, dtype=int).to(device)\n        weight = random_tensor(1, 100).to(device)\n        return torch.bincount(x, weights=weight)\n\n    @autotest(n=5, auto_backward=False, check_graph=False)\n    def test_bincount_minlength(test_case):\n        device = random_device()\n        x = random_tensor(1, 100, low=0, high=65536, dtype=int).to(device)\n        weight = random_tensor(1, 100).to(device)\n        minlength = random(1, 200).to(int)\n        return torch.bincount(x, weights=weight, minlength=minlength)\n\n    @autotest(n=5, auto_backward=False, check_graph=False)\n    def test_bincount_0element(test_case):\n        device = random_device()\n        x = random_tensor(1, 0, low=0, high=65536, dtype=int).to(device)\n        weight = random_tensor(1, 0).to(device)\n        minlength = random(1, 200).to(int)\n        return torch.bincount(x, weights=weight, minlength=minlength)\n\n    @profile(torch.bincount)\n    def profile_bincount(test_case):\n        torch.bincount(torch.ones(4096).int())\n        torch.bincount(torch.ones(65536).int())\n        torch.bincount(torch.arange(4096).int())\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_bitwise.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\n\nimport oneflow as flow\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\ndef _test_bitwise_op(test_case, op):\n    device = random_device()\n    dims_kwargs = {\n        \"ndim\": 4,\n        \"dim0\": random(low=4, high=8).to(int),\n        \"dim1\": random(low=4, high=8).to(int),\n        \"dim2\": random(low=4, high=8).to(int),\n        \"dim3\": random(low=4, high=8).to(int),\n    }\n    # TODO(WangYi): oneflow doesn't support conversion between uint8 and int8\n    # So, use \"index\" instead of \"int\" in `random_dtype`\n    x_dtype = random_dtype([\"index\", \"bool\", \"unsigned\"])\n    y_dtype = random_dtype([\"index\", \"bool\", \"unsigned\"])\n    x = random_tensor(dtype=int, **dims_kwargs,).to(device).to(x_dtype)\n    y = random_tensor(dtype=int, **dims_kwargs,).to(device).to(y_dtype)\n    bool_tensor = random_tensor(low=-1, high=1, **dims_kwargs,).to(device) > 0\n    return op(op(x, y), bool_tensor)\n\n\ndef _test_scalar_bitwise(test_case, op):\n    device = random_device()\n    dtype = random_dtype([\"int\", \"bool\", \"unsigned\"])\n    x = (\n        random_tensor(\n            ndim=4,\n            dim0=random(low=4, high=8).to(int),\n            dim1=random(low=4, high=8).to(int),\n            dim2=random(low=4, high=8).to(int),\n            dim3=random(low=4, high=8).to(int),\n            dtype=int,\n        )\n        .to(device)\n        .to(dtype)\n    )\n    scalar = random(low=-10, high=10).to(int)\n    bool_scalar = random_bool()\n    result = op(op(x, scalar), bool_scalar)\n    return result\n\n\n# Bitwise ops only accept integral dtype,\n# so auto_backward isn't necessary\n@flow.unittest.skip_unless_1n1d()\nclass TestBitwiseAndModule(flow.unittest.TestCase):\n    @autotest(n=10, auto_backward=False)\n    def test_bitwise_and(test_case):\n        return _test_bitwise_op(test_case, torch.bitwise_and)\n\n    @autotest(n=10, auto_backward=False)\n    def test_scalar_bitwise_and(test_case):\n        return _test_scalar_bitwise(test_case, torch.bitwise_and,)\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestBitwiseOrModule(flow.unittest.TestCase):\n    @autotest(n=10, auto_backward=False)\n    def test_bitwise_or(test_case):\n        return _test_bitwise_op(test_case, torch.bitwise_or)\n\n    @autotest(n=10, auto_backward=False)\n    def test_scalar_bitwise_or(test_case):\n        return _test_scalar_bitwise(test_case, torch.bitwise_or,)\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestBitwiseXorModule(flow.unittest.TestCase):\n    @autotest(n=10, auto_backward=False)\n    def test_bitwise_xor(test_case):\n        return _test_bitwise_op(test_case, torch.bitwise_xor)\n\n    @autotest(n=10, auto_backward=False)\n    def test_scalar_bitwise_xor(test_case):\n        return _test_scalar_bitwise(test_case, torch.bitwise_xor,)\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestBitwiseNotModule(flow.unittest.TestCase):\n    @autotest(n=10, auto_backward=False)\n    def test_bitwise_not(test_case):\n        device = random_device()\n        # TODO(WangYi): oneflow doesn't support conversion between uint8 and int8\n        # So, use \"index\" instead of \"int\" in `random_dtype`\n        dtype = random_dtype([\"index\", \"bool\", \"unsigned\"])\n        x = (\n            random_tensor(\n                ndim=4,\n                dim0=random(low=4, high=8).to(int),\n                dim1=random(low=4, high=8).to(int),\n                dim2=random(low=4, high=8).to(int),\n                dim3=random(low=4, high=8).to(int),\n                dtype=int,\n                high=10,\n            )\n            .to(device)\n            .to(dtype)\n        )\n        return torch.bitwise_not(x)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_bmm.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\ndef _test_bmm(test_case, device):\n    input1 = flow.tensor(\n        np.random.randn(10, 3, 4), dtype=flow.float32, device=flow.device(device)\n    )\n    input2 = flow.tensor(\n        np.random.randn(10, 4, 5), dtype=flow.float32, device=flow.device(device)\n    )\n    of_out = flow.bmm(input1, input2)\n    np_out = np.matmul(input1.numpy(), input2.numpy())\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05))\n\n\ndef _test_bmm_backward(test_case, device):\n    input1 = flow.tensor(\n        [\n            [\n                [-0.0036776792258024216, 1.9946473836898804, -0.423959881067276],\n                [1.0892143249511719, 0.04005361348390579, -0.27883127331733704],\n            ],\n            [\n                [-0.970306396484375, 0.017771577462553978, 0.019596196711063385],\n                [0.27402883768081665, -0.8192587494850159, -0.3135920464992523],\n            ],\n        ],\n        dtype=flow.float32,\n        device=flow.device(device),\n        requires_grad=True,\n    )\n    input2 = flow.tensor(\n        [\n            [\n                [1.118346929550171, -0.930071234703064],\n                [1.1238232851028442, 1.373764157295227],\n                [0.17178462445735931, -1.1010534763336182],\n            ],\n            [\n                [0.6694859862327576, 0.9250285029411316],\n                [-1.0835869312286377, 0.4192655086517334],\n                [1.2616937160491943, 0.33809131383895874],\n            ],\n        ],\n        dtype=flow.float32,\n        device=flow.device(device),\n        requires_grad=True,\n    )\n    of_out = flow.bmm(input1, input2)\n    of_out = of_out.sum()\n    of_out.backward()\n    np_grad = [\n        [\n            [0.18827569484710693, 2.4975874423980713, -0.9292688369750977],\n            [0.18827569484710693, 2.4975874423980713, -0.9292688369750977],\n        ],\n        [\n            [1.5945144891738892, -0.6643214225769043, 1.5997850894927979],\n            [1.5945144891738892, -0.6643214225769043, 1.5997850894927979],\n        ],\n    ]\n    test_case.assertTrue(\n        np.allclose(input1.grad.numpy(), np_grad, atol=1e-05, rtol=1e-05)\n    )\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestModule(flow.unittest.TestCase):\n    def test_bmm(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [_test_bmm, _test_bmm_backward]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n    @autotest(check_graph=True, rtol=1e-4, atol=1e-3)\n    def test_bmm_with_torch(test_case):\n        device = random_device()\n        mat1 = random_tensor(ndim=3, dim0=2, dim1=4, dim2=3).to(device)\n        mat2 = random_tensor(ndim=3, dim0=2, dim1=3, dim2=4).to(device)\n        y = torch.bmm(mat1, mat2,)\n        return y\n\n    @profile(torch.bmm)\n    def profile_bmm(test_case):\n        mat1 = torch.ones(10, 100, 100)\n        mat2 = torch.ones(10, 100, 100)\n        torch.bmm(mat1, mat2)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_broadcast_like.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\ndef _test_broadcast_like(test_case, device):\n    input = flow.tensor(\n        np.ones(shape=(3, 1, 1), dtype=np.float32),\n        dtype=flow.float32,\n        device=flow.device(device),\n    )\n    like_tensor = flow.tensor(\n        np.ones(shape=(3, 3, 3), dtype=np.float32),\n        dtype=flow.float32,\n        device=flow.device(device),\n    )\n    of_out = flow.broadcast_like(input, like_tensor, broadcast_axes=(1, 2))\n    np_out = np.ones(shape=(3, 3, 3))\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05))\n\n\ndef _test_broadcast_like_one(test_case, device):\n    input = flow.tensor(\n        np.ones(shape=(1, 1), dtype=np.float32),\n        dtype=flow.float32,\n        device=flow.device(device),\n    )\n    like_tensor = flow.tensor(\n        np.ones(shape=(1, 2, 3), dtype=np.float32),\n        dtype=flow.float32,\n        device=flow.device(device),\n    )\n    of_out = flow.broadcast_like(input, like_tensor)\n    np_out = np.ones(shape=(1, 2, 3))\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05))\n\n\ndef _test_broadcast_like_different_dim(test_case, device):\n    input = flow.tensor(\n        np.ones(shape=(3, 1), dtype=np.float32),\n        dtype=flow.float32,\n        device=flow.device(device),\n    )\n    like_tensor = flow.tensor(\n        np.ones(shape=(2, 3, 4), dtype=np.float32),\n        dtype=flow.float32,\n        device=flow.device(device),\n    )\n    of_out = flow.broadcast_like(input, like_tensor)\n    np_out = np.ones(shape=(2, 3, 4))\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05))\n\n\ndef _test_broadcast_like_different_dim_with_input_axisvec(test_case, device):\n    input = flow.tensor(\n        np.ones(shape=(1, 5, 6), dtype=np.float32),\n        dtype=flow.float32,\n        device=flow.device(device),\n    )\n    like_tensor = flow.tensor(\n        np.ones(shape=(1, 5, 6, 1, 6), dtype=np.float32),\n        dtype=flow.float32,\n        device=flow.device(device),\n    )\n    of_out = flow.broadcast_like(input, like_tensor, broadcast_axes=(3, 4))\n    np_out = np.ones(shape=(1, 5, 6, 1, 6))\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05))\n\n\ndef _test_broadcast_like_3dim(test_case, device):\n    input = flow.tensor(\n        np.ones(shape=(1, 3, 2), dtype=np.float32),\n        dtype=flow.float32,\n        device=flow.device(device),\n    )\n    like_tensor = flow.tensor(\n        np.ones(shape=(3, 3, 2), dtype=np.float32),\n        dtype=flow.float32,\n        device=flow.device(device),\n    )\n    of_out = flow.broadcast_like(input, like_tensor, broadcast_axes=(0,))\n    np_out = np.ones(shape=(3, 3, 2))\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05))\n\n\ndef _test_broadcast_like_4dim(test_case, device):\n    input = flow.tensor(\n        np.ones(shape=(1, 3, 2, 1), dtype=np.float32),\n        dtype=flow.float32,\n        device=flow.device(device),\n    )\n    like_tensor = flow.tensor(\n        np.ones(shape=(3, 3, 2, 3), dtype=np.float32),\n        dtype=flow.float32,\n        device=flow.device(device),\n    )\n    of_out = flow.broadcast_like(input, like_tensor, broadcast_axes=(0, 3))\n    np_out = np.ones(shape=(3, 3, 2, 3))\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05))\n\n\ndef _test_broadcast_like_empty_axisvec(test_case, device):\n    input = flow.tensor(\n        np.ones(shape=(1), dtype=np.float32),\n        dtype=flow.float32,\n        device=flow.device(device),\n    )\n    like_tensor = flow.tensor(\n        np.ones(shape=(2, 3, 4), dtype=np.float32),\n        dtype=flow.float32,\n        device=flow.device(device),\n    )\n    of_out = flow.broadcast_like(input, like_tensor)\n    np_out = np.ones(shape=(2, 3, 4))\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05))\n\n\ndef _test_broadcast_like_backward(test_case, device):\n    input = flow.tensor(\n        np.ones(shape=(3, 1, 1), dtype=np.float32),\n        dtype=flow.float32,\n        device=flow.device(device),\n        requires_grad=True,\n    )\n    like_tensor = flow.tensor(\n        np.ones(shape=(3, 3, 3), dtype=np.float32),\n        dtype=flow.float32,\n        device=flow.device(device),\n        requires_grad=True,\n    )\n    of_out = flow.broadcast_like(input, like_tensor, broadcast_axes=(1, 2))\n    of_out = of_out.sum()\n    of_out.backward()\n    np_grad = [[[9.0]], [[9.0]], [[9.0]]]\n    test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-05, 1e-05))\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestBroadCastLike(flow.unittest.TestCase):\n    def test_broadcast_like(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [\n            _test_broadcast_like,\n            _test_broadcast_like_one,\n            _test_broadcast_like_different_dim,\n            _test_broadcast_like_different_dim_with_input_axisvec,\n            _test_broadcast_like_3dim,\n            _test_broadcast_like_4dim,\n            _test_broadcast_like_empty_axisvec,\n            _test_broadcast_like_backward,\n        ]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_broadcast_ops.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\n\nimport torch as ori_torch\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\nbinary_ops = [\n    torch.add,\n    torch.sub,\n    torch.mul,\n    torch.div,\n    torch.min,\n    torch.minimum,\n    torch.max,\n    torch.maximum,\n    torch.fmod,\n    torch.pow,\n    torch.eq,\n    torch.ne,\n    torch.gt,\n    torch.ge,\n    torch.lt,\n    torch.le,\n    torch.logical_and,\n    torch.logical_or,\n    torch.logical_xor,\n]\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestBroadcastOps(flow.unittest.TestCase):\n    @autotest(n=5, auto_backward=False)\n    def test_broadcast_elementwise(test_case):\n        op_idx = random(low=0, high=len(binary_ops)).to(int).value()\n        op = binary_ops[op_idx]\n        device = random_device()\n        x = random_tensor(ndim=4, dim0=2, dim1=2, dim2=3, dim3=4).to(device)\n        y = random_tensor(ndim=4, dim0=1, dim1=2, dim2=3, dim3=1).to(device)\n        out = op(x, y)\n        return out\n\n    @autotest(n=5, auto_backward=False)\n    def test_broadcast_matrix_row(test_case):\n        op_idx = random(low=0, high=len(binary_ops)).to(int).value()\n        op = binary_ops[op_idx]\n        device = random_device()\n        x = random_tensor(ndim=3, dim0=2, dim1=2, dim2=3).to(device)\n        y = random_tensor(ndim=2, dim0=2, dim1=3).to(device)\n        out = op(x, y)\n        return out\n\n    @autotest(n=5, auto_backward=False)\n    def test_broadcast_matrix_col(test_case):\n        op_idx = random(low=0, high=len(binary_ops)).to(int).value()\n        op = binary_ops[op_idx]\n        device = random_device()\n        x = random_tensor(ndim=3, dim0=2, dim1=2, dim2=3).to(device)\n        y = random_tensor(ndim=3, dim0=2, dim1=2, dim2=1).to(device)\n        out = op(x, y)\n        return out\n\n    @autotest(n=5, auto_backward=False)\n    def test_cpu_scalar_tensor_auto_cast(test_case):\n        def check_output(test_case, output):\n            of_res = output.oneflow\n            torch_res = output.pytorch\n            # NOTE: torch's device has no device index bug oneflow has.\n            #       e.g. torch gets \"cpu\" but oneflow gets \"cpu:0\"\n            test_case.assertTrue(str(torch_res.device) in str(of_res.device))\n            test_case.assertTrue(\n                np.allclose(of_res.numpy(), torch_res.detach().cpu().numpy())\n            )\n\n        op_idx = random(low=0, high=len(binary_ops)).to(int).value()\n        op = binary_ops[op_idx]\n        device = random_device()\n        x = torch.tensor(1.0)\n        y = random_tensor(ndim=2, dim0=2, dim1=2).to(device)\n\n        out = op(x, y)\n        check_output(test_case, out)\n\n        out = op(y, x)\n        check_output(test_case, out)\n\n    @autotest(n=30, auto_backward=False)\n    def test_broadcast_scalar(test_case):\n        op_idx = random(low=0, high=len(binary_ops)).to(int).value()\n        op = binary_ops[op_idx]\n        device = random_device()\n        x = random_tensor(ndim=3, dim0=2, dim1=2, dim2=3).to(device)\n        out = op(x, 1)\n        return out\n\n    @profile(torch.add)\n    def profile_broadcast_matrix_row(test_case):\n        input0 = torch.ones(256, 1024)\n        input1 = torch.ones(1024)\n        torch.add(input0, input1)\n\n    @profile(torch.add)\n    def profile_broadcast_matrix_col(test_case):\n        input0 = torch.ones(1024, 256)\n        input1 = torch.ones(1024, 1)\n        torch.add(input0, input1)\n\n    @profile(torch.add)\n    def profile_broadcast_elementwise(test_case):\n        input0 = torch.ones(256, 1024)\n        input1 = torch.ones(256, 1024)\n        torch.add(input0, input1)\n\n    @profile(torch.add)\n    def profile_broadcast_scalar(test_case):\n        input0 = torch.ones(256, 1024)\n        torch.add(input0, 1)\n\n    @profile(torch.add)\n    def profile_broadcast_general(test_case):\n        input0 = torch.ones(2, 64, 8, 16, 16, 4)\n        input1 = torch.ones(64, 8, 1, 16, 1)\n        torch.add(input0, input1)\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestBroadcastOpsOther(flow.unittest.TestCase):\n    def test_broadcast_shapes(test_case):\n        shapes = (2,), (3, 1), (1, 1, 1)\n        test_case.assertTrue(\n            flow.broadcast_shapes(*shapes), ori_torch.broadcast_shapes(*shapes),\n        )\n\n    @autotest(n=3)\n    def test_broadcast_tensors(test_case):\n        device = random_device()\n        x = random_tensor(ndim=2, dim0=1, dim1=4).to(device=device)\n        y = random_tensor(ndim=2, dim0=3, dim1=1).to(device=device)\n        return torch.broadcast_tensors(x, y)\n\n    def test_broadcast_to(test_case):\n        # see flow.expand, because broadcast_to is an alias of flow.expand\n        pass\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_cast.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nfrom random import shuffle\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nimport torch as torch_original\n\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.test_util import GenArgList\nfrom oneflow.test_utils.automated_test_util import *\n\n\ndef _test_cast_float2int(test_case, device, shape):\n    np_arr = np.random.randn(*shape).astype(np.float32)\n    input = flow.tensor(np_arr, dtype=flow.float32, device=flow.device(device))\n    output = flow.cast(input, flow.int8)\n    np_out = np_arr.astype(np.int8)\n    test_case.assertTrue(np.array_equal(output.numpy(), np_out))\n\n\ndef _test_cast_int2float(test_case, device, shape):\n    np_arr = np.random.randn(*shape).astype(np.int8)\n    input = flow.tensor(np_arr, dtype=flow.int8, device=flow.device(device))\n    output = flow.cast(input, flow.float32)\n    np_out = np_arr.astype(np.float32)\n    test_case.assertTrue(np.array_equal(output.numpy(), np_out))\n\n\ndef _test_cast_bool2int16(test_case, device, shape):\n    np_arr = np.random.randn(*shape).astype(np.float32)\n    input = flow.tensor(np_arr, dtype=flow.bool, device=flow.device(device))\n    output = flow.cast(input, flow.int16)\n    np_out = np_arr.astype(bool).astype(np.int16)\n    test_case.assertTrue(np.array_equal(output.numpy(), np_out))\n\n\ndef _test_cast_with_non_contiguous_input(test_case, device, shape):\n    np_arr = np.random.randn(*shape).astype(np.int8)\n    permute_dims = np.arange(len(shape)).tolist()\n    shuffle(permute_dims)\n    input = flow.tensor(np_arr, dtype=flow.int8, device=flow.device(device)).permute(\n        permute_dims\n    )\n    output = flow.cast(input, flow.float32)\n    np_out = np_arr.astype(np.float32).transpose(permute_dims)\n    test_case.assertTrue(np.array_equal(output.numpy(), np_out))\n    test_case.assertTrue(input.stride() == output.stride())\n\n\ndef _test_cast_backward(test_case, device, shape):\n    np_arr = np.random.randn(*shape).astype(np.float32)\n    x = flow.tensor(\n        np_arr, dtype=flow.float32, device=flow.device(device), requires_grad=True\n    )\n    y = flow.cast(x, flow.float64)\n    z = y.sum()\n    z.backward()\n    np_out = np_arr.astype(np.float64)\n    test_case.assertTrue(np.array_equal(x.grad.numpy(), np.ones(shape=shape)))\n\n\ndef random_expand(x, ndim, expand_size):\n    dim_size = [1,] * ndim\n    random_index = random(0, ndim).to(int).value()\n    dim_size[random_index] = expand_size\n    return x.expand(*dim_size)\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestCast(flow.unittest.TestCase):\n    def test_cast(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [\n            _test_cast_float2int,\n            _test_cast_int2float,\n            _test_cast_bool2int16,\n            _test_cast_backward,\n            # _test_cast_with_non_contiguous_input,\n        ]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"shape\"] = [(2, 3), (2, 3, 4), (2, 3, 4, 5)]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n    def test_cast_with_0_size_data(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [\n            _test_cast_float2int,\n            _test_cast_int2float,\n        ]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"shape\"] = [(2, 3, 0, 5)]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n    @autotest(n=5)\n    def test_cast_with_strided_input(test_case):\n        device = random_device()\n        x = random_tensor()\n        x = x.to(dtype=torch.float32, device=device)\n        perm_list = [0, 1, 2, 3]\n        shuffle(perm_list)\n        x = x.permute(perm_list)\n        y = x.to(dtype=torch.float64, device=device)\n        return y\n\n    @autotest(n=5)\n    def test_cast_with_expanded_input(test_case):\n        device = random_device()\n        random_expand_size = random(1, 6).to(int).value()\n        x = random_tensor(ndim=5, dim0=1, dim1=1, dim2=1, dim3=1, dim4=1)\n        x = x.to(dtype=torch.float32, device=device)\n        perm_list = [0, 1, 2, 3, 4]\n        shuffle(perm_list)\n        x = x.permute(perm_list)\n        y = random_expand(x, ndim=5, expand_size=random_expand_size)\n        z = y.to(dtype=torch.float64, device=device)\n        return z\n\n    @autotest(n=5)\n    def test_cast_with_expanded_input_2(test_case):\n        device = random_device()\n        x = random_tensor(ndim=1, dim0=5)\n        a = x.to(dtype=torch.float32, device=device)\n        b = a.expand((4, 5))\n        c = b.to(dtype=torch.double, device=device)\n        return c\n\n    @autotest(n=5)\n    def test_cast_with_squeezed_input(test_case):\n        device = random_device()\n        x = random_tensor().to(device)\n        y = torch.squeeze(x, random(1, 3).to(int))\n        z = y.to(dtype=torch.double, device=device)\n        return z\n\n    @autotest(n=5, auto_backward=False)\n    def test_cast_with_sliced_input(test_case):\n        device = random_device()\n        x = random_tensor(ndim=1, dim0=20)\n        y = random_tensor(ndim=1, dim0=7)\n        x = x.to(dtype=torch.float32, device=device)\n        y = y.to(device=device)\n        rows = x * 10\n        cols = y\n        a = rows.reshape(20, 1) + cols\n        b = a[:, :1]\n        c = b.to(torch.int)\n        return c\n\n    @autotest(n=5, auto_backward=False)\n    # NOTE:if set auto_backward=True, both oneflow and pytorch will raise RuntimeError:\n    # element 0 of tensors does not require grad and does not have a grad_fn\n    def test_cast_with_scalar_input(test_case):\n        device = random_device()\n        x = torch.tensor(3.14, device=device)\n        y = x.to(dtype=torch.float64, device=device)\n        z = y.to(dtype=torch.int8, device=device)\n        return z\n\n    @autotest(n=5, auto_backward=True, include_complex=False, atol=1e-5, rtol=1e-5)\n    def test_cast_with_complex_float2complex(test_case):\n        device = random_device()\n        x = random_tensor().to(dtype=torch.float32, device=device)\n        y = x.to(torch.complex64)\n        return y\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_ceil.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestCeilModule(flow.unittest.TestCase):\n    @autotest(n=5)\n    def test_ceil_flow_with_random_data(test_case):\n        device = random_device()\n        input = random_tensor().to(device)\n        y = torch.ceil(input)\n        return y\n\n    @autotest(n=5)\n    def test_ceil_flow_with_random_0d_data(test_case):\n        device = random_device()\n        input = random_tensor(ndim=0).to(device)\n        y = torch.ceil(input)\n        return y\n\n    @autotest(n=5, auto_backward=False, check_graph=True)\n    def test_ceil_with_0_size_data(test_case):\n        device = random_device()\n        x = random_tensor(4, 2, 1, 0, 3).to(device)\n        y = torch.ceil(x)\n        return y\n\n    @autotest(n=5, auto_backward=False, check_graph=True)\n    def test_ceil_with_0shape_0d_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=0).to(device)\n        y = torch.ceil(x)\n        return y\n\n    @profile(torch.ceil)\n    def profile_ceil(test_case):\n        torch.ceil(torch.ones(4))\n        torch.ceil(torch.ones(100000))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_check_meta_consistency.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nimport oneflow as flow\nimport os\n\nimport oneflow.unittest\nfrom oneflow.test_utils.test_util import GenArgList\n\n\n@flow.unittest.skip_unless_1n2d()\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\nclass TestGlobalCastModule_1n2d(flow.unittest.TestCase):\n    def test_check_meta_consistency(test_case):\n        if os.getenv(\"RANK\") == \"0\":\n            x = flow.ones((16, 16), device=flow.device(\"cuda\"), dtype=flow.int32)\n        else:\n            x = flow.zeros((1,), device=flow.device(\"cuda\"), dtype=flow.float)\n        placement = flow.placement(\"cuda\", ranks=[0])\n        sbp = (flow.sbp.broadcast,)\n        y = x.to_global(placement=placement, sbp=sbp)\n        y.check_meta_consistency()\n        y = y.to_global(sbp=flow.sbp.split(0))\n        y.check_meta_consistency()\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_checkpointing.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nimport os\n\nimport oneflow as flow\nimport oneflow.unittest\nimport numpy as np\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n@flow.unittest.skip_unless_1n1d()\nclass TestCheckpointing(flow.unittest.TestCase):\n    def test_checkpointing(test_case):\n        relu_forward_num = 0\n        relu_backward_num = 0\n\n        class MyReLU(flow.autograd.Function):\n            @staticmethod\n            def forward(ctx, x):\n                nonlocal relu_forward_num\n                relu_forward_num += 1\n                y = flow.relu(x)\n                ctx.save_for_backward(y)\n                return y\n\n            @staticmethod\n            def backward(ctx, dy):\n                nonlocal relu_backward_num\n                relu_backward_num += 1\n                y = ctx.saved_tensors[0]\n                return dy * (y > 0)\n\n        class M(flow.nn.Module):\n            def __init__(self):\n                super().__init__()\n                self.conv1 = flow.nn.Conv2d(3, 3, 3)\n                self.conv2 = flow.nn.Conv2d(3, 3, 3)\n\n            def forward(self, x):\n                x = self.conv1(x)\n                if checkpointing:\n                    x = flow.utils.checkpoint.checkpoint(MyReLU.apply, x)\n                else:\n                    x = MyReLU.apply(x)\n                x = self.conv2(x)\n                return x\n\n        x1 = flow.randn(1, 3, 8, 16).requires_grad_()\n        x2 = x1.detach().clone().requires_grad_()\n\n        m = M()\n\n        checkpointing = True\n        y1 = m(x1)\n        y1.sum().backward()\n\n        checkpointing = False\n        y2 = m(x2)\n        y2.sum().backward()\n\n        test_case.assertTrue(np.array_equal(y1, y2))\n        test_case.assertTrue(np.array_equal(x1.grad, x2.grad))\n        test_case.assertEqual(relu_forward_num, 3)\n        test_case.assertEqual(relu_backward_num, 2)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_chunk.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\nfrom random import shuffle\n\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestChunk(flow.unittest.TestCase):\n    @autotest(n=5, check_graph=True)\n    def test_flow_chunk_list_with_random_data(test_case):\n        device = random_device()\n        dim = random(1, 4).to(int)\n        x = random_tensor(\n            ndim=4,\n            dim1=random(low=4, high=8).to(int),\n            dim2=random(low=4, high=8).to(int),\n            dim3=random(low=4, high=8).to(int),\n        ).to(device)\n        y = torch.chunk(x, chunks=random(low=1, high=5).to(int), dim=dim)\n        z = torch.cat(y, dim=dim)\n        return z\n\n    @autotest(n=10)\n    def test_flow_chunk_list_with_random_data(test_case):\n        device = random_device()\n        dim = random(1, 4).to(int)\n        x = random_tensor(\n            ndim=4,\n            dim1=random(low=4, high=8).to(int),\n            dim2=random(low=4, high=8).to(int),\n            dim3=random(low=4, high=8).to(int),\n        ).to(device)\n        permute_list = [0, 1, 2, 3]\n        shuffle(permute_list)\n        y = x.permute(permute_list)\n        z = torch.chunk(y, chunks=random(low=1, high=5).to(int), dim=dim)\n        return torch.cat(z, dim=dim)\n\n    @autotest(n=5, auto_backward=False, check_graph=True)\n    def test_flow_chunk_list_with_stride(test_case):\n        device = random_device()\n        dim = random(1, 4).to(int)\n        x = random_tensor(\n            ndim=4,\n            dim1=random(low=4, high=8).to(int),\n            dim2=random(low=4, high=8).to(int),\n            dim3=random(low=4, high=8).to(int),\n        ).to(device)\n        perm = [0, 1, 2, 3]\n        shuffle(perm)\n        y = x.permute(perm)\n        z = torch.chunk(y, chunks=random(low=1, high=5).to(int), dim=dim)\n        return torch.cat(z, dim=dim)\n\n    @autotest(n=5, auto_backward=False, check_graph=True)\n    def test_flow_chunk_list_bool_with_random_data(test_case):\n        device = random_device()\n        dim = random(1, 4).to(int)\n        x = random_tensor(\n            ndim=4,\n            dim1=random(low=4, high=8).to(int),\n            dim2=random(low=4, high=8).to(int),\n            dim3=random(low=4, high=8).to(int),\n        ).to(device, torch.bool)\n        y = torch.chunk(x, chunks=random(low=1, high=5).to(int), dim=dim)\n        z = torch.cat(y, dim=dim)\n        return z\n\n    @autotest(n=5, check_graph=True)\n    def test_flow_chunk_list_with_random_data_negative_dim(test_case):\n        device = random_device()\n        dim = random(1, 3).to(int)\n        x = random_tensor(\n            ndim=4,\n            dim0=random(low=4, high=8).to(int),\n            dim1=random(low=4, high=8).to(int),\n            dim2=random(low=4, high=8).to(int),\n            dim3=random(low=4, high=8).to(int),\n        ).to(device)\n        y = torch.chunk(x, chunks=4, dim=-1)\n        z = torch.cat(y, dim=-1)\n        return z\n\n    @profile(torch.chunk)\n    def profile_chunk(test_case):\n        torch.chunk(torch.ones(16), 4)\n        torch.chunk(torch.ones(100000), 5)\n        torch.chunk(torch.ones(100, 100), 5, dim=1)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_clamp.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\ndef _test_clamp(test_case, shape, device, dtype):\n    input = flow.tensor(\n        np.random.randn(*shape), dtype=dtype, device=flow.device(device)\n    )\n    of_out = flow.clamp(input, 0.1, 0.5)\n    np_out = np.clip(input.numpy(), 0.1, 0.5)\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05))\n\n\ndef _test_tensor_clamp(test_case, shape, device, dtype):\n    input = flow.tensor(\n        np.random.randn(*shape), dtype=dtype, device=flow.device(device)\n    )\n    of_out = input.clamp(0.1, 0.5)\n    np_out = np.clip(input.numpy(), 0.1, 0.5)\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05))\n\n\ndef _test_clamp_scalar_min(test_case, shape, device, dtype):\n    input = flow.tensor(\n        np.random.randn(*shape), dtype=dtype, device=flow.device(device)\n    )\n    of_out = flow.clamp(input, 0.1, None)\n    np_out = np.clip(input.numpy(), 0.1, None)\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05))\n\n\ndef _test_clamp_scalar_max(test_case, shape, device, dtype):\n    input = flow.tensor(\n        np.random.randn(*shape), dtype=dtype, device=flow.device(device)\n    )\n    of_out = flow.clamp(input, None, 0.5)\n    np_out = np.clip(input.numpy(), None, 0.5)\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05))\n\n\ndef _test_clamp_integral(test_case, shape, device, dtype):\n    input = flow.tensor(np.random.randint(3, 10, shape), device=flow.device(device)).to(\n        dtype\n    )\n    of_out = flow.clamp(input, 1, 5)\n    np_out = np.clip(input.numpy(), 1, 5)\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05))\n\n\ndef _numpy_clamp_grad(arr, min, max):\n    grad = np.zeros_like(arr)\n    grad[arr.clip(min, max) == arr] += 1\n    return grad\n\n\ndef _test_clamp_backward(test_case, shape, device, dtype):\n    x = flow.tensor(\n        np.random.randn(*shape),\n        dtype=flow.float32,\n        device=flow.device(device),\n        requires_grad=True,\n    )\n    y = flow.clamp(x, 0.1, 0.5).sum()\n    y.backward()\n    test_case.assertTrue(\n        np.allclose(\n            x.grad.numpy(), _numpy_clamp_grad(x.numpy(), 0.1, 0.5), 1e-05, 1e-05\n        )\n    )\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestClampModule(flow.unittest.TestCase):\n    def test_clamp(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"fun\"] = [\n            _test_clamp,\n            _test_tensor_clamp,\n            _test_clamp_scalar_min,\n            _test_clamp_scalar_max,\n            _test_clamp_backward,\n        ]\n        arg_dict[\"shape\"] = [(2,), (2, 3), (2, 4, 5, 6)]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"dtype\"] = [flow.float16, flow.float, flow.double]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n        arg_dict[\"fun\"] = [\n            _test_clamp_integral,\n        ]\n        arg_dict[\"dtype\"] = [flow.int8, flow.int, flow.long]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n    @autotest(n=5)\n    def test_clamp_flow_with_random_data(test_case):\n        device = random_device()\n        input = random_tensor().to(device)\n        y = torch.clamp(input, min=random().to(float), max=random().to(float))\n        return y\n\n    @autotest(n=5)\n    def test_clamp_min_none_flow_with_random_data(test_case):\n        device = random_device()\n        input = random_tensor().to(device)\n        y = torch.clamp(input, min=random().to(float), max=random().to(float))\n        return y\n\n    @autotest(n=5)\n    def test_clamp_max_none_flow_with_random_data(test_case):\n        device = random_device()\n        input = random_tensor().to(device)\n        y = torch.clamp(\n            input, min=random().to(float), max=random().to(float) | nothing()\n        )\n        return y\n\n    @profile(torch.clamp)\n    def profile_clamp(test_case):\n        torch.clamp(torch.ones(4), -1, 2)\n        torch.clamp(torch.ones(100000), -1, 2)\n\n    @autotest(n=5)\n    def test_clip_flow_with_random_data(test_case):\n        device = random_device()\n        input = random_tensor().to(device)\n        y = torch.clip(input, min=random().to(float), max=random().to(float))\n        return y\n\n    @autotest(n=5)\n    def test_clip_min_none_flow_with_random_data(test_case):\n        device = random_device()\n        input = random_tensor().to(device)\n        y = torch.clip(input, min=random().to(float), max=random().to(float))\n        return y\n\n    @autotest(n=5)\n    def test_clip_max_none_flow_with_random_data(test_case):\n        device = random_device()\n        input = random_tensor().to(device)\n        y = torch.clip(\n            input, min=random().to(float), max=random().to(float) | nothing()\n        )\n        return y\n\n    @profile(torch.clip)\n    def profile_clip(test_case):\n        torch.clip(torch.ones(4), -1, 2)\n        torch.clip(torch.ones(100000), -1, 2)\n\n    @autotest(n=5, auto_backward=False, check_graph=True)\n    def test_clamp_with_0_size_data(test_case):\n        device = random_device()\n        x = random_tensor(4, 2, 1, 0, 3).to(device)\n        y = torch.clamp(x, min=random().to(float), max=random().to(float))\n        return y\n\n\ndef _test_clamp_min(test_case, shape, device):\n    input = flow.tensor(\n        np.random.randn(*shape), dtype=flow.float32, device=flow.device(device)\n    )\n    of_out = flow.clamp_min(input, 0.1)\n    np_out = np.clip(input.numpy(), 0.1, None)\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05))\n\n\ndef _test_clamp_min_integral(test_case, shape, device):\n    input = flow.tensor(np.random.randint(3, 10, shape), device=flow.device(device))\n    of_out = flow.clamp_min(input, 1)\n    np_out = np.clip(input.numpy(), 1, None)\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05))\n\n\ndef _test_clamp_min_backward(test_case, shape, device):\n    x = flow.tensor(\n        np.random.randn(*shape),\n        dtype=flow.float32,\n        device=flow.device(device),\n        requires_grad=True,\n    )\n    y = flow.clamp_min(x, 0.1).sum()\n    y.backward()\n    test_case.assertTrue(\n        np.allclose(\n            x.grad.numpy(), _numpy_clamp_grad(x.numpy(), 0.1, None), 1e-05, 1e-05\n        )\n    )\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestClampMinModule(flow.unittest.TestCase):\n    def test_clamp_min(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"fun\"] = [\n            _test_clamp_min,\n            _test_clamp_min_integral,\n            _test_clamp_min_backward,\n        ]\n        arg_dict[\"shape\"] = [(2,), (2, 3), (2, 4, 5, 6)]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n    @autotest(n=5)\n    def test_clamp_min_flow_with_random_data(test_case):\n        device = random_device()\n        input = random_tensor().to(device)\n        y = torch.clamp_min(input, min=random().to(float))\n        return y\n\n    @autotest(n=5, auto_backward=False, check_graph=True)\n    def test_clamp_min_with_0_size_data(test_case):\n        device = random_device()\n        x = random_tensor(4, 2, 1, 0, 3).to(device)\n        y = torch.clamp_min(x, min=random().to(float))\n        return y\n\n\ndef _test_clamp_max(test_case, shape, device):\n    input = flow.tensor(\n        np.random.randn(*shape), dtype=flow.float32, device=flow.device(device)\n    )\n    of_out = flow.clamp_max(input, 0.5)\n    np_out = np.clip(input.numpy(), None, 0.5)\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05))\n\n\ndef _test_clamp_max_integral(test_case, shape, device):\n    input = flow.tensor(np.random.randint(3, 10, shape), device=flow.device(device))\n    of_out = flow.clamp_max(input, 1)\n    np_out = np.clip(input.numpy(), None, 1)\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05))\n\n\ndef _test_clamp_max_backward(test_case, shape, device):\n    x = flow.tensor(\n        np.random.randn(*shape),\n        dtype=flow.float32,\n        device=flow.device(device),\n        requires_grad=True,\n    )\n    y = flow.clamp_max(x, 0.5).sum()\n    y.backward()\n    test_case.assertTrue(\n        np.allclose(\n            x.grad.numpy(), _numpy_clamp_grad(x.numpy(), None, 0.5), 1e-05, 1e-05\n        )\n    )\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestClampMaxModule(flow.unittest.TestCase):\n    def test_clamp_min(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"fun\"] = [\n            _test_clamp_max,\n            _test_clamp_max_integral,\n            _test_clamp_max_backward,\n        ]\n        arg_dict[\"shape\"] = [(2,), (2, 3), (2, 4, 5, 6)]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n    @autotest(n=5)\n    def test_clamp_max_flow_with_random_data(test_case):\n        device = random_device()\n        input = random_tensor().to(device)\n        y = torch.clamp_max(input, max=random().to(float))\n        return y\n\n    @autotest(n=5, auto_backward=False, check_graph=True)\n    def test_clamp_max_with_0_size_data(test_case):\n        device = random_device()\n        x = random_tensor(4, 2, 1, 0, 3).to(device)\n        y = torch.clamp_max(x, max=random().to(float))\n        return y\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_clip_grad.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport os\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\n\nimport oneflow as flow\nfrom oneflow.test_utils.test_util import GenArgList\n\n\ndef _clip_grad_norm_np(input, max_norm, norm_type):\n    np_out = np.maximum(0, input)\n    np_grad = np.array(np_out > 0, dtype=np.float32)\n    max_norm = float(max_norm)\n    norm_type = float(norm_type)\n    input = [input]\n    if len(input) == 0:\n        return 0, 0\n    if norm_type == float(\"inf\"):\n        total_norm = np.max(np.abs(np_grad))\n    if norm_type == float(\"-inf\"):\n        total_norm = np.min(np.abs(np_grad))\n    elif norm_type == 0:\n        total_norm = np.sum(np.stack([np.sum(np_grad != 0)]) != 0)\n    else:\n        total_norm = np_grad\n        for i in range(np_grad.ndim, 0, -1):\n            total_norm = np.linalg.norm(total_norm, norm_type, axis=i - 1)\n    clip_coef = max_norm / (total_norm + 1e-6)\n    if clip_coef < 1:\n        np_grad = np.dot(np_grad, clip_coef)\n    return total_norm, np_grad\n\n\ndef _test_clip_grad_norm_impl(test_case, shape, device, max_norm, norm_type, fused):\n    np_input = np.random.rand(*shape)\n    of_input = flow.tensor(\n        np_input, dtype=flow.float32, device=flow.device(device), requires_grad=True\n    )\n    m = flow.nn.ReLU()\n    of_out = m(of_input)\n    of_out = of_out.sum()\n    of_out.backward()\n    of_total_norm = flow.nn.utils.clip_grad_norm_(of_input, max_norm, norm_type, fused)\n    np_total_norm, np_grad = _clip_grad_norm_np(np_input, max_norm, norm_type)\n    test_case.assertTrue(\n        np.allclose(of_total_norm.numpy(), np_total_norm, 1e-4, 1e-4, equal_nan=True)\n    )\n    test_case.assertTrue(\n        np.allclose(of_input.grad.numpy(), np_grad, 1e-4, 1e-4, equal_nan=True)\n    )\n\n\ndef _clip_grad_value_np(input, clip_value):\n    np_out = np.maximum(0, input)\n    np_grad = np.array(np_out > 0, dtype=np.float32)\n    clip_value = float(clip_value)\n    if len(input) == 0:\n        return 0, 0\n    np_grad = np.clip(np_grad, -clip_value, clip_value)\n    return np_grad\n\n\ndef _test_clip_grad_value_impl(test_case, shape, device, clip_value):\n    np_input = np.random.rand(*shape)\n    of_input = flow.tensor(\n        np_input, dtype=flow.float32, device=flow.device(device), requires_grad=True\n    )\n    m = flow.nn.ReLU()\n    of_out = m(of_input)\n    of_out = of_out.sum()\n    of_out.backward()\n    flow.nn.utils.clip_grad_value_(of_input, clip_value)\n    of_grad = of_input.grad.numpy()\n    np_grad = _clip_grad_value_np(np_input, clip_value)\n    test_case.assertTrue(np.allclose(of_grad, np_grad, 1e-4, 1e-4, equal_nan=True))\n\n\nclass ReluGraph(flow.nn.Graph):\n    def __init__(self, clip_value) -> None:\n        super().__init__()\n        self.clip_value = clip_value\n\n    def build(self, x):\n        flow.nn.utils.clip_grad_value_(x, self.clip_value)\n        return x\n\n\ndef _test_graph_clip_grad_value_impl(test_case, shape, device, clip_value):\n    np_input = np.random.rand(*shape)\n    of_input = flow.tensor(\n        np_input, dtype=flow.float32, device=flow.device(device), requires_grad=True\n    )\n    of_eager_out = of_input\n    flow.nn.utils.clip_grad_value_(of_eager_out, clip_value)\n    relu_graph = ReluGraph(clip_value)\n    of_graph_out = relu_graph(of_input)\n    test_case.assertTrue(\n        np.allclose(\n            of_eager_out.numpy(), of_graph_out.numpy(), 1e-4, 1e-4, equal_nan=True\n        )\n    )\n\n\ndef _test_clip_grad_norm_global_impl(\n    test_case, shape, sbp, placement, max_norm, norm_type\n):\n    of_input = flow.rand(\n        *shape, dtype=flow.float32, sbp=sbp, placement=placement, requires_grad=True\n    )\n    np_input = of_input.to_global(sbp=flow.sbp.broadcast).to_local().numpy()\n\n    m = flow.nn.ReLU()\n    of_out = m(of_input)\n    of_out = of_out.sum()\n    of_out.backward()\n    of_total_norm = flow.nn.utils.clip_grad_norm_(\n        of_input, max_norm, norm_type\n    ).to_local()\n    np_total_norm, np_grad = _clip_grad_norm_np(np_input, max_norm, norm_type)\n    test_case.assertTrue(\n        np.allclose(of_total_norm.numpy(), np_total_norm, 1e-4, 1e-4, equal_nan=True)\n    )\n    test_case.assertTrue(\n        np.allclose(\n            of_input.grad.to_global(sbp=flow.sbp.broadcast).to_local().numpy(),\n            np_grad,\n            1e-4,\n            1e-4,\n            equal_nan=True,\n        )\n    )\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestClipGrad(flow.unittest.TestCase):\n    def test_clip_grad(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"shape\"] = [(2, 3), (2, 3, 4), (2, 4, 5, 6)]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"max_norm\"] = [0, 0.5, 1.0]\n        arg_dict[\"norm_type\"] = [\"inf\", \"-inf\", 0.0, 1.0, 2.0, 3.5]\n        arg_dict[\"fused\"] = [False, True]\n        for arg in GenArgList(arg_dict):\n            _test_clip_grad_norm_impl(test_case, *arg)\n\n    def test_clip_value(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"shape\"] = [(2, 3), (2, 3, 4), (2, 4, 5, 6)]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"clip_value\"] = [0, 0.5, 1.0]\n        for arg in GenArgList(arg_dict):\n            _test_clip_grad_value_impl(test_case, *arg)\n            _test_graph_clip_grad_value_impl(test_case, *arg)\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\nclass TestClipGradGlobal(flow.unittest.TestCase):\n    @flow.unittest.skip_unless_1n2d()\n    def test_clip_grad_global(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"shape\"] = [(2, 4), (2, 4, 3), (2, 4, 5, 6)]\n        arg_dict[\"sbp\"] = [flow.sbp.broadcast, flow.sbp.split(0), flow.sbp.split(1)]\n        arg_dict[\"placement\"] = [\n            flow.placement.all(\"cpu\"),\n            flow.placement.all(\"cuda\"),\n        ]\n        arg_dict[\"max_norm\"] = [0, 0.5, 1.0]\n        arg_dict[\"norm_type\"] = [\"inf\", \"-inf\", 0.0, 1.0, 2.0, 3.5]\n        for arg in GenArgList(arg_dict):\n            _test_clip_grad_norm_global_impl(test_case, *arg)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_clone.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\n\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestClone(flow.unittest.TestCase):\n    @autotest(n=3)\n    def test_clone_with_random_data(test_case):\n        x = random_tensor()\n        y = torch.clone(x)\n        return y\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_coco_reader.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nimport os\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\nclass COCODataLoader(flow.nn.Module):\n    def __init__(\n        self,\n        anno_file=flow.unittest.dataset_dir(\n            \"mscoco_2017/annotations/instances_val2017.json\"\n        ),\n        image_dir=flow.unittest.dataset_dir(\"mscoco_2017/val2017\"),\n        batch_size=2,\n        device=None,\n        placement=None,\n        sbp=None,\n    ):\n        super().__init__()\n        self.coco_reader = flow.nn.COCOReader(\n            annotation_file=anno_file,\n            image_dir=image_dir,\n            batch_size=batch_size,\n            shuffle=True,\n            random_seed=12345,\n            stride_partition=True,\n            device=device,\n            placement=placement,\n            sbp=sbp,\n        )\n        self.image_decoder = flow.nn.image.decode(dtype=flow.float32)\n        self.resize = flow.nn.image.Resize(target_size=[224, 224], dtype=flow.float32)\n\n    def forward(self):\n        outputs = self.coco_reader()\n        # decode images\n        image = self.image_decoder(outputs[0])\n        fixed_image = self.resize(image)[0]\n        image_id = outputs[1]\n        image_size = outputs[2]\n        return fixed_image, image_id, image_size\n\n\nclass DataLoaderGraph(flow.nn.Graph):\n    def __init__(self, loader):\n        super().__init__()\n        self.loader_ = loader\n\n    def build(self):\n        return self.loader_()\n\n\n@flow.unittest.skip_unless_1n2d()\nclass COCODataLoaderDistributedTestCase(oneflow.unittest.TestCase):\n    def test_case1(test_case):\n        rank = flow.env.get_rank()\n        # pid = os.getpid()\n        # print(f\"[{pid}][{rank}] COCODataLoaderDistributedTestCase.test_case1\")\n\n        eager_coco_loader = COCODataLoader(\n            batch_size=2, device=flow.device(\"cpu\", rank)\n        )\n\n        global_coco_loader = COCODataLoader(\n            batch_size=4,\n            placement=flow.placement(\"cpu\", ranks=[0, 1]),\n            sbp=[flow.sbp.split(0)],\n        )\n        coco_loader_graph = DataLoaderGraph(global_coco_loader)\n        # coco_loader_graph.debug()\n\n        iteration = 1\n        for i in range(iteration):\n            image, image_id, image_size = eager_coco_loader()\n\n            # print(f\"image: {image.numpy().mean()} \")\n            # print(f\"image_id: {image_id.numpy()}\")\n            # print(f\"image_size: {image_size.numpy()}\")\n\n            g_image, g_image_id, g_image_size = coco_loader_graph()\n\n            # print(f\"{'-' * 20} rank {rank} iter {i} complete {'-' * 20}\")\n            test_case.assertTrue(np.allclose(image.numpy(), g_image.to_local().numpy()))\n            test_case.assertTrue(\n                np.allclose(image_id.numpy(), g_image_id.to_local().numpy())\n            )\n            test_case.assertTrue(\n                np.allclose(image_size.numpy(), g_image_size.to_local().numpy())\n            )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_coin_flip.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\nfrom oneflow.test_utils.test_util import GenArgDict\n\n\ndef _test_coin_flip_impl(test_case, batch_size, random_seed, probability, device):\n    m = flow.nn.CoinFlip(batch_size, random_seed, probability, device)\n    x = m()\n    test_case.assertEqual(x.shape[0], batch_size)\n    device = flow.device(device)\n    test_case.assertEqual(x.device, device)\n\n\nclass TestCoinFlipModule(flow.unittest.TestCase):\n    def test_coin_flip(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"batch_size\"] = [1, 2, 50]\n        arg_dict[\"random_seed\"] = [None, 1, -1]\n        arg_dict[\"probability\"] = [0.0, 0.5, 1.0]\n        # TODO: CoinFlip support cuda kernel\n        #  arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"device\"] = [\"cpu\"]\n\n        for arg in GenArgDict(arg_dict):\n            _test_coin_flip_impl(test_case, **arg)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_comb2to2d.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\n\nimport oneflow as flow\nfrom oneflow import nn\nimport os\nimport numpy as np\n\nimport oneflow.unittest\n\n\nflow.boxing.nccl.enable_use_compute_stream(False)\n\n\nclass _TestModuleDiffHierarchy(nn.Module):\n    def forward(self, x):\n        sbp_1ds = [\n            flow.sbp.broadcast,\n            flow.sbp.partial_sum,\n            flow.sbp.split(0),\n            flow.sbp.split(1),\n        ]\n\n        for sbp1 in sbp_1ds:\n            for sbp2 in sbp_1ds:\n\n                for sbp3 in sbp_1ds:\n                    for sbp4 in sbp_1ds:\n                        # (3, 2) -> (2, 3)\n                        x = x.to_global(\n                            placement=flow.placement(\n                                type=\"cuda\", ranks=np.array(range(6)).reshape(2, 3)\n                            ),\n                            sbp=[sbp1, sbp2],\n                        )\n                        # (2, 3) -> (3, 2)\n                        x = x.to_global(\n                            placement=flow.placement(\n                                type=\"cuda\", ranks=np.array(range(6)).reshape(3, 2)\n                            ),\n                            sbp=[sbp3, sbp4],\n                        )\n\n        return x\n\n\nclass _TestModuleDiffPlacement(nn.Module):\n    def forward(self, x):\n        sbp_1ds = [\n            flow.sbp.broadcast,\n            flow.sbp.partial_sum,\n            flow.sbp.split(0),\n            flow.sbp.split(1),\n        ]\n\n        for sbp1 in sbp_1ds:\n            for sbp2 in sbp_1ds:\n\n                for sbp3 in sbp_1ds:\n                    for sbp4 in sbp_1ds:\n                        # (3, 2) -> (2, 2)\n                        x = x.to_global(\n                            placement=flow.placement(\n                                type=\"cuda\", ranks=np.array(range(4)).reshape(2, 2)\n                            ),\n                            sbp=[sbp1, sbp2],\n                        )\n                        # (2, 2) -> (3, 2)\n                        x = x.to_global(\n                            placement=flow.placement(\n                                type=\"cuda\", ranks=np.array(range(6)).reshape(3, 2)\n                            ),\n                            sbp=[sbp3, sbp4],\n                        )\n\n        return x\n\n\nclass _TestGraph(nn.Graph):\n    def __init__(self, model):\n        super().__init__()\n        self.model = model\n\n    def build(self, x):\n        x = self.model(x)\n        return x\n\n\n@flow.unittest.skip_unless_2n4d()\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\nclass TestLazyAllSbpCombinationTesting(flow.unittest.TestCase):\n    def test_lazy_boxing_2d_all_combination(test_case):\n        os.environ[\"ONEFLOW_BOXING_DISABLE_MIDDLE_NODE_AND_CHECK\"] = \"0\"\n        os.environ[\"ONEFLOW_BOXING_ENABLE_GENERAL_BASIC_COMMUNICATION\"] = \"0\"\n\n        x = flow.ones(\n            12,\n            12,\n            sbp=[flow.sbp.broadcast, flow.sbp.broadcast],\n            placement=flow.placement(\n                type=\"cuda\", ranks=np.array(range(6)).reshape(3, 2)\n            ),\n        )\n\n        model_diff_hierarchy = _TestModuleDiffHierarchy()\n        graph_diff_hierarchy = _TestGraph(model_diff_hierarchy)\n        y = graph_diff_hierarchy(x)\n\n        model_diff_placement = _TestModuleDiffPlacement()\n        graph_diff_placement = _TestGraph(model_diff_placement)\n        z = graph_diff_placement(x)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_combined_margin_loss.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\n\n\ndef _scatter_add_numpy(src, dim, index, outshape):\n    output = np.zeros(outshape)\n    for srcidx in range(0, src.size):\n        outcoord = np.unravel_index(srcidx, src.shape)\n        outcoord = [*outcoord]\n        outcoord[dim] = index[np.unravel_index(srcidx, index.shape)]\n        output_offset = np.ravel_multi_index(outcoord, outshape)\n        output[np.unravel_index(output_offset, outshape)] += src[\n            np.unravel_index(srcidx, src.shape)\n        ]\n    return output\n\n\ndef _np_one_hot(indices, depth):\n    return np.eye(depth)[indices.reshape(-1)]\n\n\ndef _np_gather_with_batch_dims(params, indices, axis):\n    batch_dims = 1\n    result = []\n    for p, i in zip(params, indices):\n        r = np.take_along_axis(p, i, axis - batch_dims)\n        result.append(r)\n    return np.stack(result)\n\n\ndef _np_gather_with_batch_dims_grad(params, indices, axis, output):\n    batch_dims = 1\n    result = []\n    for p, i, o in zip(params, indices, output):\n        r = _scatter_add_numpy(np.ones_like(o), axis - batch_dims, i, p.shape)\n        result.append(r)\n    return np.stack(result)\n\n\ndef _np_combined_margin_loss(np_input, np_label, m1, m2, m3):\n    class_num = np_input.shape[1]\n    if m1 != 1.0 or m2 != 0.0 or m3 != 0.0:\n        if m1 == 1.0 and m2 == 0.0:\n            gt_one_hot = _np_one_hot(np_label, class_num) * m3\n            np_input = np_input - gt_one_hot\n        else:\n            np_label_expand = np.reshape(np_label, (np_label.shape[0], 1))\n            zy = _np_gather_with_batch_dims(np_input, np_label_expand, 0)\n            cos_t = zy * 1\n            t = np.arccos(cos_t)\n            if m1 != 1.0:\n                t = t * m1\n            if m2 > 0.0:\n                t = t + m2\n            body = np.cos(t)\n            if m3 > 0.0:\n                body = body - m3\n            new_zy = body\n            diff = new_zy - zy\n            gt_one_hot = _np_one_hot(np_label, class_num)\n            body = gt_one_hot * diff\n            np_input = np_input + body\n    return np_input\n\n\ndef _np_combined_margin_loss_grad(np_input, np_label, m1, m2, m3):\n    class_num = np_input.shape[1]\n    if m1 != 1.0 or m2 != 0.0 or m3 != 0.0:\n        if m1 == 1.0 and m2 == 0.0:\n            result = np.ones(np_input.shape)\n        else:\n            np_label_expand = np.reshape(np_label, (np_label.shape[0], 1))\n            zy = _np_gather_with_batch_dims(np_input, np_label_expand, 0)\n            dzy = _np_gather_with_batch_dims_grad(np_input, np_label_expand, 0, zy)\n            cos_t = zy * 1\n            t = np.arccos(cos_t)\n            dt = -1 / np.sqrt((1 - cos_t * cos_t)) * dzy\n            if m1 != 1.0:\n                t = t * m1\n                dt = dt * m1\n            if m2 > 0.0:\n                t = t + m2\n            body = np.cos(t)\n            dbody = -np.sin(t) * dt\n            if m3 > 0.0:\n                body = body - m3\n            new_zy = body\n            diff = new_zy - zy\n            ddiff = dbody - dzy\n            gt_one_hot = _np_one_hot(np_label, class_num)\n            body = gt_one_hot * diff\n            dbody = gt_one_hot * ddiff\n            np_input = np_input + body\n            result = np.ones(np_input.shape) + dbody\n    else:\n        result = np.ones(np_input.shape)\n    return result\n\n\ndef _test_combined_margin_loss(\n    test_case, device_type, input_shape, label_shape, data_type, m1, m2, m3\n):\n    assert device_type in [\"cpu\", \"cuda\"]\n    np_x = np.random.uniform(low=-1, high=1, size=input_shape).astype(np.float32)\n    np_labels = np.random.randint(0, input_shape[1], size=(*label_shape,)).astype(\n        np.int32\n    )\n    x = flow.tensor(np_x, device=device_type, dtype=data_type, requires_grad=True)\n    labels = flow.tensor(np_labels, device=device_type, dtype=flow.int32)\n    loss_func = flow.nn.CombinedMarginLoss(m1, m2, m3).to(flow.device(device_type))\n    output = loss_func(x, labels)\n    output.sum().backward()\n\n    output_ref = _np_combined_margin_loss(np_x, np_labels, m1, m2, m3)\n    test_case.assertTrue(np.allclose(output.numpy(), output_ref, rtol=1e-5, atol=1e-5))\n    input_grad_ref = _np_combined_margin_loss_grad(np_x, np_labels, m1, m2, m3)\n    test_case.assertTrue(\n        np.allclose(x.grad.numpy(), input_grad_ref, rtol=1e-4, atol=1e-4)\n    )\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestCombinedMarginLoss(flow.unittest.TestCase):\n    def test_combined_margin_loss(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [_test_combined_margin_loss]\n        arg_dict[\"device_type\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"input_shape\"] = [(64, 1000)]\n        arg_dict[\"label_shape\"] = [(64,)]\n        arg_dict[\"data_type\"] = [flow.float32]\n        arg_dict[\"m1\"] = [0.3, 1.0]\n        arg_dict[\"m2\"] = [0.5, 0.0]\n        arg_dict[\"m3\"] = [0.4, 0.0]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_comm.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom threading import Thread\n\nimport numpy as np\nimport os\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\nclass TestComm(flow.unittest.TestCase):\n    def _test_send_recv(test_case, x0, src, dst):\n        rank = flow.env.get_rank()\n        if rank == src:\n            x1 = x0\n            flow.comm.send(x1, dst)\n\n            x2 = x0\n            flow.comm.send(x2, dst)\n        elif rank == dst:\n            x1 = flow.comm.recv(src)\n            test_case.assertTrue(np.array_equal(x1.numpy(), x0.numpy()))\n            test_case.assertEqual(x1.device, x0.device)\n\n            x2 = flow.zeros_like(x0)\n            flow.comm.recv(src, out=x2)\n            test_case.assertTrue(np.array_equal(x2.numpy(), x0.numpy()))\n            test_case.assertEqual(x2.device, x0.device)\n        else:\n            # do nothing\n            pass\n\n    @flow.unittest.skip_unless_1n2d()\n    def test_send_recv_2_devices(test_case):\n        x0 = flow.tensor([[1, 2]])\n        test_case._test_send_recv(x0, 0, 1)\n        x0 = x0.to(\"cuda\")\n        test_case._test_send_recv(x0, 1, 0)\n\n    @flow.unittest.skip_unless_1n4d()\n    def test_send_recv_4_devices(test_case):\n        x0 = flow.tensor([[1, 2]])\n        test_case._test_send_recv(x0, 3, 1)\n        x0 = x0.to(\"cuda\")\n        test_case._test_send_recv(x0, 0, 3)\n\n    def _test_send_recv_without_sending_meta(test_case, x0, src, dst):\n        rank = flow.env.get_rank()\n        if rank == src:\n            x1 = x0\n            flow.comm.send(x1, dst, send_meta=False)\n\n            x2 = x0\n            flow.comm.send(x2, dst, send_meta=False)\n        elif rank == dst:\n            x1 = flow.comm.recv(src, shape=x0.shape, dtype=x0.dtype, device=x0.device)\n            test_case.assertTrue(np.array_equal(x1.numpy(), x0.numpy()))\n\n            x2 = flow.zeros_like(x0)\n            flow.comm.recv(\n                src, shape=x0.shape, dtype=x0.dtype, device=x0.device, out=x2\n            )\n            test_case.assertTrue(np.array_equal(x2.numpy(), x0.numpy()))\n        else:\n            # do nothing\n            pass\n\n    @flow.unittest.skip_unless_1n2d()\n    def test_send_recv_without_sending_meta_2_devices(test_case):\n        x0 = flow.tensor([[1, 2]])\n        test_case._test_send_recv_without_sending_meta(x0, 1, 0)\n        x0 = x0.to(\"cuda\")\n        test_case._test_send_recv_without_sending_meta(x0, 0, 1)\n\n    @flow.unittest.skip_unless_1n4d()\n    def test_send_recv_without_sending_meta_4_devices(test_case):\n        x0 = flow.tensor([[1, 2]])\n        test_case._test_send_recv_without_sending_meta(x0, 2, 3)\n        x0 = x0.to(\"cuda\")\n        test_case._test_send_recv_without_sending_meta(x0, 3, 1)\n\n    @flow.unittest.skip_unless_1n2d()\n    def test_comm_in_thread(test_case):\n        def threaded_function():\n            rank = flow.env.get_rank()\n            rev = flow.framework.check_point_v2._broadcast_py_object(rank, 0)\n            test_case.assertEqual(rev, 0)\n\n            x = flow.tensor([rank, rank + 1]).to_global(\n                placement=flow.placement.all(\"cpu\"), sbp=flow.sbp.split(0)\n            )\n            test_case.assertTrue(np.array_equal(x.numpy(), np.array([0, 1, 1, 2])))\n            x = flow.tensor([rank, rank + 1])\n            flow.comm.all_reduce(x)\n            test_case.assertTrue(np.array_equal(x.numpy(), np.array([1, 3])))\n\n        thread = Thread(target=threaded_function)\n        thread.start()\n        thread.join()\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_comm_ops.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport numpy as np\nimport unittest\nimport os\n\nimport oneflow as flow\nimport oneflow.unittest\n\nimport torch\nimport torch.distributed as dist\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\nclass TestAllReduce(flow.unittest.TestCase):\n    @flow.unittest.skip_unless_1n2d()\n    def test_all_reduce_1n2d(test_case):\n        np_arr = np.array([[1, 2], [3, 4]])\n        tensor = flow.tensor(np_arr, device=\"cuda\")\n        flow.comm.all_reduce(tensor)\n        test_case.assertTrue(np.allclose(tensor.numpy(), np_arr * 2))\n\n    @flow.unittest.skip_unless_2n2d()\n    def test_all_reduce_2n2d(test_case):\n        np_arr = np.array([[1, 2], [3, 4]])\n        tensor = flow.tensor(np_arr, device=\"cuda\")\n        flow.comm.all_reduce(tensor)\n        test_case.assertTrue(np.allclose(tensor.numpy(), np_arr * 4))\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\nclass TestAllGather(flow.unittest.TestCase):\n    @flow.unittest.skip_unless_1n2d()\n    def test_all_gather_into_tensor_1n2d(test_case):\n        device = \"cuda\"\n        tensor_in = (\n            flow.tensor([[1, 2, 3], [4, 5, 6]], dtype=flow.int64, device=device)\n            + flow.env.get_rank() * 6\n        )\n        tensor_out = flow.zeros(4, 3, dtype=flow.int64, device=device)\n        flow.comm.all_gather_into_tensor(tensor_out, tensor_in)\n        test_case.assertTrue(\n            np.allclose(\n                tensor_out.numpy(),\n                np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]),\n            )\n        )\n\n        tensor_out2 = flow.zeros(2, 3, 2, dtype=flow.int64, device=device)\n        flow.comm.all_gather_into_tensor(tensor_out2, tensor_in)\n        test_case.assertTrue(\n            np.allclose(\n                tensor_out2.numpy(),\n                np.array([[[1, 2], [3, 4], [5, 6]], [[7, 8], [9, 10], [11, 12]]]),\n            )\n        )\n\n    @flow.unittest.skip_unless_1n2d()\n    def test_all_gather_1n2d(test_case):\n        if flow.env.get_rank() == 0:\n            np_arr = np.array([[2, 3], [4, 5]])\n        elif flow.env.get_rank() == 1:\n            np_arr = np.array([[1, 2], [3, 4]])\n        input = flow.tensor(np_arr, device=\"cuda\", dtype=flow.int32)\n        tensor_list = [flow.zeros(np_arr.shape, dtype=flow.int32) for _ in range(2)]\n        flow.comm.all_gather(tensor_list, input)\n        test_case.assertTrue(\n            np.allclose(tensor_list[0].numpy(), np.array([[2, 3], [4, 5]]))\n        )\n        test_case.assertTrue(\n            np.allclose(tensor_list[1].numpy(), np.array([[1, 2], [3, 4]]))\n        )\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\nclass TestBroadCast(flow.unittest.TestCase):\n    @flow.unittest.skip_unless_1n2d()\n    def test_broadcast_1n2d(test_case):\n        if flow.env.get_rank() == 0:\n            np_arr = np.array([[1, 2], [3, 4]])\n        elif flow.env.get_rank() == 1:\n            np_arr = np.array([[4, 5], [6, 7]])\n        tensor = flow.tensor(np_arr, device=\"cuda\", dtype=flow.int32)\n        flow.comm.broadcast(tensor, 1)\n        test_case.assertTrue(np.allclose(tensor.numpy(), np.array([[4, 5], [6, 7]])))\n\n        tensor = flow.tensor(np_arr, device=\"cuda\", dtype=flow.int32)\n        flow.comm.broadcast(tensor, 0)\n        test_case.assertTrue(np.allclose(tensor.numpy(), np.array([[1, 2], [3, 4]])))\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\nclass TestScatter(flow.unittest.TestCase):\n    @flow.unittest.skip_unless_1n4d()\n    def test_scatter_1n4d(test_case):\n        output = flow.tensor([[1, 2], [3, 4]], device=\"cuda\")\n        if flow.env.get_rank() == 1:\n            tensor_list = [\n                flow.tensor([[5, 6], [7, 8]], device=\"cuda\") + i for i in range(4)\n            ]\n            flow.comm.scatter(output, tensor_list, src=1)\n            test_case.assertTrue(\n                np.allclose(output.numpy(), np.array([[6, 7], [8, 9]]))\n            )\n        else:\n            flow.comm.scatter(output, src=1)\n            test_case.assertTrue(\n                np.allclose(\n                    output.numpy(), np.array([[5, 6], [7, 8]]) + flow.env.get_rank()\n                )\n            )\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\nclass TestGather(flow.unittest.TestCase):\n    @flow.unittest.skip_unless_1n4d()\n    def test_gather_1n4d(test_case):\n        np_arr = np.array([[1, 2], [3, 4]])\n        if flow.env.get_rank() == 1:\n            input = flow.tensor(\n                np_arr + flow.env.get_rank(), device=\"cuda\", dtype=flow.int32\n            )\n            tensor_list = [flow.zeros(np_arr.shape, dtype=flow.int32) for _ in range(4)]\n            flow.comm.gather(input, gather_list=tensor_list, dst=1)\n            for i in range(4):\n                test_case.assertTrue(\n                    np.allclose(tensor_list[i].numpy(), np.array([[1, 2], [3, 4]]) + i)\n                )\n        else:\n            input = flow.tensor(\n                np_arr + flow.env.get_rank(), device=\"cuda\", dtype=flow.int32\n            )\n            flow.comm.gather(input, dst=1)\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\nclass TestReduce(flow.unittest.TestCase):\n    @flow.unittest.skip_unless_1n2d()\n    def test_reduce_1n2d(test_case):\n        if flow.env.get_rank() == 0:\n            np_arr = np.array([[1, 2], [3, 4]])\n        elif flow.env.get_rank() == 1:\n            np_arr = np.array([[4, 5], [6, 7]])\n        tensor = flow.tensor(np_arr, device=\"cuda\", dtype=flow.int32)\n        flow.comm.reduce(tensor, 0)\n        if flow.env.get_rank() == 0:\n            test_case.assertTrue(\n                np.allclose(tensor.numpy(), np.array([[5, 7], [9, 11]]))\n            )\n        else:\n            test_case.assertTrue(\n                np.allclose(tensor.numpy(), np.array([[4, 5], [6, 7]]))\n            )\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\nclass TestAllToAll(flow.unittest.TestCase):\n    @flow.unittest.skip_unless_1n4d()\n    def test_all_to_all_1n4d(test_case):\n        input_list = [\n            flow.tensor([0, 1], device=\"cuda\") + i * 2 + flow.env.get_rank() * 8\n            for i in range(4)\n        ]\n        output_list = [flow.tensor([0, 1], device=\"cuda\") for _ in range(4)]\n        flow.comm.all_to_all(output_list, input_list)\n        for i in range(len(output_list)):\n            test_case.assertTrue(\n                np.allclose(\n                    output_list[i].numpy(),\n                    input_list[i].numpy() + (i - flow.env.get_rank()) * 6,\n                )\n            )\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\nclass TestReduceScatter(flow.unittest.TestCase):\n    @flow.unittest.skip_unless_1n4d()\n    def test_reduce_scatter_1n4d(test_case):\n        output = flow.tensor([[0, 0], [0, 0]], device=\"cuda\")\n        tensor_list = [\n            flow.tensor([[1, 2], [3, 4]], device=\"cuda\") + flow.env.get_rank() + i\n            for i in range(4)\n        ]\n        flow.comm.reduce_scatter(output, tensor_list)\n        test_case.assertTrue(\n            np.allclose(output.numpy(), tensor_list[0].numpy() * 4 + 6)\n        )\n\n    @flow.unittest.skip_unless_1n2d()\n    def test_reduce_scatter_tensor_1n2d(test_case):\n        tensor_in = flow.tensor(\n            [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]],\n            dtype=flow.int64,\n            device=\"cuda\",\n        )\n        tensor_out = flow.zeros(2, 3, dtype=flow.int64, device=\"cuda\")\n        flow.comm.reduce_scatter_tensor(tensor_out, tensor_in)\n        if flow.env.get_rank() == 0:\n            test_case.assertTrue(\n                np.allclose(tensor_out.numpy(), np.array([[2, 4, 6], [8, 10, 12]]),)\n            )\n        else:\n            test_case.assertTrue(\n                np.allclose(tensor_out.numpy(), np.array([[14, 16, 18], [20, 22, 24]]),)\n            )\n        tensor_in2 = tensor_in.reshape(2, 3, 2)\n        tensor_out2 = flow.zeros(2, 3, dtype=flow.int64, device=\"cuda\")\n        flow.comm.reduce_scatter_tensor(tensor_out2, tensor_in2)\n        if flow.env.get_rank() == 0:\n            test_case.assertTrue(\n                np.allclose(tensor_out2.numpy(), np.array([[2, 4, 6], [8, 10, 12]]),)\n            )\n        else:\n            test_case.assertTrue(\n                np.allclose(\n                    tensor_out2.numpy(), np.array([[14, 16, 18], [20, 22, 24]]),\n                )\n            )\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n@flow.unittest.skip_unless_1n2d()\nclass TestDocs(flow.unittest.TestCase):\n    def test_docs(test_case):\n        oneflow.framework.unittest.check_multi_rank_docstr(oneflow.comm.comm_ops)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_concat.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\ndef _test_concat_origin(test_case, device):\n    input1 = flow.tensor(\n        np.random.randn(2, 6, 5, 3), dtype=flow.float32, device=flow.device(device)\n    )\n    input2 = flow.tensor(\n        np.random.randn(2, 6, 5, 3), dtype=flow.float32, device=flow.device(device)\n    )\n    of_out = flow.cat([input1, input2], dim=0)\n    np_out = np.concatenate((input1.numpy(), input2.numpy()), axis=0)\n    test_case.assertTrue(np.array_equal(of_out.numpy(), np_out))\n\n\ndef _test_concat_with_empty_input(test_case, device):\n    input1 = flow.Tensor().to(flow.device(device))\n    input2 = flow.tensor(\n        np.random.randn(2, 6, 5, 3), dtype=flow.float32, device=flow.device(device)\n    )\n    of_out1 = flow.cat([input1, input2], dim=0)\n    of_out2 = flow.cat([input2, input1], dim=0)\n    of_out3 = flow.cat([input1, input2, input1, input1], dim=0)\n\n    torch_input1 = torch.Tensor().to(torch.device(device))\n    torch_input2 = torch.tensor(\n        np.random.randn(2, 6, 5, 3), dtype=torch.float32, device=torch.device(device)\n    )\n    torch_out1 = torch.cat((torch_input1, torch_input2), 0)\n    torch_out2 = torch.cat((torch_input2, torch_input1), 0)\n    torch_out3 = torch.cat((torch_input1, torch_input2, torch_input1, torch_input1), 0)\n\n    test_case.assertTrue(\n        np.array_equal(of_out1.numpy(), torch_out1.detach().cpu().numpy())\n    )\n    test_case.assertTrue(\n        np.array_equal(of_out2.numpy(), torch_out2.detach().cpu().numpy())\n    )\n    test_case.assertTrue(\n        np.array_equal(of_out3.numpy(), torch_out3.detach().cpu().numpy())\n    )\n    test_case.assertTrue(\n        np.array_equal(of_out1.numpy(), torch_out2.detach().cpu().numpy())\n    )\n    test_case.assertTrue(\n        np.array_equal(of_out1.numpy(), torch_out3.detach().cpu().numpy())\n    )\n\n\ndef _test_concat_with_axis_one(test_case, device):\n    input1 = flow.tensor(\n        np.random.randn(2, 6, 5, 3), dtype=flow.float32, device=flow.device(device)\n    )\n    input2 = flow.tensor(\n        np.random.randn(2, 6, 5, 3), dtype=flow.float32, device=flow.device(device)\n    )\n    of_out = flow.cat([input1, input2], dim=1)\n    np_out = np.concatenate((input1.numpy(), input2.numpy()), axis=1)\n    test_case.assertTrue(np.array_equal(of_out.numpy(), np_out))\n\n\ndef _test_concat_with_three_tensor(test_case, device):\n    input1 = flow.tensor(\n        np.random.randn(2, 6, 5, 3), dtype=flow.float32, device=flow.device(device)\n    )\n    input2 = flow.tensor(\n        np.random.randn(2, 6, 5, 3), dtype=flow.float32, device=flow.device(device)\n    )\n    input3 = flow.tensor(\n        np.random.randn(2, 6, 5, 3), dtype=flow.float32, device=flow.device(device)\n    )\n    of_out = flow.cat([input1, input2, input3], dim=1)\n    np_out = np.concatenate((input1.numpy(), input2.numpy(), input3.numpy()), axis=1)\n    test_case.assertTrue(np.array_equal(of_out.numpy(), np_out))\n\n\ndef _test_concat_with_three_tensor_backward(test_case, device):\n    input1 = flow.tensor(\n        np.random.randn(2, 6, 5, 3),\n        dtype=flow.float32,\n        device=flow.device(device),\n        requires_grad=True,\n    )\n    input2 = flow.tensor(\n        np.random.randn(2, 6, 5, 3),\n        dtype=flow.float32,\n        device=flow.device(device),\n        requires_grad=True,\n    )\n    input3 = flow.tensor(\n        np.random.randn(2, 6, 5, 3),\n        dtype=flow.float32,\n        device=flow.device(device),\n        requires_grad=True,\n    )\n    of_out = flow.cat([input1, input2, input3], dim=1)\n    of_out = of_out.sum()\n    of_out.backward()\n    test_case.assertTrue(\n        np.allclose(input1.grad.numpy(), np.ones((2, 6, 5, 3)), 0.0001, 0.0001)\n    )\n    test_case.assertTrue(\n        np.allclose(input2.grad.numpy(), np.ones((2, 6, 5, 3)), 0.0001, 0.0001)\n    )\n    test_case.assertTrue(\n        np.allclose(input3.grad.numpy(), np.ones((2, 6, 5, 3)), 0.0001, 0.0001)\n    )\n\n\ndef _test_concat_grad_and_no_grad(test_case, device):\n    input1 = flow.tensor(\n        np.random.randn(2, 6, 5, 3),\n        dtype=flow.float32,\n        device=flow.device(device),\n        requires_grad=True,\n    )\n    input2 = flow.tensor(\n        np.random.randn(2, 6, 5, 3),\n        dtype=flow.float32,\n        device=flow.device(device),\n        requires_grad=False,\n    )\n    of_out = flow.cat([input1, input2], dim=1)\n    of_out = of_out.sum()\n    of_out.backward()\n    test_case.assertTrue(\n        np.allclose(input1.grad.numpy(), np.ones((2, 6, 5, 3)), 0.0001, 0.0001)\n    )\n\n\ndef _test_concat_single_input_type(test_case, device):\n    torch_list = [torch.Tensor([1, 1, 9, 1])]\n    torch_list = [t.to(dtype=torch.int64, device=device) for t in torch_list]\n\n    flow_list = [flow.Tensor([1, 1, 9, 1])]\n    flow_list = [t.to(dtype=flow.int64, device=device) for t in flow_list]\n    flow_cat_list = flow.cat(flow_list)\n    test_case.assertTrue(flow_cat_list.dtype is oneflow.int64)\n\n\ndef _test_concat_grad_fn_name(test_case, device):\n    x1 = flow.randn(2, 3, requires_grad=True, device=device)\n    x2 = flow.randn(2, 3, requires_grad=True, device=device)\n    cat = flow.cat([x1, x2], dim=1)\n    grad_fn_name = cat.grad_fn.name()\n    test_case.assertEqual(grad_fn_name, \"catBackward\")\n    test_case.assertEqual(cat.grad_fn.next_functions[0][0].name(), \"accumulategrad\")\n    next_fn = cat.grad_fn.next_functions[0]\n    test_case.assertTrue(\n        np.allclose(next_fn[0].variable.numpy(), x1.numpy(), 0.0001, 0.0001)\n    )\n    next_fn = cat.grad_fn.next_functions[1]\n    test_case.assertTrue(\n        np.allclose(next_fn[0].variable.numpy(), x2.numpy(), 0.0001, 0.0001)\n    )\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestModule(flow.unittest.TestCase):\n    def test_concat(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [\n            _test_concat_origin,\n            _test_concat_with_empty_input,\n            _test_concat_with_axis_one,\n            _test_concat_with_three_tensor,\n            _test_concat_with_three_tensor_backward,\n            _test_concat_grad_and_no_grad,\n            _test_concat_single_input_type,\n            _test_concat_grad_fn_name,\n        ]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n    @autotest(n=5)\n    def test_cat_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=2, dim0=random(), dim1=random()).to(device)\n        return torch.cat((x, x, x), random(0, 2).to(int))\n\n    @autotest(n=5, check_graph=True, check_dtype=True)\n    def test_cat_with_diff_dtypes(test_case):\n        device = random_device()\n        x = random_tensor(ndim=2, dim0=random(), dim1=random()).to(device).float()\n        y = x.int()\n        z = x.double()\n        return torch.cat((x, y, z), random(0, 2).to(int))\n\n    @autotest(n=1, check_graph=True, check_dtype=True)\n    def test_cat_with_diff_dtype_corner_case(test_case):\n        device = random_device()\n        input_list = list()\n        x = random_tensor(ndim=2, dim0=random(), dim1=random()).to(device)\n        y = x.int()\n        for i in range(128):\n            input_list.append(x)\n        for j in range(128, 257):\n            input_list.append(y)\n        return torch.cat(tuple(input_list), random(0, 2).to(int))\n\n    @autotest(n=5, auto_backward=False, check_graph=True)\n    def test_concat_with_input_0_size_data(test_case):\n        device = random_device()\n        x = random_tensor(4, 2, 3, 2, 4).to(device)\n        y = random_tensor(4, 2, 3, random(0, 3), 4).to(device)\n        z = torch.cat((x, y), dim=2)\n        return z\n\n    @autotest(n=5, auto_backward=False, check_graph=True)\n    def test_concat_with_output_0_size_data(test_case):\n        device = random_device()\n        x = random_tensor(4, 2, 0, 2, 4).to(device)\n        y = random_tensor(4, 2, 0, 2, 4).to(device)\n        dim = random(0, 4).to(int).value()\n        z = torch.cat((x, y), dim=dim)\n        return z\n\n    @autotest(n=5, auto_backward=False, check_graph=True)\n    def test_cat_bool_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=2, dim0=random(), dim1=random()).to(device, torch.bool)\n        return torch.cat((x, x, x), random(0, 2).to(int))\n\n    @autotest(n=5, check_graph=True)\n    def test_cat_only_one_tensor(test_case):\n        device = random_device()\n        x = random_tensor(4, 2, 3, random(0, 3)).to(device)\n        return torch.cat((x,), 0)\n\n    @profile(torch.cat)\n    def profile_cat(test_case):\n        input = torch.ones(100, 100)\n        torch.cat((input, input), dim=0)\n        torch.cat((input, input), dim=1)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_constant.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nimport oneflow as flow\n\nimport oneflow.unittest\nfrom oneflow.test_utils.test_util import GenArgList\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\ndef _test_different_dtype(test_case, device, shape):\n    y1 = flow.ones(shape, dtype=flow.int32, device=flow.device(device))\n    test_case.assertTrue(np.array_equal(np.ones(shape, dtype=np.int32), y1.numpy()))\n    y2 = flow.ones(shape, dtype=flow.uint8, device=flow.device(device))\n    test_case.assertTrue(np.array_equal(np.ones(shape, dtype=np.uint8), y2.numpy()))\n    y3 = flow.ones(shape, dtype=flow.float64, device=flow.device(device))\n    test_case.assertTrue(np.array_equal(np.ones(shape, dtype=np.float64), y3.numpy()))\n    y4 = flow.ones(shape, dtype=flow.short, device=flow.device(device))\n    test_case.assertTrue(np.array_equal(np.ones(shape, dtype=np.short), y4.numpy()))\n    y5 = flow.ones(shape, dtype=flow.int16, device=flow.device(device))\n    test_case.assertTrue(np.array_equal(np.ones(shape, dtype=np.int16), y5.numpy()))\n    y6 = flow.ones(shape, dtype=flow.char, device=flow.device(device))\n    test_case.assertTrue(np.array_equal(np.ones(shape, dtype=np.int8), y6.numpy()))\n    y7 = flow.ones(shape, dtype=flow.int8, device=flow.device(device))\n    test_case.assertTrue(np.array_equal(np.ones(shape, dtype=np.int8), y7.numpy()))\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestConstantModule(flow.unittest.TestCase):\n    @autotest(n=10, auto_backward=False, check_graph=True)\n    def test_flow_zeros_list_with_random_data(test_case):\n        device = random_device()\n        y1 = torch.zeros(random().to(int)).to(device)\n        y2 = torch.zeros(random().to(int), random().to(int)).to(device)\n        y3 = torch.zeros(random().to(int), random().to(int), random().to(int)).to(\n            device\n        )\n        y4 = torch.zeros(\n            random().to(int), random().to(int), random().to(int), random().to(int)\n        ).to(device)\n        return y1, y2, y3, y4\n\n    @profile(torch.zeros)\n    def profile_zeros(test_case):\n        torch.zeros(2, 3)\n        torch.zeros(32, 3, 128, 128)\n        torch.zeros(1000, 1000)\n\n    @autotest(n=10, auto_backward=False, check_graph=True)\n    def test_flow_ones_list_with_random_data(test_case):\n        device = random_device()\n        y1 = torch.ones(random().to(int)).to(device)\n        y2 = torch.ones(random().to(int), random().to(int)).to(device)\n        y3 = torch.ones(random().to(int), random().to(int), random().to(int)).to(device)\n        y4 = torch.ones(\n            random().to(int), random().to(int), random().to(int), random().to(int)\n        ).to(device)\n        return y1, y2, y3, y4\n\n    @profile(torch.ones)\n    def profile_ones(test_case):\n        torch.ones(2, 3)\n        torch.ones(32, 3, 128, 128)\n        torch.ones(1000, 1000)\n\n    @autotest(auto_backward=False, check_graph=True)\n    def test_flow_zeros_like_list_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor().to(device)\n        y = torch.zeros_like(x)\n        return y\n\n    @profile(torch.zeros_like)\n    def profile_zeros_like(test_case):\n        input1 = torch.ones(32, 3, 128, 128)\n        input2 = torch.ones(1000, 1000)\n        input3 = torch.ones(2, 3)\n        torch.zeros_like(input1)\n        torch.zeros_like(input2)\n        torch.zeros_like(input3)\n\n    @autotest(auto_backward=True, check_graph=True)\n    def test_flow_zeros_like_list_with_random_data_and_requires_grad(test_case):\n        device = random_device()\n        x = random_tensor().to(device)\n        y = torch.zeros_like(x, requires_grad=True)\n        return y\n\n    @autotest(auto_backward=False, check_graph=True)\n    def test_flow_zeros_like_list_with_0dim_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=0).to(device)\n        y = torch.zeros_like(x)\n        return y\n\n    @autotest(auto_backward=False, check_graph=True)\n    def test_flow_ones_like_list_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor().to(device)\n        y = torch.ones_like(x)\n        return y\n\n    @profile(torch.ones_like)\n    def profile_ones_like(test_case):\n        input1 = torch.ones(32, 3, 128, 128)\n        input2 = torch.ones(1000, 1000)\n        input3 = torch.ones(2, 3)\n        torch.ones_like(input1)\n        torch.ones_like(input2)\n        torch.ones_like(input3)\n\n    @autotest(auto_backward=True, check_graph=True)\n    def test_flow_ones_like_list_with_random_data_and_requires_grad(test_case):\n        device = random_device()\n        x = random_tensor().to(device)\n        y = torch.ones_like(x, requires_grad=True)\n        return y\n\n    @autotest(auto_backward=False, check_graph=True)\n    def test_flow_ones_like_list_with_0dim_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=0).to(device)\n        y = torch.ones_like(x)\n        return y\n\n    @autotest(auto_backward=True, check_graph=True)\n    def test_flow_new_ones_list_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor().to(device)\n        y = x.new_ones(\n            (random().to(int), random().to(int), random().to(int)),\n            device=device.value(),\n            requires_grad=constant(True),\n        )\n        return y\n\n    @profile(torch.Tensor.new_ones)\n    def profile_new_ones(test_case):\n        x = torch.Tensor(np.ones((1, 2, 3)))\n        x.new_ones((2, 3))\n        x.new_ones((32, 3, 128, 128))\n        x.new_ones((1000, 1000, 1000, 1000))\n\n    @unittest.skip(\"skip for now, becase it failed 10 times in past week\")\n    @autotest(auto_backward=True, check_graph=True)\n    def test_flow_new_ones_list_with_0dim_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=0).to(device)\n        y = x.new_ones(\n            (random().to(int), random().to(int), random().to(int)),\n            device=device.value(),\n            requires_grad=constant(True),\n        )\n        return y\n\n    @autotest(n=5)\n    def test_new_zeros(test_case):\n        device = random_device()\n        x = random_tensor().to(device)\n        y = x.new_zeros(\n            (random().to(int), random().to(int), random().to(int)),\n            device=device.value(),\n            requires_grad=constant(True),\n        )\n        return y\n\n    @profile(torch.Tensor.new_zeros)\n    def profile_new_zeros(test_case):\n        x = torch.Tensor(np.ones((1, 2, 3)))\n        x.new_zeros((2, 3))\n        x.new_zeros((32, 3, 128, 128))\n        x.new_zeros((1000, 1000, 1000, 1000))\n\n    @autotest(n=5)\n    def test_new_full(test_case):\n        device = random_device()\n        x = random_tensor().to(device)\n        y = x.new_full(\n            (random().to(int), random().to(int), random().to(int)),\n            random().to(float).value(),\n            device=device.value(),\n            requires_grad=constant(True),\n        )\n        return y\n\n    @autotest(n=5, auto_backward=False)\n    def test_new_full_with_scalar(test_case):\n        device = random_device()\n        x = random_tensor().to(device)\n        y = x.new_full([], random().to(int))\n        return y\n\n    @autotest(n=5, auto_backward=False)\n    def test_full_with_scalar(test_case):\n        device = random_device()\n        y = torch.full([], random().to(int), device=device)\n        return y\n\n    @autotest(n=10, auto_backward=True)\n    def test_full_with_random_data_int(test_case):\n        device = random_device()\n        shape = random_tensor(low=1, high=6, requires_grad=False).pytorch.shape\n        y = torch.full(shape, 2.0, requires_grad=True)\n        return y\n\n    @autotest(n=5)\n    def test_full_with_random_data_numpy_scalar(test_case):\n        device = random_device()\n        shape = random_tensor(low=1, high=6, requires_grad=False).pytorch.shape\n        y = torch.full(shape, np.array([2.0])[0], device=device, requires_grad=True)\n        return y\n\n    @autotest(n=5)\n    def test_full_with_scalar_tensor(test_case):\n        device = random_device()\n        shape = random_tensor(low=0, high=6, requires_grad=False).pytorch.shape\n        y = torch.full(\n            shape,\n            torch.tensor(2.0, requires_grad=random().to(bool)),\n            device=device,\n            requires_grad=True,\n        )\n        return y\n\n    @profile(torch.full)\n    def profile_full_with_scalar_tensor(test_case):\n        torch.full((2, 3), torch.tensor(3.141592))\n        torch.full((64, 3, 128, 128), torch.tensor(3.141592))\n        torch.full((1000, 1000), torch.tensor(3.141592))\n\n    @profile(torch.full)\n    def profile_full(test_case):\n        torch.full((2, 3), 3.141592)\n        torch.full((64, 3, 128, 128), 3.141592)\n        torch.full((1000, 1000), 3.141592)\n\n    @autotest(n=10, auto_backward=True)\n    def test_full_with_random_data_float(test_case):\n        device = random_device()\n        shape = random_tensor(low=1, high=6, requires_grad=False).pytorch.shape\n        y = torch.full(shape, 2.0, requires_grad=True)\n        return y\n\n    @autotest(n=10, auto_backward=True)\n    def test_full_like_with_random_data_float(test_case):\n        device = random_device()\n        x = random_tensor(low=1, high=6, requires_grad=False).to(device)\n        y = torch.full_like(x, 2.0, requires_grad=True)\n        return y\n\n    @profile(torch.full_like)\n    def profile_full_like(test_case):\n        torch.full_like(torch.ones(2, 3), 3.141592)\n        torch.full_like(torch.ones(64, 3, 128, 128), 3.141592)\n        torch.full_like(torch.ones(1000, 1000), 3.141592)\n\n    def test_cast(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [\n            _test_different_dtype,\n        ]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"shape\"] = [(2, 3), (2, 3, 4), (2, 3, 4, 5), (2, 0, 4)]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_constant_pad.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\n\n\nfrom random import choice\nfrom oneflow.test_utils.automated_test_util import *\nfrom oneflow.nn.common_types import _size_2_t, _size_4_t, _size_6_t\nimport oneflow as flow\nimport oneflow.unittest\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestConstantPad1d(flow.unittest.TestCase):\n    @autotest(n=10, rtol=0.001, atol=0.001, include_complex=True)\n    def test_constantpad1d_with_random_data(test_case):\n        m = torch.nn.ConstantPad1d(\n            padding=random(1, 6).to(_size_2_t), value=random().to(float)\n        )\n        m.train(random())\n        device = random_device()\n        m.to(device)\n        x = random_tensor(ndim=3, dim1=random(1, 6), dim2=random(1, 6)).to(device)\n        y = m(x)\n        return y\n\n    @autotest(n=10, rtol=0.001, atol=0.001, auto_backward=False)\n    def test_constantpad1d_with_random_int_data(test_case):\n        dtype = choice([int, bool])\n        value = random(0, 2).to(bool) if dtype is bool else random().to(int)\n        m = torch.nn.ConstantPad1d(padding=random(1, 6).to(_size_2_t), value=value)\n        m.train(random())\n        device = random_device()\n        m.to(device)\n        x = random_tensor(ndim=3, dim1=random(1, 6), dim2=random(1, 6), dtype=int).to(\n            device\n        )\n        if dtype is bool:\n            x = x.bool()\n        y = m(x)\n        return y\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestConstantPad2d(flow.unittest.TestCase):\n    @autotest(n=10, rtol=0.001, atol=0.001, include_complex=True)\n    def test_constantpad2d_with_random_data(test_case):\n        m = torch.nn.ConstantPad2d(\n            padding=random(1, 6).to(_size_4_t), value=random().to(float)\n        )\n        m.train(random())\n        device = random_device()\n        m.to(device)\n        x = random_tensor(\n            ndim=4, dim1=random(1, 6), dim2=random(1, 6), dim3=random(1, 6)\n        ).to(device)\n        y = m(x)\n        return y\n\n    @autotest(n=10, rtol=0.001, atol=0.001, auto_backward=False)\n    def test_constantpad2d_with_random_int_data(test_case):\n        dtype = choice([int, bool])\n        value = random(0, 2).to(bool) if dtype is bool else random().to(int)\n        m = torch.nn.ConstantPad2d(padding=random(1, 6).to(_size_4_t), value=value,)\n        m.train(random())\n        device = random_device()\n        m.to(device)\n        x = random_tensor(\n            ndim=4, dim1=random(1, 6), dim2=random(1, 6), dim3=random(1, 6)\n        ).to(device)\n        if dtype is bool:\n            x = x.bool()\n        y = m(x)\n        return y\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestConstantPad3d(flow.unittest.TestCase):\n    @autotest(n=10, rtol=0.001, atol=0.001, include_complex=True)\n    def test_constantpad3d_with_random_data(test_case):\n        m = torch.nn.ConstantPad3d(\n            padding=random(1, 6).to(_size_6_t), value=random().to(float)\n        )\n        m.train(random())\n        device = random_device()\n        m.to(device)\n        x = random_tensor(\n            ndim=5,\n            dim1=random(1, 6),\n            dim2=random(1, 6),\n            dim3=random(1, 6),\n            dim4=random(1, 6),\n        ).to(device)\n        y = m(x)\n        return y\n\n    @autotest(n=10, rtol=0.001, atol=0.001, auto_backward=False)\n    def test_constantpad3d_with_random_int_data(test_case):\n        dtype = choice([bool, int])\n        value = random(0, 2).to(bool) if dtype is bool else random().to(int)\n        m = torch.nn.ConstantPad3d(padding=random(1, 6).to(_size_6_t), value=value,)\n        m.train(random())\n        device = random_device()\n        m.to(device)\n        x = random_tensor(\n            ndim=5,\n            dim1=random(1, 6),\n            dim2=random(1, 6),\n            dim3=random(1, 6),\n            dim4=random(1, 6),\n        ).to(device)\n        if dtype is bool:\n            x = x.bool()\n        y = m(x)\n        return y\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestFunctionalConstantPad2d(flow.unittest.TestCase):\n    @autotest(n=10, rtol=0.001, atol=0.001, check_graph=True, include_complex=True)\n    def test_functional_constantpad2d(test_case):\n        device = random_device()\n        padding = random(-1, 6).to(_size_4_t)\n        value = random().to(float)\n        x = random_tensor(\n            ndim=4,\n            dim0=random(1, 6),\n            dim1=random(1, 6),\n            dim2=random(2, 6),\n            dim3=random(2, 6),\n        ).to(device)\n        y = torch.nn.functional.pad(x, pad=padding, mode=\"constant\", value=value)\n        return y\n\n    @autotest(n=10, rtol=0.001, atol=0.001, check_graph=True, auto_backward=False)\n    def test_functional_constantpad2d_int_data(test_case):\n        dtype = choice([bool, int])\n        device = random_device()\n        padding = random(-1, 6).to(_size_4_t)\n        value = random(0, 2).to(bool) if dtype is bool else random().to(int)\n        x = random_tensor(\n            ndim=4,\n            dim0=random(1, 6),\n            dim1=random(1, 6),\n            dim2=random(2, 6),\n            dim3=random(2, 6),\n        ).to(device)\n        if dtype is bool:\n            x = x.bool()\n        y = torch.nn.functional.pad(x, pad=padding, mode=\"constant\", value=value)\n        return y\n\n    @profile(torch.nn.functional.pad)\n    def profile_pad(test_case):\n        tensor = torch.ones(32, 3, 128, 128)\n        pad = (1, 1)\n        torch.nn.functional.pad(tensor, pad)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_contiguous.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nfrom random import shuffle\nimport numpy as np\n\nfrom oneflow.test_utils.automated_test_util import *\nfrom oneflow.test_utils.test_util import GenArgList\nimport oneflow.unittest\nimport oneflow as flow\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestContiguous(flow.unittest.TestCase):\n    @autotest(n=5)\n    def test_transpose_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=4).to(device)\n        y = torch.transpose(x, dim0=random(1, 3).to(int), dim1=random(1, 3).to(int))\n        z = y.contiguous()\n        return z\n\n    @autotest(n=5, auto_backward=False)\n    def test_transpose_with_bool_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=4, requires_grad=False).to(device).to(torch.bool)\n        y = torch.transpose(x, dim0=random(1, 3).to(int), dim1=random(1, 3).to(int))\n        z = y.contiguous()\n        return z\n\n    @autotest(n=5, auto_backward=False)\n    def test_transpose_with_int_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=4, requires_grad=False).to(device).to(torch.int)\n        y = torch.transpose(x, dim0=random(1, 3).to(int), dim1=random(1, 3).to(int))\n        z = y.contiguous()\n        return z\n\n    @autotest(n=5, auto_backward=False)\n    def test_contiguous_with_half_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=4, requires_grad=False).to(device).to(torch.float16)\n        y = torch.transpose(x, dim0=random(1, 3).to(int), dim1=random(1, 3).to(int))\n        z = y.contiguous()\n        return z\n\n    @autotest(n=10, check_graph=True)\n    def test_permute2d_tensor_with_random_data(test_case):\n        device = random_device()\n        ndim = 2\n        permute_list = [0, 1]\n        shuffle(permute_list)\n        x = random_tensor(\n            ndim=ndim, dim0=random(1, 32).to(int), dim1=random(1, 59).to(int),\n        ).to(device)\n        y = x.permute(permute_list)\n        z = y.contiguous()\n        return z\n\n    @autotest(n=10, check_graph=True)\n    def test_permute3d_tensor_with_random_data(test_case):\n        device = random_device()\n        ndim = 3\n        permute_list = [0, 1, 2]\n        shuffle(permute_list)\n        x = random_tensor(\n            ndim=ndim,\n            dim0=random(1, 7).to(int),\n            dim1=random(1, 15).to(int),\n            dim2=random(1, 9).to(int),\n        ).to(device)\n        y = x.permute(permute_list)\n        z = y.contiguous()\n        return z\n\n    @autotest(n=10, check_graph=True)\n    def test_permute4d_tensor_with_random_data(test_case):\n        device = random_device()\n        ndim = 4\n        permute_list = [0, 1, 2, 3]\n        shuffle(permute_list)\n        x = random_tensor(\n            ndim=ndim,\n            dim0=random(1, 7).to(int),\n            dim1=random(1, 15).to(int),\n            dim2=random(1, 9).to(int),\n            dim3=random(1, 19).to(int),\n        ).to(device)\n        y = x.permute(permute_list)\n        z = y.contiguous()\n        return z\n\n    @profile(torch.Tensor.contiguous)\n    def profile_contiguous(test_case):\n        x = torch.ones(32, 3, 128, 128)\n        x.contiguous()\n\n\ndef _test_inplace_contiguous(test_case, device):\n    arr = np.random.randn(4, 5, 6, 7).astype(np.float32)\n    input = flow.tensor(arr, device=device)\n    x = input.permute(0, 3, 2, 1)  # x is non-contiguous tensor\n    test_case.assertTrue(x.is_contiguous() == False)\n    # y1 is normal version of tensor contiguous\n    y1 = x.contiguous()\n    # y2 is inplace version of tensor contiguous\n    y2 = x.contiguous_()\n    test_case.assertTrue(np.array_equal(y1.cpu().numpy(), y2.cpu().numpy()))\n    test_case.assertTrue(id(x) != id(y1))\n    test_case.assertTrue(id(x) == id(y2))\n    test_case.assertTrue(x.is_contiguous() == True)\n    test_case.assertTrue(y1.is_contiguous() == True)\n    test_case.assertTrue(y2.is_contiguous() == True)\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestInplaceContiguous(flow.unittest.TestCase):\n    def test_inplace_contiguous(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [\n            _test_inplace_contiguous,\n        ]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_conv1d.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nfrom oneflow.test_utils.test_util import GenArgList\nfrom oneflow.test_utils.automated_test_util import *\n\nimport oneflow as flow\nimport oneflow.nn as nn\nimport oneflow.unittest\nimport torch as torch_original\nfrom packaging import version\n\n\ndef _test_conv1d_bias_false(test_case, device):\n    np_arr = np.array([[[1.28795946, -0.2921792, 0.20338029, 0.78604293, -1.89607573]]])\n    input = flow.tensor(\n        np_arr, dtype=flow.float32, device=flow.device(device), requires_grad=True\n    )\n    weight = np.array(\n        [\n            [[0.10197904, 0.3372305, -0.25743008]],\n            [[0.27720425, -0.52435774, -0.38381988]],\n            [[0.56016803, -0.10063095, -0.10760903]],\n        ]\n    )\n    m = nn.Conv1d(1, 3, 3, stride=1, bias=False)\n    m.weight = flow.nn.Parameter(flow.Tensor(weight))\n    m = m.to(device)\n    output = m(input)\n    np_out = np.array(\n        [\n            [\n                [-0.01954307, -0.16356121, 0.77392507],\n                [0.43217283, -0.48933625, 0.37196174],\n                [0.72899038, -0.2687211, 0.23886177],\n            ]\n        ]\n    )\n    test_case.assertTrue(np.allclose(output.numpy(), np_out, 1e-06, 1e-06))\n    output = output.sum()\n    output.backward()\n    np_grad = np.array(\n        [[[0.93935132, 0.65159315, -0.09726584, -1.03661716, -0.74885899]]]\n    )\n    test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-06, 1e-06))\n\n\ndef _test_conv1d_bias_true(test_case, device):\n    np_arr = np.array(\n        [\n            [\n                [0.90499806, -1.11683071, 0.71605605, -0.56754625, 0.61944169],\n                [-0.31317389, -0.26271924, 0.95579433, 0.52468461, 1.48926127],\n            ]\n        ]\n    )\n    input = flow.tensor(\n        np_arr, dtype=flow.float32, device=flow.device(device), requires_grad=True\n    )\n    weight = np.array(\n        [\n            [\n                [0.01997352, 0.23834395, 0.00526353],\n                [-0.04861857, -0.22751901, -0.06725175],\n            ],\n            [\n                [0.13344523, -0.35202524, 0.15168799],\n                [-0.25714493, -0.17459838, 0.28768948],\n            ],\n            [\n                [0.10671382, -0.28205597, -0.39752254],\n                [0.36393702, 0.07843742, -0.33898622],\n            ],\n            [\n                [0.20485674, 0.04222689, -0.1898618],\n                [0.22519711, -0.15910202, -0.35057363],\n            ],\n        ]\n    )\n    bias = np.array([0.01012857, 0.38912651, -0.01600273, -0.3883304])\n    m = nn.Conv1d(2, 4, 3, stride=1, bias=True)\n    m.weight = flow.nn.Parameter(flow.Tensor(weight))\n    m.bias = flow.nn.Parameter(flow.Tensor(bias))\n    m = m.to(device)\n    np_out = np.array(\n        [\n            [\n                [-0.22349545, -0.08447243, -0.37358052],\n                [1.4130373, -0.04644597, 0.86949122],\n                [-0.34765026, -0.31004351, -0.14158708],\n                [-0.74985039, -0.87430149, -0.77354753],\n            ]\n        ]\n    )\n    output = m(input)\n    test_case.assertTrue(np.allclose(output.numpy(), np_out, 1e-06, 1e-06))\n    output = output.sum()\n    output.backward()\n    np_grad = np.array(\n        [\n            [\n                [0.4649893, 0.11147892, -0.3189539, -0.78394318, -0.43043283],\n                [0.28337064, -0.19941133, -0.66853344, -0.95190406, -0.46912211],\n            ]\n        ]\n    )\n    test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-06, 1e-06))\n\n\ndef _test_conv1d_dilation(test_case, device):\n    np_arr = np.array(\n        [[[-0.43016902, 1.74619496, -0.57338119, 0.25563857, 0.12575546]]]\n    )\n    input = flow.tensor(\n        np_arr, dtype=flow.float32, device=flow.device(device), requires_grad=True\n    )\n    weight = np.array(\n        [\n            [[-0.35057205, -0.31304273, 0.46250814]],\n            [[-0.40786612, 0.36518192, 0.46280444]],\n            [[-0.00921835, -0.38710043, 0.47566161]],\n        ]\n    )\n    m = nn.Conv1d(1, 3, 3, stride=1, bias=False)\n    m.weight = flow.nn.Parameter(flow.Tensor(weight))\n    m = m.to(device)\n    output = m(input)\n    np_out = np.array(\n        [\n            [\n                [-0.66102189, -0.31443936, 0.17914855],\n                [0.54776692, -0.8032915, 0.38541752],\n                [-0.94472277, 0.32745653, -0.03385513],\n            ]\n        ]\n    )\n    test_case.assertTrue(np.allclose(output.numpy(), np_out, 1e-06, 1e-06))\n    output = output.sum()\n    output.backward()\n    np_grad = np.array(\n        [[[-0.76765651, -1.10261774, 0.29835641, 1.06601286, 1.40097415]]]\n    )\n    test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-06, 1e-06))\n\n\ndef _test_conv1d_stride(test_case, device):\n    np_arr = np.array(\n        [[[-1.01312506, -0.40687919, 1.5985316, 0.53594196, -1.89935565]]]\n    )\n    input = flow.tensor(\n        np_arr, dtype=flow.float32, device=flow.device(device), requires_grad=True\n    )\n    weight = np.array(\n        [\n            [[0.5751484, 0.26589182, -0.026546]],\n            [[-0.10313249, -0.20797005, -0.48268208]],\n            [[-0.22216944, -0.14962578, 0.57433963]],\n        ]\n    )\n    m = nn.Conv1d(1, 3, 3, stride=2, bias=False)\n    m.weight = flow.nn.Parameter(flow.Tensor(weight))\n    m = m.to(device)\n    output = m(input)\n    np_out = np.array(\n        [\n            [\n                [-0.73331773, 1.11231577],\n                [-0.58247775, 0.64046454],\n                [1.20406508, -1.5262109],\n            ]\n        ]\n    )\n    test_case.assertTrue(np.allclose(output.numpy(), np_out, 1e-06, 1e-06))\n    output = output.sum()\n    output.backward()\n    np_grad = np.array(\n        [[[0.24984647, -0.09170401, 0.31495798, -0.09170401, 0.06511152]]]\n    )\n    test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-06, 1e-06))\n\n\ndef _test_conv1d_group_bias_true(test_case, device):\n    np_arr = np.array(\n        [\n            [\n                [1.48566079, 0.54937589, 0.62353903, -0.94114172, -0.60260266],\n                [0.61150503, -0.50289607, 1.41735041, -1.85877609, -1.04875529],\n            ]\n        ]\n    )\n    input = flow.tensor(\n        np_arr, dtype=flow.float32, device=flow.device(device), requires_grad=True\n    )\n    weight = np.array(\n        [\n            [[0.25576305, 0.40814576, -0.05900212]],\n            [[-0.24829513, 0.42756805, -0.01354307]],\n            [[0.44658303, 0.46889144, 0.41060263]],\n            [[0.30083328, -0.5221613, 0.12215579]],\n        ]\n    )\n    bias = np.array([-0.03368823, -0.4212504, -0.42130581, -0.17434336])\n    m = nn.Conv1d(2, 4, 3, groups=2, stride=1, bias=True)\n    m.weight = flow.nn.Parameter(flow.Tensor(weight))\n    m.bias = flow.nn.Parameter(flow.Tensor(bias))\n    m = m.to(device)\n    np_out = np.array(\n        [\n            [\n                [0.53372419, 0.41684598, -0.22277816],\n                [-0.56368178, -0.27830642, -0.97031319],\n                [0.19794616, -0.74452549, -1.09052706],\n                [0.44534814, -1.29277706, 1.09451222],\n            ]\n        ]\n    )\n    output = m(input)\n    test_case.assertTrue(np.allclose(output.numpy(), np_out, 1e-06, 1e-06))\n    output = output.sum()\n    output.backward()\n    np_grad = np.array(\n        [\n            [\n                [0.00746793, 0.84318173, 0.77063656, 0.76316863, -0.07254519],\n                [0.74741632, 0.69414645, 1.22690487, 0.47948855, 0.53275841],\n            ]\n        ]\n    )\n    test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-06, 1e-06))\n\n\ndef _test_conv1d_group_large_out_bias_true(test_case, device):\n    np_arr = np.array(\n        [\n            [\n                [2.17964911, 0.91623521, 1.24746692, 0.73605931, -0.23738743],\n                [-0.70412433, 0.10727754, 1.0207864, -0.09711888, -1.10814202],\n            ]\n        ]\n    )\n    input = flow.tensor(\n        np_arr, dtype=flow.float32, device=flow.device(device), requires_grad=True\n    )\n    weight = np.array(\n        [\n            [[-0.207307473, 0.12856324, 0.371991515]],\n            [[-0.416422307, 3.26921181e-05, -0.385845661]],\n            [[-0.182592362, 0.143281639, 0.419321984]],\n            [[-0.27117458, 0.0421470925, 0.377335936]],\n            [[0.546190619, -0.211819887, -0.29785803]],\n            [[0.334832489, 0.255918801, -0.0556600206]],\n        ]\n    )\n    bias = np.array(\n        [-0.56865668, 0.17631066, -0.43992457, -0.24307285, -0.53672957, -0.52927947]\n    )\n    m = nn.Conv1d(2, 6, 3, groups=2, stride=1, bias=True)\n    m.weight = flow.nn.Parameter(flow.Tensor(weight))\n    m.bias = flow.nn.Parameter(flow.Tensor(bias))\n    m = m.to(device)\n    np_out = np.array(\n        [\n            [\n                [-0.43867296, -0.32441288, -0.82094181],\n                [-1.21264362, -0.48919463, -0.25154343],\n                [-0.18354186, -0.11983716, -0.66178048],\n                [0.33756858, -0.26578707, -0.9421193],\n                [-1.2480886, -0.66543078, 0.37145507],\n                [-0.79440582, -0.22671542, -0.15066233],\n            ]\n        ]\n    )\n    output = m(input)\n    test_case.assertTrue(np.allclose(output.numpy(), np_out, 1e-06, 1e-06))\n    output = output.sum()\n    output.backward()\n    np_grad = np.array(\n        [\n            [\n                [-0.8063221, -0.53444451, -0.12897667, 0.6773454, 0.40546784],\n                [0.6098485, 0.69609451, 0.71991241, 0.1100639, 0.02381789],\n            ]\n        ]\n    )\n    test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-06, 1e-06))\n\n\ndef _test_conv1d_group_large_in_bias_true(test_case, device):\n    np_arr = np.array(\n        [\n            [\n                [0.7382921, 0.3227571, -0.73204273, -0.01697334, 1.72585976],\n                [0.52866709, 0.28417364, 1.12931311, 1.73048413, -0.60748184],\n                [0.43222603, 0.7882517, -0.62105948, 0.10097823, 0.81639361],\n                [0.36671457, 0.24468753, -0.5824874, -0.74464536, -0.38901371],\n            ]\n        ]\n    )\n    input = flow.tensor(\n        np_arr, dtype=flow.float32, device=flow.device(device), requires_grad=True\n    )\n    weight = np.array(\n        [\n            [\n                [-0.29574063, -0.31176069, 0.17234495],\n                [0.06092392, 0.30691007, -0.36685407],\n            ],\n            [\n                [0.26149744, 0.07149458, 0.3209756],\n                [0.18960869, -0.37148297, -0.13602243],\n            ],\n        ]\n    )\n    bias = np.array([-0.35048512, -0.0093792])\n    m = nn.Conv1d(4, 2, 3, groups=2, stride=1, bias=True)\n    m.weight = flow.nn.Parameter(flow.Tensor(weight))\n    m.bias = flow.nn.Parameter(flow.Tensor(bias))\n    m = m.to(device)\n    np_out = np.array(\n        [[[-1.09048378, -0.49156523, 0.99150705], [0.01852397, 0.54882324, 0.31657016]]]\n    )\n    output = m(input)\n    test_case.assertTrue(np.allclose(output.numpy(), np_out, 1e-06, 1e-06))\n    output = output.sum()\n    output.backward()\n    np_grad = np.array(\n        [\n            [\n                [-0.29574063, -0.60750133, -0.43515638, -0.13941574, 0.17234495],\n                [0.06092392, 0.36783397, 0.0009799, -0.059944, -0.36685407],\n                [0.26149744, 0.33299202, 0.65396762, 0.39247018, 0.3209756],\n                [0.18960869, -0.18187428, -0.31789672, -0.50750542, -0.13602243],\n            ]\n        ]\n    )\n    test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-06, 1e-06))\n\n\ndef _test_conv1d_compilcate(test_case, device):\n    np_arr = np.array(\n        [\n            [\n                [-1.00674784, 0.51784992, 0.39896572, 0.11018554, 0.91136694],\n                [1.95886874, 0.89779067, 0.4748213, 0.33313531, -0.49350029],\n                [-0.19280219, 0.04023677, 1.66438103, -0.83563608, 0.15925731],\n                [1.49166429, 1.45189261, -1.86512125, 0.34329697, 0.20413807],\n            ]\n        ]\n    )\n    input = flow.tensor(\n        np_arr, dtype=flow.float32, device=flow.device(device), requires_grad=True\n    )\n    weight = np.array(\n        [\n            [\n                [-0.36045218, 0.37349278, 0.04565236],\n                [0.0242328, -0.09459515, -0.30684742],\n            ],\n            [\n                [-0.30345008, -0.1196513, -0.26765293],\n                [0.09876197, 0.03346226, 0.2748405],\n            ],\n            [\n                [-0.37798449, 0.00242459, -0.34125558],\n                [-0.05174343, -0.10443231, 0.09526101],\n            ],\n            [\n                [0.34196907, -0.32667893, 0.40264183],\n                [0.38025281, 0.26807079, -0.09074812],\n            ],\n        ]\n    )\n    bias = np.array([-0.03499984, -0.21616256, 0.13312563, -0.24104381])\n    m = nn.Conv1d(4, 4, 3, groups=2, stride=2, padding=2, dilation=2, bias=True)\n    m.weight = flow.nn.Parameter(flow.Tensor(weight))\n    m.bias = flow.nn.Parameter(flow.Tensor(bias))\n    m = m.to(device)\n    np_out = np.array(\n        [\n            [\n                [-0.72379637, 0.67248386, 0.21977007],\n                [-0.00643994, -0.1286152, -0.41589433],\n                [-0.76877236, 0.29273134, -0.42040929],\n                [1.0612179, -0.73787093, -0.37839717],\n            ]\n        ]\n    )\n    output = m(input)\n    test_case.assertTrue(np.allclose(output.numpy(), np_out, 1e-06, 1e-06))\n    output = output.sum()\n    output.backward()\n    np_grad = np.array(\n        [\n            [\n                [-0.41006082, 0.0, -0.63206136, 0.0, 0.03184089],\n                [0.06186188, 0.0, 0.02985496, 0.0, -0.09313981],\n                [-0.36026976, 0.0, -0.2988835, 0.0, -0.26286808],\n                [0.49214786, 0.0, 0.49666074, 0.0, 0.16815135],\n            ]\n        ]\n    )\n    test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-06, 1e-06))\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestConv1d(flow.unittest.TestCase):\n    def test_conv1d(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [\n            _test_conv1d_bias_true,\n            _test_conv1d_bias_false,\n            _test_conv1d_dilation,\n            _test_conv1d_stride,\n            _test_conv1d_group_bias_true,\n            _test_conv1d_group_large_out_bias_true,\n            _test_conv1d_group_large_in_bias_true,\n            _test_conv1d_compilcate,\n        ]\n        arg_dict[\"device\"] = [\"cuda\", \"cpu\"]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n    @unittest.skip(\"skip for now, becase it failed 8 times in past week\")\n    @autotest(n=3)\n    def test_nn_functional_conv1d(test_case):\n        device = random_device()\n        img = torch.ones((1, 3, 224), requires_grad=True).to(device)\n        kernel = torch.ones((3, 1, 3), requires_grad=True).to(device)\n        y = torch.nn.functional.conv1d(img, kernel, groups=3)\n        return y\n\n    @unittest.skipIf(\n        version.parse(torch_original.__version__) <= version.parse(\"1.13.0\"),\n        \"conv module don't support unbatched input in PyTorch before '1.13.0'\",\n    )\n    @autotest(n=3)\n    def test_nn_functional_conv1d_2dinput(test_case):\n        device = random_device()\n        img = torch.ones((3, 224), requires_grad=True).to(device)\n        kernel = torch.ones((3, 1, 3), requires_grad=True).to(device)\n        y = torch.nn.functional.conv1d(img, kernel, groups=3)\n        return y\n\n    @profile(torch.nn.functional.conv1d)\n    def profile_conv1d(test_case):\n        inputs = torch.ones(40, 16, 30)\n        weight_16c = torch.ones(20, 16, 5)\n        weight_16c_4g = torch.ones(20, 4, 5)\n        weight_3k_16c = torch.ones(20, 16, 3)\n        weight_1k_16c = torch.ones(20, 16, 1)\n        torch.nn.functional.conv1d(inputs, weight_16c)\n        torch.nn.functional.conv1d(inputs, weight_16c, bias=torch.ones(20))\n        torch.nn.functional.conv1d(inputs, weight_16c, bias=torch.ones(20), padding=2)\n        torch.nn.functional.conv1d(\n            inputs, weight_16c, bias=torch.ones(20), padding=2, stride=2\n        )\n        torch.nn.functional.conv1d(inputs, weight_16c_4g, groups=4)\n        torch.nn.functional.conv1d(inputs, weight_16c_4g, bias=torch.ones(20), groups=4)\n        torch.nn.functional.conv1d(\n            inputs, weight_16c_4g, bias=torch.ones(20), groups=4, stride=4\n        )\n        torch.nn.functional.conv1d(\n            inputs, weight_16c_4g, bias=torch.ones(20), groups=4, padding=2\n        )\n        torch.nn.functional.conv1d(inputs, weight_3k_16c)\n        torch.nn.functional.conv1d(inputs, weight_3k_16c, bias=torch.ones(20))\n        torch.nn.functional.conv1d(\n            inputs, weight_3k_16c, bias=torch.ones(20), padding=1\n        )\n        torch.nn.functional.conv1d(\n            inputs, weight_3k_16c, bias=torch.ones(20), padding=1, stride=2\n        )\n        torch.nn.functional.conv1d(inputs, weight_1k_16c)\n        torch.nn.functional.conv1d(inputs, weight_1k_16c, bias=torch.ones(20))\n        torch.nn.functional.conv1d(inputs, weight_1k_16c, bias=torch.ones(20), stride=2)\n\n    @autotest(n=5, atol=1e-3)\n    def test_conv1d_with_random_data(test_case):\n        channels = random(1, 6)\n        m = torch.nn.Conv1d(\n            in_channels=channels,\n            out_channels=random(1, 20),\n            kernel_size=random(1, 4),\n            stride=random() | nothing(),\n            padding=random(1, 3).to(int) | nothing(),\n            dilation=random(1, 5) | nothing(),\n            groups=random(1, 5) | nothing(),\n            padding_mode=constant(\"zeros\") | nothing(),\n        )\n        m.train(random())\n        device = random_device()\n        m.to(device)\n        x = random_tensor(ndim=3, dim1=channels).to(device)\n        y = m(x)\n        return y\n\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    @autotest(n=5, check_allclose=False)\n    def test_conv1d_group_with_random_data(test_case):\n        channels = 720  # lcm(1, 2, 3, 4, 5, 6)\n        m = torch.nn.Conv1d(\n            in_channels=channels,\n            out_channels=channels,\n            kernel_size=random(1, 4),\n            stride=random() | nothing(),\n            padding=random(1, 3).to(int) | nothing(),\n            dilation=random(1, 5) | nothing(),\n            groups=random(1, 7),\n            padding_mode=constant(\"zeros\") | nothing(),\n        )\n        m.train(random())\n\n        device = random_device()\n        m.to(device)\n        m.pytorch.to(\"cuda\")\n        x = random_tensor(ndim=3, dim1=channels).to(device)\n        x.pytorch = x.pytorch.to(\"cuda\")\n        y = m(x)\n        return y\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_conv2d.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\nimport os\n\nimport numpy as np\n\nfrom oneflow.test_utils.automated_test_util import *\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport oneflow.unittest\nimport torch as torch_original\nfrom packaging import version\n\ntest_conv2d_weight = np.array(\n    [\n        [\n            [\n                [0.8586049675941467, -0.2279418259859085, 0.2013147622346878],\n                [0.35005471110343933, 0.5360521078109741, 1.5194443464279175],\n                [1.9040879011154175, -1.5734431743621826, -0.14007866382598877],\n            ]\n        ],\n        [\n            [\n                [0.29670074582099915, 1.3111951351165771, 0.5035904049873352],\n                [-1.1894450187683105, -0.5502137541770935, -1.591875672340393],\n                [-1.1081947088241577, 0.07872020453214645, -0.9185634255409241],\n            ]\n        ],\n        [\n            [\n                [-0.7457143664360046, -1.2080862522125244, 1.8140212297439575],\n                [-1.5227429866790771, -2.515244960784912, -1.3549325466156006],\n                [-0.9574840068817139, -0.7248556613922119, 1.1119636297225952],\n            ]\n        ],\n    ]\n)\ntest_conv2d_data = np.array(\n    [\n        [\n            [\n                [\n                    1.1630785465240479,\n                    0.4838046133518219,\n                    0.299563467502594,\n                    0.15302546322345734,\n                    -1.168814778327942,\n                ],\n                [\n                    1.5580710172653198,\n                    -0.5459445714950562,\n                    -2.3556296825408936,\n                    0.5414402484893799,\n                    2.678506374359131,\n                ],\n                [\n                    1.2546343803405762,\n                    -0.5487740635871887,\n                    -0.6810643672943115,\n                    -0.13531559705734253,\n                    0.37723132967948914,\n                ],\n                [\n                    0.41016456484794617,\n                    0.5712682008743286,\n                    -2.757962703704834,\n                    1.0762799978256226,\n                    -0.6141325235366821,\n                ],\n                [\n                    1.830764889717102,\n                    -1.1468064785003662,\n                    0.053837940096855164,\n                    -2.5074806213378906,\n                    -0.5916498899459839,\n                ],\n            ]\n        ]\n    ]\n)\ntest_conv2d_data_grad = np.array(\n    [\n        [\n            [\n                [\n                    0.4095913469791412,\n                    0.2847584038972855,\n                    2.803684800863266,\n                    2.3940934538841248,\n                    2.5189263969659805,\n                ],\n                [\n                    -1.9525419473648071,\n                    -4.606781497597694,\n                    -3.51521897315979,\n                    -1.562677025794983,\n                    1.0915625244379044,\n                ],\n                [\n                    -2.1141327619552612,\n                    -6.987950943410397,\n                    -5.84306687861681,\n                    -3.7289341166615486,\n                    1.1448840647935867,\n                ],\n                [\n                    -2.5237241089344025,\n                    -7.272709347307682,\n                    -8.646751679480076,\n                    -6.123027570545673,\n                    -1.3740423321723938,\n                ],\n                [\n                    -0.1615908145904541,\n                    -2.381169445812702,\n                    -2.32784790545702,\n                    -2.1662570908665657,\n                    0.0533215403556824,\n                ],\n            ]\n        ]\n    ]\n)\ntest_conv2d_weight_grad = np.array(\n    [\n        [\n            [\n                [0.6277393400669098, -2.7888944894075394, -0.2910575419664383],\n                [-3.095237225294113, -4.835702538490295, -1.8706469237804413],\n                [-1.0139376372098923, -6.076017692685127, -5.780256435275078],\n            ]\n        ],\n        [\n            [\n                [0.6277393400669098, -2.7888944894075394, -0.2910575419664383],\n                [-3.095237225294113, -4.835702538490295, -1.8706469237804413],\n                [-1.0139376372098923, -6.076017692685127, -5.780256435275078],\n            ]\n        ],\n        [\n            [\n                [0.6277393400669098, -2.7888944894075394, -0.2910575419664383],\n                [-3.095237225294113, -4.835702538490295, -1.8706469237804413],\n                [-1.0139376372098923, -6.076017692685127, -5.780256435275078],\n            ]\n        ],\n    ]\n)\ntest_conv2d_output = np.array(\n    [\n        [\n            [\n                [0.9699610471725464, -0.20758534967899323, 2.3857712745666504],\n                [0.3666309118270874, 4.690882682800293, -8.203354835510254],\n                [2.6072847843170166, -1.9033538103103638, 2.331153154373169],\n            ],\n            [\n                [2.519343852996826, 2.3757898807525635, -1.6613528728485107],\n                [0.5777544379234314, -3.5739502906799316, 5.349126815795898],\n                [0.729295015335083, 1.5791023969650269, 3.7627718448638916],\n            ],\n            [\n                [-0.27685487270355225, 6.446267127990723, -2.762883424758911],\n                [-8.25644588470459, 9.616064071655273, 8.005367279052734],\n                [-0.6944921016693115, 3.866114854812622, 4.788446426391602],\n            ],\n        ]\n    ]\n)\ntest_conv2d_with_bias_weight = np.array(\n    [\n        [\n            [\n                [1.8271433115005493, -1.0446699857711792, 1.0062190294265747],\n                [0.5174201130867004, -0.806931734085083, 1.3769007921218872],\n                [0.205885112285614, 0.9943519234657288, -0.23580588400363922],\n            ]\n        ],\n        [\n            [\n                [0.29881811141967773, -1.9982075691223145, 0.3511354625225067],\n                [-0.7644741535186768, 1.2594351768493652, -0.9629734754562378],\n                [0.5080506205558777, 0.7561734318733215, 1.6839302778244019],\n            ]\n        ],\n        [\n            [\n                [1.2573646306991577, 0.13123232126235962, 1.6403018236160278],\n                [-1.2138012647628784, 2.399970531463623, -0.38509097695350647],\n                [-0.9878040552139282, 0.9585888385772705, -1.4976465702056885],\n            ]\n        ],\n    ]\n)\ntest_conv2d_with_bias_bias = np.array(\n    [0.6605162620544434, -0.18903568387031555, -0.27302607893943787]\n)\ntest_conv2d_with_bias_data = np.array(\n    [\n        [\n            [\n                [\n                    -0.47827261686325073,\n                    -1.1739492416381836,\n                    -0.7921845316886902,\n                    0.9321041703224182,\n                    -3.1557741165161133,\n                ],\n                [\n                    2.1935296058654785,\n                    -0.5385921001434326,\n                    -0.8611332774162292,\n                    -1.881519079208374,\n                    -0.7205708026885986,\n                ],\n                [\n                    -0.35601571202278137,\n                    -0.15963983535766602,\n                    1.797447681427002,\n                    0.19594945013523102,\n                    -1.7376397848129272,\n                ],\n                [\n                    0.047347065061330795,\n                    0.14580930769443512,\n                    0.32604914903640747,\n                    0.4578782916069031,\n                    -0.8942581415176392,\n                ],\n                [\n                    0.49383941292762756,\n                    -0.9043426513671875,\n                    -1.2140793800354004,\n                    2.1564064025878906,\n                    1.0938222408294678,\n                ],\n            ]\n        ]\n    ]\n)\ntest_conv2d_with_bias_output = np.array(\n    [\n        [\n            [\n                [-0.05607491731643677, -0.185230553150177, -3.8808679580688477],\n                [6.861937046051025, -2.3341472148895264, -0.5597308874130249],\n                [1.8299254179000854, -2.770848274230957, 2.1958212852478027],\n            ],\n            [\n                [2.9348952770233154, 4.117504119873047, -6.278541088104248],\n                [0.2638452351093292, 3.998856782913208, 2.612290620803833],\n                [-1.9891828298568726, -1.6476304531097412, 3.39066219329834],\n            ],\n            [\n                [-8.44466781616211, 0.5747121572494507, -8.501373291015625],\n                [-0.036642804741859436, -0.23458999395370483, -2.370849370956421],\n                [2.8372013568878174, -2.987276077270508, 1.8382092714309692],\n            ],\n        ]\n    ]\n)\ntest_conv2d_group_weight = np.array(\n    [\n        [\n            [\n                [-0.7248556613922119, 1.1119636297225952, -0.47827261686325073],\n                [-1.1739492416381836, -0.7921845316886902, 0.9321041703224182],\n                [-3.1557741165161133, 2.1935296058654785, -0.5385921001434326],\n            ]\n        ],\n        [\n            [\n                [-0.8611332774162292, -1.881519079208374, -0.7205708026885986],\n                [-0.35601571202278137, -0.15963983535766602, 1.797447681427002],\n                [0.19594945013523102, -1.7376397848129272, 0.047347065061330795],\n            ]\n        ],\n    ]\n)\ntest_conv2d_group_data_grad = np.array(\n    [\n        [\n            [\n                [\n                    -0.7248556613922119,\n                    0.3871079683303833,\n                    -0.0911646485328674,\n                    0.6336910128593445,\n                    -0.4782726168632507,\n                ],\n                [\n                    -1.8988049030303955,\n                    -1.5790258049964905,\n                    -1.125194251537323,\n                    0.7736106514930725,\n                    0.4538315534591675,\n                ],\n                [\n                    -5.054579019546509,\n                    -2.5412703156471252,\n                    -2.6260308623313904,\n                    2.4285481572151184,\n                    -0.0847605466842651,\n                ],\n                [\n                    -4.329723358154297,\n                    -2.9283782839775085,\n                    -2.534866213798523,\n                    1.794857144355774,\n                    0.3935120701789856,\n                ],\n                [\n                    -3.1557741165161133,\n                    -0.9622445106506348,\n                    -1.5008366107940674,\n                    1.654937505722046,\n                    -0.5385921001434326,\n                ],\n            ],\n            [\n                [\n                    -0.8611332774162292,\n                    -2.7426523566246033,\n                    -3.463223159313202,\n                    -2.6020898818969727,\n                    -0.7205708026885986,\n                ],\n                [\n                    -1.2171489894390106,\n                    -3.2583079040050507,\n                    -2.1814310252666473,\n                    -0.9642820358276367,\n                    1.0768768787384033,\n                ],\n                [\n                    -1.0211995393037796,\n                    -4.799998238682747,\n                    -3.6757742948830128,\n                    -2.654574755579233,\n                    1.1242239437997341,\n                ],\n                [\n                    -0.1600662618875504,\n                    -2.0573458820581436,\n                    -0.2125511355698109,\n                    -0.0524848736822605,\n                    1.8447947464883327,\n                ],\n                [\n                    0.195949450135231,\n                    -1.5416903346776962,\n                    -1.4943432696163654,\n                    -1.6902927197515965,\n                    0.0473470650613308,\n                ],\n            ],\n        ]\n    ]\n)\ntest_conv2d_group_weight_grad = np.array(\n    [\n        [\n            [\n                [0.6277393400669098, -2.7888944894075394, -0.2910575419664383],\n                [-3.095237225294113, -4.835702538490295, -1.8706469237804413],\n                [-1.0139376372098923, -6.076017692685127, -5.780256435275078],\n            ]\n        ],\n        [\n            [\n                [3.30740749835968, -0.7220746576786041, -3.660933956503868],\n                [0.5273916646838188, -2.631059892475605, -7.6207195818424225],\n                [-3.5466641262173653, -8.214546449482441, -11.031560003757477],\n            ]\n        ],\n    ]\n)\ntest_conv2d_group_data = np.array(\n    [\n        [\n            [\n                [\n                    1.1630785465240479,\n                    0.4838046133518219,\n                    0.299563467502594,\n                    0.15302546322345734,\n                    -1.168814778327942,\n                ],\n                [\n                    1.5580710172653198,\n                    -0.5459445714950562,\n                    -2.3556296825408936,\n                    0.5414402484893799,\n                    2.678506374359131,\n                ],\n                [\n                    1.2546343803405762,\n                    -0.5487740635871887,\n                    -0.6810643672943115,\n                    -0.13531559705734253,\n                    0.37723132967948914,\n                ],\n                [\n                    0.41016456484794617,\n                    0.5712682008743286,\n                    -2.757962703704834,\n                    1.0762799978256226,\n                    -0.6141325235366821,\n                ],\n                [\n                    1.830764889717102,\n                    -1.1468064785003662,\n                    0.053837940096855164,\n                    -2.5074806213378906,\n                    -0.5916498899459839,\n                ],\n            ],\n            [\n                [\n                    0.8586049675941467,\n                    -0.2279418259859085,\n                    0.2013147622346878,\n                    0.35005471110343933,\n                    0.5360521078109741,\n                ],\n                [\n                    1.5194443464279175,\n                    1.9040879011154175,\n                    -1.5734431743621826,\n                    -0.14007866382598877,\n                    0.29670074582099915,\n                ],\n                [\n                    1.3111951351165771,\n                    0.5035904049873352,\n                    -1.1894450187683105,\n                    -0.5502137541770935,\n                    -1.591875672340393,\n                ],\n                [\n                    -1.1081947088241577,\n                    0.07872020453214645,\n                    -0.9185634255409241,\n                    -0.7457143664360046,\n                    -1.2080862522125244,\n                ],\n                [\n                    1.8140212297439575,\n                    -1.5227429866790771,\n                    -2.515244960784912,\n                    -1.3549325466156006,\n                    -0.9574840068817139,\n                ],\n            ],\n        ]\n    ]\n)\ntest_conv2d_group_output = np.array(\n    [\n        [\n            [\n                [-8.836943626403809, 3.2316627502441406, 6.994439601898193],\n                [-0.8386597037315369, -9.857108116149902, 13.68197250366211],\n                [-13.020713806152344, 7.310227870941162, -3.3760271072387695],\n            ],\n            [\n                [-4.803101539611816, 1.026240587234497, 0.5452112555503845],\n                [-6.839838027954102, 2.0195930004119873, 0.11328654736280441],\n                [0.393694669008255, 4.987061023712158, 3.297354221343994],\n            ],\n        ]\n    ]\n)\ntest_conv2d_padding_weight = np.array(\n    [\n        [\n            [\n                [0.8586049675941467, -0.2279418259859085, 0.2013147622346878],\n                [0.35005471110343933, 0.5360521078109741, 1.5194443464279175],\n                [1.9040879011154175, -1.5734431743621826, -0.14007866382598877],\n            ]\n        ]\n    ]\n)\ntest_conv2d_padding_data = np.array(\n    [\n        [\n            [\n                [\n                    1.1630785465240479,\n                    0.4838046133518219,\n                    0.299563467502594,\n                    0.15302546322345734,\n                    -1.168814778327942,\n                ],\n                [\n                    1.5580710172653198,\n                    -0.5459445714950562,\n                    -2.3556296825408936,\n                    0.5414402484893799,\n                    2.678506374359131,\n                ],\n                [\n                    1.2546343803405762,\n                    -0.5487740635871887,\n                    -0.6810643672943115,\n                    -0.13531559705734253,\n                    0.37723132967948914,\n                ],\n                [\n                    0.41016456484794617,\n                    0.5712682008743286,\n                    -2.757962703704834,\n                    1.0762799978256226,\n                    -0.6141325235366821,\n                ],\n                [\n                    1.830764889717102,\n                    -1.1468064785003662,\n                    0.053837940096855164,\n                    -2.5074806213378906,\n                    -0.5916498899459839,\n                ],\n            ]\n        ]\n    ]\n)\ntest_conv2d_padding_data_grad = np.array(\n    [\n        [\n            [\n                [\n                    3.237529069185257,\n                    3.237529069185257,\n                    3.237529069185257,\n                    3.237529069185257,\n                    3.237529069185257,\n                ],\n                [\n                    3.428095132112503,\n                    3.428095132112503,\n                    3.428095132112503,\n                    3.428095132112503,\n                    3.428095132112503,\n                ],\n                [\n                    3.428095132112503,\n                    3.428095132112503,\n                    3.428095132112503,\n                    3.428095132112503,\n                    3.428095132112503,\n                ],\n                [\n                    3.428095132112503,\n                    3.428095132112503,\n                    3.428095132112503,\n                    3.428095132112503,\n                    3.428095132112503,\n                ],\n                [\n                    2.596117228269577,\n                    2.596117228269577,\n                    2.596117228269577,\n                    2.596117228269577,\n                    2.596117228269577,\n                ],\n            ]\n        ]\n    ]\n)\ntest_conv2d_padding_weight_grad = np.array(\n    [\n        [\n            [\n                [1.7594299167394638, 1.7594299167394638, 1.7594299167394638],\n                [-0.6019042432308197, -0.6019042432308197, -0.6019042432308197],\n                [-1.532561555504799, -1.532561555504799, -1.532561555504799],\n            ]\n        ]\n    ]\n)\ntest_conv2d_padding_output = np.array(\n    [\n        [\n            [\n                [\n                    1.5489805936813354,\n                    -1.0164761543273926,\n                    5.277345657348633,\n                    3.153532028198242,\n                    -7.301508903503418,\n                    -3.7565059661865234,\n                    4.690962314605713,\n                ],\n                [\n                    2.425799608230591,\n                    -2.0592665672302246,\n                    0.9699610471725464,\n                    -0.20758534967899323,\n                    2.3857712745666504,\n                    1.1719579696655273,\n                    0.6523551940917969,\n                ],\n                [\n                    2.1625545024871826,\n                    -1.3517316579818726,\n                    0.3666309118270874,\n                    4.690882682800293,\n                    -8.203354835510254,\n                    3.0248217582702637,\n                    1.2624683380126953,\n                ],\n                [\n                    0.6193475723266602,\n                    -2.0285415649414062,\n                    2.6072847843170166,\n                    -1.9033538103103638,\n                    2.331153154373169,\n                    -3.998155355453491,\n                    -1.0176407098770142,\n                ],\n                [\n                    2.8643176555633545,\n                    -0.7396122217178345,\n                    -0.2253415733575821,\n                    -2.846742630004883,\n                    -4.961236476898193,\n                    -0.1308247298002243,\n                    -0.7344070672988892,\n                ],\n            ]\n        ]\n    ]\n)\ntest_conv2d_stride_weight = np.array(\n    [\n        [\n            [\n                [0.8586049675941467, -0.2279418259859085, 0.2013147622346878],\n                [0.35005471110343933, 0.5360521078109741, 1.5194443464279175],\n                [1.9040879011154175, -1.5734431743621826, -0.14007866382598877],\n            ]\n        ]\n    ]\n)\ntest_conv2d_stride_data = np.array(\n    [\n        [\n            [\n                [\n                    1.1630785465240479,\n                    0.4838046133518219,\n                    0.299563467502594,\n                    0.15302546322345734,\n                    -1.168814778327942,\n                ],\n                [\n                    1.5580710172653198,\n                    -0.5459445714950562,\n                    -2.3556296825408936,\n                    0.5414402484893799,\n                    2.678506374359131,\n                ],\n                [\n                    1.2546343803405762,\n                    -0.5487740635871887,\n                    -0.6810643672943115,\n                    -0.13531559705734253,\n                    0.37723132967948914,\n                ],\n                [\n                    0.41016456484794617,\n                    0.5712682008743286,\n                    -2.757962703704834,\n                    1.0762799978256226,\n                    -0.6141325235366821,\n                ],\n                [\n                    1.830764889717102,\n                    -1.1468064785003662,\n                    0.053837940096855164,\n                    -2.5074806213378906,\n                    -0.5916498899459839,\n                ],\n            ]\n        ]\n    ]\n)\ntest_conv2d_stride_data_grad = np.array(\n    [\n        [\n            [\n                [\n                    0.5360521078109741,\n                    1.5194443464279175,\n                    0.3500547111034393,\n                    0.5360521078109741,\n                    1.5194443464279175,\n                ],\n                [\n                    -1.8013850003480911,\n                    0.061236098408699,\n                    2.762692868709564,\n                    -1.8013850003480911,\n                    0.061236098408699,\n                ],\n                [\n                    0.5360521078109741,\n                    1.5194443464279175,\n                    0.3500547111034393,\n                    0.5360521078109741,\n                    1.5194443464279175,\n                ],\n                [\n                    -1.8013850003480911,\n                    0.061236098408699,\n                    2.762692868709564,\n                    -1.8013850003480911,\n                    0.061236098408699,\n                ],\n                [\n                    0.5360521078109741,\n                    1.5194443464279175,\n                    0.3500547111034393,\n                    0.5360521078109741,\n                    1.5194443464279175,\n                ],\n            ]\n        ]\n    ]\n)\ntest_conv2d_stride_weight_grad = np.array(\n    [\n        [\n            [\n                [-5.1135923862457275, 3.5859558284282684, 2.089697480201721],\n                [-0.3276629596948624, 1.7587070614099503, -2.5950092673301697],\n                [-5.1135923862457275, 3.5859558284282684, 2.089697480201721],\n            ]\n        ]\n    ]\n)\ntest_conv2d_stride_output = np.array(\n    [\n        [\n            [\n                [-1.0164761543273926, -7.301508903503418],\n                [-1.3517316579818726, -8.203354835510254],\n                [-0.7396122217178345, -4.961236476898193],\n            ]\n        ]\n    ]\n)\ntest_conv2d_kernel_weight = np.array(\n    [\n        [\n            [\n                [\n                    -0.9574840068817139,\n                    -0.7248556613922119,\n                    1.1119636297225952,\n                    -0.47827261686325073,\n                    -1.1739492416381836,\n                ],\n                [\n                    -0.7921845316886902,\n                    0.9321041703224182,\n                    -3.1557741165161133,\n                    2.1935296058654785,\n                    -0.5385921001434326,\n                ],\n                [\n                    -0.8611332774162292,\n                    -1.881519079208374,\n                    -0.7205708026885986,\n                    -0.35601571202278137,\n                    -0.15963983535766602,\n                ],\n            ]\n        ]\n    ]\n)\ntest_conv2d_kernel_data = np.array(\n    [\n        [\n            [\n                [\n                    1.1630785465240479,\n                    0.4838046133518219,\n                    0.299563467502594,\n                    0.15302546322345734,\n                    -1.168814778327942,\n                    1.5580710172653198,\n                    -0.5459445714950562,\n                ],\n                [\n                    -2.3556296825408936,\n                    0.5414402484893799,\n                    2.678506374359131,\n                    1.2546343803405762,\n                    -0.5487740635871887,\n                    -0.6810643672943115,\n                    -0.13531559705734253,\n                ],\n                [\n                    0.37723132967948914,\n                    0.41016456484794617,\n                    0.5712682008743286,\n                    -2.757962703704834,\n                    1.0762799978256226,\n                    -0.6141325235366821,\n                    1.830764889717102,\n                ],\n                [\n                    -1.1468064785003662,\n                    0.053837940096855164,\n                    -2.5074806213378906,\n                    -0.5916498899459839,\n                    0.8586049675941467,\n                    -0.2279418259859085,\n                    0.2013147622346878,\n                ],\n                [\n                    0.35005471110343933,\n                    0.5360521078109741,\n                    1.5194443464279175,\n                    1.9040879011154175,\n                    -1.5734431743621826,\n                    -0.14007866382598877,\n                    0.29670074582099915,\n                ],\n                [\n                    1.3111951351165771,\n                    0.5035904049873352,\n                    -1.1894450187683105,\n                    -0.5502137541770935,\n                    -1.591875672340393,\n                    -1.1081947088241577,\n                    0.07872020453214645,\n                ],\n                [\n                    -0.9185634255409241,\n                    -0.7457143664360046,\n                    -1.2080862522125244,\n                    1.8140212297439575,\n                    -1.5227429866790771,\n                    -2.515244960784912,\n                    -1.3549325466156006,\n                ],\n            ]\n        ]\n    ]\n)\ntest_conv2d_kernel_data_grad = np.array(\n    [\n        [\n            [\n                [\n                    -0.9574840068817139,\n                    -1.6823396682739258,\n                    -0.5703760385513306,\n                    -0.0911646485328674,\n                    -0.5402582287788391,\n                    -1.6522218585014343,\n                    -1.1739492416381836,\n                ],\n                [\n                    -1.749668538570404,\n                    -1.5424200296401978,\n                    -3.586230516433716,\n                    -0.121304988861084,\n                    -2.0410948395729065,\n                    0.0027156472206116,\n                    -1.7125413417816162,\n                ],\n                [\n                    -2.6108018159866333,\n                    -4.285072386264801,\n                    -7.049453675746918,\n                    -3.079410582780838,\n                    -3.2773211896419525,\n                    -0.5129399001598358,\n                    -1.8721811771392822,\n                ],\n                [\n                    -2.6108018159866333,\n                    -4.285072386264801,\n                    -7.049453675746918,\n                    -3.079410582780838,\n                    -3.2773211896419525,\n                    -0.5129399001598358,\n                    -1.8721811771392822,\n                ],\n                [\n                    -2.6108018159866333,\n                    -4.285072386264801,\n                    -7.049453675746918,\n                    -3.079410582780838,\n                    -3.2773211896419525,\n                    -0.5129399001598358,\n                    -1.8721811771392822,\n                ],\n                [\n                    -1.6533178091049194,\n                    -2.6027327179908752,\n                    -6.479077637195587,\n                    -2.9882459342479706,\n                    -2.7370629608631134,\n                    1.1392819583415985,\n                    -0.6982319355010986,\n                ],\n                [\n                    -0.8611332774162292,\n                    -2.7426523566246033,\n                    -3.463223159313202,\n                    -2.958105593919754,\n                    -1.236226350069046,\n                    -0.5156555473804474,\n                    -0.159639835357666,\n                ],\n            ]\n        ]\n    ]\n)\ntest_conv2d_kernel_weight_grad = np.array(\n    [\n        [\n            [\n                [\n                    2.974529668688774,\n                    4.548736393451691,\n                    1.1672898679971695,\n                    -1.499158263206482,\n                    0.1862268149852753,\n                ],\n                [\n                    1.6534235626459122,\n                    2.3762744814157486,\n                    -1.448018729686737,\n                    -5.2917241007089615,\n                    -2.278435029089451,\n                ],\n                [\n                    -2.083257421851158,\n                    -2.23808591067791,\n                    -5.749193429946899,\n                    -7.540486767888069,\n                    -6.306201495230198,\n                ],\n            ]\n        ]\n    ]\n)\ntest_conv2d_kernel_output = np.array(\n    [\n        [\n            [\n                [-3.5647754669189453, -4.234736919403076, 1.4046944379806519],\n                [-0.6964312791824341, 16.42838478088379, -9.649789810180664],\n                [4.312150478363037, -6.283960819244385, -4.8443922996521],\n                [-2.772286891937256, -4.483709812164307, 12.315184593200684],\n                [7.39893913269043, 1.305102825164795, -2.049992561340332],\n            ]\n        ]\n    ]\n)\ntest_conv2d_dilation_weight = np.array(\n    [\n        [\n            [\n                [-0.9574840068817139, -0.7248556613922119, 1.1119636297225952],\n                [-0.47827261686325073, -1.1739492416381836, -0.7921845316886902],\n                [0.9321041703224182, -3.1557741165161133, 2.1935296058654785],\n            ]\n        ]\n    ]\n)\ntest_conv2d_dilation_data = np.array(\n    [\n        [\n            [\n                [\n                    1.1630785465240479,\n                    0.4838046133518219,\n                    0.299563467502594,\n                    0.15302546322345734,\n                    -1.168814778327942,\n                    1.5580710172653198,\n                    -0.5459445714950562,\n                ],\n                [\n                    -2.3556296825408936,\n                    0.5414402484893799,\n                    2.678506374359131,\n                    1.2546343803405762,\n                    -0.5487740635871887,\n                    -0.6810643672943115,\n                    -0.13531559705734253,\n                ],\n                [\n                    0.37723132967948914,\n                    0.41016456484794617,\n                    0.5712682008743286,\n                    -2.757962703704834,\n                    1.0762799978256226,\n                    -0.6141325235366821,\n                    1.830764889717102,\n                ],\n                [\n                    -1.1468064785003662,\n                    0.053837940096855164,\n                    -2.5074806213378906,\n                    -0.5916498899459839,\n                    0.8586049675941467,\n                    -0.2279418259859085,\n                    0.2013147622346878,\n                ],\n                [\n                    0.35005471110343933,\n                    0.5360521078109741,\n                    1.5194443464279175,\n                    1.9040879011154175,\n                    -1.5734431743621826,\n                    -0.14007866382598877,\n                    0.29670074582099915,\n                ],\n                [\n                    1.3111951351165771,\n                    0.5035904049873352,\n                    -1.1894450187683105,\n                    -0.5502137541770935,\n                    -1.591875672340393,\n                    -1.1081947088241577,\n                    0.07872020453214645,\n                ],\n                [\n                    -0.9185634255409241,\n                    -0.7457143664360046,\n                    -1.2080862522125244,\n                    1.8140212297439575,\n                    -1.5227429866790771,\n                    -2.515244960784912,\n                    -1.3549325466156006,\n                ],\n            ]\n        ]\n    ]\n)\ntest_conv2d_dilation_data_grad = np.array(\n    [\n        [\n            [\n                [\n                    -0.9574840068817139,\n                    0.0,\n                    0.0,\n                    -0.7248556613922119,\n                    0.0,\n                    0.0,\n                    1.1119636297225952,\n                ],\n                [\n                    -0.9574840068817139,\n                    0.0,\n                    0.0,\n                    -0.7248556613922119,\n                    0.0,\n                    0.0,\n                    1.1119636297225952,\n                ],\n                [\n                    -1.4357566237449646,\n                    0.0,\n                    0.0,\n                    -1.8988049030303955,\n                    0.0,\n                    0.0,\n                    0.319779098033905,\n                ],\n                [\n                    -0.4782726168632507,\n                    0.0,\n                    0.0,\n                    -1.1739492416381836,\n                    0.0,\n                    0.0,\n                    -0.7921845316886902,\n                ],\n                [\n                    0.4538315534591675,\n                    0.0,\n                    0.0,\n                    -4.329723358154297,\n                    0.0,\n                    0.0,\n                    1.4013450741767883,\n                ],\n                [\n                    0.9321041703224182,\n                    0.0,\n                    0.0,\n                    -3.1557741165161133,\n                    0.0,\n                    0.0,\n                    2.1935296058654785,\n                ],\n                [\n                    0.9321041703224182,\n                    0.0,\n                    0.0,\n                    -3.1557741165161133,\n                    0.0,\n                    0.0,\n                    2.1935296058654785,\n                ],\n            ]\n        ]\n    ]\n)\ntest_conv2d_dilation_weight_grad = np.array(\n    [\n        [\n            [\n                [-0.8153198063373566, -1.3503028601408005, 1.1495047211647034],\n                [-0.4195204377174377, -1.4455246925354004, 2.328780397772789],\n                [0.7426864206790924, 3.1678953766822815, -0.979511596262455],\n            ]\n        ]\n    ]\n)\ntest_conv2d_dilation_output = np.array(\n    [[[[-5.2563982009887695], [5.410353183746338], [-8.517012596130371]]]]\n)\n\n\ndef _test_conv2d(\n    test_case, conv, data, weight, output, bias=None, device=\"cuda\",\n):\n    to_device = flow.device(device)\n    x = flow.tensor(data, dtype=flow.float32, device=to_device)\n    conv.weight = flow.nn.Parameter(flow.Tensor(weight))\n    if bias is not None:\n        conv.bias = flow.nn.Parameter(flow.Tensor(bias))\n    conv.to(to_device)\n    of_out = conv(x)\n    test_case.assertTrue(np.allclose(of_out.numpy(), output, rtol=1e-4, atol=1e-8))\n\n\ndef _test_conv2d_backward(\n    test_case,\n    conv,\n    data,\n    weight,\n    data_grad,\n    weight_grad,\n    bias=None,\n    device=\"cuda\",\n    data_rtol=1e-4,\n    data_atol=1e-8,\n    weight_rtol=1e-4,\n    weight_atol=1e-8,\n):\n    to_device = flow.device(device)\n    x = flow.tensor(data, dtype=flow.float32, device=to_device, requires_grad=True)\n    conv.weight = flow.nn.Parameter(flow.Tensor(weight), requires_grad=True)\n    if bias is not None:\n        conv.bias = flow.nn.Parameter(flow.Tensor(bias))\n    conv.to(to_device)\n    of_out = conv(x)\n    of_out.sum().backward()\n    test_case.assertTrue(\n        np.allclose(x.grad.numpy(), data_grad, rtol=data_rtol, atol=data_atol)\n    )\n    test_case.assertTrue(\n        np.allclose(\n            conv.weight.grad.numpy(), weight_grad, rtol=weight_rtol, atol=weight_atol\n        )\n    )\n\n\ndef _test_conv2d_large_in_channel(test_case, device):\n    np_arr = np.array(\n        [\n            [\n                [\n                    [\n                        0.6206631238581714,\n                        -1.1225329393404626,\n                        0.8407155480700242,\n                        -0.6845162855236345,\n                    ],\n                    [\n                        -0.5186484633906412,\n                        0.10420735184519186,\n                        -0.1711568947473012,\n                        0.5168640476046483,\n                    ],\n                    [\n                        -0.12429464919764661,\n                        0.050277779246134253,\n                        -1.0144501797426606,\n                        -2.184600444658526,\n                    ],\n                    [\n                        0.28918126931309923,\n                        -0.822872663244595,\n                        0.44019150436683663,\n                        -1.0247720130825562,\n                    ],\n                ],\n                [\n                    [\n                        0.7786504412818226,\n                        -0.7501839068078657,\n                        -0.8187283189941765,\n                        -1.1116653569170698,\n                    ],\n                    [\n                        0.18085524152316743,\n                        -1.3461349607476678,\n                        1.142505437476448,\n                        -0.000649619704040145,\n                    ],\n                    [\n                        0.03160672782674317,\n                        -0.006318157449953413,\n                        1.2218487782604377,\n                        0.15903027907930234,\n                    ],\n                    [\n                        1.5857011815642381,\n                        0.6656477116332891,\n                        -0.04036621813223574,\n                        -0.3427168687988546,\n                    ],\n                ],\n                [\n                    [\n                        -1.1774346070102524,\n                        1.6195241269303395,\n                        -0.36185552303441965,\n                        -1.1382193113192487,\n                    ],\n                    [\n                        0.08061907334568702,\n                        1.5025447613238763,\n                        -1.1591348706634745,\n                        1.6449050139676873,\n                    ],\n                    [\n                        1.1539915649822392,\n                        -2.414624939646017,\n                        0.3056063774849572,\n                        1.1920089257083162,\n                    ],\n                    [\n                        0.7623012858982319,\n                        -0.01685314742940813,\n                        -1.096666898224702,\n                        -0.4406476137098582,\n                    ],\n                ],\n                [\n                    [\n                        0.9383797282214235,\n                        -1.1075876842796508,\n                        -0.4420913825139058,\n                        -1.0736097610655628,\n                    ],\n                    [\n                        -0.3101376466546291,\n                        1.6578227745160954,\n                        -0.6225454278031398,\n                        0.6831188620748697,\n                    ],\n                    [\n                        0.00743800968372913,\n                        -0.8089158949698473,\n                        2.08084287836801,\n                        0.721204366332351,\n                    ],\n                    [\n                        0.5694701823297723,\n                        0.031519314469744895,\n                        -0.5041680957766629,\n                        -0.4738588233094669,\n                    ],\n                ],\n            ]\n        ]\n    )\n    input = flow.tensor(\n        np_arr, dtype=flow.float32, device=flow.device(device), requires_grad=True\n    )\n    weight = np.array(\n        [\n            [\n                [\n                    [0.06456436216831207, -0.10852358490228653, -0.21638715267181396],\n                    [-0.2279110550880432, 0.1476770043373108, 0.19457484781742096],\n                    [0.05026858672499657, 0.10818571597337723, 0.02056501805782318],\n                ],\n                [\n                    [0.205095112323761, 0.1488947868347168, -0.2344113141298294],\n                    [0.1684819906949997, -0.21986986696720123, 0.1082606166601181],\n                    [-0.1528974026441574, 0.17120417952537537, 0.01954500749707222],\n                ],\n            ],\n            [\n                [\n                    [-0.09441672265529633, -0.03644559532403946, -0.22235223650932312],\n                    [-0.1771145612001419, 0.08043312281370163, 0.06938580423593521],\n                    [0.054393064230680466, -0.05483492836356163, 0.23438701033592224],\n                ],\n                [\n                    [0.22666795551776886, 0.0874653309583664, 0.07092718034982681],\n                    [0.08883464336395264, -0.052362944930791855, -0.1720171570777893],\n                    [0.10441060364246368, 0.011952142231166363, -0.0894528403878212],\n                ],\n            ],\n        ]\n    )\n    m = flow.nn.Conv2d(4, 2, 3, groups=2, bias=False)\n    m.weight = flow.nn.Parameter(flow.Tensor(weight), requires_grad=True)\n    m = m.to(device)\n    output = m(input)\n    np_out = [\n        [\n            [\n                [0.7666134238243103, -0.3961866497993469],\n                [-0.656266987323761, -1.1613956689834595],\n            ],\n            [\n                [0.3077264130115509, -0.42817503213882446],\n                [-0.5761325359344482, 0.1300736665725708],\n            ],\n        ]\n    ]\n    test_case.assertTrue(np.allclose(output.numpy(), np_out, 1e-3, 1e-3))\n    output = output.sum()\n    output.backward()\n    np_grad = [\n        [\n            [\n                [\n                    0.06456436216831207,\n                    -0.04395922273397446,\n                    -0.3249107301235199,\n                    -0.21638715267181396,\n                ],\n                [\n                    -0.16334669291973114,\n                    -0.12419328093528748,\n                    0.017341122031211853,\n                    -0.021812304854393005,\n                ],\n                [\n                    -0.17764246463775635,\n                    0.07822024822235107,\n                    0.47100257873535156,\n                    0.21513986587524414,\n                ],\n                [\n                    0.05026858672499657,\n                    0.1584542989730835,\n                    0.128750741481781,\n                    0.02056501805782318,\n                ],\n            ],\n            [\n                [\n                    0.205095112323761,\n                    0.3539898991584778,\n                    -0.08551652729511261,\n                    -0.2344113141298294,\n                ],\n                [\n                    0.3735771179199219,\n                    0.30260205268859863,\n                    -0.19712577760219574,\n                    -0.1261506974697113,\n                ],\n                [\n                    0.015584588050842285,\n                    -0.03308109939098358,\n                    0.07913993299007416,\n                    0.12780562043190002,\n                ],\n                [\n                    -0.1528974026441574,\n                    0.018306776881217957,\n                    0.1907491832971573,\n                    0.01954500749707222,\n                ],\n            ],\n            [\n                [\n                    -0.09441672265529633,\n                    -0.13086232542991638,\n                    -0.258797824382782,\n                    -0.22235223650932312,\n                ],\n                [\n                    -0.27153128385543823,\n                    -0.22754377126693726,\n                    -0.10897888988256454,\n                    -0.1529664397239685,\n                ],\n                [\n                    -0.12272149324417114,\n                    -0.09712330251932144,\n                    0.32937100529670715,\n                    0.30377280712127686,\n                ],\n                [\n                    0.054393064230680466,\n                    -0.00044186413288116455,\n                    0.1795520782470703,\n                    0.23438701033592224,\n                ],\n            ],\n            [\n                [\n                    0.22666795551776886,\n                    0.31413328647613525,\n                    0.1583925187587738,\n                    0.07092718034982681,\n                ],\n                [\n                    0.3155025839805603,\n                    0.35060498118400574,\n                    -0.06598758697509766,\n                    -0.1010899767279625,\n                ],\n                [\n                    0.19324524700641632,\n                    0.1528344452381134,\n                    -0.301880806684494,\n                    -0.2614699900150299,\n                ],\n                [\n                    0.10441060364246368,\n                    0.11636274307966232,\n                    -0.07750070095062256,\n                    -0.0894528403878212,\n                ],\n            ],\n        ]\n    ]\n    test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-3, 1e-3))\n\n\ndef _test_conv2d_large_out_channel(test_case, device):\n    np_arr = np.array(\n        [\n            [\n                [\n                    [0.56573248, -0.19689320, -0.67875558, 0.34328273, 0.31964567],\n                    [-1.33715475, 0.33422229, -1.27643383, 0.37904647, 0.35891593],\n                    [0.84579802, 2.12729621, -0.51423287, 0.61297560, -1.31156564],\n                    [-0.71047139, 1.02679253, -0.76686019, -0.72969633, 0.73425150],\n                    [-0.13592879, -1.03207183, -0.22554775, 0.74148071, 0.96601510],\n                ],\n                [\n                    [0.51595992, 0.49624804, 0.91145641, 0.49247262, 0.41002217],\n                    [-1.08001196, 1.55497086, -0.81963140, -0.45511565, -0.60269165],\n                    [0.05563145, -0.94318372, -1.17058158, -0.73568577, 0.57810956],\n                    [-0.40260276, -0.10309298, 1.12378800, -0.23510537, -0.73893374],\n                    [-0.52712536, -0.00717016, -1.85051966, -1.50790560, 1.38335907],\n                ],\n            ]\n        ]\n    )\n    input = flow.tensor(\n        np_arr, dtype=flow.float32, device=flow.device(device), requires_grad=True\n    )\n    weight = np.array(\n        [\n            [\n                [\n                    [-0.19489679, -0.32377058, 0.21736273],\n                    [0.04095296, -0.21552679, -0.14626531],\n                    [-0.19359522, -0.00742865, -0.19832158],\n                ]\n            ],\n            [\n                [\n                    [0.29926914, 0.00931164, 0.26197660],\n                    [0.27611443, -0.15439281, -0.19027126],\n                    [-0.28909120, 0.30367029, -0.05168664],\n                ]\n            ],\n            [\n                [\n                    [-0.03155736, 0.17610769, 0.22111714],\n                    [0.22790670, -0.32897446, -0.03260243],\n                    [-0.10274851, -0.06903386, -0.19438276],\n                ]\n            ],\n            [\n                [\n                    [-0.24573688, -0.06723209, -0.21363299],\n                    [-0.02136187, -0.24994437, -0.18691199],\n                    [0.12189507, 0.29469389, 0.03398871],\n                ]\n            ],\n        ]\n    )\n    m = flow.nn.Conv2d(2, 4, 3, groups=2, bias=False)\n    m.weight = flow.nn.Parameter(flow.Tensor(weight), requires_grad=True)\n    m = m.to(device)\n    output = m(input)\n    np_out = np.array(\n        [\n            [\n                [\n                    [-0.21170563, 0.03652292, 0.25926736],\n                    [-0.19168918, 0.49044561, 0.25099146],\n                    [-1.02489340, 0.25361472, -0.51828313],\n                ],\n                [\n                    [0.23977707, -0.56090075, -0.19285655],\n                    [-0.17167747, 0.24558367, -0.30935860],\n                    [-0.33303234, 1.52472734, -0.49013454],\n                ],\n                [\n                    [-0.17137986, 1.21333742, 0.18988736],\n                    [0.31785482, -0.12121570, -0.18676008],\n                    [-0.10680684, -0.30298883, 0.41809759],\n                ],\n                [\n                    [-0.87821335, -0.51665992, -0.44061098],\n                    [0.74804580, 0.53107250, 0.50418228],\n                    [-0.00512899, -0.36455840, -0.23643512],\n                ],\n            ]\n        ]\n    )\n    test_case.assertTrue(np.allclose(output.numpy(), np_out, 1e-3, 1e-3))\n    output = output.sum()\n    output.backward()\n    np_grad = np.array(\n        [\n            [\n                [\n                    [0.10437235, -0.21008658, 0.26925275, 0.16488039, 0.47933933],\n                    [0.42143974, -0.26293880, -0.12013602, -0.54157579, 0.14280275],\n                    [-0.06124666, -0.44938356, -0.55658901, -0.49534237, -0.10720548],\n                    [-0.16561902, -0.23929697, -0.82584178, -0.66022277, -0.58654481],\n                    [-0.48268640, -0.18644476, -0.43645298, 0.04623342, -0.25000823],\n                ],\n                [\n                    [-0.27729425, -0.16841865, -0.16093449, 0.11635975, 0.00748415],\n                    [-0.07074942, -0.54079264, -0.75282294, -0.68207347, -0.21203026],\n                    [-0.05160286, -0.29598606, -0.66841042, -0.61680746, -0.37242430],\n                    [0.22569139, -0.12756741, -0.50747585, -0.73316729, -0.37990844],\n                    [0.01914656, 0.24480659, 0.08441254, 0.06526598, -0.16039404],\n                ],\n            ]\n        ]\n    )\n    test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-3, 1e-3))\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestConv2d(flow.unittest.TestCase):\n    def test_conv2d_default_init(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"device\"] = [\"cuda\", \"cpu\"]\n        for arg in GenArgList(arg_dict):\n            device = arg[0]\n            conv = flow.nn.Conv2d(1, 1, (3, 3), bias=True).to(flow.device(device))\n            test_case.assertTrue(\n                not np.allclose(\n                    conv.weight.numpy(), np.zeros((1, 1, 3, 3)), rtol=1e-9, atol=1e-10\n                )\n            )\n            test_case.assertTrue(\n                not np.allclose(\n                    conv.bias.numpy(), np.zeros((1,)), rtol=1e-9, atol=1e-10\n                )\n            )\n            conv = flow.nn.Conv2d(\n                1, 1, (3, 3), bias=True, device=device, dtype=flow.float32\n            )\n            test_case.assertTrue(\n                not np.allclose(\n                    conv.weight.numpy(), np.zeros((1, 1, 3, 3)), rtol=1e-9, atol=1e-10\n                )\n            )\n            test_case.assertTrue(\n                not np.allclose(\n                    conv.bias.numpy(), np.zeros((1,)), rtol=1e-9, atol=1e-10\n                )\n            )\n            conv = flow.nn.Conv2d(\n                1, 1, (3, 3), bias=True, device=device, dtype=flow.float16\n            )\n            test_case.assertTrue(\n                not np.allclose(\n                    conv.weight.numpy(), np.zeros((1, 1, 3, 3)), rtol=1e-9, atol=1e-10\n                )\n            )\n            test_case.assertTrue(\n                not np.allclose(\n                    conv.bias.numpy(), np.zeros((1,)), rtol=1e-9, atol=1e-10\n                )\n            )\n\n    @unittest.skip(\"skip for now, becase it failed 8 times in past week\")\n    @autotest(n=3)\n    def test_nn_functional_conv2d(test_case):\n        device = random_device()\n        img = torch.ones((1, 3, 224, 224), requires_grad=True).to(device)\n        kernel = torch.ones((3, 1, 3, 3), requires_grad=True).to(device)\n        y = torch.nn.functional.conv2d(input=img, weight=kernel, groups=3)\n        return y\n\n    @unittest.skipIf(\n        version.parse(torch_original.__version__) <= version.parse(\"1.13.0\"),\n        \"conv module don't support unbatched input in PyTorch before '1.13.0'\",\n    )\n    @autotest(n=3)\n    def test_nn_functional_conv2d_3dinput(test_case):\n        device = random_device()\n        img = torch.ones((3, 224, 224), requires_grad=True).to(device)\n        kernel = torch.ones((3, 1, 3, 3), requires_grad=True).to(device)\n        y = torch.nn.functional.conv2d(input=img, weight=kernel, groups=3)\n        return y\n\n    def test_conv2d(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"device\"] = [\"cuda\", \"cpu\"]\n        for arg in GenArgList(arg_dict):\n            device = arg[0]\n            conv = flow.nn.Conv2d(1, 3, (3, 3), bias=False).to(flow.device(device))\n            _test_conv2d(\n                test_case,\n                conv,\n                test_conv2d_data,\n                test_conv2d_weight,\n                test_conv2d_output,\n                device=device,\n            )\n\n    def test_conv2d_backward(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"device\"] = [\"cuda\", \"cpu\"]\n        os.environ[\"ONEFLOW_ENABLE_NHWC\"] = \"0\"\n        for arg in GenArgList(arg_dict):\n            device = arg[0]\n            conv = flow.nn.Conv2d(1, 3, (3, 3), bias=False).to(flow.device(device))\n            _test_conv2d_backward(\n                test_case,\n                conv,\n                test_conv2d_data,\n                test_conv2d_weight,\n                test_conv2d_data_grad,\n                test_conv2d_weight_grad,\n                device=device,\n            )\n\n    # bias grad not yet supported\n    def test_conv2d_with_bias(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"device\"] = [\"cuda\", \"cpu\"]\n        for arg in GenArgList(arg_dict):\n            device = arg[0]\n            conv = flow.nn.Conv2d(1, 3, (3, 3), bias=True).to(flow.device(device))\n            _test_conv2d(\n                test_case,\n                conv,\n                test_conv2d_with_bias_data,\n                test_conv2d_with_bias_weight,\n                test_conv2d_with_bias_output,\n                bias=test_conv2d_with_bias_bias,\n                device=device,\n            )\n\n    def test_conv2d_group(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"device\"] = [\"cuda\", \"cpu\"]\n        for arg in GenArgList(arg_dict):\n            device = arg[0]\n            conv = flow.nn.Conv2d(2, 2, (3, 3), groups=2, bias=False).to(\n                flow.device(device)\n            )\n            _test_conv2d(\n                test_case,\n                conv,\n                test_conv2d_group_data,\n                test_conv2d_group_weight,\n                test_conv2d_group_output,\n                device=device,\n            )\n\n    def test_conv2d_group_backward(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"device\"] = [\"cuda\", \"cpu\"]\n        for arg in GenArgList(arg_dict):\n            device = arg[0]\n            conv = flow.nn.Conv2d(2, 2, (3, 3), groups=2, bias=False).to(\n                flow.device(device)\n            )\n            _test_conv2d_backward(\n                test_case,\n                conv,\n                test_conv2d_group_data,\n                test_conv2d_group_weight,\n                test_conv2d_group_data_grad,\n                test_conv2d_group_weight_grad,\n                device=device,\n            )\n\n    def test_conv2d_padding(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"device\"] = [\"cuda\", \"cpu\"]\n        for arg in GenArgList(arg_dict):\n            device = arg[0]\n            conv = flow.nn.Conv2d(1, 1, (3, 3), padding=(1, 2), bias=False).to(\n                flow.device(device)\n            )\n            _test_conv2d(\n                test_case,\n                conv,\n                test_conv2d_padding_data,\n                test_conv2d_padding_weight,\n                test_conv2d_padding_output,\n                device=device,\n            )\n\n    def test_conv2d_padding_backward(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"device\"] = [\"cuda\", \"cpu\"]\n        for arg in GenArgList(arg_dict):\n            device = arg[0]\n            conv = flow.nn.Conv2d(1, 1, (3, 3), padding=(1, 2), bias=False).to(\n                flow.device(device)\n            )\n            _test_conv2d_backward(\n                test_case,\n                conv,\n                test_conv2d_padding_data,\n                test_conv2d_padding_weight,\n                test_conv2d_padding_data_grad,\n                test_conv2d_padding_weight_grad,\n                device=device,\n                weight_atol=1e-3,\n            )\n\n    def test_conv2d_stride(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"device\"] = [\"cuda\", \"cpu\"]\n        for arg in GenArgList(arg_dict):\n            device = arg[0]\n            conv = flow.nn.Conv2d(\n                1, 1, (3, 3), padding=(1, 1), stride=(2, 3), bias=False\n            ).to(flow.device(device))\n            _test_conv2d(\n                test_case,\n                conv,\n                test_conv2d_stride_data,\n                test_conv2d_stride_weight,\n                test_conv2d_stride_output,\n                device=device,\n            )\n\n    def test_conv2d_stride_backward(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"device\"] = [\"cuda\", \"cpu\"]\n        for arg in GenArgList(arg_dict):\n            device = arg[0]\n            conv = flow.nn.Conv2d(\n                1, 1, (3, 3), padding=(1, 1), stride=(2, 3), bias=False\n            ).to(flow.device(device))\n            _test_conv2d_backward(\n                test_case,\n                conv,\n                test_conv2d_stride_data,\n                test_conv2d_stride_weight,\n                test_conv2d_stride_data_grad,\n                test_conv2d_stride_weight_grad,\n                device=device,\n            )\n\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    def test_conv2d_kernel(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"device\"] = [\"cuda\", \"cpu\"]\n        for arg in GenArgList(arg_dict):\n            device = arg[0]\n            conv = flow.nn.Conv2d(1, 1, (3, 5), bias=False).to(flow.device(device))\n            conv.to(flow.device(\"cuda\"))\n            _test_conv2d(\n                test_case,\n                conv,\n                test_conv2d_kernel_data,\n                test_conv2d_kernel_weight,\n                test_conv2d_kernel_output,\n                device=device,\n            )\n\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    def test_conv2d_kernel_backward(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"device\"] = [\"cuda\", \"cpu\"]\n        for arg in GenArgList(arg_dict):\n            device = arg[0]\n            conv = flow.nn.Conv2d(1, 1, (3, 5), bias=False).to(flow.device(device))\n            conv.to(flow.device(\"cuda\"))\n            _test_conv2d_backward(\n                test_case,\n                conv,\n                test_conv2d_kernel_data,\n                test_conv2d_kernel_weight,\n                test_conv2d_kernel_data_grad,\n                test_conv2d_kernel_weight_grad,\n                device=device,\n            )\n\n    def test_conv2d_dilation(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"device\"] = [\"cuda\", \"cpu\"]\n        for arg in GenArgList(arg_dict):\n            device = arg[0]\n            conv = flow.nn.Conv2d(1, 1, (3, 3), dilation=(2, 3), bias=False).to(\n                flow.device(device)\n            )\n            _test_conv2d(\n                test_case,\n                conv,\n                test_conv2d_dilation_data,\n                test_conv2d_dilation_weight,\n                test_conv2d_dilation_output,\n                device=device,\n            )\n\n    def test_conv2d_dilation_backward(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"device\"] = [\"cuda\", \"cpu\"]\n        for arg in GenArgList(arg_dict):\n            device = arg[0]\n            conv = flow.nn.Conv2d(1, 1, (3, 3), dilation=(2, 3), bias=False).to(\n                flow.device(device)\n            )\n            _test_conv2d_backward(\n                test_case,\n                conv,\n                test_conv2d_dilation_data,\n                test_conv2d_dilation_weight,\n                test_conv2d_dilation_data_grad,\n                test_conv2d_dilation_weight_grad,\n                device=device,\n            )\n\n    def test_large_in_channel_group_conv(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [\n            _test_conv2d_large_in_channel,\n        ]\n        arg_dict[\"device\"] = [\"cuda\", \"cpu\"]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n    def test_large_out_channel_group_conv(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [\n            _test_conv2d_large_out_channel,\n        ]\n        arg_dict[\"device\"] = [\"cuda\", \"cpu\"]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n    @autotest(n=5, rtol=1e-2, atol=1e-2)\n    def test_conv2d_with_random_data(test_case):\n        channels = random(1, 6)\n        m = torch.nn.Conv2d(\n            in_channels=channels,\n            out_channels=random(1, 20),\n            kernel_size=random(1, 4),\n            stride=random(1, 4) | nothing(),\n            padding=random(1, 3).to(int) | nothing(),\n            dilation=random(1, 3) | nothing(),\n            groups=random(1, 5) | nothing(),\n            padding_mode=constant(\"zeros\") | nothing(),\n        )\n        m.train(random())\n        device = random_device()\n        m.to(device)\n        x = random_tensor(ndim=4, dim1=channels).to(device)\n        y = m(x)\n        return y\n\n    @unittest.skipIf(\n        version.parse(torch_original.__version__) <= version.parse(\"1.13.0\"),\n        \"conv module don't support unbatched input in PyTorch before '1.13.0'\",\n    )\n    @autotest(n=5, rtol=1e-3, atol=1e-3)\n    def test_conv2d_auto_squeeze_with_random_data(test_case):\n        channels = random(1, 6)\n        m = torch.nn.Conv2d(\n            in_channels=channels,\n            out_channels=random(1, 20),\n            kernel_size=random(1, 4),\n            stride=random() | nothing(),\n            padding=random(1, 3).to(int) | nothing(),\n            dilation=random(1, 5) | nothing(),\n            groups=random(1, 5) | nothing(),\n            padding_mode=constant(\"zeros\") | nothing(),\n            bias=random_bool(),\n        )\n        m.train(random())\n        device = random_device()\n        m.to(device)\n        x = random_tensor(ndim=3, dim0=channels).to(device)\n        y = m(x)\n        return y\n\n    @autotest(n=5, check_graph=False)\n    def test_conv2d_0size_with_random_data(test_case):\n        channels = random(1, 6)\n        m = torch.nn.Conv2d(\n            in_channels=channels,\n            out_channels=random(1, 20),\n            kernel_size=random(1, 4),\n            stride=random() | nothing(),\n            padding=random(1, 3).to(int) | nothing(),\n            dilation=random(1, 5) | nothing(),\n            groups=random(1, 5) | nothing(),\n            padding_mode=constant(\"zeros\") | nothing(),\n        )\n        m.train(random())\n        device = random_device()\n        m.to(device)\n        x = random_tensor(ndim=4, dim0=0, dim1=channels).to(device)\n        y = m(x)\n        return y\n\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    @autotest(n=5, check_allclose=False)\n    def test_conv2d_group_with_random_data(test_case):\n        channels = 720  # lcm(1, 2, 3, 4, 5, 6)\n        m = torch.nn.Conv2d(\n            in_channels=channels,\n            out_channels=channels,\n            kernel_size=random(1, 4),\n            stride=random() | nothing(),\n            padding=random(1, 3).to(int) | nothing(),\n            dilation=random(1, 5) | nothing(),\n            groups=random(1, 7),\n            padding_mode=constant(\"zeros\") | nothing(),\n        )\n        m.train(random())\n\n        device = random_device()\n        m.to(device)\n        m.pytorch.to(\"cuda\")\n        x = random_tensor(ndim=4, dim1=channels).to(device)\n        x.pytorch = x.pytorch.to(\"cuda\")\n        y = m(x)\n        return y\n\n    @unittest.skip(\"skip for now, becase it failed 6 times in past week\")\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    def test_conv2d_NHWC_with_random_data(test_case):\n        in_channels = np.random.randint(6, 33)\n        out_channels = np.random.randint(32, 66)\n        kernel_size = np.random.randint(1, 5)\n        stride = np.random.randint(1, 2)\n        padding = np.random.randint(1, 3)\n        dilation = np.random.randint(1, 3)\n        spatial = np.random.randint(6, 64)\n\n        np_x = np.random.randn(4, in_channels, spatial, spatial).astype(np.float32)\n        np_weight = np.random.randn(\n            out_channels, in_channels, kernel_size, kernel_size\n        ).astype(np.float32)\n        np_bias = np.random.randn(out_channels).astype(np.float32)\n\n        flow_nchw_input = flow.tensor(\n            np_x, device=\"cuda\", dtype=flow.float32, requires_grad=True\n        )\n        flow_nchw_weights = flow.nn.Parameter(\n            flow.tensor(\n                np_weight, device=\"cuda\", dtype=flow.float32, requires_grad=True\n            )\n        )\n        flow_nchw_bias = flow.nn.Parameter(\n            flow.tensor(np_bias, device=\"cuda\", dtype=flow.float32, requires_grad=True)\n        )\n\n        flow_nchw_conv = flow.nn.Conv2d(\n            in_channels=in_channels,\n            out_channels=out_channels,\n            kernel_size=kernel_size,\n            stride=stride,\n            padding=padding,\n            dilation=dilation,\n        ).to(\"cuda\")\n        flow_nchw_conv.weight = flow_nchw_weights\n        flow_nchw_conv.bias = flow_nchw_bias\n\n        flow_nchw_out = flow_nchw_conv(flow_nchw_input)\n\n        os.environ[\"ONEFLOW_ENABLE_NHWC\"] = \"1\"\n        flow_nhwc_input = flow.tensor(\n            np_x, device=\"cuda\", dtype=flow.float32, requires_grad=True\n        )\n        flow_nhwc_permuted_input = flow.permute(flow_nhwc_input, (0, 2, 3, 1))\n        flow_nhwc_weights = flow.tensor(\n            np_weight, device=\"cuda\", dtype=flow.float32, requires_grad=True\n        )\n        flow_nhwc_permuted_weights = flow.nn.Parameter(\n            flow.permute(flow_nhwc_weights, (0, 2, 3, 1))\n        )\n        flow_nhwc_bias = flow.nn.Parameter(\n            flow.tensor(np_bias, device=\"cuda\", dtype=flow.float32, requires_grad=True)\n        )\n\n        flow_nhwc_conv = flow.nn.Conv2d(\n            in_channels=in_channels,\n            out_channels=out_channels,\n            kernel_size=kernel_size,\n            stride=stride,\n            padding=padding,\n            dilation=dilation,\n        ).to(\"cuda\")\n        flow_nhwc_conv.weight = flow_nhwc_permuted_weights\n        flow_nhwc_conv.bias = flow_nhwc_bias\n\n        flow_nhwc_out = flow_nhwc_conv(flow_nhwc_permuted_input)\n        flow_nhwc_permuted_out = flow.permute(flow_nhwc_out, (0, 3, 1, 2))\n\n        test_case.assertTrue(\n            np.allclose(\n                flow_nchw_out.numpy(),\n                flow_nhwc_permuted_out.numpy(),\n                rtol=1e-4,\n                atol=1e-4,\n            )\n        )\n\n        total_out = flow_nchw_out + flow_nhwc_permuted_out\n\n        total_out = total_out.sum()\n        total_out.backward()\n        test_case.assertTrue(\n            np.allclose(\n                flow_nchw_weights.grad.numpy(),\n                np.transpose(flow_nhwc_permuted_weights.grad.numpy(), (0, 3, 1, 2)),\n                rtol=1e-3,\n                atol=1e-4,\n            )\n        )\n        test_case.assertTrue(\n            np.allclose(\n                flow_nchw_input.grad.numpy(),\n                flow_nhwc_input.grad.numpy(),\n                rtol=1e-4,\n                atol=1e-4,\n            )\n        )\n        os.environ[\"ONEFLOW_ENABLE_NHWC\"] = \"0\"\n\n    @profile(torch.nn.functional.conv2d)\n    def profile_conv2d(test_case):\n        input = torch.ones(8, 128, 28, 28)\n        weight_128c = torch.ones(128, 128, 3, 3)\n        weight_128c_2g = torch.ones(128, 64, 3, 3)\n        weight_1x1_128c = torch.ones(128, 128, 1, 1)\n        weight_5x5_128c = torch.ones(128, 128, 5, 5)\n        bias = torch.ones(128)\n        torch.nn.functional.conv2d(input, weight_128c, padding=1)\n        torch.nn.functional.conv2d(input, weight_128c_2g, groups=2, padding=1)\n        torch.nn.functional.conv2d(input, weight_128c, padding=1, stride=2)\n        torch.nn.functional.conv2d(input, weight_128c, bias=bias, padding=1)\n        torch.nn.functional.conv2d(input, weight_128c, bias=bias, padding=1, stride=2)\n        torch.nn.functional.conv2d(input, weight_1x1_128c)\n        torch.nn.functional.conv2d(input, weight_1x1_128c, stride=2)\n        torch.nn.functional.conv2d(input, weight_1x1_128c, bias=bias)\n        torch.nn.functional.conv2d(input, weight_1x1_128c, bias=bias, stride=2)\n        torch.nn.functional.conv2d(input, weight_5x5_128c, padding=2)\n        torch.nn.functional.conv2d(input, weight_5x5_128c, padding=2, stride=2)\n        torch.nn.functional.conv2d(input, weight_5x5_128c, bias=bias, padding=2)\n        torch.nn.functional.conv2d(\n            input, weight_5x5_128c, bias=bias, padding=2, stride=2\n        )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_copy.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nimport torch as ori_torch\n\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@flow.unittest.skip_unless_1n1d()\nclass Test_Copy_module(flow.unittest.TestCase):\n    def test_copy_broadcast_tensor(test_case):\n        torch_base_grid = ori_torch.zeros(1, 2, 2, 3)\n        flow_base_grid = flow.zeros(1, 2, 2, 3)\n        torch_x_grid = ori_torch.ones(2)\n        flow_x_grid = flow.ones(2)\n        torch_base_grid[..., 0].copy_(torch_x_grid)\n        flow_base_grid[..., 0].copy_(flow_x_grid)\n        test_case.assertTrue(\n            np.allclose(torch_base_grid.numpy(), flow_base_grid.numpy())\n        )\n\n    def test_non_contiguous_sliced_tensor_copy(test_case):\n        torch_tensor = torch.arange(24, dtype=torch.float32).reshape(1, 2, 3, 4)\n        flow_tensor = flow.arange(24, dtype=flow.float32).reshape(1, 2, 3, 4)\n        torch_copy = torch.tensor([3.1415])\n        flow_copy = flow.tensor([3.1415])\n        torch_tensor[:, 1:2, 1:2, ::2].copy_(torch_copy)\n        flow_tensor[:, 1:2, 1:2, ::2].copy_(flow_copy)\n        test_case.assertTrue(np.allclose(flow_tensor.numpy(), torch_tensor.numpy()))\n\n    def test_non_contiguous_permuted_tensor_copy(test_case):\n        torch_tensor = torch.arange(24, dtype=torch.float32).reshape(1, 2, 3, 4)\n        flow_tensor = flow.arange(24, dtype=flow.float32).reshape(1, 2, 3, 4)\n        torch_copy = torch.tensor([3.1415])\n        flow_copy = flow.tensor([3.1415])\n        torch_tensor.permute(0, 2, 1, 3).copy_(torch_copy)\n        flow_tensor.permute(0, 2, 1, 3).copy_(flow_copy)\n        test_case.assertTrue(np.allclose(flow_tensor.numpy(), torch_tensor.numpy()))\n\n    def test_copy_fp16(test_case):\n        x = flow.tensor([1, 2], dtype=flow.float16)\n        a = np.array([0, 9], dtype=np.float16)\n        x.copy_(a)\n        test_case.assertTrue(np.array_equal(x.numpy(), a))\n\n    def test_tensor_inplace_copy_with_diff_dtype(test_case):\n        x = flow.randn(4, 12).to(flow.int)\n        y = flow.randn(4, 12)\n        y.copy_(x)\n        test_case.assertTrue(np.array_equal(y.numpy(), x.numpy()))\n\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    def test_tensor_inplace_copy_with_diff_dtype_and_device(test_case):\n        x = flow.randn(4, 12).to(flow.int)\n        y = flow.randn(4, 12).to(\"cuda\")\n        y.copy_(x)\n        test_case.assertTrue(np.array_equal(y.numpy(), x.numpy()))\n\n    @unittest.skip(\"skip for now, becase it failed 6 times in past week\")\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    def test_global_tensor_inplace_copy_with_diff_dtype_and_device(test_case):\n        x = (\n            flow.randn(4, 12)\n            .to(flow.int)\n            .to_global(placement=flow.placement.all(\"cpu\"), sbp=flow.sbp.broadcast)\n        )\n        y = flow.randn(4, 12).to_global(\n            placement=flow.placement.all(\"cuda\"), sbp=flow.sbp.broadcast\n        )\n        y.copy_(x)\n        test_case.assertTrue(np.array_equal(y.numpy(), x.numpy()))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_cosine_similarity.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestCosineSimilarity(flow.unittest.TestCase):\n    @autotest(n=3)\n    def test_cosine_similartiy_module_with_random_data(test_case):\n        device = random_device()\n        a = random_tensor(ndim=2, dim0=10, dim1=128).to(device)\n        b = random_tensor(ndim=2, dim0=10, dim1=128).to(device)\n        cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6).to(device)\n        cos.train(random())\n        output = cos(a, b)\n        return output\n\n    @autotest(n=3)\n    def test_cosine_similartiy_functional_with_random_data(test_case):\n        device = random_device()\n        a = random_tensor(ndim=2, dim0=10, dim1=128).to(device)\n        b = random_tensor(ndim=2, dim0=10, dim1=128).to(device)\n        output = torch.nn.functional.cosine_similarity(a, b, dim=1, eps=1e-6)\n        return output\n\n    @unittest.skip(\"skip for now, becase it failed 4 times in past week\")\n    @autotest(n=3)\n    def test_cosine_similartiy_broadcast_with_random_data(test_case):\n        device = random_device()\n        a = random_tensor(ndim=2, dim0=10, dim1=128).to(device)\n        b = random_tensor(ndim=2, dim0=1, dim1=128).to(device)\n        output = torch.nn.functional.cosine_similarity(a, b, dim=1, eps=1e-6)\n        return output\n\n    @autotest(n=3)\n    def test_cosine_similartiy_module_with_nonequal_dim_data(test_case):\n        device = random_device()\n        a = random_tensor(ndim=2, dim0=10, dim1=128).to(device)\n        b = random_tensor(ndim=3, dim0=10, dim1=10, dim2=128).to(device)\n        cos = torch.nn.CosineSimilarity(dim=-1, eps=1e-6).to(device)\n        cos.train(random())\n        output = cos(a, b)\n        return output\n\n    @unittest.skip(\n        reason=\"https://github.com/Oneflow-Inc/oneflow/issues/8881#issuecomment-1229682453\"\n    )\n    @profile(torch.nn.functional.cosine_similarity)\n    def profile_cosine_similarity(test_case):\n        input1 = torch.ones(100, 128)\n        input2 = torch.ones(100, 128)\n        torch.nn.functional.cosine_similarity(input1, input2)\n        torch.nn.functional.cosine_similarity(input1, input2, dim=0)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_ctc_greedy_decoder.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport oneflow.unittest\n\nninf = -float(\"inf\")\n\n\ndef log_softmax(logits, axis=0):\n    max_value = np.max(logits, axis, keepdims=True)\n    exp = np.exp(logits - max_value)\n    exp_sum = np.sum(exp, axis, keepdims=True)\n    dist = exp / exp_sum\n    return np.log(dist)\n\n\ndef np_ctc_greedy_decoder(log_probs, input_lengths, merge_repeated=True):\n    blank_label = log_probs.shape[2] - 1\n    decodes = np.zeros(\n        (log_probs.shape[1], log_probs.shape[0]), dtype=input_lengths.dtype\n    )\n    neg_sum_logits = np.zeros((input_lengths.size, 1), dtype=log_probs.dtype)\n    for b in range(input_lengths.size):\n        input_length = input_lengths[b]\n        prev_indices = -1\n        t_dec = 0\n        for t in range(input_length):\n            max_indice = np.argmax(log_probs[t, b, :])\n            neg_sum_logits[b, 0] -= log_probs[t, b, max_indice]\n            if max_indice != blank_label and (\n                not (merge_repeated and max_indice == prev_indices)\n            ):\n                decodes[b, t_dec] = max_indice\n                t_dec += 1\n            prev_indices = max_indice\n    return (decodes, neg_sum_logits)\n\n\ndef compare_with_np(\n    device_type, data_type, max_input_length, batch_size, num_classes, merge_repeated,\n):\n    assert data_type in [\"float32\", \"double\"]\n    assert device_type in [\"cpu\", \"cuda\"]\n    assert merge_repeated in [False, True]\n\n    log_probs = np.random.random(\n        size=(max_input_length, batch_size, num_classes)\n    ).astype(np.float32)\n    log_probs = log_softmax(log_probs, axis=2)\n    input_lengths = np.random.randint(\n        max_input_length / 2, high=max_input_length, size=(batch_size,), dtype=np.int64\n    )\n    (np_decoded, np_neg_sum_logits) = np_ctc_greedy_decoder(\n        log_probs, input_lengths, merge_repeated\n    )\n\n    log_probs = flow.tensor(\n        log_probs,\n        dtype=flow.float32,\n        requires_grad=False,\n        device=flow.device(device_type),\n    )\n\n    input_lengths = flow.tensor(\n        input_lengths,\n        dtype=flow.int64,\n        requires_grad=False,\n        device=flow.device(device_type),\n    )\n\n    (of_decoded, of_neg_sum_logits) = flow.nn.functional.ctc_greedy_decoder(\n        log_probs, input_lengths, merge_repeated\n    )\n    np.allclose(of_decoded.numpy(), np_decoded, atol=1e-05)\n    np.allclose(of_neg_sum_logits.numpy(), np_neg_sum_logits, atol=1e-05)\n\n\ndef gen_arg_list():\n    arg_dict = OrderedDict()\n    arg_dict[\"device_type\"] = [\"cpu\", \"cuda\"]\n    arg_dict[\"data_type\"] = [\"float32\"]\n    arg_dict[\"max_input_length\"] = [20]\n    arg_dict[\"batch_size\"] = [4]\n    arg_dict[\"num_classes\"] = [5]\n    arg_dict[\"merge_repeated\"] = [False, True]\n    return GenArgList(arg_dict)\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestCTCGreedyDecoder1n1d(flow.unittest.TestCase):\n    def test_ctc_greedy_decoder(test_case):\n        for arg in gen_arg_list():\n            compare_with_np(*arg)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_ctc_loss.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport os\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nfrom oneflow.test_utils.test_util import GenArgList\nfrom oneflow.test_utils.automated_test_util import *\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestCTCLoss1n1d(flow.unittest.TestCase):\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    # This test case can always success out of ci container, but will get error in ci container for unknown reason: error:\n    # 'oneflow.ctc_loss' op attribute 'blank' failed to satisfy constraint: 32-bit signed integer attribute\n    # loc(\"-\":0:0): error: Failed to run round-trip passes\n    @autotest(n=5, check_graph=False)\n    def test_ctc_loss_with_diff_device_input(test_case):\n        log_probs = torch.tensor(\n            [\n                [[-1.1031, -0.7998, -1.5200], [-0.9808, -1.1363, -1.1908]],\n                [[-1.2258, -1.0665, -1.0153], [-1.1135, -1.2331, -0.9671]],\n                [[-1.3348, -0.6611, -1.5118], [-0.9823, -1.2355, -1.0941]],\n                [[-1.3850, -1.3273, -0.7247], [-0.8235, -1.4783, -1.0994]],\n                [[-0.9049, -0.8867, -1.6962], [-1.4938, -1.3630, -0.6547]],\n            ],\n            dtype=torch.float32,\n            requires_grad=True,\n        )\n        targets = torch.tensor([[1, 2, 2], [1, 2, 2]], dtype=torch.int32, device=\"cuda\")\n        input_lengths = torch.tensor([5, 5], dtype=torch.int32)\n        target_lengths = torch.tensor([3, 3], dtype=torch.int32)\n        loss_mean = torch.nn.CTCLoss(reduction=oneof(\"mean\", \"none\", \"sum\", nothing()))\n        out = loss_mean(log_probs, targets, input_lengths, target_lengths)\n        return out\n\n    @unittest.skip(\"skip for now, becase it failed 10 times in past week\")\n    @autotest(n=5, check_graph=False)\n    def test_ctc_loss_functional(test_case):\n        device_random = random_device()\n        log_probs = random_tensor(ndim=3, dim0=5, dim1=2, dim2=3).to(device_random)\n        targets = random_tensor(ndim=2, dim0=2, dim1=3, low=1, high=3, dtype=int).to(\n            device_random\n        )\n        input_lengths = torch.tensor([5, 5], dtype=torch.int32)\n        target_lengths = torch.tensor([3, 3], dtype=torch.int32)\n        out = torch.nn.functional.ctc_loss(\n            log_probs,\n            targets,\n            input_lengths,\n            target_lengths,\n            reduction=oneof(\"mean\", \"none\", \"sum\", nothing()),\n        )\n        return out\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_cublas_fused_mlp.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nfrom collections import OrderedDict\nimport numpy as np\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\n\n\ndef _matmul_bias_relu(x, weight, bias, skip_activation):\n    out = flow._C.bias_add(flow._C.matmul(x, weight, transpose_b=True), bias, axis=1)\n    if not skip_activation:\n        out = flow._C.relu(out)\n    return out\n\n\ndef _test_fused_matmul_bias_add_relu(\n    test_case,\n    batchsize,\n    in_feature,\n    hidden_size_list,\n    out_feature,\n    skip_final_activation,\n    dtype,\n    device,\n):\n    x = np.random.uniform(low=-1, high=1, size=(batchsize, in_feature))\n\n    fused_x = flow.tensor(x, dtype=dtype, device=device, requires_grad=True)\n    naive_x = flow.tensor(x, dtype=dtype, device=device, requires_grad=True)\n\n    fused_weight_list = []\n    naive_weight_list = []\n    fused_bias_list = []\n    naive_bias_list = []\n\n    hidden_num = len(hidden_size_list)\n\n    if hidden_num != 0:\n        np_first_weight = np.random.uniform(\n            low=-1, high=1, size=(hidden_size_list[0], in_feature)\n        )\n        np_first_bias = np.random.uniform(low=-1, high=1, size=hidden_size_list[0])\n\n        fused_weight_list.append(\n            flow.tensor(np_first_weight, dtype=dtype, device=device, requires_grad=True)\n        )\n        fused_bias_list.append(\n            flow.tensor(np_first_bias, dtype=dtype, device=device, requires_grad=True)\n        )\n        naive_weight_list.append(\n            flow.tensor(np_first_weight, dtype=dtype, device=device, requires_grad=True)\n        )\n        naive_bias_list.append(\n            flow.tensor(np_first_bias, dtype=dtype, device=device, requires_grad=True)\n        )\n\n    for idx in range(1, hidden_num):\n        np_weight = np.random.uniform(\n            low=-1, high=1, size=(hidden_size_list[idx], hidden_size_list[idx - 1])\n        )\n        np_bias = np.random.uniform(low=-1, high=1, size=hidden_size_list[idx])\n\n        fused_weight_list.append(\n            flow.tensor(np_weight, dtype=dtype, device=device, requires_grad=True)\n        )\n        fused_bias_list.append(\n            flow.tensor(np_bias, dtype=dtype, device=device, requires_grad=True)\n        )\n        naive_weight_list.append(\n            flow.tensor(np_weight, dtype=dtype, device=device, requires_grad=True)\n        )\n        naive_bias_list.append(\n            flow.tensor(np_bias, dtype=dtype, device=device, requires_grad=True)\n        )\n\n    np_final_weight = np.random.uniform(low=-1, high=1, size=(out_feature, in_feature))\n\n    if hidden_num != 0:\n        np_final_weight = np.random.uniform(\n            low=-1, high=1, size=(out_feature, hidden_size_list[-1])\n        )\n\n    np_final_bias = np.random.uniform(low=-1, high=1, size=(out_feature))\n\n    fused_weight_list.append(\n        flow.tensor(np_final_weight, dtype=dtype, device=device, requires_grad=True)\n    )\n    fused_bias_list.append(\n        flow.tensor(np_final_bias, dtype=dtype, device=device, requires_grad=True)\n    )\n    naive_weight_list.append(\n        flow.tensor(np_final_weight, dtype=dtype, device=device, requires_grad=True)\n    )\n    naive_bias_list.append(\n        flow.tensor(np_final_bias, dtype=dtype, device=device, requires_grad=True)\n    )\n\n    fused_out = flow._C.fused_mlp(\n        fused_x,\n        fused_weight_list,\n        fused_bias_list,\n        skip_final_activation=skip_final_activation,\n    )\n\n    naive_out = _matmul_bias_relu(\n        naive_x,\n        naive_weight_list[0],\n        naive_bias_list[0],\n        False if hidden_num != 0 else skip_final_activation,\n    )\n\n    for idx in range(1, hidden_num + 1):\n        if idx == hidden_num:\n            naive_out = _matmul_bias_relu(\n                naive_out,\n                naive_weight_list[idx],\n                naive_bias_list[idx],\n                skip_final_activation,\n            )\n        else:\n            naive_out = _matmul_bias_relu(\n                naive_out, naive_weight_list[idx], naive_bias_list[idx], False\n            )\n\n    total_out = fused_out.sum() + naive_out.sum()\n    total_out.backward()\n\n    # Test output equality\n    test_case.assertTrue(\n        np.allclose(fused_out.numpy(), naive_out.numpy(), atol=1e-4, rtol=1e-4)\n    )\n    # Test weight grad equality\n    for idx in range(hidden_num + 1):\n        test_case.assertTrue(\n            np.allclose(\n                fused_weight_list[idx].grad.numpy(),\n                naive_weight_list[idx].grad.numpy(),\n                atol=1e-4,\n                rtol=1e-4,\n            )\n        )\n        test_case.assertTrue(\n            np.allclose(\n                fused_bias_list[idx].grad.numpy(),\n                naive_bias_list[idx].grad.numpy(),\n                atol=1e-4,\n                rtol=1e-4,\n            )\n        )\n    # Test dx equality\n    test_case.assertTrue(\n        np.allclose(fused_x.grad.numpy(), naive_x.grad.numpy(), atol=1e-4, rtol=1e-4)\n    )\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestFusedMatmulBiasAddRelu(flow.unittest.TestCase):\n    def test_fused_matmul_op(test_case):\n        args_dict = OrderedDict()\n        args_dict[\"test_fun\"] = [_test_fused_matmul_bias_add_relu]\n        args_dict[\"batchsize\"] = [1, 2, 4]\n        args_dict[\"in_feature\"] = [96, 128]\n        args_dict[\"hidden_size_list\"] = [[256, 512], [256], [96, 144], []]\n        args_dict[\"out_feature\"] = [512, 1024, 288, 1]\n        args_dict[\"skip_final_activation\"] = [True, False]\n        args_dict[\"dtype\"] = [flow.float32, flow.float64]\n        args_dict[\"device\"] = [\"cuda\", \"cpu\"]\n\n        for arg in GenArgList(args_dict):\n            arg[0](test_case, *arg[1:])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_cum_ops.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nfrom collections import OrderedDict\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.unittest\nimport torch as ori_torch\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestCumOp(flow.unittest.TestCase):\n    @autotest(n=5, check_graph=True)\n    def test_cumsum(test_case):\n        device = random_device()\n        x = random_tensor().to(device)\n        dim = random(0, x.ndim.pytorch).to(int)\n        z = torch.cumsum(x, dim)\n        return z\n\n    @autotest(n=5, check_graph=True)\n    def test_cumprod(test_case):\n        device = random_device()\n        x = random_tensor().to(device)\n        dim = random(0, x.ndim.pytorch).to(int)\n        y = torch.cumprod(x, dim)\n        return y\n\n    def test_cumop_with_dtype(test_case):\n        x = flow.tensor([2, 3, 4])\n        cumsum_res = flow.cumsum(x, dim=0, dtype=flow.float)\n        cumprod_res = flow.cumprod(x, dim=0, dtype=flow.float)\n        test_case.assertEqual(cumsum_res.dtype, flow.float)\n        test_case.assertEqual(cumprod_res.dtype, flow.float)\n\n    @autotest(n=5, check_graph=True)\n    def test_cumsum(test_case):\n        device = random_device()\n        x = random_tensor().to(device)\n        dim = random(0, x.ndim.pytorch).to(int)\n        y = x.cumsum(dim)\n        return y\n\n    @autotest(n=5, check_graph=True)\n    def test_cumprod_with_user_dy(test_case):\n        device = random_device()\n        x = random_tensor().to(device)\n        dim = random(0, x.ndim.pytorch).to(int)\n        y = torch.cumprod(x, dim)\n        z = y * 2\n        return z\n\n    def test_cumprod_with_zero(test_case):\n        np_arr = np.ones((5, 5))\n        np_arr_grad = np_arr\n        np_arr[2][3] = 0\n        np_arr[4][3] = 0\n        of_tensor = flow.tensor(np_arr, dtype=flow.float, requires_grad=True)\n        of_res = of_tensor.cumprod(dim=0)\n        of_res.backward(flow.tensor(np_arr_grad, dtype=flow.float))\n\n        torch_tensor = ori_torch.tensor(\n            np_arr, dtype=ori_torch.float, requires_grad=True\n        )\n        torch_res = torch_tensor.cumprod(dim=0)\n        torch_res.backward(ori_torch.tensor(np_arr_grad, dtype=ori_torch.float))\n        test_case.assertTrue(\n            np.allclose(\n                of_tensor.grad.numpy(),\n                torch_tensor.grad.numpy(),\n                rtol=0.0001,\n                atol=1e-05,\n            )\n        )\n\n    def test_cumsum_graph_backward(test_case):\n        class CustomizedModule(flow.nn.Module):\n            def __init__(self):\n                super().__init__()\n                self.layer = flow.nn.Linear(5, 5)\n\n            def forward(self, input):\n                layer_out = self.layer(input)\n                loss = flow.cumsum(layer_out, -1)\n                loss = loss.sum()\n                loss.backward()\n                return loss\n\n        class TestCumsum(flow.nn.Graph):\n            def __init__(self) -> None:\n                super().__init__()\n                self.my_module = CustomizedModule()\n                self.add_optimizer(\n                    flow.optim.SGD(self.my_module.parameters(), lr=0.1, momentum=0.0)\n                )\n\n            def build(self, ids):\n                loss = self.my_module(ids)\n                return loss\n\n        ids = np.random.randint(0, 10, (5, 5), dtype=np.int64)\n        ids_tensor = flow.tensor(ids, dtype=flow.float, requires_grad=False)\n        graph = TestCumsum()\n        loss = graph(ids_tensor)\n\n    @profile(torch.cumsum)\n    def profile_cumsum(test_case):\n        input = torch.ones(100, 1280)\n        torch.cumsum(input, dim=0)\n        torch.cumsum(input, dim=1)\n\n    @profile(torch.cumprod)\n    def profile_cumprod(test_case):\n        input = torch.ones(100, 1280)\n        torch.cumprod(input, dim=0)\n        torch.cumprod(input, dim=1)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_dataset.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport math\nimport os\nimport unittest\n\nimport cv2\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestOFRecordModule(flow.unittest.TestCase):\n    def test_record(test_case):\n        batch_size = 1\n        color_space = \"RGB\"\n        height = 224\n        width = 224\n        output_layout = \"NCHW\"\n        rgb_mean = [123.68, 116.779, 103.939]\n        rgb_std = [58.393, 57.12, 57.375]\n        record_reader = flow.nn.OFRecordReader(\n            flow.unittest.dataset_dir(\"imagenette/ofrecord\"),\n            batch_size=batch_size,\n            data_part_num=1,\n            part_name_suffix_length=5,\n            shuffle_after_epoch=False,\n        )\n        record_image_decoder = flow.nn.OFRecordImageDecoder(\n            \"encoded\", color_space=color_space\n        )\n        record_label_decoder = flow.nn.OFRecordRawDecoder(\n            \"class/label\", shape=(), dtype=flow.int32\n        )\n        resize = flow.nn.image.Resize(\n            resize_side=\"shorter\", keep_aspect_ratio=True, target_size=256\n        )\n        crop_mirror_normal = flow.nn.CropMirrorNormalize(\n            color_space=color_space,\n            output_layout=output_layout,\n            crop_h=height,\n            crop_w=width,\n            crop_pos_y=0.5,\n            crop_pos_x=0.5,\n            mean=rgb_mean,\n            std=rgb_std,\n            output_dtype=flow.float,\n        )\n        val_record = record_reader()\n        label = record_label_decoder(val_record)\n        image_raw_buffer = record_image_decoder(val_record)\n        image_raw_buffer_nd = image_raw_buffer.numpy()\n        gt_np = cv2.imread(\n            flow.unittest.dataset_dir(\"imagenette/ofrecord/gt_tensor_buffer_image.png\")\n        )\n        test_case.assertTrue(np.array_equal(image_raw_buffer_nd[0], gt_np))\n        image = resize(image_raw_buffer)[0]\n        resized_image_raw_buffer_nd = image.numpy()\n        gt_np = cv2.imread(\n            flow.unittest.dataset_dir(\n                \"imagenette/ofrecord/gt_tensor_buffer_resized_image.png\"\n            )\n        )\n        test_case.assertTrue(np.array_equal(resized_image_raw_buffer_nd[0], gt_np))\n        image = crop_mirror_normal(image)\n        image_np = image.numpy()\n        image_np = np.squeeze(image_np)\n        image_np = np.transpose(image_np, (1, 2, 0))\n        image_np = image_np * rgb_std + rgb_mean\n        image_np = cv2.cvtColor(np.float32(image_np), cv2.COLOR_RGB2BGR)\n        image_np = image_np.astype(np.uint8)\n        gt_np = cv2.imread(\n            flow.unittest.dataset_dir(\"imagenette/ofrecord/gt_val_image.png\")\n        )\n        test_case.assertEqual(label.numpy(), 5)\n        test_case.assertTrue(np.array_equal(image_np, gt_np))\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestGlobalOFRecordModule(flow.unittest.TestCase):\n    def test_global_record(test_case):\n        batch_size = 1\n        color_space = \"RGB\"\n        height = 224\n        width = 224\n        output_layout = \"NCHW\"\n        rgb_mean = [123.68, 116.779, 103.939]\n        rgb_std = [58.393, 57.12, 57.375]\n        record_reader = flow.nn.OfrecordReader(\n            flow.unittest.dataset_dir(\"imagenette/ofrecord\"),\n            batch_size=batch_size,\n            data_part_num=1,\n            part_name_suffix_length=5,\n            shuffle_after_epoch=False,\n            placement=flow.placement(\"cpu\", ranks=[0]),\n            sbp=[flow.sbp.split(0)],\n        )\n        record_image_decoder = flow.nn.OFRecordImageDecoder(\n            \"encoded\", color_space=color_space\n        )\n        record_label_decoder = flow.nn.OfrecordRawDecoder(\n            \"class/label\", shape=(), dtype=flow.int32\n        )\n        resize = flow.nn.image.Resize(\n            resize_side=\"shorter\", keep_aspect_ratio=True, target_size=256\n        )\n        flip = flow.nn.CoinFlip(\n            batch_size=batch_size,\n            placement=flow.placement(\"cpu\", ranks=[0]),\n            sbp=[flow.sbp.split(0)],\n        )\n        crop_mirror_normal = flow.nn.CropMirrorNormalize(\n            color_space=color_space,\n            output_layout=output_layout,\n            crop_h=height,\n            crop_w=width,\n            crop_pos_y=0.5,\n            crop_pos_x=0.5,\n            mean=rgb_mean,\n            std=rgb_std,\n            output_dtype=flow.float,\n        )\n        rng = flip()\n        val_record = record_reader()\n        label = record_label_decoder(val_record)\n        image_raw_buffer = record_image_decoder(val_record)\n        image_raw_buffer_nd = image_raw_buffer.to_local().numpy()\n        gt_np = cv2.imread(\n            flow.unittest.dataset_dir(\"imagenette/ofrecord/gt_tensor_buffer_image.png\")\n        )\n        test_case.assertTrue(np.array_equal(image_raw_buffer_nd[0], gt_np))\n        image = resize(image_raw_buffer)[0]\n        resized_image_raw_buffer_nd = image.to_local().numpy()\n        gt_np = cv2.imread(\n            flow.unittest.dataset_dir(\n                \"imagenette/ofrecord/gt_tensor_buffer_resized_image.png\"\n            )\n        )\n        test_case.assertTrue(np.array_equal(resized_image_raw_buffer_nd[0], gt_np))\n        image = crop_mirror_normal(image)\n        image_np = image.to_local().numpy()\n        image_np = np.squeeze(image_np)\n        image_np = np.transpose(image_np, (1, 2, 0))\n        image_np = image_np * rgb_std + rgb_mean\n        image_np = cv2.cvtColor(np.float32(image_np), cv2.COLOR_RGB2BGR)\n        image_np = image_np.astype(np.uint8)\n        gt_np = cv2.imread(\n            flow.unittest.dataset_dir(\"imagenette/ofrecord/gt_val_image.png\")\n        )\n        test_case.assertEqual(label.to_local().numpy(), 5)\n        test_case.assertTrue(np.array_equal(image_np, gt_np))\n\n\ncoco_dict = dict()\n\n\ndef _coco(anno_file):\n    global coco_dict\n    if anno_file not in coco_dict:\n        from pycocotools.coco import COCO\n\n        coco_dict[anno_file] = COCO(anno_file)\n    return coco_dict[anno_file]\n\n\ndef _get_coco_image_samples(anno_file, image_dir, image_ids):\n    coco = _coco(anno_file)\n    category_id_to_contiguous_id_map = _get_category_id_to_contiguous_id_map(coco)\n    (image, image_size) = _read_images_with_cv(coco, image_dir, image_ids)\n    bbox = _read_bbox(coco, image_ids)\n    label = _read_label(coco, image_ids, category_id_to_contiguous_id_map)\n    img_segm_poly_list = _read_segm_poly(coco, image_ids)\n    (poly, poly_index) = _segm_poly_list_to_tensor(img_segm_poly_list)\n    samples = []\n    for (im, ims, b, l, p, pi) in zip(image, image_size, bbox, label, poly, poly_index):\n        samples.append(\n            dict(image=im, image_size=ims, bbox=b, label=l, poly=p, poly_index=pi)\n        )\n    return samples\n\n\ndef _get_category_id_to_contiguous_id_map(coco):\n    return {v: i + 1 for (i, v) in enumerate(coco.getCatIds())}\n\n\ndef _read_images_with_cv(coco, image_dir, image_ids):\n    image_files = [\n        os.path.join(image_dir, coco.imgs[img_id][\"file_name\"]) for img_id in image_ids\n    ]\n    image_size = [\n        (coco.imgs[img_id][\"height\"], coco.imgs[img_id][\"width\"])\n        for img_id in image_ids\n    ]\n    return (\n        [cv2.imread(image_file).astype(np.single) for image_file in image_files],\n        image_size,\n    )\n\n\ndef _bbox_convert_from_xywh_to_xyxy(bbox, image_h, image_w):\n    (x, y, w, h) = bbox\n    (x1, y1) = (x, y)\n    x2 = x1 + max(w - 1, 0)\n    y2 = y1 + max(h - 1, 0)\n    x1 = min(max(x1, 0), image_w - 1)\n    y1 = min(max(y1, 0), image_h - 1)\n    x2 = min(max(x2, 0), image_w - 1)\n    y2 = min(max(y2, 0), image_h - 1)\n    if x1 >= x2 or y1 >= y2:\n        return None\n    return [x1, y1, x2, y2]\n\n\ndef _read_bbox(coco, image_ids):\n    img_bbox_list = []\n    for img_id in image_ids:\n        anno_ids = coco.getAnnIds(imgIds=[img_id])\n        assert len(anno_ids) > 0, \"image with id {} has no anno\".format(img_id)\n        image_h = coco.imgs[img_id][\"height\"]\n        image_w = coco.imgs[img_id][\"width\"]\n        bbox_list = []\n        for anno_id in anno_ids:\n            anno = coco.anns[anno_id]\n            if anno[\"iscrowd\"] != 0:\n                continue\n            bbox = anno[\"bbox\"]\n            assert isinstance(bbox, list)\n            bbox_ = _bbox_convert_from_xywh_to_xyxy(bbox, image_h, image_w)\n            if bbox_ is not None:\n                bbox_list.append(bbox_)\n        bbox_array = np.array(bbox_list, dtype=np.single)\n        img_bbox_list.append(bbox_array)\n    return img_bbox_list\n\n\ndef _read_label(coco, image_ids, category_id_to_contiguous_id_map):\n    img_label_list = []\n    for img_id in image_ids:\n        anno_ids = coco.getAnnIds(imgIds=[img_id])\n        assert len(anno_ids) > 0, \"image with id {} has no anno\".format(img_id)\n        label_list = []\n        for anno_id in anno_ids:\n            anno = coco.anns[anno_id]\n            if anno[\"iscrowd\"] != 0:\n                continue\n            cate_id = anno[\"category_id\"]\n            isinstance(cate_id, int)\n            label_list.append(category_id_to_contiguous_id_map[cate_id])\n        label_array = np.array(label_list, dtype=np.int32)\n        img_label_list.append(label_array)\n    return img_label_list\n\n\ndef _read_segm_poly(coco, image_ids):\n    img_segm_poly_list = []\n    for img_id in image_ids:\n        anno_ids = coco.getAnnIds(imgIds=[img_id])\n        assert len(anno_ids) > 0, \"img {} has no anno\".format(img_id)\n        segm_poly_list = []\n        for anno_id in anno_ids:\n            anno = coco.anns[anno_id]\n            if anno[\"iscrowd\"] != 0:\n                continue\n            segm = anno[\"segmentation\"]\n            assert isinstance(segm, list)\n            assert len(segm) > 0, str(len(segm))\n            assert all([len(poly) > 0 for poly in segm]), str(\n                [len(poly) for poly in segm]\n            )\n            segm_poly_list.append(segm)\n        img_segm_poly_list.append(segm_poly_list)\n    return img_segm_poly_list\n\n\ndef _segm_poly_list_to_tensor(img_segm_poly_list):\n    poly_array_list = []\n    poly_index_array_list = []\n    for (img_idx, segm_poly_list) in enumerate(img_segm_poly_list):\n        img_poly_elem_list = []\n        img_poly_index_list = []\n        for (obj_idx, poly_list) in enumerate(segm_poly_list):\n            for (poly_idx, poly) in enumerate(poly_list):\n                img_poly_elem_list.extend(poly)\n                for (pt_idx, pt) in enumerate(poly):\n                    if pt_idx % 2 == 0:\n                        img_poly_index_list.append([pt_idx / 2, poly_idx, obj_idx])\n        img_poly_array = np.array(img_poly_elem_list, dtype=np.single).reshape(-1, 2)\n        assert img_poly_array.size > 0, segm_poly_list\n        poly_array_list.append(img_poly_array)\n        img_poly_index_array = np.array(img_poly_index_list, dtype=np.int32)\n        assert img_poly_index_array.size > 0, segm_poly_list\n        poly_index_array_list.append(img_poly_index_array)\n    return (poly_array_list, poly_index_array_list)\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestCocoReader(flow.unittest.TestCase):\n    def test_coco_reader(test_case):\n        anno_file = flow.unittest.dataset_dir(\n            \"mscoco_2017/annotations/instances_val2017.json\"\n        )\n        image_dir = flow.unittest.dataset_dir(\"mscoco_2017/val2017\")\n        num_iterations = 10\n        coco_reader = flow.nn.COCOReader(\n            annotation_file=anno_file,\n            image_dir=image_dir,\n            batch_size=2,\n            shuffle=True,\n            stride_partition=True,\n        )\n        image_decoder = flow.nn.image.decode(dtype=flow.float)\n        for i in range(num_iterations):\n            (\n                image,\n                image_id,\n                image_size,\n                gt_bbox,\n                gt_label,\n                gt_segm,\n                gt_segm_index,\n            ) = coco_reader()\n            decoded_image = image_decoder(image)\n            image_list = decoded_image.numpy()\n            image_id = image_id.numpy()\n            image_size = image_size.numpy()\n            bbox_list = gt_bbox.numpy()\n            label_list = gt_label.numpy()\n            segm_list = gt_segm.numpy()\n            segm_index_list = gt_segm_index.numpy()\n            samples = _get_coco_image_samples(anno_file, image_dir, image_id)\n            for (i, sample) in enumerate(samples):\n                test_case.assertTrue(np.array_equal(image_list[i], sample[\"image\"]))\n                test_case.assertTrue(\n                    np.array_equal(image_size[i], sample[\"image_size\"])\n                )\n                test_case.assertTrue(np.allclose(bbox_list[i], sample[\"bbox\"]))\n                cur_label = label_list[i]\n                if len(cur_label.shape) == 0:\n                    cur_label = np.array([cur_label])\n                test_case.assertTrue(np.array_equal(cur_label, sample[\"label\"]))\n                test_case.assertTrue(np.allclose(segm_list[i], sample[\"poly\"]))\n                test_case.assertTrue(\n                    np.array_equal(segm_index_list[i], sample[\"poly_index\"])\n                )\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestOFRecordBytesDecoder(flow.unittest.TestCase):\n    def test_OFRecordBytesDecoder(test_case):\n        batch_size = 16\n        record_reader = flow.nn.OFRecordReader(\n            flow.unittest.dataset_dir(\"imagenette/ofrecord\"),\n            batch_size=batch_size,\n            part_name_suffix_length=5,\n        )\n        val_record = record_reader()\n\n        bytesdecoder_img = flow.nn.OFRecordBytesDecoder(\"encoded\")\n\n        image_raw_buffer = bytesdecoder_img(val_record)\n\n        image_raw_buffer_nd = image_raw_buffer.numpy()[0]\n        gt_np = cv2.imread(\n            flow.unittest.dataset_dir(\"imagenette/ofrecord/gt_tensor_buffer_image.png\")\n        )\n        img = cv2.imdecode(image_raw_buffer_nd, cv2.IMREAD_COLOR)\n        test_case.assertTrue(np.array_equal(img, gt_np))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_ddp.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nimport oneflow as flow\n\n# Test import from oneflow.nn.parallel.distributed\nfrom oneflow.nn.parallel.distributed import DistributedDataParallel\nfrom oneflow.nn.parallel import DistributedDataParallel as ddp\nfrom oneflow.test_utils.test_util import GenCartesianProduct\nimport oneflow.unittest\n\nimport numpy as np\nimport os\n\n\ndef np_allclose_with_shape(a, b, *args, **kwargs):\n    if a.shape != b.shape:\n        return False\n    return np.allclose(a, b, *args, **kwargs)\n\n\ntest_device = [\"cpu\"] if os.getenv(\"ONEFLOW_TEST_CPU_ONLY\") else [\"cpu\", \"cuda\"]\n\n\n@flow.unittest.skip_unless_1n2d()\nclass TestDDP(flow.unittest.TestCase):\n    def _test_ddp_basic(test_case, dev_type):\n        class Mul(flow.nn.Module):\n            def __init__(self):\n                super().__init__()\n                self.w = flow.nn.Parameter(flow.Tensor([1, 1]))\n\n            def forward(self, x):\n                return x * self.w\n\n        rank = flow.env.get_rank()\n        if rank == 0:\n            x = flow.Tensor([1, 1])\n        elif rank == 1:\n            x = flow.Tensor([2, 2])\n        else:\n            raise ValueError()\n\n        x = x.to(dev_type)\n        m = Mul().to(dev_type)\n        m = ddp(m)\n        y = m(x)\n        y.sum().backward()\n\n        test_case.assertTrue(\n            np_allclose_with_shape(m.w.grad.numpy(), np.array([1.5, 1.5]))\n        )\n\n    def test_ddp_basic(test_case):\n        for dev_type in test_device:\n            test_case._test_ddp_basic(dev_type)\n\n    def _test_ddp_multiple_buckets(test_case, dev_type, use_bucket):\n        class Mul(flow.nn.Module):\n            def __init__(self):\n                super().__init__()\n                for i in range(10):\n                    self.register_parameter(\n                        f\"w{i}\", flow.nn.Parameter(flow.Tensor([i % 2 + 1, i % 2 + 1]))\n                    )\n\n            def forward(self, x):\n                for i in range(10):\n                    x = x * getattr(self, f\"w{i}\")\n                return x\n\n        rank = flow.env.get_rank()\n        if rank == 0:\n            x = flow.Tensor([1, 1])\n        elif rank == 1:\n            x = flow.Tensor([2, 2])\n        else:\n            raise ValueError()\n\n        x = x.to(dev_type)\n        m = Mul().to(dev_type)\n        m = ddp(m, bucket_size=3, use_bucket=use_bucket)\n\n        y = m(x)\n        y.sum().backward()\n\n        for i in range(10):\n            test_case.assertTrue(\n                np_allclose_with_shape(\n                    getattr(m, f\"w{i}\").grad.numpy(),\n                    np.array([48, 48]) if i % 2 == 0 else np.array([24, 24]),\n                )\n            )\n\n    def test_ddp_multiple_buckets(test_case):\n        for dev_type, use_bucket in GenCartesianProduct((test_device, [True, False])):\n            test_case._test_ddp_multiple_buckets(dev_type, use_bucket)\n\n    def _test_ddp_with_unused_param(test_case, dev_type):\n        class Model(flow.nn.Module):\n            def __init__(self):\n                super().__init__()\n                self.w = flow.nn.Parameter(flow.Tensor([1]))\n                self.used_only_in_rank0 = flow.nn.Parameter(flow.Tensor([2]))\n                self.unused_in_all_ranks = flow.nn.Parameter(flow.Tensor([3]))\n\n            def forward(self, x):\n                x = x * self.w\n                if flow.env.get_rank() == 0:\n                    x = x * self.used_only_in_rank0\n                return x\n\n        rank = flow.env.get_rank()\n        if rank == 0:\n            x = flow.Tensor([1])\n        elif rank == 1:\n            x = flow.Tensor([2])\n        else:\n            raise ValueError()\n\n        x = x.to(dev_type)\n        m = Model().to(dev_type)\n        m = ddp(m, bucket_size=2)\n        y = m(x)\n        y.backward()\n\n        test_case.assertTrue(np_allclose_with_shape(m.w.grad.numpy(), np.array([2])))\n        test_case.assertTrue(\n            np_allclose_with_shape(m.used_only_in_rank0.grad.numpy(), np.array([0.5]))\n        )\n        test_case.assertTrue(\n            np_allclose_with_shape(m.unused_in_all_ranks.grad.numpy(), np.array([0]))\n        )\n\n    def test_ddp_with_unused_param(test_case):\n        for dev_type in test_device:\n            test_case._test_ddp_with_unused_param(dev_type)\n\n    def _test_out_of_order_execution(test_case, dev_type):\n        class Model(flow.nn.Module):\n            def __init__(self):\n                super().__init__()\n                self.w1 = flow.nn.Parameter(flow.Tensor([1]))\n                self.w2 = flow.nn.Parameter(flow.Tensor([2]))\n                self.w3 = flow.nn.Parameter(flow.Tensor([3]))\n\n            def forward(self, x):\n                if flow.env.get_rank() == 0:\n                    x *= self.w1\n                    x *= self.w2\n                    x *= self.w3\n                else:\n                    x *= self.w3\n                    x *= self.w2\n                    x *= self.w1\n                return x\n\n        rank = flow.env.get_rank()\n        if rank == 0:\n            x = flow.Tensor([1])\n        elif rank == 1:\n            x = flow.Tensor([2])\n        else:\n            raise ValueError()\n\n        x = x.to(dev_type)\n        m = Model().to(dev_type)\n        m = ddp(m, bucket_size=1)\n        y = m(x)\n        y.backward()\n\n        test_case.assertTrue(np_allclose_with_shape(m.w1.grad.numpy(), np.array([9])))\n        test_case.assertTrue(np_allclose_with_shape(m.w2.grad.numpy(), np.array([4.5])))\n        test_case.assertTrue(np_allclose_with_shape(m.w3.grad.numpy(), np.array([3])))\n\n    def test_out_of_order_execution(test_case):\n        for dev_type in test_device:\n            test_case._test_out_of_order_execution(dev_type)\n\n    def _test_ddp_with_partial_requires_grad_parameter(test_case, dev_type):\n        class Model(flow.nn.Module):\n            def __init__(self):\n                super().__init__()\n                self.w1 = flow.nn.Parameter(flow.Tensor([1]), requires_grad=False)\n                self.w2 = flow.nn.Parameter(flow.Tensor([2]))\n                self.w3 = flow.nn.Parameter(flow.Tensor([3]))\n\n            def forward(self, x):\n                if flow.env.get_rank() == 0:\n                    x *= self.w1\n                    x *= self.w2\n                    x *= self.w3\n                else:\n                    x *= self.w3\n                    x *= self.w2\n                    x *= self.w1\n                return x\n\n        rank = flow.env.get_rank()\n        if rank == 0:\n            x = flow.Tensor([1])\n        elif rank == 1:\n            x = flow.Tensor([2])\n        else:\n            raise ValueError()\n\n        x = x.to(dev_type)\n        m = Model().to(dev_type)\n        m = ddp(m, bucket_size=1)\n        y = m(x)\n        y.backward()\n\n        test_case.assertTrue(np_allclose_with_shape(m.w2.grad.numpy(), np.array([4.5])))\n        test_case.assertTrue(np_allclose_with_shape(m.w3.grad.numpy(), np.array([3])))\n\n    def test_ddp_with_partial_requires_grad_parameter(test_case):\n        for dev_type in test_device:\n            test_case._test_ddp_with_partial_requires_grad_parameter(dev_type)\n\n    def _test_ddp_two_iters(test_case, dev_type):\n        class Mul(flow.nn.Module):\n            def __init__(self):\n                super().__init__()\n                self.w = flow.nn.Parameter(flow.Tensor([1, 1]))\n\n            def forward(self, x):\n                return x * self.w\n\n        rank = flow.env.get_rank()\n        if rank == 0:\n            x = flow.Tensor([1, 1])\n        elif rank == 1:\n            x = flow.Tensor([2, 2])\n        else:\n            raise ValueError()\n\n        x = x.to(dev_type)\n        m = Mul().to(dev_type)\n        m = ddp(m)\n\n        for _ in range(2):\n            y = m(x)\n            y.sum().backward()\n\n        test_case.assertTrue(np_allclose_with_shape(m.w.grad.numpy(), np.array([3, 3])))\n\n    def test_ddp_two_iters(test_case):\n        for dev_type in test_device:\n            test_case._test_ddp_two_iters(dev_type)\n\n    def _test_broadcast_buffer(test_case, dev_type):\n        rank = flow.env.get_rank()\n\n        class CustomModule(flow.nn.Module):\n            def __init__(self):\n                super().__init__()\n                self.register_buffer(\"buf\", flow.tensor([1, 2]) * (rank + 1))\n\n            def forward(self, x):\n                res = self.buf + x\n                self.buf.copy_(x)\n                return res\n\n        x = flow.tensor([2, 3]) * (rank + 1)\n        x = x.to(dev_type)\n\n        m = CustomModule()\n        m = m.to(dev_type)\n        m = ddp(m)\n\n        y1 = m(x)\n        y2 = m(x)\n\n        m = CustomModule()\n        m = m.to(dev_type)\n        m = ddp(m, broadcast_buffers=False)\n\n        y3 = m(x)\n        y4 = m(x)\n\n        if rank == 0:\n            test_case.assertTrue(np_allclose_with_shape(y1.numpy(), np.array([3, 5])))\n            test_case.assertTrue(np_allclose_with_shape(y2.numpy(), np.array([4, 6])))\n            test_case.assertTrue(np_allclose_with_shape(y3.numpy(), np.array([3, 5])))\n            test_case.assertTrue(np_allclose_with_shape(y4.numpy(), np.array([4, 6])))\n        elif rank == 1:\n            test_case.assertTrue(np_allclose_with_shape(y1.numpy(), np.array([5, 8])))\n            test_case.assertTrue(np_allclose_with_shape(y2.numpy(), np.array([6, 9])))\n            test_case.assertTrue(np_allclose_with_shape(y3.numpy(), np.array([6, 10])))\n            test_case.assertTrue(np_allclose_with_shape(y4.numpy(), np.array([8, 12])))\n        else:\n            raise ValueError()\n\n    def test_broadcast_buffer(test_case):\n        for dev_type in test_device:\n            test_case._test_broadcast_buffer(dev_type)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_ddp_multi_outputs.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nimport os\nimport oneflow as flow\nfrom oneflow.nn.parallel import DistributedDataParallel as ddp\nimport oneflow.unittest\nfrom collections import OrderedDict\nfrom oneflow.test_utils.test_util import GenArgDict\n\ntrain_x = [\n    flow.tensor([[1, 2], [2, 3]], dtype=flow.float32),\n    flow.tensor([[4, 6], [3, 1]], dtype=flow.float32),\n]\n\ntrain_float32 = [\n    flow.tensor([[1, 2], [2, 3]], dtype=flow.float32),\n    flow.tensor([[4, 6], [3, 1]], dtype=flow.float32),\n]\n\ntrain_int32 = [\n    flow.tensor([[8], [13]], dtype=flow.int32),\n    flow.tensor([[26], [9]], dtype=flow.int32),\n]\n\n\nclass Model(flow.nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.lr = 0.01\n        self.iter_count = 10\n        self.w1 = flow.nn.Parameter(flow.tensor([[0], [0]], dtype=flow.float32))\n        self.w2 = flow.nn.Parameter(flow.tensor([[0], [0]], dtype=flow.float32))\n\n    def forward(self, x, label):\n        if flow.env.get_rank() == 0:\n            x1 = flow.matmul(x, self.w1)\n        else:\n            x1 = flow.matmul(x, self.w2)\n        return ([x1, label + 1], label + 2)\n\n\ndef train(test_case, train_x, device, output, requires_grad):\n    m = Model().to(device)\n    m = ddp(m)\n    loss = flow.nn.MSELoss(reduction=\"sum\")\n    optimizer = flow.optim.SGD(m.parameters(), m.lr)\n\n    for i in range(0, m.iter_count):\n        rank = flow.env.get_rank()\n\n        x = train_x[rank].clone().to(device)\n        y = output[rank].clone().to(device)\n        y.requires_grad = requires_grad\n        (y_pred, y_add_1), y_add_2 = m(x, y)\n        test_case.assertEqual(y_add_1.requires_grad, y.requires_grad)\n        test_case.assertEqual(y_add_2.requires_grad, y.requires_grad)\n        l = loss(y_pred, y)\n        l.backward()\n        optimizer.step()\n        optimizer.zero_grad()\n\n\ntest_device = [\"cpu\"] if os.getenv(\"ONEFLOW_TEST_CPU_ONLY\") else [\"cpu\", \"cuda\"]\n\n\n@flow.unittest.skip_unless_1n2d()\nclass TestDdpMultmpleOutputs(flow.unittest.TestCase):\n    def test_outputs_float32(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"device\"] = test_device\n        arg_dict[\"output\"] = [train_float32]\n        arg_dict[\"requires_grad\"] = [True, False]\n        for arg in GenArgDict(arg_dict):\n            train(test_case, train_x, **arg)\n\n    def test_outputs_int32(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"device\"] = test_device\n        arg_dict[\"output\"] = [train_int32]\n        arg_dict[\"requires_grad\"] = [False]\n        for arg in GenArgDict(arg_dict):\n            train(test_case, train_x, **arg)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_deconv2d.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nfrom oneflow.test_utils.test_util import GenArgList\nfrom oneflow.test_utils.automated_test_util import *\n\nimport oneflow as flow\nimport oneflow.nn as nn\nimport oneflow.unittest\nimport torch as torch_original\nfrom packaging import version\n\n\ndef _test_deconv_bias_false(test_case, device):\n    np_arr = np.array(\n        [\n            [\n                [\n                    [0.2735021114349365, -1.3842310905456543],\n                    [1.058540940284729, -0.03388553857803345],\n                ]\n            ]\n        ]\n    )\n    input = flow.tensor(\n        np_arr, dtype=flow.float32, device=flow.device(device), requires_grad=True\n    )\n    weight = np.array(\n        [\n            [\n                [\n                    [0.06456436216831207, -0.10852358490228653, -0.21638715267181396],\n                    [-0.2279110550880432, 0.1476770043373108, 0.19457484781742096],\n                    [0.05026858672499657, 0.10818571597337723, 0.02056501805782318],\n                ],\n                [\n                    [0.205095112323761, 0.1488947868347168, -0.2344113141298294],\n                    [0.1684819906949997, -0.21986986696720123, 0.1082606166601181],\n                    [-0.1528974026441574, 0.17120417952537537, 0.01954500749707222],\n                ],\n            ]\n        ]\n    )\n    m = nn.ConvTranspose2d(1, 2, 3, stride=1, bias=False)\n    m.weight = flow.nn.Parameter(flow.Tensor(weight))\n    m = m.to(device)\n    output = m(input)\n    np_out = np.array(\n        [\n            [\n                [\n                    [\n                        0.01765848882496357,\n                        -0.1190534234046936,\n                        0.09103937447071075,\n                        0.2995298206806183,\n                    ],\n                    [\n                        0.006009865552186966,\n                        0.2388070970773697,\n                        -0.37657976150512695,\n                        -0.26200416684150696,\n                    ],\n                    [\n                        -0.22750461101531982,\n                        0.12405071407556534,\n                        0.056831881403923035,\n                        -0.035060010850429535,\n                    ],\n                    [\n                        0.053211357444524765,\n                        0.11281562596559525,\n                        0.0181029811501503,\n                        -0.0006968567031435668,\n                    ],\n                ],\n                [\n                    [\n                        0.05609394609928131,\n                        -0.24317599833011627,\n                        -0.27021679282188416,\n                        0.32447943091392517,\n                    ],\n                    [\n                        0.26318174600601196,\n                        -0.14269141852855682,\n                        0.08078087121248245,\n                        -0.14191456139087677,\n                    ],\n                    [\n                        0.13652732968330383,\n                        0.020019691437482834,\n                        -0.10959184169769287,\n                        -0.03072327747941017,\n                    ],\n                    [\n                        -0.16184815764427185,\n                        0.1864076405763626,\n                        0.014887845143675804,\n                        -0.0006622931105084717,\n                    ],\n                ],\n            ]\n        ]\n    )\n    test_case.assertTrue(np.allclose(output.numpy(), np_out, 1e-06, 1e-06))\n    output = output.sum()\n    output.backward()\n    np_grad = [\n        [\n            [\n                [0.24731683731079102, 0.24731683731079102],\n                [0.24731683731079102, 0.24731683731079102],\n            ]\n        ]\n    ]\n    test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-06, 1e-06))\n\n\ndef _test_deconv_bias_true(test_case, device):\n    np_arr = np.array(\n        [\n            [\n                [\n                    [0.2735021114349365, -1.3842310905456543],\n                    [1.058540940284729, -0.03388553857803345],\n                ]\n            ]\n        ]\n    )\n    input = flow.tensor(\n        np_arr, dtype=flow.float32, device=flow.device(device), requires_grad=True\n    )\n    weight = np.array(\n        [\n            [\n                [\n                    [0.06456436216831207, -0.10852358490228653, -0.21638715267181396],\n                    [-0.2279110550880432, 0.1476770043373108, 0.19457484781742096],\n                    [0.05026858672499657, 0.10818571597337723, 0.02056501805782318],\n                ],\n                [\n                    [0.205095112323761, 0.1488947868347168, -0.2344113141298294],\n                    [0.1684819906949997, -0.21986986696720123, 0.1082606166601181],\n                    [-0.1528974026441574, 0.17120417952537537, 0.01954500749707222],\n                ],\n            ]\n        ]\n    )\n    bias = np.array([0.06456436216831207, -0.10852358490228653])\n    m = nn.ConvTranspose2d(1, 2, 3, stride=1)\n    m.weight = flow.nn.Parameter(flow.Tensor(weight))\n    m.bias = flow.nn.Parameter(flow.Tensor(bias))\n    m = m.to(device)\n    output = m(input)\n    np_out = [\n        [\n            [\n                [\n                    0.0822228491306305,\n                    -0.05448906123638153,\n                    0.15560373663902283,\n                    0.36409419775009155,\n                ],\n                [\n                    0.07057422399520874,\n                    0.30337145924568176,\n                    -0.3120154142379761,\n                    -0.19743980467319489,\n                ],\n                [\n                    -0.16294024884700775,\n                    0.188615083694458,\n                    0.12139624357223511,\n                    0.029504351317882538,\n                ],\n                [\n                    0.11777572333812714,\n                    0.17737999558448792,\n                    0.08266734331846237,\n                    0.06386750191450119,\n                ],\n            ],\n            [\n                [\n                    -0.05242963880300522,\n                    -0.3516995906829834,\n                    -0.3787403702735901,\n                    0.21595585346221924,\n                ],\n                [\n                    0.15465816855430603,\n                    -0.25121501088142395,\n                    -0.027742713689804077,\n                    -0.2504381537437439,\n                ],\n                [\n                    0.028003744781017303,\n                    -0.088503897190094,\n                    -0.2181154191493988,\n                    -0.139246866106987,\n                ],\n                [\n                    -0.2703717350959778,\n                    0.07788405567407608,\n                    -0.09363573789596558,\n                    -0.10918587446212769,\n                ],\n            ],\n        ]\n    ]\n    test_case.assertTrue(np.allclose(output.numpy(), np_out, 1e-06, 1e-06))\n    output = output.sum()\n    output.backward()\n    np_grad = [\n        [\n            [\n                [0.24731683731079102, 0.24731683731079102],\n                [0.24731683731079102, 0.24731683731079102],\n            ]\n        ]\n    ]\n    test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-06, 1e-06))\n\n\ndef _test_deconv_group_bias_false(test_case, device):\n    np_arr = np.array(\n        [\n            [\n                [\n                    [-2.0125174206754517, 1.9917882689443576],\n                    [0.13146748727936577, -0.5356457374181375],\n                ],\n                [\n                    [1.020683505853394, 1.2900643048299678],\n                    [-0.549010560600543, 0.8088391626901512],\n                ],\n            ]\n        ]\n    )\n    input = flow.tensor(\n        np_arr, dtype=flow.float32, device=flow.device(device), requires_grad=True\n    )\n    m = nn.ConvTranspose2d(2, 2, 3, stride=1, groups=2, bias=False)\n    weight = np.array(\n        [\n            [\n                [\n                    [0.06456436216831207, -0.10852358490228653, -0.21638715267181396],\n                    [-0.2279110550880432, 0.1476770043373108, 0.19457484781742096],\n                    [0.05026858672499657, 0.10818571597337723, 0.02056501805782318],\n                ]\n            ],\n            [\n                [\n                    [0.205095112323761, 0.1488947868347168, -0.2344113141298294],\n                    [0.1684819906949997, -0.21986986696720123, 0.1082606166601181],\n                    [-0.1528974026441574, 0.17120417952537537, 0.01954500749707222],\n                ]\n            ],\n        ]\n    )\n    m.weight = flow.nn.Parameter(flow.Tensor(weight))\n    m = m.to(device)\n    output = m(input)\n    np_out = np.array(\n        [\n            [\n                [\n                    [\n                        -0.12993690371513367,\n                        0.34700414538383484,\n                        0.219326913356781,\n                        -0.43099740147590637,\n                    ],\n                    [\n                        0.4671630859375,\n                        -0.8000040054321289,\n                        -0.06776165962219238,\n                        0.5034587383270264,\n                    ],\n                    [\n                        -0.13112929463386536,\n                        0.02389305830001831,\n                        0.12057329714298248,\n                        -0.06326202303171158,\n                    ],\n                    [\n                        0.00660868501290679,\n                        -0.012703249230980873,\n                        -0.05524558573961258,\n                        -0.011015564203262329,\n                    ],\n                ],\n                [\n                    [\n                        0.20933720469474792,\n                        0.4165603518486023,\n                        -0.04717591404914856,\n                        -0.3024056851863861,\n                    ],\n                    [\n                        0.059367403388023376,\n                        0.07707919180393219,\n                        0.07597976922988892,\n                        -0.049937888979911804,\n                    ],\n                    [\n                        -0.24855825304985046,\n                        0.2344835251569748,\n                        0.003538096323609352,\n                        0.11277973651885986,\n                    ],\n                    [\n                        0.08394229412078857,\n                        -0.21766230463981628,\n                        0.12774622440338135,\n                        0.015808766707777977,\n                    ],\n                ],\n            ]\n        ]\n    )\n\n    test_case.assertTrue(np.allclose(output.numpy(), np_out, 1e-06, 1e-06))\n    output = output.sum()\n    output.backward()\n    np_grad = [\n        [\n            [\n                [0.03301373869180679, 0.03301373869180679],\n                [0.03301373869180679, 0.03301373869180679],\n            ],\n            [\n                [0.21430310606956482, 0.21430310606956482],\n                [0.21430310606956482, 0.21430310606956482],\n            ],\n        ]\n    ]\n    test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-06, 1e-06))\n\n\ndef _test_deconv_group_bias_true(test_case, device):\n    np_arr = np.array(\n        [\n            [\n                [\n                    [-2.0125174206754517, 1.9917882689443576],\n                    [0.13146748727936577, -0.5356457374181375],\n                ],\n                [\n                    [1.020683505853394, 1.2900643048299678],\n                    [-0.549010560600543, 0.8088391626901512],\n                ],\n            ]\n        ]\n    )\n    input = flow.tensor(\n        np_arr, dtype=flow.float32, device=flow.device(device), requires_grad=True\n    )\n    m = nn.ConvTranspose2d(2, 2, 3, stride=1, groups=2)\n    weight = np.array(\n        [\n            [\n                [\n                    [0.06456436216831207, -0.10852358490228653, -0.21638715267181396],\n                    [-0.2279110550880432, 0.1476770043373108, 0.19457484781742096],\n                    [0.05026858672499657, 0.10818571597337723, 0.02056501805782318],\n                ]\n            ],\n            [\n                [\n                    [0.205095112323761, 0.1488947868347168, -0.2344113141298294],\n                    [0.1684819906949997, -0.21986986696720123, 0.1082606166601181],\n                    [-0.1528974026441574, 0.17120417952537537, 0.01954500749707222],\n                ]\n            ],\n        ]\n    )\n    m.weight = flow.nn.Parameter(flow.Tensor(weight))\n    bias = np.array([0.06456436216831207, -0.10852358490228653])\n    m.bias = flow.nn.Parameter(flow.Tensor(bias))\n    m = m.to(device)\n    output = m(input)\n    np_out = [\n        [\n            [\n                [\n                    -0.0653725415468216,\n                    0.4115685224533081,\n                    0.2838912606239319,\n                    -0.3664330244064331,\n                ],\n                [\n                    0.5317274332046509,\n                    -0.735439658164978,\n                    -0.00319729745388031,\n                    0.5680230855941772,\n                ],\n                [\n                    -0.06656493246555328,\n                    0.08845742046833038,\n                    0.18513765931129456,\n                    0.0013023391366004944,\n                ],\n                [\n                    0.0711730495095253,\n                    0.05186111479997635,\n                    0.009318776428699493,\n                    0.053548797965049744,\n                ],\n            ],\n            [\n                [\n                    0.1008136197924614,\n                    0.30803677439689636,\n                    -0.1556994915008545,\n                    -0.41092926263809204,\n                ],\n                [\n                    -0.04915618151426315,\n                    -0.03144439309835434,\n                    -0.032543815672397614,\n                    -0.15846148133277893,\n                ],\n                [\n                    -0.3570818305015564,\n                    0.12595993280410767,\n                    -0.10498549044132233,\n                    0.004256151616573334,\n                ],\n                [\n                    -0.024581290781497955,\n                    -0.3261858820915222,\n                    0.019222639501094818,\n                    -0.0927148163318634,\n                ],\n            ],\n        ]\n    ]\n    test_case.assertTrue(np.allclose(output.numpy(), np_out, 1e-06, 1e-06))\n    output = output.sum()\n    output.backward()\n    np_grad = [\n        [\n            [\n                [0.03301373869180679, 0.03301373869180679],\n                [0.03301373869180679, 0.03301373869180679],\n            ],\n            [\n                [0.21430310606956482, 0.21430310606956482],\n                [0.21430310606956482, 0.21430310606956482],\n            ],\n        ]\n    ]\n    test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-06, 1e-06))\n\n\ndef _test_deconv_group_large_out_channel(test_case, device):\n    np_arr = np.array(\n        [\n            [\n                [\n                    [-2.0125174206754517, 1.9917882689443576],\n                    [0.13146748727936577, -0.5356457374181375],\n                ],\n                [\n                    [1.020683505853394, 1.2900643048299678],\n                    [-0.549010560600543, 0.8088391626901512],\n                ],\n            ]\n        ]\n    )\n    input = flow.tensor(\n        np_arr, dtype=flow.float32, device=flow.device(device), requires_grad=True\n    )\n    m = nn.ConvTranspose2d(2, 6, 3, stride=1, groups=2, bias=False)\n    weight = np.array(\n        [\n            [\n                [\n                    [0.05271657928824425, -0.08860913664102554, -0.17667937278747559],\n                    [-0.18608860671520233, 0.12057777494192123, 0.1588696986436844],\n                    [0.04104413092136383, 0.08833327144384384, 0.016791267320513725],\n                ],\n                [\n                    [0.16745945811271667, 0.1215720921754837, -0.19139604270458221],\n                    [0.13756497204303741, -0.17952299118041992, 0.08839442580938339],\n                    [-0.12484020739793777, 0.13978762924671173, 0.015958432108163834],\n                ],\n                [\n                    [-0.07709092646837234, -0.029757702723145485, -0.18154984712600708],\n                    [-0.14461342990398407, 0.06567336618900299, 0.05665326863527298],\n                    [0.04441174864768982, -0.04477253183722496, 0.191376194357872],\n                ],\n            ],\n            [\n                [\n                    [0.1850736141204834, 0.07141514122486115, 0.05791180208325386],\n                    [0.07253318279981613, -0.042754165828228, -0.14045141637325287],\n                    [0.08525089919567108, 0.009758883155882359, -0.07303793728351593],\n                ],\n                [\n                    [-0.005451973062008619, 0.1499139368534088, 0.16706342995166779],\n                    [-0.05473465472459793, 0.02753184549510479, -0.06856250017881393],\n                    [0.03629609942436218, -0.06238799914717674, -0.041715867817401886],\n                ],\n                [\n                    [0.15021666884422302, -0.10501708835363388, 0.04741475358605385],\n                    [-0.16011257469654083, 0.1280348002910614, 0.11050418764352798],\n                    [-0.10031674802303314, 0.1449088454246521, -0.16990724205970764],\n                ],\n            ],\n        ]\n    )\n    m.weight = flow.nn.Parameter(flow.Tensor(weight))\n    m = m.to(device)\n    output = m(input)\n    np_out = np.array(\n        [\n            [\n                [\n                    [\n                        -0.10609303414821625,\n                        0.28332769870758057,\n                        0.17907968163490295,\n                        -0.3519079089164734,\n                    ],\n                    [\n                        0.3814370930194855,\n                        -0.653200626373291,\n                        -0.055327147245407104,\n                        0.41107234358787537,\n                    ],\n                    [\n                        -0.10706663131713867,\n                        0.019508585333824158,\n                        0.09844768047332764,\n                        -0.05165322124958038,\n                    ],\n                    [\n                        0.005395968910306692,\n                        -0.010372160002589226,\n                        -0.04510783404111862,\n                        -0.00899417046457529,\n                    ],\n                ],\n                [\n                    [\n                        -0.3370150923728943,\n                        0.08887782692909241,\n                        0.6273337602615356,\n                        -0.38122040033340454,\n                    ],\n                    [\n                        -0.25483641028404236,\n                        0.561577320098877,\n                        -0.6257490515708923,\n                        0.27858346700668335,\n                    ],\n                    [\n                        0.26932841539382935,\n                        -0.6272678375244141,\n                        0.35409244894981384,\n                        -0.015562277287244797,\n                    ],\n                    [\n                        -0.01641242951154709,\n                        0.08524765074253082,\n                        -0.0727786272764206,\n                        -0.008548066020011902,\n                    ],\n                ],\n                [\n                    [\n                        0.15514683723449707,\n                        -0.09366090595722198,\n                        0.3061012029647827,\n                        -0.3616088628768921,\n                    ],\n                    [\n                        0.28090208768844604,\n                        -0.38282686471939087,\n                        0.008863434195518494,\n                        0.21008771657943726,\n                    ],\n                    [\n                        -0.10839138925075531,\n                        0.2646597623825073,\n                        -0.5020549297332764,\n                        0.35083478689193726,\n                    ],\n                    [\n                        0.005838701035827398,\n                        -0.029675094410777092,\n                        0.04914196580648422,\n                        -0.10250984132289886,\n                    ],\n                ],\n                [\n                    [\n                        0.18890158832073212,\n                        0.3116491138935089,\n                        0.15123975276947021,\n                        0.074709951877594,\n                    ],\n                    [\n                        -0.027573950588703156,\n                        0.16042113304138184,\n                        -0.17254289984703064,\n                        -0.1343500316143036,\n                    ],\n                    [\n                        0.047192707657814026,\n                        0.20208004117012024,\n                        -0.01943095773458481,\n                        -0.20782624185085297,\n                    ],\n                    [\n                        -0.04680364578962326,\n                        0.06359653919935226,\n                        0.04799196869134903,\n                        -0.05907594412565231,\n                    ],\n                ],\n                [\n                    [\n                        -0.005564738996326923,\n                        0.1459812968969345,\n                        0.3639175295829773,\n                        0.21552257239818573,\n                    ],\n                    [\n                        -0.05287356674671173,\n                        -0.12922403216362,\n                        -0.0049260929226875305,\n                        0.04667740315198898,\n                    ],\n                    [\n                        0.06709674000740051,\n                        -0.0762409120798111,\n                        -0.06315286457538605,\n                        -0.10927218943834305,\n                    ],\n                    [\n                        -0.019926942884922028,\n                        0.06360937654972076,\n                        -0.027559401467442513,\n                        -0.03374142572283745,\n                    ],\n                ],\n                [\n                    [\n                        0.1533236801624298,\n                        0.08659995347261429,\n                        -0.08708333969116211,\n                        0.06116808205842972,\n                    ],\n                    [\n                        -0.24589480459690094,\n                        0.10328409075737,\n                        0.16698980331420898,\n                        0.1809084266424179,\n                    ],\n                    [\n                        -0.014488153159618378,\n                        -0.18130677938461304,\n                        0.056411802768707275,\n                        -0.1298111528158188,\n                    ],\n                    [\n                        0.05507495626807213,\n                        -0.1606965959072113,\n                        0.21048882603645325,\n                        -0.13742762804031372,\n                    ],\n                ],\n            ]\n        ]\n    )\n    test_case.assertTrue(np.allclose(output.numpy(), np_out, 1e-06, 1e-06))\n    output = output.sum()\n    output.backward()\n    np_grad = [\n        [\n            [\n                [0.0822635293006897, 0.0822635293006897],\n                [0.0822635293006897, 0.0822635293006897],\n            ],\n            [\n                [0.4193778932094574, 0.4193778932094574],\n                [0.4193778932094574, 0.4193778932094574],\n            ],\n        ]\n    ]\n    test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-06, 1e-06))\n\n\ndef _test_deconv_group_large_in_channel(test_case, device):\n    np_arr = [\n        [\n            [\n                [0.6393764315295867, 0.3890587560476374],\n                [0.8467359871201484, 0.24046160407703143],\n            ],\n            [\n                [0.23352071016856402, 0.6760713653927521],\n                [0.061939453383917376, 0.13541973098624682],\n            ],\n            [\n                [0.7524804920779914, 0.34366296030931365],\n                [0.4961502482687954, 0.38175448164636205],\n            ],\n            [\n                [0.01867975512238773, 0.12599156959160163],\n                [0.2658608593205851, 0.6184459583178925],\n            ],\n        ]\n    ]\n    input = flow.tensor(\n        np_arr, dtype=flow.float32, device=flow.device(device), requires_grad=True\n    )\n    m = nn.ConvTranspose2d(4, 2, 3, stride=1, groups=2, bias=False)\n    weight = np.array(\n        [\n            [\n                [\n                    [0.09130779653787613, -0.15347552299499512, -0.30601766705513],\n                    [-0.32231491804122925, 0.2088468372821808, 0.27517038583755493],\n                    [0.07109051942825317, 0.1529977172613144, 0.02908332832157612],\n                ]\n            ],\n            [\n                [\n                    [0.2900483012199402, 0.21056903898715973, -0.33150768280029297],\n                    [0.23826952278614044, -0.31094294786453247, 0.15310363471508026],\n                    [-0.21622958779335022, 0.24211928248405457, 0.0276408139616251],\n                ]\n            ],\n            [\n                [\n                    [-0.13352541625499725, -0.051541853696107864, -0.3144535720348358],\n                    [-0.2504778206348419, 0.11374961584806442, 0.09812634438276291],\n                    [0.07692340761423111, -0.0775483027100563, 0.33147329092025757],\n                ]\n            ],\n            [\n                [\n                    [0.3205569088459015, 0.12369465827941895, 0.1003061905503273],\n                    [0.1256311535835266, -0.07405238598585129, -0.24326899647712708],\n                    [0.14765889942646027, 0.016902882605791092, -0.12650541961193085],\n                ]\n            ],\n        ]\n    )\n    m.weight = flow.nn.Parameter(flow.Tensor(weight))\n    m = m.to(device)\n    np_out = np.array(\n        [\n            [\n                [\n                    [\n                        0.12611234188079834,\n                        0.1826610565185547,\n                        -0.19042569398880005,\n                        -0.34318169951438904,\n                    ],\n                    [\n                        -0.05516064167022705,\n                        0.04093143343925476,\n                        -0.2053149938583374,\n                        0.0920882523059845,\n                    ],\n                    [\n                        -0.2631978690624237,\n                        0.14817529916763306,\n                        0.4988565742969513,\n                        0.11690345406532288,\n                    ],\n                    [\n                        0.04680176079273224,\n                        0.13235820829868317,\n                        0.09591575711965561,\n                        0.010736535303294659,\n                    ],\n                ],\n                [\n                    [\n                        -0.09448734670877457,\n                        -0.04197392612695694,\n                        -0.2368750274181366,\n                        -0.09542831033468246,\n                    ],\n                    [\n                        -0.1671580672264099,\n                        0.16854587197303772,\n                        0.02652890235185623,\n                        -0.05493755638599396,\n                    ],\n                    [\n                        -0.030232630670070648,\n                        0.0058259665966033936,\n                        0.20417997241020203,\n                        -0.015012085437774658,\n                    ],\n                    [\n                        0.07742229104042053,\n                        0.0867031067609787,\n                        0.11167682707309723,\n                        0.048304662108421326,\n                    ],\n                ],\n            ]\n        ]\n    )\n    output = m(input)\n    test_case.assertTrue(np.allclose(output.numpy(), np_out, 1e-06, 1e-06))\n    output = output.sum()\n    output.backward()\n    np_grad = [\n        [\n            [\n                [0.046688467264175415, 0.046688467264175415],\n                [0.046688467264175415, 0.046688467264175415],\n            ],\n            [\n                [0.30307042598724365, 0.30307042598724365],\n                [0.30307042598724365, 0.30307042598724365],\n            ],\n            [\n                [-0.20727425813674927, -0.20727425813674927],\n                [-0.20727425813674927, -0.20727425813674927],\n            ],\n            [\n                [0.3909238576889038, 0.3909238576889038],\n                [0.3909238576889038, 0.3909238576889038],\n            ],\n        ]\n    ]\n    test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-06, 1e-06))\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestDeconv2d(flow.unittest.TestCase):\n    def test_deconv2d(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [\n            _test_deconv_bias_false,\n            _test_deconv_bias_true,\n            _test_deconv_group_bias_false,\n            _test_deconv_group_bias_true,\n            _test_deconv_group_large_out_channel,\n            _test_deconv_group_large_in_channel,\n        ]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n    @autotest(n=5, rtol=1e-2, atol=1e-3)\n    def test_deconv2d_with_random_data(test_case):\n        channels = random(1, 6)\n        m = torch.nn.ConvTranspose2d(\n            in_channels=channels,\n            out_channels=random(1, 20),\n            kernel_size=random(1, 4),\n            stride=random() | nothing(),\n            padding=random(1, 3).to(int) | nothing(),\n            dilation=random(1, 5) | nothing(),\n            groups=random(1, 5) | nothing(),\n            padding_mode=constant(\"zeros\") | nothing(),\n            bias=random_bool(),\n        )\n        m.train(random())\n        device = random_device()\n        m.to(device)\n        x = random_tensor(ndim=4, dim1=channels).to(device)\n        y = m(x)\n        return y\n\n    @unittest.skipIf(\n        version.parse(torch_original.__version__) <= version.parse(\"1.13.0\"),\n        \"deconv module don't support unbatched input in PyTorch before '1.13.0'\",\n    )\n    @autotest(n=5)\n    def test_deconv2d_auto_squeeze_with_random_data(test_case):\n        channels = random(1, 6)\n        m = torch.nn.ConvTranspose2d(\n            in_channels=channels,\n            out_channels=random(1, 20),\n            kernel_size=random(1, 4),\n            stride=random() | nothing(),\n            padding=random(1, 3).to(int) | nothing(),\n            dilation=random(1, 5) | nothing(),\n            groups=random(1, 5) | nothing(),\n            padding_mode=constant(\"zeros\") | nothing(),\n            bias=random_bool(),\n        )\n        m.train(random())\n        device = random_device()\n        m.to(device)\n        x = random_tensor(ndim=3, dim0=channels).to(device)\n        y = m(x)\n        return y\n\n    @autotest(check_graph=False)\n    def test_deconv2d_0size_with_random_data(test_case):\n        channels = random(1, 6)\n        m = torch.nn.ConvTranspose2d(\n            in_channels=channels,\n            out_channels=random(1, 20),\n            kernel_size=random(1, 4),\n            stride=random() | nothing(),\n            padding=random(1, 3).to(int) | nothing(),\n            dilation=random(1, 5) | nothing(),\n            groups=random(1, 5) | nothing(),\n            padding_mode=constant(\"zeros\") | nothing(),\n        )\n        m.train(random())\n        device = random_device()\n        m.to(device)\n        x = random_tensor(ndim=4, dim0=0, dim1=channels).to(device)\n        y = m(x)\n        return y\n\n    @unittest.skip(\n        \"Likely to fail the test. This case should run on cpu when the problem is solved.\"\n    )\n    @autotest(n=30, check_graph=False, rtol=1e-2, atol=1e-4)\n    def test_deconv2d_group_with_random_data(test_case):\n        channels = 720  # lcm(1, 2, 3, 4, 5, 6)\n        m = torch.nn.ConvTranspose2d(\n            in_channels=channels,\n            out_channels=channels,\n            kernel_size=random(1, 4),\n            stride=random() | nothing(),\n            padding=random(1, 3).to(int) | nothing(),\n            dilation=random(1, 5) | nothing(),\n            groups=random(1, 7),\n            padding_mode=constant(\"zeros\") | nothing(),\n        )\n        m.train(random())\n\n        device = random_device()\n        m.to(device)\n        m.pytorch.to(\"cuda\")\n        x = random_tensor(ndim=4, dim1=channels).to(device)\n        x.pytorch = x.pytorch.to(\"cuda\")\n        y = m(x)\n        return y\n\n    @profile(torch.nn.functional.conv_transpose2d)\n    def profile_conv_transpose2d(test_case):\n        inputs = torch.ones(16, 128, 128, 128)\n        weights_4x4_64c = torch.ones(128, 64, 4, 4)\n        weights_6x6_64c = torch.ones(128, 64, 6, 6)\n        weights_8x8_64c = torch.ones(128, 64, 8, 8)\n        torch.nn.functional.conv_transpose2d(\n            inputs, weights_4x4_64c, stride=2, padding=1\n        )\n        torch.nn.functional.conv_transpose2d(\n            inputs, weights_4x4_64c, stride=2, padding=1, bias=torch.ones(64)\n        )\n        torch.nn.functional.conv_transpose2d(\n            inputs, weights_6x6_64c, stride=3, padding=2, output_padding=1\n        )\n        torch.nn.functional.conv_transpose2d(\n            inputs,\n            weights_6x6_64c,\n            stride=3,\n            padding=2,\n            bias=torch.ones(64),\n            output_padding=1,\n        )\n        torch.nn.functional.conv_transpose2d(\n            inputs, weights_8x8_64c, stride=4, padding=2\n        )\n        torch.nn.functional.conv_transpose2d(\n            inputs, weights_8x8_64c, stride=4, padding=2, bias=torch.ones(64)\n        )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_default_dtype.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\n\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\n_source_op_list = [\n    flow.ones,\n    flow.zeros,\n    flow.rand,\n    flow.randn,\n    flow.empty,\n    flow.Tensor,\n]\n\n\nclass TestDefaultDTypeInferface(oneflow.unittest.TestCase):\n    def test_set_default_dtype(test_case):\n        flow.set_default_dtype(flow.float32)\n        test_case.assertEqual(flow.get_default_dtype(), flow.float32)\n\n        flow.set_default_dtype(flow.float64)\n        test_case.assertEqual(flow.get_default_dtype(), flow.float64)\n\n        for op in _source_op_list:\n            x = op((2, 3))\n            test_case.assertEqual(x.dtype, flow.float64)\n            x = op(2, 3)\n            test_case.assertEqual(x.dtype, flow.float64)\n\n        with test_case.assertRaises(Exception) as ctx:\n            flow.set_default_dtype(flow.int32)\n        test_case.assertTrue(\n            \"only floating-point types are supported as the default type\"\n            in str(ctx.exception)\n        )\n\n    def test_set_default_tensor_type(test_case):\n        flow.set_default_dtype(flow.float32)\n        test_case.assertEqual(flow.get_default_dtype(), flow.float32)\n\n        # set default tensor type by TensorType\n        flow.set_default_tensor_type(flow.DoubleTensor)\n        test_case.assertEqual(flow.get_default_dtype(), flow.float64)\n        for op in _source_op_list:\n            x = op((2, 3))\n            test_case.assertEqual(x.dtype, flow.float64)\n            x = op(2, 3)\n            test_case.assertEqual(x.dtype, flow.float64)\n\n        # set default tensor type by TensorType string\n        flow.set_default_tensor_type(\"oneflow.FloatTensor\")\n        test_case.assertEqual(flow.get_default_dtype(), flow.float32)\n        for op in _source_op_list:\n            x = op((2, 3))\n            test_case.assertEqual(x.dtype, flow.float32)\n\n    def test_behavior_for_oneflow_tensor(test_case):\n        # float32 scope\n        flow.set_default_dtype(flow.float32)\n        test_case.assertEqual(flow.get_default_dtype(), flow.float32)\n\n        x = flow.tensor([1.0, 2])\n        test_case.assertEqual(x.dtype, flow.float32)\n\n        # float64 scope\n        flow.set_default_dtype(flow.float64)\n        test_case.assertEqual(flow.get_default_dtype(), flow.float64)\n\n        x = flow.tensor([1.0, 2])\n        test_case.assertEqual(x.dtype, flow.float64)\n\n        # no affect for int type\n        x = flow.tensor((2, 3))\n        test_case.assertEqual(x.dtype, flow.int64)\n\n        # no affect for numpy array input\n        nd_arr = np.array([1, 2, 3]).astype(np.float32)\n        x = flow.tensor(nd_arr)\n        test_case.assertEqual(x.dtype, flow.float32)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_deform_conv2d.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nimport torchvision.ops\nimport torch\n\nimport oneflow as flow\nfrom oneflow.test_utils.automated_test_util.torch_flow_dual_object import random_tensor\nfrom oneflow.test_utils.test_util import GenArgList\nimport oneflow.unittest\n\n\ndef GetRandomData(max_batch_sz):\n    batch_sz = max_batch_sz\n    n_weight_grps = np.random.randint(1, 2)\n    n_offset_grps = np.random.randint(1, 2)\n    n_out_channels = n_offset_grps * np.random.randint(1, 15)\n    n_in_channels = n_offset_grps * np.random.randint(1, 15)\n\n    random_stride_h = np.random.randint(1, 5)\n    random_stride_w = np.random.randint(1, 5)\n    random_pad_h = np.random.randint(0, 3)\n    random_pad_w = np.random.randint(0, 3)\n    random_dilation_h = np.random.randint(1, 3)\n    random_dilation_w = np.random.randint(1, 3)\n    random_in_h = np.random.randint(5, 30)\n    random_in_w = np.random.randint(5, 30)\n\n    # BUG(yzm): Now use the rectangular convolution kernel is not aligned with PyTorch\n    # NOTE: Modify the following program after alignment using a rectangular convolution kernel\n    random_kernel_h = np.random.randint(1, 11)\n    random_kernel_w = random_kernel_h\n    # random_kernel_w=np.random.randint(1, 11)\n\n    stride = (random_stride_h, random_stride_w)\n    pad = (random_pad_h, random_pad_w)\n    dilation = (random_dilation_h, random_dilation_w)\n\n    return (\n        batch_sz,\n        n_out_channels,\n        n_in_channels,\n        n_weight_grps,\n        n_offset_grps,\n        stride,\n        pad,\n        dilation,\n        random_kernel_h,\n        random_kernel_w,\n        random_in_h,\n        random_in_w,\n    )\n\n\ndef GetFunArgs(device, max_batch_size):\n    out_w = 0\n    out_h = 0\n    while out_w <= 0 or out_h <= 0:\n        (\n            batch_sz,\n            n_out_channels,\n            n_in_channels,\n            n_weight_grps,\n            n_offset_grps,\n            stride,\n            pad,\n            dilation,\n            random_kernel_h,\n            random_kernel_w,\n            random_in_h,\n            random_in_w,\n        ) = GetRandomData(max_batch_size)\n        stride_h, stride_w = stride\n        pad_h, pad_w = pad\n        dil_h, dil_w = dilation\n        weight_h, weight_w = (random_kernel_h, random_kernel_w)\n        in_h, in_w = (random_in_h, random_in_w)\n        out_h = (in_h + 2 * pad_h - (dil_h * (weight_h - 1) + 1)) // stride_h + 1\n        out_w = (in_w + 2 * pad_w - (dil_w * (weight_w - 1) + 1)) // stride_w + 1\n\n    input_dims = [batch_sz, n_in_channels, in_h, in_w]\n    offset_dims = [batch_sz, 2 * n_offset_grps * weight_h * weight_w, out_h, out_w]\n    mask_dims = [batch_sz, n_offset_grps * weight_h * weight_w, out_h, out_w]\n    weight_dims = [n_out_channels, n_in_channels // n_weight_grps, weight_h, weight_w]\n\n    input = random_tensor(4, *input_dims).to(device)\n    offset = random_tensor(4, *offset_dims).to(device)\n    mask = random_tensor(4, *mask_dims).to(device)\n    weight = random_tensor(4, *weight_dims).to(device)\n    bias_dims = [n_out_channels]\n    bias = random_tensor(1, *bias_dims).to(device)\n    return input, weight, offset, mask, bias, stride, pad, dilation\n\n\ndef _test_deform_conv2d_forward(\n    test_case, input, weight, offset, mask, bias, stride, padding, dilation,\n):\n    torch_input = input.pytorch\n    torch_weight = weight.pytorch\n    torch_offset = offset.pytorch\n    torch_mask = mask.pytorch\n    torch_bias = bias.pytorch\n\n    torch_out = torchvision.ops.deform_conv2d(\n        torch_input,\n        torch_offset,\n        torch_weight,\n        stride=stride,\n        padding=padding,\n        dilation=dilation,\n        mask=torch_mask,\n        bias=torch_bias,\n    )\n\n    flow_input = input.oneflow\n    flow_weight = weight.oneflow\n    flow_offset = offset.oneflow\n    flow_mask = mask.oneflow\n    flow_bias = bias.oneflow\n\n    flow_out = oneflow.nn.functional.deform_conv2d(\n        flow_input,\n        flow_offset,\n        flow_weight,\n        stride=stride,\n        padding=padding,\n        dilation=dilation,\n        mask=flow_mask,\n        bias=flow_bias,\n    )\n    test_case.assertTrue(\n        np.allclose(\n            flow_out.numpy(), torch_out.detach().cpu().numpy(), rtol=1e-2, atol=1e-2\n        )\n    )\n\n\ndef _test_deform_conv2d_backward(\n    test_case, input, weight, offset, mask, bias, stride, padding, dilation\n):\n    torch_input = input.pytorch.detach().requires_grad_()\n    torch_weight = weight.pytorch.detach().requires_grad_()\n    torch_offset = offset.pytorch.detach().requires_grad_()\n    torch_mask = mask.pytorch.detach().requires_grad_()\n    torch_bias = bias.pytorch.detach().requires_grad_()\n\n    torch_out = torchvision.ops.deform_conv2d(\n        torch_input,\n        torch_offset,\n        torch_weight,\n        stride=stride,\n        padding=padding,\n        dilation=dilation,\n        mask=torch_mask,\n        bias=torch_bias,\n    )\n    torch_out.sum().backward()\n\n    flow_input = input.oneflow.detach().requires_grad_()\n    flow_weight = weight.oneflow.detach().requires_grad_()\n    flow_offset = offset.oneflow.detach().requires_grad_()\n    flow_mask = mask.oneflow.detach().requires_grad_()\n    flow_bias = bias.oneflow.detach().requires_grad_()\n\n    flow_out = oneflow.nn.functional.deform_conv2d(\n        flow_input,\n        flow_offset,\n        flow_weight,\n        stride=stride,\n        padding=padding,\n        dilation=dilation,\n        mask=flow_mask,\n        bias=flow_bias,\n    )\n    flow_out.sum().backward()\n    test_case.assertTrue(\n        np.allclose(\n            flow_input.grad.numpy(),\n            torch_input.grad.cpu().numpy(),\n            rtol=1e-2,\n            atol=1e-2,\n        )\n    )\n    test_case.assertTrue(\n        np.allclose(\n            flow_weight.grad.numpy(),\n            torch_weight.grad.cpu().numpy(),\n            rtol=1e-2,\n            atol=1e-2,\n        )\n    )\n    test_case.assertTrue(\n        np.allclose(\n            flow_offset.grad.numpy(),\n            torch_offset.grad.cpu().numpy(),\n            rtol=1e-2,\n            atol=1e-2,\n        )\n    )\n    test_case.assertTrue(\n        np.allclose(\n            flow_mask.grad.numpy(), torch_mask.grad.cpu().numpy(), rtol=1e-2, atol=1e-2\n        )\n    )\n    test_case.assertTrue(\n        np.allclose(\n            flow_bias.grad.numpy(), torch_bias.grad.cpu().numpy(), rtol=1e-5, atol=1e-5\n        )\n    )\n\n\ndef _test_forward_and_backward(test_case, device):\n    max_batch_size = 40\n    for batch_size in range(1, max_batch_size):\n        input, weight, offset, mask, bias, stride, padding, dilation = GetFunArgs(\n            device, batch_size\n        )\n        _test_deform_conv2d_forward(\n            test_case, input, weight, offset, mask, bias, stride, padding, dilation\n        )\n        _test_deform_conv2d_backward(\n            test_case, input, weight, offset, mask, bias, stride, padding, dilation\n        )\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestDeformConv2d(flow.unittest.TestCase):\n    def test_deform_conv2d(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [_test_forward_and_backward]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_det.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nimport re\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\ndef det_random_device():\n    cuda_version = flow._oneflow_internal.flags.cuda_version()\n    if cuda_version < 11000:  # cuSOLVER is only supported in CUDA 11.0 and above\n        return cpu_device()\n    else:\n        return random_device()\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestLinalgDet(flow.unittest.TestCase):\n    @autotest(n=5, rtol=1e-2, auto_backward=False)\n    def test_det_3by3_with_random_data(test_case):\n        device = det_random_device()\n        x = random_tensor(ndim=2, dim0=3, dim1=3, low=-1).to(device)\n        return torch.linalg.det(x)\n\n    @autotest(n=5, rtol=1e-2, auto_backward=False)\n    def test_det_batch_3by3_with_random_data(test_case):\n        device = det_random_device()\n        x = random_tensor(ndim=3, dim0=random(), dim1=3, dim2=3, low=-1).to(device)\n        return torch.linalg.det(x)\n\n    @autotest(n=5, rtol=1e-2, auto_backward=False)\n    def test_det_random_square_with_random_data(test_case):\n        device = det_random_device()\n        square_dim = random()\n        x = random_tensor(ndim=4, dim2=square_dim, dim3=square_dim, low=-1).to(device)\n        return torch.linalg.det(x)\n\n    @profile(torch.linalg.det)\n    def profile_linalg_det(test_case):\n        torch.linalg.det(torch.randn(1, 32, 4, 4))\n        torch.linalg.det(torch.randn(16, 32, 4, 4))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_diag.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nimport torch as ori_torch\n\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@flow.unittest.skip_unless_1n1d()\nclass Test_Diag_module(flow.unittest.TestCase):\n    @autotest(n=5)\n    def test_diag_one_dim(test_case):\n        device = random_device()\n        x = random_tensor(ndim=1, dim0=random()).to(device)\n        return torch.diag(x)\n\n    @autotest(n=5)\n    def test_diag_other_dim(test_case):\n        device = random_device()\n        x = random_tensor(ndim=2, dim0=random(), dim1=random()).to(device)\n        return torch.diag(x)\n\n    @autotest(auto_backward=False)\n    def test_diag_one_dim(test_case):\n        device = random_device()\n        x = random_tensor(ndim=1, dim0=random()).to(device, torch.bool)\n        return torch.diag(x)\n\n    def test_diag_0size_tensor(test_case):\n        torch_tensor = ori_torch.empty(0).diag()\n        flow_tensor = flow.empty(0).diag()\n        test_case.assertTrue(\n            np.array_equal(list(torch_tensor.shape), list(flow_tensor.shape))\n        )\n        torch_tensor = ori_torch.empty(0, 0).diag()\n        flow_tensor = flow.empty(0, 0).diag()\n        test_case.assertTrue(\n            np.array_equal(list(torch_tensor.shape), list(flow_tensor.shape))\n        )\n        torch_tensor = ori_torch.empty(0, 3).diag()\n        flow_tensor = flow.empty(0, 3).diag()\n        test_case.assertTrue(\n            np.array_equal(list(torch_tensor.shape), list(flow_tensor.shape))\n        )\n\n    @profile(torch.diag)\n    def profile_diag(test_case):\n        torch.diag(torch.ones(1000))\n        torch.diag(torch.ones(128, 128))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_diagonal.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nimport numpy as np\nimport oneflow as flow\n\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestDiagonal(flow.unittest.TestCase):\n    @autotest(n=10, check_graph=True)\n    def test_flow_diagonal_with_random_data(test_case):\n        device = random_device()\n        offset = random(-5, 5).to(int)\n        dim1 = random(-4, 4).to(int)\n        dim2 = random(-4, 4).to(int)\n\n        x = random_tensor(\n            ndim=4,\n            dim1=random(4, 6),\n            dim2=random(4, 6),\n            dim3=random(4, 6),\n            dim4=random(4, 6),\n        ).to(device)\n        z = torch.diagonal(x, offset, dim1, dim2)\n        return z\n\n    @autotest(auto_backward=False, n=10, check_graph=True)\n    def test_flow_diagonal_with_random_data(test_case):\n        device = random_device()\n        offset = random(-5, 5).to(int)\n        dim1 = random(-4, 4).to(int)\n        dim2 = random(-4, 4).to(int)\n\n        x = random_tensor(\n            ndim=4,\n            dim1=random(4, 6),\n            dim2=random(4, 6),\n            dim3=random(4, 6),\n            dim4=random(4, 6),\n        ).to(device, torch.bool)\n        z = torch.diagonal(x, offset, dim1, dim2)\n        return z\n\n    @profile(torch.diagonal)\n    def profile_diagonal(test_case):\n        input1 = torch.ones(128, 128)\n        input2 = torch.ones(16, 10, 128, 128)\n        torch.diagonal(input1, 0)\n        torch.diagonal(input1, 1)\n        torch.diagonal(input2, offset=-1, dim1=1, dim2=2)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_div.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nimport torch as torch_original\n\nfrom oneflow.test_utils.automated_test_util import *\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\ndef _test_div_impl(test_case, shape, device):\n    x = flow.tensor(\n        np.random.randn(*shape), dtype=flow.float32, device=flow.device(device)\n    )\n    y = flow.tensor(\n        np.random.randn(*shape), dtype=flow.float32, device=flow.device(device)\n    )\n    of_out = flow.div(x, y)\n    np_out = np.divide(x.numpy(), y.numpy())\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001))\n    x = 5\n    y = flow.tensor(\n        np.random.randn(*shape), dtype=flow.float32, device=flow.device(device)\n    )\n    of_out = flow.div(x, y)\n    np_out = np.divide(x, y.numpy())\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001))\n    x = flow.tensor(\n        np.random.randn(*shape), dtype=flow.float32, device=flow.device(device)\n    )\n    y = 5\n    of_out = flow.div(x, y)\n    np_out = np.divide(x.numpy(), y)\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001))\n    x = flow.tensor(\n        np.random.randn(*shape), dtype=flow.float32, device=flow.device(device)\n    )\n    y = flow.tensor(\n        np.random.randn(1, 1), dtype=flow.float32, device=flow.device(device)\n    )\n    of_out = flow.div(x, y)\n    np_out = np.divide(x.numpy(), y.numpy())\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001))\n    x = flow.tensor(np.array([5.0]), dtype=flow.float32, device=flow.device(device))\n    y = flow.tensor(\n        np.random.randn(*shape), dtype=flow.float32, device=flow.device(device)\n    )\n    of_out = flow.div(x, y)\n    np_out = np.divide(x.numpy(), y.numpy())\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001))\n    x = flow.tensor(\n        np.random.randn(*shape),\n        dtype=flow.float32,\n        device=flow.device(device),\n        requires_grad=True,\n    )\n    y = flow.tensor(\n        np.array([5.0]),\n        dtype=flow.float32,\n        device=flow.device(device),\n        requires_grad=True,\n    )\n    of_out = flow.div(x, y)\n    np_out = np.divide(x.numpy(), y.numpy())\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001))\n    of_out = of_out.sum()\n    of_out.backward()\n    np_grad_x = np.full(shape, 0.2)\n    test_case.assertTrue(np.allclose(x.grad.numpy(), np_grad_x, 0.0001, 0.0001))\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestDiv(flow.unittest.TestCase):\n    def test_div(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"shape\"] = [(2, 3), (2, 3, 4), (2, 4, 5, 6)]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            _test_div_impl(test_case, *arg)\n\n    @autotest(n=10, auto_backward=False, check_graph=True, include_complex=True)\n    def test_random_dim_div(test_case):\n        device = random_device()\n        dim0 = random(low=1, high=4).to(int)\n        dim1 = random(low=1, high=4).to(int)\n        x = random_tensor(ndim=2, dim0=dim0, dim1=dim1).to(device)\n        y = random_tensor(ndim=2, dim0=dim0, dim1=dim1).to(device)\n        z = x / y\n        return z\n\n    @autotest(n=10, auto_backward=False, check_graph=True, include_complex=True)\n    def test_random_dim_scalar_div(test_case):\n        device = random_device()\n        dim0 = random(low=1, high=4).to(int)\n        dim1 = random(low=1, high=4).to(int)\n        x = random_tensor(ndim=2, dim0=dim0, dim1=dim1).to(device)\n        y = random_tensor(ndim=0).to(device)\n        z = x / y\n        return z\n\n    @autotest(n=10, auto_backward=False, check_graph=True, include_complex=True)\n    def test_0_size_div(test_case):\n        device = random_device()\n        x = random_tensor(4, 2, 1, 0, 3).to(device)\n        y = random_tensor(4, 2, 1, 0, 3).to(device)\n        z = x / y\n        return z\n\n    @autotest(n=10, auto_backward=False, check_graph=True, include_complex=True)\n    def test_0dim_div(test_case):\n        device = random_device()\n        x = random_tensor(ndim=0).to(device)\n        y = random_tensor(ndim=0).to(device)\n        z = x / y\n        return z\n\n    @autotest(n=10, include_complex=True)\n    def test_non_contiguous_inplace_div(test_case):\n        device = random_device()\n        x = random_tensor(2, 2, 4).to(device)\n        y = x + 1\n        y = y[:, 1:3]\n        y /= random_tensor(2, 2, 2).to(device)\n        return y\n\n    @autotest(n=3, check_graph=False)\n    def test_int_dtype_inplace_div(test_case):\n        num_elems = 20\n        flow_out = flow.arange(num_elems) / num_elems\n        torch_out = torch.arange(num_elems) / num_elems\n        test_case.assertTrue(np.allclose(flow_out.numpy(), torch_out.numpy()))\n\n    @autotest(n=5, include_complex=True)\n    def test_scalar_div_with_random_devices(test_case):\n        x1_device = random_device()\n        x2_device = random_device()\n        x1 = random_tensor(2, 2, 3).to(x1_device).mean()\n        x2 = random_tensor(2, 2, 3).to(x2_device)\n        y = x1 / x2\n        return y\n\n    @profile(torch.div)\n    def profile_div(test_case):\n        input1 = torch.ones(16, 10, 128, 128)\n        input2 = torch.ones(16, 10, 128, 128)\n        torch.div(input1, input2)\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestDivRoundmode(flow.unittest.TestCase):\n    @autotest(n=3)\n    def test_random_dim_div_floor(test_case):\n        device = random_device()\n        dim0 = random(low=1, high=4).to(int)\n        dim1 = random(low=1, high=4).to(int)\n        x = random_tensor(ndim=2, dim0=dim0, dim1=dim1).to(device)\n        y = random_tensor(ndim=2, dim0=dim0, dim1=dim1).to(device)\n        z = torch.div(x, y, rounding_mode=\"floor\")\n        return z\n\n    @autotest(n=3)\n    def test_random_dim_div_trunc(test_case):\n        device = random_device()\n        dim0 = random(low=1, high=4).to(int)\n        dim1 = random(low=1, high=4).to(int)\n        x = random_tensor(ndim=2, dim0=dim0, dim1=dim1).to(device)\n        y = random_tensor(ndim=2, dim0=dim0, dim1=dim1).to(device)\n        z = torch.div(x, y, rounding_mode=\"trunc\")\n        return z\n\n    @autotest(n=3)\n    def test_scalar_div_mode_floor(test_case):\n        device = random_device()\n        x1 = random(low=1, high=5).to(float)\n        x2 = random_tensor(2, 2, 3).to(device)\n        y = torch.div(x1, x2, rounding_mode=\"floor\")\n        return y\n\n    @autotest(n=3)\n    def test_scalar_div_mode_trunc(test_case):\n        device = random_device()\n        x1 = random(low=1, high=5).to(float)\n        x2 = random_tensor(2, 2, 3).to(device)\n        y = torch.div(x1, x2, rounding_mode=\"trunc\")\n        return y\n\n    @autotest(n=3)\n    def test_scalar_div_mode_floor2(test_case):\n        device = random_device()\n        x1 = random(low=1, high=5).to(float)\n        x2 = random_tensor(2, 2, 3).to(device)\n        y = torch.div(x2, x1, rounding_mode=\"floor\")\n        return y\n\n    @autotest(n=3)\n    def test_scalar_div_mode_trunc2(test_case):\n        device = random_device()\n        x1 = random(low=1, high=5).to(float)\n        x2 = random_tensor(2, 2, 3).to(device)\n        y = torch.div(x2, x1, rounding_mode=\"trunc\")\n        return y\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_dlpack.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport random\nimport unittest\nimport os\n\nimport torch\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.test_util import GenCartesianProduct\n\n\ntest_device_args = (\n    [(\"cpu\",)]\n    if os.getenv(\"ONEFLOW_TEST_CPU_ONLY\")\n    else [(\"cpu\",), (\"cuda\", 0), (\"cuda\", 1)]\n)\ntest_args = list(\n    GenCartesianProduct((test_device_args, [(torch, flow), (flow, torch)]))\n)\n\n\ndef are_tensors_equal(a, b):\n    def are_devices_equal(a, b):\n        if a.type == \"cuda\" and b.type == \"cuda\":\n            return a.index == b.index\n        else:\n            return a.type == b.type\n\n    return (\n        np.array_equal(a.cpu().numpy(), b.cpu().numpy())\n        and are_devices_equal(a.device, b.device)\n        and a.shape == b.shape\n        and a.stride() == b.stride()\n        and a.cpu().numpy().dtype == b.cpu().numpy().dtype\n    )\n\n\n@flow.unittest.skip_unless_1n2d()\nclass TestPack(flow.unittest.TestCase):\n    def test_same_data(test_case):\n        for device_args, (m1, m2) in test_args:\n            tensor1 = m1.randn(3, 4, 5, device=m1.device(*device_args))\n            tensor2 = m2.from_dlpack(m1.to_dlpack(tensor1))\n            test_case.assertTrue(are_tensors_equal(tensor1, tensor2))\n            test_case.assertEqual(tensor2.storage_offset(), 0)\n\n            tensor2[1:2, 2:3, 3:4] = random.random()\n            # NOTE: OneFlow operations are asynchoronously executed,\n            # so we need to synchronize explicitly here.\n            flow._oneflow_internal.eager.Sync()\n            test_case.assertTrue(are_tensors_equal(tensor1, tensor2))\n\n    def test_use_ops(test_case):\n        for device_args, (m1, m2) in test_args:\n            tensor1 = m1.randn(3, 4, 5, device=m1.device(*device_args))\n            tensor2 = m2.from_dlpack(m1.to_dlpack(tensor1))\n            res1 = tensor1 ** 2\n            res2 = tensor2 ** 2\n            test_case.assertTrue(np.allclose(res1.cpu().numpy(), res2.cpu().numpy()))\n\n    def test_more_dtype(test_case):\n        # PyTorch bfloat16 tensor doesn't support .numpy() method\n        # so we can't test it\n        # torch.bfloat16, flow.bfloat16\n        dtypes = [\"float64\", \"float32\", \"float16\", \"int64\", \"int32\", \"int8\", \"uint8\"]\n\n        for device_args, (m1, m2) in test_args:\n            for dtype in dtypes:\n                tensor1 = m1.ones(\n                    (2, 3), dtype=getattr(m1, dtype), device=m1.device(*device_args)\n                )\n                tensor2 = m2.from_dlpack(m1.to_dlpack(tensor1))\n                test_case.assertTrue(are_tensors_equal(tensor1, tensor2))\n\n    def test_non_contiguous_input(test_case):\n        for device_args, (m1, m2) in test_args:\n            tensor1 = (\n                m1.randn(2, 3, 4, 5).permute(2, 0, 3, 1).to(m1.device(*device_args))\n            )\n            tensor2 = m2.from_dlpack(m1.to_dlpack(tensor1))\n            test_case.assertTrue(are_tensors_equal(tensor1, tensor2))\n\n    def test_scalar_tensor(test_case):\n        for device_args, (m1, m2) in test_args:\n            tensor1 = m1.tensor(5).to(m1.device(*device_args))\n            tensor2 = m2.from_dlpack(m1.to_dlpack(tensor1))\n            test_case.assertTrue(are_tensors_equal(tensor1, tensor2))\n\n    def test_0_size_tensor(test_case):\n        for device_args, (m1, m2) in test_args:\n            tensor1 = m1.tensor([]).to(m1.device(*device_args))\n            tensor2 = m2.from_dlpack(m1.to_dlpack(tensor1))\n            test_case.assertTrue(are_tensors_equal(tensor1, tensor2))\n\n    def test_lifecycle(test_case):\n        for device_args, (m1, m2) in test_args:\n            tensor1 = m1.randn(2, 3, 4, 5).to(m1.device(*device_args))\n            tensor2 = m2.from_dlpack(m1.to_dlpack(tensor1))\n            value = tensor1.cpu().numpy()\n            del tensor2\n            if device_args[0] == \"cuda\":\n                m2.cuda.synchronize()\n                # actually release the cuda memory\n                m2.cuda.empty_cache()\n            test_case.assertTrue(np.array_equal(tensor1.cpu().numpy(), value))\n\n            tensor1 = m1.randn(2, 3, 4, 5).to(m1.device(*device_args))\n            tensor2 = m2.from_dlpack(m1.to_dlpack(tensor1))\n            value = tensor2.cpu().numpy()\n            del tensor1\n            if device_args[0] == \"cuda\":\n                m1.cuda.synchronize()\n                m1.cuda.empty_cache()\n            test_case.assertTrue(np.array_equal(tensor2.cpu().numpy(), value))\n\n    def test_subview(test_case):\n        for device_args, (m1, m2) in test_args:\n            tensor1 = m1.randn(3, 4, 5, device=m1.device(*device_args))\n            tensor1 = tensor1[1:, :, ::2]\n            tensor2 = m2.from_dlpack(m1.to_dlpack(tensor1))\n            test_case.assertTrue(are_tensors_equal(tensor1, tensor2))\n            test_case.assertEqual(tensor2.storage_offset(), 0)\n\n            tensor2[1:2, ::2, 3:4] = random.random()\n            test_case.assertTrue(are_tensors_equal(tensor1, tensor2))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_dot.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nimport numpy as np\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestDot(flow.unittest.TestCase):\n    @autotest(n=5)\n    def test_dot(test_case):\n        device = random_device()\n        k = random(10, 100)\n        x = random_tensor(ndim=1, dim0=k).to(device)\n        y = random_tensor(ndim=1, dim0=k).to(device)\n        z = torch.dot(x, y)\n        return z\n\n    @profile(torch.dot)\n    def profile_dot(test_case):\n        input1 = torch.ones(10000)\n        input2 = torch.ones(10000)\n        torch.dot(input1, input2)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_dropout.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport os\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\ndef do_test_dropout_numpy_p0(test_case, shape, device, dtype):\n    np_x = np.random.randn(*shape).astype(dtype)\n    np_one_mask = np.ones_like(np_x)\n    x_tensor = flow.tensor(np_x, requires_grad=True, device=device)\n    out = flow._C.dropout(x_tensor, p=0.0)\n    test_case.assertTrue(np.allclose(out.numpy(), np_x, atol=1e-5, rtol=1e-5))\n    out_sum = out.sum()\n    out_sum.backward()\n    test_case.assertTrue(\n        np.allclose(x_tensor.grad.numpy(), np_one_mask, atol=1e-5, rtol=1e-5)\n    )\n\n\ndef do_test_dropout_numpy_p1(test_case, shape, device, dtype):\n    np_x = np.random.randn(*shape).astype(dtype)\n    np_zero_mask = np.zeros_like(np_x)\n    x_tensor = flow.tensor(np_x, requires_grad=True, device=device)\n    out = flow._C.dropout(x_tensor, p=1.0)\n    test_case.assertTrue(np.allclose(out.numpy(), np_zero_mask, atol=1e-5, rtol=1e-5))\n    out_sum = out.sum()\n    out_sum.backward()\n    test_case.assertTrue(\n        np.allclose(x_tensor.grad.numpy(), np_zero_mask, atol=1e-5, rtol=1e-5)\n    )\n\n\ndef do_test_dropout_numpy_fp16_p0(test_case, shape):\n    np_x = np.random.randn(*shape).astype(np.float32)\n    np_x_fp16 = np_x.astype(np.float16)\n    x_tensor = flow.tensor(np_x, requires_grad=True, device=\"cuda\")\n    x_tensor_fp16 = flow.cast(x_tensor, flow.float16)\n    np_one_mask = np.ones_like(np_x)\n    out = flow._C.dropout(x_tensor_fp16, p=0.0)\n    out_fp32 = flow.cast(out, flow.float32)\n    test_case.assertTrue(np.allclose(out_fp32.numpy(), np_x_fp16, atol=1e-5, rtol=1e-5))\n    out_sum = out_fp32.sum()\n    out_sum.backward()\n    test_case.assertTrue(\n        np.allclose(x_tensor.grad.numpy(), np_one_mask, atol=1e-5, rtol=1e-5)\n    )\n\n\ndef do_test_dropout_numpy_fp16_p1(test_case, shape):\n    np_x = np.random.randn(*shape).astype(np.float32)\n    x_tensor = flow.tensor(np_x, requires_grad=True, device=\"cuda\")\n    x_tensor_fp16 = flow.cast(x_tensor, flow.float16)\n    np_zero_mask = np.zeros_like(np_x)\n    out = flow._C.dropout(x_tensor_fp16, p=1.0)\n    out_fp32 = flow.cast(out, flow.float32)\n    test_case.assertTrue(\n        np.allclose(out_fp32.numpy(), np_zero_mask, atol=1e-5, rtol=1e-5)\n    )\n    out_sum = out_fp32.sum()\n    out_sum.backward()\n    test_case.assertTrue(\n        np.allclose(x_tensor.grad.numpy(), np_zero_mask, atol=1e-5, rtol=1e-5)\n    )\n\n\ndef do_test_dropout_addend_numpy_p0(test_case, shape, device, dtype):\n    np_x = np.random.randn(*shape).astype(dtype)\n    np_addend = np.random.randn(*shape).astype(dtype)\n    np_one_mask = np.ones_like(np_x)\n    x_tensor = flow.tensor(np_x, requires_grad=True, device=device)\n    addend_tensor = flow.tensor(np_addend, requires_grad=True, device=device)\n    DropoutModule = flow.nn.Dropout(p=0.0)\n    out = DropoutModule(x_tensor, addend_tensor)\n    test_case.assertTrue(\n        np.allclose(out.numpy(), np_x + np_addend, atol=1e-5, rtol=1e-5)\n    )\n    out_sum = out.sum()\n    out_sum.backward()\n    test_case.assertTrue(\n        np.allclose(x_tensor.grad.numpy(), np_one_mask, atol=1e-5, rtol=1e-5)\n    )\n    test_case.assertTrue(\n        np.allclose(addend_tensor.grad.numpy(), np_one_mask, atol=1e-5, rtol=1e-5)\n    )\n\n\ndef do_test_dropout_addend_numpy_p1(test_case, shape, device, dtype):\n    np_x = np.random.randn(*shape).astype(dtype)\n    np_addend = np.random.randn(*shape).astype(dtype)\n    np_one_mask = np.ones_like(np_x)\n    np_zero_mask = np.zeros_like(np_x)\n    x_tensor = flow.tensor(np_x, requires_grad=True, device=device)\n    addend_tensor = flow.tensor(np_addend, requires_grad=True, device=device)\n    DropoutModule = flow.nn.Dropout(p=1.0)\n    out = DropoutModule(x_tensor, addend_tensor)\n    test_case.assertTrue(np.allclose(out.numpy(), np_addend, atol=1e-5, rtol=1e-5))\n    out_sum = out.sum()\n    out_sum.backward()\n    test_case.assertTrue(\n        np.allclose(x_tensor.grad.numpy(), np_zero_mask, atol=1e-5, rtol=1e-5)\n    )\n    test_case.assertTrue(\n        np.allclose(addend_tensor.grad.numpy(), np_one_mask, atol=1e-5, rtol=1e-5)\n    )\n\n\ndef do_test_dropout_addend_numpy_fp16_p0(test_case, shape):\n    np_x = np.random.randn(*shape).astype(np.float32)\n    np_x_fp16 = np_x.astype(np.float16)\n    np_addend = np.random.randn(*shape).astype(np.float32)\n    np_addend_fp16 = np_addend.astype(np.float16)\n    x_tensor = flow.tensor(np_x, requires_grad=True, device=\"cuda\")\n    x_tensor_fp16 = flow.cast(x_tensor, flow.float16)\n    addend_tensor = flow.tensor(np_addend, requires_grad=True, device=\"cuda\")\n    addend_tensor_fp16 = flow.cast(addend_tensor, flow.float16)\n    np_one_mask = np.ones_like(np_x)\n    DropoutModule = flow.nn.Dropout(p=0.0)\n    out = DropoutModule(x_tensor_fp16, addend_tensor_fp16)\n    out_fp32 = flow.cast(out, flow.float32)\n    test_case.assertTrue(\n        np.allclose(out_fp32.numpy(), np_x_fp16 + np_addend_fp16, atol=1e-5, rtol=1e-5)\n    )\n    out_sum = out_fp32.sum()\n    out_sum.backward()\n    test_case.assertTrue(\n        np.allclose(x_tensor.grad.numpy(), np_one_mask, atol=1e-5, rtol=1e-5)\n    )\n    test_case.assertTrue(\n        np.allclose(addend_tensor.grad.numpy(), np_one_mask, atol=1e-5, rtol=1e-5)\n    )\n\n\ndef do_test_dropout_addend_numpy_fp16_p1(test_case, shape):\n    np_x = np.random.randn(*shape).astype(np.float32)\n    np_addend = np.random.randn(*shape).astype(np.float32)\n    np_addend_fp16 = np_addend.astype(np.float16)\n    x_tensor = flow.tensor(np_x, requires_grad=True, device=\"cuda\")\n    x_tensor_fp16 = flow.cast(x_tensor, flow.float16)\n    addend_tensor = flow.tensor(np_addend, requires_grad=True, device=\"cuda\")\n    addend_tensor_fp16 = flow.cast(addend_tensor, flow.float16)\n    np_zero_mask = np.zeros_like(np_x)\n    np_one_mask = np.ones_like(np_x)\n    DropoutModule = flow.nn.Dropout(p=1.0)\n    out = DropoutModule(x_tensor_fp16, addend_tensor_fp16)\n    out_fp32 = flow.cast(out, flow.float32)\n    test_case.assertTrue(\n        np.allclose(out_fp32.numpy(), np_addend_fp16, atol=1e-5, rtol=1e-5)\n    )\n    out_sum = out_fp32.sum()\n    out_sum.backward()\n    test_case.assertTrue(\n        np.allclose(x_tensor.grad.numpy(), np_zero_mask, atol=1e-5, rtol=1e-5)\n    )\n    test_case.assertTrue(\n        np.allclose(addend_tensor.grad.numpy(), np_one_mask, atol=1e-5, rtol=1e-5)\n    )\n\n\ndef fixed_cpu_seed_dropout_test(test_case):\n    gen1 = flow.Generator()\n    gen1.manual_seed(5)\n    dropped_array1 = np.array(\n        [\n            [0.000000, 0.000000, 1.333333],\n            [1.333333, 0.000000, 1.333333],\n            [1.333333, 1.333333, 1.333333],\n        ]\n    ).astype(np.float32)\n    dropout1 = flow.nn.Dropout(p=0.25, generator=gen1)\n    x = flow.ones((3, 3), dtype=flow.float32)\n    out1 = dropout1(x)\n    test_case.assertTrue(\n        np.allclose(out1.numpy(), dropped_array1, atol=1e-4, rtol=1e-4)\n    )\n    gen2 = flow.Generator()\n    gen2.manual_seed(7)\n    dropout2 = flow.nn.Dropout(p=0.5, generator=gen2)\n    dropped_array2 = np.array(\n        [[0.0, 0.0, 2.0], [0.0, 0.0, 2.0], [2.0, 0.0, 2.0]]\n    ).astype(np.float32)\n    out2 = dropout2(x)\n    test_case.assertTrue(\n        np.allclose(out2.numpy(), dropped_array2, atol=1e-4, rtol=1e-4)\n    )\n\n\ndef fixed_gpu_seed_dropout_test(test_case):\n    gen1 = flow.Generator()\n    gen1.manual_seed(5)\n    dropped_array1 = np.array(\n        [[1.2500, 0.0000, 1.2500], [1.2500, 1.2500, 1.2500], [1.2500, 1.2500, 1.2500]]\n    ).astype(np.float32)\n    dropout1 = flow.nn.Dropout(p=0.2, generator=gen1).to(\"cuda\")\n    x = flow.ones((3, 3), dtype=flow.float32).to(\"cuda\")\n    out1 = dropout1(x)\n    test_case.assertTrue(\n        np.allclose(out1.numpy(), dropped_array1, atol=1e-4, rtol=1e-4)\n    )\n    gen2 = flow.Generator()\n    gen2.manual_seed(7)\n    dropout2 = flow.nn.Dropout(p=0.7, generator=gen2).to(\"cuda\")\n    dropped_array2 = np.array(\n        [\n            [3.333333, 3.333333, 0.000000],\n            [0.000000, 0.000000, 0.000000],\n            [0.000000, 0.000000, 0.000000],\n        ]\n    ).astype(np.float32)\n    out2 = dropout2(x)\n    test_case.assertTrue(\n        np.allclose(out2.numpy(), dropped_array2, atol=1e-4, rtol=1e-4)\n    )\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestModule(flow.unittest.TestCase):\n    def test_dropout_numpy_case(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [do_test_dropout_numpy_p0, do_test_dropout_numpy_p1]\n        arg_dict[\"shape\"] = [[4], [4, 3], [4, 127, 256], [2, 1024, 1024]]\n        arg_dict[\"device\"] = [\"cuda\"]\n        if os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"):\n            arg_dict[\"device\"] = [\"cpu\"]\n        arg_dict[\"dtype\"] = [np.float32, np.float64]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    def test_dropout_fp16_numpy_case(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [\n            do_test_dropout_numpy_fp16_p0,\n            do_test_dropout_numpy_fp16_p1,\n        ]\n        arg_dict[\"shape\"] = [[4, 127, 256], [5, 63, 49], [7, 32, 64], [16, 512, 512]]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n    def test_dropout_addend_numpy_case(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [\n            do_test_dropout_addend_numpy_p0,\n            do_test_dropout_addend_numpy_p1,\n        ]\n        arg_dict[\"shape\"] = [[4, 47, 156], [5, 33, 65], [3, 132, 94], [9, 256, 63]]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"dtype\"] = [np.float32, np.float64]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    def test_dropout_addend_fp16_numpy_case(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [\n            do_test_dropout_addend_numpy_fp16_p0,\n            do_test_dropout_addend_numpy_fp16_p1,\n        ]\n        arg_dict[\"shape\"] = [[2, 44, 66], [1, 2, 7], [5, 32, 74], [8, 125, 63]]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n    def test_cpu_fixed_dropout(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [\n            fixed_cpu_seed_dropout_test,\n        ]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case)\n\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    def test_gpu_fixed_dropout(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [\n            fixed_gpu_seed_dropout_test,\n        ]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case)\n\n    @autotest(n=5)\n    def test_dropout_p0(test_case):\n        device = random_device()\n        x = random_tensor(ndim=random(), dim0=random(1, 8)).to(device)\n        m = torch.nn.Dropout(p=0, inplace=False)\n        return m(x)\n\n    @unittest.skipIf(True, \"Pytorch 1.10.0 do not have Dropout1d module\")\n    @autotest(n=5)\n    def test_dropout1d_p0(test_case):\n        device = random_device()\n        x = random_tensor(ndim=random(2, 4), dim0=random(1, 8)).to(device)\n        m = torch.nn.Dropout1d(p=0, inplace=False)\n        return m(x)\n\n    @autotest(n=5)\n    def test_dropout2d_p0(test_case):\n        device = random_device()\n        x = random_tensor(ndim=random(), dim0=random(1, 8)).to(device)\n        m = torch.nn.Dropout2d(p=0, inplace=False)\n        return m(x)\n\n    @unittest.skipIf(\n        True,\n        \"this will Pytorch 1.13.0, but failed with Pytorch 1.10.0 because some non-leaf tensors don't have grad\",\n    )\n    @autotest(n=5)\n    def test_dropout3d_p0(test_case):\n        device = random_device()\n        x = random_tensor(ndim=random(), dim0=random(1, 8)).to(device)\n        m = torch.nn.Dropout3d(p=0, inplace=False)\n        return m(x)\n\n    @autotest(n=5)\n    def test_dropout_p1(test_case):\n        device = random_device()\n        x = random_tensor(ndim=random(), dim0=random(1, 8)).to(device)\n        m = torch.nn.Dropout(p=1.0, inplace=False)\n        return m(x)\n\n    @unittest.skipIf(True, \"Pytorch 1.10.0 do not have Dropout1d module\")\n    @autotest(n=5)\n    def test_dropout1d_p1(test_case):\n        device = random_device()\n        x = random_tensor(ndim=random(2, 4), dim0=random(1, 8)).to(device)\n        m = torch.nn.Dropout1d(p=1.0, inplace=False)\n        return m(x)\n\n    @unittest.skip(\"skip for now, becase it failed 8 times in past week\")\n    @autotest(n=5)\n    def test_dropout2d_p1(test_case):\n        device = random_device()\n        x = random_tensor(ndim=random(), dim0=random(1, 8)).to(device)\n        m = torch.nn.Dropout2d(p=1.0, inplace=False)\n        return m(x)\n\n    @unittest.skipIf(\n        True,\n        \"this will Pytorch 1.13.0, but failed with Pytorch 1.10.0 because some non-leaf tensors don't have grad\",\n    )\n    @autotest(n=5)\n    def test_dropout3d_p1(test_case):\n        device = random_device()\n        x = random_tensor(ndim=random(), dim0=random(1, 8)).to(device)\n        m = torch.nn.Dropout3d(p=1.0, inplace=False)\n        return m(x)\n\n    @unittest.skipIf(True, \"Pytorch 1.10.0 do not have Dropout1d module\")\n    @autotest(n=5)\n    def test_functional_dropout1d_p1(test_case):\n        device = random_device()\n        x = random_tensor(ndim=random(2, 4), dim0=random(1, 8)).to(device)\n        return torch.nn.functional.dropout1d(x, p=1.0)\n\n    @autotest(n=5)\n    def test_functional_dropout2d_p1(test_case):\n        device = random_device()\n        x = random_tensor(ndim=random(), dim0=random(1, 8)).to(device)\n        return torch.nn.functional.dropout2d(x, p=1.0)\n\n    @unittest.skipIf(\n        True,\n        \"this will Pytorch 1.13.0, but failed with Pytorch 1.10.0 because some non-leaf tensors don't have grad\",\n    )\n    @autotest(n=5)\n    def test_functional_dropout3d_p1(test_case):\n        device = random_device()\n        x = random_tensor(ndim=random(), dim0=random(1, 8)).to(device)\n        return torch.nn.functional.dropout3d(x, p=1.0)\n\n    @autotest(n=5, check_graph=False)\n    def test_dropout_eval(test_case):\n        device = random_device()\n        x = random_tensor(ndim=random(), dim0=random(1, 8)).to(device)\n        m = torch.nn.Dropout(p=1.0, inplace=False)\n        m.eval()\n        return m(x)\n\n    @unittest.skipIf(True, \"Pytorch 1.10.0 do not have Dropout1d module\")\n    @autotest(n=5, check_graph=False)\n    def test_dropout1d_eval(test_case):\n        device = random_device()\n        x = random_tensor(ndim=random(2, 4), dim0=random(1, 8)).to(device)\n        m = torch.nn.Dropout1d(p=1.0, inplace=False)\n        m.eval()\n        return m(x)\n\n    @autotest(n=5, check_graph=False)\n    def test_dropout2d_eval(test_case):\n        device = random_device()\n        x = random_tensor(ndim=random(), dim0=random(1, 8)).to(device)\n        m = torch.nn.Dropout2d(p=1.0, inplace=False)\n        m.eval()\n        return m(x)\n\n    @unittest.skipIf(\n        True,\n        \"this will Pytorch 1.13.0, but failed with Pytorch 1.10.0 because some non-leaf tensors don't have grad\",\n    )\n    @autotest(n=5, check_graph=False)\n    def test_dropout3d_eval(test_case):\n        device = random_device()\n        x = random_tensor(ndim=random(), dim0=random(1, 8)).to(device)\n        m = torch.nn.Dropout3d(p=1.0, inplace=False)\n        m.eval()\n        return m(x)\n\n    @autotest(n=5, check_graph=False)\n    def test_0dim_dropout_eval(test_case):\n        device = random_device()\n        x = random_tensor(ndim=0).to(device)\n        m = torch.nn.Dropout(p=1.0, inplace=False)\n        m.eval()\n        return m(x)\n\n    @profile(torch.nn.functional.dropout)\n    def profile_dropout(test_case):\n        input = torch.ones(100, 128)\n        torch.nn.functional.dropout(input, p=0.3)\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n@flow.unittest.skip_unless_1n2d()\nclass TestDropoutOnNonDefaultDevice(flow.unittest.TestCase):\n    def test_non_default_device(test_case):\n        x = flow.tensor([2, 3], dtype=flow.float, device=\"cuda:1\")\n        y = flow._C.dropout(x)\n        test_case.assertEqual(y.device, flow.device(\"cuda:1\"))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_dynamic_allocation_gradient_shuffle_shuffle_global.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport os\n\n# dynamic memory allocation can't be tested in unittest\nos.environ[\"ONEFLOW_ONE_EMBEDDING_USE_DYNAMIC_MEMORY_ALLOCATION\"] = \"1\"\nimport unittest\nfrom collections import OrderedDict\nfrom oneflow.test_utils.test_util import GenArgDict\nimport numpy as np\nimport oneflow as flow\n\nparallel_num = 2\nmax_id = 1000\n\n\ndef get_tensors(batch_size, num_tables):\n    placement = flow.placement(type=\"cuda\", ranks=list(range(parallel_num)))\n    ids = np.random.randint(0, max_id, (batch_size, num_tables), dtype=np.int64)\n    ids_tensor = flow.tensor(ids, requires_grad=False).to_global(\n        placement=placement, sbp=flow.sbp.split(0)\n    )\n    table_ids = (\n        ids % num_tables\n    )  # same id must have same table id, so in this case get table_ids from ids\n    table_ids_tensor = flow.tensor(\n        table_ids.astype(np.int32), requires_grad=False\n    ).to_global(placement=placement, sbp=flow.sbp.split(0))\n    return ids_tensor, table_ids_tensor\n\n\ndef round_half_away_from_zero(x):\n    sign = np.sign(x)\n    abs_val = np.abs(x)\n    abs_val += 0.5\n    floor_val = np.floor(abs_val)\n    out = floor_val * sign\n    return out\n\n\ndef _test_embedding_gradient_shuffle(test_case, enable_quantize, fp16, embedding_size):\n    np_tolerance = 0\n    batch_size = int(1024 / parallel_num)\n    placement = flow.placement(type=\"cuda\", ranks=list(range(parallel_num)))\n    num_tables = 26\n    enable_quantized_comm = enable_quantize and embedding_size < 1025\n    if enable_quantized_comm:\n        np_tolerance = 0.5\n        os.environ[\"ONEFLOW_ONE_EMBEDDING_ENABLE_QUANTIZED_COMM\"] = \"1\"\n    else:\n        if fp16:\n            np_tolerance = 1e-2\n        else:\n            np_tolerance = 1e-4\n        os.environ[\"ONEFLOW_ONE_EMBEDDING_ENABLE_QUANTIZED_COMM\"] = \"0\"\n    embedding_grad = np.random.rand(batch_size, num_tables, embedding_size).astype(\n        np.float32\n    )\n    embedding_grad_tensor = flow.tensor(embedding_grad, requires_grad=False).to_global(\n        placement=placement, sbp=flow.sbp.split(0)\n    )\n\n    class TestGraph(flow.nn.Graph):\n        def __init__(self):\n            super().__init__()\n\n        def build(self, ids, table_ids, embedding_grad):\n            (\n                num_unique_matrix,\n                inverse_unique_partition_indices,\n                cur_rank_num_unique,\n                cur_rank_unique_ids,\n                _,\n                cur_rank_inverse_indices,\n            ) = flow._C.one_embedding_id_shuffle(ids, table_ids, num_tables, \"test\")\n            if fp16:\n                embedding_grad = flow.cast(embedding_grad, flow.float16)\n            cur_rank_unique_embedding_grad = flow._C.one_embedding_embedding_gradient_shuffle(\n                embedding_grad,\n                num_unique_matrix,\n                cur_rank_inverse_indices,\n                inverse_unique_partition_indices,\n                \"test\",\n            )\n            if fp16:\n                cur_rank_unique_embedding_grad = flow.cast(\n                    cur_rank_unique_embedding_grad, flow.float32\n                )\n            return (\n                cur_rank_unique_embedding_grad,\n                flow.cast(cur_rank_num_unique, flow.int32),\n                cur_rank_unique_ids,\n            )\n\n    graph = TestGraph()\n    for i in range(10):\n        ids_tensor, table_ids_tensor = get_tensors(batch_size, num_tables)\n        graph(ids_tensor, table_ids_tensor, embedding_grad_tensor)\n    ids_tensor, table_ids_tensor = get_tensors(batch_size, num_tables)\n    (\n        cur_rank_unique_embedding_grad,\n        local_cur_rank_num_unique,\n        cur_rank_unique_ids,\n    ) = graph(ids_tensor, table_ids_tensor, embedding_grad_tensor)\n    cur_rank_num_unique = local_cur_rank_num_unique.to_local().to_global(\n        placement=placement, sbp=flow.sbp.split(0)\n    )\n    global_ids = ids_tensor.numpy()\n    global_embedding_grad = embedding_grad_tensor.numpy()\n    np_unique_ids = np.unique(global_ids)\n    np_num_unique = np_unique_ids.size\n    np_cur_rank_unique_embedding_grad = np.zeros((max_id, embedding_size))\n    if fp16:\n        global_embedding_grad = global_embedding_grad.astype(np.float16)\n    for k in range(np_num_unique):\n        unique_id = np_unique_ids[k]\n        np_data = sum(\n            global_embedding_grad.reshape(-1, embedding_size)[\n                np.where(global_ids.flatten() == unique_id)[0]\n            ]\n        )\n        # Quantize Embedding Gradient.\n        if enable_quantized_comm:\n            abs_max_factor = np.max(np.abs(np_data))\n            int8_factor = np.full(abs_max_factor.shape, 127.0, dtype=np.float32)\n            quantize_factor = int8_factor / abs_max_factor\n            np_data = np_data * quantize_factor\n            np_data = round_half_away_from_zero(np_data)\n            np_data = np_data.astype(np.int8)\n            np_data = np_data.astype(np.float32)\n            dequantize_factor = abs_max_factor / int8_factor\n            np_data = np_data * dequantize_factor\n\n        np_cur_rank_unique_embedding_grad[unique_id, :] = np_data\n        if fp16:\n            np_cur_rank_unique_embedding_grad = np_cur_rank_unique_embedding_grad.astype(\n                np.float32\n            )\n\n    cur_rank_num_ids = batch_size * num_tables * parallel_num\n    of_unique_embedding_grad = np.zeros((max_id, embedding_size))\n    for i in range(parallel_num):\n        num_unique_i = cur_rank_num_unique.numpy()[i]\n        unique_ids_i = cur_rank_unique_ids.numpy()[\n            cur_rank_num_ids * i : cur_rank_num_ids * (i + 1)\n        ]\n        unique_embedding_grad_i = cur_rank_unique_embedding_grad.numpy()[\n            cur_rank_num_ids * i : cur_rank_num_ids * (i + 1)\n        ]\n        for j in range(num_unique_i):\n            unique_id = unique_ids_i[j]\n            of_unique_embedding_grad[unique_id, :] = unique_embedding_grad_i[j, :]\n\n    test_case.assertTrue(\n        np.allclose(\n            of_unique_embedding_grad,\n            np_cur_rank_unique_embedding_grad,\n            atol=np_tolerance,\n            rtol=np_tolerance,\n        ),\n    )\n\n\n# FIXME: restore this test after upgrading CUDA driver\n@unittest.skip(\"CUDA driver version of CI machine is insufficient for this test\")\n# @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n@flow.unittest.skip_unless_1n2d()\nclass DataShuffleTestCase(flow.unittest.TestCase):\n    def test_embedding_gradient_shuffle(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"enable_quantize\"] = [True, False]\n        arg_dict[\"fp16\"] = [True, False]\n        arg_dict[\"embedding_size\"] = [128, 17]\n        for kwargs in GenArgDict(arg_dict):\n            _test_embedding_gradient_shuffle(test_case, **kwargs)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_eager_boxing.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nimport oneflow as flow\nimport os\n\nimport oneflow.unittest\nfrom oneflow.test_utils.test_util import GenArgList\n\n\ndef _test_eager_boxing_with_non_overlapping_placement_p_to_s1(\n    test_case, in_device, out_device\n):\n    if flow.env.get_rank() == 0:\n        np_arr = np.array(\n            [[4, 6, 5, 20], [6, 8, 9, 0], [3, 7, 5, 0], [6, 8, 9, 0]], dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 1:\n        np_arr = np.array(\n            [[2, 10, 10, 7], [3, 9, 10, 5], [4, 6, 6, 9], [6, 8, 6, 4]],\n            dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 2:\n        np_arr = np.array(\n            [[9, 6, 5, 8], [4, 9, 7, 0], [2, 5, 7, 9], [6, 8, 10, 0]], dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 3:\n        np_arr = np.array(\n            [[9, 4, 5, 8], [7, 2, 9, 5], [6, 3, 9, 2], [3, 7, 5, 8]], dtype=np.float32,\n        )\n    device = flow.device(in_device)\n    tensor = flow.tensor(np_arr, device=device, dtype=flow.float32)\n    placement = flow.placement(in_device, ranks=[0, 1])\n    x = tensor.to_global(placement, flow.sbp.partial_sum)\n    new_placement = flow.placement(out_device, ranks=[2, 3])\n    y = x.to_global(new_placement, flow.sbp.split(1))\n    test_case.assertEqual(y.placement, new_placement)\n    if flow.env.get_rank() == 2:\n        test_case.assertTrue(\n            np.array_equal(\n                y.to_local().numpy(),\n                np.array([[6, 16], [9, 17], [7, 13], [12, 16],], dtype=np.float32,),\n            )\n        )\n    if flow.env.get_rank() == 3:\n        test_case.assertTrue(\n            np.array_equal(\n                y.to_local().numpy(),\n                np.array([[15, 27], [19, 5], [11, 9], [15, 4],], dtype=np.float32,),\n            )\n        )\n\n\ndef _test_eager_boxing_with_non_overlapping_placement_b_to_s1(\n    test_case, in_device, out_device\n):\n    if flow.env.get_rank() == 0:\n        np_arr = np.array(\n            [[4, 6, 5, 20], [6, 8, 9, 0], [3, 7, 5, 0], [6, 8, 9, 0]], dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 1:\n        np_arr = np.array(\n            [[2, 10, 10, 7], [3, 9, 10, 5], [4, 6, 6, 9], [6, 8, 6, 4]],\n            dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 2:\n        np_arr = np.array(\n            [[9, 6, 5, 8], [4, 9, 7, 0], [2, 5, 7, 9], [6, 8, 10, 0]], dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 3:\n        np_arr = np.array(\n            [[9, 4, 5, 8], [7, 2, 9, 5], [6, 3, 9, 2], [3, 7, 5, 8]], dtype=np.float32,\n        )\n    device = flow.device(in_device)\n    tensor = flow.tensor(np_arr, device=device, dtype=flow.float32)\n    placement = flow.placement(in_device, ranks=[0, 1])\n    x = tensor.to_global(placement, flow.sbp.broadcast)\n    new_placement = flow.placement(out_device, ranks=[2, 3])\n    y = x.to_global(new_placement, flow.sbp.split(1))\n    test_case.assertEqual(y.placement, new_placement)\n    if flow.env.get_rank() == 2:\n        test_case.assertTrue(\n            np.array_equal(\n                y.to_local().numpy(),\n                np.array([[4, 6], [6, 8], [3, 7], [6, 8],], dtype=np.float32,),\n            )\n        )\n    if flow.env.get_rank() == 3:\n        test_case.assertTrue(\n            np.array_equal(\n                y.to_local().numpy(),\n                np.array([[5, 20], [9, 0], [5, 0], [9, 0],], dtype=np.float32,),\n            )\n        )\n\n\ndef _test_eager_boxing_with_non_overlapping_placement_s0_to_s1(\n    test_case, in_device, out_device\n):\n    if flow.env.get_rank() == 0:\n        np_arr = np.array(\n            [[4, 6, 5, 20], [6, 8, 9, 0], [3, 7, 5, 0], [6, 8, 9, 0]], dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 1:\n        np_arr = np.array(\n            [[2, 10, 10, 7], [3, 9, 10, 5], [4, 6, 6, 9], [6, 8, 6, 4]],\n            dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 2:\n        np_arr = np.array(\n            [[9, 6, 5, 8], [4, 9, 7, 0], [2, 5, 7, 9], [6, 8, 10, 0]], dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 3:\n        np_arr = np.array(\n            [[9, 4, 5, 8], [7, 2, 9, 5], [6, 3, 9, 2], [3, 7, 5, 8]], dtype=np.float32,\n        )\n    device = flow.device(in_device)\n    tensor = flow.tensor(np_arr, device=device, dtype=flow.float32)\n    placement = flow.placement(in_device, ranks=[0, 1])\n    x = tensor.to_global(placement, flow.sbp.split(0))\n    new_placement = flow.placement(out_device, ranks=[2, 3])\n    y = x.to_global(new_placement, flow.sbp.split(1))\n    test_case.assertEqual(y.placement, new_placement)\n    if flow.env.get_rank() == 2:\n        test_case.assertTrue(\n            np.array_equal(\n                y.to_local().numpy(),\n                np.array(\n                    [[4, 6], [6, 8], [3, 7], [6, 8], [2, 10], [3, 9], [4, 6], [6, 8],],\n                    dtype=np.float32,\n                ),\n            )\n        )\n    if flow.env.get_rank() == 3:\n        test_case.assertTrue(\n            np.array_equal(\n                y.to_local().numpy(),\n                np.array(\n                    [\n                        [5, 20],\n                        [9, 0],\n                        [5, 0],\n                        [9, 0],\n                        [10, 7],\n                        [10, 5],\n                        [6, 9],\n                        [6, 4],\n                    ],\n                    dtype=np.float32,\n                ),\n            )\n        )\n\n\ndef _test_eager_boxing_with_non_overlapping_placement_s1_to_s1(\n    test_case, in_device, out_device\n):\n    if flow.env.get_rank() == 0:\n        np_arr = np.array(\n            [[4, 6, 5, 20], [6, 8, 9, 0], [3, 7, 5, 0], [6, 8, 9, 0]], dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 1:\n        np_arr = np.array(\n            [[2, 10, 10, 7], [3, 9, 10, 5], [4, 6, 6, 9], [6, 8, 6, 4]],\n            dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 2:\n        np_arr = np.array(\n            [[9, 6, 5, 8], [4, 9, 7, 0], [2, 5, 7, 9], [6, 8, 10, 0]], dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 3:\n        np_arr = np.array(\n            [[9, 4, 5, 8], [7, 2, 9, 5], [6, 3, 9, 2], [3, 7, 5, 8]], dtype=np.float32,\n        )\n    device = flow.device(in_device)\n    tensor = flow.tensor(np_arr, device=device, dtype=flow.float32)\n    placement = flow.placement(in_device, ranks=[0, 1])\n    x = tensor.to_global(placement, flow.sbp.split(0))\n    y = x.to_global(placement, flow.sbp.split(1))\n    new_placement = flow.placement(out_device, ranks=[2, 3])\n    z = y.to_global(new_placement, flow.sbp.split(1))\n    test_case.assertEqual(z.placement, new_placement)\n    if flow.env.get_rank() == 2:\n        test_case.assertTrue(\n            np.array_equal(\n                z.to_local().numpy(),\n                np.array(\n                    [[4, 6], [6, 8], [3, 7], [6, 8], [2, 10], [3, 9], [4, 6], [6, 8],],\n                    dtype=np.float32,\n                ),\n            )\n        )\n    if flow.env.get_rank() == 3:\n        test_case.assertTrue(\n            np.array_equal(\n                z.to_local().numpy(),\n                np.array(\n                    [\n                        [5, 20],\n                        [9, 0],\n                        [5, 0],\n                        [9, 0],\n                        [10, 7],\n                        [10, 5],\n                        [6, 9],\n                        [6, 4],\n                    ],\n                    dtype=np.float32,\n                ),\n            )\n        )\n\n\ndef _test_eager_boxing_with_non_overlapping_placement_s1_to_s0(\n    test_case, in_device, out_device\n):\n    if flow.env.get_rank() == 0:\n        np_arr = np.array(\n            [[4, 6, 5, 20], [6, 8, 9, 0], [3, 7, 5, 0], [6, 8, 9, 0]], dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 1:\n        np_arr = np.array(\n            [[2, 10, 10, 7], [3, 9, 10, 5], [4, 6, 6, 9], [6, 8, 6, 4]],\n            dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 2:\n        np_arr = np.array(\n            [[9, 6, 5, 8], [4, 9, 7, 0], [2, 5, 7, 9], [6, 8, 10, 0]], dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 3:\n        np_arr = np.array(\n            [[9, 4, 5, 8], [7, 2, 9, 5], [6, 3, 9, 2], [3, 7, 5, 8]], dtype=np.float32,\n        )\n    device = flow.device(in_device)\n    tensor = flow.tensor(np_arr, device=device, dtype=flow.float32)\n    placement = flow.placement(in_device, ranks=[0, 1])\n    x = tensor.to_global(placement, flow.sbp.split(0))\n    y = x.to_global(placement, flow.sbp.split(1))\n    new_placement = flow.placement(out_device, ranks=[2, 3])\n    z = y.to_global(new_placement, flow.sbp.split(0))\n    test_case.assertEqual(z.placement, new_placement)\n    if flow.env.get_rank() == 2:\n        test_case.assertTrue(\n            np.array_equal(\n                z.to_local().numpy(),\n                np.array(\n                    [[4, 6, 5, 20], [6, 8, 9, 0], [3, 7, 5, 0], [6, 8, 9, 0],],\n                    dtype=np.float32,\n                ),\n            )\n        )\n    if flow.env.get_rank() == 3:\n        test_case.assertTrue(\n            np.array_equal(\n                z.to_local().numpy(),\n                np.array(\n                    [[2, 10, 10, 7], [3, 9, 10, 5], [4, 6, 6, 9], [6, 8, 6, 4],],\n                    dtype=np.float32,\n                ),\n            )\n        )\n\n\ndef _test_eager_boxing_with_non_overlapping_placement_s1_to_b(\n    test_case, in_device, out_device\n):\n    if flow.env.get_rank() == 0:\n        np_arr = np.array(\n            [[4, 6, 5, 20], [6, 8, 9, 0], [3, 7, 5, 0], [6, 8, 9, 0]], dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 1:\n        np_arr = np.array(\n            [[2, 10, 10, 7], [3, 9, 10, 5], [4, 6, 6, 9], [6, 8, 6, 4]],\n            dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 2:\n        np_arr = np.array(\n            [[9, 6, 5, 8], [4, 9, 7, 0], [2, 5, 7, 9], [6, 8, 10, 0]], dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 3:\n        np_arr = np.array(\n            [[9, 4, 5, 8], [7, 2, 9, 5], [6, 3, 9, 2], [3, 7, 5, 8]], dtype=np.float32,\n        )\n    device = flow.device(in_device)\n    tensor = flow.tensor(np_arr, device=device, dtype=flow.float32)\n    placement = flow.placement(in_device, ranks=[0, 1])\n    x = tensor.to_global(placement, flow.sbp.split(0))\n    y = x.to_global(placement, flow.sbp.split(1))\n    new_placement = flow.placement(out_device, ranks=[2, 3])\n    z = y.to_global(new_placement, flow.sbp.broadcast)\n    test_case.assertEqual(z.placement, new_placement)\n    if flow.env.get_rank() == 2:\n        test_case.assertTrue(\n            np.array_equal(\n                z.to_local().numpy(),\n                np.array(\n                    [\n                        [4, 6, 5, 20],\n                        [6, 8, 9, 0],\n                        [3, 7, 5, 0],\n                        [6, 8, 9, 0],\n                        [2, 10, 10, 7],\n                        [3, 9, 10, 5],\n                        [4, 6, 6, 9],\n                        [6, 8, 6, 4],\n                    ],\n                    dtype=np.float32,\n                ),\n            )\n        )\n    if flow.env.get_rank() == 3:\n        test_case.assertTrue(\n            np.array_equal(\n                z.to_local().numpy(),\n                np.array(\n                    [\n                        [4, 6, 5, 20],\n                        [6, 8, 9, 0],\n                        [3, 7, 5, 0],\n                        [6, 8, 9, 0],\n                        [2, 10, 10, 7],\n                        [3, 9, 10, 5],\n                        [4, 6, 6, 9],\n                        [6, 8, 6, 4],\n                    ],\n                    dtype=np.float32,\n                ),\n            )\n        )\n\n\ndef _test_eager_boxing_with_non_overlapping_placement_s1_to_p(\n    test_case, in_device, out_device\n):\n    if flow.env.get_rank() == 0:\n        np_arr = np.array(\n            [[4, 6, 5, 20], [6, 8, 9, 0], [3, 7, 5, 0], [6, 8, 9, 0]], dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 1:\n        np_arr = np.array(\n            [[2, 10, 10, 7], [3, 9, 10, 5], [4, 6, 6, 9], [6, 8, 6, 4]],\n            dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 2:\n        np_arr = np.array(\n            [[9, 6, 5, 8], [4, 9, 7, 0], [2, 5, 7, 9], [6, 8, 10, 0]], dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 3:\n        np_arr = np.array(\n            [[9, 4, 5, 8], [7, 2, 9, 5], [6, 3, 9, 2], [3, 7, 5, 8]], dtype=np.float32,\n        )\n    device = flow.device(in_device)\n    tensor = flow.tensor(np_arr, device=device, dtype=flow.float32)\n    placement = flow.placement(in_device, ranks=[0, 1])\n    x = tensor.to_global(placement, flow.sbp.split(0))\n    y = x.to_global(placement, flow.sbp.split(1))\n    new_placement = flow.placement(out_device, ranks=[2, 3])\n    z = y.to_global(new_placement, flow.sbp.partial_sum)\n    test_case.assertEqual(z.placement, new_placement)\n    if flow.env.get_rank() == 2:\n        test_case.assertTrue(\n            np.array_equal(\n                z.to_local().numpy(),\n                np.array(\n                    [\n                        [4, 6, 0, 0],\n                        [6, 8, 0, 0],\n                        [3, 7, 0, 0],\n                        [6, 8, 0, 0],\n                        [2, 10, 0, 0],\n                        [3, 9, 0, 0],\n                        [4, 6, 0, 0],\n                        [6, 8, 0, 0],\n                    ],\n                    dtype=np.float32,\n                ),\n            )\n        )\n    if flow.env.get_rank() == 3:\n        test_case.assertTrue(\n            np.array_equal(\n                z.to_local().numpy(),\n                np.array(\n                    [\n                        [0, 0, 5, 20],\n                        [0, 0, 9, 0],\n                        [0, 0, 5, 0],\n                        [0, 0, 9, 0],\n                        [0, 0, 10, 7],\n                        [0, 0, 10, 5],\n                        [0, 0, 6, 9],\n                        [0, 0, 6, 4],\n                    ],\n                    dtype=np.float32,\n                ),\n            )\n        )\n\n\ndef _test_eager_boxing_with_overlapping_placement_p_to_s1(\n    test_case, in_device, out_device\n):\n    if flow.env.get_rank() == 0:\n        np_arr = np.array(\n            [[4, 6, 5, 20], [6, 8, 9, 0], [3, 7, 5, 0], [6, 8, 9, 0]], dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 1:\n        np_arr = np.array(\n            [[2, 10, 10, 7], [3, 9, 10, 5], [4, 6, 6, 9], [6, 8, 6, 4]],\n            dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 2:\n        np_arr = np.array(\n            [[9, 6, 5, 8], [4, 9, 7, 0], [2, 5, 7, 9], [6, 8, 10, 0]], dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 3:\n        np_arr = np.array(\n            [[9, 4, 5, 8], [7, 2, 9, 5], [6, 3, 9, 2], [3, 7, 5, 8]], dtype=np.float32,\n        )\n    device = flow.device(in_device)\n    tensor = flow.tensor(np_arr, device=device, dtype=flow.float32)\n    placement = flow.placement(in_device, ranks=[0, 1, 3])\n    x = tensor.to_global(placement, flow.sbp.partial_sum)\n    new_placement = flow.placement(out_device, ranks=[2, 3])\n    y = x.to_global(new_placement, flow.sbp.split(1))\n    test_case.assertEqual(y.placement, new_placement)\n    if flow.env.get_rank() == 2:\n        test_case.assertTrue(\n            np.array_equal(\n                y.to_local().numpy(),\n                np.array([[15, 20], [16, 19], [13, 16], [15, 23],], dtype=np.float32,),\n            )\n        )\n    if flow.env.get_rank() == 3:\n        test_case.assertTrue(\n            np.array_equal(\n                y.to_local().numpy(),\n                np.array([[20, 35], [28, 10], [20, 11], [20, 12],], dtype=np.float32,),\n            )\n        )\n\n\ndef _test_eager_boxing_with_overlapping_placement_b_to_s1(\n    test_case, in_device, out_device\n):\n    if flow.env.get_rank() == 0:\n        np_arr = np.array(\n            [[4, 6, 5, 20], [6, 8, 9, 0], [3, 7, 5, 0], [6, 8, 9, 0]], dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 1:\n        np_arr = np.array(\n            [[2, 10, 10, 7], [3, 9, 10, 5], [4, 6, 6, 9], [6, 8, 6, 4]],\n            dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 2:\n        np_arr = np.array(\n            [[9, 6, 5, 8], [4, 9, 7, 0], [2, 5, 7, 9], [6, 8, 10, 0]], dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 3:\n        np_arr = np.array(\n            [[9, 4, 5, 8], [7, 2, 9, 5], [6, 3, 9, 2], [3, 7, 5, 8]], dtype=np.float32,\n        )\n    device = flow.device(in_device)\n    tensor = flow.tensor(np_arr, device=device, dtype=flow.float32)\n    placement = flow.placement(in_device, ranks=[0, 1, 3])\n    x = tensor.to_global(placement, flow.sbp.broadcast)\n    new_placement = flow.placement(out_device, ranks=[2, 3])\n    y = x.to_global(new_placement, flow.sbp.split(1))\n    test_case.assertEqual(y.placement, new_placement)\n    if flow.env.get_rank() == 2:\n        test_case.assertTrue(\n            np.array_equal(\n                y.to_local().numpy(),\n                np.array([[4, 6], [6, 8], [3, 7], [6, 8],], dtype=np.float32,),\n            )\n        )\n    if flow.env.get_rank() == 3:\n        test_case.assertTrue(\n            np.array_equal(\n                y.to_local().numpy(),\n                np.array([[5, 20], [9, 0], [5, 0], [9, 0],], dtype=np.float32,),\n            )\n        )\n\n\ndef _test_eager_boxing_with_overlapping_placement_s0_to_s1(\n    test_case, in_device, out_device\n):\n    if flow.env.get_rank() == 0:\n        np_arr = np.array(\n            [[4, 6, 5, 20], [6, 8, 9, 0], [3, 7, 5, 0], [6, 8, 9, 0]], dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 1:\n        np_arr = np.array(\n            [[2, 10, 10, 7], [3, 9, 10, 5], [4, 6, 6, 9], [6, 8, 6, 4]],\n            dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 2:\n        np_arr = np.array(\n            [[9, 6, 5, 8], [4, 9, 7, 0], [2, 5, 7, 9], [6, 8, 10, 0]], dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 3:\n        np_arr = np.array(\n            [[9, 4, 5, 8], [7, 2, 9, 5], [6, 3, 9, 2], [3, 7, 5, 8]], dtype=np.float32,\n        )\n    device = flow.device(in_device)\n    tensor = flow.tensor(np_arr, device=device, dtype=flow.float32)\n    placement = flow.placement(in_device, ranks=[0, 1, 3])\n    x = tensor.to_global(placement, flow.sbp.split(0))\n    new_placement = flow.placement(out_device, ranks=[2, 3])\n    y = x.to_global(new_placement, flow.sbp.split(1))\n    test_case.assertEqual(y.placement, new_placement)\n    if flow.env.get_rank() == 2:\n        test_case.assertTrue(\n            np.array_equal(\n                y.to_local().numpy(),\n                np.array(\n                    [\n                        [4, 6],\n                        [6, 8],\n                        [3, 7],\n                        [6, 8],\n                        [2, 10],\n                        [3, 9],\n                        [4, 6],\n                        [6, 8],\n                        [9, 4],\n                        [7, 2],\n                        [6, 3],\n                        [3, 7],\n                    ],\n                    dtype=np.float32,\n                ),\n            )\n        )\n    if flow.env.get_rank() == 3:\n        test_case.assertTrue(\n            np.array_equal(\n                y.to_local().numpy(),\n                np.array(\n                    [\n                        [5, 20],\n                        [9, 0],\n                        [5, 0],\n                        [9, 0],\n                        [10, 7],\n                        [10, 5],\n                        [6, 9],\n                        [6, 4],\n                        [5, 8],\n                        [9, 5],\n                        [9, 2],\n                        [5, 8],\n                    ],\n                    dtype=np.float32,\n                ),\n            )\n        )\n\n\ndef _test_eager_boxing_with_overlapping_placement_s1_to_s1(\n    test_case, in_device, out_device\n):\n    if flow.env.get_rank() == 0:\n        np_arr = np.array(\n            [\n                [4, 6, 5, 20, 8, 9],\n                [6, 8, 9, 0, 4, 6],\n                [3, 7, 5, 0, 3, 5],\n                [6, 8, 9, 0, 8, 7],\n            ],\n            dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 1:\n        np_arr = np.array(\n            [\n                [2, 10, 10, 7, 10, 3],\n                [3, 9, 10, 5, 5, 6],\n                [4, 6, 6, 9, 8, 6],\n                [6, 8, 6, 4, 5, 3],\n            ],\n            dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 2:\n        np_arr = np.array(\n            [\n                [9, 6, 5, 8, 3, 6],\n                [4, 9, 7, 0, 2, 1],\n                [2, 5, 7, 9, 4, 8],\n                [6, 8, 10, 0, 4, 9],\n            ],\n            dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 3:\n        np_arr = np.array(\n            [\n                [9, 4, 5, 8, 9, 6],\n                [7, 2, 9, 5, 4, 1],\n                [6, 3, 9, 2, 5, 2],\n                [3, 7, 5, 8, 9, 3],\n            ],\n            dtype=np.float32,\n        )\n    device = flow.device(in_device)\n    tensor = flow.tensor(np_arr, device=device, dtype=flow.float32)\n    placement = flow.placement(in_device, ranks=[0, 1, 3])\n    x = tensor.to_global(placement, flow.sbp.split(0))\n    y = x.to_global(placement, flow.sbp.split(1))\n    new_placement = flow.placement(out_device, ranks=[2, 3])\n    z = y.to_global(new_placement, flow.sbp.split(1))\n    test_case.assertEqual(z.placement, new_placement)\n    if flow.env.get_rank() == 2:\n        test_case.assertTrue(\n            np.array_equal(\n                z.to_local().numpy(),\n                np.array(\n                    [\n                        [4, 6, 5],\n                        [6, 8, 9],\n                        [3, 7, 5],\n                        [6, 8, 9],\n                        [2, 10, 10],\n                        [3, 9, 10],\n                        [4, 6, 6],\n                        [6, 8, 6],\n                        [9, 4, 5],\n                        [7, 2, 9],\n                        [6, 3, 9],\n                        [3, 7, 5],\n                    ],\n                    dtype=np.float32,\n                ),\n            )\n        )\n    if flow.env.get_rank() == 3:\n        test_case.assertTrue(\n            np.array_equal(\n                z.to_local().numpy(),\n                np.array(\n                    [\n                        [20, 8, 9],\n                        [0, 4, 6],\n                        [0, 3, 5],\n                        [0, 8, 7],\n                        [7, 10, 3],\n                        [5, 5, 6],\n                        [9, 8, 6],\n                        [4, 5, 3],\n                        [8, 9, 6],\n                        [5, 4, 1],\n                        [2, 5, 2],\n                        [8, 9, 3],\n                    ],\n                    dtype=np.float32,\n                ),\n            )\n        )\n\n\ndef _test_eager_boxing_with_overlapping_placement_s1_to_s0(\n    test_case, in_device, out_device\n):\n    if flow.env.get_rank() == 0:\n        np_arr = np.array(\n            [\n                [4, 6, 5, 20, 8, 9],\n                [6, 8, 9, 0, 4, 6],\n                [3, 7, 5, 0, 3, 5],\n                [6, 8, 9, 0, 8, 7],\n            ],\n            dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 1:\n        np_arr = np.array(\n            [\n                [2, 10, 10, 7, 10, 3],\n                [3, 9, 10, 5, 5, 6],\n                [4, 6, 6, 9, 8, 6],\n                [6, 8, 6, 4, 5, 3],\n            ],\n            dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 2:\n        np_arr = np.array(\n            [\n                [9, 6, 5, 8, 3, 6],\n                [4, 9, 7, 0, 2, 1],\n                [2, 5, 7, 9, 4, 8],\n                [6, 8, 10, 0, 4, 9],\n            ],\n            dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 3:\n        np_arr = np.array(\n            [\n                [9, 4, 5, 8, 9, 6],\n                [7, 2, 9, 5, 4, 1],\n                [6, 3, 9, 2, 5, 2],\n                [3, 7, 5, 8, 9, 3],\n            ],\n            dtype=np.float32,\n        )\n    device = flow.device(in_device)\n    tensor = flow.tensor(np_arr, device=device, dtype=flow.float32)\n    placement = flow.placement(in_device, ranks=[0, 1, 3])\n    x = tensor.to_global(placement, flow.sbp.split(0))\n    y = x.to_global(placement, flow.sbp.split(1))\n    new_placement = flow.placement(out_device, ranks=[2, 3])\n    z = y.to_global(new_placement, flow.sbp.split(0))\n    test_case.assertEqual(z.placement, new_placement)\n    if flow.env.get_rank() == 2:\n        test_case.assertTrue(\n            np.array_equal(\n                z.to_local().numpy(),\n                np.array(\n                    [\n                        [4, 6, 5, 20, 8, 9],\n                        [6, 8, 9, 0, 4, 6],\n                        [3, 7, 5, 0, 3, 5],\n                        [6, 8, 9, 0, 8, 7],\n                        [2, 10, 10, 7, 10, 3],\n                        [3, 9, 10, 5, 5, 6],\n                    ],\n                    dtype=np.float32,\n                ),\n            )\n        )\n    if flow.env.get_rank() == 3:\n        test_case.assertTrue(\n            np.array_equal(\n                z.to_local().numpy(),\n                np.array(\n                    [\n                        [4, 6, 6, 9, 8, 6],\n                        [6, 8, 6, 4, 5, 3],\n                        [9, 4, 5, 8, 9, 6],\n                        [7, 2, 9, 5, 4, 1],\n                        [6, 3, 9, 2, 5, 2],\n                        [3, 7, 5, 8, 9, 3],\n                    ],\n                    dtype=np.float32,\n                ),\n            )\n        )\n\n\ndef _test_eager_boxing_with_overlapping_placement_s1_to_b(\n    test_case, in_device, out_device\n):\n    if flow.env.get_rank() == 0:\n        np_arr = np.array(\n            [\n                [4, 6, 5, 20, 8, 9],\n                [6, 8, 9, 0, 4, 6],\n                [3, 7, 5, 0, 3, 5],\n                [6, 8, 9, 0, 8, 7],\n            ],\n            dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 1:\n        np_arr = np.array(\n            [\n                [2, 10, 10, 7, 10, 3],\n                [3, 9, 10, 5, 5, 6],\n                [4, 6, 6, 9, 8, 6],\n                [6, 8, 6, 4, 5, 3],\n            ],\n            dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 2:\n        np_arr = np.array(\n            [\n                [9, 6, 5, 8, 3, 6],\n                [4, 9, 7, 0, 2, 1],\n                [2, 5, 7, 9, 4, 8],\n                [6, 8, 10, 0, 4, 9],\n            ],\n            dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 3:\n        np_arr = np.array(\n            [\n                [9, 4, 5, 8, 9, 6],\n                [7, 2, 9, 5, 4, 1],\n                [6, 3, 9, 2, 5, 2],\n                [3, 7, 5, 8, 9, 3],\n            ],\n            dtype=np.float32,\n        )\n    device = flow.device(in_device)\n    tensor = flow.tensor(np_arr, device=device, dtype=flow.float32)\n    placement = flow.placement(in_device, ranks=[0, 1, 3])\n    x = tensor.to_global(placement, flow.sbp.split(0))\n    y = x.to_global(placement, flow.sbp.split(1))\n    new_placement = flow.placement(out_device, ranks=[2, 3])\n    z = y.to_global(new_placement, flow.sbp.broadcast)\n    test_case.assertEqual(z.placement, new_placement)\n    if flow.env.get_rank() == 2:\n        test_case.assertTrue(\n            np.array_equal(\n                z.to_local().numpy(),\n                np.array(\n                    [\n                        [4, 6, 5, 20, 8, 9],\n                        [6, 8, 9, 0, 4, 6],\n                        [3, 7, 5, 0, 3, 5],\n                        [6, 8, 9, 0, 8, 7],\n                        [2, 10, 10, 7, 10, 3],\n                        [3, 9, 10, 5, 5, 6],\n                        [4, 6, 6, 9, 8, 6],\n                        [6, 8, 6, 4, 5, 3],\n                        [9, 4, 5, 8, 9, 6],\n                        [7, 2, 9, 5, 4, 1],\n                        [6, 3, 9, 2, 5, 2],\n                        [3, 7, 5, 8, 9, 3],\n                    ],\n                    dtype=np.float32,\n                ),\n            )\n        )\n    if flow.env.get_rank() == 3:\n        test_case.assertTrue(\n            np.array_equal(\n                z.to_local().numpy(),\n                np.array(\n                    [\n                        [4, 6, 5, 20, 8, 9],\n                        [6, 8, 9, 0, 4, 6],\n                        [3, 7, 5, 0, 3, 5],\n                        [6, 8, 9, 0, 8, 7],\n                        [2, 10, 10, 7, 10, 3],\n                        [3, 9, 10, 5, 5, 6],\n                        [4, 6, 6, 9, 8, 6],\n                        [6, 8, 6, 4, 5, 3],\n                        [9, 4, 5, 8, 9, 6],\n                        [7, 2, 9, 5, 4, 1],\n                        [6, 3, 9, 2, 5, 2],\n                        [3, 7, 5, 8, 9, 3],\n                    ],\n                    dtype=np.float32,\n                ),\n            )\n        )\n\n\ndef _test_eager_boxing_with_overlapping_placement_s1_to_p(\n    test_case, in_device, out_device\n):\n    if flow.env.get_rank() == 0:\n        np_arr = np.array(\n            [\n                [4, 6, 5, 20, 8, 9],\n                [6, 8, 9, 0, 4, 6],\n                [3, 7, 5, 0, 3, 5],\n                [6, 8, 9, 0, 8, 7],\n            ],\n            dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 1:\n        np_arr = np.array(\n            [\n                [2, 10, 10, 7, 10, 3],\n                [3, 9, 10, 5, 5, 6],\n                [4, 6, 6, 9, 8, 6],\n                [6, 8, 6, 4, 5, 3],\n            ],\n            dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 2:\n        np_arr = np.array(\n            [\n                [9, 6, 5, 8, 3, 6],\n                [4, 9, 7, 0, 2, 1],\n                [2, 5, 7, 9, 4, 8],\n                [6, 8, 10, 0, 4, 9],\n            ],\n            dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 3:\n        np_arr = np.array(\n            [\n                [9, 4, 5, 8, 9, 6],\n                [7, 2, 9, 5, 4, 1],\n                [6, 3, 9, 2, 5, 2],\n                [3, 7, 5, 8, 9, 3],\n            ],\n            dtype=np.float32,\n        )\n    device = flow.device(in_device)\n    tensor = flow.tensor(np_arr, device=device, dtype=flow.float32)\n    placement = flow.placement(in_device, ranks=[0, 1, 3])\n    x = tensor.to_global(placement, flow.sbp.split(0))\n    y = x.to_global(placement, flow.sbp.split(1))\n    new_placement = flow.placement(out_device, ranks=[2, 3])\n    z = y.to_global(new_placement, flow.sbp.partial_sum)\n    test_case.assertEqual(z.placement, new_placement)\n    if flow.env.get_rank() == 2:\n        test_case.assertTrue(\n            np.array_equal(\n                z.to_local().numpy(),\n                np.array(\n                    [\n                        [4, 6, 0, 0, 0, 0],\n                        [6, 8, 0, 0, 0, 0],\n                        [3, 7, 0, 0, 0, 0],\n                        [6, 8, 0, 0, 0, 0],\n                        [2, 10, 0, 0, 0, 0],\n                        [3, 9, 0, 0, 0, 0],\n                        [4, 6, 0, 0, 0, 0],\n                        [6, 8, 0, 0, 0, 0],\n                        [9, 4, 0, 0, 0, 0],\n                        [7, 2, 0, 0, 0, 0],\n                        [6, 3, 0, 0, 0, 0],\n                        [3, 7, 0, 0, 0, 0],\n                    ],\n                    dtype=np.float32,\n                ),\n            )\n        )\n    if flow.env.get_rank() == 3:\n        test_case.assertTrue(\n            np.array_equal(\n                z.to_local().numpy(),\n                np.array(\n                    [\n                        [0, 0, 5, 20, 8, 9],\n                        [0, 0, 9, 0, 4, 6],\n                        [0, 0, 5, 0, 3, 5],\n                        [0, 0, 9, 0, 8, 7],\n                        [0, 0, 10, 7, 10, 3],\n                        [0, 0, 10, 5, 5, 6],\n                        [0, 0, 6, 9, 8, 6],\n                        [0, 0, 6, 4, 5, 3],\n                        [0, 0, 5, 8, 9, 6],\n                        [0, 0, 9, 5, 4, 1],\n                        [0, 0, 9, 2, 5, 2],\n                        [0, 0, 5, 8, 9, 3],\n                    ],\n                    dtype=np.float32,\n                ),\n            )\n        )\n\n\ndef _test_eager_boxing_with_in_placement_contain_out_placement_p_to_s1(\n    test_case, in_device, out_device\n):\n    if flow.env.get_rank() == 0:\n        np_arr = np.array(\n            [[4, 6, 5, 20], [6, 8, 9, 0], [3, 7, 5, 0], [6, 8, 9, 0]], dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 1:\n        np_arr = np.array(\n            [[2, 10, 10, 7], [3, 9, 10, 5], [4, 6, 6, 9], [6, 8, 6, 4]],\n            dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 2:\n        np_arr = np.array(\n            [[9, 6, 5, 8], [4, 9, 7, 0], [2, 5, 7, 9], [6, 8, 10, 0]], dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 3:\n        np_arr = np.array(\n            [[9, 4, 5, 8], [7, 2, 9, 5], [6, 3, 9, 2], [3, 7, 5, 8]], dtype=np.float32,\n        )\n    device = flow.device(in_device)\n    tensor = flow.tensor(np_arr, device=device, dtype=flow.float32)\n    placement = flow.placement(in_device, ranks=[0, 1, 3])\n    x = tensor.to_global(placement, flow.sbp.partial_sum)\n    new_placement = flow.placement(out_device, ranks=[1, 3])\n    y = x.to_global(new_placement, flow.sbp.split(1))\n    test_case.assertEqual(y.placement, new_placement)\n    if flow.env.get_rank() == 1:\n        test_case.assertTrue(\n            np.array_equal(\n                y.to_local().numpy(),\n                np.array([[15, 20], [16, 19], [13, 16], [15, 23],], dtype=np.float32,),\n            )\n        )\n    if flow.env.get_rank() == 3:\n        test_case.assertTrue(\n            np.array_equal(\n                y.to_local().numpy(),\n                np.array([[20, 35], [28, 10], [20, 11], [20, 12],], dtype=np.float32,),\n            )\n        )\n\n\ndef _test_eager_boxing_with_in_placement_contain_out_placement_b_to_s1(\n    test_case, in_device, out_device\n):\n    if flow.env.get_rank() == 0:\n        np_arr = np.array(\n            [[4, 6, 5, 20], [6, 8, 9, 0], [3, 7, 5, 0], [6, 8, 9, 0]], dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 1:\n        np_arr = np.array(\n            [[2, 10, 10, 7], [3, 9, 10, 5], [4, 6, 6, 9], [6, 8, 6, 4]],\n            dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 2:\n        np_arr = np.array(\n            [[9, 6, 5, 8], [4, 9, 7, 0], [2, 5, 7, 9], [6, 8, 10, 0]], dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 3:\n        np_arr = np.array(\n            [[9, 4, 5, 8], [7, 2, 9, 5], [6, 3, 9, 2], [3, 7, 5, 8]], dtype=np.float32,\n        )\n    device = flow.device(in_device)\n    tensor = flow.tensor(np_arr, device=device, dtype=flow.float32)\n    placement = flow.placement(in_device, ranks=[0, 1, 3])\n    x = tensor.to_global(placement, flow.sbp.broadcast)\n    new_placement = flow.placement(out_device, ranks=[1, 3])\n    y = x.to_global(new_placement, flow.sbp.split(1))\n    test_case.assertEqual(y.placement, new_placement)\n    if flow.env.get_rank() == 1:\n        test_case.assertTrue(\n            np.array_equal(\n                y.to_local().numpy(),\n                np.array([[4, 6], [6, 8], [3, 7], [6, 8],], dtype=np.float32,),\n            )\n        )\n    if flow.env.get_rank() == 3:\n        test_case.assertTrue(\n            np.array_equal(\n                y.to_local().numpy(),\n                np.array([[5, 20], [9, 0], [5, 0], [9, 0],], dtype=np.float32,),\n            )\n        )\n\n\ndef _test_eager_boxing_with_in_placement_contain_out_placement_s0_to_s1(\n    test_case, in_device, out_device\n):\n    if flow.env.get_rank() == 0:\n        np_arr = np.array(\n            [[4, 6, 5, 20], [6, 8, 9, 0], [3, 7, 5, 0], [6, 8, 9, 0]], dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 1:\n        np_arr = np.array(\n            [[2, 10, 10, 7], [3, 9, 10, 5], [4, 6, 6, 9], [6, 8, 6, 4]],\n            dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 2:\n        np_arr = np.array(\n            [[9, 6, 5, 8], [4, 9, 7, 0], [2, 5, 7, 9], [6, 8, 10, 0]], dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 3:\n        np_arr = np.array(\n            [[9, 4, 5, 8], [7, 2, 9, 5], [6, 3, 9, 2], [3, 7, 5, 8]], dtype=np.float32,\n        )\n    device = flow.device(in_device)\n    tensor = flow.tensor(np_arr, device=device, dtype=flow.float32)\n    placement = flow.placement(in_device, ranks=[0, 1, 3])\n    x = tensor.to_global(placement, flow.sbp.split(0))\n    new_placement = flow.placement(out_device, ranks=[1, 3])\n    y = x.to_global(new_placement, flow.sbp.split(1))\n    test_case.assertEqual(y.placement, new_placement)\n    if flow.env.get_rank() == 1:\n        test_case.assertTrue(\n            np.array_equal(\n                y.to_local().numpy(),\n                np.array(\n                    [\n                        [4, 6],\n                        [6, 8],\n                        [3, 7],\n                        [6, 8],\n                        [2, 10],\n                        [3, 9],\n                        [4, 6],\n                        [6, 8],\n                        [9, 4],\n                        [7, 2],\n                        [6, 3],\n                        [3, 7],\n                    ],\n                    dtype=np.float32,\n                ),\n            )\n        )\n    if flow.env.get_rank() == 3:\n        test_case.assertTrue(\n            np.array_equal(\n                y.to_local().numpy(),\n                np.array(\n                    [\n                        [5, 20],\n                        [9, 0],\n                        [5, 0],\n                        [9, 0],\n                        [10, 7],\n                        [10, 5],\n                        [6, 9],\n                        [6, 4],\n                        [5, 8],\n                        [9, 5],\n                        [9, 2],\n                        [5, 8],\n                    ],\n                    dtype=np.float32,\n                ),\n            )\n        )\n\n\ndef _test_eager_boxing_with_in_placement_contain_out_placement_s1_to_s1(\n    test_case, in_device, out_device\n):\n    if flow.env.get_rank() == 0:\n        np_arr = np.array(\n            [[4, 6, 5, 20], [6, 8, 9, 0], [3, 7, 5, 0], [6, 8, 9, 0]], dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 1:\n        np_arr = np.array(\n            [[2, 10, 10, 7], [3, 9, 10, 5], [4, 6, 6, 9], [6, 8, 6, 4]],\n            dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 2:\n        np_arr = np.array(\n            [[9, 6, 5, 8], [4, 9, 7, 0], [2, 5, 7, 9], [6, 8, 10, 0]], dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 3:\n        np_arr = np.array(\n            [[9, 4, 5, 8], [7, 2, 9, 5], [6, 3, 9, 2], [3, 7, 5, 8]], dtype=np.float32,\n        )\n    device = flow.device(in_device)\n    tensor = flow.tensor(np_arr, device=device, dtype=flow.float32)\n    placement = flow.placement(in_device, ranks=[0, 2, 1, 3])\n    x = tensor.to_global(placement, flow.sbp.broadcast)\n    y = x.to_global(placement, flow.sbp.split(1))\n    new_placement = flow.placement(out_device, ranks=[1, 3])\n    z = y.to_global(new_placement, flow.sbp.split(1))\n    test_case.assertEqual(z.placement, new_placement)\n    if flow.env.get_rank() == 1:\n        test_case.assertTrue(\n            np.array_equal(\n                z.to_local().numpy(),\n                np.array([[4, 6], [6, 8], [3, 7], [6, 8],], dtype=np.float32,),\n            )\n        )\n    if flow.env.get_rank() == 3:\n        test_case.assertTrue(\n            np.array_equal(\n                z.to_local().numpy(),\n                np.array([[5, 20], [9, 0], [5, 0], [9, 0],], dtype=np.float32,),\n            )\n        )\n\n\ndef _test_eager_boxing_with_in_placement_contain_out_placement_s1_to_s0(\n    test_case, in_device, out_device\n):\n    if flow.env.get_rank() == 0:\n        np_arr = np.array(\n            [[4, 6, 5, 20], [6, 8, 9, 0], [3, 7, 5, 0], [6, 8, 9, 0]], dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 1:\n        np_arr = np.array(\n            [[2, 10, 10, 7], [3, 9, 10, 5], [4, 6, 6, 9], [6, 8, 6, 4]],\n            dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 2:\n        np_arr = np.array(\n            [[9, 6, 5, 8], [4, 9, 7, 0], [2, 5, 7, 9], [6, 8, 10, 0]], dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 3:\n        np_arr = np.array(\n            [[9, 4, 5, 8], [7, 2, 9, 5], [6, 3, 9, 2], [3, 7, 5, 8]], dtype=np.float32,\n        )\n    device = flow.device(in_device)\n    tensor = flow.tensor(np_arr, device=device, dtype=flow.float32)\n    placement = flow.placement(in_device, ranks=[0, 2, 1, 3])\n    x = tensor.to_global(placement, flow.sbp.broadcast)\n    y = x.to_global(placement, flow.sbp.split(1))\n    new_placement = flow.placement(out_device, ranks=[1, 3])\n    z = y.to_global(new_placement, flow.sbp.split(0))\n    test_case.assertEqual(z.placement, new_placement)\n    if flow.env.get_rank() == 1:\n        test_case.assertTrue(\n            np.array_equal(\n                z.to_local().numpy(),\n                np.array([[4, 6, 5, 20], [6, 8, 9, 0],], dtype=np.float32,),\n            )\n        )\n    if flow.env.get_rank() == 3:\n        test_case.assertTrue(\n            np.array_equal(\n                z.to_local().numpy(),\n                np.array([[3, 7, 5, 0], [6, 8, 9, 0],], dtype=np.float32,),\n            )\n        )\n\n\ndef _test_eager_boxing_with_in_placement_contain_out_placement_s1_to_p(\n    test_case, in_device, out_device\n):\n    if flow.env.get_rank() == 0:\n        np_arr = np.array(\n            [[4, 6, 5, 20], [6, 8, 9, 0], [3, 7, 5, 0], [6, 8, 9, 0]], dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 1:\n        np_arr = np.array(\n            [[2, 10, 10, 7], [3, 9, 10, 5], [4, 6, 6, 9], [6, 8, 6, 4]],\n            dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 2:\n        np_arr = np.array(\n            [[9, 6, 5, 8], [4, 9, 7, 0], [2, 5, 7, 9], [6, 8, 10, 0]], dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 3:\n        np_arr = np.array(\n            [[9, 4, 5, 8], [7, 2, 9, 5], [6, 3, 9, 2], [3, 7, 5, 8]], dtype=np.float32,\n        )\n    device = flow.device(in_device)\n    tensor = flow.tensor(np_arr, device=device, dtype=flow.float32)\n    placement = flow.placement(in_device, ranks=[0, 2, 1, 3])\n    x = tensor.to_global(placement, flow.sbp.broadcast)\n    y = x.to_global(placement, flow.sbp.split(1))\n    new_placement = flow.placement(out_device, ranks=[1, 3])\n    z = y.to_global(new_placement, flow.sbp.partial_sum)\n    test_case.assertEqual(z.placement, new_placement)\n    if flow.env.get_rank() == 1:\n        test_case.assertTrue(\n            np.array_equal(\n                z.to_local().numpy(),\n                np.array(\n                    [[4, 6, 5, 0], [6, 8, 9, 0], [3, 7, 5, 0], [6, 8, 9, 0],],\n                    dtype=np.float32,\n                ),\n            )\n        )\n    if flow.env.get_rank() == 3:\n        test_case.assertTrue(\n            np.array_equal(\n                z.to_local().numpy(),\n                np.array(\n                    [[0, 0, 0, 20], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0],],\n                    dtype=np.float32,\n                ),\n            )\n        )\n\n\ndef _test_eager_boxing_with_in_placement_contain_out_placement_s1_to_b(\n    test_case, in_device, out_device\n):\n    if flow.env.get_rank() == 0:\n        np_arr = np.array(\n            [[4, 6, 5, 20], [6, 8, 9, 0], [3, 7, 5, 0], [6, 8, 9, 0]], dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 1:\n        np_arr = np.array(\n            [[2, 10, 10, 7], [3, 9, 10, 5], [4, 6, 6, 9], [6, 8, 6, 4]],\n            dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 2:\n        np_arr = np.array(\n            [[9, 6, 5, 8], [4, 9, 7, 0], [2, 5, 7, 9], [6, 8, 10, 0]], dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 3:\n        np_arr = np.array(\n            [[9, 4, 5, 8], [7, 2, 9, 5], [6, 3, 9, 2], [3, 7, 5, 8]], dtype=np.float32,\n        )\n    device = flow.device(in_device)\n    tensor = flow.tensor(np_arr, device=device, dtype=flow.float32)\n    placement = flow.placement(in_device, ranks=[0, 2, 1, 3])\n    x = tensor.to_global(placement, flow.sbp.broadcast)\n    y = x.to_global(placement, flow.sbp.split(1))\n    new_placement = flow.placement(out_device, ranks=[1, 3])\n    z = y.to_global(new_placement, flow.sbp.broadcast)\n    test_case.assertEqual(z.placement, new_placement)\n    if flow.env.get_rank() == 1:\n        test_case.assertTrue(\n            np.array_equal(\n                z.to_local().numpy(),\n                np.array(\n                    [[4, 6, 5, 20], [6, 8, 9, 0], [3, 7, 5, 0], [6, 8, 9, 0],],\n                    dtype=np.float32,\n                ),\n            )\n        )\n    if flow.env.get_rank() == 3:\n        test_case.assertTrue(\n            np.array_equal(\n                z.to_local().numpy(),\n                np.array(\n                    [[4, 6, 5, 20], [6, 8, 9, 0], [3, 7, 5, 0], [6, 8, 9, 0],],\n                    dtype=np.float32,\n                ),\n            )\n        )\n\n\ndef _test_eager_boxing_with_out_placement_contain_in_placement_p_to_s1(\n    test_case, in_device, out_device\n):\n    if flow.env.get_rank() == 0:\n        np_arr = np.array(\n            [[4, 6, 5, 20], [6, 8, 9, 0], [3, 7, 5, 0], [6, 8, 9, 0]], dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 1:\n        np_arr = np.array(\n            [[2, 10, 10, 7], [3, 9, 10, 5], [4, 6, 6, 9], [6, 8, 6, 4]],\n            dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 2:\n        np_arr = np.array(\n            [[9, 6, 5, 8], [4, 9, 7, 0], [2, 5, 7, 9], [6, 8, 10, 0]], dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 3:\n        np_arr = np.array(\n            [[9, 4, 5, 8], [7, 2, 9, 5], [6, 3, 9, 2], [3, 7, 5, 8]], dtype=np.float32,\n        )\n    device = flow.device(in_device)\n    tensor = flow.tensor(np_arr, device=device, dtype=flow.float32)\n    placement = flow.placement(in_device, ranks=[0, 1, 3])\n    x = tensor.to_global(placement, flow.sbp.partial_sum)\n    new_placement = flow.placement(out_device, ranks=[0, 1, 2, 3])\n    y = x.to_global(new_placement, flow.sbp.split(1))\n    test_case.assertEqual(y.placement, new_placement)\n    if flow.env.get_rank() == 0:\n        test_case.assertTrue(\n            np.array_equal(\n                y.to_local().numpy(),\n                np.array([[15], [16], [13], [15],], dtype=np.float32,),\n            )\n        )\n    if flow.env.get_rank() == 1:\n        test_case.assertTrue(\n            np.array_equal(\n                y.to_local().numpy(),\n                np.array([[20], [19], [16], [23],], dtype=np.float32,),\n            )\n        )\n    if flow.env.get_rank() == 2:\n        test_case.assertTrue(\n            np.array_equal(\n                y.to_local().numpy(),\n                np.array([[20], [28], [20], [20],], dtype=np.float32,),\n            )\n        )\n    if flow.env.get_rank() == 3:\n        test_case.assertTrue(\n            np.array_equal(\n                y.to_local().numpy(),\n                np.array([[35], [10], [11], [12],], dtype=np.float32,),\n            )\n        )\n\n\ndef _test_eager_boxing_with_out_placement_contain_in_placement_b_to_s1(\n    test_case, in_device, out_device\n):\n    if flow.env.get_rank() == 0:\n        np_arr = np.array(\n            [[4, 6, 5, 20], [6, 8, 9, 0], [3, 7, 5, 0], [6, 8, 9, 0]], dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 1:\n        np_arr = np.array(\n            [[2, 10, 10, 7], [3, 9, 10, 5], [4, 6, 6, 9], [6, 8, 6, 4]],\n            dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 2:\n        np_arr = np.array(\n            [[9, 6, 5, 8], [4, 9, 7, 0], [2, 5, 7, 9], [6, 8, 10, 0]], dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 3:\n        np_arr = np.array(\n            [[9, 4, 5, 8], [7, 2, 9, 5], [6, 3, 9, 2], [3, 7, 5, 8]], dtype=np.float32,\n        )\n    device = flow.device(in_device)\n    tensor = flow.tensor(np_arr, device=device, dtype=flow.float32)\n    placement = flow.placement(in_device, ranks=[0, 1, 3])\n    x = tensor.to_global(placement, flow.sbp.broadcast)\n    new_placement = flow.placement(out_device, ranks=[0, 1, 2, 3])\n    y = x.to_global(new_placement, flow.sbp.split(1))\n    test_case.assertEqual(y.placement, new_placement)\n    if flow.env.get_rank() == 0:\n        test_case.assertTrue(\n            np.array_equal(\n                y.to_local().numpy(),\n                np.array([[4], [6], [3], [6],], dtype=np.float32,),\n            )\n        )\n    if flow.env.get_rank() == 1:\n        test_case.assertTrue(\n            np.array_equal(\n                y.to_local().numpy(),\n                np.array([[6], [8], [7], [8],], dtype=np.float32,),\n            )\n        )\n    if flow.env.get_rank() == 2:\n        test_case.assertTrue(\n            np.array_equal(\n                y.to_local().numpy(),\n                np.array([[5], [9], [5], [9],], dtype=np.float32,),\n            )\n        )\n    if flow.env.get_rank() == 3:\n        test_case.assertTrue(\n            np.array_equal(\n                y.to_local().numpy(),\n                np.array([[20], [0], [0], [0],], dtype=np.float32,),\n            )\n        )\n\n\ndef _test_eager_boxing_with_out_placement_contain_in_placement_s0_to_s1(\n    test_case, in_device, out_device\n):\n    if flow.env.get_rank() == 0:\n        np_arr = np.array(\n            [[4, 6, 5, 20], [6, 8, 9, 0], [3, 7, 5, 0], [6, 8, 9, 0]], dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 1:\n        np_arr = np.array(\n            [[2, 10, 10, 7], [3, 9, 10, 5], [4, 6, 6, 9], [6, 8, 6, 4]],\n            dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 2:\n        np_arr = np.array(\n            [[9, 6, 5, 8], [4, 9, 7, 0], [2, 5, 7, 9], [6, 8, 10, 0]], dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 3:\n        np_arr = np.array(\n            [[9, 4, 5, 8], [7, 2, 9, 5], [6, 3, 9, 2], [3, 7, 5, 8]], dtype=np.float32,\n        )\n    device = flow.device(in_device)\n    tensor = flow.tensor(np_arr, device=device, dtype=flow.float32)\n    placement = flow.placement(in_device, ranks=[0, 1, 3])\n    x = tensor.to_global(placement, flow.sbp.split(0))\n    new_placement = flow.placement(out_device, ranks=[0, 1, 2, 3])\n    y = x.to_global(new_placement, flow.sbp.split(1))\n    test_case.assertEqual(y.placement, new_placement)\n    if flow.env.get_rank() == 0:\n        test_case.assertTrue(\n            np.array_equal(\n                y.to_local().numpy(),\n                np.array(\n                    [[4], [6], [3], [6], [2], [3], [4], [6], [9], [7], [6], [3],],\n                    dtype=np.float32,\n                ),\n            )\n        )\n    if flow.env.get_rank() == 1:\n        test_case.assertTrue(\n            np.array_equal(\n                y.to_local().numpy(),\n                np.array(\n                    [[6], [8], [7], [8], [10], [9], [6], [8], [4], [2], [3], [7],],\n                    dtype=np.float32,\n                ),\n            )\n        )\n    if flow.env.get_rank() == 2:\n        test_case.assertTrue(\n            np.array_equal(\n                y.to_local().numpy(),\n                np.array(\n                    [[5], [9], [5], [9], [10], [10], [6], [6], [5], [9], [9], [5],],\n                    dtype=np.float32,\n                ),\n            )\n        )\n    if flow.env.get_rank() == 3:\n        test_case.assertTrue(\n            np.array_equal(\n                y.to_local().numpy(),\n                np.array(\n                    [[20], [0], [0], [0], [7], [5], [9], [4], [8], [5], [2], [8],],\n                    dtype=np.float32,\n                ),\n            )\n        )\n\n\ndef _test_eager_boxing_with_out_placement_contain_in_placement_s1_to_b(\n    test_case, in_device, out_device\n):\n    if flow.env.get_rank() == 0:\n        np_arr = np.array(\n            [\n                [4, 6, 5, 20, 8, 9],\n                [6, 8, 9, 0, 4, 6],\n                [3, 7, 5, 0, 3, 5],\n                [6, 8, 9, 0, 8, 7],\n            ],\n            dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 1:\n        np_arr = np.array(\n            [\n                [2, 10, 10, 7, 10, 3],\n                [3, 9, 10, 5, 5, 6],\n                [4, 6, 6, 9, 8, 6],\n                [6, 8, 6, 4, 5, 3],\n            ],\n            dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 2:\n        np_arr = np.array(\n            [\n                [9, 6, 5, 8, 3, 6],\n                [4, 9, 7, 0, 2, 1],\n                [2, 5, 7, 9, 4, 8],\n                [6, 8, 10, 0, 4, 9],\n            ],\n            dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 3:\n        np_arr = np.array(\n            [\n                [9, 4, 5, 8, 9, 6],\n                [7, 2, 9, 5, 4, 1],\n                [6, 3, 9, 2, 5, 2],\n                [3, 7, 5, 8, 9, 3],\n            ],\n            dtype=np.float32,\n        )\n    device = flow.device(in_device)\n    tensor = flow.tensor(np_arr, device=device, dtype=flow.float32)\n    placement = flow.placement(in_device, ranks=[0, 1, 3])\n    x = tensor.to_global(placement, flow.sbp.split(0))\n    y = x.to_global(placement, flow.sbp.split(1))\n    new_placement = flow.placement(out_device, ranks=[0, 1, 2, 3])\n    z = y.to_global(new_placement, flow.sbp.broadcast)\n    test_case.assertEqual(z.placement, new_placement)\n    test_case.assertTrue(\n        np.array_equal(\n            z.to_local().numpy(),\n            np.array(\n                [\n                    [4, 6, 5, 20, 8, 9],\n                    [6, 8, 9, 0, 4, 6],\n                    [3, 7, 5, 0, 3, 5],\n                    [6, 8, 9, 0, 8, 7],\n                    [2, 10, 10, 7, 10, 3],\n                    [3, 9, 10, 5, 5, 6],\n                    [4, 6, 6, 9, 8, 6],\n                    [6, 8, 6, 4, 5, 3],\n                    [9, 4, 5, 8, 9, 6],\n                    [7, 2, 9, 5, 4, 1],\n                    [6, 3, 9, 2, 5, 2],\n                    [3, 7, 5, 8, 9, 3],\n                ],\n                dtype=np.float32,\n            ),\n        )\n    )\n\n\ndef _test_eager_boxing_with_out_placement_contain_in_placement_s1_to_p(\n    test_case, in_device, out_device\n):\n    if flow.env.get_rank() == 0:\n        np_arr = np.array(\n            [\n                [4, 6, 5, 20, 8, 9],\n                [6, 8, 9, 0, 4, 6],\n                [3, 7, 5, 0, 3, 5],\n                [6, 8, 9, 0, 8, 7],\n            ],\n            dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 1:\n        np_arr = np.array(\n            [\n                [2, 10, 10, 7, 10, 3],\n                [3, 9, 10, 5, 5, 6],\n                [4, 6, 6, 9, 8, 6],\n                [6, 8, 6, 4, 5, 3],\n            ],\n            dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 2:\n        np_arr = np.array(\n            [\n                [9, 6, 5, 8, 3, 6],\n                [4, 9, 7, 0, 2, 1],\n                [2, 5, 7, 9, 4, 8],\n                [6, 8, 10, 0, 4, 9],\n            ],\n            dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 3:\n        np_arr = np.array(\n            [\n                [9, 4, 5, 8, 9, 6],\n                [7, 2, 9, 5, 4, 1],\n                [6, 3, 9, 2, 5, 2],\n                [3, 7, 5, 8, 9, 3],\n            ],\n            dtype=np.float32,\n        )\n    device = flow.device(in_device)\n    tensor = flow.tensor(np_arr, device=device, dtype=flow.float32)\n    placement = flow.placement(in_device, ranks=[0, 1, 3])\n    x = tensor.to_global(placement, flow.sbp.split(0))\n    y = x.to_global(placement, flow.sbp.split(1))\n    new_placement = flow.placement(out_device, ranks=[0, 1, 2, 3])\n    z = y.to_global(new_placement, flow.sbp.partial_sum)\n    test_case.assertEqual(z.placement, new_placement)\n    if flow.env.get_rank() == 0:\n        test_case.assertTrue(\n            np.array_equal(\n                z.to_local().numpy(),\n                np.array(\n                    [\n                        [4, 6, 0, 0, 0, 0],\n                        [6, 8, 0, 0, 0, 0],\n                        [3, 7, 0, 0, 0, 0],\n                        [6, 8, 0, 0, 0, 0],\n                        [2, 10, 0, 0, 0, 0],\n                        [3, 9, 0, 0, 0, 0],\n                        [4, 6, 0, 0, 0, 0],\n                        [6, 8, 0, 0, 0, 0],\n                        [9, 4, 0, 0, 0, 0],\n                        [7, 2, 0, 0, 0, 0],\n                        [6, 3, 0, 0, 0, 0],\n                        [3, 7, 0, 0, 0, 0],\n                    ],\n                    dtype=np.float32,\n                ),\n            )\n        )\n    elif flow.env.get_rank() == 1:\n        test_case.assertTrue(\n            np.array_equal(\n                z.to_local().numpy(),\n                np.array(\n                    [\n                        [0, 0, 5, 20, 0, 0],\n                        [0, 0, 9, 0, 0, 0],\n                        [0, 0, 5, 0, 0, 0],\n                        [0, 0, 9, 0, 0, 0],\n                        [0, 0, 10, 7, 0, 0],\n                        [0, 0, 10, 5, 0, 0],\n                        [0, 0, 6, 9, 0, 0],\n                        [0, 0, 6, 4, 0, 0],\n                        [0, 0, 5, 8, 0, 0],\n                        [0, 0, 9, 5, 0, 0],\n                        [0, 0, 9, 2, 0, 0],\n                        [0, 0, 5, 8, 0, 0],\n                    ],\n                    dtype=np.float32,\n                ),\n            )\n        )\n    elif flow.env.get_rank() == 2:\n        test_case.assertTrue(\n            np.array_equal(\n                z.to_local().numpy(),\n                np.array(\n                    [\n                        [0, 0, 0, 0, 0, 0],\n                        [0, 0, 0, 0, 0, 0],\n                        [0, 0, 0, 0, 0, 0],\n                        [0, 0, 0, 0, 0, 0],\n                        [0, 0, 0, 0, 0, 0],\n                        [0, 0, 0, 0, 0, 0],\n                        [0, 0, 0, 0, 0, 0],\n                        [0, 0, 0, 0, 0, 0],\n                        [0, 0, 0, 0, 0, 0],\n                        [0, 0, 0, 0, 0, 0],\n                        [0, 0, 0, 0, 0, 0],\n                        [0, 0, 0, 0, 0, 0],\n                    ],\n                    dtype=np.float32,\n                ),\n            )\n        )\n    else:\n        test_case.assertTrue(\n            np.array_equal(\n                z.to_local().numpy(),\n                np.array(\n                    [\n                        [0, 0, 0, 0, 8, 9],\n                        [0, 0, 0, 0, 4, 6],\n                        [0, 0, 0, 0, 3, 5],\n                        [0, 0, 0, 0, 8, 7],\n                        [0, 0, 0, 0, 10, 3],\n                        [0, 0, 0, 0, 5, 6],\n                        [0, 0, 0, 0, 8, 6],\n                        [0, 0, 0, 0, 5, 3],\n                        [0, 0, 0, 0, 9, 6],\n                        [0, 0, 0, 0, 4, 1],\n                        [0, 0, 0, 0, 5, 2],\n                        [0, 0, 0, 0, 9, 3],\n                    ],\n                    dtype=np.float32,\n                ),\n            )\n        )\n\n\ndef _test_eager_boxing_with_out_placement_contain_in_placement_s1_to_s0(\n    test_case, in_device, out_device\n):\n    if flow.env.get_rank() == 0:\n        np_arr = np.array(\n            [\n                [4, 6, 5, 20, 8, 9],\n                [6, 8, 9, 0, 4, 6],\n                [3, 7, 5, 0, 3, 5],\n                [6, 8, 9, 0, 8, 7],\n            ],\n            dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 1:\n        np_arr = np.array(\n            [\n                [2, 10, 10, 7, 10, 3],\n                [3, 9, 10, 5, 5, 6],\n                [4, 6, 6, 9, 8, 6],\n                [6, 8, 6, 4, 5, 3],\n            ],\n            dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 2:\n        np_arr = np.array(\n            [\n                [9, 6, 5, 8, 3, 6],\n                [4, 9, 7, 0, 2, 1],\n                [2, 5, 7, 9, 4, 8],\n                [6, 8, 10, 0, 4, 9],\n            ],\n            dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 3:\n        np_arr = np.array(\n            [\n                [9, 4, 5, 8, 9, 6],\n                [7, 2, 9, 5, 4, 1],\n                [6, 3, 9, 2, 5, 2],\n                [3, 7, 5, 8, 9, 3],\n            ],\n            dtype=np.float32,\n        )\n    device = flow.device(in_device)\n    tensor = flow.tensor(np_arr, device=device, dtype=flow.float32)\n    placement = flow.placement(in_device, ranks=[0, 1, 3])\n    x = tensor.to_global(placement, flow.sbp.split(0))\n    y = x.to_global(placement, flow.sbp.split(1))\n    new_placement = flow.placement(out_device, ranks=[0, 1, 2, 3])\n    z = y.to_global(new_placement, flow.sbp.split(0))\n    test_case.assertEqual(z.placement, new_placement)\n    if flow.env.get_rank() == 0:\n        test_case.assertTrue(\n            np.array_equal(\n                z.to_local().numpy(),\n                np.array(\n                    [[4, 6, 5, 20, 8, 9], [6, 8, 9, 0, 4, 6], [3, 7, 5, 0, 3, 5],],\n                    dtype=np.float32,\n                ),\n            )\n        )\n    elif flow.env.get_rank() == 1:\n        test_case.assertTrue(\n            np.array_equal(\n                z.to_local().numpy(),\n                np.array(\n                    [[6, 8, 9, 0, 8, 7], [2, 10, 10, 7, 10, 3], [3, 9, 10, 5, 5, 6],],\n                    dtype=np.float32,\n                ),\n            )\n        )\n    elif flow.env.get_rank() == 2:\n        test_case.assertTrue(\n            np.array_equal(\n                z.to_local().numpy(),\n                np.array(\n                    [[4, 6, 6, 9, 8, 6], [6, 8, 6, 4, 5, 3], [9, 4, 5, 8, 9, 6],],\n                    dtype=np.float32,\n                ),\n            )\n        )\n    elif flow.env.get_rank() == 3:\n        test_case.assertTrue(\n            np.array_equal(\n                z.to_local().numpy(),\n                np.array(\n                    [[7, 2, 9, 5, 4, 1], [6, 3, 9, 2, 5, 2], [3, 7, 5, 8, 9, 3],],\n                    dtype=np.float32,\n                ),\n            )\n        )\n\n\ndef _test_eager_boxing_with_out_placement_contain_in_placement_s1_to_s1(\n    test_case, in_device, out_device\n):\n    if flow.env.get_rank() == 0:\n        np_arr = np.array(\n            [\n                [4, 6, 5, 20, 8, 9, 5, 20],\n                [6, 8, 9, 0, 4, 6, 9, 0],\n                [3, 7, 5, 0, 3, 5, 0, 3],\n                [6, 8, 9, 0, 8, 7, 8, 9],\n            ],\n            dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 1:\n        np_arr = np.array(\n            [\n                [2, 10, 10, 7, 10, 3, 10, 7],\n                [3, 9, 10, 5, 5, 6, 9, 10],\n                [4, 6, 6, 9, 8, 6, 6, 9],\n                [6, 8, 6, 4, 5, 3, 8, 6],\n            ],\n            dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 2:\n        np_arr = np.array(\n            [\n                [9, 6, 5, 8, 3, 6, 8, 3],\n                [4, 9, 7, 0, 2, 1, 9, 7],\n                [2, 5, 7, 9, 4, 8, 5, 7],\n                [6, 8, 10, 0, 4, 9, 8, 10],\n            ],\n            dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 3:\n        np_arr = np.array(\n            [\n                [9, 4, 5, 8, 9, 6, 5, 8],\n                [7, 2, 9, 5, 4, 1, 7, 2],\n                [6, 3, 9, 2, 5, 2, 9, 2],\n                [3, 7, 5, 8, 9, 3, 7, 5],\n            ],\n            dtype=np.float32,\n        )\n    device = flow.device(in_device)\n    tensor = flow.tensor(np_arr, device=device, dtype=flow.float32)\n    placement = flow.placement(in_device, ranks=[0, 1])\n    x = tensor.to_global(placement, flow.sbp.split(0))\n    y = x.to_global(placement, flow.sbp.split(1))\n    new_placement = flow.placement(out_device, ranks=[0, 1, 2, 3])\n    z = y.to_global(new_placement, flow.sbp.split(1))\n    test_case.assertEqual(z.placement, new_placement)\n    if flow.env.get_rank() == 0:\n        test_case.assertTrue(\n            np.array_equal(\n                z.to_local().numpy(),\n                np.array(\n                    [[4, 6], [6, 8], [3, 7], [6, 8], [2, 10], [3, 9], [4, 6], [6, 8],],\n                    dtype=np.float32,\n                ),\n            )\n        )\n    elif flow.env.get_rank() == 1:\n        test_case.assertTrue(\n            np.array_equal(\n                z.to_local().numpy(),\n                np.array(\n                    [\n                        [5, 20],\n                        [9, 0],\n                        [5, 0],\n                        [9, 0],\n                        [10, 7],\n                        [10, 5],\n                        [6, 9],\n                        [6, 4],\n                    ],\n                    dtype=np.float32,\n                ),\n            )\n        )\n    elif flow.env.get_rank() == 2:\n        test_case.assertTrue(\n            np.array_equal(\n                z.to_local().numpy(),\n                np.array(\n                    [[8, 9], [4, 6], [3, 5], [8, 7], [10, 3], [5, 6], [8, 6], [5, 3],],\n                    dtype=np.float32,\n                ),\n            )\n        )\n    elif flow.env.get_rank() == 3:\n        test_case.assertTrue(\n            np.array_equal(\n                z.to_local().numpy(),\n                np.array(\n                    [\n                        [5, 20],\n                        [9, 0],\n                        [0, 3],\n                        [8, 9],\n                        [10, 7],\n                        [9, 10],\n                        [6, 9],\n                        [8, 6],\n                    ],\n                    dtype=np.float32,\n                ),\n            )\n        )\n\n\ndef _test_eager_boxing_with_same_placement_p_to_s1(test_case, in_device, out_device):\n    if flow.env.get_rank() == 0:\n        np_arr = np.array(\n            [\n                [4, 6, 5, 20, 8, 9],\n                [6, 8, 9, 0, 4, 6],\n                [3, 7, 5, 0, 3, 5],\n                [6, 8, 9, 0, 8, 7],\n                [6, 8, 9, 0, 4, 6],\n                [6, 8, 6, 4, 5, 3],\n            ],\n            dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 1:\n        np_arr = np.array(\n            [\n                [2, 10, 10, 7, 10, 3],\n                [3, 9, 10, 5, 5, 6],\n                [4, 6, 6, 9, 8, 6],\n                [6, 8, 6, 4, 5, 3],\n                [4, 9, 7, 0, 2, 1],\n                [6, 3, 9, 2, 5, 2],\n            ],\n            dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 2:\n        np_arr = np.array(\n            [\n                [9, 6, 5, 8, 3, 6],\n                [4, 9, 7, 0, 2, 1],\n                [2, 5, 7, 9, 4, 8],\n                [6, 8, 10, 0, 4, 9],\n                [6, 3, 9, 2, 5, 2],\n                [2, 5, 7, 9, 4, 8],\n            ],\n            dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 3:\n        np_arr = np.array(\n            [\n                [9, 4, 5, 8, 9, 6],\n                [7, 2, 9, 5, 4, 1],\n                [6, 3, 9, 2, 5, 2],\n                [3, 7, 5, 8, 9, 3],\n                [7, 2, 9, 5, 4, 1],\n                [4, 9, 7, 0, 2, 1],\n            ],\n            dtype=np.float32,\n        )\n    device = flow.device(in_device)\n    tensor = flow.tensor(np_arr, device=device, dtype=flow.float32)\n    placement = flow.placement(in_device, ranks=[0, 1, 3])\n    x = tensor.to_global(placement, flow.sbp.partial_sum)\n    y = x.to_global(placement, flow.sbp.split(1))\n    test_case.assertEqual(y.placement, placement)\n    if flow.env.get_rank() == 0:\n        test_case.assertTrue(\n            np.array_equal(\n                y.to_local().numpy(),\n                np.array(\n                    [[15, 20], [16, 19], [13, 16], [15, 23], [17, 19], [16, 20],],\n                    dtype=np.float32,\n                ),\n            )\n        )\n    if flow.env.get_rank() == 1:\n        test_case.assertTrue(\n            np.array_equal(\n                y.to_local().numpy(),\n                np.array(\n                    [[20, 35], [28, 10], [20, 11], [20, 12], [25, 5], [22, 6],],\n                    dtype=np.float32,\n                ),\n            )\n        )\n    if flow.env.get_rank() == 3:\n        test_case.assertTrue(\n            np.array_equal(\n                y.to_local().numpy(),\n                np.array(\n                    [[27, 18], [13, 13], [16, 13], [22, 13], [10, 8], [12, 6],],\n                    dtype=np.float32,\n                ),\n            )\n        )\n\n\ndef _test_eager_boxing_with_same_placement_b_to_s1(test_case, in_device, out_device):\n    if flow.env.get_rank() == 0:\n        np_arr = np.array(\n            [\n                [4, 6, 5, 20, 8, 9],\n                [6, 8, 9, 0, 4, 6],\n                [3, 7, 5, 0, 3, 5],\n                [6, 8, 9, 0, 8, 7],\n                [6, 8, 9, 0, 4, 6],\n                [6, 8, 6, 4, 5, 3],\n            ],\n            dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 1:\n        np_arr = np.array(\n            [\n                [2, 10, 10, 7, 10, 3],\n                [3, 9, 10, 5, 5, 6],\n                [4, 6, 6, 9, 8, 6],\n                [6, 8, 6, 4, 5, 3],\n                [4, 9, 7, 0, 2, 1],\n                [6, 3, 9, 2, 5, 2],\n            ],\n            dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 2:\n        np_arr = np.array(\n            [\n                [9, 6, 5, 8, 3, 6],\n                [4, 9, 7, 0, 2, 1],\n                [2, 5, 7, 9, 4, 8],\n                [6, 8, 10, 0, 4, 9],\n                [6, 3, 9, 2, 5, 2],\n                [2, 5, 7, 9, 4, 8],\n            ],\n            dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 3:\n        np_arr = np.array(\n            [\n                [9, 4, 5, 8, 9, 6],\n                [7, 2, 9, 5, 4, 1],\n                [6, 3, 9, 2, 5, 2],\n                [3, 7, 5, 8, 9, 3],\n                [7, 2, 9, 5, 4, 1],\n                [4, 9, 7, 0, 2, 1],\n            ],\n            dtype=np.float32,\n        )\n    device = flow.device(in_device)\n    tensor = flow.tensor(np_arr, device=device, dtype=flow.float32)\n    placement = flow.placement(in_device, ranks=[0, 1, 3])\n    x = tensor.to_global(placement, flow.sbp.broadcast)\n    y = x.to_global(placement, flow.sbp.split(1))\n    test_case.assertEqual(y.placement, placement)\n    if flow.env.get_rank() == 0:\n        test_case.assertTrue(\n            np.array_equal(\n                y.to_local().numpy(),\n                np.array(\n                    [[4, 6], [6, 8], [3, 7], [6, 8], [6, 8], [6, 8],], dtype=np.float32,\n                ),\n            )\n        )\n    if flow.env.get_rank() == 1:\n        test_case.assertTrue(\n            np.array_equal(\n                y.to_local().numpy(),\n                np.array(\n                    [[5, 20], [9, 0], [5, 0], [9, 0], [9, 0], [6, 4],],\n                    dtype=np.float32,\n                ),\n            )\n        )\n    if flow.env.get_rank() == 3:\n        test_case.assertTrue(\n            np.array_equal(\n                y.to_local().numpy(),\n                np.array(\n                    [[8, 9], [4, 6], [3, 5], [8, 7], [4, 6], [5, 3],], dtype=np.float32,\n                ),\n            )\n        )\n\n\ndef _test_eager_boxing_with_same_placement_s0_to_s1(test_case, in_device, out_device):\n    if flow.env.get_rank() == 0:\n        np_arr = np.array(\n            [\n                [4, 6, 5, 20, 8, 9],\n                [6, 8, 9, 0, 4, 6],\n                [3, 7, 5, 0, 3, 5],\n                [6, 8, 9, 0, 8, 7],\n            ],\n            dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 1:\n        np_arr = np.array(\n            [\n                [2, 10, 10, 7, 10, 3],\n                [3, 9, 10, 5, 5, 6],\n                [4, 6, 6, 9, 8, 6],\n                [6, 8, 6, 4, 5, 3],\n            ],\n            dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 2:\n        np_arr = np.array(\n            [\n                [9, 6, 5, 8, 3, 6],\n                [4, 9, 7, 0, 2, 1],\n                [2, 5, 7, 9, 4, 8],\n                [6, 8, 10, 0, 4, 9],\n            ],\n            dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 3:\n        np_arr = np.array(\n            [\n                [9, 4, 5, 8, 9, 6],\n                [7, 2, 9, 5, 4, 1],\n                [6, 3, 9, 2, 5, 2],\n                [3, 7, 5, 8, 9, 3],\n            ],\n            dtype=np.float32,\n        )\n    device = flow.device(in_device)\n    tensor = flow.tensor(np_arr, device=device, dtype=flow.float32)\n    placement = flow.placement(in_device, ranks=[0, 1, 3])\n    x = tensor.to_global(placement, flow.sbp.split(0))\n    y = x.to_global(placement, flow.sbp.split(1))\n    test_case.assertEqual(y.placement, placement)\n    if flow.env.get_rank() == 0:\n        test_case.assertTrue(\n            np.array_equal(\n                y.to_local().numpy(),\n                np.array(\n                    [\n                        [4, 6],\n                        [6, 8],\n                        [3, 7],\n                        [6, 8],\n                        [2, 10],\n                        [3, 9],\n                        [4, 6],\n                        [6, 8],\n                        [9, 4],\n                        [7, 2],\n                        [6, 3],\n                        [3, 7],\n                    ],\n                    dtype=np.float32,\n                ),\n            )\n        )\n    if flow.env.get_rank() == 1:\n        test_case.assertTrue(\n            np.array_equal(\n                y.to_local().numpy(),\n                np.array(\n                    [\n                        [5, 20],\n                        [9, 0],\n                        [5, 0],\n                        [9, 0],\n                        [10, 7],\n                        [10, 5],\n                        [6, 9],\n                        [6, 4],\n                        [5, 8],\n                        [9, 5],\n                        [9, 2],\n                        [5, 8],\n                    ],\n                    dtype=np.float32,\n                ),\n            )\n        )\n    if flow.env.get_rank() == 3:\n        test_case.assertTrue(\n            np.array_equal(\n                y.to_local().numpy(),\n                np.array(\n                    [\n                        [8, 9],\n                        [4, 6],\n                        [3, 5],\n                        [8, 7],\n                        [10, 3],\n                        [5, 6],\n                        [8, 6],\n                        [5, 3],\n                        [9, 6],\n                        [4, 1],\n                        [5, 2],\n                        [9, 3],\n                    ],\n                    dtype=np.float32,\n                ),\n            )\n        )\n\n\ndef _test_eager_boxing_with_same_placement_s1_to_s1(test_case, in_device, out_device):\n    if flow.env.get_rank() == 0:\n        np_arr = np.array(\n            [\n                [4, 6, 5, 20, 8, 9],\n                [6, 8, 9, 0, 4, 6],\n                [3, 7, 5, 0, 3, 5],\n                [6, 8, 9, 0, 8, 7],\n            ],\n            dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 1:\n        np_arr = np.array(\n            [\n                [2, 10, 10, 7, 10, 3],\n                [3, 9, 10, 5, 5, 6],\n                [4, 6, 6, 9, 8, 6],\n                [6, 8, 6, 4, 5, 3],\n            ],\n            dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 2:\n        np_arr = np.array(\n            [\n                [9, 6, 5, 8, 3, 6],\n                [4, 9, 7, 0, 2, 1],\n                [2, 5, 7, 9, 4, 8],\n                [6, 8, 10, 0, 4, 9],\n            ],\n            dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 3:\n        np_arr = np.array(\n            [\n                [9, 4, 5, 8, 9, 6],\n                [7, 2, 9, 5, 4, 1],\n                [6, 3, 9, 2, 5, 2],\n                [3, 7, 5, 8, 9, 3],\n            ],\n            dtype=np.float32,\n        )\n    device = flow.device(in_device)\n    tensor = flow.tensor(np_arr, device=device, dtype=flow.float32)\n    placement = flow.placement(in_device, ranks=[0, 1, 3])\n    x = tensor.to_global(placement, flow.sbp.split(0))\n    y = x.to_global(placement, flow.sbp.split(1))\n    z = y.to_global(placement, flow.sbp.split(1))\n    test_case.assertEqual(z.placement, placement)\n    if flow.env.get_rank() == 0:\n        test_case.assertTrue(\n            np.array_equal(\n                z.to_local().numpy(),\n                np.array(\n                    [\n                        [4, 6],\n                        [6, 8],\n                        [3, 7],\n                        [6, 8],\n                        [2, 10],\n                        [3, 9],\n                        [4, 6],\n                        [6, 8],\n                        [9, 4],\n                        [7, 2],\n                        [6, 3],\n                        [3, 7],\n                    ],\n                    dtype=np.float32,\n                ),\n            )\n        )\n    if flow.env.get_rank() == 1:\n        test_case.assertTrue(\n            np.array_equal(\n                z.to_local().numpy(),\n                np.array(\n                    [\n                        [5, 20],\n                        [9, 0],\n                        [5, 0],\n                        [9, 0],\n                        [10, 7],\n                        [10, 5],\n                        [6, 9],\n                        [6, 4],\n                        [5, 8],\n                        [9, 5],\n                        [9, 2],\n                        [5, 8],\n                    ],\n                    dtype=np.float32,\n                ),\n            )\n        )\n    if flow.env.get_rank() == 3:\n        test_case.assertTrue(\n            np.array_equal(\n                z.to_local().numpy(),\n                np.array(\n                    [\n                        [8, 9],\n                        [4, 6],\n                        [3, 5],\n                        [8, 7],\n                        [10, 3],\n                        [5, 6],\n                        [8, 6],\n                        [5, 3],\n                        [9, 6],\n                        [4, 1],\n                        [5, 2],\n                        [9, 3],\n                    ],\n                    dtype=np.float32,\n                ),\n            )\n        )\n\n\ndef _test_eager_boxing_with_same_placement_s1_to_s0(test_case, in_device, out_device):\n    if flow.env.get_rank() == 0:\n        np_arr = np.array(\n            [\n                [4, 6, 5, 20, 8, 9],\n                [6, 8, 9, 0, 4, 6],\n                [3, 7, 5, 0, 3, 5],\n                [6, 8, 9, 0, 8, 7],\n            ],\n            dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 1:\n        np_arr = np.array(\n            [\n                [2, 10, 10, 7, 10, 3],\n                [3, 9, 10, 5, 5, 6],\n                [4, 6, 6, 9, 8, 6],\n                [6, 8, 6, 4, 5, 3],\n            ],\n            dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 2:\n        np_arr = np.array(\n            [\n                [9, 6, 5, 8, 3, 6],\n                [4, 9, 7, 0, 2, 1],\n                [2, 5, 7, 9, 4, 8],\n                [6, 8, 10, 0, 4, 9],\n            ],\n            dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 3:\n        np_arr = np.array(\n            [\n                [9, 4, 5, 8, 9, 6],\n                [7, 2, 9, 5, 4, 1],\n                [6, 3, 9, 2, 5, 2],\n                [3, 7, 5, 8, 9, 3],\n            ],\n            dtype=np.float32,\n        )\n    device = flow.device(in_device)\n    tensor = flow.tensor(np_arr, device=device, dtype=flow.float32)\n    placement = flow.placement(in_device, ranks=[0, 1, 3])\n    x = tensor.to_global(placement, flow.sbp.split(0))\n    y = x.to_global(placement, flow.sbp.split(1))\n    z = y.to_global(placement, flow.sbp.split(0))\n    test_case.assertEqual(z.placement, placement)\n    if flow.env.get_rank() == 0:\n        test_case.assertTrue(\n            np.array_equal(\n                z.to_local().numpy(),\n                np.array(\n                    [\n                        [4, 6, 5, 20, 8, 9],\n                        [6, 8, 9, 0, 4, 6],\n                        [3, 7, 5, 0, 3, 5],\n                        [6, 8, 9, 0, 8, 7],\n                    ],\n                    dtype=np.float32,\n                ),\n            )\n        )\n    if flow.env.get_rank() == 1:\n        test_case.assertTrue(\n            np.array_equal(\n                z.to_local().numpy(),\n                np.array(\n                    [\n                        [2, 10, 10, 7, 10, 3],\n                        [3, 9, 10, 5, 5, 6],\n                        [4, 6, 6, 9, 8, 6],\n                        [6, 8, 6, 4, 5, 3],\n                    ],\n                    dtype=np.float32,\n                ),\n            )\n        )\n    if flow.env.get_rank() == 3:\n        test_case.assertTrue(\n            np.array_equal(\n                z.to_local().numpy(),\n                np.array(\n                    [\n                        [9, 4, 5, 8, 9, 6],\n                        [7, 2, 9, 5, 4, 1],\n                        [6, 3, 9, 2, 5, 2],\n                        [3, 7, 5, 8, 9, 3],\n                    ],\n                    dtype=np.float32,\n                ),\n            )\n        )\n\n\ndef _test_eager_boxing_with_same_placement_s1_to_p(test_case, in_device, out_device):\n    if flow.env.get_rank() == 0:\n        np_arr = np.array(\n            [\n                [4, 6, 5, 20, 8, 9],\n                [6, 8, 9, 0, 4, 6],\n                [3, 7, 5, 0, 3, 5],\n                [6, 8, 9, 0, 8, 7],\n            ],\n            dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 1:\n        np_arr = np.array(\n            [\n                [2, 10, 10, 7, 10, 3],\n                [3, 9, 10, 5, 5, 6],\n                [4, 6, 6, 9, 8, 6],\n                [6, 8, 6, 4, 5, 3],\n            ],\n            dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 2:\n        np_arr = np.array(\n            [\n                [9, 6, 5, 8, 3, 6],\n                [4, 9, 7, 0, 2, 1],\n                [2, 5, 7, 9, 4, 8],\n                [6, 8, 10, 0, 4, 9],\n            ],\n            dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 3:\n        np_arr = np.array(\n            [\n                [9, 4, 5, 8, 9, 6],\n                [7, 2, 9, 5, 4, 1],\n                [6, 3, 9, 2, 5, 2],\n                [3, 7, 5, 8, 9, 3],\n            ],\n            dtype=np.float32,\n        )\n    device = flow.device(in_device)\n    tensor = flow.tensor(np_arr, device=device, dtype=flow.float32)\n    placement = flow.placement(in_device, ranks=[0, 1, 3])\n    x = tensor.to_global(placement, flow.sbp.split(0))\n    y = x.to_global(placement, flow.sbp.split(1))\n    z = y.to_global(placement, flow.sbp.partial_sum)\n    test_case.assertEqual(z.placement, placement)\n    if flow.env.get_rank() == 0:\n        test_case.assertTrue(\n            np.array_equal(\n                z.to_local().numpy(),\n                np.array(\n                    [\n                        [4, 6, 0, 0, 0, 0],\n                        [6, 8, 0, 0, 0, 0],\n                        [3, 7, 0, 0, 0, 0],\n                        [6, 8, 0, 0, 0, 0],\n                        [2, 10, 0, 0, 0, 0],\n                        [3, 9, 0, 0, 0, 0],\n                        [4, 6, 0, 0, 0, 0],\n                        [6, 8, 0, 0, 0, 0],\n                        [9, 4, 0, 0, 0, 0],\n                        [7, 2, 0, 0, 0, 0],\n                        [6, 3, 0, 0, 0, 0],\n                        [3, 7, 0, 0, 0, 0],\n                    ],\n                    dtype=np.float32,\n                ),\n            )\n        )\n    if flow.env.get_rank() == 1:\n        test_case.assertTrue(\n            np.array_equal(\n                z.to_local().numpy(),\n                np.array(\n                    [\n                        [0, 0, 5, 20, 0, 0],\n                        [0, 0, 9, 0, 0, 0],\n                        [0, 0, 5, 0, 0, 0],\n                        [0, 0, 9, 0, 0, 0],\n                        [0, 0, 10, 7, 0, 0],\n                        [0, 0, 10, 5, 0, 0],\n                        [0, 0, 6, 9, 0, 0],\n                        [0, 0, 6, 4, 0, 0],\n                        [0, 0, 5, 8, 0, 0],\n                        [0, 0, 9, 5, 0, 0],\n                        [0, 0, 9, 2, 0, 0],\n                        [0, 0, 5, 8, 0, 0],\n                    ],\n                    dtype=np.float32,\n                ),\n            )\n        )\n    if flow.env.get_rank() == 3:\n        test_case.assertTrue(\n            np.array_equal(\n                z.to_local().numpy(),\n                np.array(\n                    [\n                        [0, 0, 0, 0, 8, 9],\n                        [0, 0, 0, 0, 4, 6],\n                        [0, 0, 0, 0, 3, 5],\n                        [0, 0, 0, 0, 8, 7],\n                        [0, 0, 0, 0, 10, 3],\n                        [0, 0, 0, 0, 5, 6],\n                        [0, 0, 0, 0, 8, 6],\n                        [0, 0, 0, 0, 5, 3],\n                        [0, 0, 0, 0, 9, 6],\n                        [0, 0, 0, 0, 4, 1],\n                        [0, 0, 0, 0, 5, 2],\n                        [0, 0, 0, 0, 9, 3],\n                    ],\n                    dtype=np.float32,\n                ),\n            )\n        )\n\n\ndef _test_eager_boxing_with_same_placement_s1_to_b(test_case, in_device, out_device):\n    if flow.env.get_rank() == 0:\n        np_arr = np.array(\n            [\n                [4, 6, 5, 20, 8, 9],\n                [6, 8, 9, 0, 4, 6],\n                [3, 7, 5, 0, 3, 5],\n                [6, 8, 9, 0, 8, 7],\n            ],\n            dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 1:\n        np_arr = np.array(\n            [\n                [2, 10, 10, 7, 10, 3],\n                [3, 9, 10, 5, 5, 6],\n                [4, 6, 6, 9, 8, 6],\n                [6, 8, 6, 4, 5, 3],\n            ],\n            dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 2:\n        np_arr = np.array(\n            [\n                [9, 6, 5, 8, 3, 6],\n                [4, 9, 7, 0, 2, 1],\n                [2, 5, 7, 9, 4, 8],\n                [6, 8, 10, 0, 4, 9],\n            ],\n            dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 3:\n        np_arr = np.array(\n            [\n                [9, 4, 5, 8, 9, 6],\n                [7, 2, 9, 5, 4, 1],\n                [6, 3, 9, 2, 5, 2],\n                [3, 7, 5, 8, 9, 3],\n            ],\n            dtype=np.float32,\n        )\n    device = flow.device(in_device)\n    tensor = flow.tensor(np_arr, device=device, dtype=flow.float32)\n    placement = flow.placement(in_device, ranks=[0, 1, 3])\n    x = tensor.to_global(placement, flow.sbp.split(0))\n    y = x.to_global(placement, flow.sbp.split(1))\n    z = y.to_global(placement, flow.sbp.broadcast)\n    test_case.assertEqual(z.placement, placement)\n    if flow.env.get_rank() == 0:\n        test_case.assertTrue(\n            np.array_equal(\n                z.to_local().numpy(),\n                np.array(\n                    [\n                        [4, 6, 5, 20, 8, 9],\n                        [6, 8, 9, 0, 4, 6],\n                        [3, 7, 5, 0, 3, 5],\n                        [6, 8, 9, 0, 8, 7],\n                        [2, 10, 10, 7, 10, 3],\n                        [3, 9, 10, 5, 5, 6],\n                        [4, 6, 6, 9, 8, 6],\n                        [6, 8, 6, 4, 5, 3],\n                        [9, 4, 5, 8, 9, 6],\n                        [7, 2, 9, 5, 4, 1],\n                        [6, 3, 9, 2, 5, 2],\n                        [3, 7, 5, 8, 9, 3],\n                    ],\n                    dtype=np.float32,\n                ),\n            )\n        )\n    if flow.env.get_rank() == 1:\n        test_case.assertTrue(\n            np.array_equal(\n                z.to_local().numpy(),\n                np.array(\n                    [\n                        [4, 6, 5, 20, 8, 9],\n                        [6, 8, 9, 0, 4, 6],\n                        [3, 7, 5, 0, 3, 5],\n                        [6, 8, 9, 0, 8, 7],\n                        [2, 10, 10, 7, 10, 3],\n                        [3, 9, 10, 5, 5, 6],\n                        [4, 6, 6, 9, 8, 6],\n                        [6, 8, 6, 4, 5, 3],\n                        [9, 4, 5, 8, 9, 6],\n                        [7, 2, 9, 5, 4, 1],\n                        [6, 3, 9, 2, 5, 2],\n                        [3, 7, 5, 8, 9, 3],\n                    ],\n                    dtype=np.float32,\n                ),\n            )\n        )\n    if flow.env.get_rank() == 3:\n        test_case.assertTrue(\n            np.array_equal(\n                z.to_local().numpy(),\n                np.array(\n                    [\n                        [4, 6, 5, 20, 8, 9],\n                        [6, 8, 9, 0, 4, 6],\n                        [3, 7, 5, 0, 3, 5],\n                        [6, 8, 9, 0, 8, 7],\n                        [2, 10, 10, 7, 10, 3],\n                        [3, 9, 10, 5, 5, 6],\n                        [4, 6, 6, 9, 8, 6],\n                        [6, 8, 6, 4, 5, 3],\n                        [9, 4, 5, 8, 9, 6],\n                        [7, 2, 9, 5, 4, 1],\n                        [6, 3, 9, 2, 5, 2],\n                        [3, 7, 5, 8, 9, 3],\n                    ],\n                    dtype=np.float32,\n                ),\n            )\n        )\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\ndef _test_eager_boxing_b_to_s(\n    test_case, shape, device_type, in_device_list, out_device_list, out_split_axis\n):\n    np_arr = np.random.uniform(-1e-05, 1e-05, shape)\n    # use cuda to avoid slice boxing here\n    placement_with_all_cuda_device = flow.placement.all(\"cuda\")\n\n    x = flow.tensor(np_arr, device=\"cuda\", dtype=flow.float32)\n    x = x.to_global(placement_with_all_cuda_device, flow.sbp.broadcast)\n\n    placement = flow.placement(device_type, in_device_list)\n    y = x.to_global(placement, flow.sbp.broadcast)\n    new_placement = flow.placement(device_type, out_device_list)\n    z = y.to_global(new_placement, flow.sbp.split(out_split_axis))\n\n    if flow.env.get_rank() in out_device_list:\n        idx = out_device_list.index(flow.env.get_rank())\n        step = int(shape[out_split_axis] / len(out_device_list))\n        if out_split_axis == 0:\n            test_case.assertTrue(\n                np.allclose(\n                    z.to_local().numpy(),\n                    x.to_local().numpy()[idx * step : (idx + 1) * step],\n                    1e-5,\n                    1e-5,\n                )\n            )\n        elif out_split_axis == 1:\n            test_case.assertTrue(\n                np.allclose(\n                    z.to_local().numpy(),\n                    x.to_local().numpy()[..., idx * step : (idx + 1) * step],\n                    1e-5,\n                    1e-5,\n                )\n            )\n        else:\n            raise \"only test case with out_split_axis == 0 or out_split_axis == 1\"\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\ndef _test_eager_boxing_s_to_b(\n    test_case, shape, device_type, in_device_list, out_device_list, in_split_axis\n):\n    np_arr = np.random.uniform(-1e-05, 1e-05, shape)\n    # use cuda to avoid slice boxing here\n    placement_with_all_cuda_device = flow.placement.all(\"cuda\")\n\n    x = flow.tensor(np_arr, device=\"cuda\", dtype=flow.float32)\n    x = x.to_global(placement_with_all_cuda_device, flow.sbp.broadcast)\n\n    placement = flow.placement(device_type, in_device_list)\n    y = x.to_global(placement, flow.sbp.broadcast)\n\n    y = y.to_global(placement, flow.sbp.split(in_split_axis))\n\n    new_placement = flow.placement(device_type, out_device_list)\n    z = y.to_global(new_placement, flow.sbp.broadcast)\n\n    if flow.env.get_rank() in out_device_list:\n        test_case.assertTrue(\n            np.allclose(z.to_local().numpy(), x.to_local().numpy(), 1e-5, 1e-5,)\n        )\n    test_case.assertEqual(z.placement, new_placement)\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\ndef _test_eager_boxing_p_to_s(\n    test_case, shape, device_type, in_device_list, out_device_list, out_split_axis\n):\n    np_arr = np.random.uniform(-1e-05, 1e-05, shape)\n    # use cuda to avoid slice boxing here\n    placement_with_all_cuda_device = flow.placement.all(\"cuda\")\n\n    x = flow.tensor(np_arr, device=\"cuda\", dtype=flow.float32)\n    x = x.to_global(placement_with_all_cuda_device, flow.sbp.broadcast)\n\n    placement = flow.placement(device_type, in_device_list)\n    y = x.to_global(placement, flow.sbp.broadcast)\n    y = y.to_global(placement, flow.sbp.partial_sum)\n    new_placement = flow.placement(device_type, out_device_list)\n    z = y.to_global(new_placement, flow.sbp.split(out_split_axis))\n\n    if flow.env.get_rank() in out_device_list:\n        idx = out_device_list.index(flow.env.get_rank())\n        step = int(shape[out_split_axis] / len(out_device_list))\n        if out_split_axis == 0:\n            test_case.assertTrue(\n                np.allclose(\n                    z.to_local().numpy(),\n                    x.to_local().numpy()[idx * step : (idx + 1) * step],\n                    1e-5,\n                    1e-5,\n                )\n            )\n        elif out_split_axis == 1:\n            test_case.assertTrue(\n                np.allclose(\n                    z.to_local().numpy(),\n                    x.to_local().numpy()[..., idx * step : (idx + 1) * step],\n                    1e-5,\n                    1e-5,\n                )\n            )\n        else:\n            raise \"only test case with out_split_axis == 0 or out_split_axis == 1\"\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\ndef _test_eager_boxing_p_to_b(\n    test_case, shape, device_type, in_device_list, out_device_list\n):\n    np_arr = np.random.uniform(-1e-05, 1e-05, shape)\n    # use cuda to avoid slice boxing here\n    placement_with_all_cuda_device = flow.placement.all(\"cuda\")\n\n    x = flow.tensor(np_arr, device=\"cuda\", dtype=flow.float32)\n    x = x.to_global(placement_with_all_cuda_device, flow.sbp.broadcast)\n\n    placement = flow.placement(device_type, in_device_list)\n    y = x.to_global(placement, flow.sbp.broadcast)\n    y = y.to_global(placement, flow.sbp.partial_sum)\n\n    new_placement = flow.placement(device_type, out_device_list)\n    z = y.to_global(new_placement, flow.sbp.broadcast)\n\n    if flow.env.get_rank() in out_device_list:\n        test_case.assertTrue(\n            np.allclose(z.to_local().numpy(), x.to_local().numpy(), 1e-5, 1e-5,)\n        )\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\ndef _test_eager_naive_boxing_s_to_s(\n    test_case,\n    device_type,\n    shape,\n    in_device_list,\n    out_device_list,\n    in_split_axis,\n    out_split_axis,\n):\n    np_arr = np.random.uniform(-1e-05, 1e-05, shape)\n    placement_with_all_cuda_device = flow.placement.all(device_type)\n\n    x = flow.tensor(np_arr, device=device_type, dtype=flow.float32)\n\n    x = x.to_global(placement_with_all_cuda_device, flow.sbp.broadcast)\n\n    placement = flow.placement(device_type, in_device_list)\n    y = x.to_global(placement, flow.sbp.broadcast)\n    y = y.to_global(placement, flow.sbp.split(in_split_axis))\n\n    new_placement = flow.placement(device_type, out_device_list)\n    z = y.to_global(new_placement, flow.sbp.split(out_split_axis))\n\n    if flow.env.get_rank() in out_device_list:\n        idx = out_device_list.index(flow.env.get_rank())\n        step = int(shape[out_split_axis] / len(out_device_list))\n        if out_split_axis == 0:\n            test_case.assertTrue(\n                np.allclose(\n                    z.to_local().numpy(),\n                    x.to_local().numpy()[idx * step : (idx + 1) * step],\n                    1e-5,\n                    1e-5,\n                )\n            )\n        elif out_split_axis == 1:\n            test_case.assertTrue(\n                np.allclose(\n                    z.to_local().numpy(),\n                    x.to_local().numpy()[..., idx * step : (idx + 1) * step],\n                    1e-5,\n                    1e-5,\n                )\n            )\n        else:\n            raise \"only test case with out_split_axis == 0 or out_split_axis == 1\"\n    test_case.assertEqual(z.placement, new_placement)\n\n\n@flow.unittest.skip_unless_1n4d()\nclass TestEagerBoxingWithNonOverlappingPlacement(flow.unittest.TestCase):\n    def test_eager_boxing_with_non_overlapping_placement_p_to_s1(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"in_device\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"out_device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            _test_eager_boxing_with_non_overlapping_placement_p_to_s1(test_case, *arg)\n\n    def test_eager_boxing_with_non_overlapping_placement_b_to_s1(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"in_device\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"out_device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            _test_eager_boxing_with_non_overlapping_placement_b_to_s1(test_case, *arg)\n\n    def test_eager_boxing_with_non_overlapping_placement_s0_to_s1(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"in_device\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"out_device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            _test_eager_boxing_with_non_overlapping_placement_s0_to_s1(test_case, *arg)\n\n    def test_eager_boxing_with_non_overlapping_placement_s1_to_s1(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"in_device\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"out_device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            _test_eager_boxing_with_non_overlapping_placement_s1_to_s1(test_case, *arg)\n\n    def test_eager_boxing_with_non_overlapping_placement_s1_to_s0(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"in_device\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"out_device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            _test_eager_boxing_with_non_overlapping_placement_s1_to_s0(test_case, *arg)\n\n    def test_eager_boxing_with_non_overlapping_placement_s1_to_b(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"in_device\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"out_device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            _test_eager_boxing_with_non_overlapping_placement_s1_to_b(test_case, *arg)\n\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    def test_eager_boxing_with_non_overlapping_placement_s1_to_p(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"in_device\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"out_device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            _test_eager_boxing_with_non_overlapping_placement_s1_to_p(test_case, *arg)\n\n\n@flow.unittest.skip_unless_1n4d()\nclass TestEagerBoxingWithOverlappingPlacement(flow.unittest.TestCase):\n    def test_eager_boxing_with_overlapping_placement_p_to_s1(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"in_device\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"out_device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            _test_eager_boxing_with_overlapping_placement_p_to_s1(test_case, *arg)\n\n    def test_eager_boxing_with_overlapping_placement_b_to_s1(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"in_device\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"out_device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            _test_eager_boxing_with_overlapping_placement_b_to_s1(test_case, *arg)\n\n    def test_eager_boxing_with_overlapping_placement_s0_to_s1(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"in_device\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"out_device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            _test_eager_boxing_with_overlapping_placement_s0_to_s1(test_case, *arg)\n\n    def test_eager_boxing_with_overlapping_placement_s1_to_s1(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"in_device\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"out_device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            _test_eager_boxing_with_overlapping_placement_s1_to_s1(test_case, *arg)\n\n    def test_eager_boxing_with_overlapping_placement_s1_to_s0(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"in_device\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"out_device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            _test_eager_boxing_with_overlapping_placement_s1_to_s0(test_case, *arg)\n\n    def test_eager_boxing_with_overlapping_placement_s1_to_b(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"in_device\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"out_device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            _test_eager_boxing_with_overlapping_placement_s1_to_b(test_case, *arg)\n\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    def test_eager_boxing_with_overlapping_placement_s1_to_p(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"in_device\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"out_device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            _test_eager_boxing_with_overlapping_placement_s1_to_p(test_case, *arg)\n\n\n@flow.unittest.skip_unless_1n4d()\nclass TestEagerBoxingWithInPlacementContainOutPlacement(flow.unittest.TestCase):\n    def test_eager_boxing_with_in_placement_contain_out_placement_p_to_s1(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"in_device\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"out_device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            _test_eager_boxing_with_in_placement_contain_out_placement_p_to_s1(\n                test_case, *arg\n            )\n\n    def test_eager_boxing_with_in_placement_contain_out_placement_b_to_s1(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"in_device\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"out_device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            _test_eager_boxing_with_in_placement_contain_out_placement_b_to_s1(\n                test_case, *arg\n            )\n\n    def test_eager_boxing_with_in_placement_contain_out_placement_s0_to_s1(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"in_device\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"out_device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            _test_eager_boxing_with_in_placement_contain_out_placement_s0_to_s1(\n                test_case, *arg\n            )\n\n    def test_eager_boxing_with_in_placement_contain_out_placement_s1_to_s1(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"in_device\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"out_device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            _test_eager_boxing_with_in_placement_contain_out_placement_s1_to_s1(\n                test_case, *arg\n            )\n\n    def test_eager_boxing_with_in_placement_contain_out_placement_s1_to_s0(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"in_device\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"out_device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            _test_eager_boxing_with_in_placement_contain_out_placement_s1_to_s0(\n                test_case, *arg\n            )\n\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    def test_eager_boxing_with_in_placement_contain_out_placement_s1_to_p(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"in_device\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"out_device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            _test_eager_boxing_with_in_placement_contain_out_placement_s1_to_p(\n                test_case, *arg\n            )\n\n    def test_eager_boxing_with_in_placement_contain_out_placement_s1_to_b(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"in_device\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"out_device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            _test_eager_boxing_with_in_placement_contain_out_placement_s1_to_b(\n                test_case, *arg\n            )\n\n\n@flow.unittest.skip_unless_1n4d()\nclass TestEagerBoxingWithOutPlacementContainInPlacement(flow.unittest.TestCase):\n    def test_eager_boxing_with_out_placement_contain_in_placement_p_to_s1(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"in_device\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"out_device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            _test_eager_boxing_with_out_placement_contain_in_placement_p_to_s1(\n                test_case, *arg\n            )\n\n    def test_eager_boxing_with_out_placement_contain_in_placement_b_to_s1(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"in_device\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"out_device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            _test_eager_boxing_with_out_placement_contain_in_placement_b_to_s1(\n                test_case, *arg\n            )\n\n    def test_eager_boxing_with_out_placement_contain_in_placement_s0_to_s1(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"in_device\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"out_device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            _test_eager_boxing_with_out_placement_contain_in_placement_s0_to_s1(\n                test_case, *arg\n            )\n\n    def test_eager_boxing_with_out_placement_contain_in_placement_s1_to_b(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"in_device\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"out_device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            _test_eager_boxing_with_out_placement_contain_in_placement_s1_to_b(\n                test_case, *arg\n            )\n\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    def test_eager_boxing_with_out_placement_contain_in_placement_s1_to_p(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"in_device\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"out_device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            _test_eager_boxing_with_out_placement_contain_in_placement_s1_to_p(\n                test_case, *arg\n            )\n\n    def test_eager_boxing_with_out_placement_contain_in_placement_s1_to_s0(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"in_device\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"out_device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            _test_eager_boxing_with_out_placement_contain_in_placement_s1_to_s0(\n                test_case, *arg\n            )\n\n    def test_eager_boxing_with_out_placement_contain_in_placement_s1_to_s1(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"in_device\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"out_device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            _test_eager_boxing_with_out_placement_contain_in_placement_s1_to_s1(\n                test_case, *arg\n            )\n\n\n@flow.unittest.skip_unless_1n4d()\nclass TestEagerBoxingWithSameInOutPlacement(flow.unittest.TestCase):\n    def test_eager_boxing_with_same_placement_s0_to_s1(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"in_device\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"out_device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            _test_eager_boxing_with_same_placement_s0_to_s1(test_case, *arg)\n\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    def test_eager_boxing_with_same_placement_p_to_s1(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"in_device\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"out_device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            _test_eager_boxing_with_same_placement_p_to_s1(test_case, *arg)\n\n    def test_eager_boxing_with_same_placement_b_to_s1(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"in_device\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"out_device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            _test_eager_boxing_with_same_placement_b_to_s1(test_case, *arg)\n\n    def test_eager_boxing_with_same_placement_s1_to_s1(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"in_device\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"out_device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            _test_eager_boxing_with_same_placement_s1_to_s1(test_case, *arg)\n\n    def test_eager_boxing_with_same_placement_s1_to_s0(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"in_device\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"out_device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            _test_eager_boxing_with_same_placement_s1_to_s0(test_case, *arg)\n\n    def test_eager_boxing_with_same_placement_s1_to_p(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"in_device\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"out_device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            _test_eager_boxing_with_same_placement_s1_to_p(test_case, *arg)\n\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    def test_eager_boxing_with_same_placement_s1_to_b(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"in_device\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"out_device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            _test_eager_boxing_with_same_placement_s1_to_b(test_case, *arg)\n\n\n@flow.unittest.skip_unless_1n4d()\nclass TestEagerBoxingBToS(flow.unittest.TestCase):\n    def test_eager_boxing_b_to_s(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"shape\"] = [(12, 12), (18, 24)]\n        arg_dict[\"device_type\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"in_device_list\"] = [[0, 1], [1, 2, 3]]\n        arg_dict[\"out_device_list\"] = [[2, 3], [0, 1, 3]]\n        arg_dict[\"out_split_axis\"] = [0, 1]\n        for arg in GenArgList(arg_dict):\n            _test_eager_boxing_b_to_s(test_case, *arg)\n\n\n@flow.unittest.skip_unless_1n4d()\nclass TestEagerBoxingPToS(flow.unittest.TestCase):\n    def test_eager_boxing_p_to_s(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"shape\"] = [(12, 12), (18, 24)]\n        arg_dict[\"device_type\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"in_device_list\"] = [[0, 1], [1, 2, 3]]\n        arg_dict[\"out_device_list\"] = [[2, 3], [0, 1, 3]]\n        arg_dict[\"out_split_axis\"] = [0, 1]\n        for arg in GenArgList(arg_dict):\n            _test_eager_boxing_p_to_s(test_case, *arg)\n\n\n@flow.unittest.skip_unless_1n4d()\nclass TestEagerBoxingSToB(flow.unittest.TestCase):\n    def test_eager_boxing_s_to_b(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"shape\"] = [(12, 12), (12, 18, 24)]\n        arg_dict[\"device_type\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"in_device_list\"] = [[0, 1], [1, 2, 3]]\n        arg_dict[\"out_device_list\"] = [[2, 3], [0, 1, 3]]\n        arg_dict[\"in_split_axis\"] = [0, 1]\n        for arg in GenArgList(arg_dict):\n            _test_eager_boxing_s_to_b(test_case, *arg)\n\n\n@flow.unittest.skip_unless_1n4d()\nclass TestEagerBoxingPToB(flow.unittest.TestCase):\n    def test_eager_boxing_p_to_b(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"shape\"] = [(12, 12), (12, 18, 24)]\n        arg_dict[\"device_type\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"in_device_list\"] = [[0, 1], [1, 2, 3]]\n        arg_dict[\"out_device_list\"] = [[2, 3], [0, 1, 3]]\n        for arg in GenArgList(arg_dict):\n            _test_eager_boxing_p_to_b(test_case, *arg)\n\n\n@flow.unittest.skip_unless_1n4d()\nclass TestEagerNaiveBoxingSToS(flow.unittest.TestCase):\n    def test_eager_naive_boxing_s_to_s(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"device_type\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"shape\"] = [(12, 12), (18, 24)]\n        arg_dict[\"in_device_list\"] = [[0, 1], [1, 2, 3]]\n        arg_dict[\"out_device_list\"] = [[1], [3], [2, 3], [0, 1, 3]]\n        arg_dict[\"in_split_axis\"] = [0, 1]\n        arg_dict[\"out_split_axis\"] = [0, 1]\n        for arg in GenArgList(arg_dict):\n            _test_eager_naive_boxing_s_to_s(test_case, *arg)\n\n\n@flow.unittest.skip_unless_1n2d()\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\nclass TestEagerGlobalCastWithSamePlacementAndSBP(flow.unittest.TestCase):\n    def test_eager_global_cast_with_same_placement_and_sbp(test_case):\n        x = np.ones((4, 8), dtype=np.int32)\n        placement = flow.placement(\"cuda\", ranks=[0, 1])\n        y = flow.tensor(\n            x,\n            dtype=flow.float32,\n            placement=placement,\n            sbp=[flow.sbp.split(0)],\n            requires_grad=False,\n        )\n        z = y.to_global(placement=placement, sbp=[flow.sbp.split(0)])\n        test_case.assertEqual(y.global_id(), z.global_id())\n\n\n@flow.unittest.skip_unless_1n4d()\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\nclass TestEagerGlobalCast1DTo2DSBP(flow.unittest.TestCase):\n    def test_eager_global_cast_1d_to_2d_sbp(test_case):\n        x = np.ones((4, 8), dtype=np.int32)\n        placement1 = flow.placement(\"cuda\", ranks=[0, 1, 2, 3])\n        placement2 = flow.placement(\"cuda\", ranks=[[0, 1], [2, 3]])\n        y = flow.tensor(\n            x,\n            dtype=flow.float32,\n            placement=placement1,\n            sbp=[flow.sbp.split(0)],\n            requires_grad=False,\n        )\n        z = y.to_global(\n            placement=placement2, sbp=[flow.sbp.broadcast, flow.sbp.split(0)]\n        )\n        test_case.assertEqual(z.placement, placement2)\n        test_case.assertTrue(\n            np.array_equal(z.to_local().numpy(), np.ones((2, 8), dtype=np.int32),)\n        )\n\n\n@flow.unittest.skip_unless_1n4d()\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\nclass TestEagerGlobalCast2DTo1DSBP(flow.unittest.TestCase):\n    def test_eager_global_cast_2d_to_1d_sbp(test_case):\n        x = np.ones((4, 8), dtype=np.int32)\n        placement1 = flow.placement(\"cuda\", ranks=[0, 1, 2, 3])\n        placement2 = flow.placement(\"cuda\", ranks=[[0, 1], [2, 3]])\n        y = flow.tensor(\n            x,\n            dtype=flow.float32,\n            placement=placement2,\n            sbp=[flow.sbp.broadcast, flow.sbp.split(0)],\n            requires_grad=False,\n        )\n        z = y.to_global(placement=placement1, sbp=[flow.sbp.split(0)])\n        test_case.assertEqual(z.placement, placement1)\n        test_case.assertTrue(\n            np.array_equal(z.to_local().numpy(), np.ones((1, 8), dtype=np.int32),)\n        )\n\n\ndef _test_eager_global_cast_1d_uneven_split(test_case, device_type, shape):\n    np_arr = np.random.uniform(-1e-05, 1e-05, shape)\n    placement = flow.placement(device_type, range(flow.env.get_world_size()))\n    x = flow.tensor(\n        np_arr, dtype=flow.float32, device=device_type, requires_grad=False,\n    )\n    x = x.to_global(placement=placement, sbp=[flow.sbp.broadcast])\n    # B To S(0)\n    y = x.to_global(placement=placement, sbp=[flow.sbp.split(0)])\n    from oneflow.framework import balanced_splitter as balanced_splitter\n\n    s0_balanced_ranges = balanced_splitter.BalancedRanges(\n        shape[0], flow.env.get_world_size()\n    )\n    s0_range_of_this_rank = s0_balanced_ranges[flow.env.get_rank()]\n    test_case.assertEqual(y.placement, placement)\n    test_case.assertTrue(\n        np.array_equal(\n            y.to_local().numpy(),\n            x.to_local().numpy()[s0_range_of_this_rank[0] : s0_range_of_this_rank[1]],\n        )\n    )\n\n    # S(0) To S(1)\n    z = y.to_global(placement=placement, sbp=[flow.sbp.split(1)])\n    s1_balanced_ranges = flow.framework.balanced_splitter.BalancedRanges(\n        shape[1], flow.env.get_world_size()\n    )\n    s1_range_of_this_rank = s1_balanced_ranges[flow.env.get_rank()]\n\n    test_case.assertEqual(z.placement, placement)\n    test_case.assertTrue(\n        np.allclose(\n            z.to_local().numpy(),\n            x.to_local().numpy()[\n                ..., s1_range_of_this_rank[0] : s1_range_of_this_rank[1]\n            ],\n        )\n    )\n\n    # S(1) To B\n    w = z.to_global(placement=placement, sbp=[flow.sbp.broadcast])\n    test_case.assertEqual(w.placement, placement)\n\n    test_case.assertTrue(np.allclose(w.to_local().numpy(), x.to_local().numpy()))\n\n\n@flow.unittest.skip_unless_1n4d()\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\nclass TestEagerGlobalCastOneDUnevenSplit(flow.unittest.TestCase):\n    def test_eager_global_cast_1d_uneven_split(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"device_type\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"shape\"] = [(25, 33), (13, 17)]\n        for arg in GenArgList(arg_dict):\n            _test_eager_global_cast_1d_uneven_split(test_case, *arg)\n\n\ndef _test_eager_global_n_dim_reduce(test_case, device_type, src_sbp, dst_sbp):\n    np.random.seed(10)\n    np_arr = np.random.uniform(-1e-05, 1e-05, (16, 32))\n    placement0 = flow.placement(device_type, ranks=[[0]])\n    placement1 = flow.placement(device_type, ranks=[[0, 1], [2, 3]])\n\n    # oneflow.placement(type=\"cuda\", ranks=[[0]])\n    # (src_sbp, src_sbp)\n    x = flow.tensor(\n        np_arr, placement=placement0, sbp=[src_sbp, src_sbp], requires_grad=False,\n    )\n\n    # oneflow.placement(type=\"cuda\", ranks=[[0,1],[2,3]])\n    # (dst_sbp, dst_sbp)\n    y = x.to_global(placement=placement1, sbp=[dst_sbp, dst_sbp])\n\n    z = y.to_global(placement=placement1, sbp=[flow.sbp.broadcast, flow.sbp.broadcast])\n    test_case.assertEqual(z.placement, placement1)\n\n    test_case.assertTrue(np.allclose(z.to_local().numpy(), np_arr))\n\n\n@flow.unittest.skip_unless_1n4d()\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\nclass TestEagerGlobalCastNDimReduceBoxing(flow.unittest.TestCase):\n    def test_eager_global_n_dim_reduce(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"device_type\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"src_sbp\"] = [flow.sbp.broadcast, flow.sbp.split(0), flow.sbp.split(1)]\n        arg_dict[\"dst_sbp\"] = [flow.sbp.broadcast, flow.sbp.split(0), flow.sbp.split(1)]\n        for arg in GenArgList(arg_dict):\n            _test_eager_global_n_dim_reduce(test_case, *arg)\n\n\ndef _test_eager_global_with_0_size_data(\n    test_case,\n    shape,\n    in_device_type,\n    out_device_type,\n    in_device_list,\n    out_device_list,\n    in_sbp,\n    out_sbp,\n):\n    in_placement = flow.placement(in_device_type, in_device_list)\n    out_placement = flow.placement(out_device_type, out_device_list)\n    x = flow.Tensor(*shape, placement=in_placement, sbp=in_sbp)\n    y = x.to_global(out_placement, out_sbp)\n\n    test_case.assertEqual(y.placement, out_placement)\n    test_case.assertEqual(y.sbp, out_sbp)\n    test_case.assertEqual(y.size(), shape)\n\n\n@flow.unittest.skip_unless_1n4d()\nclass TestEagerNaiveBoxingSToS(flow.unittest.TestCase):\n    def test_eager_global_with_0_size_data(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"shape\"] = [(8, 0, 4), (5, 0, 7)]\n        arg_dict[\"in_device_type\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"out_device_type\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"in_device_list\"] = [[0, 1], [1, 2, 3], [0, 1, 2, 3]]\n        arg_dict[\"out_device_list\"] = [[1], [3], [2, 3], [0, 1, 3], [0, 1, 2, 3]]\n        arg_dict[\"in_sbp\"] = [\n            (flow.sbp.split(0),),\n            (flow.sbp.split(2),),\n            (flow.sbp.broadcast,),\n            (flow.sbp.partial_sum,),\n        ]\n        arg_dict[\"out_sbp\"] = [\n            (flow.sbp.split(0),),\n            (flow.sbp.split(2),),\n            (flow.sbp.broadcast,),\n            (flow.sbp.partial_sum,),\n        ]\n        for arg in GenArgList(arg_dict):\n            _test_eager_global_with_0_size_data(test_case, *arg)\n\n\ndef _test_eager_boxing_one_to_n_with_diff_dim(\n    test_case, in_device_type, out_device_type\n):\n    x = flow.tensor(\n        [1, 2, 3, 4],\n        sbp=flow.sbp.broadcast,\n        placement=flow.placement(in_device_type, ranks=[0]),\n    )\n    y = x.to_global(\n        sbp=[flow.sbp.broadcast, flow.sbp.split(0)],\n        placement=flow.placement(out_device_type, ranks=[[0, 1], [2, 3]]),\n    )\n\n    rank = flow.env.get_rank()\n    if rank == 0 or rank == 2:\n        test_case.assertTrue(np.array_equal(y.to_local().numpy(), np.array([1, 2]),))\n    elif rank == 1 or rank == 3:\n        test_case.assertTrue(np.array_equal(y.to_local().numpy(), np.array([3, 4]),))\n\n\ndef _test_eager_boxing_n_to_one_with_diff_dim(\n    test_case, in_device_type, out_device_type\n):\n    x = flow.tensor(\n        [1, 2, 3, 4],\n        sbp=[flow.sbp.broadcast, flow.sbp.split(0)],\n        placement=flow.placement(in_device_type, ranks=[[0, 1], [2, 3]]),\n    )\n    y = x.to_global(\n        sbp=flow.sbp.broadcast, placement=flow.placement(out_device_type, ranks=[0])\n    )\n\n    rank = flow.env.get_rank()\n    if rank == 0:\n        test_case.assertTrue(\n            np.array_equal(y.to_local().numpy(), np.array([1, 2, 3, 4]),)\n        )\n\n\n@flow.unittest.skip_unless_1n4d()\nclass TestEagerBoxingOneToNWithDiffDim(flow.unittest.TestCase):\n    def test_eager_boxing_one_to_n_with_diff_dim(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"in_device_type\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"out_device_type\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            _test_eager_boxing_one_to_n_with_diff_dim(test_case, *arg)\n\n\n@flow.unittest.skip_unless_1n4d()\nclass TestEagerBoxingNToOneWithDiffDim(flow.unittest.TestCase):\n    def test_eager_boxing_n_to_one_with_diff_dim(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"in_device_type\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"out_device_type\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            _test_eager_boxing_one_to_n_with_diff_dim(test_case, *arg)\n\n\ndef _test_asymmetric_mix_1d_2d_eager_boxing_with_random_placement(\n    test_case,\n    in_sbp,\n    out_sbp,\n    shape,\n    in_device_type,\n    out_device_type,\n    in_device_list,\n    out_device_list,\n):\n    if not isinstance(in_sbp, tuple):\n        in_sbp = (in_sbp,)\n    if not isinstance(out_sbp, tuple):\n        out_sbp = (out_sbp,)\n    in_placement = flow.placement(type=in_device_type, ranks=in_device_list)\n    out_placement = flow.placement(type=out_device_type, ranks=out_device_list)\n    np_arr = np.random.uniform(-1e-05, 1e-05, shape)\n    x = flow.tensor(\n        np_arr, dtype=flow.float32, device=in_device_type, requires_grad=False,\n    )\n    x = x.to_global(in_placement, in_sbp)\n    y = x.to_global(out_placement, out_sbp)\n    test_case.assertTrue(y.sbp == out_sbp)\n    test_case.assertTrue(y.placement == out_placement)\n    test_case.assertTrue(np.allclose(x.numpy(), y.numpy()))\n\n\n@flow.unittest.skip_unless_1n4d()\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\nclass TestEagerBoxingAsymmetricMix1d2dWithRandomPlacement(flow.unittest.TestCase):\n    def test_eager_boxing_asymmetric_mix_1d_2d_with_random_placement(test_case):\n        arg_dict = OrderedDict()\n        sbp_dict = OrderedDict()\n        arg_dict[\"shape\"] = [(12, 24), (17, 13, 19)]\n\n        arg_dict[\"in_device_type\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"out_device_type\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"in_device_list\"] = [\n            [2],\n            [0, 1],\n            [1, 2, 3],\n            [0, 1, 2, 3],\n            [[0, 1, 2, 3]],\n            [[0, 1], [2, 3]],\n        ]\n        arg_dict[\"out_device_list\"] = [\n            [1],\n            [3],\n            [2, 3],\n            [0, 1, 3],\n            [0, 1, 2, 3],\n            [[2], [3]],\n            [[0, 1], [2, 3]],\n        ]\n        sbp_1d = [\n            flow.sbp.split(0),\n            flow.sbp.split(1),\n            flow.sbp.broadcast,\n            flow.sbp.partial_sum,\n        ]\n        sbp_dict[\"in_sbp_1d\"] = sbp_1d\n        sbp_dict[\"out_sbp_1d\"] = sbp_1d\n\n        import itertools\n\n        sbp_2d = list(itertools.product(sbp_1d, sbp_1d))\n        sbp_dict[\"in_sbp_2d\"] = sbp_2d\n        sbp_dict[\"out_sbp_2d\"] = sbp_2d\n\n        is_2d_device_list = lambda x: isinstance(x[0], list)\n\n        for arg in GenArgList(arg_dict):\n\n            in_device_list = arg[-2]\n            out_device_list = arg[-1]\n\n            is_in_2d_n_device_list = is_2d_device_list(in_device_list)\n            is_out_2d_n_device_list = is_2d_device_list(out_device_list)\n            if is_in_2d_n_device_list and is_out_2d_n_device_list:\n                for in_sbp in sbp_dict[\"in_sbp_2d\"]:\n                    for out_sbp in sbp_dict[\"out_sbp_2d\"]:\n                        _test_asymmetric_mix_1d_2d_eager_boxing_with_random_placement(\n                            test_case, in_sbp, out_sbp, *arg\n                        )\n            elif is_in_2d_n_device_list and not is_out_2d_n_device_list:\n                for in_sbp in sbp_dict[\"in_sbp_2d\"]:\n                    for out_sbp in sbp_dict[\"out_sbp_1d\"]:\n                        _test_asymmetric_mix_1d_2d_eager_boxing_with_random_placement(\n                            test_case, in_sbp, out_sbp, *arg\n                        )\n            elif not is_in_2d_n_device_list and is_out_2d_n_device_list:\n                for in_sbp in sbp_dict[\"in_sbp_1d\"]:\n                    for out_sbp in sbp_dict[\"out_sbp_2d\"]:\n                        _test_asymmetric_mix_1d_2d_eager_boxing_with_random_placement(\n                            test_case, in_sbp, out_sbp, *arg\n                        )\n            elif not is_in_2d_n_device_list and not is_out_2d_n_device_list:\n                for in_sbp in sbp_dict[\"in_sbp_1d\"]:\n                    for out_sbp in sbp_dict[\"out_sbp_1d\"]:\n                        _test_asymmetric_mix_1d_2d_eager_boxing_with_random_placement(\n                            test_case, in_sbp, out_sbp, *arg\n                        )\n            else:\n                raise NotImplementedError\n\n\n@flow.unittest.skip_unless_1n4d()\nclass TestEagerBoxing2DLocalToGlobalWithBalancedSplitSize(flow.unittest.TestCase):\n    def test_eager_boxing_2d_local_to_globa_with_balanced_size(test_case):\n        placement = flow.placement(type=\"cpu\", ranks=np.arange(4).reshape((2, 2)))\n        sbp = (flow.sbp.split(0), flow.sbp.split(1))\n        x = flow.tensor(np.arange(25).reshape((5, 5)), placement=placement, sbp=sbp)\n        y = x.to_local()\n        z = y.to_global(placement=placement, sbp=sbp)\n\n        test_case.assertEqual(z.placement, placement)\n        test_case.assertEqual(z.sbp, sbp)\n        test_case.assertEqual(z.size(), (5, 5))\n        test_case.assertTrue(\n            np.allclose(z.numpy(), np.arange(25).reshape((5, 5)), 1e-5, 1e-5)\n        )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_eager_boxing_exhaustive.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport os\nimport itertools\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nimport oneflow as flow\n\nimport oneflow.unittest\nfrom oneflow.test_utils.test_util import GenArgList\nfrom oneflow.test_utils.automated_test_util import *\n\n\ndef _test_eager_boxing_normal_1d_exhaustive_testing(\n    test_case, shape, in_device, out_device, in_device_list, out_device_list\n):\n    sbps = [\n        flow.sbp.split(0),\n        flow.sbp.split(1),\n        flow.sbp.broadcast,\n        flow.sbp.partial_sum,\n    ]\n    in_placement = flow.placement(type=in_device, ranks=in_device_list)\n    out_placement = flow.placement(type=out_device, ranks=out_device_list)\n    rand_tensor = random_tensor(len(shape), *shape, requires_grad=False).oneflow\n    for elem in itertools.product(sbps, sbps):\n        x = rand_tensor.to_global(placement=in_placement, sbp=elem[0])\n        y = x.to_global(placement=out_placement, sbp=elem[1])\n        test_case.assertTrue(np.allclose(y.numpy(), x.numpy(), 1e-3, 1e-3))\n\n\ndef _test_eager_boxing_symmetric_2d_exhaustive_testing(\n    test_case, in_device, out_device\n):\n    sbps = [\n        flow.sbp.split(0),\n        flow.sbp.split(1),\n        flow.sbp.broadcast,\n        flow.sbp.partial_sum,\n    ]\n    nd_sbps = itertools.product(\n        itertools.product(sbps, sbps), itertools.product(sbps, sbps)\n    )\n    shape = (8, 8, 16)\n    in_placement = flow.placement(type=in_device, ranks=[[0, 1], [2, 3]])\n    out_placement = flow.placement(type=out_device, ranks=[[0, 1], [2, 3]])\n    rand_tensor = random_tensor(len(shape), *shape, requires_grad=False).oneflow\n    for elem in nd_sbps:\n        x = rand_tensor.to_global(placement=in_placement, sbp=elem[0])\n        y = x.to_global(placement=out_placement, sbp=elem[1])\n        test_case.assertTrue(np.allclose(y.numpy(), x.numpy(), 1e-3, 1e-3))\n\n\ndef _test_eager_boxing_1d_special_split_axis(\n    test_case, in_device, out_device, in_device_list, out_device_list\n):\n    sbps = [\n        flow.sbp.split(2),\n        flow.sbp.split(3),\n        flow.sbp.broadcast,\n        flow.sbp.partial_sum,\n    ]\n    shape = (4, 4, 5, 7)\n    in_placement = flow.placement(type=in_device, ranks=in_device_list)\n    out_placement = flow.placement(type=out_device, ranks=out_device_list)\n    rand_tensor = random_tensor(len(shape), *shape, requires_grad=False).oneflow\n    for elem in itertools.product(sbps, sbps):\n        x = rand_tensor.to_global(placement=in_placement, sbp=elem[0])\n        y = x.to_global(placement=out_placement, sbp=elem[1])\n        test_case.assertTrue(np.allclose(y.numpy(), x.numpy(), 1e-3, 1e-3))\n\n\ndef _test_eager_boxing_2d_special_split_axis(test_case, in_device, out_device):\n    sbps = [\n        flow.sbp.split(2),\n        flow.sbp.split(4),\n        flow.sbp.broadcast,\n        flow.sbp.partial_sum,\n    ]\n    nd_sbps = itertools.product(\n        itertools.product(sbps, sbps), itertools.product(sbps, sbps)\n    )\n    shape = (4, 8, 4, 8, 4)\n    in_placement = flow.placement(type=in_device, ranks=[[0, 1], [2, 3]])\n    out_placement = flow.placement(type=out_device, ranks=[[0, 1], [2, 3]])\n    rand_tensor = random_tensor(len(shape), *shape, requires_grad=False).oneflow\n    for elem in nd_sbps:\n        x = rand_tensor.to_global(placement=in_placement, sbp=elem[0])\n        y = x.to_global(placement=out_placement, sbp=elem[1])\n        test_case.assertTrue(np.allclose(y.numpy(), x.numpy(), 1e-3, 1e-3))\n\n\n@flow.unittest.skip_unless_1n4d()\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\nclass TestEagerBoxingSymmetricExhaustiveTesting(flow.unittest.TestCase):\n    @globaltest\n    def test_eager_boxing_normal_1d_exhaustive_testing(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"shape\"] = [(4, 4), (6, 8), (5, 7)]\n        arg_dict[\"in_device\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"out_device\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"in_device_list\"] = [[0, 1], [1, 2, 3], [0, 1, 2, 3]]\n        arg_dict[\"out_device_list\"] = [[0, 1, 3], [0, 1, 2, 3]]\n        for arg in GenArgList(arg_dict):\n            _test_eager_boxing_normal_1d_exhaustive_testing(test_case, *arg)\n\n    @globaltest\n    def test_eager_boxing_symmetric_2d_exhaustive_testing(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"in_device\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"out_device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            _test_eager_boxing_symmetric_2d_exhaustive_testing(test_case, *arg)\n\n\n@flow.unittest.skip_unless_1n4d()\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\nclass TestEagerBoxingSpecialSplitAxisExhaustiveTesting(flow.unittest.TestCase):\n    @globaltest\n    def test_eager_boxing_1d_special_split_axis(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"in_device\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"out_device\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"in_device_list\"] = [[0, 1], [1, 2, 3], [0, 1, 2, 3]]\n        arg_dict[\"out_device_list\"] = [[0, 1, 3], [0, 1, 2, 3]]\n        for arg in GenArgList(arg_dict):\n            _test_eager_boxing_1d_special_split_axis(test_case, *arg)\n\n    @globaltest\n    def test_eager_boxing_2d_special_split_axis(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"in_device\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"out_device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            _test_eager_boxing_2d_special_split_axis(test_case, *arg)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_empty.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport oneflow as flow\n\nfrom oneflow.test_utils.test_util import GenArgDict\n\n\ndef _test_local_empty(test_case, shape, dtype, device, requires_grad):\n    x = flow.empty(\n        shape,\n        dtype=dtype,\n        device=flow.device(device),\n        requires_grad=requires_grad if dtype == flow.float32 else False,\n    )\n    test_case.assertFalse(x.is_global)\n    test_case.assertEqual(x.shape, flow.Size(shape))\n    test_case.assertEqual(x.dtype, dtype)\n    test_case.assertEqual(x.device, flow.device(device))\n    if dtype == flow.float32:\n        test_case.assertEqual(x.requires_grad, requires_grad)\n    empty_like_x = flow.empty_like(\n        x,\n        dtype=dtype,\n        device=flow.device(device),\n        requires_grad=requires_grad if dtype == flow.float32 else False,\n    )\n    test_case.assertFalse(empty_like_x.is_global)\n    test_case.assertEqual(empty_like_x.shape, flow.Size(shape))\n    test_case.assertEqual(empty_like_x.dtype, dtype)\n    test_case.assertEqual(empty_like_x.device, flow.device(device))\n    if dtype == flow.float32:\n        test_case.assertEqual(empty_like_x.requires_grad, requires_grad)\n\n\ndef _test_new_empty(test_case, shape, dtype, device, requires_grad):\n    x = flow.empty(shape, dtype=dtype, device=flow.device(device))\n    y = x.new_empty(\n        shape,\n        dtype=dtype,\n        device=flow.device(device),\n        requires_grad=requires_grad if dtype == flow.float32 else False,\n    )\n    test_case.assertFalse(y.is_global)\n    test_case.assertEqual(y.shape, flow.Size(shape))\n    test_case.assertEqual(y.dtype, dtype)\n    test_case.assertEqual(y.device, flow.device(device))\n    if dtype == flow.float32:\n        test_case.assertEqual(y.requires_grad, requires_grad)\n\n    y = x.new_empty(*shape)\n    test_case.assertFalse(y.is_global)\n    test_case.assertEqual(y.shape, flow.Size(shape))\n    test_case.assertEqual(y.dtype, x.dtype)\n    test_case.assertEqual(y.device, x.device)\n    test_case.assertFalse(y.requires_grad)\n\n\ndef _test_local_empty_strided(test_case, shape, stride, dtype, device, requires_grad):\n    x = flow.empty_strided(\n        shape,\n        stride,\n        dtype=dtype,\n        device=flow.device(device),\n        requires_grad=requires_grad,\n    )\n    test_case.assertEqual(x.shape, flow.Size(shape))\n    test_case.assertEqual(x.stride(), stride)\n    test_case.assertEqual(x.dtype, dtype)\n    test_case.assertEqual(x.device, flow.device(device))\n    if dtype == flow.float32:\n        test_case.assertEqual(x.requires_grad, requires_grad)\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestEmptyOp(flow.unittest.TestCase):\n    def test_local_empty(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"shape\"] = [(2, 3), (2, 3, 4), (2, 3, 4, 5)]\n        arg_dict[\"dtype\"] = [flow.float32, flow.float16, flow.int32]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"requires_grad\"] = [True, False]\n        for arg in GenArgDict(arg_dict):\n            _test_local_empty(test_case, **arg)\n            _test_new_empty(test_case, **arg)\n\n    def test_local_empty_strided(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"shape\"] = [(2, 3), (2, 3, 6), (2, 3, 12, 4)]\n        arg_dict[\"stride\"] = [(1, 2), (1, 2, 3), (2, 4, 5, 1)]\n        arg_dict[\"dtype\"] = [flow.float32, flow.float16, flow.int32]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"requires_grad\"] = [True, False]\n        for arg in GenArgDict(arg_dict):\n            _test_local_empty_strided(test_case, **arg)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_eq.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestEq(flow.unittest.TestCase):\n    @autotest(n=5, auto_backward=False, check_graph=True)\n    def test_eq_with_0_size_data(test_case):\n        device = random_device()\n        x = random_tensor(3, 2, 0, 3).to(device)\n        y = random_tensor(3, 2, 0, 3).to(device)\n        z = torch.eq(x, y)\n        return z\n\n    @autotest(n=5, auto_backward=False, check_graph=True)\n    def test_eq_with_0shape_0d_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=0).to(device)\n        y = random_tensor(ndim=0).to(device)\n        z = torch.eq(x, y)\n        return z\n\n    @autotest(n=5, auto_backward=False, check_graph=True)\n    def test_flow_eq_with_random_data(test_case):\n        device = random_device()\n        shape = random_tensor().oneflow.shape\n        x = random_tensor(len(shape), *shape, requires_grad=False).to(device)\n        y = random_tensor(len(shape), *shape, requires_grad=False).to(device)\n        return torch.eq(x, oneof(y, random().to(int), random().to(float)))\n\n    @autotest(n=5, auto_backward=False, check_graph=True)\n    def test_flow_tensor_eq_with_random_data(test_case):\n        device = random_device()\n        shape = random_tensor().oneflow.shape\n        x = random_tensor(len(shape), *shape, requires_grad=False).to(device)\n        y = random_tensor(len(shape), *shape, requires_grad=False).to(device)\n        return x.eq(oneof(y, random().to(int), random().to(float)))\n\n    @autotest(n=5, auto_backward=False, check_graph=True)\n    def test_flow_eq_with_random_0d_data(test_case):\n        device = random_device()\n        shape = random_tensor().oneflow.shape\n        x = random_tensor(ndim=0, requires_grad=False).to(device)\n        y = random_tensor(ndim=0, requires_grad=False).to(device)\n        return torch.eq(x, y)\n\n    @autotest(n=5, auto_backward=False, check_graph=True)\n    def test_flow_eq_with_same_random_data(test_case):\n        device = random_device()\n        shape = random_tensor().oneflow.shape\n        x = random_tensor(len(shape), *shape, requires_grad=False).to(device)\n        return torch.eq(x, x)\n\n    @autotest(n=5, auto_backward=False, check_graph=True)\n    def test_flow_eq_bool_with_random_data(test_case):\n        device = random_device()\n        shape = random_tensor().oneflow.shape\n        x = random_tensor(len(shape), *shape, requires_grad=False).to(\n            device=device, dtype=torch.bool\n        )\n        y = random_tensor(len(shape), *shape, requires_grad=False).to(\n            device=device, dtype=torch.bool\n        )\n        return torch.eq(x, y)\n\n    @autotest(n=5, auto_backward=False, check_graph=True)\n    def test_flow_eq_with_same_random_0d_data(test_case):\n        device = random_device()\n        shape = random_tensor().oneflow.shape\n        x = random_tensor(ndim=0, requires_grad=False).to(device)\n        return torch.eq(x, x)\n\n    @profile(torch.eq)\n    def profile_eq(test_case):\n        input1 = torch.ones(1000, 1280)\n        input2 = torch.ones(1000, 1280)\n        torch.eq(input1, input2)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_equal.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nimport torch as torch_original\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestEqual(flow.unittest.TestCase):\n    @autotest(n=5, auto_backward=False, check_graph=False, include_complex=True)\n    def test_eq_with_0_size_data(test_case):\n        device = random_device()\n        x = random_tensor(3, 2, 0, 3).to(device)\n        y = random_tensor(3, 2, 0, 3).to(device)\n        z = torch.equal(x, y)\n        return z\n\n    @autotest(n=5, auto_backward=False, check_graph=False)\n    def test_equal_with_0shape_0d_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=0).to(device)\n        y = random_tensor(ndim=0).to(device)\n        z = torch.equal(x, y)\n        return z\n\n    @autotest(n=5, auto_backward=False, check_graph=False)\n    def test_flow_equal_with_random_data(test_case):\n        device = random_device()\n        shape = random_tensor().oneflow.shape\n        x = random_tensor(len(shape), *shape, requires_grad=False).to(device)\n        y = random_tensor(len(shape), *shape, requires_grad=False).to(device)\n        return torch.equal(x, y)\n\n    @autotest(n=5, auto_backward=False, check_graph=False)\n    def test_flow_tensor_equal_with_random_data(test_case):\n        device = random_device()\n        shape = random_tensor().oneflow.shape\n        x = random_tensor(len(shape), *shape, requires_grad=False).to(device)\n        y = random_tensor(len(shape), *shape, requires_grad=False).to(device)\n        return x.equal(y)\n\n    @autotest(n=5, auto_backward=False, check_graph=False)\n    def test_flow_equal_with_random_0d_data(test_case):\n        device = random_device()\n        shape = random_tensor().oneflow.shape\n        x = random_tensor(ndim=0, requires_grad=False).to(device)\n        y = random_tensor(ndim=0, requires_grad=False).to(device)\n        return torch.equal(x, y)\n\n    @autotest(n=5, auto_backward=False, check_graph=False)\n    def test_flow_equal_with_same_random_data(test_case):\n        device = random_device()\n        shape = random_tensor().oneflow.shape\n        x = random_tensor(len(shape), *shape, requires_grad=False).to(device)\n        return torch.equal(x, x)\n\n    @autotest(n=5, auto_backward=False, check_graph=False, include_complex=True)\n    def test_flow_equal_complex_with_same_random_data(test_case):\n        device = random_device()\n        shape = random_tensor().oneflow.shape\n        x = random_tensor(len(shape), *shape, requires_grad=False, dtype=complex).to(\n            device\n        )\n        return torch.equal(x, x)\n\n    @autotest(n=5, auto_backward=False, check_graph=False)\n    def test_flow_equal_bool_with_random_data(test_case):\n        device = random_device()\n        shape = random_tensor().oneflow.shape\n        x = random_tensor(len(shape), *shape, requires_grad=False).to(\n            device=device, dtype=torch.bool\n        )\n        y = random_tensor(len(shape), *shape, requires_grad=False).to(\n            device=device, dtype=torch.bool\n        )\n        return torch.equal(x, y)\n\n    @autotest(n=5, auto_backward=False, check_graph=False, include_complex=True)\n    def test_flow_equal_complex_with_random_data(test_case):\n        device = random_device()\n        shape = random_tensor().oneflow.shape\n        x = random_tensor(len(shape), *shape, requires_grad=False, dtype=complex).to(\n            device=device\n        )\n        y = random_tensor(len(shape), *shape, requires_grad=False, dtype=complex).to(\n            device=device\n        )\n        return torch.equal(x, y)\n\n    @autotest(n=5, auto_backward=False, check_graph=False, include_complex=True)\n    def test_flow_not_equal_complex_with_random_data(test_case):\n        device = random_device()\n        shape = random_tensor().oneflow.shape\n        x = random_tensor(len(shape), *shape, requires_grad=False, dtype=complex).to(\n            device=device\n        )\n        y = random_tensor(len(shape), *shape, requires_grad=False, dtype=complex).to(\n            device=device\n        )\n        return torch.not_equal(x, y)\n\n    @autotest(n=5, auto_backward=False, check_graph=False)\n    def test_flow_equal_with_same_random_0d_data(test_case):\n        device = random_device()\n        shape = random_tensor().oneflow.shape\n        x = random_tensor(ndim=0, requires_grad=False).to(device)\n        return torch.equal(x, x)\n\n    @profile(torch.equal)\n    def profile_equal(test_case):\n        input1 = torch.ones(1000, 1280)\n        input2 = torch.ones(1000, 1280)\n        torch.equal(input1, input2)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_erf.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nfrom scipy import special\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestErfModule(flow.unittest.TestCase):\n    @autotest(n=5)\n    def test_flow_erf_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor().to(device)\n        y = torch.erf(x)\n        return y\n\n    @autotest(n=5)\n    def test_flow_erf_with_0dim_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=0).to(device)\n        y = torch.erf(x)\n        return y\n\n    @profile(torch.erf)\n    def profile_erf(test_case):\n        torch.erf(torch.ones(100000))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_erfc.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nfrom scipy import special\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestErfcModule(flow.unittest.TestCase):\n    @autotest(n=5)\n    def test_flow_erfc_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor().to(device)\n        y = torch.erfc(x)\n        return y\n\n    @autotest(n=5)\n    def test_flow_erfc_with_0dim_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=0).to(device)\n        y = torch.erfc(x)\n        return y\n\n    @profile(torch.erfc)\n    def profile_erfc(test_case):\n        torch.erfc(torch.ones(100000))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_erfinv.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nfrom scipy import special\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\ndef _test_flow_erfinv_with_inf_data(test_case, device):\n    x = flow.tensor(np.ones((5, 5)), dtype=flow.float32, device=flow.device(device))\n    of_out = flow.erfinv(x)\n    np_out = np.full((5, 5), np.inf)\n    test_case.assertTrue(np.array_equal(of_out.numpy(), np_out))\n\n\ndef _test_flow_erfinv_with_nan_data(test_case, device):\n    x = flow.tensor(\n        np.arange(2, 22).reshape(4, 5), dtype=flow.float32, device=flow.device(device)\n    )\n    of_out = flow.erfinv(x)\n    np_out = np.full((4, 5), np.nan)\n    test_case.assertTrue(np.array_equal(of_out.numpy(), np_out, equal_nan=True))\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestErfinvModule(flow.unittest.TestCase):\n    def test_flow_erfinv(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [\n            _test_flow_erfinv_with_inf_data,\n            _test_flow_erfinv_with_nan_data,\n        ]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n    @autotest(check_graph=True, auto_backward=False)\n    def test_flow_erfinv_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(requires_grad=False).to(device)\n        y = torch.erfinv(x)\n        return y\n\n    @profile(torch.erfinv)\n    def profile_erfinv(test_case):\n        torch.erfinv(torch.ones(100000))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_expand.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\n\nimport oneflow as flow\nfrom oneflow.test_utils.test_util import GenArgList\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\ndef _np_get_expand(input_shape, expand_size):\n    input = np.random.random(size=input_shape).astype(np.float32)\n\n    input_stride = [1]\n    for i in range(len(input_shape) - 2, -1, -1):\n        input_stride.insert(0, input_stride[0] * input_shape[i + 1])\n    # calculate the output shape and stride\n    new_size = []\n    new_stride = []\n    diff = len(expand_size) - len(input_shape)\n    for i in range(len(expand_size) - 1, -1, -1):\n        if i >= diff:\n            if expand_size[i] == -1 or expand_size[i] == input_shape[i - diff]:\n                new_size.insert(0, input_shape[i - diff])\n                new_stride.insert(0, input_stride[i - diff])\n            else:\n                assert expand_size[i] >= 1 and input_shape[i - diff] == 1\n                new_size.insert(0, expand_size[i])\n                new_stride.insert(0, 0)\n        else:\n            assert expand_size[i] >= 1\n            new_size.insert(0, expand_size[i])\n            if expand_size[i] == 1:\n                new_stride.insert(0, new_stride[0])\n            else:\n                new_stride.insert(0, 0)\n\n    gout = np.random.random(size=tuple(new_size)).astype(np.float32)\n\n    out_stride = [1]\n    for i in range(len(new_size) - 2, -1, -1):\n        out_stride.insert(0, out_stride[0] * new_size[i + 1])\n\n    gin = np.zeros(input_shape).flatten()\n    out = np.zeros(np.product(new_size))\n\n    def getOffset(i_offset, stride, expand_stride, n):\n        remain = i_offset\n        o_offset = 0\n        for i in range(n):\n            idx = int(remain / stride[i])\n            o_offset += idx * expand_stride[i]\n            remain = remain - idx * stride[i]\n        return o_offset\n\n    in_flatten = input.flatten()\n    gout_flatten = gout.flatten()\n    num_elem = np.product(new_size)\n    dims = len(new_size)\n\n    for i in range(num_elem):\n        offset = getOffset(i, out_stride, new_stride, dims)\n        gin[offset] += gout_flatten[i]\n        out[i] = in_flatten[offset]\n\n    return input, gout, out.reshape(tuple(new_size)), gin.reshape(input_shape)\n\n\ndef _test_expand_new_dims(test_case, device):\n    input_shape = (1, 4, 1, 32)\n    expand_dim = [2, 1, 2, 4, 2, 32]\n    input, gout, out_np, gin_np = _np_get_expand(input_shape, expand_dim)\n    of_input = flow.tensor(\n        input, dtype=flow.float32, device=flow.device(device), requires_grad=True\n    )\n    of_out = of_input.expand(2, 1, 2, 4, 2, 32)\n    test_case.assertTrue(np.array_equal(of_out.numpy(), out_np))\n\n\ndef _test_expand_same_dim(test_case, device):\n    input_shape = (2, 4, 1, 32)\n    expand_dim = [2, 4, 2, 32]\n    input, gout, out_np, gin_np = _np_get_expand(input_shape, expand_dim)\n    of_input = flow.tensor(input, dtype=flow.float32, device=flow.device(device))\n    of_out = of_input.expand(2, 4, 2, 32)\n\n    test_case.assertTrue(np.array_equal(of_out.numpy(), out_np))\n\n\ndef _test_expand_same_dim_negative(test_case, device):\n    input_shape = (1, 6, 5, 3)\n    expand_dim = [4, -1, 5, 3]\n    input, gout, out_np, gin_np = _np_get_expand(input_shape, expand_dim)\n    of_input = flow.tensor(input, dtype=flow.float32, device=flow.device(device))\n    of_out = of_input.expand(4, -1, 5, 3)\n\n    test_case.assertTrue(np.array_equal(of_out.numpy(), out_np))\n\n\ndef _test_expand_same_int(test_case, device):\n    input_shape = (2, 4, 1, 32)\n    expand_dim = [2, 4, 2, 32]\n    input, gout, out_np, gin_np = _np_get_expand(input_shape, expand_dim)\n    of_input = flow.tensor(input, dtype=flow.int, device=flow.device(device))\n    of_out = of_input.expand(2, 4, 2, 32)\n\n    test_case.assertTrue(np.array_equal(of_out.numpy(), out_np.astype(np.int32)))\n\n\ndef _test_expand_flow_size(test_case, device):\n    input_shape = (2, 4, 1, 32)\n    expand_dim = flow.Size([2, 4, 2, 32])\n    input, gout, out_np, gin_np = _np_get_expand(input_shape, expand_dim)\n    of_input = flow.tensor(input, dtype=flow.int, device=flow.device(device))\n    of_out = of_input.expand(expand_dim)\n\n    test_case.assertTrue(np.array_equal(of_out.numpy(), out_np.astype(np.int32)))\n\n\ndef _test_expand_backward_same_dim(test_case, device):\n    input = np.array(\n        [\n            [\n                [[0.9876952171325684]],\n                [[0.8772538304328918]],\n                [[0.9200366735458374]],\n                [[0.2810221314430237]],\n            ],\n            [\n                [[0.3037724494934082]],\n                [[0.7783719897270203]],\n                [[0.08884672075510025]],\n                [[0.17156553268432617]],\n            ],\n        ]\n    )\n    of_input = flow.tensor(\n        input, dtype=flow.float32, device=flow.device(device), requires_grad=True\n    )\n    of_out = of_input.expand(2, 4, 2, 1)\n    of_out.sum().backward()\n    np_grad = [\n        [[[2.0]], [[2.0]], [[2.0]], [[2.0]]],\n        [[[2.0]], [[2.0]], [[2.0]], [[2.0]]],\n    ]\n    test_case.assertTrue(np.array_equal(of_input.grad.numpy(), np_grad))\n\n\ndef _test_expand_backward(test_case, device):\n    input = np.array(\n        [\n            [\n                [[0.8981702327728271, 0.5372866988182068]],\n                [[0.45116370916366577, 0.8656941056251526]],\n                [[0.8811476230621338, 0.5552017688751221]],\n                [[0.6291894316673279, 0.5786571502685547]],\n            ]\n        ]\n    )\n    of_input = flow.tensor(\n        input, dtype=flow.float32, device=flow.device(device), requires_grad=True\n    )\n    of_out = of_input.expand(2, 1, 2, 4, 2, 2)\n    of_out.sum().backward()\n    np_grad = [[[[8.0, 8.0]], [[8.0, 8.0]], [[8.0, 8.0]], [[8.0, 8.0]]]]\n    test_case.assertTrue(np.array_equal(of_input.grad.numpy(), np_grad))\n\n\ndef random_expand(x, ndim, expand_size):\n    dim_size = [1,] * ndim\n    random_index = random(0, ndim).to(int).value()\n    dim_size[random_index] = expand_size\n    return x.expand(*dim_size)\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestExpand(flow.unittest.TestCase):\n    @autotest(check_graph=True)\n    def test_flow_tensor_expand_with_random_data(test_case):\n        random_expand_size = random(1, 6).to(int).value()\n        x = random_tensor(ndim=5, dim0=1, dim1=1, dim2=1, dim3=1, dim4=1)\n        return random_expand(x, ndim=5, expand_size=random_expand_size)\n\n    @autotest(auto_backward=False, check_graph=True)\n    def test_flow_tensor_expand_bool_with_random_data(test_case):\n        random_expand_size = random(1, 6).to(int).value()\n        x = random_tensor(ndim=5, dim0=1, dim1=1, dim2=1, dim3=1, dim4=1).to(torch.bool)\n        return random_expand(x, ndim=5, expand_size=random_expand_size)\n\n    def test_expand_compare_with_numpy(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [\n            _test_expand_new_dims,\n            _test_expand_same_dim,\n            _test_expand_same_dim_negative,\n            _test_expand_same_int,\n            _test_expand_flow_size,\n            _test_expand_backward,\n            _test_expand_backward_same_dim,\n        ]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n    @autotest(n=5, auto_backward=False)\n    def test_flow_expand_with_0_size(test_case):\n        device = random_device()\n        x = random_tensor(ndim=2, dim1=1).to(device)\n        return x.expand([0, 3])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_expand_stride.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport os\nimport unittest\nimport numpy as np\n\n\nimport oneflow as flow\nimport torch\n\nfrom collections import OrderedDict\nfrom oneflow.test_utils.test_util import GenArgDict\n\n\ndef _cmp_expand_stride(\n    test_case, input_shape, expand_shape, device=\"cuda\", verbose=False,\n):\n    input = np.random.randn(*input_shape)\n    torch_x = torch.tensor(input, dtype=torch.float32, device=device)\n    torch_y = torch_x.expand(*expand_shape)\n\n    x = flow.tensor(input, dtype=flow.float32, device=device)\n    y = x.expand(*expand_shape)\n\n    if verbose:\n        print(\"\")\n        print(f\" eager (view::Expand) (device={device}) \".center(50, \"=\"))\n        print(f\" {input_shape} -> {expand_shape} \".center(50, \"*\"))\n        print(f\"x: shape={x.shape}, stride={x.stride()}\")\n        print(f\"y: shape={y.shape}, stride={y.stride()}\")\n        print(f\"torch_y: shape={torch_y.shape}, stride={torch_y.stride()}\")\n        print(\" input \".center(50, \"-\"))\n        print(input)\n        print(\" y \".center(50, \"-\"))\n        print(y)\n        print(\" torch_y \".center(50, \"-\"))\n        print(torch_y)\n\n    test_case.assertTrue(np.array_equal(y.stride(), torch_y.stride()))\n    test_case.assertTrue(np.array_equal(y.numpy(), torch_y.detach().cpu().numpy()))\n\n\ndef _cmp_expand_non_contiguous_stride(\n    test_case, input_shape, perm, expand_shape, device=\"cuda\", verbose=False,\n):\n    input = np.random.randn(*input_shape).astype(np.float32)\n    x = flow.tensor(input, device=device)\n    y = x.permute(*perm)\n    z = y.expand(*expand_shape)\n\n    torch_x = torch.tensor(input, device=device)\n    torch_y = torch_x.permute(*perm)\n    torch_z = torch_y.expand(*expand_shape)\n\n    if verbose:\n        print(\"\")\n        print(f\" non_contiguous (device={device}) \".center(50, \"-\"))\n        print(f\" {input_shape}, {perm} -> {expand_shape} \".center(50, \"-\"))\n        print(f\"x: shape={x.shape}, stride={x.stride()}\")\n        print(f\"y: shape={y.shape}, stride={y.stride()}\")\n        print(f\"z: shape={z.shape}, stride={z.stride()}\")\n        print(f\"torch_y: shape={torch_y.shape}, stride={torch_y.stride()}\")\n        print(f\"torch_z: shape={torch_z.shape}, stride={torch_z.stride()}\")\n        print(\" input \".center(50, \"-\"))\n        print(input)\n        print(\" z \".center(50, \"-\"))\n        print(z)\n        print(\" torch_z \".center(50, \"-\"))\n        print(torch_z)\n\n    test_case.assertTrue(np.array_equal(z.stride(), torch_z.stride()))\n    test_case.assertTrue(np.array_equal(z.numpy(), torch_z.detach().cpu().numpy()))\n\n\ndef _cmp_lazy_expand_stride(\n    test_case, input_shape, expand_shape, device=\"cuda\", verbose=False,\n):\n    input = np.random.randn(*input_shape)\n    torch_x = torch.tensor(input, dtype=torch.float32, device=device)\n    torch_y = torch_x.expand(*expand_shape).contiguous()\n    # oneflow lazy must do this contiguous\n\n    class MyGraph(flow.nn.Graph):\n        def __init__(self, expand_shape):\n            super().__init__()\n            self.expand_shape = expand_shape\n\n        def build(self, x):\n            return x.expand(*self.expand_shape)\n\n    expand_graph = MyGraph(expand_shape)\n    x = flow.tensor(input, dtype=flow.float32, device=device)\n    y = expand_graph(x)\n\n    squeeze_y_stride = []\n    for d, s in zip(y.shape, y.stride()):\n        if d != 1:\n            squeeze_y_stride.append(s)\n\n    squeeze_torch_y_stride = []\n    for d, s in zip(torch_y.shape, torch_y.stride()):\n        if d != 1:\n            squeeze_torch_y_stride.append(s)\n\n    if verbose:\n        print(\"\")\n        print(f\" lazy (expand op/kernel) (device={device}) \".center(50, \"=\"))\n        print(f\" {input_shape} -> {expand_shape} \".center(50, \"*\"))\n        print(f\"x: shape={x.shape}, stride={x.stride()}\")\n        print(f\"y: shape={y.shape}, stride={y.stride()}\")\n        print(f\"torch_y: shape={torch_y.shape}, stride={torch_y.stride()}\")\n        print(f\"squeeze_y_stride={squeeze_y_stride}\")\n        print(f\"squeeze_torch_y_stride={squeeze_torch_y_stride}\")\n        print(\" input \".center(50, \"-\"))\n        print(input)\n        print(\" y \".center(50, \"-\"))\n        print(y)\n        print(\" torch_y \".center(50, \"-\"))\n        print(torch_y)\n\n    test_case.assertTrue(np.array_equal(squeeze_y_stride, squeeze_torch_y_stride))\n    test_case.assertTrue(np.array_equal(y.numpy(), torch_y.detach().cpu().numpy()))\n\n\n@flow.unittest.skip_unless_1n1d()\nclass ExpandStrideTestCase(flow.unittest.TestCase):\n    test_shape_tuple_list = [\n        ((1, 2), (2, 2)),\n        ((1, 2), (1, 1, 2)),\n        ((1, 2), (1, 2, 2)),\n        ((1, 2), (2, 1, 2)),\n        ((1, 2), (2, 2, 2)),\n        ((1, 2), (1, 1, 1, 2)),\n        ((1, 2), (1, 2, 1, 2)),\n        ((1, 2), (2, 1, 1, 2)),\n        ((1, 2), (2, 2, 1, 2)),\n        ((1, 2), (2, 2, 2, 2)),\n        ((2, 1), (2, 2)),\n        ((2, 1), (1, 2, 1)),\n        ((2, 1), (1, 2, 2)),\n        ((2, 1), (2, 2, 1)),\n        ((2, 1), (2, 2, 2)),\n        ((2, 1), (1, 1, 2, 1)),\n        ((2, 1), (1, 2, 2, 1)),\n        ((2, 1), (2, 2, 2, 1)),\n        ((2, 1), (2, 2, 2, 2)),\n        ((2, 2), (1, 2, 2)),\n        ((2, 2), (2, 2, 2)),\n        ((2, 2), (1, 1, 2, 2)),\n        ((2, 2), (1, 2, 2, 2)),\n        ((2, 2), (2, 1, 2, 2)),\n        ((2, 2), (2, 2, 2, 2)),\n        ((2, 1, 4), (2, 2, 2, 4)),\n        ((2, 1, 3), (2, 1, -1, -1, -1)),\n        ((2, 1, 3), (1, 2, -1, -1, -1)),\n        ((2, 1, 3), (2, 2, -1, -1, -1)),\n        ((2, 1, 3), (2, 1, -1, 2, 3)),\n        ((2, 1, 3), (1, 2, 2, 2, -1)),\n        ((2, 1, 3), (2, 2, 2, 2, 3)),\n        ((2, 3, 4), (1, 2, -1, -1, -1)),\n        ((2, 3, 4), (2, 1, -1, -1, -1)),\n        ((2, 3, 4), (2, 2, -1, -1, -1)),\n        ((), (1,)),\n        ((), (2,)),\n        ((), (1, 2)),\n        ((), (2, 1)),\n        ((), (2, 2)),\n    ]\n\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    def test_on_cpu(self):\n        arg_dict = OrderedDict()\n        arg_dict[\"verbose\"] = [False]\n        arg_dict[\"device\"] = [\"cpu\"]\n        arg_dict[\"shapes\"] = self.test_shape_tuple_list\n        for kwargs in GenArgDict(arg_dict):\n            assert \"shapes\" in kwargs\n            input_shape, expand_shape = kwargs.pop(\"shapes\")\n            _cmp_expand_stride(self, input_shape, expand_shape, **kwargs)\n\n    def test_stride(self):\n        arg_dict = OrderedDict()\n        arg_dict[\"verbose\"] = [False]\n        arg_dict[\"device\"] = [\"cuda\"]\n        arg_dict[\"shapes\"] = self.test_shape_tuple_list\n        for kwargs in GenArgDict(arg_dict):\n            assert \"shapes\" in kwargs\n            input_shape, expand_shape = kwargs.pop(\"shapes\")\n            _cmp_expand_stride(self, input_shape, expand_shape, **kwargs)\n\n    def test_non_contiguous_stride(self):\n        arg_dict = OrderedDict()\n        arg_dict[\"verbose\"] = [False]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"shapes\"] = [\n            ((2, 1, 3), (0, 2, 1), (1, 2, -1, -1, -1)),\n            ((2, 1, 3), (0, 2, 1), (2, 1, -1, -1, -1)),\n            ((2, 1, 3), (0, 2, 1), (2, 3, -1, -1, -1)),\n            ((2, 1, 3), (0, 2, 1), (1, 2, -1, -1, 2)),\n            ((2, 1, 3), (0, 2, 1), (2, 1, -1, -1, 2)),\n            ((2, 1, 3), (0, 2, 1), (2, 3, -1, -1, 2)),\n            ((2, 3, 4), (0, 2, 1), (1, 2, -1, -1, -1)),\n            ((2, 3, 4), (0, 2, 1), (2, 1, -1, -1, -1)),\n            ((2, 3, 4), (0, 2, 1), (2, 2, -1, -1, -1)),\n        ]\n        for kwargs in GenArgDict(arg_dict):\n            assert \"shapes\" in kwargs\n            input_shape, perm, expand_shape = kwargs.pop(\"shapes\")\n            _cmp_expand_non_contiguous_stride(\n                self, input_shape, perm, expand_shape, **kwargs\n            )\n\n    def test_lazy(self):\n        arg_dict = OrderedDict()\n        arg_dict[\"verbose\"] = [False]\n        arg_dict[\"device\"] = [\"cuda\"]\n        arg_dict[\"shapes\"] = self.test_shape_tuple_list\n        for kwargs in GenArgDict(arg_dict):\n            assert \"shapes\" in kwargs\n            input_shape, expand_shape = kwargs.pop(\"shapes\")\n            _cmp_lazy_expand_stride(self, input_shape, expand_shape, **kwargs)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_expm1.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\ndef _test_expm1_impl(test_case, device, shape):\n    x = flow.tensor(\n        np.random.randn(*shape),\n        dtype=flow.float32,\n        device=flow.device(device),\n        requires_grad=True,\n    )\n    of_out = flow.expm1(x)\n    np_out = np.expm1(x.numpy())\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001))\n    of_out = of_out.sum()\n    of_out.backward()\n    test_case.assertTrue(np.allclose(x.grad.numpy(), np.exp(x.numpy()), 0.0001, 0.0001))\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestExpm1Module(flow.unittest.TestCase):\n    def test_expm1(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [_test_expm1_impl]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"shape\"] = [(1,), (2, 3), (2, 3, 4), (2, 3, 4, 5)]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n    @autotest(n=5, check_graph=True)\n    def test_expm1_flow_with_random_data(test_case):\n        device = random_device()\n        input = random_tensor().to(device)\n        y = torch.expm1(input)\n        return y\n\n    @autotest(n=5, auto_backward=False, check_graph=True)\n    def test_expm1_with_0_size_data(test_case):\n        device = random_device()\n        x = random_tensor(4, 2, 1, 0, 3).to(device)\n        y = torch.expm1(x)\n        return y\n\n    @autotest(n=5, check_graph=True)\n    def test_expm1_flow_with_0dim_data(test_case):\n        device = random_device()\n        input = random_tensor(ndim=0).to(device)\n        y = torch.expm1(input)\n        return y\n\n    @profile(torch.expm1)\n    def profile_expm1(test_case):\n        torch.expm1(torch.ones(100000))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_eye.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\n\nfrom oneflow.test_utils.automated_test_util import *\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\n\n\ndef _test_eye_forward(test_case, device, n, m):\n    output = flow.eye(n, m, device=device)\n    np_out = np.eye(n, m)\n    test_case.assertTrue(np.array_equal(output.numpy(), np_out))\n\n\ndef _test_eye_backward(test_case, device, n, m):\n    x = flow.eye(n, m, device=device)\n    x.requires_grad = True\n    y = x.sum()\n    y.backward()\n    test_case.assertTrue(np.array_equal(x.grad.numpy(), np.ones([n, m])))\n\n\ndef _test_eye_with_1n2d(test_case, n, m, device):\n    placement = flow.placement(device, range(2))\n    x = flow.eye(n, m, placement=placement, sbp=flow.sbp.broadcast)\n    test_case.assertTrue(x.placement, placement)\n    test_case.assertTrue(x.sbp, flow.sbp.broadcast)\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestEye(flow.unittest.TestCase):\n    def test_eye(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [\n            _test_eye_forward,\n            _test_eye_backward,\n        ]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"n\"] = [4, 3, 2]\n        arg_dict[\"m\"] = [4, 3, 2]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n    @autotest(check_graph=True)\n    def test_eye_with_random_data(test_case):\n        n = random(low=1, high=5).to(int)\n        m = random(low=1, high=5).to(int)\n        x = torch.eye(n=n, m=m, device=random_device())\n        x.oneflow.requires_grad = True\n        x.pytorch.requires_grad = True\n        return x\n\n    @autotest(check_graph=True, auto_backward=False)\n    def test_eye_with_random_data(test_case):\n        n = random(low=0, high=1).to(int)\n        m = random(low=0, high=2).to(int)\n        x = torch.eye(n=n, m=m, device=random_device())\n        return x\n\n    @autotest(check_graph=True)\n    def test_eye_bool_with_random_data(test_case):\n        n = random().to(int)\n        m = random().to(int)\n        x = torch.eye(n=n, m=m)\n        device = random_device()\n        x.to(device=device, dtype=torch.bool)\n        x = random_tensor().to(device)\n        return x\n\n    @autotest(check_graph=True, auto_backward=False)\n    def test_eye_with_0dim_data(test_case):\n        n = random().to(int)\n        m = random().to(int)\n        x = torch.eye(n=n, m=m)\n        device = random_device()\n        x.to(device)\n        x = random_tensor(ndim=0).to(device)\n        return x\n\n    @profile(torch.eye)\n    def profile_eye(test_case):\n        torch.eye(1000)\n        torch.eye(100, 1280)\n\n\n@flow.unittest.skip_unless_1n2d()\nclass TestGlobalEye(flow.unittest.TestCase):\n    def test_eye_with_1n2d(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [_test_eye_with_1n2d]\n        arg_dict[\"n\"] = [4, 3, 2]\n        arg_dict[\"m\"] = [4, 3, 2]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_fake_quantization.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport math\nimport numpy as np\n\nfrom oneflow.test_utils.automated_test_util import *\nfrom oneflow.test_utils.test_util import GenArgList\nfrom oneflow.test_utils.test_util import (\n    GenArgList,\n    type_name_to_flow_type,\n    type_name_to_np_type,\n)\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\ndef gen_quant_scale_for_min_max_symmetric(weight, quantization_bit):\n    weight_max = np.max(np.abs(weight))\n    denominator = 2.0 ** (quantization_bit - 1) - 1\n    return (weight_max / denominator, 0)\n\n\ndef gen_quant_scale_for_min_max_affine(weight, quantization_bit):\n    weight_max = np.max(weight)\n    weight_min = np.min(weight)\n    denominator = 2.0 ** quantization_bit - 1\n    scale = (weight_max - weight_min) / denominator\n    zero_point = -np.round(weight_min / scale)\n    return (scale, zero_point)\n\n\ndef gen_quant_scale_for_min_max_cambricon(weight, quantization_bit):\n    weight_max = np.max(np.abs(weight))\n    scale = math.floor(math.log2(weight_max)) - (quantization_bit - 2)\n    return (scale, 0)\n\n\ndef product(tu):\n    return np.prod(tu).astype(np.int32).item()\n\n\ndef fake_quant_per_layer_symmetric(input, quantization_bit, scale):\n    upper_bound = 2.0 ** (quantization_bit - 1) - 1\n    lower_bound = -upper_bound\n    return np.clip(np.rint(input / scale), lower_bound, upper_bound) * scale\n\n\ndef fake_quant_per_layer_affine(input, quantization_bit, scale, zero_point):\n    upper_bound = 2.0 ** quantization_bit - 1\n    lower_bound = 0\n    return (\n        np.clip(np.rint(input / scale + zero_point), lower_bound, upper_bound)\n        - zero_point\n    ) * scale\n\n\ndef fake_quant_per_layer_cambricon(input, quantization_bit, shift):\n    upper_bound = 2.0 ** (quantization_bit - 1) - 1\n    lower_bound = -upper_bound\n    scale = 2 ** shift\n    return np.clip(np.rint(input / scale), lower_bound, upper_bound) * scale\n\n\ndef _check_fake_quantize(\n    test_case,\n    input,\n    input_diff_of,\n    out_of,\n    quantization_bit,\n    quantization_scheme,\n    quantization_formula,\n    per_layer_quantization,\n):\n    if per_layer_quantization or quantization_formula == \"cambricon\":\n        outer_num = 1\n        inner_num = product(input.shape[0:])\n    else:\n        outer_num = input.shape[0]\n        inner_num = product(input.shape[1:])\n    scale_np = np.zeros((outer_num,))\n    zero_point_np = np.zeros((outer_num,))\n    out_np = np.zeros((inner_num * outer_num,))\n    input_flatten = input.flatten()\n    input_diff_np = np.full((inner_num * outer_num,), 1.0 / (inner_num * outer_num))\n    if quantization_formula == \"google\":\n        if quantization_scheme == \"symmetric\":\n            for c in range(outer_num):\n                (scale_np[c], zero_point_np[c]) = gen_quant_scale_for_min_max_symmetric(\n                    input_flatten[c * inner_num : (c + 1) * inner_num], quantization_bit\n                )\n                out = fake_quant_per_layer_symmetric(\n                    input_flatten[c * inner_num : (c + 1) * inner_num],\n                    quantization_bit,\n                    scale_np[c],\n                )\n                out_np[c * inner_num : (c + 1) * inner_num] = out\n        else:\n            for c in range(outer_num):\n                (scale_np[c], zero_point_np[c]) = gen_quant_scale_for_min_max_affine(\n                    input_flatten[c * inner_num : (c + 1) * inner_num], quantization_bit\n                )\n                out = fake_quant_per_layer_affine(\n                    input_flatten[c * inner_num : (c + 1) * inner_num],\n                    quantization_bit,\n                    scale_np[c],\n                    zero_point_np[c],\n                )\n                out_np[c * inner_num : (c + 1) * inner_num] = out\n    else:\n        (scale_np[0], zero_point_np[0]) = gen_quant_scale_for_min_max_cambricon(\n            input_flatten, quantization_bit\n        )\n        out_np = fake_quant_per_layer_cambricon(\n            input_flatten, quantization_bit, scale_np[0]\n        )\n    rmse = np.sqrt(np.mean((out_of - out_np) ** 2))\n    assert rmse <= 1.0, \"fake_quantization op has bug!\"\n    test_case.assertTrue(np.allclose(input_diff_of, input_diff_np, rtol=0.001))\n\n\ndef _run_test_fake_quantize(\n    test_case,\n    device_type,\n    dtype,\n    in_shape,\n    quantization_bit,\n    quantization_scheme,\n    quantization_formula,\n    per_layer_quantization,\n):\n    input = (np.random.random(in_shape) - 0.5).astype(type_name_to_np_type[dtype])\n    input_tensor = flow.tensor(\n        input, dtype=flow.float32, requires_grad=True, device=flow.device(device_type)\n    )\n    min_max_observer = flow.nn.MinMaxObserver(\n        quantization_formula=quantization_formula,\n        quantization_bit=quantization_bit,\n        quantization_scheme=quantization_scheme,\n        per_layer_quantization=per_layer_quantization,\n    )\n    (scale, zero_point) = min_max_observer(input_tensor)\n    fake_quantization = flow.nn.FakeQuantization(\n        quantization_formula=quantization_formula,\n        quantization_bit=quantization_bit,\n        quantization_scheme=quantization_scheme,\n    )\n    output_tensor = fake_quantization(input_tensor, scale, zero_point)\n    y = output_tensor.mean()\n    y = y.backward()\n\n    out = output_tensor.numpy()\n    input_diff = input_tensor.grad.numpy()\n    _check_fake_quantize(\n        test_case,\n        input,\n        input_diff.flatten(),\n        out.flatten(),\n        quantization_bit,\n        quantization_scheme,\n        quantization_formula,\n        per_layer_quantization,\n    )\n\n\nclass TestFakeQuantize(flow.unittest.TestCase):\n    def test_fake_quantize(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_case\"] = [test_case]\n        arg_dict[\"device_type\"] = [\"cuda\", \"cpu\"]\n        arg_dict[\"dtype\"] = [\"float32\", \"double\"]\n        arg_dict[\"in_shape\"] = [(9, 40, 20, 10)]\n        arg_dict[\"quantization_bit\"] = [8, 2]\n        arg_dict[\"quantization_scheme\"] = [\"symmetric\", \"affine\"]\n        arg_dict[\"quantization_formula\"] = [\"google\"]\n        arg_dict[\"per_layer_quantization\"] = [True, False]\n        for arg in GenArgList(arg_dict):\n            if arg[-2] == \"cambricon\" and arg[-1] == False:\n                continue\n            _run_test_fake_quantize(*arg)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_fft.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nimport torch as torch_original\nfrom packaging import version\n\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.test_util import GenArgList\n\nfrom oneflow.test_utils.automated_test_util import *\nimport os\n\n\ndef is_cufft_available():\n    if flow.cuda.is_available():\n        (major, _minor) = flow.cuda.get_device_capability()\n        return major >= 7\n    else:\n        return False\n\n\ndef is_complex_dtype(dtype):\n    if hasattr(dtype, \"pytorch\") and hasattr(dtype, \"oneflow\"):\n        # is DualObject\n        return dtype.pytorch.is_complex\n    else:\n        return dtype in [\n            flow.complex64,\n            flow.complex128,\n            torch_original.complex64,\n            torch_original.complex128,\n            torch.pytorch.complex64,\n            torch.pytorch.complex128,\n        ]\n\n\ndef gen_params_1d_fft(lower_n_dims=1, upper_n_dims=5):\n    num_dims = np.random.randint(lower_n_dims, upper_n_dims)\n    shape = [np.random.randint(1, 5) * 2 for _ in range(num_dims)]\n\n    if np.random.randint(2) == 1:\n        dim = np.random.randint(low=-num_dims, high=num_dims - 1)\n    else:\n        dim = -1\n\n    norm = np.random.choice([\"backward\", \"forward\", \"ortho\", None])\n\n    if np.random.randint(2) == 1:\n        n = None\n    else:\n        n = np.random.randint(low=1, high=shape[dim] * 2)\n\n    params = {\n        \"num_dims\": num_dims,\n        \"shape\": shape,\n        \"n\": n,\n        \"dim\": dim,\n        \"norm\": norm,\n    }\n    return params\n\n\ndef gen_params_2d_fft(lower_n_dims=2, upper_n_dims=5):\n    num_dims = np.random.randint(lower_n_dims, upper_n_dims)\n    shape = [np.random.randint(1, 5) * 2 for _ in range(num_dims)]\n    len_fft_dim = np.random.randint(low=1, high=3)\n\n    total_dims_range = np.arange(num_dims)\n    if np.random.randint(2) == 1:\n        dims = np.random.choice(\n            total_dims_range, size=len_fft_dim, replace=False\n        ).tolist()\n    else:\n        dims = (-2, -1)\n\n    norm = np.random.choice([\"backward\", \"forward\", \"ortho\", None])\n    len_fft_dim = len(dims)\n    if np.random.randint(2) == 1 and dims is not None:\n        n = []\n        for i in range(len_fft_dim):\n            n_ = (\n                np.random.randint(low=1, high=2 * shape[i])\n                if np.random.randint(2) == 1\n                else -1\n            )\n            n.append(n_)\n    else:\n        n = None\n\n    params = {\n        \"num_dims\": num_dims,\n        \"shape\": shape,\n        \"n\": n,\n        \"dim\": dims,\n        \"norm\": norm,\n    }\n    return params\n\n\ndef gen_params_nd_fft(lower_n_dims=2, upper_n_dims=5):\n    num_dims = np.random.randint(lower_n_dims, upper_n_dims)\n    shape = [np.random.randint(1, 5) * 2 for _ in range(num_dims)]\n    len_fft_dim = np.random.randint(low=1, high=num_dims + 1)\n\n    total_dims_range = np.arange(num_dims)\n    if np.random.randint(2) == 1:\n        dims = np.random.choice(\n            total_dims_range, size=len_fft_dim, replace=False\n        ).tolist()\n    else:\n        dims = None\n\n    norm = np.random.choice([\"backward\", \"forward\", \"ortho\", None])\n\n    if np.random.randint(2) == 1:\n        n = None\n    else:\n        n = []\n        len_fft_dim = (\n            len(dims)\n            if dims is not None\n            else np.random.randint(low=1, high=num_dims + 1)\n        )\n        for i in range(len_fft_dim):\n            n_ = (\n                np.random.randint(low=1, high=2 * shape[i])\n                if np.random.randint(2) == 1\n                else -1\n            )\n            n.append(n_)\n\n    params = {\n        \"num_dims\": num_dims,\n        \"shape\": shape,\n        \"n\": n,\n        \"dim\": dims,\n        \"norm\": norm,\n    }\n    return params\n\n\ndef _test_fft(test_case):\n    if is_cufft_available():\n        device = random_device()\n    else:\n        device = cpu_device()\n\n    lower_n_dims = test_case.ndims_dict[\"1d\"][\"lower_n_dims\"]\n    upper_n_dims = test_case.ndims_dict[\"1d\"][\"upper_n_dims\"]\n    params = gen_params_1d_fft(lower_n_dims, upper_n_dims)\n\n    num_dims = params[\"num_dims\"]\n    shape = params[\"shape\"]\n    n = params[\"n\"]\n    dim = params[\"dim\"]\n    norm = params[\"norm\"]\n\n    x = random_tensor(num_dims, dtype=float, *shape)\n    if is_complex_dtype(x.dtype):\n        # test fft_c2c\n        dtype = test_case.dtype_dict[\"complex\"]\n        x = x.to(device=device, dtype=dtype)\n    else:\n        # test fft_r2c\n        dtype = test_case.dtype_dict[\"real\"]\n        x = x.to(device=device, dtype=dtype)\n    y = torch.fft.fft(x, n, dim, norm)\n    return y\n\n\ndef _test_ifft(test_case):\n    if is_cufft_available():\n        device = random_device()\n    else:\n        device = cpu_device()\n\n    lower_n_dims = test_case.ndims_dict[\"1d\"][\"lower_n_dims\"]\n    upper_n_dims = test_case.ndims_dict[\"1d\"][\"upper_n_dims\"]\n    params = gen_params_1d_fft(lower_n_dims, upper_n_dims)\n\n    num_dims = params[\"num_dims\"]\n    shape = params[\"shape\"]\n    n = params[\"n\"]\n    dim = params[\"dim\"]\n    norm = params[\"norm\"]\n\n    x = random_tensor(num_dims, dtype=float, *shape)\n    if is_complex_dtype(x.dtype):\n        # test fft_c2c\n        dtype = test_case.dtype_dict[\"complex\"]\n        x = x.to(device=device, dtype=dtype)\n    else:\n        # test fft_r2c\n        dtype = test_case.dtype_dict[\"real\"]\n        x = x.to(device=device, dtype=dtype)\n\n    y = torch.fft.ifft(x, n, dim, norm)\n\n    return y\n\n\ndef _test_rfft(test_case):\n    if is_cufft_available():\n        device = random_device()\n    else:\n        device = cpu_device()\n\n    lower_n_dims = test_case.ndims_dict[\"1d\"][\"lower_n_dims\"]\n    upper_n_dims = test_case.ndims_dict[\"1d\"][\"upper_n_dims\"]\n    params = gen_params_1d_fft(lower_n_dims, upper_n_dims)\n\n    num_dims = params[\"num_dims\"]\n    shape = params[\"shape\"]\n    n = params[\"n\"]\n    dim = params[\"dim\"]\n    norm = params[\"norm\"]\n\n    dtype = test_case.dtype_dict[\"real\"]\n\n    x = random_tensor(num_dims, dtype=float, *shape).to(device=device, dtype=dtype)\n    y = torch.fft.rfft(x, n, dim, norm)\n\n    return y\n\n\ndef _test_irfft(test_case):\n    if is_cufft_available():\n        device = random_device()\n    else:\n        device = cpu_device()\n\n    lower_n_dims = test_case.ndims_dict[\"1d\"][\"lower_n_dims\"]\n    upper_n_dims = test_case.ndims_dict[\"1d\"][\"upper_n_dims\"]\n    params = gen_params_1d_fft(lower_n_dims, upper_n_dims)\n\n    num_dims = params[\"num_dims\"]\n    shape = params[\"shape\"]\n    n = params[\"n\"]\n    dim = params[\"dim\"]\n    norm = params[\"norm\"]\n    dtype = test_case.dtype_dict[\"complex\"]\n\n    x = random_tensor(num_dims, dtype=float, *shape).to(device=device, dtype=dtype)\n    y = torch.fft.irfft(x, n, dim, norm)\n\n    return y\n\n\ndef _test_hfft(test_case):\n    if is_cufft_available():\n        device = random_device()\n    else:\n        device = cpu_device()\n\n    lower_n_dims = test_case.ndims_dict[\"1d\"][\"lower_n_dims\"]\n    upper_n_dims = test_case.ndims_dict[\"1d\"][\"upper_n_dims\"]\n    params = gen_params_1d_fft(lower_n_dims, upper_n_dims)\n\n    num_dims = params[\"num_dims\"]\n    shape = params[\"shape\"]\n    n = params[\"n\"]\n    dim = params[\"dim\"]\n    norm = params[\"norm\"]\n    dtype = test_case.dtype_dict[\"complex\"]\n\n    x = random_tensor(num_dims, dtype=float, *shape).to(device=device, dtype=dtype)\n    y = torch.fft.hfft(x, n, dim, norm)\n\n    return y\n\n\ndef _test_ihfft(test_case):\n    if is_cufft_available():\n        device = random_device()\n    else:\n        device = cpu_device()\n\n    lower_n_dims = test_case.ndims_dict[\"1d\"][\"lower_n_dims\"]\n    upper_n_dims = test_case.ndims_dict[\"1d\"][\"upper_n_dims\"]\n    params = gen_params_1d_fft(lower_n_dims, upper_n_dims)\n\n    num_dims = params[\"num_dims\"]\n    shape = params[\"shape\"]\n    n = params[\"n\"]\n    dim = params[\"dim\"]\n    norm = params[\"norm\"]\n    dtype = test_case.dtype_dict[\"real\"]\n\n    x = random_tensor(num_dims, dtype=float, *shape).to(device=device, dtype=dtype)\n    y = torch.fft.ihfft(x, n, dim, norm)\n\n    return y\n\n\ndef _test_fft2(test_case):\n    if is_cufft_available():\n        device = random_device()\n    else:\n        device = cpu_device()\n\n    lower_n_dims = test_case.ndims_dict[\"2d\"][\"lower_n_dims\"]\n    upper_n_dims = test_case.ndims_dict[\"2d\"][\"upper_n_dims\"]\n    params = gen_params_2d_fft(lower_n_dims, upper_n_dims)\n\n    num_dims = params[\"num_dims\"]\n    shape = params[\"shape\"]\n    n = params[\"n\"]\n    dim = params[\"dim\"]\n    norm = params[\"norm\"]\n\n    x = random_tensor(num_dims, dtype=float, *shape)\n    if is_complex_dtype(x.dtype):\n        # test fft_c2c\n        dtype = test_case.dtype_dict[\"complex\"]\n        x = x.to(device=device, dtype=dtype)\n    else:\n        # test fft_r2c\n        dtype = test_case.dtype_dict[\"real\"]\n        x = x.to(device=device, dtype=dtype)\n    y = torch.fft.fft2(x, n, dim, norm)\n\n    return y\n\n\ndef _test_ifft2(test_case):\n    if is_cufft_available():\n        device = random_device()\n    else:\n        device = cpu_device()\n\n    lower_n_dims = test_case.ndims_dict[\"2d\"][\"lower_n_dims\"]\n    upper_n_dims = test_case.ndims_dict[\"2d\"][\"upper_n_dims\"]\n    params = gen_params_2d_fft(lower_n_dims, upper_n_dims)\n\n    num_dims = params[\"num_dims\"]\n    shape = params[\"shape\"]\n    n = params[\"n\"]\n    dim = params[\"dim\"]\n    norm = params[\"norm\"]\n\n    x = random_tensor(num_dims, dtype=float, *shape)\n    if is_complex_dtype(x.dtype):\n        # test fft_c2c\n        dtype = test_case.dtype_dict[\"complex\"]\n        x = x.to(device=device, dtype=dtype)\n    else:\n        # test fft_r2c\n        dtype = test_case.dtype_dict[\"real\"]\n        x = x.to(device=device, dtype=dtype)\n\n    y = torch.fft.ifft2(x, n, dim, norm)\n\n    return y\n\n\ndef _test_rfft2(test_case):\n    if is_cufft_available():\n        device = random_device()\n    else:\n        device = cpu_device()\n\n    lower_n_dims = test_case.ndims_dict[\"2d\"][\"lower_n_dims\"]\n    upper_n_dims = test_case.ndims_dict[\"2d\"][\"upper_n_dims\"]\n    params = gen_params_2d_fft(lower_n_dims, upper_n_dims)\n\n    num_dims = params[\"num_dims\"]\n    shape = params[\"shape\"]\n    n = params[\"n\"]\n    dim = params[\"dim\"]\n    norm = params[\"norm\"]\n\n    dtype = test_case.dtype_dict[\"real\"]\n\n    x = random_tensor(num_dims, dtype=float, *shape).to(device=device, dtype=dtype)\n    y = torch.fft.rfft2(x, n, dim, norm)\n\n    return y\n\n\ndef _test_irfft2(test_case):\n    if is_cufft_available():\n        device = random_device()\n    else:\n        device = cpu_device()\n\n    lower_n_dims = test_case.ndims_dict[\"2d\"][\"lower_n_dims\"]\n    upper_n_dims = test_case.ndims_dict[\"2d\"][\"upper_n_dims\"]\n    params = gen_params_2d_fft(lower_n_dims, upper_n_dims)\n\n    num_dims = params[\"num_dims\"]\n    shape = params[\"shape\"]\n    n = params[\"n\"]\n    dim = params[\"dim\"]\n    norm = params[\"norm\"]\n    dtype = test_case.dtype_dict[\"complex\"]\n\n    x = random_tensor(num_dims, dtype=float, *shape).to(device=device, dtype=dtype)\n    y = torch.fft.irfft2(x, n, dim, norm)\n\n    return y\n\n\ndef _test_hfft2(test_case):\n    if is_cufft_available():\n        device = random_device()\n    else:\n        device = cpu_device()\n\n    lower_n_dims = test_case.ndims_dict[\"2d\"][\"lower_n_dims\"]\n    upper_n_dims = test_case.ndims_dict[\"2d\"][\"upper_n_dims\"]\n    params = gen_params_2d_fft(lower_n_dims, upper_n_dims)\n\n    num_dims = params[\"num_dims\"]\n    shape = params[\"shape\"]\n    n = params[\"n\"]\n    dim = params[\"dim\"]\n    norm = params[\"norm\"]\n    dtype = test_case.dtype_dict[\"complex\"]\n\n    x = random_tensor(num_dims, dtype=float, *shape).to(device=device, dtype=dtype)\n    y = torch.fft.hfft2(x, n, dim, norm)\n\n    return y\n\n\ndef _test_ihfft2(test_case):\n    if is_cufft_available():\n        device = random_device()\n    else:\n        device = cpu_device()\n\n    lower_n_dims = test_case.ndims_dict[\"2d\"][\"lower_n_dims\"]\n    upper_n_dims = test_case.ndims_dict[\"2d\"][\"upper_n_dims\"]\n    params = gen_params_2d_fft(lower_n_dims, upper_n_dims)\n\n    num_dims = params[\"num_dims\"]\n    shape = params[\"shape\"]\n    n = params[\"n\"]\n    dim = params[\"dim\"]\n    norm = params[\"norm\"]\n    dtype = test_case.dtype_dict[\"real\"]\n\n    x = random_tensor(num_dims, dtype=float, *shape).to(device=device, dtype=dtype)\n    y = torch.fft.ihfft2(x, n, dim, norm)\n\n    return y\n\n\ndef _test_fftn(test_case):\n    if is_cufft_available():\n        device = random_device()\n    else:\n        device = cpu_device()\n\n    lower_n_dims = test_case.ndims_dict[\"nd\"][\"lower_n_dims\"]\n    upper_n_dims = test_case.ndims_dict[\"nd\"][\"upper_n_dims\"]\n    params = gen_params_nd_fft(lower_n_dims, upper_n_dims)\n\n    num_dims = params[\"num_dims\"]\n    shape = params[\"shape\"]\n    n = params[\"n\"]\n    dim = params[\"dim\"]\n    norm = params[\"norm\"]\n\n    x = random_tensor(num_dims, dtype=float, *shape)\n    if is_complex_dtype(x.dtype):\n        # test fft_c2c\n        dtype = test_case.dtype_dict[\"complex\"]\n        x = x.to(device=device, dtype=dtype)\n    else:\n        # test fft_r2c\n        dtype = test_case.dtype_dict[\"real\"]\n        x = x.to(device=device, dtype=dtype)\n    y = torch.fft.fftn(x, n, dim, norm)\n\n    return y\n\n\ndef _test_ifftn(test_case):\n    if is_cufft_available():\n        device = random_device()\n    else:\n        device = cpu_device()\n\n    lower_n_dims = test_case.ndims_dict[\"nd\"][\"lower_n_dims\"]\n    upper_n_dims = test_case.ndims_dict[\"nd\"][\"upper_n_dims\"]\n    params = gen_params_nd_fft(lower_n_dims, upper_n_dims)\n\n    num_dims = params[\"num_dims\"]\n    shape = params[\"shape\"]\n    n = params[\"n\"]\n    dim = params[\"dim\"]\n    norm = params[\"norm\"]\n\n    x = random_tensor(num_dims, dtype=float, *shape)\n    if is_complex_dtype(x.dtype):\n        # test fft_c2c\n        dtype = test_case.dtype_dict[\"complex\"]\n        x = x.to(device=device, dtype=dtype)\n    else:\n        # test fft_r2c\n        dtype = test_case.dtype_dict[\"real\"]\n        x = x.to(device=device, dtype=dtype)\n\n    y = torch.fft.ifftn(x, n, dim, norm)\n\n    return y\n\n\ndef _test_rfftn(test_case):\n    if is_cufft_available():\n        device = random_device()\n    else:\n        device = cpu_device()\n\n    lower_n_dims = test_case.ndims_dict[\"nd\"][\"lower_n_dims\"]\n    upper_n_dims = test_case.ndims_dict[\"nd\"][\"upper_n_dims\"]\n    params = gen_params_nd_fft(lower_n_dims, upper_n_dims)\n\n    num_dims = params[\"num_dims\"]\n    shape = params[\"shape\"]\n    n = params[\"n\"]\n    dim = params[\"dim\"]\n    norm = params[\"norm\"]\n\n    dtype = test_case.dtype_dict[\"real\"]\n\n    x = random_tensor(num_dims, dtype=float, *shape).to(device=device, dtype=dtype)\n    y = torch.fft.rfftn(x, n, dim, norm)\n\n    return y\n\n\ndef _test_irfftn(test_case):\n    if is_cufft_available():\n        device = random_device()\n    else:\n        device = cpu_device()\n\n    lower_n_dims = test_case.ndims_dict[\"nd\"][\"lower_n_dims\"]\n    upper_n_dims = test_case.ndims_dict[\"nd\"][\"upper_n_dims\"]\n    params = gen_params_nd_fft(lower_n_dims, upper_n_dims)\n\n    num_dims = params[\"num_dims\"]\n    shape = params[\"shape\"]\n    n = params[\"n\"]\n    dim = params[\"dim\"]\n    norm = params[\"norm\"]\n    dtype = test_case.dtype_dict[\"complex\"]\n\n    x = random_tensor(num_dims, dtype=float, *shape).to(device=device, dtype=dtype)\n    y = torch.fft.irfftn(x, n, dim, norm)\n\n    return y\n\n\ndef _test_hfftn(test_case):\n    if is_cufft_available():\n        device = random_device()\n    else:\n        device = cpu_device()\n\n    lower_n_dims = test_case.ndims_dict[\"nd\"][\"lower_n_dims\"]\n    upper_n_dims = test_case.ndims_dict[\"nd\"][\"upper_n_dims\"]\n    params = gen_params_nd_fft(lower_n_dims, upper_n_dims)\n\n    num_dims = params[\"num_dims\"]\n    shape = params[\"shape\"]\n    n = params[\"n\"]\n    dim = params[\"dim\"]\n    norm = params[\"norm\"]\n    dtype = test_case.dtype_dict[\"complex\"]\n\n    x = random_tensor(num_dims, dtype=float, *shape).to(device=device, dtype=dtype)\n    y = torch.fft.hfftn(x, n, dim, norm)\n\n    return y\n\n\ndef _test_ihfftn(test_case):\n    if is_cufft_available():\n        device = random_device()\n    else:\n        device = cpu_device()\n\n    lower_n_dims = test_case.ndims_dict[\"nd\"][\"lower_n_dims\"]\n    upper_n_dims = test_case.ndims_dict[\"nd\"][\"upper_n_dims\"]\n    params = gen_params_nd_fft(lower_n_dims, upper_n_dims)\n\n    num_dims = params[\"num_dims\"]\n    shape = params[\"shape\"]\n    n = params[\"n\"]\n    dim = params[\"dim\"]\n    norm = params[\"norm\"]\n    dtype = test_case.dtype_dict[\"real\"]\n\n    x = random_tensor(num_dims, dtype=float, *shape).to(device=device, dtype=dtype)\n    y = torch.fft.ihfftn(x, n, dim, norm)\n\n    return y\n\n\n# NOTE: skip for multi-nodes and multi-devices now, because it failed in ci randomly\n@flow.unittest.skip_unless_1n1d()\nclass TestComplex64Fft(flow.unittest.TestCase):\n    def setUp(test_case):\n        # should override by other data type of complex\n        test_case.ndims_dict = {\n            \"1d\": {\"lower_n_dims\": 1, \"upper_n_dims\": 5},\n            \"2d\": {\"lower_n_dims\": 2, \"upper_n_dims\": 5},\n            \"nd\": {\"lower_n_dims\": 1, \"upper_n_dims\": 5},\n        }\n\n        test_case.dtype_dict = {\"real\": torch.float32, \"complex\": torch.complex64}\n\n        test_case.rtol = 1e-5\n        test_case.atol = 1e-5\n        if os.environ[\"ONEFLOW_CI\"] == \"1\":\n            test_case.rtol = 1e-2\n            test_case.atol = 1e-2\n        test_case.initTestFft()\n\n    def initTestFft(test_case):\n        test_case.test_fft = autotest(\n            n=5,\n            auto_backward=True,\n            rtol=test_case.rtol,\n            atol=test_case.atol,\n            check_graph=False,\n            check_grad_use_random_data=True,\n            include_complex=True,\n        )(_test_fft)\n\n        test_case.test_ifft = autotest(\n            n=5,\n            auto_backward=True,\n            rtol=test_case.rtol,\n            atol=test_case.atol,\n            check_graph=False,\n            check_grad_use_random_data=True,\n            include_complex=True,\n        )(_test_ifft)\n\n        test_case.test_rfft = autotest(\n            n=5,\n            auto_backward=True,\n            rtol=test_case.rtol,\n            atol=test_case.atol,\n            check_graph=False,\n            check_grad_use_random_data=True,\n            include_complex=False,\n        )(_test_rfft)\n\n        test_case.test_irfft = autotest(\n            n=5,\n            auto_backward=True,\n            rtol=test_case.rtol,\n            atol=test_case.atol,\n            check_graph=False,\n            check_grad_use_random_data=True,\n            include_complex=True,\n        )(_test_irfft)\n\n        test_case.test_hfft = autotest(\n            n=5,\n            auto_backward=True,\n            rtol=test_case.rtol,\n            atol=test_case.atol,\n            check_graph=False,\n            check_grad_use_random_data=True,\n            include_complex=True,\n        )(_test_hfft)\n\n        test_case.test_ihfft = autotest(\n            n=5,\n            auto_backward=True,\n            rtol=test_case.rtol,\n            atol=test_case.atol,\n            check_graph=False,\n            check_grad_use_random_data=True,\n            include_complex=False,\n        )(_test_ihfft)\n\n        test_case.test_fft2 = autotest(\n            n=5,\n            auto_backward=True,\n            rtol=test_case.rtol,\n            atol=test_case.atol,\n            check_graph=False,\n            check_grad_use_random_data=True,\n            include_complex=True,\n        )(_test_fft2)\n\n        test_case.test_ifft2 = autotest(\n            n=5,\n            auto_backward=True,\n            rtol=test_case.rtol,\n            atol=test_case.atol,\n            check_graph=False,\n            check_grad_use_random_data=True,\n            include_complex=True,\n        )(_test_ifft2)\n\n        test_case.test_rfft2 = autotest(\n            n=5,\n            auto_backward=True,\n            rtol=test_case.rtol,\n            atol=test_case.atol,\n            check_graph=False,\n            check_grad_use_random_data=True,\n            include_complex=False,\n        )(_test_rfft2)\n\n        test_case.test_irfft2 = autotest(\n            n=5,\n            auto_backward=True,\n            rtol=test_case.rtol,\n            atol=test_case.atol\n            * 100,  # NOTE: ND-dimension of fft_c2r expands the numerical accuracy error\n            check_graph=False,\n            check_grad_use_random_data=True,\n            include_complex=True,\n        )(_test_irfft2)\n\n        test_case.test_hfft2 = autotest(\n            n=5,\n            auto_backward=True,\n            rtol=test_case.rtol,\n            atol=test_case.atol\n            * 100,  # NOTE: ND-dimension of fft_c2r expands the numerical accuracy error\n            check_graph=False,\n            check_grad_use_random_data=True,\n            include_complex=True,\n        )(_test_hfft2)\n\n        test_case.test_ihfft2 = autotest(\n            n=5,\n            auto_backward=True,\n            rtol=test_case.rtol,\n            atol=test_case.atol,\n            check_graph=False,\n            check_grad_use_random_data=True,\n            include_complex=False,\n        )(_test_ihfft2)\n\n        test_case.test_fftn = autotest(\n            n=5,\n            auto_backward=True,\n            rtol=test_case.rtol,\n            atol=test_case.atol * 1e2,  # NOTE:\n            check_graph=False,\n            check_grad_use_random_data=True,\n            include_complex=True,\n        )(_test_fftn)\n\n        test_case.test_ifftn = autotest(\n            n=5,\n            auto_backward=True,\n            rtol=test_case.rtol,\n            atol=test_case.atol * 1e2,\n            check_graph=False,\n            check_grad_use_random_data=True,\n            include_complex=True,\n        )(_test_ifftn)\n\n        test_case.test_rfftn = autotest(\n            n=5,\n            auto_backward=True,\n            rtol=test_case.rtol,\n            atol=test_case.atol * 1e2,\n            check_graph=False,\n            check_grad_use_random_data=True,\n            include_complex=False,\n        )(_test_rfftn)\n\n        test_case.test_irfftn = autotest(\n            n=5,\n            auto_backward=True,\n            rtol=test_case.rtol,\n            atol=test_case.atol\n            * 1e2,  # NOTE: ND-dimension of fft_c2r expands the numerical accuracy error\n            check_graph=False,\n            check_grad_use_random_data=True,\n            include_complex=True,\n        )(_test_irfftn)\n\n        test_case.test_hfftn = autotest(\n            n=5,\n            auto_backward=True,\n            rtol=test_case.rtol,\n            atol=test_case.atol\n            * 1e2,  # NOTE: ND-dimension of fft_c2r expands the numerical accuracy error\n            check_graph=False,\n            check_grad_use_random_data=True,\n            include_complex=True,\n        )(_test_hfftn)\n\n        test_case.test_ihfftn = autotest(\n            n=5,\n            auto_backward=True,\n            rtol=test_case.rtol,\n            atol=test_case.atol * 1e2,\n            check_graph=False,\n            check_grad_use_random_data=True,\n            include_complex=False,\n        )(_test_ihfftn)\n\n    def test_1d_fft(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [\n            test_case.test_fft,\n            test_case.test_ifft,\n            test_case.test_rfft,\n            test_case.test_irfft,\n            test_case.test_hfft,\n            test_case.test_ihfft,\n        ]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n    def test_2d_fft_except_hfft2(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [\n            test_case.test_fft2,\n            test_case.test_ifft2,\n            test_case.test_rfft2,\n            test_case.test_irfft2,\n        ]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n    @unittest.skipIf(\n        version.parse(torch_original.__version__) < version.parse(\"1.11.0\"),\n        \"module 'torch.fft' has no attribute 'hfft2' or 'ihfft2' before '1.11.0'\",\n    )\n    def test_2d_fft_hfft2(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [test_case.test_hfft2, test_case.test_ihfft2]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n    def test_nd_fft_except_hfftn(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [\n            test_case.test_fftn,\n            test_case.test_ifftn,\n            test_case.test_rfftn,\n            test_case.test_irfftn,\n        ]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n    @unittest.skipIf(\n        version.parse(torch_original.__version__) < version.parse(\"1.11.0\"),\n        \"module 'torch.fft' has no attribute 'hfftn' or 'ihfftn' before '1.11.0'\",\n    )\n    def test_nd_fft_hfftn(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [test_case.test_hfftn, test_case.test_ihfftn]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n\n# NOTE: skip for multi-nodes and multi-devices now, because it failed in ci randomly\n@flow.unittest.skip_unless_1n1d()\nclass TestComplex128Fft(TestComplex64Fft):\n    def setUp(test_case):\n        # should override by other data type of complex\n        test_case.ndims_dict = {\n            \"1d\": {\"lower_n_dims\": 1, \"upper_n_dims\": 5},\n            \"2d\": {\"lower_n_dims\": 2, \"upper_n_dims\": 5},\n            \"nd\": {\"lower_n_dims\": 1, \"upper_n_dims\": 5},\n        }\n\n        test_case.dtype_dict = {\"real\": torch.float64, \"complex\": torch.complex128}\n\n        test_case.rtol = 1e-7\n        test_case.atol = 1e-7\n        test_case.initTestFft()\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_flatten.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\n\nfrom oneflow.test_utils.automated_test_util import *\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\ndef _test_flatten(test_case, device):\n    m = flow.nn.Flatten()\n    x = flow.Tensor(32, 2, 5, 5, device=flow.device(device))\n    flow.nn.init.uniform_(x)\n    y = m(x)\n    test_case.assertTrue(y.shape == flow.Size((32, 50)))\n    test_case.assertTrue(np.array_equal(y.numpy().flatten(), x.numpy().flatten()))\n    y2 = flow.flatten(x, start_dim=2)\n    test_case.assertTrue(y2.shape == flow.Size((32, 2, 25)))\n    test_case.assertTrue(np.array_equal(y2.numpy().flatten(), x.numpy().flatten()))\n    y3 = x.flatten(start_dim=1)\n    test_case.assertTrue(y3.shape == flow.Size((32, 50)))\n    test_case.assertTrue(np.array_equal(y3.numpy().flatten(), x.numpy().flatten()))\n    y4 = x.flatten(start_dim=1, end_dim=2)\n    test_case.assertTrue(y4.shape == flow.Size((32, 10, 5)))\n    test_case.assertTrue(np.array_equal(y4.numpy().flatten(), x.numpy().flatten()))\n    y5 = flow.flatten(x)\n    test_case.assertTrue(y5.shape == flow.Size((1600,)))\n    test_case.assertTrue(np.array_equal(y5.numpy().flatten(), x.numpy().flatten()))\n\n\ndef _test_flatten_backward(test_case, device):\n    m = flow.nn.Flatten().to(flow.device(device))\n    x = flow.Tensor(2, 3, 4, 5, device=flow.device(device))\n    x.requires_grad = True\n    flow.nn.init.uniform_(x)\n    y = m(x)\n    z = y.sum()\n    z.backward()\n    test_case.assertTrue(np.array_equal(np.ones(shape=(2, 3, 4, 5)), x.grad.numpy()))\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestFlattenModule(flow.unittest.TestCase):\n    def test_cast(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [_test_flatten, _test_flatten_backward]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n    @autotest(n=5)\n    def test_flatten_module_with_random_data(test_case):\n        m = torch.nn.Flatten(\n            start_dim=random(1, 6) | nothing(), end_dim=random(1, 6) | nothing()\n        )\n        m.train(random())\n        device = random_device()\n        m.to(device)\n        x = random_tensor().to(device)\n        y = m(x)\n        return y\n\n    @autotest(n=5)\n    def test_flatten_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor().to(device)\n        y = torch.flatten(\n            x,\n            start_dim=random(1, 6).to(int) | nothing(),\n            end_dim=random(1, 6).to(int) | nothing(),\n        )\n        return y\n\n    @autotest(n=5, auto_backward=False, check_graph=True)\n    def test_flatten_bool_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor().to(device=device, dtype=torch.bool)\n        y = torch.flatten(\n            x,\n            start_dim=random(1, 6).to(int) | nothing(),\n            end_dim=random(1, 6).to(int) | nothing(),\n        )\n        return y\n\n    @autotest(n=5)\n    def test_flatten_with_0dim_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=0).to(device)\n        y = torch.flatten(\n            x,\n            start_dim=random(1, 6).to(int) | nothing(),\n            end_dim=random(1, 6).to(int) | nothing(),\n        )\n        return y\n\n    @profile(torch.flatten)\n    def profile_flatten(test_case):\n        torch.flatten(torch.ones(1000, 1000))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_flip.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\n\nfrom oneflow.test_utils.automated_test_util import *\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestFlip(flow.unittest.TestCase):\n    @autotest(check_graph=True, check_allclose=False)\n    def test_flow_flip_list_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(\n            ndim=4, dim1=random().to(int), dim2=random().to(int), dim3=random().to(int)\n        ).to(device)\n        y = torch.flip(x, constant([0, 1, 2]))\n        return y\n\n    @autotest(n=5)\n    def test_flow_flip_tuple_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(\n            ndim=4, dim1=random().to(int), dim2=random().to(int), dim3=random().to(int)\n        ).to(device)\n        y = torch.flip(x, constant((0, 1, 2)))\n        return y\n\n    @autotest(n=5, auto_backward=False, check_graph=True)\n    def test_flow_flip_bool_tuple_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(\n            ndim=4, dim1=random().to(int), dim2=random().to(int), dim3=random().to(int)\n        ).to(device=device, dtype=torch.bool)\n        y = torch.flip(x, constant((0, 1, 2)))\n        return y\n\n    def test_flow_flip_list_lastdim_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(\n            ndim=4, dim1=random().to(int), dim2=random().to(int), dim3=random().to(int)\n        ).to(device)\n        y = torch.flip(x, [-1,])\n        return y\n\n    @profile(torch.flip)\n    def profile_flip(test_case):\n        torch.flip(torch.ones(100, 100, 100), [0, 1])\n        torch.flip(torch.ones(1, 100000), [-1,])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_floor.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\ndef _test_floor(test_case, shape, device):\n    np_input = np.random.randn(*shape)\n    of_input = flow.tensor(\n        np_input, dtype=flow.float32, device=flow.device(device), requires_grad=True\n    )\n    of_out = flow.floor(of_input)\n    np_out = np.floor(np_input)\n    test_case.assertTrue(\n        np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05, equal_nan=True)\n    )\n    of_out = of_out.sum()\n    of_out.backward()\n    np_out_grad = np.zeros_like(of_out, dtype=np.float32)\n    test_case.assertTrue(\n        np.allclose(of_input.grad.numpy(), np_out_grad, 0.0001, 0.0001, equal_nan=True)\n    )\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestFloor(flow.unittest.TestCase):\n    def test_floor(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"shape\"] = [(2,), (2, 3), (2, 4, 5, 6)]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            _test_floor(test_case, *arg)\n\n    @autotest(check_graph=True)\n    def test_flow_floor_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor().to(device)\n        y = torch.floor(x)\n        return y\n\n    @autotest(check_graph=True)\n    def test_flow_floor_inplace_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor().to(device)\n        y = x + 1\n        y.floor_()\n        return y\n\n    @autotest(check_graph=True)\n    def test_flow_floor_with_0dim_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=0).to(device)\n        y = torch.floor(x)\n        return y\n\n    @profile(torch.floor)\n    def profile_floor(test_case):\n        torch.floor(torch.ones(100, 100, 100))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_fmod.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport random as rd\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\n\nfrom oneflow.test_utils.automated_test_util import *\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport oneflow.unittest\n\nimport torch as torch_original\nfrom packaging import version\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestFmodModule(flow.unittest.TestCase):\n    # other.grad in torch.fmod(input, other) was not implemented before pytorch 1.11.0\n    grad_implemented = version.parse(torch_original.__version__) >= version.parse(\n        \"1.11.0\"\n    )\n\n    @autotest(n=1, auto_backward=grad_implemented)\n    def test_flow_fmod_element_with_random_data(test_case):\n        device = random_device()\n        dim1 = random().to(int)\n        dim2 = random().to(int)\n        input = random_tensor(ndim=3, dim1=dim1, dim2=dim2).to(device)\n        other = random_tensor(ndim=3, dim1=dim1, dim2=dim2).to(device)\n        return torch.fmod(input, other)\n\n    @autotest(n=1, auto_backward=grad_implemented)\n    def test_flow_fmod_element_with_0dim_data(test_case):\n        device = random_device()\n        input = random_tensor(ndim=0).to(device)\n        other = random_tensor(ndim=0).to(device)\n        return torch.fmod(input, other)\n\n    @autotest(n=1, auto_backward=grad_implemented)\n    def test_flow_fmod_broadcast_with_random_data(test_case):\n        device = random_device()\n        dim1 = random().to(int)\n        dim2 = random().to(int)\n        input = random_tensor(ndim=3, dim1=constant(1), dim2=dim2).to(device)\n        other = random_tensor(ndim=3, dim1=dim1, dim2=constant(1)).to(device)\n        return torch.fmod(input, other)\n\n    @autotest(n=1, auto_backward=True)\n    def test_flow_fmod_scalar_with_random_data(test_case):\n        device = random_device()\n        dim1 = random().to(int)\n        dim2 = random().to(int)\n        input = random_tensor(ndim=3, dim1=dim1, dim2=dim2).to(device)\n        other = 3\n        return torch.fmod(input, other)\n\n    @autotest(n=1, auto_backward=True)\n    def test_fmod_with_0_size_data(test_case):\n        device = random_device()\n        x = random_tensor(4, 2, 1, 0, 3).to(device)\n        y = torch.fmod(x, 2)\n        return y\n\n    @profile(torch.fmod)\n    def profile_fmod(test_case):\n        torch.fmod(torch.ones(100, 100, 100), 1)\n        torch.fmod(torch.ones(100, 100, 100), -0.5)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_fold.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\n\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\nfrom oneflow.nn.common_types import _size_2_t\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestFold(flow.unittest.TestCase):\n    @autotest(n=3, auto_backward=True, rtol=1e-4, atol=1e-4)\n    def test_fold_with_random_data_1(test_case):\n        m = torch.nn.Fold(\n            output_size=constant((4, 4)),\n            kernel_size=constant(3),\n            dilation=constant(1),\n            padding=constant(1),\n            stride=constant(1),\n        )\n        m.train(random())\n        device = random_device()\n        m.to(device)\n        x = random_tensor(\n            ndim=3, dim0=constant(2), dim1=constant(36), dim2=constant(16)\n        ).to(device)\n        y = m(x)\n        func_y = torch.nn.functional.fold(\n            x,\n            output_size=constant((4, 4)),\n            kernel_size=constant(3),\n            dilation=constant(1),\n            padding=constant(1),\n            stride=constant(1),\n        )\n        return y, func_y\n\n    @autotest(n=3, auto_backward=True, rtol=1e-4, atol=1e-4)\n    def test_fold_with_random_data_2(test_case):\n        m = torch.nn.Fold(\n            output_size=constant((4, 4)),\n            kernel_size=constant(3),\n            dilation=constant(1),\n            padding=constant(0),\n            stride=constant(1),\n        )\n        m.train(random())\n        device = random_device()\n        m.to(device)\n        x = random_tensor(\n            ndim=3, dim0=constant(2), dim1=constant(36), dim2=constant(4)\n        ).to(device)\n        y = m(x)\n        func_y = torch.nn.functional.fold(\n            x,\n            output_size=constant((4, 4)),\n            kernel_size=constant(3),\n            dilation=constant(1),\n            padding=constant(0),\n            stride=constant(1),\n        )\n        return y, func_y\n\n    @autotest(n=3, auto_backward=True, rtol=1e-4, atol=1e-4)\n    def test_fold_with_random_data_3(test_case):\n        m = torch.nn.Fold(\n            output_size=constant((8, 8)),\n            kernel_size=constant(3),\n            dilation=constant(1),\n            padding=constant(1),\n            stride=constant(2),\n        )\n        m.train(random())\n        device = random_device()\n        m.to(device)\n        x = random_tensor(\n            ndim=3, dim0=constant(2), dim1=constant(72), dim2=constant(16)\n        ).to(device)\n        y = m(x)\n        func_y = torch.nn.functional.fold(\n            x,\n            output_size=constant((8, 8)),\n            kernel_size=constant(3),\n            dilation=constant(1),\n            padding=constant(1),\n            stride=constant(2),\n        )\n        return y, func_y\n\n    @autotest(n=3, auto_backward=True, rtol=1e-4, atol=1e-4)\n    def test_fold_with_random_data_4(test_case):\n        m = torch.nn.Fold(\n            output_size=constant((8, 8)),\n            kernel_size=constant(3),\n            dilation=constant(2),\n            padding=constant(1),\n            stride=constant(2),\n        )\n        m.train(random())\n        device = random_device()\n        m.to(device)\n        x = random_tensor(\n            ndim=3, dim0=constant(2), dim1=constant(9), dim2=constant(9)\n        ).to(device)\n        y = m(x)\n        func_y = torch.nn.functional.fold(\n            x,\n            output_size=constant((8, 8)),\n            kernel_size=constant(3),\n            dilation=constant(2),\n            padding=constant(1),\n            stride=constant(2),\n        )\n        return y, func_y\n\n    @profile(torch.nn.functional.fold)\n    def profile_fold(test_case):\n        x = torch.ones(128, 128, 4)\n        torch.nn.functional.fold(x, output_size=(4, 4), kernel_size=(2, 2), stride=2)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_fork_sub_process.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\n\nfrom multiprocessing.pool import Pool\nimport numpy as np\n\nfrom oneflow.test_utils.automated_test_util import *\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\n\n\ndef _test_fork_sub_process(id):\n    print(\"\\nchild process:%s start! process id: %d\" % (id, os.getpid()))\n    import oneflow as flow\n\n    x = flow.tensor(np.ones((4, 16)), device=\"cpu\")\n    y = flow.tensor(np.ones((16)), device=\"cpu\")\n    z = x + y\n    assert np.array_equal(z.numpy(), np.ones((4, 16)) * 2)\n    print(\"%s child process done! process id: %d.\" % (id, os.getpid()))\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestForkSubProcess(flow.unittest.TestCase):\n    def test_fork_sub_process(test_case):\n        flow._oneflow_internal.eager.Sync()\n        print(\"=============main process start=============\")\n        # process pool\n        num_process = 4\n        p = Pool(num_process)\n        async_res = []\n        for i in range(num_process):  # create n child processes\n            # put it to pool\n            async_res.append(p.apply_async(_test_fork_sub_process, args=(i,)))\n        p.close()\n        p.join()\n        for i in range(num_process):\n            test_case.assertTrue(async_res[i].successful())\n\n        print(\"=============main process done!=============\")\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_frac.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nfrom oneflow.test_utils.automated_test_util import *\nimport oneflow as flow\nimport oneflow.unittest\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestFrac(flow.unittest.TestCase):\n    @autotest(n=5)\n    def test_frac(test_case):\n        device = random_device()\n        ndim = random(2, 4).to(int).value()\n        shape = [random(2, 4) for i in range(ndim)]\n        input = random_tensor(ndim, *shape).to(device)\n        output = torch.frac(input)\n        return output\n\n    @autotest(n=5)\n    def test_tensor_frac(test_case):\n        device = random_device()\n        ndim = random(2, 4).to(int).value()\n        shape = [random(2, 4) for i in range(ndim)]\n        input = random_tensor(ndim, *shape).to(device)\n        output = input.frac()\n        return output\n\n    @autotest(n=5)\n    def test_tensor_frac_inplace(test_case):\n        device = random_device()\n        ndim = random(2, 4).to(int).value()\n        shape = [random(2, 4) for i in range(ndim)]\n        input = random_tensor(ndim, *shape).to(device)\n        input = input + 1.0\n        input.frac_()\n        return input\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_from_numpy.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport random\nimport unittest\n\nimport torch\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestFromNumpy(flow.unittest.TestCase):\n    def test_same_data(test_case):\n        np_arr = np.random.randn(3, 4, 5)\n        tensor = flow.from_numpy(np_arr)\n        test_case.assertTrue(np.array_equal(np_arr, tensor.numpy()))\n        test_case.assertEqual(tensor.size(), (3, 4, 5))\n        test_case.assertEqual(tensor.stride(), (20, 5, 1))\n        test_case.assertEqual(tensor.storage_offset(), 0)\n\n        np_arr[1:2, 2:3, 3:4] = random.random()\n        test_case.assertTrue(np.array_equal(np_arr, tensor.numpy()))\n\n    def test_use_ops(test_case):\n        np_arr = np.random.randn(3, 4, 5)\n        tensor = flow.from_numpy(np_arr)\n        res = tensor ** 2\n        test_case.assertTrue(np.allclose(np_arr ** 2, res.numpy()))\n\n    def test_more_dtype(test_case):\n        for dtype in [\n            np.float64,\n            np.float32,\n            np.float16,\n            np.int64,\n            np.int32,\n            np.int8,\n            np.uint8,\n        ]:\n            np_arr = np.ones((2, 3), dtype=dtype)\n            tensor = flow.from_numpy(np_arr)\n            # TODO(wyg): oneflow.float16 do not support to copy from tensor to numpy\n            if tensor.dtype not in [flow.float16]:\n                test_case.assertTrue(np.array_equal(np_arr, tensor.numpy()))\n\n    def test_non_contiguous_input(test_case):\n        np_arr = np.random.randn(2, 3, 4, 5).transpose(2, 0, 3, 1)\n        flow_tensor = flow.from_numpy(np_arr)\n        torch_tensor = torch.from_numpy(np_arr)\n        test_case.assertTrue(flow_tensor.shape == torch_tensor.shape)\n        test_case.assertTrue(flow_tensor.stride() == torch_tensor.stride())\n        test_case.assertTrue(\n            flow_tensor.is_contiguous() == torch_tensor.is_contiguous()\n        )\n        test_case.assertTrue(np.array_equal(flow_tensor.numpy(), torch_tensor.numpy()))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_from_torch.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nimport numpy as np\nimport os\n\nimport oneflow as flow\nimport oneflow.unittest\nimport torch\n\n\ndef torch_device_to_flow(device):\n    if device.type == \"cpu\":\n        return flow.device(\"cpu\")\n    elif device.type == \"cuda\":\n        return flow.device(\"cuda\", device.index)\n    else:\n        raise NotImplementedError(\"Unsupported device type: {}\".format(device.type))\n\n\nclass TestFromTroch(flow.unittest.TestCase):\n    @flow.unittest.skip_unless_1n1d()\n    def test_from_torch_cpu(test_case):\n        torch_t = torch.rand(5, 3, 3)\n        numpy_from_torch = torch_t.numpy()\n\n        # NOTE: torch and numpy shared the same memory.\n        test_case.assertEqual(\n            torch_t.data_ptr(), numpy_from_torch.__array_interface__[\"data\"][0]\n        )\n        numpy_from_torch[0][0] = [1, 2, 3]\n        test_case.assertTrue(\n            np.allclose(torch_t.numpy(), numpy_from_torch, rtol=0.001, atol=0.001)\n        )\n\n        # NOTE: oneflow and numpy shared the same memory,\n        #   so oneflow and torch cpu tensor shared the same memory,\n        #   which means oneflow can use torch's cpu tensor without cost.\n        flow_t = flow.utils.tensor.from_torch(torch_t)\n\n        test_case.assertTrue(\n            np.allclose(torch_t.numpy(), flow_t.numpy(), rtol=0.001, atol=0.001)\n        )\n        test_case.assertEqual(torch_t.numpy().dtype, flow_t.numpy().dtype)\n\n    # NOTE: For the case of 0 size tensor, no memory addresses are compared.\n    #  Because the address of 0 size tensor is random at this time.\n    @flow.unittest.skip_unless_1n1d()\n    def test_from_torch_cpu_with_0_size_data(test_case):\n        torch_t = torch.rand(5, 0, 3)\n\n        flow_t = flow.utils.tensor.from_torch(torch_t)\n\n        test_case.assertTrue(\n            np.allclose(torch_t.numpy(), flow_t.numpy(), rtol=0.001, atol=0.001)\n        )\n        test_case.assertEqual(torch_t.numpy().dtype, flow_t.numpy().dtype)\n\n    @flow.unittest.skip_unless_1n1d()\n    def test_from_torch_cpu_with_0dim_data(test_case):\n        torch_t = torch.tensor(5)\n        numpy_from_torch = torch_t.numpy()\n\n        test_case.assertEqual(\n            torch_t.data_ptr(), numpy_from_torch.__array_interface__[\"data\"][0]\n        )\n\n        flow_t = flow.utils.tensor.from_torch(torch_t)\n\n        test_case.assertTrue(\n            np.allclose(torch_t.numpy(), flow_t.numpy(), rtol=0.001, atol=0.001)\n        )\n        test_case.assertEqual(torch_t.numpy().dtype, flow_t.numpy().dtype)\n\n    @flow.unittest.skip_unless_1n2d()\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    def test_from_torch_gpu(test_case):\n        for device in [torch.device(\"cuda\", 0), torch.device(\"cuda\", 1)]:\n            torch_t = torch.tensor([1, 2]).to(device)\n\n            flow_t = flow.utils.tensor.from_torch(torch_t)\n\n            test_case.assertTrue(np.array_equal(torch_t.cpu().numpy(), flow_t.numpy()))\n            test_case.assertEqual(torch_t.cpu().numpy().dtype, flow_t.numpy().dtype)\n            test_case.assertEqual(torch_device_to_flow(torch_t.device), flow_t.device)\n\n            # Test oneflow tensor and pytorch tensor share the data\n            torch_t[0] = 5\n            test_case.assertTrue(np.array_equal(torch_t.cpu().numpy(), flow_t.numpy()))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_functional_docstr.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport inspect\nimport os\nimport unittest\nfrom collections import OrderedDict\n\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.test_util import GenArgList\n\n\ndef _run_functional_doctest(\n    test_case,\n    globs=None,\n    verbose=None,\n    optionflags=0,\n    raise_on_error=True,\n    module=flow,\n):\n    import doctest\n\n    parser = doctest.DocTestParser()\n    if raise_on_error:\n        runner = doctest.DebugRunner(verbose=verbose, optionflags=optionflags)\n    else:\n        runner = doctest.DocTestRunner(verbose=verbose, optionflags=optionflags)\n    r = inspect.getmembers(module)\n    for (name, fun) in r:\n        if fun.__doc__ is not None:\n            test = parser.get_doctest(fun.__doc__, {}, __name__, __file__, 0)\n            try:\n                runner.run(test)\n            except doctest.DocTestFailure as e:\n                print(f\"\\nGot error result in the docstring of {name}\")\n                print(f\"got output: {e.got}\")\n                raise e\n            except doctest.UnexpectedException as e:\n                print(f\"\\nGot UnexpectedException in the docstring of {name}\")\n                raise e.exc_info[1]\n\n    if not raise_on_error:\n        test_case.assertEqual(\n            runner.failures,\n            0,\n            f\"{runner.summarize()}, please turn on raise_on_error to see more details\",\n        )\n\n\n@flow.unittest.skip_unless_1n1d()\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\nclass TestFunctionalDocstrModule(flow.unittest.TestCase):\n    def test_functional_docstr(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"module\"] = [flow, flow.Tensor, flow.sbp, flow.env, flow.nn.functional]\n        for arg in GenArgList(arg_dict):\n            _run_functional_doctest(\n                test_case, raise_on_error=True, verbose=True, module=arg[0]\n            )\n\n\nif __name__ == \"__main__\":\n    flow.set_printoptions(linewidth=80)\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_functional_scalar_tensor_param.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestFunctionalWithScalarTensorParam(flow.unittest.TestCase):\n    # NOTE: graph mode not support dynamic scalar tensor parameter\n    @autotest(n=2, auto_backward=False, check_graph=False)\n    def test_scalar_tensor_transfer_to_scalar(test_case):\n        device = random_device()\n        min = torch.tensor(0.0)\n        max = torch.tensor(0.5)\n        x = random_tensor(ndim=2, dim0=2, dim1=3).to(device)\n        return x.clamp(min=min, max=max)\n\n    @autotest(n=2, auto_backward=False, check_graph=False)\n    def test_scalar_tensor_transfer_to_double(test_case):\n        device = random_device()\n        threshold = torch.tensor(0.5).to(device)\n        x = random_tensor(ndim=2, dim0=2, dim1=3).to(device)\n        return torch.nn.functional.threshold(x, threshold=threshold, value=0.5)\n\n    @autotest(n=2, auto_backward=False, check_graph=False)\n    def test_scalar_tensor_transfer_to_int(test_case):\n        device = random_device()\n        start_dim = torch.tensor(1).to(device)\n        end_dim = torch.tensor(3).to(device)\n        x = random_tensor(4, *(2, 3, 4, 5)).to(device)\n        return x.flatten(start_dim=start_dim, end_dim=end_dim)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_fused_attention_ops.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nfrom collections import OrderedDict\nimport numpy as np\nfrom oneflow.test_utils.test_util import GenArgList\nimport math\nimport itertools\nimport os\n\nimport oneflow as flow\n\n\ndef _ref(\n    query,\n    key,\n    value,\n    num_heads,\n    attn_mask_type=\"none\",\n    attn_bias=None,\n    causal_diagonal_offset=0,\n    query_seq_len=None,\n    key_seq_len=None,\n):\n    query = query.permute(0, 2, 1, 3)\n    key = key.permute(0, 2, 3, 1)\n    value = value.permute(0, 2, 1, 3)\n    scores = flow.matmul(query, key) / math.sqrt(query.shape[-1])\n    if attn_mask_type == \"causal_from_bottom_right\":\n        causal_diagonal_offset += key.shape[-1] - query.shape[-2]\n    if (\n        attn_mask_type == \"causal_from_top_left\"\n        or attn_mask_type == \"causal_from_bottom_right\"\n    ):\n        causal_mask = flow.triu(\n            flow.ones(\n                scores.shape[-2], scores.shape[-1], dtype=flow.bool, device=\"cuda\"\n            ),\n            causal_diagonal_offset + 1,\n        )\n        scores = flow.masked_fill(scores, causal_mask, float(\"-inf\"))\n    if attn_bias is not None:\n        scores = scores + attn_bias\n    if query_seq_len is not None:\n        scores = flow.masked_fill(\n            scores,\n            flow.arange(scores.shape[-2], device=query_seq_len.device).view(\n                1, 1, scores.shape[-2], 1\n            )\n            >= query_seq_len.view(scores.shape[0], 1, 1, 1),\n            float(\"-inf\"),\n        )\n    if key_seq_len is not None:\n        scores = flow.masked_fill(\n            scores,\n            flow.arange(scores.shape[-1], device=key_seq_len.device).view(\n                1, 1, 1, scores.shape[-1]\n            )\n            >= key_seq_len.view(scores.shape[0], 1, 1, 1),\n            float(\"-inf\"),\n        )\n    attn = flow.softmax(scores, dim=-1)\n    out = flow.matmul(attn, value)\n    out = out.permute(0, 2, 1, 3)\n    out = out.reshape(out.shape[0], out.shape[1], -1)\n    return out\n\n\ndef _to_layout(ts, layout, tensor_index, seq_len=None):\n    if layout == \"BMHK\":\n        return ts[tensor_index]\n    elif layout == \"BM(HK)\":\n        return ts[tensor_index].view(\n            ts[tensor_index].shape[0], ts[tensor_index].shape[1], -1\n        )\n    elif layout == \"MB(HK)\":\n        return (\n            ts[tensor_index]\n            .view(ts[tensor_index].shape[0], ts[tensor_index].shape[1], -1)\n            .transpose(0, 1)\n        )\n    elif layout == \"BHMK\":\n        return ts[tensor_index].transpose(1, 2)\n    elif layout == \"MBHK\":\n        return ts[tensor_index].transpose(0, 1)\n    elif layout == \"BM(H3K)\":\n        return flow.stack(ts, -2).view(ts[0].shape[0], ts[0].shape[1], -1)\n    elif layout == \"MB(H3K)\":\n        return (\n            flow.stack(ts, -2).view(ts[0].shape[0], ts[0].shape[1], -1).transpose(0, 1)\n        )\n    elif layout == \"BM(H2K)\":\n        return flow.stack(ts[1:], -2).view(ts[1].shape[0], ts[1].shape[1], -1)\n    elif layout == \"MB(H2K)\":\n        return (\n            flow.stack(ts[1:], -2)\n            .view(ts[1].shape[0], ts[1].shape[1], -1)\n            .transpose(0, 1)\n        )\n    elif layout == \"(BM)HK\":\n        t = ts[tensor_index]\n        if seq_len is None:\n            return t.view(-1, t.shape[-2], t.shape[-1])\n        mask = flow.arange(t.shape[1], device=t.device).view(\n            1, t.shape[1]\n        ) < seq_len.view(t.shape[0], 1)\n        return flow.masked_select(\n            t, mask.view(mask.shape[0], mask.shape[1], 1, 1)\n        ).view(-1, t.shape[-2], t.shape[-1])\n    elif layout == \"(BM)(HK)\":\n        t = ts[tensor_index]\n        if seq_len is None:\n            return t.view(-1, t.shape[-2] * t.shape[-1])\n        mask = flow.arange(t.shape[1], device=t.device).view(\n            1, t.shape[1]\n        ) < seq_len.view(t.shape[0], 1)\n        return flow.masked_select(\n            t, mask.view(mask.shape[0], mask.shape[1], 1, 1)\n        ).view(-1, t.shape[-2] * t.shape[-1])\n    elif layout == \"(BM)(H2K)\":\n        t = flow.stack(ts[1:], -2)\n        if seq_len is None:\n            return t.view(t.shape[0] * t.shape[1], -1)\n        mask = flow.arange(t.shape[1], device=t.device).view(\n            1, t.shape[1]\n        ) < seq_len.view(t.shape[0], 1)\n        return flow.masked_select(\n            t, mask.view(mask.shape[0], mask.shape[1], 1, 1, 1)\n        ).view(-1, t.shape[-3] * t.shape[-2] * t.shape[-1])\n    elif layout == \"(BM)(H3K)\":\n        t = flow.stack(ts, -2)\n        if seq_len is None:\n            return t.view(t.shape[0] * t.shape[1], -1)\n        mask = flow.arange(t.shape[1], device=t.device).view(\n            1, t.shape[1]\n        ) < seq_len.view(t.shape[0], 1)\n        return flow.masked_select(\n            t, mask.view(mask.shape[0], mask.shape[1], 1, 1, 1)\n        ).view(-1, t.shape[-3] * t.shape[-2] * t.shape[-1])\n    else:\n        raise NotImplementedError\n\n\ndef _fused_mha(\n    query,\n    key,\n    value,\n    num_heads,\n    attn_mask_type=\"none\",\n    attn_bias=None,\n    causal_diagonal_offset=0,\n    query_layout=\"BM(HK)\",\n    key_layout=\"BM(HK)\",\n    value_layout=\"BM(HK)\",\n    output_layout=\"MB(HK)\",\n    query_seq_len=None,\n    key_seq_len=None,\n    use_kv_seq_len=False,\n):\n    batch_size = query.shape[0]\n    query_max_seq_len = query.shape[1]\n    query_head_size = query.shape[-1]\n    key_max_seq_len = key.shape[1]\n    ts = [query, key, value]\n    query = _to_layout(ts, query_layout, 0, query_seq_len)\n    if use_kv_seq_len:\n        key = _to_layout(ts, key_layout, 1)\n        value = _to_layout(ts, value_layout, 2)\n    else:\n        key = _to_layout(ts, key_layout, 1, key_seq_len)\n        value = _to_layout(ts, value_layout, 2, key_seq_len)\n    if query_seq_len is not None:\n        query_seq_start = (\n            flow.cumsum(flow.pad(query_seq_len, (1, 0)), dim=-1)\n            .to(flow.int32)\n            .to(query.device)\n        )\n    else:\n        query_seq_start = None\n        query_max_seq_len = None\n    if key_seq_len is not None:\n        if use_kv_seq_len:\n            key_seq_start = flow.arange(\n                0,\n                key_max_seq_len * (batch_size + 1),\n                key_max_seq_len,\n                dtype=flow.int32,\n                device=key_seq_len.device,\n            )\n        else:\n            key_seq_start = (\n                flow.cumsum(flow.pad(key_seq_len, (1, 0)), dim=-1)\n                .to(flow.int32)\n                .to(query.device)\n            )\n    else:\n        key_seq_start = None\n        key_max_seq_len = None\n    if attn_bias is not None and attn_bias.shape[-1] % 8 != 0:\n        pad = 8 - attn_bias.shape[-1] % 8\n        attn_bias = flow.pad(attn_bias, (0, pad), \"constant\", 0)\n    output = flow._C.fused_multi_head_attention_inference_v2(\n        query=query,\n        key=key,\n        value=value,\n        query_head_size=query_head_size,\n        attn_mask_type=attn_mask_type,\n        attn_bias=attn_bias,\n        causal_diagonal_offset=causal_diagonal_offset,\n        query_layout=query_layout,\n        key_layout=key_layout,\n        value_layout=value_layout,\n        output_layout=output_layout,\n        query_seq_start=query_seq_start,\n        key_seq_start=key_seq_start,\n        key_seq_len=key_seq_len.to(flow.int32).to(\"cuda\") if use_kv_seq_len else None,\n        query_max_seq_len=query_max_seq_len,\n        key_max_seq_len=key_max_seq_len,\n    )\n    if output_layout == \"BM(HK)\" or output_layout == \"(BM)(HK)\":\n        return output\n    elif output_layout == \"MB(HK)\":\n        return output.transpose(0, 1)\n    else:\n        raise NotImplementedError\n\n\ndef _test_fused_attention_concat_past_key_value(\n    test_case,\n    dtype,\n    b,\n    past_m,\n    m,\n    h,\n    k,\n    past_key_layout,\n    past_value_layout,\n    key_layout,\n    value_layout,\n):\n    if past_m > 0:\n        past_key = flow.randn((b, past_m, h, k), device=\"cuda\", dtype=flow.float,).to(\n            dtype\n        )\n        past_value = flow.randn((b, past_m, h, k), device=\"cuda\", dtype=flow.float,).to(\n            dtype\n        )\n    else:\n        past_key = None\n        past_value = None\n    key = flow.randn((b, m, h, k), device=\"cuda\", dtype=flow.float,).to(dtype)\n    value = flow.randn((b, m, h, k), device=\"cuda\", dtype=flow.float,).to(dtype)\n\n    (\n        fused_concated_key,\n        fused_concated_value,\n    ) = flow._C.fused_attention_concat_past_key_value(\n        past_key=_to_layout([past_key, past_key, past_value], past_key_layout, 1),\n        past_key_layout=past_key_layout,\n        past_value=_to_layout([past_key, past_key, past_value], past_value_layout, 2),\n        past_value_layout=past_value_layout,\n        key=_to_layout([key, key, value], key_layout, 1),\n        key_layout=key_layout,\n        value=_to_layout([key, key, value], value_layout, 2),\n        value_layout=value_layout,\n        key_head_size=k,\n    )\n    if past_m > 0:\n        concated_key = flow.cat([past_key, key], dim=1)\n        concated_value = flow.cat([past_value, value], dim=1)\n    else:\n        concated_key = key\n        concated_value = value\n    ref_concated_key = _to_layout(\n        [concated_key, concated_key, concated_value], past_key_layout, 1\n    )\n    ref_concated_value = _to_layout(\n        [concated_key, concated_key, concated_value], past_value_layout, 2\n    )\n    test_case.assertTrue(\n        np.array_equal(fused_concated_key.numpy(), ref_concated_key.numpy())\n    )\n    test_case.assertTrue(\n        np.array_equal(fused_concated_value.numpy(), ref_concated_value.numpy())\n    )\n\n\ndef _test_fused_multi_head_attention_inference(\n    test_case,\n    batch_size,\n    num_heads,\n    query_seq_len,\n    kv_seq_len,\n    query_head_size,\n    value_head_size,\n    dtype,\n    attn_mask_type=\"none\",\n    causal_diagonal_offset=0,\n    query_layout=\"BM(HK)\",\n    key_layout=\"BM(HK)\",\n    value_layout=\"BM(HK)\",\n    output_layout=\"BM(HK)\",\n):\n    query = flow.randn(\n        (batch_size, query_seq_len, num_heads, query_head_size),\n        device=\"cuda\",\n        dtype=flow.float,\n    ).to(dtype)\n    key = flow.randn(\n        (batch_size, kv_seq_len, num_heads, query_head_size),\n        device=\"cuda\",\n        dtype=flow.float,\n    ).to(dtype)\n    value = flow.randn(\n        (batch_size, kv_seq_len, num_heads, value_head_size),\n        device=\"cuda\",\n        dtype=flow.float,\n    ).to(dtype)\n\n    fused_out = _fused_mha(\n        query,\n        key,\n        value,\n        num_heads,\n        attn_mask_type=attn_mask_type,\n        causal_diagonal_offset=causal_diagonal_offset,\n        query_layout=query_layout,\n        key_layout=key_layout,\n        value_layout=value_layout,\n        output_layout=output_layout,\n    ).numpy()\n    ref_out = _ref(\n        query,\n        key,\n        value,\n        num_heads,\n        attn_mask_type=attn_mask_type,\n        causal_diagonal_offset=causal_diagonal_offset,\n    ).numpy()\n\n    test_case.assertTrue(np.allclose(ref_out, fused_out, atol=1e-2, rtol=1e-2))\n\n\ndef _test_fused_multi_head_attention_inference_with_attn_bias(\n    test_case,\n    batch_size,\n    num_heads,\n    query_seq_len,\n    kv_seq_len,\n    query_head_size,\n    value_head_size,\n    dtype,\n    attn_mask_type=\"none\",\n):\n\n    query = flow.randn(\n        (batch_size, query_seq_len, num_heads, query_head_size),\n        device=\"cuda\",\n        dtype=flow.float,\n    ).to(dtype)\n    key = flow.randn(\n        (batch_size, kv_seq_len, num_heads, query_head_size),\n        device=\"cuda\",\n        dtype=flow.float,\n    ).to(dtype)\n    value = flow.randn(\n        (batch_size, kv_seq_len, num_heads, value_head_size),\n        device=\"cuda\",\n        dtype=flow.float,\n    ).to(dtype)\n\n    attn_bias = flow.randn((kv_seq_len,), device=\"cuda\", dtype=flow.float).to(dtype)\n    ref_out = _ref(\n        query, key, value, num_heads, attn_bias=attn_bias, attn_mask_type=attn_mask_type\n    ).numpy()\n    fused_out = _fused_mha(\n        query, key, value, num_heads, attn_bias=attn_bias, attn_mask_type=attn_mask_type\n    ).numpy()\n    test_case.assertTrue(np.allclose(ref_out, fused_out, atol=1e-2, rtol=1e-2))\n\n    attn_bias = flow.randn(\n        (query_seq_len, kv_seq_len), device=\"cuda\", dtype=flow.float\n    ).to(dtype)\n    ref_out = _ref(\n        query, key, value, num_heads, attn_bias=attn_bias, attn_mask_type=attn_mask_type\n    ).numpy()\n    fused_out = _fused_mha(\n        query, key, value, num_heads, attn_bias=attn_bias, attn_mask_type=attn_mask_type\n    ).numpy()\n    test_case.assertTrue(np.allclose(ref_out, fused_out, atol=1e-2, rtol=1e-2))\n\n    attn_bias = flow.randn(\n        (num_heads, query_seq_len, kv_seq_len), device=\"cuda\", dtype=flow.float\n    ).to(dtype)\n    ref_out = _ref(\n        query, key, value, num_heads, attn_bias=attn_bias, attn_mask_type=attn_mask_type\n    ).numpy()\n    fused_out = _fused_mha(\n        query, key, value, num_heads, attn_bias=attn_bias, attn_mask_type=attn_mask_type\n    ).numpy()\n    test_case.assertTrue(np.allclose(ref_out, fused_out, atol=1e-2, rtol=1e-2))\n\n    attn_bias = flow.randn(\n        (batch_size, num_heads, query_seq_len, kv_seq_len),\n        device=\"cuda\",\n        dtype=flow.float,\n    ).to(dtype)\n    ref_out = _ref(\n        query, key, value, num_heads, attn_bias=attn_bias, attn_mask_type=attn_mask_type\n    ).numpy()\n    fused_out = _fused_mha(\n        query, key, value, num_heads, attn_bias=attn_bias, attn_mask_type=attn_mask_type\n    ).numpy()\n    test_case.assertTrue(np.allclose(ref_out, fused_out, atol=1e-2, rtol=1e-2))\n\n    attn_bias = flow.randn(\n        (num_heads, 1, kv_seq_len), device=\"cuda\", dtype=flow.float\n    ).to(dtype)\n    ref_out = _ref(\n        query, key, value, num_heads, attn_bias=attn_bias, attn_mask_type=attn_mask_type\n    ).numpy()\n    fused_out = _fused_mha(\n        query, key, value, num_heads, attn_bias=attn_bias, attn_mask_type=attn_mask_type\n    ).numpy()\n    test_case.assertTrue(np.allclose(ref_out, fused_out, atol=1e-2, rtol=1e-2))\n\n\ndef _test_fused_multi_head_attention_inference_variable_length(\n    test_case,\n    batch_size,\n    num_heads,\n    query_seq_len,\n    kv_seq_len,\n    query_head_size,\n    value_head_size,\n    dtype,\n    query_layout,\n    key_layout,\n    value_layout,\n    use_kv_seq_len,\n    attn_mask_type=\"none\",\n    causal_diagonal_offset=0,\n):\n    query = flow.randn(\n        (batch_size, query_seq_len, num_heads, query_head_size),\n        device=\"cuda\",\n        dtype=flow.float,\n    ).to(dtype)\n    key = flow.randn(\n        (batch_size, kv_seq_len, num_heads, query_head_size),\n        device=\"cuda\",\n        dtype=flow.float,\n    ).to(dtype)\n    value = flow.randn(\n        (batch_size, kv_seq_len, num_heads, value_head_size),\n        device=\"cuda\",\n        dtype=flow.float,\n    ).to(dtype)\n\n    query_seq_len_t = flow.randint(\n        low=1,\n        high=query.shape[1],\n        size=(query.shape[0],),\n        device=\"cuda\",\n        dtype=flow.int32,\n    )\n    key_seq_len_t = flow.randint(\n        low=1, high=key.shape[1], size=(key.shape[0],), device=\"cuda\", dtype=flow.int32\n    )\n\n    fused_out = _fused_mha(\n        query,\n        key,\n        value,\n        num_heads,\n        attn_mask_type=attn_mask_type,\n        causal_diagonal_offset=causal_diagonal_offset,\n        query_layout=query_layout,\n        key_layout=key_layout,\n        value_layout=value_layout,\n        output_layout=\"(BM)(HK)\",\n        query_seq_len=query_seq_len_t,\n        key_seq_len=key_seq_len_t,\n        use_kv_seq_len=use_kv_seq_len,\n    )\n    ref_out = _ref(\n        query,\n        key,\n        value,\n        num_heads,\n        attn_mask_type=attn_mask_type,\n        causal_diagonal_offset=causal_diagonal_offset,\n        query_seq_len=query_seq_len_t,\n        key_seq_len=key_seq_len_t,\n    )\n    ref_out = ref_out.view(batch_size, query_seq_len, num_heads, value_head_size)\n    ref_out = _to_layout([ref_out], \"(BM)HK\", 0, seq_len=query_seq_len_t)\n    ref_out = ref_out.view(ref_out.shape[0], -1)\n\n    test_case.assertTrue(\n        np.allclose(ref_out.numpy(), fused_out.numpy(), atol=1e-2, rtol=1e-2)\n    )\n\n\n@unittest.skipIf(True, \"skip test\")\n@flow.unittest.skip_unless_1n1d()\nclass TestFusedMultiHeadAttentionInference(flow.unittest.TestCase):\n    def test_multi_head_attention_inference(test_case):\n        # test_case,batch_size, num_heads,query_seq_len, kv_seq_len,query_head_size,value_head_size,dtype\n        _test_fused_multi_head_attention_inference(\n            test_case, 2, 8, 4096, 4096, 40, 40, flow.float16\n        )\n        _test_fused_multi_head_attention_inference(\n            test_case, 2, 8, 4096, 77, 40, 40, flow.float16\n        )\n        _test_fused_multi_head_attention_inference(\n            test_case, 2, 8, 1024, 1024, 80, 80, flow.float16\n        )\n        _test_fused_multi_head_attention_inference(\n            test_case, 2, 8, 1024, 77, 80, 80, flow.float16\n        )\n        _test_fused_multi_head_attention_inference(\n            test_case, 2, 8, 256, 256, 160, 160, flow.float16\n        )\n        _test_fused_multi_head_attention_inference(\n            test_case, 2, 8, 256, 77, 160, 160, flow.float16\n        )\n\n        _test_fused_multi_head_attention_inference(\n            test_case, 2, 8, 4096, 4096, 40, 40, flow.float\n        )\n        _test_fused_multi_head_attention_inference(\n            test_case, 2, 8, 4096, 77, 40, 40, flow.float\n        )\n        _test_fused_multi_head_attention_inference(\n            test_case, 2, 8, 1024, 1024, 80, 80, flow.float\n        )\n        _test_fused_multi_head_attention_inference(\n            test_case, 2, 8, 1024, 77, 80, 80, flow.float\n        )\n        _test_fused_multi_head_attention_inference(\n            test_case, 2, 8, 256, 256, 160, 160, flow.float\n        )\n        _test_fused_multi_head_attention_inference(\n            test_case, 2, 8, 256, 77, 160, 160, flow.float\n        )\n        _test_fused_multi_head_attention_inference(\n            test_case,\n            1,\n            8,\n            4,\n            8,\n            16,\n            16,\n            flow.float,\n            attn_mask_type=\"causal_from_top_left\",\n            causal_diagonal_offset=4,\n        )\n\n    def test_multi_head_attention_inference_with_attn_bias(test_case):\n        # test_case,batch_size, num_heads,query_seq_len, kv_seq_len,query_head_size,value_head_size,dtype\n        _test_fused_multi_head_attention_inference_with_attn_bias(\n            test_case, 2, 8, 4096, 4096, 40, 40, flow.float16\n        )\n        _test_fused_multi_head_attention_inference_with_attn_bias(\n            test_case, 2, 8, 4096, 4096, 40, 40, flow.float\n        )\n        _test_fused_multi_head_attention_inference_with_attn_bias(\n            test_case, 2, 8, 4096, 4096, 40, 40, flow.float16, \"causal_from_top_left\"\n        )\n        _test_fused_multi_head_attention_inference_with_attn_bias(\n            test_case, 2, 8, 4096, 4096, 40, 40, flow.float, \"causal_from_bottom_right\"\n        )\n        _test_fused_multi_head_attention_inference_with_attn_bias(\n            test_case, 2, 8, 4096, 80, 40, 40, flow.float16\n        )\n        _test_fused_multi_head_attention_inference_with_attn_bias(\n            test_case, 2, 8, 4096, 80, 40, 40, flow.float\n        )\n        _test_fused_multi_head_attention_inference_with_attn_bias(\n            test_case, 2, 8, 4096, 80, 40, 40, flow.float16, \"causal_from_top_left\"\n        )\n        _test_fused_multi_head_attention_inference_with_attn_bias(\n            test_case, 2, 8, 80, 4096, 40, 40, flow.float16, \"causal_from_bottom_right\"\n        )\n        _test_fused_multi_head_attention_inference_with_attn_bias(\n            test_case, 2, 8, 4096, 80, 40, 40, flow.float, \"causal_from_top_left\"\n        )\n        _test_fused_multi_head_attention_inference_with_attn_bias(\n            test_case, 2, 8, 4096, 77, 40, 40, flow.float, \"causal_from_top_left\"\n        )\n\n    def test_multi_head_attention_inference_with_layout(test_case):\n        layouts = [\n            \"BM(HK)\",\n            \"BMHK\",\n            \"MBHK\",\n            \"BHMK\",\n            \"MB(HK)\",\n            \"BM(H3K)\",\n            \"BM(H2K)\",\n            \"MB(H3K)\",\n            \"MB(H2K)\",\n        ]\n        for query_layout, key_layout, value_layout in itertools.product(\n            layouts, layouts, layouts\n        ):\n            if query_layout == \"BM(H2K)\" or query_layout == \"MB(H2K)\":\n                continue\n            _test_fused_multi_head_attention_inference(\n                test_case,\n                2,\n                8,\n                256,\n                256,\n                160,\n                160,\n                flow.float16,\n                query_layout=query_layout,\n                key_layout=key_layout,\n                value_layout=value_layout,\n            )\n\n    def test_multi_head_attention_inference_with_output_layout(test_case):\n        layouts = [\n            \"BM(HK)\",\n            \"MB(HK)\",\n        ]\n        for output_layout in layouts:\n            _test_fused_multi_head_attention_inference(\n                test_case,\n                2,\n                8,\n                256,\n                256,\n                160,\n                160,\n                flow.float16,\n                output_layout=output_layout,\n            )\n            _test_fused_multi_head_attention_inference(\n                test_case,\n                1,\n                8,\n                256,\n                256,\n                160,\n                160,\n                flow.float16,\n                output_layout=output_layout,\n            )\n\n    def test_multi_head_attention_inference_variable_length(test_case):\n        # test_case,batch_size, num_heads,query_seq_len, kv_seq_len,query_head_size,value_head_size,dtype\n        layouts = [\"(BM)HK\", \"(BM)(HK)\", \"(BM)(H2K)\", \"(BM)(H3K)\"]\n        for (\n            query_layout,\n            key_layout,\n            value_layout,\n            use_kv_seq_len,\n        ) in itertools.product(layouts, layouts, layouts, (False, True)):\n            if query_layout == \"(BM)(H2K)\":\n                continue\n            _test_fused_multi_head_attention_inference_variable_length(\n                test_case,\n                2,\n                8,\n                16,\n                16,\n                40,\n                40,\n                flow.float16,\n                query_layout=query_layout,\n                key_layout=key_layout,\n                value_layout=value_layout,\n                use_kv_seq_len=use_kv_seq_len,\n            )\n            if (\n                query_layout == \"(BM)(H3K)\"\n                or key_layout == \"(BM)(H3K)\"\n                or value_layout == \"(BM)(H3K)\"\n            ):\n                continue\n            _test_fused_multi_head_attention_inference_variable_length(\n                test_case,\n                2,\n                8,\n                16,\n                32,\n                40,\n                40,\n                flow.float16,\n                query_layout=query_layout,\n                key_layout=key_layout,\n                value_layout=value_layout,\n                use_kv_seq_len=use_kv_seq_len,\n            )\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n@flow.unittest.skip_unless_1n1d()\nclass TestFusedAttentionConcatPastKeyValue(flow.unittest.TestCase):\n    def test_fused_attention_concat_past_key_value(test_case):\n        kv_layouts = [\n            \"BM(HK)\",\n            \"BMHK\",\n            \"MBHK\",\n            \"BHMK\",\n            \"MB(HK)\",\n            \"BM(H3K)\",\n            # \"BM(H2K)\",\n            # \"MB(H3K)\",\n            \"MB(H2K)\",\n        ]\n\n        past_layouts = [\n            \"BM(HK)\",\n            \"BMHK\",\n            # \"MBHK\",\n            # \"BHMK\",\n            \"MB(HK)\",\n        ]\n\n        types = [flow.float16]\n        for (\n            past_key_layout,\n            past_value_layout,\n            key_layout,\n            value_layout,\n            dtype,\n        ) in itertools.product(\n            past_layouts, past_layouts, kv_layouts, kv_layouts, types\n        ):\n            _test_fused_attention_concat_past_key_value(\n                test_case,\n                dtype,\n                1,\n                127,\n                1,\n                40,\n                128,\n                past_key_layout=past_key_layout,\n                past_value_layout=past_value_layout,\n                key_layout=key_layout,\n                value_layout=value_layout,\n            )\n        _test_fused_attention_concat_past_key_value(\n            test_case,\n            flow.float,\n            1,\n            0,\n            1,\n            40,\n            128,\n            past_key_layout=\"BMHK\",\n            past_value_layout=\"BMHK\",\n            key_layout=\"BMHK\",\n            value_layout=\"BMHK\",\n        )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_fused_bias_add_dropout.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\nimport os\n\nimport numpy as np\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\ndef _test_fused_bias_add_dropout(test_case, shape, axis, drop_prob):\n    x = np.random.randn(*shape)\n    bias = np.random.randn(shape[axis])\n    # fused version only support in GPU\n    fused_x_tensor = flow.Tensor(x).to(\"cuda\")\n    fused_x_tensor.requires_grad = True\n    fused_bias_tensor = flow.Tensor(bias).to(\"cuda\")\n    fused_bias_tensor.requires_grad = True\n    fused_out = flow._C.fused_bias_add_dropout(\n        fused_x_tensor, fused_bias_tensor, p=drop_prob, axis=axis\n    )\n\n    origin_x_tensor = flow.Tensor(x).to(\"cuda\")\n    origin_x_tensor.requires_grad = True\n    origin_bias_tensor = flow.Tensor(bias).to(\"cuda\")\n    origin_bias_tensor.requires_grad = True\n\n    origin_dropout = flow.nn.Dropout(p=drop_prob)\n    origin_out = origin_dropout(\n        flow._C.bias_add(origin_x_tensor, origin_bias_tensor, axis=axis)\n    )\n\n    total_out = fused_out.sum() + origin_out.sum()\n    total_out.backward()\n\n    test_case.assertTrue(\n        np.allclose(fused_out.numpy(), origin_out.numpy(), atol=1e-4, rtol=1e-4)\n    )\n    test_case.assertTrue(\n        np.allclose(\n            fused_x_tensor.grad.numpy(),\n            origin_x_tensor.grad.numpy(),\n            atol=1e-4,\n            rtol=1e-4,\n        )\n    )\n    test_case.assertTrue(\n        np.allclose(\n            fused_bias_tensor.grad.numpy(),\n            origin_bias_tensor.grad.numpy(),\n            atol=1e-4,\n            rtol=1e-4,\n        )\n    )\n\n\n@flow.unittest.skip_unless_1n1d()\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test gpu cases\")\nclass TestFusedBiasAddDropout(flow.unittest.TestCase):\n    def test_fuse_bias_add_dropout(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [_test_fused_bias_add_dropout]\n        arg_dict[\"shape\"] = [(16, 64, 72), (32, 16, 48)]\n        arg_dict[\"axis\"] = [0, 1, 2, -1, -2, -3]\n        arg_dict[\"drop_prob\"] = [0.0, 1.0]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_fused_bias_add_gelu.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\nimport os\n\nimport numpy as np\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\ndef _test_fused_bias_add_gelu(test_case, channel, axis):\n    x = np.random.randn(4, channel, 8, 10)\n    bias = np.random.randn(channel)\n    # fused version only support in GPU\n    fused_x_tensor = flow.Tensor(x).to(\"cuda\")\n    fused_x_tensor.requires_grad = True\n    fused_bias_tensor = flow.Tensor(bias).to(\"cuda\")\n    fused_bias_tensor.requires_grad = True\n    fused_out = flow._C.fused_bias_add_gelu(\n        fused_x_tensor, fused_bias_tensor, axis=axis\n    )\n\n    origin_x_tensor = flow.Tensor(x).to(\"cuda\")\n    origin_x_tensor.requires_grad = True\n    origin_bias_tensor = flow.Tensor(bias).to(\"cuda\")\n    origin_bias_tensor.requires_grad = True\n    origin_out = flow.gelu(\n        flow._C.bias_add(origin_x_tensor, origin_bias_tensor, axis=axis)\n    )\n\n    total_out = fused_out.sum() + origin_out.sum()\n    total_out.backward()\n\n    test_case.assertTrue(\n        np.allclose(fused_out.numpy(), origin_out.numpy(), atol=1e-4, rtol=1e-4)\n    )\n    test_case.assertTrue(\n        np.allclose(\n            fused_x_tensor.grad.numpy(),\n            origin_x_tensor.grad.numpy(),\n            atol=1e-4,\n            rtol=1e-4,\n        )\n    )\n    test_case.assertTrue(\n        np.allclose(\n            fused_bias_tensor.grad.numpy(),\n            origin_bias_tensor.grad.numpy(),\n            atol=1e-4,\n            rtol=1e-4,\n        )\n    )\n\n\n@flow.unittest.skip_unless_1n1d()\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test gpu cases\")\nclass TestFusedBiasAddGelu(flow.unittest.TestCase):\n    def test_gather(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [_test_fused_bias_add_gelu]\n        arg_dict[\"channel\"] = [2, 4, 6, 8]\n        arg_dict[\"axis\"] = [1]\n\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_fused_bias_add_scale_mask_softmax_dropout.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nimport os\nimport numpy as np\nfrom collections import OrderedDict\n\nimport torch\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.test_util import GenArgDict\n\n\ndef _torch_bias_add_scale_mask_softmax_dropout(x, bias, mask, fill, scale, p):\n    masked = (x + bias) * mask * scale\n    unmask = (1 - mask.int()).bool()\n    masked.masked_fill_(unmask, fill)\n    softmax_y = torch.nn.functional.softmax(masked, dim=-1)\n    y = torch.nn.functional.dropout(softmax_y, p)\n    return y, softmax_y\n\n\ndef _test_bias_add_fused_scale_mask_softmax_dropout(\n    test_case,\n    input_shape,\n    bias_shape,\n    mask_shape,\n    input_dtype=flow.float32,\n    mask_dtype=flow.bool,\n    fill=-10000,\n    scale=1.0,\n    p=0.0,\n    device=\"cuda\",\n):\n    # print(f\"{'=' * 40} test case {'=' * 40}\")\n    # print(f\"input_shap={input_shape}\")\n    # print(f\"bias_shape={bias_shape}\")\n    # print(f\"mask_shape={mask_shape}\")\n    # print(f\"input_dtype={input_dtype}\")\n    # print(f\"mask_dtype={mask_dtype}\")\n    # print(f\"fill={fill}\")\n    # print(f\"scale={scale}\")\n    # print(f\"p={p}\")\n\n    np_input = np.random.randn(*input_shape).astype(np.float32)\n    np_bias = np.random.randn(*bias_shape).astype(np.float32)\n    np_mask = np.random.randint(0, 2, size=mask_shape).astype(np.int32)\n    np_rand_init_grad = np.random.randn(*input_shape).astype(np.float32)\n\n    torch_input = torch.tensor(np_input).to(device=device)\n    torch_bias = torch.tensor(np_bias).to(device=device)\n    torch_mask = torch.tensor(np_mask).to(device=device).bool()\n    torch_rand_init_grad = torch.tensor(np_rand_init_grad).to(device=device)\n    torch_input.requires_grad_(True)\n    torch_bias.requires_grad_(True)\n    torch_output, torch_softmax_output = _torch_bias_add_scale_mask_softmax_dropout(\n        torch_input, torch_bias, torch_mask, fill, scale, p\n    )\n    (torch_output * torch_rand_init_grad).sum().backward()\n    torch_input_grad = torch_input.grad.detach().cpu()\n    torch_bias_grad = torch_bias.grad.detach().cpu()\n    torch_output = torch_output.detach().cpu()\n    torch_softmax_output = torch_softmax_output.detach().cpu()\n\n    input = flow.tensor(np_input, dtype=input_dtype, device=device)\n    bias = flow.tensor(np_bias, dtype=input_dtype, device=device)\n    mask = flow.tensor(np_mask, dtype=mask_dtype, device=device)\n    rand_init_grad = flow.tensor(np_rand_init_grad, dtype=input_dtype, device=device)\n    input.requires_grad_(True)\n    bias.requires_grad_(True)\n    output, softmax_output = flow._C.fused_bias_add_scale_mask_softmax_dropout(\n        input, bias, mask, fill_value=fill, scale=scale, p=p,\n    )\n    (output * rand_init_grad).sum().backward()\n    input_grad = input.grad.detach().cpu()\n    bias_grad = bias.grad.detach().cpu()\n    output = output.to(dtype=flow.float32, device=\"cpu\")\n    softmax_output = softmax_output.to(dtype=flow.float32, device=\"cpu\")\n\n    def compare(a, b, a_name, b_name, atol=1e-5, rtol=1e-8):\n        test_case.assertTrue(\n            np.allclose(a.numpy(), b.numpy(), atol=atol, rtol=rtol),\n            f\"\\n{a_name}:\\n{a.numpy()}\\n{'-' * 80}\\n{b_name}:\\n{b.numpy()}\\n{'*' * 80}\\ndiff:\\n{a.numpy() - b.numpy()}\\n{a_name} vs. {b_name} max_diff:\\n{np.max(np.abs(a.numpy() - b.numpy()))}\",\n        )\n\n    if input_dtype == flow.float16:\n        compare(output, torch_output, \"output\", \"torch_output\", atol=1e-3, rtol=1e-2)\n        compare(\n            softmax_output,\n            torch_softmax_output,\n            \"softmax_output\",\n            \"torch_softmax_output\",\n            atol=1e-3,\n            rtol=1e-2,\n        )\n        compare(\n            input_grad,\n            torch_input_grad,\n            \"input_grad\",\n            \"torch_input_grad\",\n            atol=1e-2,\n            rtol=1e-2,\n        )\n        compare(\n            bias_grad,\n            torch_bias_grad,\n            \"bias_grad\",\n            \"torch_bias_grad\",\n            atol=1e-2,\n            rtol=1e-2,\n        )\n    else:\n        compare(output, torch_output, \"output\", \"torch_output\")\n        compare(\n            softmax_output,\n            torch_softmax_output,\n            \"softmax_output\",\n            \"torch_softmax_output\",\n        )\n        compare(input_grad, torch_input_grad, \"input_grad\", \"torch_input_grad\")\n        compare(bias_grad, torch_bias_grad, \"bias_grad\", \"torch_bias_grad\")\n\n\n@flow.unittest.skip_unless_1n1d()\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test gpu cases\")\nclass TestFusedBiasAddScaleMaskSoftmaxDropout(flow.unittest.TestCase):\n    def test_real_case(test_case):\n        args_dict = OrderedDict()\n        args_dict[\"input_shape\"] = [[4, 12, 8, 8]]\n        args_dict[\"bias_shape\"] = [[1, 12, 8, 8]]\n        args_dict[\"mask_shape\"] = [[4, 1, 1, 8]]\n        args_dict[\"input_dtype\"] = [flow.float16, flow.float32]\n        args_dict[\"mask_dtype\"] = [flow.bool]\n        args_dict[\"fill\"] = [-10000.0]\n        args_dict[\"scale\"] = [1.0, 2.0, 4.0]\n        args_dict[\"p\"] = [0.0, 1.0]\n\n        for kwarg in GenArgDict(args_dict):\n            _test_bias_add_fused_scale_mask_softmax_dropout(test_case, **kwarg)\n\n    def test_different_broadcast_dim(test_case):\n        _test_bias_add_fused_scale_mask_softmax_dropout(\n            test_case, [4, 2, 3], [1, 2, 3], [4, 1, 3]\n        )\n\n    def test_same_broadcast_dim(test_case):\n        _test_bias_add_fused_scale_mask_softmax_dropout(\n            test_case, [4, 2, 3], [1, 2, 3], [1, 2, 3]\n        )\n\n    def test_broadcast_bias(test_case):\n        _test_bias_add_fused_scale_mask_softmax_dropout(\n            test_case, [4, 2, 3], [1, 1, 3], [4, 2, 3]\n        )\n\n    def test_broadcast_mask(test_case):\n        _test_bias_add_fused_scale_mask_softmax_dropout(\n            test_case, [4, 2, 3], [4, 2, 3], [4, 1, 3]\n        )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_fused_center.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nimport torch\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\ndef torch_center(b1_x1, b1_x2, b2_x1, b2_x2, b1_y1, b1_y2, b2_y1, b2_y2):\n    return (\n        (b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 + (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2\n    ) / 4\n\n\ndef _test_fused_get_center_dist_impl(test_case, device, shape):\n    def compare(a, b, rtol=1e-5, atol=1e-5):\n        test_case.assertTrue(\n            np.allclose(\n                a.detach().cpu().numpy(), b.detach().cpu().numpy(), rtol=rtol, atol=atol\n            ),\n            f\"\\na\\n{a.detach().cpu().numpy()}\\n{'-' * 80}\\nb:\\n{b.detach().cpu().numpy()}\\n{'*' * 80}\\ndiff:\\n{a.detach().cpu().numpy() - b.detach().cpu().numpy()}\",\n        )\n\n    x = []\n    torch_x = []\n    for _ in range(8):\n        tmp = flow.tensor(\n            np.random.randn(*shape),\n            dtype=flow.float32,\n            device=flow.device(device),\n            requires_grad=True,\n        )\n        x.append(tmp)\n        torch_x.append(\n            torch.tensor(\n                tmp.numpy(),\n                dtype=torch.float32,\n                device=torch.device(device),\n                requires_grad=True,\n            )\n        )\n    b1_x1, b1_x2, b2_x1, b2_x2, b1_y1, b1_y2, b2_y1, b2_y2 = (\n        x[0],\n        x[1],\n        x[2],\n        x[3],\n        x[4],\n        x[5],\n        x[6],\n        x[7],\n    )\n    (\n        torch_b1_x1,\n        torch_b1_x2,\n        torch_b2_x1,\n        torch_b2_x2,\n        torch_b1_y1,\n        torch_b1_y2,\n        torch_b2_y1,\n        torch_b2_y2,\n    ) = (\n        torch_x[0],\n        torch_x[1],\n        torch_x[2],\n        torch_x[3],\n        torch_x[4],\n        torch_x[5],\n        torch_x[6],\n        torch_x[7],\n    )\n    rho2 = flow._C.fused_get_center_dist(\n        b1_x1, b1_x2, b2_x1, b2_x2, b1_y1, b1_y2, b2_y1, b2_y2\n    )\n    torch_rho2 = torch_center(\n        torch_b1_x1,\n        torch_b1_x2,\n        torch_b2_x1,\n        torch_b2_x2,\n        torch_b1_y1,\n        torch_b1_y2,\n        torch_b2_y1,\n        torch_b2_y2,\n    )\n    compare(rho2, torch_rho2)\n\n    rho2.sum().backward()\n    torch_rho2.sum().backward()\n    compare(b1_x1.grad, torch_b1_x1.grad)\n    compare(b1_x2.grad, torch_b1_x2.grad)\n    compare(b2_x1.grad, torch_b2_x1.grad)\n    compare(b2_x2.grad, torch_b2_x2.grad)\n    compare(b1_y1.grad, torch_b1_y1.grad)\n    compare(b1_y2.grad, torch_b1_y2.grad)\n    compare(b2_y1.grad, torch_b2_y1.grad)\n    compare(b2_y2.grad, torch_b2_y2.grad)\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestGetCenterDistModule(flow.unittest.TestCase):\n    def test_fused_get_center_dist(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [_test_fused_get_center_dist_impl]\n        arg_dict[\"device\"] = [\"cuda\"]\n        arg_dict[\"shape\"] = [(583, 1), (759, 1), (1234, 1)]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_fused_codegeex_qkv_reshape.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\ndef _test_codegeex_qkv_reshape_impl(test_case, device, shape, num_attention_heads):\n    query = flow.randn(shape).to(\"cuda\")\n    key = flow.randn(shape).to(\"cuda\")\n    value = flow.randn(shape).to(\"cuda\")\n    new_shape = (\n        shape[0],\n        shape[1],\n        num_attention_heads,\n        shape[2] / num_attention_heads,\n    )\n    new_query = query.view(new_shape)\n    new_query = new_query.contiguous()\n    new_key = key.view(new_shape)\n    new_key = new_key.contiguous()\n    new_value = value.view(new_shape)\n    new_value = new_value.contiguous()\n    (\n        fused_new_query,\n        fused_new_key,\n        fused_new_value,\n    ) = flow._C.fused_codegeex_qkv_reshape(query, key, value, num_attention_heads)\n\n    def compare(a, b, rtol=1e-5, atol=1e-5):\n        test_case.assertTrue(\n            np.allclose(\n                a.detach().cpu().numpy(), b.detach().cpu().numpy(), rtol=rtol, atol=atol\n            ),\n            f\"\\na\\n{a.detach().cpu().numpy()}\\n{'-' * 80}\\nb:\\n{b.detach().cpu().numpy()}\\n{'*' * 80}\\ndiff:\\n{a.detach().cpu().numpy() - b.detach().cpu().numpy()}\",\n        )\n\n    compare(new_query, fused_new_query)\n    compare(new_key, fused_new_key)\n    compare(new_value, fused_new_value)\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestFusedCodegeexQkvReshapeModule(flow.unittest.TestCase):\n    def test_codegeex_qkv_reshape(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [_test_codegeex_qkv_reshape_impl]\n        arg_dict[\"device\"] = [\"cuda\"]\n        arg_dict[\"shape\"] = [(32, 8, 16), (32, 8, 32)]\n        arg_dict[\"num_attention_heads\"] = [(4), (8)]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_fused_cross_interaction.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nfrom collections import OrderedDict\nimport os\nimport numpy as np\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\n\n\ndef _test_fused_cross_feature_interaction_v1(\n    test_case, batchsize, in_feature, dtype, device,\n):\n    x = np.random.uniform(low=-1, high=1, size=(batchsize, in_feature))\n    weight = np.random.uniform(low=-1, high=1, size=(1, in_feature))\n    bias = np.random.uniform(low=-1, high=1, size=(in_feature))\n    x0 = np.random.uniform(low=-1, high=1, size=(batchsize, in_feature))\n\n    fused_x = flow.tensor(x, dtype=dtype, device=device, requires_grad=True)\n    naive_x = flow.tensor(x, dtype=dtype, device=device, requires_grad=True)\n    fused_weight = flow.tensor(weight, dtype=dtype, device=device, requires_grad=True)\n    naive_weight = flow.tensor(weight, dtype=dtype, device=device, requires_grad=True)\n    fused_bias = flow.tensor(bias, dtype=dtype, device=device, requires_grad=True)\n    naive_bias = flow.tensor(bias, dtype=dtype, device=device, requires_grad=True)\n    fused_x0 = flow.tensor(x0, dtype=dtype, device=device, requires_grad=True)\n    naive_x0 = flow.tensor(x0, dtype=dtype, device=device, requires_grad=True)\n\n    fused_out = flow._C.fused_cross_feature_interaction(\n        fused_x, fused_weight, fused_x0, fused_bias, \"vector\"\n    )\n\n    naive_out = (\n        flow._C.matmul(naive_x, naive_weight, transpose_b=True) * naive_x0 + naive_bias\n    ) + naive_x\n\n    total_out = fused_out.sum() + naive_out.sum()\n    total_out.backward()\n\n    test_case.assertTrue(\n        np.allclose(fused_out.numpy(), naive_out.numpy(), atol=1e-4, rtol=1e-4)\n    )\n    test_case.assertTrue(\n        np.allclose(fused_x.grad.numpy(), naive_x.grad.numpy(), atol=1e-4, rtol=1e-4,)\n    )\n    test_case.assertTrue(\n        np.allclose(\n            fused_weight.grad.numpy(), naive_weight.grad.numpy(), atol=1e-4, rtol=1e-4,\n        )\n    )\n    test_case.assertTrue(\n        np.allclose(fused_x0.grad.numpy(), naive_x0.grad.numpy(), atol=1e-4, rtol=1e-4,)\n    )\n    test_case.assertTrue(\n        np.allclose(\n            fused_bias.grad.numpy(), naive_bias.grad.numpy(), atol=1e-4, rtol=1e-4,\n        )\n    )\n\n\ndef _test_fused_cross_feature_interaction_v2(\n    test_case, batchsize, in_feature, dtype, device,\n):\n    x = np.random.uniform(low=-1, high=1, size=(batchsize, in_feature))\n    weight = np.random.uniform(low=-1, high=1, size=(in_feature, in_feature))\n    bias = np.random.uniform(low=-1, high=1, size=(in_feature))\n    x0 = np.random.uniform(low=-1, high=1, size=(batchsize, in_feature))\n\n    fused_x = flow.tensor(x, dtype=dtype, device=device, requires_grad=True)\n    naive_x = flow.tensor(x, dtype=dtype, device=device, requires_grad=True)\n    fused_weight = flow.tensor(weight, dtype=dtype, device=device, requires_grad=True)\n    naive_weight = flow.tensor(weight, dtype=dtype, device=device, requires_grad=True)\n    fused_bias = flow.tensor(bias, dtype=dtype, device=device, requires_grad=True)\n    naive_bias = flow.tensor(bias, dtype=dtype, device=device, requires_grad=True)\n    fused_x0 = flow.tensor(x0, dtype=dtype, device=device, requires_grad=True)\n    naive_x0 = flow.tensor(x0, dtype=dtype, device=device, requires_grad=True)\n\n    fused_out = flow._C.fused_cross_feature_interaction(\n        fused_x, fused_weight, fused_x0, fused_bias, \"matrix\"\n    )\n\n    naive_out = (\n        flow._C.bias_add(\n            flow._C.matmul(naive_x, naive_weight, transpose_b=True), naive_bias, axis=1\n        )\n        * naive_x0\n        + naive_x\n    )\n\n    total_out = fused_out.sum() + naive_out.sum()\n    total_out.backward()\n\n    test_case.assertTrue(\n        np.allclose(fused_out.numpy(), naive_out.numpy(), atol=1e-4, rtol=1e-4)\n    )\n    test_case.assertTrue(\n        np.allclose(fused_x.grad.numpy(), naive_x.grad.numpy(), atol=1e-4, rtol=1e-4,)\n    )\n    test_case.assertTrue(\n        np.allclose(\n            fused_weight.grad.numpy(), naive_weight.grad.numpy(), atol=1e-4, rtol=1e-4,\n        )\n    )\n    test_case.assertTrue(\n        np.allclose(fused_x0.grad.numpy(), naive_x0.grad.numpy(), atol=1e-4, rtol=1e-4,)\n    )\n    test_case.assertTrue(\n        np.allclose(\n            fused_bias.grad.numpy(), naive_bias.grad.numpy(), atol=1e-4, rtol=1e-4,\n        )\n    )\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n@flow.unittest.skip_unless_1n1d()\nclass TestFusedCrossFeatureInteraction(flow.unittest.TestCase):\n    def test_fused_cross_feature_interaction_v1(test_case):\n        args_dict = OrderedDict()\n        args_dict[\"test_fun\"] = [_test_fused_cross_feature_interaction_v1]\n        args_dict[\"batchsize\"] = [1, 2, 4]\n        args_dict[\"in_feature\"] = [32, 64, 96, 128]\n        args_dict[\"dtype\"] = [flow.float32]\n        args_dict[\"device\"] = [\"cuda\"]\n\n        for arg in GenArgList(args_dict):\n            arg[0](test_case, *arg[1:])\n\n    def test_fused_cross_feature_interaction_v2(test_case):\n        args_dict = OrderedDict()\n        args_dict[\"test_fun\"] = [_test_fused_cross_feature_interaction_v2]\n        args_dict[\"batchsize\"] = [1, 2, 4]\n        args_dict[\"in_feature\"] = [32, 64, 96, 128]\n        args_dict[\"dtype\"] = [flow.float32]\n        args_dict[\"device\"] = [\"cuda\"]\n\n        for arg in GenArgList(args_dict):\n            arg[0](test_case, *arg[1:])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_fused_dot_feature_interaction.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\nfrom oneflow.test_utils.test_util import GenArgDict\nimport numpy as np\nimport oneflow as flow\nimport oneflow.unittest\nimport os\n\n\ndef _test_fused_dot_feature_interaction(\n    test_case,\n    embedding_size,\n    self_interaction=False,\n    output_concat=True,\n    output_padding=0,\n    dtype=flow.float32,\n    device_type=\"cuda\",\n):\n    batch_size = 100\n    dims = 26\n    if dtype == flow.float16:\n        np_dtype = np.float16\n    else:\n        np_dtype = np.float32\n    feature_0_np = np.random.rand(batch_size, embedding_size).astype(np_dtype)\n    feature_1_np = np.random.rand(batch_size, 26, embedding_size).astype(np_dtype)\n    feature_0_tensor = flow.tensor(feature_0_np, device=\"cuda\", requires_grad=True)\n    feature_1_tensor = flow.tensor(feature_1_np, device=\"cuda\", requires_grad=True)\n    if self_interaction:\n        offset = 1\n    else:\n        offset = 0\n    li = flow.tensor([i for i in range(27) for j in range(i + offset)])\n    lj = flow.tensor([j for i in range(27) for j in range(i + offset)])\n    T = flow.cat(\n        [\n            flow.reshape(feature_0_tensor, (batch_size, 1, embedding_size)),\n            feature_1_tensor,\n        ],\n        dim=1,\n    )\n    Z = flow.matmul(T, T, transpose_b=True)\n    # gather_nd not support half, so cast to float32\n    Z = flow.cast(Z, flow.float32)\n    Zflat = Z[:, li, lj]\n    Zflat = flow.cast(Zflat, dtype)\n    if output_concat:\n        R = flow.cat([feature_0_tensor, Zflat], dim=1)\n    else:\n        R = Zflat\n    if output_padding != 0:\n        padding_tensor = flow.tensor(\n            np.zeros((batch_size, output_padding)).astype(np_dtype),\n            device=\"cuda\",\n            requires_grad=False,\n        )\n        R = flow.cat([R, padding_tensor], dim=1)\n    loss = R.sum()\n    loss.backward()\n\n    fused_feature_0_tensor = flow.tensor(\n        feature_0_np, device=\"cuda\", requires_grad=True\n    )\n    fused_feature_1_tensor = flow.tensor(\n        feature_1_np, device=\"cuda\", requires_grad=True\n    )\n    if output_concat:\n        output_concat_tensor = fused_feature_0_tensor\n    else:\n        output_concat_tensor = None\n    fused_R = flow._C.fused_dot_feature_interaction(\n        [\n            fused_feature_0_tensor.reshape(batch_size, 1, embedding_size),\n            fused_feature_1_tensor,\n        ],\n        output_concat=output_concat_tensor,\n        self_interaction=self_interaction,\n        output_padding=output_padding,\n        pooling=\"none\",\n    )\n    fused_loss = fused_R.sum()\n    fused_loss.backward()\n    test_case.assertTrue(\n        np.allclose(\n            feature_0_tensor.grad.numpy(),\n            fused_feature_0_tensor.grad.numpy(),\n            rtol=1e-3,\n            atol=1e-4,\n        )\n    )\n    test_case.assertTrue(\n        np.allclose(\n            feature_1_tensor.grad.numpy(),\n            fused_feature_1_tensor.grad.numpy(),\n            rtol=1e-3,\n            atol=1e-4,\n        )\n    )\n\n    test_case.assertTrue(np.allclose(fused_R.numpy(), R.numpy(), rtol=1e-3, atol=1e-3))\n\n\ndef _test_fused_dot_feature_interaction_pooling_sum(\n    test_case, dtype, feature_dims, embedding_size, device_type=\"cuda\",\n):\n    batch_size = 100\n    if dtype == flow.float16:\n        np_dtype = np.float16\n    else:\n        np_dtype = np.float32\n\n    feature_tensor_list = []\n    fused_feature_tensor_list = []\n    for dim in feature_dims:\n        feature_np = np.random.uniform(-1, 1, (batch_size, dim, embedding_size)).astype(\n            np_dtype\n        )\n        feature_tensor = flow.tensor(feature_np, device=\"cuda\", requires_grad=True)\n        feature_tensor_list.append(feature_tensor)\n        fused_feature_tensor = flow.tensor(\n            feature_np, device=\"cuda\", requires_grad=True\n        )\n        fused_feature_tensor_list.append(fused_feature_tensor)\n\n    concat = flow.cat(feature_tensor_list, dim=1,)\n    if dtype == flow.float16:\n        concat = flow.cast(concat, flow.float)\n    sum_then_square = flow.sum(concat, dim=1) ** 2\n    square_then_sum = flow.sum(concat ** 2, dim=1)\n    bi_interaction = (sum_then_square - square_then_sum) * 0.5\n    if dtype == flow.float16:\n        bi_interaction = flow.cast(bi_interaction, flow.float16)\n    R = flow.sum(bi_interaction, dim=-1, keepdim=True)\n    loss = R.sum()\n    loss.backward()\n\n    fused_R = flow._C.fused_dot_feature_interaction(\n        fused_feature_tensor_list, pooling=\"sum\",\n    )\n    fused_loss = fused_R.sum()\n    fused_loss.backward()\n    if dtype == flow.float16:\n        tol = 1e-2\n    else:\n        tol = 1e-3\n    for i in range(len(feature_dims)):\n        test_case.assertTrue(\n            np.allclose(\n                feature_tensor_list[i].grad.numpy(),\n                fused_feature_tensor_list[i].grad.numpy(),\n                rtol=1e-3,\n                atol=1e-3,\n            )\n        )\n    test_case.assertTrue(np.allclose(fused_R.numpy(), R.numpy(), rtol=tol, atol=tol))\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n@flow.unittest.skip_unless_1n1d()\nclass FusedDotFeatureInteractionTestCase(flow.unittest.TestCase):\n    def test_fused_dot_feature_interaction(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"embedding_size\"] = [128, 127, 16, 15]\n        arg_dict[\"self_interaction\"] = [False, True]\n        arg_dict[\"output_concat\"] = [True, False]\n        arg_dict[\"output_padding\"] = [1, 0]\n        arg_dict[\"dtype\"] = [flow.float16, flow.float32]\n        for kwargs in GenArgDict(arg_dict):\n            _test_fused_dot_feature_interaction(test_case, **kwargs)\n\n    def test_fused_dot_feature_interaction_pooling_sum(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"dtype\"] = [flow.float16, flow.float32]\n        arg_dict[\"feature_dims\"] = [[39], [13, 26], [1, 10, 3]]\n        arg_dict[\"embedding_size\"] = [16, 11, 12]\n        for kwargs in GenArgDict(arg_dict):\n            _test_fused_dot_feature_interaction_pooling_sum(test_case, **kwargs)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_fused_gelu_mul.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport os\nimport numpy as np\nimport unittest\nfrom collections import OrderedDict\n\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.test_util import GenArgDict\n\n\ndef _test_fused_fast_gelu_mul(test_case, shape, dtype=flow.float32):\n    x = flow.randn(*shape).to(dtype=dtype, device=\"cuda\").requires_grad_(True)\n    multiplier = flow.randn(*shape).to(dtype=dtype, device=\"cuda\").requires_grad_(True)\n    y = flow.nn.functional.gelu(x, approximate=\"tanh\") * multiplier\n    y.mean().backward()\n    x_grad = x.grad.detach().cpu()\n    m_grad = multiplier.grad.detach().cpu()\n    y = y.detach().cpu()\n\n    fused_x = x.detach().clone().requires_grad_(True)\n    fused_multiplier = multiplier.detach().clone().requires_grad_(True)\n    fused_y = flow._C.fused_fast_gelu_mul(fused_x, fused_multiplier)\n    fused_y.mean().backward()\n    fused_x_grad = fused_x.grad.detach().cpu()\n    fused_m_grad = fused_multiplier.grad.detach().cpu()\n    fused_y = fused_y.detach().cpu()\n\n    def compare(a, b, rtol=1e-5, atol=1e-8):\n        test_case.assertTrue(\n            np.allclose(a.numpy(), b.numpy(), rtol=rtol, atol=atol),\n            f\"\\na\\n{a.numpy()}\\n{'-' * 80}\\nb:\\n{b.numpy()}\\n{'*' * 80}\\ndiff:\\n{a.numpy() - b.numpy()}\",\n        )\n\n    # print(f\"\\n{'=' * 20} shape={shape} dtype={dtype} {'=' * 20}\")\n    if dtype == flow.float16:\n        compare(fused_y, y, 1e-2, 1e-3)\n        compare(fused_x_grad, x_grad, 1e-4, 1e-3)\n        compare(fused_m_grad, m_grad, 1e-4, 1e-3)\n    else:\n        compare(fused_y, y)\n        compare(fused_x_grad, x_grad)\n        compare(fused_m_grad, m_grad)\n\n\n@flow.unittest.skip_unless_1n1d()\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test gpu cases\")\nclass TestFusedFastGeluMul(flow.unittest.TestCase):\n    def test_fused_fast_gelu_mul(test_case):\n        args_dict = OrderedDict()\n        args_dict[\"shape\"] = [[5], [7, 10], [4, 2, 3], [8, 3, 16, 16]]\n        args_dict[\"dtype\"] = [flow.float16, flow.float32]\n        for kwarg in GenArgDict(args_dict):\n            _test_fused_fast_gelu_mul(test_case, **kwarg)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_fused_get_boundding_boxes_coord.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nimport torch\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\ndef _test_get_boundding_boxes_coord_impl(test_case, device, shape):\n    x = []\n    torch_x = []\n    for _ in range(8):\n        tmp = flow.tensor(\n            np.random.randn(*shape),\n            dtype=flow.float32,\n            device=flow.device(device),\n            requires_grad=True,\n        )\n        x.append(tmp)\n        torch_x.append(\n            torch.tensor(\n                tmp.numpy(),\n                dtype=torch.float32,\n                device=torch.device(device),\n                requires_grad=True,\n            )\n        )\n    x1, y1, w1, h1, x2, y2, w2, h2 = x[0], x[1], x[2], x[3], x[4], x[5], x[6], x[7]\n    (\n        b1_x1,\n        b1_x2,\n        b1_y1,\n        b1_y2,\n        b2_x1,\n        b2_x2,\n        b2_y1,\n        b2_y2,\n    ) = flow._C.fused_get_boundding_boxes_coord(x1, y1, w1, h1, x2, y2, w2, h2)\n    torch_x1, torch_y1, torch_w1, torch_h1, torch_x2, torch_y2, torch_w2, torch_h2 = (\n        torch_x[0],\n        torch_x[1],\n        torch_x[2],\n        torch_x[3],\n        torch_x[4],\n        torch_x[5],\n        torch_x[6],\n        torch_x[7],\n    )\n    torch_w1_, torch_h1_, torch_w2_, torch_h2_ = (\n        torch_w1 / 2,\n        torch_h1 / 2,\n        torch_w2 / 2,\n        torch_h2 / 2,\n    )\n    torch_b1_x1, torch_b1_x2, torch_b1_y1, torch_b1_y2 = (\n        torch_x1 - torch_w1_,\n        torch_x1 + torch_w1_,\n        torch_y1 - torch_h1_,\n        torch_y1 + torch_h1_,\n    )\n    torch_b2_x1, torch_b2_x2, torch_b2_y1, torch_b2_y2 = (\n        torch_x2 - torch_w2_,\n        torch_x2 + torch_w2_,\n        torch_y2 - torch_h2_,\n        torch_y2 + torch_h2_,\n    )\n\n    def compare(a, b, rtol=1e-5, atol=1e-8):\n        test_case.assertTrue(\n            np.allclose(\n                a.detach().cpu().numpy(), b.detach().cpu().numpy(), rtol=rtol, atol=atol\n            ),\n            f\"\\na\\n{a.detach().cpu().numpy()}\\n{'-' * 80}\\nb:\\n{b.detach().cpu().numpy()}\\n{'*' * 80}\\ndiff:\\n{a.detach().cpu().numpy() - b.detach().cpu().numpy()}\",\n        )\n\n    compare(b1_x1, torch_b1_x1)\n    compare(b1_x2, torch_b1_x2)\n    compare(b1_y1, torch_b1_y1)\n    compare(b1_y2, torch_b1_y2)\n    compare(b2_x1, torch_b2_x1)\n    compare(b2_x2, torch_b2_x2)\n    compare(b2_y1, torch_b2_y1)\n    compare(b2_y2, torch_b2_y2)\n    res = (\n        (b1_x1 + 2 * b1_x2 + b1_y1 + b1_y2 + b2_x1 + b2_x2 + b2_y1 + b2_y2) * 2\n    ).sum()\n    torch_res = (\n        (\n            torch_b1_x1\n            + 2 * torch_b1_x2\n            + torch_b1_y1\n            + torch_b1_y2\n            + torch_b2_x1\n            + torch_b2_x2\n            + torch_b2_y1\n            + torch_b2_y2\n        )\n        * 2\n    ).sum()\n    res.sum().backward()\n    torch_res.sum().backward()\n    compare(x1.grad, torch_x1.grad)\n    compare(y1.grad, torch_y1.grad)\n    compare(w1.grad, torch_w1.grad)\n    compare(h1.grad, torch_h1.grad)\n    compare(x2.grad, torch_x2.grad)\n    compare(y2.grad, torch_y2.grad)\n    compare(w2.grad, torch_w2.grad)\n    compare(h2.grad, torch_h2.grad)\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestGetBounddingBoxesCoordModule(flow.unittest.TestCase):\n    def test_get_boundding_boxes_coord(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [_test_get_boundding_boxes_coord_impl]\n        arg_dict[\"device\"] = [\"cuda\"]\n        arg_dict[\"shape\"] = [(583, 1), (759, 1), (1234, 1)]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_fused_get_ciou_diagonal_angle.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport math\nimport numpy as np\nimport torch\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\ndef torch_get_ciou_diagonal_angle(w1, h1, w2, h2, eps=1e-8):\n    return (4 / math.pi ** 2) * torch.pow(\n        torch.atan(w2 / (h2 + eps)) - torch.atan(w1 / (h1 + eps)), 2\n    )\n\n\ndef _test_fused_get_ciou_diagonal_angle_impl(test_case, device, shape):\n    def compare(a, b, rtol=1e-5, atol=1e-5):\n        test_case.assertTrue(\n            np.allclose(\n                a.detach().cpu().numpy(), b.detach().cpu().numpy(), rtol=rtol, atol=atol\n            ),\n            f\"\\na\\n{a.detach().cpu().numpy()}\\n{'-' * 80}\\nb:\\n{b.detach().cpu().numpy()}\\n{'*' * 80}\\ndiff:\\n{a.detach().cpu().numpy() - b.detach().cpu().numpy()}\",\n        )\n\n    x = []\n    torch_x = []\n    for _ in range(4):\n        tmp = flow.tensor(\n            np.random.randn(*shape),\n            dtype=flow.float32,\n            device=flow.device(device),\n            requires_grad=True,\n        )\n        x.append(tmp)\n        torch_x.append(\n            torch.tensor(\n                tmp.numpy(),\n                dtype=torch.float32,\n                device=torch.device(device),\n                requires_grad=True,\n            )\n        )\n    w1, h1, w2, h2 = (\n        x[0],\n        x[1],\n        x[2],\n        x[3],\n    )\n    (torch_w1, torch_h1, torch_w2, torch_h2,) = (\n        torch_x[0],\n        torch_x[1],\n        torch_x[2],\n        torch_x[3],\n    )\n    v = flow._C.fused_get_ciou_diagonal_angle(w1, h1, w2, h2, eps=1e-8)\n    torch_v = torch_get_ciou_diagonal_angle(torch_w1, torch_h1, torch_w2, torch_h2,)\n    compare(v, torch_v)\n\n    v.sum().backward()\n    torch_v.sum().backward()\n    compare(w1.grad, torch_w1.grad)\n    compare(h1.grad, torch_h1.grad)\n    compare(w2.grad, torch_w2.grad)\n    compare(h2.grad, torch_h2.grad)\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestGetCiouDiagonalAngle(flow.unittest.TestCase):\n    def test_fused_get_ciou_diagonal_angle(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [_test_fused_get_ciou_diagonal_angle_impl]\n        arg_dict[\"device\"] = [\"cuda\"]\n        arg_dict[\"shape\"] = [(583, 1), (759, 1), (1234, 1)]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_fused_get_ciou_result.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nimport torch\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\ndef _test_get_ciou_result_impl(test_case, device, shape):\n    eps = 1e-7\n    x = []\n    torch_x = []\n    for _ in range(4):\n        tmp = flow.tensor(\n            np.random.uniform(0, 1, shape),\n            dtype=flow.float32,\n            device=flow.device(device),\n            requires_grad=True,\n        )\n        x.append(tmp)\n        torch_x.append(\n            torch.tensor(\n                tmp.numpy(),\n                dtype=torch.float32,\n                device=torch.device(device),\n                requires_grad=True,\n            )\n        )\n    v, iou, rho2, c2 = x[0], x[1], x[2], x[3]\n    y = flow._C.fused_get_ciou_result(v, iou, rho2, c2, eps)[0]\n    torch_v, torch_iou, torch_rho2, torch_c2 = (\n        torch_x[0],\n        torch_x[1],\n        torch_x[2],\n        torch_x[3],\n    )\n    with torch.no_grad():\n        torch_alpha = torch_v / (torch_v - torch_iou + (1.0 + eps))\n    torch_y = torch_iou - (torch_rho2 / torch_c2 + torch_v * torch_alpha)\n\n    def compare(a, b, rtol=1e-5, atol=1e-5):\n        test_case.assertTrue(\n            np.allclose(\n                a.detach().cpu().numpy(), b.detach().cpu().numpy(), rtol=rtol, atol=atol\n            ),\n            f\"\\na\\n{a.detach().cpu().numpy()}\\n{'-' * 80}\\nb:\\n{b.detach().cpu().numpy()}\\n{'*' * 80}\\ndiff:\\n{a.detach().cpu().numpy() - b.detach().cpu().numpy()}\",\n        )\n\n    compare(y, torch_y)\n\n    res = y.sum()\n    torch_res = torch_y.sum()\n    res.backward()\n    torch_res.backward()\n    compare(v.grad, torch_v.grad)\n    compare(iou.grad, torch_iou.grad)\n    compare(rho2.grad, torch_rho2.grad)\n    compare(c2.grad, torch_c2.grad)\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestGetCiouResultModule(flow.unittest.TestCase):\n    def test_get_ciou_result(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [_test_get_ciou_result_impl]\n        arg_dict[\"device\"] = [\"cuda\"]\n        arg_dict[\"shape\"] = [(492), (691, 1), (1162, 1)]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_fused_get_convex_diagonal_squared.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nimport torch\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\ndef torch_fused_get_convex_diagonal_squared(\n    b1_x1, b1_x2, b2_x1, b2_x2, b1_y1, b1_y2, b2_y1, b2_y2, eps\n):\n    cw = torch.max(b1_x2, b2_x2) - torch.min(b1_x1, b2_x1)\n    ch = torch.max(b1_y2, b2_y2) - torch.min(b1_y1, b2_y1)\n    c2 = cw ** 2 + ch ** 2 + eps\n    return c2\n\n\ndef _test_fused_get_convex_diagonal_squared_impl(test_case, device, shape):\n    def compare(a, b, rtol=1e-5, atol=1e-5):\n        test_case.assertTrue(\n            np.allclose(\n                a.detach().cpu().numpy(), b.detach().cpu().numpy(), rtol=rtol, atol=atol\n            ),\n            f\"\\na\\n{a.detach().cpu().numpy()}\\n{'-' * 80}\\nb:\\n{b.detach().cpu().numpy()}\\n{'*' * 80}\\ndiff:\\n{a.detach().cpu().numpy() - b.detach().cpu().numpy()}\",\n        )\n\n    eps = 1e-8\n    x = []\n    torch_x = []\n    for _ in range(8):\n        tmp = flow.tensor(\n            np.random.randn(*shape),\n            dtype=flow.float32,\n            device=flow.device(device),\n            requires_grad=True,\n        )\n        x.append(tmp)\n        torch_x.append(\n            torch.tensor(\n                tmp.numpy(),\n                dtype=torch.float32,\n                device=torch.device(device),\n                requires_grad=True,\n            )\n        )\n    b1_x1, b1_x2, b2_x1, b2_x2, b1_y1, b1_y2, b2_y1, b2_y2 = (\n        x[0],\n        x[1],\n        x[2],\n        x[3],\n        x[4],\n        x[5],\n        x[6],\n        x[7],\n    )\n    (\n        torch_b1_x1,\n        torch_b1_x2,\n        torch_b2_x1,\n        torch_b2_x2,\n        torch_b1_y1,\n        torch_b1_y2,\n        torch_b2_y1,\n        torch_b2_y2,\n    ) = (\n        torch_x[0],\n        torch_x[1],\n        torch_x[2],\n        torch_x[3],\n        torch_x[4],\n        torch_x[5],\n        torch_x[6],\n        torch_x[7],\n    )\n    c2 = flow._C.fused_get_convex_diagonal_squared(\n        b1_x1, b1_x2, b2_x1, b2_x2, b1_y1, b1_y2, b2_y1, b2_y2, eps\n    )\n    torch_c2 = torch_fused_get_convex_diagonal_squared(\n        torch_b1_x1,\n        torch_b1_x2,\n        torch_b2_x1,\n        torch_b2_x2,\n        torch_b1_y1,\n        torch_b1_y2,\n        torch_b2_y1,\n        torch_b2_y2,\n        eps,\n    )\n    compare(c2, torch_c2)\n\n    c2.sum().backward()\n    torch_c2.sum().backward()\n    compare(b1_x1.grad, torch_b1_x1.grad)\n    compare(b1_x2.grad, torch_b1_x2.grad)\n    compare(b2_x1.grad, torch_b2_x1.grad)\n    compare(b2_x2.grad, torch_b2_x2.grad)\n    compare(b1_y1.grad, torch_b1_y1.grad)\n    compare(b1_y2.grad, torch_b1_y2.grad)\n    compare(b2_y1.grad, torch_b2_y1.grad)\n    compare(b2_y2.grad, torch_b2_y2.grad)\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestGetCenterDistModule(flow.unittest.TestCase):\n    def test_fused_get_convex_diagonal_squared(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [_test_fused_get_convex_diagonal_squared_impl]\n        arg_dict[\"device\"] = [\"cuda\"]\n        arg_dict[\"shape\"] = [(583, 1), (759, 1), (1234, 1)]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_fused_get_intersection_area.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nimport torch\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\ndef torch_get_intersection_area(b1_x1, b1_x2, b2_x1, b2_x2, b1_y1, b1_y2, b2_y1, b2_y2):\n    return (torch.min(b1_x2, b2_x2) - torch.max(b1_x1, b2_x1)).clamp(0) * (\n        torch.min(b1_y2, b2_y2) - torch.max(b1_y1, b2_y1)\n    ).clamp(0)\n\n\ndef _test_fused_get_intersection_area_impl(test_case, device, shape):\n    def compare(a, b, rtol=1e-5, atol=1e-5):\n        test_case.assertTrue(\n            np.allclose(\n                a.detach().cpu().numpy(), b.detach().cpu().numpy(), rtol=rtol, atol=atol\n            ),\n            f\"\\na\\n{a.detach().cpu().numpy()}\\n{'-' * 80}\\nb:\\n{b.detach().cpu().numpy()}\\n{'*' * 80}\\ndiff:\\n{a.detach().cpu().numpy() - b.detach().cpu().numpy()}\\n\",\n        )\n\n    x = []\n    torch_x = []\n    for _ in range(8):\n        tmp = flow.tensor(\n            np.random.randn(*shape),\n            dtype=flow.float32,\n            device=flow.device(device),\n            requires_grad=True,\n        )\n        x.append(tmp)\n        torch_x.append(\n            torch.tensor(\n                tmp.numpy(),\n                dtype=torch.float32,\n                device=torch.device(device),\n                requires_grad=True,\n            )\n        )\n    b1_x1, b1_x2, b2_x1, b2_x2, b1_y1, b1_y2, b2_y1, b2_y2 = (\n        x[0],\n        x[1],\n        x[2],\n        x[3],\n        x[4],\n        x[5],\n        x[6],\n        x[7],\n    )\n    (\n        torch_b1_x1,\n        torch_b1_x2,\n        torch_b2_x1,\n        torch_b2_x2,\n        torch_b1_y1,\n        torch_b1_y2,\n        torch_b2_y1,\n        torch_b2_y2,\n    ) = (\n        torch_x[0],\n        torch_x[1],\n        torch_x[2],\n        torch_x[3],\n        torch_x[4],\n        torch_x[5],\n        torch_x[6],\n        torch_x[7],\n    )\n    inter = flow._C.fused_get_intersection_area(\n        b1_x1, b1_x2, b2_x1, b2_x2, b1_y1, b1_y2, b2_y1, b2_y2\n    )\n    torch_inter = torch_get_intersection_area(\n        torch_b1_x1,\n        torch_b1_x2,\n        torch_b2_x1,\n        torch_b2_x2,\n        torch_b1_y1,\n        torch_b1_y2,\n        torch_b2_y1,\n        torch_b2_y2,\n    )\n    compare(inter, torch_inter)\n\n    inter.sum().backward()\n    torch_inter.sum().backward()\n    compare(b1_x1.grad, torch_b1_x1.grad)\n    compare(b1_x2.grad, torch_b1_x2.grad)\n    compare(b2_x1.grad, torch_b2_x1.grad)\n    compare(b2_x2.grad, torch_b2_x2.grad)\n    compare(b1_y1.grad, torch_b1_y1.grad)\n    compare(b1_y2.grad, torch_b1_y2.grad)\n    compare(b2_y1.grad, torch_b2_y1.grad)\n    compare(b2_y2.grad, torch_b2_y2.grad)\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestGetIntersectionAreaModule(flow.unittest.TestCase):\n    def test_fused_get_inter_intersection_area(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [_test_fused_get_intersection_area_impl]\n        arg_dict[\"device\"] = [\"cuda\"]\n        arg_dict[\"shape\"] = [(583, 1), (759, 1), (1234, 1)]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_fused_get_iou.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nimport torch\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\ndef _test_get_iou_impl(test_case, device, shape):\n    eps = 1e-7\n    x = []\n    torch_x = []\n    for _ in range(5):\n        tmp = flow.tensor(\n            np.random.uniform(0, 1, shape),\n            dtype=flow.float64,\n            device=flow.device(device),\n            requires_grad=True if (_ < 2 or _ > 3) else False,\n        )\n        x.append(tmp)\n        torch_x.append(\n            torch.tensor(\n                tmp.numpy(),\n                dtype=torch.float64,\n                device=torch.device(device),\n                requires_grad=True if (_ < 2 or _ > 3) else False,\n            )\n        )\n    w1, h1, w2, h2, inter = x[0], x[1], x[2], x[3], x[4]\n    iou = flow._C.fused_get_iou(w1, h1, w2, h2, inter, eps)\n    torch_w1, torch_h1, torch_w2, torch_h2, torch_inter = (\n        torch_x[0],\n        torch_x[1],\n        torch_x[2],\n        torch_x[3],\n        torch_x[4],\n    )\n    torch_iou = torch_inter / (\n        torch_w1 * torch_h1 + torch_w2 * torch_h2 - torch_inter + eps\n    )\n\n    def compare(a, b, rtol=1e-5, atol=1e-5, w1=w1, h1=h1, w2=w2, h2=h2, inter=inter):\n        test_case.assertTrue(\n            np.allclose(\n                a.detach().cpu().numpy(), b.detach().cpu().numpy(), rtol=rtol, atol=atol\n            ),\n            f\"\\na\\n{a.detach().cpu().numpy()}\\n{'-' * 80}\\nb:\\n{b.detach().cpu().numpy()}\\n{'*' * 80}\\ndiff:\\n{a.detach().cpu().numpy() - b.detach().cpu().numpy()}\",\n        )\n\n    compare(iou, torch_iou)\n\n    res = iou.sum()\n    torch_res = torch_iou.sum()\n    res.backward()\n    torch_res.backward()\n    compare(w1.grad, torch_w1.grad)\n    compare(h1.grad, torch_h1.grad)\n    compare(inter.grad, torch_inter.grad)\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestGetIouModule(flow.unittest.TestCase):\n    def test_get_iou(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [_test_get_iou_impl]\n        arg_dict[\"device\"] = [\"cuda\"]\n        arg_dict[\"shape\"] = [(492), (691, 1), (1162, 1)]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_fused_glu.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport os\nimport unittest\nimport time\nimport datetime\nimport numpy as np\nfrom collections import OrderedDict\n\nimport oneflow as flow\nimport oneflow.nn as nn\nimport oneflow.unittest\nfrom oneflow.test_utils.test_util import GenArgList\n\ntest_dualgemm_impt = False\n\n\nclass Glu(nn.Module):\n    def __init__(self):\n        super().__init__()\n\n    def forward(\n        self,\n        x: flow.Tensor,\n        w: flow.Tensor,\n        b: flow.Tensor = None,\n        v: flow.Tensor = None,\n        c: flow.Tensor = None,\n        split_mode: bool = False,\n        activation: str = \"none\",\n    ) -> flow.Tensor:\n        # matmul\n        matmul_wx = flow._C.matmul(\n            input=x, other=w, transpose_a=False, transpose_b=True\n        )\n        if split_mode:\n            matmul_vx = flow._C.matmul(\n                input=x, other=v, transpose_a=False, transpose_b=True\n            )\n\n        # add bias\n        if b != None:\n            matmul_wx_b = flow._C.add(input=matmul_wx, other=b)\n            if split_mode:\n                matmul_vx_c = flow._C.add(input=matmul_vx, other=c)\n        else:\n            matmul_wx_b = matmul_wx\n            if split_mode:\n                matmul_vx_c = matmul_vx\n\n        # chunk\n        if split_mode:\n            hidden_state = matmul_wx_b\n            gate = matmul_vx_c\n        else:\n            hidden_state, gate = matmul_wx_b.chunk(2, dim=-1)\n\n        # activation and element-wise product\n        if activation == \"none\":\n            return hidden_state * gate\n        elif activation == \"sigmoid\":\n            return hidden_state * flow.sigmoid(gate)\n        elif activation == \"relu\":\n            return hidden_state * flow.relu(gate)\n        elif activation == \"gelu\":\n            return hidden_state * flow.gelu(gate)\n        elif activation == \"fast_gelu\":\n            return hidden_state * flow._C.fast_gelu(gate)\n        elif activation == \"silu\":\n            return hidden_state * flow.silu(gate)\n\n\ndef tensor_builder(params: dict, dtype=flow.float32, is_split_mode=True):\n    # config test data\n    m = params[\"m\"]\n    n = params[\"n\"]\n    k = params[\"k\"]\n\n    # generate random input\n    x = np.random.randn(2, m, k) / 100\n    y_nor = np.random.randn(2, m, n)\n    if is_split_mode:\n        w = np.random.randn(n, k) / 100  # transpose\n        b = np.random.randn(n) / 100\n        v = np.random.randn(n, k) / 100  # transpose\n        c = np.random.randn(n) / 100\n    else:\n        w = np.random.randn(n * 2, k) / 100  # transpose\n        b = np.random.randn(n * 2) / 100\n\n    # transfer to gpu memory\n    tensor_x = flow.FloatTensor(x).to(dtype=dtype, device=\"cuda\")\n    tensor_y_nor = flow.FloatTensor(y_nor).to(dtype=dtype, device=\"cuda\")\n    tensor_w = flow.FloatTensor(w).to(dtype=dtype, device=\"cuda\").requires_grad_(True)\n    tensor_b = flow.FloatTensor(b).to(dtype=dtype, device=\"cuda\").requires_grad_(True)\n    if is_split_mode:\n        tensor_v = (\n            flow.FloatTensor(v).to(dtype=dtype, device=\"cuda\").requires_grad_(True)\n        )\n        tensor_c = (\n            flow.FloatTensor(c).to(dtype=dtype, device=\"cuda\").requires_grad_(True)\n        )\n\n    if is_split_mode:\n        return tensor_x, tensor_w, tensor_b, tensor_v, tensor_c, tensor_y_nor\n    else:\n        return tensor_x, tensor_w, tensor_b, tensor_y_nor\n\n\ndef compare_result(test_case, a, b, rtol=1e-5, atol=1e-8):\n    test_case.assertTrue(\n        np.allclose(a.numpy(), b.numpy(), rtol=rtol, atol=atol),\n        f\"\\na\\n{a.numpy()}\\n{'-' * 80}\\nb:\\n{b.numpy()}\\n{'*' * 80}\\ndiff:\\n{a.numpy() - b.numpy()}\",\n    )\n\n\ndef _test_fused_glu(test_case, params: dict, dtype=flow.float32):\n    print(f\"========== Start Testing ==========\")\n    print(f\"weight tensor: merged\")\n    print(f'tensor shape: m={params[\"m\"]}, n={params[\"n\"]}, k={params[\"k\"]}')\n    print(f'activation: {params[\"act\"]}')\n    print(f\"dtype: {dtype}\")\n\n    flow_module = Glu()\n    x, w, b, y_nor = tensor_builder(params=params, dtype=dtype, is_split_mode=False)\n\n    # forward\n    y = flow_module.forward(x=x, w=w, b=b, split_mode=False, activation=params[\"act\"])\n\n    # backward\n    y.sum().backward()\n\n    # copy back to cpu memory\n    w_grad = w.grad.detach().cpu()\n    b_grad = b.grad.detach().cpu()\n    y = y.detach().cpu()\n\n    fused_x = x.detach().clone()\n    fused_w = w.detach().clone().requires_grad_(True)\n    fused_b = b.detach().clone().requires_grad_(True)\n\n    # forward\n    fused_y = flow._C.fused_glu(\n        x=fused_x, w=fused_w, b=fused_b, v=None, c=None, activation=params[\"act\"]\n    )\n\n    # backward\n    fused_y.sum().backward()\n\n    # copy back to cpu memory\n    fused_w_grad = fused_w.grad.detach().cpu()\n    fused_b_grad = fused_b.grad.detach().cpu()\n    fused_y = fused_y.detach().cpu()\n\n    if dtype == flow.float16:\n        compare_result(test_case, fused_y, y, 1e-2, 1e-3)\n        compare_result(test_case, fused_w_grad, w_grad, 1e-2, 1e-1)\n        compare_result(test_case, fused_b_grad, b_grad, 1e-2, 1e-1)\n    else:\n        compare_result(test_case, fused_y, y)\n        compare_result(test_case, fused_w_grad, w_grad, 1e-5, 1e-2)\n        compare_result(test_case, fused_b_grad, b_grad, 1e-5, 1e-2)\n    print(f\"============== PASSED =============\")\n    print(\"\\n\")\n\n\ndef _test_fused_glu_without_bias(test_case, params: dict, dtype=flow.float32):\n    print(f\"========== Start Testing ==========\")\n    print(f\"weight tensor: merged\")\n    print(f\"no bias\")\n    print(f'tensor shape: m={params[\"m\"]}, n={params[\"n\"]}, k={params[\"k\"]}')\n    print(f'activation: {params[\"act\"]}')\n    print(f\"dtype: {dtype}\")\n\n    flow_module = Glu()\n    x, w, b, y_nor = tensor_builder(params=params, dtype=dtype, is_split_mode=False)\n\n    # forward\n    y = flow_module.forward(x=x, w=w, split_mode=False, activation=params[\"act\"])\n\n    # backward\n    y.sum().backward()\n\n    # copy back to cpu memory\n    w_grad = w.grad.detach().cpu()\n    y = y.detach().cpu()\n\n    fused_x = x.detach().clone()\n    fused_w = w.detach().clone().requires_grad_(True)\n\n    # forward\n    fused_y = flow._C.fused_glu(\n        x=fused_x, w=fused_w, b=None, v=None, c=None, activation=params[\"act\"]\n    )\n\n    # backward\n    fused_y.sum().backward()\n\n    # copy back to cpu memory\n    fused_w_grad = fused_w.grad.detach().cpu()\n    fused_y = fused_y.detach().cpu()\n\n    if dtype == flow.float16:\n        compare_result(test_case, fused_y, y, 1e-2, 1e-3)\n        compare_result(test_case, fused_w_grad, w_grad, 1e-2, 1e-1)\n    else:\n        compare_result(test_case, fused_y, y)\n        compare_result(test_case, fused_w_grad, w_grad, 1e-5, 1e-2)\n    print(f\"============== PASSED =============\")\n    print(\"\\n\")\n\n\ndef _test_fused_glu_split(test_case, params: dict, dtype=flow.float32):\n    print(f\"========== Start Testing ==========\")\n    print(f\"weight tensor: splited\")\n    print(f'tensor shape: m={params[\"m\"]}, n={params[\"n\"]}, k={params[\"k\"]}')\n    print(f'activation: {params[\"act\"]}')\n    print(f\"dtype: {dtype}\")\n\n    flow_module = Glu()\n    x, w, b, v, c, y_nor = tensor_builder(\n        params=params, dtype=dtype, is_split_mode=True\n    )\n\n    # forward\n    y = flow_module.forward(\n        x=x, w=w, b=b, v=v, c=c, split_mode=True, activation=params[\"act\"]\n    )\n\n    # backward\n    y.sum().backward()\n\n    # copy back to cpu memory\n    w_grad = w.grad.detach().cpu()\n    b_grad = b.grad.detach().cpu()\n    v_grad = v.grad.detach().cpu()\n    c_grad = c.grad.detach().cpu()\n    y = y.detach().cpu()\n\n    fused_x = x.detach().clone()\n    fused_w = w.detach().clone().requires_grad_(True)\n    fused_b = b.detach().clone().requires_grad_(True)\n    fused_v = v.detach().clone().requires_grad_(True)\n    fused_c = c.detach().clone().requires_grad_(True)\n\n    # forward\n    fused_y = flow._C.fused_glu(\n        x=fused_x, w=fused_w, b=fused_b, v=fused_v, c=fused_c, activation=params[\"act\"]\n    )\n\n    # backward\n    fused_y.sum().backward()\n\n    fused_w_grad = fused_w.grad.detach().cpu()\n    fused_b_grad = fused_b.grad.detach().cpu()\n    fused_v_grad = fused_v.grad.detach().cpu()\n    fused_c_grad = fused_c.grad.detach().cpu()\n    fused_y = fused_y.detach().cpu()\n\n    if dtype == flow.float16:\n        compare_result(test_case, fused_y, y, 1e-2, 1e-3)\n        compare_result(test_case, fused_w_grad, w_grad, 1e-2, 1e-1)\n        compare_result(test_case, fused_b_grad, b_grad, 1e-2, 1e-1)\n        compare_result(test_case, fused_v_grad, v_grad, 1e-2, 1e-1)\n        compare_result(test_case, fused_c_grad, c_grad, 1e-2, 1e-1)\n    else:\n        compare_result(test_case, fused_y, y)\n        compare_result(test_case, fused_w_grad, w_grad, 1e-5, 1e-2)\n        compare_result(test_case, fused_b_grad, b_grad, 1e-5, 1e-2)\n        compare_result(test_case, fused_v_grad, v_grad, 1e-5, 1e-2)\n        compare_result(test_case, fused_c_grad, c_grad, 1e-5, 1e-2)\n    print(f\"============== PASSED =============\")\n    print(\"\\n\")\n\n\ndef _test_fused_glu_split_without_bias(test_case, params: dict, dtype=flow.float32):\n    print(f\"========== Start Testing ==========\")\n    print(f\"weight tensor: splited\")\n    print(f\"no bias\")\n    print(f'tensor shape: m={params[\"m\"]}, n={params[\"n\"]}, k={params[\"k\"]}')\n    print(f'activation: {params[\"act\"]}')\n    print(f\"dtype: {dtype}\")\n\n    flow_module = Glu()\n    x, w, b, v, c, y_nor = tensor_builder(\n        params=params, dtype=dtype, is_split_mode=True\n    )\n\n    # forward\n    y = flow_module.forward(x=x, w=w, v=v, split_mode=True, activation=params[\"act\"])\n\n    # backward\n    y.sum().backward()\n\n    # copy back to cpu memory\n    w_grad = w.grad.detach().cpu()\n    v_grad = v.grad.detach().cpu()\n    y = y.detach().cpu()\n\n    fused_x = x.detach().clone()\n    fused_w = w.detach().clone().requires_grad_(True)\n    fused_v = v.detach().clone().requires_grad_(True)\n\n    # forward\n    fused_y = flow._C.fused_glu(\n        x=fused_x, w=fused_w, b=None, v=fused_v, c=None, activation=params[\"act\"]\n    )\n\n    # backward\n    fused_y.sum().backward()\n\n    fused_w_grad = fused_w.grad.detach().cpu()\n    fused_v_grad = fused_v.grad.detach().cpu()\n    fused_y = fused_y.detach().cpu()\n\n    if dtype == flow.float16:\n        compare_result(test_case, fused_y, y, 1e-2, 1e-3)\n        compare_result(test_case, fused_w_grad, w_grad, 1e-2, 1e-1)\n        compare_result(test_case, fused_v_grad, v_grad, 1e-2, 1e-1)\n    else:\n        compare_result(test_case, fused_y, y)\n        compare_result(test_case, fused_w_grad, w_grad, 1e-5, 1e-2)\n        compare_result(test_case, fused_v_grad, v_grad, 1e-5, 1e-2)\n    print(f\"============== PASSED =============\")\n    print(\"\\n\")\n\n\n# @flow.unittest.skip_unless_1n1d()\n# @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n@unittest.skipIf(True, \"CI test taking too long.\")\nclass TestFusedGlu(flow.unittest.TestCase):\n    @unittest.skip(\"skip for now, becase it failed 4 times in past week\")\n    def test_gather(test_case):\n        arg_dict = OrderedDict()\n        # set up test functions\n        arg_dict[\"test_fun\"] = [\n            _test_fused_glu,\n            _test_fused_glu_split,\n            _test_fused_glu_without_bias,\n            _test_fused_glu_split_without_bias,\n        ]\n\n        # set up env valuable if necessary\n        if not test_dualgemm_impt:\n            os.environ[\"ONEFLOW_KERNEL_GLU_ENABLE_DUAL_GEMM_IMPL\"] = \"false\"\n        else:\n            os.environ[\"ONEFLOW_KERNEL_GLU_ENABLE_DUAL_GEMM_IMPL\"] = \"true\"\n\n        # set up profiling functions\n        if not test_dualgemm_impt:\n            arg_dict[\"params\"] = [\n                # m=256, k=1280, n=5120\n                {\"m\": 256, \"k\": 1280, \"n\": 5120, \"act\": \"none\"},\n                {\"m\": 256, \"k\": 1280, \"n\": 5120, \"act\": \"sigmoid\"},\n                {\"m\": 256, \"k\": 1280, \"n\": 5120, \"act\": \"relu\"},\n                {\"m\": 256, \"k\": 1280, \"n\": 5120, \"act\": \"gelu\"},\n                {\"m\": 256, \"k\": 1280, \"n\": 5120, \"act\": \"fast_gelu\"},\n                {\"m\": 256, \"k\": 1280, \"n\": 5120, \"act\": \"silu\"},\n                # m=1024, k=640, n=2560\n                {\"m\": 1024, \"k\": 640, \"n\": 2560, \"act\": \"none\"},\n                {\"m\": 1024, \"k\": 640, \"n\": 2560, \"act\": \"sigmoid\"},\n                {\"m\": 1024, \"k\": 640, \"n\": 2560, \"act\": \"relu\"},\n                {\"m\": 1024, \"k\": 640, \"n\": 2560, \"act\": \"gelu\"},\n                {\"m\": 1024, \"k\": 640, \"n\": 2560, \"act\": \"fast_gelu\"},\n                {\"m\": 1024, \"k\": 640, \"n\": 2560, \"act\": \"silu\"},\n                # m=4096, k=320, n=1280\n                # {\"m\": 4096, \"k\": 320, \"n\": 1280, \"act\": \"none\"},\n                # {\"m\": 4096, \"k\": 320, \"n\": 1280, \"act\": \"sigmoid\"},\n                # {\"m\": 4096, \"k\": 320, \"n\": 1280, \"act\": \"relu\"},\n                # {\"m\": 4096, \"k\": 320, \"n\": 1280, \"act\": \"gelu\"},\n                # {\"m\": 4096, \"k\": 320, \"n\": 1280, \"act\": \"fast_gelu\"},\n                # {\"m\": 4096, \"k\": 320, \"n\": 1280, \"act\": \"silu\"},\n                # m=2560, k=12800, n=51200\n                # {\"m\": 2560, \"k\": 1280, \"n\": 5120, \"act\": \"none\"},\n                # {\"m\": 2560, \"k\": 1280, \"n\": 5120, \"act\": \"sigmoid\"},\n                # {\"m\": 2560, \"k\": 1280, \"n\": 5120, \"act\": \"relu\"},\n                # {\"m\": 2560, \"k\": 1280, \"n\": 5120, \"act\": \"gelu\"},\n                # {\"m\": 2560, \"k\": 1280, \"n\": 5120, \"act\": \"fast_gelu\"},\n                # {\"m\": 2560, \"k\": 1280, \"n\": 5120, \"act\": \"silu\"},\n            ]\n        else:\n            arg_dict[\"params\"] = [\n                # m=256, k=1280, n=5120\n                {\"m\": 256, \"k\": 1280, \"n\": 5120, \"act\": \"fast_gelu\"},\n                # m=1024, k=640, n=2560\n                {\"m\": 1024, \"k\": 640, \"n\": 2560, \"act\": \"fast_gelu\"},\n                # m=4096, k=320, n=1280\n                {\"m\": 4096, \"k\": 320, \"n\": 1280, \"act\": \"fast_gelu\"},\n                # m=2560, k=12800, n=51200\n                {\"m\": 2560, \"k\": 1280, \"n\": 5120, \"act\": \"fast_gelu\"},\n            ]\n\n        if not test_dualgemm_impt:\n            arg_dict[\"dtype\"] = [flow.float16, flow.float32]\n        else:\n            arg_dict[\"dtype\"] = [flow.float16]\n\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_fused_matmul_bias.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport numpy as np\n\n\ndef _matmul_bias(x, weight, bias, add_to_output):\n    return flow._C.add(\n        flow._C.bias_add(\n            flow._C.matmul(x, weight, transpose_b=True), bias, axis=len(x.shape) - 1\n        ),\n        add_to_output,\n    )\n\n\ndef _test_fused_matmul_add_bias(\n    test_case, batchsize, in_feature, out_feature, _add_to_output, dtype, device,\n):\n    add_to_output = np.zeros((*batchsize, out_feature))\n    if _add_to_output:\n        add_to_output = np.random.uniform(\n            low=-1, high=1, size=(*batchsize, out_feature)\n        )\n    x = np.random.uniform(low=-1, high=1, size=(*batchsize, in_feature))\n    weight = np.random.uniform(low=-1, high=1, size=(out_feature, in_feature))\n    bias = np.random.uniform(low=-1, high=1, size=(out_feature))\n\n    naive_x = flow.tensor(x, dtype=dtype, requires_grad=True)\n    naive_weight = flow.tensor(weight, dtype=dtype, requires_grad=True)\n    naive_bias = flow.tensor(bias, dtype=dtype, requires_grad=True)\n    naive_add_to_output = flow.tensor(add_to_output, dtype=dtype, requires_grad=True)\n\n    fused_x = flow.tensor(x, dtype=dtype, device=device, requires_grad=True)\n    fused_weight = flow.tensor(weight, dtype=dtype, device=device, requires_grad=True)\n    fused_bias = flow.tensor(bias, dtype=dtype, device=device, requires_grad=True)\n    fused_add_to_output = None\n    if _add_to_output:\n        fused_add_to_output = flow.tensor(\n            add_to_output, dtype=dtype, device=device, requires_grad=False\n        )\n\n    navie_y = _matmul_bias(naive_x, naive_weight, naive_bias, naive_add_to_output)\n    fused_y = flow._C.fused_matmul_bias(\n        fused_x, fused_weight, fused_bias, fused_add_to_output\n    )\n\n    y = navie_y.sum() + fused_y.sum()\n    y.backward()\n\n    # TODO: relative error might be too high...\n    # Test output equality\n    if _add_to_output:\n        test_case.assertTrue(\n            np.allclose(navie_y.numpy(), fused_y.numpy(), atol=5e-2, rtol=1e-4)\n        )\n    else:\n        test_case.assertTrue(\n            np.allclose(navie_y.numpy(), fused_y.numpy(), atol=5e-2, rtol=1e-4)\n        )\n\n    # Test grad equality\n    test_case.assertTrue(\n        np.allclose(naive_x.grad.numpy(), fused_x.grad.numpy(), atol=5e-2, rtol=1e-4)\n    )\n\n    test_case.assertTrue(\n        np.allclose(\n            naive_weight.grad.numpy(), fused_weight.grad.numpy(), atol=5e-2, rtol=1e-4\n        )\n    )\n    test_case.assertTrue(\n        np.allclose(\n            naive_bias.grad.numpy(), fused_bias.grad.numpy(), atol=1e-4, rtol=1e-4\n        )\n    )\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestFusedMatmulBiasAddRelu(flow.unittest.TestCase):\n    def test_fused_matmul_op(test_case):\n        args_dict = OrderedDict()\n        args_dict[\"test_fun\"] = [_test_fused_matmul_add_bias]\n        args_dict[\"batchsize\"] = [\n            (1,),\n            (4,),\n            (8,),\n            (2, 4),\n            (2, 4, 8),\n            (2, 4, 4, 4, 8),\n        ]\n        args_dict[\"in_feature\"] = [96, 128]\n        args_dict[\"out_feature\"] = [512, 1024, 288, 1]\n        args_dict[\"_add_to_output\"] = [True]\n        args_dict[\"dtype\"] = [flow.float32, flow.float64]\n        args_dict[\"device\"] = [\"cuda\", \"cpu\"]\n\n        for arg in GenArgList(args_dict):\n            arg[0](test_case, *arg[1:])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_fused_matmul_bias_add_relu_dropout.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nfrom collections import OrderedDict\nimport numpy as np\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\n\n\ndef _matmul_bias_relu(x, weight, bias, skip_activate):\n    # We do not add dropout in unittest, cause its result is random.\n    out = flow._C.bias_add(flow._C.matmul(x, weight, transpose_b=True), bias, axis=1)\n    if not skip_activate:\n        out = flow._C.relu(out)\n    return out\n\n\ndef _test_fused_matmul_bias_add_relu_dropout(\n    test_case,\n    batchsize,\n    in_feature,\n    hidden_size_list,\n    out_feature,\n    skip_final_activation,\n    dtype,\n    device,\n):\n    x = np.random.uniform(low=-1, high=1, size=(batchsize, in_feature))\n\n    fused_x = flow.tensor(x, dtype=dtype, device=device, requires_grad=True)\n    naive_x = flow.tensor(x, dtype=dtype, device=device, requires_grad=True)\n\n    fused_weight_list = []\n    naive_weight_list = []\n    fused_bias_list = []\n    naive_bias_list = []\n\n    hidden_num = len(hidden_size_list)\n\n    if hidden_num != 0:\n        np_first_weight = np.random.uniform(\n            low=-1, high=1, size=(hidden_size_list[0], in_feature)\n        )\n        np_first_bias = np.random.uniform(low=-1, high=1, size=hidden_size_list[0])\n\n        fused_weight_list.append(\n            flow.tensor(np_first_weight, dtype=dtype, device=device, requires_grad=True)\n        )\n        fused_bias_list.append(\n            flow.tensor(np_first_bias, dtype=dtype, device=device, requires_grad=True)\n        )\n        naive_weight_list.append(\n            flow.tensor(np_first_weight, dtype=dtype, device=device, requires_grad=True)\n        )\n        naive_bias_list.append(\n            flow.tensor(np_first_bias, dtype=dtype, device=device, requires_grad=True)\n        )\n\n    for idx in range(1, hidden_num):\n        np_weight = np.random.uniform(\n            low=-1, high=1, size=(hidden_size_list[idx], hidden_size_list[idx - 1])\n        )\n        np_bias = np.random.uniform(low=-1, high=1, size=hidden_size_list[idx])\n\n        fused_weight_list.append(\n            flow.tensor(np_weight, dtype=dtype, device=device, requires_grad=True)\n        )\n        fused_bias_list.append(\n            flow.tensor(np_bias, dtype=dtype, device=device, requires_grad=True)\n        )\n        naive_weight_list.append(\n            flow.tensor(np_weight, dtype=dtype, device=device, requires_grad=True)\n        )\n        naive_bias_list.append(\n            flow.tensor(np_bias, dtype=dtype, device=device, requires_grad=True)\n        )\n\n    np_final_weight = np.random.uniform(low=-1, high=1, size=(out_feature, in_feature))\n\n    if hidden_num != 0:\n        np_final_weight = np.random.uniform(\n            low=-1, high=1, size=(out_feature, hidden_size_list[-1])\n        )\n\n    np_final_bias = np.random.uniform(low=-1, high=1, size=(out_feature))\n\n    fused_weight_list.append(\n        flow.tensor(np_final_weight, dtype=dtype, device=device, requires_grad=True)\n    )\n    fused_bias_list.append(\n        flow.tensor(np_final_bias, dtype=dtype, device=device, requires_grad=True)\n    )\n    naive_weight_list.append(\n        flow.tensor(np_final_weight, dtype=dtype, device=device, requires_grad=True)\n    )\n    naive_bias_list.append(\n        flow.tensor(np_final_bias, dtype=dtype, device=device, requires_grad=True)\n    )\n\n    fused_out = flow._C.fused_matmul_bias_add_relu_dropout(\n        fused_x,\n        fused_weight_list,\n        fused_bias_list,\n        # We do not add dropout in unittest, cause its result is random.\n        dropout_rate_list=[0.0] * len(fused_weight_list),\n        skip_final_activation=skip_final_activation,\n    )\n\n    naive_out = _matmul_bias_relu(\n        naive_x,\n        naive_weight_list[0],\n        naive_bias_list[0],\n        False if hidden_num != 0 else skip_final_activation,\n    )\n\n    for idx in range(1, hidden_num + 1):\n        if idx == hidden_num:\n            naive_out = _matmul_bias_relu(\n                naive_out,\n                naive_weight_list[idx],\n                naive_bias_list[idx],\n                skip_final_activation,\n            )\n        else:\n            naive_out = _matmul_bias_relu(\n                naive_out, naive_weight_list[idx], naive_bias_list[idx], False\n            )\n\n    total_out = fused_out.sum() + naive_out.sum()\n    total_out.backward()\n\n    test_case.assertTrue(\n        np.allclose(fused_out.numpy(), naive_out.numpy(), atol=1e-4, rtol=1e-4)\n    )\n\n    # Test weight grad equality\n    for idx in range(hidden_num + 1):\n        test_case.assertTrue(\n            np.allclose(\n                fused_weight_list[idx].grad.numpy(),\n                naive_weight_list[idx].grad.numpy(),\n                atol=1e-4,\n                rtol=1e-4,\n            )\n        )\n        test_case.assertTrue(\n            np.allclose(\n                fused_bias_list[idx].grad.numpy(),\n                naive_bias_list[idx].grad.numpy(),\n                atol=1e-4,\n                rtol=1e-4,\n            )\n        )\n    # Test dx equality\n    test_case.assertTrue(\n        np.allclose(fused_x.grad.numpy(), naive_x.grad.numpy(), atol=1e-4, rtol=1e-4)\n    )\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestFusedMatmulBiasAddReluDropout(flow.unittest.TestCase):\n    def test_fused_matmul_bias_add_relu_dropout(test_case):\n        args_dict = OrderedDict()\n        args_dict[\"test_func\"] = [_test_fused_matmul_bias_add_relu_dropout]\n        args_dict[\"batchsize\"] = [1, 2, 4]\n        args_dict[\"in_feature\"] = [96, 128, 64]\n        args_dict[\"hidden_size_list\"] = [[256, 512], [400, 400, 400, 400], [17, 33, 79]]\n        args_dict[\"out_feature\"] = [512, 400, 1024, 1]\n        args_dict[\"skip_final_activation\"] = [False]\n        args_dict[\"dtype\"] = [flow.float32]\n        args_dict[\"device\"] = [\"cuda\"]\n\n        for arg in GenArgList(args_dict):\n            arg[0](test_case, *arg[1:])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_fused_rotary_embedding.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport numpy as np\nimport math\n\n\ndef plane_shuffle(x):\n    x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]\n    return np.concatenate((-x2, x1), axis=-1)\n\n\ndef shuffle_adjacent_two_elem(x):\n    y = x.copy()\n    for i in range(x.shape[-1] // 2):\n        y[..., 2 * i] = -x[..., 2 * i + 1]\n        y[..., 2 * i + 1] = x[..., 2 * i]\n    return y\n\n\ndef parseDims(dims, x_layout):\n    B = 1\n    M = 1\n    H = 1\n    K = 1\n    merged_dims = dims\n    if x_layout == \"BHMK\":\n        B = dims[0]\n        H = dims[1]\n        M = dims[2]\n        K = dims[3]\n        merged_dims = dims  # no merge\n    elif x_layout == \"BMHK\":\n        B = dims[0]\n        M = dims[1]\n        H = dims[2]\n        K = dims[3]\n        merged_dims = dims\n    elif x_layout == \"MBHK\":\n        B = dims[1]\n        M = dims[0]\n        H = dims[2]\n        K = dims[3]\n        merged_dims = dims\n    elif x_layout == \"BM(HK)\":\n        B = dims[0]\n        M = dims[1]\n        H = dims[2]\n        K = dims[3]\n        merged_dims = [dims[0], dims[1], dims[2] * dims[3]]\n    elif x_layout == \"MB(HK)\":\n        B = dims[1]\n        M = dims[0]\n        H = dims[2]\n        K = dims[3]\n        merged_dims = [dims[0], dims[1], dims[2] * dims[3]]\n    elif x_layout == \"BM(H3K)\":\n        B = dims[0]\n        M = dims[1]\n        H = dims[2]\n        K = dims[3]\n        merged_dims = [dims[0], dims[1], 3 * dims[2] * dims[3]]\n    elif x_layout == \"MB(H3K)\":\n        B = dims[1]\n        M = dims[0]\n        H = dims[2]\n        K = dims[3]\n        merged_dims = [dims[0], dims[1], 3 * dims[2] * dims[3]]\n\n    return B, M, H, K, merged_dims\n\n\n# all cos&sin are by default in x_layout (B, H, M, K), in which H is 1\ndef naive_embedding(\n    x,\n    cos,\n    sin,\n    x_layout,\n    B,\n    M,\n    H,\n    K,\n    dims,\n    merged_dims,\n    rotary_size,\n    rotary_ndims,\n    mode,\n):\n    naive_out = None\n    if mode == \"plane\":\n        if rotary_ndims == 2:\n            y1 = plane_shuffle(x[..., : rotary_size // 2])\n            y2 = plane_shuffle(x[..., rotary_size // 2 : rotary_size])\n            y3 = x[..., rotary_size:]\n            y = np.concatenate((y1, y2, y3), axis=-1)\n        else:\n            y1 = plane_shuffle(x[..., :rotary_size])\n            y2 = x[..., rotary_size:]\n            y = np.concatenate((y1, y2), axis=-1)\n    else:\n        y = shuffle_adjacent_two_elem(x)\n\n    if x_layout == \"BHMK\":\n        naive_out = x * cos + y * sin\n    elif x_layout == \"BMHK\":\n        naive_out = x.reshape(dims) * cos.reshape([B, M, 1, K]) + y.reshape(\n            dims\n        ) * sin.reshape(\n            [B, M, 1, K]\n        )  # un-merge\n    elif x_layout == \"MBHK\" or x_layout == \"MB(HK)\":\n        naive_out = x.reshape(dims) * cos.transpose([2, 0, 1, 3]).reshape(\n            [M, B, 1, K]\n        ) + y.reshape(dims) * sin.transpose([2, 0, 1, 3]).reshape(\n            [M, B, 1, K]\n        )  # un-merge\n    elif x_layout == \"BM(HK)\":\n        naive_out = x.reshape(dims) * cos.reshape([B, M, 1, K]) + y.reshape(\n            dims\n        ) * sin.reshape(\n            [B, M, 1, K]\n        )  # un-merge\n    elif x_layout == \"BM(H3K)\":\n        out0 = x[..., 0, :].reshape(dims) * cos.reshape([B, M, 1, K]) + y[\n            ..., 0, :\n        ].reshape(dims) * sin.reshape([B, M, 1, K])\n        out1 = x[..., 1, :].reshape(dims) * cos.reshape([B, M, 1, K]) + y[\n            ..., 1, :\n        ].reshape(dims) * sin.reshape([B, M, 1, K])\n        out2 = x[..., 2, :].reshape(dims) * cos.reshape([B, M, 1, K]) + y[\n            ..., 2, :\n        ].reshape(dims) * sin.reshape([B, M, 1, K])\n\n        naive_out = np.concatenate((out0, out1, out2), axis=-1)\n    elif x_layout == \"MB(H3K)\":\n        out0 = x[..., 0, :].reshape(dims) * cos.transpose([2, 0, 1, 3]).reshape(\n            [M, B, 1, K]\n        ) + y[..., 0, :].reshape(dims) * sin.transpose([2, 0, 1, 3]).reshape(\n            [M, B, 1, K]\n        )\n        out1 = x[..., 1, :].reshape(dims) * cos.transpose([2, 0, 1, 3]).reshape(\n            [M, B, 1, K]\n        ) + y[..., 1, :].reshape(dims) * sin.transpose([2, 0, 1, 3]).reshape(\n            [M, B, 1, K]\n        )\n        out2 = x[..., 2, :].reshape(dims) * cos.transpose([2, 0, 1, 3]).reshape(\n            [M, B, 1, K]\n        ) + y[..., 2, :].reshape(dims) * sin.transpose([2, 0, 1, 3]).reshape(\n            [M, B, 1, K]\n        )\n\n        naive_out = np.concatenate((out0, out1, out2), axis=-1)\n\n    return naive_out\n\n\n# this assume that rotary_ndims is by default 1\ndef _test_without_position(\n    test_case, x_layout, mode, base, rotary_size, dims, rotary_ndims, dtype, device\n):\n    B, M, H, K, merged_dims = parseDims(dims, x_layout)\n\n    np.random.seed(3124)\n\n    x = np.random.uniform(low=-1, high=1, size=(*merged_dims,))\n    naive_cos = np.array(\n        [\n            [\n                [\n                    math.cos(\n                        m\n                        * (\n                            (1 / base)\n                            ** (\n                                2\n                                * ((i % (rotary_size / rotary_ndims)) // 2)\n                                / (rotary_size / rotary_ndims)\n                            )\n                        )\n                    )\n                    for i in range(K)\n                ]\n                for m in range(M)\n            ]\n            for b in range(B)\n        ]\n    ).reshape(B, 1, M, K)\n    naive_sin = np.array(\n        [\n            [\n                [\n                    math.sin(\n                        m\n                        * (\n                            (1 / base)\n                            ** (\n                                2\n                                * ((i % (rotary_size / rotary_ndims)) // 2)\n                                / (rotary_size / rotary_ndims)\n                            )\n                        )\n                    )\n                    for i in range(K)\n                ]\n                for m in range(M)\n            ]\n            for b in range(B)\n        ]\n    ).reshape(B, 1, M, K)\n\n    naive_cos[..., rotary_size:] = 1\n    naive_sin[..., rotary_size:] = 0\n\n    naive_x = x\n    if x_layout == \"BM(HK)\" or x_layout == \"BM(H2K)\" or x_layout == \"BM(H3K)\":\n        naive_x = x.reshape([B, M, H, -1, K])\n    elif x_layout == \"MB(HK)\" or x_layout == \"MB(H2K)\" or x_layout == \"MB(H3K)\":\n        naive_x = x.reshape([M, B, H, -1, K])\n\n    naive_out = naive_embedding(\n        naive_x,\n        naive_cos,\n        naive_sin,\n        x_layout,\n        B,\n        M,\n        H,\n        K,\n        dims,\n        merged_dims,\n        rotary_size,\n        rotary_ndims,\n        mode,\n    )\n\n    fused_cos = np.array(\n        [\n            [\n                math.cos(\n                    m\n                    * (\n                        (1 / base)\n                        ** (\n                            2\n                            * ((i % (rotary_size // rotary_ndims)) // 2)\n                            / (rotary_size / rotary_ndims)\n                        )\n                    )\n                )\n                for i in range(rotary_size // rotary_ndims)\n            ]\n            for m in range(M)\n        ]\n    ).reshape(M, rotary_size // rotary_ndims)\n    fused_sin = np.array(\n        [\n            [\n                math.sin(\n                    m\n                    * (\n                        (1 / base)\n                        ** (\n                            2\n                            * ((i % (rotary_size // rotary_ndims)) // 2)\n                            / (rotary_size // rotary_ndims)\n                        )\n                    )\n                )\n                for i in range(rotary_size // rotary_ndims)\n            ]\n            for m in range(M)\n        ]\n    ).reshape(M, rotary_size // rotary_ndims)\n    fused_x = flow.tensor(x, dtype=dtype, device=device)\n    fused_cos = flow.tensor(fused_cos, dtype=dtype, device=device)\n    fused_sin = flow.tensor(fused_sin, dtype=dtype, device=device)\n\n    if x_layout == \"BM(H3K)\":\n        out0 = flow._C.fused_apply_rotary_emb(\n            fused_x,\n            cos=fused_cos,\n            sin=fused_sin,\n            position_ids=None,\n            x_layout=x_layout,\n            output_layout=\"BMHK\",\n            k_size=K,\n            base=base,\n            rotary_size=rotary_size,\n            mode=mode,\n            tensor_index=0,\n        )\n        out1 = flow._C.fused_apply_rotary_emb(\n            fused_x,\n            cos=fused_cos,\n            sin=fused_sin,\n            position_ids=None,\n            x_layout=x_layout,\n            output_layout=\"BMHK\",\n            k_size=K,\n            base=base,\n            rotary_size=rotary_size,\n            mode=mode,\n            tensor_index=1,\n        )\n        out2 = flow._C.fused_apply_rotary_emb(\n            fused_x,\n            cos=fused_cos,\n            sin=fused_sin,\n            position_ids=None,\n            x_layout=x_layout,\n            output_layout=\"BMHK\",\n            k_size=K,\n            base=base,\n            rotary_size=rotary_size,\n            mode=mode,\n            tensor_index=2,\n        )\n\n        fused_out = np.concatenate((out0, out1, out2), axis=-1)\n    else:\n        fused_out = flow._C.fused_apply_rotary_emb(\n            fused_x,\n            cos=fused_cos,\n            sin=fused_sin,\n            position_ids=None,\n            x_layout=x_layout,\n            k_size=K,\n            base=base,\n            rotary_size=rotary_size,\n            mode=mode,\n        ).numpy()\n\n    test_case.assertTrue(\n        np.allclose(\n            naive_out.reshape(merged_dims),\n            fused_out.reshape(merged_dims),\n            atol=5e-2,\n            rtol=5e-3,\n        )\n    )\n\n\n# this assume that rotary_ndims is by default 1\ndef _test_without_position_sinuous(\n    test_case, x_layout, mode, base, rotary_size, dims, rotary_ndims, dtype, device\n):\n    B, M, H, K, merged_dims = parseDims(dims, x_layout)\n\n    x = np.random.uniform(low=-1, high=1, size=(*merged_dims,))\n    naive_cos = np.array(\n        [\n            [\n                [\n                    math.cos(\n                        m\n                        * (\n                            (1 / base)\n                            ** (\n                                2\n                                * ((i % (rotary_size // rotary_ndims)) // 2)\n                                / (rotary_size // rotary_ndims)\n                            )\n                        )\n                    )\n                    for i in range(K)\n                ]\n                for m in range(M)\n            ]\n            for b in range(B)\n        ]\n    ).reshape(B, 1, M, K)\n    naive_sin = np.array(\n        [\n            [\n                [\n                    math.sin(\n                        m\n                        * (\n                            (1 / base)\n                            ** (\n                                2\n                                * ((i % (rotary_size // rotary_ndims)) // 2)\n                                / (rotary_size // rotary_ndims)\n                            )\n                        )\n                    )\n                    for i in range(K)\n                ]\n                for m in range(M)\n            ]\n            for b in range(B)\n        ]\n    ).reshape(B, 1, M, K)\n\n    naive_cos[..., rotary_size:] = 1\n    naive_sin[..., rotary_size:] = 0\n\n    naive_x = x\n    if x_layout == \"BM(HK)\" or x_layout == \"BM(H2K)\" or x_layout == \"BM(H3K)\":\n        naive_x = x.reshape([B, M, H, -1, K])\n    elif x_layout == \"MB(HK)\" or x_layout == \"MB(H2K)\" or x_layout == \"MB(H3K)\":\n        naive_x = x.reshape([M, B, H, -1, K])\n\n    naive_out = naive_embedding(\n        naive_x,\n        naive_cos,\n        naive_sin,\n        x_layout,\n        B,\n        M,\n        H,\n        K,\n        dims,\n        merged_dims,\n        rotary_size,\n        rotary_ndims,\n        mode,\n    )\n\n    fused_x = flow.tensor(x, dtype=dtype, device=device)\n\n    if x_layout == \"BM(H3K)\":\n        out0 = flow._C.fused_apply_rotary_emb(\n            fused_x,\n            cos=None,\n            sin=None,\n            position_ids=None,\n            x_layout=x_layout,\n            output_layout=\"BMHK\",\n            k_size=K,\n            base=base,\n            rotary_size=rotary_size,\n            mode=mode,\n            tensor_index=0,\n        )\n        out1 = flow._C.fused_apply_rotary_emb(\n            fused_x,\n            cos=None,\n            sin=None,\n            position_ids=None,\n            x_layout=x_layout,\n            output_layout=\"BMHK\",\n            k_size=K,\n            base=base,\n            rotary_size=rotary_size,\n            mode=mode,\n            tensor_index=1,\n        )\n        out2 = flow._C.fused_apply_rotary_emb(\n            fused_x,\n            cos=None,\n            sin=None,\n            position_ids=None,\n            x_layout=x_layout,\n            output_layout=\"BMHK\",\n            k_size=K,\n            base=base,\n            rotary_size=rotary_size,\n            mode=mode,\n            tensor_index=2,\n        )\n\n        fused_out = np.concatenate((out0, out1, out2), axis=-1)\n    else:\n        fused_out = flow._C.fused_apply_rotary_emb(\n            fused_x,\n            cos=None,\n            sin=None,\n            position_ids=None,\n            x_layout=x_layout,\n            k_size=K,\n            base=base,\n            rotary_size=rotary_size,\n            mode=mode,\n        ).numpy()\n\n    test_case.assertTrue(\n        np.allclose(\n            naive_out.reshape(merged_dims),\n            fused_out.reshape(merged_dims),\n            atol=5e-2,\n            rtol=5e-3,\n        )\n    )\n\n\ndef _test_with_position_sinuous(\n    test_case, x_layout, mode, base, rotary_size, dims, rotary_ndims, dtype, device\n):\n    B, M, H, K, merged_dims = parseDims(dims, x_layout)\n\n    np.random.seed(3124)\n\n    x = np.random.uniform(low=-1, high=1, size=(*merged_dims,))\n\n    position_ids = np.random.randint(2 * M, size=(B, rotary_ndims, M), dtype=np.int64)\n\n    naive_cos = np.array(\n        [\n            [\n                [\n                    math.cos(\n                        position_ids[b, i // ((rotary_size) // rotary_ndims), m]\n                        * (\n                            (1 / base)\n                            ** (\n                                2\n                                * ((i % (rotary_size // rotary_ndims)) // 2)\n                                / (rotary_size // rotary_ndims)\n                            )\n                        )\n                    )\n                    if i < rotary_size\n                    else 1\n                    for i in range(K)\n                ]\n                for m in range(M)\n            ]\n            for b in range(B)\n        ]\n    ).reshape(B, 1, M, K)\n\n    naive_sin = np.array(\n        [\n            [\n                [\n                    math.sin(\n                        position_ids[b, i // ((rotary_size) // rotary_ndims), m]\n                        * (\n                            (1 / base)\n                            ** (\n                                2\n                                * ((i % (rotary_size // rotary_ndims)) // 2)\n                                / (rotary_size // rotary_ndims)\n                            )\n                        )\n                    )\n                    if i < rotary_size\n                    else 0\n                    for i in range(K)\n                ]\n                for m in range(M)\n            ]\n            for b in range(B)\n        ]\n    ).reshape(B, 1, M, K)\n\n    naive_cos[..., rotary_size:] = 1\n    naive_sin[..., rotary_size:] = 0\n\n    naive_x = x\n    if x_layout == \"BM(HK)\" or x_layout == \"BM(H2K)\" or x_layout == \"BM(H3K)\":\n        naive_x = x.reshape([B, M, H, -1, K])\n    elif x_layout == \"MB(HK)\" or x_layout == \"MB(H2K)\" or x_layout == \"MB(H3K)\":\n        naive_x = x.reshape([M, B, H, -1, K])\n\n    naive_out = naive_embedding(\n        naive_x,\n        naive_cos,\n        naive_sin,\n        x_layout,\n        B,\n        M,\n        H,\n        K,\n        dims,\n        merged_dims,\n        rotary_size,\n        rotary_ndims,\n        mode,\n    )\n\n    fused_cos = np.array(\n        [\n            [\n                math.cos(\n                    m\n                    * (\n                        (1 / base)\n                        ** (\n                            2\n                            * ((i % (rotary_size // rotary_ndims)) // 2)\n                            / (rotary_size // rotary_ndims)\n                        )\n                    )\n                )\n                for i in range(rotary_size // rotary_ndims)\n            ]\n            for m in range(2 * M)\n        ]\n    )\n    fused_sin = np.array(\n        [\n            [\n                math.sin(\n                    m\n                    * (\n                        (1 / base)\n                        ** (\n                            2\n                            * ((i % (rotary_size // rotary_ndims)) // 2)\n                            / (rotary_size // rotary_ndims)\n                        )\n                    )\n                )\n                for i in range(rotary_size // rotary_ndims)\n            ]\n            for m in range(2 * M)\n        ]\n    )\n\n    fused_x = flow.tensor(x, dtype=dtype, device=device)\n    fused_cos = flow.tensor(fused_cos, dtype=dtype, device=device)\n    fused_sin = flow.tensor(fused_sin, dtype=dtype, device=device)\n    fused_position_ids = flow.tensor(position_ids, dtype=flow.int32, device=device)\n\n    if x_layout == \"BM(H3K)\":\n        out0 = flow._C.fused_apply_rotary_emb(\n            fused_x,\n            cos=fused_cos,\n            sin=fused_sin,\n            position_ids=fused_position_ids,\n            x_layout=x_layout,\n            output_layout=\"BMHK\",\n            k_size=K,\n            base=base,\n            rotary_size=rotary_size,\n            mode=mode,\n            tensor_index=0,\n        )\n        out1 = flow._C.fused_apply_rotary_emb(\n            fused_x,\n            cos=fused_cos,\n            sin=fused_sin,\n            position_ids=fused_position_ids,\n            x_layout=x_layout,\n            output_layout=\"BMHK\",\n            k_size=K,\n            base=base,\n            rotary_size=rotary_size,\n            mode=mode,\n            tensor_index=1,\n        )\n        out2 = flow._C.fused_apply_rotary_emb(\n            fused_x,\n            cos=fused_cos,\n            sin=fused_sin,\n            position_ids=fused_position_ids,\n            x_layout=x_layout,\n            output_layout=\"BMHK\",\n            k_size=K,\n            base=base,\n            rotary_size=rotary_size,\n            mode=mode,\n            tensor_index=2,\n        )\n\n        fused_out = np.concatenate((out0, out1, out2), axis=-1)\n    else:\n        fused_out = flow._C.fused_apply_rotary_emb(\n            fused_x,\n            cos=fused_cos,\n            sin=fused_sin,\n            position_ids=fused_position_ids,\n            x_layout=x_layout,\n            k_size=K,\n            base=base,\n            rotary_size=rotary_size,\n            mode=mode,\n        ).numpy()\n\n    test_case.assertTrue(\n        np.allclose(\n            naive_out.reshape(merged_dims),\n            fused_out.reshape(merged_dims),\n            atol=5e-2,\n            rtol=5e-3,\n        )\n    )\n\n\ndef _test_with_position(\n    test_case, x_layout, mode, base, rotary_size, dims, rotary_ndims, dtype, device\n):\n    B, M, H, K, merged_dims = parseDims(dims, x_layout)\n\n    x = np.random.uniform(low=-1, high=1, size=(*merged_dims,))\n\n    position_ids = np.random.randint(2 * M, size=(B, rotary_ndims, M), dtype=int)\n\n    naive_cos = np.array(\n        [\n            [\n                [\n                    math.cos(\n                        position_ids[b, i // ((rotary_size) // rotary_ndims), m]\n                        * (\n                            (1 / base)\n                            ** (\n                                2\n                                * ((i % (rotary_size / rotary_ndims)) // 2)\n                                / (rotary_size / rotary_ndims)\n                            )\n                        )\n                    )\n                    if i < rotary_size\n                    else 1\n                    for i in range(K)\n                ]\n                for m in range(M)\n            ]\n            for b in range(B)\n        ]\n    ).reshape(B, 1, M, K)\n\n    naive_sin = np.array(\n        [\n            [\n                [\n                    math.sin(\n                        position_ids[b, i // ((rotary_size) // rotary_ndims), m]\n                        * (\n                            (1 / base)\n                            ** (\n                                2\n                                * ((i % (rotary_size / rotary_ndims)) // 2)\n                                / (rotary_size / rotary_ndims)\n                            )\n                        )\n                    )\n                    if i < rotary_size\n                    else 0\n                    for i in range(K)\n                ]\n                for m in range(M)\n            ]\n            for b in range(B)\n        ]\n    ).reshape(B, 1, M, K)\n\n    naive_x = x\n    if x_layout == \"BM(HK)\" or x_layout == \"BM(H2K)\" or x_layout == \"BM(H3K)\":\n        naive_x = x.reshape([B, M, H, -1, K])\n    elif x_layout == \"MB(HK)\" or x_layout == \"MB(H2K)\" or x_layout == \"MB(H3K)\":\n        naive_x = x.reshape([M, B, H, -1, K])\n\n    naive_out = naive_embedding(\n        naive_x,\n        naive_cos,\n        naive_sin,\n        x_layout,\n        B,\n        M,\n        H,\n        K,\n        dims,\n        merged_dims,\n        rotary_size,\n        rotary_ndims,\n        mode,\n    )\n\n    fused_x = flow.tensor(x, dtype=dtype, device=device)\n    fused_position_ids = flow.tensor(position_ids, dtype=flow.int32, device=device)\n\n    if x_layout == \"BM(H3K)\":\n        out0 = flow._C.fused_apply_rotary_emb(\n            fused_x,\n            cos=None,\n            sin=None,\n            position_ids=fused_position_ids,\n            x_layout=x_layout,\n            output_layout=\"BMHK\",\n            k_size=K,\n            base=base,\n            rotary_size=rotary_size,\n            mode=mode,\n            tensor_index=0,\n        )\n        out1 = flow._C.fused_apply_rotary_emb(\n            fused_x,\n            cos=None,\n            sin=None,\n            position_ids=fused_position_ids,\n            x_layout=x_layout,\n            output_layout=\"BMHK\",\n            k_size=K,\n            base=base,\n            rotary_size=rotary_size,\n            mode=mode,\n            tensor_index=1,\n        )\n        out2 = flow._C.fused_apply_rotary_emb(\n            fused_x,\n            cos=None,\n            sin=None,\n            position_ids=fused_position_ids,\n            x_layout=x_layout,\n            output_layout=\"BMHK\",\n            k_size=K,\n            base=base,\n            rotary_size=rotary_size,\n            mode=mode,\n            tensor_index=2,\n        )\n\n        fused_out = np.concatenate((out0, out1, out2), axis=-1)\n    else:\n        fused_out = flow._C.fused_apply_rotary_emb(\n            fused_x,\n            cos=None,\n            sin=None,\n            position_ids=fused_position_ids,\n            x_layout=x_layout,\n            k_size=K,\n            base=base,\n            rotary_size=rotary_size,\n            mode=mode,\n        ).numpy()\n\n    test_case.assertTrue(\n        np.allclose(\n            naive_out.reshape(merged_dims),\n            fused_out.reshape(merged_dims),\n            atol=5e-2,\n            rtol=5e-3,\n        )\n    )\n\n\ndef _test_plane(\n    test_case, x_layout, mode, base, rotary_size, dims, rotary_ndims, dtype, device\n):\n    B, M, H, K, merged_dims = parseDims(dims, x_layout)\n\n    np.random.seed(3124)\n\n    x = np.random.uniform(low=-1, high=1, size=(*merged_dims,))\n\n    position_ids = np.random.randint(2 * M, size=(B, rotary_ndims, M), dtype=int)\n\n    naive_cos = np.array(\n        [\n            [\n                [\n                    math.cos(\n                        position_ids[b, i // ((rotary_size) // rotary_ndims), m]\n                        * (\n                            1\n                            / (\n                                base\n                                ** (\n                                    2\n                                    * (i % (rotary_size // (2 * rotary_ndims)))\n                                    / (rotary_size / rotary_ndims)\n                                )\n                            )\n                        )\n                    )\n                    if i < rotary_size\n                    else 1\n                    for i in range(K)\n                ]\n                for m in range(M)\n            ]\n            for b in range(B)\n        ]\n    ).reshape(B, 1, M, K)\n\n    naive_sin = np.array(\n        [\n            [\n                [\n                    math.sin(\n                        position_ids[b, i // ((rotary_size) // rotary_ndims), m]\n                        * (\n                            1\n                            / (\n                                base\n                                ** (\n                                    2\n                                    * (i % (rotary_size // (2 * rotary_ndims)))\n                                    / (rotary_size / rotary_ndims)\n                                )\n                            )\n                        )\n                    )\n                    if i < rotary_size\n                    else 0\n                    for i in range(K)\n                ]\n                for m in range(M)\n            ]\n            for b in range(B)\n        ]\n    ).reshape(B, 1, M, K)\n\n    naive_x = x\n    if x_layout == \"BM(HK)\" or x_layout == \"BM(H2K)\" or x_layout == \"BM(H3K)\":\n        naive_x = x.reshape([B, M, H, -1, K])\n    elif x_layout == \"MB(HK)\" or x_layout == \"MB(H2K)\" or x_layout == \"MB(H3K)\":\n        naive_x = x.reshape([M, B, H, -1, K])\n\n    naive_out = naive_embedding(\n        naive_x,\n        naive_cos,\n        naive_sin,\n        x_layout,\n        B,\n        M,\n        H,\n        K,\n        dims,\n        merged_dims,\n        rotary_size,\n        rotary_ndims,\n        mode,\n    )\n\n    fused_x = flow.tensor(x, dtype=dtype, device=device)\n    fused_position_ids = flow.tensor(position_ids, dtype=flow.int32, device=device)\n\n    if x_layout == \"MB(H3K)\":\n        out0 = flow._C.fused_apply_rotary_emb(\n            fused_x,\n            cos=None,\n            sin=None,\n            position_ids=fused_position_ids,\n            x_layout=x_layout,\n            output_layout=\"MBHK\",\n            k_size=K,\n            base=base,\n            rotary_size=rotary_size,\n            mode=mode,\n            tensor_index=0,\n        )\n        out1 = flow._C.fused_apply_rotary_emb(\n            fused_x,\n            cos=None,\n            sin=None,\n            position_ids=fused_position_ids,\n            x_layout=x_layout,\n            output_layout=\"MBHK\",\n            k_size=K,\n            base=base,\n            rotary_size=rotary_size,\n            mode=mode,\n            tensor_index=1,\n        )\n        out2 = flow._C.fused_apply_rotary_emb(\n            fused_x,\n            cos=None,\n            sin=None,\n            position_ids=fused_position_ids,\n            x_layout=x_layout,\n            output_layout=\"MBHK\",\n            k_size=K,\n            base=base,\n            rotary_size=rotary_size,\n            mode=mode,\n            tensor_index=2,\n        )\n\n        fused_out = np.concatenate((out0, out1, out2), axis=-1)\n    else:\n        fused_out = flow._C.fused_apply_rotary_emb(\n            fused_x,\n            cos=None,\n            sin=None,\n            position_ids=fused_position_ids,\n            x_layout=x_layout,\n            k_size=K,\n            base=base,\n            rotary_size=rotary_size,\n            mode=mode,\n        ).numpy()\n\n    test_case.assertTrue(\n        np.allclose(\n            naive_out.reshape(merged_dims),\n            fused_out.reshape(merged_dims),\n            atol=5e-2,\n            rtol=5e-3,\n        )\n    )\n\n\n\"\"\"\n1. if cos&sin is given, then base will not be used\n2. if cos&sin is not given, then any form of x_layout which cannot infer the dimension of k is not allowed, e.g. BM(HK)\n3. if position_ids is given, then M of cos&sin could be different from M of x\n4. if position_ids is not given, the dimension of rotary positional embedding is by default 1\n\"\"\"\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestFusedRotaryEmbedding(flow.unittest.TestCase):\n    # because rule no.2, kernels without cos&sin cannot work under specific x_layout\n    def test_fused_rotary_embedding_op_plane(test_case):\n        args_dict = OrderedDict()\n        args_dict[\"test_fun\"] = [_test_plane]\n        args_dict[\"x_layout\"] = [\"MB(H3K)\", \"MB(HK)\"]\n        args_dict[\"mode\"] = [\"plane\"]\n        args_dict[\"base\"] = [1e1]\n        args_dict[\"rotary_size\"] = [4, 8]\n        args_dict[\"dims\"] = [(3, 2, 5, 8)]\n        args_dict[\"rotary_ndims\"] = [2, 1]\n        # args_dict[\"rotary_size\"] = [48]\n        # args_dict[\"dims\"] = [(32, 2048, 32, 64)]\n        args_dict[\"dtype\"] = [flow.float16]\n        args_dict[\"device\"] = [\"cuda\"]\n\n        for arg in GenArgList(args_dict):\n            arg[0](test_case, *arg[1:])\n\n    def test_fused_rotary_embedding_op_interval_2d(test_case):\n        args_dict = OrderedDict()\n        args_dict[\"test_fun\"] = [_test_with_position, _test_with_position_sinuous]\n        args_dict[\"x_layout\"] = [\"BMHK\"]\n        args_dict[\"mode\"] = [\"interval\"]\n        args_dict[\"base\"] = [1e1]\n        args_dict[\"rotary_size\"] = [4]\n        args_dict[\"dims\"] = [(3, 2, 5, 8)]\n        args_dict[\"rotary_ndims\"] = [2]\n        # args_dict[\"rotary_size\"] = [48]\n        # args_dict[\"dims\"] = [(32, 2048, 32, 64)]\n        args_dict[\"dtype\"] = [flow.float16]\n        args_dict[\"device\"] = [\"cuda\"]\n\n        for arg in GenArgList(args_dict):\n            arg[0](test_case, *arg[1:])\n\n    def test_fused_rotary_embedding_op_interval_1d(test_case):\n        args_dict = OrderedDict()\n        args_dict[\"test_fun\"] = [\n            _test_without_position_sinuous,\n            _test_without_position,\n            _test_with_position,\n            _test_with_position_sinuous,\n        ]\n        args_dict[\"x_layout\"] = [\"BMHK\"]\n        args_dict[\"mode\"] = [\"interval\"]\n        args_dict[\"base\"] = [1e1]\n        args_dict[\"rotary_size\"] = [4]\n        args_dict[\"dims\"] = [(3, 2, 5, 8)]\n        args_dict[\"rotary_ndims\"] = [1]\n        # args_dict[\"rotary_size\"] = [48]\n        # args_dict[\"dims\"] = [(32, 2048, 32, 64)]\n        args_dict[\"dtype\"] = [flow.float16]\n        args_dict[\"device\"] = [\"cuda\"]\n\n        for arg in GenArgList(args_dict):\n            arg[0](test_case, *arg[1:])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_fused_scale_mask_bias_softmax.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\nimport numpy as np\nfrom oneflow.test_utils.test_util import GenArgList\nimport os\nfrom typing import List\nimport time\n\nimport oneflow as flow\n\n\ndef timing(fn):\n    def wrapper(*args, **kwargs):\n        if args[-1] or kwargs.get(\"inplace\"):\n            return fn(*args, **kwargs)\n        for _ in range(10):\n            fn(*args, **kwargs)\n        flow.cuda.synchronize()\n        start = time.perf_counter()\n        for _ in range(10):\n            fn(*args, **kwargs)\n        flow.cuda.synchronize()\n        print(f\"{fn.__name__}:{time.perf_counter() - start}\")\n        return fn(*args, **kwargs)\n\n    return wrapper\n\n\ndef permute_final_dims(tensor: flow.Tensor, inds: List[int]):\n    zero_index = -1 * len(inds)\n    first_inds = list(range(len(tensor.shape[:zero_index])))\n    return tensor.permute(first_inds + [zero_index + i for i in inds])\n\n\n@timing\ndef _fused_op(x, v, scale, mask, bias, inplace=False):\n    out = flow._C.fused_scale_mask_bias_softmax(x, mask, bias, scale, inplace=inplace)\n    out = flow.matmul(out, v)\n    return out\n\n\n@timing\ndef _ref_op(x, v, scale, mask, bias=None, inplace=False):\n    x = x * scale + mask + bias if bias is not None else x * scale + mask\n    out = flow.softmax(x, dim=-1)\n    out = flow.matmul(out, v)\n    return out\n\n\ndef _test_fused_scale_mask_bias_softmax(\n    test_case,\n    N=512,\n    S=128,\n    D=128,\n    h=8,\n    d=32,\n    mode=\"row\",\n    ensemble_batch=8,\n    inplace=False,\n):\n    x = flow.randn(N, S, D, requires_grad=True).cuda()  # N, S, D\n    w3 = [flow.randn(D, h * d, requires_grad=True).cuda() for _ in range(3)]  # D, h*d*3\n    mask = flow.randn(N, S, requires_grad=False).cuda()  # N, S\n    bias = None\n    scale = 1 / (d ** 0.5)\n    if mode in [\"row\", \"triangular_start\", \"triangular_end\"]:\n        bias = flow.randn(1, h, S, S, requires_grad=True).cuda()  # 1, h, S, S\n        bias.retain_grad()\n        mask = mask[:, None, None, :]\n    if mode == \"ensemble\":\n        x = flow.randn(ensemble_batch, N, S, D, requires_grad=True).cuda()  # N, S, D\n        bias = flow.randn(\n            ensemble_batch, 1, h, S, S, requires_grad=True\n        ).cuda()  # E, 1, h, S, S\n        bias.retain_grad()\n        mask = flow.randn(ensemble_batch, N, 1, 1, S, requires_grad=False).cuda()\n    if mode == \"col\" or mode == \"global_col\":\n        N, S = S, N\n        x = x.transpose(-2, -3)  # S, N, D\n        mask = mask.transpose(-1, -2)\n        if mode == \"col\":\n            mask = mask[..., None, None, :]  # S, 1, 1, N\n    q, k, v = [flow.matmul(x, w) for w in w3]  # N, S, h * d\n    if mode == \"template\":\n        n_templ = 4\n        x = flow.randn(S, S, 1, D, requires_grad=True).cuda()\n        k = v = flow.randn(S, S, n_templ, D, requires_grad=True).cuda()  # N, S, D\n        mask = flow.randn(1, 1, 1, 1, n_templ).cuda()\n        q, k, v = [flow.matmul(x_, w) for x_, w in zip([x, k, v], w3)]\n\n    q, k, v = [\n        permute_final_dims(a.view(*a.shape[:-1], h, d), (0, 2, 1, 3)) for a in [q, k, v]\n    ]  # N, h, S, d\n\n    if mode == \"global_col\":\n        w_q = flow.randn(D, h * d, requires_grad=True).cuda()  # D, h*d\n        w_kv = flow.randn(D, d * 2, requires_grad=True).cuda()  # D, h*d*2\n        q = flow.sum(x * mask.unsqueeze(-1), dim=-2) / (\n            flow.sum(mask, dim=-1)[..., None] + 1e-9\n        )  # [N, D]\n        mask = mask[..., :, None, :]  # N,1,S\n        q = flow.matmul(q, w_q).view(*q.shape[:-1], h, d)  # N, h, d\n        k, v = flow.matmul(x, w_kv).chunk(2, dim=-1)  # N, S, d\n    qk = flow.matmul(q, k.transpose(-1, -2))\n\n    # general op\n    x.retain_grad()\n    out1 = _ref_op(qk, v, scale, mask, bias, inplace)\n    out1.sum().backward(retain_graph=True)\n    grad_x1 = x.grad\n    grad_bias1 = bias.grad if bias is not None else None\n\n    # fused op\n    out2 = _fused_op(qk, v, scale, mask, bias, inplace)\n    out2.sum().backward()\n    grad_x2 = x.grad\n    grad_bias2 = bias.grad if bias is not None else None\n    test_case.assertTrue(np.allclose(out1, out2, atol=2e-3, rtol=1e-5))\n    test_case.assertTrue(np.allclose(grad_x1, grad_x2, atol=5e-3, rtol=1e-5))\n\n    if bias is not None:\n        test_case.assertTrue(np.allclose(grad_bias1, grad_bias2, atol=5e-4, rtol=1e-5))\n\n\n@unittest.skipIf(True, \"skip test for fused_scale_mask_bias_softmax.\")\n@flow.unittest.skip_unless_1n1d()\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test gpu cases\")\nclass TestFusedMsaSoftmax(flow.unittest.TestCase):\n    def test_fused_msa_softmax(test_case):\n        # different mask shape for each mode\n        _test_fused_scale_mask_bias_softmax(test_case, 16, 128, 64, 8, 32, \"row\")\n        _test_fused_scale_mask_bias_softmax(test_case, 16, 128, 64, 8, 32, \"col\")\n        _test_fused_scale_mask_bias_softmax(\n            test_case, 16, 128, 64, 8, 32, \"triangular_start\"\n        )\n        _test_fused_scale_mask_bias_softmax(\n            test_case, 16, 128, 64, 8, 32, \"triangular_end\"\n        )\n        _test_fused_scale_mask_bias_softmax(test_case, 16, 128, 64, 8, 32, \"template\")\n        _test_fused_scale_mask_bias_softmax(test_case, 16, 128, 64, 8, 32, \"global_col\")\n\n        _test_fused_scale_mask_bias_softmax(test_case, 16, 128, 64, 8, 32, \"ensemble\")\n\n        _test_fused_scale_mask_bias_softmax(\n            test_case, 16, 128, 64, 8, 32, \"row\", inplace=True\n        )\n        _test_fused_scale_mask_bias_softmax(\n            test_case, 16, 128, 64, 8, 32, \"col\", inplace=True\n        )\n        _test_fused_scale_mask_bias_softmax(\n            test_case, 128, 128, 64, 8, 32, \"triangular_start\", inplace=True\n        )\n        _test_fused_scale_mask_bias_softmax(\n            test_case, 16, 128, 64, 8, 32, \"triangular_end\", inplace=True\n        )\n        _test_fused_scale_mask_bias_softmax(\n            test_case, 16, 128, 64, 8, 32, \"template\", inplace=True\n        )\n        _test_fused_scale_mask_bias_softmax(\n            test_case, 16, 128, 64, 8, 32, \"global_col\", inplace=True\n        )\n        _test_fused_scale_mask_bias_softmax(\n            test_case, 16, 128, 64, 8, 32, \"ensemble\", inplace=True\n        )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_fused_scale_mask_softmax.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\nimport os\n\nimport numpy as np\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\ndef _test_fused_scale_mask_softmax(\n    test_case, batch_size, num_heads, seq_length, fill_value, scale_value, broadcast_dim\n):\n    x = np.random.randn(batch_size, num_heads, seq_length, seq_length).astype(\n        np.float32\n    )\n    mask_size = [batch_size, num_heads, seq_length, seq_length]\n    if broadcast_dim:\n        mask_size[broadcast_dim] = 1\n\n    mask = np.random.randint(0, 2, size=mask_size, dtype=bool)\n    fused_x_tensor = flow.tensor(x, dtype=flow.float32).to(\"cuda\")\n    fused_mask_tensor = flow.tensor(mask, dtype=flow.bool).to(\"cuda\")\n    fused_x_tensor.requires_grad = True\n\n    fused_out = flow._C.fused_scale_mask_softmax(\n        fused_x_tensor, fused_mask_tensor, fill_value=fill_value, scale=scale_value,\n    )\n\n    origin_x_tensor = flow.tensor(x).to(\"cuda\")\n    origin_mask_tensor = flow.tensor(mask, dtype=flow.float32).to(\"cuda\")\n    origin_x_tensor.requires_grad = True\n    origin_out = flow.mul(\n        origin_x_tensor, origin_mask_tensor\n    ) * scale_value + fill_value * (1.0 - origin_mask_tensor)\n    origin_out = flow.softmax(origin_out, dim=-1)\n\n    total_out = fused_out.sum() + origin_out.sum()\n    total_out.backward()\n\n    test_case.assertTrue(\n        np.allclose(fused_out.numpy(), origin_out.numpy(), atol=1e-4, rtol=1e-4)\n    )\n    test_case.assertTrue(\n        np.allclose(\n            fused_x_tensor.grad.numpy(),\n            origin_x_tensor.grad.numpy(),\n            atol=1e-4,\n            rtol=1e-4,\n        )\n    )\n\n\n@flow.unittest.skip_unless_1n1d()\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test gpu cases\")\nclass TestFusedScaleMaskSoftmax(flow.unittest.TestCase):\n    def test_fused_op(test_case):\n        args_dict = OrderedDict()\n        args_dict[\"test_fun\"] = [_test_fused_scale_mask_softmax]\n        args_dict[\"batch_size\"] = [4, 8, 16]\n        args_dict[\"num_heads\"] = [1, 4, 8]\n        args_dict[\"seq_length\"] = [16, 32, 64]\n        args_dict[\"fill_value\"] = [-10000.0]\n        args_dict[\"scale_value\"] = [1.0, 2.0, 4.0]\n        args_dict[\"broadcast_dim\"] = [None, 0, 1, 2]\n\n        for arg in GenArgList(args_dict):\n            arg[0](test_case, *arg[1:])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_fused_scale_mask_softmax_dropout.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\n\nimport unittest\nfrom collections import OrderedDict\nimport os\n\nimport numpy as np\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\ndef _test_fused_scale_mask_softmax_dropout(\n    test_case,\n    batch_size,\n    num_heads,\n    seq_length,\n    fill_value,\n    scale_value,\n    broadcast_dim,\n    p,\n):\n    x = np.random.randn(batch_size, num_heads, seq_length, seq_length)\n    mask_size = [batch_size, num_heads, seq_length, seq_length]\n    if broadcast_dim:\n        mask_size[broadcast_dim] = 1\n    mask = np.random.randint(0, 2, size=mask_size, dtype=bool)\n\n    fused_x_tensor = flow.tensor(x, dtype=flow.float32).to(\"cuda\")\n    fused_mask_tensor = flow.tensor(mask, dtype=flow.bool).to(\"cuda\")\n    fused_x_tensor.requires_grad = True\n\n    # if mask is zero, fill it\n    fused_out = flow._C.fused_scale_mask_softmax_dropout(\n        fused_x_tensor,\n        fused_mask_tensor,\n        fill_value=fill_value,\n        scale=scale_value,\n        p=p,\n    )[0]\n\n    origin_x_tensor = flow.tensor(x, dtype=flow.float32).to(\"cuda\")\n    origin_mask_tensor = flow.tensor(mask, dtype=flow.float32).to(\"cuda\")\n    origin_x_tensor.requires_grad = True\n    origin_out = flow.mul(\n        origin_x_tensor, origin_mask_tensor\n    ) * scale_value + fill_value * (1.0 - origin_mask_tensor)\n    origin_out = flow.softmax(origin_out, dim=-1)\n    origin_out = flow._C.dropout(origin_out, p=p)\n\n    total_out = fused_out.sum() + origin_out.sum()\n    total_out.backward()\n\n    test_case.assertTrue(\n        np.allclose(fused_out.numpy(), origin_out.numpy(), atol=1e-4, rtol=1e-4)\n    )\n    test_case.assertTrue(\n        np.allclose(\n            fused_x_tensor.grad.numpy(),\n            origin_x_tensor.grad.numpy(),\n            atol=1e-4,\n            rtol=1e-4,\n        )\n    )\n\n\n@flow.unittest.skip_unless_1n1d()\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test gpu cases\")\nclass TestFusedScaleMaskSoftmaxDropout(flow.unittest.TestCase):\n    def test_fused_op(test_case):\n        args_dict = OrderedDict()\n        args_dict[\"test_fun\"] = [_test_fused_scale_mask_softmax_dropout]\n        args_dict[\"batch_size\"] = [4, 8, 16]\n        args_dict[\"num_heads\"] = [1, 4, 8]\n        args_dict[\"seq_length\"] = [8, 16, 32, 64]\n        args_dict[\"fill_value\"] = [-10000.0]\n        args_dict[\"scale_value\"] = [1.0, 2.0, 4.0]\n        args_dict[\"broadcast_dim\"] = [None, 0, 1, 2]\n        args_dict[\"p\"] = [0.0, 1.0]\n\n        for arg in GenArgList(args_dict):\n            arg[0](test_case, *arg[1:])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_fused_scale_tril.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nimport os\nimport numpy as np\nfrom collections import OrderedDict\n\nfrom oneflow.test_utils.test_util import GenArgDict\n\nimport oneflow as flow\n\n\ndef _np_tril(x, diagonal, fill_value, scale):\n    if int(fill_value) == 0:\n        return np.tril(x, diagonal) * scale\n\n    upper = np.empty(x.shape)\n    upper.fill(fill_value)\n    upper = np.triu(upper, diagonal + 1)\n\n    return np.tril(x, diagonal) * scale + upper\n\n\ndef _test_fused_scale_tril(\n    test_case,\n    shape,\n    diagonal=0,\n    fill_value=0,\n    scale=1,\n    dtype=flow.float32,\n    device_type=\"cuda\",\n):\n    if dtype is flow.int32 and not isinstance(scale, int):\n        return\n\n    if dtype is flow.int32:\n        x = np.random.randint(0, 10, shape)\n        y_grad = np.random.randint(0, 10, shape)\n    else:\n        x = np.random.rand(*shape)\n        y_grad = np.random.rand(*shape)\n\n    y = _np_tril(x, diagonal, fill_value, scale)\n    x_grad = _np_tril(y_grad, diagonal, 0, scale)\n\n    flow_x = flow.tensor(\n        x, device=flow.device(device_type), dtype=dtype, requires_grad=True\n    )\n    flow_y = flow._C.fused_scale_tril(flow_x, diagonal, fill_value, scale)\n    flow_y_grad = flow.tensor(y_grad, device=flow.device(device_type), dtype=dtype)\n    flow_y.backward(flow_y_grad)\n\n    flow_y_np = flow_y.numpy()\n    test_case.assertTrue(np.allclose(flow_y_np, y.astype(flow_y_np.dtype)))\n\n    flow_x_grad_np = flow_x.grad.numpy()\n    test_case.assertTrue(\n        np.allclose(flow_x_grad_np, x_grad.astype(flow_x_grad_np.dtype))\n    )\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n@flow.unittest.skip_unless_1n1d()\nclass FusedScaleTrilTestCase(flow.unittest.TestCase):\n    def test_fused_scale_tril(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"shape\"] = [(5, 5), (4, 6)]\n        arg_dict[\"diagonal\"] = [-1, 0, 1]\n        arg_dict[\"fill_value\"] = [-1, 0, 1]\n        arg_dict[\"scale\"] = [-2.3, 0.7, 2]\n        arg_dict[\"dtype\"] = [flow.float32]\n        for kwargs in GenArgDict(arg_dict):\n            _test_fused_scale_tril(test_case, **kwargs)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_fused_self_attention.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport os\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\ndef _test_fused_self_attention(test_case, batch_size, seq_len, num_heads, head_size):\n    hidden_size = num_heads * 3 * head_size\n\n    x = np.random.randn(seq_len, batch_size, hidden_size)\n    fused_input = flow.Tensor(x).to(\"cuda\")\n    fused_input.requires_grad = True\n    (fused_qmk, fused_v) = flow._C.fused_self_attention(\n        fused_input, head_size=head_size, alpha=1.0,\n    )\n    fused_atten = flow.matmul(fused_qmk, fused_v)\n    fused_atten_sum = fused_atten.sum()\n\n    origin_input = flow.Tensor(x).to(\"cuda\")\n    origin_input.requires_grad = True\n    reshape_input = flow.reshape(origin_input, (seq_len, batch_size, -1, 3 * head_size))\n\n    origin_q = flow.slice(\n        reshape_input,\n        slice_tup_list=[\n            [None, None, None],\n            [None, None, None],\n            [None, None, None],\n            [0, head_size, 1],\n        ],\n    ).permute(1, 2, 0, 3)\n    origin_k = flow.slice(\n        reshape_input,\n        slice_tup_list=[\n            [None, None, None],\n            [None, None, None],\n            [None, None, None],\n            [head_size, 2 * head_size, 1],\n        ],\n    ).permute(1, 2, 0, 3)\n    origin_v = flow.slice(\n        reshape_input,\n        slice_tup_list=[\n            [None, None, None],\n            [None, None, None],\n            [None, None, None],\n            [2 * head_size, 3 * head_size, 1],\n        ],\n    ).permute(1, 2, 0, 3)\n\n    origin_k = origin_k.transpose(2, 3)\n    origin_qmk = flow.matmul(origin_q, origin_k)\n    origin_atten = flow.matmul(origin_qmk, origin_v)\n    origin_atten_sum = origin_atten.sum()\n\n    total_sum = fused_atten_sum + origin_atten_sum\n    total_sum.backward()\n\n    test_case.assertTrue(\n        np.allclose(fused_atten.numpy(), origin_atten.numpy(), atol=1e-4, rtol=1e-4)\n    )\n    test_case.assertTrue(\n        np.allclose(\n            fused_input.grad.numpy(), origin_input.grad.numpy(), atol=1e-4, rtol=1e-4,\n        )\n    )\n\n\n@flow.unittest.skip_unless_1n1d()\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\nclass TestFusedSelfAttention(flow.unittest.TestCase):\n    def _test_fused_self_attention(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [_test_fused_self_attention]\n        arg_dict[\"batch_size\"] = [1, 4, 6, 8]\n        arg_dict[\"seq_len\"] = [5, 10, 12]\n        arg_dict[\"num_heads\"] = [4, 8, 16]\n        arg_dict[\"head_size\"] = [16, 32, 64]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_fused_tril_softmax_mask_scale.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\nimport os\n\nimport numpy as np\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\ndef _test_fused_tril_softmax_mask_scale(\n    test_case, seq_length, channel, p, diagonal, tril_scale_value\n):\n    x = np.random.randn(4, seq_length, channel)\n    # fused version only support in GPU\n    fused_x_tensor = flow.Tensor(x).to(\"cuda\")\n    fused_x_tensor.requires_grad = True\n    fused_out = flow._C.fused_scale_tril_softmax_mask_scale(\n        fused_x_tensor, p=p, diagonal=diagonal, tril_scale_value=tril_scale_value\n    )[\n        0\n    ]  # The second output is softmax_y\n\n    origin_x_tensor = flow.Tensor(x).to(\"cuda\")\n    origin_x_tensor.requires_grad = True\n    origin_out = flow.tril(origin_x_tensor, diagonal)\n    origin_out = origin_out * tril_scale_value\n    origin_out = flow.softmax(origin_out, dim=-1)\n    origin_out = flow._C.dropout(origin_out, p=p)\n\n    total_out = fused_out.sum() + origin_out.sum()\n    total_out.backward()\n\n    test_case.assertTrue(\n        np.allclose(fused_out.numpy(), origin_out.numpy(), atol=1e-4, rtol=1e-4)\n    )\n    test_case.assertTrue(\n        np.allclose(\n            fused_x_tensor.grad.numpy(),\n            origin_x_tensor.grad.numpy(),\n            atol=1e-4,\n            rtol=1e-4,\n        )\n    )\n\n\n@flow.unittest.skip_unless_1n1d()\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test gpu cases\")\nclass TestFusedTrilSoftmaxMaskScale(flow.unittest.TestCase):\n    def test_fused_tril_softmax_dropout(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [_test_fused_tril_softmax_mask_scale]\n        arg_dict[\"seq_length\"] = [10, 20]\n        arg_dict[\"channel\"] = [20, 30]\n        arg_dict[\"p\"] = [0.0, 1.0]\n        arg_dict[\"diagonal\"] = [0, 1, 2]\n        arg_dict[\"tril_scale_value\"] = [2, 4, 10]\n\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_fused_weighted_sum.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nfrom collections import OrderedDict\nimport numpy as np\nfrom oneflow.test_utils.test_util import GenArgList\nimport math\nimport random\nimport os\n\nimport oneflow as flow\n\n\ndef _ref(inputs, weights, alpha, init_grad, device, dtype):\n    inputs = [flow.tensor(t).to(device).to(dtype) for t in inputs]\n    for t in inputs:\n        t.requires_grad = True\n    init_grad = flow.tensor(init_grad).to(device).to(dtype)\n    out = inputs[0] * weights[0]\n    for i, w in zip(inputs[1:], weights[1:]):\n        out += i * w\n    out = out * alpha\n    out.backward(init_grad)\n    return out, [t.grad for t in inputs]\n\n\ndef _fused_weighted_sum(inputs, weights, alpha, init_grad, device, dtype):\n    inputs = [flow.tensor(t).to(device).to(dtype) for t in inputs]\n    for t in inputs:\n        t.requires_grad = True\n    init_grad = flow.tensor(init_grad).to(device).to(dtype)\n    out = flow._C.fused_weighted_sum(inputs, weights, alpha)\n    out.backward(init_grad)\n    return out, [t.grad for t in inputs]\n\n\ndef _test_fused_weighted_sum(test_case, shape, n, device, dtype):\n    inputs = [np.random.randn(*shape) for _ in range(n)]\n    init_grad = np.random.randn(*shape)\n    weights = [random.random() for _ in range(n)]\n    alpha = random.random()\n    out, grads = _fused_weighted_sum(inputs, weights, alpha, init_grad, device, dtype)\n    ref, ref_grads = _ref(inputs, weights, alpha, init_grad, device, dtype)\n    test_case.assertTrue(np.allclose(ref, out, atol=1e-5, rtol=1e-5))\n    for (grad, ref_grad) in zip(grads, ref_grads):\n        test_case.assertTrue(np.allclose(ref_grad, grad, atol=1e-5, rtol=1e-5))\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n@flow.unittest.skip_unless_1n1d()\nclass TestFusedWeightedSum(flow.unittest.TestCase):\n    def test_fused_weighted_sum(test_case):\n        _test_fused_weighted_sum(test_case, (1024, 1024), 1, \"cuda\", flow.float32)\n        _test_fused_weighted_sum(test_case, (1024, 1024), 3, \"cuda\", flow.float32)\n        _test_fused_weighted_sum(test_case, (1024, 1024), 8, \"cuda\", flow.float32)\n        _test_fused_weighted_sum(test_case, (1024, 1024), 11, \"cuda\", flow.float32)\n        _test_fused_weighted_sum(test_case, (1024, 1024), 21, \"cuda\", flow.float32)\n        _test_fused_weighted_sum(test_case, (1024, 1024), 1, \"cpu\", flow.float32)\n        _test_fused_weighted_sum(test_case, (1024, 1024), 3, \"cpu\", flow.float32)\n        _test_fused_weighted_sum(test_case, (1024, 1024), 8, \"cpu\", flow.float32)\n        _test_fused_weighted_sum(test_case, (1024, 1024), 11, \"cpu\", flow.float32)\n        _test_fused_weighted_sum(test_case, (1024, 1024), 21, \"cpu\", flow.float32)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_gather.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nfrom oneflow.test_utils.test_util import GenArgList\nfrom oneflow.test_utils.automated_test_util import *\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\ndef _scatter_add_numpy(src, dim, index, outshape):\n    output = np.zeros(outshape)\n    for srcidx in range(0, src.size):\n        outcoord = np.unravel_index(srcidx, src.shape)\n        outcoord = [*outcoord]\n        outcoord[dim] = index[np.unravel_index(srcidx, index.shape)]\n        output_offset = np.ravel_multi_index(outcoord, outshape)\n        output[np.unravel_index(output_offset, outshape)] += src[\n            np.unravel_index(srcidx, src.shape)\n        ]\n    return output\n\n\ndef _test_gather(test_case, device):\n    input = np.array([[1, 2], [3, 4]])\n    index = np.array([[0, 0], [1, 0]])\n    np_out = np.take_along_axis(input, index, 0)\n    output = flow.gather(\n        flow.tensor(input, dtype=flow.float32, device=flow.device(device)),\n        0,\n        flow.tensor(index, dtype=flow.int64, device=flow.device(device)),\n    )\n    test_case.assertTrue(np.array_equal(output.numpy(), np_out))\n\n\ndef _test_gather_tensor_function(test_case, device):\n    input = np.array([[1, 2], [3, 4]])\n    index = np.array([[0, 0], [1, 0]])\n    np_out = np.take_along_axis(input, index, 1)\n    input = flow.tensor(input, dtype=flow.float32, device=flow.device(device))\n    index = flow.tensor(index, dtype=flow.int64, device=flow.device(device))\n    output = input.gather(1, index)\n    test_case.assertTrue(np.array_equal(output.numpy(), np_out))\n\n\ndef _test_gather_random_array(test_case, device):\n    input = np.random.randn(3, 4, 3, 5)\n    index = np.random.choice(np.arange(3), size=180, replace=True).reshape((3, 4, 3, 5))\n    np_out = np.take_along_axis(input, index, 1)\n    output = flow.gather(\n        flow.tensor(input, dtype=flow.float32, device=flow.device(device)),\n        1,\n        flow.tensor(index, dtype=flow.int64, device=flow.device(device)),\n    )\n    test_case.assertTrue(np.allclose(output.numpy(), np_out))\n    np_out2 = np.take_along_axis(input, index, 2)\n    output2 = flow.gather(\n        flow.tensor(input, dtype=flow.float32, device=flow.device(device)),\n        2,\n        flow.tensor(index, dtype=flow.int64, device=flow.device(device)),\n    )\n    test_case.assertTrue(np.allclose(output2.numpy(), np_out2))\n    np_out3 = np.take_along_axis(input, index, 3)\n    output3 = flow.gather(\n        flow.tensor(input, dtype=flow.float32, device=flow.device(device)),\n        3,\n        flow.tensor(index, dtype=flow.int64, device=flow.device(device)),\n    )\n    test_case.assertTrue(np.allclose(output3.numpy(), np_out3))\n\n\ndef _test_gather_backward(test_case, device):\n    input = np.array([[1, 2], [3, 4]])\n    index = np.array([[0, 0], [1, 0]])\n    np_out = np.take_along_axis(input, index, 0)\n    np_grad = _scatter_add_numpy(np.ones_like(np_out), 0, index, input.shape)\n    of_input = flow.tensor(\n        input, dtype=flow.float32, requires_grad=True, device=flow.device(device)\n    )\n    output = flow.gather(\n        of_input, 0, flow.tensor(index, dtype=flow.int64, device=flow.device(device)),\n    )\n    out_sum = output.sum()\n    out_sum.backward()\n    test_case.assertTrue(np.array_equal(output.numpy(), np_out))\n    test_case.assertTrue(np.array_equal(of_input.grad.numpy(), np_grad))\n\n\ndef _test_gather_index_0dim_tensor(test_case, device):\n    input = flow.ones(1).to(device)\n    input.requires_grad = True\n    index = flow.tensor(0).to(device)\n    output = flow.gather(input, 0, index)\n    test_case.assertTrue(np.array_equal(output.numpy(), 1.0))\n    output.sum().backward()\n    test_case.assertTrue(np.array_equal(input.grad.numpy(), [1.0]))\n\n\ndef _test_gather_input_index_0dim_tensor(test_case, device):\n    input = flow.tensor(1.0).to(device)\n    input.requires_grad = True\n    index = flow.tensor(0).to(device)\n    output = flow.gather(input, 0, index)\n    test_case.assertTrue(np.array_equal(output.numpy(), 1.0))\n    output.sum().backward()\n    test_case.assertTrue(np.array_equal(input.grad.numpy(), 1.0))\n\n\ndef _test_gather_input_0dim_tensor(test_case, device):\n    input = flow.tensor(1.0).to(device)\n    input.requires_grad = True\n    index = flow.tensor([0]).to(device)\n    output = flow.gather(input, 0, index)\n    test_case.assertTrue(np.array_equal(output.numpy(), [1.0]))\n    output.sum().backward()\n    test_case.assertTrue(np.array_equal(input.grad.numpy(), 1.0))\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestGather(flow.unittest.TestCase):\n    def test_gather(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [\n            _test_gather,\n            _test_gather_tensor_function,\n            _test_gather_random_array,\n            _test_gather_backward,\n            _test_gather_index_0dim_tensor,\n            _test_gather_input_index_0dim_tensor,\n            _test_gather_input_0dim_tensor,\n        ]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n    @autotest(n=5)\n    def test_flow_gather_with_random_data(test_case):\n        device = random_device()\n        input = random_tensor(ndim=4, dim0=3, dim1=3, dim2=4, dim3=5).to(device)\n        dim = random(-4, 4).to(int)\n        index = random_tensor(\n            ndim=4,\n            dim1=random(1, 3).to(int),\n            dim2=random(1, 4).to(int),\n            dim3=random(1, 5).to(int),\n            low=0,\n            high=3,\n            dtype=int,\n        ).to(device)\n        return torch.gather(input, dim, index)\n\n    @autotest(n=5, auto_backward=False, check_graph=True)\n    def test_flow_gather_bool_with_random_data(test_case):\n        device = random_device()\n        input = random_tensor(ndim=4, dim0=3, dim1=3, dim2=4, dim3=5).to(\n            device=device, dtype=torch.bool\n        )\n        dim = random(0, 4).to(int)\n        index = random_tensor(\n            ndim=4,\n            dim1=random(1, 3).to(int),\n            dim2=random(1, 4).to(int),\n            dim3=random(1, 5).to(int),\n            low=0,\n            high=3,\n            dtype=int,\n        ).to(device)\n        return torch.gather(input, dim, index)\n\n    @profile(torch.gather)\n    def profile_gather(test_case):\n        t = torch.ones(1000, 1000)\n        torch.gather(t, 1, torch.ones(1000, 1000, dtype=torch.int64))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_gather_nd.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\ndef _test_gather_nd(test_case, device):\n    input = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])\n    indices = np.array([[0], [2]])\n    np_out = np.array([[1, 2, 3], [7, 8, 9]])\n    output = flow.gather_nd(\n        flow.tensor(input, dtype=flow.float, device=flow.device(device)),\n        flow.tensor(indices, dtype=flow.int, device=flow.device(device)),\n    )\n    test_case.assertTrue(np.array_equal(output.numpy(), np_out))\n\n\ndef _test_gather_nd_t(test_case, device):\n    input = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])\n    indices = np.array([[0, 2], [2, 1]])\n    np_out = np.array([3, 8])\n    output = flow.gather_nd(\n        flow.tensor(input, dtype=flow.float, device=flow.device(device)),\n        flow.tensor(indices, dtype=flow.int, device=flow.device(device)),\n    )\n    test_case.assertTrue(np.array_equal(output.numpy(), np_out))\n\n\ndef _test_gather_nd_backward(test_case, device):\n    input = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])\n    indices = np.array([[0], [2]])\n    np_out = np.array([[1, 2, 3], [7, 8, 9]])\n    np_grad = np.array([[1, 1, 1], [0, 0, 0], [1, 1, 1]])\n    of_input = flow.tensor(\n        input, requires_grad=True, dtype=flow.float, device=flow.device(device)\n    )\n    output = flow.gather_nd(\n        of_input, flow.tensor(indices, dtype=flow.int, device=flow.device(device))\n    )\n    out_sum = output.sum()\n    out_sum.backward()\n    test_case.assertTrue(np.array_equal(output.numpy(), np_out))\n    test_case.assertTrue(np.array_equal(of_input.grad.numpy(), np_grad))\n\n\ndef _test_gather_nd_backward_t(test_case, device):\n    input = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])\n    indices = np.array([[0, 2], [2, 1]])\n    np_out = np.array([3, 8])\n    np_grad = np.array([[0, 0, 1], [0, 0, 0], [0, 1, 0]])\n    of_input = flow.tensor(\n        input, requires_grad=True, dtype=flow.float, device=flow.device(device)\n    )\n    output = flow.gather_nd(\n        of_input, flow.tensor(indices, dtype=flow.int, device=flow.device(device))\n    )\n    out_sum = output.sum()\n    out_sum.backward()\n    test_case.assertTrue(np.array_equal(output.numpy(), np_out))\n    test_case.assertTrue(np.array_equal(of_input.grad.numpy(), np_grad))\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestGather_nd(flow.unittest.TestCase):\n    def test_gather_nd(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [\n            _test_gather_nd,\n            _test_gather_nd_t,\n            _test_gather_nd_backward,\n            _test_gather_nd_backward_t,\n        ]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_gelu_approximate.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport math\nimport numpy as np\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport oneflow.unittest\nimport torch\n\n\nclass NewGELUActivation(torch.nn.Module):\n    \"\"\"\n    Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see\n    the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415\n    \"\"\"\n\n    def forward(self, input: torch.Tensor) -> torch.Tensor:\n        return (\n            0.5\n            * input\n            * (\n                1.0\n                + torch.tanh(\n                    math.sqrt(2.0 / math.pi)\n                    * (input + 0.044715 * torch.pow(input, 3.0))\n                )\n            )\n        )\n\n\ndef _test_gelu_approximate(test_case, device):\n    torch_gelu = NewGELUActivation()\n    x = np.random.randn(2, 4, 3)\n    torch_x = torch.tensor(x, requires_grad=True, device=torch.device(device))\n    oneflow_x = flow.tensor(x, requires_grad=True, device=flow.device(device))\n    torch_y = torch_gelu(torch_x)\n    oneflow_y = flow._C.gelu_with_approximate(oneflow_x, \"tanh\")\n    test_case.assertTrue(np.allclose(torch_y.detach().cpu().numpy(), oneflow_y.numpy()))\n    torch_y_sum = torch_y.sum()\n    torch_y_sum.backward()\n    oneflow_y_sum = oneflow_y.sum()\n    oneflow_y_sum.backward()\n    test_case.assertTrue(\n        np.allclose(torch_x.grad.cpu().numpy(), oneflow_x.grad.numpy())\n    )\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestModule(flow.unittest.TestCase):\n    def test_gelu_approximate(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [_test_gelu_approximate]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_generator.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport os\nimport unittest\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\nclass TestGenerator(flow.unittest.TestCase):\n    def test_different_devices(test_case):\n        auto_gen = flow.Generator(device=\"auto\")\n        cpu_gen = flow.Generator(device=\"cpu\")\n        test_case.assertTrue(auto_gen.initial_seed() == cpu_gen.initial_seed())\n        with test_case.assertRaises(RuntimeError) as context:\n            flow.Generator(device=\"invalid\")\n        if not os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"):\n            cuda_gen = flow.Generator(device=\"cuda\")\n            test_case.assertTrue(auto_gen.initial_seed() == cuda_gen.initial_seed())\n\n    def test_generator_manual_seed(test_case):\n        generator = flow.Generator()\n        generator.manual_seed(1)\n        test_case.assertTrue(generator.initial_seed() == 1)\n        generator.manual_seed(2)\n        test_case.assertTrue(generator.initial_seed() == 2)\n\n    def test_generator_in_dropout(test_case):\n        tgt = flow.ones(2000000)\n        output = flow._C.dropout(\n            tgt, p=0.1, training=True, generator=flow.Generator(), addend=None\n        )\n        output.numpy()\n        if not os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"):\n            output = flow._C.dropout(\n                tgt.cuda(), 0.1, training=True, generator=flow.Generator(), addend=None\n            )\n            output.numpy()\n\n\nclass TestDefaultGenerator(flow.unittest.TestCase):\n    def test_different_devices(test_case):\n        auto_gen = flow.Generator(device=\"auto\")\n        cpu_gen = flow.default_generator\n        with test_case.assertRaises(RuntimeError) as context:\n            flow.Generator(device=\"invalid\")\n\n        flow.Generator(device=\"cpu:1000\")\n        if not os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"):\n            with test_case.assertRaises(\n                oneflow._oneflow_internal.exception.Exception\n            ) as context:\n                flow.Generator(device=\"cuda:1000\")\n            cuda_gen = flow.Generator(device=\"cuda\")\n            cuda0_gen = flow.Generator(device=\"cuda:0\")\n\n    def test_generator_manual_seed(test_case):\n        cpu_gen = flow.default_generator\n        auto_gen = flow.Generator(device=\"auto\")\n        test_gens = [cpu_gen, auto_gen]\n        if not os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"):\n            cuda_gen = flow.Generator(device=\"cuda\")\n            cuda0_gen = flow.Generator(device=\"cuda:0\")\n            test_gens += [cuda_gen, cuda0_gen]\n        for seed in [1, 2]:\n            for gen in test_gens:\n                gen.manual_seed(seed)\n                test_case.assertTrue(gen.initial_seed() == seed)\n\n    def test_generator_seed(test_case):\n        cpu_gen = flow.default_generator\n        auto_gen = flow.Generator(device=\"auto\")\n        test_gens = [auto_gen, cpu_gen]\n        if not os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"):\n            cuda_gen = flow.Generator(device=\"cuda\")\n            cuda0_gen = flow.Generator(device=\"cuda:0\")\n            test_gens += [cuda_gen, cuda0_gen]\n        for gen in test_gens:\n            seed = gen.seed()\n            test_case.assertTrue(seed == gen.initial_seed())\n\n    def test_generator_getstate(test_case):\n        auto_gen = flow.Generator(device=\"auto\")\n        state = auto_gen.get_state()\n        cpu_gen = flow.Generator(device=\"cpu\")\n        state = cpu_gen.get_state()\n        if not os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"):\n            cuda_gen = flow.Generator(device=\"cuda\")\n            state = cuda_gen.get_state()\n\n    @unittest.skip(\"the curandstate is no longer used by normal kernel\")\n    def test_generator_setstate(test_case):\n        cpu_gen = flow.default_generator\n        flow.randn(100, 100, dtype=flow.float32, device=\"cpu\", generator=cpu_gen)\n        if not os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"):\n            cuda_gen = flow.Generator(\"cuda\")\n            flow.randn(100, 100, dtype=flow.float32, device=\"cuda\", generator=cuda_gen)\n        state = cpu_gen.get_state()\n        flow.randn(100, 100, dtype=flow.float32, device=\"cpu\", generator=cpu_gen)\n        if not os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"):\n            cuda_state = cuda_gen.get_state()\n            flow.randn(100, 100, dtype=flow.float32, device=\"cuda\", generator=cuda_gen)\n\n        new_state = cpu_gen.get_state()\n        test_case.assertTrue(not np.allclose(new_state.numpy(), state.numpy()))\n\n        cpu_gen.set_state(state)\n        new_state = cpu_gen.get_state()\n        test_case.assertTrue(np.allclose(new_state.numpy(), state.numpy()))\n\n        if not os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"):\n            new_cuda_state = cuda_gen.get_state()\n            test_case.assertTrue(\n                not np.allclose(new_cuda_state.numpy(), cuda_state.numpy())\n            )\n\n            cuda_gen.set_state(cuda_state)\n            new_cuda_state = cuda_gen.get_state()\n            test_case.assertTrue(\n                np.allclose(new_cuda_state.numpy(), cuda_state.numpy())\n            )\n\n    def test_get_rng_state(test_case):\n        cpu_gen = flow.default_generator\n        state = cpu_gen.get_state()\n        rng_state = flow.get_rng_state()\n        test_case.assertTrue(np.allclose(state.numpy(), rng_state.numpy()))\n\n        flow.randn(100, 100, dtype=flow.float32, device=\"cpu\", generator=cpu_gen)\n        state = cpu_gen.get_state()\n        rng_state = flow.get_rng_state()\n        test_case.assertTrue(np.allclose(state.numpy(), rng_state.numpy()))\n\n    def test_set_rng_state(test_case):\n        flow.randn(100, 100)\n        state = flow.get_rng_state()\n        flow.randn(100, 100)\n\n        new_state = flow.get_rng_state()\n        test_case.assertTrue(not np.allclose(new_state.numpy(), state.numpy()))\n\n        flow.set_rng_state(state)\n        new_state = flow.get_rng_state()\n        test_case.assertTrue(np.allclose(new_state.numpy(), state.numpy()))\n\n        if not os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"):\n            flow.randn(100, 100).to(\"cuda\")\n            state = flow.cuda.get_rng_state()\n            flow.randn(100, 100).to(\"cuda\")\n            new_state = flow.cuda.get_rng_state()\n            test_case.assertTrue(np.allclose(new_state.numpy(), state.numpy()))\n\n            states = flow.cuda.get_rng_state_all()\n            before0 = flow.cuda.FloatTensor(100, device=0).normal_()\n            before1 = flow.cuda.FloatTensor(100, device=1).normal_()\n            flow.cuda.set_rng_state_all(states)\n            after0 = flow.cuda.FloatTensor(100, device=0).normal_()\n            after1 = flow.cuda.FloatTensor(100, device=1).normal_()\n            test_case.assertTrue(np.allclose(before0.numpy(), after0.numpy()))\n            test_case.assertTrue(np.allclose(before1.numpy(), after1.numpy()))\n\n    # NOTE: according to https://github.com/Oneflow-Inc/oneflow/pull/9102#discussion_r973811389\n    # tensor init function fallback to `flow.default_generator.seed()`, and this test will be normal while tensor init functions reconstructed.(using op/kernel)\n    # @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    @unittest.skipIf(True, \"tensor init functions need to be reconstructed!\")\n    def test_tensor_init(test_case):\n        flow.manual_seed(0)\n        x = flow.ones(2)\n        x.uniform_()\n\n        flow.manual_seed(0)\n        y = flow.ones(2).to(\"cuda\")\n        y.uniform_()\n\n        test_case.assertTrue(np.allclose(x.numpy(), y.numpy()))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_0_dim_tensor.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=1, check_graph=True)\ndef _test_0_dim_tensor(test_case, placement, sbp):\n    x1 = random_tensor(0).to_global(placement=placement, sbp=sbp)\n    x2 = random_tensor(0).to_global(placement=placement, sbp=sbp)\n    y1 = x1 * x2\n    y2 = x1 + x2\n    return y1 + y2\n\n\n@autotest(n=1, check_graph=True)\ndef _test_1dim_slice(test_case, placement, sbp):\n    x = random_tensor(1, random(1, 4) * 8).to_global(placement=placement, sbp=sbp)\n    return x[5]\n\n\nclass TestZeroDimensionTensor(flow.unittest.TestCase):\n    @unittest.skip(\"skip for now, becase it failed 2 times in past week\")\n    @globaltest\n    def test_0_dim_tensor(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=0):\n                _test_0_dim_tensor(test_case, placement, sbp)\n            for sbp in all_sbp(placement, max_dim=1):\n                _test_1dim_slice(test_case, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_TripletMarginLoss.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\n\nfrom oneflow.test_utils.automated_test_util import *\nimport oneflow as flow\nimport oneflow.unittest\n\n\n@autotest(n=1, check_graph=True)\ndef _test_global_triplet_marginloss_with_random_data(test_case, placement, sbp):\n    margin = random().to(float)\n    p = random().to(float)\n    swap = random_bool()\n    reduction = oneof(\"none\", \"sum\", \"mean\", nothing())\n    m = torch.nn.TripletMarginLoss(margin=margin, p=p, swap=swap, reduction=reduction)\n    m.train(random())\n    anchor = random_tensor(2, 8, 16).to_global(placement, sbp)\n    pos = random_tensor(2, 8, 16).to_global(placement, sbp)\n    neg = random_tensor(2, 8, 16).to_global(placement, sbp)\n    y = m(anchor, pos, neg)\n    return y\n\n\nclass TestGlobalTripletMarginLoss(flow.unittest.TestCase):\n    @unittest.skip(\"skip for now, becase it failed 4 times in past week\")\n    @globaltest\n    def test_global_triplet_marginloss_with_random_data(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=2):\n                _test_global_triplet_marginloss_with_random_data(\n                    test_case, placement, sbp\n                )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_abs.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\n\nimport oneflow as flow\nfrom oneflow.test_utils.automated_test_util import *\nimport oneflow.unittest\n\n\n@autotest(n=1, check_graph=True)\ndef _test_abs_with_ndim_data(test_case, ndim, placement, sbp):\n    dims = [random(1, 3) * 8 for i in range(ndim)]\n    x = random_tensor(ndim, *dims).to_global(placement=placement, sbp=sbp)\n    y = torch.abs(x)\n    return y\n\n\nclass TestAbsModule(flow.unittest.TestCase):\n    @unittest.skip(\"skip for now, becase it failed 2 times in past week\")\n    @globaltest\n    def test_abs_with_ndim_data(test_case):\n        for placement in all_placement():\n            ndim = random(0, 4).to(int).value()\n            for sbp in all_sbp(placement, max_dim=ndim):\n                _test_abs_with_ndim_data(test_case, ndim, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_activation.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom oneflow.test_utils.automated_test_util import *\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\ndef build_module(act_type):\n    if act_type == \"relu\":\n        return torch.nn.ReLU()\n    elif act_type == \"relu6\":\n        return torch.nn.ReLU6()\n    elif act_type == \"tanh\":\n        return torch.nn.Tanh()\n    elif act_type == \"elu\":\n        return torch.nn.ELU(alpha=random())\n    elif act_type == \"celu\":\n        return torch.nn.CELU(alpha=random())\n    elif act_type == \"gelu\":\n        return torch.nn.GELU()\n    elif act_type == \"sigmoid\":\n        return torch.nn.Sigmoid()\n    elif act_type == \"hardsigmoid\":\n        return torch.nn.Hardsigmoid()\n    elif act_type == \"hardshrink\":\n        return torch.nn.Hardshrink(lambd=random())\n    elif act_type == \"logsigmoid\":\n        return torch.nn.LogSigmoid()\n    elif act_type == \"hardswish\":\n        return torch.nn.Hardswish()\n    elif act_type == \"hardtanh\":\n        return torch.nn.Hardtanh(\n            min_val=random().to(float), max_val=random().to(float),\n        )\n    elif act_type == \"leakyrelu\":\n        return torch.nn.LeakyReLU(negative_slope=random())\n    elif act_type == \"mish\":\n        return torch.nn.Mish()\n    elif act_type == \"silu\":\n        return torch.nn.SiLU()\n    elif act_type == \"selu\":\n        return torch.nn.SELU()\n    elif act_type == \"threshold\":\n        return torch.nn.Threshold(threshold=random(), value=random())\n    elif act_type == \"softplus\":\n        return torch.nn.Softplus()\n    elif act_type == \"softshrink\":\n        return torch.nn.Softshrink()\n    else:\n        raise ValueError(\"activation type %s is not support\" % act_type)\n\n\n@autotest(n=1, check_graph=False)\ndef _test_activation_module_with_random_data(test_case, act_type, ndim, placement, sbp):\n    m = build_module(act_type)\n    m.train(random())\n    dims = [random(1, 3) * 8 for i in range(ndim)]\n    x = random_tensor(ndim, *dims).to_global(placement=placement, sbp=sbp)\n    y = m(x)\n    return y\n\n\n@autotest(n=1, check_graph=False)\ndef _test_activation_module_with_0dim_data(test_case, act_type, placement, sbp):\n    m = build_module(act_type)\n    m.train(random())\n    x = random_tensor(ndim=0).to_global(placement=placement, sbp=sbp)\n    y = m(x)\n    return y\n\n\n@autotest(n=1, check_graph=False)\ndef _test_activation_module_with_0_size_data(\n    test_case, act_type, ndim, zerodim, placement, sbp\n):\n    m = build_module(act_type)\n    m.train(random())\n    dims = [random(1, 3) * 8 for i in range(ndim)]\n    dims[zerodim] = 0\n    x = random_tensor(ndim, *dims).to_global(placement=placement, sbp=sbp)\n    y = m(x)\n    return y\n\n\n@globaltest\ndef _test_activation_module(test_case, act_type):\n    for placement in all_placement():\n        ndim = random(1, 4).to(int).value()\n        for sbp in all_sbp(placement, max_dim=ndim):\n            _test_activation_module_with_random_data(\n                test_case, act_type, ndim, placement, sbp\n            )\n        # Skip gelu 0 size test since \"Floating point exception\" maybe encountered in PyTorch.\n        if act_type != \"gelu\":\n            zerodim = random(0, ndim).to(int).value()\n            valid_split_axis = [i for i in range(ndim) if i != zerodim]\n            for sbp in all_sbp(\n                placement, max_dim=ndim, valid_split_axis=valid_split_axis\n            ):\n                _test_activation_module_with_0_size_data(\n                    test_case, act_type, ndim, zerodim, placement, sbp\n                )\n        for sbp in all_sbp(placement, max_dim=0):\n            _test_activation_module_with_0dim_data(test_case, act_type, placement, sbp)\n\n\nclass TestReLUModule(flow.unittest.TestCase):\n    def test_relu_module(test_case):\n        _test_activation_module(test_case, \"relu\")\n\n\nclass TestReLU6Module(flow.unittest.TestCase):\n    def test_relu6_module(test_case):\n        _test_activation_module(test_case, \"relu6\")\n\n\nclass TestTanh(flow.unittest.TestCase):\n    def test_tanh_module(test_case):\n        _test_activation_module(test_case, \"tanh\")\n\n\nclass TestELUModule(flow.unittest.TestCase):\n    def test_elu_module(test_case):\n        _test_activation_module(test_case, \"elu\")\n\n\nclass TestCELUModule(flow.unittest.TestCase):\n    def test_celu_module(test_case):\n        _test_activation_module(test_case, \"celu\")\n\n\nclass TestGelu(flow.unittest.TestCase):\n    def test_gelu_module(test_case):\n        _test_activation_module(test_case, \"gelu\")\n\n\nclass TestSigmoidModule(flow.unittest.TestCase):\n    def test_sigmoid_module(test_case):\n        _test_activation_module(test_case, \"sigmoid\")\n\n\nclass TestHardsigmoidModule(flow.unittest.TestCase):\n    def test_hardsigmoid_module(test_case):\n        _test_activation_module(test_case, \"hardsigmoid\")\n\n\nclass TestHardshrinkModule(flow.unittest.TestCase):\n    def test_hardshrink_module(test_case):\n        _test_activation_module(test_case, \"hardshrink\")\n\n\nclass TestLogSigmoidModule(flow.unittest.TestCase):\n    def test_logsigmoid_module(test_case):\n        _test_activation_module(test_case, \"logsigmoid\")\n\n\nclass TestHardswishModule(flow.unittest.TestCase):\n    def test_hardswish_module(test_case):\n        _test_activation_module(test_case, \"hardswish\")\n\n\nclass TestHardtanhModule(flow.unittest.TestCase):\n    def test_hardtanh_module(test_case):\n        _test_activation_module(test_case, \"hardtanh\")\n\n\nclass TestLeakyReLUModule(flow.unittest.TestCase):\n    def test_leakyrelu_module(test_case):\n        _test_activation_module(test_case, \"leakyrelu\")\n\n\nclass TestMishModule(flow.unittest.TestCase):\n    def test_mish_module(test_case):\n        _test_activation_module(test_case, \"mish\")\n\n\nclass TestSiluModule(flow.unittest.TestCase):\n    def test_silu_module(test_case):\n        _test_activation_module(test_case, \"silu\")\n\n\nclass TestSeluModule(flow.unittest.TestCase):\n    def test_selu_module(test_case):\n        _test_activation_module(test_case, \"selu\")\n\n\nclass TestThresholdModule(flow.unittest.TestCase):\n    def test_threshold_module(test_case):\n        _test_activation_module(test_case, \"threshold\")\n\n\nclass TestSoftplusModule(flow.unittest.TestCase):\n    def test_softplus_module(test_case):\n        _test_activation_module(test_case, \"softplus\")\n\n\nclass TestSoftshrinkModule(flow.unittest.TestCase):\n    def test_softshrink_module(test_case):\n        _test_activation_module(test_case, \"softshrink\")\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_adaptive_pool.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nfrom packaging import version\nimport unittest\nfrom typing import Union, Tuple\nimport torch as torch_original\n\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.nn.common_types import _size_1_t\nfrom oneflow.test_utils.automated_test_util import *\n\nNoneType = type(None)\n# Not the same as those in PyTorch because 'output_size' cannot be NoneType (even in 'torch.nn.AdaptiveAvgPoolXd')\n_size_2_opt_t_not_none = Union[int, Tuple[Union[int, NoneType], Union[int, NoneType]]]\n_size_3_opt_t_not_none = Union[\n    int, Tuple[Union[int, NoneType], Union[int, NoneType], Union[int, NoneType]]\n]\n\n\n@autotest(n=1, check_graph=True)\ndef _test_adaptive_avgpoolnd(test_case, ndim, pool_size, placement, sbp):\n    dims = [random(1, 3) * 8 for i in range(ndim)]\n    x = random_tensor(ndim, *dims).to_global(placement=placement, sbp=sbp)\n    if pool_size == 1:\n        m = torch.nn.AdaptiveAvgPool1d(output_size=random().to(_size_1_t))\n    elif pool_size == 2:\n        m = torch.nn.AdaptiveAvgPool2d(output_size=random().to(_size_2_opt_t_not_none))\n    elif pool_size == 3:\n        m = torch.nn.AdaptiveAvgPool3d(output_size=random().to(_size_3_opt_t_not_none))\n    else:\n        raise ValueError(\"pool size should be 1, 2 or 3, but got %d\" % pool_size)\n    m.train(random())\n    y = m(x)\n    return y\n\n\n@autotest(n=1, check_graph=True)\ndef _test_adaptive_avgpoolnd_functional(test_case, ndim, pool_size, placement, sbp):\n    dims = [random(1, 3) * 8 for i in range(ndim)]\n    x = random_tensor(ndim, *dims).to_global(placement=placement, sbp=sbp)\n    if pool_size == 1:\n        return torch.nn.functional.adaptive_avg_pool1d(x, output_size=random().to(int))\n    elif pool_size == 2:\n        return torch.nn.functional.adaptive_avg_pool2d(x, output_size=random().to(int))\n    elif pool_size == 3:\n        return torch.nn.functional.adaptive_avg_pool3d(x, output_size=random().to(int))\n\n\nclass TestAdaptiveAvgPool(flow.unittest.TestCase):\n    @globaltest\n    def test_adaptive_avgpool(test_case):\n        for placement in all_placement():\n            ndim = 3\n            for sbp in all_sbp(placement, max_dim=2):\n                _test_adaptive_avgpoolnd(test_case, ndim, 1, placement, sbp)\n                _test_adaptive_avgpoolnd_functional(test_case, ndim, 1, placement, sbp)\n\n            ndim = 4\n            for sbp in all_sbp(placement, max_dim=2):\n                _test_adaptive_avgpoolnd(test_case, ndim, 2, placement, sbp)\n                _test_adaptive_avgpoolnd_functional(test_case, ndim, 2, placement, sbp)\n\n            # GPU version 'nn.AdaptiveAvgPool3d' has a bug in PyTorch before '1.10.0'\n            if (\n                version.parse(torch_original.__version__) < version.parse(\"1.10.0\")\n                and placement.type == \"cuda\"\n            ):\n                continue\n            ndim = 5\n            for sbp in all_sbp(placement, max_dim=2):\n                _test_adaptive_avgpoolnd(test_case, ndim, 3, placement, sbp)\n                _test_adaptive_avgpoolnd_functional(test_case, ndim, 3, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_add.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=1, check_graph=True)\ndef _test_add_with_alpha(test_case, ndim, placement, sbp):\n    dims = [random(1, 4) * 8 for i in range(ndim)]\n    x1 = random_tensor(ndim, *dims).to_global(placement=placement, sbp=sbp).mean()\n    x2 = random_tensor(ndim, *dims).to_global(placement=placement, sbp=sbp).mean()\n    x3 = random_tensor(ndim, *dims).to_global(placement=placement, sbp=sbp).mean()\n    y = random_tensor(ndim, *dims).to_global(placement=placement, sbp=sbp)\n    s = random().to(float)\n    alpha = random().to(float)\n    z1 = torch.add(x1, y, alpha=alpha)\n    z2 = torch.add(x2, s, alpha=alpha)\n    z3 = torch.add(s, x3, alpha=alpha)\n    return z1, z2, z3\n\n\n@autotest(n=1, check_graph=True)\ndef _test_add_with_0size(test_case, ndim, zerodim, placement, sbp):\n    dims = [random(1, 4) * 8 for i in range(ndim)]\n    dims[zerodim] = 1\n    x1 = random_tensor(ndim, *dims).to_global(placement=placement, sbp=sbp)\n    dims[zerodim] = 0\n    x2 = random_tensor(ndim, *dims).to_global(placement=placement, sbp=sbp)\n    return torch.add(x1, x2)\n\n\nclass TestAddModule(flow.unittest.TestCase):\n    @unittest.skip(\"skip for now, becase it failed 2 times in past week\")\n    @globaltest\n    def test_add_with_alpha(test_case):\n        ndim = random(1, 4).to(int).value()\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=ndim):\n                _test_add_with_alpha(test_case, ndim, placement, sbp)\n            zerodim = random(0, ndim).to(int).value()\n            valid_split_axis = [i for i in range(ndim) if i != zerodim]\n            for sbp in all_sbp(\n                placement, max_dim=ndim, valid_split_axis=valid_split_axis\n            ):\n                _test_add_with_0size(test_case, ndim, zerodim, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_addcdiv.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\n\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=1, check_graph=True)\ndef _test_addcdiv(test_case, ndim, placement, sbp):\n    shape = [random(2, 4) * 8 for i in range(ndim)]\n    input = random_tensor(ndim, *shape).to_global(placement=placement, sbp=sbp)\n    tensor1 = random_tensor(ndim, *shape).to_global(placement=placement, sbp=sbp)\n    tensor2 = random_tensor(ndim, *shape).to_global(placement=placement, sbp=sbp)\n    value = random(2, 4).to(int)\n    output = torch.addcdiv(input, tensor1, tensor2, value=value)\n    return output\n\n\nclass TestModule(flow.unittest.TestCase):\n    @globaltest\n    def test_addcdiv(test_case):\n        ndim = random(2, 4).to(int).value()\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=ndim):\n                _test_addcdiv(test_case, ndim, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_addcmul.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom oneflow.test_utils.automated_test_util import *\nimport oneflow as flow\nimport oneflow.unittest\n\n\n@autotest(n=1, check_graph=True)\ndef _test_addcmul(test_case, ndim, placement, sbp):\n    shape = [random(low=2, high=3) * 8 for i in range(ndim)]\n\n    input = random_tensor(ndim, *shape).to_global(placement=placement, sbp=sbp)\n    tensor1 = random_tensor(len(shape), *shape).to_global(placement=placement, sbp=sbp)\n    tensor2 = random_tensor(len(shape), *shape).to_global(placement=placement, sbp=sbp)\n    value = random(3, 6).to(int)\n    output = torch.addcmul(input, tensor1, tensor2, value=value)\n    return output\n\n\nclass TestModule(flow.unittest.TestCase):\n    @globaltest\n    def test_addcmul(test_case):\n        ndim = random(low=2, high=5).to(int).value()\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=ndim):\n                _test_addcmul(test_case, ndim, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_addmm.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=1, check_graph=True)\ndef _test_addmm_with_random_data(test_case, placement, sbp):\n    m = random(1, 3) * 8\n    n = random(1, 3) * 8\n    k = random(1, 3) * 8\n    input = random_tensor(ndim=2, dim0=m, dim1=n).to_global(\n        placement=placement, sbp=sbp\n    )\n    mat1 = random_tensor(ndim=2, dim0=m, dim1=k).to_global(placement=placement, sbp=sbp)\n    mat2 = random_tensor(ndim=2, dim0=k, dim1=n).to_global(placement=placement, sbp=sbp)\n    y = torch.addmm(\n        input, mat1, mat2, beta=random().to(float), alpha=random().to(float),\n    )\n    return y\n\n\n@autotest(n=1, check_graph=True)\ndef _test_addmm_broadcast_with_random_data(test_case, placement, sbp):\n    m = random(1, 3) * 8\n    n = random(1, 3) * 8\n    k = random(1, 3) * 8\n    input = random_tensor(ndim=2, dim0=1, dim1=1).to_global(\n        placement=placement, sbp=[flow.sbp.broadcast for _ in range(len(sbp))]\n    )\n    mat1 = random_tensor(ndim=2, dim0=m, dim1=k).to_global(placement=placement, sbp=sbp)\n    mat2 = random_tensor(ndim=2, dim0=k, dim1=n).to_global(placement=placement, sbp=sbp)\n    y = torch.addmm(\n        input, mat1, mat2, beta=random().to(float), alpha=random().to(float),\n    )\n    return y\n\n\nclass TestAddmm(flow.unittest.TestCase):\n    @globaltest\n    def test_addmm(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=2):\n                _test_addmm_with_random_data(test_case, placement, sbp)\n                _test_addmm_broadcast_with_random_data(test_case, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_affine_grid.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=1, rtol=1e-03, atol=1e-04, check_graph=True)\ndef _test_affine_grid_2d_with_random_data(test_case, placement, sbp):\n    N = random(1, 3).to(int).value() * 8\n    C = random(1, 8).to(int).value()\n    H = random(1, 8).to(int).value()\n    W = random(1, 8).to(int).value()\n    align_corners = oneof(True, False).value()\n    dims = [N, 2, 3]\n\n    theta = random_tensor(3, *dims).to_global(placement=placement, sbp=sbp)\n    output = torch.nn.functional.affine_grid(\n        theta, (N, C, H, W), align_corners=align_corners\n    )\n    return output\n\n\n@autotest(n=1, rtol=1e-03, atol=1e-04, check_graph=True)\ndef _test_affine_grid_3d_with_random_data(test_case, placement, sbp):\n    N = random(1, 3).to(int) * 8\n    C = random(1, 8).to(int)\n    D = random(1, 8).to(int)\n    H = random(1, 8).to(int)\n    W = random(1, 8).to(int)\n    align_corners = oneof(True, False)\n    dims = [N, 3, 4]\n\n    theta = random_tensor(3, *dims).to_global(placement=placement, sbp=sbp)\n    output = torch.nn.functional.affine_grid(\n        theta, (N, C, D, H, W), align_corners=align_corners\n    )\n    return output\n\n\nclass TestAffineGrid(flow.unittest.TestCase):\n    @globaltest\n    def test_affine_grid(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=1):\n                _test_affine_grid_2d_with_random_data(test_case, placement, sbp)\n                _test_affine_grid_3d_with_random_data(test_case, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_argmax.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\n\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=1, auto_backward=False, check_graph=True)\ndef _test_argmax_with_random_data(test_case, ndim, placement, sbp):\n    dims = [8 for _ in range(ndim)]\n    x = random_tensor(ndim, *dims).to_global(placement=placement, sbp=sbp)\n    y = torch.argmax(x, dim=random(0, ndim).to(int), keepdim=random().to(bool))\n    return y\n\n\n@unittest.skip(\"TODO: sometimes global TestArgmax fails on 2-GPU runs\")\nclass TestArgmax(flow.unittest.TestCase):\n    @globaltest\n    def test_argmax(test_case):\n        ndim = random(1, 3).to(int).value()\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=ndim):\n                _test_argmax_with_random_data(test_case, ndim, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_argmin.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\n\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=1, auto_backward=False, check_graph=True)\ndef _test_argmin_with_random_data(test_case, ndim, placement, sbp):\n    dims = [random(1, 3) * 8 for _ in range(ndim)]\n    x = random_tensor(ndim, *dims).to_global(placement=placement, sbp=sbp)\n    y = torch.argmin(x, dim=random(0, ndim).to(int), keepdim=random().to(bool))\n    return y\n\n\n@unittest.skip(\"TODO: sometimes global TestArgmin fails on 2-GPU runs\")\nclass TestArgmin(flow.unittest.TestCase):\n    @globaltest\n    def test_argmin(test_case):\n        ndim = random(1, 5).to(int).value()\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=ndim):\n                _test_argmin_with_random_data(test_case, ndim, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_argsort.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\n\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=1, auto_backward=False, check_graph=True)\ndef _test_argsort_with_random_data(test_case, ndim, placement, sbp):\n    dims = [random(1, 3) * 8 for _ in range(ndim)]\n    x = random_tensor(ndim, *dims).to_global(placement=placement, sbp=sbp)\n    y = torch.argsort(\n        x, dim=random(low=-ndim, high=ndim).to(int), descending=random_bool()\n    )\n    return y\n\n\n@unittest.skip(\"argsort has bug not found at now.\")\nclass TestArgsort(flow.unittest.TestCase):\n    @globaltest\n    def test_argsort(test_case):\n        ndim = random(1, 5).to(int).value()\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=ndim):\n                _test_argsort_with_random_data(test_case, ndim, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_argwhere.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nimport torch as torch_ori\n\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=1, auto_backward=False, check_graph=True)\ndef _test_argwhere_with_random_data(test_case, ndim, placement, sbp):\n    dims = [random(1, 3) * 8 for _ in range(ndim)]\n    x = random_tensor(ndim, *dims).to_global(placement=placement, sbp=sbp)\n    # PyTorch has no argwhere before v1.11, so we use nonzero instead of argwhere for PyTorch\n    # y = torch.argwhere(x)\n    y = x.clone()\n    y.oneflow = flow.argwhere(x.oneflow)\n    y.pytorch = torch_ori.nonzero(x.pytorch)\n    return y\n\n\nclass TestArgwhere(flow.unittest.TestCase):\n    @globaltest\n    def test_argwhere(test_case):\n        ndim = random(1, 5).to(int).value()\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=ndim):\n                _test_argwhere_with_random_data(test_case, ndim, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_atleast.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\n\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=1, check_graph=True)\ndef _test_atleast1d_with_random_data(test_case, placement, sbp):\n    x = random_tensor(ndim=1, dim0=8).to_global(placement, sbp)\n    y = random_tensor(ndim=2, dim0=8).to_global(placement, sbp)\n    out = torch.atleast_1d([x, y])\n    return out\n\n\n@autotest(n=1, check_graph=True)\ndef _test_atleast2d_with_random_data(test_case, placement, sbp):\n    x = random_tensor(ndim=1, dim0=8).to_global(placement, sbp)\n    y = random_tensor(ndim=2, dim0=8).to_global(placement, sbp)\n    z = random_tensor(ndim=3, dim0=8).to_global(placement, sbp)\n    out = torch.atleast_2d([x, y, z])\n    return out\n\n\n@autotest(n=1, check_graph=True)\ndef _test_atleast3d_with_random_data(test_case, placement, sbp):\n    x = random_tensor(ndim=1, dim0=8).to_global(placement, sbp)\n    y = random_tensor(ndim=2, dim0=8).to_global(placement, sbp)\n    z = random_tensor(ndim=3, dim0=8).to_global(placement, sbp)\n    p = random_tensor(ndim=4, dim0=8).to_global(placement, sbp)\n    out = torch.atleast_3d([x, y, z, p])\n    return out\n\n\nclass TestAtLeastModule(flow.unittest.TestCase):\n    @globaltest\n    def test_atleast1d_with_random_data(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=1):\n                _test_atleast1d_with_random_data(test_case, placement, sbp)\n\n    @globaltest\n    def test_atleast2d_with_random_data(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=1):\n                _test_atleast2d_with_random_data(test_case, placement, sbp)\n\n    @globaltest\n    def test_atleast3d_with_random_data(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=1):\n                _test_atleast3d_with_random_data(test_case, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_avgpool.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\n\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=1, check_graph=True)\ndef _test_avgpool1d_with_random_data(test_case, placement, sbp):\n    m = torch.nn.AvgPool1d(\n        kernel_size=random(4, 6),\n        stride=random(1, 3),\n        padding=random(1, 3),\n        ceil_mode=random(),\n        count_include_pad=random(),\n    )\n    m.train(random())\n    m.to_global(placement=placement, sbp=sbp)\n    ndim = 3\n    dims = [random(1, 3) * 8 for _ in range(ndim)]\n    x = random_tensor(ndim, *dims).to_global(placement=placement, sbp=sbp)\n    y = m(x)\n    return y\n\n\n@autotest(n=1, check_graph=True)\ndef _test_avgpool2d_with_random_data(test_case, placement, sbp):\n    m = torch.nn.AvgPool2d(\n        kernel_size=random(4, 6),\n        stride=random(1, 3),\n        padding=random(1, 3),\n        ceil_mode=random(),\n        count_include_pad=random(),\n        divisor_override=random().to(int),\n    )\n    m.train(random())\n    m.to_global(placement=placement, sbp=sbp)\n    ndim = 4\n    dims = [random(1, 3) * 8 for _ in range(ndim)]\n    x = random_tensor(ndim, *dims).to_global(placement=placement, sbp=sbp)\n    y = m(x)\n    return y\n\n\n@autotest(n=1, check_graph=True)\ndef _test_avgpool3d_with_random_data(test_case, placement, sbp):\n    m = torch.nn.AvgPool3d(\n        kernel_size=random(4, 6),\n        stride=random(1, 3),\n        padding=random(1, 3),\n        ceil_mode=random(),\n        count_include_pad=random(),\n        divisor_override=random().to(int),\n    )\n    m.train(random())\n    m.to_global(placement=placement, sbp=sbp)\n    ndim = 5\n    dims = [random(1, 3) * 8 for _ in range(ndim)]\n    x = random_tensor(ndim, *dims).to_global(placement=placement, sbp=sbp)\n    y = m(x)\n    return y\n\n\n@autotest(n=1, check_graph=True)\ndef _test_functional_avgpool1d_with_random_data(test_case, placement, sbp):\n    ndim = 3\n    dims = [random(1, 3) * 8 for _ in range(ndim)]\n    x = random_tensor(ndim, *dims).to_global(placement=placement, sbp=sbp)\n    y = torch.nn.functional.avg_pool1d(\n        x,\n        kernel_size=random(1, 6).to(int),\n        stride=random(1, 3).to(int),\n        padding=random(1, 3).to(int),\n        ceil_mode=random_bool(),\n        count_include_pad=random_bool(),\n    )\n    return y\n\n\n@autotest(n=1, check_graph=True)\ndef _test_functional_avgpool2d_with_random_data(test_case, placement, sbp):\n    ndim = 4\n    dims = [random(1, 3) * 8 for _ in range(ndim)]\n    x = random_tensor(ndim, *dims).to_global(placement=placement, sbp=sbp)\n    y = torch.nn.functional.avg_pool2d(\n        x,\n        kernel_size=random(1, 6).to(int),\n        stride=random(1, 3).to(int),\n        padding=random(1, 3).to(int),\n        ceil_mode=random_bool(),\n        count_include_pad=random_bool(),\n    )\n    return y\n\n\n@autotest(n=1, check_graph=True)\ndef _test_functional_avgpool3d_with_random_data(test_case, placement, sbp):\n    ndim = 5\n    dims = [random(1, 3) * 8 for _ in range(ndim)]\n    x = random_tensor(ndim, *dims).to_global(placement=placement, sbp=sbp)\n    y = torch.nn.functional.avg_pool3d(\n        x,\n        kernel_size=random(1, 6).to(int),\n        stride=random(1, 3).to(int),\n        padding=random(1, 3).to(int),\n        ceil_mode=random_bool(),\n        count_include_pad=random_bool(),\n    )\n    return y\n\n\nclass TestAvgPoolingModule(flow.unittest.TestCase):\n    @globaltest\n    def test_avg_pooling(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=2):\n                _test_avgpool1d_with_random_data(test_case, placement, sbp)\n                _test_functional_avgpool1d_with_random_data(test_case, placement, sbp)\n            for sbp in all_sbp(placement, max_dim=2):\n                _test_avgpool2d_with_random_data(test_case, placement, sbp)\n                _test_functional_avgpool2d_with_random_data(test_case, placement, sbp)\n            for sbp in all_sbp(placement, max_dim=2):\n                _test_avgpool3d_with_random_data(test_case, placement, sbp)\n                _test_functional_avgpool3d_with_random_data(test_case, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_batch_gather.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\nfrom oneflow.test_utils.automated_test_util.util import broadcast\n\n\ndef _test_batch_gather(test_case, ndim, placement, sbp):\n    dims = [random(1, 3).to(int).value() * 8 for _ in range(ndim)]\n    x = random_tensor(ndim, *dims, requires_grad=True)\n    local_x = flow.tensor(x.pytorch.detach().cpu().numpy(), requires_grad=True)\n    global_x = x.oneflow.to_global(placement=placement, sbp=sbp)\n    global_x.retain_grad()\n\n    indices_ndim = random(1, ndim + 1).to(int).value()\n    indices_dims = [dims[i] for i in range(indices_ndim)]\n    indices_dims[-1] = random(1, dims[indices_ndim - 1]).to(int).value()\n    indices = np.random.choice(dims[indices_ndim - 1], indices_dims)\n    indices = broadcast(indices)\n    local_indices = flow.tensor(indices)\n    global_indices = local_indices.to_global(\n        placement=placement, sbp=[flow.sbp.broadcast for _ in range(len(sbp))]\n    )\n\n    global_out = flow.batch_gather(global_x, global_indices)\n    global_out.sum().backward()\n    local_out = flow.batch_gather(local_x, local_indices)\n    local_out.sum().backward()\n    test_case.assertTrue(\n        np.allclose(\n            global_x.grad.detach().cpu().numpy(),\n            local_x.grad.detach().cpu().numpy(),\n            atol=1e-5,\n            rtol=1e-5,\n        )\n    )\n\n\nclass TestBatchGather(flow.unittest.TestCase):\n    @globaltest\n    def test_batch_gather(test_case):\n        ndim = 2\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=ndim):\n                _test_batch_gather(test_case, ndim, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_bincount.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=1, check_graph=False, auto_backward=False)\ndef _test_bincount(test_case, placement, sbp):\n    x = random_tensor(1, 64, low=0, dtype=int).to_global(placement=placement, sbp=sbp)\n    weight = random_tensor(1, 64).to_global(placement=placement, sbp=sbp)\n    minlength = random(1, 100).to(int)\n    return torch.bincount(x, weight, minlength)\n\n\nclass TestBinCountModule(flow.unittest.TestCase):\n    @globaltest\n    def test_bincount(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, valid_split_axis=0):\n                _test_bincount(test_case, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_bitwise.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\n\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=1, auto_backward=False)\ndef _test_bitwise_ops_with_random_data(test_case, op, placement, sbp):\n    x = random_tensor(ndim=1, dim0=8, dtype=int).to_global(placement, sbp)\n    y = random_tensor(ndim=1, dim0=8, dtype=int).to_global(placement, sbp)\n    out = op(x, y)\n    return out\n\n\n@autotest(n=1, auto_backward=False)\ndef _test_bitwise_not_with_random_data(test_case, placement, sbp):\n    x = random_tensor(ndim=1, dim0=8, dtype=int).to_global(placement, sbp)\n    return torch.bitwise_not(x)\n\n\nclass TestBitwiseModule(flow.unittest.TestCase):\n    @globaltest\n    def test_bitwise_and_with_random_data(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=1):\n                _test_bitwise_ops_with_random_data(\n                    test_case, torch.bitwise_and, placement, sbp\n                )\n                _test_bitwise_ops_with_random_data(\n                    test_case, torch.bitwise_or, placement, sbp\n                )\n                _test_bitwise_ops_with_random_data(\n                    test_case, torch.bitwise_xor, placement, sbp\n                )\n                _test_bitwise_not_with_random_data(test_case, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_broadcase_like.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\n\nimport numpy as np\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\ndef _test_broadcast_like(test_case, placement, sbp):\n    like_shape = [8] * 4\n    like = random_tensor(4, *like_shape).to_global(\n        placement, random_sbp(placement, max_dim=4)\n    )\n    x = random_tensor(2, *(8, 8)).to_global(placement, sbp)\n    # oneflow\n    of_y = flow.broadcast_like(x.oneflow, like.oneflow)\n    # pytorch\n    torch_y = x.pytorch.broadcast_to(like_shape)\n\n    test_case.assertTrue(np.allclose(of_y.numpy(), torch_y.detach().cpu().numpy()))\n\n\ndef _test_broadcast_like_expand_dims(test_case, placement, sbp):\n    like_shape = [8] * 4\n    like = random_tensor(4, *like_shape).to_global(\n        placement, random_sbp(placement, max_dim=4)\n    )\n    x = random_tensor(2, *(8, 8)).to_global(placement, sbp)\n    # oneflow\n    of_y = flow.broadcast_like(x.oneflow, like.oneflow, [1, 3])\n    # pytorch\n    torch_y = x.pytorch.view(8, 1, 8, 1).broadcast_to(like_shape)\n\n    test_case.assertTrue(np.allclose(of_y.numpy(), torch_y.detach().cpu().numpy()))\n\n\nclass TestGlobalBroadcaseLike(flow.unittest.TestCase):\n    @globaltest\n    def test_broadcase_like(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=2):\n                _test_broadcast_like(test_case, placement, sbp)\n                _test_broadcast_like_expand_dims(test_case, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_broadcast_matmul.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=1, check_graph=True)\ndef _test_flow_tensor_global_broadcast_matmul_with_random_data(\n    test_case, placement, x_sbp, y_sbp\n):\n    batch_dim = random(1, 6) * 8\n    k = random(1, 6) * 4\n    x = random_tensor(ndim=3, dim0=batch_dim, dim2=k).to_global(\n        placement=placement, sbp=x_sbp\n    )\n    y = random_tensor(ndim=2, dim0=k).to_global(placement=placement, sbp=y_sbp)\n    return x.matmul(y)\n\n\n@autotest(n=1, check_graph=True)\ndef _test_flow_tensor_global_x_broadcast_y_matmul(test_case, placement, x_sbp, y_sbp):\n    batch_dim = random(1, 6) * 8\n    k = random(1, 6) * 4\n    x = random_tensor(ndim=2, dim1=k).to_global(placement=placement, sbp=x_sbp)\n    y = random_tensor(ndim=3, dim0=batch_dim, dim1=k).to_global(\n        placement=placement, sbp=y_sbp\n    )\n\n    return x.matmul(y)\n\n\n@autotest(n=1, check_graph=True, rtol=1e-3, atol=1e-4)\ndef _test_flow_tensor_global_broadcast_matmul_with_same_dims(\n    test_case, placement, x_sbp, y_sbp\n):\n    k = random(1, 6) * 8\n    batch_dim = random(1, 6) * 8\n    x = random_tensor(ndim=3, dim0=batch_dim, dim1=4, dim2=k).to_global(\n        placement=placement, sbp=x_sbp\n    )\n    y = random_tensor(ndim=3, dim0=batch_dim, dim1=k, dim2=4).to_global(\n        placement=placement, sbp=y_sbp\n    )\n    return x.matmul(y)\n\n\nclass TestGlobalBroadcastMatmulModule(flow.unittest.TestCase):\n    @globaltest\n    def test_global_broadcast_matmul_with_random_data(test_case):\n        for placement in all_placement():\n            for x_sbp in all_sbp(placement, max_dim=2, valid_split_axis=[0]):\n                for y_sbp in all_sbp(placement, max_dim=2, except_split=True):\n                    _test_flow_tensor_global_broadcast_matmul_with_random_data(\n                        test_case, placement, x_sbp, y_sbp\n                    )\n\n    @globaltest\n    def test_global_x_broadcast_y_matmul(test_case):\n        for placement in all_placement():\n            for x_sbp in all_sbp(placement, max_dim=2, except_split=True):\n                for y_sbp in all_sbp(placement, max_dim=2, valid_split_axis=[0]):\n                    _test_flow_tensor_global_x_broadcast_y_matmul(\n                        test_case, placement, x_sbp, y_sbp\n                    )\n\n    @globaltest\n    def test_global_broadcast_matmul_with_same_dims(test_case):\n        for placement in all_placement():\n            for x_sbp in all_sbp(placement, max_dim=2):\n                for y_sbp in all_sbp(placement, max_dim=2):\n                    _test_flow_tensor_global_broadcast_matmul_with_same_dims(\n                        test_case, placement, x_sbp, y_sbp\n                    )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_broadcast_ops.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=1)\ndef _test_global_broadcast_tensors(\n    test_case, input_shape, other_shape, placement, x_sbp, y_sbp\n):\n    x = random_tensor(len(input_shape), *input_shape).to_global(\n        placement=placement, sbp=x_sbp\n    )\n    y = random_tensor(len(other_shape), *other_shape).to_global(\n        placement=placement, sbp=y_sbp\n    )\n    return torch.broadcast_tensors(x, y)\n\n\nclass TestGlobalBroadcastOps(flow.unittest.TestCase):\n    # flow.broadcast_shapes's input are shapes, so it can't be tested in global mode\n    # flow.broadcast_to is an alias of flow.expand, so its global tests are same as flow.expand's\n\n    @globaltest\n    def test_global_tensors(test_case):\n        shapes = [((2, 2), (2, 2, 2)), ((1, 2), (3, 1))]\n        for input_shape, other_shape in shapes:\n            for placement in all_placement():\n                for x_sbp in all_sbp(\n                    placement,\n                    max_dim=2,\n                    valid_split_axis=[x for x in input_shape if x != 1],\n                ):\n                    for y_sbp in all_sbp(\n                        placement,\n                        max_dim=2,\n                        valid_split_axis=[y for y in other_shape if y != 1],\n                    ):\n                        _test_global_broadcast_tensors(\n                            test_case, input_shape, other_shape, placement, x_sbp, y_sbp\n                        )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_cast.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nfrom collections import OrderedDict\nimport os\n\nimport numpy as np\n\nimport oneflow as flow\nfrom oneflow import nn\nimport oneflow.unittest\nfrom oneflow.test_utils.test_util import GenArgList\nfrom oneflow import Tensor\nfrom oneflow.framework.args_tree import ArgsTree\n\n\n@flow.unittest.skip_unless_1n4d()\nclass TestGlobalCastModule_1n4d(flow.unittest.TestCase):\n    def test_to_global_flatten_hierarchy(test_case):\n        x = flow.ones((4, 4), dtype=flow.int32)\n        sbp = (flow.sbp.partial_sum,)\n        y = x.to_global(\n            placement=flow.placement(\"cpu\", ranks=[[0, 1], [2, 3]]),\n            sbp=(flow.sbp.partial_sum, flow.sbp.partial_sum),\n        )\n        placement = flow.placement(\"cpu\", ranks=[0, 1, 2, 3])\n        y = y.to_global(placement=placement, sbp=sbp)\n        test_case.assertEqual(y.sbp, sbp)\n        test_case.assertEqual(y.placement, placement)\n        test_case.assertEqual(tuple(y.shape), (4, 4))\n\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    def test_to_global_flatten_hierarchy_cpu_to_gpu(test_case):\n        x = flow.ones((4, 4), dtype=flow.int32)\n        sbp = (flow.sbp.partial_sum,)\n        y = x.to_global(\n            placement=flow.placement(\"cpu\", ranks=[[0, 1], [2, 3]]),\n            sbp=(flow.sbp.partial_sum, flow.sbp.partial_sum),\n        )\n        placement = flow.placement(\"cuda\", ranks=[0, 1, 2, 3])\n        y = y.to_global(placement=placement, sbp=sbp)\n        test_case.assertEqual(y.sbp, sbp)\n        test_case.assertEqual(y.placement, placement)\n        test_case.assertEqual(tuple(y.shape), (4, 4))\n\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    def test_to_global_flatten_hierarchy_gpu_to_cpu(test_case):\n        x = flow.ones((4, 4), dtype=flow.int32)\n        sbp = (flow.sbp.partial_sum,)\n        y = x.to_global(\n            placement=flow.placement(\"cuda\", ranks=[[0, 1], [2, 3]]),\n            sbp=(flow.sbp.partial_sum, flow.sbp.partial_sum),\n        )\n        placement = flow.placement(\"cpu\", ranks=[0, 1, 2, 3])\n        y = y.to_global(placement=placement, sbp=sbp)\n        test_case.assertEqual(y.sbp, sbp)\n        test_case.assertEqual(y.placement, placement)\n        test_case.assertEqual(tuple(y.shape), (4, 4))\n\n    def test_to_global_broadcast_shape_dtype(test_case):\n        if int(os.getenv(\"RANK\")) < 2:\n            x = flow.ones((4, 4), dtype=flow.int32)\n        else:\n            x = flow.zeros((1,), dtype=flow.float)\n        placement = flow.placement(\"cpu\", ranks=[0, 1])\n        sbp = (flow.sbp.split(0),)\n        y = x.to_global(placement=placement, sbp=sbp)\n        test_case.assertEqual(y.sbp, sbp)\n        test_case.assertEqual(y.placement, placement)\n        test_case.assertEqual(tuple(y.shape), (8, 4))\n        test_case.assertEqual(y.dtype, flow.int32)\n\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    def test_local_to_global_2d_sbp(test_case):\n        x = flow.ones((4, 4), device=flow.device(\"cuda\"), dtype=flow.int32)\n        placement = flow.placement(\"cuda\", ranks=[[0, 1], [2, 3]])\n        sbp = (flow.sbp.split(0), flow.sbp.partial_sum)\n        y = x.to_global(placement=placement, sbp=sbp)\n        test_case.assertEqual(y.sbp, sbp)\n        test_case.assertEqual(y.placement, placement)\n        test_case.assertEqual(tuple(y.shape), (8, 4))\n        test_case.assertEqual(y.dtype, flow.int32)\n\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    def test_local_to_global_sp_2_bb(test_case):\n        x = flow.ones((4, 4), device=flow.device(\"cuda\"), dtype=flow.int32)\n        placement = flow.placement(\"cuda\", ranks=[[0, 1], [2, 3]])\n        sbp = (flow.sbp.split(0), flow.sbp.partial_sum)\n        y = x.to_global(placement=placement, sbp=sbp)\n        test_case.assertEqual(y.sbp, sbp)\n        test_case.assertEqual(y.placement, placement)\n        test_case.assertEqual(tuple(y.shape), (8, 4))\n        test_case.assertEqual(y.dtype, flow.int32)\n        y = y.to_global(sbp=(flow.sbp.broadcast, flow.sbp.broadcast))\n        test_case.assertEqual(y.sbp, (flow.sbp.broadcast, flow.sbp.broadcast))\n        test_case.assertEqual(y.placement, placement)\n        test_case.assertEqual(tuple(y.shape), (8, 4))\n        test_case.assertEqual(y.dtype, flow.int32)\n        z = y.to_local()\n        test_case.assertTrue(\n            np.array_equal(z.numpy(), np.ones((8, 4), dtype=np.int32) * 2)\n        )\n\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    def test_local_to_global_ps0_2_s0s0(test_case):\n        x = flow.ones((4, 4), device=flow.device(\"cuda\"), dtype=flow.int32)\n        x = x * int(os.getenv(\"RANK\"))\n        placement = flow.placement(\"cuda\", ranks=[[0, 1], [2, 3]])\n        sbp = (flow.sbp.partial_sum, flow.sbp.split(0))\n        y = x.to_global(placement=placement, sbp=sbp)\n        test_case.assertEqual(y.sbp, sbp)\n        test_case.assertEqual(y.placement, placement)\n        test_case.assertEqual(tuple(y.shape), (8, 4))\n        test_case.assertEqual(y.dtype, flow.int32)\n        sbp = (flow.sbp.split(0), flow.sbp.split(0))\n        y = y.to_global(sbp=sbp)\n        z = y.to_local()\n        if int(os.getenv(\"RANK\")) < 2:\n            scale = 2\n        else:\n            scale = 4\n        test_case.assertTrue(\n            np.array_equal(z.numpy(), np.ones((2, 4), dtype=np.int32) * scale)\n        )\n\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    def test_local_to_global_s0p_2_s0s0(test_case):\n        x = flow.ones((4, 4), device=flow.device(\"cuda\"), dtype=flow.int32)\n        x = x * int(os.getenv(\"RANK\"))\n        placement = flow.placement(\"cuda\", ranks=[[0, 1], [2, 3]])\n        sbp = (flow.sbp.split(0), flow.sbp.partial_sum)\n        y = x.to_global(placement=placement, sbp=sbp)\n        test_case.assertEqual(y.sbp, sbp)\n        test_case.assertEqual(y.placement, placement)\n        test_case.assertEqual(tuple(y.shape), (8, 4))\n        test_case.assertEqual(y.dtype, flow.int32)\n        sbp = (flow.sbp.split(0), flow.sbp.split(0))\n        y = y.to_global(sbp=sbp)\n        z = y.to_local()\n        if int(os.getenv(\"RANK\")) < 2:\n            scale = 1\n        else:\n            scale = 5\n        test_case.assertTrue(\n            np.array_equal(z.numpy(), np.ones((2, 4), dtype=np.int32) * scale)\n        )\n\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    def test_to_global_loop_broadcast_shape_dtype(test_case):\n        if int(os.getenv(\"RANK\")) < 2:\n            x = flow.ones((4, 4), device=flow.device(\"cuda\"), dtype=flow.int32)\n            a = flow.ones((4, 4), device=flow.device(\"cpu\"), dtype=flow.int32)\n        else:\n            x = flow.zeros((1,), dtype=flow.float)\n            a = flow.zeros((4, 4), device=flow.device(\"cpu\"), dtype=flow.int32)\n        placement = flow.placement(\"cuda\", ranks=[0, 1])\n        sbp = (flow.sbp.split(0),)\n        for i in range(1000):\n            if i % 100 == 0:\n                print(i)\n            y = x.to_global(placement=placement, sbp=sbp)\n            b = a.to_global(placement=placement, sbp=flow.sbp.broadcast)\n        test_case.assertEqual(y.sbp, sbp)\n        test_case.assertEqual(y.placement, placement)\n        test_case.assertEqual(tuple(y.shape), (8, 4))\n        test_case.assertEqual(y.dtype, flow.int32)\n\n\n@flow.unittest.skip_unless_1n2d()\nclass TestGlobalCastModule_1n2d(flow.unittest.TestCase):\n    def test_to_global_broadcast_shape_dtype(test_case):\n        if os.getenv(\"RANK\") == \"0\":\n            x = flow.ones((4, 4), dtype=flow.int32)\n        else:\n            x = flow.zeros((1,), dtype=flow.float)\n        placement = flow.placement(\"cpu\", ranks=[0])\n        sbp = (flow.sbp.broadcast,)\n        y = x.to_global(placement=placement, sbp=sbp)\n        test_case.assertEqual(y.sbp, sbp)\n        test_case.assertEqual(y.placement, placement)\n        test_case.assertEqual(tuple(y.shape), (4, 4))\n        test_case.assertEqual(y.dtype, flow.int32)\n\n    def test_local_to_global_broadcast_data(test_case):\n        if int(os.getenv(\"RANK\")) == 0:\n            x = flow.ones((4, 4), dtype=flow.int32)\n        else:\n            x = flow.zeros((4, 4), dtype=flow.int32)\n        placement = flow.placement(\"cpu\", ranks=[0, 1])\n        sbp = (flow.sbp.broadcast,)\n        y = x.to_global(placement=placement, sbp=sbp)\n        test_case.assertEqual(y.sbp, sbp)\n        test_case.assertEqual(y.placement, placement)\n        test_case.assertEqual(tuple(y.shape), (4, 4))\n        test_case.assertEqual(y.dtype, flow.int32)\n        z = y.to_local()\n        test_case.assertTrue(np.array_equal(z.numpy(), np.ones((4, 4), dtype=np.int32)))\n\n    def test_cuda_global_to_global_cpu_s2b(test_case):\n        x = flow.ones((4, 4), device=flow.device(\"cpu\"), dtype=flow.int32)\n        placement = flow.placement(\"cpu\", ranks=[0, 1])\n        y = x.to_global(placement=placement, sbp=flow.sbp.split(0))\n        sbp = (flow.sbp.broadcast,)\n        y = y.to_global(sbp=sbp)\n        test_case.assertEqual(y.sbp, sbp)\n        test_case.assertEqual(y.placement, placement)\n        test_case.assertEqual(tuple(y.shape), (8, 4))\n        test_case.assertEqual(y.dtype, flow.int32)\n        z = y.to_local()\n        test_case.assertTrue(np.array_equal(z.numpy(), np.ones((8, 4), dtype=np.int32)))\n\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    def test_cuda_global_to_global_s2b(test_case):\n        x = flow.ones((4, 4), device=flow.device(\"cuda\"), dtype=flow.int32)\n        placement = flow.placement(\"cuda\", ranks=[0, 1])\n        y = x.to_global(placement=placement, sbp=flow.sbp.split(0))\n        sbp = (flow.sbp.broadcast,)\n        y = y.to_global(sbp=sbp)\n        test_case.assertEqual(y.sbp, sbp)\n        test_case.assertEqual(y.placement, placement)\n        test_case.assertEqual(tuple(y.shape), (8, 4))\n        test_case.assertEqual(y.dtype, flow.int32)\n        z = y.to_local()\n        test_case.assertTrue(np.array_equal(z.numpy(), np.ones((8, 4), dtype=np.int32)))\n\n    def test_cuda_global_to_global_cpu_s2p(test_case):\n        x = flow.ones((4, 4), device=flow.device(\"cpu\"), dtype=flow.int32)\n        placement = flow.placement(\"cpu\", ranks=[0, 1])\n        y = x.to_global(placement=placement, sbp=flow.sbp.split(0))\n        sbp = (flow.sbp.partial_sum,)\n        y = y.to_global(sbp=sbp)\n        test_case.assertEqual(y.sbp, sbp)\n        test_case.assertEqual(y.placement, placement)\n        test_case.assertEqual(tuple(y.shape), (8, 4))\n        test_case.assertEqual(y.dtype, flow.int32)\n        z = y.to_local()\n        if int(os.getenv(\"RANK\")) == 0:\n            test_case.assertTrue(\n                np.array_equal(\n                    z.numpy(),\n                    np.concatenate(\n                        (\n                            np.ones((4, 4), dtype=np.int32),\n                            np.zeros((4, 4), dtype=np.int32),\n                        ),\n                        axis=0,\n                    ),\n                )\n            )\n        else:\n            test_case.assertTrue(\n                np.array_equal(\n                    z.numpy(),\n                    np.concatenate(\n                        (\n                            np.zeros((4, 4), dtype=np.int32),\n                            np.ones((4, 4), dtype=np.int32),\n                        ),\n                        axis=0,\n                    ),\n                )\n            )\n\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    def test_cuda_global_to_global_s2p(test_case):\n        x = flow.ones((4, 4), device=flow.device(\"cuda\"), dtype=flow.int32)\n        placement = flow.placement(\"cuda\", ranks=[0, 1])\n        y = x.to_global(placement=placement, sbp=flow.sbp.split(0))\n        sbp = (flow.sbp.partial_sum,)\n        y = y.to_global(sbp=sbp)\n        test_case.assertEqual(y.sbp, sbp)\n        test_case.assertEqual(y.placement, placement)\n        test_case.assertEqual(tuple(y.shape), (8, 4))\n        test_case.assertEqual(y.dtype, flow.int32)\n        z = y.to_local()\n        if int(os.getenv(\"RANK\")) == 0:\n            test_case.assertTrue(\n                np.array_equal(\n                    z.numpy(),\n                    np.concatenate(\n                        (\n                            np.ones((4, 4), dtype=np.int32),\n                            np.zeros((4, 4), dtype=np.int32),\n                        ),\n                        axis=0,\n                    ),\n                )\n            )\n        else:\n            test_case.assertTrue(\n                np.array_equal(\n                    z.numpy(),\n                    np.concatenate(\n                        (\n                            np.zeros((4, 4), dtype=np.int32),\n                            np.ones((4, 4), dtype=np.int32),\n                        ),\n                        axis=0,\n                    ),\n                )\n            )\n\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    def test_cuda_global_to_global_b2p(test_case):\n        x = flow.ones((4, 4), device=flow.device(\"cuda\"), dtype=flow.int32)\n        placement = flow.placement(\"cuda\", ranks=[0, 1])\n        y = x.to_global(placement=placement, sbp=flow.sbp.broadcast)\n        sbp = (flow.sbp.partial_sum,)\n        y = y.to_global(sbp=sbp)\n        test_case.assertEqual(y.sbp, sbp)\n        test_case.assertEqual(y.placement, placement)\n        test_case.assertEqual(tuple(y.shape), (4, 4))\n        test_case.assertEqual(y.dtype, flow.int32)\n        z = y.to_local()\n        if int(os.getenv(\"RANK\")) == 0:\n            test_case.assertTrue(\n                np.array_equal(z.numpy(), np.ones((4, 4), dtype=np.int32))\n            )\n        else:\n            test_case.assertTrue(\n                np.array_equal(z.numpy(), np.zeros((4, 4), dtype=np.int32))\n            )\n\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    def test_cuda_global_to_global_b2s(test_case):\n        x = flow.ones((4, 4), device=flow.device(\"cuda\"), dtype=flow.int32)\n        placement = flow.placement(\"cuda\", ranks=[0, 1])\n        y = x.to_global(placement=placement, sbp=flow.sbp.broadcast)\n        sbp = (flow.sbp.split(0),)\n        y = y.to_global(sbp=sbp)\n        test_case.assertEqual(y.sbp, sbp)\n        test_case.assertEqual(y.placement, placement)\n        test_case.assertEqual(tuple(y.shape), (4, 4))\n        test_case.assertEqual(y.dtype, flow.int32)\n        z = y.to_local()\n        test_case.assertTrue(np.array_equal(z.numpy(), np.ones((2, 4), dtype=np.int32)))\n\n    def test_cuda_global_to_global_cpu_p2s(test_case):\n        x = flow.ones((4, 4), device=flow.device(\"cpu\"), dtype=flow.int32)\n        placement = flow.placement(\"cpu\", ranks=[0, 1])\n        y = x.to_global(placement=placement, sbp=flow.sbp.partial_sum)\n        sbp = (flow.sbp.split(0),)\n        y = y.to_global(sbp=sbp)\n        test_case.assertEqual(y.sbp, sbp)\n        test_case.assertEqual(y.placement, placement)\n        test_case.assertEqual(tuple(y.shape), (4, 4))\n        test_case.assertEqual(y.dtype, flow.int32)\n        z = y.to_local()\n        test_case.assertTrue(\n            np.array_equal(z.numpy(), np.ones((2, 4), dtype=np.int32) * 2)\n        )\n\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    def test_cuda_global_to_global_p2s(test_case):\n        x = flow.ones((4, 4), device=flow.device(\"cuda\"), dtype=flow.int32)\n        placement = flow.placement(\"cuda\", ranks=[0, 1])\n        y = x.to_global(placement=placement, sbp=flow.sbp.partial_sum)\n        sbp = (flow.sbp.split(0),)\n        y = y.to_global(sbp=sbp)\n        test_case.assertEqual(y.sbp, sbp)\n        test_case.assertEqual(y.placement, placement)\n        test_case.assertEqual(tuple(y.shape), (4, 4))\n        test_case.assertEqual(y.dtype, flow.int32)\n        z = y.to_local()\n        test_case.assertTrue(\n            np.array_equal(z.numpy(), np.ones((2, 4), dtype=np.int32) * 2)\n        )\n\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    def test_cuda_global_to_global_cuda_h2d(test_case):\n        x = flow.ones((4, 4), device=flow.device(\"cpu\"), dtype=flow.int32)\n        placement = flow.placement(\"cpu\", ranks=[0, 1])\n        cuda_placement = flow.placement(\"cuda\", ranks=[0, 1])\n        y = x.to_global(placement=placement, sbp=flow.sbp.partial_sum)\n        y = y.to_global(placement=cuda_placement, sbp=flow.sbp.partial_sum)\n        test_case.assertEqual(y.sbp, (flow.sbp.partial_sum,))\n        test_case.assertEqual(y.placement, cuda_placement)\n        test_case.assertEqual(tuple(y.shape), (4, 4))\n        test_case.assertEqual(y.dtype, flow.int32)\n        z = y.to_local()\n        test_case.assertTrue(np.array_equal(z.numpy(), np.ones((4, 4), dtype=np.int32)))\n\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    def test_cuda_global_to_global_cpu_p2b(test_case):\n        x = flow.ones((4, 4), device=flow.device(\"cpu\"), dtype=flow.int32)\n        placement = flow.placement(\"cpu\", ranks=[0, 1])\n        cuda_placement = flow.placement(\"cuda\", ranks=[0, 1])\n        y = x.to_global(placement=placement, sbp=flow.sbp.partial_sum)\n        import time\n\n        y = y.to_global(placement=cuda_placement, sbp=flow.sbp.partial_sum)\n        sbp = (flow.sbp.broadcast,)\n        y = y.to_global(placement=cuda_placement, sbp=sbp)\n        y = y.to_global(placement=placement, sbp=sbp)\n        test_case.assertEqual(y.sbp, sbp)\n        test_case.assertEqual(y.placement, placement)\n        test_case.assertEqual(tuple(y.shape), (4, 4))\n        test_case.assertEqual(y.dtype, flow.int32)\n        z = y.to_local()\n        test_case.assertTrue(\n            np.array_equal(z.numpy(), np.ones((4, 4), dtype=np.int32) * 2)\n        )\n\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    def test_cuda_global_to_global_p2b(test_case):\n        x = flow.ones((4, 4), device=flow.device(\"cuda\"), dtype=flow.int32)\n        placement = flow.placement(\"cuda\", ranks=[0, 1])\n        y = x.to_global(placement=placement, sbp=flow.sbp.partial_sum)\n        sbp = (flow.sbp.broadcast,)\n        y = y.to_global(sbp=sbp)\n        test_case.assertEqual(y.sbp, sbp)\n        test_case.assertEqual(y.placement, placement)\n        test_case.assertEqual(tuple(y.shape), (4, 4))\n        test_case.assertEqual(y.dtype, flow.int32)\n        z = y.to_local()\n        test_case.assertTrue(\n            np.array_equal(z.numpy(), np.ones((4, 4), dtype=np.int32) * 2)\n        )\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestGlobalCastModule_1n1d(flow.unittest.TestCase):\n    def test_to_global(test_case):\n        x = flow.ones((4, 4))\n        placement = flow.placement(\"cpu\", ranks=[0])\n        sbp = (flow.sbp.broadcast,)\n        y = x.to_global(placement=placement, sbp=sbp)\n        test_case.assertEqual(y.sbp, sbp)\n        test_case.assertEqual(y.placement, placement)\n        test_case.assertEqual(tuple(y.shape), (4, 4))\n\n\ndef _test_cpu_p2b_with_random_parameter(test_case, device_list):\n    gen_float = np.random.random\n    gen_int = np.random.randint\n    dtype_list = [\n        flow.uint8,\n        flow.int8,\n        flow.int32,\n        flow.int64,\n        flow.float32,\n        flow.float64,\n        flow.double,\n    ]\n\n    def choose_shape_and_dtype(seed):\n        rng = np.random.default_rng(seed)\n        kdtype = rng.integers(low=1, high=len(dtype_list), size=1)\n        ndim = rng.integers(low=1, high=4, size=1)\n        shape = rng.integers(low=1, high=10, size=ndim)\n        return kdtype, shape\n\n    for _ in range(10):\n        seed = flow.tensor(gen_int(1, 1000, 1))\n        seed = seed.to_global(\n            placement=flow.placement.all(seed.device.type), sbp=flow.sbp.broadcast,\n        )\n        seed = int(seed.to_local().numpy())\n        kdtype, shape = choose_shape_and_dtype(seed)\n        if kdtype <= 3:\n            np_arr = gen_int(1, 10, shape)\n        else:\n            np_arr = gen_float(shape)\n        tensor = flow.tensor(np_arr, device=\"cpu\", dtype=dtype_list[int(kdtype)])\n        cpu_tensor = tensor.to_global(\n            placement=flow.placement(\"cpu\", device_list), sbp=flow.sbp.partial_sum\n        )\n        cpu_tensor = cpu_tensor.to_global(sbp=flow.sbp.broadcast)\n        tensor = tensor.to(\"cuda\")\n        cuda_tensor = tensor.to_global(\n            placement=flow.placement(\"cuda\", device_list), sbp=flow.sbp.partial_sum\n        )\n        cuda_tensor = cuda_tensor.to_global(sbp=flow.sbp.broadcast)\n        test_case.assertTrue(\n            np.allclose(cpu_tensor.to_local().numpy(), cuda_tensor.to_local().numpy())\n        )\n\n\ndef _test_cpu_s2b_with_random_parameter(test_case, device_list):\n    gen_float = np.random.random\n    gen_int = np.random.randint\n    dtype_list = [\n        flow.uint8,\n        flow.int8,\n        flow.int32,\n        flow.int64,\n        flow.float32,\n        flow.float64,\n        flow.double,\n    ]\n\n    def choose_shape_and_dtype(seed):\n        rng = np.random.default_rng(seed)\n        kdtype = rng.integers(low=1, high=len(dtype_list), size=1)\n        ndim = rng.integers(low=1, high=4, size=1)\n        shape = rng.integers(low=1, high=10, size=ndim)\n        return kdtype, shape\n\n    for _ in range(10):\n        seed = flow.tensor(gen_int(1, 1000, 1))\n        seed = seed.to_global(\n            placement=flow.placement.all(seed.device.type), sbp=flow.sbp.broadcast,\n        )\n        seed = int(seed.to_local().numpy())\n        kdtype, shape = choose_shape_and_dtype(seed)\n        if kdtype <= 3:\n            np_arr = gen_int(1, 10, shape)\n        else:\n            np_arr = gen_float(shape)\n        tensor = flow.tensor(np_arr, device=\"cpu\", dtype=dtype_list[int(kdtype)])\n        cpu_tensor = tensor.to_global(\n            placement=flow.placement(\"cpu\", device_list), sbp=flow.sbp.split(0)\n        )\n        cpu_tensor = cpu_tensor.to_global(sbp=flow.sbp.broadcast)\n        tensor = tensor.to(\"cuda\")\n        cuda_tensor = tensor.to_global(\n            placement=flow.placement(\"cuda\", device_list), sbp=flow.sbp.split(0)\n        )\n        cuda_tensor = cuda_tensor.to_global(sbp=flow.sbp.broadcast)\n        test_case.assertTrue(\n            np.allclose(cpu_tensor.to_local().numpy(), cuda_tensor.to_local().numpy())\n        )\n\n\ndef _test_cpu_p2s_with_random_parameter(test_case, device_list):\n    gen_float = np.random.random\n    gen_int = np.random.randint\n    dtype_list = [\n        flow.uint8,\n        flow.int8,\n        flow.int32,\n        flow.int64,\n        flow.float32,\n        flow.float64,\n        flow.double,\n    ]\n\n    def choose_shape_and_dtype(seed):\n        rng = np.random.default_rng(seed)\n        kdtype = rng.integers(low=1, high=len(dtype_list), size=1)\n        ndim = rng.integers(low=1, high=4, size=1)\n        shape = list(rng.integers(low=1, high=5, size=1) * 12) + list(\n            rng.integers(low=1, high=10, size=ndim - 1)\n        )\n        return kdtype, shape\n\n    for _ in range(10):\n        seed = flow.tensor(gen_int(1, 1000, 1))\n        seed = seed.to_global(\n            placement=flow.placement.all(seed.device.type), sbp=flow.sbp.broadcast,\n        )\n        seed = int(seed.to_local().numpy())\n        kdtype, shape = choose_shape_and_dtype(seed)\n        if kdtype <= 3:\n            np_arr = gen_int(1, 10, shape)\n        else:\n            np_arr = gen_float(shape)\n        tensor = flow.tensor(np_arr, device=\"cpu\", dtype=dtype_list[int(kdtype)])\n        cpu_tensor = tensor.to_global(\n            placement=flow.placement(\"cpu\", device_list), sbp=flow.sbp.partial_sum\n        )\n        cpu_tensor = cpu_tensor.to_global(sbp=flow.sbp.split(0))\n        tensor = tensor.to(\"cuda\")\n        cuda_tensor = tensor.to_global(\n            placement=flow.placement(\"cuda\", device_list), sbp=flow.sbp.partial_sum\n        )\n        cuda_tensor = cuda_tensor.to_global(sbp=flow.sbp.split(0))\n        test_case.assertTrue(\n            np.allclose(cpu_tensor.to_local().numpy(), cuda_tensor.to_local().numpy())\n        )\n\n\n@flow.unittest.skip_unless_1n4d()\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\nclass TestGlobalCast(flow.unittest.TestCase):\n    def test_cpu_local_tensor_to_gpu_placement(test_case):\n        if flow.env.get_rank() == 0:\n            np_arr = np.array([4, 6, 7, 8], dtype=np.float32)\n        else:\n            np_arr = np.array([0, 0, 0, 0], dtype=np.float32)\n        tensor = flow.tensor(np_arr, dtype=flow.float32)\n        placement = flow.placement(\"cuda\", [0, 1, 2, 3])\n        device = flow.device(\"cuda\")\n        global_tensor = tensor.to_global(placement, flow.sbp.broadcast)\n        test_case.assertEqual(global_tensor.to_local().device, device)\n        test_case.assertEqual(global_tensor.placement, placement)\n        test_case.assertTrue(\n            np.array_equal(\n                global_tensor.to_local().numpy(),\n                np.array([4, 6, 7, 8], dtype=np.float32),\n            )\n        )\n\n    def test_cpu_p2b_with_random_parameter(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"device_list\"] = [[0, 1], [1, 2, 3], [0, 1, 2, 3]]\n        for arg in GenArgList(arg_dict):\n            _test_cpu_p2b_with_random_parameter(test_case, *arg)\n\n    def test_cpu_s2b_with_random_parameter(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"device_list\"] = [[0, 1], [1, 2, 3], [0, 1, 2, 3]]\n        for arg in GenArgList(arg_dict):\n            _test_cpu_s2b_with_random_parameter(test_case, *arg)\n\n    def test_cpu_p2s_with_random_parameter(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"device_list\"] = [[0, 1], [1, 2, 3], [0, 1, 2, 3]]\n        for arg in GenArgList(arg_dict):\n            _test_cpu_p2s_with_random_parameter(test_case, *arg)\n\n    def test_local_to_global_with_wrong_device(test_case):\n        np_arr = np.array([4, 6], dtype=np.float32)\n        tensor = flow.tensor(\n            np_arr,\n            device=flow.device(\"cuda:%d\" % ((flow.env.get_rank() + 1) % 4)),\n            dtype=flow.float32,\n        )\n        placement = flow.placement(\"cuda\", ranks=[0, 1, 2, 3])\n        device = flow.device(\"cuda\")\n        global_tensor = tensor.to_global(placement, flow.sbp.broadcast)\n        local_tensor = global_tensor.to_local()\n        test_case.assertEqual(local_tensor.device, device)\n        test_case.assertEqual(global_tensor.placement, placement)\n\n\n@flow.unittest.skip_unless_1n4d()\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\nclass TestGlobalCast_S2S(flow.unittest.TestCase):\n    def test_global_to_global_s0_to_s1(test_case):\n        if flow.env.get_rank() == 0:\n            np_arr = np.array(\n                [[4, 6, 5, 20], [6, 2, 5, 7], [3, 7, 5, 4], [6, 8, 9, 4]],\n                dtype=np.float32,\n            )\n        else:\n            np_arr = np.array(\n                [[2, 10, 10, 7], [3, 9, 10, 5], [4, 6, 6, 9], [6, 8, 6, 4]],\n                dtype=np.float32,\n            )\n        device = flow.device(\"cuda\")\n        tensor = flow.tensor(np_arr, device=device, dtype=flow.float32)\n        placement = flow.placement(\"cuda\", ranks=[0, 1])\n        split0_tensor = tensor.to_global(placement, flow.sbp.split(0))\n        split1_tensor = split0_tensor.to_global(placement, flow.sbp.split(1))\n        if flow.env.get_rank() == 0:\n            test_case.assertTrue(\n                np.array_equal(\n                    split1_tensor.to_local().numpy(),\n                    np.array(\n                        [\n                            [4.0, 6.0],\n                            [6.0, 2.0],\n                            [3.0, 7.0],\n                            [6.0, 8.0],\n                            [2.0, 10.0],\n                            [3.0, 9.0],\n                            [4.0, 6.0],\n                            [6.0, 8.0],\n                        ],\n                        dtype=np.float32,\n                    ),\n                )\n            )\n        elif flow.env.get_rank() == 1:\n            test_case.assertTrue(\n                np.array_equal(\n                    split1_tensor.to_local().numpy(),\n                    np.array(\n                        [\n                            [5.0, 20.0],\n                            [5.0, 7.0],\n                            [5.0, 4.0],\n                            [9.0, 4.0],\n                            [10.0, 7.0],\n                            [10.0, 5.0],\n                            [6.0, 9.0],\n                            [6.0, 4.0],\n                        ],\n                        dtype=np.float32,\n                    ),\n                )\n            )\n\n    def test_global_to_global_s1_to_s0(test_case):\n        if flow.env.get_rank() == 0:\n            np_arr = np.array(\n                [[4, 6, 5, 20], [6, 2, 5, 7], [3, 7, 5, 4], [6, 8, 9, 4]],\n                dtype=np.float32,\n            )\n        else:\n            np_arr = np.array(\n                [[2, 10, 10, 7], [3, 9, 10, 5], [4, 6, 6, 9], [6, 8, 6, 4]],\n                dtype=np.float32,\n            )\n        device = flow.device(\"cuda\")\n        tensor = flow.tensor(np_arr, device=device, dtype=flow.float32)\n        placement = flow.placement(\"cuda\", ranks=[0, 1])\n        split_tensor = tensor.to_global(placement, flow.sbp.split(0))\n        split1_tensor = split_tensor.to_global(placement, flow.sbp.split(1))\n        split0_tensor = split1_tensor.to_global(placement, flow.sbp.split(0))\n        if flow.env.get_rank() == 0:\n            test_case.assertTrue(\n                np.array_equal(\n                    split0_tensor.to_local().numpy(),\n                    np.array(\n                        [\n                            [4.0, 6.0, 5.0, 20.0],\n                            [6.0, 2.0, 5.0, 7.0],\n                            [3.0, 7.0, 5.0, 4.0],\n                            [6.0, 8.0, 9.0, 4.0],\n                        ],\n                        dtype=np.float32,\n                    ),\n                )\n            )\n        elif flow.env.get_rank() == 1:\n            test_case.assertTrue(\n                np.array_equal(\n                    split0_tensor.to_local().numpy(),\n                    np.array(\n                        [\n                            [2.0, 10.0, 10.0, 7.0],\n                            [3.0, 9.0, 10.0, 5.0],\n                            [4.0, 6.0, 6.0, 9.0],\n                            [6.0, 8.0, 6.0, 4.0],\n                        ],\n                        dtype=np.float32,\n                    ),\n                )\n            )\n\n    def test_global_to_global_s0_to_s1_cpu(test_case):\n        np_arr = np.random.randn(4, 12)\n\n        cuda_device = flow.device(\"cuda\")\n        cuda_tensor = flow.tensor(np_arr, device=cuda_device, dtype=flow.float32)\n        cuda_placement = flow.placement(\"cuda\", ranks=[1, 3])\n        cuda_split0_tensor = cuda_tensor.to_global(cuda_placement, flow.sbp.split(0))\n        cuda_split1_tensor = cuda_split0_tensor.to_global(\n            cuda_placement, flow.sbp.split(1)\n        )\n\n        cpu_device = flow.device(\"cpu\")\n        cpu_tensor = flow.tensor(np_arr, device=cpu_device, dtype=flow.float32)\n        cpu_placement = flow.placement(\"cpu\", ranks=[1, 3])\n        cpu_split0_tensor = cpu_tensor.to_global(cpu_placement, flow.sbp.split(0))\n        cpu_split1_tensor = cpu_split0_tensor.to_global(\n            cpu_placement, flow.sbp.split(1)\n        )\n\n        if flow.env.get_rank() == 0 or flow.env.get_rank() == 1:\n            test_case.assertTrue(\n                np.array_equal(\n                    cuda_split1_tensor.to_local().numpy(),\n                    cpu_split1_tensor.to_local().numpy(),\n                )\n            )\n\n    def test_global_to_global_s1_to_s0_cpu(test_case):\n        np_arr = np.random.randn(4, 12)\n\n        cuda_device = flow.device(\"cuda\")\n        cuda_tensor = flow.tensor(np_arr, device=cuda_device, dtype=flow.float32)\n        cuda_placement = flow.placement(\"cuda\", ranks=[0, 1])\n        cuda_split_tensor = cuda_tensor.to_global(cuda_placement, flow.sbp.split(0))\n        cuda_split1_tensor = cuda_split_tensor.to_global(\n            cuda_placement, flow.sbp.split(1)\n        )\n        cuda_split0_tensor = cuda_split1_tensor.to_global(\n            cuda_placement, flow.sbp.split(0)\n        )\n\n        cpu_device = flow.device(\"cpu\")\n        cpu_tensor = flow.tensor(np_arr, device=cpu_device, dtype=flow.float32)\n        cpu_placement = flow.placement(\"cpu\", ranks=[0, 1])\n        cpu_split_tensor = cpu_tensor.to_global(cpu_placement, flow.sbp.split(0))\n        cpu_split1_tensor = cpu_split_tensor.to_global(cpu_placement, flow.sbp.split(1))\n        cpu_split0_tensor = cpu_split1_tensor.to_global(\n            cpu_placement, flow.sbp.split(0)\n        )\n\n        if flow.env.get_rank() == 0 or flow.env.get_rank() == 1:\n            test_case.assertTrue(\n                np.array_equal(\n                    cuda_split0_tensor.to_local().numpy(),\n                    cpu_split0_tensor.to_local().numpy(),\n                )\n            )\n\n\n@flow.unittest.skip_unless_1n4d()\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\nclass TestGlobalCast_XToB(flow.unittest.TestCase):\n    def test_global_to_global_btb_gpu_to_gpu(test_case):\n        if flow.env.get_rank() == 0:\n            np_arr = np.array(\n                [[4, 6, 5, 20], [6, 8, 9, 0], [3, 7, 5, 0], [6, 8, 9, 0]],\n                dtype=np.float32,\n            )\n        elif flow.env.get_rank() == 1:\n            np_arr = np.array(\n                [[2, 10, 10, 7], [3, 9, 10, 5], [4, 6, 6, 9], [6, 8, 6, 4]],\n                dtype=np.float32,\n            )\n        elif flow.env.get_rank() == 2:\n            np_arr = np.array(\n                [[9, 6, 5, 8], [4, 9, 7, 0], [2, 5, 7, 9], [6, 8, 10, 0]],\n                dtype=np.float32,\n            )\n        elif flow.env.get_rank() == 3:\n            np_arr = np.array(\n                [[9, 4, 5, 8], [7, 2, 9, 5], [6, 3, 9, 2], [3, 7, 5, 8]],\n                dtype=np.float32,\n            )\n        device = flow.device(\"cuda\")\n        tensor = flow.tensor(np_arr, device=device, dtype=flow.float32)\n        placement = flow.placement(\"cuda\", ranks=[0, 1])\n        global_tensor = tensor.to_global(placement, flow.sbp.broadcast)\n        new_placement = flow.placement(\"cuda\", ranks=[0, 1, 2])\n        broadcast_tensor = global_tensor.to_global(new_placement, flow.sbp.broadcast)\n        test_case.assertEqual(broadcast_tensor.placement, new_placement)\n        if flow.env.get_rank() != 3:\n            test_case.assertTrue(\n                np.array_equal(\n                    broadcast_tensor.to_local().numpy(),\n                    np.array(\n                        [[4, 6, 5, 20], [6, 8, 9, 0], [3, 7, 5, 0], [6, 8, 9, 0]],\n                        dtype=np.float32,\n                    ),\n                )\n            )\n\n    def test_global_to_global_stb_gpu_to_gpu(test_case):\n        if flow.env.get_rank() == 0:\n            np_arr = np.array(\n                [[4, 6, 5, 20], [6, 8, 9, 0], [3, 7, 5, 0], [6, 8, 9, 0]],\n                dtype=np.float32,\n            )\n        elif flow.env.get_rank() == 1:\n            np_arr = np.array(\n                [[2, 10, 10, 7], [3, 9, 10, 5], [4, 6, 6, 9], [6, 8, 6, 4]],\n                dtype=np.float32,\n            )\n        elif flow.env.get_rank() == 2:\n            np_arr = np.array(\n                [[9, 6, 5, 8], [4, 9, 7, 0], [2, 5, 7, 9], [6, 8, 10, 0]],\n                dtype=np.float32,\n            )\n        elif flow.env.get_rank() == 3:\n            np_arr = np.array(\n                [[9, 4, 5, 8], [7, 2, 9, 5], [6, 3, 9, 2], [3, 7, 5, 8]],\n                dtype=np.float32,\n            )\n        device = flow.device(\"cuda\")\n        tensor = flow.tensor(np_arr, device=device, dtype=flow.float32)\n        placement = flow.placement(\"cuda\", ranks=[0, 1, 2])\n        global_tensor = tensor.to_global(placement, flow.sbp.split(0))\n        new_placement = flow.placement(\"cuda\", ranks=[0, 1, 2, 3])\n        broadcast_tensor = global_tensor.to_global(new_placement, flow.sbp.broadcast)\n        test_case.assertEqual(broadcast_tensor.placement, new_placement)\n        test_case.assertTrue(\n            np.array_equal(\n                broadcast_tensor.to_local().numpy(),\n                np.array(\n                    [\n                        [4, 6, 5, 20],\n                        [6, 8, 9, 0],\n                        [3, 7, 5, 0],\n                        [6, 8, 9, 0],\n                        [2, 10, 10, 7],\n                        [3, 9, 10, 5],\n                        [4, 6, 6, 9],\n                        [6, 8, 6, 4],\n                        [9, 6, 5, 8],\n                        [4, 9, 7, 0],\n                        [2, 5, 7, 9],\n                        [6, 8, 10, 0],\n                    ],\n                    dtype=np.float32,\n                ),\n            )\n        )\n\n    def test_global_to_global_ptb_gpu_to_gpu(test_case):\n        if flow.env.get_rank() == 0:\n            np_arr = np.array(\n                [[4, 6, 5, 20], [6, 8, 9, 0], [3, 7, 5, 0], [6, 8, 9, 0]],\n                dtype=np.float32,\n            )\n        elif flow.env.get_rank() == 1:\n            np_arr = np.array(\n                [[2, 10, 10, 7], [3, 9, 10, 5], [4, 6, 6, 9], [6, 8, 6, 4]],\n                dtype=np.float32,\n            )\n        elif flow.env.get_rank() == 2:\n            np_arr = np.array(\n                [[9, 6, 5, 8], [4, 9, 7, 0], [2, 5, 7, 9], [6, 8, 10, 0]],\n                dtype=np.float32,\n            )\n        elif flow.env.get_rank() == 3:\n            np_arr = np.array(\n                [[9, 4, 5, 8], [7, 2, 9, 5], [6, 3, 9, 2], [3, 7, 5, 8]],\n                dtype=np.float32,\n            )\n        device = flow.device(\"cuda\")\n        tensor = flow.tensor(np_arr, device=device, dtype=flow.float32)\n        placement = flow.placement(\"cuda\", ranks=[0, 1, 2])\n        global_tensor = tensor.to_global(placement, flow.sbp.partial_sum)\n        new_placement = flow.placement(\"cuda\", ranks=[0, 1, 2, 3])\n        broadcast_tensor = global_tensor.to_global(new_placement, flow.sbp.broadcast)\n        test_case.assertEqual(broadcast_tensor.placement, new_placement)\n        test_case.assertTrue(\n            np.array_equal(\n                broadcast_tensor.to_local().numpy(),\n                np.array(\n                    [\n                        [15, 22, 20, 35],\n                        [13, 26, 26, 5],\n                        [9, 18, 18, 18],\n                        [18, 24, 25, 4],\n                    ],\n                    dtype=np.float32,\n                ),\n            )\n        )\n\n\n@flow.unittest.skip_unless_1n4d()\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\nclass TestGlobalCast_1ToN(flow.unittest.TestCase):\n    def test_global_to_global_1tob(test_case):\n        if flow.env.get_rank() == 0:\n            np_arr = np.array(\n                [[4, 6, 5, 20], [6, 2, 5, 7], [3, 7, 5, 4], [6, 8, 9, 4]],\n                dtype=np.float32,\n            )\n        else:\n            np_arr = np.array(\n                [[2, 10, 10, 7], [3, 9, 10, 5], [4, 6, 6, 9], [6, 8, 6, 4]],\n                dtype=np.float32,\n            )\n        device = flow.device(\"cuda\")\n        tensor = flow.tensor(np_arr, device=device, dtype=flow.float32)\n        placement = flow.placement(\"cuda\", ranks=[0])\n        global_tensor = tensor.to_global(placement, flow.sbp.split(0))\n        new_placement = flow.placement(\"cuda\", ranks=[0, 1])\n        broadcast_tensor = global_tensor.to_global(new_placement, flow.sbp.broadcast)\n        test_case.assertEqual(broadcast_tensor.placement, new_placement)\n        if flow.env.get_rank() < 2:\n            test_case.assertTrue(\n                np.array_equal(\n                    broadcast_tensor.to_local().numpy(),\n                    np.array(\n                        [[4, 6, 5, 20], [6, 2, 5, 7], [3, 7, 5, 4], [6, 8, 9, 4]],\n                        dtype=np.float32,\n                    ),\n                )\n            )\n\n    def test_global_to_global_1top(test_case):\n        if flow.env.get_rank() == 0:\n            np_arr = np.array(\n                [[4, 6, 5, 20], [6, 2, 5, 7], [3, 7, 5, 4], [6, 8, 9, 4]],\n                dtype=np.float32,\n            )\n        else:\n            np_arr = np.array(\n                [[2, 10, 10, 7], [3, 9, 10, 5], [4, 6, 6, 9], [6, 8, 6, 4]],\n                dtype=np.float32,\n            )\n        device = flow.device(\"cuda\")\n        tensor = flow.tensor(np_arr, device=device, dtype=flow.float32)\n        placement = flow.placement(\"cuda\", [0])\n        global_tensor = tensor.to_global(placement, flow.sbp.split(0))\n        new_placement = flow.placement(\"cuda\", ranks=[0, 1])\n        partial_sum_tensor = global_tensor.to_global(\n            new_placement, flow.sbp.partial_sum\n        )\n        test_case.assertEqual(partial_sum_tensor.placement, new_placement)\n        if flow.env.get_rank() == 0:\n            test_case.assertTrue(\n                np.array_equal(\n                    partial_sum_tensor.to_local().numpy(),\n                    np.array(\n                        [[4, 6, 5, 20], [6, 2, 5, 7], [3, 7, 5, 4], [6, 8, 9, 4]],\n                        dtype=np.float32,\n                    ),\n                )\n            )\n        elif flow.env.get_rank() == 1:\n            test_case.assertTrue(\n                np.array_equal(\n                    partial_sum_tensor.to_local().numpy(),\n                    np.array(\n                        [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]],\n                        dtype=np.float32,\n                    ),\n                )\n            )\n\n    def test_global_to_global_1tos(test_case):\n        if flow.env.get_rank() == 0:\n            np_arr = np.array(\n                [[4, 6, 5, 20], [6, 2, 5, 7], [3, 7, 5, 4], [6, 8, 9, 4]],\n                dtype=np.float32,\n            )\n        else:\n            np_arr = np.array(\n                [[2, 10, 10, 7], [3, 9, 10, 5], [4, 6, 6, 9], [6, 8, 6, 4]],\n                dtype=np.float32,\n            )\n        device = flow.device(\"cuda\")\n        tensor = flow.tensor(np_arr, device=device, dtype=flow.float32)\n        placement = flow.placement(\"cuda\", ranks=[0])\n        global_tensor = tensor.to_global(placement, flow.sbp.split(0))\n        new_placement = flow.placement(\"cuda\", ranks=[0, 1, 2, 3])\n        split_tensor = global_tensor.to_global(new_placement, flow.sbp.split(0))\n        test_case.assertEqual(split_tensor.placement, new_placement)\n        if flow.env.get_rank() == 0:\n            test_case.assertTrue(\n                np.array_equal(\n                    split_tensor.to_local().numpy(),\n                    np.array([[4, 6, 5, 20]], dtype=np.float32,),\n                )\n            )\n        elif flow.env.get_rank() == 1:\n            test_case.assertTrue(\n                np.array_equal(\n                    split_tensor.to_local().numpy(),\n                    np.array([[6, 2, 5, 7]], dtype=np.float32,),\n                )\n            )\n        elif flow.env.get_rank() == 2:\n            test_case.assertTrue(\n                np.array_equal(\n                    split_tensor.to_local().numpy(),\n                    np.array([[3, 7, 5, 4]], dtype=np.float32,),\n                )\n            )\n        elif flow.env.get_rank() == 3:\n            test_case.assertTrue(\n                np.array_equal(\n                    split_tensor.to_local().numpy(),\n                    np.array([[6, 8, 9, 4]], dtype=np.float32,),\n                )\n            )\n\n\n@flow.unittest.skip_unless_1n4d()\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\nclass TestGlobalCast_NTo1(flow.unittest.TestCase):\n    def test_global_to_global_bt1(test_case):\n        if flow.env.get_rank() == 0:\n            np_arr = np.array(\n                [[4, 6, 5, 20], [6, 2, 5, 7], [3, 7, 5, 4], [6, 8, 9, 4]],\n                dtype=np.float32,\n            )\n        else:\n            np_arr = np.array(\n                [[2, 10, 10, 7], [3, 9, 10, 5], [4, 6, 6, 9], [6, 8, 6, 4]],\n                dtype=np.float32,\n            )\n        device = flow.device(\"cuda\")\n        tensor = flow.tensor(np_arr, device=device, dtype=flow.float32)\n        placement = flow.placement(\"cuda\", ranks=[0, 1])\n        global_tensor = tensor.to_global(placement, flow.sbp.broadcast)\n        new_placement = flow.placement(\"cuda\", ranks=[0])\n        broadcast_tensor = global_tensor.to_global(new_placement, flow.sbp.broadcast)\n        test_case.assertEqual(broadcast_tensor.placement, new_placement)\n        if flow.env.get_rank() == 0:\n            test_case.assertTrue(\n                np.array_equal(\n                    broadcast_tensor.to_local().numpy(),\n                    np.array(\n                        [[4, 6, 5, 20], [6, 2, 5, 7], [3, 7, 5, 4], [6, 8, 9, 4]],\n                        dtype=np.float32,\n                    ),\n                )\n            )\n\n    def test_global_to_global_st1(test_case):\n        if flow.env.get_rank() == 0:\n            np_arr = np.array(\n                [[4, 6, 5, 20], [6, 2, 5, 7], [3, 7, 5, 4], [6, 8, 9, 4]],\n                dtype=np.float32,\n            )\n        else:\n            np_arr = np.array(\n                [[2, 10, 10, 7], [3, 9, 10, 5], [4, 6, 6, 9], [6, 8, 6, 4]],\n                dtype=np.float32,\n            )\n        device = flow.device(\"cuda\")\n        tensor = flow.tensor(np_arr, device=device, dtype=flow.float32)\n        placement = flow.placement(\"cuda\", ranks=[0, 1])\n        global_tensor = tensor.to_global(placement, flow.sbp.split(0))\n        new_placement = flow.placement(\"cuda\", ranks=[0])\n        partial_sum_tensor = global_tensor.to_global(new_placement, flow.sbp.broadcast)\n        test_case.assertEqual(partial_sum_tensor.placement, new_placement)\n        if flow.env.get_rank() == 0:\n            test_case.assertTrue(\n                np.array_equal(\n                    partial_sum_tensor.to_local().numpy(),\n                    np.array(\n                        [\n                            [4, 6, 5, 20],\n                            [6, 2, 5, 7],\n                            [3, 7, 5, 4],\n                            [6, 8, 9, 4],\n                            [2, 10, 10, 7],\n                            [3, 9, 10, 5],\n                            [4, 6, 6, 9],\n                            [6, 8, 6, 4],\n                        ],\n                        dtype=np.float32,\n                    ),\n                )\n            )\n\n    def test_global_to_global_pt1(test_case):\n        if flow.env.get_rank() == 0:\n            np_arr = np.array(\n                [[4, 6, 5, 20], [6, 2, 5, 7], [3, 7, 5, 4], [6, 8, 9, 4]],\n                dtype=np.float32,\n            )\n        else:\n            np_arr = np.array(\n                [[2, 10, 10, 7], [3, 9, 10, 5], [4, 6, 6, 9], [6, 8, 6, 4]],\n                dtype=np.float32,\n            )\n        device = flow.device(\"cuda\")\n        tensor = flow.tensor(np_arr, device=device, dtype=flow.float32)\n        placement = flow.placement(\"cuda\", ranks=[0, 1])\n        global_tensor = tensor.to_global(placement, flow.sbp.partial_sum)\n        new_placement = flow.placement(\"cuda\", ranks=[0])\n        partial_sum_tensor = global_tensor.to_global(new_placement, flow.sbp.broadcast)\n        test_case.assertEqual(partial_sum_tensor.placement, new_placement)\n        if flow.env.get_rank() == 0:\n            test_case.assertTrue(\n                np.array_equal(\n                    partial_sum_tensor.to_local().numpy(),\n                    np.array(\n                        [\n                            [6, 16, 15, 27],\n                            [9, 11, 15, 12],\n                            [7, 13, 11, 13],\n                            [12, 16, 15, 8],\n                        ],\n                        dtype=np.float32,\n                    ),\n                )\n            )\n\n\n@flow.unittest.skip_unless_1n4d()\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\nclass TestGlobalCast_1To1(flow.unittest.TestCase):\n    def test_global_to_global_1to1_gpu_to_gpu(test_case):\n        if flow.env.get_rank() == 0:\n            np_arr = np.array(\n                [[4, 6, 5, 20], [6, 2, 5, 7], [3, 7, 5, 4], [6, 8, 9, 4]],\n                dtype=np.float32,\n            )\n        else:\n            np_arr = np.array(\n                [[2, 10, 10, 7], [3, 9, 10, 5], [4, 6, 6, 9], [6, 8, 6, 4]],\n                dtype=np.float32,\n            )\n        device = flow.device(\"cuda\")\n        local_tensor = flow.tensor(np_arr, device=device, dtype=flow.float32)\n        placement = flow.placement(\"cuda\", ranks=[3])\n        x = local_tensor.to_global(placement, flow.sbp.split(0))\n        new_placement = flow.placement(\"cuda\", ranks=[2])\n        y = x.to_global(new_placement, flow.sbp.broadcast)\n        test_case.assertEqual(y.placement, new_placement)\n        if flow.env.get_rank() == 2:\n            test_case.assertTrue(\n                np.array_equal(\n                    y.to_local().numpy(),\n                    np.array(\n                        [[2, 10, 10, 7], [3, 9, 10, 5], [4, 6, 6, 9], [6, 8, 6, 4]],\n                        dtype=np.float32,\n                    ),\n                )\n            )\n\n    def test_global_to_global_1to1_cpu_to_cpu(test_case):\n        if flow.env.get_rank() == 0:\n            np_arr = np.array(\n                [[4, 6, 5, 20], [6, 2, 5, 7], [3, 7, 5, 4], [6, 8, 9, 4]],\n                dtype=np.float32,\n            )\n        else:\n            np_arr = np.array(\n                [[2, 10, 10, 7], [3, 9, 10, 5], [4, 6, 6, 9], [6, 8, 6, 4]],\n                dtype=np.float32,\n            )\n        device = flow.device(\"cpu\")\n        local_tensor = flow.tensor(np_arr, device=device, dtype=flow.float32)\n        placement = flow.placement(\"cpu\", ranks=[0])\n        x = local_tensor.to_global(placement, flow.sbp.split(0))\n        new_placement = flow.placement(\"cpu\", ranks=[2])\n        y = x.to_global(new_placement, flow.sbp.broadcast)\n        test_case.assertEqual(y.placement, new_placement)\n        if flow.env.get_rank() == 2:\n            test_case.assertTrue(\n                np.array_equal(\n                    y.to_local().numpy(),\n                    np.array(\n                        [[4, 6, 5, 20], [6, 2, 5, 7], [3, 7, 5, 4], [6, 8, 9, 4]],\n                        dtype=np.float32,\n                    ),\n                )\n            )\n\n    def test_global_to_global_1to1_gpu_to_cpu(test_case):\n        if flow.env.get_rank() == 0:\n            np_arr = np.array(\n                [[4, 6, 5, 20], [6, 2, 5, 7], [3, 7, 5, 4], [6, 8, 9, 4]],\n                dtype=np.float32,\n            )\n        else:\n            np_arr = np.array(\n                [[2, 10, 10, 7], [3, 9, 10, 5], [4, 6, 6, 9], [6, 8, 6, 4]],\n                dtype=np.float32,\n            )\n        device = flow.device(\"cuda\")\n        local_tensor = flow.tensor(np_arr, device=device, dtype=flow.float32)\n        placement = flow.placement(\"cuda\", ranks=[0])\n        x = local_tensor.to_global(placement, flow.sbp.split(0))\n        new_placement = flow.placement(\"cpu\", ranks=[3])\n        y = x.to_global(new_placement, flow.sbp.broadcast)\n        test_case.assertEqual(y.placement, new_placement)\n        if flow.env.get_rank() == 3:\n            test_case.assertTrue(\n                np.array_equal(\n                    y.to_local().numpy(),\n                    np.array(\n                        [[4, 6, 5, 20], [6, 2, 5, 7], [3, 7, 5, 4], [6, 8, 9, 4]],\n                        dtype=np.float32,\n                    ),\n                )\n            )\n\n    def test_global_to_global_1to1_cpu_to_gpu(test_case):\n        if flow.env.get_rank() == 0:\n            np_arr = np.array(\n                [[4, 6, 5, 20], [6, 2, 5, 7], [3, 7, 5, 4], [6, 8, 9, 4]],\n                dtype=np.float32,\n            )\n        else:\n            np_arr = np.array(\n                [[2, 10, 10, 7], [3, 9, 10, 5], [4, 6, 6, 9], [6, 8, 6, 4]],\n                dtype=np.float32,\n            )\n        device = flow.device(\"cpu\")\n        local_tensor = flow.tensor(np_arr, device=device, dtype=flow.float32)\n        placement = flow.placement(\"cpu\", ranks=[1])\n        x = local_tensor.to_global(placement, flow.sbp.split(0))\n        new_placement = flow.placement(\"cuda\", ranks=[3])\n        y = x.to_global(new_placement, flow.sbp.broadcast)\n        test_case.assertEqual(y.placement, new_placement)\n        if flow.env.get_rank() == 3:\n            test_case.assertTrue(\n                np.array_equal(\n                    y.to_local().numpy(),\n                    np.array(\n                        [[2, 10, 10, 7], [3, 9, 10, 5], [4, 6, 6, 9], [6, 8, 6, 4]],\n                        dtype=np.float32,\n                    ),\n                )\n            )\n\n\nclass GraphTestModel(nn.Graph):\n    def __init__(self, model):\n        super().__init__()\n        self.model = model\n\n    def build(self, x):\n        return self.model(x)\n\n\n@flow.unittest.skip_unless_1n2d()\nclass TestToGlobalAndLocal(flow.unittest.TestCase):\n    placement = flow.placement(\"cpu\", ranks=[0, 1])\n    sbp = None\n    model = nn.Sequential(nn.Linear(8, 4), nn.ReLU(), nn.Linear(4, 2))\n    local_graph_model = GraphTestModel(model)\n    global_graph_model = None\n\n    def __all_global(test_case, input, placement, sbp):\n        if type(input) == Tensor:\n            test_case.assertTrue(input.is_global)\n            # check placement\n            test_case.assertEqual(placement.type, input.placement.type)\n            test_case.assertListEqual(\n                list(placement.ranks), list(input.placement.ranks)\n            )\n            # check sbp\n            test_case.assertTupleEqual(sbp, input.sbp)\n        elif isinstance(input, (dict, tuple, list)):\n            node_tree = ArgsTree(input)\n            for node in node_tree.iter_nodes():\n                if isinstance(node, Tensor):\n                    test_case.assertTrue(node.is_global)\n                    # check placement\n                    test_case.assertEqual(placement.type, node.placement.type)\n                    test_case.assertListEqual(\n                        list(placement.ranks), list(node.placement.ranks)\n                    )\n                    # check sbp\n                    test_case.assertTupleEqual(sbp, node.sbp)\n\n    def __all_local(test_case, input):\n        if type(input) == Tensor:\n            test_case.assertFalse(input.is_global)\n        elif isinstance(input, (dict, tuple, list)):\n            node_tree = ArgsTree(input)\n            for node in node_tree.iter_nodes():\n                if isinstance(node, Tensor):\n                    test_case.assertFalse(node.is_global)\n\n    def _test_any_input(test_case):\n        tensor = flow.zeros((3, 4))\n        tensor_list = [flow.tensor([1, 2, 3]), flow.randn((2, 3, 4))]\n        tensor_tuple = (flow.zeros((2, 2)), flow.ones((2, 3)), flow.randn((3, 5)))\n        tensor_dict = {\"tensor\": tensor, \"tensor_lt\": tensor_list}\n        random_combination = [\n            None,\n            1,\n            \"test_str\",\n            tensor,\n            tensor_list,\n            tensor_tuple,\n            tensor_dict,\n        ]\n\n        inputs = [\n            None,\n            100,\n            \"test_str\",\n            tensor,\n            tensor_list,\n            tensor_tuple,\n            tensor_dict,\n            random_combination,\n        ]\n        global_inputs = []\n        for i in inputs:\n            ret = flow.utils.global_view.to_global(\n                i,\n                placement=TestToGlobalAndLocal.placement,\n                sbp=TestToGlobalAndLocal.sbp,\n            )\n            test_case.__all_global(\n                ret,\n                placement=TestToGlobalAndLocal.placement,\n                sbp=TestToGlobalAndLocal.sbp,\n            )\n            global_inputs.append(ret)\n\n        for i in global_inputs:\n            ret = flow.utils.global_view.to_local(i)\n            test_case.__all_local(ret)\n\n    def _test_any_input_get_sbp_func(test_case):\n        def __get_sbp(input, tensor):\n            return TestToGlobalAndLocal.sbp\n\n        tensor = flow.zeros((3, 4))\n        tensor_list = [flow.tensor([1, 2, 3]), flow.randn((2, 3, 4))]\n        tensor_tuple = (flow.zeros((2, 2)), flow.ones((2, 3)), flow.randn((3, 5)))\n        tensor_dict = {\"tensor\": tensor, \"tensor_lt\": tensor_list}\n        random_combination = [\n            None,\n            1,\n            \"test_str\",\n            tensor,\n            tensor_list,\n            tensor_tuple,\n            tensor_dict,\n        ]\n\n        inputs = [\n            None,\n            100,\n            \"test_str\",\n            tensor,\n            tensor_list,\n            tensor_tuple,\n            tensor_dict,\n            random_combination,\n        ]\n        global_inputs = []\n        for i in inputs:\n            ret = flow.utils.global_view.to_global(\n                i, placement=TestToGlobalAndLocal.placement, sbp=__get_sbp,\n            )\n            test_case.__all_global(\n                ret,\n                placement=TestToGlobalAndLocal.placement,\n                sbp=TestToGlobalAndLocal.sbp,\n            )\n            global_inputs.append(ret)\n\n        for i in global_inputs:\n            ret = flow.utils.global_view.to_local(i)\n            test_case.__all_local(ret)\n\n    def _test_tensor_to_global(test_case):\n        local_tensor = flow.ones((3, 4))\n\n        # local tensor -> global tensor\n        global_tensor = flow.utils.global_view.to_global(\n            local_tensor,\n            placement=TestToGlobalAndLocal.placement,\n            sbp=TestToGlobalAndLocal.sbp,\n        )\n        test_case.assertTrue(global_tensor.is_global)\n\n        # global tensor -> global tensor\n        global_tensor = flow.utils.global_view.to_global(\n            global_tensor,\n            placement=TestToGlobalAndLocal.placement,\n            sbp=TestToGlobalAndLocal.sbp,\n        )\n        test_case.assertTrue(global_tensor.is_global)\n\n        # passing no placement and sbp\n        with test_case.assertRaises(ValueError):\n            global_tensor = flow.utils.global_view.to_global(\n                local_tensor, placement=None, sbp=None\n            )\n\n        # wrong sbp type\n        with test_case.assertRaises(TypeError):\n            global_tensor = flow.utils.global_view.to_global(\n                local_tensor,\n                placement=TestToGlobalAndLocal.placement,\n                sbp=(TestToGlobalAndLocal.sbp, 0),\n            )\n\n    def _test_tensor_to_local(test_case):\n        # global tensor -> local tensor\n        global_tensor = flow.ones(\n            (3, 4),\n            placement=TestToGlobalAndLocal.placement,\n            sbp=TestToGlobalAndLocal.sbp,\n        )\n        local_tensor = flow.utils.global_view.to_local(global_tensor)\n        test_case.assertFalse(local_tensor.is_global)\n\n    def __test_state_dict_to_global(test_case, local_state_dict):\n        # local state dict -> global state dict\n        global_state_dict = flow.utils.global_view.to_global(\n            local_state_dict,\n            placement=TestToGlobalAndLocal.placement,\n            sbp=TestToGlobalAndLocal.sbp,\n        )\n        test_case.__all_global(\n            global_state_dict,\n            placement=TestToGlobalAndLocal.placement,\n            sbp=TestToGlobalAndLocal.sbp,\n        )\n\n        # global state dict -> global state dict\n        global_state_dict = flow.utils.global_view.to_global(\n            global_state_dict,\n            placement=TestToGlobalAndLocal.placement,\n            sbp=TestToGlobalAndLocal.sbp,\n        )\n        test_case.__all_global(\n            global_state_dict,\n            placement=TestToGlobalAndLocal.placement,\n            sbp=TestToGlobalAndLocal.sbp,\n        )\n\n    def __test_state_dict_to_local(test_case, global_state_dict):\n        # global state dict -> local state dict\n        local_state_dict = flow.utils.global_view.to_local(global_state_dict)\n        test_case.__all_local(local_state_dict)\n\n        # local input, display warning\n        local_state_dict = flow.utils.global_view.to_local(local_state_dict)\n\n    def _test_eagar_state_dict(test_case):\n        test_case.__test_state_dict_to_global(TestToGlobalAndLocal.model.state_dict())\n        global_model = TestToGlobalAndLocal.model.to_global(\n            placement=TestToGlobalAndLocal.placement, sbp=TestToGlobalAndLocal.sbp\n        )\n        test_case.__test_state_dict_to_local(global_model.state_dict())\n\n    def _test_graph_state_dict(test_case):\n        test_case.__test_state_dict_to_global(\n            TestToGlobalAndLocal.local_graph_model.state_dict()\n        )\n        test_case.__test_state_dict_to_local(\n            TestToGlobalAndLocal.global_graph_model.state_dict()\n        )\n\n    def test_to_global_local(test_case):\n        sbp_types = [\n            (flow.sbp.broadcast,),\n            (flow.sbp.split(0),),\n            (flow.sbp.partial_sum,),\n        ]\n        for sbp in sbp_types:\n            TestToGlobalAndLocal.sbp = sbp\n            TestToGlobalAndLocal.global_graph_model = GraphTestModel(\n                TestToGlobalAndLocal.model.to_global(\n                    placement=TestToGlobalAndLocal.placement, sbp=sbp\n                )\n            )\n            test_case._test_any_input()\n            test_case._test_any_input_get_sbp_func()\n            test_case._test_tensor_to_global()\n            test_case._test_tensor_to_local()\n            test_case._test_eagar_state_dict()\n            test_case._test_graph_state_dict()\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_chunk.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nimport numpy as np\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=1, check_graph=True)\ndef _test_chunk(test_case, ndim, placement, sbp):\n    dims = [random(1, 3).to(int) * 8 for _ in range(ndim)]\n    x = random_tensor(ndim, *dims).to_global(placement=placement, sbp=sbp)\n    dim = random(-ndim, ndim).to(int)\n    chunks = random(low=1, high=4).to(int)\n    y = torch.chunk(x, chunks=chunks, dim=dim)\n    z = torch.cat(y, dim=dim)\n    return z\n\n\nclass TestModule(flow.unittest.TestCase):\n    @globaltest\n    def test_chunk(test_case):\n        for placement in all_placement():\n            ndim = random(1, 4).to(int).value()\n            for sbp in all_sbp(placement, max_dim=min(ndim, 2)):\n                _test_chunk(test_case, ndim, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_clone.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=1, check_graph=True)\ndef do_test_clone_impl(test_case, ndim, placement, sbp):\n    dims = [random(1, 4) * 8 for i in range(ndim)]\n    x = random_tensor(ndim, *dims)\n    y = x.to_global(placement=placement, sbp=sbp)\n    z = y.clone()\n    return z\n\n\nclass TestCloneConsistent(flow.unittest.TestCase):\n    @globaltest\n    def test_clone(test_case):\n        # random ndim in range [1,4]\n        ndim = random(1, 5).to(int).value()\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=ndim):\n                do_test_clone_impl(test_case, ndim, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_coin_flip.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\nfrom oneflow.test_utils.test_util import GenArgDict\n\n\ndef _test_global_coin_flip(\n    test_case, batch_size, random_seed, probability, placement, sbp\n):\n    m = flow.nn.CoinFlip(\n        batch_size, random_seed, probability, placement=placement, sbp=sbp\n    )\n    x = m()\n\n    test_case.assertEqual(x.shape[0], batch_size)\n    test_case.assertEqual(x.sbp, sbp)\n    test_case.assertEqual(x.placement, placement)\n\n\ndef _test_graph_coin_flip(\n    test_case, batch_size, random_seed, probability, placement, sbp\n):\n    class GlobalCoinFlipGraph(flow.nn.Graph):\n        def __init__(self,):\n            super().__init__()\n            self.m = flow.nn.CoinFlip(\n                batch_size, random_seed, probability, placement=placement, sbp=sbp\n            )\n\n        def build(self):\n            return self.m()\n\n    model = GlobalCoinFlipGraph()\n    x = model()\n\n    test_case.assertEqual(x.shape[0], batch_size)\n    test_case.assertEqual(x.sbp, sbp)\n    test_case.assertEqual(x.placement, placement)\n\n\nclass TestCoinFlipGlobal(flow.unittest.TestCase):\n    @globaltest\n    def test_coin_flip_global(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"batch_size\"] = [8, 64]\n        arg_dict[\"random_seed\"] = [None, 1, -1]\n        arg_dict[\"probability\"] = [0.0, 0.5, 1.0]\n        for args in GenArgDict(arg_dict):\n            for placement in all_placement():\n                # TODO: CoinFlip support cuda kernel\n                if placement.type == \"cuda\":\n                    continue\n\n                for sbp in all_sbp(placement, max_dim=1, except_partial_sum=True):\n                    _test_global_coin_flip(\n                        test_case, **args, placement=placement, sbp=sbp\n                    )\n\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    @flow.unittest.skip_unless_1n2d()\n    def test_coin_flip_graph(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"batch_size\"] = [8]\n        arg_dict[\"random_seed\"] = [None, 1, -1]\n        arg_dict[\"probability\"] = [0.0, 0.5, 1.0]\n        arg_dict[\"placement\"] = [\n            # 1d\n            flow.placement(\"cpu\", ranks=[0, 1]),\n            # TODO: CoinFlip support cuda kernel\n            #  flow.placement(\"cuda\", ranks=[0, 1]),\n            # 2d\n            flow.placement(\"cpu\", ranks=[[0, 1],]),\n            # TODO: CoinFlip support cuda kernel\n            #  flow.placement(\"cuda\", ranks=[[0, 1],]),\n        ]\n        for args in GenArgDict(arg_dict):\n            placement = args[\"placement\"]\n            for sbp in all_sbp(placement, max_dim=1, except_partial_sum=True):\n                _test_graph_coin_flip(test_case, **args, sbp=sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_concat.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=1, check_graph=True)\ndef _test_cat_with_random_data(test_case, placement, sbp):\n    x = random_tensor(ndim=2, dim0=8, dim1=8).to_global(placement=placement, sbp=sbp)\n    return torch.cat((x, x), random(0, 2).to(int))\n\n\n@autotest(n=1, auto_backward=False, check_graph=True)\ndef _test_concat_with_input_0_size_data(test_case, placement, sbp):\n    x = random_tensor(4, 8, 8, 2, 4).to_global(placement=placement, sbp=sbp)\n    y = random_tensor(4, 8, 8, random(0, 3) * 8, 4).to_global(\n        placement=placement, sbp=sbp\n    )\n    z = torch.cat((x, y), dim=2)\n    return z\n\n\n@autotest(n=1, auto_backward=False, check_graph=True)\ndef _test_concat_with_output_0_size_data(test_case, placement, sbp):\n    x = random_tensor(4, 8, 8, 0, 4).to_global(placement=placement, sbp=sbp)\n    y = random_tensor(4, 8, 8, 0, 4).to_global(placement=placement, sbp=sbp)\n    z = torch.cat((x, y), dim=2)\n    return z\n\n\n@autotest(n=1, check_graph=True)\ndef _test_cat_only_one_tensor(test_case, placement, sbp):\n    x = random_tensor(4, 8, 8, random(1, 3) * 8, 8).to_global(\n        placement=placement, sbp=sbp\n    )\n    return torch.cat((x,), 0)\n\n\nclass TestModule(flow.unittest.TestCase):\n    @globaltest\n    def test_cat_with_random_data(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=2):\n                _test_cat_with_random_data(test_case, placement, sbp)\n\n    @globaltest\n    def test_cat_only_one_tensor(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=2):\n                _test_cat_only_one_tensor(test_case, placement, sbp)\n\n    @globaltest\n    def test_concat_with_input_0_size_data(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=2):\n                _test_concat_with_input_0_size_data(test_case, placement, sbp)\n\n    @globaltest\n    def test_concat_with_output_0_size_data(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=2):\n                _test_concat_with_output_0_size_data(test_case, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_constant.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\nfrom oneflow.test_utils.test_util import GenArgDict\n\n\ndef _test_global_new_full(test_case, shape, full_value, placement, sbp):\n\n    np_res = np.full(shape, full_value)\n    x = flow.ones(shape)\n    y = x.new_full(shape, full_value, placement=placement, sbp=sbp)\n\n    test_case.assertEqual(y.shape, flow.Size(shape))\n    test_case.assertEqual(y.sbp, sbp)\n    test_case.assertEqual(y.placement, placement)\n\n    y = y.to_global(\n        placement=placement,\n        sbp=[flow.sbp.broadcast for _ in range(len(placement.ranks.shape))],\n    ).to_local()\n    test_case.assertTrue(np.array_equal(y.numpy(), np_res))\n\n\ndef _test_global_graph_new_full(test_case, shape, full_value, placement, sbp):\n\n    np_res = np.full(shape, full_value)\n\n    class GlobalNewFullGraph(flow.nn.Graph):\n        def __init__(self,):\n            super().__init__()\n\n        def build(self,):\n            x = flow.ones(shape)\n            y = x.new_full(shape, full_value, placement=placement, sbp=sbp)\n            return y\n\n    model = GlobalNewFullGraph()\n    y = model()\n\n    test_case.assertEqual(y.shape, flow.Size(shape))\n    test_case.assertEqual(y.sbp, sbp)\n    test_case.assertEqual(y.placement, placement)\n\n    y = y.to_global(\n        placement=placement,\n        sbp=[flow.sbp.broadcast for _ in range(len(placement.ranks.shape))],\n    ).to_local()\n    test_case.assertTrue(np.array_equal(y.numpy(), np_res))\n\n\ndef _test_global_constant(test_case, func, shape, placement, sbp):\n    func2 = None\n    if func == \"ones\":\n        func = flow.ones\n        np_res = np.ones(shape)\n    elif func == \"zeros\":\n        func = flow.zeros\n        np_res = np.zeros(shape)\n    elif func == \"new_zeros\":\n        func = flow.zeros\n        np_res = np.zeros(shape)\n        func2 = flow.new_zeros\n    else:\n        raise NotImplementedError\n\n    x = func(*shape, placement=placement, sbp=sbp)\n    if func2:\n        x = func2(x)\n\n    test_case.assertEqual(x.shape, flow.Size(shape))\n    test_case.assertEqual(x.sbp, sbp)\n    test_case.assertEqual(x.placement, placement)\n\n    x = x.to_global(\n        placement=placement,\n        sbp=[flow.sbp.broadcast for _ in range(len(placement.ranks.shape))],\n    ).to_local()\n    test_case.assertTrue(np.array_equal(x.numpy(), np_res))\n\n\ndef _test_graph_constant(test_case, func, shape, placement, sbp):\n    func2 = None\n    if func == \"ones\":\n        func = flow.ones\n        np_res = np.ones(shape)\n    elif func == \"zeros\":\n        func = flow.zeros\n        np_res = np.zeros(shape)\n    elif func == \"new_zeros\":\n        func = flow.zeros\n        np_res = np.zeros(shape)\n        func2 = flow.new_zeros\n    else:\n        raise NotImplementedError\n\n    class GlobalConstantGraph(flow.nn.Graph):\n        def __init__(self,):\n            super().__init__()\n\n        def build(self):\n            x = func(*shape, placement=placement, sbp=sbp)\n            if func2:\n                x = func2(x)\n            return x\n\n    model = GlobalConstantGraph()\n    x = model()\n\n    test_case.assertEqual(x.shape, flow.Size(shape))\n    test_case.assertEqual(x.sbp, sbp)\n    test_case.assertEqual(x.placement, placement)\n\n    x = x.to_global(\n        placement=placement,\n        sbp=[flow.sbp.broadcast for _ in range(len(placement.ranks.shape))],\n    ).to_local()\n    test_case.assertTrue(np.array_equal(x.numpy(), np_res))\n\n\nclass TestConstantGlobal(flow.unittest.TestCase):\n    @globaltest\n    def test_constant_global(test_case):\n        shapes = [(8,), (8, 8,), (8, 8, 8)]\n        functions = [\n            \"ones\",\n            \"zeros\",\n            \"new_zeros\",\n        ]\n        for func in functions:\n            for shape in shapes:\n                for placement in all_placement():\n                    for sbp in all_sbp(\n                        placement, max_dim=len(shape), except_partial_sum=True\n                    ):\n                        _test_global_constant(test_case, func, shape, placement, sbp)\n\n        full_values = [2, 3, 4]\n        for full_value in full_values:\n            for shape in shapes:\n                for placement in all_placement():\n                    for sbp in all_sbp(placement, max_dim=len(shape),):\n                        _test_global_new_full(\n                            test_case, shape, full_value, placement, sbp\n                        )\n\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    @flow.unittest.skip_unless_1n2d()\n    def test_constant_graph(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"func\"] = [\"ones\", \"zeros\", \"new_zeros\"]\n        arg_dict[\"shape\"] = [(8,), (8, 8,), (8, 8, 8)]\n        arg_dict[\"placement\"] = [\n            # 1d\n            flow.placement(\"cpu\", ranks=[0, 1]),\n            flow.placement(\"cuda\", ranks=[0, 1]),\n            # 2d\n            flow.placement(\"cpu\", ranks=[[0, 1],]),\n            flow.placement(\"cuda\", ranks=[[0, 1],]),\n        ]\n\n        for args in GenArgDict(arg_dict):\n            func = args[\"func\"]\n            shape = args[\"shape\"]\n            placement = args[\"placement\"]\n            for sbp in all_sbp(placement, max_dim=len(shape), except_partial_sum=True):\n                _test_graph_constant(test_case, func, shape, placement, sbp)\n        full_values = [2, 3, 4]\n        shapes = [(8,), (8, 8,), (8, 8, 8)]\n        for full_value in full_values:\n            for shape in shapes:\n                for placement in all_placement():\n                    for sbp in all_sbp(placement, max_dim=len(shape)):\n                        _test_global_graph_new_full(\n                            test_case, shape, full_value, placement, sbp\n                        )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_ctc_loss.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nimport oneflow as flow\nimport oneflow.unittest\nimport torch\nfrom oneflow.test_utils.automated_test_util.generators import *\nfrom oneflow.test_utils.automated_test_util.torch_flow_dual_object import globaltest\nfrom oneflow.test_utils.test_util import GenArgDict\n\n\ndef log_softmax(logits, axis=0):\n    max_value = np.max(logits, axis, keepdims=True)\n    exp = np.exp(logits - max_value)\n    exp_sum = np.sum(exp, axis, keepdims=True)\n    dist = exp / exp_sum\n    return np.log(dist)\n\n\ndef _compare_torch_and_oneflow(\n    test_case,\n    torch_ctc_loss,\n    flow_ctc_loss,\n    placement,\n    module_sbp,\n    in_sbp,\n    max_input_length,\n    batch_size,\n    num_classes,\n    max_target_length,\n):\n    log_probs = np.random.random(\n        size=(max_input_length, batch_size, num_classes)\n    ).astype(np.float32)\n    log_probs = log_softmax(log_probs, axis=2)\n    targets = np.random.randint(\n        1, high=num_classes, size=(batch_size, max_target_length), dtype=np.int32\n    )\n    input_lengths = np.random.randint(\n        max_input_length / 2, high=max_input_length, size=(batch_size,), dtype=np.int32\n    )\n    target_lengths = np.random.randint(\n        max_target_length / 2,\n        high=max_target_length,\n        size=(batch_size,),\n        dtype=np.int32,\n    )\n\n    log_probs_torch = torch.tensor(log_probs, dtype=torch.float32, requires_grad=True)\n    targets_torch = torch.tensor(targets, dtype=torch.int32)\n    input_lengths_torch = torch.tensor(input_lengths, dtype=torch.int32)\n    target_lengths_torch = torch.tensor(target_lengths, dtype=torch.int32)\n\n    log_probs_flow = (\n        flow.tensor(log_probs, dtype=flow.float32, requires_grad=True)\n        .to_global(flow.placement.all(\"cpu\"), flow.sbp.broadcast)\n        .to_global(placement=placement, sbp=in_sbp)\n    )\n    targets_flow = (\n        flow.tensor(targets, dtype=flow.int32)\n        .to_global(flow.placement.all(\"cpu\"), flow.sbp.broadcast)\n        .to_global(placement=placement, sbp=in_sbp)\n    )\n    input_lengths_flow = (\n        flow.tensor(input_lengths, dtype=flow.int32)\n        .to_global(flow.placement.all(\"cpu\"), flow.sbp.broadcast)\n        .to_global(placement=placement, sbp=in_sbp)\n    )\n    target_lengths_flow = (\n        flow.tensor(target_lengths, dtype=flow.int32)\n        .to_global(flow.placement.all(\"cpu\"), flow.sbp.broadcast)\n        .to_global(placement=placement, sbp=in_sbp)\n    )\n\n    out_torch = torch_ctc_loss(\n        log_probs_torch, targets_torch, input_lengths_torch, target_lengths_torch\n    )\n    out_flow = flow_ctc_loss(\n        log_probs_flow, targets_flow, input_lengths_flow, target_lengths_flow\n    )\n\n    # check forward\n    local_output = out_flow.to_global(\n        placement=placement,\n        sbp=[flow.sbp.broadcast for _ in range(len(placement.ranks.shape))],\n    ).to_local()\n    if flow.env.get_rank() == 0:\n        test_case.assertTrue(\n            np.allclose(\n                out_torch.cpu().detach().numpy(),\n                local_output.numpy(),\n                rtol=1e-05,\n                atol=1e-05,\n            )\n        )\n\n    # check backward\n    out_torch.sum().backward()\n    out_flow.sum().backward()\n    local_x_grad = log_probs_flow.to_global(\n        placement=placement,\n        sbp=[flow.sbp.broadcast for _ in range(len(placement.ranks.shape))],\n    ).to_local()\n    if flow.env.get_rank() == 0:\n        test_case.assertTrue(\n            np.allclose(\n                log_probs_torch.cpu().detach().numpy(),\n                local_x_grad.numpy(),\n                rtol=1e-05,\n                atol=1e-05,\n            )\n        )\n\n\ndef _test_ctc_loss_impl(\n    test_case,\n    placement,\n    module_sbp,\n    in_sbp,\n    max_input_length,\n    batch_size,\n    num_classes,\n    max_target_length,\n    blank,\n    reduction,\n    zero_infinity,\n):\n    torch_ctc_loss = torch.nn.CTCLoss(\n        blank=blank, reduction=reduction, zero_infinity=zero_infinity\n    )\n    flow_ctc_loss = flow.nn.CTCLoss(\n        blank=blank, reduction=reduction, zero_infinity=zero_infinity\n    )\n    _compare_torch_and_oneflow(\n        test_case,\n        torch_ctc_loss,\n        flow_ctc_loss,\n        placement,\n        module_sbp,\n        in_sbp,\n        max_input_length,\n        batch_size,\n        num_classes,\n        max_target_length,\n    )\n\n\n@flow.unittest.skip_unless_1n2d()\n@unittest.skip(\"skip for now, becase it segfaults several times in CI\")\nclass TestCTCLossGlobal(oneflow.unittest.TestCase):\n    @globaltest\n    def test_ctc_loss_global(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"max_input_length\"] = [20]\n        arg_dict[\"batch_size\"] = [4]\n        arg_dict[\"num_classes\"] = [5]\n        arg_dict[\"max_target_length\"] = [10]\n        arg_dict[\"blank\"] = [0, 4]\n        arg_dict[\"reduction\"] = [\"mean\", \"none\"]\n        arg_dict[\"zero_infinity\"] = [False, True]\n\n        module_sbp = flow.sbp.broadcast\n        for args in GenArgDict(arg_dict):\n            for placement in all_placement():\n                for in_sbp in all_sbp(placement):\n                    _test_ctc_loss_impl(\n                        test_case, placement, module_sbp, in_sbp, **args\n                    )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_cumprod.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=2, auto_backward=True, check_graph=True)\ndef _test_cumprod_impl(test_case, ndim, placement, sbp):\n    dims = [random(1, 4) * 8 for i in range(ndim)]\n    x = random_tensor(ndim, *dims)\n    y = x.to_global(placement=placement, sbp=sbp)\n    dim = random(0, ndim).to(int).value()\n    z = torch.cumprod(y, dim)\n    return z\n\n\nclass TestCumprodGlobal(flow.unittest.TestCase):\n    @globaltest\n    def test_cumprod(test_case):\n        # random ndim in range [1,4]\n        ndim = random(1, 5).to(int).value()\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=min(2, ndim)):\n                _test_cumprod_impl(test_case, ndim, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_cumsum.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=1, auto_backward=True, check_graph=True)\ndef _test_cumsum_impl(test_case, ndim, placement, sbp):\n    dims = [random(1, 4) * 8 for i in range(ndim)]\n    x = random_tensor(ndim, *dims)\n    y = x.to_global(placement=placement, sbp=sbp)\n    dim = random(0, ndim).to(int).value()\n    z = torch.cumsum(x, dim)\n    return z\n\n\n@unittest.skip(\"This fails in multi-gpu\")\nclass TestCumsumGlobal(flow.unittest.TestCase):\n    @globaltest\n    def test_cumsum(test_case):\n        # random ndim in range [1,4]\n        ndim = random(1, 5).to(int).value()\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=min(2, ndim)):\n                _test_cumsum_impl(test_case, ndim, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_deconv2d.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=1, check_graph=True, rtol=1e-2, atol=1e-3)\ndef _test_deconv2d_impl(test_case, placement, input_sbp):\n    ndim = 4\n    in_channels = random(1, 5).to(int).value() * 8\n    groups = random(1, 4).to(int).value()\n    out_channels = groups * 8\n    kernel_size = random(1, 4).to(int).value()\n    stride = random(1, 5).to(int).value()\n    padding = random(1, 3).to(int).value()\n    dilation = random(1, 5).to(int).value()\n    padding_mode = constant(\"zeros\")\n    m = torch.nn.ConvTranspose2d(\n        in_channels=in_channels,\n        out_channels=out_channels,\n        kernel_size=kernel_size,\n        stride=stride,\n        padding=padding,\n        dilation=dilation,\n        groups=groups,\n        padding_mode=padding_mode,\n        bias=False,\n    )\n    m.train(random())\n\n    weight_sbp = random_sbp(placement, max_dim=2, except_partial_sum=True)\n    m.weight = torch.nn.Parameter(\n        m.weight.to_global(placement=placement, sbp=weight_sbp)\n    )\n\n    if m.bias is not None:\n        bias_sbp = random_sbp(placement, max_dim=1)\n        m.bias = torch.nn.Parameter(m.bias.to_global(placement=placement, sbp=bias_sbp))\n\n    batch = random(1, 3).to(int).value() * 8\n    height = random(1, 5).to(int).value() * 8\n    width = random(1, 5).to(int).value() * 8\n    nchw = [batch, in_channels, height, width]\n    x = random_tensor(ndim, *nchw).to_global(placement=placement, sbp=input_sbp)\n    y = m(x)\n    return y\n\n\nclass TestDeconv2dGlobal(flow.unittest.TestCase):\n    @unittest.skip(\"skip for now, becase it failed 2 times in past week\")\n    @globaltest\n    def test_deconv2d(test_case):\n        for placement in all_placement():\n            for input_sbp in all_sbp(placement, max_dim=2):\n                _test_deconv2d_impl(test_case, placement, input_sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_deform_conv2d.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\n\nimport numpy as np\nimport oneflow as flow\nimport oneflow.unittest\nimport torch as pytorch\nimport torchvision\nfrom oneflow.test_utils.automated_test_util import *\n\n\ndef _test_deform_conv2d(test_case, placement):\n    input_sbp = random_sbp(placement, max_dim=4)\n    input_dims = [8, 8, 8, 8]\n    input = random_tensor(4, *input_dims).to_global(placement=placement, sbp=input_sbp)\n\n    offset_sbp = random_sbp(placement, max_dim=2)\n    offset_dims = [8, 32, 5, 5]\n    offset = random_tensor(4, *offset_dims).to_global(\n        placement=placement, sbp=offset_sbp\n    )\n\n    mask_sbp = random_sbp(placement, max_dim=2)\n    mask_dims = [8, 4 * 4, 5, 5]\n    mask = random_tensor(4, *mask_dims).to_global(placement=placement, sbp=mask_sbp)\n\n    weight_sbp = random_sbp(placement, max_dim=2)\n    weight_dims = [8, 8, 4, 4]\n    weight = random_tensor(4, *weight_dims).to_global(\n        placement=placement, sbp=weight_sbp\n    )\n\n    bias_sbp = random_sbp(placement, max_dim=1)\n    bias_dims = [8]\n    bias = random_tensor(1, *bias_dims).to_global(placement=placement, sbp=bias_sbp)\n\n    flow_input = input.oneflow.detach().requires_grad_()\n    torch_input = input.pytorch.detach().requires_grad_()\n    flow_offset = offset.oneflow.detach().requires_grad_()\n    torch_offset = offset.pytorch.detach().requires_grad_()\n    flow_weight = weight.oneflow.detach().requires_grad_()\n    torch_weight = weight.pytorch.detach().requires_grad_()\n    flow_mask = mask.oneflow.detach().requires_grad_()\n    torch_mask = mask.pytorch.detach().requires_grad_()\n    flow_bias = bias.oneflow.detach().requires_grad_()\n    torch_bias = bias.pytorch.detach().requires_grad_()\n\n    torch_out = torchvision.ops.deform_conv2d(\n        torch_input, torch_offset, torch_weight, mask=torch_mask, bias=torch_bias\n    )\n    flow_out = oneflow.nn.functional.deform_conv2d(\n        flow_input, flow_offset, flow_weight, mask=flow_mask, bias=flow_bias\n    )\n\n    # compare forward\n    test_case.assertTrue(\n        np.allclose(\n            flow_out.numpy(), torch_out.detach().cpu().numpy(), rtol=1e-04, atol=1e-4\n        )\n    )\n\n    # compare backward\n    flow_out.sum().backward()\n    torch_out.sum().backward()\n\n    flow_input_grad = flow_input.grad\n    torch_input_grad = torch_input.grad.detach().cpu()\n    flow_weight_grad = flow_weight.grad\n    torch_weight_grad = torch_weight.grad.detach().cpu()\n    flow_offset_grad = flow_offset.grad\n    torch_offset_grad = torch_offset.grad.detach().cpu()\n    flow_mask_grad = flow_mask.grad\n    torch_mask_grad = torch_mask.grad.detach().cpu()\n    flow_bias_grad = flow_bias.grad\n    torch_bias_grad = torch_bias.grad.detach().cpu()\n\n    test_case.assertTrue(\n        np.allclose(\n            flow_input_grad.numpy(), torch_input_grad.numpy(), rtol=1e-04, atol=1e-4\n        )\n    )\n    test_case.assertTrue(\n        np.allclose(\n            flow_weight_grad.numpy(), torch_weight_grad.numpy(), rtol=1e-04, atol=1e-4\n        )\n    )\n    test_case.assertTrue(\n        np.allclose(\n            flow_offset_grad.numpy(), torch_offset_grad.numpy(), rtol=1e-04, atol=1e-4\n        )\n    )\n    test_case.assertTrue(\n        np.allclose(\n            flow_mask_grad.numpy(), torch_mask_grad.numpy(), rtol=1e-04, atol=1e-4\n        )\n    )\n    test_case.assertTrue(\n        np.allclose(\n            flow_bias_grad.numpy(), torch_bias_grad.numpy(), rtol=1e-04, atol=1e-4\n        )\n    )\n\n\nclass TestGlobalDeformConv2d(flow.unittest.TestCase):\n    @globaltest\n    def test_deform_conv2d(test_case):\n        for placement in all_placement():\n            for count in range(5):\n                _test_deform_conv2d(test_case, placement)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_det.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport re\nimport unittest\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\ndef det_all_placement():\n    cuda_version = flow._oneflow_internal.flags.cuda_version()\n    if cuda_version < 11000:  # cuSOLVER is only supported in CUDA 11.0 and above\n        return all_cpu_placement()\n    else:\n        # FIXME: remove this after fixing the bug of cuda global det\n        return all_cpu_placement()\n        # return all_placement()\n\n\n@autotest(n=1, check_graph=False, auto_backward=\"auto\")\ndef _test_det(test_case, placement, sbp, ndim):\n    dim_list = [random(1, 3).to(int).value() * 8 for _ in range(ndim - 2)]\n    square_dim = 8\n    dim_list.extend([square_dim] * 2)\n    x = (\n        random_tensor(ndim, *dim_list, low=-1)\n        .to(torch.double)\n        .to_global(placement, sbp)\n    )\n    return torch.linalg.det(x)\n\n\nclass TestDet(flow.unittest.TestCase):\n    @globaltest\n    def test_det(test_case):\n        ndim = random(2, 5).to(int).value()\n        for placement in det_all_placement():\n            for sbp in all_sbp(placement, max_dim=ndim):\n                _test_det(test_case, placement, sbp, ndim)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_diag.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\n\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=1, check_graph=True)\ndef do_test_diag_impl(test_case, ndim, placement, sbp):\n    dims = [random(1, 4) * 8 for i in range(ndim)]\n    x = random_tensor(ndim, *dims)\n    y = x.to_global(placement=placement, sbp=sbp)\n    return torch.diag(y)\n\n\nclass TestDiagGlobal(flow.unittest.TestCase):\n    @globaltest\n    def test_diag(test_case):\n        # random ndim in range [1,2]\n        ndim = random(1, 3).to(int).value()\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=ndim):\n                do_test_diag_impl(test_case, ndim, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_diagonal.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=1, auto_backward=True, check_graph=True)\ndef _test_diagonal_impl(test_case, placement, sbp):\n    offset = random(-5, 5).to(int).value()\n    dim1 = random(-4, 4).to(int).value()\n    dim2 = random(-4, 4).to(int).value()\n\n    x = random_tensor(\n        ndim=4,\n        dim0=random(1, 4) * 8,\n        dim1=random(1, 4) * 8,\n        dim2=random(1, 4) * 8,\n        dim3=random(1, 4) * 8,\n    )\n    y = x.to_global(placement=placement, sbp=sbp)\n    z = torch.diagonal(y, offset, dim1, dim2)\n    return z\n\n\n@unittest.skip(\"TODO: fix this test\")\nclass TestDiagonalGlobal(flow.unittest.TestCase):\n    @globaltest\n    def test_diagonal(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=4):\n                _test_diagonal_impl(test_case, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_div.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=1, check_graph=True)\ndef do_test_div_impl(test_case, ndim, placement, sbp):\n    dims = [random(1, 4) * 8 for i in range(ndim)]\n    x = random_tensor(ndim, *dims)\n    x = x.to_global(placement=placement, sbp=sbp)\n    y = random_tensor(ndim, *dims)\n    y = y.to_global(placement=placement, sbp=sbp)\n\n    z = torch.div(x, y)\n    return z\n\n\nclass TestDivGlobal(flow.unittest.TestCase):\n    @globaltest\n    def test_div(test_case):\n        # random ndim in range [1,4]\n        ndim = random(1, 5).to(int).value()\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=ndim):\n                do_test_div_impl(test_case, ndim, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_dot.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=1, check_graph=True)\ndef do_test_dot_impl(test_case, placement, sbp):\n    k = random(100, 1000) * 8\n    x = random_tensor(ndim=1, dim0=k).to_global(placement=placement, sbp=sbp)\n    y = random_tensor(ndim=1, dim0=k).to_global(placement=placement, sbp=sbp)\n    z = torch.dot(x, y)\n    return z\n\n\nclass TestDotGlobal(flow.unittest.TestCase):\n    @unittest.skip(\"skip for now, becase it failed 4 times in past week\")\n    @globaltest\n    def test_dot(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=1):\n                do_test_dot_impl(test_case, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_dropout.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=1, auto_backward=True, check_graph=True, atol=1e-5, rtol=1e-5)\ndef _test_dropout_p01(test_case, placement, sbp, ndim, p):\n    dims = [random(1, 5) * 8 for i in range(ndim)]\n    x = random_tensor(ndim, *dims)\n    y = x.to_global(placement=placement, sbp=sbp)\n    m = torch.nn.Dropout(p=p, inplace=False)\n    return m(x)\n\n\n@autotest(n=1, auto_backward=True, check_graph=True, atol=1e-5, rtol=1e-5)\ndef _test_dropout_eval_p01(test_case, placement, sbp, ndim, p):\n    dims = [random(1, 5) * 8 for i in range(ndim)]\n    x = random_tensor(ndim, *dims)\n    y = x.to_global(placement=placement, sbp=sbp)\n    m = torch.nn.Dropout(p=p, inplace=False)\n    m.eval()\n    return m(x)\n\n\nclass TestDropoutGlobal(flow.unittest.TestCase):\n    @globaltest\n    def test_dropout_p01(test_case):\n        # random ndim in range [1,3]\n        ndim = random(1, 4).to(int).value()\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=min(2, ndim)):\n                _test_dropout_p01(test_case, placement, sbp, ndim, p=0.0)\n                _test_dropout_p01(test_case, placement, sbp, ndim, p=1.0)\n\n    @globaltest\n    def test_dropout_eval(test_case):\n        # random ndim in range [1,3]\n        ndim = random(1, 4).to(int).value()\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=min(2, ndim)):\n                _test_dropout_eval_p01(test_case, placement, sbp, ndim, 0.0)\n                _test_dropout_eval_p01(test_case, placement, sbp, ndim, 1.0)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_einsum_alphaflod_usecase1.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\n\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=2, check_graph=True, atol=1e-2)\ndef _test_einsum_alphaflod_usecase1(test_case, placement, sbp):\n    dim0 = random(1, 3) * 8\n    dim1 = random(1, 3) * 8\n    x = random_tensor(ndim=3, dim0=random(1, 3) * 8, dim1=dim0, dim2=dim1,)\n    y = random_tensor(ndim=3, dim0=dim0, dim1=dim1, dim2=random(1, 3) * 8,)\n    g_x = x.to_global(placement=placement, sbp=sbp)\n    g_y = y.to_global(placement=placement, sbp=sbp)\n    z = torch.einsum(\"hij, ijc->ihc\", g_x, g_y)\n    return z\n\n\nclass TestEinsumGlobal(flow.unittest.TestCase):\n    @globaltest\n    def test_einsum_alphaflod_usecase1(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=3):\n                _test_einsum_alphaflod_usecase1(test_case, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_einsum_alphaflod_usecase10.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\n\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=2, check_graph=True)\ndef _test_einsum_alphaflod_usecase10(test_case, placement, sbp):\n    dim0 = random(1, 3) * 8\n    dim1 = random(1, 3) * 8\n    dim2 = random(1, 3) * 8\n    x = random_tensor(ndim=4, dim0=dim0, dim1=dim1, dim2=random(1, 3) * 8, dim3=dim2,)\n    y = random_tensor(ndim=4, dim0=dim0, dim1=dim2, dim2=dim1, dim3=random(1, 3) * 8)\n    g_x = x.to_global(placement=placement, sbp=sbp)\n    g_y = y.to_global(placement=placement, sbp=sbp)\n    z = torch.einsum(\"bhqk,bkhc->bqhc\", g_x, g_y)\n    return z\n\n\n@unittest.skipIf(True, \"skip this test temporarily\")\nclass TestEinsumGlobal(flow.unittest.TestCase):\n    @globaltest\n    def test_einsum_alphaflod_usecase10(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=4):\n                _test_einsum_alphaflod_usecase10(test_case, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_einsum_alphaflod_usecase11.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\n\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=2, check_graph=True, rtol=1e-3, atol=1e-4)\ndef _test_einsum_alphaflod_usecase11(test_case, placement, sbp):\n    dim0 = random(1, 3) * 8\n    x = random_tensor(ndim=3, dim0=random(1, 3) * 8, dim1=random(1, 3) * 8, dim2=dim0,)\n    y = random_tensor(ndim=3, dim0=dim0, dim1=random(1, 3) * 8, dim2=random(1, 3) * 8,)\n    g_x = x.to_global(placement=placement, sbp=sbp)\n    g_y = y.to_global(placement=placement, sbp=sbp)\n    z = torch.einsum(\"bqa,ahc->bqhc\", g_x, g_y)\n    return z\n\n\nclass TestEinsumGlobal(flow.unittest.TestCase):\n    @globaltest\n    def test_einsum_alphaflod_usecase11(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=3):\n                _test_einsum_alphaflod_usecase11(test_case, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_einsum_alphaflod_usecase2.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\n\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=2, check_graph=True, atol=1e-2)\ndef _test_einsum_alphaflod_usecase2(test_case, placement, sbp):\n    dim0 = random(1, 3) * 8\n    dim1 = random(1, 3) * 8\n    x = random_tensor(ndim=3, dim0=dim0, dim1=dim1, dim2=random(1, 3) * 8,)\n    y = random_tensor(ndim=3, dim0=dim0, dim1=dim1, dim2=random(1, 3) * 8,)\n    g_x = x.to_global(placement=placement, sbp=sbp)\n    g_y = y.to_global(placement=placement, sbp=sbp)\n    z = torch.einsum(\"rac,rab->rbc\", g_x, g_y)\n    return z\n\n\nclass TestEinsumGlobal(flow.unittest.TestCase):\n    @globaltest\n    def test_einsum_alphaflod_usecase2(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=3):\n                _test_einsum_alphaflod_usecase2(test_case, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_einsum_alphaflod_usecase3.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\n\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=2, check_graph=True)\ndef _test_einsum_alphaflod_usecase3(test_case, placement, sbp):\n    dim0 = random(1, 3) * 8\n    dim1 = random(1, 3) * 8\n    x = random_tensor(ndim=2, dim0=dim0, dim1=dim1,)\n    y = random_tensor(ndim=3, dim0=dim0, dim1=dim1, dim2=random(1, 3) * 8,)\n    g_x = x.to_global(placement=placement, sbp=sbp)\n    g_y = y.to_global(placement=placement, sbp=sbp)\n    z = torch.einsum(\"ra,rab->rb\", g_x, g_y)\n    return z\n\n\nclass TestEinsumGlobal(flow.unittest.TestCase):\n    @unittest.skip(\"skip for now, becase it failed 4 times in past week\")\n    @globaltest\n    def test_einsum_alphaflod_usecase3(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=2):\n                _test_einsum_alphaflod_usecase3(test_case, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_einsum_alphaflod_usecase4.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\n\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=2, check_graph=True, atol=1e-2)\ndef _test_einsum_alphaflod_usecase4(test_case, placement, sbp):\n    dim0 = random(1, 3) * 8\n    dim1 = random(1, 3) * 8\n    x = random_tensor(ndim=3, dim0=random(1, 3) * 8, dim1=dim0, dim2=dim1,)\n    y = random_tensor(ndim=3, dim0=random(1, 3) * 8, dim1=dim0, dim2=dim1,)\n    g_x = x.to_global(placement=placement, sbp=sbp)\n    g_y = y.to_global(placement=placement, sbp=sbp)\n    z = torch.einsum(\"qhc,khc->qkh\", g_x, g_y)\n    return z\n\n\nclass TestEinsumGlobal(flow.unittest.TestCase):\n    @globaltest\n    def test_einsum_alphaflod_usecase4(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=3):\n                _test_einsum_alphaflod_usecase4(test_case, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_einsum_alphaflod_usecase5.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\n\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=2, check_graph=True)\ndef _test_einsum_alphaflod_usecase5(test_case, placement, sbp):\n    dim0 = random(1, 3) * 8\n    x = random_tensor(ndim=2, dim0=random(1, 3) * 8, dim1=dim0,)\n    y = random_tensor(ndim=3, dim0=dim0, dim1=random(1, 3) * 8, dim2=random(1, 3) * 8,)\n    g_x = x.to_global(placement=placement, sbp=sbp)\n    g_y = y.to_global(placement=placement, sbp=sbp)\n    z = torch.einsum(\"nm, mrc->nrc\", g_x, g_y)\n    return z\n\n\n@unittest.skip(\"this case fails in multi gpu. TODO: depeng, shenghang\")\nclass TestEinsumGlobal(flow.unittest.TestCase):\n    @globaltest\n    def test_einsum_alphaflod_usecase5(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=2):\n                _test_einsum_alphaflod_usecase5(test_case, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_einsum_alphaflod_usecase6.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\n\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=2, check_graph=True, atol=1e-2)\ndef _test_einsum_alphaflod_usecase6(test_case, placement, sbp):\n    dim0 = random(1, 3) * 8\n    dim1 = random(1, 3) * 8\n    x = random_tensor(ndim=3, dim0=dim0, dim1=random(1, 3) * 8, dim2=dim1,)\n    y = random_tensor(ndim=3, dim0=dim0, dim1=random(1, 3) * 8, dim2=dim1,)\n    g_x = x.to_global(placement=placement, sbp=sbp)\n    g_y = y.to_global(placement=placement, sbp=sbp)\n    z = torch.einsum(\"abc,adc->bdc\", g_x, g_y)\n    return z\n\n\nclass TestEinsumGlobal(flow.unittest.TestCase):\n    @globaltest\n    def test_einsum_alphaflod_usecase6(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=3):\n                _test_einsum_alphaflod_usecase6(test_case, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_einsum_alphaflod_usecase7.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\n\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=2, check_graph=True, rtol=1e-3, atol=1e-4)\ndef _test_einsum_alphaflod_usecase7(test_case, placement, sbp):\n    dim0 = random(1, 3) * 8\n    dim1 = random(1, 3) * 8\n    x = random_tensor(\n        ndim=4, dim0=random(1, 3) * 8, dim1=dim0, dim2=dim1, dim3=random(1, 3) * 8,\n    )\n    y = random_tensor(ndim=3, dim0=dim0, dim1=dim1, dim2=random(1, 3) * 8,)\n    g_x = x.to_global(placement=placement, sbp=sbp)\n    g_y = y.to_global(placement=placement, sbp=sbp)\n    z = torch.einsum(\"dceb,cef->dbf\", g_x, g_y)\n    return z\n\n\nclass TestEinsumGlobal(flow.unittest.TestCase):\n    @globaltest\n    def test_einsum_alphaflod_usecase7(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=3):\n                _test_einsum_alphaflod_usecase7(test_case, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_einsum_alphaflod_usecase8.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\n\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=2, check_graph=True, rtol=1e-3, atol=1e-4)\ndef _test_einsum_alphaflod_usecase8(test_case, placement, sbp):\n    dim0 = random(1, 3) * 8\n    x = random_tensor(ndim=3, dim0=dim0, dim1=random(1, 3) * 8, dim2=random(1, 3) * 8,)\n    y = random_tensor(ndim=3, dim0=dim0, dim1=random(1, 3) * 8, dim2=random(1, 3) * 8,)\n    g_x = x.to_global(placement=placement, sbp=sbp)\n    g_y = y.to_global(placement=placement, sbp=sbp)\n    z = torch.einsum(\"acb,ade->dceb\", g_x, g_y)\n    return z\n\n\nclass TestEinsumGlobal(flow.unittest.TestCase):\n    @globaltest\n    def test_einsum_alphaflod_usecase8(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=3):\n                _test_einsum_alphaflod_usecase8(test_case, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_einsum_alphaflod_usecase9.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\n\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=2, check_graph=True, rtol=1e-3, atol=1e-4)\ndef _test_einsum_alphaflod_usecase9(test_case, placement, sbp):\n    dim0 = random(1, 3) * 8\n    x = random_tensor(ndim=3, dim0=random(1, 3) * 8, dim1=random(1, 3) * 8, dim2=dim0,)\n    y = random_tensor(ndim=2, dim0=dim0, dim1=random(1, 3) * 8,)\n    g_x = x.to_global(placement=placement, sbp=sbp)\n    g_y = y.to_global(placement=placement, sbp=sbp)\n    z = torch.einsum(\"qkc,ch->hqk\", g_x, g_y)\n    return z\n\n\nclass TestEinsumGlobal(flow.unittest.TestCase):\n    @globaltest\n    def test_einsum_alphaflod_usecase9(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=2):\n                _test_einsum_alphaflod_usecase9(test_case, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_einsum_attention.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\n\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=2, check_graph=True)\ndef _test_einsum_attention(test_case, placement, sbp):\n    dim0 = random(1, 3) * 8\n    dim1 = random(1, 3) * 8\n    dim2 = random(1, 3) * 8\n    x = random_tensor(ndim=4, dim0=dim0, dim1=dim1, dim2=random(1, 3) * 8, dim3=dim2,)\n    y = random_tensor(ndim=4, dim0=dim0, dim1=dim1, dim2=random(1, 3) * 8, dim3=dim2,)\n    g_x = x.to_global(placement=placement, sbp=sbp)\n    g_y = y.to_global(placement=placement, sbp=sbp)\n    z = torch.einsum(\"b h i d, b h j d -> b h i j\", g_x, g_y)\n    return z\n\n\n@unittest.skipIf(True, \"skip this test temporarily\")\nclass TestEinsumGlobal(flow.unittest.TestCase):\n    @globaltest\n    def test_einsum_attention(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=4):\n                _test_einsum_attention(test_case, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_einsum_batch_matmul.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\n\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=2, check_graph=True, atol=1e-2)\ndef _test_einsum_batch_matmul(test_case, placement, sbp):\n    dim0 = random(1, 3) * 8\n    dim1 = random(1, 3) * 8\n    x = random_tensor(ndim=3, dim0=dim0, dim1=random(1, 3) * 8, dim2=dim1,)\n    y = random_tensor(ndim=3, dim0=dim0, dim1=dim1, dim2=random(1, 3) * 8,)\n    g_x = x.to_global(placement=placement, sbp=sbp)\n    g_y = y.to_global(placement=placement, sbp=sbp)\n    z = torch.einsum(\"ijk,ikl->ijl\", g_x, g_y)\n    return z\n\n\nclass TestEinsumGlobal(flow.unittest.TestCase):\n    @globaltest\n    def test_einsum_batch_matmul(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=3):\n                _test_einsum_batch_matmul(test_case, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_einsum_batch_matmul2.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\n\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=2, check_graph=True, atol=1e-2)\ndef _test_einsum_batch_matmul2(test_case, placement, sbp):\n    dim0 = random(1, 3) * 8\n    dim1 = random(1, 3) * 8\n    dim2 = random(1, 3) * 8\n    x = random_tensor(ndim=4, dim0=dim0, dim1=dim1, dim2=random(1, 3) * 8, dim3=dim2)\n    y = random_tensor(ndim=4, dim0=dim0, dim1=dim1, dim2=dim2, dim3=random(1, 3) * 8)\n    g_x = x.to_global(placement=placement, sbp=sbp)\n    g_y = y.to_global(placement=placement, sbp=sbp)\n    z = torch.einsum(\"b h i j, b h j d -> b h i d\", g_x, g_y)\n    return z\n\n\nclass TestEinsumGlobal(flow.unittest.TestCase):\n    @globaltest\n    def test_einsum_batch_matmul2(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=4):\n                _test_einsum_batch_matmul2(test_case, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_einsum_batch_matmul3.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\n\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=1, check_graph=True, atol=1e-2)\ndef _test_einsum_batch_matmul3(test_case, placement, sbp):\n    dim0 = random(1, 3) * 8\n    dim1 = random(1, 3) * 8\n    x = random_tensor(\n        ndim=4, dim0=dim0, dim1=random(1, 3) * 8, dim2=random(1, 3) * 8, dim3=dim1,\n    )\n    y = random_tensor(ndim=3, dim0=dim0, dim1=random(1, 3) * 8, dim2=dim1,)\n    g_x = x.to_global(placement=placement, sbp=sbp)\n    g_y = y.to_global(placement=placement, sbp=sbp)\n    z = torch.einsum(\"b x i d, b j d -> b x i j\", g_x, g_y)\n    return z\n\n\nclass TestEinsumGlobal(flow.unittest.TestCase):\n    @globaltest\n    def test_einsum_batch_matmul3(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=3):\n                _test_einsum_batch_matmul3(test_case, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_einsum_batch_matmul4.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\n\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=2, check_graph=True, atol=1e-2)\ndef _test_einsum_batch_matmul4(test_case, placement, sbp):\n    dim0 = random(1, 3) * 8\n    dim1 = random(1, 3) * 8\n    x = random_tensor(\n        ndim=4, dim0=dim0, dim1=random(1, 3) * 8, dim2=random(1, 3) * 8, dim3=dim1,\n    )\n    y = random_tensor(ndim=3, dim0=dim0, dim1=dim1, dim2=random(1, 3) * 8,)\n    g_x = x.to_global(placement=placement, sbp=sbp)\n    g_y = y.to_global(placement=placement, sbp=sbp)\n    z = torch.einsum(\"b x i j, b j d -> b x i d\", g_x, g_y)\n    return z\n\n\nclass TestEinsumGlobal(flow.unittest.TestCase):\n    @globaltest\n    def test_einsum_batch_matmul4(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=3):\n                _test_einsum_batch_matmul4(test_case, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_einsum_batch_matrix_vector_multiply.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\n\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=2, check_graph=True)\ndef _test_einsum_batch_matrix_vector_multiply(test_case, placement, sbp):\n    dim0 = random(1, 3) * 8\n    dim1 = random(1, 3) * 8\n    dim2 = random(1, 3) * 8\n    x = random_tensor(ndim=3, dim0=dim0, dim1=dim1, dim2=dim2,)\n    y = random_tensor(ndim=4, dim0=dim0, dim1=dim1, dim2=random(1, 3) * 8, dim3=dim2,)\n    g_x = x.to_global(placement=placement, sbp=sbp)\n    g_y = y.to_global(placement=placement, sbp=sbp)\n    z = torch.einsum(\"b i d, b i j d -> b i j\", g_x, g_y)\n    return z\n\n\nclass TestEinsumGlobal(flow.unittest.TestCase):\n    @unittest.skip(\"skip for now, becase it failed 28 times in past week\")\n    @globaltest\n    def test_einsum_batch_matrix_vector_multiply(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=3):\n                _test_einsum_batch_matrix_vector_multiply(test_case, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_einsum_batch_permute.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\n\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=2, check_graph=True)\ndef _test_einsum_batch_permute(test_case, placement, sbp):\n    x = random_tensor(\n        ndim=5,\n        dim0=random(1, 3) * 8,\n        dim1=random(1, 3) * 8,\n        dim2=random(1, 3) * 8,\n        dim3=random(1, 3) * 8,\n        dim4=random(1, 3) * 8,\n    )\n    g_x = x.to_global(placement=placement, sbp=sbp)\n    z = torch.einsum(\"...ij->...ji\", g_x)\n    return z\n\n\nclass TestEinsumGlobal(flow.unittest.TestCase):\n    @globaltest\n    def test_einsum_batch_permute(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=5):\n                _test_einsum_batch_permute(test_case, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_einsum_bilinear_transformation.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\n\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=2, check_graph=True)\ndef _test_einsum_bilinear_transformation(test_case, placement, sbp):\n    dim0 = random(1, 3) * 8\n    dim1 = random(1, 3) * 8\n    dim2 = random(1, 3) * 8\n    x = random_tensor(ndim=2, dim0=dim0, dim1=dim1,)\n    y = random_tensor(ndim=3, dim0=random(1, 3) * 8, dim1=dim1, dim2=dim2,)\n    w = random_tensor(ndim=2, dim0=dim0, dim1=dim2,)\n    g_x = x.to_global(placement=placement, sbp=sbp)\n    g_y = y.to_global(placement=placement, sbp=sbp)\n    g_w = w.to_global(placement=placement, sbp=sbp)\n    z = torch.einsum(\"ik,jkl,il->ij\", g_x, g_y, g_w)\n    return z\n\n\nclass TestEinsumGlobal(flow.unittest.TestCase):\n    @globaltest\n    def test_einsum_bilinear_transformation(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=2):\n                _test_einsum_bilinear_transformation(test_case, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_einsum_eltwise_mul_sum_row.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\n\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=2, check_graph=True)\ndef _test_einsum_eltwise_mul_sum_row(test_case, placement, sbp):\n    dim0 = random(1, 3) * 8\n    dim1 = random(1, 3) * 8\n    x = random_tensor(ndim=2, dim0=dim0, dim1=dim1,)\n    y = random_tensor(ndim=2, dim0=dim0, dim1=dim1,)\n    g_x = x.to_global(placement=placement, sbp=sbp)\n    g_y = y.to_global(placement=placement, sbp=sbp)\n    z = torch.einsum(\"n d, n d -> n\", g_x, g_y)\n    return z\n\n\nclass TestEinsumGlobal(flow.unittest.TestCase):\n    @globaltest\n    def test_einsum_eltwise_mul_sum_row(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=2):\n                _test_einsum_eltwise_mul_sum_row(test_case, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_einsum_eltwise_mul_then_reduce_sum.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\n\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=2, check_graph=True)\ndef _test_einsum_eltwise_mul_then_reduce_sum(test_case, placement, sbp):\n    dim0 = random(1, 3) * 8\n    dim1 = random(1, 3) * 8\n    x = random_tensor(ndim=2, dim0=dim0, dim1=dim1,)\n    y = random_tensor(ndim=2, dim0=dim0, dim1=dim1,)\n    g_x = x.to_global(placement=placement, sbp=sbp)\n    g_y = y.to_global(placement=placement, sbp=sbp)\n    # NOTE(Liang Depeng): the same as 'ij,ij->'\n    z = torch.einsum(\"ij,ij\", g_x, g_y)\n    return z\n\n\nclass TestEinsumGlobal(flow.unittest.TestCase):\n    @globaltest\n    def test_einsum_eltwise_mul_then_reduce_sum(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=2):\n                _test_einsum_eltwise_mul_then_reduce_sum(test_case, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_einsum_eltwise_multiply.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\n\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=2, check_graph=True)\ndef _test_einsum_eltwise_multiply(test_case, placement, sbp):\n    dim0 = random(1, 3) * 8\n    dim1 = random(1, 3) * 8\n    x = random_tensor(ndim=2, dim0=dim0, dim1=dim1,)\n    y = random_tensor(ndim=2, dim0=dim0, dim1=dim1,)\n    g_x = x.to_global(placement=placement, sbp=sbp)\n    g_y = y.to_global(placement=placement, sbp=sbp)\n    z = torch.einsum(\"ij,ij->ij\", g_x, g_y)\n    return z\n\n\nclass TestEinsumGlobal(flow.unittest.TestCase):\n    @globaltest\n    def test_einsum_eltwise_multiply(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=2):\n                _test_einsum_eltwise_multiply(test_case, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_einsum_get_diagonal.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\n\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=2, check_graph=True)\ndef _test_einsum_get_diagonal(test_case, placement, sbp):\n    dim = random(1, 3) * 8\n    x = random_tensor(ndim=2, dim0=dim, dim1=dim,)\n    g_x = x.to_global(placement=placement, sbp=sbp)\n    z = torch.einsum(\"ii->i\", g_x)\n    return z\n\n\nclass TestEinsumGlobal(flow.unittest.TestCase):\n    @globaltest\n    def test_einsum_get_diagonal(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=2):\n                _test_einsum_get_diagonal(test_case, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_einsum_matmul.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\n\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=2, check_graph=True, rtol=1e-3)\ndef _test_einsum_matmul(test_case, placement, sbp):\n    dim0 = random(1, 3) * 8\n    dim1 = random(1, 3) * 8\n    dim2 = random(1, 3) * 8\n    x = random_tensor(ndim=2, dim0=dim0, dim1=dim1,)\n    y = random_tensor(ndim=2, dim0=dim1, dim1=dim2,)\n    g_x = x.to_global(placement=placement, sbp=sbp)\n    g_y = y.to_global(placement=placement, sbp=sbp)\n    # NOTE(Liang Depeng): the same as 'ik,kj->ij'\n    z = torch.einsum(\"ik,kj\", g_x, g_y)\n    return z\n\n\nclass TestEinsumGlobal(flow.unittest.TestCase):\n    @globaltest\n    @unittest.skip(\"skip for now, becase it fails several times in CI\")\n    def test_einsum_matmul(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=2):\n                _test_einsum_matmul(test_case, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_einsum_matmul2.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\n\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=2, check_graph=True)\ndef _test_einsum_matmul2(test_case, placement, sbp):\n    dim0 = random(1, 3) * 8\n    x = random_tensor(ndim=2, dim0=random(1, 3) * 8, dim1=dim0,)\n    y = random_tensor(ndim=2, dim0=random(1, 3) * 8, dim1=dim0,)\n    g_x = x.to_global(placement=placement, sbp=sbp)\n    g_y = y.to_global(placement=placement, sbp=sbp)\n    z = torch.einsum(\"i d, j d -> i j\", g_x, g_y)\n    return z\n\n\nclass TestEinsumGlobal(flow.unittest.TestCase):\n    @unittest.skip(\"skip for now, becase it failed 4 times in past week\")\n    @globaltest\n    def test_einsum_matmul2(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=2):\n                _test_einsum_matmul2(test_case, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_einsum_matrix_column_sum.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\n\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=2, check_graph=True, rtol=1e-3)\ndef _test_einsum_matrix_column_sum(test_case, placement, sbp):\n    x = random_tensor(ndim=2, dim0=random(1, 3) * 8, dim1=random(1, 3) * 8,)\n    g_x = x.to_global(placement=placement, sbp=sbp)\n    z = torch.einsum(\"ij->j\", g_x)\n    return z\n\n\nclass TestEinsumGlobal(flow.unittest.TestCase):\n    @unittest.skip(\"skip for now, becase it failed 8 times in past week\")\n    @globaltest\n    def test_einsum_matrix_column_sum(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=2):\n                _test_einsum_matrix_column_sum(test_case, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_einsum_matrix_transpose.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\n\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=2, check_graph=True)\ndef _test_einsum_matrix_transpose(test_case, placement, sbp):\n    x = random_tensor(ndim=2, dim0=random(1, 3) * 8, dim1=random(1, 3) * 8)\n    g_x = x.to_global(placement=placement, sbp=sbp)\n    z = torch.einsum(\"ij->ji\", g_x)\n    return z\n\n\nclass TestEinsumGlobal(flow.unittest.TestCase):\n    @globaltest\n    def test_einsum_matrix_transpose(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=2):\n                _test_einsum_matrix_transpose(test_case, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_einsum_matrix_vector_multiply.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\n\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=2, check_graph=True)\ndef _test_einsum_matrix_vector_multiply(test_case, placement, sbp):\n    dim0 = random(1, 3) * 8\n    dim1 = random(1, 3) * 8\n    x = random_tensor(ndim=2, dim0=dim0, dim1=dim1,)\n    y = random_tensor(ndim=1, dim0=dim1,)\n    g_x = x.to_global(placement=placement, sbp=sbp)\n    g_y = y.to_global(placement=placement, sbp=sbp)\n    # NOTE(Liang Depeng): the same as 'ik,k->i'\n    z = torch.einsum(\"ik,k\", g_x, g_y)\n    return z\n\n\nclass TestEinsumGlobal(flow.unittest.TestCase):\n    @globaltest\n    def test_einsum_matrix_vector_multiply(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=1):\n                _test_einsum_matrix_vector_multiply(test_case, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_einsum_reduce_sum.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\n\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=2, check_graph=True)\ndef _test_einsum_reduce_sum(test_case, placement, sbp):\n    x = random_tensor(ndim=2, dim0=random(1, 3) * 8, dim1=random(1, 3) * 8,)\n    g_x = x.to_global(placement=placement, sbp=sbp)\n    z = torch.einsum(\"ij->\", g_x)\n    return z\n\n\nclass TestEinsumGlobal(flow.unittest.TestCase):\n    @globaltest\n    def test_einsum_reduce_sum(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=2):\n                _test_einsum_reduce_sum(test_case, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_einsum_tensor_contraction.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\n\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n# The rtol is too large caused by the expansion of random tensor range\n# of #9534. It should be checked again in the future.\n@autotest(n=1, check_graph=True, rtol=5e-1, atol=1e-3)\ndef _test_einsum_tensor_contraction(test_case, placement, sbp):\n    dim0 = random(1, 3) * 8\n    dim1 = random(1, 3) * 8\n    x = random_tensor(\n        ndim=4, dim0=random(1, 3) * 8, dim1=dim0, dim2=dim1, dim3=random(1, 3) * 8,\n    )\n    y = random_tensor(\n        ndim=5,\n        dim0=random(1, 3) * 8,\n        dim1=random(1, 3) * 8,\n        dim2=dim0,\n        dim3=random(1, 3) * 8,\n        dim4=dim1,\n    )\n    g_x = x.to_global(placement=placement, sbp=sbp)\n    g_y = y.to_global(placement=placement, sbp=sbp)\n    z = torch.einsum(\"pqrs,tuqvr->pstuv\", g_x, g_y)\n    return z\n\n\nclass TestEinsumGlobal(flow.unittest.TestCase):\n    @globaltest\n    def test_einsum_tensor_contraction(test_case):\n        for placement in all_placement():\n            if len(np.array(placement.ranks).shape) > 1 and all(\n                dim != 1 for dim in np.array(placement.ranks).shape\n            ):\n                print(\n                    f\"[{flow.env.get_rank()}] skip TestEinsumConsistent.test_einsum_tensor_contraction with {placement}\"\n                )\n                continue\n\n            for sbp in all_sbp(placement, max_dim=4):\n                _test_einsum_tensor_contraction(test_case, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_einsum_tensor_contraction2.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\n\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=2, check_graph=True, rtol=1e-3, atol=1e-4)\ndef _test_einsum_tensor_contraction2(test_case, placement, sbp):\n    dim0 = random(1, 3) * 8\n    x = random_tensor(\n        ndim=4,\n        dim0=random(1, 3) * 8,\n        dim1=dim0,\n        dim2=random(1, 3) * 8,\n        dim3=random(1, 3) * 8,\n    )\n    y = random_tensor(ndim=2, dim0=dim0, dim1=random(1, 3) * 8,)\n    g_x = x.to_global(placement=placement, sbp=sbp)\n    g_y = y.to_global(placement=placement, sbp=sbp)\n    z = torch.einsum(\"b n h w, n d -> b d h w\", g_x, g_y)\n    return z\n\n\nclass TestEinsumGlobal(flow.unittest.TestCase):\n    @unittest.skip(\"skip for now, becase it failed 10 times in past week\")\n    @globaltest\n    def test_einsum_tensor_contraction2(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=2):\n                _test_einsum_tensor_contraction2(test_case, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_einsum_vector_inner_product.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\n\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=2, check_graph=True)\ndef _test_einsum_vector_inner_product(test_case, placement, sbp):\n    dim0 = random(1, 3) * 8\n    x = random_tensor(ndim=1, dim0=dim0,)\n    y = random_tensor(ndim=1, dim0=dim0,)\n    g_x = x.to_global(placement=placement, sbp=sbp)\n    g_y = y.to_global(placement=placement, sbp=sbp)\n    # NOTE(Liang Depeng): the same as 'i,i->'\n    z = torch.einsum(\"i,i\", g_x, g_y)\n    return z\n\n\nclass TestEinsumGlobal(flow.unittest.TestCase):\n    @globaltest\n    def test_einsum_vector_inner_product(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=1):\n                _test_einsum_vector_inner_product(test_case, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_einsum_vector_outer_product.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\n\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=2, check_graph=True)\ndef _test_einsum_vector_outer_product(test_case, placement, sbp):\n    x = random_tensor(ndim=1, dim0=random(1, 3) * 8,)\n    y = random_tensor(ndim=1, dim0=random(1, 3) * 8,)\n    g_x = x.to_global(placement=placement, sbp=sbp)\n    g_y = y.to_global(placement=placement, sbp=sbp)\n    # NOTE(Liang Depeng): the same as 'i,j->ij'\n    z = torch.einsum(\"i,j\", g_x, g_y)\n    return z\n\n\nclass TestEinsumGlobal(flow.unittest.TestCase):\n    @globaltest\n    def test_einsum_vector_outer_product(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=1):\n                _test_einsum_vector_outer_product(test_case, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_empty.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\nfrom oneflow.test_utils.test_util import GenArgDict\n\n\ndef _test_global_empty(test_case, func, shape, placement, sbp):\n    func2 = None\n    if func == \"empty\":\n        func = flow.empty\n    elif func == \"new_empty\":\n        func = flow.empty\n        func2 = flow.new_empty\n    elif func == \"empty_like\":\n        func = flow.empty\n        func2 = flow.empty_like\n    else:\n        raise NotImplementedError\n\n    x = func(*shape, placement=placement, sbp=sbp)\n    if func2:\n        if func2.__name__ == \"new_empty_op\":\n            x = func2(x, size=shape)\n        elif func2.__name__ == \"empty_like_op\":\n            x = func2(x)\n        else:\n            raise NotImplementedError\n\n    test_case.assertEqual(x.shape, flow.Size(shape))\n    test_case.assertEqual(x.sbp, sbp)\n    test_case.assertEqual(x.placement, placement)\n\n\ndef _test_graph_empty(test_case, func, shape, placement, sbp):\n    func2 = None\n    if func == \"empty\":\n        func = flow.empty\n    elif func == \"new_empty\":\n        func = flow.empty\n        func2 = flow.new_empty\n    elif func == \"empty_like\":\n        func = flow.empty\n        func2 = flow.empty_like\n    else:\n        raise NotImplementedError\n\n    class GlobalEmptyGraph(flow.nn.Graph):\n        def __init__(self,):\n            super().__init__()\n\n        def build(self):\n            x = func(*shape, placement=placement, sbp=sbp)\n            if func2:\n                if func2.__name__ == \"new_empty_op\":\n                    x = func2(x, size=shape)\n                elif func2.__name__ == \"empty_like_op\":\n                    x = func2(x)\n                else:\n                    raise NotImplementedError\n            return x\n\n    model = GlobalEmptyGraph()\n    x = model()\n\n    test_case.assertEqual(x.shape, flow.Size(shape))\n    test_case.assertEqual(x.sbp, sbp)\n    test_case.assertEqual(x.placement, placement)\n\n\nclass TestEmptyGlobal(flow.unittest.TestCase):\n    @globaltest\n    def test_empty_global(test_case):\n        shapes = [(8,), (8, 8,), (8, 8, 8)]\n        functions = [\n            \"empty\",\n            \"new_empty\",\n            \"empty_like\",\n        ]\n        for func in functions:\n            for shape in shapes:\n                for placement in all_placement():\n                    for sbp in all_sbp(\n                        placement, max_dim=len(shape), except_partial_sum=True\n                    ):\n                        _test_global_empty(test_case, func, shape, placement, sbp)\n\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    @flow.unittest.skip_unless_1n2d()\n    def test_empty_graph(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"func\"] = [\"empty\", \"new_empty\", \"empty_like\"]\n        arg_dict[\"shape\"] = [(8,), (8, 8,), (8, 8, 8)]\n        arg_dict[\"placement\"] = [\n            # 1d\n            flow.placement(\"cpu\", ranks=[0, 1]),\n            flow.placement(\"cuda\", ranks=[0, 1]),\n            # 2d\n            flow.placement(\"cpu\", ranks=[[0, 1],]),\n            flow.placement(\"cuda\", ranks=[[0, 1],]),\n        ]\n        for args in GenArgDict(arg_dict):\n            func = args[\"func\"]\n            shape = args[\"shape\"]\n            placement = args[\"placement\"]\n            for sbp in all_sbp(placement, max_dim=len(shape), except_partial_sum=True):\n                _test_graph_empty(test_case, func, shape, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_eq.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=1, auto_backward=False, check_graph=True)\ndef do_test_eq_impl(test_case, ndim, placement, sbp):\n    dims = [random(1, 4) * 8 for i in range(ndim)]\n    x = random_tensor(ndim, *dims)\n    x = x.to_global(placement=placement, sbp=sbp)\n    y = random_tensor(ndim, *dims)\n    y = y.to_global(placement=placement, sbp=sbp)\n\n    z = torch.eq(x, y)\n    return z\n\n\nclass TestEqGlobal(flow.unittest.TestCase):\n    @globaltest\n    def test_eq(test_case):\n        # random ndim in range [1,4]\n        ndim = random(1, 5).to(int).value()\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=ndim):\n                do_test_eq_impl(test_case, ndim, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_erf.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=1, check_graph=True)\ndef do_test_erf_impl(test_case, ndim, placement, sbp):\n    dims = [random(1, 4) * 8 for i in range(ndim)]\n    x = random_tensor(ndim, *dims)\n    y = x.to_global(placement=placement, sbp=sbp)\n    z = torch.erf(y)\n    return z\n\n\nclass TestErfGlobal(flow.unittest.TestCase):\n    @globaltest\n    def test_erf(test_case):\n        # random ndim in range [1,4]\n        ndim = random(1, 5).to(int).value()\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=ndim):\n                do_test_erf_impl(test_case, ndim, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_erfc.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=1, auto_backward=True, rtol=1e-3, atol=1e-3, check_graph=True)\ndef do_test_erfc_impl(test_case, ndim, placement, sbp):\n    dims = [random(1, 3) * 8 for i in range(ndim)]\n    x = random_tensor(ndim, *dims)\n    y = x.to_global(placement=placement, sbp=sbp)\n    z = torch.erfc(y)\n    return z\n\n\nclass TestErfcGlobal(flow.unittest.TestCase):\n    @globaltest\n    def test_erfc(test_case):\n        # random ndim in range [1,4]\n        ndim = random(1, 5).to(int).value()\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=ndim):\n                do_test_erfc_impl(test_case, ndim, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_expand_op.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.unittest\nimport torch\n\nfrom collections import OrderedDict\nfrom oneflow.test_utils.test_util import GenArgDict\n\n\ndef _test_global_expand(\n    test_case,\n    input_shape,\n    expand_shape,\n    device=\"cuda\",\n    sbp=flow.sbp.broadcast,\n    verbose=False,\n):\n    # random input\n    input = np.random.randn(*input_shape)\n    if isinstance(input, np.ndarray):\n        input = input.astype(np.float32)\n\n    # torch computation\n    torch_x = torch.tensor(input, requires_grad=True)\n    torch_y = torch_x.expand(*expand_shape)\n    torch_y.sum().backward()\n\n    # oneflow computation\n    placement = flow.placement(device, np.array(range(flow.env.get_world_size())))\n    x = flow.tensor(input, requires_grad=True)\n    global_x = x.to_global(placement=placement, sbp=flow.sbp.broadcast)\n    if global_x.sbp != sbp:\n        global_x = global_x.to_global(sbp=sbp, grad_sbp=flow.sbp.broadcast)\n    y = global_x.expand(*expand_shape)\n    y.sum().backward()\n\n    y_b = y.to_global(sbp=flow.sbp.broadcast)\n\n    if flow.env.get_rank() == 0:\n        out_a = y_b.to_local().numpy()\n        out_b = torch_y.detach().cpu().numpy()\n        grad_a = x.grad.numpy()\n        grad_b = torch_x.grad.cpu().numpy()\n\n        if verbose:\n            print(\"\")\n            print(f\"{'=' * 10} {input_shape} -> {expand_shape} {'=' * 10}\")\n            print(f\"{'=' * 10} {device}, {sbp} {'=' * 10}\")\n            print(f\"{'-' * 20} compare out {'-' * 20}\")\n            print(out_a)\n            print(\"*\" * 20)\n            print(out_b)\n            print(\"\")\n            print(f\"{'-' * 20} compare grad {'-' * 20}\")\n            print(grad_a)\n            print(\"*\" * 20)\n            print(grad_b)\n\n        test_case.assertTrue(np.array_equal(out_a, out_b))\n        test_case.assertTrue(np.array_equal(grad_a, grad_b))\n\n\n@flow.unittest.skip_unless_1n2d()\nclass ExpandGlobalTestCase(oneflow.unittest.TestCase):\n    def test_global_expand(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"verbose\"] = [False]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"sbp\"] = [flow.sbp.split(0), flow.sbp.broadcast()]\n        arg_dict[\"shapes\"] = [\n            ((2, 2), (2, 2, 2)),\n            ((2, 1, 3), (2, 1, -1, -1, -1)),\n            ((2, 1, 3), (1, 2, -1, -1, -1)),\n            ((2, 1, 3), (2, 1, -1, 2, 3)),\n            ((2, 1, 3), (1, 2, 2, 2, -1)),\n        ]\n        for kwargs in GenArgDict(arg_dict):\n            assert \"shapes\" in kwargs\n            input_shape, expand_shape = kwargs.pop(\"shapes\")\n            _test_global_expand(test_case, input_shape, expand_shape, **kwargs)\n\n    def test_split_expand(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"verbose\"] = [False]\n        arg_dict[\"device\"] = [\"cuda\"]\n        arg_dict[\"sbp\"] = [flow.sbp.split(0)]\n        arg_dict[\"shapes\"] = [\n            ((2,), (1, 2)),\n            ((2,), (2, 2)),\n        ]\n        for kwargs in GenArgDict(arg_dict):\n            assert \"shapes\" in kwargs\n            input_shape, expand_shape = kwargs.pop(\"shapes\")\n            _test_global_expand(test_case, input_shape, expand_shape, **kwargs)\n\n    def test_broadcast_scalar_expand(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"verbose\"] = [False]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"sbp\"] = [flow.sbp.broadcast()]\n        arg_dict[\"shapes\"] = [\n            ((), (1,)),\n            ((), (2,)),\n            ((), (1, 1)),\n            ((), (1, 2)),\n            ((), (2, 1)),\n            ((), (2, 2)),\n            ((), (2, 1, 2)),\n        ]\n        for kwargs in GenArgDict(arg_dict):\n            assert \"shapes\" in kwargs\n            input_shape, expand_shape = kwargs.pop(\"shapes\")\n            _test_global_expand(test_case, input_shape, expand_shape, **kwargs)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n\n# ONEFLOW_TEST_DEVICE_NUM=2 python3 -m oneflow.distributed.launch --nproc_per_node 2 test_global_expand_op.py\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_expm1.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=1, check_graph=True)\ndef do_test_expm1_impl(test_case, ndim, placement, sbp):\n    dims = [random(1, 4) * 8 for i in range(ndim)]\n    x = random_tensor(ndim, *dims)\n    y = x.to_global(placement=placement, sbp=sbp)\n    z = torch.expm1(y)\n    return z\n\n\nclass TestExpm1Global(flow.unittest.TestCase):\n    @globaltest\n    def test_expm1(test_case):\n        # random ndim in range [1,4]\n        ndim = random(1, 5).to(int).value()\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=ndim):\n                do_test_expm1_impl(test_case, ndim, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_eye.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=1, auto_backward=False, check_graph=True)\ndef do_test_eye_impl(test_case, placement, sbp):\n    n = random(1, 5).to(int).value() * 8\n    m = random(1, 5).to(int).value() * 8\n    x = torch.eye(n, m)\n    x.oneflow = flow.tensor(\n        x.pytorch.cpu().detach().numpy(),\n        requires_grad=x.pytorch.requires_grad,\n        placement=placement,\n        sbp=sbp,\n    )\n    return x\n\n\nclass TestEyeGlobal(flow.unittest.TestCase):\n    @globaltest\n    def test_eye(test_case):\n        shape = random_tensor().shape\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=2):\n                do_test_eye_impl(test_case, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_fill.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=1, check_graph=True)\ndef _test_fill_(test_case, ndim, placement, sbp):\n    dims = [random(1, 4) * 8 for i in range(ndim)]\n    x = random_tensor(ndim, *dims).to_global(placement=placement, sbp=sbp)\n    value = random().to(float)\n    y = x + 1\n    y.fill_(value)\n    return y\n\n\n@autotest(n=1, check_graph=True)\ndef _test_fill_tensor_(test_case, ndim, placement, sbp):\n    dims = [random(2, 4) * 8 for i in range(ndim)]\n    x = (\n        random_tensor(ndim, *dims)\n        .to_global(placement=placement, sbp=sbp)\n        .requires_grad_()\n    )\n    value = (\n        torch.tensor(1.0)\n        .to_global(placement=placement, sbp=[flow.sbp.broadcast for _ in sbp])\n        .requires_grad_()\n    )\n    y = x + 1\n    y.oneflow = y.oneflow.to_global(placement, sbp)\n    y.fill_(value)\n    return y\n\n\nclass TestFillModule(flow.unittest.TestCase):\n    @globaltest\n    def test_fill_(test_case):\n        ndim = random(1, 5).to(int).value()\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=ndim):\n                _test_fill_(test_case, ndim, placement, sbp)\n                _test_fill_tensor_(test_case, ndim, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_flatten.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=1, check_graph=True)\ndef do_test_flatten_impl(test_case, ndim, placement, sbp):\n    dims = [random(1, 4) * 8 for i in range(ndim)]\n    x = random_tensor(ndim, *dims)\n    y = x.to_global(placement=placement, sbp=sbp)\n    start_dim = random(0, ndim).to(int).value()\n    end_dim = random(start_dim, ndim).to(int).value()\n\n    z = torch.flatten(x, start_dim, end_dim)\n    return z\n\n\nclass TestFlattenGlobal(flow.unittest.TestCase):\n    @globaltest\n    def test_flatten(test_case):\n        # random ndim in range [1,4]\n        ndim = random(1, 5).to(int).value()\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=ndim):\n                do_test_flatten_impl(test_case, ndim, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_flip.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=1, check_graph=True)\ndef _test_flip_impl(test_case, ndim, placement, sbp):\n    dims = [random(1, 4) * 8 for i in range(ndim)]\n    x = random_tensor(ndim, *dims)\n    y = x.to_global(placement=placement, sbp=sbp)\n    new_dim = random(0, ndim).to(int).value()\n    z = torch.flip(y, constant([i for i in range(new_dim)]))\n    return z\n\n\nclass TestFlipGlobal(flow.unittest.TestCase):\n    @globaltest\n    def test_flip(test_case):\n        # random ndim in range [1,4]\n        ndim = random(1, 5).to(int).value()\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=ndim):\n                _test_flip_impl(test_case, ndim, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_floor.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=1, check_graph=True)\ndef do_test_floor_impl(test_case, ndim, placement, sbp):\n    dims = [random(1, 4) * 8 for i in range(ndim)]\n    x = random_tensor(ndim, *dims)\n    y = x.to_global(placement=placement, sbp=sbp)\n    z = torch.floor(y)\n    return z\n\n\nclass TestFloorGlobal(flow.unittest.TestCase):\n    @globaltest\n    def test_floor(test_case):\n        # random ndim in range [1,4]\n        ndim = random(1, 5).to(int).value()\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=ndim):\n                do_test_floor_impl(test_case, ndim, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_fmod.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\nimport torch as torch_original\nfrom packaging import version\n\n# other.grad in torch.fmod(input, other) was not implemented before pytorch 1.11.0\ngrad_implemented = version.parse(torch_original.__version__) >= version.parse(\"1.11.0\")\n\n\n@autotest(n=1, auto_backward=grad_implemented, check_graph=True)\ndef do_test_fmod_impl(test_case, ndim, placement, sbp):\n    dims = [random(1, 4) * 8 for i in range(ndim)]\n    x = random_tensor(ndim, *dims)\n    x = x.to_global(placement=placement, sbp=sbp)\n    y = random_tensor(ndim, *dims)\n    y = y.to_global(placement=placement, sbp=sbp)\n\n    z = torch.fmod(x, y)\n    return z\n\n\nclass TestFmodGlobal(flow.unittest.TestCase):\n    @globaltest\n    def test_fmod(test_case):\n        # random ndim in range [1,5]\n        ndim = random(1, 5).to(int).value()\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=ndim):\n                do_test_fmod_impl(test_case, ndim, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_fold.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=1, check_graph=True)\ndef _test_fold_impl(test_case, placement, sbp):\n    ndim = 3\n    dims = [random(1, 4).to(int).value() * 8 for i in range(ndim)]\n    m = torch.nn.Fold(\n        output_size=constant(((dims[2] // 4) * 2, 4 * 2)),\n        kernel_size=constant(2),\n        dilation=constant(1),\n        padding=constant(0),\n        stride=constant(2),\n    )\n    m.train(random())\n\n    x = random_tensor(ndim, *dims).to_global(placement=placement, sbp=sbp)\n    y = m(x)\n    func_y = torch.nn.functional.fold(\n        x,\n        output_size=constant(((dims[2] // 4) * 2, 4 * 2)),\n        kernel_size=constant(2),\n        dilation=constant(1),\n        padding=constant(0),\n        stride=constant(2),\n    )\n    return y, func_y\n\n\nclass TestFold(flow.unittest.TestCase):\n    @globaltest\n    def test_fold(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=3):\n                _test_fold_impl(test_case, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_frac.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\n\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=1, check_graph=False)\ndef _test_frac(test_case, ndim, placement, sbp):\n    shape = [random(2, 4) * 8 for i in range(ndim)]\n    input = random_tensor(ndim, *shape).to_global(placement=placement, sbp=sbp)\n    output = torch.frac(input)\n    return output\n\n\nclass TestModule(flow.unittest.TestCase):\n    @globaltest\n    def test_frac(test_case):\n        ndim = random(2, 4).to(int).value()\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=ndim):\n                _test_frac(test_case, ndim, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_full.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\nfrom oneflow.test_utils.test_util import GenArgDict\n\n\ndef _test_global_full(test_case, shape, placement, sbp):\n    x = flow.full(shape, 1.0, placement=placement, sbp=sbp)\n\n    test_case.assertEqual(x.shape, flow.Size(shape))\n    test_case.assertEqual(x.sbp, sbp)\n    test_case.assertEqual(x.placement, placement)\n\n\ndef _test_global_full_tensor_scalar(test_case, shape, placement, sbp):\n    scalar_sbp = [flow.sbp.broadcast for _ in range(len(placement.ranks.shape))]\n    x1 = flow.tensor(1.0, placement=placement, sbp=scalar_sbp)\n    x2 = flow.full(shape, x1, placement=placement, sbp=sbp)\n    test_case.assertEqual(x2.shape, flow.Size(shape))\n    test_case.assertEqual(x2.sbp, sbp)\n    test_case.assertEqual(x2.placement, placement)\n\n\ndef _test_graph_full(test_case, shape, placement, sbp):\n    class GlobalFullGraph(flow.nn.Graph):\n        def __init__(self,):\n            super().__init__()\n\n        def build(self):\n            x = flow.full(shape, 1.0, placement=placement, sbp=sbp)\n            return x\n\n    model = GlobalFullGraph()\n    x = model()\n\n    test_case.assertEqual(x.shape, flow.Size(shape))\n    test_case.assertEqual(x.sbp, sbp)\n    test_case.assertEqual(x.placement, placement)\n\n\ndef _test_graph_full_tensor_scalar(test_case, shape, placement, sbp):\n    class GlobalFullGraph2(flow.nn.Graph):\n        def __init__(self,):\n            super().__init__()\n\n        def build(self):\n            x = flow.full(\n                shape,\n                flow.tensor(\n                    1.0,\n                    placement=placement,\n                    sbp=[flow.sbp.broadcast for _ in range(len(placement.ranks.shape))],\n                ),\n                placement=placement,\n                sbp=sbp,\n            )\n            return x\n\n    model = GlobalFullGraph2()\n    x = model()\n\n    test_case.assertEqual(x.shape, flow.Size(shape))\n    test_case.assertEqual(x.sbp, sbp)\n    test_case.assertEqual(x.placement, placement)\n\n\nclass TestFullGlobal(flow.unittest.TestCase):\n    @globaltest\n    def test_full_global(test_case):\n        shapes = [(8,), (8, 8,), (8, 8, 8)]\n        for shape in shapes:\n            for placement in all_placement():\n                for sbp in all_sbp(\n                    placement, max_dim=len(shape), except_partial_sum=True\n                ):\n                    _test_global_full(test_case, shape, placement, sbp)\n                    _test_global_full_tensor_scalar(test_case, shape, placement, sbp)\n\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    @flow.unittest.skip_unless_1n2d()\n    def test_full_graph(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"shape\"] = [[8], [8, 8], [8, 8, 8]]\n        arg_dict[\"placement\"] = [\n            # 1d\n            flow.placement(\"cpu\", ranks=[0, 1]),\n            flow.placement(\"cuda\", ranks=[0, 1]),\n            # 2d\n            flow.placement(\"cpu\", ranks=[[0, 1],]),\n            flow.placement(\"cuda\", ranks=[[0, 1],]),\n        ]\n        for args in GenArgDict(arg_dict):\n            shape = args[\"shape\"]\n            placement = args[\"placement\"]\n            for sbp in all_sbp(placement, max_dim=len(shape), except_partial_sum=True):\n                _test_graph_full(test_case, shape, placement, sbp)\n                _test_graph_full_tensor_scalar(test_case, shape, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_full_like.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\nfrom oneflow.test_utils.test_util import GenArgDict\n\n\ndef _test_global_full_like(test_case, shape, placement, sbp):\n    x_ = flow.randn(shape)\n    x = flow.full_like(x_, 1.0, placement=placement, sbp=sbp)\n\n    test_case.assertEqual(x.shape, flow.Size(shape))\n    test_case.assertEqual(x.sbp, sbp)\n    test_case.assertEqual(x.placement, placement)\n\n\ndef _test_graph_full_like(test_case, shape, placement, sbp):\n    class GlobalFullLikeGraph(flow.nn.Graph):\n        def __init__(self,):\n            super().__init__()\n\n        def build(self):\n            x_ = flow.randn(shape)\n            x = flow.full_like(x_, 1.0, placement=placement, sbp=sbp)\n            return x\n\n    model = GlobalFullLikeGraph()\n    x = model()\n\n    test_case.assertEqual(x.shape, flow.Size(shape))\n    test_case.assertEqual(x.sbp, sbp)\n    test_case.assertEqual(x.placement, placement)\n\n\nclass TestFillLikeGlobal(flow.unittest.TestCase):\n    @globaltest\n    def test_full_like_global(test_case):\n        shapes = [(8,), (8, 8,), (8, 8, 8)]\n        for shape in shapes:\n            for placement in all_placement():\n                for sbp in all_sbp(\n                    placement, max_dim=len(shape), except_partial_sum=True\n                ):\n                    _test_global_full_like(test_case, shape, placement, sbp)\n\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    @flow.unittest.skip_unless_1n2d()\n    def test_full_like_graph(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"shape\"] = [[8], [8, 8], [8, 8, 8]]\n        arg_dict[\"placement\"] = [\n            # 1d\n            flow.placement(\"cpu\", ranks=[0, 1]),\n            flow.placement(\"cuda\", ranks=[0, 1]),\n            # 2d\n            flow.placement(\"cpu\", ranks=[[0, 1],]),\n            flow.placement(\"cuda\", ranks=[[0, 1],]),\n        ]\n        for args in GenArgDict(arg_dict):\n            shape = args[\"shape\"]\n            placement = args[\"placement\"]\n            for sbp in all_sbp(placement, max_dim=len(shape), except_partial_sum=True):\n                _test_graph_full_like(test_case, shape, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_greater.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=10, auto_backward=False, check_graph=True)\ndef _test_greater_impl(test_case, ndim, placement, sbp):\n    dims = [random(1, 4) * 8 for i in range(ndim)]\n    x1 = random_tensor(ndim, *dims)\n    x2 = x1.to_global(placement=placement, sbp=sbp)\n    y1 = random_tensor(ndim, *dims)\n    y2 = y1.to_global(placement=placement, sbp=sbp)\n\n    z = torch.gt(x2, y2)\n    return z\n\n\n@unittest.skip(\"TODO: houjiang, yushun. this test might fail\")\nclass TestGreaterGlobal(flow.unittest.TestCase):\n    @globaltest\n    def test_greater(test_case):\n        # random ndim in range [1,4]\n        ndim = random(1, 5).to(int).value()\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=ndim):\n                _test_greater_impl(test_case, ndim, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_greater_equal.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=1, auto_backward=False, check_graph=True)\ndef do_test_greater_equal_impl(test_case, ndim, placement, sbp):\n    dims = [random(1, 4) * 8 for i in range(ndim)]\n    x1 = random_tensor(ndim, *dims)\n    x1 = x1.to_global(placement=placement, sbp=sbp)\n    x2 = random_tensor(ndim, *dims)\n    x2 = x2.to_global(placement=placement, sbp=sbp)\n\n    z = torch.ge(x1, x2)\n    return z\n\n\nclass TestGreaterEqualGlobal(flow.unittest.TestCase):\n    @globaltest\n    def test_greater_equal(test_case):\n        # random ndim in range [1,4]\n        ndim = random(1, 5).to(int).value()\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=ndim):\n                do_test_greater_equal_impl(test_case, ndim, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_grid_sample.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nimport numpy as np\n\nfrom oneflow.test_utils.automated_test_util import *\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\n@autotest(n=1, rtol=1e-03, atol=1e-04, check_graph=True)\ndef _test_flow_grid_sample_cudnn(test_case, placement, sbp):\n    # cudnn only support 4D input, with mode = 'bilinear' && padding_mode = 'zeros' && align_corners\n    N = random(1, 3).to(int) * 8\n    C = random(1, 3).to(int) * 8\n    in_H = random(1, 8).to(int)\n    in_W = random(1, 8).to(int)\n    out_H = random(1, 8).to(int)\n    out_W = random(1, 8).to(int)\n    mode = \"bilinear\"\n    padding_mode = \"zeros\"\n    align_corners = True\n    theta = random_tensor(ndim=3, dim0=N, dim1=2, dim2=3).to_global(\n        placement=placement, sbp=random_sbp(placement, max_dim=1)\n    )\n    grid = torch.nn.functional.affine_grid(\n        theta, (N, C, out_H, out_W), align_corners=align_corners\n    )\n    input = random_tensor(ndim=4, dim0=N, dim1=C, dim2=in_H, dim3=in_W).to_global(\n        placement=placement, sbp=sbp\n    )\n    output = torch.nn.functional.grid_sample(\n        input, grid, mode=mode, padding_mode=padding_mode, align_corners=align_corners,\n    )\n    return output\n\n\n# This test may fail due to using ::floor in backward\n# floor(1.99999988) = 1 and floor(2.000000) = 2, then select differente images pixel\n@autotest(\n    n=1,\n    auto_backward=False,\n    rtol=1e-03,\n    atol=1e-04,\n    check_graph=True,\n    check_allclose=False,\n)\ndef _test_flow_grid_sample_4d(test_case, placement, sbp):\n    N = random(1, 3).to(int) * 8\n    C = random(1, 3).to(int) * 8\n    in_H = random(1, 8).to(int)\n    in_W = random(1, 8).to(int)\n    out_H = random(1, 8).to(int)\n    out_W = random(1, 8).to(int)\n    mode = oneof(\"bilinear\", \"nearest\", \"bicubic\")\n    padding_mode = oneof(\"zeros\", \"border\", \"reflection\")\n    align_corners = oneof(True, False)\n    theta = random_tensor(ndim=3, dim0=N, dim1=2, dim2=3).to_global(\n        placement=placement, sbp=random_sbp(placement, max_dim=1)\n    )\n    grid = torch.nn.functional.affine_grid(\n        theta, (N, C, out_H, out_W), align_corners=align_corners\n    )\n    input = random_tensor(ndim=4, dim0=N, dim1=C, dim2=in_H, dim3=in_W).to_global(\n        placement=placement, sbp=sbp\n    )\n    output = torch.nn.functional.grid_sample(\n        input, grid, mode=mode, padding_mode=padding_mode, align_corners=align_corners,\n    )\n    return output\n\n\n@autotest(n=1, auto_backward=False, rtol=1e-03, atol=1e-03, check_graph=True)\ndef _test_flow_grid_sample_5d(test_case, placement, sbp):\n    N = random(1, 3).to(int) * 8\n    C = random(1, 3).to(int) * 8\n    in_D = random(1, 8).to(int)\n    in_H = random(1, 8).to(int)\n    in_W = random(1, 8).to(int)\n    out_D = random(1, 8).to(int)\n    out_H = random(1, 8).to(int)\n    out_W = random(1, 8).to(int)\n    mode = oneof(\"bilinear\", \"nearest\")\n    padding_mode = oneof(\"zeros\", \"border\", \"reflection\")\n    align_corners = oneof(True, False)\n    theta = random_tensor(ndim=3, dim0=N, dim1=3, dim2=4).to_global(\n        placement=placement, sbp=random_sbp(placement, max_dim=1)\n    )\n    grid = torch.nn.functional.affine_grid(\n        theta, (N, C, out_D, out_H, out_W), align_corners=align_corners\n    )\n    input = random_tensor(\n        ndim=5, dim0=N, dim1=C, dim2=in_D, dim3=in_H, dim4=in_W\n    ).to_global(placement=placement, sbp=sbp)\n    output = torch.nn.functional.grid_sample(\n        input, grid, mode=mode, padding_mode=padding_mode, align_corners=align_corners,\n    )\n    return output\n\n\nclass TestGridSample(flow.unittest.TestCase):\n    @unittest.skip(\"skip for now, becase it may fail in CI\")\n    @globaltest\n    def test_grid_sample(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=2):\n                if placement.type == \"cuda\":\n                    _test_flow_grid_sample_cudnn(test_case, placement, sbp)\n                _test_flow_grid_sample_4d(test_case, placement, sbp)\n                _test_flow_grid_sample_5d(test_case, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_groupnorm.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=1, check_graph=False, atol=1e-3, rtol=1e-3)\ndef _test_global_group_norm(test_case, placement, input_sbp, affine):\n    if placement.type == \"cpu\":\n        return\n    batch_size = 4\n    channel_size = 8\n    num_groups = 2\n    m = torch.nn.GroupNorm(\n        num_groups=num_groups, num_channels=channel_size, affine=affine\n    )\n    m.train(random())\n    m.to_global(\n        placement=placement, sbp=[flow.sbp.broadcast] * len(placement.ranks.shape)\n    )\n    x = random_tensor(\n        ndim=4,\n        dim0=batch_size,\n        dim1=channel_size,\n        dim2=random(4, 16),\n        dim3=random(4, 16),\n    ).to_global(placement=placement, sbp=input_sbp)\n    y = m(x)\n    return y\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\nclass TestGroupNormModule(flow.unittest.TestCase):\n    @globaltest\n    def test_global_group_norm_with_random_data(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=4):\n                _test_global_group_norm(test_case, placement, sbp, True)\n                _test_global_group_norm(test_case, placement, sbp, False)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_gru_cell.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\n\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\n# NOTE(lixiang): Do not check the graph for the time being, because ci will report \"The action has timed out\".\n@autotest(n=1, check_graph=\"ValidatedFalse\")\ndef _test_gru_cell(test_case, placement, sbp):\n    batch_size = random(2, 3) * 8\n    time_steps = random(2, 3) * 8\n    input_size = random(2, 3) * 8\n    hidden_size = random(2, 3) * 8\n    has_bias = random().to(bool)\n    m = torch.nn.GRUCell(input_size=input_size, hidden_size=hidden_size, bias=has_bias,)\n\n    weight_sbp = random_sbp(placement, max_dim=2, except_partial_sum=True)\n    m.weight_ih = torch.nn.Parameter(\n        m.weight_ih.to_global(placement=placement, sbp=weight_sbp)\n    )\n    m.weight_hh = torch.nn.Parameter(\n        m.weight_hh.to_global(placement=placement, sbp=weight_sbp)\n    )\n    if m.bias_ih is not None:\n        # bias is 1-d tensor\n        bias_sbp = random_sbp(placement, max_dim=1, except_partial_sum=True)\n        m.bias_ih = torch.nn.Parameter(\n            m.bias_ih.to_global(placement=placement, sbp=bias_sbp)\n        )\n        m.bias_hh = torch.nn.Parameter(\n            m.bias_hh.to_global(placement=placement, sbp=bias_sbp)\n        )\n\n    input_sbp = random_sbp(placement, max_dim=3, valid_split_axis=1)\n    input = random_tensor(\n        ndim=3, dim0=time_steps, dim1=batch_size, dim2=input_size\n    ).to_global(placement=placement, sbp=input_sbp)\n    hx = random_tensor(\n        ndim=2, dim0=batch_size, dim1=hidden_size, requires_grad=False\n    ).to_global(placement=placement, sbp=sbp)\n\n    for i in range(time_steps.to(int).value()):\n        hx = m(input[i], hx)\n\n    return hx\n\n\nclass TestRNNCellGlobal(flow.unittest.TestCase):\n    @globaltest\n    def test_gru_cell(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=2):\n                _test_gru_cell(test_case, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_hann_window.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nfrom collections import OrderedDict\n\nimport unittest\n\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\nfrom oneflow.test_utils.test_util import GenArgDict\n\n\ndef _test_global_hann_window(test_case, placement, sbp):\n    x = flow.hann_window(8, placement=placement, sbp=sbp)\n\n    test_case.assertEqual(x.sbp, sbp)\n    test_case.assertEqual(x.placement, placement)\n\n\ndef _test_graph_hann_window(test_case, placement, sbp):\n    class GlobalHannWindowGraph(flow.nn.Graph):\n        def __init__(self,):\n            super().__init__()\n\n        def build(self):\n            x = flow.hann_window(8, placement=placement, sbp=sbp)\n            return x\n\n    model = GlobalHannWindowGraph()\n    x = model()\n\n    test_case.assertEqual(x.sbp, sbp)\n    test_case.assertEqual(x.placement, placement)\n\n\nclass TestHannWindowGlobal(flow.unittest.TestCase):\n    # TODO(wyg): It will be infer all broadcast sbp when 1n1d,\n    #            slice_update will get error when doing inplace operator.\n    #            Remove this judgement after refactor sbp infer method in Operator class.\n    @globaltest\n    def test_hann_window_global(test_case):\n        for placement in all_placement():\n            if placement.ranks.size == 1:\n                continue\n            for sbp in all_sbp(placement, max_dim=1, except_partial_sum=True):\n                _test_global_hann_window(test_case, placement, sbp)\n\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    @flow.unittest.skip_unless_1n2d()\n    def test_hann_window_graph(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"placement\"] = [\n            # 1d\n            flow.placement(\"cpu\", ranks=[0, 1]),\n            flow.placement(\"cuda\", ranks=[0, 1]),\n            # 2d\n            flow.placement(\"cpu\", ranks=[[0, 1],]),\n            flow.placement(\"cuda\", ranks=[[0, 1],]),\n        ]\n        for args in GenArgDict(arg_dict):\n            placement = args[\"placement\"]\n            for sbp in all_sbp(placement, max_dim=1, except_partial_sum=True):\n                _test_graph_hann_window(test_case, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_higher_derivative_activation.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\n\nimport numpy as np\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\nimport torch as pytorch_origin\nimport oneflow as oneflow_origin\nfrom collections import defaultdict\n\n\ndef _assert_true(test_case, value1, value2):\n    test_case.assertTrue(\n        np.allclose(\n            value1.detach().cpu().numpy(),\n            value2.detach().numpy(),\n            rtol=1e-05,\n            atol=1e-05,\n        )\n    )\n\n\ndef _test_activation_grad_grad_impl(test_case, op_name, placement, *args, **kwargs):\n    x = random_tensor(ndim=2, low=-5, dim0=8, dim1=8).to_global(\n        placement=placement, sbp=random_sbp(placement=placement, max_dim=2)\n    )\n    y = eval(f\"torch.nn.functional.{op_name}\")(x, *args, **kwargs)\n\n    x_shape = x.oneflow.shape\n    init_grad_x = random_tensor(len(x_shape), *x_shape).to_global(\n        placement=placement, sbp=random_sbp(placement=placement, max_dim=2)\n    )\n    init_grad_y = random_tensor(len(x_shape), *x_shape).to_global(\n        placement=placement, sbp=random_sbp(placement=placement, max_dim=2)\n    )\n\n    dx = torch.autograd.grad(y, x, init_grad_y, True, True)[0]\n    _assert_true(test_case, dx.pytorch, dx.oneflow)\n\n    ddx_ddy = torch.autograd.grad(dx, [x, init_grad_y], init_grad_x)\n    ddx, ddy = ddx_ddy[0], ddx_ddy[1]\n    _assert_true(test_case, ddx.pytorch, ddx.oneflow)\n    _assert_true(test_case, ddy.pytorch, ddy.oneflow)\n\n\ndef _test_prelu_activation_grad_grad_impl(\n    test_case, op_name, placement, *args, **kwargs\n):\n    x = random_tensor(ndim=2, low=-5, dim0=8, dim1=8).to_global(\n        placement=placement, sbp=random_sbp(placement=placement, max_dim=2)\n    )\n    a = random_tensor(ndim=1, dim0=x.oneflow.shape[1]).to_global(\n        placement=placement, sbp=random_sbp(placement, max_dim=1)\n    )\n    y = torch.nn.functional.prelu(x, a)\n\n    x_shape = x.oneflow.shape\n    a_shape = a.oneflow.shape\n    init_grad_x = random_tensor(len(x_shape), *x_shape).to_global(\n        placement=placement, sbp=random_sbp(placement=placement, max_dim=2)\n    )\n    init_grad_y = random_tensor(len(x_shape), *x_shape).to_global(\n        placement=placement, sbp=random_sbp(placement=placement, max_dim=2)\n    )\n    init_grad_a = random_tensor(len(a_shape), *a_shape).to_global(\n        placement=placement, sbp=random_sbp(placement, max_dim=1)\n    )\n\n    dx_and_da = torch.autograd.grad(y, [x, a], init_grad_y, True, True)\n    dx, da = dx_and_da[0], dx_and_da[1]\n    _assert_true(test_case, dx.pytorch, dx.oneflow)\n    _assert_true(test_case, da.pytorch, da.oneflow)\n\n    ddx_dda_ddy = torch.autograd.grad(\n        dx_and_da, [dx, da, init_grad_y], [init_grad_x, init_grad_a], True, True\n    )\n    ddx, dda, ddy = ddx_dda_ddy[0], ddx_dda_ddy[1], ddx_dda_ddy[2]\n    _assert_true(test_case, ddx.pytorch, ddx.oneflow)\n    _assert_true(test_case, dda.pytorch, dda.oneflow)\n    _assert_true(test_case, ddy.pytorch, ddy.oneflow)\n\n\ndef _test_hardswish_activation_grad_grad_impl(\n    test_case, op_name, placement, *args, **kwargs\n):\n    x = random_tensor(ndim=2, low=-1, dim0=8, dim1=8).to_global(\n        placement=placement, sbp=random_sbp(placement=placement, max_dim=2)\n    )\n    y = torch.nn.functional.hardswish(x, *args, **kwargs)\n\n    x_shape = x.oneflow.shape\n    init_grad_x = random_tensor(len(x_shape), *x_shape).to_global(\n        placement=placement, sbp=random_sbp(placement=placement, max_dim=2)\n    )\n    init_grad_y = random_tensor(len(x_shape), *x_shape).to_global(\n        placement=placement, sbp=random_sbp(placement=placement, max_dim=2)\n    )\n\n    dx_pytorch = pytorch_origin.autograd.grad(\n        y.pytorch, x.pytorch, init_grad_y.pytorch\n    )[0]\n    dx_oneflow = oneflow_origin.autograd.grad(\n        y.oneflow, x.oneflow, init_grad_y.oneflow, True, True\n    )[0]\n    _assert_true(test_case, dx_pytorch, dx_oneflow)\n\n    ddx, ddy = flow.autograd.grad(\n        dx_oneflow, [x.oneflow, init_grad_y.oneflow], init_grad_x.oneflow\n    )\n    x, dx, init_grad_x, init_grad_y = (\n        x.oneflow,\n        dx_oneflow,\n        init_grad_x.oneflow,\n        init_grad_y.oneflow,\n    )\n\n    zeros_grad = flow.zeros_like(x).to_global(placement=placement, sbp=x.sbp)\n    manual_ddx = flow.where(\n        ((x > -3.0) < 3.0), 1.0 / 3.0 * init_grad_x * init_grad_y, zeros_grad\n    )\n    manual_ddy = dx / init_grad_y * init_grad_x\n    _assert_true(test_case, manual_ddx, ddx)\n    _assert_true(test_case, manual_ddy, ddy)\n\n\ndef _test_hardsigmoid_activation_grad_grad_impl(\n    test_case, op_name, placement, *args, **kwargs\n):\n    x = random_tensor(ndim=2, low=-1, dim0=8, dim1=8).to_global(\n        placement=placement, sbp=random_sbp(placement=placement, max_dim=2)\n    )\n    y = torch.nn.functional.hardsigmoid(x, *args, **kwargs)\n\n    x_shape = x.oneflow.shape\n    init_grad_x = random_tensor(len(x_shape), *x_shape).to_global(\n        placement=placement, sbp=random_sbp(placement=placement, max_dim=2)\n    )\n    init_grad_y = random_tensor(len(x_shape), *x_shape).to_global(\n        placement=placement, sbp=random_sbp(placement=placement, max_dim=2)\n    )\n\n    dx_pytorch = pytorch_origin.autograd.grad(\n        y.pytorch, x.pytorch, init_grad_y.pytorch\n    )[0]\n    dx_oneflow = oneflow_origin.autograd.grad(\n        y.oneflow, x.oneflow, init_grad_y.oneflow, True, True\n    )[0]\n    _assert_true(test_case, dx_pytorch, dx_oneflow)\n\n    ddx, ddy = flow.autograd.grad(\n        dx_oneflow, [x.oneflow, init_grad_y.oneflow], init_grad_x.oneflow\n    )\n    x, dx, init_grad_x, init_grad_y = (\n        x.oneflow,\n        dx_oneflow,\n        init_grad_x.oneflow,\n        init_grad_y.oneflow,\n    )\n    manual_ddx = flow.zeros_like(x)\n    manual_ddy = dx / init_grad_y * init_grad_x\n    _assert_true(test_case, manual_ddx, ddx)\n    _assert_true(test_case, manual_ddy, ddy)\n\n\nclass TestActivationHigherDerivative(flow.unittest.TestCase):\n    @globaltest\n    def test_activation_grad_grad(test_case):\n        op_args = defaultdict(list)\n        op_kwargs = defaultdict(dict)\n\n        # parameter name not same in pytorch and oneflow\n        op_args[\"leaky_relu\"] = [random(-1, 1).to(float)]\n\n        # some op only support kwargs, like celu in oneflow\n        op_kwargs[\"hardtanh\"] = {\n            \"min_val\": random(-5, -1).to(float),\n            \"max_val\": random(1, 5).to(float),\n        }\n        op_kwargs[\"elu\"] = {\"alpha\": random(0, 10).to(float)}\n        op_kwargs[\"celu\"] = {\"alpha\": random(0, 10).to(float)}\n        op_kwargs[\"threshold\"] = {\n            \"threshold\": random().to(float),\n            \"value\": random().to(float),\n        }\n        op_kwargs[\"softplus\"] = {\n            \"beta\": random().to(float),\n            \"threshold\": random().to(float),\n        }\n\n        op_names = [\n            \"mish\",\n            \"gelu\",\n            \"silu\",\n            \"selu\",\n            \"softsign\",\n            \"hardsigmoid\",\n            \"hardswish\",\n            \"relu\",\n            \"elu\",\n            \"celu\",\n            \"prelu\",\n            \"hardshrink\",\n            \"softshrink\",\n            \"leaky_relu\",\n            \"hardtanh\",\n            \"softplus\",\n            \"threshold\",\n        ]\n        for op_name in op_names:\n            try:\n                functor = eval(f\"_test_{op_name}_activation_grad_grad_impl\")\n            except:\n                functor = _test_activation_grad_grad_impl\n\n            print(f\"| {op_name:-^60} |\")\n            for placement in all_placement():\n                for i in range(1):\n                    functor(\n                        test_case,\n                        op_name,\n                        placement,\n                        *op_args[op_name],\n                        **op_kwargs[op_name],\n                    )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_higher_derivative_conv.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\n\nimport numpy as np\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\nimport torch as pytorch_origin\nimport oneflow as oneflow_origin\n\n\ndef _test_convnd_grad_grad_impl(test_case, ndim, placement):\n    x_shape = [8, 8] + [5 for _ in range(ndim)]\n    w_shape = [8, 8] + [3 for _ in range(ndim)]\n    y_shape = [8, 8] + [3 for _ in range(ndim)]\n\n    x = random_tensor(len(x_shape), *x_shape).to_global(\n        placement=placement, sbp=random_sbp(placement, max_dim=2)\n    )\n    w = random_tensor(len(w_shape), *w_shape).to_global(\n        placement=placement, sbp=random_sbp(placement, max_dim=2)\n    )\n    init_grad_x = random_tensor(len(x_shape), *x_shape).to_global(\n        placement=placement, sbp=random_sbp(placement, max_dim=2)\n    )\n    init_grad_w = random_tensor(len(w_shape), *w_shape).to_global(\n        placement=placement, sbp=random_sbp(placement, max_dim=2)\n    )\n    init_grad_y = random_tensor(len(y_shape), *y_shape).to_global(\n        placement=placement, sbp=random_sbp(placement, max_dim=2)\n    )\n\n    y = eval(f\"torch.nn.functional.conv{ndim}d\")(\n        x, w, stride=1, padding=0, groups=1, dilation=1\n    )\n\n    dx = torch.autograd.grad(\n        outputs=y,\n        inputs=x,\n        grad_outputs=init_grad_y,\n        create_graph=True,\n        retain_graph=True,\n    )[0]\n\n    test_case.assertTrue(\n        np.allclose(\n            dx.pytorch.detach().cpu().numpy(),\n            dx.oneflow.detach().numpy(),\n            rtol=1e-5,\n            atol=1e-2,\n        )\n    )\n\n    dw = torch.autograd.grad(\n        outputs=y,\n        inputs=w,\n        grad_outputs=init_grad_y,\n        create_graph=True,\n        retain_graph=True,\n    )[0]\n    test_case.assertTrue(\n        np.allclose(\n            dw.pytorch.detach().cpu().numpy(),\n            dw.oneflow.detach().numpy(),\n            rtol=1e-5,\n            atol=1e-5,\n        )\n    )\n\n    # torch.autograd.grad in autotest does not support inputs/outpus/grad_outputs as a list\n    # so use the original pytorch/oneflow module\n    ddx_pytorch, ddw_pytorch = pytorch_origin.autograd.grad(\n        outputs=[dx.pytorch, dw.pytorch],\n        inputs=[x.pytorch, w.pytorch],\n        grad_outputs=[init_grad_x.pytorch, init_grad_w.pytorch],\n        create_graph=True,\n        retain_graph=True,\n    )\n    ddx_oneflow, ddw_oneflow = oneflow_origin.autograd.grad(\n        outputs=[dx.oneflow, dw.oneflow],\n        inputs=[x.oneflow, w.oneflow],\n        grad_outputs=[init_grad_x.oneflow, init_grad_w.oneflow],\n        create_graph=True,\n        retain_graph=True,\n    )\n\n    test_case.assertTrue(\n        np.allclose(\n            ddw_pytorch.detach().cpu().numpy(),\n            ddw_oneflow.detach().numpy(),\n            rtol=1e-5,\n            atol=1e-5,\n        )\n    )\n    test_case.assertTrue(\n        np.allclose(\n            ddx_pytorch.detach().cpu().numpy(),\n            ddx_oneflow.detach().numpy(),\n            rtol=1e-5,\n            atol=1e-2,\n        )\n    )\n\n    dgrad_dx = torch.autograd.grad(\n        outputs=dx,\n        inputs=init_grad_y,\n        grad_outputs=init_grad_x,\n        create_graph=True,\n        retain_graph=True,\n    )[0]\n    test_case.assertTrue(\n        np.allclose(\n            dgrad_dx.pytorch.detach().cpu().numpy(),\n            dgrad_dx.oneflow.detach().numpy(),\n            rtol=1e-4,\n            atol=1e-2,\n        )\n    )\n\n    dgrad_dw = torch.autograd.grad(\n        outputs=dw,\n        inputs=init_grad_y,\n        grad_outputs=init_grad_w,\n        create_graph=True,\n        retain_graph=True,\n    )[0]\n    test_case.assertTrue(\n        np.allclose(\n            dgrad_dw.pytorch.detach().cpu().numpy(),\n            dgrad_dw.oneflow.detach().numpy(),\n            rtol=1e-4,\n            atol=1e-2,\n        )\n    )\n\n\nclass TestGlobalConvHigherDerivative(flow.unittest.TestCase):\n    @globaltest\n    def test_conv1d_grad_grad(test_case):\n        for placement in all_placement():\n            for i in range(5):\n                _test_convnd_grad_grad_impl(test_case, ndim=1, placement=placement)\n\n    @globaltest\n    def test_conv2d_grad_grad(test_case):\n        for placement in all_placement():\n            for i in range(5):\n                _test_convnd_grad_grad_impl(test_case, ndim=2, placement=placement)\n\n    @globaltest\n    def test_conv3d_grad_grad(test_case):\n        for placement in all_placement():\n            for i in range(5):\n                _test_convnd_grad_grad_impl(test_case, ndim=3, placement=placement)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_higher_derivative_div.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\n\nimport numpy as np\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\ndef _test_global_div_grad_grad_impl(test_case, placement):\n    x_shape = [8, 8, 8, 8]\n    y_shape = [8, 8]\n    if random_bool().value():\n        x_shape, y_shape = y_shape, x_shape\n    x = random_tensor(len(x_shape), *x_shape).to_global(\n        placement=placement, sbp=random_sbp(placement, max_dim=2)\n    )\n    y = random_tensor(len(y_shape), *y_shape).to_global(\n        placement=placement, sbp=random_sbp(placement, max_dim=2)\n    )\n    z = torch.div(x, y)\n    init_grad_z = random_tensor(len(z.oneflow.shape), *z.oneflow.shape).to_global(\n        placement=placement, sbp=random_sbp(placement, max_dim=2)\n    )\n    init_grad_x = random_tensor(len(x.oneflow.shape), *x.oneflow.shape).to_global(\n        placement=placement, sbp=random_sbp(placement, max_dim=2)\n    )\n    init_grad_y = random_tensor(len(y.oneflow.shape), *y.oneflow.shape).to_global(\n        placement=placement, sbp=random_sbp(placement, max_dim=2)\n    )\n\n    dx_and_dy = torch.autograd.grad(z, [x, y], init_grad_z, True, True)\n    test_case.assertTrue(\n        np.allclose(\n            dx_and_dy.pytorch[0].detach().cpu().numpy(),\n            dx_and_dy.oneflow[0].detach().numpy(),\n            rtol=1e-4,\n            atol=1e-4,\n        )\n    )\n    test_case.assertTrue(\n        np.allclose(\n            dx_and_dy.pytorch[1].detach().cpu().numpy(),\n            dx_and_dy.oneflow[1].detach().numpy(),\n            rtol=1e-3,\n            atol=1e-4,\n        )\n    )\n\n    ddx_and_ddy_and_ddz = torch.autograd.grad(\n        dx_and_dy, [x, y, init_grad_z], [init_grad_x, init_grad_y], True, True\n    )\n    test_case.assertTrue(\n        np.allclose(\n            ddx_and_ddy_and_ddz.pytorch[0].detach().cpu().numpy(),\n            ddx_and_ddy_and_ddz.oneflow[0].detach().numpy(),\n            rtol=1e-3,\n            atol=1e-3,\n        )\n    )\n    test_case.assertTrue(\n        np.allclose(\n            ddx_and_ddy_and_ddz.pytorch[1].detach().cpu().numpy(),\n            ddx_and_ddy_and_ddz.oneflow[1].detach().numpy(),\n            rtol=1e-2,\n            atol=1e-3,\n        )\n    )\n    test_case.assertTrue(\n        np.allclose(\n            ddx_and_ddy_and_ddz.pytorch[2].detach().cpu().numpy(),\n            ddx_and_ddy_and_ddz.oneflow[2].detach().numpy(),\n            rtol=1e-3,\n            atol=1e-3,\n        )\n    )\n\n\nclass TestGlobalDivHigherDerivative(flow.unittest.TestCase):\n    @unittest.skip(\"skip for now, becase it failed 22 times in past week\")\n    @globaltest\n    def test_global_div_grad_grad(test_case):\n        for placement in all_placement():\n            for i in range(1):\n                _test_global_div_grad_grad_impl(test_case, placement)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_higher_derivative_loss.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nimport numpy as np\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\ndef _assert_true(test_case, value1, value2, name=\"\"):\n    is_equal = np.allclose(\n        value1.detach().cpu().numpy(), value2.detach().numpy(), rtol=1e-03, atol=1e-03,\n    )\n    test_case.assertTrue(is_equal, f\"{name} is not equal.\" if name else \"\")\n\n\ndef generate_grads_for_variables(variables):\n    if isinstance(variables, list):\n        shape_and_sbp = [(i.oneflow.shape, i.oneflow.sbp) for i in variables]\n        placement = variables[0].oneflow.placement\n    elif hasattr(variables, \"pytorch\"):\n        shape_and_sbp = [(i.shape, i.sbp) for i in variables.oneflow]\n        placement = variables.oneflow[0].placement\n    else:\n        assert False\n    grads = [\n        random_tensor(\n            len(shape), *shape, requires_grad=random_bool().value()\n        ).to_global(placement=placement, sbp=sbp)\n        for shape, sbp in shape_and_sbp\n    ]\n    return grads\n\n\ndef calculate_and_compare_loss(test_case, input, target, model, order=2):\n    output = model(input, target)\n    _assert_true(test_case, output.pytorch, output.oneflow, \"output\")\n\n    init_inputs = [input, target]\n    grad_inputs = [output]\n    grad_outputs = []\n    for i in range(order):\n        inputs = [\n            var for var in [*init_inputs, *grad_outputs] if var.pytorch.requires_grad\n        ]\n        outputs = grad_inputs\n        grad_outputs = generate_grads_for_variables(outputs)\n        if i == order - 1:\n            grad_inputs = torch.autograd.grad(outputs, inputs, grad_outputs)\n        else:\n            grad_inputs = torch.autograd.grad(outputs, inputs, grad_outputs, True, True)\n        for j in range(len(inputs)):\n            _assert_true(\n                test_case,\n                grad_inputs[j].pytorch,\n                grad_inputs[j].oneflow,\n                f\"{i}-grad_inputs[{j}]\",\n            )\n\n\ndef generate_necessity_for_default_loss(placement):\n    shape = [8, 8]\n    ndim = len(shape)\n    input_requires_grad = True\n    target_requires_grad = random_bool().value()\n    return (\n        random_tensor(ndim, *shape, low=0, requires_grad=input_requires_grad).to_global(\n            placement=placement, sbp=random_sbp(placement, max_dim=2)\n        ),\n        random_tensor(\n            ndim, *shape, low=0, requires_grad=target_requires_grad\n        ).to_global(placement=placement, sbp=random_sbp(placement, max_dim=2)),\n    )\n\n\ndef generate_necessity_for_nll_loss(placement):\n    ndim = 2\n    num_classes = 8\n    batch_size = 8\n    ignore_index = oneof(random(0, num_classes).to(int).value(), -100).value()\n    extra_dim = [random().to(int) for _ in range(ndim - 2)]\n    return (\n        random_tensor(ndim, batch_size, num_classes).to_global(\n            placement=placement, sbp=random_sbp(placement, max_dim=2)\n        ),\n        random_tensor(\n            ndim - 1,\n            batch_size,\n            low=0,\n            high=num_classes,\n            dtype=int,\n            requires_grad=False,\n        ).to_global(placement=placement, sbp=random_sbp(placement, max_dim=1)),\n        random_tensor(1, num_classes, low=0, high=3, requires_grad=False).to_global(\n            placement=placement, sbp=random_sbp(placement, except_split=True)\n        ),\n        ignore_index,\n    )\n\n\ndef generate_necessity_for_bce_loss(placement):\n    ndim = 3\n    num_classes = 2\n    batch_size = 8\n    extra_dim = [random().to(int) for _ in range(ndim - 2)]\n    input_requires_grad = True\n    target_requires_grad = False\n    return (\n        random_tensor(\n            ndim,\n            batch_size,\n            num_classes,\n            low=0,\n            high=1,\n            *extra_dim,\n            requires_grad=input_requires_grad,\n        ).to_global(placement=placement, sbp=random_sbp(placement, max_dim=1)),\n        random_tensor(\n            ndim,\n            batch_size,\n            num_classes,\n            *extra_dim,\n            low=0,\n            high=num_classes,\n            requires_grad=target_requires_grad,\n        ).to_global(placement=placement, sbp=random_sbp(placement, max_dim=1)),\n        random_tensor(\n            ndim,\n            batch_size,\n            num_classes,\n            *extra_dim,\n            low=0,\n            high=3,\n            requires_grad=False,\n        ).to_global(placement=placement, sbp=random_sbp(placement, max_dim=1)),\n        random_tensor(1, 1, low=1, high=3, requires_grad=False,).to_global(\n            placement=placement, sbp=random_sbp(placement, except_split=True)\n        ),\n    )\n\n\ndef _test_smooth_l1_loss_grad_grad_impl(test_case, placement):\n    x, y = generate_necessity_for_default_loss(placement)\n\n    m = torch.nn.SmoothL1Loss(\n        reduction=oneof(\"none\", \"sum\", \"mean\", nothing()), beta=oneof(0.0, 0.5, 1)\n    )\n\n    calculate_and_compare_loss(test_case, x, y, m)\n\n\ndef _test_kl_div_loss_grad_grad_impl(test_case, placement):\n    x, y = generate_necessity_for_default_loss(placement)\n\n    m = torch.nn.KLDivLoss(\n        reduction=oneof(\"none\", \"sum\", \"mean\", nothing()),\n        log_target=oneof(True, False),\n    )\n\n    calculate_and_compare_loss(test_case, x, y, m)\n\n\ndef _test_bce_loss_grad_grad_impl(test_case, placement, with_logits=False):\n    x, y, weight, pos_weight = generate_necessity_for_bce_loss(placement)\n\n    if with_logits:\n        weight = weight if random_bool().value() else None\n        has_pos_weight = random_bool().value()\n        pos_weight = pos_weight if has_pos_weight else nothing()\n        m = torch.nn.BCEWithLogitsLoss(\n            weight=weight,\n            pos_weight=pos_weight,\n            reduction=oneof(\"none\", \"sum\", \"mean\"),\n        )\n        if has_pos_weight:\n            y = y.detach().clone().requires_grad_(False)\n    else:\n        m = torch.nn.BCELoss(\n            weight=(weight if random_bool().value() else None),\n            reduction=oneof(\"none\", \"sum\", \"mean\"),\n        )\n\n    calculate_and_compare_loss(test_case, x, y, m)\n\n\ndef _test_nll_loss_grad_grad_impl(test_case, placement):\n    (x, y, weight, ignore_index) = generate_necessity_for_nll_loss(placement)\n\n    m = torch.nn.NLLLoss(\n        weight=(weight if random_bool().value() else None),\n        reduction=oneof(\"none\", \"sum\", \"mean\", nothing()),\n        ignore_index=ignore_index,\n    )\n\n    calculate_and_compare_loss(test_case, x, y, m)\n\n\nclass TestGlobalLossHigherDerivative(flow.unittest.TestCase):\n    @globaltest\n    def test_smooth_l1_loss_grad_grad(test_case):\n        for placement in all_placement():\n            _test_smooth_l1_loss_grad_grad_impl(test_case, placement)\n\n    @globaltest\n    def test_kl_div_loss_grad_grad(test_case):\n        for placement in all_placement():\n            _test_kl_div_loss_grad_grad_impl(test_case, placement)\n\n    @globaltest\n    def test_nll_loss_grad_grad(test_case):\n        for placement in all_placement():\n            _test_nll_loss_grad_grad_impl(test_case, placement)\n\n    @globaltest\n    def test_bce_loss_grad_grad(test_case):\n        for placement in all_placement():\n            _test_bce_loss_grad_grad_impl(test_case, placement)\n\n    @globaltest\n    def test_bce_with_logits_loss_grad_grad(test_case):\n        for placement in all_placement():\n            _test_bce_loss_grad_grad_impl(test_case, placement, with_logits=True)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_higher_derivative_matmul.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\n\nimport numpy as np\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\nimport torch as pytorch_origin\nimport oneflow as oneflow_origin\n\n\ndef _test_broadcast_matmul_grad_b_grad_impl(test_case, placement):\n    broadcast_dims = [np.random.randint(1, 5) * 8 for _ in range(2)]\n    m = np.random.randint(1, 5) * 8\n    n = np.random.randint(1, 5) * 8\n    k = np.random.randint(1, 5) * 8\n\n    a_shape = broadcast_dims + [m, k]\n    b_shape = [k, n]\n    y_shape = broadcast_dims + [m, n]\n\n    a = random_tensor(len(a_shape), *a_shape).to_global(\n        placement=placement, sbp=random_sbp(placement, max_dim=2)\n    )\n    b = random_tensor(len(b_shape), *b_shape).to_global(\n        placement=placement, sbp=random_sbp(placement, max_dim=2)\n    )\n    init_grad_a = random_tensor(len(a_shape), *a_shape).to_global(\n        placement=placement, sbp=random_sbp(placement, max_dim=2)\n    )\n    init_grad_b = random_tensor(len(b_shape), *b_shape).to_global(\n        placement=placement, sbp=random_sbp(placement, max_dim=2)\n    )\n    init_grad_y = random_tensor(len(y_shape), *y_shape).to_global(\n        placement=placement, sbp=random_sbp(placement, max_dim=2)\n    )\n\n    y = torch.matmul(a, b)\n\n    da = torch.autograd.grad(\n        outputs=y,\n        inputs=a,\n        grad_outputs=init_grad_y,\n        create_graph=True,\n        retain_graph=True,\n    )[0]\n\n    test_case.assertTrue(\n        np.allclose(\n            da.pytorch.detach().cpu().numpy(),\n            da.oneflow.detach().numpy(),\n            rtol=1e-5,\n            atol=1e-2,\n        )\n    )\n\n    db = torch.autograd.grad(\n        outputs=y,\n        inputs=b,\n        grad_outputs=init_grad_y,\n        create_graph=True,\n        retain_graph=True,\n    )[0]\n    test_case.assertTrue(\n        np.allclose(\n            db.pytorch.detach().cpu().numpy(),\n            db.oneflow.detach().numpy(),\n            rtol=1e-3,\n            atol=1e-4,\n        )\n    )\n\n    # torch.autograd.grad in autotest does not support inputs/outpus/grad_outputs as a list\n    # so use the original pytorch/oneflow module\n    dda_pytorch, ddb_pytorch = pytorch_origin.autograd.grad(\n        outputs=[da.pytorch, db.pytorch],\n        inputs=[a.pytorch, b.pytorch],\n        grad_outputs=[init_grad_a.pytorch, init_grad_b.pytorch],\n        create_graph=True,\n        retain_graph=True,\n    )\n    dda_oneflow, ddb_oneflow = oneflow_origin.autograd.grad(\n        outputs=[da.oneflow, db.oneflow],\n        inputs=[a.oneflow, b.oneflow],\n        grad_outputs=[init_grad_a.oneflow, init_grad_b.oneflow],\n        create_graph=True,\n        retain_graph=True,\n    )\n\n    test_case.assertTrue(\n        np.allclose(\n            ddb_pytorch.detach().cpu().numpy(),\n            ddb_oneflow.detach().numpy(),\n            rtol=1e-3,\n            atol=1e-4,\n        )\n    )\n    test_case.assertTrue(\n        np.allclose(\n            dda_pytorch.detach().cpu().numpy(),\n            dda_oneflow.detach().numpy(),\n            rtol=1e-5,\n            atol=1e-2,\n        )\n    )\n\n    dgrad_da = torch.autograd.grad(\n        outputs=da,\n        inputs=init_grad_y,\n        grad_outputs=init_grad_a,\n        create_graph=True,\n        retain_graph=True,\n    )[0]\n    test_case.assertTrue(\n        np.allclose(\n            dgrad_da.pytorch.detach().cpu().numpy(),\n            dgrad_da.oneflow.detach().numpy(),\n            rtol=1e-5,\n            atol=1e-2,\n        )\n    )\n\n    dgrad_db = torch.autograd.grad(\n        outputs=db,\n        inputs=init_grad_y,\n        grad_outputs=init_grad_b,\n        create_graph=True,\n        retain_graph=True,\n    )[0]\n    test_case.assertTrue(\n        np.allclose(\n            dgrad_db.pytorch.detach().cpu().numpy(),\n            dgrad_db.oneflow.detach().numpy(),\n            rtol=1e-5,\n            atol=1e-2,\n        )\n    )\n\n\nclass TestGlobalMatmulHigherDerivative(flow.unittest.TestCase):\n    @unittest.skip(\"skip for now, becase it failed 32 times in past week\")\n    @globaltest\n    def test_broadcast_matmul_grad_b_grad(test_case):\n        for placement in all_placement():\n            for i in range(5):\n                _test_broadcast_matmul_grad_b_grad_impl(test_case, placement=placement)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_higher_derivative_neg.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\n\nimport numpy as np\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\ndef _global_neg_grad_grad_impl(test_case, placement, sbp):\n    x = flow.randn(8, 8).to_global(placement=placement, sbp=sbp).requires_grad_(True)\n    init_grad = (\n        flow.randn(8, 8).to_global(placement=placement, sbp=sbp).requires_grad_(True)\n    )\n    init_grad_grad = (\n        flow.randn(8, 8).to_global(placement=placement, sbp=sbp).requires_grad_(True)\n    )\n\n    y = x.neg()\n    x_grad = flow.autograd.grad(y, x, init_grad, create_graph=True)[0]\n    test_case.assertTrue(np.allclose(-init_grad, x_grad.detach().numpy()))\n\n    dgrad = flow.autograd.grad(x_grad, init_grad, init_grad_grad, create_graph=True)[0]\n    test_case.assertTrue(np.allclose(-init_grad_grad, dgrad.detach().numpy(),))\n\n\nclass TestGlobalNegHigherDerivative(flow.unittest.TestCase):\n    @globaltest\n    def test_global_neg_grad_grad(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=2):\n                _global_neg_grad_grad_impl(test_case, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_higher_derivative_pool.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nimport numpy as np\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\ndef _check_equal(test_case, lhs, rhs, name=\"\", rtol=1e-5, atol=1e-5):\n    is_equal = np.allclose(\n        lhs.detach().cpu().numpy(),\n        rhs.detach().cpu().numpy(),\n        rtol=rtol,\n        atol=atol,\n        equal_nan=True,\n    )\n    test_case.assertTrue(is_equal, f\"{name} is not equal\" if name else \"\")\n\n\ndef _test_avg_pool_grad_grad_impl(test_case, placement, ndim):\n    x_shape = [8, 8] + [5] * ndim\n\n    m = eval(f\"torch.nn.AvgPool{ndim}d\")(kernel_size=random(2, 5).to(int))\n\n    x = random_tensor(len(x_shape), *x_shape).to_global(\n        placement=placement, sbp=random_sbp(placement, max_dim=2)\n    )\n    y = m(x)\n    _check_equal(test_case, y.pytorch, y.oneflow, \"y\")\n\n    init_grad_y = random_tensor(len(y.oneflow.shape), *y.oneflow.shape).to_global(\n        placement=placement, sbp=random_sbp(placement, max_dim=2)\n    )\n    init_grad_x = random_tensor(len(x.oneflow.shape), *x.oneflow.shape).to_global(\n        placement=placement, sbp=random_sbp(placement, max_dim=2)\n    )\n\n    dx = torch.autograd.grad(y, x, init_grad_y, True, True)[0]\n    _check_equal(test_case, dx.pytorch, dx.oneflow, \"dx\")\n\n    ddx_ddy = torch.autograd.grad(dx, [x, init_grad_y], init_grad_x, True, True)\n    ddx, ddy = ddx_ddy[0], ddx_ddy[1]\n    _check_equal(test_case, ddx.pytorch, ddx.oneflow, \"ddx\")\n    _check_equal(test_case, ddy.pytorch, ddy.oneflow, \"ddy\")\n\n\ndef _test_max_pool_grad_grad_impl(test_case, placement, ndim):\n    x_shape = [8, 8] + [5] * ndim\n\n    m = eval(f\"torch.nn.MaxPool{ndim}d\")(kernel_size=random(2, 5).to(int))\n\n    x = random_tensor(len(x_shape), *x_shape).to_global(\n        placement=placement, sbp=random_sbp(placement, max_dim=2)\n    )\n\n    y = m(x)\n    _check_equal(test_case, y.pytorch, y.oneflow, \"y\")\n\n    init_grad_y = random_tensor(len(y.oneflow.shape), *y.oneflow.shape).to_global(\n        placement=placement, sbp=random_sbp(placement, max_dim=2)\n    )\n    init_grad_x = random_tensor(len(x.oneflow.shape), *x.oneflow.shape).to_global(\n        placement=placement, sbp=random_sbp(placement, max_dim=2)\n    )\n\n    dx = torch.autograd.grad(y, x, init_grad_y, True, True)[0]\n    _check_equal(test_case, dx.pytorch, dx.oneflow, \"dx\")\n\n    ddx_ddy = torch.autograd.grad(dx, [x, init_grad_y], init_grad_x, True, True)\n    ddx, ddy = ddx_ddy[0], ddx_ddy[1]\n    _check_equal(test_case, ddx.pytorch, ddx.oneflow, \"ddx\")\n    _check_equal(test_case, ddy.pytorch, ddy.oneflow, \"ddy\")\n\n\ndef _test_adaptive_pool_grad_grad_impl(test_case, placement, ndim, mode):\n    x_shape = [8, 8] + [5] * ndim\n\n    m = eval(f\"torch.nn.Adaptive{mode.title()}Pool{ndim}d\")(\n        output_size=random(2, 5).to(int)\n    )\n\n    x = random_tensor(len(x_shape), *x_shape).to_global(\n        placement=placement, sbp=random_sbp(placement, max_dim=2)\n    )\n    y = m(x)\n    _check_equal(test_case, y.pytorch, y.oneflow, \"y\")\n\n    init_grad_y = random_tensor(len(y.oneflow.shape), *y.oneflow.shape).to_global(\n        placement=placement, sbp=random_sbp(placement, max_dim=2)\n    )\n    init_grad_x = random_tensor(len(x.oneflow.shape), *x.oneflow.shape).to_global(\n        placement=placement, sbp=random_sbp(placement, max_dim=2)\n    )\n\n    dx = torch.autograd.grad(y, x, init_grad_y, True, True)[0]\n    _check_equal(test_case, dx.pytorch, dx.oneflow, \"dx\")\n\n    ddx_ddy = torch.autograd.grad(dx, [x, init_grad_y], init_grad_x, True, True)\n    ddx, ddy = ddx_ddy[0], ddx_ddy[1]\n\n    _check_equal(test_case, ddx.pytorch, ddx.oneflow, \"ddx\")\n    _check_equal(test_case, ddy.pytorch, ddy.oneflow, \"ddy\")\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestGlobalPoolHigherDerivative(flow.unittest.TestCase):\n    @globaltest\n    def test_max_pool_1d_grad_grad(test_case):\n        for placement in all_placement():\n            _test_max_pool_grad_grad_impl(test_case, placement, 1)\n\n    @globaltest\n    def test_max_pool_2d_grad_grad(test_case):\n        for placement in all_placement():\n            _test_max_pool_grad_grad_impl(test_case, placement, 2)\n\n    @globaltest\n    def test_max_pool_3d_grad_grad(test_case):\n        for placement in all_placement():\n            _test_max_pool_grad_grad_impl(test_case, placement, 3)\n\n    @globaltest\n    def test_avg_pool_1d_grad_grad(test_case):\n        for placement in all_placement():\n            _test_avg_pool_grad_grad_impl(test_case, placement, ndim=1)\n\n    @globaltest\n    def test_avg_pool_2d_grad_grad(test_case):\n        for placement in all_placement():\n            _test_avg_pool_grad_grad_impl(test_case, placement, ndim=2)\n\n    @globaltest\n    def test_avg_pool_3d_grad_grad(test_case):\n        for placement in all_placement():\n            _test_avg_pool_grad_grad_impl(test_case, placement, ndim=3)\n\n    @globaltest\n    def test_adaptive_avg_pool_1d_grad_grad(test_case):\n        for placement in all_placement():\n            _test_adaptive_pool_grad_grad_impl(test_case, placement, ndim=1, mode=\"avg\")\n\n    @globaltest\n    def test_adaptive_avg_pool_2d_grad_grad(test_case):\n        for placement in all_placement():\n            _test_adaptive_pool_grad_grad_impl(test_case, placement, ndim=2, mode=\"avg\")\n\n    @globaltest\n    def test_adaptive_avg_pool_3d_grad_grad(test_case):\n        for placement in all_placement():\n            _test_adaptive_pool_grad_grad_impl(test_case, placement, ndim=3, mode=\"avg\")\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_higher_derivative_pow.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nimport numpy as np\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\ndef _check_equal(test_case, lhs, rhs, rtol=1e-3, atol=1e-3):\n    is_equal = np.allclose(\n        lhs.detach().cpu().numpy(),\n        rhs.detach().cpu().numpy(),\n        rtol=rtol,\n        atol=atol,\n        equal_nan=True,\n    )\n    test_case.assertTrue(is_equal)\n\n\ndef _test_global_pow_grad_grad_impl(test_case, placement):\n    x_shape, y_shape = [([8, 8], [8, 8]), ([8, 8, 8], [8, 8]), ([8, 8], [8, 8, 8]),][\n        random(1, 3).to(int).value()\n    ]\n\n    x = random_tensor(len(x_shape), *x_shape).to_global(\n        placement=placement, sbp=random_sbp(placement, max_dim=2)\n    )\n    y = random_tensor(len(y_shape), *y_shape).to_global(\n        placement=placement, sbp=random_sbp(placement, max_dim=2)\n    )\n\n    z = torch.pow(x, y)\n    _check_equal(test_case, z.pytorch, z.oneflow)\n\n    init_grad_z = random_tensor(len(z.oneflow.shape), *z.oneflow.shape).to_global(\n        placement=placement, sbp=random_sbp(placement, max_dim=2)\n    )\n    init_grad_x = random_tensor(len(x.oneflow.shape), *x.oneflow.shape).to_global(\n        placement=placement, sbp=random_sbp(placement, max_dim=2)\n    )\n    init_grad_y = random_tensor(len(y.oneflow.shape), *y.oneflow.shape).to_global(\n        placement=placement, sbp=random_sbp(placement, max_dim=2)\n    )\n\n    dx_and_dy = torch.autograd.grad(z, [x, y], init_grad_z, True, True)\n    _check_equal(test_case, dx_and_dy.pytorch[0], dx_and_dy.oneflow[0])\n    _check_equal(test_case, dx_and_dy.pytorch[1], dx_and_dy.oneflow[1])\n\n    ddx_ddy_ddz = torch.autograd.grad(\n        dx_and_dy, [x, y, init_grad_z], [init_grad_x, init_grad_y]\n    )\n    _check_equal(test_case, ddx_ddy_ddz.pytorch[0], ddx_ddy_ddz.oneflow[0])\n    _check_equal(test_case, ddx_ddy_ddz.pytorch[1], ddx_ddy_ddz.oneflow[1])\n    _check_equal(test_case, ddx_ddy_ddz.pytorch[2], ddx_ddy_ddz.oneflow[2])\n\n\nclass TestGlobalPowHigherDerivative(flow.unittest.TestCase):\n    @globaltest\n    def test_global_pow_grad_grad(test_case):\n        for placement in all_placement():\n            for i in range(5):\n                _test_global_pow_grad_grad_impl(test_case, placement)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_higher_derivative_scalar_pow.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nimport numpy as np\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\ndef _check_equal(test_case, lhs, rhs, rtol=1e-4, atol=1e-4, name=\"\"):\n    is_equal = np.allclose(\n        lhs.detach().cpu().numpy(),\n        rhs.detach().cpu().numpy(),\n        rtol=rtol,\n        atol=atol,\n        equal_nan=True,\n    )\n    test_case.assertTrue(is_equal, f\"{name} is not equal\")\n\n\ndef _test_scalar_pow_grad_grad_impl(test_case, placement, reverse=False):\n    x_shape = [8, 8]\n    y = random().to(float if random_bool().value() else int).value()\n\n    x = random_tensor(len(x_shape), *x_shape).to_global(\n        placement=placement, sbp=random_sbp(placement, max_dim=2)\n    )\n    z = torch.pow(x, y) if not reverse else torch.pow(y, x)\n\n    init_grad_z = random_tensor(len(z.oneflow.shape), *z.oneflow.shape).to_global(\n        placement=placement, sbp=random_sbp(placement, max_dim=2)\n    )\n    init_grad_x = random_tensor(len(x.oneflow.shape), *x.oneflow.shape).to_global(\n        placement=placement, sbp=random_sbp(placement, max_dim=2)\n    )\n\n    dx = torch.autograd.grad(z, x, init_grad_z, True, True)[0]\n    _check_equal(test_case, dx.pytorch, dx.oneflow, name=\"dx\")\n\n    ddx_and_ddz = torch.autograd.grad(dx, [x, init_grad_z], init_grad_x, True, True)\n    _check_equal(test_case, ddx_and_ddz.pytorch[0], ddx_and_ddz.oneflow[0], name=\"ddx\")\n    _check_equal(test_case, ddx_and_ddz.pytorch[1], ddx_and_ddz.oneflow[1], name=\"ddz\")\n\n\nclass TestGlobalScalarPowHigherDerivative(flow.unittest.TestCase):\n    @globaltest\n    def test_global_scalar_pow_grad_grad(test_case):\n        for placement in all_placement():\n            for i in range(10):\n                _test_scalar_pow_grad_grad_impl(test_case, placement)\n\n    @globaltest\n    def test_global_scalar_reverse_pow_grad_grad(test_case):\n        for placement in all_placement():\n            for i in range(10):\n                _test_scalar_pow_grad_grad_impl(test_case, placement, reverse=True)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_higher_derivative_slice.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\n\nimport numpy as np\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\ndef _global_slice_grad_grad_impl(test_case, placement, sbp):\n    x = (\n        random_tensor(ndim=3, dim0=8, dim1=8, dim2=8)\n        .to_global(placement=placement, sbp=sbp)\n        .requires_grad_(True)\n    )\n    init_grad = (\n        random_tensor(ndim=3, dim0=8, dim1=8, dim2=4)\n        .to_global(placement=placement, sbp=sbp)\n        .requires_grad_(True)\n    )\n    init_grad_grad = (\n        random_tensor(ndim=3, dim0=8, dim1=8, dim2=8)\n        .to_global(placement=placement, sbp=sbp)\n        .requires_grad_(True)\n    )\n\n    y = x[:, :, 2:6]\n\n    x_grad = torch.autograd.grad(y, x, init_grad, create_graph=True)[0]\n    test_case.assertTrue(\n        np.allclose(\n            x_grad.pytorch.detach().cpu().numpy(), x_grad.oneflow.detach().numpy()\n        )\n    )\n\n    dgrad = torch.autograd.grad(x_grad, init_grad, init_grad_grad, create_graph=False)[\n        0\n    ]\n    test_case.assertTrue(\n        np.allclose(\n            dgrad.pytorch.detach().cpu().numpy(), dgrad.oneflow.detach().numpy(),\n        )\n    )\n\n\nclass TestGlobalSliceHigherDerivative(flow.unittest.TestCase):\n    @globaltest\n    def test_global_slice_grad_grad(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=2):\n                _global_slice_grad_grad_impl(test_case, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_higher_derivative_softmax.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nimport numpy as np\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\ndef _assert_true(test_case, value1, value2):\n    test_case.assertTrue(\n        np.allclose(\n            value1.detach().cpu().numpy(),\n            value2.detach().cpu().numpy(),\n            rtol=1e-05,\n            atol=1e-05,\n        )\n    )\n\n\ndef _test_global_softmax_grad_grad_impl(test_case, op_name, placement, sbp):\n    ndim = 2\n    data = random_tensor(ndim=ndim, dim0=8, dim1=8)\n\n    for dim in range(ndim):\n        x = (\n            data.detach()\n            .clone()\n            .requires_grad_()\n            .to_global(placement=placement, sbp=sbp)\n        )\n        m = eval(f\"torch.nn.{op_name}\")(dim)\n        y = m(x)\n        _assert_true(test_case, y.pytorch, y.oneflow)\n\n        x_shape = x.oneflow.shape\n        init_grad_x = random_tensor(len(x_shape), *x_shape).to_global(\n            placement=placement, sbp=sbp\n        )\n        init_grad_y = random_tensor(len(x_shape), *x_shape).to_global(\n            placement=placement, sbp=sbp\n        )\n\n        dx = torch.autograd.grad(y, x, init_grad_y, True, True)[0]\n        _assert_true(test_case, dx.pytorch, dx.oneflow)\n\n        ddx_ddy = torch.autograd.grad(dx, [x, init_grad_y], init_grad_x)\n        ddx, ddy = ddx_ddy[0], ddx_ddy[1]\n        _assert_true(test_case, ddx.pytorch, ddx.oneflow)\n        _assert_true(test_case, ddy.pytorch, ddy.oneflow)\n\n\nclass TestGlobalSoftmaxHigherDerivative(flow.unittest.TestCase):\n    @globaltest\n    def test_global_softmax_grad_grad(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=2):\n                _test_global_softmax_grad_grad_impl(\n                    test_case, op_name=\"Softmax\", placement=placement, sbp=sbp\n                )\n\n    @globaltest\n    def test_global_logsoftmax_grad_grad(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=2):\n                _test_global_softmax_grad_grad_impl(\n                    test_case, op_name=\"LogSoftmax\", placement=placement, sbp=sbp\n                )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_inv.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=1, check_graph=True)\ndef _test_inv(test_case, placement, sbp, ndim):\n    dim_list = [random(1, 3).to(int).value() * 8 for _ in range(ndim - 2)]\n    square_dim = 8\n    dim_list.extend([square_dim] * 2)\n    x = (\n        random_tensor(ndim, *dim_list, low=-1)\n        .to(torch.double)\n        .to_global(placement, sbp)\n    )\n    return torch.linalg.inv(x)\n\n\nclass TestInv(flow.unittest.TestCase):\n    @globaltest\n    def test_inv(test_case):\n        ndim = random(2, 5).to(int).value()\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=ndim):\n                _test_inv(test_case, placement, sbp, ndim)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_lerp.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=1, check_graph=False)\ndef _test_lerp(test_case, placement, sbp, ndim):\n    dim_list = [random(1, 3).to(int).value() * 8 for _ in range(ndim - 2)]\n    square_dim = 8\n    dim_list.extend([square_dim] * 2)\n    start = random_tensor(ndim, *dim_list).to(torch.double).to_global(placement, sbp)\n    end = random_tensor(ndim, *dim_list).to(torch.double).to_global(placement, sbp)\n    weight = random_tensor(ndim, *dim_list).to(torch.double).to_global(placement, sbp)\n    return torch.lerp(start, end, weight)\n\n\nclass TestLerp(flow.unittest.TestCase):\n    @globaltest\n    def test_lerp(test_case):\n        ndim = random(2, 5).to(int).value()\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=ndim):\n                _test_lerp(test_case, placement, sbp, ndim)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_linalg_cross.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=1)\ndef _test_linalg_cross(test_case, index_size_equal_3, ndim, placement, sbp):\n    shape = [random(1, 4).to(int) * 8 for i in range(ndim)]\n    shape[index_size_equal_3] = 3\n    x = random_tensor(ndim, *shape)\n    x = x.to_global(placement=placement, sbp=sbp)\n    y = random_tensor(ndim, *shape)\n    y = y.to_global(placement=placement, sbp=sbp)\n    return torch.cross(\n        x, y, dim=index_size_equal_3\n    )  # TODO(peihong): will convert to torch.linalg.cross when PyTorch in ci is upgraded to 1.11\n\n\nclass TestLinalgCrossGlobal(flow.unittest.TestCase):\n    @globaltest\n    def test_linalg_cross(test_case):\n        ndim = random(2, 5).to(int).value()\n        index_size_equal_3 = random(0, ndim).to(int).value()\n        for placement in all_placement():\n            for sbp in all_sbp(\n                placement,\n                max_dim=ndim,\n                valid_split_axis=[i for i in range(ndim) if i != index_size_equal_3],\n            ):\n                _test_linalg_cross(test_case, index_size_equal_3, ndim, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_linear.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=1, check_graph=False)\ndef _test_linear_with_random_data(test_case, placement, input_sbp):\n    input_size = 8\n    m = torch.nn.Linear(in_features=input_size, out_features=8, bias=random())\n    m.train(random())\n    weight_sbp = random_sbp(placement, max_dim=2, except_partial_sum=True)\n    m.weight = torch.nn.Parameter(\n        m.weight.to_global(placement=placement, sbp=weight_sbp)\n    )\n    if m.bias is not None:\n        # bias is 1-d tensor\n        bias_sbp = random_sbp(placement, max_dim=1, except_partial_sum=True)\n        m.bias = torch.nn.Parameter(m.bias.to_global(placement=placement, sbp=bias_sbp))\n    x = random_tensor(ndim=2, dim0=input_size, dim1=8).to_global(\n        placement=placement, sbp=input_sbp\n    )\n    y = m(x)\n    return y\n\n\nclass TestLinearModule(flow.unittest.TestCase):\n    @globaltest\n    def test_linear_with_random_data(test_case):\n        for placement in all_placement():\n            for input_sbp in all_sbp(placement, max_dim=2):\n                _test_linear_with_random_data(test_case, placement, input_sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_linspace.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nfrom collections import OrderedDict\n\nimport unittest\n\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\nfrom oneflow.test_utils.test_util import GenArgDict\n\n\ndef _test_global_linspace(test_case, placement, sbp):\n    x = flow.linspace(start=-10, end=10, steps=8, placement=placement, sbp=sbp)\n\n    test_case.assertEqual(x.sbp, sbp)\n    test_case.assertEqual(x.placement, placement)\n\n\ndef _test_graph_linspace(test_case, start, end, steps, placement, sbp):\n    class GlobalLinspaceGraph(flow.nn.Graph):\n        def __init__(self,):\n            super().__init__()\n\n        def build(self):\n            x = flow.linspace(start, end, steps, placement=placement, sbp=sbp)\n            return x\n\n    model = GlobalLinspaceGraph()\n    x = model()\n\n    test_case.assertEqual(x.sbp, sbp)\n    test_case.assertEqual(x.placement, placement)\n\n\nclass TestLinspaceGlobal(flow.unittest.TestCase):\n    @globaltest\n    def test_linspace_global(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=1, except_partial_sum=True):\n                _test_global_linspace(test_case, placement, sbp)\n\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    @flow.unittest.skip_unless_1n2d()\n    def test_linspace_graph(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"start\"] = [-2, 0, 2]\n        arg_dict[\"end\"] = [4, 8, 16]\n        arg_dict[\"steps\"] = [8, 16, 24]\n        arg_dict[\"placement\"] = [\n            # 1d\n            flow.placement(\"cpu\", ranks=[0, 1]),\n            flow.placement(\"cuda\", ranks=[0, 1]),\n            # 2d\n            flow.placement(\"cpu\", ranks=[[0, 1],]),\n            flow.placement(\"cuda\", ranks=[[0, 1],]),\n        ]\n        for args in GenArgDict(arg_dict):\n            start = args[\"start\"]\n            end = args[\"end\"]\n            steps = args[\"steps\"]\n            placement = args[\"placement\"]\n            for sbp in all_sbp(placement, max_dim=1, except_partial_sum=True):\n                _test_graph_linspace(test_case, start, end, steps, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_logspace.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nfrom collections import OrderedDict\n\nimport unittest\n\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\nfrom oneflow.test_utils.test_util import GenArgDict\n\n\ndef _test_global_logspace(test_case, placement, sbp):\n    x = flow.logspace(start=-10, end=10, steps=8, placement=placement, sbp=sbp)\n\n    test_case.assertEqual(x.sbp, sbp)\n    test_case.assertEqual(x.placement, placement)\n\n\ndef _test_graph_logspace(test_case, start, end, steps, placement, sbp):\n    class GlobalLogspaceGraph(flow.nn.Graph):\n        def __init__(self,):\n            super().__init__()\n\n        def build(self):\n            x = flow.logspace(start, end, steps, placement=placement, sbp=sbp)\n            return x\n\n    model = GlobalLogspaceGraph()\n    x = model()\n\n    test_case.assertEqual(x.sbp, sbp)\n    test_case.assertEqual(x.placement, placement)\n\n\nclass TestLogspaceGlobal(flow.unittest.TestCase):\n    # TODO(wyg): It will be infer all broadcast sbp when 1n1d,\n    #            slice_update will get error when doing inplace operator.\n    #            Remove this judgement after refactor sbp infer method in Operator class.\n    @globaltest\n    def test_logspace_global(test_case):\n        for placement in all_placement():\n            if placement.ranks.size == 1:\n                continue\n            for sbp in all_sbp(placement, max_dim=1, except_partial_sum=True):\n                _test_global_logspace(test_case, placement, sbp)\n\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    @flow.unittest.skip_unless_1n2d()\n    def test_logspace_graph(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"start\"] = [-2, 0, 2]\n        arg_dict[\"end\"] = [4, 8, 16]\n        arg_dict[\"steps\"] = [8, 16, 24]\n        arg_dict[\"placement\"] = [\n            # 1d\n            flow.placement(\"cpu\", ranks=[0, 1]),\n            flow.placement(\"cuda\", ranks=[0, 1]),\n            # 2d\n            flow.placement(\"cpu\", ranks=[[0, 1],]),\n            flow.placement(\"cuda\", ranks=[[0, 1],]),\n        ]\n        for args in GenArgDict(arg_dict):\n            start = args[\"start\"]\n            end = args[\"end\"]\n            steps = args[\"steps\"]\n            placement = args[\"placement\"]\n            for sbp in all_sbp(placement, max_dim=1, except_partial_sum=True):\n                _test_graph_logspace(test_case, start, end, steps, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_lstm_cell.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\n\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\n# NOTE(lixiang): Do not check the graph for the time being, because ci will report \"The action has timed out\".\n@autotest(n=1, check_graph=\"ValidatedFalse\")\ndef _test_lstm_cell(test_case, placement, sbp):\n    batch_size = random(2, 3) * 8\n    time_steps = random(2, 3) * 8\n    input_size = random(2, 3) * 8\n    hidden_size = random(2, 3) * 8\n    has_bias = random().to(bool)\n    cx_requires_grad = random().to(bool)\n    m = torch.nn.LSTMCell(\n        input_size=input_size, hidden_size=hidden_size, bias=has_bias,\n    )\n\n    weight_sbp = random_sbp(placement, max_dim=2, except_partial_sum=True)\n    m.weight_ih = torch.nn.Parameter(\n        m.weight_ih.to_global(placement=placement, sbp=weight_sbp)\n    )\n    m.weight_hh = torch.nn.Parameter(\n        m.weight_hh.to_global(placement=placement, sbp=weight_sbp)\n    )\n    if m.bias_ih is not None:\n        # bias is 1-d tensor\n        bias_sbp = random_sbp(placement, max_dim=1, except_partial_sum=True)\n        m.bias_ih = torch.nn.Parameter(\n            m.bias_ih.to_global(placement=placement, sbp=bias_sbp)\n        )\n        m.bias_hh = torch.nn.Parameter(\n            m.bias_hh.to_global(placement=placement, sbp=bias_sbp)\n        )\n\n    input_sbp = random_sbp(placement, max_dim=3, valid_split_axis=1)\n    input = random_tensor(\n        ndim=3, dim0=time_steps, dim1=batch_size, dim2=input_size\n    ).to_global(placement=placement, sbp=input_sbp)\n    hx = random_tensor(\n        ndim=2, dim0=batch_size, dim1=hidden_size, requires_grad=False\n    ).to_global(placement=placement, sbp=sbp)\n    cx = random_tensor(\n        ndim=2, dim0=batch_size, dim1=hidden_size, requires_grad=cx_requires_grad\n    ).to_global(placement=placement, sbp=sbp)\n\n    for i in range(time_steps.to(int).value()):\n        res = m(input[i], (hx, cx))\n        hx = res[0]\n        cx = res[1]\n    return res[0]\n\n\nclass TestRNNCellGlobal(flow.unittest.TestCase):\n    @globaltest\n    def test_lstm_cell(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=2):\n                _test_lstm_cell(test_case, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_masked_fill.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\n\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=1, check_graph=True)\ndef _test_masked_fill(test_case, placement, sbp):\n    k1 = random().to(int).value() * 8\n    k2 = random().to(int).value() * 8\n    input = random_tensor(ndim=2, dim0=k1, dim1=k2).to_global(placement, sbp)\n    mask = random_tensor(ndim=2, dim0=k1, dim1=k2).to_global(placement, sbp)\n    value = random().to(float)\n    return input.masked_fill(mask > 0.5, value)\n\n\n@autotest(n=1, check_graph=True)\ndef _test_masked_fill_with_0dim_data(test_case, placement, sbp):\n    input = random_tensor(ndim=0).to_global(placement, sbp)\n    mask = random_tensor(ndim=0).to_global(placement, sbp)\n    value = random().to(float)\n    return input.masked_fill(mask > 0.5, value)\n\n\n@autotest(n=1, check_graph=True)\ndef _test_masked_fill_with_broadcast_way(test_case, placement, sbp):\n    k1 = random().to(int).value() * 8\n    k2 = random().to(int).value() * 8\n    input = random_tensor(ndim=2, dim0=k1, dim1=k2, dim2=1, dim3=k2).to_global(\n        placement, sbp\n    )\n    mask = random_tensor(ndim=2, dim0=k1, dim1=k2, dim2=k1, dim3=1).to_global(\n        placement, sbp\n    )\n    value = random().to(float)\n    return input.masked_fill(mask > 0.5, value)\n\n\nclass TestMaskedFill(flow.unittest.TestCase):\n    @globaltest\n    def test_masked_fill(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=2):\n                _test_masked_fill(test_case, placement, sbp)\n                # TODO() : fail at tensor slice\n                # _test_masked_fill_with_0dim_data(test_case, placement, sbp)\n                _test_masked_fill_with_broadcast_way(test_case, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_masked_select.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\n\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n# Not check graph because of one reason:\n# Reason 1, The implementation of the masked_select op calls argwhere with the lazy tensor as an argument, but lazy tensor can not be applied to argwhere.\n# Please refer to File \"python/oneflow/nn/modules/masked_select.py\", line 54, in masked_select_op.\n@autotest(n=1, check_graph=\"ValidatedFalse\")\ndef _test_masked_select(test_case, placement, sbp):\n    k1 = random(1, 2).to(int).value() * 8\n    k2 = random(1, 2).to(int).value() * 8\n    input = random_tensor(ndim=2, dim0=k1, dim1=k2).to_global(placement, sbp)\n    mask = input.ge(0.5)\n    return torch.masked_select(input, mask)\n\n\n# Not check graph because of one reason:\n# Reason 1, The implementation of the masked_select op calls argwhere with the lazy tensor as an argument, but lazy tensor can not be applied to argwhere.\n# Please refer to File \"python/oneflow/nn/modules/masked_select.py\", line 54, in masked_select_op.\n@autotest(n=1, check_graph=\"ValidatedFalse\")\ndef _test_masked_select_broadcast(test_case, placement, input_sbp, mask_sbp):\n    k1 = random(1, 2).to(int).value() * 8\n    k2 = random(1, 2).to(int).value() * 8\n    input = random_tensor(ndim=4, dim0=k1, dim1=k2, dim2=1, dim3=k2).to_global(\n        placement, input_sbp\n    )\n    mask = random_tensor(ndim=4, dim0=k1, dim1=k2, dim2=k1, dim3=1).to_global(\n        placement, mask_sbp\n    )\n    return torch.masked_select(input, mask > 0.5)\n\n\nclass TestMaskedSelect(flow.unittest.TestCase):\n    @globaltest\n    def test_masked_select(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=2):\n                _test_masked_select(test_case, placement, sbp)\n\n    @globaltest\n    def test_masked_select_broadcast(test_case):\n        for placement in all_placement():\n            for input_sbp in all_sbp(placement, valid_split_axis=[0, 1, 3]):\n                for mask_sbp in all_sbp(placement, max_dim=3):\n                    _test_masked_select_broadcast(\n                        test_case, placement, input_sbp, mask_sbp\n                    )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_math_op_higher_derivative.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\n\nimport numpy as np\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\ndef _global_math_op_grad_grad_impl(test_case, op_name, placement, sbp):\n    x = (\n        random_tensor(2, dim0=8, dim1=8, low=-2, high=2)\n        .to_global(placement=placement, sbp=sbp)\n        .requires_grad_(True)\n    )\n    y = eval(f\"torch.{op_name}\")(x)\n    init_grad = random_tensor(2, 8, 8).to_global(placement, sbp).requires_grad_()\n\n    x_grad = torch.autograd.grad(y, x, init_grad, create_graph=True)[0]\n    test_case.assertTrue(\n        np.allclose(\n            x_grad.pytorch.detach().cpu().numpy(),\n            x_grad.oneflow.detach().numpy(),\n            atol=1e-4,\n            rtol=1e-4,\n            equal_nan=True,\n        )\n    )\n\n    x_grad_grad = torch.autograd.grad(x_grad, x, init_grad, retain_graph=True)[0]\n    test_case.assertTrue(\n        np.allclose(\n            x_grad_grad.pytorch.detach().cpu().numpy(),\n            x_grad_grad.oneflow.detach().numpy(),\n            atol=1e-4,\n            rtol=1e-4,\n            equal_nan=True,\n        )\n    )\n\n    init_grad_grad = random_tensor(2, 8, 8).to_global(placement, sbp).requires_grad_()\n    dgrad = torch.autograd.grad(x_grad, init_grad, init_grad_grad, retain_graph=True)[0]\n    test_case.assertTrue(\n        np.allclose(\n            dgrad.pytorch.detach().cpu().numpy(),\n            dgrad.oneflow.detach().numpy(),\n            atol=1e-4,\n            rtol=1e-4,\n            equal_nan=True,\n        )\n    )\n\n\nclass TestGlobalMathOpHigherDerivative(flow.unittest.TestCase):\n    @globaltest\n    def test_global_sin_grad_grad(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=2):\n                _global_math_op_grad_grad_impl(test_case, \"sin\", placement, sbp)\n\n    @globaltest\n    def test_global_cos_grad_grad(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=2):\n                _global_math_op_grad_grad_impl(test_case, \"cos\", placement, sbp)\n\n    @globaltest\n    def test_global_tan_grad_grad(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=2):\n                _global_math_op_grad_grad_impl(test_case, \"tan\", placement, sbp)\n\n    @globaltest\n    def test_global_sinh_grad_grad(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=2):\n                _global_math_op_grad_grad_impl(test_case, \"sinh\", placement, sbp)\n\n    @globaltest\n    def test_global_cosh_grad_grad(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=2):\n                _global_math_op_grad_grad_impl(test_case, \"cosh\", placement, sbp)\n\n    @globaltest\n    def test_global_tanh_grad_grad(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=2):\n                _global_math_op_grad_grad_impl(test_case, \"tanh\", placement, sbp)\n\n    @globaltest\n    def test_global_asin_grad_grad(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=2):\n                _global_math_op_grad_grad_impl(test_case, \"asin\", placement, sbp)\n\n    @globaltest\n    def test_global_acos_grad_grad(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=2):\n                _global_math_op_grad_grad_impl(test_case, \"acos\", placement, sbp)\n\n    @globaltest\n    def test_global_atan_grad_grad(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=2):\n                _global_math_op_grad_grad_impl(test_case, \"atan\", placement, sbp)\n\n    @globaltest\n    def test_global_asinh_grad_grad(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=2):\n                _global_math_op_grad_grad_impl(test_case, \"asinh\", placement, sbp)\n\n    @globaltest\n    def test_global_acosh_grad_grad(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=2):\n                _global_math_op_grad_grad_impl(test_case, \"acosh\", placement, sbp)\n\n    @globaltest\n    def test_global_atanh_grad_grad(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=2):\n                _global_math_op_grad_grad_impl(test_case, \"atanh\", placement, sbp)\n\n    @globaltest\n    def test_global_erf_grad_grad(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=2):\n                _global_math_op_grad_grad_impl(test_case, \"erf\", placement, sbp)\n\n    @globaltest\n    def test_global_erfc_grad_grad(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=2):\n                _global_math_op_grad_grad_impl(test_case, \"erfc\", placement, sbp)\n\n    @globaltest\n    def test_global_exp_grad_grad(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=2):\n                _global_math_op_grad_grad_impl(test_case, \"exp\", placement, sbp)\n\n    @globaltest\n    def test_global_exp2_grad_grad(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=2):\n                _global_math_op_grad_grad_impl(test_case, \"exp2\", placement, sbp)\n\n    @globaltest\n    def test_global_expm1_grad_grad(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=2):\n                _global_math_op_grad_grad_impl(test_case, \"expm1\", placement, sbp)\n\n    @globaltest\n    def test_global_log_grad_grad(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=2):\n                _global_math_op_grad_grad_impl(test_case, \"log\", placement, sbp)\n\n    @globaltest\n    def test_global_logsigmoid_grad_grad(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=2):\n                _global_math_op_grad_grad_impl(\n                    test_case, \"nn.functional.logsigmoid\", placement, sbp\n                )\n\n    @globaltest\n    def test_global_log2_grad_grad(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=2):\n                _global_math_op_grad_grad_impl(test_case, \"log2\", placement, sbp)\n\n    @globaltest\n    def test_global_log1p_grad_grad(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=2):\n                _global_math_op_grad_grad_impl(test_case, \"log1p\", placement, sbp)\n\n    @globaltest\n    def test_global_reciprocal_grad_grad(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=2):\n                _global_math_op_grad_grad_impl(test_case, \"reciprocal\", placement, sbp)\n\n    @globaltest\n    def test_global_rsqrt_grad_grad(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=2):\n                _global_math_op_grad_grad_impl(test_case, \"rsqrt\", placement, sbp)\n\n    @globaltest\n    def test_global_sqrt_grad_grad(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=2):\n                _global_math_op_grad_grad_impl(test_case, \"sqrt\", placement, sbp)\n\n    @globaltest\n    def test_global_square_grad_grad(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=2):\n                _global_math_op_grad_grad_impl(test_case, \"square\", placement, sbp)\n\n    @globaltest\n    def test_global_sigmoid_grad_grad(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=2):\n                _global_math_op_grad_grad_impl(test_case, \"sigmoid\", placement, sbp)\n\n    @globaltest\n    def test_global_abs_grad_grad(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=2):\n                _global_math_op_grad_grad_impl(test_case, \"abs\", placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_math_ops.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=1)\ndef _test_sinh(test_case, placement, sbp, ndim):\n    dim_list = [random(1, 3).to(int).value() * 8 for _ in range(ndim)]\n    x = random_tensor(ndim, *dim_list).to_global(placement, sbp)\n    y = torch.sinh(x)\n    return y\n\n\n@autotest(n=1)\ndef _test_sin(test_case, placement, sbp, ndim):\n    dim_list = [random(1, 3).to(int).value() * 8 for _ in range(ndim)]\n    x = random_tensor(ndim, *dim_list).to_global(placement, sbp)\n    y = torch.sin(x)\n    return y\n\n\n@autotest(n=1)\ndef _test_inplace_sin(test_case, placement, sbp, ndim):\n    dim_list = [random(1, 3).to(int).value() * 8 for _ in range(ndim)]\n    x = random_tensor(ndim, *dim_list).to_global(placement, sbp)\n    y = x + 1\n    y.sin_()\n    return y\n\n\n@autotest(n=1)\ndef _test_cos(test_case, placement, sbp, ndim):\n    dim_list = [random(1, 3).to(int).value() * 8 for _ in range(ndim)]\n    x = random_tensor(ndim, *dim_list).to_global(placement, sbp)\n    y = torch.cos(x)\n    return y\n\n\n@autotest(n=1)\ndef _test_log(test_case, placement, sbp, ndim):\n    dim_list = [random(1, 3).to(int).value() * 8 for _ in range(ndim)]\n    x = random_tensor(ndim, *dim_list).to_global(placement, sbp)\n    y = torch.log(x)\n    return y\n\n\n@autotest(n=1)\ndef _test_sqrt(test_case, placement, sbp, ndim):\n    dim_list = [random(1, 3).to(int).value() * 8 for _ in range(ndim)]\n    x = random_tensor(ndim, *dim_list).to_global(placement, sbp)\n    y = torch.sqrt(x)\n    return y\n\n\n@autotest(n=1)\ndef _test_exp(test_case, placement, sbp, ndim):\n    dim_list = [random(1, 3).to(int).value() * 8 for _ in range(ndim)]\n    x = random_tensor(ndim, *dim_list).to_global(placement, sbp)\n    y = torch.exp(x)\n    return y\n\n\n@autotest(n=1)\ndef _test_exp2(test_case, placement, sbp, ndim):\n    dim_list = [random(1, 3).to(int).value() * 8 for _ in range(ndim)]\n    x = random_tensor(ndim, *dim_list).to_global(placement, sbp)\n    y = torch.exp2(x)\n    return y\n\n\n@autotest(n=1)\ndef _test_rsqrt(test_case, placement, sbp, ndim):\n    dim_list = [random(1, 3).to(int).value() * 8 for _ in range(ndim)]\n    x = random_tensor(ndim, *dim_list).to_global(placement, sbp)\n    y = torch.rsqrt(x)\n    return y\n\n\n@autotest(n=1)\ndef _test_square(test_case, placement, sbp, ndim):\n    dim_list = [random(1, 3).to(int).value() * 8 for _ in range(ndim)]\n    x = random_tensor(ndim, *dim_list).to_global(placement, sbp)\n    y = torch.square(x)\n    return y\n\n\n@autotest(n=1)\ndef _test_pow_with_scalar(test_case, placement, sbp, ndim):\n    dim_list = [random(1, 3).to(int).value() * 8 for _ in range(ndim)]\n    x = random_tensor(ndim, *dim_list).to_global(placement, sbp)\n    y = random().to(float)\n    z = torch.pow(x, y)\n    return z\n\n\n@autotest(n=1, auto_backward=False)\ndef _test_floordiv_with_scalar(test_case, placement, sbp, ndim):\n    dim_list = [random(1, 3).to(int).value() * 8 for _ in range(ndim)]\n    # The random value is narrowed to positive number because of the error from pytorch 1.10.0\n    # Please remove the value range striction after updating the pytorch version of ci to 1.13.\n    x = random_tensor(ndim, *dim_list, low=0, high=10).to_global(placement, sbp)\n    y = random().to(float)\n    z = torch.floor_divide(x, y)\n    return z\n\n\n@autotest(n=1)\ndef _test_arccos(test_case, placement, sbp, ndim):\n    dim_list = [random(1, 3).to(int).value() * 8 for _ in range(ndim)]\n    x = random_tensor(ndim, *dim_list, low=-1, high=1).to_global(placement, sbp)\n    y = torch.arccos(x)\n    return y\n\n\n@autotest(n=1)\ndef _test_acos(test_case, placement, sbp, ndim):\n    dim_list = [random(1, 3).to(int).value() * 8 for _ in range(ndim)]\n    x = random_tensor(ndim, *dim_list, low=-1, high=1).to_global(placement, sbp)\n    y = torch.acos(x)\n    return y\n\n\n@autotest(n=1)\ndef _test_arccosh(test_case, placement, sbp, ndim):\n    dim_list = [random(1, 3).to(int).value() * 8 for _ in range(ndim)]\n    x = random_tensor(ndim, *dim_list, low=2, high=3).to_global(placement, sbp)\n    y = torch.arccosh(x)\n    return y\n\n\n@autotest(n=1)\ndef _test_acosh(test_case, placement, sbp, ndim):\n    dim_list = [random(1, 3).to(int).value() * 8 for _ in range(ndim)]\n    x = random_tensor(ndim, *dim_list, low=2, high=3).to_global(placement, sbp)\n    y = torch.acosh(x)\n    return y\n\n\n@autotest(n=1, auto_backward=False)\ndef _test_floordiv(test_case, placement, sbp, ndim):\n    dim_list = [random(1, 3).to(int).value() * 8 for _ in range(ndim)]\n    # The random value is narrowed to positive number because of the error from pytorch 1.10.0\n    # Please remove the value range striction after updating the pytorch version of ci to 1.13.\n    x = random_tensor(ndim, *dim_list, low=0, high=10).to_global(placement, sbp)\n    y = random_tensor(ndim, *dim_list, low=1, high=10).to_global(placement, sbp)\n    z = torch.floor_divide(x, y)\n    return z\n\n\n@autotest(n=1)\ndef _test_atan2(test_case, placement, sbp, ndim):\n    dim_list = [random(1, 3).to(int).value() * 8 for _ in range(ndim)]\n    x = random_tensor(ndim, *dim_list).to_global(placement, sbp)\n    y = random_tensor(ndim, *dim_list).to_global(placement, sbp)\n    z = torch.atan2(x, y)\n    return z\n\n\n@autotest(n=1)\ndef _test_digamma(test_case, placement, sbp, ndim):\n    dim_list = [random(1, 3).to(int).value() * 8 for _ in range(ndim)]\n    x = random_tensor(ndim, *dim_list, low=0, high=10).to_global(placement, sbp)\n    y = torch.digamma(x)\n    return y\n\n\nclass TestMathOps(flow.unittest.TestCase):\n    @globaltest\n    def test_math_ops(test_case):\n        ndim = random(1, 3).to(int).value()\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=ndim):\n                _test_sinh(test_case, placement, sbp, ndim)\n                _test_sin(test_case, placement, sbp, ndim)\n                _test_inplace_sin(test_case, placement, sbp, ndim)\n                _test_cos(test_case, placement, sbp, ndim)\n                _test_log(test_case, placement, sbp, ndim)\n                _test_sqrt(test_case, placement, sbp, ndim)\n                _test_exp(test_case, placement, sbp, ndim)\n                _test_exp2(test_case, placement, sbp, ndim)\n                _test_rsqrt(test_case, placement, sbp, ndim)\n                _test_square(test_case, placement, sbp, ndim)\n                _test_pow_with_scalar(test_case, placement, sbp, ndim)\n                _test_floordiv_with_scalar(test_case, placement, sbp, ndim)\n                _test_arccos(test_case, placement, sbp, ndim)\n                _test_acos(test_case, placement, sbp, ndim)\n                _test_arccosh(test_case, placement, sbp, ndim)\n                _test_acosh(test_case, placement, sbp, ndim)\n                _test_digamma(test_case, placement, sbp, ndim)\n\n                _test_floordiv(test_case, placement, sbp, ndim)\n                _test_atan2(test_case, placement, sbp, ndim)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_matmul.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\n\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=1, check_graph=True)\ndef _test_matmul(test_case, placement, x_sbp, y_sbp):\n    x = random_tensor(ndim=2, dim0=8, dim1=16).to_global(placement=placement, sbp=x_sbp)\n    y = random_tensor(ndim=2, dim0=16, dim1=8).to_global(placement=placement, sbp=y_sbp)\n    return torch.matmul(x, y)\n\n\n@autotest(n=1, check_graph=True)\ndef _test_tensor_broadcast_matmul(test_case, placement, x_sbp, y_sbp):\n    x = random_tensor(ndim=3, dim0=8, dim1=8, dim2=16).to_global(\n        placement=placement, sbp=x_sbp\n    )\n    y = random_tensor(ndim=2, dim0=16, dim1=8).to_global(placement=placement, sbp=y_sbp)\n    return x.matmul(y)\n\n\nclass TestMatMulModule(flow.unittest.TestCase):\n    @globaltest\n    def test_matmul(test_case):\n        for placement in all_placement():\n            for x_sbp in all_sbp(placement, max_dim=2):\n                for y_sbp in all_sbp(placement, max_dim=2):\n                    _test_matmul(test_case, placement, x_sbp, y_sbp)\n\n    @globaltest\n    def test_broadcast_matmul(test_case):\n        for placement in all_placement():\n            for x_sbp in all_sbp(placement, valid_split_axis=[0, 1, 2, 3]):\n                for y_sbp in all_sbp(placement, valid_split_axis=[0, 1]):\n                    _test_tensor_broadcast_matmul(test_case, placement, x_sbp, y_sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_max.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nimport numpy as np\nimport oneflow as flow\nimport oneflow.unittest\nfrom collections import OrderedDict\nfrom oneflow.test_utils.automated_test_util import *\nfrom oneflow.test_utils.test_util import GenArgList\n\n\ndef _np_max(shape, dim, keepdims):\n    # np array result\n    input_arr = np.random.randn(*shape)\n    np_out = np.amax(input_arr, axis=dim, keepdims=keepdims)\n    np_out_grad = np.zeros_like(input_arr)\n    if dim == None:\n        arg_max = np.argmax(input_arr)\n        np.put(np_out_grad, arg_max, 1)\n    else:\n        arg_max = np.expand_dims(np.argmax(input_arr, axis=dim), axis=dim)\n        np.put_along_axis(np_out_grad, arg_max, 1, axis=dim)\n\n    return np_out, np_out_grad, input_arr\n\n\ndef _test_max(\n    test_case, placement, sbp, np_out, np_out_grad, input_arr, shape, dim, keepdims\n):\n    # of result\n    global_x = flow.tensor(\n        input_arr,\n        dtype=flow.float32,\n        requires_grad=True,\n        placement=flow.placement.all(\"cpu\"),\n        sbp=flow.sbp.broadcast,\n    )\n    if dim is None:\n        of_out = flow.max(global_x)\n    else:\n        of_out = flow.max(global_x, dim, keepdims)[0]\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05))\n    of_out = of_out.sum()\n    of_out.backward()\n\n    test_case.assertTrue(\n        np.allclose(global_x.grad.numpy(), np_out_grad, 0.0001, 0.0001)\n    )\n\n\nclass TestMaxModule(flow.unittest.TestCase):\n    # backward formula is different from one of torch.\n    @globaltest\n    def test_eager_global_max(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [_test_max]\n        arg_dict[\"shape\"] = [(8,), (8, 8), (8, 8, 8, 8)]\n        arg_dict[\"dim\"] = [None, 0, -1]\n        arg_dict[\"keepdims\"] = [False, True]\n        for arg in GenArgList(arg_dict):\n            np_out, np_out_grad, input_arr = _np_max(*arg[1:])\n            np_out = (\n                flow.tensor(np_out)\n                .to_global(placement=flow.placement.all(\"cpu\"), sbp=flow.sbp.broadcast,)\n                .numpy()\n            )\n            np_out_grad = (\n                flow.tensor(np_out_grad)\n                .to_global(placement=flow.placement.all(\"cpu\"), sbp=flow.sbp.broadcast,)\n                .numpy()\n            )\n            input_arr = (\n                flow.tensor(input_arr)\n                .to_global(placement=flow.placement.all(\"cpu\"), sbp=flow.sbp.broadcast,)\n                .numpy()\n            )\n            for placement in all_placement():\n                for sbp in all_sbp(placement, max_dim=len(*arg[1:2])):\n                    arg[0](\n                        test_case,\n                        placement,\n                        sbp,\n                        np_out,\n                        np_out_grad,\n                        input_arr,\n                        *arg[1:]\n                    )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_maximum_minimum.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\n\nimport oneflow as flow\nimport oneflow.unittest\nimport torch as torch_original\nfrom packaging import version\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(\n    n=5,\n    auto_backward=(\n        version.parse(torch_original.__version__) >= version.parse(\"1.10.2\")\n    ),\n    check_graph=True,\n)\ndef _test_broadcast_maximum(test_case, placement, x_sbp, y_sbp):\n    x = random_tensor(ndim=5, dim0=8, dim1=8, dim2=8, dim3=1, dim4=8).to_global(\n        placement, x_sbp\n    )\n    y = random_tensor(ndim=5, dim0=8, dim1=8, dim2=1, dim3=8, dim4=1).to_global(\n        placement, y_sbp\n    )\n    z = torch.maximum(x, y)\n    return z\n\n\n@autotest(\n    n=5,\n    auto_backward=(\n        version.parse(torch_original.__version__) >= version.parse(\"1.10.2\")\n    ),\n    check_graph=True,\n)\ndef _test_broadcast_minimum(test_case, placement, x_sbp, y_sbp):\n    x = random_tensor(ndim=5, dim0=8, dim1=8, dim2=8, dim3=1, dim4=8).to_global(\n        placement, x_sbp\n    )\n    y = random_tensor(ndim=5, dim0=8, dim1=8, dim2=1, dim3=8, dim4=1).to_global(\n        placement, y_sbp\n    )\n    z = torch.minimum(x, y)\n    return z\n\n\n@autotest(\n    n=5,\n    auto_backward=(\n        version.parse(torch_original.__version__) >= version.parse(\"1.10.2\")\n    ),\n    check_graph=True,\n)\ndef _test_maximum_with_same_input(test_case, placement, sbp):\n    x = random_tensor(ndim=4, dim0=8, dim1=8, dim2=8, dim3=8).to_global(placement, sbp)\n    y = x.detach().clone()\n    y.requires_grad = True\n    z = torch.maximum(x, y)\n    return z\n\n\n@autotest(\n    n=5,\n    auto_backward=(\n        version.parse(torch_original.__version__) >= version.parse(\"1.10.2\")\n    ),\n    check_graph=True,\n)\ndef _test_minimum_with_same_input(test_case, placement, sbp):\n    x = random_tensor(ndim=4, dim0=8, dim1=8, dim2=8, dim3=8).to_global(placement, sbp)\n    y = x.detach().clone()\n    y.requires_grad = True\n    z = torch.minimum(x, y)\n    return z\n\n\nclass TestMaximumMinimumOps(flow.unittest.TestCase):\n    @globaltest\n    def test_maximum_minimum_with_same_input(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=2):\n                _test_maximum_with_same_input(test_case, placement, sbp)\n                _test_minimum_with_same_input(test_case, placement, sbp)\n\n    @globaltest\n    def test_broadcast_maximum_minimum(test_case):\n        for placement in all_placement():\n            for x_sbp in all_sbp(placement, valid_split_axis=[0, 1, 2, 4]):\n                for y_sbp in all_sbp(placement, valid_split_axis=[0, 1, 3]):\n                    _test_broadcast_maximum(test_case, placement, x_sbp, y_sbp)\n                    _test_broadcast_minimum(test_case, placement, x_sbp, y_sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_maxpool.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nimport numpy as np\nfrom pkg_resources import packaging\nimport oneflow as flow\nimport torch as ori_torch\nimport oneflow.unittest\nfrom collections import OrderedDict\nfrom oneflow.test_utils.automated_test_util import *\nfrom oneflow.test_utils.test_util import GenArgList\nfrom oneflow.nn.common_types import _size_1_t, _size_2_t, _size_3_t\n\n\n@autotest(n=1, check_graph=True)\ndef _test_maxpool1d_functional(test_case, placement, sbp):\n    return_indices = random().to(bool).value()\n    dim0 = random(1, 4).to(int).value() * 8\n    dim1 = random(1, 4).to(int).value() * 8\n    x = random_tensor(ndim=3, dim0=dim0, dim1=dim1, dim2=random(20, 22)).to_global(\n        placement, sbp\n    )\n    y = torch.nn.functional.max_pool1d(\n        x,\n        kernel_size=random(4, 6).to(int),\n        stride=random(1, 3).to(int),\n        padding=random(1, 3).to(int),\n        dilation=random(2, 4).to(int),\n        ceil_mode=random().to(bool),\n        return_indices=return_indices,\n    )\n    if return_indices:\n        return y[0]\n    else:\n        return y\n\n\n@autotest(n=1, check_graph=True)\ndef _test_maxpool2d_functional(test_case, placement, sbp):\n    return_indices = random().to(bool).value()\n    dim0 = random(1, 4).to(int).value() * 8\n    dim1 = random(1, 4).to(int).value() * 8\n    x = random_tensor(\n        ndim=4, dim0=dim0, dim1=dim1, dim2=random(20, 22), dim3=random(20, 22)\n    ).to_global(placement, sbp)\n    y = torch.nn.functional.max_pool2d(\n        x,\n        kernel_size=random(4, 6).to(int),\n        stride=random(1, 3).to(int),\n        padding=random(1, 3).to(int),\n        dilation=random(2, 4).to(int),\n        ceil_mode=random().to(bool),\n        return_indices=return_indices,\n    )\n\n    if return_indices:\n        return y[0]\n    else:\n        return y\n\n\n@autotest(n=1, check_graph=True)\ndef _test_maxpool3d_functional(test_case, placement, sbp):\n    return_indices = random().to(bool).value()\n    dim0 = random(high=4).to(int).value() * 8\n    dim1 = random(high=4).to(int).value() * 8\n    x = random_tensor(\n        ndim=5,\n        dim0=dim0,\n        dim1=dim1,\n        dim2=random(10, 12),\n        dim3=random(10, 12),\n        dim4=random(10, 12),\n    ).to_global(placement, sbp)\n    y = torch.nn.functional.max_pool3d(\n        x,\n        kernel_size=random(4, 6).to(int),\n        stride=random(1, 3).to(int),\n        padding=random(1, 3).to(int),\n        dilation=random(2, 4).to(int),\n        ceil_mode=random().to(bool),\n        return_indices=return_indices,\n    )\n\n    if return_indices:\n        return y[0]\n    else:\n        return y\n\n\n@autotest(n=1, check_graph=True)\ndef _test_maxpool1d(test_case, placement, sbp):\n    return_indices = random().to(bool).value()\n    dim0 = random(1, 4).to(int).value() * 8\n    dim1 = random(1, 4).to(int).value() * 8\n    m = torch.nn.MaxPool1d(\n        kernel_size=random(4, 6).to(_size_1_t),\n        stride=random(1, 3).to(_size_1_t),\n        padding=random(1, 3).to(_size_1_t),\n        dilation=random(2, 4).to(_size_1_t),\n        ceil_mode=random(),\n        return_indices=return_indices,\n    )\n    m.train(random())\n    x = random_tensor(ndim=3, dim0=dim0, dim1=dim1, dim2=random(20, 22)).to_global(\n        placement, sbp\n    )\n    y = m(x)\n    if return_indices:\n        return y[0]\n    else:\n        return y\n\n\n@autotest(n=1, check_graph=True)\ndef _test_maxpool2d(test_case, placement, sbp):\n    return_indices = random().to(bool).value()\n    dim0 = random(1, 3).to(int).value() * 8\n    dim1 = random(1, 3).to(int).value() * 8\n    m = torch.nn.MaxPool2d(\n        kernel_size=random(4, 6).to(_size_2_t),\n        stride=random(1, 3).to(_size_2_t),\n        padding=random(1, 3).to(_size_2_t),\n        dilation=random(2, 4).to(_size_2_t),\n        ceil_mode=random(),\n        return_indices=return_indices,\n    )\n    m.train(random())\n    x = random_tensor(\n        ndim=4, dim0=dim0, dim1=dim1, dim2=random(20, 22), dim3=random(20, 22)\n    ).to_global(placement, sbp)\n    y = m(x)\n    if return_indices:\n        return y[0]\n    else:\n        return y\n\n\n@autotest(n=1, check_graph=True)\ndef _test_maxpool3d(test_case, placement, sbp):\n    return_indices = random().to(bool).value()\n    dim0 = random(high=4).to(int).value() * 8\n    dim1 = random(high=4).to(int).value() * 8\n    m = torch.nn.MaxPool3d(\n        kernel_size=random(4, 6).to(_size_3_t),\n        stride=random(1, 3).to(_size_3_t),\n        padding=random(1, 3).to(_size_3_t),\n        dilation=random(2, 4).to(_size_3_t),\n        ceil_mode=random(),\n        return_indices=return_indices,\n    )\n    m.train(random())\n    x = random_tensor(\n        ndim=5,\n        dim0=dim0,\n        dim1=dim1,\n        dim2=random(10, 12),\n        dim3=random(10, 12),\n        dim4=random(10, 12),\n    ).to_global(placement, sbp)\n    y = m(x)\n\n    if return_indices:\n        return y[0]\n    else:\n        return y\n\n\ndef _test_maxpool2d_channel_last(\n    test_case, placement, sbp, shape, kernel_size, stride, padding, dilation, ceil_mode\n):\n    os.environ[\"ONEFLOW_ENABLE_NHWC\"] = \"1\"\n\n    tensor = random_tensor(len(shape), *shape, requires_grad=False).to_global(\n        placement, sbp\n    )\n    # oneflow result\n    x1 = tensor.oneflow\n    m1 = flow.nn.MaxPool2d(\n        kernel_size=kernel_size,\n        stride=stride,\n        padding=padding,\n        dilation=dilation,\n        ceil_mode=ceil_mode,\n    )\n    y1 = m1(x1)\n\n    # pytorch result\n    x2 = tensor.pytorch.permute(0, 3, 1, 2).to(placement.type)\n    m2 = ori_torch.nn.MaxPool2d(\n        kernel_size=kernel_size,\n        stride=stride,\n        padding=padding,\n        dilation=dilation,\n        ceil_mode=ceil_mode,\n    )\n    y2 = m2(x2).permute(0, 2, 3, 1)\n    os.environ[\"ONEFLOW_ENABLE_NHWC\"] = \"1\"\n\n    # It should be added after updating to torch1.13\n    # test_case.assertTrue(\n    #     np.allclose(y1.detach().cpu().numpy(), y2.detach().cpu().numpy(), 1e-4, 1e-4)\n    # )\n\n\nclass TestMaxPool(flow.unittest.TestCase):\n    @globaltest\n    def test_maxpool(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=2):\n                _test_maxpool1d_functional(test_case, placement, sbp)\n                _test_maxpool2d_functional(test_case, placement, sbp)\n                _test_maxpool3d_functional(test_case, placement, sbp)\n                _test_maxpool1d(test_case, placement, sbp)\n                _test_maxpool2d(test_case, placement, sbp)\n                _test_maxpool3d(test_case, placement, sbp)\n\n    @globaltest\n    @unittest.skipIf(\n        packaging.version.parse(ori_torch.__version__)\n        == packaging.version.parse(\"1.10.0\"),\n        \"skip when pytorch version == 1.10.0\",\n    )\n    # NOTE:pytorch maxpool2d nhwc has bug in version of 1.10.0, so skip it in CI.\n    # detail:https://github.com/pytorch/pytorch/pull/76597\n    def test_maxpool2d_channel_last(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [_test_maxpool2d_channel_last]\n        arg_dict[\"shape\"] = [(1, 16, 16, 3), (2, 224, 224, 3)]\n        arg_dict[\"kernel_size\"] = [3, (2, 3)]\n        arg_dict[\"stride\"] = [1, (1, 2)]\n        arg_dict[\"padding\"] = [0, (0, 1)]\n        arg_dict[\"dilation\"] = [1, 2]\n        arg_dict[\"ceil_mode\"] = [True, False]\n        for arg in GenArgList(arg_dict):\n            for placement in all_placement():\n                for sbp in all_sbp(placement, valid_split_axis=[1, 2]):\n                    arg[0](test_case, placement, sbp, *arg[1:])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_maxunpool.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nimport numpy as np\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\nfrom oneflow.nn.common_types import _size_1_t, _size_2_t, _size_3_t\n\n# y = pool(x), z = unpool(y, indices), pool_input_shape is x.shape, pool_output_shape is y.shape.\n# When `output_size` in unpool() is empty, the op will calculate the output size according to\n# kernel_size, stride and padding. But when index in indices is outside the range required\n# by output_size calculated by unpool op, the value of result and related grad will be unknown.\n# To avoid the problem, this function calculate the output_size which will not cause unknown problems.\ndef _get_valid_output_size(\n    pool_input_shape, pool_output_shape, kernel_size, stride, padding\n):\n    def convert_data(data, i, dst_data=None):\n        if not isinstance(data, (list, int)):\n            return dst_data\n        if isinstance(data, list):\n            return data[i]\n        return data\n\n    _, _, *pool_input_hwd_shape = pool_input_shape.pytorch\n    batch_size, num_channels, *pool_out_hwd_shape = pool_output_shape.pytorch\n    unpool_output_shape = [batch_size, num_channels]\n    for i, (pool_input_size, pool_output_size) in enumerate(\n        zip(pool_input_hwd_shape, pool_out_hwd_shape)\n    ):\n        kernel_size_value = convert_data(kernel_size.value(), i)\n        stride_value = convert_data(stride.value(), i, kernel_size_value)\n        padding_value = convert_data(padding.value(), i, 0)\n        unpool_output_size = max(\n            pool_input_size,\n            (pool_output_size - 1) * stride_value\n            - 2 * padding_value\n            + kernel_size_value,\n        )\n        unpool_output_shape.append(unpool_output_size)\n    return torch.Size(unpool_output_shape)\n\n\ndef _test_module_unpoolnd(test_case, placement, sbp, n):\n    device = random_device()\n    dim0 = random(high=4).to(int).value() * 8\n    dim1 = random(high=4).to(int).value() * 8\n    if n == 1:\n        _size_n_t = _size_1_t\n        MaxPoolNd = torch.nn.MaxPool1d\n        MaxUnpoolNd = torch.nn.MaxUnpool1d\n        x = random_tensor(\n            ndim=3, dim0=dim0, dim1=dim1, dim2=random(20, 31), requires_grad=False\n        ).to_global(placement=placement, sbp=sbp)\n    elif n == 2:\n        _size_n_t = _size_2_t\n        MaxPoolNd = torch.nn.MaxPool2d\n        MaxUnpoolNd = torch.nn.MaxUnpool2d\n        x = random_tensor(\n            ndim=4,\n            dim0=dim0,\n            dim1=dim1,\n            dim2=random(10, 21),\n            dim3=random(10, 21),\n            requires_grad=False,\n        ).to_global(placement=placement, sbp=sbp)\n    elif n == 3:\n        _size_n_t = _size_3_t\n        MaxPoolNd = torch.nn.MaxPool3d\n        MaxUnpoolNd = torch.nn.MaxUnpool3d\n        x = random_tensor(\n            ndim=5,\n            dim0=dim0,\n            dim1=dim1,\n            dim2=random(10, 14),\n            dim3=random(10, 14),\n            dim4=random(10, 14),\n            requires_grad=False,\n        ).to_global(placement=placement, sbp=sbp)\n\n    kernel_size = random(4, 6).to(_size_n_t)\n    stride = random(1, 3).to(_size_n_t)\n    padding = random(1, 3).to(_size_n_t)\n    m = MaxPoolNd(\n        kernel_size=kernel_size, stride=stride, padding=padding, return_indices=True,\n    )\n    m.train(random())\n    m.to(device)\n    y = m(x)\n    pooling_results = y[0]\n    indices = y[1]\n    pooling_results.requires_grad_()\n    output_size = _get_valid_output_size(\n        x.shape, pooling_results.shape, kernel_size, stride, padding\n    )\n    unpool_module = MaxUnpoolNd(\n        kernel_size=kernel_size, stride=stride, padding=padding,\n    )\n    result = unpool_module(pooling_results, indices, output_size=output_size)\n    return result\n\n\ndef _test_functional_unpoolnd(test_case, placement, sbp, n):\n    device = random_device()\n    dim0 = random(high=4).to(int).value() * 8\n    dim1 = random(high=4).to(int).value() * 8\n\n    if n == 1:\n        _size_n_t = _size_1_t\n        MaxPoolNd = torch.nn.MaxPool1d\n        max_unpool_nd = torch.nn.functional.max_unpool1d\n        x = random_tensor(\n            ndim=3, dim0=dim0, dim1=dim1, dim2=random(20, 31), requires_grad=False\n        ).to_global(placement=placement, sbp=sbp)\n    elif n == 2:\n        _size_n_t = _size_2_t\n        MaxPoolNd = torch.nn.MaxPool2d\n        max_unpool_nd = torch.nn.functional.max_unpool2d\n        x = random_tensor(\n            ndim=4,\n            dim0=dim0,\n            dim1=dim1,\n            dim2=random(10, 21),\n            dim3=random(10, 21),\n            requires_grad=False,\n        ).to_global(placement=placement, sbp=sbp)\n    elif n == 3:\n        _size_n_t = _size_3_t\n        MaxPoolNd = torch.nn.MaxPool3d\n        max_unpool_nd = torch.nn.functional.max_unpool3d\n        x = random_tensor(\n            ndim=5,\n            dim0=dim0,\n            dim1=dim1,\n            dim2=random(10, 14),\n            dim3=random(10, 14),\n            dim4=random(10, 14),\n            requires_grad=False,\n        ).to_global(placement=placement, sbp=sbp)\n\n    kernel_size = random(4, 6).to(_size_n_t)\n    stride = random(1, 3).to(_size_n_t)\n    padding = random(1, 3).to(_size_n_t)\n    m = MaxPoolNd(\n        kernel_size=kernel_size, stride=stride, padding=padding, return_indices=True,\n    )\n    m.train(random())\n    m.to(device)\n    y = m(x)\n    pooling_results = y[0]\n    indices = y[1]\n    pooling_results.requires_grad_()\n    output_size = _get_valid_output_size(\n        x.shape, pooling_results.shape, kernel_size, stride, padding\n    )\n    return max_unpool_nd(\n        pooling_results,\n        indices,\n        kernel_size=kernel_size,\n        stride=stride,\n        padding=padding,\n        output_size=output_size,\n    )\n\n\nclass TestMaxPool(flow.unittest.TestCase):\n    @globaltest\n    def test_maxpool(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=2):\n                _test_functional_unpoolnd(test_case, placement, sbp, 1)\n                _test_functional_unpoolnd(test_case, placement, sbp, 2)\n                _test_functional_unpoolnd(test_case, placement, sbp, 3)\n                _test_module_unpoolnd(test_case, placement, sbp, 1)\n                _test_module_unpoolnd(test_case, placement, sbp, 2)\n                _test_module_unpoolnd(test_case, placement, sbp, 3)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_mean.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=1, check_graph=True)\ndef _test_mean(test_case, placement, sbp, ndim):\n    dim = random(1, ndim).to(int).value()\n    dim_list = [random(1, 3).to(int).value() * 8 for _ in range(ndim)]\n    x = random_tensor(ndim, *dim_list, dtype=float).to_global(placement, sbp)\n    return torch.mean(x, dim)\n\n\nclass TestMean(flow.unittest.TestCase):\n    @globaltest\n    def test_mean(test_case):\n        ndim = random(2, 5).to(int).value()\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=ndim):\n                _test_mean(test_case, placement, sbp, ndim)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_median.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nimport torch\nfrom functools import reduce\nimport operator\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=1, check_graph=True)\ndef _test_median(test_case, placement, sbp, ndim):\n    dim_list = [random(1, 3).to(int).value() * 8 for _ in range(ndim)]\n    x = random_tensor(ndim, *dim_list).to_global(placement, sbp)\n    return torch.median(x)\n\n\n@autotest(n=1, check_graph=True)\ndef _test_median_with_indices(test_case, placement, sbp, ndim):\n    dim = random(1, ndim).to(int).value()\n    dim_list = [random(1, 3).to(int).value() * 8 for _ in range(ndim)]\n    x = choice_tensor(\n        reduce(operator.mul, dim_list, 1),\n        dim_list,\n        replace=False,\n        dtype=float,\n        requires_grad=True,\n    ).to_global(placement, sbp)\n    return torch.median(x, dim)\n\n\nclass TestMedian(flow.unittest.TestCase):\n    @globaltest\n    def test_median(test_case):\n        ndim = random(2, 5).to(int).value()\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=ndim):\n                _test_median(test_case, placement, sbp, ndim)\n                _test_median_with_indices(test_case, placement, sbp, ndim)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_meshgrid.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\n\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=5, auto_backward=False, check_graph=True)\ndef _test_meshgrid(test_case, placement):\n    x_sbp = random_sbp(placement, max_dim=1)\n    x = random_tensor(ndim=1, dim0=8, requires_grad=False).to_global(placement, x_sbp)\n    y_sbp = random_sbp(placement, max_dim=1)\n    y = random_tensor(ndim=1, dim0=8, requires_grad=False).to_global(placement, y_sbp)\n    res = torch.meshgrid(x, y)\n    return res[0], res[1]\n\n\nclass TestMeshGrid(flow.unittest.TestCase):\n    @globaltest\n    def test_meshgrid(test_case):\n        for placement in all_placement():\n            _test_meshgrid(test_case, placement)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_min.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nimport numpy as np\nimport oneflow as flow\nimport oneflow.unittest\nfrom collections import OrderedDict\nfrom oneflow.test_utils.automated_test_util import *\nfrom oneflow.test_utils.test_util import GenArgList\n\n\ndef _np_min(shape, dim, keepdims):\n    # np array result\n    input_arr = np.random.randn(*shape)\n    np_out = np.amin(input_arr, axis=dim, keepdims=keepdims)\n    np_out_grad = np.zeros_like(input_arr)\n    if dim == None:\n        arg_min = np.argmin(input_arr)\n        np.put(np_out_grad, arg_min, 1)\n    else:\n        arg_min = np.expand_dims(np.argmin(input_arr, axis=dim), axis=dim)\n        np.put_along_axis(np_out_grad, arg_min, 1, axis=dim)\n\n    return np_out, np_out_grad, input_arr\n\n\ndef _test_min(\n    test_case, placement, sbp, np_out, np_out_grad, input_arr, shape, dim, keepdims\n):\n    # of result\n    global_x = flow.tensor(\n        input_arr,\n        dtype=flow.float32,\n        requires_grad=True,\n        placement=flow.placement.all(\"cpu\"),\n        sbp=flow.sbp.broadcast,\n    )\n    if dim is None:\n        of_out = flow.min(global_x)\n    else:\n        of_out = flow.min(global_x, dim, keepdims)[0]\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05))\n    of_out = of_out.sum()\n    of_out.backward()\n\n    test_case.assertTrue(\n        np.allclose(global_x.grad.numpy(), np_out_grad, 0.0001, 0.0001)\n    )\n\n\nclass TestMinModule(flow.unittest.TestCase):\n    # backward formula is different from one of torch.\n    @unittest.skip(\"skip for now, becase it failed 8 times in past week\")\n    @globaltest\n    def test_eager_global_min(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [_test_min]\n        arg_dict[\"shape\"] = [(8,), (8, 8), (8, 8, 8, 8)]\n        arg_dict[\"dim\"] = [None, 0, -1]\n        arg_dict[\"keepdims\"] = [False, True]\n        for arg in GenArgList(arg_dict):\n            np_out, np_out_grad, input_arr = _np_min(*arg[1:])\n            np_out = (\n                flow.tensor(np_out)\n                .to_global(placement=flow.placement.all(\"cpu\"), sbp=flow.sbp.broadcast,)\n                .numpy()\n            )\n            np_out_grad = (\n                flow.tensor(np_out_grad)\n                .to_global(placement=flow.placement.all(\"cpu\"), sbp=flow.sbp.broadcast,)\n                .numpy()\n            )\n            input_arr = (\n                flow.tensor(input_arr)\n                .to_global(placement=flow.placement.all(\"cpu\"), sbp=flow.sbp.broadcast,)\n                .numpy()\n            )\n            for placement in all_placement():\n                for sbp in all_sbp(placement, max_dim=len(*arg[1:2])):\n                    arg[0](\n                        test_case,\n                        placement,\n                        sbp,\n                        np_out,\n                        np_out_grad,\n                        input_arr,\n                        *arg[1:]\n                    )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_min_max_observer.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nimport numpy as np\nimport oneflow as flow\nfrom collections import OrderedDict\nfrom oneflow.test_utils.automated_test_util import *\nfrom oneflow.nn.modules import min_max_observer\nfrom oneflow.test_utils.test_util import GenArgList\nfrom test_min_max_observer import _check_min_max_observer\n\n\ndef _run_test_min_max_observer(\n    test_case,\n    placement,\n    sbp,\n    weight_shape,\n    quantization_bit,\n    quantization_scheme,\n    quantization_formula,\n    per_layer_quantization,\n):\n    weight = random_tensor(\n        len(weight_shape), *weight_shape, low=-0.5, high=0.5\n    ).to_global(placement, sbp)\n    of_weight = weight.oneflow\n    np_weight = of_weight.numpy()\n\n    min_max_observer = flow.nn.MinMaxObserver(\n        quantization_formula=quantization_formula,\n        quantization_bit=quantization_bit,\n        quantization_scheme=quantization_scheme,\n        per_layer_quantization=per_layer_quantization,\n    )\n    scale, zero_point = min_max_observer(of_weight)\n    _check_min_max_observer(\n        test_case,\n        np_weight,\n        scale.numpy(),\n        zero_point.numpy(),\n        quantization_bit,\n        quantization_scheme,\n        quantization_formula,\n        per_layer_quantization,\n    )\n\n\nclass TestMinMaxObserver(flow.unittest.TestCase):\n    @globaltest\n    def test_min_max_observer(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"weight_shape\"] = [(9, 48, 24, 10)]\n        arg_dict[\"quantization_bit\"] = [8, 2]\n        arg_dict[\"quantization_scheme\"] = [\"symmetric\", \"affine\"]\n        arg_dict[\"quantization_formula\"] = [\"google\"]\n        arg_dict[\"per_layer_quantization\"] = [True, False]\n        for arg in GenArgList(arg_dict):\n            for placement in all_placement():\n                for sbp in all_sbp(placement, valid_split_axis=[1, 2]):\n                    _run_test_min_max_observer(test_case, placement, sbp, *arg)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_movedim.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=1, check_graph=True)\ndef _test_movedim(test_case, placement, sbp):\n    x = random_tensor(\n        ndim=4,\n        dim1=random(1, 3) * 8,\n        dim2=random(1, 3) * 8,\n        dim3=random(1, 3) * 8,\n        dim4=random(1, 3) * 8,\n    ).to_global(placement, sbp)\n    z = torch.movedim(x, (0, 1), (2, 3))\n    return z\n\n\nclass TestMovedim(flow.unittest.TestCase):\n    @globaltest\n    def test_movedim(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement):\n                _test_movedim(test_case, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_moving_average_max_min_observer.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nimport numpy as np\nimport oneflow as flow\nfrom collections import OrderedDict\nfrom oneflow.test_utils.test_util import GenArgList\nfrom test_moving_average_min_max_observer import _check_moving_average_min_max_observer\nfrom oneflow.test_utils.automated_test_util import *\n\n\ndef _run_test_moving_average_min_max_observer(\n    test_case,\n    placement,\n    sbp,\n    device_type,\n    dtype,\n    activation_shape,\n    quantization_bit,\n    quantization_scheme,\n    quantization_formula,\n    momentum,\n):\n    moving_max_np = np.zeros((1,))\n    moving_min_np = np.zeros((1,))\n\n    current_train_step_tensor = flow.tensor(\n        np.zeros((1,)).astype(np.float32),\n        dtype=flow.int64,\n        placement=placement,\n        sbp=sbp,\n    )\n    for i in range(10):\n        of_activation = (\n            random_tensor(len(activation_shape), *activation_shape, low=-0.5, high=0.5)\n            .to_global(placement, sbp)\n            .oneflow\n        )\n        np_activation = of_activation.numpy()\n\n        moving_average_min_max_observer = flow.nn.MovingAverageMinMaxObserver(\n            quantization_formula=quantization_formula,\n            stop_update_after_iters=1,\n            quantization_bit=quantization_bit,\n            quantization_scheme=quantization_scheme,\n            momentum=momentum,\n        )\n        moving_average_min_max_observer = moving_average_min_max_observer.to_global(\n            placement, sbp\n        )\n        (scale, zero_point) = moving_average_min_max_observer(\n            of_activation, current_train_step_tensor\n        )\n        _check_moving_average_min_max_observer(\n            test_case,\n            np_activation,\n            scale.numpy(),\n            zero_point.numpy(),\n            moving_max_np,\n            moving_min_np,\n            quantization_bit,\n            quantization_scheme,\n            quantization_formula,\n            momentum,\n        )\n\n\nclass TestMovingAverageMinMaxObserver(flow.unittest.TestCase):\n    @globaltest\n    def test_moving_average_min_max_observer(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"device_type\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"dtype\"] = [\"float32\", \"double\"]\n        arg_dict[\"activation_shape\"] = [(9, 48, 24, 10)]\n        arg_dict[\"quantization_bit\"] = [8, 2]\n        arg_dict[\"quantization_scheme\"] = [\"symmetric\", \"affine\"]\n        arg_dict[\"quantization_formula\"] = [\"google\"]\n        arg_dict[\"momentum\"] = [0.95]\n        for arg in GenArgList(arg_dict):\n            for placement in all_placement():\n                for sbp in all_sbp(placement, valid_split_axis=[1, 2]):\n                    _run_test_moving_average_min_max_observer(\n                        test_case, placement, sbp, *arg\n                    )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_mul.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=1, check_graph=True)\ndef _test_broadcast_mul(test_case, placement, sbp):\n    x = random_tensor(ndim=3, dim0=16, dim1=8, dim2=24).to_global(placement, sbp)\n    y_sbp = random_sbp(placement, max_dim=2)\n    y = random_tensor(ndim=2, dim0=8, dim1=24).to_global(placement, y_sbp)\n    z = torch.mul(x, y)\n    return z\n\n\n@autotest(n=1, check_graph=True)\ndef _test_mul_with_scalar(test_case, ndim, placement, sbp):\n    dim_list = [random(1, 3).to(int).value() * 8 for _ in range(ndim)]\n    x = random_tensor(ndim, *dim_list).to_global(placement, sbp)\n    y = 2\n    return torch.mul(x, y)\n\n\nclass TestMulModule(flow.unittest.TestCase):\n    @globaltest\n    def test_broadcast_mul(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=3):\n                _test_broadcast_mul(test_case, placement, sbp)\n\n    @globaltest\n    def test_mul_with_scalar(test_case):\n        ndim = random(1, 4).to(int).value()\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=ndim):\n                _test_mul_with_scalar(test_case, ndim, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_mv.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=1, check_graph=True)\ndef _test_mv(test_case, placement, sbp):\n    dim = random(1, 6)\n    mat = random_tensor(2, dim1=dim).to_global(placement=placement, sbp=sbp)\n    vec = random_tensor(1, dim0=dim).to_global(placement=placement, sbp=sbp)\n    return torch.mv(mat, vec)\n\n\nclass TestMvModule(flow.unittest.TestCase):\n    @globaltest\n    def test_mv(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement):\n                _test_mv(test_case, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_nansum.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\n\nfrom oneflow.test_utils.automated_test_util import *\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\n@autotest(n=1, check_graph=False)\ndef _test_global_nansum_against_pytorch(test_case, placement, sbp):\n    x = random_tensor(4, 8, 16, 8, 24).to_global(placement, sbp)\n    mask = x < 0\n    x = x.masked_fill(mask, float(\"nan\"))\n    y = torch.nansum(x)\n    return y\n\n\n@autotest(n=1, check_graph=False)\ndef _test_global_nansum_with_0_size_tensor(test_case, placement, sbp):\n    x = random_tensor(4, 8, 16, 0, 24).to_global(placement, sbp)\n    mask = torch.ones_like(x).bool()\n    x = x.masked_fill(mask, float(\"nan\"))\n    y = torch.nansum(x, dim=random(0, 3).to(int))\n    return y\n\n\nclass TestGlobalNanSumModule(flow.unittest.TestCase):\n    @globaltest\n    def test_global_nansum_against_pytorch(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=4):\n                _test_global_nansum_against_pytorch(test_case, placement, sbp)\n\n    @globaltest\n    def test_global_nansum_with_0_size_tensor(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=4, valid_split_axis=[0, 1, 3]):\n                _test_global_nansum_with_0_size_tensor(test_case, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_narrow.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=1, check_graph=True)\ndef _test_narrow(test_case, ndim, placement, sbp):\n    dims = [random(1, 3).to(int).value() * 8 for _ in range(ndim)]\n    x = random_tensor(ndim, *dims).to_global(placement=placement, sbp=sbp)\n    dim = random(-ndim, ndim).to(int).value()\n    start = random(0, dims[dim]).to(int).value()\n    length = random(1, dims[dim] - start + 1).to(int).value()\n\n    return torch.narrow(x, dim=dim, start=start, length=length)\n\n\nclass TestNarrow(flow.unittest.TestCase):\n    @globaltest\n    def test_narrow(test_case):\n        for placement in all_placement():\n            ndim = random(1, 4).to(int).value()\n            for sbp in all_sbp(placement, max_dim=min(ndim, 2)):\n                _test_narrow(test_case, ndim, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_ne.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=1, auto_backward=False, check_graph=True)\ndef _test_ne(test_case, placement, sbp):\n    x1 = random_tensor(ndim=2, dim0=8, dim1=8).to_global(placement, sbp)\n    x2 = random_tensor(ndim=2, dim0=8, dim1=8).to_global(placement, sbp)\n    return torch.ne(x1, x2)\n\n\nclass TestNe(flow.unittest.TestCase):\n    @globaltest\n    def test_ne(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=2):\n                _test_ne(test_case, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_negative.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=1, check_graph=True)\ndef _test_negative(test_case, placement, sbp, ndim):\n    shape = [8 for _ in range(ndim)]\n    x = random_tensor(ndim, *shape).to_global(placement, sbp)\n    return torch.negative(x)\n\n\nclass TestNegative(flow.unittest.TestCase):\n    @globaltest\n    def test_negative(test_case):\n        ndim = random(2, 5).to(int).value()\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=ndim):\n                _test_negative(test_case, placement, sbp, ndim)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_nms.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nimport numpy as np\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\nfrom test_nms import create_tensors_with_iou\nfrom test_nms import nms_np\n\n\ndef _test_nms(test_case, placement, sbp):\n    iou = 0.5\n    boxes, scores = create_tensors_with_iou(800, iou)\n\n    global_boxes = flow.tensor(boxes, dtype=flow.float32).to_global(\n        placement=flow.placement.all(\"cpu\"), sbp=flow.sbp.broadcast\n    )\n    np_boxes = global_boxes.numpy()\n    global_boxes = global_boxes.to_global(placement=placement, sbp=sbp)\n\n    global_scores = flow.tensor(scores, dtype=flow.float32).to_global(\n        placement=flow.placement.all(\"cpu\"), sbp=flow.sbp.broadcast\n    )\n    np_scores = global_scores.numpy()\n    global_scores = global_scores.to_global(placement=placement, sbp=sbp)\n\n    keep_np = nms_np(np_boxes, np_scores, iou)\n\n    keep = flow.nms(global_boxes, global_scores, iou)\n    test_case.assertTrue(np.allclose(keep.numpy(), keep_np))\n\n\nclass TestNMS(flow.unittest.TestCase):\n    @globaltest\n    def test_nms(test_case):\n        for placement in all_placement():\n            # TODO: nms only has cuda kernel at now.\n            if placement.type == \"cpu\":\n                continue\n            for sbp in all_sbp(placement, max_dim=1):\n                _test_nms(test_case, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_normal.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\nfrom oneflow.test_utils.test_util import GenArgList, type_name_to_flow_type\nfrom oneflow.test_utils.automated_test_util import *\nimport oneflow as flow\n\n\ndef _test_global_normal(\n    test_case, placement, sbp, mean, std, shape, dtype, requires_grad\n):\n    dtype = type_name_to_flow_type[dtype]\n    x = flow.normal(\n        mean,\n        std,\n        shape,\n        placement=placement,\n        sbp=sbp,\n        dtype=dtype,\n        requires_grad=requires_grad,\n    )\n\n    test_case.assertEqual(x.shape, shape)\n    test_case.assertEqual(x.dtype, dtype)\n    test_case.assertEqual(x.sbp, sbp)\n    test_case.assertEqual(x.placement, placement)\n    test_case.assertEqual(x.requires_grad, requires_grad)\n\n\nclass TestNormalGlobal(flow.unittest.TestCase):\n    @globaltest\n    def test_normal_global(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"mean\"] = [-1, 0, 1]\n        arg_dict[\"std\"] = [1, 2, 8]\n        arg_dict[\"shape\"] = [(8, 8), (8, 8, 8), (8, 8, 8, 8)]\n        arg_dict[\"dtype\"] = [\"float32\", \"double\"]\n        arg_dict[\"requires_grad\"] = [True, False]\n        for arg in GenArgList(arg_dict):\n            for placement in all_placement():\n                for sbp in all_sbp(\n                    placement, max_dim=len(arg[2]), except_partial_sum=True\n                ):\n                    _test_global_normal(test_case, placement, sbp, *arg)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_normalize.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=1, check_graph=True)\ndef _test_functional_normalize(test_case, placement, sbp):\n    ndim = random(low=2, high=5).to(int).value()\n    shape = [random(low=2, high=3) * 8 for i in range(ndim)]\n    x = random_tensor(len(shape), *shape).to_global(placement=placement, sbp=sbp)\n    dim = random(low=0, high=ndim).to(int).value()\n    y = torch.nn.functional.normalize(x, oneof(2, 3, 4), dim, 1e-12)\n    return y\n\n\nclass TestModule(flow.unittest.TestCase):\n    @globaltest\n    def test_normalize_with_random_data(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=2):\n                _test_functional_normalize(test_case, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_nozero.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n# Not check graph because of one reason:\n# Reason 1, lazy tensor cannot call numpy(), tensor.numpy() is not allowed to called in nn.Graph.build(*args) or called by lazy tensor.\n# Please refer to File \"python/oneflow/nn/modules/nonzero.py\", line 29, in nonzero_op.\n@autotest(n=1, auto_backward=False, check_graph=\"ValidatedFalse\")\ndef _test_nonzero(test_case, placement, sbp, ndim):\n    shape = [8 for _ in range(ndim)]\n    x = random_tensor(ndim, *shape).to_global(placement, sbp)\n    return torch.nonzero(x)\n\n\nclass TestNonZero(flow.unittest.TestCase):\n    @globaltest\n    def test_nonzero(test_case):\n        ndim = random(2, 5).to(int).value()\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=ndim):\n                _test_nonzero(test_case, placement, sbp, ndim)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_ones_like.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\nimport numpy as np\nfrom oneflow.test_utils.test_util import GenArgList\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\ndef _test_ones_like_float(test_case, placement, sbp, shape, device):\n    x = flow.tensor(\n        np.random.randn(*shape), dtype=flow.float32, device=flow.device(device)\n    )\n    x = x.to_global(placement=placement, sbp=sbp)\n    y = flow.ones_like(x, placement=placement, sbp=sbp)\n    test_case.assertTrue(y.dtype is flow.float32)\n    test_case.assertTrue(y.shape == x.shape)\n    test_case.assertTrue(y.placement == placement)\n    y_numpy = np.ones(x.numpy().shape)\n    print(\"y_numpy: \", y_numpy)\n    print(\"y.numpy()\", y.numpy())\n\n    test_case.assertTrue(np.array_equal(y.numpy(), y_numpy))\n\n\ndef _test_ones_like_int(test_case, placement, sbp, shape, device):\n    x = flow.tensor(np.random.randn(*shape), dtype=flow.int, device=flow.device(device))\n    x = x.to_global(placement=placement, sbp=sbp)\n    y = flow.ones_like(x, dtype=flow.int, placement=placement, sbp=sbp)\n    test_case.assertTrue(y.dtype is flow.int)\n    test_case.assertTrue(y.shape == x.shape)\n    test_case.assertTrue(y.placement == placement)\n    y_numpy = np.ones(x.numpy().shape)\n    test_case.assertTrue(np.array_equal(y.numpy(), y_numpy))\n\n\nclass TestModule(flow.unittest.TestCase):\n    @unittest.skip(\"TODO: global ones_like test will fail!\")\n    @globaltest\n    def test_ones_like(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [_test_ones_like_float, _test_ones_like_int]\n        arg_dict[\"shape\"] = [(8, 8), (8, 8, 4), (8, 8, 5, 6)]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            for placement in all_placement():\n                for sbp in all_sbp(placement, max_dim=2):\n                    arg[0](test_case, placement, sbp, *arg[1:])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_pad.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\n\nimport oneflow as flow\nfrom oneflow.test_utils.automated_test_util import *\nimport oneflow.unittest\n\n\n@autotest(n=1, check_graph=True)\ndef _test_pad_1d_impl(test_case, placement, sbp):\n    pad = [random(0, 5).to(int) for i in range(2)]\n    x = random_tensor(\n        ndim=3, dim0=8, dim1=random(2, 8).to(int) * 8, dim2=random(2, 8).to(int) * 8\n    ).to_global(placement=placement, sbp=sbp)\n    y = torch.nn.functional.pad(x, pad, mode=oneof(\"constant\", \"reflect\", \"replicate\"))\n    return y\n\n\n@autotest(n=1, check_graph=True)\ndef _test_pad_2d_impl(test_case, placement, sbp):\n    pad = [random(0, 5).to(int) for i in range(4)]\n    x = random_tensor(\n        ndim=4,\n        dim0=8,\n        dim1=8,\n        dim2=random(2, 8).to(int) * 8,\n        dim3=random(2, 8).to(int) * 8,\n    ).to_global(placement=placement, sbp=sbp)\n    y = torch.nn.functional.pad(x, pad, mode=oneof(\"constant\", \"reflect\", \"replicate\"))\n    return y\n\n\nclass TestPad(flow.unittest.TestCase):\n    @globaltest\n    def test_pad_1d(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=2):\n                _test_pad_1d_impl(test_case, placement, sbp)\n                _test_pad_2d_impl(test_case, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_partical_fc.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\n\nimport numpy as np\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\nclass TestParitalFC(flow.unittest.TestCase):\n    @globaltest\n    def test_parital_fc(test_case):\n        placement = flow.placement.all(\"cuda\")\n        w = flow.randn(5000, 128, placement=placement, sbp=flow.sbp.split(0))\n        label = flow.randint(\n            0, 5000, (512,), placement=placement, sbp=flow.sbp.split(0)\n        )\n        num_sample = 500\n        out = flow.distributed_partial_fc_sample(w, label, num_sample)\n        test_case.assertTrue(out[0].shape == flow.Size([512]))\n        test_case.assertTrue(out[1].shape == flow.Size([500]))\n        test_case.assertTrue(out[2].shape == flow.Size([500, 128]))\n\n        w = flow.randn(5000, 128, placement=placement, sbp=flow.sbp.broadcast)\n        label = flow.randint(\n            0, 5000, (512,), placement=placement, sbp=flow.sbp.split(0)\n        )\n        num_sample = 500\n        out = flow.distributed_partial_fc_sample(w, label, num_sample)\n        test_case.assertTrue(out[0].shape == flow.Size([512]))\n        test_case.assertTrue(out[1].shape == flow.Size([500]))\n        test_case.assertTrue(out[2].shape == flow.Size([500, 128]))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_permute.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=1, check_graph=True)\ndef _test_permute4d_tensor_with_random_data(test_case, placement, sbp):\n    ndim = 4\n    permute_list = [1, 2, 3, 0]\n    x = random_tensor(\n        ndim=ndim, dim0=8, dim1=8, dim2=random(2, 8).to(int), dim3=random(2, 8).to(int),\n    ).to_global(placement=placement, sbp=sbp)\n    y = x.permute(permute_list)\n    return y\n\n\nclass TestModule(flow.unittest.TestCase):\n    @globaltest\n    def test_permute4d_tensor_with_random_data(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=2):\n                _test_permute4d_tensor_with_random_data(test_case, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_rand.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\nfrom oneflow.test_utils.test_util import GenArgDict\n\n\ndef _test_global_rand(test_case, shape, placement, sbp):\n    x = flow.rand(*shape, placement=placement, sbp=sbp)\n\n    test_case.assertEqual(x.shape, flow.Size(shape))\n    test_case.assertEqual(x.sbp, sbp)\n    test_case.assertEqual(x.placement, placement)\n\n\ndef _test_graph_rand(test_case, shape, placement, sbp):\n    class GlobalRandGraph(flow.nn.Graph):\n        def __init__(self,):\n            super().__init__()\n\n        def build(self):\n            x = flow.rand(*shape, placement=placement, sbp=sbp)\n            return x\n\n    model = GlobalRandGraph()\n    x = model()\n\n    test_case.assertEqual(x.shape, flow.Size(shape))\n    test_case.assertEqual(x.sbp, sbp)\n    test_case.assertEqual(x.placement, placement)\n\n\nclass TestRandGlobal(flow.unittest.TestCase):\n    @globaltest\n    def test_rand_global(test_case):\n        shapes = [(8,), (8, 8,), (8, 8, 8)]\n        for shape in shapes:\n            for placement in all_placement():\n                for sbp in all_sbp(\n                    placement, max_dim=len(shape), except_partial_sum=True\n                ):\n                    _test_global_rand(test_case, shape, placement, sbp)\n\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    @flow.unittest.skip_unless_1n2d()\n    def test_rand_graph(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"shape\"] = [(8,), (8, 8,), (8, 8, 8)]\n        arg_dict[\"placement\"] = [\n            # 1d\n            flow.placement(\"cpu\", ranks=[0, 1]),\n            flow.placement(\"cuda\", ranks=[0, 1]),\n            # 2d\n            flow.placement(\"cpu\", ranks=[[0, 1],]),\n            flow.placement(\"cuda\", ranks=[[0, 1],]),\n        ]\n        for args in GenArgDict(arg_dict):\n            shape = args[\"shape\"]\n            placement = args[\"placement\"]\n            for sbp in all_sbp(placement, max_dim=len(shape), except_partial_sum=True):\n                _test_graph_rand(test_case, shape, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_randint.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\nfrom oneflow.test_utils.test_util import GenArgDict\n\n\ndef _test_global_randint(test_case, shape, placement, sbp, dtype):\n    x = flow.randint(1, 10, shape, placement=placement, sbp=sbp, dtype=dtype)\n\n    test_case.assertEqual(x.shape, flow.Size(shape))\n    test_case.assertEqual(x.sbp, sbp)\n    test_case.assertEqual(x.placement, placement)\n    test_case.assertEqual(x.dtype, dtype)\n\n\ndef _test_graph_randint(test_case, shape, placement, sbp, dtype):\n    class GlobalRandintGraph(flow.nn.Graph):\n        def __init__(self,):\n            super().__init__()\n\n        def build(self):\n            x = flow.randint(1, 10, shape, placement=placement, sbp=sbp, dtype=dtype)\n            return x\n\n    model = GlobalRandintGraph()\n    x = model()\n\n    test_case.assertEqual(x.shape, flow.Size(shape))\n    test_case.assertEqual(x.sbp, sbp)\n    test_case.assertEqual(x.placement, placement)\n    test_case.assertEqual(x.dtype, dtype)\n\n\nclass TestRandintGlobal(flow.unittest.TestCase):\n    @globaltest\n    def test_randint_global(test_case):\n        shapes = [(8,), (8, 8,), (8, 8, 8)]\n        dtypes = [\n            flow.uint8,\n            flow.int8,\n            flow.int32,\n            flow.int64,\n            flow.float32,\n            flow.float64,\n        ]\n        for shape in shapes:\n            for placement in all_placement():\n                for sbp in all_sbp(\n                    placement, max_dim=len(shape), except_partial_sum=True\n                ):\n                    for dtype in dtypes:\n                        _test_global_randint(test_case, shape, placement, sbp, dtype)\n\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    @flow.unittest.skip_unless_1n2d()\n    def test_randint_graph(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"shape\"] = [(8,), (8, 8,), (8, 8, 8)]\n        arg_dict[\"dtype\"] = [\n            flow.uint8,\n            flow.int32,\n            flow.float32,\n        ]\n        arg_dict[\"placement\"] = [\n            # 1d\n            flow.placement(\"cpu\", ranks=[0, 1]),\n            flow.placement(\"cuda\", ranks=[0, 1]),\n            # 2d\n            flow.placement(\"cpu\", ranks=[[0, 1],]),\n            flow.placement(\"cuda\", ranks=[[0, 1],]),\n        ]\n        for args in GenArgDict(arg_dict):\n            shape = args[\"shape\"]\n            placement = args[\"placement\"]\n            dtype = args[\"dtype\"]\n            for sbp in all_sbp(placement, max_dim=len(shape), except_partial_sum=True):\n                _test_graph_randint(test_case, shape, placement, sbp, dtype)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_randint_like.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\nfrom oneflow.test_utils.test_util import GenArgDict\n\n\ndef _test_consistent_randint_like(test_case, shape, placement, sbp, dtype):\n    x_ = flow.randint(1, 10, shape)\n    x = flow.randint_like(x_, 1, 10, placement=placement, sbp=sbp, dtype=dtype)\n\n    test_case.assertEqual(x.shape, flow.Size(shape))\n    test_case.assertEqual(x.sbp, sbp)\n    test_case.assertEqual(x.placement, placement)\n    test_case.assertEqual(x.dtype, dtype)\n\n\ndef _test_graph_randint_like(test_case, shape, placement, sbp, dtype):\n    class ConsistentRandIntLikeGraph(flow.nn.Graph):\n        def __init__(self,):\n            super().__init__()\n\n        def build(self):\n            x_ = flow.randint(1, 10, shape)\n            x = flow.randint_like(x_, 1, 10, placement=placement, sbp=sbp, dtype=dtype)\n            return x\n\n    model = ConsistentRandIntLikeGraph()\n    x = model()\n\n    test_case.assertEqual(x.shape, flow.Size(shape))\n    test_case.assertEqual(x.sbp, sbp)\n    test_case.assertEqual(x.placement, placement)\n    test_case.assertEqual(x.dtype, dtype)\n\n\nclass TestRandIntLikeConsistent(flow.unittest.TestCase):\n    @globaltest\n    def test_randint_like_consistent(test_case):\n        shapes = [(8,), (8, 8,), (8, 8, 8)]\n        dtypes = [\n            flow.uint8,\n            flow.int8,\n            flow.int32,\n            flow.int64,\n            flow.float32,\n            flow.float64,\n        ]\n        for shape in shapes:\n            for placement in all_placement():\n                for sbp in all_sbp(\n                    placement, max_dim=len(shape), except_partial_sum=True\n                ):\n                    for dtype in dtypes:\n                        _test_consistent_randint_like(\n                            test_case, shape, placement, sbp, dtype\n                        )\n\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    @flow.unittest.skip_unless_1n2d()\n    def test_randint_like_graph(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"shape\"] = [(8,), (8, 8,), (8, 8, 8)]\n        arg_dict[\"dtype\"] = [\n            flow.uint8,\n            flow.int32,\n            flow.float32,\n        ]\n        arg_dict[\"placement\"] = [\n            # 1d\n            flow.placement(\"cpu\", ranks=[0, 1]),\n            flow.placement(\"cuda\", ranks=[0, 1]),\n            # 2d\n            flow.placement(\"cpu\", ranks=[[0, 1],]),\n            flow.placement(\"cuda\", ranks=[[0, 1],]),\n        ]\n        for args in GenArgDict(arg_dict):\n            shape = args[\"shape\"]\n            placement = args[\"placement\"]\n            dtype = args[\"dtype\"]\n            for sbp in all_sbp(placement, max_dim=len(shape), except_partial_sum=True):\n                _test_graph_randint_like(test_case, shape, placement, sbp, dtype)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_randn.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport oneflow as flow\nimport numpy as np\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\nfrom oneflow.test_utils.test_util import GenArgDict\n\n\ndef _test_global_randn(test_case, shape, placement, sbp):\n    x1 = flow.randn(*shape, placement=placement, sbp=sbp)\n    x2 = flow.randn(*shape, placement=placement, sbp=sbp)\n    test_case.assertTrue(not np.allclose(x1.numpy(), x2.numpy(), atol=1e-4, rtol=1e-4))\n    test_case.assertEqual(x1.shape, flow.Size(shape))\n    test_case.assertEqual(x1.sbp, sbp)\n    test_case.assertEqual(x1.placement, placement)\n\n\ndef _test_different_dtype(test_case, shape, placement, sbp):\n    x1 = flow.randn(*shape, dtype=flow.float32, placement=placement, sbp=sbp)\n    x2 = flow.randn(*shape, dtype=flow.float64, placement=placement, sbp=sbp)\n    test_case.assertTrue(not np.allclose(x1.numpy(), x2.numpy(), atol=1e-4, rtol=1e-4))\n    test_case.assertEqual(x1.shape, flow.Size(shape))\n\n\ndef _test_backward(test_case, shape, placement, sbp):\n    x = flow.randn(*shape, placement=placement, sbp=sbp, requires_grad=True)\n    y = x.sum()\n    y.backward()\n    test_case.assertTrue(\n        np.allclose(np.ones(shape), x.grad.numpy(), atol=1e-4, rtol=1e-4)\n    )\n\n\ndef _test_with_generator(test_case, shape, placement, sbp):\n    gen = flow.Generator()\n    gen.manual_seed(0)\n    y1 = flow.randn(*shape, placement=placement, sbp=sbp, generator=gen)\n    gen.manual_seed(0)\n    y2 = flow.randn(*shape, placement=placement, sbp=sbp, generator=gen)\n    test_case.assertTrue(np.allclose(y1.numpy(), y2.numpy(), atol=1e-4, rtol=1e-4))\n\n\ndef _test_randn_tuple_shape(test_case, shape, placement, sbp):\n    y1 = flow.randn(*shape, placement=placement, sbp=sbp)\n    y2 = flow.randn(*shape, placement=placement, sbp=sbp)\n\n    test_case.assertTrue(not np.array_equal(y1.numpy(), y2.numpy()))\n    test_case.assertTrue(shape == y1.shape)\n\n\ndef _test_graph_randn(test_case, shape, placement, sbp):\n    class GlobalRandnGraph(flow.nn.Graph):\n        def __init__(self,):\n            super().__init__()\n\n        def build(self):\n            x = flow.randn(*shape, placement=placement, sbp=sbp)\n            return x\n\n    model = GlobalRandnGraph()\n    x = model()\n\n    test_case.assertEqual(x.shape, flow.Size(shape))\n    test_case.assertEqual(x.sbp, sbp)\n    test_case.assertEqual(x.placement, placement)\n\n\nclass TestRandnGlobal(flow.unittest.TestCase):\n    @globaltest\n    def test_randn_global(test_case):\n        shapes = [(8,), (8, 8,), (8, 8, 8)]\n        for shape in shapes:\n            for placement in all_placement():\n                for sbp in all_sbp(\n                    placement, max_dim=len(shape), except_partial_sum=True\n                ):\n                    _test_global_randn(test_case, shape, placement, sbp)\n                    _test_different_dtype(test_case, shape, placement, sbp)\n                    _test_backward(test_case, shape, placement, sbp)\n                    _test_with_generator(test_case, shape, placement, sbp)\n                    _test_randn_tuple_shape(test_case, shape, placement, sbp)\n\n    @flow.unittest.skip_unless_1n2d()\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    @globaltest\n    def test_randn_graph(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"shape\"] = [(8,), (8, 8,), (8, 8, 8)]\n        arg_dict[\"placement\"] = [\n            # 1d\n            flow.placement(\"cpu\", ranks=[0, 1]),\n            flow.placement(\"cuda\", ranks=[0, 1]),\n            # 2d\n            flow.placement(\"cpu\", ranks=[[0, 1],]),\n            flow.placement(\"cuda\", ranks=[[0, 1],]),\n        ]\n        for args in GenArgDict(arg_dict):\n            shape = args[\"shape\"]\n            placement = args[\"placement\"]\n            for sbp in all_sbp(placement, max_dim=len(shape), except_partial_sum=True):\n                _test_graph_randn(test_case, shape, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_random_op_data.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport oneflow as flow\nimport numpy as np\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\nfrom oneflow.test_utils.test_util import GenArgDict\n\n\n_fn_param = {\n    \"normal\": lambda shape, placement, sbp: flow.normal(\n        size=shape, mean=0.0, std=1.0, placement=placement, sbp=sbp\n    ),\n    \"rand\": lambda shape, placement, sbp: flow.rand(\n        size=shape, placement=placement, sbp=sbp\n    ),\n    \"randint\": lambda shape, placement, sbp: flow.randint(\n        low=0, high=2, size=shape, placement=placement, sbp=sbp\n    ),\n    \"randn\": lambda shape, placement, sbp: flow.randn(\n        size=shape, placement=placement, sbp=sbp\n    ),\n}\n\n\ndef _test_data_consistent(test_case, shape, placement, sbp, fn):\n    # lazy result\n    class GlobalRandnGraph(flow.nn.Graph):\n        def __init__(self):\n            super().__init__()\n\n        def build(self):\n            flow.manual_seed(233)\n            x = fn(shape, placement, sbp)\n            return x\n\n    model = GlobalRandnGraph()\n    lazy_x = model()\n\n    # eager result\n    flow.manual_seed(233)\n    eager_x = fn(shape, placement, sbp)\n\n    test_case.assertTrue(\n        np.array_equal(lazy_x.to_local().numpy(), eager_x.to_local().numpy())\n    )\n\n    # different data\n    eager_x2 = fn(shape, placement, sbp)\n\n    test_case.assertFalse(\n        np.array_equal(eager_x.to_local().numpy(), eager_x2.to_local().numpy())\n    )\n\n\nclass TestGlobalRandomOpData(flow.unittest.TestCase):\n    @unittest.skip(\"skip for now, becase it failed 4 times in past week\")\n    @globaltest\n    def test_random_op_data_consistent_with_eager_and_lazy(test_case):\n        shape = (8, 8)\n\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=2, except_partial_sum=True):\n                for _, fn in _fn_param.items():\n                    _test_data_consistent(test_case, shape, placement, sbp, fn=fn)\n\n    @globaltest\n    @oneflow.unittest.skip_unless_1n4d()\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    def test_random_op_data_correctness(test_case):\n        shape = (8, 8)\n        sbp = [flow.sbp.split(0), flow.sbp.broadcast]\n\n        for device in [\"cpu\", \"cuda\"]:\n            placement = flow.placement(device, [[0, 1], [2, 3]])\n\n            for _, fn in _fn_param.items():\n                flow.manual_seed(233)\n                local_tensor = fn(shape, placement, sbp).to_local().cpu()\n\n                # broadcast local data for each rank\n                rank_to_tensor = [\n                    local_tensor\n                    if rank_id == flow.env.get_rank()\n                    else flow.empty(local_tensor.shape, dtype=local_tensor.dtype)\n                    for rank_id in range(4)\n                ]\n                for rank_id in range(4):\n                    flow.comm.broadcast(rank_to_tensor[rank_id], rank_id)\n\n                np_local = [x.numpy() for x in rank_to_tensor]\n                # rank0 == rank1\n                test_case.assertTrue(np.array_equal(np_local[0], np_local[1]))\n                # rank2 == rank3\n                test_case.assertTrue(np.array_equal(np_local[2], np_local[3]))\n                # rank0 != rank2\n                test_case.assertFalse(np.array_equal(np_local[0], np_local[2]))\n                # rank1 != rank3\n                test_case.assertFalse(np.array_equal(np_local[1], np_local[3]))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_randperm.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nfrom collections import OrderedDict\n\nimport oneflow as flow\nimport numpy as np\nfrom oneflow.test_utils.automated_test_util import *\n\nfrom oneflow.test_utils.test_util import GenArgDict\n\n\ndef _test_global_randperm(test_case, N, placement, sbp, dtype):\n    x = flow.randperm(N, placement=placement, sbp=sbp, dtype=dtype)\n    # TODO:Synchronously get a global random seed, and then each rank sets its own seed in manual_seeds\n    test_case.assertEqual(x.dtype, dtype)\n    test_case.assertEqual(x.sbp, sbp)\n    test_case.assertEqual(x.placement, placement)\n\n\ndef _test_graph_randperm(test_case, N, placement, sbp, dtype):\n    class GlobalRandpermGraph(flow.nn.Graph):\n        def __init__(self,):\n            super().__init__()\n\n        def build(self):\n            x = flow.randperm(N, placement=placement, sbp=sbp, dtype=dtype)\n            return x\n\n    model = GlobalRandpermGraph()\n    x = model()\n    y1 = x.to_global(placement=placement, sbp=sbp)\n    y1_np_sort = np.sort(y1.numpy())\n    y2 = np.arange(N)\n    test_case.assertTrue(np.allclose(y1_np_sort, y2, atol=1e-4, rtol=1e-4))\n    test_case.assertEqual(x.dtype, dtype)\n    test_case.assertEqual(x.sbp, sbp)\n    test_case.assertEqual(x.placement, placement)\n\n\n@unittest.skip(\"This fails in multi-gpu\")\nclass TestRandpermGlobal(flow.unittest.TestCase):\n    @globaltest\n    def test_randperm_global(test_case):\n        RandNs = [i for i in range(10, 50, 10)]\n        # TODO support uint8,int8,int64,float32,float64,data type test\n        Dtypes = [\n            flow.int32,\n        ]\n        for N in RandNs:\n            for placement in all_placement():\n                for sbp in all_sbp(placement, max_dim=1, except_partial_sum=True):\n                    for dtype in Dtypes:\n                        _test_global_randperm(test_case, N, placement, sbp, dtype)\n\n    @flow.unittest.skip_unless_1n2d()\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    @globaltest\n    def test_randperm_graph(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"N\"] = [i for i in range(10, 50, 10)]\n        arg_dict[\"placement\"] = [\n            # 1d\n            flow.placement(\"cpu\", ranks=[0, 1]),\n            flow.placement(\"cuda\", ranks=[0, 1]),\n            # 2d\n            flow.placement(\"cpu\", ranks=[[0, 1],]),\n            flow.placement(\"cuda\", ranks=[[0, 1],]),\n        ]\n        arg_dict[\"dtype\"] = [\n            flow.uint8,\n            flow.int8,\n            flow.int32,\n            flow.int64,\n            flow.float32,\n            flow.float64,\n        ]\n        for args in GenArgDict(arg_dict):\n            N = args[\"N\"]\n            placement = args[\"placement\"]\n            dtype = args[\"dtype\"]\n            for sbp in all_sbp(placement, max_dim=1, except_partial_sum=True):\n                _test_graph_randperm(test_case, N, placement, sbp, dtype)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_reciprocal.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\n\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=1, check_graph=True)\ndef _test_reciprocal_impl(test_case, ndim, placement, sbp):\n    dims = [random(1, 4) * 8 for _ in range(ndim)]\n    x = random_tensor(ndim, *dims)\n    y = x.to_global(placement=placement, sbp=sbp)\n    z = torch.reciprocal(y)\n    return z\n\n\nclass TestReciprocalGlobal(flow.unittest.TestCase):\n    @globaltest\n    def test_reciprocal(test_case):\n        # random ndim in range [1,4]\n        ndim = random(1, 5).to(int).value()\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=ndim):\n                _test_reciprocal_impl(test_case, ndim, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_reflection_pad2d.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\n\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=1, check_graph=False)\ndef _test_reflection_pad2d_impl(test_case, padding, placement, sbp):\n    m = torch.nn.ReflectionPad2d(padding=padding)\n    dims = [random(2, 4) * 8 for _ in range(4)]\n    x = random_tensor(4, *dims)\n    y = x.to_global(placement=placement, sbp=sbp)\n    z = m(y)\n    return z\n\n\nclass TestReflectionPad2dGlobal(flow.unittest.TestCase):\n    @globaltest\n    def test_reflection_pad2d(test_case):\n        padding = [\n            (2, 2, 1, 1),\n            1,\n            (1, 0, 1, 0),\n            (0, 1, 0, 1),\n        ]\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=4):\n                for pad in padding:\n                    _test_reflection_pad2d_impl(test_case, pad, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_repeat.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\n\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=2, check_graph=True)\ndef _test_repeat_impl(test_case, ndim, placement, sbp):\n    dims = [random(1, 4).to(int).value() * 8 for _ in range(ndim)]\n    repeat_size = [random(1, 3).to(int).value() for _ in range(ndim)]\n    x = random_tensor(ndim, *dims)\n    y = x.to_global(placement=placement, sbp=sbp)\n    z = y.repeat(repeat_size)\n    return z\n\n\nclass TestRepeatGlobal(flow.unittest.TestCase):\n    @unittest.skip(\"skip for now, becase it failed 2 times in past week\")\n    @globaltest\n    def test_repeat(test_case):\n        # random ndim in range [1,3]\n        ndim = random(1, 4).to(int).value()\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=ndim):\n                _test_repeat_impl(test_case, ndim, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_replication_pad2d.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\n\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=1, check_graph=False)\ndef _test_replication_pad2d_impl(test_case, padding, placement, sbp):\n    m = torch.nn.ReplicationPad2d(padding=padding)\n    dims = [random(2, 4) * 8 for _ in range(4)]\n    x = random_tensor(4, *dims)\n    y = x.to_global(placement=placement, sbp=sbp)\n    z = m(y)\n    return z\n\n\nclass TestReplicationPad2dGlobal(flow.unittest.TestCase):\n    @globaltest\n    def test_replication_pad2d(test_case):\n        padding = [\n            (2, 2, 1, 1),\n            1,\n            (1, 0, 1, 0),\n            (0, 1, 0, 1),\n        ]\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=4):\n                for pad in padding:\n                    _test_replication_pad2d_impl(test_case, pad, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_reshape.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\n\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=1, check_graph=True)\ndef _test_reshape_impl(test_case, pair, placement, sbp):\n    shape, to_shape = pair\n    x = random_tensor(len(shape), *shape)\n    y = x.to_global(placement=placement, sbp=sbp)\n    z = y.reshape(to_shape)\n    return z\n\n\ndef _test_reshape_like_impl(test_case, pair, placement, in_sbp, like_sbp):\n    shape, to_shape = pair\n\n    nd_arr = np.random.rand(*shape)\n    np_out = nd_arr.reshape(to_shape)\n\n    x = flow.tensor(nd_arr)\n    like = flow.empty(to_shape)\n    y = x.to_global(flow.placement.all(\"cpu\"), flow.sbp.broadcast).to_global(\n        placement=placement, sbp=in_sbp\n    )\n    like = like.to_global(flow.placement.all(\"cpu\"), flow.sbp.broadcast).to_global(\n        placement=placement, sbp=like_sbp\n    )\n    z = flow._C.reshape_like(y, like)\n    local_z = z.to_global(\n        placement, sbp=[flow.sbp.broadcast for _ in range(len(placement.ranks.shape))]\n    ).to_local()\n    if flow.env.get_rank() == 0:\n        test_case.assertTrue(np.array_equal(np_out, local_z.numpy()))\n\n\nclass TestReshapeGlobal(flow.unittest.TestCase):\n    @globaltest\n    def test_reshape(test_case):\n        shape_pairs = [\n            ((8, 16), (8 * 16,)),\n            ((8, 16), (8 * 4, 4)),\n            ((8, 16, 24), (64, 6, 8)),\n            ((8, 16), (64, 1, -1)),\n            ((8, 16), (-1,)),\n        ]\n        for pair in shape_pairs:\n            for placement in all_placement():\n                for sbp in all_sbp(placement, max_dim=len(pair[0])):\n                    _test_reshape_impl(test_case, pair, placement, sbp)\n\n    @globaltest\n    def test_reshape_like(test_case):\n        shape_pairs = [\n            ((8, 16), (8 * 16,)),\n            ((8, 16), (8 * 2, 8)),\n            ((8, 16, 24), (64, 48)),\n        ]\n        for pair in shape_pairs:\n            for placement in all_placement():\n                for in_sbp in all_sbp(placement, max_dim=len(pair[0])):\n                    for like_sbp in all_sbp(placement, max_dim=len(pair[1])):\n                        _test_reshape_like_impl(\n                            test_case, pair, placement, in_sbp, like_sbp\n                        )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_rnn.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nimport oneflow as flow\nimport oneflow.unittest\nimport torch\nfrom oneflow.test_utils.automated_test_util.generators import *\nfrom oneflow.test_utils.automated_test_util.torch_flow_dual_object import globaltest\nfrom oneflow.test_utils.test_util import GenArgDict\n\n\ndef _compare_torch_and_oneflow(\n    test_case, m_torch, m_flow, placement, module_sbp, in_sbp, input_size\n):\n    torch_state_dict = m_torch.state_dict()\n    new_dict = {}\n    for k, v in torch_state_dict.items():\n        new_dict[k] = v.detach().numpy()\n    m_flow.load_state_dict(new_dict)\n\n    m_flow = m_flow.to_global(flow.placement.all(\"cpu\"), flow.sbp.broadcast).to_global(\n        placement=placement, sbp=[module_sbp for _ in range(len(placement.ranks.shape))]\n    )\n\n    x = np.random.rand(32, 16, input_size).astype(np.float32)\n    x_torch = torch.tensor(x, dtype=torch.float32, requires_grad=True)\n    x_flow = (\n        flow.tensor(x, dtype=flow.float32, requires_grad=True)\n        .to_global(flow.placement.all(\"cpu\"), flow.sbp.broadcast)\n        .to_global(placement=placement, sbp=in_sbp)\n    )\n\n    out_torch, hid_torch = m_torch(x_torch)\n    out_flow, hid_flow = m_flow(x_flow)\n\n    # check forward\n    local_output = out_flow.to_global(\n        placement=placement,\n        sbp=[flow.sbp.broadcast for _ in range(len(placement.ranks.shape))],\n    ).to_local()\n    if flow.env.get_rank() == 0:\n        test_case.assertTrue(\n            np.allclose(\n                out_torch.cpu().detach().numpy(),\n                local_output.numpy(),\n                rtol=1e-05,\n                atol=1e-05,\n            )\n        )\n\n    # check backward\n    out_torch.sum().backward()\n    out_flow.sum().backward()\n    local_x_grad = x_flow.to_global(\n        placement=placement,\n        sbp=[flow.sbp.broadcast for _ in range(len(placement.ranks.shape))],\n    ).to_local()\n    if flow.env.get_rank() == 0:\n        test_case.assertTrue(\n            np.allclose(\n                x_torch.cpu().detach().numpy(),\n                local_x_grad.numpy(),\n                rtol=1e-05,\n                atol=1e-05,\n            )\n        )\n\n\ndef _test_rnn_impl(\n    test_case,\n    placement,\n    module_sbp,\n    in_sbp,\n    input_size,\n    hidden_size,\n    num_layers,\n    nonlinearity,\n    bias,\n    batch_first,\n    dropout,\n    bidirectional,\n):\n    rnn_torch = torch.nn.RNN(\n        input_size=input_size,\n        hidden_size=hidden_size,\n        num_layers=num_layers,\n        nonlinearity=nonlinearity,\n        bias=bias,\n        batch_first=batch_first,\n        dropout=dropout,\n        bidirectional=bidirectional,\n    )\n    rnn_flow = flow.nn.RNN(\n        input_size=input_size,\n        hidden_size=hidden_size,\n        num_layers=num_layers,\n        nonlinearity=nonlinearity,\n        bias=bias,\n        batch_first=batch_first,\n        dropout=dropout,\n        bidirectional=bidirectional,\n    )\n    _compare_torch_and_oneflow(\n        test_case, rnn_torch, rnn_flow, placement, module_sbp, in_sbp, input_size\n    )\n\n\ndef _test_lstm_impl(\n    test_case,\n    placement,\n    module_sbp,\n    in_sbp,\n    input_size,\n    hidden_size,\n    num_layers,\n    bias,\n    batch_first,\n    dropout,\n    bidirectional,\n    proj_size,\n):\n    lstm_torch = torch.nn.LSTM(\n        input_size=input_size,\n        hidden_size=hidden_size,\n        num_layers=num_layers,\n        bias=bias,\n        batch_first=batch_first,\n        dropout=dropout,\n        bidirectional=bidirectional,\n        proj_size=proj_size,\n    )\n    lstm_flow = flow.nn.LSTM(\n        input_size=input_size,\n        hidden_size=hidden_size,\n        num_layers=num_layers,\n        bias=bias,\n        batch_first=batch_first,\n        dropout=dropout,\n        bidirectional=bidirectional,\n        proj_size=proj_size,\n    )\n    _compare_torch_and_oneflow(\n        test_case, lstm_torch, lstm_flow, placement, module_sbp, in_sbp, input_size\n    )\n\n\ndef _test_gru_impl(\n    test_case,\n    placement,\n    module_sbp,\n    in_sbp,\n    input_size,\n    hidden_size,\n    num_layers,\n    bias,\n    batch_first,\n    dropout,\n    bidirectional,\n):\n    gru_torch = torch.nn.GRU(\n        input_size=input_size,\n        hidden_size=hidden_size,\n        num_layers=num_layers,\n        bias=bias,\n        batch_first=batch_first,\n        dropout=dropout,\n        bidirectional=bidirectional,\n    )\n    gru_flow = flow.nn.GRU(\n        input_size=input_size,\n        hidden_size=hidden_size,\n        num_layers=num_layers,\n        bias=bias,\n        batch_first=batch_first,\n        dropout=dropout,\n        bidirectional=bidirectional,\n    )\n    _compare_torch_and_oneflow(\n        test_case, gru_torch, gru_flow, placement, module_sbp, in_sbp, input_size\n    )\n\n\nclass TestRNNGlobal(oneflow.unittest.TestCase):\n    @globaltest\n    def test_rnn(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"input_size\"] = [\n            1,\n        ]\n        arg_dict[\"hidden_size\"] = [\n            1,\n        ]\n        arg_dict[\"num_layers\"] = [\n            1,\n        ]\n        arg_dict[\"nonlinearity\"] = [\"tanh\", \"relu\"]\n        arg_dict[\"bias\"] = [True, False]\n        arg_dict[\"batch_first\"] = [True, False]\n        arg_dict[\"dropout\"] = [\n            0,\n        ]\n        arg_dict[\"bidirectional\"] = [True, False]\n\n        module_sbp = flow.sbp.broadcast\n        for args in GenArgDict(arg_dict):\n            for placement in all_placement():\n                for in_sbp in all_sbp(placement, max_dim=3, valid_split_axis=1):\n                    _test_rnn_impl(test_case, placement, module_sbp, in_sbp, **args)\n\n    @globaltest\n    def test_lstm(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"input_size\"] = [\n            1,\n        ]\n        arg_dict[\"hidden_size\"] = [\n            2,\n        ]\n        arg_dict[\"num_layers\"] = [\n            1,\n        ]\n        arg_dict[\"bias\"] = [True, False]\n        arg_dict[\"batch_first\"] = [True, False]\n        arg_dict[\"dropout\"] = [\n            0,\n        ]\n        arg_dict[\"bidirectional\"] = [True, False]\n        arg_dict[\"proj_size\"] = [0, 1]\n\n        module_sbp = flow.sbp.broadcast\n        for args in GenArgDict(arg_dict):\n            for placement in all_placement():\n                for in_sbp in all_sbp(placement, max_dim=3, valid_split_axis=1):\n                    _test_lstm_impl(test_case, placement, module_sbp, in_sbp, **args)\n\n    @globaltest\n    def test_gru(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"input_size\"] = [\n            1,\n        ]\n        arg_dict[\"hidden_size\"] = [\n            1,\n        ]\n        arg_dict[\"num_layers\"] = [\n            1,\n        ]\n        arg_dict[\"bias\"] = [True, False]\n        arg_dict[\"batch_first\"] = [True, False]\n        arg_dict[\"dropout\"] = [\n            0,\n        ]\n        arg_dict[\"bidirectional\"] = [True, False]\n\n        module_sbp = flow.sbp.broadcast\n        for args in GenArgDict(arg_dict):\n            for placement in all_placement():\n                for in_sbp in all_sbp(placement, max_dim=3, valid_split_axis=1):\n                    _test_gru_impl(test_case, placement, module_sbp, in_sbp, **args)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_rnn_cell.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\n\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=1, check_graph=False)\ndef _test_rnn_relu_cell(test_case, placement, sbp):\n    batch_size = random(2, 3) * 8\n    time_steps = random(2, 3) * 8\n    input_size = random(2, 3) * 8\n    hidden_size = random(2, 3) * 8\n    has_bias = random().to(bool)\n    m = torch.nn.RNNCell(\n        input_size=input_size,\n        hidden_size=hidden_size,\n        bias=has_bias,\n        nonlinearity=\"relu\",\n    )\n\n    weight_sbp = random_sbp(placement, max_dim=2, except_partial_sum=True)\n    m.weight_ih = torch.nn.Parameter(\n        m.weight_ih.to_global(placement=placement, sbp=weight_sbp)\n    )\n    m.weight_hh = torch.nn.Parameter(\n        m.weight_hh.to_global(placement=placement, sbp=weight_sbp)\n    )\n    if m.bias_ih is not None:\n        # bias is 1-d tensor\n        bias_sbp = random_sbp(placement, max_dim=1, except_partial_sum=True)\n        m.bias_ih = torch.nn.Parameter(\n            m.bias_ih.to_global(placement=placement, sbp=bias_sbp)\n        )\n        m.bias_hh = torch.nn.Parameter(\n            m.bias_hh.to_global(placement=placement, sbp=bias_sbp)\n        )\n\n    input_sbp = random_sbp(placement, max_dim=3, valid_split_axis=1)\n    input = random_tensor(\n        ndim=3, dim0=time_steps, dim1=batch_size, dim2=input_size\n    ).to_global(placement=placement, sbp=input_sbp)\n    hx = random_tensor(\n        ndim=2, dim0=batch_size, dim1=hidden_size, requires_grad=False\n    ).to_global(placement=placement, sbp=sbp)\n\n    for i in range(time_steps.to(int).value()):\n        hx = m(input[i], hx)\n\n    return hx\n\n\n@autotest(n=1, check_graph=False)\ndef _test_rnn_tanh_cell(test_case, placement, sbp):\n    batch_size = random(2, 3) * 8\n    time_steps = random(2, 3) * 8\n    input_size = random(2, 3) * 8\n    hidden_size = random(2, 3) * 8\n    has_bias = random().to(bool)\n    m = torch.nn.RNNCell(\n        input_size=input_size,\n        hidden_size=hidden_size,\n        bias=has_bias,\n        nonlinearity=\"tanh\",\n    )\n\n    weight_sbp = random_sbp(placement, max_dim=2, except_partial_sum=True)\n    m.weight_ih = torch.nn.Parameter(\n        m.weight_ih.to_global(placement=placement, sbp=weight_sbp)\n    )\n    m.weight_hh = torch.nn.Parameter(\n        m.weight_hh.to_global(placement=placement, sbp=weight_sbp)\n    )\n    if m.bias_ih is not None:\n        # bias is 1-d tensor\n        bias_sbp = random_sbp(placement, max_dim=1, except_partial_sum=True)\n        m.bias_ih = torch.nn.Parameter(\n            m.bias_ih.to_global(placement=placement, sbp=bias_sbp)\n        )\n        m.bias_hh = torch.nn.Parameter(\n            m.bias_hh.to_global(placement=placement, sbp=bias_sbp)\n        )\n\n    input_sbp = random_sbp(placement, max_dim=3, valid_split_axis=1)\n    input = random_tensor(\n        ndim=3, dim0=time_steps, dim1=batch_size, dim2=input_size\n    ).to_global(placement=placement, sbp=input_sbp)\n    hx = random_tensor(\n        ndim=2, dim0=batch_size, dim1=hidden_size, requires_grad=False\n    ).to_global(placement=placement, sbp=sbp)\n\n    for i in range(time_steps.to(int).value()):\n        hx = m(input[i], hx)\n\n    return hx\n\n\n@unittest.skip(\"TODO(depeng): fails often on 4 GPUs\")\nclass TestRNNCellGlobal(flow.unittest.TestCase):\n    @globaltest\n    def test_rnn_relu_cell(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=2):\n                _test_rnn_relu_cell(test_case, placement, sbp)\n\n    @globaltest\n    def test_rnn_tanh_cell(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=2):\n                _test_rnn_tanh_cell(test_case, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_roi_align.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\n\nimport numpy as np\nimport oneflow as flow\nimport oneflow.unittest\nimport torch as pytorch\nimport torchvision\nfrom oneflow.test_utils.automated_test_util import *\n\n\ndef _get_np_rois():\n    random_img_idx = np.asarray(\n        [random(0, 2).to(int).value() for _ in range(200)]\n    ).reshape((200, 1))\n    random_box_idx = np.asarray(\n        [random(0, 64 * 64).to(float).value() for _ in range(400)]\n    ).reshape((200, 2))\n\n    def get_h_w(idx1, idx2):\n        if idx1 > idx2:\n            idx1, idx2 = idx2, idx1\n        h1 = idx1 // 64\n        w1 = idx1 % 64\n        h2 = idx2 // 64\n        w2 = idx2 % 64\n        return [x / 2 for x in [h1, w1, h2, w2]]\n\n    zipped = zip(random_box_idx[:, 0], random_box_idx[:, 1])\n    concated = [get_h_w(idx1, idx2) for (idx1, idx2) in zipped]\n    concated = np.array(concated)\n    rois = np.hstack((random_img_idx, concated)).astype(np.float32)\n    return rois\n\n\ndef _test_roi_align(test_case, placement, rois_sbp):\n    dims = [8, 8, 64, 64]\n    x = random_tensor(4, *dims).to_global(\n        placement=placement,\n        sbp=[flow.sbp.broadcast for _ in range(len(placement.ranks.shape))],\n    )\n    x.oneflow = x.oneflow.detach().requires_grad_()\n    x.pytorch = x.pytorch.detach().requires_grad_()\n\n    def get_h_w(idx1, idx2):\n        if idx1 > idx2:\n            idx1, idx2 = idx2, idx1\n        h1 = idx1 // 64\n        w1 = idx1 % 64\n        h2 = idx2 // 64\n        w2 = idx2 % 64\n        return [x / 2 for x in [h1, w1, h2, w2]]\n\n    np_rois = _get_np_rois()\n    of_rois = (\n        flow.tensor(np_rois, dtype=flow.float)\n        .to_global(placement=flow.placement.all(\"cpu\"), sbp=[flow.sbp.broadcast,])\n        .to_global(placement, rois_sbp)\n    )\n    torch_rois = pytorch.tensor(np_rois)\n\n    of_out = flow.roi_align(x.oneflow, of_rois, 2.0, 14, 14, 2, True)\n    torch_out = torchvision.ops.roi_align(\n        x.pytorch,\n        torch_rois,\n        spatial_scale=2.0,\n        output_size=[14, 14],\n        sampling_ratio=2,\n        aligned=True,\n    )\n\n    # compare output\n    of_local = of_out.to_global(\n        placement=flow.placement.all(\"cpu\"), sbp=[flow.sbp.broadcast,]\n    ).to_local()\n    test_case.assertTrue(\n        np.allclose(\n            of_local.numpy(), torch_out.detach().cpu().numpy(), rtol=1e-04, atol=1e-4\n        )\n    )\n\n    # compare backward\n    of_out.sum().backward()\n    torch_out.sum().backward()\n    of_input_grad = x.oneflow.grad.to_global(\n        placement=flow.placement.all(\"cpu\"), sbp=[flow.sbp.broadcast,]\n    ).to_local()\n    torch_input_grad = x.pytorch.grad.detach().cpu()\n    test_case.assertTrue(\n        np.allclose(\n            of_input_grad.numpy(), torch_input_grad.numpy(), rtol=1e-04, atol=1e-4\n        )\n    )\n\n\ndef _test_roi_align_in_fixed_data_impl(test_case, placement, sbp):\n    from test_roi_align import input_np, rois_np, input_grad_np\n\n    input = (\n        flow.tensor(input_np, dtype=flow.float32)\n        .to_global(flow.placement.all(\"cpu\"), [flow.sbp.broadcast,])\n        .to_global(placement, sbp)\n        .requires_grad_()\n    )\n    rois = (\n        flow.tensor(rois_np, dtype=flow.float32)\n        .to_global(flow.placement.all(\"cpu\"), [flow.sbp.broadcast,])\n        .to_global(\n            placement, [flow.sbp.broadcast for _ in range(len(placement.ranks.shape))]\n        )\n    )\n    of_out = flow.roi_align(input, rois, 2.0, 5, 5, 2, True)\n    of_out.sum().backward()\n    test_case.assertTrue(\n        np.allclose(input.grad.numpy(), input_grad_np, rtol=1e-04, atol=1e-4)\n    )\n\n\nclass TestGlobalRoiAlign(flow.unittest.TestCase):\n    # TODO(wyg): It is a bug in pytorch-1.9.0, torchvision-0.10.0 and python3.7.10.\n    #            Open this test after updating the versions of pytorch in CI.\n\n    #  @globaltest\n    #  def test_global_roi_align(test_case):\n    #      for placement in all_placement():\n    #          # TODO: roi_align only support gpu\n    #          if placement.type == \"cpu\":\n    #              continue\n    #          for rois_sbp in all_sbp(placement, max_dim=0, except_partial_sum=True):\n    #              _test_roi_align(test_case, placement, rois_sbp)\n\n    def test_global_roi_align_in_fixed_data(test_case):\n        for placement in all_placement():\n            # TODO: roi_align only support gpu\n            if placement.type == \"cpu\":\n                continue\n            for rois_sbp in all_sbp(placement, max_dim=0, except_partial_sum=True):\n                _test_roi_align_in_fixed_data_impl(test_case, placement, rois_sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_roll.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\n\nimport numpy as np\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=1, check_graph=True)\ndef _test_roll_impl(test_case, placement, sbp):\n    shifts = (\n        random(-100, 100).to(int).value(),\n        random(-100, 100).to(int).value(),\n        random(-100, 100).to(int).value(),\n        random(-100, 100).to(int).value(),\n    )\n    dims = (0, 1, 2, 3)\n    x_dims = [random(2, 4) * 8 for _ in range(4)]\n    x = random_tensor(4, *x_dims)\n    y = x.to_global(placement=placement, sbp=sbp)\n    z = torch.roll(y, shifts, dims)\n    return z\n\n\nclass TestRollGlobal(flow.unittest.TestCase):\n    @globaltest\n    def test_roll(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=4):\n                _test_roll_impl(test_case, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_round.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nfrom collections import OrderedDict\n\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=1, check_graph=False)\ndef _test_round_impl(test_case, ndim, placement, sbp):\n    x_dims = [random(2, 4) * 8 for _ in range(ndim)]\n    x = random_tensor(ndim, *x_dims)\n    y = x.to_global(placement=placement, sbp=sbp)\n    z = torch.round(y)\n    return z\n\n\nclass TestRoundGlobal(flow.unittest.TestCase):\n    @globaltest\n    def test_round(test_case):\n        ndim = random(1, 5).to(int).value()\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=ndim):\n                _test_round_impl(test_case, ndim, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_scatter_nd.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\n\nimport numpy as np\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\ndef _test_scatter_nd(test_case, placement, sbp):\n    indices = (\n        flow.tensor(np.array([[1], [6], [4]]), dtype=flow.int)\n        .to_global(flow.placement.all(\"cpu\"), [flow.sbp.broadcast,])\n        .to_global(placement, sbp)\n    )\n    update = (\n        flow.tensor(np.array([10.2, 5.1, 12.7]), dtype=flow.float)\n        .to_global(flow.placement.all(\"cpu\"), [flow.sbp.broadcast,])\n        .to_global(placement, sbp)\n        .requires_grad_()\n    )\n    output = flow.scatter_nd(indices, update, [8])\n\n    # forward\n    of_local = output.to_global(\n        flow.placement.all(\"cpu\"), [flow.sbp.broadcast,]\n    ).to_local()\n    np_out = np.array([0.0, 10.2, 0.0, 0.0, 12.7, 0.0, 5.1, 0.0])\n    test_case.assertTrue(np.allclose(of_local.numpy(), np_out, 1e-4, 1e-4))\n\n    # backward\n    output.sum().backward()\n    of_grad_local = update.grad.to_global(\n        flow.placement.all(\"cpu\"), [flow.sbp.broadcast,]\n    ).to_local()\n    test_case.assertTrue(np.allclose(of_grad_local.numpy(), np.ones((3)), 1e-4, 1e-4))\n\n\nclass TestScatterNd(flow.unittest.TestCase):\n    @globaltest\n    def test_scatter_nd(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, except_partial_sum=True, except_split=True):\n                _test_scatter_nd(test_case, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_scatter_ops.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\n\nimport numpy as np\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=10, auto_backward=True, check_graph=True)\ndef _test_scatter_random_data(test_case, placement):\n    input = random_tensor(ndim=2, dim0=2, dim1=2).to_global(\n        placement=placement, sbp=random_sbp(placement, max_dim=2)\n    )\n    src = random_tensor(ndim=2, dim0=2, dim1=2).to_global(\n        placement=placement, sbp=random_sbp(placement, max_dim=2)\n    )\n    index = (\n        torch.tensor(np.array([[0, 1], [1, 0]]), dtype=torch.int64)\n        .to_global(flow.placement.all(\"cpu\"), [flow.sbp.broadcast,])\n        .to_global(placement, sbp=random_sbp(placement, max_dim=2),)\n    )\n    dim = random(0, 2).to(int).value()\n    return torch.scatter(input, dim, index, src)\n\n\n@autotest(n=10, auto_backward=True, check_graph=True)\ndef _test_scatter_scalar_random_data(test_case, placement):\n    input = random_tensor(ndim=2, dim0=2, dim1=2).to_global(\n        placement=placement, sbp=random_sbp(placement, max_dim=2)\n    )\n    index = (\n        torch.tensor(np.array([[0, 1], [1, 0]]), dtype=torch.int64)\n        .to_global(flow.placement.all(\"cpu\"), [flow.sbp.broadcast,])\n        .to_global(placement, sbp=random_sbp(placement, max_dim=2),)\n    )\n    dim = random(0, 2).to(int).value()\n    return torch.scatter(input, dim, index, 3.14)\n\n\n@autotest(n=10, auto_backward=True, check_graph=True)\ndef _test_scatter_add_random_data(test_case, placement):\n    input = random_tensor(ndim=2, dim0=2, dim1=2).to_global(\n        placement=placement, sbp=random_sbp(placement, max_dim=2)\n    )\n    src = random_tensor(ndim=2, dim0=2, dim1=2).to_global(\n        placement=placement, sbp=random_sbp(placement, max_dim=2)\n    )\n    index = (\n        torch.tensor(np.array([[0, 1], [1, 0]]), dtype=torch.int64)\n        .to_global(flow.placement.all(\"cpu\"), [flow.sbp.broadcast,])\n        .to_global(placement, sbp=random_sbp(placement, max_dim=2),)\n    )\n    dim = random(0, 2).to(int).value()\n    return torch.scatter_add(input, dim, index, src)\n\n\n@flow.unittest.skip_unless_1n2d()\nclass TestScatterOps(flow.unittest.TestCase):\n    @globaltest\n    def test_scatter_ops(test_case):\n        for placement in all_placement():\n            _test_scatter_random_data(test_case, placement)\n            _test_scatter_scalar_random_data(test_case, placement)\n            _test_scatter_add_random_data(test_case, placement)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_searchsorted.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\n\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=1, auto_backward=False, check_graph=False)\ndef _test_search_sorted(test_case, placement, sbp, ndim):\n    dims = [random(1, 3) * 8 for _ in range(ndim)]\n    sorted_sequence = random_tensor(ndim, *dims).to_global(placement, sbp)\n    values = random_tensor(ndim, *dims).to_global(placement, sbp)\n    y = torch.searchsorted(\n        sorted_sequence, values, out_int32=oneof(True, False), right=oneof(True, False),\n    )\n    return y\n\n\nclass TestSearchSorted_Global(flow.unittest.TestCase):\n    @globaltest\n    def test_search_sorted(test_case):\n        ndim = random(1, 5).to(int).value()\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=ndim):\n                _test_search_sorted(test_case, placement, sbp, ndim)\n\n\n@autotest(n=1, auto_backward=False, check_graph=False)\ndef _test_search_sorted_scalar(test_case, placement, sbp):\n    dim0 = [random(1, 3) * 8]\n    sorted_sequence = random_tensor(1, *dim0).to_global(placement, sbp)\n    values = 5\n    y = torch.searchsorted(\n        sorted_sequence, values, out_int32=oneof(True, False), right=oneof(True, False),\n    )\n    return y\n\n\nclass TestSearchSortedScalar_Global(flow.unittest.TestCase):\n    @unittest.skip(\"skip for now, becase it failed 8 times in past week\")\n    @globaltest\n    def test_search_sorted_scalar(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=1):\n                _test_search_sorted_scalar(test_case, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_sign.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\n\nimport oneflow as flow\nfrom oneflow.test_utils.automated_test_util import *\nimport oneflow.unittest\n\n\n@autotest(n=1, check_graph=True)\ndef _test_sign_impl(test_case, ndim, placement, sbp):\n    dims = [random(1, 3) * 8 for i in range(ndim)]\n    x = random_tensor(ndim, *dims).to_global(placement=placement, sbp=sbp)\n    y = torch.sign(x)\n    return y\n\n\nclass TestSign(flow.unittest.TestCase):\n    @globaltest\n    def test_sign(test_case):\n        for placement in all_placement():\n            ndim = random(1, 4).to(int).value()\n            for sbp in all_sbp(placement, max_dim=ndim):\n                _test_sign_impl(test_case, ndim, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_slice.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\n\nimport numpy as np\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\ndef _check_forward_and_backward(test_case, input, of_out, torch_out):\n    # compare forward\n    test_case.assertTrue(\n        np.array_equal(of_out.numpy(), torch_out.cpu().detach().numpy())\n    )\n\n    # compare backward\n    of_out.sum().backward()\n    torch_out.sum().backward()\n    torch_grad_local = input.pytorch.grad.cpu().detach()\n    test_case.assertTrue(\n        np.array_equal(input.oneflow.grad.numpy(), torch_grad_local.numpy())\n    )\n\n\ndef _test_slice_random_data(test_case, placement, sbp):\n    dims = [random(1, 2) * 8 for _ in range(2)]\n    input = random_tensor(2, *dims)\n    x = input.to_global(placement=placement, sbp=sbp)\n    slice_tup_list = [[None, None, None], [0, 5, 2]]\n    of_out = flow.slice(x.oneflow, slice_tup_list=slice_tup_list)\n    torch_out = x.pytorch[:, 0:5:2]\n\n    _check_forward_and_backward(test_case, input, of_out, torch_out)\n\n\ndef _test_slice_empty(test_case, placement, sbp):\n    dims = [random(1, 2) * 8 for _ in range(2)]\n    input = random_tensor(2, *dims)\n    x = input.to_global(placement=placement, sbp=sbp)\n    slice_tup_list = [[3, 3, 1], [None, None, None]]\n    of_out = flow.slice(x.oneflow, slice_tup_list=slice_tup_list)\n    torch_out = x.pytorch[3:3:1, :]\n\n    _check_forward_and_backward(test_case, input, of_out, torch_out)\n\n\ndef _test_slice_1dim(test_case, placement, sbp):\n    dims = [random(1, 2) * 8 for _ in range(2)]\n    input = random_tensor(2, *dims)\n    x = input.to_global(placement=placement, sbp=sbp)\n    of_out = x.oneflow[2]\n    torch_out = x.pytorch[2]\n\n    _check_forward_and_backward(test_case, input, of_out, torch_out)\n\n\ndef _test_negative_index(test_case, placement, sbp):\n    dims = [random(1, 2) * 8 for _ in range(2)]\n    input = random_tensor(2, *dims)\n    x = input.to_global(placement=placement, sbp=sbp)\n    of_out = x.oneflow[-1:-6:1, :]\n    torch_out = x.pytorch[-1:-6:1, :]\n\n    _check_forward_and_backward(test_case, input, of_out, torch_out)\n\n\ndef _test_slice_ellipsis_type(test_case, placement, sbp):\n    dims = [random(1, 2) * 8 for _ in range(2)]\n    input = random_tensor(2, *dims)\n    x = input.to_global(placement=placement, sbp=sbp)\n    of_out = x.oneflow[..., :]\n    torch_out = x.pytorch[..., :]\n\n    _check_forward_and_backward(test_case, input, of_out, torch_out)\n\n\ndef _test_slice_with_bool(test_case, placement, sbp):\n    x = random_tensor(2, 8, 8).oneflow > 0.5\n    x_numpy = x.detach().cpu().numpy()\n\n    x = x.to_global(placement=placement, sbp=sbp)\n    y = flow.slice(x, slice_tup_list=[[0, 1, 1]])\n\n    test_case.assertTrue(np.array_equal(y.numpy(), x_numpy[0:1:1]))\n\n\n@autotest(\n    n=2, auto_backward=False, check_graph=True,\n)\ndef _test_slice_with_grad(test_case, placement):\n    sbp = random_sbp(placement, max_dim=2).value()\n\n    # out_sbp\n    sbp_map = {\n        flow.sbp.broadcast: flow.sbp.broadcast,\n        flow.sbp.split(0): flow.sbp.split(0),\n        flow.sbp.split(1): flow.sbp.partial_sum(),\n        flow.sbp.partial_sum: flow.sbp.partial_sum(),\n    }\n    assert sbp is not None\n    out_sbp = tuple([sbp_map[in_sbp] for in_sbp in sbp])\n\n    x = random_tensor(2, 8, 16, requires_grad=True).oneflow\n    x_numpy = x.detach().cpu().numpy()\n\n    class SliceWithGrad(flow.nn.Module):\n        def __init__(self):\n            super().__init__()\n            self.input_grad = flow.nn.Parameter(flow.zeros(8, 16))\n\n        def forward(self, input):\n            x = input + self.input_grad\n            x = x.to_global(placement, sbp)\n            return x[:, :8]\n\n    slice_with_grad_m = SliceWithGrad().to_global(\n        placement, [flow.sbp.broadcast,] * len(sbp)\n    )\n\n    of_sgd = flow.optim.SGD(slice_with_grad_m.parameters(), lr=1.0, momentum=0.0)\n\n    class SliceTrainGraph(flow.nn.Graph):\n        def __init__(self):\n            super().__init__()\n            self.module = slice_with_grad_m\n            self.add_optimizer(of_sgd)\n\n        def build(self, x):\n            out = self.module(x)\n            test_case.assertEqual(\n                out.sbp,\n                out_sbp,\n                f\"input sbp is {sbp}, but output sbp is {out.sbp} with placement: {placement}\",\n            )\n            z = out.sum()\n            z.backward()\n            return out\n\n    graph = SliceTrainGraph()\n\n    input = x.to_global(placement=placement, sbp=sbp)\n    y = graph(input)\n\n    # output\n    test_case.assertTrue(np.array_equal(y.numpy(), x_numpy[:, :8]))\n    # input_grad\n    x_grad_np = np.zeros((8, 16))\n    x_grad_np[:, :8] = 1\n    test_case.assertTrue(\n        np.array_equal(-graph.module.input_grad.to(flow.Tensor).numpy(), x_grad_np)\n    )\n\n\nclass TestSlice(flow.unittest.TestCase):\n    @globaltest\n    def test_slice(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=2):\n                _test_slice_random_data(test_case, placement, sbp)\n                _test_slice_empty(test_case, placement, sbp)\n                _test_slice_1dim(test_case, placement, sbp)\n                _test_negative_index(test_case, placement, sbp)\n                _test_slice_ellipsis_type(test_case, placement, sbp)\n                _test_slice_with_bool(test_case, placement, sbp)\n\n    @unittest.skip(\"skip for now, becase it failed 12 times in past week\")\n    @globaltest\n    def test_graph_slice(test_case):\n        for placement in all_placement():\n            _test_slice_with_grad(test_case, placement)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_slice_update.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\n\nimport numpy as np\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\ndef _test_slice_update(test_case, placement, sbp):\n    input = random_tensor(2, 8, 16, requires_grad=True).oneflow\n    value = random_tensor(2, 8, 8, requires_grad=True).oneflow\n    x = (input + 0).to_global(\n        placement=placement, sbp=sbp\n    )  # add 0 to change to non-leaf tensor\n    y = value.to_global(placement, sbp=sbp)\n    x[:, :8] = y\n\n    ref_np = input.detach().cpu().numpy()\n    value_np = value.detach().cpu().numpy()\n\n    # forward\n    ref_np[:, :8] = value_np\n    test_case.assertTrue(x.sbp == sbp)\n    test_case.assertTrue(np.array_equal(x.numpy(), ref_np))\n\n    # backward\n    x.sum().backward()\n    # ref grad\n    ref_grad_np = np.ones((8, 16))\n    ref_grad_np[:, :8] = 0\n    test_case.assertTrue(np.array_equal(input.grad.numpy(), ref_grad_np))\n    # value grad\n    value_grad_np = np.ones((8, 8))\n    test_case.assertTrue(np.array_equal(value.grad.numpy(), value_grad_np))\n\n\ndef _test_graph_slice_update(test_case, placement, sbp):\n    ref = random_tensor(2, 8, 16, requires_grad=True).oneflow\n    value = random_tensor(2, 8, 8, requires_grad=True).oneflow\n\n    class SliceUpdateWithGrad(flow.nn.Module):\n        def __init__(self):\n            super().__init__()\n            self.ref_grad = flow.nn.Parameter(flow.zeros(8, 16))\n            self.value_grad = flow.nn.Parameter(flow.zeros(8, 8))\n\n        def forward(self, ref, value):\n            x = ref + self.ref_grad\n            y = value + self.value_grad\n            x = x.to_global(placement, sbp)\n            y = y.to_global(placement, sbp)\n            x[:, :8] = y\n            return x\n\n    slice_update_with_grad_m = SliceUpdateWithGrad().to_global(\n        placement, [flow.sbp.broadcast,] * len(sbp)\n    )\n\n    of_sgd = flow.optim.SGD(slice_update_with_grad_m.parameters(), lr=1.0, momentum=0.0)\n\n    class SliceUpdateTrainGraph(flow.nn.Graph):\n        def __init__(self):\n            super().__init__()\n            self.module = slice_update_with_grad_m\n            self.add_optimizer(of_sgd)\n\n        def build(self, x, y):\n            out = self.module(x, y)\n            z = out.sum()\n            z.backward()\n            return out\n\n    graph = SliceUpdateTrainGraph()\n\n    x = ref.to_global(placement=placement, sbp=sbp)\n    y = value.to_global(placement=placement, sbp=sbp)\n    z = graph(x, y)\n\n    test_case.assertTrue(z.sbp == sbp)\n\n    ref_np = ref.detach().cpu().numpy()\n    value_np = value.detach().cpu().numpy()\n\n    # forward\n    ref_np[:, :8] = value_np\n    test_case.assertTrue(np.array_equal(z.numpy(), ref_np))\n\n    # backward\n    # ref grad\n    ref_grad = np.ones((8, 16))\n    ref_grad[:, :8] = 0\n    test_case.assertTrue(\n        np.array_equal(-graph.module.ref_grad.to(flow.Tensor).numpy(), ref_grad)\n    )\n    # value grad\n    value_grad = np.ones((8, 8))\n    test_case.assertTrue(\n        np.array_equal(-graph.module.value_grad.to(flow.Tensor).numpy(), value_grad)\n    )\n\n\nclass TestGlobalSliceUpdate(flow.unittest.TestCase):\n    @globaltest\n    def test_slice_update(test_case):\n        for placement in all_placement():\n            for _ in range(2):\n                sbp = random_sbp(placement, max_dim=2).value()\n                _test_slice_update(test_case, placement, sbp)\n                _test_graph_slice_update(test_case, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_sort.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nfrom collections import OrderedDict\n\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=1, auto_backward=False, check_graph=True)\ndef _test_sort_impl(test_case, placement):\n    sbp = random_sbp(placement, max_dim=4)\n    x_dims = [random(2, 4) * 8 for _ in range(4)]\n    x = random_tensor(4, *x_dims)\n    dim = random(0, 4).to(int).value()\n    descending = random().to(bool).value()\n\n    y = x.to_global(placement=placement, sbp=sbp)\n    sort_result = torch.sort(y, dim=dim, descending=descending)\n    value = sort_result[0]\n    return value\n\n\nclass TestSortGlobal(flow.unittest.TestCase):\n    @globaltest\n    def test_sort(test_case):\n        for placement in all_placement():\n            _test_sort_impl(test_case, placement)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_sparse.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom oneflow.test_utils.automated_test_util import *\nimport oneflow as flow\nimport oneflow.unittest\n\n\n@autotest(n=1, check_graph=False)\ndef _test_embedding(test_case, ndim, placement, sbp):\n    emb_size = random() * 8\n    emb_dim = random() * 8\n    emb_shape = [emb_size, emb_dim]\n    idx_shape = [random(high=4) * 8 for i in range(ndim)]\n\n    weight = random_tensor(2, *emb_shape)\n    indices = random_tensor(\n        len(idx_shape), *idx_shape, low=0, high=emb_size, dtype=int\n    ).to_global(placement=placement, sbp=sbp)\n\n    embedding = torch.nn.Embedding(emb_size, emb_dim, _weight=weight).to_global(\n        placement=placement, sbp=sbp\n    )\n\n    output = embedding(indices)\n    return output\n\n\nclass TestEmbedding(flow.unittest.TestCase):\n    @globaltest\n    def test_embedding(test_case):\n        ndim = 2\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=ndim):\n                _test_embedding(test_case, ndim, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_sparse_softmax_cross_entropy.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nimport os\nfrom collections import OrderedDict\n\nimport numpy as np\nimport torch\n\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.test_util import GenArgList, type_name_to_flow_type\n\nfrom oneflow.test_utils.automated_test_util.generators import *\nfrom oneflow.test_utils.automated_test_util.torch_flow_dual_object import globaltest\n\n\ndef _compare_eager_global_with_torch(\n    placement, logits_sbp, labels_sbp, data_type, label_type, batch_size, num_classes,\n):\n    data_type = type_name_to_flow_type[data_type]\n    label_type = type_name_to_flow_type[label_type]\n    np_labels = np.random.randint(0, num_classes, size=(batch_size,)).astype(np.int32)\n    np_logits = np.random.random((batch_size, num_classes)).astype(np.float32)\n    if flow.env.get_rank() == 0:\n        torch_logits = torch.tensor(np_logits, dtype=torch.float32, requires_grad=True)\n        torch_labels = torch.tensor(np_labels, dtype=torch.int64)\n        torch_output = torch.nn.functional.cross_entropy(\n            torch_logits, torch_labels, reduction=\"none\"\n        )\n        torch_output.sum().backward()\n\n    of_logits = flow.tensor(np_logits, dtype=data_type, requires_grad=True).to_global(\n        flow.placement.all(\"cpu\"), flow.sbp.broadcast\n    )\n    of_logits = of_logits.to_global(placement, logits_sbp)\n\n    of_logits.retain_grad()\n\n    of_labels = flow.tensor(np_labels, dtype=label_type).to_global(\n        flow.placement.all(\"cpu\"), flow.sbp.broadcast\n    )\n    of_labels = of_labels.to_global(placement, labels_sbp)\n\n    of_output = flow.nn.functional.sparse_softmax_cross_entropy(\n        labels=of_labels, logits=of_logits\n    )\n    of_output.sum().backward()\n    of_logits_grad = of_logits.grad.to_global(\n        flow.placement.all(\"cpu\"), flow.sbp.broadcast\n    )\n    of_logits_grad = of_logits_grad.to_local()\n    of_output = of_output.to_global(flow.placement.all(\"cpu\"), flow.sbp.broadcast)\n    of_output = of_output.to_local()\n\n    if flow.env.get_rank() == 0:\n        assert np.allclose(\n            of_output.numpy(), torch_output.detach().numpy(), rtol=1e-03, atol=1e-04\n        )\n        assert np.allclose(\n            of_logits_grad.numpy(), torch_logits.grad, rtol=1e-03, atol=1e-04\n        )\n\n\ndef _compare_lazy_global_with_torch(\n    placement, logits_sbp, labels_sbp, data_type, label_type, batch_size, num_classes,\n):\n    data_type = type_name_to_flow_type[data_type]\n    label_type = type_name_to_flow_type[label_type]\n    np_labels = np.random.randint(0, num_classes, size=(batch_size,)).astype(np.int32)\n    np_logits = np.random.random((batch_size, num_classes)).astype(np.float32)\n    if flow.env.get_rank() == 0:\n        torch_logits = torch.tensor(np_logits, dtype=torch.float32, requires_grad=True)\n        torch_labels = torch.tensor(np_labels, dtype=torch.int64)\n        torch_output = torch.nn.functional.cross_entropy(\n            torch_logits, torch_labels, reduction=\"none\"\n        )\n\n    class MyModule(flow.nn.Graph):\n        def __init__(self):\n            super(MyModule, self).__init__()\n\n        # nn.graph no support get input.grad\n        def build(self, logits, labels):\n            output = flow.nn.functional.sparse_softmax_cross_entropy(\n                labels=labels, logits=logits\n            )\n            return output\n\n    of_logits = flow.tensor(np_logits, dtype=data_type, requires_grad=True).to_global(\n        flow.placement.all(\"cpu\"), flow.sbp.broadcast\n    )\n    of_logits = of_logits.to_global(placement, logits_sbp)\n\n    of_labels = flow.tensor(np_labels, dtype=label_type).to_global(\n        flow.placement.all(\"cpu\"), flow.sbp.broadcast\n    )\n    of_labels = of_labels.to_global(placement, labels_sbp)\n    graph = MyModule()\n    of_output = graph(of_logits, of_labels)\n    of_output = of_output.to_global(\n        placement=flow.placement.all(\"cpu\"), sbp=[flow.sbp.broadcast]\n    )\n    of_output = of_output.to_local()\n\n    flow._oneflow_internal.eager.multi_client.Sync()\n\n    if flow.env.get_rank() == 0:\n        assert np.allclose(\n            of_output.numpy(), torch_output.detach().numpy(), rtol=1e-03, atol=1e-04\n        )\n\n\nclass TestGlobalSparseSoftmaxCrossEntropyWithLogits(flow.unittest.TestCase):\n    @globaltest\n    def test_eager_global_sparse_softmax_cross_entropy(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"data_type\"] = [\"float32\", \"double\"]\n        arg_dict[\"label_type\"] = [\"int32\", \"int64\"]\n        arg_dict[\"batch_size\"] = [64]\n        arg_dict[\"num_classes\"] = [1024]\n        for arg in GenArgList(arg_dict):\n            for placement in all_placement():\n                for logits_sbp in all_sbp(placement, max_dim=2):\n                    for labels_sbp in all_sbp(placement, max_dim=1):\n                        _compare_eager_global_with_torch(\n                            placement, logits_sbp, labels_sbp, *arg\n                        )\n\n    # TODO: Too many streams will cause bugs, open the graph mode after solving\n    # @globaltest\n    # def test_lazy_global_sparse_softmax_cross_entropy(test_case):\n    #     arg_dict = OrderedDict()\n    #     arg_dict[\"data_type\"] = [\"float32\", \"double\"]\n    #     arg_dict[\"label_type\"] = [\"int32\", \"int64\"]\n    #     arg_dict[\"batch_size\"] = [64]\n    #     arg_dict[\"num_classes\"] = [1024]\n    #     for arg in GenArgList(arg_dict):\n    #         for placement in all_placement():\n    #             for logits_sbp in all_sbp(placement, max_dim=2):\n    #                 for labels_sbp in all_sbp(placement, max_dim=1):\n    #                     _compare_lazy_global_with_torch(\n    #                         placement, logits_sbp, labels_sbp, *arg\n    #                     )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_split.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=2, check_graph=True)\ndef _test_flow_split_with_random_data(test_case, placement, sbp):\n    k0 = random(2, 6) * 8\n    k1 = random(2, 6) * 8\n    k2 = random(2, 6) * 8\n    rand_dim = random(0, 3).to(int)\n    x = random_tensor(ndim=3, dim0=k0, dim1=k1, dim2=k2).to_global(\n        placement=placement, sbp=sbp\n    )\n    res = torch.split(x, 2, dim=rand_dim)\n    return torch.cat(res, rand_dim)\n\n\n@autotest(n=2, check_graph=True)\ndef _test_flow_split_sizes_with_random_data(test_case, placement, sbp):\n    k0 = random(2, 6) * 8\n    k1 = 16\n    k2 = random(2, 6) * 8\n    x = random_tensor(ndim=3, dim0=k0, dim1=k1, dim2=k2).to_global(\n        placement=placement, sbp=sbp\n    )\n    res = torch.split(x, [6, 3, 4, 3], dim=1)\n    return torch.cat(res, dim=1)\n\n\n@autotest(n=2, check_graph=True)\ndef _test_flow_split_sizes_neg_dim_with_random_data(test_case, placement, sbp):\n    k0 = random(2, 6) * 8\n    k1 = 16\n    k2 = random(2, 6) * 8\n    x = random_tensor(ndim=3, dim0=k0, dim1=k1, dim2=k2).to_global(\n        placement=placement, sbp=sbp\n    )\n    res = torch.split(x, [6, 3, 4, 3], dim=-2)\n    return torch.cat(res, dim=1)\n\n\nclass TestGlobalSplitModule(flow.unittest.TestCase):\n    @globaltest\n    def test_flow_split_with_random_data(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=2):\n                _test_flow_split_with_random_data(test_case, placement, sbp)\n\n    @globaltest\n    def test_flow_split_sizes_with_random_data(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=2):\n                _test_flow_split_sizes_with_random_data(test_case, placement, sbp)\n\n    @globaltest\n    def test_flow_split_sizes_neg_dim_with_random_data(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=2):\n                _test_flow_split_sizes_neg_dim_with_random_data(\n                    test_case, placement, sbp\n                )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_sqrt_square_sum.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=1, check_graph=True, rtol=0.5, atol=0.5)\ndef _test_sqrt_sum_with_cpu_random_data(test_case, placement, sbp):\n    x = random_tensor(ndim=4, dim0=8, dim1=32, dim2=40, dim3=64).to_global(\n        placement=placement, sbp=sbp\n    )\n    y = torch.linalg.norm(x)\n    return y\n\n\n@autotest(n=1, check_graph=True, rtol=0.5, atol=0.5)\ndef _test_scalar_random_data(test_case, placement, sbp):\n    x = random_tensor(ndim=4, dim0=8, dim1=24, dim2=16, dim3=40).to_global(\n        placement=placement, sbp=sbp\n    )\n    y = torch.linalg.norm(x)\n    return y\n\n\nclass TestGlobalLinalgVectorNorm2D(flow.unittest.TestCase):\n    @globaltest\n    def test_sqrt_sum_with_cpu_random_data(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=2):\n                _test_sqrt_sum_with_cpu_random_data(test_case, placement, sbp)\n\n    @globaltest\n    def test_scalar_random_data(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=2):\n                _test_scalar_random_data(test_case, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_squeeze.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\n\nfrom oneflow.test_utils.automated_test_util import *\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\n@autotest(n=1, check_graph=True)\ndef _test_squeeze_1d_input(test_case, placement, sbp):\n    x = random_tensor(1, 16, dtype=float).to_global(placement, sbp)\n    y = torch.squeeze(x)\n    return y\n\n\n@autotest(n=1, check_graph=True)\ndef _test_flow_squeeze_with_random_data(test_case, placement, sbp):\n    x = random_tensor(2, 8, 16).to_global(placement, sbp)\n    y = torch.squeeze(x, random(0, 2).to(int))\n    return y\n\n\n@autotest(n=1, check_graph=True)\ndef _test_squeeze_with_0_size_data(test_case, placement, sbp):\n    x = random_tensor(3, 8, 16, 0).to_global(placement, sbp)\n    y = torch.squeeze(x)\n    return y\n\n\nclass TestGlobalSqueeze(flow.unittest.TestCase):\n    @globaltest\n    def test_squeeze_1d_input(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=1):\n                _test_squeeze_1d_input(test_case, placement, sbp)\n\n    @globaltest\n    def test_flow_squeeze_with_random_data(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=2):\n                _test_flow_squeeze_with_random_data(test_case, placement, sbp)\n\n    @globaltest\n    def test_squeeze_with_0_size_data(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=2):\n                _test_squeeze_with_0_size_data(test_case, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_stack.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\n\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=2, check_graph=True)\ndef _test_stack_with_random_data(test_case, placement, sbp):\n    x = random_tensor(ndim=4, dim0=8, dim1=16, dim2=24, dim3=8).to_global(\n        placement, sbp\n    )\n    y = random_tensor(ndim=4, dim0=8, dim1=16, dim2=24, dim3=8).to_global(\n        placement, sbp\n    )\n    out = torch.stack((x, y), dim=random(low=-5, high=5).to(int))\n    return out\n\n\n@unittest.skip(\"backward of stack with random diff has bug.\")\nclass TestStackModule(flow.unittest.TestCase):\n    @globaltest\n    def test_stack_with_random_data(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=4):\n                _test_stack_with_random_data(test_case, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_stateful_kernel_with_cache.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\n\nimport numpy as np\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\ndef _test_global_stateful_kernel_with_inpersistent_state(test_case, placement, sbp):\n    x = (\n        flow.arange(64)\n        .reshape(8, 8)\n        .to_global(flow.placement.all(\"cpu\"), flow.sbp.broadcast)\n    )\n    x = x.to_global(placement, sbp)\n    y = x[0:3, 0:1]\n    y_np = np.array([[0], [8], [16]])\n    test_case.assertTrue(np.array_equal(y.numpy(), y_np,))\n    x = x.to_global(flow.placement.all(\"cpu\"), sbp=flow.sbp.split(1))\n    y = x[0:3, 0:1]\n    test_case.assertTrue(np.array_equal(y.numpy(), y_np,))\n\n\nclass TestStatefulKernelWithInpersistentState(flow.unittest.TestCase):\n    @globaltest\n    def test_global_stateful_kernel_with_inpersistent_state(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=2):\n                _test_global_stateful_kernel_with_inpersistent_state(\n                    test_case, placement, sbp\n                )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_std.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\n\nfrom oneflow.test_utils.automated_test_util import *\nimport oneflow as flow\nimport oneflow.unittest\n\n\n@autotest(n=1, auto_backward=True, check_graph=True)\ndef _test_global_std_flow_with_random_data(test_case, placement, sbp):\n    dim = random(low=0, high=4).to(int)\n    x = random_tensor(\n        ndim=4,\n        dim0=random(1, 4) * 8,\n        dim1=random(1, 4) * 8,\n        dim2=random(1, 4) * 8,\n        dim3=random(1, 4) * 8,\n    ).to_global(placement, sbp)\n    z = torch.std(x, dim=dim, unbiased=random().to(bool), keepdim=random().to(bool),)\n    return z\n\n\n@autotest(n=1, auto_backward=True, check_graph=True)\ndef _test_global_std_tensor_with_random_data(test_case, placement, sbp):\n    dim = random(low=0, high=4).to(int)\n    x = random_tensor(\n        ndim=4,\n        dim0=random(1, 4) * 8,\n        dim1=random(1, 4) * 8,\n        dim2=random(1, 4) * 8,\n        dim3=random(1, 4) * 8,\n    ).to_global(placement, sbp)\n    z = x.std(dim=dim, keepdim=random().to(bool),)\n    return z\n\n\nclass TestGlobalStd(flow.unittest.TestCase):\n    @globaltest\n    def test_global_std_flow_with_random_data(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=2):\n                _test_global_std_flow_with_random_data(test_case, placement, sbp)\n\n    @globaltest\n    def test_global_std_tensor_with_random_data(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=2):\n                _test_global_std_tensor_with_random_data(test_case, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_sub.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=1, auto_backward=False, check_graph=True)\ndef _test_global_sub(test_case, placement, sbp):\n    x = random_tensor(2, 8, 8).to_global(placement=placement, sbp=sbp)\n    y = random_tensor(2, 8, 8).to_global(placement=placement, sbp=sbp)\n    out1 = x - y\n    out2 = x - 2\n    out3 = 2 - x\n    out4 = torch.sub(x, y)\n    return out1, out2, out3, out4\n\n\n@autotest(n=1, auto_backward=False, check_graph=True)\ndef _test_global_sub_with_0_size_data(test_case, placement, sbp):\n    device = random_device()\n    x = random_tensor(2, 0, 8).to_global(placement=placement, sbp=sbp)\n    out1 = x - 2\n    out2 = 2 - x\n    return out1, out2\n\n\nclass TestGlobalSubModule(flow.unittest.TestCase):\n    @globaltest\n    def test_global_sub(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=2):\n                _test_global_sub(test_case, placement, sbp)\n\n    @globaltest\n    def test_global_sub_with_0_size_data(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=2, valid_split_axis=1):\n                _test_global_sub_with_0_size_data(test_case, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_sum.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\n\nfrom oneflow.test_utils.automated_test_util import *\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\n@autotest(n=1, check_graph=True, rtol=1e-3)\ndef _test_global_sum_against_pytorch(test_case, placement, sbp):\n    x = random_tensor(4, 8, 16, 8, 24).to_global(placement, sbp)\n    y = torch.sum(x)\n    return y\n\n\n@autotest(n=1, check_graph=True)\ndef _test_global_sum_with_0_size_tensor(test_case, placement, sbp):\n    x = random_tensor(4, 8, 16, 0, 24).to_global(placement, sbp)\n    y = torch.sum(x, dim=random(0, 3).to(int))\n    return y\n\n\nclass TestGlobalSumModule(flow.unittest.TestCase):\n    @unittest.skip(\"skip for now, becase it failed 2 times in past week\")\n    @globaltest\n    def test_global_sum_against_pytorch(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=4):\n                _test_global_sum_against_pytorch(test_case, placement, sbp)\n\n    @globaltest\n    def test_global_sum_with_0_size_tensor(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=4, valid_split_axis=[0, 1, 3]):\n                _test_global_sum_with_0_size_tensor(test_case, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_tensor_new.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=1, check_graph=False, auto_backward=False)\ndef _test_tensor_new(test_case, placement, sbp):\n    x = random_tensor(1, 64).to_global(placement=placement, sbp=sbp).oneflow\n    y = x.new()\n    test_case.assertTrue(x.dtype == y.dtype)\n    for x_sbp, y_sbp in zip(x.sbp, y.sbp):\n        test_case.assertTrue(x_sbp == y_sbp)\n    test_case.assertTrue(x.placement == y.placement)\n\n    y = x.new(1, 2, 3)\n    test_case.assertTrue(list(y.shape) == [1, 2, 3])\n    test_case.assertTrue(x.dtype == y.dtype)\n    for x_sbp, y_sbp in zip(x.sbp, y.sbp):\n        test_case.assertTrue(x_sbp == y_sbp)\n    test_case.assertTrue(x.placement == y.placement)\n\n    y = x.new([1, 2, 3])\n    test_case.assertTrue(list(y.shape) == [3])\n    test_case.assertTrue(x.dtype == y.dtype)\n    for x_sbp, y_sbp in zip(x.sbp, y.sbp):\n        test_case.assertTrue(x_sbp == y_sbp)\n    test_case.assertTrue(x.placement == y.placement)\n\n\nclass TestTensorNew(flow.unittest.TestCase):\n    @globaltest\n    def test_tensor_new(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, valid_split_axis=0):\n                _test_tensor_new(test_case, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_tensor_ops.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\ndef _test_type_as(test_case, shape, src_dtype, tgt_dtype, placement, sbp):\n    np_input = np.random.rand(*shape)\n    input = flow.tensor(np_input, dtype=src_dtype).to_global(placement, sbp)\n    target = flow.tensor(np_input, dtype=tgt_dtype).to_global(placement, sbp)\n    input = input.type_as(target)\n    test_case.assertEqual(input.dtype, target.dtype)\n\n\ndef _test_local_to_global_type_as(\n    test_case, shape, src_dtype, tgt_dtype, placement, sbp\n):\n    np_input = np.random.rand(*shape)\n    input = random_tensor(ndim=len(shape)).oneflow.to_local()\n    target = flow.tensor(np_input, dtype=tgt_dtype).to_global(placement, sbp)\n    input = input.type_as(target)\n    test_case.assertEqual(input.dtype, target.dtype)\n    test_case.assertEqual(input.placement, target.placement)\n    test_case.assertEqual(input.sbp, target.sbp)\n\n\ndef _test_global_to_local_type_as(\n    test_case, shape, src_dtype, tgt_dtype, placement, sbp\n):\n    np_input = np.random.rand(*shape)\n    input = flow.tensor(np_input, dtype=tgt_dtype).to_global(placement, sbp)\n    target = random_tensor(ndim=len(shape)).to(random_device()).oneflow.to_local()\n    input = input.type_as(target)\n    test_case.assertEqual(input.dtype, target.dtype)\n    test_case.assertEqual(input.device, target.device)\n\n\ndef _test_is_floating_point(test_case, shape, dtype, placement, sbp):\n    np_input = np.random.rand(*shape)\n    input = flow.tensor(np_input, dtype=dtype).to_global(placement, sbp)\n    output = input.is_floating_point()\n    if input.dtype in (flow.float, flow.float16, flow.float32, flow.double):\n        test_case.assertEqual(output, True)\n    else:\n        test_case.assertEqual(output, False)\n\n\n@autotest(n=1, check_graph=True)\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\ndef _test_global_cuda(test_case, placement, sbp):\n    x = random_tensor(2, 8, 16).to_global(placement, sbp)\n    x = x.cuda()\n    y = x.sum()\n    return y\n\n\nclass TestGlobalCuda(flow.unittest.TestCase):\n    @globaltest\n    def test_global_cuda(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=2):\n                _test_global_cuda(test_case, placement, sbp)\n\n\n@autotest(n=1, check_graph=True)\ndef _test_global_cpu(test_case, placement, sbp):\n    x = random_tensor(2, 8, 16).to_global(placement, sbp)\n    x = x.cpu()\n    y = x.sum()\n    return y\n\n\n# PyTorch error if open auto_backward:\n# element 0 of tensors does not require grad and does not have a grad_fn\n@autotest(n=1, auto_backward=False, check_graph=True)\ndef _test_global_long(test_case, placement, sbp):\n    x = random_tensor(2, 8, 16, requires_grad=True).to_global(placement, sbp)\n    y = x.long()\n    test_case.assertFalse(y.oneflow.requires_grad)\n    return y\n\n\n@autotest(n=1, auto_backward=False, check_graph=True)\ndef _test_global_int(test_case, placement, sbp):\n    x = random_tensor(2, 8, 16, requires_grad=True).to_global(placement, sbp)\n    y = x.int()\n    test_case.assertFalse(y.oneflow.requires_grad)\n    return y\n\n\n@autotest(n=1, auto_backward=False, check_graph=True)\ndef _test_global_float(test_case, placement, sbp):\n    x = random_tensor(2, 8, 16, dtype=int).to_global(placement, sbp)\n    y = x.float()\n    return y\n\n\n@autotest(n=1, auto_backward=False, check_graph=True)\ndef _test_global_double(test_case, placement, sbp):\n    x = random_tensor(2, 8, 16, dtype=int).to_global(placement, sbp)\n    y = x.double()\n    return y\n\n\n@autotest(n=1, auto_backward=False, check_graph=True)\ndef _test_global_item(test_case, placement, sbp):\n    x = random_tensor(ndim=1, dim0=1, dtype=int).to_global(placement, sbp)\n    y = torch.tensor(x.item())\n    return y\n\n\n@autotest(n=1, auto_backward=False, check_graph=False)\ndef _test_global_tolist(test_case, placement, sbp):\n    x = random_tensor(ndim=4, dim0=8, dim1=16, dim2=24, dim3=32, dtype=int).to_global(\n        placement, sbp\n    )\n    y = torch.tensor(x.tolist())\n    return y\n\n\nclass TestGlobalTensorOps(flow.unittest.TestCase):\n    @globaltest\n    def test_global_cpu(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=2):\n                _test_global_cpu(test_case, placement, sbp)\n\n    @globaltest\n    def test_global_long(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=2):\n                _test_global_long(test_case, placement, sbp)\n\n    @globaltest\n    def test_global_int(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=2):\n                _test_global_int(test_case, placement, sbp)\n\n    @globaltest\n    def test_global_float(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=2):\n                _test_global_float(test_case, placement, sbp)\n\n    @globaltest\n    def test_global_double(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=2):\n                _test_global_double(test_case, placement, sbp)\n\n    @unittest.skip(\"TODO: sometimes global item will result to segment fault!\")\n    @globaltest\n    def test_global_item(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=1, except_split=True):\n                _test_global_item(test_case, placement, sbp)\n\n    @globaltest\n    def test_global_tolist(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=4):\n                _test_global_tolist(test_case, placement, sbp)\n\n    @globaltest\n    def test_type_as(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"shape\"] = [(8, 16), (8, 16, 24), (8, 16, 24, 32)]\n        arg_dict[\"src_dtype\"] = [flow.int64, flow.int32, flow.float32, flow.float64]\n        arg_dict[\"tgt_dtype\"] = [flow.int64, flow.int32, flow.float32, flow.float64]\n        for arg in GenArgList(arg_dict):\n            for placement in all_placement():\n                for sbp in all_sbp(placement, max_dim=len(arg[0])):\n                    _test_type_as(test_case, *arg, placement, sbp)\n                    _test_local_to_global_type_as(test_case, *arg, placement, sbp)\n                    _test_global_to_local_type_as(test_case, *arg, placement, sbp)\n\n    @globaltest\n    def test_is_floating_point(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"shape\"] = [(8, 16), (8, 16, 24), (8, 16, 24, 32)]\n        arg_dict[\"dtype\"] = [\n            # flow.uint8, nccl don't support uint8\n            flow.int8,\n            flow.int32,\n            flow.int64,\n            flow.float32,\n            flow.float64,\n            flow.double,\n            flow.float,\n            flow.int,\n        ]\n        for arg in GenArgList(arg_dict):\n            for placement in all_placement():\n                for sbp in all_sbp(placement, max_dim=len(arg[0])):\n                    _test_is_floating_point(test_case, *arg, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_tensor_scatter_nd_update.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\nclass TensorScatterNdUpdate(flow.nn.Graph):\n    def __init__(self):\n        super(TensorScatterNdUpdate, self).__init__()\n\n    def build(self, origin, indices, update):\n        return flow.tensor_scatter_nd_update(origin, indices, update)\n\n\ndef _test_global_tensor_scatter_nd_update(test_case, placement, sbp, check_graph=True):\n    origin = random_tensor(1, 16, requires_grad=False).to_global(placement, sbp)\n    indices = choice_tensor(16, (8, 1), replace=False).to_global(\n        placement, [flow.sbp.broadcast for _ in range(len(placement.ranks.shape))]\n    )\n    update = random_tensor(1, 8, requires_grad=False).to_global(\n        placement, [flow.sbp.broadcast for _ in range(len(placement.ranks.shape))]\n    )\n\n    np_origin = origin.oneflow.numpy()\n    np_indices = indices.oneflow.numpy().reshape(8)\n    np_update = update.oneflow.numpy()\n\n    if check_graph:\n        tensor_scatter_nd_update = TensorScatterNdUpdate()\n        output = tensor_scatter_nd_update(\n            origin.oneflow, indices.oneflow, update.oneflow\n        )\n    else:\n        output = flow.tensor_scatter_nd_update(\n            origin.oneflow, indices.oneflow, update.oneflow\n        )\n\n    np_origin[np_indices] = np_update\n\n    test_case.assertTrue(np.allclose(output.numpy(), np_origin, 0.0001, 0.0001))\n\n\ndef _test_global_tensor_scatter_nd_update_t(\n    test_case, placement, sbp, check_graph=True\n):\n\n    origin = random_tensor(2, 16, 4, requires_grad=False).to_global(placement, sbp)\n    indices = choice_tensor(16, (8, 1), replace=False).to_global(\n        placement, [flow.sbp.broadcast for _ in range(len(placement.ranks.shape))]\n    )\n    update = random_tensor(2, 8, 4, requires_grad=False).to_global(\n        placement, [flow.sbp.broadcast for _ in range(len(placement.ranks.shape))]\n    )\n\n    np_origin = origin.oneflow.numpy()\n    np_indices = indices.oneflow.numpy().reshape(8)\n    np_update = update.oneflow.numpy()\n\n    if check_graph:\n        tensor_scatter_nd_update = TensorScatterNdUpdate()\n        output = tensor_scatter_nd_update(\n            origin.oneflow, indices.oneflow, update.oneflow\n        )\n    else:\n        output = flow.tensor_scatter_nd_update(\n            origin.oneflow, indices.oneflow, update.oneflow\n        )\n\n    np_origin[np_indices] = np_update\n\n    test_case.assertTrue(np.allclose(output.numpy(), np_origin, 0.0001, 0.0001))\n\n\ndef _test_eager_global_tensor_scatter_nd_update_backward(test_case, placement, sbp):\n    origin = random_tensor(1, 16,).to_global(placement, sbp)\n    origin.retain_grad()\n    indices = choice_tensor(16, (8, 1), replace=False).to_global(\n        placement, [flow.sbp.broadcast for _ in range(len(placement.ranks.shape))]\n    )\n    update = random_tensor(1, 8).to_global(\n        placement, [flow.sbp.broadcast for _ in range(len(placement.ranks.shape))]\n    )\n    update.retain_grad()\n\n    np_origin = origin.oneflow.numpy()\n    np_indices = indices.oneflow.numpy().reshape(8)\n    np_update = update.oneflow.numpy()\n\n    np_update_grad = np.ones(8)\n    np_origin_grad = np.ones(16)\n    np_origin_grad[np_indices] = np.zeros(8)\n\n    output = flow.tensor_scatter_nd_update(\n        origin.oneflow, indices.oneflow, update.oneflow\n    )\n    out_sum = output.sum()\n    out_sum.backward()\n\n    np_origin[np_indices] = np_update\n\n    test_case.assertTrue(np.allclose(output.numpy(), np_origin, 0.0001, 0.0001))\n    test_case.assertTrue(np.allclose(update.oneflow.grad.numpy(), np_update_grad))\n    test_case.assertTrue(np.allclose(origin.oneflow.grad.numpy(), np_origin_grad))\n\n\nclass TestTensorScatterNdUpdate(flow.unittest.TestCase):\n    @globaltest\n    def test_global_tensor_scatter_nd_update(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=1):\n                _test_global_tensor_scatter_nd_update(\n                    test_case, placement, sbp, False\n                )  # eager global\n                # skip lazy test\n                # _test_global_tensor_scatter_nd_update(\n                #     test_case, placement, sbp, True\n                # )  # nn graph\n\n    @globaltest\n    def test_global_tensor_scatter_nd_update_t(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=1):\n                _test_global_tensor_scatter_nd_update_t(\n                    test_case, placement, sbp, False\n                )  # eager global\n                # skip lazy test\n                # _test_global_tensor_scatter_nd_update_t(\n                #     test_case, placement, sbp, True\n                # )  # nn graph\n\n    @globaltest\n    def test_global_tensor_scatter_nd_update_backward(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=1):\n                _test_eager_global_tensor_scatter_nd_update_backward(\n                    test_case, placement, sbp\n                )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_tensordot.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=1, check_graph=True, atol=1e-3)\ndef _test_global_tensordot_against_pytorch(test_case, ndim, placement, sbp):\n    k = random(1, 2) * 8\n    tensordot_dim = random(0, ndim + 1).to(int)\n\n    x = random_tensor(ndim=ndim, dim0=k, dim1=k, dim2=k, dim3=k).to_global(\n        placement=placement, sbp=sbp\n    )\n    y = random_tensor(ndim=ndim, dim0=k, dim1=k, dim2=k, dim3=k).to_global(\n        placement=placement, sbp=sbp\n    )\n    z = torch.tensordot(x, y, dims=tensordot_dim)\n    return z\n\n\nclass TestTensorDotGlobal(flow.unittest.TestCase):\n    @globaltest\n    def test_tensordot(test_case):\n        for placement in all_placement():\n            for ndim in range(1, 4):\n                for sbp in all_sbp(placement, max_dim=ndim):\n                    _test_global_tensordot_against_pytorch(\n                        test_case, ndim, placement, sbp\n                    )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_tile.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\n\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=1, check_graph=True)\ndef _test_global_flow_tile_with_random_data(test_case, placement, sbp):\n    x = random_tensor(ndim=2, dim0=8, dim1=16).to_global(placement, sbp)\n    reps = (\n        random(1, 5).to(int) * 8,\n        random(1, 5).to(int) * 8,\n        random(1, 5).to(int) * 8,\n    )\n    z = torch.tile(x, reps)\n    return z\n\n\n@autotest(n=1, check_graph=True)\ndef _test_global_flow_tensor_tile_with_random_data(test_case, placement, sbp):\n    x = random_tensor(ndim=2, dim0=8, dim1=16).to_global(placement, sbp)\n    reps = (\n        random(1, 5).to(int) * 8,\n        random(1, 5).to(int) * 8,\n        random(1, 5).to(int) * 8,\n    )\n    y = x.tile(reps)\n    return y\n\n\nclass TestGlobalTile(flow.unittest.TestCase):\n    @unittest.skip(\"skip for now, becase it failed in 10 retry\")\n    @globaltest\n    def test_global_flow_tile_with_random_data(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=2):\n                _test_global_flow_tile_with_random_data(test_case, placement, sbp)\n\n    @unittest.skip(\"skip for now, becase it failed in 10 retry\")\n    @globaltest\n    def test_global_flow_tensor_tile_with_random_data(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=2):\n                _test_global_flow_tensor_tile_with_random_data(\n                    test_case, placement, sbp\n                )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_transpose.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\n\nfrom oneflow.test_utils.automated_test_util import *\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\ndef _test_global_transpose(test_case, placement, sbp):\n    input = flow.tensor(np.random.randn(8, 16, 8, 16), dtype=flow.float32).to_global(\n        flow.placement.all(\"cpu\"), flow.sbp.broadcast\n    )\n    input = input.to_global(placement, sbp)\n    of_out = flow.transpose(input, 0, 1)\n    np_out = input.numpy().transpose((1, 0, 2, 3))\n    test_case.assertTrue(np.array_equal(of_out.numpy().flatten(), np_out.flatten()))\n\n\ndef _test_global_tensor_transpose(test_case, placement, sbp):\n    input = flow.tensor(np.random.randn(8, 16, 8, 16), dtype=flow.float32).to_global(\n        flow.placement.all(\"cpu\"), flow.sbp.broadcast\n    )\n    input = input.to_global(placement, sbp)\n    of_out = input.transpose(0, 1)\n    np_out = input.numpy().transpose((1, 0, 2, 3))\n    test_case.assertTrue(np.array_equal(of_out.numpy().flatten(), np_out.flatten()))\n\n\ndef _test_global_tranpose_negative_dim(test_case, placement, sbp):\n    input = flow.tensor(np.random.randn(8, 16, 8, 16), dtype=flow.float32).to_global(\n        flow.placement.all(\"cpu\"), flow.sbp.broadcast\n    )\n    input = input.to_global(placement, sbp)\n    of_out = flow.transpose(input, -4, -3)\n    np_out = input.numpy().transpose((1, 0, 2, 3))\n    test_case.assertTrue(np.array_equal(of_out.numpy().flatten(), np_out.flatten()))\n\n\ndef _test_global_transpose_backward(test_case, placement, sbp):\n    x = flow.tensor(\n        np.random.randn(8, 16, 8, 16), dtype=flow.float32, requires_grad=True,\n    ).to_global(flow.placement.all(\"cpu\"), flow.sbp.broadcast)\n    x = x.to_global(placement, sbp)\n    x.retain_grad()\n    y = flow.transpose(x, 0, 1).sum()\n    y.backward()\n    test_case.assertTrue(\n        np.allclose(x.grad.numpy(), np.ones((8, 16, 8, 16)), 1e-05, 1e-05)\n    )\n\n\ndef _test_global_transpose_backward_v2(test_case, placement, sbp):\n    x = flow.tensor(\n        np.random.randn(8, 16, 8, 16), dtype=flow.float32, requires_grad=True,\n    ).to_global(flow.placement.all(\"cpu\"), flow.sbp.broadcast)\n    x = x.to_global(placement, sbp)\n    x.retain_grad()\n    y = flow.transpose(x, 3, 1).sum()\n    y.backward()\n    test_case.assertTrue(\n        np.allclose(x.grad.numpy(), np.ones((8, 16, 8, 16)), 1e-05, 1e-05)\n    )\n\n\n@autotest(n=1, check_graph=True)\ndef _test_global_transpose_flow_with_random_data(test_case, placement, sbp):\n    x = random_tensor(4, 8, 16, 24, 8).to_global(placement, sbp)\n    y = torch.transpose(x, dim0=random(1, 3).to(int), dim1=random(1, 3).to(int))\n    return y\n\n\n@autotest(n=1, check_graph=True)\ndef _test_global_transpose_with_0_size_data(test_case, placement, sbp):\n    device = random_device()\n    x = random_tensor(4, 8, 16, 0, 8).to_global(placement, sbp)\n    y = torch.transpose(x, dim0=random(1, 3).to(int), dim1=random(1, 3).to(int))\n    return y\n\n\nclass TestGlobalTranspose(flow.unittest.TestCase):\n    @globaltest\n    def test_global_transpose(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"fun\"] = [\n            _test_global_transpose,\n            _test_global_tensor_transpose,\n            _test_global_tranpose_negative_dim,\n            _test_global_transpose_backward,\n            _test_global_transpose_backward_v2,\n        ]\n        for arg in GenArgList(arg_dict):\n            for placement in all_placement():\n                for sbp in all_sbp(placement, max_dim=4):\n                    arg[0](test_case, placement, sbp)\n\n    @globaltest\n    def test_global_transpose_flow_with_random_data(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=4):\n                _test_global_transpose_flow_with_random_data(test_case, placement, sbp)\n\n    @globaltest\n    def test_global_transpose_with_0_size_data(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=4, valid_split_axis=[0, 1, 3]):\n                _test_global_transpose_with_0_size_data(test_case, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_tril.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\n\nfrom oneflow.test_utils.automated_test_util import *\nimport oneflow as flow\nimport oneflow.unittest\n\n\n@autotest(n=2, check_graph=True)\ndef _test_global_tril_without_diag(test_case, placement, sbp):\n    x = random_tensor(\n        ndim=4,\n        dim0=random(1, 3).to(int) * 8,\n        dim1=random(1, 3).to(int) * 8,\n        dim2=random(1, 3).to(int) * 8,\n        dim3=random(1, 3).to(int) * 8,\n    ).to_global(placement, sbp)\n    y = torch.tril(x)\n    y = torch.exp(y)\n\n    return y\n\n\n@autotest(n=2, check_graph=True)\ndef _test_global_tril_with_diag(test_case, placement, sbp):\n    diagonal = random(-3, 3).to(int)\n    x = random_tensor(\n        ndim=4,\n        dim0=random(1, 4).to(int) * 8,\n        dim1=random(1, 4).to(int) * 8,\n        dim2=random(1, 4).to(int) * 8,\n        dim3=random(1, 4).to(int) * 8,\n    ).to_global(placement, sbp)\n    y = torch.tril(x, diagonal)\n    y = torch.exp(y)\n\n    return y\n\n\nclass TestGlobalTril(flow.unittest.TestCase):\n    @unittest.skip(\"skip for now, becase it failed 2 times in past week\")\n    @globaltest\n    def test_global_tril_without_diag(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=4):\n                _test_global_tril_without_diag(test_case, placement, sbp)\n\n    @globaltest\n    def test_global_tril_with_diag(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=4):\n                _test_global_tril_with_diag(test_case, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_triu.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\n\nfrom oneflow.test_utils.automated_test_util import *\nimport oneflow as flow\nimport oneflow.unittest\n\n\n@autotest(n=2, check_graph=True)\ndef _test_global_triu_without_diag(test_case, placement, sbp):\n    x = random_tensor(\n        ndim=4,\n        dim0=random(1, 3).to(int) * 8,\n        dim1=random(1, 3).to(int) * 8,\n        dim2=2,\n        dim3=4,\n    ).to_global(placement, sbp)\n    y = torch.triu(x)\n    y = torch.exp(y)\n\n    return y\n\n\n@autotest(n=2, check_graph=True)\ndef _test_global_triu_with_diag(test_case, placement, sbp):\n    diagonal = random(-3, 3).to(int)\n    x = random_tensor(\n        ndim=4,\n        dim0=random(1, 3).to(int) * 8,\n        dim1=random(1, 3).to(int) * 8,\n        dim2=2,\n        dim3=4,\n    ).to_global(placement, sbp)\n    y = torch.triu(x, diagonal)\n    y = torch.exp(y)\n\n    return y\n\n\nclass TestGlobalTriu(flow.unittest.TestCase):\n    @globaltest\n    def test_global_triu_without_diag(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=2):\n                _test_global_triu_without_diag(test_case, placement, sbp)\n\n    @globaltest\n    def test_global_triu_with_diag(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=2):\n                _test_global_triu_with_diag(test_case, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_unbind.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\n# TODO: the test is dependent on global select op(global tensor->stride())\n@unittest.skip(\"global select op is not currently supported\")\n@autotest(n=1, check_graph=True)\ndef _test_unbind(test_case, placement, sbp):\n    dim_size = random(1, 3).to(int).value() * 8\n    rand_dim = random(0, 3).to(int).value()\n    x = random_tensor(ndim=3, dim0=dim_size, dim1=dim_size, dim2=dim_size).to_global(\n        placement, sbp\n    )\n    return torch.unbind(x, dim=rand_dim)\n\n\nclass TestUnbind(flow.unittest.TestCase):\n    @globaltest\n    def test_unbind(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=3):\n                _test_unbind(test_case, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_unfold.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\n\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\nfrom oneflow.nn.common_types import _size_2_t\n\n\n@autotest(n=1, check_graph=True)\ndef _test_unfold_with_random_data(test_case, placement, sbp):\n    ndim = 4\n    dims = [random(1, 3).to(int).value() * 8 for i in range(ndim)]\n    m = torch.nn.Unfold(\n        kernel_size=random(1, 3).to(_size_2_t),\n        dilation=random(1, 2).to(_size_2_t),\n        padding=random(0, 1).to(_size_2_t),\n        stride=random(1, 2).to(_size_2_t),\n    )\n    m.train(random())\n\n    x = random_tensor(ndim, *dims).to_global(placement, sbp)\n    y = m(x)\n    func_y = torch.nn.functional.unfold(\n        x,\n        kernel_size=random(1, 3).to(_size_2_t),\n        dilation=random(1, 2).to(_size_2_t),\n        padding=random(0, 1).to(_size_2_t),\n        stride=random(1, 2).to(_size_2_t),\n    )\n    return y, func_y\n\n\nclass TestUnfold(flow.unittest.TestCase):\n    @globaltest\n    def test_unfold_with_random_data(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=4):\n                _test_unfold_with_random_data(test_case, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_unfold_tensor.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\n\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\nimport numpy as np\n\n\n@autotest(n=1, auto_backward=True, check_graph=True)\ndef _test_global_unfold_tensor_with_random_data(test_case, placement, sbp):\n    ndim = 4\n    dim = random(0, ndim).to(int).value()\n    x = random_tensor(\n        ndim=ndim,\n        dim0=random(1, 3).to(int) * 8,\n        dim1=random(1, 3).to(int) * 8,\n        dim2=4,\n        dim3=4,\n    ).to_global(placement, sbp)\n    high = x.oneflow.size()[dim]\n    size = random(1, high).to(int).value()\n    step = random(1, high).to(int).value()\n    y = x.unfold(dim, size, step)\n    return y\n\n\nclass TestGlobalUnfoldTensor(flow.unittest.TestCase):\n    @globaltest\n    def test_global_unfold_tensor_with_random_data(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=2):\n                _test_global_unfold_tensor_with_random_data(test_case, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_unique.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nimport numpy as np\nimport oneflow as flow\nimport torch as torch_ori\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\ndef _test_unique_unsorted(test_case, placement, sbp):\n    input = random_tensor(ndim=1, dim0=64, high=20).to_global(\n        placement=placement, sbp=sbp\n    )\n    oneflow_output = flow.unique(\n        input.oneflow, sorted=False, return_inverse=True, return_counts=True\n    )\n    torch_output = torch_ori.unique(\n        input.pytorch, sorted=False, return_inverse=True, return_counts=True\n    )\n\n    oneflow_result, oneflow_indices, oneflow_counts = oneflow_output\n    torch_result, torch_indices, torch_counts = torch_output\n\n    test_case.assertTrue(\n        np.allclose(\n            np.sort(oneflow_result.to_local().numpy()),\n            np.sort(torch_result.detach().cpu().numpy()),\n        )\n    )\n    test_case.assertTrue(\n        np.allclose(\n            oneflow_result[oneflow_indices].numpy(),\n            torch_result[torch_indices].detach().cpu().numpy(),\n        )\n    )\n    test_case.assertTrue(\n        np.allclose(\n            oneflow_counts.numpy()[np.argsort(oneflow_result.numpy())],\n            torch_counts.detach()\n            .cpu()\n            .numpy()[np.argsort(torch_result.detach().cpu().numpy())],\n        )\n    )\n\n\ndef _test_unique_sorted(test_case, placement, sbp):\n    input = random_tensor(ndim=1, dim0=64, high=20).to_global(\n        placement=placement, sbp=sbp\n    )\n    oneflow_output = flow.unique(\n        input.oneflow, sorted=True, return_inverse=True, return_counts=True\n    )\n    torch_output = torch_ori.unique(\n        input.pytorch, sorted=True, return_inverse=True, return_counts=True\n    )\n\n    oneflow_result, oneflow_indices, oneflow_counts = oneflow_output\n    torch_result, torch_indices, torch_counts = torch_output\n\n    test_case.assertTrue(\n        np.allclose(\n            oneflow_result.to_local().numpy(), torch_result.detach().cpu().numpy(),\n        )\n    )\n    test_case.assertTrue(\n        np.allclose(oneflow_indices.numpy(), torch_indices.detach().cpu().numpy(),)\n    )\n    test_case.assertTrue(\n        np.allclose(oneflow_counts.numpy(), torch_counts.detach().cpu().numpy(),)\n    )\n\n\nclass TestUniqueModule(flow.unittest.TestCase):\n    @globaltest\n    def test_unique(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement):\n                _test_unique_unsorted(test_case, placement, sbp)\n                _test_unique_sorted(test_case, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_unsqueeze.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\n\nfrom oneflow.test_utils.automated_test_util import *\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\n@autotest(n=1, check_graph=True)\ndef _test_flow_unsqueeze_with_random_data(test_case, placement, sbp):\n    x = random_tensor(2, 8, 16).to_global(placement, sbp)\n    y = torch.unsqueeze(x, random(0, 3).to(int))\n    return y\n\n\n@autotest(n=1, check_graph=True)\ndef _test_tensor_unsqueeze_with_random_data(test_case, placement, sbp):\n    x = random_tensor(2, 8, 16).to_global(placement, sbp)\n    y = x.unsqueeze(random(0, 3).to(int))\n    return y\n\n\n@autotest(n=1, check_graph=True)\ndef _test_unsqueeze_with_0_size_data(test_case, placement, sbp):\n    x = random_tensor(3, 8, 16, 0).to_global(placement, sbp)\n    y = torch.unsqueeze(x, random(0, 4).to(int))\n    return y\n\n\nclass TestGlobalUnsqueeze(flow.unittest.TestCase):\n    @globaltest\n    def test_flow_unsqueeze_with_random_data(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=2):\n                _test_flow_unsqueeze_with_random_data(test_case, placement, sbp)\n\n    @globaltest\n    def test_tensor_unsqueeze_with_random_data(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=2):\n                _test_tensor_unsqueeze_with_random_data(test_case, placement, sbp)\n\n    @globaltest\n    def test_unsqueeze_with_0_size_data(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=2):\n                _test_unsqueeze_with_0_size_data(test_case, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_upsample.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nfrom oneflow.test_utils.test_util import GenArgList\nfrom oneflow.test_utils.automated_test_util import *\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\n@autotest(n=1, auto_backward=True, check_graph=True)\ndef _test_global_upsample2d_nearest(test_case, placement, sbp):\n    x = random_tensor(ndim=3, dim0=8, dim1=16).to_global(placement, sbp)\n    print(x)\n    m = torch.nn.Upsample(scale_factor=random().to(int), mode=\"nearest\",)\n    y = m(x)\n    return y\n\n\n@autotest(n=1, auto_backward=True, check_graph=True)\ndef _test_global_upsample2d_linear(test_case, placement, sbp):\n    x = random_tensor(ndim=3, dim0=8, dim1=16).to_global(placement, sbp)\n    m = torch.nn.Upsample(\n        scale_factor=random().to(int), mode=\"linear\", align_corners=random_bool(),\n    )\n    y = m(x)\n    return y\n\n\n@autotest(n=1, auto_backward=True, check_graph=True)\ndef _test_global_upsample2d_bilinear(test_case, placement, sbp):\n    x = random_tensor(ndim=4, dim0=8, dim1=16).to_global(placement, sbp)\n    m = torch.nn.Upsample(\n        scale_factor=random().to(int), mode=\"bilinear\", align_corners=random_bool(),\n    )\n    y = m(x)\n    return y\n\n\n@autotest(n=1, auto_backward=True, check_graph=True)\ndef _test_global_upsample2d_bicubic(test_case, placement, sbp):\n    x = random_tensor(ndim=4, dim0=8, dim1=16).to_global(placement, sbp)\n    m = torch.nn.Upsample(\n        scale_factor=random().to(int), mode=\"bicubic\", align_corners=random_bool(),\n    )\n    y = m(x)\n    return y\n\n\n@autotest(n=1, auto_backward=True, check_graph=True)\ndef _test_global_upsample2d_trilinear(test_case, placement, sbp):\n    x = random_tensor(ndim=5, dim0=8, dim1=16).to_global(placement, sbp)\n    m = torch.nn.Upsample(\n        scale_factor=random().to(int), mode=\"trilinear\", align_corners=random_bool(),\n    )\n    y = m(x)\n    return y\n\n\nclass TestGlobalUpsample2d(flow.unittest.TestCase):\n    @unittest.skip(\n        \"The nearest interpolate operation in pytorch has bug, https://github.com/pytorch/pytorch/issues/65200\"\n    )\n    @globaltest\n    def test_global_upsample2d_nearest(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=1):\n                _test_global_upsample2d_nearest(test_case, placement, sbp)\n\n    @globaltest\n    def test_global_upsample2d_linear(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=1):\n                _test_global_upsample2d_linear(test_case, placement, sbp)\n\n    @globaltest\n    def test_global_upsample2d_bilinear(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=1):\n                _test_global_upsample2d_bilinear(test_case, placement, sbp)\n\n    @globaltest\n    def test_global_upsample2d_bicubic(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=1):\n                _test_global_upsample2d_bicubic(test_case, placement, sbp)\n\n    @globaltest\n    def test_global_upsample2d_trilinear(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=1):\n                _test_global_upsample2d_trilinear(test_case, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_var.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\n\nimport oneflow as flow\nfrom oneflow.test_utils.automated_test_util.generators import random\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=1, check_graph=True)\ndef _test_flow_global_var_all_dim_with_random_data(test_case, placement, sbp):\n    x = random_tensor(\n        ndim=2, dim0=random(1, 3).to(int) * 8, dim1=random(1, 3).to(int) * 8,\n    ).to_global(placement, sbp)\n    y = torch.var(x)\n    return y\n\n\n@autotest(n=1, check_graph=True)\ndef _test_flow_global_var_one_dim_with_random_data(test_case, placement, sbp):\n    x = random_tensor(\n        ndim=2, dim0=random(1, 3).to(int) * 8, dim1=random(1, 3).to(int) * 8,\n    ).to_global(placement, sbp)\n    y = torch.var(\n        x,\n        dim=random(low=0, high=2).to(int),\n        unbiased=random().to(bool),\n        keepdim=random().to(bool),\n    )\n    return y\n\n\n@autotest(n=1, auto_backward=True, check_graph=True)\ndef _test_flow_var_0_size_data_with_random_data(test_case, placement, sbp):\n    x = random_tensor(3, 8, 0, 8).to_global(placement, sbp)\n    y = torch.var(\n        x,\n        dim=random(low=0, high=3).to(int),\n        unbiased=random().to(bool),\n        keepdim=random().to(bool),\n    )\n    return y\n\n\nclass TestVar(flow.unittest.TestCase):\n    @globaltest\n    def test_flow_global_var_all_dim_with_random_data(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=2):\n                _test_flow_global_var_all_dim_with_random_data(\n                    test_case, placement, sbp\n                )\n\n    @globaltest\n    def test_flow_global_var_one_dim_with_random_data(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=2):\n                _test_flow_global_var_one_dim_with_random_data(\n                    test_case, placement, sbp\n                )\n\n    @globaltest\n    def test_flow_var_0_size_data_with_random_data(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=2, valid_split_axis=[0]):\n                _test_flow_var_0_size_data_with_random_data(test_case, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_vector_matrix_product.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=1, check_graph=True)\ndef _test_vector_matrix_product(test_case, placement, sbp):\n    dim = random(1, 6)\n    vec = random_tensor(1, dim0=dim).to_global(placement=placement, sbp=sbp)\n    mat = random_tensor(2, dim0=dim, dim1=constant(4)).to_global(\n        placement=placement, sbp=sbp\n    )\n    return torch.matmul(vec, mat)\n\n\nclass TestGlobalVectorMatrixProduct(flow.unittest.TestCase):\n    @globaltest\n    def test_vector_matrix_product(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement):\n                _test_vector_matrix_product(test_case, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_view.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\n@autotest(n=1, check_graph=True)\ndef _test_global_view(test_case, placement, sbp):\n    x = random_tensor(ndim=2, dim0=8, dim1=32).to_global(placement, sbp)\n    y = x.view(8, 8, 2, -1)\n    return y\n\n\n@autotest(n=1, check_graph=True)\ndef _test_global_view_size(test_case, placement, sbp):\n    x = random_tensor(ndim=2, dim0=8, dim1=32).to_global(placement, sbp)\n    shape = torch.Size([8, 8, 2, -1])\n    y = x.view(shape)\n    return y\n\n\nclass TestGlobalView(flow.unittest.TestCase):\n    @globaltest\n    def test_global_view(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=2):\n                _test_global_view(test_case, placement, sbp)\n\n    @globaltest\n    def test_global_view_size(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=2):\n                _test_global_view_size(test_case, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_weight_norm.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nfrom collections import OrderedDict\n\nfrom oneflow.test_utils.test_util import GenArgList\nfrom oneflow.test_utils.automated_test_util import *\nimport oneflow as flow\nimport oneflow.unittest\n\n\n@autotest(n=1, check_graph=False)\ndef _test_global_weight_norm_with_random_data(test_case, placement, sbp):\n    dim = random(-2, 2).to(int).value()\n    liner_model_torch = torch.nn.Linear(8, 16).to_global(placement, sbp)\n    m = torch.nn.utils.weight_norm(liner_model_torch, name=\"weight\", dim=dim)\n    return m.weight_g, m.weight_v\n\n\nclass TestGlobalWeightNorm(flow.unittest.TestCase):\n    @unittest.skip(\"skip for now, becase it failed 6 times in past week\")\n    @globaltest\n    def test_global_weight_norm_with_random_data(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=1):\n                _test_global_weight_norm_with_random_data(test_case, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_where.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nfrom collections import OrderedDict\n\nfrom oneflow.test_utils.test_util import GenArgList\nfrom oneflow.test_utils.automated_test_util import *\nimport oneflow as flow\nimport oneflow.unittest\n\n\n@autotest(n=1, check_graph=True)\ndef _test_global_where(test_case, placement, sbp):\n    x = random_tensor(ndim=2, dim0=8, dim1=16).to_global(placement, sbp)\n    y = random_tensor(ndim=2, dim0=8, dim1=16).to_global(placement, sbp)\n    condition = random_tensor(ndim=2, dim0=8, dim1=16, high=2, dtype=int).to_global(\n        placement, sbp\n    )\n\n    condition = condition.to(torch.bool)\n\n    z = torch.where(condition, x, y)\n    return z\n\n\n@autotest(n=1, check_graph=True)\ndef _test_global_where_broadcast(test_case, placement, sbp):\n    x = random_tensor(ndim=3, dim0=8, dim1=16, dim2=1).to_global(placement, sbp)\n    y = random_tensor(ndim=3, dim0=8, dim1=16, dim2=8).to_global(placement, sbp)\n    condition = random_tensor(\n        ndim=3, dim0=8, dim1=16, dim2=1, high=2, dtype=int\n    ).to_global(placement, sbp)\n\n    condition = condition.to(torch.bool)\n\n    z = torch.where(condition, x, y)\n    return z\n\n\n@autotest(n=1, check_graph=True)\ndef _test_global_where_scalar(test_case, placement, sbp):\n    x = random_tensor(ndim=0).to_global(placement, sbp)\n    y = random_tensor(ndim=0).to_global(placement, sbp)\n    condition = random_tensor(ndim=0, high=2, dtype=int).to_global(placement, sbp)\n\n    condition = condition.to(torch.bool)\n\n    z = torch.where(condition, x, y)\n    return z\n\n\n# Close auto_backward because pytorch raise error:\n# PyTorch error: element 0 of tensors does not require grad and does not have a grad_fn\n# Not check graph because of one reason:\n# Reason 1, lazy tensor cannot call .numpy(), tensor.numpy() is not allowed to called in nn.Graph.build(*args) or called by lazy tensor.\n# Please refer to File \"python/oneflow/nn/modules/nonzero.py\", line 29, in nonzero_op. Because nonzero_op is called by where.\n@autotest(n=1, auto_backward=False, check_graph=\"ValidatedFalse\")\ndef _test_where_x_y_none(test_case, placement, sbp):\n    condition = random_tensor(ndim=2, dim0=8, dim1=8, low=-1, high=1).to_global(\n        placement, sbp\n    )\n    y = torch.where(condition)\n    return y[0], y[1]\n\n\n@autotest(n=1, check_graph=True)\ndef _test_global_where_tensor_with_0dim_data(test_case, placement, sbp):\n    cond = random_tensor(ndim=2, dim0=8, dim1=16).to_global(placement, sbp)\n    x = random_tensor(ndim=0).to_global(placement, sbp)\n    y = random_tensor(ndim=0).to_global(placement, sbp)\n    return torch.where(cond > 0, x, y)\n\n\n@autotest(n=1, check_graph=True)\ndef _test_flow_where_tensor_broadcast_with_random_data(test_case, placement, sbp):\n    cond = random_tensor(ndim=3, dim0=8, dim1=16, dim2=8).to_global(placement, sbp)\n    x = random_tensor(ndim=3, dim0=8, dim1=1, dim2=8).to_global(placement, sbp)\n    y = random_tensor(ndim=3, dim0=8, dim1=16, dim2=1).to_global(placement, sbp)\n    return torch.where(cond > 0, x, y)\n\n\n@autotest(n=1, check_graph=True)\ndef _test_flow_where_scalar_x_with_random_data(test_case, placement, sbp):\n    cond = random_tensor(ndim=2, dim0=8, dim1=16).to_global(placement, sbp)\n    x = random().to(float)\n    y = (\n        random_tensor(ndim=2, dim0=8, dim1=16, dtype=float)\n        .to_global(placement, sbp)\n        .to(torch.float64)\n    )\n    return torch.where(cond > 0, x, y)\n\n\n@autotest(n=1, check_graph=True)\ndef _test_flow_where_scalar_x_broadcast_with_random_data(test_case, placement, sbp):\n    cond = random_tensor(ndim=2, dim0=1, dim1=16).to_global(placement, sbp)\n    x = random().to(float)\n    y = (\n        random_tensor(ndim=2, dim0=8, dim1=1, dtype=float)\n        .to_global(placement, sbp)\n        .to(torch.float64)\n    )\n    return torch.where(cond > 0, x, y)\n\n\n@autotest(n=1, auto_backward=False, check_graph=True)\ndef _test_flow_where_scalar_x_int_with_random_data(test_case, placement, sbp):\n    cond = random_tensor(ndim=2, dim0=8, dim1=16).to_global(placement, sbp)\n    x = random().to(int)\n    y = random_tensor(ndim=2, dim0=8, dim1=16, dtype=int).to_global(placement, sbp)\n    return torch.where(cond > 0, x, y)\n\n\n@autotest(n=1, check_graph=True)\ndef _test_flow_where_scalar_y_with_random_data(test_case, placement, sbp):\n    cond = random_tensor(ndim=2, dim0=8, dim1=16).to_global(placement, sbp)\n    x = (\n        random_tensor(ndim=2, dim0=8, dim1=16, dtype=float)\n        .to_global(placement, sbp)\n        .to(torch.float64)\n    )\n    y = random().to(float)\n    return torch.where(cond > 0, x, y)\n\n\n@autotest(n=1, check_graph=True)\ndef _test_flow_where_scalar_y_broadcast_with_random_data(test_case, placement, sbp):\n    cond = random_tensor(ndim=2, dim0=1, dim1=16).to_global(placement, sbp)\n    x = (\n        random_tensor(ndim=2, dim0=8, dim1=1, dtype=float)\n        .to_global(placement, sbp)\n        .to(torch.float64)\n    )\n    y = random().to(float)\n    return torch.where(cond > 0, x, y)\n\n\n@autotest(n=1, auto_backward=False, check_graph=True)\ndef _test_flow_where_scalar_y_int_with_random_data(test_case, placement, sbp):\n    cond = random_tensor(ndim=2, dim0=8, dim1=16).to_global(placement, sbp)\n    x = random_tensor(ndim=2, dim0=8, dim1=16, dtype=int).to_global(placement, sbp)\n    y = random().to(int)\n    return torch.where(cond > 0, x, y)\n\n\n@autotest(n=1, auto_backward=False, check_graph=True)\ndef _test_flow_where_tensor_bool_with_random_data(test_case, placement, sbp):\n    cond = random_tensor(ndim=2, dim0=8, dim1=16).to_global(placement, sbp)\n    x = random_tensor(ndim=2, dim0=8, dim1=16).to_global(placement, sbp).to(torch.bool)\n    y = random_tensor(ndim=2, dim0=8, dim1=16).to_global(placement, sbp).to(torch.bool)\n    return torch.where(cond > 0, x, y)\n\n\n@autotest(n=1, auto_backward=False, check_graph=True)\ndef _test_flow_where_tensor_broadcast_bool_with_random_data(test_case, placement, sbp):\n    cond = random_tensor(ndim=2, dim0=8, dim1=16).to_global(placement, sbp)\n    x = random_tensor(ndim=2, dim0=1, dim1=16).to_global(placement, sbp).to(torch.bool)\n    y = random_tensor(ndim=2, dim0=8, dim1=1).to_global(placement, sbp).to(torch.bool)\n    return torch.where(cond > 0, x, y)\n\n\n@autotest(n=1, auto_backward=False, check_graph=True)\ndef _test_flow_where_scalar_x_bool_with_random_data(test_case, placement, sbp):\n    cond = random_tensor(ndim=2, dim0=8, dim1=16).to_global(placement, sbp)\n    x = random().to(bool)\n    y = (\n        random_tensor(ndim=2, dim0=8, dim1=16, dtype=float)\n        .to_global(placement, sbp)\n        .to(torch.bool)\n    )\n    return torch.where(cond > 0, x, y)\n\n\n@autotest(n=1, auto_backward=False, check_graph=True)\ndef _test_flow_where_scalar_x_broadcast_bool_with_random_data(\n    test_case, placement, sbp\n):\n    cond = random_tensor(ndim=2, dim0=1, dim1=16).to_global(placement, sbp)\n    x = random().to(bool)\n    y = (\n        random_tensor(ndim=2, dim0=8, dim1=1, dtype=float)\n        .to_global(placement, sbp)\n        .to(torch.bool)\n    )\n    return torch.where(cond > 0, x, y)\n\n\n@autotest(n=1, auto_backward=False, check_graph=True)\ndef _test_flow_where_scalar_y_bool_with_random_data(test_case, placement, sbp):\n    cond = random_tensor(ndim=2, dim0=8, dim1=16).to_global(placement, sbp)\n    x = (\n        random_tensor(ndim=2, dim0=8, dim1=16, dtype=float)\n        .to_global(placement, sbp)\n        .to(torch.bool)\n    )\n    y = random().to(bool)\n    return torch.where(cond > 0, x, y)\n\n\n@autotest(n=1, auto_backward=False, check_graph=True)\ndef _test_flow_where_scalar_y_broadcast_bool_with_random_data(\n    test_case, placement, sbp\n):\n    cond = random_tensor(ndim=2, dim0=8, dim1=16).to_global(placement, sbp)\n    x = (\n        random_tensor(ndim=2, dim0=8, dim1=1, dtype=float)\n        .to_global(placement, sbp)\n        .to(torch.bool)\n    )\n    y = random().to(bool)\n    return torch.where(cond > 0, x, y)\n\n\n@autotest(n=1, auto_backward=False, check_graph=True)\ndef _test_flow_where_scalar_xy_bool_with_random_data(test_case, placement, sbp):\n    cond = random_tensor(ndim=2, dim0=8, dim1=16).to_global(placement, sbp)\n    x = random().to(bool)\n    y = random().to(bool)\n    return torch.where(cond > 0, x, y)\n\n\nclass TestGlobalWhere(flow.unittest.TestCase):\n    @globaltest\n    def test_global_where(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=2):\n                _test_global_where(test_case, placement, sbp)\n\n    @globaltest\n    def test_global_where_broadcast(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=2):\n                _test_global_where_broadcast(test_case, placement, sbp)\n\n    @globaltest\n    def test_global_where_scalar(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, except_split=True):\n                _test_global_where_scalar(test_case, placement, sbp)\n\n    @globaltest\n    def test_where_x_y_none(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=2):\n                _test_where_x_y_none(test_case, placement, sbp)\n\n    @globaltest\n    def test_global_where_tensor_with_0dim_data(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, except_split=True):\n                _test_global_where_tensor_with_0dim_data(test_case, placement, sbp)\n\n    @globaltest\n    def test_flow_where_tensor_broadcast_with_random_data(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=1):\n                _test_flow_where_tensor_broadcast_with_random_data(\n                    test_case, placement, sbp\n                )\n\n    @globaltest\n    def test_flow_where_scalar_x_with_random_data(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, except_split=True):\n                _test_flow_where_scalar_x_with_random_data(test_case, placement, sbp)\n\n    @globaltest\n    def test_flow_where_scalar_x_broadcast_with_random_data(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, except_split=True):\n                _test_flow_where_scalar_x_broadcast_with_random_data(\n                    test_case, placement, sbp\n                )\n\n    @globaltest\n    def test_flow_where_scalar_x_int_with_random_data(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, except_split=True):\n                _test_flow_where_scalar_x_int_with_random_data(\n                    test_case, placement, sbp\n                )\n\n    @globaltest\n    def test_flow_where_scalar_y_with_random_data(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, except_split=True):\n                _test_flow_where_scalar_y_with_random_data(test_case, placement, sbp)\n\n    @globaltest\n    def test_flow_where_scalar_y_broadcast_with_random_data(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, except_split=True):\n                _test_flow_where_scalar_y_broadcast_with_random_data(\n                    test_case, placement, sbp\n                )\n\n    @globaltest\n    def test_flow_where_scalar_y_int_with_random_data(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, except_split=True):\n                _test_flow_where_scalar_y_int_with_random_data(\n                    test_case, placement, sbp\n                )\n\n    @globaltest\n    def test_flow_where_tensor_bool_with_random_data(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=2):\n                _test_flow_where_tensor_bool_with_random_data(test_case, placement, sbp)\n\n    @globaltest\n    def test_flow_where_tensor_broadcast_bool_with_random_data(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, except_split=True):\n                _test_flow_where_tensor_broadcast_bool_with_random_data(\n                    test_case, placement, sbp\n                )\n\n    @globaltest\n    def test_flow_where_scalar_x_bool_with_random_data(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, except_split=True):\n                _test_flow_where_scalar_x_bool_with_random_data(\n                    test_case, placement, sbp\n                )\n\n    @globaltest\n    def test_flow_where_scalar_x_broadcast_bool_with_random_data(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, except_split=True):\n                _test_flow_where_scalar_x_broadcast_bool_with_random_data(\n                    test_case, placement, sbp\n                )\n\n    @globaltest\n    def test_flow_where_scalar_y_bool_with_random_data(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, except_split=True):\n                _test_flow_where_scalar_y_bool_with_random_data(\n                    test_case, placement, sbp\n                )\n\n    @globaltest\n    def test_flow_where_scalar_y_broadcast_bool_with_random_data(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, except_split=True):\n                _test_flow_where_scalar_y_broadcast_bool_with_random_data(\n                    test_case, placement, sbp\n                )\n\n    @globaltest\n    def test_flow_where_scalar_xy_bool_with_random_data(test_case):\n        for placement in all_placement():\n            for sbp in all_sbp(placement, except_split=True):\n                _test_flow_where_scalar_xy_bool_with_random_data(\n                    test_case, placement, sbp\n                )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_zeropad2d.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nfrom collections import OrderedDict\n\nfrom oneflow.test_utils.test_util import GenArgList\nfrom oneflow.test_utils.automated_test_util import *\nimport oneflow as flow\nimport oneflow.unittest\n\n\n@autotest(n=1, check_graph=True)\ndef _test_global_ZeroPad2d(test_case, placement, sbp, padding):\n    x = random_tensor(ndim=4, dim0=8, dim1=16, dim2=8, dim3=8,).to_global(\n        placement, sbp\n    )\n    m = torch.nn.ZeroPad2d(padding)\n    y = m(x)\n    return y\n\n\nclass TestGlobalZeroPad2dModule(flow.unittest.TestCase):\n    @globaltest\n    def test_global_ZeroPad2d(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"padding\"] = [2, (1, 1, 2, 2)]\n        for arg in GenArgList(arg_dict):\n            for placement in all_placement():\n                for sbp in all_sbp(placement, max_dim=4):\n                    _test_global_ZeroPad2d(test_case, placement, sbp, *arg)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_global_zeros_like.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\nimport numpy as np\nfrom oneflow.test_utils.test_util import GenArgList\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\ndef _test_zeros_like_float(test_case, placement, sbp, shape, device):\n    x = flow.tensor(\n        np.random.randn(*shape), dtype=flow.float32, device=flow.device(device)\n    )\n    x = x.to_global(placement=placement, sbp=sbp)\n    y = flow.zeros_like(x, placement=placement, sbp=sbp)\n    test_case.assertTrue(y.dtype is flow.float32)\n    test_case.assertTrue(y.shape == x.shape)\n    test_case.assertTrue(y.placement == placement)\n    y_numpy = np.zeros(x.numpy().shape)\n    test_case.assertTrue(np.array_equal(y.numpy(), y_numpy))\n\n\ndef _test_zeros_like_int(test_case, placement, sbp, shape, device):\n    x = flow.tensor(np.random.randn(*shape), dtype=flow.int, device=flow.device(device))\n    x = x.to_global(placement=placement, sbp=sbp)\n    y = flow.zeros_like(x, dtype=flow.int, placement=placement, sbp=sbp)\n    test_case.assertTrue(y.dtype is flow.int)\n    test_case.assertTrue(y.shape == x.shape)\n    test_case.assertTrue(y.placement == placement)\n    y_numpy = np.zeros(x.numpy().shape)\n    test_case.assertTrue(np.array_equal(y.numpy(), y_numpy))\n\n\nclass TestModule(flow.unittest.TestCase):\n    @globaltest\n    def test_zeros_like(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [_test_zeros_like_float, _test_zeros_like_int]\n        arg_dict[\"shape\"] = [(8, 8), (8, 8, 4), (8, 8, 5, 6)]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            for placement in all_placement():\n                for sbp in all_sbp(placement, max_dim=2):\n                    arg[0](test_case, placement, sbp, *arg[1:])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_glu.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestGluModule(flow.unittest.TestCase):\n    @autotest(n=5, check_graph=True)\n    def test_glu_module_with_random_data(test_case):\n        device = random_device()\n        dim = random(-3, 3).to(int)\n        m = torch.nn.functional.glu\n        x = random_tensor(ndim=3, dim0=2, dim1=4, dim2=6).to(device)\n        y = m(x, dim)\n        return y\n\n    @autotest(n=5, check_graph=True)\n    def test_glu_module_with_random_data(test_case):\n        device = random_device()\n        m = torch.nn.GLU()\n        m.train(random())\n        m.to(device)\n        x = random_tensor(ndim=3, dim0=2, dim1=4, dim2=6).to(device)\n        y = m(x)\n        return y\n\n    @profile(torch.nn.functional.glu)\n    def profile_glu(test_case):\n        input = torch.ones(1000, 1000)\n        torch.nn.functional.glu(input)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_gpt_data_loader.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nimport os\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\nclass GPTDataLoader(flow.nn.Module):\n    def __init__(\n        self,\n        data_file_prefix=flow.unittest.dataset_dir(\n            \"Megatron-LM/dummy/gpt_sample_dataset_text_document\"\n        ),\n        seq_length=1024,\n        num_samples=648,\n        batch_size=8,\n        shuffle=True,\n        random_seed=12345,\n        device=None,\n        placement=None,\n        sbp=None,\n    ):\n        super().__init__()\n        self.loader_ = flow.nn.GPTIndexedBinDataReader(\n            data_file_prefix=data_file_prefix,\n            seq_length=seq_length,\n            num_samples=num_samples,\n            batch_size=batch_size,\n            shuffle=shuffle,\n            random_seed=random_seed,\n            device=device,\n            placement=placement,\n            sbp=sbp,\n        )\n\n    def forward(self):\n        return self.loader_()\n\n\nclass DataLoaderGraph(flow.nn.Graph):\n    def __init__(self, loader):\n        super().__init__()\n        self.loader_ = loader\n\n    def build(self):\n        return self.loader_()\n\n\n@unittest.skipIf(\n    os.getenv(\"ONEFLOW_TEST_GITHUB_HOSTED\"),\n    \"/dataset not available on GitHub hosted servers\",\n)\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n@flow.unittest.skip_unless_1n2d()\nclass GPTDataLoaderDistributedTestCase(oneflow.unittest.TestCase):\n    def test_case1(test_case):\n        rank = flow.env.get_rank()\n        # print(\n        #     f\"GPTDataLoaderDistributedTestCase.test_case1 on rank {rank} {os.getpid()}\"\n        # )\n        eager_gpt_loader = GPTDataLoader(batch_size=4, device=flow.device(\"cpu\", rank))\n\n        global_gpt_loader = GPTDataLoader(\n            batch_size=8,\n            placement=flow.placement(\"cpu\", ranks=[0, 1]),\n            sbp=[flow.sbp.split(0)],\n        )\n        gpt_loader_graph = DataLoaderGraph(global_gpt_loader)\n\n        iteration = 2\n        for i in range(iteration):\n            tokens = eager_gpt_loader()\n            # print(\n            #     f\"rank {rank} tokens: {tokens.shape}, {tokens.dtype}, device: {tokens.device}\"\n            #     f\"\\n{tokens.numpy()}\"\n            # )\n\n            g_tokens = gpt_loader_graph()\n            # print(\n            #     f\"rank {rank} graph output tokens: {g_tokens.shape}, {g_tokens.dtype}\"\n            #     f\", placement: {g_tokens.placement}\"\n            #     f\"\\n{g_tokens.to_local().numpy()}\"\n            # )\n\n            # print(f\"{'-' * 20} rank {rank} iter {i} complete {'-' * 20}\")\n            test_case.assertTrue(\n                np.allclose(tokens.numpy(), g_tokens.to_local().numpy())\n            )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_greater.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\n\nfrom oneflow.test_utils.automated_test_util import *\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\ndef _test_greater_normal(test_case, device):\n    input1 = flow.tensor(\n        np.array([1, 1, 4]).astype(np.float32),\n        dtype=flow.float32,\n        device=flow.device(device),\n    )\n    input2 = flow.tensor(\n        np.array([1, 2, 3]).astype(np.float32),\n        dtype=flow.float32,\n        device=flow.device(device),\n    )\n    of_out = flow.gt(input1, input2)\n    np_out = np.greater(input1.numpy(), input2.numpy())\n    test_case.assertTrue(np.array_equal(of_out.numpy(), np_out))\n\n\ndef _test_greater_symbol(test_case, device):\n    input1 = flow.tensor(\n        np.array([1, 1, 4]).astype(np.float32),\n        dtype=flow.float32,\n        device=flow.device(device),\n    )\n    input2 = flow.tensor(\n        np.array([1, 2, 3]).astype(np.float32),\n        dtype=flow.float32,\n        device=flow.device(device),\n    )\n    of_out = input1 > input2\n    np_out = np.greater(input1.numpy(), input2.numpy())\n    test_case.assertTrue(np.array_equal(of_out.numpy(), np_out))\n\n\ndef _test_greater_int_scalar(test_case, device):\n    np_arr = np.random.randn(2, 3, 4, 5)\n    input1 = flow.tensor(np_arr, dtype=flow.float32, device=flow.device(device))\n    input2 = 1\n    of_out = input1 > input2\n    np_out = np.greater(np_arr, input2)\n    test_case.assertTrue(np.array_equal(of_out.numpy(), np_out))\n\n\ndef _test_greater_int_tensor_int_scalar(test_case, device):\n    np_arr = np.random.randint(2, size=(2, 3, 4, 5))\n    input1 = flow.tensor(np_arr, dtype=flow.int, device=flow.device(device))\n    input2 = 1\n    of_out = input1 > input2\n    np_out = np.greater(np_arr, input2)\n    test_case.assertTrue(np.array_equal(of_out.numpy(), np_out))\n\n\ndef _test_greater_float_scalar(test_case, device):\n    np_arr = np.random.randn(3, 2, 5, 7)\n    input1 = flow.tensor(np_arr, dtype=flow.float32, device=flow.device(device))\n    input2 = 2.3\n    of_out = input1 > input2\n    np_out = np.greater(np_arr, input2)\n    test_case.assertTrue(np.array_equal(of_out.numpy(), np_out))\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestGreater(flow.unittest.TestCase):\n    def test_greater(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [\n            _test_greater_normal,\n            _test_greater_symbol,\n            _test_greater_int_scalar,\n            _test_greater_int_tensor_int_scalar,\n            _test_greater_float_scalar,\n        ]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n    @autotest(n=5, auto_backward=False, check_graph=True)\n    def test_greater_with_random_data(test_case):\n        device = random_device()\n        shape = random_tensor().oneflow.shape\n        x1 = random_tensor(len(shape), *shape, requires_grad=False).to(device)\n        x2 = random_tensor(len(shape), *shape, requires_grad=False).to(device)\n        y = torch.gt(x1, oneof(x2, random().to(int), random().to(float)))\n        return y\n\n    @autotest(n=5, auto_backward=False, check_graph=True)\n    def test_tensor_inplace_greater_with_random_data(test_case):\n        device = random_device()\n        shape = random_tensor().oneflow.shape\n        x1 = random_tensor(len(shape), *shape, requires_grad=False).to(device)\n        x2 = random_tensor(len(shape), *shape, requires_grad=False).to(device)\n        x1.gt_(oneof(x2, random().to(int), random().to(float)))\n        return x1\n\n    @autotest(n=5, auto_backward=False, check_graph=True)\n    def test_tensor_greater_with_random_data(test_case):\n        device = random_device()\n        shape = random_tensor().oneflow.shape\n        x1 = random_tensor(len(shape), *shape, requires_grad=False).to(device)\n        x2 = random_tensor(len(shape), *shape, requires_grad=False).to(device)\n        y1 = x1.gt(oneof(x2, random().to(int), random().to(float)))\n        y2 = x1 > x2\n        return (y1, y2)\n\n    @autotest(n=5, auto_backward=False, check_graph=True)\n    def test_greater_with_0_size_data(test_case):\n        device = random_device()\n        x1 = random_tensor(4, 2, 3, 0, 5).to(device)\n        x2 = random_tensor(4, 2, 3, 0, 5).to(device)\n        y1 = torch.gt(x1, x2)\n        y2 = x1 > x2\n        return (y1, y2)\n\n    @autotest(n=5, auto_backward=False, check_graph=True)\n    def test_greater_bool_with_random_data(test_case):\n        device = random_device()\n        shape = random_tensor().oneflow.shape\n        x1 = random_tensor(len(shape), *shape, requires_grad=False).to(\n            device=device, dtype=torch.bool\n        )\n        x2 = random_tensor(len(shape), *shape, requires_grad=False).to(\n            device=device, dtype=torch.bool\n        )\n        y = torch.gt(x1, oneof(x2, random().to(int), random().to(float)))\n        return y\n\n    @autotest(n=5, auto_backward=False, check_graph=True)\n    def test_greater_with_0dim_data(test_case):\n        device = random_device()\n        x1 = random_tensor(ndim=0).to(device)\n        x2 = random_tensor(ndim=0).to(device)\n        y1 = torch.gt(x1, x2)\n        y2 = x1 > x2\n        return (y1, y2)\n\n    @profile(torch.gt)\n    def profile_gt(test_case):\n        input = torch.ones(1000, 1000)\n        other = torch.ones(1000, 1000)\n        torch.gt(input, other)\n        torch.gt(input, 0)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_greater_equal.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\ndef _test_greater_equal_normal(test_case, device):\n    input1 = flow.tensor(\n        np.array([1, 1, 4]).astype(np.float32),\n        dtype=flow.float32,\n        device=flow.device(device),\n    )\n    input2 = flow.tensor(\n        np.array([1, 2, 3]).astype(np.float32),\n        dtype=flow.float32,\n        device=flow.device(device),\n    )\n    of_out = flow.ge(input1, input2)\n    np_out = np.greater_equal(input1.numpy(), input2.numpy())\n    test_case.assertTrue(np.array_equal(of_out.numpy(), np_out))\n\n\ndef _test_greater_equal_symbol(test_case, device):\n    input1 = flow.tensor(\n        np.array([1, 1, 4]).astype(np.float32),\n        dtype=flow.float32,\n        device=flow.device(device),\n    )\n    input2 = flow.tensor(\n        np.array([1, 2, 3]).astype(np.float32),\n        dtype=flow.float32,\n        device=flow.device(device),\n    )\n    of_out = input1 >= input2\n    np_out = np.greater_equal(input1.numpy(), input2.numpy())\n    test_case.assertTrue(np.array_equal(of_out.numpy(), np_out))\n\n\ndef _test_greater_equal_int_scalar(test_case, device):\n    np_arr = np.random.randn(2, 3, 4, 5)\n    input1 = flow.tensor(np_arr, dtype=flow.float32, device=flow.device(device))\n    input2 = 1\n    of_out = input1 >= input2\n    np_out = np.greater_equal(np_arr, input2)\n    test_case.assertTrue(np.array_equal(of_out.numpy(), np_out))\n\n\ndef _test_greater_equal_int_tensor_int_scalr(test_case, device):\n    np_arr = np.random.randint(2, size=(2, 3, 4, 5))\n    input1 = flow.tensor(np_arr, dtype=flow.int, device=flow.device(device))\n    input2 = 1\n    of_out = input1 >= input2\n    np_out = np.greater_equal(np_arr, input2)\n    test_case.assertTrue(np.array_equal(of_out.numpy(), np_out))\n\n\ndef _test_greater_equal_float_scalar(test_case, device):\n    np_arr = np.random.randn(3, 2, 5, 7)\n    input1 = flow.tensor(np_arr, dtype=flow.float32, device=flow.device(device))\n    input2 = 2.3\n    of_out = input1 >= input2\n    np_out = np.greater_equal(np_arr, input2)\n    test_case.assertTrue(np.array_equal(of_out.numpy(), np_out))\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestGreaterEqual(flow.unittest.TestCase):\n    def test_greter_equal(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [\n            _test_greater_equal_normal,\n            _test_greater_equal_symbol,\n            _test_greater_equal_int_scalar,\n            _test_greater_equal_int_tensor_int_scalr,\n            _test_greater_equal_float_scalar,\n        ]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_grid_sample.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom random import randint\nfrom random import choice\n\nimport numpy as np\n\nfrom oneflow.test_utils.automated_test_util import *\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestGridSample(flow.unittest.TestCase):\n    def test_grid_sample_4d(test_case):\n        input = flow.tensor(\n            np.arange(1.0, 11).reshape((1, 1, 2, 5)), dtype=flow.float32\n        )\n        np_grid = np.array(\n            [\n                [[-0.9, -4.1], [0, 0.2000], [1, -1], [-0.333, 1e-6], [0.5, 1.0]],\n                [[-1.0, -0.5], [0, 0.3333], [1, -1], [-0.200, 1e-6], [1.5, 0.5]],\n            ]\n        ).reshape(1, 2, 5, 2)\n        grid = flow.tensor(np_grid, dtype=flow.float32)\n        groundtruth = np.reshape(\n            np.array([[0.0, 8.0, 5.0, 7.0, 9.0], [1.0, 8.0, 5.0, 8.0, 0.0]]),\n            (1, 1, 2, 5),\n        )\n        output = flow.nn.functional.grid_sample(\n            input, grid, mode=\"nearest\", padding_mode=\"zeros\", align_corners=True\n        )\n        test_case.assertTrue(\n            np.allclose(output.numpy(), groundtruth, rtol=1e-3, atol=1e-4)\n        )\n\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    @autotest(rtol=1e-03, atol=1e-04, check_graph=True)\n    def test_flow_grid_sample_cudnn_with_random_data(test_case):\n        # cudnn only support 4D input, with mode = 'bilinear' && padding_mode = 'zeros' && align_corners\n        N = randint(1, 8)\n        C = randint(1, 8)\n        in_H = randint(1, 8)\n        in_W = randint(1, 8)\n        out_H = randint(1, 8)\n        out_W = randint(1, 8)\n        device = \"cuda\"\n        mode = \"bilinear\"\n        padding_mode = \"zeros\"\n        align_corners = True\n        theta = random_tensor(ndim=3, dim0=N, dim1=2, dim2=3).to(device)\n        grid = torch.nn.functional.affine_grid(\n            theta, (N, C, out_H, out_W), align_corners=align_corners\n        ).to(device)\n        input = random_tensor(ndim=4, dim0=N, dim1=C, dim2=in_H, dim3=in_W).to(device)\n        output = torch.nn.functional.grid_sample(\n            input,\n            grid,\n            mode=mode,\n            padding_mode=padding_mode,\n            align_corners=align_corners,\n        )\n        return output\n\n    # This test may fail due to using ::floor in backward\n    # floor(1.99999988) = 1 and floor(2.000000) = 2, then select differente images pixel\n    @autotest(\n        auto_backward=False,\n        rtol=1e-03,\n        atol=1e-04,\n        check_graph=True,\n        check_allclose=False,\n    )\n    def test_flow_grid_sample_4d_with_random_data(test_case):\n        N = randint(1, 8)\n        C = randint(1, 8)\n        in_H = randint(1, 8)\n        in_W = randint(1, 8)\n        out_H = randint(1, 8)\n        out_W = randint(1, 8)\n        device = random_device()\n        mode = choice([\"bilinear\", \"nearest\", \"bicubic\"])\n        padding_mode = choice([\"zeros\", \"border\", \"reflection\"])\n        align_corners = choice([True, False])\n        theta = random_tensor(ndim=3, dim0=N, dim1=2, dim2=3).to(device)\n        grid = torch.nn.functional.affine_grid(\n            theta, (N, C, out_H, out_W), align_corners=align_corners\n        ).to(device)\n        input = random_tensor(ndim=4, dim0=N, dim1=C, dim2=in_H, dim3=in_W).to(device)\n        output = torch.nn.functional.grid_sample(\n            input,\n            grid,\n            mode=mode,\n            padding_mode=padding_mode,\n            align_corners=align_corners,\n        )\n        return output\n\n    @autotest(auto_backward=False, rtol=1e-03, atol=1e-03, check_graph=True)\n    def test_flow_grid_sample_5d_with_random_data(test_case):\n        N = randint(1, 8)\n        C = randint(1, 8)\n        in_D = randint(1, 8)\n        in_H = randint(1, 8)\n        in_W = randint(1, 8)\n        out_D = randint(1, 8)\n        out_H = randint(1, 8)\n        out_W = randint(1, 8)\n        device = random_device()\n        mode = choice([\"bilinear\", \"nearest\"])\n        padding_mode = choice([\"zeros\", \"border\", \"reflection\"])\n        align_corners = choice([True, False])\n        theta = random_tensor(ndim=3, dim0=N, dim1=3, dim2=4).to(device)\n        grid = torch.nn.functional.affine_grid(\n            theta, (N, C, out_D, out_H, out_W), align_corners=align_corners\n        ).to(device)\n        input = random_tensor(\n            ndim=5, dim0=N, dim1=C, dim2=in_D, dim3=in_H, dim4=in_W\n        ).to(device)\n        output = torch.nn.functional.grid_sample(\n            input,\n            grid,\n            mode=mode,\n            padding_mode=padding_mode,\n            align_corners=align_corners,\n        )\n        return output\n\n    @profile(torch.nn.functional.grid_sample)\n    def profile_grid_sample(test_case):\n        input = torch.ones(32, 3, 128, 128)\n        grid = torch.ones(32, 64, 64, 2)\n        torch.nn.functional.grid_sample(input, grid)\n        torch.nn.functional.grid_sample(input, grid, align_corners=True)\n        torch.nn.functional.grid_sample(input, grid, mode=\"nearest\", align_corners=True)\n        torch.nn.functional.grid_sample(input, grid, mode=\"bicubic\", align_corners=True)\n        torch.nn.functional.grid_sample(input, grid, padding_mode=\"border\")\n        torch.nn.functional.grid_sample(input, grid, padding_mode=\"reflection\")\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_grouped_matmul_bias.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nfrom collections import OrderedDict\nimport numpy as np\nfrom oneflow.test_utils.test_util import GenArgList\nimport math\nimport os\n\nimport oneflow as flow\n\n\ndef _ref(xs, weights, biases):\n    if biases is None:\n        return [\n            flow._C.matmul(x, w, transpose_a=False, transpose_b=True)\n            for x, w in zip(xs, weights)\n        ]\n    else:\n        return [\n            flow._C.matmul(x, w, transpose_a=False, transpose_b=True) + b\n            for x, w, b in zip(xs, weights, biases)\n        ]\n\n\ndef _grouped(xs, weights, biases):\n    if biases is None:\n        return flow._C.grouped_matmul(xs, weights)\n    else:\n        return flow._C.grouped_matmul_bias(xs, weights, biases)\n\n\ndef _test_grouped_matmul_bias(test_case, dtype, problems, bias):\n\n    xs = [\n        flow.randn((m, k), device=\"cuda\", dtype=dtype) / 10.0 for (m, n, k) in problems\n    ]\n    ws = [\n        flow.randn((n, k), device=\"cuda\", dtype=dtype) / 10.0 for (m, n, k) in problems\n    ]\n    bs = [flow.randn((n), device=\"cuda\", dtype=dtype) / 10.0 for (m, n, k) in problems]\n\n    ref_out = _ref(xs, ws, bs if bias else None)\n    grouped_out = _grouped(xs, ws, bs if bias else None)\n    for (ref_y, grouped_y) in zip(ref_out, grouped_out):\n        test_case.assertTrue(\n            np.allclose(ref_y.numpy(), grouped_y.numpy(), atol=1e-2, rtol=1e-2)\n        )\n\n\n@flow.unittest.skip_unless_1n1d()\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\nclass TestGroupedMatmulBias(flow.unittest.TestCase):\n    def test_grouped_matmul_bias(test_case):\n        problems = [(2, 1280, 1280)] * 12 + [(2, 1280, 640)] * 4 + [(2, 1280, 320)] * 5\n        _test_grouped_matmul_bias(test_case, flow.float16, problems, True)\n        _test_grouped_matmul_bias(test_case, flow.float16, problems, False)\n        problems = (\n            [(2 * 77, 768, 1280)] * 6\n            + [(2 * 77, 768, 640)] * 5\n            + [(2 * 77, 768, 320)] * 5\n        )\n        _test_grouped_matmul_bias(test_case, flow.float16, problems, True)\n        _test_grouped_matmul_bias(test_case, flow.float16, problems, False)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_groupnorm.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nfrom oneflow.test_utils.test_util import GenArgList\nfrom oneflow.test_utils.automated_test_util import *\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\ndef _test_groupnorm(test_case, device):\n    input_arr = np.array(\n        [\n            [\n                [\n                    [-0.8791, 0.2553, 0.7403, -0.2859],\n                    [0.8006, -1.7701, -0.9617, 0.1705],\n                    [0.2842, 1.7825, 0.3365, -0.8525],\n                ],\n                [\n                    [0.7332, -0.0737, 0.7245, -0.6551],\n                    [1.4461, -0.1827, 0.9737, -2.1571],\n                    [0.4657, 0.7244, 0.3378, 0.1775],\n                ],\n            ],\n            [\n                [\n                    [1.8896, 1.8686, 0.1896, 0.9817],\n                    [-0.0671, 1.5569, 1.1449, 0.0086],\n                    [-0.9468, -0.0124, 1.3227, -0.6567],\n                ],\n                [\n                    [-0.8472, 1.3012, -1.1065, 0.9348],\n                    [1.0346, 1.5703, 0.2419, -0.7048],\n                    [0.6957, -0.4523, -0.8819, 1.0164],\n                ],\n            ],\n        ],\n        dtype=np.float32,\n    )\n    output = np.array(\n        [\n            [\n                [\n                    [-1.0548115, 0.18125379, 0.7097197, -0.4084487],\n                    [0.77542377, -2.0256634, -1.1448141, 0.08885399],\n                    [0.21274385, 1.845322, 0.26973096, -1.0258276],\n                ],\n                [\n                    [0.7019834, -0.17723128, 0.6925037, -0.81073654],\n                    [1.4787737, -0.2959999, 0.96403706, -2.4473464],\n                    [0.4105099, 0.69239473, 0.2711475, 0.09648134],\n                ],\n            ],\n            [\n                [\n                    [1.5438884, 1.5218256, -0.24213786, 0.5900453],\n                    [-0.5118278, 1.1943525, 0.76150376, -0.43229714],\n                    [-1.4360437, -0.4543598, 0.94830114, -1.1312639],\n                ],\n                [\n                    [-1.3314037, 0.9257132, -1.6038253, 0.54077196],\n                    [0.6456222, 1.2084305, -0.18719131, -1.1817979],\n                    [0.28957263, -0.91652036, -1.3678597, 0.6265012],\n                ],\n            ],\n        ],\n        dtype=np.float32,\n    )\n    x = flow.tensor(input_arr, dtype=flow.float32, device=flow.device(device))\n    m = flow.nn.GroupNorm(num_groups=1, num_channels=2).to(device=flow.device(device))\n    y = m(x)\n    test_case.assertTrue(np.allclose(y.numpy(), output, 1e-03, 1e-03))\n\n\ndef _test_groupnorm_3d(test_case, device):\n    input_arr = np.array(\n        [\n            [\n                [\n                    [\n                        [1.04569761, 0.22863248, 1.42439335, 1.62249689],\n                        [-0.80578825, -0.27276461, 1.04556507, 0.56864134],\n                        [-1.24085419, -1.23960097, 0.33451416, -1.84820402],\n                    ],\n                    [\n                        [-1.511261, 1.06157517, -0.26715858, -1.32888141],\n                        [1.17976881, -0.07931171, 0.33910684, -1.93458573],\n                        [-1.72659647, 0.79049652, 0.39102785, -1.16264882],\n                    ],\n                ],\n                [\n                    [\n                        [0.30067973, -1.2912226, -0.61508225, 0.56454001],\n                        [0.87074187, -1.69257376, 0.36119148, -0.31014289],\n                        [0.20776964, 1.26195488, -1.37122193, -0.17945234],\n                    ],\n                    [\n                        [-0.31112407, -0.80682631, 0.8233194, 0.6384975],\n                        [0.57617527, 0.45505028, 1.68286151, -1.09590744],\n                        [-1.18127546, -1.07529277, 0.52779943, 1.21755926],\n                    ],\n                ],\n            ],\n            [\n                [\n                    [\n                        [-0.12832351, 1.05625455, -0.23253249, -0.64747611],\n                        [-0.00738123, -1.41390089, -1.92664144, -0.21427625],\n                        [-0.94631219, -0.86493989, 0.21026905, 0.24989732],\n                    ],\n                    [\n                        [1.3859182, 1.72002107, 0.50091892, 1.04198896],\n                        [0.71694594, 1.66417023, -1.63030052, 0.77182641],\n                        [0.71545083, 1.96458366, -1.99031931, 1.3196714],\n                    ],\n                ],\n                [\n                    [\n                        [1.80091702, 0.02834973, 0.82259214, -1.05597501],\n                        [-0.58212207, 0.44205949, -0.14740003, -0.994508],\n                        [1.14678114, -0.39196097, 1.2554798, -0.41829324],\n                    ],\n                    [\n                        [-1.0153903, -0.25755713, -1.81756333, -1.06781159],\n                        [1.79680841, -1.9107133, -0.64325796, -1.94640775],\n                        [1.30671156, 1.20445339, -1.26262901, -0.79494188],\n                    ],\n                ],\n            ],\n        ],\n        dtype=np.float32,\n    )\n    output = np.array(\n        [\n            [\n                [\n                    [\n                        [1.0670303, 0.3324034, 1.4075173, 1.5856332],\n                        [-0.5976489, -0.11840499, 1.0669112, 0.6381069],\n                        [-0.9888186, -0.9876919, 0.42760208, -1.5348896],\n                    ],\n                    [\n                        [-1.2319425, 1.0813059, -0.11336456, -1.0679643],\n                        [1.1875744, 0.05552938, 0.43173137, -1.6125557],\n                        [-1.4255517, 0.8375778, 0.4784138, -0.9185038],\n                    ],\n                ],\n                [\n                    [\n                        [0.3447361, -1.3750811, -0.6446106, 0.62979853],\n                        [0.9606047, -1.8086823, 0.41011015, -0.3151683],\n                        [0.24436034, 1.3832531, -1.4615086, -0.17397629],\n                    ],\n                    [\n                        [-0.31622827, -0.8517619, 0.9093717, 0.7096987],\n                        [0.6423687, 0.51151085, 1.8379811, -1.1640717],\n                        [-1.2562994, -1.1418006, 0.59010565, 1.3352901],\n                    ],\n                ],\n            ],\n            [\n                [\n                    [\n                        [-0.23265934, 0.8016156, -0.32364592, -0.6859402],\n                        [-0.12706259, -1.3551185, -1.802801, -0.30770612],\n                        [-0.946859, -0.8758114, 0.06297152, 0.09757163],\n                    ],\n                    [\n                        [1.0894505, 1.3811613, 0.3167428, 0.78916013],\n                        [0.50535965, 1.3323971, -1.5440607, 0.55327666],\n                        [0.50405425, 1.5946931, -1.8583992, 1.0316093],\n                    ],\n                ],\n                [\n                    [\n                        [1.7506906, 0.19012147, 0.8893728, -0.7645185],\n                        [-0.3473382, 0.5543517, 0.03539129, -0.71040297],\n                        [1.174789, -0.17992027, 1.2704874, -0.20310321],\n                    ],\n                    [\n                        [-0.7287877, -0.06159106, -1.4350212, -0.7749395],\n                        [1.7470733, -1.5170306, -0.40116227, -1.548456],\n                        [1.3155918, 1.2255636, -0.9464568, -0.53470486],\n                    ],\n                ],\n            ],\n        ],\n        dtype=np.float32,\n    )\n    x = flow.tensor(input_arr, dtype=flow.float32, device=flow.device(device))\n    m = flow.nn.GroupNorm(num_groups=2, num_channels=2, affine=False).to(\n        device=flow.device(device)\n    )\n    y = m(x)\n    test_case.assertTrue(np.allclose(y.numpy(), output, 1e-03, 1e-03))\n\n\ndef _test_groupnorm_backward(test_case, device):\n    input_arr = np.array(\n        [\n            [\n                [\n                    [-0.8791, 0.2553, 0.7403, -0.2859],\n                    [0.8006, -1.7701, -0.9617, 0.1705],\n                    [0.2842, 1.7825, 0.3365, -0.8525],\n                ],\n                [\n                    [0.7332, -0.0737, 0.7245, -0.6551],\n                    [1.4461, -0.1827, 0.9737, -2.1571],\n                    [0.4657, 0.7244, 0.3378, 0.1775],\n                ],\n            ],\n            [\n                [\n                    [1.8896, 1.8686, 0.1896, 0.9817],\n                    [-0.0671, 1.5569, 1.1449, 0.0086],\n                    [-0.9468, -0.0124, 1.3227, -0.6567],\n                ],\n                [\n                    [-0.8472, 1.3012, -1.1065, 0.9348],\n                    [1.0346, 1.5703, 0.2419, -0.7048],\n                    [0.6957, -0.4523, -0.8819, 1.0164],\n                ],\n            ],\n        ],\n        dtype=np.float32,\n    )\n    x = flow.tensor(\n        input_arr, dtype=flow.float32, device=flow.device(device), requires_grad=True\n    )\n    m = flow.nn.GroupNorm(num_groups=1, num_channels=2).to(device=flow.device(device))\n    y = m(x)\n    z = y.sum()\n    z.backward()\n    test_case.assertTrue(\n        np.allclose(x.grad.numpy(), np.zeros(shape=input_arr.shape), 1e-03, 1e-03)\n    )\n\n\ndef _test_groupnorm_backward_fp16(test_case, device):\n    input_arr = np.array(\n        [\n            [\n                [\n                    [-0.8791, 0.2553, 0.7403, -0.2859],\n                    [0.8006, -1.7701, -0.9617, 0.1705],\n                    [0.2842, 1.7825, 0.3365, -0.8525],\n                ],\n                [\n                    [0.7332, -0.0737, 0.7245, -0.6551],\n                    [1.4461, -0.1827, 0.9737, -2.1571],\n                    [0.4657, 0.7244, 0.3378, 0.1775],\n                ],\n            ],\n            [\n                [\n                    [1.8896, 1.8686, 0.1896, 0.9817],\n                    [-0.0671, 1.5569, 1.1449, 0.0086],\n                    [-0.9468, -0.0124, 1.3227, -0.6567],\n                ],\n                [\n                    [-0.8472, 1.3012, -1.1065, 0.9348],\n                    [1.0346, 1.5703, 0.2419, -0.7048],\n                    [0.6957, -0.4523, -0.8819, 1.0164],\n                ],\n            ],\n        ],\n        dtype=np.float16,\n    )\n    x = flow.tensor(\n        input_arr, dtype=flow.float16, device=flow.device(device), requires_grad=True\n    )\n    m = (\n        flow.nn.GroupNorm(num_groups=1, num_channels=2)\n        .to(device=flow.device(device))\n        .to(flow.float16)\n    )\n    y = m(x)\n    z = y.sum()\n    z.backward()\n    test_case.assertTrue(\n        np.allclose(x.grad.numpy(), np.zeros(shape=input_arr.shape), 1e-03, 1e-03)\n    )\n\n\ndef _test_groupnorm_backward_3d(test_case, device):\n    input_arr = np.array(\n        [\n            [\n                [\n                    [\n                        [1.04569761, 0.22863248, 1.42439335, 1.62249689],\n                        [-0.80578825, -0.27276461, 1.04556507, 0.56864134],\n                        [-1.24085419, -1.23960097, 0.33451416, -1.84820402],\n                    ],\n                    [\n                        [-1.511261, 1.06157517, -0.26715858, -1.32888141],\n                        [1.17976881, -0.07931171, 0.33910684, -1.93458573],\n                        [-1.72659647, 0.79049652, 0.39102785, -1.16264882],\n                    ],\n                ],\n                [\n                    [\n                        [0.30067973, -1.2912226, -0.61508225, 0.56454001],\n                        [0.87074187, -1.69257376, 0.36119148, -0.31014289],\n                        [0.20776964, 1.26195488, -1.37122193, -0.17945234],\n                    ],\n                    [\n                        [-0.31112407, -0.80682631, 0.8233194, 0.6384975],\n                        [0.57617527, 0.45505028, 1.68286151, -1.09590744],\n                        [-1.18127546, -1.07529277, 0.52779943, 1.21755926],\n                    ],\n                ],\n            ],\n            [\n                [\n                    [\n                        [-0.12832351, 1.05625455, -0.23253249, -0.64747611],\n                        [-0.00738123, -1.41390089, -1.92664144, -0.21427625],\n                        [-0.94631219, -0.86493989, 0.21026905, 0.24989732],\n                    ],\n                    [\n                        [1.3859182, 1.72002107, 0.50091892, 1.04198896],\n                        [0.71694594, 1.66417023, -1.63030052, 0.77182641],\n                        [0.71545083, 1.96458366, -1.99031931, 1.3196714],\n                    ],\n                ],\n                [\n                    [\n                        [1.80091702, 0.02834973, 0.82259214, -1.05597501],\n                        [-0.58212207, 0.44205949, -0.14740003, -0.994508],\n                        [1.14678114, -0.39196097, 1.2554798, -0.41829324],\n                    ],\n                    [\n                        [-1.0153903, -0.25755713, -1.81756333, -1.06781159],\n                        [1.79680841, -1.9107133, -0.64325796, -1.94640775],\n                        [1.30671156, 1.20445339, -1.26262901, -0.79494188],\n                    ],\n                ],\n            ],\n        ],\n        dtype=np.float32,\n    )\n    x = flow.tensor(\n        input_arr, dtype=flow.float32, device=flow.device(device), requires_grad=True\n    )\n    m = flow.nn.GroupNorm(num_groups=2, num_channels=2, affine=False).to(\n        device=flow.device(device)\n    )\n    y = m(x)\n    z = y.sum()\n    z.backward()\n    test_case.assertTrue(\n        np.allclose(x.grad.numpy(), np.zeros(shape=input_arr.shape), 1e-03, 1e-03)\n    )\n\n\ndef _test_groupnorm_backward_3d_fp16(test_case, device):\n    input_arr = np.array(\n        [\n            [\n                [\n                    [\n                        [1.04569761, 0.22863248, 1.42439335, 1.62249689],\n                        [-0.80578825, -0.27276461, 1.04556507, 0.56864134],\n                        [-1.24085419, -1.23960097, 0.33451416, -1.84820402],\n                    ],\n                    [\n                        [-1.511261, 1.06157517, -0.26715858, -1.32888141],\n                        [1.17976881, -0.07931171, 0.33910684, -1.93458573],\n                        [-1.72659647, 0.79049652, 0.39102785, -1.16264882],\n                    ],\n                ],\n                [\n                    [\n                        [0.30067973, -1.2912226, -0.61508225, 0.56454001],\n                        [0.87074187, -1.69257376, 0.36119148, -0.31014289],\n                        [0.20776964, 1.26195488, -1.37122193, -0.17945234],\n                    ],\n                    [\n                        [-0.31112407, -0.80682631, 0.8233194, 0.6384975],\n                        [0.57617527, 0.45505028, 1.68286151, -1.09590744],\n                        [-1.18127546, -1.07529277, 0.52779943, 1.21755926],\n                    ],\n                ],\n            ],\n            [\n                [\n                    [\n                        [-0.12832351, 1.05625455, -0.23253249, -0.64747611],\n                        [-0.00738123, -1.41390089, -1.92664144, -0.21427625],\n                        [-0.94631219, -0.86493989, 0.21026905, 0.24989732],\n                    ],\n                    [\n                        [1.3859182, 1.72002107, 0.50091892, 1.04198896],\n                        [0.71694594, 1.66417023, -1.63030052, 0.77182641],\n                        [0.71545083, 1.96458366, -1.99031931, 1.3196714],\n                    ],\n                ],\n                [\n                    [\n                        [1.80091702, 0.02834973, 0.82259214, -1.05597501],\n                        [-0.58212207, 0.44205949, -0.14740003, -0.994508],\n                        [1.14678114, -0.39196097, 1.2554798, -0.41829324],\n                    ],\n                    [\n                        [-1.0153903, -0.25755713, -1.81756333, -1.06781159],\n                        [1.79680841, -1.9107133, -0.64325796, -1.94640775],\n                        [1.30671156, 1.20445339, -1.26262901, -0.79494188],\n                    ],\n                ],\n            ],\n        ],\n        dtype=np.float16,\n    )\n    x = flow.tensor(\n        input_arr, dtype=flow.float16, device=flow.device(device), requires_grad=True\n    )\n    m = (\n        flow.nn.GroupNorm(num_groups=2, num_channels=2, affine=False)\n        .to(device=flow.device(device))\n        .to(flow.float16)\n    )\n    y = m(x)\n    z = y.sum()\n    z.backward()\n    test_case.assertTrue(\n        np.allclose(x.grad.numpy(), np.zeros(shape=input_arr.shape), 1e-03, 1e-03)\n    )\n\n\ndef _test_groupnorm_nhwc(test_case, shape, num_groups):\n    (n, c, h, w) = shape\n    x = flow.tensor(\n        np.random.uniform(low=0.0, high=1.0, size=shape).astype(np.float32)\n    ).to(\"cuda\")\n    gamma = flow.tensor(\n        np.random.uniform(low=0.0, high=1.0, size=(c)).astype(np.float32)\n    ).to(\"cuda\")\n    beta = flow.tensor(\n        np.random.uniform(low=0.0, high=1.0, size=(c)).astype(np.float32)\n    ).to(\"cuda\")\n    y = flow._C.group_norm(x, gamma, beta, True, num_groups, 1e-5)\n    x_nhwc = x.permute(0, 2, 3, 1).contiguous()\n    y_nhwc = flow._C.group_norm(\n        x_nhwc, gamma, beta, True, num_groups, 1e-5, \"channels_last\"\n    )\n    test_case.assertTrue(\n        np.allclose(y_nhwc.permute(0, 3, 1, 2).numpy(), y, 1e-03, 1e-03)\n    )\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestGroupNorm(flow.unittest.TestCase):\n    def test_groupnorm(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [\n            _test_groupnorm,\n            _test_groupnorm_3d,\n            _test_groupnorm_backward,\n            _test_groupnorm_backward_3d,\n        ]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n    def test_groupnorm_grad_fp16(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [\n            _test_groupnorm_backward_fp16,\n            _test_groupnorm_backward_3d_fp16,\n        ]\n        # cpu test will raise error: var only support floating point dtypes\n        # https://github.com/Oneflow-Inc/oneflow/issues/9559\n        # arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"device\"] = [\"cuda\"]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n    @autotest(rtol=1e-03, atol=1e-03, check_graph=True)\n    def test_group_norm_with_random_data(test_case):\n        channels = random(5, 20)\n        m = torch.nn.GroupNorm(\n            num_groups=random(1, 5),\n            num_channels=channels,\n            eps=random(0, 1) | nothing(),\n            affine=random(),\n        )\n        m.train(random())\n        device = random_device()\n        m.to(device)\n        x = random_tensor(ndim=4, dim1=channels).to(device)\n        y = m(x)\n        return y\n\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    def test_groupnorm_nhwc(test_case):\n        _test_groupnorm_nhwc(test_case, (16, 64, 128, 128), 32)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_groupwise_quantization.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nfrom collections import OrderedDict\nimport numpy as np\nfrom oneflow.test_utils.test_util import GenArgList\nimport math\nimport os\n\nimport oneflow as flow\n\n\ndef _pack_int8_to_int4(x):\n    np_x = x.numpy()\n    l = np_x[..., 0::2]\n    r = np_x[..., 1::2]\n    l = np.left_shift(l, 4)\n    if x.dtype is flow.int8:\n        r = np.bitwise_and(r, np.int8(0xF))\n    packed = flow.tensor(np.bitwise_or(l, r), device=x.device)\n    return packed\n\n\ndef _unpack_int4_to_int8(x):\n    np_x = x.numpy()\n    l = np.right_shift(np_x, 4).reshape(x.shape + (1,))\n    r = np.right_shift(np.left_shift(np_x, 4), 4).reshape(x.shape + (1,))\n    unpacked = np.concatenate((l, r), -1).reshape(x.shape[0:-1] + (x.shape[-1] * 2,))\n    unpacked = flow.tensor(unpacked, device=x.device)\n    return unpacked\n\n\ndef _quantize(num_bits, symmetric, x, group_dim, group_size, quant_type):\n    x_float = x.float()\n    x_reshaped = x_float.reshape(\n        x.shape[:group_dim]\n        + (x.shape[group_dim] // group_size, group_size)\n        + x.shape[group_dim + 1 :]\n    )\n    if symmetric:\n        signed_max = float(2 ** (num_bits - 1)) - 1\n        offset = signed_max if quant_type is flow.uint8 else 0.0\n        scale_float = (\n            x_reshaped.abs().max(dim=group_dim + 1, keepdim=True).values / signed_max\n        )\n        quantized = (\n            flow.round(x_reshaped / scale_float + offset)\n            .reshape(x.shape)\n            .to(quant_type)\n        )\n        if num_bits == 4:\n            quantized = _pack_int8_to_int4(quantized)\n        return (quantized, scale_float.squeeze(group_dim + 1).to(x.dtype), None)\n    else:\n        unsigned_max = float(2 ** num_bits) - 1\n        mn = x_reshaped.min(dim=group_dim + 1, keepdim=True).values\n        mx = x_reshaped.max(dim=group_dim + 1, keepdim=True).values\n        scale_float = (mx - mn) / unsigned_max\n        quantized = (\n            flow.round((x_reshaped - mn) / scale_float).reshape(x.shape).to(flow.uint8)\n        )\n        if num_bits == 4:\n            quantized = _pack_int8_to_int4(quantized)\n        return (\n            quantized,\n            scale_float.squeeze(group_dim + 1).to(x.dtype),\n            mn.squeeze(group_dim + 1).to(x.dtype),\n        )\n\n\ndef _dequantize_ref(num_bits, symmetric, quantized, scale, zero, group_dim, group_size):\n    if num_bits == 4:\n        quantized = _unpack_int4_to_int8(quantized)\n    scale_reshaped = scale.unsqueeze(group_dim + 1)\n    quantized_reshaped = quantized.reshape(\n        quantized.shape[:group_dim]\n        + (quantized.shape[group_dim] // group_size, group_size)\n        + quantized.shape[group_dim + 1 :]\n    )\n    if symmetric:\n        offset = (\n            float(2 ** (num_bits - 1)) - 1 if quantized.dtype is flow.uint8 else 0.0\n        )\n        dequantized = (quantized_reshaped.to(scale.dtype) - offset) * scale_reshaped\n    else:\n        zero_reshaped = zero.unsqueeze(group_dim + 1)\n        dequantized = (\n            zero_reshaped + quantized_reshaped.to(scale.dtype) * scale_reshaped\n        )\n    return dequantized.reshape(quantized.shape)\n\n\ndef _dequantize(num_bits, symmetric, x, scale, zero, group_dim, group_size):\n    return flow._C.groupwise_dequantize(\n        x,\n        scale=scale,\n        zero=zero,\n        group_dim=group_dim,\n        group_size=group_size,\n        num_bits=num_bits,\n        symmetric=symmetric,\n    )\n\n\ndef _test_dequantize(test_case, num_bits, shape, group_dim, group_size):\n\n    for dtype in [flow.float, flow.float16]:\n        x = flow.randn(shape, device=\"cuda\", dtype=flow.float,).to(dtype)\n        for symmetric in [True, False]:\n            for quant_type in [flow.int8, flow.uint8] if symmetric else [flow.uint8]:\n                quantized, scale, zero = _quantize(\n                    num_bits, symmetric, x, group_dim, group_size, quant_type\n                )\n                dequantized = _dequantize(\n                    num_bits, symmetric, quantized, scale, zero, group_dim, group_size\n                )\n                dequantized_ref = _dequantize_ref(\n                    num_bits, symmetric, quantized, scale, zero, group_dim, group_size,\n                )\n                test_case.assertTrue(\n                    np.allclose(dequantized_ref, dequantized, atol=1e-2, rtol=1e-2)\n                )\n\n\ndef _test_fused_linear(test_case, num_bits, m, k, n, group_dim, group_size):\n    for dtype in [flow.float16, flow.float]:\n        x = flow.randn((m, k), device=\"cuda\", dtype=flow.float,).to(dtype) / 10\n        w = flow.randn((n, k), device=\"cuda\", dtype=flow.float,).to(dtype) / 10\n        b = flow.randn((n), device=\"cuda\", dtype=flow.float,).to(dtype) / 10\n\n        for symmetric in [True, False]:\n            for quant_type in [flow.int8, flow.uint8] if symmetric else [flow.uint8]:\n                w_quantized, w_scale, w_zero = _quantize(\n                    num_bits, symmetric, w, group_dim, group_size, quant_type\n                )\n\n                fused_out = flow._C.fused_linear_with_groupwise_quantized_weight(\n                    x=x,\n                    w=w_quantized,\n                    w_scale=w_scale,\n                    w_zero=w_zero,\n                    b=b,\n                    num_bits=num_bits,\n                    symmetric=symmetric,\n                    group_dim=group_dim,\n                    group_size=group_size,\n                )\n                ref = (\n                    flow.matmul(\n                        x,\n                        _dequantize(\n                            num_bits,\n                            symmetric,\n                            w_quantized,\n                            w_scale,\n                            w_zero,\n                            group_dim,\n                            group_size,\n                        ).t(),\n                    )\n                    + b\n                )\n\n                test_case.assertTrue(np.allclose(ref, fused_out, atol=1e-2, rtol=1e-2))\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n@flow.unittest.skip_unless_1n1d()\nclass TestGroupWiseQuantization(flow.unittest.TestCase):\n    def test_dequantize(test_case):\n        _test_dequantize(test_case, 8, (128, 256), 0, 128)\n        _test_dequantize(test_case, 8, (64, 128, 256), 0, 64)\n        _test_dequantize(test_case, 8, (64, 128, 256), 1, 128)\n        _test_dequantize(test_case, 8, (64, 128, 256), 2, 256)\n        _test_dequantize(test_case, 8, (63, 127, 255), 0, 63)\n        _test_dequantize(test_case, 8, (63, 127, 255), 1, 127)\n        _test_dequantize(test_case, 8, (63, 127, 255), 2, 255)\n        _test_dequantize(test_case, 8, (128, 256), 1, 256 // 4)\n        _test_dequantize(test_case, 8, (128, 256), 0, 128 // 4)\n        _test_dequantize(test_case, 8, (64, 128, 256), 0, 64 // 4)\n        _test_dequantize(test_case, 8, (64, 128, 256), 1, 128 // 4)\n        _test_dequantize(test_case, 8, (64, 128, 256), 2, 256 // 4)\n\n        _test_dequantize(test_case, 4, (128, 256), 1, 256)\n        _test_dequantize(test_case, 4, (128, 256), 0, 128)\n        _test_dequantize(test_case, 4, (64, 128, 256), 0, 64)\n        _test_dequantize(test_case, 4, (64, 128, 256), 1, 128)\n        _test_dequantize(test_case, 4, (64, 128, 256), 2, 256)\n        _test_dequantize(test_case, 4, (128, 256), 1, 256 // 4)\n        _test_dequantize(test_case, 4, (128, 256), 0, 128 // 4)\n        _test_dequantize(test_case, 4, (64, 128, 256), 0, 64 // 4)\n        _test_dequantize(test_case, 4, (64, 128, 256), 1, 128 // 4)\n        _test_dequantize(test_case, 4, (64, 128, 256), 2, 256 // 4)\n\n    def test_fused_linear(test_case):\n        _test_fused_linear(test_case, 8, 1, 64, 128, 0, 128)\n        _test_fused_linear(test_case, 8, 1, 64, 128, 1, 64)\n        _test_fused_linear(test_case, 8, 16, 64, 128, 0, 128)\n        _test_fused_linear(test_case, 8, 16, 64, 128, 1, 64)\n        _test_fused_linear(test_case, 8, 1, 63, 127, 0, 127)\n        _test_fused_linear(test_case, 8, 1, 63, 127, 1, 63)\n        _test_fused_linear(test_case, 8, 1, 256, 512, 0, 64)\n        _test_fused_linear(test_case, 8, 1, 256, 512, 1, 64)\n        _test_fused_linear(test_case, 4, 1, 256, 512, 0, 512)\n        _test_fused_linear(test_case, 4, 1, 256, 512, 1, 256)\n        _test_fused_linear(test_case, 4, 1, 256, 512, 0, 64)\n        _test_fused_linear(test_case, 4, 1, 256, 512, 1, 64)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_gumbel_softmax.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nfrom oneflow.test_utils.test_util import GenArgList, type_name_to_flow_type\nfrom oneflow.test_utils.automated_test_util import *\n\nimport oneflow as flow\nimport oneflow.nn.functional as F\nimport oneflow.unittest\n\n\ndef _test_gumbel_softmax(test_case, tau, dim, device, dtype):\n    dtype = type_name_to_flow_type[dtype]\n    x = flow.tensor(np.random.randn(20, 32), dtype=dtype, device=flow.device(device),)\n    y_soft = F.gumbel_softmax(x, tau=tau, dim=dim)\n    y_hard = F.gumbel_softmax(x, tau=tau, dim=dim, hard=True)\n    test_case.assertEqual(x.shape, y_soft.shape)\n    test_case.assertEqual(x.shape, y_hard.shape)\n    test_case.assertEqual(x.dtype, y_soft.dtype)\n    test_case.assertEqual(x.dtype, y_hard.dtype)\n\n\ndef _test_gumbel_softmax_hard(test_case, tau, dim, device, dtype):\n    dtype = type_name_to_flow_type[dtype]\n    x = flow.tensor(np.random.randn(45, 23), dtype=dtype, device=flow.device(device),)\n    y_hard = F.gumbel_softmax(x, tau=tau, dim=dim, hard=True)\n    test_case.assertEqual(y_hard.min(), 0)\n    if dim == -1:\n        test_case.assertEqual(y_hard.sum().item(), 45)\n    elif dim == 0:\n        test_case.assertEqual(y_hard.sum().item(), 23)\n\n\ndef _test_gumbel_softmax_backward(test_case, tau, dim, device, dtype):\n    dtype = type_name_to_flow_type[dtype]\n    x_np = np.random.rand(10, 10)\n    x_soft = flow.tensor(\n        x_np, dtype=dtype, device=flow.device(device), requires_grad=True,\n    )\n    x_hard = flow.tensor(\n        x_np, dtype=dtype, device=flow.device(device), requires_grad=True,\n    )\n    y_soft = F.gumbel_softmax(x_soft, tau, dim=dim)\n    y_hard = F.gumbel_softmax(x_hard, tau, dim=dim, hard=False)\n\n    y_soft.mean().backward()\n    y_hard.mean().backward()\n\n    np.testing.assert_allclose(\n        x_hard.grad.numpy(), x_soft.grad.numpy(), rtol=1e-5, atol=1e-5, verbose=True\n    )\n\n\ndef _test_gumbel_softmax_half(test_case, tau, dim, device):\n    x = flow.tensor(np.random.randn(20, 32), device=flow.device(device),).to(\n        flow.float16\n    )\n    y_soft = F.gumbel_softmax(x, tau=tau, dim=dim)\n    y_hard = F.gumbel_softmax(x, tau=tau, dim=dim, hard=True)\n    test_case.assertEqual(x.shape, y_soft.shape)\n    test_case.assertEqual(x.shape, y_hard.shape)\n    test_case.assertEqual(x.dtype, y_soft.dtype)\n    test_case.assertEqual(x.dtype, y_hard.dtype)\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestGumbelSoftmaxModule(flow.unittest.TestCase):\n    @autotest()\n    def test_gumbel_softmax(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"fun\"] = [\n            _test_gumbel_softmax,\n            _test_gumbel_softmax_hard,\n            _test_gumbel_softmax_backward,\n        ]\n        arg_dict[\"tau\"] = [1, 2, 0.5]\n        arg_dict[\"dim\"] = [0, -1]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"dtype\"] = [\"float32\", \"double\"]\n\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n    @autotest()\n    def test_leakyrelu_module_with_half_random_data(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"fun\"] = [\n            _test_gumbel_softmax_half,\n        ]\n        arg_dict[\"tau\"] = [1, 2, 0.5]\n        arg_dict[\"dim\"] = [0, -1]\n        arg_dict[\"device\"] = [\"cuda\"]\n\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_hann_window.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nfrom oneflow.test_utils.automated_test_util import *\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestHannWindow(flow.unittest.TestCase):\n    @autotest(n=1, auto_backward=False, check_graph=True)\n    def test_hann_window(test_case):\n        device = random_device()\n        window_length = random(1, 8).to(int).value()\n        periodic = random_bool().value()\n        output = torch.hann_window(window_length, periodic, device=device)\n        return output\n\n    def test_hann_window_global(test_case):\n        placement = flow.placement(\"cpu\", ranks=[0])\n        sbp = (flow.sbp.broadcast,)\n        window_length = random(1, 8).to(int).value()\n        periodic = random_bool().value()\n        output = flow.hann_window(window_length, periodic, placement=placement, sbp=sbp)\n        test_case.assertEqual(output.sbp, sbp)\n        test_case.assertEqual(output.placement, placement)\n\n    def test_hann_window_dtype(test_case):\n        device = random_device().value()\n        window_length = random(1, 8).to(int).value()\n        periodic = random_bool().value()\n        dtype = flow.float64\n        output = flow.hann_window(window_length, periodic, device=device, dtype=dtype)\n        test_case.assertEqual(output.dtype, dtype)\n\n    @profile(torch.hann_window)\n    def profile_hann_window(test_case):\n        torch.hann_window(128000, periodic=True)\n        torch.hann_window(128001, periodic=False)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_higher_derivative_activation.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\n\nimport numpy as np\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\nimport torch as pytorch_origin\nimport oneflow as oneflow_origin\nfrom collections import defaultdict\n\n\ndef _assert_true(test_case, value1, value2):\n    test_case.assertTrue(\n        np.allclose(\n            value1.detach().cpu().numpy(),\n            value2.detach().numpy(),\n            rtol=1e-05,\n            atol=1e-05,\n        )\n    )\n\n\ndef _test_activation_grad_grad_impl(test_case, op_name, *args, **kwargs):\n    x = random_tensor(ndim=2, low=-5)\n    y = eval(f\"torch.nn.functional.{op_name}\")(x, *args, **kwargs)\n\n    x_shape = x.oneflow.shape\n    init_grad_x = random_tensor(len(x_shape), *x_shape)\n    init_grad_y = random_tensor(len(x_shape), *x_shape)\n\n    dx = torch.autograd.grad(y, x, init_grad_y, True, True)[0]\n    _assert_true(test_case, dx.pytorch, dx.oneflow)\n\n    ddx_ddy = torch.autograd.grad(dx, [x, init_grad_y], init_grad_x)\n    ddx, ddy = ddx_ddy[0], ddx_ddy[1]\n    _assert_true(test_case, ddx.pytorch, ddx.oneflow)\n    _assert_true(test_case, ddy.pytorch, ddy.oneflow)\n\n\ndef _test_prelu_activation_grad_grad_impl(test_case, op_name, *args, **kwargs):\n    x = random_tensor(ndim=2, low=-5)\n    a = random_tensor(ndim=1, dim0=x.oneflow.shape[1])\n    y = torch.nn.functional.prelu(x, a)\n\n    x_shape = x.oneflow.shape\n    a_shape = a.oneflow.shape\n    init_grad_x = random_tensor(len(x_shape), *x_shape)\n    init_grad_y = random_tensor(len(x_shape), *x_shape)\n    init_grad_a = random_tensor(len(a_shape), *a_shape)\n\n    dx_and_da = torch.autograd.grad(y, [x, a], init_grad_y, True, True)\n    dx, da = dx_and_da[0], dx_and_da[1]\n    _assert_true(test_case, dx.pytorch, dx.oneflow)\n    _assert_true(test_case, da.pytorch, da.oneflow)\n\n    ddx_dda_ddy = torch.autograd.grad(\n        dx_and_da, [dx, da, init_grad_y], [init_grad_x, init_grad_a]\n    )\n    ddx, dda, ddy = ddx_dda_ddy[0], ddx_dda_ddy[1], ddx_dda_ddy[2]\n    _assert_true(test_case, ddx.pytorch, ddx.oneflow)\n    _assert_true(test_case, dda.pytorch, dda.oneflow)\n    _assert_true(test_case, ddy.pytorch, ddy.oneflow)\n\n\ndef _test_hardswish_activation_grad_grad_impl(test_case, op_name, *args, **kwargs):\n    x = random_tensor(ndim=2, low=-1, dim1=4)\n    y = torch.nn.functional.hardswish(x, *args, **kwargs)\n\n    x_shape = x.oneflow.shape\n    init_grad_x = random_tensor(len(x_shape), *x_shape)\n    init_grad_y = random_tensor(len(x_shape), *x_shape)\n\n    dx_pytorch = pytorch_origin.autograd.grad(\n        y.pytorch, x.pytorch, init_grad_y.pytorch\n    )[0]\n    dx_oneflow = oneflow_origin.autograd.grad(\n        y.oneflow, x.oneflow, init_grad_y.oneflow, True, True\n    )[0]\n    _assert_true(test_case, dx_pytorch, dx_oneflow)\n\n    ddx, ddy = flow.autograd.grad(\n        dx_oneflow, [x.oneflow, init_grad_y.oneflow], init_grad_x.oneflow\n    )\n    x, dx, init_grad_x, init_grad_y = (\n        x.oneflow,\n        dx_oneflow,\n        init_grad_x.oneflow,\n        init_grad_y.oneflow,\n    )\n    manual_ddx = flow.where(\n        ((x > -3.0) < 3.0), 1.0 / 3.0 * init_grad_x * init_grad_y, flow.tensor(0.0)\n    )\n    manual_ddy = dx / init_grad_y * init_grad_x\n    _assert_true(test_case, manual_ddx, ddx)\n    _assert_true(test_case, manual_ddy, ddy)\n\n\ndef _test_hardsigmoid_activation_grad_grad_impl(test_case, op_name, *args, **kwargs):\n    x = random_tensor(ndim=2, low=-1, dim1=4)\n    y = torch.nn.functional.hardsigmoid(x, *args, **kwargs)\n\n    x_shape = x.oneflow.shape\n    init_grad_x = random_tensor(len(x_shape), *x_shape)\n    init_grad_y = random_tensor(len(x_shape), *x_shape)\n\n    dx_pytorch = pytorch_origin.autograd.grad(\n        y.pytorch, x.pytorch, init_grad_y.pytorch\n    )[0]\n    dx_oneflow = oneflow_origin.autograd.grad(\n        y.oneflow, x.oneflow, init_grad_y.oneflow, True, True\n    )[0]\n    _assert_true(test_case, dx_pytorch, dx_oneflow)\n\n    ddx, ddy = flow.autograd.grad(\n        dx_oneflow, [x.oneflow, init_grad_y.oneflow], init_grad_x.oneflow\n    )\n    x, dx, init_grad_x, init_grad_y = (\n        x.oneflow,\n        dx_oneflow,\n        init_grad_x.oneflow,\n        init_grad_y.oneflow,\n    )\n    manual_ddx = flow.zeros_like(x)\n    manual_ddy = dx / init_grad_y * init_grad_x\n    _assert_true(test_case, manual_ddx, ddx)\n    _assert_true(test_case, manual_ddy, ddy)\n\n\nclass TestActivationHigherDerivative(flow.unittest.TestCase):\n    @unittest.skip(\"skip for now, becase it failed 8 times in past week\")\n    def test_activation_grad_grad(test_case):\n        op_args = defaultdict(list)\n        op_kwargs = defaultdict(dict)\n\n        # parameter name not same in pytorch and oneflow\n        op_args[\"leaky_relu\"] = [random(-1, 1).to(float)]\n\n        # some op only support kwargs, like celu in oneflow\n        op_kwargs[\"hardtanh\"] = {\n            \"min_val\": random(-5, -1).to(float),\n            \"max_val\": random(1, 5).to(float),\n        }\n        op_kwargs[\"elu\"] = {\"alpha\": random(0, 1).to(float)}\n        op_kwargs[\"celu\"] = {\"alpha\": random(0, 1).to(float)}\n        op_kwargs[\"threshold\"] = {\n            \"threshold\": random().to(float),\n            \"value\": random().to(float),\n        }\n        op_kwargs[\"softplus\"] = {\n            \"beta\": random().to(float),\n            \"threshold\": random().to(float),\n        }\n\n        op_names = [\n            \"gelu\",\n            \"mish\",\n            \"silu\",\n            \"selu\",\n            \"softsign\",\n            \"hardsigmoid\",\n            \"hardswish\",\n            \"relu\",\n            \"elu\",\n            \"celu\",\n            \"prelu\",\n            \"hardshrink\",\n            \"softshrink\",\n            \"leaky_relu\",\n            \"hardtanh\",\n            \"softplus\",\n            \"threshold\",\n        ]\n        for op_name in op_names:\n            try:\n                functor = eval(f\"_test_{op_name}_activation_grad_grad_impl\")\n            except:\n                functor = _test_activation_grad_grad_impl\n\n            print(f\"| {op_name:-^60} |\")\n            for i in range(10):\n                functor(test_case, op_name, *op_args[op_name], **op_kwargs[op_name])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_higher_derivative_conv.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\n\nimport numpy as np\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\nimport torch as pytorch_origin\nimport oneflow as oneflow_origin\n\n\ndef _test_convnd_grad_grad_impl(test_case, ndim, rtol=1e-4, atol=1e-5):\n    minibatch = np.random.randint(1, 5)\n    groups = np.random.randint(1, 5)\n    in_channels = np.random.randint(1, 5) * groups\n    out_channels = in_channels * np.random.randint(1, 5)\n    padding = np.random.randint(1, 3)\n    stride = np.random.randint(1, 3)\n    dilation = np.random.randint(1, 3)\n\n    x_shape = [minibatch, in_channels] + [np.random.randint(8, 12) for i in range(ndim)]\n    w_shape = [out_channels, in_channels // groups] + [\n        np.random.randint(2, 5) for i in range(ndim)\n    ]\n\n    x = random_tensor(len(x_shape), *x_shape)\n    w = random_tensor(len(w_shape), *w_shape)\n    init_grad_x = random_tensor(len(x_shape), *x_shape)\n    init_grad_w = random_tensor(len(w_shape), *w_shape)\n\n    y = eval(f\"torch.nn.functional.conv{ndim}d\")(\n        x, w, stride=stride, padding=padding, groups=groups, dilation=dilation\n    )\n    init_grad_y = random_tensor(len(y.oneflow.shape), *y.oneflow.shape)\n\n    dx = torch.autograd.grad(\n        outputs=y,\n        inputs=x,\n        grad_outputs=init_grad_y,\n        create_graph=True,\n        retain_graph=True,\n    )[0]\n    test_case.assertTrue(\n        np.allclose(\n            dx.pytorch.detach().cpu().numpy(),\n            dx.oneflow.detach().numpy(),\n            rtol=rtol,\n            atol=atol,\n        )\n    )\n\n    dw = torch.autograd.grad(\n        outputs=y,\n        inputs=w,\n        grad_outputs=init_grad_y,\n        create_graph=True,\n        retain_graph=True,\n    )[0]\n    test_case.assertTrue(\n        np.allclose(\n            dw.pytorch.detach().cpu().numpy(),\n            dw.oneflow.detach().numpy(),\n            rtol=rtol,\n            atol=atol,\n        )\n    )\n\n    # torch.autograd.grad in autotest does not support inputs/outpus/grad_outputs as a list\n    # so use the original pytorch/oneflow module\n    ddx_pytorch, ddw_pytorch = pytorch_origin.autograd.grad(\n        outputs=[dx.pytorch, dw.pytorch],\n        inputs=[x.pytorch, w.pytorch],\n        grad_outputs=[init_grad_x.pytorch, init_grad_w.pytorch],\n        create_graph=True,\n        retain_graph=True,\n    )\n    ddx_oneflow, ddw_oneflow = oneflow_origin.autograd.grad(\n        outputs=[dx.oneflow, dw.oneflow],\n        inputs=[x.oneflow, w.oneflow],\n        grad_outputs=[init_grad_x.oneflow, init_grad_w.oneflow],\n        create_graph=True,\n        retain_graph=True,\n    )\n\n    test_case.assertTrue(\n        np.allclose(\n            ddw_pytorch.detach().cpu().numpy(),\n            ddw_oneflow.detach().numpy(),\n            rtol=rtol,\n            atol=atol,\n        )\n    )\n    test_case.assertTrue(\n        np.allclose(\n            ddx_pytorch.detach().cpu().numpy(),\n            ddx_oneflow.detach().numpy(),\n            rtol=rtol,\n            atol=atol,\n        )\n    )\n\n    dgrad_dx = torch.autograd.grad(\n        outputs=dx,\n        inputs=init_grad_y,\n        grad_outputs=init_grad_x,\n        create_graph=True,\n        retain_graph=True,\n    )[0]\n    test_case.assertTrue(\n        np.allclose(\n            dgrad_dx.pytorch.detach().cpu().numpy(),\n            dgrad_dx.oneflow.detach().numpy(),\n            rtol=rtol,\n            atol=atol,\n        )\n    )\n\n    dgrad_dw = torch.autograd.grad(\n        outputs=dw,\n        inputs=init_grad_y,\n        grad_outputs=init_grad_w,\n        create_graph=True,\n        retain_graph=True,\n    )[0]\n    test_case.assertTrue(\n        np.allclose(\n            dgrad_dw.pytorch.detach().cpu().numpy(),\n            dgrad_dw.oneflow.detach().numpy(),\n            rtol=rtol,\n            atol=atol,\n        )\n    )\n\n\nclass TestConvHigherDerivative(flow.unittest.TestCase):\n    def test_conv1d_grad_grad(test_case):\n        _test_convnd_grad_grad_impl(test_case, 1)\n\n    def test_conv2d_grad_grad(test_case):\n        _test_convnd_grad_grad_impl(test_case, 2)\n\n    def test_conv3d_grad_grad(test_case):\n        _test_convnd_grad_grad_impl(test_case, 3, atol=1e-3)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_higher_derivative_div.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\n\nimport numpy as np\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\nfrom numpy.random import randint\n\n\ndef _test_div_grad_grad_impl(test_case):\n    y_shape = [randint(2, 5) for _ in range(randint(0, 6))]\n    x_shape = [randint(2, 5) for _ in range(randint(0, 6 - len(y_shape)))] + y_shape\n    if random_bool().value():\n        x_shape, y_shape = y_shape, x_shape\n\n    x = random_tensor(len(x_shape), *x_shape).requires_grad_(True)\n    y = random_tensor(len(y_shape), *y_shape).requires_grad_(True)\n    z = torch.div(x, y)\n\n    init_grad_z = random_tensor(len(z.oneflow.shape), *z.oneflow.shape)\n    init_grad_x = random_tensor(len(x.oneflow.shape), *x.oneflow.shape)\n    init_grad_y = random_tensor(len(y.oneflow.shape), *y.oneflow.shape)\n\n    dx_and_dy = torch.autograd.grad(z, [x, y], init_grad_z, True, True)\n    test_case.assertTrue(\n        np.allclose(\n            dx_and_dy.pytorch[0].detach().cpu().numpy(),\n            dx_and_dy.oneflow[0].detach().numpy(),\n            rtol=1e-4,\n            atol=1e-4,\n        )\n    )\n    test_case.assertTrue(\n        np.allclose(\n            dx_and_dy.pytorch[1].detach().cpu().numpy(),\n            dx_and_dy.oneflow[1].detach().numpy(),\n            rtol=1e-4,\n            atol=1e-4,\n        )\n    )\n\n    ddx_and_ddy_and_ddz = torch.autograd.grad(\n        dx_and_dy, [x, y, init_grad_z], [init_grad_x, init_grad_y], True, True\n    )\n    test_case.assertTrue(\n        np.allclose(\n            ddx_and_ddy_and_ddz.pytorch[0].detach().cpu().numpy(),\n            ddx_and_ddy_and_ddz.oneflow[0].detach().numpy(),\n            rtol=1e-3,\n            atol=1e-3,\n        )\n    )\n    test_case.assertTrue(\n        np.allclose(\n            ddx_and_ddy_and_ddz.pytorch[1].detach().cpu().numpy(),\n            ddx_and_ddy_and_ddz.oneflow[1].detach().numpy(),\n            rtol=1e-3,\n            atol=1e-3,\n        )\n    )\n    test_case.assertTrue(\n        np.allclose(\n            ddx_and_ddy_and_ddz.pytorch[2].detach().cpu().numpy(),\n            ddx_and_ddy_and_ddz.oneflow[2].detach().numpy(),\n            rtol=1e-3,\n            atol=1e-3,\n        )\n    )\n\n\nclass TestDivHigherDerivative(flow.unittest.TestCase):\n    def test_div_grad_grad(test_case):\n        for i in range(10):\n            _test_div_grad_grad_impl(test_case)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_higher_derivative_loss.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nimport numpy as np\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\ndef _assert_true(test_case, value1, value2, name=\"\"):\n    is_equal = np.allclose(\n        value1.detach().cpu().numpy(), value2.detach().numpy(), rtol=1e-04, atol=1e-04,\n    )\n    test_case.assertTrue(is_equal, f\"{name} is not equal.\" if name else \"\")\n\n\ndef generate_grads_for_variables(variables):\n    if isinstance(variables, list):\n        variables_shape = [i.pytorch.shape for i in variables]\n        device = torch.device(str(variables[0].pytorch.device))\n    elif hasattr(variables, \"pytorch\"):\n        variables_shape = [i.shape for i in variables.pytorch]\n        device = torch.device(str(variables.pytorch[0].device))\n    else:\n        assert False\n\n    grads = [\n        random_tensor(len(shape), *shape, requires_grad=True).to(device)\n        for shape in variables_shape\n    ]\n    return grads\n\n\ndef calculate_and_compare_loss(test_case, input, target, model, order=2):\n    output = model(input, target)\n    _assert_true(test_case, output.pytorch, output.oneflow)\n\n    init_inputs = [input, target]\n    grad_inputs = [output]\n    grad_outputs = []\n    for i in range(order):\n        inputs = [\n            var for var in [*init_inputs, *grad_outputs] if var.pytorch.requires_grad\n        ]\n        outputs = grad_inputs\n        grad_outputs = generate_grads_for_variables(outputs)\n        if i == order - 1:\n            grad_inputs = torch.autograd.grad(outputs, inputs, grad_outputs)\n        else:\n            grad_inputs = torch.autograd.grad(outputs, inputs, grad_outputs, True, True)\n        for j in range(len(inputs)):\n            _assert_true(\n                test_case,\n                grad_inputs[j].pytorch,\n                grad_inputs[j].oneflow,\n                f\"{i}-grad_inputs[{j}]\",\n            )\n\n\ndef generate_necessity_for_default_loss():\n    ndim = random(2, 6).to(int).value()\n    device = random_device()\n    shape = [random().to(int) for _ in range(ndim)]\n    input_requires_grad = True\n    target_requires_grad = random_bool().value()\n    return (\n        random_tensor(ndim, *shape, requires_grad=input_requires_grad, low=0).to(\n            device\n        ),\n        random_tensor(ndim, *shape, requires_grad=target_requires_grad, low=0).to(\n            device\n        ),\n    )\n\n\ndef generate_necessity_for_nll_loss():\n    ndim = random(2, 6).to(int).value()\n    device = random_device()\n    num_classes = random(low=2).to(int)\n    batch_size = random(low=2, high=5).to(int)\n    ignore_index = (\n        random(0, num_classes).to(int) | nothing()\n        if num_classes.value() > 2\n        else nothing()\n    )\n    extra_dim = [random().to(int) for _ in range(ndim - 2)]\n    return (\n        random_tensor(ndim, batch_size, num_classes, *extra_dim).to(device),\n        random_tensor(\n            ndim - 1,\n            batch_size,\n            *extra_dim,\n            low=0,\n            high=num_classes,\n            dtype=int,\n            requires_grad=False,\n        ).to(device),\n        random_tensor(1, num_classes, low=0, high=3, requires_grad=False).to(device),\n        ignore_index,\n    )\n\n\ndef generate_necessity_for_bce_loss():\n    ndim = random(2, 6).to(int).value()\n    device = random_device()\n    num_classes = 2\n    batch_size = random(low=2, high=5).to(int)\n    extra_dim = [random().to(int) for _ in range(ndim - 2)]\n    input_requires_grad = True\n    target_requires_grad = False\n    return (\n        random_tensor(\n            ndim,\n            batch_size,\n            num_classes,\n            *extra_dim,\n            requires_grad=input_requires_grad,\n            low=0,\n            high=1,\n        ).to(device),\n        random_tensor(\n            ndim,\n            batch_size,\n            num_classes,\n            *extra_dim,\n            low=0,\n            high=num_classes,\n            requires_grad=target_requires_grad,\n        ).to(device),\n        random_tensor(\n            ndim,\n            batch_size,\n            num_classes,\n            *extra_dim,\n            low=0,\n            high=3,\n            requires_grad=False,\n        ).to(device),\n        random_tensor(\n            1,\n            oneof(extra_dim[-1] if ndim > 2 else num_classes, 1).value(),\n            low=1,\n            high=3,\n            requires_grad=False,\n        ).to(device),\n    )\n\n\ndef _test_smooth_l1_loss_grad_grad_impl(test_case):\n    x, y = generate_necessity_for_default_loss()\n\n    m = torch.nn.SmoothL1Loss(\n        reduction=oneof(\"none\", \"sum\", \"mean\", nothing()), beta=oneof(0.0, 0.5, 1)\n    )\n    m.to(x.device)\n\n    calculate_and_compare_loss(test_case, x, y, m)\n\n\ndef _test_kl_div_loss_grad_grad_impl(test_case):\n    x, y = generate_necessity_for_default_loss()\n\n    m = torch.nn.KLDivLoss(\n        reduction=oneof(\"none\", \"sum\", \"mean\", nothing()),\n        log_target=oneof(True, False),\n    )\n    m.to(x.device)\n\n    calculate_and_compare_loss(test_case, x, y, m)\n\n\ndef _test_bce_loss_grad_grad_impl(test_case, with_logits=False):\n    x, y, weight, pos_weight = generate_necessity_for_bce_loss()\n\n    if with_logits:\n        weight = oneof(weight, nothing())\n        has_pos_weight = random_bool().value()\n        pos_weight = pos_weight if has_pos_weight else nothing()\n        m = torch.nn.BCEWithLogitsLoss(\n            weight=weight,\n            pos_weight=pos_weight,\n            reduction=oneof(\"none\", \"sum\", \"mean\"),\n        )\n        if has_pos_weight:\n            y = y.detach().clone().requires_grad_(False)\n    else:\n        m = torch.nn.BCELoss(\n            weight=oneof(weight, nothing()), reduction=oneof(\"none\", \"sum\", \"mean\"),\n        )\n    m.to(x.device)\n\n    calculate_and_compare_loss(test_case, x, y, m)\n\n\ndef _test_nll_loss_grad_grad_impl(test_case):\n    (x, y, weight, ignore_index) = generate_necessity_for_nll_loss()\n    m = torch.nn.NLLLoss(\n        weight=oneof(weight, nothing()),\n        reduction=oneof(\"none\", \"sum\", \"mean\"),\n        ignore_index=ignore_index,\n    )\n    m.to(x.device)\n\n    calculate_and_compare_loss(test_case, x, y, m)\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestLossHigherDerivative(flow.unittest.TestCase):\n    def test_smooth_l1_loss_grad_grad(test_case):\n        for i in range(5):\n            _test_smooth_l1_loss_grad_grad_impl(test_case)\n\n    def test_kl_div_loss_grad_grad(test_case):\n        for i in range(5):\n            _test_kl_div_loss_grad_grad_impl(test_case)\n\n    @unittest.skip(\"skip for now, becase it failed 8 times in past week\")\n    def test_nll_loss_grad_grad(test_case):\n        for i in range(5):\n            _test_nll_loss_grad_grad_impl(test_case)\n\n    def test_bce_loss_grad_grad(test_case):\n        for i in range(5):\n            _test_bce_loss_grad_grad_impl(test_case)\n\n    def test_bce_with_logits_loss_grad_grad(test_case):\n        for i in range(5):\n            _test_bce_loss_grad_grad_impl(test_case, with_logits=True)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_higher_derivative_matmul.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\n\nimport numpy as np\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\nimport torch as pytorch_origin\nimport oneflow as oneflow_origin\n\n\nclass TestMatmulHigherDerivative(flow.unittest.TestCase):\n    def test_broadcast_matmul_grad_b_grad(test_case):\n        broadcast_dims = [\n            np.random.randint(2, 10) for _ in range(np.random.randint(1, 3))\n        ]\n        m = np.random.randint(2, 10)\n        n = np.random.randint(2, 10)\n        k = np.random.randint(2, 10)\n\n        shape_a = broadcast_dims + [m, k]\n        shape_b = [k, n]\n        shape_y = broadcast_dims + [m, n]\n\n        a = random_tensor(len(shape_a), *shape_a).requires_grad_(True)\n        b = random_tensor(len(shape_b), *shape_b).requires_grad_(True)\n\n        y = torch.matmul(a, b)\n\n        init_grad_a = random_tensor(len(shape_a), *shape_a).requires_grad_(True)\n        init_grad_b = random_tensor(len(shape_b), *shape_b).requires_grad_(True)\n        init_grad_y = random_tensor(len(shape_y), *shape_y).requires_grad_(True)\n\n        da = torch.autograd.grad(\n            outputs=y,\n            inputs=a,\n            grad_outputs=init_grad_y,\n            create_graph=True,\n            retain_graph=True,\n        )[0]\n        test_case.assertTrue(\n            np.allclose(\n                da.pytorch.detach().cpu().numpy(),\n                da.oneflow.detach().numpy(),\n                rtol=1e-4,\n                atol=1e-5,\n            )\n        )\n\n        db = torch.autograd.grad(\n            outputs=y,\n            inputs=b,\n            grad_outputs=init_grad_y,\n            create_graph=True,\n            retain_graph=True,\n        )[0]\n        test_case.assertTrue(\n            np.allclose(\n                db.pytorch.detach().cpu().numpy(),\n                db.oneflow.detach().numpy(),\n                rtol=1e-4,\n                atol=1e-5,\n            )\n        )\n\n        # torch.autograd.grad in autotest does not support inputs/outpus/grad_outputs as a list\n        # so use the original pytorch/oneflow module\n        dda_pytorch, ddb_pytorch = pytorch_origin.autograd.grad(\n            outputs=[da.pytorch, db.pytorch],\n            inputs=[a.pytorch, b.pytorch],\n            grad_outputs=[init_grad_a.pytorch, init_grad_b.pytorch],\n            create_graph=True,\n            retain_graph=True,\n        )\n        dda_oneflow, ddb_oneflow = oneflow_origin.autograd.grad(\n            outputs=[da.oneflow, db.oneflow],\n            inputs=[a.oneflow, b.oneflow],\n            grad_outputs=[init_grad_a.oneflow, init_grad_b.oneflow],\n            create_graph=True,\n            retain_graph=True,\n        )\n\n        test_case.assertTrue(\n            np.allclose(\n                ddb_pytorch.detach().cpu().numpy(),\n                ddb_oneflow.detach().numpy(),\n                rtol=1e-4,\n                atol=1e-5,\n            )\n        )\n        test_case.assertTrue(\n            np.allclose(\n                dda_pytorch.detach().cpu().numpy(),\n                dda_oneflow.detach().numpy(),\n                rtol=1e-4,\n                atol=1e-5,\n            )\n        )\n\n        dgrad_da = torch.autograd.grad(\n            outputs=da,\n            inputs=init_grad_y,\n            grad_outputs=init_grad_a,\n            create_graph=True,\n            retain_graph=True,\n        )[0]\n        test_case.assertTrue(\n            np.allclose(\n                dgrad_da.pytorch.detach().cpu().numpy(),\n                dgrad_da.oneflow.detach().numpy(),\n                rtol=1e-4,\n                atol=1e-5,\n            )\n        )\n\n        dgrad_db = torch.autograd.grad(\n            outputs=db,\n            inputs=init_grad_y,\n            grad_outputs=init_grad_b,\n            create_graph=True,\n            retain_graph=True,\n        )[0]\n        test_case.assertTrue(\n            np.allclose(\n                dgrad_db.pytorch.detach().cpu().numpy(),\n                dgrad_db.oneflow.detach().numpy(),\n                rtol=1e-4,\n                atol=1e-5,\n            )\n        )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_higher_derivative_neg.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\n\nimport numpy as np\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\nclass TestNegHigherDerivative(flow.unittest.TestCase):\n    def test_neg_grad_grad(test_case):\n        x = random_tensor(ndim=2).requires_grad_(True)\n        y = torch.neg(x)\n        np_arr = np.random.rand(*x.oneflow.shape)\n        init_grad = torch.tensor(np_arr).requires_grad_()\n\n        x_grad = torch.autograd.grad(y, x, init_grad, create_graph=True)[0]\n        test_case.assertTrue(\n            np.allclose(\n                x_grad.pytorch.detach().cpu().numpy(), x_grad.oneflow.detach().numpy()\n            )\n        )\n\n        init_grad_grad = torch.tensor(np_arr).requires_grad_()\n        dgrad = torch.autograd.grad(\n            x_grad, init_grad, init_grad_grad, create_graph=False\n        )[0]\n        test_case.assertTrue(\n            np.allclose(\n                dgrad.pytorch.detach().cpu().numpy(), dgrad.oneflow.detach().numpy(),\n            )\n        )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_higher_derivative_pool.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nimport numpy as np\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\ndef _check_equal(test_case, lhs, rhs, name=\"\", rtol=1e-5, atol=1e-5):\n    is_equal = np.allclose(\n        lhs.detach().cpu().numpy(),\n        rhs.detach().cpu().numpy(),\n        rtol=rtol,\n        atol=atol,\n        equal_nan=True,\n    )\n    test_case.assertTrue(is_equal, f\"{name} is not equal\" if name else \"\")\n\n\ndef _test_avg_pool_grad_grad_impl(test_case, ndim):\n    device = random_device()\n    minibatch = random(1, 5).to(int).value()\n    channels = random(1, 5).to(int).value()\n    padding = random(0, 3).to(int).value()\n    ceil_mode = random_bool().value()\n    count_include_pad = random_bool().value()\n    divisor_override = random().to(int).value()\n    kernel_size = random(4, 6).to(int).value()\n    stride = random(1, 3).to(int).value()\n    x_shape = [minibatch, channels] + [\n        random(8, 12).to(int).value() for i in range(ndim)\n    ]\n\n    kwargs = {\n        \"kernel_size\": kernel_size,\n        \"stride\": oneof(stride, nothing()),\n        \"padding\": oneof(padding, nothing()),\n        \"ceil_mode\": ceil_mode,\n        \"count_include_pad\": count_include_pad,\n    }\n    if ndim != 1:\n        kwargs[\"divisor_override\"] = divisor_override\n\n    m = eval(f\"torch.nn.AvgPool{ndim}d\")(**kwargs)\n    m.to(device)\n\n    x = random_tensor(len(x_shape), *x_shape).to(device)\n    y = m(x)\n    _check_equal(test_case, y.pytorch, y.oneflow, \"y\")\n\n    init_grad_y = random_tensor(len(y.oneflow.shape), *y.oneflow.shape).to(device)\n    init_grad_x = random_tensor(len(x.oneflow.shape), *x.oneflow.shape).to(device)\n\n    dx = torch.autograd.grad(y, x, init_grad_y, True, True)[0]\n    _check_equal(test_case, dx.pytorch, dx.oneflow, \"dx\")\n\n    ddx_ddy = torch.autograd.grad(dx, [x, init_grad_y], init_grad_x, True, True)\n    ddx, ddy = ddx_ddy[0], ddx_ddy[1]\n    _check_equal(test_case, ddx.pytorch, ddx.oneflow, \"ddx\")\n    _check_equal(test_case, ddy.pytorch, ddy.oneflow, \"ddy\")\n\n\ndef _test_max_pool_grad_grad_impl(test_case, ndim):\n    device = random_device()\n    minibatch = random(1, 5).to(int).value()\n    channels = random(1, 5).to(int).value()\n    padding = random(0, 3).to(int).value()\n    dilation = random(1, 3).to(int).value()\n    ceil_mode = random_bool().value()\n    return_indices = random_bool().value()\n    kernel_size = random(4, 6).to(int).value()\n    stride = random(1, 3).to(int).value()\n    x_shape = [minibatch, channels] + [\n        random(10, 12).to(int).value() for i in range(ndim)\n    ]\n\n    m = eval(f\"torch.nn.MaxPool{ndim}d\")(\n        kernel_size=kernel_size,\n        stride=oneof(stride, nothing()),\n        padding=oneof(padding, nothing()),\n        dilation=oneof(dilation, nothing()),\n        ceil_mode=ceil_mode,\n        return_indices=return_indices,\n    )\n    m.to(device)\n\n    x = random_tensor(len(x_shape), *x_shape).to(device)\n    if return_indices:\n        y_and_indices = m(x)\n        y, indices = y_and_indices[0], y_and_indices[1]\n    else:\n        y = m(x)\n    _check_equal(test_case, y.pytorch, y.oneflow, \"y\")\n\n    init_grad_y = random_tensor(len(y.oneflow.shape), *y.oneflow.shape).to(device)\n    init_grad_x = random_tensor(len(x.oneflow.shape), *x.oneflow.shape).to(device)\n\n    dx = torch.autograd.grad(y, x, init_grad_y, True, True)[0]\n    _check_equal(test_case, dx.pytorch, dx.oneflow, \"dx\")\n\n    ddx_ddy = torch.autograd.grad(dx, [x, init_grad_y], init_grad_x, True, True)\n    ddx, ddy = ddx_ddy[0], ddx_ddy[1]\n    _check_equal(test_case, ddx.pytorch, ddx.oneflow, \"ddx\")\n    _check_equal(test_case, ddy.pytorch, ddy.oneflow, \"ddy\")\n\n\ndef _test_adaptive_pool_grad_grad_impl(test_case, ndim, mode):\n    device = random_device()\n    x_shape = [random(5, 10).to(int).value() for i in range(2 + ndim)]\n    output_size = [random(2, 1 + x_shape[2 + i]).to(int).value() for i in range(ndim)]\n\n    m = eval(f\"torch.nn.Adaptive{mode.title()}Pool{ndim}d\")(output_size)\n    m.to(device)\n\n    x = random_tensor(len(x_shape), *x_shape).to(device)\n    y = m(x)\n    _check_equal(test_case, y.pytorch, y.oneflow, \"y\")\n\n    init_grad_y = random_tensor(len(y.oneflow.shape), *y.oneflow.shape).to(device)\n    init_grad_x = random_tensor(len(x.oneflow.shape), *x.oneflow.shape).to(device)\n\n    dx = torch.autograd.grad(y, x, init_grad_y, True, True)[0]\n    _check_equal(test_case, dx.pytorch, dx.oneflow, \"dx\")\n\n    ddx_ddy = torch.autograd.grad(dx, [x, init_grad_y], init_grad_x, True, True)\n    ddx, ddy = ddx_ddy[0], ddx_ddy[1]\n\n    _check_equal(test_case, ddx.pytorch, ddx.oneflow, \"ddx\")\n    _check_equal(test_case, ddy.pytorch, ddy.oneflow, \"ddy\")\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestPoolHigherDerivative(flow.unittest.TestCase):\n    def test_max_pool_1d_grad_grad(test_case):\n        _test_max_pool_grad_grad_impl(test_case, 1)\n\n    def test_max_pool_2d_grad_grad(test_case):\n        _test_max_pool_grad_grad_impl(test_case, 2)\n\n    def test_max_pool_3d_grad_grad(test_case):\n        _test_max_pool_grad_grad_impl(test_case, 3)\n\n    def test_avg_pool_1d_grad_grad(test_case):\n        _test_avg_pool_grad_grad_impl(test_case, ndim=1)\n\n    def test_avg_pool_2d_grad_grad(test_case):\n        _test_avg_pool_grad_grad_impl(test_case, ndim=2)\n\n    def test_avg_pool_3d_grad_grad(test_case):\n        _test_avg_pool_grad_grad_impl(test_case, ndim=3)\n\n    def test_adaptive_avg_pool_1d_grad_grad(test_case):\n        _test_adaptive_pool_grad_grad_impl(test_case, ndim=1, mode=\"avg\")\n\n    def test_adaptive_avg_pool_2d_grad_grad(test_case):\n        _test_adaptive_pool_grad_grad_impl(test_case, ndim=2, mode=\"avg\")\n\n    def test_adaptive_avg_pool_3d_grad_grad(test_case):\n        _test_adaptive_pool_grad_grad_impl(test_case, ndim=3, mode=\"avg\")\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_higher_derivative_pow.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nimport numpy as np\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\ndef _check_equal(test_case, lhs, rhs, rtol=1e-3, atol=1e-3):\n    is_equal = np.allclose(\n        lhs.detach().cpu().numpy(),\n        rhs.detach().cpu().numpy(),\n        rtol=rtol,\n        atol=atol,\n        equal_nan=True,\n    )\n    test_case.assertTrue(is_equal)\n\n\ndef _test_pow_grad_grad_impl(test_case):\n    y_shape = [random().to(int).value() for _ in range(random().to(int).value())]\n    x_shape = y_shape[random(0, 5).to(int).value() :]\n    if random_bool().value():\n        x_shape, y_shape = y_shape, x_shape\n\n    # The range limit should be removed after solving issue #9908\n    x = random_tensor(len(x_shape), *x_shape, low=0, high=1)\n    y = random_tensor(len(y_shape), *y_shape, low=0, high=1)\n\n    z = torch.pow(x, y)\n    _check_equal(test_case, z.pytorch, z.oneflow)\n\n    init_grad_z = random_tensor(len(z.oneflow.shape), *z.oneflow.shape)\n    init_grad_x = random_tensor(len(x.oneflow.shape), *x.oneflow.shape)\n    init_grad_y = random_tensor(len(y.oneflow.shape), *y.oneflow.shape)\n\n    dx_and_dy = torch.autograd.grad(z, [x, y], init_grad_z, True, True)\n    _check_equal(test_case, dx_and_dy.pytorch[0], dx_and_dy.oneflow[0])\n    _check_equal(test_case, dx_and_dy.pytorch[1], dx_and_dy.oneflow[1])\n\n    ddx_ddy_ddz = torch.autograd.grad(\n        dx_and_dy, [x, y, init_grad_z], [init_grad_x, init_grad_y]\n    )\n    _check_equal(test_case, ddx_ddy_ddz.pytorch[0], ddx_ddy_ddz.oneflow[0])\n    _check_equal(test_case, ddx_ddy_ddz.pytorch[1], ddx_ddy_ddz.oneflow[1])\n    _check_equal(test_case, ddx_ddy_ddz.pytorch[2], ddx_ddy_ddz.oneflow[2])\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestPowHigherDerivative(flow.unittest.TestCase):\n    def test_pow_grad_grad(test_case):\n        for i in range(10):\n            _test_pow_grad_grad_impl(test_case)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_higher_derivative_scalar_pow.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nimport numpy as np\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\ndef _check_equal(test_case, lhs, rhs, rtol=1e-4, atol=1e-4, name=\"\"):\n    is_equal = np.allclose(\n        lhs.detach().cpu().numpy(),\n        rhs.detach().cpu().numpy(),\n        rtol=rtol,\n        atol=atol,\n        equal_nan=True,\n    )\n    test_case.assertTrue(is_equal, f\"{name} is not equal\")\n\n\ndef _test_scalar_pow_grad_grad_impl(test_case, reverse=False):\n    x_shape = [random().to(int).value() for _ in range(random().to(int).value())]\n    y = random().to(float if random_bool().value() else int).value()\n\n    x = random_tensor(len(x_shape), *x_shape)\n    z = torch.pow(x, y) if not reverse else torch.pow(y, x)\n\n    init_grad_z = random_tensor(len(z.oneflow.shape), *z.oneflow.shape)\n    init_grad_x = random_tensor(len(x.oneflow.shape), *x.oneflow.shape)\n\n    dx = torch.autograd.grad(z, x, init_grad_z, True, True)[0]\n    _check_equal(test_case, dx.pytorch, dx.oneflow, name=\"dx\")\n\n    ddx_and_ddz = torch.autograd.grad(dx, [x, init_grad_z], init_grad_x, True, True)\n    _check_equal(test_case, ddx_and_ddz.pytorch[0], ddx_and_ddz.oneflow[0], name=\"ddx\")\n    _check_equal(test_case, ddx_and_ddz.pytorch[1], ddx_and_ddz.oneflow[1], name=\"ddz\")\n\n\nclass TestScalarPowHigherDerivative(flow.unittest.TestCase):\n    def test_scalar_pow_grad_grad(test_case):\n        for i in range(10):\n            _test_scalar_pow_grad_grad_impl(test_case)\n\n    def test_scalar_reverse_pow_grad_grad(test_case):\n        for i in range(10):\n            _test_scalar_pow_grad_grad_impl(test_case, reverse=True)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_higher_derivative_slice.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\n\nimport numpy as np\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\ndef random_index(dim):\n    start = np.random.choice(list(range(dim)))\n    stop = np.random.choice(list(range(1, dim + 1)))\n    if start >= stop:\n        start, stop = stop - 1, start + 1\n    step = np.random.randint(1, dim)\n    return f\"{start}:{stop}:{step}\"\n\n\ndef random_slice(dim_vec):\n    slice_index = \", \".join(random_index(dim) for dim in dim_vec)\n    return slice_index\n\n\ndef _test_slice_grad_grad_impl(test_case):\n    ndim = np.random.randint(2, 5)\n    x_shape = [np.random.randint(3, 8) for _ in range(ndim)]\n    x = random_tensor(len(x_shape), *x_shape).requires_grad_(True)\n\n    slice_index = random_slice(x_shape)\n    y = eval(f\"x[{slice_index}]\")\n\n    init_grad = random_tensor(len(y.oneflow.shape), *y.oneflow.shape).requires_grad_()\n    x_grad = torch.autograd.grad(y, x, init_grad, create_graph=True)[0]\n    test_case.assertTrue(\n        np.allclose(\n            x_grad.pytorch.detach().cpu().numpy(), x_grad.oneflow.detach().numpy()\n        )\n    )\n\n    init_grad_grad = random_tensor(\n        len(x_grad.oneflow.shape), *x_grad.oneflow.shape\n    ).requires_grad_()\n    dgrad = torch.autograd.grad(x_grad, init_grad, init_grad_grad, create_graph=False)[\n        0\n    ]\n    test_case.assertTrue(\n        np.allclose(\n            dgrad.pytorch.detach().cpu().numpy(), dgrad.oneflow.detach().numpy(),\n        )\n    )\n\n\nclass TestSliceHigherDerivative(flow.unittest.TestCase):\n    def test_slice_grad_grad(test_case):\n        for i in range(10):\n            _test_slice_grad_grad_impl(test_case)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_higher_derivative_softmax.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nimport numpy as np\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\ndef _assert_true(test_case, value1, value2):\n    test_case.assertTrue(\n        np.allclose(\n            value1.detach().cpu().numpy(),\n            value2.detach().cpu().numpy(),\n            rtol=1e-05,\n            atol=1e-05,\n        )\n    )\n\n\ndef _test_softmax_grad_grad_impl(test_case, op_name):\n    ndim = random(low=2).to(int).value()\n    data = random_tensor(ndim=ndim)\n\n    for dim in range(ndim):\n        x = data.detach().clone().requires_grad_()\n        m = eval(f\"torch.nn.{op_name}\")(dim)\n        y = m(x)\n        _assert_true(test_case, y.pytorch, y.oneflow)\n\n        x_shape = x.oneflow.shape\n        init_grad_x = random_tensor(len(x_shape), *x_shape)\n        init_grad_y = random_tensor(len(x_shape), *x_shape)\n\n        dx = torch.autograd.grad(y, x, init_grad_y, True, True)[0]\n        _assert_true(test_case, dx.pytorch, dx.oneflow)\n\n        ddx_ddy = torch.autograd.grad(dx, [x, init_grad_y], init_grad_x)\n        ddx, ddy = ddx_ddy[0], ddx_ddy[1]\n        _assert_true(test_case, ddx.pytorch, ddx.oneflow)\n        _assert_true(test_case, ddy.pytorch, ddy.oneflow)\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestSoftmaxHigherDerivative(flow.unittest.TestCase):\n    def test_softmax_grad_grad(test_case):\n        _test_softmax_grad_grad_impl(test_case, op_name=\"Softmax\")\n\n    def test_logsoftmax_grad_grad(test_case):\n        _test_softmax_grad_grad_impl(test_case, op_name=\"LogSoftmax\")\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_host_memory_input.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport os\nimport unittest\nimport numpy as np\n\nimport oneflow as flow\nfrom oneflow import nn\nimport oneflow.unittest\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\nclass TestHostMemory(oneflow.unittest.TestCase):\n    @flow.unittest.skip_unless_1n1d()\n    def test_host_memory(test_case):\n        x = flow.ones(2, 3, device=\"cuda\")\n        scalar = flow.Tensor([3.0], device=\"cuda\")\n\n        y = x + scalar\n        out = y + scalar + y\n\n        class HostMemoryInputGraph(nn.Graph):\n            def __init__(self):\n                super(HostMemoryInputGraph, self).__init__()\n\n            def build(self, x, scalar):\n                a = flow._C.host_scalar_add_by_tensor(x, scalar.cpu())\n                b = flow._C.host_scalar_add_by_tensor(a, scalar)\n                return a + b\n\n        graph = HostMemoryInputGraph()\n        lazy_out = graph(x, scalar)\n\n        test_case.assertTrue(np.array_equal(out.numpy(), lazy_out.numpy()))\n\n        a = flow._C.host_scalar_add_by_tensor(x, scalar.cpu())\n        b = flow._C.host_scalar_add_by_tensor(a, scalar)\n        eager_out = a + b\n        test_case.assertTrue(np.array_equal(out.numpy(), eager_out.numpy()))\n\n    @flow.unittest.skip_unless_1n2d()\n    def test_host_memory_1n2d(test_case):\n        x = flow.ones(\n            2, 3, placement=flow.placement(\"cuda\", [0, 1]), sbp=flow.sbp.broadcast\n        )\n        scalar = flow.Tensor(\n            [3.0], placement=flow.placement(\"cuda\", [0, 1]), sbp=flow.sbp.broadcast\n        )\n\n        y = x + scalar\n        out = y + scalar + y\n\n        class HostMemoryInputGraph(nn.Graph):\n            def __init__(self):\n                super(HostMemoryInputGraph, self).__init__()\n\n            def build(self, x, scalar):\n                a = flow._C.host_scalar_add_by_tensor(x, scalar.cpu())\n                b = flow._C.host_scalar_add_by_tensor(a, scalar)\n                return a + b\n\n        graph = HostMemoryInputGraph()\n        lazy_out = graph(x, scalar)\n\n        test_case.assertTrue(np.array_equal(out.numpy(), lazy_out.numpy()))\n\n        a = flow._C.host_scalar_add_by_tensor(x, scalar.cpu())\n        b = flow._C.host_scalar_add_by_tensor(a, scalar)\n        eager_out = a + b\n        test_case.assertTrue(np.array_equal(out.numpy(), eager_out.numpy()))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_hsplit.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nfrom random import shuffle\n\nfrom oneflow.test_utils.automated_test_util import *\nimport oneflow as flow\nimport oneflow.unittest\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestHsplitVec(flow.unittest.TestCase):\n    @autotest(n=5, check_graph=True)\n    def test_flow_hsplit_vec(test_case):\n        device = random_device()\n        x = random_tensor(\n            ndim=4,\n            dim0=random(3, 6),\n            dim1=random(3, 6),\n            dim2=random(3, 6),\n            dim3=random(3, 6),\n        ).to(device)\n        z = torch.hsplit(x, (1, 2))\n        return z\n\n    @autotest(n=5)\n    def test_flow_hsplit_vec_with_stride(test_case):\n        device = random_device()\n        x = random_tensor(\n            ndim=4,\n            dim0=random(3, 6),\n            dim1=random(3, 6),\n            dim2=random(3, 6),\n            dim3=random(3, 6),\n        ).to(device)\n        perm = [0, 1, 2, 3]\n        shuffle(perm)\n        y = x.permute(perm)\n        z = torch.hsplit(y, (1, 2))\n        return z\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestHsplitInt(flow.unittest.TestCase):\n    @autotest(n=10, check_graph=True)\n    def test_flow_hsplit_int(test_case):\n        device = random_device()\n        x = random_tensor(\n            ndim=4, dim0=random(3, 6), dim1=12, dim2=random(3, 6), dim3=random(3, 6),\n        ).to(device)\n        split = oneof(2, 4, 6)\n        z = torch.hsplit(x, split)\n        return z\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_hub.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\n\nimport numpy as np\nimport oneflow as flow\nimport oneflow.unittest\n\n\n@unittest.skip(reason=\"network fluctuations can cause downloads to fail!\")\nclass TestHub(flow.unittest.TestCase):\n    def test_hub_list_api(test_case):\n        entrypoints = flow.hub.list(\"OneFlow-Inc/vision\", force_reload=False)\n        test_case.assertEqual(\"alexnet\" in entrypoints, True)\n        test_case.assertEqual(\"densenet121\" in entrypoints, True)\n\n    def test_hub_help_api(test_case):\n        help_info = flow.hub.help(\"Oneflow-Inc/vision\", \"resnet18\", force_reload=False)\n        print(help_info)\n\n    def test_hub_load_api(test_case):\n        repo = \"Oneflow-Inc/vision\"\n        model = flow.hub.load(repo, \"resnet18\", pretrained=True)\n        x = flow.randn(1, 3, 224, 224)\n        y = model(x)\n        test_case.assertTrue(np.array_equal(y.size(), (1, 1000)))\n\n    def test_hub_download_url_to_file__api(test_case):\n        flow.hub.download_url_to_file(\n            \"https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/ResNet/resnet18.zip\",\n            \"/tmp/temporary_file\",\n        )\n\n    def test_hub_load_state_dict_from_url_api(test_case):\n        state_dict = flow.hub.load_state_dict_from_url(\n            \"https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/ResNet/resnet18.zip\"\n        )\n        test_case.assertEqual(\"layer3.1.bn2.bias\" in state_dict.keys(), True)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_image_batch_align.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport operator\nimport unittest\nfrom functools import reduce\n\nimport cv2\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\ndef _read_images_by_cv(image_files):\n    images = [cv2.imread(image_file).astype(np.single) for image_file in image_files]\n    return images\n\n\ndef _get_images_static_shape(images):\n    image_shapes = [image.shape for image in images]\n    image_static_shape = np.amax(image_shapes, axis=0)\n    assert isinstance(\n        image_static_shape, np.ndarray\n    ), \"image_shapes: {}, image_static_shape: {}\".format(\n        str(image_shapes), str(image_static_shape)\n    )\n    image_static_shape = image_static_shape.tolist()\n    image_static_shape.insert(0, len(image_shapes))\n    return image_static_shape\n\n\ndef _roundup(x, n):\n    return int((x + n - 1) / n) * n\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestImageBatchAlign(flow.unittest.TestCase):\n    def test_image_batch_align(test_case):\n        image_files = [\n            flow.unittest.dataset_dir(\"mscoco_2017/val2017/000000000139.jpg\"),\n            flow.unittest.dataset_dir(\"mscoco_2017/val2017/000000000632.jpg\"),\n            flow.unittest.dataset_dir(\"mscoco_2017/val2017/000000000785.jpg\"),\n            flow.unittest.dataset_dir(\"mscoco_2017/val2017/000000001000.jpg\"),\n        ]\n        alignment = 16\n        images = _read_images_by_cv(image_files)\n        image_shape = _get_images_static_shape(images)\n        assert len(image_shape) == 4\n        aligned_image_shape = [\n            image_shape[0],\n            _roundup(image_shape[1], alignment),\n            _roundup(image_shape[2], alignment),\n            image_shape[3],\n        ]\n        image_batch_aligner = flow.nn.image.batch_align(\n            shape=aligned_image_shape[1:], dtype=flow.float, alignment=alignment\n        )\n        images_np_arr_static = np.zeros(image_shape, dtype=np.float32)\n        for (idx, np_arr) in enumerate(images):\n            images_np_arr_static[idx, : np_arr.shape[0], : np_arr.shape[1], :] = np_arr\n        input = flow.tensor(\n            images_np_arr_static, dtype=flow.float, device=flow.device(\"cpu\")\n        )\n        images_buffer = flow.tensor_to_tensor_buffer(input, instance_dims=3)\n        of_aligned_image = image_batch_aligner(images_buffer).numpy()\n        test_case.assertTrue(\n            np.array_equal(aligned_image_shape, of_aligned_image.shape)\n        )\n        empty_image_array = np.zeros(aligned_image_shape, np.float32)\n        for (empty_image, image) in zip(empty_image_array, images):\n            empty_image[0 : image.shape[0], 0 : image.shape[1], :] = image\n        test_case.assertTrue(np.array_equal(of_aligned_image, empty_image_array))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_image_decode.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\n\nimport cv2\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestImageDecode(flow.unittest.TestCase):\n    def test_image_decode(test_case):\n        images = [\n            flow.unittest.dataset_dir(\"mscoco_2017/val2017/000000000139.jpg\"),\n            flow.unittest.dataset_dir(\"mscoco_2017/val2017/000000000632.jpg\"),\n        ]\n        image_files = [open(im, \"rb\") for im in images]\n        images_bytes = [imf.read() for imf in image_files]\n        static_shape = (len(images_bytes), max([len(bys) for bys in images_bytes]))\n        for imf in image_files:\n            imf.close()\n        image_decoder = flow.nn.image.decode(color_space=\"BGR\")\n        images_np_arr = [\n            np.frombuffer(bys, dtype=np.byte).reshape(1, -1) for bys in images_bytes\n        ]\n        images_np_arr_static = np.zeros(static_shape, dtype=np.int8)\n        for (idx, np_arr) in enumerate(images_np_arr):\n            images_np_arr_static[idx, : np_arr.shape[1]] = np_arr\n        input = flow.tensor(\n            images_np_arr_static, dtype=flow.int8, device=flow.device(\"cpu\")\n        )\n        images_buffer = flow.tensor_to_tensor_buffer(input, instance_dims=1)\n        decoded_images_buffer = image_decoder(images_buffer)\n        of_decoded_images = decoded_images_buffer.numpy()\n        cv2_images = [cv2.imread(image) for image in images]\n        cv2_decoded_images = [np.array(image) for image in cv2_images]\n        for (of_decoded_image, cv2_decoded_image) in zip(\n            of_decoded_images, cv2_decoded_images\n        ):\n            test_case.assertTrue(len(of_decoded_image.shape) == 3)\n            test_case.assertTrue(len(cv2_decoded_image.shape) == 3)\n            test_case.assertTrue(np.allclose(of_decoded_image, cv2_decoded_image))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_image_flip.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\n\nimport cv2\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\ndef _of_image_flip(images, image_static_shape, flip_code):\n    image_tensors = flow.tensor(images, dtype=flow.float, device=flow.device(\"cpu\"))\n    image_tensor_buffer = flow.tensor_to_tensor_buffer(image_tensors, instance_dims=3)\n    flip_images = flow.nn.image.flip()(image_tensor_buffer, flip_code)\n    return flip_images.numpy()\n\n\ndef _read_images_by_cv(image_files):\n    images = [cv2.imread(image_file).astype(np.single) for image_file in image_files]\n    return [np.expand_dims(image, axis=0) for image in images]\n\n\ndef _get_images_static_shape(images):\n    image_shapes = [image.shape for image in images]\n    image_static_shape = np.amax(image_shapes, axis=0)\n    assert isinstance(\n        image_static_shape, np.ndarray\n    ), \"image_shapes: {}, image_static_shape: {}\".format(\n        str(image_shapes), str(image_static_shape)\n    )\n    image_static_shape = image_static_shape.tolist()\n    assert image_static_shape[0] == 1, str(image_static_shape)\n    image_static_shape[0] = len(image_shapes)\n    return image_static_shape\n\n\ndef _compare_image_flip_with_cv(test_case, image_files):\n    images = _read_images_by_cv(image_files)\n    assert all([len(image.shape) == 4 for image in images])\n    image_static_shape = _get_images_static_shape(images)\n    image_paddings = np.zeros(tuple(image_static_shape))\n    for (idx, image) in enumerate(images):\n        image_paddings[\n            idx, : image.shape[1], : image.shape[2], : image.shape[3]\n        ] = image\n    flip_code = flow.ones(image_static_shape[0], dtype=flow.int8)\n    flip_images = _of_image_flip(image_paddings, image_static_shape, flip_code)\n    for (image, flip_image) in zip(image_paddings, flip_images):\n        exp_flip_image = cv2.flip(image.squeeze(), 1)\n        test_case.assertTrue(np.allclose(exp_flip_image, flip_image))\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestImageFlip(flow.unittest.TestCase):\n    def test_image_flip(test_case):\n        _compare_image_flip_with_cv(\n            test_case,\n            [\n                flow.unittest.dataset_dir(\"mscoco_2017/val2017/000000000139.jpg\"),\n                flow.unittest.dataset_dir(\"mscoco_2017/val2017/000000000632.jpg\"),\n            ],\n        )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_image_normalize.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\n\nimport cv2\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\ndef _of_image_normalize(images, image_static_shape, std, mean):\n    image_zeros = np.zeros(tuple(image_static_shape))\n    for (idx, image) in enumerate(images):\n        image_zeros[idx, : image.shape[1], : image.shape[2], : image.shape[3]] = image\n    image_tensors = flow.tensor(\n        image_zeros, dtype=flow.float, device=flow.device(\"cpu\")\n    )\n    image_tensor_buffer = flow.tensor_to_tensor_buffer(image_tensors, instance_dims=3)\n    image_normalizer = flow.nn.image.normalize(std, mean)\n    norm_images = image_normalizer(image_tensor_buffer)\n    return norm_images.numpy()\n\n\ndef _read_images_by_cv(image_files):\n    images = [cv2.imread(image_file).astype(np.single) for image_file in image_files]\n    return [np.expand_dims(image, axis=0) for image in images]\n\n\ndef _get_images_static_shape(images):\n    image_shapes = [image.shape for image in images]\n    image_static_shape = np.amax(image_shapes, axis=0)\n    assert isinstance(\n        image_static_shape, np.ndarray\n    ), \"image_shapes: {}, image_static_shape: {}\".format(\n        str(image_shapes), str(image_static_shape)\n    )\n    image_static_shape = image_static_shape.tolist()\n    assert image_static_shape[0] == 1, str(image_static_shape)\n    image_static_shape[0] = len(image_shapes)\n    return image_static_shape\n\n\ndef _compare_image_normalize(test_case, image_files, std, mean):\n    images = _read_images_by_cv(image_files)\n    assert all([len(image.shape) == 4 for image in images])\n    image_static_shape = _get_images_static_shape(images)\n    norm_images = _of_image_normalize(images, image_static_shape, std, mean)\n    std_array = np.array(std).reshape(1, 1, 1, -1)\n    mean_array = np.array(mean).reshape(1, 1, 1, -1)\n    for (image, norm_image) in zip(images, norm_images):\n        np_norm_image = np.squeeze((image - mean_array) / std_array, axis=0)\n        norm_image = norm_image[\n            : np_norm_image.shape[0], : np_norm_image.shape[1], : np_norm_image.shape[2]\n        ]\n        test_case.assertTrue(np.allclose(np_norm_image, norm_image))\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestImageNormalize(flow.unittest.TestCase):\n    def test_image_normalize(test_case):\n        _compare_image_normalize(\n            test_case,\n            [\n                flow.unittest.dataset_dir(\"mscoco_2017/val2017/000000000139.jpg\"),\n                flow.unittest.dataset_dir(\"mscoco_2017/val2017/000000000632.jpg\"),\n            ],\n            (102.9801, 115.9465, 122.7717),\n            (1.0, 1.0, 1.0),\n        )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_image_resize.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\n\nimport cv2\nimport image_test_util\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.nn as nn\nimport oneflow.unittest\n\n\ndef _of_image_resize(\n    image_list,\n    dtype=flow.float32,\n    origin_dtype=flow.float32,\n    channels=3,\n    keep_aspect_ratio=False,\n    target_size=None,\n    min_size=None,\n    max_size=None,\n    resize_side=\"shorter\",\n    interpolation_type=\"bilinear\",\n):\n    assert isinstance(image_list, (list, tuple))\n    assert all((isinstance(image, np.ndarray) for image in image_list))\n    assert all((image.ndim == 3 for image in image_list))\n    assert all((image.shape[2] == channels for image in image_list))\n    res_image_list = []\n    res_size_list = []\n    res_scale_list = []\n    image_resize_module = nn.image.Resize(\n        target_size=target_size,\n        min_size=min_size,\n        max_size=max_size,\n        keep_aspect_ratio=keep_aspect_ratio,\n        resize_side=resize_side,\n        dtype=dtype,\n        interpolation_type=interpolation_type,\n        channels=channels,\n    )\n    for image in image_list:\n        tensor_dtype = dtype if keep_aspect_ratio else origin_dtype\n        input = flow.tensor(\n            np.expand_dims(image, axis=0), dtype=tensor_dtype, device=flow.device(\"cpu\")\n        )\n        image_buffer = flow.tensor_to_tensor_buffer(input, instance_dims=3)\n        (res_image, scale, new_size) = image_resize_module(image_buffer)\n        res_image = res_image.numpy()\n        scale = scale.numpy()\n        if not keep_aspect_ratio:\n            new_size = np.asarray([(target_size, target_size)])\n        else:\n            new_size = new_size.numpy()\n        res_image_list.append(res_image[0])\n        res_size_list.append(new_size[0])\n        res_scale_list.append(scale[0])\n    return (res_image_list, res_scale_list, res_size_list)\n\n\ndef _get_resize_size_and_scale(\n    w,\n    h,\n    target_size,\n    min_size=None,\n    max_size=None,\n    keep_aspect_ratio=True,\n    resize_side=\"shorter\",\n):\n    if keep_aspect_ratio:\n        assert isinstance(target_size, int)\n        aspect_ratio = float(min((w, h))) / float(max((w, h)))\n        (\n            min_res_size,\n            max_res_size,\n        ) = image_test_util.compute_keep_aspect_ratio_resized_size(\n            target_size, min_size, max_size, aspect_ratio, resize_side\n        )\n        if w < h:\n            res_w = min_res_size\n            res_h = max_res_size\n        else:\n            res_w = max_res_size\n            res_h = min_res_size\n    else:\n        assert isinstance(target_size, (list, tuple))\n        assert len(target_size) == 2\n        assert all((isinstance(size, int) for size in target_size))\n        (res_w, res_h) = target_size\n    scale_w = res_w / w\n    scale_h = res_h / h\n    return ((res_w, res_h), (scale_w, scale_h))\n\n\ndef _cv_image_resize(\n    image_list,\n    target_size,\n    keep_aspect_ratio=True,\n    min_size=None,\n    max_size=None,\n    resize_side=\"shorter\",\n    interpolation=cv2.INTER_LINEAR,\n    dtype=np.float32,\n):\n    res_image_list = []\n    res_size_list = []\n    res_scale_list = []\n    for image in image_list:\n        (h, w) = image.shape[:2]\n        (new_size, scale) = _get_resize_size_and_scale(\n            w, h, target_size, min_size, max_size, keep_aspect_ratio, resize_side\n        )\n        res_image_list.append(\n            cv2.resize(image.squeeze(), new_size, interpolation=interpolation).astype(\n                dtype\n            )\n        )\n        res_size_list.append(new_size)\n        res_scale_list.append(scale)\n    return (res_image_list, res_scale_list, res_size_list)\n\n\ndef _test_image_resize_with_cv(\n    test_case,\n    image_files,\n    target_size,\n    min_size=None,\n    max_size=None,\n    keep_aspect_ratio=True,\n    resize_side=\"shorter\",\n    dtype=flow.float32,\n    origin_dtype=None,\n):\n    if origin_dtype is None:\n        origin_dtype = dtype\n    image_list = image_test_util.read_images_by_cv(image_files, origin_dtype)\n    (of_res_images, of_scales, of_new_sizes) = _of_image_resize(\n        image_list=image_list,\n        dtype=dtype,\n        origin_dtype=origin_dtype,\n        keep_aspect_ratio=keep_aspect_ratio,\n        target_size=target_size,\n        min_size=min_size,\n        max_size=max_size,\n        resize_side=resize_side,\n    )\n    (cv_res_images, cv_scales, cv_new_sizes) = _cv_image_resize(\n        image_list=image_list,\n        target_size=target_size,\n        keep_aspect_ratio=keep_aspect_ratio,\n        min_size=min_size,\n        max_size=max_size,\n        resize_side=resize_side,\n        dtype=flow.convert_oneflow_dtype_to_numpy_dtype(dtype),\n    )\n    for (\n        of_res_image,\n        cv_res_image,\n        of_scale,\n        cv_scale,\n        of_new_size,\n        cv_new_size,\n    ) in zip(\n        of_res_images, cv_res_images, of_scales, cv_scales, of_new_sizes, cv_new_sizes\n    ):\n        test_case.assertTrue(np.allclose(of_res_image, cv_res_image))\n        test_case.assertTrue(np.allclose(of_scale, cv_scale))\n        test_case.assertTrue(np.allclose(of_new_size, cv_new_size))\n\n\n@flow.unittest.skip_unless_1n1d()\n@unittest.skipIf(\n    not flow.unittest.env.eager_execution_enabled(),\n    \".numpy() doesn't work in lazy mode\",\n)\nclass TestImageResize(flow.unittest.TestCase):\n    def test_image_resize_to_fixed_size(test_case):\n        (image_files, _) = image_test_util.random_sample_images_from_coco()\n        _test_image_resize_with_cv(\n            test_case, image_files, target_size=(224, 224), keep_aspect_ratio=False\n        )\n\n    def test_image_resize_shorter_to_target_size(test_case):\n        (image_files, _) = image_test_util.random_sample_images_from_coco()\n        _test_image_resize_with_cv(\n            test_case,\n            image_files,\n            target_size=800,\n            keep_aspect_ratio=True,\n            resize_side=\"shorter\",\n        )\n\n    def test_image_resize_longer_to_target_size(test_case):\n        (image_files, _) = image_test_util.random_sample_images_from_coco()\n        _test_image_resize_with_cv(\n            test_case,\n            image_files,\n            target_size=1000,\n            keep_aspect_ratio=True,\n            resize_side=\"longer\",\n        )\n\n    def test_image_resize_shorter_to_target_size_with_max_size(test_case):\n        (image_files, _) = image_test_util.random_sample_images_from_coco()\n        _test_image_resize_with_cv(\n            test_case,\n            image_files,\n            target_size=800,\n            max_size=1333,\n            keep_aspect_ratio=True,\n            resize_side=\"shorter\",\n        )\n\n    def test_image_resize_longer_to_target_size_with_min_size(test_case):\n        (image_files, _) = image_test_util.random_sample_images_from_coco()\n        _test_image_resize_with_cv(\n            test_case,\n            image_files,\n            target_size=1000,\n            min_size=600,\n            keep_aspect_ratio=True,\n            resize_side=\"longer\",\n        )\n\n    def test_image_resize_to_fixed_size_with_dtype_uint8(test_case):\n        (image_files, _) = image_test_util.random_sample_images_from_coco()\n        _test_image_resize_with_cv(\n            test_case,\n            image_files,\n            target_size=(1000, 1000),\n            keep_aspect_ratio=False,\n            dtype=flow.uint8,\n        )\n\n    def test_image_reisze_shorter_to_target_size_with_max_size_with_dtype_uint8(\n        test_case,\n    ):\n        (image_files, _) = image_test_util.random_sample_images_from_coco()\n        _test_image_resize_with_cv(\n            test_case,\n            image_files,\n            target_size=1000,\n            max_size=1600,\n            keep_aspect_ratio=True,\n            resize_side=\"shorter\",\n            dtype=flow.uint8,\n        )\n\n    def test_image_resize_uint8_to_float(test_case):\n        (image_files, _) = image_test_util.random_sample_images_from_coco()\n        _test_image_resize_with_cv(\n            test_case,\n            image_files,\n            target_size=(1000, 1000),\n            keep_aspect_ratio=False,\n            dtype=flow.float32,\n            origin_dtype=flow.uint8,\n        )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_in_top_k.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\ndef _topk_np(input, k, dim: int = -1, largest: bool = True, _sorted: bool = True):\n    in_dims = input.shape\n    out_dims = list(in_dims)\n    num_axes = len(input.shape)\n    if dim < 0:\n        dim = dim + num_axes\n    n = in_dims[dim]\n    if k > n:\n        k = n\n    out_dims[dim] = k\n    out_dims = tuple(out_dims)\n    prev_dims = 1\n    next_dims = 1\n    for i in range(dim):\n        prev_dims *= in_dims[i]\n    for i in range(dim + 1, len(in_dims)):\n        next_dims *= in_dims[i]\n    input_flat = input.reshape((prev_dims, n, next_dims))\n    values_ref = np.ndarray(shape=(prev_dims, k, next_dims), dtype=input.dtype)\n    values_ref.fill(0)\n    indices_ref = np.ndarray(shape=(prev_dims, k, next_dims), dtype=np.int64)\n    indices_ref.fill(-1)\n    for i in range(prev_dims):\n        for j in range(next_dims):\n            kv = []\n            for x in range(n):\n                val = input_flat[i, x, j]\n                y = x * next_dims + i * in_dims[dim] * next_dims + j\n                kv.append((val, x, y))\n            cnt = 0\n            for (val, x, y) in sorted(kv, key=lambda x: (x[0], -x[1]), reverse=largest):\n                values_ref[i, cnt, j] = val\n                indices_ref[i, cnt, j] = x\n                cnt += 1\n                if cnt >= k or cnt >= n:\n                    break\n    values_ref = values_ref.reshape(out_dims)\n    indices_ref = indices_ref.reshape(out_dims)\n    return (values_ref, indices_ref)\n\n\ndef _in_top_k_np(targets, predictions, k):\n    assert (\n        targets.shape[0] == predictions.shape[0]\n    ), \"The num of targets must equal the num of predictions\"\n    assert len(targets.shape) == 1, \"The dimension of targets must be 1\"\n    assert len(predictions.shape) == 2, \"The dimension of predictions must be 2\"\n    results = np.zeros_like(targets, dtype=np.int8)\n    for i in range(len(results)):\n        (_, indices_topk) = _topk_np(predictions[i], k)\n        if targets[i] in indices_topk:\n            results[i] = 1\n    return results\n\n\ndef _test_in_top_k_impl(test_case, shape, k, device):\n    np_targets = np.random.randint(0, shape[1], size=shape[0])\n    np_predictions = np.random.rand(*shape)\n    of_targets = flow.tensor(\n        np_targets, dtype=flow.int32, device=flow.device(device), requires_grad=False\n    )\n    of_predictions = flow.tensor(\n        np_predictions,\n        dtype=flow.float32,\n        device=flow.device(device),\n        requires_grad=True,\n    )\n    of_out = flow.in_top_k(of_targets, of_predictions, k)\n    np_out = _in_top_k_np(np_targets, np_predictions, k)\n    test_case.assertTrue(\n        np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001, equal_nan=True)\n    )\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestInTopK(flow.unittest.TestCase):\n    def test_in_top_k(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"shape\"] = [(2, 3), (3, 4), (5, 6)]\n        arg_dict[\"k\"] = [1, 2, 5]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            _test_in_top_k_impl(test_case, *arg)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_index_add.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport numpy as np\nimport torch as torch_origin\nfrom collections import OrderedDict\n\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.test_util import GenArgList\nimport unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\ndef _test_index_add(test_case, device):\n    torch_origin_x = torch_origin.ones(5, 3).to(device)\n    torch_origin_t = torch_origin.tensor(\n        [[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=torch_origin.float\n    ).to(device)\n    torch_origin_index = torch_origin.tensor([0, 4, 2]).to(device)\n    torch_origin_y = torch_origin.index_add(\n        torch_origin_x, 0, torch_origin_index, torch_origin_t\n    )\n    torch_origin_y_alpha = torch_origin.index_add(\n        torch_origin_x, 0, torch_origin_index, torch_origin_t, alpha=-1\n    )\n\n    flow_x = flow.ones(5, 3).to(device)\n    flow_t = flow.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=flow.float).to(device)\n    flow_index = flow.tensor([0, 4, 2]).to(device)\n    flow_y = flow.index_add(flow_x, 0, flow_index, flow_t)\n    flow_y_alpha = flow.index_add(flow_x, 0, flow_index, flow_t, alpha=-1)\n    test_case.assertTrue(\n        np.allclose(torch_origin_y.cpu().numpy(), flow_y.cpu().numpy(), 1e-05, 1e-05)\n    )\n    test_case.assertTrue(\n        np.allclose(\n            torch_origin_y_alpha.cpu().numpy(), flow_y_alpha.cpu().numpy(), 1e-05, 1e-05\n        )\n    )\n\n    # check inplace\n    torch_origin_x.index_add_(0, torch_origin_index, torch_origin_t)\n    flow_x.index_add_(0, flow_index, flow_t)\n    test_case.assertTrue(\n        np.allclose(torch_origin_y.cpu().numpy(), flow_y.cpu().numpy(), 1e-05, 1e-05)\n    )\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestIndexAdd(flow.unittest.TestCase):\n    def test_index_add(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [_test_index_add]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n    @profile(torch.index_add)\n    def profile_index_add(test_case):\n        torch.index_add(\n            torch.ones(50, 30),\n            0,\n            torch.arange(30),\n            torch.arange(1, 901, dtype=torch.float32).reshape(30, 30),\n        )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_index_select.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\n\nimport oneflow as flow\nimport oneflow.unittest\n\nimport unittest\n\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestIndexSelect(flow.unittest.TestCase):\n    @autotest()\n    def test_index_select_by_random(test_case):\n        device = random_device()\n\n        # test 4 dimensions tensor\n        dim = random(0, 4).to(int)\n\n        tensor_dim = []\n        for i in range(0, 4):\n            tensor_dim.append(random(2, 6).to(int).value())\n\n        index = random_tensor(\n            ndim=1,\n            dim0=random(1, 10).to(int),\n            low=0,\n            high=tensor_dim[dim.value()],\n            dtype=int,\n        ).to(device)\n\n        x = random_tensor(\n            ndim=4,\n            dim0=tensor_dim[0],\n            dim1=tensor_dim[1],\n            dim2=tensor_dim[2],\n            dim3=tensor_dim[3],\n        ).to(device)\n\n        y = torch.index_select(x, dim, index)\n\n        return y\n\n    @autotest(auto_backward=False)\n    def test_index_select_bool_by_random(test_case):\n        device = random_device()\n\n        # test 4 dimensions tensor\n        dim = random(0, 4).to(int)\n\n        tensor_dim = []\n        for i in range(0, 4):\n            tensor_dim.append(random(2, 6).to(int).value())\n\n        index = random_tensor(\n            ndim=1,\n            dim0=random(1, 10).to(int),\n            low=0,\n            high=tensor_dim[dim.value()],\n            dtype=int,\n        ).to(device)\n\n        x = random_tensor(\n            ndim=4,\n            dim0=tensor_dim[0],\n            dim1=tensor_dim[1],\n            dim2=tensor_dim[2],\n            dim3=tensor_dim[3],\n        ).to(device=device, dtype=torch.bool)\n\n        y = torch.index_select(x, dim, index)\n\n        return y\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_info.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\n\nimport oneflow as flow\nfrom oneflow.test_utils.automated_test_util import *\nimport oneflow.unittest\n\n\ndef _test_finfo(test_case, dtype):\n    # test finfo without input params\n    if dtype is None:\n        finfo = torch.finfo()\n    else:\n        finfo = torch.finfo(dtype)\n    torch_finfo = finfo.pytorch\n    flow_finfo = finfo.oneflow\n    test_case.assertEqual(torch_finfo.max, flow_finfo.max)\n    test_case.assertEqual(torch_finfo.min, flow_finfo.min)\n    test_case.assertEqual(torch_finfo.bits, flow_finfo.bits)\n    test_case.assertEqual(torch_finfo.eps, flow_finfo.eps)\n    test_case.assertEqual(torch_finfo.tiny, flow_finfo.tiny)\n    test_case.assertEqual(torch_finfo.resolution, flow_finfo.resolution)\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestTypeInfo(flow.unittest.TestCase):\n    def test_iinfo(test_case):\n        for dtype in [torch.uint8, torch.int8, torch.int32, torch.int64]:\n            iinfo = torch.iinfo(dtype)\n            # checker not implemented for type <class 'torch.iinfo'> and <class 'oneflow.iinfo'>\n            # so return all fields as a tuple\n            return iinfo.max, iinfo.min, iinfo.bits\n\n    def test_finfo(test_case):\n        for dtype in [None, torch.half, torch.bfloat16, torch.float, torch.double]:\n            _test_finfo(test_case, dtype)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_initializer.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nimport numpy as np\n\nimport oneflow as flow\nfrom oneflow.test_utils.automated_test_util import *\nimport oneflow.unittest\n\n\nclass DataChecker:\n    check_list = [\n        \"mean\",\n        \"std\",\n        \"min\",\n        \"max\",\n        \"value\",\n        \"lambda_func\",\n    ]\n\n    def __init__(self, **kwargs):\n        self.checkers = {}\n        for key in self.check_list:\n            if key in kwargs:\n                self.checkers[key] = kwargs[key]\n\n    def __call__(self, test_case, tensor):\n        for func in [\"mean\", \"std\"]:\n            if func in self.checkers:\n                of_res = eval(f\"tensor.{func}\")().numpy()\n                checker_res = self.checkers[func]\n                test_case.assertTrue(\n                    np.allclose(of_res, checker_res, rtol=1e-1, atol=1e-1),\n                    f\"{func} not equal, {of_res} vs {checker_res}\",\n                )\n\n        if \"min\" in self.checkers:\n            test_case.assertTrue(np.all(tensor.numpy() >= self.checkers[\"min\"]))\n\n        if \"max\" in self.checkers:\n            test_case.assertTrue(np.all(tensor.numpy() <= self.checkers[\"max\"]))\n\n        if \"value\" in self.checkers:\n            test_case.assertTrue(np.all(tensor.numpy() == self.checkers[\"value\"]))\n\n        if \"lambda_func\" in self.checkers:\n            test_case.assertTrue(\n                np.allclose(\n                    tensor.numpy(),\n                    self.checkers[\"lambda_func\"](tensor.shape),\n                    rtol=1e-4,\n                    atol=1e-4,\n                )\n            )\n\n\n# NOTE(wyg): register initializers to this list\ncheck_func_list = [\n    # oneflow.nn.init.normal_\n    {\n        \"func\": flow.nn.init.normal_,\n        \"params\": {\"mean\": 0.0, \"std\": 1.0},\n        \"checker\": DataChecker(mean=0.0, std=1.0),\n    },\n    # oneflow.nn.init.xavier_normal_\n    {\n        \"func\": flow.nn.init.xavier_normal_,\n        \"params\": {\"gain\": 1.0},\n        \"checker\": DataChecker(mean=0.0, std=0.0625),\n    },\n    # oneflow.nn.init.kaiming_normal_\n    {\n        \"func\": flow.nn.init.kaiming_normal_,\n        \"params\": {\"mode\": \"fan_in\"},\n        \"checker\": DataChecker(mean=0.0, std=0.0883883476),\n    },\n    {\n        \"func\": flow.nn.init.kaiming_normal_,\n        \"params\": {\"mode\": \"fan_out\"},\n        \"checker\": DataChecker(mean=0.0, std=0.0883883476),\n    },\n    {\n        \"func\": flow.nn.init.kaiming_normal_,\n        \"params\": {\"mode\": \"fan_in\", \"a\": 2.0, \"nonlinearity\": \"leaky_relu\"},\n        \"checker\": DataChecker(mean=0.0, std=0.0395284708),\n    },\n    {\n        \"func\": flow.nn.init.kaiming_normal_,\n        \"params\": {\"mode\": \"fan_in\", \"a\": 2.0, \"nonlinearity\": \"linear\"},\n        \"checker\": DataChecker(mean=0.0, std=0.0625),\n    },\n    # oneflow.nn.init.trunc_normal_\n    {\n        \"func\": flow.nn.init.trunc_normal_,\n        \"params\": {\"mean\": 0.0, \"std\": 1.0, \"a\": -5.0, \"b\": 5.0},\n        \"checker\": DataChecker(min=-5.0, max=5.0),\n    },\n    # oneflow.nn.init.uniform_\n    {\n        \"func\": flow.nn.init.uniform_,\n        \"params\": {\"a\": 0.0, \"b\": 1.0},\n        \"checker\": DataChecker(min=0.0, max=1.0, mean=0.5, std=0.28849875926971436),\n    },\n    # oneflow.nn.init.xavier_uniform_\n    {\n        \"func\": flow.nn.init.xavier_uniform_,\n        \"params\": {\"gain\": 1.0},\n        \"checker\": DataChecker(\n            min=-0.10825317547305482, max=0.10825317547305482, mean=0.0, std=0.0625\n        ),\n    },\n    # oneflow.nn.init.kaiming_uniform_\n    {\n        \"func\": flow.nn.init.kaiming_uniform_,\n        \"params\": {\"mode\": \"fan_in\"},\n        \"checker\": DataChecker(\n            min=-0.15309310892394865, max=15309310892394865, mean=0.0, std=0.0883883476\n        ),\n    },\n    {\n        \"func\": flow.nn.init.kaiming_uniform_,\n        \"params\": {\"mode\": \"fan_out\"},\n        \"checker\": DataChecker(\n            min=-0.15309310892394865, max=15309310892394865, mean=0.0, std=0.0883883476\n        ),\n    },\n    {\n        \"func\": flow.nn.init.kaiming_uniform_,\n        \"params\": {\"mode\": \"fan_in\", \"a\": 2.0, \"nonlinearity\": \"leaky_relu\"},\n        \"checker\": DataChecker(\n            min=-0.06846531968814576,\n            max=0.06846531968814576,\n            mean=0.0,\n            std=0.0395284708,\n        ),\n    },\n    {\n        \"func\": flow.nn.init.kaiming_uniform_,\n        \"params\": {\"mode\": \"fan_in\", \"a\": 2.0, \"nonlinearity\": \"linear\"},\n        \"checker\": DataChecker(\n            min=-0.10825317547305482, max=0.10825317547305482, mean=0.0, std=0.0625\n        ),\n    },\n    # oneflow.nn.init.eye_\n    {\n        \"func\": flow.nn.init.eye_,\n        \"params\": {},\n        \"checker\": DataChecker(lambda_func=lambda size: np.eye(*size)),\n    },\n]\n\n\n@oneflow.unittest.skip_unless_1n1d()\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\nclass TestInitializer(flow.unittest.TestCase):\n    def test_initializer(test_case):\n        default_shape = (256, 256)\n        for device in [\"cpu\", \"cuda\"]:\n            for check_func in check_func_list:\n                tensor = flow.empty(*default_shape, device=flow.device(device))\n                check_func[\"func\"](tensor, **check_func[\"params\"])\n                try:\n                    check_func[\"checker\"](test_case, tensor)\n                except AssertionError as e:\n                    print(\n                        f\"Failed: {check_func['func'].__name__} {check_func['params']}\"\n                    )\n                    raise e\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_instancenorm.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\ndef _test_instancenorm1d(test_case, device):\n    input_arr = np.array(\n        [\n            [\n                [-0.1091, 2.0041, 0.885, -0.0412],\n                [-1.2055, 0.7442, 2.33, 1.2411],\n                [-1.2466, 0.3667, 1.2267, 0.3043],\n            ],\n            [\n                [-0.2484, -1.1407, 0.3352, 0.6687],\n                [-0.2975, -0.0227, -0.2302, -0.3762],\n                [-0.7759, -0.6789, 1.1444, 1.8077],\n            ],\n        ],\n        dtype=np.float32,\n    )\n    output_arr = np.array(\n        [\n            [\n                [-0.9262, 1.5395, 0.2337, -0.847],\n                [-1.5486, -0.026, 1.2125, 0.3621],\n                [-1.5807, 0.2287, 1.1933, 0.1587],\n            ],\n            [\n                [-0.2215, -1.5212, 0.6285, 1.1143],\n                [-0.5016, 1.5917, 0.011, -1.1011],\n                [-1.0207, -0.9346, 0.6833, 1.2719],\n            ],\n        ],\n        dtype=np.float32,\n    )\n    m = flow.nn.InstanceNorm1d(num_features=3, eps=1e-05, momentum=0.1).to(\n        device=flow.device(device)\n    )\n    x = flow.tensor(input_arr, dtype=flow.float32, device=flow.device(device))\n    y = m(x)\n    test_case.assertTrue(np.allclose(y.numpy(), output_arr, rtol=1e-3, atol=1e-3))\n    m.eval()\n    y = m(x)\n    test_case.assertTrue(np.allclose(y.numpy(), output_arr, rtol=1e-3, atol=1e-3))\n\n\ndef _test_instancenorm2d(test_case, device):\n    input_arr = np.array(\n        [\n            [\n                [\n                    [-0.8791, 0.2553, 0.7403, -0.2859],\n                    [0.8006, -1.7701, -0.9617, 0.1705],\n                    [0.2842, 1.7825, 0.3365, -0.8525],\n                ],\n                [\n                    [0.7332, -0.0737, 0.7245, -0.6551],\n                    [1.4461, -0.1827, 0.9737, -2.1571],\n                    [0.4657, 0.7244, 0.3378, 0.1775],\n                ],\n            ],\n            [\n                [\n                    [1.8896, 1.8686, 0.1896, 0.9817],\n                    [-0.0671, 1.5569, 1.1449, 0.0086],\n                    [-0.9468, -0.0124, 1.3227, -0.6567],\n                ],\n                [\n                    [-0.8472, 1.3012, -1.1065, 0.9348],\n                    [1.0346, 1.5703, 0.2419, -0.7048],\n                    [0.6957, -0.4523, -0.8819, 1.0164],\n                ],\n            ],\n        ],\n        dtype=np.float32,\n    )\n    output = np.array(\n        [\n            [\n                [\n                    [-0.9155, 0.31, 0.8339, -0.2747],\n                    [0.8991, -1.8781, -1.0048, 0.2183],\n                    [0.3412, 1.9598, 0.3977, -0.8868],\n                ],\n                [\n                    [0.586, -0.3169, 0.5763, -0.9675],\n                    [1.3837, -0.4389, 0.8551, -2.6483],\n                    [0.2867, 0.5761, 0.1435, -0.0358],\n                ],\n            ],\n            [\n                [\n                    [1.374, 1.3515, -0.4466, 0.4017],\n                    [-0.7215, 1.0177, 0.5765, -0.6405],\n                    [-1.6636, -0.663, 0.7669, -1.353],\n                ],\n                [\n                    [-1.1583, 1.1444, -1.4363, 0.7516],\n                    [0.8586, 1.4328, 0.009, -1.0057],\n                    [0.4954, -0.7351, -1.1955, 0.8391],\n                ],\n            ],\n        ],\n        dtype=np.float32,\n    )\n    m = flow.nn.InstanceNorm2d(num_features=2, eps=1e-05, momentum=0.1).to(\n        device=flow.device(device)\n    )\n    x = flow.tensor(input_arr, dtype=flow.float32, device=flow.device(device))\n    y = m(x)\n    test_case.assertTrue(np.allclose(y.numpy(), output, 0.0001, 0.0001))\n    m.eval()\n    y = m(x)\n    test_case.assertTrue(np.allclose(y.numpy(), output, 0.0001, 0.0001))\n\n\ndef _test_instancenorm3d(test_case, device):\n    input_arr = np.array(\n        [\n            [\n                [\n                    [\n                        [1.04569761, 0.22863248, 1.42439335, 1.62249689],\n                        [-0.80578825, -0.27276461, 1.04556507, 0.56864134],\n                        [-1.24085419, -1.23960097, 0.33451416, -1.84820402],\n                    ],\n                    [\n                        [-1.511261, 1.06157517, -0.26715858, -1.32888141],\n                        [1.17976881, -0.07931171, 0.33910684, -1.93458573],\n                        [-1.72659647, 0.79049652, 0.39102785, -1.16264882],\n                    ],\n                ],\n                [\n                    [\n                        [0.30067973, -1.2912226, -0.61508225, 0.56454001],\n                        [0.87074187, -1.69257376, 0.36119148, -0.31014289],\n                        [0.20776964, 1.26195488, -1.37122193, -0.17945234],\n                    ],\n                    [\n                        [-0.31112407, -0.80682631, 0.8233194, 0.6384975],\n                        [0.57617527, 0.45505028, 1.68286151, -1.09590744],\n                        [-1.18127546, -1.07529277, 0.52779943, 1.21755926],\n                    ],\n                ],\n            ],\n            [\n                [\n                    [\n                        [-0.12832351, 1.05625455, -0.23253249, -0.64747611],\n                        [-0.00738123, -1.41390089, -1.92664144, -0.21427625],\n                        [-0.94631219, -0.86493989, 0.21026905, 0.24989732],\n                    ],\n                    [\n                        [1.3859182, 1.72002107, 0.50091892, 1.04198896],\n                        [0.71694594, 1.66417023, -1.63030052, 0.77182641],\n                        [0.71545083, 1.96458366, -1.99031931, 1.3196714],\n                    ],\n                ],\n                [\n                    [\n                        [1.80091702, 0.02834973, 0.82259214, -1.05597501],\n                        [-0.58212207, 0.44205949, -0.14740003, -0.994508],\n                        [1.14678114, -0.39196097, 1.2554798, -0.41829324],\n                    ],\n                    [\n                        [-1.0153903, -0.25755713, -1.81756333, -1.06781159],\n                        [1.79680841, -1.9107133, -0.64325796, -1.94640775],\n                        [1.30671156, 1.20445339, -1.26262901, -0.79494188],\n                    ],\n                ],\n            ],\n        ],\n        dtype=np.float32,\n    )\n    output_arr = np.array(\n        [\n            [\n                [\n                    [\n                        [1.067, 0.3324, 1.4075, 1.5856],\n                        [-0.5976, -0.1184, 1.0669, 0.6381],\n                        [-0.9888, -0.9877, 0.4276, -1.5349],\n                    ],\n                    [\n                        [-1.2319, 1.0813, -0.1134, -1.068],\n                        [1.1876, 0.0555, 0.4317, -1.6126],\n                        [-1.4256, 0.8376, 0.4784, -0.9185],\n                    ],\n                ],\n                [\n                    [\n                        [0.3447, -1.3751, -0.6446, 0.6298],\n                        [0.9606, -1.8087, 0.4101, -0.3152],\n                        [0.2444, 1.3833, -1.4615, -0.174],\n                    ],\n                    [\n                        [-0.3162, -0.8518, 0.9094, 0.7097],\n                        [0.6424, 0.5115, 1.838, -1.1641],\n                        [-1.2563, -1.1418, 0.5901, 1.3353],\n                    ],\n                ],\n            ],\n            [\n                [\n                    [\n                        [-0.2327, 0.8016, -0.3236, -0.6859],\n                        [-0.1271, -1.3551, -1.8028, -0.3077],\n                        [-0.9469, -0.8758, 0.063, 0.0976],\n                    ],\n                    [\n                        [1.0895, 1.3812, 0.3167, 0.7892],\n                        [0.5054, 1.3324, -1.5441, 0.5533],\n                        [0.5041, 1.5947, -1.8584, 1.0316],\n                    ],\n                ],\n                [\n                    [\n                        [1.7507, 0.1901, 0.8894, -0.7645],\n                        [-0.3473, 0.5544, 0.0354, -0.7104],\n                        [1.1748, -0.1799, 1.2705, -0.2031],\n                    ],\n                    [\n                        [-0.7288, -0.0616, -1.435, -0.7749],\n                        [1.7471, -1.517, -0.4012, -1.5485],\n                        [1.3156, 1.2256, -0.9465, -0.5347],\n                    ],\n                ],\n            ],\n        ],\n        dtype=np.float32,\n    )\n    m = flow.nn.InstanceNorm3d(num_features=2, eps=1e-05, momentum=0.1).to(\n        device=flow.device(device)\n    )\n    x = flow.tensor(input_arr, dtype=flow.float32, device=flow.device(device))\n    y = m(x)\n    test_case.assertTrue(np.allclose(y.numpy(), output_arr, 0.0001, 0.0001))\n    m.eval()\n    y = m(x)\n    test_case.assertTrue(np.allclose(y.numpy(), output_arr, 0.0001, 0.0001))\n\n\ndef _test_instancenorm1d_backward(test_case, device):\n    input_arr = np.array(\n        [\n            [\n                [-0.1091, 2.0041, 0.885, -0.0412],\n                [-1.2055, 0.7442, 2.33, 1.2411],\n                [-1.2466, 0.3667, 1.2267, 0.3043],\n            ],\n            [\n                [-0.2484, -1.1407, 0.3352, 0.6687],\n                [-0.2975, -0.0227, -0.2302, -0.3762],\n                [-0.7759, -0.6789, 1.1444, 1.8077],\n            ],\n        ],\n        dtype=np.float32,\n    )\n    m = flow.nn.InstanceNorm1d(num_features=2, eps=1e-05, momentum=0.1).to(\n        device=flow.device(device)\n    )\n    x = flow.tensor(input_arr, device=flow.device(device), requires_grad=True)\n    y = m(x)\n    z = y.sum()\n    z.backward()\n    test_case.assertTrue(\n        np.allclose(x.grad.numpy(), np.zeros(shape=input_arr.shape), 1e-05, 1e-05)\n    )\n\n\ndef _test_instancenorm2d_backward(test_case, device):\n    input_arr = np.array(\n        [\n            [\n                [\n                    [-0.8791, 0.2553, 0.7403, -0.2859],\n                    [0.8006, -1.7701, -0.9617, 0.1705],\n                    [0.2842, 1.7825, 0.3365, -0.8525],\n                ],\n                [\n                    [0.7332, -0.0737, 0.7245, -0.6551],\n                    [1.4461, -0.1827, 0.9737, -2.1571],\n                    [0.4657, 0.7244, 0.3378, 0.1775],\n                ],\n            ],\n            [\n                [\n                    [1.8896, 1.8686, 0.1896, 0.9817],\n                    [-0.0671, 1.5569, 1.1449, 0.0086],\n                    [-0.9468, -0.0124, 1.3227, -0.6567],\n                ],\n                [\n                    [-0.8472, 1.3012, -1.1065, 0.9348],\n                    [1.0346, 1.5703, 0.2419, -0.7048],\n                    [0.6957, -0.4523, -0.8819, 1.0164],\n                ],\n            ],\n        ],\n        dtype=np.float32,\n    )\n    m = flow.nn.InstanceNorm2d(num_features=2, eps=1e-05, momentum=0.1).to(\n        device=flow.device(device)\n    )\n    x = flow.tensor(input_arr, device=flow.device(device), requires_grad=True)\n    y = m(x)\n    z = y.sum()\n    z.backward()\n    test_case.assertTrue(\n        np.allclose(x.grad.numpy(), np.zeros(shape=input_arr.shape), 1e-05, 1e-05)\n    )\n\n\ndef _test_instancenorm3d_backward(test_case, device):\n    input_arr = np.array(\n        [\n            [\n                [\n                    [\n                        [1.04569761, 0.22863248, 1.42439335, 1.62249689],\n                        [-0.80578825, -0.27276461, 1.04556507, 0.56864134],\n                        [-1.24085419, -1.23960097, 0.33451416, -1.84820402],\n                    ],\n                    [\n                        [-1.511261, 1.06157517, -0.26715858, -1.32888141],\n                        [1.17976881, -0.07931171, 0.33910684, -1.93458573],\n                        [-1.72659647, 0.79049652, 0.39102785, -1.16264882],\n                    ],\n                ],\n                [\n                    [\n                        [0.30067973, -1.2912226, -0.61508225, 0.56454001],\n                        [0.87074187, -1.69257376, 0.36119148, -0.31014289],\n                        [0.20776964, 1.26195488, -1.37122193, -0.17945234],\n                    ],\n                    [\n                        [-0.31112407, -0.80682631, 0.8233194, 0.6384975],\n                        [0.57617527, 0.45505028, 1.68286151, -1.09590744],\n                        [-1.18127546, -1.07529277, 0.52779943, 1.21755926],\n                    ],\n                ],\n            ],\n            [\n                [\n                    [\n                        [-0.12832351, 1.05625455, -0.23253249, -0.64747611],\n                        [-0.00738123, -1.41390089, -1.92664144, -0.21427625],\n                        [-0.94631219, -0.86493989, 0.21026905, 0.24989732],\n                    ],\n                    [\n                        [1.3859182, 1.72002107, 0.50091892, 1.04198896],\n                        [0.71694594, 1.66417023, -1.63030052, 0.77182641],\n                        [0.71545083, 1.96458366, -1.99031931, 1.3196714],\n                    ],\n                ],\n                [\n                    [\n                        [1.80091702, 0.02834973, 0.82259214, -1.05597501],\n                        [-0.58212207, 0.44205949, -0.14740003, -0.994508],\n                        [1.14678114, -0.39196097, 1.2554798, -0.41829324],\n                    ],\n                    [\n                        [-1.0153903, -0.25755713, -1.81756333, -1.06781159],\n                        [1.79680841, -1.9107133, -0.64325796, -1.94640775],\n                        [1.30671156, 1.20445339, -1.26262901, -0.79494188],\n                    ],\n                ],\n            ],\n        ],\n        dtype=np.float32,\n    )\n    m = flow.nn.InstanceNorm3d(num_features=2, eps=1e-05, momentum=0.1).to(\n        device=flow.device(device)\n    )\n    x = flow.tensor(input_arr, device=flow.device(device), requires_grad=True)\n    y = m(x)\n    z = y.sum()\n    z.backward()\n    test_case.assertTrue(\n        np.allclose(x.grad.numpy(), np.zeros(shape=input_arr.shape), 1e-05, 1e-05)\n    )\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestInstanceNorm(flow.unittest.TestCase):\n    def test_instancenorm(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [\n            _test_instancenorm1d,\n            _test_instancenorm2d,\n            _test_instancenorm3d,\n            _test_instancenorm1d_backward,\n            _test_instancenorm2d_backward,\n            _test_instancenorm3d_backward,\n        ]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n    # NOTE: in the following tese cases, if set track_running_stats=True, will fail!\n    # it could be some bud to be fixed in nn.InstanceNorm\n    @autotest(n=5, auto_backward=True, rtol=1e-3, atol=1e-3, check_graph=True)\n    def test_instancenorm_with_random_data(test_case):\n        height = random(1, 6).to(int)\n        width = random(1, 6).to(int)\n        m = torch.nn.InstanceNorm1d(\n            num_features=height,\n            eps=random().to(float) | nothing(),\n            momentum=random().to(float) | nothing(),\n            affine=random().to(bool),\n            track_running_stats=False,\n        )\n        m.train(random())\n        device = random_device()\n        m.to(device)\n        x = random_tensor(ndim=3, dim1=height, dim2=width).to(device)\n        y = m(x)\n        return y\n\n    @autotest(n=5, rtol=1e-3, atol=1e-3)\n    def test_instancenorm_with_random_data2(test_case):\n        channel = random(1, 6).to(int)\n        height = random(1, 6).to(int)\n        width = random(1, 6).to(int)\n        m = torch.nn.InstanceNorm2d(\n            num_features=channel,\n            eps=random().to(float) | nothing(),\n            momentum=random().to(float) | nothing(),\n            affine=random().to(bool),\n            track_running_stats=False,\n        )\n        m.train(random())\n        device = random_device()\n        m.to(device)\n        x = random_tensor(ndim=4, dim1=channel, dim2=height, dim3=width).to(device)\n        y = m(x)\n        return y\n\n    @autotest(n=5, rtol=1e-3, atol=1e-3)\n    def test_instancenorm_with_random_data3(test_case):\n        channel = random(1, 6).to(int)\n        depth = random(1, 6).to(int)\n        height = random(1, 6).to(int)\n        width = random(1, 6).to(int)\n        m = torch.nn.InstanceNorm3d(\n            num_features=channel,\n            eps=random().to(float) | nothing(),\n            momentum=random().to(float) | nothing(),\n            affine=random().to(bool),\n            track_running_stats=False,\n        )\n        m.train(random())\n        device = random_device()\n        m.to(device)\n        x = random_tensor(ndim=5, dim1=channel, dim2=depth, dim3=height, dim4=width).to(\n            device\n        )\n        y = m(x)\n        return y\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_interpolate.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\ndef _test_interpolate_linear_1d(test_case, device):\n    input = flow.tensor(\n        np.arange(1, 5).reshape((1, 1, 4)),\n        device=flow.device(device),\n        dtype=flow.float32,\n        requires_grad=True,\n    )\n    of_out = flow.nn.functional.interpolate(input, scale_factor=2.0, mode=\"linear\")\n    np_out = [[[1.0, 1.25, 1.75, 2.25, 2.75, 3.25, 3.75, 4.0]]]\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05))\n    of_out = of_out.sum()\n    of_out.backward()\n    np_grad = [[[2.0, 2.0, 2.0, 2.0]]]\n    test_case.assertTrue(np.allclose(np_grad, input.grad.numpy(), 0.0001, 0.0001))\n    input.grad = None\n    of_out = flow.nn.functional.interpolate(\n        input, scale_factor=2.0, mode=\"linear\", align_corners=True\n    )\n    np_out = [\n        [\n            [\n                1.0,\n                1.4285714626312256,\n                1.8571429252624512,\n                2.2857141494750977,\n                2.7142856121063232,\n                3.142857074737549,\n                3.5714285373687744,\n                4.0,\n            ]\n        ]\n    ]\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05))\n    of_out = of_out.sum()\n    of_out.backward()\n    np_grad = [\n        [\n            [\n                1.7142856121063232,\n                2.2857141494750977,\n                2.2857143878936768,\n                1.7142856121063232,\n            ]\n        ]\n    ]\n    test_case.assertTrue(np.allclose(np_grad, input.grad.numpy(), 0.0001, 0.0001))\n\n\ndef _test_interpolate_nearest_1d(test_case, device):\n    input = flow.tensor(\n        np.arange(1, 5).reshape((1, 1, 4)),\n        device=flow.device(device),\n        dtype=flow.float32,\n        requires_grad=True,\n    )\n    of_out = flow.nn.functional.interpolate(input, scale_factor=2.0, mode=\"nearest\")\n    np_out = [[[1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0]]]\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05))\n    of_out = of_out.sum()\n    of_out.backward()\n    np_grad = [[[2.0, 2.0, 2.0, 2.0]]]\n    test_case.assertTrue(np.allclose(np_grad, input.grad.numpy(), 0.0001, 0.0001))\n\n\ndef _test_interpolate_nearest_2d(test_case, device):\n    input = flow.tensor(\n        np.arange(1, 5).reshape((1, 1, 2, 2)),\n        device=flow.device(device),\n        dtype=flow.float32,\n        requires_grad=True,\n    )\n    of_out = flow.nn.functional.interpolate(input, scale_factor=2.0, mode=\"nearest\")\n    np_out = np.array(\n        [\n            [\n                [\n                    [1.0, 1.0, 2.0, 2.0],\n                    [1.0, 1.0, 2.0, 2.0],\n                    [3.0, 3.0, 4.0, 4.0],\n                    [3.0, 3.0, 4.0, 4.0],\n                ]\n            ]\n        ]\n    )\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001))\n    of_out = of_out.sum()\n    of_out.backward()\n    np_grad = [[[[4.0, 4.0], [4.0, 4.0]]]]\n    test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-05, 1e-05))\n\n\ndef _test_interpolate_nearest_3d(test_case, device):\n    input = flow.tensor(\n        np.arange(1, 9).reshape((1, 1, 2, 2, 2)),\n        device=flow.device(device),\n        dtype=flow.float32,\n        requires_grad=True,\n    )\n    of_out = flow.nn.functional.interpolate(input, scale_factor=2.0, mode=\"nearest\")\n    np_out = np.array(\n        [\n            [\n                [\n                    [\n                        [1.0, 1.0, 2.0, 2.0],\n                        [1.0, 1.0, 2.0, 2.0],\n                        [3.0, 3.0, 4.0, 4.0],\n                        [3.0, 3.0, 4.0, 4.0],\n                    ],\n                    [\n                        [1.0, 1.0, 2.0, 2.0],\n                        [1.0, 1.0, 2.0, 2.0],\n                        [3.0, 3.0, 4.0, 4.0],\n                        [3.0, 3.0, 4.0, 4.0],\n                    ],\n                    [\n                        [5.0, 5.0, 6.0, 6.0],\n                        [5.0, 5.0, 6.0, 6.0],\n                        [7.0, 7.0, 8.0, 8.0],\n                        [7.0, 7.0, 8.0, 8.0],\n                    ],\n                    [\n                        [5.0, 5.0, 6.0, 6.0],\n                        [5.0, 5.0, 6.0, 6.0],\n                        [7.0, 7.0, 8.0, 8.0],\n                        [7.0, 7.0, 8.0, 8.0],\n                    ],\n                ]\n            ]\n        ]\n    )\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001))\n    of_out = of_out.sum()\n    of_out.backward()\n    np_grad = [[[[[8.0, 8.0], [8.0, 8.0]], [[8.0, 8.0], [8.0, 8.0]]]]]\n    test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-05, 1e-05))\n\n\ndef _test_interpolate_bilinear_2d(test_case, device):\n    input = flow.tensor(\n        np.arange(1, 5).reshape((1, 1, 2, 2)),\n        device=flow.device(device),\n        dtype=flow.float32,\n        requires_grad=True,\n    )\n    of_out = flow.nn.functional.interpolate(input, scale_factor=2.0, mode=\"bilinear\")\n    np_out = np.array(\n        [\n            [\n                [\n                    [1.0, 1.25, 1.75, 2.0],\n                    [1.5, 1.75, 2.25, 2.5],\n                    [2.5, 2.75, 3.25, 3.5],\n                    [3.0, 3.25, 3.75, 4.0],\n                ]\n            ]\n        ]\n    )\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001))\n    of_out = of_out.sum()\n    of_out.backward()\n    np_grad = [[[[4.0, 4.0], [4.0, 4.0]]]]\n    test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-05, 1e-05))\n\n\ndef _test_interpolate_bicubic_2d(test_case, device):\n    input = flow.tensor(\n        np.arange(1, 5).reshape((1, 1, 2, 2)).astype(np.float32),\n        device=flow.device(device),\n        dtype=flow.float32,\n        requires_grad=True,\n    )\n    of_out = flow.nn.functional.interpolate(input, scale_factor=2.0, mode=\"bicubic\")\n    np_out = np.array(\n        [\n            [\n                [\n                    [0.68359375, 1.015625, 1.5625, 1.89453125],\n                    [1.34765625, 1.6796875, 2.2265625, 2.55859375],\n                    [2.44140625, 2.7734375, 3.3203125, 3.65234375],\n                    [3.10546875, 3.4375, 3.984375, 4.31640625],\n                ]\n            ]\n        ]\n    )\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001))\n    of_out = of_out.sum()\n    of_out.backward()\n    np_grad = [[[[4.0, 4.0], [4.0, 4.0]]]]\n    test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-05, 1e-05))\n\n\ndef _test_interpolate_bicubic_same_dim_2d(test_case, device):\n    input = flow.tensor(\n        np.arange(1, 5).reshape((1, 1, 2, 2)).astype(np.float32),\n        device=flow.device(device),\n        dtype=flow.float32,\n        requires_grad=True,\n    )\n    of_out = flow.nn.functional.interpolate(input, scale_factor=1.0, mode=\"bicubic\")\n    np_out = [[[[1.0, 2.0], [3.0, 4.0]]]]\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001))\n    of_out = of_out.sum()\n    of_out.backward()\n    np_grad = [[[[1.0, 1.0], [1.0, 1.0]]]]\n    test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-05, 1e-05))\n\n\ndef _test_interpolate_trilinear_3d(test_case, device):\n    input = flow.tensor(\n        np.arange(1, 9).reshape((1, 1, 2, 2, 2)),\n        device=flow.device(device),\n        dtype=flow.float32,\n        requires_grad=True,\n    )\n    of_out = flow.nn.functional.interpolate(input, scale_factor=2.0, mode=\"trilinear\")\n    np_out = np.array(\n        [\n            [\n                [\n                    [\n                        [1.0, 1.25, 1.75, 2.0],\n                        [1.5, 1.75, 2.25, 2.5],\n                        [2.5, 2.75, 3.25, 3.5],\n                        [3.0, 3.25, 3.75, 4.0],\n                    ],\n                    [\n                        [2.0, 2.25, 2.75, 3.0],\n                        [2.5, 2.75, 3.25, 3.5],\n                        [3.5, 3.75, 4.25, 4.5],\n                        [4.0, 4.25, 4.75, 5.0],\n                    ],\n                    [\n                        [4.0, 4.25, 4.75, 5.0],\n                        [4.5, 4.75, 5.25, 5.5],\n                        [5.5, 5.75, 6.25, 6.5],\n                        [6.0, 6.25, 6.75, 7.0],\n                    ],\n                    [\n                        [5.0, 5.25, 5.75, 6.0],\n                        [5.5, 5.75, 6.25, 6.5],\n                        [6.5, 6.75, 7.25, 7.5],\n                        [7.0, 7.25, 7.75, 8.0],\n                    ],\n                ]\n            ]\n        ]\n    )\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001))\n    of_out = of_out.sum()\n    of_out.backward()\n    np_grad = [[[[[8.0, 8.0], [8.0, 8.0]], [[8.0, 8.0], [8.0, 8.0]]]]]\n    test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-05, 1e-05))\n\n\ndef _test_interpolate_trilinear_3d_align_corners(test_case, device):\n    input = flow.tensor(\n        np.arange(1, 9).reshape((1, 1, 2, 2, 2)),\n        device=flow.device(device),\n        dtype=flow.float32,\n        requires_grad=True,\n    )\n    of_out = flow.nn.functional.interpolate(\n        input, scale_factor=2.0, mode=\"trilinear\", align_corners=True\n    )\n    np_out = np.array(\n        [\n            [\n                [\n                    [\n                        [1.0, 1.3333332538604736, 1.6666667461395264, 2.0],\n                        [\n                            1.6666666269302368,\n                            2.0,\n                            2.3333334922790527,\n                            2.6666665077209473,\n                        ],\n                        [\n                            2.3333332538604736,\n                            2.6666665077209473,\n                            3.0,\n                            3.3333334922790527,\n                        ],\n                        [3.0, 3.3333332538604736, 3.6666667461395264, 4.0],\n                    ],\n                    [\n                        [\n                            2.3333334922790527,\n                            2.6666665077209473,\n                            3.0,\n                            3.3333332538604736,\n                        ],\n                        [3.0, 3.3333330154418945, 3.6666665077209473, 4.0],\n                        [\n                            3.6666665077209473,\n                            4.0,\n                            4.333333492279053,\n                            4.6666669845581055,\n                        ],\n                        [4.333333492279053, 4.666666030883789, 5.0, 5.3333330154418945],\n                    ],\n                    [\n                        [3.6666667461395264, 4.0, 4.333333492279053, 4.666666507720947],\n                        [4.333333492279053, 4.666666507720947, 5.0, 5.3333330154418945],\n                        [5.0, 5.333333492279053, 5.6666669845581055, 6.0],\n                        [\n                            5.6666669845581055,\n                            6.0,\n                            6.333333492279053,\n                            6.6666669845581055,\n                        ],\n                    ],\n                    [\n                        [5.0, 5.3333330154418945, 5.666666507720947, 6.0],\n                        [\n                            5.666666507720947,\n                            5.999999523162842,\n                            6.3333330154418945,\n                            6.666666507720947,\n                        ],\n                        [6.333333492279053, 6.666666030883789, 7.0, 7.333333492279053],\n                        [7.0, 7.3333330154418945, 7.6666669845581055, 8.0],\n                    ],\n                ]\n            ]\n        ]\n    )\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001))\n    of_out = of_out.sum()\n    of_out.backward()\n    np_grad = [\n        [\n            [\n                [[7.999999523162842, 8.0], [7.999999523162842, 8.0]],\n                [[8.0, 8.0], [8.0, 8.0]],\n            ]\n        ]\n    ]\n    test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-05, 1e-05))\n\n\ndef _test_interpolate_area_1d(test_case, device):\n    input = flow.tensor(\n        np.array(\n            [\n                [\n                    [\n                        0.05580734834074974,\n                        -0.6875145435333252,\n                        -1.654430866241455,\n                        -0.6225992441177368,\n                        0.10183599591255188,\n                        0.05019790679216385,\n                        -1.2537643909454346,\n                        0.14907236397266388,\n                    ]\n                ]\n            ]\n        ),\n        dtype=flow.float32,\n        device=flow.device(device),\n        requires_grad=True,\n    )\n    of_out_1 = flow.nn.functional.interpolate(input, size=4, mode=\"area\")\n    of_out_2 = flow.nn.functional.interpolate(input, scale_factor=0.5, mode=\"area\")\n    np_out = np.array(\n        [\n            [\n                [\n                    -0.3158535957336426,\n                    -1.1385149955749512,\n                    0.07601694762706757,\n                    -0.5523459911346436,\n                ]\n            ]\n        ]\n    )\n    test_case.assertTrue(np.allclose(of_out_1.numpy(), np_out, 1e-05, 1e-05))\n    test_case.assertTrue(np.allclose(of_out_2.numpy(), np_out, 1e-05, 1e-05))\n    of_out_1 = of_out_1.sum()\n    of_out_1.backward()\n    np_grad = np.array([[[0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5]]])\n    test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-05, 1e-05))\n\n\ndef _test_interpolate_area_2d(test_case, device):\n    input = flow.tensor(\n        np.array(\n            [\n                [\n                    [\n                        [\n                            0.10039155930280685,\n                            0.04879157617688179,\n                            -1.0515470504760742,\n                            0.9466001987457275,\n                        ],\n                        [\n                            0.45375481247901917,\n                            0.23611211776733398,\n                            1.343685269355774,\n                            0.3979687988758087,\n                        ],\n                        [\n                            0.05580734834074974,\n                            -0.6875145435333252,\n                            -1.654430866241455,\n                            -0.6225992441177368,\n                        ],\n                        [\n                            0.10183599591255188,\n                            0.05019790679216385,\n                            -1.2537643909454346,\n                            0.14907236397266388,\n                        ],\n                    ]\n                ]\n            ]\n        ),\n        dtype=flow.float32,\n        device=flow.device(device),\n        requires_grad=True,\n    )\n    of_out_1 = flow.nn.functional.interpolate(input, size=(2, 2), mode=\"area\")\n    of_out_2 = flow.nn.functional.interpolate(input, scale_factor=0.5, mode=\"area\")\n    np_out = np.array(\n        [\n            [\n                [\n                    [0.20976251363754272, 0.4091767966747284],\n                    [-0.1199183315038681, -0.8454304933547974],\n                ]\n            ]\n        ]\n    )\n    test_case.assertTrue(np.allclose(of_out_1.numpy(), np_out, 1e-05, 1e-05))\n    test_case.assertTrue(np.allclose(of_out_2.numpy(), np_out, 1e-05, 1e-05))\n    of_out_1 = of_out_1.sum()\n    of_out_1.backward()\n    np_grad = np.array(\n        [\n            [\n                [\n                    [0.25, 0.25, 0.25, 0.25],\n                    [0.25, 0.25, 0.25, 0.25],\n                    [0.25, 0.25, 0.25, 0.25],\n                    [0.25, 0.25, 0.25, 0.25],\n                ]\n            ]\n        ]\n    )\n    test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-05, 1e-05))\n\n\ndef _test_interpolate_area_3d(test_case, device):\n    input = flow.tensor(\n        np.array(\n            [\n                [\n                    [\n                        [\n                            [\n                                -1.077571799600885,\n                                -0.7804538890365837,\n                                -1.2627538752119443,\n                                0.9993507145120477,\n                            ],\n                            [\n                                2.0222532489157516,\n                                1.103451377699465,\n                                -0.4377324754879578,\n                                1.890491810587517,\n                            ],\n                            [\n                                -0.5593861899064654,\n                                -0.4949520241526519,\n                                -0.18536721363519787,\n                                -0.6098969866775772,\n                            ],\n                            [\n                                -1.6536215260171816,\n                                -1.0392583540436786,\n                                0.3686776597613967,\n                                -0.5356882834951805,\n                            ],\n                        ],\n                        [\n                            [\n                                -1.2617900664449953,\n                                -1.4390921091631532,\n                                0.20654399652431357,\n                                0.8186472101906713,\n                            ],\n                            [\n                                -0.3033378863400014,\n                                -0.8173269764076293,\n                                -0.3767515097625614,\n                                -0.11021655039337777,\n                            ],\n                            [\n                                -0.22977043608192885,\n                                1.2717196366649905,\n                                -0.4790851297878291,\n                                -1.4495369404727856,\n                            ],\n                            [\n                                -1.2802093286977783,\n                                -0.11184514806663474,\n                                1.7022167087210984,\n                                -1.7354837287725355,\n                            ],\n                        ],\n                        [\n                            [\n                                2.4706497991773606,\n                                -0.6549702631973298,\n                                -0.9318107079571676,\n                                1.4652904271682428,\n                            ],\n                            [\n                                1.1419864234341397,\n                                1.389909081086008,\n                                0.9657841900525568,\n                                -0.8563114264976619,\n                            ],\n                            [\n                                0.19515087084250754,\n                                -0.37808457398571094,\n                                0.2938625398496183,\n                                0.9279930510353327,\n                            ],\n                            [\n                                -0.9374118277994007,\n                                0.3341831730452431,\n                                -0.2792542765303833,\n                                0.38029090707066726,\n                            ],\n                        ],\n                        [\n                            [\n                                0.5918686659736041,\n                                -0.7870631089938902,\n                                -0.9534344874245392,\n                                0.31341612954718795,\n                            ],\n                            [\n                                0.7509029444145228,\n                                -0.9299288398562323,\n                                -0.7343054052782476,\n                                -0.8806481590696694,\n                            ],\n                            [\n                                -0.4707853016353985,\n                                0.12253641652645629,\n                                0.5088022039832846,\n                                0.520391789327562,\n                            ],\n                            [\n                                -0.0861300651163632,\n                                0.30291348404866386,\n                                -0.6268565873680123,\n                                -0.27469204305759976,\n                            ],\n                        ],\n                    ]\n                ]\n            ]\n        ),\n        dtype=flow.float32,\n        device=flow.device(device),\n        requires_grad=True,\n    )\n    of_out_1 = flow.nn.functional.interpolate(input, size=(2, 2, 2), mode=\"area\")\n    of_out_2 = flow.nn.functional.interpolate(input, scale_factor=0.5, mode=\"area\")\n    np_out = np.array(\n        [\n            [\n                [\n                    [\n                        [-0.3192335125472539, 0.2159474151198386],\n                        [-0.5121654212876662, -0.3655204892948264],\n                    ],\n                    [\n                        [0.4966693377547728, -0.2015024299324123],\n                        [-0.11470347800925032, 0.18131719803880864],\n                    ],\n                ]\n            ]\n        ]\n    )\n    test_case.assertTrue(np.allclose(of_out_1.numpy(), np_out, 1e-05, 1e-05))\n    test_case.assertTrue(np.allclose(of_out_2.numpy(), np_out, 1e-05, 1e-05))\n    of_out_1 = of_out_1.sum()\n    of_out_1.backward()\n    np_grad = np.array(\n        [\n            [\n                [\n                    [\n                        [0.125, 0.125, 0.125, 0.125],\n                        [0.125, 0.125, 0.125, 0.125],\n                        [0.125, 0.125, 0.125, 0.125],\n                        [0.125, 0.125, 0.125, 0.125],\n                    ],\n                    [\n                        [0.125, 0.125, 0.125, 0.125],\n                        [0.125, 0.125, 0.125, 0.125],\n                        [0.125, 0.125, 0.125, 0.125],\n                        [0.125, 0.125, 0.125, 0.125],\n                    ],\n                    [\n                        [0.125, 0.125, 0.125, 0.125],\n                        [0.125, 0.125, 0.125, 0.125],\n                        [0.125, 0.125, 0.125, 0.125],\n                        [0.125, 0.125, 0.125, 0.125],\n                    ],\n                    [\n                        [0.125, 0.125, 0.125, 0.125],\n                        [0.125, 0.125, 0.125, 0.125],\n                        [0.125, 0.125, 0.125, 0.125],\n                        [0.125, 0.125, 0.125, 0.125],\n                    ],\n                ]\n            ]\n        ]\n    )\n    test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-05, 1e-05))\n\n\ndef _test_interpolate_output_size_arg_with_scalar(test_case, device):\n    mode = \"bicubic\"\n    x = flow.Tensor(8, 32, 64).to(device)\n\n    window = 16\n    t = x.shape[2]\n    x = x[:, None]\n\n    np_center = np.random.randint(window, t - window, (1,))[0]\n    np_warped = np.random.randint(np_center - window, np_center + window, (1,))[0] + 1\n\n    center = flow.tensor(np_center)\n    warped = flow.tensor(np_warped)\n\n    res = flow.nn.functional.interpolate(\n        x[:, :, :center], (warped, x.shape[3]), mode=mode, align_corners=False\n    )\n    test_case.assertTrue(np.array_equal(res.size()[0], 8))\n    test_case.assertTrue(np.array_equal(res.size()[1], 1))\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestInterpolate(flow.unittest.TestCase):\n    def test_interpolate(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [\n            _test_interpolate_linear_1d,\n            _test_interpolate_nearest_1d,\n            _test_interpolate_nearest_2d,\n            _test_interpolate_nearest_3d,\n            _test_interpolate_bilinear_2d,\n            _test_interpolate_bicubic_2d,\n            _test_interpolate_bicubic_same_dim_2d,\n            _test_interpolate_trilinear_3d,\n            _test_interpolate_trilinear_3d_align_corners,\n            _test_interpolate_area_1d,\n            _test_interpolate_area_2d,\n            _test_interpolate_area_3d,\n            _test_interpolate_output_size_arg_with_scalar,\n        ]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            for i in range(100):\n                arg[0](test_case, *arg[1:])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_inv.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nimport time\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestLinalgInv(flow.unittest.TestCase):\n    @unittest.skip(\"TODO: peihong, fix this test\")\n    @autotest(n=5, rtol=1e-2)\n    def test_inv_3by3_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=2, dim0=3, dim1=3, low=-1).to(device)\n        return torch.linalg.inv(x)\n\n    @autotest(n=5, rtol=1e-2)\n    def test_inv_batch_3by3_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=3, dim0=random(), dim1=3, dim2=3, low=-1).to(device)\n        return torch.linalg.inv(x)\n\n    @autotest(n=5, rtol=1e-2)\n    def test_inv_random_square_with_random_data(test_case):\n        device = random_device()\n        square_dim = random()\n        x = random_tensor(ndim=4, dim2=square_dim, dim3=square_dim, low=-1).to(device)\n        return torch.linalg.inv(x)\n\n    @profile(torch.linalg.inv)\n    def profile_linalg_inv(test_case):\n        torch.linalg.inv(torch.randn(1, 32, 4, 4))\n        torch.linalg.inv(torch.randn(16, 32, 4, 4))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_isclose.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\n\nfrom oneflow.test_utils.automated_test_util import *\n\nimport oneflow as flow\nimport oneflow.unittest\n\nrtol = 1e-3\n\n\ndef _perturbate(x):\n    shape = x.oneflow.shape\n    device = x.device\n    diff = (\n        random_tensor(len(shape), *shape, low=-1, high=1, requires_grad=False).to(\n            device\n        )\n        * rtol\n        * 2\n    )\n    return x + diff\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestIsClose(flow.unittest.TestCase):\n    @autotest(n=10, auto_backward=False, check_graph=False)\n    def test_isclose_with_random_data(test_case):\n        device = random_device()\n        x1 = random_tensor(requires_grad=False).to(device)\n        x2 = _perturbate(x1)\n        y = torch.isclose(x1, x2, rtol=rtol)\n        return y\n\n    @autotest(n=10, auto_backward=False, check_graph=False)\n    def test_isclose_with_0dim_data(test_case):\n        device = random_device()\n        x1 = random_tensor(requires_grad=False).to(device)\n        x2 = _perturbate(x1)\n        y = torch.isclose(x1, x2, rtol=rtol)\n        return y\n\n    @autotest(n=10, auto_backward=False, check_graph=False)\n    def test_tensor_isclose_with_random_data(test_case):\n        device = random_device()\n        x1 = random_tensor(requires_grad=False).to(device)\n        x2 = _perturbate(x1)\n        y = x1.isclose(x2, rtol=rtol)\n        return y\n\n    @autotest(n=10, auto_backward=False, check_graph=False)\n    def test_isclose_broadcast(test_case):\n        device = random_device()\n        shape = random_tensor(2, 2, 4).oneflow.shape\n        x1 = random_tensor(len(shape), *shape, requires_grad=False).to(device)\n        x2 = _perturbate(x1[:, :1])\n        y = torch.isclose(x1, x2, rtol=rtol)\n        return y\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_jit_script_api.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\nfrom oneflow.test_utils.test_util import GenArgList\n\n\ndef _test_jit_script_api(test_case):\n    @flow.jit.script\n    def add2(x):\n        return x + x\n\n    x = flow.randn(2, 3)\n    y = add2(x)\n    test_case.assertTrue(x.size(), y.size())\n\n\ndef _test_jit_ignore_api(test_case):\n    @flow.jit.ignore\n    def add2(x):\n        return x + x\n\n    x = flow.randn(2, 3)\n    y = add2(x)\n    test_case.assertTrue(x.size(), y.size())\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestJitScriptApi(flow.unittest.TestCase):\n    def test_jit_script(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [_test_jit_script_api, _test_jit_ignore_api]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_layer_norm.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport os\nimport numpy as np\nimport unittest\n\nimport oneflow as flow\nimport oneflow.unittest\nimport torch\n\n\ndef _layer_norm(x, normalized_shape, weight=None, bias=None, eps=1e-6):\n    begin_norm_axis = len(x.shape) - len(normalized_shape)\n    begin_params_axis = len(x.shape) - len(normalized_shape)\n\n    if weight is not None and bias is not None:\n        return flow._C.layer_norm_affine(\n            x,\n            weight,\n            bias,\n            begin_norm_axis=begin_norm_axis,\n            begin_params_axis=begin_params_axis,\n            epsilon=eps,\n        )\n    else:\n        return flow._C.layer_norm(\n            x,\n            begin_norm_axis=begin_norm_axis,\n            begin_params_axis=begin_params_axis,\n            epsilon=eps,\n        )\n\n\ndef _test_layer_norm(\n    test_case,\n    shape,\n    normalized_shape,\n    affine=True,\n    eps=1e-6,\n    dtype=flow.float32,\n    device=\"cuda\",\n    backward=True,\n):\n    np_x = np.random.randn(*shape).astype(np.float32)\n    if affine:\n        np_weight = np.random.randn(*normalized_shape).astype(np.float32)\n        np_bias = np.random.randn(*normalized_shape).astype(np.float32)\n\n    # torch process\n    torch_dtype = torch.float16 if dtype is flow.float16 else torch.float32\n    torch_x = torch.tensor(np_x).to(device=device, dtype=torch_dtype)\n    if backward:\n        torch_x.requires_grad_(True)\n    torch_weight = None\n    torch_bias = None\n    if affine:\n        torch_weight = torch.tensor(np_weight).to(device=device, dtype=torch_dtype)\n        torch_bias = torch.tensor(np_bias).to(device=device, dtype=torch_dtype)\n        if backward:\n            torch_weight.requires_grad_(True)\n            torch_bias.requires_grad_(True)\n    torch_y = torch.nn.functional.layer_norm(\n        torch_x, normalized_shape, torch_weight, torch_bias, eps\n    )\n\n    if backward:\n        np_rand_init_grad = np.random.randn(*tuple(torch_y.shape)).astype(np.float32)\n        torch_rand_init_grad = torch.tensor(np_rand_init_grad).to(\n            device=device, dtype=torch_dtype\n        )\n        (torch_y * torch_rand_init_grad).sum().backward()\n\n        torch_x_grad = torch_x.grad.detach().cpu().numpy()\n        if affine:\n            torch_weight_grad = torch_weight.grad.detach().cpu().numpy()\n            torch_bias_grad = torch_bias.grad.detach().cpu().numpy()\n\n    torch_y = torch_y.detach().cpu().numpy()\n\n    # oneflow process\n    x = flow.tensor(np_x).to(device=device, dtype=dtype)\n    if backward:\n        x.requires_grad_(True)\n    weight = None\n    bias = None\n    if affine:\n        weight = flow.tensor(np_weight).to(device=device, dtype=dtype)\n        bias = flow.tensor(np_bias).to(device=device, dtype=dtype)\n        if backward:\n            weight.requires_grad_(True)\n            bias.requires_grad_(True)\n    y = _layer_norm(x, normalized_shape, weight, bias, eps)\n\n    if backward:\n        # np_rand_init_grad = np.random.randn(*tuple(y.shape)).astype(np.float32)\n        rand_init_grad = flow.tensor(np_rand_init_grad).to(device=device, dtype=dtype)\n        (y * rand_init_grad).sum().backward()\n\n        x_grad = x.grad.detach().cpu().numpy()\n        if affine:\n            weight_grad = weight.grad.detach().cpu().numpy()\n            bias_grad = bias.grad.detach().cpu().numpy()\n\n    y = y.detach().cpu().numpy()\n\n    def compare(a, b, a_name, b_name, atol=1e-5, rtol=1e-8):\n        test_case.assertTrue(\n            np.allclose(a, b, atol=atol, rtol=rtol),\n            f\"\\n{'=' * 80}\"\n            f\"\\n{a_name}:\"\n            f\"\\n{a}\"\n            f\"\\n{'-' * 80}\"\n            f\"\\n{b_name}:\"\n            f\"\\n{b}\"\n            f\"\\n{'-' * 80}\"\n            f\"\\ndiff:\"\n            f\"\\n{a - b}\"\n            f\"\\n{'*' * 80}\"\n            f\"\\nshape={shape}\"\n            f\"\\normalized_shape={normalized_shape}\"\n            f\"\\naffine={affine}\"\n            f\"\\ndtype={dtype}\"\n            f\"\\ndevice={device}\"\n            f\"\\n{a_name} vs. {b_name} max abs diff: {np.max(np.abs(a - b))}\",\n        )\n\n    if dtype is flow.float16:\n        compare(y, torch_y, \"y\", \"torch_y\", 1e-2, 1e-2)\n        if backward:\n            compare(x_grad, torch_x_grad, \"x_grad\", \"torch_x_grad\", 1e-2, 1e-2)\n            if affine:\n                compare(\n                    weight_grad,\n                    torch_weight_grad,\n                    \"weight_grad\",\n                    \"torch_weight_grad\",\n                    1e-2,\n                    1e-2,\n                )\n                compare(\n                    bias_grad,\n                    torch_bias_grad,\n                    \"bias_grad\",\n                    \"torch_bias_grad\",\n                    1e-2,\n                    1e-2,\n                )\n    else:\n        compare(y, torch_y, \"y\", \"torch_y\")\n        if backward:\n            compare(x_grad, torch_x_grad, \"x_grad\", \"torch_x_grad\")\n            if affine:\n                compare(\n                    weight_grad, torch_weight_grad, \"weight_grad\", \"torch_weight_grad\",\n                )\n                compare(\n                    bias_grad, torch_bias_grad, \"bias_grad\", \"torch_bias_grad\",\n                )\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n@flow.unittest.skip_unless_1n1d()\nclass TestLayerNorm(flow.unittest.TestCase):\n    def test_no_affine(test_case):\n        _test_layer_norm(\n            test_case, shape=[4, 16], normalized_shape=[16], affine=False,\n        )\n\n    def test_warp_impl(test_case):\n        _test_layer_norm(\n            test_case, shape=[32, 1024], normalized_shape=[1024], dtype=flow.float16,\n        )\n        _test_layer_norm(test_case, shape=[16, 512], normalized_shape=[512])\n        _test_layer_norm(test_case, shape=[15, 512], normalized_shape=[512])\n        _test_layer_norm(test_case, shape=[16, 511], normalized_shape=[511])\n        _test_layer_norm(test_case, shape=[13, 499], normalized_shape=[499])\n\n    def test_block_smem_impl(test_case):\n        _test_layer_norm(\n            test_case, shape=[16, 2048], normalized_shape=[2048], dtype=flow.float16,\n        )\n        _test_layer_norm(test_case, shape=[8, 1536], normalized_shape=[1536])\n        _test_layer_norm(test_case, shape=[8, 2048], normalized_shape=[2048])\n        _test_layer_norm(test_case, shape=[7, 1536], normalized_shape=[1536])\n        _test_layer_norm(test_case, shape=[8, 1533], normalized_shape=[1533])\n        _test_layer_norm(test_case, shape=[7, 1533], normalized_shape=[1533])\n\n    def test_block_uncached_impl(test_case):\n        _test_layer_norm(\n            test_case,\n            shape=[16, 1024 * 1024],\n            normalized_shape=[1024 * 1024],\n            dtype=flow.float16,\n        )\n        _test_layer_norm(\n            test_case, shape=[8, 1024], normalized_shape=[1024], dtype=flow.double\n        )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_lerp.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\n\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestLerp(flow.unittest.TestCase):\n    @autotest(check_graph=False)\n    def test_lerp_with_broadcast_data(test_case):\n        device = random_device()\n        start = random_tensor(ndim=2, dim0=3, dim1=1).to(device)\n        end = random_tensor(ndim=2, dim0=1, dim1=3).to(device)\n        weight = random_tensor(ndim=1, dim0=1).to(device)\n        return torch.lerp(start, end, weight)\n\n    @autotest()\n    def test_lerp_with_random_data(test_case):\n        device = random_device()\n        start = random_tensor(ndim=3, dim0=3, dim1=4, dim2=5).to(device)\n        end = random_tensor(ndim=3, dim0=3, dim1=4, dim2=5).to(device)\n        weight = random_tensor(ndim=3, dim0=3, dim1=4, dim2=5).to(device)\n        return torch.lerp(\n            start, end, oneof(weight, random().to(int), random().to(float))\n        )\n\n    @autotest()\n    def test_tesnor_lerp_with_random_data(test_case):\n        device = random_device()\n        start = random_tensor(ndim=3, dim0=3, dim1=4, dim2=5).to(device)\n        end = random_tensor(ndim=3, dim0=3, dim1=4, dim2=5).to(device)\n        weight = random_tensor(ndim=3, dim0=3, dim1=4, dim2=5).to(device)\n        return start.lerp(end, oneof(weight, random().to(int), random().to(float)))\n\n    @autotest()\n    def test_tesnor_inplace_lerp_with_random_data(test_case):\n        device = random_device()\n        start = random_tensor(ndim=3, dim0=3, dim1=4, dim2=5).to(device) + 0.01\n        end = random_tensor(ndim=3, dim0=3, dim1=4, dim2=5).to(device) + 0.01\n        weight = random_tensor(ndim=3, dim0=3, dim1=4, dim2=5).to(device) + 0.01\n        return start.lerp_(end, oneof(weight, random().to(int), random().to(float)))\n\n    @profile(torch.lerp)\n    def profile_lerp(test_case):\n        torch.lerp(\n            torch.randn(1, 32, 4, 4), torch.randn(1, 32, 4, 4), torch.randn(1, 32, 4, 4)\n        )\n        torch.lerp(\n            torch.randn(8, 32, 4, 4), torch.randn(8, 32, 4, 4), torch.randn(8, 32, 4, 4)\n        )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_less.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\n\nfrom oneflow.test_utils.automated_test_util import *\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\ndef _test_less_normal(test_case, device):\n    input1 = flow.tensor(\n        np.random.randn(2, 6, 5, 3), dtype=flow.float32, device=flow.device(device)\n    )\n    input2 = flow.tensor(\n        np.random.randn(2, 6, 5, 3), dtype=flow.float32, device=flow.device(device)\n    )\n    of_out = flow.lt(input1, input2)\n    np_out = np.less(input1.numpy(), input2.numpy())\n    test_case.assertTrue(np.array_equal(of_out.numpy(), np_out))\n\n\ndef _test_less_symbol(test_case, device):\n    input1 = flow.tensor(\n        np.array([1, 1, 4]).astype(np.float32),\n        dtype=flow.float32,\n        device=flow.device(device),\n    )\n    input2 = flow.tensor(\n        np.array([1, 2, 3]).astype(np.float32),\n        dtype=flow.float32,\n        device=flow.device(device),\n    )\n    of_out = input1 < input2\n    np_out = np.less(input1.numpy(), input2.numpy())\n    test_case.assertTrue(np.array_equal(of_out.numpy(), np_out))\n\n\ndef _test_less_int_scalar(test_case, device):\n    np_arr = np.random.randn(2, 3, 4, 5)\n    input1 = flow.tensor(np_arr, dtype=flow.float32, device=flow.device(device))\n    input2 = 1\n    of_out = input1 < input2\n    np_out = np.less(np_arr, input2)\n    test_case.assertTrue(np.array_equal(of_out.numpy(), np_out))\n\n\ndef _test_less_int_tensor_int_scalr(test_case, device):\n    np_arr = np.random.randint(2, size=(2, 3, 4, 5))\n    input1 = flow.tensor(np_arr, dtype=flow.int, device=flow.device(device))\n    input2 = 1\n    of_out = input1 < input2\n    np_out = np.less(np_arr, input2)\n    test_case.assertTrue(np.array_equal(of_out.numpy(), np_out))\n\n\ndef _test_less_float_scalar(test_case, device):\n    np_arr = np.random.randn(3, 2, 5, 7)\n    input1 = flow.tensor(np_arr, dtype=flow.float32, device=flow.device(device))\n    input2 = 2.3\n    of_out = input1 < input2\n    np_out = np.less(np_arr, input2)\n    test_case.assertTrue(np.array_equal(of_out.numpy(), np_out))\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestLess(flow.unittest.TestCase):\n    def test_less(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [\n            _test_less_normal,\n            _test_less_symbol,\n            _test_less_int_scalar,\n            _test_less_int_tensor_int_scalr,\n            _test_less_float_scalar,\n        ]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n    @autotest(n=10, auto_backward=False, check_graph=True)\n    def test_less_with_random_data(test_case):\n        device = random_device()\n        shape = random_tensor().oneflow.shape\n        x1 = random_tensor(len(shape), *shape, requires_grad=False).to(device)\n        x2 = random_tensor(len(shape), *shape, requires_grad=False).to(device)\n        y = torch.lt(x1, oneof(x2, random().to(int).to(float)))\n        return y\n\n    @autotest(n=10, auto_backward=False, check_graph=True)\n    def test_less_with_0dim_data(test_case):\n        device = random_device()\n        shape = random_tensor().oneflow.shape\n        x1 = random_tensor(ndim=0).to(device)\n        x2 = random_tensor(ndim=0).to(device)\n        y = torch.lt(x1, oneof(x2, random().to(int).to(float)))\n        return y\n\n    @autotest(n=10, auto_backward=False, check_graph=True)\n    def test_tensor_less_with_random_data(test_case):\n        device = random_device()\n        shape = random_tensor().oneflow.shape\n        x1 = random_tensor(len(shape), *shape, requires_grad=False).to(device)\n        x2 = random_tensor(len(shape), *shape, requires_grad=False).to(device)\n        y1 = x1.lt(oneof(x2, random().to(int), random().to(float)))\n        y2 = x1 < x2\n        return (y1, y2)\n\n    @autotest(n=10, auto_backward=False, check_graph=True)\n    def test_less_bool_with_random_data(test_case):\n        device = random_device()\n        shape = random_tensor().oneflow.shape\n        x1 = random_tensor(len(shape), *shape, requires_grad=False).to(\n            device=device, dtype=torch.bool\n        )\n        x2 = random_tensor(len(shape), *shape, requires_grad=False).to(\n            device=device, dtype=torch.bool\n        )\n        y = torch.lt(x1, oneof(x2, random().to(int).to(float)))\n        return y\n\n    @autotest(n=10, auto_backward=False, check_graph=True)\n    def test_tensor_less_with_0dim_data(test_case):\n        device = random_device()\n        shape = random_tensor().oneflow.shape\n        x1 = random_tensor(ndim=0).to(device)\n        x2 = random_tensor(ndim=0).to(device)\n        y1 = x1.lt(oneof(x2, random().to(int), random().to(float)))\n        y2 = x1 < x2\n        return (y1, y2)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_less_equal.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\ndef _test_less_equal_normal(test_case, device):\n    input1 = flow.tensor(\n        np.random.randn(2, 6, 5, 3), dtype=flow.float32, device=flow.device(device)\n    )\n    input2 = flow.tensor(\n        np.random.randn(2, 6, 5, 3), dtype=flow.float32, device=flow.device(device)\n    )\n    of_out = flow.le(input1, input2)\n    np_out = np.less_equal(input1.numpy(), input2.numpy())\n    test_case.assertTrue(np.array_equal(of_out.numpy(), np_out))\n\n\ndef _test_less_equal_symbol(test_case, device):\n    input1 = flow.tensor(\n        np.array([1, 1, 4]).astype(np.float32),\n        dtype=flow.float32,\n        device=flow.device(device),\n    )\n    input2 = flow.tensor(\n        np.array([1, 2, 3]).astype(np.float32),\n        dtype=flow.float32,\n        device=flow.device(device),\n    )\n    of_out = input1 <= input2\n    np_out = np.less_equal(input1.numpy(), input2.numpy())\n    test_case.assertTrue(np.array_equal(of_out.numpy(), np_out))\n\n\ndef _test_less_equal_int_scalar(test_case, device):\n    np_arr = np.random.randn(2, 3, 4, 5)\n    input1 = flow.tensor(np_arr, dtype=flow.float32, device=flow.device(device))\n    input2 = 1\n    of_out = input1 <= input2\n    np_out = np.less_equal(np_arr, input2)\n    test_case.assertTrue(np.array_equal(of_out.numpy(), np_out))\n\n\ndef _test_less_equal_int_tensor_int_scalr(test_case, device):\n    np_arr = np.random.randint(2, size=(2, 3, 4, 5))\n    input1 = flow.tensor(np_arr, dtype=flow.int, device=flow.device(device))\n    input2 = 1\n    of_out = input1 <= input2\n    np_out = np.less_equal(np_arr, input2)\n    test_case.assertTrue(np.array_equal(of_out.numpy(), np_out))\n\n\ndef _test_less_equal_float_scalar(test_case, device):\n    np_arr = np.random.randn(3, 2, 5, 7)\n    input1 = flow.tensor(np_arr, dtype=flow.float32, device=flow.device(device))\n    input2 = 2.3\n    of_out = input1 <= input2\n    np_out = np.less_equal(np_arr, input2)\n    test_case.assertTrue(np.array_equal(of_out.numpy(), np_out))\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestLessEqual(flow.unittest.TestCase):\n    def test_less_equal(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [\n            _test_less_equal_normal,\n            _test_less_equal_symbol,\n            _test_less_equal_int_scalar,\n            _test_less_equal_int_tensor_int_scalr,\n            _test_less_equal_float_scalar,\n        ]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_linalg_cross.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nimport numpy as np\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestLinalgCross(flow.unittest.TestCase):\n    # TODO(peihong): PyTorch 1.10 has no torch.linalg.cross, so uncomment the below code when PyTorch in ci is upgraded to 1.11.\n    # @autotest(n=5)\n    # def test_linalg_cross_with_random_data(test_case):\n    #     device = random_device()\n    #     ndim = np.random.randint(2, 6)\n    #     shape = list(np.random.randint(16, size=ndim))\n    #     index = np.random.randint(ndim)\n    #     shape[index] = 3\n\n    #     x = random_tensor(ndim, *shape).to(device)\n    #     y = random_tensor(ndim, *shape).to(device)\n    #     return torch.linalg.cross(x, y, dim=index)\n\n    # @autotest(n=10)\n    # def test_linalg_cross_with_random_data_broadcast(test_case):\n    #     device = random_device()\n    #     ndim = np.random.randint(3, 6)\n    #     shape = list(np.random.randint(16, size=ndim))\n    #     indexes = list(np.random.choice(ndim, 3))\n    #     shape[indexes[0]] = 3\n    #     x_shape = shape\n    #     y_shape = shape[:]\n    #     x_shape[indexes[1]] = 1\n    #     y_shape[indexes[2]] = 1\n\n    #     x = random_tensor(ndim, *x_shape).to(device)\n    #     y = random_tensor(ndim, *y_shape).to(device)\n    #     return torch.linalg.cross(x, y, dim=indexes[0])\n\n    # @autotest(n=1)\n    # def test_linalg_cross_with_random_data_broadcast_different_num_axes(test_case):\n    #     device = random_device()\n    #     x = random_tensor(4, 4, 5, 3, 5).to(device)\n    #     y = random_tensor(3, 1, 3, 5).to(device)\n    #     return torch.linalg.cross(x, y, dim=2)\n\n    # @autotest(n=5)\n    # def test_linalg_cross_with_random_data_default_dim(test_case):\n    #     device = random_device()\n    #     ndim = np.random.randint(2, 6)\n    #     shape = list(np.random.randint(16, size=ndim))\n    #     index = np.random.randint(ndim)\n    #     shape[index] = 3\n\n    #     x = random_tensor(ndim, *shape).to(device)\n    #     y = random_tensor(ndim, *shape).to(device)\n    #     return torch.linalg.cross(x, y)\n\n    @autotest(n=5)\n    def test_cross_with_random_data_default_dim(test_case):\n        device = random_device()\n        ndim = np.random.randint(2, 6)\n        shape = list(np.random.randint(16, size=ndim))\n        index = np.random.randint(ndim)\n        shape[index] = 3\n\n        x = random_tensor(ndim, *shape).to(device)\n        y = random_tensor(ndim, *shape).to(device)\n        return torch.cross(x, y)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_linear.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\n\nfrom oneflow.test_utils.automated_test_util import *\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\ndef _test_linear_no_bias(test_case, device):\n    linear = flow.nn.Linear(3, 8, False)\n    linear = linear.to(device)\n    input_arr = np.array(\n        [\n            [-0.94630778, -0.83378579, -0.87060891],\n            [2.0289922, -0.28708987, -2.18369248],\n            [0.35217619, -0.67095644, -1.58943879],\n            [0.08086036, -1.81075924, 1.20752494],\n            [0.8901075, -0.49976737, -1.07153746],\n            [-0.44872912, -1.07275683, 0.06256855],\n            [-0.22556897, 0.74798368, 0.90416439],\n            [0.48339456, -2.32742195, -0.59321527],\n        ],\n        dtype=np.float32,\n    )\n    np_weight = np.ones((3, 8)).astype(np.float32)\n    np_weight.fill(2.3)\n    x = flow.tensor(input_arr, dtype=flow.float32, device=flow.device(device))\n    flow.nn.init.constant_(linear.weight, 2.3)\n    of_out = linear(x)\n    np_out = np.matmul(input_arr, np_weight)\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05))\n\n\ndef _test_linear_with_bias(test_case, device):\n    linear = flow.nn.Linear(3, 8)\n    linear = linear.to(device)\n    input_arr = np.array(\n        [\n            [-0.94630778, -0.83378579, -0.87060891],\n            [2.0289922, -0.28708987, -2.18369248],\n            [0.35217619, -0.67095644, -1.58943879],\n            [0.08086036, -1.81075924, 1.20752494],\n            [0.8901075, -0.49976737, -1.07153746],\n            [-0.44872912, -1.07275683, 0.06256855],\n            [-0.22556897, 0.74798368, 0.90416439],\n            [0.48339456, -2.32742195, -0.59321527],\n        ],\n        dtype=np.float32,\n    )\n    np_weight = np.ones((3, 8)).astype(np.float32)\n    np_weight.fill(2.068758)\n    np_bias = np.ones(8)\n    np_bias.fill(0.23)\n    x = flow.tensor(input_arr, dtype=flow.float32, device=flow.device(device))\n    flow.nn.init.constant_(linear.weight, 2.068758)\n    flow.nn.init.constant_(linear.bias, 0.23)\n    of_out = linear(x)\n    np_out = np.matmul(input_arr, np_weight)\n    np_out += np_bias\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05))\n\n\ndef _test_linear_3_dimension_input(test_case, device):\n    input_arr = np.random.randn(2, 3, 4)\n    x = flow.tensor(input_arr, dtype=flow.float32, device=flow.device(device))\n    linear = flow.nn.Linear(4, 5, True)\n    linear = linear.to(device)\n    flow.nn.init.constant_(linear.weight, 5.6)\n    flow.nn.init.constant_(linear.bias, 0.78)\n    of_out = linear(x)\n    np_weight = np.ones((4, 5)).astype(np.float32)\n    np_weight.fill(5.6)\n    np_bias = np.ones(5)\n    np_bias.fill(0.78)\n    np_out = np.matmul(input_arr, np_weight)\n    np_out += np_bias\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05))\n\n\ndef _test_linear_4_dimension_input(test_case, device):\n    input_arr = np.random.randn(4, 5, 6, 7)\n    x = flow.tensor(input_arr, dtype=flow.float32, device=flow.device(device))\n    linear = flow.nn.Linear(7, 3, False)\n    linear = linear.to(device)\n    flow.nn.init.constant_(linear.weight, 11.3)\n    of_out = linear(x)\n    np_weight = np.ones((7, 3)).astype(np.float32)\n    np_weight.fill(11.3)\n    np_out = np.matmul(input_arr, np_weight)\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05))\n\n\ndef _test_identity(test_case, device):\n    linear = flow.nn.Identity(54, unused_argument1=0.1, unused_argument2=False)\n    linear = linear.to(device)\n    x = flow.tensor(\n        np.random.rand(2, 3, 4, 5), dtype=flow.float32, device=flow.device(device)\n    )\n    y = linear(x)\n    test_case.assertTrue(np.array_equal(x.numpy(), y.numpy()))\n\n\ndef _test_linear_backward_with_bias(test_case, device):\n    linear = flow.nn.Linear(3, 8)\n    linear = linear.to(device)\n    x = flow.tensor(\n        [\n            [-0.94630778, -0.83378579, -0.87060891],\n            [2.0289922, -0.28708987, -2.18369248],\n            [0.35217619, -0.67095644, -1.58943879],\n            [0.08086036, -1.81075924, 1.20752494],\n            [0.8901075, -0.49976737, -1.07153746],\n            [-0.44872912, -1.07275683, 0.06256855],\n            [-0.22556897, 0.74798368, 0.90416439],\n            [0.48339456, -2.32742195, -0.59321527],\n        ],\n        dtype=flow.float32,\n        device=flow.device(device),\n        requires_grad=True,\n    )\n    flow.nn.init.constant_(linear.weight, 2.068758)\n    flow.nn.init.constant_(linear.bias, 0.23)\n    of_out = linear(x)\n    of_out = of_out.sum()\n    of_out.backward()\n    np_grad = np.array(\n        [\n            [16.5501, 16.5501, 16.5501],\n            [16.5501, 16.5501, 16.5501],\n            [16.5501, 16.5501, 16.5501],\n            [16.5501, 16.5501, 16.5501],\n            [16.5501, 16.5501, 16.5501],\n            [16.5501, 16.5501, 16.5501],\n            [16.5501, 16.5501, 16.5501],\n            [16.5501, 16.5501, 16.5501],\n        ]\n    )\n    test_case.assertTrue(np.allclose(np_grad, x.grad.numpy(), 0.0001, 0.0001))\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestLinear(flow.unittest.TestCase):\n    def test_linear_forward(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [\n            _test_linear_no_bias,\n            _test_linear_with_bias,\n            _test_linear_3_dimension_input,\n            _test_linear_4_dimension_input,\n            _test_identity,\n        ]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n    def test_linear_backward(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [_test_linear_backward_with_bias]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n    @autotest(n=5, rtol=1e-2)\n    def test_linear_with_random_data(test_case):\n        input_size = random()\n        m = torch.nn.Linear(\n            in_features=input_size, out_features=random(), bias=random() | nothing()\n        )\n        m.train(random())\n        device = random_device()\n        m.to(device)\n        x = random_tensor(ndim=2, dim1=input_size).to(device)\n        y = m(x)\n        return y\n\n    @autotest(n=5, rtol=1e-2, atol=1e-4)\n    def test_linear_with_device_and_dtype(test_case):\n        input_size = random()\n        device = random_device()\n        m = torch.nn.Linear(\n            in_features=input_size,\n            out_features=random(),\n            bias=random() | nothing(),\n            device=device,\n            dtype=torch.float,\n        )\n        m.train(random())\n        m.to(device)\n        x = random_tensor(ndim=2, dim1=input_size).to(device)\n        y = m(x)\n        return y\n\n    @autotest(n=5, rtol=1e-3)\n    def test_nn_functional_linear_with_random_data(test_case):\n        input_size = random()\n        device = random_device()\n        x = random_tensor(ndim=2, dim1=input_size).to(device)\n        weight = random_tensor(ndim=2, dim1=input_size).to(device)\n        y = torch.nn.functional.linear(x, weight)\n        return y\n\n    @autotest(n=5, rtol=1e-2)\n    def test_nn_functional_bias_linear_with_random_data(test_case):\n        input_size = random()\n        bias_size = random()\n        device = random_device()\n        x = random_tensor(ndim=2, dim1=input_size).to(device)\n        weight = random_tensor(ndim=2, dim0=bias_size, dim1=input_size).to(device)\n        bias = random_tensor(ndim=1, dim0=bias_size).to(device)\n        y = torch.nn.functional.linear(x, weight, bias)\n        return y\n\n    @autotest(n=5)\n    def test_identity_with_random_data(test_case):\n        m = torch.nn.Identity(\n            x=random().to(int),\n            unused_argument1=random().to(float),\n            unused_argument2=random().to(float),\n        )\n        m.train(random())\n        device = random_device()\n        m.to(device)\n        x = random_tensor().to(device)\n        y = m(x)\n        return y\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_linspace.py",
    "content": "\"\"\"\r\nCopyright 2020 The OneFlow Authors. All rights reserved.\r\n\r\nLicensed under the Apache License, Version 2.0 (the \"License\");\r\nyou may not use this file except in compliance with the License.\r\nYou may obtain a copy of the License at\r\n\r\n    http://www.apache.org/licenses/LICENSE-2.0\r\n\r\nUnless required by applicable law or agreed to in writing, software\r\ndistributed under the License is distributed on an \"AS IS\" BASIS,\r\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\r\nSee the License for the specific language governing permissions and\r\nlimitations under the License.\r\n\"\"\"\r\n\r\nimport unittest\r\nfrom collections import OrderedDict\r\nimport numpy as np\r\n\r\nfrom oneflow.test_utils.test_util import GenArgList\r\n\r\nimport oneflow as flow\r\nimport oneflow.unittest\r\n\r\nfrom oneflow.test_utils.automated_test_util import *\r\n\r\n\r\n@flow.unittest.skip_unless_1n1d()\r\nclass TestLinspace(flow.unittest.TestCase):\r\n    @autotest(n=5, auto_backward=False, rtol=1e-5, atol=1e-5, check_graph=True)\r\n    def test_linspace_int_with_random_data(test_case):\r\n        start = random().to(int)\r\n        end = start + random().to(int)\r\n        steps = random(0, end - start).to(int)\r\n        x = torch.linspace(start=start, end=end, steps=steps)\r\n        device = random_device()\r\n        x.to(device)\r\n        return x\r\n\r\n    @autotest(n=5, auto_backward=False, rtol=1e-5, atol=1e-5, check_graph=True)\r\n    def test_linspace_float_with_random_data(test_case):\r\n        start = random()\r\n        end = start + random()\r\n        steps = random(0, end - start).to(int)\r\n        x = torch.linspace(start=start, end=end, steps=steps)\r\n        device = random_device()\r\n        x.to(device)\r\n        return x\r\n\r\n    @autotest(n=5, auto_backward=False)\r\n    def test_linspace_with_scalar_tensor_as_params(test_case):\r\n        start = random_tensor(2, 3, 4, requires_grad=False).mean()\r\n        end = start + random_tensor(2, 3, 4, requires_grad=False).mean()\r\n        steps = random(0, 10).to(int)\r\n        y = torch.linspace(start=start, end=end, steps=steps)\r\n        return y\r\n\r\n    def test_global_naive(test_case):\r\n        placement = flow.placement(\"cpu\", ranks=[0])\r\n        sbp = (flow.sbp.broadcast,)\r\n        x = flow.linspace(start=0, end=10, steps=2, placement=placement, sbp=sbp)\r\n        test_case.assertEqual(x.sbp, sbp)\r\n        test_case.assertEqual(x.placement, placement)\r\n\r\n    def test_linspace_in_transformer_bug(test_case):\r\n        drop_path_rate = 0.1\r\n        depths = [2, 2, 6, 2]\r\n        flow_res = flow.linspace(0, drop_path_rate, sum(depths))\r\n        torch_res = np.array(\r\n            [\r\n                0.0000,\r\n                0.0091,\r\n                0.0182,\r\n                0.0273,\r\n                0.0364,\r\n                0.0455,\r\n                0.0545,\r\n                0.0636,\r\n                0.0727,\r\n                0.0818,\r\n                0.0909,\r\n                0.1000,\r\n            ]\r\n        )\r\n        test_case.assertTrue(np.allclose(flow_res.numpy(), torch_res, atol=1e-4))\r\n        drop_path_rate = 0.2\r\n        depths = [2, 2, 6, 2]\r\n        flow_res = flow.linspace(0, drop_path_rate, sum(depths))\r\n        torch_res = np.array(\r\n            [\r\n                0.0000,\r\n                0.0182,\r\n                0.0364,\r\n                0.0545,\r\n                0.0727,\r\n                0.0909,\r\n                0.1091,\r\n                0.1273,\r\n                0.1455,\r\n                0.1636,\r\n                0.1818,\r\n                0.2000,\r\n            ]\r\n        )\r\n        test_case.assertTrue(np.allclose(flow_res.numpy(), torch_res, atol=1e-4))\r\n        drop_path_rate = 0.3\r\n        depths = [2, 2, 18, 2]\r\n        flow_res = flow.linspace(0, drop_path_rate, sum(depths))\r\n        torch_res = np.array(\r\n            [\r\n                0.0000,\r\n                0.0130,\r\n                0.0261,\r\n                0.0391,\r\n                0.0522,\r\n                0.0652,\r\n                0.0783,\r\n                0.0913,\r\n                0.1043,\r\n                0.1174,\r\n                0.1304,\r\n                0.1435,\r\n                0.1565,\r\n                0.1696,\r\n                0.1826,\r\n                0.1957,\r\n                0.2087,\r\n                0.2217,\r\n                0.2348,\r\n                0.2478,\r\n                0.2609,\r\n                0.2739,\r\n                0.2870,\r\n                0.3000,\r\n            ]\r\n        )\r\n        test_case.assertTrue(np.allclose(flow_res.numpy(), torch_res, atol=1e-4))\r\n        drop_path_rate = 0.1\r\n        depths = [2, 2, 18, 2]\r\n        flow_res = flow.linspace(0, drop_path_rate, sum(depths))\r\n        torch_res = np.array(\r\n            [\r\n                0.0000,\r\n                0.0043,\r\n                0.0087,\r\n                0.0130,\r\n                0.0174,\r\n                0.0217,\r\n                0.0261,\r\n                0.0304,\r\n                0.0348,\r\n                0.0391,\r\n                0.0435,\r\n                0.0478,\r\n                0.0522,\r\n                0.0565,\r\n                0.0609,\r\n                0.0652,\r\n                0.0696,\r\n                0.0739,\r\n                0.0783,\r\n                0.0826,\r\n                0.0870,\r\n                0.0913,\r\n                0.0957,\r\n                0.1000,\r\n            ]\r\n        )\r\n        test_case.assertTrue(np.allclose(flow_res.numpy(), torch_res, atol=1e-4))\r\n        drop_path_rate = 0.5\r\n        depths = [2, 2, 18, 2]\r\n        flow_res = flow.linspace(0, drop_path_rate, sum(depths))\r\n        torch_res = np.array(\r\n            [\r\n                0.0000,\r\n                0.0217,\r\n                0.0435,\r\n                0.0652,\r\n                0.0870,\r\n                0.1087,\r\n                0.1304,\r\n                0.1522,\r\n                0.1739,\r\n                0.1957,\r\n                0.2174,\r\n                0.2391,\r\n                0.2609,\r\n                0.2826,\r\n                0.3043,\r\n                0.3261,\r\n                0.3478,\r\n                0.3696,\r\n                0.3913,\r\n                0.4130,\r\n                0.4348,\r\n                0.4565,\r\n                0.4783,\r\n                0.5000,\r\n            ]\r\n        )\r\n        test_case.assertTrue(np.allclose(flow_res.numpy(), torch_res, atol=1e-4))\r\n\r\n    def test_linspace_start_equal_end_bug(test_case):\r\n        flow_res = flow.linspace(0, 0.0, 12).numpy()\r\n        torch_res = np.array(\r\n            [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]\r\n        )\r\n        test_case.assertTrue(np.allclose(flow_res, torch_res, atol=1e-4))\r\n\r\n\r\nif __name__ == \"__main__\":\r\n    unittest.main()\r\n"
  },
  {
    "path": "python/oneflow/test/modules/test_log1p.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestLog1pModule(flow.unittest.TestCase):\n    @autotest(check_graph=True)\n    def test_log1p_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor().to(device)\n        return torch.log1p(x)\n\n    @autotest(check_graph=True)\n    def test_log1p_with_0dim_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=0).to(device)\n        return torch.log1p(x)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_logaddexp.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom oneflow.test_utils.automated_test_util import *\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestLogAddExpModule(flow.unittest.TestCase):\n    @autotest(n=3, check_graph=True)\n    def test_log_add_exp_against_pytorch(test_case):\n        device = random_device()\n        dim1 = random(1, 5)\n        dim2 = random(1, 5)\n        x = random_tensor(2, dim1, dim2).to(device)\n        y = random_tensor(2, dim1, dim2).to(device)\n        z = torch.logaddexp(x, y)\n        return z\n\n    @autotest(n=3, check_graph=True)\n    def test_log_add_exp_with_0dim_tensor(test_case):\n        device = random_device()\n        x = random_tensor(ndim=0).to(device)\n        y = random_tensor(2, random(1, 5), random(1, 5)).to(device)\n        z = torch.logaddexp(x, y)\n        return z\n\n    @autotest(n=3, check_graph=True)\n    def test_tensor_log_add_exp_against_pytorch(test_case):\n        device = random_device()\n        dim1 = random(1, 5)\n        dim2 = random(1, 5)\n        x = random_tensor(2, dim1, dim2).to(device)\n        y = random_tensor(2, dim1, dim2).to(device)\n        z = x.logaddexp(y)\n        return z\n\n    @autotest(n=3, check_graph=True)\n    def test_tensor_log_add_exp_with_0dim_tensor(test_case):\n        device = random_device()\n        y = random_tensor(ndim=0).to(device)\n        x = random_tensor(2, random(1, 5), random(1, 5)).to(device)\n        z = x.logaddexp(y)\n        return z\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_logical_and.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\ndef _test_logical_and(test_case, shape, dtype, device):\n    np_input = np.random.randint(3, size=shape)\n    np_other = np.random.randint(3, size=shape)\n    input = flow.tensor(np_input, dtype=dtype, device=flow.device(device))\n    other = flow.tensor(np_other, dtype=dtype, device=flow.device(device))\n    of_out = flow.logical_and(input, other)\n    np_out = np.logical_and(np_input, np_other)\n    test_case.assertTrue(np.array_equal(of_out.numpy(), np_out))\n    x = torch.ones(3).byte()\n    y = torch.ones(3).byte()\n\n    z = (x & ~y).bool()\n    test_case.assertTrue(np.array_equal(z.numpy(), [False, False, False]))\n\n\ndef _test_tensor_logical_and(test_case, shape, dtype, device):\n    np_input = np.random.randint(3, size=shape)\n    np_other = np.random.randint(3, size=shape)\n    input = flow.tensor(np_input, dtype=dtype, device=flow.device(device))\n    other = flow.tensor(np_other, dtype=dtype, device=flow.device(device))\n    of_out = input.logical_and(other)\n    np_out = np.logical_and(np_input, np_other)\n    test_case.assertTrue(np.array_equal(of_out.numpy(), np_out))\n\n\ndef _test_tensor_scalar_logical_and(test_case, shape, scalar, dtype, device):\n    np_input = np.random.randint(3, size=shape)\n    input = flow.tensor(np_input, dtype=dtype, device=flow.device(device))\n    of_out = input.logical_and(scalar)\n    np_out = np.logical_and(np_input, scalar)\n    test_case.assertTrue(np.array_equal(of_out.numpy(), np_out))\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestLogicalAndModule(flow.unittest.TestCase):\n    def test_logical_and(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [\n            _test_logical_and,\n            _test_tensor_logical_and,\n        ]\n        arg_dict[\"shape\"] = [(2, 3), (2, 4, 5)]\n        arg_dict[\"dtype\"] = [flow.float32, flow.int32]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n    def test_scalar_logical_and(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [_test_tensor_scalar_logical_and]\n        arg_dict[\"shape\"] = [(2, 3), (2, 4, 5)]\n        arg_dict[\"scalar\"] = [1, 0]\n        arg_dict[\"dtype\"] = [flow.float32, flow.int32]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n    @autotest(n=10, auto_backward=False, check_graph=True)\n    def test_logical_and_with_random_data(test_case):\n        device = random_device()\n        shape = random_tensor().oneflow.shape\n        x1 = random_tensor(len(shape), *shape, requires_grad=False).to(device)\n        x2 = random_tensor(len(shape), *shape, requires_grad=False).to(device)\n        y = torch.logical_and(x1, x2)\n        return y\n\n    @autotest(n=10, auto_backward=False, check_graph=True)\n    def test_logical_and_bool_with_random_data(test_case):\n        device = random_device()\n        shape = random_tensor().oneflow.shape\n        x1 = random_tensor(len(shape), *shape, requires_grad=False).to(\n            device=device, dtype=torch.bool\n        )\n        x2 = random_tensor(len(shape), *shape, requires_grad=False).to(\n            device=device, dtype=torch.bool\n        )\n        y = torch.logical_and(x1, x2)\n        return y\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_logical_not.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\nimport numpy as np\nfrom oneflow.test_utils.test_util import GenArgList\nimport oneflow as flow\nfrom oneflow.test_utils.automated_test_util import *\n\n\ndef _test_logical_not(test_case, shape, device):\n    np_input = np.random.randint(3, size=shape)\n    input = flow.tensor(np_input, dtype=flow.float32, device=flow.device(device))\n    of_out = flow.logical_not(input)\n    np_out = np.logical_not(np_input)\n    test_case.assertTrue(np.array_equal(of_out.numpy(), np_out))\n\n\ndef _test_tensor_logical_not(test_case, shape, device):\n    np_input = np.random.randint(3, size=shape)\n    input = flow.tensor(np_input, dtype=flow.float32, device=flow.device(device))\n    of_out = input.logical_not()\n    np_out = np.logical_not(np_input)\n    test_case.assertTrue(np.array_equal(of_out.numpy(), np_out))\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestLogicalNotModule(flow.unittest.TestCase):\n    def test_logical_not(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [\n            _test_logical_not,\n            _test_tensor_logical_not,\n        ]\n        arg_dict[\"shape\"] = [(2, 3), (2, 4, 5)]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n    @autotest(n=10, auto_backward=False, check_graph=True)\n    def test_logical_not_with_random_data(test_case):\n        device = random_device()\n        shape = random_tensor().oneflow.shape\n        x1 = random_tensor(len(shape), *shape, requires_grad=False).to(device)\n        y = torch.logical_not(x1)\n        return y\n\n    @autotest(n=10, auto_backward=False, check_graph=True)\n    def test_logical_not_bool_with_random_data(test_case):\n        device = random_device()\n        shape = random_tensor().oneflow.shape\n        x1 = random_tensor(len(shape), *shape, requires_grad=False).to(\n            device=device, dtype=torch.bool\n        )\n        y = torch.logical_not(x1)\n        return y\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_logical_or.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\ndef _test_logical_or(test_case, shape, device):\n    np_input = np.random.randint(3, size=shape)\n    np_other = np.random.randint(3, size=shape)\n    input = flow.tensor(np_input, dtype=flow.float32, device=flow.device(device))\n    other = flow.tensor(np_other, dtype=flow.float32, device=flow.device(device))\n    of_out = flow.logical_or(input, other)\n    np_out = np.logical_or(np_input, np_other)\n    test_case.assertTrue(np.array_equal(of_out.numpy(), np_out))\n\n\ndef _test_tensor_logical_or(test_case, shape, device):\n    np_input = np.random.randint(3, size=shape)\n    np_other = np.random.randint(3, size=shape)\n    input = flow.tensor(np_input, dtype=flow.float32, device=flow.device(device))\n    other = flow.tensor(np_other, dtype=flow.float32, device=flow.device(device))\n    of_out = input.logical_or(other)\n    np_out = np.logical_or(np_input, np_other)\n    test_case.assertTrue(np.array_equal(of_out.numpy(), np_out))\n\n\ndef _test_tensor_scalar_logical_or(test_case, shape, scalar, dtype, device):\n    np_input = np.random.randint(3, size=shape)\n    input = flow.tensor(np_input, dtype=dtype, device=flow.device(device))\n    of_out = input.logical_or(scalar)\n    np_out = np.logical_or(np_input, scalar)\n    test_case.assertTrue(np.array_equal(of_out.numpy(), np_out))\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestLogicalOrModule(flow.unittest.TestCase):\n    def test_logical_or(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [\n            _test_logical_or,\n            _test_tensor_logical_or,\n        ]\n        arg_dict[\"shape\"] = [(2, 3), (2, 4, 5)]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n    def test_scalar_logical_or(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [_test_tensor_scalar_logical_or]\n        arg_dict[\"shape\"] = [(2, 3), (2, 4, 5)]\n        arg_dict[\"scalar\"] = [1, 0]\n        arg_dict[\"dtype\"] = [flow.float32, flow.int32]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n    @autotest(n=10, auto_backward=False, check_graph=True)\n    def test_logical_or_with_random_data(test_case):\n        device = random_device()\n        shape = random_tensor().oneflow.shape\n        x1 = random_tensor(len(shape), *shape, requires_grad=False).to(device)\n        x2 = random_tensor(len(shape), *shape, requires_grad=False).to(device)\n        y = torch.logical_or(x1, x2)\n        return y\n\n    @autotest(n=10, auto_backward=False, check_graph=True)\n    def test_logical_or_bool_with_random_data(test_case):\n        device = random_device()\n        shape = random_tensor().oneflow.shape\n        x1 = random_tensor(len(shape), *shape, requires_grad=False).to(\n            device=device, dtype=torch.bool\n        )\n        x2 = random_tensor(len(shape), *shape, requires_grad=False).to(\n            device=device, dtype=torch.bool\n        )\n        y = torch.logical_or(x1, x2)\n        return y\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_logical_reduce.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\n\nfrom oneflow.test_utils.automated_test_util import *\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestLogicalReduce(flow.unittest.TestCase):\n    @autotest(n=5, auto_backward=False)\n    def test_sum_with_random_data(test_case):\n        device = random_device()\n        dim = random(1, 4).to(int)\n        x = random_tensor(ndim=4, dtype=float, requires_grad=False).to(device)\n        return torch.sum(x, dim)\n\n    @autotest(n=5, auto_backward=False)\n    def test_mean_with_random_data(test_case):\n        device = random_device()\n        dim = random(1, 4).to(int)\n        x = random_tensor(ndim=4, dtype=float, requires_grad=False).to(device)\n        return torch.mean(x, dim)\n\n    @autotest(n=5, auto_backward=False)\n    def test_all_with_random_data(test_case):\n        device = random_device()\n        dim = random(1, 4).to(int)\n        x = random_tensor(ndim=4, dtype=float, requires_grad=False).to(device)\n        return torch.all(x, dim)\n\n    @autotest(n=5, auto_backward=False)\n    def test_any_with_random_data(test_case):\n        device = random_device()\n        dim = random(1, 4).to(int)\n        x = random_tensor(ndim=4, dtype=float, requires_grad=False).to(device)\n        return torch.any(x, dim)\n\n    @autotest(n=5, auto_backward=False)\n    def test_prod_with_random_data(test_case):\n        device = random_device()\n        dim = random(1, 4).to(int)\n        x = random_tensor(ndim=4, dtype=float, requires_grad=False).to(device)\n        return torch.prod(x, dim)\n\n    @autotest(n=5, auto_backward=False)\n    def test_sum_keepdim_with_random_data(test_case):\n        device = random_device()\n        dim = random(1, 4).to(int)\n        x = random_tensor(ndim=4, dtype=float, requires_grad=False).to(device)\n        return torch.sum(x, dim, keepdim=True)\n\n    @autotest(n=5, auto_backward=False)\n    def test_mean_keepdim_with_random_data(test_case):\n        device = random_device()\n        dim = random(1, 4).to(int)\n        x = random_tensor(ndim=4, dtype=float, requires_grad=False).to(device)\n        return torch.mean(x, dim, keepdim=True)\n\n    @autotest(n=5, auto_backward=False)\n    def test_all_keepdim_with_random_data(test_case):\n        device = random_device()\n        dim = random(1, 4).to(int)\n        x = random_tensor(ndim=4, dtype=float, requires_grad=False).to(device)\n        return torch.all(x, dim, keepdim=True)\n\n    @autotest(n=5, auto_backward=False)\n    def test_any_keepdim_with_random_data(test_case):\n        device = random_device()\n        dim = random(1, 4).to(int)\n        x = random_tensor(ndim=4, dtype=float, requires_grad=False).to(device)\n        return torch.any(x, dim, keepdim=True)\n\n    @autotest(n=5, auto_backward=False)\n    def test_prod_keepdim_with_random_data(test_case):\n        device = random_device()\n        dim = random(1, 4).to(int)\n        x = random_tensor(ndim=4, dtype=float, requires_grad=False).to(device)\n        return torch.prod(x, dim, keepdim=True)\n\n    @autotest(n=5, auto_backward=False)\n    def test_scalar_reduce_sum_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=4, dtype=float, requires_grad=False).to(device)\n        return torch.sum(x)\n\n    @autotest(n=5, auto_backward=False)\n    def test_scalar_reduce_mean_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=4, dtype=float, requires_grad=False).to(device)\n        return torch.mean(x)\n\n    @autotest(n=5, auto_backward=False)\n    def test_scalar_reduce_all_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=4, dtype=float, requires_grad=False).to(device)\n        return torch.all(x)\n\n    @autotest(n=5, auto_backward=False)\n    def test_scalar_reduce_any_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=4, dtype=float, requires_grad=False).to(device)\n        return torch.any(x)\n\n    @autotest(n=5, auto_backward=False)\n    def test_scalar_reduce_prod_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=4, dtype=float, requires_grad=False).to(device)\n        return torch.prod(x)\n\n    @autotest(n=5, auto_backward=False)\n    def test_all_bool_input_with_random_data(test_case):\n        device = random_device()\n        dim = random(1, 4).to(int)\n        x = random_tensor(ndim=4, dtype=float, requires_grad=False).to(\n            device, dtype=torch.bool\n        )\n        return torch.all(x, dim)\n\n    @autotest(auto_backward=False, check_graph=True)\n    def test_max_bool_input_with_random_data(test_case):\n        device = random_device()\n        dim = random(1, 4).to(int)\n        x = random_tensor(ndim=4, dtype=float, requires_grad=False).to(\n            device, dtype=torch.bool\n        )\n        return torch.max(x, dim)\n\n    @autotest(auto_backward=False, check_graph=True)\n    def test_min_bool_input_with_random_data(test_case):\n        device = random_device()\n        dim = random(1, 4).to(int)\n        x = random_tensor(ndim=4, dtype=float, requires_grad=False).to(\n            device, dtype=torch.bool\n        )\n        return torch.min(x, dim)\n\n    @autotest(n=5, auto_backward=False)\n    def test_any_bool_input_with_random_data(test_case):\n        device = random_device()\n        dim = random(1, 4).to(int)\n        x = random_tensor(ndim=4, dtype=float, requires_grad=False).to(\n            device, dtype=torch.bool\n        )\n        return torch.any(x, dim)\n\n    @autotest(n=5, auto_backward=False)\n    def test_reduce_all_0dim_tensor(test_case):\n        device = random_device()\n        x = random_tensor(ndim=0, requires_grad=False).to(device)\n        return torch.all(x)\n\n    @autotest(n=5, auto_backward=False)\n    def test_reduce_all_0size_tensor(test_case):\n        device = random_device()\n        x = torch.empty(0, 2).to(device)\n        return torch.all(x)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_logical_xor.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\ndef _test_logical_xor_int(test_case, shape, device):\n    np_input = np.random.randint(-2, 4, size=shape)\n    np_other = np.random.randint(-2, 4, size=shape)\n    input = flow.tensor(np_input, dtype=flow.float32, device=flow.device(device))\n    other = flow.tensor(np_other, dtype=flow.float32, device=flow.device(device))\n    of_out = flow.logical_xor(input, other)\n    np_out = np.logical_xor(np_input, np_other)\n    test_case.assertTrue(np.array_equal(of_out.numpy(), np_out))\n\n\ndef _test_logical_xor_float(test_case, shape, device):\n    np_input = np.random.uniform(low=-5, high=5, size=shape)\n    np_other = np.random.uniform(low=-5, high=5, size=shape)\n    input = flow.tensor(np_input, dtype=flow.float32, device=flow.device(device))\n    other = flow.tensor(np_other, dtype=flow.float32, device=flow.device(device))\n    of_out = flow.logical_xor(input, other)\n    np_out = np.logical_xor(np_input, np_other)\n    test_case.assertTrue(np.array_equal(of_out.numpy(), np_out))\n\n\ndef _test_tensor_logical_xor_int(test_case, shape, device):\n    np_input = np.random.randint(-2, 4, size=shape)\n    np_other = np.random.randint(-2, 4, size=shape)\n    input = flow.tensor(np_input, dtype=flow.float32, device=flow.device(device))\n    other = flow.tensor(np_other, dtype=flow.float32, device=flow.device(device))\n    of_out = input.logical_xor(other)\n    np_out = np.logical_xor(np_input, np_other)\n    test_case.assertTrue(np.array_equal(of_out.numpy(), np_out))\n\n\ndef _test_tensor_logical_xor_float(test_case, shape, device):\n    np_input = np.random.uniform(low=-5, high=5, size=shape)\n    np_other = np.random.uniform(low=-5, high=5, size=shape)\n    input = flow.tensor(np_input, dtype=flow.float32, device=flow.device(device))\n    other = flow.tensor(np_other, dtype=flow.float32, device=flow.device(device))\n    of_out = input.logical_xor(other)\n    np_out = np.logical_xor(np_input, np_other)\n    test_case.assertTrue(np.array_equal(of_out.numpy(), np_out))\n\n\ndef _test_tensor_scalar_logical_xor(test_case, shape, scalar, dtype, device):\n    np_input = np.random.randint(3, size=shape)\n    input = flow.tensor(np_input, dtype=dtype, device=flow.device(device))\n    of_out = input.logical_xor(scalar)\n    np_out = np.logical_xor(np_input, scalar)\n    test_case.assertTrue(np.array_equal(of_out.numpy(), np_out))\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestLogicalXorModule(flow.unittest.TestCase):\n    def test_logical_xor(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [\n            _test_logical_xor_int,\n            _test_tensor_logical_xor_int,\n            _test_logical_xor_float,\n            _test_tensor_logical_xor_float,\n        ]\n        arg_dict[\"shape\"] = [(2, 3), (2, 4, 5)]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n    def test_scalar_logical_xor(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [_test_tensor_scalar_logical_xor]\n        arg_dict[\"shape\"] = [(2, 3), (2, 4, 5)]\n        arg_dict[\"scalar\"] = [1, 0]\n        arg_dict[\"dtype\"] = [flow.float32, flow.int32]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n    @autotest(n=10, auto_backward=False, check_graph=True)\n    def test_logical_xor_with_random_data(test_case):\n        device = random_device()\n        shape = random_tensor().oneflow.shape\n        x1 = random_tensor(len(shape), *shape, requires_grad=False).to(device)\n        x2 = random_tensor(len(shape), *shape, requires_grad=False).to(device)\n        y = torch.logical_xor(x1, x2)\n        return y\n\n    @autotest(n=10, auto_backward=False, check_graph=True)\n    def test_logical_xor_bool_with_random_data(test_case):\n        device = random_device()\n        shape = random_tensor().oneflow.shape\n        x1 = random_tensor(len(shape), *shape, requires_grad=False).to(\n            device=device, dtype=torch.bool\n        )\n        x2 = random_tensor(len(shape), *shape, requires_grad=False).to(\n            device=device, dtype=torch.bool\n        )\n        y = torch.logical_xor(x1, x2)\n        return y\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_logspace.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestLogspace(flow.unittest.TestCase):\n    @autotest(n=5, auto_backward=False)\n    def test_logspace_int_with_random_data(test_case):\n        start = random().to(int)\n        end = start + random().to(int)\n        steps = random(0, end - start).to(int)\n        x = torch.logspace(start=start, end=end, steps=steps)\n        device = random_device()\n        x.to(device)\n        return x\n\n    @autotest(n=5, auto_backward=False)\n    def test_logspace_float_with_random_data(test_case):\n        start = random()\n        end = start + random()\n        steps = random(0, end - start).to(int)\n        x = torch.logspace(start=start, end=end, steps=steps)\n        device = random_device()\n        x.to(device)\n        return x\n\n    @autotest(n=5, auto_backward=False)\n    def test_logspace_with_random_base(test_case):\n        start = random()\n        end = start + random()\n        steps = random(0, end - start).to(int)\n        base = random(1, 4).to(float)\n        x = torch.logspace(start=start, end=end, steps=steps, base=base)\n        device = random_device()\n        x.to(device)\n        return x\n\n    def test_global_naive(test_case):\n        placement = flow.placement(\"cpu\", ranks=[0])\n        sbp = (flow.sbp.broadcast,)\n        x = flow.logspace(start=0, end=10, steps=2, placement=placement, sbp=sbp)\n        test_case.assertEqual(x.sbp, sbp)\n        test_case.assertEqual(x.placement, placement)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_logsumexp.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\n\nfrom oneflow.test_utils.automated_test_util import *\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestLogSumExpModule(flow.unittest.TestCase):\n    @autotest(n=3, check_graph=True)\n    def test_log_sum_exp_against_pytorch(test_case):\n        device = random_device()\n        x = random_tensor(4, random(0, 5), 2).to(device)\n        y = torch.logsumexp(x, dim=np.random.randint(0, 3))\n        return y\n\n    @unittest.skipIf(True, \"pytorch-1.10.0 dose not support big_value of logsumexp\")\n    @autotest(n=3, auto_backward=False, check_graph=True)\n    def test_log_sum_exp_with_big_value(test_case):\n        device = random_device()\n        x = torch.tensor([100, 200]).to(device)\n        y = torch.logsumexp(x, dim=0)\n        return y\n\n    @autotest(n=3, auto_backward=False, check_graph=True)\n    def test_log_sum_exp_with_0_size_tensor(test_case):\n        device = random_device()\n        x = random_tensor(4, 4, 3, 0, 2).to(device)\n        y = torch.logsumexp(x, dim=np.random.randint(0, 3))\n        return y\n\n    @autotest(n=3, auto_backward=False, check_graph=True)\n    def test_log_sum_exp_with_0dim_tensor(test_case):\n        device = random_device()\n        x = random_tensor(ndim=0).to(device)\n        y = torch.logsumexp(x, dim=0)\n        return y\n\n    @autotest(n=3, check_graph=True)\n    def test_tensor_log_sum_exp_against_pytorch(test_case):\n        device = random_device()\n        x = random_tensor(4, random(0, 5), 2).to(device)\n        y = x.logsumexp(dim=np.random.randint(0, 3))\n        return y\n\n    @unittest.skipIf(True, \"pytorch-1.10.0 dose not support big_value of logsumexp\")\n    @autotest(n=3, auto_backward=False, check_graph=True)\n    def test_tensor_log_sum_exp_with_big_value(test_case):\n        device = random_device()\n        x = torch.tensor([100, 200]).to(device)\n        y = x.logsumexp(dim=0)\n        return y\n\n    @autotest(n=3, auto_backward=False, check_graph=True)\n    def test_tensor_log_sum_exp_with_0_size_tensor(test_case):\n        device = random_device()\n        x = random_tensor(4, 4, 3, 0, 2).to(device)\n        y = x.logsumexp(dim=np.random.randint(0, 3))\n        return y\n\n    @autotest(n=3, auto_backward=False, check_graph=True)\n    def test_tensor_log_sum_exp_with_0dim_tensor(test_case):\n        device = random_device()\n        x = random_tensor(ndim=0).to(device)\n        y = x.logsumexp(dim=0)\n        return y\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_loss.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\nfrom oneflow.test_utils.test_util import GenArgList\nfrom oneflow.test_utils.automated_test_util import *\n\nimport numpy as np\nimport oneflow as flow\nimport oneflow.unittest\nimport torch as torch_original\nfrom packaging import version\n\n\ndef generate_necessity_for_cross_entropy_or_nll_loss(dim: int, prob: bool = False):\n    if dim > 5 or dim < 2:\n        raise ValueError(\"dim should be less than 5 or greater than 1. \")\n    device = random_device()\n    num_classes = random(low=2).to(int)\n    batch_size = random(low=10, high=100).to(int)\n    ignore_index = (\n        random(0, num_classes).to(int) | nothing()\n        if num_classes.value() > 2 and not prob\n        else nothing()\n    )\n    extra_dim = [random().to(int) for _ in range(dim - 2)]\n\n    if prob:\n        target_tensor = random_tensor(\n            dim, batch_size, num_classes, *extra_dim, requires_grad=False,\n        ).to(device)\n    else:\n        target_tensor = random_tensor(\n            dim - 1,\n            batch_size,\n            *extra_dim,\n            low=0,\n            high=num_classes,\n            dtype=int,\n            requires_grad=False,\n        ).to(device)\n    return (\n        random_tensor(dim, batch_size, num_classes, *extra_dim).to(device),\n        target_tensor,\n        random_tensor(1, num_classes, low=0, high=3, requires_grad=False).to(device),\n        ignore_index,\n        device,\n    )\n\n\ndef generate_necessity_for_bce_loss(dim: int):\n    if dim > 5 or dim < 2:\n        raise ValueError(\"dim should be less than 6 or greater than 1. \")\n    device = random_device()\n    num_classes = random(low=3).to(int)\n    batch_size = random(low=10, high=100).to(int)\n    extra_dim = [random().to(int) for _ in range(dim - 2)]\n    return (\n        random_tensor(dim, batch_size, num_classes, low=0, high=1, *extra_dim).to(\n            device\n        ),\n        random_tensor(\n            dim,\n            batch_size,\n            num_classes,\n            *extra_dim,\n            low=0,\n            high=num_classes,\n            requires_grad=False,\n        ).to(device),\n        random_tensor(\n            dim, batch_size, num_classes, *extra_dim, low=0, high=3, requires_grad=False\n        ).to(device),\n        random_tensor(\n            1,\n            extra_dim[-1] if dim > 2 else num_classes,\n            low=1,\n            high=3,\n            requires_grad=False,\n        ).to(device),\n        device,\n    )\n\n\ndef _test_cross_entropy_loss(dim: int, prob: bool = False):\n    (\n        x,\n        target,\n        weight,\n        ignore_index,\n        device,\n    ) = generate_necessity_for_cross_entropy_or_nll_loss(dim, prob)\n    m = torch.nn.CrossEntropyLoss(\n        reduction=oneof(\"none\", \"sum\", \"mean\", nothing()),\n        ignore_index=ignore_index,\n        weight=oneof(weight, nothing()),\n        # TODO(wangyi): PyTorch under 1.12 has bug here, which returns wrong result when ignore_index >= 0 and label_smoothing > 0\n        label_smoothing=random(low=0, high=1)\n        if version.parse(torch_original.__version__) >= version.parse(\"1.12.0\")\n        else 0,\n    )\n    m.train(random())\n    m.to(device)\n\n    y = m(x, target)\n    return y\n\n\ndef _test_nn_functional_cross_entropy_loss(dim: int, prob: bool):\n    (\n        x,\n        target,\n        weight,\n        ignore_index,\n        device,\n    ) = generate_necessity_for_cross_entropy_or_nll_loss(dim, prob)\n    y1 = torch.nn.functional.cross_entropy(x, target)\n    y2 = torch.nn.functional.cross_entropy(x, target, weight)\n    return y1 + y2\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestCrossEntropyLossModule(flow.unittest.TestCase):\n    @autotest(n=5)\n    def test_cross_entropy_loss_with_random_data_dim_2(test_case):\n        return _test_cross_entropy_loss(2, prob=False)\n\n    @autotest(n=5)\n    def test_cross_entropy_loss_with_random_data_dim_3(test_case):\n        return _test_cross_entropy_loss(3, prob=False)\n\n    @autotest(n=5)\n    def test_cross_entropy_loss_with_random_data_dim_4(test_case):\n        return _test_cross_entropy_loss(4, prob=False)\n\n    @autotest(n=5)\n    def test_cross_entropy_loss_with_random_data_dim_5(test_case):\n        return _test_cross_entropy_loss(5, prob=False)\n\n    @autotest(n=5)\n    def test_nn_functional_cross_entropy_with_random_data_dim(test_case):\n        dim = random(2, 6).to(int).value()\n        return _test_nn_functional_cross_entropy_loss(dim, prob=False)\n\n    @unittest.skip(\"skip for now, becase it failed 3 times in past week\")\n    @autotest(n=5)\n    def test_cross_entropy_prob_loss_with_random_data_dim_2(test_case):\n        return _test_cross_entropy_loss(2, prob=True)\n\n    @autotest(n=5, rtol=1e-3)\n    def test_cross_entropy_prob_loss_with_random_data_dim_3(test_case):\n        return _test_cross_entropy_loss(3, prob=True)\n\n    @unittest.skip(\"skip for now, becase it failed 4 times in past week\")\n    @autotest(n=5)\n    def test_cross_entropy_prob_loss_with_random_data_dim_4(test_case):\n        return _test_cross_entropy_loss(4, prob=True)\n\n    @unittest.skip(\"skip for now, becase it failed 6 times in past week\")\n    @autotest(n=5)\n    def test_cross_entropy_prob_loss_with_random_data_dim_5(test_case):\n        return _test_cross_entropy_loss(5, prob=True)\n\n    @autotest(n=5)\n    def test_nn_functional_prob_cross_entropy_with_random_data_dim(test_case):\n        dim = random(2, 6).to(int).value()\n        return _test_nn_functional_cross_entropy_loss(dim, prob=True)\n\n\ndef _test_nll_loss(dim=int):\n    (\n        x,\n        target,\n        weight,\n        ignore_index,\n        device,\n    ) = generate_necessity_for_cross_entropy_or_nll_loss(dim)\n    m = torch.nn.NLLLoss(\n        weight=oneof(weight, nothing()),\n        reduction=oneof(\"none\", \"sum\", \"mean\", nothing()),\n        ignore_index=ignore_index,\n    )\n    m.train(random())\n    m.to(device)\n\n    y = m(x, target)\n    return y\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestNLLLossModule(flow.unittest.TestCase):\n    @autotest(n=5)\n    def test_nll_loss_with_random_data_dim_2(test_case):\n        return _test_nll_loss(2)\n\n    @autotest(n=5)\n    def test_nll_loss_with_random_data_dim_3(test_case):\n        return _test_nll_loss(3)\n\n    @autotest(n=5)\n    def test_nll_loss_with_random_data_dim_4(test_case):\n        return _test_nll_loss(4)\n\n    @autotest(n=5)\n    def test_nll_loss_with_random_data_dim_5(test_case):\n        return _test_nll_loss(5)\n\n\ndef _test_bce_loss(dim=int, with_logits: bool = False):\n    x, target, weight, pos_weight, device = generate_necessity_for_bce_loss(dim)\n\n    m = torch.nn.BCELoss(\n        weight=oneof(weight, nothing()),\n        reduction=oneof(\"none\", \"sum\", \"mean\", nothing()),\n    )\n    pos_weight_for_testing_broadcast = random_tensor(\n        1, 1, low=1, high=3, requires_grad=False,\n    ).to(device)\n    if with_logits:\n        m = torch.nn.BCEWithLogitsLoss(\n            weight=oneof(weight, nothing()),\n            pos_weight=oneof(pos_weight, pos_weight_for_testing_broadcast, nothing()),\n            reduction=oneof(\"none\", \"sum\", \"mean\", nothing()),\n        )\n    m.train(random())\n    m.to(device)\n\n    y = m(x, target)\n    return y\n\n\ndef _test_nn_functional_binary_cross_entropy(dim=int):\n    (x, target, weight, pos_weight, device) = generate_necessity_for_bce_loss(dim)\n    y = torch.nn.functional.binary_cross_entropy(\n        x,\n        target,\n        weight=oneof(weight, nothing()),\n        reduction=oneof(\"none\", \"sum\", \"mean\", nothing()),\n        pos_weight=oneof(pos_weight, nothing()),\n    )\n    return y\n\n\ndef _test_nn_functional_binary_cross_entropy_with_logits(dim=int):\n    (x, target, weight, pos_weight, device) = generate_necessity_for_bce_loss(dim)\n    y = torch.nn.functional.binary_cross_entropy_with_logits(\n        x,\n        target,\n        weight=oneof(weight, nothing()),\n        reduction=oneof(\"none\", \"sum\", \"mean\", nothing()),\n    )\n    return y\n\n\ndef _test_nn_functional_binary_cross_entropy_with_logits_different_dtype_float_first(\n    test_case, shape, reduction, device\n):\n    def compare(a, b):\n        test_case.assertTrue(\n            np.allclose(\n                a.detach().cpu().numpy(),\n                b.detach().cpu().numpy(),\n                rtol=1e-5,\n                atol=1e-5,\n            )\n        )\n\n    arr = np.random.randn(*shape)\n\n    flow_pred_mask = flow.Tensor(arr).float().to(device)\n    flow_pred_mask.requires_grad = True\n    flow_gt_mask = flow.Tensor(arr).double().to(device)\n    flow_loss = flow.nn.functional.binary_cross_entropy_with_logits(\n        flow_pred_mask, flow_gt_mask, reduction=reduction\n    )\n    flow_loss.sum().backward()\n    torch_pred_mask = torch_original.Tensor(arr).float().to(device)\n    torch_pred_mask.requires_grad = True\n    torch_gt_mask = torch_original.Tensor(arr).double().to(device)\n    torch_loss = torch_original.nn.functional.binary_cross_entropy_with_logits(\n        torch_pred_mask, torch_gt_mask, reduction=reduction\n    )\n    torch_loss.sum().backward()\n    compare(flow_loss, torch_loss)\n    compare(flow_pred_mask.grad.data, torch_pred_mask.grad.data)\n\n\ndef _test_nn_functional_binary_cross_entropy_with_logits_different_dtype_double_first(\n    test_case, shape, reduction, device\n):\n    def compare(a, b):\n        test_case.assertTrue(\n            np.allclose(\n                a.detach().cpu().numpy(),\n                b.detach().cpu().numpy(),\n                rtol=1e-5,\n                atol=1e-5,\n            )\n        )\n\n    arr = np.random.randn(*shape)\n\n    flow_pred_mask = flow.Tensor(arr).double().to(device)\n    flow_pred_mask.requires_grad = True\n    flow_gt_mask = flow.Tensor(arr).float().to(device)\n    flow_loss = flow.nn.functional.binary_cross_entropy_with_logits(\n        flow_pred_mask, flow_gt_mask, reduction=reduction\n    )\n    flow_loss.sum().backward()\n    torch_pred_mask = torch_original.Tensor(arr).double().to(device)\n    torch_pred_mask.requires_grad = True\n    torch_gt_mask = torch_original.Tensor(arr).float().to(device)\n    torch_loss = torch_original.nn.functional.binary_cross_entropy_with_logits(\n        torch_pred_mask, torch_gt_mask, reduction=reduction\n    )\n    torch_loss.sum().backward()\n    compare(flow_loss, torch_loss)\n    compare(flow_pred_mask.grad.data, torch_pred_mask.grad.data)\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestBCELossModule(flow.unittest.TestCase):\n    @autotest(n=5)\n    def test_bce_loss_with_random_data_dim_2(test_case):\n        return _test_bce_loss(2)\n\n    @autotest(n=5)\n    def test_bce_loss_with_random_data_dim_3(test_case):\n        return _test_bce_loss(3)\n\n    @autotest(n=5)\n    def test_bce_loss_with_random_data_dim_4(test_case):\n        return _test_bce_loss(4)\n\n    @autotest(n=5)\n    def test_bce_loss_with_random_data_dim_5(test_case):\n        return _test_bce_loss(5)\n\n    @autotest(n=5)\n    def test_nn_functional_binary_cross_entropy(test_case):\n        dim = random(2, 6).to(int).value()\n        return _test_nn_functional_binary_cross_entropy(dim)\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestBCEWithLogitsLossModule(flow.unittest.TestCase):\n    @autotest(n=5)\n    def test_bce_with_logits_loss_with_random_data_dim_2(test_case):\n        return _test_bce_loss(2, True)\n\n    @autotest(n=5)\n    def test_bce_with_logits_loss_with_random_data_dim_3(test_case):\n        return _test_bce_loss(3, True)\n\n    @autotest(n=5)\n    def test_bce_with_logits_loss_with_random_data_dim_4(test_case):\n        return _test_bce_loss(4, True)\n\n    @autotest(n=5)\n    def test_bce_with_logits_loss_with_random_data_dim_5(test_case):\n        return _test_bce_loss(5, True)\n\n    @autotest(n=5)\n    def test_nn_functional_binary_cross_entropy_with_logits(test_case):\n        dim = random(2, 6).to(int).value()\n        return _test_nn_functional_binary_cross_entropy_with_logits(dim)\n\n    @autotest(n=5)\n    def test_nn_functional_binary_cross_entropy_with_logits_different_dtype(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"fun\"] = [\n            _test_nn_functional_binary_cross_entropy_with_logits_different_dtype_float_first,\n            _test_nn_functional_binary_cross_entropy_with_logits_different_dtype_double_first,\n        ]\n        arg_dict[\"shape\"] = [(24, 16, 80), (42, 160), (4, 54, 32, 56)]\n        arg_dict[\"reduction\"] = [\"sum\", \"mean\", \"none\"]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestL1LossModule(flow.unittest.TestCase):\n    @autotest(n=5)\n    def test_l1_loss_with_random_data(test_case):\n        device = random_device()\n        shape = random_tensor().oneflow.shape\n\n        x = random_tensor(len(shape), *shape).to(device)\n        target = random_tensor(len(shape), *shape, requires_grad=False).to(device)\n\n        m = torch.nn.L1Loss(reduction=oneof(\"none\", \"sum\", \"mean\", nothing()))\n        m.train(random())\n        m.to(device)\n\n        y = m(x, target)\n        return y\n\n    @autotest(n=5)\n    def _test_nn_functional_l1_loss(test_case):\n        device = random_device()\n        shape = random_tensor().oneflow.shape\n\n        x = random_tensor(len(shape), *shape).to(device)\n        target = random_tensor(len(shape), *shape, requires_grad=False).to(device)\n\n        y = torch.nn.functional.l1_loss(\n            x, target, reduction=oneof(\"none\", \"sum\", \"mean\", nothing())\n        )\n        return y\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestSmoothL1LossModule(flow.unittest.TestCase):\n    @autotest(n=5)\n    def test_smooth_l1_loss_with_random_data(test_case):\n        device = random_device()\n        shape = random_tensor().oneflow.shape\n\n        x = random_tensor(len(shape), *shape).to(device)\n        target = random_tensor(len(shape), *shape, requires_grad=False).to(device)\n\n        m = torch.nn.SmoothL1Loss(\n            reduction=oneof(\"none\", \"sum\", \"mean\", nothing()), beta=oneof(0, 0.5, 1)\n        )\n        m.train(random())\n        m.to(device)\n\n        y = m(x, target)\n        return y\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestMSELossModule(flow.unittest.TestCase):\n    @autotest(n=5)\n    def test_mse_loss_with_random_data(test_case):\n        device = random_device()\n        shape = random_tensor().oneflow.shape\n\n        x = random_tensor(len(shape), *shape).to(device)\n        target = random_tensor(len(shape), *shape, requires_grad=False).to(device)\n\n        m = torch.nn.MSELoss(reduction=oneof(\"none\", \"sum\", \"mean\", nothing()))\n        m.train(random())\n        m.to(device)\n\n        y = m(x, target)\n        return y\n\n    @autotest(n=5)\n    def _test_nn_functional_mse_loss(test_case):\n        device = random_device()\n        shape = random_tensor().oneflow.shape\n\n        x = random_tensor(len(shape), *shape).to(device)\n        target = random_tensor(len(shape), *shape, requires_grad=False).to(device)\n\n        y = torch.nn.functional.mse_loss(\n            x, target, reduction=oneof(\"none\", \"sum\", \"mean\", nothing())\n        )\n        return y\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestKLDivLossModule(flow.unittest.TestCase):\n    @autotest(n=5)\n    def test_kldiv_loss_with_random_data(test_case):\n        device = random_device()\n        shape = random_tensor().oneflow.shape\n\n        x = random_tensor(len(shape), low=0, *shape).to(device)\n        target = random_tensor(len(shape), low=0, *shape, requires_grad=False).to(\n            device\n        )\n\n        m = torch.nn.KLDivLoss(\n            reduction=oneof(\"none\", \"sum\", \"mean\", \"batchmean\", nothing()),\n            log_target=oneof(True, False, nothing()),\n        )\n        m.train(random())\n        m.to(device)\n\n        y = m(x, target)\n        return y\n\n    @autotest(n=5)\n    def test_nn_functional_kl_div(test_case):\n        device = random_device()\n        shape = random_tensor().oneflow.shape\n        x = random_tensor(len(shape), low=0, *shape).to(device)\n        target = random_tensor(len(shape), low=0, *shape, requires_grad=False).to(\n            device\n        )\n        y = torch.nn.functional.kl_div(\n            x,\n            target,\n            reduction=oneof(\"none\", \"sum\", \"mean\", \"batchmean\", nothing()),\n            log_target=oneof(True, False, nothing()),\n        )\n        return y\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestMarginRankingLossModule(flow.unittest.TestCase):\n    @autotest(n=5)\n    def test_margin_ranking_loss_with_random_data(test_case):\n        device = random_device()\n        shape = random_tensor().oneflow.shape\n\n        x1 = random_tensor(len(shape), *shape).to(device)\n        x2 = random_tensor(len(shape), *shape).to(device)\n        target = random_tensor(len(shape), *shape, requires_grad=False).to(device)\n\n        m = torch.nn.MarginRankingLoss(\n            margin=oneof(0.0, 0.3, 10),\n            reduction=oneof(\"none\", \"sum\", \"mean\", nothing()),\n        )\n        m.train(random())\n        m.to(device)\n\n        y = m(x1, x2, target)\n        return y\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_loss_global.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nimport numpy as np\nimport oneflow as flow\nimport oneflow.unittest\nfrom collections import OrderedDict\nfrom oneflow.test_utils.test_util import GenArgList\n\n\ndef get_sbp(device: str):\n    return flow.placement.all(device), flow.sbp.split(0)\n\n\nshapes = {2: (128, 8), 3: (16, 8, 64), 4: (16, 8, 32, 32), 5: (16, 8, 16, 16, 16)}\n\n\ndef compare_loss(device_type, dim, reduction, cls, data_generator):\n    x, y, x1, y1 = data_generator(dim, device_type, *get_sbp(device_type))\n    reduce_loss_func = cls(reduction=reduction).to(device_type)\n    none_loss_func = cls(reduction=\"none\").to(device_type)\n\n    loss_mean = reduce_loss_func(x, y)\n    loss_none = (\n        flow.mean(none_loss_func(x1, y1))\n        if reduction == \"mean\"\n        else flow.sum(none_loss_func(x1, y1))\n    )\n\n    loss_mean.backward()\n    loss_none.backward()\n\n    assert np.allclose(\n        loss_none.to_local().numpy(),\n        loss_mean.to_local().numpy(),\n        rtol=1e-05,\n        atol=1e-05,\n    )\n    assert np.allclose(loss_none.numpy(), loss_mean.numpy(), rtol=1e-05, atol=1e-05,)\n    assert np.allclose(\n        x.grad.to_local().numpy(), x1.grad.to_local().numpy(), rtol=1e-05, atol=1e-05,\n    )\n\n\ndef generate_necessity_default(dim: int, device: str, placement, sbp):\n    shape = shapes[dim]\n    x_np = np.random.uniform(0, 1, shape)\n    y_np = np.random.uniform(0, 1, shape)\n\n    def f(x, requires_grad):\n        t = flow.tensor(x, device=device, requires_grad=requires_grad).to_global(\n            placement=placement, sbp=[sbp]\n        )\n        if requires_grad:\n            t.retain_grad()\n        return t\n\n    return f(x_np, True), f(y_np, False), f(x_np, True), f(y_np, False)\n\n\ndef generate_necessity_for_cross_entropy_or_nll_loss(\n    dim: int, device: str, placement, sbp\n):\n    shape = shapes[dim]\n    y_shape = (shape[0],) if dim == 2 else (shape[0], *shape[2:])\n    x_np = np.random.uniform(0, 1, shape)\n    y_np = np.random.randint(0, shape[1], y_shape)\n\n    def f(x, requires_grad):\n        t = flow.tensor(x, device=device, requires_grad=requires_grad).to_global(\n            placement=placement, sbp=[sbp]\n        )\n        if requires_grad:\n            t.retain_grad()\n        return t\n\n    return f(x_np, True), f(y_np, False), f(x_np, True), f(y_np, False)\n\n\nclass TestBCELossOrWithLogitsConsistent(flow.unittest.TestCase):\n    @flow.unittest.skip_unless_1n2d()\n    def test_bce_loss(testcase):\n        arg_dict = OrderedDict()\n        arg_dict[\"device_type\"] = [\"cuda\", \"cpu\"]\n        arg_dict[\"dim\"] = [2, 3, 4, 5]\n        arg_dict[\"reduction\"] = [\"sum\", \"mean\"]\n        arg_dict[\"cls\"] = [flow.nn.BCELoss, flow.nn.BCEWithLogitsLoss]\n        arg_dict[\"data_generator\"] = [generate_necessity_default]\n        for arg in GenArgList(arg_dict):\n            compare_loss(*arg)\n\n\nclass TestCrossEntropyOrNllLossConsistent(flow.unittest.TestCase):\n    @flow.unittest.skip_unless_1n2d()\n    def test_cross_entropy_loss_or_nll_loss(testcase):\n        arg_dict = OrderedDict()\n        arg_dict[\"device_type\"] = [\"cuda\", \"cpu\"]\n        arg_dict[\"dim\"] = [2, 3, 4, 5]\n        arg_dict[\"reduction\"] = [\"sum\", \"mean\"]\n        arg_dict[\"cls\"] = [flow.nn.CrossEntropyLoss, flow.nn.NLLLoss]\n        arg_dict[\"data_generator\"] = [generate_necessity_for_cross_entropy_or_nll_loss]\n        for arg in GenArgList(arg_dict):\n            compare_loss(*arg)\n\n\nclass TestKLDivLossConsistent(flow.unittest.TestCase):\n    @flow.unittest.skip_unless_1n2d()\n    def test_kl_div_loss(testcase):\n        arg_dict = OrderedDict()\n        arg_dict[\"device_type\"] = [\"cuda\", \"cpu\"]\n        arg_dict[\"dim\"] = [2, 3, 4, 5]\n        arg_dict[\"reduction\"] = [\"sum\", \"mean\"]\n        arg_dict[\"cls\"] = [flow.nn.KLDivLoss]\n        arg_dict[\"data_generator\"] = [generate_necessity_default]\n        for arg in GenArgList(arg_dict):\n            compare_loss(*arg)\n\n\nclass TestSmoothL1LossConsistent(flow.unittest.TestCase):\n    @flow.unittest.skip_unless_1n2d()\n    def test_smooth_l1_loss(testcase):\n        arg_dict = OrderedDict()\n        arg_dict[\"device_type\"] = [\"cuda\", \"cpu\"]\n        arg_dict[\"dim\"] = [2, 3, 4, 5]\n        arg_dict[\"reduction\"] = [\"sum\", \"mean\"]\n        arg_dict[\"cls\"] = [flow.nn.SmoothL1Loss]\n        arg_dict[\"data_generator\"] = [generate_necessity_default]\n        for arg in GenArgList(arg_dict):\n            compare_loss(*arg)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_lr_scheduler.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport math\nimport random\nimport tempfile\nimport unittest\nimport numpy as np\nfrom collections import OrderedDict\n\nimport oneflow as flow\nimport oneflow.unittest\nimport torch\nfrom oneflow.nn.parameter import Parameter\n\nfrom oneflow.test_utils.test_util import GenArgDict\n\n\ndef compare_with_torch_reduce_lr(\n    test_case, mode, factor, patience, threshold, threshold_mode, cooldown, min_lr, eps,\n):\n    optimizer_flow = flow.optim.SGD(\n        [{\"params\": [Parameter(flow.Tensor([1.0]))]},],\n        lr=TestLrScheduler.base_lr,\n        momentum=0.9,\n    )\n\n    optimizer_torch = torch.optim.SGD(\n        [{\"params\": [torch.nn.Parameter(torch.Tensor([1.0]))]},],\n        lr=TestLrScheduler.base_lr,\n        momentum=0.9,\n    )\n\n    scheduler_flow = flow.optim.lr_scheduler.ReduceLROnPlateau(\n        optimizer_flow,\n        mode,\n        factor,\n        patience,\n        threshold,\n        threshold_mode,\n        cooldown,\n        min_lr,\n        eps,\n    )\n    scheduler_troch = torch.optim.lr_scheduler.ReduceLROnPlateau(\n        optimizer_torch,\n        mode,\n        factor,\n        patience,\n        threshold,\n        threshold_mode,\n        cooldown,\n        min_lr,\n        eps,\n    )\n    val_loss = 0.1\n    for epoch in range(15):\n        val_loss += (random.random() - 0.5) / 10\n        scheduler_flow.step(val_loss)\n        scheduler_troch.step(val_loss)\n        for (lr1, lr2) in zip(scheduler_flow._last_lr, scheduler_troch._last_lr):\n            test_case.assertAlmostEqual(lr1, lr2, places=5)\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestLrScheduler(flow.unittest.TestCase):\n    base_lr = 1.0\n\n    def test_cosine_decay_lr(test_case):\n        optimizer = flow.optim.SGD(\n            [{\"params\": [Parameter(flow.Tensor([1.0]))]}], lr=TestLrScheduler.base_lr\n        )\n\n        def cosine_decay_lr_step(base_lr, current_step, decay_steps, alpha):\n            if current_step < decay_steps:\n                cos_decay = 0.5 * (1 + math.cos(math.pi * current_step / decay_steps))\n                decay_factor = (1 - alpha) * cos_decay + alpha\n                return base_lr * decay_factor\n            else:\n                return base_lr * alpha\n\n        alpha = 0.5\n        decay_steps = 10\n        cosine_decay_lr = flow.optim.lr_scheduler.CosineDecayLR(\n            optimizer, decay_steps=decay_steps, alpha=alpha\n        )\n        for i in range(1, 21):\n            cosine_decay_lr.step()\n            new_lr = cosine_decay_lr_step(\n                TestLrScheduler.base_lr, i, decay_steps, alpha\n            )\n            test_case.assertAlmostEqual(\n                cosine_decay_lr.get_last_lr()[0], new_lr, places=4\n            )\n\n    def test_cosine_annealing_lr(test_case):\n        optimizer = flow.optim.SGD(\n            [{\"params\": [Parameter(flow.Tensor([1.0]))]}], lr=TestLrScheduler.base_lr\n        )\n\n        def cosine_annealing_lr_step(base_lr, current_step, last_lr, T_max, eta_min):\n            if (current_step - 1 - T_max) % (2 * T_max) == 0:\n                return (\n                    last_lr\n                    + (TestLrScheduler.base_lr - eta_min)\n                    * (1 - math.cos(math.pi / T_max))\n                    / 2\n                )\n            else:\n                return (1 + math.cos(math.pi * current_step / T_max)) / (\n                    1 + math.cos(math.pi * (current_step - 1) / T_max)\n                ) * (last_lr - eta_min) + eta_min\n\n        T_max = 20\n        eta_min = 0.5\n        cosine_annealing_lr = flow.optim.lr_scheduler.CosineAnnealingLR(\n            optimizer, T_max=T_max, eta_min=eta_min\n        )\n        numpy_last_lr = TestLrScheduler.base_lr\n        for i in range(1, 101):\n            cosine_annealing_lr.step()\n            numpy_last_lr = cosine_annealing_lr_step(\n                TestLrScheduler.base_lr, i, numpy_last_lr, T_max, eta_min\n            )\n            test_case.assertAlmostEqual(\n                cosine_annealing_lr.get_last_lr()[0], numpy_last_lr, places=4\n            )\n\n    def test_step_lr(test_case):\n        optimizer = flow.optim.SGD(\n            [{\"params\": [Parameter(flow.Tensor([1.0]))]}], lr=TestLrScheduler.base_lr\n        )\n\n        def step_lr_step(base_lr, current_step, step_size, gamma):\n            return base_lr * gamma ** (current_step // step_size)\n\n        gamma = 0.1\n        step_size = 5\n        step_lr = flow.optim.lr_scheduler.StepLR(\n            optimizer, step_size=step_size, gamma=gamma\n        )\n        for i in range(1, 21):\n            step_lr.step()\n            new_lr = step_lr_step(TestLrScheduler.base_lr, i, step_size, gamma)\n            test_case.assertAlmostEqual(step_lr.get_last_lr()[0], new_lr, places=5)\n\n    def test_multistep_lr(test_case):\n        optimizer = flow.optim.SGD(\n            [{\"params\": [Parameter(flow.Tensor([1.0]))]}], lr=TestLrScheduler.base_lr\n        )\n\n        def multistep_lr_step(base_lr, current_step, milestones, gamma):\n            count = 0\n            for step in milestones:\n                if current_step >= step:\n                    count += 1\n            return base_lr * gamma ** count\n\n        gamma = 0.1\n        milestones = [5, 11, 15]\n        multistep_lr = flow.optim.lr_scheduler.MultiStepLR(\n            optimizer, milestones=milestones, gamma=gamma\n        )\n        for i in range(1, 18):\n            multistep_lr.step()\n            new_lr = multistep_lr_step(TestLrScheduler.base_lr, i, milestones, gamma)\n            test_case.assertAlmostEqual(multistep_lr.get_last_lr()[0], new_lr, places=5)\n\n    def test_exponential_lr(test_case):\n        optimizer = flow.optim.SGD(\n            [{\"params\": [Parameter(flow.Tensor([1.0]))]}], lr=TestLrScheduler.base_lr\n        )\n\n        def exponential_lr_step(base_lr, current_step, gamma):\n            return base_lr * gamma ** current_step\n\n        gamma = 0.1\n        exponential_lr = flow.optim.lr_scheduler.ExponentialLR(optimizer, gamma=gamma)\n        for i in range(1, 21):\n            exponential_lr.step()\n            new_lr = exponential_lr_step(TestLrScheduler.base_lr, i, gamma)\n            test_case.assertAlmostEqual(\n                exponential_lr.get_last_lr()[0], new_lr, places=5\n            )\n\n    def test_lambda_lr(test_case):\n        optimizer = flow.optim.SGD(\n            [\n                {\"params\": [Parameter(flow.Tensor([1.0]))]},\n                {\"params\": [Parameter(flow.Tensor([1.0]))]},\n            ],\n            lr=TestLrScheduler.base_lr,\n        )\n        lambdas = [lambda step: step // 30, lambda step: 0.95 * step]\n\n        def lambda_lr_step(base_lrs, current_step):\n            return [\n                base_lr * lmbda(current_step)\n                for (base_lr, lmbda) in zip(base_lrs, lambdas)\n            ]\n\n        lambda_lr = flow.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambdas)\n        for i in range(1, 21):\n            lambda_lr.step()\n            new_lrs = lambda_lr_step(lambda_lr.base_lrs, i)\n            for (lr1, lr2) in zip(lambda_lr.get_last_lr(), new_lrs):\n                test_case.assertAlmostEqual(lr1, lr2, places=5)\n\n    def test_polynomial_lr(test_case):\n        optimizer = flow.optim.SGD(\n            [{\"params\": [Parameter(flow.Tensor([1.0]))]}], lr=TestLrScheduler.base_lr\n        )\n\n        def polynomial_lr_step(base_lr, end_lr, step, decay_steps, power, cycle):\n            if cycle:\n                if step == 0:\n                    step = 1\n                decay_steps = decay_steps * math.ceil(step / decay_steps)\n            step = min(step, decay_steps)\n            return (base_lr - end_lr) * (1 - step / decay_steps) ** power + end_lr\n\n        decay_steps = 100\n        end_learning_rate = 1e-5\n        power = 2\n        cycle = True\n        poly_decay_lr = flow.optim.lr_scheduler.PolynomialLR(\n            optimizer, decay_steps, end_learning_rate, power, cycle\n        )\n        # step(0) will be invoked in LRScheduler.__init__\n        new_lr = polynomial_lr_step(\n            TestLrScheduler.base_lr, end_learning_rate, 0, decay_steps, power, cycle\n        )\n        test_case.assertAlmostEqual(poly_decay_lr.get_last_lr()[0], new_lr, places=4)\n        for i in range(1, 21):\n            poly_decay_lr.step()\n            new_lr = polynomial_lr_step(\n                TestLrScheduler.base_lr, end_learning_rate, i, decay_steps, power, cycle\n            )\n            test_case.assertAlmostEqual(\n                poly_decay_lr.get_last_lr()[0], new_lr, places=4\n            )\n\n        cycle = True\n        poly_decay_lr = flow.optim.lr_scheduler.PolynomialLR(\n            optimizer, decay_steps, end_learning_rate, power, cycle\n        )\n        for i in range(1, 21):\n            poly_decay_lr.step()\n            new_lr = polynomial_lr_step(\n                TestLrScheduler.base_lr, end_learning_rate, i, decay_steps, power, cycle\n            )\n            test_case.assertAlmostEqual(\n                poly_decay_lr.get_last_lr()[0], new_lr, places=4\n            )\n\n    def test_reduce_lr_on_plateau(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"mode\"] = [\"min\", \"max\"]\n        arg_dict[\"factor\"] = [0.1, 0.3]\n        arg_dict[\"patience\"] = [2, 5]\n        arg_dict[\"threshold\"] = [1e-3, 1e-5]\n        arg_dict[\"threshold_mode\"] = [\"rel\", \"abs\"]\n        arg_dict[\"cooldown\"] = [0, 1]\n        arg_dict[\"min_lr\"] = [0, 1e-3]\n        arg_dict[\"eps\"] = [1e-5, 1e-8]\n        for arg in GenArgDict(arg_dict):\n            compare_with_torch_reduce_lr(test_case, **arg)\n\n    def test_warmup_scheduler_save_and_load(test_case):\n        param = flow.nn.Parameter(flow.ones(3, 4))\n\n        optimizer = flow.optim.SGD([param])\n        cosine_scheduler = flow.optim.lr_scheduler.CosineAnnealingLR(optimizer, 100)\n        lr_scheduler = flow.optim.lr_scheduler.WarmUpLR(\n            cosine_scheduler, warmup_factor=0.1, warmup_iters=5, warmup_method=\"linear\",\n        )\n        for _ in range(random.randint(1, 10)):\n            lr_scheduler.step()\n        # save\n        with tempfile.NamedTemporaryFile() as f:\n            flow.save(lr_scheduler.state_dict(), f.name)\n            state_dict = flow.load(f.name)\n\n        # load\n        param2 = flow.nn.Parameter(flow.ones(3, 4))\n        optimizer2 = flow.optim.SGD([param])\n        cosine_scheduler2 = flow.optim.lr_scheduler.CosineAnnealingLR(optimizer, 50)\n        lr_scheduler2 = flow.optim.lr_scheduler.WarmUpLR(\n            cosine_scheduler2,\n            warmup_factor=0.5,\n            warmup_iters=10,\n            warmup_method=\"linear\",\n        )\n        lr_scheduler2.load_state_dict(state_dict)\n\n        # compare warm up scheduler\n        for attr in [\"warmup_iters\", \"warmup_factor\", \"warmup_method\", \"last_step\"]:\n            test_case.assertEqual(\n                getattr(lr_scheduler, attr), getattr(lr_scheduler2, attr)\n            )\n        # compare cosine_annealing_lr\n        for attr in [\"T_max\", \"eta_min\", \"last_step\"]:\n            test_case.assertEqual(\n                getattr(cosine_scheduler, attr), getattr(cosine_scheduler2, attr)\n            )\n\n\n@flow.unittest.skip_unless_1n1d()\nclass WarmupLRTestCase(flow.unittest.TestCase):\n    def test_only_warmup(test_case):\n        param = flow.nn.Parameter(flow.ones(3, 4))\n        optimizer = flow.optim.SGD([param], lr=0.001)\n        warmup_lr = flow.optim.lr_scheduler.WarmupLR(\n            optimizer, warmup_factor=0.5, warmup_iters=5, warmup_method=\"linear\"\n        )\n        expected_lrs = [\n            0.0005,\n            0.0006,\n            0.0007,\n            0.0008,\n            0.0009,\n            0.001,\n            0.001,\n            0.001,\n            0.001,\n            0.001,\n        ]\n        lrs = [warmup_lr.get_last_lr()[0]]\n        for _ in range(len(expected_lrs)):\n            optimizer.step()\n            warmup_lr.step()\n            lrs.append(warmup_lr.get_last_lr()[0])\n\n        lrs = lrs[:-1]\n\n        test_case.assertTrue(\n            np.allclose(lrs, expected_lrs),\n            f\"\\nexpected_lrs: {expected_lrs}\\nvs.\\ncalculated lrs: {lrs}\",\n        )\n\n    def test_warmup_iters_0_exp_lr(test_case):\n        lr = 0.1\n        gamma = 0.9\n        param = flow.nn.Parameter(flow.ones(3, 4))\n        optimizer = flow.optim.SGD([param], lr)\n        exp_lr = flow.optim.lr_scheduler.ExponentialLR(optimizer, gamma)\n        warmup_lr = flow.optim.lr_scheduler.WarmupLR(\n            exp_lr, warmup_factor=0.5, warmup_iters=0, warmup_method=\"linear\"\n        )\n        iters = 10\n        lrs = [warmup_lr.get_last_lr()[0]]\n        for _ in range(iters):\n            warmup_lr.step()\n            lrs.append(warmup_lr.get_last_lr()[0])\n\n        lrs = lrs[:-1]\n        expected_lrs = [lr * pow(gamma, i) for i in range(iters)]\n        test_case.assertTrue(\n            np.allclose(lrs, expected_lrs),\n            f\"\\nexpected_lrs: {expected_lrs}\\nvs.\\ncalculated lrs: {lrs}\",\n        )\n\n    def test_linear_warmup_exp_lr(test_case):\n        lr = 0.1\n        gamma = 0.9\n        param = flow.nn.Parameter(flow.ones(3, 4))\n        optimizer = flow.optim.SGD([param], lr)\n        exp_lr = flow.optim.lr_scheduler.ExponentialLR(optimizer, gamma)\n        warmup_lr = flow.optim.lr_scheduler.WarmupLR(\n            exp_lr, warmup_factor=0.5, warmup_iters=5, warmup_method=\"linear\"\n        )\n        expected_lrs = [\n            0.05,\n            0.0518098,\n            0.0536196,\n            0.0554294,\n            0.0572392,\n            0.059049,\n            0.0531441,\n            0.04782969,\n            0.043046721,\n            0.0387420489,\n        ]\n\n        lrs = [warmup_lr.get_last_lr()[0]]\n        for _ in range(len(expected_lrs)):\n            warmup_lr.step()\n            lrs.append(warmup_lr.get_last_lr()[0])\n\n        lrs = lrs[:-1]\n        test_case.assertTrue(\n            np.allclose(lrs, expected_lrs),\n            f\"\\nexpected_lrs: {expected_lrs}\\nvs.\\ncalculated lrs: {lrs}\",\n        )\n\n    def test_linear_warmup_prefix_exp_lr(test_case):\n        lr = 0.1\n        gamma = 0.9\n        param = flow.nn.Parameter(flow.ones(3, 4))\n        optimizer = flow.optim.SGD([param], lr)\n        exp_lr = flow.optim.lr_scheduler.ExponentialLR(optimizer, gamma)\n        warmup_lr = flow.optim.lr_scheduler.WarmupLR(\n            exp_lr,\n            warmup_factor=0.5,\n            warmup_iters=5,\n            warmup_method=\"linear\",\n            warmup_prefix=True,\n        )\n        expected_lrs = [\n            0.05,\n            0.06,\n            0.07,\n            0.08,\n            0.09,\n            0.1,\n            0.09,\n            0.081,\n            0.0729,\n            0.06561,\n        ]\n\n        lrs = [warmup_lr.get_last_lr()[0]]\n        for _ in range(len(expected_lrs)):\n            warmup_lr.step()\n            lrs.append(warmup_lr.get_last_lr()[0])\n\n        lrs = lrs[:-1]\n        test_case.assertTrue(\n            np.allclose(lrs, expected_lrs),\n            f\"\\nexpected_lrs: {expected_lrs}\\nvs.\\ncalculated lrs: {lrs}\",\n        )\n\n    def test_constant_warmup_cosine_annealing(test_case):\n        lr = 0.1\n        param = flow.nn.Parameter(flow.ones(3, 4))\n        optimizer = flow.optim.SGD([param], lr)\n        cos_annl_lr = flow.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)\n        warmup_lr = flow.optim.lr_scheduler.WarmupLR(\n            cos_annl_lr, warmup_factor=0.5, warmup_iters=5, warmup_method=\"constant\",\n        )\n\n        expected_lrs = [\n            0.05,\n            0.05,\n            0.05,\n            0.05,\n            0.05,\n            0.05,\n            0.03454915028125264,\n            0.020610737385376353,\n            0.009549150281252635,\n            0.002447174185242324,\n            0.0,\n            0.0024471741852423235,\n            0.009549150281252666,\n            0.020610737385376433,\n            0.034549150281252786,\n            0.050000000000000225,\n            0.06545084971874766,\n            0.079389262614624,\n            0.09045084971874778,\n            0.09755282581475812,\n            0.1,\n        ]\n\n        lrs = [warmup_lr.get_last_lr()[0]]\n        for _ in range(len(expected_lrs)):\n            warmup_lr.step()\n            lrs.append(warmup_lr.get_last_lr()[0])\n\n        lrs = lrs[:-1]\n        test_case.assertTrue(\n            np.allclose(lrs, expected_lrs),\n            f\"\\nexpected_lrs: {expected_lrs}\\nvs.\\ncalculated lrs: {lrs}\",\n        )\n\n    def test_linear_warmup_cosine_annealing(test_case):\n        param = flow.nn.Parameter(flow.ones(3, 4))\n        optimizer = flow.optim.SGD([param], lr=0.1)\n        cos_annl_lr = flow.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20)\n        warmup_lr = flow.optim.lr_scheduler.WarmupLR(\n            cos_annl_lr, warmup_factor=0.1, warmup_iters=5, warmup_method=\"linear\",\n        )\n\n        expected_lrs = [\n            0.01,\n            0.025071068,\n            0.040142136,\n            0.055213203,\n            0.070284271,\n            0.085355339,\n            0.079389263,\n            0.072699525,\n            0.06545085,\n            0.057821723,\n            0.05,\n            0.042178277,\n            0.03454915,\n            0.027300475,\n            0.020610737,\n            0.014644661,\n            0.00954915,\n            0.005449674,\n            0.002447174,\n            0.000615583,\n        ]\n\n        lrs = [warmup_lr.get_last_lr()[0]]\n        for _ in range(len(expected_lrs)):\n            warmup_lr.step()\n            lrs.append(warmup_lr.get_last_lr()[0])\n\n        lrs = lrs[:-1]\n\n        test_case.assertTrue(\n            np.allclose(lrs, expected_lrs),\n            f\"\\nexpected_lrs: {expected_lrs}\\nvs.\\ncalculated lrs: {lrs}\",\n        )\n\n    def test_linear_warmup_prefix_cosine_annealing(test_case):\n        param = flow.nn.Parameter(flow.ones(3, 4))\n        optimizer = flow.optim.SGD([param], lr=0.1)\n        cos_annl_lr = flow.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20)\n        warmup_lr = flow.optim.lr_scheduler.WarmupLR(\n            cos_annl_lr,\n            warmup_factor=0.1,\n            warmup_iters=5,\n            warmup_method=\"linear\",\n            warmup_prefix=True,\n        )\n\n        expected_lrs = [\n            0.01,\n            0.028,\n            0.046,\n            0.064,\n            0.082,\n            0.1,\n            0.099384417,\n            0.097552826,\n            0.094550326,\n            0.09045085,\n            0.085355339,\n            0.079389263,\n            0.072699525,\n            0.06545085,\n            0.057821723,\n            0.05,\n            0.042178277,\n            0.03454915,\n            0.027300475,\n            0.020610737,\n        ]\n\n        lrs = [warmup_lr.get_last_lr()[0]]\n        for _ in range(len(expected_lrs)):\n            warmup_lr.step()\n            lrs.append(warmup_lr.get_last_lr()[0])\n\n        lrs = lrs[:-1]\n\n        test_case.assertTrue(\n            np.allclose(lrs, expected_lrs),\n            f\"\\nexpected_lrs: {expected_lrs}\\nvs.\\ncalculated lrs: {lrs}\",\n        )\n\n    def test_linear_warmup_multistep_lr(test_case):\n        param = flow.nn.Parameter(flow.ones(3, 4))\n        optimizer = flow.optim.SGD([param], lr=0.001)\n        multistep_lr = flow.optim.lr_scheduler.MultiStepLR(optimizer, [10])\n        warmup_lr = flow.optim.lr_scheduler.WarmupLR(\n            multistep_lr, warmup_factor=0.5, warmup_iters=5, warmup_method=\"linear\",\n        )\n        expected_lrs = [\n            0.0005,\n            0.0006,\n            0.0007,\n            0.0008,\n            0.0009,\n            0.001,\n            0.001,\n            0.001,\n            0.001,\n            0.001,\n            0.0001,\n            0.0001,\n            0.0001,\n            0.0001,\n            0.0001,\n            0.0001,\n            0.0001,\n            0.0001,\n            0.0001,\n            0.0001,\n        ]\n        lrs = [warmup_lr.get_last_lr()[0]]\n        for _ in range(len(expected_lrs)):\n            optimizer.step()\n            warmup_lr.step()\n            lrs.append(warmup_lr.get_last_lr()[0])\n\n        lrs = lrs[:-1]\n\n        test_case.assertTrue(\n            np.allclose(lrs, expected_lrs),\n            f\"\\nexpected_lrs: {expected_lrs}\\nvs.\\ncalculated lrs: {lrs}\",\n        )\n\n    def test_linear_warmup_prefix_multistep_lr(test_case):\n        param = flow.nn.Parameter(flow.ones(3, 4))\n        optimizer = flow.optim.SGD([param], lr=0.1)\n        multistep_lr = flow.optim.lr_scheduler.MultiStepLR(\n            optimizer, milestones=[5, 10]\n        )\n        warmup_lr = flow.optim.lr_scheduler.WarmupLR(\n            multistep_lr,\n            warmup_factor=0.1,\n            warmup_iters=5,\n            warmup_method=\"linear\",\n            warmup_prefix=True,\n        )\n\n        expected_lrs = [\n            0.01,\n            0.028,\n            0.046,\n            0.064,\n            0.082,\n            0.1,\n            0.1,\n            0.1,\n            0.1,\n            0.1,\n            0.01,\n            0.01,\n            0.01,\n            0.01,\n            0.01,\n            0.001,\n            0.001,\n            0.001,\n            0.001,\n            0.001,\n        ]\n\n        lrs = [warmup_lr.get_last_lr()[0]]\n        for _ in range(len(expected_lrs)):\n            warmup_lr.step()\n            lrs.append(warmup_lr.get_last_lr()[0])\n\n        lrs = lrs[:-1]\n\n        test_case.assertTrue(\n            np.allclose(lrs, expected_lrs),\n            f\"\\nexpected_lrs: {expected_lrs}\\nvs.\\ncalculated lrs: {lrs}\",\n        )\n\n\n@flow.unittest.skip_unless_1n1d()\nclass ConstantLRTestCase(flow.unittest.TestCase):\n    def test(test_case):\n        param = flow.nn.Parameter(flow.ones(3, 4))\n        optimizer = flow.optim.SGD([param], lr=0.01)\n        constant_lr = flow.optim.lr_scheduler.ConstantLR(optimizer, 0.1, 10)\n        expected_lrs = [\n            0.001,\n            0.001,\n            0.001,\n            0.001,\n            0.001,\n            0.001,\n            0.001,\n            0.001,\n            0.001,\n            0.001,\n            0.01,\n            0.01,\n            0.01,\n            0.01,\n            0.01,\n            0.01,\n            0.01,\n            0.01,\n            0.01,\n            0.01,\n        ]\n        lrs = [constant_lr.get_last_lr()[0]]\n        for _ in range(len(expected_lrs)):\n            constant_lr.step()\n            lrs.append(constant_lr.get_last_lr()[0])\n\n        lrs = lrs[:-1]\n        test_case.assertTrue(\n            np.allclose(lrs, expected_lrs),\n            f\"\\nexpected_lrs: {expected_lrs}\\nvs.\\ncalculated lrs: {lrs}\",\n        )\n\n\n@flow.unittest.skip_unless_1n1d()\nclass LinearLRTestCase(flow.unittest.TestCase):\n    def test(test_case):\n        param = flow.nn.Parameter(flow.ones(3, 4))\n        optimizer = flow.optim.SGD([param], lr=0.1)\n        linear_lr = flow.optim.lr_scheduler.LinearLR(optimizer, 0.1, 1, 10)\n        expected_lrs = [\n            0.01,\n            0.019,\n            0.028,\n            0.037,\n            0.046,\n            0.055,\n            0.064,\n            0.073,\n            0.082,\n            0.091,\n            0.1,\n            0.1,\n            0.1,\n            0.1,\n            0.1,\n            0.1,\n            0.1,\n            0.1,\n            0.1,\n            0.1,\n            0.1,\n        ]\n        lrs = [linear_lr.get_last_lr()[0]]\n        for _ in range(len(expected_lrs)):\n            linear_lr.step()\n            lrs.append(linear_lr.get_last_lr()[0])\n\n        lrs = lrs[:-1]\n        test_case.assertTrue(\n            np.allclose(lrs, expected_lrs),\n            f\"\\nexpected_lrs: {expected_lrs}\\nvs.\\ncalculated lrs: {lrs}\",\n        )\n\n    def test_end_factor(test_case):\n        param = flow.nn.Parameter(flow.ones(3, 4))\n        optimizer = flow.optim.SGD([param], lr=0.1)\n        linear_lr = flow.optim.lr_scheduler.LinearLR(optimizer, 0.1, 0.9, 10)\n        expected_lrs = [\n            0.01,\n            0.018,\n            0.026,\n            0.034,\n            0.042,\n            0.05,\n            0.058,\n            0.066,\n            0.074,\n            0.082,\n            0.09,\n            0.09,\n            0.09,\n            0.09,\n            0.09,\n            0.09,\n            0.09,\n            0.09,\n            0.09,\n            0.09,\n            0.09,\n        ]\n        lrs = [linear_lr.get_last_lr()[0]]\n        for _ in range(len(expected_lrs)):\n            linear_lr.step()\n            lrs.append(linear_lr.get_last_lr()[0])\n\n        lrs = lrs[:-1]\n        test_case.assertTrue(\n            np.allclose(lrs, expected_lrs),\n            f\"\\nexpected_lrs: {expected_lrs}\\nvs.\\ncalculated lrs: {lrs}\",\n        )\n\n\n@flow.unittest.skip_unless_1n1d()\nclass ChainedSchedulerTestCase(flow.unittest.TestCase):\n    def test(test_case):\n        param = flow.nn.Parameter(flow.ones(3, 4))\n        opt = flow.optim.SGD([param], lr=1)\n        s1 = flow.optim.lr_scheduler.ConstantLR(opt, factor=0.1, total_iters=3)\n        s2 = flow.optim.lr_scheduler.ExponentialLR(opt, gamma=0.9)\n        scheduler = flow.optim.lr_scheduler.ChainedScheduler([s1, s2])\n\n        expected_lrs = [0.1, 0.09, 0.081, 0.729, 0.6561, 0.59049]\n        lrs = [scheduler.get_last_lr()[0]]\n        for _ in range(len(expected_lrs)):\n            scheduler.step()\n            lrs.append(scheduler.get_last_lr()[0])\n\n        lrs = lrs[: len(expected_lrs)]\n        test_case.assertTrue(\n            np.allclose(lrs, expected_lrs),\n            f\"\\nexpected_lrs: {expected_lrs}\\nvs.\\ncalculated lrs: {lrs}\",\n        )\n\n\n@flow.unittest.skip_unless_1n1d()\nclass CosineAnnealingWarmRestartsTestCase(flow.unittest.TestCase):\n    def test_mult_1(test_case):\n        param = flow.nn.Parameter(flow.ones(3, 4))\n        optimizer = flow.optim.SGD([param], lr=0.1)\n        cosa_r_lr = flow.optim.lr_scheduler.CosineAnnealingWarmRestarts(\n            optimizer, T_0=10, eta_min=0.01,\n        )\n        # fmt: off\n        expected_lrs = [0.1, 0.09779754323328192, 0.09140576474687263, 0.08145033635316129, 0.06890576474687264, 0.05500000000000001, 0.04109423525312737, 0.028549663646838717, 0.01859423525312737, 0.012202456766718092, 0.1, 0.09779754323328192, 0.09140576474687263, 0.08145033635316129, 0.06890576474687264, 0.05500000000000001, 0.04109423525312737, 0.028549663646838717, 0.01859423525312737, 0.012202456766718092, 0.1, 0.09779754323328192, 0.09140576474687263, 0.08145033635316129, 0.06890576474687264, 0.05500000000000001, 0.04109423525312737, 0.028549663646838717, 0.01859423525312737, 0.012202456766718092, 0.1, 0.09779754323328192, 0.09140576474687263, 0.08145033635316129, 0.06890576474687264, 0.05500000000000001, 0.04109423525312737, 0.028549663646838717, 0.01859423525312737, 0.012202456766718092, 0.1, 0.09779754323328192, 0.09140576474687263, 0.08145033635316129, 0.06890576474687264, 0.05500000000000001, 0.04109423525312737, 0.028549663646838717, 0.01859423525312737, 0.012202456766718092]\n        # fmt: on\n        lrs = [cosa_r_lr.get_last_lr()[0]]\n        for _ in range(len(expected_lrs)):\n            cosa_r_lr.step()\n            lrs.append(cosa_r_lr.get_last_lr()[0])\n\n        lrs = lrs[: len(expected_lrs)]\n        test_case.assertTrue(\n            np.allclose(lrs, expected_lrs),\n            f\"\\nexpected_lrs: {expected_lrs}\\nvs.\\ncalculated lrs: {lrs}\",\n        )\n\n    def test_mult_2(test_case):\n        param = flow.nn.Parameter(flow.ones(3, 4))\n        optimizer = flow.optim.SGD([param], lr=0.1)\n        cosa_r_lr = flow.optim.lr_scheduler.CosineAnnealingWarmRestarts(\n            optimizer, T_0=10, T_mult=2, eta_min=0.01,\n        )\n        # fmt: off\n        expected_lrs = [0.1, 0.09779754323328192, 0.09140576474687263, 0.08145033635316129, 0.06890576474687264, 0.05500000000000001, 0.04109423525312737, 0.028549663646838717, 0.01859423525312737, 0.012202456766718092, 0.1, 0.0994459753267812, 0.09779754323328192, 0.09509529358847656, 0.09140576474687263, 0.08681980515339464, 0.08145033635316129, 0.07542957248827961, 0.06890576474687264, 0.0620395509268104, 0.05500000000000001, 0.04796044907318963, 0.04109423525312737, 0.034570427511720396, 0.028549663646838717, 0.023180194846605363, 0.01859423525312737, 0.014904706411523451, 0.012202456766718092, 0.010554024673218806, 0.1, 0.09986128001799077, 0.0994459753267812, 0.09875664641789544, 0.09779754323328192, 0.0965745789630079, 0.09509529358847656, 0.09336880739593416, 0.09140576474687263, 0.0892182684520014, 0.08681980515339464, 0.08422516217485827, 0.08145033635316129, 0.0785124354122177, 0.07542957248827961, 0.07222075445642905, 0.06890576474687264, 0.06550504137351576, 0.0620395509268104, 0.05853065930775304, 0.05500000000000001, 0.05146934069224699, 0.04796044907318963, 0.04449495862648427, 0.04109423525312737, 0.03777924554357097, 0.034570427511720396, 0.031487564587782305, 0.028549663646838717, 0.02577483782514174, 0.023180194846605363, 0.02078173154799861, 0.01859423525312737, 0.016631192604065852, 0.014904706411523451, 0.013425421036992097, 0.012202456766718092, 0.011243353582104555, 0.010554024673218806, 0.010138719982009242]\n        # fmt: on\n        lrs = [cosa_r_lr.get_last_lr()[0]]\n        for _ in range(len(expected_lrs)):\n            cosa_r_lr.step()\n            lrs.append(cosa_r_lr.get_last_lr()[0])\n\n        lrs = lrs[: len(expected_lrs)]\n        test_case.assertTrue(\n            np.allclose(lrs, expected_lrs),\n            f\"\\nexpected_lrs: {expected_lrs}\\nvs.\\ncalculated lrs: {lrs}\",\n        )\n\n    def test_mult_2_decay_half_limit_2(test_case):\n        param = flow.nn.Parameter(flow.ones(3, 4))\n        optimizer = flow.optim.SGD([param], lr=0.1)\n        cosa_r_lr = flow.optim.lr_scheduler.CosineAnnealingWarmRestarts(\n            optimizer, T_0=10, T_mult=2, decay_rate=0.5, restart_limit=2, eta_min=0.01,\n        )\n        # fmt: off\n        expected_lrs = [0.1, 0.09779754323328192, 0.09140576474687263, 0.08145033635316129, 0.06890576474687264, 0.05500000000000001, 0.04109423525312737, 0.028549663646838717, 0.01859423525312737, 0.012202456766718092, 0.05, 0.04975376681190276, 0.04902113032590308, 0.04782013048376736, 0.04618033988749895, 0.044142135623730955, 0.04175570504584947, 0.03907980999479094, 0.03618033988749895, 0.03312868930080462, 0.03, 0.02687131069919539, 0.023819660112501053, 0.020920190005209068, 0.018244294954150538, 0.01585786437626905, 0.013819660112501053, 0.012179869516232645, 0.01097886967409693, 0.010246233188097247, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01]\n        # fmt: on\n        lrs = [cosa_r_lr.get_last_lr()[0]]\n        for _ in range(len(expected_lrs)):\n            cosa_r_lr.step()\n            lrs.append(cosa_r_lr.get_last_lr()[0])\n\n        lrs = lrs[: len(expected_lrs)]\n        test_case.assertTrue(\n            np.allclose(lrs, expected_lrs),\n            f\"\\nexpected_lrs: {expected_lrs}\\nvs.\\ncalculated lrs: {lrs}\",\n        )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_masked_fill.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestMaskedFill(flow.unittest.TestCase):\n    @autotest(n=3)\n    def test_flow_masked_fill_with_random_data(test_case):\n        k1 = random(2, 6)\n        k2 = random(2, 6)\n        device = random_device()\n        input = random_tensor(ndim=2, dim0=k1, dim1=k2).to(device)\n        mask = random_tensor(ndim=2, dim0=k1, dim1=k2).to(device)\n        value = random().to(float)\n        return input.masked_fill(mask > 0.5, value)\n\n    @autotest(n=3)\n    def test_flow_masked_fill_with_0dim_data(test_case):\n        device = random_device()\n        input = random_tensor(ndim=0).to(device)\n        mask = random_tensor(ndim=0).to(device)\n        value = random().to(float)\n        return input.masked_fill(mask > 0, value)\n\n    @autotest(n=3)\n    def test_flow_masked_fill_broadcast_with_random_data(test_case):\n        k1 = random(2, 6)\n        k2 = random(2, 6)\n        device = random_device()\n        input = random_tensor(ndim=2, dim0=1, dim1=k2).to(device)\n        mask = random_tensor(ndim=2, dim0=k1, dim1=1).to(device)\n        value = random().to(float)\n        return input.masked_fill(mask > 0.5, value)\n\n    @autotest(n=3)\n    def test_flow_masked_fill_int_with_random_data(test_case):\n        k1 = random(2, 6)\n        k2 = random(2, 6)\n        device = random_device()\n        input = random_tensor(ndim=2, dim0=k1, dim1=k2).to(device)\n        mask = random_tensor(ndim=2, dim0=k1, dim1=k2).to(device)\n        value = random().to(int)\n        return input.masked_fill(mask > 0.5, value)\n\n    @autotest(auto_backward=False, n=3)\n    def test_flow_masked_fill_bool_with_random_data(test_case):\n        k1 = random(2, 6)\n        k2 = random(2, 6)\n        device = random_device()\n        input = random_tensor(ndim=2, dim0=k1, dim1=k2).to(\n            device=device, dtype=torch.bool\n        )\n        mask = random_tensor(ndim=2, dim0=k1, dim1=k2).to(device)\n        value = random().to(bool)\n        return input.masked_fill(mask > 0.5, value)\n\n    @autotest(auto_backward=False, n=3)\n    def test_flow_masked_fill_inplace_with_random_data(test_case):\n        device = random_device()\n        input = random_tensor(ndim=2, dim0=10, dim1=20).to(device).clone()\n        mask = random_tensor(ndim=2, dim0=10, dim1=20).to(device)\n        value = random().to(float)\n        input.masked_fill_(mask > 0.5, value)\n        return input\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_masked_select.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\ndef _test_masked_select(test_case, device):\n    x = flow.tensor(\n        np.array([[-0.462, 0.3139], [0.3898, -0.7197], [0.0478, -0.1657]]),\n        dtype=flow.float32,\n        device=flow.device(device),\n        requires_grad=True,\n    )\n    mask = x.gt(0.05)\n    of_out = flow.masked_select(x, mask)\n    np_out = np.array([0.3139, 0.3898])\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05))\n    of_out = of_out.sum()\n    of_out.backward()\n    np_grad = np.array([[0, 1], [1, 0], [0, 0]])\n    test_case.assertTrue(np.allclose(x.grad.numpy(), np_grad, 1e-05, 1e-05))\n\n\ndef _test_masked_select_broadcast(test_case, device):\n    x = flow.tensor(\n        np.array([[[-0.462, 0.3139], [0.3898, -0.7197], [0.0478, -0.1657]]]),\n        dtype=flow.float32,\n        device=flow.device(device),\n        requires_grad=True,\n    )\n    mask = flow.tensor(\n        np.array(\n            [\n                [[1.0, 0.0], [1.0, 1.0], [0.0, 1.0]],\n                [[1.0, 0], [1.0, 1.0], [0.0, 1.0]],\n                [[1.0, 1.0], [0.0, 1.0], [1.0, 1.0]],\n            ]\n        ),\n        dtype=flow.int8,\n        device=flow.device(device),\n    )\n    of_out = flow.masked_select(x, mask)\n    np_out = [\n        -0.462,\n        0.3898,\n        -0.7197,\n        -0.1657,\n        -0.462,\n        0.3898,\n        -0.7197,\n        -0.1657,\n        -0.462,\n        0.3139,\n        -0.7197,\n        0.0478,\n        -0.1657,\n    ]\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05))\n    of_out = of_out.sum()\n    of_out.backward()\n    np_grad = [[[3.0, 1.0], [2.0, 3.0], [1.0, 3.0]]]\n    test_case.assertTrue(np.allclose(x.grad.numpy(), np_grad, 1e-05, 1e-05))\n\n\ndef _test_masked_select_input_zero(test_case, device):\n    x = flow.tensor(\n        [[26, 14, 18, 14, 5, 18, 5, 18, 4, 18, 15, 18, 22, 18, 0]],\n        device=flow.device(device),\n        dtype=flow.int64,\n    )\n    f_mask = flow.tensor(\n        [\n            [\n                True,\n                True,\n                True,\n                True,\n                True,\n                True,\n                True,\n                True,\n                True,\n                True,\n                True,\n                True,\n                True,\n                True,\n                True,\n            ]\n        ],\n        device=flow.device(device),\n        dtype=flow.bool,\n    )\n    y = x.masked_select(f_mask)\n    test_case.assertTrue(\n        np.allclose(\n            y.numpy(),\n            [26, 14, 18, 14, 5, 18, 5, 18, 4, 18, 15, 18, 22, 18, 0],\n            1e-05,\n            1e-05,\n        )\n    )\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestMaskedSelect(flow.unittest.TestCase):\n    def test_masked_select(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [\n            _test_masked_select,\n            _test_masked_select_broadcast,\n            _test_masked_select_input_zero,\n        ]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n    def test_masked_select_broadcast(test_case):\n        x = flow.ones(2, 3, 3)\n        mask = flow.triu(flow.ones(3, 3), 1)\n        flow_res = flow.masked_select(x, mask)\n        np_res = [1, 1, 1, 1, 1, 1]\n        test_case.assertTrue(np.allclose(flow_res.numpy(), np_res, 1e-05, 1e-05))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_math_op_higher_derivative.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\n\nimport numpy as np\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\ndef _test_math_op_grad_grad_impl(test_case, op_name):\n    x = random_tensor(ndim=2, low=-2, high=2).requires_grad_(True)\n    y = eval(f\"torch.{op_name}\")(x)\n    np_arr = np.random.rand(*x.oneflow.shape)\n    init_grad = torch.tensor(np_arr).requires_grad_()\n\n    x_grad = torch.autograd.grad(y, x, init_grad, retain_graph=True, create_graph=True)[\n        0\n    ]\n    test_case.assertTrue(\n        np.allclose(\n            x_grad.pytorch.detach().cpu().numpy(),\n            x_grad.oneflow.detach().numpy(),\n            atol=1e-4,\n            rtol=1e-4,\n            equal_nan=True,\n        )\n    )\n\n    x_grad_grad = torch.autograd.grad(x_grad, x, init_grad, retain_graph=True)[0]\n    test_case.assertTrue(\n        np.allclose(\n            x_grad_grad.pytorch.detach().cpu().numpy(),\n            x_grad_grad.oneflow.detach().numpy(),\n            atol=1e-4,\n            rtol=1e-4,\n            equal_nan=True,\n        )\n    )\n\n    init_grad_grad = torch.tensor(np_arr).requires_grad_()\n    dgrad = torch.autograd.grad(x_grad, init_grad, init_grad_grad, retain_graph=True)[0]\n    test_case.assertTrue(\n        np.allclose(\n            dgrad.pytorch.detach().cpu().numpy(),\n            dgrad.oneflow.detach().numpy(),\n            atol=1e-4,\n            rtol=1e-4,\n            equal_nan=True,\n        )\n    )\n\n\nclass TestMathOpHigherDerivative(flow.unittest.TestCase):\n    def test_sin_grad_grad(test_case):\n        _test_math_op_grad_grad_impl(test_case, \"sin\")\n\n    def test_cos_grad_grad(test_case):\n        _test_math_op_grad_grad_impl(test_case, \"cos\")\n\n    def test_tan_grad_grad(test_case):\n        _test_math_op_grad_grad_impl(test_case, \"tan\")\n\n    def test_sinh_grad_grad(test_case):\n        _test_math_op_grad_grad_impl(test_case, \"sinh\")\n\n    def test_cosh_grad_grad(test_case):\n        _test_math_op_grad_grad_impl(test_case, \"cosh\")\n\n    def test_tanh_grad_grad(test_case):\n        _test_math_op_grad_grad_impl(test_case, \"tanh\")\n\n    def test_asin_grad_grad(test_case):\n        _test_math_op_grad_grad_impl(test_case, \"asin\")\n\n    def test_acos_grad_grad(test_case):\n        _test_math_op_grad_grad_impl(test_case, \"acos\")\n\n    def test_atan_grad_grad(test_case):\n        _test_math_op_grad_grad_impl(test_case, \"atan\")\n\n    def test_asinh_grad_grad(test_case):\n        _test_math_op_grad_grad_impl(test_case, \"asinh\")\n\n    def test_acosh_grad_grad(test_case):\n        _test_math_op_grad_grad_impl(test_case, \"acosh\")\n\n    def test_atanh_grad_grad(test_case):\n        _test_math_op_grad_grad_impl(test_case, \"atanh\")\n\n    def test_erf_grad_grad(test_case):\n        _test_math_op_grad_grad_impl(test_case, \"erf\")\n\n    def test_erfc_grad_grad(test_case):\n        _test_math_op_grad_grad_impl(test_case, \"erfc\")\n\n    def test_exp_grad_grad(test_case):\n        _test_math_op_grad_grad_impl(test_case, \"exp\")\n\n    def test_exp2_grad_grad(test_case):\n        _test_math_op_grad_grad_impl(test_case, \"exp2\")\n\n    def test_expm1_grad_grad(test_case):\n        _test_math_op_grad_grad_impl(test_case, \"expm1\")\n\n    def test_log_grad_grad(test_case):\n        _test_math_op_grad_grad_impl(test_case, \"log\")\n\n    def test_logsigmoid_grad_grad(test_case):\n        _test_math_op_grad_grad_impl(test_case, \"nn.functional.logsigmoid\")\n\n    def test_log2_grad_grad(test_case):\n        _test_math_op_grad_grad_impl(test_case, \"log2\")\n\n    def test_log1p_grad_grad(test_case):\n        _test_math_op_grad_grad_impl(test_case, \"log1p\")\n\n    def test_reciprocal_grad_grad(test_case):\n        _test_math_op_grad_grad_impl(test_case, \"reciprocal\")\n\n    def test_rsqrt_grad_grad(test_case):\n        _test_math_op_grad_grad_impl(test_case, \"rsqrt\")\n\n    def test_sqrt_grad_grad(test_case):\n        _test_math_op_grad_grad_impl(test_case, \"sqrt\")\n\n    def test_square_grad_grad(test_case):\n        _test_math_op_grad_grad_impl(test_case, \"square\")\n\n    def test_sigmoid_grad_grad(test_case):\n        _test_math_op_grad_grad_impl(test_case, \"sigmoid\")\n\n    def test_abs_grad_grad(test_case):\n        _test_math_op_grad_grad_impl(test_case, \"abs\")\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_math_ops.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\nfrom oneflow.test_utils.test_util import (\n    GenArgList,\n    type_name_to_flow_type,\n    type_name_to_np_type,\n)\n\nimport torch as torch_original\nfrom packaging import version\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestSinh(flow.unittest.TestCase):\n    @autotest(n=5)\n    def test_flow_sinh_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor().to(device)\n        y = torch.sinh(x)\n        return y\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestSin(flow.unittest.TestCase):\n    @autotest(n=5)\n    def test_flow_sin_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor().to(device)\n        y = x.sin()\n        return y\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestInplaceSin(flow.unittest.TestCase):\n    @autotest(n=5)\n    def test_flow_inplace_sin_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor().to(device)\n        y = x + 1  # transform to non-leaf tensor\n        y.sin_()\n        return y\n\n\ndef _test_cos(test_case, shape, device):\n    input = flow.tensor(\n        np.random.randn(*shape), dtype=flow.float32, device=flow.device(device)\n    )\n    of_out = flow.cos(input)\n    np_out = np.cos(input.numpy())\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05))\n\n\ndef _test_cos_backward(test_case, shape, device):\n    x = flow.tensor(\n        np.random.randn(*shape),\n        dtype=flow.float32,\n        device=flow.device(device),\n        requires_grad=True,\n    )\n    y = flow.cos(x)\n    z = y.sum()\n    z.backward()\n    np_grad = -np.sin(x.numpy())\n    test_case.assertTrue(np.allclose(x.grad.numpy(), np_grad, 1e-05, 1e-05))\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestCos(flow.unittest.TestCase):\n    def test_cos(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [_test_cos, _test_cos_backward]\n        arg_dict[\"shape\"] = [(2, 3), (2, 3, 4), (2, 3, 4, 5)]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestLogModule(flow.unittest.TestCase):\n    @autotest(n=5)\n    def test_log_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor().to(device)\n        return torch.log(x)\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestSqrt(flow.unittest.TestCase):\n    @autotest(n=10, include_complex=True)\n    def test_sqrt_flow_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor().to(device)\n        z = torch.sqrt(x)\n        return z\n\n    @autotest(n=10, include_complex=True)\n    def test_sqrt_tensor_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor().to(device)\n        z = x.sqrt()\n        return z\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestExp(flow.unittest.TestCase):\n    @autotest(n=5)\n    def test_flow_exp_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor().to(device)\n        y = torch.exp(x)\n        return y\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestExp2(flow.unittest.TestCase):\n    @autotest(n=5, auto_backward=\"auto\")\n    def test_flow_exp2_with_random_data(test_case):\n        device = random_device()\n        x_dtype = random_dtype([\"arithmetic\"])\n        x = random_tensor().to(device).to(x_dtype)\n        y = torch.exp2(x)\n        return y\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestRsqrt(flow.unittest.TestCase):\n    @autotest(n=5)\n    def test_rsqrt_flow_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor().to(device)\n        z = torch.rsqrt(x)\n        return z\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestSquare(flow.unittest.TestCase):\n    @autotest(n=5)\n    def test_square_flow_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor().to(device)\n        z = torch.square(x)\n        return z\n\n    @autotest(n=5)\n    def test_square_tensor_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor().to(device)\n        z = x.square()\n        return z\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestPow(flow.unittest.TestCase):\n    @autotest(n=5)\n    def test_pow_float_scalar_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor().to(device)\n        y = random().to(float)\n        return torch.pow(x, y)\n\n    def test_pow_int_scalar_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor().to(device)\n        y = random().to(int)\n        return torch.pow(x, y)\n\n    @autotest(n=10)\n    def test_reverse_pow_int_scalar_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor().to(device)\n        y = random().to(int)\n        return torch.pow(y, x)\n\n    @autotest(n=10)\n    def test_symbolic_reverse_pow_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor().to(device)\n        y = random().to(int)\n        return y ** x\n\n    @autotest(n=5)\n    def test_pow_elementwise_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=2, dim1=2).to(device)\n        y = random_tensor(ndim=2, dim1=2).to(device)\n        return torch.pow(x, y)\n\n    @autotest(n=5)\n    def test_pow_broadcast_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=2, dim1=2).to(device)\n        y = random_tensor(ndim=2, dim1=1).to(device)\n        return torch.pow(x, y)\n\n    @autotest(n=5)\n    def test_pow_broadcast_with_random_data_reverse(test_case):\n        device = random_device()\n        x = random_tensor(ndim=2, dim1=1).to(device)\n        y = random_tensor(ndim=2, dim1=2).to(device)\n        return torch.pow(x, y)\n\n    @autotest(n=5)\n    def test_scalar_pow_with_random_devices(test_case):\n        x1_device = random_device()\n        x2_device = random_device()\n        x1 = random_tensor(2, 2, 3).to(x1_device).mean()\n        x2 = random_tensor(2, 2, 3).to(x2_device)\n        y = torch.pow(x1, x2)\n        return y\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestAsin(flow.unittest.TestCase):\n    @autotest(n=5)\n    def test_flow_asin_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(low=-0.5, high=0.5).to(device)\n        y = torch.asin(x)\n        return y\n\n    @autotest(n=5)\n    def test_flow_arcsin_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(low=-0.5, high=0.5).to(device)\n        y = torch.arcsin(x)\n        return y\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestAsinh(flow.unittest.TestCase):\n    @autotest(n=5)\n    def test_flow_asinh_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor().to(device)\n        y = torch.asinh(x)\n        return y\n\n    @autotest(n=5)\n    def test_flow_arcsinh_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor().to(device)\n        y = torch.arcsinh(x)\n        return y\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestTan(flow.unittest.TestCase):\n    @autotest(n=5)\n    def test_flow_tan_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor().to(device)\n        y = torch.tan(x)\n        return y\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestAtan(flow.unittest.TestCase):\n    @autotest(n=5)\n    def test_flow_atan_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor().to(device)\n        y = torch.atan(x)\n        return y\n\n    @autotest(n=5)\n    def test_flow_arctan_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor().to(device)\n        y = torch.arctan(x)\n        return y\n\n    @autotest(n=5)\n    def test_flow_atan2_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=2, dim1=3).to(device)\n        y = random_tensor(ndim=2, dim1=3).to(device)\n        z = torch.atan2(x, y)\n        return z\n\n    @autotest(n=5)\n    def test_flow_atan2_with_1elem_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=1, dim1=1).to(device)\n        y = random_tensor(ndim=3, dim1=random(1, 6).to(int)).to(device)\n        z = torch.atan2(x, y)\n        return z\n\n    @autotest(n=5)\n    def test_flow_atanh_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(low=-0.5, high=0.5).to(device)\n        y = torch.atanh(x)\n        return y\n\n    @autotest(n=5)\n    def test_flow_arctanh_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(low=-0.5, high=0.5).to(device)\n        y = torch.arctanh(x)\n        return y\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestTopk(flow.unittest.TestCase):\n    @autotest(auto_backward=False)\n    def test_flow_topk_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=4, dim1=8, dim2=9, dim3=10).to(device)\n        y = torch.topk(\n            x,\n            random(low=1, high=8).to(int),\n            dim=random(low=1, high=4).to(int),\n            largest=random_bool(),\n            sorted=constant(True),\n        )\n        return y[0], y[1]\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestTopkReturnValues(flow.unittest.TestCase):\n    @autotest(auto_backward=False)\n    def test_flow_topk_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=4, dim1=8, dim2=9, dim3=10).to(device)\n        result = torch.topk(\n            x,\n            random(low=1, high=8).to(int),\n            dim=random(low=1, high=4).to(int),\n            largest=random_bool(),\n            sorted=constant(True),\n        )\n        return result.values, result.indices\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestPow(flow.unittest.TestCase):\n    @autotest(n=5)\n    def test_pow_scalar_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor().to(device)\n        y = random().to(float)\n        return torch.pow(x, y)\n\n    @autotest(n=5)\n    def test_pow_elementwise_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=2, dim1=2).to(device)\n        y = random_tensor(ndim=2, dim1=2).to(device)\n        return torch.pow(x, y)\n\n    @unittest.skip(\"not support for broadcast currently\")\n    @autotest(n=5)\n    def test_pow_broadcast_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=2, dim1=2).to(device)\n        y = random_tensor(ndim=2, dim1=1).to(device)\n        return torch.pow(x, y)\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestArccos(flow.unittest.TestCase):\n    @autotest(n=5)\n    def test_arccos_flow_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(low=-1, high=1).to(device)\n        y = torch.arccos(x)\n        return y\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestAcos(flow.unittest.TestCase):\n    @autotest(n=5)\n    def test_acos_flow_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(low=-1, high=1).to(device)\n        y = torch.acos(x)\n        return y\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestArccosh(flow.unittest.TestCase):\n    @autotest(n=5)\n    def test_arccosh_flow_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(low=2, high=3).to(device)\n        y = torch.arccosh(x)\n        return y\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestAcosh(flow.unittest.TestCase):\n    @autotest(n=5)\n    def test_acosh_flow_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(low=2, high=3).to(device)\n        y = torch.acosh(x)\n        return y\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestAtan2(flow.unittest.TestCase):\n    @autotest(n=5)\n    def test_flow_atan2_with_random_data(test_case):\n        device = random_device()\n        x1 = random_tensor(ndim=1, dim0=1).to(device)\n        x2 = random_tensor(ndim=1, dim0=1).to(device)\n        y = torch.atan2(x1, x2)\n        return y\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestMinimum(flow.unittest.TestCase):\n    @autotest(n=5)\n    def test_flow_elementwise_minimum_with_random_data(test_case):\n        device = random_device()\n        k1 = random(2, 6)\n        k2 = random(2, 6)\n        x = random_tensor(ndim=2, dim0=k1, dim1=k2).to(device)\n        y = random_tensor(ndim=2, dim0=k1, dim1=k2).to(device)\n        return torch.minimum(x, y)\n\n    @autotest(n=5)\n    def test_flow_broadcast_minimum_with_random_data(test_case):\n        device = random_device()\n        k1 = random(2, 6)\n        k2 = random(2, 6)\n        k3 = random(2, 6)\n        x = random_tensor(ndim=3, dim0=k1, dim1=1, dim2=1).to(device)\n        y = random_tensor(ndim=3, dim0=1, dim1=k2, dim2=k3).to(device)\n        return torch.minimum(x, y)\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestMaximum(flow.unittest.TestCase):\n    @autotest(n=5)\n    def test_flow_elementwise_mximum_with_random_data(test_case):\n        device = random_device()\n        k1 = random(2, 6)\n        k2 = random(2, 6)\n        x = random_tensor(ndim=2, dim0=k1, dim1=k2).to(device)\n        y = random_tensor(ndim=2, dim0=k1, dim1=k2).to(device)\n        return torch.maximum(x, y)\n\n    @autotest(n=5)\n    def test_flow_broadcast_maximum_with_random_data(test_case):\n        device = random_device()\n        k1 = random(2, 6)\n        k2 = random(2, 6)\n        k3 = random(2, 6)\n        x = random_tensor(ndim=3, dim0=k1, dim1=1, dim2=1).to(device)\n        y = random_tensor(ndim=3, dim0=1, dim1=k2, dim2=k3).to(device)\n        return torch.maximum(x, y)\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestFloorDiv(flow.unittest.TestCase):\n    @autotest(auto_backward=False)\n    def test_elementwise_floordiv_random_data(test_case):\n        device = random_device()\n        # The random value is narrowed to positive number because of the error from pytorch 1.10.0\n        # Please remove the value range striction after updating the pytorch version of ci to 1.13.\n        x = random_tensor(ndim=4, dim0=2, dim1=4, dim2=8, dim3=3, low=0, high=10).to(\n            device\n        )\n        y = random_tensor(ndim=4, dim0=2, dim1=4, dim2=8, dim3=3, low=1, high=10).to(\n            device\n        )\n\n        return torch.floor_divide(x, y)\n\n    @autotest(auto_backward=False)\n    def test_tensor_floordiv_scalar_random_data(test_case):\n        device = random_device()\n        # The random value is narrowed to positive number because of the error from pytorch 1.10.0\n        # Please remove the value range striction after updating the pytorch version of ci to 1.13.\n        x = random_tensor(ndim=4, dim0=2, dim1=4, dim2=8, dim3=3, low=0, high=10).to(\n            device\n        )\n        y = random().to(int)\n        return torch.floor_divide(x, y)\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestFmod(flow.unittest.TestCase):\n    # other.grad in torch.fmod(input, other) was not implemented before pytorch 1.11.0\n    grad_implemented = version.parse(torch_original.__version__) >= version.parse(\n        \"1.11.0\"\n    )\n\n    @autotest(auto_backward=grad_implemented)\n    def test_elementwise_fmod_random_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=4, dim0=2, dim1=4, dim2=8, dim3=3).to(device)\n        y = random_tensor(ndim=4, dim0=2, dim1=4, dim2=8, dim3=3).to(device)\n\n        return torch.fmod(x, y)\n\n    @autotest(n=5, auto_backward=grad_implemented)\n    def test_flow_broadcast_fmod_with_random_data(test_case):\n        device = random_device()\n        k1 = random(2, 6)\n        k2 = random(2, 6)\n        k3 = random(2, 6)\n        x = random_tensor(ndim=3, dim0=k1, dim1=1, dim2=1).to(device)\n        y = random_tensor(ndim=3, dim0=1, dim1=k2, dim2=k3).to(device)\n        return torch.fmod(x, y)\n\n    @autotest(auto_backward=grad_implemented)\n    def test_tensor_fmod_scalar_random_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=4, dim0=2, dim1=4, dim2=8, dim3=3).to(device)\n        y = random().to(int)\n        return torch.fmod(x, y)\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestPow(flow.unittest.TestCase):\n    @autotest(auto_backward=False)\n    def test_elementwise_pow_random_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=4, dim0=2, dim1=4, dim2=8, dim3=3).to(device)\n        y = random_tensor(ndim=4, dim0=2, dim1=4, dim2=8, dim3=3).to(device)\n\n        return torch.pow(x, y)\n\n    @autotest(n=5)\n    def test_flow_broadcast_pow_with_random_data(test_case):\n        device = random_device()\n        k1 = random(2, 6)\n        k2 = random(2, 6)\n        k3 = random(2, 6)\n        x = random_tensor(ndim=3, dim0=k1, dim1=1, dim2=1).to(device)\n        y = random_tensor(ndim=3, dim0=1, dim1=k2, dim2=k3).to(device)\n        return torch.pow(x, y)\n\n    @autotest(auto_backward=False)\n    def test_tensor_pow_scalar_random_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=4, dim0=2, dim1=4, dim2=8, dim3=3).to(device)\n        y = random().to(int)\n        return torch.pow(x, y)\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestAbsModule(flow.unittest.TestCase):\n    @autotest(n=5)\n    def test_abs_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor().to(device)\n        return torch.abs(x)\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestCoshModule(flow.unittest.TestCase):\n    @autotest(n=5)\n    def test_cosh_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor().to(device)\n        return torch.cosh(x)\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestLgammaModule(flow.unittest.TestCase):\n    @autotest(n=5)\n    def test_lgamma_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor().to(device)\n        return torch.lgamma(x)\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestLog2Module(flow.unittest.TestCase):\n    @autotest(n=5)\n    def test_log2_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor().to(device)\n        return torch.log2(x)\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestLog10Module(flow.unittest.TestCase):\n    @autotest(n=5)\n    def test_log10_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor().to(device)\n        return torch.log10(x)\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestDigammaModule(flow.unittest.TestCase):\n    @autotest(n=5)\n    def test_digamma_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor().to(device)\n        return torch.digamma(x)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_matmul.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\n\nimport numpy as np\nimport torch as torch_original\nimport oneflow as flow\nimport oneflow.unittest\nimport torch as torch_original\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestModule(flow.unittest.TestCase):\n    @autotest(check_graph=True, rtol=1e-2, atol=1e-3, include_complex=True)\n    def test_flow_matmul_with_random_data(test_case):\n        device = random_device()\n        k = random(1, 6)\n        x = random_tensor(ndim=2, dim1=k).to(device)\n        y = random_tensor(ndim=2, dim0=k).to(device)\n        z = torch.matmul(x, y)\n        return z\n\n    @autotest(check_graph=True, rtol=1e-2, atol=1e-4)\n    def test_flow_tensor_matmul_with_random_data_allow_tf32(test_case):\n        flow.backends.cuda.matmul.allow_tf32 = True\n        torch_original.backends.cuda.matmul.allow_tf32 = True\n        device = random_device()\n        k = random(1, 6)\n        x = random_tensor(ndim=2, dim1=k).to(device)\n        y = random_tensor(ndim=2, dim0=k).to(device)\n        ret = x.matmul(y)\n        flow.backends.cuda.matmul.allow_tf32 = False\n        torch_original.backends.cuda.matmul.allow_tf32 = False\n        return ret\n\n    @autotest(check_graph=True, rtol=1e-2, atol=1e-4)\n    def test_flow_tensor_matmul_with_random_data(test_case):\n        device = random_device()\n        k = random(1, 6)\n        x = random_tensor(ndim=2, dim1=k).to(device)\n        y = random_tensor(ndim=2, dim0=k).to(device)\n        return x.matmul(y)\n\n    @autotest(n=5, check_graph=False)\n    def test_flow_tensor_matmul_with_random_int_data(test_case):\n        x = np.random.randint(10, 21, size=5)\n        y = np.random.randint(1, 14, size=(5, 4))\n        torch_x = torch.from_numpy(x).to(torch.int)\n        torch_y = torch.from_numpy(y).to(torch.int)\n        torch_output_numpy = torch_x.matmul(torch_y).numpy()\n        flow_x = flow.tensor(x).to(flow.int)\n        flow_y = flow.tensor(y).to(flow.int)\n        flow_output_numpy = flow_x.matmul(flow_y).numpy()\n        test_case.assertTrue(\n            np.allclose(flow_output_numpy, torch_output_numpy, 1e-05, 1e-05)\n        )\n\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    @autotest(n=5, check_graph=False)\n    def test_flow_tensor_matmul_with_random_fp16_data(test_case):\n        x = np.random.rand(3, 5)\n        y = np.random.rand(5, 4)\n        torch_x = torch.from_numpy(x).to(device=gpu_device(), dtype=torch.float16)\n        torch_y = torch.from_numpy(y).to(device=gpu_device(), dtype=torch.float16)\n        torch_output_numpy = torch_x.matmul(torch_y).cpu().numpy()\n        flow_x = flow.tensor(x).to(device=\"cuda\", dtype=flow.float16)\n        flow_y = flow.tensor(y).to(device=\"cuda\", dtype=flow.float16)\n        flow_output_numpy = flow_x.matmul(flow_y).cpu().numpy()\n        test_case.assertTrue(\n            np.allclose(flow_output_numpy, torch_output_numpy, 1e-05, 1e-05)\n        )\n\n    @autotest(n=5, check_graph=True, rtol=1e-2, atol=1e-3)\n    def test_flow_tensor_broadcast_matmul_with_random_data(test_case):\n        device = random_device()\n        k = random(1, 6)\n        x = random_tensor(ndim=4, dim3=k).to(device)\n        y = random_tensor(ndim=2, dim0=k).to(device)\n        return x.matmul(y)\n\n    @autotest(n=10, check_graph=True, rtol=1e-2, atol=1e-3, include_complex=True)\n    def test_flow_tensor_x_broadcast_y_matmul(test_case):\n        device = random_device()\n        k = random(1, 6)\n        x = random_tensor(ndim=2, dim1=k).to(device)\n        y = random_tensor(ndim=4, dim2=k).to(device)\n        return x.matmul(y)\n\n    @autotest(n=10, check_graph=True, rtol=1e-2, atol=1e-4, include_complex=True)\n    def test_flow_tensor_broadcast_matmul_with_same_dims(test_case):\n        device = random_device()\n        k = random(1, 6)\n        x = random_tensor(ndim=4, dim1=1, dim3=k).to(device)\n        y = random_tensor(ndim=4, dim0=1, dim2=k).to(device)\n        return x.matmul(y)\n\n    @autotest(check_graph=True, rtol=1e-2, atol=1e-3, include_complex=True)\n    def test_flow_mm_with_random_data(test_case):\n        device = random_device()\n        k = random(1, 6)\n        x = random_tensor(ndim=2, dim1=k).to(device)\n        y = random_tensor(ndim=2, dim0=k).to(device)\n        z = torch.mm(x, y)\n        return z\n\n    @autotest(n=10, check_graph=True, include_complex=True)\n    def test_flow_mv_with_random_data(test_case):\n        device = random_device()\n        k = random(1, 6)\n        x = random_tensor(ndim=2, dim1=k).to(device)\n        y = random_tensor(ndim=1, dim0=k).to(device)\n        z = torch.mv(x, y)\n        return z\n\n    @profile(torch.mv)\n    def profile_mv(test_case):\n        torch.mv(torch.ones(32, 64), torch.ones(64))\n\n    @autotest(n=10, check_graph=True, rtol=1e-2, atol=1e-4, include_complex=True)\n    def test_flow_vector_matrix_product_with_random_data(test_case):\n        device = random_device()\n        k = random(1, 6)\n        x = random_tensor(ndim=1, dim0=k).to(device)\n        y = random_tensor(ndim=2, dim0=k).to(device)\n        z = torch.matmul(x, y)\n        return z\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_max.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nimport numpy as np\n\nfrom oneflow.test_utils.automated_test_util import *\nimport oneflow as flow\nimport oneflow.unittest\n\n\ndef _test_scalar_max(test_case, device):\n    y = flow.max(flow.tensor(1.0, device=device), flow.tensor([2], device=device))\n    test_case.assertTrue(np.allclose(y.numpy(), [2], 1e-05, 1e-05))\n    y = flow.max(flow.tensor(1.0, device=device), flow.tensor(2, device=device))\n    test_case.assertTrue(np.allclose(y.numpy(), [2], 1e-05, 1e-05))\n    y = flow.max(flow.tensor([1.0], device=device), flow.tensor(2, device=device))\n    test_case.assertTrue(np.allclose(y.numpy(), [2], 1e-05, 1e-05))\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestMaxModule(flow.unittest.TestCase):\n    def test_scalar_max(test_case):\n        _test_scalar_max(test_case, \"cpu\")\n\n    @autotest(n=5, check_allclose=False, check_graph=True)\n    def test_max_reduce_random_dim(test_case):\n        device = random_device()\n        ndim = random().to(int).value()\n        x = random_tensor(ndim=ndim, dim0=random(1, 8))\n        y = x.to(device)\n        dim = random(-ndim, ndim).to(int).value()\n        keep_dims = random_bool().value()\n        y = torch.max(x, dim=dim, keepdim=keep_dims)\n\n        # pytorch result is an instance of class 'torch.return_types.max', but oneflow is tuple\n        test_case.assertTrue(\n            np.allclose(\n                y.oneflow[0].detach().cpu().numpy(),\n                y.pytorch.values.detach().cpu().numpy(),\n                rtol=0.0001,\n                atol=1e-05,\n            )\n        )\n        test_case.assertTrue(\n            np.allclose(\n                y.oneflow[1].detach().cpu().numpy(),\n                y.pytorch.indices.detach().cpu().numpy(),\n                rtol=0.0001,\n                atol=1e-05,\n            )\n        )\n\n        y.oneflow[0].sum().backward()\n        y.pytorch.values.sum().backward()\n        test_case.assertTrue(\n            np.allclose(\n                x.oneflow.grad.detach().cpu().numpy(),\n                x.pytorch.grad.detach().cpu().numpy(),\n                rtol=0.0001,\n                atol=1e-05,\n            )\n        )\n\n    @autotest(n=5, check_graph=True)\n    def test_max_reduce_all_dim(test_case):\n        device = random_device()\n        ndim = random().to(int).value()\n        x = random_tensor(ndim=ndim, dim0=random(1, 8)).to(device)\n        return torch.max(x)\n\n    @autotest(n=5, check_graph=True)\n    def test_max_elementwise(test_case):\n        device = random_device()\n        ndim = random().to(int).value()\n        dims = [random(1, 8) for _ in range(ndim)]\n        x = random_tensor(ndim, *dims).to(device)\n        y = random_tensor(ndim, *dims).to(device)\n        return torch.max(x, y)\n\n    @autotest(n=5, check_graph=True, check_dtype=True)\n    def test_max_elementwise_dtype_promotion(test_case):\n        device = random_device()\n        ndim = random().to(int).value()\n        dims = [random(1, 8) for _ in range(ndim)]\n        x = random_tensor(ndim, *dims, dtype=float).to(device)\n        y = random_tensor(ndim, *dims, dtype=int).to(device)\n        return torch.max(x, y)\n\n    @autotest(n=5, check_graph=True, check_dtype=True)\n    def test_max_broadcast_dtype_promotion(test_case):\n        device = random_device()\n        ndim = random().to(int).value()\n        dims = [random(1, 8) for _ in range(ndim)]\n        b_dims = [1 for _ in range(ndim)]\n        x = random_tensor(ndim, *dims, dtype=float).to(device)\n        y = random_tensor(ndim, *b_dims, dtype=int).to(device)\n        return torch.max(x, y)\n\n    @autotest(n=3, auto_backward=True, check_graph=True)\n    def test_max_with_diff_size(test_case):\n        x = flow.rand(1, 1, 4, requires_grad=True)\n        y = flow.rand(1, 4, requires_grad=True)\n        x = random_tensor(3, 1, 1, 4)\n        y = random_tensor(2, 1, 4)\n        return torch.max(x, y)\n\n    @autotest(n=3, auto_backward=False)\n    def test_max_return_type(test_case):\n        x = random_tensor(3, 4)\n        result = x.max(1)\n        return result.values, result.indices\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_maxpool.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nfrom collections import OrderedDict\nfrom pkg_resources import packaging\n\nimport numpy as np\nimport torch as pytorch\n\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\nfrom oneflow.test_utils.test_util import GenArgList\nfrom oneflow.nn.common_types import _size_1_t, _size_2_t, _size_3_t\n\n\ndef _test_maxpool2d_channel_last(\n    test_case, device, shape, kernel_size, stride, padding, dilation, ceil_mode\n):\n    os.environ[\"ONEFLOW_ENABLE_NHWC\"] = \"1\"\n    arr = np.random.randn(*shape)\n    x1 = flow.tensor(arr, dtype=flow.float64, device=device)\n    m1 = flow.nn.MaxPool2d(\n        kernel_size=kernel_size,\n        stride=stride,\n        padding=padding,\n        dilation=dilation,\n        ceil_mode=ceil_mode,\n    )\n    y1 = m1(x1)\n\n    x2 = pytorch.tensor(arr.transpose(0, 3, 1, 2), dtype=pytorch.float64, device=device)\n    m2 = pytorch.nn.MaxPool2d(\n        kernel_size=kernel_size,\n        stride=stride,\n        padding=padding,\n        dilation=dilation,\n        ceil_mode=ceil_mode,\n    )\n    y2 = m2(x2).permute(0, 2, 3, 1)\n    os.environ[\"ONEFLOW_ENABLE_NHWC\"] = \"0\"\n    # The test fails with pytorch 1.10 but success with pytorch1.13. It should be took back after updating to pytorch1.13.\n    # test_case.assertTrue(\n    #     np.allclose(y1.detach().cpu().numpy(), y2.detach().cpu().numpy(), 1e-4, 1e-4)\n    # )\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestMaxPooling(flow.unittest.TestCase):\n    @autotest(n=5, auto_backward=True, check_graph=True)\n    def test_maxpool1d_with_random_data(test_case):\n        return_indices = random().to(bool).value()\n        m = torch.nn.MaxPool1d(\n            kernel_size=random(4, 6).to(_size_1_t),\n            stride=random(1, 3).to(_size_1_t) | nothing(),\n            padding=random(1, 3).to(_size_1_t) | nothing(),\n            dilation=random(2, 4).to(_size_1_t) | nothing(),\n            ceil_mode=random(),\n            return_indices=return_indices,\n        )\n        m.train(random())\n        device = random_device()\n        m.to(device)\n        x = random_tensor(ndim=3, dim2=random(20, 22)).to(device)\n        y = m(x)\n\n        # NOTE(lixiang): When return_indices=False, maxpool1d will return the max indices along with the outputs,\n        #   y[1] tensor has no grad_fn and cannot be backward, so only y[0] is verified here.\n        if return_indices:\n            return y[0]\n        else:\n            return y\n\n    @autotest(n=5)\n    def test_maxpool1d_with_2d_input_tensor(test_case):\n        return_indices = random().to(bool).value()\n        m = torch.nn.MaxPool1d(\n            kernel_size=random(4, 6).to(_size_1_t),\n            stride=random(1, 3).to(_size_1_t) | nothing(),\n            padding=random(1, 3).to(_size_1_t) | nothing(),\n            dilation=random(2, 4).to(_size_1_t) | nothing(),\n            ceil_mode=random(),\n            return_indices=return_indices,\n        )\n        m.train(random())\n        device = random_device()\n        m.to(device)\n        x = random_tensor(ndim=2, dim1=random(20, 22)).to(device)\n        y = m(x)\n        if return_indices:\n            return y[0]\n        else:\n            return y\n\n    @autotest(n=10, auto_backward=True, check_graph=True)\n    def test_maxpool2d_with_random_data(test_case):\n        return_indices = random().to(bool).value()\n        m = torch.nn.MaxPool2d(\n            kernel_size=random(4, 6).to(_size_2_t),\n            stride=random(1, 3).to(_size_2_t) | nothing(),\n            padding=random(1, 3).to(_size_2_t) | nothing(),\n            dilation=random(2, 4).to(_size_2_t) | nothing(),\n            ceil_mode=random(),\n            return_indices=return_indices,\n        )\n        m.train(random())\n        device = random_device()\n        m.to(device)\n        x = random_tensor(ndim=4, dim2=random(20, 22), dim3=random(20, 22)).to(device)\n        y = m(x)\n\n        # NOTE(lixiang): When return_indices=False, maxpool2d will return the max indices along with the outputs,\n        #   y[1] tensor has no grad_fn and cannot be backward, so only y[0] is verified here.\n        if return_indices:\n            return y[0]\n        else:\n            return y\n\n    @autotest(n=5)\n    def test_maxpool2d_with_3d_input_tensor(test_case):\n        return_indices = random().to(bool).value()\n        m = torch.nn.MaxPool2d(\n            kernel_size=random(4, 6).to(_size_2_t),\n            stride=random(1, 3).to(_size_2_t) | nothing(),\n            padding=random(1, 3).to(_size_2_t) | nothing(),\n            dilation=random(2, 4).to(_size_2_t) | nothing(),\n            ceil_mode=random(),\n            return_indices=return_indices,\n        )\n        m.train(random())\n        device = random_device()\n        m.to(device)\n        x = random_tensor(ndim=3, dim1=random(20, 22), dim2=random(20, 22)).to(device)\n        y = m(x)\n        if return_indices:\n            return y[0]\n        else:\n            return y\n\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    @autotest(n=5, auto_backward=False)\n    def test_maxpool2d_with_half_data(test_case):\n        return_indices = random().to(bool).value()\n        m = torch.nn.MaxPool2d(\n            kernel_size=random(4, 6).to(_size_2_t),\n            stride=random(1, 3).to(_size_2_t) | nothing(),\n            padding=random(1, 3).to(_size_2_t) | nothing(),\n            dilation=random(2, 4).to(_size_2_t) | nothing(),\n            ceil_mode=random(),\n            return_indices=return_indices,\n        )\n        m.train(random())\n        device = gpu_device()\n        m.to(device)\n        x = (\n            random_tensor(ndim=4, dim2=random(20, 22), dim3=random(20, 22))\n            .to(device)\n            .to(torch.float16)\n        )\n        y = m(x)\n        if return_indices:\n            return y[0]\n        else:\n            return y\n\n    @autotest(n=5, auto_backward=True, check_graph=True)\n    def test_maxpool3d_with_random_data(test_case):\n        return_indices = random().to(bool).value()\n        m = torch.nn.MaxPool3d(\n            kernel_size=random(4, 6).to(_size_3_t),\n            stride=random(1, 3).to(_size_3_t) | nothing(),\n            padding=random(1, 3).to(_size_3_t) | nothing(),\n            dilation=random(2, 4).to(_size_3_t) | nothing(),\n            ceil_mode=random(),\n            return_indices=return_indices,\n        )\n        m.train(random())\n        device = random_device()\n        m.to(device)\n        x = random_tensor(\n            ndim=5, dim2=random(20, 22), dim3=random(20, 22), dim4=random(20, 22)\n        ).to(device)\n        y = m(x)\n\n        # NOTE(lixiang): When return_indices=False, maxpool3d will return the max indices along with the outputs,\n        #   y[1] tensor has no grad_fn and cannot be backward, so only y[0] is verified here.\n        if return_indices:\n            return y[0]\n        else:\n            return y\n\n    @autotest(n=5)\n    def test_maxpool3d_with_4d_input_tensor(test_case):\n        return_indices = random().to(bool).value()\n        m = torch.nn.MaxPool3d(\n            kernel_size=random(4, 6).to(_size_3_t),\n            stride=random(1, 3).to(_size_3_t) | nothing(),\n            padding=random(1, 3).to(_size_3_t) | nothing(),\n            dilation=random(2, 4).to(_size_3_t) | nothing(),\n            ceil_mode=random(),\n            return_indices=return_indices,\n        )\n        m.train(random())\n        device = random_device()\n        m.to(device)\n        x = random_tensor(\n            ndim=4, dim1=random(20, 22), dim2=random(20, 22), dim3=random(20, 22)\n        ).to(device)\n        y = m(x)\n        if return_indices:\n            return y[0]\n        else:\n            return y\n\n    @unittest.skipIf(\n        packaging.version.parse(pytorch.__version__)\n        == packaging.version.parse(\"1.10.0\"),\n        \"skip when pytorch version == 1.10.0\",\n    )\n    # NOTE:pytorch maxpool2d nhwc has bug in version of 1.10.0, so skip it in CI.\n    # detail:https://github.com/pytorch/pytorch/pull/76597\n    def test_maxpool2d_channel_last(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [_test_maxpool2d_channel_last]\n        arg_dict[\"device\"] = [\"cuda\"]\n        # CPU pool is very slow, so don't run it with CUDA\n        if os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"):\n            arg_dict[\"device\"] = [\"cpu\"]\n        arg_dict[\"shape\"] = [(3, 14, 27, 3), (5, 9, 14, 10), (2, 224, 224, 3)]\n        arg_dict[\"kernel_size\"] = [3, (2, 3), (3, 4)]\n        arg_dict[\"stride\"] = [1, (1, 2), 2]\n        arg_dict[\"padding\"] = [0, (0, 1)]\n        arg_dict[\"dilation\"] = [1, (1, 2), 2]\n        arg_dict[\"ceil_mode\"] = [True, False]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestMaxPoolingFunctional(flow.unittest.TestCase):\n    @autotest(n=5, auto_backward=True, check_graph=True)\n    def test_maxpool1d_with_random_data(test_case):\n        return_indices = random().to(bool).value()\n        device = random_device()\n        x = random_tensor(ndim=3, dim2=random(20, 22)).to(device)\n        y = torch.nn.functional.max_pool1d(\n            x,\n            kernel_size=random(4, 6).to(int),\n            stride=random(1, 3).to(int) | nothing(),\n            padding=random(1, 3).to(int) | nothing(),\n            dilation=random(2, 4).to(int) | nothing(),\n            ceil_mode=random().to(bool),\n            return_indices=return_indices,\n        )\n\n        # NOTE(lixiang): When return_indices=False, maxpool1d will return the max indices along with the outputs,\n        #   y[1] tensor has no grad_fn and cannot be backward, so only y[0] is verified here.\n        if return_indices:\n            return y[0]\n        else:\n            return y\n\n    @autotest(n=5, auto_backward=True, check_graph=True)\n    def test_maxpool2d_with_random_data(test_case):\n        return_indices = random().to(bool).value()\n        device = random_device()\n        x = random_tensor(ndim=4, dim2=random(20, 22), dim3=random(20, 22)).to(device)\n        y = torch.nn.functional.max_pool2d(\n            x,\n            kernel_size=random(4, 6).to(int),\n            stride=random(1, 3).to(int) | nothing(),\n            padding=random(1, 3).to(int) | nothing(),\n            dilation=random(2, 4).to(int) | nothing(),\n            ceil_mode=random().to(bool),\n            return_indices=return_indices,\n        )\n\n        # NOTE(lixiang): When return_indices=False, maxpool2d will return the max indices along with the outputs,\n        #   y[1] tensor has no grad_fn and cannot be backward, so only y[0] is verified here.\n        if return_indices:\n            return y[0]\n        else:\n            return y\n\n    @autotest(auto_backward=True, check_graph=True)\n    def test_maxpool3d_with_random_data(test_case):\n        return_indices = random().to(bool).value()\n        device = random_device()\n        x = random_tensor(\n            ndim=5, dim2=random(20, 22), dim3=random(20, 22), dim4=random(20, 22)\n        ).to(device)\n        y = torch.nn.functional.max_pool3d(\n            x,\n            kernel_size=random(4, 6).to(int),\n            stride=random(1, 3).to(int) | nothing(),\n            padding=random(1, 3).to(int) | nothing(),\n            dilation=random(2, 4).to(int) | nothing(),\n            ceil_mode=random().to(bool),\n            return_indices=return_indices,\n        )\n\n        # NOTE(lixiang): When return_indices=False, maxpool3d will return the max indices along with the outputs,\n        #   y[1] tensor has no grad_fn and cannot be backward, so only y[0] is verified here.\n        if return_indices:\n            return y[0]\n        else:\n            return y\n\n    @profile(torch.nn.functional.max_pool2d)\n    def profile_maxpool2d(test_case):\n        torch.nn.functional.max_pool2d(\n            torch.ones(1, 128, 28, 28), kernel_size=3, padding=1\n        )\n        torch.nn.functional.max_pool2d(\n            torch.ones(1, 128, 28, 28), kernel_size=3, stride=2, padding=1\n        )\n        torch.nn.functional.max_pool2d(\n            torch.ones(16, 128, 28, 28), kernel_size=3, padding=1\n        )\n        torch.nn.functional.max_pool2d(\n            torch.ones(16, 128, 28, 28), kernel_size=3, stride=2, padding=1\n        )\n        torch.nn.functional.max_pool2d(\n            torch.ones(16, 128, 28, 28),\n            kernel_size=3,\n            stride=2,\n            padding=1,\n            ceil_mode=True,\n        )\n        # torch.nn.functional.max_pool2d(torch.ones(16, 128, 28, 28), kernel_size=3, dilation=2, padding=2)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_maxunpool.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport random as random_util\nimport unittest\n\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\nfrom oneflow.nn.common_types import _size_1_t, _size_2_t, _size_3_t\n\n# y = pool(x), z = unpool(y, indices), pool_input_shape is x.shape, pool_output_shape is y.shape.\n# When `output_size` in unpool() is empty, the op will calculate the output size according to\n# kernel_size, stride and padding. But when index in indices is outside the range required\n# by output_size calculated by unpool op, the value of result and related grad will be unknown.\n# To avoid the problem, this function calculate the output_size which will not cause unknown problems.\ndef _get_valid_output_size(\n    pool_input_shape, pool_output_shape, kernel_size, stride, padding\n):\n    def convert_data(data, i, dst_data=None):\n        if not isinstance(data, (list, int)):\n            return dst_data\n        if isinstance(data, list):\n            return data[i]\n        return data\n\n    _, _, *pool_input_hwd_shape = pool_input_shape.pytorch\n    batch_size, num_channels, *pool_out_hwd_shape = pool_output_shape.pytorch\n    unpool_output_shape = [batch_size, num_channels]\n    for i, (pool_input_size, pool_output_size) in enumerate(\n        zip(pool_input_hwd_shape, pool_out_hwd_shape)\n    ):\n        kernel_size_value = convert_data(kernel_size.value(), i)\n        stride_value = convert_data(stride.value(), i, kernel_size_value)\n        padding_value = convert_data(padding.value(), i, 0)\n        unpool_output_size = max(\n            pool_input_size,\n            (pool_output_size - 1) * stride_value\n            - 2 * padding_value\n            + kernel_size_value,\n        )\n        unpool_output_shape.append(unpool_output_size)\n    return torch.Size(unpool_output_shape)\n\n\ndef _test_module_unpoolnd(test_case, n):\n    device = random_device()\n    if n == 1:\n        _size_n_t = _size_1_t\n        MaxPoolNd = torch.nn.MaxPool1d\n        MaxUnpoolNd = torch.nn.MaxUnpool1d\n        x = random_tensor(ndim=3, dim2=random(20, 31), requires_grad=False).to(device)\n    elif n == 2:\n        _size_n_t = _size_2_t\n        MaxPoolNd = torch.nn.MaxPool2d\n        MaxUnpoolNd = torch.nn.MaxUnpool2d\n        x = random_tensor(\n            ndim=4, dim2=random(20, 31), dim3=random(20, 31), requires_grad=False\n        ).to(device)\n    elif n == 3:\n        _size_n_t = _size_3_t\n        MaxPoolNd = torch.nn.MaxPool3d\n        MaxUnpoolNd = torch.nn.MaxUnpool3d\n        x = random_tensor(\n            ndim=5,\n            dim2=random(20, 31),\n            dim3=random(20, 31),\n            dim4=random(20, 31),\n            requires_grad=False,\n        ).to(device)\n\n    kernel_size = random(4, 6).to(_size_n_t)\n    stride = random(1, 3).to(_size_n_t) | nothing()\n    padding = random(1, 3).to(_size_n_t) | nothing()\n    m = MaxPoolNd(\n        kernel_size=kernel_size, stride=stride, padding=padding, return_indices=True,\n    )\n    m.train(random())\n    m.to(device)\n    y = m(x)\n    pooling_results_dtype = random_util.choice(\n        [torch.int, torch.long, torch.float, torch.double]\n    )\n    indices_dtype = random_util.choice([torch.int, torch.long])\n    pooling_results = y[0].to(pooling_results_dtype)\n    indices = y[1].to(indices_dtype)\n    pooling_results.requires_grad_()\n    output_size = _get_valid_output_size(\n        x.shape, pooling_results.shape, kernel_size, stride, padding\n    )\n    unpool_module = MaxUnpoolNd(\n        kernel_size=kernel_size, stride=stride, padding=padding,\n    )\n    result = unpool_module(pooling_results, indices, output_size=output_size)\n    return result\n\n\ndef _test_functional_unpoolnd(test_case, n):\n    device = random_device()\n\n    if n == 1:\n        _size_n_t = _size_1_t\n        MaxPoolNd = torch.nn.MaxPool1d\n        max_unpool_nd = torch.nn.functional.max_unpool1d\n        x = random_tensor(ndim=3, dim2=random(20, 31), requires_grad=False).to(device)\n    elif n == 2:\n        _size_n_t = _size_2_t\n        MaxPoolNd = torch.nn.MaxPool2d\n        max_unpool_nd = torch.nn.functional.max_unpool2d\n        x = random_tensor(\n            ndim=4, dim2=random(20, 31), dim3=random(20, 31), requires_grad=False\n        ).to(device)\n    elif n == 3:\n        _size_n_t = _size_3_t\n        MaxPoolNd = torch.nn.MaxPool3d\n        max_unpool_nd = torch.nn.functional.max_unpool3d\n        x = random_tensor(\n            ndim=5,\n            dim2=random(20, 31),\n            dim3=random(20, 31),\n            dim4=random(20, 31),\n            requires_grad=False,\n        ).to(device)\n\n    kernel_size = random(4, 6).to(_size_n_t)\n    stride = random(1, 3).to(_size_n_t) | nothing()\n    padding = random(1, 3).to(_size_n_t) | nothing()\n    m = MaxPoolNd(\n        kernel_size=kernel_size, stride=stride, padding=padding, return_indices=True,\n    )\n    m.train(random())\n    m.to(device)\n    y = m(x)\n    pooling_results_dtype = random_util.choice(\n        [torch.int, torch.long, torch.float, torch.double]\n    )\n    indices_dtype = random_util.choice([torch.int, torch.long])\n    pooling_results = y[0].to(pooling_results_dtype)\n    indices = y[1].to(indices_dtype)\n    pooling_results.requires_grad_()\n    output_size = _get_valid_output_size(\n        x.shape, pooling_results.shape, kernel_size, stride, padding\n    )\n    return max_unpool_nd(\n        pooling_results,\n        indices,\n        kernel_size=kernel_size,\n        stride=stride,\n        padding=padding,\n        output_size=output_size,\n    )\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestMaxUnpooling(flow.unittest.TestCase):\n    @autotest(n=3, check_graph=False)\n    def test_max_unpool1d_with_random_data(test_case):\n        return _test_module_unpoolnd(test_case, 1)\n\n    @autotest(n=3, check_graph=False)\n    def test_functional_max_unpool1d_with_random_data(test_case):\n        return _test_functional_unpoolnd(test_case, 1)\n\n    @autotest(n=3, check_graph=False)\n    def test_max_unpool2d_with_random_data(test_case):\n        return _test_module_unpoolnd(test_case, 2)\n\n    @autotest(n=3, check_graph=False)\n    def test_functional_max_unpool2d_with_random_data(test_case):\n        return _test_functional_unpoolnd(test_case, 2)\n\n    @autotest(n=3, check_graph=False)\n    def test_max_unpool3d_with_random_data(test_case):\n        return _test_module_unpoolnd(test_case, 3)\n\n    @autotest(n=3, check_graph=False)\n    def test_functional_max_unpool3d_with_random_data(test_case):\n        return _test_functional_unpoolnd(test_case, 3)\n\n    @profile(torch.nn.functional.max_unpool1d)\n    def profile_max_unpool1d(test_case):\n        max_pool_results = torch.randn(1, 32, 64)\n        max_pool_indices = torch.arange(64).expand(1, 32, 64)\n        torch.nn.functional.max_unpool1d(max_pool_results, max_pool_indices, 2)\n\n        max_pool_results = torch.randn(32, 32, 64)\n        max_pool_indices = torch.arange(64).expand(32, 32, 64)\n        torch.nn.functional.max_unpool1d(max_pool_results, max_pool_indices, 2)\n\n    @profile(torch.nn.functional.max_unpool2d)\n    def profile_max_unpool2d(test_case):\n        max_pool_results = torch.randn(1, 16, 32, 32)\n        max_pool_indices = torch.arange(32).expand(1, 16, 32, 32)\n        torch.nn.functional.max_unpool2d(max_pool_results, max_pool_indices, 2)\n\n        max_pool_results = torch.randn(32, 16, 32, 32)\n        max_pool_indices = torch.arange(32).expand(32, 16, 32, 32)\n        torch.nn.functional.max_unpool2d(max_pool_results, max_pool_indices, 2)\n\n    @profile(torch.nn.functional.max_unpool3d)\n    def profile_max_unpool3d(test_case):\n        max_pool_results = torch.randn(1, 4, 32, 32, 32)\n        max_pool_indices = torch.arange(32).expand(1, 4, 32, 32, 32)\n        torch.nn.functional.max_unpool3d(max_pool_results, max_pool_indices, 2)\n\n        max_pool_results = torch.randn(16, 4, 32, 32, 32)\n        max_pool_indices = torch.arange(32).expand(16, 4, 32, 32, 32)\n        torch.nn.functional.max_unpool3d(max_pool_results, max_pool_indices, 2)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_mean.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\n\nfrom oneflow.test_utils.automated_test_util import *\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\ndef _test_mean(test_case, shape, device):\n    input = flow.tensor(\n        np.random.randn(*shape), dtype=flow.float32, device=flow.device(device)\n    )\n    of_out = flow.mean(input, dim=1)\n    np_out = np.mean(input.numpy(), axis=1)\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001))\n    input = flow.tensor(\n        np.random.randn(*shape), dtype=flow.float32, device=flow.device(device)\n    )\n    of_out = flow.mean(input, dim=0)\n    np_out = np.mean(input.numpy(), axis=0)\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001))\n\n\ndef _test_mean_negative_dim(test_case, shape, device):\n    if len(shape) < 4:\n        shape = (2, 3, 4, 5)\n    input = flow.tensor(\n        np.random.randn(*shape), dtype=flow.float32, device=flow.device(device)\n    )\n    of_out = flow.mean(input, dim=(-2, -1, -3))\n    np_out = np.mean(input.numpy(), axis=(-2, -1, -3))\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001))\n\n\ndef _test_mean_backward(test_case, shape, device):\n    np_arr = np.random.randn(*shape)\n    x = flow.tensor(\n        np_arr, dtype=flow.float32, device=flow.device(device), requires_grad=True\n    )\n    y = flow.mean(x, dim=1)\n    z = y.sum()\n    z.backward()\n    np_grad = np.zeros(shape=np_arr.shape)\n    np_grad[:] = 1 / x.size(1)\n    test_case.assertTrue(np.allclose(x.grad.numpy(), np_grad, 1e-05, 1e-05))\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestMean(flow.unittest.TestCase):\n    def test_mean(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [\n            _test_mean,\n            _test_mean_negative_dim,\n            _test_mean_backward,\n        ]\n        arg_dict[\"shape\"] = [(2, 3), (2, 3, 4), (2, 4, 5, 6)]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n    @autotest(check_graph=True)\n    def test_mean_with_random_data(test_case):\n        device = random_device()\n        dim = random(1, 4).to(int)\n        x = random_tensor(ndim=4, dtype=float).to(device)\n        return torch.mean(x, dim)\n\n    @autotest(n=5)\n    def test_mean_with_scalar_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=4, dtype=float).to(device).mean()\n        y = x.mean(-1)\n        return y\n\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    @autotest(n=5, atol=1e-3)\n    def test_mean_with_float16_data(test_case):\n        device = gpu_device()\n        dim = random(1, 4).to(int)\n        x = random_tensor(ndim=4, dtype=float).to(device=device, dtype=torch.float16)\n        return torch.mean(x, dim)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_median.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nfrom oneflow.test_utils.automated_test_util import *\nimport oneflow as flow\nimport oneflow.unittest\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestMedianModule(flow.unittest.TestCase):\n    @autotest(n=5)\n    def test_median_reduce_all_dim(test_case):\n        device = random_device()\n        ndim = random(1, 4).to(int).value()\n        x = random_tensor(ndim=ndim, dim0=random(1, 4)).to(device)\n        return torch.median(x)\n\n    @autotest(n=5)\n    def test_median_reduce_one_dim(test_case):\n        device = random_device()\n        ndim = random(low=2).to(int).value()\n        reduce_dim = random(high=ndim).to(int).value()\n        x = random_tensor(ndim).to(device)\n        return torch.median(x, reduce_dim)\n\n    @autotest(n=5)\n    def test_median_reduce_one_dim_keepdim(test_case):\n        device = random_device()\n        ndim = random(low=2).to(int).value()\n        reduce_dim = random(high=ndim).to(int).value()\n        x = random_tensor(ndim).to(device)\n        return torch.median(x, reduce_dim, True)\n\n    @autotest(n=5, auto_backward=False, check_graph=True)\n    def test_median_0size(test_case):\n        device = random_device()\n        x = random_tensor(ndim=3, dim1=0, requires_grad=False).to(device)\n        return torch.median(x)\n\n    @autotest(n=5, auto_backward=False, check_graph=True)\n    def test_median_reduce_one_dim_0size(test_case):\n        device = random_device()\n        x = random_tensor(ndim=3, dim1=0, requires_grad=False).to(device)\n        return torch.median(x, 0)\n\n    @autotest(n=5, auto_backward=False)\n    def test_median_return_type(test_case):\n        x = random_tensor(3, 4)\n        result = x.median(1)\n        return result.values, result.indices\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_meshgrid.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\ndef _test_meshgrid_forawd(test_case, device, indexing):\n    input1 = flow.tensor(\n        np.array([1, 2, 3]), dtype=flow.float32, device=flow.device(device)\n    )\n    input2 = flow.tensor(\n        np.array([4, 5, 6]), dtype=flow.float32, device=flow.device(device)\n    )\n    (np_x, np_y) = np.meshgrid(input1.numpy(), input2.numpy(), indexing=indexing)\n    (of_x, of_y) = flow.meshgrid(input1, input2, indexing=indexing)\n    test_case.assertTrue(np.allclose(of_x.numpy(), np_x, 0.0001, 0.0001))\n\n\ndef _test_meshgrid_forawd_scalar(test_case, device, indexing):\n    input1 = flow.tensor(np.array(1.0), dtype=flow.float32, device=flow.device(device))\n    input2 = flow.tensor(np.array(2.0), dtype=flow.float32, device=flow.device(device))\n    (np_x, np_y) = np.meshgrid(input1.numpy(), input2.numpy(), indexing=indexing)\n    (of_x, of_y) = flow.meshgrid(input1, input2, indexing=indexing)\n    test_case.assertTrue(np.allclose(of_x.numpy(), np_x, 0.0001, 0.0001))\n\n\ndef _test_meshgrid_forawd_3tensor(test_case, device, indexing):\n    input1 = flow.tensor(\n        np.array([1, 2, 3]), dtype=flow.float32, device=flow.device(device)\n    )\n    input2 = flow.tensor(\n        np.array([4, 5, 6]), dtype=flow.float32, device=flow.device(device)\n    )\n    input3 = flow.tensor(\n        np.array([7, 8, 9]), dtype=flow.float32, device=flow.device(device)\n    )\n    (np_x, np_y, np_z) = np.meshgrid(\n        input1.numpy(), input2.numpy(), input3.numpy(), indexing=indexing\n    )\n    (of_x, of_y, of_z) = flow.meshgrid(input1, input2, input3, indexing=indexing)\n    test_case.assertTrue(np.allclose(of_x.numpy(), np_x, 0.0001, 0.0001))\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestMeshGridModule(flow.unittest.TestCase):\n    def test_meshgrid(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [\n            _test_meshgrid_forawd,\n            _test_meshgrid_forawd_scalar,\n            _test_meshgrid_forawd_3tensor,\n        ]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"indexing\"] = [\"ij\", \"xy\"]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n    @autotest(auto_backward=False, check_graph=True)\n    @unittest.skip(\"pytorch 1.9.0 exist not indexing\")\n    def test_meshgrid_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=1, dim0=3, requires_grad=False).to(device)\n        y = random_tensor(ndim=1, dim0=3, requires_grad=False).to(device)\n        res = torch.meshgrid(x, y)\n        return res[0], res[1]\n\n    @autotest(auto_backward=False)\n    def test_meshgrid_with_0dim_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=0).to(device)\n        y = random_tensor(ndim=0).to(device)\n        res = torch.meshgrid(x, y)\n\n    @autotest(auto_backward=True)\n    @unittest.skip(\"pytorch 1.9.0 exist not indexing\")\n    def test_meshgrid_with_random_data_xy(test_case):\n        device = random_device()\n        x = random_tensor(ndim=1, dim0=random(1, 6)).to(device)\n        y = random_tensor(ndim=1, dim0=random(1, 6)).to(device)\n        res = torch.meshgrid(x, y, indexing=\"xy\")\n        return torch.cat((res[0], res[1]), 0)\n\n    @autotest(auto_backward=True)\n    @unittest.skip(\"pytorch 1.9.0 exist not indexing\")\n    def test_meshgrid_with_random_data_size(test_case):\n        device = random_device()\n        x = random_tensor(ndim=1, dim0=random(1, 6)).to(device)\n        res = torch.meshgrid(x, indexing=\"xy\")\n        return res[0]\n\n    @autotest(n=3)\n    def test_meshgrid_tuple_list_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=1, dim0=random(1, 6)).to(device)\n        y = random_tensor(ndim=1, dim0=random(1, 6)).to(device)\n        res1 = torch.meshgrid((x, y))\n        res2 = torch.meshgrid([x, y])\n        return torch.cat((res1[0], res1[1], res2[0], res2[1]), 0)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_min.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nimport numpy as np\n\nfrom oneflow.test_utils.automated_test_util import *\nimport oneflow as flow\nimport oneflow.unittest\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestMinModule(flow.unittest.TestCase):\n    @autotest(n=5, check_allclose=False, check_graph=True)\n    def test_min_reduce_random_dim(test_case):\n        device = random_device()\n        ndim = random().to(int).value()\n        x = random_tensor(ndim=ndim, dim0=random(1, 8))\n        y = x.to(device)\n        dim = random(-ndim, ndim).to(int).value()\n        keep_dims = random_bool().value()\n        y = torch.min(x, dim=dim, keepdim=keep_dims)\n\n        # pytorch result is an instance of class 'torch.return_types.min', but oneflow is tuple\n        test_case.assertTrue(\n            np.allclose(\n                y.oneflow[0].detach().cpu().numpy(),\n                y.pytorch.values.detach().cpu().numpy(),\n                rtol=0.0001,\n                atol=1e-05,\n            )\n        )\n        test_case.assertTrue(\n            np.allclose(\n                y.oneflow[1].detach().cpu().numpy(),\n                y.pytorch.indices.detach().cpu().numpy(),\n                rtol=0.0001,\n                atol=1e-05,\n            )\n        )\n\n        y.oneflow[0].sum().backward()\n        y.pytorch.values.sum().backward()\n        test_case.assertTrue(\n            np.allclose(\n                x.oneflow.grad.detach().cpu().numpy(),\n                x.pytorch.grad.detach().cpu().numpy(),\n                rtol=0.0001,\n                atol=1e-05,\n            )\n        )\n\n    @autotest(n=5, check_graph=True)\n    def test_min_reduce_all_dim(test_case):\n        device = random_device()\n        ndim = random().to(int).value()\n        x = random_tensor(ndim=ndim, dim0=random(1, 8)).to(device)\n        return torch.min(x)\n\n    @autotest(n=5, check_graph=True)\n    def test_min_elementwise(test_case):\n        device = random_device()\n        ndim = random().to(int).value()\n        dims = [random(1, 8) for _ in range(ndim)]\n        x = random_tensor(ndim, *dims).to(device)\n        y = random_tensor(ndim, *dims).to(device)\n        return torch.min(x, y)\n\n    @autotest(n=5, check_graph=True, check_dtype=True)\n    def test_min_elementwise_dtype_promotion(test_case):\n        device = random_device()\n        ndim = random().to(int).value()\n        dims = [random(1, 8) for _ in range(ndim)]\n        x = random_tensor(ndim, *dims, dtype=float).to(device)\n        y = random_tensor(ndim, *dims, dtype=int).to(device)\n        return torch.min(x, y)\n\n    @autotest(n=5, check_graph=True, check_dtype=True)\n    def test_min_broadcast_dtype_promotion(test_case):\n        device = random_device()\n        ndim = random().to(int).value()\n        dims = [random(1, 8) for _ in range(ndim)]\n        b_dims = [1 for _ in range(ndim)]\n        x = random_tensor(ndim, *dims, dtype=float).to(device)\n        y = random_tensor(ndim, *b_dims, dtype=int).to(device)\n        return torch.min(x, y)\n\n    @autotest(n=3, auto_backward=False)\n    def test_min_return_type(test_case):\n        x = random_tensor(3, 4)\n        result = x.min(1)\n        return result.values, result.indices\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_min_max_observer.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport math\nimport numpy as np\n\nfrom oneflow.test_utils.automated_test_util import *\nfrom oneflow.nn.modules import min_max_observer\nfrom oneflow.test_utils.test_util import GenArgList\nfrom oneflow.test_utils.test_util import (\n    GenArgList,\n    type_name_to_flow_type,\n    type_name_to_np_type,\n)\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\ndef gen_quant_scale_for_min_max_symmetric(weight, quantization_bit):\n    weight_max = np.max(np.abs(weight))\n    denominator = 2.0 ** (quantization_bit - 1) - 1\n    return (weight_max / denominator, 0)\n\n\ndef gen_quant_scale_for_min_max_affine(weight, quantization_bit):\n    weight_max = np.max(weight)\n    weight_min = np.min(weight)\n    denominator = 2.0 ** quantization_bit - 1\n    scale = (weight_max - weight_min) / denominator\n    zero_point = -np.round(weight_min / scale)\n    return (scale, zero_point)\n\n\ndef gen_quant_scale_for_min_max_cambricon(weight, quantization_bit):\n    weight_max = np.max(np.abs(weight))\n    scale = math.floor(math.log2(weight_max)) - (quantization_bit - 2)\n    return (scale, 0)\n\n\ndef product(tu):\n    return np.prod(tu).astype(np.int32).item()\n\n\ndef _check_min_max_observer(\n    test_case,\n    weight,\n    scale_of,\n    zero_point_of,\n    quantization_bit,\n    quantization_scheme,\n    quantization_formula,\n    per_layer_quantization,\n):\n    if per_layer_quantization or quantization_formula == \"cambricon\":\n        outer_num = 1\n        inner_num = product(weight.shape[0:])\n    else:\n        outer_num = weight.shape[0]\n        inner_num = product(weight.shape[1:])\n    scale_np = np.zeros((outer_num,))\n    zero_point_np = np.zeros((outer_num,))\n    weight_flatten = weight.flatten()\n    if quantization_formula == \"google\":\n        if quantization_scheme == \"symmetric\":\n            for c in range(outer_num):\n                (scale_np[c], zero_point_np[c]) = gen_quant_scale_for_min_max_symmetric(\n                    weight_flatten[c * inner_num : (c + 1) * inner_num],\n                    quantization_bit,\n                )\n        else:\n            for c in range(outer_num):\n                (scale_np[c], zero_point_np[c]) = gen_quant_scale_for_min_max_affine(\n                    weight_flatten[c * inner_num : (c + 1) * inner_num],\n                    quantization_bit,\n                )\n    else:\n        (scale_np[0], zero_point_np[0]) = gen_quant_scale_for_min_max_cambricon(\n            weight_flatten, quantization_bit\n        )\n    test_case.assertTrue(np.allclose(scale_of, scale_np, rtol=0.001))\n\n    rmse = np.sqrt(np.mean((zero_point_of - zero_point_np) ** 2))\n    assert rmse <= 1.0, \"min_max_observer op zero_point calculate has bug!\"\n\n\ndef _run_test_min_max_observer(\n    test_case,\n    device_type,\n    weight_shape,\n    quantization_bit,\n    quantization_scheme,\n    quantization_formula,\n    per_layer_quantization,\n):\n    weight = (np.random.random(weight_shape) - 0.5).astype(np.float32)\n    tensor_weight = flow.tensor(\n        weight, device=flow.device(device_type), dtype=flow.float32\n    )\n    min_max_observer = flow.nn.MinMaxObserver(\n        quantization_formula=quantization_formula,\n        quantization_bit=quantization_bit,\n        quantization_scheme=quantization_scheme,\n        per_layer_quantization=per_layer_quantization,\n    )\n    scale, zero_point = min_max_observer(tensor_weight)\n    _check_min_max_observer(\n        test_case,\n        weight,\n        scale.numpy(),\n        zero_point.numpy(),\n        quantization_bit,\n        quantization_scheme,\n        quantization_formula,\n        per_layer_quantization,\n    )\n\n\nclass TestMinMaxObserver(flow.unittest.TestCase):\n    def test_min_max_observer(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_case\"] = [test_case]\n        arg_dict[\"device_type\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"weight_shape\"] = [(9, 40, 20, 10)]\n        arg_dict[\"quantization_bit\"] = [8, 2]\n        arg_dict[\"quantization_scheme\"] = [\"symmetric\", \"affine\"]\n        arg_dict[\"quantization_formula\"] = [\"google\"]\n        arg_dict[\"per_layer_quantization\"] = [True, False]\n        for arg in GenArgList(arg_dict):\n            if arg[-2] == \"cambricon\" and arg[-1] == False:\n                continue\n            _run_test_min_max_observer(*arg)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_mock.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestMockModule(flow.unittest.TestCase):\n    def test_mock_device(test_case):\n        device = flow.device(\"mock\")\n        test_case.assertEqual(device.type, \"mock\")\n\n    def test_mock_placement(test_case):\n        placement = flow.placement(\"mock\", [0])\n        test_case.assertEqual(placement.type, \"mock\")\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_mode.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nfrom oneflow.test_utils.automated_test_util import *\nimport oneflow as flow\nimport oneflow.unittest\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestModeModule(flow.unittest.TestCase):\n    @autotest(n=5)\n    def test_mode_reduce_one_dim(test_case):\n        device = cpu_device()\n        ndim = random(low=2).to(int).value()\n        reduce_dim = random(high=ndim).to(int).value()\n        x = random_tensor(ndim).to(device)\n        return torch.mode(x, reduce_dim)\n\n    @autotest(n=5)\n    def test_mode_reduce_one_dim_keepdim(test_case):\n        device = cpu_device()\n        ndim = random(low=2).to(int).value()\n        reduce_dim = random(high=ndim).to(int).value()\n        x = random_tensor(ndim).to(device)\n        return torch.mode(x, reduce_dim, True)\n\n    @autotest(n=5, auto_backward=False, check_graph=False)\n    def test_mode_0size(test_case):\n        device = cpu_device()\n        x = random_tensor(ndim=3, dim1=0, requires_grad=False).to(device)\n        return torch.mode(x)\n\n    @autotest(n=5, auto_backward=False, check_graph=False)\n    def test_mode_reduce_one_dim_0size(test_case):\n        device = cpu_device()\n        x = random_tensor(ndim=3, dim1=0, requires_grad=False).to(device)\n        return torch.mode(x, 0)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_module.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport os\nimport math\nimport warnings\nimport tempfile\nimport unittest\nfrom itertools import repeat\nfrom typing import Tuple, Union, List\nfrom collections import OrderedDict\n\nimport numpy as np\nimport torch\n\nimport oneflow as flow\nimport oneflow.nn as nn\nimport oneflow.unittest\nfrom oneflow._oneflow_internal import TensorTuple\nfrom oneflow.test_utils.test_util import GenArgList\n\n\ndef np_relu(np_arr):\n    return np.where(np_arr > 0, np_arr, 0)\n\n\ndef _test_hooks(test_case, backward_register_fn):\n    module = nn.Sigmoid()\n    input = flow.ones(5, 5, requires_grad=True)\n\n    counter = {\"forwards\": 0, \"backwards\": 0}\n\n    def fw_hook(inc, h_module, input, output):\n        test_case.assertTrue(isinstance(input, tuple))\n        test_case.assertTrue(isinstance(output, flow.Tensor))\n        test_case.assertTrue(h_module is module)\n        test_case.assertTrue(flow.equal(input[0], flow.ones(5, 5)))\n        test_case.assertTrue(\n            flow.equal(output, flow.empty(5, 5).fill_(1 / (1 + 1 / math.e)))\n        )\n        counter[\"forwards\"] += inc\n\n    def bw_hook(inc, h_module, grad_input, grad_output):\n        test_case.assertTrue(isinstance(grad_input, TensorTuple))\n        test_case.assertTrue(isinstance(grad_output, TensorTuple))\n        test_case.assertTrue(h_module is module)\n        test_case.assertTrue(flow.equal(grad_output[0], flow.ones(5, 5) * 2))\n        counter[\"backwards\"] += inc\n\n    test_fwd = module.register_forward_hook(lambda *args: fw_hook(1, *args))\n\n    module(input)\n    module(input)\n    test_case.assertEqual(counter[\"forwards\"], 2)\n    test_case.assertEqual(counter[\"backwards\"], 0)\n\n    test_bwd = getattr(module, backward_register_fn)(lambda *args: bw_hook(1, *args))\n\n    output = module(input)\n    test_case.assertEqual(counter[\"forwards\"], 3)\n    test_case.assertEqual(counter[\"backwards\"], 0)\n\n    output.backward(flow.ones(5, 5) * 2, retain_graph=True)\n    test_case.assertEqual(counter[\"forwards\"], 3)\n    test_case.assertEqual(counter[\"backwards\"], 1)\n\n    output.backward(flow.ones(5, 5) * 2, retain_graph=True)\n    test_case.assertEqual(counter[\"forwards\"], 3)\n    test_case.assertEqual(counter[\"backwards\"], 2)\n\n    test2_fwd = module.register_forward_hook(lambda *args: fw_hook(2, *args))\n\n    output = module(input)\n    test_case.assertEqual(counter[\"forwards\"], 6)\n    test_case.assertEqual(counter[\"backwards\"], 2)\n\n    test2_bwd = getattr(module, backward_register_fn)(lambda *args: bw_hook(2, *args))\n    module(input).backward(flow.ones(5, 5) * 2)\n    test_case.assertEqual(counter[\"forwards\"], 9)\n    test_case.assertEqual(counter[\"backwards\"], 5)\n\n    test2_bwd.remove()\n\n    module(input).backward(flow.ones(5, 5) * 2)\n    test_case.assertEqual(counter[\"forwards\"], 12)\n    test_case.assertEqual(counter[\"backwards\"], 6)\n\n    test2_fwd.remove()\n\n    module(input).backward(flow.ones(5, 5) * 2)\n    test_case.assertEqual(counter[\"forwards\"], 13)\n    test_case.assertEqual(counter[\"backwards\"], 7)\n\n    test_fwd.remove()\n    test_bwd.remove()\n\n\ndef _test_module_forward_preforward_hook_removable(test_case):\n    module = nn.Sigmoid()\n\n    def removable_hook(m, input):\n        nonlocal handle\n        handle.remove()\n        return input\n\n    def removable_hook_2(m, input):\n        nonlocal handle_2\n        handle_2.remove()\n        return input\n\n    handle = module.register_forward_pre_hook(removable_hook)\n    handle_2 = module.register_forward_pre_hook(removable_hook_2)\n\n    # make sure hook register is successful\n    test_case.assertEqual(len(handle.hooks_dict_ref()), 2)\n    test_case.assertEqual(len(handle_2.hooks_dict_ref()), 2)\n\n    input = flow.randn(2, 2)\n    output = module(input)\n    test_case.assertTrue(flow.equal(flow.sigmoid(input), output))\n\n    # make sure hook removal is successful\n    test_case.assertFalse(handle.id in handle.hooks_dict_ref())\n    test_case.assertFalse(handle_2.id in handle.hooks_dict_ref())\n    test_case.assertEqual(len(handle.hooks_dict_ref()), 0)\n    test_case.assertEqual(len(handle_2.hooks_dict_ref()), 0)\n\n\ndef _test_module_forward_forward_hook_removable(test_case):\n    module = nn.Sigmoid()\n\n    def removable_hook(m, input, output):\n        nonlocal handle\n        handle.remove()\n        return output\n\n    def removable_hook_2(m, input, output):\n        nonlocal handle_2\n        handle_2.remove()\n        return output\n\n    handle = module.register_forward_hook(removable_hook)\n    handle_2 = module.register_forward_hook(removable_hook_2)\n\n    # make sure hook register is successful\n    test_case.assertEqual(len(handle.hooks_dict_ref()), 2)\n    test_case.assertEqual(len(handle_2.hooks_dict_ref()), 2)\n\n    input = flow.randn(2, 2)\n    output = module(input)\n    test_case.assertTrue(flow.equal(flow.sigmoid(input), output))\n\n    # make sure hook removal is successful\n    test_case.assertFalse(handle.id in handle.hooks_dict_ref())\n    test_case.assertFalse(handle_2.id in handle.hooks_dict_ref())\n    test_case.assertEqual(len(handle.hooks_dict_ref()), 0)\n    test_case.assertEqual(len(handle_2.hooks_dict_ref()), 0)\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\nclass TestModule(flow.unittest.TestCase):\n    @flow.unittest.skip_unless_1n1d()\n    def test_nested_module(test_case):\n        class CustomModule(flow.nn.Module):\n            def __init__(self):\n                super().__init__()\n                self.relu = flow.nn.ReLU()\n\n            def forward(self, x):\n                return self.relu(x)\n\n        m = CustomModule()\n        x = flow.Tensor(2, 3)\n        flow.nn.init.uniform_(x, a=-1.0, b=1.0)\n        y = m(x)\n        test_case.assertTrue(np.array_equal(np_relu(x.numpy()), y.numpy()))\n\n    @flow.unittest.skip_unless_1n1d()\n    def test_relu(test_case):\n        relu = flow.nn.ReLU()\n        x = flow.Tensor(2, 3)\n        flow.nn.init.uniform_(x, a=-1.0, b=1.0)\n        y = relu(x)\n        test_case.assertTrue(np.array_equal(np_relu(x.numpy()), y.numpy()))\n\n    @flow.unittest.skip_unless_1n1d()\n    def test_load_state_dict(test_case):\n        class CustomModule(flow.nn.Module):\n            def __init__(self):\n                super().__init__()\n                self.w = flow.nn.Parameter(flow.Tensor(2, 3))\n\n            def forward(self, x):\n                return self.w\n\n        m = CustomModule()\n        ones = np.ones((2, 3), dtype=np.float32)\n        m.load_state_dict({\"w\": ones})\n        x = flow.Tensor(2, 3)\n        y = m(x).numpy()\n        test_case.assertTrue(np.array_equal(y, ones))\n\n    @flow.unittest.skip_unless_1n1d()\n    def test_state_dict(test_case):\n        class CustomModule(flow.nn.Module):\n            def __init__(self, param1, param2):\n                super().__init__()\n                self.param1 = param1\n                self.param2 = param2\n\n        tensor0 = flow.nn.Parameter(flow.Tensor(2, 3))\n        tensor1 = flow.nn.Parameter(flow.Tensor(2, 3))\n        sub_module = CustomModule(tensor0, tensor1)\n        m = CustomModule(tensor1, sub_module)\n        state_dict = m.state_dict()\n        test_case.assertEqual(\n            state_dict,\n            {\"param2.param1\": tensor0, \"param2.param2\": tensor1, \"param1\": tensor1},\n        )\n\n    @flow.unittest.skip_unless_1n1d()\n    def test_parameter(test_case):\n        shape = (3, 4)\n        t = flow.Tensor(*shape)\n        p = flow.nn.Parameter(t)\n        test_case.assertEqual(type(p), flow.nn.Parameter)\n        test_case.assertEqual(p.shape, shape)\n\n    @flow.unittest.skip_unless_1n1d()\n    def test_module_forward(test_case):\n        class CustomModule(flow.nn.Module):\n            def __init__(self, w):\n                super().__init__()\n                self.w = w\n\n            def forward(self, x):\n                return x + self.w\n\n        m = CustomModule(5)\n        test_case.assertEqual(m(1), 6)\n        m = CustomModule(4)\n        test_case.assertEqual(m(3), 7)\n\n    @flow.unittest.skip_unless_1n1d()\n    def test_train_eval(test_case):\n        m = flow.nn.Module()\n        test_case.assertEqual(m.training, True)\n        m.train()\n        test_case.assertEqual(m.training, True)\n        m.eval()\n        test_case.assertEqual(m.training, False)\n\n    @flow.unittest.skip_unless_1n1d()\n    def test_module_setattr(test_case):\n        class CustomModule(flow.nn.Module):\n            def __init__(self, param1, param2):\n                super().__init__()\n                self.param1 = param1\n                self.param2 = param2\n\n        param0 = flow.nn.Parameter(flow.Tensor(2, 3))\n        param1 = flow.nn.Parameter(flow.Tensor(2, 3))\n        param2 = CustomModule(param0, param1)\n        m = CustomModule(param1, param2)\n        params = list(m.parameters())\n        test_case.assertEqual(len(params), 2)\n\n        test_case.assertTrue(\n            np.allclose(params[0].numpy(), param1.numpy(), atol=1e-4, rtol=1e-4)\n        )\n        test_case.assertTrue(\n            np.allclose(params[1].numpy(), param0.numpy(), atol=1e-4, rtol=1e-4)\n        )\n        children = list(m.children())\n        test_case.assertEqual(len(children), 1)\n        child = children[0]\n        test_case.assertEqual(child, param2)\n        child_params = list(child.parameters())\n\n        test_case.assertEqual(len(child_params), 2)\n        test_case.assertTrue(np.allclose(child_params[0].numpy(), param0.numpy()))\n        test_case.assertTrue(np.allclose(child_params[1].numpy(), param1.numpy()))\n\n    @flow.unittest.skip_unless_1n1d()\n    def test_module_apply(test_case):\n        class CustomModule(flow.nn.Module):\n            def __init__(self):\n                super().__init__()\n                self.modules = flow.nn.Module()\n\n        global module_num\n        module_num = 0\n\n        def get_module_num(m):\n            global module_num\n            module_num += 1\n\n        net = CustomModule()\n        net.apply(get_module_num)\n        test_case.assertEqual(module_num, 2)\n\n    @flow.unittest.skip_unless_1n1d()\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    def test_module_cpu_cuda(test_case):\n        class CustomModule(flow.nn.Module):\n            def __init__(self, param1, param2):\n                super().__init__()\n                self.param1 = param1\n                self.param2 = param2\n\n        tensor0 = flow.nn.Parameter(flow.Tensor(2, 3, device=flow.device(\"cpu\")))\n        tensor1 = flow.nn.Parameter(flow.Tensor(2, 3, device=flow.device(\"cpu\")))\n        sub_module = CustomModule(tensor0, tensor1)\n        m = CustomModule(tensor1, sub_module)\n        m.cuda()\n        state_dict = m.state_dict()\n        test_case.assertEqual(state_dict[\"param2.param1\"].device, flow.device(\"cuda:0\"))\n        test_case.assertEqual(state_dict[\"param2.param2\"].device, flow.device(\"cuda:0\"))\n\n        m.cpu()\n        state_dict = m.state_dict()\n        test_case.assertEqual(state_dict[\"param2.param1\"].device, flow.device(\"cpu\"))\n        test_case.assertEqual(state_dict[\"param2.param2\"].device, flow.device(\"cpu\"))\n\n    @flow.unittest.skip_unless_1n1d()\n    def test_module_float_double(test_case):\n        class CustomModule(flow.nn.Module):\n            def __init__(self, param1, param2):\n                super().__init__()\n                self.param1 = param1\n                self.param2 = param2\n\n        tensor0 = flow.nn.Parameter(flow.Tensor(2, 3).to(dtype=flow.float64))\n        tensor1 = flow.nn.Parameter(flow.Tensor(2, 3).to(dtype=flow.float64))\n        m = CustomModule(tensor0, tensor1)\n        m = m.float()\n        state_dict = m.state_dict()\n        test_case.assertEqual(state_dict[\"param1\"].dtype, flow.float32)\n        test_case.assertEqual(state_dict[\"param2\"].dtype, flow.float32)\n\n        m = m.double()\n        state_dict = m.state_dict()\n        test_case.assertEqual(state_dict[\"param1\"].dtype, flow.float64)\n        test_case.assertEqual(state_dict[\"param2\"].dtype, flow.float64)\n\n    @flow.unittest.skip_unless_1n1d()\n    def test_moduledict(test_case):\n        class ModuleDict(nn.Module):\n            def __init__(self):\n                super(ModuleDict, self).__init__()\n                self.choices = nn.ModuleDict(\n                    {\"conv\": nn.Conv2d(10, 10, 3), \"pool\": nn.MaxPool2d(3)}\n                )\n                self.activations = nn.ModuleDict(\n                    {\"relu\": nn.ReLU(), \"prelu\": nn.PReLU()}\n                )\n\n            def forward(self, x, choice, act):\n                x = self.choices[choice](x)\n                x = self.activations[act](x)\n                return x\n\n        model = ModuleDict()\n        input = flow.tensor(np.random.randn(4, 10, 32, 32), dtype=flow.float32)\n        output = model(input, \"conv\", \"relu\")\n        test_case.assertEqual(output.shape, flow.Size([4, 10, 30, 30]))\n\n    @flow.unittest.skip_unless_1n1d()\n    def test_module_submodule(test_case):\n        class CustomSubModule(flow.nn.Module):\n            def __init__(self):\n                super().__init__()\n                self.param = flow.nn.Linear(2, 3)\n\n        class CustomModule(flow.nn.Module):\n            def __init__(self) -> None:\n                super().__init__()\n                self.linear = CustomSubModule()\n\n        m = CustomModule()\n        test_case.assertTrue(\n            isinstance(m.get_submodule(\"linear.param\"), flow.nn.Linear)\n        )\n\n    @flow.unittest.skip_unless_1n1d()\n    def test_module_get_parameter(test_case):\n        class CustomModule(flow.nn.Module):\n            def __init__(self, param1, param2):\n                super().__init__()\n                self.param1 = param1\n                self.param2 = param2\n\n        tensor0 = flow.nn.Parameter(flow.Tensor(2, 3).to(dtype=flow.float32))\n        tensor1 = flow.nn.Parameter(flow.Tensor(2, 3).to(dtype=flow.float32))\n        m = CustomModule(tensor0, tensor1)\n        test_case.assertTrue(m.get_parameter(\"param1\") is tensor0)\n        test_case.assertTrue(m.get_parameter(\"param2\") is tensor1)\n\n    def test_module_delattr(test_case):\n        class ConvBNModule(nn.Module):\n            def __init__(self):\n                super(ConvBNModule, self).__init__()\n                self.conv = nn.Conv2d(1, 2, 1, 1)\n                self.bn = nn.BatchNorm2d(2)\n\n            def forward(self, x):\n                return self.bn(self.conv(x))\n\n        m = ConvBNModule()\n        delattr(m, \"bn\")\n\n    @flow.unittest.skip_unless_1n1d()\n    def test_hooks_register(test_case):\n        for hook in [\"register_backward_hook\", \"register_full_backward_hook\"]:\n            _test_hooks(test_case, hook)\n        _test_module_forward_preforward_hook_removable(test_case)\n        _test_module_forward_forward_hook_removable(test_case)\n\n    @flow.unittest.skip_unless_1n1d()\n    def test_register_state_dict_hook_hook(test_case):\n        destination_check = None\n\n        def state_dict_hook(module, destination, prefix, local_metadata):\n            for submodule_name, submodule in module.named_modules():\n                for attr_name, attr in submodule.__dict__.items():\n                    if isinstance(attr, torch.Tensor):\n                        mod_prefix = prefix + submodule_name\n                        key = mod_prefix + (\".\" if mod_prefix else \"\") + attr_name\n                        destination[key] = attr\n            nonlocal destination_check\n            destination_check = destination\n\n        class CustomModule(flow.nn.Module):\n            def __init__(self):\n                super().__init__()\n                self.linear = nn.Linear(10, 5)\n                self._register_state_dict_hook(state_dict_hook)\n\n            def forward(self, x):\n                x = self.linear(x)\n                return x\n\n        m = CustomModule()\n        test_case.assertEqual(destination_check, None)\n        state_dict = m.state_dict()\n        test_case.assertEqual(destination_check, state_dict)\n\n    @flow.unittest.skip_unless_1n1d()\n    def test_full_backward_hook(test_case):\n        hook_triggered = False\n\n        def hook(_, grad_input, grad_output):\n            nonlocal hook_triggered\n            hook_triggered = True\n            test_case.assertEqual(len(grad_input), 1)\n            test_case.assertEqual(len(grad_output), 1)\n            test_case.assertTrue(np.array_equal(grad_input[0].numpy(), [1, 0]))\n            test_case.assertTrue(np.array_equal(grad_output[0].numpy(), [1, 1]))\n\n        m = flow.nn.ReLU()\n        m.register_full_backward_hook(hook)\n\n        x0 = flow.tensor([1.0, -1], requires_grad=True)\n        x = x0 + 1\n        y = m(x)\n        y.sum().backward()\n        test_case.assertTrue(hook_triggered)\n        test_case.assertTrue(np.array_equal(x0.grad, [1, 0]))\n\n    @flow.unittest.skip_unless_1n1d()\n    def test_full_backward_hook_with_return_value(test_case):\n        hook_triggered = False\n\n        def hook(_, grad_input, grad_output):\n            nonlocal hook_triggered\n            hook_triggered = True\n            test_case.assertEqual(len(grad_input), 1)\n            test_case.assertEqual(len(grad_output), 1)\n            test_case.assertTrue(np.array_equal(grad_input[0].numpy(), [1, 0]))\n            test_case.assertTrue(np.array_equal(grad_output[0].numpy(), [1, 1]))\n            return (flow.tensor([1, 1]),)\n\n        m = flow.nn.ReLU()\n        m.register_full_backward_hook(hook)\n\n        x0 = flow.tensor([1.0, -1], requires_grad=True)\n        x = x0 + 1\n        y = m(x)\n        y.sum().backward()\n        test_case.assertTrue(hook_triggered)\n        test_case.assertTrue(np.array_equal(x0.grad, [1, 1]))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_module_to.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport os\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport oneflow.unittest\n\ndummy_val = np.random.randn(2, 3)\nin_val = np.full((2, 3), -2)\ncpu0_device = flow.device(\"cpu\")\nif os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"):\n    gpu0_device = cpu0_device\nelse:\n    gpu0_device = flow.device(\"cuda\")\n\n\nclass DummyModule(flow.nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.register_buffer(\"dummy_buf\", flow.Tensor(dummy_val))\n        self.dummy_para = flow.nn.Parameter(flow.Tensor(dummy_val))\n        self.register_buffer(\"dummy_buf_int\", flow.Tensor(dummy_val).to(flow.int32))\n\n    def forward(self, x):\n        return self.dummy_para * x + self.dummy_buf\n\n\ndef _test_dummy_module(test_case):\n    m = DummyModule()\n    test_case.assertEqual(m.dummy_buf.device, cpu0_device)\n    test_case.assertEqual(m.dummy_para.device, cpu0_device)\n    input = flow.Tensor(in_val)\n    output = m(input)\n    test_case.assertTrue(np.allclose(output.numpy(), -dummy_val, 0.0001, 0.0001))\n    test_case.assertEqual(m.dummy_buf.grad, None)\n    test_case.assertEqual(m.dummy_para.grad, None)\n    test_case.assertEqual(input.device, cpu0_device)\n    test_case.assertEqual(output.device, cpu0_device)\n\n\ndef _test_dummy_module_to(test_case):\n    m = DummyModule()\n    test_case.assertEqual(m.dummy_buf.device, cpu0_device)\n    test_case.assertEqual(m.dummy_para.device, cpu0_device)\n    m.to(gpu0_device)\n    test_case.assertEqual(m.dummy_buf.device, gpu0_device)\n    test_case.assertTrue(m.dummy_buf.is_leaf)\n    test_case.assertTrue(not m.dummy_buf.requires_grad)\n    test_case.assertEqual(m.dummy_para.device, gpu0_device)\n    test_case.assertTrue(m.dummy_para.is_leaf)\n    test_case.assertTrue(m.dummy_para.requires_grad)\n    input = flow.Tensor(in_val).to(gpu0_device)\n    output = m(input)\n    test_case.assertTrue(np.allclose(output.numpy(), -dummy_val, 0.0001, 0.0001))\n    test_case.assertEqual(m.dummy_buf.grad, None)\n    test_case.assertEqual(m.dummy_para.grad, None)\n    test_case.assertEqual(input.device, gpu0_device)\n    test_case.assertEqual(output.device, gpu0_device)\n    output_grad = flow.ones((2, 3)).to(gpu0_device)\n    output.backward(output_grad)\n    test_case.assertEqual(output_grad.device, gpu0_device)\n    test_case.assertEqual(m.dummy_buf.grad, None)\n    test_case.assertTrue(np.allclose(m.dummy_para.grad.numpy(), in_val, 0.0001, 0.0001))\n    test_case.assertEqual(m.dummy_para.grad.device, gpu0_device)\n\n\n@flow.unittest.skip_unless_1n1d()\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\nclass TestModuleTo(flow.unittest.TestCase):\n    @unittest.skip(\"skip for now, becase it failed 4 times in past week\")\n    def test_module_to_device(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [_test_dummy_module, _test_dummy_module_to]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n    def test_module_to_dtype(test_case):\n        m = DummyModule()\n        m.to(flow.float64)\n        test_case.assertEqual(m.dummy_buf.dtype, flow.float64)\n        test_case.assertEqual(m.dummy_para.dtype, flow.float64)\n        test_case.assertEqual(m.dummy_buf_int.dtype, flow.int32)\n\n    def test_module_to_tensor(test_case):\n        m = DummyModule()\n        m.to(flow.zeros(1, dtype=flow.float16, device=\"cuda\"))\n        test_case.assertEqual(m.dummy_buf.dtype, flow.float16)\n        test_case.assertEqual(m.dummy_para.dtype, flow.float16)\n        test_case.assertEqual(m.dummy_buf_int.dtype, flow.int32)\n        test_case.assertEqual(m.dummy_buf.device.type, \"cuda\")\n        test_case.assertEqual(m.dummy_para.device.type, \"cuda\")\n        test_case.assertEqual(m.dummy_buf_int.device.type, \"cuda\")\n\n    def test_module_to_with_var_reuse(test_case):\n        class ReuseVarModule(flow.nn.Module):\n            def __init__(self):\n                super().__init__()\n                self.linear1 = flow.nn.Linear(3, 4)\n                self.linear2 = flow.nn.Linear(3, 4)\n                self.linear2.weight = self.linear1.weight\n\n        reuse_var_m = ReuseVarModule()\n\n        test_case.assertTrue(reuse_var_m.linear1.weight is reuse_var_m.linear2.weight)\n        test_case.assertEqual(reuse_var_m.linear1.weight.device, cpu0_device)\n\n        test_case.assertTrue(reuse_var_m.linear1.bias is not reuse_var_m.linear2.bias)\n        test_case.assertEqual(reuse_var_m.linear1.bias.device, cpu0_device)\n\n        reuse_var_m.to(gpu0_device)\n\n        test_case.assertTrue(reuse_var_m.linear1.weight is reuse_var_m.linear2.weight)\n        test_case.assertEqual(reuse_var_m.linear1.weight.device, gpu0_device)\n\n        test_case.assertTrue(reuse_var_m.linear1.bias is not reuse_var_m.linear2.bias)\n        test_case.assertEqual(reuse_var_m.linear1.bias.device, gpu0_device)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_module_to_global_or_local.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport os\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\n@flow.unittest.skip_unless_1n2d()\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\nclass TestModuleToGlobalOrLocal(flow.unittest.TestCase):\n    def test_module_to_global(test_case):\n        rank = flow.env.get_rank()\n        P = flow.placement(\"cuda\", ranks=[0, 1])\n        B = flow.sbp.broadcast\n\n        class ReuseVarModule(flow.nn.Module):\n            def __init__(self):\n                super().__init__()\n                self.linear1 = flow.nn.Linear(3, 4)\n                self.linear2 = flow.nn.Linear(3, 4)\n                self.linear2.weight = self.linear1.weight\n\n        reuse_var_m = ReuseVarModule()\n\n        test_case.assertTrue(reuse_var_m.linear1.weight is reuse_var_m.linear2.weight)\n        test_case.assertEqual(\n            reuse_var_m.linear1.weight.device, flow.device(\"cpu\", rank)\n        )\n\n        test_case.assertTrue(reuse_var_m.linear1.bias is not reuse_var_m.linear2.bias)\n        test_case.assertEqual(reuse_var_m.linear1.bias.device, flow.device(\"cpu\", rank))\n\n        reuse_var_m.to_global(placement=P, sbp=B)\n\n        test_case.assertTrue(reuse_var_m.linear1.weight is reuse_var_m.linear2.weight)\n        test_case.assertEqual(reuse_var_m.linear1.weight.placement, P)\n        test_case.assertEqual(reuse_var_m.linear1.weight.sbp[0], B)\n\n        test_case.assertTrue(reuse_var_m.linear1.bias is not reuse_var_m.linear2.bias)\n        test_case.assertEqual(reuse_var_m.linear1.bias.placement, P)\n        test_case.assertEqual(reuse_var_m.linear1.bias.sbp[0], B)\n\n    def test_module_to_local(test_case):\n        rank = flow.env.get_rank()\n        device = \"cuda\"\n        P = flow.placement(device, ranks=[0, 1])\n        B = flow.sbp.broadcast\n        S = flow.sbp.split(0)\n\n        class ToLocalModule(flow.nn.Module):\n            def __init__(self):\n                super().__init__()\n                self.linear = flow.nn.Linear(3, 4, False)\n\n        to_local_m = ToLocalModule()\n        flow.nn.init.uniform_(to_local_m.linear.weight)\n\n        to_local_m.to_global(placement=P, sbp=B)\n        origin_w_np = to_local_m.linear.weight.numpy()\n\n        to_local_m.to_global(placement=P, sbp=S)\n        test_case.assertTrue(\n            np.array_equal(to_local_m.linear.weight.numpy(), origin_w_np)\n        )\n\n        # When wight SBP is split(0)\n        to_local_m.to_local()\n        test_case.assertTrue(to_local_m.linear.weight.is_local)\n        if rank == 0:\n            test_case.assertTrue(\n                np.array_equal(to_local_m.linear.weight.numpy(), origin_w_np[:2])\n            )\n        elif rank == 1:\n            test_case.assertTrue(\n                np.array_equal(to_local_m.linear.weight.numpy(), origin_w_np[2:])\n            )\n\n        # local to global from split(0)\n        to_local_m.to_global(placement=P, sbp=S)\n        test_case.assertTrue(\n            np.array_equal(to_local_m.linear.weight.numpy(), origin_w_np)\n        )\n\n        # When wight SBP is broadcast\n        to_local_m.to_global(placement=P, sbp=B)\n        test_case.assertTrue(not to_local_m.linear.weight.is_local)\n        test_case.assertTrue(\n            np.array_equal(to_local_m.linear.weight.numpy(), origin_w_np)\n        )\n\n        # When wight SBP is broadcast\n        to_local_m.to_local()\n        test_case.assertTrue(to_local_m.linear.weight.is_local)\n        test_case.assertTrue(\n            np.array_equal(to_local_m.linear.weight.numpy(), origin_w_np)\n        )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_module_to_half.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nfrom oneflow.test_utils.automated_test_util import *\nimport oneflow as flow\nimport oneflow.unittest\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestModuleToHalf(flow.unittest.TestCase):\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    def test_module_to_half(test_case):\n        input = flow.randn(10, 10).to(flow.float16).cuda()\n        model = flow.nn.Linear(10, 20).half().cuda()\n        output = model(input)\n        test_case.assertEqual(output.dtype, flow.float16)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_movedim.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nfrom random import shuffle\n\nfrom oneflow.test_utils.automated_test_util import *\nimport oneflow as flow\nimport oneflow.unittest\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestMovedim(flow.unittest.TestCase):\n    @autotest(check_graph=True)\n    def test_flow_movedim_with_vector(test_case):\n        device = random_device()\n        x = random_tensor(\n            ndim=4,\n            dim1=random(3, 6),\n            dim2=random(3, 6),\n            dim3=random(3, 6),\n            dim4=random(3, 6),\n        ).to(device)\n        z = torch.movedim(x, (0, 1), (2, 3))\n        return z\n\n    @autotest(n=10)\n    def test_flow_movedim_with_stride(test_case):\n        device = random_device()\n        x = random_tensor(\n            ndim=4,\n            dim1=random(3, 6),\n            dim2=random(3, 6),\n            dim3=random(3, 6),\n            dim4=random(3, 6),\n        ).to(device)\n        perm = [0, 1, 2, 3]\n        shuffle(perm)\n        y = x.permute(perm)\n        z = torch.movedim(y, (0, 1), (2, 3))\n        return z\n\n    @autotest(check_graph=True)\n    def test_flow_movedim_with_int(test_case):\n        device = random_device()\n        x = random_tensor(\n            ndim=4,\n            dim1=random(3, 6),\n            dim2=random(3, 6),\n            dim3=random(3, 6),\n            dim4=random(3, 6),\n        ).to(device)\n        z = torch.movedim(x, 0, 3)\n        return z\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_moving_average_min_max_observer.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport math\nimport numpy as np\n\nfrom oneflow.test_utils.automated_test_util import *\nfrom oneflow.test_utils.test_util import GenArgList\nfrom oneflow.test_utils.test_util import (\n    GenArgList,\n    type_name_to_flow_type,\n    type_name_to_np_type,\n)\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\ndef gen_quant_scale_for_moving_average_min_max_symmetric(\n    activation, quantization_bit, momentum, moving_max, moving_min\n):\n    activation_max = np.max(np.abs(activation))\n    denominator = 2.0 ** (quantization_bit - 1) - 1\n    if moving_max[0] == 0:\n        moving_max[0] = activation_max\n    else:\n        moving_max[0] = moving_max[0] * momentum + activation_max * (1 - momentum)\n    moving_min[0] = moving_max[0]\n    return (moving_max[0] / denominator, 0)\n\n\ndef gen_quant_scale_for_moving_average_min_max_affine(\n    activation, quantization_bit, momentum, moving_max, moving_min\n):\n    activation_max = np.max(activation)\n    activation_min = np.min(activation)\n    denominator = 2.0 ** quantization_bit - 1\n    if moving_max[0] == 0:\n        moving_max[0] = activation_max\n    else:\n        moving_max[0] = moving_max[0] * momentum + activation_max * (1 - momentum)\n    if moving_min[0] == 0:\n        moving_min[0] = activation_min\n    else:\n        moving_min[0] = moving_min[0] * momentum + activation_min * (1 - momentum)\n    scale = (moving_max[0] - moving_min[0]) / denominator\n    zero_point = -np.round(moving_min[0] / scale)\n    return (scale, zero_point)\n\n\ndef gen_quant_scale_for_moving_average_min_max_cambricon(\n    activation, quantization_bit, momentum, moving_max, moving_min\n):\n    activation_max = np.max(np.abs(activation))\n    if moving_max[0] == 0:\n        moving_max[0] = activation_max\n    else:\n        moving_max[0] = moving_max[0] * momentum + activation_max * (1 - momentum)\n    moving_min[0] = moving_max[0]\n    return (math.floor(math.log2(moving_max[0])) - (quantization_bit - 2), 0)\n\n\ndef _check_moving_average_min_max_observer(\n    test_case,\n    activation,\n    scale_of,\n    zero_point_of,\n    moving_max_np,\n    moving_min_np,\n    quantization_bit,\n    quantization_scheme,\n    quantization_formula,\n    momentum,\n):\n    if quantization_formula == \"google\":\n        if quantization_scheme == \"symmetric\":\n            (\n                scale_np,\n                zero_point_np,\n            ) = gen_quant_scale_for_moving_average_min_max_symmetric(\n                activation.flatten(),\n                quantization_bit,\n                momentum,\n                moving_max_np,\n                moving_min_np,\n            )\n        else:\n            (\n                scale_np,\n                zero_point_np,\n            ) = gen_quant_scale_for_moving_average_min_max_affine(\n                activation.flatten(),\n                quantization_bit,\n                momentum,\n                moving_max_np,\n                moving_min_np,\n            )\n    else:\n        (\n            scale_np,\n            zero_point_np,\n        ) = gen_quant_scale_for_moving_average_min_max_cambricon(\n            activation.flatten(),\n            quantization_bit,\n            momentum,\n            moving_max_np,\n            moving_min_np,\n        )\n    test_case.assertTrue(np.allclose(scale_of[0], scale_np, rtol=0.001))\n\n    rmse = np.sqrt(np.mean((zero_point_of[0] - zero_point_np) ** 2))\n    assert (\n        rmse <= 1.0\n    ), \"moving_average_min_max_observer op zero_point calculate has bug!\"\n\n\ndef _run_test_moving_average_min_max_observer(\n    test_case,\n    device_type,\n    dtype,\n    activation_shape,\n    quantization_bit,\n    quantization_scheme,\n    quantization_formula,\n    momentum,\n):\n    moving_max_np = np.zeros((1,))\n    moving_min_np = np.zeros((1,))\n    current_train_step_tensor = flow.tensor(\n        np.zeros((1,)).astype(np.float32),\n        dtype=flow.int64,\n        device=flow.device(device_type),\n    )\n    for i in range(10):\n        activation = (np.random.random(activation_shape) - 0.5).astype(\n            type_name_to_np_type[dtype]\n        )\n        activation_tensor = flow.tensor(\n            activation, dtype=flow.float32, device=flow.device(device_type)\n        )\n        moving_average_min_max_observer = flow.nn.MovingAverageMinMaxObserver(\n            stop_update_after_iters=1,\n            quantization_formula=quantization_formula,\n            quantization_bit=quantization_bit,\n            quantization_scheme=quantization_scheme,\n            momentum=momentum,\n        )\n        moving_average_min_max_observer = moving_average_min_max_observer.to(\n            device_type\n        )\n        (scale, zero_point) = moving_average_min_max_observer(\n            activation_tensor, current_train_step_tensor\n        )\n        _check_moving_average_min_max_observer(\n            test_case,\n            activation,\n            scale.numpy(),\n            zero_point.numpy(),\n            moving_max_np,\n            moving_min_np,\n            quantization_bit,\n            quantization_scheme,\n            quantization_formula,\n            momentum,\n        )\n\n\nclass TestMovingAverageMinMaxObserver(flow.unittest.TestCase):\n    def test_moving_average_min_max_observer(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_case\"] = [test_case]\n        arg_dict[\"device_type\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"dtype\"] = [\"float32\", \"double\"]\n        arg_dict[\"activation_shape\"] = [(9, 40, 20, 10)]\n        arg_dict[\"quantization_bit\"] = [8, 2]\n        arg_dict[\"quantization_scheme\"] = [\"symmetric\", \"affine\"]\n        arg_dict[\"quantization_formula\"] = [\"google\"]\n        arg_dict[\"momentum\"] = [0.95]\n        for arg in GenArgList(arg_dict):\n            _run_test_moving_average_min_max_observer(*arg)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_mul.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nimport torch as torch_original\n\nfrom oneflow.test_utils.automated_test_util import *\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\ndef _test_mul_impl(test_case, device):\n    x = flow.tensor(\n        np.random.randn(2, 3),\n        dtype=flow.float32,\n        device=flow.device(device),\n        requires_grad=True,\n    )\n    y = flow.tensor(\n        np.random.randn(2, 3),\n        dtype=flow.float32,\n        device=flow.device(device),\n        requires_grad=True,\n    )\n    of_out = flow.mul(x, y)\n    np_out = np.multiply(x.numpy(), y.numpy())\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05))\n    of_out = of_out.sum()\n    of_out.backward()\n    np_grad_x = y.numpy()\n    np_grad_y = x.numpy()\n    test_case.assertTrue(np.allclose(x.grad.numpy(), np_grad_x, 1e-05, 1e-05))\n    test_case.assertTrue(np.allclose(y.grad.numpy(), np_grad_y, 1e-05, 1e-05))\n    x = 5\n    y = flow.tensor(\n        np.random.randn(2, 3), dtype=flow.float32, device=flow.device(device)\n    )\n    of_out = flow.mul(x, y)\n    np_out = np.multiply(x, y.numpy())\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05))\n    x = flow.tensor(\n        np.random.randn(2, 3), dtype=flow.float32, device=flow.device(device)\n    )\n    y = 5\n    of_out = flow.mul(x, y)\n    np_out = np.multiply(x.numpy(), y)\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05))\n    x = flow.tensor(\n        np.random.randn(1, 1),\n        dtype=flow.float32,\n        device=flow.device(device),\n        requires_grad=True,\n    )\n    y = flow.tensor(\n        np.random.randn(2, 3),\n        dtype=flow.float32,\n        device=flow.device(device),\n        requires_grad=True,\n    )\n    of_out = flow.mul(x, y)\n    np_out = np.multiply(x.numpy(), y.numpy())\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05))\n    of_out = of_out.sum()\n    of_out.backward()\n    test_case.assertTrue(np.allclose(x.grad.numpy(), np.sum(y.numpy()), 1e-05, 1e-05))\n    test_case.assertTrue(np.allclose(y.grad.numpy(), x.numpy(), 1e-05, 1e-05))\n    x = flow.tensor(\n        np.random.randn(1, 1),\n        dtype=flow.float32,\n        device=flow.device(device),\n        requires_grad=True,\n    )\n    y = flow.tensor(\n        np.random.randn(2, 3, 4),\n        dtype=flow.float32,\n        device=flow.device(device),\n        requires_grad=True,\n    )\n    of_out = flow.mul(x, y)\n    np_out = np.multiply(x.numpy(), y.numpy())\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05))\n    of_out = of_out.sum()\n    of_out.backward()\n    test_case.assertTrue(np.allclose(x.grad.numpy(), np.sum(y.numpy()), 1e-05, 1e-05))\n    test_case.assertTrue(np.allclose(y.grad.numpy(), x.numpy(), 1e-05, 1e-05))\n    x = flow.tensor(\n        np.random.randn(1, 1),\n        dtype=flow.float32,\n        device=flow.device(device),\n        requires_grad=True,\n    )\n    y = flow.tensor(\n        np.random.randn(2, 3, 4, 5),\n        dtype=flow.float32,\n        device=flow.device(device),\n        requires_grad=True,\n    )\n    of_out = flow.mul(x, y)\n    np_out = np.multiply(x.numpy(), y.numpy())\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05))\n    of_out = of_out.sum()\n    of_out.backward()\n    test_case.assertTrue(np.allclose(x.grad.numpy(), np.sum(y.numpy()), 1e-05, 1e-05))\n    test_case.assertTrue(np.allclose(y.grad.numpy(), x.numpy(), 1e-05, 1e-05))\n\n\ndef inplace_mul_tensors_helper(test_case, device, arr_0, arr_y):\n    of_x = flow.tensor(\n        arr_0, dtype=flow.float32, device=flow.device(device), requires_grad=True,\n    )\n    of_inplace_x = of_x + 1\n    of_y = flow.tensor(\n        arr_y, dtype=flow.float32, device=flow.device(device), requires_grad=True,\n    )\n    id_inpalce_x = id(of_inplace_x)\n    of_inplace_x.mul_(of_y)\n    test_case.assertTrue(\n        np.allclose(of_inplace_x.numpy(), np.multiply(arr_0 + 1, arr_y), 1e-05, 1e-05)\n    )\n    test_case.assertTrue(id_inpalce_x == id(of_inplace_x))\n    of_inplace_x = of_inplace_x.sum()\n    of_inplace_x.backward()\n    test_case.assertTrue(np.allclose(arr_y, of_x.grad.numpy(), 1e-05, 1e-05))\n    test_case.assertTrue(np.allclose(arr_0 + 1, of_y.grad.numpy(), 1e-05, 1e-05))\n\n\ndef _test_inplace_mul_tensors(test_case, device):\n    arr_0 = np.random.rand(3, 5)\n    arr_y = np.random.rand(3, 5)\n    inplace_mul_tensors_helper(test_case, device, arr_0, arr_y)\n\n\ndef _test_inplace_mul_scalar(test_case, device):\n    arr = np.random.rand(2, 3, 4)\n    of_x = flow.tensor(\n        arr, dtype=flow.float32, device=flow.device(device), requires_grad=True\n    )\n    y = 3.25\n    of_inplace_x = of_x + 1\n    id_x_before = id(of_inplace_x)\n    of_inplace_x.mul_(y)\n    test_case.assertTrue(id_x_before == id(of_inplace_x))\n    test_case.assertTrue(np.allclose(of_inplace_x.numpy(), np.multiply(arr + 1, y)))\n\n    of_x = flow.tensor(\n        arr, dtype=flow.float32, device=flow.device(device), requires_grad=True\n    )\n    of_inplace_x = of_x + 1\n    of_inplace_x_id_before = id(of_inplace_x)\n    of_inplace_x.mul_(y)\n    test_case.assertTrue(of_inplace_x_id_before == id(of_inplace_x))\n    test_case.assertTrue(\n        np.allclose(of_inplace_x.numpy(), np.multiply(arr + 1, y), 1e-05, 1e-05)\n    )\n    of_inplace_x = of_inplace_x.sum()\n    of_inplace_x.backward()\n    test_case.assertTrue(\n        np.allclose(np.full(arr.shape, y), of_x.grad.numpy(), 1e-05, 1e-05)\n    )\n\n\ndef _test_mul_inplace_0size_tensor(test_case, device):\n    targets = flow.randn((0, 6), device=flow.device(device))\n    height, width = 640, 640\n    targets[:, 2:] *= flow.tensor(\n        (width, height, width, height), device=flow.device(device)\n    )\n    test_case.assertTrue(np.array_equal(targets.size(), (0, 6)))\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestMulModule(flow.unittest.TestCase):\n    def test_mul(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [\n            _test_mul_impl,\n            _test_inplace_mul_tensors,\n            _test_inplace_mul_scalar,\n            _test_mul_inplace_0size_tensor,\n        ]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n    @autotest(check_graph=True, include_complex=True)\n    def test_broadcast_mul(test_case):\n        device = random_device()\n        x_0 = random_tensor(ndim=3, dim0=4, dim1=2, dim2=3).to(device)\n        y = random_tensor(ndim=2, dim0=2, dim1=3).to(device)\n        x = x_0 + 1\n        x.mul_(y)\n        return x\n\n    @autotest(n=6, include_complex=True)\n    def test_non_contiguous_inplace_mul(test_case):\n        device = random_device()\n        x = random_tensor(2, 2, 4).to(device)\n        y = x + 1\n        y = y[:, 1:3]\n        y *= random_tensor(2, 2, 2).to(device)\n        return y\n\n    @autotest(n=10, include_complex=True)\n    def test_scalar_mul_with_random_devices(test_case):\n        x1_device = random_device()\n        x2_device = random_device()\n        x1 = random_tensor(2, 2, 3).to(x1_device).mean()\n        x2 = random_tensor(2, 2, 3).to(x2_device)\n        y = x1 * x2\n        return y\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_multi_tensor_yolov5_weight_update.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nimport torch\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\ndef _test_multi_tensor_weight_update_impl(test_case, device, shape, n, d):\n    def compare(a, b, rtol=1e-5, atol=1e-5):\n        test_case.assertTrue(\n            np.allclose(\n                a.detach().cpu().numpy(), b.detach().cpu().numpy(), rtol=rtol, atol=atol\n            ),\n            f\"\\na\\n{a.detach().cpu().numpy()}\\n{'-' * 80}\\nb:\\n{b.detach().cpu().numpy()}\\n{'*' * 80}\\ndiff:\\n{a.detach().cpu().numpy() - b.detach().cpu().numpy()}\",\n        )\n\n    weight = []\n    torch_weight = []\n    weight_update = []\n    torch_weight_update = []\n    for _ in range(n):\n        tmp = flow.tensor(\n            np.random.randn(*shape),\n            dtype=flow.float32,\n            device=flow.device(device),\n            requires_grad=False,\n        )\n        weight.append(tmp)\n        torch_weight.append(\n            torch.tensor(\n                tmp.numpy(),\n                dtype=torch.float32,\n                device=torch.device(device),\n                requires_grad=False,\n            )\n        )\n        tmp = flow.tensor(\n            np.random.randn(*shape),\n            dtype=flow.float32,\n            device=flow.device(device),\n            requires_grad=False,\n        )\n        weight_update.append(tmp)\n        torch_weight_update.append(\n            torch.tensor(\n                tmp.numpy(),\n                dtype=torch.float32,\n                device=torch.device(device),\n                requires_grad=False,\n            )\n        )\n    for i, v in enumerate(torch_weight):\n        v = v * d\n        v = v + (1 - d) * torch_weight_update[i]\n        torch_weight[i] = v\n\n    flow._C.multi_tensor_yolov5_weight_update(weight, weight_update, d)\n    for i in range(n):\n        compare(weight[i], torch_weight[i])\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestMultiTensorWeightUpdateModule(flow.unittest.TestCase):\n    def test_multi_tensor_weight_update(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [_test_multi_tensor_weight_update_impl]\n        arg_dict[\"device\"] = [\"cuda\"]\n        arg_dict[\"shape\"] = [(20, 1), (30, 1), (55, 1)]\n        arg_dict[\"n\"] = [5, 10, 292]\n        arg_dict[\"d\"] = [0.22, 0.5]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_multinomial.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nimport random\nimport numpy as np\nfrom collections import OrderedDict\nimport torch\n\nimport oneflow as flow\n\nimport oneflow.unittest\nfrom oneflow.test_utils.test_util import GenArgList\n\n\ndef _test_multinomial(test_case, device, seed, replacement, dtype):\n    n_dists = random.randint(8, 64)\n    n_categories = random.randint(8, 64)\n    num_samples = random.randint(4, n_categories)\n\n    weights_torch = torch.rand(\n        n_dists,\n        n_categories,\n        device=device,\n        dtype=torch.float32 if dtype == \"float\" else torch.float64,\n    )\n    weights_oneflow = flow.tensor(\n        weights_torch.cpu().numpy(),\n        device=device,\n        dtype=flow.float32 if dtype == \"float\" else flow.float64,\n    )\n\n    torch.manual_seed(seed)\n    flow.manual_seed(seed)\n\n    torch_res = torch.multinomial(\n        weights_torch, num_samples, replacement=replacement, generator=None\n    )\n    flow_res = flow.multinomial(\n        weights_oneflow, num_samples, replacement=replacement, generator=None\n    )\n\n    test_case.assertTrue(\n        np.allclose(torch_res.cpu().numpy(), flow_res.cpu().numpy(), atol=1e-8,)\n    )\n\n    torch_gen = torch.Generator(device=device)\n    torch_gen.manual_seed(seed)\n    oneflow_gen = flow.Generator(device=device)\n    oneflow_gen.manual_seed(seed)\n\n    torch_res = torch.multinomial(\n        weights_torch, num_samples, replacement=replacement, generator=torch_gen\n    )\n    flow_res = flow.multinomial(\n        weights_oneflow, num_samples, replacement=replacement, generator=oneflow_gen\n    )\n\n    test_case.assertTrue(\n        np.allclose(torch_res.cpu().numpy(), flow_res.cpu().numpy(), atol=1e-8,)\n    )\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestMultinomial(flow.unittest.TestCase):\n    def test_multinomial(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"device\"] = [\"cuda\", \"cpu\"]\n        arg_dict[\"seed\"] = [0, 2, 4]\n        arg_dict[\"replacement\"] = [True, False]\n        arg_dict[\"dtype\"] = [\"double\", \"float\"]\n        for arg in GenArgList(arg_dict):\n            _test_multinomial(test_case, *arg[0:])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_nansum.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\n\nfrom oneflow.test_utils.automated_test_util import *\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestNanSumModule(flow.unittest.TestCase):\n    @autotest(n=5, check_graph=True)\n    def test_nansum_without_nan(test_case):\n        device = random_device()\n        x = random_tensor(4, random(0, 5), 2).to(device)\n        y = torch.nansum(x)\n        return y\n\n    @autotest(n=5, check_graph=True)\n    def test_nansum_with_partial_nan(test_case):\n        device = random_device()\n        x = random_tensor(4, random(0, 5), 2).to(device)\n        mask = x < 0\n        x = x.masked_fill(mask, float(\"nan\"))\n        y = torch.nansum(x)\n        return y\n\n    @autotest(n=5, check_graph=True)\n    def test_nansum_with_total_nan(test_case):\n        device = random_device()\n        x = random_tensor(4, random(0, 5), 2).to(device)\n        mask = torch.ones_like(x).bool()\n        x = x.masked_fill(mask, float(\"nan\"))\n        y = torch.nansum(x)\n        return y\n\n    @autotest(n=5, check_graph=True)\n    def test_nansum_with_partial_nan_dims(test_case):\n        device = random_device()\n        x = random_tensor(4, random(0, 5), 2).to(device)\n        mask = x < 0\n        x = x.masked_fill(mask, float(\"nan\"))\n        y = torch.nansum(x, dim=random(0, 4).to(int))\n        return y\n\n    @autotest(n=5, check_graph=True)\n    def test_nansum_with_total_nan_dims(test_case):\n        device = random_device()\n        x = random_tensor(4, random(0, 5), 2).to(device)\n        mask = torch.ones_like(x).bool()\n        x = x.masked_fill(mask, float(\"nan\"))\n        y = torch.nansum(x, dim=random(0, 4).to(int))\n        return y\n\n    @autotest(n=5, auto_backward=False, check_graph=True)\n    def test_sum_with_0_size_tensor(test_case):\n        device = random_device()\n        x = random_tensor(4, 4, 3, 0, 2).to(device)\n        y = torch.nansum(x, dim=np.random.randint(0, 3))\n        return y\n\n    @autotest(n=5, auto_backward=False, check_graph=True)\n    def test_sum_with_0dim_tensor(test_case):\n        device = random_device()\n        x = random_tensor(ndim=0).to(device)\n        y = torch.nansum(x)\n        return y\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_narrow.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nimport numpy as np\nfrom random import shuffle\n\nfrom scipy.fftpack import ss_diff\n\nfrom oneflow.test_utils.automated_test_util import *\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestNarrow(flow.unittest.TestCase):\n    @autotest(check_graph=True)\n    def test_flow_narrow_start_with_random_data(test_case):\n        k0 = random(2, 6)\n        k1 = random(2, 6)\n        k2 = random(2, 6)\n        rand_dim = random(0, 3).to(int)\n        device = random_device()\n        x = random_tensor(ndim=3, dim0=k0, dim1=k1, dim3=k2).to(device)\n        return torch.narrow(x, dim=rand_dim, start=2, length=1)\n\n    @autotest(check_graph=True)\n    def test_flow_narrow_length_with_random_data(test_case):\n        k0 = random(2, 6)\n        k1 = random(2, 6)\n        k2 = random(2, 6)\n        rand_dim = random(0, 3).to(int)\n        device = random_device()\n        x = random_tensor(ndim=3, dim0=k0, dim1=k1, dim3=k2).to(device)\n        return torch.narrow(x, dim=rand_dim, start=0, length=2)\n\n    @autotest(n=10, check_graph=True)\n    def test_flow_narrow_with_stride(test_case):\n        k0 = random(2, 6)\n        k1 = random(2, 6)\n        k2 = random(2, 6)\n        rand_dim = random(0, 3).to(int)\n        device = random_device()\n        x = random_tensor(ndim=3, dim0=k0, dim1=k1, dim3=k2).to(device)\n        perm = [0, 1, 2]\n        shuffle(perm)\n        x = x.permute(perm)\n        y = torch.narrow(x, dim=rand_dim, start=0, length=2)\n        return y\n\n    @autotest(auto_backward=False, check_graph=True)\n    def test_flow_narrow_start_bool_with_random_data(test_case):\n        k0 = random(2, 6)\n        k1 = random(2, 6)\n        k2 = random(2, 6)\n        rand_dim = random(0, 3).to(int)\n        device = random_device()\n        x = random_tensor(ndim=3, dim0=k0, dim1=k1, dim3=k2).to(\n            device=device, dtype=torch.bool\n        )\n        return torch.narrow(x, dim=rand_dim, start=2, length=1)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_ne.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\ndef _test_ne(test_case, shape, device):\n    arr1 = np.random.randn(*shape)\n    arr2 = np.random.randn(*shape)\n    input = flow.tensor(arr1, dtype=flow.float32, device=flow.device(device))\n    other = flow.tensor(arr2, dtype=flow.float32, device=flow.device(device))\n    of_out = flow.ne(input, other)\n    of_out2 = flow.not_equal(input, other)\n    np_out = np.not_equal(arr1, arr2)\n    test_case.assertTrue(np.array_equal(of_out.numpy(), np_out))\n    test_case.assertTrue(np.array_equal(of_out2.numpy(), np_out))\n    test_case.assertTrue(input != None)\n    test_case.assertTrue(None != input)\n\n\ndef _test_tensor_ne_operator(test_case, shape, device):\n    arr1 = np.random.randn(*shape)\n    arr2 = np.random.randn(*shape)\n    input = flow.tensor(arr1, dtype=flow.float32, device=flow.device(device))\n    other = flow.tensor(arr2, dtype=flow.float32, device=flow.device(device))\n    of_out = input.ne(other)\n    np_out = np.not_equal(arr1, arr2)\n    test_case.assertTrue(np.array_equal(of_out.numpy(), np_out))\n\n\ndef _test_ne_int(test_case, shape, device):\n    arr = np.random.randn(*shape)\n    input = flow.tensor(arr, dtype=flow.float32, device=flow.device(device))\n    num = 1\n    of_out = flow.ne(input, num)\n    np_out = np.not_equal(arr, num)\n    test_case.assertTrue(np.array_equal(of_out.numpy(), np_out))\n\n\ndef _test_tensor_ne_operator_int(test_case, shape, device):\n    arr = np.random.randn(*shape)\n    input = flow.tensor(arr, dtype=flow.float32, device=flow.device(device))\n    num = 1\n    of_out = input.ne(num)\n    np_out = np.not_equal(arr, num)\n    test_case.assertTrue(np.array_equal(of_out.numpy(), np_out))\n\n\ndef _test_ne_float(test_case, shape, device):\n    arr = np.random.randn(*shape)\n    input = flow.tensor(arr, dtype=flow.float32, device=flow.device(device))\n    num = 1.0\n    of_out = flow.ne(input, num)\n    np_out = np.not_equal(arr, num)\n    test_case.assertTrue(np.array_equal(of_out.numpy(), np_out))\n\n\ndef _test_tensor_ne_operator_float(test_case, shape, device):\n    arr = np.random.randn(*shape)\n    input = flow.tensor(arr, dtype=flow.float32, device=flow.device(device))\n    num = 1.0\n    of_out = input.ne(num)\n    np_out = np.not_equal(arr, num)\n    test_case.assertTrue(np.array_equal(of_out.numpy(), np_out))\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestNe(flow.unittest.TestCase):\n    def test_ne(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_func\"] = [\n            _test_ne,\n            _test_tensor_ne_operator,\n            _test_ne_int,\n            _test_tensor_ne_operator_int,\n            _test_ne_float,\n            _test_tensor_ne_operator_float,\n        ]\n        arg_dict[\"shape\"] = [(2, 3), (2, 3, 4), (2, 4, 5, 6)]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n    @autotest(n=5, auto_backward=False, check_graph=True)\n    def test_ne_with_0_size_data(test_case):\n        device = random_device()\n        x1 = random_tensor(4, 2, 3, 0, 5).to(device)\n        x2 = random_tensor(4, 2, 3, 0, 5).to(device)\n        y1 = torch.ne(x1, x2)\n        y2 = torch.ne(x1, 2)\n        y3 = torch.ne(x1, 2.0)\n        return (y1, y2, y3)\n\n    @autotest(n=5, auto_backward=False)\n    def test_ne_with_0dim_data(test_case):\n        device = random_device()\n        x1 = random_tensor(ndim=0).to(device)\n        x2 = random_tensor(ndim=0).to(device)\n        y1 = torch.ne(x1, x2)\n        y2 = torch.ne(x1, 2)\n        y3 = torch.ne(x1, 2.0)\n        return (y1, y2, y3)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_negative.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\n\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestNegativeModule(flow.unittest.TestCase):\n    @autotest(n=5, auto_backward=False, check_graph=True)\n    def test_ne_with_0_size_data(test_case):\n        device = random_device()\n        x = random_tensor(4, 2, 3, 0, 5).to(device)\n        y1 = torch.negative(x)\n        y2 = torch.neg(x)\n        y3 = -x\n        return (y1, y2, y3)\n\n    @autotest()\n    def test_tensor_negative_with_random_data(test_case):\n        x = random_tensor().to(random_device())\n        return x.negative()\n\n    @autotest()\n    def test_negative_with_random_data(test_case):\n        x = random_tensor().to(random_device())\n        z = torch.negative(x)\n        return z\n\n    @autotest()\n    def test_neg_with_random_data(test_case):\n        x = random_tensor().to(random_device())\n        z = torch.neg(x)\n        return z\n\n    @autotest()\n    def test_tensor_negative_with_0dim_data(test_case):\n        x = random_tensor(ndim=0).to(random_device())\n        return x.negative()\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_nll_loss.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport numpy as np\nimport unittest\n\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@autotest(n=5)\ndef _test_nll_loss(\n    test_case, has_weight=False, split_batch_dim=False, split_class_dim=False\n):\n    N = random(1, 4) * 2\n    C = random(1, 10) * 2\n    ndim = random(2, 5).to(int).value()\n    dims = [random(2, 10) for i in range(ndim - 2)]\n    input_dims = [N, C] + dims\n    target_dims = [N] + dims\n    input = random_tensor(ndim, *input_dims)\n    target = random_tensor(\n        ndim - 1, *target_dims, low=0, high=C, dtype=int, requires_grad=False\n    )\n    weight = None\n    if has_weight:\n        weight = random_tensor(1, C, requires_grad=False)\n\n    device = random_device().value()\n    if not split_class_dim and not split_batch_dim:\n        input = input.to(device)\n        target = target.to(device)\n        if has_weight:\n            weight = weight.to(device)\n    else:\n        rank = flow.env.get_rank()\n        world_size = flow.env.get_world_size()\n        assert world_size % 2 == 0\n        ranks = np.array(range(world_size))\n\n        if split_batch_dim and split_class_dim:\n            placement = flow.placement(device, ranks.reshape((ranks.size // 2, 2)))\n            input_sbp = [flow.sbp.split(0), flow.sbp.split(1)]\n            target_sbp = [flow.sbp.split(0), flow.sbp.broadcast()]\n            weight_sbp = [flow.sbp.broadcast(), flow.sbp.split(0)]\n        elif split_batch_dim:\n            placement = flow.placement(device, ranks)\n            input_sbp = flow.sbp.split(0)\n            target_sbp = flow.sbp.split(0)\n            weight_sbp = flow.sbp.broadcast()\n        else:\n            placement = flow.placement(device, ranks)\n            input_sbp = flow.sbp.split(1)\n            target_sbp = flow.sbp.broadcast()\n            weight_sbp = flow.sbp.split(0)\n\n        input = input.to_global(placement=placement, sbp=input_sbp)\n        target = target.to_global(placement=placement, sbp=target_sbp)\n        # print(\n        #     f\"**[{rank}] input: {input.oneflow.shape} {input.oneflow.placement} {input.oneflow.sbp}\"\n        # )\n        # print(\n        #     f\"**[{rank}] target: {target.oneflow.shape} {target.oneflow.placement} {target.oneflow.sbp}\"\n        # )\n        if has_weight:\n            # print(f\"**[{rank}] weight: {weight.oneflow.numpy()}\")\n            weight = weight.to_global(placement=placement, sbp=weight_sbp)\n\n    # reduction = oneof(\"none\", \"sum\", \"mean\")\n    reduction = (\n        \"none\"  # Temporarily skip the test of \"sum\" and \"mean\" because of unknown error\n    )\n    if has_weight:\n        nll = torch.nn.NLLLoss(weight=weight, reduction=reduction)\n    else:\n        nll = torch.nn.NLLLoss(reduction=reduction)\n    return nll(input, target)\n\n\n@flow.unittest.skip_unless_1n1d()\nclass NLLLossTestCase(flow.unittest.TestCase):\n    def test_local(test_case):\n        _test_nll_loss(test_case)\n\n    def test_weighted(test_case):\n        _test_nll_loss(test_case, has_weight=True)\n\n\n@flow.unittest.skip_unless_1n2d()\nclass ParallelNLLLossTestCase(flow.unittest.TestCase):\n    @globaltest\n    def test_data_parallel(test_case):\n        _test_nll_loss(test_case, split_batch_dim=True)\n\n    @globaltest\n    def test_data_parallel_weighted(test_case):\n        _test_nll_loss(test_case, has_weight=True, split_batch_dim=True)\n\n    @globaltest\n    def test_model_parallel(test_case):\n        _test_nll_loss(test_case, split_class_dim=True)\n\n    @globaltest\n    def test_model_parallel_weighted(test_case):\n        _test_nll_loss(test_case, has_weight=True, split_class_dim=True)\n\n\n@flow.unittest.skip_unless_1n4d()\nclass TowDParallelNLLLossTestCase(flow.unittest.TestCase):\n    @globaltest\n    def test_2d_parallel(test_case):\n        _test_nll_loss(test_case, split_batch_dim=True, split_class_dim=True)\n\n    @globaltest\n    def test_2d_parallel_weighted(test_case):\n        _test_nll_loss(\n            test_case, has_weight=True, split_batch_dim=True, split_class_dim=True\n        )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_nms.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nimport oneflow as flow\nfrom oneflow.test_utils.test_util import GenArgList\n\n\ndef box_area(boxes):\n    return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])\n\n\ndef _box_inter_union_np(boxes1, boxes2):\n    area1 = box_area(boxes1)\n    area2 = box_area(boxes2)\n\n    lt = np.maximum(boxes1[:, np.newaxis, :2], boxes2[:, :2])\n    rb = np.minimum(boxes1[:, np.newaxis, 2:], boxes2[:, 2:])\n\n    wh = np.clip(rb - lt, a_min=0, a_max=np.inf)\n    inter = wh[:, :, 0] * wh[:, :, 1]\n\n    union = area1[:, np.newaxis] + area2 - inter\n\n    return inter, union\n\n\ndef box_iou_np(boxes1, boxes2):\n    inter, union = _box_inter_union_np(boxes1, boxes2)\n    iou = inter / union\n    return iou\n\n\ndef nms_np(boxes, scores, iou_threshold):\n    picked = []\n    indexes = np.argsort(-scores)\n    while len(indexes) > 0:\n        current = indexes[0]\n        picked.append(current.item())\n        if len(indexes) == 1:\n            break\n        current_box = boxes[current, :]\n        indexes = indexes[1:]\n        rest_boxes = boxes[indexes, :]\n        iou = np.squeeze(box_iou_np(rest_boxes, current_box[np.newaxis]), axis=1)\n        indexes = indexes[iou <= iou_threshold]\n\n    return np.asarray(picked)\n\n\ndef create_tensors_with_iou(N, iou_thresh):\n    boxes = np.random.rand(N, 4) * 100\n    boxes[:, 2:] += boxes[:, :2]\n    boxes[-1, :] = boxes[0, :]\n    x0, y0, x1, y1 = boxes[-1].tolist()\n    iou_thresh += 1e-5\n    boxes[-1, 2] += (x1 - x0) * (1 - iou_thresh) / iou_thresh\n    # Avoid score lists have the same score which will\n    # result in an unstable sort.\n    scores = np.random.choice(N, N, replace=False)\n    return boxes, scores\n\n\ndef _test_nms(test_case, device):\n    iou = 0.5\n    boxes, scores = create_tensors_with_iou(1000, iou)\n    boxes = flow.tensor(boxes, dtype=flow.float32, device=flow.device(device))\n    scores = flow.tensor(scores, dtype=flow.float32, device=flow.device(device))\n    keep_np = nms_np(boxes.numpy(), scores.numpy(), iou)\n    keep = flow.nms(boxes, scores, iou)\n    test_case.assertTrue(np.allclose(keep.numpy(), keep_np))\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestNMS(flow.unittest.TestCase):\n    def test_nms(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [_test_nms]\n        arg_dict[\"device\"] = [\"cuda\", \"cpu\"]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_noncontiguous_binary_op.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nimport oneflow as flow\nfrom oneflow.test_utils.test_util import GenArgList\n\n\ndef _test_op(test_case, x, y, inplace):\n    ref1 = x + y\n    out1 = flow._C.noncontiguous_binary_op(x, y, op=\"add\", inplace=inplace)\n    test_case.assertTrue(np.allclose(ref1.numpy(), out1.numpy(), rtol=1e-5, atol=1e-5))\n\n    ref2 = x - y\n    out2 = flow._C.noncontiguous_binary_op(x, y, op=\"sub\", inplace=inplace)\n    test_case.assertTrue(np.allclose(ref2.numpy(), out2.numpy(), rtol=1e-5, atol=1e-5))\n\n    ref3 = x * y\n    out3 = flow._C.noncontiguous_binary_op(x, y, op=\"mul\", inplace=inplace)\n    test_case.assertTrue(np.allclose(ref3.numpy(), out3.numpy(), rtol=1e-5, atol=1e-5))\n\n    y = y.abs() + 1e-3  # incase zero\n    ref4 = x / y\n    out4 = flow._C.noncontiguous_binary_op(x, y, op=\"div\", inplace=inplace)\n    print(np.abs(ref4 - out4).max())\n    test_case.assertTrue(np.allclose(ref4.numpy(), out4.numpy(), rtol=1e-3, atol=1e-3))\n\n\ndef _test_noncontiguous_binary_op(test_case, dtype, pack_size, ndims, inplace):\n    shape = []\n    for _ in range(ndims - 1):\n        if np.random.uniform(-1, 1) > 0:\n            shape.append(1 << np.random.randint(4, 7))\n        else:\n            shape.append(np.random.randint(20, 100))\n    shape.append(1 << np.random.randint(3, 7) + pack_size)\n    # case 1\n    x = flow.randn(*shape, requires_grad=True).cuda().to(dtype)\n    y = flow.randn(*shape, requires_grad=True).cuda().to(dtype)\n    d1, d2 = np.random.choice(ndims, 2, replace=False)\n    x1 = x.transpose(d1, d2)\n    y1 = y.transpose(d1, d2)\n    _test_op(test_case, x1, y1, inplace)\n\n    # case 2\n    y2 = flow.randn(*shape, requires_grad=True).cuda().to(dtype)\n    shape[d1], shape[d2] = shape[d2], shape[d1]\n    x = flow.randn(*shape, requires_grad=True).cuda().to(dtype)\n    x2 = x.transpose(d1, d2)\n    _test_op(test_case, x2, y2, inplace)\n\n\n@unittest.skipIf(True, \"skip test for noncontiguous_binary_op.\")\n@flow.unittest.skip_unless_1n1d()\nclass TestNonContiguousBinaryOp(flow.unittest.TestCase):\n    def test_noncontiguous_binary_op(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fn\"] = [_test_noncontiguous_binary_op]\n        arg_dict[\"dtype\"] = [flow.float16, flow.float32]\n        arg_dict[\"pack_size\"] = [1, 2, 4]\n        arg_dict[\"ndims\"] = [2, 3, 4]\n        arg_dict[\"inplace\"] = [True, False]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_nonzero.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\ndef np_nonzero(input, as_tuple):\n    if as_tuple:\n        return np.nonzero(input)\n    else:\n        return np.transpose(np.nonzero(input))\n\n\ndef _test_nonzero(test_case, shape, as_tuple, device):\n    np_input = np.random.randn(*shape)\n    input = flow.tensor(np_input, dtype=flow.float32, device=flow.device(device))\n    of_out = flow.nonzero(input, as_tuple)\n    np_out = np_nonzero(np_input, as_tuple)\n    if as_tuple:\n        test_case.assertTrue(\n            np.allclose(tuple(x.numpy() for x in of_out), np_out, 0.0001, 0.0001)\n        )\n    else:\n        test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001))\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestNonzero(flow.unittest.TestCase):\n    def test_nonzero(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [_test_nonzero]\n        arg_dict[\"shape\"] = [(2, 3), (2, 3, 4), (2, 4, 5, 6), (2, 3, 0, 4)]\n        arg_dict[\"as_tuple\"] = [True, False]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n    # Not check graph because of one reason:\n    # Reason 1, lazy tensor cannot call .numpy(). tensor.numpy() is not allowed to called in nn.Graph.build(*args) or called by lazy tensor.\n    # Please refer to File \"python/oneflow/nn/modules/nonzero.py\", line 29, in nonzero_op.\n    @autotest(auto_backward=False, check_graph=\"ValidatedFalse\")\n    def test_nonzero_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=random(2, 5).to(int)).to(device)\n        y = torch.nonzero(x)\n        return y\n\n    # Not check graph because of one reason:\n    # Reason 1, lazy tensor cannot call .numpy(). tensor.numpy() is not allowed to called in nn.Graph.build(*args) or called by lazy tensor.\n    # Please refer to File \"python/oneflow/nn/modules/nonzero.py\", line 29, in nonzero_op.\n    @autotest(auto_backward=False, check_graph=\"ValidatedFalse\")\n    def test_nonzero_bool_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=random(2, 5).to(int)).to(device=device, dtype=torch.bool)\n        y = torch.nonzero(x)\n        return y\n\n    # Not check graph because of one reason:\n    # Reason 1, lazy tensor cannot call .numpy(). tensor.numpy() is not allowed to called in nn.Graph.build(*args) or called by lazy tensor.\n    @autotest(auto_backward=False, check_graph=\"ValidatedFalse\")\n    def test_half_nonzero_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=random(2, 5).to(int)).to(\n            device=device, dtype=torch.float16\n        )\n        y = torch.nonzero(x)\n        return y\n\n    # Not check graph because of one reason:\n    # Reason 1, lazy tensor cannot call .numpy(). tensor.numpy() is not allowed to called in nn.Graph.build(*args) or called by lazy tensor.\n    # Please refer to File \"python/oneflow/nn/modules/nonzero.py\", line 29, in nonzero_op.\n    @autotest(auto_backward=False, check_graph=\"ValidatedFalse\")\n    def test_nonzero_with_0dim_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=0).to(device)\n        y = torch.nonzero(x)\n        return y\n\n    # Not check graph because of one reason:\n    # Reason 1, lazy tensor cannot call .numpy(). tensor.numpy() is not allowed to called in nn.Graph.build(*args) or called by lazy tensor.\n    # Please refer to File \"python/oneflow/nn/modules/nonzero.py\", line 29, in nonzero_op.\n    @autotest(auto_backward=False, check_graph=\"ValidatedFalse\")\n    def test_nonzero_tuple_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=random(2, 5).to(int)).to(device)\n        y = torch.nonzero(x, as_tuple=True)\n        return y\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_norm.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nfrom oneflow.test_utils.test_util import GenArgList\nfrom oneflow.test_utils.automated_test_util import *\nimport oneflow as flow\nimport oneflow.unittest\n\n\ndef _np_vector_norm_backward(x, ord=2, dim=None):\n    re = np.zeros_like(x)\n    if isinstance(ord, int) and isinstance(dim, int):\n        if ord == 0:\n            return re\n        else:\n            temp = np.sum(np.abs(x ** ord), dim) ** (1.0 / ord - 1)\n            re = np.where(x ** ord < 0, -temp, temp) * x ** (ord - 1)\n    elif dim == None and x.ndim == 1:\n        if ord == 0:\n            return re\n        elif ord == float(\"inf\"):\n            max_ind = np.argmax(np.abs(x))\n            re[max_ind] += 1 if x[max_ind] != 0 else 0\n            re = np.where(x < 0, -re, re)\n        elif ord == float(\"-inf\"):\n            min_ind = np.argmin(np.abs(x))\n            re[min_ind] += 1 if x[min_ind] != 0 else 0\n            re = np.where(x < 0, -re, re)\n        else:\n            temp = np.sum(np.abs(x ** ord)) ** (1.0 / ord - 1)\n            re = np.where(x ** ord < 0, -temp, temp) * x ** (ord - 1)\n    elif (\n        isinstance(ord, float)\n        and isinstance(dim, int)\n        and (ord in [float(\"inf\"), float(\"-inf\")])\n    ):\n        if ord == float(\"inf\"):\n            max_ind = np.argmax(np.abs(x), dim)\n            index = (\n                [(i, max_ind[i]) for i in range(len(max_ind))]\n                if dim == 1\n                else [(max_ind[i], i) for i in range(len(max_ind))]\n            )\n            print(index)\n            for j in index:\n                re[j] += 1 if x[j] != 0 else 0\n            re = np.where(x < 0, -re, re)\n        else:\n            min_ind = np.argmin(np.abs(x), dim)\n            index = (\n                [(i, min_ind[i]) for i in range(len(min_ind))]\n                if dim == 1\n                else [(min_ind[i], i) for i in range(len(min_ind))]\n            )\n            for j in index:\n                re[j] += 1 if x[j] != 0 else 0\n            re = np.where(x < 0, -re, re)\n    return re\n\n\ndef _np_matrix_norm_backward(x, ord=\"fro\"):\n    re = np.zeros_like(x)\n    if isinstance(ord, int):\n        if ord == 1:\n            max_ind = np.argmax(np.sum(np.abs(x), 0))\n            index = [(i, max_ind) for i in range(x.shape[0])]\n            for j in index:\n                re[j] += 1 if x[j] != 0 else 0\n            re = np.where(x < 0, -re, re)\n        elif ord == -1:\n            min_ind = np.argmin(np.sum(np.abs(x), 0))\n            index = [(i, min_ind) for i in range(x.shape[0])]\n            for j in index:\n                re[j] += 1 if x[j] != 0 else 0\n            re = np.where(x < 0, -re, re)\n    elif ord == \"fro\":\n        re = np.sum(x ** 2) ** (-0.5) * x\n    elif isinstance(ord, float) and ord in [float(\"inf\"), float(\"-inf\")]:\n        if ord == float(\"inf\"):\n            max_ind = np.argmax(np.sum(np.abs(x), 1))\n            index = [(max_ind, i) for i in range(x.shape[1])]\n            for j in index:\n                re[j] += 1 if x[j] != 0 else 0\n            re = np.where(x < 0, -re, re)\n        else:\n            min_ind = np.argmin(np.sum(np.abs(x), 1))\n            index = [(min_ind, i) for i in range(x.shape[1])]\n            for j in index:\n                re[j] += 1 if x[j] != 0 else 0\n            re = np.where(x < 0, -re, re)\n    return re\n\n\ndef _test_norm_1d(test_case, device):\n    input = flow.tensor(\n        np.random.randn(10), dtype=flow.float32, device=flow.device(device)\n    )\n    of_out_1 = flow.linalg.norm(input)\n    of_out_2 = flow.linalg.norm(input, ord=0)\n    of_out_3 = flow.linalg.norm(input, ord=3)\n    of_out_4 = flow.linalg.norm(input, ord=float(\"inf\"))\n    of_out_5 = flow.linalg.norm(input, ord=-float(\"inf\"))\n    np_out_1 = np.linalg.norm(input.numpy())\n    np_out_2 = np.linalg.norm(input.numpy(), ord=0)\n    np_out_3 = np.linalg.norm(input.numpy(), ord=3)\n    np_out_4 = np.linalg.norm(input.numpy(), ord=float(\"inf\"))\n    np_out_5 = np.linalg.norm(input.numpy(), ord=-float(\"inf\"))\n    test_case.assertTrue(np.allclose(of_out_1.numpy(), np_out_1, 1e-05, 1e-05))\n    test_case.assertTrue(np.allclose(of_out_2.numpy(), np_out_2, 1e-05, 1e-05))\n    test_case.assertTrue(np.allclose(of_out_3.numpy(), np_out_3, 1e-05, 1e-05))\n    test_case.assertTrue(np.allclose(of_out_4.numpy(), np_out_4, 1e-05, 1e-05))\n    test_case.assertTrue(np.allclose(of_out_5.numpy(), np_out_5, 1e-05, 1e-05))\n\n\ndef _test_norm_2d(test_case, device):\n    input = flow.tensor(\n        np.random.randn(5, 4), dtype=flow.float32, device=flow.device(device)\n    )\n    of_out_1 = flow.linalg.norm(input)\n    of_out_2 = flow.linalg.norm(input, dim=0)\n    of_out_3 = flow.linalg.norm(input, dim=1, keepdim=True)\n    of_out_4 = flow.linalg.norm(input, ord=1, dim=0)\n    of_out_5 = flow.linalg.norm(input, ord=-1, dim=1, keepdim=True)\n    np_out_1 = np.linalg.norm(input.numpy())\n    np_out_2 = np.linalg.norm(input.numpy(), axis=0)\n    np_out_3 = np.linalg.norm(input.numpy(), axis=1, keepdims=True)\n    np_out_4 = np.linalg.norm(input.numpy(), ord=1, axis=0)\n    np_out_5 = np.linalg.norm(input.numpy(), ord=-1, axis=1, keepdims=True)\n    test_case.assertTrue(np.allclose(of_out_1.numpy(), np_out_1, 1e-05, 1e-05))\n    test_case.assertTrue(np.allclose(of_out_2.numpy(), np_out_2, 1e-05, 1e-05))\n    test_case.assertTrue(np.allclose(of_out_3.numpy(), np_out_3, 1e-05, 1e-05))\n    test_case.assertTrue(np.allclose(of_out_4.numpy(), np_out_4, 1e-05, 1e-05))\n    test_case.assertTrue(np.allclose(of_out_5.numpy(), np_out_5, 1e-05, 1e-05))\n\n\ndef _test_norm_Nd(test_case, device):\n    input1 = flow.tensor(\n        np.random.randn(3, 4, 3), dtype=flow.float32, device=flow.device(device)\n    )\n    input2 = flow.tensor(\n        np.random.randn(3, 4, 3, 5), dtype=flow.float32, device=flow.device(device)\n    )\n    of_out_1 = flow.linalg.norm(input1)\n    of_out_2 = flow.linalg.norm(input1, dim=(0, 1))\n    of_out_3 = flow.linalg.norm(input2, dim=(0, 2))\n    np_out_1 = np.linalg.norm(input1.numpy())\n    np_out_2 = np.linalg.norm(input1.numpy(), axis=(0, 1))\n    np_out_3 = np.linalg.norm(input2.numpy(), axis=(0, 2))\n    test_case.assertTrue(np.allclose(of_out_1.numpy(), np_out_1, 1e-05, 1e-05))\n    test_case.assertTrue(np.allclose(of_out_2.numpy(), np_out_2, 1e-05, 1e-05))\n    test_case.assertTrue(np.allclose(of_out_3.numpy(), np_out_3, 1e-05, 1e-05))\n\n\ndef _test_fro_order_norm_backward(test_case, device):\n    input = flow.tensor(\n        np.random.randn(5, 4),\n        dtype=flow.float32,\n        device=flow.device(device),\n        requires_grad=True,\n    )\n    of_out = flow.linalg.norm(input)\n    of_out.backward()\n    np_out_grad = _np_matrix_norm_backward(input.numpy())\n    test_case.assertTrue(np.allclose(input.grad.numpy(), np_out_grad, 1e-05, 1e-05))\n\n\ndef _test_1d_inf_order_norm_backward(test_case, device):\n    for ord in [float(\"inf\"), -float(\"inf\")]:\n        input = flow.tensor(\n            np.random.randn(5),\n            dtype=flow.float32,\n            device=flow.device(device),\n            requires_grad=True,\n        )\n        of_out = flow.linalg.norm(input, ord=ord)\n        of_out.backward()\n        np_out_grad = _np_vector_norm_backward(input.numpy(), ord=ord)\n        test_case.assertTrue(np.allclose(input.grad.numpy(), np_out_grad, 1e-05, 1e-05))\n\n\ndef _test_2d_inf_order_norm_backward(test_case, device):\n    for ord in [float(\"inf\"), -float(\"inf\")]:\n        input = flow.tensor(\n            np.random.randn(5, 4),\n            dtype=flow.float32,\n            device=flow.device(device),\n            requires_grad=True,\n        )\n        of_out = flow.linalg.norm(input, ord=ord)\n        of_out.backward()\n        np_out_grad = _np_matrix_norm_backward(input.numpy(), ord=ord)\n        test_case.assertTrue(np.allclose(input.grad.numpy(), np_out_grad, 1e-05, 1e-05))\n\n\ndef _test_1d_digits_order_norm_backward(test_case, device):\n    for ord in [1, -1, 2, -2, 5]:\n        input = flow.tensor(\n            np.random.randn(5),\n            dtype=flow.float32,\n            device=flow.device(device),\n            requires_grad=True,\n        )\n        of_out = flow.linalg.norm(input, ord=ord)\n        of_out.backward()\n        np_out_grad = _np_vector_norm_backward(input.numpy(), ord=ord)\n        test_case.assertTrue(np.allclose(input.grad.numpy(), np_out_grad, 1e-05, 1e-05))\n\n\ndef _test_2d_digits_order_norm_backward(test_case, device):\n    for ord in [1, -1]:\n        input = flow.tensor(\n            np.random.randn(4, 5),\n            dtype=flow.float32,\n            device=flow.device(device),\n            requires_grad=True,\n        )\n        of_out = flow.linalg.norm(input, ord=ord)\n        of_out.backward()\n        np_out_grad = _np_matrix_norm_backward(input.numpy(), ord=ord)\n        test_case.assertTrue(np.allclose(input.grad.numpy(), np_out_grad, 1e-05, 1e-05))\n\n\ndef _test_linalg_norm_shape_not_match(test_case, device):\n    x = flow.randn(1, 3, 1, 5, 2)\n    x = x.to(device)\n    y = flow.linalg.norm(x, keepdim=True)\n    test_case.assertEqual(y.size(), (1, 1, 1, 1, 1))\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestNormModule(flow.unittest.TestCase):\n    def test_norm(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"fun\"] = [\n            _test_norm_1d,\n            _test_norm_2d,\n            _test_norm_Nd,\n            _test_fro_order_norm_backward,\n            _test_1d_inf_order_norm_backward,\n            _test_2d_inf_order_norm_backward,\n            _test_1d_digits_order_norm_backward,\n            _test_2d_digits_order_norm_backward,\n            _test_linalg_norm_shape_not_match,\n        ]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n    @autotest(n=5)\n    def test_no_dim_no_ord_norm_with_random_data(test_case):\n        device = random_device()\n        input = random_tensor().to(device)\n        keepdim = random_bool()\n        m = torch.linalg.norm(input, keepdim=keepdim)\n        n = torch.norm(input, keepdim=keepdim)\n        return m, n\n\n    @autotest(n=5)\n    def test_one_dim_norm_with_random_data(test_case):\n        device = random_device()\n        input = random_tensor(ndim=4).to(device)\n        dim = random(low=0, high=4).to(int)\n        k = random().to(float)\n        ord = oneof(float(\"inf\"), float(\"-inf\"), k, None)\n        keepdim = random_bool()\n        m = torch.linalg.norm(input, ord, dim, keepdim)\n        n = torch.norm(input, ord, dim, keepdim)\n        return m, n\n\n    @autotest(n=5)\n    def test_no_dim_one_shape_norm_with_random_data(test_case):\n        device = random_device()\n        input = random_tensor(ndim=1).to(device)\n        k = random().to(float)\n        ord = oneof(float(\"inf\"), float(\"-inf\"), k)\n        keepdim = random_bool()\n        m = torch.linalg.norm(input, ord=ord, keepdim=keepdim)\n        n = torch.norm(input, p=ord, keepdim=keepdim)\n        return m, n\n\n    @autotest(n=5)\n    def test_no_dim_two_shape_norm_with_random_data(test_case):\n        device = random_device()\n        input = random_tensor(ndim=2).to(device)\n        ord = oneof(float(\"inf\"), float(\"-inf\"), \"fro\", 1, -1)\n        keepdim = random().to(bool)\n        m = torch.linalg.norm(input, ord=ord, keepdim=keepdim)\n        return m\n\n    @autotest(n=5)\n    def test_tuple_dim_norm_with_random_data(test_case):\n        device = random_device()\n        input = random_tensor(ndim=2).to(device)\n        dim = oneof((-2, -1), (0, 1), (-1, 0))\n        ord = oneof(float(\"inf\"), float(\"-inf\"), \"fro\", 1, -1, None)\n        keepdim = random().to(bool)\n        m = torch.linalg.norm(input, ord=ord, dim=dim, keepdim=keepdim)\n        return m\n\n    @autotest(n=5)\n    def test_vector_norm_only_zero_with_random_data(test_case):\n        device = random_device()\n        input = random_tensor(ndim=2).to(device)\n        dim = oneof((-2, -1), (0, 1), (-1, 0))\n        keepdim = random().to(bool)\n        m = torch.linalg.vector_norm(input, ord=0, dim=dim, keepdim=keepdim)\n        return m\n\n    @autotest(n=5)\n    def test_ord_random_data(test_case):\n        device = random_device()\n        ndim = random(1, 3).to(int)\n        input = random_tensor(ndim).to(device)\n        p1 = random(-5, -1).to(int).value()\n        p2 = random(2, 6).to(int).value()\n        m = input.norm(p1)\n        n = input.norm(p2)\n        return m, n\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_normalize.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.test_util import GenArgList, type_name_to_flow_type\nfrom oneflow.test_utils.automated_test_util import *\n\n\ndef _test_functional_normalize_double_dtype(test_case, device, dtype):\n    dtype = type_name_to_flow_type[dtype]\n    x = flow.ones(2, 2, dtype=dtype).to(device)\n    y = flow.nn.functional.normalize(x, p=2, dim=0)\n    test_case.assertEqual((2, 2), y.shape)\n    out = np.array(\n        [\n            [0.7071067690849304, 0.7071067690849304],\n            [0.7071067690849304, 0.7071067690849304],\n        ]\n    )\n    test_case.assertTrue(np.allclose(y.numpy().tolist(), out, 1e-05, 1e-05))\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestFunctionalNormalize(flow.unittest.TestCase):\n    def test_functional_normalize_naive(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"fun\"] = [_test_functional_normalize_double_dtype]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"dtype\"] = [\"float32\", \"double\"]\n\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n    @autotest(n=5)\n    def test_functional_normalize(test_case):\n        device = random_device()\n        ndim = random(low=2)\n\n        shape = list(random_tensor(ndim=ndim).oneflow.shape)\n        dim = random(low=0, high=ndim).to(int).value()\n        shape[dim] = random(low=2, high=8).to(int).value()\n        shape = tuple(shape)\n\n        x = random_tensor(len(shape), *shape).to(device)\n        y = torch.nn.functional.normalize(x, oneof(2, 3, 4), dim, 1e-12)\n\n        return y\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_ofrecord_reader.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nimport os\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\nclass OFRecordDataLoader(flow.nn.Module):\n    def __init__(self, batch_size, device=None, placement=None, sbp=None):\n        super().__init__()\n        # don't shuffle, for comparing\n        shuffle = False\n\n        self.ofrecord_reader = flow.nn.OFRecordReader(\n            flow.unittest.dataset_dir(\"imagenet_227/train/32\"),\n            batch_size=batch_size,\n            data_part_num=2,\n            random_shuffle=shuffle,\n            shuffle_after_epoch=shuffle,\n            device=device,\n            placement=placement,\n            sbp=sbp,\n        )\n\n        self.record_label_decoder = flow.nn.OFRecordRawDecoder(\n            \"class/label\", shape=(), dtype=flow.int32\n        )\n\n        self.record_image_decoder = flow.nn.OFRecordImageDecoder(\n            \"encoded\", color_space=\"RGB\"\n        )\n\n        self.resize = flow.nn.image.Resize(target_size=[227, 227], dtype=flow.float32)\n\n    def forward(self):\n        record = self.ofrecord_reader()\n        label = self.record_label_decoder(record)\n        image_raw_buffer = self.record_image_decoder(record)\n        image = self.resize(image_raw_buffer)[0]\n        return image, label\n\n\nclass DataLoaderGraph(flow.nn.Graph):\n    def __init__(self, loader):\n        super().__init__()\n        self.loader_ = loader\n\n    def build(self):\n        return self.loader_()\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n@unittest.skipUnless(os.path.exists(flow.unittest.dataset_dir(\"imagenet_227\")), \"\")\n@flow.unittest.skip_unless_1n2d()\nclass DistributedOFRecordReaderTestCase(oneflow.unittest.TestCase):\n    def test(test_case):\n        rank = flow.env.get_rank()\n        # print(f\"DistributedOFRecordReaderTestCase.test on rank {rank} {os.getpid()}\")\n\n        eager_ofrecord_loader = OFRecordDataLoader(\n            batch_size=2, device=flow.device(\"cpu\", rank)\n        )\n\n        lazy_global_loader = OFRecordDataLoader(\n            batch_size=4,\n            placement=flow.placement(\"cpu\", ranks=[0, 1]),\n            sbp=[flow.sbp.split(0)],\n        )\n        loader_graph = DataLoaderGraph(lazy_global_loader)\n\n        iteration = 2\n        for i in range(iteration):\n            image, label = eager_ofrecord_loader()\n            # print(\n            #     f\"rank {rank} image: {image.shape}, {image.dtype}, device: {image.device}\"\n            #     f\"\\n{image.numpy().mean()}\"\n            # )\n            # print(\n            #     f\"rank {rank} label: {label.shape}, {label.dtype}, device: {label.device}\"\n            #     f\"\\n{label.numpy()}\"\n            # )\n\n            g_image, g_label = loader_graph()\n            # print(\n            #     f\"rank {rank} graph output image: {g_image.shape}, {g_image.dtype}, placement: {g_image.placement}\"\n            #     f\"\\n{g_image.to_local().numpy().mean()}\"\n            # )\n            # print(\n            #     f\"rank {rank} graph output label: {g_label.shape}, {g_label.dtype}, placement: {g_image.placement}\"\n            #     f\"\\n{g_label.to_local().numpy()}\"\n            # )\n\n            # print(f\"{'-' * 20} rank {rank} iter {i} complete {'-' * 20}\")\n            test_case.assertTrue(np.allclose(image.numpy(), g_image.to_local().numpy()))\n            test_case.assertTrue(np.allclose(label.numpy(), g_label.to_local().numpy()))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_one_embedding_adagrad.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\nimport tempfile\n\nimport os\n\n# dynamic memory allocation can't be tested in unittest\nos.environ[\"ONEFLOW_ONE_EMBEDDING_USE_DYNAMIC_MEMORY_ALLOCATION\"] = \"0\"\nimport numpy as np\nfrom oneflow.test_utils.test_util import GenArgDict\nfrom optimizer_test_util import clip_grad_norm_np\n\nimport oneflow as flow\nfrom oneflow.nn.parameter import Parameter\n\n\ndef compare_with_numpy_adagrad(\n    test_case, weight_decay, lr_decay, scale, learning_rate, train_iters,\n):\n\n    num_rows = 500\n    embedding_size = 128\n    model_shape = (num_rows, embedding_size)\n    line_size = embedding_size * 2\n\n    num_valid_seq = np.random.randint(1, num_rows, (train_iters))\n    skip_if_seq = [np.random.randint(2) for i in range(train_iters)]\n\n    random_grad_seq = []\n    for _ in range(train_iters):\n        random_grad_seq.append(np.random.uniform(size=model_shape).astype(np.float32))\n\n    init_value = np.random.uniform(size=(num_rows, line_size)).astype(np.float32)\n\n    down_scale_by = 10\n    epsilon = 1e-5\n\n    class TestGraph(flow.nn.Graph):\n        def __init__(self):\n            super().__init__()\n\n        def build(\n            self,\n            ids,\n            unique_embeddings,\n            embedding_grad,\n            lr_tensor,\n            down_scale_by_tensor,\n            skip_if,\n            train_step,\n        ):\n            # add id shuffle to set num_unique in op, and use it in update\n            (_, _, num_valid, _, _, _,) = flow._C.one_embedding_id_shuffle(\n                ids, table_ids=None, num_tables=1, embedding_name=\"\"\n            )\n            return flow._C.one_embedding_adagrad_update(\n                num_valid,\n                unique_embeddings,\n                embedding_grad,\n                lr_tensor,\n                down_scale_by_tensor,\n                skip_if,\n                train_step,\n                0,\n                0.0,\n                scale,\n                weight_decay,\n                lr_decay,\n                epsilon,\n                line_size,\n                embedding_size,\n                \"\",\n            )\n\n    graph = TestGraph()\n\n    def adagrad_by_oneflow():\n        unique_embeddings_tensor = flow.tensor(init_value, requires_grad=False).to(\n            \"cuda\"\n        )\n        lr_tensor = flow.tensor(\n            np.array(learning_rate).reshape(1,).astype(np.float32)\n        ).to(\"cuda\")\n        down_scale_by_tensor = flow.tensor(\n            np.array(down_scale_by).reshape(1,).astype(np.float32)\n        ).to(\"cuda\")\n\n        def train_one_iter(ids, unique_embeddings, embedding_grad, skip_if, train_step):\n            return graph(\n                ids,\n                unique_embeddings,\n                embedding_grad,\n                lr_tensor,\n                down_scale_by_tensor,\n                skip_if,\n                train_step,\n            )\n\n        for i in range(1, train_iters):\n            np_ids = np.zeros(num_rows)\n            np_ids[0 : num_valid_seq[i]] = np.arange(num_valid_seq[i])\n            # add ids of num_valid unique to use id_shuffle out_put num_unique as grad input\n            ids = flow.tensor(np_ids.astype(np.int32)).to(\"cuda\")\n            grad_tensor = flow.tensor(random_grad_seq[i]).to(\"cuda\")\n            skip_if_tensor = flow.tensor(\n                np.array(skip_if_seq[i]).reshape(1,).astype(np.int64)\n            ).to(\"cuda\")\n            step_tensor = flow.tensor(np.array(i).reshape(1,).astype(np.int64)).to(\n                \"cuda\"\n            )\n            updated_tensor = train_one_iter(\n                ids, unique_embeddings_tensor, grad_tensor, skip_if_tensor, step_tensor,\n            )\n            unique_embeddings_tensor[0 : num_valid_seq[i]] = updated_tensor[\n                0 : num_valid_seq[i]\n            ]\n        return unique_embeddings_tensor\n\n    def adagrad_by_numpy():\n        x = init_value[:, 0:embedding_size]\n        st = init_value[:, embedding_size:]\n\n        def train_one_iter(iter, num_valid, grad, model, state):\n            grad[0:num_valid] = grad[0:num_valid] * (scale / down_scale_by)\n            lr = learning_rate / (1 + iter * lr_decay)\n            state[0:num_valid] = (\n                state[0:num_valid] + grad[0:num_valid] * grad[0:num_valid]\n            )\n            model[0:num_valid] = (\n                model[0:num_valid]\n                - lr / (np.sqrt(state[0:num_valid]) + epsilon) * grad[0:num_valid]\n                - lr * weight_decay * model[0:num_valid]\n            )\n            return (model, state)\n\n        for i in range(1, train_iters):\n            if skip_if_seq[i] > 0:\n                pass\n            else:\n                (x, st) = train_one_iter(\n                    i, int(num_valid_seq[i]), random_grad_seq[i], x, st\n                )\n\n        return x, st\n\n    oneflow_res = adagrad_by_oneflow().numpy()\n    of_model = oneflow_res[:, 0:embedding_size]\n    of_sum = oneflow_res[:, embedding_size:]\n    np_model, np_sum = adagrad_by_numpy()\n    test_case.assertTrue(\n        np.allclose(of_model.flatten(), np_model.flatten(), rtol=0.001, atol=0.001)\n    )\n    test_case.assertTrue(\n        np.allclose(of_sum.flatten(), np_sum.flatten(), rtol=0.001, atol=0.001)\n    )\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n@flow.unittest.skip_unless_1n1d()\nclass TestOptimizers(flow.unittest.TestCase):\n    def test_one_embedding_adagrad(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"weight_decay\"] = [0, 0.1]\n        arg_dict[\"lr_decay\"] = [0, 0.1]\n        arg_dict[\"scale\"] = [1, 0.1]\n        arg_dict[\"learning_rate\"] = [0.3, 1.5]\n        arg_dict[\"train_iters\"] = [10]\n        for arg in GenArgDict(arg_dict):\n            compare_with_numpy_adagrad(test_case, **arg)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_one_embedding_adam.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\nimport tempfile\n\nimport os\n\n# dynamic memory allocation can't be tested in unittest\nos.environ[\"ONEFLOW_ONE_EMBEDDING_USE_DYNAMIC_MEMORY_ALLOCATION\"] = \"0\"\nimport numpy as np\nfrom oneflow.test_utils.test_util import GenArgDict\nfrom optimizer_test_util import clip_grad_norm_np\n\nimport oneflow as flow\nfrom oneflow.nn.parameter import Parameter\n\n\ndef compare_with_numpy_adam(\n    test_case,\n    weight_decay,\n    scale,\n    learning_rate,\n    train_iters,\n    do_bias_correction,\n    beta1,\n    beta2,\n    use_optional_tensor,\n):\n\n    num_rows = 500\n    embedding_size = 128\n    model_shape = (num_rows, embedding_size)\n    line_size = embedding_size * 3\n\n    num_valid_seq = np.random.randint(1, num_rows, (train_iters))\n    skip_if_seq = [np.random.randint(2) for i in range(train_iters)]\n\n    random_grad_seq = []\n    for _ in range(train_iters):\n        random_grad_seq.append(np.random.uniform(size=model_shape).astype(np.float32))\n\n    init_value = np.random.uniform(size=(num_rows, line_size)).astype(np.float32)\n\n    down_scale_by = 10\n\n    \"\"\"\n    In OneFlow's optimizer, learning_rate is passed by attr in eager mode, and passed by tensor in lazy mode.\n    in this test, if use_optional_tensor is True, we also pass lr_tensor/down_scale_by_tensor/skip_if tensor for unittest.\n    if use_optional_tensor is False, we only pass lr by attr, and not have down_scale_by_tensor/skip_if, so mul down_scale_by to scale and skip skip_if's test.\n    \"\"\"\n    bias_correction1_val = 1.0\n    bias_correction2_val = 1.0\n    if use_optional_tensor:\n        scale_val = scale\n    else:\n        # if pass as attr instead of tensor, mul down_scale_by to scale_value\n        scale_val = scale / down_scale_by\n    epsilon = 1e-5\n\n    class TestGraph(flow.nn.Graph):\n        def __init__(self):\n            super().__init__()\n\n        def build(\n            self,\n            ids,\n            unique_embeddings,\n            embedding_grad,\n            lr_tensor,\n            down_scale_by_tensor,\n            skip_if,\n            bias_correction1,\n            bias_correction2,\n        ):\n            # add id shuffle to set num_unique in op, and use it in update\n            (_, _, num_valid, _, _, _,) = flow._C.one_embedding_id_shuffle(\n                ids, table_ids=None, num_tables=1, embedding_name=\"\"\n            )\n            return flow._C.one_embedding_adam_update(\n                num_valid,\n                unique_embeddings,\n                embedding_grad,\n                lr_tensor,\n                down_scale_by_tensor,\n                skip_if,\n                bias_correction1,\n                bias_correction2,\n                learning_rate,\n                scale_val,\n                weight_decay,\n                beta1,\n                beta2,\n                bias_correction1_val,\n                bias_correction2_val,\n                epsilon,\n                do_bias_correction,\n                line_size,\n                embedding_size,\n                embedding_name=\"\",\n            )\n\n    graph = TestGraph()\n\n    def adam_by_oneflow():\n        unique_embeddings_tensor = flow.tensor(init_value, requires_grad=False).to(\n            \"cuda\"\n        )\n        if use_optional_tensor:\n            lr_tensor = flow.tensor(\n                np.array(learning_rate).reshape(1,).astype(np.float32)\n            ).to(\"cuda\")\n            down_scale_by_tensor = flow.tensor(\n                np.array(down_scale_by).reshape(1,).astype(np.float32)\n            ).to(\"cuda\")\n        else:\n            lr_tensor = None\n            down_scale_by_tensor = None\n\n        def train_one_iter(\n            ids,\n            unique_embeddings,\n            embedding_grad,\n            skip_if,\n            bias_correction1,\n            bias_correction2,\n        ):\n            return graph(\n                ids,\n                unique_embeddings,\n                embedding_grad,\n                lr_tensor,\n                down_scale_by_tensor,\n                skip_if,\n                bias_correction1,\n                bias_correction2,\n            )\n\n        for i in range(1, train_iters):\n            np_ids = np.zeros(num_rows)\n            np_ids[0 : num_valid_seq[i]] = np.arange(num_valid_seq[i])\n            # add ids of num_valid unique to use id_shuffle out_put num_unique as grad input\n            ids = flow.tensor(np_ids.astype(np.int32)).to(\"cuda\")\n            grad_tensor = flow.tensor(random_grad_seq[i]).to(\"cuda\")\n            if use_optional_tensor:\n                skip_if_tensor = flow.tensor(\n                    np.array(skip_if_seq[i]).reshape(1,).astype(np.int64)\n                ).to(\"cuda\")\n            else:\n                skip_if_tensor = None\n            if do_bias_correction and use_optional_tensor:\n                bias_correction1 = 1.0 - np.power(beta1, i)\n                bias_correction2 = 1.0 - np.power(beta2, i)\n                bias_correction1_tensor = flow.tensor(\n                    np.array(bias_correction1).reshape(1,).astype(np.float32)\n                ).to(\"cuda\")\n                bias_correction2_tensor = flow.tensor(\n                    np.array(bias_correction2).reshape(1,).astype(np.float32)\n                ).to(\"cuda\")\n            else:\n                bias_correction1_tensor = None\n                bias_correction2_tensor = None\n            updated_tensor = train_one_iter(\n                ids,\n                unique_embeddings_tensor,\n                grad_tensor,\n                skip_if_tensor,\n                bias_correction1_tensor,\n                bias_correction2_tensor,\n            )\n            unique_embeddings_tensor[0 : num_valid_seq[i]] = updated_tensor[\n                0 : num_valid_seq[i]\n            ]\n        return unique_embeddings_tensor\n\n    def adam_by_numpy():\n        x = init_value[:, 0:embedding_size]\n        m = init_value[:, embedding_size : 2 * embedding_size]\n        v = init_value[:, 2 * embedding_size : 3 * embedding_size]\n\n        def np_train_one_iter(step, num_valid, grad, model, state_m, state_v):\n            grad[0:num_valid] = grad[0:num_valid] * (scale / down_scale_by)\n\n            bias_correction1 = 1.0\n            bias_correction2 = 1.0\n\n            if do_bias_correction and use_optional_tensor:\n                bias_correction1 = 1.0 - np.power(beta1, step)\n                bias_correction2 = 1.0 - np.power(beta2, step)\n\n            state_m[0:num_valid] = (\n                beta1 * state_m[0:num_valid] + (1 - beta1) * grad[0:num_valid]\n            )\n            state_v[0:num_valid] = (\n                beta2 * state_v[0:num_valid]\n                + (1 - beta2) * grad[0:num_valid] * grad[0:num_valid]\n            )\n            denom = np.sqrt(state_v[0:num_valid]) / np.sqrt(bias_correction2) + epsilon\n\n            model[0:num_valid] = (\n                model[0:num_valid]\n                - ((learning_rate / bias_correction1) * state_m[0:num_valid] / denom)\n                - learning_rate * weight_decay * model[0:num_valid]\n            )\n            return (model, state_m, state_v)\n\n        for i in range(1, train_iters):  # if step = 0, bias_correction2 is 0\n            if skip_if_seq[i] > 0 and use_optional_tensor:\n                pass\n            else:\n                (x, m, v) = np_train_one_iter(\n                    i, int(num_valid_seq[i]), random_grad_seq[i], x, m, v\n                )\n        return x, m, v\n\n    oneflow_res = adam_by_oneflow().numpy()\n    of_model = oneflow_res[:, 0:embedding_size]\n    of_m = oneflow_res[:, embedding_size : 2 * embedding_size]\n    of_v = oneflow_res[:, 2 * embedding_size : 3 * embedding_size]\n    np_model, np_m, np_v = adam_by_numpy()\n    test_case.assertTrue(\n        np.allclose(of_model.flatten(), np_model.flatten(), rtol=0.001, atol=0.001)\n    )\n    test_case.assertTrue(\n        np.allclose(of_m.flatten(), np_m.flatten(), rtol=0.001, atol=0.001)\n    )\n    test_case.assertTrue(\n        np.allclose(of_v.flatten(), np_v.flatten(), rtol=0.001, atol=0.001)\n    )\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n@flow.unittest.skip_unless_1n1d()\nclass TestOptimizers(flow.unittest.TestCase):\n    @unittest.skip(\"skip for now, becase it failed 16 times in past week\")\n    def test_one_embedding_adam(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"weight_decay\"] = [0, 0.1]\n        arg_dict[\"scale\"] = [1, 0.1]\n        arg_dict[\"learning_rate\"] = [1, 1.5]\n        arg_dict[\"train_iters\"] = [10]\n        arg_dict[\"do_bias_correction\"] = [True, False]\n        arg_dict[\"beta1\"] = [0.9, 0.8]\n        arg_dict[\"beta2\"] = [0.9, 0.8]\n        arg_dict[\"use_optional_tensor\"] = [True, False]\n\n        for arg in GenArgDict(arg_dict):\n            compare_with_numpy_adam(test_case, **arg)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_one_embedding_ftrl.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nfrom collections import OrderedDict\nimport tempfile\n\nimport os\n\n# dynamic memory allocation can't be tested in unittest\nos.environ[\"ONEFLOW_ONE_EMBEDDING_USE_DYNAMIC_MEMORY_ALLOCATION\"] = \"0\"\nimport numpy as np\nfrom oneflow.test_utils.test_util import GenArgDict\nfrom optimizer_test_util import clip_grad_norm_np\n\nimport oneflow as flow\nfrom oneflow.nn.parameter import Parameter\n\n\ndef compare_with_numpy_ftrl(\n    test_case,\n    weight_decay,\n    lr_power,\n    lambda1,\n    lambda2,\n    beta,\n    scale,\n    learning_rate,\n    train_iters,\n    use_optional_tensor,\n):\n    num_rows = 500\n    embedding_size = 128\n    model_shape = (num_rows, embedding_size)\n    line_size = embedding_size * 3\n\n    num_valid_seq = np.random.randint(1, num_rows, (train_iters))\n    skip_if_seq = [np.random.randint(2) for i in range(train_iters)]\n    random_grad_seq = []\n    for _ in range(train_iters):\n        random_grad_seq.append(np.random.uniform(size=model_shape).astype(np.float32))\n\n    init_value = np.random.uniform(size=(num_rows, line_size)).astype(np.float32)\n\n    down_scale_by = 10\n\n    \"\"\"\n    In OneFlow's optimizer, learning_rate is passed by attr in eager mode, and passed by tensor in lazy mode.\n    in this test, if use_optional_tensor is True, we also pass lr_tensor/down_scale_by_tensor/skip_if tensor for unittest.\n    if use_optional_tensor is False, we only pass lr by attr, and not have down_scale_by_tensor/skip_if, so mul down_scale_by to scale and skip skip_if's test.\n    \"\"\"\n    if use_optional_tensor:\n        scale_val = scale\n    else:\n        # if pass as attr instead of tensor, mul down_scale_by to scale_value\n        scale_val = scale / down_scale_by\n\n    class TestGraph(flow.nn.Graph):\n        def __init__(self):\n            super().__init__()\n\n        def build(\n            self,\n            ids,\n            unique_embeddings,\n            embedding_grad,\n            lr_tensor,\n            down_scale_by_tensor,\n            skip_if,\n        ):\n            # add id shuffle to set num_unique in op, and use it in update\n            (_, _, num_valid, _, _, _,) = flow._C.one_embedding_id_shuffle(\n                ids, table_ids=None, num_tables=1, embedding_name=\"\"\n            )\n            return flow._C.one_embedding_ftrl_update(\n                num_valid,\n                unique_embeddings,\n                embedding_grad,\n                lr_tensor,\n                down_scale_by_tensor,\n                skip_if,\n                learning_rate,\n                scale_val,\n                weight_decay,\n                lr_power,\n                lambda1,\n                lambda2,\n                beta,\n                line_size,\n                embedding_size,\n                embedding_name=\"\",\n            )\n\n    graph = TestGraph()\n\n    def ftrl_by_oneflow():\n        unique_embeddings_tensor = flow.tensor(init_value, requires_grad=False).to(\n            \"cuda\"\n        )\n        if use_optional_tensor:\n            lr_tensor = flow.tensor(\n                np.array(learning_rate).reshape(1,).astype(np.float32)\n            ).to(\"cuda\")\n            down_scale_by_tensor = flow.tensor(\n                np.array(down_scale_by).reshape(1,).astype(np.float32)\n            ).to(\"cuda\")\n        else:\n            lr_tensor = None\n            down_scale_by_tensor = None\n\n        def train_one_iter(ids, unique_embeddings, embedding_grad, skip_if):\n            return graph(\n                ids,\n                unique_embeddings,\n                embedding_grad,\n                lr_tensor,\n                down_scale_by_tensor,\n                skip_if,\n            )\n\n        for i in range(1, train_iters):\n            np_ids = np.zeros(num_rows)\n            np_ids[0 : num_valid_seq[i]] = np.arange(num_valid_seq[i])\n            # add ids of num_valid unique to use id_shuffle out_put num_unique as grad input\n            ids = flow.tensor(np_ids.astype(np.int32)).to(\"cuda\")\n            grad_tensor = flow.tensor(random_grad_seq[i]).to(\"cuda\")\n            if use_optional_tensor:\n                skip_if_tensor = flow.tensor(\n                    np.array(skip_if_seq[i]).reshape(1,).astype(np.int64)\n                ).to(\"cuda\")\n            else:\n                skip_if_tensor = None\n\n            updated_tensor = train_one_iter(\n                ids, unique_embeddings_tensor, grad_tensor, skip_if_tensor,\n            )\n            unique_embeddings_tensor[0 : num_valid_seq[i]] = updated_tensor[\n                0 : num_valid_seq[i]\n            ]\n        return unique_embeddings_tensor\n\n    def ftrl_by_numpy():\n        x = init_value[:, 0:embedding_size]\n        accumulate = init_value[:, embedding_size : 2 * embedding_size]\n        z = init_value[:, 2 * embedding_size :]\n\n        def train_one_iter(iter, num_valid, grad, model, accum, z):\n            grad[0:num_valid] = grad[0:num_valid] * (scale / down_scale_by)\n\n            new_accum = accumulate[0:num_valid] + grad[0:num_valid] * grad[0:num_valid]\n\n            sigma = (\n                np.power(new_accum, lr_power)\n                - np.power(accumulate[0:num_valid], lr_power)\n            ) / learning_rate\n\n            new_z_val = z[0:num_valid] + grad[0:num_valid] - sigma * model[0:num_valid]\n\n            # Here weight_decay equals to AdamW's, not equal to l2.\n            update_val = (np.sign(new_z_val) * lambda1 - new_z_val) / (\n                (beta + np.power(new_accum, lr_power)) / learning_rate + lambda2\n            ) - learning_rate * weight_decay * model[0:num_valid]\n\n            model[0:num_valid] = np.where(np.abs(new_z_val) < lambda1, 0.0, update_val)\n            accumulate[0:num_valid] = new_accum\n            z[0:num_valid] = new_z_val\n\n            return (model, accumulate, z)\n\n        for i in range(1, train_iters):\n            # when use_optional_tensor is False, not pass skip_if to op\n            if skip_if_seq[i] > 0 and use_optional_tensor:\n                pass\n            else:\n                (x, accumulate, z) = train_one_iter(\n                    i, int(num_valid_seq[i]), random_grad_seq[i], x, accumulate, z\n                )\n\n        return x, accumulate, z\n\n    oneflow_res = ftrl_by_oneflow().numpy()\n    of_model = oneflow_res[:, 0:embedding_size]\n    of_accum = oneflow_res[:, embedding_size : 2 * embedding_size]\n    of_z = oneflow_res[:, 2 * embedding_size :]\n\n    np_model, np_accum, np_z = ftrl_by_numpy()\n    test_case.assertTrue(\n        np.allclose(of_model.flatten(), np_model.flatten(), rtol=1e-4, atol=1e-4)\n    )\n    test_case.assertTrue(\n        np.allclose(of_accum.flatten(), np_accum.flatten(), rtol=1e-4, atol=1e-4)\n    )\n    test_case.assertTrue(\n        np.allclose(of_z.flatten(), np_z.flatten(), rtol=1e-4, atol=1e-4)\n    )\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n@flow.unittest.skip_unless_1n1d()\nclass TestOptimizers(flow.unittest.TestCase):\n    @unittest.skip(\"skip for now, becase it failed 2 times in past week\")\n    def test_ftrl(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"weight_decay\"] = [\n            0.0\n        ]  # TODO(zzk): Currently Only support weight_decay = 0.0.\n        arg_dict[\"lr_power\"] = [-0.2, -0.05]\n        arg_dict[\"lambda1\"] = [0.1]\n        arg_dict[\"lambda2\"] = [0.00]\n        arg_dict[\"beta\"] = [1.0]\n        arg_dict[\"scale\"] = [1, 0.1]\n        arg_dict[\"learning_rate\"] = [0.3, 1.5]\n        arg_dict[\"train_iters\"] = [10]\n        arg_dict[\"use_optional_tensor\"] = [True, False]\n\n        for arg in GenArgDict(arg_dict):\n            compare_with_numpy_ftrl(test_case, **arg)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_one_embedding_sgd.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\nimport tempfile\n\nimport os\n\n# dynamic memory allocation can't be tested in unittest\nos.environ[\"ONEFLOW_ONE_EMBEDDING_USE_DYNAMIC_MEMORY_ALLOCATION\"] = \"0\"\nimport numpy as np\nfrom oneflow.test_utils.test_util import GenArgDict\nfrom optimizer_test_util import clip_grad_norm_np\n\nimport oneflow as flow\nfrom oneflow.nn.parameter import Parameter\n\n\ndef compare_with_numpy_sgd(\n    test_case,\n    momentum,\n    weight_decay,\n    scale,\n    learning_rate,\n    train_iters,\n    use_optional_tensor,\n):\n    # if use_optional_tensor, pass lr as tensor to sgd_update, else pass as attr.\n    num_rows = 500\n    embedding_size = 128\n    model_shape = (num_rows, embedding_size)\n    line_size = embedding_size * 2 if momentum > 0 else embedding_size\n\n    num_valid_seq = np.random.randint(1, num_rows, (train_iters))\n    skip_if_seq = [np.random.randint(2) for i in range(train_iters)]\n\n    random_grad_seq = []\n    for _ in range(train_iters):\n        random_grad_seq.append(np.random.uniform(size=model_shape).astype(np.float32))\n\n    init_value = np.random.uniform(size=(num_rows, line_size)).astype(np.float32)\n\n    \"\"\"\n    In OneFlow's optimizer, learning_rate is passed by attr in eager mode, and passed by tensor in lazy mode.\n    in this test, if use_optional_tensor is True, we also pass lr_tensor/down_scale_by_tensor/skip_if tensor for unittest.\n    if use_optional_tensor is False, we only pass lr by attr, and not have down_scale_by_tensor/skip_if, so mul down_scale_by to scale and skip skip_if's test.\n    \"\"\"\n    down_scale_by = 10\n    if use_optional_tensor:\n        scale_val = scale\n    else:\n        # if pass as attr instead of tensor, mul down_scale_by to scale_value\n        scale_val = scale / down_scale_by\n\n    class TestGraph(flow.nn.Graph):\n        def __init__(self):\n            super().__init__()\n\n        def build(\n            self,\n            ids,\n            unique_embeddings,\n            embedding_grad,\n            lr_tensor,\n            down_scale_by_tensor,\n            skip_if,\n        ):\n            # add id shuffle to set num_unique in op, and use it in update\n            (_, _, num_valid, _, _, _,) = flow._C.one_embedding_id_shuffle(\n                ids, table_ids=None, num_tables=1, embedding_name=\"\"\n            )\n            return flow._C.one_embedding_sgd_update(\n                num_valid,\n                unique_embeddings,\n                embedding_grad,\n                lr_tensor,\n                down_scale_by_tensor,\n                skip_if,\n                learning_rate,\n                scale_val,\n                weight_decay,\n                momentum,\n                line_size,\n                embedding_size,\n                embedding_name=\"\",\n            )\n\n    graph = TestGraph()\n\n    def sgd_by_oneflow():\n        unique_embeddings_tensor = flow.tensor(init_value, requires_grad=False).to(\n            \"cuda\"\n        )\n        if use_optional_tensor:\n            lr_tensor = flow.tensor(\n                np.array(learning_rate).reshape(1,).astype(np.float32)\n            ).to(\"cuda\")\n            down_scale_by_tensor = flow.tensor(\n                np.array((down_scale_by,)).astype(np.float32)\n            ).to(\"cuda\")\n        else:\n            # pass by attr\n            lr_tensor = None\n            down_scale_by_tensor = None\n\n        def train_one_iter(\n            ids,\n            unique_embeddings,\n            embedding_grad,\n            lr_tensor,\n            down_scale_by_tensor,\n            skip_if,\n        ):\n            return graph(\n                ids,\n                unique_embeddings,\n                embedding_grad,\n                lr_tensor,\n                down_scale_by_tensor,\n                skip_if,\n            )\n\n        for i in range(train_iters):\n            np_ids = np.zeros(num_rows)\n            np_ids[0 : num_valid_seq[i]] = np.arange(num_valid_seq[i])\n            # add ids of num_valid unique to use id_shuffle out_put num_unique as grad input\n            ids = flow.tensor(np_ids.astype(np.int32)).to(\"cuda\")\n            grad_tensor = flow.tensor(random_grad_seq[i]).to(\"cuda\")\n            if use_optional_tensor:\n                skip_if_tensor = flow.tensor(\n                    np.array(skip_if_seq[i]).reshape(1,).astype(np.int64)\n                ).to(\"cuda\")\n            else:\n                skip_if_tensor = None\n            updated_tensor = train_one_iter(\n                ids,\n                unique_embeddings_tensor,\n                grad_tensor,\n                lr_tensor,\n                down_scale_by_tensor,\n                skip_if_tensor,\n            )\n            unique_embeddings_tensor[0 : num_valid_seq[i]] = updated_tensor[\n                0 : num_valid_seq[i]\n            ]\n        return unique_embeddings_tensor\n\n    def sgd_by_numpy():\n        x = init_value[:, 0:embedding_size]\n        vt = init_value[:, embedding_size:]\n\n        def train_one_iter(num_valid, grad, model, state):\n            grad[0:num_valid] = grad[0:num_valid] * (scale / down_scale_by)\n            next_state = (\n                (momentum * state[0:num_valid] + grad[0:num_valid])\n                if momentum > 0\n                else 0\n            )\n            if momentum > 0:\n                state[0:num_valid] = next_state\n                model[0:num_valid] = (\n                    model[0:num_valid]\n                    - learning_rate * next_state\n                    - learning_rate * weight_decay * model[0:num_valid]\n                )\n            else:\n                state[0:num_valid] = 0\n                model[0:num_valid] = (\n                    model[0:num_valid]\n                    - learning_rate * grad[0:num_valid]\n                    - learning_rate * weight_decay * model[0:num_valid]\n                )\n            return (model, state)\n\n        for i in range(train_iters):\n            if skip_if_seq[i] > 0 and use_optional_tensor:\n                pass\n            else:\n                (x, vt) = train_one_iter(\n                    int(num_valid_seq[i]), random_grad_seq[i], x, vt\n                )\n        return x, vt\n\n    oneflow_res = sgd_by_oneflow().numpy()\n    of_model = oneflow_res[:, 0:embedding_size]\n    of_momentum = oneflow_res[:, embedding_size:]\n    np_model, np_momentum = sgd_by_numpy()\n    test_case.assertTrue(\n        np.allclose(of_model.flatten(), np_model.flatten(), rtol=0.001, atol=0.001)\n    )\n    if momentum > 0:\n        test_case.assertTrue(\n            np.allclose(\n                of_momentum.flatten(), np_momentum.flatten(), rtol=0.001, atol=0.001\n            )\n        )\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n@flow.unittest.skip_unless_1n1d()\nclass TestOptimizers(flow.unittest.TestCase):\n    def test_one_embedding_sgd(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"momentum\"] = [0, 0.9]\n        arg_dict[\"weight_decay\"] = [0, 0.1]\n        arg_dict[\"scale\"] = [1, 0.1]\n        arg_dict[\"learning_rate\"] = [1, 0.9]\n        arg_dict[\"train_iters\"] = [10]\n        arg_dict[\"use_optional_tensor\"] = [True, False]\n        for arg in GenArgDict(arg_dict):\n            compare_with_numpy_sgd(test_case, **arg)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_one_hot.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nfrom oneflow.test_utils.test_util import GenArgList\n\n\nfrom oneflow.test_utils.automated_test_util import *\nimport oneflow as flow\n\n\ndef _test_one_hot(test_case, device, num_classes, size, on_value, off_value):\n    x = np.random.randint(9, size=size)\n    input = flow.tensor(x, device=flow.device(device), dtype=flow.int64)\n    output = flow.nn.functional.one_hot(input, num_classes, on_value, off_value)\n    if num_classes == -1:\n        np_outtmp = np.eye(np.max(x) + 1)[x]\n    else:\n        np_outtmp = np.eye(num_classes)[x]\n    np_out = np.where(np_outtmp == 1, on_value, off_value)\n    test_case.assertTrue(np.allclose(output.numpy(), np_out, 1e-06, 1e-06))\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestOnehot(flow.unittest.TestCase):\n    def test_onehot(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [\n            _test_one_hot,\n        ]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"num_classes\"] = [-1, 10, 11]\n        arg_dict[\"size\"] = [(2, 3), (2, 3, 4), (2, 4, 5, 6)]\n        arg_dict[\"on_value\"] = [-1, -0.9, 0, 0.9, 1]\n        arg_dict[\"off_value\"] = [-2, -0.5, 0, 0.5, 2]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n    @autotest(auto_backward=False)\n    def test_one_hot_scalar(test_case):\n        x = torch.tensor(2)\n        y = torch.nn.functional.one_hot(x, num_classes=5)\n        return y\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_ones_like.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\ndef _test_ones_like_float(test_case, shape, device):\n    x = flow.tensor(\n        np.random.randn(*shape), dtype=flow.float32, device=flow.device(device)\n    )\n    y = flow.ones_like(x)\n    test_case.assertTrue(y.dtype is flow.float32)\n    test_case.assertTrue(y.shape == x.shape)\n    test_case.assertTrue(y.device == x.device)\n    y_numpy = np.ones_like(x.numpy())\n    test_case.assertTrue(np.array_equal(y.numpy(), y_numpy))\n\n\ndef _test_ones_like_int(test_case, shape, device):\n    x = flow.tensor(np.random.randn(*shape), dtype=flow.int, device=flow.device(device))\n    y = flow.ones_like(x)\n    test_case.assertTrue(y.dtype is flow.int)\n    test_case.assertTrue(y.shape == x.shape)\n    test_case.assertTrue(y.device == x.device)\n    y_numpy = np.ones_like(x.numpy())\n    test_case.assertTrue(np.array_equal(y.numpy(), y_numpy))\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestModule(flow.unittest.TestCase):\n    def test_ones_like(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [_test_ones_like_float, _test_ones_like_int]\n        arg_dict[\"shape\"] = [(2, 3), (2, 3, 4), (2, 4, 5, 6)]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_optim_adadelta.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport os\nimport tempfile\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nfrom oneflow.test_utils.test_util import GenArgList\nfrom optimizer_test_util import clip_grad_norm_np\n\nimport oneflow as flow\nfrom oneflow.nn.parameter import Parameter\n\n\ndef compare_with_numpy_adadelta(\n    test_case,\n    device,\n    x_shape,\n    learning_rate,\n    train_iters,\n    rho,\n    eps,\n    maximize,\n    weight_decay,\n    reload_state_step,\n    save_load_by_pickle,\n    contiguous_params,\n):\n    random_grad_seq = []\n    for _ in range(train_iters):\n        random_grad_seq.append(np.random.uniform(size=x_shape).astype(np.float32))\n    init_value = np.random.uniform(size=x_shape).astype(np.float32)\n\n    def train_by_oneflow():\n        x = Parameter(flow.Tensor(init_value, device=flow.device(device)))\n        adadelta = flow.optim.Adadelta(\n            [{\"params\": [x], \"lr\": learning_rate, \"weight_decay\": weight_decay,}],\n            rho=rho,\n            eps=eps,\n            maximize=maximize,\n            contiguous_params=contiguous_params,\n        )\n\n        def train_one_iter(grad):\n            grad_tensor = flow.tensor(\n                grad, requires_grad=False, device=flow.device(device)\n            )\n            loss = flow.sum(x * grad_tensor)\n            loss.backward()\n            adadelta.step()\n            adadelta.zero_grad()\n\n        for i in range(train_iters):\n            train_one_iter(random_grad_seq[i])\n            if i == reload_state_step:\n                state_dict = adadelta.state_dict()\n                adadelta = flow.optim.Adadelta([x], contiguous_params=contiguous_params)\n                if save_load_by_pickle:\n                    with tempfile.NamedTemporaryFile() as f:\n                        flow.save(state_dict, f.name)\n                        state_dict = flow.load(f.name)\n                adadelta.load_state_dict(state_dict)\n        return x\n\n    def train_by_numpy():\n        x = init_value\n        square_avgs = np.zeros_like(x)\n        acc_deltas = np.zeros_like(x)\n\n        def train_one_iter(grad):\n            grad = grad if not maximize else -grad\n            grad = grad + weight_decay * x\n            new_square_avgs = square_avgs * rho + (1.0 - rho) * grad * grad\n            std = np.sqrt(new_square_avgs + eps)\n            delta = np.sqrt(acc_deltas + eps) / std * grad\n            new_acc_deltas = acc_deltas * rho + delta * delta * (1 - rho)\n            param = x - learning_rate * delta\n            return (param, new_square_avgs, new_acc_deltas)\n\n        for i in range(1, train_iters + 1):\n            (x, square_avgs, acc_deltas) = train_one_iter(random_grad_seq[i - 1])\n        return x\n\n    oneflow_res = train_by_oneflow().numpy()\n    numpy_res = train_by_numpy()\n\n    test_case.assertTrue(\n        np.allclose(oneflow_res.flatten(), numpy_res.flatten(), rtol=1e-4, atol=1e-4)\n    )\n\n\ndef compare_with_numpy_adadelta_clip_grad(\n    test_case,\n    device,\n    x_shape,\n    learning_rate,\n    train_iters,\n    rho,\n    eps,\n    maximize,\n    weight_decay,\n    clip_grad_max_norm,\n    clip_grad_norm_type,\n    reload_state_step,\n    save_load_by_pickle,\n    contiguous_params,\n):\n    random_grad_seq = []\n    for _ in range(train_iters):\n        random_grad_seq.append(np.random.uniform(size=x_shape).astype(np.float32))\n    init_value = np.random.uniform(size=x_shape).astype(np.float32)\n\n    def train_by_oneflow():\n        x = Parameter(flow.Tensor(init_value, device=flow.device(device)))\n        adadelta = flow.optim.Adadelta(\n            [\n                {\n                    \"params\": [x],\n                    \"lr\": learning_rate,\n                    \"weight_decay\": weight_decay,\n                    \"clip_grad_max_norm\": clip_grad_max_norm,\n                    \"clip_grad_norm_type\": clip_grad_norm_type,\n                }\n            ],\n            rho=rho,\n            eps=eps,\n            maximize=maximize,\n            contiguous_params=contiguous_params,\n        )\n\n        def train_one_iter(grad):\n            grad_tensor = flow.tensor(\n                grad, requires_grad=False, device=flow.device(device)\n            )\n            loss = flow.sum(x * grad_tensor)\n            loss.backward()\n            adadelta.clip_grad()\n            adadelta.step()\n            adadelta.zero_grad()\n\n        for i in range(train_iters):\n            train_one_iter(random_grad_seq[i])\n            if i == reload_state_step:\n                state_dict = adadelta.state_dict()\n                adadelta = flow.optim.Adadelta([x], contiguous_params=contiguous_params)\n                if save_load_by_pickle:\n                    with tempfile.NamedTemporaryFile() as f:\n                        flow.save(state_dict, f.name)\n                        state_dict = flow.load(f.name)\n                adadelta.load_state_dict(state_dict)\n        return x\n\n    def train_by_numpy():\n        x = init_value\n        square_avgs = np.zeros_like(x)\n        acc_deltas = np.zeros_like(x)\n\n        def train_one_iter(grad):\n            total_norm, grad = clip_grad_norm_np(\n                grad, clip_grad_max_norm, clip_grad_norm_type\n            )\n            grad = grad if not maximize else -grad\n            grad = grad + weight_decay * x\n            new_square_avgs = square_avgs * rho + (1.0 - rho) * grad * grad\n            std = np.sqrt(new_square_avgs + eps)\n            delta = np.sqrt(acc_deltas + eps) / std * grad\n            new_acc_deltas = acc_deltas * rho + delta * delta * (1 - rho)\n            param = x - learning_rate * delta\n            return (param, new_square_avgs, new_acc_deltas)\n\n        for i in range(1, train_iters + 1):\n            (x, square_avgs, acc_deltas) = train_one_iter(random_grad_seq[i - 1])\n\n        return x\n\n    oneflow_res = train_by_oneflow().numpy()\n    numpy_res = train_by_numpy()\n\n    test_case.assertTrue(\n        np.allclose(oneflow_res.flatten(), numpy_res.flatten(), rtol=1e-4, atol=1e-4)\n    )\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestAdadelta(flow.unittest.TestCase):\n    def test_adadelta(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"x_shape\"] = [(10,)]\n        arg_dict[\"learning_rate\"] = [1, 1e-3]\n        arg_dict[\"train_iters\"] = [10]\n        arg_dict[\"rho\"] = [0.9, 0.6]\n        arg_dict[\"eps\"] = [1e-6, 1e-4]\n        arg_dict[\"maximize\"] = [False]\n        arg_dict[\"weight_decay\"] = [0.0, 0.1]\n        arg_dict[\"reload_state_step\"] = [5]  # save and load optim state\n        arg_dict[\"save_load_by_pickle\"] = [False, True]\n        arg_dict[\"contiguous_params\"] = [False, True]\n\n        for arg in GenArgList(arg_dict):\n            compare_with_numpy_adadelta(test_case, *arg)\n\n    def test_adadelta_clip_grad(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"device\"] = [\"cuda\"]\n        if os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"):\n            arg_dict[\"device\"] = [\"cpu\"]\n        arg_dict[\"x_shape\"] = [(10,)]\n        arg_dict[\"learning_rate\"] = [1e-3]\n        arg_dict[\"train_iters\"] = [10]\n        arg_dict[\"rho\"] = [0.9, 0.6]\n        arg_dict[\"eps\"] = [1e-6, 1e-4]\n        arg_dict[\"maximize\"] = [False]\n        arg_dict[\"weight_decay\"] = [0.0, 0.1]\n        arg_dict[\"clip_grad_max_norm\"] = [0, 0.5, 1.0]\n        arg_dict[\"clip_grad_norm_type\"] = [\"inf\", \"-inf\", 0.0, 1.0, 2.0, 3.5]\n        arg_dict[\"reload_state_step\"] = [5]  # save and load optim state\n        arg_dict[\"save_load_by_pickle\"] = [False, True]\n        arg_dict[\"contiguous_params\"] = [False, True]\n\n        for arg in GenArgList(arg_dict):\n            compare_with_numpy_adadelta_clip_grad(test_case, *arg)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_optim_adagrad.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport os\nimport tempfile\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nfrom oneflow.test_utils.test_util import GenArgList\nfrom optimizer_test_util import clip_grad_norm_np\n\nimport oneflow as flow\nfrom oneflow.nn.parameter import Parameter\n\n\ndef compare_with_numpy_adagrad(\n    test_case,\n    device,\n    x_shape,\n    learning_rate,\n    train_iters,\n    lr_decay,\n    weight_decay,\n    initial_accumulator_value,\n    eps,\n    reload_state_step,\n    save_load_by_pickle,\n    contiguous_params,\n):\n    random_grad_seq = []\n    for _ in range(train_iters):\n        random_grad_seq.append(np.random.uniform(size=x_shape).astype(np.float32))\n    init_value = np.random.uniform(size=x_shape).astype(np.float32)\n\n    def train_by_oneflow():\n        x = Parameter(flow.Tensor(init_value, device=flow.device(device)))\n        adagrad = flow.optim.Adagrad(\n            [\n                {\n                    \"params\": [x],\n                    \"lr\": learning_rate,\n                    \"eps\": eps,\n                    \"weight_decay\": weight_decay,\n                }\n            ],\n            lr_decay=lr_decay,\n            initial_accumulator_value=initial_accumulator_value,\n            contiguous_params=contiguous_params,\n        )\n\n        def train_one_iter(grad):\n            grad_tensor = flow.tensor(\n                grad, requires_grad=False, device=flow.device(device)\n            )\n            loss = flow.sum(x * grad_tensor)\n            loss.backward()\n            adagrad.step()\n            adagrad.zero_grad()\n\n        for i in range(train_iters):\n            train_one_iter(random_grad_seq[i])\n            if i == reload_state_step:\n                state_dict = adagrad.state_dict()\n                adagrad = flow.optim.Adagrad([x], contiguous_params=contiguous_params)\n                if save_load_by_pickle:\n                    with tempfile.NamedTemporaryFile() as f:\n                        flow.save(state_dict, f.name)\n                        state_dict = flow.load(f.name)\n                adagrad.load_state_dict(state_dict)\n        return x\n\n    def train_by_numpy():\n        x = init_value\n        st = np.ones_like(x) * initial_accumulator_value\n\n        def train_one_iter(iter, grad):\n            grad = grad + weight_decay * x\n            lr = learning_rate / (1 + (iter - 1) * lr_decay)\n            s = st + grad * grad\n            param = x - lr / (np.sqrt(s) + eps) * grad\n            return (param, s)\n\n        for i in range(1, train_iters + 1):\n            (x, st) = train_one_iter(i, random_grad_seq[i - 1])\n        return x\n\n    oneflow_res = train_by_oneflow().numpy()\n    numpy_res = train_by_numpy()\n\n    test_case.assertTrue(\n        np.allclose(oneflow_res.flatten(), numpy_res.flatten(), rtol=1e-3, atol=1e-3)\n    )\n\n\ndef compare_with_numpy_adagrad_clip_grad(\n    test_case,\n    device,\n    x_shape,\n    learning_rate,\n    train_iters,\n    lr_decay,\n    weight_decay,\n    initial_accumulator_value,\n    eps,\n    clip_grad_max_norm,\n    clip_grad_norm_type,\n    reload_state_step,\n    save_load_by_pickle,\n    contiguous_params,\n):\n    random_grad_seq = []\n    for _ in range(train_iters):\n        random_grad_seq.append(np.random.uniform(size=x_shape).astype(np.float32))\n    init_value = np.random.uniform(size=x_shape).astype(np.float32)\n\n    def train_by_oneflow():\n        x = Parameter(flow.Tensor(init_value, device=flow.device(device)))\n        adagrad = flow.optim.Adagrad(\n            [\n                {\n                    \"params\": [x],\n                    \"lr\": learning_rate,\n                    \"eps\": eps,\n                    \"weight_decay\": weight_decay,\n                    \"clip_grad_max_norm\": clip_grad_max_norm,\n                    \"clip_grad_norm_type\": clip_grad_norm_type,\n                }\n            ],\n            lr_decay=lr_decay,\n            initial_accumulator_value=initial_accumulator_value,\n            contiguous_params=contiguous_params,\n        )\n\n        def train_one_iter(grad):\n            grad_tensor = flow.tensor(\n                grad, requires_grad=False, device=flow.device(device)\n            )\n            loss = flow.sum(x * grad_tensor)\n            loss.backward()\n            adagrad.clip_grad()\n            adagrad.step()\n            adagrad.zero_grad()\n\n        for i in range(train_iters):\n            train_one_iter(random_grad_seq[i])\n            if i == reload_state_step:\n                state_dict = adagrad.state_dict()\n                adagrad = flow.optim.Adagrad([x], contiguous_params=contiguous_params)\n                if save_load_by_pickle:\n                    with tempfile.NamedTemporaryFile() as f:\n                        flow.save(state_dict, f.name)\n                        state_dict = flow.load(f.name)\n                adagrad.load_state_dict(state_dict)\n        return x\n\n    def train_by_numpy():\n        x = init_value\n        st = np.ones_like(x) * initial_accumulator_value\n\n        def train_one_iter(iter, grad):\n            total_norm, grad = clip_grad_norm_np(\n                grad, clip_grad_max_norm, clip_grad_norm_type\n            )\n            grad = grad + weight_decay * x\n\n            lr = learning_rate / (1 + (iter - 1) * lr_decay)\n            s = st + grad * grad\n            param = x - lr / (np.sqrt(s) + eps) * grad\n\n            return (param, s)\n\n        for i in range(1, train_iters + 1):\n            (x, st) = train_one_iter(i, random_grad_seq[i - 1])\n\n        return x\n\n    oneflow_res = train_by_oneflow().numpy()\n    numpy_res = train_by_numpy()\n\n    test_case.assertTrue(\n        np.allclose(oneflow_res.flatten(), numpy_res.flatten(), rtol=1e-3, atol=1e-3)\n    )\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestAdagrad(flow.unittest.TestCase):\n    def test_adagrad(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"x_shape\"] = [(10,)]\n        arg_dict[\"learning_rate\"] = [1, 1e-3]\n        arg_dict[\"train_iters\"] = [10]\n        arg_dict[\"lr_decay\"] = [0.9, 0.75]\n        arg_dict[\"weight_decay\"] = [0.0, 0.1]\n        arg_dict[\"initial_accumulator_value\"] = [1.0, 2.1]\n        arg_dict[\"eps\"] = [1e-08, 1e-07]\n        arg_dict[\"reload_state_step\"] = [5]  # save and load optim state\n        arg_dict[\"save_load_by_pickle\"] = [False, True]\n        arg_dict[\"contiguous_params\"] = [False, True]\n\n        for arg in GenArgList(arg_dict):\n            compare_with_numpy_adagrad(test_case, *arg)\n\n    def test_adagrad_clip_grad(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"device\"] = [\"cuda\"]\n        if os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"):\n            arg_dict[\"device\"] = [\"cpu\"]\n        arg_dict[\"x_shape\"] = [(10,)]\n        arg_dict[\"learning_rate\"] = [1, 1e-3]\n        arg_dict[\"train_iters\"] = [10]\n        arg_dict[\"lr_decay\"] = [0.9, 0.75]\n        arg_dict[\"weight_decay\"] = [0.0, 0.1]\n        arg_dict[\"initial_accumulator_value\"] = [2.1]\n        arg_dict[\"eps\"] = [1e-07]\n        arg_dict[\"clip_grad_max_norm\"] = [0, 0.5, 1.0]\n        arg_dict[\"clip_grad_norm_type\"] = [\"inf\", \"-inf\", 0.0, 1.0, 2.0, 3.5]\n        arg_dict[\"reload_state_step\"] = [5]  # save and load optim state\n        arg_dict[\"save_load_by_pickle\"] = [False, True]\n        arg_dict[\"contiguous_params\"] = [False, True]\n\n        for arg in GenArgList(arg_dict):\n            compare_with_numpy_adagrad_clip_grad(test_case, *arg)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_optim_adam.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport tempfile\nimport unittest\nfrom collections import OrderedDict\nimport random as random_util\n\nimport numpy as np\nfrom oneflow.test_utils.test_util import GenArgList\nfrom oneflow.test_utils.automated_test_util import random_device, random_bool\nfrom optimizer_test_util import clip_grad_norm_np\n\nimport oneflow as flow\nfrom oneflow.nn.parameter import Parameter\n\n\ndef compare_with_numpy_adam(\n    test_case,\n    device,\n    x_shape,\n    learning_rate,\n    train_iters,\n    betas,\n    weight_decay,\n    eps,\n    do_bias_correction,\n    amsgrad,\n    reload_state_step,\n    save_load_by_pickle,\n    contiguous_params,\n    fused,\n    tensor_num,\n):\n    random_grad_seq = []\n    init_value_seq = []\n\n    for i in range(tensor_num):\n        init_value_seq.append(np.random.uniform(size=x_shape).astype(np.float32))\n\n    for _ in range(train_iters):\n        random_grad_seq_per_iter = []\n        for i in range(tensor_num):\n            random_grad_seq_per_iter.append(\n                np.random.uniform(size=x_shape).astype(np.float32)\n            )\n        random_grad_seq.append(random_grad_seq_per_iter)\n\n    def train_by_oneflow():\n        x = []\n        for i in range(tensor_num):\n            x.append(\n                Parameter(flow.Tensor(init_value_seq[i], device=flow.device(device)))\n            )\n\n        adam = flow.optim.Adam(\n            [\n                {\n                    \"params\": x,\n                    \"lr\": learning_rate,\n                    \"betas\": betas,\n                    \"eps\": eps,\n                    \"weight_decay\": weight_decay,\n                }\n            ],\n            do_bias_correction=do_bias_correction,\n            amsgrad=amsgrad,\n            contiguous_params=contiguous_params,\n            fused=fused,\n        )\n\n        def train_one_iter(grad):\n            loss = 0.0\n            for i in range(tensor_num):\n                grad_tensor = flow.tensor(\n                    grad[i],\n                    dtype=flow.float32,\n                    requires_grad=False,\n                    device=flow.device(device),\n                )\n                loss += flow.sum(x[i] * grad_tensor)\n            loss.backward()\n            adam.step()\n            adam.zero_grad()\n\n        for i in range(train_iters):\n            train_one_iter(random_grad_seq[i])\n            if i == reload_state_step:\n                state_dict = adam.state_dict()\n                adam = flow.optim.Adam(\n                    [{\"params\": x,}], contiguous_params=contiguous_params\n                )\n                if save_load_by_pickle:\n                    with tempfile.NamedTemporaryFile() as f:\n                        flow.save(state_dict, f.name)\n                        state_dict = flow.load(f.name)\n                adam.load_state_dict(state_dict)\n        return x\n\n    def train_by_numpy(tensor_idx):\n        x = init_value_seq[tensor_idx]\n        vt = np.zeros_like(x)\n        st = np.zeros_like(x)\n        max_st = np.zeros_like(x)\n        beta1 = betas[0]\n        beta2 = betas[1]\n\n        def np_train_one_iter(step, grad):\n            grad = grad + weight_decay * x\n\n            bias_correction1 = 1.0\n            bias_correction2 = 1.0\n\n            if do_bias_correction:\n                bias_correction1 = 1.0 - np.power(beta1, step)\n                bias_correction2 = 1.0 - np.power(beta2, step)\n\n            v = beta1 * vt + (1 - beta1) * grad\n            s = beta2 * st + (1 - beta2) * grad * grad\n            max_s = np.zeros_like(x)\n\n            if amsgrad:\n                max_s = np.maximum(s, max_st)\n                denom = np.sqrt(max_s) / np.sqrt(bias_correction2) + eps\n            else:\n                denom = np.sqrt(s) / np.sqrt(bias_correction2) + eps\n\n            param = x - ((learning_rate / bias_correction1) * v / denom)\n            return (param, v, s, max_s)\n\n        for i in range(1, train_iters + 1):\n            (x, vt, st, max_st) = np_train_one_iter(\n                i, random_grad_seq[i - 1][tensor_idx]\n            )\n        return x\n\n    oneflow_res = train_by_oneflow()\n    numpy_res = []\n    for i in range(tensor_num):\n        numpy_res.append(train_by_numpy(i))\n\n    for i in range(tensor_num):\n        test_case.assertTrue(\n            np.allclose(\n                oneflow_res[i].numpy().flatten(),\n                numpy_res[i].flatten(),\n                rtol=0.001,\n                atol=0.0001,\n            )\n        )\n\n\ndef compare_with_numpy_adam_clip_grad(\n    test_case,\n    device,\n    x_shape,\n    learning_rate,\n    train_iters,\n    betas,\n    weight_decay,\n    eps,\n    do_bias_correction,\n    amsgrad,\n    clip_grad_max_norm,\n    clip_grad_norm_type,\n    reload_state_step,\n    save_load_by_pickle,\n    contiguous_params,\n    fused,\n    tensor_num,\n):\n    random_grad_seq = []\n    init_value_seq = []\n\n    for i in range(tensor_num):\n        init_value_seq.append(np.random.uniform(size=x_shape).astype(np.float32))\n\n    for _ in range(train_iters):\n        random_grad_seq_per_iter = []\n        for i in range(tensor_num):\n            random_grad_seq_per_iter.append(\n                np.random.uniform(size=x_shape).astype(np.float32)\n            )\n        random_grad_seq.append(random_grad_seq_per_iter)\n\n    def train_by_oneflow():\n        x = []\n        for i in range(tensor_num):\n            x.append(\n                Parameter(flow.Tensor(init_value_seq[i], device=flow.device(device)))\n            )\n\n        adam = flow.optim.Adam(\n            [\n                {\n                    \"params\": x,\n                    \"lr\": learning_rate,\n                    \"betas\": betas,\n                    \"eps\": eps,\n                    \"weight_decay\": weight_decay,\n                    \"clip_grad_max_norm\": clip_grad_max_norm,\n                    \"clip_grad_norm_type\": clip_grad_norm_type,\n                }\n            ],\n            do_bias_correction=do_bias_correction,\n            amsgrad=amsgrad,\n            contiguous_params=contiguous_params,\n            fused=fused,\n        )\n\n        def train_one_iter(grad):\n            loss = 0.0\n            for i in range(tensor_num):\n                grad_tensor = flow.tensor(\n                    grad[i],\n                    dtype=flow.float32,\n                    requires_grad=False,\n                    device=flow.device(device),\n                )\n                loss += flow.sum(x[i] * grad_tensor)\n            loss.backward()\n            adam.clip_grad()\n            adam.step()\n            adam.zero_grad()\n\n        for i in range(train_iters):\n            train_one_iter(random_grad_seq[i])\n            if i == reload_state_step:\n                state_dict = adam.state_dict()\n                adam = flow.optim.Adam(\n                    [{\"params\": x,}], contiguous_params=contiguous_params\n                )\n                if save_load_by_pickle:\n                    with tempfile.NamedTemporaryFile() as f:\n                        flow.save(state_dict, f.name)\n                        state_dict = flow.load(f.name)\n                adam.load_state_dict(state_dict)\n        return x\n\n    def train_by_numpy():\n        x = init_value_seq\n        vt = np.zeros_like(x)\n        st = np.zeros_like(x)\n        max_st = np.zeros_like(x)\n        beta1 = betas[0]\n        beta2 = betas[1]\n\n        def train_one_iter(step, grad):\n            total_norm, grad = clip_grad_norm_np(\n                grad, clip_grad_max_norm, clip_grad_norm_type\n            )\n\n            for i in range(tensor_num):\n                grad[i] = grad[i] + weight_decay * x[i]\n\n                bias_correction1 = 1.0\n                bias_correction2 = 1.0\n\n                if do_bias_correction:\n                    bias_correction1 = 1.0 - np.power(beta1, step)\n                    bias_correction2 = 1.0 - np.power(beta2, step)\n\n                vt[i] = beta1 * vt[i] + (1 - beta1) * grad[i]\n                st[i] = beta2 * st[i] + (1 - beta2) * grad[i] * grad[i]\n\n                if amsgrad:\n                    max_st[i] = np.maximum(st[i], max_st[i])\n                    denom = np.sqrt(max_st[i]) / np.sqrt(bias_correction2) + eps\n                else:\n                    denom = np.sqrt(st[i]) / np.sqrt(bias_correction2) + eps\n\n                x[i] = x[i] - ((learning_rate / bias_correction1) * vt[i] / denom)\n\n        for i in range(1, train_iters + 1):\n            train_one_iter(i, random_grad_seq[i - 1])\n        return x\n\n    oneflow_res = train_by_oneflow()\n    numpy_res = train_by_numpy()\n\n    for i in range(tensor_num):\n        test_case.assertTrue(\n            np.allclose(\n                oneflow_res[i].numpy().flatten(),\n                numpy_res[i].flatten(),\n                rtol=0.0001,\n                atol=0.0001,\n            )\n        )\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestAdam(flow.unittest.TestCase):\n    def test_adam(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"device\"] = [random_device().value()]\n        arg_dict[\"x_shape\"] = [(10,)]\n        arg_dict[\"learning_rate\"] = [1, 1e-3]\n        arg_dict[\"train_iters\"] = [10]\n        arg_dict[\"betas\"] = [(0.99, 0.9)]\n        arg_dict[\"weight_decay\"] = [0.9, 0.000]\n        arg_dict[\"eps\"] = [1e-08]\n        arg_dict[\"do_bias_correction\"] = [random_bool().value()]\n        arg_dict[\"amsgrad\"] = [random_bool().value()]\n        arg_dict[\"reload_state_step\"] = [5]  # save and load optim state\n        arg_dict[\"save_load_by_pickle\"] = [random_bool().value()]\n        arg_dict[\"contiguous_params\"] = [random_bool().value()]\n        arg_dict[\"fused\"] = [random_bool().value()]\n        arg_dict[\"tensor_num\"] = [1, 4]\n\n        for arg in GenArgList(arg_dict):\n            compare_with_numpy_adam(test_case, *arg)\n\n    def test_adam_clip_grad(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"device\"] = [random_device().value()]\n        arg_dict[\"x_shape\"] = [(10,)]\n        arg_dict[\"learning_rate\"] = [1e-3]\n        arg_dict[\"train_iters\"] = [10]\n        arg_dict[\"betas\"] = [(0.99, 0.9)]\n        arg_dict[\"weight_decay\"] = [0.1, 0.000]\n        arg_dict[\"eps\"] = [1e-08]\n        arg_dict[\"do_bias_correction\"] = [random_bool().value()]\n        arg_dict[\"amsgrad\"] = [random_bool().value()]\n        arg_dict[\"clip_grad_max_norm\"] = [0, 0.5, 1.0]\n        arg_dict[\"clip_grad_norm_type\"] = random_util.sample(\n            [\"inf\", \"-inf\", 0.0, 1.0, 2.0, 3.5], k=3\n        )\n        arg_dict[\"reload_state_step\"] = [5]  # save and load optim state\n        arg_dict[\"save_load_by_pickle\"] = [random_bool().value()]\n        arg_dict[\"contiguous_params\"] = [random_bool().value()]\n        arg_dict[\"fused\"] = [random_bool().value()]\n        arg_dict[\"tensor_num\"] = [1, 4]\n\n        for arg in GenArgList(arg_dict):\n            compare_with_numpy_adam_clip_grad(test_case, *arg)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_optim_adamw.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport tempfile\nimport unittest\nfrom collections import OrderedDict\nimport random as random_util\n\nimport numpy as np\nfrom oneflow.test_utils.test_util import GenArgList\nfrom oneflow.test_utils.automated_test_util import random_bool, random_device\nfrom optimizer_test_util import clip_grad_norm_np\n\nimport oneflow as flow\nfrom oneflow.nn.parameter import Parameter\n\n\ndef compare_with_numpy_adamw(\n    test_case,\n    device,\n    x_shape,\n    learning_rate,\n    train_iters,\n    betas,\n    weight_decay,\n    eps,\n    do_bias_correction,\n    amsgrad,\n    reload_state_step,\n    save_load_by_pickle,\n    contiguous_params,\n    fused,\n    tensor_num,\n):\n    random_grad_seq = []\n    init_value_seq = []\n\n    for i in range(tensor_num):\n        init_value_seq.append(np.random.uniform(size=x_shape).astype(np.float32))\n\n    for _ in range(train_iters):\n        random_grad_seq_per_iter = []\n        for i in range(tensor_num):\n            random_grad_seq_per_iter.append(\n                np.random.uniform(size=x_shape).astype(np.float32)\n            )\n        random_grad_seq.append(random_grad_seq_per_iter)\n\n    def train_by_oneflow():\n        x = []\n        for i in range(tensor_num):\n            x.append(\n                Parameter(flow.Tensor(init_value_seq[i], device=flow.device(device)))\n            )\n\n        adam = flow.optim.AdamW(\n            [\n                {\n                    \"params\": x,\n                    \"lr\": learning_rate,\n                    \"betas\": betas,\n                    \"eps\": eps,\n                    \"weight_decay\": weight_decay,\n                }\n            ],\n            do_bias_correction=do_bias_correction,\n            amsgrad=amsgrad,\n            contiguous_params=contiguous_params,\n            fused=fused,\n        )\n\n        def train_one_iter(grad):\n            loss = 0.0\n            for i in range(tensor_num):\n                grad_tensor = flow.tensor(\n                    grad[i],\n                    dtype=flow.float32,\n                    requires_grad=False,\n                    device=flow.device(device),\n                )\n                loss += flow.sum(x[i] * grad_tensor)\n            loss.backward()\n            adam.step()\n            adam.zero_grad()\n\n        for i in range(train_iters):\n            train_one_iter(random_grad_seq[i])\n            if i == reload_state_step:\n                state_dict = adam.state_dict()\n                adam = flow.optim.AdamW(x, contiguous_params=contiguous_params)\n                if save_load_by_pickle:\n                    with tempfile.NamedTemporaryFile() as f:\n                        flow.save(state_dict, f.name)\n                        state_dict = flow.load(f.name)\n                adam.load_state_dict(state_dict)\n        return x\n\n    def train_by_numpy(tensor_idx):\n        x = init_value_seq[tensor_idx]\n        vt = np.zeros_like(x)\n        st = np.zeros_like(x)\n        max_st = np.zeros_like(x)\n        beta1 = betas[0]\n        beta2 = betas[1]\n\n        def train_one_iter(step, grad):\n            v = beta1 * vt + (1 - beta1) * grad\n            s = beta2 * st + (1 - beta2) * grad * grad\n\n            bias_correction1 = 1.0\n            bias_correction2 = 1.0\n\n            if do_bias_correction:\n                bias_correction1 = 1.0 - np.power(beta1, step)\n                bias_correction2 = 1.0 - np.power(beta2, step)\n\n            max_s = np.zeros_like(x)\n            if amsgrad:\n                max_s = np.maximum(s, max_st)\n                denom = np.sqrt(max_s) / np.sqrt(bias_correction2) + eps\n            else:\n                denom = np.sqrt(s) / np.sqrt(bias_correction2) + eps\n\n            lr = learning_rate / bias_correction1 / denom\n            g = lr * v + learning_rate * weight_decay * x\n            param = x - g\n            return (param, v, s, max_s)\n\n        for i in range(1, train_iters + 1):\n            (x, vt, st, max_st) = train_one_iter(i, random_grad_seq[i - 1][tensor_idx])\n        return x\n\n    oneflow_res = train_by_oneflow()\n    numpy_res = []\n    for i in range(tensor_num):\n        numpy_res.append(train_by_numpy(i))\n\n    for i in range(tensor_num):\n        test_case.assertTrue(\n            np.allclose(\n                oneflow_res[i].numpy().flatten(),\n                numpy_res[i].flatten(),\n                rtol=0.0001,\n                atol=0.0001,\n            )\n        )\n\n\ndef compare_with_numpy_adamw_clip_grad(\n    test_case,\n    device,\n    x_shape,\n    learning_rate,\n    train_iters,\n    betas,\n    weight_decay,\n    eps,\n    do_bias_correction,\n    amsgrad,\n    clip_grad_max_norm,\n    clip_grad_norm_type,\n    reload_state_step,\n    save_load_by_pickle,\n    contiguous_params,\n    fused,\n    tensor_num,\n):\n    random_grad_seq = []\n    init_value_seq = []\n\n    for i in range(tensor_num):\n        init_value_seq.append(np.random.uniform(size=x_shape).astype(np.float32))\n\n    for _ in range(train_iters):\n        random_grad_seq_per_iter = []\n        for i in range(tensor_num):\n            random_grad_seq_per_iter.append(\n                np.random.uniform(size=x_shape).astype(np.float32)\n            )\n        random_grad_seq.append(random_grad_seq_per_iter)\n\n    def train_by_oneflow():\n        x = []\n        for i in range(tensor_num):\n            x.append(\n                Parameter(flow.Tensor(init_value_seq[i], device=flow.device(device)))\n            )\n\n        adam = flow.optim.AdamW(\n            [\n                {\n                    \"params\": x,\n                    \"lr\": learning_rate,\n                    \"betas\": betas,\n                    \"eps\": eps,\n                    \"weight_decay\": weight_decay,\n                    \"clip_grad_max_norm\": clip_grad_max_norm,\n                    \"clip_grad_norm_type\": clip_grad_norm_type,\n                }\n            ],\n            do_bias_correction=do_bias_correction,\n            amsgrad=amsgrad,\n            contiguous_params=contiguous_params,\n            fused=fused,\n        )\n\n        def train_one_iter(grad):\n            loss = 0.0\n            for i in range(tensor_num):\n                grad_tensor = flow.tensor(\n                    grad[i],\n                    dtype=flow.float32,\n                    requires_grad=False,\n                    device=flow.device(device),\n                )\n                loss += flow.sum(x[i] * grad_tensor)\n            loss.backward()\n            adam.clip_grad()\n            adam.step()\n            adam.zero_grad()\n\n        for i in range(train_iters):\n            train_one_iter(random_grad_seq[i])\n            if i == reload_state_step:\n                state_dict = adam.state_dict()\n                adam = flow.optim.AdamW(x, contiguous_params=contiguous_params)\n                if save_load_by_pickle:\n                    with tempfile.NamedTemporaryFile() as f:\n                        flow.save(state_dict, f.name)\n                        state_dict = flow.load(f.name)\n                adam.load_state_dict(state_dict)\n        return x\n\n    def train_by_numpy():\n        x = init_value_seq\n        vt = np.zeros_like(x)\n        st = np.zeros_like(x)\n        max_st = np.zeros_like(x)\n\n        beta1 = betas[0]\n        beta2 = betas[1]\n\n        def train_one_iter(step, grad):\n            total_norm, grad = clip_grad_norm_np(\n                grad, clip_grad_max_norm, clip_grad_norm_type\n            )\n\n            for i in range(tensor_num):\n                vt[i] = beta1 * vt[i] + (1 - beta1) * grad[i]\n                st[i] = beta2 * st[i] + (1 - beta2) * grad[i] * grad[i]\n\n                bias_correction1 = 1.0\n                bias_correction2 = 1.0\n\n                if do_bias_correction:\n                    bias_correction1 = 1.0 - np.power(beta1, step)\n                    bias_correction2 = 1.0 - np.power(beta2, step)\n\n                if amsgrad:\n                    max_st[i] = np.maximum(st[i], max_st[i])\n                    denom = np.sqrt(max_st[i]) / np.sqrt(bias_correction2) + eps\n                else:\n                    denom = np.sqrt(st[i]) / np.sqrt(bias_correction2) + eps\n\n                lr = learning_rate / bias_correction1 / denom\n                g = lr * vt[i] + learning_rate * weight_decay * x[i]\n                x[i] = x[i] - g\n\n        for i in range(1, train_iters + 1):\n            train_one_iter(i, random_grad_seq[i - 1])\n        return x\n\n    oneflow_res = train_by_oneflow()\n    numpy_res = train_by_numpy()\n\n    for i in range(tensor_num):\n        test_case.assertTrue(\n            np.allclose(\n                oneflow_res[i].numpy().flatten(),\n                numpy_res[i].flatten(),\n                rtol=0.0001,\n                atol=0.0001,\n            )\n        )\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestAdamW(flow.unittest.TestCase):\n    def test_adamw(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"device\"] = [random_device().value()]\n        arg_dict[\"x_shape\"] = [(10,)]\n        arg_dict[\"learning_rate\"] = [1]\n        arg_dict[\"train_iters\"] = [10]\n        arg_dict[\"betas\"] = [(0.9, 0.999)]\n        arg_dict[\"weight_decay\"] = [0.01, 0.00]\n        arg_dict[\"eps\"] = [1e-8]\n        arg_dict[\"do_bias_correction\"] = [random_bool().value()]\n        arg_dict[\"amsgrad\"] = [random_bool().value()]\n        arg_dict[\"reload_state_step\"] = [5]  # save and load optim state\n        arg_dict[\"save_load_by_pickle\"] = [random_bool().value()]\n        arg_dict[\"contiguous_params\"] = [random_bool().value()]\n        arg_dict[\"fused\"] = [random_bool().value()]\n        arg_dict[\"tensor_num\"] = [1, 4]\n        for arg in GenArgList(arg_dict):\n            compare_with_numpy_adamw(test_case, *arg)\n\n    def test_adamw_clip_grad(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"device\"] = [random_device().value()]\n        arg_dict[\"x_shape\"] = [(10,)]\n        arg_dict[\"learning_rate\"] = [1]\n        arg_dict[\"train_iters\"] = [10]\n        arg_dict[\"betas\"] = [(0.9, 0.999)]\n        arg_dict[\"weight_decay\"] = [0.001, 0.0]\n        arg_dict[\"eps\"] = [1e-8]\n        arg_dict[\"do_bias_correction\"] = [random_bool().value()]\n        arg_dict[\"amsgrad\"] = [random_bool().value()]\n        arg_dict[\"clip_grad_max_norm\"] = [0, 0.5, 1.0]\n        arg_dict[\"clip_grad_norm_type\"] = random_util.sample(\n            [\"inf\", \"-inf\", 0.0, 1.0, 2.0, 3.5], k=3\n        )\n        arg_dict[\"reload_state_step\"] = [5]  # save and load optim state\n        arg_dict[\"save_load_by_pickle\"] = [random_bool().value()]\n        arg_dict[\"contiguous_params\"] = [random_bool().value()]\n        arg_dict[\"fused\"] = [random_bool().value()]\n        arg_dict[\"tensor_num\"] = [1, 4]\n        for arg in GenArgList(arg_dict):\n            compare_with_numpy_adamw_clip_grad(test_case, *arg)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_optim_add_param_group.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom oneflow.test_utils.test_util import GenArgList\nimport oneflow as flow\n\n\ndef _test_sgd_add_param_group(test_case):\n    w1 = flow.ones(3, 3)\n    w1.requires_grad = True\n    w2 = flow.ones(3, 3)\n    w2.requires_grad = True\n    o = flow.optim.SGD([w1])\n    test_case.assertTrue(o.param_groups[0][\"lr\"] == 0.001)\n    test_case.assertTrue(o.param_groups[0][\"momentum\"] == 0.0)\n    test_case.assertTrue(o.param_groups[0][\"weight_decay\"] == 0.0)\n    test_case.assertTrue(o.param_groups[0][\"nesterov\"] == False)\n    test_case.assertTrue(o.param_groups[0][\"maximize\"] == False)\n    o.step()\n    o.add_param_group({\"params\": w2})\n    test_case.assertTrue(o.param_groups[1][\"lr\"] == 0.001)\n    test_case.assertTrue(o.param_groups[1][\"momentum\"] == 0.0)\n    test_case.assertTrue(o.param_groups[1][\"weight_decay\"] == 0.0)\n    test_case.assertTrue(o.param_groups[1][\"nesterov\"] == False)\n    test_case.assertTrue(o.param_groups[1][\"maximize\"] == False)\n    o.step()\n\n\nclass TestAddParamGroup(flow.unittest.TestCase):\n    def test_sgd_add_param_group(test_case):\n        _test_sgd_add_param_group(test_case)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_optim_ftrl.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport tempfile\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nfrom oneflow.test_utils.test_util import GenArgList\nfrom optimizer_test_util import clip_grad_norm_np\nfrom oneflow.one_embedding import Ftrl\nimport oneflow as flow\nfrom oneflow.nn.parameter import Parameter\n\n\ndef compare_with_numpy_ftrl(\n    test_case,\n    device,\n    x_shape,\n    learning_rate,\n    train_iters,\n    weight_decay,\n    lr_power,\n    initial_accumulator_value,\n    lambda1,\n    lambda2,\n    beta,\n    reload_state_step,\n    save_load_by_pickle,\n):\n    random_grad_seq = []\n    for _ in range(train_iters):\n        random_grad_seq.append(np.random.uniform(size=x_shape).astype(np.float32))\n    init_value = np.random.uniform(size=x_shape).astype(np.float32)\n\n    def train_by_oneflow():\n        x = Parameter(flow.Tensor(init_value, device=flow.device(device)))\n        ftrl = Ftrl(\n            [\n                {\n                    \"params\": [x],\n                    \"lr\": learning_rate,\n                    \"weight_decay\": weight_decay,\n                    \"lr_power\": lr_power,\n                    \"initial_accumulator_value\": initial_accumulator_value,\n                    \"lambda1\": lambda1,\n                    \"lambda2\": lambda2,\n                    \"beta\": beta,\n                }\n            ]\n        )\n\n        def train_one_iter(grad):\n            grad_tensor = flow.tensor(\n                grad,\n                dtype=flow.float32,\n                requires_grad=False,\n                device=flow.device(device),\n            )\n            loss = flow.sum(x * grad_tensor)\n            loss.backward()\n            ftrl.step()\n            ftrl.zero_grad()\n\n        for i in range(train_iters):\n            train_one_iter(random_grad_seq[i])\n            if i == reload_state_step:\n                state_dict = ftrl.state_dict()\n                ftrl = Ftrl([{\"params\": [x],}],)\n                if save_load_by_pickle:\n                    with tempfile.NamedTemporaryFile() as f:\n                        flow.save(state_dict, f.name)\n                        state_dict = flow.load(f.name)\n                ftrl.load_state_dict(state_dict)\n        return x\n\n    def train_by_numpy():\n        x = init_value\n        accum = np.zeros_like(x)\n        accum.fill(initial_accumulator_value)\n        z_arr = np.zeros_like(x)\n\n        def np_train_one_iter(grad):\n            grad = grad + weight_decay * x\n\n            new_accum = accum + grad * grad\n            sigma = (\n                np.power(new_accum, lr_power) - np.power(accum, lr_power)\n            ) / learning_rate\n            new_z_val = z_arr + grad - sigma * x\n\n            update_val = (np.sign(new_z_val) * lambda1 - new_z_val) / (\n                (beta + np.power(new_accum, lr_power)) / learning_rate + lambda2\n            )\n            param = np.where(np.abs(new_z_val) < lambda1, 0.0, update_val)\n            return (param, new_accum, new_z_val)\n\n        for i in range(1, train_iters + 1):\n            (x, accum, z_arr) = np_train_one_iter(random_grad_seq[i - 1])\n        return x\n\n    oneflow_res = train_by_oneflow().numpy()\n    numpy_res = train_by_numpy()\n    test_case.assertTrue(\n        np.allclose(oneflow_res.flatten(), numpy_res.flatten(), rtol=1e-4, atol=1e-4)\n    )\n\n\ndef compare_with_numpy_ftrl_clip_grad(\n    test_case,\n    device,\n    x_shape,\n    learning_rate,\n    train_iters,\n    weight_decay,\n    lr_power,\n    initial_accumulator_value,\n    lambda1,\n    lambda2,\n    beta,\n    clip_grad_max_norm,\n    clip_grad_norm_type,\n    reload_state_step,\n    save_load_by_pickle,\n):\n    random_grad_seq = []\n    for _ in range(train_iters):\n        random_grad_seq.append(np.random.uniform(size=x_shape).astype(np.float32))\n    init_value = np.random.uniform(size=x_shape).astype(np.float32)\n\n    def train_by_oneflow():\n        x = Parameter(flow.Tensor(init_value, device=flow.device(device)))\n        ftrl = Ftrl(\n            [\n                {\n                    \"params\": [x],\n                    \"lr\": learning_rate,\n                    \"weight_decay\": weight_decay,\n                    \"lr_power\": lr_power,\n                    \"initial_accumulator_value\": initial_accumulator_value,\n                    \"lambda1\": lambda1,\n                    \"lambda2\": lambda2,\n                    \"beta\": beta,\n                    \"clip_grad_max_norm\": clip_grad_max_norm,\n                    \"clip_grad_norm_type\": clip_grad_norm_type,\n                }\n            ]\n        )\n\n        def train_one_iter(grad):\n            grad_tensor = flow.tensor(\n                grad,\n                dtype=flow.float32,\n                requires_grad=False,\n                device=flow.device(device),\n            )\n            loss = flow.sum(x * grad_tensor)\n            loss.backward()\n            ftrl.clip_grad()\n            ftrl.step()\n            ftrl.zero_grad()\n\n        for i in range(train_iters):\n            train_one_iter(random_grad_seq[i])\n            if i == reload_state_step:\n                state_dict = ftrl.state_dict()\n                ftrl = Ftrl([{\"params\": [x],}])\n                if save_load_by_pickle:\n                    with tempfile.NamedTemporaryFile() as f:\n                        flow.save(state_dict, f.name)\n                        state_dict = flow.load(f.name)\n                ftrl.load_state_dict(state_dict)\n        return x\n\n    def train_by_numpy():\n        x = init_value\n        accum = np.zeros_like(x)\n        accum.fill(initial_accumulator_value)\n        z_arr = np.zeros_like(x)\n\n        def np_train_one_iter(grad):\n            total_norm, grad = clip_grad_norm_np(\n                grad, clip_grad_max_norm, clip_grad_norm_type\n            )\n            grad = grad + weight_decay * x\n\n            new_accum = accum + grad * grad\n            sigma = (\n                np.power(new_accum, lr_power) - np.power(accum, lr_power)\n            ) / learning_rate\n            new_z_val = z_arr + grad - sigma * x\n\n            update_val = (np.sign(new_z_val) * lambda1 - new_z_val) / (\n                (beta + np.power(new_accum, lr_power)) / learning_rate + lambda2\n            )\n            param = np.where(np.abs(new_z_val) < lambda1, 0.0, update_val)\n            return (param, new_accum, new_z_val)\n\n        for i in range(1, train_iters + 1):\n            (x, accum, z_arr) = np_train_one_iter(random_grad_seq[i - 1])\n        return x\n\n    oneflow_res = train_by_oneflow().numpy()\n    numpy_res = train_by_numpy()\n    test_case.assertTrue(\n        np.allclose(oneflow_res.flatten(), numpy_res.flatten(), rtol=1e-4, atol=1e-4)\n    )\n\n\n@flow.unittest.skip_unless_1n1d()\nclass Testftrl(flow.unittest.TestCase):\n    def test_ftrl(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"device\"] = [\"cuda\", \"cpu\"]\n        arg_dict[\"x_shape\"] = [(10,)]\n        arg_dict[\"learning_rate\"] = [1, 1e-3]\n        arg_dict[\"train_iters\"] = [10]\n        arg_dict[\"weight_decay\"] = [0.9, 0.000]\n        arg_dict[\"lr_power\"] = [-0.5, 0.5]\n        arg_dict[\"initial_accumulator_value\"] = [0.1, 0.05]\n        arg_dict[\"lambda1\"] = [0.01]\n        arg_dict[\"lambda2\"] = [0.0, 0.01]\n        arg_dict[\"beta\"] = [1.0]\n        arg_dict[\"reload_state_step\"] = [5]  # save and load optim state\n        arg_dict[\"save_load_by_pickle\"] = [False, True]\n\n        for arg in GenArgList(arg_dict):\n            compare_with_numpy_ftrl(test_case, *arg)\n\n    def test_ftrl_clip_grad(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"x_shape\"] = [(10,)]\n        arg_dict[\"learning_rate\"] = [1, 1e-3]\n        arg_dict[\"train_iters\"] = [10]\n        arg_dict[\"weight_decay\"] = [0.9, 0.000]\n        arg_dict[\"lr_power\"] = [-0.5]\n        arg_dict[\"initial_accumulator_value\"] = [0.1, 0.05]\n        arg_dict[\"lambda1\"] = [0.01]\n        arg_dict[\"lambda2\"] = [0.0]\n        arg_dict[\"beta\"] = [1.0]\n        arg_dict[\"clip_grad_max_norm\"] = [0, 0.5, 1.0]\n        arg_dict[\"clip_grad_norm_type\"] = [\"inf\", \"-inf\", 0.0, 1.0, 2.0, 3.5]\n        arg_dict[\"reload_state_step\"] = [5]  # save and load optim state\n        arg_dict[\"save_load_by_pickle\"] = [False, True]\n\n        for arg in GenArgList(arg_dict):\n            compare_with_numpy_ftrl_clip_grad(test_case, *arg)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_optim_lamb.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport os\nimport tempfile\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nfrom optimizer_test_util import clip_grad_norm_np\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\n\n\ndef compare_with_numpy_lamb(\n    test_case,\n    device,\n    x_shape,\n    learning_rate,\n    train_iters,\n    betas,\n    weight_decay,\n    eps,\n    do_bias_correction,\n    adam_w_mode,\n    clip_grad_max_norm,\n    clip_grad_norm_type,\n    reload_state_step,\n    save_load_by_pickle,\n    contiguous_params,\n):\n\n    np.random.seed(1000)\n\n    random_grad_seq = []\n    for _ in range(train_iters):\n        random_grad_seq.append(np.random.uniform(size=x_shape).astype(np.float32))\n    init_value = np.random.uniform(size=x_shape).astype(np.float32)\n\n    def train_by_oneflow():\n        x = flow.nn.Parameter(flow.Tensor(init_value, device=flow.device(device)))\n\n        optim_kwargs = {\n            \"params\": [x],\n            \"lr\": learning_rate,\n            \"betas\": betas,\n            \"eps\": eps,\n            \"weight_decay\": weight_decay,\n            \"adam_w_mode\": adam_w_mode,\n            \"do_bias_correction\": do_bias_correction,\n            \"contiguous_params\": contiguous_params,\n        }\n\n        if clip_grad_max_norm != -1:\n            optim_kwargs[\"clip_grad_max_norm\"] = clip_grad_max_norm\n            optim_kwargs[\"clip_grad_norm_type\"] = clip_grad_norm_type\n\n        lamb = flow.optim.LAMB([optim_kwargs])\n\n        def train_one_iter(grad):\n            grad_tensor = flow.tensor(\n                grad,\n                dtype=flow.float32,\n                requires_grad=False,\n                device=flow.device(device),\n            )\n\n            loss = flow.sum(x * grad_tensor)\n            loss.backward()\n            if clip_grad_max_norm != -1:\n                lamb.clip_grad()\n            lamb.step()\n            lamb.zero_grad()\n\n        for i in range(train_iters):\n            train_one_iter(random_grad_seq[i])\n            if i == reload_state_step:\n                state_dict = lamb.state_dict()\n                lamb = flow.optim.LAMB([optim_kwargs])\n                if save_load_by_pickle:\n                    with tempfile.NamedTemporaryFile() as f:\n                        flow.save(state_dict, f.name)\n                        state_dict = flow.load(f.name)\n                lamb.load_state_dict(state_dict)\n        return x\n\n    def train_by_numpy():\n        x = init_value\n        mt = np.zeros_like(x)\n        vt = np.zeros_like(x)\n        beta1 = betas[0]\n        beta2 = betas[1]\n        if adam_w_mode:\n            l2 = 0\n            wd = weight_decay\n        else:\n            l2 = weight_decay\n            wd = 0\n\n        def np_train_one_iter(step, grad):\n            if clip_grad_max_norm != -1:\n                _, grad = clip_grad_norm_np(\n                    grad, clip_grad_max_norm, clip_grad_norm_type\n                )\n\n            grad = grad + l2 * x\n\n            bias_correction1 = 1.0\n            bias_correction2 = 1.0\n\n            if do_bias_correction:\n                bias_correction1 = 1.0 - np.power(beta1, step + 1)\n                bias_correction2 = 1.0 - np.power(beta2, step + 1)\n\n            m = beta1 * mt + (1 - beta1) * grad\n            v = beta2 * vt + (1 - beta2) * grad * grad\n\n            denom = np.sqrt(v) / np.sqrt(bias_correction2) + eps\n\n            adam_diff = m / bias_correction1 / denom\n\n            w_norm = np.linalg.norm(x, ord=2)\n            g_norm = np.linalg.norm(adam_diff, ord=2)\n            if w_norm > 0 and g_norm > 0:\n                trust_ratio = w_norm / g_norm\n            else:\n                trust_ratio = 1.0\n\n            param = x - learning_rate * trust_ratio * (adam_diff + wd * x)\n            return (param, m, v)\n\n        for i in range(train_iters):\n            (x, mt, vt) = np_train_one_iter(i, random_grad_seq[i])\n        return x\n\n    of_res = train_by_oneflow().numpy()\n    np_res = train_by_numpy()\n\n    test_case.assertTrue(\n        np.allclose(of_res.flatten(), np_res.flatten(), rtol=1e-3, atol=1e-3)\n    )\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestLamb(flow.unittest.TestCase):\n    def test_lamb(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"device\"] = [\"cuda\"]\n        if os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"):\n            arg_dict[\"device\"] = [\"cpu\"]\n        arg_dict[\"x_shape\"] = [(10,)]\n        arg_dict[\"learning_rate\"] = [0.1, 1e-3]\n        arg_dict[\"train_iters\"] = [10]\n        arg_dict[\"betas\"] = [(0.99, 0.9)]\n        arg_dict[\"weight_decay\"] = [0.001, 0.1]\n        arg_dict[\"eps\"] = [1e-6]\n        arg_dict[\"do_bias_correction\"] = [True, False]\n        arg_dict[\"adam_w_mode\"] = [True, False]\n        # NOTE(l1aoxingyu): max_norm = -1 means no clip grad\n        arg_dict[\"clip_grad_max_norm\"] = [-1, 0.0, 0.5, 1.0]\n        arg_dict[\"clip_grad_norm_type\"] = [\"inf\", \"-inf\", 0.0, 1.0, 2.0, 3.5]\n        arg_dict[\"reload_state_step\"] = [5]\n        arg_dict[\"save_load_by_pickle\"] = [False, True]\n        arg_dict[\"contiguous_params\"] = [False, True]\n\n        for arg in GenArgList(arg_dict):\n            compare_with_numpy_lamb(test_case, *arg)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_optim_lbfgs.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport tempfile\nimport unittest\nfrom collections import OrderedDict\nimport random as random_util\n\nimport numpy as np\nfrom oneflow.test_utils.test_util import GenArgList\nfrom oneflow.test_utils.automated_test_util import random_device, random_bool\nimport oneflow as flow\nfrom oneflow.nn.parameter import Parameter\nfrom collections import defaultdict\n\n\ndef _quadratic_interpolate(x1, f1, g1, x2, f2, g2, bounds=None):\n    if bounds is not None:\n        xmin_bound, xmax_bound = bounds\n    else:\n        xmin_bound, xmax_bound = (x1, x2) if x1 < x2 else (x2, x1)\n    if x1 == 0:\n        t_new = -(g1 * (x2 ** 2)) / (2 * (f2 - f1 - g1 * x2))\n    else:\n        a = -(f1 - f2 - g1 * (x1 - x2)) / ((x1 - x2) ** 2)\n        t_new = x1 - g1 / (2 * a)\n    return min(xmax_bound, max(xmin_bound, t_new))\n\n\ndef _strong_wolfe(\n    eval_closure, x, t, d, f, g, gtd, c1=1e-4, c2=0.9, tolerance_change=1e-9, max_ls=25\n):\n\n    d_norm = max(map(abs, d))\n    g = np.copy(g)\n    f_new, g_new = eval_closure(x, t, d)\n    ls_func_evals = 1\n    gtd_new = g_new.dot(d)\n\n    t_prev, f_prev, g_prev, gtd_prev = 0, f, g, gtd\n    done = False\n    ls_iter = 0\n    while ls_iter < max_ls:\n        if f_new > (f + c1 * t * gtd) or (ls_iter > 1 and f_new > f_prev):\n            search_area = [t_prev, t]\n            search_area_f = [f_prev, f_new]\n            search_area_g = [g_prev, np.copy(g_new)]\n            search_area_gtd = [gtd_prev, gtd_new]\n            break\n\n        if abs(gtd_new) <= -c2 * gtd:\n            search_area = [t]\n            search_area_f = [f_new]\n            search_area_g = [g_new]\n            done = True\n            break\n\n        if gtd_new >= 0:\n            search_area = [t_prev, t]\n            search_area_f = [f_prev, f_new]\n            search_area_g = [g_prev, np.copy(g_new)]\n            search_area_gtd = [gtd_prev, gtd_new]\n\n        min_step = t + 0.01 * (t - t_prev)\n        max_step = t * 10\n        tmp = t\n        t = _quadratic_interpolate(\n            t_prev, f_prev, gtd_prev, t, f_new, gtd_new, bounds=(min_step, max_step)\n        )\n        t_prev = tmp\n        f_prev = f_new\n        g_prev = np.copy(g_new)\n        gtd_prev = gtd_new\n        f_new, g_new = eval_closure(x, t, d)\n        ls_func_evals += 1\n        gtd_new = g_new.dot(d)\n        ls_iter += 1\n    if ls_iter == max_ls:\n        search_area = [0, t]\n        search_area_f = [f, f_new]\n        search_area_g = [g, g_new]\n\n    # zoom\n    low_pos, high_pos = (0, 1) if search_area_f[0] <= search_area_f[-1] else (1, 0)\n    while not done and ls_iter < max_ls:\n\n        if abs(search_area[1] - search_area[0]) * d_norm < tolerance_change:\n            break\n\n        t = _quadratic_interpolate(\n            search_area[0],\n            search_area_f[0],\n            search_area_gtd[0],\n            search_area[1],\n            search_area_f[1],\n            search_area_gtd[1],\n        )\n\n        f_new, g_new = eval_closure(x, t, d)\n        ls_func_evals += 1\n        gtd_new = g_new.dot(d)\n        ls_iter += 1\n\n        if f_new > (f + c1 * t * gtd) or f_new >= search_area_f[low_pos]:\n            search_area[high_pos] = t\n            search_area_f[high_pos] = f_new\n            search_area_g[high_pos] = np.copy(g_new)\n            search_area_gtd[high_pos] = gtd_new\n            low_pos, high_pos = (\n                (0, 1) if search_area_f[0] <= search_area_f[1] else (1, 0)\n            )\n        if abs(gtd_new) <= -c2 * gtd:\n            done = True\n        elif gtd_new * (search_area[high_pos] - search_area[low_pos]) >= 0:\n            search_area[high_pos] = search_area[low_pos]\n            search_area_f[high_pos] = search_area_f[low_pos]\n            search_area_g[high_pos] = search_area_g[low_pos]\n            search_area_gtd[high_pos] = search_area_gtd[low_pos]\n\n        search_area[low_pos] = t\n        search_area_f[low_pos] = f_new\n        search_area_g[low_pos] = np.copy(g_new)\n        search_area_gtd[low_pos] = gtd_new\n\n    t = search_area[low_pos]\n    f_new = search_area_f[low_pos]\n    g_new = search_area_g[low_pos]\n    return f_new, g_new, t, ls_func_evals\n\n\ndef compare_with_numpy_lbfgs(\n    test_case,\n    device,\n    x_shape,\n    learning_rate,\n    train_iters,\n    max_iter,\n    max_eval,\n    tolerance_grad,\n    tolerance_change,\n    history_size,\n    line_search_fn,\n    reload_state_step,\n    save_load_by_pickle,\n    contiguous_params,\n    tensor_num,\n    use_float64,\n):\n    random_grad_seq = []\n    init_value_seq = []\n    if use_float64:\n        npType = np.float64\n        flowType = flow.float64\n        flow.set_default_tensor_type(flow.DoubleTensor)\n    else:\n        npType = np.float32\n        flowType = flow.float32\n        flow.set_default_tensor_type(flow.FloatTensor)\n    for _ in range(tensor_num):\n        init_value_seq.append(np.random.uniform(size=x_shape).astype(npType))\n    for _ in range(tensor_num):\n        random_grad_seq.append(np.random.uniform(size=x_shape).astype(npType))\n\n    def train_by_oneflow():\n        x = []\n        for i in range(tensor_num):\n            x.append(\n                Parameter(\n                    flow.tensor(\n                        init_value_seq[i], device=flow.device(device), dtype=flowType\n                    )\n                )\n            )\n\n        lbfgs = flow.optim.LBFGS(\n            [{\"params\": x}],\n            lr=learning_rate,\n            max_iter=max_iter,\n            max_eval=max_eval,\n            tolerance_grad=tolerance_grad,\n            tolerance_change=tolerance_change,\n            history_size=history_size,\n            line_search_fn=line_search_fn,\n            contiguous_params=contiguous_params,\n        )\n\n        def compute_loss(grad):\n            loss = 0.0\n            for i in range(tensor_num):\n                grad_tensor = flow.tensor(\n                    grad[i],\n                    dtype=flowType,\n                    requires_grad=False,\n                    device=flow.device(device),\n                )\n                loss += flow.sum(x[i] * x[i] * grad_tensor)\n            loss.backward()\n            return loss\n\n        def train_one_iter(grad):\n            def closure():\n                lbfgs.zero_grad()\n                loss = compute_loss(grad)\n                return loss\n\n            return lbfgs.step(closure)\n\n        for i in range(train_iters):\n            train_one_iter(random_grad_seq)\n            if i == reload_state_step:\n                state_dict = lbfgs.state_dict()\n                lbfgs = flow.optim.LBFGS(\n                    [{\"params\": x,}], contiguous_params=contiguous_params\n                )\n                if save_load_by_pickle:\n                    with tempfile.NamedTemporaryFile() as f:\n                        flow.save(state_dict, f.name)\n                        state_dict = flow.load(f.name)\n                lbfgs.load_state_dict(state_dict)\n        return x\n\n    def train_by_numpy():\n        def compute_loss(param, grad):\n            loss = 0.0\n            loss += np.sum(param * param * grad)\n            return loss\n\n        x = np.concatenate(init_value_seq)\n\n        def np_train_one_iter(x, state, init_grad):\n            flat_grad = 2 * x * init_grad\n            if max(map(abs, flat_grad)) <= tolerance_grad:\n                return x\n            loss = compute_loss(x, init_grad)\n            current_evals = 1\n            state[\"func_evals\"] += 1\n            d = state.get(\"d\")\n            t = state.get(\"t\")\n            old_diffs = state.get(\"old_diffs\")\n            old_step_size = state.get(\"old_step_size\")\n            ro = state.get(\"ro\")\n            H_diag = state.get(\"H_diag\")\n            prev_flat_grad = state.get(\"prev_flat_grad\")\n            prev_loss = state.get(\"prev_loss\")\n            n_iter = 0\n            while n_iter < max_iter:\n                n_iter += 1\n                state[\"n_iter\"] += 1\n                if state[\"n_iter\"] == 1:\n                    d = -flat_grad\n                    old_diffs = []\n                    old_step_size = []\n                    ro = []\n                    H_diag = 1\n                else:\n                    y = flat_grad - prev_flat_grad\n                    s = d * t\n                    ys = y.dot(s)\n                    if ys > 1e-10:\n                        if len(old_diffs) == history_size:\n                            old_diffs.pop(0)\n                            old_step_size.pop(0)\n                            ro.pop(0)\n                        old_diffs.append(y)\n                        old_step_size.append(s)\n                        ro.append(1.0 / ys)\n                        H_diag = ys / y.dot(y)\n                    num_old = len(old_diffs)\n                    if \"alpha\" not in state:\n                        state[\"alpha\"] = [None] * history_size\n                    alpha = state[\"alpha\"]\n\n                    q = -flat_grad\n                    for i in range(num_old - 1, -1, -1):\n                        alpha[i] = old_step_size[i].dot(q) * ro[i]\n                        q += old_diffs[i] * -alpha[i]\n                    d = q * H_diag\n                    for i in range(num_old):\n                        beta_i = old_diffs[i].dot(d) * ro[i]\n                        d += old_step_size[i] * (alpha[i] - beta_i)\n\n                prev_flat_grad = np.copy(flat_grad)\n                prev_loss = loss\n                if state[\"n_iter\"] == 1:\n                    t = min(1.0, 1.0 / np.sum(np.abs(flat_grad))) * learning_rate\n                else:\n                    t = learning_rate\n                gtd = flat_grad.dot(d)\n                if gtd > -tolerance_change:\n                    break\n\n                ls_func_evals = 0\n                if line_search_fn is None:\n                    x += t * d\n                    if n_iter != max_iter:\n                        loss = float(compute_loss(x, init_grad))\n                        ls_func_evals = 1\n                        flat_grad = 2 * x * init_grad\n                else:\n                    assert (\n                        line_search_fn == \"strong_wolfe\"\n                    ), \"only strong_wolfe is expected\"\n                    init_param = np.copy(x)\n\n                    def eval_func(x, t, d):\n                        return (\n                            compute_loss(x + t * d, init_grad),\n                            2 * (x + t * d) * init_grad,\n                        )\n\n                    loss, flat_grad, t, ls_func_evals = _strong_wolfe(\n                        eval_func, init_param, t, d, loss, flat_grad, gtd\n                    )\n                    x += t * d\n                current_evals += ls_func_evals\n                state[\"func_evals\"] += ls_func_evals\n                if n_iter == max_iter:\n                    break\n\n                if current_evals >= max_eval:\n                    break\n\n                if np.max(np.abs(flat_grad)) <= tolerance_grad:\n                    break\n\n                if np.max(np.abs(d * t)) <= tolerance_change:\n                    break\n\n                if abs(loss - prev_loss) < tolerance_change:\n                    break\n            state[\"d\"] = d\n            state[\"t\"] = t\n            state[\"old_diffs\"] = old_diffs\n            state[\"old_step_size\"] = old_step_size\n            state[\"ro\"] = ro\n            state[\"prev_flat_grad\"] = prev_flat_grad\n            state[\"prev_loss\"] = prev_loss\n            state[\"H_diag\"] = H_diag\n            return x\n\n        state = defaultdict(dict)\n        state.setdefault(\"func_evals\", 0)\n        state.setdefault(\"n_iter\", 0)\n        for _ in range(0, train_iters):\n            x = np_train_one_iter(x, state, np.concatenate(random_grad_seq))\n        return x\n\n    oneflow_res = flow.cat(train_by_oneflow(), 0)\n    numpy_res = train_by_numpy()\n    test_case.assertTrue(\n        np.allclose(\n            oneflow_res.numpy().flatten(), numpy_res.flatten(), rtol=0.01, atol=0.01,\n        )\n    )\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestLBFGS(flow.unittest.TestCase):\n    def test_lbfgs_numpy(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"device\"] = [random_device().value()]\n        arg_dict[\"x_shape\"] = [10, 20]\n        arg_dict[\"learning_rate\"] = [0.01]\n        arg_dict[\"train_iters\"] = [20]\n        arg_dict[\"max_iter\"] = [20]\n        arg_dict[\"max_eval\"] = [25]\n        arg_dict[\"tolerance_grad\"] = [1e-7]\n        arg_dict[\"tolerance_change\"] = [1e-9]\n        arg_dict[\"history_size\"] = [100]\n        arg_dict[\"line_search_fn\"] = [None, \"strong_wolfe\"]\n        arg_dict[\"reload_state_step\"] = [5]\n        arg_dict[\"save_load_by_pickle\"] = [random_bool().value()]\n        arg_dict[\"contiguous_params\"] = [random_bool().value()]\n        arg_dict[\"tensor_num\"] = [3, 4, 7]\n        arg_dict[\"use_float64\"] = [True, False]\n        for arg in GenArgList(arg_dict):\n            compare_with_numpy_lbfgs(test_case, *arg)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_optim_rmsprop.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport tempfile\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nfrom oneflow.test_utils.test_util import GenArgList\nfrom optimizer_test_util import clip_grad_norm_np\n\nimport oneflow as flow\nfrom oneflow.nn.parameter import Parameter\n\n\ndef compare_with_numpy_rmsprop(\n    test_case,\n    device,\n    x_shape,\n    learning_rate,\n    momentum,\n    train_iters,\n    alpha,\n    eps,\n    weight_decay,\n    centered,\n    reload_state_step,\n    save_load_by_pickle,\n    contiguous_params,\n):\n    random_grad_seq = []\n    for _ in range(train_iters):\n        random_grad_seq.append(np.random.uniform(size=x_shape).astype(np.float32))\n    init_value = np.random.uniform(size=x_shape).astype(np.float32)\n\n    def train_by_oneflow():\n        x = Parameter(flow.Tensor(init_value, device=flow.device(device)))\n        param_list = list()\n        param_list.append(x)\n        rmsprop = flow.optim.RMSprop(\n            [\n                {\n                    \"params\": param_list,\n                    \"lr\": learning_rate,\n                    \"alpha\": alpha,\n                    \"eps\": eps,\n                    \"weight_decay\": weight_decay,\n                    \"momentum\": momentum,\n                    \"centered\": centered,\n                    \"contiguous_params\": contiguous_params,\n                }\n            ]\n        )\n\n        def train_one_iter(grad):\n            grad_tensor = flow.tensor(\n                grad,\n                dtype=flow.float32,\n                requires_grad=False,\n                device=flow.device(device),\n            )\n            loss = flow.sum(x * grad_tensor)\n            loss.backward()\n            rmsprop.step()\n            rmsprop.zero_grad()\n\n        for i in range(train_iters):\n            train_one_iter(random_grad_seq[i])\n            if i == reload_state_step:\n                state_dict = rmsprop.state_dict()\n                rmsprop = flow.optim.RMSprop([x], contiguous_params=contiguous_params)\n                if save_load_by_pickle:\n                    with tempfile.NamedTemporaryFile() as f:\n                        flow.save(state_dict, f.name)\n                        state_dict = flow.load(f.name)\n                rmsprop.load_state_dict(state_dict)\n        return x\n\n    def train_by_numpy():\n        x = init_value\n        r = np.zeros_like(x)\n        v = np.zeros_like(x)\n        g = np.zeros_like(x)\n\n        def train_one_iter(grad):\n\n            grad = grad + weight_decay * x\n            r_ = alpha * r + (1 - alpha) * grad * grad\n            if centered:\n                g_ = alpha * g + (1 - alpha) * grad\n                v_ = momentum * v + learning_rate / np.sqrt(r_ - g_ * g_ + eps) * grad\n            else:\n                g_ = g\n                v_ = momentum * v + learning_rate / np.sqrt(r_ + eps) * grad\n            param = x - v_\n            return (param, r_, g_, v_)\n\n        for i in range(train_iters):\n            (x, r, g, v) = train_one_iter(random_grad_seq[i])\n        return x\n\n    oneflow_res = train_by_oneflow().numpy()\n    numpy_res = train_by_numpy()\n    test_case.assertTrue(\n        np.allclose(oneflow_res.flatten(), numpy_res.flatten(), rtol=2e-3, atol=2e-3)\n    )\n\n\ndef compare_with_numpy_rmsprop_clip_grad(\n    test_case,\n    device,\n    x_shape,\n    learning_rate,\n    momentum,\n    train_iters,\n    alpha,\n    eps,\n    weight_decay,\n    centered,\n    clip_grad_max_norm,\n    clip_grad_norm_type,\n    reload_state_step,\n    save_load_by_pickle,\n    contiguous_params,\n):\n    random_grad_seq = []\n    for _ in range(train_iters):\n        random_grad_seq.append(np.random.uniform(size=x_shape).astype(np.float32))\n    init_value = np.random.uniform(size=x_shape).astype(np.float32)\n\n    def train_by_oneflow():\n        x = Parameter(flow.Tensor(init_value, device=flow.device(device)))\n        param_list = list()\n        param_list.append(x)\n        rmsprop = flow.optim.RMSprop(\n            [\n                {\n                    \"params\": param_list,\n                    \"lr\": learning_rate,\n                    \"alpha\": alpha,\n                    \"eps\": eps,\n                    \"weight_decay\": weight_decay,\n                    \"momentum\": momentum,\n                    \"centered\": centered,\n                    \"clip_grad_max_norm\": clip_grad_max_norm,\n                    \"clip_grad_norm_type\": clip_grad_norm_type,\n                    \"contiguous_params\": contiguous_params,\n                }\n            ]\n        )\n\n        def train_one_iter(grad):\n            grad_tensor = flow.tensor(\n                grad,\n                dtype=flow.float32,\n                requires_grad=False,\n                device=flow.device(device),\n            )\n            loss = flow.sum(x * grad_tensor)\n            loss.backward()\n            rmsprop.clip_grad()\n            rmsprop.step()\n            rmsprop.zero_grad()\n\n        for i in range(train_iters):\n            train_one_iter(random_grad_seq[i])\n            if i == reload_state_step:\n                state_dict = rmsprop.state_dict()\n                rmsprop = flow.optim.RMSprop([x], contiguous_params=contiguous_params)\n                if save_load_by_pickle:\n                    with tempfile.NamedTemporaryFile() as f:\n                        flow.save(state_dict, f.name)\n                        state_dict = flow.load(f.name)\n                rmsprop.load_state_dict(state_dict)\n        return x\n\n    def train_by_numpy():\n        x = init_value\n        r = np.zeros_like(x)\n        v = np.zeros_like(x)\n        g = np.zeros_like(x)\n\n        def train_one_iter(grad):\n            total_norm, grad = clip_grad_norm_np(\n                grad, clip_grad_max_norm, clip_grad_norm_type\n            )\n            grad = grad + weight_decay * x\n            r_ = alpha * r + (1 - alpha) * grad * grad\n            if centered:\n                g_ = alpha * g + (1 - alpha) * grad\n                v_ = momentum * v + learning_rate / np.sqrt(r_ - g_ * g_ + eps) * grad\n            else:\n                g_ = g\n                v_ = momentum * v + learning_rate / np.sqrt(r_ + eps) * grad\n            param = x - v_\n            return (param, r_, g_, v_)\n\n        for i in range(train_iters):\n            (x, r, g, v) = train_one_iter(random_grad_seq[i])\n        return x\n\n    oneflow_res = train_by_oneflow().numpy()\n    numpy_res = train_by_numpy()\n    test_case.assertTrue(\n        np.allclose(oneflow_res.flatten(), numpy_res.flatten(), rtol=2e-3, atol=2e-3)\n    )\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestRMSProp(flow.unittest.TestCase):\n    def test_rmsprop(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"x_shape\"] = [(10,)]\n        arg_dict[\"learning_rate\"] = [1]\n        arg_dict[\"momentum\"] = [0.0]\n        arg_dict[\"train_iters\"] = [2]\n        arg_dict[\"alpha\"] = [0.9, 0.99]\n        arg_dict[\"eps\"] = [1e-08, 1e-05]\n        arg_dict[\"weight_decay\"] = [0.1, 0.99]\n        arg_dict[\"centered\"] = [False, True]\n        arg_dict[\"reload_state_step\"] = [5]  # save and load optim state\n        arg_dict[\"save_load_by_pickle\"] = [False, True]\n        arg_dict[\"contiguous_params\"] = [True, False]\n        for arg in GenArgList(arg_dict):\n            compare_with_numpy_rmsprop(test_case, *arg)\n\n    def test_rmsprop_clip_grad(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"x_shape\"] = [(10,)]\n        arg_dict[\"learning_rate\"] = [1]\n        arg_dict[\"momentum\"] = [0.0]\n        arg_dict[\"train_iters\"] = [2]\n        arg_dict[\"alpha\"] = [0.9, 0.99]\n        arg_dict[\"eps\"] = [1e-08, 1e-05]\n        arg_dict[\"weight_decay\"] = [0.1, 0.99]\n        arg_dict[\"centered\"] = [False, True]\n        arg_dict[\"clip_grad_max_norm\"] = [0, 0.5, 1.0]\n        arg_dict[\"clip_grad_norm_type\"] = [\"inf\", \"-inf\", 0.0, 1.0, 2.0, 3.5]\n        arg_dict[\"reload_state_step\"] = [5]  # save and load optim state\n        arg_dict[\"save_load_by_pickle\"] = [False, True]\n        arg_dict[\"contiguous_params\"] = [False, True]\n        for arg in GenArgList(arg_dict):\n            compare_with_numpy_rmsprop_clip_grad(test_case, *arg)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_optim_sgd.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\nimport tempfile\nimport os\nimport random as random_util\n\nimport numpy as np\nfrom oneflow.test_utils.test_util import GenArgDict\nfrom oneflow.test_utils.automated_test_util import random_bool, random_device\nfrom optimizer_test_util import clip_grad_norm_np\n\nimport oneflow as flow\nfrom oneflow.nn.parameter import Parameter\n\n\ndef compare_with_numpy_sgd(\n    test_case,\n    device,\n    x_shape,\n    momentum,\n    dampening,\n    nesterov,\n    maximize,\n    weight_decay,\n    learning_rate,\n    train_iters,\n    reload_state_step,\n    save_load_by_pickle,\n    contiguous_params,\n    fused,\n    tensor_num,\n):\n    random_grad_seq = []\n    init_value_seq = []\n\n    for i in range(tensor_num):\n        init_value_seq.append(np.random.uniform(size=x_shape).astype(np.float32))\n\n    for _ in range(train_iters):\n        random_grad_seq_per_iter = []\n        for i in range(tensor_num):\n            random_grad_seq_per_iter.append(\n                np.random.uniform(size=x_shape).astype(np.float32)\n            )\n        random_grad_seq.append(random_grad_seq_per_iter)\n\n    def train_by_oneflow():\n        x = []\n        for i in range(tensor_num):\n            x.append(\n                Parameter(flow.Tensor(init_value_seq[i], device=flow.device(device)))\n            )\n\n        sgd = flow.optim.SGD(\n            [{\"params\": x, \"lr\": learning_rate, \"weight_decay\": weight_decay,}],\n            momentum=momentum,\n            dampening=dampening,\n            nesterov=nesterov,\n            maximize=maximize,\n            contiguous_params=contiguous_params,\n            fused=fused,\n        )\n\n        def train_one_iter(grad):\n            loss = 0.0\n            for i in range(tensor_num):\n                grad_tensor = flow.tensor(\n                    grad[i],\n                    dtype=flow.float32,\n                    requires_grad=False,\n                    device=flow.device(device),\n                )\n                loss += flow.sum(x[i] * grad_tensor)\n            loss.backward()\n            sgd.step()\n            sgd.zero_grad()\n\n        for i in range(train_iters):\n            train_one_iter(random_grad_seq[i])\n            # test state_dict/load_state_dict\n            if i == reload_state_step:\n                state_dict = sgd.state_dict()\n                sgd = flow.optim.SGD(x, contiguous_params=contiguous_params)\n                if save_load_by_pickle:\n                    with tempfile.NamedTemporaryFile() as f:\n                        flow.save(state_dict, f.name)\n                        state_dict = flow.load(f.name)\n                sgd.load_state_dict(state_dict)\n        return x\n\n    def train_by_numpy(tensor_idx):\n        x = init_value_seq[tensor_idx]\n        vt = np.zeros_like(x)\n\n        def train_one_iter(grad):\n            grad = grad + weight_decay * x\n            if momentum > 0.0:\n                next_momentum = momentum * vt + (1 - dampening) * grad\n                v = next_momentum\n\n                if nesterov:\n                    grad += momentum * next_momentum\n                else:\n                    grad = next_momentum\n\n                alpha = -learning_rate\n                if maximize:\n                    alpha = learning_rate\n                next_model = x + alpha * grad\n                param = next_model\n            else:\n                v = learning_rate * grad\n                param = x - v\n            return (param, v)\n\n        for i in range(train_iters):\n            (x, vt) = train_one_iter(random_grad_seq[i][tensor_idx])\n        return x\n\n    oneflow_res = train_by_oneflow()\n    numpy_res = []\n    for i in range(tensor_num):\n        numpy_res.append(train_by_numpy(i))\n\n    for i in range(tensor_num):\n        test_case.assertTrue(\n            np.allclose(\n                oneflow_res[i].numpy().flatten(),\n                numpy_res[i].flatten(),\n                rtol=0.0001,\n                atol=0.0001,\n            )\n        )\n\n\ndef compare_with_numpy_sgd_clip_grad(\n    test_case,\n    device,\n    x_shape,\n    momentum,\n    dampening,\n    nesterov,\n    maximize,\n    weight_decay,\n    learning_rate,\n    clip_grad_max_norm,\n    clip_grad_norm_type,\n    train_iters,\n    reload_state_step,\n    save_load_by_pickle,\n    contiguous_params,\n    fused,\n    tensor_num,\n):\n    random_grad_seq = []\n    init_value_seq = []\n\n    for i in range(tensor_num):\n        init_value_seq.append(np.random.uniform(size=x_shape).astype(np.float32))\n\n    for _ in range(train_iters):\n        random_grad_seq_per_iter = []\n        for i in range(tensor_num):\n            random_grad_seq_per_iter.append(\n                np.random.uniform(size=x_shape).astype(np.float32)\n            )\n        random_grad_seq.append(random_grad_seq_per_iter)\n\n    def train_by_oneflow():\n        x = []\n        for i in range(tensor_num):\n            x.append(\n                Parameter(flow.Tensor(init_value_seq[i], device=flow.device(device)))\n            )\n\n        sgd = flow.optim.SGD(\n            [\n                {\n                    \"params\": x,\n                    \"lr\": learning_rate,\n                    \"dampening\": dampening,\n                    \"nesterov\": nesterov,\n                    \"maximize\": maximize,\n                    \"weight_decay\": weight_decay,\n                    \"clip_grad_max_norm\": clip_grad_max_norm,\n                    \"clip_grad_norm_type\": clip_grad_norm_type,\n                }\n            ],\n            momentum=momentum,\n            dampening=dampening,\n            nesterov=nesterov,\n            maximize=maximize,\n            contiguous_params=contiguous_params,\n            fused=fused,\n        )\n\n        def train_one_iter(grad):\n            loss = 0.0\n            for i in range(tensor_num):\n                grad_tensor = flow.tensor(\n                    grad[i],\n                    dtype=flow.float32,\n                    requires_grad=False,\n                    device=flow.device(device),\n                )\n                loss += flow.sum(x[i] * grad_tensor)\n            loss.backward()\n            sgd.clip_grad()\n            sgd.step()\n            sgd.zero_grad()\n\n        for i in range(train_iters):\n            train_one_iter(random_grad_seq[i])\n            # test state_dict/load_state_dict\n            if i == reload_state_step:\n                state_dict = sgd.state_dict()\n                sgd = flow.optim.SGD(x, contiguous_params=contiguous_params)\n                if save_load_by_pickle:\n                    with tempfile.NamedTemporaryFile() as f:\n                        flow.save(state_dict, f.name)\n                        state_dict = flow.load(f.name)\n                sgd.load_state_dict(state_dict)\n        return x\n\n    def train_by_numpy():\n        x = init_value_seq\n        vt = np.zeros_like(x)\n\n        def train_one_iter(grad):\n            total_norm, grad = clip_grad_norm_np(\n                grad, clip_grad_max_norm, clip_grad_norm_type\n            )\n\n            for i in range(tensor_num):\n                grad[i] = grad[i] + weight_decay * x[i]\n                if momentum > 0.0:\n                    next_momentum = momentum * vt[i] + (1 - dampening) * grad[i]\n                    vt[i] = next_momentum\n\n                    if nesterov:\n                        grad[i] += momentum * next_momentum\n                    else:\n                        grad[i] = next_momentum\n\n                    alpha = -learning_rate\n                    if maximize:\n                        alpha = learning_rate\n                    x[i] = x[i] + alpha * grad[i]\n                else:\n                    vt[i] = learning_rate * grad[i]\n                    x[i] = x[i] - vt[i]\n\n        for i in range(train_iters):\n            train_one_iter(random_grad_seq[i])\n        return x\n\n    oneflow_res = train_by_oneflow()\n    numpy_res = train_by_numpy()\n\n    for i in range(tensor_num):\n        test_case.assertTrue(\n            np.allclose(\n                oneflow_res[i].numpy().flatten(),\n                numpy_res[i].flatten(),\n                rtol=0.0001,\n                atol=0.0001,\n            )\n        )\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestOptimizers(flow.unittest.TestCase):\n    def test_sgd(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"device\"] = [random_device().value()]\n        arg_dict[\"x_shape\"] = [(10,)]\n        arg_dict[\"momentum\"] = [0.0, 0.9]\n        arg_dict[\"dampening\"] = [0.0, 0.9]\n        arg_dict[\"nesterov\"] = [random_bool().value()]\n        arg_dict[\"maximize\"] = [random_bool().value()]\n        arg_dict[\"weight_decay\"] = [0.0, 0.9]\n        arg_dict[\"learning_rate\"] = [1, 0.1]\n        arg_dict[\"train_iters\"] = [10]\n        arg_dict[\"reload_state_step\"] = [5]  # save and load optim state\n        arg_dict[\"save_load_by_pickle\"] = [random_bool().value()]\n        arg_dict[\"contiguous_params\"] = [random_bool().value()]\n        arg_dict[\"fused\"] = [random_bool().value()]\n        arg_dict[\"tensor_num\"] = [1, 4]\n        for arg in GenArgDict(arg_dict):\n            compare_with_numpy_sgd(test_case, **arg)\n\n    def test_sgd_clip_grad(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"device\"] = [random_device().value()]\n        arg_dict[\"x_shape\"] = [(10,)]\n        arg_dict[\"momentum\"] = [0.0, 0.9]\n        arg_dict[\"dampening\"] = [0.0, 0.9]\n        arg_dict[\"nesterov\"] = [random_bool().value()]\n        arg_dict[\"maximize\"] = [random_bool().value()]\n        arg_dict[\"weight_decay\"] = [0.0, 0.9]\n        arg_dict[\"learning_rate\"] = [1, 0.1]\n        arg_dict[\"clip_grad_max_norm\"] = [0, 0.5, 1.0]\n        arg_dict[\"clip_grad_norm_type\"] = random_util.sample(\n            [\"inf\", \"-inf\", 0.0, 1.0, 2.0, 3.5], k=3\n        )\n        arg_dict[\"train_iters\"] = [10]\n        arg_dict[\"reload_state_step\"] = [5]  # save and load optim state\n        arg_dict[\"save_load_by_pickle\"] = [random_bool().value()]\n        arg_dict[\"contiguous_params\"] = [random_bool().value()]\n        arg_dict[\"fused\"] = [random_bool().value()]\n        arg_dict[\"tensor_num\"] = [1, 4]\n        for arg in GenArgDict(arg_dict):\n            compare_with_numpy_sgd_clip_grad(test_case, **arg)\n\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    def test_eager_global_zero_grad_sbp(test_case):\n        x = flow.nn.Parameter(\n            flow.zeros((10,)).to_global(\n                sbp=flow.sbp.broadcast, placement=flow.placement(\"cuda\", [0])\n            )\n        )\n        x.grad = flow.ones_like(x)\n        t = x.grad\n        test_case.assertEqual(len(t.sbp), 1)\n        test_case.assertEqual(t.sbp[0], flow.sbp.broadcast)\n        optimizer = flow.optim.SGD([x])\n        optimizer.zero_grad()\n        test_case.assertTrue(np.allclose(t.numpy(), 0.0))\n        test_case.assertEqual(len(t.sbp), 1)\n        test_case.assertEqual(t.sbp[0], flow.sbp.partial_sum)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_pairwise_distance.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestPairwiseDistance(flow.unittest.TestCase):\n    @autotest(n=3)\n    def test_pairwise_distance_module_with_random_data(test_case):\n        device = random_device()\n        a = random_tensor(ndim=2, dim0=10, dim1=128).to(device)\n        b = random_tensor(ndim=2, dim0=10, dim1=128).to(device)\n        cos = torch.nn.PairwiseDistance(p=2, eps=1e-6).to(device)\n        cos.train(random())\n        output = cos(a, b)\n        return output\n\n    @autotest(n=3)\n    def test_pairwise_distance_module_with_nonequal_dim_random_data(test_case):\n        device = random_device()\n        a = random_tensor(ndim=1, dim0=128).to(device)\n        b = random_tensor(ndim=2, dim0=10, dim1=128).to(device)\n        cos = torch.nn.PairwiseDistance(p=2, eps=1e-6).to(device)\n        cos.train(random())\n        output = cos(a, b)\n        return output\n\n    @autotest(n=3)\n    def test_pairwise_distance_functional_with_random_data(test_case):\n        device = random_device()\n        a = random_tensor(ndim=2, dim0=10, dim1=128).to(device)\n        b = random_tensor(ndim=2, dim0=10, dim1=128).to(device)\n        output = torch.nn.functional.pairwise_distance(a, b, p=2, eps=1e-6)\n        return output\n\n    @autotest(n=3)\n    def test_pairwise_distance_functional_with_nonequal_dim_random_data(test_case):\n        device = random_device()\n        a = random_tensor(ndim=1, dim0=128).to(device)\n        b = random_tensor(ndim=2, dim0=10, dim1=128).to(device)\n        output = torch.nn.functional.pairwise_distance(a, b, p=2, eps=1e-6)\n        return output\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_param_group.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestParamGroup(flow.unittest.TestCase):\n    def test_ParamGroup(test_case):\n        parameters = {\n            \"params\": [flow.ones(10), flow.ones(5)],\n            \"lr\": 0.01,\n        }\n        default_options = {\n            \"test_float\": 1e-3,\n            \"test_int\": 6,\n            \"test_list\": [1, 2, 3],\n            \"test_tensor\": flow.ones(10),\n            \"test_str\": \"test\",\n        }\n\n        pg = flow.optim.optimizer.ParamGroup(parameters, default_options)\n\n        test_case.assertEqual(pg[\"test_float\"], 1e-3)\n        test_case.assertEqual(pg[\"test_int\"], 6)\n        test_case.assertTrue(np.array_equal(pg.get(\"test_list\"), [1, 2, 3]))\n        test_case.assertTrue(\n            np.array_equal(pg.get(\"test_tensor\").numpy(), flow.ones(10).numpy())\n        )\n        test_case.assertEqual(pg[\"test_str\"], \"test\")\n        test_case.assertTrue(\"params\" in pg.keys())\n        test_case.assertTrue(\n            np.array_equal(pg[\"params\"][0].numpy(), flow.ones(10).numpy())\n        )\n        test_case.assertTrue(\n            np.array_equal(pg[\"params\"][1].numpy(), flow.ones(5).numpy())\n        )\n        test_case.assertEqual(pg[\"lr\"], 0.01)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_parameters_grouping.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nimport numpy as np\nfrom collections import OrderedDict\n\nimport oneflow as flow\nfrom oneflow.test_utils.test_util import GenArgDict\nfrom oneflow.nn.utils.parameters_grouping import ContiguousParamsGroup as CPG\nfrom oneflow.nn.parameter import Parameter\nimport oneflow.unittest\n\n\ndef np_allclose_with_shape(a, b, *args, **kwargs):\n    return a.shape == b.shape and np.allclose(a, b, *args, **kwargs)\n\n\ndef module_grouping(test_case, device):\n    class Model(flow.nn.Module):\n        def __init__(self):\n            super(Model, self).__init__()\n            dtypes = [flow.float32, flow.float64]\n            for i in range(10):\n                self.register_parameter(\n                    f\"w{i}\",\n                    flow.nn.Parameter(\n                        flow.tensor([i % 2 + 1, i % 2 + 1], dtype=dtypes[i % 2])\n                    ),\n                )\n\n    m = Model().to(device)\n    m.make_contiguous_params_group()\n    cpg = CPG(\n        list(m.parameters())\n        + [flow.tensor([3, 3], dtype=flow.float32, requires_grad=True)]\n    )\n\n    test_case.assertTrue(len(m.cpg.grouped_parameters) == 2)\n    test_case.assertTrue(len(m.cpg.grouped_grads) == 2)\n    test_case.assertTrue(flow.max(m.cpg.grouped_parameters[0]) == 1)\n    test_case.assertTrue(flow.max(m.cpg.grouped_parameters[1]) == 2)\n\n    test_case.assertTrue(len(cpg.grouped_parameters) == 3)\n    test_case.assertTrue(len(cpg.grouped_grads) == 3)\n    test_case.assertTrue(flow.max(cpg.grouped_parameters[0]) == 1)\n    test_case.assertTrue(flow.max(cpg.grouped_parameters[1]) == 2)\n    test_case.assertTrue(flow.max(cpg.grouped_parameters[2]) == 3)\n\n\ndef direct_grouping(test_case, device):\n    x = [\n        Parameter(\n            flow.tensor(\n                [1, 2],\n                device=flow.device(device),\n                dtype=flow.float32,\n                requires_grad=True,\n            )\n        ),\n        Parameter(\n            flow.tensor(\n                [3, 4],\n                device=flow.device(device),\n                dtype=flow.float32,\n                requires_grad=True,\n            )\n        ),\n    ]\n    cpg = CPG([[x[0]], [x[1]]])\n    test_case.assertTrue(len(cpg.grouped_parameters) == 2)\n    test_case.assertTrue(len(cpg.grouped_grads) == 2)\n\n\ndef global_grouping(test_case, device):\n    x = flow.nn.Parameter(\n        flow.zeros((10,), dtype=flow.float32, requires_grad=True).to_global(\n            sbp=flow.sbp.broadcast, placement=flow.placement(device, [0])\n        )\n    )\n    y = flow.nn.Parameter(\n        flow.zeros((10,), dtype=flow.float32, requires_grad=True).to_global(\n            sbp=flow.sbp.split(0), placement=flow.placement(device, [0])\n        )\n    )\n    cpg = CPG([x, y], group_on_current_buffer=False)\n    test_case.assertTrue(len(cpg.grouped_parameters) == 2)\n    test_case.assertTrue(len(cpg.grouped_grads) == 2)\n\n\ndef multi_module_grad(test_case, device):\n    class Module1(flow.nn.Module):\n        def __init__(self):\n            super().__init__()\n            self.w1 = flow.nn.Parameter(flow.Tensor([1, 1]))\n            self.w2 = flow.nn.Parameter(flow.Tensor([1, 1]))\n\n        def forward(self, x):\n            return x * self.w1 * self.w2\n\n    class Module2(flow.nn.Module):\n        def __init__(self):\n            super().__init__()\n            self.w1 = flow.nn.Parameter(flow.Tensor([2, 2]))\n            self.w2 = flow.nn.Parameter(flow.Tensor([2, 2]))\n\n        def forward(self, x):\n            return x * self.w1 * self.w2\n\n    m1 = Module1().to(device)\n    m1.make_contiguous_params_group()\n    m2 = Module2().to(device)\n    m2.make_contiguous_params_group()\n    optim1 = flow.optim.SGD(m1.parameters(), lr=1e-2, contiguous_params=True)\n    optim2 = flow.optim.SGD(m2.parameters(), lr=1e-2, contiguous_params=True)\n    x1 = flow.ones([1, 1]).to(device)\n    x2 = flow.ones([2, 2]).to(device)\n    flow.sum(m1(x1)).backward()\n    flow.sum(m2(x2)).backward()\n\n    for p in m1.parameters():\n        test_case.assertTrue(\n            np_allclose_with_shape(p.grad.numpy(), np.array([1.0, 1.0]))\n        )\n\n    for p in m2.parameters():\n        test_case.assertTrue(\n            np_allclose_with_shape(p.grad.numpy(), np.array([4.0, 4.0]))\n        )\n\n\ndef multi_module_lifecycle(test_case, device):\n    class Module1(flow.nn.Module):\n        def __init__(self):\n            super().__init__()\n            self.w1 = flow.nn.Parameter(flow.Tensor([1, 1]))\n            self.w2 = flow.nn.Parameter(flow.Tensor([1, 1]))\n\n        def forward(self, x):\n            return x * self.w1 * self.w2\n\n    class Module2(flow.nn.Module):\n        def __init__(self):\n            super().__init__()\n            self.w1 = flow.nn.Parameter(flow.Tensor([2, 2]))\n            self.w2 = flow.nn.Parameter(flow.Tensor([2, 2]))\n\n        def forward(self, x):\n            return x * self.w1 * self.w2\n\n    m1 = Module1().to(device)\n    m1.make_contiguous_params_group()\n    m2 = Module2().to(device)\n    m2.make_contiguous_params_group()\n    del m1\n    cpg = CPG(list(m2.parameters()))\n    test_case.assertTrue(len(cpg.grouped_parameters) == 1)\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestCPG(flow.unittest.TestCase):\n    def test_cpg(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"device\"] = [\"cuda\", \"cpu\"]\n        for arg in GenArgDict(arg_dict):\n            device = arg[\"device\"]\n            module_grouping(test_case, device)\n            direct_grouping(test_case, device)\n            global_grouping(test_case, device)\n            multi_module_lifecycle(test_case, device)\n            multi_module_grad(test_case, device)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_parital_fc.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\n\nfrom oneflow.test_utils.automated_test_util import *\nimport oneflow as flow\nimport oneflow.unittest\n\n\n# TODO: guoran, fix this on multi gpu\n@flow.unittest.skip_unless_1n1d()\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\nclass TestParitalFC(flow.unittest.TestCase):\n    @unittest.skip(\"skip for now, becase it failed 2 times in past week\")\n    def test_parital_fc(test_case):\n        p = flow.placement.all(\"cuda\")\n        w = flow.randn(\n            50000, 128, placement=p, sbp=flow.sbp.broadcast, requires_grad=True\n        )\n        label = flow.randint(0, 50000, (512,), placement=p, sbp=flow.sbp.broadcast)\n        num_sample = 5000\n        out = flow.distributed_partial_fc_sample(w, label, num_sample)\n        test_case.assertTrue(out[0].shape == flow.Size([512]))\n        test_case.assertTrue(out[1].shape == flow.Size([5000]))\n        test_case.assertTrue(out[2].shape == flow.Size([5000, 128]))\n        # test gradient function\n        sample_weight = out[2]\n        sample_weight.sum().backward()\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_pixel_shuffle.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\n\nfrom oneflow.test_utils.automated_test_util import *\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\ndef _np_pixel_shuffle(input, h_factor, w_factor):\n    (_batch, _channel, _height, _width) = input.shape\n    assert (\n        _channel % (h_factor * w_factor) == 0\n    ), \"The channels of input tensor must be divisible by (h_upscale_factor * w_upscale_factor)\"\n    _new_c = int(_channel / (h_factor * w_factor))\n    out = np.reshape(input, [_batch, _new_c, h_factor * w_factor, _height, _width])\n    out = np.reshape(out, [_batch, _new_c, h_factor, w_factor, _height, _width])\n    out = np.transpose(out, [0, 1, 4, 2, 5, 3])\n    out = np.reshape(out, [_batch, _new_c, _height * h_factor, _width * w_factor])\n    return out\n\n\ndef _np_pixel_shuffle_grad(input, h_factor, w_factor):\n    (_batch, _new_channel, _height_mul_factor, _width_mul_factor) = input.shape\n    _channel = _new_channel * (h_factor * w_factor)\n    _height = _height_mul_factor // h_factor\n    _width = _width_mul_factor // w_factor\n    out = np.ones(shape=(_batch, _channel, _height, _width))\n    return out\n\n\ndef _test_pixel_shuffle_impl(\n    test_case, device, shape, h_upscale_factor, w_upscale_factor\n):\n    x = np.random.randn(*shape)\n    input = flow.tensor(\n        x, dtype=flow.float32, requires_grad=True, device=flow.device(device)\n    )\n    m = flow.nn.PixelShuffle(\n        h_upscale_factor=h_upscale_factor, w_upscale_factor=w_upscale_factor\n    )\n    m = m.to(device)\n    of_out = m(input)\n    np_out = _np_pixel_shuffle(x, h_upscale_factor, w_upscale_factor)\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05))\n    of_out = of_out.sum()\n    of_out.backward()\n    np_grad = _np_pixel_shuffle_grad(np_out, h_upscale_factor, w_upscale_factor)\n    test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-05, 1e-05))\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestPixelShuffleModule(flow.unittest.TestCase):\n    def test_pixel_shuffle(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [_test_pixel_shuffle_impl]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"shape\"] = [(2, 144, 5, 5), (11, 144, 1, 1)]\n        arg_dict[\"h_upscale_factor\"] = [2, 3, 4]\n        arg_dict[\"w_upscale_factor\"] = [2, 3, 4]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n        arg_dict[\"shape\"] = [(8, 25, 18, 18), (1, 25, 2, 2)]\n        arg_dict[\"h_upscale_factor\"] = [5]\n        arg_dict[\"w_upscale_factor\"] = [5]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n    @autotest()\n    def test_pixel_shuffle_with_random_data(test_case):\n        upscale_factor = random().to(int)\n        num_channels = upscale_factor * upscale_factor * random().to(int)\n        m = torch.nn.PixelShuffle(upscale_factor=upscale_factor)\n        m.train(random())\n        device = random_device()\n        m.to(device)\n        x = random_tensor(ndim=4, dim1=num_channels).to(device)\n        y = m(x)\n        return y\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_prelu.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\n\nfrom oneflow.test_utils.automated_test_util import *\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestPReLU(flow.unittest.TestCase):\n    @autotest(n=5)\n    def test_prelu_4dim_module_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=4, dim1=3).to(device)\n        m = torch.nn.PReLU(\n            num_parameters=3 | nothing(), init=random().to(float) | nothing(),\n        )\n        m.to(device)\n        m.train(random())\n        y = m(x)\n        return y\n\n    @autotest(n=5)\n    def test_prelu_4dim_default_alpha_module_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=4, dim1=3).to(device)\n        m = torch.nn.PReLU(init=random().to(float) | nothing(),)\n        m.to(device)\n        m.train(random())\n        y = m(x)\n        return y\n\n    @autotest(n=5)\n    def test_prelu_2dim_module_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=2, dim1=3).to(device)\n        m = torch.nn.PReLU(\n            num_parameters=3 | nothing(), init=random().to(float) | nothing(),\n        )\n        m.to(device)\n        m.train(random())\n        y = m(x)\n        return y\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_prod.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\n\nfrom oneflow.test_utils.automated_test_util import *\nimport oneflow as flow\nimport oneflow.unittest\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestReduceProd(flow.unittest.TestCase):\n    @autotest(n=5, check_graph=True)\n    def test_reduce_prod_without_dim(test_case):\n        device = random_device()\n        ndim = random(1, 5).to(int)\n        x = random_tensor(ndim=ndim).to(device)\n        y = torch.prod(x)\n\n        return y\n\n    @autotest(n=5, check_graph=True)\n    def test_reduce_prod_with_dim(test_case):\n        device = random_device()\n        ndim = random(1, 5).to(int)\n        x = random_tensor(ndim=ndim).to(device)\n        dim = random(0, ndim).to(int)\n        y = torch.prod(x, dim)\n        y = torch.exp(y)\n\n        return y\n\n    @autotest(n=5, auto_backward=False, check_graph=True)\n    def test_reduce_prod_bool_without_dim(test_case):\n        device = random_device()\n        ndim = random(1, 5).to(int)\n        x = random_tensor(ndim=ndim).to(device=device, dtype=torch.bool)\n        y = torch.prod(x)\n\n        return y\n\n    @autotest(n=5, auto_backward=False, check_graph=True)\n    def test_reduce_prod_with_dtype(test_case):\n        device = random_device()\n        ndim = random(1, 5).to(int)\n        x = random_tensor(ndim=ndim, low=1.0, high=4.0, requires_grad=False).to(device)\n        dim = random(0, ndim).to(int)\n        y = torch.prod(x, dim, dtype=torch.int32)\n\n        return y\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_pruning.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nimport sys\nfrom collections import OrderedDict\n\nimport numpy as np\nimport tempfile\nimport pickle\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow.nn.utils.prune as prune\nimport oneflow as flow\nimport oneflow.unittest\nimport oneflow.nn as nn\nimport unittest.mock as mock\nfrom contextlib import contextmanager\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\nclass TestPrune(flow.unittest.TestCase):\n    def test_validate_pruning_amount_init(self):\n        r\"\"\"Test the first util function that validates the pruning\n            amount requested by the user the moment the pruning method\n            is initialized. This test checks that the expected errors are\n            raised whenever the amount is invalid.\n            The original function runs basic type checking + value range checks.\n            It doesn't check the validity of the pruning amount with\n            respect to the size of the tensor to prune. That's left to\n            `_validate_pruning_amount`, tested below.\n            \"\"\"\n        # neither float not int should raise TypeError\n        with self.assertRaises(TypeError):\n            prune._validate_pruning_amount_init(amount=\"I'm a string\")\n\n        # float not in [0, 1] should raise ValueError\n        with self.assertRaises(ValueError):\n            prune._validate_pruning_amount_init(amount=1.1)\n        with self.assertRaises(ValueError):\n            prune._validate_pruning_amount_init(amount=20.0)\n\n        # negative int should raise ValueError\n        with self.assertRaises(ValueError):\n            prune._validate_pruning_amount_init(amount=-10)\n\n        # all these should pass without errors because they're valid amounts\n        prune._validate_pruning_amount_init(amount=0.34)\n        prune._validate_pruning_amount_init(amount=1500)\n        prune._validate_pruning_amount_init(amount=0)\n        prune._validate_pruning_amount_init(amount=0.0)\n        prune._validate_pruning_amount_init(amount=1)\n        prune._validate_pruning_amount_init(amount=1.0)\n        self.assertTrue(True)\n\n    def test_validate_pruning_amount(self):\n        r\"\"\"Tests the second util function that validates the pruning\n        amount requested by the user, this time with respect to the size\n        of the tensor to prune. The rationale is that if the pruning amount,\n        converted to absolute value of units to prune, is larger than\n        the number of units in the tensor, then we expect the util function\n        to raise a value error.\n        \"\"\"\n        # if amount is int and amount > tensor_size, raise ValueError\n        with self.assertRaises(ValueError):\n            prune._validate_pruning_amount(amount=20, tensor_size=19)\n\n        # amount is a float so this should not raise an error\n        prune._validate_pruning_amount(amount=0.3, tensor_size=0)\n\n        # this is okay\n        prune._validate_pruning_amount(amount=19, tensor_size=20)\n        prune._validate_pruning_amount(amount=0, tensor_size=0)\n        prune._validate_pruning_amount(amount=1, tensor_size=1)\n        self.assertTrue(True)\n\n    def test_compute_nparams_to_prune(self):\n        r\"\"\"Test that requested pruning `amount` gets translated into the\n        correct absolute number of units to prune.\n        \"\"\"\n        self.assertEqual(prune._compute_nparams_toprune(amount=0, tensor_size=15), 0)\n        self.assertEqual(prune._compute_nparams_toprune(amount=10, tensor_size=15), 10)\n        # if 1 is int, means 1 unit\n        self.assertEqual(prune._compute_nparams_toprune(amount=1, tensor_size=15), 1)\n        # if 1. is float, means 100% of units\n        self.assertEqual(prune._compute_nparams_toprune(amount=1.0, tensor_size=15), 15)\n        self.assertEqual(prune._compute_nparams_toprune(amount=0.4, tensor_size=17), 7)\n\n    def test_random_pruning_sizes(self):\n        r\"\"\"Test that the new parameters and buffers created by the pruning\n        method have the same size as the input tensor to prune. These, in\n        fact, correspond to the pruned version of the tensor itself, its\n        mask, and its original copy, so the size must match.\n        \"\"\"\n        # fixturize test\n        # TODO: add other modules\n        modules = [nn.Linear(5, 7), nn.Conv3d(2, 2, 2)]\n        names = [\"weight\", \"bias\"]\n\n        for m in modules:\n            for name in names:\n                with self.subTest(m=m, name=name):\n                    original_tensor = getattr(m, name)\n\n                    prune.random_unstructured(m, name=name, amount=0.1)\n                    # mask has the same size as tensor being pruned\n                    self.assertEqual(\n                        original_tensor.size(), getattr(m, name + \"_mask\").size()\n                    )\n                    # 'orig' tensor has the same size as the original tensor\n                    self.assertEqual(\n                        original_tensor.size(), getattr(m, name + \"_orig\").size()\n                    )\n                    # new tensor has the same size as the original tensor\n                    self.assertEqual(original_tensor.size(), getattr(m, name).size())\n\n    def test_random_pruning_orig(self):\n        r\"\"\"Test that original tensor is correctly stored in 'orig'\n        after pruning is applied. Important to make sure we don't\n        lose info about the original unpruned parameter.\n        \"\"\"\n        # fixturize test\n        # TODO: add other modules\n        modules = [nn.Linear(5, 7), nn.Conv3d(2, 2, 2)]\n        names = [\"weight\", \"bias\"]\n\n        for m in modules:\n            for name in names:\n                with self.subTest(m=m, name=name):\n\n                    # tensor prior to pruning\n                    original_tensor = getattr(m, name)\n                    prune.random_unstructured(m, name=name, amount=0.1)\n                    result = flow.sum(\n                        original_tensor - getattr(m, name + \"_orig\")\n                    ).item()\n                    self.assertEqual(result, 0)\n\n    def test_random_pruning_new_weight(self):\n        r\"\"\"Test that module.name now contains a pruned version of\n        the original tensor obtained from multiplying it by the mask.\n        \"\"\"\n        # fixturize test\n        # TODO: add other modules\n        modules = [nn.Linear(5, 7), nn.Conv3d(2, 2, 2)]\n        names = [\"weight\", \"bias\"]\n\n        for m in modules:\n            for name in names:\n                with self.subTest(m=m, name=name):\n                    # tensor prior to pruning\n                    original_tensor = getattr(m, name)\n                    prune.random_unstructured(m, name=name, amount=0.1)\n                    # weight = weight_orig * weight_mask\n                    weight = getattr(m, name)\n                    weight_orig_mask = getattr(m, name + \"_orig\") * getattr(\n                        m, name + \"_mask\"\n                    ).to(dtype=original_tensor.dtype)\n                    result = flow.sum(weight - weight_orig_mask).item()\n\n                    self.assertEqual(result, 0)\n\n    def test_identity_pruning(self):\n        r\"\"\"Test that a mask of 1s does not change forward or backward.\n        \"\"\"\n        input_ = flow.ones(1, 5)\n        m = nn.Linear(5, 2)\n        y_prepruning = m(input_)  # output prior to pruning\n\n        # compute grad pre-pruning and check it's equal to all ones\n        y_prepruning.sum().backward()\n        old_grad_weight = m.weight.grad.clone()  # don't grab pointer!\n        self.assertEqual(flow.sum(old_grad_weight - flow.ones_like(m.weight)).item(), 0)\n        old_grad_bias = m.bias.grad.clone()\n        self.assertEqual(flow.sum(old_grad_bias - flow.ones_like(m.bias)).item(), 0)\n\n        # remove grads\n        m.zero_grad()\n\n        # force the mask to be made of all 1s\n        prune.identity(m, name=\"weight\")\n\n        # with mask of 1s, output should be identical to no mask\n        y_postpruning = m(input_)\n        self.assertEqual(flow.sum(y_prepruning - y_postpruning).item(), 0)\n\n        # with mask of 1s, grad should be identical to no mask\n        y_postpruning.sum().backward()\n        self.assertEqual(flow.sum(old_grad_weight - m.weight_orig.grad).item(), 0)\n        self.assertEqual(flow.sum(old_grad_bias - m.bias.grad).item(), 0)\n\n        # calling forward twice in a row shouldn't change output\n        y1 = m(input_)\n        y2 = m(input_)\n        self.assertEqual(flow.sum(y1 - y2).item(), 0)\n\n    def test_random_pruning_0perc(self):\n        r\"\"\"Test that a mask of 1s does not change forward or backward.\n        \"\"\"\n        input_ = flow.ones(1, 5)\n        m = nn.Linear(5, 2)\n        y_prepruning = m(input_)  # output prior to pruning\n\n        # compute grad pre-pruning and check it's equal to all ones\n        y_prepruning.sum().backward()\n        old_grad_weight = m.weight.grad.clone()  # don't grab pointer!\n        self.assertEqual(flow.sum(old_grad_weight - flow.ones_like(m.weight)).item(), 0)\n        old_grad_bias = m.bias.grad.clone()\n        self.assertEqual(flow.sum(old_grad_bias - flow.ones_like(m.bias)).item(), 0)\n\n        # remove grads\n        m.zero_grad()\n\n        # force the mask to be made of all 1s\n        with mock.patch(\n            \"oneflow.nn.utils.prune.RandomUnstructured.compute_mask\"\n        ) as compute_mask:\n            compute_mask.return_value = flow.ones_like(m.weight)\n            prune.random_unstructured(\n                m, name=\"weight\", amount=0.9\n            )  # amount won't count\n\n        # with mask of 1s, output should be identical to no mask\n        y_postpruning = m(input_)\n        self.assertEqual(flow.sum(y_prepruning - y_postpruning).item(), 0)\n\n        # with mask of 1s, grad should be identical to no mask\n        y_postpruning.sum().backward()\n        self.assertEqual(flow.sum(old_grad_weight - m.weight_orig.grad).item(), 0)\n        self.assertEqual(flow.sum(old_grad_bias - m.bias.grad).item(), 0)\n\n        # calling forward twice in a row shouldn't change output\n        y1 = m(input_)\n        y2 = m(input_)\n        self.assertEqual(flow.sum(y1 - y2).item(), 0)\n\n    def test_random_pruning(self):\n        input_ = flow.ones(1, 5)\n        m = nn.Linear(5, 2)\n\n        # define custom mask to assign with mock\n        mask = flow.ones_like(m.weight)\n        mask[1, 0] = 0\n        mask[0, 3] = 0\n\n        # check grad is zero for masked weights\n        with mock.patch(\n            \"oneflow.nn.utils.prune.RandomUnstructured.compute_mask\"\n        ) as compute_mask:\n            compute_mask.return_value = mask\n            prune.random_unstructured(m, name=\"weight\", amount=0.9)\n\n        y_postpruning = m(input_)\n        y_postpruning.sum().backward()\n        # weight_orig is the parameter, so it's the tensor that will accumulate the grad\n        self.assertEqual(\n            flow.sum(m.weight_orig.grad - mask).item(), 0\n        )  # all 1s, except for masked units\n        self.assertEqual(flow.sum(m.bias.grad - flow.ones_like(m.bias)).item(), 0)\n\n        # make sure that weight_orig update doesn't modify [1, 0] and [0, 3]\n        old_weight_orig = m.weight_orig.clone()\n        # update weights\n        learning_rate = 1.0\n        for p in m.parameters():\n            p.data.sub_(p.grad.data * learning_rate)\n        # since these are pruned, they should not be updated\n        self.assertEqual(\n            flow.sum(old_weight_orig[1, 0] - m.weight_orig[1, 0]).item(), 0\n        )\n        self.assertEqual(\n            flow.sum(old_weight_orig[0, 3] - m.weight_orig[0, 3]).item(), 0\n        )\n\n    def test_random_pruning_forward(self):\n        r\"\"\"check forward with mask (by hand).\n        \"\"\"\n        input_ = flow.ones(1, 5)\n        m = nn.Linear(5, 2)\n\n        # define custom mask to assign with mock\n        mask = flow.zeros_like(m.weight)\n        mask[1, 0] = 1\n        mask[0, 3] = 1\n\n        with mock.patch(\n            \"oneflow.nn.utils.prune.RandomUnstructured.compute_mask\"\n        ) as compute_mask:\n            compute_mask.return_value = mask\n            prune.random_unstructured(m, name=\"weight\", amount=0.9)\n\n        yhat = m(input_)\n        self.assertTrue(\n            flow.sum(yhat[0, 0] - m.weight_orig[0, 3] - m.bias[0]).item() - 0 < 1e-5\n        )\n        self.assertTrue(\n            flow.sum(yhat[0, 1] - m.weight_orig[1, 0] - m.bias[1]).item() - 0 < 1e-5\n        )\n\n    def test_remove_pruning_forward(self):\n        r\"\"\"Remove pruning and check forward is unchanged from previous\n        pruned state.\n        \"\"\"\n        input_ = flow.ones(1, 5)\n        m = nn.Linear(5, 2)\n\n        # define custom mask to assign with mock\n        mask = flow.ones_like(m.weight)\n        mask[1, 0] = 0\n        mask[0, 3] = 0\n\n        # check grad is zero for masked weights\n        with mock.patch(\n            \"oneflow.nn.utils.prune.RandomUnstructured.compute_mask\"\n        ) as compute_mask:\n            compute_mask.return_value = mask\n            prune.random_unstructured(m, name=\"weight\", amount=0.9)\n\n        y_postpruning = m(input_)\n\n        prune.remove(m, \"weight\")\n\n        y_postremoval = m(input_)\n        self.assertEqual(flow.sum(y_postpruning - y_postremoval).item(), 0)\n\n    def test_pruning_id_consistency(self):\n        r\"\"\"Test that pruning doesn't change the id of the parameters, which\n        would otherwise introduce issues with pre-existing optimizers that\n        point to old parameters.\n        \"\"\"\n        m = nn.Linear(5, 2, bias=False)\n\n        tensor_id = id(list(m.parameters())[0])\n\n        prune.random_unstructured(m, name=\"weight\", amount=0.9)\n        self.assertEqual(tensor_id, id(list(m.parameters())[0]))\n\n        prune.remove(m, \"weight\")\n        self.assertEqual(tensor_id, id(list(m.parameters())[0]))\n\n    def test_random_pruning_pickle(self):\n        modules = [nn.Linear(5, 7), nn.Conv3d(2, 2, 2)]\n        names = [\"weight\", \"bias\"]\n\n        for m in modules:\n            for name in names:\n                with self.subTest(m=m, name=name):\n                    prune.random_unstructured(m, name=name, amount=0.1)\n                    m_new = pickle.loads(pickle.dumps(m))\n                    self.assertIsInstance(m_new, type(m))\n\n    def test_multiple_pruning_calls(self):\n        # if you call pruning twice, the hook becomes a PruningContainer\n        m = nn.Conv3d(2, 2, 2)\n        prune.l1_unstructured(m, name=\"weight\", amount=0.1)\n        weight_mask0 = m.weight_mask  # save it for later sanity check\n\n        # prune again\n        prune.ln_structured(m, name=\"weight\", amount=0.3, n=2, dim=0)\n        hook = next(iter(m._forward_pre_hooks.values()))\n        self.assertIsInstance(hook, oneflow.nn.utils.prune.PruningContainer)\n        # check that container._tensor_name is correctly set no matter how\n        # many pruning methods are in the container\n        self.assertEqual(hook._tensor_name, \"weight\")\n\n        # check that the pruning container has the right length\n        # equal to the number of pruning iters\n        self.assertEqual(len(hook), 2)  # m.weight has been pruned twice\n\n        # check that the entries of the pruning container are of the expected\n        # type and in the expected order\n        self.assertIsInstance(hook[0], oneflow.nn.utils.prune.L1Unstructured)\n        self.assertIsInstance(hook[1], oneflow.nn.utils.prune.LnStructured)\n\n        # check that all entries that are 0 in the 1st mask are 0 in the\n        # 2nd mask too\n        self.assertTrue(flow.all(m.weight_mask[weight_mask0 == 0] == 0))\n\n        # prune again\n        prune.ln_structured(m, name=\"weight\", amount=0.1, n=float(\"inf\"), dim=1)\n        # check that container._tensor_name is correctly set no matter how\n        # many pruning methods are in the container\n        hook = next(iter(m._forward_pre_hooks.values()))\n        self.assertEqual(hook._tensor_name, \"weight\")\n\n    def test_pruning_container(self):\n        # create an empty container\n        container = prune.PruningContainer()\n        container._tensor_name = \"test\"\n        self.assertEqual(len(container), 0)\n\n        p = prune.L1Unstructured(amount=2)\n        p._tensor_name = \"test\"\n\n        # test adding a pruning method to a container\n        container.add_pruning_method(p)\n\n        # test error raised if tensor name is different\n        q = prune.L1Unstructured(amount=2)\n        q._tensor_name = \"another_test\"\n        with self.assertRaises(ValueError):\n            container.add_pruning_method(q)\n\n        # test that adding a non-pruning method object to a pruning container\n        # raises a TypeError\n        with self.assertRaises(TypeError):\n            container.add_pruning_method(10)\n        with self.assertRaises(TypeError):\n            container.add_pruning_method(\"ugh\")\n\n    def test_pruning_container_compute_mask(self):\n        r\"\"\"Test `compute_mask` of pruning container with a known `t` and\n        `default_mask`. Indirectly checks that Ln structured pruning is\n        acting on the right axis.\n        \"\"\"\n        # create an empty container\n        container = prune.PruningContainer()\n        container._tensor_name = \"test\"\n\n        # 1) test unstructured pruning\n        # create a new pruning method\n        p = prune.L1Unstructured(amount=2)\n        p._tensor_name = \"test\"\n        # add the pruning method to the container\n        container.add_pruning_method(p)\n\n        # create tensor to be pruned\n        t = flow.tensor([[1, 2, 3, 4], [5, 6, 7, 8]]).to(dtype=flow.float32)\n        # create prior mask by hand\n        default_mask = flow.tensor([[1, 1, 1, 0], [1, 1, 0, 1]])\n        # since we are pruning the two lowest magnitude units, the outcome of\n        # the calculation should be this:\n        expected_mask = flow.tensor([[0, 0, 1, 0], [1, 1, 0, 1]], dtype=flow.float32)\n        computed_mask = container.compute_mask(t, default_mask)\n        self.assertEqual(flow.sum(expected_mask - computed_mask).item(), 0)\n\n        # 2) test structured pruning\n        q = prune.LnStructured(amount=1, n=2, dim=0)\n        q._tensor_name = \"test\"\n        container.add_pruning_method(q)\n        # since we are pruning the lowest magnitude one of the two rows, the\n        # outcome of the calculation should be this:\n        expected_mask = flow.tensor([[0, 0, 0, 0], [1, 1, 0, 1]], dtype=flow.float32)\n        computed_mask = container.compute_mask(t, default_mask)\n        self.assertEqual(flow.sum(expected_mask - computed_mask).item(), 0)\n\n        # 2) test structured pruning, along another axis\n        r = prune.LnStructured(amount=1, n=2, dim=1)\n        r._tensor_name = \"test\"\n        container.add_pruning_method(r)\n        # since we are pruning the lowest magnitude of the four columns, the\n        # outcome of the calculation should be this:\n        expected_mask = flow.tensor([[0, 1, 1, 0], [0, 1, 0, 1]], dtype=flow.float32)\n        computed_mask = container.compute_mask(t, default_mask)\n        self.assertEqual(flow.sum(expected_mask - computed_mask).item(), 0)\n\n    def test_l1_unstructured_pruning(self):\n        r\"\"\"Test that l1 unstructured pruning actually removes the lowest\n        entries by l1 norm (by hand). It also checks that applying l1\n        unstructured pruning more than once respects the previous mask.\n        \"\"\"\n        m = nn.Linear(4, 2)\n        # modify its weight matrix by hand\n        m.weight = flow.nn.Parameter(\n            flow.tensor([[1, 2, 3, 4], [-4, -3, -2, -1]], dtype=flow.float32)\n        )\n\n        prune.l1_unstructured(m, \"weight\", amount=2)\n        expected_weight = flow.tensor(\n            [[0, 2, 3, 4], [-4, -3, -2, 0]], dtype=m.weight.dtype\n        )\n        self.assertEqual(flow.sum(expected_weight - m.weight).item(), 0)\n\n        # check that pruning again removes the next two smallest entries\n        prune.l1_unstructured(m, \"weight\", amount=2)\n        expected_weight = flow.tensor(\n            [[0, 0, 3, 4], [-4, -3, 0, 0]], dtype=m.weight.dtype\n        )\n        self.assertEqual(flow.sum(expected_weight - m.weight).item(), 0)\n\n    def test_l1_unstructured_pruning_with_importance_scores(self):\n        r\"\"\"Test that l1 unstructured pruning actually removes the lowest\n        entries of importance scores and not the parameter by l1 norm (by hand).\n        It also checks that applying l1 unstructured pruning more than once\n        respects the previous mask.\n        \"\"\"\n        m = nn.Linear(4, 2)\n        # modify its weight matrix by hand\n        m.weight = flow.nn.Parameter(\n            flow.tensor([[1, 2, 3, 4], [-4, -3, -2, -1]], dtype=flow.float32)\n        )\n        importance_scores = flow.tensor(\n            [[4, 2, 1, 3], [-3, -1, -2, -4]], dtype=flow.float32\n        )\n\n        prune.l1_unstructured(\n            m, \"weight\", amount=2, importance_scores=importance_scores\n        )\n        expected_weight = flow.tensor(\n            [[1, 2, 0, 4], [-4, 0, -2, -1]], dtype=m.weight.dtype\n        )\n        self.assertEqual(flow.sum(expected_weight - m.weight).item(), 0)\n\n        # check that pruning again removes two entries of m.weight that are colocated with\n        # the next two smallest absolute values of importance scores.\n        prune.l1_unstructured(\n            m, \"weight\", amount=2, importance_scores=importance_scores\n        )\n        expected_weight = flow.tensor(\n            [[1, 0, 0, 4], [-4, 0, 0, -1]], dtype=m.weight.dtype\n        )\n        self.assertEqual(flow.sum(expected_weight - m.weight).item(), 0)\n\n    def test_unstructured_pruning_same_magnitude(self):\n        r\"\"\"Since it may happen that the tensor to prune has entries with the\n        same exact magnitude, it is important to check that pruning happens\n        consistenly based on the bottom % of weights, and not by threshold,\n        which would instead kill off *all* units with magnitude = threshold.\n        \"\"\"\n        AMOUNT = 0.2\n        p = prune.L1Unstructured(amount=AMOUNT)\n        # create a random tensors with entries in {-2, 0, 2}\n        t = 2 * flow.randint(low=-1, high=2, size=(10, 7))\n        nparams_toprune = prune._compute_nparams_toprune(AMOUNT, t.nelement())\n\n        computed_mask = p.compute_mask(t, default_mask=flow.ones_like(t))\n        nparams_pruned = flow.sum(computed_mask == 0)\n        self.assertEqual(nparams_toprune, nparams_pruned)\n\n    def test_random_structured_pruning_amount(self):\n        AMOUNT = 0.6\n        AXIS = 2\n        p = prune.RandomStructured(amount=AMOUNT, dim=AXIS)\n        t = 2 * flow.randint(low=-1, high=2, size=(5, 4, 2)).to(dtype=flow.float32)\n        nparams_toprune = prune._compute_nparams_toprune(AMOUNT, t.shape[AXIS])\n\n        computed_mask = p.compute_mask(t, default_mask=flow.ones_like(t))\n        # check that 1 column is fully prune, the others are left untouched\n        remaining_axes = [_ for _ in range(len(t.shape)) if _ != AXIS]\n        per_column_sums = sorted(flow.sum(computed_mask == 0, dim=remaining_axes))\n        assert per_column_sums == [0, 20]\n\n    def test_ln_structured_pruning(self):\n        r\"\"\"Check Ln structured pruning by hand.\n        \"\"\"\n        m = nn.Conv2d(3, 1, 2)\n        m.weight.data = flow.tensor(\n            [\n                [\n                    [[1.0, 2.0], [1.0, 2.5]],\n                    [[0.5, 1.0], [0.1, 0.1]],\n                    [[-3.0, -5.0], [0.1, -1.0]],\n                ]\n            ]\n        )\n        # expected effect of pruning 1 of the 3 channels by L2-norm\n        expected_mask_axis1 = flow.ones_like(m.weight)\n        expected_mask_axis1[:, 1] = 0.0\n\n        prune.ln_structured(m, \"weight\", amount=1, n=2, dim=1)\n        self.assertEqual(flow.sum(expected_mask_axis1 - m.weight_mask).item(), 0)\n\n        # expected effect of pruning 1 of the 2 columns along axis -1 by L1-norm\n        expected_mask_axis3 = expected_mask_axis1\n        expected_mask_axis3[:, :, :, 0] = 0.0\n\n        prune.ln_structured(m, \"weight\", amount=1, n=1, dim=-1)\n        self.assertEqual(flow.sum(expected_mask_axis3 - m.weight_mask).item(), 0)\n\n    def test_ln_structured_pruning_importance_scores(self):\n        r\"\"\"Check Ln structured pruning by hand.\n        \"\"\"\n        m = nn.Conv2d(3, 1, 2)\n        m.weight.data = flow.tensor(\n            [\n                [\n                    [[1.0, 2.0], [1.0, 2.5]],\n                    [[0.5, 1.0], [0.1, 0.1]],\n                    [[-3.0, -5.0], [0.1, -1.0]],\n                ]\n            ]\n        )\n        importance_scores = flow.tensor(\n            [\n                [\n                    [[10.0, 1.0], [10.0, 1.0]],\n                    [[30.0, 3.0], [30.0, 3.0]],\n                    [[-20.0, -2.0], [-20.0, -2.0]],\n                ]\n            ]\n        )\n        # expected effect of pruning 1 of the 3 channels by L2-norm\n        expected_mask_axis1 = flow.ones_like(m.weight)\n        expected_mask_axis1[:, 0] = 0.0\n\n        prune.ln_structured(\n            m, \"weight\", amount=1, n=2, dim=1, importance_scores=importance_scores\n        )\n        self.assertEqual(flow.sum(expected_mask_axis1 - m.weight_mask).item(), 0)\n\n        # expected effect of pruning 1 of the 2 columns along axis -1 by L1-norm\n        expected_mask_axis3 = expected_mask_axis1\n        expected_mask_axis3[:, :, :, 1] = 0.0\n\n        prune.ln_structured(\n            m, \"weight\", amount=1, n=1, dim=-1, importance_scores=importance_scores\n        )\n        self.assertEqual(flow.sum(expected_mask_axis3 - m.weight_mask).item(), 0)\n\n    def test_remove_pruning(self):\n        r\"\"\"`prune.remove` removes the hook and the reparametrization\n        and makes the pruning final in the original parameter.\n        \"\"\"\n        modules = [nn.Linear(5, 7), nn.Conv3d(2, 2, 2)]\n        names = [\"weight\", \"bias\"]\n\n        for m in modules:\n            for name in names:\n                with self.subTest(m=m, name=name):\n                    # first prune\n                    prune.random_unstructured(m, name, amount=0.5)\n                    self.assertIn(name + \"_orig\", dict(m.named_parameters()))\n                    self.assertIn(name + \"_mask\", dict(m.named_buffers()))\n                    self.assertNotIn(name, dict(m.named_parameters()))\n                    self.assertTrue(hasattr(m, name))\n                    pruned_t = getattr(m, name)\n\n                    # then remove pruning\n                    prune.remove(m, name)\n                    self.assertIn(name, dict(m.named_parameters()))\n                    self.assertNotIn(name + \"_orig\", dict(m.named_parameters()))\n                    self.assertNotIn(name + \"_mask\", dict(m.named_buffers()))\n                    final_t = getattr(m, name)\n\n                    self.assertEqual(flow.sum(pruned_t - final_t).item(), 0)\n\n    def test_remove_pruning_exception(self):\n        r\"\"\"Removing from an unpruned tensor throws an assertion error\n        \"\"\"\n        modules = [nn.Linear(5, 7), nn.Conv3d(2, 2, 2)]\n        names = [\"weight\", \"bias\"]\n\n        for m in modules:\n            for name in names:\n                with self.subTest(m=m, name=name):\n                    # check that the module isn't pruned\n                    self.assertFalse(prune.is_pruned(m))\n                    # since it isn't pruned, pruning can't be removed from it\n                    with self.assertRaises(ValueError):\n                        prune.remove(m, name)\n\n    def test_global_pruning(self):\n        r\"\"\"Test that global l1 unstructured pruning over 2 parameters removes\n        the `amount=4` smallest global weights across the 2 parameters.\n        \"\"\"\n        m = nn.Linear(4, 2)\n        n = nn.Linear(3, 1)\n        # modify the weight matrices by hand\n        m.weight = flow.nn.Parameter(\n            flow.tensor([[1, 2, 3, 4], [-4, -3, -2, -1]]).to(dtype=flow.float32)\n        )\n        n.weight = flow.nn.Parameter(flow.tensor([[0, 0.1, -2]]).to(dtype=flow.float32))\n\n        params_to_prune = (\n            (m, \"weight\"),\n            (n, \"weight\"),\n        )\n\n        # prune the 4 smallest weights globally by L1 magnitude\n        prune.global_unstructured(\n            params_to_prune, pruning_method=prune.L1Unstructured, amount=4\n        )\n\n        expected_mweight = flow.tensor(\n            [[0, 2, 3, 4], [-4, -3, -2, 0]], dtype=m.weight.dtype\n        )\n        self.assertEqual(flow.sum(expected_mweight - m.weight).item(), 0)\n\n        expected_nweight = flow.tensor([[0, 0, -2]]).to(dtype=n.weight.dtype)\n        self.assertEqual(flow.sum(expected_nweight - n.weight).item(), 0)\n\n    def test_global_pruning_importance_scores(self):\n        r\"\"\"Test that global l1 unstructured pruning over 2 parameters removes\n        the `amount=4` smallest global weights across the 2 parameters.\n        \"\"\"\n        m = nn.Linear(4, 2)\n        n = nn.Linear(3, 1)\n        # modify the weight matrices by hand\n        m.weight = flow.nn.Parameter(\n            flow.tensor([[1, 2, 3, 4], [-4, -3, -2, -1]]).to(dtype=flow.float32)\n        )\n        m_importance_scores = flow.tensor(\n            [[4, 2, 1, 3], [-3, -1, -2, -4]], dtype=flow.float32\n        )\n        n.weight = flow.nn.Parameter(flow.tensor([[0, 0.1, -2]]).to(dtype=flow.float32))\n        n_importance_scores = flow.tensor([[0, 10.0, -0.2]]).to(dtype=flow.float32)\n\n        params_to_prune = (\n            (m, \"weight\"),\n            (n, \"weight\"),\n        )\n        importance_scores = {\n            (m, \"weight\"): m_importance_scores,\n            (n, \"weight\"): n_importance_scores,\n        }\n\n        # prune the 4 smallest weights globally by L1 magnitude\n        prune.global_unstructured(\n            params_to_prune,\n            pruning_method=prune.L1Unstructured,\n            amount=4,\n            importance_scores=importance_scores,\n        )\n\n        expected_m_weight = flow.tensor(\n            [[1, 2, 0, 4], [-4, 0, -2, -1]], dtype=m.weight.dtype\n        )\n        self.assertEqual(flow.sum(expected_m_weight - m.weight).item(), 0)\n\n        expected_n_weight = flow.tensor([[0, 0.1, 0]]).to(dtype=n.weight.dtype)\n        self.assertEqual(flow.sum(expected_n_weight - n.weight).item(), 0)\n\n    def test_custom_from_mask_pruning(self):\n        r\"\"\"Test that the CustomFromMask is capable of receiving\n        as input at instantiation time a custom mask, and combining it with\n        the previous default mask to generate the correct final mask.\n        \"\"\"\n        # new mask\n        mask = flow.tensor([[0, 1, 1, 0], [0, 0, 1, 1]])\n        # old mask\n        default_mask = flow.tensor([[0, 0, 0, 0], [1, 1, 1, 1]])\n\n        # some tensor (not actually used)\n        t = flow.rand(mask.shape, dtype=flow.float32, device=mask.device)\n        # t = flow.rand_like(mask.to(dtype=flow.float32))\n\n        p = prune.CustomFromMask(mask=mask)\n\n        computed_mask = p.compute_mask(t, default_mask)\n        expected_mask = flow.tensor(\n            [[0, 0, 0, 0], [0, 0, 1, 1]], dtype=computed_mask.dtype\n        )\n\n        self.assertEqual(flow.sum(computed_mask - expected_mask).item(), 0)\n\n    def test_pruning_rollback(self):\n        r\"\"\"Test that if something fails when the we try to compute the mask,\n        then the model isn't left in some intermediate half-pruned state.\n        The try/except statement in `apply` should handle rolling back\n        to the previous state before pruning began.\n        \"\"\"\n        modules = [nn.Linear(5, 7), nn.Conv3d(2, 2, 2)]\n        names = [\"weight\", \"bias\"]\n\n        for m in modules:\n            for name in names:\n                with self.subTest(m=m, name=name):\n\n                    with mock.patch(\n                        \"oneflow.nn.utils.prune.L1Unstructured.compute_mask\"\n                    ) as compute_mask:\n                        compute_mask.side_effect = Exception(\"HA!\")\n                        with self.assertRaises(Exception):\n                            prune.l1_unstructured(m, name=name, amount=0.9)\n\n                        self.assertTrue(name in dict(m.named_parameters()))\n                        self.assertFalse(name + \"_mask\" in dict(m.named_buffers()))\n                        self.assertFalse(name + \"_orig\" in dict(m.named_parameters()))\n\n    def test_pruning_serialization_model(self):\n        # create a model\n        model = flow.nn.Sequential(\n            flow.nn.Linear(10, 10), flow.nn.ReLU(), flow.nn.Linear(10, 1),\n        )\n        # check that everything looks normal before pruning\n        self.assertNotIn(\"0.weight_orig\", model.state_dict())\n        self.assertNotIn(\"0.weight_mask\", model.state_dict())\n        self.assertIn(\"0.weight\", model.state_dict())\n\n        # prune one of its parameters\n        prune.l1_unstructured(module=model[0], name=\"weight\", amount=0.9)\n\n        # check that the original weight and the new mask are present\n        self.assertIn(\"0.weight_orig\", model.state_dict())\n        self.assertIn(\"0.weight_mask\", model.state_dict())\n        self.assertNotIn(\"0.weight\", model.state_dict())\n        self.assertTrue(hasattr(model[0], \"weight\"))\n\n        pruned_weight = model[0].weight\n\n        with tempfile.NamedTemporaryFile() as f:\n            flow.save(model, f.name)\n            new_model = flow.load(f.name)\n\n        # check that the original weight and the new mask are present\n        self.assertIn(\"0.weight_orig\", new_model.state_dict())\n        self.assertIn(\"0.weight_mask\", new_model.state_dict())\n        self.assertNotIn(\"0.weight\", new_model.state_dict())\n        self.assertTrue(hasattr(new_model[0], \"weight\"))\n\n        self.assertEqual(flow.sum(pruned_weight - new_model[0].weight).item(), 0)\n\n    def test_prune(self):\n        # create a new pruning method\n        p = prune.L1Unstructured(amount=2)\n        # create tensor to be pruned\n        t = flow.tensor([[1, 2, 3, 4], [5, 6, 7, 8]]).to(dtype=flow.float32)\n        # create prior mask by hand\n        default_mask = flow.tensor([[1, 1, 1, 0], [1, 1, 0, 1]])\n        # since we are pruning the two lowest magnitude units, the outcome of\n        # the calculation should be this:\n        expected_mask = flow.tensor([[0, 0, 1, 0], [1, 1, 0, 1]])\n        pruned_tensor = p.prune(t, default_mask)\n        self.assertEqual(flow.sum(t * expected_mask - pruned_tensor).item(), 0)\n\n    def test_prune_importance_scores(self):\n        # create a new pruning method\n        p = prune.L1Unstructured(amount=2)\n        # create tensor to be pruned\n        t = flow.tensor([[1, 2, 3, 4], [5, 6, 7, 8]]).to(dtype=flow.float32)\n        importance_scores = flow.tensor([[1, 2, 3, 4], [1.5, 1.6, 1.7, 1.8]]).to(\n            dtype=flow.float32\n        )\n        # create prior mask by hand\n        default_mask = flow.tensor([[1, 1, 1, 0], [1, 1, 0, 1]])\n        # since we are pruning the two lowest magnitude units, the outcome of\n        # the calculation should be this:\n        expected_mask = flow.tensor([[0, 1, 1, 0], [0, 1, 0, 1]])\n        pruned_tensor = p.prune(t, default_mask, importance_scores=importance_scores)\n        self.assertEqual(flow.sum(t * expected_mask - pruned_tensor).item(), 0)\n\n    def test_prune_importance_scores_mimic_default(self):\n        # create a new pruning method\n        p = prune.L1Unstructured(amount=2)\n        # create tensor to be pruned\n        t = flow.tensor([[1, 2, 3, 4], [5, 6, 7, 8]]).to(dtype=flow.float32)\n        # create prior mask by hand\n        default_mask = flow.tensor([[1, 1, 1, 0], [1, 1, 0, 1]])\n        # since we are pruning the two lowest magnitude units, the outcome of\n        # the calculation should be this:\n        expected_mask = flow.tensor([[0, 0, 1, 0], [1, 1, 0, 1]])\n        pruned_tensor_without_importance_scores = p.prune(t, default_mask)\n        pruned_tensor_with_importance_scores = p.prune(\n            t, default_mask, importance_scores=t\n        )\n        self.assertEqual(\n            flow.sum(\n                pruned_tensor_without_importance_scores\n                - pruned_tensor_with_importance_scores\n            ).item(),\n            0,\n        )\n        self.assertEqual(\n            flow.sum(\n                t * expected_mask - pruned_tensor_without_importance_scores\n            ).item(),\n            0,\n        )\n\n    def test_rnn_pruning(self):\n        l = flow.nn.LSTM(32, 32)\n        # This Module has 4 parameters called:\n        # 'weight_ih_l0', 'weight_hh_l0', 'bias_ih_l0', 'bias_hh_l0'\n\n        # Pruning one of them causes one of the weights to become a tensor\n        prune.l1_unstructured(l, \"weight_ih_l0\", 0.5)\n        assert sum([isinstance(p, flow.nn.Parameter) for p in l._flat_weights]) == 3\n\n        # Removing the pruning reparametrization restores the Parameter\n        prune.remove(l, \"weight_ih_l0\")\n        assert sum([isinstance(p, flow.nn.Parameter) for p in l._flat_weights]) == 4\n\n        # Make sure that, upon removal of the reparametrization, the\n        # `._parameters` and `.named_parameters` contain the right params.\n        # Specifically, the original weight ('weight_ih_l0') should be placed\n        # back in the parameters, while the reparametrization component\n        # ('weight_ih_l0_orig') should be removed.\n        assert \"weight_ih_l0\" in l._parameters\n        assert l._parameters[\"weight_ih_l0\"] is not None\n        assert \"weight_ih_l0_orig\" not in l._parameters\n        assert \"weight_ih_l0\" in dict(l.named_parameters())\n        assert dict(l.named_parameters())[\"weight_ih_l0\"] is not None\n        assert \"weight_ih_l0_orig\" not in dict(l.named_parameters())\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_qat_conv_modules.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nfrom collections import OrderedDict\nimport random\n\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.test_util import GenArgList\n\n\ndef _test_qat_conv1d(\n    test_case,\n    device,\n    quantization_formula,\n    quantization_bit,\n    quantization_scheme,\n    weight_quant_per_layer,\n    input_quant_momentum,\n):\n    batch_size = random.randint(1, 5)\n    input_channels = random.randint(1, 3)\n    output_channels = random.randint(1, 3)\n    spatial_size = random.randint(8, 16)\n    kernel_size = random.randint(1, 3)\n    stride = random.randint(1, 2)\n    padding = random.randint(0, 2)\n\n    qat_conv1d = flow.nn.QatConv1d(\n        in_channels=input_channels,\n        out_channels=output_channels,\n        kernel_size=kernel_size,\n        stride=stride,\n        padding=padding,\n        quantization_formula=quantization_formula,\n        quantization_bit=quantization_bit,\n        quantization_scheme=quantization_scheme,\n        weight_quant_per_layer=weight_quant_per_layer,\n        input_quant_momentum=input_quant_momentum,\n    ).to(device)\n\n    qat_input = flow.rand(\n        batch_size,\n        input_channels,\n        spatial_size,\n        dtype=flow.float32,\n        requires_grad=True,\n        device=device,\n    )\n\n    qat_out = qat_conv1d(qat_input)\n    qat_out.sum().backward()\n    qat_out.numpy()\n    qat_input.grad.numpy()\n\n\ndef _test_qat_conv2d(\n    test_case,\n    device,\n    quantization_formula,\n    quantization_bit,\n    quantization_scheme,\n    weight_quant_per_layer,\n    input_quant_momentum,\n):\n    batch_size = random.randint(1, 5)\n    input_channels = random.randint(1, 3)\n    output_channels = random.randint(1, 3)\n    spatial_size = random.randint(8, 16)\n    kernel_size = random.randint(1, 3)\n    stride = random.randint(1, 2)\n    padding = random.randint(0, 2)\n\n    qat_conv2d = flow.nn.QatConv2d(\n        in_channels=input_channels,\n        out_channels=output_channels,\n        kernel_size=kernel_size,\n        stride=stride,\n        padding=padding,\n        quantization_formula=quantization_formula,\n        quantization_bit=quantization_bit,\n        quantization_scheme=quantization_scheme,\n        weight_quant_per_layer=weight_quant_per_layer,\n        input_quant_momentum=input_quant_momentum,\n    ).to(device)\n\n    qat_input = flow.rand(\n        batch_size,\n        input_channels,\n        spatial_size,\n        spatial_size,\n        dtype=flow.float32,\n        requires_grad=True,\n        device=device,\n    )\n    qat_out = qat_conv2d(qat_input)\n    qat_out.sum().backward()\n    qat_out.numpy()\n    qat_input.grad.numpy()\n\n\ndef _test_qat_conv3d(\n    test_case,\n    device,\n    quantization_formula,\n    quantization_bit,\n    quantization_scheme,\n    weight_quant_per_layer,\n    input_quant_momentum,\n):\n    batch_size = random.randint(1, 5)\n    input_channels = random.randint(1, 3)\n    output_channels = random.randint(1, 3)\n    spatial_size = random.randint(8, 16)\n    kernel_size = random.randint(1, 3)\n    stride = random.randint(1, 2)\n    padding = random.randint(0, 2)\n\n    qat_conv3d = flow.nn.QatConv3d(\n        in_channels=input_channels,\n        out_channels=output_channels,\n        kernel_size=kernel_size,\n        stride=stride,\n        padding=padding,\n        quantization_formula=quantization_formula,\n        quantization_bit=quantization_bit,\n        quantization_scheme=quantization_scheme,\n        weight_quant_per_layer=weight_quant_per_layer,\n        input_quant_momentum=input_quant_momentum,\n    ).to(device)\n\n    qat_input = flow.rand(\n        batch_size,\n        input_channels,\n        spatial_size,\n        spatial_size,\n        spatial_size,\n        dtype=flow.float32,\n        requires_grad=True,\n        device=device,\n    )\n    qat_out = qat_conv3d(qat_input)\n    qat_out.sum().backward()\n    qat_out.numpy()\n    qat_input.grad.numpy()\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestQatModules(flow.unittest.TestCase):\n    @unittest.skip(\"skip for now, becase it failed 2 times in past week\")\n    def test_qat_conv1d(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"device\"] = [\"cuda\", \"cpu\"]\n        arg_dict[\"quantization_formula\"] = [\"google\"]\n        arg_dict[\"quantization_bit\"] = [4, 8]\n        arg_dict[\"quantization_scheme\"] = [\"symmetric\", \"affine\"]\n        arg_dict[\"weight_quant_per_layer\"] = [True, False]\n        arg_dict[\"input_quant_momentum\"] = [0.95]\n\n        for i in range(5):\n            for arg in GenArgList(arg_dict):\n                _test_qat_conv1d(test_case, *arg)\n\n    def test_qat_conv2d(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"device\"] = [\"cuda\", \"cpu\"]\n        arg_dict[\"quantization_formula\"] = [\"google\"]\n        arg_dict[\"quantization_bit\"] = [4, 8]\n        arg_dict[\"quantization_scheme\"] = [\"symmetric\", \"affine\"]\n        arg_dict[\"weight_quant_per_layer\"] = [True, False]\n        arg_dict[\"input_quant_momentum\"] = [0.95]\n\n        for i in range(5):\n            for arg in GenArgList(arg_dict):\n                _test_qat_conv2d(test_case, *arg)\n\n    def test_qat_conv3d(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"device\"] = [\"cuda\", \"cpu\"]\n        arg_dict[\"quantization_formula\"] = [\"google\"]\n        arg_dict[\"quantization_bit\"] = [4, 8]\n        arg_dict[\"quantization_scheme\"] = [\"symmetric\", \"affine\"]\n        arg_dict[\"weight_quant_per_layer\"] = [True, False]\n        arg_dict[\"input_quant_momentum\"] = [0.95]\n\n        for i in range(5):\n            for arg in GenArgList(arg_dict):\n                _test_qat_conv3d(test_case, *arg)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_quantile.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nfrom oneflow.test_utils.automated_test_util import *\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\n@autotest(n=3, check_graph=True)\ndef _test_quantile(test_cast, q):\n    device = random_device()\n    a = random_tensor(2, random(2, 5), random(2, 5)).to(device)\n    out = torch.quantile(a, q, dim=1, interpolation=\"linear\")\n    return out\n\n\n@unittest.skipIf(True, \"pytorch-1.10.0 will cause oneflow cudnn or cublas error\")\n@flow.unittest.skip_unless_1n1d()\nclass TestQuantile(flow.unittest.TestCase):\n    def test_quantile(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"q\"] = [0.2, 0.6, 0.8]\n        for arg in GenArgList(arg_dict):\n            _test_quantile(test_case, *arg)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_quantization.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport math\nimport numpy as np\n\nfrom oneflow.test_utils.automated_test_util import *\nfrom oneflow.test_utils.test_util import GenArgList\nfrom oneflow.test_utils.test_util import (\n    GenArgList,\n    type_name_to_flow_type,\n    type_name_to_np_type,\n)\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\ndef gen_quant_scale_for_min_max_symmetric(weight, quantization_bit):\n    weight_max = np.max(np.abs(weight))\n    denominator = 2.0 ** (quantization_bit - 1) - 1\n    return (weight_max / denominator, 0)\n\n\ndef gen_quant_scale_for_min_max_affine(weight, quantization_bit):\n    weight_max = np.max(weight)\n    weight_min = np.min(weight)\n    denominator = 2.0 ** quantization_bit - 1\n    scale = (weight_max - weight_min) / denominator\n    zero_point = -np.round(weight_min / scale)\n    return (scale, zero_point)\n\n\ndef gen_quant_scale_for_min_max_cambricon(weight, quantization_bit):\n    weight_max = np.max(np.abs(weight))\n    scale = math.floor(math.log2(weight_max)) - (quantization_bit - 2)\n    return (scale, 0)\n\n\ndef product(tu):\n    return np.prod(tu).astype(np.int32).item()\n\n\ndef quant_per_layer_symmetric(input, quantization_bit, scale):\n    upper_bound = 2.0 ** (quantization_bit - 1) - 1\n    lower_bound = -upper_bound\n    return np.clip(np.rint(input / scale), lower_bound, upper_bound)\n\n\ndef quant_per_layer_affine(input, quantization_bit, scale, zero_point):\n    upper_bound = 2.0 ** quantization_bit - 1\n    lower_bound = 0\n    return np.clip(np.rint(input / scale + zero_point), lower_bound, upper_bound)\n\n\ndef quant_per_layer_cambricon(input, quantization_bit, shift):\n    upper_bound = 2.0 ** (quantization_bit - 1) - 1\n    lower_bound = -upper_bound\n    scale = 2 ** shift\n    return np.clip(np.rint(input / scale), lower_bound, upper_bound)\n\n\ndef _check_quantize(\n    test_case,\n    input,\n    out_of,\n    quantization_bit,\n    quantization_scheme,\n    quantization_formula,\n    per_layer_quantization,\n):\n    if per_layer_quantization or quantization_formula == \"cambricon\":\n        outer_num = 1\n        inner_num = product(input.shape[0:])\n    else:\n        outer_num = input.shape[0]\n        inner_num = product(input.shape[1:])\n    scale_np = np.zeros((outer_num,))\n    zero_point_np = np.zeros((outer_num,))\n    out_np = np.zeros((inner_num * outer_num,))\n    input_flatten = input.flatten()\n    input_diff_np = np.full((inner_num * outer_num,), 1.0 / (inner_num * outer_num))\n    if quantization_formula == \"google\":\n        if quantization_scheme == \"symmetric\":\n            for c in range(outer_num):\n                (scale_np[c], zero_point_np[c]) = gen_quant_scale_for_min_max_symmetric(\n                    input_flatten[c * inner_num : (c + 1) * inner_num], quantization_bit\n                )\n                out = quant_per_layer_symmetric(\n                    input_flatten[c * inner_num : (c + 1) * inner_num],\n                    quantization_bit,\n                    scale_np[c],\n                )\n                out_np[c * inner_num : (c + 1) * inner_num] = out\n        else:\n            for c in range(outer_num):\n                (scale_np[c], zero_point_np[c]) = gen_quant_scale_for_min_max_affine(\n                    input_flatten[c * inner_num : (c + 1) * inner_num], quantization_bit\n                )\n                out = quant_per_layer_affine(\n                    input_flatten[c * inner_num : (c + 1) * inner_num],\n                    quantization_bit,\n                    scale_np[c],\n                    zero_point_np[c],\n                )\n                out_np[c * inner_num : (c + 1) * inner_num] = out\n    else:\n        (scale_np[0], zero_point_np[0]) = gen_quant_scale_for_min_max_cambricon(\n            input_flatten, quantization_bit\n        )\n        out_np = quant_per_layer_cambricon(input_flatten, quantization_bit, scale_np[0])\n    rmse = np.sqrt(np.mean((out_of - out_np) ** 2))\n    assert rmse <= 2.0, \"quantization op has bug!\"\n\n\ndef _run_test_quantize(\n    test_case,\n    device_type,\n    dtype,\n    in_shape,\n    quantization_bit,\n    quantization_scheme,\n    quantization_formula,\n    per_layer_quantization,\n):\n    input = (np.random.random(in_shape) - 0.5).astype(type_name_to_np_type[dtype])\n    input_tensor = flow.tensor(\n        input, dtype=flow.float32, device=flow.device(device_type)\n    )\n    min_max_observer = flow.nn.MinMaxObserver(\n        quantization_formula=quantization_formula,\n        quantization_bit=quantization_bit,\n        quantization_scheme=quantization_scheme,\n        per_layer_quantization=per_layer_quantization,\n    )\n    (scale, zero_point) = min_max_observer(input_tensor)\n    quantization = flow.nn.Quantization(\n        quantization_formula=quantization_formula,\n        quantization_bit=quantization_bit,\n        quantization_scheme=quantization_scheme,\n    )\n    output_tensor = quantization(input_tensor, scale, zero_point)\n\n    out = output_tensor.numpy()\n    _check_quantize(\n        test_case,\n        input,\n        out.flatten(),\n        quantization_bit,\n        quantization_scheme,\n        quantization_formula,\n        per_layer_quantization,\n    )\n\n\nclass TestQuantize(flow.unittest.TestCase):\n    def test_quantize(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_case\"] = [test_case]\n        arg_dict[\"device_type\"] = [\"cuda\", \"cpu\"]\n        arg_dict[\"dtype\"] = [\"float32\", \"double\"]\n        arg_dict[\"in_shape\"] = [(9, 40, 20, 10)]\n        arg_dict[\"quantization_bit\"] = [8, 2]\n        arg_dict[\"quantization_scheme\"] = [\"symmetric\", \"affine\"]\n        arg_dict[\"quantization_formula\"] = [\"google\"]\n        arg_dict[\"per_layer_quantization\"] = [True, False]\n        for arg in GenArgList(arg_dict):\n            if arg[-2] == \"cambricon\" and arg[-1] == False:\n                continue\n            _run_test_quantize(*arg)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_quick_gelu.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport oneflow.unittest\nimport torch\n\n\nclass QuickGELUActivation(torch.nn.Module):\n    \"\"\"\n    Applies GELU approximation that is fast but somewhat inaccurate. See: https://github.com/hendrycks/GELUs\n    \"\"\"\n\n    def forward(self, input: torch.Tensor) -> torch.Tensor:\n        return input * torch.sigmoid(1.702 * input)\n\n\ndef _test_quick_gelu(test_case, device):\n    torch_quick_gelu = QuickGELUActivation()\n    x = np.random.randn(2, 4, 3)\n    torch_x = torch.tensor(x, requires_grad=True, device=torch.device(device))\n    oneflow_x = flow.tensor(x, requires_grad=True, device=flow.device(device))\n    torch_y = torch_quick_gelu(torch_x)\n    oneflow_y = flow._C.quick_gelu(oneflow_x)\n    test_case.assertTrue(np.allclose(torch_y.detach().cpu().numpy(), oneflow_y.numpy()))\n    torch_y_sum = torch_y.sum()\n    torch_y_sum.backward()\n    oneflow_y_sum = oneflow_y.sum()\n    oneflow_y_sum.backward()\n    test_case.assertTrue(\n        np.allclose(torch_x.grad.cpu().numpy(), oneflow_x.grad.numpy())\n    )\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestModule(flow.unittest.TestCase):\n    def test_quick_gelu(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [_test_quick_gelu]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_rand.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\nfrom oneflow.test_utils.test_util import GenArgList\n\n\ndef _test_rand(test_case, device, shape):\n    y1 = flow.rand(*shape, device=flow.device(device))\n    y2 = flow.rand(size=shape, device=flow.device(device))\n\n    test_case.assertTrue(not np.array_equal(y1.numpy(), y2.numpy()))\n    test_case.assertTrue(shape == y1.shape)\n    test_case.assertTrue(shape == y2.shape)\n\n\ndef _test_rand_tuple_shape(test_case, device, shape):\n    y1 = flow.rand(shape, device=flow.device(device))\n    y2 = flow.rand(shape, device=flow.device(device))\n\n    test_case.assertTrue(not np.array_equal(y1.numpy(), y2.numpy()))\n    test_case.assertTrue(shape == y1.shape)\n\n\ndef _test_0d_rand(test_case, device, shape):\n    y1 = flow.rand(*shape, device=flow.device(device))\n    y2 = flow.rand(*shape, device=flow.device(device))\n    test_case.assertTrue(\n        np.allclose(y1.numpy(), y2.numpy(), atol=1e-4, rtol=1e-4)\n    )  # 0d is [] and []\n    test_case.assertTrue(shape == y1.shape)\n\n\ndef _test_different_dtype(test_case, device, shape):\n    y1 = flow.rand(*shape, dtype=flow.float32, device=flow.device(device))\n    y2 = flow.rand(*shape, dtype=flow.float64, device=flow.device(device))\n    test_case.assertTrue(not np.array_equal(y1.numpy(), y2.numpy()))\n    test_case.assertTrue(shape == y1.shape)\n\n    with test_case.assertRaises(NotImplementedError):\n        flow.rand(*shape, dtype=flow.int32, device=flow.device(device))\n\n\ndef _test_backward(test_case, device, shape):\n    x = flow.rand(*shape, device=flow.device(device), requires_grad=True)\n    y = x.sum()\n    y.backward()\n    test_case.assertTrue(np.array_equal(np.ones(shape), x.grad.numpy()))\n\n\ndef _test_with_generator(test_case, device, shape):\n    gen = flow.Generator()\n    gen.manual_seed(0)\n    y1 = flow.rand(\n        *shape, dtype=flow.float32, device=flow.device(device), generator=gen\n    )\n    gen.manual_seed(0)\n    y2 = flow.rand(\n        *shape, dtype=flow.float32, device=flow.device(device), generator=gen\n    )\n    test_case.assertTrue(np.allclose(y1.numpy(), y2.numpy(), atol=1e-4, rtol=1e-4))\n\n\ndef _test_rand_with_flow_size(test_case, device, shape):\n    y1 = flow.rand(flow.Size(shape), device=flow.device(device))\n    y2 = flow.rand(flow.Size(shape), device=flow.device(device))\n\n    test_case.assertTrue(not np.array_equal(y1.numpy(), y2.numpy()))\n    test_case.assertTrue(shape == y1.shape)\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestRandModule(flow.unittest.TestCase):\n    def test_0d_randint(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [_test_0d_rand]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"shape\"] = [(2, 0, 4), (2, 0, 2)]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n    def test_cases(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [\n            _test_rand,\n            _test_rand_tuple_shape,\n            _test_different_dtype,\n            _test_backward,\n            _test_with_generator,\n            _test_rand_with_flow_size,\n        ]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"shape\"] = [(2, 3), (2, 3, 4), (2, 3, 4, 5), (2, 4)]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    def test_half_rand(test_case):\n        for device in [\"cuda\", \"cpu\"]:\n            x = flow.rand(2, 3, dtype=flow.float16, device=flow.device(device))\n            test_case.assertTrue(x.dtype == flow.float16)\n            test_case.assertTrue(x.shape == flow.Size((2, 3)))\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n@flow.unittest.skip_unless_1n2d()\nclass TestRandOnNonDefaultDevice(flow.unittest.TestCase):\n    def test_non_default_device(test_case):\n        x = flow.rand(2, 3, device=\"cuda:1\")\n        test_case.assertEqual(x.device, flow.device(\"cuda:1\"))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_randint.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport os\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.test_util import GenArgList\n\n\ndef _test_randint(test_case, device, shape, low, high):\n    y1 = flow.randint(low, high, shape, device=flow.device(device))\n    y2 = flow.randint(low, high, shape, device=flow.device(device))\n    test_case.assertFalse(np.allclose(y1.numpy(), y2.numpy(), atol=1e-4, rtol=1e-4))\n    test_case.assertTrue(shape == y1.shape)\n\n\ndef _test_0d_randint(test_case, device, shape, low, high):\n    y1 = flow.randint(low, high, shape, device=flow.device(device))\n    y2 = flow.randint(low, high, shape, device=flow.device(device))\n    test_case.assertTrue(\n        np.allclose(y1.numpy(), y2.numpy(), atol=1e-4, rtol=1e-4)\n    )  # 0d is [] and []\n    test_case.assertTrue(shape == y1.shape)\n\n\ndef _test_different_dtype(test_case, device, shape, low, high):\n    for dtype in [\n        flow.uint8,\n        flow.int8,\n        flow.int32,\n        flow.int64,\n        flow.float32,\n        flow.float64,\n    ]:\n        y = flow.randint(low, high, shape, dtype=dtype, device=flow.device(device))\n        test_case.assertTrue(y.dtype == dtype)\n        test_case.assertTrue(y.shape == shape)\n\n\ndef _test_with_generator(test_case, device, shape, low, high):\n    gen = flow.Generator()\n    gen.manual_seed(0)\n    y1 = flow.randint(\n        low, high, shape, dtype=flow.float32, device=flow.device(device), generator=gen\n    )\n    gen.manual_seed(0)\n    y2 = flow.randint(\n        low, high, shape, dtype=flow.float32, device=flow.device(device), generator=gen\n    )\n    test_case.assertTrue(np.allclose(y1.numpy(), y2.numpy(), atol=1e-4, rtol=1e-4))\n\n\ndef _test_high(test_case, device, shape, low, high):\n    y1 = flow._C.randint(high, shape, device=flow.device(device))\n    y2 = flow._C.randint(high, shape, device=flow.device(device))\n    test_case.assertFalse(np.allclose(y1.numpy(), y2.numpy(), atol=1e-4, rtol=1e-4))\n    test_case.assertTrue(shape == y1.shape)\n\n\ndef _test_0rank(test_case, device, shape, low, high):\n    y1 = flow.randint(low, high, shape, device=flow.device(device))\n    test_case.assertTrue(y1.shape == shape)\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestRandint(flow.unittest.TestCase):\n    def test_global_different_types(test_case):\n        for dtype in [\n            flow.int8,\n            flow.int32,\n            flow.int64,\n            flow.float32,\n            flow.float64,\n        ]:\n            placement = flow.placement(\"cpu\", ranks=[0])\n            sbp = (flow.sbp.broadcast,)\n            x = flow.randint(0, 16, (10, 1), placement=placement, sbp=sbp, dtype=dtype)\n            test_case.assertEqual(x.dtype, dtype)\n            test_case.assertEqual(x.sbp, sbp)\n            test_case.assertEqual(x.placement, placement)\n\n    def test_randint(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [\n            _test_randint,\n            _test_different_dtype,\n            _test_with_generator,\n        ]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"shape\"] = [(2, 3), (2, 3, 4), (2, 3, 4, 5)]\n        arg_dict[\"low\"] = [i for i in range(10)]\n        arg_dict[\"high\"] = [10 + np.random.randint(10, 20) for i in range(10)]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n    def test_0d_randint(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [_test_0d_randint]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"shape\"] = [(2, 0, 4), (2, 0, 2)]\n        arg_dict[\"low\"] = [i for i in range(10)]\n        arg_dict[\"high\"] = [10 + np.random.randint(1, 20) for i in range(10)]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n    def test_high_randint(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [_test_high]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"shape\"] = [(2, 3, 4), (2, 5, 2)]\n        arg_dict[\"low\"] = [i for i in range(10)]\n        arg_dict[\"high\"] = [10 + np.random.randint(10, 20) for i in range(10)]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n    def test_0rank_randint(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [_test_0rank]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"shape\"] = [()]\n        arg_dict[\"low\"] = [i for i in range(10)]\n        arg_dict[\"high\"] = [1000 + np.random.randint(1, 10) for i in range(10)]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n@flow.unittest.skip_unless_1n2d()\nclass TestRandintOnNonDefaultDevice(flow.unittest.TestCase):\n    def test_non_default_device(test_case):\n        x = flow.randint(low=1, high=2, size=flow.Size((2, 3)), device=\"cuda:1\")\n        test_case.assertEqual(x.device, flow.device(\"cuda:1\"))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_randint_like.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport os\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.test_util import GenArgList\n\n\ndef _test_randint_like(test_case, device, shape, low, high):\n    x = flow.randn(shape)\n    y1 = flow.randint_like(x, low, high, device=flow.device(device))\n    y2 = flow.randint_like(x, low, high, device=flow.device(device))\n    test_case.assertFalse(np.allclose(y1.numpy(), y2.numpy(), atol=1e-4, rtol=1e-4))\n    test_case.assertTrue(shape == y1.shape)\n\n\ndef _test_0d_randint_like(test_case, device, shape, low, high):\n    x = flow.randn(shape)\n    y1 = flow.randint_like(x, low, high, device=flow.device(device))\n    y2 = flow.randint_like(x, low, high, device=flow.device(device))\n    test_case.assertTrue(\n        np.allclose(y1.numpy(), y2.numpy(), atol=1e-4, rtol=1e-4)\n    )  # 0d is [] and []\n    test_case.assertTrue(shape == y1.shape)\n\n\ndef _test_different_dtype(test_case, device, shape, low, high):\n    for dtype in [\n        flow.uint8,\n        flow.int8,\n        flow.int32,\n        flow.int64,\n        flow.float32,\n        flow.float64,\n    ]:\n        x = flow.randint(low, high, shape, dtype=dtype)\n        y = flow.randint_like(x, low, high, dtype=dtype, device=flow.device(device))\n        test_case.assertTrue(y.dtype == dtype)\n        test_case.assertTrue(y.shape == shape)\n\n\ndef _test_with_generator(test_case, device, shape, low, high):\n    gen = flow.Generator()\n    gen.manual_seed(0)\n    x = flow.randn(shape)\n    y1 = flow.randint_like(\n        x, low, high, dtype=flow.float32, device=flow.device(device), generator=gen\n    )\n    gen.manual_seed(0)\n    x = flow.randn(shape)\n    y2 = flow.randint_like(\n        x, low, high, dtype=flow.float32, device=flow.device(device), generator=gen\n    )\n    test_case.assertTrue(np.allclose(y1.numpy(), y2.numpy(), atol=1e-4, rtol=1e-4))\n\n\ndef _test_high(test_case, device, shape, low, high):\n    x = flow.randn(shape)\n    y1 = flow._C.randint_like(x, high, device=flow.device(device))\n    y2 = flow._C.randint_like(x, high, device=flow.device(device))\n    test_case.assertFalse(np.allclose(y1.numpy(), y2.numpy(), atol=1e-4, rtol=1e-4))\n    test_case.assertTrue(shape == y1.shape)\n\n\ndef _test_0rank(test_case, device, shape, low, high):\n    x = flow.randn(shape)\n    y1 = flow.randint_like(x, low, high, device=flow.device(device))\n    test_case.assertTrue(y1.shape == shape)\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestRandIntLike(flow.unittest.TestCase):\n    def test_global_different_types(test_case):\n        for dtype in [\n            flow.int8,\n            flow.int32,\n            flow.int64,\n            flow.float32,\n            flow.float64,\n        ]:\n            placement = flow.placement(\"cpu\", ranks=[0])\n            sbp = (flow.sbp.broadcast,)\n            x_ = flow.randn((10, 1))\n            x = flow.randint_like(x_, 0, 16, placement=placement, sbp=sbp, dtype=dtype)\n            test_case.assertEqual(x.dtype, dtype)\n            test_case.assertEqual(x.sbp, sbp)\n            test_case.assertEqual(x.placement, placement)\n\n    def test_randint_like(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [\n            _test_randint_like,\n            _test_different_dtype,\n            _test_with_generator,\n        ]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"shape\"] = [(2, 3), (2, 3, 4), (2, 3, 4, 5)]\n        arg_dict[\"low\"] = [i for i in range(10)]\n        arg_dict[\"high\"] = [10 + np.random.randint(10, 20) for i in range(10)]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n    def test_0d_randint_like(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [_test_0d_randint_like]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"shape\"] = [(2, 0, 4), (2, 0, 2)]\n        arg_dict[\"low\"] = [i for i in range(10)]\n        arg_dict[\"high\"] = [10 + np.random.randint(1, 20) for i in range(10)]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n    def test_high_randint_like(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [_test_high]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"shape\"] = [(2, 3, 4), (2, 5, 2)]\n        arg_dict[\"low\"] = [i for i in range(10)]\n        arg_dict[\"high\"] = [10 + np.random.randint(10, 20) for i in range(10)]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n    def test_0rank_randint_like(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [_test_0rank]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"shape\"] = [()]\n        arg_dict[\"low\"] = [i for i in range(10)]\n        arg_dict[\"high\"] = [1000 + np.random.randint(1, 10) for i in range(10)]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n@flow.unittest.skip_unless_1n2d()\nclass TestRandIntLikeOnNonDefaultDevice(flow.unittest.TestCase):\n    def test_non_default_device(test_case):\n        x_ = flow.randn((2, 3))\n        x = flow.randint_like(x_, low=1, high=2, device=\"cuda:1\")\n        test_case.assertEqual(x.device, flow.device(\"cuda:1\"))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_randn.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nimport oneflow as flow\n\nimport oneflow.unittest\nfrom oneflow.test_utils.test_util import GenArgList\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\ndef _test_randn(test_case, device, shape):\n    y1 = flow.randn(*shape, device=flow.device(device))\n    y2 = flow.randn(size=shape, device=flow.device(device))\n    test_case.assertTrue(not np.allclose(y1.numpy(), y2.numpy(), atol=1e-4, rtol=1e-4))\n    test_case.assertTrue(shape == y1.shape)\n    test_case.assertTrue(shape == y2.shape)\n\n\ndef _test_0d_rand(test_case, device, shape):\n    y1 = flow.randn(*shape, device=flow.device(device))\n    y2 = flow.randn(*shape, device=flow.device(device))\n    test_case.assertTrue(\n        np.allclose(y1.numpy(), y2.numpy(), atol=1e-4, rtol=1e-4)\n    )  # 0d is [] and []\n    test_case.assertTrue(shape == y1.shape)\n\n\ndef _test_different_dtype(test_case, device, shape):\n    y1 = flow.randn(*shape, dtype=flow.float32, device=flow.device(device))\n    y2 = flow.randn(*shape, dtype=flow.float64, device=flow.device(device))\n    test_case.assertTrue(not np.allclose(y1.numpy(), y2.numpy(), atol=1e-4, rtol=1e-4))\n    test_case.assertTrue(shape == y1.shape)\n\n    with test_case.assertRaises(NotImplementedError):\n        flow.randn(*shape, dtype=flow.int32, device=flow.device(device))\n\n\ndef _test_backward(test_case, device, shape):\n    x = flow.randn(*shape, device=flow.device(device), requires_grad=True)\n    y = x.sum()\n    y.backward()\n    test_case.assertTrue(\n        np.allclose(np.ones(shape), x.grad.numpy(), atol=1e-4, rtol=1e-4)\n    )\n\n\ndef _test_with_generator(test_case, device, shape):\n    gen = flow.Generator()\n    gen.manual_seed(0)\n    y1 = flow.randn(\n        *shape, dtype=flow.float32, device=flow.device(device), generator=gen\n    )\n    gen.manual_seed(0)\n    y2 = flow.randn(\n        *shape, dtype=flow.float32, device=flow.device(device), generator=gen\n    )\n    test_case.assertTrue(np.allclose(y1.numpy(), y2.numpy(), atol=1e-4, rtol=1e-4))\n\n\ndef _test_randn_tuple_shape(test_case, device, shape):\n    y1 = flow.randn(shape, device=flow.device(device))\n    y2 = flow.randn(shape, device=flow.device(device))\n\n    test_case.assertTrue(not np.array_equal(y1.numpy(), y2.numpy()))\n    test_case.assertTrue(shape == y1.shape)\n\n\ndef _test_randn_with_flow_size(test_case, device, shape):\n    y1 = flow.randn(flow.Size(shape), device=flow.device(device))\n    y2 = flow.randn(flow.Size(shape), device=flow.device(device))\n\n    test_case.assertTrue(not np.array_equal(y1.numpy(), y2.numpy()))\n    test_case.assertTrue(shape == y1.shape)\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestRandnModule(flow.unittest.TestCase):\n    def test_global_naive(test_case):\n        placement = flow.placement(\"cpu\", ranks=[0])\n        sbp = (flow.sbp.broadcast,)\n        x = flow.randn(16, 16, placement=placement, sbp=sbp)\n        test_case.assertEqual(x.sbp, sbp)\n        test_case.assertEqual(x.placement, placement)\n\n    def test_randn(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [\n            _test_randn,\n            _test_different_dtype,\n            _test_backward,\n            _test_with_generator,\n            _test_randn_tuple_shape,\n            _test_randn_with_flow_size,\n        ]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"shape\"] = [(2, 3), (2, 3, 4), (2, 3, 4, 5)]\n\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n    def test_0d_randn(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [_test_0d_rand]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"shape\"] = [(2, 0, 4), (2, 0, 2)]\n\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    def test_half_randn(test_case):\n        for device in [\"cuda\", \"cpu\"]:\n            x = flow.randn(2, 3, dtype=flow.float16, device=flow.device(device))\n            test_case.assertTrue(x.dtype == flow.float16)\n            test_case.assertTrue(x.shape == flow.Size((2, 3)))\n\n    # Just check if `layout` param in api is available, there's no related implementation about it\n    # TODO(WangYi): remove this test when randn **really** supports `layout`\n    def test_randn_layout_param(test_case):\n        x = flow.randn(2, 3, layout=flow.strided)\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n@flow.unittest.skip_unless_1n2d()\nclass TestRandnOnNonDefaultDevice(flow.unittest.TestCase):\n    def test_non_default_device(test_case):\n        x = flow.randn(2, 3, device=\"cuda:1\")\n        test_case.assertEqual(x.device, flow.device(\"cuda:1\"))\n\n    def test_with_generator(test_case):\n        gen = flow.Generator(\"cuda\")\n        x = flow.randn(2, 3, device=\"cuda\", generator=gen)\n        test_case.assertEqual(x.device, flow.device(f\"cuda:{flow.env.get_rank()}\"))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_randn_like.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport os\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.test_util import GenArgList\n\n\ndef _test_randn_like(test_case, device, shape):\n    x = flow.randn(shape)\n    y = flow.randn_like(x, device=flow.device(device))\n    test_case.assertTrue(x.shape == y.shape)\n\n\ndef _test_0d_randn_like(test_case, device, shape):\n    x = flow.randn(shape)\n    y = flow.randn_like(x, device=flow.device(device))\n    test_case.assertTrue(x.shape == y.shape)\n\n\ndef _test_different_dtype(test_case, device, shape):\n    for dtype in [\n        flow.float16,\n        flow.float32,\n        flow.float64,\n        flow.double,\n    ]:\n        x = flow.randn(shape, dtype=dtype)\n        y = flow.randn_like(x, dtype=dtype, device=flow.device(device))\n        test_case.assertTrue(x.shape == y.shape)\n\n\ndef _test_with_generator(test_case, device, shape):\n    gen = flow.Generator()\n    gen.manual_seed(0)\n    x = flow.randn(shape)\n    y1 = flow.randn_like(\n        x, dtype=flow.float32, device=flow.device(device), generator=gen\n    )\n    gen.manual_seed(0)\n    x = flow.randn(shape)\n    y2 = flow.randn_like(\n        x, dtype=flow.float32, device=flow.device(device), generator=gen\n    )\n    test_case.assertTrue(np.allclose(y1.numpy(), y2.numpy(), atol=1e-4, rtol=1e-4))\n\n\ndef _test_0rank(test_case, device, shape):\n    x = flow.randn(shape)\n    y = flow.randn_like(x, device=flow.device(device))\n    test_case.assertTrue(x.shape == y.shape)\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestRandIntLike(flow.unittest.TestCase):\n    def test_global_different_types(test_case):\n        for dtype in [\n            flow.float16,\n            flow.float32,\n            flow.float64,\n            flow.double,\n        ]:\n            placement = flow.placement(\"cpu\", ranks=[0])\n            sbp = (flow.sbp.broadcast,)\n            x_ = flow.randn((10, 1), dtype=dtype)\n            x = flow.randn_like(x_, placement=placement, sbp=sbp, dtype=dtype)\n            test_case.assertEqual(x.dtype, dtype)\n            test_case.assertEqual(x.sbp, sbp)\n            test_case.assertEqual(x.placement, placement)\n\n    def test_randn_like(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [\n            _test_randn_like,\n            _test_different_dtype,\n            _test_with_generator,\n        ]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"shape\"] = [(2, 3), (2, 3, 4), (2, 3, 4, 5)]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n    def test_0d_randn_like(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [_test_0d_randn_like]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"shape\"] = [(2, 0, 4), (2, 0, 2)]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n    def test_0rank_randn_like(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [_test_0rank]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"shape\"] = [()]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n@flow.unittest.skip_unless_1n2d()\nclass TestRandIntLikeOnNonDefaultDevice(flow.unittest.TestCase):\n    def test_non_default_device(test_case):\n        x_ = flow.randn((2, 3))\n        x = flow.randn_like(x_, device=\"cuda:1\")\n        test_case.assertEqual(x.device, flow.device(\"cuda:1\"))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_random_generator_and_seed.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport os\nimport numpy as np\nimport inspect\nimport types\nimport unittest\nimport oneflow as flow\nimport oneflow.nn as nn\nimport oneflow.unittest\n\nfrom collections import OrderedDict\nfrom oneflow.test_utils.test_util import GenArgDict\n\n\n# y1 = rand_op1(x)\n# y2 = rand_op2(x)\n# rand_op1 and rand_op2 should have different seed in graph, then lead to different result\ndef _inspect_rand_op_and_args(rand_op, **kwargs):\n    if inspect.isclass(rand_op) and issubclass(rand_op, nn.Module):\n        init_method_signature = inspect.signature(rand_op.__init__)\n\n        module_init_args = dict()\n        for arg_name in list(init_method_signature.parameters.keys())[1:]:\n            if arg_name in kwargs:\n                module_init_args[arg_name] = kwargs.pop(arg_name)\n\n        module_instance = rand_op(**module_init_args)\n        return module_instance, kwargs\n\n    if isinstance(rand_op, types.BuiltinFunctionType):\n        return rand_op, kwargs\n\n    if inspect.isfunction(rand_op):\n        return rand_op, kwargs\n\n    raise ValueError(f\"invalid rand_op {rand_op}, type: {type(rand_op)}\")\n\n\ndef _test_rand_op_unidentical(test_case, rand_op, input=None, **kwargs):\n    rand_op1, kwargs1 = _inspect_rand_op_and_args(rand_op, **kwargs)\n    rand_op2, kwargs2 = _inspect_rand_op_and_args(rand_op, **kwargs)\n\n    if input is None:\n        result1 = rand_op1(**kwargs1)\n        result2 = rand_op2(**kwargs2)\n    else:\n        x1 = input\n        x2 = input.clone()\n        result1 = rand_op1(x1, **kwargs1)\n        result2 = rand_op2(x2, **kwargs2)\n\n    if isinstance(result1, (list, tuple)):\n        result1 = result1[0]\n    if isinstance(result2, (list, tuple)):\n        result2 = result2[0]\n\n    test_case.assertFalse(\n        np.allclose(result1.numpy(), result2.numpy()),\n        f\"\\ninput:\\n{input}\\result1:\\n{result1}\\result2:\\n{result2}\",\n    )\n\n\ndef _test_global_rand_op_with_split(test_case, rand_op, input=None, **kwargs):\n    rand_op, kwargs = _inspect_rand_op_and_args(rand_op, **kwargs)\n    ranks = np.array(range(flow.env.get_world_size()))\n\n    if input is None:\n        device = kwargs.pop(\"device\", None)\n        placement = flow.placement(device, ranks)\n        y = rand_op(placement=placement, sbp=flow.sbp.split(0), **kwargs)\n    else:\n        x = flow.concat([input, input], dim=0)\n        placement = flow.placement(input.device.type, ranks)\n        # local to broadcast global\n        x_broadcast = x.to_global(\n            placement=placement, sbp=flow.sbp.broadcast(), copy=True\n        )\n        x_split = x_broadcast.to_global(sbp=flow.sbp.split(0))\n        y = rand_op(x_split, **kwargs)\n\n    if isinstance(y, (list, tuple)):\n        y = y[0]\n\n    y_broadcast = y.to_global(placement=placement, sbp=flow.sbp.broadcast())\n    half = y_broadcast.shape[0] // 2\n    first_half = y_broadcast[0:half]\n    second_half = y_broadcast[half:]\n    test_case.assertFalse(np.allclose(first_half.numpy(), second_half.numpy()))\n\n\ndef _test_global_rand_op_with_broadcast(test_case, rand_op, input=None, **kwargs):\n    rand_op, kwargs = _inspect_rand_op_and_args(rand_op, **kwargs)\n    ranks = np.array(range(flow.env.get_world_size()))\n\n    if input is None:\n        device = kwargs.pop(\"device\", \"cpu\")\n        placement = flow.placement(device, ranks)\n        y = rand_op(placement=placement, sbp=flow.sbp.broadcast(), **kwargs)\n    else:\n        placement = flow.placement(input.device.type, ranks)\n        # local to broadcast global\n        x = input.to_global(placement=placement, sbp=flow.sbp.broadcast(), copy=True)\n        y = rand_op(x, **kwargs)\n\n    if isinstance(y, (list, tuple)):\n        y_local = y[0].to_local()\n    else:\n        y_local = y.to_local()\n\n    y_all_ranks = y_local.to_global(placement=placement, sbp=flow.sbp.split(0))\n    y_allgather = y_all_ranks.to_global(sbp=flow.sbp.broadcast())\n    half = y_allgather.shape[0] // 2\n    first_half = y_allgather[0:half]\n    second_half = y_allgather[half:]\n    test_case.assertTrue(np.allclose(first_half.numpy(), second_half.numpy()))\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestRandOpUnidentical(oneflow.unittest.TestCase):\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    def test_usual_rand_op(self):\n        for device in (\"cpu\", \"cuda\"):\n            x = flow.randn(4, 16, device=device)\n            _test_rand_op_unidentical(self, nn.Dropout, x, p=0.5)\n            _test_rand_op_unidentical(self, flow._C.rrelu, x, training=True)\n            _test_rand_op_unidentical(self, nn.init.uniform_, x)\n            _test_rand_op_unidentical(self, flow._C.exponential_, x)\n\n            x1 = flow.rand(4, 16, device=device)\n            _test_rand_op_unidentical(\n                self, flow.multinomial, x1, num_samples=16, replacement=True\n            )\n\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    def test_source_rand_op(self):\n        shape = (4, 16)\n        for device in (\"cpu\", \"cuda\"):\n            _test_rand_op_unidentical(self, flow.rand, size=shape, device=device)\n            _test_rand_op_unidentical(\n                self, flow.normal, mean=0.0, std=1.0, size=shape, device=device\n            )\n            _test_rand_op_unidentical(\n                self, flow.randint, low=0, high=10, size=shape, device=device\n            )\n            _test_rand_op_unidentical(self, flow.randperm, n=32, device=device)\n\n    def test_bernoulli(self):\n        x1 = flow.randn(4, 16)\n        _test_rand_op_unidentical(self, flow.bernoulli, x1, p=0.5)\n        x2 = flow.rand(4, 16)\n        _test_rand_op_unidentical(self, flow.bernoulli, x2)\n\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    def test_random_mask_like(self):\n        x = flow.randn(4, 16, 64).to(\"cuda\")\n        _test_rand_op_unidentical(\n            self,\n            flow._C.fused_scale_tril_softmax_mask_scale,\n            x,\n            p=0.1,\n            diagonal=2,\n            tril_scale_value=-1000,\n        )\n\n\n@flow.unittest.skip_unless_1n2d()\nclass TestGlobalRandOp(oneflow.unittest.TestCase):\n    @unittest.skip(\"skip for now, becase it failed 4 times in past week\")\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    def test_usual_rand_op_with_split(self):\n        for device in (\"cpu\", \"cuda\"):\n            x = flow.randn(2, 4, device=device)\n            _test_global_rand_op_with_split(self, nn.Dropout, x, p=0.5)\n            _test_global_rand_op_with_split(self, flow._C.rrelu, x, training=True)\n            _test_global_rand_op_with_split(self, nn.init.uniform_, x)\n            _test_global_rand_op_with_split(self, flow._C.exponential_, x)\n\n            x1 = flow.rand(2, 8, device=device)\n            _test_global_rand_op_with_split(\n                self, flow.multinomial, x1, num_samples=8, replacement=True\n            )\n\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    def test_usual_rand_op_with_broadcast(self):\n        for device in (\"cpu\", \"cuda\"):\n            x = flow.randn(2, 4, device=device)\n            _test_global_rand_op_with_broadcast(self, nn.Dropout, x, p=0.5)\n            _test_global_rand_op_with_broadcast(self, flow._C.rrelu, x, training=True)\n            _test_global_rand_op_with_broadcast(self, nn.init.uniform_, x)\n            _test_global_rand_op_with_broadcast(self, flow._C.exponential_, x)\n\n            x1 = flow.rand(2, 8, device=device)\n            _test_global_rand_op_with_broadcast(\n                self, flow.multinomial, x1, num_samples=8, replacement=True\n            )\n\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    def test_source_rand_op_with_split(self):\n        shape = (4, 4)\n        for device in (\"cpu\", \"cuda\"):\n            _test_global_rand_op_with_split(self, flow.rand, size=shape, device=device)\n            _test_global_rand_op_with_split(\n                self, flow.normal, mean=0.0, std=1.0, size=shape, device=device\n            )\n            _test_global_rand_op_with_split(\n                self, flow.randint, low=0, high=10, size=shape, device=device\n            )\n            _test_global_rand_op_with_split(self, flow.randperm, n=32, device=device)\n\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    def test_source_rand_op_with_broadcast(self):\n        shape = (4, 4)\n        for device in (\"cpu\", \"cuda\"):\n            _test_global_rand_op_with_broadcast(\n                self, flow.rand, size=shape, device=device\n            )\n            _test_global_rand_op_with_broadcast(\n                self, flow.normal, mean=0.0, std=1.0, size=shape, device=device\n            )\n            _test_global_rand_op_with_broadcast(\n                self, flow.randint, low=0, high=10, size=shape, device=device\n            )\n            _test_global_rand_op_with_broadcast(\n                self, flow.randperm, n=32, device=device\n            )\n\n    @unittest.skip(\"skip for now, becase it failed 4 times in past week\")\n    def test_bernoulli_with_split(self):\n        x1 = flow.randn(2, 8)\n        _test_global_rand_op_with_split(self, flow.bernoulli, x1, p=0.5)\n        x2 = flow.rand(2, 8)\n        _test_global_rand_op_with_split(self, flow.bernoulli, x2)\n\n    def test_bernoulli_with_broadcast(self):\n        x1 = flow.randn(2, 8)\n        _test_global_rand_op_with_broadcast(self, flow.bernoulli, x1, p=0.5)\n        x2 = flow.rand(2, 8)\n        _test_global_rand_op_with_broadcast(self, flow.bernoulli, x2)\n\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    def test_random_mask_like_with_split(self):\n        x = flow.randn(2, 16, 64).to(\"cuda\")\n        _test_global_rand_op_with_split(\n            self,\n            flow._C.fused_scale_tril_softmax_mask_scale,\n            x,\n            p=0.1,\n            diagonal=0,\n            tril_scale_value=-1000,\n        )\n\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    def test_random_mask_like_with_broadcast(self):\n        x = flow.randn(2, 16, 64).to(\"cuda\")\n        _test_global_rand_op_with_broadcast(\n            self,\n            flow._C.fused_scale_tril_softmax_mask_scale,\n            x,\n            p=0.2,\n            diagonal=1,\n            tril_scale_value=-100,\n        )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_randperm.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow as flow\nfrom collections import OrderedDict\n\nfrom oneflow.test_utils.automated_test_util import *\nfrom oneflow.test_utils.test_util import GenArgList\nimport numpy as np\nimport unittest\n\n\ndef _test_randperm_with_generator(test_case, N, device, dtype):\n    generator = flow.Generator()\n    generator.manual_seed(0)\n    y_1 = flow.randperm(N, device=device, dtype=dtype, generator=generator)\n    generator.manual_seed(0)\n    y_2 = flow.randperm(N, device=device, dtype=dtype, generator=generator)\n    test_case.assertTrue(np.allclose(y_1.numpy(), y_2.numpy()))\n    test_case.assertTrue(\n        y_1.device == flow.device(device) and y_2.device == flow.device(device)\n    )\n    test_case.assertTrue(y_1.dtype == dtype and y_2.dtype == dtype)\n\n\ndef _test_randperm_backward(test_case, N, device, dtype):\n    dtype = flow.float32  # fix dtype here as reduce_sum doesn't support all dtypes yet\n    x = flow.randperm(N, device=device, dtype=dtype)\n    x.requires_grad = True\n    y = x.sum()\n    y.backward()\n    test_case.assertTrue(np.allclose(x.grad.numpy(), np.ones(N), 1e-05, 1e-05))\n\n\ndef _test_randperm_randomness(test_case, N, device, dtype):\n    n = np.random.randint(100, 1000)\n    x1 = flow.randperm(n, device=device)\n    x2 = flow.randperm(n, device=device)\n    test_case.assertFalse(np.all(x1.numpy() == x2.numpy()))\n\n\ndef _test_randperm_large_seq_randomness(test_case, N, device, dtype):\n    n = 65536\n    x1 = flow.randperm(n, device=device)\n    x2 = flow.randperm(n, device=device)\n    test_case.assertFalse(np.all(x1.numpy() == x2.numpy()))\n\n\n@flow.unittest.skip_unless_1n1d()\nclass Testrandperm(flow.unittest.TestCase):\n    def test_global_naive(test_case):\n        placement = flow.placement(\"cpu\", ranks=[0])\n        sbp = (flow.sbp.broadcast,)\n        x = flow.randperm(10, placement=placement, sbp=sbp)\n        test_case.assertEqual(x.sbp, sbp)\n        test_case.assertEqual(x.placement, placement)\n\n    def test_global_different_types(test_case):\n        for dtype in [\n            flow.uint8,\n            flow.int8,\n            flow.int32,\n            flow.int64,\n            flow.float32,\n            flow.float64,\n        ]:\n            placement = flow.placement(\"cpu\", ranks=[0])\n            sbp = (flow.sbp.broadcast,)\n            x = flow.randperm(10, placement=placement, sbp=sbp, dtype=dtype)\n            test_case.assertEqual(x.dtype, dtype)\n            test_case.assertEqual(x.sbp, sbp)\n            test_case.assertEqual(x.placement, placement)\n\n    def test_randperm(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_functions\"] = [\n            _test_randperm_with_generator,\n            _test_randperm_randomness,\n            _test_randperm_large_seq_randomness,\n        ]\n        arg_dict[\"N\"] = [i for i in range(10, 100, 5)]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"dtype\"] = [\n            flow.uint8,\n            flow.int8,\n            flow.int32,\n            flow.int64,\n            flow.float32,\n            flow.float64,\n        ]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n    def test_randperm_backward(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_functions\"] = [\n            _test_randperm_backward,\n        ]\n        arg_dict[\"N\"] = [i for i in range(10, 100, 5)]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"dtype\"] = [flow.float32, flow.float64]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n    @autotest(auto_backward=False, check_graph=True)\n    def test_auto_1(test_case):\n        device = random_device()\n        y = torch.randperm(1, device=device)\n        return y\n\n    @autotest(n=5, auto_backward=False, check_graph=True)\n    def test_auto_0(test_case):\n        device = random_device()\n        y = torch.randperm(0, device=device)\n        return y\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n@flow.unittest.skip_unless_1n2d()\nclass TestRandpermOnNonDefaultDevice(flow.unittest.TestCase):\n    def test_non_default_device(test_case):\n        x = flow.randperm(3, device=\"cuda:1\")\n        test_case.assertEqual(x.device, flow.device(\"cuda:1\"))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_reciprocal.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nfrom oneflow.test_utils.test_util import GenArgList\n\nfrom oneflow.test_utils.automated_test_util import *\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestReciprocalModule(flow.unittest.TestCase):\n    @autotest(check_graph=True)\n    def test_flow_reciprocal_list_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(\n            ndim=4, dim1=random().to(int), dim2=random().to(int), dim3=random().to(int)\n        ).to(device)\n        y = torch.reciprocal(x)\n        return y\n\n    @autotest(check_graph=True)\n    def test_flow_reciprocal_list_with_0dim_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=0).to(device)\n        y = torch.reciprocal(x)\n        return y\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_reduce.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport os\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.test_util import GenArgList\n\n\ndef _test_reduce(test_case, dst, device):\n    if flow.env.get_rank() == 0:\n        np_arr = np.array(\n            [[4, 6, 5, 20], [6, 8, 9, 0], [3, 7, 5, 0], [6, 8, 9, 0]], dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 1:\n        np_arr = np.array(\n            [[2, 10, 10, 7], [3, 9, 10, 5], [4, 6, 6, 9], [6, 8, 6, 4]],\n            dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 2:\n        np_arr = np.array(\n            [[9, 6, 5, 8], [4, 9, 7, 0], [2, 5, 7, 9], [6, 8, 10, 0]], dtype=np.float32,\n        )\n    elif flow.env.get_rank() == 3:\n        np_arr = np.array(\n            [[9, 4, 5, 8], [7, 2, 9, 5], [6, 3, 9, 2], [3, 7, 5, 8]], dtype=np.float32,\n        )\n    x = flow.tensor(np_arr, device=device, dtype=flow.float32)\n    flow._C.local_reduce(x, dst=dst)\n    if flow.env.get_rank() == dst:\n        test_case.assertTrue(\n            np.allclose(\n                x.numpy(),\n                np.array(\n                    [\n                        [24, 26, 25, 43],\n                        [20, 28, 35, 10],\n                        [15, 21, 27, 20],\n                        [21, 31, 30, 12],\n                    ],\n                    dtype=np.float32,\n                ),\n            )\n        )\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n@flow.unittest.skip_unless_1n4d()\nclass TestReduce(flow.unittest.TestCase):\n    def test_reduce(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"dst\"] = [0, 1, 2, 3]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n\n        for arg in GenArgList(arg_dict):\n            _test_reduce(test_case, *arg)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_reduce_sum_like.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\ndef _test_reduce_sum_like(test_case, device):\n    input = flow.tensor(\n        np.ones(shape=(3, 3, 3), dtype=np.float32),\n        dtype=flow.float32,\n        device=flow.device(device),\n    )\n    like_tensor = flow.tensor(\n        np.ones(shape=(3, 1, 1), dtype=np.float32),\n        dtype=flow.float32,\n        device=flow.device(device),\n    )\n    of_out = flow._C.reduce_sum_like(input, like_tensor, axis=(1, 2))\n    np_out = np.full(shape=like_tensor.shape, fill_value=9)\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05))\n\n\ndef _test_reduce_sum_like_one(test_case, device):\n    input = flow.tensor(\n        np.ones(shape=(1, 2, 3), dtype=np.float32),\n        dtype=flow.float32,\n        device=flow.device(device),\n    )\n    like_tensor = flow.tensor(\n        np.ones(shape=(1, 1), dtype=np.float32),\n        dtype=flow.float32,\n        device=flow.device(device),\n    )\n    of_out = flow._C.reduce_sum_like(input, like_tensor, axis=(1, 2))\n    np_out = np.full(like_tensor.shape, 6)\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05))\n\n\ndef _test_reduce_sum_like_different_dim(test_case, device):\n    input = flow.tensor(\n        np.ones(shape=(2, 3, 4), dtype=np.float32),\n        dtype=flow.float32,\n        device=flow.device(device),\n    )\n    like_tensor = flow.tensor(\n        np.ones(shape=(3, 1), dtype=np.float32),\n        dtype=flow.float32,\n        device=flow.device(device),\n    )\n    of_out = flow._C.reduce_sum_like(input, like_tensor, axis=(0, 2))\n    np_out = np.full(like_tensor.shape, 8)\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05))\n\n\ndef _test_reduce_sum_like_different_dim_with_input_axisvec(test_case, device):\n    input = flow.tensor(\n        np.ones(shape=(1, 5, 6, 1, 6), dtype=np.float32),\n        dtype=flow.float32,\n        device=flow.device(device),\n    )\n    like_tensor = flow.tensor(\n        np.ones(shape=(1, 5, 6), dtype=np.float32),\n        dtype=flow.float32,\n        device=flow.device(device),\n    )\n    of_out = flow._C.reduce_sum_like(input, like_tensor, axis=(3, 4))\n    np_out = np.full(like_tensor.shape, 6)\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05))\n\n\ndef _test_reduce_sum_like_3dim(test_case, device):\n    input = flow.tensor(\n        np.ones(shape=(3, 3, 2), dtype=np.float32),\n        dtype=flow.float32,\n        device=flow.device(device),\n    )\n    like_tensor = flow.tensor(\n        np.ones(shape=(1, 3, 2), dtype=np.float32),\n        dtype=flow.float32,\n        device=flow.device(device),\n    )\n    of_out = flow._C.reduce_sum_like(input, like_tensor, axis=(0,))\n    np_out = np.full(like_tensor.shape, 3)\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05))\n\n\ndef _test_reduce_sum_like_4dim(test_case, device):\n    input = flow.tensor(\n        np.ones(shape=(3, 3, 2, 3), dtype=np.float32),\n        dtype=flow.float32,\n        device=flow.device(device),\n    )\n    like_tensor = flow.tensor(\n        np.ones(shape=(1, 3, 2, 1), dtype=np.float32),\n        dtype=flow.float32,\n        device=flow.device(device),\n    )\n    of_out = flow._C.reduce_sum_like(input, like_tensor, axis=(0, 3))\n    np_out = np.full(like_tensor.shape, 9)\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05))\n\n\ndef _test_reduce_sum_like_backward(test_case, device):\n    input = flow.tensor(\n        np.ones(shape=(3, 3, 3), dtype=np.float32),\n        dtype=flow.float32,\n        device=flow.device(device),\n        requires_grad=True,\n    )\n    like_tensor = flow.tensor(\n        np.ones(shape=(3, 1, 1), dtype=np.float32),\n        dtype=flow.float32,\n        device=flow.device(device),\n        requires_grad=True,\n    )\n    of_out = flow._C.reduce_sum_like(input, like_tensor, axis=(1, 2))\n    of_out = of_out.sum()\n    of_out.backward()\n    np_grad = np.full(input.shape, 1.0)\n    test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-05, 1e-05))\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestReduceSumLike(flow.unittest.TestCase):\n    def test_reduce_sum_like(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [\n            _test_reduce_sum_like,\n            _test_reduce_sum_like_one,\n            _test_reduce_sum_like_different_dim,\n            _test_reduce_sum_like_different_dim_with_input_axisvec,\n            _test_reduce_sum_like_3dim,\n            _test_reduce_sum_like_4dim,\n            _test_reduce_sum_like_backward,\n        ]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_reflection_pad.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nfrom oneflow.test_utils.test_util import (\n    Array2Numpy,\n    FlattenArray,\n    GenArgList,\n    Index2Coordinate,\n)\n\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\ndef gen_numpy_test_sample(input, padding):\n    (c_idx, h_idx, w_idx) = (1, 2, 3)\n    pad_left = padding[0]\n    pad_right = padding[1]\n    pad_top = padding[2]\n    pad_bottom = padding[3]\n    pad_shape = ((0, 0), (0, 0), (pad_top, pad_bottom), (pad_left, pad_right))\n\n    def _np_reflection_pad2d(input, pad_shape):\n        numpy_reflect = np.pad(input, pad_shape, \"reflect\")\n        return numpy_reflect\n\n    def _np_reflection_pad2d_grad(src, dest):\n        (dx_height, dx_width) = (input.shape[h_idx], input.shape[w_idx])\n        (dy_height, dy_width) = (output.shape[h_idx], output.shape[w_idx])\n        numpy_src = np.ones(src.shape, np.int32)\n        numpy_dest = np.zeros(dest.shape, np.int32)\n        array_src = FlattenArray(numpy_src)\n        array_dest = FlattenArray(numpy_dest)\n        src_num = src.shape[c_idx] * src.shape[h_idx] * src.shape[w_idx]\n        dest_num = dest.shape[c_idx] * dest.shape[h_idx] * dest.shape[w_idx]\n        elements_num = src.shape[0] * src_num\n        for iter_n in range(elements_num):\n            coords = Index2Coordinate(iter_n, src.shape)\n            (n, c, i, j) = (coords[0], coords[c_idx], coords[h_idx], coords[w_idx])\n            ip_x = ip_y = 0\n            if j < pad_left:\n                ip_x = pad_left * 2 - j\n            elif j >= pad_left and j < dx_width + pad_left:\n                ip_x = j\n            else:\n                ip_x = (dx_width + pad_left - 1) * 2 - j\n            if i < pad_top:\n                ip_y = pad_top * 2 - i\n            elif i >= pad_top and i < dx_height + pad_top:\n                ip_y = i\n            else:\n                ip_y = (dx_height + pad_top - 1) * 2 - i\n            ip_x = ip_x - pad_left\n            ip_y = ip_y - pad_top\n            src_index = n * src_num + c * dy_width * dy_height + i * dy_width + j\n            dest_index = (\n                n * dest_num + c * dx_width * dx_height + ip_y * dx_width + ip_x\n            )\n            array_dest[dest_index] += array_src[src_index]\n        numpy_dest = Array2Numpy(array_dest, dest.shape)\n        return numpy_dest\n\n    output = _np_reflection_pad2d(input, pad_shape)\n    grad = _np_reflection_pad2d_grad(output, input)\n    return (output, grad)\n\n\ndef _test_reflection_pad2d(test_case, shape, padding, device):\n    np_input = np.random.randn(*shape).astype(np.float32)\n    of_input = flow.tensor(\n        np_input, dtype=flow.float32, device=flow.device(device), requires_grad=True\n    )\n    if isinstance(padding, int):\n        boundary = [padding, padding, padding, padding]\n    elif isinstance(padding, tuple) and len(padding) == 4:\n        boundary = [padding[0], padding[1], padding[2], padding[3]]\n    else:\n        raise ValueError(\"padding must be in or list or tuple!\")\n    (np_out, np_grad) = gen_numpy_test_sample(np_input, boundary)\n    layer = flow.nn.ReflectionPad2d(padding=padding)\n    of_out = layer(of_input)\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001))\n    of_out = of_out.sum()\n    of_out.backward()\n    test_case.assertTrue(np.allclose(of_input.grad.numpy(), np_grad, 0.0001, 0.0001))\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestReflectionPadModule(flow.unittest.TestCase):\n    def test_reflection_pad2d(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"shape\"] = [(1, 2, 3, 4), (8, 3, 4, 4)]\n        arg_dict[\"padding\"] = [2, (1, 1, 2, 2)]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            _test_reflection_pad2d(test_case, *arg)\n\n    @autotest(n=5)\n    def test_reflection_pad_1d_with_3d_input(test_case):\n        c = random(1, 6).to(int)\n        w = random(1, 6).to(int)\n        m = torch.nn.ReflectionPad1d(padding=random(low=0, high=5).to(int))\n        m.train(random())\n        device = random_device()\n        m.to(device)\n        x = random_tensor(ndim=3, dim1=c, dim2=w).to(device)\n        y = m(x)\n        return y\n\n    @autotest(n=5)\n    def test_reflection_pad_1d_with_2d_input(test_case):\n        w = random(1, 6).to(int)\n        m = torch.nn.ReflectionPad1d(padding=random(low=0, high=5).to(int))\n        m.train(random())\n        device = random_device()\n        m.to(device)\n        x = random_tensor(ndim=2, dim1=w).to(device)\n        y = m(x)\n        return y\n\n    @autotest(n=5)\n    def test_reflection_pad_2d_with_random_data(test_case):\n        c = random(1, 6).to(int)\n        h = random(1, 6).to(int)\n        w = random(1, 6).to(int)\n        m = torch.nn.ReflectionPad2d(padding=random(low=0, high=5).to(int))\n        m.train(random())\n        device = random_device()\n        m.to(device)\n        x = random_tensor(ndim=4, dim1=c, dim2=h, dim3=w).to(device)\n        y = m(x)\n        return y\n\n    @autotest(n=5)\n    def test_functional_reflection_pad_1d_with_random_data(test_case):\n        c = random(1, 6).to(int)\n        w = random(1, 6).to(int)\n        pad = [1, 2]\n        device = random_device()\n        x = random_tensor(ndim=3, dim1=c, dim2=w).to(device)\n        y = torch.nn.functional.pad(input=x, pad=pad, mode=\"reflect\")\n        return y\n\n    @autotest(n=5)\n    def test_functional_reflection_pad_2d_with_random_data(test_case):\n        c = random(1, 6).to(int)\n        h = random(1, 6).to(int)\n        w = random(1, 6).to(int)\n        pad = [0, 1, 2, 3]\n        device = random_device()\n        x = random_tensor(ndim=4, dim1=c, dim2=h, dim3=w).to(device)\n        y = torch.nn.functional.pad(input=x, pad=pad, mode=\"reflect\")\n        return y\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_repeat.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\n\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestRepeat(flow.unittest.TestCase):\n    @autotest(n=10)\n    def test_flow_tensor_repeat_with_random_data(test_case):\n        x = random_tensor(ndim=2, dim0=1, dim1=2)\n        sizes = (random(1, 5).to(int), random(1, 5).to(int), random(1, 5).to(int))\n        y = x.repeat(sizes)\n        return y\n\n    @autotest(n=10, auto_backward=False)\n    def test_flow_tensor_repeat_bool_with_random_data(test_case):\n        x = random_tensor(ndim=2, dim0=1, dim1=2).to(torch.bool)\n        sizes = (random(1, 5).to(int), random(1, 5).to(int), random(1, 5).to(int))\n        y = x.repeat(sizes)\n        return y\n\n    @autotest(n=10)\n    def test_flow_tensor_repeat_with_0dim_data(test_case):\n        x = random_tensor(ndim=0)\n        sizes = (random(1, 5).to(int), random(1, 5).to(int), random(1, 5).to(int))\n        y = x.repeat(sizes)\n        return y\n\n    @autotest(n=5, auto_backward=False)\n    def test_complicated_repeat_case(test_case):\n        x = torch.ones(224, 224)\n        y = torch.triu(x, diagonal=1).repeat(32, 1, 1)\n        z = y.byte()\n        return z\n\n    @autotest(n=5)\n    def test_flow_tensor_0size_with_random_data(test_case):\n        x = random_tensor(ndim=2, dim0=3, dim1=1)\n        sizes = (1, 0)\n        y = x.repeat(sizes)\n        return y\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_repeat_interleave.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\n\nimport numpy as np\nimport oneflow as flow\nimport oneflow.unittest\nimport torch as torch_original\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestRepeatInterLeave(flow.unittest.TestCase):\n    @autotest(n=5)\n    def test_flow_int_repeat_interleave_dim_none(test_case):\n        x = random_tensor(ndim=2, dim0=1, dim1=2)\n        y = torch.repeat_interleave(x, 2)\n        return y\n\n    @autotest(n=5)\n    def test_flow_int_repeat_interleave_with_dim(test_case):\n        x = random_tensor(ndim=3, dim0=2, dim1=2, dim2=3)\n        dim = random(low=0, high=2).to(int)\n        y = torch.repeat_interleave(x, 2, dim)\n        return y\n\n    @autotest(n=5)\n    def test_flow_tensor_repeat_interleave_dim(test_case):\n        x = random_tensor(ndim=3, dim0=2, dim1=2, dim2=3)\n        y = random_tensor(ndim=1, dim0=2, dtype=int, low=0, high=4)\n        z = torch.repeat_interleave(x, y, 1)\n        return z\n\n    @autotest(n=5)\n    def test_flow_tensor_repeat_interleave_dim_with_output_size(test_case):\n        x = random_tensor(ndim=3, dim0=2, dim1=2, dim2=3)\n        y = random_tensor(ndim=1, dim0=2, dtype=int, low=0, high=4)\n        z = torch.repeat_interleave(x, y, 1, output_size=2)\n        return z\n\n    def test_flow_tensor_repeat_interleave_0size_tensor(test_case):\n        np_arr = np.array(\n            [\n                [[0.8548, 0.0436, 0.7977], [0.1919, 0.4191, 0.2186]],\n                [[0.4741, 0.8896, 0.6859], [0.5223, 0.7803, 0.1134]],\n            ]\n        )\n        x_torch = torch_original.tensor(np_arr)\n        x_torch.requires_grad = True\n        y_torch = torch_original.tensor([0, 0])\n        z_torch = torch_original.repeat_interleave(x_torch, y_torch, 1)\n        z_torch.sum().backward()\n\n        x_flow = flow.tensor(np_arr)\n        x_flow.requires_grad = True\n        y_flow = flow.tensor([0, 0])\n        z_flow = flow.repeat_interleave(x_flow, y_flow, 1)\n        z_flow.sum().backward()\n        test_case.assertTrue(np.array_equal(x_torch.grad.numpy(), x_flow.grad.numpy()))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_replication_pad.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nfrom oneflow.test_utils.test_util import (\n    Array2Numpy,\n    FlattenArray,\n    GenArgList,\n    Index2Coordinate,\n)\n\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\ndef _np_replication_pad2d_grad(src, dest, padding):\n    (c_idx, h_idx, w_idx) = (1, 2, 3)\n    pad_left = padding[0]\n    pad_right = padding[1]\n    pad_top = padding[2]\n    pad_bottom = padding[3]\n    (dx_height, dx_width) = (dest.shape[h_idx], dest.shape[w_idx])\n    (dy_height, dy_width) = (src.shape[h_idx], src.shape[w_idx])\n    numpy_src = np.ones(src.shape, np.int32)\n    numpy_dest = np.zeros(dest.shape, np.int32)\n    array_src = FlattenArray(numpy_src)\n    array_dest = FlattenArray(numpy_dest)\n    src_num = src.shape[c_idx] * src.shape[h_idx] * src.shape[w_idx]\n    dest_num = dest.shape[c_idx] * dest.shape[h_idx] * dest.shape[w_idx]\n    elements_num = src.shape[0] * src_num\n    for iter_n in range(elements_num):\n        coords = Index2Coordinate(iter_n, src.shape)\n        (n, c, i, j) = (coords[0], coords[c_idx], coords[h_idx], coords[w_idx])\n        ip_x = ip_y = 0\n        if j < pad_left:\n            ip_x = pad_left\n        elif j >= pad_left and j < dx_width + pad_left:\n            ip_x = j\n        else:\n            ip_x = dx_width + pad_left - 1\n        if i < pad_top:\n            ip_y = pad_top\n        elif i >= pad_top and i < dx_height + pad_top:\n            ip_y = i\n        else:\n            ip_y = dx_height + pad_top - 1\n        ip_x = ip_x - pad_left\n        ip_y = ip_y - pad_top\n        src_index = n * src_num + c * dy_width * dy_height + i * dy_width + j\n        dest_index = n * dest_num + c * dx_width * dx_height + ip_y * dx_width + ip_x\n        array_dest[dest_index] += array_src[src_index]\n    numpy_dest = Array2Numpy(array_dest, dest.shape)\n    return numpy_dest\n\n\ndef _test_ReplicationPad2d(test_case, shape, padding, device):\n    np_input = np.random.random(shape).astype(np.float32)\n    of_input = flow.tensor(\n        np_input, dtype=flow.float32, device=flow.device(device), requires_grad=True\n    )\n    if isinstance(padding, int):\n        np_boundary = ((0, 0), (0, 0), (padding, padding), (padding, padding))\n        boundry = [padding, padding, padding, padding]\n    elif isinstance(padding, (tuple, int)) and len(padding) == 4:\n        np_boundary = (\n            (0, 0),\n            (0, 0),\n            (padding[2], padding[3]),\n            (padding[0], padding[1]),\n        )\n        boundry = [padding[0], padding[1], padding[2], padding[3]]\n    else:\n        raise ValueError(\"padding must be in or list or tuple!\")\n    layer = flow.nn.ReplicationPad2d(padding=padding)\n    of_out = layer(of_input)\n    np_out = np.pad(np_input, np_boundary, mode=\"edge\")\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05))\n    of_out = of_out.sum()\n    of_out.backward()\n    np_out_grad = _np_replication_pad2d_grad(np_out, np_input, boundry)\n    test_case.assertTrue(np.allclose(of_input.grad.numpy(), np_out_grad, 0.001, 0.001))\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestReplicationPadModule(flow.unittest.TestCase):\n    def test_ReplicationPad2d(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"shape\"] = [(1, 2, 3, 4), (8, 3, 4, 4)]\n        arg_dict[\"padding\"] = [2, (1, 1, 2, 2)]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            _test_ReplicationPad2d(test_case, *arg)\n\n    @autotest(n=5)\n    def test_replication_pad1d_with_3d_input(test_case):\n        c = random(1, 6).to(int)\n        w = random(1, 6).to(int)\n        pad = random(low=0, high=5).to(int)\n        m = torch.nn.ReplicationPad1d(padding=pad)\n        m.train(random())\n        device = random_device()\n        m.to(device)\n        x = random_tensor(ndim=3, dim1=c, dim2=w).to(device)\n        y = m(x)\n        return y\n\n    @autotest(n=5)\n    def test_replication_pad1d_with_2d_input(test_case):\n        w = random(1, 6).to(int)\n        m = torch.nn.ReplicationPad1d(padding=random(low=0, high=5).to(int))\n        m.train(random())\n        device = random_device()\n        m.to(device)\n        x = random_tensor(ndim=2, dim1=w).to(device)\n        y = m(x)\n        return y\n\n    @autotest(n=5)\n    def test_replication_pad2d_with_random_data(test_case):\n        c = random(1, 6).to(int)\n        h = random(1, 6).to(int)\n        w = random(1, 6).to(int)\n        m = torch.nn.ReplicationPad2d(padding=random(low=0, high=5))\n        m.train(random())\n        device = random_device()\n        m.to(device)\n        x = random_tensor(ndim=4, dim1=c, dim2=h, dim3=w).to(device)\n        y = m(x)\n        return y\n\n    @autotest(n=5)\n    def test_functional_replication_pad_1d_with_random_data(test_case):\n        c = random(1, 6).to(int)\n        w = random(1, 6).to(int)\n        pad = [0, 1]\n        device = random_device()\n        x = random_tensor(ndim=3, dim1=c, dim2=w).to(device)\n        y = torch.nn.functional.pad(input=x, pad=pad, mode=\"replicate\")\n        return y\n\n    @autotest(n=5)\n    def test_functional_replication_pad_2d_with_random_data(test_case):\n        c = random(1, 6).to(int)\n        h = random(1, 6).to(int)\n        w = random(1, 6).to(int)\n        pad = [0, 1, 2, 3]\n        device = random_device()\n        x = random_tensor(ndim=4, dim1=c, dim2=h, dim3=w).to(device)\n        y = torch.nn.functional.pad(input=x, pad=pad, mode=\"replicate\")\n        return y\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_reshape.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\n\nfrom oneflow.test_utils.automated_test_util import *\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\ndef _test_reshape(test_case, device):\n    x = np.array(\n        [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]]\n    ).astype(np.float32)\n    input = flow.tensor(x, dtype=flow.float32, device=flow.device(device))\n    of_shape = flow.reshape(input, shape=[2, 2, 2, -1]).numpy().shape\n    np_shape = (2, 2, 2, 2)\n    test_case.assertTrue(np.array_equal(of_shape, np_shape))\n\n\ndef _test_reshape_tuple(test_case, device):\n    x = np.array(\n        [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]]\n    ).astype(np.float32)\n    input = flow.tensor(x, dtype=flow.float32, device=flow.device(device))\n    of_shape = flow.reshape(input, shape=(2, 2, 2, -1)).numpy().shape\n    np_shape = (2, 2, 2, 2)\n    test_case.assertTrue(np.array_equal(of_shape, np_shape))\n\n\ndef _test_reshape_backward(test_case, device):\n    x = np.array(\n        [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]]\n    ).astype(np.float32)\n    input = flow.tensor(\n        x, dtype=flow.float32, device=flow.device(device), requires_grad=True\n    )\n    of_out = flow.reshape(input, shape=[2, 2, 2, -1]).sum()\n    of_out.backward()\n    np_grad = np.array(\n        [\n            [1.0, 1.0, 1.0, 1.0],\n            [1.0, 1.0, 1.0, 1.0],\n            [1.0, 1.0, 1.0, 1.0],\n            [1.0, 1.0, 1.0, 1.0],\n        ]\n    )\n    test_case.assertTrue(np.allclose(np_grad, input.grad.numpy(), 0.0001, 0.0001))\n\n\ndef _test_reshape_scalar(test_case, device):\n    x = flow.tensor(2.0, device=flow.device(device))\n    test_case.assertTrue(np.array_equal(x.shape, ()))\n    a = flow.reshape(x, (1,))\n    test_case.assertTrue(np.array_equal(a.shape, (1,)))\n    b = flow.reshape(x, (1, 1, 1, 1,))\n    test_case.assertTrue(np.array_equal(b.shape, (1, 1, 1, 1)))\n    c = flow.reshape(b, ())\n    test_case.assertTrue(np.array_equal(c.shape, ()))\n    d = flow.reshape(x, ())\n    test_case.assertTrue(np.array_equal(d.shape, ()))\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestModule(flow.unittest.TestCase):\n    def test_reshape(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [\n            _test_reshape,\n            _test_reshape_tuple,\n            _test_reshape_backward,\n            _test_reshape_scalar,\n        ]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n    @autotest(n=5)\n    def test_reshape_flow_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=4).to(device)\n        y = torch.reshape(x, shape=(-1,))\n        return y\n\n    @autotest(n=5)\n    def test_reshape_flow_with_0dim_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=0).to(device)\n        y = torch.reshape(x, shape=(-1,))\n        return y\n\n    @autotest(n=5, auto_backward=False, check_graph=True)\n    def test_reshape_with_0_size_data(test_case):\n        device = random_device()\n        x = random_tensor(4, 2, 0, 3).to(device)\n        y = torch.reshape(\n            x, shape=(random(0, 5).to(int).value(), 0, random(0, 5).to(int).value())\n        )\n        return y\n\n    @autotest(n=5, auto_backward=False, check_graph=True)\n    def test_reshape_flow_bool_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=4).to(device=device, dtype=torch.bool)\n        y = torch.reshape(x, shape=(-1,))\n        return y\n\n    @autotest(n=2, auto_backward=False, check_graph=True)\n    def test_reshape_like(test_case):\n        device = random_device()\n        shape = [random(1, 5).to(int).value() for _ in range(4)]\n        like_shape = np.random.choice(\n            np.array(shape), len(shape), replace=False\n        ).tolist()\n        x = (\n            random_tensor(4, *shape, requires_grad=False)\n            .to(device=device)\n            .requires_grad_()\n        )\n        y = (\n            random_tensor(4, *like_shape)\n            .to(device=device)\n            .requires_grad_(random_bool())\n        )\n        # forward\n        of_z = flow._C.reshape_like(x.oneflow, y.oneflow)\n        torch_z = torch.pytorch.reshape(x.pytorch, like_shape)\n        test_case.assertTrue(\n            np.array_equal(of_z.numpy(), torch_z.detach().cpu().numpy())\n        )\n        # backward\n        of_z.sum().backward()\n        torch_z.sum().backward()\n        test_case.assertTrue(\n            np.array_equal(\n                x.grad.oneflow.numpy(), x.grad.pytorch.detach().cpu().numpy()\n            )\n        )\n\n    @profile(torch.reshape)\n    def profile_reshape(test_case):\n        torch.reshape(torch.ones(50, 20), (20, 50))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_reshape_sbp.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nimport os\nimport oneflow.unittest\nimport oneflow as flow\n\n\n@flow.unittest.skip_unless_1n2d()\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\nclass TestReshapeSbp(flow.unittest.TestCase):\n    def test_reshape_sbp(test_case):\n        input = flow.rand(\n            9, 9, 8, placement=flow.placement(\"cuda\", [0, 1]), sbp=flow.sbp.split(0)\n        )\n\n        output = input.view(81, 8)\n        test_case.assertTrue(output.sbp[0] != flow.sbp.split(0))\n\n\n@flow.unittest.skip_unless_1n4d()\nclass TestReshapeNdSbp(flow.unittest.TestCase):\n    def test_reshape_nd_sbp(test_case):\n        in_shape = (8, 4)\n        out_shape = (2, 4, 4)\n        P = flow.placement(\"cpu\", [[0, 1], [2, 3]])\n        in_sbp = [flow.sbp.split(0), flow.sbp.split(0)]\n        input = flow.rand(*in_shape, placement=P, sbp=in_sbp)\n        output = input.view(*out_shape)\n        out_sbp = output.sbp\n        test_case.assertTrue(len(in_sbp) == len(out_sbp))\n        test_case.assertTrue(out_sbp[0] == flow.sbp.split(0))\n        test_case.assertTrue(out_sbp[1] == flow.sbp.split(1))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_resnet_load_torch_weight_compatibile.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\n\nimport numpy as np\nimport torch\nimport torchvision.models as models_torch\nimport flowvision.models as models_flow\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestResNet18LoadWeightCompatibile(flow.unittest.TestCase):\n    def test_resnet18_load_weight_compatibile(test_case):\n        resnet18_torch = models_torch.resnet18(pretrained=True)\n        resnet18_flow = models_flow.resnet18()\n        parameters = resnet18_torch.state_dict()\n        for key, value in parameters.items():\n            val = value.detach().cpu().numpy()\n            parameters[key] = val\n\n        resnet18_flow.load_state_dict(parameters)\n        torch_input = torch.randn(1, 3, 224, 224)\n        flow_input = flow.tensor(torch_input.cpu().numpy())\n        torch_output = resnet18_torch(torch_input)\n        flow_output = resnet18_flow(flow_input)\n        test_case.assertTrue(\n            np.allclose(\n                torch_output.detach().numpy(), flow_output.numpy(), atol=1e-4, rtol=1e-3\n            )\n        )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_rmsnorm.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport os\nimport numpy as np\nimport unittest\nfrom collections import OrderedDict\n\nimport oneflow as flow\nimport oneflow.unittest\nimport torch\n\n\ndef _get_norm_dims(shape, normalized_shape):\n    lpad = len(shape) - len(normalized_shape)\n    assert lpad >= 0\n    return tuple(range(lpad, len(shape)))\n\n\ndef _torch_rmsnorm(x, weight, normalized_shape=None, eps=1e-6):\n    if weight is not None:\n        normalized_shape = weight.shape\n    else:\n        assert normalized_shape is not None\n    norm_dims = _get_norm_dims(x.shape, normalized_shape)\n    root_mean = torch.mean(x * x, dim=norm_dims, keepdim=True)\n    rms = torch.rsqrt(root_mean + eps)\n    normed = x * rms\n    return normed * weight if weight is not None else normed\n\n\ndef _test_rmsnorm(\n    test_case,\n    shape,\n    normalized_shape,\n    affine=True,\n    eps=1e-6,\n    dtype=flow.float32,\n    device=\"cuda\",\n):\n    np_x = np.random.randn(*shape).astype(np.float32)\n    np_weight = (\n        np.random.randn(*normalized_shape).astype(np.float32) if affine else None\n    )\n\n    torch_dtype = torch.float16 if dtype is flow.float16 else torch.float32\n    torch_x = torch.tensor(np_x).to(device=device, dtype=torch_dtype)\n    torch_weight = (\n        torch.tensor(np_weight).to(device=device, dtype=torch_dtype) if affine else None\n    )\n    torch_x.requires_grad_(True)\n    if affine:\n        torch_weight.requires_grad_(True)\n    torch_y = _torch_rmsnorm(torch_x, torch_weight, normalized_shape, eps)\n\n    np_rand_init_grad = np.random.randn(*tuple(torch_y.shape)).astype(np.float32)\n    torch_rand_init_grad = torch.tensor(np_rand_init_grad).to(\n        device=device, dtype=torch_dtype\n    )\n    (torch_y * torch_rand_init_grad).sum().backward()\n\n    torch_y = torch_y.detach().cpu().numpy()\n    torch_x_grad = torch_x.grad.detach().cpu().numpy()\n    if affine:\n        torch_weight_grad = torch_weight.grad.detach().cpu().numpy()\n\n    x = flow.tensor(np_x).to(device=device, dtype=dtype)\n    weight = flow.tensor(np_weight).to(device=device, dtype=dtype) if affine else None\n    x.requires_grad_(True)\n    if affine:\n        weight.requires_grad_(True)\n    y = flow._C.rms_norm(x, weight, normalized_shape, eps)\n    # np_rand_init_grad = np.random.randn(*tuple(y.shape)).astype(np.float32)\n    rand_init_grad = flow.tensor(np_rand_init_grad).to(device=device, dtype=dtype)\n    (y * rand_init_grad).sum().backward()\n\n    y = y.detach().cpu().numpy()\n    x_grad = x.grad.detach().cpu().numpy()\n    if affine:\n        weight_grad = weight.grad.detach().cpu().numpy()\n\n    def compare(a, b, a_name, b_name, atol=1e-5, rtol=1e-8):\n        test_case.assertTrue(\n            np.allclose(a, b, atol=atol, rtol=rtol),\n            f\"\\n{'=' * 80}\"\n            f\"\\n{a_name}:\"\n            f\"\\n{a}\"\n            f\"\\n{'-' * 80}\"\n            f\"\\n{b_name}:\"\n            f\"\\n{b}\"\n            f\"\\n{'-' * 80}\"\n            f\"\\ndiff:\"\n            f\"\\n{a - b}\"\n            f\"\\n{'*' * 80}\"\n            f\"\\nshape={shape}\"\n            f\"\\normalized_shape={normalized_shape}\"\n            f\"\\naffine={affine}\"\n            f\"\\ndtype={dtype}\"\n            f\"\\ndevice={device}\"\n            f\"\\n{a_name} vs. {b_name} max abs diff: {np.max(np.abs(a - b))}\",\n        )\n\n    if dtype is flow.float16:\n        compare(y, torch_y, \"y\", \"torch_y\", 1e-3, 1e-2)\n        compare(x_grad, torch_x_grad, \"x_grad\", \"torch_x_grad\", 1e-2, 1e-2)\n        if affine:\n            compare(\n                weight_grad,\n                torch_weight_grad,\n                \"weight_grad\",\n                \"torch_weight_grad\",\n                0.1,\n                0.1,\n            )\n    else:\n        compare(y, torch_y, \"y\", \"torch_y\")\n        compare(x_grad, torch_x_grad, \"x_grad\", \"torch_x_grad\")\n        if affine:\n            compare(\n                weight_grad,\n                torch_weight_grad,\n                \"weight_grad\",\n                \"torch_weight_grad\",\n                1e-5,\n                1e-4,\n            )\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n@flow.unittest.skip_unless_1n1d()\nclass TestRMSNorm(flow.unittest.TestCase):\n    def test_real_example(test_case):\n        _test_rmsnorm(\n            test_case,\n            shape=[512, 4, 768],\n            normalized_shape=[768],\n            affine=True,\n            dtype=flow.float16,\n            device=\"cuda\",\n        )\n\n    def test_no_affine(test_case):\n        _test_rmsnorm(\n            test_case, shape=[4, 16], normalized_shape=[16], affine=False,\n        )\n\n    def test_warp_impl(test_case):\n        _test_rmsnorm(\n            test_case, shape=[32, 1024], normalized_shape=[1024], dtype=flow.float16,\n        )\n        _test_rmsnorm(test_case, shape=[16, 512], normalized_shape=[512])\n        _test_rmsnorm(test_case, shape=[15, 512], normalized_shape=[512])\n        _test_rmsnorm(test_case, shape=[16, 511], normalized_shape=[511])\n        _test_rmsnorm(test_case, shape=[13, 499], normalized_shape=[499])\n\n    def test_block_smem_impl(test_case):\n        _test_rmsnorm(\n            test_case, shape=[16, 2048], normalized_shape=[2048], dtype=flow.float16,\n        )\n        _test_rmsnorm(test_case, shape=[8, 1536], normalized_shape=[1536])\n        _test_rmsnorm(test_case, shape=[8, 2048], normalized_shape=[2048])\n        _test_rmsnorm(test_case, shape=[7, 1536], normalized_shape=[1536])\n        _test_rmsnorm(test_case, shape=[8, 1533], normalized_shape=[1533])\n        _test_rmsnorm(test_case, shape=[7, 1533], normalized_shape=[1533])\n\n    @unittest.skip(\"skip for now, becase it failed 4 times in past week\")\n    def test_block_uncached_impl(test_case):\n        _test_rmsnorm(\n            test_case,\n            shape=[16, 1024 * 1024],\n            normalized_shape=[1024 * 1024],\n            dtype=flow.float16,\n        )\n        _test_rmsnorm(\n            test_case, shape=[8, 1024], normalized_shape=[1024], dtype=flow.double\n        )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_roc_auc_score.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nimport oneflow as flow\nfrom oneflow.test_utils.test_util import GenArgList\nfrom sklearn.metrics import roc_auc_score\n\n\ndef _test_roc_auc_score(test_case, label_dtype, pred_dtype):\n    inputs = [\n        {\"label\": [0, 0, 1, 1], \"pred\": [0.1, 0.4, 0.35, 0.8], \"score\": 0.75},\n        {\"label\": [0, 1, 0, 1], \"pred\": [0.5, 0.5, 0.5, 0.5], \"score\": 0.5},\n    ]\n    for data in inputs:\n        label = flow.tensor(data[\"label\"], dtype=label_dtype)\n        pred = flow.tensor(data[\"pred\"], dtype=pred_dtype)\n        of_score = flow.roc_auc_score(label, pred)\n        test_case.assertTrue(np.allclose(of_score.numpy()[0], data[\"score\"]))\n\n\ndef _compare_roc_auc_score(test_case, label_dtype, pred_dtype):\n    n_examples = 16384\n    label = np.random.randint(0, 2, n_examples)\n    pred = np.random.random(n_examples)\n    score = roc_auc_score(label, pred)\n\n    label = flow.tensor(label, dtype=label_dtype)\n    pred = flow.tensor(pred, dtype=pred_dtype)\n    of_score = flow.roc_auc_score(label, pred)\n\n    test_case.assertTrue(np.allclose(of_score.numpy()[0], score))\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestNMS(flow.unittest.TestCase):\n    def test_roc_auc_score(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [_test_roc_auc_score, _compare_roc_auc_score]\n        arg_dict[\"label_dtype\"] = [\n            flow.double,\n            flow.int32,\n            flow.float,\n            flow.int64,\n            flow.int8,\n            flow.uint8,\n        ]\n        arg_dict[\"pred_dtype\"] = [flow.float]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_roi_align.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nimport math\n\nimport oneflow as flow\nfrom oneflow.test_utils.test_util import GenArgList\n\n\ninput_np = np.array(\n    [\n        [\n            [\n                [\n                    0.33840093,\n                    1.1469249,\n                    1.0410756,\n                    -0.8350606,\n                    -1.782742,\n                    -0.00350855,\n                    -0.45829752,\n                    -1.0764053,\n                ],\n                [\n                    -0.4169678,\n                    -0.07322863,\n                    1.5186151,\n                    1.3238515,\n                    -0.3002863,\n                    0.90660757,\n                    -0.2955834,\n                    1.5069526,\n                ],\n                [\n                    0.3829125,\n                    1.0149552,\n                    -0.5808607,\n                    -0.4644214,\n                    1.2142111,\n                    0.668561,\n                    1.0866925,\n                    0.16446872,\n                ],\n                [\n                    0.14043295,\n                    -0.55108964,\n                    -0.8154048,\n                    1.1554539,\n                    2.421505,\n                    -0.54017824,\n                    0.32610297,\n                    -1.0632077,\n                ],\n                [\n                    -0.6218423,\n                    0.6000421,\n                    0.3742695,\n                    0.11130165,\n                    0.9991065,\n                    -0.28596586,\n                    -0.05164787,\n                    0.07725058,\n                ],\n                [\n                    0.6141537,\n                    0.2919493,\n                    0.2101646,\n                    -0.16639,\n                    1.145933,\n                    0.08825321,\n                    0.9865119,\n                    0.47285828,\n                ],\n                [\n                    -1.5073836,\n                    -0.8056736,\n                    -0.7402776,\n                    -0.9932287,\n                    0.74761075,\n                    -0.46474454,\n                    -0.22881153,\n                    0.6082243,\n                ],\n                [\n                    0.8328902,\n                    0.17223845,\n                    0.48917648,\n                    -1.6264182,\n                    0.248678,\n                    -1.2603166,\n                    1.2644174,\n                    0.06434552,\n                ],\n            ]\n        ],\n        [\n            [\n                [\n                    0.6627289,\n                    0.68173873,\n                    0.17659399,\n                    0.17474514,\n                    0.72995424,\n                    -0.47240442,\n                    0.27204773,\n                    -0.5277862,\n                ],\n                [\n                    0.23609516,\n                    0.9604236,\n                    0.78075147,\n                    0.26125216,\n                    0.72746485,\n                    0.04412199,\n                    0.04948105,\n                    -0.08477508,\n                ],\n                [\n                    0.8646437,\n                    -0.20755729,\n                    1.0184883,\n                    0.06346282,\n                    -0.18039183,\n                    0.56243396,\n                    -0.07350786,\n                    -1.8523406,\n                ],\n                [\n                    -0.2267861,\n                    -1.6466936,\n                    2.1746075,\n                    -1.2284307,\n                    0.74488103,\n                    -0.13243976,\n                    -0.9046582,\n                    -2.2992454,\n                ],\n                [\n                    -0.56131303,\n                    -0.17723852,\n                    -0.6063047,\n                    2.4105318,\n                    0.96672636,\n                    -1.8386889,\n                    1.1021106,\n                    -0.65429336,\n                ],\n                [\n                    2.0618255,\n                    -0.86972237,\n                    -0.59159493,\n                    0.9894253,\n                    -0.26607743,\n                    -0.395585,\n                    -0.44035113,\n                    -0.663197,\n                ],\n                [\n                    -0.02398485,\n                    -0.04574186,\n                    -0.43163615,\n                    -0.42599657,\n                    -2.751177,\n                    -0.35520887,\n                    -0.413676,\n                    2.0098279,\n                ],\n                [\n                    1.5619192,\n                    -2.4961088,\n                    0.08771367,\n                    -2.289146,\n                    1.0729461,\n                    0.7120767,\n                    -0.09780294,\n                    -1.6628668,\n                ],\n            ]\n        ],\n    ]\n)\n\nrois_np = np.array(\n    [\n        [1.0, 2.0, 1.0324688, 2.5, 3.90168],\n        [1.0, 2.5, 2.8329468, 3.5, 3.2008305],\n        [0.0, 1.0, 1.6188955, 2.0, 0.99051666],\n        [1.0, 1.0, 1.843338, 1.0, 3.9240131],\n        [1.0, 2.0, 2.798994, 3.5, 1.2012959],\n        [0.0, 0.5, 2.7753997, 3.0, 0.8280029],\n        [1.0, 0.5, 2.167975, 2.0, 2.067833],\n        [0.0, 0.5, 2.6843219, 2.0, 3.9924717],\n        [0.0, 2.0, 2.8996983, 3.5, 2.356554],\n        [0.0, 1.5, 0.34730053, 3.0, 2.8540745],\n        [0.0, 0.0, 2.096885, 0.5, 3.357812],\n        [0.0, 1.5, 0.10133362, 3.0, 0.18236923],\n        [1.0, 1.0, 1.609498, 1.5, 3.8893862],\n        [0.0, 1.5, 0.03415012, 1.5, 1.2880297],\n        [0.0, 0.5, 3.9403543, 2.0, 3.8870106],\n        [0.0, 0.0, 3.7515945, 3.5, 0.5866394],\n        [1.0, 1.5, 1.7729645, 2.0, 1.2372265],\n        [1.0, 0.0, 1.5092888, 2.0, 3.1585617],\n        [1.0, 0.0, 2.9033833, 1.5, 1.659832],\n        [1.0, 0.5, 1.9115062, 3.0, 1.066021],\n        [0.0, 1.5, 3.185645, 2.0, 0.20558739],\n        [1.0, 2.0, 0.3081894, 2.5, 2.4888725],\n        [0.0, 0.5, 3.5662794, 3.5, 2.8792458],\n        [1.0, 0.5, 2.556768, 2.5, 2.1553097],\n        [0.0, 1.0, 1.397994, 3.5, 0.77407074],\n        [0.0, 0.5, 3.1722808, 3.5, 2.5378036],\n        [0.0, 0.5, 0.11013985, 3.5, 0.8963146],\n        [0.0, 2.0, 1.1824799, 2.0, 3.2211132],\n        [1.0, 0.0, 3.9227288, 2.0, 2.0894089],\n        [0.0, 1.0, 0.79490566, 1.5, 3.4291687],\n    ]\n)\n\ninput_grad_np = np.array(\n    [\n        [\n            [\n                [\n                    0.2517704,\n                    1.7398968,\n                    8.248332,\n                    16.302334,\n                    11.048147,\n                    10.059495,\n                    2.800579,\n                    0.24844748,\n                ],\n                [\n                    0.790752,\n                    3.154358,\n                    13.0182705,\n                    15.519342,\n                    7.0133696,\n                    6.28652,\n                    3.9538488,\n                    0.51601994,\n                ],\n                [\n                    0.7077478,\n                    3.6854784,\n                    19.228241,\n                    22.597464,\n                    10.153106,\n                    6.2180595,\n                    3.5736852,\n                    0.44621366,\n                ],\n                [\n                    1.1430397,\n                    2.6666558,\n                    8.699481,\n                    12.510508,\n                    7.6093874,\n                    3.3150473,\n                    1.0373969,\n                    0.08225401,\n                ],\n                [\n                    7.372374,\n                    3.458156,\n                    6.5517087,\n                    10.535179,\n                    9.493686,\n                    5.800008,\n                    3.2196481,\n                    0.3790145,\n                ],\n                [\n                    9.979998,\n                    7.723156,\n                    11.384828,\n                    15.13672,\n                    14.71994,\n                    11.550301,\n                    8.666647,\n                    1.1556869,\n                ],\n                [\n                    7.4674473,\n                    7.990606,\n                    11.032139,\n                    10.031732,\n                    6.5969977,\n                    5.1203485,\n                    4.1267443,\n                    0.57233953,\n                ],\n                [\n                    1.9118737,\n                    10.9567,\n                    12.461995,\n                    10.991727,\n                    2.2403586,\n                    0.9002282,\n                    0.74645257,\n                    0.1000254,\n                ],\n            ]\n        ],\n        [\n            [\n                [0.0, 0.0, 0.0, 0.2796778, 1.6780672, 0.2796781, 0.0, 0.0],\n                [\n                    0.02485762,\n                    0.17400333,\n                    0.19886094,\n                    0.94998413,\n                    4.7056007,\n                    0.9251272,\n                    0.02485762,\n                    0.0,\n                ],\n                [\n                    0.54076296,\n                    2.3330488,\n                    4.2377095,\n                    13.100019,\n                    12.285746,\n                    4.7681584,\n                    1.6636131,\n                    0.18542966,\n                ],\n                [\n                    4.555413,\n                    9.538326,\n                    14.063398,\n                    17.882318,\n                    14.635002,\n                    5.9126663,\n                    2.6039343,\n                    0.3144545,\n                ],\n                [\n                    7.877132,\n                    19.767809,\n                    24.037426,\n                    15.584505,\n                    14.542083,\n                    4.4302306,\n                    2.3387682,\n                    0.3145125,\n                ],\n                [\n                    6.3498077,\n                    11.157468,\n                    14.465272,\n                    6.4254785,\n                    7.471047,\n                    7.448948,\n                    6.4972777,\n                    0.88493776,\n                ],\n                [\n                    2.473032,\n                    6.144208,\n                    9.52839,\n                    3.2779343,\n                    4.3061023,\n                    6.409383,\n                    5.87155,\n                    0.80066556,\n                ],\n                [\n                    1.4289956,\n                    4.5101476,\n                    7.2189507,\n                    2.2500885,\n                    2.8763475,\n                    0.45081174,\n                    0.0,\n                    0.0,\n                ],\n            ]\n        ],\n    ]\n)\n\n\ndef bilinear_interpolate(data, y, x, snap_border=False):\n    height, width = data.shape\n\n    if snap_border:\n        if -1 < y <= 0:\n            y = 0\n        elif height - 1 <= y < height:\n            y = height - 1\n\n        if -1 < x <= 0:\n            x = 0\n        elif width - 1 <= x < width:\n            x = width - 1\n\n    y_low = int(math.floor(y))\n    x_low = int(math.floor(x))\n    y_high = y_low + 1\n    x_high = x_low + 1\n\n    wy_h = y - y_low\n    wx_h = x - x_low\n    wy_l = 1 - wy_h\n    wx_l = 1 - wx_h\n\n    val = 0\n    for wx, xp in zip((wx_l, wx_h), (x_low, x_high)):\n        for wy, yp in zip((wy_l, wy_h), (y_low, y_high)):\n            if 0 <= yp < height and 0 <= xp < width:\n                val += wx * wy * data[yp, xp]\n    return val\n\n\ndef roi_align_np(\n    in_data,\n    rois,\n    pool_h,\n    pool_w,\n    spatial_scale=1,\n    sampling_ratio=-1,\n    aligned=False,\n    dtype=np.float32,\n):\n    n_channels = in_data.shape[1]\n    out_data = np.zeros((rois.shape[0], n_channels, pool_h, pool_w), dtype=dtype)\n\n    offset = 0.5 if aligned else 0.0\n\n    for r, roi in enumerate(rois):\n        batch_idx = int(roi[0])\n        j_begin, i_begin, j_end, i_end = (\n            x.item() * spatial_scale - offset for x in roi[1:]\n        )\n\n        roi_h = i_end - i_begin\n        roi_w = j_end - j_begin\n        bin_h = roi_h / pool_h\n        bin_w = roi_w / pool_w\n\n        for i in range(0, pool_h):\n            start_h = i_begin + i * bin_h\n            grid_h = sampling_ratio if sampling_ratio > 0 else int(np.ceil(bin_h))\n            for j in range(0, pool_w):\n                start_w = j_begin + j * bin_w\n                grid_w = sampling_ratio if sampling_ratio > 0 else int(np.ceil(bin_w))\n\n                for channel in range(0, n_channels):\n\n                    val = 0\n                    for iy in range(0, grid_h):\n                        y = start_h + (iy + 0.5) * bin_h / grid_h\n                        for ix in range(0, grid_w):\n                            x = start_w + (ix + 0.5) * bin_w / grid_w\n                            val += bilinear_interpolate(\n                                in_data[batch_idx, channel, :, :],\n                                y,\n                                x,\n                                snap_border=True,\n                            )\n                    val /= grid_h * grid_w\n\n                    out_data[r, channel, i, j] = val\n    return out_data\n\n\ndef _test_roi_align(test_case, device):\n    input = flow.tensor(\n        np.random.randn(2, 3, 64, 64), dtype=flow.float32, device=flow.device(device)\n    )\n\n    random_img_idx = np.random.randint(low=0, high=2, size=(200, 1))\n    random_box_idx = np.random.uniform(low=0, high=64 * 64, size=(200, 2)).astype(\n        np.float32\n    )\n\n    def get_h_w(idx1, idx2):\n        if idx1 > idx2:\n            idx1, idx2 = idx2, idx1\n        h1 = idx1 // 64\n        w1 = idx1 % 64\n        h2 = idx2 // 64\n        w2 = idx2 % 64\n        return [x / 2 for x in [h1, w1, h2, w2]]\n\n    zipped = zip(random_box_idx[:, 0], random_box_idx[:, 1])\n    concated = [get_h_w(idx1, idx2) for (idx1, idx2) in zipped]\n    concated = np.array(concated)\n    rois = flow.tensor(\n        np.hstack((random_img_idx, concated)),\n        dtype=flow.float32,\n        device=flow.device(device),\n    )\n\n    of_out = flow.roi_align(input, rois, 2.0, 14, 14, 2, True)\n    np_out = roi_align_np(input.numpy(), rois.numpy(), 14, 14, 2.0, 2, True)\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, rtol=1e-4, atol=1e-4))\n\n\ndef _test_roi_align_backward(test_case, device):\n    input = flow.tensor(\n        input_np, dtype=flow.float32, device=flow.device(device), requires_grad=True\n    )\n    rois = flow.tensor(rois_np, dtype=flow.float32, device=flow.device(device))\n    of_out = flow.roi_align(input, rois, 2.0, 5, 5, 2, True)\n    of_out.sum().backward()\n    test_case.assertTrue(\n        np.allclose(input.grad.numpy(), input_grad_np, rtol=1e-5, atol=1e-5)\n    )\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestRoIAlign(flow.unittest.TestCase):\n    def test_roi_align(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [_test_roi_align, _test_roi_align_backward]\n        arg_dict[\"device\"] = [\"cuda\"]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_roll.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\n\nimport oneflow as flow\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport torch\n\n\ndef _test_roll(test_case, device):\n    torch_x = torch.rand(\n        (2, 3, 5, 10, 20), device=device, dtype=torch.float32, requires_grad=True\n    )\n    torch_grad = torch.rand_like(torch_x, device=device)\n\n    shifts = (\n        np.random.randint(-100, 100),\n        np.random.randint(-100, 100),\n        np.random.randint(-100, 100),\n        np.random.randint(-100, 100),\n    )\n    dims = (0, 2, 3, 4)\n\n    torch_y = torch.roll(torch_x, shifts, dims)\n    torch_y.backward(torch_grad)\n\n    of_x = flow.tensor(\n        torch_x.detach().cpu().numpy(),\n        device=device,\n        dtype=flow.float32,\n        requires_grad=True,\n    )\n    of_y = flow.roll(of_x, shifts, dims)\n    of_grad = flow.tensor(torch_grad.cpu().numpy(), device=device, dtype=flow.float32)\n    of_y.backward(of_grad)\n\n    test_case.assertTrue(np.array_equal(of_y.numpy(), torch_y.detach().cpu().numpy()))\n    test_case.assertTrue(np.array_equal(of_x.grad.numpy(), torch_x.grad.cpu().numpy()))\n\n\ndef _test_roll_single_dims(test_case, device):\n    torch_x = torch.rand(\n        (2, 3, 5, 10, 20), device=device, dtype=torch.float32, requires_grad=True\n    )\n    torch_grad = torch.rand_like(torch_x, device=device)\n\n    shifts = np.random.randint(-100, 100)\n    dims = np.random.randint(0, 4)\n\n    torch_y = torch.roll(torch_x, shifts, dims)\n    torch_y.backward(torch_grad)\n\n    of_x = flow.tensor(\n        torch_x.detach().cpu().numpy(),\n        device=device,\n        dtype=flow.float32,\n        requires_grad=True,\n    )\n    of_y = flow.roll(of_x, shifts, dims)\n    of_grad = flow.tensor(torch_grad.cpu().numpy(), device=device, dtype=flow.float32)\n    of_y.backward(of_grad)\n\n    test_case.assertTrue(np.array_equal(of_y.numpy(), torch_y.detach().cpu().numpy()))\n    test_case.assertTrue(np.array_equal(of_x.grad.numpy(), torch_x.grad.cpu().numpy()))\n\n\ndef _test_roll_none_dims(test_case, device):\n    torch_x = torch.rand(\n        (2, 3, 5, 10, 20), device=device, dtype=torch.float32, requires_grad=True\n    )\n    torch_grad = torch.rand_like(torch_x, device=device)\n\n    shifts = np.random.randint(-100, 100)\n    dims = None\n\n    torch_y = torch.roll(torch_x, shifts, dims)\n    torch_y.backward(torch_grad)\n\n    of_x = flow.tensor(\n        torch_x.detach().cpu().numpy(),\n        device=device,\n        dtype=flow.float32,\n        requires_grad=True,\n    )\n    of_y = flow.roll(of_x, shifts, dims)\n    of_grad = flow.tensor(torch_grad.cpu().numpy(), device=device, dtype=flow.float32)\n    of_y.backward(of_grad)\n\n    test_case.assertTrue(np.array_equal(of_y.numpy(), torch_y.detach().cpu().numpy()))\n    test_case.assertTrue(np.array_equal(of_x.grad.numpy(), torch_x.grad.cpu().numpy()))\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestRoll(flow.unittest.TestCase):\n    def test_expand_compare_with_torch(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [\n            _test_roll,\n            _test_roll_single_dims,\n            _test_roll_none_dims,\n        ]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_round.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\n\nimport oneflow as flow\nfrom oneflow.test_utils.test_util import GenArgList\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestRound(flow.unittest.TestCase):\n    @autotest(check_graph=True)\n    def test_flow_round_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor().to(device)\n        y = torch.round(x)\n        return y\n\n    @autotest(check_graph=True)\n    def test_flow_round_with_0dim_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=0).to(device)\n        y = torch.round(x)\n        return y\n\n    @autotest(check_graph=True)\n    def test_flow_round_half_to_even(test_case):\n        device = random_device()\n        random_shape = [random(1, 10).to(int).value() for _ in range(4)]\n        random_tenosr = np.random.randint(-99999, 99999, size=random_shape)\n        x = torch.tensor(random_tenosr).to(device)\n        y = torch.full(x.shape, 0.5).to(device)\n        y += x\n        y = y.requires_grad_()\n        z = torch.round(y)\n        return z\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_rrelu.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport os\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nimport torch as torch_original\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\ndef do_test_rrelu_same_bound(test_case, shape, device, dtype):\n    np_x = np.random.randn(*shape).astype(dtype)\n    flow.manual_seed(233)\n    torch_original.manual_seed(233)\n\n    flow_tensor = flow.tensor(np_x, requires_grad=True, device=device)\n    torch_tensor = torch_original.tensor(np_x, requires_grad=True, device=device)\n\n    rate = np.random.randn()\n    flow_rrelu = flow.nn.RReLU(lower=rate, upper=rate)\n    torch_rrelu = torch_original.nn.RReLU(lower=rate, upper=rate)\n    flow_out = flow_rrelu(flow_tensor)\n    torch_out = torch_rrelu(torch_tensor)\n\n    test_case.assertTrue(\n        np.allclose(\n            flow_out.cpu().detach().numpy(),\n            torch_out.cpu().detach().numpy(),\n            atol=1e-5,\n            rtol=1e-5,\n        )\n    )\n    flow_out.sum().backward()\n    torch_out.sum().backward()\n    test_case.assertTrue(\n        np.allclose(\n            flow_tensor.grad.cpu().detach().numpy(),\n            torch_tensor.grad.cpu().detach().numpy(),\n            atol=1e-5,\n            rtol=1e-5,\n        )\n    )\n\n\ndef do_test_rrelu_different_bound(test_case, shape, device, dtype):\n    np_x = np.random.randn(*shape).astype(dtype)\n    flow_tensor = flow.tensor(np_x, requires_grad=True, device=device)\n    rate = np.random.randn()\n    flow_rrelu = flow.nn.RReLU(lower=rate, upper=rate + 0.5)\n    flow_out = flow_rrelu(flow_tensor)\n    flow_out.sum().backward()\n    flow_grad = flow_tensor.grad\n    flow_div = flow_out / flow_tensor\n    test_case.assertTrue(\n        np.allclose(\n            (flow.where(flow_tensor >= 0, 1, 0)).cpu().detach().numpy(),\n            (flow.where(flow_div == 1.0, 1, 0)).cpu().detach().numpy(),\n            rtol=1e-4,\n        )\n    )\n    test_case.assertTrue(\n        np.allclose(\n            (flow.where(flow_tensor < 0, 1, 0)).cpu().detach().numpy(),\n            (\n                flow.where(\n                    flow.logical_and(\n                        flow.logical_and(flow_div >= rate, flow_div <= (rate + 0.5)),\n                        flow_tensor < 0,\n                    ),\n                    1,\n                    0,\n                )\n            )\n            .cpu()\n            .detach()\n            .numpy(),\n        )\n    )\n    test_case.assertTrue(\n        np.allclose(\n            flow_grad.cpu().detach().numpy(),\n            flow_div.cpu().detach().numpy(),\n            rtol=1e-1,\n            atol=1e-4,\n        )\n    )\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestModule(flow.unittest.TestCase):\n    @unittest.skip(\"skip for now, becase it failed 4 times in past week\")\n    def test_numpy_case(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [\n            do_test_rrelu_same_bound,\n            do_test_rrelu_different_bound,\n        ]\n        arg_dict[\"shape\"] = [\n            [20],\n            [12, 32],\n            [4, 47, 156],\n            [5, 33, 65],\n            [3, 132, 94],\n            [9, 256, 63],\n        ]\n        # NOTE(hujiakui): in PyTorch <= 1.13, the CUDA RReLU Backward Function of PyTorch is wrong.\n        if float(torch_original.__version__[:4]) < 1.13:\n            arg_dict[\"device\"] = [\"cpu\"]\n        else:\n            arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"dtype\"] = [np.float32, np.float64]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n    @autotest(n=5)\n    def test_functional_rrelu(test_case):\n        device = random_device()\n        x = random_tensor(ndim=random(), dim0=random(1, 8)).to(device)\n        lower = np.abs(\n            np.random.randn()\n        )  # In-place leakyReLu backward calculation is triggered with a negative slope which is not supported\n        return torch.nn.functional.rrelu(\n            x, lower=lower, upper=lower + 0.5, inplace=random_bool(), training=False,\n        )\n\n    @autotest(n=5)\n    @unittest.skipIf(\n        float(torch_original.__version__[:4]) < 1.13\n        and not os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"),\n        f\"RReLU CUDA test need pytorch version >= 1.13, got {torch_original.__version__}\",\n    )\n    def test_rrelu_train(test_case):\n        device = random_device()\n        x = random_tensor(ndim=random(), dim0=random(1, 8)).to(device)\n        lower = np.abs(np.random.randn())\n        m = torch.nn.RReLU(lower=lower, upper=lower, inplace=random_bool())\n        return m(x)\n\n    @autotest(n=5, check_graph=False)\n    def test_rrelu_eval(test_case):\n        device = random_device()\n        x = random_tensor(ndim=random(), dim0=random(1, 8)).to(device)\n        lower = np.abs(np.random.randn())\n        m = torch.nn.RReLU(lower=lower, upper=lower, inplace=random_bool()).eval()\n        return m(x)\n\n    @profile(torch.nn.functional.rrelu)\n    def profile_rrelu(test_case):\n        lower = np.random.randn()\n        torch.nn.functional.rrelu(\n            torch.ones(1, 128, 28, 28),\n            lower=lower,\n            upper=lower + 0.5,\n            inplace=False,\n            training=True,\n        )\n        torch.nn.functional.rrelu(\n            torch.ones(1, 128, 28, 28),\n            lower=lower,\n            upper=lower + 0.5,\n            inplace=True,\n            training=True,\n        )\n\n        torch.nn.functional.rrelu(\n            torch.ones(16, 128, 28, 28),\n            lower=lower,\n            upper=lower + 0.5,\n            inplace=False,\n            training=True,\n        )\n\n        torch.nn.functional.rrelu(\n            torch.ones(16, 128, 28, 28),\n            lower=lower,\n            upper=lower + 0.5,\n            inplace=True,\n            training=True,\n        )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_save_load.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport os\nimport warnings\nimport tempfile\nimport unittest\nfrom pathlib import Path\nimport io\n\nimport numpy as np\nimport torch\n\nimport oneflow as flow\nimport oneflow.nn as nn\nimport oneflow.unittest\n\n\nclass CustomModuleForSaveLoad(flow.nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.param = flow.nn.Parameter(flow.randn(1, 3, 3, 3))\n\n    def forward(self, x):\n        return self.param + x\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\nclass TestSaveLoad(flow.unittest.TestCase):\n    @flow.unittest.skip_unless_1n1d()\n    def test_load_map_location(test_case):\n        x = flow.ones(1, 2, 3)\n        y = flow.ones(2, 3, 4)\n        with tempfile.NamedTemporaryFile() as f:\n            flow.save({\"x\": x, \"y\": y}, f.name)\n            loaded = flow.load(f.name, map_location=\"cuda\")\n        assert np.array_equal(loaded[\"x\"].numpy(), x.numpy())\n        assert loaded[\"x\"].device == flow.device(\"cuda\")\n        assert np.array_equal(loaded[\"y\"].numpy(), y.numpy())\n        assert loaded[\"y\"].device == flow.device(\"cuda\")\n\n        with tempfile.NamedTemporaryFile() as f:\n            flow.save({\"x\": x, \"y\": y}, f.name)\n            loaded = flow.load(f.name, map_location=\"cpu\")\n        assert np.array_equal(loaded[\"x\"].numpy(), x.numpy())\n        assert loaded[\"x\"].device == flow.device(\"cpu\")\n        assert np.array_equal(loaded[\"y\"].numpy(), y.numpy())\n        assert loaded[\"y\"].device == flow.device(\"cpu\")\n\n        x = x.to_global(sbp=flow.sbp.broadcast, placement=flow.placement(\"cuda\", [0]))\n        y = y.to_global(sbp=flow.sbp.broadcast, placement=flow.placement(\"cuda\", [0]))\n\n        with tempfile.NamedTemporaryFile() as f:\n            flow.save({\"x\": x, \"y\": y}, f.name, global_dst_rank=0)\n            loaded = flow.load(\n                f.name, global_src_rank=0, map_location=flow.placement(\"cuda\", [0])\n            )\n        assert np.array_equal(loaded[\"x\"].numpy(), x.numpy())\n        assert loaded[\"x\"].placement == flow.placement(\"cuda\", [0])\n        assert np.array_equal(loaded[\"y\"].numpy(), y.numpy())\n        assert loaded[\"y\"].placement == flow.placement(\"cuda\", [0])\n\n        with tempfile.NamedTemporaryFile() as f:\n            flow.save({\"x\": x, \"y\": y}, f.name, global_dst_rank=0)\n            loaded = flow.load(\n                f.name, global_src_rank=0, map_location=flow.placement(\"cpu\", [0])\n            )\n        assert np.array_equal(loaded[\"x\"].numpy(), x.numpy())\n        assert loaded[\"y\"].placement == flow.placement(\"cpu\", [0])\n        assert np.array_equal(loaded[\"y\"].numpy(), y.numpy())\n        assert loaded[\"y\"].placement == flow.placement(\"cpu\", [0])\n\n    @flow.unittest.skip_unless_1n1d()\n    def test_save_dir(test_case):\n        m1 = CustomModuleForSaveLoad()\n        with tempfile.TemporaryDirectory() as save_dir:\n            flow.save(m1.state_dict(), save_dir, save_as_external_data=True)\n            loaded_state_dict = flow.load(save_dir)\n        m2 = CustomModuleForSaveLoad()\n        m2.load_state_dict(loaded_state_dict)\n        test_case.assertTrue(np.array_equal(m1.param.numpy(), m2.param.numpy()))\n\n    @flow.unittest.skip_unless_1n1d()\n    def test_save_dir_fault_tolerance(test_case):\n        m1 = CustomModuleForSaveLoad()\n        with tempfile.TemporaryDirectory() as save_dir:\n            flow.save(m1.state_dict(), save_dir, save_as_external_data=True)\n            with open(os.path.join(save_dir, \"random_file\"), \"w\") as fp:\n                fp.write(\"nothing\")\n            with warnings.catch_warnings():\n                warnings.simplefilter(\"ignore\")\n                loaded_state_dict = flow.load(save_dir)\n        m2 = CustomModuleForSaveLoad()\n        m2.load_state_dict(loaded_state_dict)\n        test_case.assertTrue(np.array_equal(m1.param.numpy(), m2.param.numpy()))\n\n    @flow.unittest.skip_unless_1n1d()\n    def test_save_state_dict(test_case):\n        class CustomModule(flow.nn.Module):\n            def __init__(self):\n                super().__init__()\n                self.param1 = flow.nn.Parameter(flow.Tensor(32, 1024, 1024))\n                self.param2 = flow.nn.Parameter(flow.Tensor(32, 1024, 1024))\n\n            def forward(self):\n                return self.param1 + self.param2\n\n        m = CustomModule()\n        res1 = m()\n        state_dict = m.state_dict()\n        with tempfile.NamedTemporaryFile() as f:\n            flow.save(state_dict, f.name)\n            test_case.assertTrue(os.path.exists(f.name))\n            loaded_state_dict = flow.load(f.name)\n            m.load_state_dict(loaded_state_dict)\n        res2 = m()\n        test_case.assertTrue(np.array_equal(res1.numpy(), res2.numpy()))\n\n    @flow.unittest.skip_unless_1n1d()\n    def test_save_state_dict_bytes(test_case):\n        class CustomModule(flow.nn.Module):\n            def __init__(self):\n                super().__init__()\n                self.param1 = flow.nn.Parameter(flow.Tensor(32, 1024, 1024))\n                self.param2 = flow.nn.Parameter(flow.Tensor(32, 1024, 1024))\n\n            def forward(self):\n                return self.param1 + self.param2\n\n        m = CustomModule()\n        res1 = m()\n        state_dict = m.state_dict()\n        with tempfile.NamedTemporaryFile() as path:\n            buffer = io.BytesIO()\n            flow.save(state_dict, buffer)\n            with open(path.name, \"wb\") as f:\n                f.write(buffer.getvalue())\n            test_case.assertTrue(os.path.exists(path.name))\n            loaded_state_dict = flow.load(path.name)\n            m.load_state_dict(loaded_state_dict)\n        res2 = m()\n        test_case.assertTrue(np.array_equal(res1.numpy(), res2.numpy()))\n\n    def _test_save_and_load_global_from_nested_dict(test_case):\n        class CustomModule(flow.nn.Module):\n            def __init__(self):\n                super().__init__()\n                self.param = flow.nn.Parameter(flow.randn(3, 32, 3, 3))\n\n            def forward(self):\n                return self.param\n\n        m1 = CustomModule()\n        m1 = m1.to_global(\n            flow.placement(\"cuda\", range(1, 3)), flow.sbp.broadcast\n        ).to_global(sbp=flow.sbp.split(1))\n        m2 = CustomModule()\n        m2 = m2.to_global(flow.placement(\"cuda\", range(1, 3)), flow.sbp.broadcast)\n        res1 = m1() + m2()\n        state_dict1 = m1.state_dict()\n        state_dict2 = m2.state_dict()\n        state_dict = {\"m1\": state_dict1, \"m2\": state_dict2}\n\n        with tempfile.TemporaryDirectory() as dir:\n            filename = os.path.join(dir, \"tmp\")\n            with test_case.assertRaises(Exception):\n                flow.save(state_dict, filename)\n\n            global_src_dst_rank = 0\n            flow.save(state_dict, filename, global_dst_rank=global_src_dst_rank)\n            rank = flow.env.get_rank()\n            if rank != global_src_dst_rank:\n                test_case.assertFalse(os.path.exists(filename))\n\n            m1 = CustomModule()\n            m1 = m1.to_global(\n                flow.placement(\"cuda\", [[0, 1], [2, 3]]),\n                [flow.sbp.broadcast, flow.sbp.broadcast],\n            ).to_global(sbp=[flow.sbp.split(1), flow.sbp.broadcast])\n            m2 = CustomModule()\n            m2 = m2.to_global(\n                flow.placement(\"cuda\", [[0, 1], [2, 3]]),\n                [flow.sbp.broadcast, flow.sbp.broadcast],\n            ).to_global(sbp=[flow.sbp.broadcast, flow.sbp.split(1)])\n\n            with test_case.assertRaises(Exception):\n                loaded_state_dict = flow.load(filename)\n                m1.load_state_dict(loaded_state_dict[\"m1\"])\n\n            loaded_state_dict = flow.load(filename, global_src_rank=global_src_dst_rank)\n            test_case.assertEqual(len(loaded_state_dict), 2)\n            m1.load_state_dict(loaded_state_dict[\"m1\"])\n            m2.load_state_dict(loaded_state_dict[\"m2\"])\n            res2 = m1() + m2()\n\n        test_case.assertTrue(np.array_equal(res1.numpy(), res2.numpy()))\n\n    @flow.unittest.skip_unless_1n4d()\n    def test_save_and_load_global_from_nested_dict_1n4d(test_case):\n        test_case._test_save_and_load_global_from_nested_dict()\n\n    @flow.unittest.skip_unless_2n2d()\n    def test_save_and_load_global_from_nested_dict_2n2d(test_case):\n        test_case._test_save_and_load_global_from_nested_dict()\n\n    @flow.unittest.skip_unless_1n1d()\n    def test_load_pytorch_weights(test_case):\n        for device in [\"cpu\", \"cuda\"]:\n            for map_location in [None, flow.device(\"cuda:0\")]:\n                conv_torch = torch.nn.Conv2d(3, 3, 3).to(device)\n\n                conv_flow1 = flow.nn.Conv2d(3, 3, 3).to(device)\n                with tempfile.NamedTemporaryFile() as f:\n                    torch.save(conv_torch.state_dict(), f.name)\n                    conv_flow1.load_state_dict(\n                        flow.load(f.name, map_location=map_location)\n                    )\n                test_case.assertTrue(\n                    np.array_equal(\n                        conv_torch.weight.detach().cpu().numpy(),\n                        conv_flow1.weight.numpy(),\n                    )\n                )\n\n                conv_flow2 = flow.nn.Conv2d(3, 3, 3).to(device)\n                with tempfile.NamedTemporaryFile() as f:\n                    torch.save({\"weights\": conv_torch.state_dict()}, f.name)\n                    conv_flow2.load_state_dict(\n                        flow.load(f.name, map_location=map_location)[\"weights\"]\n                    )\n                test_case.assertTrue(\n                    np.array_equal(\n                        conv_torch.weight.detach().cpu().numpy(),\n                        conv_flow2.weight.numpy(),\n                    )\n                )\n\n    @flow.unittest.skip_unless_1n2d()\n    def test_load_pytorch_weights_global(test_case):\n        for device in [\"cpu\", \"cuda\"]:\n            for map_location in [None, flow.placement.all(\"cuda\")]:\n                conv_torch = torch.nn.Conv2d(3, 3, 3).to(device)\n\n                all_placement = flow.placement.all(device)\n                conv_flow1 = flow.nn.Conv2d(3, 3, 3).to_global(\n                    all_placement, flow.sbp.broadcast\n                )\n                with tempfile.NamedTemporaryFile() as f:\n                    if flow.env.get_rank() == 0:\n                        torch.save(conv_torch.state_dict(), f.name)\n                    conv_flow1.load_state_dict(\n                        flow.load(f.name, map_location=map_location, global_src_rank=0)\n                    )\n                if flow.env.get_rank() == 0:\n                    test_case.assertTrue(\n                        np.array_equal(\n                            conv_torch.weight.detach().cpu().numpy(),\n                            conv_flow1.weight.numpy(),\n                        )\n                    )\n\n                conv_flow2 = flow.nn.Conv2d(3, 3, 3).to_global(\n                    all_placement, flow.sbp.broadcast\n                )\n                with tempfile.NamedTemporaryFile() as f:\n                    if flow.env.get_rank() == 0:\n                        torch.save({\"weights\": conv_torch.state_dict()}, f.name)\n                    conv_flow2.load_state_dict(\n                        flow.load(f.name, map_location=map_location, global_src_rank=0)[\n                            \"weights\"\n                        ]\n                    )\n                if flow.env.get_rank() == 0:\n                    test_case.assertTrue(\n                        np.array_equal(\n                            conv_torch.weight.detach().cpu().numpy(),\n                            conv_flow2.weight.numpy(),\n                        )\n                    )\n\n    @flow.unittest.skip_unless_1n1d()\n    def test_save_load_module_directly(test_case):\n        x = flow.randn(1, 3, 3, 3)\n\n        m = CustomModuleForSaveLoad()\n\n        with tempfile.NamedTemporaryFile() as f:\n            flow.save(m, f.name)\n            new_m = flow.load(f.name)\n            res = m(x)\n            new_res = new_m(x)\n            test_case.assertTrue(np.array_equal(res.numpy(), new_res.numpy()))\n\n        m = flow.nn.parallel.DistributedDataParallel(m)\n        test_case.assertTrue(m._is_ddp_module)\n\n        with tempfile.NamedTemporaryFile() as f:\n            flow.save(m, f.name)\n            new_m = flow.load(f.name)\n            test_case.assertTrue(new_m._is_ddp_module)\n            res = m(x)\n            new_res = new_m(x)\n            test_case.assertTrue(np.array_equal(res.numpy(), new_res.numpy()))\n\n    @flow.unittest.skip_unless_1n1d()\n    def test_save_load_module_directly_save_bytes(test_case):\n        x = flow.randn(1, 3, 3, 3)\n\n        m = CustomModuleForSaveLoad()\n\n        with tempfile.NamedTemporaryFile() as path:\n            buffer = io.BytesIO()\n            flow.save(m, buffer)\n            with open(path.name, \"wb\") as f:\n                f.write(buffer.getvalue())\n            new_m = flow.load(path.name)\n            res = m(x)\n            new_res = new_m(x)\n            test_case.assertTrue(np.array_equal(res.numpy(), new_res.numpy()))\n\n        m = flow.nn.parallel.DistributedDataParallel(m)\n        test_case.assertTrue(m._is_ddp_module)\n\n        with tempfile.NamedTemporaryFile() as path:\n            buffer = io.BytesIO()\n            flow.save(m, buffer)\n            with open(path.name, \"wb\") as f:\n                f.write(buffer.getvalue())\n            new_m = flow.load(path.name)\n            test_case.assertTrue(new_m._is_ddp_module)\n            res = m(x)\n            new_res = new_m(x)\n            test_case.assertTrue(np.array_equal(res.numpy(), new_res.numpy()))\n\n    @flow.unittest.skip_unless_1n1d()\n    def test_save_load_module_directly_load_filestream(test_case):\n        x = flow.randn(1, 3, 3, 3)\n\n        m = CustomModuleForSaveLoad()\n\n        with tempfile.NamedTemporaryFile() as f:\n            flow.save(m, f.name)\n            with open(f.name, \"rb\") as r:\n                new_m = flow.load(r)\n            res = m(x)\n            new_res = new_m(x)\n            test_case.assertTrue(np.array_equal(res.numpy(), new_res.numpy()))\n\n        m = flow.nn.parallel.DistributedDataParallel(m)\n        test_case.assertTrue(m._is_ddp_module)\n\n        with tempfile.NamedTemporaryFile() as f:\n            flow.save(m, f.name)\n            with open(f.name, \"rb\") as r:\n                new_m = flow.load(r)\n            test_case.assertTrue(new_m._is_ddp_module)\n            res = m(x)\n            new_res = new_m(x)\n            test_case.assertTrue(np.array_equal(res.numpy(), new_res.numpy()))\n\n    def test_load_old_dir_data(test_case):\n        test_data_dir = Path(__file__).parent / \"save_load_test_data\"\n        m1 = nn.Conv2d(3, 3, 3)\n        params = flow.load(test_data_dir / \"3x3_i3o3_conv2d_params\")\n        m1.load_state_dict(params)\n\n        m2 = flow.load(test_data_dir / \"3x3_i3o3_conv2d\")\n\n        x = flow.randn(1, 3, 3, 3)\n        y1 = m1(x)\n        y2 = m2(x)\n        test_case.assertTrue(np.array_equal(y1.numpy(), y2.numpy()))\n\n    def test_pytorch_non_tensor(test_case):\n        with tempfile.NamedTemporaryFile() as f:\n            torch.save({\"a\": 2}, f.name)\n            res = flow.load(f.name, map_location=\"cpu\")\n        test_case.assertTrue(isinstance(res, dict))\n        test_case.assertEqual(len(res), 1)\n        test_case.assertEqual(res[\"a\"], 2)\n\n    def test_pytorch_non_tensor_load_filestream(test_case):\n        with tempfile.NamedTemporaryFile() as f:\n            torch.save({\"a\": 2}, f.name)\n            with open(f.name, \"rb\") as r:\n                res = flow.load(r, map_location=\"cpu\")\n        test_case.assertTrue(isinstance(res, dict))\n        test_case.assertEqual(len(res), 1)\n        test_case.assertEqual(res[\"a\"], 2)\n\n    def test_pytorch_non_tensor_save_bytes(test_case):\n        with tempfile.NamedTemporaryFile() as path:\n            buffer = io.BytesIO()\n            torch.save({\"a\": 2}, buffer)\n            with open(path.name, \"wb\") as f:\n                f.write(buffer.getvalue())\n            res = flow.load(path.name, map_location=\"cpu\")\n        test_case.assertTrue(isinstance(res, dict))\n        test_case.assertEqual(len(res), 1)\n        test_case.assertEqual(res[\"a\"], 2)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_saved_tensor_hooks.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport os\nimport unittest\n\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n@flow.unittest.skip_unless_1n1d()\nclass TestSavedTensorHooks(flow.unittest.TestCase):\n    def test_normal_saved_tensor_hooks(test_case):\n        x = flow.ones(1, 2, 3).to(\"cuda\").requires_grad_()\n        y = flow.zeros(1, 2, 3).to(\"cuda\").requires_grad_()\n        tensor_list = []\n\n        def pack(x):\n            tensor_list.append(x)\n            return len(tensor_list) - 1\n\n        def unpack(x):\n            return tensor_list[x]\n\n        with flow.autograd.graph.saved_tensors_hooks(pack, unpack):\n            z = x * y\n        z.sum().backward()\n        test_case.assertEqual(len(tensor_list), 2)\n        test_case.assertTrue(np.array_equal(tensor_list[0], y))\n        test_case.assertTrue(np.array_equal(tensor_list[1], x))\n        test_case.assertTrue(np.allclose(x.grad, y))\n        test_case.assertTrue(np.allclose(y.grad, x))\n\n    def test_saved_tensor_hooks_in_autograd_function(test_case):\n        x = flow.ones(1, 2, 3).to(\"cuda\").requires_grad_()\n        y = flow.zeros(1, 2, 3).to(\"cuda\").requires_grad_()\n        tensor_list = []\n\n        def pack(x):\n            tensor_list.append(x)\n            return len(tensor_list) - 1\n\n        def unpack(x):\n            return tensor_list[x]\n\n        class MulFunction(flow.autograd.Function):\n            @staticmethod\n            def forward(ctx, x, y):\n                ctx.save_for_backward(x, y)\n                return x * y\n\n            @staticmethod\n            def backward(ctx, dz):\n                x, y = ctx.saved_tensors\n                dx = dz * y\n                dy = dz * x\n                return dx, dy\n\n        with flow.autograd.graph.saved_tensors_hooks(pack, unpack):\n            z = MulFunction.apply(x, y)\n        z.sum().backward()\n        test_case.assertEqual(len(tensor_list), 2)\n        test_case.assertTrue(np.array_equal(tensor_list[0], x))\n        test_case.assertTrue(np.array_equal(tensor_list[1], y))\n        test_case.assertTrue(np.allclose(x.grad, y))\n        test_case.assertTrue(np.allclose(y.grad, x))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_sbp_symbol.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nimport oneflow as flow\nimport oneflow.unittest\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestSBPSymbol(flow.unittest.TestCase):\n    def test_sbp_symbol(test_case):\n        test_case.assertTrue(flow.sbp.split(0) == flow.sbp.split(0)())\n        test_case.assertTrue(flow.sbp.split(1) == flow.sbp.split(1)())\n        test_case.assertTrue(flow.sbp.split(0) != flow.sbp.split(1))\n        test_case.assertTrue(flow.sbp.broadcast == flow.sbp.broadcast())\n        test_case.assertTrue(flow.sbp.partial_sum == flow.sbp.partial_sum())\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_scatter_nd.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\ndef _test_scatter_nd(test_case, device):\n    indices = flow.tensor(\n        np.array([[1], [6], [4]]), dtype=flow.int, device=flow.device(device)\n    )\n    update = flow.tensor(\n        np.array([10.2, 5.1, 12.7]), dtype=flow.float, device=flow.device(device)\n    )\n    np_out = np.array([0.0, 10.2, 0.0, 0.0, 12.7, 0.0, 5.1, 0.0])\n    output = flow.scatter_nd(indices, update, [8])\n    test_case.assertTrue(np.allclose(output.numpy(), np_out, 0.0001, 0.0001))\n\n\ndef _test_scatter_nd_t(test_case, device):\n    indices = flow.tensor(\n        np.array([[0], [4], [2]]), dtype=flow.int, device=flow.device(device)\n    )\n    update = flow.tensor(\n        np.array([[1, 1, 1], [2, 2, 2], [3, 3, 3]]),\n        dtype=flow.float,\n        device=flow.device(device),\n    )\n    np_out = np.array(\n        [\n            [1.0, 1.0, 1.0],\n            [0.0, 0.0, 0.0],\n            [3.0, 3.0, 3.0],\n            [0.0, 0.0, 0.0],\n            [2.0, 2.0, 2.0],\n        ]\n    )\n    output = flow.scatter_nd(indices, update, [5, 3])\n    test_case.assertTrue(np.allclose(output.numpy(), np_out, 0.0001, 0.0001))\n\n\ndef _test_scatter_nd_backward(test_case, device):\n    indices = flow.tensor(\n        np.array([[1], [6], [4]]), dtype=flow.int, device=flow.device(device)\n    )\n    of_update = flow.tensor(\n        np.array([10.2, 5.1, 12.7]),\n        requires_grad=True,\n        dtype=flow.float,\n        device=flow.device(device),\n    )\n    np_out = np.array([0.0, 10.2, 0.0, 0.0, 12.7, 0.0, 5.1, 0.0])\n    np_grad = np.array([1.0, 1.0, 1.0])\n    output = flow.scatter_nd(indices, of_update, [8])\n    out_sum = output.sum()\n    out_sum.backward()\n    test_case.assertTrue(np.allclose(output.numpy(), np_out, 0.0001, 0.0001))\n    test_case.assertTrue(np.array_equal(of_update.grad.numpy(), np_grad))\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestScatter_nd(flow.unittest.TestCase):\n    def test_scatter_nd(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [\n            _test_scatter_nd,\n            _test_scatter_nd_t,\n            _test_scatter_nd_backward,\n        ]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_scatter_ops.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\n\nimport oneflow as flow\nimport oneflow.unittest\nimport numpy as np\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\ndef _get_indexes(device):\n    return (\n        constant(\n            torch.tensor(np.array([[0, 1], [1, 0]]), dtype=torch.int64, device=device)\n        ),\n        constant(\n            torch.tensor(np.array([[1, 0], [0, 1]]), dtype=torch.int64, device=device)\n        ),\n        constant(\n            torch.tensor(np.array([[1, 0], [1, 0]]), dtype=torch.int64, device=device)\n        ),\n        constant(\n            torch.tensor(np.array([[0, 1], [0, 1]]), dtype=torch.int64, device=device)\n        ),\n    )\n\n\ndef _test_scatter(test_case, test_scalar: bool, dim: int):\n    device = random_device()\n    input = random_tensor(ndim=2, dim0=2, dim1=2).to(device)\n    src = 3.14 if test_scalar else random_tensor(ndim=2, dim0=2, dim1=2).to(device)\n    y = torch.scatter(input, dim, oneof(*_get_indexes(device)), src)\n    return y\n\n\ndef _test_scatter_add(test_case, dim: int):\n    device = random_device()\n    input = random_tensor(ndim=2, dim0=2, dim1=2).to(device)\n    src = random_tensor(ndim=2, dim0=2, dim1=2).to(device)\n    y = torch.scatter_add(input, dim, oneof(*_get_indexes(device)), src)\n    return y\n\n\ndef _test_scatter_reduce(test_case, dim: int):\n    device = random_device()\n    input = random_tensor(ndim=2, dim0=2, dim1=2).to(device)\n    src = random_tensor(ndim=2, dim0=2, dim1=2).to(device)\n    y = torch.scatter(\n        input,\n        dim,\n        oneof(*_get_indexes(device)),\n        src,\n        reduce=oneof(\"add\", \"multiply\", nothing()),\n    )\n    return y\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestScatterOpsModule(flow.unittest.TestCase):\n    @autotest(n=10)\n    def test_scatter_with_random_data(test_case):\n        return _test_scatter(test_case, oneof(True, False), oneof(0, 1, -1))\n\n    @autotest(n=5)\n    def test_scatter_add_with_random_data(test_case):\n        return _test_scatter_add(test_case, oneof(0, 1))\n\n    @autotest(\n        n=5, auto_backward=False\n    )  # peihong: pytorch dose not support backward when reduce is add or multiply\n    def test_scatter_reduce_with_random_data(test_case):\n        return _test_scatter_reduce(test_case, oneof(0, 1))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_searchsorted.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\n\nfrom oneflow.test_utils.automated_test_util import *\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util.torch_flow_dual_object import autotest\n\n\ndef _test_search_sorted(test_case, input_dtype, device):\n    sorted_sequence = flow.tensor(\n        np.array([[1, 3, 5, 7, 9], [2, 4, 6, 8, 10]]),\n        dtype=input_dtype,\n        device=flow.device(device),\n    )\n    values = flow.tensor(\n        np.array([[3, 6, 9], [3, 6, 9]]), dtype=input_dtype, device=flow.device(device)\n    )\n    gt = np.array([[1, 3, 4], [1, 2, 4]])\n    output = flow.searchsorted(sorted_sequence, values)\n    test_case.assertTrue(np.allclose(output.numpy(), gt, 0.0001, 0.0001))\n    test_case.assertTrue(output.dtype == flow.int64)\n\n\ndef _test_search_sorted_1(test_case, input_dtype, device):\n    sorted_sequence = flow.tensor(\n        np.array([[1, 3, 5, 7, 9], [2, 4, 6, 8, 10]]),\n        dtype=input_dtype,\n        device=flow.device(device),\n    )\n    values = flow.tensor(\n        np.array([[3, 6, 9], [3, 6, 9]]), dtype=input_dtype, device=flow.device(device)\n    )\n    gt = np.array([[2, 3, 5], [1, 3, 4]])\n    output = flow.searchsorted(sorted_sequence, values, right=True, side=\"right\")\n    test_case.assertTrue(np.allclose(output.numpy(), gt, 0.0001, 0.0001))\n    test_case.assertTrue(output.dtype == flow.int64)\n\n\ndef _test_search_sorted_2(test_case, input_dtype, device):\n    sorted_sequence_1d = flow.tensor(\n        np.array([1, 3, 5, 7, 9]), dtype=input_dtype, device=flow.device(device)\n    )\n    values = flow.tensor(\n        np.array([3, 6, 9]), dtype=input_dtype, device=flow.device(device)\n    )\n    gt = np.array([1, 3, 4])\n    output = flow.searchsorted(sorted_sequence_1d, values)\n    test_case.assertTrue(np.allclose(output.numpy(), gt, 0.0001, 0.0001))\n    test_case.assertTrue(output.dtype == flow.int64)\n\n\ndef _test_search_sorted_3(test_case, input_dtype, device):\n    sorted_sequence = flow.tensor(\n        np.array([[1, 3, 5, 7, 9], [2, 4, 6, 8, 10]]),\n        dtype=input_dtype,\n        device=flow.device(device),\n    )\n    values = flow.tensor(\n        np.array([[3, 6, 9], [3, 6, 9]]), dtype=input_dtype, device=flow.device(device)\n    )\n    gt = np.array([[1, 3, 4], [1, 2, 4]])\n    output = flow.searchsorted(sorted_sequence, values, out_int32=True)\n    test_case.assertTrue(np.allclose(output.numpy(), gt, 0.0001, 0.0001))\n    test_case.assertTrue(output.dtype == flow.int32)\n\n\ndef _test_search_sorted_4(test_case, input_dtype, device):\n    sorted_sequence = flow.tensor(\n        np.array([[1, 3, 5, 7, 9], [2, 4, 6, 8, 10]]),\n        dtype=input_dtype,\n        device=flow.device(device),\n    )\n    values = flow.tensor(\n        np.array([[3, 6, 9], [3, 6, 9]]), dtype=input_dtype, device=flow.device(device)\n    )\n    sorter = flow.tensor(\n        np.array([[4, 3, 2, 1, 0], [3, 2, 4, 0, 1]]),\n        dtype=flow.int64,\n        device=flow.device(device),\n    )\n    gt = np.array([[0, 5, 5], [0, 0, 2]])\n    output = flow.searchsorted(sorted_sequence, values, sorter=sorter)\n    test_case.assertTrue(np.allclose(output.numpy(), gt, 0.0001, 0.0001))\n    test_case.assertTrue(output.dtype == flow.int64)\n\n\ndef _test_search_sorted_5(test_case, input_dtype, device):\n    sorted_sequence_1d = flow.tensor(\n        np.array([1, 3, 5, 7, 9]), dtype=input_dtype, device=flow.device(device)\n    )\n    gt = np.array(2)\n    output = flow.searchsorted(sorted_sequence_1d, 5)\n    test_case.assertTrue(np.allclose(output.numpy(), gt, 0.0001, 0.0001))\n    test_case.assertTrue(output.dtype == flow.int64)\n\n\ndef _test_search_sorted_6(test_case, input_dtype, device):\n    sorted_sequence_1d = flow.tensor(\n        np.array([1, 3, 5, 7, 9]), dtype=input_dtype, device=flow.device(device)\n    )\n    gt = np.array(3)\n    output = flow.searchsorted(sorted_sequence_1d, 5, right=True, side=\"right\")\n    test_case.assertTrue(np.allclose(output.numpy(), gt, 0.0001, 0.0001))\n    test_case.assertTrue(output.dtype == flow.int64)\n\n\ndef _test_search_sorted_7(test_case, input_dtype, device):\n    sorted_sequence_1d = flow.tensor(\n        np.array([1, 3, 5, 7, 9]), dtype=input_dtype, device=flow.device(device)\n    )\n    gt = np.array(2)\n    output = flow.searchsorted(sorted_sequence_1d, 5, out_int32=True)\n    test_case.assertTrue(np.allclose(output.numpy(), gt, 0.0001, 0.0001))\n    test_case.assertTrue(output.dtype == flow.int32)\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestSearchSorted(flow.unittest.TestCase):\n    def test_search_sorted(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [\n            _test_search_sorted,\n            _test_search_sorted_1,\n            _test_search_sorted_2,\n            _test_search_sorted_3,\n            _test_search_sorted_4,\n            _test_search_sorted_5,\n            _test_search_sorted_6,\n            _test_search_sorted_7,\n        ]\n        arg_dict[\"input_dtype\"] = [\n            flow.int8,\n            flow.int32,\n            flow.int64,\n            flow.float,\n            flow.double,\n        ]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n    @autotest(n=20, auto_backward=False, check_dtype=True)\n    def test_search_sorted(test_case):\n        device = random_device()\n        sorted_sequence = random_tensor(ndim=2, dim0=2, dim1=3).to(device)\n        values = random_tensor(ndim=2, dim0=2).to(device)\n        right = oneof(True, False)\n        y = torch.searchsorted(\n            sorted_sequence, values, out_int32=oneof(True, False), right=right,\n        )\n        return y\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_select.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nfrom random import shuffle\n\nfrom oneflow.test_utils.automated_test_util import *\nfrom oneflow.test_utils.automated_test_util import util\nimport oneflow as flow\nimport oneflow.unittest\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestSelect(flow.unittest.TestCase):\n    @autotest(check_graph=True)\n    def test_flow_select(test_case):\n        device = random_device()\n        x = random_tensor(\n            ndim=4,\n            dim0=random(3, 6),\n            dim1=random(3, 6),\n            dim2=random(3, 6),\n            dim3=random(3, 6),\n        ).to(device)\n        dim = random(-4, 3).to(int)\n        index = random(0, 2).to(int)\n        z = torch.select(x, dim, index)\n        return z\n\n    # TODO:(zhaoluyang) some bug in as_strided backward to be fixed\n    @autotest(n=10, auto_backward=False, check_graph=True)\n    def test_flow_select_with_stride(test_case):\n        device = random_device()\n        x = random_tensor(\n            ndim=4,\n            dim0=random(3, 6),\n            dim1=random(3, 6),\n            dim2=random(3, 6),\n            dim3=random(3, 6),\n        ).to(device)\n        dim = random(-4, 3).to(int)\n        index = random(0, 2).to(int)\n        perm = [0, 1, 2, 3]\n        shuffle(perm)\n        y = x.permute(perm)\n        z = torch.select(y, dim, index)\n        return z\n\n    @autotest(check_graph=True)\n    def test_flow_select_1dim(test_case):\n        device = random_device()\n        x = random_tensor(ndim=1, dim0=random(3, 6),).to(device)\n        index = random(0, 2).to(int)\n        z = torch.select(x, 0, index)\n        return z\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_shutting_down.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow\nimport os\n\n\nworld_size = os.getenv(\"WORLD_SIZE\")\n\n\nclass _TestCallWhenShuttingDown:\n    def __init__(self):\n        self.oneflow = oneflow\n        tensor = oneflow.ones((2, 2))\n        print(tensor)\n\n    def __del__(self, of=oneflow):\n        try:\n            if world_size == 1:\n                tensor = of.ones((2, 2))\n        except:\n            # Please refer to: https://github.com/Oneflow-Inc/OneTeam/issues/1219#issuecomment-1092370402\n            print(\"__del__ at shutting down phase in Python is not stable.\")\n\n\ntest_call_when_shutting_down = _TestCallWhenShuttingDown()\n\n\nclass _TestSyncWhenShuttingDown:\n    def __init__(self):\n        self.eager = oneflow._oneflow_internal.eager\n\n    def __del__(self):\n        try:\n            self.eager.Sync()\n        except:\n            # Please refer to: https://github.com/Oneflow-Inc/OneTeam/issues/1219#issuecomment-1092370402\n            print(\"__del__ at shutting down phase in Python is not stable.\")\n\n\ntest_sync_when_shutting_down = _TestSyncWhenShuttingDown()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_sign.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\ndef _test_sign_impl(test_case, shape, device):\n    np_input = np.random.randn(*shape)\n    of_input = flow.tensor(\n        np_input, dtype=flow.float32, device=flow.device(device), requires_grad=True\n    )\n    of_out = flow.sign(of_input)\n    np_out = np.sign(np_input)\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001))\n    of_out = of_out.sum()\n    of_out.backward()\n    np_grad = np.zeros_like(np_input)\n    test_case.assertTrue(np.allclose(of_input.grad.numpy(), np_grad, 0.0001, 0.0001))\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestSign(flow.unittest.TestCase):\n    def test_sign(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"shape\"] = [(2, 3), (2, 4, 5, 6)]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            _test_sign_impl(test_case, *arg)\n\n    @autotest(n=5)\n    def test_sign_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor().to(device)\n        y = torch.sign(x)\n        return y\n\n    @autotest(n=5, auto_backward=False, check_graph=True)\n    def test_sign_with_0_size_data(test_case):\n        device = random_device()\n        x = random_tensor(4, 2, 3, 0, 4).to(device)\n        y = torch.sign(x)\n        return y\n\n    @autotest(n=5, auto_backward=False, check_graph=True)\n    def test_sign_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor().to(device=device, dtype=torch.bool)\n        y = torch.sign(x)\n        return y\n\n    @autotest(n=5, auto_backward=False, check_graph=True)\n    def test_sign_with_0dim_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=0).to(device)\n        y = torch.sign(x)\n        return y\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_single_threaded_vm.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport subprocess\nimport sys\nimport os\nimport unittest\nimport oneflow as flow\nimport oneflow.unittest\n\n\nclass TestSingleThreadedVM(flow.unittest.TestCase):\n    @flow.unittest.skip_unless_1n2d()\n    def test_ddp_in_single_threaded_vm(test_case):\n        # Environment variables of current process like ONEFLOW_TEST_DEVICE_NUM\n        # and environment variables about distributed training (i.e. MASTER_ADDR,\n        # MASTER_PORT, WORLD_SIZE, RANK) are all in `env`.\n        env = os.environ.copy()\n        env[\"ONEFLOW_VM_MULTI_THREAD\"] = \"0\"\n        p = subprocess.run(\n            [sys.executable, \"test_ddp.py\"],\n            cwd=os.path.dirname(os.path.realpath(__file__)),\n            env=env,\n        )\n        test_case.assertEqual(p.returncode, 0)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_skip_layer_norm.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport os\nimport numpy as np\nimport unittest\n\nimport oneflow as flow\nimport oneflow.nn as nn\nimport oneflow.unittest\nfrom collections import OrderedDict\nfrom oneflow.test_utils.test_util import GenArgList\n\nis_profiling = False\n\n\ndef compare_result(test_case, a, b, rtol=1e-5, atol=1e-8):\n    test_case.assertTrue(\n        np.allclose(a.numpy(), b.numpy(), rtol=rtol, atol=atol),\n        f\"\\na\\n{a.numpy()}\\n{'-' * 80}\\nb:\\n{b.numpy()}\\n{'*' * 80}\\ndiff:\\n{a.numpy() - b.numpy()}\",\n    )\n\n\nclass NaiveSkipLayerNorm(nn.Module):\n    def __init__(self):\n        super().__init__()\n\n    def forward(\n        self,\n        x: flow.Tensor,\n        gamma: flow.Tensor,\n        beta: flow.Tensor,\n        bias: flow.Tensor = None,\n        skip: flow.Tensor = None,\n        alpha: float = 1e-5,\n        eps: float = 1e-6,\n    ) -> flow.Tensor:\n        begin_norm_axis = len(x.shape) - 1\n        begin_params_axis = len(x.shape) - 1\n        if bias is not None:\n            x = flow._C.add(input=x, other=bias)\n        if skip is not None:\n            skip = skip * alpha\n            x = flow._C.add(input=x, other=skip)\n        return flow._C.layer_norm_affine(\n            x,\n            gamma,\n            beta,\n            begin_norm_axis=begin_norm_axis,\n            begin_params_axis=begin_params_axis,\n            epsilon=eps,\n        )\n\n\nclass FusedSkipLayerNorm(nn.Module):\n    def __init__(self):\n        super().__init__()\n\n    def forward(\n        self,\n        x: flow.Tensor,\n        gamma: flow.Tensor,\n        beta: flow.Tensor,\n        bias: flow.Tensor = None,\n        skip: flow.Tensor = None,\n        alpha: float = 1e-5,\n        eps: float = 1e-6,\n    ) -> flow.Tensor:\n        return flow._C.skip_layer_norm(\n            x=x, gamma=gamma, beta=beta, bias=bias, skip=skip, alpha=alpha, epsilon=eps\n        )\n\n\ndef _test_skip_layer_norm(\n    test_case,\n    x_shape,\n    has_gamma,\n    has_beta,\n    has_bias,\n    has_skip,\n    eps=1e-6,\n    alpha=1e-5,\n    dtype=flow.float32,\n):\n    print(\n        f\"x_shape: {x_shape}\\nhas_gamma: {has_gamma}\\nhas_beta: {has_beta}\\nhas_bias: {has_bias}\\nhas_skip: {has_skip}\\ndtype: {dtype}\\n\"\n    )\n\n    normalize_shape = list()\n    normalize_shape.append(x_shape[-1])\n\n    np_dtype = np.float16 if dtype is flow.float16 else np.float32\n\n    # generate np array\n    np_x = np.random.randn(*x_shape).astype(np_dtype)\n\n    naive_flow_gamma = None\n    fused_flow_gamma = None\n    if has_gamma:\n        np_gamma = np.random.randn(*normalize_shape).astype(np_dtype)\n        naive_flow_gamma = flow.tensor(np_gamma).to(device=\"cuda\", dtype=dtype)\n        fused_flow_gamma = flow.tensor(np_gamma).to(device=\"cuda\", dtype=dtype)\n    else:\n        np_gamma = np.ones(*normalize_shape).astype(np_dtype)\n        naive_flow_gamma = flow.tensor(np_gamma).to(device=\"cuda\", dtype=dtype)\n\n    naive_flow_beta = None\n    fused_flow_beta = None\n    if has_beta:\n        np_beta = np.random.randn(*normalize_shape).astype(np_dtype)\n        naive_flow_beta = flow.tensor(np_beta).to(device=\"cuda\", dtype=dtype)\n        fused_flow_beta = flow.tensor(np_beta).to(device=\"cuda\", dtype=dtype)\n    else:\n        np_beta = np.zeros(*normalize_shape).astype(np_dtype)\n        naive_flow_beta = flow.tensor(np_beta).to(device=\"cuda\", dtype=dtype)\n\n    flow_bias = None\n    if has_bias:\n        np_bias = np.random.randn(*normalize_shape).astype(np_dtype)\n        flow_bias = flow.tensor(np_bias).to(device=\"cuda\", dtype=dtype)\n\n    flow_skip_naive = None\n    flow_skip_fused = None\n    np_skip = None\n    if has_skip:\n        np_skip = np.random.randn(*x_shape).astype(np_dtype)\n        flow_skip_naive = flow.tensor(np_skip).to(device=\"cuda\", dtype=dtype)\n        flow_skip_fused = flow.tensor(np_skip).to(device=\"cuda\", dtype=dtype)\n\n    # naive process\n    flow_naive_module = NaiveSkipLayerNorm()\n    flow_x_naive = flow.tensor(np_x).to(device=\"cuda\", dtype=dtype)\n    flow_y_naive = flow_naive_module.forward(\n        x=flow_x_naive,\n        gamma=naive_flow_gamma,\n        beta=naive_flow_beta,\n        bias=flow_bias,\n        skip=flow_skip_naive,\n        alpha=alpha,\n        eps=eps,\n    )\n\n    # fused process\n    flow_fused_module = FusedSkipLayerNorm()\n    flow_x_fused = flow.tensor(np_x).to(device=\"cuda\", dtype=dtype)\n    flow_y_fused = flow_fused_module.forward(\n        x=flow_x_fused,\n        gamma=fused_flow_gamma,\n        beta=fused_flow_beta,\n        bias=flow_bias,\n        skip=flow_skip_fused,\n        alpha=alpha,\n        eps=eps,\n    )\n\n    if dtype is flow.float16:\n        compare_result(test_case, flow_y_naive, flow_y_fused, 1e-2, 1e-2)\n    else:\n        compare_result(test_case, flow_y_naive, flow_y_fused, 1e-4, 1e-4)\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n@flow.unittest.skip_unless_1n1d()\nclass TestSkipLayerNorm(flow.unittest.TestCase):\n    def test_gather(test_case):\n        arg_dict = OrderedDict()\n\n        # set up test functions\n        arg_dict[\"test_fun\"] = [\n            _test_skip_layer_norm,\n        ]\n\n        # set up test parameters\n        if is_profiling:\n            arg_dict[\"x_shape\"] = [[1, 5120]]\n            arg_dict[\"has_gamma\"] = [True]\n            arg_dict[\"has_beta\"] = [True]\n            arg_dict[\"has_bias\"] = [True]\n            arg_dict[\"has_skip\"] = [True]\n            arg_dict[\"eps\"] = [1e-6]\n            arg_dict[\"alpha\"] = [1e-5]\n            arg_dict[\"dtype\"] = [flow.float32]\n        else:\n            arg_dict[\"x_shape\"] = [[1, 5120]]\n            arg_dict[\"has_gamma\"] = [True, False]\n            arg_dict[\"has_beta\"] = [True, False]\n            arg_dict[\"has_bias\"] = [True, False]\n            arg_dict[\"has_skip\"] = [True, False]\n            arg_dict[\"eps\"] = [1e-6]\n            arg_dict[\"alpha\"] = [1e-5]\n            arg_dict[\"dtype\"] = [flow.float32, flow.float16]\n\n        # run test functions\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_skip_rms_norm.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport os\nimport numpy as np\nimport unittest\n\nimport oneflow as flow\nimport oneflow.nn as nn\nimport oneflow.unittest\nfrom collections import OrderedDict\nfrom oneflow.test_utils.test_util import GenArgList\n\nis_profiling = False\n\n\ndef compare_result(test_case, a, b, rtol=1e-5, atol=1e-8):\n    test_case.assertTrue(\n        np.allclose(a.numpy(), b.numpy(), rtol=rtol, atol=atol),\n        f\"\\na\\n{a.numpy()}\\n{'-' * 80}\\nb:\\n{b.numpy()}\\n{'*' * 80}\\ndiff:\\n{a.numpy() - b.numpy()}\",\n    )\n\n\nclass NaiveSkipRMSNorm(nn.Module):\n    def __init__(self):\n        super().__init__()\n\n    def forward(\n        self,\n        x: flow.Tensor,\n        weight: flow.Tensor,\n        bias: flow.Tensor = None,\n        skip: flow.Tensor = None,\n        alpha: float = 1e-5,\n        eps: float = 1e-6,\n    ) -> flow.Tensor:\n        if bias is not None:\n            x = flow._C.add(input=x, other=bias)\n        if skip is not None:\n            skip = skip * alpha\n            x = flow._C.add(input=x, other=skip)\n        return flow._C.rms_norm(x, weight, [x.shape[-1]], eps)\n\n\nclass FusedSkipRMSNorm(nn.Module):\n    def __init__(self):\n        super().__init__()\n\n    def forward(\n        self,\n        x: flow.Tensor,\n        weight: flow.Tensor,\n        bias: flow.Tensor = None,\n        skip: flow.Tensor = None,\n        alpha: float = 1e-5,\n        eps: float = 1e-6,\n    ) -> flow.Tensor:\n        return flow._C.skip_rms_norm(\n            x=x, weight=weight, bias=bias, skip=skip, epsilon=eps, alpha=alpha\n        )\n\n\ndef _test_skip_rms_norm(\n    test_case,\n    x_shape,\n    has_weight,\n    has_bias,\n    has_skip,\n    eps=1e-6,\n    alpha=1e-5,\n    dtype=flow.float32,\n):\n    print(\n        f\"x_shape: {x_shape}\\nhas_weight: {has_weight}\\nhas_bias: {has_bias}\\nhas_skip: {has_skip}\\ndtype: {dtype}\\n\"\n    )\n\n    normalize_shape = list()\n    normalize_shape.append(x_shape[-1])\n\n    np_dtype = np.float16 if dtype is flow.float16 else np.float32\n\n    # generate np array\n    np_x = np.random.randn(*x_shape).astype(np_dtype)\n\n    naive_flow_weight = None\n    fused_flow_weight = None\n    if has_weight:\n        np_gamma = np.random.randn(*normalize_shape).astype(np_dtype)\n        naive_flow_weight = flow.tensor(np_gamma).to(device=\"cuda\", dtype=dtype)\n        fused_flow_weight = flow.tensor(np_gamma).to(device=\"cuda\", dtype=dtype)\n    else:\n        np_gamma = np.ones(*normalize_shape).astype(np_dtype)\n        naive_flow_gamma = flow.tensor(np_gamma).to(device=\"cuda\", dtype=dtype)\n\n    flow_bias = None\n    if has_bias:\n        np_bias = np.random.randn(*normalize_shape).astype(np_dtype)\n        flow_bias = flow.tensor(np_bias).to(device=\"cuda\", dtype=dtype)\n\n    flow_skip_naive = None\n    flow_skip_fused = None\n    np_skip = None\n    if has_skip:\n        np_skip = np.random.randn(*x_shape).astype(np_dtype)\n        flow_skip_naive = flow.tensor(np_skip).to(device=\"cuda\", dtype=dtype)\n        flow_skip_fused = flow.tensor(np_skip).to(device=\"cuda\", dtype=dtype)\n\n    # naive process\n    flow_naive_module = NaiveSkipRMSNorm()\n    flow_x_naive = flow.tensor(np_x).to(device=\"cuda\", dtype=dtype)\n    flow_y_naive = flow_naive_module.forward(\n        x=flow_x_naive,\n        weight=naive_flow_weight,\n        bias=flow_bias,\n        skip=flow_skip_naive,\n        alpha=alpha,\n        eps=eps,\n    )\n\n    # fused process\n    flow_fused_module = FusedSkipRMSNorm()\n    flow_x_fused = flow.tensor(np_x).to(device=\"cuda\", dtype=dtype)\n    flow_y_fused = flow_fused_module.forward(\n        x=flow_x_fused,\n        weight=fused_flow_weight,\n        bias=flow_bias,\n        skip=flow_skip_fused,\n        alpha=alpha,\n        eps=eps,\n    )\n\n    if dtype is flow.float16:\n        compare_result(test_case, flow_y_naive, flow_y_fused, 1e-2, 1e-2)\n    else:\n        compare_result(test_case, flow_y_naive, flow_y_fused, 1e-4, 1e-4)\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n@flow.unittest.skip_unless_1n1d()\nclass TestSkipRMSNorm(flow.unittest.TestCase):\n    def test_gather(test_case):\n        arg_dict = OrderedDict()\n\n        # set up test functions\n        arg_dict[\"test_fun\"] = [\n            _test_skip_rms_norm,\n        ]\n\n        # set up test parameters\n        if is_profiling:\n            arg_dict[\"x_shape\"] = [[1, 5120]]\n            arg_dict[\"has_weight\"] = [True]\n            arg_dict[\"has_bias\"] = [True]\n            arg_dict[\"has_skip\"] = [True]\n            arg_dict[\"eps\"] = [1e-6]\n            arg_dict[\"alpha\"] = [1e-5]\n            arg_dict[\"dtype\"] = [flow.float32]\n        else:\n            arg_dict[\"x_shape\"] = [[1, 5120]]\n            arg_dict[\"has_weight\"] = [True, False]\n            arg_dict[\"has_bias\"] = [True, False]\n            arg_dict[\"has_skip\"] = [True, False]\n            arg_dict[\"eps\"] = [1e-6]\n            arg_dict[\"alpha\"] = [1e-5]\n            arg_dict[\"dtype\"] = [flow.float32, flow.float16]\n\n        # run test functions\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_slice.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\nfrom random import randint\n\nimport numpy as np\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\ndef _test_slice(test_case, device):\n    np_arr = np.random.randn(3, 6, 9).astype(np.float32)\n    x = flow.tensor(np_arr, device=flow.device(device))\n    tup_list = [[None, None, None], [0, 5, 2], [0, 6, 3]]\n    y = flow.slice(x, slice_tup_list=tup_list)\n    flow_tmp = x[0:3, 0:5, 0:6]\n    y = flow_tmp[::1, ::2, ::3]\n    tmp = np_arr[0:3, 0:5, 0:6]\n    np_out = tmp[::1, ::2, ::3]\n    test_case.assertTrue(np.array_equal(y.numpy(), np_out))\n\n\ndef _test_slice_empty(test_case, device):\n    np_arr = np.random.randn(10).astype(np.float32)\n    x = flow.tensor(np_arr, device=flow.device(device))\n    y = x[3:3]\n    test_case.assertTrue(y.shape, flow.Size((0,)))\n    np_out = np_arr[3:3]\n    test_case.assertTrue(np.array_equal(y.numpy(), np_out))\n\n\ndef _test_slice_1_dim(test_case, device):\n    np_arr = np.random.randn(100).astype(np.float32)\n    x = flow.tensor(np_arr, device=flow.device(device))\n    test_case.assertTrue(np.allclose(x[1].numpy(), np_arr[1], 1e-05, 1e-05))\n    test_case.assertTrue(np.allclose(x[99].numpy(), np_arr[99], 1e-05, 1e-05))\n    test_case.assertTrue(np.allclose(x[0:2].numpy(), np_arr[0:2], 1e-05, 1e-05))\n\n\ndef _test_slice_3_dim(test_case, device):\n    np_arr = np.random.randn(2, 3, 4).astype(np.float32)\n    x = flow.tensor(np_arr, device=flow.device(device))\n    test_case.assertTrue(np.allclose(x[:, 0].numpy(), np_arr[:, 0], 1e-05, 1e-05))\n\n\ndef _test_slice_4_dim(test_case, device):\n    np_arr = np.random.randn(5, 3, 6, 9).astype(np.float32)\n    x = flow.tensor(np_arr, device=flow.device(device))\n    tup_list = [[0, 5, 2], [None, None, None], [0, 5, 2], [0, 6, 3]]\n    y = flow.slice(x, slice_tup_list=tup_list)\n    tmp = np_arr[0:5, 0:3, 0:5, 0:6]\n    np_out = tmp[::2, ::1, ::2, ::3]\n    test_case.assertTrue(np.array_equal(y.numpy(), np_out))\n\n\ndef _test_slice_with_int_index(test_case, device):\n    np_arr = np.random.randn(2, 3, 4).astype(np.float32)\n    x = flow.tensor(np_arr, device=flow.device(device))\n    of_out = x[0, 1:2]\n    np_out = np_arr[0, 1:2]\n    test_case.assertTrue(np.array_equal(of_out.numpy(), np_out))\n    np_arr = np.random.randn(2, 3, 4).astype(np.float32)\n    x = flow.tensor(np_arr, device=flow.device(device))\n    of_out = x[0, :]\n    np_out = np_arr[0, :]\n    test_case.assertTrue(np.array_equal(of_out.numpy(), np_out))\n    np_arr = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]).astype(np.float32)\n    x = flow.tensor(np_arr, device=flow.device(device))\n    of_out = x[0, :, :]\n    np_out = np_arr[0, :, :]\n    test_case.assertTrue(np.array_equal(of_out.numpy(), np_out))\n    np_arr = np.random.randn(2, 3, 4, 5).astype(np.float32)\n    x = flow.tensor(np_arr, device=flow.device(device))\n    of_out = x[0, :, :, :]\n    np_out = np_arr[0, :, :, :]\n    test_case.assertTrue(np.array_equal(of_out.numpy(), np_out))\n\n\ndef _test_slice_negative_index(test_case, device):\n    np_arr = np.random.randn(4, 5, 6)\n    x = flow.tensor(np_arr, dtype=flow.float32, device=flow.device(device))\n    test_case.assertTrue(np.allclose(x[-1].numpy(), np_arr[-1], 0.0001, 0.0001))\n    test_case.assertTrue(np.allclose(x[-2].numpy(), np_arr[-2], 0.0001, 0.0001))\n    test_case.assertTrue(np.allclose(x[-3].numpy(), np_arr[-3], 0.0001, 0.0001))\n    test_case.assertTrue(np.allclose(x[-4].numpy(), np_arr[-4], 0.0001, 0.0001))\n\n\ndef _test_slice_ellipsis_type(test_case, device):\n    np_arr = np.random.randn(2, 3, 4, 5, 6, 7).astype(np.float32)\n    x = flow.tensor(np_arr, device=flow.device(device))\n    of_out = x[..., ::2, ::2, 3:4]\n    np_out = np_arr[..., ::2, ::2, 3:4]\n    test_case.assertTrue(np.array_equal(of_out.numpy(), np_out))\n    of_out = x[..., 1:2, ::2, 1, ::3]\n    np_out = np_arr[..., 1:2, ::2, 1, ::3]\n    test_case.assertTrue(np.array_equal(of_out.numpy(), np_out))\n    of_out = x[0, 2, ..., 1, 1:2]\n    np_out = np_arr[0, 2, ..., 1, 1:2]\n    test_case.assertTrue(np.array_equal(of_out.numpy(), np_out))\n    of_out = x[::2, ..., 1:2]\n    np_out = np_arr[::2, ..., 1:2]\n    test_case.assertTrue(np.array_equal(of_out.numpy(), np_out))\n\n\ndef _test_slice_backward(test_case, device):\n    np_arr = np.random.randn(3, 6, 9).astype(np.float32)\n    x = flow.tensor(np_arr, device=flow.device(device), requires_grad=True)\n    tup_list = [[None, None, None], [0, 5, 2], [0, 6, 3]]\n    y = flow.slice(x, slice_tup_list=tup_list)\n    z = y.sum()\n    z.backward()\n    np_grad = np.zeros((3, 6, 9))\n    np_grad[0:3, 0:5, 0:6][::1, ::2, ::3] = 1\n    test_case.assertTrue(np.array_equal(x.grad.numpy(), np_grad))\n\n\ndef _test_slice_scalar(test_case, device):\n    dtype = [flow.int8, flow.int16, flow.int32, flow.int64]\n    x = flow.randn(50, 534, 800, device=device)\n    for d in dtype:\n        scalar = flow.tensor(3, dtype=d, device=device)\n        y = x[scalar]\n        test_case.assertTrue(y.shape, (534, 800))\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestSlice(flow.unittest.TestCase):\n    def test_slice(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [\n            _test_slice,\n            _test_slice_empty,\n            _test_slice_1_dim,\n            _test_slice_3_dim,\n            _test_slice_4_dim,\n            _test_slice_with_int_index,\n            _test_slice_negative_index,\n            _test_slice_ellipsis_type,\n            _test_slice_backward,\n            _test_slice_scalar,\n        ]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestSliceUpdate(flow.unittest.TestCase):\n    def test_slice_update(test_case):\n        x = np.array([1, 1, 1, 1, 1]).astype(np.float32)\n        input = flow.tensor(x)\n        update = flow.tensor(np.array([2, 3, 4]).astype(np.float32))\n        output = np.array([1.0, 2.0, 3.0, 4.0, 1.0])\n        flow.slice_update(input, update, slice_tup_list=[[1, 4, 1]])\n        test_case.assertTrue(np.array_equal(input.numpy(), output))\n\n    def test_slice_update_negative_index(test_case):\n        np_arr = np.zeros(shape=(2, 3, 4))\n        input = flow.tensor(np_arr, dtype=flow.float32)\n        np_arr[-1] = 1\n        input[-1] = 1\n        test_case.assertTrue(np.array_equal(input.numpy(), np_arr))\n\n    def test_slice_update_scalar_integer_tensor_index(test_case):\n        np_arr_a = np.random.rand(133, 1, 15)\n        np_arr_b = np.random.rand(133, 2, 1)\n\n        a_torch = torch.tensor(np_arr_a)\n        b_torch = torch.tensor(np_arr_b)\n        pos_torch = torch.tensor(0)\n        a_torch[:, 0, pos_torch] = b_torch[:, 1, 0]\n\n        a_flow = flow.tensor(np_arr_a)\n        b_flow = flow.tensor(np_arr_b)\n        pos_flow = flow.tensor(0)\n        a_flow[:, 0, pos_flow] = b_flow[:, 1, 0]\n\n        test_case.assertTrue(\n            np.allclose(a_flow.numpy(), a_torch.cpu().numpy(), rtol=1e-5, atol=1e-5,)\n        )\n\n    def test_slice_update_scalar_boolean_tensor_index(test_case):\n        np_arr_a = np.random.rand(2, 1, 2)\n        np_arr_b = np.random.rand(2, 2, 1)\n\n        a_torch = torch.tensor(np_arr_a)\n        b_torch = torch.tensor(np_arr_b)\n        pos_torch = torch.tensor(True)\n        a_torch[:, 0, pos_torch] = b_torch[:, 1, 0]\n\n        a_flow = flow.tensor(np_arr_a)\n        b_flow = flow.tensor(np_arr_b)\n        pos_flow = flow.tensor(True)\n        a_flow[:, 0, pos_flow] = b_flow[:, 1, 0]\n\n        test_case.assertTrue(\n            np.allclose(a_flow.numpy(), a_torch.cpu().numpy(), rtol=1e-5, atol=1e-5,)\n        )\n\n    def test_slice_update_negative_index_graph(test_case):\n        np_arr = np.zeros(shape=(2, 3, 4))\n        input = flow.tensor(np_arr, dtype=flow.float32)\n        np_arr[-1] = 1\n\n        @flow.nn.Graph.trace\n        def test_func():\n            input[-1] = 1\n            return input\n\n        out = test_func()\n        test_case.assertTrue(np.array_equal(out.numpy(), np_arr))\n\n    def test_slice_update_different_dtype(test_case):\n        x = np.array([1, 1, 1, 1, 1]).astype(np.float32)\n        for value_type in [np.int32, np.float64]:\n            input = flow.tensor(x)\n            update = flow.tensor(np.array([2, 3, 4]).astype(value_type))\n            output = np.array([1.0, 2.0, 3.0, 4.0, 1.0])\n            flow.slice_update(input, update, slice_tup_list=[[1, 4, 1]])\n            test_case.assertTrue(np.array_equal(input.numpy(), output))\n\n    def test_slice_update_ellipsis_type(test_case):\n        np_arr = np.zeros(shape=(2, 3, 4, 5, 6))\n        input = flow.tensor(np_arr, dtype=flow.float32)\n        np_arr[0, ::1, ..., 2:3] = 1\n        input[0, ::1, ..., 2:3] = 1\n        test_case.assertTrue(np.array_equal(input.numpy(), np_arr))\n\n    def test_slice_update_ellipsis_type_graph(test_case):\n        np_arr = np.zeros(shape=(2, 3, 4, 5, 6))\n        input = flow.tensor(np_arr, dtype=flow.float32)\n        np_arr[0, ::1, ..., 2:3] = 1\n\n        @flow.nn.Graph.trace\n        def test_func():\n            input[0, ::1, ..., 2:3] = 1\n            return input\n\n        out = test_func()\n        test_case.assertTrue(np.array_equal(out.numpy(), np_arr))\n\n    def test_slice_update_grad_graph(test_case):\n        x = np.array([1, 1, 1, 1, 1]).astype(np.float32)\n        input = flow.tensor(x, requires_grad=True)\n        update = flow.tensor(np.array([2, 3, 4]).astype(np.float32), requires_grad=True)\n        output = np.array([1.0, 2.0, 3.0, 4.0, 1.0])\n\n        class TestModule(flow.nn.Module):\n            def __init__(self):\n                super().__init__()\n                self.ref_grad = flow.nn.Parameter(flow.zeros(5))\n                self.value_grad = flow.nn.Parameter(flow.zeros(3))\n\n            def forward(self, ref, value):\n                x = ref + self.ref_grad\n                y = value + self.value_grad\n                return flow._C.slice_update(x, y, [1,], [4,], [1,])\n\n        test_m = TestModule()\n        of_sgd = flow.optim.SGD(test_m.parameters(), lr=1.0, momentum=0.0)\n\n        class TestSliceUpdateGraph(flow.nn.Graph):\n            def __init__(self):\n                super().__init__()\n                self.m = test_m\n                self.add_optimizer(of_sgd)\n\n            def build(self, ref, update):\n                x = self.m(ref, update)\n                x.sum().backward()\n                return x\n\n        slice_update_g = TestSliceUpdateGraph()\n\n        y = slice_update_g(input, update)\n\n        # forward\n        test_case.assertTrue(np.array_equal(y.numpy(), output))\n        # ref grad\n        ref_grad = np.array([1.0, 0.0, 0.0, 0.0, 1.0]).astype(np.float32)\n        test_case.assertTrue(np.array_equal(-test_m.ref_grad, ref_grad))\n        # value grad\n        value_grad = np.array([1.0, 1.0, 1.0]).astype(np.float32)\n        test_case.assertTrue(np.array_equal(-test_m.value_grad, value_grad))\n\n    def test_random_nd_slice_update_in_non_contiguous_tensor(test_case):\n        def get_random_slice_tuple(shape):\n            slice_tup = []\n            slice_size = []\n            for i in range(len(shape)):\n                start = randint(0, shape[i] - 1)\n                end = randint(start + 1, shape[i])\n                step = randint(1, end - start + 1)\n                slice_tup.append(slice(start, end, step))\n                slice_size.append((end - start + step - 1) // step)\n            return tuple(slice_tup), tuple(slice_size)\n\n        def get_random_update_shape_and_perm(shape):\n            perm = flow.randperm(len(shape)).tolist()\n            no_perm_shape = [shape[i] for i in perm]\n            inv_perm = [0] * len(shape)\n            for i in range(len(shape)):\n                inv_perm[perm[i]] = i\n            return no_perm_shape, inv_perm\n\n        def compare_result_between_oneflow_and_numpy(test_case, shape):\n            device = random_device().value()\n            # non-contiguous ref\n            ref = (\n                flow.rand(shape, dtype=flow.float32)\n                .to(device)\n                .permute(flow.randperm(len(shape)).tolist())\n            )\n            ref_np = ref.detach().clone().numpy()\n            shape = ref.shape\n            # slice param\n            slice_tup, slice_size = get_random_slice_tuple(shape)\n            # non-contiguous update\n            no_perm_shape, perm = get_random_update_shape_and_perm(slice_size)\n            update = (\n                flow.rand(no_perm_shape, dtype=flow.float32).to(device).permute(perm)\n            )\n            update_np = update.detach().clone().numpy()\n\n            ref_np[slice_tup] = update_np\n            # non-inplace update\n            # NOTE: should test non-inplace first\n            def slice_tuple_to_slice_list(slice_tup):\n                # NOTE: oneflow.slice_update don't support passing slice parameters.\n                slice_list = []\n                for i in range(len(slice_tup)):\n                    slice_list.append(\n                        (slice_tup[i].start, slice_tup[i].stop, slice_tup[i].step)\n                    )\n                return slice_list\n\n            of_res = flow.slice_update(\n                ref, update, slice_tuple_to_slice_list(slice_tup)\n            )\n            test_case.assertTrue(np.array_equal(of_res.numpy(), ref_np))\n            # inplace update\n            ref[slice_tup] = update\n            test_case.assertTrue(np.array_equal(ref.numpy(), ref_np))\n\n        for dims in (2, 3, 4):\n            for _ in range(10):\n                shape = [randint(1, 21) for _ in range(dims)]\n                compare_result_between_oneflow_and_numpy(test_case, shape)\n\n    def test_slice_update_expand_value(test_case):\n        ref_np = np.random.rand(2, 3, 4)\n        ref_of = flow.tensor(ref_np)\n        update_np = np.random.rand(3,)\n        update_ref = flow.tensor(update_np)\n\n        ref_of[:, :, 1] = update_ref\n        ref_np[:, :, 1] = update_np\n        test_case.assertTrue(np.array_equal(ref_of.numpy(), ref_np))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_softmax.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nimport os\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\ndef _dtype_flow_to_np(dtype):\n    return {flow.float32: np.float32, flow.float16: np.float16}[dtype]\n\n\ndef _np_softmax(x, dtype=None):\n    if dtype is not None:\n        x = x.astype(dtype)\n    x -= np.max(x, axis=-1, keepdims=True)\n    x = np.exp(x)\n    return x / np.sum(x, axis=-1, keepdims=True)\n\n\ndef _test_softmax_impl(test_case, shape, input_dtype, output_dtype):\n    np_input = np.random.randn(*shape).astype(_dtype_flow_to_np(input_dtype))\n    of_input = flow.tensor(np_input, dtype=input_dtype, device=flow.device(\"cuda\"))\n    of_out = flow.nn.functional.softmax(of_input, dtype=output_dtype)\n    if output_dtype is not None:\n        np_out = _np_softmax(np_input, dtype=_dtype_flow_to_np(output_dtype))\n    else:\n        np_out = _np_softmax(np_input)\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.001, 0.001))\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n@flow.unittest.skip_unless_1n1d()\nclass Testsoftmax(flow.unittest.TestCase):\n    def test_softmax(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"shape\"] = [(20, 30), (32, 128)]\n        arg_dict[\"input_dtype\"] = [flow.float16, flow.float32]\n        arg_dict[\"output_dtype\"] = [None, flow.float32]\n        for arg in GenArgList(arg_dict):\n            _test_softmax_impl(test_case, *arg)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_softplus.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\ndef _test_softplus_impl(test_case, shape, device):\n    np_input = np.random.randn(*shape)\n    of_input = flow.tensor(\n        np_input, dtype=flow.float32, device=flow.device(device), requires_grad=True\n    )\n    np_x_grad = np.exp(np_input) / (1 + np.exp(np_input))\n    of_out = flow.softplus(of_input)\n    np_out = np.log(1 + np.exp(np_input))\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001))\n    of_out = of_out.sum()\n    of_out.backward()\n    test_case.assertTrue(np.allclose(of_input.grad.numpy(), np_x_grad, 0.0001, 0.0001))\n\n\n@flow.unittest.skip_unless_1n1d()\nclass Testsoftplus(flow.unittest.TestCase):\n    def test_softplus(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"shape\"] = [(2, 3), (2, 4, 5, 6)]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            _test_softplus_impl(test_case, *arg)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_sort.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nfrom oneflow.test_utils.test_util import GenArgList, type_name_to_flow_type\n\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\ndef _test_sort(test_case, data_shape, axis, descending, data_type, device):\n    input = flow.tensor(\n        np.random.randn(*data_shape),\n        dtype=type_name_to_flow_type[data_type],\n        device=flow.device(device),\n    )\n    (of_values, of_indices) = flow.sort(input, dim=axis, descending=descending)\n    np_input = -input.numpy() if descending else input.numpy()\n    np_indices = np.argsort(np_input, axis=axis)\n    np_out = np.sort(np_input, axis=axis)\n    np_values = -np_out if descending else np_out\n    test_case.assertTrue(\n        np.array_equal(of_values.numpy().flatten(), np_values.flatten())\n    )\n    test_case.assertTrue(\n        np.array_equal(of_indices.numpy().flatten(), np_indices.flatten())\n    )\n\n\ndef _test_tensor_sort(test_case, data_shape, axis, descending, data_type, device):\n    input = flow.tensor(\n        np.random.randn(*data_shape),\n        dtype=type_name_to_flow_type[data_type],\n        device=flow.device(device),\n    )\n    (of_values, of_indices) = input.sort(dim=axis, descending=descending)\n    np_input = -input.numpy() if descending else input.numpy()\n    np_indices = np.argsort(np_input, axis=axis)\n    np_out = np.sort(np_input, axis=axis)\n    np_values = -np_out if descending else np_out\n    test_case.assertTrue(\n        np.array_equal(of_values.numpy().flatten(), np_values.flatten())\n    )\n    test_case.assertTrue(\n        np.array_equal(of_indices.numpy().flatten(), np_indices.flatten())\n    )\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestSort(flow.unittest.TestCase):\n    def test_sort(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [_test_sort, _test_tensor_sort]\n        arg_dict[\"data_shape\"] = [(2, 6, 5, 4), (3, 4, 8)]\n        arg_dict[\"axis\"] = [-1, 0, 2]\n        arg_dict[\"descending\"] = [True, False]\n        arg_dict[\"data_type\"] = [\"double\", \"float32\", \"int32\"]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n    @autotest(n=5, auto_backward=False, check_graph=True)\n    def test_sort_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=4).to(device)\n        y = torch.sort(x, dim=random(low=-4, high=4).to(int), descending=random_bool())\n        return y[0], y[1]\n\n    @autotest(n=5, auto_backward=False, check_graph=True)\n    def test_sort_return_type_with_random_data_(test_case):\n        device = random_device()\n        x = random_tensor(ndim=4).to(device)\n        result = torch.sort(\n            x, dim=random(low=-4, high=4).to(int), descending=random_bool()\n        )\n        return result.values, result.indices\n\n    @autotest(n=10, auto_backward=False, check_graph=True)\n    def test_sort_bool_with_random_data(test_case):\n        x = random_tensor(ndim=4).to(device=\"cpu\", dtype=torch.bool)\n        y = torch.sort(x, dim=random(low=-4, high=4).to(int), descending=random_bool())\n        return y[0], y[1]\n\n    @autotest(n=10, auto_backward=False, check_graph=True)\n    def test_sort_return_type_bool_with_random_data(test_case):\n        x = random_tensor(ndim=4).to(device=\"cpu\", dtype=torch.bool)\n        result = torch.sort(\n            x, dim=random(low=-4, high=4).to(int), descending=random_bool()\n        )\n        return result.values, result.indices\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_sparse.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nfrom collections import OrderedDict\nimport unittest\nfrom oneflow.test_utils.automated_test_util import *\nfrom oneflow.test_utils.test_util import GenArgList\nimport oneflow as flow\nimport oneflow.unittest\nimport numpy as np\n\n\ndef _test_embedding_padding_idx(test_case, device):\n    indices = flow.tensor(\n        [[1, 0, 4, 8], [8, 3, 0, 9]],\n        dtype=flow.int,\n        device=flow.device(device),\n        requires_grad=False,\n    )\n    embedding = flow.nn.Embedding(10, 3, padding_idx=0).to(device)\n    output = embedding(indices)\n    test_case.assertEqual(output[0][1].sum(), 0)\n    test_case.assertEqual(output[1][2].sum(), 0)\n\n    # negative indexing check for padding_idx\n    # padding_idx=-2, num_embeddings=10 ==> index 8 padded\n    embedding = flow.nn.Embedding(10, 3, padding_idx=-2).to(device)\n    output = embedding(indices)\n    test_case.assertEqual(output[0][3].sum(), 0)\n    test_case.assertEqual(output[1][0].sum(), 0)\n\n    # out of bounds check for padding_idx\n    test_case.assertRaises(\n        AssertionError,\n        flow.nn.Embedding,\n        num_embeddings=10,\n        embedding_dim=3,\n        padding_idx=25,\n    )\n    test_case.assertRaises(\n        AssertionError,\n        flow.nn.Embedding,\n        num_embeddings=10,\n        embedding_dim=3,\n        padding_idx=-25,\n    )\n\n    padding_idx = 0\n    embedding = flow.nn.Embedding(10, 3, padding_idx=padding_idx).to(device)\n    indices = flow.tensor(\n        [[1, 0, 4, 8], [8, 3, 0, 9]],\n        dtype=flow.int,\n        device=flow.device(device),\n        requires_grad=False,\n    )\n    pre = embedding.weight[padding_idx].clone()\n    embedding(indices).sum().backward()\n    after = (embedding.weight + embedding.weight.grad)[padding_idx]\n    embedding.zero_grad()\n    test_case.assertTrue(flow.equal(after, pre))\n\n\ndef _test_embedding_scale_by_freq(test_case, device):\n    weight = np.array(\n        [\n            [0.68258786, 0.6957856, 1.1829041],\n            [1.0154, -1.0616943, 0.50303376],\n            [0.29679507, 0.65562993, 1.0424724],\n            [-0.42980736, -0.35347632, -0.15600166],\n            [0.6763601, -0.24286619, -2.0873115],\n            [-0.13371214, -0.5589277, 1.9173933],\n            [0.08762296, 1.0264007, -0.67938024],\n            [0.32019204, -0.26137325, -1.3534237],\n            [-1.1555519, -0.67776406, 0.27372134],\n            [1.0615997, -0.59715784, 1.9855849],\n        ],\n        dtype=np.float32,\n    )\n    output = np.array(\n        [\n            [\n                [1.0154, -1.0616943, 0.50303376],\n                [0.29679507, 0.65562993, 1.0424724],\n                [0.6763601, -0.24286619, -2.0873115],\n                [-0.13371214, -0.5589277, 1.9173933],\n            ],\n            [\n                [0.6763601, -0.24286619, -2.0873115],\n                [-0.42980736, -0.35347632, -0.15600166],\n                [0.29679507, 0.65562993, 1.0424724],\n                [1.0615997, -0.59715784, 1.9855849],\n            ],\n        ],\n        dtype=np.float32,\n    )\n    indices = flow.tensor(\n        [[1, 2, 4, 5], [4, 3, 2, 9]],\n        dtype=flow.int,\n        device=flow.device(device),\n        requires_grad=False,\n    )\n    m = flow.nn.Embedding(10, 3, scale_grad_by_freq=True, _weight=flow.Tensor(weight))\n    m = m.to(device)\n    y = m(indices)\n    test_case.assertTrue(np.allclose(y.numpy(), output, 1e-05, 1e-05))\n    y = y.sum()\n    y.backward()\n    weight_grad_np = [\n        [0.0, 0.0, 0.0],\n        [1.0, 1.0, 1.0],\n        [1.0, 1.0, 1.0],\n        [1.0, 1.0, 1.0],\n        [1.0, 1.0, 1.0],\n        [1.0, 1.0, 1.0],\n        [0.0, 0.0, 0.0],\n        [0.0, 0.0, 0.0],\n        [0.0, 0.0, 0.0],\n        [1.0, 1.0, 1.0],\n    ]\n    test_case.assertTrue(\n        np.allclose(m.weight.grad.numpy(), weight_grad_np, 1e-05, 1e-05)\n    )\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestEmbedding(flow.unittest.TestCase):\n    def test_padding_idx(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            _test_embedding_padding_idx(test_case, *arg)\n            _test_embedding_scale_by_freq(test_case, *arg)\n\n    @unittest.skip(\"skip for now, becase it failed 2 times in past week\")\n    @autotest(n=5, check_graph=True)\n    def test_embedding_impl(test_case):\n        device = random_device()\n        emb_size = random(low=2) * 16\n        emb_dim = random(low=2) * 16\n        emb_shape = [emb_size, emb_dim]\n\n        idx_ndim = random(high=4).to(int).value()\n        idx_shape = [random(high=4) for i in range(idx_ndim)]\n\n        weight = random_tensor(len(emb_shape), *emb_shape).to(device)\n        indices = random_tensor(\n            len(idx_shape), *idx_shape, low=0, high=emb_size, dtype=int\n        ).to(device)\n\n        embedding = torch.nn.Embedding(emb_size, emb_dim, _weight=weight).to(device)\n        y = embedding(indices)\n        return y\n\n    @autotest(n=5, check_graph=True)\n    def test_embedding_functional(test_case):\n        device = random_device()\n        emb_size = random(low=2) * 16\n        emb_dim = random(low=2) * 16\n        emb_shape = [emb_size, emb_dim]\n\n        idx_ndim = random(high=4).to(int).value()\n        idx_shape = [random(high=4) for i in range(idx_ndim)]\n\n        weight = random_tensor(len(emb_shape), *emb_shape).to(device)\n        indices = random_tensor(\n            len(idx_shape), *idx_shape, low=0, high=emb_size, dtype=int\n        ).to(device)\n\n        y = torch.nn.functional.embedding(indices, weight)\n        return y\n\n    # NOTE(Yao Zihang): Set check_graph=False temporarily\n    # Graph mode do not support inplace op with flow.no_grad()\n    # See this issue: https://github.com/Oneflow-Inc/OneTeam/issues/1382\n    @unittest.skip(\"still have error in ci test. TODO(Yao Zihang)\")\n    @autotest(n=5, rtol=1e-03, atol=1e-03, check_graph=\"ValidatedFalse\")\n    def test_embedding_renorm(test_case):\n        device = random_device()\n        emb_size = random(low=2) * 16\n        emb_dim = random(low=2) * 16\n        emb_shape = [emb_size, emb_dim]\n\n        idx_ndim = 2\n        idx_shape = [random(high=4) for i in range(idx_ndim)]\n\n        weight = random_tensor(len(emb_shape), *emb_shape).to(device)\n        indices = random_tensor(\n            len(idx_shape), *idx_shape, low=0, high=emb_size, dtype=int\n        ).to(device)\n\n        embedding = torch.nn.Embedding(\n            emb_size, emb_dim, max_norm=1.0, _weight=weight\n        ).to(device)\n        y = embedding(indices)\n        return y\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_sparse_softmax_cross_entropy.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nimport os\nfrom collections import OrderedDict\n\nimport numpy as np\nimport torch\n\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.test_util import (\n    GenArgList,\n    type_name_to_flow_type,\n    type_name_to_np_type,\n)\n\n\ndef compare_with_torch(\n    device_type, data_type, label_type, batch_size, num_classes,\n):\n    data_type = type_name_to_flow_type[data_type]\n    label_type = type_name_to_flow_type[label_type]\n    np_labels = np.random.randint(0, num_classes, size=(batch_size,)).astype(np.int32)\n    np_logits = np.random.random((batch_size, num_classes)).astype(np.float32)\n\n    torch_logits = torch.tensor(np_logits, dtype=torch.float32, requires_grad=True)\n    torch_labels = torch.tensor(np_labels, dtype=torch.int64)\n    torch_output = torch.nn.functional.cross_entropy(\n        torch_logits, torch_labels, reduction=\"none\"\n    )\n    torch_output.sum().backward()\n\n    of_logits = flow.tensor(\n        np_logits, device=device_type, dtype=data_type, requires_grad=True\n    )\n    of_labels = flow.tensor(np_labels, device=device_type, dtype=label_type)\n    of_output = flow.nn.functional.sparse_softmax_cross_entropy(\n        labels=of_labels, logits=of_logits\n    ).to(device_type)\n    of_output.sum().backward()\n\n    assert np.allclose(\n        of_output.numpy(), torch_output.detach().numpy(), rtol=1e-03, atol=1e-04\n    )\n    assert np.allclose(\n        of_logits.grad.numpy(), torch_logits.grad, rtol=1e-03, atol=1e-04\n    )\n\n\ndef compare_eager_global_with_torch(\n    device_type, data_type, label_type, batch_size, num_classes,\n):\n    data_type = type_name_to_flow_type[data_type]\n    label_type = type_name_to_flow_type[label_type]\n    np_labels = np.random.randint(0, num_classes, size=(batch_size,)).astype(np.int32)\n    np_logits = np.random.random((batch_size, num_classes)).astype(np.float32)\n    placement = flow.placement(device_type, range(4))\n    rank = flow.env.get_rank()\n    if rank == 0:\n        torch_logits = torch.tensor(np_logits, dtype=torch.float32, requires_grad=True)\n        torch_labels = torch.tensor(np_labels, dtype=torch.int64)\n        torch_output = torch.nn.functional.cross_entropy(\n            torch_logits, torch_labels, reduction=\"none\"\n        )\n        torch_output.sum().backward()\n\n    # 1D sbp\n    of_logits = flow.tensor(\n        np_logits, device=device_type, dtype=data_type, requires_grad=True\n    )\n    flow.comm.broadcast(of_logits, 0)\n    of_logits = of_logits.to_global(placement=placement, sbp=[flow.sbp.broadcast])\n    of_logits.retain_grad()\n    global_of_logits = of_logits.to_global(placement=placement, sbp=[flow.sbp.split(1)])\n    of_labels = flow.tensor(np_labels, device=device_type, dtype=label_type)\n    flow.comm.broadcast(of_labels, 0)\n    of_labels = of_labels.to_global(placement=placement, sbp=[flow.sbp.broadcast])\n\n    of_output = flow.nn.functional.sparse_softmax_cross_entropy(\n        labels=of_labels, logits=global_of_logits\n    ).to(device_type)\n    of_output.sum().backward()\n    of_logits_grad = of_logits.grad.to_global(\n        placement=placement, sbp=[flow.sbp.broadcast]\n    )\n    of_logits_grad = of_logits_grad.to_local()\n    of_output = of_output.to_global(placement=placement, sbp=[flow.sbp.broadcast])\n    of_output = of_output.to_local()\n\n    if rank == 0:\n        assert np.allclose(\n            of_output.numpy(), torch_output.detach().numpy(), rtol=1e-03, atol=1e-04\n        )\n        assert np.allclose(\n            of_logits_grad.numpy(), torch_logits.grad, rtol=1e-03, atol=1e-04\n        )\n\n\ndef compare_eager_2d_global_with_torch(\n    device_type, data_type, label_type, batch_size, num_classes,\n):\n    data_type = type_name_to_flow_type[data_type]\n    label_type = type_name_to_flow_type[label_type]\n    np_labels = np.random.randint(0, num_classes, size=(batch_size,)).astype(np.int32)\n    np_logits = np.random.random((batch_size, num_classes)).astype(np.float32)\n\n    rank = flow.env.get_rank()\n    if rank == 0:\n        torch_logits = torch.tensor(np_logits, dtype=torch.float32, requires_grad=True)\n        torch_labels = torch.tensor(np_labels, dtype=torch.int64)\n        torch_output = torch.nn.functional.cross_entropy(\n            torch_logits, torch_labels, reduction=\"none\"\n        )\n        torch_output.sum().backward()\n\n    # 2D sbp\n    placement = flow.placement(\"cuda\", ranks=[[0, 1], [2, 3]])\n    of_logits = flow.tensor(\n        np_logits, device=device_type, dtype=data_type, requires_grad=True\n    )\n    flow.comm.broadcast(of_logits, 0)\n    of_logits = of_logits.to_global(\n        placement=placement, sbp=[flow.sbp.broadcast, flow.sbp.broadcast]\n    )\n    of_logits.retain_grad()\n    global_of_logits = of_logits.to_global(\n        placement=placement, sbp=[flow.sbp.split(0), flow.sbp.split(1)]\n    )\n    of_labels = flow.tensor(np_labels, device=device_type, dtype=label_type)\n    flow.comm.broadcast(of_labels, 0)\n    of_labels = of_labels.to_global(\n        placement=placement, sbp=[flow.sbp.broadcast, flow.sbp.broadcast]\n    )\n    of_labels = of_labels.to_global(\n        placement=placement, sbp=[flow.sbp.split(0), flow.sbp.broadcast]\n    )\n\n    of_output = flow.nn.functional.sparse_softmax_cross_entropy(\n        labels=of_labels, logits=global_of_logits\n    ).to(device_type)\n    of_output.sum().backward()\n    of_logits_grad = of_logits.grad.to_global(\n        placement=placement, sbp=[flow.sbp.broadcast, flow.sbp.broadcast]\n    )\n    of_logits_grad = of_logits_grad.to_local()\n    of_output = of_output.to_global(\n        placement=placement, sbp=[flow.sbp.broadcast, flow.sbp.broadcast]\n    )\n    of_output = of_output.to_local()\n\n    if rank == 0:\n        assert np.allclose(\n            of_output.numpy(), torch_output.detach().numpy(), rtol=1e-03, atol=1e-04\n        )\n        assert np.allclose(\n            of_logits_grad.numpy(),\n            torch_logits.grad.detach().numpy(),\n            rtol=1e-03,\n            atol=1e-04,\n        )\n\n\ndef compare_lazy_global_with_torch(\n    device_type, data_type, label_type, batch_size, num_classes,\n):\n    data_type = type_name_to_flow_type[data_type]\n    label_type = type_name_to_flow_type[label_type]\n    np_labels = np.random.randint(0, num_classes, size=(batch_size,)).astype(np.int32)\n    np_logits = np.random.random((batch_size, num_classes)).astype(np.float32)\n    placement = flow.placement(device_type, range(4))\n    rank = flow.env.get_rank()\n\n    if rank == 0:\n        torch_logits = torch.tensor(np_logits, dtype=torch.float32, requires_grad=True)\n        torch_labels = torch.tensor(np_labels, dtype=torch.int64)\n        torch_output = torch.nn.functional.cross_entropy(\n            torch_logits, torch_labels, reduction=\"none\"\n        )\n        torch_output.sum().backward()\n\n    class MyModule(flow.nn.Graph):\n        def __init__(self):\n            super(MyModule, self).__init__()\n\n        def build(self, logits, labels):\n            output = flow.nn.functional.sparse_softmax_cross_entropy(\n                labels=labels, logits=logits\n            )\n            # nn.graph no support get input.grad\n            # output.sum().backward()\n            return output\n\n    of_logits = flow.tensor(\n        np_logits, device=device_type, dtype=data_type, requires_grad=True\n    )\n    flow.comm.broadcast(of_logits, 0)\n    of_logits = of_logits.to_global(placement=placement, sbp=[flow.sbp.broadcast])\n    of_logits.retain_grad()\n    global_of_logits = of_logits.to_global(placement=placement, sbp=[flow.sbp.split(1)])\n    of_labels = flow.tensor(np_labels, device=device_type, dtype=label_type)\n    flow.comm.broadcast(of_labels, 0)\n    of_labels = of_labels.to_global(placement=placement, sbp=[flow.sbp.broadcast])\n    graph = MyModule()\n    of_output = graph(global_of_logits, of_labels)\n    of_output = of_output.to_global(placement=placement, sbp=[flow.sbp.broadcast])\n    of_output = of_output.to_local()\n\n    flow._oneflow_internal.eager.Sync()\n\n    if rank == 0:\n        assert np.allclose(\n            of_output.numpy(), torch_output.detach().numpy(), rtol=1e-03, atol=1e-04\n        )\n\n\nclass TestSparseSoftmaxCrossEntropyWithLogits(flow.unittest.TestCase):\n    @flow.unittest.skip_unless_1n1d()\n    def test_sparse_softmax_cross_entropy(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"device_type\"] = [\"cuda\", \"cpu\"]\n        arg_dict[\"data_type\"] = [\"float32\", \"double\"]\n        arg_dict[\"label_type\"] = [\"int32\", \"int64\"]\n        arg_dict[\"batch_size\"] = [64, 16]\n        arg_dict[\"num_classes\"] = [100, 1000]\n        for arg in GenArgList(arg_dict):\n            compare_with_torch(*arg)\n\n\nclass TestSparseSoftmaxCrossEntropyMsWithLogits(flow.unittest.TestCase):\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    @flow.unittest.skip_unless_1n4d()\n    def test_distributed_sparse_softmax_cross_entropy(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"device_type\"] = [\"cuda\"]\n        arg_dict[\"data_type\"] = [\"float32\", \"double\"]\n        arg_dict[\"label_type\"] = [\"int32\", \"int64\"]\n        arg_dict[\"batch_size\"] = [64]\n        arg_dict[\"num_classes\"] = [1000]\n        for arg in GenArgList(arg_dict):\n            # compare_eager_global_with_torch(*arg)\n            compare_eager_2d_global_with_torch(*arg)\n            compare_lazy_global_with_torch(*arg)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_special_ops.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport oneflow.unittest\n\nimport torch as torch_original\nfrom packaging import version\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestSpecialOps(flow.unittest.TestCase):\n    @autotest(n=5, auto_backward=\"auto\")\n    def test_flow_erf_with_random_data(test_case):\n        device = random_device()\n        x_dtype = random_dtype([\"arithmetic\"])\n        x = random_tensor().to(device).to(x_dtype)\n        y = torch.special.erf(x)\n        return y\n\n    @autotest(n=5, auto_backward=\"auto\")\n    def test_flow_erfc_with_random_data(test_case):\n        device = random_device()\n        x_dtype = random_dtype([\"arithmetic\"])\n        x = random_tensor().to(device).to(x_dtype)\n        y = torch.special.erfc(x)\n        return y\n\n    @autotest(n=5, auto_backward=\"auto\")\n    def test_flow_erfinv_with_random_data(test_case):\n        device = random_device()\n        x_dtype = random_dtype([\"float\"])\n        x = random_tensor(requires_grad=False).to(device).to(x_dtype)\n        y = torch.special.erfinv(x)\n        return y\n\n    @autotest(n=5, auto_backward=\"auto\")\n    def test_flow_exp2_with_random_data(test_case):\n        device = random_device()\n        x_dtype = random_dtype([\"arithmetic\"])\n        x = random_tensor().to(device).to(x_dtype)\n        y = torch.special.exp2(x)\n        return y\n\n    @autotest(n=5, auto_backward=\"auto\")\n    def test_flow_expm1_with_random_data(test_case):\n        device = random_device()\n        x_dtype = random_dtype([\"arithmetic\"])\n        x = random_tensor().to(device).to(x_dtype)\n        y = torch.special.expm1(x)\n        return y\n\n    @autotest(n=5, auto_backward=\"auto\")\n    def test_flow_round_with_random_data(test_case):\n        device = random_device()\n        x_dtype = random_dtype([\"arithmetic\"])\n        x = random_tensor().to(device).to(x_dtype)\n        y = torch.special.round(x)\n\n    @autotest(n=5, auto_backward=\"auto\")\n    def test_flow_log1p_with_random_data(test_case):\n        device = random_device()\n        x_dtype = random_dtype([\"arithmetic\"])\n        x = random_tensor().to(device).to(x_dtype)\n        y = torch.special.log1p(x)\n        return y\n\n    @autotest(n=5, auto_backward=\"auto\")\n    def test_flow_log_softmax_with_random_data(test_case):\n        num_dims = random(low=1, high=5).to(int)\n        device = random_device()\n        x = random_tensor(ndim=num_dims).to(device)\n        y = torch.special.log_softmax(x, dim=random(low=0, high=num_dims).to(int))\n        return y\n\n    @unittest.skipIf(\n        version.parse(torch_original.__version__) <= version.parse(\"1.13.0\"),\n        \"module 'torch.special' has no attribute 'softmax' before '1.13.0'\",\n    )\n    @autotest(n=5, auto_backward=\"auto\")\n    def test_flow_softmax_with_random_data(test_case):\n        num_dims = random(low=1, high=5).to(int)\n        device = random_device()\n        x = random_tensor(ndim=num_dims).to(device)\n        y = torch.special.softmax(x, dim=random(low=0, high=num_dims).to(int))\n        return y\n\n    @autotest(n=5, auto_backward=\"auto\")\n    def test_flow_logsumexp_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(4, random(0, 5), 2).to(device)\n        y = torch.special.logsumexp(x, dim=np.random.randint(0, 3))\n        return y\n\n    @autotest(n=5, auto_backward=\"auto\")\n    def test_flow_digamma_with_random_data(test_case):\n        device = random_device()\n        x_dtype = random_dtype([\"arithmetic\", \"half\"])\n        x = random_tensor().to(device).to(x_dtype)\n        y = torch.special.digamma(x)\n        return y\n\n    @autotest(n=5, auto_backward=\"auto\")\n    def test_flow_psi_with_random_data(test_case):\n        device = random_device()\n        x_dtype = random_dtype([\"arithmetic\", \"half\"])\n        x = random_tensor().to(device).to(x_dtype)\n        y = torch.special.psi(x)\n        return y\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestZeta(flow.unittest.TestCase):\n    # the grad func of zeta is not supported\n    @autotest(n=5, auto_backward=False)\n    def test_flow_zeta_with_random_data(test_case):\n        device = random_device()\n        x_dtype = random_dtype([\"arithmetic\"])\n        input = (\n            random_tensor(ndim=2, dim0=20, dim1=20, low=1, high=10)\n            .to(device)\n            .to(x_dtype)\n        )\n        other = (\n            random_tensor(ndim=2, dim0=20, dim1=20, low=1, high=10)\n            .to(device)\n            .to(x_dtype)\n        )\n        out = torch.special.zeta(input, other)\n        return out\n\n    @autotest(n=5, auto_backward=False)\n    def test_flow_zeta_broadcast_input(test_case):\n        device = random_device()\n        x_dtype = random_dtype([\"arithmetic\"])\n        input = random_tensor(ndim=2, dim0=1, dim1=20).to(device).to(x_dtype)\n        other = random_tensor(ndim=2, dim0=20, dim1=20).to(device).to(x_dtype)\n        out = torch.special.zeta(input, other)\n        return out\n\n    @autotest(n=5, auto_backward=False)\n    def test_flow_zeta_broadcast_other(test_case):\n        device = random_device()\n        x_dtype = random_dtype([\"arithmetic\"])\n        input = random_tensor(ndim=2, dim0=20, dim1=20).to(device).to(x_dtype)\n        other = random_tensor(ndim=2, dim0=1, dim1=20).to(device).to(x_dtype)\n        out = torch.special.zeta(input, other)\n        return out\n\n    @autotest(n=5, auto_backward=False)\n    def test_flow_zeta_scalar_other(test_case):\n        device = random_device()\n        x_dtype = random_dtype([\"arithmetic\"])\n        input = random_tensor(ndim=2, dim0=2, dim1=20).to(device).to(x_dtype)\n        out = torch.special.zeta(0.5, input)\n        return out\n\n    @autotest(n=5, auto_backward=False)\n    def test_flow_zeta_scalar_other(test_case):\n        device = random_device()\n        x_dtype = random_dtype([\"arithmetic\"])\n        input = random_tensor(ndim=2, dim0=2, dim1=20).to(device).to(x_dtype)\n        out = torch.special.zeta(input, 0.5)\n        return out\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_split.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nimport numpy as np\nfrom random import shuffle\n\nfrom oneflow.test_utils.automated_test_util import *\nimport oneflow as flow\nimport oneflow.unittest\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestSplit(flow.unittest.TestCase):\n    @autotest(n=5)\n    def test_flow_split_with_random_data(test_case):\n        k0 = random(2, 6)\n        k1 = random(2, 6)\n        k2 = random(2, 6)\n        rand_dim = random(0, 3).to(int)\n        device = random_device()\n        x = random_tensor(ndim=3, dim0=k0, dim1=k1, dim2=k2).to(device)\n        res = torch.split(x, 2, dim=rand_dim)\n        return torch.cat(res, rand_dim)\n\n    @autotest(n=5, check_graph=True)\n    def test_flow_split_with_stride(test_case):\n        k0 = random(2, 6)\n        k1 = random(2, 6)\n        k2 = random(2, 6)\n        rand_dim = random(0, 3).to(int)\n        device = random_device()\n        x = random_tensor(ndim=3, dim0=k0, dim1=k1, dim2=k2).to(device)\n        perm = [0, 1, 2]\n        shuffle(perm)\n        y = x.permute(perm)\n        z = torch.split(y, 2, dim=rand_dim)\n        return torch.cat(z, rand_dim)\n\n    @autotest(n=5)\n    def test_flow_split_sizes_with_random_data(test_case):\n        k0 = random(2, 6)\n        k1 = 7\n        k2 = random(2, 6)\n        device = random_device()\n        x = random_tensor(ndim=3, dim0=k0, dim1=k1, dim2=k2).to(device)\n        res = torch.split(x, [1, 2, 3, 1], dim=1)\n        return torch.cat(res, dim=1)\n\n    @autotest(n=5)\n    def test_flow_split_sizes_neg_dim_with_random_data(test_case):\n        k0 = random(2, 6)\n        k1 = 7\n        k2 = random(2, 6)\n        device = random_device()\n        x = random_tensor(ndim=3, dim0=k0, dim1=k1, dim2=k2).to(device)\n        res = torch.split(x, [1, 2, 3, 1], dim=-2)\n        return torch.cat(res, dim=1)\n\n    @autotest(n=5, auto_backward=False, check_graph=True)\n    def test_flow_split_bool_with_random_data(test_case):\n        k0 = random(2, 6)\n        k1 = random(2, 6)\n        k2 = random(2, 6)\n        rand_dim = random(0, 3).to(int)\n        device = random_device()\n        x = random_tensor(ndim=3, dim0=k0, dim1=k1, dim3=k2).to(\n            device=device, dtype=torch.bool\n        )\n        res = torch.split(x, split_size_or_sections=2, dim=rand_dim)\n        return torch.cat(res, rand_dim)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_square_relu.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport oneflow.unittest\nimport torch\n\n\nclass SquareReLUActivation(torch.nn.Module):\n    \"\"\"\n    Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2\n    \"\"\"\n\n    def forward(self, input):\n        relu_applied = torch.nn.functional.relu(input)\n        squared = torch.square(relu_applied)\n        return squared\n\n\ndef _test_square_relu(test_case, device):\n    torch_square_relu = SquareReLUActivation()\n    x = np.random.randn(2, 4, 3)\n    torch_x = torch.tensor(x, requires_grad=True, device=torch.device(device))\n    oneflow_x = flow.tensor(x, requires_grad=True, device=flow.device(device))\n    torch_y = torch_square_relu(torch_x)\n    oneflow_y = flow._C.square_relu(oneflow_x)\n    test_case.assertTrue(np.allclose(torch_y.detach().cpu().numpy(), oneflow_y.numpy()))\n    torch_y_sum = torch_y.sum()\n    torch_y_sum.backward()\n    oneflow_y_sum = oneflow_y.sum()\n    oneflow_y_sum.backward()\n    test_case.assertTrue(\n        np.allclose(torch_x.grad.cpu().numpy(), oneflow_x.grad.numpy())\n    )\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestModule(flow.unittest.TestCase):\n    def test_square_relu(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [_test_square_relu]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_squeeze.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\n\nfrom oneflow.test_utils.automated_test_util import *\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\ndef _test_squeeze(test_case, device):\n    np_arr = np.random.rand(1, 1, 1, 3)\n    input = flow.tensor(np_arr, dtype=flow.float32, device=flow.device(device))\n    of_shape = flow.squeeze(input, dim=[1, 2]).numpy().shape\n    np_shape = (1, 3)\n    test_case.assertTrue(np.array_equal(of_shape, np_shape))\n    test_case.assertTrue(\n        np.allclose(\n            flow.squeeze(input, dim=[1, 2]).numpy(),\n            np.squeeze(input.numpy(), axis=(1, 2)),\n            0.0001,\n            0.0001,\n        )\n    )\n\n\ndef _test_squeeze_1d_input(test_case, device):\n    np_arr = np.random.rand(10)\n    input = flow.tensor(np_arr, dtype=flow.float32, device=flow.device(device))\n    output = flow.squeeze(input)\n    test_case.assertTrue(np.allclose(output.numpy(), np_arr, 1e-05, 1e-05))\n\n\ndef _test_tensor_squeeze(test_case, device):\n    np_arr = np.random.rand(1, 1, 1, 3)\n    input = flow.tensor(np_arr, dtype=flow.float32, device=flow.device(device))\n    of_shape = input.squeeze(dim=[1, 2]).numpy().shape\n    np_shape = (1, 3)\n    test_case.assertTrue(np.array_equal(of_shape, np_shape))\n    test_case.assertTrue(\n        np.allclose(\n            input.squeeze(dim=[1, 2]).numpy(),\n            np.squeeze(input.numpy(), axis=(1, 2)),\n            0.0001,\n            0.0001,\n        )\n    )\n\n\ndef _test_squeeze_int(test_case, device):\n    np_arr = np.random.rand(1, 1, 1, 3)\n    input = flow.tensor(np_arr, dtype=flow.float32, device=flow.device(device))\n    of_shape = flow.squeeze(input, 1).numpy().shape\n    np_shape = (1, 1, 3)\n    test_case.assertTrue(np.array_equal(of_shape, np_shape))\n    test_case.assertTrue(\n        np.allclose(\n            input.squeeze(1).numpy(), np.squeeze(input.numpy(), axis=1), 0.0001, 0.0001\n        )\n    )\n\n\ndef _test_squeeze_backward(test_case, device):\n    np_arr = np.random.rand(1, 1, 1, 3)\n    input = flow.tensor(\n        np_arr, dtype=flow.float32, device=flow.device(device), requires_grad=True\n    )\n    y = flow.squeeze(input, dim=1).sum()\n    y.backward()\n    np_grad = np.ones((1, 1, 1, 3))\n    test_case.assertTrue(np.array_equal(input.grad.numpy(), np_grad))\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestSqueeze(flow.unittest.TestCase):\n    def test_squeeze(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [\n            _test_squeeze,\n            _test_squeeze_1d_input,\n            _test_squeeze_int,\n            _test_tensor_squeeze,\n            _test_squeeze_backward,\n        ]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n    @autotest(check_graph=True)\n    def test_flow_squeeze_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor().to(device)\n        y = torch.squeeze(x, random(1, 3).to(int))\n        return y\n\n    @autotest(n=10, check_graph=False, auto_backward=False)\n    def test_inplace_squeeze_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(requires_grad=False).to(device)\n        y = x.squeeze_(random(1, 3).to(int))\n        return y\n\n    @autotest(auto_backward=False, check_graph=True)\n    def test_squeeze_with_0_size_data(test_case):\n        device = random_device()\n        x = random_tensor(3, 2, 1, 0).to(device)\n        y = torch.squeeze(x)\n        return y\n\n    @autotest(auto_backward=False, check_graph=True)\n    def test_flow_squeeze_bool_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor().to(device=device, dtype=torch.bool)\n        y = torch.squeeze(x, random(1, 3).to(int))\n        return y\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_stack.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\n\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestStackModule(flow.unittest.TestCase):\n    @autotest(check_graph=True)\n    def test_stack_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=4, dim1=3, dim2=4, dim3=5).to(device)\n        y = random_tensor(ndim=4, dim1=3, dim2=4, dim3=5).to(device)\n        out = torch.stack((x, y), dim=random(low=-5, high=5).to(int))\n        return out\n\n    @autotest(auto_backward=False, check_graph=True)\n    def test_stack_bool_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=4, dim1=3, dim2=4, dim3=5).to(\n            device=device, dtype=torch.bool\n        )\n        y = random_tensor(ndim=4, dim1=3, dim2=4, dim3=5).to(\n            device=device, dtype=torch.bool\n        )\n        out = torch.stack((x, y), dim=random(low=1, high=4).to(int))\n        return out\n\n    @autotest(check_graph=True)\n    def test_column_stack_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=1, dim0=10).to(device)\n        y = random_tensor(ndim=2, dim0=10, dim1=5).to(device)\n        z = random_tensor(ndim=2, dim0=10, dim1=5).to(device)\n        out = torch.column_stack((x, y, z))\n        return out\n\n    def test_column_stack_with_0dim_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=0).to(device)\n        y = random_tensor(ndim=1, dim0=1).to(device)\n        out = torch.column_stack((x, y))\n        return out\n\n    @autotest(check_graph=True)\n    def test_row_stack_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=1, dim0=10).to(device)\n        y = random_tensor(ndim=2, dim0=5, dim1=10).to(device)\n        z = random_tensor(ndim=2, dim0=5, dim1=10).to(device)\n        out = torch.row_stack((x, y, z))\n        return out\n\n    def test_row_stack_with_0dim_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=0).to(device)\n        y = random_tensor(ndim=1, dim0=1).to(device)\n        out = torch.row_stack((x, y))\n        return out\n\n    @autotest(check_graph=True)\n    def test_hstack_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=1, dim0=5).to(device)\n        y = random_tensor(ndim=1, dim0=5).to(device)\n        out = torch.hstack((x, y))\n        return out\n\n    @autotest(check_graph=True)\n    def test_hstack_with_0dim_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=0).to(device)\n        y = random_tensor(ndim=0).to(device)\n        # test 1-dim simultaneouslsimultaneouslyy\n        z = random_tensor(ndim=1, dim0=1).to(device)\n        out = torch.hstack((x, y, z))\n        return out\n\n    @autotest(check_graph=True)\n    def test_vstack_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=2, dim0=3, dim1=4).to(device)\n        y = random_tensor(ndim=1, dim0=4).to(device)\n        z = random_tensor(ndim=2, dim0=3, dim1=4).to(device)\n        out = torch.vstack((x, y, z))\n        return out\n\n    @autotest(check_graph=True)\n    def test_vstack_with_0dim_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=0).to(device)\n        y = random_tensor(ndim=0).to(device)\n        out = torch.vstack((x, y))\n        return out\n\n    @autotest(check_graph=True)\n    def test_dstack_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=2, dim0=1, dim1=4).to(device)\n        y = random_tensor(ndim=3, dim0=1, dim1=4, dim2=1).to(device)\n        z = random_tensor(ndim=1, dim0=4).to(device)\n        out = torch.dstack((x, y, z))\n        return out\n\n    @autotest(check_graph=True)\n    def test_dstack_with_0dim_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=0).to(device)\n        y = random_tensor(ndim=0).to(device)\n        z = random_tensor(ndim=0).to(device)\n        out = torch.dstack((x, y, z))\n\n    @autotest(auto_backward=True, check_graph=True)\n    def test_stack_kMaxInputCount_inputs(test_case):\n        kMaxInputCount = 128 + 1\n        stack_list = [\n            random_tensor(ndim=2, dim0=3, dim1=4) for _ in range(kMaxInputCount)\n        ]\n        out = torch.stack(stack_list, 0)\n        return out\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_stateful_kernel_with_cache.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nimport os\n\nimport numpy as np\nimport oneflow as flow\nimport oneflow.unittest\n\n\n@flow.unittest.skip_unless_1n2d()\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\nclass TestStatefulKernelWithInpersistentState(flow.unittest.TestCase):\n    def test_stateful_kernel_with_inpersistent_state(test_case):\n        x = flow.arange(4).reshape(2, 2)\n        x = x.to_global(flow.placement.all(\"cuda\"), flow.sbp.split(0))\n        y = x[0:3, 0:1]\n        y_np = np.array([[0], [2], [0]])\n        test_case.assertTrue(\n            np.array_equal(y.to_global(sbp=flow.sbp.broadcast).to_local().numpy(), y_np)\n        )\n        x = x.to_global(sbp=flow.sbp.split(1))\n        y = x[0:3, 0:1]\n        test_case.assertTrue(\n            np.array_equal(y.to_global(sbp=flow.sbp.broadcast).to_local().numpy(), y_np)\n        )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_stateful_local_opkernel.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport os\nimport unittest\n\nimport oneflow as flow\nimport oneflow.unittest\n\nimport numpy as np\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\nclass TestStatefulLocalKernel(flow.unittest.TestCase):\n    @flow.unittest.skip_unless_1n1d()\n    def test_dynamic_attrs(test_case):\n        x = flow.full((2, 3), 3.0)\n        y = flow.unsqueeze(x, dim=1)\n        test_case.assertEqual(y.shape, flow.Size((2, 1, 3)))\n        y = flow.unsqueeze(x, dim=2)\n        test_case.assertEqual(y.shape, flow.Size((2, 3, 1)))\n\n    @flow.unittest.skip_unless_1n2d()\n    def test_stateful_local_kernel_in_global_mode(test_case):\n        rank = int(os.getenv(\"RANK\"))\n\n        x = flow.tensor(np.array([1, 2]) * (rank + 1)).to(\"cuda\")\n        x = x.to_global(flow.placement(\"cuda\", range(2)), flow.sbp.split(0))\n\n        y = flow.tensor([3, 4, 5]).to(\"cuda\")\n        y = y.to_global(flow.placement(\"cuda\", range(2)), flow.sbp.broadcast)\n\n        # logical slice assign op needs sbp and logical shape from stateful local opkernel\n        x[:3] = y\n\n        x = x.to_global(sbp=flow.sbp.broadcast)\n\n        test_case.assertTrue(\n            np.array_equal(x.to_local().numpy(), np.array([3, 4, 5, 4]))\n        )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_std.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\n\nfrom oneflow.test_utils.automated_test_util import *\nimport oneflow as flow\nimport oneflow.unittest\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestStd(flow.unittest.TestCase):\n    @autotest(n=10, auto_backward=False, rtol=0.01, atol=0.01, check_graph=True)\n    def test_std_flow_with_random_data(test_case):\n        device = random_device()\n        all_dim = random().to(int)\n        dim = random(low=0, high=6).to(int)\n        x = random_tensor(ndim=all_dim, low=2, high=6).to(device)\n        z = torch.std(\n            x, dim=dim, unbiased=random().to(bool), keepdim=random().to(bool),\n        )\n        return z\n\n    @autotest(n=10, auto_backward=False, rtol=0.01, atol=0.01, check_graph=True)\n    def test_std_tensor_with_random_data(test_case):\n        device = random_device()\n        dim = random(low=0, high=4).to(int)\n        x = random_tensor(\n            ndim=4,\n            dim0=random(2, 4),\n            dim1=random(2, 4),\n            dim2=random(2, 4),\n            dim3=random(2, 4),\n        ).to(device)\n        z = x.std(dim=dim, keepdim=random().to(bool),)\n        return z\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_stft.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nfrom numpy import random\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nimport re\n\nimport oneflow as flow\nfrom oneflow.test_utils.test_util import GenArgList\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\ndef getRandBoolvalue():\n    value = np.random.randint(0, 2)\n    if value == 1:\n        return True\n    else:\n        return False\n\n\ndef getRandFFtvalue():\n    pow = np.random.randint(2, 5)\n    result = 1\n    for i in range(pow):\n        result = result * 2\n    return result\n\n\ndef is_cufft_available():\n    if flow.cuda.is_available():\n        (major, _minor) = flow.cuda.get_device_capability()\n        return major >= 7\n    else:\n        return False\n\n\nclass TestStft(flow.unittest.TestCase):\n    @autotest(\n        n=20, check_graph=False, check_grad_use_random_data=False, auto_backward=False,\n    )\n    def test_stft_with_1D_random_data(test_case):\n        if is_cufft_available():\n            device = random_device()\n        else:\n            device = cpu_device()\n        rand_fft = getRandFFtvalue()\n        rand_size = np.random.randint(rand_fft, 300)\n        input_dims = [rand_size]\n        win_dims = [rand_fft]\n        x = random_tensor(1, *input_dims).to(device)\n        win = random_tensor(1, *win_dims).to(device)\n        onesided_value = getRandBoolvalue()\n        center_value = getRandBoolvalue()\n        normalized_value = getRandBoolvalue()\n        y = torch.stft(\n            x,\n            n_fft=rand_fft,\n            window=win,\n            return_complex=False,\n            onesided=onesided_value,\n            center=center_value,\n            normalized=normalized_value,\n        )\n        return y\n\n    def test_stft_with_2D_random_data(test_case):\n        if is_cufft_available():\n            device = random_device()\n        else:\n            device = cpu_device()\n        row_rand_size = np.random.randint(1, 50)\n        rand_fft = getRandFFtvalue()\n        col_rand_size = np.random.randint(rand_fft, 300)\n        input_dims = [row_rand_size, col_rand_size]\n        win_dims = [rand_fft]\n        x = random_tensor(2, *input_dims).to(device)\n        win = random_tensor(1, *win_dims).to(device)\n        onesided_value = getRandBoolvalue()\n        center_value = getRandBoolvalue()\n        normalized_value = getRandBoolvalue()\n        y = torch.stft(\n            x,\n            n_fft=rand_fft,\n            window=win,\n            return_complex=False,\n            onesided=onesided_value,\n            center=center_value,\n            normalized=normalized_value,\n        )\n        return y\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_sub.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nimport torch as torch_original\n\nfrom oneflow.test_utils.automated_test_util import *\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\ndef _test_sub_impl(test_case, shape, device):\n    x = flow.tensor(\n        np.random.randn(*shape),\n        dtype=flow.float32,\n        device=flow.device(device),\n        requires_grad=True,\n    )\n    y = flow.tensor(\n        np.random.randn(*shape),\n        dtype=flow.float32,\n        device=flow.device(device),\n        requires_grad=True,\n    )\n    of_out = flow.sub(x, y)\n    np_out = np.subtract(x.numpy(), y.numpy())\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05))\n    of_out = of_out.sum()\n    of_out.backward()\n    np_grad_x = np.ones(shape)\n    np_grad_y = -np.ones(shape)\n    test_case.assertTrue(np.allclose(x.grad.numpy(), np_grad_x, 1e-05, 1e-05))\n    test_case.assertTrue(np.allclose(y.grad.numpy(), np_grad_y, 1e-05, 1e-05))\n    x = 5\n    y = flow.tensor(\n        np.random.randn(*shape), dtype=flow.float32, device=flow.device(device)\n    )\n    of_out = flow.sub(x, y)\n    np_out = np.subtract(x, y.numpy())\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05))\n    x = flow.tensor(\n        np.random.randn(*shape), dtype=flow.float32, device=flow.device(device)\n    )\n    y = 5\n    of_out = flow.sub(x, y)\n    np_out = np.subtract(x.numpy(), y)\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05))\n    x = flow.tensor(\n        np.random.randn(*shape), dtype=flow.float32, device=flow.device(device)\n    )\n    y = flow.tensor(\n        np.random.randn(1, 1), dtype=flow.float32, device=flow.device(device)\n    )\n    of_out = flow.sub(x, y)\n    np_out = np.subtract(x.numpy(), y.numpy())\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05))\n    x = flow.tensor(np.array([5.0]), dtype=flow.float32)\n    y = flow.tensor(np.random.randn(1, 1), dtype=flow.float32)\n    of_out = flow.sub(x, y)\n    np_out = np.subtract(x.numpy(), y.numpy())\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05))\n    x = flow.tensor(np.random.randn(1, 1), dtype=flow.float32, requires_grad=True)\n    y = flow.tensor(np.array([5.0]), dtype=flow.float32, requires_grad=True)\n    of_out = flow.sub(x, y)\n    np_out = np.subtract(x.numpy(), y.numpy())\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05))\n    of_out = of_out.sum()\n    of_out.backward()\n    np_grad_x = np.ones((1, 1))\n    np_grad_y = -np.ones(1)\n    test_case.assertTrue(np.allclose(x.grad.numpy(), np_grad_x, 1e-05, 1e-05))\n    test_case.assertTrue(np.allclose(y.grad.numpy(), np_grad_y, 1e-05, 1e-05))\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestSubModule(flow.unittest.TestCase):\n    def test_sub(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"shape\"] = [(2, 3), (2, 3, 4), (2, 4, 5, 6)]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            _test_sub_impl(test_case, *arg)\n\n    @autotest(n=5, auto_backward=False, check_graph=True, include_complex=True)\n    def test_random_dim_sub(test_case):\n        device = random_device()\n        dim0 = random(low=1, high=4).to(int)\n        dim1 = random(low=1, high=4).to(int)\n        x = random_tensor(ndim=2, dim0=dim0, dim1=dim1).to(device)\n        y = random_tensor(ndim=2, dim0=dim0, dim1=dim1).to(device)\n        z = x - y\n        return z\n\n    @autotest(n=5, auto_backward=False, check_graph=True, include_complex=True)\n    def test_random_dim_scalar_sub(test_case):\n        device = random_device()\n        dim0 = random(low=1, high=4).to(int)\n        dim1 = random(low=1, high=4).to(int)\n        x = random_tensor(ndim=2, dim0=dim0, dim1=dim1).to(device)\n        y = random_tensor(ndim=0).to(device)\n        z = x - y\n        return z\n\n    @autotest(n=5, auto_backward=False, check_graph=True, include_complex=True)\n    def test_sub_with_0_size_data(test_case):\n        device = random_device()\n        x = random_tensor(2, 0, 3).to(device)\n        y = random_tensor(2, 1, 3).to(device)\n        out1 = x - y\n        out2 = x - 2\n        out3 = 2 - x\n        out4 = torch.sub(x, y)\n        return out1, out2, out3, out4\n\n    @autotest(n=5, auto_backward=False, check_graph=True, include_complex=True)\n    def test_sub_with_0dim_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=0).to(device)\n        y = random_tensor(ndim=0).to(device)\n        out1 = x - y\n        out2 = x - 2\n        out3 = 2 - x\n        out4 = torch.sub(x, y)\n        return out1, out2, out3, out4\n\n    @autotest(n=5, include_complex=True)\n    def test_sub_with_alpha(test_case):\n        device = random_device()\n        x1 = random_tensor(2, 2, 3).to(device)\n        x2 = random_tensor(2, 2, 3).to(device)\n        x3 = random_tensor(2, 2, 3).to(device)\n        y = random_tensor(2, 2, 3).to(device)\n        s = random().to(float)\n        alpha = random().to(float)\n        z1 = torch.sub(x1, y, alpha=alpha)\n        z2 = torch.sub(x2, s, alpha=alpha)\n        z3 = torch.sub(s, x3, alpha=alpha)\n        return z1, z2, z3\n\n    @autotest(n=5, include_complex=True)\n    def test_non_contiguous_inplace_sub(test_case):\n        device = random_device()\n        x = random_tensor(2, 2, 4).to(device)\n        y = x + 1\n        y = y[:, 1:3]\n        y -= random_tensor(2, 2, 2).to(device)\n        return y\n\n    @unittest.skip(\"skip for now, becase it failed 2 times in past week\")\n    @autotest(n=5, include_complex=True)\n    def test_scalar_sub_with_random_devices(test_case):\n        x1_device = random_device()\n        x2_device = random_device()\n        x1 = random_tensor(2, 2, 3).to(x1_device).mean()\n        x2 = random_tensor(2, 2, 3).to(x2_device)\n        y = x1 - x2\n        return y\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_sum.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\n\nfrom oneflow.test_utils.automated_test_util import *\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\ndef _test_sum_impl(test_case, device, data_type):\n    if device == \"cpu\" and data_type == flow.float16:\n        return\n    input = flow.tensor(\n        np.random.randn(2, 3) - 0.5, dtype=data_type, device=flow.device(device)\n    )\n    of_out = flow.sum(input, dim=0)\n    np_out = np.sum(input.numpy(), axis=0)\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05))\n    input = flow.tensor(\n        np.random.randn(2, 3), dtype=data_type, device=flow.device(device)\n    )\n    of_out = flow.sum(input, dim=0)\n    np_out = np.sum(input.numpy(), axis=0)\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05))\n    input = flow.tensor(\n        np.random.randn(2, 3), dtype=data_type, device=flow.device(device)\n    )\n    of_out = flow.sum(input, dim=1)\n    of_out2 = input.sum(dim=1)\n    np_out = np.sum(input.numpy(), axis=1)\n    test_case.assertTrue(np.allclose(of_out2.numpy(), of_out.numpy(), 1e-05, 1e-05))\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05))\n    input = flow.tensor(\n        np.random.randn(4, 5, 6) - 0.5,\n        dtype=data_type,\n        device=flow.device(device),\n        requires_grad=True,\n    )\n    of_out = flow.sum(input, dim=(2, 1))\n    np_out = np.sum(input.numpy(), axis=(2, 1))\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05))\n    of_out = of_out.sum()\n    of_out.backward()\n    np_grad = np.ones((4, 5, 6))\n    test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-05, 1e-05))\n\n    # For 0-dim tensor test\n    input = flow.tensor(1.0)\n    of_out = input.sum()\n    test_case.assertTrue(np.allclose(input.numpy(), of_out.numpy(), 1e-05, 1e-05))\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestSumModule(flow.unittest.TestCase):\n    def test_sum(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"data_type\"] = [flow.float16, flow.float32]\n        for arg in GenArgList(arg_dict):\n            _test_sum_impl(test_case, *arg)\n\n    @autotest(check_graph=True, include_complex=True)\n    def test_sum_against_pytorch(test_case):\n        device = random_device()\n        x = random_tensor(4, random(0, 5), 2).to(device)\n        y = torch.sum(x)\n        return y\n\n    @autotest(check_graph=True, auto_backward=False)\n    def test_sum_dtype(test_case):\n        device = random_device()\n        x = random_tensor(4, requires_grad=False).to(device)\n        y = torch.sum(\n            x,\n            dim=np.random.randint(0, 3),\n            keepdim=random_bool(),\n            dtype=random_dtype([\"arithmetic\"]),\n        )\n        return y\n\n    @autotest(\n        n=10,\n        check_graph=False,\n        auto_backward=True,\n        include_complex=True,\n        atol=1e-2,\n        rtol=1e-5,\n    )\n    def test_sum_complex_dtype(test_case):\n        device = random_device()\n        x = random_tensor(4, dtype=complex, requires_grad=True).to(\n            device=device, dtype=random_dtype([\"complex\"])\n        )\n        y = torch.sum(\n            x,\n            dim=np.random.randint(0, 3),\n            keepdim=random_bool(),\n            dtype=random_dtype([\"complex\"]),\n        )\n        return y\n\n    @autotest(\n        n=10,\n        check_graph=False,\n        auto_backward=True,\n        include_complex=True,\n        atol=1e-2,\n        rtol=1e-5,\n    )\n    def test_sum_complex_dtype(test_case):\n        device = random_device()\n        x = random_tensor(4, dtype=complex, requires_grad=True).to(\n            device=device, dtype=random_dtype([\"complex\"])\n        )\n        y = torch.sum(\n            x,\n            dim=np.random.randint(0, 3),\n            keepdim=random_bool(),\n            dtype=random_dtype([\"complex\"]),\n        )\n        return y\n\n    @autotest(check_graph=True, auto_backward=False)\n    def test_sum_arithmetic_dtype(test_case):\n        device = random_device()\n        x = random_tensor(4, requires_grad=False).to(device)\n        y = torch.sum(x, dtype=random_dtype([\"arithmetic\"]))\n        return y\n\n    @autotest(auto_backward=False, check_graph=True, include_complex=True)\n    def test_sum_with_0_size_tensor(test_case):\n        device = random_device()\n        x = random_tensor(4, 4, 3, 0, 2).to(device)\n        y = torch.sum(x, dim=np.random.randint(0, 3))\n        return y\n\n    @autotest(auto_backward=False, check_graph=True, include_complex=True)\n    def test_sum_with_0dim_tensor(test_case):\n        device = random_device()\n        x = random_tensor(ndim=0).to(device)\n        y = torch.sum(x)\n        return y\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_swapaxes.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nfrom random import shuffle\n\nfrom oneflow.test_utils.automated_test_util import *\nimport oneflow as flow\nimport oneflow.unittest\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestSwapaxes(flow.unittest.TestCase):\n    @autotest(check_graph=True)\n    def test_swapaxes_flow_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=3).to(device)\n        y = torch.swapaxes(x, random(0, 2).to(int), random(0, 2).to(int))\n        return y\n\n    @autotest(n=10)\n    def test_swapaxes_flow_with_stride(test_case):\n        device = random_device()\n        x = random_tensor(ndim=3).to(device)\n        perm = [0, 1, 2]\n        shuffle(perm)\n        y = x.permute(perm)\n        z = torch.swapaxes(y, random(0, 2).to(int), random(0, 2).to(int))\n        return z\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_swapdims.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\n\nfrom oneflow.test_utils.automated_test_util import *\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\n@flow.unittest.skip_unless_1n1d()\nclass Testswapdims(flow.unittest.TestCase):\n    @autotest(check_graph=True)\n    def test_swapdims_flow_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=3).to(device)\n        y = torch.swapdims(x, np.random.randint(0, 3), np.random.randint(0, 3))\n        return y\n\n    @autotest(check_graph=True)\n    def test_swapdims_flow_with_random_data2(test_case):\n        device = random_device()\n        x = random_tensor(ndim=4).to(device)\n        y = torch.swapdims(x, np.random.randint(0, 4), np.random.randint(0, 4))\n        return y\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_swautils.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nr\"\"\"\nThis test module references to pytorch.\nhttps://github.com/pytorch/pytorch/blob/master/test/test_optim.py.\n\"\"\"\nimport math\nimport unittest\nimport itertools\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.optim as optim\nimport oneflow.nn.functional as F\nfrom oneflow.nn import Parameter\nfrom oneflow.optim import SGD, Optimizer\nfrom oneflow.nn.optimizer.lr_scheduler import LRScheduler\nfrom oneflow.nn.optimizer.multiplicative_lr import MultiplicativeLR\nfrom oneflow.nn.optimizer.swa_utils import AveragedModel, SWALR, update_bn\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestLRScheduler(flow.unittest.TestCase):\n    # This class mainly used to test MultiplicativeLR and SWALR\n    def setUp(self):\n        super(TestLRScheduler, self).setUp()\n        self.net = SchedulerTestNet()\n        self.opt = SGD(\n            [\n                {\"params\": self.net.conv1.parameters()},\n                {\"params\": self.net.conv2.parameters(), \"lr\": 0.5},\n            ],\n            lr=0.05,\n        )\n\n    def test_multiplicative_lr(self):\n        # test Multiplicative lr\n        epochs = 10\n        self.opt.param_groups[0][\"lr\"] = 0.05\n        self.opt.param_groups[1][\"lr\"] = 0.4\n        targets = [\n            [0.05 * (0.9 ** x) for x in range(epochs)],\n            [0.4 * (0.8 ** x) for x in range(epochs)],\n        ]\n        scheduler = MultiplicativeLR(\n            self.opt, lr_lambda=[lambda x1: 0.9, lambda x2: 0.8]\n        )\n        self._test(scheduler, targets, epochs)\n\n    def _test(self, schedulers, targets, epochs=10):\n        if isinstance(schedulers, LRScheduler):\n            schedulers = [schedulers]\n        for epoch in range(epochs):\n            for param_group, target in zip(self.opt.param_groups, targets):\n                self.assertTrue(\n                    np.allclose(\n                        target[epoch], param_group[\"lr\"], atol=1e-6, rtol=1e-5,\n                    ),\n                    msg=\"LR is wrong in epoch {}: expected {}, got {}\".format(\n                        epoch, target[epoch], param_group[\"lr\"]\n                    ),\n                )\n            [scheduler.step() for scheduler in schedulers]\n\n    def test_swa_lr_state_dict(self):\n        self._check_scheduler_state_dict(\n            lambda: SWALR(self.opt, anneal_epochs=3, swa_lr=0.5),\n            lambda: SWALR(\n                self.opt, anneal_epochs=10, anneal_strategy=\"linear\", swa_lr=5.0\n            ),\n        )\n\n    def _check_scheduler_state_dict(self, constr, constr2, epochs=10):\n        scheduler = constr()\n        for _ in range(epochs):\n            scheduler.optimizer.step()\n            scheduler.step()\n        scheduler_copy = constr2()\n        scheduler_copy.load_state_dict(scheduler.state_dict())\n        for key in scheduler.__dict__.keys():\n            if key != \"optimizer\":\n                self.assertEqual(scheduler.__dict__[key], scheduler_copy.__dict__[key])\n        self.assertEqual(scheduler.get_last_lr(), scheduler_copy.get_last_lr())\n\n    def test_swalr_no_anneal(self):\n        epochs, swa_start, swa_lr = 10, 5, 0.01\n        initial_lrs = [group[\"lr\"] for group in self.opt.param_groups]\n        targets = [\n            [lr] * (swa_start + 1) + [swa_lr] * (epochs - swa_start - 1)\n            for lr in initial_lrs\n        ]\n        swa_scheduler = SWALR(self.opt, anneal_epochs=1, swa_lr=swa_lr)\n        self._test_swalr(swa_scheduler, None, targets, swa_start, epochs)\n\n    def test_swalr_cosine_anneal_after_multiplicative(self):\n        # same swa_lr for different param_groups\n        epochs, swa_start, swa_lr, anneal_epochs = 15, 5, 0.01, 5\n        mult_factor = 0.9\n        scheduler = MultiplicativeLR(self.opt, lr_lambda=lambda epoch: mult_factor)\n        swa_scheduler = SWALR(self.opt, anneal_epochs=anneal_epochs, swa_lr=swa_lr)\n\n        def anneal_coef(t):\n            if t + 1 >= anneal_epochs:\n                return 0.0\n            return (1 + math.cos(math.pi * (t + 1) / anneal_epochs)) / 2\n\n        initial_lrs = [group[\"lr\"] for group in self.opt.param_groups]\n        targets_before_swa = [\n            [lr * mult_factor ** i for i in range(swa_start + 1)] for lr in initial_lrs\n        ]\n        swa_epochs = epochs - swa_start - 1\n        targets = [\n            lrs\n            + [\n                lrs[-1] * anneal_coef(t) + swa_lr * (1 - anneal_coef(t))\n                for t in range(swa_epochs)\n            ]\n            for lrs in targets_before_swa\n        ]\n\n        self._test_swalr(swa_scheduler, scheduler, targets, swa_start, epochs)\n\n    def _test_swalr(self, swa_scheduler, scheduler, targets, swa_start, epochs):\n        for epoch in range(epochs):\n            for param_group, target in zip(self.opt.param_groups, targets):\n                self.assertTrue(\n                    np.allclose(\n                        target[epoch], param_group[\"lr\"], atol=1e-6, rtol=1e-5,\n                    ),\n                    msg=\"LR is wrong in epoch {}: expected {}, got {}\".format(\n                        epoch, target[epoch], param_group[\"lr\"]\n                    ),\n                )\n            if epoch >= swa_start:\n                self.opt.step()\n                swa_scheduler.step()\n            elif scheduler is not None:\n                self.opt.step()\n                scheduler.step()\n\n    def test_swalr_hypers(self):\n        # Test that SWALR raises errors for incorrect hyper-parameters\n        with self.assertRaisesRegex(ValueError, \"anneal_strategy must\"):\n            swa_scheduler = SWALR(self.opt, anneal_strategy=\"exponential\", swa_lr=1.0)\n\n        with self.assertRaisesRegex(ValueError, \"anneal_epochs must\"):\n            swa_scheduler = SWALR(self.opt, anneal_epochs=-1, swa_lr=1.0)\n        with self.assertRaisesRegex(ValueError, \"anneal_epochs must\"):\n            swa_scheduler = SWALR(self.opt, anneal_epochs=1.7, swa_lr=1.0)\n        with self.assertRaisesRegex(ValueError, \"swa_lr must\"):\n            swa_scheduler = SWALR(self.opt, swa_lr=[1.0, 0.1, 0.01])\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestSWAUtils(flow.unittest.TestCase):\n    # This class mainly used to test AveragedModel and update_bn\n    def _test_averaged_model(self, net_device, swa_device):\n        # test the average of AveragedModel\n        dnn = flow.nn.Sequential(\n            flow.nn.Conv2d(1, 5, kernel_size=3),\n            flow.nn.ReLU(),\n            flow.nn.MaxPool2d(kernel_size=2),\n            flow.nn.BatchNorm2d(5, momentum=0.3),\n            flow.nn.Conv2d(5, 2, kernel_size=3),\n            flow.nn.ReLU(),\n            flow.nn.Linear(5, 5),\n            flow.nn.ReLU(),\n            flow.nn.Linear(5, 10),\n        ).to(net_device)\n\n        averaged_dnn = AveragedModel(dnn, device=swa_device)\n        averaged_params = [flow.zeros_like(param) for param in dnn.parameters()]\n        n_updates = 10\n        for i in range(n_updates):\n            for p, p_avg in zip(dnn.parameters(), averaged_params):\n                p.detach().add_(flow.randn_like(p))\n                p_avg += p.detach() / n_updates\n            if i == 0:\n                averaged_dnn.update_parameters(dnn)\n            else:\n                averaged_dnn.update_parameters(dnn)\n\n        for p_avg, p_swa in zip(averaged_params, averaged_dnn.parameters()):\n            self.assertTrue(\n                flow.allclose(p_avg.cpu(), p_swa.cpu(), atol=1e-5, rtol=1e-4)\n            )\n            # Check that AveragedModel is on the correct device\n            self.assertTrue(p_swa.device == swa_device)\n            self.assertTrue(p.device == net_device)\n        self.assertTrue(averaged_dnn.n_averaged.device == swa_device)\n\n    def test_averaged_model_all_devices(self):\n        cpu = flow.device(\"cpu\")\n        self._test_averaged_model(cpu, cpu)\n        if flow.cuda.is_available():\n            cuda = flow.device(\"cuda:0\")\n            self._test_averaged_model(cuda, cpu)\n            self._test_averaged_model(cpu, cuda)\n            self._test_averaged_model(cuda, cuda)\n\n    def test_averaged_model_mixed_device(self):\n        if not flow.cuda.is_available():\n            return\n        dnn = flow.nn.Sequential(\n            flow.nn.Conv2d(1, 5, kernel_size=3), flow.nn.Linear(5, 10)\n        )\n        dnn[0].cuda()\n        dnn[1].cpu()\n        averaged_dnn = AveragedModel(dnn)\n        averaged_params = [flow.zeros_like(param) for param in dnn.parameters()]\n        n_updates = 10\n        for i in range(n_updates):\n            for p, p_avg in zip(dnn.parameters(), averaged_params):\n                p.detach().add_(flow.randn_like(p))\n                p_avg += p.detach() / n_updates\n            averaged_dnn.update_parameters(dnn)\n\n        for p_avg, p_swa in zip(averaged_params, averaged_dnn.parameters()):\n            self.assertTrue(flow.allclose(p_avg, p_swa, atol=1e-5, rtol=1e-4))\n            # Check that AveragedModel is on the correct device\n            self.assertTrue(p_avg.device == p_swa.device)\n\n    def test_averaged_model_state_dict(self):\n        dnn = flow.nn.Sequential(\n            flow.nn.Conv2d(1, 5, kernel_size=3), flow.nn.Linear(5, 10)\n        )\n        averaged_dnn = AveragedModel(dnn)\n        averaged_dnn2 = AveragedModel(dnn)\n        n_updates = 10\n        for i in range(n_updates):\n            for p in dnn.parameters():\n                p.detach().add_(flow.randn_like(p))\n            averaged_dnn.update_parameters(dnn)\n        averaged_dnn2.load_state_dict(averaged_dnn.state_dict())\n        for p_swa, p_swa2 in zip(averaged_dnn.parameters(), averaged_dnn2.parameters()):\n            self.assertTrue(flow.allclose(p_swa, p_swa2, atol=1e-5, rtol=1e-4))\n        self.assertTrue(averaged_dnn.n_averaged == averaged_dnn2.n_averaged)\n\n    def test_averaged_model_exponential(self):\n        # Test AveragedModel with EMA as avg_fn\n        dnn = flow.nn.Sequential(\n            flow.nn.Conv2d(1, 5, kernel_size=3),\n            flow.nn.BatchNorm2d(5, momentum=0.3),\n            flow.nn.Linear(5, 10),\n        )\n        alpha = 0.9\n\n        def avg_fn(p_avg, p, n_avg):\n            return alpha * p_avg + (1 - alpha) * p\n\n        averaged_dnn = AveragedModel(dnn, avg_fn=avg_fn)\n        averaged_params = [flow.zeros_like(param) for param in dnn.parameters()]\n        n_updates = 10\n        for i in range(n_updates):\n            updated_averaged_params = []\n            for p, p_avg in zip(dnn.parameters(), averaged_params):\n                p.detach().add_(flow.randn_like(p))\n                if i == 0:\n                    updated_averaged_params.append(p.clone())\n                else:\n                    updated_averaged_params.append(\n                        (p_avg * alpha + p * (1 - alpha)).clone()\n                    )\n            for b in dnn.buffers():\n                if b.size() != flow.Size([]):\n                    # oneflow don't support detach_\n                    # b.detach_().add_(flow.randn_like(b))\n                    b.detach().add_(flow.randn_like(b))\n\n            averaged_dnn.update_parameters(dnn)\n            averaged_params = updated_averaged_params\n\n        for p_avg, p_swa in zip(averaged_params, averaged_dnn.parameters()):\n            self.assertTrue(flow.allclose(p_avg, p_swa, atol=1e-5, rtol=1e-4))\n        for b_avg, b_swa in zip(dnn.buffers(), averaged_dnn.module.buffers()):\n            self.assertTrue(flow.allclose(b_avg, b_swa, atol=1e-5, rtol=1e-4))\n\n    def test_averaged_model_exponential_buffers(self):\n        # Test AveragedModel with EMA as avg_fn and use_buffers as True.\n        dnn = flow.nn.Sequential(\n            flow.nn.Conv2d(1, 5, kernel_size=3),\n            flow.nn.BatchNorm2d(5, momentum=0.3),\n            flow.nn.Linear(5, 10),\n        )\n        alpha = 0.9\n\n        def avg_fn(p_avg, p, n_avg):\n            return alpha * p_avg + (1 - alpha) * p\n\n        averaged_dnn = AveragedModel(dnn, avg_fn=avg_fn, use_buffers=True)\n        dnn_params = itertools.chain(dnn.parameters(), dnn.buffers())\n        averaged_params = [\n            flow.zeros_like(param)\n            for param in dnn_params\n            if param.size() != flow.Size([])\n        ]\n        n_updates = 10\n        for i in range(n_updates):\n            updated_averaged_params = []\n            for p, p_avg in zip(dnn_params, averaged_params):\n                if p.size() == flow.Size.Size([]):\n                    continue\n                p.detach().add_(flow.Size.randn_like(p))\n                if i == 0:\n                    updated_averaged_params.append(p.clone())\n                else:\n                    updated_averaged_params.append(\n                        (p_avg * alpha + p * (1 - alpha)).clone()\n                    )\n            averaged_dnn.update_parameters(dnn)\n            averaged_params = updated_averaged_params\n\n        for p_avg, p_swa in zip(\n            averaged_params,\n            itertools.chain(\n                averaged_dnn.module.parameters(), averaged_dnn.module.buffers()\n            ),\n        ):\n            self.assertTrue(flow.allclose(p_avg, p_swa, atol=1e-5, rtol=1e-4))\n\n    def _test_update_bn(self, dnn, dl_x, dl_xy, momentum, cuda):\n\n        preactivation_sum = flow.zeros(dnn.n_features)\n        preactivation_squared_sum = flow.zeros(dnn.n_features)\n        if cuda:\n            preactivation_sum = preactivation_sum.cuda()\n            preactivation_squared_sum = preactivation_squared_sum.cuda()\n        total_num = 0\n        for x in dl_x:\n            x = x[0]\n            if cuda:\n                x = x.cuda()\n\n            dnn.forward(x)\n            preactivations = dnn.compute_preactivation(x)\n            if len(preactivations.shape) == 4:\n                preactivations = preactivations.transpose(1, 3)\n            preactivations = preactivations.contiguous().view(-1, dnn.n_features)\n            total_num += preactivations.shape[0]\n\n            preactivation_sum += flow.sum(preactivations, dim=0)\n            preactivation_squared_sum += flow.sum(preactivations ** 2, dim=0)\n\n        preactivation_mean = preactivation_sum / total_num\n        preactivation_var = preactivation_squared_sum / total_num\n        preactivation_var = preactivation_var - preactivation_mean ** 2\n\n        update_bn(dl_xy, dnn, device=x.device)\n        self.assertTrue(\n            flow.allclose(preactivation_mean, dnn.bn.running_mean, atol=1e-6, rtol=1e-3)\n        )\n        self.assertTrue(\n            flow.allclose(preactivation_var, dnn.bn.running_var, atol=1e-1, rtol=1e-1)\n        )\n\n        def _reset_bn(module):\n            if issubclass(module.__class__, flow.nn.modules.batchnorm._BatchNorm):\n                module.running_mean = flow.zeros_like(module.running_mean)\n                module.running_var = flow.ones_like(module.running_var)\n\n        # reset batch norm and run update_bn again\n        dnn.apply(_reset_bn)\n        update_bn(dl_xy, dnn, device=x.device)\n        self.assertTrue(\n            flow.allclose(preactivation_mean, dnn.bn.running_mean, atol=1e-6, rtol=1e-3)\n        )\n        self.assertTrue(\n            flow.allclose(preactivation_var, dnn.bn.running_var, atol=1e-1, rtol=1e-1)\n        )\n        # using the dl_x loader instead of dl_xy\n        dnn.apply(_reset_bn)\n        update_bn(dl_x, dnn, device=x.device)\n        self.assertTrue(\n            flow.allclose(preactivation_mean, dnn.bn.running_mean, atol=1e-6, rtol=1e-3)\n        )\n        self.assertTrue(\n            flow.allclose(preactivation_var, dnn.bn.running_var, atol=1e-1, rtol=1e-1)\n        )\n\n    def test_update_bn_dnn(self):\n        # Test update_bn for a fully-connected network with BatchNorm1d\n        objects, input_features = 100, 5\n        x = flow.rand(objects, input_features)\n        y = flow.rand(objects)\n        ds_x = flow.utils.data.TensorDataset(x)\n        ds_xy = flow.utils.data.TensorDataset(x, y)\n        dl_x = flow.utils.data.DataLoader(ds_x, batch_size=5, shuffle=True)\n        dl_xy = flow.utils.data.DataLoader(ds_xy, batch_size=5, shuffle=True)\n        dnn = SWATestDNN(input_features=input_features)\n        dnn.train()\n        self._test_update_bn(dnn, dl_x, dl_xy, 0.1, False)\n        if flow.cuda.is_available():\n            dnn = SWATestDNN(input_features=input_features)\n            dnn.train()\n            self._test_update_bn(dnn.cuda(), dl_x, dl_xy, 0.1, True)\n        self.assertTrue(dnn.training)\n\n    def test_update_bn_cnn(self):\n        # Test update_bn for convolutional network and BatchNorm2d\n        objects = 100\n        input_channels = 3\n        height, width = 5, 5\n        x = flow.rand(objects, input_channels, height, width)\n        y = flow.rand(objects)\n        ds_x = flow.utils.data.TensorDataset(x)\n        ds_xy = flow.utils.data.TensorDataset(x, y)\n        dl_x = flow.utils.data.DataLoader(ds_x, batch_size=5, shuffle=True)\n        dl_xy = flow.utils.data.DataLoader(ds_xy, batch_size=5, shuffle=True)\n        dnn = SWATestCNN(input_channels=input_channels)\n        dnn.train()\n        self._test_update_bn(dnn, dl_x, dl_xy, 0.3, False)\n        if flow.cuda.is_available():\n            dnn = SWATestCNN(input_channels=input_channels)\n            dnn.train()\n            self._test_update_bn(dnn.cuda(), dl_x, dl_xy, 0.3, True)\n        self.assertTrue(dnn.training)\n\n    def test_bn_update_eval_momentum(self):\n        # check that update_bn preserves eval mode\n        objects = 100\n        input_channels = 3\n        height, width = 5, 5\n        x = flow.rand(objects, input_channels, height, width)\n        ds_x = flow.utils.data.TensorDataset(x)\n        dl_x = flow.utils.data.DataLoader(ds_x, batch_size=5, shuffle=True)\n        dnn = SWATestCNN(input_channels=input_channels)\n        dnn.eval()\n        update_bn(dl_x, dnn)\n        self.assertFalse(dnn.training)\n\n        # check that momentum is preserved\n        self.assertEqual(dnn.bn.momentum, 0.3)\n\n\nclass SWATestDNN(flow.nn.Module):\n    def __init__(self, input_features):\n        super(SWATestDNN, self).__init__()\n        self.n_features = 100\n        self.fc1 = flow.nn.Linear(input_features, self.n_features)\n        self.bn = flow.nn.BatchNorm1d(self.n_features)\n\n    def compute_preactivation(self, x):\n        return self.fc1(x)\n\n    def forward(self, x):\n        x = self.fc1(x)\n        x = self.bn(x)\n        return x\n\n\nclass SWATestCNN(flow.nn.Module):\n    def __init__(self, input_channels):\n        super(SWATestCNN, self).__init__()\n        self.n_features = 10\n        self.conv1 = flow.nn.Conv2d(\n            input_channels, self.n_features, kernel_size=3, padding=1\n        )\n        self.bn = flow.nn.BatchNorm2d(self.n_features, momentum=0.3)\n\n    def compute_preactivation(self, x):\n        return self.conv1(x)\n\n    def forward(self, x):\n        x = self.conv1(x)\n        x = self.bn(x)\n        return x\n\n\nclass SchedulerTestNet(flow.nn.Module):\n    def __init__(self):\n        super(SchedulerTestNet, self).__init__()\n        self.conv1 = flow.nn.Conv2d(1, 1, 1)\n        self.conv2 = flow.nn.Conv2d(1, 1, 1)\n\n    def forward(self, x):\n        return self.conv2(F.relu(self.conv1(x)))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_sync_and_async_allreduce.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nimport numpy as np\nimport os\nimport oneflow as flow\nimport oneflow.unittest\n\n\ndef sync_allreduce(x):\n    return x.to_global(sbp=flow.sbp.broadcast)\n\n\ndef async_allreduce(x):\n    return flow._C.local_all_reduce(x)\n\n\n@flow.unittest.skip_unless_1n4d()\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\nclass TestP2bOnGPU(flow.unittest.TestCase):\n    def test_p2b(test_case):\n        placement = flow.placement(\"cuda\", range(4))\n        sync_x = flow.ones(\n            (128, 1024),\n            placement=placement,\n            dtype=flow.int32,\n            sbp=flow.sbp.partial_sum,\n        )\n        async_x = flow.ones((128 * 2, 1024), device=\"cuda\", dtype=flow.int32)\n        i = 0\n        for i in range(500):\n            synced_y = sync_allreduce(sync_x)\n            asynced_y = async_allreduce(async_x)\n            if i % 20 == 0:\n                print(i)\n        print(synced_y.to_local().numpy())\n        print(asynced_y.numpy())\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_sync_batchnorm.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport os\nimport unittest\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.unittest\nfrom sync_batchnorm_test_util import ensure_datas\n\n\n@flow.unittest.skip_unless_1n2d()\n@unittest.skip(\"TODO(depeng): data too larger\")\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\nclass TestSyncBatchNorm(flow.unittest.TestCase):\n    def test_sync_batchnorm3d(test_case):\n        data_path = ensure_datas()\n        os.environ[\"ONEFLOW_ENABLE_NHWC\"] = \"0\"\n        channel = 8\n        input_np = np.load(\n            f\"{data_path}/sync_bn3d_nchw_input_rank{flow.env.get_rank()}.npy\"\n        )\n        torch_out = np.load(\n            f\"{data_path}/sync_bn3d_nchw_torch_output_rank{flow.env.get_rank()}.npy\"\n        )\n        torch_grad = np.load(\n            f\"{data_path}/sync_bn3d_nchw_torch_grad_rank{flow.env.get_rank()}.npy\"\n        )\n\n        of_input = flow.tensor(input_np, requires_grad=True, device=\"cuda\")\n        of_bn = flow.nn.BatchNorm3d(channel)\n        of_bn = flow.nn.SyncBatchNorm.convert_sync_batchnorm(of_bn).cuda()\n        of_res = of_bn(of_input)\n        of_res.sum().backward()\n\n        test_case.assertTrue(np.allclose(torch_out, of_res.numpy(), atol=1e-8))\n        test_case.assertTrue(np.allclose(torch_grad, of_input.grad.numpy(), atol=1e-8,))\n\n    def test_sync_batchnorm2d(test_case):\n        data_path = ensure_datas()\n        os.environ[\"ONEFLOW_ENABLE_NHWC\"] = \"0\"\n        channel = 8\n        input_np = np.load(\n            f\"{data_path}/sync_bn2d_nchw_input_rank{flow.env.get_rank()}.npy\"\n        )\n        torch_out = np.load(\n            f\"{data_path}/sync_bn2d_nchw_torch_output_rank{flow.env.get_rank()}.npy\"\n        )\n        torch_grad = np.load(\n            f\"{data_path}/sync_bn2d_nchw_torch_grad_rank{flow.env.get_rank()}.npy\"\n        )\n\n        of_input = flow.tensor(input_np, requires_grad=True, device=\"cuda\")\n        of_bn = flow.nn.BatchNorm2d(channel)\n        of_bn = flow.nn.SyncBatchNorm.convert_sync_batchnorm(of_bn).cuda()\n        of_res = of_bn(of_input)\n        of_res.sum().backward()\n\n        test_case.assertTrue(np.allclose(torch_out, of_res.numpy(), atol=1e-8))\n        test_case.assertTrue(np.allclose(torch_grad, of_input.grad.numpy(), atol=1e-8,))\n\n    def test_sync_batchnorm1d(test_case):\n        data_path = ensure_datas()\n        os.environ[\"ONEFLOW_ENABLE_NHWC\"] = \"0\"\n        channel = 8\n        input_np = np.load(\n            f\"{data_path}/sync_bn2d_nchw_input_rank{flow.env.get_rank()}.npy\"\n        )\n        torch_out = np.load(\n            f\"{data_path}/sync_bn2d_nchw_torch_output_rank{flow.env.get_rank()}.npy\"\n        )\n        torch_grad = np.load(\n            f\"{data_path}/sync_bn2d_nchw_torch_grad_rank{flow.env.get_rank()}.npy\"\n        )\n\n        of_input = flow.tensor(input_np, requires_grad=True, device=\"cuda\")\n        of_bn = flow.nn.BatchNorm1d(channel)\n        of_bn = flow.nn.SyncBatchNorm.convert_sync_batchnorm(of_bn).cuda()\n        of_res = of_bn(of_input)\n        of_res.sum().backward()\n\n        test_case.assertTrue(np.allclose(torch_out, of_res.numpy(), atol=1e-8))\n        test_case.assertTrue(np.allclose(torch_grad, of_input.grad.numpy(), atol=1e-8,))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_t.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\n\nfrom oneflow.test_utils.automated_test_util import *\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestTransposeAllDimFunction(flow.unittest.TestCase):\n    @autotest(check_graph=True)\n    def test_t_flow_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(\n            ndim=constant(2).to(int), dim0=random(0, 64), dim1=random(0, 64)\n        ).to(device)\n        y = torch.t(x)\n        return y\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_t5_layernorm.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport math\nimport numpy as np\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport oneflow.unittest\nimport torch\n\n\nclass TorchT5LayerNorm(torch.nn.Module):\n    def __init__(self, hidden_size, eps=1e-6):\n        \"\"\"\n        Construct a layernorm module in the T5 style. No bias and no subtraction of mean.\n        \"\"\"\n        super().__init__()\n        self.weight = torch.nn.Parameter(torch.ones(hidden_size))\n        self.variance_epsilon = eps\n\n    def forward(self, hidden_states):\n\n        # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean\n        # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated\n        # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for\n        # half-precision inputs is done in fp32\n\n        variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)\n        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)\n\n        # convert into half-precision if necessary\n        if self.weight.dtype in [torch.float16, torch.bfloat16]:\n            hidden_states = hidden_states.to(self.weight.dtype)\n        res = self.weight * hidden_states\n        return res\n\n\ndef _test_t5_layer_norm(test_case, device):\n    torch_t5_layernrom = TorchT5LayerNorm(3)\n    oneflow_t5_layernorm = flow.nn.RMSLayerNorm(3)\n    torch_t5_layernrom.to(device)\n    oneflow_t5_layernorm.to(device)\n    x = np.random.randn(2, 4, 3)\n    torch_x = torch.tensor(x, requires_grad=True, device=torch.device(device))\n    oneflow_x = flow.tensor(x, requires_grad=True, device=flow.device(device))\n    torch_y = torch_t5_layernrom(torch_x)\n    oneflow_y = oneflow_t5_layernorm(oneflow_x)\n    test_case.assertTrue(\n        np.allclose(\n            torch_y.detach().cpu().numpy(), oneflow_y.numpy(), rtol=1e-4, atol=1e-4\n        )\n    )\n    torch_y_sum = torch_y.sum()\n    torch_y_sum.backward()\n    oneflow_y_sum = oneflow_y.sum()\n    oneflow_y_sum.backward()\n    test_case.assertTrue(\n        np.allclose(\n            torch_x.grad.cpu().numpy(), oneflow_x.grad.numpy(), rtol=1e-5, atol=1e-5\n        )\n    )\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestModule(flow.unittest.TestCase):\n    def test_t5_layernorm(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [_test_t5_layer_norm]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_tensor_buffer.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nfrom oneflow.test_utils.test_util import GenArgList, type_name_to_flow_type\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\ndef _test_tensor_buffer_convert(test_case, device):\n    input = flow.tensor(\n        np.random.rand(16, 24, 32, 36), dtype=flow.float32, device=flow.device(device)\n    )\n    tensor_buffer = flow.tensor_to_tensor_buffer(input, instance_dims=2)\n    orig_tensor = flow.tensor_buffer_to_tensor(\n        tensor_buffer, dtype=flow.float32, instance_shape=[32, 36]\n    )\n    test_case.assertTrue(np.array_equal(input.numpy(), orig_tensor.numpy()))\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestTensorBufferOps(flow.unittest.TestCase):\n    def test_tensor_buffer_convert(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [_test_tensor_buffer_convert]\n        arg_dict[\"device\"] = [\"cpu\"]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_tensor_ops.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nfrom random import shuffle\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\ndef _test_is_floating_point(test_case, shape, device, dtype):\n    np_input = np.random.rand(*shape)\n    input = flow.tensor(np_input, dtype=dtype, device=device)\n    output = input.is_floating_point()\n    if input.dtype in (flow.float, flow.float16, flow.float32, flow.double):\n        test_case.assertEqual(output, True)\n    else:\n        test_case.assertEqual(output, False)\n\n\ndef _test_type_dtype(test_case, shape, device, src_dtype, tgt_dtype):\n    # test tensor.type(x: dtype) rather than tensor.type_dtype\n    np_input = np.random.rand(*shape)\n    input = flow.tensor(np_input, dtype=src_dtype, device=device)\n    input = input.type(tgt_dtype)\n    test_case.assertEqual(input.dtype, tgt_dtype)\n    test_case.assertEqual(input.device, flow.device(device))\n\n\ndef _test_type_str(\n    test_case, tensortype_dict, shape, device, dtype, tgt_tensortype_str\n):\n    # test tensor.type(x: str) rather than tensor.type_tensortype\n    np_input = np.random.rand(*shape)\n    input = flow.tensor(np_input, dtype=dtype, device=device)\n    input = input.type(tgt_tensortype_str)\n    tgt_dtype, tgt_device = tensortype_dict[tgt_tensortype_str]\n    test_case.assertEqual(input.dtype, tgt_dtype)\n    test_case.assertEqual(input.device, tgt_device)\n\n\ndef _test_type_tensortype(\n    test_case, tensortype_dict, shape, device, dtype, tgt_tensortype\n):\n    # test tensor.type(x: tensortype) rather than tensor.type_tensortype\n    np_input = np.random.rand(*shape)\n    input = flow.tensor(np_input, dtype=dtype, device=device)\n    input = input.type(tgt_tensortype)\n    tgt_dtype, tgt_device = tensortype_dict[tgt_tensortype]\n    test_case.assertEqual(input.dtype, tgt_dtype)\n    test_case.assertEqual(input.device, tgt_device)\n\n\ndef _test_type_noargs(test_case, shape, device, dtype):\n    # test tensor.type() rather than tensor.type_noargs\n    def generate_tensortype_string(device, dtype):\n        dtype_to_str_dict = {\n            flow.uint8: \"ByteTensor\",\n            flow.int8: \"CharTensor\",\n            flow.int32: \"IntTensor\",\n            flow.int64: \"LongTensor\",\n            flow.float16: \"HalfTensor\",\n            flow.bfloat16: \"BFloat16Tensor\",  # Currently unsupport\n            flow.float32: \"FloatTensor\",\n            flow.float64: \"DoubleTensor\",\n        }\n        dtype = dtype_to_str_dict[dtype]\n        if device == \"cpu\":\n            return dtype\n        return \".\".join([device, dtype])\n\n    np_input = np.random.rand(*shape)\n    input = flow.tensor(np_input, dtype=dtype, device=device)\n    test_case.assertEqual(\n        input.type(), \"oneflow.\" + generate_tensortype_string(device, dtype)\n    )\n\n\n@flow.unittest.skip_unless_1n1d()\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\nclass TestCuda(flow.unittest.TestCase):\n    @autotest(n=20, auto_backward=True, rtol=1e-4, atol=1e-4, check_graph=True)\n    def test_cuda(test_case):\n        device = random_device()\n        x = random_tensor().to(device)\n        x = x.cuda()\n        y = x.sum()\n        return y\n\n    @autotest(n=20, auto_backward=True, rtol=1e-4, atol=1e-4, check_graph=True)\n    def test_cuda_0dim(test_case):\n        device = random_device()\n        x = random_tensor(ndim=0).to(device)\n        x = x.cuda()\n        y = x.sum()\n        return y\n\n    @autotest(n=5)\n    def test_cuda_int_device(test_case):\n        device = random_device()\n        x = random_tensor().to(device)\n        x = x.cuda(0)\n        y = x.sum()\n        return y\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestTensorOps(flow.unittest.TestCase):\n    @autotest(n=20, auto_backward=False, rtol=1e-4, atol=1e-4, check_graph=True)\n    def test_cpu(test_case):\n        device = random_device()\n        x = random_tensor().to(device)\n        x = x.cpu()\n        y = x.sum()\n        return y\n\n    @autotest(n=20, auto_backward=False, rtol=1e-4, atol=1e-4, check_graph=True)\n    def test_long(test_case):\n        device = random_device()\n        x = random_tensor().to(device)\n        y = x.long()\n        return y\n\n    @autotest(n=20, auto_backward=False, rtol=1e-4, atol=1e-4, check_graph=True)\n    def test_long_0dim(test_case):\n        device = random_device()\n        x = random_tensor(ndim=0).to(device)\n        y = x.long()\n        return y\n\n    @autotest(n=5, auto_backward=False)\n    def test_long_with_non_contiguous_input(test_case):\n        device = random_device()\n        permute_list = list(range(4))\n        shuffle(permute_list)\n        input = random_tensor(ndim=4).to(device)\n        x = input.permute(permute_list)\n        y = x.long()\n        return y\n\n    @autotest(n=20, auto_backward=False, rtol=1e-4, atol=1e-4, check_graph=True)\n    def test_int(test_case):\n        device = random_device()\n        x = random_tensor().to(device)\n        y = x.int()\n        return y\n\n    @autotest(n=20, auto_backward=False, rtol=1e-4, atol=1e-4, check_graph=True)\n    def test_int_0dim(test_case):\n        device = random_device()\n        x = random_tensor(ndim=0).to(device)\n        y = x.int()\n        return y\n\n    @autotest(n=20, auto_backward=False, rtol=1e-4, atol=1e-4, check_graph=True)\n    def test_half(test_case):\n        device = random_device()\n        x = random_tensor(dtype=int).to(device)\n        y = x.half()\n        return y\n\n    @autotest(n=20, auto_backward=False, rtol=1e-4, atol=1e-4, check_graph=True)\n    def test_half_0dim(test_case):\n        device = random_device()\n        x = random_tensor(ndim=0, dtype=int).to(device)\n        y = x.half()\n        return y\n\n    @autotest(n=20, auto_backward=False, rtol=1e-4, atol=1e-4, check_graph=True)\n    def test_float(test_case):\n        device = random_device()\n        x = random_tensor(dtype=int).to(device)\n        y = x.float()\n        return y\n\n    @autotest(n=20, auto_backward=False, rtol=1e-4, atol=1e-4, check_graph=True)\n    def test_float_0dim(test_case):\n        device = random_device()\n        x = random_tensor(ndim=0, dtype=int).to(device)\n        y = x.float()\n        return y\n\n    @autotest(n=20, auto_backward=False, rtol=1e-4, atol=1e-4, check_graph=True)\n    def test_double(test_case):\n        device = random_device()\n        x = random_tensor(dtype=int).to(device)\n        y = x.double()\n        return y\n\n    @autotest(n=20, auto_backward=False, rtol=1e-4, atol=1e-4, check_graph=True)\n    def test_double_0dim(test_case):\n        device = random_device()\n        x = random_tensor(ndim=0, dtype=int).to(device)\n        y = x.double()\n        return y\n\n    @autotest(n=20, auto_backward=False, rtol=1e-4, atol=1e-4, check_graph=True)\n    def test_bool(test_case):\n        device = random_device()\n        x = random_tensor().to(device)\n        y = x.bool()\n        return y\n\n    @autotest(n=20, auto_backward=False, rtol=1e-4, atol=1e-4, check_graph=True)\n    def test_bool_0dim(test_case):\n        device = random_device()\n        x = random_tensor(ndim=0).to(device)\n        y = x.bool()\n        return y\n\n    @autotest(n=5, auto_backward=False)\n    def test_bool_with_non_contiguous_input(test_case):\n        device = random_device()\n        permute_list = list(range(4))\n        shuffle(permute_list)\n        input = random_tensor(ndim=4).to(device)\n        x = input.permute(permute_list)\n        y = x.bool()\n        return y\n\n    # Not check graph because of 2 reason.\n    # Reason 1, nn.Graph.build()'s input/output item only support types: Tensor/None.\n    # Reason 2, This op needs to convert the EagerTensor to a numpy array，so this op only supports eager mode.\n    # Please refer to File \"oneflow/api/python/utils/tensor_utils.h\", line 49, in EagerTensorToNumpy.\n    @autotest(\n        n=20, auto_backward=False, rtol=1e-4, atol=1e-4, check_graph=\"ValidatedFalse\"\n    )\n    def test_item(test_case):\n        device = random_device()\n        x = random_tensor(ndim=1, dim0=1, dtype=int).to(device)\n        y = torch.tensor(x.item())\n        return y\n\n    # Not check graph because of 2 reason.\n    # Reason 1, nn.Graph.build()'s input/output item only support types: Tensor/None.\n    # Reason 2, This op needs to convert the EagerTensor to a numpy array，so this op only supports eager mode.\n    # Please refer to File \"oneflow/api/python/utils/tensor_utils.h\", line 49, in EagerTensorToNumpy.\n    @autotest(\n        n=20, auto_backward=False, rtol=1e-4, atol=1e-4, check_graph=\"ValidatedFalse\"\n    )\n    def test_item_0dim(test_case):\n        device = random_device()\n        x = random_tensor(ndim=0, dtype=int).to(device)\n        y = torch.tensor(x.item())\n        return y\n\n    # Not check graph because of 2 reasons\n    # Reason 1, nn.Graph.build()'s input/output item only support types: Tensor/None.\n    # Reason 2, This op needs to convert the EagerTensor to a numpy array，so this op only supports eager mode.\n    # Please refer to File \"oneflow/api/python/utils/tensor_utils.h\", line 49, in EagerTensorToNumpy.\n    @autotest(\n        n=20, auto_backward=False, rtol=1e-4, atol=1e-4, check_graph=\"ValidatedFalse\"\n    )\n    def test_tolist(test_case):\n        device = random_device()\n        x = random_tensor(ndim=4).to(device)\n        y = torch.tensor(x.tolist())\n        return y\n\n    # Not check graph because of 2 reasons\n    # Reason 1, nn.Graph.build()'s input/output item only support types: Tensor/None.\n    # Reason 2, This op needs to convert the EagerTensor to a numpy array，so this op only supports eager mode.\n    # Please refer to File \"oneflow/api/python/utils/tensor_utils.h\", line 49, in EagerTensorToNumpy.\n    @autotest(\n        n=20, auto_backward=False, rtol=1e-4, atol=1e-4, check_graph=\"ValidatedFalse\"\n    )\n    def test_tolist_0dim(test_case):\n        device = random_device()\n        x = random_tensor(ndim=0).to(device)\n        y = torch.tensor(x.tolist())\n        return y\n\n    @autotest()\n    def test_type_as(test_case):\n        input = random_tensor().to(random_device())\n        target = random_tensor().to(random_device())\n        input = input.type_as(target)\n        return input\n\n    def test_is_floating_point(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"shape\"] = [(1, 2), (3, 4, 5), (2, 3, 4, 5)]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"dtype\"] = [\n            flow.uint8,\n            flow.int8,\n            flow.int32,\n            flow.int64,\n            flow.float32,\n            flow.float64,\n            flow.double,\n            flow.float,\n            flow.int,\n        ]\n        for arg in GenArgList(arg_dict):\n            _test_is_floating_point(test_case, *arg)\n\n    def test_type_dtype(test_case):\n        # test tensor.type(x.dtype) rather than tensor.type_dtype\n        arg_dict = OrderedDict()\n        arg_dict[\"shape\"] = [(1, 2), (3, 4, 5), (2, 3, 4, 5)]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"src_dtype\"] = [\n            flow.uint8,\n            flow.int8,\n            flow.int64,\n            flow.int32,\n            flow.float16,\n            flow.float32,\n            flow.float64,\n        ]\n        arg_dict[\"tgt_dtype\"] = arg_dict[\"src_dtype\"]\n        for arg in GenArgList(arg_dict):\n            _test_type_dtype(test_case, *arg)\n\n    def test_type_tensortype_str_cpu(test_case):\n        # test tensor.type(x: str) rather than tensor.type_tensortype\n        arg_dict = OrderedDict()\n        arg_dict[\"shape\"] = [(1, 2), (3, 4, 5), (2, 3, 4, 5)]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"src_dtype\"] = [\n            flow.uint8,\n            flow.int8,\n            flow.int64,\n            flow.int32,\n            flow.float16,\n            flow.float32,\n            flow.float64,\n        ]\n        tensortype_dict = {\n            \"oneflow.CharTensor\": [flow.char, flow.device(\"cpu\")],\n            \"oneflow.ByteTensor\": [flow.uint8, flow.device(\"cpu\")],\n            \"oneflow.IntTensor\": [flow.int32, flow.device(\"cpu\")],\n            \"oneflow.LongTensor\": [flow.int64, flow.device(\"cpu\")],\n            \"oneflow.HalfTensor\": [flow.float16, flow.device(\"cpu\")],\n            \"oneflow.FloatTensor\": [flow.float32, flow.device(\"cpu\")],\n            \"oneflow.DoubleTensor\": [flow.float64, flow.device(\"cpu\")],\n        }\n        arg_dict[\"tgt_tensortype_str\"] = list(tensortype_dict.keys())\n        for arg in GenArgList(arg_dict):\n            _test_type_str(test_case, tensortype_dict, *arg)\n\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    def test_type_tensortype_str(test_case):\n        # test tensor.type(x: str) rather than tensor.type_tensortype\n        arg_dict = OrderedDict()\n        arg_dict[\"shape\"] = [(1, 2), (3, 4, 5), (2, 3, 4, 5)]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"src_dtype\"] = [\n            flow.uint8,\n            flow.char,\n            flow.int64,\n            flow.int32,\n            flow.float16,\n            flow.float32,\n            flow.float64,\n        ]\n        tensortype_dict = {\n            \"oneflow.CharTensor\": [flow.char, flow.device(\"cpu\")],\n            \"oneflow.ByteTensor\": [flow.uint8, flow.device(\"cpu\")],\n            \"oneflow.IntTensor\": [flow.int32, flow.device(\"cpu\")],\n            \"oneflow.LongTensor\": [flow.int64, flow.device(\"cpu\")],\n            \"oneflow.HalfTensor\": [flow.float16, flow.device(\"cpu\")],\n            \"oneflow.FloatTensor\": [flow.float32, flow.device(\"cpu\")],\n            \"oneflow.DoubleTensor\": [flow.float64, flow.device(\"cpu\")],\n            \"oneflow.cuda.CharTensor\": [flow.char, flow.device(\"cuda\")],\n            \"oneflow.cuda.ByteTensor\": [flow.uint8, flow.device(\"cuda\")],\n            \"oneflow.cuda.IntTensor\": [flow.int32, flow.device(\"cuda\")],\n            \"oneflow.cuda.LongTensor\": [flow.int64, flow.device(\"cuda\")],\n            \"oneflow.cuda.HalfTensor\": [flow.float16, flow.device(\"cuda\")],\n            \"oneflow.cuda.FloatTensor\": [flow.float32, flow.device(\"cuda\")],\n            \"oneflow.cuda.DoubleTensor\": [flow.float64, flow.device(\"cuda\")],\n        }\n        arg_dict[\"tgt_tensortype_str\"] = list(tensortype_dict.keys())\n        for arg in GenArgList(arg_dict):\n            _test_type_str(test_case, tensortype_dict, *arg)\n\n    def test_type_tensortype_cpu(test_case):\n        # test tensor.type(x: tensortype) rather than tensor.type_tensortype\n        arg_dict = OrderedDict()\n        arg_dict[\"shape\"] = [(1, 2), (3, 4, 5), (2, 3, 4, 5)]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"src_dtype\"] = [\n            flow.uint8,\n            flow.int8,\n            flow.int64,\n            flow.int32,\n            flow.float16,\n            flow.float32,\n            flow.float64,\n        ]\n        tensortype_dict = {\n            flow.CharTensor: [flow.int8, flow.device(\"cpu\")],\n            flow.ByteTensor: [flow.uint8, flow.device(\"cpu\")],\n            flow.IntTensor: [flow.int32, flow.device(\"cpu\")],\n            flow.LongTensor: [flow.int64, flow.device(\"cpu\")],\n            flow.HalfTensor: [flow.float16, flow.device(\"cpu\")],\n            flow.FloatTensor: [flow.float32, flow.device(\"cpu\")],\n            flow.DoubleTensor: [flow.float64, flow.device(\"cpu\")],\n        }\n        arg_dict[\"tgt_tensortype\"] = list(tensortype_dict.keys())\n        for arg in GenArgList(arg_dict):\n            _test_type_tensortype(test_case, tensortype_dict, *arg)\n\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    def test_type_tensortype(test_case):\n        # test tensor.type(x: tensortype) rather than tensor.type_tensortype\n        arg_dict = OrderedDict()\n        arg_dict[\"shape\"] = [(1, 2), (3, 4, 5), (2, 3, 4, 5)]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"src_dtype\"] = [\n            flow.uint8,\n            flow.int8,\n            flow.int64,\n            flow.int32,\n            flow.float16,\n            flow.float32,\n            flow.float64,\n        ]\n        tensortype_dict = {\n            flow.CharTensor: [flow.int8, flow.device(\"cpu\")],\n            flow.ByteTensor: [flow.uint8, flow.device(\"cpu\")],\n            flow.IntTensor: [flow.int32, flow.device(\"cpu\")],\n            flow.LongTensor: [flow.int64, flow.device(\"cpu\")],\n            flow.HalfTensor: [flow.float16, flow.device(\"cpu\")],\n            flow.Tensor: [flow.float32, flow.device(\"cpu\")],\n            flow.FloatTensor: [flow.float32, flow.device(\"cpu\")],\n            flow.DoubleTensor: [flow.float64, flow.device(\"cpu\")],\n            flow.cuda.CharTensor: [flow.int8, flow.device(\"cuda\")],\n            flow.cuda.ByteTensor: [flow.uint8, flow.device(\"cuda\")],\n            flow.cuda.IntTensor: [flow.int32, flow.device(\"cuda\")],\n            flow.cuda.LongTensor: [flow.int64, flow.device(\"cuda\")],\n            flow.cuda.HalfTensor: [flow.float16, flow.device(\"cuda\")],\n            flow.cuda.FloatTensor: [flow.float32, flow.device(\"cuda\"),],\n            flow.cuda.DoubleTensor: [flow.float64, flow.device(\"cuda\"),],\n        }\n        arg_dict[\"tgt_tensortype\"] = list(tensortype_dict.keys())\n        for arg in GenArgList(arg_dict):\n            _test_type_tensortype(test_case, tensortype_dict, *arg)\n\n    def test_type_noargs(test_case):\n        # test tensor.type() rather than tensor.type_noargs\n        arg_dict = OrderedDict()\n        arg_dict[\"shape\"] = [(1, 2), (3, 4, 5), (2, 3, 4, 5)]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"dtype\"] = [\n            flow.uint8,\n            flow.int8,\n            flow.int64,\n            flow.int32,\n            flow.float16,\n            flow.float32,\n            flow.float64,\n        ]\n        for arg in GenArgList(arg_dict):\n            _test_type_noargs(test_case, *arg)\n\n    @autotest(n=3, auto_backward=False)\n    def test_bincount(test_case):\n        device = random_device()\n        len = random(1, 100)\n        input = random_tensor(1, len, dtype=int, low=0).to(device)\n        weight = random_tensor(1, len, dtype=float).to(device)\n        min_length = random(1, 100) | nothing()\n        return (\n            input.bincount(minlength=min_length),\n            input.bincount(weight, minlength=min_length),\n        )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_tensor_scatter_nd_update.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.test_util import GenArgList\n\n\ndef _test_tensor_scatter_nd_update(test_case, device):\n    origin = flow.tensor(np.arange(8), dtype=flow.float, device=flow.device(device))\n    indices = flow.tensor(\n        np.array([[1], [6], [4]]), dtype=flow.int, device=flow.device(device)\n    )\n    update = flow.tensor(\n        np.array([10.2, 5.1, 12.7]), dtype=flow.float, device=flow.device(device)\n    )\n    np_out = np.array([0.0, 10.2, 2.0, 3.0, 12.7, 5.0, 5.1, 7.0])\n    output = flow.tensor_scatter_nd_update(origin, indices, update)\n    test_case.assertTrue(np.allclose(output.numpy(), np_out, 0.0001, 0.0001))\n\n\ndef _test_tensor_scatter_nd_update_with_non_contiguous_input(test_case, device):\n    # non-contiguous tensor with shape (2, 3, 4)\n    origin = flow.tensor(\n        np.ones((4, 3, 2)), dtype=flow.float, device=flow.device(device)\n    ).permute(2, 1, 0)\n    # indices with shape (3, 2)\n    indices = flow.tensor(\n        np.array([[0, 0], [1, 0], [1, 1]]), dtype=flow.int, device=flow.device(device)\n    )\n    # non-contiguous update with shape (3, 4)\n    update = flow.tensor(\n        np.zeros((4, 3)), dtype=flow.float, device=flow.device(device)\n    ).T\n    output = flow.tensor_scatter_nd_update(origin, indices, update)\n\n    np_res = np.ones((2, 3, 4))\n    np_res[0, 0] = 0\n    np_res[1, 0] = 0\n    np_res[1, 1] = 0\n    test_case.assertTrue(np.array_equal(output.numpy(), np_res))\n\n\ndef _test_tensor_scatter_nd_update_t(test_case, device):\n    origin = flow.tensor(\n        np.arange(15).reshape(5, 3), dtype=flow.float, device=flow.device(device)\n    )\n    indices = flow.tensor(\n        np.array([[0], [4], [2]]), dtype=flow.int, device=flow.device(device)\n    )\n    update = flow.tensor(\n        np.array([[1, 1, 1], [2, 2, 2], [3, 3, 3]]),\n        dtype=flow.float,\n        device=flow.device(device),\n    )\n    np_out = np.array(\n        [\n            [1.0, 1.0, 1.0],\n            [3.0, 4.0, 5.0],\n            [3.0, 3.0, 3.0],\n            [9.0, 10.0, 11.0],\n            [2.0, 2.0, 2.0],\n        ]\n    )\n    output = flow.tensor_scatter_nd_update(origin, indices, update)\n    test_case.assertTrue(np.allclose(output.numpy(), np_out, 0.0001, 0.0001))\n\n\ndef _test_tensor_scatter_nd_update_backward(test_case, device):\n    origin = flow.tensor(\n        np.arange(8), dtype=flow.float, device=flow.device(device), requires_grad=True,\n    )\n    indices = flow.tensor(\n        np.array([[1], [6], [4]]), dtype=flow.int, device=flow.device(device)\n    )\n    of_update = flow.tensor(\n        np.array([10.2, 5.1, 12.7]),\n        requires_grad=True,\n        dtype=flow.float,\n        device=flow.device(device),\n    )\n    np_out = np.array([0.0, 10.2, 2.0, 3.0, 12.7, 5.0, 5.1, 7.0])\n    np_update_grad = np.array([1.0, 1.0, 1.0])\n    np_origin_grad = np.array([1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0, 1.0])\n    output = flow.tensor_scatter_nd_update(origin, indices, of_update)\n    out_sum = output.sum()\n    out_sum.backward()\n    test_case.assertTrue(np.allclose(output.numpy(), np_out, 0.0001, 0.0001))\n    test_case.assertTrue(np.allclose(of_update.grad.numpy(), np_update_grad))\n    test_case.assertTrue(np.allclose(origin.grad.numpy(), np_origin_grad))\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestTensorScatterNdUpdate(flow.unittest.TestCase):\n    def test_tensor_scatter_nd_update(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [\n            _test_tensor_scatter_nd_update,\n            _test_tensor_scatter_nd_update_with_non_contiguous_input,\n            _test_tensor_scatter_nd_update_t,\n            _test_tensor_scatter_nd_update_backward,\n        ]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_tensor_split.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nfrom random import shuffle\n\nfrom oneflow.test_utils.automated_test_util import *\nimport oneflow as flow\nimport oneflow.unittest\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestTensorSplitVec(flow.unittest.TestCase):\n    @autotest(check_graph=True)\n    def test_flow_tensor_split_vec(test_case):\n        device = random_device()\n        x = random_tensor(\n            ndim=4,\n            dim0=random(3, 6),\n            dim1=random(3, 6),\n            dim2=random(3, 6),\n            dim3=random(3, 6),\n        ).to(device)\n        dim = random(-3, 3).to(int)\n        z = torch.tensor_split(x, (1, 2), dim)\n        return z[0]\n\n    @autotest(n=5)\n    def test_flow_tensor_split_vec_with_stride(test_case):\n        device = random_device()\n        x = random_tensor(\n            ndim=4,\n            dim0=random(3, 6),\n            dim1=random(3, 6),\n            dim2=random(3, 6),\n            dim3=random(3, 6),\n        ).to(device)\n        dim = random(-3, 3).to(int)\n        perm = [0, 1, 2, 3]\n        shuffle(perm)\n        y = x.permute(perm)\n        z = torch.tensor_split(y, (1, 2), dim)\n        return z[0]\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestTensorSplitInt(flow.unittest.TestCase):\n    @autotest(check_graph=True)\n    def test_flow_tensor_split_int(test_case):\n        device = random_device()\n        x = random_tensor(\n            ndim=4,\n            dim0=random(3, 6),\n            dim1=random(3, 6),\n            dim2=random(3, 6),\n            dim3=random(3, 6),\n        ).to(device)\n        split = random(1, 3).to(int)\n        dim = random(-3, 3).to(int)\n        z = torch.tensor_split(x, split, dim)\n        return z[0]\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_tensor_to.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport os\nimport unittest\n\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@flow.unittest.skip_unless_1n2d()\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\nclass Test2DeviceGlobalTensorTo(flow.unittest.TestCase):\n    def test_asymmetric_global_tensor_clone(test_case):\n        placement = flow.placement(\"cuda\", range(1))\n        x = flow.ones((4,), placement=placement, sbp=flow.sbp.broadcast)\n        cloned = x.detach().clone()\n        test_case.assertEqual(x.placement, cloned.placement)\n        test_case.assertEqual(x.sbp, cloned.sbp)\n        if flow.env.get_rank() == 0:\n            cloned_local = cloned.to_local()\n            cloned_local[0] = 0\n            test_case.assertEqual(cloned_local[0].numpy().item(), 0)\n            test_case.assertEqual(x.to_local()[0].numpy().item(), 1)\n\n    def test_global_tensor_clone(test_case):\n        placement = flow.placement(\"cuda\", range(2))\n        x = flow.ones((4,), placement=placement, sbp=flow.sbp.broadcast)\n        cloned = x.detach().clone()\n        test_case.assertEqual(x.placement, cloned.placement)\n        test_case.assertEqual(x.sbp, cloned.sbp)\n        cloned_local = cloned.to_local()\n        cloned_local[0] = 0\n        test_case.assertEqual(cloned_local[0].numpy().item(), 0)\n        test_case.assertEqual(x.to_local()[0].numpy().item(), 1)\n\n    def test_global_tensor_to(test_case):\n        placement = flow.placement(\"cuda\", range(2))\n        x = flow.ones((4,), placement=placement, sbp=flow.sbp.broadcast)\n        cloned = x.to(copy=True)\n        test_case.assertEqual(x.placement, cloned.placement)\n        test_case.assertEqual(x.sbp, cloned.sbp)\n        cloned_local = cloned.to_local()\n        cloned_local[0] = 0\n        test_case.assertEqual(cloned_local[0].numpy().item(), 0)\n        test_case.assertEqual(x.to_local()[0].numpy().item(), 1)\n\n    def test_tensor_to_h2d1(test_case):\n        input = flow.tensor(np.random.randn(2, 3, 4, 5), dtype=flow.int64)\n        output = input.to(device=flow.device(\"cuda:1\"), dtype=flow.int32)\n        test_case.assertEqual(output.device, flow.device(\"cuda:1\"))\n        test_case.assertEqual(output.dtype, flow.int32)\n        test_case.assertTrue(\n            np.allclose(input.numpy(), output.numpy(), rtol=0.0001, atol=0.0001)\n        )\n\n\n@flow.unittest.skip_unless_1n1d()\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\nclass TestTo(flow.unittest.TestCase):\n    def test_global_tensor_clone(test_case):\n        x = flow.ones(\n            (4,), placement=flow.placement(\"cuda\", ranks=[0]), sbp=flow.sbp.broadcast\n        )\n        cloned = x.detach().clone()\n        test_case.assertEqual(x.placement, cloned.placement)\n        test_case.assertEqual(x.sbp, cloned.sbp)\n        cloned_local = cloned.to_local()\n        cloned_local[0] = 0\n        test_case.assertEqual(cloned_local[0].numpy().item(), 0)\n        test_case.assertEqual(x.to_local()[0].numpy().item(), 1)\n\n    def test_global_tensor_to(test_case):\n        x = flow.ones(\n            (4,), placement=flow.placement(\"cuda\", ranks=[0]), sbp=flow.sbp.broadcast\n        )\n        cloned = x.to(copy=True)\n        test_case.assertEqual(x.placement, cloned.placement)\n        test_case.assertEqual(x.sbp, cloned.sbp)\n        cloned_local = cloned.to_local()\n        cloned_local[0] = 0\n        test_case.assertEqual(cloned_local[0].numpy().item(), 0)\n        test_case.assertEqual(x.to_local()[0].numpy().item(), 1)\n\n    def test_empty_global_tensor_to(test_case):\n        x = flow.ones(\n            (0,), placement=flow.placement(\"cuda\", ranks=[0]), sbp=flow.sbp.broadcast\n        )\n        cloned = x.to(copy=True)\n        test_case.assertEqual(x.placement, cloned.placement)\n        test_case.assertEqual(x.sbp, cloned.sbp)\n        cloned_local = cloned.to_local()\n        test_case.assertEqual(tuple(cloned.shape), (0,))\n        test_case.assertEqual(tuple(cloned_local.shape), (0,))\n\n    def test_tensor_to_h2d(test_case):\n        input = flow.tensor(np.random.randn(2, 3, 4, 5), dtype=flow.float32)\n        output = input.to(device=flow.device(\"cuda\"))\n        test_case.assertEqual(output.device, flow.device(\"cuda\"))\n        test_case.assertTrue(\n            np.allclose(input.numpy(), output.numpy(), rtol=0.0001, atol=0.0001)\n        )\n        gpu_output = output.to(device=flow.device(\"cuda\"))\n        test_case.assertEqual(gpu_output.device, flow.device(\"cuda\"))\n        test_case.assertTrue(\n            np.allclose(input.numpy(), gpu_output.numpy(), rtol=0.0001, atol=0.0001)\n        )\n\n    def test_tensor_to_d2h(test_case):\n        input = flow.tensor(\n            np.random.randn(2, 3, 4, 5), dtype=flow.float32, device=flow.device(\"cuda\")\n        )\n        output = input.to(device=flow.device(\"cpu\"))\n        test_case.assertEqual(output.device, flow.device(\"cpu\"))\n        test_case.assertTrue(\n            np.allclose(input.numpy(), output.numpy(), rtol=0.0001, atol=0.0001)\n        )\n\n    def test_tensor_to_d2d(test_case):\n        input = flow.tensor(\n            np.random.randn(2, 3, 4, 5), dtype=flow.float32, device=flow.device(\"cuda\")\n        )\n        output = input.to(device=flow.device(\"cuda:0\"))\n        test_case.assertEqual(output.device, flow.device(\"cuda:0\"))\n        test_case.assertTrue(\n            np.allclose(input.numpy(), output.numpy(), rtol=0.0001, atol=0.0001)\n        )\n\n    def test_tensor_to_h2h(test_case):\n        input = flow.tensor(np.random.randn(2, 3, 4, 5), dtype=flow.float32)\n        output = input.to(device=flow.device(\"cpu\"))\n        test_case.assertEqual(output.device, flow.device(\"cpu\"))\n        test_case.assertTrue(\n            np.allclose(input.numpy(), output.numpy(), rtol=0.0001, atol=0.0001)\n        )\n\n    def test_tensor_to_cast(test_case):\n        input = flow.tensor(np.random.randn(2, 3, 4, 5), dtype=flow.float32)\n        output = input.to(dtype=flow.int)\n        test_case.assertEqual(output.dtype, flow.int)\n\n    def test_tensor_to_cast_h2d(test_case):\n        input = flow.tensor(np.random.randn(2, 3, 4, 5), dtype=flow.float32)\n        output = input.to(device=flow.device(\"cuda\"), dtype=flow.int)\n        test_case.assertEqual(output.dtype, flow.int)\n        test_case.assertEqual(output.device, flow.device(\"cuda\"))\n\n    def test_tensor_using_tensor(test_case):\n        tensor = flow.tensor(np.random.randn(2, 3, 4, 5), device=\"cuda\", dtype=flow.int)\n        input = flow.tensor(np.random.randn(2, 3))\n        output = input.to(tensor)\n        test_case.assertEqual(output.dtype, flow.int)\n        test_case.assertEqual(output.device, flow.device(\"cuda\"))\n\n    @autotest(n=5, check_graph=True)\n    def test_int_to_args(test_case):\n        device_num = random(0, 2).to(int).value()\n        x = random_tensor(ndim=4).to(device_num)\n        return x\n\n    @autotest(n=5, check_graph=True)\n    def test_int_to_kwargs(test_case):\n        device_num = random(0, 2).to(int).value()\n        x = random_tensor(ndim=4).to(device=device_num)\n        return x\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_tensordot.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nfrom collections import OrderedDict\nimport unittest\nimport numpy as np\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.test_util import GenArgList\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestTensordot(flow.unittest.TestCase):\n    @autotest(n=5, rtol=1e-2, atol=1e-3)\n    def test_tensordot_intdim(test_case):\n        device = random_device()\n        dims = random()\n        dims_list = [random().to(int).value() for i in range(dims.to(int).value() + 3)]\n        x = random_tensor(\n            ndim=3, dim0=dims_list[0], dim1=dims_list[1], dim2=dims_list[2],\n        ).to(device)\n        y = random_tensor(\n            ndim=3,\n            dim0=dims_list[0 + dims.to(int).value()],\n            dim1=dims_list[1 + dims.to(int).value()],\n            dim2=dims_list[2 + dims.to(int).value()],\n        ).to(device)\n\n        z = torch.tensordot(x, y, dims=3 - dims.to(int).value())\n        return z\n\n    @autotest(n=5, rtol=1e-2, atol=1e-3)\n    def test_tensordot_list_dim(test_case):\n        device = random_device()\n        x = random_tensor(4, 1, 3, 2, 5).to(device)\n        y = random_tensor(4, 4, 2, 3, 5).to(device)\n        z = torch.tensordot(x, y, dims=[[1, 2, 0], [2, 1, 0]])\n        return z\n\n    @autotest(n=5, rtol=1e-2, atol=1e-2)\n    def test_tensordot_tuple_dim(test_case):\n        device = random_device()\n        x = random_tensor(4, 1, 3, 2, 5).to(device)\n        y = random_tensor(4, 4, 2, 3, 5).to(device)\n        z = torch.tensordot(x, y, dims=([1, 2, 0], [2, 1, 0]))\n        return z\n\n    @autotest(n=5, rtol=1e-2, atol=1e-3)\n    def test_tensordot_list_neg_dim(test_case):\n        device = random_device()\n        x = random_tensor(4, 1, 3, 2, 5).to(device)\n        y = random_tensor(4, 4, 2, 3, 5).to(device)\n        z = torch.tensordot(x, y, dims=[[-3, -2, -4], [-2, -3, -4]])\n        return z\n\n    @autotest(check_graph=False, rtol=1e-2, atol=1e-3)\n    def test_tensordot_backward(test_case):\n        device = random_device()\n        x = random_tensor(3, 3, 4, 5).to(device)\n        y = random_tensor(2, 4, 5).to(device)\n        z = torch.tensordot(x, y, dims=[[1, 2], [0, 1]])\n        z.sum().backward()\n\n    @autotest(check_graph=False)\n    def test_tensordot_tensor_dim(test_case):\n        def _test_tensor_dim(test_case, device):\n            np_dim = np.array([[1, 2, 3], [1, 2, 3]], dtype=int)\n            flow_dim = flow.tensor(np_dim).to(device)\n            torch_dim = torch.tensor(np_dim).to(device)\n\n            np_random_array = np.random.randn(2, 3, 4, 5)\n            flow_tensor = flow.tensor(np_random_array).to(device)\n            torch_tensor = torch.tensor(np_random_array).to(device)\n\n            flow_result = flow.tensordot(flow_tensor, flow_tensor, dims=flow_dim)\n            torch_result = torch.tensordot(torch_tensor, torch_tensor, dims=torch_dim)\n            test_case.assertTrue(\n                np.allclose(\n                    flow_result.numpy(),\n                    torch_result.cpu().numpy(),\n                    rtol=0.0001,\n                    atol=0.0001,\n                )\n            )\n\n        arg_dict = OrderedDict()\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            _test_tensor_dim(test_case, arg[0])\n\n    @autotest(n=5, check_graph=False, rtol=1e-2, atol=1e-2)\n    def test_tensordot_single_item_tensor_dim(test_case):\n        device = random_device()\n        dims = random_tensor(1, dim0=1, low=0, high=4, dtype=int).to(device)\n        x = random_tensor(3, dim0=4, dim1=4, dim2=4).to(device)\n        y = random_tensor(3, dim0=4, dim1=4, dim2=4).to(device)\n        z = torch.tensordot(x, y, dims=dims)\n        return z\n\n    @autotest(n=5, rtol=1e-3, atol=1e-4)\n    def test_tensordot_broadcast(test_case):\n        device = random_device()\n        x = random_tensor(4, 1, 1, 1, 1).to(device)\n        y = random_tensor(4, 2, 3, 4, 5).to(device)\n        z = torch.tensordot(x, y, dims=random(high=5).to(int).value())\n        return z\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_tile.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\n\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestTile(flow.unittest.TestCase):\n    @autotest(check_graph=True)\n    def test_flow_tile_with_random_data(test_case):\n        x = random_tensor(ndim=2, dim0=1, dim1=2)\n        reps = (random(1, 5).to(int), random(1, 5).to(int), random(1, 5).to(int))\n        z = torch.tile(x, reps)\n        return z\n\n    @autotest(check_graph=True)\n    def test_flow_tensor_tile_with_random_data(test_case):\n        x = random_tensor(ndim=2, dim0=1, dim1=2)\n        reps = (random(1, 5).to(int), random(1, 5).to(int), random(1, 5).to(int))\n        y = x.tile(reps)\n        return y\n\n    @autotest(auto_backward=False, check_graph=True)\n    def test_flow_tile_bool_with_random_data(test_case):\n        x = random_tensor(ndim=2, dim0=1, dim1=2).to(torch.bool)\n        reps = (random(1, 5).to(int), random(1, 5).to(int), random(1, 5).to(int))\n        z = torch.tile(x, reps)\n        return z\n\n    @autotest(check_graph=True)\n    def test_flow_tile_with_0dim_data(test_case):\n        x = random_tensor(ndim=0)\n        reps = (random(1, 5).to(int), random(1, 5).to(int), random(1, 5).to(int))\n        z = torch.tile(x, reps)\n        return z\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_to_torch.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nimport numpy as np\nimport os\n\nimport oneflow as flow\nimport oneflow.unittest\nimport torch\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestToTroch(flow.unittest.TestCase):\n    # NOTE: oneflow and torch cpu tensor shared the same memory, refer to File \"python/oneflow/test/modules/test_from_torch.py\", line 49, in test_from_torch_cpu.\n    def test_to_torch_cpu(test_case):\n        flow_t = flow.rand(5, 3, 3)\n        numpy_from_flow = flow_t.numpy()\n\n        torch_t = flow.utils.tensor.to_torch(flow_t)\n\n        test_case.assertEqual(\n            torch_t.data_ptr(), numpy_from_flow.__array_interface__[\"data\"][0]\n        )\n        numpy_from_flow[0][0] = [1, 2, 3]\n        test_case.assertTrue(\n            np.allclose(torch_t.numpy(), numpy_from_flow, rtol=0.001, atol=0.001)\n        )\n\n        test_case.assertTrue(\n            np.allclose(flow_t.numpy(), torch_t.numpy(), rtol=0.001, atol=0.001)\n        )\n        test_case.assertEqual(flow_t.numpy().dtype, torch_t.numpy().dtype)\n\n    # NOTE: For the case of 0 size tensor, no memory addresses are compared.\n    #  Because the address of 0 size tensor is random at this time.\n    def test_to_torch_cpu_with_0_size_data(test_case):\n        flow_t = flow.rand(5, 3, 0)\n\n        torch_t = flow.utils.tensor.to_torch(flow_t)\n\n        test_case.assertTrue(\n            np.allclose(flow_t.numpy(), torch_t.numpy(), rtol=0.001, atol=0.001)\n        )\n        test_case.assertEqual(flow_t.numpy().dtype, torch_t.numpy().dtype)\n\n    def test_to_torch_cpu_with_0dim_data(test_case):\n        flow_t = flow.tensor(5)\n        numpy_from_flow = flow_t.numpy()\n\n        torch_t = flow.utils.tensor.to_torch(flow_t)\n\n        test_case.assertEqual(\n            torch_t.data_ptr(), numpy_from_flow.__array_interface__[\"data\"][0]\n        )\n\n        test_case.assertTrue(\n            np.allclose(flow_t.numpy(), torch_t.numpy(), rtol=0.001, atol=0.001)\n        )\n        test_case.assertEqual(flow_t.numpy().dtype, torch_t.numpy().dtype)\n\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    def test_to_torch_gpu(test_case):\n        flow_t = flow.rand(5, 3, 3).to(\"cuda\")\n\n        torch_t = flow.utils.tensor.to_torch(flow_t)\n\n        flow_t[0][0] = flow.tensor([1, 2, 3]).to(flow.float32)\n        # NOTE: OneFlow operations are asynchoronously executed,\n        # so we need to synchronize explicitly here.\n        flow._oneflow_internal.eager.Sync()\n        test_case.assertTrue(np.array_equal(torch_t.cpu().numpy(), flow_t.numpy()))\n\n        test_case.assertEqual(flow_t.numpy().dtype, torch_t.cpu().numpy().dtype)\n\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    def test_to_torch_global(test_case):\n        flow_t = flow.rand(5, 3, 3).to_global(\n            placement=flow.placement.all(\"cuda\"), sbp=flow.sbp.broadcast\n        )\n\n        torch_t = flow.utils.tensor.to_torch(flow_t)\n\n        test_case.assertEqual(flow_t.numpy().dtype, torch_t.cpu().numpy().dtype)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_topk.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport torch\nimport oneflow.unittest\n\n\ndef _test_top_k(test_case, shape, k, dim, device):\n    if k >= shape[dim]:\n        return\n    x_np = np.random.randn(*shape)\n    x_of = flow.tensor(x_np, device=device)\n    of_out = flow.topk(x_of, k=k, dim=dim)\n    x_pt = torch.tensor(x_np, device=device)\n    pt_out = torch.topk(x_pt, k=k, dim=dim)\n    test_case.assertTrue(\n        np.array_equal(of_out.values.cpu().numpy(), pt_out.values.cpu().numpy())\n    )\n    test_case.assertTrue(\n        np.array_equal(of_out.indices.cpu().numpy(), pt_out.indices.cpu().numpy())\n    )\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestTopK(flow.unittest.TestCase):\n    def test_in_top_k(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"shape\"] = [(1, 16), (1, 1024), (8, 8), (8, 256)]\n        arg_dict[\"k\"] = [1, 4, 64]\n        arg_dict[\"dim\"] = [0, 1]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            _test_top_k(test_case, *arg)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_transpose.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nfrom cgi import test\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nfrom random import shuffle\n\nfrom oneflow.test_utils.automated_test_util import *\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\ndef _test_transpose(test_case, device):\n    input = flow.tensor(\n        np.random.randn(2, 6, 5, 3), dtype=flow.float32, device=flow.device(device)\n    )\n    of_out = flow.transpose(input, 0, 1)\n    np_out = input.numpy().transpose((1, 0, 2, 3))\n    test_case.assertTrue(np.array_equal(of_out.numpy().flatten(), np_out.flatten()))\n\n\ndef _test_tensor_transpose(test_case, device):\n    input = flow.tensor(\n        np.random.randn(2, 6, 5, 3), dtype=flow.float32, device=flow.device(device)\n    )\n    of_out = input.transpose(0, 1)\n    np_out = input.numpy().transpose((1, 0, 2, 3))\n    test_case.assertTrue(np.array_equal(of_out.numpy().flatten(), np_out.flatten()))\n\n\ndef _test_tranpose_negative_dim(test_case, device):\n    input = flow.tensor(\n        np.random.randn(2, 6, 5, 3), dtype=flow.float32, device=flow.device(device)\n    )\n    of_out = flow.transpose(input, -4, -3)\n    np_out = input.numpy().transpose((1, 0, 2, 3))\n    test_case.assertTrue(np.array_equal(of_out.numpy().flatten(), np_out.flatten()))\n\n\ndef _test_transpose_backward(test_case, device):\n    x = flow.tensor(\n        np.random.randn(2, 6, 5, 3),\n        dtype=flow.float32,\n        device=flow.device(device),\n        requires_grad=True,\n    )\n    y = flow.transpose(x, 0, 1).sum()\n    y.backward()\n    test_case.assertTrue(\n        np.allclose(x.grad.numpy(), np.ones((2, 6, 5, 3)), 1e-05, 1e-05)\n    )\n\n\ndef _test_transpose_backward_v2(test_case, device):\n    x = flow.tensor(\n        np.random.randn(2, 3, 4, 5),\n        dtype=flow.float32,\n        device=flow.device(device),\n        requires_grad=True,\n    )\n    y = flow.transpose(x, 3, 1).sum()\n    y.backward()\n    test_case.assertTrue(\n        np.allclose(x.grad.numpy(), np.ones((2, 3, 4, 5)), 1e-05, 1e-05)\n    )\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestTranspose(flow.unittest.TestCase):\n    def test_transpose(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"fun\"] = [\n            _test_transpose,\n            _test_tensor_transpose,\n            _test_tranpose_negative_dim,\n            _test_transpose_backward,\n            _test_transpose_backward_v2,\n        ]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n    @autotest(n=10, check_graph=True)\n    def test_transpose_flow_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=4).to(device)\n        y = torch.transpose(x, dim0=random(1, 3).to(int), dim1=random(1, 3).to(int))\n        return y\n\n    @autotest(n=10, check_graph=True)\n    def test_transpose_with_stride(test_case):\n        device = random_device()\n        x = random_tensor(ndim=4).to(device)\n        permute_list = [0, 1, 2, 3]\n        shuffle(permute_list)\n        x = x.permute(permute_list)\n        y = torch.transpose(x, dim0=random(1, 3).to(int), dim1=random(1, 3).to(int))\n        return y\n\n    @autotest(n=10, auto_backward=False, check_graph=True)\n    def test_transpose_with_0_size_data(test_case):\n        device = random_device()\n        x = random_tensor(4, 2, 3, 0, 4).to(device)\n        y = torch.transpose(x, dim0=random(1, 3).to(int), dim1=random(1, 3).to(int))\n        return y\n\n    @autotest(n=10, auto_backward=False, check_graph=True)\n    def test_transpose_flow_bool_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=4).to(device=device, dtype=torch.bool)\n        y = torch.transpose(x, dim0=random(1, 3).to(int), dim1=random(1, 3).to(int))\n        return y\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_tril.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\n\nfrom oneflow.test_utils.automated_test_util import *\nimport oneflow as flow\nimport oneflow.unittest\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestTril(flow.unittest.TestCase):\n    @autotest(n=5, check_graph=True)\n    def test_tril_without_diag(test_case):\n        device = random_device()\n        x = random_tensor(\n            ndim=4,\n            dim0=random(1, 5).to(int),\n            dim1=random(1, 5).to(int),\n            dim2=random(1, 5).to(int),\n            dim3=random(1, 5).to(int),\n        ).to(device)\n        y = torch.tril(x)\n        y = torch.exp(y)\n\n        return y\n\n    @autotest(n=5, check_graph=True)\n    def test_tril_with_diag(test_case):\n        device = random_device()\n        diagonal = random(-3, 3).to(int)\n        x = random_tensor(\n            ndim=4,\n            dim0=random(1, 5).to(int),\n            dim1=random(1, 5).to(int),\n            dim2=random(1, 5).to(int),\n            dim3=random(1, 5).to(int),\n        ).to(device)\n        y = torch.tril(x, diagonal)\n        y = torch.exp(y)\n\n        return y\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_triu.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport oneflow.nn as nn\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\ndef _test_triu(test_case, diagonal, device, dtype):\n    arr_shape = (4, 4, 8)\n    flow_dtype, np_dtype = dtype\n    np_arr = np.random.randn(*arr_shape).astype(np_dtype)\n    input_tensor = flow.tensor(\n        np_arr, dtype=flow_dtype, device=flow.device(device), requires_grad=True\n    )\n    output = flow.triu(input_tensor, diagonal=diagonal)\n    np_out = np.triu(np_arr, diagonal)\n    test_case.assertTrue(np.allclose(output.numpy(), np_out, 1e-06, 1e-06))\n    output = output.sum()\n    output.backward()\n    np_grad = np.triu(np.ones(shape=arr_shape, dtype=np_dtype), diagonal)\n    test_case.assertTrue(np.allclose(input_tensor.grad.numpy(), np_grad, 1e-06, 1e-06))\n\n\ndef _test_triu_(test_case, diagonal, device, dtype):\n    arr_shape = (4, 4, 8)\n    flow_dtype, np_dtype = dtype\n    np_arr = np.random.randn(*arr_shape).astype(np_dtype)\n    input = flow.tensor(np_arr, dtype=flow_dtype, device=flow.device(device))\n    np_out = np.triu(np_arr, diagonal)\n    test_case.assertFalse(np.allclose(input.numpy(), np_out))\n    input.triu_(diagonal=diagonal)\n    test_case.assertTrue(np.allclose(input.numpy(), np_out))\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestTriu(flow.unittest.TestCase):\n    def test_triu(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [_test_triu, _test_triu_]\n        arg_dict[\"diagonal\"] = [2, -1]\n        arg_dict[\"device\"] = [\"cuda\", \"cpu\"]\n        arg_dict[\"dtype\"] = [(flow.float32, np.float32), (flow.float16, np.float16)]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n    @autotest()\n    def test_triu_with_0_size_data(test_case):\n        device = random_device()\n        x = random_tensor(4, 2, 1, 0, 3).to(device)\n        y = torch.triu(x)\n        return y\n\n    @autotest()\n    def test_triu_with_0_size_data_fp16(test_case):\n        device = random_device()\n        x = random_tensor(4, 2, 1, 0, 3).to(device, torch.float16)\n        y = torch.triu(x)\n        return y\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_trunc.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\n\nfrom oneflow.test_utils.automated_test_util import *\nimport oneflow as flow\nimport oneflow.unittest\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestTrunc(flow.unittest.TestCase):\n    @autotest(n=5, check_graph=True)\n    def test_trunc(test_case):\n        device = random_device()\n        x = random_tensor(\n            ndim=4,\n            dim0=random(1, 5).to(int),\n            dim1=random(1, 5).to(int),\n            dim2=random(1, 5).to(int),\n            dim3=random(1, 5).to(int),\n        ).to(device)\n        y = torch.trunc(x)\n\n        return y\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_trunc_divide.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nimport numpy as np\n\nfrom oneflow.test_utils.automated_test_util import *\nimport oneflow as flow\nimport torch as torch_original\nimport oneflow.unittest\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestTruncDivide(flow.unittest.TestCase):\n    @autotest(n=5, check_allclose=False, check_graph=True)\n    def test_elementwise_trunc_divide_random_data(test_case):\n        device = random_device()\n        dim0 = random(1, 8)\n        dim1 = random(1, 8)\n        dim2 = random(1, 8)\n        dim3 = random(1, 8)\n        x = random_tensor(ndim=4, dim0=dim0, dim1=dim1, dim2=dim2, dim3=dim3).to(device)\n        y = random_tensor(ndim=4, dim0=dim0, dim1=dim1, dim2=dim2, dim3=dim3).to(device)\n\n        x.oneflow = x.oneflow.detach().requires_grad_()\n        x.pytorch = x.pytorch.detach().requires_grad_()\n        y.oneflow = y.oneflow.detach().requires_grad_()\n        y.pytorch = y.pytorch.detach().requires_grad_()\n\n        oneflow_out = flow._C.trunc_divide(x.oneflow, y.oneflow)\n        torch_out = torch_original.div(x.pytorch, y.pytorch, rounding_mode=\"trunc\")\n\n        test_case.assertTrue(\n            np.allclose(\n                oneflow_out.detach().cpu().numpy(),\n                torch_out.detach().cpu().numpy(),\n                rtol=0.0001,\n                atol=1e-05,\n            )\n        )\n\n        oneflow_out.sum().backward()\n        torch_out.sum().backward()\n\n        test_case.assertTrue(\n            np.allclose(\n                x.oneflow.grad.detach().cpu().numpy(),\n                x.pytorch.grad.detach().cpu().numpy(),\n                rtol=0.0001,\n                atol=1e-05,\n            )\n        )\n        test_case.assertTrue(\n            np.allclose(\n                y.oneflow.grad.detach().cpu().numpy(),\n                y.pytorch.grad.detach().cpu().numpy(),\n                rtol=0.0001,\n                atol=1e-05,\n            )\n        )\n\n    @autotest(n=5, check_allclose=False, check_graph=True)\n    def test_tensor_truncdiv_scalar_random_data(test_case):\n        device = random_device()\n        x = random_tensor(\n            ndim=4,\n            dim0=random(1, 8),\n            dim1=random(1, 8),\n            dim2=random(1, 8),\n            dim3=random(1, 8),\n        ).to(device)\n        x.oneflow = x.oneflow.detach().requires_grad_()\n        x.pytorch = x.pytorch.detach().requires_grad_()\n\n        scalar = random().to(float).value()\n\n        oneflow_out = oneflow._C.trunc_divide(x.oneflow, scalar)\n        torch_out = torch_original.div(x.pytorch, scalar, rounding_mode=\"trunc\")\n\n        test_case.assertTrue(\n            np.allclose(\n                oneflow_out.detach().cpu().numpy(),\n                torch_out.detach().cpu().numpy(),\n                rtol=0.0001,\n                atol=1e-5,\n            )\n        )\n\n        oneflow_out.sum().backward()\n        torch_out.sum().backward()\n\n        test_case.assertTrue(\n            np.allclose(\n                x.oneflow.grad.detach().cpu().numpy(),\n                x.pytorch.grad.detach().cpu().numpy(),\n                rtol=0.0001,\n                atol=1e-5,\n            )\n        )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_type_tensor.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport random\nimport unittest\nimport os\n\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\ntype_tensor_all = [\n    {\n        \"cpu_interface\": flow.HalfTensor,\n        \"cuda_interface\": flow.cuda.HalfTensor,\n        \"dtype\": flow.float16,\n    },\n    {\n        \"cpu_interface\": flow.FloatTensor,\n        \"cuda_interface\": flow.cuda.FloatTensor,\n        \"dtype\": flow.float32,\n    },\n    {\n        \"cpu_interface\": flow.DoubleTensor,\n        \"cuda_interface\": flow.cuda.DoubleTensor,\n        \"dtype\": flow.float64,\n    },\n    {\n        \"cpu_interface\": flow.BoolTensor,\n        \"cuda_interface\": flow.cuda.BoolTensor,\n        \"dtype\": flow.bool,\n    },\n    {\n        \"cpu_interface\": flow.ByteTensor,\n        \"cuda_interface\": flow.cuda.ByteTensor,\n        \"dtype\": flow.uint8,\n    },\n    {\n        \"cpu_interface\": flow.CharTensor,\n        \"cuda_interface\": flow.cuda.CharTensor,\n        \"dtype\": flow.int8,\n    },\n    {\n        \"cpu_interface\": flow.IntTensor,\n        \"cuda_interface\": flow.cuda.IntTensor,\n        \"dtype\": flow.int32,\n    },\n    {\n        \"cpu_interface\": flow.LongTensor,\n        \"cuda_interface\": flow.cuda.LongTensor,\n        \"dtype\": flow.int64,\n    },\n    # TODO: flow.BFloat16Tensor fails to creat Tensor.\n    # {\"cpu_interface\": flow.BFloat16Tensor, \"cuda_interface\": flow.cuda.BFloat16Tensor, \"dtype\": flow.bfloat16},\n]\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\nclass TestTypeTensor(flow.unittest.TestCase):\n    def test_type_tensor(test_case):\n        for type_tensor_case in type_tensor_all:\n            x = type_tensor_case[\"cpu_interface\"](np.random.randn(2, 3, 4, 5))\n            test_case.assertEqual(x.device, flow.device(\"cpu\"))\n            test_case.assertEqual(x.dtype, type_tensor_case[\"dtype\"])\n            test_case.assertEqual(x.shape, (2, 3, 4, 5))\n            test_case.assertFalse(x.requires_grad)\n            test_case.assertTrue(x.is_leaf)\n            y = type_tensor_case[\"cuda_interface\"](np.random.randn(2, 3, 4, 5))\n            test_case.assertEqual(y.device, flow.device(\"cuda\"))\n            test_case.assertEqual(y.dtype, type_tensor_case[\"dtype\"])\n            test_case.assertEqual(y.shape, (2, 3, 4, 5))\n            test_case.assertFalse(y.requires_grad)\n            test_case.assertTrue(y.is_leaf)\n\n    def test_doubletensor_corner_cases(test_case):\n        corner_cases = [random.randint(1 << 24, 1 << 25) for _ in range(20)]\n        test_case.assertTrue(\n            np.allclose(\n                flow.DoubleTensor(corner_cases).numpy(),\n                np.array(corner_cases, dtype=np.float64),\n                1e-6,\n                1e-6,\n            )\n        )\n\n    def test_type_tensor_ctor(test_case):\n        for tensor_type in type_tensor_all:\n            cpu_type = tensor_type[\"cpu_interface\"]\n            cuda_type = tensor_type[\"cuda_interface\"]\n\n            # empty ctor\n            cpu_type_tensor = cpu_type()\n            cuda_type_tensor = cuda_type()\n            test_case.assertEqual(cpu_type_tensor.dtype, tensor_type[\"dtype\"])\n            test_case.assertEqual(cpu_type_tensor.device, flow.device(\"cpu\"))\n            test_case.assertEqual(cuda_type_tensor.dtype, tensor_type[\"dtype\"])\n            test_case.assertEqual(cuda_type_tensor.device, flow.device(\"cuda\"))\n\n            # other ctor\n            other_tensor = flow.Tensor(flow.Size([2, 3, 4, 5]))\n            cpu_type_tensor = cpu_type(other_tensor)\n            cuda_type_tensor = cuda_type(other_tensor)\n            test_case.assertEqual(cpu_type_tensor.dtype, tensor_type[\"dtype\"])\n            test_case.assertEqual(cpu_type_tensor.device, flow.device(\"cpu\"))\n            test_case.assertEqual(cuda_type_tensor.dtype, tensor_type[\"dtype\"])\n            test_case.assertEqual(cuda_type_tensor.device, flow.device(\"cuda\"))\n\n            # data ctor\n            # numpy inputs have been tested above in test_type_tensor\n            data = [random.random() for i in range(20)]\n            cpu_type_tensor = cpu_type(data)\n            cuda_type_tensor = cuda_type(data)\n            test_case.assertEqual(cpu_type_tensor.dtype, tensor_type[\"dtype\"])\n            test_case.assertEqual(cpu_type_tensor.device, flow.device(\"cpu\"))\n            test_case.assertEqual(cuda_type_tensor.dtype, tensor_type[\"dtype\"])\n            test_case.assertEqual(cuda_type_tensor.device, flow.device(\"cuda\"))\n\n            # shape ctor\n            shape = flow.Size([2, 3, 4, 5])\n            cpu_type_tensor = cpu_type(shape)\n            cuda_type_tensor = cuda_type(shape)\n            test_case.assertEqual(cpu_type_tensor.dtype, tensor_type[\"dtype\"])\n            test_case.assertEqual(cpu_type_tensor.device, flow.device(\"cpu\"))\n            test_case.assertEqual(cuda_type_tensor.dtype, tensor_type[\"dtype\"])\n            test_case.assertEqual(cuda_type_tensor.device, flow.device(\"cuda\"))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_unbind.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\n\nfrom oneflow.test_utils.automated_test_util import *\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestUnbind(flow.unittest.TestCase):\n    @autotest(n=5, check_graph=True)\n    def test_unbind_flow_with_random_data1(test_case):\n        device = random_device()\n        x = random_tensor(ndim=4).to(device)\n        y = torch.unbind(x, random(0, 4).to(int))\n        return y\n\n    @autotest(n=5, check_graph=True)\n    def test_unbind_flow_with_random_data2(test_case):\n        device = random_device()\n        x = random_tensor(ndim=4).to(device)\n        y = torch.unbind(x, random(0, 4).to(int))\n        return y\n\n    @autotest(n=5, check_graph=True)\n    def test_unbind_flow_with_random_data3(test_case):\n        device = random_device()\n        x = random_tensor(ndim=3).to(device)\n        y = torch.unbind(x, random(0, 3).to(int))\n        return y\n\n    @autotest(n=5, check_graph=True)\n    def test_unbind_flow_with_random_data4(test_case):\n        device = random_device()\n        x = random_tensor(ndim=3).to(device)\n        y = torch.unbind(x, random(0, 3).to(int))\n        return y\n\n    @autotest(n=5, check_graph=True)\n    def test_unbind_flow_with_random_data5(test_case):\n        device = random_device()\n        x = random_tensor(ndim=2).to(device)\n        y = torch.unbind(x, random(0, 2).to(int))\n        return y\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_unfold.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\n\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\nfrom oneflow.nn.common_types import _size_2_t\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestUnfold(flow.unittest.TestCase):\n    @autotest(n=50, auto_backward=True, rtol=1e-4, atol=1e-4)\n    def test_unfold_with_random_data(test_case):\n        m = torch.nn.Unfold(\n            kernel_size=random(1, 3).to(_size_2_t),\n            dilation=random(1, 2).to(_size_2_t) | nothing(),\n            padding=random(0, 1).to(_size_2_t) | nothing(),\n            stride=random(1, 2).to(_size_2_t) | nothing(),\n        )\n        m.train(random())\n        device = random_device()\n        m.to(device)\n        x = random_tensor(\n            ndim=4,\n            dim0=random(1, 5),\n            dim1=random(1, 5),\n            dim2=random(10, 20),\n            dim3=random(10, 20),\n        ).to(device)\n        y = m(x)\n        func_y = torch.nn.functional.unfold(\n            x,\n            kernel_size=random(1, 3).to(_size_2_t),\n            dilation=random(1, 2).to(_size_2_t) | nothing(),\n            padding=random(0, 1).to(_size_2_t) | nothing(),\n            stride=random(1, 2).to(_size_2_t) | nothing(),\n        )\n        return y, func_y\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_unfold_tensor.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\n\nimport numpy as np\nfrom random import shuffle\n\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestUnfoldTensor(flow.unittest.TestCase):\n    @autotest(n=10, auto_backward=True, check_graph=True)\n    def test_unfold_tensor_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(3, 3, 4, 5).to(device)\n        dimension = random(0, 2).to(int).value()\n        size = random(1, 3).to(int).value()\n        step = random(1, 3).to(int).value()\n        y = x.unfold(dimension, size, step)\n        return y\n\n    @autotest(n=5)\n    def test_unfold_tensor_with_stride(test_case):\n        device = random_device()\n        x = random_tensor(3, 3, 4, 5).to(device)\n        perm = [0, 1, 2]\n        shuffle(perm)\n        y = x.permute(perm)\n        dimension = random(0, 2).to(int).value()\n        size = random(1, 3).to(int).value()\n        step = random(1, 3).to(int).value()\n        z = y.unfold(dimension, size, step)\n        return z\n\n    @autotest(n=10, auto_backward=True, check_graph=True)\n    def test_unfold_tensor_with_0dim_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=0).to(device)\n        dimension = random(0, 2).to(int).value()\n        size = random(1, 3).to(int).value()\n        step = random(1, 3).to(int).value()\n        y = x.unfold(dimension, size, step)\n        return y\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_unique.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\n\nimport random as random_util\nimport torch as torch_ori\nfrom collections import OrderedDict\n\nfrom oneflow.test_utils.test_util import GenArgList\nfrom oneflow.test_utils.automated_test_util import *\n\nimport numpy as np\nimport oneflow as flow\nimport oneflow.unittest\n\n\ndef _test_unique_unsorted(test_case, device, return_inverse, return_counts):\n    dtype = random_util.choice([torch.int8, torch.int, torch.float, torch.double])\n    input = random_tensor(ndim=3, dim0=random(), dim1=random(), dim2=random(), high=20)\n    input = input.to(device).to(dtype)\n    oneflow_output = flow.unique(\n        input.oneflow,\n        sorted=False,\n        return_inverse=return_inverse,\n        return_counts=return_counts,\n    )\n    torch_output = torch_ori.unique(\n        input.pytorch,\n        sorted=False,\n        return_inverse=return_inverse,\n        return_counts=return_counts,\n    )\n\n    if not return_inverse and not return_counts:\n        oneflow_result = oneflow_output\n        torch_result = torch_output\n    else:\n        oneflow_result = oneflow_output[0]\n        torch_result = torch_output[0]\n\n    test_case.assertTrue(\n        np.allclose(\n            np.sort(oneflow_result.numpy()),\n            np.sort(torch_result.detach().cpu().numpy()),\n        )\n    )\n    test_case.assertEqual(list(oneflow_result.shape), list(torch_result.shape))\n\n    if return_inverse:\n        oneflow_indices = oneflow_output[1]\n        torch_indices = torch_output[1]\n        test_case.assertTrue(\n            np.allclose(\n                oneflow_result[oneflow_indices].numpy(),\n                torch_result[torch_indices].detach().cpu().numpy(),\n            )\n        )\n        test_case.assertEqual(list(oneflow_indices.shape), list(torch_indices.shape))\n\n    if return_counts:\n        oneflow_counts = oneflow_output[-1]\n        torch_counts = torch_output[-1]\n        test_case.assertTrue(\n            np.allclose(\n                oneflow_counts.numpy()[np.argsort(oneflow_result.numpy())],\n                torch_counts.detach()\n                .cpu()\n                .numpy()[np.argsort(torch_result.detach().cpu().numpy())],\n            )\n        )\n        test_case.assertEqual(list(oneflow_counts.shape), list(torch_counts.shape))\n\n\ndef _test_unique_sorted(test_case, device, return_inverse, return_counts):\n    dtype = random_util.choice([torch.int8, torch.int, torch.float, torch.double])\n    input = random_tensor(ndim=3, dim0=random(), dim1=random(), dim2=random(), high=20)\n    input = input.to(device).to(dtype)\n    oneflow_output = flow.unique(\n        input.oneflow,\n        sorted=True,\n        return_inverse=return_inverse,\n        return_counts=return_counts,\n    )\n    torch_output = torch_ori.unique(\n        input.pytorch,\n        sorted=True,\n        return_inverse=return_inverse,\n        return_counts=return_counts,\n    )\n\n    if not return_inverse and not return_counts:\n        oneflow_result = oneflow_output\n        torch_result = torch_output\n    else:\n        oneflow_result = oneflow_output[0]\n        torch_result = torch_output[0]\n\n    test_case.assertTrue(\n        np.allclose(oneflow_result.numpy(), torch_result.detach().cpu().numpy(),)\n    )\n    test_case.assertEqual(list(oneflow_result.shape), list(torch_result.shape))\n\n    if return_inverse:\n        oneflow_indices = oneflow_output[1]\n        torch_indices = torch_output[1]\n        test_case.assertTrue(\n            np.allclose(oneflow_indices.numpy(), torch_indices.detach().cpu().numpy(),)\n        )\n        test_case.assertEqual(list(oneflow_indices.shape), list(torch_indices.shape))\n\n    if return_counts:\n        oneflow_counts = oneflow_output[-1]\n        torch_counts = torch_output[-1]\n        test_case.assertTrue(\n            np.allclose(oneflow_counts.numpy(), torch_counts.detach().cpu().numpy(),)\n        )\n        test_case.assertEqual(list(oneflow_counts.shape), list(torch_counts.shape))\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestUnique(flow.unittest.TestCase):\n    @autotest(n=5)\n    def test_unique(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"return_inverse\"] = [False, True]\n        arg_dict[\"return_counts\"] = [False, True]\n        for arg in GenArgList(arg_dict):\n            _test_unique_unsorted(test_case, *arg)\n            _test_unique_sorted(test_case, *arg)\n\n    @profile(torch.unique)\n    def profile_unique(test_case):\n        input = torch.randint(0, 1000, (1000,))\n        torch.unique(input)\n        torch.unique(input, return_inverse=True, return_counts=True)\n        input = torch.randn(1000,)\n        torch.unique(input)\n        torch.unique(input, return_inverse=True, return_counts=True)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_unsqueeze.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\n\nfrom oneflow.test_utils.automated_test_util import *\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\ndef _test_unsqueeze(test_case, device):\n    np_arr = np.random.rand(2, 6, 9, 3)\n    x = flow.tensor(np_arr, dtype=flow.float32, device=flow.device(device))\n    y = flow.unsqueeze(x, dim=1)\n    output = np.expand_dims(np_arr, axis=1)\n    test_case.assertTrue(np.allclose(output, y.numpy(), 1e-05, 1e-05))\n    x_flow = flow.randn(5)\n    x_flow = flow.unsqueeze(x_flow, 0)\n    test_case.assertTrue(np.array_equal(x_flow.stride(), (5, 1)))\n    x_flow = flow.randn(5, 2)\n    x_flow = flow.unsqueeze(x_flow, 0)\n    test_case.assertTrue(np.array_equal(x_flow.stride(), (10, 2, 1)))\n\n\ndef _test_unsqueeze_tensor_function(test_case, device):\n    np_arr = np.random.rand(2, 3, 4)\n    x = flow.tensor(np_arr, dtype=flow.float32, device=flow.device(device))\n    y = x.unsqueeze(dim=2)\n    output = np.expand_dims(np_arr, axis=2)\n    test_case.assertTrue(np.allclose(output, y.numpy(), 1e-05, 1e-05))\n\n\ndef _test_unsqueeze_different_dim(test_case, device):\n    np_arr = np.random.rand(4, 5, 6, 7)\n    x = flow.tensor(np_arr, dtype=flow.float32, device=flow.device(device))\n    for axis in range(-5, 5):\n        y = flow.unsqueeze(x, dim=axis)\n        output = np.expand_dims(np_arr, axis=axis)\n        test_case.assertTrue(np.allclose(output, y.numpy(), 1e-05, 1e-05))\n\n\ndef _test_unsqueeze_backward(test_case, device):\n    np_arr = np.random.rand(2, 3, 4, 5)\n    x = flow.tensor(\n        np_arr, dtype=flow.float32, device=flow.device(device), requires_grad=True\n    )\n    y = flow.unsqueeze(x, dim=1).sum()\n    y.backward()\n    test_case.assertTrue(\n        np.allclose(x.grad.numpy(), np.ones((2, 3, 4, 5)), 1e-05, 1e-05)\n    )\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestUnsqueeze(flow.unittest.TestCase):\n    def test_unsqueeze(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [\n            _test_unsqueeze,\n            _test_unsqueeze_tensor_function,\n            _test_unsqueeze_different_dim,\n            _test_unsqueeze_backward,\n        ]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n    @autotest(check_graph=True)\n    def test_flow_unsqueeze_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor().to(device)\n        y = torch.unsqueeze(x, random(1, 3).to(int))\n        return y\n\n    @autotest(n=10, check_graph=False, auto_backward=False)\n    def test_inplace_unsqueeze_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(requires_grad=False).to(device)\n        y = x.unsqueeze_(random(1, 3).to(int))\n        return y\n\n    @autotest(auto_backward=False, check_graph=True)\n    def test_unsqueeze_with_0_size_data(test_case):\n        device = random_device()\n        x = random_tensor(3, 2, 1, 0).to(device)\n        y = torch.unsqueeze(x, random(0, 2).to(int))\n        return y\n\n    @autotest(auto_backward=False, check_graph=True)\n    def test_flow_unsqueeze_bool_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor().to(device=device, dtype=torch.bool)\n        y = torch.unsqueeze(x, random(1, 3).to(int))\n        return y\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_upsample.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nfrom oneflow.test_utils.test_util import GenArgList\nfrom oneflow.test_utils.automated_test_util import *\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\ndef _test_upsample2d_bilinear(test_case, device):\n    input = flow.tensor(\n        np.arange(1, 5).reshape((1, 1, 2, 2)),\n        device=flow.device(device),\n        dtype=flow.float32,\n    )\n    m = flow.nn.Upsample(scale_factor=2.0, mode=\"bilinear\")\n    of_out = m(input)\n    np_out = np.array(\n        [\n            [\n                [\n                    [1.0, 1.25, 1.75, 2.0],\n                    [1.5, 1.75, 2.25, 2.5],\n                    [2.5, 2.75, 3.25, 3.5],\n                    [3.0, 3.25, 3.75, 4.0],\n                ]\n            ]\n        ]\n    )\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05))\n\n\ndef _test_upsample2d_bilinear_aligncorner(test_case, device):\n    input = flow.tensor(\n        np.arange(1, 5).reshape((1, 1, 2, 2)),\n        device=flow.device(device),\n        dtype=flow.float32,\n    )\n    m = flow.nn.Upsample(scale_factor=2.0, mode=\"bilinear\", align_corners=True)\n    of_out = m(input)\n    np_out = np.array(\n        [\n            [\n                [\n                    [1.0, 1.3333, 1.6667, 2.0],\n                    [1.6667, 2.0, 2.3333, 2.6667],\n                    [2.3333, 2.6667, 3.0, 3.3333],\n                    [3.0, 3.3333, 3.6667, 4.0],\n                ]\n            ]\n        ]\n    )\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001))\n\n\ndef _test_UpsamplingNearest2d(test_case, device):\n    input = flow.tensor(\n        np.arange(1, 5).reshape((1, 1, 2, 2)),\n        device=flow.device(device),\n        dtype=flow.float32,\n    )\n    m = flow.nn.UpsamplingNearest2d(scale_factor=2.0)\n    of_out = m(input)\n    np_out = np.array(\n        [\n            [\n                [\n                    [1.0, 1.0, 2.0, 2.0],\n                    [1.0, 1.0, 2.0, 2.0],\n                    [3.0, 3.0, 4.0, 4.0],\n                    [3.0, 3.0, 4.0, 4.0],\n                ]\n            ]\n        ]\n    )\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05))\n\n\ndef _test_UpsamplingBilinear2d(test_case, device):\n    input = flow.tensor(\n        np.arange(1, 5).reshape((1, 1, 2, 2)),\n        device=flow.device(device),\n        dtype=flow.float32,\n    )\n    m = flow.nn.UpsamplingBilinear2d(scale_factor=2.0)\n    of_out = m(input)\n    np_out = np.array(\n        [\n            [\n                [\n                    [1.0, 1.3333, 1.6667, 2.0],\n                    [1.6667, 2.0, 2.3333, 2.6667],\n                    [2.3333, 2.6667, 3.0, 3.3333],\n                    [3.0, 3.3333, 3.6667, 4.0],\n                ]\n            ]\n        ]\n    )\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001))\n\n\ndef _test_upsample2d_4dim(test_case, device):\n    input = flow.tensor(\n        np.arange(1, 37).reshape((2, 2, 3, 3)),\n        device=flow.device(device),\n        dtype=flow.float32,\n    )\n    m = flow.nn.Upsample(scale_factor=2.0, mode=\"nearest\")\n    of_out = m(input)\n    np_out = np.array(\n        [\n            [\n                [\n                    [1.0, 1.0, 2.0, 2.0, 3.0, 3.0],\n                    [1.0, 1.0, 2.0, 2.0, 3.0, 3.0],\n                    [4.0, 4.0, 5.0, 5.0, 6.0, 6.0],\n                    [4.0, 4.0, 5.0, 5.0, 6.0, 6.0],\n                    [7.0, 7.0, 8.0, 8.0, 9.0, 9.0],\n                    [7.0, 7.0, 8.0, 8.0, 9.0, 9.0],\n                ],\n                [\n                    [10.0, 10.0, 11.0, 11.0, 12.0, 12.0],\n                    [10.0, 10.0, 11.0, 11.0, 12.0, 12.0],\n                    [13.0, 13.0, 14.0, 14.0, 15.0, 15.0],\n                    [13.0, 13.0, 14.0, 14.0, 15.0, 15.0],\n                    [16.0, 16.0, 17.0, 17.0, 18.0, 18.0],\n                    [16.0, 16.0, 17.0, 17.0, 18.0, 18.0],\n                ],\n            ],\n            [\n                [\n                    [19.0, 19.0, 20.0, 20.0, 21.0, 21.0],\n                    [19.0, 19.0, 20.0, 20.0, 21.0, 21.0],\n                    [22.0, 22.0, 23.0, 23.0, 24.0, 24.0],\n                    [22.0, 22.0, 23.0, 23.0, 24.0, 24.0],\n                    [25.0, 25.0, 26.0, 26.0, 27.0, 27.0],\n                    [25.0, 25.0, 26.0, 26.0, 27.0, 27.0],\n                ],\n                [\n                    [28.0, 28.0, 29.0, 29.0, 30.0, 30.0],\n                    [28.0, 28.0, 29.0, 29.0, 30.0, 30.0],\n                    [31.0, 31.0, 32.0, 32.0, 33.0, 33.0],\n                    [31.0, 31.0, 32.0, 32.0, 33.0, 33.0],\n                    [34.0, 34.0, 35.0, 35.0, 36.0, 36.0],\n                    [34.0, 34.0, 35.0, 35.0, 36.0, 36.0],\n                ],\n            ],\n        ]\n    )\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05))\n\n\ndef _test_upsample2d_bilinear_4dim(test_case, device):\n    input = flow.tensor(\n        np.arange(1, 37).reshape((2, 2, 3, 3)),\n        device=flow.device(device),\n        dtype=flow.float32,\n    )\n    m = flow.nn.Upsample(scale_factor=2.0, mode=\"bilinear\")\n    of_out = m(input)\n    np_out = np.array(\n        [\n            [\n                [\n                    [1.0, 1.25, 1.75, 2.25, 2.75, 3.0],\n                    [1.75, 2.0, 2.5, 3.0, 3.5, 3.75],\n                    [3.25, 3.5, 4.0, 4.5, 5.0, 5.25],\n                    [4.75, 5.0, 5.5, 6.0, 6.5, 6.75],\n                    [6.25, 6.5, 7.0, 7.5, 8.0, 8.25],\n                    [7.0, 7.25, 7.75, 8.25, 8.75, 9.0],\n                ],\n                [\n                    [10.0, 10.25, 10.75, 11.25, 11.75, 12.0],\n                    [10.75, 11.0, 11.5, 12.0, 12.5, 12.75],\n                    [12.25, 12.5, 13.0, 13.5, 14.0, 14.25],\n                    [13.75, 14.0, 14.5, 15.0, 15.5, 15.75],\n                    [15.25, 15.5, 16.0, 16.5, 17.0, 17.25],\n                    [16.0, 16.25, 16.75, 17.25, 17.75, 18.0],\n                ],\n            ],\n            [\n                [\n                    [19.0, 19.25, 19.75, 20.25, 20.75, 21.0],\n                    [19.75, 20.0, 20.5, 21.0, 21.5, 21.75],\n                    [21.25, 21.5, 22.0, 22.5, 23.0, 23.25],\n                    [22.75, 23.0, 23.5, 24.0, 24.5, 24.75],\n                    [24.25, 24.5, 25.0, 25.5, 26.0, 26.25],\n                    [25.0, 25.25, 25.75, 26.25, 26.75, 27.0],\n                ],\n                [\n                    [28.0, 28.25, 28.75, 29.25, 29.75, 30.0],\n                    [28.75, 29.0, 29.5, 30.0, 30.5, 30.75],\n                    [30.25, 30.5, 31.0, 31.5, 32.0, 32.25],\n                    [31.75, 32.0, 32.5, 33.0, 33.5, 33.75],\n                    [33.25, 33.5, 34.0, 34.5, 35.0, 35.25],\n                    [34.0, 34.25, 34.75, 35.25, 35.75, 36.0],\n                ],\n            ],\n        ]\n    )\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05))\n\n\ndef _test_upsample2d_backward(test_case, device):\n    input = flow.tensor(\n        np.arange(1, 5).reshape((1, 1, 2, 2)),\n        dtype=flow.float32,\n        device=flow.device(device),\n        requires_grad=True,\n    )\n    m = flow.nn.Upsample(scale_factor=2.0, mode=\"nearest\")\n    of_out = m(input)\n    of_out = of_out.sum()\n    of_out.backward()\n    np_grad = [[[[4.0, 4.0], [4.0, 4.0]]]]\n    test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-05, 1e-05))\n\n\ndef _test_upsample2d_bilinear_aligncorner_backward(test_case, device):\n    input = flow.tensor(\n        np.arange(1, 5).reshape((1, 1, 2, 2)),\n        device=flow.device(device),\n        dtype=flow.float32,\n        requires_grad=True,\n    )\n    m = flow.nn.Upsample(scale_factor=2.0, mode=\"bilinear\", align_corners=True)\n    of_out = m(input)\n    of_out = of_out.sum()\n    of_out.backward()\n    np_grad = [[[[3.999999523162842, 4.000000476837158], [3.999999761581421, 4.0]]]]\n    test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-05, 1e-05))\n\n\ndef _test_interpolate_nearest_float_scale(test_case, device):\n    input = flow.tensor(\n        np.arange(1, 10).reshape((1, 1, 3, 3)),\n        device=flow.device(device),\n        dtype=flow.float32,\n        requires_grad=True,\n    )\n    m = flow.nn.Upsample(scale_factor=1.5)\n    of_out = m(input)\n    np_out = np.array(\n        [\n            [\n                [\n                    [1.0, 1.0, 2.0, 3.0],\n                    [1.0, 1.0, 2.0, 3.0],\n                    [4.0, 4.0, 5.0, 6.0],\n                    [7.0, 7.0, 8.0, 9.0],\n                ]\n            ]\n        ]\n    )\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05))\n    of_out = of_out.sum()\n    of_out.backward()\n    np_grad = np.array([[[[4.0, 2.0, 2.0], [2.0, 1.0, 1.0], [2.0, 1.0, 1.0]]]])\n    test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-05, 1e-05))\n\n\ndef _test_interpolate_bilinear_float_scale(test_case, device):\n    input = flow.tensor(\n        np.arange(1, 5, dtype=np.int32).reshape((1, 1, 2, 2)),\n        device=flow.device(device),\n        dtype=flow.float32,\n        requires_grad=True,\n    )\n    m = flow.nn.Upsample(scale_factor=0.5, mode=\"bilinear\")\n    of_out = m(input)\n    np_out = np.array([[[[2.5]]]])\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05))\n    of_out = of_out.sum()\n    of_out.backward()\n    np_grad = np.array([[[[0.25, 0.25], [0.25, 0.25]]]])\n    test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-05, 1e-05))\n    input = flow.tensor(\n        np.arange(1, 10, dtype=np.int32).reshape((1, 1, 3, 3)),\n        device=flow.device(device),\n        dtype=flow.float32,\n        requires_grad=True,\n    )\n    m = flow.nn.Upsample(scale_factor=0.5, mode=\"bilinear\")\n    of_out = m(input)\n    np_out = np.array([[[[3.0]]]])\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05))\n    of_out = of_out.sum()\n    of_out.backward()\n    np_grad = np.array([[[[0.25, 0.25, 0.0], [0.25, 0.25, 0.0], [0.0, 0.0, 0.0]]]])\n    test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-05, 1e-05))\n    input = flow.tensor(\n        np.arange(1, 11, dtype=np.int32).reshape((1, 1, 5, 2)),\n        device=flow.device(device),\n        dtype=flow.float32,\n        requires_grad=True,\n    )\n    m = flow.nn.Upsample(size=(4, 4), mode=\"bilinear\")\n    of_out = m(input)\n    np_out = np.array(\n        [\n            [\n                [\n                    [1.25, 1.5, 2.0, 2.25],\n                    [3.75, 4.0, 4.5, 4.75],\n                    [6.25, 6.5, 7.0, 7.25],\n                    [8.75, 9.0, 9.5, 9.75],\n                ]\n            ]\n        ]\n    )\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05))\n    of_out = of_out.sum()\n    of_out.backward()\n    np_grad = np.array(\n        [[[[1.75, 1.75], [1.5, 1.5], [1.5, 1.5], [1.5, 1.5], [1.75, 1.75]]]]\n    )\n    test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-05, 1e-05))\n\n\ndef _test_upsample_bilinear_align_corners(test_case, device):\n    input = flow.tensor(\n        np.arange(1, 5, dtype=np.int32).reshape((1, 1, 2, 2)),\n        device=flow.device(device),\n        dtype=flow.float32,\n        requires_grad=True,\n    )\n    m = flow.nn.Upsample(scale_factor=0.5, mode=\"bilinear\", align_corners=True)\n    of_out = m(input)\n    np_out = np.array([[[[1.0]]]])\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05))\n    of_out = of_out.sum()\n    of_out.backward()\n    np_grad = np.array([[[[1.0, 0.0], [0.0, 0.0]]]])\n    test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-05, 1e-05))\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestUpsample2d(flow.unittest.TestCase):\n    def test_upsample2d(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [\n            _test_upsample2d_bilinear,\n            _test_upsample2d_bilinear_aligncorner,\n            _test_UpsamplingNearest2d,\n            _test_UpsamplingBilinear2d,\n            _test_upsample2d_4dim,\n            _test_upsample2d_bilinear_4dim,\n            _test_upsample2d_backward,\n            _test_upsample2d_bilinear_aligncorner_backward,\n            _test_interpolate_nearest_float_scale,\n            _test_interpolate_bilinear_float_scale,\n            _test_upsample_bilinear_align_corners,\n        ]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n    @unittest.skip(\n        \"The nearest interpolate operation in pytorch has bug, https://github.com/pytorch/pytorch/issues/65200\"\n    )\n    @autotest()\n    def test_upsample2d_nearest(test_case):\n        device = random_device()\n        x = random_tensor().to(device)\n        m = torch.nn.Upsample(scale_factor=random().to(float), mode=\"nearest\")\n        y = m(x)\n        return y\n\n    @unittest.skip(\n        \"The nearest interpolate operation in pytorch has bug, https://github.com/pytorch/pytorch/issues/65200\"\n    )\n    @autotest()\n    def test_upsample2d_nearest_half(test_case):\n        device = random_device()\n        x = random_tensor().to(device=device, dtype=torch.float16)\n        m = torch.nn.Upsample(scale_factor=random().to(float), mode=\"nearest\")\n        y = m(x)\n        return y\n\n    # The forward and backward result in cpu and cuda of bilinear interpolate operation in PyTorch is different\n    # in some corner cases. OneFlow has the same cpu and cuda results with PyTorch's cuda result.\n    # So here we only test cuda device forward result.\n    @autotest(n=10, auto_backward=False, atol=1e-8)\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    def test_upsample2d_bilinear(test_case):\n        x = random_tensor(ndim=4).to(\"cuda\")\n        x = x.permute(1, 3, 0, 2)\n        m = torch.nn.Upsample(\n            scale_factor=random().to(float),\n            mode=\"bilinear\",\n            align_corners=random_bool(),\n        )\n        y = m(x)\n        return y\n\n    @autotest(atol=1e-5)\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    def test_upsample2d_bicubic(test_case):\n        x = random_tensor(ndim=4, dim0=16, dim1=8).to(\"cuda\")\n        m = torch.nn.Upsample(\n            scale_factor=random().to(float),\n            mode=\"bicubic\",\n            align_corners=random_bool(),\n        )\n        y = m(x)\n        return y\n\n    @autotest(n=5, atol=1e-5)\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    def test_upsample1d_nearest_output_size(test_case):\n        x = random_tensor(ndim=3, dim0=1, dim1=2, dim2=12).to(\"cuda\")\n        m = torch.nn.Upsample(size=(13), mode=\"nearest\")\n        y = m(x)\n        return y\n\n    @autotest(n=5, atol=1e-5)\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    def test_upsample2d_nearest_output_size(test_case):\n        x = random_tensor(ndim=4, dim0=1, dim1=1, dim2=1, dim3=937).to(\"cuda\")\n        m = torch.nn.Upsample(size=(1, 30), mode=\"nearest\")\n        y = m(x)\n        return y\n\n    @autotest(n=5, atol=1e-5)\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    def test_upsample3d_nearest_output_size(test_case):\n        x = random_tensor(ndim=5, dim0=1, dim1=1, dim2=6, dim3=12, dim4=6).to(\"cuda\")\n        m = torch.nn.Upsample(size=(8, 10, 7), mode=\"nearest\")\n        y = m(x)\n        return y\n\n    @autotest(n=5, atol=1e-5)\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    def test_upsample1d_linear_output_size(test_case):\n        device = random_device()\n        x = random_tensor(ndim=3, dim0=1, dim1=2, dim2=12).to(device)\n        m = torch.nn.Upsample(size=(13), mode=\"linear\")\n        y = m(x)\n        return y\n\n    @autotest(n=5, atol=1e-5)\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    def test_upsample2d_bilinear_output_size(test_case):\n        x = random_tensor(ndim=4, dim0=1, dim1=1, dim2=12, dim3=21).to(\"cuda\")\n        m = torch.nn.Upsample(size=(14, 19), mode=\"bilinear\")\n        y = m(x)\n        return y\n\n    @autotest(n=5, atol=1e-5)\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    def test_upsample2d_bicubic_output_size(test_case):\n        x = random_tensor(ndim=4, dim0=1, dim1=2, dim2=12, dim3=21).to(\"cuda\")\n        m = torch.nn.Upsample(size=(14, 19), mode=\"bicubic\")\n        y = m(x)\n        return y\n\n    @autotest(n=5, atol=1e-5)\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    def test_upsample3d_trilinear_output_size(test_case):\n        x = random_tensor(ndim=5, dim0=1, dim1=2, dim2=1, dim3=12, dim4=17).to(\"cuda\")\n        m = torch.nn.Upsample(size=(1, 14, 23), mode=\"trilinear\")\n        y = m(x)\n        return y\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_util_ops.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nimport numpy as np\nimport oneflow as flow\nfrom collections import OrderedDict\nfrom oneflow.test_utils.automated_test_util import *\nfrom oneflow.test_utils.test_util import GenArgList\n\n# TODO(): random_tensor can't generate a tensor with nan or inf element.\ndef _test_isnan(test_case, shape, dtype, device):\n    np_array = np.random.randn(*shape)\n    mask = np.random.choice([1, 0], np_array.shape, p=[0.1, 0.9]).astype(bool)\n    np_array[mask] = np.nan\n    of_tensor = flow.tensor(np_array, dtype=dtype, device=device)\n    res = flow.isnan(of_tensor)\n    test_case.assertTrue(np.allclose(res.numpy(), np.isnan(of_tensor.numpy())))\n\n\ndef _test_isinf(test_case, shape, dtype, device):\n    np_array = np.random.randn(*shape)\n    mask = np.random.choice([1, 0], np_array.shape, p=[0.1, 0.9]).astype(bool)\n    np_array[mask] = np.inf\n    of_tensor = flow.tensor(np_array, dtype=dtype, device=device)\n    res = flow.isinf(of_tensor)\n    test_case.assertTrue(np.allclose(res.numpy(), np.isinf(of_tensor.numpy())))\n\n\ndef _test_isfinite(test_case, shape, dtype, device):\n    np_array = np.random.randn(*shape)\n    inf_mask = np.random.choice([1, 0], np_array.shape, p=[0.1, 0.9]).astype(bool)\n    nan_mask = np.random.choice([1, 0], np_array.shape, p=[0.1, 0.9]).astype(bool)\n    np_array[inf_mask] = np.inf\n    np_array[nan_mask] = np.nan\n    of_tensor = flow.tensor(np_array, dtype=dtype, device=device)\n    res = flow.isfinite(of_tensor)\n    test_case.assertTrue(np.allclose(res.numpy(), np.isfinite(of_tensor.numpy())))\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestUtilOps(flow.unittest.TestCase):\n    def test_util_ops(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [_test_isnan, _test_isinf, _test_isfinite]\n        arg_dict[\"shape\"] = [(2, 3, 4), (1, 2, 3)]\n        arg_dict[\"dtype\"] = [flow.float, flow.int]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_utils.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nimport numpy as np\nimport oneflow as flow\nimport torch\nfrom torch._utils import _flatten_dense_tensors as torch_flatten_dense_tensors\nfrom torch._utils import _unflatten_dense_tensors as torch_unflatten_dense_tensors\nfrom oneflow._utils import _flatten_dense_tensors, _unflatten_dense_tensors\nfrom collections import OrderedDict\nfrom oneflow.test_utils.test_util import GenArgList\n\n\ndef _test_flatten_dense_tensors(test_case, device):\n    torch_x = torch.randn(6, 6, device=device)\n    x = flow.utils.tensor.from_torch(torch_x)\n    torch_x_flatten = torch_flatten_dense_tensors([torch_x])\n    x_flatten = _flatten_dense_tensors([x])\n    test_case.assertTrue(np.array_equal(torch_x_flatten.size(), x_flatten.size()))\n    torch_x_flatten = torch_flatten_dense_tensors([torch_x, torch_x, torch_x])\n    x_flatten = _flatten_dense_tensors([x, x, x])\n    test_case.assertTrue(np.array_equal(torch_x_flatten.size(), x_flatten.size()))\n    test_case.assertTrue(\n        np.allclose(\n            torch_x_flatten.cpu().numpy(), x_flatten.cpu().numpy(), 1e-05, 1e-05\n        )\n    )\n\n\ndef _test_unflatten_dense_tensors(test_case, device):\n    torch_flat = torch.randn(6, 1, device=device)\n    torch_x1 = torch.randn(2, 1, device=device)\n    torch_x2 = torch.randn(2, 1, device=device)\n    torch_x3 = torch.randn(2, 1, device=device)\n    torch_tensors = [\n        torch_x1,\n        torch_x2,\n        torch_x3,\n    ]\n    tensors = [\n        flow.utils.tensor.from_torch(torch_x1),\n        flow.utils.tensor.from_torch(torch_x2),\n        flow.utils.tensor.from_torch(torch_x3),\n    ]\n    torch_outputs = torch_unflatten_dense_tensors(torch_flat, torch_tensors)\n    outputs = _unflatten_dense_tensors(\n        flow.utils.tensor.from_torch(torch_flat), tensors\n    )\n    for i in range(len(outputs)):\n        test_case.assertTrue(np.array_equal(torch_outputs[i].size(), outputs[i].size()))\n        test_case.assertTrue(\n            np.allclose(\n                torch_outputs[i].cpu().numpy(), outputs[i].cpu().numpy(), 1e-05, 1e-05\n            )\n        )\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestUtilsFunction(flow.unittest.TestCase):\n    def test_utils_function(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [\n            _test_flatten_dense_tensors,\n            _test_unflatten_dense_tensors,\n        ]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_var.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\n\nimport oneflow as flow\nfrom oneflow.test_utils.automated_test_util.generators import random\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestVar(flow.unittest.TestCase):\n    @autotest(check_graph=True)\n    def test_flow_var_all_dim_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor().to(device)\n        y = torch.var(x)\n        return y\n\n    @autotest(check_graph=True)\n    def test_flow_var_one_dim_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=4).to(device)\n        y = torch.var(\n            x,\n            dim=random(low=-4, high=4).to(int),\n            unbiased=random().to(bool),\n            keepdim=random().to(bool),\n        )\n        return y\n\n    # In fp16 mode, variance op backward has a gap of 1e-3 between the gradient of PyTorch\n    # and OneFlow for some unknown reason. However, it is not important now, because both in\n    # PyTorch and OneFlow variance op don't need support fp16 backward in amp train.\n    @autotest(n=5, auto_backward=True, check_graph=True, rtol=1e-3, atol=1e-3)\n    def test_flow_var_one_dim_with_random_half_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=4).to(device).to(torch.float16)\n        y = torch.var(\n            x,\n            dim=random(low=-4, high=4).to(int),\n            unbiased=random().to(bool),\n            keepdim=random().to(bool),\n        )\n        return y\n\n    @autotest(auto_backward=False, check_graph=True)\n    def test_flow_var_0_size_data_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(4, 2, 3, 0, 4).to(device)\n        y = torch.var(\n            x,\n            dim=random(low=-4, high=4).to(int),\n            unbiased=random().to(bool),\n            keepdim=random().to(bool),\n        )\n        return y\n\n    @autotest(n=5, auto_backward=False, check_graph=True)\n    def test_flow_var_0_size_data_with_random_half_data(test_case):\n        device = random_device()\n        x = random_tensor(4, 2, 3, 0, 4).to(device).to(torch.float16)\n        y = torch.var(\n            x,\n            dim=random(low=-4, high=4).to(int),\n            unbiased=random().to(bool),\n            keepdim=random().to(bool),\n        )\n        return y\n\n    @autotest(n=5)\n    def test_flow_var_all_dim_with_random_data_n5(test_case):\n        device = random_device()\n        x = random_tensor(ndim=4, dim0=5, dim1=1, dim2=16, dim3=16).to(device)\n        y = torch.var(x, dim=[0, 2, 3])\n        return y\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_view.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nfrom oneflow.test_utils.test_util import GenArgList\nfrom oneflow.test_utils.automated_test_util import *\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\ndef _test_view(test_case, device):\n    x = np.array(\n        [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]]\n    ).astype(np.float32)\n    input = flow.tensor(\n        x, dtype=flow.float32, device=flow.device(device), requires_grad=True\n    )\n    of_out = input.view(2, 2, 2, -1)\n    of_shape = of_out.numpy().shape\n    np_shape = (2, 2, 2, 2)\n    test_case.assertTrue(np.array_equal(of_shape, np_shape))\n    of_out = of_out.sum()\n    of_out.backward()\n    np_grad = np.array(\n        [\n            [1.0, 1.0, 1.0, 1.0],\n            [1.0, 1.0, 1.0, 1.0],\n            [1.0, 1.0, 1.0, 1.0],\n            [1.0, 1.0, 1.0, 1.0],\n        ]\n    )\n    test_case.assertTrue(np.allclose(np_grad, input.grad.numpy(), 0.0001, 0.0001))\n\n\ndef _test_view_flow_size(test_case, device):\n    x = np.array(\n        [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]]\n    ).astype(np.float32)\n    input = flow.tensor(\n        x, dtype=flow.float32, device=flow.device(device), requires_grad=True\n    )\n    shape = flow.Size([2, 2, 2, -1])\n    of_out = input.view(shape)\n    np_shape = (2, 2, 2, 2)\n    test_case.assertTrue(np.array_equal(of_out.numpy().shape, np_shape))\n    of_out = of_out.sum()\n    of_out.backward()\n    np_grad = np.array(\n        [\n            [1.0, 1.0, 1.0, 1.0],\n            [1.0, 1.0, 1.0, 1.0],\n            [1.0, 1.0, 1.0, 1.0],\n            [1.0, 1.0, 1.0, 1.0],\n        ]\n    )\n    test_case.assertTrue(np.allclose(np_grad, input.grad.numpy(), 0.0001, 0.0001))\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestView(flow.unittest.TestCase):\n    # TODO:(zhaoluyang) add test case that trigger tensor.view's check\n    def test_view(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [\n            _test_view,\n            _test_view_flow_size,\n        ]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n    @autotest(n=5, check_graph=True)\n    def test_view_with_0_dim_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=0).to(device)\n        y1 = torch.reshape(x, shape=(-1,))\n        y2 = x.view((1, 1, 1))\n        test_case.assertTrue(x.oneflow.stride() == x.pytorch.stride())\n        test_case.assertTrue(y1.oneflow.stride() == y1.pytorch.stride())\n        test_case.assertTrue(y2.oneflow.stride() == y2.pytorch.stride())\n        return y2\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_vsplit.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nfrom random import shuffle\n\nfrom oneflow.test_utils.automated_test_util import *\nimport oneflow as flow\nimport oneflow.unittest\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestVsplitVec(flow.unittest.TestCase):\n    @autotest(check_graph=True)\n    def test_flow_vsplit_vec(test_case):\n        device = random_device()\n        x = random_tensor(\n            ndim=4,\n            dim0=random(3, 6),\n            dim1=random(3, 6),\n            dim2=random(3, 6),\n            dim3=random(3, 6),\n        ).to(device)\n        z = torch.vsplit(x, (1, 2))\n        return z[0]\n\n    @autotest(n=10)\n    def test_flow_vsplit_vec_with_stride(test_case):\n        device = random_device()\n        x = random_tensor(\n            ndim=4,\n            dim0=random(3, 6),\n            dim1=random(3, 6),\n            dim2=random(3, 6),\n            dim3=random(3, 6),\n        ).to(device)\n        perm = [0, 1, 2, 3]\n        shuffle(perm)\n        y = x.permute(perm)\n        z = torch.vsplit(y, (1, 2))\n        return z[0]\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestVsplitInt(flow.unittest.TestCase):\n    @autotest(check_graph=True)\n    def test_flow_vsplit_int(test_case):\n        device = random_device()\n        x = random_tensor(\n            ndim=4, dim0=12, dim1=random(3, 6), dim2=random(3, 6), dim3=random(3, 6),\n        ).to(device)\n        split = oneof(2, 4, 6)\n        z = torch.vsplit(x, split)\n        return z[0]\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_weight_norm.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport torch as torch_original\nfrom oneflow.test_utils.automated_test_util import *\n\ninput_arr = np.array(\n    [\n        [-0.16046895, -1.03667831],\n        [-0.34974465, 0.26505867],\n        [-1.24111986, -0.53806001],\n        [1.72426331, 0.43572459],\n    ],\n    dtype=np.float64,\n)\n\n\ndef _test_weightnorm(test_case, device, dim):\n    model_flow = flow.nn.Linear(2, 4)\n    model_flow = model_flow.to(device)\n    with flow.no_grad():\n        for i in range(input_arr.shape[0]):\n            for j in range(input_arr.shape[1]):\n                model_flow.weight[i, j] = input_arr[i][j]\n    m_flow = flow.nn.utils.weight_norm(model_flow, name=\"weight\", dim=dim)\n\n    model_torch = torch_original.nn.Linear(2, 4)\n    model_torch = model_torch.to(device)\n    with torch_original.no_grad():\n        for i in range(input_arr.shape[0]):\n            for j in range(input_arr.shape[1]):\n                model_torch.weight[i, j] = input_arr[i][j]\n    m_torch = torch_original.nn.utils.weight_norm(model_torch, name=\"weight\", dim=dim)\n\n    if device == \"cpu\":\n        test_case.assertTrue(\n            np.allclose(\n                m_flow.weight_g.detach().numpy(),\n                m_torch.weight_g.detach().numpy(),\n                1e-05,\n                1e-05,\n            )\n        )\n        test_case.assertTrue(\n            np.allclose(\n                m_flow.weight_v.detach().numpy(),\n                m_torch.weight_v.detach().numpy(),\n                1e-05,\n                1e-05,\n            )\n        )\n    elif device == \"cuda\":\n        test_case.assertTrue(\n            np.allclose(\n                m_flow.weight_g.detach().cpu().numpy(),\n                m_torch.weight_g.detach().cpu().numpy(),\n                1e-05,\n                1e-05,\n            )\n        )\n        test_case.assertTrue(\n            np.allclose(\n                m_flow.weight_v.detach().numpy(),\n                m_torch.weight_v.detach().cpu().numpy(),\n                1e-05,\n                1e-05,\n            )\n        )\n\n\ndef _test_weightnorm_backward(test_case, device, dim):\n    linear = flow.nn.Linear(3, 8)\n    x = flow.tensor(\n        [\n            [-0.94630778, -0.83378579, -0.87060891],\n            [2.0289922, -0.28708987, -2.18369248],\n            [0.35217619, -0.67095644, -1.58943879],\n            [0.08086036, -1.81075924, 1.20752494],\n            [0.8901075, -0.49976737, -1.07153746],\n            [-0.44872912, -1.07275683, 0.06256855],\n            [-0.22556897, 0.74798368, 0.90416439],\n            [0.48339456, -2.32742195, -0.59321527],\n        ],\n        dtype=flow.float32,\n        requires_grad=True,\n    )\n    flow.nn.init.constant_(linear.weight, 2.068758)\n    flow.nn.init.constant_(linear.bias, 0.23)\n\n    linear_wn = flow.nn.utils.weight_norm(linear, name=\"weight\", dim=dim)\n    of_out = linear_wn(x)\n\n    of_out = of_out.sum()\n    of_out.backward()\n\n    np_grad = np.array(\n        [\n            [16.5501, 16.5501, 16.5501],\n            [16.5501, 16.5501, 16.5501],\n            [16.5501, 16.5501, 16.5501],\n            [16.5501, 16.5501, 16.5501],\n            [16.5501, 16.5501, 16.5501],\n            [16.5501, 16.5501, 16.5501],\n            [16.5501, 16.5501, 16.5501],\n            [16.5501, 16.5501, 16.5501],\n        ]\n    )\n    test_case.assertTrue(np.allclose(np_grad, x.grad.numpy(), 0.0001, 0.0001))\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestWeightNorm(flow.unittest.TestCase):\n    def test_weightnorm(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [\n            _test_weightnorm,\n            _test_weightnorm_backward,\n        ]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        arg_dict[\"dim\"] = [None, -2, -1, 0, 1]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n    # Not check graph because of one reason:\n    # Reason 1, Graph's build input nn.modules.linear.Linear type is not supported.\n    # Please refer to issue: https://github.com/Oneflow-Inc/oneflow/issues/7466\n    @autotest(n=10, auto_backward=True, check_graph=\"ValidatedFalse\")\n    def test_weight_norm_with_random_data(test_case):\n        device = random_device()\n\n        dim = random(-2, 2).to(int).value()\n        output = random(2, 6).to(int)\n        input = random(2, 6).to(int)\n\n        model_torch = torch.nn.Linear(output, input)\n        model_torch = model_torch.to(device)\n        m = torch.nn.utils.weight_norm(model_torch, name=\"weight\", dim=dim)\n        return m.weight_g, m.weight_v\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_where.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\n\nfrom oneflow.test_utils.automated_test_util import *\nfrom oneflow.test_utils.test_util import GenArgList\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\ndef _test_where(test_case, device):\n    x = flow.tensor(\n        np.array([[-0.462, 0.3139], [0.3898, -0.7197], [0.0478, -0.1657]]),\n        dtype=flow.float32,\n        device=flow.device(device),\n    )\n    y = flow.tensor(\n        np.ones(shape=(3, 2)), dtype=flow.float32, device=flow.device(device)\n    )\n    condition = flow.tensor(\n        np.array([[0, 1], [1, 0], [1, 0]]), dtype=flow.int32, device=flow.device(device)\n    )\n    of_out = flow.where(condition, x, y)\n    np_out = np.array([[1.0, 0.3139], [0.3898, 1.0], [0.0478, 1.0]])\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05))\n\n\ndef _test_where_broadcast(test_case, device):\n    x = flow.tensor(\n        np.array([[[-0.462, 0.3139], [0.3898, -0.7197], [0.0478, -0.1657]]]),\n        dtype=flow.float32,\n        device=flow.device(device),\n    )\n    y = flow.tensor(\n        np.ones(shape=(3, 3, 2)), dtype=flow.float32, device=flow.device(device)\n    )\n    condition = flow.tensor(\n        np.array([[[0, 1], [1, 0], [1, 0]]]),\n        dtype=flow.int32,\n        device=flow.device(device),\n    )\n    of_out = flow.where(condition, x, y)\n    np_out = np.array(\n        [\n            [[1.0, 0.3139], [0.3898, 1.0], [0.0478, 1.0]],\n            [[1.0, 0.3139], [0.3898, 1.0], [0.0478, 1.0]],\n            [[1.0, 0.3139], [0.3898, 1.0], [0.0478, 1.0]],\n        ]\n    )\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05))\n\n\ndef _test_where_scalar(test_case, device):\n    x = 0.5\n    y = 2.0\n    condition = flow.tensor(np.array([1]), dtype=flow.int32)\n    of_out = flow.where(condition, x, y)\n    test_case.assertTrue(of_out.dtype == flow.float32)\n    np_out = np.array([0.5])\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05))\n    flow.set_default_dtype(flow.double)\n    of_out = flow.where(condition, x, y)\n    test_case.assertTrue(of_out.dtype == flow.double)\n    flow.set_default_dtype(flow.float16)\n    of_out = flow.where(condition, x, y)\n    test_case.assertTrue(of_out.dtype == flow.float16)\n    flow.set_default_dtype(flow.bfloat16)\n    of_out = flow.where(condition, x, y)\n    test_case.assertTrue(of_out.dtype == flow.bfloat16)\n\n\ndef _test_where_dim4(test_case, device):\n    x = flow.tensor(\n        np.array([[[[-0.462, 0.3139], [0.3898, -0.7197], [0.0478, -0.1657]]]]),\n        dtype=flow.float32,\n        device=flow.device(device),\n    )\n    y = flow.tensor(\n        np.ones(shape=(1, 1, 3, 2)), dtype=flow.float32, device=flow.device(device)\n    )\n    condition = flow.tensor(\n        np.array([[[[0, 1], [1, 0], [1, 0]]]]),\n        dtype=flow.int32,\n        device=flow.device(device),\n    )\n    of_out = flow.where(condition, x, y)\n    np_out = np.array([[[[1.0, 0.3139], [0.3898, 1.0], [0.0478, 1.0]]]])\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05))\n\n\ndef _test_where_backward(test_case, device):\n    x = flow.tensor(\n        np.array([[-0.462, 0.3139], [0.3898, -0.7197], [0.0478, -0.1657]]),\n        dtype=flow.float32,\n        device=flow.device(device),\n        requires_grad=True,\n    )\n    y = flow.tensor(\n        np.ones(shape=(3, 2)),\n        dtype=flow.float32,\n        device=flow.device(device),\n        requires_grad=True,\n    )\n    condition = flow.tensor(\n        np.array([[0, 1], [1, 0], [1, 0]]), dtype=flow.int32, device=flow.device(device)\n    )\n    of_out = flow.where(condition, x, y)\n    of_out = of_out.sum()\n    of_out.backward()\n    test_case.assertTrue(\n        np.allclose(x.grad.numpy(), condition.numpy() == 1, 1e-05, 1e-05)\n    )\n    test_case.assertTrue(\n        np.allclose(y.grad.numpy(), condition.numpy() == 0, 1e-05, 1e-05)\n    )\n\n\ndef _test_where_broadcast_backward(test_case, device):\n    x = flow.tensor(\n        np.array([[[-0.462, 0.3139], [0.3898, -0.7197], [0.0478, -0.1657]]]),\n        dtype=flow.float32,\n        device=flow.device(device),\n        requires_grad=True,\n    )\n    y = flow.tensor(\n        np.ones(shape=(3, 3, 2)),\n        dtype=flow.float32,\n        device=flow.device(device),\n        requires_grad=True,\n    )\n    condition = flow.tensor(\n        np.array([[[0, 1], [1, 0], [1, 0]]]),\n        dtype=flow.int32,\n        device=flow.device(device),\n    )\n    of_out = flow.where(condition, x, y)\n    of_out = of_out.sum()\n    of_out.backward()\n    x_grad = [[[0.0, 3.0], [3.0, 0.0], [3.0, 0.0]]]\n    test_case.assertTrue(np.allclose(x.grad.numpy(), x_grad, 1e-05, 1e-05))\n    y_grad = [\n        [[1.0, 0.0], [0.0, 1.0], [0.0, 1.0]],\n        [[1.0, 0.0], [0.0, 1.0], [0.0, 1.0]],\n        [[1.0, 0.0], [0.0, 1.0], [0.0, 1.0]],\n    ]\n    test_case.assertTrue(np.allclose(y.grad.numpy(), y_grad, 1e-05, 1e-05))\n\n\ndef _test_where_broadcast_x_backward(test_case, device):\n    x = flow.tensor(\n        np.array([[[-0.462, 0.3139], [0.3898, -0.7197], [0.0478, -0.1657]]]),\n        dtype=flow.float32,\n        device=flow.device(device),\n        requires_grad=True,\n    )\n    y = flow.tensor(\n        np.ones(shape=(3, 3, 2)), dtype=flow.float32, device=flow.device(device)\n    )\n    condition = flow.tensor(\n        np.array([[[0, 1], [1, 0], [1, 0]]]),\n        dtype=flow.int32,\n        device=flow.device(device),\n    )\n    of_out = flow.where(condition, x, y)\n    of_out = of_out.sum()\n    of_out.backward()\n    x_grad = [[[0.0, 3.0], [3.0, 0.0], [3.0, 0.0]]]\n    test_case.assertTrue(np.allclose(x.grad.numpy(), x_grad, 1e-05, 1e-05))\n\n\ndef _test_where_x_y_none(test_case, device):\n    condition = flow.tensor(\n        np.array([[[-0.462, 0.3139], [0.3898, -0.7197], [0.0478, -0.1657]]]),\n        dtype=flow.float32,\n        device=flow.device(device),\n        requires_grad=True,\n    )\n    of_out = flow.where(condition)\n    of_nonzero = flow.nonzero(condition, as_tuple=True)\n    for i in range(len(of_out)):\n        test_case.assertTrue(\n            np.allclose(of_out[i].numpy(), of_nonzero[i].numpy(), 1e-05, 1e-05)\n        )\n\n\ndef _test_where_scalar(test_case, device):\n    x = flow.randn(5, 5)\n    y = flow.where(x > 0, x, 0.0)\n    test_case.assertTrue(np.array_equal(y.size(), (5, 5)))\n    y = flow.where(x > 0, 0.0, x)\n    test_case.assertTrue(np.array_equal(y.size(), (5, 5)))\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestWhere(flow.unittest.TestCase):\n    def test_where(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"test_fun\"] = [\n            _test_where,\n            _test_where_broadcast,\n            _test_where_scalar,\n            _test_where_dim4,\n            _test_where_backward,\n            _test_where_broadcast_backward,\n            _test_where_broadcast_x_backward,\n            _test_where_x_y_none,\n            _test_where_scalar,\n        ]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n    @autotest(n=5)\n    def test_flow_where_tensor_with_random_data(test_case):\n        k1 = random(2, 6)\n        k2 = random(2, 6)\n        device = random_device()\n        cond = random_tensor(ndim=2, dim0=k1, dim1=k2).to(device)\n        x = random_tensor(ndim=2, dim0=k1, dim1=k2).to(device)\n        y = random_tensor(ndim=2, dim0=k1, dim1=k2).to(device)\n        return torch.where(cond > 0, x, y)\n\n    @autotest(n=5)\n    def test_flow_where_tensor_with_0dim_data(test_case):\n        k1 = random(2, 6)\n        k2 = random(2, 6)\n        device = random_device()\n        cond = random_tensor(ndim=2, dim0=k1, dim1=k2).to(device)\n        x = random_tensor(ndim=0).to(device)\n        y = random_tensor(ndim=0).to(device)\n        return torch.where(cond > 0, x, y)\n\n    @autotest(n=5)\n    def test_flow_where_tensor_broadcast_with_random_data(test_case):\n        k1 = random(2, 6)\n        k2 = random(2, 6)\n        device = random_device()\n        cond = random_tensor(ndim=2, dim0=k1, dim1=k2).to(device)\n        x = random_tensor(ndim=2, dim0=1, dim1=k2).to(device)\n        y = random_tensor(ndim=2, dim0=k1, dim1=1).to(device)\n        return torch.where(cond > 0, x, y)\n\n    @autotest(n=5)\n    def test_flow_where_scalar_x_with_random_data(test_case):\n        k1 = random(2, 6)\n        k2 = random(2, 6)\n        device = random_device()\n        cond = random_tensor(ndim=2, dim0=k1, dim1=k2).to(device)\n        x = random().to(float)\n        y = random_tensor(ndim=2, dim0=k1, dim1=k2, dtype=float).to(\n            device=device, dtype=torch.float64\n        )\n        return torch.where(cond > 0, x, y)\n\n    @autotest(n=5)\n    def test_flow_where_scalar_x_broadcast_with_random_data(test_case):\n        k1 = random(2, 6)\n        k2 = random(2, 6)\n        device = random_device()\n        cond = random_tensor(ndim=2, dim0=1, dim1=k2).to(device)\n        x = random().to(float)\n        y = random_tensor(ndim=2, dim0=k1, dim1=1, dtype=float).to(\n            device=device, dtype=torch.float64\n        )\n        return torch.where(cond > 0, x, y)\n\n    @autotest(n=5, auto_backward=False, check_graph=True)\n    def test_flow_where_scalar_x_int_with_random_data(test_case):\n        k1 = random(2, 6)\n        k2 = random(2, 6)\n        device = random_device()\n        cond = random_tensor(ndim=2, dim0=k1, dim1=k2).to(device)\n        x = random().to(int)\n        y = random_tensor(ndim=2, dim0=k1, dim1=k2, dtype=int).to(device)\n        return torch.where(cond > 0, x, y)\n\n    @autotest(n=5)\n    def test_flow_where_scalar_y_with_random_data(test_case):\n        k1 = random(2, 6)\n        k2 = random(2, 6)\n        device = random_device()\n        cond = random_tensor(ndim=2, dim0=k1, dim1=k2).to(device)\n        x = random_tensor(ndim=2, dim0=k1, dim1=k2, dtype=float).to(\n            device=device, dtype=torch.float64\n        )\n        y = random().to(float)\n        return torch.where(cond > 0, x, y)\n\n    @autotest(n=5)\n    def test_flow_where_scalar_y_broadcast_with_random_data(test_case):\n        k1 = random(2, 6)\n        k2 = random(2, 6)\n        device = random_device()\n        cond = random_tensor(ndim=2, dim0=1, dim1=k2).to(device)\n        x = random_tensor(ndim=2, dim0=k1, dim1=1, dtype=float).to(\n            device=device, dtype=torch.float64\n        )\n        y = random().to(float)\n        return torch.where(cond > 0, x, y)\n\n    @autotest(n=5, auto_backward=False, check_graph=True)\n    def test_flow_where_scalar_y_int_with_random_data(test_case):\n        k1 = random(2, 6)\n        k2 = random(2, 6)\n        device = random_device()\n        cond = random_tensor(ndim=2, dim0=k1, dim1=k2).to(device)\n        x = random_tensor(ndim=2, dim0=k1, dim1=k2, dtype=int).to(device)\n        y = random().to(int)\n        return torch.where(cond > 0, x, y)\n\n    @autotest(n=5, auto_backward=False, check_graph=True)\n    def test_flow_where_scalar_xy_with_random_data(test_case):\n        k1 = random(2, 6)\n        k2 = random(2, 6)\n        device = random_device()\n        cond = random_tensor(ndim=2, dim0=k1, dim1=k2).to(device)\n        x = random().to(float)\n        y = random().to(float)\n        return torch.where(cond > 0, x, y)\n\n    @autotest(n=5, auto_backward=False, check_graph=True)\n    def test_flow_where_scalar_xy_int_with_random_data(test_case):\n        k1 = random(2, 6)\n        k2 = random(2, 6)\n        device = random_device()\n        cond = random_tensor(ndim=2, dim0=k1, dim1=k2).to(device)\n        x = random().to(int)\n        y = random().to(int)\n        return torch.where(cond > 0, x, y)\n\n    @autotest(n=5, auto_backward=False, check_graph=True)\n    def test_flow_where_tensor_bool_with_random_data(test_case):\n        k1 = random(2, 6)\n        k2 = random(2, 6)\n        device = random_device()\n        cond = random_tensor(ndim=2, dim0=k1, dim1=k2).to(device)\n        x = random_tensor(ndim=2, dim0=k1, dim1=k2).to(device=device, dtype=torch.bool)\n        y = random_tensor(ndim=2, dim0=k1, dim1=k2).to(device=device, dtype=torch.bool)\n        return torch.where(cond > 0, x, y)\n\n    @autotest(n=5, auto_backward=False, check_graph=True)\n    def test_flow_where_tensor_broadcast_bool_with_random_data(test_case):\n        k1 = random(2, 6)\n        k2 = random(2, 6)\n        device = random_device()\n        cond = random_tensor(ndim=2, dim0=k1, dim1=k2).to(device)\n        x = random_tensor(ndim=2, dim0=1, dim1=k2).to(device=device, dtype=torch.bool)\n        y = random_tensor(ndim=2, dim0=k1, dim1=1).to(device=device, dtype=torch.bool)\n        return torch.where(cond > 0, x, y)\n\n    @autotest(n=5, auto_backward=False, check_graph=True)\n    def test_flow_where_scalar_x_bool_with_random_data(test_case):\n        k1 = random(2, 6)\n        k2 = random(2, 6)\n        device = random_device()\n        cond = random_tensor(ndim=2, dim0=k1, dim1=k2).to(device)\n        x = random().to(bool)\n        y = random_tensor(ndim=2, dim0=k1, dim1=k2, dtype=float).to(\n            device=device, dtype=torch.bool\n        )\n        return torch.where(cond > 0, x, y)\n\n    @autotest(n=5, auto_backward=False, check_graph=True)\n    def test_flow_where_scalar_x_broadcast_bool_with_random_data(test_case):\n        k1 = random(2, 6)\n        k2 = random(2, 6)\n        device = random_device()\n        cond = random_tensor(ndim=2, dim0=1, dim1=k2).to(device)\n        x = random().to(bool)\n        y = random_tensor(ndim=2, dim0=k1, dim1=1, dtype=float).to(\n            device=device, dtype=torch.bool\n        )\n        return torch.where(cond > 0, x, y)\n\n    @autotest(n=5, auto_backward=False, check_graph=True)\n    def test_flow_where_scalar_y_bool_with_random_data(test_case):\n        k1 = random(2, 6)\n        k2 = random(2, 6)\n        device = random_device()\n        cond = random_tensor(ndim=2, dim0=k1, dim1=k2).to(device)\n        x = random_tensor(ndim=2, dim0=k1, dim1=k2, dtype=float).to(\n            device=device, dtype=torch.bool\n        )\n        y = random().to(bool)\n        return torch.where(cond > 0, x, y)\n\n    @autotest(n=5, auto_backward=False, check_graph=True)\n    def test_flow_where_scalar_y_broadcast_bool_with_random_data(test_case):\n        k1 = random(2, 6)\n        k2 = random(2, 6)\n        device = random_device()\n        cond = random_tensor(ndim=2, dim0=1, dim1=k2).to(device)\n        x = random_tensor(ndim=2, dim0=k1, dim1=1, dtype=float).to(\n            device=device, dtype=torch.bool\n        )\n        y = random().to(bool)\n        return torch.where(cond > 0, x, y)\n\n    @autotest(n=5, auto_backward=False, check_graph=True)\n    def test_flow_where_scalar_xy_bool_with_random_data(test_case):\n        k1 = random(2, 6)\n        k2 = random(2, 6)\n        device = random_device()\n        cond = random_tensor(ndim=2, dim0=k1, dim1=k2).to(device)\n        x = random().to(bool)\n        y = random().to(bool)\n        return torch.where(cond > 0, x, y)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/modules/test_zeropad2d.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nfrom oneflow.test_utils.test_util import (\n    Array2Numpy,\n    FlattenArray,\n    GenArgList,\n    Index2Coordinate,\n)\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\ndef _np_zero_pad2d_grad(src, dest, padding):\n    (c_idx, h_idx, w_idx) = (1, 2, 3)\n    pad_left = padding[0]\n    pad_right = padding[1]\n    pad_top = padding[2]\n    pad_bottom = padding[3]\n    (dx_height, dx_width) = (dest.shape[h_idx], dest.shape[w_idx])\n    (dy_height, dy_width) = (src.shape[h_idx], src.shape[w_idx])\n    numpy_src = np.ones(src.shape, np.int32)\n    numpy_dest = np.zeros(dest.shape, np.int32)\n    array_src = FlattenArray(numpy_src)\n    array_dest = FlattenArray(numpy_dest)\n    src_num = src.shape[c_idx] * src.shape[h_idx] * src.shape[w_idx]\n    dest_num = dest.shape[c_idx] * dest.shape[h_idx] * dest.shape[w_idx]\n    elements_num = src.shape[0] * src_num\n    for iter_n in range(elements_num):\n        coords = Index2Coordinate(iter_n, src.shape)\n        (n, c, i, j) = (coords[0], coords[c_idx], coords[h_idx], coords[w_idx])\n        ip_x = ip_y = 0\n        if (\n            j >= pad_left\n            and j < dx_width + pad_left\n            and (i >= pad_top)\n            and (i < dx_height + pad_top)\n        ):\n            ip_x = j - pad_left\n            ip_y = i - pad_top\n            src_index = n * src_num + c * dy_width * dy_height + i * dy_width + j\n            dest_index = (\n                n * dest_num + c * dx_width * dx_height + ip_y * dx_width + ip_x\n            )\n            array_dest[dest_index] += array_src[src_index]\n    numpy_dest = Array2Numpy(array_dest, dest.shape)\n    return numpy_dest\n\n\ndef _test_ZeroPad2d(test_case, shape, padding, value, device):\n    np_input = np.random.random(shape)\n    of_input = flow.tensor(\n        np_input, dtype=flow.float32, device=flow.device(device), requires_grad=True\n    )\n    if isinstance(padding, int):\n        np_boundary = ((0, 0), (0, 0), (padding, padding), (padding, padding))\n    elif isinstance(padding, (tuple, int)) and len(padding) == 4:\n        np_boundary = (\n            (0, 0),\n            (0, 0),\n            (padding[2], padding[3]),\n            (padding[0], padding[1]),\n        )\n    else:\n        raise ValueError(\"padding must be in  or tuple!\")\n    layer = flow.nn.ZeroPad2d(padding=padding)\n    of_out = layer(of_input)\n    np_out = np.pad(np_input, np_boundary, mode=\"constant\", constant_values=value)\n    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05))\n    of_out = of_out.sum()\n    of_out.backward()\n    np_out_grad = _np_zero_pad2d_grad(np_out, np_input, layer.padding)\n    test_case.assertTrue(np.allclose(of_input.grad.numpy(), np_out_grad, 1e-05, 1e-05))\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestZeroPad2dModule(flow.unittest.TestCase):\n    def test_ConstantPad2d(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"shape\"] = [(1, 2, 3, 4), (8, 3, 4, 4)]\n        arg_dict[\"padding\"] = [2, (1, 1, 2, 2)]\n        arg_dict[\"value\"] = [0.0]\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgList(arg_dict):\n            _test_ZeroPad2d(test_case, *arg)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/profiler/test_events.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport json\nimport unittest\nimport oneflow.unittest\nimport oneflow as flow\nfrom oneflow.profiler.events import *\n\n\nclass TestEventAndEvents(flow.unittest.TestCase):\n    def test_event(test_case):\n        classes = [CustomEvent, KernelEvent]\n        custom_event = CustomEvent(\"custom\", 1234, CustomEventType.Default)\n        custom_event_json = {\n            \"name\": \"custom\",\n            \"time\": 1234,\n            \"custom_type\": 0,\n            \"type\": 0,\n        }\n        test_case.assertEqual(\n            custom_event,\n            classes[custom_event_json.get(\"type\")].from_dict(custom_event_json),\n        )\n\n        kernel_event = KernelEvent(\"kernel\", 1234, 1024, \"-\")\n        kernel_event_json = {\n            \"name\": \"kernel\",\n            \"time\": 1234,\n            \"memory_size\": 1024,\n            \"type\": 1,\n            \"input_shapes\": \"-\",\n        }\n        test_case.assertEqual(\n            kernel_event,\n            classes[kernel_event_json.get(\"type\")].from_dict(kernel_event_json),\n        )\n\n    def test_event_update(test_case):\n        event = CustomEvent(\"custom\", 1234, CustomEventType.Default)\n        event1 = CustomEvent(\"custom\", 3346, CustomEventType.Default)\n        event.update(event1)\n        test_case.assertEqual(event.count, 2)\n        test_case.assertEqual(event.cpu_time, 2290)\n        test_case.assertEqual(event.cpu_time_total, 4580)\n\n    def test_events(test_case):\n        events_json = json.dumps(\n            [\n                {\"name\": \"custom\", \"time\": 1234, \"custom_type\": 0, \"type\": 0},\n                {\"name\": \"custom\", \"time\": 3346, \"custom_type\": 0, \"type\": 0},\n            ]\n        )\n        events = [\n            CustomEvent(\"custom\", 1234, CustomEventType.Default),\n            CustomEvent(\"custom\", 3346, CustomEventType.Default),\n        ]\n        events_avg = [CustomEvent(\"custom\", 4580, CustomEventType.Default)]\n        events_avg[0].count = 2\n        test_case.assertEqual(Events(events_json), events)\n        test_case.assertEqual(Events(events_json).key_averages(), events_avg)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/profiler/test_profile_lenet.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport os\nimport unittest\nimport oneflow.unittest\nimport oneflow as flow\nimport oneflow.nn as nn\nimport oneflow.nn.functional as F\nimport oneflow.profiler\nfrom collections import OrderedDict\nfrom oneflow.profiler.events import CustomEvent, KernelEvent\nfrom oneflow.test_utils.test_util import GenArgDict\n\n\nclass LeNet(nn.Module):\n    def __init__(self):\n        super(LeNet, self).__init__()\n        self.conv1 = nn.Conv2d(3, 6, 5)\n        self.conv2 = nn.Conv2d(6, 16, 5)\n        self.fc1 = nn.Linear(16 * 5 * 5, 120)\n        self.fc2 = nn.Linear(120, 84)\n        self.fc3 = nn.Linear(84, 10)\n\n    def forward(self, x):\n        out = F.relu(self.conv1(x))\n        out = F.max_pool2d(out, 2)\n        out = F.relu(self.conv2(out))\n        out = F.max_pool2d(out, 2)\n        out = out.view(out.size(0), -1)\n        out = F.relu(self.fc1(out))\n        out = F.relu(self.fc2(out))\n        out = self.fc3(out)\n        return out\n\n\ndef get_event(events, name: str, input_shapes: str = \"\", attributes: str = \"\"):\n    for item in events:\n        if isinstance(item, CustomEvent):\n            if item.name == name:\n                return item\n        if isinstance(item, KernelEvent):\n            if (\n                item.name == name\n                and item.input_shapes == input_shapes\n                and item.attributes == attributes\n            ):\n                return item\n    return None\n\n\ndef _test_lenet(\n    test_case,\n    on_cuda: bool,\n    record_shapes: bool,\n    record_attrs: bool,\n    record_bandwidth_for_cuda: bool = False,\n):\n    x = flow.randn(2, 3, 32, 32)\n    lenet = LeNet()\n    if on_cuda:\n        x = x.to(\"cuda\")\n        lenet.to(\"cuda\")\n    activities = [oneflow.profiler.ProfilerActivity.CPU]\n    if on_cuda:\n        activities.append(oneflow.profiler.ProfilerActivity.CUDA)\n    with oneflow.profiler.profile(\n        activities=activities,\n        record_shapes=record_shapes,\n        record_attrs=record_attrs,\n        record_bandwidth_for_cuda=record_bandwidth_for_cuda,\n    ) as prof:\n        with oneflow.profiler.record_function(\"lenet_forward_total_time\") as f:\n            for _ in range(2):\n                eager_res = lenet(x)\n        with oneflow.profiler.record_function(\"lenet_backward_total_time\") as f:\n            eager_res.sum().backward()\n    events = prof.key_averages(group_by_input_shape=True, group_by_attributes=True)\n\n    conv_event_input_shapes = \"(2,3,32,32), (6,3,5,5)\" if record_shapes else \"\"\n    conv_event_attributes = (\n        \"data_format=channels_first, dilation_rate=[1, 1], filters=6, groups=1, kernel_size=[5, 5], padding_before=[0, 0], strides=[1, 1]\"\n        if record_attrs\n        else \"\"\n    )\n    conv_event = get_event(\n        events, \"conv2d\", conv_event_input_shapes, conv_event_attributes\n    )\n    test_case.assertIsNotNone(conv_event)\n\n    if on_cuda:\n        test_case.assertGreater(conv_event.cpu_time, 0.0)\n        test_case.assertGreater(conv_event.cpu_time_total, 0.0)\n        test_case.assertGreater(conv_event.cuda_time, 0.0)\n        test_case.assertGreater(conv_event.cuda_time_total, 0.0)\n    else:\n        test_case.assertGreater(conv_event.cpu_time, 0.0)\n        test_case.assertGreater(conv_event.cpu_time_total, 0.0)\n\n    test_case.assertEqual(conv_event.count, 2 if record_shapes or record_attrs else 4)\n    if record_bandwidth_for_cuda and on_cuda:\n        test_case.assertNotEqual(conv_event.bandwidth, -1)\n\n    relu_grad_event_input_shapes = \"(2,6,28,28), (2,6,28,28)\" if record_shapes else \"\"\n    relu_grad_event = get_event(events, \"relu_grad\", relu_grad_event_input_shapes, \"\")\n    test_case.assertIsNotNone(relu_grad_event)\n    if on_cuda:\n        test_case.assertGreater(relu_grad_event.cpu_time, 0.0)\n        test_case.assertGreater(relu_grad_event.cpu_time_total, 0.0)\n        test_case.assertGreater(relu_grad_event.cuda_time, 0.0)\n        test_case.assertGreater(relu_grad_event.cuda_time_total, 0.0)\n    else:\n        test_case.assertGreater(relu_grad_event.cpu_time, 0.0)\n        test_case.assertGreater(relu_grad_event.cpu_time_total, 0.0)\n\n    test_case.assertEqual(relu_grad_event.count, 1 if record_shapes else 4)\n    if record_bandwidth_for_cuda and on_cuda:\n        test_case.assertNotEqual(relu_grad_event.bandwidth, -1)\n\n    test_case.assertIsNotNone(get_event(events, \"lenet_forward_total_time\"))\n    test_case.assertIsNotNone(get_event(events, \"lenet_backward_total_time\"))\n\n\nclass TestProfileLenet(flow.unittest.TestCase):\n    def test_lenet_cpu(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"record_shapes\"] = [True, False]\n        arg_dict[\"record_attrs\"] = [True, False]\n        for kwargs in GenArgDict(arg_dict):\n            _test_lenet(test_case, False, **kwargs)\n\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    def test_lenet_cuda(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"record_shapes\"] = [True, False]\n        arg_dict[\"record_attrs\"] = [True, False]\n        arg_dict[\"record_bandwidth_for_cuda\"] = [True, False]\n        for kwargs in GenArgDict(arg_dict):\n            _test_lenet(test_case, True, **kwargs)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/tensor/test_autocast.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"skip test cpu cases\")\n@flow.unittest.skip_unless_1n1d()\nclass TestAutoCast(flow.unittest.TestCase):\n    @autotest(n=1, auto_backward=True, check_graph=False)\n    def test_autocast_half_mm(test_case):\n        a = random_tensor(2, 2, 3).to(\"cuda\")\n        b = random_tensor(2, 3, 4).to(\"cuda\")\n        with torch.autocast(\"cuda\"):\n            x = torch.mm(a, b)\n        return x\n\n    @autotest(n=1, auto_backward=True, check_graph=False)\n    def test_autocast_half_mm_add(test_case):\n        a = random_tensor(2, 2, 3).to(\"cuda\")\n        b = random_tensor(2, 3, 4).to(\"cuda\")\n        c = random_tensor(2, 2, 4).to(\"cuda\")\n        with torch.autocast(\"cuda\"):\n            x = torch.mm(a, b)\n            y = x + c\n        return x.float() + y.float()\n\n    def test_autocast_graph(test_case):\n        class LinearGraph(flow.nn.Graph):\n            def __init__(self):\n                super().__init__()\n                self.linear = flow.nn.Linear(3, 4, bias=False).cuda().half()\n\n            def build(self, x):\n                return self.linear(x)\n\n        x = flow.Tensor(3, 3).cuda()\n\n        with flow.autocast(device_type=\"cuda\"):\n            linear = LinearGraph()\n            y = linear(x)\n            test_case.assertTrue(y.dtype == flow.float16)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/tensor/test_bfloat16_activation.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport os\nimport unittest\nimport numpy as np\nimport oneflow as flow\nimport oneflow.unittest\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestBfloat16Activatian(flow.unittest.TestCase):\n    def test_tan_with_random_data(test_case):\n        np_array = np.random.rand(4, 4)\n        x = flow.tensor(np_array, dtype=flow.bfloat16, device=\"cpu\")\n        fp32_x = x.float()\n        y = flow.tan(x)\n        fp32_y = flow.tan(fp32_x)\n        test_case.assertTrue(\n            np.allclose(\n                y.float().numpy(),\n                fp32_y.bfloat16().float().numpy(),\n                atol=1e-4,\n                rtol=1e-4,\n            )\n        )\n\n    def test_tanh_with_random_data(test_case):\n        np_array = np.random.rand(4, 4)\n        x = flow.tensor(np_array, dtype=flow.bfloat16, device=\"cpu\")\n        fp32_x = x.float()\n        y = flow.tanh(x)\n        fp32_y = flow.tanh(fp32_x)\n        test_case.assertTrue(\n            np.allclose(\n                y.float().numpy(),\n                fp32_y.bfloat16().float().numpy(),\n                atol=1e-4,\n                rtol=1e-4,\n            )\n        )\n\n    def test_sin_with_random_data(test_case):\n        np_array = np.random.rand(4, 4)\n        x = flow.tensor(np_array, dtype=flow.bfloat16, device=\"cpu\")\n        fp32_x = x.float()\n        y = flow.sin(x)\n        fp32_y = flow.sin(fp32_x)\n        test_case.assertTrue(\n            np.allclose(\n                y.float().numpy(),\n                fp32_y.bfloat16().float().numpy(),\n                atol=1e-4,\n                rtol=1e-4,\n            )\n        )\n\n    def test_sinh_with_random_data(test_case):\n        np_array = np.random.rand(4, 4)\n        x = flow.tensor(np_array, dtype=flow.bfloat16, device=\"cpu\")\n        fp32_x = x.float()\n        y = flow.sinh(x)\n        fp32_y = flow.sinh(fp32_x)\n        test_case.assertTrue(\n            np.allclose(\n                y.float().numpy(),\n                fp32_y.bfloat16().float().numpy(),\n                atol=1e-4,\n                rtol=1e-4,\n            )\n        )\n\n    def test_cos_with_random_data(test_case):\n        np_array = np.random.rand(4, 4)\n        x = flow.tensor(np_array, dtype=flow.bfloat16, device=\"cpu\")\n        fp32_x = x.float()\n        y = flow.cos(x)\n        fp32_y = flow.cos(fp32_x)\n        test_case.assertTrue(\n            np.allclose(\n                y.float().numpy(),\n                fp32_y.bfloat16().float().numpy(),\n                atol=1e-4,\n                rtol=1e-4,\n            )\n        )\n\n    def test_cosh_with_random_data(test_case):\n        np_array = np.random.rand(4, 4)\n        x = flow.tensor(np_array, dtype=flow.bfloat16, device=\"cpu\")\n        fp32_x = x.float()\n        y = flow.cosh(x)\n        fp32_y = flow.cosh(fp32_x)\n        test_case.assertTrue(\n            np.allclose(\n                y.float().numpy(),\n                fp32_y.bfloat16().float().numpy(),\n                atol=1e-4,\n                rtol=1e-4,\n            )\n        )\n\n    def test_atan_with_random_data(test_case):\n        np_array = np.random.rand(4, 4)\n        x = flow.tensor(np_array, dtype=flow.bfloat16, device=\"cpu\")\n        fp32_x = x.float()\n        y = flow.atan(x)\n        fp32_y = flow.atan(fp32_x)\n        test_case.assertTrue(\n            np.allclose(\n                y.float().numpy(),\n                fp32_y.bfloat16().float().numpy(),\n                atol=1e-4,\n                rtol=1e-4,\n            )\n        )\n\n    def test_atanh_with_random_data(test_case):\n        np_array = np.random.rand(4, 4)\n        x = flow.tensor(np_array, dtype=flow.bfloat16, device=\"cpu\")\n        fp32_x = x.float()\n        y = flow.atanh(x)\n        fp32_y = flow.atanh(fp32_x)\n        test_case.assertTrue(\n            np.allclose(\n                y.float().numpy(),\n                fp32_y.bfloat16().float().numpy(),\n                atol=1e-4,\n                rtol=1e-4,\n            )\n        )\n\n    def test_asin_with_random_data(test_case):\n        np_array = np.random.rand(4, 4)\n        x = flow.tensor(np_array, dtype=flow.bfloat16, device=\"cpu\")\n        fp32_x = x.float()\n        y = flow.asin(x)\n        fp32_y = flow.asin(fp32_x)\n        test_case.assertTrue(\n            np.allclose(\n                y.float().numpy(),\n                fp32_y.bfloat16().float().numpy(),\n                atol=1e-4,\n                rtol=1e-4,\n            )\n        )\n\n    def test_asinh_with_random_data(test_case):\n        np_array = np.random.rand(4, 4)\n        x = flow.tensor(np_array, dtype=flow.bfloat16, device=\"cpu\")\n        fp32_x = x.float()\n        y = flow.asinh(x)\n        fp32_y = flow.asinh(fp32_x)\n        test_case.assertTrue(\n            np.allclose(\n                y.float().numpy(),\n                fp32_y.bfloat16().float().numpy(),\n                atol=1e-4,\n                rtol=1e-4,\n            )\n        )\n\n    def test_acos_with_random_data(test_case):\n        np_array = np.random.uniform(-1, 1, (4, 4))\n        x = flow.tensor(np_array, dtype=flow.bfloat16, device=\"cpu\")\n        fp32_x = x.float()\n        y = flow.acos(x)\n        fp32_y = flow.acos(fp32_x)\n        test_case.assertTrue(\n            np.allclose(\n                y.float().numpy(),\n                fp32_y.bfloat16().float().numpy(),\n                atol=1e-4,\n                rtol=1e-4,\n            )\n        )\n\n    def test_acosh_with_random_data(test_case):\n        np_array = np.random.uniform(1, 5, (4, 4))\n        x = flow.tensor(np_array, dtype=flow.bfloat16, device=\"cpu\")\n        fp32_x = x.float()\n        y = flow.acosh(x)\n        fp32_y = flow.acosh(fp32_x)\n        test_case.assertTrue(\n            np.allclose(\n                y.float().numpy(),\n                fp32_y.bfloat16().float().numpy(),\n                atol=1e-4,\n                rtol=1e-4,\n            )\n        )\n\n    def test_sqrt_with_random_data(test_case):\n        np_array = np.random.rand(4, 4)\n        x = flow.tensor(np_array, dtype=flow.bfloat16, device=\"cpu\")\n        fp32_x = x.float()\n        y = flow.sqrt(x)\n        fp32_y = flow.sqrt(fp32_x)\n        test_case.assertTrue(\n            np.allclose(\n                y.float().numpy(),\n                fp32_y.bfloat16().float().numpy(),\n                atol=1e-4,\n                rtol=1e-4,\n            )\n        )\n\n    def test_square_with_random_data(test_case):\n        np_array = np.random.rand(4, 4)\n        x = flow.tensor(np_array, dtype=flow.bfloat16, device=\"cpu\")\n        fp32_x = x.float()\n        y = flow.square(x)\n        fp32_y = flow.square(fp32_x)\n        test_case.assertTrue(\n            np.allclose(\n                y.float().numpy(),\n                fp32_y.bfloat16().float().numpy(),\n                atol=1e-4,\n                rtol=1e-4,\n            )\n        )\n\n    def test_exp_with_random_data(test_case):\n        np_array = np.random.rand(4, 4)\n        x = flow.tensor(np_array, dtype=flow.bfloat16, device=\"cpu\")\n        fp32_x = x.float()\n        y = flow.exp(x)\n        fp32_y = flow.exp(fp32_x)\n        test_case.assertTrue(\n            np.allclose(\n                y.float().numpy(),\n                fp32_y.bfloat16().float().numpy(),\n                atol=1e-4,\n                rtol=1e-4,\n            )\n        )\n\n    def test_exp2_with_random_data(test_case):\n        np_array = np.random.rand(4, 4)\n        x = flow.tensor(np_array, dtype=flow.bfloat16, device=\"cpu\")\n        fp32_x = x.float()\n        y = flow.exp2(x)\n        fp32_y = flow.exp2(fp32_x)\n        test_case.assertTrue(\n            np.allclose(\n                y.float().numpy(),\n                fp32_y.bfloat16().float().numpy(),\n                atol=1e-4,\n                rtol=1e-4,\n            )\n        )\n\n    def test_ceil_with_random_data(test_case):\n        np_array = np.random.rand(4, 4)\n        x = flow.tensor(np_array, dtype=flow.bfloat16, device=\"cpu\")\n        fp32_x = x.float()\n        y = flow.ceil(x)\n        fp32_y = flow.ceil(fp32_x)\n        test_case.assertTrue(\n            np.allclose(\n                y.float().numpy(),\n                fp32_y.bfloat16().float().numpy(),\n                atol=1e-4,\n                rtol=1e-4,\n            )\n        )\n\n    def test_erf_with_random_data(test_case):\n        np_array = np.random.rand(4, 4)\n        x = flow.tensor(np_array, dtype=flow.bfloat16, device=\"cpu\")\n        fp32_x = x.float()\n        y = flow.erf(x)\n        fp32_y = flow.erf(fp32_x)\n        test_case.assertTrue(\n            np.allclose(\n                y.float().numpy(),\n                fp32_y.bfloat16().float().numpy(),\n                atol=1e-4,\n                rtol=1e-4,\n            )\n        )\n\n    def test_erfc_with_random_data(test_case):\n        np_array = np.random.rand(4, 4)\n        x = flow.tensor(np_array, dtype=flow.bfloat16, device=\"cpu\")\n        fp32_x = x.float()\n        y = flow.erfc(x)\n        fp32_y = flow.erfc(fp32_x)\n        test_case.assertTrue(\n            np.allclose(\n                y.float().numpy(),\n                fp32_y.bfloat16().float().numpy(),\n                atol=1e-4,\n                rtol=1e-4,\n            )\n        )\n\n    def test_floor_with_random_data(test_case):\n        np_array = np.random.rand(4, 4)\n        x = flow.tensor(np_array, dtype=flow.bfloat16, device=\"cpu\")\n        fp32_x = x.float()\n        y = flow.floor(x)\n        fp32_y = flow.floor(fp32_x)\n        test_case.assertTrue(\n            np.allclose(\n                y.float().numpy(),\n                fp32_y.bfloat16().float().numpy(),\n                atol=1e-4,\n                rtol=1e-4,\n            )\n        )\n\n    def test_expm1_with_random_data(test_case):\n        np_array = np.random.rand(4, 4)\n        x = flow.tensor(np_array, dtype=flow.bfloat16, device=\"cpu\")\n        fp32_x = x.float()\n        y = flow.expm1(x)\n        fp32_y = flow.expm1(fp32_x)\n        test_case.assertTrue(\n            np.allclose(\n                y.float().numpy(),\n                fp32_y.bfloat16().float().numpy(),\n                atol=1e-4,\n                rtol=1e-4,\n            )\n        )\n\n    def test_lgamma_with_random_data(test_case):\n        np_array = np.random.rand(4, 4)\n        x = flow.tensor(np_array, dtype=flow.bfloat16, device=\"cpu\")\n        fp32_x = x.float()\n        y = flow.lgamma(x)\n        fp32_y = flow.lgamma(fp32_x)\n        test_case.assertTrue(\n            np.allclose(\n                y.float().numpy(),\n                fp32_y.bfloat16().float().numpy(),\n                atol=1e-4,\n                rtol=1e-4,\n            )\n        )\n\n    def test_log_with_random_data(test_case):\n        np_array = np.random.rand(4, 4)\n        x = flow.tensor(np_array, dtype=flow.bfloat16, device=\"cpu\")\n        fp32_x = x.float()\n        y = flow.log(x)\n        fp32_y = flow.log(fp32_x)\n        test_case.assertTrue(\n            np.allclose(\n                y.float().numpy(),\n                fp32_y.bfloat16().float().numpy(),\n                atol=1e-4,\n                rtol=1e-4,\n            )\n        )\n\n    def test_log2_with_random_data(test_case):\n        np_array = np.random.rand(4, 4)\n        x = flow.tensor(np_array, dtype=flow.bfloat16, device=\"cpu\")\n        fp32_x = x.float()\n        y = flow.log2(x)\n        fp32_y = flow.log2(fp32_x)\n        test_case.assertTrue(\n            np.allclose(\n                y.float().numpy(),\n                fp32_y.bfloat16().float().numpy(),\n                atol=1e-4,\n                rtol=1e-4,\n            )\n        )\n\n    def test_log1p_with_random_data(test_case):\n        np_array = np.random.rand(4, 4)\n        x = flow.tensor(np_array, dtype=flow.bfloat16, device=\"cpu\")\n        fp32_x = x.float()\n        y = flow.log1p(x)\n        fp32_y = flow.log1p(fp32_x)\n        test_case.assertTrue(\n            np.allclose(\n                y.float().numpy(),\n                fp32_y.bfloat16().float().numpy(),\n                atol=1e-4,\n                rtol=1e-4,\n            )\n        )\n\n    def test_sigmoid_with_random_data(test_case):\n        np_array = np.random.rand(4, 4)\n        x = flow.tensor(np_array, dtype=flow.bfloat16, device=\"cpu\")\n        fp32_x = x.float()\n        y = flow.sigmoid(x)\n        fp32_y = flow.sigmoid(fp32_x)\n        test_case.assertTrue(\n            np.allclose(\n                y.float().numpy(),\n                fp32_y.bfloat16().float().numpy(),\n                atol=1e-4,\n                rtol=1e-4,\n            )\n        )\n\n    def test_round_with_random_data(test_case):\n        np_array = np.random.rand(4, 4)\n        x = flow.tensor(np_array, dtype=flow.bfloat16, device=\"cpu\")\n        fp32_x = x.float()\n        y = flow.round(x)\n        fp32_y = flow.round(fp32_x)\n        test_case.assertTrue(\n            np.allclose(\n                y.float().numpy(),\n                fp32_y.bfloat16().float().numpy(),\n                atol=1e-4,\n                rtol=1e-4,\n            )\n        )\n\n    def test_rsqrt_with_random_data(test_case):\n        np_array = np.random.rand(4, 4)\n        x = flow.tensor(np_array, dtype=flow.bfloat16, device=\"cpu\")\n        fp32_x = x.float()\n        y = flow.rsqrt(x)\n        fp32_y = flow.rsqrt(fp32_x)\n        test_case.assertTrue(\n            np.allclose(\n                y.float().numpy(),\n                fp32_y.bfloat16().float().numpy(),\n                atol=1e-4,\n                rtol=1e-4,\n            )\n        )\n\n    def test_softplus_with_random_data(test_case):\n        np_array = np.random.rand(4, 4)\n        x = flow.tensor(np_array, dtype=flow.bfloat16, device=\"cpu\")\n        fp32_x = x.float()\n        y = flow.softplus(x)\n        fp32_y = flow.softplus(fp32_x)\n        test_case.assertTrue(\n            np.allclose(\n                y.float().numpy(),\n                fp32_y.bfloat16().float().numpy(),\n                atol=1e-4,\n                rtol=1e-4,\n            )\n        )\n\n    def test_softsign_with_random_data(test_case):\n        np_array = np.random.rand(4, 4)\n        x = flow.tensor(np_array, dtype=flow.bfloat16, device=\"cpu\")\n        fp32_x = x.float()\n        y = flow.softsign(x)\n        fp32_y = flow.softsign(fp32_x)\n        test_case.assertTrue(\n            np.allclose(\n                y.float().numpy(),\n                fp32_y.bfloat16().float().numpy(),\n                atol=1e-4,\n                rtol=1e-4,\n            )\n        )\n\n    def test_softshrink_with_random_data(test_case):\n        np_array = np.random.rand(4, 4)\n        x = flow.tensor(np_array, dtype=flow.bfloat16, device=\"cpu\")\n        fp32_x = x.float()\n        y = flow.softshrink(x)\n        fp32_y = flow.softshrink(fp32_x)\n        test_case.assertTrue(\n            np.allclose(\n                y.float().numpy(),\n                fp32_y.bfloat16().float().numpy(),\n                atol=1e-4,\n                rtol=1e-4,\n            )\n        )\n\n    def test_silu_with_random_data(test_case):\n        np_array = np.random.rand(4, 4)\n        x = flow.tensor(np_array, dtype=flow.bfloat16, device=\"cpu\")\n        fp32_x = x.float()\n        y = flow.silu(x)\n        fp32_y = flow.silu(fp32_x)\n        test_case.assertTrue(\n            np.allclose(\n                y.float().numpy(),\n                fp32_y.bfloat16().float().numpy(),\n                atol=1e-4,\n                rtol=1e-4,\n            )\n        )\n\n    def test_selu_with_random_data(test_case):\n        np_array = np.random.rand(4, 4)\n        x = flow.tensor(np_array, dtype=flow.bfloat16, device=\"cpu\")\n        fp32_x = x.float()\n        y = flow.selu(x)\n        fp32_y = flow.selu(fp32_x)\n        test_case.assertTrue(\n            np.allclose(\n                y.float().numpy(),\n                fp32_y.bfloat16().float().numpy(),\n                atol=1e-4,\n                rtol=1e-4,\n            )\n        )\n\n    def test_mish_with_random_data(test_case):\n        np_array = np.random.rand(4, 4)\n        x = flow.tensor(np_array, dtype=flow.bfloat16, device=\"cpu\")\n        fp32_x = x.float()\n        y = flow.mish(x)\n        fp32_y = flow.mish(fp32_x)\n        test_case.assertTrue(\n            np.allclose(\n                y.float().numpy(),\n                fp32_y.bfloat16().float().numpy(),\n                atol=1e-4,\n                rtol=1e-4,\n            )\n        )\n\n    def test_gelu_with_random_data(test_case):\n        np_array = np.random.rand(4, 4)\n        x = flow.tensor(np_array, dtype=flow.bfloat16, device=\"cpu\")\n        fp32_x = x.float()\n        y = flow.gelu(x)\n        fp32_y = flow.gelu(fp32_x)\n        test_case.assertTrue(\n            np.allclose(\n                y.float().numpy(),\n                fp32_y.bfloat16().float().numpy(),\n                atol=1e-4,\n                rtol=1e-4,\n            )\n        )\n\n    def test_elu_with_random_data(test_case):\n        np_array = np.random.rand(4, 4)\n        x = flow.tensor(np_array, dtype=flow.bfloat16, device=\"cpu\")\n        fp32_x = x.float()\n        elu = flow.nn.ELU()\n        y = elu(x)\n        fp32_y = elu(fp32_x)\n        test_case.assertTrue(\n            np.allclose(\n                y.float().numpy(),\n                fp32_y.bfloat16().float().numpy(),\n                atol=1e-4,\n                rtol=1e-4,\n            )\n        )\n\n    def test_celu_with_random_data(test_case):\n        np_array = np.random.rand(4, 4)\n        x = flow.tensor(np_array, dtype=flow.bfloat16, device=\"cpu\")\n        fp32_x = x.float()\n        celu = flow.nn.CELU()\n        y = celu(x)\n        fp32_y = celu(fp32_x)\n        test_case.assertTrue(\n            np.allclose(\n                y.float().numpy(),\n                fp32_y.bfloat16().float().numpy(),\n                atol=1e-4,\n                rtol=1e-4,\n            )\n        )\n\n    def test_hardswish_with_random_data(test_case):\n        np_array = np.random.rand(4, 4)\n        x = flow.tensor(np_array, dtype=flow.bfloat16, device=\"cpu\")\n        fp32_x = x.float()\n        hardswish = flow.nn.Hardswish()\n        y = hardswish(x)\n        fp32_y = hardswish(fp32_x)\n        test_case.assertTrue(\n            np.allclose(\n                y.float().numpy(),\n                fp32_y.bfloat16().float().numpy(),\n                atol=1e-4,\n                rtol=1e-4,\n            )\n        )\n\n    def test_hardswish_with_random_data(test_case):\n        np_array = np.random.rand(4, 4)\n        x = flow.tensor(np_array, dtype=flow.bfloat16, device=\"cpu\")\n        fp32_x = x.float()\n        hardsigmoid = flow.nn.Hardsigmoid()\n        y = hardsigmoid(x)\n        fp32_y = hardsigmoid(fp32_x)\n        test_case.assertTrue(\n            np.allclose(\n                y.float().numpy(),\n                fp32_y.bfloat16().float().numpy(),\n                atol=1e-4,\n                rtol=1e-4,\n            )\n        )\n\n    def test_hardshrink_with_random_data(test_case):\n        np_array = np.random.rand(4, 4)\n        x = flow.tensor(np_array, dtype=flow.bfloat16, device=\"cpu\")\n        fp32_x = x.float()\n        hardshrink = flow.nn.Hardshrink()\n        y = hardshrink(x)\n        fp32_y = hardshrink(fp32_x)\n        test_case.assertTrue(\n            np.allclose(\n                y.float().numpy(),\n                fp32_y.bfloat16().float().numpy(),\n                atol=1e-4,\n                rtol=1e-4,\n            )\n        )\n\n    def test_hardtanh_with_random_data(test_case):\n        np_array = np.random.rand(4, 4)\n        x = flow.tensor(np_array, dtype=flow.bfloat16, device=\"cpu\")\n        fp32_x = x.float()\n        hardtanh = flow.nn.Hardtanh()\n        y = hardtanh(x)\n        fp32_y = hardtanh(fp32_x)\n        test_case.assertTrue(\n            np.allclose(\n                y.float().numpy(),\n                fp32_y.bfloat16().float().numpy(),\n                atol=1e-4,\n                rtol=1e-4,\n            )\n        )\n\n    def test_leakyrelu_with_random_data(test_case):\n        np_array = np.random.rand(4, 4)\n        x = flow.tensor(np_array, dtype=flow.bfloat16, device=\"cpu\")\n        fp32_x = x.float()\n        leakyrelu = flow.nn.LeakyReLU(0.1)\n        y = leakyrelu(x)\n        fp32_y = leakyrelu(fp32_x)\n        test_case.assertTrue(\n            np.allclose(\n                y.float().numpy(),\n                fp32_y.bfloat16().float().numpy(),\n                atol=1e-4,\n                rtol=1e-4,\n            )\n        )\n\n    def test_threshold_with_random_data(test_case):\n        np_array = np.random.rand(4, 4)\n        x = flow.tensor(np_array, dtype=flow.bfloat16, device=\"cpu\")\n        fp32_x = x.float()\n        th = flow.nn.Threshold(threshold=0.5, value=0.2)\n        y = th(x)\n        fp32_y = th(fp32_x)\n        test_case.assertTrue(\n            np.allclose(\n                y.float().numpy(),\n                fp32_y.bfloat16().float().numpy(),\n                atol=1e-4,\n                rtol=1e-4,\n            )\n        )\n\n    def test_logsinmoid_with_random_data(test_case):\n        np_array = np.random.rand(4, 4)\n        x = flow.tensor(np_array, dtype=flow.bfloat16, device=\"cpu\")\n        fp32_x = x.float()\n        logsigmoid = flow.nn.LogSigmoid()\n        y = logsigmoid(x)\n        fp32_y = logsigmoid(fp32_x)\n        test_case.assertTrue(\n            np.allclose(\n                y.float().numpy(),\n                fp32_y.bfloat16().float().numpy(),\n                atol=1e-4,\n                rtol=1e-4,\n            )\n        )\n\n    def test_digamma_with_random_data(test_case):\n        np_array = np.random.rand(4, 4)\n        x = flow.tensor(np_array, dtype=flow.bfloat16, device=\"cpu\")\n        fp32_x = x.float()\n        y = flow.digamma(x)\n        fp32_y = flow.digamma(fp32_x)\n        test_case.assertTrue(\n            np.allclose(\n                y.float().numpy(),\n                fp32_y.bfloat16().float().numpy(),\n                atol=1e-4,\n                rtol=1e-4,\n            )\n        )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/tensor/test_complex.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport numpy as np\nimport torch as torch_original\n\nimport os\nimport unittest\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\nfrom oneflow.test_utils.test_util import (\n    Array2Numpy,\n    FlattenArray,\n    GenArgList,\n    Index2Coordinate,\n)\nfrom collections import OrderedDict\n\n\"\"\"\nTODO(lml): Support and test more apis.\nFinished: \nflow.from_numpy()\nflow.tensor()\nflow.ones()\nflow.zeros()\nflow.full()\nflow.add()\nflow.sub()\nflow.mul\nflow.sum()\nflow.equal()\nflow.not_equal()\nflow.cast()\nTensor.new_ones()\nTensor.new_zeros()\nTensor.new_full()\nTensor.real()\nTensor.imag()\nTensor.conj()\nTensor.conj_physical()\n\nTo complete:\nflow.randn()\nflow.div()\nflow.pow()\nTensor.adjoint()\nTensor.conj_physical_()\nTensor.resolve_conj()\nTensor.chalf()\nTensor.cfloat(),\nTensor.cdouble()\nMore apis..\n\"\"\"\n\n\ndef compare_result(a, b, rtol=1e-5, atol=1e-8):\n    assert np.allclose(\n        a, b, rtol=rtol, atol=atol\n    ), f\"\\na\\n{a}\\n{'-' * 80}\\nb:\\n{b}\\n{'*' * 80}\\ndiff:\\n{a - b}\"\n\n\ndef _np_zero_pad2d_grad(src, dest, padding):\n    (c_idx, h_idx, w_idx) = (1, 2, 3)\n    pad_left = padding[0]\n    pad_right = padding[1]\n    pad_top = padding[2]\n    pad_bottom = padding[3]\n    (dx_height, dx_width) = (dest.shape[h_idx], dest.shape[w_idx])\n    (dy_height, dy_width) = (src.shape[h_idx], src.shape[w_idx])\n    numpy_src = np.ones(src.shape, np.int32)\n    numpy_dest = np.zeros(dest.shape, np.int32)\n    array_src = FlattenArray(numpy_src)\n    array_dest = FlattenArray(numpy_dest)\n    src_num = src.shape[c_idx] * src.shape[h_idx] * src.shape[w_idx]\n    dest_num = dest.shape[c_idx] * dest.shape[h_idx] * dest.shape[w_idx]\n    elements_num = src.shape[0] * src_num\n    for iter_n in range(elements_num):\n        coords = Index2Coordinate(iter_n, src.shape)\n        (n, c, i, j) = (coords[0], coords[c_idx], coords[h_idx], coords[w_idx])\n        ip_x = ip_y = 0\n        if (\n            j >= pad_left\n            and j < dx_width + pad_left\n            and (i >= pad_top)\n            and (i < dx_height + pad_top)\n        ):\n            ip_x = j - pad_left\n            ip_y = i - pad_top\n            src_index = n * src_num + c * dy_width * dy_height + i * dy_width + j\n            dest_index = (\n                n * dest_num + c * dx_width * dx_height + ip_y * dx_width + ip_x\n            )\n            array_dest[dest_index] += array_src[src_index]\n    numpy_dest = Array2Numpy(array_dest, dest.shape)\n    return numpy_dest\n\n\ndef _test_ZeroPad2d(test_case, shape, padding, value, device, rtol, atol):\n    np_input = np.random.random(shape)\n    of_input = flow.tensor(\n        np_input, dtype=test_case.dtype, device=flow.device(device), requires_grad=True\n    )\n    if isinstance(padding, int):\n        np_boundary = ((0, 0), (0, 0), (padding, padding), (padding, padding))\n    elif isinstance(padding, (tuple, int)) and len(padding) == 4:\n        np_boundary = (\n            (0, 0),\n            (0, 0),\n            (padding[2], padding[3]),\n            (padding[0], padding[1]),\n        )\n    else:\n        raise ValueError(\"padding must be in  or tuple!\")\n    layer = flow.nn.ZeroPad2d(padding=padding)\n    of_out = layer(of_input)\n    np_out = np.pad(np_input, np_boundary, mode=\"constant\", constant_values=value)\n    test_case.assertTrue(np.allclose(of_out.cpu().detach().numpy(), np_out, rtol, atol))\n    of_out = of_out.sum()\n    of_out.backward()\n    np_out_grad = _np_zero_pad2d_grad(np_out, np_input, layer.padding)\n    test_case.assertTrue(\n        np.allclose(of_input.grad.cpu().detach().numpy(), np_out_grad, rtol, atol)\n    )\n\n\nclass TestTensorComplex64(unittest.TestCase):\n    def setUp(self):\n        self.dtype = flow.cfloat\n        self.complex_dtype = flow.complex64\n        self.np_dtype = np.complex64\n        self.type_str = \"ComplexFloatTensor\"\n        self.real_dtype = flow.float\n        self.np_real_dtype = np.float32\n        self.rtol = 1e-5\n        self.atol = 1e-5\n        self.a = [1.0 + 1j, 2.0]\n        self.np_a = np.array(self.a, dtype=self.np_dtype)\n        self.b = [[1.0 + 1j, 2.0], [1.0, 2.0 - 1j], [-1.0, 1j]]\n        self.np_b = np.array(self.b, dtype=self.np_dtype)\n\n        self.lower_n_dims = 2\n        self.upper_n_dims = 5\n        self.shape = []\n        for _ in range(10):\n            num_dims = np.random.randint(self.lower_n_dims, self.upper_n_dims)\n            shape_ = [np.random.randint(1, 11) * 4 for _ in range(num_dims)]\n            self.shape.append(shape_)\n\n    def test_from_numpy(self):\n        a = flow.from_numpy(self.np_a)\n        self.assertEqual(a.dtype, self.dtype)\n        self.assertEqual(a.type(), \"oneflow.\" + self.type_str)\n        np_a = a.numpy()\n        self.assertEqual(np_a.dtype, self.np_dtype)\n        assert np.allclose(np_a, self.np_a)\n\n        b = flow.from_numpy(self.np_b)\n        self.assertEqual(b.dtype, self.dtype)\n        self.assertEqual(b.type(), \"oneflow.\" + self.type_str)\n        np_b = b.numpy()\n        self.assertEqual(np_b.dtype, self.np_dtype)\n        assert np.allclose(np_b, self.np_b)\n\n    def test_tensor(self):\n        a = flow.tensor(self.a, dtype=self.dtype)\n        self.assertEqual(a.dtype, self.dtype)\n        self.assertEqual(a.type(), \"oneflow.\" + self.type_str)\n        np_a = a.numpy()\n        self.assertEqual(np_a.dtype, self.np_dtype)\n        assert np.allclose(np_a, self.np_a)\n\n        a = flow.tensor(self.np_a, dtype=self.dtype)\n        self.assertEqual(a.dtype, self.dtype)\n        self.assertEqual(a.type(), \"oneflow.\" + self.type_str)\n        np_a = a.numpy()\n        self.assertEqual(np_a.dtype, self.np_dtype)\n        assert np.allclose(np_a, self.np_a)\n\n    @unittest.skip(\"skip for now, becase it failed 6 times in past week\")\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    def test_tensor_cuda(self):\n        a = flow.tensor(self.a, dtype=self.dtype, device=\"cuda\")\n        self.assertEqual(a.dtype, self.dtype)\n        self.assertEqual(a.type(), \"oneflow.cuda.\" + self.type_str)\n        np_a = a.numpy()\n        self.assertEqual(np_a.dtype, self.np_dtype)\n        assert np.allclose(np_a, self.np_a)\n\n        a = flow.tensor(self.np_a, dtype=self.dtype, device=\"cuda\")\n        self.assertEqual(a.dtype, self.dtype)\n        self.assertEqual(a.type(), \"oneflow.cuda.\" + self.type_str)\n        np_a = a.numpy()\n        self.assertEqual(np_a.dtype, self.np_dtype)\n        assert np.allclose(np_a, self.np_a)\n\n    @unittest.skip(\"skip for now, becase it failed 2 times in past week\")\n    def test_slice(self):\n        a = flow.from_numpy(self.np_a)\n        np_slice_a = a[1].numpy()\n        self.assertEqual(np_slice_a.dtype, self.np_dtype)\n        assert np.allclose(np_slice_a, self.np_a[1])\n\n        b = flow.from_numpy(self.np_b)\n        np_slice_b = b[1].numpy()\n        self.assertEqual(np_slice_b.dtype, self.np_dtype)\n        assert np.allclose(np_slice_b, self.np_b[1])\n\n        c = flow.full((3, 2), 3.14 + 2j, dtype=self.dtype)\n        np_slice_c = c[0:2, :].numpy()\n        self.assertEqual(np_slice_c.dtype, self.np_dtype)\n        assert np.allclose(\n            np_slice_c, np.ones((2, 2), dtype=self.np_dtype) * (3.14 + 2j)\n        )\n\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    def test_slice_cuda(self):\n        a = flow.from_numpy(self.np_a).cuda()\n        np_slice_a = a[1].cpu().numpy()\n        self.assertEqual(np_slice_a.dtype, self.np_dtype)\n        assert np.allclose(np_slice_a, self.np_a[1])\n\n        b = flow.from_numpy(self.np_b).cuda()\n        np_slice_b = b[1].cpu().numpy()\n        self.assertEqual(np_slice_b.dtype, self.np_dtype)\n        assert np.allclose(np_slice_b, self.np_b[1])\n\n        c = flow.full((3, 2), 3.14 + 2j, dtype=self.dtype).cuda()\n        np_slice_c = c[0:2, :].cpu().numpy()\n        self.assertEqual(np_slice_c.dtype, self.np_dtype)\n        assert np.allclose(\n            np_slice_c, np.ones((2, 2), dtype=self.np_dtype) * (3.14 + 2j)\n        )\n\n    def test_new_tensor(self):\n        a = flow.tensor(self.a, dtype=self.dtype)\n        b = a.new_tensor(self.b)\n        self.assertEqual(b.dtype, self.dtype)\n        self.assertEqual(b.type(), \"oneflow.\" + self.type_str)\n        np_b = b.numpy()\n        self.assertEqual(np_b.dtype, self.np_dtype)\n        assert np.allclose(np_b, self.np_b)\n\n    def test_new_empty(self):\n        a = flow.tensor(self.a, dtype=self.dtype)\n        c = a.new_empty((3, 2))\n        self.assertEqual(c.dtype, self.dtype)\n        self.assertEqual(c.type(), \"oneflow.\" + self.type_str)\n        np_c = c.numpy()\n        self.assertEqual(np_c.dtype, self.np_dtype)\n\n    def test_ones(self):\n        c = flow.ones((3, 2), dtype=self.dtype)\n        self.assertEqual(c.dtype, self.dtype)\n        self.assertEqual(c.type(), \"oneflow.\" + self.type_str)\n        np_c = c.numpy()\n        self.assertEqual(np_c.dtype, self.np_dtype)\n        assert np.allclose(np_c, np.ones((3, 2), dtype=self.np_dtype))\n\n    def test_new_ones(self):\n        b = flow.tensor(self.b, dtype=self.dtype)\n        c = b.new_ones((3, 2))\n        self.assertEqual(c.dtype, self.dtype)\n        self.assertEqual(c.type(), \"oneflow.\" + self.type_str)\n        np_c = c.numpy()\n        self.assertEqual(np_c.dtype, self.np_dtype)\n        assert np.allclose(np_c, np.ones((3, 2), dtype=self.np_dtype))\n\n    def test_zeros(self):\n        c = flow.zeros((3, 2), dtype=self.dtype)\n        self.assertEqual(c.dtype, self.dtype)\n        self.assertEqual(c.type(), \"oneflow.\" + self.type_str)\n        np_c = c.numpy()\n        self.assertEqual(np_c.dtype, self.np_dtype)\n        assert np.allclose(np_c, np.zeros((3, 2), dtype=self.np_dtype))\n\n    def test_new_zeros(self):\n        b = flow.tensor(self.b, dtype=self.dtype)\n        c = b.new_zeros((3, 2))\n        self.assertEqual(c.dtype, self.dtype)\n        self.assertEqual(c.type(), \"oneflow.\" + self.type_str)\n        np_c = c.numpy()\n        self.assertEqual(np_c.dtype, self.np_dtype)\n        assert np.allclose(np_c, np.zeros((3, 2), dtype=self.np_dtype))\n\n    def test_full(self):\n        c = flow.full((3, 2), 3.14 + 2j, dtype=self.dtype)\n        self.assertEqual(c.dtype, self.dtype)\n        self.assertEqual(c.type(), \"oneflow.\" + self.type_str)\n        np_c = c.numpy()\n        self.assertEqual(np_c.dtype, self.np_dtype)\n        assert np.allclose(np_c, np.ones((3, 2), dtype=self.np_dtype) * (3.14 + 2j))\n\n    def test_new_full(self):\n        a = flow.tensor(self.a, dtype=self.dtype)\n        c = a.new_full((3, 2), 3.14 + 2j)\n        self.assertEqual(c.dtype, self.dtype)\n        self.assertEqual(c.type(), \"oneflow.\" + self.type_str)\n        np_c = c.numpy()\n        self.assertEqual(np_c.dtype, self.np_dtype)\n        assert np.allclose(np_c, np.ones((3, 2), dtype=self.np_dtype) * (3.14 + 2j))\n\n    def test_real(self):\n        c = flow.full((3, 2), 3.14 + 2j, dtype=self.dtype).real()\n        self.assertEqual(c.dtype, self.real_dtype)\n        np_c = c.numpy()\n        self.assertEqual(np_c.dtype, self.np_real_dtype)\n        assert np.allclose(np_c, np.ones((3, 2), dtype=self.np_real_dtype) * 3.14)\n\n    def test_imag(self):\n        c = flow.full((3, 2), 3.14 + 2j, dtype=self.dtype).imag()\n        self.assertEqual(c.dtype, self.real_dtype)\n        np_c = c.numpy()\n        self.assertEqual(np_c.dtype, self.np_real_dtype)\n        assert np.allclose(np_c, np.ones((3, 2), dtype=self.np_real_dtype) * 2)\n\n    def test_conj(self):\n        c = flow.full((3, 2), 3.14 + 2j, dtype=self.dtype).conj()\n        self.assertEqual(c.dtype, self.dtype)\n        self.assertEqual(c.type(), \"oneflow.\" + self.type_str)\n        np_c = c.numpy()\n        self.assertEqual(np_c.dtype, self.np_dtype)\n        assert np.allclose(np_c, np.ones((3, 2), dtype=self.np_dtype) * (3.14 - 2j))\n\n    def test_conj_physical(self):\n        c = flow.full((3, 2), 3.14 + 2j, dtype=self.dtype).conj_physical()\n        self.assertEqual(c.dtype, self.dtype)\n        self.assertEqual(c.type(), \"oneflow.\" + self.type_str)\n        np_c = c.numpy()\n        self.assertEqual(np_c.dtype, self.np_dtype)\n        assert np.allclose(np_c, np.ones((3, 2), dtype=self.np_dtype) * (3.14 - 2j))\n\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    def test_real_cuda(self):\n        c = flow.full((3, 2), 3.14 + 2j, dtype=self.dtype, device=\"cuda\").real()\n        self.assertEqual(c.dtype, self.real_dtype)\n        np_c = c.numpy()\n        self.assertEqual(np_c.dtype, self.np_real_dtype)\n        assert np.allclose(np_c, np.ones((3, 2), dtype=self.np_real_dtype) * 3.14)\n\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    def test_imag_cuda(self):\n        c = flow.full((3, 2), 3.14 + 2j, dtype=self.dtype, device=\"cuda\").imag()\n        self.assertEqual(c.dtype, self.real_dtype)\n        np_c = c.numpy()\n        self.assertEqual(np_c.dtype, self.np_real_dtype)\n        assert np.allclose(np_c, np.ones((3, 2), dtype=self.np_real_dtype) * 2)\n\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    def test_conj_cuda(self):\n        c = flow.full((3, 2), 3.14 + 2j, dtype=self.dtype, device=\"cuda\").conj()\n        self.assertEqual(c.dtype, self.dtype)\n        self.assertEqual(c.type(), \"oneflow.cuda.\" + self.type_str)\n        np_c = c.numpy()\n        self.assertEqual(np_c.dtype, self.np_dtype)\n        assert np.allclose(np_c, np.ones((3, 2), dtype=self.np_dtype) * (3.14 - 2j))\n\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    def test_conj_physical_cuda(self):\n        c = flow.full(\n            (3, 2), 3.14 + 2j, dtype=self.dtype, device=\"cuda\"\n        ).conj_physical()\n        self.assertEqual(c.dtype, self.dtype)\n        self.assertEqual(c.type(), \"oneflow.cuda.\" + self.type_str)\n        np_c = c.numpy()\n        self.assertEqual(np_c.dtype, self.np_dtype)\n        assert np.allclose(np_c, np.ones((3, 2), dtype=self.np_dtype) * (3.14 - 2j))\n\n    def test_add_cpu(self):\n        device = \"cpu\"\n        for i, input_shape in enumerate(self.shape):\n            np_x = np.random.randn(*input_shape) + 1.0j * np.random.randn(*input_shape)\n            np_x = np_x.astype(self.np_dtype)\n\n            np_y = np.random.randn(*input_shape) + 1.0j * np.random.randn(*input_shape)\n            np_y = np_y.astype(self.np_dtype)\n\n            flow_x = flow.from_numpy(np_x).to(device).requires_grad_(True)\n            flow_y = flow.from_numpy(np_y).to(device).requires_grad_(True)\n            self.assertEqual(flow_x.dtype, self.dtype)\n            self.assertEqual(flow_y.dtype, self.dtype)\n\n            # forward\n            flow_ret = flow.add(flow_x, flow_y)\n            np_ret = np_x + np_y\n            compare_result(flow_ret, np_ret, self.rtol, self.atol)\n\n            # backward\n            flow_ret.sum().backward()\n            compare_result(\n                flow_x.grad.numpy(), np.ones(input_shape), self.rtol, self.atol\n            )\n            compare_result(\n                flow_y.grad.numpy(), np.ones(input_shape), self.rtol, self.atol\n            )\n\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    def test_add_cuda(self):\n        device = \"cuda\"\n        for i, input_shape in enumerate(self.shape):\n            np_x = np.random.randn(*input_shape) + 1.0j * np.random.randn(*input_shape)\n            np_x = np_x.astype(self.np_dtype)\n\n            np_y = np.random.randn(*input_shape) + 1.0j * np.random.randn(*input_shape)\n            np_y = np_y.astype(self.np_dtype)\n\n            flow_x = flow.from_numpy(np_x).to(device).requires_grad_(True)\n            flow_y = flow.from_numpy(np_y).to(device).requires_grad_(True)\n            self.assertEqual(flow_x.dtype, self.dtype)\n            self.assertEqual(flow_y.dtype, self.dtype)\n\n            # forward\n            flow_ret = flow.add(flow_x, flow_y)\n            np_ret = np_x + np_y\n            compare_result(flow_ret.cpu().detach(), np_ret, self.rtol, self.atol)\n\n            # backward\n            flow_ret.sum().backward()\n            compare_result(\n                flow_x.grad.cpu().detach().numpy(),\n                np.ones(input_shape),\n                self.rtol,\n                self.atol,\n            )\n            compare_result(\n                flow_y.grad.cpu().detach().numpy(),\n                np.ones(input_shape),\n                self.rtol,\n                self.atol,\n            )\n\n    def test_sub_cpu(self):\n        device = \"cpu\"\n        for i, input_shape in enumerate(self.shape):\n            np_x = np.random.randn(*input_shape) + 1.0j * np.random.randn(*input_shape)\n            np_x = np_x.astype(self.np_dtype)\n\n            np_y = np.random.randn(*input_shape) + 1.0j * np.random.randn(*input_shape)\n            np_y = np_y.astype(self.np_dtype)\n\n            flow_x = flow.from_numpy(np_x).to(device).requires_grad_(True)\n            flow_y = flow.from_numpy(np_y).to(device).requires_grad_(True)\n            self.assertEqual(flow_x.dtype, self.dtype)\n            self.assertEqual(flow_y.dtype, self.dtype)\n\n            # forward\n            flow_ret = flow.sub(flow_x, flow_y)\n            np_ret = np_x - np_y\n            compare_result(flow_ret, np_ret, self.rtol, self.atol)\n\n            # backward\n            flow_ret.sum().backward()\n            compare_result(\n                flow_x.grad.numpy(), np.ones(input_shape), self.rtol, self.atol\n            )\n            compare_result(\n                flow_y.grad.numpy(), -np.ones(input_shape), self.rtol, self.atol\n            )\n\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    def test_sub_cuda(self):\n        device = \"cuda\"\n        for i, input_shape in enumerate(self.shape):\n            np_x = np.random.randn(*input_shape) + 1.0j * np.random.randn(*input_shape)\n            np_x = np_x.astype(self.np_dtype)\n\n            np_y = np.random.randn(*input_shape) + 1.0j * np.random.randn(*input_shape)\n            np_y = np_y.astype(self.np_dtype)\n\n            flow_x = flow.from_numpy(np_x).to(device).requires_grad_(True)\n            flow_y = flow.from_numpy(np_y).to(device).requires_grad_(True)\n            self.assertEqual(flow_x.dtype, self.dtype)\n            self.assertEqual(flow_y.dtype, self.dtype)\n\n            # forward\n            flow_ret = flow.sub(flow_x, flow_y)\n            np_ret = np_x - np_y\n            compare_result(flow_ret.cpu().detach(), np_ret, self.rtol, self.atol)\n\n            # backward\n            flow_ret.sum().backward()\n            compare_result(\n                flow_x.grad.cpu().detach().numpy(),\n                np.ones(input_shape),\n                self.rtol,\n                self.atol,\n            )\n            compare_result(\n                flow_y.grad.cpu().detach().numpy(),\n                -np.ones(input_shape),\n                self.rtol,\n                self.atol,\n            )\n\n    def test_mul_cpu(self):\n        device = \"cpu\"\n        for i, input_shape in enumerate(self.shape):\n            np_x = np.random.randn(*input_shape) + 1.0j * np.random.randn(*input_shape)\n            np_x = np_x.astype(self.np_dtype)\n\n            np_y = np.random.randn(*input_shape) + 1.0j * np.random.randn(*input_shape)\n            np_y = np_y.astype(self.np_dtype)\n\n            flow_x = flow.from_numpy(np_x).to(device).requires_grad_(True)\n            flow_y = flow.from_numpy(np_y).to(device).requires_grad_(True)\n            self.assertEqual(flow_x.dtype, self.dtype)\n            self.assertEqual(flow_y.dtype, self.dtype)\n\n            # forward\n            flow_ret = flow.mul(flow_x, flow_y)\n            np_ret = np_x * np_y\n            compare_result(flow_ret, np_ret, self.rtol, self.atol)\n\n            # backward\n            flow_ret.sum().backward()\n            compare_result(\n                flow_x.grad.numpy(), flow_y.numpy().conjugate(), self.rtol, self.atol\n            )\n            compare_result(\n                flow_y.grad.numpy(), flow_x.numpy().conjugate(), self.rtol, self.atol\n            )\n\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    def test_mul_cuda(self):\n        device = \"cuda\"\n        for i, input_shape in enumerate(self.shape):\n            np_x = np.random.randn(*input_shape) + 1.0j * np.random.randn(*input_shape)\n            np_x = np_x.astype(self.np_dtype)\n\n            np_y = np.random.randn(*input_shape) + 1.0j * np.random.randn(*input_shape)\n            np_y = np_y.astype(self.np_dtype)\n\n            flow_x = flow.from_numpy(np_x).to(device).requires_grad_(True)\n            flow_y = flow.from_numpy(np_y).to(device).requires_grad_(True)\n            self.assertEqual(flow_x.dtype, self.dtype)\n            self.assertEqual(flow_y.dtype, self.dtype)\n\n            # forward\n            flow_ret = flow.mul(flow_x, flow_y)\n            np_ret = np_x * np_y\n            compare_result(flow_ret.cpu().detach(), np_ret, self.rtol, self.atol)\n\n            # backward\n            flow_ret.sum().backward()\n            compare_result(\n                flow_x.grad.cpu().detach().numpy(),\n                flow_y.numpy().conjugate(),\n                self.rtol,\n                self.atol,\n            )\n            compare_result(\n                flow_y.grad.cpu().detach().numpy(),\n                flow_x.numpy().conjugate(),\n                self.rtol,\n                self.atol,\n            )\n\n    def test_sum_cpu(self):\n        device = \"cpu\"\n        for i, input_shape in enumerate(self.shape):\n            n_dims = np.random.randint(1, len(input_shape))\n            dims = np.random.choice(\n                len(input_shape) - 1, n_dims, replace=False\n            ).tolist()\n            keepdim = True if np.random.randint(2) == 1 else False\n\n            np_x = np.random.randn(*input_shape) + 1.0j * np.random.randn(*input_shape)\n            np_x = np_x.astype(self.np_dtype)\n\n            flow_x = flow.from_numpy(np_x).to(device).requires_grad_(True)\n            self.assertEqual(flow_x.dtype, self.dtype)\n\n            # forward\n            flow_ret = flow.sum(flow_x, dim=dims, keepdim=keepdim)\n            np_ret = np.sum(np_x, axis=tuple(dims), keepdims=keepdim)\n            compare_result(flow_ret, np_ret, self.rtol, self.atol * 1000)\n\n            # backward\n            flow_ret.sum().backward()\n            compare_result(\n                flow_x.grad.numpy(), np.ones(input_shape), self.rtol, self.atol\n            )\n\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    def test_sum_cuda(self):\n        device = \"cuda\"\n        for i, input_shape in enumerate(self.shape):\n            n_dims = np.random.randint(1, len(input_shape))\n            dims = np.random.choice(\n                len(input_shape) - 1, n_dims, replace=False\n            ).tolist()\n            keepdim = True if np.random.randint(2) == 1 else False\n\n            np_x = np.random.randn(*input_shape) + 1.0j * np.random.randn(*input_shape)\n            np_x = np_x.astype(self.np_dtype)\n\n            flow_x = flow.from_numpy(np_x).to(device).requires_grad_(True)\n            self.assertEqual(flow_x.dtype, self.dtype)\n\n            # forward\n            flow_ret = flow.sum(flow_x, dim=dims, keepdim=keepdim)\n            np_ret = np.sum(np_x, axis=tuple(dims), keepdims=keepdim)\n            compare_result(flow_ret.cpu().detach(), np_ret, self.rtol, self.atol * 1000)\n\n            # backward\n            flow_ret.sum().backward()\n            compare_result(\n                flow_x.grad.cpu().detach().numpy(),\n                np.ones(input_shape),\n                self.rtol,\n                self.atol,\n            )\n\n    def test_equal_cpu(self):\n        device = \"cpu\"\n        for i, input_shape in enumerate(self.shape):\n\n            np_x = np.random.randn(*input_shape) + 1.0j * np.random.randn(*input_shape)\n            np_x = np_x.astype(self.np_dtype)\n\n            np_y = np.random.randn(*input_shape) + 1.0j * np.random.randn(*input_shape)\n            np_y = np_y.astype(self.np_dtype)\n\n            np_z = np.copy(np_x)\n\n            flow_x = flow.from_numpy(np_x).to(device).requires_grad_(False)\n            flow_y = flow.from_numpy(np_y).to(device).requires_grad_(False)\n            flow_z = flow.from_numpy(np_z).to(device).requires_grad_(False)\n            self.assertEqual(flow_x.dtype, self.dtype)\n            self.assertEqual(flow_y.dtype, self.dtype)\n            self.assertEqual(flow_z.dtype, self.dtype)\n\n            # forward\n            flow_ret = flow.equal(flow_x, flow_y)\n            np_ret = np.equal(np_x, np_y)\n            compare_result(flow_ret, np_ret, self.rtol, self.atol)\n\n            flow_ret = flow.equal(flow_x, flow_z)\n            compare_result(\n                flow_ret, np.ones(flow_x.shape).astype(bool), self.rtol, self.atol\n            )\n\n            flow_ret = flow.not_equal(flow_x, flow_z)\n            compare_result(\n                flow_ret, np.zeros(flow_x.shape).astype(bool), self.rtol, self.atol\n            )\n\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    def test_equal_cuda(self):\n        device = \"cuda\"\n        for i, input_shape in enumerate(self.shape):\n\n            np_x = np.random.randn(*input_shape) + 1.0j * np.random.randn(*input_shape)\n            np_x = np_x.astype(self.np_dtype)\n\n            np_y = np.random.randn(*input_shape) + 1.0j * np.random.randn(*input_shape)\n            np_y = np_y.astype(self.np_dtype)\n\n            np_z = np.copy(np_x)\n\n            flow_x = flow.from_numpy(np_x).to(device).requires_grad_(False)\n            flow_y = flow.from_numpy(np_y).to(device).requires_grad_(False)\n            flow_z = flow.from_numpy(np_z).to(device).requires_grad_(False)\n            self.assertEqual(flow_x.dtype, self.dtype)\n            self.assertEqual(flow_y.dtype, self.dtype)\n            self.assertEqual(flow_z.dtype, self.dtype)\n\n            # forward\n            flow_ret = flow.equal(flow_x, flow_y)\n            np_ret = np.equal(np_x, np_y)\n            compare_result(flow_ret, np_ret, self.rtol, self.atol)\n\n            flow_ret = flow.equal(flow_x, flow_z)\n            compare_result(\n                flow_ret, np.ones(flow_x.shape).astype(bool), self.rtol, self.atol\n            )\n\n            flow_ret = flow.not_equal(flow_x, flow_z)\n            compare_result(\n                flow_ret.cpu().detach(),\n                np.zeros(flow_x.shape).astype(bool),\n                self.rtol,\n                self.atol,\n            )\n\n    def test_constant_pad(self):\n        arg_dict = OrderedDict()\n        arg_dict[\"shape\"] = [(1, 2, 3, 4), (8, 3, 4, 4)]\n        arg_dict[\"padding\"] = [2, (1, 1, 2, 2)]\n        arg_dict[\"value\"] = [0.0]\n        arg_dict[\"device\"] = (\n            [\"cpu\", \"cuda\"] if os.getenv(\"ONEFLOW_TEST_CPU_ONLY\") is None else [\"cpu\"]\n        )\n        arg_dict[\"rtol\"] = [self.rtol]\n        arg_dict[\"atol\"] = [self.atol]\n        for arg in GenArgList(arg_dict):\n            _test_ZeroPad2d(self, *arg)\n\n    def test_cast(self):\n        dtype_pairs = [\n            (np.uint8, \"ByteTensor\"),\n            (np.int8, \"CharTensor\"),\n            (np.int32, \"IntTensor\"),\n            (np.int64, \"LongTensor\"),\n            (np.float32, \"FloatTensor\"),\n            (np.float64, \"DoubleTensor\"),\n        ]\n        shape = (3, 5, 2)\n        for np_dtype, type_str in dtype_pairs:\n            np_arr = np.random.randn(*shape).astype(np_dtype)\n            flow_tensor = flow.from_numpy(np_arr)\n            self.assertEqual(flow_tensor.type(), \"oneflow.\" + type_str)\n            np_out = np_arr.astype(self.np_dtype)\n            flow_out = flow.cast(flow_tensor, dtype=self.complex_dtype)\n            self.assertTrue(np.array_equal(flow_out.numpy(), np_out))\n\n        # cp64 -> cp128\n        np_arr = np.random.randn(*shape) + 1.0j * np.random.randn(*shape)\n        np_arr = np_arr.astype(np.complex64)\n        flow_tensor = flow.from_numpy(np_arr)\n        self.assertEqual(flow_tensor.dtype, flow.complex64)\n\n        np_out = np_arr.astype(np.complex128)\n        flow_out = flow.cast(flow_tensor, dtype=flow.complex128)\n        self.assertTrue(np.array_equal(flow_out.numpy(), np_out))\n\n        # cp128 -> cp64\n        np_out = np_out.astype(np.complex64)\n        flow_out = flow.cast(flow_out, dtype=flow.complex64)\n        self.assertTrue(np.array_equal(flow_out.numpy(), np_out))\n\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    def test_cast_cuda(self):\n        dtype_pairs = [\n            (np.uint8, \"ByteTensor\"),\n            (np.int8, \"CharTensor\"),\n            (np.int32, \"IntTensor\"),\n            (np.int64, \"LongTensor\"),\n            (np.float32, \"FloatTensor\"),\n            (np.float64, \"DoubleTensor\"),\n        ]\n        shape = (7, 4, 11)\n        for np_dtype, type_str in dtype_pairs:\n            np_arr = np.random.randn(*shape).astype(np_dtype)\n            flow_tensor = flow.from_numpy(np_arr).cuda()\n            self.assertEqual(flow_tensor.type(), \"oneflow.cuda.\" + type_str)\n            np_out = np_arr.astype(self.np_dtype)\n            flow_out = flow.cast(flow_tensor, dtype=self.complex_dtype)\n            self.assertTrue(np.array_equal(flow_out.cpu().detach().numpy(), np_out))\n\n        # cp64 -> cp128\n        np_arr = np.random.randn(*shape) + 1.0j * np.random.randn(*shape)\n        np_arr = np_arr.astype(np.complex64)\n        flow_tensor = flow.from_numpy(np_arr).cuda()\n        self.assertEqual(flow_tensor.dtype, flow.complex64)\n\n        np_out = np_arr.astype(np.complex128)\n        flow_out = flow.cast(flow_tensor, dtype=flow.complex128)\n        self.assertTrue(np.array_equal(flow_out.cpu().detach().numpy(), np_out))\n\n        # cp128 -> cp64\n        np_out = np_out.astype(np.complex64)\n        flow_out = flow.cast(flow_out, dtype=flow.complex64)\n        self.assertTrue(np.array_equal(flow_out.cpu().detach().numpy(), np_out))\n\n\nclass TestTensorComplex128(TestTensorComplex64):\n    def setUp(self):\n        self.dtype = flow.cdouble\n        self.complex_dtype = flow.complex128\n        self.np_dtype = np.complex128\n        self.type_str = \"ComplexDoubleTensor\"\n        self.real_dtype = flow.double\n        self.np_real_dtype = np.float64\n        self.rtol = 1e-7\n        self.atol = 1e-7\n        self.a = [1.0 + 1j, 2.0]\n        self.np_a = np.array(self.a, dtype=self.np_dtype)\n        self.b = [[1.0 + 1j, 2.0], [1.0, 2.0 - 1j], [-1.0, 1j]]\n        self.np_b = np.array(self.b, dtype=self.np_dtype)\n\n        self.lower_n_dims = 2\n        self.upper_n_dims = 5\n        self.shape = []\n        for _ in range(10):\n            num_dims = np.random.randint(self.lower_n_dims, self.upper_n_dims)\n            shape_ = [np.random.randint(1, 11) * 4 for _ in range(num_dims)]\n            self.shape.append(shape_)\n\n\nclass TestAutograd(unittest.TestCase):\n    def test_backward(self):\n        a = flow.tensor([1.0 + 2j, 2.0 - 3j, 1j], dtype=flow.cfloat)\n        a.requires_grad = True\n        b = flow.conj(a)\n        loss = flow.sum(a.real() + b.imag())\n        loss.backward()\n        assert np.allclose(a.grad.numpy(), np.ones((3,), dtype=np.complex64) * (1 - 1j))\n\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    def test_backward_cuda(self):\n        a = flow.tensor([1.0 + 2j, 2.0 - 3j, 1j], dtype=flow.cfloat, device=\"cuda\")\n        a.requires_grad = True\n        b = flow.conj(a)\n        loss = flow.sum(a.real() + b.imag())\n        loss.backward()\n        assert np.allclose(a.grad.numpy(), np.ones((3,), dtype=np.complex64) * (1 - 1j))\n\n    def test_grad(self):\n        a = flow.tensor([1.0 + 2j, 2.0 - 3j, 1j], dtype=flow.cfloat)\n        a.requires_grad = True\n        b = flow.conj(a)\n        c = a.real() + b.imag()\n        np_dc = np.ones((3,), dtype=np.float32)\n        dc = flow.tensor(np_dc)\n        (da,) = flow.autograd.grad(c, a, dc)\n        assert np.allclose(da.numpy(), np.ones((3,), dtype=np.complex64) * (1 - 1j))\n\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    def test_grad_cuda(self):\n        a = flow.tensor([1.0 + 2j, 2.0 - 3j, 1j], dtype=flow.cfloat, device=\"cuda\")\n        a.requires_grad = True\n        b = flow.conj(a)\n        c = a.real() + b.imag()\n        np_dc = np.ones((3,), dtype=np.float32)\n        dc = flow.tensor(np_dc)\n        (da,) = flow.autograd.grad(c, a, dc)\n        assert np.allclose(da.numpy(), np.ones((3,), dtype=np.complex64) * (1 - 1j))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/tensor/test_data_ptr.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nimport oneflow as flow\nimport oneflow.unittest\n\n\nclass TestDataPtr(unittest.TestCase):\n    @flow.unittest.skip_unless_1n1d()\n    def test_equality(test_case):\n        x = flow.ones(2, 3)\n        y = flow.ones(2, 3)\n        test_case.assertNotEqual(x.data_ptr(), y.data_ptr())\n\n        test_case.assertEqual(x.data_ptr(), x.data.data_ptr())\n\n        x_ptr = x.data_ptr()\n        x[:] = 2\n        test_case.assertEqual(x_ptr, x.data_ptr())\n\n    @flow.unittest.skip_unless_1n2d()\n    def test_global_tensor(test_case):\n        x = flow.randn(\n            2, 3, placement=flow.placement.all(\"cpu\"), sbp=flow.sbp.broadcast\n        )\n        test_case.assertEqual(x.data_ptr(), x.to_local().data_ptr())\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/tensor/test_global_tensor.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nimport oneflow as flow\nimport oneflow.unittest\n\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\nclass TestTensor(flow.unittest.TestCase):\n    @flow.unittest.skip_unless_1n1d()\n    def test_creating_global_tensor(test_case):\n        placement = flow.placement(\"cuda\", [0])\n        sbp = flow.sbp.broadcast\n\n        # Shape -> GlobalTensor\n        shape = (2, 3)\n        x = flow.Tensor(*shape, placement=placement, sbp=sbp)\n        test_case.assertTrue(x.is_global)\n        test_case.assertTrue(x.size() == shape)\n\n        shape = flow.Size((2, 3))\n        x = flow.Tensor(shape, placement=placement, sbp=sbp)\n        test_case.assertTrue(x.is_global)\n        test_case.assertTrue(x.size() == shape)\n\n        # LocalTensor -> GlobalTensor\n        x = flow.Tensor(*shape, device=\"cpu\")\n        test_case.assertTrue(x.is_local)\n        y = flow.Tensor(x, placement=placement, sbp=sbp)\n        test_case.assertTrue(y.is_global)\n\n        # GlobalTensor -> GlobalTensor\n        z = flow.Tensor(y, placement=placement, sbp=sbp)\n        test_case.assertTrue(z.is_global)\n\n        # TODO: ndarray -> GlobalTensor\n\n    @flow.unittest.skip_unless_1n1d()\n    def test_construct_local_from_global_tensor(test_case):\n        placement = flow.placement(\"cuda\", [0])\n        sbp = flow.sbp.broadcast\n        shape = (2, 3)\n        x = flow.Tensor(*shape, placement=placement, sbp=sbp)\n        test_case.assertTrue(x.is_global)\n        # GlobalTensor -> LocalTensor\n        y = flow.Tensor(x, device=\"cpu\")\n        test_case.assertTrue(y.is_local)\n        y = flow.Tensor(x, device=\"cuda\")\n        test_case.assertTrue(y.is_local)\n\n    @flow.unittest.skip_unless_1n1d()\n    def test_global_set_data(test_case):\n        x_placement = flow.placement(\"cpu\", [0])\n        x_sbp = flow.sbp.broadcast\n        x = flow.ones(2, 3, placement=x_placement, sbp=x_sbp)\n        y_placement = flow.placement(\"cuda\", [0])\n        y_sbp = flow.sbp.split(0)\n        y = flow.ones(4, 5, placement=y_placement, sbp=y_sbp)\n        old_id = id(x)\n        x.data = y\n        test_case.assertEqual(old_id, id(x))\n        test_case.assertTrue(x.shape == (4, 5))\n        test_case.assertTrue(x.placement == y_placement)\n        test_case.assertTrue(x.sbp[0] == y_sbp)\n\n    @flow.unittest.skip_unless_1n1d()\n    def test_global_tensor_autograd_related_methods(test_case):\n        placement = flow.placement(\"cuda\", [0])\n        sbp = flow.sbp.split(0)\n        shape = (2, 3, 4, 5)\n        l_x = flow.Tensor(*shape)\n        test_case.assertFalse(l_x.requires_grad)\n        test_case.assertTrue(l_x.is_leaf)\n\n        l_y = flow.Tensor(*shape)\n        l_y.requires_grad = True\n        test_case.assertTrue(l_y.requires_grad)\n        test_case.assertTrue(l_y.is_leaf)\n\n        x = l_x.to_global(placement=placement, sbp=sbp)\n        test_case.assertTrue(x.is_leaf)\n        y = l_y.to_global(placement=placement, sbp=sbp)\n        test_case.assertFalse(y.is_leaf)\n\n        z = x + y\n        test_case.assertTrue(z.requires_grad)\n        test_case.assertFalse(z.is_leaf)\n\n        with flow.no_grad():\n            m = x + y\n\n        test_case.assertTrue(m.is_leaf)\n        test_case.assertFalse(m.requires_grad)\n\n        l_v = flow.Tensor(*shape)\n        l_v.requires_grad = True\n        v = l_v.to_global(placement=placement, sbp=sbp)\n\n        z.retain_grad()\n        w = v + z\n\n        l_grad = flow.ones(*shape)\n        grad = l_grad.to_global(placement=placement, sbp=sbp)\n        w.backward(gradient=grad)\n\n        test_case.assertTrue(\n            np.allclose(l_v.grad.numpy(), np.ones(shape), atol=1e-4, rtol=1e-4)\n        )\n        test_case.assertTrue(\n            np.allclose(l_y.grad.numpy(), np.ones(shape), atol=1e-4, rtol=1e-4)\n        )\n        test_case.assertTrue(\n            np.allclose(\n                z.grad.to_global(sbp=flow.sbp.broadcast).to_local().numpy(),\n                np.ones(shape),\n                atol=1e-4,\n                rtol=1e-4,\n            )\n        )\n        test_case.assertIsNone(l_x.grad)\n\n    @flow.unittest.skip_unless_1n1d()\n    def test_global_tensor_unsupported_property(test_case):\n\n        shape = (2, 3)\n        placement = flow.placement(\"cuda\", [0])\n        sbp = flow.sbp.split(0)\n        a = flow.Tensor(*shape)\n        b = a.to_global(placement=placement, sbp=sbp)\n        test_case.assertTrue(b.is_global)\n\n        with test_case.assertRaises(RuntimeError):\n            b.device()\n\n        with test_case.assertRaises(RuntimeError):\n            b._tensor_buffer_shapes_and_dtypes\n\n    @flow.unittest.skip_unless_1n4d()\n    def test_global_tensor_2d_sbp_init(test_case):\n        V = 10\n        H = 4\n        S = 6\n\n        P = flow.placement(\"cuda\", [[0, 1], [2, 3]])\n\n        wte = flow.nn.Parameter(\n            flow.empty(\n                (V, H),\n                dtype=flow.float32,\n                placement=P,\n                sbp=[flow.sbp.broadcast, flow.sbp.split(0)],\n            )\n        )\n\n        wpe = flow.nn.Parameter(\n            flow.empty(\n                (S, H),\n                dtype=flow.float32,\n                placement=P,\n                sbp=[flow.sbp.broadcast, flow.sbp.broadcast],\n            )\n        )\n\n        flow.nn.init.normal_(wte, std=0.02)\n        flow.nn.init.normal_(wpe, std=0.02)\n\n    @flow.unittest.skip_unless_1n2d()\n    def test_copy(test_case):\n        x = flow.zeros(2, 3)\n        y = flow.ones(2, 3)\n        x.copy_(y)\n        test_case.assertTrue(np.array_equal(x.numpy(), y.numpy()))\n\n        x = flow.zeros(\n            4, 6, placement=flow.placement(\"cuda\", [0, 1]), sbp=flow.sbp.broadcast\n        )\n        y = flow.ones(\n            4, 6, placement=flow.placement(\"cpu\", [0]), sbp=flow.sbp.broadcast\n        )\n        x.copy_(y)\n        test_case.assertTrue(np.array_equal(x.numpy(), y.numpy()))\n\n        x = flow.zeros(\n            4, 6, placement=flow.placement(\"cuda\", [0, 1]), sbp=flow.sbp.broadcast\n        )\n        y = flow.ones(\n            4, 6, placement=flow.placement(\"cuda\", [0]), sbp=flow.sbp.broadcast\n        )\n        x.copy_(y)\n        test_case.assertTrue(np.array_equal(x.numpy(), y.numpy()))\n\n        x = flow.zeros(\n            4, 6, placement=flow.placement(\"cuda\", [0, 1]), sbp=flow.sbp.split(0)\n        )\n        y = flow.ones(\n            4, 6, placement=flow.placement(\"cuda\", [0, 1]), sbp=flow.sbp.broadcast\n        )\n        x.copy_(y)\n        test_case.assertTrue(np.array_equal(x.numpy(), y.numpy()))\n\n        x = flow.zeros(\n            4, 6, placement=flow.placement(\"cuda\", [0, 1]), sbp=flow.sbp.broadcast\n        )\n        y = flow.ones(\n            4, 6, placement=flow.placement(\"cuda\", [0, 1]), sbp=flow.sbp.broadcast\n        )\n        x.copy_(y)\n        test_case.assertTrue(np.array_equal(x.numpy(), y.numpy()))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/tensor/test_global_tensor_and_ndarray_compatibility.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nfrom collections import OrderedDict\n\n\nimport torch\nimport oneflow as flow\nimport unittest\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\nimport unittest\nimport oneflow as flow\nimport oneflow.unittest\n\nfrom oneflow.test_utils.automated_test_util import *\nimport numpy as np\n\nnp.random.seed(233)\n\ntest_compute_op_list = [\n    \"+\",\n    \"-\",\n    \"*\",\n    \"/\",\n    \"**\",\n    \"//\",\n    \"%\",\n]\n\n\ndef do_test_compute_op(test_case, ndim, placement, sbp):\n    dims = [random(1, 4) * 8 for i in range(ndim)]\n    x = random_tensor(ndim, *dims, dtype=int, low=0, high=5)\n    x = x.to_global(placement=placement, sbp=sbp)\n    x = x.to(\"cpu\")\n    flow_input = x.oneflow.detach()\n    torch_input = x.pytorch.detach()\n\n    for op in test_compute_op_list:\n        if op not in [\"**\"]:\n            random_numpy = np.random.randint(1, 30000, size=list(flow_input.shape))\n        else:\n            random_numpy = np.random.randint(1, 5, size=list(flow_input.shape))\n\n        z_flow = eval(f\"flow_input {op} random_numpy\")\n        z_torch = eval(f\"torch_input {op} random_numpy\")\n        test_case.assertTrue(np.allclose(z_flow.numpy(), z_torch.numpy()))\n\n\nclass TestGlobalTensorAndNdarrayCompatibility(flow.unittest.TestCase):\n    @globaltest\n    def test_tensor_and_ndarray_compatibility(test_case):\n        # random ndim in range [1,4]\n        ndim = random(1, 5).to(int).value()\n        for placement in all_placement():\n            for sbp in all_sbp(placement, max_dim=ndim):\n                do_test_compute_op(test_case, ndim, placement, sbp)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/tensor/test_global_tensor_indexing.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\n# This test code is referenced from: https://github.com/pytorch/pytorch/blob/cd41c8f032dd06c445bf97fc76fb82008b19afcb/test/test_indexing.py\n\nimport unittest\n\nimport numpy as np\n\nimport oneflow as flow\nfrom oneflow.test_utils.automated_test_util import *\nimport oneflow.unittest\n\n\ndef _randint(low, high):\n    \"\"\"\n    Get a random integer in the range [low, high).\n    \"\"\"\n    return random(low, high).to(int).value()\n\n\ndef _cpu_global_tensor(tensor):\n    return tensor.to_global(flow.placement.all(\"cpu\"), flow.sbp.broadcast)\n\n\ndef _assert_tensor_equal(test_case, tensor1, tensor2, atol=0.0, rtol=0.0):\n    test_case.assertTrue(\n        np.allclose(tensor1.numpy(), tensor2.numpy(), atol, rtol),\n        f\"{tensor1.numpy()} vs {tensor2.numpy()}\",\n    )\n\n\ndef global_broadcast_consec(size, start=1):\n    \"\"\"\n    Generate a arithmetic progression with given size and start value.\n    \"\"\"\n    sequence = flow.ones([int(np.array(size).prod(0)),]).cumsum(0)\n    sequence.add_(start - 1)\n    return _cpu_global_tensor(sequence.view(*size))\n\n\ndef _test_basic_slice(test_case, placement):\n    broadcast_for_placement = [flow.sbp.broadcast,] * len(placement.ranks.shape)\n\n    ref_sbp = random_sbp(placement, max_dim=3).value()\n    reference = global_broadcast_consec((8, 8, 8)).to_global(placement, ref_sbp)\n\n    # empty tensor indexing\n    _assert_tensor_equal(\n        test_case,\n        reference[\n            _cpu_global_tensor(flow.LongTensor()).to_global(\n                placement, broadcast_for_placement\n            )\n        ],\n        flow.empty(0, 8, 8),\n        atol=0,\n        rtol=0,\n    )\n\n    _assert_tensor_equal(\n        test_case, reference[0], global_broadcast_consec((8, 8)), atol=0, rtol=0\n    )\n    _assert_tensor_equal(\n        test_case, reference[1], global_broadcast_consec((8, 8), 65), atol=0, rtol=0\n    )\n    _assert_tensor_equal(\n        test_case, reference[2], global_broadcast_consec((8, 8), 129), atol=0, rtol=0\n    )\n    _assert_tensor_equal(\n        test_case, reference[0, 1], global_broadcast_consec((8,), 9), atol=0, rtol=0\n    )\n    _assert_tensor_equal(\n        test_case, reference[0:2], global_broadcast_consec((2, 8, 8)), atol=0, rtol=0\n    )\n    test_case.assertEqual(reference[2, 2, 2].item(), 147)\n    _assert_tensor_equal(\n        test_case, reference[:], global_broadcast_consec((8, 8, 8)), atol=0, rtol=0\n    )\n\n    # indexing with Ellipsis\n    _assert_tensor_equal(\n        test_case,\n        reference[..., 2, 2],\n        flow.tensor([19, 83, 147, 211, 275, 339, 403, 467]),\n        atol=0,\n        rtol=0,\n    )\n    _assert_tensor_equal(\n        test_case,\n        reference[0, ..., 2],\n        flow.tensor([3, 11, 19, 27, 35, 43, 51, 59]),\n        atol=0,\n        rtol=0,\n    )\n    _assert_tensor_equal(\n        test_case, reference[..., 2], reference[:, :, 2], atol=0, rtol=0\n    )\n    _assert_tensor_equal(\n        test_case, reference[0, ..., 2], reference[0, :, 2], atol=0, rtol=0\n    )\n    _assert_tensor_equal(\n        test_case, reference[0, 2, ...], reference[0, 2], atol=0, rtol=0\n    )\n    test_case.assertEqual(reference[..., 2, 2, 2].item(), 147)\n    test_case.assertEqual(reference[2, ..., 2, 2].item(), 147)\n    test_case.assertEqual(reference[2, 2, ..., 2].item(), 147)\n    test_case.assertEqual(reference[2, 2, 2, ...].item(), 147)\n    _assert_tensor_equal(test_case, reference[...], reference, atol=0, rtol=0)\n\n    reference_5d = global_broadcast_consec((8, 8, 8, 8, 8)).to_global(\n        placement, sbp=random_sbp(placement, max_dim=5).value()\n    )\n    _assert_tensor_equal(\n        test_case, reference_5d[..., 1, 0], reference_5d[:, :, :, 1, 0], atol=0, rtol=0\n    )\n    _assert_tensor_equal(\n        test_case,\n        reference_5d[2, ..., 1, 0],\n        reference_5d[2, :, :, 1, 0],\n        atol=0,\n        rtol=0,\n    )\n    _assert_tensor_equal(\n        test_case,\n        reference_5d[2, 1, 0, ..., 1],\n        reference_5d[2, 1, 0, :, 1],\n        atol=0,\n        rtol=0,\n    )\n    _assert_tensor_equal(test_case, reference_5d[...], reference_5d, atol=0, rtol=0)\n\n    # LongTensor indexing\n    sbp = random_sbp(placement, max_dim=3).value()\n    reference = global_broadcast_consec((8, 8, 8)).to_global(placement, sbp)\n    idx = _cpu_global_tensor(flow.LongTensor([2, 4])).to_global(\n        placement, broadcast_for_placement\n    )\n    _assert_tensor_equal(\n        test_case, reference[idx], flow.stack([reference[2], reference[4]])\n    )\n\n    # None indexing\n    _assert_tensor_equal(test_case, reference[2, None], reference[2].unsqueeze(0))\n    _assert_tensor_equal(\n        test_case, reference[2, None, None], reference[2].unsqueeze(0).unsqueeze(0)\n    )\n    _assert_tensor_equal(test_case, reference[2:4, None], reference[2:4].unsqueeze(1))\n    _assert_tensor_equal(\n        test_case,\n        reference[None, 2, None, None],\n        reference.unsqueeze(0)[:, 2].unsqueeze(0).unsqueeze(0),\n    )\n    _assert_tensor_equal(\n        test_case,\n        reference[None, 2:5, None, None],\n        reference.unsqueeze(0)[:, 2:5].unsqueeze(2).unsqueeze(2),\n    )\n\n    # indexing 0-length slice\n    _assert_tensor_equal(test_case, flow.empty(0, 8, 8), reference[slice(0)])\n    _assert_tensor_equal(test_case, flow.empty(0, 8), reference[slice(0), 2])\n    _assert_tensor_equal(test_case, flow.empty(0, 8), reference[2, slice(0)])\n    _assert_tensor_equal(test_case, flow.tensor([]), reference[2, 1:1, 2])\n\n    # indexing with step\n    sbp = random_sbp(placement, max_dim=3).value()\n    reference = global_broadcast_consec((8, 8, 8)).to_global(placement, sbp)\n    _assert_tensor_equal(\n        test_case, reference[1:5:2], flow.stack([reference[1], reference[3]], 0)\n    )\n    _assert_tensor_equal(\n        test_case,\n        reference[1:6:2],\n        flow.stack([reference[1], reference[3], reference[5]], 0),\n    )\n    _assert_tensor_equal(\n        test_case, reference[1:9:4], flow.stack([reference[1], reference[5]], 0)\n    )\n    _assert_tensor_equal(\n        test_case,\n        reference[2:4, 1:5:2],\n        flow.stack([reference[2:4, 1], reference[2:4, 3]], 1),\n    )\n    _assert_tensor_equal(\n        test_case,\n        reference[3, 1:6:2],\n        flow.stack([reference[3, 1], reference[3, 3], reference[3, 5]], 0),\n    )\n    _assert_tensor_equal(\n        test_case,\n        reference[None, 2, 1:9:4],\n        flow.stack([reference[2, 1], reference[2, 5]], 0).unsqueeze(0),\n    )\n    _assert_tensor_equal(\n        test_case,\n        reference[:, 2, 1:6:2],\n        flow.stack([reference[:, 2, 1], reference[:, 2, 3], reference[:, 2, 5]], 1),\n    )\n\n    #  random check\n    lst = [\n        list(range(i, i + 16)) for i in range(0, 256, 16)\n    ]  # arange(64).reshape(8, 8)\n    tensor = _cpu_global_tensor(flow.DoubleTensor(lst))\n    for _ in range(5):\n        sbp = random_sbp(placement, max_dim=2).value()\n        cur_tensor = tensor.to_global(placement, sbp)\n\n        idx1_start = _randint(0, 16)\n        idx1_end = idx1_start + _randint(1, 16 - idx1_start + 1)\n        idx1_step = _randint(1, 14)\n        idx1 = slice(idx1_start, idx1_end, idx1_step)\n        if _randint(0, 2) == 0:\n            idx2_start = _randint(0, 16)\n            idx2_end = idx2_start + _randint(1, 16 - idx2_start + 1)\n            idx2_step = _randint(1, 14)\n            idx2 = slice(idx2_start, idx2_end, idx2_step)\n            lst_indexed = [l[idx2] for l in lst[idx1]]\n            tensor_indexed = cur_tensor[idx1, idx2]\n        else:\n            lst_indexed = lst[idx1]\n            tensor_indexed = cur_tensor[idx1]\n        _assert_tensor_equal(test_case, flow.DoubleTensor(lst_indexed), tensor_indexed)\n\n    # error check\n    sbp = random_sbp(placement, max_dim=3).value()\n    reference = global_broadcast_consec((8, 8, 8)).to_global(placement, sbp)\n    test_case.assertRaises(RuntimeError, lambda: reference[1:9:0])\n    test_case.assertRaises(RuntimeError, lambda: reference[1:9:-1])\n\n    test_case.assertRaises(IndexError, lambda: reference[1, 1, 1, 1])\n    test_case.assertRaises(IndexError, lambda: reference[1, 1, 1, 1:1])\n    test_case.assertRaises(IndexError, lambda: reference[3, 3, 3, 3, 3, 3, 3, 3])\n\n    test_case.assertRaises(IndexError, lambda: reference[0.0])\n    test_case.assertRaises(RuntimeError, lambda: reference[0.0:2.0])\n    test_case.assertRaises(IndexError, lambda: reference[0.0, 0.0:2.0])\n    test_case.assertRaises(IndexError, lambda: reference[0.0, :, 0.0:2.0])\n    test_case.assertRaises(IndexError, lambda: reference[0.0, ..., 0.0:2.0])\n    test_case.assertRaises(IndexError, lambda: reference[0.0, :, 0.0])\n\n\ndef _test_advanced_indexing(test_case, placement, dtype):\n    broadcast_for_placement = [flow.sbp.broadcast] * len(placement.ranks.shape)\n\n    # pick a random valid indexer type\n    def ri(indices):\n        choice = _randint(0, 3)\n        if choice == 0:\n            return _cpu_global_tensor(flow.LongTensor(indices)).to_global(\n                placement, broadcast_for_placement\n            )\n        elif choice == 1:\n            return list(indices)\n        else:\n            return tuple(indices)\n\n    def validate_indexing(x):\n        _assert_tensor_equal(test_case, x[[0]], global_broadcast_consec((1,)))\n        _assert_tensor_equal(test_case, x[ri([0]),], global_broadcast_consec((1,)))\n        _assert_tensor_equal(test_case, x[ri([3]),], global_broadcast_consec((1,), 4))\n        _assert_tensor_equal(test_case, x[[2, 3, 4]], global_broadcast_consec((3,), 3))\n        _assert_tensor_equal(\n            test_case, x[ri([2, 3, 4]),], global_broadcast_consec((3,), 3)\n        )\n        _assert_tensor_equal(\n            test_case, x[ri([0, 2, 4]),], flow.tensor([1, 3, 5], dtype=dtype),\n        )\n\n    def validate_setting(x):\n        x[[0]] = -2\n        _assert_tensor_equal(test_case, x[[0]], flow.tensor([-2], dtype=dtype))\n        x[[0]] = -1\n        _assert_tensor_equal(test_case, x[ri([0]),], flow.tensor([-1], dtype=dtype))\n        x[[2, 3, 4]] = 4\n        _assert_tensor_equal(\n            test_case, x[[2, 3, 4]], flow.tensor([4, 4, 4], dtype=dtype)\n        )\n        x[ri([2, 3, 4]),] = 3\n        _assert_tensor_equal(\n            test_case, x[ri([2, 3, 4]),], flow.tensor([3, 3, 3], dtype=dtype),\n        )\n        x[ri([0, 2, 4]),] = _cpu_global_tensor(flow.tensor([5, 4, 3], dtype=dtype))\n        _assert_tensor_equal(\n            test_case, x[ri([0, 2, 4]),], flow.tensor([5, 4, 3], dtype=dtype),\n        )\n\n    # 1d tensor and integer index setitem and getitem\n    sbp = random_sbp(placement, max_dim=1).value()\n    reference = global_broadcast_consec((8,)).to_global(placement, sbp)\n    validate_indexing(reference)\n    validate_setting(reference)\n\n    # reference is  1  2  3  4  5  6  7  8\n    #               9 10 11 12 13 14 15 16\n    #              17 18 19 20 21 22 23 24\n    #              25 26 27 28 29 30 31 32\n    #              33 34 35 36 37 38 39 40\n    #              41 42 43 44 45 46 47 48\n    #              49 50 51 52 53 54 55 56\n    #              57 58 59 60 61 62 63 64\n    sbp = random_sbp(placement, max_dim=2).value()\n    reference = global_broadcast_consec((8, 8)).to_global(placement, sbp)\n    _assert_tensor_equal(\n        test_case,\n        reference[ri([0, 1, 2]), ri([0])],\n        flow.tensor([1, 9, 17], dtype=dtype),\n    )\n    _assert_tensor_equal(\n        test_case,\n        reference[ri([0, 1, 2]), ri([1])],\n        flow.tensor([2, 10, 18], dtype=dtype),\n    )\n    _assert_tensor_equal(\n        test_case, reference[ri([0]), ri([0])], global_broadcast_consec((1,))\n    )\n    _assert_tensor_equal(\n        test_case, reference[ri([2]), ri([1])], global_broadcast_consec((1,), 18)\n    )\n    _assert_tensor_equal(\n        test_case,\n        reference[[ri([0, 0]), ri([0, 1])]],\n        flow.tensor([1, 2], dtype=dtype),\n    )\n    _assert_tensor_equal(\n        test_case,\n        reference[[ri([0, 1, 1, 0, 2, 7]), ri([1])]],\n        flow.tensor([2, 10, 10, 2, 18, 58], dtype=dtype),\n    )\n    _assert_tensor_equal(\n        test_case,\n        reference[[ri([0, 0, 1, 1]), ri([0, 1, 0, 0])]],\n        flow.tensor([1, 2, 9, 9], dtype=dtype),\n    )\n\n    rows = ri([[0, 0], [1, 6]])\n    columns = ([0],)\n    _assert_tensor_equal(\n        test_case,\n        reference[rows, columns],\n        flow.tensor([[1, 1], [9, 49]], dtype=dtype),\n    )\n\n    rows = ri([[0, 0], [1, 6]])\n    columns = ri([6, 0])\n    _assert_tensor_equal(\n        test_case,\n        reference[rows, columns],\n        flow.tensor([[7, 1], [15, 49]], dtype=dtype),\n    )\n    rows = ri([[0, 0], [1, 2]])\n    columns = ri([[0, 1], [3, 7]])\n    _assert_tensor_equal(\n        test_case,\n        reference[rows, columns],\n        flow.tensor([[1, 2], [12, 24]], dtype=dtype),\n    )\n\n    # setting values\n    reference[ri([0]), ri([1])] = -1\n    _assert_tensor_equal(\n        test_case, reference[ri([0]), ri([1])], flow.tensor([-1], dtype=dtype),\n    )\n    reference[ri([0, 1, 2]), ri([0])] = _cpu_global_tensor(\n        flow.tensor([-1, 2, -4], dtype=dtype)\n    ).to_global(placement, broadcast_for_placement)\n    _assert_tensor_equal(\n        test_case,\n        reference[ri([0, 1, 2]), ri([0])],\n        flow.tensor([-1, 2, -4], dtype=dtype),\n    )\n    reference[rows, columns] = _cpu_global_tensor(\n        flow.tensor([[4, 6], [2, 3]], dtype=dtype)\n    ).to_global(placement, broadcast_for_placement)\n    _assert_tensor_equal(\n        test_case, reference[rows, columns], flow.tensor([[4, 6], [2, 3]], dtype=dtype),\n    )\n\n    # Tests using less than the number of dims, and ellipsis\n    # reference is  1  2  3  4  5  6  7  8\n    #               9 10 11 12 13 14 15 16\n    #              17 18 19 20 21 22 23 24\n    #              25 26 27 28 29 30 31 32\n    #              33 34 35 36 37 38 39 40\n    #              41 42 43 44 45 46 47 48\n    #              49 50 51 52 53 54 55 56\n    #              57 58 59 60 61 62 63 64\n    sbp = random_sbp(placement, max_dim=2).value()\n    reference = global_broadcast_consec((8, 8)).to_global(placement, sbp)\n    _assert_tensor_equal(\n        test_case,\n        reference[ri([0, 2]),],\n        flow.tensor(\n            [[1, 2, 3, 4, 5, 6, 7, 8], [17, 18, 19, 20, 21, 22, 23, 24]], dtype=dtype\n        ),\n    )\n    _assert_tensor_equal(\n        test_case,\n        reference[ri([1]), ...],\n        flow.tensor([[9, 10, 11, 12, 13, 14, 15, 16]], dtype=dtype),\n    )\n    _assert_tensor_equal(\n        test_case,\n        reference[..., ri([1])],\n        flow.tensor([[2], [10], [18], [26], [34], [42], [50], [58]], dtype=dtype),\n    )\n\n    # verify too many indices fails\n    with test_case.assertRaises(IndexError):\n        reference[ri([1]), ri([0, 2]), ri([3])]\n\n    # test invalid index fails\n    sbp = random_sbp(placement, max_dim=1).value()\n    reference = _cpu_global_tensor(flow.empty(8, dtype=dtype)).to_global(placement, sbp)\n    for err_idx in (10, -11):\n        with test_case.assertRaisesRegex(IndexError, r\"out of bounds\"):\n            reference[err_idx]\n\n\ndef _test_combined_indexing(test_case, placement, dtype):\n    broadcast_for_placement = [flow.sbp.broadcast,] * len(placement.ranks.shape)\n\n    def tensor_indices_to_np(tensor, indices):\n        # convert the flow Tensor to a numpy array\n        npt = tensor.numpy()\n\n        # convert indices\n        idxs = tuple(\n            i.tolist() if isinstance(i, flow.LongTensor) else i for i in indices\n        )\n\n        return npt, idxs\n\n    def get_numpy(tensor, indices):\n        npt, idxs = tensor_indices_to_np(tensor, indices)\n\n        # index and return as a oneflow local Tensor\n        return flow.tensor(npt[idxs], dtype=dtype)\n\n    def set_numpy(tensor, indices, value):\n        if not isinstance(value, int):\n            value = value.numpy()\n\n        npt, idxs = tensor_indices_to_np(tensor, indices)\n        npt[idxs] = value\n        return npt\n\n    def assert_get_eq(tensor, indexer):\n        _assert_tensor_equal(test_case, tensor[indexer], get_numpy(tensor, indexer))\n\n    def assert_set_eq(tensor, indexer, val):\n        pyt = tensor.clone()\n        np_ref = tensor.clone()\n        pyt[indexer] = val\n        np_ref = flow.tensor(set_numpy(np_ref, indexer, val), dtype=dtype)\n        _assert_tensor_equal(test_case, pyt, np_ref)\n\n    def assert_backward_eq(tensor, indexer):\n        # compare gradient between cpu and cuda\n        cpu = (\n            tensor.float()\n            .clone()\n            .detach()\n            .to_global(placement, broadcast_for_placement)\n            .requires_grad_()\n        )\n        outcpu = cpu.clone()[indexer]\n        outcpu.sum().backward()\n        dev = (\n            cpu.detach()\n            .to_global(\n                placement, random_sbp(placement, max_dim=len(tensor.shape)).value()\n            )\n            .requires_grad_(True)\n        )\n        outdev = dev[indexer]\n        outdev.sum().backward()\n        _assert_tensor_equal(test_case, cpu.grad, dev.grad)\n\n    def get_set_tensor(indexed, indexer):\n        set_size = indexed[indexer].size()\n        set_count = indexed[indexer].numel()\n        set_tensor = _cpu_global_tensor(\n            flow.arange(set_count, 0, -1).view(set_size).to(dtype)\n        ).to_global(placement, broadcast_for_placement)\n        return set_tensor\n\n    # Tensor is  1  2  3  4  5  6  7  8\n    #            9  10 11 12 13 14 15 16\n    #            17 18 19 20 21 22 23 24\n    #            25 26 27 28 29 30 31 32\n    #            33 34 35 36 37 38 39 40\n    #            41 42 43 44 45 46 47 48\n    #            49 50 51 52 53 54 55 56\n    #            57 58 59 60 61 62 63 64\n    sbp = random_sbp(placement, max_dim=2).value()\n    reference = global_broadcast_consec((8, 8)).to_global(placement, sbp)\n\n    indices_to_test = [\n        # grab the second, fourth columns\n        [slice(None), [4, 6]],\n        # first, third rows,\n        [[0, 6], slice(None)],\n        # TODO(wyg): only support getitem but not setitem\n        #  # weird shape\n        #  [slice(None), [[0, 1],\n        #                 [2, 3]]],\n        # negatives\n        [[-1], [0]],\n        [[0, 7], [-1]],\n        [slice(None), [-1]],\n    ]\n\n    # test getitem\n    get_indices_to_test = indices_to_test + [[slice(None), [0, 1, 1, 2, 2]]]\n    get_indices_to_test = indices_to_test + [\n        [slice(None), [[0, 1], [2, 3]]]\n    ]  # TODO: test setitem\n    for indexer in get_indices_to_test:\n        assert_get_eq(reference, indexer)\n        if placement.type != \"cpu\":\n            assert_backward_eq(reference, indexer)\n\n    # test setitem\n    for indexer in indices_to_test:\n        assert_set_eq(reference, indexer, 44)\n        assert_set_eq(reference, indexer, get_set_tensor(reference, indexer))\n\n    #########################\n    # test more dims tensor #\n    #########################\n    sbp = random_sbp(placement, max_dim=3).value()\n    reference = global_broadcast_consec((8, 8, 8), 0).float().to_global(placement, sbp)\n\n    indices_to_test = [\n        [slice(None), slice(None), [0, 3, 4]],\n        [slice(None), [2, 4, 5, 7], slice(None)],\n        [[2, 3], slice(None), slice(None)],\n        [slice(None), [0, 2, 3], [1, 3, 4]],\n        [slice(None), [0], [1, 2, 4]],\n        [slice(None), [0, 1, 3], [4]],\n        [slice(None), [[0, 1], [1, 0]], [[2, 3]]],\n        [slice(None), [[0, 1], [2, 3]], [[0]]],\n        [slice(None), [[5, 6]], [[0, 3], [4, 4]]],\n        [[0, 2, 3], [1, 3, 4], slice(None)],\n        [[0], [1, 2, 4], slice(None)],\n        [[0, 1, 3], [4], slice(None)],\n        [[[0, 1], [1, 0]], [[2, 1], [3, 5]], slice(None)],\n        [[[0, 1], [1, 0]], [[2, 3]], slice(None)],\n        [[[0, 1], [2, 3]], [[0]], slice(None)],\n        [[[2, 1]], [[0, 3], [4, 4]], slice(None)],\n        [[[2]], [[0, 3], [4, 1]], slice(None)],\n        # non-contiguous indexing subspace\n        [[0, 2, 3], slice(None), [1, 3, 4]],\n        # less dim, ellipsis\n        [[0, 2],],\n        [[0, 2], slice(None)],\n        [[0, 2], Ellipsis],\n        [[0, 2], slice(None), Ellipsis],\n        [[0, 2], Ellipsis, slice(None)],\n        [[0, 2], [1, 3]],\n        [[0, 2], [1, 3], Ellipsis],\n        [Ellipsis, [1, 3], [2, 3]],\n        [Ellipsis, [2, 3, 4]],\n        [Ellipsis, slice(None), [2, 3, 4]],\n        [slice(None), Ellipsis, [2, 3, 4]],\n        # ellipsis counts for nothing\n        [Ellipsis, slice(None), slice(None), [0, 3, 4]],\n        [slice(None), Ellipsis, slice(None), [0, 3, 4]],\n        [slice(None), slice(None), Ellipsis, [0, 3, 4]],\n        [slice(None), slice(None), [0, 3, 4], Ellipsis],\n        [Ellipsis, [[0, 1], [1, 0]], [[2, 1], [3, 5]], slice(None)],\n        [[[0, 1], [1, 0]], [[2, 1], [3, 5]], Ellipsis, slice(None)],\n        [[[0, 1], [1, 0]], [[2, 1], [3, 5]], slice(None), Ellipsis],\n    ]\n\n    for indexer in indices_to_test:\n        assert_get_eq(reference, indexer)\n        assert_set_eq(reference, indexer, 212)\n        assert_set_eq(reference, indexer, get_set_tensor(reference, indexer))\n        if placement.type != \"cpu\":\n            assert_backward_eq(reference, indexer)\n\n    sbp = random_sbp(placement, max_dim=4).value()\n    reference = (\n        global_broadcast_consec((8, 8, 8, 8), 0).float().to_global(placement, sbp)\n    )\n\n    indices_to_test = [\n        [slice(None), slice(None), slice(None), [0, 3, 4]],\n        [slice(None), slice(None), [2, 4, 5, 7], slice(None)],\n        [slice(None), [2, 3], slice(None), slice(None)],\n        [[1, 2], slice(None), slice(None), slice(None)],\n        [slice(None), slice(None), [0, 2, 3], [1, 3, 4]],\n        [slice(None), slice(None), [0], [1, 2, 4]],\n        [slice(None), slice(None), [0, 1, 3], [4]],\n        [slice(None), slice(None), [[0, 1], [1, 0]], [[2, 3]]],\n        [slice(None), slice(None), [[0, 1], [2, 3]], [[0]]],\n        [slice(None), slice(None), [[5, 6]], [[0, 3], [4, 4]]],\n        [slice(None), [0, 2, 3], [1, 3, 4], slice(None)],\n        [slice(None), [0], [1, 2, 4], slice(None)],\n        [slice(None), [0, 1, 3], [4], slice(None)],\n        [slice(None), [[0, 1], [3, 4]], [[2, 3], [0, 1]], slice(None)],\n        [slice(None), [[0, 1], [3, 4]], [[2, 3]], slice(None)],\n        [slice(None), [[0, 1], [3, 2]], [[0]], slice(None)],\n        [slice(None), [[2, 1]], [[0, 3], [6, 4]], slice(None)],\n        [slice(None), [[2]], [[0, 3], [4, 2]], slice(None)],\n        [[0, 1, 2], [1, 3, 4], slice(None), slice(None)],\n        [[0], [1, 2, 4], slice(None), slice(None)],\n        [[0, 1, 2], [4], slice(None), slice(None)],\n        [[[0, 1], [0, 2]], [[2, 4], [1, 5]], slice(None), slice(None)],\n        [[[0, 1], [1, 2]], [[2, 0]], slice(None), slice(None)],\n        [[[2, 2]], [[0, 3], [4, 5]], slice(None), slice(None)],\n        [[[2]], [[0, 3], [4, 5]], slice(None), slice(None)],\n        [slice(None), [3, 4, 6], [0, 2, 3], [1, 3, 4]],\n        [slice(None), [2, 3, 4], [1, 3, 4], [4]],\n        [slice(None), [0, 1, 3], [4], [1, 3, 4]],\n        [slice(None), [6], [0, 2, 3], [1, 3, 4]],\n        [slice(None), [2, 3, 5], [3], [4]],\n        [slice(None), [0], [4], [1, 3, 4]],\n        [slice(None), [6], [0, 2, 3], [1]],\n        [slice(None), [[0, 3], [3, 6]], [[0, 1], [1, 3]], [[5, 3], [1, 2]]],\n        [[2, 2, 1], [0, 2, 3], [1, 3, 4], slice(None)],\n        [[2, 0, 1], [1, 2, 3], [4], slice(None)],\n        [[0, 1, 2], [4], [1, 3, 4], slice(None)],\n        [[0], [0, 2, 3], [1, 3, 4], slice(None)],\n        [[0, 2, 1], [3], [4], slice(None)],\n        [[0], [4], [1, 3, 4], slice(None)],\n        [[1], [0, 2, 3], [1], slice(None)],\n        [[[1, 2], [1, 2]], [[0, 1], [2, 3]], [[2, 3], [3, 5]], slice(None)],\n        # less dim, ellipsis\n        [Ellipsis, [0, 3, 4]],\n        [Ellipsis, slice(None), [0, 3, 4]],\n        [Ellipsis, slice(None), slice(None), [0, 3, 4]],\n        [slice(None), Ellipsis, [0, 3, 4]],\n        [slice(None), slice(None), Ellipsis, [0, 3, 4]],\n        [slice(None), [0, 2, 3], [1, 3, 4]],\n        [slice(None), [0, 2, 3], [1, 3, 4], Ellipsis],\n        [Ellipsis, [0, 2, 3], [1, 3, 4], slice(None)],\n        [[0], [1, 2, 4]],\n        [[0], [1, 2, 4], slice(None)],\n        [[0], [1, 2, 4], Ellipsis],\n        [[0], [1, 2, 4], Ellipsis, slice(None)],\n        [[1],],\n        [[0, 2, 1], [3], [4]],\n        [[0, 2, 1], [3], [4], slice(None)],\n        [[0, 2, 1], [3], [4], Ellipsis],\n        [Ellipsis, [0, 2, 1], [3], [4]],\n    ]\n\n    for indexer in indices_to_test:\n        assert_get_eq(reference, indexer)\n        assert_set_eq(reference, indexer, 1333)\n        assert_set_eq(reference, indexer, get_set_tensor(reference, indexer))\n    indices_to_test += [\n        [slice(None), slice(None), [[0, 1], [1, 0]], [[2, 3], [3, 0]]],\n        [slice(None), slice(None), [[2]], [[0, 3], [4, 4]]],\n    ]\n    for indexer in indices_to_test:\n        assert_get_eq(reference, indexer)\n        assert_set_eq(reference, indexer, 1333)\n        if placement.type != \"cpu\":\n            assert_backward_eq(reference, indexer)\n\n\ndef _test_single_int(test_case, placement):\n    sbp = random_sbp(placement, max_dim=1).value()\n    v = _cpu_global_tensor(flow.zeros(8, 7, 3)).to_global(placement, sbp)\n    test_case.assertEqual(v[2].shape, (7, 3))\n    test_case.assertEqual(v[6].shape, (7, 3))\n\n\ndef _test_multiple_int(test_case, placement):\n    sbp = random_sbp(placement, max_dim=3).value()\n    v = _cpu_global_tensor(flow.zeros(8, 8, 8)).to_global(placement, sbp)\n    test_case.assertEqual(v[4, :, 1].shape, (8,))\n\n\ndef _test_none(test_case, placement):\n    sbp = random_sbp(placement, max_dim=3).value()\n    v = _cpu_global_tensor(flow.zeros(8, 8, 8)).to_global(placement, sbp)\n    test_case.assertEqual(v[None].shape, (1, 8, 8, 8))\n    test_case.assertEqual(v[:, None].shape, (8, 1, 8, 8))\n    test_case.assertEqual(v[:, None, None].shape, (8, 1, 1, 8, 8))\n    test_case.assertEqual(v[..., None].shape, (8, 8, 8, 1))\n\n\ndef _test_step(test_case, placement):\n    sbp = random_sbp(placement, max_dim=1).value()\n    v = _cpu_global_tensor(flow.arange(8)).to_global(placement, sbp)\n    _assert_tensor_equal(test_case, v[::1], v)\n    test_case.assertEqual(v[::2].tolist(), [0, 2, 4, 6])\n    test_case.assertEqual(v[::3].tolist(), [0, 3, 6])\n    test_case.assertEqual(v[::11].tolist(), [0])\n    test_case.assertEqual(v[1:6:2].tolist(), [1, 3, 5])\n\n\ndef _test_step_assignment(test_case, placement):\n    broadcast_for_placement = [flow.sbp.broadcast,] * len(placement.ranks.shape)\n    sbp = random_sbp(placement, max_dim=2).value()\n    v = _cpu_global_tensor(flow.zeros(8, 8)).to_global(placement, sbp)\n    v[0, 1::2] = _cpu_global_tensor(flow.tensor([3.0, 4.0, 5.0, 6.0])).to_global(\n        placement, broadcast_for_placement\n    )\n    test_case.assertEqual(v[0].tolist(), [0.0, 3.0, 0.0, 4.0, 0.0, 5.0, 0.0, 6.0])\n    test_case.assertEqual(v[1:].sum(), 0)\n\n\ndef _test_bool_indices(test_case, placement):\n    broadcast_for_placement = [flow.sbp.broadcast,] * len(placement.ranks.shape)\n    sbp = random_sbp(placement, max_dim=3).value()\n    v = global_broadcast_consec((8, 8, 8)).to_global(placement, sbp)\n    boolIndices = _cpu_global_tensor(\n        flow.tensor(\n            [True, False, True, True, False, False, False, True], dtype=flow.bool\n        )\n    ).to_global(placement, broadcast_for_placement)\n    test_case.assertEqual(v[boolIndices].shape, (4, 8, 8))\n    _assert_tensor_equal(\n        test_case, v[boolIndices], flow.stack([v[0], v[2], v[3], v[7]])\n    )\n\n\ndef _test_multiple_bool_indices(test_case, placement):\n    broadcast_for_placement = [flow.sbp.broadcast,] * len(placement.ranks.shape)\n    sbp = random_sbp(placement, max_dim=2).value()\n    v = global_broadcast_consec((8, 8, 4)).to_global(placement, sbp)\n    # NOTE: these broadcast together and are transposed to the first dim\n    mask1 = _cpu_global_tensor(\n        flow.tensor([1, 0, 1, 0, 0, 1, 0, 0], dtype=flow.bool)\n    ).to_global(placement, broadcast_for_placement)\n    mask2 = _cpu_global_tensor(flow.tensor([1, 1, 1, 0], dtype=flow.bool)).to_global(\n        placement, broadcast_for_placement\n    )\n    test_case.assertEqual(v[mask1, :, mask2].shape, (3, 8))\n\n\ndef _test_int_indices(test_case, placement):\n    sbp = random_sbp(placement, max_dim=3).value()\n    v = global_broadcast_consec((8, 8, 8)).to_global(placement, sbp)\n    test_case.assertEqual(v[[0, 4, 2]].shape, (3, 8, 8))\n    test_case.assertEqual(v[:, [0, 4, 2]].shape, (8, 3, 8))\n    test_case.assertEqual(v[:, [[0, 1], [4, 3]]].shape, (8, 2, 2, 8))\n\n\ndef _test_int_indices2d(test_case, placement):\n    broadcast_for_placement = [flow.sbp.broadcast,] * len(placement.ranks.shape)\n    sbp = random_sbp(placement, max_dim=2).value()\n    x = global_broadcast_consec((8, 8)).to_global(placement, sbp)\n    rows = _cpu_global_tensor(flow.tensor([[0, 0], [6, 3]])).to_global(\n        placement, broadcast_for_placement\n    )\n    columns = _cpu_global_tensor(flow.tensor([[0, 2], [0, 7]])).to_global(\n        placement, broadcast_for_placement\n    )\n    test_case.assertEqual(x[rows, columns].tolist(), [[1, 3], [49, 32]])\n\n\ndef _test_int_indices_broadcast(test_case, placement):\n    broadcast_for_placement = [flow.sbp.broadcast,] * len(placement.ranks.shape)\n    sbp = random_sbp(placement, max_dim=2).value()\n    x = global_broadcast_consec((8, 8)).to_global(placement, sbp)\n    rows = _cpu_global_tensor(flow.tensor([0, 7])).to_global(\n        placement, broadcast_for_placement\n    )\n    columns = _cpu_global_tensor(flow.tensor([7, 2])).to_global(\n        placement, broadcast_for_placement\n    )\n    result = x[rows[:, None], columns]\n    test_case.assertEqual(result.tolist(), [[8, 3], [64, 59]])\n\n\ndef _test_empty_index(test_case, placement):\n    broadcast_for_placement = [flow.sbp.broadcast,] * len(placement.ranks.shape)\n    # TODO:(wangyinggang): masked_fill support sbp:partial_sum\n    sbp = random_sbp(placement, max_dim=2, except_partial_sum=True).value()\n    x = global_broadcast_consec((8, 8)).to_global(placement, sbp)\n    idx = _cpu_global_tensor(flow.tensor([], dtype=flow.long)).to_global(\n        placement, broadcast_for_placement\n    )\n    test_case.assertEqual(x[idx].numel(), 0)\n\n    # empty assignment should have no effect but not throw an exception\n    y = x.clone()\n    y[idx] = -1\n    _assert_tensor_equal(test_case, x, y)\n\n    mask = _cpu_global_tensor(flow.zeros(8, 8).to(flow.bool)).to_global(\n        placement, broadcast_for_placement\n    )\n    y[mask] = -1\n    _assert_tensor_equal(test_case, x, y)\n\n\ndef _test_empty_ndim_index(test_case, placement):\n    broadcast_for_placement = [flow.sbp.broadcast,] * len(placement.ranks.shape)\n    sbp = random_sbp(placement, max_dim=1).value()\n    x = global_broadcast_consec((8,)).to_global(placement, sbp)\n    _assert_tensor_equal(\n        test_case,\n        x[\n            _cpu_global_tensor(flow.empty(0, 2, dtype=flow.int64)).to_global(\n                placement, broadcast_for_placement\n            )\n        ],\n        flow.empty(0, 2),\n    )\n\n    sbp = random_sbp(placement, max_dim=1).value()\n    x = _cpu_global_tensor(flow.empty(8, 0)).to_global(placement, sbp)\n    test_case.assertEqual(x[[1, 2]].shape, (2, 0))\n    test_case.assertEqual(x[[], []].shape, (0,))\n    test_case.assertEqual(x[[[]]].shape, (0, 0))\n    test_case.assertEqual(x[[[[]]]].shape, (1, 0, 0))\n    test_case.assertEqual(x[[1], []].shape, (0,))\n    test_case.assertEqual(x[[], [2]].shape, (0,))\n    with test_case.assertRaisesRegex(IndexError, \"for dimension with size 0\"):\n        x[:, [0, 1]]\n\n\ndef _test_empty_ndim_index_bool(test_case, placement):\n    broadcast_for_placement = [flow.sbp.broadcast,] * len(placement.ranks.shape)\n    sbp = random_sbp(placement, max_dim=1).value()\n    x = global_broadcast_consec((8,)).to_global(placement, sbp)\n    test_case.assertRaises(\n        IndexError,\n        lambda: x[\n            _cpu_global_tensor(flow.empty(0, 2, dtype=flow.uint8)).to_global(\n                placement, broadcast_for_placement\n            )\n        ],\n    )\n\n\ndef _test_empty_slice(test_case, placement):\n    sbp = random_sbp(placement, max_dim=1).value()\n    x = global_broadcast_consec((8, 8, 8, 8)).to_global(placement, sbp)\n    y = x[:, :, :, 1]\n    z = y[:, 1:1, :]\n    test_case.assertEqual((8, 0, 8), z.shape)\n\n\ndef _test_index_getitem_copy_bools_slices(test_case, placement):\n    broadcast_for_placement = [flow.sbp.broadcast,] * len(placement.ranks.shape)\n    false = _cpu_global_tensor(flow.tensor(0, dtype=flow.uint8)).to_global(\n        placement, broadcast_for_placement\n    )\n\n    sbp = random_sbp(placement, max_dim=1).value()\n    tensor = global_broadcast_consec((8, 8)).to_global(placement, sbp)\n\n    _assert_tensor_equal(test_case, flow.empty(0, *tensor.shape), tensor[False])\n    _assert_tensor_equal(test_case, flow.empty(0, *tensor.shape), tensor[false])\n\n\ndef _test_setitem_scalars(test_case, placement):\n    broadcast_for_placement = [flow.sbp.broadcast,] * len(placement.ranks.shape)\n    zero = _cpu_global_tensor(flow.tensor(0, dtype=flow.int64)).to_global(\n        placement, broadcast_for_placement\n    )\n\n    # non-scalar indexed with scalars\n    a = global_broadcast_consec((8, 8)).to_global(\n        placement, random_sbp(placement, max_dim=2).value()\n    )\n    a_set_with_number = a.clone()\n    a_set_with_scalar = a.clone()\n    b = global_broadcast_consec((8,), 233).to_global(\n        placement, random_sbp(placement, max_dim=1).value()\n    )\n\n    a_set_with_number[0] = b\n    a_set_with_scalar[zero] = b\n    _assert_tensor_equal(test_case, a_set_with_number, a_set_with_scalar)\n    a[1, zero] = 7.7\n    value = a[1, 0].numpy()\n    test_case.assertEqual(np.array(7.7, dtype=value.dtype), value)\n\n    np_x = np.zeros((8, 8))\n    np_x[0, 6] = 1.0\n    x = _cpu_global_tensor(flow.tensor(np_x)).to_global(\n        placement, random_sbp(placement, max_dim=2).value()\n    )\n    x[0, 6] = 1.0\n    test_case.assertEqual(x.numpy().all(), np_x.all())\n\n    # scalar indexed with scalars\n    r = _cpu_global_tensor(flow.tensor(1.0)).to_global(\n        placement, random_sbp(placement, max_dim=0).value()\n    )\n    with test_case.assertRaises(IndexError):\n        r[:] = 8.8\n    with test_case.assertRaises(IndexError):\n        r[zero] = 8.8\n    r[...] = 9.9\n    test_case.assertEqual(r, 9.9)\n\n    # scalar indexed with oneflow.Size([1])\n    np_x = np.zeros((8, 8))\n    np_x[0, 6] = np.ones(1)\n    x = _cpu_global_tensor(flow.tensor(np_x)).to_global(\n        placement, random_sbp(placement, max_dim=2).value()\n    )\n    x[0, 0] = _cpu_global_tensor(flow.ones(1).to(flow.float64)).to_global(\n        placement, broadcast_for_placement\n    )\n    test_case.assertEqual(x.numpy().all(), np_x.all())\n\n\ndef _test_basic_advanced_combined(test_case, placement):\n    sbp = random_sbp(placement, max_dim=2).value()\n    x = global_broadcast_consec((8, 8)).to_global(placement, sbp)\n    _assert_tensor_equal(test_case, x[1:2, 3:5], x[1:2, [3, 4]])\n    test_case.assertEqual(x[1:2, 1:3].tolist(), [[10, 11]])\n\n    # Check that it is a copy\n    unmodified = x.clone()\n    x[1:2, [1, 2]].zero_()\n    _assert_tensor_equal(test_case, x, unmodified)\n\n    # But assignment should modify the original\n    unmodified = x.clone()\n    x[1:2, [1, 2]] = 0\n    test_case.assertFalse(np.array_equal(x.numpy(), unmodified.numpy()))\n\n\ndef _test_ellipsis_tensor(test_case, placement):\n    broadcast_for_placement = [flow.sbp.broadcast,] * len(placement.ranks.shape)\n    sbp = random_sbp(placement, max_dim=2).value()\n    x = global_broadcast_consec((8, 8)).to_global(placement, sbp)\n    idx = _cpu_global_tensor(flow.tensor([0, 7])).to_global(\n        placement, broadcast_for_placement\n    )\n    test_case.assertEqual(\n        x[..., idx].tolist(),\n        [[1, 8], [9, 16], [17, 24], [25, 32], [33, 40], [41, 48], [49, 56], [57, 64]],\n    )\n    test_case.assertEqual(\n        x[idx, ...].tolist(),\n        [[1, 2, 3, 4, 5, 6, 7, 8], [57, 58, 59, 60, 61, 62, 63, 64]],\n    )\n\n    # Test scalar ellipsis getitem\n    x_scalar = _cpu_global_tensor(flow.tensor(9.9)).to_global(\n        placement, broadcast_for_placement\n    )\n    test_case.assertEqual(x_scalar[...], 9.9)\n\n\nclass TestGlobalIndexing(flow.unittest.TestCase):\n    @globaltest\n    def test_global_slice(test_case):\n        for placement in all_placement():\n            for _ in range(5):\n                _test_basic_slice(test_case, placement)\n                _test_advanced_indexing(test_case, placement, dtype=flow.float32)\n                _test_combined_indexing(test_case, placement, dtype=flow.float32)\n                _test_single_int(test_case, placement)\n                _test_multiple_int(test_case, placement)\n                _test_none(test_case, placement)\n                _test_step(test_case, placement)\n                _test_step_assignment(test_case, placement)\n                _test_bool_indices(test_case, placement)\n                _test_multiple_bool_indices(test_case, placement)\n                _test_int_indices(test_case, placement)\n                _test_int_indices2d(test_case, placement)\n                _test_int_indices_broadcast(test_case, placement)\n                _test_empty_index(test_case, placement)\n                _test_empty_ndim_index(test_case, placement)\n                _test_empty_ndim_index_bool(test_case, placement)\n                _test_empty_slice(test_case, placement)\n                _test_index_getitem_copy_bools_slices(test_case, placement)\n                _test_setitem_scalars(test_case, placement)\n                _test_basic_advanced_combined(test_case, placement)\n                _test_ellipsis_tensor(test_case, placement)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/tensor/test_lazy_tensor_indexing.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\n\nimport numpy as np\n\nimport oneflow as flow\nfrom oneflow.test_utils.automated_test_util import *\nimport oneflow.unittest\nimport oneflow.framework.session_context as session_ctx\n\n\ndef get_graph_output(*args, func):\n    def generate_graph(func):\n        class Graph(flow.nn.Graph):\n            def __init__(self):\n                super().__init__()\n\n            def build(self, *args):\n                return func(*args)\n\n        return Graph()\n\n    graph = generate_graph(func)\n    return graph(*args)\n\n\ndef setitem_and_return(ref, idx, value):\n    ref[idx] = value\n    return ref\n\n\ndef _randint(low, high):\n    \"\"\"\n    Get a random integer in the range [low, high).\n    \"\"\"\n    return random(low, high).to(int).value()\n\n\ndef _cpu_global_tensor(tensor):\n    return tensor.to_global(flow.placement.all(\"cpu\"), flow.sbp.broadcast)\n\n\ndef _assert_tensor_equal(test_case, tensor1, tensor2, atol=0.0, rtol=0.0):\n    test_case.assertTrue(\n        np.allclose(tensor1.numpy(), tensor2.numpy(), atol, rtol),\n        f\"{tensor1.numpy()} vs {tensor2.numpy()}\",\n    )\n\n\ndef global_broadcast_consec(size, start=1):\n    \"\"\"\n    Generate a arithmetic progression with given size and start value.\n    \"\"\"\n    sequence = flow.ones([int(np.array(size).prod(0)),]).cumsum(0)\n    sequence.add_(start - 1)\n    return _cpu_global_tensor(sequence.view(*size))\n\n\ndef _test_basic_slice(test_case, placement):\n    broadcast_for_placement = [flow.sbp.broadcast,] * len(placement.ranks.shape)\n\n    ref_sbp = random_sbp(placement, max_dim=3).value()\n    reference = global_broadcast_consec((8, 8, 8)).to_global(placement, ref_sbp)\n\n    # empty tensor indexing\n    empty_index = _cpu_global_tensor(flow.LongTensor()).to_global(\n        placement, broadcast_for_placement\n    )\n    _assert_tensor_equal(\n        test_case,\n        get_graph_output(reference, func=lambda x: x[empty_index]),\n        flow.empty(0, 8, 8),\n        atol=0,\n        rtol=0,\n    )\n\n    _assert_tensor_equal(\n        test_case,\n        get_graph_output(reference, func=lambda x: x[1]),\n        global_broadcast_consec((8, 8), 65),\n        atol=0,\n        rtol=0,\n    )\n    _assert_tensor_equal(\n        test_case,\n        get_graph_output(reference, func=lambda x: x[0, 1]),\n        global_broadcast_consec((8,), 9),\n        atol=0,\n        rtol=0,\n    )\n    _assert_tensor_equal(\n        test_case,\n        get_graph_output(reference, func=lambda x: x[0:2]),\n        global_broadcast_consec((2, 8, 8)),\n        atol=0,\n        rtol=0,\n    )\n    test_case.assertEqual(\n        get_graph_output(reference, func=lambda x: x[2, 2, 2]).item(), 147\n    )\n    _assert_tensor_equal(\n        test_case,\n        get_graph_output(reference, func=lambda x: x[:]),\n        global_broadcast_consec((8, 8, 8)),\n        atol=0,\n        rtol=0,\n    )\n\n    # indexing with Ellipsis\n    _assert_tensor_equal(\n        test_case,\n        get_graph_output(reference, func=lambda x: x[..., 2, 2]),\n        flow.tensor([19, 83, 147, 211, 275, 339, 403, 467]),\n        atol=0,\n        rtol=0,\n    )\n    _assert_tensor_equal(\n        test_case,\n        get_graph_output(reference, func=lambda x: x[0, ..., 2]),\n        flow.tensor([3, 11, 19, 27, 35, 43, 51, 59]),\n        atol=0,\n        rtol=0,\n    )\n    _assert_tensor_equal(\n        test_case,\n        get_graph_output(reference, func=lambda x: x[0, 2, ...]),\n        reference[0, 2],\n        atol=0,\n        rtol=0,\n    )\n\n    reference_5d = global_broadcast_consec((8, 8, 8, 8, 8)).to_global(\n        placement, sbp=random_sbp(placement, max_dim=5).value()\n    )\n    _assert_tensor_equal(\n        test_case,\n        get_graph_output(reference_5d, func=lambda x: x[2, ..., 1, 0]),\n        get_graph_output(reference_5d, func=lambda x: x[2, :, :, 1, 0]),\n        atol=0,\n        rtol=0,\n    )\n\n    # LongTensor indexing\n    sbp = random_sbp(placement, max_dim=3).value()\n    reference = global_broadcast_consec((8, 8, 8)).to_global(placement, sbp)\n    idx = _cpu_global_tensor(flow.LongTensor([2, 4])).to_global(\n        placement, broadcast_for_placement\n    )\n    _assert_tensor_equal(\n        test_case,\n        get_graph_output(reference, idx, func=lambda x, y: x[y]),\n        get_graph_output(reference, func=lambda x: flow.stack([x[2], x[4]])),\n    )\n\n    # None indexing\n    _assert_tensor_equal(\n        test_case,\n        get_graph_output(reference, func=lambda x: x[None, 2, None, None]),\n        reference.unsqueeze(0)[:, 2].unsqueeze(0).unsqueeze(0),\n    )\n    _assert_tensor_equal(\n        test_case,\n        get_graph_output(reference, func=lambda x: x[None, 2:5, None, None]),\n        reference.unsqueeze(0)[:, 2:5].unsqueeze(2).unsqueeze(2),\n    )\n\n    # indexing 0-length slice\n    _assert_tensor_equal(\n        test_case,\n        flow.empty(0, 8, 8),\n        get_graph_output(reference, func=lambda x: x[slice(0)]),\n    )\n    _assert_tensor_equal(\n        test_case,\n        flow.empty(0, 8),\n        get_graph_output(reference, func=lambda x: x[2, slice(0)]),\n    )\n    _assert_tensor_equal(\n        test_case,\n        flow.tensor([]),\n        get_graph_output(reference, func=lambda x: x[2, 1:1, 2]),\n    )\n\n    # indexing with step\n    sbp = random_sbp(placement, max_dim=3).value()\n    reference = global_broadcast_consec((8, 8, 8)).to_global(placement, sbp)\n    _assert_tensor_equal(\n        test_case,\n        get_graph_output(reference, func=lambda x: x[2:4, 1:5:2]),\n        get_graph_output(\n            reference, func=lambda x: flow.stack([x[2:4, 1], x[2:4, 3]], 1)\n        ),\n    )\n    _assert_tensor_equal(\n        test_case,\n        get_graph_output(reference, func=lambda x: x[3, 1:6:2]),\n        get_graph_output(\n            reference, func=lambda x: flow.stack([x[3, 1], x[3, 3], x[3, 5]], 0)\n        ),\n    )\n    _assert_tensor_equal(\n        test_case,\n        get_graph_output(reference, func=lambda x: x[None, 2, 1:9:4]),\n        get_graph_output(\n            reference, func=lambda x: flow.stack([x[2, 1], x[2, 5]], 0).unsqueeze(0)\n        ),\n    )\n\n\ndef _test_advanced_indexing(test_case, placement, dtype):\n    broadcast_for_placement = [flow.sbp.broadcast] * len(placement.ranks.shape)\n\n    # pick a random valid indexer type\n    def ri(indices):\n        choice = _randint(0, 3)\n        if choice == 0:\n            return flow.LongTensor(\n                indices, placement=flow.placement.all(\"cpu\"), sbp=flow.sbp.broadcast,\n            ).to_global(placement, broadcast_for_placement)\n        elif choice == 1:\n            return list(indices)\n        else:\n            return tuple(indices)\n\n    def validate_indexing(x):\n        _assert_tensor_equal(\n            test_case,\n            get_graph_output(x, func=lambda x: x[ri([3]),]),\n            global_broadcast_consec((1,), 4),\n        )\n        _assert_tensor_equal(\n            test_case,\n            get_graph_output(x, func=lambda x: x[ri([2, 3, 4]),]),\n            global_broadcast_consec((3,), 3),\n        )\n\n    def validate_setting(x):\n        #  x[[0]] = -2\n        x = get_graph_output(x, func=lambda x: setitem_and_return(x, [0], -2))\n        _assert_tensor_equal(test_case, x[0], flow.tensor([-2], dtype=dtype))\n        #  x[[0]] = -1\n        x = get_graph_output(x, func=lambda x: setitem_and_return(x, [0], -1))\n        _assert_tensor_equal(test_case, x[0], flow.tensor([-1], dtype=dtype))\n        #  x[[2, 3, 4]] = 4\n        x = get_graph_output(x, func=lambda x: setitem_and_return(x, [2, 3, 4], 4))\n        _assert_tensor_equal(\n            test_case, x[[2, 3, 4]], flow.tensor([4, 4, 4], dtype=dtype)\n        )\n        #  x[ri([2, 3, 4]),] = 3\n        x = get_graph_output(\n            x, func=lambda x: setitem_and_return(x, [ri([2, 3, 4]),], 3)\n        )\n        _assert_tensor_equal(\n            test_case, x[[2, 3, 4]], flow.tensor([3, 3, 3], dtype=dtype),\n        )\n        #  x[ri([0, 2, 4]),] = _cpu_global_tensor(flow.tensor([5, 4, 3], dtype=dtype))\n        value_tensor = _cpu_global_tensor(flow.tensor([5, 4, 3], dtype=dtype))\n        x = get_graph_output(\n            x, func=lambda x: setitem_and_return(x, [ri([0, 2, 4]),], value_tensor)\n        )\n        _assert_tensor_equal(\n            test_case, x[[0, 2, 4]], flow.tensor([5, 4, 3], dtype=dtype),\n        )\n\n    # 1d tensor and integer index setitem and getitem\n    sbp = random_sbp(placement, max_dim=1).value()\n    reference = global_broadcast_consec((8,)).to_global(placement, sbp)\n    validate_indexing(reference)\n    validate_setting(reference)\n\n    # reference is  1  2  3  4  5  6  7  8\n    #               9 10 11 12 13 14 15 16\n    #              17 18 19 20 21 22 23 24\n    #              25 26 27 28 29 30 31 32\n    #              33 34 35 36 37 38 39 40\n    #              41 42 43 44 45 46 47 48\n    #              49 50 51 52 53 54 55 56\n    #              57 58 59 60 61 62 63 64\n    sbp = random_sbp(placement, max_dim=2).value()\n    reference = global_broadcast_consec((8, 8)).to_global(placement, sbp)\n    _assert_tensor_equal(\n        test_case,\n        get_graph_output(reference, func=lambda x: x[ri([0, 1, 2]), ri([0])]),\n        flow.tensor([1, 9, 17], dtype=dtype),\n    )\n    _assert_tensor_equal(\n        test_case,\n        get_graph_output(reference, func=lambda x: x[ri([0, 1, 2]), ri([1])]),\n        flow.tensor([2, 10, 18], dtype=dtype),\n    )\n    _assert_tensor_equal(\n        test_case,\n        get_graph_output(reference, func=lambda x: x[ri([0]), ri([0])]),\n        global_broadcast_consec((1,)),\n    )\n    _assert_tensor_equal(\n        test_case,\n        get_graph_output(reference, func=lambda x: x[ri([2]), ri([1])]),\n        global_broadcast_consec((1,), 18),\n    )\n    _assert_tensor_equal(\n        test_case,\n        get_graph_output(reference, func=lambda x: x[ri([0, 0]), ri([0, 1])]),\n        flow.tensor([1, 2], dtype=dtype),\n    )\n    _assert_tensor_equal(\n        test_case,\n        get_graph_output(reference, func=lambda x: x[ri([0, 1, 1, 0, 2, 7]), ri([1])]),\n        flow.tensor([2, 10, 10, 2, 18, 58], dtype=dtype),\n    )\n    _assert_tensor_equal(\n        test_case,\n        get_graph_output(\n            reference, func=lambda x: x[ri([0, 0, 1, 1]), ri([0, 1, 0, 0])]\n        ),\n        flow.tensor([1, 2, 9, 9], dtype=dtype),\n    )\n\n    rows = ri([[0, 0], [1, 6]])\n    columns = ([0],)\n    _assert_tensor_equal(\n        test_case,\n        get_graph_output(reference, func=lambda x: x[rows, columns]),\n        flow.tensor([[1, 1], [9, 49]], dtype=dtype),\n    )\n\n    rows = ri([[0, 0], [1, 6]])\n    columns = ri([6, 0])\n    _assert_tensor_equal(\n        test_case,\n        get_graph_output(reference, func=lambda x: x[rows, columns]),\n        flow.tensor([[7, 1], [15, 49]], dtype=dtype),\n    )\n    rows = ri([[0, 0], [1, 2]])\n    columns = ri([[0, 1], [3, 7]])\n    _assert_tensor_equal(\n        test_case,\n        get_graph_output(reference, func=lambda x: x[rows, columns]),\n        flow.tensor([[1, 2], [12, 24]], dtype=dtype),\n    )\n\n    # setting values\n    #  reference[ri([0]), ri([1])] = -1\n    reference = get_graph_output(\n        reference, func=lambda x: setitem_and_return(x, [ri([0]), ri([1])], -1)\n    )\n    _assert_tensor_equal(\n        test_case, reference[ri([0]), ri([1])], flow.tensor([-1], dtype=dtype),\n    )\n\n    value_tensor = _cpu_global_tensor(flow.tensor([-1, 2, -4], dtype=dtype)).to_global(\n        placement, broadcast_for_placement\n    )\n    reference = get_graph_output(\n        reference,\n        func=lambda x: setitem_and_return(x, [ri([0, 1, 2]), ri([0])], value_tensor),\n    )\n    _assert_tensor_equal(\n        test_case,\n        reference[ri([0, 1, 2]), ri([0])],\n        flow.tensor([-1, 2, -4], dtype=dtype),\n    )\n\n    value_tensor = _cpu_global_tensor(\n        flow.tensor([[4, 6], [2, 3]], dtype=dtype)\n    ).to_global(placement, broadcast_for_placement)\n    reference = get_graph_output(\n        reference, func=lambda x: setitem_and_return(x, [rows, columns], value_tensor)\n    )\n    _assert_tensor_equal(\n        test_case, reference[rows, columns], flow.tensor([[4, 6], [2, 3]], dtype=dtype),\n    )\n\n    # Tests using less than the number of dims, and ellipsis\n    # reference is  1  2  3  4  5  6  7  8\n    #               9 10 11 12 13 14 15 16\n    #              17 18 19 20 21 22 23 24\n    #              25 26 27 28 29 30 31 32\n    #              33 34 35 36 37 38 39 40\n    #              41 42 43 44 45 46 47 48\n    #              49 50 51 52 53 54 55 56\n    #              57 58 59 60 61 62 63 64\n    sbp = random_sbp(placement, max_dim=2).value()\n    reference = global_broadcast_consec((8, 8)).to_global(placement, sbp)\n    _assert_tensor_equal(\n        test_case,\n        get_graph_output(reference, func=lambda x: x[ri([0, 2]),]),\n        flow.tensor(\n            [[1, 2, 3, 4, 5, 6, 7, 8], [17, 18, 19, 20, 21, 22, 23, 24]], dtype=dtype\n        ),\n    )\n    _assert_tensor_equal(\n        test_case,\n        get_graph_output(reference, func=lambda x: x[ri([1]), ...]),\n        flow.tensor([[9, 10, 11, 12, 13, 14, 15, 16]], dtype=dtype),\n    )\n    _assert_tensor_equal(\n        test_case,\n        get_graph_output(reference, func=lambda x: x[..., ri([1])]),\n        flow.tensor([[2], [10], [18], [26], [34], [42], [50], [58]], dtype=dtype),\n    )\n\n\ndef _test_combined_indexing(test_case, placement, dtype):\n    broadcast_for_placement = [flow.sbp.broadcast,] * len(placement.ranks.shape)\n\n    def tensor_indices_to_np(tensor, indices):\n        # convert the flow Tensor to a numpy array\n        npt = tensor.numpy()\n\n        # convert indices\n        idxs = tuple(\n            i.tolist() if isinstance(i, flow.LongTensor) else i for i in indices\n        )\n\n        return npt, idxs\n\n    def get_numpy(tensor, indices):\n        npt, idxs = tensor_indices_to_np(tensor, indices)\n\n        # index and return as a oneflow local Tensor\n        return flow.tensor(npt[idxs], dtype=dtype)\n\n    def set_numpy(tensor, indices, value):\n        if not isinstance(value, int):\n            value = value.numpy()\n\n        npt, idxs = tensor_indices_to_np(tensor, indices)\n        npt[idxs] = value\n        return npt\n\n    def assert_get_eq(tensor, indexer):\n        _assert_tensor_equal(\n            test_case,\n            get_graph_output(tensor, func=lambda x: x[indexer]),\n            get_numpy(tensor, indexer),\n        )\n\n    def assert_set_eq(tensor, indexer, val):\n        pyt = tensor.clone()\n        np_ref = tensor.clone()\n        pyt = get_graph_output(pyt, func=lambda x: setitem_and_return(x, indexer, val))\n        np_ref = flow.tensor(set_numpy(np_ref, indexer, val), dtype=dtype)\n        _assert_tensor_equal(test_case, pyt, np_ref)\n\n    def get_set_tensor(indexed, indexer):\n        set_size = indexed[indexer].size()\n        set_count = indexed[indexer].numel()\n        set_tensor = _cpu_global_tensor(\n            flow.arange(set_count, 0, -1).view(set_size).to(dtype)\n        ).to_global(placement, broadcast_for_placement)\n        return set_tensor\n\n    # Tensor is  1  2  3  4  5  6  7  8\n    #            9  10 11 12 13 14 15 16\n    #            17 18 19 20 21 22 23 24\n    #            25 26 27 28 29 30 31 32\n    #            33 34 35 36 37 38 39 40\n    #            41 42 43 44 45 46 47 48\n    #            49 50 51 52 53 54 55 56\n    #            57 58 59 60 61 62 63 64\n    sbp = random_sbp(placement, max_dim=2).value()\n    reference = global_broadcast_consec((8, 8)).to_global(placement, sbp)\n\n    indices_to_test = [\n        # grab the second, fourth columns\n        [slice(None), [4, 6]],\n        # first, third rows,\n        [[0, 6], slice(None)],\n        # TODO(wyg): only support getitem but not setitem\n        #  # weird shape\n        #  [slice(None), [[0, 1],\n        #                 [2, 3]]],\n        # negatives\n        [[-1], [0]],\n        [[0, 7], [-1]],\n        [slice(None), [-1]],\n    ]\n\n    # test getitem\n    get_indices_to_test = indices_to_test + [[slice(None), [0, 1, 1, 2, 2]]]\n    get_indices_to_test = indices_to_test + [\n        [slice(None), [[0, 1], [2, 3]]]\n    ]  # TODO: test setitem\n    for indexer in get_indices_to_test:\n        assert_get_eq(reference, indexer)\n\n    # test setitem\n    for indexer in indices_to_test:\n        assert_set_eq(reference, indexer, 44)\n        assert_set_eq(reference, indexer, get_set_tensor(reference, indexer))\n\n    #########################\n    # test more dims tensor #\n    #########################\n    sbp = random_sbp(placement, max_dim=3).value()\n    reference = global_broadcast_consec((8, 8, 8), 0).float().to_global(placement, sbp)\n\n    indices_to_test = [\n        [slice(None), slice(None), [0, 3, 4]],\n        [slice(None), [2, 4, 5, 7], slice(None)],\n        [[2, 3], slice(None), slice(None)],\n        [slice(None), [0, 2, 3], [1, 3, 4]],\n        [slice(None), [0], [1, 2, 4]],\n        [slice(None), [0, 1, 3], [4]],\n        [slice(None), [[0, 1], [1, 0]], [[2, 3]]],\n        [slice(None), [[0, 1], [2, 3]], [[0]]],\n        [slice(None), [[5, 6]], [[0, 3], [4, 4]]],\n        [[0, 2, 3], [1, 3, 4], slice(None)],\n        [[0], [1, 2, 4], slice(None)],\n        [[0, 1, 3], [4], slice(None)],\n        [[[0, 1], [1, 0]], [[2, 1], [3, 5]], slice(None)],\n        [[[0, 1], [1, 0]], [[2, 3]], slice(None)],\n        [[[0, 1], [2, 3]], [[0]], slice(None)],\n        [[[2, 1]], [[0, 3], [4, 4]], slice(None)],\n        [[[2]], [[0, 3], [4, 1]], slice(None)],\n        # non-contiguous indexing subspace\n        [[0, 2, 3], slice(None), [1, 3, 4]],\n        # less dim, ellipsis\n        [[0, 2],],\n        [[0, 2], slice(None)],\n        [[0, 2], Ellipsis],\n        [[0, 2], slice(None), Ellipsis],\n        [[0, 2], Ellipsis, slice(None)],\n        [[0, 2], [1, 3]],\n        [[0, 2], [1, 3], Ellipsis],\n        [Ellipsis, [1, 3], [2, 3]],\n        [Ellipsis, [2, 3, 4]],\n        [Ellipsis, slice(None), [2, 3, 4]],\n        [slice(None), Ellipsis, [2, 3, 4]],\n        # ellipsis counts for nothing\n        [Ellipsis, slice(None), slice(None), [0, 3, 4]],\n        [slice(None), Ellipsis, slice(None), [0, 3, 4]],\n        [slice(None), slice(None), Ellipsis, [0, 3, 4]],\n        [slice(None), slice(None), [0, 3, 4], Ellipsis],\n        [Ellipsis, [[0, 1], [1, 0]], [[2, 1], [3, 5]], slice(None)],\n        [[[0, 1], [1, 0]], [[2, 1], [3, 5]], Ellipsis, slice(None)],\n        [[[0, 1], [1, 0]], [[2, 1], [3, 5]], slice(None), Ellipsis],\n    ]\n\n    for indexer in indices_to_test:\n        assert_get_eq(reference, indexer)\n        assert_set_eq(reference, indexer, 212)\n        assert_set_eq(reference, indexer, get_set_tensor(reference, indexer))\n\n    sbp = random_sbp(placement, max_dim=4).value()\n    reference = (\n        global_broadcast_consec((8, 8, 8, 8), 0).float().to_global(placement, sbp)\n    )\n\n    indices_to_test = [\n        [slice(None), slice(None), slice(None), [0, 3, 4]],\n        [slice(None), slice(None), [2, 4, 5, 7], slice(None)],\n        [slice(None), [2, 3], slice(None), slice(None)],\n        [[1, 2], slice(None), slice(None), slice(None)],\n        [slice(None), slice(None), [0, 2, 3], [1, 3, 4]],\n        [slice(None), slice(None), [0], [1, 2, 4]],\n        [slice(None), slice(None), [0, 1, 3], [4]],\n        [slice(None), slice(None), [[0, 1], [1, 0]], [[2, 3]]],\n        [slice(None), slice(None), [[0, 1], [2, 3]], [[0]]],\n        [slice(None), slice(None), [[5, 6]], [[0, 3], [4, 4]]],\n        [slice(None), [0, 2, 3], [1, 3, 4], slice(None)],\n        [slice(None), [0], [1, 2, 4], slice(None)],\n        [slice(None), [0, 1, 3], [4], slice(None)],\n        [slice(None), [[0, 1], [3, 4]], [[2, 3], [0, 1]], slice(None)],\n        [slice(None), [[0, 1], [3, 4]], [[2, 3]], slice(None)],\n        [slice(None), [[0, 1], [3, 2]], [[0]], slice(None)],\n        [slice(None), [[2, 1]], [[0, 3], [6, 4]], slice(None)],\n        [slice(None), [[2]], [[0, 3], [4, 2]], slice(None)],\n        [[0, 1, 2], [1, 3, 4], slice(None), slice(None)],\n        [[0], [1, 2, 4], slice(None), slice(None)],\n        [[0, 1, 2], [4], slice(None), slice(None)],\n        [[[0, 1], [0, 2]], [[2, 4], [1, 5]], slice(None), slice(None)],\n        [[[0, 1], [1, 2]], [[2, 0]], slice(None), slice(None)],\n        [[[2, 2]], [[0, 3], [4, 5]], slice(None), slice(None)],\n        [[[2]], [[0, 3], [4, 5]], slice(None), slice(None)],\n        [slice(None), [3, 4, 6], [0, 2, 3], [1, 3, 4]],\n        [slice(None), [2, 3, 4], [1, 3, 4], [4]],\n        [slice(None), [0, 1, 3], [4], [1, 3, 4]],\n        [slice(None), [6], [0, 2, 3], [1, 3, 4]],\n        [slice(None), [2, 3, 5], [3], [4]],\n        [slice(None), [0], [4], [1, 3, 4]],\n        [slice(None), [6], [0, 2, 3], [1]],\n        [slice(None), [[0, 3], [3, 6]], [[0, 1], [1, 3]], [[5, 3], [1, 2]]],\n        [[2, 2, 1], [0, 2, 3], [1, 3, 4], slice(None)],\n        [[2, 0, 1], [1, 2, 3], [4], slice(None)],\n        [[0, 1, 2], [4], [1, 3, 4], slice(None)],\n        [[0], [0, 2, 3], [1, 3, 4], slice(None)],\n        [[0, 2, 1], [3], [4], slice(None)],\n        [[0], [4], [1, 3, 4], slice(None)],\n        [[1], [0, 2, 3], [1], slice(None)],\n        [[[1, 2], [1, 2]], [[0, 1], [2, 3]], [[2, 3], [3, 5]], slice(None)],\n        # less dim, ellipsis\n        [Ellipsis, [0, 3, 4]],\n        [Ellipsis, slice(None), [0, 3, 4]],\n        [Ellipsis, slice(None), slice(None), [0, 3, 4]],\n        [slice(None), Ellipsis, [0, 3, 4]],\n        [slice(None), slice(None), Ellipsis, [0, 3, 4]],\n        [slice(None), [0, 2, 3], [1, 3, 4]],\n        [slice(None), [0, 2, 3], [1, 3, 4], Ellipsis],\n        [Ellipsis, [0, 2, 3], [1, 3, 4], slice(None)],\n        [[0], [1, 2, 4]],\n        [[0], [1, 2, 4], slice(None)],\n        [[0], [1, 2, 4], Ellipsis],\n        [[0], [1, 2, 4], Ellipsis, slice(None)],\n        [[1],],\n        [[0, 2, 1], [3], [4]],\n        [[0, 2, 1], [3], [4], slice(None)],\n        [[0, 2, 1], [3], [4], Ellipsis],\n        [Ellipsis, [0, 2, 1], [3], [4]],\n    ]\n\n    for indexer in indices_to_test:\n        assert_get_eq(reference, indexer)\n        assert_set_eq(reference, indexer, 1333)\n        assert_set_eq(reference, indexer, get_set_tensor(reference, indexer))\n    indices_to_test += [\n        [slice(None), slice(None), [[0, 1], [1, 0]], [[2, 3], [3, 0]]],\n        [slice(None), slice(None), [[2]], [[0, 3], [4, 4]]],\n    ]\n    for indexer in indices_to_test:\n        assert_get_eq(reference, indexer)\n        assert_set_eq(reference, indexer, 1333)\n\n\ndef _test_single_int(test_case, placement):\n    sbp = random_sbp(placement, max_dim=1).value()\n    v = _cpu_global_tensor(flow.zeros(8, 7, 3)).to_global(placement, sbp)\n    test_case.assertEqual(get_graph_output(v, func=lambda x: x[2]).shape, (7, 3))\n\n\ndef _test_multiple_int(test_case, placement):\n    sbp = random_sbp(placement, max_dim=3).value()\n    v = _cpu_global_tensor(flow.zeros(8, 8, 8)).to_global(placement, sbp)\n    test_case.assertEqual(get_graph_output(v, func=lambda x: x[4, :, 1]).shape, (8,))\n\n\ndef _test_none(test_case, placement):\n    sbp = random_sbp(placement, max_dim=3).value()\n    v = _cpu_global_tensor(flow.zeros(8, 8, 8)).to_global(placement, sbp)\n    test_case.assertEqual(\n        get_graph_output(v, func=lambda x: x[None]).shape, (1, 8, 8, 8)\n    )\n    test_case.assertEqual(\n        get_graph_output(v, func=lambda x: x[:, None]).shape, (8, 1, 8, 8)\n    )\n    test_case.assertEqual(\n        get_graph_output(v, func=lambda x: x[:, None, None]).shape, (8, 1, 1, 8, 8)\n    )\n    test_case.assertEqual(\n        get_graph_output(v, func=lambda x: x[..., None]).shape, (8, 8, 8, 1)\n    )\n\n\ndef _test_step(test_case, placement):\n    sbp = random_sbp(placement, max_dim=1).value()\n    v = _cpu_global_tensor(flow.arange(8)).to_global(placement, sbp)\n    _assert_tensor_equal(test_case, v[::1], v)\n    test_case.assertEqual(\n        get_graph_output(v, func=lambda x: x[::2]).tolist(), [0, 2, 4, 6]\n    )\n    test_case.assertEqual(\n        get_graph_output(v, func=lambda x: x[::3]).tolist(), [0, 3, 6]\n    )\n    test_case.assertEqual(get_graph_output(v, func=lambda x: x[::11]).tolist(), [0])\n    test_case.assertEqual(\n        get_graph_output(v, func=lambda x: x[1:6:2]).tolist(), [1, 3, 5]\n    )\n\n\ndef _test_step_assignment(test_case, placement):\n    broadcast_for_placement = [flow.sbp.broadcast,] * len(placement.ranks.shape)\n    sbp = random_sbp(placement, max_dim=2).value()\n    v = _cpu_global_tensor(flow.zeros(8, 8)).to_global(placement, sbp)\n    value_tensor = _cpu_global_tensor(flow.tensor([3.0, 4.0, 5.0, 6.0])).to_global(\n        placement, broadcast_for_placement\n    )\n    v = get_graph_output(\n        v, func=lambda x: setitem_and_return(x, [0, slice(1, None, 2)], value_tensor)\n    )\n    test_case.assertEqual(v[0].tolist(), [0.0, 3.0, 0.0, 4.0, 0.0, 5.0, 0.0, 6.0])\n    test_case.assertEqual(v[1:].sum(), 0)\n\n\ndef _test_multiple_bool_indices(test_case, placement):\n    broadcast_for_placement = [flow.sbp.broadcast,] * len(placement.ranks.shape)\n    sbp = random_sbp(placement, max_dim=2).value()\n    v = global_broadcast_consec((8, 8, 4)).to_global(placement, sbp)\n    # NOTE: these broadcast together and are transposed to the first dim\n    mask1 = _cpu_global_tensor(\n        flow.tensor([1, 0, 1, 0, 0, 1, 0, 0], dtype=flow.bool)\n    ).to_global(placement, broadcast_for_placement)\n    mask2 = _cpu_global_tensor(flow.tensor([1, 1, 1, 0], dtype=flow.bool)).to_global(\n        placement, broadcast_for_placement\n    )\n    test_case.assertEqual(v[mask1, :, mask2].shape, (3, 8))\n\n\ndef _test_int_indices(test_case, placement):\n    sbp = random_sbp(placement, max_dim=3).value()\n    v = global_broadcast_consec((8, 8, 8)).to_global(placement, sbp)\n    test_case.assertEqual(\n        get_graph_output(v, func=lambda x: x[[0, 4, 2]]).shape, (3, 8, 8)\n    )\n    test_case.assertEqual(\n        get_graph_output(v, func=lambda x: x[:, [0, 4, 2]]).shape, (8, 3, 8)\n    )\n    test_case.assertEqual(\n        get_graph_output(v, func=lambda x: x[:, [[0, 1], [4, 3]]]).shape, (8, 2, 2, 8)\n    )\n\n\ndef _test_int_indices2d(test_case, placement):\n    broadcast_for_placement = [flow.sbp.broadcast,] * len(placement.ranks.shape)\n    sbp = random_sbp(placement, max_dim=2).value()\n    x = global_broadcast_consec((8, 8)).to_global(placement, sbp)\n    rows = _cpu_global_tensor(flow.tensor([[0, 0], [6, 3]])).to_global(\n        placement, broadcast_for_placement\n    )\n    columns = _cpu_global_tensor(flow.tensor([[0, 2], [0, 7]])).to_global(\n        placement, broadcast_for_placement\n    )\n    test_case.assertEqual(\n        get_graph_output(x, func=lambda x: x[rows, columns]).tolist(),\n        [[1, 3], [49, 32]],\n    )\n\n\ndef _test_int_indices_broadcast(test_case, placement):\n    broadcast_for_placement = [flow.sbp.broadcast,] * len(placement.ranks.shape)\n    sbp = random_sbp(placement, max_dim=2).value()\n    x = global_broadcast_consec((8, 8)).to_global(placement, sbp)\n    rows = _cpu_global_tensor(flow.tensor([0, 7])).to_global(\n        placement, broadcast_for_placement\n    )\n    columns = _cpu_global_tensor(flow.tensor([7, 2])).to_global(\n        placement, broadcast_for_placement\n    )\n    result = get_graph_output(x, func=lambda x: x[rows[:, None], columns])\n    test_case.assertEqual(result.tolist(), [[8, 3], [64, 59]])\n\n\ndef _test_empty_index(test_case, placement):\n    broadcast_for_placement = [flow.sbp.broadcast,] * len(placement.ranks.shape)\n    sbp = random_sbp(placement, max_dim=2).value()\n    x = global_broadcast_consec((8, 8)).to_global(placement, sbp)\n    idx = _cpu_global_tensor(flow.tensor([], dtype=flow.long)).to_global(\n        placement, broadcast_for_placement\n    )\n    test_case.assertEqual(get_graph_output(x, func=lambda x: x[idx]).numel(), 0)\n\n    # empty assignment should have no effect but not throw an exception\n    y = x.clone()\n    y = get_graph_output(y, func=lambda x: setitem_and_return(x, idx, -1))\n    _assert_tensor_equal(test_case, x, y)\n\n    # TODO(wyg): support eager bool indices tensor in lazy mode\n    #  mask = _cpu_global_tensor(flow.zeros(8, 8).to(flow.bool)).to_global(\n    #      placement, broadcast_for_placement\n    #  )\n    #  y = get_graph_output(y, func=lambda x: setitem_and_return(x, mask, -1))\n    #  _assert_tensor_equal(test_case, x, y)\n\n\ndef _test_empty_ndim_index(test_case, placement):\n    broadcast_for_placement = [flow.sbp.broadcast,] * len(placement.ranks.shape)\n    sbp = random_sbp(placement, max_dim=1).value()\n    x = global_broadcast_consec((8,)).to_global(placement, sbp)\n    index = _cpu_global_tensor(flow.empty(0, 2, dtype=flow.int64)).to_global(\n        placement, broadcast_for_placement\n    )\n    _assert_tensor_equal(\n        test_case, get_graph_output(x, func=lambda x: x[index]), flow.empty(0, 2),\n    )\n\n    sbp = random_sbp(placement, max_dim=1).value()\n    x = _cpu_global_tensor(flow.empty(8, 0)).to_global(placement, sbp)\n    test_case.assertEqual(get_graph_output(x, func=lambda x: x[[1, 2]]).shape, (2, 0))\n    test_case.assertEqual(get_graph_output(x, func=lambda x: x[[], []]).shape, (0,))\n    test_case.assertEqual(get_graph_output(x, func=lambda x: x[[[]]]).shape, (0, 0))\n    test_case.assertEqual(\n        get_graph_output(x, func=lambda x: x[[[[]]]]).shape, (1, 0, 0)\n    )\n    test_case.assertEqual(get_graph_output(x, func=lambda x: x[[1], []]).shape, (0,))\n    test_case.assertEqual(get_graph_output(x, func=lambda x: x[[], [2]]).shape, (0,))\n\n\ndef _test_empty_slice(test_case, placement):\n    sbp = random_sbp(placement, max_dim=1).value()\n    x = global_broadcast_consec((8, 8, 8, 8)).to_global(placement, sbp)\n    y = get_graph_output(x, func=lambda x: x[:, 1:1, :, 1])\n    test_case.assertEqual((8, 0, 8), y.shape)\n\n\ndef _test_setitem_scalars(test_case, placement):\n    broadcast_for_placement = [flow.sbp.broadcast,] * len(placement.ranks.shape)\n    zero = _cpu_global_tensor(flow.tensor(0, dtype=flow.int64)).to_global(\n        placement, broadcast_for_placement\n    )\n\n    # non-scalar indexed with scalars\n    a = global_broadcast_consec((8, 8)).to_global(\n        placement, random_sbp(placement, max_dim=2).value()\n    )\n    a_set_with_number = a.clone()\n    a_set_with_scalar = a.clone()\n    b = global_broadcast_consec((8,), 233).to_global(\n        placement, random_sbp(placement, max_dim=1).value()\n    )\n\n    a_set_with_number = get_graph_output(\n        a_set_with_number, func=lambda x: setitem_and_return(x, 0, b)\n    )\n    a_set_with_scalar = get_graph_output(\n        a_set_with_scalar, func=lambda x: setitem_and_return(x, zero, b)\n    )\n    _assert_tensor_equal(test_case, a_set_with_number, a_set_with_scalar)\n\n    #  a[1, zero] = 7.7\n    value = get_graph_output(\n        a, func=lambda x: setitem_and_return(x, [1, zero], 7.7)\n    ).numpy()\n    test_case.assertEqual(np.array(7.7, dtype=value.dtype), value[1, 0])\n\n    np_x = np.zeros((8, 8))\n    np_x[0, 6] = 1.0\n    x = _cpu_global_tensor(flow.tensor(np_x)).to_global(\n        placement, random_sbp(placement, max_dim=2).value()\n    )\n    #  x[0, 6] = 1.0\n    res = get_graph_output(x, func=lambda x: setitem_and_return(x, [0, 6], 1.0))\n    test_case.assertEqual(res.numpy().all(), np_x.all())\n\n    # scalar indexed with scalars\n    r = _cpu_global_tensor(flow.tensor(1.0)).to_global(\n        placement, random_sbp(placement, max_dim=0).value()\n    )\n    #  r[...] = 9.9\n    res = get_graph_output(r, func=lambda x: setitem_and_return(x, [...], 9.9))\n    test_case.assertEqual(res, 9.9)\n\n    # scalar indexed with oneflow.Size([1])\n    np_x = np.zeros((8, 8))\n    np_x[0, 6] = np.ones(1)\n    x = _cpu_global_tensor(flow.tensor(np_x)).to_global(\n        placement, random_sbp(placement, max_dim=2).value()\n    )\n    value_tensor = _cpu_global_tensor(flow.ones(1).to(flow.float64)).to_global(\n        placement, broadcast_for_placement\n    )\n    # x[0, 0] = value\n    res = get_graph_output(\n        x, func=lambda x: setitem_and_return(x, [0, 0], value_tensor)\n    )\n    test_case.assertEqual(res.numpy().all(), np_x.all())\n\n\ndef _test_ellipsis_tensor(test_case, placement):\n    broadcast_for_placement = [flow.sbp.broadcast,] * len(placement.ranks.shape)\n    sbp = random_sbp(placement, max_dim=2).value()\n    x = global_broadcast_consec((8, 8)).to_global(placement, sbp)\n    idx = _cpu_global_tensor(flow.tensor([0, 7])).to_global(\n        placement, broadcast_for_placement\n    )\n    test_case.assertEqual(\n        get_graph_output(x, func=lambda x: x[..., idx]).tolist(),\n        [[1, 8], [9, 16], [17, 24], [25, 32], [33, 40], [41, 48], [49, 56], [57, 64]],\n    )\n    test_case.assertEqual(\n        get_graph_output(x, func=lambda x: x[idx, ...]).tolist(),\n        [[1, 2, 3, 4, 5, 6, 7, 8], [57, 58, 59, 60, 61, 62, 63, 64]],\n    )\n\n    # Test scalar ellipsis getitem\n    x_scalar = _cpu_global_tensor(flow.tensor(9.9)).to_global(\n        placement, broadcast_for_placement\n    )\n    test_case.assertEqual(get_graph_output(x_scalar, func=lambda x: x[...]), 9.9)\n\n\ndef _test_bool_indices(test_case, placement):\n    broadcast_for_placement = [flow.sbp.broadcast,] * len(placement.ranks.shape)\n    sbp = random_sbp(placement, max_dim=1, except_partial_sum=True).value()\n    v = global_broadcast_consec((8,)).to_global(placement, sbp)\n    boolIndices = _cpu_global_tensor(\n        flow.tensor(\n            [True, False, True, True, False, False, False, True], dtype=flow.bool\n        )\n    ).to_global(placement, sbp)\n    _assert_tensor_equal(\n        test_case,\n        get_graph_output(v, func=lambda x: setitem_and_return(x, boolIndices, 6.6)),\n        flow.tensor([6.6, 2.0, 6.6, 6.6, 5.0, 6.0, 7.0, 6.6]),\n    )\n\n\nclass TestGlobalIndexing(flow.unittest.TestCase):\n    @globaltest\n    @unittest.skip(\n        \"TODO(wyg, zwx): test these cases after supporting clear session interface to avoid\"\n        \"geting 'stream_id.h:33 Check failed: stream_index <= kMaxStreamIndex (4096 vs. 4095)' error\"\n    )\n    def test_global_slice(test_case):\n        for placement in all_placement():\n            for _ in range(5):\n                _test_basic_slice(test_case, placement)\n                _test_advanced_indexing(test_case, placement, dtype=flow.float32)\n                _test_combined_indexing(test_case, placement, dtype=flow.float32)\n                _test_single_int(test_case, placement)\n                _test_multiple_int(test_case, placement)\n                _test_none(test_case, placement)\n                _test_step(test_case, placement)\n                _test_step_assignment(test_case, placement)\n                _test_int_indices(test_case, placement)\n                _test_int_indices2d(test_case, placement)\n                _test_int_indices_broadcast(test_case, placement)\n                _test_empty_index(test_case, placement)\n                _test_empty_ndim_index(test_case, placement)\n                _test_empty_slice(test_case, placement)\n                _test_ellipsis_tensor(test_case, placement)\n                # TODO: cpu variable don't support common net\n                if not placement.type == \"cpu\":\n                    _test_setitem_scalars(test_case, placement)\n\n    @globaltest\n    def test_bool_indices(test_case):\n        for placement in all_placement():\n            for _ in range(2):\n                _test_bool_indices(test_case, placement)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/tensor/test_meta_tensor.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport os\nimport unittest\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow import nn\n\n\nclass CustomModule(nn.Module):\n    def __init__(self, foo, bar, device=None):\n        super().__init__()\n        # ==== Case 1: Module creates parameters directly. ====\n        self.param1 = nn.Parameter(flow.empty((foo, bar), device=device))\n        self.register_parameter(\"param2\", nn.Parameter(flow.empty(bar, device=device)))\n        with flow.no_grad():\n            nn.init.kaiming_uniform_(self.param1)\n            nn.init.uniform_(self.param2)\n        # ==== Case 2: Module creates submodules. ====\n        self.fc = nn.Linear(bar, 5, device=device)\n        self.linears = nn.Sequential(\n            nn.Linear(5, 5, device=device), nn.Linear(5, 1, device=device)\n        )\n        # ==== Case 3: Module creates buffers. ====\n        self.register_buffer(\"some_buffer\", flow.ones(7, device=device))\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\nclass TestMetaTensor(flow.unittest.TestCase):\n    @flow.unittest.skip_unless_1n1d()\n    def test_meta_tensor_local_mode_without_data(test_case):\n        x = flow.Tensor(3, 2, device=\"meta\")\n        y = flow.Tensor(3, 2, device=\"cpu\")\n        test_case.assertEqual(x.dtype, y.dtype)\n        test_case.assertEqual(x.shape, y.shape)\n        test_case.assertEqual(x.device, flow.device(\"meta\"))\n\n    @flow.unittest.skip_unless_1n1d()\n    def test_meta_tensor_local_mode_with_data(test_case):\n        x = flow.Tensor([3, 2], device=\"meta\")\n        y = flow.Tensor([3, 2], device=\"cpu\")\n        test_case.assertEqual(x.dtype, y.dtype)\n        test_case.assertEqual(x.shape, y.shape)\n        test_case.assertEqual(x.device, flow.device(\"meta\"))\n\n    @flow.unittest.skip_unless_1n1d()\n    def test_meta_tensor_func_local_mode_without_data(test_case):\n        x = flow.tensor([3, 2], device=\"meta\")\n        y = flow.tensor([3, 2], device=\"cpu\")\n        test_case.assertEqual(x.dtype, y.dtype)\n        test_case.assertEqual(x.shape, y.shape)\n        test_case.assertEqual(x.device, flow.device(\"meta\"))\n\n    @flow.unittest.skip_unless_1n1d()\n    def test_meta_tensor_func_local_mode_with_data(test_case):\n        x = flow.tensor([3, 2], device=\"meta\")\n        y = flow.tensor([3, 2], device=\"cpu\")\n        test_case.assertEqual(x.dtype, y.dtype)\n        test_case.assertEqual(x.shape, y.shape)\n        test_case.assertEqual(x.device, flow.device(\"meta\"))\n\n    @flow.unittest.skip_unless_1n1d()\n    def test_meta_tensor_local_mode_ones(test_case):\n        x = flow.ones(3, 2, device=\"meta\")\n        y = flow.ones([3, 2], device=\"cpu\")\n        test_case.assertEqual(x.dtype, y.dtype)\n        test_case.assertEqual(x.shape, y.shape)\n        test_case.assertEqual(x.device, flow.device(\"meta\"))\n\n    @flow.unittest.skip_unless_1n1d()\n    def test_meta_tensor_local_mode_linear(test_case):\n        x = flow.nn.Linear(3, 2, device=\"meta\")\n        y = flow.nn.Linear(3, 2, device=\"cpu\")\n        test_case.assertEqual(x.weight.dtype, y.weight.dtype)\n        test_case.assertEqual(x.weight.shape, y.weight.shape)\n        test_case.assertEqual(x.weight.requires_grad, y.weight.requires_grad)\n        test_case.assertEqual(x.weight.device, flow.device(\"meta\"))\n\n    @flow.unittest.skip_unless_1n1d()\n    def test_skip_init_function(test_case):\n        x = flow.nn.utils.skip_init(flow.nn.Linear, 4, 3)\n        y = flow.nn.Linear(4, 3, device=\"cpu\")\n        test_case.assertEqual(x.weight.dtype, y.weight.dtype)\n        test_case.assertEqual(x.weight.shape, y.weight.shape)\n        test_case.assertEqual(x.weight.requires_grad, y.weight.requires_grad)\n        test_case.assertEqual(x.weight.device, flow.device(\"cpu\"))\n\n    @flow.unittest.skip_unless_1n1d()\n    def test_skip_init_function_custom_module(test_case):\n        x = flow.nn.utils.skip_init(CustomModule, 4, 3)\n        y = CustomModule(4, 3, device=\"cpu\")\n        test_case.assertEqual(x.param1.dtype, y.param1.dtype)\n        test_case.assertEqual(x.param1.shape, y.param1.shape)\n        test_case.assertEqual(x.param1.requires_grad, y.param1.requires_grad)\n        test_case.assertEqual(x.param1.device, flow.device(\"cpu\"))\n        test_case.assertEqual(x.param2.dtype, y.param2.dtype)\n        test_case.assertEqual(x.param2.shape, y.param2.shape)\n        test_case.assertEqual(x.param2.requires_grad, y.param2.requires_grad)\n        test_case.assertEqual(x.param2.device, flow.device(\"cpu\"))\n        test_case.assertEqual(x.fc.weight.dtype, y.fc.weight.dtype)\n        test_case.assertEqual(x.fc.weight.shape, y.fc.weight.shape)\n        test_case.assertEqual(x.fc.weight.requires_grad, y.fc.weight.requires_grad)\n        test_case.assertEqual(x.fc.weight.device, flow.device(\"cpu\"))\n\n    @flow.unittest.skip_unless_1n1d()\n    def test_meta_tensor_local_mode_clone(test_case):\n        x = flow.tensor([3, 2], device=\"meta\")\n        y = x.clone()\n        test_case.assertEqual(x.dtype, y.dtype)\n        test_case.assertEqual(x.shape, y.shape)\n        test_case.assertEqual(x.device, y.device)\n\n    @flow.unittest.skip_unless_1n1d()\n    def test_meta_tensor_global_mode_without_data(test_case):\n        P1 = flow.placement(type=\"meta\", ranks=[0])\n        P2 = flow.placement(type=\"cpu\", ranks=[0])\n        sbp = flow.sbp.broadcast\n        x = flow.Tensor(3, 2, placement=P1, sbp=sbp)\n        y = flow.Tensor(3, 2, placement=P2, sbp=sbp)\n        test_case.assertEqual(x.dtype, y.dtype)\n        test_case.assertEqual(x.shape, y.shape)\n        test_case.assertEqual(x.sbp, y.sbp)\n        test_case.assertEqual(x.placement.type, \"meta\")\n        test_case.assertEqual(x.to_local().dtype, y.to_local().dtype)\n        test_case.assertEqual(x.to_local().shape, y.to_local().shape)\n        test_case.assertEqual(x.to_local().device.type, \"meta\")\n\n    @flow.unittest.skip_unless_1n1d()\n    def test_meta_tensor_global_mode_with_data(test_case):\n        P1 = flow.placement(type=\"meta\", ranks=[0])\n        P2 = flow.placement(type=\"cpu\", ranks=[0])\n        sbp = flow.sbp.broadcast\n        x = flow.Tensor([3, 2], placement=P1, sbp=sbp)\n        y = flow.Tensor([3, 2], placement=P2, sbp=sbp)\n        test_case.assertEqual(x.dtype, y.dtype)\n        test_case.assertEqual(x.shape, y.shape)\n        test_case.assertEqual(x.sbp, y.sbp)\n        test_case.assertEqual(x.placement.type, \"meta\")\n        test_case.assertEqual(x.to_local().dtype, y.to_local().dtype)\n        test_case.assertEqual(x.to_local().shape, y.to_local().shape)\n        test_case.assertEqual(x.to_local().device.type, \"meta\")\n\n    @flow.unittest.skip_unless_1n1d()\n    def test_meta_tensor_func_global_mode_without_data(test_case):\n        P1 = flow.placement(type=\"meta\", ranks=[0])\n        P2 = flow.placement(type=\"cpu\", ranks=[0])\n        sbp = flow.sbp.broadcast\n        x = flow.tensor([3, 2], placement=P1, sbp=sbp)\n        y = flow.tensor([3, 2], placement=P2, sbp=sbp)\n        test_case.assertEqual(x.dtype, y.dtype)\n        test_case.assertEqual(x.shape, y.shape)\n        test_case.assertEqual(x.sbp, y.sbp)\n        test_case.assertEqual(x.placement.type, \"meta\")\n        test_case.assertEqual(x.to_local().dtype, y.to_local().dtype)\n        test_case.assertEqual(x.to_local().shape, y.to_local().shape)\n        test_case.assertEqual(x.to_local().device.type, \"meta\")\n\n    @flow.unittest.skip_unless_1n1d()\n    def test_meta_tensor_global_mode_clone(test_case):\n        P = flow.placement(type=\"meta\", ranks=[0])\n        sbp = flow.sbp.broadcast\n        x = flow.tensor([3, 2], placement=P, sbp=sbp)\n        y = x.clone()\n        test_case.assertEqual(x.dtype, y.dtype)\n        test_case.assertEqual(x.shape, y.shape)\n        test_case.assertEqual(x.sbp, y.sbp)\n        test_case.assertEqual(x.placement, y.placement)\n\n    @flow.unittest.skip_unless_1n1d()\n    def test_meta_tensor_calculate(test_case):\n        x1 = flow.tensor([3, 2], device=\"meta\")\n        y1 = x1 + 1\n        P = flow.placement(type=\"meta\", ranks=[0])\n        sbp = flow.sbp.broadcast\n        x2 = flow.tensor([3, 2], placement=P, sbp=sbp)\n        y2 = x2 + 1\n        test_case.assertEqual(y1.device.type, \"meta\")\n        test_case.assertEqual(y2.placement.type, \"meta\")\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/tensor/test_new_tensor.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport os\nimport unittest\nimport numpy as np\nimport oneflow as flow\nimport oneflow.unittest\n\n\nclass TestNewTensor(flow.unittest.TestCase):\n    @flow.unittest.skip_unless_1n1d()\n    def test_new_tensor_local_mode_with_default_args(test_case):\n        tensor = flow.randn(5)\n        data = [[1, 2], [3, 4]]\n        new_tensor = tensor.new_tensor(data)\n        test_case.assertEqual(new_tensor.dtype, tensor.dtype)\n        test_case.assertEqual(new_tensor.device, tensor.device)\n\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    @flow.unittest.skip_unless_1n1d()\n    def test_new_tensor_local_mode_with_spec_args(test_case):\n        tensor = flow.randn(5)\n        data = [[1, 2], [3, 4]]\n        new_tensor = tensor.new_tensor(data, flow.int64, \"cuda\")\n        test_case.assertEqual(new_tensor.dtype, flow.int64)\n        test_case.assertEqual(new_tensor.device, flow.device(\"cuda\"))\n\n    @flow.unittest.skip_unless_1n2d()\n    def test_new_tensor_global_mode_with_default_args(test_case):\n        placement = flow.placement(type=\"cpu\", ranks=[0, 1])\n        sbp = flow.sbp.split(0)\n        tensor = flow.randn(4, 4, placement=placement, sbp=sbp)\n        data = [[1, 2], [3, 4]]\n        new_tensor = tensor.new_tensor(data)\n        test_case.assertEqual(new_tensor.dtype, tensor.dtype)\n        test_case.assertEqual(new_tensor.placement, placement)\n        test_case.assertEqual(new_tensor.sbp, (sbp,))\n\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    @flow.unittest.skip_unless_1n2d()\n    def test_new_tensor_global_mode_with_spec_args(test_case):\n        placement = flow.placement(type=\"cuda\", ranks=[0, 1])\n        sbp = flow.sbp.split(0)\n        tensor = flow.randn(4, 4, placement=placement, sbp=sbp)\n        data = [[1, 2], [3, 4]]\n        new_tensor = tensor.new_tensor(\n            data, placement=placement, sbp=flow.sbp.broadcast\n        )\n        test_case.assertEqual(new_tensor.dtype, tensor.dtype)\n        test_case.assertEqual(new_tensor.placement, placement)\n        test_case.assertEqual(new_tensor.sbp, (flow.sbp.broadcast,))\n\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    @flow.unittest.skip_unless_1n1d()\n    def test_new_cuda_bfloat16_local_tensor_with_numpy(test_case):\n        from oneflow import sysconfig\n\n        if sysconfig.get_cuda_version() < 11000:\n            return\n        np_array = np.random.rand(4, 4)\n        tensor = flow.tensor(np_array, dtype=flow.bfloat16, device=\"cuda\")\n        test_case.assertEqual(tensor.dtype, flow.bfloat16)\n        test_case.assertEqual(tensor.device, flow.device(\"cuda\"))\n\n    @flow.unittest.skip_unless_1n1d()\n    def test_new_cpu_bfloat16_local_tensor_with_numpy(test_case):\n        np_array = np.random.rand(4, 4)\n        tensor = flow.tensor(np_array, dtype=flow.bfloat16, device=\"cpu\")\n        test_case.assertEqual(tensor.dtype, flow.bfloat16)\n        test_case.assertEqual(tensor.device, flow.device(\"cpu\"))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/tensor/test_parameter.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestParameter(flow.unittest.TestCase):\n    @autotest(n=1, check_graph=True)\n    def test_parameter_grad_fn_none(test_case):\n        x = torch.ones(2, 3).requires_grad_(True)\n        y = x + x\n        z = torch.nn.Parameter(y)\n        return z.grad_fn\n\n    @autotest(n=1, check_graph=True)\n    def test_parameter_set_data_autograd_meta(test_case):\n        x = torch.ones(2, 3).requires_grad_(True)\n        y = x + x\n        z = torch.nn.Parameter(x)\n        z.data = y\n        return z.grad_fn, z.is_leaf\n\n    # Not check graph because of 2 reason.\n    # Reason 1, x.data return a new tensor but share storage with the origin tensor, this is not well dealed in nn.Graph.\n    # Reason 2, inplace operation mul_ can works well inside nn.Graph but will not change the value in free eager tensor.\n    # Please refer to test case: test_graph_return_inplace_free_eager_tensor\n    @autotest(n=1, check_graph=\"ValidatedFalse\")\n    def test_parameter_inplace_modify_data(test_case):\n        x = torch.nn.Parameter(torch.ones(2, 3))\n        x.data.mul_(2)\n        return x\n\n    def test_parameter_set_data(test_case):\n        a = flow.nn.Parameter(flow.ones(2, 3), False)\n        old_id = id(a)\n        b = flow.nn.Parameter(flow.ones(4, 5), True)\n        a.data = b\n        test_case.assertEqual(old_id, id(a))\n        test_case.assertTrue(a.shape == (4, 5))\n        test_case.assertFalse(a.requires_grad)\n        test_case.assertTrue(a.is_leaf)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/tensor/test_safetensors.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport os\nimport unittest\nimport tempfile\nimport oneflow as flow\nimport oneflow.unittest\nimport oneflow.mock_torch as mock\n\n\ntensors = {\n    \"weight1\": flow.zeros((1024, 1024)),\n    \"weight2\": flow.ones((1024, 1024)),\n    \"weight3\": flow.rand((1024, 1024)),\n    \"weight4\": flow.eye(1024),\n}\n\n\ndef _test_save_safetensors(save_path):\n    with mock.enable():\n        from safetensors.torch import save_file\n\n        save_file(tensors, save_path)\n\n\ndef _test_load_safetensors(load_path):\n    with mock.enable():\n        from safetensors import safe_open\n\n        tensors_load = {}\n        with safe_open(load_path, framework=\"pt\", device=\"cpu\") as f:\n            for key in f.keys():\n                tensors_load[key] = f.get_tensor(key)\n        return tensors_load\n\n\nclass TestSafetensors(flow.unittest.TestCase):\n    def test_safetensors(test_case):\n        with tempfile.TemporaryDirectory() as f0:\n            _test_save_safetensors(os.path.join(f0, \"model.safetensors\"))\n            tensors_load = _test_load_safetensors(os.path.join(f0, \"model.safetensors\"))\n        for key in tensors.keys():\n            test_case.assertTrue((tensors[key] == tensors_load[key]).all())\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/tensor/test_tensor_and_ndarray_compatibility.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport unittest\nfrom collections import OrderedDict\n\nimport oneflow as flow\n\nfrom oneflow.test_utils.test_util import GenArgDict\nimport numpy as np\nimport torch\n\n\ntest_compute_op_list = [\n    \"+\",\n    \"-\",\n    \"*\",\n    \"/\",\n    \"**\",\n    \"//\",\n    \"%\",\n]\n\ntest_login_op_list = [\n    \"^\",\n    \"&\",\n    \"|\",\n]\n\ntest_compare_op_list = [\n    \"==\",\n    \"!=\",\n]\n\n\ndef _test_compute_operator(test_case, shape, dtype):\n    random_tensor = np.random.randn(*shape).astype(dtype)\n    x_flow = flow.tensor(random_tensor)\n    x_torch = torch.tensor(random_tensor)\n    random_numpy = np.random.randn(*shape)\n\n    for op in test_compute_op_list:\n        if op in [\"**\", \"//\", \"%\"]:\n            random_tensor = np.random.randint(1, 100, size=shape)\n            random_numpy = np.random.randint(1, 10, size=shape)\n        else:\n            random_tensor = np.random.randn(*shape)\n            random_numpy = np.random.randn(*shape)\n\n        x_flow = flow.tensor(random_tensor)\n        x_torch = torch.tensor(random_tensor)\n\n        z_flow = eval(f\"x_flow {op} random_numpy\")\n        z_torch = eval(f\"x_torch {op} random_numpy\")\n        test_case.assertTrue(np.allclose(z_flow.numpy(), z_torch.numpy()))\n\n        # TODO:support for \"+=\" compatibility\n        if op not in [\"**\", \"+\"]:\n            exec(f\"x_flow {op}= random_numpy\")\n            exec(f\"x_torch {op}= random_numpy\")\n            test_case.assertTrue(\n                np.allclose(z_flow.numpy(), z_torch.numpy(), 1e-05, 1e-05)\n            )\n\n\ndef _test_logic_operator(test_case, shape):\n    random_tensor = np.random.randint(100, size=shape)\n    x_flow = flow.tensor(random_tensor, dtype=flow.int64)\n    x_torch = torch.tensor(random_tensor, dtype=torch.int64)\n    random_numpy = np.random.randint(100, size=shape)\n\n    for op in test_login_op_list:\n        z_flow = eval(f\"x_flow {op} random_numpy\")\n        z_torch = eval(f\"x_torch {op} random_numpy\")\n        test_case.assertTrue(np.allclose(z_flow.numpy(), z_torch.numpy(), 1e-05, 1e-05))\n\n\ndef _test_compare_operator(test_case, shape):\n    random_tensor = np.random.randint(100, size=shape)\n    x_flow = flow.tensor(random_tensor, dtype=flow.int64)\n    x_torch = torch.tensor(random_tensor, dtype=torch.int64)\n    random_numpy = np.random.randint(100, size=shape)\n\n    for op in test_compare_op_list:\n        flow_bool_value = eval(f\"x_flow {op} random_numpy\")\n        torch_bool_value = eval(f\"x_torch {op} random_numpy\")\n        print(flow_bool_value)\n        print(torch_bool_value)\n        test_case.assertTrue(flow_bool_value, torch_bool_value)\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestTensorAndNdarrayCompatibility(flow.unittest.TestCase):\n    def test_op_compatibility(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"shape\"] = [(2, 3), (2, 3, 4), (2, 3, 4, 5)]\n        arg_dict[\"dtype\"] = [np.float32, np.float64]\n\n        for arg in GenArgDict(arg_dict):\n            _test_compute_operator(test_case, **arg)\n            # TODO(yzm):support compare  operator Compatibility\n            # _test_compare_operator(test_case, **arg)\n\n            # TODO(yzm):fix the logic op bug\n            # _test_logic_operator(test_case, **arg)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/tensor/test_tensor_exponential.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nimport random\nimport numpy as np\nfrom collections import OrderedDict\nimport torch\n\nimport oneflow as flow\n\nimport oneflow.unittest\nfrom oneflow.test_utils.test_util import GenArgList\n\n\ndef _test_exponential(test_case, device, seed, lambd, dtype):\n    torch.manual_seed(seed)\n    flow.manual_seed(seed)\n\n    dim1 = random.randint(8, 64)\n    dim2 = random.randint(8, 64)\n\n    torch_arr = torch.zeros(\n        dim1, device=device, dtype=torch.float32 if dtype == \"float\" else torch.float64\n    ).exponential_(lambd=lambd, generator=None)\n    oneflow_arr = flow.zeros(\n        dim1, device=device, dtype=flow.float32 if dtype == \"float\" else flow.float64\n    ).exponential_(lambd=lambd, generator=None)\n\n    test_case.assertTrue(\n        np.allclose(torch_arr.cpu().numpy(), oneflow_arr.cpu().numpy(), atol=1e-8,)\n    )\n\n    torch_arr = torch.zeros(\n        dim1, device=device, dtype=torch.float32 if dtype == \"float\" else torch.float64\n    ).exponential_(lambd=lambd, generator=None)\n    oneflow_arr = flow.zeros(\n        dim1, device=device, dtype=flow.float32 if dtype == \"float\" else flow.float64\n    ).exponential_(lambd=lambd, generator=None)\n\n    test_case.assertTrue(\n        np.allclose(torch_arr.cpu().numpy(), oneflow_arr.cpu().numpy(), atol=1e-8,)\n    )\n\n    torch_gen = torch.Generator(device=device)\n    torch_gen.manual_seed(seed)\n    oneflow_gen = flow.Generator(device=device)\n    oneflow_gen.manual_seed(seed)\n\n    torch_arr = torch.zeros(\n        dim1, device=device, dtype=torch.float32 if dtype == \"float\" else torch.float64\n    ).exponential_(lambd=lambd, generator=torch_gen)\n    oneflow_arr = flow.zeros(\n        dim1, device=device, dtype=flow.float32 if dtype == \"float\" else flow.float64\n    ).exponential_(lambd=lambd, generator=oneflow_gen)\n\n    test_case.assertTrue(\n        np.allclose(torch_arr.cpu().numpy(), oneflow_arr.cpu().numpy(), atol=1e-8,)\n    )\n\n    torch_arr = torch.zeros(\n        dim1, device=device, dtype=torch.float32 if dtype == \"float\" else torch.float64\n    ).exponential_(lambd=lambd, generator=torch_gen)\n    oneflow_arr = flow.zeros(\n        dim1, device=device, dtype=flow.float32 if dtype == \"float\" else flow.float64\n    ).exponential_(lambd=lambd, generator=oneflow_gen)\n\n    test_case.assertTrue(\n        np.allclose(torch_arr.cpu().numpy(), oneflow_arr.cpu().numpy(), atol=1e-8,)\n    )\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestExponential(flow.unittest.TestCase):\n    def test_exponential(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"device\"] = [\"cuda\", \"cpu\"]\n        arg_dict[\"seed\"] = [0, 2, 4]\n        arg_dict[\"lambd\"] = [1, 0.5, 0.1]\n        arg_dict[\"dtype\"] = [\"double\", \"float\"]\n        for arg in GenArgList(arg_dict):\n            _test_exponential(test_case, *arg[0:])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/tensor/test_tensor_indexing.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport os\nimport unittest\nfrom oneflow.test_utils.test_util import GenArgList\nfrom collections import OrderedDict\nfrom oneflow.test_utils.automated_test_util import *\n\nimport numpy as np\nimport oneflow as flow\nimport oneflow.unittest\n\n\ndef _test_numpy_scalar_indexing(test_case, numpy_x, np_scalar):\n    x = flow.Tensor(numpy_x)\n\n    # basic_slice\n    test_case.assertTrue(np.allclose(numpy_x[np_scalar(1)], x[np_scalar(1)].numpy()))\n    test_case.assertTrue(np.allclose(numpy_x[np_scalar(-2)], x[np_scalar(-2)].numpy()))\n    test_case.assertTrue(\n        np.allclose(\n            numpy_x[np_scalar(0), np_scalar(1)], x[np_scalar(0), np_scalar(1)].numpy()\n        )\n    )\n    test_case.assertTrue(\n        np.allclose(\n            numpy_x[(np_scalar(0), np_scalar(1))],\n            x[(np_scalar(0), np_scalar(1))].numpy(),\n        )\n    )\n    test_case.assertTrue(\n        np.allclose(\n            numpy_x[((np_scalar(0), np_scalar(1)))],\n            x[((np_scalar(0), np_scalar(1)))].numpy(),\n        )\n    )\n\n\ndef _test_numpy_scalar_advance_indexing(test_case, numpy_x, np_scalar):\n    x = flow.Tensor(numpy_x)\n\n    # advance indexing\n    test_case.assertTrue(\n        np.allclose(\n            numpy_x[[np_scalar(0), np_scalar(1)]],\n            x[[np_scalar(0), np_scalar(1)]].numpy(),\n        )\n    )\n    test_case.assertTrue(\n        np.allclose(\n            numpy_x[[np_scalar(0), np_scalar(1)], [np_scalar(1), np_scalar(0)]],\n            x[[np_scalar(0), np_scalar(1)], [np_scalar(1), np_scalar(0)]].numpy(),\n        )\n    )\n    test_case.assertTrue(\n        np.allclose(\n            numpy_x[\n                [np_scalar(0), np_scalar(1)],\n                [np_scalar(0), np_scalar(1)],\n                [np_scalar(1), np_scalar(0)],\n            ],\n            x[\n                [np_scalar(0), np_scalar(1)],\n                [np_scalar(0), np_scalar(1)],\n                [np_scalar(1), np_scalar(0)],\n            ].numpy(),\n        )\n    )\n\n\ndef _test_basic_slice(test_case, numpy_x):\n    x = flow.tensor(numpy_x)\n\n    test_case.assertTrue(np.allclose(numpy_x[1], x[1].numpy()))\n    test_case.assertTrue(np.allclose(numpy_x[-2], x[-2].numpy()))\n\n    test_case.assertTrue(np.allclose(numpy_x[0, 1], x[0, 1].numpy()))\n    test_case.assertTrue(np.allclose(numpy_x[(0, 1)], x[(0, 1)].numpy()))\n    test_case.assertTrue(np.allclose(numpy_x[((0, 1))], x[((0, 1))].numpy()))\n\n    test_case.assertTrue(np.allclose(numpy_x[None], x[None].numpy()))\n    test_case.assertTrue(np.allclose(numpy_x[True], x[True].numpy()))\n    test_case.assertTrue(np.allclose(numpy_x[1, None], x[1, None].numpy()))\n    test_case.assertTrue(np.allclose(numpy_x[1, None, 1], x[1, None, 1].numpy()))\n    test_case.assertTrue(\n        np.allclose(numpy_x[1, None, None, 1], x[1, None, None, 1].numpy())\n    )\n\n    test_case.assertTrue(np.allclose(numpy_x[:], x[:].numpy()))\n    test_case.assertTrue(np.allclose(numpy_x[:1], x[:1].numpy()))\n    test_case.assertTrue(np.allclose(numpy_x[0:1], x[0:1].numpy()))\n    test_case.assertTrue(np.allclose(numpy_x[-2:-1], x[-2:-1].numpy()))\n    test_case.assertTrue(np.allclose(numpy_x[2:100:200], x[2:100:200].numpy()))\n\n    test_case.assertTrue(np.allclose(numpy_x[0:2, ...], x[0:2, ...].numpy()))\n    test_case.assertTrue(np.allclose(numpy_x[0:2, ..., 1], x[0:2, ..., 1].numpy()))\n    test_case.assertTrue(\n        np.allclose(numpy_x[0:2, ..., 1, 1], x[0:2, ..., 1, 1].numpy())\n    )\n\n    test_case.assertTrue(np.allclose(numpy_x[0:4:2, ...], x[0:4:2, ...].numpy()))\n    test_case.assertTrue(\n        np.allclose(numpy_x[0:2, None, ..., True], x[0:2, None, ..., True].numpy())\n    )\n    test_case.assertTrue(\n        np.allclose(numpy_x[None, ..., 0:4:2, True], x[None, ..., 0:4:2, True].numpy())\n    )\n\n    test_case.assertTrue(np.allclose(numpy_x[False, ...], x[False, ...].numpy()))\n    test_case.assertTrue(\n        np.allclose(numpy_x[False, True, ...], x[False, True, ...].numpy())\n    )\n    test_case.assertTrue(\n        np.allclose(numpy_x[True, ..., False, True], x[True, ..., False, True].numpy())\n    )\n    test_case.assertTrue(\n        np.allclose(\n            numpy_x[True, None, ..., False, True],\n            x[True, None, ..., False, True].numpy(),\n        )\n    )\n    test_case.assertTrue(\n        np.allclose(\n            numpy_x[True, 1, ..., False, True], x[True, 1, ..., False, True].numpy()\n        )\n    )\n\n\n# NOTE: When numpy>=1.23.0, the list of index will be seemed as basic indexing,\n#       and tuple of index will be seemed as advanced indexing.\ndef _test_advanced_indexing(test_case, numpy_x):\n    x = flow.tensor(numpy_x)\n\n    test_case.assertTrue(np.allclose(numpy_x[[0, 1]], x[[0, 1]].numpy()))\n    test_case.assertTrue(\n        np.allclose(numpy_x[[0, 1], [1, 0]], x[[0, 1], [1, 0]].numpy())\n    )\n    test_case.assertTrue(\n        np.allclose(\n            numpy_x[tuple([[0, 1], [0, 1], [1, 0]])],\n            x[[[0, 1], [0, 1], [1, 0]]].numpy(),\n        )\n    )\n    test_case.assertTrue(np.allclose(numpy_x[tuple([[0], [1]])], x[[[0], [1]]].numpy()))\n    test_case.assertTrue(\n        np.allclose(\n            numpy_x[tuple([[[0], [1]], [[0], [1]], [0, 1]])],\n            x[[[[0], [1]], [[0], [1]], [0, 1]]].numpy(),\n        )\n    )\n    test_case.assertTrue(\n        np.allclose(\n            numpy_x[tuple([[[0, 1], [1, 1]], [[0, 0], [1, 1]], [0, 1]])],\n            x[[[[0, 1], [1, 1]], [[0, 0], [1, 1]], [0, 1]]].numpy(),\n        )\n    )\n\n    # Tensor index\n    test_case.assertTrue(\n        np.allclose(\n            numpy_x[np.array([0, 1]), np.array([1, 0])],\n            x[flow.tensor([0, 1]), flow.tensor([1, 0])].numpy(),\n        )\n    )\n    test_case.assertTrue(\n        np.allclose(\n            numpy_x[:, np.array([[0, 1], [1, 1]]), np.array([[1, 0], [1, 1]])],\n            x[:, flow.tensor([[0, 1], [1, 1]]), flow.tensor([[1, 0], [1, 1]]),].numpy(),\n        )\n    )\n\n    # mask tensor index\n    mask = np.random.rand(numpy_x.shape[0], numpy_x.shape[1]).astype(np.float32)\n    y = flow.tensor(mask)\n    test_case.assertTrue(np.allclose(numpy_x[mask > 0.5], x[y > 0.5].numpy()))\n    test_case.assertTrue(np.allclose(numpy_x[mask > 0.5, 1], x[y > 0.5, 1].numpy()))\n    test_case.assertTrue(np.allclose(numpy_x[mask > 0], x[y > 0].numpy()))\n    test_case.assertTrue(np.allclose(numpy_x[mask > 0, 1], x[y > 0, 1].numpy()))\n    test_case.assertTrue(np.allclose(numpy_x[mask > 1], x[y > 1].numpy()))\n    test_case.assertTrue(np.allclose(numpy_x[mask > 1, 1], x[y > 1, 1].numpy()))\n\n    mask = np.random.rand(*numpy_x.shape).astype(np.float32)\n    y = flow.tensor(mask)\n    test_case.assertTrue(np.allclose(numpy_x[mask > 0.5], x[y > 0.5].numpy()))\n    test_case.assertTrue(np.allclose(numpy_x[mask > 0], x[y > 0].numpy()))\n    test_case.assertTrue(np.allclose(numpy_x[mask > 1], x[y > 1].numpy()))\n\n\ndef _test_advanced_indexing_array(test_case, numpy_x, dtype):\n    x = flow.tensor(numpy_x)\n\n    idx = np.array([0, 1], dtype=dtype)\n    test_case.assertTrue(np.allclose(numpy_x[idx], x[idx].numpy()))\n\n    idx1 = np.array([0, 1], dtype=dtype)\n    idx2 = np.array([1, 0], dtype=dtype)\n    test_case.assertTrue(np.allclose(numpy_x[idx1, idx2], x[idx1, idx2].numpy()))\n\n    idx = np.array([[0, 1], [0, 1], [1, 0]], dtype=dtype)\n    test_case.assertTrue(np.allclose(numpy_x[idx, :, :], x[idx, :, :].numpy()))\n    test_case.assertTrue(np.allclose(numpy_x[idx, idx, :], x[idx, idx, :].numpy()))\n    test_case.assertTrue(np.allclose(numpy_x[idx, idx, idx], x[idx, idx, idx].numpy()))\n\n    idx1 = np.array([[1, 0, 1], [1, 1, 0]])\n    idx2 = np.array([[0], [1]])\n    test_case.assertTrue(\n        np.allclose(numpy_x[:, idx1, :, idx2].shape, x[:, idx1, :, idx2].shape)\n    )\n    test_case.assertTrue(\n        np.allclose(numpy_x[:, idx1, 1, idx2].shape, x[:, idx1, 1, idx2].shape)\n    )\n    test_case.assertTrue(\n        np.allclose(numpy_x[idx1, :, idx2, :].shape, x[idx1, :, idx2, :].shape)\n    )\n    test_case.assertTrue(\n        np.allclose(numpy_x[:, idx1, idx2, :].shape, x[:, idx1, idx2, :].shape)\n    )\n\n\ndef _test_combining_indexing(test_case, numpy_x):\n    x = flow.tensor(numpy_x)\n\n    test_case.assertTrue(\n        np.allclose(numpy_x[[0, 1], 1:2, [1, 0]], x[[0, 1], 1:2, [1, 0]].numpy())\n    )\n    test_case.assertTrue(\n        np.allclose(numpy_x[:, [0, 1], [1, 0]], x[:, [0, 1], [1, 0]].numpy())\n    )\n    test_case.assertTrue(np.allclose(numpy_x[:, [0, 1], 1], x[:, [0, 1], 1].numpy()))\n    test_case.assertTrue(\n        np.allclose(numpy_x[..., [0, 1], 1, [1, 0]], x[..., [0, 1], 1, [1, 0]].numpy())\n    )\n\n\ndef _test_mask_getitem(test_case, numpy_x):\n    x = flow.tensor(numpy_x)\n\n    mask = np.random.rand(*numpy_x.shape).astype(np.float32)\n    y = flow.tensor(mask)\n    test_case.assertTrue(np.allclose(numpy_x[mask > 0.5], x[y > 0.5].numpy()))\n    test_case.assertTrue(np.allclose(numpy_x[mask > 1.0], x[y > 1.0].numpy()))\n\n    mask = np.random.rand(numpy_x.shape[0]).astype(np.float32)\n    y = flow.tensor(mask)\n    test_case.assertTrue(np.allclose(numpy_x[mask > 0.5], x[y > 0.5].numpy()))\n    test_case.assertTrue(np.allclose(numpy_x[mask > 1.0], x[y > 1.0].numpy()))\n\n    test_case.assertTrue(np.allclose(numpy_x[mask > 0.5, 1], x[y > 0.5, 1].numpy()))\n    test_case.assertTrue(np.allclose(numpy_x[mask > 1.0, 1], x[y > 1.0, 1].numpy()))\n\n\ndef _test_mask_setitem(test_case, numpy_x):\n    x = flow.tensor(numpy_x)\n\n    # mask tensor index\n    mask = np.random.rand(*numpy_x.shape).astype(np.float32)\n    y = flow.tensor(mask)\n\n    # broadcast set\n    x[y > 0.5] = 1.0\n    numpy_x[mask > 0.5] = 1.0\n    test_case.assertTrue(np.allclose(numpy_x, x.numpy()))\n\n    # elementwise set\n    update = np.random.randn((mask > 0.5).sum()).astype(np.float32)\n    tensor_update = flow.tensor(update)\n    x[y > 0.5] = tensor_update\n    numpy_x[mask > 0.5] = update\n    test_case.assertTrue(np.allclose(numpy_x, x.numpy()))\n\n    # empty mask\n    x[y > 1.0] = 1.0\n    numpy_x[mask > 1.0] = 1.0\n    test_case.assertTrue(np.allclose(numpy_x, x.numpy()))\n\n\ndef _test_list_indexing_using_scalar_tensor(test_case, dtype):\n    y = np.random.randint(0, 100, size=100)\n    for i in range(len(y)):\n        x = flow.tensor(i, dtype=dtype)\n        test_case.assertEqual(y[i], y[x])\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestTensorIndexing(flow.unittest.TestCase):\n    def test_basic_slice(test_case):\n        numpy_x = np.arange(0, 60, 1).reshape([3, 4, 5]).astype(np.float32)\n        _test_basic_slice(test_case, numpy_x)\n\n        numpy_x = np.arange(0, 360, 1).reshape([3, 4, 5, 6]).astype(np.float32)\n        _test_basic_slice(test_case, numpy_x)\n\n        numpy_x = np.arange(0, 720, 1).reshape([8, 9, 10]).astype(np.float32)\n        _test_basic_slice(test_case, numpy_x)\n\n    def test_advanced_indexing(test_case):\n        numpy_x = np.arange(0, 60, 1).reshape([3, 4, 5]).astype(np.float32)\n        _test_advanced_indexing(test_case, numpy_x)\n\n        numpy_x = np.arange(0, 360, 1).reshape([3, 4, 5, 6]).astype(np.float32)\n        _test_advanced_indexing(test_case, numpy_x)\n\n        numpy_x = np.arange(0, 720, 1).reshape([8, 9, 10]).astype(np.float32)\n        _test_advanced_indexing(test_case, numpy_x)\n\n    def test_advanced_indexing_array(test_case):\n        numpy_x = np.arange(0, 60, 1).reshape([3, 2, 2, 5]).astype(np.float32)\n        _test_advanced_indexing_array(test_case, numpy_x, np.int32)\n        _test_advanced_indexing_array(test_case, numpy_x, np.int64)\n\n        numpy_x = np.arange(0, 360, 1).reshape([3, 4, 5, 6]).astype(np.float32)\n        _test_advanced_indexing_array(test_case, numpy_x, np.int32)\n        _test_advanced_indexing_array(test_case, numpy_x, np.int64)\n\n        numpy_x = np.arange(0, 720, 1).reshape([5, 8, 9, 2]).astype(np.float32)\n        _test_advanced_indexing_array(test_case, numpy_x, np.int32)\n        _test_advanced_indexing_array(test_case, numpy_x, np.int64)\n\n    def test_combining_indexing(test_case):\n        numpy_x = np.arange(0, 60, 1).reshape([3, 4, 5]).astype(np.float32)\n        _test_combining_indexing(test_case, numpy_x)\n\n        numpy_x = np.arange(0, 360, 1).reshape([3, 4, 5, 6]).astype(np.float32)\n        _test_combining_indexing(test_case, numpy_x)\n\n        numpy_x = np.arange(0, 720, 1).reshape([8, 9, 10]).astype(np.float32)\n        _test_combining_indexing(test_case, numpy_x)\n\n    def test_numpy_scalar_indexing(test_case):\n        for np_scalar in [np.int8, np.int16, np.int32, np.int64]:\n            numpy_x = np.arange(0, 60, 1).reshape([3, 4, 5]).astype(np.float32)\n            _test_numpy_scalar_indexing(test_case, numpy_x, np_scalar)\n\n            numpy_x = np.arange(0, 360, 1).reshape([3, 4, 5, 6]).astype(np.float32)\n            _test_numpy_scalar_indexing(test_case, numpy_x, np_scalar)\n\n            numpy_x = np.arange(0, 720, 1).reshape([8, 9, 10]).astype(np.float32)\n            _test_numpy_scalar_indexing(test_case, numpy_x, np_scalar)\n\n        # TODO: add np.int16 when advance indexing supports np.int16 mapping\n        for np_scalar in [np.int32, np.int64]:\n            numpy_x = np.arange(0, 60, 1).reshape([3, 4, 5]).astype(np.float32)\n            _test_numpy_scalar_advance_indexing(test_case, numpy_x, np_scalar)\n\n            numpy_x = np.arange(0, 360, 1).reshape([3, 4, 5, 6]).astype(np.float32)\n            _test_numpy_scalar_advance_indexing(test_case, numpy_x, np_scalar)\n\n            numpy_x = np.arange(0, 720, 1).reshape([8, 9, 10]).astype(np.float32)\n            _test_numpy_scalar_advance_indexing(test_case, numpy_x, np_scalar)\n\n    def test_mask_getitem(test_case):\n        numpy_x = np.arange(0, 60, 1).reshape([3, 4, 5]).astype(np.float32)\n        _test_mask_getitem(test_case, numpy_x)\n\n        numpy_x = np.arange(0, 360, 1).reshape([3, 4, 5, 6]).astype(np.float32)\n        _test_mask_getitem(test_case, numpy_x)\n\n        numpy_x = np.arange(0, 720, 1).reshape([8, 9, 10]).astype(np.float32)\n        _test_mask_getitem(test_case, numpy_x)\n\n        numpy_x = np.arange(0, 27, 1).reshape(3, 3, 3)\n        x = flow.tensor(numpy_x)\n        test_case.assertTrue(\n            np.allclose(\n                numpy_x[[False, True, False], 1], x[[False, True, False], 1].numpy()\n            )\n        )\n        test_case.assertTrue(\n            np.allclose(\n                numpy_x[[False, True, False], [True, False, False]],\n                x[[False, True, False], [True, False, False]].numpy(),\n            )\n        )\n\n    def test_mask_setitem(test_case):\n        numpy_x = np.arange(0, 60, 1).reshape([3, 4, 5]).astype(np.float32)\n        _test_mask_setitem(test_case, numpy_x)\n\n        numpy_x = np.arange(0, 360, 1).reshape([3, 4, 5, 6]).astype(np.float32)\n        _test_mask_setitem(test_case, numpy_x)\n\n        numpy_x = np.arange(0, 720, 1).reshape([8, 9, 10]).astype(np.float32)\n        _test_mask_setitem(test_case, numpy_x)\n\n    def test_combined_mask_setitem(test_case):\n        np_in = np.random.rand(5, 4, 3, 2)\n        np_mask_dim1 = np.array([False, True, False, True])\n        np_mask_dim3 = np.array([True, False])\n        np_update = np.random.rand(2, 5, 3)\n        np_in[:, np_mask_dim1, :, np_mask_dim3] = np_update\n\n        flow_in = flow.tensor(np_in)\n        flow_mask_dim1 = flow.tensor(np_mask_dim1)\n        flow_mask_dim3 = flow.tensor(np_mask_dim3)\n        flow_update = flow.tensor(np_update)\n        flow_in[:, flow_mask_dim1, :, flow_mask_dim3] = flow_update\n        test_case.assertTrue(np.array_equal(flow_in.numpy(), np_in))\n\n    def test_non_contiguous_combined_mask_setitem(test_case):\n        np_in = np.random.rand(5, 4, 3, 2)\n        np_mask_dim1 = np.array([False, True, False])\n        np_mask_dim3 = np.array([True, False, False, True, True])\n        np_update = np.random.rand(4, 2, 3)\n\n        flow_in = flow.tensor(np_in).permute(3, 2, 1, 0)  # (2, 3, 4, 5)\n        flow_mask_dim1 = flow.tensor(np_mask_dim1)\n        flow_mask_dim3 = flow.tensor(np_mask_dim3)\n        flow_update = flow.tensor(np_update).permute(2, 1, 0)  # (3, 2, 4)\n        flow_in[:, flow_mask_dim1, :, flow_mask_dim3] = flow_update\n\n        np_in = np_in.transpose(3, 2, 1, 0)\n        np_update = np_update.transpose(2, 1, 0)\n        np_in[:, np_mask_dim1, :, np_mask_dim3] = np_update\n        test_case.assertTrue(np.array_equal(flow_in.numpy(), np_in))\n\n    def test_combined_indexing_setitem(test_case):\n        np_in = np.random.rand(2, 3, 4)\n        np_in[[0, 1], 1:2, [0, 1]] = 1.0\n\n        flow_in = flow.tensor(np_in)\n        flow_in[[0, 1], 1:2, [0, 1]] = 1.0\n        test_case.assertTrue(np.array_equal(flow_in.numpy(), np_in))\n\n    def test_expand_dim_setitem(test_case):\n        a = flow.tensor(1.0)\n        a[True, ...] = 0.0\n        test_case.assertTrue(np.array_equal(a.numpy(), 0.0))\n\n        a = flow.tensor(1.0)\n        a[False, ...] = 1.0\n        test_case.assertTrue(np.array_equal(a.numpy(), 1.0))\n\n    def test_advanced_indexing_with_scalar_index(test_case):\n        index = flow.tensor([0, 2])\n        x = flow.randn(5)\n        x[index[0]] = 1\n        test_case.assertTrue(np.allclose(x[0].numpy(), 1))\n\n    def test_list_indexing_using_scalar_tensor(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"function_test\"] = [\n            _test_list_indexing_using_scalar_tensor,\n        ]\n        arg_dict[\"dtype\"] = [flow.uint8, flow.int8, flow.int32, flow.int64]\n        for arg in GenArgList(arg_dict):\n            arg[0](test_case, *arg[1:])\n\n    @autotest(n=3, auto_backward=False)\n    def test_advanced_indexing_with_0_size_tensor(test_case):\n        device = random_device()\n        data = torch.arange(8).reshape(2, 2, 2).to(device)\n        ranges = []\n        ranges.append(torch.ones(0, 1).to(torch.int64))\n        ranges.append(torch.zeros(1, 3).to(torch.int64))\n        res = data[ranges]\n        return res\n\n    @autotest(n=1)\n    def test_dataloader_indexing_with_1_dim_tensor(test_case):\n        device = random_device()\n        x = random_tensor(ndim=1, dim0=512).to(device)\n        batch_data = list()\n        for i in range(512):\n            batch_data.append(x[i])\n        return torch.stack(batch_data)\n\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    def test_indecies_on_different_devices(test_case):\n        x = flow.ones(3, 10)\n        y = flow.ones(3, 10, device=flow.device(\"cuda:0\"))\n\n        x_idx = [flow.tensor([1, 2]), flow.tensor([2, 0], device=flow.device(\"cuda:0\"))]\n        y_idx = [flow.tensor([1, 2], device=flow.device(\"cuda:0\")), flow.tensor([2, 0])]\n\n        test_case.assertTrue(np.allclose(x[x_idx].numpy(), np.array([1, 1])))\n        test_case.assertTrue(np.allclose(y[y_idx].numpy(), np.array([1, 1])))\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\nclass TestTensorIndexingMultiGpu(flow.unittest.TestCase):\n    @flow.unittest.skip_unless_1n2d()\n    def test_indecies_on_different_devices(test_case):\n        x = flow.ones(3, 10, device=flow.device(\"cuda:0\"))\n        idx = [flow.tensor([1, 2], device=flow.device(\"cuda:1\")), flow.tensor([2, 0])]\n        test_case.assertTrue(np.allclose(x[idx].numpy(), np.array([1, 1])))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/tensor/test_tensor_indexing2.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\n# This test code is referenced from: https://github.com/pytorch/pytorch/blob/cd41c8f032dd06c445bf97fc76fb82008b19afcb/test/test_indexing.py\n\nfrom collections import OrderedDict\nimport random\nfrom random import randrange\nimport unittest\n\nimport numpy as np\n\nimport oneflow as flow\nfrom oneflow.test_utils.test_util import GenArgDict\nimport oneflow.unittest\n\n\ndef _assert_tensor_equal(test_case, tensor1, tensor2, atol=0.0, rtol=0.0):\n    test_case.assertTrue(np.allclose(tensor1.numpy(), tensor2.numpy()))\n\n\ndef consec(size, start=1):\n    \"\"\"\n    Generate a arithmetic progression with given size and start value.\n    \"\"\"\n    sequence = flow.ones([int(np.array(size).prod(0)),]).cumsum(0)\n    sequence.add_(start - 1)\n    return sequence.view(*size)\n\n\ndef _test_basic_slice(test_case, device, dtype):\n    reference = consec((3, 3, 3)).to(device=device, dtype=dtype)\n\n    # empty tensor indexing\n    _assert_tensor_equal(\n        test_case,\n        reference[flow.LongTensor().to(device)],\n        flow.empty(0, 3, 3),\n        atol=0,\n        rtol=0,\n    )\n\n    _assert_tensor_equal(test_case, reference[0], consec((3, 3)), atol=0, rtol=0)\n    _assert_tensor_equal(test_case, reference[1], consec((3, 3), 10), atol=0, rtol=0)\n    _assert_tensor_equal(test_case, reference[2], consec((3, 3), 19), atol=0, rtol=0)\n    _assert_tensor_equal(test_case, reference[0, 1], consec((3,), 4), atol=0, rtol=0)\n    _assert_tensor_equal(test_case, reference[0:2], consec((2, 3, 3)), atol=0, rtol=0)\n    test_case.assertEqual(reference[2, 2, 2].item(), 27)\n    _assert_tensor_equal(test_case, reference[:], consec((3, 3, 3)), atol=0, rtol=0)\n\n    # indexing with Ellipsis\n    _assert_tensor_equal(\n        test_case,\n        reference[..., 2],\n        flow.tensor([[3.0, 6.0, 9.0], [12.0, 15.0, 18.0], [21.0, 24.0, 27.0]]),\n        atol=0,\n        rtol=0,\n    )\n    _assert_tensor_equal(\n        test_case, reference[0, ..., 2], flow.tensor([3.0, 6.0, 9.0]), atol=0, rtol=0\n    )\n    _assert_tensor_equal(\n        test_case, reference[..., 2], reference[:, :, 2], atol=0, rtol=0\n    )\n    _assert_tensor_equal(\n        test_case, reference[0, ..., 2], reference[0, :, 2], atol=0, rtol=0\n    )\n    _assert_tensor_equal(\n        test_case, reference[0, 2, ...], reference[0, 2], atol=0, rtol=0\n    )\n    test_case.assertEqual(reference[..., 2, 2, 2].item(), 27)\n    test_case.assertEqual(reference[2, ..., 2, 2].item(), 27)\n    test_case.assertEqual(reference[2, 2, ..., 2].item(), 27)\n    test_case.assertEqual(reference[2, 2, 2, ...].item(), 27)\n    _assert_tensor_equal(test_case, reference[...], reference, atol=0, rtol=0)\n\n    reference_5d = consec((3, 3, 3, 3, 3)).to(device)\n    _assert_tensor_equal(\n        test_case, reference_5d[..., 1, 0], reference_5d[:, :, :, 1, 0], atol=0, rtol=0\n    )\n    _assert_tensor_equal(\n        test_case,\n        reference_5d[2, ..., 1, 0],\n        reference_5d[2, :, :, 1, 0],\n        atol=0,\n        rtol=0,\n    )\n    _assert_tensor_equal(\n        test_case,\n        reference_5d[2, 1, 0, ..., 1],\n        reference_5d[2, 1, 0, :, 1],\n        atol=0,\n        rtol=0,\n    )\n    _assert_tensor_equal(test_case, reference_5d[...], reference_5d, atol=0, rtol=0)\n\n    # LongTensor indexing\n    reference = consec((5, 5, 5)).to(device=device, dtype=dtype)\n    idx = flow.LongTensor([2, 4]).to(device)\n    _assert_tensor_equal(\n        test_case, reference[idx], flow.stack([reference[2], reference[4]])\n    )\n\n    # None indexing\n    _assert_tensor_equal(test_case, reference[2, None], reference[2].unsqueeze(0))\n    _assert_tensor_equal(\n        test_case, reference[2, None, None], reference[2].unsqueeze(0).unsqueeze(0)\n    )\n    _assert_tensor_equal(test_case, reference[2:4, None], reference[2:4].unsqueeze(1))\n    _assert_tensor_equal(\n        test_case,\n        reference[None, 2, None, None],\n        reference.unsqueeze(0)[:, 2].unsqueeze(0).unsqueeze(0),\n    )\n    _assert_tensor_equal(\n        test_case,\n        reference[None, 2:5, None, None],\n        reference.unsqueeze(0)[:, 2:5].unsqueeze(2).unsqueeze(2),\n    )\n\n    # indexing 0-length slice\n    _assert_tensor_equal(test_case, flow.empty(0, 5, 5), reference[slice(0)])\n    _assert_tensor_equal(test_case, flow.empty(0, 5), reference[slice(0), 2])\n    _assert_tensor_equal(test_case, flow.empty(0, 5), reference[2, slice(0)])\n    _assert_tensor_equal(test_case, flow.tensor([]), reference[2, 1:1, 2])\n\n    # indexing with step\n    reference = consec((10, 10, 10)).to(device=device, dtype=dtype)\n    _assert_tensor_equal(\n        test_case, reference[1:5:2], flow.stack([reference[1], reference[3]], 0)\n    )\n    _assert_tensor_equal(\n        test_case,\n        reference[1:6:2],\n        flow.stack([reference[1], reference[3], reference[5]], 0),\n    )\n    _assert_tensor_equal(\n        test_case, reference[1:9:4], flow.stack([reference[1], reference[5]], 0)\n    )\n    _assert_tensor_equal(\n        test_case,\n        reference[2:4, 1:5:2],\n        flow.stack([reference[2:4, 1], reference[2:4, 3]], 1),\n    )\n    _assert_tensor_equal(\n        test_case,\n        reference[3, 1:6:2],\n        flow.stack([reference[3, 1], reference[3, 3], reference[3, 5]], 0),\n    )\n    _assert_tensor_equal(\n        test_case,\n        reference[None, 2, 1:9:4],\n        flow.stack([reference[2, 1], reference[2, 5]], 0).unsqueeze(0),\n    )\n    _assert_tensor_equal(\n        test_case,\n        reference[:, 2, 1:6:2],\n        flow.stack([reference[:, 2, 1], reference[:, 2, 3], reference[:, 2, 5]], 1),\n    )\n\n    lst = [list(range(i, i + 10)) for i in range(0, 100, 10)]\n    tensor = flow.DoubleTensor(lst).to(device=device, dtype=dtype)\n    for _ in range(10):\n        idx1_start = randrange(10)\n        idx1_end = idx1_start + randrange(1, 10 - idx1_start + 1)\n        idx1_step = randrange(1, 8)\n        idx1 = slice(idx1_start, idx1_end, idx1_step)\n        if randrange(2) == 0:\n            idx2_start = randrange(10)\n            idx2_end = idx2_start + randrange(1, 10 - idx2_start + 1)\n            idx2_step = randrange(1, 8)\n            idx2 = slice(idx2_start, idx2_end, idx2_step)\n            lst_indexed = [l[idx2] for l in lst[idx1]]\n            tensor_indexed = tensor[idx1, idx2]\n        else:\n            lst_indexed = lst[idx1]\n            tensor_indexed = tensor[idx1]\n        _assert_tensor_equal(\n            test_case, flow.DoubleTensor(lst_indexed).to(dtype), tensor_indexed\n        )\n\n    test_case.assertRaises(RuntimeError, lambda: reference[1:9:0])\n    test_case.assertRaises(RuntimeError, lambda: reference[1:9:-1])\n\n    test_case.assertRaises(IndexError, lambda: reference[1, 1, 1, 1])\n    test_case.assertRaises(IndexError, lambda: reference[1, 1, 1, 1:1])\n    test_case.assertRaises(IndexError, lambda: reference[3, 3, 3, 3, 3, 3, 3, 3])\n\n    test_case.assertRaises(IndexError, lambda: reference[0.0])\n    test_case.assertRaises(RuntimeError, lambda: reference[0.0:2.0])\n    test_case.assertRaises(IndexError, lambda: reference[0.0, 0.0:2.0])\n    test_case.assertRaises(IndexError, lambda: reference[0.0, :, 0.0:2.0])\n    test_case.assertRaises(IndexError, lambda: reference[0.0, ..., 0.0:2.0])\n    test_case.assertRaises(IndexError, lambda: reference[0.0, :, 0.0])\n\n\ndef _test_advanced_indexing(test_case, device, dtype):\n    # pick a random valid indexer type\n    def ri(indices):\n        choice = random.randint(0, 2)\n        if choice == 0:\n            return flow.LongTensor(indices).to(device)\n        elif choice == 1:\n            return list(indices)\n        else:\n            return tuple(indices)\n\n    def validate_indexing(x):\n        _assert_tensor_equal(test_case, x[[0]], consec((1,)))\n        _assert_tensor_equal(test_case, x[ri([0]),], consec((1,)))\n        _assert_tensor_equal(test_case, x[ri([3]),], consec((1,), 4))\n        _assert_tensor_equal(test_case, x[[2, 3, 4]], consec((3,), 3))\n        _assert_tensor_equal(test_case, x[ri([2, 3, 4]),], consec((3,), 3))\n        _assert_tensor_equal(\n            test_case,\n            x[ri([0, 2, 4]),],\n            flow.tensor([1, 3, 5], dtype=dtype, device=device),\n        )\n\n    def validate_setting(x):\n        x[[0]] = -2\n        _assert_tensor_equal(\n            test_case, x[[0]], flow.tensor([-2], dtype=dtype, device=device)\n        )\n        x[[0]] = -1\n        _assert_tensor_equal(\n            test_case, x[ri([0]),], flow.tensor([-1], dtype=dtype, device=device)\n        )\n        x[[2, 3, 4]] = 4\n        _assert_tensor_equal(\n            test_case, x[[2, 3, 4]], flow.tensor([4, 4, 4], dtype=dtype, device=device)\n        )\n        x[ri([2, 3, 4]),] = 3\n        _assert_tensor_equal(\n            test_case,\n            x[ri([2, 3, 4]),],\n            flow.tensor([3, 3, 3], dtype=dtype, device=device),\n        )\n        x[ri([0, 2, 4]),] = flow.tensor([5, 4, 3], dtype=dtype, device=device)\n        _assert_tensor_equal(\n            test_case,\n            x[ri([0, 2, 4]),],\n            flow.tensor([5, 4, 3], dtype=dtype, device=device),\n        )\n\n    # 1d tensor and integer index setitem and getitem\n    reference = consec((10,)).to(device=device, dtype=dtype)\n    validate_indexing(reference)\n    validate_setting(reference)\n\n    # reference is 1 2\n    #              3 4\n    #              5 6\n    reference = consec((3, 2)).to(device=device, dtype=dtype)\n    _assert_tensor_equal(\n        test_case,\n        reference[ri([0, 1, 2]), ri([0])],\n        flow.tensor([1, 3, 5], dtype=dtype, device=device),\n    )\n    _assert_tensor_equal(\n        test_case,\n        reference[ri([0, 1, 2]), ri([1])],\n        flow.tensor([2, 4, 6], dtype=dtype, device=device),\n    )\n    _assert_tensor_equal(test_case, reference[ri([0]), ri([0])], consec((1,)))\n    _assert_tensor_equal(test_case, reference[ri([2]), ri([1])], consec((1,), 6))\n    _assert_tensor_equal(\n        test_case,\n        reference[[ri([0, 0]), ri([0, 1])]],\n        flow.tensor([1, 2], dtype=dtype, device=device),\n    )\n    _assert_tensor_equal(\n        test_case,\n        reference[[ri([0, 1, 1, 0, 2]), ri([1])]],\n        flow.tensor([2, 4, 4, 2, 6], dtype=dtype, device=device),\n    )\n    _assert_tensor_equal(\n        test_case,\n        reference[[ri([0, 0, 1, 1]), ri([0, 1, 0, 0])]],\n        flow.tensor([1, 2, 3, 3], dtype=dtype, device=device),\n    )\n\n    rows = ri([[0, 0], [1, 2]])\n    columns = ([0],)\n    _assert_tensor_equal(\n        test_case,\n        reference[rows, columns],\n        flow.tensor([[1, 1], [3, 5]], dtype=dtype, device=device),\n    )\n\n    rows = ri([[0, 0], [1, 2]])\n    columns = ri([1, 0])\n    _assert_tensor_equal(\n        test_case,\n        reference[rows, columns],\n        flow.tensor([[2, 1], [4, 5]], dtype=dtype, device=device),\n    )\n    rows = ri([[0, 0], [1, 2]])\n    columns = ri([[0, 1], [1, 0]])\n    _assert_tensor_equal(\n        test_case,\n        reference[rows, columns],\n        flow.tensor([[1, 2], [4, 5]], dtype=dtype, device=device),\n    )\n\n    # setting values\n    reference[ri([0]), ri([1])] = -1\n    _assert_tensor_equal(\n        test_case,\n        reference[ri([0]), ri([1])],\n        flow.tensor([-1], dtype=dtype, device=device),\n    )\n    reference[ri([0, 1, 2]), ri([0])] = flow.tensor(\n        [-1, 2, -4], dtype=dtype, device=device\n    )\n    _assert_tensor_equal(\n        test_case,\n        reference[ri([0, 1, 2]), ri([0])],\n        flow.tensor([-1, 2, -4], dtype=dtype, device=device),\n    )\n    reference[rows, columns] = flow.tensor([[4, 6], [2, 3]], dtype=dtype, device=device)\n    _assert_tensor_equal(\n        test_case,\n        reference[rows, columns],\n        flow.tensor([[4, 6], [2, 3]], dtype=dtype, device=device),\n    )\n\n    # Test non-contiguous(by transpose) reference\n    # Transposed: [[0, 4, 8],\n    #              [1, 5, 9],\n    #              [2, 6, 10],\n    #              [3, 7, 11]]\n    reference = flow.tensor(\n        [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]], dtype=dtype, device=device\n    ).T\n\n    _assert_tensor_equal(\n        test_case,\n        reference[ri([0, 1, 2]), ri([0])],\n        flow.tensor([0, 1, 2], dtype=dtype, device=device),\n    )\n    _assert_tensor_equal(\n        test_case,\n        reference[ri([0, 1, 2]), ri([1])],\n        flow.tensor([4, 5, 6], dtype=dtype, device=device),\n    )\n    _assert_tensor_equal(\n        test_case,\n        reference[ri([0]), ri([0])],\n        flow.tensor([0], dtype=dtype, device=device),\n    )\n    _assert_tensor_equal(\n        test_case,\n        reference[ri([2]), ri([1])],\n        flow.tensor([6], dtype=dtype, device=device),\n    )\n    _assert_tensor_equal(\n        test_case,\n        reference[[ri([0, 0]), ri([0, 1])]],\n        flow.tensor([0, 4], dtype=dtype, device=device),\n    )\n    _assert_tensor_equal(\n        test_case,\n        reference[[ri([0, 1, 1, 0, 3]), ri([1])]],\n        flow.tensor([4, 5, 5, 4, 7], dtype=dtype, device=device),\n    )\n    _assert_tensor_equal(\n        test_case,\n        reference[[ri([0, 0, 1, 1]), ri([0, 1, 0, 0])]],\n        flow.tensor([0, 4, 1, 1], dtype=dtype, device=device),\n    )\n\n    rows = ri([[0, 0], [1, 2]])\n    columns = ([0],)\n    _assert_tensor_equal(\n        test_case,\n        reference[rows, columns],\n        flow.tensor([[0, 0], [1, 2]], dtype=dtype, device=device),\n    )\n\n    rows = ri([[0, 0], [1, 2]])\n    columns = ri([1, 0])\n    _assert_tensor_equal(\n        test_case,\n        reference[rows, columns],\n        flow.tensor([[4, 0], [5, 2]], dtype=dtype, device=device),\n    )\n    rows = ri([[0, 0], [1, 3]])\n    columns = ri([[0, 1], [1, 2]])\n    _assert_tensor_equal(\n        test_case,\n        reference[rows, columns],\n        flow.tensor([[0, 4], [5, 11]], dtype=dtype, device=device),\n    )\n\n    # setting values\n    reference[ri([0]), ri([1])] = -1\n    _assert_tensor_equal(\n        test_case,\n        reference[ri([0]), ri([1])],\n        flow.tensor([-1], dtype=dtype, device=device),\n    )\n    reference[ri([0, 1, 2]), ri([0])] = flow.tensor(\n        [-1, 2, -4], dtype=dtype, device=device\n    )\n    _assert_tensor_equal(\n        test_case,\n        reference[ri([0, 1, 2]), ri([0])],\n        flow.tensor([-1, 2, -4], dtype=dtype, device=device),\n    )\n    reference[rows, columns] = flow.tensor([[4, 6], [2, 3]], dtype=dtype, device=device)\n    _assert_tensor_equal(\n        test_case,\n        reference[rows, columns],\n        flow.tensor([[4, 6], [2, 3]], dtype=dtype, device=device),\n    )\n\n    # Tests using less than the number of dims, and ellipsis\n    # reference is 1 2\n    #              3 4\n    #              5 6\n    reference = consec((3, 2)).to(dtype=dtype, device=device)\n    _assert_tensor_equal(\n        test_case,\n        reference[ri([0, 2]),],\n        flow.tensor([[1, 2], [5, 6]], dtype=dtype, device=device),\n    )\n    _assert_tensor_equal(\n        test_case,\n        reference[ri([1]), ...],\n        flow.tensor([[3, 4]], dtype=dtype, device=device),\n    )\n    _assert_tensor_equal(\n        test_case,\n        reference[..., ri([1])],\n        flow.tensor([[2], [4], [6]], dtype=dtype, device=device),\n    )\n\n    # verify too many indices fails\n    with test_case.assertRaises(IndexError):\n        reference[ri([1]), ri([0, 2]), ri([3])]\n\n    # test invalid index fails\n    reference = flow.empty(10, dtype=dtype, device=device)\n    for err_idx in (10, -11):\n        with test_case.assertRaisesRegex(IndexError, r\"out of range\"):\n            reference[err_idx]\n\n\ndef _test_combined_indexing(test_case, device, dtype):\n    def tensor_indices_to_np(tensor, indices):\n        # convert the flow Tensor to a numpy array\n        tensor = tensor.to(device=\"cpu\")\n        npt = tensor.numpy()\n\n        # convert indices\n        idxs = tuple(\n            i.tolist() if isinstance(i, flow.LongTensor) else i for i in indices\n        )\n\n        return npt, idxs\n\n    def get_numpy(tensor, indices):\n        npt, idxs = tensor_indices_to_np(tensor, indices)\n\n        # index and return as a flow Tensor\n        return flow.tensor(npt[idxs], dtype=dtype, device=device)\n\n    def set_numpy(tensor, indices, value):\n        if not isinstance(value, int):\n            if device != \"cpu\":\n                value = value.cpu()\n            value = value.numpy()\n\n        npt, idxs = tensor_indices_to_np(tensor, indices)\n        npt[idxs] = value\n        return npt\n\n    def assert_get_eq(tensor, indexer):\n        _assert_tensor_equal(test_case, tensor[indexer], get_numpy(tensor, indexer))\n\n    def assert_set_eq(tensor, indexer, val):\n        pyt = tensor.clone()\n        np_ref = tensor.clone()\n        pyt[indexer] = val\n        np_ref = flow.tensor(\n            set_numpy(np_ref, indexer, val), dtype=dtype, device=device\n        )\n        _assert_tensor_equal(test_case, pyt, np_ref)\n\n    def assert_backward_eq(tensor, indexer):\n        cpu = tensor.cpu().float().clone().detach().requires_grad_(True)\n        outcpu = cpu[indexer]\n        grad = flow.rand(outcpu.shape)\n        outcpu.backward(grad)\n        dev = cpu.to(device).detach().requires_grad_(True)\n        outdev = dev[indexer]\n        outdev.backward(grad.to(device))\n        _assert_tensor_equal(test_case, cpu.grad, dev.grad)\n\n    def get_set_tensor(indexed, indexer):\n        set_size = indexed[indexer].size()\n        set_count = indexed[indexer].numel()\n        set_tensor = flow.randperm(set_count).view(set_size).to(dtype).to(device)\n        return set_tensor\n\n    # Tensor is  0  1  2  3  4\n    #            5  6  7  8  9\n    #           10 11 12 13 14\n    #           15 16 17 18 19\n    reference = flow.arange(0.0, 20, device=device).to(dtype).view(4, 5)\n\n    indices_to_test = [\n        # grab the second, fourth columns\n        [slice(None), [1, 3]],\n        # first, third rows,\n        [[0, 2], slice(None)],\n        # TODO(wyg): only support getitem but not setitem\n        #  # weird shape\n        #  [slice(None), [[0, 1],\n        #                 [2, 3]]],\n        # negatives\n        [[-1], [0]],\n        [[0, 2], [-1]],\n        [slice(None), [-1]],\n    ]\n\n    # test getitem\n    get_indices_to_test = indices_to_test + [[slice(None), [0, 1, 1, 2, 2]]]\n    get_indices_to_test = indices_to_test + [\n        [slice(None), [[0, 1], [2, 3]]]\n    ]  # TODO: test setitem\n    for indexer in get_indices_to_test:\n        assert_get_eq(reference, indexer)\n        if device != \"cpu\":\n            assert_backward_eq(reference, indexer)\n\n    # test setitem\n    for indexer in indices_to_test:\n        assert_set_eq(reference, indexer, 44)\n        assert_set_eq(reference, indexer, get_set_tensor(reference, indexer))\n\n    #########################\n    # test more dims tensor #\n    #########################\n    reference = flow.arange(0.0, 160, device=device).to(dtype).view(4, 8, 5)\n\n    indices_to_test = [\n        [slice(None), slice(None), [0, 3, 4]],\n        [slice(None), [2, 4, 5, 7], slice(None)],\n        [[2, 3], slice(None), slice(None)],\n        [slice(None), [0, 2, 3], [1, 3, 4]],\n        [slice(None), [0], [1, 2, 4]],\n        [slice(None), [0, 1, 3], [4]],\n        [slice(None), [[0, 1], [1, 0]], [[2, 3]]],\n        [slice(None), [[0, 1], [2, 3]], [[0]]],\n        [slice(None), [[5, 6]], [[0, 3], [4, 4]]],\n        [[0, 2, 3], [1, 3, 4], slice(None)],\n        [[0], [1, 2, 4], slice(None)],\n        [[0, 1, 3], [4], slice(None)],\n        [[[0, 1], [1, 0]], [[2, 1], [3, 5]], slice(None)],\n        [[[0, 1], [1, 0]], [[2, 3]], slice(None)],\n        [[[0, 1], [2, 3]], [[0]], slice(None)],\n        [[[2, 1]], [[0, 3], [4, 4]], slice(None)],\n        [[[2]], [[0, 3], [4, 1]], slice(None)],\n        # non-contiguous indexing subspace\n        [[0, 2, 3], slice(None), [1, 3, 4]],\n        # less dim, ellipsis\n        [[0, 2],],\n        [[0, 2], slice(None)],\n        [[0, 2], Ellipsis],\n        [[0, 2], slice(None), Ellipsis],\n        [[0, 2], Ellipsis, slice(None)],\n        [[0, 2], [1, 3]],\n        [[0, 2], [1, 3], Ellipsis],\n        [Ellipsis, [1, 3], [2, 3]],\n        [Ellipsis, [2, 3, 4]],\n        [Ellipsis, slice(None), [2, 3, 4]],\n        [slice(None), Ellipsis, [2, 3, 4]],\n        # ellipsis counts for nothing\n        [Ellipsis, slice(None), slice(None), [0, 3, 4]],\n        [slice(None), Ellipsis, slice(None), [0, 3, 4]],\n        [slice(None), slice(None), Ellipsis, [0, 3, 4]],\n        [slice(None), slice(None), [0, 3, 4], Ellipsis],\n        [Ellipsis, [[0, 1], [1, 0]], [[2, 1], [3, 5]], slice(None)],\n        [[[0, 1], [1, 0]], [[2, 1], [3, 5]], Ellipsis, slice(None)],\n        [[[0, 1], [1, 0]], [[2, 1], [3, 5]], slice(None), Ellipsis],\n    ]\n\n    for indexer in indices_to_test:\n        assert_get_eq(reference, indexer)\n        assert_set_eq(reference, indexer, 212)\n        assert_set_eq(reference, indexer, get_set_tensor(reference, indexer))\n        if device != \"cpu\":\n            assert_backward_eq(reference, indexer)\n\n    reference = flow.arange(0.0, 1296, device=device).to(dtype).view(3, 9, 8, 6)\n\n    indices_to_test = [\n        [slice(None), slice(None), slice(None), [0, 3, 4]],\n        [slice(None), slice(None), [2, 4, 5, 7], slice(None)],\n        [slice(None), [2, 3], slice(None), slice(None)],\n        [[1, 2], slice(None), slice(None), slice(None)],\n        [slice(None), slice(None), [0, 2, 3], [1, 3, 4]],\n        [slice(None), slice(None), [0], [1, 2, 4]],\n        [slice(None), slice(None), [0, 1, 3], [4]],\n        [slice(None), slice(None), [[0, 1], [1, 0]], [[2, 3]]],\n        [slice(None), slice(None), [[0, 1], [2, 3]], [[0]]],\n        [slice(None), slice(None), [[5, 6]], [[0, 3], [4, 4]]],\n        [slice(None), [0, 2, 3], [1, 3, 4], slice(None)],\n        [slice(None), [0], [1, 2, 4], slice(None)],\n        [slice(None), [0, 1, 3], [4], slice(None)],\n        [slice(None), [[0, 1], [3, 4]], [[2, 3], [0, 1]], slice(None)],\n        [slice(None), [[0, 1], [3, 4]], [[2, 3]], slice(None)],\n        [slice(None), [[0, 1], [3, 2]], [[0]], slice(None)],\n        [slice(None), [[2, 1]], [[0, 3], [6, 4]], slice(None)],\n        [slice(None), [[2]], [[0, 3], [4, 2]], slice(None)],\n        [[0, 1, 2], [1, 3, 4], slice(None), slice(None)],\n        [[0], [1, 2, 4], slice(None), slice(None)],\n        [[0, 1, 2], [4], slice(None), slice(None)],\n        [[[0, 1], [0, 2]], [[2, 4], [1, 5]], slice(None), slice(None)],\n        [[[0, 1], [1, 2]], [[2, 0]], slice(None), slice(None)],\n        [[[2, 2]], [[0, 3], [4, 5]], slice(None), slice(None)],\n        [[[2]], [[0, 3], [4, 5]], slice(None), slice(None)],\n        [slice(None), [3, 4, 6], [0, 2, 3], [1, 3, 4]],\n        [slice(None), [2, 3, 4], [1, 3, 4], [4]],\n        [slice(None), [0, 1, 3], [4], [1, 3, 4]],\n        [slice(None), [6], [0, 2, 3], [1, 3, 4]],\n        [slice(None), [2, 3, 5], [3], [4]],\n        [slice(None), [0], [4], [1, 3, 4]],\n        [slice(None), [6], [0, 2, 3], [1]],\n        [slice(None), [[0, 3], [3, 6]], [[0, 1], [1, 3]], [[5, 3], [1, 2]]],\n        [[2, 2, 1], [0, 2, 3], [1, 3, 4], slice(None)],\n        [[2, 0, 1], [1, 2, 3], [4], slice(None)],\n        [[0, 1, 2], [4], [1, 3, 4], slice(None)],\n        [[0], [0, 2, 3], [1, 3, 4], slice(None)],\n        [[0, 2, 1], [3], [4], slice(None)],\n        [[0], [4], [1, 3, 4], slice(None)],\n        [[1], [0, 2, 3], [1], slice(None)],\n        [[[1, 2], [1, 2]], [[0, 1], [2, 3]], [[2, 3], [3, 5]], slice(None)],\n        # less dim, ellipsis\n        [Ellipsis, [0, 3, 4]],\n        [Ellipsis, slice(None), [0, 3, 4]],\n        [Ellipsis, slice(None), slice(None), [0, 3, 4]],\n        [slice(None), Ellipsis, [0, 3, 4]],\n        [slice(None), slice(None), Ellipsis, [0, 3, 4]],\n        [slice(None), [0, 2, 3], [1, 3, 4]],\n        [slice(None), [0, 2, 3], [1, 3, 4], Ellipsis],\n        [Ellipsis, [0, 2, 3], [1, 3, 4], slice(None)],\n        [[0], [1, 2, 4]],\n        [[0], [1, 2, 4], slice(None)],\n        [[0], [1, 2, 4], Ellipsis],\n        [[0], [1, 2, 4], Ellipsis, slice(None)],\n        [[1],],\n        [[0, 2, 1], [3], [4]],\n        [[0, 2, 1], [3], [4], slice(None)],\n        [[0, 2, 1], [3], [4], Ellipsis],\n        [Ellipsis, [0, 2, 1], [3], [4]],\n    ]\n\n    for indexer in indices_to_test:\n        assert_get_eq(reference, indexer)\n        assert_set_eq(reference, indexer, 1333)\n        assert_set_eq(reference, indexer, get_set_tensor(reference, indexer))\n    indices_to_test += [\n        [slice(None), slice(None), [[0, 1], [1, 0]], [[2, 3], [3, 0]]],\n        [slice(None), slice(None), [[2]], [[0, 3], [4, 4]]],\n    ]\n    for indexer in indices_to_test:\n        assert_get_eq(reference, indexer)\n        assert_set_eq(reference, indexer, 1333)\n        if device != \"cpu\":\n            assert_backward_eq(reference, indexer)\n\n\ndef _test_single_int(test_case, device):\n    v = flow.randn(5, 7, 3, device=device)\n    test_case.assertEqual(v[4].shape, (7, 3))\n\n\ndef _test_multiple_int(test_case, device):\n    v = flow.randn(5, 7, 3, device=device)\n    test_case.assertEqual(v[4].shape, (7, 3))\n    test_case.assertEqual(v[4, :, 1].shape, (7,))\n\n\ndef _test_none(test_case, device):\n    v = flow.randn(5, 7, 3, device=device)\n    test_case.assertEqual(v[None].shape, (1, 5, 7, 3))\n    test_case.assertEqual(v[:, None].shape, (5, 1, 7, 3))\n    test_case.assertEqual(v[:, None, None].shape, (5, 1, 1, 7, 3))\n    test_case.assertEqual(v[..., None].shape, (5, 7, 3, 1))\n\n\ndef _test_step(test_case, device):\n    v = flow.arange(10, device=device)\n    _assert_tensor_equal(test_case, v[::1], v)\n    test_case.assertEqual(v[::2].tolist(), [0, 2, 4, 6, 8])\n    test_case.assertEqual(v[::3].tolist(), [0, 3, 6, 9])\n    test_case.assertEqual(v[::11].tolist(), [0])\n    test_case.assertEqual(v[1:6:2].tolist(), [1, 3, 5])\n\n\ndef _test_step_assignment(test_case, device):\n    v = flow.zeros(4, 4, device=device)\n    v[0, 1::2] = flow.tensor([3.0, 4.0], device=device)\n    test_case.assertEqual(v[0].tolist(), [0.0, 3.0, 0.0, 4.0])\n    test_case.assertEqual(v[1:].sum(), 0)\n\n\ndef _test_bool_indices(test_case, device):\n    v = flow.randn(5, 7, 3, device=device)\n    boolIndices = flow.tensor(\n        [True, False, True, True, False], dtype=flow.bool, device=device\n    )\n    test_case.assertEqual(v[boolIndices].shape, (3, 7, 3))\n    _assert_tensor_equal(test_case, v[boolIndices], flow.stack([v[0], v[2], v[3]]))\n\n    v = flow.tensor([True, False, True], dtype=flow.bool, device=device)\n    boolIndices = flow.tensor([True, False, False], dtype=flow.bool, device=device)\n    uint8Indices = flow.tensor([1, 0, 0], dtype=flow.uint8, device=device)\n    test_case.assertEqual(v[boolIndices].shape, v[uint8Indices].shape)\n    test_case.assertEqual(v[boolIndices], v[uint8Indices])\n    test_case.assertEqual(\n        v[boolIndices], flow.tensor([True], dtype=flow.bool, device=device)\n    )\n\n\ndef _test_multiple_bool_indices(test_case, device):\n    v = flow.randn(5, 7, 3, device=device)\n    # NOTE: these broadcast together and are transposed to the first dim\n    mask1 = flow.tensor([1, 0, 1, 1, 0], dtype=flow.bool, device=device)\n    mask2 = flow.tensor([1, 1, 1], dtype=flow.bool, device=device)\n    test_case.assertEqual(v[mask1, :, mask2].shape, (3, 7))\n\n\ndef _test_int_indices(test_case, device):\n    v = flow.randn(5, 7, 3, device=device)\n    test_case.assertEqual(v[[0, 4, 2]].shape, (3, 7, 3))\n    test_case.assertEqual(v[:, [0, 4, 2]].shape, (5, 3, 3))\n    test_case.assertEqual(v[:, [[0, 1], [4, 3]]].shape, (5, 2, 2, 3))\n\n\ndef _test_int_indices2d(test_case, device):\n    x = flow.arange(0, 12, device=device).view(4, 3)\n    rows = flow.tensor([[0, 0], [3, 3]], device=device)\n    columns = flow.tensor([[0, 2], [0, 2]], device=device)\n    test_case.assertEqual(x[rows, columns].tolist(), [[0, 2], [9, 11]])\n\n\ndef _test_int_indices_broadcast(test_case, device):\n    x = flow.arange(0, 12, device=device).view(4, 3)\n    rows = flow.tensor([0, 3], device=device)\n    columns = flow.tensor([0, 2], device=device)\n    result = x[rows[:, None], columns]\n    test_case.assertEqual(result.tolist(), [[0, 2], [9, 11]])\n\n\ndef _test_empty_index(test_case, device):\n    x = flow.arange(0, 12, device=device).view(4, 3)\n    idx = flow.tensor([], dtype=flow.long, device=device)\n    test_case.assertEqual(x[idx].numel(), 0)\n\n    # empty assignment should have no effect but not throw an exception\n    y = x.clone()\n    y[idx] = -1\n    _assert_tensor_equal(test_case, x, y)\n\n    mask = flow.zeros(4, 3, device=device).to(flow.bool)\n    y[mask] = -1\n    _assert_tensor_equal(test_case, x, y)\n\n\ndef _test_empty_ndim_index(test_case, device):\n    x = flow.randn(5, device=device)\n    _assert_tensor_equal(\n        test_case,\n        flow.empty(0, 2, device=device),\n        x[flow.empty(0, 2, dtype=flow.int64, device=device)],\n    )\n\n    x = flow.randn(2, 3, 4, 5, device=device)\n    _assert_tensor_equal(\n        test_case,\n        flow.empty(2, 0, 6, 4, 5, device=device),\n        x[:, flow.empty(0, 6, dtype=flow.int64, device=device)],\n    )\n\n    x = flow.empty(10, 0, device=device)\n    test_case.assertEqual(x[[1, 2]].shape, (2, 0))\n    test_case.assertEqual(x[[], []].shape, (0,))\n    test_case.assertEqual(x[[[]]].shape, (0, 0))\n    test_case.assertEqual(x[[[[]]]].shape, (1, 0, 0))\n    test_case.assertEqual(x[[1], []].shape, (0,))\n    test_case.assertEqual(x[[], [2]].shape, (0,))\n    with test_case.assertRaisesRegex(IndexError, \"for dimension with size 0\"):\n        x[:, [0, 1]]\n\n\ndef _test_empty_ndim_index_bool(test_case, device):\n    x = flow.randn(5, device=device)\n    test_case.assertRaises(\n        IndexError, lambda: x[flow.empty(0, 2, dtype=flow.uint8, device=device)]\n    )\n\n\ndef _test_empty_slice(test_case, device):\n    x = flow.randn(2, 3, 4, 5, device=device)\n    y = x[:, :, :, 1]\n    z = y[:, 1:1, :]\n    test_case.assertEqual((2, 0, 4), z.shape)\n    # this isn't technically necessary, but matches NumPy stride calculations.\n    test_case.assertEqual((60, 20, 5), z.stride())\n    test_case.assertTrue(z.is_contiguous())\n\n\ndef _test_index_getitem_copy_bools_slices(test_case, device):\n    true = flow.tensor(1, dtype=flow.uint8, device=device)\n    false = flow.tensor(0, dtype=flow.uint8, device=device)\n\n    tensors = [flow.randn(2, 3, device=device), flow.tensor([1.0], device=device)]\n\n    # TODO: compare tensor_storage after exporting the inferface\n    for a in tensors:\n        #  test_case.assertNotEqual(a.data_ptr(), a[True].data_ptr())\n        _assert_tensor_equal(test_case, flow.empty(0, *a.shape), a[False])\n        #  test_case.assertNotEqual(a.data_ptr(), a[true].data_ptr())\n        _assert_tensor_equal(test_case, flow.empty(0, *a.shape), a[false])\n        #  test_case.assertEqual(a.data_ptr(), a[None].data_ptr())\n        #  test_case.assertEqual(a.data_ptr(), a[...].data_ptr())\n\n\ndef _test_setitem_scalars(test_case, device):\n    zero = flow.tensor(0, dtype=flow.int64)\n\n    # non-scalar indexed with scalars\n    a = flow.randn(2, 3, device=device)\n    a_set_with_number = a.clone()\n    a_set_with_scalar = a.clone()\n    b = flow.randn(3, device=device)\n\n    a_set_with_number[0] = b\n    a_set_with_scalar[zero] = b\n    _assert_tensor_equal(test_case, a_set_with_number, a_set_with_scalar)\n    a[1, zero] = 7.7\n    value = a[1, 0].numpy()\n    test_case.assertEqual(np.array(7.7, dtype=value.dtype), value)\n\n    np_x = np.random.rand(2, 3)\n    np_x[0, 0] = 1.0\n    x = flow.tensor(np_x)\n    x[0, 0] = 1.0\n    test_case.assertEqual(x.numpy().all(), np_x.all())\n\n    # scalar indexed with scalars\n    r = flow.tensor(1.0).to(device)\n    with test_case.assertRaises(IndexError):\n        r[:] = 8.8\n    with test_case.assertRaises(IndexError):\n        r[zero] = 8.8\n    r[...] = 9.9\n    test_case.assertEqual(r, 9.9)\n\n    # scalar indexed with oneflow.Size([1])\n    np_x = np.random.rand(2, 3)\n    np_x[0, 0] = np.ones(1)\n    x = flow.tensor(np_x)\n    x[0, 0] = flow.ones(1).to(flow.float64)\n    test_case.assertEqual(x.numpy().all(), np_x.all())\n\n\ndef _test_basic_advanced_combined(test_case, device):\n    x = flow.arange(0, 12, device=device).view(4, 3)\n    _assert_tensor_equal(test_case, x[1:2, 1:3], x[1:2, [1, 2]])\n    test_case.assertEqual(x[1:2, 1:3].tolist(), [[4, 5]])\n\n    # Check that it is a copy\n    unmodified = x.clone()\n    x[1:2, [1, 2]].zero_()\n    _assert_tensor_equal(test_case, x, unmodified)\n\n    # But assignment should modify the original\n    unmodified = x.clone()\n    x[1:2, [1, 2]] = 0\n    test_case.assertFalse(np.array_equal(x.numpy(), unmodified.numpy()))\n\n\ndef _test_ellipsis_tensor(test_case, device):\n    x = flow.arange(0, 9, device=device).view(3, 3)\n    idx = flow.tensor([0, 2], device=device)\n    test_case.assertEqual(x[..., idx].tolist(), [[0, 2], [3, 5], [6, 8]])\n    test_case.assertEqual(x[idx, ...].tolist(), [[0, 1, 2], [6, 7, 8]])\n\n    # Test scalar ellipsis getitem\n    y = flow.tensor(1.0).to(device)\n    x_scalar = flow.tensor(9.9)\n    y = x_scalar[...]\n    test_case.assertEqual(y, 9.9)\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestIndexing(flow.unittest.TestCase):\n    def test_slice(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"device\"] = [\"cpu\", \"cuda\"]\n        for arg in GenArgDict(arg_dict):\n            dtype_list = [flow.float32, flow.float16]\n            from oneflow import sysconfig\n\n            if not sysconfig.get_cuda_version() < 11000:\n                dtype_list.append(flow.bfloat16)\n\n            for dtype in dtype_list:\n                _test_basic_slice(test_case, **arg, dtype=dtype)\n                _test_advanced_indexing(test_case, **arg, dtype=dtype)\n                _test_combined_indexing(test_case, **arg, dtype=dtype)\n            _test_single_int(test_case, **arg)\n            _test_multiple_int(test_case, **arg)\n            _test_none(test_case, **arg)\n            _test_step(test_case, **arg)\n            _test_step_assignment(test_case, **arg)\n            _test_bool_indices(test_case, **arg)\n            _test_multiple_bool_indices(test_case, **arg)\n            _test_int_indices(test_case, **arg)\n            _test_int_indices2d(test_case, **arg)\n            _test_int_indices_broadcast(test_case, **arg)\n            _test_empty_index(test_case, **arg)\n            _test_empty_ndim_index(test_case, **arg)\n            _test_empty_ndim_index_bool(test_case, **arg)\n            _test_empty_slice(test_case, **arg)\n            _test_index_getitem_copy_bools_slices(test_case, **arg)\n            _test_setitem_scalars(test_case, **arg)\n            _test_basic_advanced_combined(test_case, **arg)\n            _test_ellipsis_tensor(test_case, **arg)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/tensor/test_tensor_is_view.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nimport random\nimport numpy as np\nfrom collections import OrderedDict\n\nimport oneflow as flow\n\nimport oneflow.unittest\nfrom oneflow.test_utils.test_util import GenArgList\n\n\ndef _test_is_view(test_case, device):\n    shape = (2, 3, 4, 5)\n    xx = flow.randn(shape, device=device)\n    yy = xx.reshape(4, 5, 6)\n    test_case.assertEqual(xx.is_contiguous(), yy.is_contiguous())\n    test_case.assertEqual(yy.is_view(), True)\n    test_case.assertEqual(xx.is_view(), False)\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestTensorIsView(flow.unittest.TestCase):\n    def test_is_view(test_case):\n        arg_dict = OrderedDict()\n        arg_dict[\"device\"] = [\"cuda\", \"cpu\"]\n        for arg in GenArgList(arg_dict):\n            _test_is_view(test_case, *arg[0:])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/tensor/test_tensor_part_1.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport copy\nimport os\nimport numpy as np\nimport unittest\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n@flow.unittest.skip_unless_1n1d()\nclass TestTensor(flow.unittest.TestCase):\n    @flow.unittest.skip_unless_1n1d()\n    def test_numpy_and_default_dtype(test_case):\n        shape = (2, 3, 4, 5)\n        tensor = flow.Tensor(*shape)\n        flow.nn.init.ones_(tensor)\n        test_case.assertTrue(tensor.dtype == flow.float32)\n        test_case.assertTrue(\n            np.allclose(tensor.numpy(), np.ones(shape, dtype=np.float32))\n        )\n\n        shape = flow.Size((2, 3, 4, 5))\n        tensor = flow.Tensor(shape)\n        flow.nn.init.ones_(tensor)\n        test_case.assertTrue(tensor.dtype == flow.float32)\n        test_case.assertTrue(\n            np.allclose(tensor.numpy(), np.ones(shape, dtype=np.float32))\n        )\n\n        shape = flow.Size((2, 3))\n        tensor = flow.Tensor(shape)\n        flow.nn.init.eye_(tensor)\n        test_case.assertTrue(tensor.dtype == flow.float32)\n        test_case.assertTrue(np.allclose(tensor.numpy(), np.eye(2, 3)))\n\n    @flow.unittest.skip_unless_1n1d()\n    def test_tensor_deepcopy(test_case):\n        shape = (2, 3)\n        tensor1 = flow.ones(*shape).cuda()\n        tensor2 = copy.deepcopy(tensor1)\n        tensor1[0, 0] = 0\n        test_case.assertEqual(tensor1.device, tensor2.device)\n        test_case.assertEqual(tensor1[0, 0], 0)\n        test_case.assertEqual(tensor2[0, 0], 1)\n\n    @flow.unittest.skip_unless_1n1d()\n    def test_tensor_property(test_case):\n        shape = (2, 3, 4, 5)\n        tensor = flow.Tensor(*shape)\n        test_case.assertEqual(tensor.storage_offset(), 0)\n        test_case.assertEqual(tensor.stride(), (60, 20, 5, 1))\n        test_case.assertEqual(tensor.is_cuda, False)\n        test_case.assertTrue(tensor.is_contiguous())\n\n    @flow.unittest.skip_unless_1n1d()\n    def test_copy_to_and_from_numpy(test_case):\n        np_arr = np.array([4, 6], dtype=np.float32)\n        tensor = flow.tensor(np_arr, dtype=flow.float32)\n        test_case.assertTrue(np.allclose(tensor.numpy(), np_arr))\n        test_case.assertEqual(np.float32, tensor.numpy().dtype)\n        np_arr = np.array([4, 6], dtype=np.int32)\n        tensor = flow.tensor(np_arr, dtype=flow.int32)\n        test_case.assertTrue(np.allclose(tensor.numpy(), np_arr))\n        test_case.assertEqual(np.int32, tensor.numpy().dtype)\n        np_arr = np.array([4, 6], dtype=np.float16)\n        tensor = flow.tensor(np_arr, dtype=flow.float16)\n        test_case.assertTrue(np.allclose(tensor.numpy(), np_arr))\n        test_case.assertEqual(np.float16, tensor.numpy().dtype)\n\n    @flow.unittest.skip_unless_1n1d()\n    def test_inplace_copy_from_contiguous_numpy(test_case):\n        np_arr = np.arange(6).reshape(3, 2)\n        tensor = flow.zeros(3, 2).to(flow.int64)\n        tensor.copy_(np_arr)\n        test_case.assertTrue(np.allclose(tensor.numpy(), np_arr))\n\n    @flow.unittest.skip_unless_1n1d()\n    def test_inplace_copy_from_non_contiguous_numpy(test_case):\n        np_arr = np.arange(6).reshape(2, 3).transpose(1, 0)\n        tensor = flow.zeros(3, 2).to(flow.int64)\n        tensor.copy_(np_arr)\n        test_case.assertTrue(np.allclose(tensor.numpy(), np_arr))\n\n    @flow.unittest.skip_unless_1n1d()\n    def test_construct_from_numpy_or_list(test_case):\n        shape = (2, 3, 4, 5)\n        np_arr = np.random.rand(*shape).astype(np.float32)\n        tensor = flow.tensor(np_arr)\n        test_case.assertTrue(np.allclose(tensor.numpy(), np_arr))\n        np_int_arr = np.random.randint(-100, high=100, size=shape, dtype=np.int32)\n        tensor = flow.tensor(np_int_arr, dtype=flow.int32)\n        test_case.assertEqual(tensor.dtype, flow.int32)\n        test_case.assertTrue(np_arr.flags[\"C_CONTIGUOUS\"])\n        test_case.assertTrue(np.allclose(tensor.numpy(), np_int_arr))\n        np_arr = np.random.random((1, 256, 256, 3)).astype(np.float32)\n        np_arr = np_arr.transpose(0, 3, 1, 2)\n        tensor = flow.tensor(np_arr)\n        test_case.assertFalse(np_arr.flags[\"C_CONTIGUOUS\"])\n        test_case.assertTrue(np.allclose(tensor.numpy(), np_arr))\n\n    @flow.unittest.skip_unless_1n1d()\n    def test_construct_from_another_tensor(test_case):\n        shape = (2, 3, 4, 5)\n        np_arr = np.random.rand(*shape).astype(np.float32)\n        tensor = flow.tensor(np_arr)\n        output = flow.tensor(tensor)\n        test_case.assertEqual(output.dtype, flow.float32)\n        test_case.assertTrue(np.allclose(output.numpy(), np_arr))\n\n    @flow.unittest.skip_unless_1n1d()\n    def test_construct_np_array_from_tensor(test_case):\n        tensor = flow.randn(5)\n        np_arr = np.array(tensor)\n        test_case.assertEqual(np_arr.shape, (5,))\n        test_case.assertEqual(np_arr.dtype, np.float32)\n        test_case.assertTrue(np.allclose(np_arr, tensor.numpy()))\n        test_case.assertEqual(str(np_arr), str(tensor.numpy()))\n\n    @flow.unittest.skip_unless_1n1d()\n    @autotest(n=5)\n    def test_tensor_sign_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor().to(device)\n        y = x.sign()\n        return y\n\n    @flow.unittest.skip_unless_1n1d()\n    @autotest(n=5)\n    def test_flow_tensor_gather_with_random_data(test_case):\n        device = random_device()\n        input = random_tensor(ndim=4, dim1=3, dim2=4, dim3=5).to(device)\n        dim = random(0, 4).to(int).value()\n        index = random_tensor(\n            ndim=4,\n            dim1=random(1, 3).to(int),\n            dim2=random(1, 4).to(int),\n            dim3=random(1, 5).to(int),\n            low=0,\n            high=1 if dim == 0 else dim,\n            dtype=int,\n        ).to(device)\n        return input.gather(dim, index)\n\n    def _test_tensor_init_methods(test_case, tensor_creator, get_numpy):\n        for dtype in [flow.float32, flow.float16]:\n            shape = (2, 3, 4, 5)\n            x = tensor_creator(*shape).to(dtype)\n            np_ones = np.ones(x.shape)\n            np_zeros = np.zeros(x.shape)\n            random_fill_val = 2.0\n            x.fill_(random_fill_val)\n            test_case.assertTrue(np.allclose(get_numpy(x), random_fill_val * np_ones))\n            flow.nn.init.ones_(x)\n            test_case.assertTrue(np.allclose(get_numpy(x), np_ones))\n            flow.nn.init.zeros_(x)\n            test_case.assertTrue(np.allclose(get_numpy(x), np_zeros))\n            flow.nn.init.constant_(x, random_fill_val)\n            test_case.assertTrue(np.allclose(get_numpy(x), random_fill_val * np_ones))\n            z = tensor_creator(5, 4, 3, 2)\n            flow.nn.init.kaiming_normal_(z, a=0.1, mode=\"fan_out\", nonlinearity=\"relu\")\n            flow.nn.init.kaiming_uniform_(z)\n            z.requires_grad_()\n            flow.nn.init.xavier_normal_(z, flow.nn.init.calculate_gain(\"relu\"))\n            flow.nn.init.xavier_uniform_(z, flow.nn.init.calculate_gain(\"relu\"))\n            flow.nn.init.xavier_normal_(\n                z, flow.nn.init.calculate_gain(\"leaky_relu\", 0.2)\n            )\n            flow.nn.init.xavier_uniform_(\n                z, flow.nn.init.calculate_gain(\"leaky_relu\", 0.2)\n            )\n            flow.nn.init.trunc_normal_(z, mean=0.0, std=1.0, a=-2.0, b=2.0)\n            flow.nn.init.normal_(z, mean=0.0, std=1.0)\n            flow.nn.init.orthogonal_(z)\n\n        x = tensor_creator(*shape).to(dtype=flow.int32)\n        np_ones = np.ones(x.shape, dtype=np.int32)\n        np_zeros = np.zeros(x.shape, dtype=np.int32)\n        random_fill_val = -2\n        x.fill_(random_fill_val)\n        test_case.assertTrue(np.allclose(get_numpy(x), random_fill_val * np_ones))\n        flow.nn.init.ones_(x)\n        test_case.assertTrue(np.allclose(get_numpy(x), np_ones))\n        flow.nn.init.zeros_(x)\n        test_case.assertTrue(np.allclose(get_numpy(x), np_zeros))\n        flow.nn.init.constant_(x, random_fill_val)\n        test_case.assertTrue(np.allclose(get_numpy(x), random_fill_val * np_ones))\n        x.zero_()\n        test_case.assertTrue(np.array_equal(get_numpy(x), np_zeros))\n        test_case.assertEqual(flow.nn.init.calculate_gain(\"conv2d\"), 1)\n        test_case.assertEqual(flow.nn.init.calculate_gain(\"tanh\"), 5.0 / 3)\n\n    def _test_non_contiguous_tensor_init_methods(test_case, tensor_creator, get_numpy):\n        shape = (8, 8)\n        x = flow.zeros(shape)\n        sliced_x = x[::2, 1::2]\n        not_sliced_x = x[1::2, ::2]\n        random_fill_val = 923.53\n        np_zeros = np.zeros((4, 4))\n        # ones\n        flow.nn.init.ones_(sliced_x)\n        test_case.assertTrue(np.allclose(get_numpy(sliced_x), np.ones((4, 4))))\n        test_case.assertTrue(np.allclose(get_numpy(not_sliced_x), np_zeros))\n        # constant\n        flow.nn.init.constant_(sliced_x, random_fill_val)\n        test_case.assertTrue(\n            np.allclose(get_numpy(sliced_x), np.ones((4, 4)) * random_fill_val)\n        )\n        test_case.assertTrue(np.allclose(get_numpy(not_sliced_x), np_zeros))\n        # eye\n        flow.nn.init.eye_(sliced_x)\n        test_case.assertTrue(np.allclose(get_numpy(sliced_x), np.eye(4)))\n        test_case.assertTrue(np.allclose(get_numpy(not_sliced_x), np_zeros))\n        # kaiming_normal_\n        flow.nn.init.kaiming_normal_(\n            sliced_x, a=0.1, mode=\"fan_out\", nonlinearity=\"relu\"\n        )\n        test_case.assertTrue(np.allclose(get_numpy(not_sliced_x), np_zeros))\n        # kaiming_uniform_\n        flow.nn.init.kaiming_uniform_(sliced_x)\n        test_case.assertTrue(np.allclose(get_numpy(not_sliced_x), np_zeros))\n        # xavier_normal_ with relu gain\n        flow.nn.init.xavier_normal_(sliced_x, flow.nn.init.calculate_gain(\"relu\"))\n        test_case.assertTrue(np.allclose(get_numpy(not_sliced_x), np_zeros))\n        # xavier_uniform_ with relu gain\n        flow.nn.init.xavier_uniform_(sliced_x, flow.nn.init.calculate_gain(\"relu\"))\n        test_case.assertTrue(np.allclose(get_numpy(not_sliced_x), np_zeros))\n        # trunc_normal_\n        flow.nn.init.trunc_normal_(sliced_x, mean=0.0, std=1.0, a=-2.0, b=2.0)\n        test_case.assertTrue(np.allclose(get_numpy(not_sliced_x), np_zeros))\n        # normal_\n        flow.nn.init.normal_(sliced_x, mean=0.0, std=1.0)\n        test_case.assertTrue(np.allclose(get_numpy(not_sliced_x), np_zeros))\n        # orthogonal_\n        flow.nn.init.orthogonal_(sliced_x)\n        test_case.assertTrue(np.allclose(get_numpy(not_sliced_x), np_zeros))\n\n    @flow.unittest.skip_unless_1n1d()\n    def test_local_tensor_init_methods(test_case):\n        for device in [\"cpu\", \"cuda\"]:\n            test_case._test_tensor_init_methods(\n                lambda *args, **kwargs: flow.Tensor(*args, **kwargs, device=device),\n                lambda x: x.numpy(),\n            )\n            test_case._test_non_contiguous_tensor_init_methods(\n                lambda *args, **kwargs: flow.Tensor(*args, **kwargs, device=device),\n                lambda x: x.numpy(),\n            )\n\n    @flow.unittest.skip_unless_1n2d()\n    def test_global_tensor_init_methods(test_case):\n        for device in [\"cpu\", \"cuda\"]:\n            test_case._test_tensor_init_methods(\n                lambda *args, **kwargs: flow.Tensor(\n                    *args,\n                    **kwargs,\n                    sbp=flow.sbp.broadcast,\n                    placement=flow.placement(device, range(2))\n                ),\n                lambda x: x.to_global(sbp=flow.sbp.broadcast).to_local().numpy(),\n            )\n\n    @flow.unittest.skip_unless_1n1d()\n    def test_tensor_with_single_int(test_case):\n        x = flow.Tensor(5)\n        test_case.assertEqual(x.shape, flow.Size([5]))\n        x = flow.tensor(5)\n        test_case.assertEqual(x.numpy().item(), 5)\n\n    @flow.unittest.skip_unless_1n1d()\n    def test_tensor_device(test_case):\n        shape = (2, 3, 4, 5)\n        x = flow.Tensor(*shape)\n        test_case.assertTrue(not x.is_cuda)\n        x = flow.Tensor(*shape, device=flow.device(\"cuda\"))\n        test_case.assertTrue(x.is_cuda)\n        x = flow.Tensor(*shape, device=flow.device(\"cpu\"))\n        test_case.assertTrue(not x.is_cuda)\n\n    @flow.unittest.skip_unless_1n1d()\n    @autotest(n=1, check_graph=True)\n    def test_tensor_set_data_autograd_meta(test_case):\n        x = torch.ones(2, 3).requires_grad_(True)\n        y = x + x\n        z = torch.zeros(2, 3)\n        z.data = y\n        return z.grad_fn, z.is_leaf\n\n    @flow.unittest.skip_unless_1n1d()\n    def test_tensor_set_data(test_case):\n        a = flow.ones(2, 3, requires_grad=False)\n        b = flow.ones(4, 5, requires_grad=True).to(\"cuda\")\n        old_id = id(a)\n        a.data = b\n        test_case.assertEqual(old_id, id(a))\n        test_case.assertTrue(a.shape == (4, 5))\n        test_case.assertTrue(a.device == flow.device(\"cuda\"))\n        test_case.assertFalse(a.requires_grad)\n        test_case.assertTrue(a.is_leaf)\n\n    @flow.unittest.skip_unless_1n1d()\n    def test_tensor_set_ref_tensor(test_case):\n        a = flow.ones(2, 3, requires_grad=False)\n        b = flow.ones(4, 5, requires_grad=True).to(\"cuda\")\n        test_case.assertEqual(a._ref_tensor, None)\n        test_case.assertEqual(a._ref_index, 0)\n        a._ref_tensor = b\n        a._ref_index = 200\n        test_case.assertTrue(id(a._ref_tensor), id(b))\n        test_case.assertTrue(a._ref_tensor.shape == (4, 5))\n        test_case.assertTrue(a._ref_tensor.device == flow.device(\"cuda\"))\n        test_case.assertTrue(a._ref_tensor.requires_grad)\n        test_case.assertTrue(a._ref_index, 200)\n\n    @flow.unittest.skip_unless_1n1d()\n    def test_tensor_unsupported_property(test_case):\n\n        shape = (2, 3, 4, 5)\n        x = flow.Tensor(*shape)\n        test_case.assertTrue(x.is_local)\n\n        with test_case.assertRaises(RuntimeError):\n            x.global_id()\n\n        with test_case.assertRaises(RuntimeError):\n            x.sbp\n\n        with test_case.assertRaises(RuntimeError):\n            x.placement\n\n        if x.dtype != flow.tensor_buffer:\n            with test_case.assertRaises(RuntimeError):\n                x._tensor_buffer_shapes_and_dtypes\n\n    @flow.unittest.skip_unless_1n1d()\n    def test_tensor_to_bool(test_case):\n        x = flow.tensor([0.0])\n        test_case.assertFalse(bool(x))\n        x = flow.tensor([0.0]).to(\"cuda\")\n        test_case.assertFalse(bool(x))\n        x = flow.tensor([1.5])\n        test_case.assertTrue(bool(x))\n        x = flow.tensor([3])\n        test_case.assertTrue(bool(x))\n        with test_case.assertRaises(RuntimeError):\n            bool(flow.tensor([1, 3, 5]))\n            bool(flow.tensor([]))\n\n    @flow.unittest.skip_unless_1n1d()\n    def test_tensor_autograd_fill_cpu(test_case):\n        shape = (2, 3, 4, 5)\n        x = flow.Tensor(*shape)\n        y = flow.Tensor(*shape)\n        x.fill_(1.0)\n        y.fill_(flow.tensor(1.0))\n        y.requires_grad = True\n        z = x + y\n        test_case.assertFalse(x.requires_grad)\n        test_case.assertTrue(x.is_leaf)\n        test_case.assertTrue(y.requires_grad)\n        test_case.assertTrue(y.is_leaf)\n        test_case.assertTrue(z.requires_grad)\n        test_case.assertFalse(z.is_leaf)\n        with flow.no_grad():\n            m = x + y\n        test_case.assertTrue(m.is_leaf)\n        test_case.assertFalse(m.requires_grad)\n        m.requires_grad = True\n        v = flow.Tensor(*shape)\n        v.requires_grad = True\n        z.retain_grad()\n        w = v + z\n        grad = flow.Tensor(*shape)\n        grad.fill_(1.0)\n        w.backward(gradient=grad, retain_graph=True)\n        test_case.assertTrue(\n            np.allclose(v.grad.numpy(), np.ones(shape), atol=1e-4, rtol=1e-4)\n        )\n        test_case.assertTrue(\n            np.allclose(y.grad.numpy(), np.ones(shape), atol=1e-4, rtol=1e-4)\n        )\n        test_case.assertTrue(\n            np.allclose(z.grad.numpy(), np.ones(shape), atol=1e-4, rtol=1e-4)\n        )\n        test_case.assertIsNone(x.grad)\n        test_case.assertIsNotNone(y.grad)\n        w.backward(gradient=grad, retain_graph=True)\n        # autocast test for fill_\n        x = flow.tensor([2.4, 3.5], device=\"cuda\", dtype=flow.float16)\n        with flow.amp.autocast(\"cuda\", flow.float16):\n            y = x.clone()\n            y.fill_(2.36)\n            test_case.assertTrue(y.dtype == flow.float16)\n\n    @flow.unittest.skip_unless_1n1d()\n    def test_tensor_autograd_fill_cuda(test_case):\n        shape = (2, 3, 4, 5)\n        x = flow.Tensor(*shape).to(\"cuda:0\")\n        y = flow.Tensor(*shape).to(\"cuda:0\")\n        x.fill_(1.0)\n        y.fill_(flow.tensor(1.0).to(\"cuda:0\"))\n        y.requires_grad = True\n        z = x + y\n        test_case.assertFalse(x.requires_grad)\n        test_case.assertTrue(x.is_leaf)\n        test_case.assertTrue(y.requires_grad)\n        test_case.assertTrue(y.is_leaf)\n        test_case.assertTrue(z.requires_grad)\n        test_case.assertFalse(z.is_leaf)\n        with flow.no_grad():\n            m = x + y\n        test_case.assertTrue(m.is_leaf)\n        test_case.assertFalse(m.requires_grad)\n        m.requires_grad = True\n        v = flow.Tensor(*shape).to(\"cuda:0\")\n        v.requires_grad = True\n        z.retain_grad()\n        w = v + z\n        grad = flow.Tensor(*shape)\n        grad.fill_(1.0)\n        w.backward(gradient=grad, retain_graph=True)\n        test_case.assertTrue(\n            np.allclose(v.grad.numpy(), np.ones(shape), atol=1e-4, rtol=1e-4)\n        )\n        test_case.assertTrue(\n            np.allclose(y.grad.numpy(), np.ones(shape), atol=1e-4, rtol=1e-4)\n        )\n        test_case.assertTrue(\n            np.allclose(z.grad.numpy(), np.ones(shape), atol=1e-4, rtol=1e-4)\n        )\n        test_case.assertIsNone(x.grad)\n        test_case.assertIsNotNone(y.grad)\n        w.backward(gradient=grad, retain_graph=True)\n\n    @flow.unittest.skip_unless_1n1d()\n    def test_tensor_register_post_grad_accumulation_hook(test_case):\n        shape = (2, 3)\n        x = flow.Tensor(*shape)\n        x.requires_grad = True\n        x._register_post_grad_accumulation_hook(lambda grad: grad * 2 + 1)\n        y = x.sum() + (x * 2).sum()\n        y.backward()\n        test_case.assertTrue(\n            np.allclose(x.grad.numpy(), np.ones(shape) * 7, atol=1e-4, rtol=1e-4)\n        )\n\n        x = flow.Tensor(*shape)\n        x.requires_grad = True\n\n        def inplace_add_and_return_none(x):\n            x.add_(1)\n            return None\n\n        x._register_post_grad_accumulation_hook(inplace_add_and_return_none)\n        y = x.sum() + (x * 2).sum()\n        y.backward()\n        test_case.assertTrue(\n            np.allclose(x.grad.numpy(), np.ones(shape) * 4, atol=1e-4, rtol=1e-4)\n        )\n\n    @flow.unittest.skip_unless_1n1d()\n    def test_tensor_register_hook(test_case):\n        shape = (2, 3)\n        x = flow.Tensor(*shape)\n        x.requires_grad = True\n        x.register_hook(lambda grad: grad * 2 + 1)\n        y = x.sum() + (x * 2).sum()\n        y.backward()\n        test_case.assertTrue(\n            np.allclose(x.grad.numpy(), np.ones(shape) * 7, atol=1e-4, rtol=1e-4)\n        )\n        x = flow.Tensor(*shape)\n        x.requires_grad = True\n        new_grad = flow.Tensor([[1, 2, 3], [4, 5, 6]])\n        x.register_hook(lambda _: new_grad)\n        y = x.sum() + (x * 2).sum()\n        y.backward()\n        test_case.assertTrue(np.allclose(x.grad.numpy(), new_grad.numpy()))\n        grad_nonlocal = None\n\n        def assign_nonlocal_variable_and_return_none(grad):\n            nonlocal grad_nonlocal\n            grad_nonlocal = grad\n\n        x = flow.Tensor(*shape)\n        x.requires_grad = True\n        new_grad = flow.tensor([[1, 2, 3], [4, 5, 6]], dtype=flow.float32)\n        x.register_hook(assign_nonlocal_variable_and_return_none)\n        y = x.sum() + (x * 2).sum()\n        y.backward()\n        test_case.assertTrue(np.allclose(grad_nonlocal.numpy(), np.ones(shape) * 3))\n\n    @flow.unittest.skip_unless_1n1d()\n    def test_non_leaf_tensor_register_hook(test_case):\n        shape = (2, 3)\n        x = flow.Tensor(*shape).requires_grad_()\n        y = x + 1\n        y.register_hook(lambda grad: grad * 2)\n        z1 = y * 2\n        z2 = y * 3\n        loss = (z1 + z2).sum()\n        loss.backward(retain_graph=True)\n        loss.backward()\n        test_case.assertTrue(np.allclose(x.grad.numpy(), np.ones(shape) * 20))\n\n    @flow.unittest.skip_unless_1n1d()\n    def test_user_defined_data(test_case):\n        list_data = [5, 5]\n        tuple_data = (5, 5)\n        numpy_data = np.array((5, 5))\n        x = flow.Tensor(list_data)\n        y = flow.Tensor(tuple_data)\n        z = flow.Tensor(numpy_data)\n        test_case.assertTrue(np.allclose(x.numpy(), 5 * np.ones(x.shape)))\n        test_case.assertTrue(np.allclose(y.numpy(), 5 * np.ones(y.shape)))\n        test_case.assertTrue(np.allclose(z.numpy(), 5 * np.ones(z.shape)))\n\n    @flow.unittest.skip_unless_1n1d()\n    def test_local_tensor_and_op(test_case):\n        x1 = flow.Tensor([[1.0, 2.0]])\n        test_case.assertEqual(x1.dtype, flow.float32)\n        test_case.assertEqual(x1.shape, flow.Size((1, 2)))\n        x2 = flow.Tensor([[1.0], [2.0]])\n        y = flow.matmul(x1, x2)\n        test_case.assertTrue(\n            np.allclose(y.numpy(), np.array([[5.0]], dtype=np.float32))\n        )\n\n    @flow.unittest.skip_unless_1n1d()\n    @autotest(n=5, rtol=1e-2, atol=1e-3)\n    def test_matmul_with_random_data(test_case):\n        device = random_device()\n        dim0 = random(low=2, high=10).to(int)\n        dim1 = random(low=3, high=20).to(int)\n        dim2 = random(low=2, high=11).to(int)\n        a = random_tensor(ndim=2, dim0=dim0, dim1=dim1).to(device)\n        b = random_tensor(ndim=2, dim0=dim1, dim1=dim2).to(device)\n        return a @ b\n\n    @flow.unittest.skip_unless_1n1d()\n    @autotest(n=5)\n    def test_mv_with_random_data(test_case):\n        device = random_device()\n        dim0 = random(low=2, high=10).to(int)\n        dim1 = random(low=3, high=20).to(int)\n        a = random_tensor(ndim=2, dim0=dim0, dim1=dim1).to(device)\n        b = random_tensor(ndim=1, dim0=dim1).to(device)\n        return a.mv(b)\n\n    @flow.unittest.skip_unless_1n1d()\n    @autotest(check_graph=True, rtol=1e-2, atol=1e-3)\n    def test_mm_with_random_data(test_case):\n        device = random_device()\n        dim0 = random(low=2, high=10).to(int)\n        dim1 = random(low=3, high=20).to(int)\n        dim2 = random(low=2, high=11).to(int)\n        a = random_tensor(ndim=2, dim0=dim0, dim1=dim1).to(device)\n        b = random_tensor(ndim=2, dim0=dim1, dim1=dim2).to(device)\n        return a.mm(b)\n\n    @flow.unittest.skip_unless_1n1d()\n    def test_tensor_to_list(test_case):\n        list_data = [[1.0, 3.0], [5.0, 6.0]]\n        input = flow.Tensor(list_data)\n        test_case.assertEqual(list_data, input.tolist())\n\n    @flow.unittest.skip_unless_1n1d()\n    def test_tensor_nelement(test_case):\n        shape = (2, 3, 4)\n        input = flow.Tensor(*shape)\n        test_case.assertEqual(input.nelement(), 24)\n\n    @flow.unittest.skip_unless_1n1d()\n    def test_tensor_numel(test_case):\n        shape = (2, 3, 4, 5)\n        input = flow.Tensor(*shape)\n        test_case.assertEqual(input.numel(), 120)\n\n    @flow.unittest.skip_unless_1n1d()\n    def test_tensor_print(test_case):\n        shape = (2, 3, 4, 5)\n        input = flow.Tensor(*shape)\n        input_str = str(input)\n        test_case.assertTrue(input_str.startswith(\"tensor(\"))\n        test_case.assertTrue(\"device=\" not in input_str)\n        gpu_input = flow.Tensor(*shape, device=\"cuda\")\n        gpu_input_str = str(gpu_input)\n        test_case.assertTrue(\"device=\" in gpu_input_str)\n        test_case.assertTrue(\"cuda:0\" in gpu_input_str)\n        requires_grad_input = flow.Tensor(*shape)\n        requires_grad_input.requires_grad = True\n        requires_grad_input_str = str(requires_grad_input)\n        test_case.assertTrue(\"requires_grad=\" in requires_grad_input_str)\n\n    @unittest.skip(\"skip for now, becase it failed 2 times in past week\")\n    @flow.unittest.skip_unless_1n1d()\n    def test_indexing(test_case):\n        class SliceExtracter:\n            def __getitem__(self, key):\n                return key\n\n        se = SliceExtracter()\n\n        def compare_getitem_with_numpy(tensor, slices):\n            np_arr = tensor.numpy()\n            test_case.assertTrue(np.allclose(np_arr[slices], tensor[slices].numpy()))\n\n        def compare_setitem_with_numpy(tensor, slices, value):\n            np_arr = tensor.numpy()\n            if isinstance(value, flow.Tensor):\n                np_value = value.numpy()\n            else:\n                np_value = value\n            np_arr[slices] = np_value\n            tensor[slices] = value\n            test_case.assertTrue(np.allclose(np_arr, tensor.numpy(), rtol=1e-4))\n\n        x = flow.randn(5, 5)\n        v = flow.Tensor([[0, 1, 2, 3, 4]])\n        compare_getitem_with_numpy(x, se[-4:-1:2])\n        compare_getitem_with_numpy(x, se[-1:])\n        compare_setitem_with_numpy(x, se[-1:], v)\n        compare_setitem_with_numpy(x, se[2::2], 2)\n        x = flow.Tensor(2, 3, 4)\n        v = flow.Tensor(3)\n        compare_setitem_with_numpy(x, se[:, :, 2], v)\n        x = flow.Tensor(2, 3, 4)\n        compare_setitem_with_numpy(x, se[1, :, 2], v)\n\n    @flow.unittest.skip_unless_1n1d()\n    @autotest(n=5, auto_backward=False)\n    def test_setitem_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(low=0, high=0, ndim=1, dim0=16, requires_grad=False).to(\n            device\n        )\n        y = random_tensor(low=-2, high=2, ndim=1, dim0=16).to(device)\n        idx = random_tensor(\n            low=0, high=15, ndim=1, dim0=20, dtype=int, requires_grad=False\n        ).to(device)\n\n        getitem_of = y.oneflow[idx.oneflow]\n        getitem_torch = y.pytorch[idx.pytorch]\n        test_case.assertTrue(\n            np.allclose(getitem_of.numpy(), getitem_torch.detach().cpu().numpy())\n        )\n\n        x.oneflow[idx.oneflow] = getitem_of\n        x.pytorch[idx.pytorch] = getitem_torch\n        return x\n\n    @flow.unittest.skip_unless_1n1d()\n    def test_div(test_case):\n        x = flow.Tensor(np.random.randn(1, 1))\n        y = flow.Tensor(np.random.randn(2, 3))\n        of_out = x / y\n        np_out = np.divide(x.numpy(), y.numpy())\n        test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001))\n        x = flow.Tensor(np.random.randn(2, 3))\n        of_out = x / 3\n        np_out = np.divide(x.numpy(), 3)\n        test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001))\n        x = flow.Tensor(np.random.randn(2, 3))\n        of_out = 3 / x\n        np_out = np.divide(3, x.numpy())\n        test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001))\n        x = flow.Tensor(np.random.randn(1))\n        of_out = 3 / x\n        np_out = np.divide(3, x.numpy())\n        test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001))\n\n    @flow.unittest.skip_unless_1n1d()\n    def test_mul(test_case):\n        x = flow.Tensor(np.random.randn(1, 1))\n        y = flow.Tensor(np.random.randn(2, 3))\n        of_out = x * y\n        np_out = np.multiply(x.numpy(), y.numpy())\n        test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001))\n        x = flow.Tensor(np.random.randn(2, 3))\n        of_out = x * 3\n        np_out = np.multiply(x.numpy(), 3)\n        test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001))\n        x = flow.Tensor(np.random.randn(2, 3))\n        of_out = 3 * x\n        np_out = np.multiply(3, x.numpy())\n        test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001))\n\n    @flow.unittest.skip_unless_1n1d()\n    @autotest(n=5)\n    def test_mul_inplace_tensor(test_case):\n        device = random_device()\n        rand_tensor = random_tensor(\n            low=-2, high=2, ndim=4, dim0=16, dim1=9, dim2=4, dim3=7\n        ).to(device)\n        y = rand_tensor + 1\n        x = random_tensor(low=-2, high=2, ndim=4, dim0=16, dim1=9, dim2=4, dim3=7).to(\n            device\n        )\n        y.mul_(x)\n        return y\n\n    @flow.unittest.skip_unless_1n1d()\n    @autotest(n=5)\n    def test_broadcast_mul_inplace_tensor(test_case):\n        device = random_device()\n        rand_tensor = random_tensor(ndim=3, dim0=4, dim1=8, dim2=13).to(device)\n        y = rand_tensor + 1\n        x = random_tensor(ndim=2, dim0=8, dim1=13).to(device)\n        y.mul_(x)\n        return y\n\n    @flow.unittest.skip_unless_1n1d()\n    @autotest(n=5)\n    def test_div_inplace_tensor(test_case):\n        device = random_device()\n        rand_tensor = random_tensor(\n            low=-2, high=2, ndim=4, dim0=26, dim1=7, dim2=4, dim3=17\n        ).to(device)\n        y = rand_tensor + 1\n        x = random_tensor(low=-2, high=2, ndim=4, dim0=26, dim1=7, dim2=4, dim3=17).to(\n            device\n        )\n        y.div_(x)\n        return y\n\n    @flow.unittest.skip_unless_1n1d()\n    @autotest(n=5)\n    def test_broadcast_div_inplace_tensor(test_case):\n        device = random_device()\n        rand_tensor = random_tensor(ndim=3, dim0=4, dim1=8, dim2=13).to(device)\n        y = rand_tensor + 1\n        x = random_tensor(ndim=2, dim0=8, dim1=13).to(device)\n        y.div_(x)\n        return y\n\n    @flow.unittest.skip_unless_1n1d()\n    @autotest(n=5)\n    def test_add_inplace_tensor(test_case):\n        device = random_device()\n        rand_tensor = random_tensor(\n            low=-2, high=2, ndim=4, dim0=6, dim1=9, dim2=14, dim3=17\n        ).to(device)\n        y = rand_tensor + 1\n        x = random_tensor(low=-2, high=2, ndim=4, dim0=6, dim1=9, dim2=14, dim3=17).to(\n            device\n        )\n        y.add_(x)\n        return y\n\n    @flow.unittest.skip_unless_1n1d()\n    @autotest(n=5)\n    def test_broadcast_add_inplace_tensor(test_case):\n        device = random_device()\n        rand_tensor = random_tensor(ndim=3, dim0=5, dim1=9, dim2=23).to(device)\n        y = rand_tensor + 1\n        x = random_tensor(ndim=2, dim0=9, dim1=23).to(device)\n        y.add_(x)\n        return y\n\n    @flow.unittest.skip_unless_1n1d()\n    @autotest(n=5)\n    def test_sub_inplace_tensor(test_case):\n        device = random_device()\n        rand_tensor = random_tensor(\n            low=-2, high=2, ndim=4, dim0=6, dim1=9, dim2=14, dim3=17\n        ).to(device)\n        y = rand_tensor + 1\n        x = random_tensor(low=-2, high=2, ndim=4, dim0=6, dim1=9, dim2=14, dim3=17).to(\n            device\n        )\n        y.sub_(x)\n        return y\n\n    @flow.unittest.skip_unless_1n1d()\n    @autotest(n=5)\n    def test_broadcast_sub_inplace_tensor(test_case):\n        device = random_device()\n        rand_tensor = random_tensor(ndim=3, dim0=5, dim1=9, dim2=23).to(device)\n        y = rand_tensor + 1\n        x = random_tensor(ndim=2, dim0=9, dim1=23).to(device)\n        y.sub_(x)\n        return y\n\n    @flow.unittest.skip_unless_1n1d()\n    def test_add_tensor_method(test_case):\n        x = flow.Tensor(np.random.randn(1, 1))\n        y = flow.Tensor(np.random.randn(2, 3))\n        of_out = x + y\n        np_out = np.add(x.numpy(), y.numpy())\n        test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001))\n        x = flow.Tensor(np.random.randn(2, 3))\n        of_out = x + 3\n        np_out = np.add(x.numpy(), 3)\n        test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001))\n        x = flow.Tensor(np.random.randn(2, 3))\n        of_out = 3 + x\n        np_out = np.add(3, x.numpy())\n        test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001))\n\n    @flow.unittest.skip_unless_1n1d()\n    def test_sub_tensor_method(test_case):\n        x = flow.Tensor(np.random.randn(1, 1))\n        y = flow.Tensor(np.random.randn(2, 3))\n        of_out = x - y\n        np_out = np.subtract(x.numpy(), y.numpy())\n        test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001))\n        x = flow.Tensor(np.random.randn(2, 3))\n        of_out = x - 3\n        np_out = np.subtract(x.numpy(), 3)\n        test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001))\n        x = flow.Tensor(np.random.randn(2, 3))\n        of_out = 3 - x\n        np_out = np.subtract(3, x.numpy())\n        test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001))\n\n    @flow.unittest.skip_unless_1n1d()\n    def test_sum(test_case):\n        input = flow.tensor(np.random.randn(4, 5, 6), dtype=flow.float32)\n        of_out = input.sum(dim=(2, 1))\n        np_out = np.sum(input.numpy(), axis=(2, 1))\n        test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001))\n\n    @flow.unittest.skip_unless_1n1d()\n    def test_argwhere(test_case):\n        shape = (2, 3, 4, 5)\n        precision = 1e-5\n        np_input = np.random.randn(*shape)\n        input = flow.Tensor(np_input)\n        of_out = input.argwhere()\n        np_out = np.argwhere(np_input)\n        test_case.assertTrue(np.allclose(of_out.numpy(), np_out, precision, precision))\n        test_case.assertTrue(np.allclose(of_out.numpy().shape, np_out.shape))\n\n    @flow.unittest.skip_unless_1n1d()\n    @autotest(n=5, auto_backward=False, check_graph=True)\n    def test_tensor_argmax_with_random_data(test_case):\n        device = random_device()\n        ndim = random(1, 6).to(int)\n        x = random_tensor(ndim=ndim).to(device)\n        y = x.argmax(dim=random(0, ndim).to(int), keepdim=random().to(bool))\n        return y\n\n    @autotest(auto_backward=False, check_graph=False)\n    def test_max_bool_input_with_random_data(test_case):\n        device = random_device()\n        dim = random(1, 4).to(int)\n        x = random_tensor(ndim=4, dtype=float, requires_grad=False).to(\n            device, dtype=torch.bool\n        )\n        return x.max(dim)\n\n    @autotest(auto_backward=False, check_graph=False)\n    def test_min_bool_input_with_random_data(test_case):\n        device = random_device()\n        dim = random(1, 4).to(int)\n        x = random_tensor(ndim=4, dtype=float, requires_grad=False).to(\n            device, dtype=torch.bool\n        )\n        return x.min(dim)\n\n    @flow.unittest.skip_unless_1n1d()\n    @autotest(n=5)\n    def test_tensor_tanh_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor().to(device)\n        y = x.tanh()\n        return y\n\n    @flow.unittest.skip_unless_1n1d()\n    @autotest(n=5)\n    def test_flow_tensor_asin_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(low=-0.5, high=0.5).to(device)\n        y = x.asin()\n        return y\n\n    @flow.unittest.skip_unless_1n1d()\n    @autotest(n=5)\n    def test_flow_tensor_arcsin_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(low=-0.5, high=0.5).to(device)\n        y = x.arcsin()\n        return y\n\n    @flow.unittest.skip_unless_1n1d()\n    @autotest(n=5)\n    def test_flow_tensor_asinh_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor().to(device)\n        y = x.asinh()\n        return y\n\n    @flow.unittest.skip_unless_1n1d()\n    @autotest(n=5)\n    def test_flow_tensor_arcsinh_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor().to(device)\n        y = x.arcsinh()\n        return y\n\n    @flow.unittest.skip_unless_1n1d()\n    @autotest(n=5)\n    def test_flow_tensor_sinh_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor().to(device)\n        y = x.sinh()\n        return y\n\n    @flow.unittest.skip_unless_1n1d()\n    @autotest(n=5)\n    def test_flow_tensor_atan2_with_random_data(test_case):\n        device = random_device()\n        x1 = random_tensor(ndim=1, dim0=1).to(device)\n        x2 = random_tensor(ndim=1, dim0=1).to(device)\n        y = x1.atan2(x2)\n        return y\n\n    @flow.unittest.skip_unless_1n1d()\n    @autotest(n=5)\n    def test_dot(test_case):\n        device = random_device()\n        k = random(10, 100)\n        x = random_tensor(ndim=1, dim0=k).to(device)\n        y = random_tensor(ndim=1, dim0=k).to(device)\n        z = x.dot(y)\n        return z\n\n    @flow.unittest.skip_unless_1n1d()\n    @autotest(n=5)\n    def test_arccos_tensor_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(low=2, high=3).to(device)\n        y = x.arccos()\n        return y\n\n    @flow.unittest.skip_unless_1n1d()\n    @autotest(n=5)\n    def test_arccosh_tensor_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(low=2, high=3).to(device)\n        y = x.arccosh()\n        return y\n\n    @flow.unittest.skip_unless_1n1d()\n    @autotest(n=5)\n    def test_acosh_tensor_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(low=2, high=3).to(device)\n        y = x.acosh()\n        return y\n\n    @flow.unittest.skip_unless_1n1d()\n    @autotest(auto_backward=False, check_graph=True)\n    def test_sort_tensor_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=4).to(device)\n        y = x.sort(dim=random(low=-4, high=4).to(int), descending=random_bool())\n        return y[0], y[1]\n\n    @flow.unittest.skip_unless_1n1d()\n    @autotest(auto_backward=False, check_graph=True)\n    def test_sort_tensor_return_type(test_case):\n        device = random_device()\n        x = random_tensor(ndim=4).to(device)\n        result = x.sort(dim=random(low=-4, high=4).to(int), descending=random_bool())\n        return result.values, result.indices\n\n    @flow.unittest.skip_unless_1n1d()\n    @autotest(auto_backward=False, check_graph=True)\n    def test_argsort_tensor_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=4).to(device)\n        y = x.argsort(dim=random(low=-4, high=4).to(int), descending=random_bool())\n        return y\n\n    @autotest(n=5)\n    def test_mean_with_random_data(test_case):\n        device = random_device()\n        dim = random(1, 4).to(int)\n        x = random_tensor(ndim=4, dtype=float).to(device)\n        return x.mean(dim)\n\n    @autotest(n=5)\n    def test_log_tensor_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor().to(device)\n        return x.log()\n\n    @autotest(n=5)\n    def test_log1p_tensor_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor().to(device)\n        return x.log1p()\n\n    @autotest(n=5)\n    def test_log2_tensor_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor().to(device)\n        return x.log2()\n\n    @autotest(n=5)\n    def test_log10_tensor_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor().to(device)\n        return x.log10()\n\n    @autotest(n=5)\n    def test_neg_tensor_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor().to(device)\n        return -x\n\n    @autotest(n=5)\n    def test_negative_tensor_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor().to(device)\n        return x.negative()\n\n    @autotest(n=5)\n    def test_neg_tensor_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor().to(device)\n        return x.neg()\n\n    @autotest(auto_backward=False, check_graph=True)\n    def test_greater_tensor_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=3, dim1=2, dim2=3).to(device)\n        y = random_tensor(ndim=3, dim1=2, dim2=3).to(device)\n        return x.gt(y)\n\n    @autotest(auto_backward=False, check_graph=True)\n    def test_less_tensor_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=3, dim1=2, dim2=3).to(device)\n        y = random_tensor(ndim=3, dim1=2, dim2=3).to(device)\n        return x.lt(y)\n\n    @autotest(auto_backward=False, check_graph=True)\n    def test_tensor_topk_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=4, dim1=8, dim2=9, dim3=10).to(device)\n        y = x.topk(\n            random(low=1, high=8).to(int),\n            dim=random(low=1, high=4).to(int) | nothing(),\n            largest=random_bool() | nothing(),\n            sorted=constant(True) | nothing(),\n        )\n        return y[0], y[1]\n\n    @autotest(auto_backward=False, check_graph=True)\n    def test_tensor_topk_return_type(test_case):\n        device = random_device()\n        x = random_tensor(ndim=4, dim1=8, dim2=9, dim3=10).to(device)\n        result = x.topk(\n            random(low=1, high=8).to(int),\n            dim=random(low=1, high=4).to(int),\n            largest=random_bool(),\n            sorted=constant(True),\n        )\n        return result.values, result.indices\n\n    @autotest(auto_backward=False, check_graph=True)\n    def test_flow_fmod_element_with_random_data(test_case):\n        device = random_device()\n        dim1 = random().to(int)\n        dim2 = random().to(int)\n        input = random_tensor(ndim=3, dim1=dim1, dim2=dim2).to(device)\n        other = random_tensor(ndim=3, dim1=dim1, dim2=dim2).to(device)\n        return input.fmod(other)\n\n    @autotest(auto_backward=False, check_graph=True)\n    def test_flow_fmod_broadcast_with_random_data(test_case):\n        device = random_device()\n        dim1 = random().to(int)\n        dim2 = random().to(int)\n        input = random_tensor(ndim=3, dim1=constant(1), dim2=dim2).to(device)\n        other = random_tensor(ndim=3, dim1=dim1, dim2=constant(1)).to(device)\n        return input.fmod(other)\n\n    @autotest(auto_backward=True, check_graph=True)\n    def test_flow_fmod_scalar_with_random_data(test_case):\n        device = random_device()\n        dim1 = random().to(int)\n        dim2 = random().to(int)\n        input = random_tensor(ndim=3, dim1=dim1, dim2=dim2).to(device)\n        other = 3\n        return input.fmod(other)\n\n    @autotest(auto_backward=False, check_graph=True)\n    def test_fmod_with_0_size_data(test_case):\n        device = random_device()\n        x = random_tensor(4, 2, 1, 0, 3).to(device)\n        y = x.fmod(2)\n        return y\n\n    @autotest(n=5)\n    def test_tensor_flip_list_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(\n            ndim=4, dim1=random().to(int), dim2=random().to(int), dim3=random().to(int)\n        ).to(device)\n        y = x.flip(constant([0, 1, 2]))\n        return y\n\n    @autotest(n=5)\n    def test_tensor_flip_tuple_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(\n            ndim=4, dim1=random().to(int), dim2=random().to(int), dim3=random().to(int)\n        ).to(device)\n        y = x.flip(constant((0, 1, 2)))\n        return y\n\n    @autotest(n=5)\n    def test_tensor_chunk_list_with_random_data(test_case):\n        device = random_device()\n        dim = random(1, 4).to(int)\n        x = random_tensor(\n            ndim=4,\n            dim1=random(low=4, high=8).to(int),\n            dim2=random(low=4, high=8).to(int),\n            dim3=random(low=4, high=8).to(int),\n        ).to(device)\n        y = x.chunk(chunks=random(low=1, high=5).to(int), dim=dim)\n        z = torch.cat(y, dim=dim)\n        return z\n\n    @autotest(n=5)\n    def test_tensor_reciprocal_list_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(\n            ndim=4, dim1=random().to(int), dim2=random().to(int), dim3=random().to(int)\n        ).to(device)\n        y = x.reciprocal()\n        return y\n\n    @flow.unittest.skip_unless_1n1d()\n    def test_tensor_slice(test_case):\n        x = np.random.randn(2, 3, 4, 5).astype(np.float32)\n        input = flow.tensor(x)\n        test_case.assertTrue(np.allclose(input[0].numpy(), x[0], 1e-05, 1e-05))\n        test_case.assertTrue(np.allclose(input[1].numpy(), x[1], 1e-05, 1e-05))\n        test_case.assertTrue(np.allclose(input[0, :].numpy(), x[0, :], 1e-05, 1e-05))\n        test_case.assertTrue(\n            np.allclose(input[0, :, 0:2].numpy(), x[0, :, 0:2], 1e-05, 1e-05)\n        )\n\n    @flow.unittest.skip_unless_1n1d()\n    def test_zeros_(test_case):\n        shape = (2, 3)\n        x = flow.tensor(np.random.randn(*shape), dtype=flow.float32)\n        x.zero_()\n        test_case.assertTrue(np.allclose(x.numpy(), np.zeros(shape)))\n\n    @flow.unittest.skip_unless_1n1d()\n    def test_construct_small_tensor(test_case):\n        shape = (2, 3, 4, 5)\n        np_arr = np.random.rand(*shape).astype(np.float32)\n        tensor = flow.tensor(np_arr)\n        test_case.assertTrue(np.allclose(tensor.numpy(), np_arr))\n        test_case.assertEqual(tensor.dtype, flow.float32)\n        np_int_arr = np.random.randint(-100, high=100, size=shape, dtype=np.int32)\n        tensor = flow.tensor(np_int_arr, dtype=flow.int32)\n        test_case.assertEqual(tensor.dtype, flow.int32)\n        list_data = [[1, 2.0], [5, 3]]\n        tensor = flow.tensor(list_data)\n        test_case.assertEqual(tensor.dtype, flow.float32)\n        test_case.assertTrue(\n            np.allclose(tensor.numpy(), np.array(list_data), 0.0001, 0.0001)\n        )\n        tuple_data = ((1, 2, 5), (4, 3, 10))\n        tensor = flow.tensor(tuple_data)\n        test_case.assertEqual(tensor.dtype, flow.int64)\n        test_case.assertTrue(np.allclose(tensor.numpy(), np.array(tuple_data)))\n        scalar = 5.5\n        tensor = flow.tensor(scalar)\n        test_case.assertEqual(tensor.dtype, flow.float32)\n        test_case.assertTrue(\n            np.allclose(tensor.numpy(), np.array(scalar), 0.0001, 0.0001)\n        )\n\n    @autotest(n=5)\n    def test_tensor_floor_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor().to(device)\n        y = x.floor()\n        return y\n\n    @autotest(n=5)\n    def test_tensor_round_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor().to(device)\n        y = x.round()\n        return y\n\n    def _test_tensor_reshape(test_case):\n        x = np.array(\n            [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]]\n        ).astype(np.float32)\n        input = flow.tensor(x)\n        of_shape = input.reshape(2, 2, 2, -1).numpy().shape\n        np_shape = (2, 2, 2, 2)\n        test_case.assertTrue(np.allclose(of_shape, np_shape))\n\n    @autotest(n=5)\n    def test_flatten_tensor_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor().to(device)\n        y = x.flatten(\n            start_dim=random(1, 6).to(int) | nothing(),\n            end_dim=random(1, 6).to(int) | nothing(),\n        )\n        return y\n\n    @autotest(n=5)\n    def test_reshape_tensor_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=4).to(device)\n        y = x.reshape(-1)\n        return y\n\n    @autotest(n=1)\n    def test_reshape_tensor_with_random_data_and_keyword(test_case):\n        device = random_device()\n        x = random_tensor(ndim=4).to(device)\n        y = x.reshape(shape=[-1,])\n        return y\n\n    @autotest(n=5)\n    def test_reshape_as_tensor_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=4).to(device)\n        y = x.reshape(-1)\n        z = y.reshape_as(other=x)\n        return z\n\n    @autotest(n=5)\n    def test_tensor_squeeze_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor().to(device)\n        y = x.squeeze(random().to(int))\n        return y\n\n    @autotest(n=5)\n    def test_flow_unsqueeze_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor().to(device)\n        y = x.unsqueeze(random(1, 3).to(int))\n        return y\n\n    @autotest(n=3, auto_backward=False, check_graph=True)\n    def test_flow_invert_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor().to(device, dtype=torch.bool)\n        y = ~x\n        return y\n\n    def test_tensor_float(test_case):\n        x = flow.tensor(1)\n        y = float(x)\n        test_case.assertTrue(np.array_equal(y, 1.0))\n\n    def test_tensor_int(test_case):\n        x = flow.tensor(2.3)\n        y = int(x)\n        test_case.assertTrue(np.array_equal(y, 2))\n\n    def test_none_equal(test_case):\n        xt = flow.randn(10)\n        yt = flow.randn(10)\n        z = None in [xt, yt]\n        test_case.assertTrue(np.array_equal(z, False))\n        zt = None\n        z = None in [xt, yt, zt]\n        test_case.assertTrue(np.array_equal(z, True))\n\n    def test_half(test_case):\n        x = flow.tensor([1], dtype=flow.int64)\n        test_case.assertTrue(x.dtype == flow.int64)\n        y = x.half()\n        test_case.assertTrue(y.dtype == flow.float16)\n\n    def test_byte(test_case):\n        x = flow.tensor([1.2], dtype=flow.float32)\n        test_case.assertTrue(x.dtype == flow.float32)\n        y = x.byte()\n        test_case.assertTrue(y.dtype == flow.uint8)\n\n    def test_tensor_constructor(test_case):\n        x = flow.tensor([1, 2, 3])\n        test_case.assertTrue(np.array_equal(x.numpy(), [1, 2, 3]))\n        test_case.assertEqual(x.dtype, flow.int64)\n        x = flow.tensor([1.0, 2.0, 3.0])\n        test_case.assertTrue(np.array_equal(x.numpy(), [1.0, 2.0, 3.0]))\n        test_case.assertEqual(x.dtype, flow.float32)\n        x = flow.tensor([1.0, 2.0, 3.0], dtype=flow.float64)\n        test_case.assertTrue(np.array_equal(x.numpy(), [1.0, 2.0, 3.0]))\n        test_case.assertEqual(x.dtype, flow.float64)\n        np_arr = np.array([1, 2, 3])\n        x = flow.tensor(np_arr)\n        test_case.assertTrue(np.array_equal(x.numpy(), [1, 2, 3]))\n        test_case.assertEqual(x.dtype, flow.int64)\n        np_arr = np.array([1, 2, 3], dtype=np.float64)\n        x = flow.tensor(np_arr)\n        test_case.assertTrue(np.array_equal(x.numpy(), [1.0, 2.0, 3.0]))\n        test_case.assertEqual(x.dtype, flow.float64)\n        x = flow.tensor(np_arr, dtype=flow.float32)\n        test_case.assertTrue(np.array_equal(x.numpy(), [1.0, 2.0, 3.0]))\n        test_case.assertEqual(x.dtype, flow.float32)\n        x = flow.tensor(np_arr, dtype=flow.int8)\n        test_case.assertTrue(np.array_equal(x.numpy(), [1.0, 2.0, 3.0]))\n        test_case.assertEqual(x.dtype, flow.int8)\n        x = flow.tensor([flow.tensor([1, 2])] * 3, dtype=flow.float32)\n        test_case.assertTrue(np.array_equal(x.numpy(), [[1, 2], [1, 2], [1, 2]]))\n        test_case.assertEqual(x.dtype, flow.float32)\n\n    def test_tensor_contains_magic_method(test_case):\n        x = flow.tensor([[1, 2, 3], [4, 5, 6]])\n        y = 1 in x\n        test_case.assertEqual(y, True)\n\n    @profile(torch.Tensor.fill_)\n    def profile_fill_(test_case):\n        torch.Tensor.fill_(torch.ones(1, 8, 16, 16), 2)\n        torch.Tensor.fill_(torch.ones(1000, 1000), 2)\n        torch.Tensor.fill_(torch.ones(1, 8, 16, 16), torch.tensor(2))\n        torch.Tensor.fill_(torch.ones(1000, 1000), torch.tensor(2))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/tensor/test_tensor_part_2.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport copy\nimport os\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n@flow.unittest.skip_unless_1n1d()\nclass TestTensor(flow.unittest.TestCase):\n    @autotest(n=10)\n    def test_permute_flow_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=4).to(device)\n        permute_list = [0, 1, 2, 3]\n        np.random.shuffle(permute_list)\n        y = x.permute(permute_list)\n        return y\n\n    @autotest(n=1)\n    def test_permute_flow_with_random_data_and_keyword(test_case):\n        device = random_device()\n        x = random_tensor(ndim=4).to(device)\n        permute_list = [0, 1, 2, 3]\n        np.random.shuffle(permute_list)\n        y = x.permute(dims=permute_list)\n        return y\n\n    @autotest(n=5)\n    def test_transpose_tensor_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=4).to(device)\n        permute_list = np.random.permutation(4)\n        y = x.transpose(permute_list[0], permute_list[1])\n        return y\n\n    @autotest(n=5)\n    def test_t_tensor_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(\n            ndim=constant(2).to(int), dim0=random(0, 64), dim1=random(0, 64)\n        ).to(device)\n        y = x.t()\n        return y\n\n    @autotest(n=5)\n    def test_T_tensor_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=random(1, 4)).to(device)\n        y = x.T\n        return y\n\n    def test_tensor_where(test_case):\n        x = flow.tensor(\n            np.array([[-0.462, 0.3139], [0.3898, -0.7197], [0.0478, -0.1657]]),\n            dtype=flow.float32,\n        )\n        y = flow.tensor(np.ones(shape=(3, 2)), dtype=flow.float32)\n        condition = flow.tensor(np.array([[0, 1], [1, 0], [1, 0]]), dtype=flow.int32)\n        of_out = condition.where(x, y)\n        np_out = np.array([[1.0, 0.3139], [0.3898, 1.0], [0.0478, 1.0]])\n        test_case.assertTrue(np.allclose(of_out.numpy(), np_out))\n\n    def test_tensor_equal(test_case):\n        arr1 = np.random.randint(1, 10, size=(2, 3, 4, 5))\n        arr2 = np.random.randint(1, 10, size=(2, 3, 4, 5))\n        input = flow.tensor(arr1, dtype=flow.float32)\n        other = flow.tensor(arr2, dtype=flow.float32)\n        of_out = input.eq(other)\n        np_out = np.equal(arr1, arr2)\n        test_case.assertTrue(np.allclose(of_out.numpy(), np_out))\n\n    def test_tensor_equal_bool_dtype(test_case):\n        np_bool = np.random.randint(0, 2, size=()).astype(bool).item()\n        input = flow.tensor(np_bool, dtype=flow.bool)\n        input2 = flow.tensor([np_bool], dtype=flow.bool)\n        test_case.assertTrue(input == np_bool)\n        test_case.assertTrue(input2 == np_bool)\n\n    def test_tensor_detach(test_case):\n        shape = (2, 3, 4, 5)\n        x = flow.tensor(np.random.randn(*shape), dtype=flow.float32, requires_grad=True)\n        test_case.assertTrue(np.allclose(x.detach().numpy(), x.numpy(), 0.0001, 0.0001))\n        test_case.assertEqual(x.detach().requires_grad, False)\n        y = x * 2\n        z = y.detach()\n        test_case.assertEqual(z.is_leaf, True)\n        test_case.assertEqual(z.grad_fn, None)\n\n    def _test_cast_tensor_function(test_case):\n        shape = (2, 3, 4, 5)\n        np_arr = np.random.randn(*shape).astype(np.float32)\n        input = flow.tensor(np_arr, dtype=flow.float32)\n        output = input.cast(flow.int8)\n        np_out = np_arr.astype(np.int8)\n        test_case.assertTrue(np.allclose(output.numpy(), np_out))\n\n    def _test_sin_tensor_function(test_case, shape, device):\n        input = flow.Tensor(np.random.randn(2, 3, 4, 5))\n        of_out = input.sin()\n        np_out = np.sin(input.numpy())\n        test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05))\n\n    def test_cos_tensor_function(test_case):\n        arr = np.random.randn(2, 3, 4, 5)\n        input = flow.tensor(arr, dtype=flow.float32)\n        np_out = np.cos(arr)\n        of_out = input.cos()\n        test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05))\n\n    def test_std_tensor_function(test_case):\n        np_arr = np.random.randn(9, 8, 7, 6)\n        input = flow.Tensor(np_arr)\n        of_out = input.std(dim=1, unbiased=False, keepdim=False)\n        np_out = np.std(np_arr, axis=1)\n        test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-04, 1e-04))\n\n    def test_sqrt_tensor_function(test_case):\n        input_arr = np.random.rand(1, 6, 3, 8)\n        np_out = np.sqrt(input_arr)\n        x = flow.Tensor(input_arr)\n        of_out = x.sqrt()\n        test_case.assertTrue(\n            np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05, equal_nan=True)\n        )\n\n    def test_rsqrt_tensor_function(test_case):\n        np_arr = np.random.rand(3, 2, 5, 7)\n        np_out = 1 / np.sqrt(np_arr)\n        x = flow.Tensor(np_arr)\n        of_out = flow.rsqrt(x)\n        test_case.assertTrue(\n            np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05, equal_nan=True)\n        )\n\n    def test_square_tensor_function(test_case):\n        np_arr = np.random.randn(2, 7, 7, 3)\n        np_out = np.square(np_arr)\n        x = flow.Tensor(np_arr)\n        of_out = x.square()\n        test_case.assertTrue(\n            np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05, equal_nan=True)\n        )\n\n    # This test will fail with the rtol and atol constraint under pytorch1.10, but success with pytorch 1.13.\n    # The constraints should be removed in the future.\n    @autotest(n=5, rtol=1e-3, atol=1e-3)\n    def test_addmm_tensor_with_random_data(test_case):\n        device = random_device()\n        input = random_tensor(ndim=2, dim0=2, dim1=3).to(device)\n        mat1 = random_tensor(ndim=2, dim0=2, dim1=4).to(device)\n        mat2 = random_tensor(ndim=2, dim0=4, dim1=3).to(device)\n        y = input.addmm(\n            mat1,\n            mat2,\n            beta=random().to(float) | nothing(),\n            alpha=random().to(float) | nothing(),\n        )\n        return y\n\n    # This test will fail with the rtol and atol constraint under pytorch1.10, but success with pytorch 1.13.\n    # The constraints should be removed in the future.\n    @autotest(n=5, rtol=1e-3, atol=1e-2)\n    def test_addmm_broadcast_tensor_with_random_data(test_case):\n        device = random_device()\n        input = random_tensor(ndim=2, dim0=1, dim1=1).to(device)\n        mat1 = random_tensor(ndim=2, dim0=2, dim1=4).to(device)\n        mat2 = random_tensor(ndim=2, dim0=4, dim1=3).to(device)\n        y = input.addmm(\n            mat1,\n            mat2,\n            beta=random().to(float) | nothing(),\n            alpha=random().to(float) | nothing(),\n        )\n        return y\n\n    @autotest(n=5)\n    def test_clamp_tensor_with_random_data(test_case):\n        device = random_device()\n        input = random_tensor(low=-2, high=2).to(device)\n        y = input.clamp(\n            min=random(low=-1, high=-0.5).to(float),\n            max=random(low=0.5, high=1).to(float),\n        )\n        return y\n\n    @autotest(n=5)\n    def test_clamp_inplace_tensor_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(low=-2, high=2).to(device)\n        y = x + 1\n        y.clamp_(\n            min=random(low=-1, high=-0.5).to(float),\n            max=random(low=0.5, high=1).to(float),\n        )\n        return y\n\n    @autotest(auto_backward=False)\n    def test_clamp_inplace_tensor_no_grad_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(low=-2, high=2).to(device)\n        y = x + 1\n        y.clamp_(\n            min=random(low=-1, high=-0.5).to(float),\n            max=random(low=0.5, high=1).to(float),\n        )\n        return y\n\n    @autotest(n=5)\n    def test_clamp_minnone_tensor_with_random_data(test_case):\n        device = random_device()\n        input = random_tensor(low=-2, high=2).to(device)\n        y = input.clamp(\n            min=random(low=-1, high=-0.5).to(float) | nothing(),\n            max=random(low=0.5, high=1).to(float),\n        )\n        return y\n\n    @flow.unittest.skip_unless_1n1d()\n    @autotest(auto_backward=False)\n    def test_clamp_minnone_tensor_no_grad_with_random_data(test_case):\n        device = random_device()\n        input = random_tensor(low=-2, high=2).to(device)\n        y = input.clamp(\n            min=random(low=-1, high=-0.5).to(float) | nothing(),\n            max=random(low=0.5, high=1).to(float),\n        )\n        return y\n\n    @autotest(n=5)\n    def test_clamp_inplace_minnone_tensor_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(low=-2, high=2).to(device)\n        y = x + 1\n        y.clamp_(\n            min=random(low=-1, high=-0.5).to(float) | nothing(),\n            max=random(low=0.5, high=1).to(float),\n        )\n        return y\n\n    @autotest(auto_backward=False)\n    def test_clamp_inplace_minnone_tensor_no_grad_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(low=-2, high=2).to(device)\n        y = x + 1\n        y.clamp_(\n            min=random(low=-1, high=-0.5).to(float) | nothing(),\n            max=random(low=0.5, high=1).to(float),\n        )\n        return y\n\n    @autotest(n=5)\n    def test_clamp_maxnone_tensor_with_random_data(test_case):\n        device = random_device()\n        input = random_tensor(low=-2, high=2).to(device)\n        y = input.clamp(\n            min=random(low=-1, high=-0.5).to(float),\n            max=random(low=0.5, high=1).to(float) | nothing(),\n        )\n        return y\n\n    @autotest(auto_backward=False)\n    def test_clamp_maxnone_tensor_no_grad_with_random_data(test_case):\n        device = random_device()\n        input = random_tensor(low=-2, high=2).to(device)\n        y = input.clamp(\n            min=random(low=-1, high=-0.5).to(float),\n            max=random(low=0.5, high=1).to(float) | nothing(),\n        )\n        return y\n\n    @autotest(n=5)\n    def test_clamp_inplace_maxnone_tensor_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(low=-2, high=2).to(device)\n        y = x + 1\n        y.clamp_(\n            min=random(low=-1, high=-0.5).to(float),\n            max=random(low=0.5, high=1).to(float) | nothing(),\n        )\n        return y\n\n    @autotest(auto_backward=False)\n    def test_clamp_inplace_maxnone_tensor_no_grad_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(low=-2, high=2).to(device)\n        y = x + 1\n        y.clamp_(\n            min=random(low=-1, high=-0.5).to(float),\n            max=random(low=0.5, high=1).to(float) | nothing(),\n        )\n        return y\n\n    @autotest(n=5)\n    def test_clamp_min_tensor_with_random_data(test_case):\n        device = random_device()\n        input = random_tensor(low=-2, high=2).to(device)\n        y = input.clamp_min(random(low=-0.5, high=0.5).to(float))\n        return y\n\n    @autotest(n=5)\n    def test_clamp_min_inplace_tensor_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(low=-2, high=2).to(device)\n        y = x + 1\n        y.clamp_min_(random(low=-0.5, high=0.5).to(float))\n        return y\n\n    @autotest(auto_backward=False)\n    def test_clamp_min_tensor_no_grad_with_random_data(test_case):\n        device = random_device()\n        input = random_tensor(low=-2, high=2).to(device)\n        y = input.clamp_min(random(low=-0.5, high=0.5).to(float))\n        return y\n\n    @autotest(auto_backward=False)\n    def test_clamp_min_inplace_tensor_no_grad_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(low=-2, high=2).to(device)\n        y = x + 1\n        y.clamp_min_(random(low=-0.5, high=0.5).to(float))\n        return y\n\n    @autotest(n=5)\n    def test_clamp_max_tensor_with_random_data(test_case):\n        device = random_device()\n        input = random_tensor(low=-2, high=2).to(device)\n        y = input.clamp_max(random(low=-0.5, high=0.5).to(float))\n        return y\n\n    @autotest(n=5)\n    def test_clamp_max_inplace_tensor_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(low=-2, high=2).to(device)\n        y = x + 1\n        y.clamp_max_(random(low=-0.5, high=0.5).to(float))\n        return y\n\n    @autotest(auto_backward=False)\n    def test_clamp_max_tensor_no_grad_with_random_data(test_case):\n        device = random_device()\n        input = random_tensor(low=-2, high=2).to(device)\n        y = input.clamp_max(random(low=-0.5, high=0.5).to(float))\n        return y\n\n    @autotest(auto_backward=False)\n    def test_clamp_max_inplace_tensor_no_grad_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(low=-2, high=2).to(device)\n        y = x + 1\n        y.clamp_max_(random(low=-0.5, high=0.5).to(float))\n        return y\n\n    @autotest(n=5)\n    def test_clip_tensor_with_random_data(test_case):\n        device = random_device()\n        input = random_tensor(low=-2, high=2).to(device)\n        y = input.clip(\n            min=random(low=-1, high=-0.5).to(float),\n            max=random(low=0.5, high=1).to(float),\n        )\n        return y\n\n    @autotest(n=5)\n    def test_clip_inplace_tensor_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(low=-2, high=2).to(device)\n        y = x + 1\n        y.clip_(\n            min=random(low=-1, high=-0.5).to(float),\n            max=random(low=0.5, high=1).to(float),\n        )\n        return y\n\n    @autotest(n=5)\n    def test_clip_minnone_tensor_with_random_data(test_case):\n        device = random_device()\n        input = random_tensor(low=-2, high=2).to(device)\n        y = input.clip(\n            min=random(low=-1, high=-0.5).to(float) | nothing(),\n            max=random(low=0.5, high=1).to(float),\n        )\n        return y\n\n    @autotest(n=5)\n    def test_clip_inplace_maxnone_tensor_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(low=-2, high=2).to(device)\n        y = x + 1\n        y.clip_(\n            min=random(low=-1, high=-0.5).to(float),\n            max=random(low=0.5, high=1).to(float) | nothing(),\n        )\n        return y\n\n    @autotest(n=5)\n    def test_clip_maxnone_tensor_with_random_data(test_case):\n        device = random_device()\n        input = random_tensor().to(device)\n        y = input.clip(\n            min=random(low=-1, high=-0.5).to(float),\n            max=random(low=0.5, high=1).to(float) | nothing(),\n        )\n        return y\n\n    @autotest(n=5)\n    def test_clip_inplace_maxnone_tensor_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(low=-2, high=2).to(device)\n        y = x + 1\n        y.clip_(\n            min=random(low=-1, high=-0.5).to(float),\n            max=random(low=0.5, high=1).to(float) | nothing(),\n        )\n        return y\n\n    @autotest(n=5)\n    def test_ceil_tensor_with_random_data(test_case):\n        device = random_device()\n        input = random_tensor().to(device)\n        y = len(input)\n        return y\n\n    @autotest(n=5)\n    def test_ceil_tensor_with_random_data(test_case):\n        device = random_device()\n        input = random_tensor().to(device)\n        y = input.ceil()\n        return y\n\n    @autotest(n=5)\n    def test_expm1_tensor_with_random_data(test_case):\n        device = random_device()\n        input = random_tensor().to(device)\n        y = input.expm1()\n        return y\n\n    @autotest(n=5)\n    def test_floor_tensor_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor().to(device)\n        y = x.floor()\n        return y\n\n    @autotest(n=5)\n    def test_tensor_var_all_dim_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor().to(device)\n        y = x.var()\n        return y\n\n    # TODO(): 'var backward' is composed of several other ops,\n    # reducemean doesn't support 0-shape for now\n    @autotest(n=5, auto_backward=False)\n    def test_tensor_var_one_dim_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=4).to(device)\n        y = x.var(\n            dim=random(low=0, high=4).to(int),\n            unbiased=random().to(bool),\n            keepdim=random().to(bool),\n        )\n        return y\n\n    def test_norm_tensor_function(test_case):\n        input = flow.tensor(\n            np.array([[-4.0, -3.0, -2.0], [-1.0, 0.0, 1.0], [2.0, 3.0, 4.0]]),\n            dtype=flow.float32,\n        )\n        of_out_1 = input.norm(\"fro\")\n        np_out_1 = np.linalg.norm(input.numpy(), \"fro\")\n        of_out_2 = input.norm(2, dim=1)\n        np_out_2 = np.linalg.norm(input.numpy(), ord=2, axis=1)\n        of_out_3 = input.norm(float(\"inf\"), dim=0, keepdim=True)\n        np_out_3 = np.linalg.norm(\n            input.numpy(), ord=float(\"inf\"), axis=0, keepdims=True\n        )\n        test_case.assertTrue(np.allclose(of_out_1.numpy(), np_out_1, 1e-05, 1e-05))\n        test_case.assertTrue(np.allclose(of_out_2.numpy(), np_out_2, 1e-05, 1e-05))\n        test_case.assertTrue(np.allclose(of_out_3.numpy(), np_out_3, 1e-05, 1e-05))\n\n    @autotest(n=5)\n    def test_pow_tensor_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor().to(device)\n        y = random().to(float)\n        z = x.pow(y)\n        return z\n\n    @autotest(n=5)\n    def test_atanh_tensor_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(low=-0.5, high=0.49).to(device)\n        y = x.atanh()\n        return y\n\n    @autotest(n=5)\n    def test_acos_tensor_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(low=-0.5, high=0.49).to(device)\n        y = x.acos()\n        return y\n\n    @autotest(n=5)\n    def test_acosh_tensor_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(low=2.0, high=3.0).to(device)\n        y = x.acosh()\n        return y\n\n    @autotest(n=5)\n    def test_atan_tensor_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor().to(device)\n        y = x.atan()\n        return y\n\n    @autotest(n=5)\n    def test_arctan_tensor_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor().to(device)\n        y = x.arctan()\n        return y\n\n    @autotest(n=5)\n    def test_tan_tensor_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor().to(device)\n        y = x.tan()\n        return y\n\n    @autotest(n=5)\n    def test_tan2_tensor_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(ndim=2, dim1=3).to(device)\n        y = random_tensor(ndim=2, dim1=3).to(device)\n        z = x.atan2(y)\n        return z\n\n    @autotest(n=5)\n    def test_arctanh_tensor_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(low=-0.5, high=0.5).to(device)\n        y = x.arctanh()\n        return y\n\n    # Not check graph because of one reason:\n    # Reason 1, lazy tensor cannot call .numpy(). tensor.numpy() is not allowed to called in nn.Graph.build(*args) or called by lazy tensor.\n    # Please refer to File \"python/oneflow/nn/modules/nonzero.py\", line 29, in nonzero_op.\n    @autotest(n=5, auto_backward=False, check_graph=\"ValidatedFalse\")\n    def test_tensor_nonzero_with_random_data(test_case):\n        device = random_device()\n        ndim = random(2, 6).to(int)\n        x = random_tensor(ndim=ndim).to(device)\n        y = x.nonzero()\n        return y\n\n    @unittest.skipIf(\n        not flow.unittest.env.eager_execution_enabled(),\n        \"numpy doesn't work in lazy mode\",\n    )\n    def test_tensor_fmod(test_case):\n        x = flow.Tensor(np.random.uniform(-100, 100, (5, 5)))\n        x.requires_grad = True\n        y = np.random.uniform(-10, 10)\n        of_out = x.fmod(y)\n        np_out = np.sign(x.numpy()) * np.abs(np.fmod(x.numpy(), y))\n        test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001))\n        of_out = of_out.sum()\n        of_out.backward()\n        test_case.assertTrue(\n            np.allclose(x.grad.numpy(), np.ones((5, 5)), 0.0001, 0.0001)\n        )\n\n    @unittest.skipIf(\n        not flow.unittest.env.eager_execution_enabled(),\n        \"numpy doesn't work in lazy mode\",\n    )\n    def test_magic_fmod(test_case):\n        x = flow.Tensor(np.random.uniform(-100, 100, (5, 5)))\n        x.requires_grad = True\n        y = np.random.uniform(-10, 10)\n        of_out = x % y\n        np_out = np.sign(x.numpy()) * np.abs(np.fmod(x.numpy(), y))\n        test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001))\n        of_out = of_out.sum()\n        of_out.backward()\n        test_case.assertTrue(\n            np.allclose(x.grad.numpy(), np.ones((5, 5)), 0.0001, 0.0001)\n        )\n\n    def test_tensor_mish(test_case):\n        def np_mish(x):\n            f = 1 + np.exp(x)\n            y = x * ((f * f - 1) / (f * f + 1))\n            y_grad = (f * f - 1) / (f * f + 1) + x * (4 * f * (f - 1)) / (\n                (f * f + 1) * (f * f + 1)\n            )\n            return [y, y_grad]\n\n        np_input = np.random.randn(2, 4, 5, 6)\n        of_input = flow.tensor(np_input, dtype=flow.float32, requires_grad=True)\n        of_out = of_input.mish()\n        (np_out, np_grad) = np_mish(np_input)\n        test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05))\n        of_out = of_out.sum()\n        of_out.backward()\n        test_case.assertTrue(np.allclose(of_input.grad.numpy(), np_grad, 1e-05, 1e-05))\n\n    def test_tensor_triu(test_case):\n        def np_triu(x, diagonal):\n            y = np.triu(x, diagonal)\n            y_grad = np.triu(np.ones_like(x), diagonal)\n            return [y, y_grad]\n\n        diagonal_list = [2, -1]\n        for diagonal in diagonal_list:\n            np_input = np.random.randn(2, 4, 6)\n            of_input = flow.tensor(np_input, dtype=flow.float32, requires_grad=True)\n            of_out = of_input.triu(diagonal)\n            (np_out, np_grad) = np_triu(np_input, diagonal)\n            test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05))\n            of_out = of_out.sum()\n            of_out.backward()\n            test_case.assertTrue(\n                np.allclose(of_input.grad.numpy(), np_grad, 1e-05, 1e-05)\n            )\n\n    def test_tensor_grad_assignment(test_case):\n        np_input = np.random.randn(2, 4, 5, 6)\n        of_input = flow.tensor(np_input, dtype=flow.float32, requires_grad=True)\n        of_output = 2 * of_input\n        of_output = of_output.sum()\n        of_output.backward()\n        new_grad = flow.tensor(\n            np.full(np_input.shape, np.random.randn(1)), dtype=flow.float32\n        )\n        of_input.grad = new_grad\n        test_case.assertTrue(\n            np.allclose(of_input.grad.detach().numpy(), new_grad.numpy(), 1e-05, 1e-05)\n        )\n        of_input.grad = None\n        test_case.assertTrue(of_input.grad is None)\n\n    def test_tensor_grad_assignment_sum(test_case):\n        np_input = np.random.randn(1, 5, 7, 3)\n        of_input = flow.tensor(np_input, dtype=flow.float32, requires_grad=True)\n        of_output = of_input.sum()\n        of_output.backward()\n        rand_init = np.random.randn(1)\n        rand_scale = np.random.randn(1)\n        new_grad = flow.tensor(np.full(np_input.shape, rand_init), dtype=flow.float32)\n        of_input.grad = new_grad\n        of_output = flow.tensor(rand_scale, dtype=flow.float32) * of_input\n        of_output = of_output.sum()\n        of_output.backward()\n        test_case.assertTrue(\n            np.allclose(\n                of_input.grad.detach().numpy(),\n                np.full(np_input.shape, rand_init + rand_scale),\n                1e-05,\n                1e-05,\n            )\n        )\n        of_input.grad = of_input.grad * 2\n        test_case.assertTrue(\n            np.allclose(\n                of_input.grad.detach().numpy(),\n                2 * np.full(np_input.shape, rand_init + rand_scale),\n                1e-05,\n                1e-05,\n            )\n        )\n\n    def test_tensor_mish(test_case):\n        def np_mish(x):\n            f = 1 + np.exp(x)\n            y = x * ((f * f - 1) / (f * f + 1))\n            y_grad = (f * f - 1) / (f * f + 1) + x * (4 * f * (f - 1)) / (\n                (f * f + 1) * (f * f + 1)\n            )\n            return [y, y_grad]\n\n        np_input = np.random.randn(2, 4, 5, 6,)\n        of_input = flow.tensor(np_input, dtype=flow.float32, requires_grad=True)\n        of_out = of_input.mish()\n\n        np_out, np_grad = np_mish(np_input)\n        test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-5, 1e-5))\n\n        of_out = of_out.sum()\n        of_out.backward()\n        test_case.assertTrue(np.allclose(of_input.grad.numpy(), np_grad, 1e-5, 1e-5))\n\n    def test_tensor_silu(test_case):\n        def np_silu(x):\n            _sig = 1 / (1 + np.exp(-x))\n            y = x * _sig\n            y_grad = _sig * (1 + x * (1 - _sig))\n            return [y, y_grad]\n\n        np_input = np.random.randn(2, 4, 5, 6,)\n        of_input = flow.tensor(np_input, dtype=flow.float32, requires_grad=True)\n        of_out = of_input.silu()\n\n        np_out, np_grad = np_silu(np_input)\n        test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-5, 1e-5))\n\n        of_out = of_out.sum()\n        of_out.backward()\n        test_case.assertTrue(np.allclose(of_input.grad.numpy(), np_grad, 1e-5, 1e-5))\n\n    def test_tensor_selu(test_case):\n        _scale = 1.0507009873554804934193349852946\n        _alpha = 1.6732632423543772848170429916717\n\n        def np_selu(x):\n            y = np.where(x < 0, _scale * _alpha * (np.exp(x) - 1), _scale * x)\n            y_grad = np.where(x < 0, _scale * _alpha * np.exp(x), _scale)\n            return [y, y_grad]\n\n        np_input = np.random.randn(2, 4, 5, 6,)\n        of_input = flow.tensor(np_input, dtype=flow.float32, requires_grad=True)\n        of_out = of_input.selu()\n\n        np_out, np_grad = np_selu(np_input)\n        test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-5, 1e-5))\n\n        of_out = of_out.sum()\n        of_out.backward()\n        test_case.assertTrue(np.allclose(of_input.grad.numpy(), np_grad, 1e-5, 1e-5))\n\n    @unittest.skip(\"still have error in ci\")\n    def test_tensor_softsign(test_case):\n        def np_softsign(x):\n            y = x / (1 + np.abs(x))\n            y_grad = 1 / np.square(1 + np.abs(x))\n            return [y, y_grad]\n\n        np_input = np.random.randn(2, 4, 5, 6,)\n        of_input = flow.tensor(np_input, dtype=flow.float32, requires_grad=True)\n        of_out = of_input.softsign()\n\n        np_out, np_grad = np_softsign(np_input)\n        test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-5, 1e-5))\n\n        of_out = of_out.sum()\n        of_out.backward()\n        test_case.assertTrue(np.allclose(of_input.grad.numpy(), np_grad, 1e-5, 1e-5))\n\n    @autotest(auto_backward=False)\n    def test_eq_tensor_with_random_data(test_case):\n        device = random_device()\n        shape = random_tensor().oneflow.shape\n        x = random_tensor(len(shape), *shape, requires_grad=False).to(device)\n        y = random_tensor(len(shape), *shape, requires_grad=False).to(device)\n        return x.eq(y)\n\n    @autotest(auto_backward=False)\n    def test_eq_tensor_with_same_random_data(test_case):\n        device = random_device()\n        shape = random_tensor().oneflow.shape\n        x = random_tensor(len(shape), *shape, requires_grad=False).to(device)\n        return x.eq(x)\n\n    @autotest(n=5)\n    def test_erf_tensor_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor().to(device)\n        return x.erf()\n\n    @autotest(n=5)\n    def test_erfc_tensor_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor().to(device)\n        return x.erfc()\n\n    @autotest(\n        auto_backward=False\n    )  # Todo: After add gradient func, you should set `auto_backward` as True\n    def test_erfinv_tensor_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(low=-1, high=1).to(device).requires_grad_(False)\n        return x.erfinv()\n\n    @autotest(\n        n=10, auto_backward=False\n    )  # Todo: After add gradient func, you should set `auto_backward` as True\n    def test_erfinv_inplace_tensor_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor(low=-1, high=1).to(device).requires_grad_(False)\n        y = x + 1\n        y.erfinv_()\n        return y\n\n    @autotest(n=5)\n    def test_exp_tensor_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor().to(device)\n        return x.exp()\n\n    @autotest(n=5)\n    def test_exp2_tensor_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor().to(device)\n        return x.exp2()\n\n    @autotest(n=5)\n    def test_round_tensor_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor().to(device)\n        return x.round()\n\n    @autotest(n=5)\n    def test_tensor_diag_one_dim(test_case):\n        device = random_device()\n        x = random_tensor(ndim=1, dim0=random()).to(device)\n        return x.diag()\n\n    @autotest(n=5)\n    def test_flow_tensor_expand_with_random_data(test_case):\n        random_expand_size = random(1, 6).to(int).value()\n        x = random_tensor(ndim=5, dim0=1, dim1=1, dim2=1, dim3=1, dim4=1)\n        ndim = 5\n        expand_size = random_expand_size\n        dim_size = [1,] * ndim\n        random_index = random(0, ndim).to(int).value()\n        dim_size[random_index] = expand_size\n        return x.expand(*dim_size)\n\n    @autotest(n=5)\n    def test_flow_tensor_expand_with_random_data(test_case):\n        random_expand_size = random(1, 6).to(int).value()\n        x = random_tensor(ndim=5, dim0=1, dim1=1, dim2=1, dim3=1, dim4=1)\n        ndim = 5\n        expand_size = random_expand_size\n        dim_size = [1,] * ndim\n        random_index = random(0, ndim).to(int).value()\n        dim_size[random_index] = expand_size\n        y = torch.ones(dim_size)\n        return x.expand_as(y)\n\n    @autotest(n=5)\n    def test_flow_tensor_view_with_random_data(test_case):\n        dim0_ = random(2, 4).to(int)\n        dim1_ = random(2, 4).to(int)\n        dim2_ = random(2, 4).to(int)\n        dim3_ = random(2, 4).to(int)\n        dim4_ = random(2, 4).to(int)\n        x = random_tensor(\n            ndim=5, dim0=dim0_, dim1=dim1_, dim2=dim2_, dim3=dim3_, dim4=dim4_\n        )\n        shape = [x.value() for x in [dim4_, dim3_, dim2_, dim1_, dim0_]]\n        return [x.view(shape), x.view(size=shape)]\n\n    @autotest(n=5)\n    def test_flow_tensor_view_as_with_random_data(test_case):\n        dim0_ = random(2, 4).to(int)\n        dim1_ = random(2, 4).to(int)\n        dim2_ = random(2, 4).to(int)\n        dim3_ = random(2, 4).to(int)\n        dim4_ = random(2, 4).to(int)\n        x = random_tensor(\n            ndim=5, dim0=dim0_, dim1=dim1_, dim2=dim2_, dim3=dim3_, dim4=dim4_\n        )\n        other = random_tensor(\n            ndim=5, dim0=dim4_, dim1=dim3_, dim2=dim2_, dim3=dim1_, dim4=dim0_\n        )\n        return x.view_as(other)\n\n    @autotest(n=5)\n    def test_tensor_diag_other_dim(test_case):\n        device = random_device()\n        x = random_tensor(ndim=2, dim0=random(), dim1=random()).to(device)\n        return x.diag()\n\n    @autotest(auto_backward=False)\n    def test_floordiv_elementwise_tensor_with_random_data(test_case):\n        device = random_device()\n        # The random value is narrowed to positive number because of the error from pytorch 1.10.0\n        # Please remove the value range striction after updating the pytorch version of ci to 1.13.\n        input = random_tensor(ndim=2, dim0=4, dim1=8, low=0, high=10).to(device)\n        other = random_tensor(ndim=2, dim0=4, dim1=8, low=0, high=10).to(device)\n        y = input.floor_divide(other)\n        return y\n\n    @autotest(auto_backward=False)\n    def test_scalar_floordiv_tensor_with_random_data(test_case):\n        device = random_device()\n        # The random value is narrowed to positive number because of the error from pytorch 1.10.0\n        # Please remove the value range striction after updating the pytorch version of ci to 1.13.\n        input = random_tensor(ndim=2, dim0=4, dim1=8, low=0, high=10).to(device)\n        other = random().to(int)\n        y = input.floor_divide(other)\n        return y\n\n    @flow.unittest.skip_unless_1n4d()\n    def test_construct_global_tensor_by_numpy(test_case):\n        x = np.ones((4, 4), dtype=np.int32)\n        placement = flow.placement(\"cuda\", [0, 1, 2, 3])\n        y = flow.tensor(\n            x,\n            dtype=flow.float32,\n            placement=placement,\n            sbp=[flow.sbp.split(0)],\n            requires_grad=False,\n        )\n        test_case.assertTrue(y.dtype == flow.float32)\n        test_case.assertTrue(\n            np.allclose(y.to_local().numpy(), np.ones((1, 4), dtype=np.float32))\n        )\n        test_case.assertEqual(y.placement, placement)\n\n        y_default_dtype = flow.tensor(\n            x, placement=placement, sbp=[flow.sbp.split(0)], requires_grad=False,\n        )\n        test_case.assertTrue(y_default_dtype.dtype == flow.int32)\n\n    @autotest(n=5)\n    def test_digamma_tensor_with_random_data(test_case):\n        device = random_device()\n        x = random_tensor().to(device)\n        return x.digamma()\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\nclass TestTensorNumpy(flow.unittest.TestCase):\n    @flow.unittest.skip_unless_1n2d()\n    def test_1d_sbp_tensor_numpy_1n2d(test_case):\n        ori_x = flow.tensor([1, 2, 3, 4]) + flow.env.get_rank()\n        placement = flow.placement.all(\"cpu\")\n        x = ori_x.to_global(placement=placement, sbp=flow.sbp.split(0))\n        test_case.assertTrue(np.allclose(x.numpy(), [1, 2, 3, 4, 2, 3, 4, 5]))\n\n        x = ori_x.to_global(placement=placement, sbp=flow.sbp.broadcast, copy=True)\n        test_case.assertTrue(np.allclose(x.numpy(), [1, 2, 3, 4]))\n\n        x = ori_x.to_global(placement=placement, sbp=flow.sbp.partial_sum)\n        test_case.assertTrue(np.allclose(x.numpy(), [3, 5, 7, 9]))\n\n        placement = flow.placement.all(\"cuda\")\n        x = ori_x.to_global(placement=placement, sbp=flow.sbp.split(0))\n        test_case.assertTrue(np.allclose(x.numpy(), [1, 2, 3, 4, 2, 3, 4, 5]))\n\n        x = ori_x.to_global(placement=placement, sbp=flow.sbp.broadcast, copy=True)\n        test_case.assertTrue(np.allclose(x.numpy(), [1, 2, 3, 4]))\n\n        x = ori_x.to_global(placement=placement, sbp=flow.sbp.partial_sum)\n        test_case.assertTrue(np.allclose(x.numpy(), [3, 5, 7, 9]))\n\n    @flow.unittest.skip_unless_1n2d()\n    def test_2d_sbp_tensor_numpy_1n2d(test_case):\n        ori_x = flow.tensor(np.ones((2, 2))) + flow.env.get_rank()\n        placement = flow.placement(\"cuda\", [[0], [1]])\n        x = ori_x.to_global(\n            placement=placement, sbp=[flow.sbp.split(0), flow.sbp.split(1)]\n        )\n        test_case.assertTrue(np.allclose(x.numpy(), [[1, 1], [1, 1], [2, 2], [2, 2]]))\n\n        x = ori_x.to_global(\n            placement=placement, sbp=[flow.sbp.broadcast, flow.sbp.split(0)]\n        )\n        test_case.assertTrue(np.allclose(x.numpy(), [[1, 1], [1, 1]]))\n\n        x = ori_x.to_global(\n            placement=placement,\n            sbp=[flow.sbp.partial_sum, flow.sbp.broadcast],\n            copy=True,\n        )\n        test_case.assertTrue(np.allclose(x.numpy(), [[3, 3], [3, 3]]))\n\n    @flow.unittest.skip_unless_1n4d()\n    def test_2d_sbp_tensor_numpy_1n4d(test_case):\n        ori_x = flow.tensor(np.ones((2, 2))) + flow.env.get_rank()\n        placement = flow.placement(\"cuda\", [[0, 1], [2, 3]])\n\n        x = ori_x.to_global(\n            placement=placement, sbp=[flow.sbp.split(0), flow.sbp.split(1)]\n        )\n        test_case.assertTrue(\n            np.allclose(\n                x.numpy(), [[1, 1, 2, 2], [1, 1, 2, 2], [3, 3, 4, 4], [3, 3, 4, 4]]\n            )\n        )\n\n        x = ori_x.to_global(\n            placement=placement, sbp=[flow.sbp.split(0), flow.sbp.partial_sum]\n        )\n        test_case.assertTrue(np.allclose(x.numpy(), [[3, 3], [3, 3], [7, 7], [7, 7]]))\n\n        # TODO: (s0, b) has bug\n        # x = ori_x.to_global(placement=placement, sbp=[flow.sbp.split(0), flow.sbp.broadcast])\n\n    @flow.unittest.skip_unless_1n1d()\n    @autotest(n=5)\n    def test_tensor_bmm(test_case):\n        t = random(1, 5)\n        k = random(1, 5)\n        input1 = random_tensor(ndim=3, dim0=t, dim1=3, dim2=k)\n        input2 = random_tensor(ndim=3, dim0=t, dim1=k, dim2=5)\n        of_out = input1.bmm(input2)\n        return of_out\n\n    @flow.unittest.skip_unless_1n1d()\n    @autotest(n=5)\n    def test_tensor_split(test_case):\n        k0 = random(2, 6)\n        k1 = random(2, 6)\n        k2 = random(2, 6)\n        rand_dim = random(0, 3).to(int)\n        device = random_device()\n        x = random_tensor(ndim=3, dim0=k0, dim1=k1, dim2=k2).to(device)\n        res = x.split(2, dim=rand_dim)\n        return torch.cat(res, rand_dim)\n\n    @flow.unittest.skip_unless_1n1d()\n    @autotest(n=5)\n    def test_tensor_split_sizes(test_case):\n        k0 = random(2, 6)\n        k1 = 7\n        k2 = random(2, 6)\n        device = random_device()\n        x = random_tensor(ndim=3, dim0=k0, dim1=k1, dim2=k2).to(device)\n        res = x.split([1, 2, 3, 1], dim=-2)\n        return torch.cat(res, dim=1)\n\n    @flow.unittest.skip_unless_1n1d()\n    @autotest(n=5)\n    def test_tensor_unbind(test_case):\n        device = random_device()\n        x = random_tensor(ndim=4).to(device)\n        y = x.unbind(random(0, 4).to(int))\n        return y\n\n    @flow.unittest.skip_unless_1n1d()\n    @autotest(n=5)\n    def test_tensor_swapaxes(test_case):\n        device = random_device()\n        x = random_tensor(ndim=3).to(device)\n        y = x.swapaxes(random(0, 2).to(int), random(0, 2).to(int))\n        return y\n\n    @flow.unittest.skip_unless_1n1d()\n    @autotest(n=5)\n    def test_tensor_swapdimst(test_case):\n        device = random_device()\n        x = random_tensor(ndim=3).to(device)\n        y = x.swapdims(random(0, 3).to(int), random(0, 3).to(int))\n        return y\n\n    @flow.unittest.skip_unless_1n1d()\n    @autotest(n=5)\n    def test_tensor_int_repeat_interleave_dim_none(test_case):\n        x = random_tensor(ndim=2, dim0=1, dim1=2)\n        y = x.repeat_interleave(2)\n        return y\n\n    @flow.unittest.skip_unless_1n1d()\n    @autotest(n=5)\n    def test_tensor_int_repeat_interleave_with_dim(test_case):\n        x = random_tensor(ndim=3, dim0=2, dim1=2, dim2=3)\n        dim = random(low=0, high=2).to(int)\n        y = x.repeat_interleave(2, dim)\n        return y\n\n    @flow.unittest.skip_unless_1n1d()\n    @autotest(n=5)\n    def test_tensor_tensor_repeat_interleave_dim(test_case):\n        x = random_tensor(ndim=3, dim0=2, dim1=2, dim2=3)\n        y = random_tensor(ndim=1, dim0=2, dtype=int, low=1, high=4)\n        z = x.repeat_interleave(y, 1)\n        return z\n\n    @unittest.skip(\"skip for now, becase it failed 2 times in past week\")\n    @flow.unittest.skip_unless_1n1d()\n    @autotest(n=5, rtol=1e-3)\n    def test_tensor_tensor_repeat_interleave_dim_with_output_size(test_case):\n        x = random_tensor(ndim=3, dim0=2, dim1=2, dim2=3)\n        y = random_tensor(ndim=1, dim0=2, dtype=int, low=1, high=4)\n        z = x.repeat_interleave(y, 1, output_size=2)\n        return z\n\n    @flow.unittest.skip_unless_1n2d()\n    @globaltest\n    def test_global_tensor_detach(test_case):\n        device = random_device().value()\n        placement = flow.placement(device, [0, 1])\n        a = flow.ones(4, 8).to_global(placement, flow.sbp.broadcast)\n        test_case.assertTrue(a.is_leaf)\n        b = a.float().clone().detach()\n        test_case.assertTrue(b.is_leaf)\n\n    @flow.unittest.skip_unless_1n1d()\n    @autotest(n=5)\n    def test_tensor_nansum(test_case):\n        device = random_device()\n        x = random_tensor(4, random(0, 5), 2).to(device)\n        mask = x < 0\n        x = x.masked_fill(mask, float(\"nan\"))\n        y = x.nansum()\n        return y\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/tensor/test_tensor_part_3.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\n\nimport oneflow as flow\nimport oneflow.unittest\nimport numpy as np\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\ndef _get_indexes(device):\n    return (\n        constant(\n            torch.tensor(np.array([[0, 1], [1, 0]]), dtype=torch.int64, device=device)\n        ),\n        constant(\n            torch.tensor(np.array([[1, 0], [0, 1]]), dtype=torch.int64, device=device)\n        ),\n        constant(\n            torch.tensor(np.array([[1, 0], [1, 0]]), dtype=torch.int64, device=device)\n        ),\n        constant(\n            torch.tensor(np.array([[0, 1], [0, 1]]), dtype=torch.int64, device=device)\n        ),\n    )\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n@flow.unittest.skip_unless_1n1d()\nclass TestTensor(flow.unittest.TestCase):\n    @autotest(n=10)\n    def test_scatter_random_data(test_case):\n        device = random_device()\n        input = random_tensor(ndim=2, dim0=2, dim1=2).to(device)\n        src = oneof(3.14, random_tensor(ndim=2, dim0=2, dim1=2).to(device))\n        inplace = oneof(True, False)\n        dim = oneof(0, 1, -1)\n        if inplace:\n            y = input + 1\n            y.scatter_(dim, oneof(*_get_indexes(device)), src)\n            return y\n        return input.scatter(dim, oneof(*_get_indexes(device)), src)\n\n    @autotest(\n        n=10, auto_backward=False\n    )  # peihong: pytorch dose not support backward when reduce is add or multiply\n    def test_scatter_add_or_multiply_random_data(test_case):\n        device = random_device()\n        input = random_tensor(ndim=2, dim0=2, dim1=2).to(device)\n        src = random_tensor(ndim=2, dim0=2, dim1=2).to(device)\n        inplace = oneof(True, False)\n        reduce = oneof(\"add\", \"multiply\")\n        dim = oneof(0, 1)\n        if inplace:\n            y = input + 1\n            y.scatter_(\n                dim, oneof(*_get_indexes(device)), src, reduce=reduce,\n            )\n            return y\n        return input.scatter(dim, oneof(*_get_indexes(device)), src, reduce=reduce)\n\n    def test_tensor_element_size_api(test_case):\n        x = flow.ones(2, 1, dtype=flow.float)\n        test_case.assertEqual(x.element_size(), 4)\n\n    def test_tensor_new(test_case):\n        dtype = random_dtype([\"pod\"])\n        device = random_device()\n        x = random_tensor(ndim=3).to(dtype).to(device)\n        of_result = x.oneflow.new()\n        th_result = x.pytorch.new()\n        test_case.assertTrue(list(of_result.shape) == list(th_result.shape))\n        test_case.assertTrue(\n            of_result.numpy().dtype == th_result.detach().cpu().numpy().dtype\n        )\n        test_case.assertTrue(of_result.device.type == th_result.device.type)\n\n        y = random_tensor(ndim=3).to(dtype).to(device)\n        of_result = x.oneflow.new(y.oneflow)\n        th_result = x.pytorch.new(y.pytorch)\n        test_case.assertTrue(list(of_result.shape) == list(th_result.shape))\n        test_case.assertTrue(\n            of_result.numpy().dtype == th_result.detach().cpu().numpy().dtype\n        )\n        test_case.assertTrue(of_result.device.type == th_result.device.type)\n\n        np_data = np.random.randn(3, 3)\n        of_result = x.oneflow.new(np_data)\n        th_result = x.pytorch.new(np_data)\n        test_case.assertTrue(list(of_result.shape) == list(th_result.shape))\n        test_case.assertTrue(\n            of_result.numpy().dtype == th_result.detach().cpu().numpy().dtype\n        )\n        test_case.assertTrue(of_result.device.type == th_result.device.type)\n\n        of_result = x.oneflow.new([1, 2, 3])\n        th_result = x.pytorch.new([1, 2, 3])\n        test_case.assertTrue(list(of_result.shape) == list(th_result.shape))\n        test_case.assertTrue(\n            of_result.numpy().dtype == th_result.detach().cpu().numpy().dtype\n        )\n        test_case.assertTrue(of_result.device.type == th_result.device.type)\n\n    @autotest(n=3)\n    def test_baddbmm(test_case):\n        device = random_device()\n        batch_dim = random().to(int)\n        dim1 = random().to(int)\n        dim2 = random().to(int)\n        dim3 = random().to(int)\n        x = random_tensor(\n            ndim=3, dim0=oneof(batch_dim, 1).value(), dim1=dim1, dim2=dim3\n        ).to(device)\n        batch1 = random_tensor(ndim=3, dim0=batch_dim, dim1=dim1, dim2=dim2).to(device)\n        batch2 = random_tensor(ndim=3, dim0=batch_dim, dim1=dim2, dim2=dim3).to(device)\n        alpha = random_or_nothing(-1, 1).to(float)\n        beta = random_or_nothing(-1, 1).to(float)\n        return x.baddbmm(batch1, batch2, alpha=alpha, beta=beta)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/tensor/test_tensor_pin_memory.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport copy\nimport os\nimport unittest\nfrom collections import OrderedDict\n\nimport numpy as np\nimport oneflow as flow\nimport oneflow.unittest\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@flow.unittest.skip_unless_1n1d()\nclass TestTensor(flow.unittest.TestCase):\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    @flow.unittest.skip_unless_1n1d()\n    @autotest(n=5, auto_backward=True, check_graph=False)\n    def test_tensor_pin_memory(test_case):\n        device = random_device()\n        x = random_tensor(ndim=3).to(device)\n        x2 = x.pin_memory()\n        x3 = x2.pin_memory()\n        test_case.assertTrue(id(x.pytorch) != id(x2.pytorch))\n        test_case.assertTrue(id(x3.pytorch) == id(x2.pytorch))\n        test_case.assertTrue(id(x.oneflow) != id(x2.oneflow))\n        test_case.assertTrue(id(x3.oneflow) == id(x2.oneflow))\n        return x3\n\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    @flow.unittest.skip_unless_1n1d()\n    @autotest(n=5, auto_backward=False, check_graph=False)\n    def test_0_dim_tensor_pin_memory(test_case):\n        device = random_device()\n        x = random_tensor(ndim=1).to(device)\n        x1 = x[0]\n        x2 = x1.pin_memory()\n        x3 = x2.pin_memory()\n        test_case.assertTrue(id(x1.pytorch) != id(x2.pytorch))\n        test_case.assertTrue(id(x3.pytorch) == id(x2.pytorch))\n        test_case.assertTrue(id(x1.oneflow) != id(x2.oneflow))\n        test_case.assertTrue(id(x3.oneflow) == id(x2.oneflow))\n        return x3\n\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    @flow.unittest.skip_unless_1n1d()\n    @autotest(n=5, auto_backward=False, check_graph=False)\n    def test_tensor_construct_with_pin_memory_param(test_case):\n        device = random_device()\n        n = random(1, 4).to(int)\n        c = random(1, 4).to(int)\n        h = random(1, 4).to(int)\n        w = random(1, 4).to(int)\n        x = random_tensor(ndim=4, dim0=n, dim1=c, dim2=h, dim3=w, pin_memory=True).to(\n            device\n        )\n        return x\n\n    @unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n    @flow.unittest.skip_unless_1n1d()\n    @autotest(n=5, auto_backward=True, check_graph=False)\n    def test_tensor_is_pinned(test_case):\n        device = random_device()\n        x = random_tensor(ndim=4).to(device)\n        y = x.pin_memory()\n        test_case.assertTrue(x.oneflow.is_pinned() == x.pytorch.is_pinned())\n        test_case.assertTrue(y.oneflow.is_pinned() == y.pytorch.is_pinned())\n        return y\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test/tensor/test_tensor_to_memory_format.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport unittest\nimport random as random_util\n\nimport oneflow as flow\nimport oneflow.unittest\nimport numpy as np\n\nfrom oneflow.test_utils.automated_test_util import *\n\n\n@unittest.skipIf(os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"), \"only test cpu cases\")\n@flow.unittest.skip_unless_1n1d()\nclass TestTensor(flow.unittest.TestCase):\n    @autotest(n=3)\n    def test_to_memory_format(test_case):\n        def check_equal(a, b):\n            test_case.assertEqual(list(a.shape), list(b.shape))\n            test_case.assertEqual(list(a.stride()), list(b.stride()))\n            test_case.assertEqual(a.is_contiguous(), b.is_contiguous())\n            test_case.assertTrue(\n                np.allclose(\n                    a.detach().cpu().numpy(), b.detach().cpu().numpy(), 1e-06, 1e-06\n                )\n            )\n\n        device = random_device()\n        x = random_tensor(\n            ndim=4,\n            dim0=random(1, 6).to(int),\n            dim1=random(1, 6).to(int),\n            dim2=random(1, 6).to(int),\n            dim3=random(1, 6).to(int),\n        ).to(device)\n\n        oneflow_x = x.oneflow\n        pytorch_x = x.pytorch\n\n        # TODO(): implement backward\n        with flow.no_grad():\n            oneflow_y = oneflow_x.to(memory_format=torch.contiguous_format.oneflow)\n            pytorch_y = pytorch_x.to(memory_format=torch.contiguous_format.pytorch)\n            check_equal(oneflow_y, pytorch_y)\n\n            oneflow_y = oneflow_x.to(memory_format=torch.channels_last.oneflow)\n            pytorch_y = pytorch_x.to(memory_format=torch.channels_last.pytorch)\n            # Note: pytorch Tensor.to(channels_last) won't change tensor shape, so we should\n            #       permute it that only change the tensor shape and won't relayout its storage.\n            # TODO(): align with pytorch\n            check_equal(oneflow_y, pytorch_y.permute(0, 2, 3, 1))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "python/oneflow/test_utils/__init__.py",
    "content": ""
  },
  {
    "path": "python/oneflow/test_utils/automated_test_util/__init__.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nfrom .generators import *\nfrom .torch_flow_dual_object import *\nfrom .torch_flow_dual_object import torch\nfrom .profiler import profile\nimport os\n"
  },
  {
    "path": "python/oneflow/test_utils/automated_test_util/generators.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport inspect\nimport os\nimport random as random_util\nimport typing\nfrom collections import namedtuple\nfrom typing import Any, Dict, Optional, Tuple, Sequence, Union\nfrom itertools import product\n\nimport numpy as np\nimport torch\n\nimport oneflow as flow\n\nfrom .global_scope import *\nfrom .util import broadcast\n\npy_tuple = tuple\nNoneType = type(None)\n\nTEST_MODULE = 0\nTEST_FLOW = 1\nTEST_TENSOR = 2\nrng = np.random.default_rng()\nannotation2default_generator = {}\nannotation2torch_to_flow_converter = {}\nNoneType = type(None)\nrandom_value_default_range = {int: (-10, 11), float: (-1, 1), complex: (-10, 10)}\n\n\ndef data_generator(annotation):\n    def register_data_generator(cls):\n        annotation2default_generator[annotation] = lambda: cls()\n        return cls\n\n    return register_data_generator\n\n\ndef torch_to_flow_converter(annotation):\n    def register_flow_to_flow_converter(func):\n        annotation2torch_to_flow_converter[annotation] = func\n        return func\n\n    return register_flow_to_flow_converter\n\n\n@torch_to_flow_converter(torch.Tensor)\ndef tensor_converter(torch_tensor):\n    return flow.tensor(torch_tensor.cpu().numpy())\n\n\ndef convert_torch_object_to_flow(x):\n    for (annotation, converter) in annotation2torch_to_flow_converter.items():\n        if isinstance(x, annotation):\n            return converter(x)\n    return x\n\n\ndef pack(x):\n    if isinstance(x, generator):\n        return x\n    return constant(x)\n\n\nclass Nothing:\n    pass\n\n\nclass generator:\n    def __init__(self, children):\n        self.children = children\n        self._value = None\n        self._has_value = False\n\n    def _init(self):\n        self._value = None\n        self._has_value = False\n        for x in self.children:\n            x._init()\n\n    def eval(self):\n        self._init()\n        return self.value()\n\n    def _calc_value(self):\n        raise NotImplementedError()\n\n    def value(self):\n        if not self._has_value:\n            self._value = self._calc_value()\n            if is_global():\n                self._value = broadcast(self._value)\n            self._has_value = True\n        return self._value\n\n    def size(self):\n        return 1\n\n    def __or__(self, other):\n        other = pack(other)\n        return oneof(\n            self, other, possibility=self.size() / (self.size() + other.size())\n        )\n\n    def __ror__(self, other):\n        return self | other\n\n    def __add__(self, other):\n        return add(self, other)\n\n    def __radd__(self, other):\n        return self + other\n\n    def __sub__(self, other):\n        return self + neg(other)\n\n    def __rsub__(self, other):\n        return neg(self - other)\n\n    def __mul__(self, other):\n        return mul(self, other)\n\n    def __rmul__(self, other):\n        return self * other\n\n    def to(self, annotation):\n        self._to(annotation)\n        for x in self.children:\n            x.to(annotation)\n        return self\n\n    def _to(self, annotation):\n        pass\n\n\nclass add(generator):\n    def __init__(self, a, b):\n        self.a = pack(a)\n        self.b = pack(b)\n        super().__init__([self.a, self.b])\n\n    def _calc_value(self):\n        return self.a.value() + self.b.value()\n\n\nclass mul(generator):\n    def __init__(self, a, b):\n        self.a = pack(a)\n        self.b = pack(b)\n        super(mul, self).__init__([self.a, self.b])\n\n    def _calc_value(self):\n        return self.a.value() * self.b.value()\n\n\nclass neg(generator):\n    def __init__(self, a):\n        self.a = pack(a)\n        super().__init__([self.a])\n\n    def _calc_value(self):\n        return -self.a.value()\n\n\nclass oneof(generator):\n    def __init__(self, *args, possibility=None):\n        self.args = list(map(pack, args))\n        super().__init__(self.args)\n        if isinstance(possibility, float):\n            assert len(args) == 2\n            possibility = [possibility, 1 - possibility]\n        if possibility is None:\n            possibility = [1 / len(args)] * len(args)\n        self.possibility = pack(possibility)\n\n    def _calc_value(self):\n        rand = rng.random()\n        sum = 0\n        for (i, possibility) in enumerate(self.possibility.value()):\n            sum += possibility\n            if sum > rand:\n                return self.args[i].value()\n        raise RuntimeError()\n\n    def __call__(self, *args: Any, **kwds: Any) -> Any:\n        return self._calc_value()(*args, **kwds)\n\n    def size(self):\n        return sum([x.size() for x in self.args])\n\n\nclass tuple(generator):\n    def __init__(self, *args):\n        self.args = list(map(pack, args))\n        super().__init__(self.args)\n\n    def _calc_value(self):\n        return py_tuple([x.value() for x in self.args])\n\n\nclass constant(generator):\n    def __init__(self, x):\n        super().__init__([])\n        self.x = x\n\n    def _calc_value(self):\n        return self.x\n\n\nclass nothing(generator):\n    def __init__(self):\n        super().__init__([])\n\n    def _calc_value(self):\n        return Nothing()\n\n\nclass random(generator):\n    def __init__(self, low=1, high=6):\n        self.low = pack(low)\n        self.high = pack(high)\n        super().__init__([self.low, self.high])\n        self.annotation = None\n\n    def _to(self, annotation):\n        if self.annotation is not None:\n            return\n        if hasattr(annotation, \"__origin__\"):\n            annotation = eval(repr(annotation))\n        self.annotation = annotation\n\n    def _generate(self, annotation):\n        if hasattr(annotation, \"__origin__\"):\n            if annotation.__origin__ is Union:\n                x = random_util.choice(annotation.__args__)\n                return self._generate(x)\n            if annotation.__origin__ is Tuple or annotation.__origin__ is py_tuple:\n                return [self._generate(x) for x in annotation.__args__]\n            else:\n                raise NotImplementedError(\n                    f\"Not implemented annotation {annotation} in random, type(annotation.__origin__) is {type(annotation.__origin__)}\"\n                )\n        (low, high) = (self.low.value(), self.high.value())\n        if annotation == int:\n            val = int(rng.integers(low, high))\n        elif annotation == float:\n            val = float(rng.random() * (high - low) + low)\n        elif annotation == bool:\n            val = random_util.choice([True, False])\n        elif annotation == complex:\n            val_real = float(rng.random() * (high - low) + low)\n            val_imag = float(rng.random() * (high - low) + low)\n            val = val_real + 1.0j * val_imag\n        elif annotation is None:\n            val = None\n        elif annotation is NoneType:\n            val = None\n        else:\n            raise NotImplementedError(\n                f\"Not implemented annotation {annotation} in random\"\n            )\n        return val\n\n    def _calc_value(self):\n        return self._generate(self.annotation)\n\n\ndef random_or_nothing(low, high):\n    return oneof(random(low, high), nothing(), possibility=2 / 3)\n\n\n@data_generator(torch.Tensor)\nclass random_pytorch_tensor(generator):\n    def __init__(\n        self,\n        ndim=None,\n        dim0=1,\n        dim1=None,\n        dim2=None,\n        dim3=None,\n        dim4=None,\n        low=None,\n        high=None,\n        dtype=float,\n        pin_memory=False,\n    ):\n        if ndim is None:\n            ndim = random(1, 6)\n        if dim0 is None:\n            dim0 = random(1, 8)\n        if dim1 is None:\n            dim1 = random(1, 8)\n        if dim2 is None:\n            dim2 = random(1, 8)\n        if dim3 is None:\n            dim3 = random(1, 8)\n        if dim4 is None:\n            dim4 = random(1, 8)\n        self.ndim = pack(ndim).to(int)\n        self.dim0 = pack(dim0).to(int)\n        self.dim1 = pack(dim1).to(int)\n        self.dim2 = pack(dim2).to(int)\n        self.dim3 = pack(dim3).to(int)\n        self.dim4 = pack(dim4).to(int)\n        self.low = pack(low).to(float)\n        self.high = pack(high).to(float)\n        self.dtype = pack(dtype)\n        self.pin_memory = pin_memory\n        super().__init__(\n            [\n                self.ndim,\n                self.dim0,\n                self.dim1,\n                self.dim2,\n                self.dim3,\n                self.dim4,\n                self.low,\n                self.high,\n                self.dtype,\n                self.pin_memory,\n            ]\n        )\n\n    def _calc_value(self):\n        ndim = self.ndim.value()\n        dim0 = self.dim0.value()\n        dim1 = self.dim1.value()\n        dim2 = self.dim2.value()\n        dim3 = self.dim3.value()\n        dim4 = self.dim4.value()\n        dtype = self.dtype.value()\n        low = self.low.value()\n        high = self.high.value()\n        if low is None:\n            low = random_value_default_range[dtype][0]\n        if high is None:\n            high = random_value_default_range[dtype][1]\n        pin_memory = self.pin_memory\n\n        shape = rng.integers(low=1, high=8, size=ndim)\n        if ndim == 0:\n            shape = []\n        if ndim >= 1 and dim0 is not None:\n            shape[0] = dim0\n        if ndim >= 2:\n            shape[1] = dim1\n        if ndim >= 3:\n            shape[2] = dim2\n        if ndim >= 4:\n            shape[3] = dim3\n        if ndim == 5:\n            shape[4] = dim4\n\n        pytorch_tensor = None\n        if dtype == float:\n            np_arr = rng.uniform(low=low, high=high, size=shape)\n            res = torch.Tensor(np_arr)\n            if pin_memory:\n                res = res.pin_memory()\n            return res\n        elif dtype == int:\n            np_arr = rng.integers(low=low, high=high, size=shape)\n            res = torch.tensor(np_arr, dtype=torch.int64)\n            if pin_memory:\n                res = res.pin_memory()\n            return res\n        elif dtype == complex:\n            np_arr = rng.uniform(low=low, high=high, size=shape) + 1.0j * rng.uniform(\n                low=low, high=high, size=shape\n            )\n            res = torch.tensor(np_arr, dtype=torch.complex64)\n            if pin_memory:\n                res = res.pin_memory()\n            return res\n        else:\n            raise NotImplementedError(f\"Not implemented dtype {dtype} in random\")\n\n\n@data_generator(bool)\ndef random_bool():\n    return random().to(bool)\n\n\nclass random_device(generator):\n    def __init__(self):\n        super().__init__([])\n\n    def _calc_value(self):\n        if os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"):\n            return \"cpu\"\n        else:\n            return random_util.choice([\"cuda\", \"cpu\"])\n\n\nclass cpu_device(generator):\n    def __init__(self):\n        super().__init__([])\n\n    def _calc_value(self):\n        return random_util.choice([\"cpu\"])\n\n\nclass gpu_device(generator):\n    def __init__(self):\n        super().__init__([])\n\n    def _calc_value(self):\n        return random_util.choice([\"cuda\"])\n\n\n@data_generator(torch.dtype)\nclass random_pytorch_dtype(generator):\n    none_dtype_seq = [None]\n    bool_dtype_seq = [torch.bool]\n    floating_dtype_seq = [torch.float, torch.double]\n    half_dtype_seq = [torch.half]\n    bfloat16_dtype_seq = [torch.bfloat16]\n    complex_dtype_seq = [torch.complex64, torch.complex128]\n    signed_int_dtype_seq = [torch.int8, torch.int32, torch.int64]\n    unsigned_int_dtype_seq = [torch.uint8]\n    int_dtype_seq = [torch.int8, torch.int32, torch.int64]\n    image_dtype_seq = [torch.uint8, torch.float]\n    index_dtype_seq = [torch.int32, torch.int64]\n    arithmetic_dtype_seq = [*floating_dtype_seq, *int_dtype_seq]\n    pod_dtype_seq = [*arithmetic_dtype_seq, *unsigned_int_dtype_seq, *bool_dtype_seq]\n    all_dtype_seq = [*arithmetic_dtype_seq, torch.half, torch.bfloat16]\n\n    seq_name_to_seq = {\n        \"None\": none_dtype_seq,\n        \"bool\": bool_dtype_seq,\n        \"float\": floating_dtype_seq,\n        \"half\": half_dtype_seq,\n        \"bfloat16\": bfloat16_dtype_seq,\n        \"complex\": complex_dtype_seq,\n        \"signed\": signed_int_dtype_seq,\n        \"unsigned\": unsigned_int_dtype_seq,\n        \"int\": int_dtype_seq,\n        \"image\": image_dtype_seq,\n        \"index\": index_dtype_seq,\n        \"arithmetic\": arithmetic_dtype_seq,\n        \"pod\": pod_dtype_seq,\n        \"all\": all_dtype_seq,\n    }\n\n    def __init__(self, seq_names):\n        super().__init__([])\n        # concat related dtype_seq for name in seq_names\n        self.data_type_seq = [\n            dtype for name in seq_names for dtype in self.seq_name_to_seq[name]\n        ]\n\n    def _calc_value(self):\n        return random_util.choice(self.data_type_seq)\n\n\nclass all_placement(generator):\n    def __init__(self):\n        super().__init__([])\n        self.node_size = flow.env.get_node_size()\n        self.world_size = flow.env.get_world_size()\n        self.num_rank_for_each_node = self.world_size // self.node_size\n\n    def __len__(self):\n        return len(self.value())\n\n    def __getitem__(self, key):\n        return self.value()[key]\n\n    def _calc_device(self):\n        if os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"):\n            return [\n                \"cpu\",\n            ]\n        else:\n            return [\"cuda\", \"cpu\"]\n\n    def _calc_all_placement(self):\n        all_device = self._calc_device()\n        all_hierarchy = [\n            (self.world_size,),\n            (self.node_size, self.num_rank_for_each_node),\n        ]\n        return [\n            flow.placement(device, np.array(range(self.world_size)).reshape(hierarchy))\n            for device, hierarchy in list(product(all_device, all_hierarchy))\n        ]\n\n    def _calc_value(self):\n        return self._calc_all_placement()\n\n\nclass all_cpu_placement(all_placement):\n    def __init__(self):\n        super().__init__()\n\n    def _calc_device(self):\n        return [\"cpu\"]\n\n\nclass all_cuda_placement(all_placement):\n    def __init__(self):\n        super().__init__()\n\n    def _calc_device(self):\n        return [\"cuda\"]\n\n\nclass random_placement(all_placement):\n    def __init__(self):\n        super().__init__()\n\n    def _calc_value(self):\n        return random_util.choice(self._calc_all_placement())\n\n\nclass random_cpu_placement(random_placement):\n    def __init__(self):\n        super().__init__()\n\n    def _calc_device(self):\n        return [\"cpu\"]\n\n\nclass random_gpu_placement(random_placement):\n    def __init__(self):\n        super().__init__()\n\n    def _calc_device(self):\n        return [\"cuda\"]\n\n\nclass all_sbp(generator):\n    def __init__(\n        self,\n        placement=None,\n        dim=1,\n        max_dim=0,\n        except_split=False,\n        except_broadcast=False,\n        except_partial_sum=False,\n        valid_split_axis: Optional[Union[int, Sequence[int]]] = None,\n    ):\n        super().__init__([])\n        if placement is not None:\n            if isinstance(placement, random_placement):\n                self.dim = len(placement.value().ranks.shape)\n            elif isinstance(placement, flow.placement):\n                self.dim = len(placement.ranks.shape)\n            else:\n                raise RuntimeError(\n                    f\"placement should be instance of random_placement or oneflow.placement\"\n                )\n        else:\n            self.dim = dim\n        self.max_dim = max_dim\n        self.except_split = except_split\n        self.except_broadcast = except_broadcast\n        self.except_partial_sum = except_partial_sum\n        if valid_split_axis is not None:\n            if isinstance(valid_split_axis, int):\n                self.valid_split_axis = [\n                    valid_split_axis,\n                ]\n            else:\n                self.valid_split_axis = list(valid_split_axis)\n        else:\n            self.valid_split_axis = [i for i in range(self.max_dim)]\n\n    def __len__(self):\n        return len(self.value())\n\n    def __getitem__(self, key):\n        return self.value()[key]\n\n    def _calc_all_sbp(self):\n        # scalar only use broadcast sbp\n        if self.max_dim == 0:\n            return [\n                [flow.sbp.broadcast for i in range(self.dim)],\n            ]\n        all_sbps = []\n        if not self.except_split:\n            for i in range(self.max_dim):\n                if i in self.valid_split_axis:\n                    all_sbps.append(flow.sbp.split(i))\n        if not self.except_broadcast:\n            all_sbps.append(flow.sbp.broadcast)\n        if not self.except_partial_sum:\n            all_sbps.append(flow.sbp.partial_sum)\n        return list(product(all_sbps, repeat=self.dim))\n\n    def _calc_value(self):\n        return self._calc_all_sbp()\n\n\nclass random_sbp(all_sbp):\n    def __init__(\n        self,\n        placement=None,\n        dim=1,\n        max_dim=0,\n        except_split=False,\n        except_broadcast=False,\n        except_partial_sum=False,\n        valid_split_axis: Optional[Union[int, Sequence[int]]] = None,\n    ):\n        super().__init__(\n            placement,\n            dim,\n            max_dim,\n            except_split,\n            except_broadcast,\n            except_partial_sum,\n            valid_split_axis,\n        )\n\n    def _calc_value(self):\n        return random_util.choice(self._calc_all_sbp())\n\n\n@data_generator(torch.Tensor)\nclass choice_pytorch_tensor(generator):\n    def __init__(self, a, size=None, replace=True, p=None, dtype=int):\n        self.a = a\n        self.size = size\n        self.replace = replace\n        self.p = p\n        self.dtype = dtype\n        super().__init__(\n            [self.a, self.size, self.replace, self.p, self.dtype,]\n        )\n\n    def _calc_value(self):\n        pytorch_tensor = None\n        np_arr = np.random.choice(self.a, self.size, self.replace, self.p)\n        torch_dtype = None\n        return torch.tensor(np_arr.astype(self.dtype))\n\n\n__all__ = [\n    \"random_pytorch_tensor\",\n    \"random_bool\",\n    \"random_device\",\n    \"random_pytorch_dtype\",\n    \"cpu_device\",\n    \"gpu_device\",\n    \"random_placement\",\n    \"random_cpu_placement\",\n    \"random_gpu_placement\",\n    \"all_placement\",\n    \"all_cpu_placement\",\n    \"all_cuda_placement\",\n    \"random_sbp\",\n    \"all_sbp\",\n    \"random\",\n    \"random_or_nothing\",\n    \"oneof\",\n    \"constant\",\n    \"nothing\",\n    \"choice_pytorch_tensor\",\n]\n"
  },
  {
    "path": "python/oneflow/test_utils/automated_test_util/global_scope.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\n_global_is_global = False\n\n\nclass GlobalScope:\n    def __init__(self):\n        pass\n\n    def __enter__(self, *argc, **kwarg):\n        global _global_is_global\n        self.last_is_global = _global_is_global\n        _global_is_global = True\n\n    def __exit__(self, *argc, **kwarg):\n        global _global_is_global\n        _global_is_global = self.last_is_global\n\n\ndef is_global():\n    return _global_is_global\n"
  },
  {
    "path": "python/oneflow/test_utils/automated_test_util/profiler.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport functools\nimport os\nfrom typing import Any, Callable, Iterable, List, Optional, Tuple\n\nimport torch\nimport oneflow as flow\nimport oneflow.support.env_var_util\nfrom oneflow.test_utils.automated_test_util import (\n    torch_flow_dual_object as dual_object_module,\n)\n\n__all__ = [\"profile\", \"set_profiler_hook\", \"profile_dual_object\", \"profiled_framework\"]\n\n\ndef compose(*fs):\n    def compose2(f, g):\n        return lambda *a, **kw: f(g(*a, **kw))\n\n    return functools.reduce(compose2, fs)\n\n\nclass ProfResult:\n    def __init__(\n        self,\n        prof,\n        num,\n        kind,\n        device,\n        thread_num,\n        op_name,\n        args_description,\n        additional_description=None,\n    ):\n        self.prof = prof\n        self.num = num\n        self.kind = kind\n        self.device = device\n        self.thread_num = thread_num\n        self.op_name = op_name\n        self.args_description = args_description\n        self.additional_description = additional_description\n\n    def __getattr__(self, attr):\n        return getattr(self.prof, attr)\n\n\nWARMUP_NUM = int(os.getenv(\"ONEFLOW_PROFILE_WARMUP_NUM\", 10))\nRUN_NUM = int(os.getenv(\"ONEFLOW_PROFILE_RUN_NUM\", 1000))\nPROF_VERBOSE = flow.support.env_var_util.parse_boolean_from_env(\n    \"ONEFLOW_PROFILE_VERBOSE\", False\n)\nEND_TO_END = \"end-to-end\"\n\n\ndef run_torch(\n    op,\n    args,\n    kwargs,\n    device,\n    num_threads,\n    op_name,\n    args_description,\n    additional_description=None,\n):\n    assert device in [\"cpu\", \"cuda\"]\n    if device == \"cpu\":\n        torch.set_num_threads(num_threads)\n        assert torch.get_num_threads() == num_threads\n        activities = [torch.profiler.ProfilerActivity.CPU]\n    else:\n        activities = [torch.profiler.ProfilerActivity.CUDA]\n\n    def tensor_to_device(x):\n        if isinstance(x, torch.Tensor):\n            return x.to(device)\n        return x\n\n    args = [tensor_to_device(arg) for arg in args]\n    kwargs = {k: tensor_to_device(v) for k, v in kwargs.items()}\n    for _ in range(WARMUP_NUM):\n        op(*args, **kwargs)\n\n    if PROF_VERBOSE:\n        print(\n            f'PyTorch ({f\"CPU, num_threads={num_threads}\" if device == \"cpu\" else \"GPU\"}):'\n        )\n    with torch.profiler.profile(activities=activities) as prof:\n        with torch.profiler.record_function(END_TO_END):\n            for _ in range(RUN_NUM):\n                op(*args, **kwargs)\n\n    if PROF_VERBOSE:\n        print(prof.key_averages().table(row_limit=10))\n    return ProfResult(\n        prof,\n        RUN_NUM,\n        \"PyTorch\",\n        device,\n        num_threads,\n        op_name,\n        args_description,\n        additional_description,\n    )\n\n\ndef run_flow(\n    op,\n    args,\n    kwargs,\n    device,\n    num_threads,\n    op_name,\n    args_description,\n    additional_description=None,\n):\n    assert device in [\"cpu\", \"cuda\"]\n    if device == \"cpu\":\n        # NOTE: there is no flow.get_num_threads()\n        flow.set_num_threads(num_threads)\n        activities = [flow.profiler.ProfilerActivity.CPU]\n    else:\n        activities = [flow.profiler.ProfilerActivity.CUDA]\n\n    def tensor_to_device(x):\n        if isinstance(x, flow.Tensor):\n            return x.to(device)\n        return x\n\n    args = [tensor_to_device(arg) for arg in args]\n    kwargs = {k: tensor_to_device(v) for k, v in kwargs.items()}\n    for _ in range(WARMUP_NUM):\n        op(*args, **kwargs)\n\n    if PROF_VERBOSE:\n        print(\n            f'OneFlow ({f\"CPU, num_threads={num_threads}\" if device == \"cpu\" else \"GPU\"}):'\n        )\n    with flow.profiler.profile(\n        activities=activities,\n        record_bandwidth_for_cuda=flow.profiler.ProfilerActivity.CUDA in activities,\n    ) as prof:\n        with flow.profiler.record_function(END_TO_END):\n            for _ in range(RUN_NUM):\n                op(*args, **kwargs)\n\n    if PROF_VERBOSE:\n        print(prof.key_averages())\n    return ProfResult(\n        prof,\n        RUN_NUM,\n        \"OneFlow\",\n        device,\n        num_threads,\n        op_name,\n        args_description,\n        additional_description,\n    )\n\n\ndef profile_dual_object(op):\n    assert isinstance(op, dual_object_module.DualObject)\n    torch_op = op.pytorch\n    flow_op = op.oneflow\n\n    def profiled_op(*args, **kwargs):\n        if \"profile_description\" in kwargs:\n            additional_description = kwargs[\"profile_description\"]\n            del kwargs[\"profile_description\"]\n        else:\n            additional_description = None\n\n        (\n            torch_args,\n            torch_kwargs,\n            flow_args,\n            flow_kwargs,\n        ) = dual_object_module.get_args(torch_op, *args, **kwargs)\n\n        op_name = dual_object_module.to_string(op)\n        args_description = dual_object_module.to_string(*args, **kwargs)\n\n        result = []\n        for hardware_info in _hardware_info_list:\n            if \"oneflow\" in profiled_framework:\n                result.append(\n                    run_flow(\n                        flow_op,\n                        flow_args,\n                        flow_kwargs,\n                        *hardware_info,\n                        op_name,\n                        args_description,\n                        additional_description,\n                    )\n                )\n            else:\n                result.append(None)\n        for hardware_info in _hardware_info_list:\n            if \"pytorch\" in profiled_framework:\n                result.append(\n                    run_torch(\n                        torch_op,\n                        torch_args,\n                        torch_kwargs,\n                        *hardware_info,\n                        op_name,\n                        args_description,\n                        additional_description,\n                    )\n                )\n            else:\n                result.append(None)\n        return _profiler_hook(result)\n\n    return profiled_op\n\n\nHardwareInfo = Tuple[str, Optional[int]]  # (device_type, num_threads)\n_hardware_info_list: List[HardwareInfo] = [(\"cpu\", 1), (\"cuda\", None)]\n_profiler_hook: Callable[[List[ProfResult]], Any] = lambda x: x\nprofiled_framework: List[str] = [\"oneflow\", \"pytorch\"]\n\n\ndef set_hardware_info_list(hardware_info_list: List[HardwareInfo]) -> None:\n    global _hardware_info_list\n    _hardware_info_list = hardware_info_list\n\n\ndef set_profiler_hook(hook: Callable[[List[ProfResult]], Any]) -> None:\n    global _profiler_hook\n    _profiler_hook = hook\n\n\ndef profile(op):\n    def deco(f):\n        def new_f(*args, **kwargs):\n            dual_object_module.profiled_method_name.append(op.name)\n            res = f(*args, **kwargs)\n            dual_object_module.profiled_method_name.pop()\n            return res\n\n        return new_f\n\n    return deco\n"
  },
  {
    "path": "python/oneflow/test_utils/automated_test_util/torch_flow_dual_object.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport collections.abc\nimport functools\nimport inspect\nimport copy\nimport os\nimport warnings\nimport gc\nfrom typing import Union\n\nimport numpy as np\nimport oneflow as flow\nfrom oneflow.test_utils.automated_test_util import profiler as auto_profiler\nfrom oneflow.test_utils.test_util import type_name_to_flow_type\n\nflow.backends.cudnn.deterministic = True\n\ntry:\n    import torch as torch_original\n\n    torch_original.backends.cudnn.deterministic = True\n    torch_original.set_printoptions(profile=\"full\")\nexcept ImportError:\n    print(\n        \"automated_test_util module uses PyTorch to verify OneFlow module's interface and result. Please install Pytorch according `https://pytorch.org/get-started/locally/`.\"\n    )\n\n\nfrom .util import broadcast\nfrom .global_scope import *\nfrom .generators import (\n    Nothing,\n    generator,\n    random_pytorch_tensor,\n    random_pytorch_dtype,\n    choice_pytorch_tensor,\n    rng,\n)\n\npostulate = [\".rand\", \".Tensor\"]\n\ntesting = False\ntesting_graph = False\ntesting_complex = False\nglobal_check_allclose = True\nglobal_atol = 1e-5\nglobal_rtol = 1e-5\nglobal_backward = True\n\n\ndef torch_tensor_to_flow(x):\n    return flow.tensor(x.cpu().numpy())\n\n\nnote_pytorch_method_names = []\nnote_pytorch_args = []\nnote_pytorch_kwargs = []\nvis_tensor = []\nvis_parameters = {}\ncall_tensor_id = []\nextra_input_tensor = []\n\n\nclass PyTorchDoesNotSupportError(Exception):\n    def __init__(self, exc):\n        self.exc = exc\n\n    def __str__(self):\n        return repr(self)\n\n    def __repr__(self):\n        return f\"PyTorch error: {str(self.exc)}\"\n\n\nclass OneFlowGraphBuildOrRunError(Exception):\n    def __init__(self, exc):\n        self.exc = exc\n\n    def __str__(self):\n        return repr(self)\n\n    def __repr__(self):\n        return f\"OneFlow nn.Graph Build Or Run Error: {str(self.exc)}\"\n\n\nclass BothDoNotSupportError(Exception):\n    def __init__(self, th_exc, of_exc):\n        self.th_exc = th_exc\n        self.of_exc = of_exc\n\n    def __str__(self):\n        return repr(self)\n\n    def __repr__(self):\n        return f\"PyTorch error: {str(self.th_exc)}\\nOneFlow error: {str(self.of_exc)}\"\n\n\ncall_pytorch = None\n\n\ndef get_tensor_shape(call_pytorch):\n    shape_list = []\n    for i in range(len(call_pytorch.shape)):\n        shape_list.append(call_pytorch.shape[i])\n    return shape_list\n\n\ndef get_args(callable, *args, **kwargs):\n    try:\n        spec = inspect.getfullargspec(callable)\n        spec_args = spec.args\n        if spec_args[0] == \"self\":\n            del spec_args[0]\n        for (i, arg) in enumerate(args):\n            arg_name = spec_args[i]\n            annotation = spec.annotations[arg_name]\n            if isinstance(arg, generator):\n                arg.to(annotation)\n        for (arg_name, arg) in kwargs.items():\n            annotation = spec.annotations[arg_name]\n            if isinstance(arg, generator):\n                arg.to(annotation)\n    except:\n        pass\n    (pytorch_args, pytorch_kwargs, oneflow_args, oneflow_kwargs) = ([], {}, [], {})\n\n    def get_pytorch_value(x):\n        if isinstance(x, DualObject):\n            return x.pytorch\n        return x\n\n    def get_oneflow_value(x):\n        if isinstance(x, DualObject):\n            return x.oneflow\n        return x\n\n    def get_generator_value(x):\n        if isinstance(x, generator):\n            return x.value()\n        return x\n\n    for arg in args:\n        # TODO: refine codes\n        if isinstance(arg, (tuple, list)):\n            pytorch_tuple_args = []\n            oneflow_tuple_args = []\n            for t in arg:\n                t = get_generator_value(t)\n                pytorch_tuple_args.append(get_pytorch_value(t))\n                oneflow_tuple_args.append(get_oneflow_value(t))\n            pytorch_args.append(tuple(pytorch_tuple_args))\n            oneflow_args.append(tuple(oneflow_tuple_args))\n        else:\n            arg = get_generator_value(arg)\n            pytorch_args.append(get_pytorch_value(arg))\n            oneflow_args.append(get_oneflow_value(arg))\n    for (key, value) in kwargs.items():\n        value = get_generator_value(value)\n        if isinstance(value, Nothing):\n            continue\n        pytorch_kwargs[key] = get_pytorch_value(value)\n        oneflow_kwargs[key] = get_oneflow_value(value)\n\n    new_pytorch_args = []\n    new_pytorch_kwargs = {}\n    for x in pytorch_args:\n        if isinstance(x, (tuple, list)):\n            new_x = f\"(\"\n            len_x = len(x)\n            for i in range(len_x):\n                if type(x[i]) is torch_original.Tensor:\n                    if i < len_x - 1:\n                        new_x += f\"Tensor({get_tensor_shape(x[i])}), \"\n                    else:\n                        new_x += f\"Tensor({get_tensor_shape(x[i])})\"\n                else:\n                    if i < len_x - 1:\n                        new_x += f\"{x[i]}, \"\n                    else:\n                        new_x += f\"{x[i]}\"\n            new_x += f\")\"\n            new_pytorch_args.append(new_x)\n            continue\n        if type(x) is torch_original.Tensor:\n            new_pytorch_args.append(f\"Tensor({get_tensor_shape(x)})\")\n        else:\n            new_pytorch_args.append(x)\n    for key, value in pytorch_kwargs.items():\n        if type(value) is torch_original.Tensor:\n            new_pytorch_kwargs[key] = f\"Tensor({get_tensor_shape(value)})\"\n        else:\n            new_pytorch_kwargs[key] = value\n\n    if not isinstance(callable, (torch_original.nn.Module)):\n        if isinstance(call_pytorch, torch_original.Tensor):\n            note_pytorch_method_names.append(\n                f\"Tensor({get_tensor_shape(call_pytorch)}).{callable.__name__}\"\n            )\n        elif isinstance(call_pytorch, torch_original.nn.Module):\n            note_pytorch_method_names.append(f\"Module.{callable.__name__}\")\n        else:\n            note_pytorch_method_names.append(f\"{callable.__name__}\")\n    else:\n        note_pytorch_method_names.append(repr(callable))\n\n    note_pytorch_args.append(new_pytorch_args)\n    note_pytorch_kwargs.append(new_pytorch_kwargs)\n\n    return (pytorch_args, pytorch_kwargs, oneflow_args, oneflow_kwargs)\n\n\ndef to_string(*args, **kwargs) -> str:\n    def _to_string(x):\n        if isinstance(x, DualObject):\n            return x.name\n        return str(x)\n\n    strs = []\n    if len(args) > 0:\n        strs.append(\", \".join([_to_string(arg) for arg in args]))\n    if len(kwargs) > 0:\n        strs.append(\", \".join([f\"{k}={_to_string(v)}\" for k, v in kwargs.items()]))\n    return \", \".join(strs)\n\n\ncounter = 0\nalign_exception = os.getenv(\"ONEFLOW_TEST_ALIGN_EXCEPTION\") is not None\n\n\ndef check_eager_graph_tensor(eager_res, graph_res):\n    if (\n        global_check_allclose\n        and isinstance(eager_res, flow.Tensor)\n        and isinstance(graph_res, flow.Tensor)\n    ):\n        equality_res = np.allclose(\n            eager_res.numpy(),\n            graph_res.numpy(),\n            rtol=global_rtol,\n            atol=global_atol,\n            equal_nan=True,\n        )\n        return equality_res\n    else:\n        return True\n\n\n# NOTE(lixiang): Deepcopy the input parameters in order to correctly test the inplace version of the op.\ndef get_args_copy(args, kwargs):\n    copy_args = []\n    for arg in args:\n        if flow.is_tensor(arg):\n            copy_arg = arg.clone().detach()\n        else:\n            copy_arg = copy.deepcopy(arg)\n        copy_args.append(copy_arg)\n    copy_kwargs = {}\n    for key, value in kwargs.items():\n        if flow.is_tensor(value):\n            copy_kwargs[key] = value.clone().detach()\n        else:\n            copy_kwargs[key] = copy.deepcopy(value)\n    return copy_args, copy_kwargs\n\n\ndef get_fake_program_more_detail(oneflow, mode, func, args=None, kwargs=None):\n    print(f\"\\033[1;33m============= {mode} ================\\033[1;33m\")\n    print(f\"\\033[1;33mEnter {func} function\\033[1;33m\")\n    try:\n        if \"__self__\" in dir(oneflow) and flow.is_tensor(oneflow.__self__):\n            print(f\"\\033[1;33m{oneflow.__self__}\\033[1;33m\")\n    except:\n        if flow.is_tensor(oneflow):\n            print(f\"\\033[1;33m{oneflow}\\033[1;33m\")\n    if args is not None:\n        print(f\"\\033[1;33m{args}\\033[1;33m\")\n    if kwargs is not None:\n        print(f\"\\033[1;33m{kwargs}\\033[1;33m\")\n    print_note_fake_program()\n    print(f\"\\033[1;33mLeave {func} function\\033[1;33m\")\n    print(f\"\\033[1;37m\\033[1;37m\")\n    print(\"\\n\\n\")\n\n\n# NOTE(lixiang): When the graph global test is executed, the func is used to get the device type.\ndef get_global_test_device(oneflow_args, oneflow_kwargs=None):\n    # The case when the parameter input of Op only has kwargs.\n    if not oneflow_args:\n        return oneflow_kwargs[\"placement\"].type\n    # The case when the parameter input of Op is tensors.\n    elif isinstance(oneflow_args[0], flow.Tensor):\n        return oneflow_args[0].placement.type\n    # The case when the parameter input of Op is tensor.\n    elif isinstance(oneflow_args[0], flow.placement):\n        return oneflow_args[0].type\n    # The case when the parameter input of Op is tuple. For example: test_0_dim_tensor.\n    elif isinstance(oneflow_args[0], tuple):\n        return oneflow_args[0][0].placement.type\n    # When oneflow_args[0] is int or float, etc.\n    else:\n        return oneflow_args[1].placement.type\n\n\n# NOTE(lixiang): When oneflow is of type nn.Module, build the following Graph for testing.\n#   graph_train_oneflow: is a deepcopy of oneflow.\ndef get_module_graph_test(graph_train_oneflow, oneflow, verbose, oneflow_args, *args):\n    of_sgd = flow.optim.SGD(graph_train_oneflow.parameters(), lr=0.001, momentum=0.9,)\n    graph_train_parameters_len = 0\n    for param in oneflow._parameters.values():\n        if param is not None:\n            graph_train_parameters_len += 1\n\n    if verbose:\n        get_fake_program_more_detail(\n            oneflow, \"nn.Graph\", \"get_module_graph_test\", oneflow_args\n        )\n\n    class TestGraphOfModule(flow.nn.Graph):\n        def __init__(self):\n            super().__init__()\n            self.test_module = graph_train_oneflow\n            if global_backward and graph_train_parameters_len:\n                self.add_optimizer(of_sgd)\n\n        def build(self, *args):\n            res = self.test_module(*args)\n            forward_res = res\n            if global_backward and graph_train_parameters_len:\n                if isinstance(self.test_module.to(flow.nn.Module), flow.nn.LSTMCell):\n                    res = res[0] + res[1]\n                elif isinstance(self.test_module.to(flow.nn.Module), flow.nn.LSTM):\n                    res = res[0].sum() + res[1][0].sum() + res[1][1].sum()\n                elif isinstance(res, (tuple, list)):\n                    res = res[0]\n                res = res.sum()\n                res.backward()\n            return forward_res\n\n    try:\n        test_g_res = TestGraphOfModule()\n    except Exception as e:\n        if not verbose:\n            get_fake_program_more_detail(\n                oneflow, \"nn.Graph\", \"get_module_graph_test\", oneflow_args\n            )\n        raise OneFlowGraphBuildOrRunError(e)\n    return test_g_res\n\n\ndef check_oneflow_args_first_element_is_int(args):\n    if isinstance(args, (tuple, list)) and len(args) > 0:\n        if isinstance(args[0], (int, float)):\n            return True\n        elif isinstance(args[0], (tuple, list)):\n            return check_oneflow_args_first_element_is_int(args[0])\n    return False\n\n\n# NOTE(lixiang): When oneflow is of functional type, build the following Graph for testing, and return the test results in Graph mode.\n#   graph_functional_oneflow: is a deepcopy of oneflow.\ndef get_functional_graph_res(\n    graph_functional_oneflow,\n    oneflow,\n    oneflow_res,\n    oneflow_args,\n    oneflow_kwargs,\n    verbose,\n    *graph_args,\n    **graph_kwargs,\n):\n    test_g_res = []\n\n    if verbose:\n        get_fake_program_more_detail(\n            oneflow,\n            \"nn.Graph\",\n            \"get_functional_graph_res\",\n            oneflow_args,\n            oneflow_kwargs,\n        )\n\n    class TestGraphOfFunctional(flow.nn.Graph):\n        def __init__(self):\n            super().__init__()\n\n        def build(self):\n            return graph_functional_oneflow(*graph_args, **graph_kwargs)\n\n    try:\n        is_global_flag = is_global()\n\n        # In graph mode, when the tensor on the cpu executes the to(\"cpu\") method, a check error will be reported.\n        if oneflow.__name__ == \"to\" or oneflow.__name__ == \"_to\":\n            if isinstance(oneflow_res, flow.Tensor):\n                # The global tensor needs to obtain the device type through placement.type.\n                if is_global_flag:\n                    if (\n                        oneflow_args and oneflow_res.placement.type == oneflow_args[0]\n                    ) or (\n                        oneflow_kwargs\n                        and oneflow_res.placement.type == oneflow_kwargs[\"device\"]\n                    ):\n                        test_g_res = oneflow_res\n                # The tensor needs to obtain the device type through device.type.\n                else:\n                    if (\n                        oneflow_args and oneflow_res.device.type == oneflow_args[0]\n                    ) or (\n                        oneflow_kwargs\n                        and oneflow_res.device.type == oneflow_kwargs[\"device\"]\n                    ):\n                        test_g_res = oneflow_res\n            else:\n                pass\n        # nn.Graph donot deal with Module type. EX: m.to_global(placement, sbp).\n        elif oneflow.__name__ == \"to_global\":\n            test_g_res = oneflow_res\n        elif oneflow.__name__ == \"Parameter\":\n            # nn.Graph donot deal with Parameter creation.\n            test_g_res = oneflow_res\n        # oneflow_args may be empty, such as dropout.\n        elif is_global_flag and len(oneflow_args) == 0:\n            test_g_res = oneflow_res\n        # For some ops whose input parameters is int, 'int' object has no attribute 'placement'.\n        elif (\n            is_global_flag\n            and len(oneflow_args) != 0\n            and (check_oneflow_args_first_element_is_int(oneflow_args))\n        ):\n            test_g_res = oneflow_res\n        # When doing the global op test, get_global_test_device() will be executed, and temporarily skipping the graph autotest on cpu device.\n        elif (\n            is_global_flag\n            and oneflow.__name__ != \"weight_norm\"\n            and (get_global_test_device(oneflow_args, oneflow_kwargs) == \"cpu\")\n        ):\n            test_g_res = oneflow_res\n        else:\n            test_g = TestGraphOfFunctional()\n            test_g_res = test_g()\n    except Exception as e:\n        if not verbose:\n            get_fake_program_more_detail(\n                oneflow,\n                \"nn.Graph\",\n                \"get_functional_graph_res\",\n                oneflow_args,\n                oneflow_kwargs,\n            )\n        raise OneFlowGraphBuildOrRunError(e)\n    return test_g_res\n\n\n# NOTE(lixiang): When oneflow is of tensor type, build the following Graph for testing, and return the test results in Graph mode.\n#   graph_tensor_oneflow is a deepcopy of oneflow.\ndef get_tensor_graph_res(\n    graph_tensor_oneflow, oneflow, verbose, *tensor_graph_args, **tensor_graph_kwargs\n):\n    test_g_res = []\n\n    if verbose:\n        get_fake_program_more_detail(\n            oneflow,\n            \"nn.Graph\",\n            \"get_tensor_graph_res\",\n            tensor_graph_args,\n            tensor_graph_kwargs,\n        )\n\n    class TestGraphOfTensorMethod(flow.nn.Graph):\n        def __init__(self):\n            super().__init__()\n\n        def build(self):\n            return graph_tensor_oneflow(*tensor_graph_args, **tensor_graph_kwargs)\n\n    try:\n        # Set test_g_res = None, check_eager_graph_tensor will return True, the purpose is to temporarily skip the Graph global test on cpu.\n        if is_global() and (get_global_test_device((oneflow,)) == \"cpu\"):\n            test_g_res = None\n        else:\n            test_g = TestGraphOfTensorMethod()\n            test_g_res = test_g()\n    except Exception as e:\n        if not verbose:\n            get_fake_program_more_detail(\n                oneflow,\n                \"nn.Graph\",\n                \"get_tensor_graph_res\",\n                tensor_graph_args,\n                tensor_graph_kwargs,\n            )\n        raise OneFlowGraphBuildOrRunError(e)\n    return test_g_res\n\n\ndef get_oneflow_eager_res(\n    oneflow, oneflow_args, oneflow_kwargs, verbose, is_tesnor_method=False\n):\n    if verbose:\n        get_fake_program_more_detail(\n            oneflow, \"Eager\", \"get_oneflow_eager_res\", oneflow_args, oneflow_kwargs\n        )\n    if not is_tesnor_method:\n        oneflow_res = oneflow(*oneflow_args, **oneflow_kwargs)\n    else:\n        oneflow_res = oneflow(*oneflow_args, **oneflow_kwargs)\n    return oneflow_res\n\n\n# NOTE(lixiang): Check if the results of eager and graph are equal when oneflow is of type nn.Module or functional.\ndef oneflow_eager_run_with_graph_check(\n    oneflow, oneflow_args, oneflow_kwargs, testing_graph, verbose, *args\n):\n    if testing_graph:\n        graph_args, graph_kwargs = get_args_copy(oneflow_args, oneflow_kwargs)\n\n        if isinstance(oneflow, flow.nn.Module):\n            graph_train_oneflow = copy.deepcopy(oneflow)\n            if not is_global():\n                arg_device_type = \"cpu\"\n                for arg in oneflow_args:\n                    if flow.is_tensor(arg):\n                        arg_device_type = arg.device.type\n                graph_train_oneflow = graph_train_oneflow.to(arg_device_type)\n\n        else:\n            graph_functional_oneflow = copy.deepcopy(oneflow)\n\n    oneflow_res = get_oneflow_eager_res(oneflow, oneflow_args, oneflow_kwargs, verbose)\n    if testing_graph:\n        find_check_module_func = True\n        ignore_apis_list = [\"tensor\", \"train\"]\n        test_g_res = []\n        if isinstance(oneflow, flow.nn.Module):\n            test_g = get_module_graph_test(\n                graph_train_oneflow, oneflow, verbose, oneflow_args, *args\n            )\n            # When doing the global op test, get_global_test_device() will be executed, and temporarily skipping the graph autotest on cpu device.\n            if is_global() and (\n                get_global_test_device(oneflow_args, oneflow_kwargs) == \"cpu\"\n            ):\n                test_g_res = oneflow_res\n            else:\n                # When testing module methods, kwargs are not considered.\n                test_g_res = test_g(*graph_args)\n        elif oneflow.__name__ in ignore_apis_list:\n            find_check_module_func = False\n        # 1. \"oneflow.nn.modules\" not in oneflow.__module__: For avoid run nn.Module branch graph test, like fold op call Fold Module actually.\n        # 2. inspect.isfunction(oneflow): Compared with the ordinary flow.xxx, oneflow.nn.modules.math_ops series op exist an extra layer of python wrapper.\n        # 3. inspect.ismethod(oneflow) and \"oneflow.nn.modules\" in oneflow.__module__:  For op that only has Tensor.xxx method, and call oneflow.xxx actually, like masked_fill.\n        elif (\n            (\n                oneflow.__module__ is not None\n                and (\"oneflow.nn.modules\" not in oneflow.__module__)\n            )\n            or inspect.isfunction(oneflow)\n            or (\n                inspect.ismethod(oneflow) and \"oneflow.nn.modules\" in oneflow.__module__\n            )\n        ):\n\n            test_g_res = get_functional_graph_res(\n                graph_functional_oneflow,\n                oneflow,\n                oneflow_res,\n                oneflow_args,\n                oneflow_kwargs,\n                verbose,\n                *graph_args,\n                **graph_kwargs,\n            )\n        if find_check_module_func:\n            if isinstance(test_g_res, tuple):\n                for _, g_res in enumerate(test_g_res):\n                    if not check_eager_graph_tensor(oneflow_res, g_res):\n                        get_fake_program_more_detail(\n                            oneflow,\n                            \"Eager + nn.Graph\",\n                            \"oneflow_eager_run_with_graph_check\",\n                            oneflow_args,\n                            oneflow_kwargs,\n                        )\n            else:\n                if not check_eager_graph_tensor(oneflow_res, test_g_res):\n                    get_fake_program_more_detail(\n                        oneflow,\n                        \"Eager + nn.Graph\",\n                        \"oneflow_eager_run_with_graph_check\",\n                        oneflow_args,\n                        oneflow_kwargs,\n                    )\n    return oneflow_res\n\n\n# NOTE(lixiang): Check if the results of eager and graph are equal when oneflow is of type tensor.\ndef oneflow_tensor_eager_run_with_graph_check(\n    oneflow, oneflow_method, oneflow_args, oneflow_kwargs, testing_graph, verbose\n):\n    if testing_graph:\n        tensor_graph_args, tensor_graph_kwargs = get_args_copy(\n            oneflow_args, oneflow_kwargs\n        )\n        graph_tensor_oneflow = copy.deepcopy(oneflow_method)\n\n    oneflow_res = get_oneflow_eager_res(\n        oneflow_method, oneflow_args, oneflow_kwargs, verbose, is_tesnor_method=True\n    )\n\n    if testing_graph:\n\n        test_g_res = get_tensor_graph_res(\n            graph_tensor_oneflow,\n            oneflow,\n            verbose,\n            *tensor_graph_args,\n            **tensor_graph_kwargs,\n        )\n\n        if isinstance(test_g_res, tuple):\n            for _, g_res in enumerate(test_g_res):\n                if not check_eager_graph_tensor(oneflow_res, g_res):\n                    get_fake_program_more_detail(\n                        oneflow,\n                        \"nn.Graph\",\n                        \"oneflow_tensor_eager_run_with_graph_check\",\n                        oneflow_args,\n                        oneflow_kwargs,\n                    )\n        else:\n            if not check_eager_graph_tensor(oneflow_res, test_g_res):\n                get_fake_program_more_detail(\n                    oneflow,\n                    \"nn.Graph\",\n                    \"oneflow_tensor_eager_run_with_graph_check\",\n                    oneflow_args,\n                    oneflow_kwargs,\n                )\n    return oneflow_res\n\n\ndef get_pytorch_oneflow_res(\n    pytorch,\n    oneflow,\n    pytorch_args,\n    pytorch_kwargs,\n    oneflow_args,\n    oneflow_kwargs,\n    name,\n    verbose,\n    testing_graph,\n    *args,\n):\n    try:\n        pytorch_res = pytorch(*pytorch_args, **pytorch_kwargs)\n\n        if isinstance(pytorch_res, torch_original.Tensor):\n            call_flag = True\n            source_flag = True\n            for x in pytorch_args:\n                if isinstance(x, (tuple, list)):\n                    for y in x:\n                        if torch_original.is_tensor(y):\n                            source_flag = False\n                            if (\n                                id(pytorch_res) == id(y)\n                                and pytorch_res.device.type == y.device.type\n                            ):\n                                call_flag = False\n                                break\n                elif torch_original.is_tensor(x):\n                    source_flag = False\n                    if (\n                        id(pytorch_res) == id(x)\n                        and pytorch_res.device.type == x.device.type\n                    ):\n                        call_flag = False\n                        break\n            for x in pytorch_kwargs.values():\n                if isinstance(x, (tuple, list)):\n                    for y in x:\n                        if torch_original.is_tensor(y):\n                            source_flag = False\n                            if (\n                                id(pytorch_res) == id(y)\n                                and pytorch_res.device.type == y.device.type\n                            ):\n                                call_flag = False\n                                break\n                elif torch_original.is_tensor(x):\n                    source_flag = False\n                    if (\n                        id(pytorch_res) == id(x)\n                        and pytorch_res.device.type == x.device.type\n                    ):\n                        call_flag = False\n                        break\n            if source_flag and pytorch.__name__ != \"to\":\n                call_tensor_id.append(id(pytorch_res))\n                extra_input_tensor.append(pytorch_res)\n            elif call_flag:\n                call_tensor_id.append(id(pytorch_res))\n\n    except Exception as e:\n        if align_exception:\n            try:\n                oneflow_res = oneflow(*oneflow_args, **oneflow_kwargs)\n            except Exception as ee:\n                raise BothDoNotSupportError(e, ee) from None\n            print(\n                \"PyTorch has an error but OneFlow is ok, maybe you should check your implementation to align with PyTorch.\"\n            )\n            get_fake_program_more_detail(\n                oneflow,\n                \"Eager\",\n                \"get_pytorch_oneflow_res\",\n                oneflow_args,\n                oneflow_kwargs,\n            )\n        raise PyTorchDoesNotSupportError(e)\n\n    if name in postulate:\n        oneflow_res = torch_tensor_to_flow(pytorch_res)\n    else:\n        oneflow_res = oneflow_eager_run_with_graph_check(\n            oneflow, oneflow_args, oneflow_kwargs, testing_graph, verbose, *args,\n        )\n    return pytorch_res, oneflow_res\n\n\ndef get_pytorch_oneflow_tensor_res(\n    pytorch_method,\n    oneflow_method,\n    oneflow,\n    pytorch_args,\n    pytorch_kwargs,\n    oneflow_args,\n    oneflow_kwargs,\n    testing_graph,\n    verbose,\n):\n    try:\n        pytorch_res = pytorch_method(*pytorch_args, **pytorch_kwargs)\n        if isinstance(pytorch_res, torch_original.Tensor):\n            if (\n                id(pytorch_res) != id(pytorch_method.__self__)\n                or pytorch_res.device.type == pytorch_method.__self__.device.type\n            ):\n                call_tensor_id.append(id(pytorch_res))\n    except Exception as e:\n        if align_exception:\n            try:\n                oneflow_res = oneflow_method(*oneflow_args, **oneflow_kwargs)\n            except Exception as ee:\n                raise BothDoNotSupportError(e, ee) from None\n            print(\n                \"PyTorch has an error but OneFlow is ok, maybe you should check your implementation to align with PyTorch.\"\n            )\n        raise PyTorchDoesNotSupportError(e)\n    oneflow_res = oneflow_tensor_eager_run_with_graph_check(\n        oneflow, oneflow_method, oneflow_args, oneflow_kwargs, testing_graph, verbose,\n    )\n    return pytorch_res, oneflow_res\n\n\nprofiled_method_name = []\n\n\ndef GetDualObject(name, pytorch, oneflow):\n    global counter\n    counter += 1\n    skipped_magic_methods = [\n        \"__class__\",\n        \"__mro__\",\n        \"__new__\",\n        \"__init__\",\n        \"__getattr__\",\n        \"__setattr__\",\n        \"__getattribute__\",\n        \"__dict__\",\n        \"__weakref__\",\n        \"__builtins__\",\n        \"__qualname__\",\n        \"__name__\",\n        \"__str__\",\n        \"__repr__\",\n    ]\n    verbose = os.getenv(\"ONEFLOW_TEST_VERBOSE\") is not None\n    pytorch_methods = dir(pytorch)\n    if hasattr(pytorch, \"__call__\") and \"__call__\" not in pytorch_methods:\n        pytorch_methods.append(\"__call__\")\n    magic_methods_for_new_cls = {}\n    for method_name in pytorch_methods:\n        if method_name.startswith(\"__\") and method_name not in skipped_magic_methods:\n\n            def get_dual_method(method_name):\n                if method_name == \"__call__\":\n\n                    if name in profiled_method_name:\n\n                        def method(self, *args, **kwargs):\n                            return auto_profiler.profile_dual_object(self)(\n                                *args, **kwargs\n                            )\n\n                        return method\n\n                    def dual_method(self, *args, **kwargs):\n                        param_str = to_string(*args, **kwargs)\n                        (\n                            pytorch_args,\n                            pytorch_kwargs,\n                            oneflow_args,\n                            oneflow_kwargs,\n                        ) = get_args(pytorch, *args, **kwargs)\n\n                        pytorch_res, oneflow_res = get_pytorch_oneflow_res(\n                            pytorch,\n                            oneflow,\n                            pytorch_args,\n                            pytorch_kwargs,\n                            oneflow_args,\n                            oneflow_kwargs,\n                            name,\n                            verbose,\n                            testing_graph,\n                            *args,\n                        )\n                        return GetDualObject(\n                            f\"{name}({param_str})\", pytorch_res, oneflow_res\n                        )\n\n                else:\n\n                    def dual_method(self, *args, **kwargs):\n                        pytorch_method = getattr(pytorch, method_name)\n                        oneflow_method = getattr(oneflow, method_name)\n                        (\n                            pytorch_args,\n                            pytorch_kwargs,\n                            oneflow_args,\n                            oneflow_kwargs,\n                        ) = get_args(pytorch_method, *args, **kwargs)\n                        pytorch_res, oneflow_res = get_pytorch_oneflow_tensor_res(\n                            pytorch_method,\n                            oneflow_method,\n                            oneflow,\n                            pytorch_args,\n                            pytorch_kwargs,\n                            oneflow_args,\n                            oneflow_kwargs,\n                            testing_graph,\n                            verbose,\n                        )\n                        return GetDualObject(\"unused\", pytorch_res, oneflow_res)\n\n                return dual_method\n\n            magic_methods_for_new_cls[method_name] = get_dual_method(method_name)\n    Cls = type(f\"{name}_{counter}\", (DualObject,), magic_methods_for_new_cls)\n    return Cls(name, pytorch, oneflow)\n\n\ndef note_print_args(x, end=True):\n    if end:\n        if isinstance(x, str) and \"Tensor\" not in x:\n            print(f\"\\033[32m{x}, \\033[0m\", end=\"\")\n        else:\n            print(f\"\\033[32m{x}, \\033[0m\", end=\"\")\n    else:\n        if isinstance(x, str) and \"Tensor\" not in x:\n            print(f\"\\033[32m{x}\\033[0m\", end=\"\")\n        else:\n            print(f\"\\033[32m{x}\\033[0m\", end=\"\")\n\n\ndef note_print_kwargs(x, y, end=True):\n    if end:\n        if isinstance(y, str) and \"Tensor\" not in y:\n            print(f\"\\033[32m{x}={y}, \\033[0m\", end=\"\")\n        else:\n            print(f\"\\033[32m{x}={y}, \\033[0m\", end=\"\")\n    else:\n        if isinstance(y, str) and \"Tensor\" not in y:\n            print(f\"\\033[32m{x}={y}\\033[0m\", end=\"\")\n        else:\n            print(f\"\\033[32m{x}={y}\\033[0m\", end=\"\")\n\n\ndef print_note_fake_program(detail=False):\n    code_len = len(note_pytorch_method_names)\n    for i in range(code_len):\n        note_pytorch_args_len = len(note_pytorch_args[i])\n        note_pytorch_kwargs_len = len(note_pytorch_kwargs[i])\n        print(f\"\\033[32m{note_pytorch_method_names[i]}\\033[0m\", end=\"\")\n        print(f\"\\033[32m(\\033[0m\", end=\"\")\n        if note_pytorch_args[i]:\n            index = 0\n            for x in note_pytorch_args[i]:\n                index += 1\n                note_print_args(x, index < note_pytorch_args_len)\n\n        if note_pytorch_kwargs[i]:\n            index = 0\n            if note_pytorch_args[i]:\n                print(f\"\\033[32m, \\033[0m\", end=\"\")\n            for x in note_pytorch_kwargs[i].keys():\n                index += 1\n                note_print_kwargs(\n                    x, note_pytorch_kwargs[i][x], index < note_pytorch_kwargs_len\n                )\n        print(f\"\\033[32m)\\033[0m\")\n    if detail:\n        print(\n            f\"\\033[32m-----------------------------------------------------------\\033[0m\"\n        )\n        unique_vis_tensor = []\n        flag_vis_input_tensor = [False for _ in range(len(vis_tensor))]\n        for i in range(len(vis_tensor)):\n            if flag_vis_input_tensor[i] == True:\n                continue\n            unique_vis_tensor.append(vis_tensor[i])\n            flag_vis_input_tensor[i] = True\n            for j in range(i + 1, len(vis_tensor)):\n                if (\n                    id(vis_tensor[i]) == id(vis_tensor[j])\n                    and flag_vis_input_tensor[j] == False\n                ):\n                    flag_vis_input_tensor[j] = True\n        unique_extra_tensor = []\n        flag_vis_extra_tensor = [False for _ in range(len(extra_input_tensor))]\n        for i in range(len(extra_input_tensor)):\n            if flag_vis_extra_tensor[i] == True:\n                continue\n            unique_extra_tensor.append(extra_input_tensor[i])\n            flag_vis_extra_tensor[i] = True\n            for j in range(i + 1, len(extra_input_tensor)):\n                if (\n                    id(extra_input_tensor[i]) == id(extra_input_tensor[j])\n                    and flag_vis_extra_tensor[j] == False\n                ):\n                    flag_vis_extra_tensor[j] = True\n\n        print(\n            f\"\\033[32mThis program has {len(unique_extra_tensor) + len(unique_vis_tensor)} input tensor: \\033[0m\"\n        )\n        for input_tensor in iter(unique_extra_tensor):\n            print(f\"\\033[32mShape{get_tensor_shape(input_tensor)}\\033[0m\")\n            print(f\"\\033[32m{input_tensor}\\033[0m\")\n            print(\n                f\"\\033[32m-----------------------------------------------------------\\033[0m\"\n            )\n        for input_tensor in iter(unique_vis_tensor):\n            print(f\"\\033[32mShape{get_tensor_shape(input_tensor)}\\033[0m\")\n            print(f\"\\033[32m{input_tensor}\\033[0m\")\n            print(\n                f\"\\033[32m-----------------------------------------------------------\\033[0m\"\n            )\n        if vis_parameters:\n            print(\n                f\"\\033[32m-------------------nn.Module Parameters---------------------\\033[0m\"\n            )\n            for name, param in vis_parameters.items():\n                print(f\"\\033[32m{name}: {param}\\033[0m\")\n\n\ndef clear_note_fake_program():\n    note_pytorch_method_names.clear()\n    note_pytorch_args.clear()\n    note_pytorch_kwargs.clear()\n    call_tensor_id.clear()\n    vis_tensor.clear()\n    vis_parameters.clear()\n    extra_input_tensor.clear()\n    flow.set_printoptions(profile=\"full\")\n\n\ntensor_size_limit_mb = int(os.getenv(\"ONEFLOW_TEST_TENSOR_SIZE_LIMIT_MB\", 32))\n\n\nclass DualObject:\n    def __init__(self, name, pytorch, oneflow):\n        self.name = name\n        if isinstance(pytorch, torch_original.nn.Module):\n            if is_global():\n                pytorch.load_state_dict(broadcast(pytorch).state_dict())\n            state_dict = pytorch.state_dict()\n            state_dict = {k: v.detach().cpu().numpy() for (k, v) in state_dict.items()}\n            oneflow_state_dict = oneflow.state_dict()\n            oneflow_state_dict = {\n                k: v.detach() for (k, v) in oneflow_state_dict.items()\n            }\n            already_global = any([v.is_global for v in oneflow_state_dict.values()])\n            if is_global() and already_global:\n                for k, v in state_dict.items():\n                    if k not in oneflow_state_dict:\n                        continue\n                    of_state = oneflow_state_dict[k]\n                    if of_state.is_global:\n                        state_dict[k] = flow.tensor(\n                            v, sbp=of_state.sbp, placement=of_state.placement\n                        )\n\n            oneflow.load_state_dict(state_dict, strict=False)\n\n            if is_global():\n                if already_global:\n                    for (k, v) in oneflow_state_dict.items():\n                        if v.is_global:\n                            t = getattr(oneflow, k)\n                            new = t.to_global(placement=v.placement, sbp=v.sbp)\n                            if isinstance(t, flow.nn.Parameter):\n                                new = flow.nn.Parameter(new)\n                            setattr(\n                                oneflow, k, new,\n                            )\n                else:\n                    oneflow = oneflow.to_global(\n                        placement=flow.placement.all(\"cpu\"), sbp=[flow.sbp.broadcast,],\n                    )\n            if testing:\n                dual_modules_to_test.append(self)\n        if isinstance(pytorch, torch_original.Tensor):\n            tensor_size_mb = pytorch.nelement() * pytorch.element_size() / 1024 / 1024\n            assert (\n                tensor_size_mb < tensor_size_limit_mb\n            ), f\"Tensor memory in autotest cannot be larger than {tensor_size_limit_mb}MB, but got {tensor_size_mb}MB\"\n            if testing:\n                dual_objects_to_test.append(self)\n        self.pytorch = pytorch\n        self.oneflow = oneflow\n\n    def __repr__(self):\n        return f\"PyTorch object:\\n{self.pytorch}\\n\\nOneFlow object:\\n{self.oneflow}\"\n\n    def __getattr__(self, key):\n        if key in [\"to_global\", \"to_local\"]:\n\n            def identity(*args, **kwargs):\n                if isinstance(self.pytorch, torch_original.Tensor):\n                    return self.pytorch.clone()\n                return self.pytorch\n\n            pytorch_attr = identity\n        elif key in [\"placement\", \"sbp\"]:\n            pytorch_attr = \"unused\"\n        elif key in [\"broadcast_like\"]:\n\n            def broadcast_like(x, y, *args, **kwargs):\n                return self.pytorch.broadcast_to(x, y.size())\n\n            pytorch_attr = broadcast_like\n        else:\n            pytorch_attr = getattr(self.pytorch, key)\n        oneflow_attr = getattr(self.oneflow, key)\n        if pytorch_attr is None:\n            assert (\n                oneflow_attr is None\n            ), f\"pytorch value is None for attr {key}, but oneflow is not.\"\n            return None\n        if self.name == \"\":\n            new_name = key\n        else:\n            new_name = f\"{self.name}.{key}\"\n        global call_pytorch\n        call_pytorch = self.pytorch\n        return GetDualObject(new_name, pytorch_attr, oneflow_attr)\n\n    def __setattr__(self, key, value):\n        if isinstance(value, DualObject):\n            setattr(self.pytorch, key, value.pytorch)\n            setattr(self.oneflow, key, value.oneflow)\n        else:\n            self.__dict__[key] = value\n\n    def __eq__(self, other):\n        if isinstance(other, DualObject):\n            return self.pytorch == other.pytorch and self.oneflow == other.oneflow\n        else:\n            return self.pytorch == other\n\n\ndual_modules_to_test = []\ndual_objects_to_test = []\ntorch_type2checker = {}\n\n\ndef equality_checker(torch_type, flow_type):\n    def deco(f):\n        torch_type2checker[torch_type, flow_type] = f\n        return f\n\n    return deco\n\n\ndef check_equality(dual_object: DualObject, rtol=0.0001, atol=1e-05, check_dtype=False):\n    checker = torch_type2checker.get(\n        (type(dual_object.pytorch), type(dual_object.oneflow)), None\n    )\n    if checker is None:\n        for (key, value) in torch_type2checker.items():\n            if isinstance(dual_object.pytorch, key[0]) and isinstance(\n                dual_object.oneflow, key[1]\n            ):\n                checker = value\n                break\n    assert checker is not None, (\n        \"checker not found for type \"\n        + str(type(dual_object.pytorch))\n        + \" and \"\n        + str(type(dual_object.oneflow))\n    )\n    return checker(dual_object.pytorch, dual_object.oneflow, rtol, atol, check_dtype)\n\n\n@equality_checker(torch_original.Tensor, flow.Tensor)\n@equality_checker(torch_original.Tensor, flow._oneflow_internal.Tensor)\ndef check_tensor_equality(\n    torch_tensor, flow_tensor, rtol=0.0001, atol=1e-05, check_dtype=False\n):\n    if torch_tensor.grad is not None:\n        if flow_tensor.grad is None:\n            print_note_fake_program(detail=True)\n        assert (\n            flow_tensor.grad is not None\n        ), f\"OneFlow tensor doesn't have grad while PyTorch tensor has one, PyTorch tensor is\\n {torch_tensor}\\n, OneFlow tensor is\\n{flow_tensor} \"\n        torch_grad = (\n            torch_tensor.grad.detach().cpu().numpy()\n            if not torch_original.is_conj(torch_tensor.grad)\n            else torch_original.resolve_conj(torch_tensor.grad.detach()).cpu().numpy()\n        )\n        flow_grad = flow_tensor.grad.numpy()\n        if not np.allclose(\n            torch_grad, flow_grad, rtol=rtol, atol=atol, equal_nan=True,\n        ):\n            print_note_fake_program(detail=True)\n            print(\"---------Grad Shape--------\")\n            print(torch_grad.shape)\n            print(flow_grad.shape)\n            print(\n                f\"Grads are not equal. PyTorch grad: \\n{torch_grad}\\n, OneFlow grad: \\n{flow_grad}\"\n            )\n            return False\n    torch_numpy = (\n        torch_tensor.detach().cpu().numpy()\n        if not torch_original.is_conj(torch_tensor)\n        else torch_original.resolve_conj(torch_tensor.detach()).cpu().numpy()\n    )\n    oneflow_numpy = flow_tensor.numpy()\n    equality_res = np.allclose(\n        torch_numpy, oneflow_numpy, rtol=rtol, atol=atol, equal_nan=True,\n    )\n    # NOTE: if check_dtype=True, then check the equality of data type\n    if check_dtype:\n        equality_res = equality_res and (torch_numpy.dtype == oneflow_numpy.dtype)\n\n    if equality_res == False:\n        print_note_fake_program(detail=True)\n        print(\"---------Tensor Shape--------\")\n        print(torch_tensor.shape)\n        print(flow_tensor.shape)\n        print(\"---------Tensor dtype--------\")\n        print(torch_tensor.dtype)\n        print(flow_tensor.dtype)\n    return equality_res\n\n\n@equality_checker(int, int)\n@equality_checker(bool, bool)\ndef check_basetype_equality(a, b, ignored1, ignored2, check_dtype=False):\n    if check_dtype:\n        return (a == b) and (type(a) == type(b))\n    return a == b\n\n\n@equality_checker(tuple, tuple)\n@equality_checker(list, list)\ndef check_basetype_equality(a, b, rtol=0.0001, atol=1e-05, check_dtype=False):\n    if len(a) != len(b):\n        equality_res = False\n    else:\n        for i in range(len(a)):\n            torch_np = a[i].detach().cpu().numpy()\n            flow_np = b[i].detach().cpu().numpy()\n            equality_res = np.allclose(\n                torch_np, flow_np, rtol=rtol, atol=atol, equal_nan=True,\n            )\n            if check_dtype:\n                equality_res = equality_res and (torch_np.dtype == flow_np.dtype)\n            if equality_res == False:\n                print_note_fake_program(detail=True)\n                print(\"---------Tensor Shape--------\")\n                print(a[i].shape)\n                print(b[i].shape)\n                print(\"---------Tensor dtype--------\")\n                print(a[i].dtype)\n                print(b[i].dtype)\n                break\n\n    return equality_res\n\n\n@equality_checker(type(None), type(None))\ndef check_nonetype_equality(a, b, ignored1, ignored2, check_dtype=False):\n    return True\n\n\ndef autotest(\n    n=20,\n    auto_backward: Union[bool, str] = True,\n    rtol=0.0001,\n    atol=1e-05,\n    check_graph=True,\n    check_allclose=True,\n    check_dtype=False,\n    check_grad_use_random_data=True,\n    include_complex=False,\n):\n    verbose = os.getenv(\"ONEFLOW_TEST_VERBOSE\") is not None\n\n    if check_graph == \"ValidatedFalse\":\n        # check graph is intentionally closed and there is a validated reason.\n        check_graph = False\n\n    def deco(f):\n        @functools.wraps(f)\n        def new_f(test_case, *args, **kwargs):\n            successful_runs_needed = n\n            loop_limit = successful_runs_needed * 20\n            current_run = 0\n            while successful_runs_needed > 0:\n                clear_note_fake_program()\n                if current_run > loop_limit:\n                    raise ValueError(\n                        \"autotest stuck in an endless loop, usually it is caused by invalid code in the test case\"\n                    )\n                dual_modules_to_test.clear()\n                dual_objects_to_test.clear()\n                global global_check_allclose, global_rtol, global_atol, global_backward\n                global_check_allclose = check_allclose\n                global_rtol = rtol\n                global_atol = atol\n                global_backward = auto_backward\n\n                try:\n                    global testing_graph\n                    # for generate fake program input tensor\n                    global testing\n                    testing = True\n                    if check_graph:\n                        testing_graph = True\n\n                    global testing_complex\n                    if include_complex:\n                        testing_complex = True\n                        testing_graph = False\n\n                    res = f(test_case, *args, **kwargs)\n\n                    testing = False\n                    testing_graph = False\n                    testing_complex = False\n                except (PyTorchDoesNotSupportError, BothDoNotSupportError) as e:\n                    if verbose:\n                        print(f\"{f.__name__}\")\n                        print(e)\n                    current_run += 1\n                    continue\n                if res is not None:\n                    if not isinstance(res, collections.abc.Sequence):\n                        res = [res]\n                    for x in res:\n                        if x is None:\n                            continue\n                        if auto_backward:\n                            if isinstance(x.pytorch, torch_original.Tensor):\n                                if auto_backward == \"auto\" and (\n                                    not x.pytorch.requires_grad\n                                    or not x.oneflow.requires_grad\n                                ):\n                                    continue\n                                call_tensor_id.append(id(x.pytorch))\n                                if check_grad_use_random_data:\n                                    np_arr = rng.uniform(\n                                        low=0, high=1, size=list(x.oneflow.shape)\n                                    )\n                                    if is_global():\n                                        np_arr = broadcast(np_arr)\n                                        flow_tensor = flow.tensor(\n                                            np_arr,\n                                            dtype=x.oneflow.dtype,\n                                            placement=x.oneflow.placement,\n                                            sbp=len(x.oneflow.sbp)\n                                            * [flow.sbp.broadcast],\n                                        )\n                                    else:\n                                        flow_tensor = flow.tensor(\n                                            np_arr,\n                                            dtype=x.oneflow.dtype,\n                                            device=x.oneflow.device,\n                                        )\n                                    # TODO(): Inferred shape of some op is different between oneflow and torch\n                                    pytorch_tensor = torch_original.tensor(\n                                        np_arr.reshape(list(x.pytorch.shape)),\n                                        dtype=x.pytorch.dtype,\n                                        device=x.pytorch.device,\n                                    )\n                                    call_tensor_id.append(id(pytorch_tensor))\n                                    diff_output = GetDualObject(\n                                        \"unused\", pytorch_tensor, flow_tensor\n                                    )\n                                    x.backward(diff_output)\n                                else:\n                                    x.sum().backward()\n                        dual_objects_to_test.append(x)\n                for x in dual_modules_to_test:\n                    for key in x.pytorch.state_dict().keys():\n                        if key not in x.oneflow.state_dict().keys():\n                            warnings.warn(f\"oneflow module don't have `{key}`\")\n                            continue\n                        vis_parameters[key] = x.pytorch.state_dict()[key]\n                        dual_objects_to_test.append(\n                            GetDualObject(\n                                \"unused\",\n                                getattr(x.pytorch, key),\n                                getattr(x.oneflow, key),\n                            )\n                        )\n                        call_tensor_id.append(id(getattr(x.pytorch, key)))\n                        dual_objects_to_test.append(\n                            GetDualObject(\n                                \"unused\",\n                                getattr(x.pytorch, key).grad,\n                                getattr(x.oneflow, key).grad,\n                            )\n                        )\n                        call_tensor_id.append(id(getattr(x.pytorch, key).grad))\n\n                for x in dual_objects_to_test:\n                    if (\n                        isinstance(x.pytorch, torch_original.Tensor)\n                        and id(x.pytorch) not in call_tensor_id\n                    ):\n                        vis_tensor.append(x.pytorch)\n\n                # check eager\n                for x in dual_objects_to_test:\n                    if check_allclose:\n                        test_case.assertTrue(\n                            check_equality(\n                                x, rtol=rtol, atol=atol, check_dtype=check_dtype,\n                            ),\n                            x,\n                        )\n\n                if verbose:\n                    print(f\"{f.__name__} test eager passed.\")\n\n                if verbose and check_graph:\n                    print(f\"{f.__name__} test graph passed.\")\n\n                successful_runs_needed -= 1\n                current_run += 1\n\n        return new_f\n\n    return deco\n\n\ndef globaltest(f):\n    @functools.wraps(f)\n    def new_f(*args, **kwargs):\n        with GlobalScope() as scope:\n            return f(*args, **kwargs)\n\n    return new_f\n\n\ndef random_tensor(\n    ndim=None,\n    dim0=1,\n    dim1=None,\n    dim2=None,\n    dim3=None,\n    dim4=None,\n    low=None,\n    high=None,\n    dtype=float,\n    requires_grad=True,\n    pin_memory=False,\n):\n    if isinstance(requires_grad, generator):\n        requires_grad = requires_grad.value()\n    if dtype == float and testing_complex:\n        # Generate complex with the probability of 0.5\n        dtype = complex if rng.integers(0, 2) == 1 else float\n\n    pytorch_tensor = (\n        random_pytorch_tensor(\n            ndim, dim0, dim1, dim2, dim3, dim4, low, high, dtype, pin_memory\n        )\n        .value()\n        .requires_grad_(requires_grad and dtype != int)\n    )\n    extra_input_tensor.append(pytorch_tensor)\n    if is_global():\n        flow_tensor = flow.tensor(\n            pytorch_tensor.detach().cpu().numpy(),\n            requires_grad=(requires_grad and dtype != int),\n            placement=flow.placement.all(\"cpu\"),\n            sbp=flow.sbp.broadcast,\n        )\n    else:\n        flow_tensor = flow.tensor(\n            pytorch_tensor.detach().cpu().numpy(),\n            requires_grad=(requires_grad and dtype != int),\n            pin_memory=pin_memory,\n        )\n\n    return GetDualObject(\"unused\", pytorch_tensor, flow_tensor)\n\n\ndef random_dtype(seq_names):\n    pytorch_dtype = random_pytorch_dtype(seq_names).value()\n    if pytorch_dtype is None:\n        flow_dtype = None\n    else:\n        flow_dtype = type_name_to_flow_type[pytorch_dtype.__str__().split(\".\")[-1]]\n    return GetDualObject(\"DualDType\", pytorch_dtype, flow_dtype)\n\n\ndef choice_tensor(\n    a, size=None, replace=True, p=None, dtype=int, requires_grad=False,\n):\n    \"\"\"Generates a random sample from a given 1-D array, which aligns with numpy.random.choice\n    see https://numpy.org/doc/stable/reference/random/generated/numpy.random.choice.html for details\n\n    \"\"\"\n    if isinstance(requires_grad, generator):\n        requires_grad = requires_grad.value()\n    pytorch_tensor = (\n        choice_pytorch_tensor(a, size, replace, p, dtype)\n        .value()\n        .requires_grad_(requires_grad and dtype != int)\n    )\n    if is_global():\n        flow_tensor = flow.tensor(\n            pytorch_tensor.detach().cpu().numpy(),\n            requires_grad=(requires_grad and dtype != int),\n            placement=flow.placement.all(\"cpu\"),\n            sbp=flow.sbp.broadcast,\n        )\n    else:\n        flow_tensor = flow.tensor(\n            pytorch_tensor.detach().cpu().numpy(),\n            requires_grad=(requires_grad and dtype != int),\n        )\n\n    return GetDualObject(\"unused\", pytorch_tensor, flow_tensor)\n\n\ntorch = GetDualObject(\"\", torch_original, flow)\n__all__ = [\"autotest\", \"globaltest\", \"random_tensor\", \"random_dtype\", \"choice_tensor\"]\n"
  },
  {
    "path": "python/oneflow/test_utils/automated_test_util/util.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport pickle\nimport oneflow as flow\n\n\ndef broadcast(obj, src: int = 0):\n    rank = flow.env.get_rank()\n    if src == rank:\n        obj_bytes = pickle.dumps(obj)\n        obj_bytes = flow._oneflow_internal.cpu_broadcast(obj_bytes, src)\n    else:\n        obj_bytes = flow._oneflow_internal.cpu_broadcast(None, src)\n    return pickle.loads(obj_bytes)\n"
  },
  {
    "path": "python/oneflow/test_utils/oneflow_pytorch_compatibility/__init__.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nfrom .oneflow_pytorch_compatiblity_test import *\n"
  },
  {
    "path": "python/oneflow/test_utils/oneflow_pytorch_compatibility/oneflow_pytorch_compatiblity_test.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport os\nimport importlib.util\nimport unittest\n\nimport numpy as np\nimport time\nimport tempfile\nimport argparse\n\nimport oneflow as flow\nimport torch\nimport oneflow.unittest\nimport shutil\nimport matplotlib as mpl\n\nmpl.use(\"Agg\")\nimport matplotlib.pyplot as plt\n\nverbose = os.getenv(\"ONEFLOW_TEST_VERBOSE\") is not None\n\n\ndef cos_sim(vector_a, vector_b):\n    vector_a = np.mat(vector_a)\n    vector_b = np.mat(vector_b)\n    num = float(vector_a * vector_b.T)\n    denom = np.linalg.norm(vector_a) * np.linalg.norm(vector_b)\n    cos = num / denom\n    sim = 0.5 + 0.5 * cos\n    return sim\n\n\ndef import_file(source):\n    with tempfile.NamedTemporaryFile(\"w\", suffix=\".py\") as f:\n        f.write(source)\n        f.flush()\n        spec = importlib.util.spec_from_file_location(\"mod\", f.name)\n        mod = importlib.util.module_from_spec(spec)\n        spec.loader.exec_module(mod)\n        return mod\n\n\ndef get_loss(\n    image_nd,\n    label_nd,\n    model_path: str,\n    module_name: str,\n    test_pytorch: bool = True,\n    device: str = \"cuda\",\n    tmpfilename: str = \"/tmp/oneflow_tmp_file\",\n):\n    model_loss = []\n    learning_rate = 0.01\n    mom = 0.9\n    bp_iters = 100\n\n    for_time = 0.0\n    bp_time = 0.0\n    update_time = 0.0\n\n    if test_pytorch == True:\n        image = flow.tensor(image_nd)\n        label = flow.tensor(label_nd)\n        corss_entropy = flow.nn.CrossEntropyLoss(reduction=\"mean\")\n\n        with open(model_path) as f:\n            buf = f.read()\n            lines = buf.split(\"\\n\")\n            buf = \"\\n\".join(lines)\n            python_module = import_file(buf)\n\n        Net = getattr(python_module, module_name)\n        pytorch_model = Net()\n\n        w = pytorch_model.state_dict()\n        new_parameters = dict()\n        for k, v in w.items():\n            if \"num_batches_tracked\" not in k:\n                new_parameters[k] = flow.tensor(w[k].detach().numpy())\n\n        flow.save(new_parameters, tmpfilename)\n\n        pytorch_model.to(device)\n        torch_sgd = torch.optim.SGD(\n            pytorch_model.parameters(), lr=learning_rate, momentum=mom\n        )\n\n        image = torch.tensor(image_nd)\n        image_gpu = image.to(device)\n        corss_entropy = torch.nn.CrossEntropyLoss()\n        corss_entropy.to(device)\n        label = torch.tensor(label_nd, dtype=torch.long).to(device)\n\n        print(\"start pytorch training loop....\")\n        start_t = time.time()\n        for i in range(bp_iters):\n            s_t = time.time()\n            logits = pytorch_model(image_gpu)\n            loss = corss_entropy(logits, label)\n            for_time += time.time() - s_t\n\n            s_t = time.time()\n            loss.backward()\n            bp_time += time.time() - s_t\n\n            model_loss.append(loss.detach().cpu().numpy())\n\n            s_t = time.time()\n            torch_sgd.step()\n            torch_sgd.zero_grad()\n            update_time += time.time() - s_t\n\n        end_t = time.time()\n\n        if verbose:\n            print(\n                \"pytorch traning loop avg time : {}\".format(\n                    (end_t - start_t) / bp_iters\n                )\n            )\n            print(\"forward avg time : {}\".format(for_time / bp_iters))\n            print(\"backward avg time : {}\".format(bp_time / bp_iters))\n            print(\"update parameters avg time : {}\".format(update_time / bp_iters))\n    else:\n        with open(model_path) as f:\n            buf = f.read()\n\n            lines = buf.split(\"\\n\")\n            for i, line in enumerate(lines):\n                if (\n                    i > 15 and \"import\" not in line and len(line.strip()) != 0\n                ):  # 15 means license\n                    break\n            lines = (\n                lines[:i]\n                + [\n                    \"import oneflow as torch\",\n                    \"import oneflow.nn as nn\",\n                    \"import oneflow.nn.init as init\",\n                    \"import oneflow.nn.functional as F\",\n                    \"from oneflow import Tensor\",\n                    \"from oneflow.nn import Parameter\",\n                    \"import math\",\n                    \"from flowvision.layers import *\",\n                ]\n                + lines[i:]\n            )\n            buf = \"\\n\".join(lines)\n\n            python_module = import_file(buf)\n\n        Net = getattr(python_module, module_name)\n        oneflow_model = Net()\n\n        image = flow.tensor(image_nd)\n        label = flow.tensor(label_nd)\n        corss_entropy = flow.nn.CrossEntropyLoss(reduction=\"mean\")\n\n        image_gpu = image.to(device)\n        label = label.to(device)\n        oneflow_model.to(device)\n        corss_entropy.to(device)\n\n        params = flow.load(tmpfilename)\n        oneflow_model.load_state_dict(params)\n\n        of_sgd = flow.optim.SGD(\n            oneflow_model.parameters(), lr=learning_rate, momentum=mom\n        )\n\n        print(\"start oneflow training loop....\")\n        start_t = time.time()\n        for i in range(bp_iters):\n            s_t = time.time()\n            logits = oneflow_model(image_gpu)\n            loss = corss_entropy(logits, label)\n            for_time += time.time() - s_t\n\n            s_t = time.time()\n            loss.backward()\n            bp_time += time.time() - s_t\n\n            model_loss.append(loss.numpy())\n\n            s_t = time.time()\n            of_sgd.step()\n            of_sgd.zero_grad()\n            update_time += time.time() - s_t\n\n        end_t = time.time()\n\n        if verbose:\n            print(\n                \"oneflow traning loop avg time : {}\".format(\n                    (end_t - start_t) / bp_iters\n                )\n            )\n            print(\"forward avg time : {}\".format(for_time / bp_iters))\n            print(\"backward avg time : {}\".format(bp_time / bp_iters))\n            print(\"update parameters avg time : {}\".format(update_time / bp_iters))\n\n    return model_loss\n\n\ndef do_test_train_loss_oneflow_pytorch(\n    test_case,\n    model_path: str,\n    module_name: str,\n    device: str = \"cuda\",\n    batch_size: int = 16,\n    img_size: int = 224,\n):\n    image_nd = np.random.rand(batch_size, 3, img_size, img_size).astype(np.float32)\n    label_nd = np.array([e for e in range(batch_size)], dtype=np.int32)\n    oneflow_model_loss = []\n    pytorch_model_loss = []\n\n    with tempfile.NamedTemporaryFile() as f:\n        pytorch_model_loss = get_loss(\n            image_nd, label_nd, model_path, module_name, True, device, f.name\n        )\n        oneflow_model_loss = get_loss(\n            image_nd, label_nd, model_path, module_name, False, device, f.name\n        )\n\n    if verbose:\n        indes = [i for i in range(len(oneflow_model_loss))]\n\n        plt.plot(indes, oneflow_model_loss, label=\"oneflow\")\n        plt.plot(indes, pytorch_model_loss, label=\"pytorch\")\n\n        plt.xlabel(\"iter - axis\")\n        # Set the y axis label of the current axis.\n        plt.ylabel(\"loss - axis\")\n        # Set a title of the current axes.\n        plt.title(\"compare \")\n        # show a legend on the plot\n        plt.legend()\n\n        # Display a figure.\n        plt.savefig(\"./loss_compare.png\")\n        plt.show()\n\n    test_case.assertTrue(\n        np.allclose(cos_sim(oneflow_model_loss, pytorch_model_loss), 1.0, 1e-1, 1e-1)\n    )\n"
  },
  {
    "path": "python/oneflow/test_utils/test_util.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport itertools\nimport os\nfrom collections import OrderedDict\nfrom collections.abc import Iterable\n\nimport numpy as np\n\nimport oneflow as flow\nimport oneflow.unittest\n\n\ndef GenCartesianProduct(sets):\n    assert isinstance(sets, Iterable)\n    for set in sets:\n        assert isinstance(set, Iterable)\n        if os.getenv(\"ONEFLOW_TEST_CPU_ONLY\"):\n            if \"cuda\" in set:\n                set.remove(\"cuda\")\n    return itertools.product(*sets)\n\n\ndef GenArgList(arg_dict):\n    assert isinstance(arg_dict, OrderedDict)\n    assert all([isinstance(x, list) for x in arg_dict.values()])\n    sets = [arg_set for (_, arg_set) in arg_dict.items()]\n    return GenCartesianProduct(sets)\n\n\ndef GenArgDict(arg_dict):\n    return [dict(zip(arg_dict.keys(), x)) for x in GenArgList(arg_dict)]\n\n\nclass Args:\n    def __init__(self, flow_args, tf_args=None):\n        super().__init__()\n        if tf_args is None:\n            tf_args = flow_args\n        self.flow_args = flow_args\n        self.tf_args = tf_args\n\n    def __str__(self):\n        return \"flow_args={} tf_args={}\".format(self.flow_args, self.tf_args)\n\n    def __repr__(self):\n        return self.__str__()\n\n\ntype_name_to_flow_type = {\n    \"bool\": flow.bool,\n    \"float16\": flow.float16,\n    \"float32\": flow.float32,\n    \"double\": flow.double,\n    \"float64\": flow.double,\n    \"int8\": flow.int8,\n    \"int32\": flow.int32,\n    \"int64\": flow.int64,\n    \"uint8\": flow.uint8,\n    \"half\": flow.half,\n    \"bfloat16\": flow.bfloat16,\n    \"complex64\": flow.complex64,\n    \"complex128\": flow.complex128,\n}\ntype_name_to_np_type = {\n    \"float16\": np.float16,\n    \"float32\": np.float32,\n    \"double\": np.float64,\n    \"int8\": np.int8,\n    \"int32\": np.int32,\n    \"int64\": np.int64,\n    \"uint8\": np.uint8,\n    \"complex64\": np.complex64,\n    \"complex128\": np.complex128,\n}\n\n\ndef FlattenArray(input_array):\n    output_array = list()\n    for x in np.nditer(input_array):\n        output_array.append(x.tolist())\n    return output_array\n\n\ndef Array2Numpy(input_array, target_shape):\n    return np.array(input_array).reshape(target_shape, order=\"C\")\n\n\ndef Index2Coordinate(idx, tensor_shape):\n    coordinate = []\n    tmp = idx\n    for i in range(len(tensor_shape) - 1, -1, -1):\n        axis_size = tensor_shape[i]\n        coor = tmp % axis_size\n        coordinate.insert(0, int(coor))\n        tmp = (tmp - coor) / axis_size\n    return coordinate\n\n\ndef Coordinate2Index(coordinate, tensor_shape):\n    if len(coordinate) != len(tensor_shape):\n        raise \"wrong coordinate or shape\"\n    idx = 0\n    for (i, coor) in enumerate(coordinate):\n        size_at_axis = coor\n        for j in range(i + 1, len(tensor_shape)):\n            size_at_axis *= tensor_shape[j]\n        idx += size_at_axis\n    return idx\n"
  },
  {
    "path": "python/oneflow/test_utils/throttle.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport argparse\nimport hashlib\nimport subprocess\nimport portalocker\nimport os\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(\n        description=\"Control when the script runs through special variables.\"\n    )\n    parser.add_argument(\n        \"--with-cuda\", type=int, default=1, help=\"whether has cuda device.\"\n    )\n    parser.add_argument(\"cmd\", type=str, nargs=\"...\", help=\"command to run\")\n    return parser.parse_args()\n\n\ndef hash_cli2gpu(cmd: list):\n    import pynvml\n\n    pynvml.nvmlInit()\n    slot = pynvml.nvmlDeviceGetCount()\n    hash = hashlib.sha1(\" \".join(cmd).encode(\"utf-8\")).hexdigest()\n    gpu_id = int(hash, 16) % slot\n    return [gpu_id]\n\n\ndef main():\n    args = parse_args()\n    if args.with_cuda:\n        cuda_visible_devices = [str(i) for i in hash_cli2gpu(args.cmd)]\n        with portalocker.Lock(\n            \".oneflow-throttle-gpu-\" + \"-\".join(cuda_visible_devices) + \".lock\",\n            timeout=400,\n        ):\n            env = dict(os.environ, CUDA_VISIBLE_DEVICES=\",\".join(cuda_visible_devices))\n            return subprocess.call(args.cmd, env=env)\n    else:\n        return subprocess.call(args.cmd)\n\n\nif __name__ == \"__main__\":\n    returncode = main()\n    exit(returncode)\n"
  },
  {
    "path": "python/oneflow/unittest/__init__.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nfrom oneflow.framework.unittest import (\n    TestCase,\n    num_nodes_required,\n    register_test_cases,\n    skip_unless_1n1d,\n    skip_unless_1n2d,\n    skip_unless_1n4d,\n    skip_unless_2n1d,\n    skip_unless_2n2d,\n    skip_unless_2n4d,\n)\n\nfrom . import env\nfrom .mlir import MLIRTestCase\nfrom .dataset import dataset_dir\n"
  },
  {
    "path": "python/oneflow/unittest/dataset.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport os\n\n\ndef dataset_dir(sub_dir=None):\n    base_dir = os.getenv(\"ONEFLOW_TEST_DATASET_DIR\")\n    if base_dir == None:\n        base_dir = \"/dataset\"\n    if sub_dir == None:\n        return base_dir\n    else:\n        return os.path.join(base_dir, sub_dir)\n"
  },
  {
    "path": "python/oneflow/unittest/env.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nfrom oneflow.framework.unittest import (\n    device_num,\n    eager_execution_enabled,\n    has_node_list,\n    has_world_size,\n    node_list,\n    node_size,\n    typing_check_enabled,\n    world_size,\n)\n"
  },
  {
    "path": "python/oneflow/unittest/mlir.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport os\nimport unittest\n\n\nclass MLIRTestCase(unittest.TestCase):\n    def tearDown(self):\n        for key in os.environ.keys():\n            if key.startswith(\"ONEFLOW_MLIR\"):\n                os.environ.pop(key)\n"
  },
  {
    "path": "python/oneflow/utils/__init__.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nfrom oneflow.framework.config_util import api_load_library as load_library\nfrom oneflow.utils import tensor\nfrom oneflow.utils import global_view\nfrom oneflow.utils import model_zoo\nfrom . import checkpoint\nfrom . import hooks\n"
  },
  {
    "path": "python/oneflow/utils/checkpoint.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n# This file is mostly copied from PyTorch\n\nimport oneflow as flow\nfrom typing import List, Union\n\n\ndef _checkpoint_without_reentrant(function, *args):\n    \"\"\"Checkpointining without re-entrant autograd\n    Args:\n        function: describes what to run in the forward pass of the model or\n            part of the model. It should also know how to handle the inputs\n            passed as the tuple. For example, in LSTM, if user passes\n            ``(activation, hidden)``, :attr:`function` should correctly use the\n            first input as ``activation`` and the second input as ``hidden``\n        *args: Arguments to pass in to the given ``function``.\n    \"\"\"\n\n    storage: List[Union[flow.Tensor, None]] = []\n    counter = 0\n\n    def pack(x):\n        nonlocal counter\n        counter += 1\n        return counter - 1\n\n    # TODO(jianhao): support restoring rng state once we have flow.random.fork_rng\n    def unpack(x):\n        if len(storage) == 0:\n\n            def inner_pack(inner):\n                storage.append(inner)\n                return None\n\n            def inner_unpack(packed):\n                raise RuntimeError(\n                    \"You are calling backwards on a tensor that is never exposed. Please open an issue.\"\n                )\n\n            with flow.enable_grad():\n                with flow.autograd.graph.saved_tensors_hooks(inner_pack, inner_unpack):\n                    _unused = function(*args)\n\n        return storage[x]\n\n    with flow.autograd.graph.saved_tensors_hooks(pack, unpack):\n        output = function(*args)\n\n    return output\n\n\ndef checkpoint(function, *args):\n    r\"\"\"Checkpoint a model or part of the model\n\n    Checkpointing works by trading compute for memory. Rather than storing all\n    intermediate activations of the entire computation graph for computing\n    backward, the checkpointed part does **not** save intermediate activations,\n    and instead recomputes them in backward pass. It can be applied on any part\n    of a model.\n\n    Specifically, in the forward pass, :attr:`function` will run in\n    :func:`flow.no_grad` manner, i.e., not storing the intermediate\n    activations. Instead, the forward pass saves the inputs tuple and the\n    :attr:`function` parameter. In the backwards pass, the saved inputs and\n    :attr:`function` is retrieved, and the forward pass is computed on\n    :attr:`function` again, now tracking the intermediate activations, and then\n    the gradients are calculated using these activation values.\n\n    The output of :attr:`function` can contain non-Tensor values and gradient\n    recording is only performed for the Tensor values. Note that if the output\n    consists of nested structures (ex: custom objects, lists, dicts etc.)\n    consisting of Tensors, these Tensors nested in custom structures will not\n    be considered as part of autograd.\n\n\n    .. warning::\n        If :attr:`function` invocation during backward does anything different\n        than the one during forward, e.g., due to some global variable, the\n        checkpointed version won't be equivalent, and unfortunately it can't be\n        detected.\n\n    .. warning::\n        Preserving rng states is not supported now, so that the behavior of\n        checkpointing does not fully align with PyTorch.\n\n    Args:\n        function: describes what to run in the forward pass of the model or\n            part of the model. It should also know how to handle the inputs\n            passed as the tuple. For example, in LSTM, if user passes\n            ``(activation, hidden)``, :attr:`function` should correctly use the\n            first input as ``activation`` and the second input as ``hidden``\n        args: tuple containing inputs to the :attr:`function`\n\n    Returns:\n        Output of running :attr:`function` on :attr:`*args`\n    \"\"\"\n    return _checkpoint_without_reentrant(function, *args)\n"
  },
  {
    "path": "python/oneflow/utils/data/__init__.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nfrom oneflow.utils.data.sampler import (\n    Sampler,\n    SequentialSampler,\n    RandomSampler,\n    SubsetRandomSampler,\n    BatchSampler,\n)\nfrom oneflow.utils.data.dataset import (\n    Dataset,\n    IterableDataset,\n    TensorDataset,\n    ConcatDataset,\n    Subset,\n    random_split,\n)\nfrom oneflow.utils.data.dataset import IterableDataset as IterDataPipe\nfrom oneflow.utils.data.dataloader import (\n    DataLoader,\n    _DatasetKind,\n    get_worker_info,\n)\nfrom oneflow.utils.data.decorator import (\n    functional_datapipe,\n    guaranteed_datapipes_determinism,\n    non_deterministic,\n)\nfrom oneflow.utils.data.distributed import DistributedSampler\n\n\n__all__ = [\n    \"Sampler\",\n    \"SequentialSampler\",\n    \"RandomSampler\",\n    \"SubsetRandomSampler\",\n    \"BatchSampler\",\n    \"Dataset\",\n    \"IterableDataset\",\n    \"TensorDataset\",\n    \"ConcatDataset\",\n    \"Subset\",\n    \"random_split\",\n    \"DataLoader\",\n    \"_DatasetKind\",\n    \"IterDataPipe\",\n    \"functional_datapipe\",\n    \"guaranteed_datapipes_determinism\",\n    \"non_deterministic\",\n    \"DistributedSampler\",\n]\n"
  },
  {
    "path": "python/oneflow/utils/data/_utils/__init__.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nr\"\"\"Utility classes & functions for data loading. Code in this folder is mostly\nused by ../dataloder.py.\n\nA lot of multiprocessing is used in data loading, which only supports running\nfunctions defined in global environment (py2 can't serialize static methods).\nTherefore, for code tidiness we put these functions into different files in this\nfolder.\n\"\"\"\nimport sys\nimport atexit\n\n\nIS_WINDOWS = sys.platform == \"win32\"\n\n# pytorch's check interval is 5.0 seconds\nMP_STATUS_CHECK_INTERVAL = 10.0\nr\"\"\"Interval (in seconds) to check status of processes to avoid hanging in\n    multiprocessing data loading. This is mainly used in getting data from\n    another process, in which case we need to periodically check whether the\n    sender is alive to prevent hanging.\"\"\"\n\n\npython_exit_status = False\nr\"\"\"Whether Python is shutting down. This flag is guaranteed to be set before\nthe Python core library resources are freed, but Python may already be exiting\nfor some time when this is set.\n\nHook to set this flag is `_set_python_exit_flag`, and is inspired by a similar\nhook in Python 3.7 multiprocessing library:\nhttps://github.com/python/cpython/blob/d4d60134b29290049e28df54f23493de4f1824b6/Lib/multiprocessing/util.py#L277-L327\n\"\"\"\n\ntry:\n    import numpy\n\n    HAS_NUMPY = True\nexcept ModuleNotFoundError:\n    HAS_NUMPY = False\n\n\ndef _set_python_exit_flag():\n    global python_exit_status\n    python_exit_status = True\n\n\natexit.register(_set_python_exit_flag)\n\n\nfrom . import worker, signal_handling, collate, fetch, pin_memory\n"
  },
  {
    "path": "python/oneflow/utils/data/_utils/collate.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nr\"\"\"\"Contains definitions of the methods used by the _BaseDataLoaderIter workers to\ncollate samples fetched from dataset into Tensor(s).\n\nThese **needs** to be in global scope since Py2 doesn't support serializing\nstatic methods.\n\"\"\"\nimport re\nimport collections\n\nimport oneflow as flow\n\n\nstring_classes = (str, bytes)\n\nnp_str_obj_array_pattern = re.compile(r\"[SaUO]\")\n\n\ndef default_convert(data):\n    r\"\"\"Converts each NumPy array data field into a tensor\"\"\"\n    elem_type = type(data)\n    if isinstance(data, (flow.Tensor, flow._oneflow_internal.Tensor)):\n        return data\n    elif (\n        elem_type.__module__ == \"numpy\"\n        and elem_type.__name__ != \"str_\"\n        and elem_type.__name__ != \"string_\"\n    ):\n        # array of string classes and object\n        if (\n            elem_type.__name__ == \"ndarray\"\n            and np_str_obj_array_pattern.search(data.dtype.str) is not None\n        ):\n            return data\n        return flow.tensor(data)\n    elif isinstance(data, collections.abc.Mapping):\n        return {key: default_convert(data[key]) for key in data}\n    elif isinstance(data, tuple) and hasattr(data, \"_fields\"):  # namedtuple\n        return elem_type(*(default_convert(d) for d in data))\n    elif isinstance(data, collections.abc.Sequence) and not isinstance(\n        data, string_classes\n    ):\n        return [default_convert(d) for d in data]\n    else:\n        # NOTE: pytorch just return data here, and not raise any exception!\n        raise TypeError(default_convert_err_msg_format.format(elem_type))\n\n\ndefault_collate_err_msg_format = (\n    \"default_collate: batch must contain tensors, numpy arrays, numbers, \"\n    \"dicts or lists; found {}\"\n)\n\ndefault_convert_err_msg_format = (\n    \"default_convert: batch must contain tensors, numpy arrays, numbers, \"\n    \"dicts or lists; found {}\"\n)\n\n\ndef default_collate(batch):\n    r\"\"\"Puts each data field into a tensor with outer dimension batch size\"\"\"\n\n    elem = batch[0]\n    elem_type = type(elem)\n    if isinstance(elem, (flow.Tensor, flow._oneflow_internal.Tensor)):\n        # TODO: tensor.storage()._new_shared(numel)\n        return flow._C.stack(batch, dim=0)\n    elif (\n        elem_type.__module__ == \"numpy\"\n        and elem_type.__name__ != \"str_\"\n        and elem_type.__name__ != \"string_\"\n    ):\n        if elem_type.__name__ == \"ndarray\" or elem_type.__name__ == \"memmap\":\n            # array of string classes and object\n            if np_str_obj_array_pattern.search(elem.dtype.str) is not None:\n                raise TypeError(default_collate_err_msg_format.format(elem.dtype))\n\n            return default_collate([flow.tensor(b) for b in batch])\n        elif elem.shape == ():  # scalars\n            return flow.tensor(batch)\n    elif isinstance(elem, float):\n        return flow.tensor(batch, dtype=flow.float64)\n    elif isinstance(elem, int):\n        return flow.tensor(batch)\n    elif isinstance(elem, string_classes):\n        return batch\n    elif isinstance(elem, collections.abc.Mapping):\n        return {key: default_collate([d[key] for d in batch]) for key in elem}\n    elif isinstance(elem, tuple) and hasattr(elem, \"_fields\"):  # namedtuple\n        return elem_type(*(default_collate(samples) for samples in zip(*batch)))\n    elif isinstance(elem, collections.abc.Sequence):\n        # check to make sure that the elements in batch have consistent size\n        it = iter(batch)\n        elem_size = len(next(it))\n        if not all(len(elem) == elem_size for elem in it):\n            raise RuntimeError(\"each element in list of batch should be of equal size\")\n        transposed = zip(*batch)\n        return [default_collate(samples) for samples in transposed]\n\n    raise TypeError(default_collate_err_msg_format.format(elem_type))\n"
  },
  {
    "path": "python/oneflow/utils/data/_utils/fetch.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\"\"\"\"Contains definitions of the methods used by the _BaseDataLoaderIter to fetch\ndata from an iterable-style or map-style dataset. This logic is shared in both\nsingle- and multi-processing data loading.\n\"\"\"\n\n\nclass _BaseDatasetFetcher(object):\n    def __init__(self, dataset, auto_collation, collate_fn, drop_last):\n        self.dataset = dataset\n        self.auto_collation = auto_collation\n        self.collate_fn = collate_fn\n        self.drop_last = drop_last\n\n    def fetch(self, possibly_batched_index):\n        raise NotImplementedError()\n\n\nclass _IterableDatasetFetcher(_BaseDatasetFetcher):\n    def __init__(self, dataset, auto_collation, collate_fn, drop_last):\n        super(_IterableDatasetFetcher, self).__init__(\n            dataset, auto_collation, collate_fn, drop_last\n        )\n        self.dataset_iter = iter(dataset)\n\n    def fetch(self, possibly_batched_index):\n        if self.auto_collation:\n            data = []\n            for _ in possibly_batched_index:\n                try:\n                    data.append(next(self.dataset_iter))\n                except StopIteration:\n                    break\n            if len(data) == 0 or (\n                self.drop_last and len(data) < len(possibly_batched_index)\n            ):\n                raise StopIteration\n        else:\n            data = next(self.dataset_iter)\n        return self.collate_fn(data)\n\n\nclass _MapDatasetFetcher(_BaseDatasetFetcher):\n    def __init__(self, dataset, auto_collation, collate_fn, drop_last):\n        super(_MapDatasetFetcher, self).__init__(\n            dataset, auto_collation, collate_fn, drop_last\n        )\n\n    def fetch(self, possibly_batched_index):\n        if self.auto_collation:\n            data = [self.dataset[idx] for idx in possibly_batched_index]\n        else:\n            data = self.dataset[possibly_batched_index]\n        return self.collate_fn(data)\n"
  },
  {
    "path": "python/oneflow/utils/data/_utils/pin_memory.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nr\"\"\"\"Contains definitions of the methods used by the _BaseDataLoaderIter to put\nfetched tensors into pinned memory.\n\nThese **needs** to be in global scope since Py2 doesn't support serializing\nstatic methods.\n\"\"\"\n\nimport oneflow as flow\nimport collections.abc\nimport queue\n\nfrom . import MP_STATUS_CHECK_INTERVAL\nfrom oneflow._utils import ExceptionWrapper\n\ncontainer_abcs = collections.abc\nstring_classes = (str, bytes)\n\n\ndef _pin_memory_loop(in_queue, out_queue, device_id, done_event):\n    # This setting is thread local, and prevents the copy in pin_memory from\n    # consuming all CPU cores.\n    flow.set_num_threads(1)\n\n    # TODO: support flow.cuda.set_device\n    # flow.cuda.set_device(device_id)\n\n    while not done_event.is_set():\n        try:\n            r = in_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)\n        except queue.Empty:\n            continue\n        idx, data = r\n        if not done_event.is_set() and not isinstance(data, ExceptionWrapper):\n            try:\n                data = pin_memory(data)\n            except Exception:\n                data = ExceptionWrapper(\n                    where=\"in pin memory thread for device {}\".format(device_id)\n                )\n            r = (idx, data)\n        while not done_event.is_set():\n            try:\n                out_queue.put(r, timeout=MP_STATUS_CHECK_INTERVAL)\n                break\n            except queue.Full:\n                continue\n        del r  # save memory\n\n\ndef pin_memory(data):\n    if isinstance(data, flow.Tensor):\n        return data.pin_memory()\n    elif isinstance(data, string_classes):\n        return data\n    elif isinstance(data, container_abcs.Mapping):\n        return {k: pin_memory(sample) for k, sample in data.items()}\n    elif isinstance(data, tuple) and hasattr(data, \"_fields\"):  # namedtuple\n        return type(data)(*(pin_memory(sample) for sample in data))\n    elif isinstance(data, container_abcs.Sequence):\n        return [pin_memory(sample) for sample in data]\n    elif hasattr(data, \"pin_memory\"):\n        return data.pin_memory()\n    else:\n        return data\n"
  },
  {
    "path": "python/oneflow/utils/data/_utils/signal_handling.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nr\"\"\"Signal handling for multiprocessing data loading.\nNOTE [ Signal handling in multiprocessing data loading ]\nIn cases like DataLoader, if a worker process dies due to bus error/segfault\nor just hang, the main process will hang waiting for data. This is difficult\nto avoid on OneFlow side as it can be caused by limited shm, or other\nlibraries users call in the workers. In this file and `DataLoader.cpp`, we make\nour best effort to provide some error message to users when such unfortunate\nevents happen.\nWhen a _BaseDataLoaderIter starts worker processes, their pids are registered in a\ndefined in `DataLoader.cpp`: id(_BaseDataLoaderIter) => Collection[ Worker pids ]\nvia `_set_worker_pids`.\nWhen an error happens in a worker process, the main process received a SIGCHLD,\nand Python will eventually call the handler registered below\n(in `_set_SIGCHLD_handler`). In the handler, the `_error_if_any_worker_fails`\ncall checks all registered worker pids and raise proper error message to\nprevent main process from hanging waiting for data from worker.\nAdditionally, at the beginning of each worker's `_utils.worker._worker_loop`,\n`_set_worker_signal_handlers` is called to register critical signal handlers\n(e.g., for SIGSEGV, SIGBUS, SIGFPE, SIGTERM) in C, which just prints an error\nmessage to stderr before triggering the default handler. So a message will also\nbe printed from the worker process when it is killed by such signals.\nSee NOTE [ Data Loader Multiprocessing Shutdown Logic ] for the reasoning of\nthis signal handling design and other mechanism we implement to make our\nmultiprocessing data loading robust to errors.\n\"\"\"\n\nimport signal\nimport threading\nfrom . import IS_WINDOWS\n\n# Some of the following imported functions are not used in this file, but are to\n# be used `_utils.signal_handling.XXXXX`.\n\n\nfrom oneflow._oneflow_internal import (\n    _set_worker_pids,\n    _remove_worker_pids,\n    _error_if_any_worker_fails,\n    _set_worker_signal_handlers,\n)\n\n_SIGCHLD_handler_set = False\nr\"\"\"Whether SIGCHLD handler is set for DataLoader worker failures. Only one\nhandler needs to be set for all DataLoaders in a process.\"\"\"\n\n\ndef _set_SIGCHLD_handler():\n    # Windows doesn't support SIGCHLD handler\n    if IS_WINDOWS:\n        return\n    # can't set signal in child threads\n    if not isinstance(threading.current_thread(), threading._MainThread):  # type: ignore[attr-defined]\n        return\n    global _SIGCHLD_handler_set\n    if _SIGCHLD_handler_set:\n        return\n    previous_handler = signal.getsignal(signal.SIGCHLD)\n    if not callable(previous_handler):\n        # This doesn't catch default handler, but SIGCHLD default handler is a\n        # no-op.\n        previous_handler = None\n\n    def handler(signum, frame):\n        # This following call uses `waitid` with WNOHANG from C side. Therefore,\n        # Python can still get and update the process status successfully.\n        _error_if_any_worker_fails()\n        if previous_handler is not None:\n            assert callable(previous_handler)\n            previous_handler(signum, frame)\n\n    signal.signal(signal.SIGCHLD, handler)\n    _SIGCHLD_handler_set = True\n"
  },
  {
    "path": "python/oneflow/utils/data/_utils/worker.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nr\"\"\"\"Contains definitions of the methods used by the _BaseDataLoaderIter workers.\nThese **needs** to be in global scope since Py2 doesn't support serializing\nstatic methods.\n\"\"\"\nimport random\nimport os\nimport sys\nimport traceback\nimport queue\nfrom dataclasses import dataclass\nfrom typing import Union\nfrom oneflow.multiprocessing import _prctl_pr_set_pdeathsig  # type: ignore[attr-defined]\nfrom oneflow.multiprocessing import unlink_all_shared_memory\nimport signal\n\nimport oneflow as flow\nfrom . import signal_handling, MP_STATUS_CHECK_INTERVAL, IS_WINDOWS, HAS_NUMPY\nfrom oneflow._utils import ExceptionWrapper\n\n\nif IS_WINDOWS:\n    import ctypes\n    from ctypes.wintypes import DWORD, BOOL, HANDLE\n\n    # On Windows, the parent ID of the worker process remains unchanged when the manager process\n    # is gone, and the only way to check it through OS is to let the worker have a process handle\n    # of the manager and ask if the process status has changed.\n    class ManagerWatchdog(object):\n        def __init__(self):\n            self.manager_pid = os.getppid()\n\n            # mypy cannot detect this code is windows only\n            self.kernel32 = ctypes.WinDLL(\"kernel32\", use_last_error=True)  # type: ignore[attr-defined]\n            self.kernel32.OpenProcess.argtypes = (DWORD, BOOL, DWORD)\n            self.kernel32.OpenProcess.restype = HANDLE\n            self.kernel32.WaitForSingleObject.argtypes = (HANDLE, DWORD)\n            self.kernel32.WaitForSingleObject.restype = DWORD\n\n            # Value obtained from https://msdn.microsoft.com/en-us/library/ms684880.aspx\n            SYNCHRONIZE = 0x00100000\n            self.manager_handle = self.kernel32.OpenProcess(\n                SYNCHRONIZE, 0, self.manager_pid\n            )\n\n            if not self.manager_handle:\n                raise ctypes.WinError(ctypes.get_last_error())  # type: ignore[attr-defined]\n\n            self.manager_dead = False\n\n        def is_alive(self):\n            if not self.manager_dead:\n                # Value obtained from https://msdn.microsoft.com/en-us/library/windows/desktop/ms687032.aspx\n                self.manager_dead = (\n                    self.kernel32.WaitForSingleObject(self.manager_handle, 0) == 0\n                )\n            return not self.manager_dead\n\n\nelse:\n\n    class ManagerWatchdog(object):  # type: ignore[no-redef]\n        def __init__(self):\n            self.manager_pid = os.getppid()\n            self.manager_dead = False\n\n        def is_alive(self):\n            if not self.manager_dead:\n                self.manager_dead = os.getppid() != self.manager_pid\n            return not self.manager_dead\n\n\n_worker_info = None\n\n\nclass WorkerInfo(object):\n    __initialized = False\n\n    def __init__(self, **kwargs):\n        for k, v in kwargs.items():\n            setattr(self, k, v)\n        self.__keys = tuple(kwargs.keys())\n        self.__initialized = True\n\n    def __setattr__(self, key, val):\n        if self.__initialized:\n            raise RuntimeError(\n                \"Cannot assign attributes to {} objects\".format(self.__class__.__name__)\n            )\n        return super(WorkerInfo, self).__setattr__(key, val)\n\n    def __repr__(self):\n        items = []\n        for k in self.__keys:\n            items.append(\"{}={}\".format(k, getattr(self, k)))\n        return \"{}({})\".format(self.__class__.__name__, \", \".join(items))\n\n\ndef get_worker_info():\n    r\"\"\"Returns the information about the current\n    :class:`~flow.utils.data.DataLoader` iterator worker process.\n    When called in a worker, this returns an object guaranteed to have the\n    following attributes:\n    * :attr:`id`: the current worker id.\n    * :attr:`num_workers`: the total number of workers.\n    * :attr:`seed`: the random seed set for the current worker. This value is\n      determined by main process RNG and the worker id. See\n      :class:`~flow.utils.data.DataLoader`'s documentation for more details.\n    * :attr:`dataset`: the copy of the dataset object in **this** process. Note\n      that this will be a different object in a different process than the one\n      in the main process.\n    When called in the main process, this returns ``None``.\n    .. note::\n       When used in a :attr:`worker_init_fn` passed over to\n       :class:`~flow.utils.data.DataLoader`, this method can be useful to\n       set up each worker process differently, for instance, using ``worker_id``\n       to configure the ``dataset`` object to only read a specific fraction of a\n       sharded dataset, or use ``seed`` to seed other libraries used in dataset\n       code.\n    \"\"\"\n    return _worker_info\n\n\nr\"\"\"Dummy class used to signal the end of an IterableDataset\"\"\"\n\n\n@dataclass(frozen=True)\nclass _IterableDatasetStopIteration(object):\n    worker_id: int\n\n\nr\"\"\"Dummy class used to resume the fetching when worker reuse is enabled\"\"\"\n\n\n@dataclass(frozen=True)\nclass _ResumeIteration(object):\n    pass\n\n\n# The function `_generate_state` is adapted from `numpy.random.SeedSequence`\n# from https://github.com/numpy/numpy/blob/main/numpy/random/bit_generator.pyx\n# It's MIT licensed, here is the copyright:\n\n# Copyright (c) 2015 Melissa E. O'Neill\n# Copyright (c) 2019 NumPy Developers\n#\n# Permission is hereby granted, free of charge, to any person obtaining a copy\n# of this software and associated documentation files (the \"Software\"), to deal\n# in the Software without restriction, including without limitation the rights\n# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n# copies of the Software, and to permit persons to whom the Software is\n# furnished to do so, subject to the following conditions:\n#\n# The above copyright notice and this permission notice shall be included in\n# all copies or substantial portions of the Software.\n#\n# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\n# SOFTWARE.\n\n# This function generates an array of int32 as the seed for\n# `numpy.random`, in order to prevent state collision due to same\n# seed and algorithm for `numpy.random` and `random` modules.\n# TODO: Implement `SeedSequence` like object for `flow.random`\ndef _generate_state(base_seed, worker_id):\n    INIT_A = 0x43B0D7E5\n    MULT_A = 0x931E8875\n    INIT_B = 0x8B51F9DD\n    MULT_B = 0x58F38DED\n    MIX_MULT_L = 0xCA01F9DD\n    MIX_MULT_R = 0x4973F715\n    XSHIFT = 4 * 8 // 2\n    MASK32 = 0xFFFFFFFF\n\n    entropy = [worker_id, base_seed & MASK32, base_seed >> 32, 0]\n    pool = [0] * 4\n\n    hash_const_A = INIT_A\n\n    def hash(value):\n        nonlocal hash_const_A\n        value = (value ^ hash_const_A) & MASK32\n        hash_const_A = (hash_const_A * MULT_A) & MASK32\n        value = (value * hash_const_A) & MASK32\n        value = (value ^ (value >> XSHIFT)) & MASK32\n        return value\n\n    def mix(x, y):\n        result_x = (MIX_MULT_L * x) & MASK32\n        result_y = (MIX_MULT_R * y) & MASK32\n        result = (result_x - result_y) & MASK32\n        result = (result ^ (result >> XSHIFT)) & MASK32\n        return result\n\n    # Add in the entropy to the pool.\n    for i in range(len(pool)):\n        pool[i] = hash(entropy[i])\n\n    # Mix all bits together so late bits can affect earlier bits.\n    for i_src in range(len(pool)):\n        for i_dst in range(len(pool)):\n            if i_src != i_dst:\n                pool[i_dst] = mix(pool[i_dst], hash(pool[i_src]))\n\n    hash_const_B = INIT_B\n    state = []\n    for i_dst in range(4):\n        data_val = pool[i_dst]\n        data_val = (data_val ^ hash_const_B) & MASK32\n        hash_const_B = (hash_const_B * MULT_B) & MASK32\n        data_val = (data_val * hash_const_B) & MASK32\n        data_val = (data_val ^ (data_val >> XSHIFT)) & MASK32\n        state.append(data_val)\n    return state\n\n\ndef _worker_loop(\n    dataset_kind,\n    dataset,\n    index_queue,\n    data_queue,\n    done_event,\n    auto_collation,\n    collate_fn,\n    drop_last,\n    base_seed,\n    init_fn,\n    worker_id,\n    num_workers,\n    persistent_workers,\n):\n    # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the\n    # logic of this function.\n    try:\n\n        def cleanup_shm_at_exit(num, frame):\n            unlink_all_shared_memory()\n            # Use os._exit() to handle the exit of the subprocess to avoid share memory leaks\n            # caused by the subprocess continuing for a period of time after the parent process ends.\n            os._exit(0)\n\n        _prctl_pr_set_pdeathsig(signal.SIGINT)\n\n        # Initialize C side signal handlers for SIGBUS and SIGSEGV. Python signal\n        # module's handlers are executed after Python returns from C low-level\n        # handlers, likely when the same fatal signal had already happened\n        # again.\n        # https://docs.python.org/3/library/signal.html#execution-of-python-signal-handlers\n        signal_handling._set_worker_signal_handlers()\n        signal.signal(signal.SIGTERM, cleanup_shm_at_exit)\n        signal.signal(signal.SIGINT, cleanup_shm_at_exit)\n        flow.set_num_threads(1)\n        seed = base_seed + worker_id\n        random.seed(seed)\n        flow.manual_seed(seed)\n        if HAS_NUMPY:\n            np_seed = _generate_state(base_seed, worker_id)\n            import numpy as np\n\n            np.random.seed(np_seed)\n\n        global _worker_info\n        _worker_info = WorkerInfo(\n            id=worker_id, num_workers=num_workers, seed=seed, dataset=dataset\n        )\n\n        from oneflow.utils.data import _DatasetKind\n\n        init_exception = None\n\n        try:\n            if init_fn is not None:\n                init_fn(worker_id)\n\n            fetcher = _DatasetKind.create_fetcher(\n                dataset_kind, dataset, auto_collation, collate_fn, drop_last\n            )\n        except Exception:\n            init_exception = ExceptionWrapper(\n                where=\"in DataLoader worker process {}\".format(worker_id)\n            )\n\n        # When using Iterable mode, some worker can exit earlier than others due\n        # to the IterableDataset behaving differently for different workers.\n        # When such things happen, an `_IterableDatasetStopIteration` object is\n        # sent over to the main process with the ID of this worker, so that the\n        # main process won't send more tasks to this worker, and will send\n        # `None` to this worker to properly exit it.\n        #\n        # Note that we cannot set `done_event` from a worker as it is shared\n        # among all processes. Instead, we set the `iteration_end` flag to\n        # signify that the iterator is exhausted. When either `done_event` or\n        # `iteration_end` is set, we skip all processing step and just wait for\n        # `None`.\n        iteration_end = False\n\n        watchdog = ManagerWatchdog()\n\n        while watchdog.is_alive():\n            try:\n                r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)\n            except queue.Empty:\n                continue\n            if isinstance(r, _ResumeIteration):\n                # Acknowledge the main process\n                data_queue.put((r, None))\n                iteration_end = False\n                # Recreate the fetcher for worker-reuse policy\n                fetcher = _DatasetKind.create_fetcher(\n                    dataset_kind, dataset, auto_collation, collate_fn, drop_last\n                )\n                continue\n            elif r is None:\n                # Received the final signal\n                assert done_event.is_set() or iteration_end\n                break\n            elif done_event.is_set() or iteration_end:\n                # `done_event` is set. But I haven't received the final signal\n                # (None) yet. I will keep continuing until get it, and skip the\n                # processing steps.\n                continue\n            idx, index = r\n            data: Union[_IterableDatasetStopIteration, ExceptionWrapper]\n\n            if init_exception is not None:\n                data = init_exception\n                init_exception = None\n            else:\n                try:\n                    data = fetcher.fetch(index)\n                except Exception as e:\n                    if (\n                        isinstance(e, StopIteration)\n                        and dataset_kind == _DatasetKind.Iterable\n                    ):\n                        data = _IterableDatasetStopIteration(worker_id)\n                        # Set `iteration_end`\n                        #   (1) to save future `next(...)` calls, and\n                        #   (2) to avoid sending multiple `_IterableDatasetStopIteration`s.\n                        iteration_end = True\n                    else:\n                        # It is important that we don't store exc_info in a variable.\n                        # `ExceptionWrapper` does the correct thing.\n                        # See NOTE [ Python Traceback Reference Cycle Problem ]\n                        data = ExceptionWrapper(\n                            where=\"in DataLoader worker process {}\".format(worker_id)\n                        )\n            data_queue.put((idx, data))\n            del data, idx, index, r  # save memory\n    except KeyboardInterrupt:\n        # Main process will raise KeyboardInterrupt anyways.\n        pass\n    if done_event.is_set():\n        data_queue.cancel_join_thread()\n        data_queue.close()\n\n    # Python subprocess will be exited by os._exit(), which skips destructors of\n    # C++ objects, so we should explicitly call unlink_all_shared_memory() here\n    unlink_all_shared_memory()\n"
  },
  {
    "path": "python/oneflow/utils/data/dataloader.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport warnings\nimport os\nimport threading\nimport itertools\nimport queue\n\nfrom typing import Any, Callable, TypeVar, Generic, Sequence, List, Optional\nimport multiprocessing as python_multiprocessing\n\nimport oneflow.multiprocessing as multiprocessing\nfrom oneflow._utils import ExceptionWrapper\nimport oneflow as flow\nimport numpy as np\n\nstring_classes = (str, bytes)\n\nfrom . import (\n    IterableDataset,\n    Sampler,\n    SequentialSampler,\n    RandomSampler,\n    BatchSampler,\n    Dataset,\n)\nfrom . import _utils\n\nT_co = TypeVar(\"T_co\", covariant=True)\nT = TypeVar(\"T\")\n_worker_init_fn_t = Callable[[int], None]\n\n# Ideally we would parameterize `DataLoader` by the return type of `collate_fn`, but there is currently no way to have that\n# type parameter set to a default value if the user doesn't pass in a custom 'collate_fn'.\n# See https://github.com/python/mypy/issues/3737.\n_collate_fn_t = Callable[[List[T]], Any]\n\n\n# This function used to be defined in this file. However, it was moved to\n# _utils/collate.py. Although it is rather hard to access this from user land\n# (one has to explicitly directly `import flow.utils.data.dataloader`), there\n# probably is user code out there using it. This aliasing maintains BC in this\n# aspect.\ndefault_collate: _collate_fn_t = _utils.collate.default_collate\n\nget_worker_info = _utils.worker.get_worker_info\n\n\nclass _DatasetKind(object):\n    Map = 0\n    Iterable = 1\n\n    @staticmethod\n    def create_fetcher(kind, dataset, auto_collation, collate_fn, drop_last):\n        if kind == _DatasetKind.Map:\n            return _utils.fetch._MapDatasetFetcher(\n                dataset, auto_collation, collate_fn, drop_last\n            )\n        else:\n            return _utils.fetch._IterableDatasetFetcher(\n                dataset, auto_collation, collate_fn, drop_last\n            )\n\n\nclass _InfiniteConstantSampler(Sampler):\n    r\"\"\"Analogous to ``itertools.repeat(None, None)``.\n    Used as sampler for :class:`~flow.utils.data.IterableDataset`.\n\n    Args:\n        data_source (Dataset): dataset to sample from\n    \"\"\"\n\n    def __init__(self):\n        super(_InfiniteConstantSampler, self).__init__(None)\n\n    def __iter__(self):\n        while True:\n            yield None\n\n\nclass DataLoader(Generic[T_co]):\n    r\"\"\"\n    Data loader. Combines a dataset and a sampler, and provides an iterable over\n    the given dataset.\n\n    The :class:`~oneflow.utils.data.DataLoader` supports both map-style and\n    iterable-style datasets with single- or multi-process loading, customizing\n    loading order and optional automatic batching (collation) and memory pinning.\n\n    See :py:mod:`oneflow.utils.data` documentation page for more details.\n\n    In consideration of compatibility, the design of our dataloader is consistent with pytorch, ref: https://github.com/pytorch/pytorch/tree/v1.7.0\n\n    Args:\n        dataset (Dataset): dataset from which to load the data.\n        batch_size (int, optional): how many samples per batch to load\n            (default: ``1``).\n        shuffle (bool, optional): set to ``True`` to have the data reshuffled\n            at every epoch (default: ``False``).\n        sampler (Sampler or Iterable, optional): defines the strategy to draw\n            samples from the dataset. Can be any ``Iterable`` with ``__len__``\n            implemented. If specified, :attr:`shuffle` must not be specified.\n        batch_sampler (Sampler or Iterable, optional): like :attr:`sampler`, but\n            returns a batch of indices at a time. Mutually exclusive with\n            :attr:`batch_size`, :attr:`shuffle`, :attr:`sampler`,\n            and :attr:`drop_last`.\n        num_workers (int, optional): how many subprocesses to use for data\n            loading (default: ``0``). ``0`` means that the data will be loaded in the main process.\n        collate_fn (callable, optional): merges a list of samples to form a\n            mini-batch of Tensor(s).  Used when using batched loading from a\n            map-style dataset.\n        pin_memory (bool, optional): If ``True``, the data loader will copy Tensors\n            into CUDA pinned memory before returning them.  If your data elements\n            are a custom type, or your :attr:`collate_fn` returns a batch that is a custom type,\n            see the example below. (default: ``False``)\n        drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,\n            if the dataset size is not divisible by the batch size. If ``False`` and\n            the size of dataset is not divisible by the batch size, then the last batch\n            will be smaller. (default: ``False``)\n        timeout (numeric, optional): if positive, the timeout value for collecting a batch\n            from workers. Should always be non-negative. (default: ``0``)\n        worker_init_fn (callable, optional): If not ``None``, this will be called on each\n            worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as\n            input, after seeding and before data loading. (default: ``None``)\n        prefetch_factor (int, optional, keyword-only arg): Number of samples loaded\n            in advance by each worker. ``2`` means there will be a total of\n            2 * num_workers samples prefetched across all workers. (default: ``2``)\n        persistent_workers (bool, optional): If ``True``, the data loader will immediately \n            initialize worker preocesses and not shutdown them after a dataset has been \n            consumed once. This allows to maintain the workers `Dataset` instances alive. \n            If you are using oneflow with RDMA support in distributed training, the\n            ``persistent_workers`` must be ``True`` otherwise will encounter segmentation\n            fault. (default: ``False``)\n\n\n    .. warning:: If the ``spawn`` start method is used, :attr:`worker_init_fn`\n                 cannot be an unpicklable object, e.g., a lambda function.\n\n    .. warning:: ``len(dataloader)`` heuristic is based on the length of the sampler used.\n                 When :attr:`dataset` is an :class:`~flow.utils.data.IterableDataset`,\n                 it instead returns an estimate based on ``len(dataset) / batch_size``, with proper\n                 rounding depending on :attr:`drop_last`, regardless of multi-process loading\n                 configurations. This represents the best guess OneFlow can make because OneFlow\n                 trusts user :attr:`dataset` code in correctly handling multi-process\n                 loading to avoid duplicate data.\n\n                 However, if sharding results in multiple workers having incomplete last batches,\n                 this estimate can still be inaccurate, because (1) an otherwise complete batch can\n                 be broken into multiple ones and (2) more than one batch worth of samples can be\n                 dropped when :attr:`drop_last` is set. Unfortunately, OneFlow can not detect such\n                 cases in general.\n\n    \"\"\"\n    dataset: Dataset[T_co]\n    batch_size: Optional[int]\n    num_workers: int\n    pin_memory: bool\n    drop_last: bool\n    timeout: float\n    sampler: Sampler\n    prefetch_factor: int\n    _iterator: Optional[\"_BaseDataLoaderIter\"]\n    __initialized = False\n\n    def __init__(\n        self,\n        dataset: Dataset[T_co],\n        batch_size: Optional[int] = 1,\n        shuffle: bool = False,\n        sampler: Optional[Sampler[int]] = None,\n        batch_sampler: Optional[Sampler[Sequence[int]]] = None,\n        num_workers: int = 0,\n        collate_fn: Optional[_collate_fn_t] = None,\n        pin_memory: bool = False,\n        drop_last: bool = False,\n        timeout: float = 0,\n        worker_init_fn: Optional[_worker_init_fn_t] = None,\n        multiprocessing_context=None,\n        generator=flow.Generator(\"cpu\"),\n        *,\n        prefetch_factor: int = 2,\n        persistent_workers: bool = False\n    ):\n\n        if num_workers < 0:\n            raise ValueError(\n                \"num_workers option should be non-negative; \"\n                \"use num_workers=0 to disable multiprocessing.\"\n            )\n        else:\n            self.num_workers = num_workers\n\n        if timeout < 0:\n            raise ValueError(\"timeout option should be non-negative\")\n\n        if self.num_workers == 0 and prefetch_factor != 2:\n            raise ValueError(\n                \"prefetch_factor option could only be specified in multiprocessing.\"\n                \"let num_workers > 0 to enable multiprocessing.\"\n            )\n        assert prefetch_factor > 0\n\n        if persistent_workers and num_workers == 0:\n            raise ValueError(\"persistent_workers option needs num_workers > 0\")\n\n        self.dataset = dataset\n        self.prefetch_factor = prefetch_factor\n        self.pin_memory = pin_memory\n        self.timeout = timeout\n        self.worker_init_fn = worker_init_fn\n        self.multiprocessing_context = multiprocessing_context\n\n        # Arg-check dataset related before checking samplers because we want to\n        # tell users that iterable-style datasets are incompatible with custom\n        # samplers first, so that they don't learn that this combo doesn't work\n        # after spending time fixing the custom sampler errors.\n        if isinstance(dataset, IterableDataset):\n            self._dataset_kind = _DatasetKind.Iterable\n            # NOTE [ Custom Samplers and IterableDataset ]\n            #\n            # `IterableDataset` does not support custom `batch_sampler` or\n            # `sampler` since the key is irrelevant (unless we support\n            # generator-style dataset one day...).\n            #\n            # For `sampler`, we always create a dummy sampler. This is an\n            # infinite sampler even when the dataset may have an implemented\n            # finite `__len__` because in multi-process data loading, naive\n            # settings will return duplicated data (which may be desired), and\n            # thus using a sampler with length matching that of dataset will\n            # cause data lost (you may have duplicates of the first couple\n            # batches, but never see anything afterwards). Therefore,\n            # `Iterabledataset` always uses an infinite sampler, an instance of\n            # `_InfiniteConstantSampler` defined above.\n            #\n            # A custom `batch_sampler` essentially only controls the batch size.\n            # However, it is unclear how useful it would be since an iterable-style\n            # dataset can handle that within itself. Moreover, it is pointless\n            # in multi-process data loading as the assignment order of batches\n            # to workers is an implementation detail so users can not control\n            # how to batchify each worker's iterable. Thus, we disable this\n            # option. If this turns out to be useful in future, we can re-enable\n            # this, and support custom samplers that specify the assignments to\n            # specific workers.\n            if shuffle is not False:\n                raise ValueError(\n                    \"DataLoader with IterableDataset: expected unspecified \"\n                    \"shuffle option, but got shuffle={}\".format(shuffle)\n                )\n            elif sampler is not None:\n                # See NOTE [ Custom Samplers and IterableDataset ]\n                raise ValueError(\n                    \"DataLoader with IterableDataset: expected unspecified \"\n                    \"sampler option, but got sampler={}\".format(sampler)\n                )\n            elif batch_sampler is not None:\n                # See NOTE [ Custom Samplers and IterableDataset ]\n                raise ValueError(\n                    \"DataLoader with IterableDataset: expected unspecified \"\n                    \"batch_sampler option, but got batch_sampler={}\".format(\n                        batch_sampler\n                    )\n                )\n        else:\n            self._dataset_kind = _DatasetKind.Map\n\n        if sampler is not None and shuffle:\n            raise ValueError(\"sampler option is mutually exclusive with \" \"shuffle\")\n\n        if batch_sampler is not None:\n            # auto_collation with custom batch_sampler\n            if batch_size != 1 or shuffle or sampler is not None or drop_last:\n                raise ValueError(\n                    \"batch_sampler option is mutually exclusive \"\n                    \"with batch_size, shuffle, sampler, and \"\n                    \"drop_last\"\n                )\n            batch_size = None\n            drop_last = False\n        elif batch_size is None:\n            # no auto_collation\n            if drop_last:\n                raise ValueError(\n                    \"batch_size=None option disables auto-batching \"\n                    \"and is mutually exclusive with drop_last\"\n                )\n\n        if sampler is None:  # give default samplers\n            if self._dataset_kind == _DatasetKind.Iterable:\n                # See NOTE [ Custom Samplers and IterableDataset ]\n                sampler = _InfiniteConstantSampler()\n            else:  # map-style\n                if shuffle:\n                    # Cannot statically verify that dataset is Sized\n                    # Somewhat related: see NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]\n                    sampler = RandomSampler(dataset, generator=generator)  # type: ignore\n                else:\n                    sampler = SequentialSampler(dataset)\n\n        if batch_size is not None and batch_sampler is None:\n            # auto_collation without custom batch_sampler\n            batch_sampler = BatchSampler(sampler, batch_size, drop_last)\n\n        self.batch_size = batch_size\n        self.drop_last = drop_last\n        self.sampler = sampler\n        self.batch_sampler = batch_sampler\n        self.generator = generator\n\n        if collate_fn is None:\n            if self._auto_collation:\n                collate_fn = _utils.collate.default_collate\n            else:\n                collate_fn = _utils.collate.default_convert\n\n        self.collate_fn = collate_fn\n        self.persistent_workers = persistent_workers\n\n        self.__initialized = True\n        self._IterableDataset_len_called = (\n            None  # See NOTE [ IterableDataset and __len__ ]\n        )\n\n        self._iterator = self._get_iterator() if self.persistent_workers else None\n\n    def _get_iterator(self) -> \"_BaseDataLoaderIter\":\n        if self.num_workers == 0:\n            return _SingleProcessDataLoaderIter(self)\n        else:\n            self.check_worker_number_rationality()\n            return _MultiProcessingDataLoaderIter(self)\n\n    def __setattr__(self, attr, val):\n        if self.__initialized and attr in (\n            \"batch_size\",\n            \"batch_sampler\",\n            \"sampler\",\n            \"drop_last\",\n            \"dataset\",\n            \"persistent_workers\",\n        ):\n            raise ValueError(\n                \"{} attribute should not be set after {} is \"\n                \"initialized\".format(attr, self.__class__.__name__)\n            )\n\n        super(DataLoader, self).__setattr__(attr, val)\n\n    # We quote '_BaseDataLoaderIter' since it isn't defined yet and the definition can't be moved up\n    # since '_BaseDataLoaderIter' references 'DataLoader'.\n    def __iter__(self) -> \"_BaseDataLoaderIter\":\n        # When using a single worker the returned iterator should be\n        # created everytime to avoid reseting its state\n        # However, in the case of a multiple workers iterator\n        # the iterator is only created once in the lifetime of the\n        # DataLoader object so that workers can be reused\n        if self.persistent_workers and self.num_workers > 0:\n            if self._iterator is None:\n                self._iterator = self._get_iterator()\n            elif not self._iterator._status_reset:\n                self._iterator._reset(self)\n            return self._iterator\n        else:\n            return self._get_iterator()\n\n    @property\n    def _auto_collation(self):\n        return self.batch_sampler is not None\n\n    @property\n    def _index_sampler(self):\n        # The actual sampler used for generating indices for `_DatasetFetcher`\n        # (see _utils/fetch.py) to read data at each time. This would be\n        # `.batch_sampler` if in auto-collation mode, and `.sampler` otherwise.\n        # We can't change `.sampler` and `.batch_sampler` attributes for BC\n        # reasons.\n        if self._auto_collation:\n            return self.batch_sampler\n        else:\n            return self.sampler\n\n    def __len__(self) -> int:\n        if self._dataset_kind == _DatasetKind.Iterable:\n            # NOTE [ IterableDataset and __len__ ]\n            #\n            # For `IterableDataset`, `__len__` could be inaccurate when one naively\n            # does multi-processing data loading, since the samples will be duplicated.\n            # However, no real use case should be actually using that behavior, so\n            # it should count as a user error. We should generally trust user\n            # code to do the proper thing (e.g., configure each replica differently\n            # in `__iter__`), and give us the correct `__len__` if they choose to\n            # implement it (this will still throw if the dataset does not implement\n            # a `__len__`).\n            #\n            # To provide a further warning, we track if `__len__` was called on the\n            # `DataLoader`, save the returned value in `self._len_called`, and warn\n            # if the iterator ends up yielding more than this number of samples.\n\n            # Cannot statically verify that dataset is Sized\n            length = self._IterableDataset_len_called = len(self.dataset)  # type: ignore\n            if (\n                self.batch_size is not None\n            ):  # IterableDataset doesn't allow custom sampler or batch_sampler\n                from math import ceil\n\n                if self.drop_last:\n                    length = length // self.batch_size\n                else:\n                    length = ceil(length / self.batch_size)\n            return length\n        else:\n            return len(self._index_sampler)\n\n    def check_worker_number_rationality(self):\n        def _create_warning_msg(num_worker_suggest, num_worker_created, cpuset_checked):\n\n            suggested_max_worker_msg = (\n                (\n                    (\n                        \"Our suggested max number of worker in current system is {}{}, which is smaller \"\n                        \"than what this DataLoader is going to create.\"\n                    ).format(\n                        num_worker_suggest,\n                        (\n                            \"\"\n                            if cpuset_checked\n                            else \" (`cpuset` is not taken into account)\"\n                        ),\n                    )\n                )\n                if num_worker_suggest is not None\n                else (\n                    \"DataLoader is not able to compute a suggested max number of worker in current system.\"\n                )\n            )\n\n            warn_msg = (\n                \"This DataLoader will create {} worker processes in total. {} \"\n                \"Please be aware that excessive worker creation might get DataLoader running slow or even freeze, \"\n                \"lower the worker number to avoid potential slowness/freeze if necessary.\"\n            ).format(num_worker_created, suggested_max_worker_msg)\n            return warn_msg\n\n        if not self.num_workers or self.num_workers == 0:\n            return\n\n        # try to compute a suggested max number of worker based on system's resource\n        max_num_worker_suggest = None\n        cpuset_checked = False\n        if hasattr(os, \"sched_getaffinity\"):\n            try:\n                max_num_worker_suggest = len(os.sched_getaffinity(0))\n                cpuset_checked = True\n            except Exception:\n                pass\n        if max_num_worker_suggest is None:\n            # os.cpu_count() could return Optional[int]\n            # get cpu count first and check None in order to satify mypy check\n            cpu_count = os.cpu_count()\n            if cpu_count is not None:\n                max_num_worker_suggest = cpu_count\n\n        if max_num_worker_suggest is None:\n            warnings.warn(\n                _create_warning_msg(\n                    max_num_worker_suggest, self.num_workers, cpuset_checked\n                )\n            )\n            return\n\n        if self.num_workers > max_num_worker_suggest:\n            warnings.warn(\n                _create_warning_msg(\n                    max_num_worker_suggest, self.num_workers, cpuset_checked\n                )\n            )\n\n\nclass _BaseDataLoaderIter(object):\n    def __init__(self, loader: DataLoader) -> None:\n        self._dataset = loader.dataset\n        self._dataset_kind = loader._dataset_kind\n        self._IterableDataset_len_called = loader._IterableDataset_len_called\n        self._auto_collation = loader._auto_collation\n        self._drop_last = loader.drop_last\n        self._index_sampler = loader._index_sampler\n        self._num_workers = loader.num_workers\n        self._prefetch_factor = loader.prefetch_factor\n        self._pin_memory = loader.pin_memory and flow.cuda.is_available()\n        self._timeout = loader.timeout\n        self._collate_fn = loader.collate_fn\n        self._sampler_iter = iter(self._index_sampler)\n        # self._base_seed = flow.empty((), dtype=flow.int64).random_(generator=loader.generator).item()\n        self._base_seed = flow.randint(\n            0, np.iinfo(np.int64).max, (), generator=loader.generator\n        ).item()\n        self._persistent_workers = loader.persistent_workers\n        self._num_yielded = 0\n        self._profile_name = \"enumerate(DataLoader)#{}.__next__\".format(\n            self.__class__.__name__\n        )\n        self._status_reset = True\n\n    def __iter__(self) -> \"_BaseDataLoaderIter\":\n        return self\n\n    def _reset(self, loader, first_iter=False):\n        self._status_reset = True\n        self._sampler_iter = iter(self._index_sampler)\n        self._num_yielded = 0\n        self._IterableDataset_len_called = loader._IterableDataset_len_called\n\n    def _next_index(self):\n        return next(self._sampler_iter)  # may raise StopIteration\n\n    def _next_data(self):\n        raise NotImplementedError\n\n    def __next__(self) -> Any:\n        self._status_reset = False\n        if self._sampler_iter is None:\n            self._reset()\n        data = self._next_data()\n        self._num_yielded += 1\n        if (\n            self._dataset_kind == _DatasetKind.Iterable\n            and self._IterableDataset_len_called is not None\n            and self._num_yielded > self._IterableDataset_len_called\n        ):\n            warn_msg = (\n                \"Length of IterableDataset {} was reported to be {} (when accessing len(dataloader)), but {} \"\n                \"samples have been fetched. \"\n            ).format(self._dataset, self._IterableDataset_len_called, self._num_yielded)\n            if self._num_workers > 1:\n                warn_msg += \"Multiprocessing dataloader is not support yet!\"\n            warnings.warn(warn_msg)\n\n        return data\n\n    next = __next__\n\n    def __len__(self) -> int:\n        return len(self._index_sampler)\n\n    def __getstate__(self):\n        raise NotImplementedError(\"{} cannot be pickled\", self.__class__.__name__)\n\n\nclass _SingleProcessDataLoaderIter(_BaseDataLoaderIter):\n    def __init__(self, loader):\n        super(_SingleProcessDataLoaderIter, self).__init__(loader)\n        assert self._timeout == 0\n        assert 0 <= self._num_workers <= 1\n\n        self._dataset_fetcher = _DatasetKind.create_fetcher(\n            self._dataset_kind,\n            self._dataset,\n            self._auto_collation,\n            self._collate_fn,\n            self._drop_last,\n        )\n\n    def _next_data(self):\n        index = self._next_index()  # may raise StopIteration\n        data = self._dataset_fetcher.fetch(index)  # may raise StopIteration\n        if self._pin_memory:\n            data = _utils.pin_memory.pin_memory(data)\n        return data\n\n\nclass _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):\n    r\"\"\"Iterates once over the DataLoader's dataset, as specified by the sampler\"\"\"\n\n    # NOTE [ Data Loader Multiprocessing Shutdown Logic ]\n    #\n    # Preliminary:\n    #\n    # Our data model looks like this (queues are indicated with curly brackets):\n    #\n    #                main process                              ||\n    #                     |                                    ||\n    #               {index_queue}                              ||\n    #                     |                                    ||\n    #              worker processes                            ||     DATA\n    #                     |                                    ||\n    #            {worker_result_queue}                         ||     FLOW\n    #                     |                                    ||\n    #      pin_memory_thread of main process                   ||   DIRECTION\n    #                     |                                    ||\n    #               {data_queue}                               ||\n    #                     |                                    ||\n    #                data output                               \\/\n    #\n    # P.S. `worker_result_queue` and `pin_memory_thread` part may be omitted if\n    #      `pin_memory=False`.\n    #\n    #\n    # Terminating multiprocessing logic requires very careful design. In\n    # particular, we need to make sure that\n    #\n    #   1. The iterator gracefully exits the workers when its last reference is\n    #      gone or it is depleted.\n    #\n    #      In this case, the workers should be gracefully exited because the\n    #      main process may still need to continue to run, and we want cleaning\n    #      up code in the workers to be executed (e.g., releasing GPU memory).\n    #      Naturally, we implement the shutdown logic in `__del__` of\n    #      DataLoaderIterator.\n    #\n    #      We delay the discussion on the logic in this case until later.\n    #\n    #   2. The iterator exits the workers when the loader process and/or worker\n    #      processes exits normally or with error.\n    #\n    #      We set all workers and `pin_memory_thread` to have `daemon=True`.\n    #\n    #      You may ask, why can't we make the workers non-daemonic, and\n    #      gracefully exit using the same logic as we have in `__del__` when the\n    #      iterator gets deleted (see 1 above)?\n    #\n    #      First of all, `__del__` is **not** guaranteed to be called when\n    #      interpreter exits. Even if it is called, by the time it executes,\n    #      many Python core library resources may alreay be freed, and even\n    #      simple things like acquiring an internal lock of a queue may hang.\n    #      Therefore, in this case, we actually need to prevent `__del__` from\n    #      being executed, and rely on the automatic termination of daemonic\n    #      children.\n    #\n    #      Thus, we register an `atexit` hook that sets a global flag\n    #      `_utils.python_exit_status`. Since `atexit` hooks are executed in the\n    #      reverse order of registration, we are guaranteed that this flag is\n    #      set before library resources we use are freed (which, at least in\n    #      CPython, is done via an `atexit` handler defined in\n    #      `multiprocessing/util.py`\n    #      https://github.com/python/cpython/blob/c606624af8d4cb3b4a052fb263bb983b3f87585b/Lib/multiprocessing/util.py#L320-L362\n    #      registered when an object requiring this mechanism is first\n    #      created, e.g., `mp.Queue`\n    #      https://github.com/python/cpython/blob/c606624af8d4cb3b4a052fb263bb983b3f87585b/Lib/multiprocessing/context.py#L100-L103\n    #      https://github.com/python/cpython/blob/c606624af8d4cb3b4a052fb263bb983b3f87585b/Lib/multiprocessing/queues.py#L29\n    #      )\n    #\n    #      So in `__del__`, we check if `_utils.python_exit_status` is set or\n    #      `None` (freed), and perform no-op if so.\n    #\n    #      However, simply letting library clean-up codes run can also be bad,\n    #      because such codes (i.e., `multiprocessing.util._exit_function()`)\n    #      include join putting threads for `mp.Queue`, which can be blocking.\n    #      Hence, the main process putting threads are called with\n    #      `cancel_join_thread` at creation.  See later section\n    #      [ 3b. A process won't hang when putting into a queue; ]\n    #      for more details.\n    #\n    #      Here are two example cases where library clean-up codes can run\n    #      before `__del__` is called:\n    #\n    #        1. If we hold onto a reference to the iterator, it more often\n    #           than not tries to do `multiprocessing` library cleaning before\n    #           clearing the alive referenced objects (https://github.com/pytorch/pytorch/issues/48666)\n    #           and thus prevents our cleaning-up code to run first.\n    #\n    #        2. A similar issue araises when a `DataLoader` is used in a subprocess.\n    #           When a process ends, it shuts the all its daemonic children\n    #           down with a SIGTERM (instead of joining them without a timeout).\n    #           Simiarly for threads, but by a different mechanism. This fact,\n    #           together with a few implementation details of multiprocessing, forces\n    #           us to make workers daemonic. All of our problems arise when a\n    #           DataLoader is used in a subprocess, and are caused by multiprocessing\n    #           code which looks more or less like this:\n    #\n    #               try:\n    #                   your_function_using_a_dataloader()\n    #               finally:\n    #                   multiprocessing.util._exit_function()\n    #\n    #           The joining/termination mentioned above happens inside\n    #           `_exit_function()`. Now, if `your_function_using_a_dataloader()`\n    #           throws, the stack trace stored in the exception will prevent the\n    #           frame which uses `DataLoaderIter` to be freed. If the frame has any\n    #           reference to the `DataLoaderIter` (e.g., in a method of the iter),\n    #           its  `__del__`, which starts the shutdown procedure, will not be\n    #           called. That, in turn, means that workers aren't notified. Attempting\n    #           to join in `_exit_function` will then result in a hang.\n    #\n    #           For context, `_exit_function` is also registered as an `atexit` call.\n    #           So it is unclear to me (@ssnl) why this is needed in a finally block.\n    #           The code dates back to 2008 and there is no comment on the original\n    #           PEP 371 or patch https://bugs.python.org/issue3050 (containing both\n    #           the finally block and the `atexit` registration) that explains this.\n    #\n    #\n    #      Finally, another choice is to just shutdown workers with logic in 1\n    #      above whenever we see an error in `next`. This isn't ideal because\n    #        a. It prevents users from using try-catch to resume data loading.\n    #        b. It doesn't prevent hanging if users have references to the\n    #           iterator.\n    #\n    #   3. All processes exit if any of them die unexpectedly by fatal signals.\n    #\n    #      As shown above, the workers are set as daemonic children of the main\n    #      process. However, automatic cleaning-up of such child processes only\n    #      happens if the parent process exits gracefully (e.g., not via fatal\n    #      signals like SIGKILL). So we must ensure that each process will exit\n    #      even the process that should send/receive data to/from it were\n    #      killed, i.e.,\n    #\n    #        a. A process won't hang when getting from a queue.\n    #\n    #           Even with carefully designed data dependencies (i.e., a `put()`\n    #           always corresponding to a `get()`), hanging on `get()` can still\n    #           happen when data in queue is corrupted (e.g., due to\n    #           `cancel_join_thread` or unexpected exit).\n    #\n    #           For child exit, we set a timeout whenever we try to get data\n    #           from `data_queue`, and check the workers' status on each timeout\n    #           and error.\n    #           See `_DataLoaderiter._get_batch()` and\n    #           `_DataLoaderiter._try_get_data()` for details.\n    #\n    #           Additionally, for child exit on non-Windows platforms, we also\n    #           register a SIGCHLD handler (which is supported on Windows) on\n    #           the main process, which checks if any of the workers fail in the\n    #           (Python) handler. This is more efficient and faster in detecting\n    #           worker failures, compared to only using the above mechanism.\n    #           See `DataLoader.cpp` and `_utils/signal_handling.py` for details.\n    #\n    #           For `.get()` calls where the sender(s) is not the workers, we\n    #           guard them with timeouts, and check the status of the sender\n    #           when timeout happens:\n    #             + in the workers, the `_utils.worker.ManagerWatchdog` class\n    #               checks the status of the main process.\n    #             + if `pin_memory=True`, when getting from `pin_memory_thread`,\n    #               check `pin_memory_thread` status periodically until `.get()`\n    #               returns or see that `pin_memory_thread` died.\n    #\n    #        b. A process won't hang when putting into a queue;\n    #\n    #           We use `mp.Queue` which has a separate background thread to put\n    #           objects from an unbounded buffer array. The background thread is\n    #           daemonic and usually automatically joined when the process\n    #           *exits*.\n    #\n    #           In case that the receiver has ended abruptly while\n    #           reading from the pipe, the join will hang forever.  The usual\n    #           solution for this in Python is calling  `q.cancel_join_thread`,\n    #           which prevents automatically joining it when finalizing\n    #           (exiting).\n    #\n    #           Nonetheless, `cancel_join_thread` must only be called when the\n    #           queue is **not** going to be read from or write into by another\n    #           process, because it may hold onto a lock or leave corrupted data\n    #           in the queue, leading other readers/writers to hang.\n    #\n    #           Hence,\n    #             + For worker processes, we only do so (for their output\n    #               queues, i.e., `worker_result_queue`) before exiting.\n    #             + For `pin_memory_thread`, its output queue `data_queue` is a\n    #               `queue.Queue` that does blocking `put` if the queue is full.\n    #               So there is no above problem, but as a result, in\n    #               `_pin_memory_loop`, we do need to  wrap the `put` in a loop\n    #               that breaks not only upon success, but also when the main\n    #               process stops reading, i.e., is shutting down.\n    #             + For loader process, we `cancel_join_thread()` for all\n    #               `_index_queues` because the whole purpose of workers and\n    #               `pin_memory_thread` is to serve the loader process.  If\n    #               loader process is already exiting, we don't really care if\n    #               the queues are corrupted.\n    #\n    #\n    # Now let's get back to 1:\n    #   how we gracefully exit the workers when the last reference to the\n    #   iterator is gone.\n    #\n    # To achieve this, we implement the following logic along with the design\n    # choices mentioned above:\n    #\n    # `workers_done_event`:\n    #   A `multiprocessing.Event` shared among the main process and all worker\n    #   processes. This is used to signal the workers that the iterator is\n    #   shutting down. After it is set, they will not send processed data to\n    #   queues anymore, and only wait for the final `None` before exiting.\n    #   `done_event` isn't strictly needed. I.e., we can just check for `None`\n    #   from the input queue, but it allows us to skip wasting resources\n    #   processing data if we are already shutting down.\n    #\n    # `pin_memory_thread_done_event`:\n    #   A `threading.Event` for a similar purpose to that of\n    #   `workers_done_event`, but is for the `pin_memory_thread`. The reason\n    #   that separate events are needed is that `pin_memory_thread` reads from\n    #   the output queue of the workers. But the workers, upon seeing that\n    #   `workers_done_event` is set, only wants to see the final `None`, and is\n    #   not required to flush all data in the output queue (e.g., it may call\n    #   `cancel_join_thread` on that queue if its `IterableDataset` iterator\n    #   happens to exhaust coincidentally, which is out of the control of the\n    #   main process). Thus, since we will exit `pin_memory_thread` before the\n    #   workers (see below), two separete events are used.\n    #\n    # NOTE: In short, the protocol is that the main process will set these\n    #       `done_event`s and then the corresponding processes/threads a `None`,\n    #       and that they may exit at any time after receiving the `None`.\n    #\n    # NOTE: Using `None` as the final signal is valid, since normal data will\n    #       always be a 2-tuple with the 1st element being the index of the data\n    #       transferred (different from dataset index/key), and the 2nd being\n    #       either the dataset key or the data sample (depending on which part\n    #       of the data model the queue is at).\n    #\n    # [ worker processes ]\n    #   While loader process is alive:\n    #     Get from `index_queue`.\n    #       If get anything else,\n    #          Check `workers_done_event`.\n    #            If set, continue to next iteration\n    #                    i.e., keep getting until see the `None`, then exit.\n    #            Otherwise, process data:\n    #                If is fetching from an `IterableDataset` and the iterator\n    #                    is exhausted, send an `_IterableDatasetStopIteration`\n    #                    object to signal iteration end. The main process, upon\n    #                    receiving such an object, will send `None` to this\n    #                    worker and not use the corresponding `index_queue`\n    #                    anymore.\n    #       If timed out,\n    #          No matter `workers_done_event` is set (still need to see `None`)\n    #          or not, must continue to next iteration.\n    #   (outside loop)\n    #   If `workers_done_event` is set,  (this can be False with `IterableDataset`)\n    #     `data_queue.cancel_join_thread()`.  (Everything is ending here:\n    #                                          main process won't read from it;\n    #                                          other workers will also call\n    #                                          `cancel_join_thread`.)\n    #\n    # [ pin_memory_thread ]\n    #   # No need to check main thread. If this thread is alive, the main loader\n    #   # thread must be alive, because this thread is set as daemonic.\n    #   While `pin_memory_thread_done_event` is not set:\n    #     Get from `index_queue`.\n    #       If timed out, continue to get in the next iteration.\n    #       Otherwise, process data.\n    #       While `pin_memory_thread_done_event` is not set:\n    #         Put processed data to `data_queue` (a `queue.Queue` with blocking put)\n    #         If timed out, continue to put in the next iteration.\n    #         Otherwise, break, i.e., continuing to the out loop.\n    #\n    #   NOTE: we don't check the status of the main thread because\n    #           1. if the process is killed by fatal signal, `pin_memory_thread`\n    #              ends.\n    #           2. in other cases, either the cleaning-up in __del__ or the\n    #              automatic exit of daemonic thread will take care of it.\n    #              This won't busy-wait either because `.get(timeout)` does not\n    #              busy-wait.\n    #\n    # [ main process ]\n    #   In the DataLoader Iter's `__del__`\n    #     b. Exit `pin_memory_thread`\n    #          i.   Set `pin_memory_thread_done_event`.\n    #          ii   Put `None` in `worker_result_queue`.\n    #          iii. Join the `pin_memory_thread`.\n    #          iv.  `worker_result_queue.cancel_join_thread()`.\n    #\n    #     c. Exit the workers.\n    #          i.   Set `workers_done_event`.\n    #          ii.  Put `None` in each worker's `index_queue`.\n    #          iii. Join the workers.\n    #          iv.  Call `.cancel_join_thread()` on each worker's `index_queue`.\n    #\n    #        NOTE: (c) is better placed after (b) because it may leave corrupted\n    #              data in `worker_result_queue`, which `pin_memory_thread`\n    #              reads from, in which case the `pin_memory_thread` can only\n    #              happen at timeing out, which is slow. Nonetheless, same thing\n    #              happens if a worker is killed by signal at unfortunate times,\n    #              but in other cases, we are better off having a non-corrupted\n    #              `worker_result_queue` for `pin_memory_thread`.\n    #\n    #   NOTE: If `pin_memory=False`, there is no `pin_memory_thread` and (b)\n    #         can be omitted\n    #\n    # NB: `done_event`s isn't strictly needed. E.g., we can just check for\n    #     `None` from `index_queue`, but it allows us to skip wasting resources\n    #     processing indices already in `index_queue` if we are already shutting\n    #     down.\n\n    def __init__(self, loader):\n        super(_MultiProcessingDataLoaderIter, self).__init__(loader)\n        assert not flow.env.rdma_is_initialized(), (\n            \"RDMA is initialized! Could not create _MultiProcessingDataLoaderIter any more. \"\n            \"Please make sure Dataloader is created before invoking oneflow.env.init_rdma(). \"\n            \"If this condition is met, you can pass the arg persistent_workers=True in \"\n            \"Dataloader to avoid this error!\"\n        )\n        assert self._num_workers > 0\n        assert self._prefetch_factor > 0\n\n        if loader.multiprocessing_context is None:\n            multiprocessing_context = multiprocessing\n        else:\n            multiprocessing_context = loader.multiprocessing_context\n\n        self._worker_init_fn = loader.worker_init_fn\n        self._worker_queue_idx_cycle = itertools.cycle(range(self._num_workers))\n        # No certainty which module multiprocessing_context is\n        self._worker_result_queue = multiprocessing_context.Queue()  # type: ignore[var-annotated]\n        self._worker_pids_set = False\n        self._shutdown = False\n        self._workers_done_event = multiprocessing_context.Event()\n\n        self._index_queues = []\n        self._workers = []\n        for i in range(self._num_workers):\n            # No certainty which module multiprocessing_context is\n            index_queue = multiprocessing_context.Queue()  # type: ignore[var-annotated]\n            # Need to `cancel_join_thread` here!\n            # See sections (2) and (3b) above.\n            index_queue.cancel_join_thread()\n\n            w = multiprocessing_context.Process(\n                target=_utils.worker._worker_loop,\n                args=(\n                    self._dataset_kind,\n                    self._dataset,\n                    index_queue,\n                    self._worker_result_queue,\n                    self._workers_done_event,\n                    self._auto_collation,\n                    self._collate_fn,\n                    self._drop_last,\n                    self._base_seed,\n                    self._worker_init_fn,\n                    i,\n                    self._num_workers,\n                    self._persistent_workers,\n                ),\n            )\n            w.daemon = True\n            # NB: Process.start() actually take some time as it needs to\n            #     start a process and pass the arguments over via a pipe.\n            #     Therefore, we only add a worker to self._workers list after\n            #     it started, so that we do not call .join() if program dies\n            #     before it starts, and __del__ tries to join but will get:\n            #     AssertionError: can only join a started process.\n            w.start()\n            self._index_queues.append(index_queue)\n            self._workers.append(w)\n\n        if self._pin_memory:\n            self._pin_memory_thread_done_event = threading.Event()\n\n            # Queue is not type-annotated\n            self._data_queue = queue.Queue()  # type: ignore[var-annotated]\n            pin_memory_thread = threading.Thread(\n                target=_utils.pin_memory._pin_memory_loop,\n                args=(\n                    self._worker_result_queue,\n                    self._data_queue,\n                    flow.cuda.current_device(),\n                    self._pin_memory_thread_done_event,\n                ),\n            )\n            pin_memory_thread.daemon = True\n            pin_memory_thread.start()\n            # Similar to workers (see comment above), we only register\n            # pin_memory_thread once it is started.\n            self._pin_memory_thread = pin_memory_thread\n        else:\n            self._data_queue = self._worker_result_queue\n\n        # .pid can be None only before process is spawned (not the case, so ignore)\n        _utils.signal_handling._set_worker_pids(id(self), tuple(w.pid for w in self._workers))  # type: ignore[misc]\n        _utils.signal_handling._set_SIGCHLD_handler()\n        self._worker_pids_set = True\n        self._reset(loader, first_iter=True)\n\n    def _reset(self, loader, first_iter=False):\n        super()._reset(loader, first_iter)\n        self._send_idx = 0  # idx of the next task to be sent to workers\n        self._rcvd_idx = 0  # idx of the next task to be returned in __next__\n        # information about data not yet yielded, i.e., tasks w/ indices in range [rcvd_idx, send_idx).\n        # map: task idx => - (worker_id,)        if data isn't fetched (outstanding)\n        #                  \\ (worker_id, data)   if data is already fetched (out-of-order)\n        self._task_info = {}\n        self._tasks_outstanding = (\n            0  # always equal to count(v for v in task_info.values() if len(v) == 1)\n        )\n        # A list of booleans representing whether each worker still has work to\n        # do, i.e., not having exhausted its iterable dataset object. It always\n        # contains all `True`s if not using an iterable-style dataset\n        # (i.e., if kind != Iterable).\n        # Not that this indicates that a worker still has work to do *for this epoch*.\n        # It does not mean that a worker is dead. In case of `_persistent_workers`,\n        # the worker will be reset to available in the next epoch.\n        self._workers_status = [True for i in range(self._num_workers)]\n        # We resume the prefetching in case it was enabled\n        if not first_iter:\n            for idx in range(self._num_workers):\n                self._index_queues[idx].put(_utils.worker._ResumeIteration())\n            resume_iteration_cnt = self._num_workers\n            while resume_iteration_cnt > 0:\n                return_idx, return_data = self._get_data()\n                if isinstance(return_idx, _utils.worker._ResumeIteration):\n                    assert return_data is None\n                    resume_iteration_cnt -= 1\n        # prime the prefetch loop\n        for _ in range(self._prefetch_factor * self._num_workers):\n            self._try_put_index()\n\n    def _try_get_data(self, timeout=_utils.MP_STATUS_CHECK_INTERVAL):\n        # Tries to fetch data from `self._data_queue` once for a given timeout.\n        # This can also be used as inner loop of fetching without timeout, with\n        # the sender status as the loop condition.\n        #\n        # This raises a `RuntimeError` if any worker died expectedly. This error\n        # can come from either the SIGCHLD handler in `_utils/signal_handling.py`\n        # (only for non-Windows platforms), or the manual check below on errors\n        # and timeouts.\n        #\n        # Returns a 2-tuple:\n        #   (bool: whether successfully get data, any: data if successful else None)\n        try:\n            data = self._data_queue.get(timeout=timeout)\n            return (True, data)\n        except Exception as e:\n            # At timeout and error, we manually check whether any worker has\n            # failed. Note that this is the only mechanism for Windows to detect\n            # worker failures.\n            failed_workers = []\n            for worker_id, w in enumerate(self._workers):\n                if self._workers_status[worker_id] and not w.is_alive():\n                    failed_workers.append(w)\n                    self._mark_worker_as_unavailable(worker_id)\n            if len(failed_workers) > 0:\n                pids_str = \", \".join(str(w.pid) for w in failed_workers)\n                raise RuntimeError(\n                    \"DataLoader worker (pid(s) {}) exited unexpectedly\".format(pids_str)\n                ) from e\n            if isinstance(e, queue.Empty):\n                return (False, None)\n            import tempfile\n            import errno\n\n            try:\n                # Raise an exception if we are this close to the FDs limit.\n                # Apparently, trying to open only one file is not a sufficient\n                # test.\n                # See NOTE [ DataLoader on Linux and open files limit ]\n                fds_limit_margin = 10\n                fs = [tempfile.NamedTemporaryFile() for i in range(fds_limit_margin)]\n            except OSError as e:\n                if e.errno == errno.EMFILE:\n                    raise RuntimeError(\n                        \"Too many open files. Communication with the\"\n                        \" workers is no longer possible. Please increase the\"\n                        \" limit using `ulimit -n` in the shell or change the\"\n                        \" sharing strategy by calling\"\n                        \" `flow.multiprocessing.set_sharing_strategy('file_system')`\"\n                        \" at the beginning of your code\"\n                    ) from None\n            raise\n\n    def _get_data(self):\n        # Fetches data from `self._data_queue`.\n        #\n        # We check workers' status every `MP_STATUS_CHECK_INTERVAL` seconds,\n        # which we achieve by running `self._try_get_data(timeout=MP_STATUS_CHECK_INTERVAL)`\n        # in a loop. This is the only mechanism to detect worker failures for\n        # Windows. For other platforms, a SIGCHLD handler is also used for\n        # worker failure detection.\n        #\n        # If `pin_memory=True`, we also need check if `pin_memory_thread` had\n        # died at timeouts.\n        if self._timeout > 0:\n            success, data = self._try_get_data(self._timeout)\n            if success:\n                return data\n            else:\n                raise RuntimeError(\n                    \"DataLoader timed out after {} seconds\".format(self._timeout)\n                )\n        elif self._pin_memory:\n            while self._pin_memory_thread.is_alive():\n                success, data = self._try_get_data()\n                if success:\n                    return data\n            else:\n                # while condition is false, i.e., pin_memory_thread died.\n                raise RuntimeError(\"Pin memory thread exited unexpectedly\")\n            # In this case, `self._data_queue` is a `queue.Queue`,. But we don't\n            # need to call `.task_done()` because we don't use `.join()`.\n        else:\n            while True:\n                success, data = self._try_get_data()\n                if success:\n                    return data\n\n    def _next_data(self):\n        while True:\n            # If the worker responsible for `self._rcvd_idx` has already ended\n            # and was unable to fulfill this task (due to exhausting an `IterableDataset`),\n            # we try to advance `self._rcvd_idx` to find the next valid index.\n            #\n            # This part needs to run in the loop because both the `self._get_data()`\n            # call and `_IterableDatasetStopIteration` check below can mark\n            # extra worker(s) as dead.\n            while self._rcvd_idx < self._send_idx:\n                info = self._task_info[self._rcvd_idx]\n                worker_id = info[0]\n                if (\n                    len(info) == 2 or self._workers_status[worker_id]\n                ):  # has data or is still active\n                    break\n                del self._task_info[self._rcvd_idx]\n                self._rcvd_idx += 1\n            else:\n                # no valid `self._rcvd_idx` is found (i.e., didn't break)\n                if not self._persistent_workers:\n                    self._shutdown_workers()\n                raise StopIteration\n\n            # Now `self._rcvd_idx` is the batch index we want to fetch\n\n            # Check if the next sample has already been generated\n            if len(self._task_info[self._rcvd_idx]) == 2:\n                data = self._task_info.pop(self._rcvd_idx)[1]\n                return self._process_data(data)\n\n            assert not self._shutdown and self._tasks_outstanding > 0\n            idx, data = self._get_data()\n            self._tasks_outstanding -= 1\n            if self._dataset_kind == _DatasetKind.Iterable:\n                # Check for _IterableDatasetStopIteration\n                if isinstance(data, _utils.worker._IterableDatasetStopIteration):\n                    if self._persistent_workers:\n                        self._workers_status[data.worker_id] = False\n                    else:\n                        self._mark_worker_as_unavailable(data.worker_id)\n                    self._try_put_index()\n                    continue\n\n            if idx != self._rcvd_idx:\n                # store out-of-order samples\n                self._task_info[idx] += (data,)\n            else:\n                del self._task_info[idx]\n                return self._process_data(data)\n\n    def _try_put_index(self):\n        assert self._tasks_outstanding < self._prefetch_factor * self._num_workers\n\n        try:\n            index = self._next_index()\n        except StopIteration:\n            return\n        for _ in range(self._num_workers):  # find the next active worker, if any\n            worker_queue_idx = next(self._worker_queue_idx_cycle)\n            if self._workers_status[worker_queue_idx]:\n                break\n        else:\n            # not found (i.e., didn't break)\n            return\n\n        self._index_queues[worker_queue_idx].put((self._send_idx, index))\n        self._task_info[self._send_idx] = (worker_queue_idx,)\n        self._tasks_outstanding += 1\n        self._send_idx += 1\n\n    def _process_data(self, data):\n        self._rcvd_idx += 1\n        self._try_put_index()\n        if isinstance(data, ExceptionWrapper):\n            data.reraise()\n        return data\n\n    def _mark_worker_as_unavailable(self, worker_id, shutdown=False):\n        # Mark a worker as having finished its work e.g., due to\n        # exhausting an `IterableDataset`. This should be used only when this\n        # `_MultiProcessingDataLoaderIter` is going to continue running.\n\n        assert self._workers_status[worker_id] or (\n            self._persistent_workers and shutdown\n        )\n\n        # Signal termination to that specific worker.\n        q = self._index_queues[worker_id]\n        # Indicate that no more data will be put on this queue by the current\n        # process.\n        q.put(None)\n\n        # Note that we don't actually join the worker here, nor do we remove the\n        # worker's pid from C side struct because (1) joining may be slow, and\n        # (2) since we don't join, the worker may still raise error, and we\n        # prefer capturing those, rather than ignoring them, even though they\n        # are raised after the worker has finished its job.\n        # Joinning is deferred to `_shutdown_workers`, which it is called when\n        # all workers finish their jobs (e.g., `IterableDataset` replicas) or\n        # when this iterator is garbage collected.\n\n        self._workers_status[worker_id] = False\n\n        assert self._workers_done_event.is_set() == shutdown\n\n    def _shutdown_workers(self):\n        # Called when shutting down this `_MultiProcessingDataLoaderIter`.\n        # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on\n        # the logic of this function.\n\n        # See (2) of the note. If Python is shutting down, do no-op.\n        try:\n            python_exit_status = _utils.python_exit_status\n        except AttributeError:\n            # Python is shutting down and `_utils` has been freed\n            assert _utils is None\n            return\n        if python_exit_status is True or python_exit_status is None:\n            return\n        # Normal exit when last reference is gone / iterator is depleted.\n        # See (1) and the second half of the note.\n        if not self._shutdown:\n            self._shutdown = True\n            try:\n                # Normal exit when last reference is gone / iterator is depleted.\n                # See (1) and the second half of the note.\n\n                # Exit `pin_memory_thread` first because exiting workers may leave\n                # corrupted data in `worker_result_queue` which `pin_memory_thread`\n                # reads from.\n                if hasattr(self, \"_pin_memory_thread\"):\n                    # Use hasattr in case error happens before we set the attribute.\n                    self._pin_memory_thread_done_event.set()\n                    # Send something to pin_memory_thread in case it is waiting\n                    # so that it can wake up and check `pin_memory_thread_done_event`\n                    self._worker_result_queue.put((None, None))\n                    self._pin_memory_thread.join()\n                    self._worker_result_queue.cancel_join_thread()\n                    self._worker_result_queue.close()\n\n                # Exit workers now.\n                self._workers_done_event.set()\n                for worker_id in range(len(self._workers)):\n                    # Get number of workers from `len(self._workers)` instead of\n                    # `self._num_workers` in case we error before starting all\n                    # workers.\n                    # If we are using workers_status with persistent_workers\n                    # we have to shut it down because the worker is paused\n                    if self._persistent_workers or self._workers_status[worker_id]:\n                        self._mark_worker_as_unavailable(worker_id, shutdown=True)\n                for w in self._workers:\n                    # We should be able to join here, but in case anything went\n                    # wrong, we set a timeout and if the workers fail to join,\n                    # they are killed in the `finally` block.\n                    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)\n                for q in self._index_queues:\n                    q.cancel_join_thread()\n                    q.close()\n            finally:\n                # Even though all this function does is putting into queues that\n                # we have called `cancel_join_thread` on, weird things can\n                # happen when a worker is killed by a signal, e.g., hanging in\n                # `Event.set()`. So we need to guard this with SIGCHLD handler,\n                # and remove pids from the C side data structure only at the\n                # end.\n                #\n                # FIXME: Unfortunately, for Windows, we are missing a worker\n                #        error detection mechanism here in this function, as it\n                #        doesn't provide a SIGCHLD handler.\n                if self._worker_pids_set:\n                    _utils.signal_handling._remove_worker_pids(id(self))\n                    self._worker_pids_set = False\n                for w in self._workers:\n                    if w.is_alive():\n                        # Existing mechanisms try to make the workers exit\n                        # peacefully, but in case that we unfortunately reach\n                        # here, which we shouldn't, (e.g., pytorch/pytorch#39570),\n                        # we kill the worker.\n                        w.terminate()\n\n    def __del__(self):\n        self._shutdown_workers()\n"
  },
  {
    "path": "python/oneflow/utils/data/dataset.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport bisect\nimport functools\nfrom typing import (\n    TypeVar,\n    Generic,\n    Iterable,\n    Iterator,\n    Sequence,\n    List,\n    Optional,\n    Tuple,\n    Dict,\n    Callable,\n)\n\nimport oneflow as flow\nfrom oneflow.framework.tensor import Tensor\n\n\ndefault_generator = flow._oneflow_internal.default_generator\n\n# Taken from python 3.5 docs\ndef _accumulate(iterable, fn=lambda x, y: x + y):\n    \"Return running totals\"\n    # _accumulate([1,2,3,4,5]) --> 1 3 6 10 15\n    # _accumulate([1,2,3,4,5], operator.mul) --> 1 2 6 24 120\n    it = iter(iterable)\n    try:\n        total = next(it)\n    except StopIteration:\n        return\n    yield total\n    for element in it:\n        total = fn(total, element)\n        yield total\n\n\nT_co = TypeVar(\"T_co\", covariant=True)\nT = TypeVar(\"T\")\n\n\nclass Dataset(Generic[T_co]):\n    r\"\"\"An abstract class representing a :class:`Dataset`.\n\n    All datasets that represent a map from keys to data samples should subclass\n    it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a\n    data sample for a given key. Subclasses could also optionally overwrite\n    :meth:`__len__`, which is expected to return the size of the dataset by many\n    :class:`~flow.utils.data.Sampler` implementations and the default options\n    of :class:`~flow.utils.data.DataLoader`.\n\n    .. note::\n      :class:`~flow.utils.data.DataLoader` by default constructs a index\n      sampler that yields integral indices.  To make it work with a map-style\n      dataset with non-integral indices/keys, a custom sampler must be provided.\n    \"\"\"\n\n    def __getitem__(self, index) -> T_co:\n        raise NotImplementedError\n\n    def __add__(self, other: \"Dataset[T_co]\") -> \"ConcatDataset[T_co]\":\n        return ConcatDataset([self, other])\n\n\nclass IterableDataset(Dataset[T_co]):\n    r\"\"\"An iterable Dataset.\n\n    All datasets that represent an iterable of data samples should subclass it.\n    Such form of datasets is particularly useful when data come from a stream.\n\n    All subclasses should overwrite :meth:`__iter__`, which would return an\n    iterator of samples in this dataset.\n\n    When a subclass is used with :class:`~flow.utils.data.DataLoader`, each\n    item in the dataset will be yielded from the :class:`~flow.utils.data.DataLoader`\n    iterator. When :attr:`num_workers > 0`, each worker process will have a\n    different copy of the dataset object, so it is often desired to configure\n    each copy independently to avoid having duplicate data returned from the\n    workers.\n\n    Example 1: splitting workload across all workers in :meth:`__iter__`::\n\n        >>> class MyIterableDataset(flow.utils.data.IterableDataset):\n        ...     def __init__(self, start, end):\n        ...         super(MyIterableDataset).__init__()\n        ...         assert end > start, \"this example code only works with end >= start\"\n        ...         self.start = start\n        ...         self.end = end\n        ...\n        ...     def __iter__(self):\n        ...         iter_start = self.start\n        ...         iter_end = self.end\n        ...         return iter(range(iter_start, iter_end))\n        ...\n        >>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].\n        >>> ds = MyIterableDataset(start=3, end=7)\n\n        >>> # Single-process loading\n        >>> print(list(flow.utils.data.DataLoader(ds, num_workers=0)))\n        [3, 4, 5, 6]\n\n\n    Example 2: splitting workload across all workers using :attr:`worker_init_fn`::\n\n        >>> class MyIterableDataset(flow.utils.data.IterableDataset):\n        ...     def __init__(self, start, end):\n        ...         super(MyIterableDataset).__init__()\n        ...         assert end > start, \"this example code only works with end >= start\"\n        ...         self.start = start\n        ...         self.end = end\n        ...\n        ...     def __iter__(self):\n        ...         return iter(range(self.start, self.end))\n        ...\n        >>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].\n        >>> ds = MyIterableDataset(start=3, end=7)\n\n        >>> # Single-process loading\n        >>> print(list(flow.utils.data.DataLoader(ds, num_workers=0)))\n        [3, 4, 5, 6]\n\n    \"\"\"\n    functions: Dict[str, Callable] = {}\n    reduce_ex_hook: Optional[Callable] = None\n\n    def __iter__(self) -> Iterator[T_co]:\n        raise NotImplementedError\n\n    def __add__(self, other: Dataset[T_co]):\n        return ChainDataset([self, other])\n\n    def __getattr__(self, attribute_name):\n        if attribute_name in IterableDataset.functions:\n            function = functools.partial(\n                IterableDataset.functions[attribute_name], self\n            )\n            return function\n        else:\n            raise AttributeError\n\n    @classmethod\n    def register_function(cls, function_name, function):\n        IterableDataset.functions[function_name] = function\n\n    @classmethod\n    def register_datapipe_as_function(cls, function_name, cls_to_register):\n        if function_name in IterableDataset.functions:\n            raise Exception(\n                \"Unable to add DataPipe function name {} as it is already taken\".format(\n                    function_name\n                )\n            )\n\n        def class_function(cls, source_dp, *args, **kwargs):\n            return cls(source_dp, *args, **kwargs)\n\n        function = functools.partial(class_function, cls_to_register)\n        IterableDataset.functions[function_name] = function\n\n    def __reduce_ex__(self, *args, **kwargs):\n        if IterableDataset.reduce_ex_hook is not None:\n            try:\n                return IterableDataset.reduce_ex_hook(self)\n            except NotImplementedError:\n                pass\n        return super().__reduce_ex__(*args, **kwargs)\n\n    @classmethod\n    def set_reduce_ex_hook(cls, hook_fn):\n        if IterableDataset.reduce_ex_hook is not None and hook_fn is not None:\n            raise Exception(\"Attempt to override existing reduce_ex_hook\")\n        IterableDataset.reduce_ex_hook = hook_fn\n\n\nclass TensorDataset(Dataset[Tuple[Tensor, ...]]):\n    r\"\"\"Dataset wrapping tensors.\n\n    Each sample will be retrieved by indexing tensors along the first dimension.\n\n    Args:\n        *tensors (Tensor): tensors that have the same size of the first dimension.\n    \"\"\"\n\n    def __init__(self, *tensors: Tensor) -> None:\n        assert all(\n            tensors[0].size(0) == tensor.size(0) for tensor in tensors\n        ), \"Size mismatch between tensors\"\n        self.tensors = tensors\n\n    def __getitem__(self, index):\n        return tuple(tensor[index] for tensor in self.tensors)\n\n    def __len__(self):\n        return self.tensors[0].size(0)\n\n\nclass ConcatDataset(Dataset[T_co]):\n    r\"\"\"Dataset as a concatenation of multiple datasets.\n\n    This class is useful to assemble different existing datasets.\n\n    Args:\n        datasets (sequence): List of datasets to be concatenated\n    \"\"\"\n    datasets: List[Dataset[T_co]]\n    cumulative_sizes: List[int]\n\n    @staticmethod\n    def cumsum(sequence):\n        r, s = [], 0\n        for e in sequence:\n            l = len(e)\n            r.append(l + s)\n            s += l\n        return r\n\n    def __init__(self, datasets: Iterable[Dataset]) -> None:\n        super(ConcatDataset, self).__init__()\n        # Cannot verify that datasets is Sized\n        assert len(datasets) > 0, \"datasets should not be an empty iterable\"  # type: ignore\n        self.datasets = list(datasets)\n        for d in self.datasets:\n            assert not isinstance(\n                d, IterableDataset\n            ), \"ConcatDataset does not support IterableDataset\"\n        self.cumulative_sizes = self.cumsum(self.datasets)\n\n    def __len__(self):\n        return self.cumulative_sizes[-1]\n\n    def __getitem__(self, idx):\n        if idx < 0:\n            if -idx > len(self):\n                raise ValueError(\n                    \"absolute value of index should not exceed dataset length\"\n                )\n            idx = len(self) + idx\n        dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)\n        if dataset_idx == 0:\n            sample_idx = idx\n        else:\n            sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]\n        return self.datasets[dataset_idx][sample_idx]\n\n\nclass ChainDataset(IterableDataset):\n    r\"\"\"Dataset for chainning multiple :class:`IterableDataset` s.\n\n    This class is useful to assemble different existing dataset streams. The\n    chainning operation is done on-the-fly, so concatenating large-scale\n    datasets with this class will be efficient.\n\n    Args:\n        datasets (iterable of IterableDataset): datasets to be chained together\n    \"\"\"\n\n    def __init__(self, datasets: Iterable[Dataset]) -> None:\n        super(ChainDataset, self).__init__()\n        self.datasets = datasets\n\n    def __iter__(self):\n        for d in self.datasets:\n            assert isinstance(\n                d, IterableDataset\n            ), \"ChainDataset only supports IterableDataset\"\n            for x in d:\n                yield x\n\n    def __len__(self):\n        total = 0\n        for d in self.datasets:\n            assert isinstance(\n                d, IterableDataset\n            ), \"ChainDataset only supports IterableDataset\"\n            # Cannot verify that all self.datasets are Sized\n            total += len(d)\n        return total\n\n\nclass Subset(Dataset[T_co]):\n    r\"\"\"\n    Subset of a dataset at specified indices.\n\n    Args:\n        dataset (Dataset): The whole Dataset\n        indices (sequence): Indices in the whole set selected for subset\n    \"\"\"\n    dataset: Dataset[T_co]\n    indices: Sequence[int]\n\n    def __init__(self, dataset: Dataset[T_co], indices: Sequence[int]) -> None:\n        self.dataset = dataset\n        self.indices = indices\n\n    def __getitem__(self, idx):\n        return self.dataset[self.indices[idx]]\n\n    def __len__(self):\n        return len(self.indices)\n\n\ndef random_split(\n    dataset: Dataset[T],\n    lengths: Sequence[int],\n    generator: Optional[object] = default_generator,\n) -> List[Subset[T]]:\n    r\"\"\"\n    Randomly split a dataset into non-overlapping new datasets of given lengths.\n    Optionally fix the generator for reproducible results, e.g.:\n\n    >>> random_split(range(10), [3, 7], generator=flow.Generator().manual_seed(42))\n\n    Args:\n        dataset (Dataset): Dataset to be split\n        lengths (sequence): lengths of splits to be produced\n        generator (Generator): Generator used for the random permutation.\n    \"\"\"\n    # Cannot verify that dataset is Sized\n    if sum(lengths) != len(dataset):  # type: ignore\n        raise ValueError(\n            \"Sum of input lengths does not equal the length of the input dataset!\"\n        )\n\n    indices = flow._C.randperm(sum(lengths), generator=generator).tolist()\n    return [\n        Subset(dataset, indices[offset - length : offset])\n        for offset, length in zip(_accumulate(lengths), lengths)\n    ]\n"
  },
  {
    "path": "python/oneflow/utils/data/decorator.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nfrom typing import Any, Callable, Optional, Type, Union\n\nfrom oneflow.utils.data import IterDataPipe\n\n\nclass functional_datapipe(object):\n    name: str\n\n    def __init__(self, name: str) -> None:\n        self.name = name\n\n    def __call__(self, cls):\n        if isinstance(cls, Type):  # type: ignore\n            if not issubclass(cls, IterDataPipe):\n                raise TypeError(\"`functional_datapipe` can only decorate IterDataPipe\")\n        # with non_deterministic decorator\n        else:\n            if not isinstance(cls, non_deterministic) and not (\n                hasattr(cls, \"__self__\") and isinstance(cls.__self__, non_deterministic)\n            ):\n                raise TypeError(\"`functional_datapipe` can only decorate IterDataPipe\")\n        IterDataPipe.register_datapipe_as_function(self.name, cls)\n        return cls\n\n\n_determinism: bool = False\n\n\nclass guaranteed_datapipes_determinism(object):\n    prev: bool\n\n    def __init__(self) -> None:\n        global _determinism\n        self.prev = _determinism\n        _determinism = True\n\n    def __enter__(self) -> None:\n        pass\n\n    def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:\n        global _determinism\n        _determinism = self.prev\n\n\nclass non_deterministic(object):\n    cls: Optional[Type[IterDataPipe]] = None\n    # TODO: Lambda for picking\n    deterministic_fn: Callable[[], bool]\n\n    def __init__(self, arg: Union[Type[IterDataPipe], Callable[[], bool]]) -> None:\n        # 1. Decorator doesn't have any argument\n        if isinstance(arg, Type):  # type: ignore\n            if not issubclass(arg, IterDataPipe):  # type: ignore\n                raise TypeError(\n                    \"Only `IterDataPipe` can be decorated with `non_deterministic`\"\n                    \", but {} is found\".format(arg.__name__)\n                )\n            self.cls = arg  # type: ignore\n        # 2. Decorator has an argument of a function\n        #    This class should behave differently given different inputs. Use this\n        #    function to verify the determinism for each instance.\n        #    When the function returns True, the instance is non-deterministic. Otherwise,\n        #    the instance is a deterministic DataPipe.\n        elif isinstance(arg, Callable):  # type:ignore\n            self.deterministic_fn = arg  # type: ignore\n        else:\n            raise TypeError(\"{} can not be decorated by non_deterministic\".format(arg))\n\n    def __call__(self, *args, **kwargs):\n        global _determinism\n        #  Decorate IterDataPipe\n        if self.cls is not None:\n            if _determinism:\n                raise TypeError(\n                    \"{} is non-deterministic, but you set 'guaranteed_datapipes_determinism'. \"\n                    \"You can turn off determinism for this DataPipe if that is acceptable \"\n                    \"for your application\".format(self.cls.__name__)\n                )\n            return self.cls(*args, **kwargs)  # type: ignore\n\n        # Decorate with a functional argument\n        if not (\n            isinstance(args[0], Type)\n            and issubclass(  # type: ignore\n                args[0], IterDataPipe\n            )\n        ):\n            raise TypeError(\n                \"Only `IterDataPipe` can be decorated, but {} is found\".format(\n                    args[0].__name__\n                )\n            )\n        self.cls = args[0]\n        return self.deterministic_wrapper_fn\n\n    def deterministic_wrapper_fn(self, *args, **kwargs) -> IterDataPipe:\n        res = self.deterministic_fn(*args, **kwargs)  # type: ignore\n        if not isinstance(res, bool):\n            raise TypeError(\n                \"deterministic_fn of `non_deterministic` decorator is required \"\n                \"to return a boolean value, but {} is found\".format(type(res))\n            )\n        global _determinism\n        if _determinism and res:\n            raise TypeError(\n                \"{} is non-deterministic with the inputs, but you set \"\n                \"'guaranteed_datapipes_determinism'. You can turn off determinism \"\n                \"for this DataPipe if that is acceptable for your application\".format(\n                    self.cls.__name__\n                )\n            )  # type: ignore\n        return self.cls(*args, **kwargs)  # type: ignore\n"
  },
  {
    "path": "python/oneflow/utils/data/distributed.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport math\nimport numpy as np\nfrom typing import TypeVar, Optional, Iterator\n\nimport oneflow as flow\nfrom oneflow.utils.data import Sampler, Dataset\n\n\nT_co = TypeVar(\"T_co\", covariant=True)\n\n\nclass DistributedSampler(Sampler[T_co]):\n    r\"\"\"Sampler that restricts data loading to a subset of the dataset.\n\n    It is especially useful in conjunction with\n    :class:`flow.nn.parallel.DistributedDataParallel`. In such a case, each\n    process can pass a :class:`~flow.utils.data.DistributedSampler` instance as a\n    :class:`~flow.utils.data.DataLoader` sampler, and load a subset of the\n    original dataset that is exclusive to it.\n\n    .. note::\n        Dataset is assumed to be of constant size.\n\n    Args:\n        dataset: Dataset used for sampling.\n        num_replicas (int, optional): Number of processes participating in\n            distributed training. By default, :attr:`world_size` is retrieved from the\n            current distributed group.\n        rank (int, optional): Rank of the current process within :attr:`num_replicas`.\n            By default, :attr:`rank` is retrieved from the current distributed\n            group.\n        shuffle (bool, optional): If ``True`` (default), sampler will shuffle the\n            indices.\n        seed (int, optional): random seed used to shuffle the sampler if\n            :attr:`shuffle=True`. This number should be identical across all\n            processes in the distributed group. Default: ``0``.\n        drop_last (bool, optional): if ``True``, then the sampler will drop the\n            tail of the data to make it evenly divisible across the number of\n            replicas. If ``False``, the sampler will add extra indices to make\n            the data evenly divisible across the replicas. Default: ``False``.\n\n    .. warning::\n        In distributed mode, calling the :meth:`set_epoch` method at\n        the beginning of each epoch **before** creating the :class:`DataLoader` iterator\n        is necessary to make shuffling work properly across multiple epochs. Otherwise,\n        the same ordering will be always used.\n\n    For example:\n\n    .. code-block:: python\n\n        >>> sampler = DistributedSampler(dataset) if is_distributed else None\n        >>> loader = DataLoader(dataset, shuffle=(sampler is None), sampler=sampler)\n        >>> for epoch in range(start_epoch, n_epochs):\n        ...     if is_distributed:\n        ...         sampler.set_epoch(epoch)\n        ...     train(loader)\n    \"\"\"\n\n    def __init__(\n        self,\n        dataset: Dataset,\n        num_replicas: Optional[int] = None,\n        rank: Optional[int] = None,\n        shuffle: bool = True,\n        seed: int = 0,\n        drop_last: bool = False,\n    ) -> None:\n        if num_replicas is None:\n            num_replicas = flow.env.get_world_size()\n        if rank is None:\n            rank = flow.env.get_rank()\n\n        if rank >= num_replicas or rank < 0:\n            raise ValueError(\n                \"Invalid rank {}, rank should be in the interval\"\n                \" [0, {}]\".format(rank, num_replicas - 1)\n            )\n        self.dataset = dataset\n        self.num_replicas = num_replicas\n        self.rank = rank\n        self.epoch = 0\n        self.drop_last = drop_last\n        # If the dataset length is evenly divisible by # of replicas, then there\n        # is no need to drop any data, since the dataset will be split equally.\n        if self.drop_last and len(self.dataset) % self.num_replicas != 0:\n            # Split to nearest available length that is evenly divisible.\n            # This is to ensure each rank receives the same amount of data when\n            # using this Sampler.\n            self.num_samples = math.ceil(\n                # `type:ignore` is required because Dataset cannot provide a default __len__\n                (len(self.dataset) - self.num_replicas)\n                / self.num_replicas\n            )\n        else:\n            self.num_samples = math.ceil(len(self.dataset) / self.num_replicas)\n        self.total_size = self.num_samples * self.num_replicas\n        self.shuffle = shuffle\n        self.seed = seed\n\n    def __iter__(self) -> Iterator[T_co]:\n        if self.shuffle:\n            # deterministically shuffle based on epoch and seed\n            g = flow.Generator(\"cpu\")\n            g.manual_seed(self.seed + self.epoch)\n            indices = flow._C.randperm(len(self.dataset), generator=g).tolist()\n        else:\n            indices = list(range(len(self.dataset)))\n\n        if not self.drop_last:\n            # add extra samples to make it evenly divisible\n            padding_size = self.total_size - len(indices)\n            if padding_size <= len(indices):\n                indices += indices[:padding_size]\n            else:\n                indices += (indices * math.ceil(padding_size / len(indices)))[\n                    :padding_size\n                ]\n        else:\n            # remove tail of data to make it evenly divisible.\n            indices = indices[: self.total_size]\n        assert len(indices) == self.total_size\n\n        # subsample\n        indices = indices[self.rank : self.total_size : self.num_replicas]\n        assert len(indices) == self.num_samples\n\n        return iter(indices)\n\n    def __len__(self) -> int:\n        return self.num_samples\n\n    def set_epoch(self, epoch: int) -> None:\n        \"\"\"Sets the epoch for this sampler. \n        When :attr:`shuffle=True`, this ensures all replicas use a different random \n        ordering for each epoch. Otherwise, the next iteration of this sampler \n        will yield the same ordering.\n\n        Args:\n            epoch (int): Epoch number.\n        \"\"\"\n        self.epoch = epoch\n"
  },
  {
    "path": "python/oneflow/utils/data/sampler.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nfrom typing import Iterator, Optional, Sequence, List, TypeVar, Generic, Sized\nimport numpy as np\n\nimport oneflow as flow\n\n\nT_co = TypeVar(\"T_co\", covariant=True)\n\n\nclass Sampler(Generic[T_co]):\n    r\"\"\"Base class for all Samplers.\n\n    Every Sampler subclass has to provide an :meth:`__iter__` method, providing a\n    way to iterate over indices of dataset elements, and a :meth:`__len__` method\n    that returns the length of the returned iterators.\n\n    .. note:: The :meth:`__len__` method isn't strictly required by\n              :class:`~flow.utils.data.DataLoader`, but is expected in any\n              calculation involving the length of a :class:`~flow.utils.data.DataLoader`.\n    \"\"\"\n\n    def __init__(self, data_source: Optional[Sized]) -> None:\n        pass\n\n    def __iter__(self) -> Iterator[T_co]:\n        raise NotImplementedError\n\n    # NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]\n    #\n    # Many times we have an abstract class representing a collection/iterable of\n    # data, e.g., `flow.utils.data.Sampler`, with its subclasses optionally\n    # implementing a `__len__` method. In such cases, we must make sure to not\n    # provide a default implementation, because both straightforward default\n    # implementations have their issues:\n    #\n    #   + `return NotImplemented`:\n    #     Calling `len(subclass_instance)` raises:\n    #       TypeError: 'NotImplementedType' object cannot be interpreted as an integer\n    #\n    #   + `raise NotImplementedError()`:\n    #     This prevents triggering some fallback behavior. E.g., the built-in\n    #     `list(X)` tries to call `len(X)` first, and executes a different code\n    #     path if the method is not found or `NotImplemented` is returned, while\n    #     raising an `NotImplementedError` will propagate and and make the call\n    #     fail where it could have use `__iter__` to complete the call.\n    #\n    # Thus, the only two sensible things to do are\n    #\n    #   + **not** provide a default `__len__`.\n    #\n    #   + raise a `TypeError` instead, which is what Python uses when users call\n    #     a method that is not defined on an object.\n    #     (@ssnl verifies that this works on at least Python 3.7.)\n\n\nclass SequentialSampler(Sampler[int]):\n    r\"\"\"Samples elements sequentially, always in the same order.\n\n    Args:\n        data_source (Dataset): dataset to sample from\n    \"\"\"\n    data_source: Sized\n\n    def __init__(self, data_source):\n        self.data_source = data_source\n\n    def __iter__(self):\n        return iter(range(len(self.data_source)))\n\n    def __len__(self) -> int:\n        return len(self.data_source)\n\n\nclass RandomSampler(Sampler[int]):\n    r\"\"\"Samples elements randomly. If without replacement, then sample from a shuffled dataset.\n    If with replacement, then user can specify :attr:`num_samples` to draw.\n\n    Args:\n        data_source (Dataset): dataset to sample from\n        replacement (bool): samples are drawn on-demand with replacement if ``True``, default=``False``\n        num_samples (int): number of samples to draw, default=`len(dataset)`. This argument\n            is supposed to be specified only when `replacement` is ``True``.\n        generator (Generator): Generator used in sampling.\n    \"\"\"\n    data_source: Sized\n    replacement: bool\n\n    def __init__(\n        self,\n        data_source: Sized,\n        replacement: bool = False,\n        num_samples: Optional[int] = None,\n        generator=None,\n    ) -> None:\n        self.data_source = data_source\n        self.replacement = replacement\n        self._num_samples = num_samples\n        self.generator = generator\n\n        if not isinstance(self.replacement, bool):\n            raise TypeError(\n                \"replacement should be a boolean value, but got \"\n                \"replacement={}\".format(self.replacement)\n            )\n\n        if self._num_samples is not None and not replacement:\n            raise ValueError(\n                \"With replacement=False, num_samples should not be specified, \"\n                \"since a random permute will be performed.\"\n            )\n\n        if not isinstance(self.num_samples, int) or self.num_samples <= 0:\n            raise ValueError(\n                \"num_samples should be a positive integer \"\n                \"value, but got num_samples={}\".format(self.num_samples)\n            )\n\n    @property\n    def num_samples(self) -> int:\n        # dataset size might change at runtime\n        if self._num_samples is None:\n            return len(self.data_source)\n        return self._num_samples\n\n    def __iter__(self):\n        n = len(self.data_source)\n        if self.generator is None:\n            generator = flow.Generator(\"cpu\")\n            generator.manual_seed(np.random.randint(0, np.iinfo(np.int64).max))\n            # TODO: use Tensor.random_\n            # generator.manual_seed(\n            #     int(flow.empty((), dtype=flow.int64).random_().item())\n            # )\n        else:\n            generator = self.generator\n        if self.replacement:\n            for _ in range(self.num_samples // 32):\n                yield from flow._C.randint(\n                    high=n, size=(32,), dtype=flow.int64, generator=generator\n                ).numpy().tolist()\n            yield from flow._C.randint(\n                high=n,\n                size=(self.num_samples % 32,),\n                dtype=flow.int64,\n                generator=generator,\n            ).numpy().tolist()\n        else:\n            yield from flow._C.randperm(n, generator=generator).numpy().tolist()\n\n    def __len__(self):\n        return self.num_samples\n\n\nclass SubsetRandomSampler(Sampler[int]):\n    r\"\"\"Samples elements randomly from a given list of indices, without replacement.\n\n    Args:\n        indices (sequence): a sequence of indices\n        generator (Generator): Generator used in sampling.\n    \"\"\"\n    indices: Sequence[int]\n\n    def __init__(self, indices: Sequence[int], generator=None) -> None:\n        self.indices = indices\n        self.generator = generator\n\n    def __iter__(self):\n        return (\n            self.indices[i]\n            for i in flow._C.randperm(len(self.indices), generator=self.generator)\n        )\n\n    def __len__(self):\n        return len(self.indices)\n\n\nclass BatchSampler(Sampler[List[int]]):\n    r\"\"\"Wraps another sampler to yield a mini-batch of indices.\n\n    Args:\n        sampler (Sampler or Iterable): Base sampler. Can be any iterable object\n        batch_size (int): Size of mini-batch.\n        drop_last (bool): If ``True``, the sampler will drop the last batch if\n            its size would be less than ``batch_size``\n\n    Example:\n        >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))\n        [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]\n        >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True))\n        [[0, 1, 2], [3, 4, 5], [6, 7, 8]]\n    \"\"\"\n\n    def __init__(self, sampler: Sampler[int], batch_size: int, drop_last: bool) -> None:\n        # Since collections.abc.Iterable does not check for `__getitem__`, which\n        # is one way for an object to be an iterable, we don't do an `isinstance`\n        # check here.\n        if (\n            not isinstance(batch_size, int)\n            or isinstance(batch_size, bool)\n            or batch_size <= 0\n        ):\n            raise ValueError(\n                \"batch_size should be a positive integer value, \"\n                \"but got batch_size={}\".format(batch_size)\n            )\n        if not isinstance(drop_last, bool):\n            raise ValueError(\n                \"drop_last should be a boolean value, but got \"\n                \"drop_last={}\".format(drop_last)\n            )\n        self.sampler = sampler\n        self.batch_size = batch_size\n        self.drop_last = drop_last\n\n    def __iter__(self):\n        batch = []\n        for idx in self.sampler:\n            batch.append(idx)\n            if len(batch) == self.batch_size:\n                yield batch\n                batch = []\n        if len(batch) > 0 and not self.drop_last:\n            yield batch\n\n    def __len__(self):\n        # Can only be called if self.sampler has __len__ implemented\n        # We cannot enforce this condition, so we turn off typechecking for the\n        # implementation below.\n        # Somewhat related: see NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]\n        if self.drop_last:\n            return len(self.sampler) // self.batch_size  # type: ignore\n        else:\n            return (len(self.sampler) + self.batch_size - 1) // self.batch_size  # type: ignore\n"
  },
  {
    "path": "python/oneflow/utils/global_view/__init__.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nfrom oneflow.utils.global_view.to_global import to_global\nfrom oneflow.utils.global_view.to_local import to_local\nfrom oneflow.utils.global_view.global_mode import global_mode, current_global_mode\n\n__all__ = [\n    \"to_global\",\n    \"to_local\",\n    \"global_mode\",\n    \"current_global_mode\",\n]\n"
  },
  {
    "path": "python/oneflow/utils/global_view/global_mode.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport oneflow._oneflow_internal.global_view as internal_global_view\n\n\nclass global_mode(internal_global_view.global_mode):\n    r\"\"\"Create a scope to provide global information for the computation process within it.\n    \n    It provides convinience for converting from local execution to global execution, especially for converting to ddp global execution.\n    \n    1) Make the source op create the global tensor directly.\n    2) Make it legal for the \"to(device)\" API  of the global tensor.\n    3) Make it legal to use \".device\" to get the device type of the global tensor.\n    \n    Note:\n        Both placement and sbp are required if the global mode is enabled.\n        \n    Args:\n        enabled (bool): whether the global mode is enbaled.\n        placement (oneflow.placement, optional): the desired placement of the input. Default: None\n        sbp (oneflow.sbp.sbp, list/tuple of oneflow.sbp.sbp, optional): the desired sbp of the input or self-defined functions in order to specify SBP. Default: None\n\n    For example:\n\n    .. code-block:: python\n\n        class LinearEvalGraphWithDDP(flow.nn.Graph):\n            def __init__(self):\n                super().__init__()\n                self.linear_dp = linear_dp\n\n            def build(self, x):\n                with global_mode(True, placement=P, sbp=B):\n                    device = self.linear_dp.weight.device\n\n                    x = x.to(device)\n\n                    out = self.linear_dp(x)\n\n                    # The local tensor will be converted to global\n                    sample = flow.randn(out.shape, device=\"cpu\").to(device)\n                    out = out + sample * 100\n                    out = out - sample * 100\n\n                return out\n         \n    .. code-block:: python       \n\n        with global_mode(False):\n            # The tensor will be keeped as local.\n            sample = flow.randn(out.shape, device=\"cpu\").to(device)\n            out = out + sample * 100\n            out = out - sample * 100\n    \"\"\"\n\n    def __init__(self, enabled, placement=None, sbp=None) -> None:\n        if not enabled:\n            super().__init__(enabled)\n        else:\n            super().__init__(enabled, placement, sbp)\n\n    def __enter__(self):\n        pass\n\n    def __exit__(self, type, value, traceback):\n        pass\n\n\nclass current_global_mode(internal_global_view.current_global_mode):\n    r\"\"\"Get the current global mode information.\n    \n    Use the current_global_mode to get the information of global mode, including enabled, placement and sbp.\n\n    Note: \n        The sbp property is supposed to return a list/tuple of `oneflow.sbp.sbp`.\n\n    For example:\n\n    .. code-block:: python\n\n        with global_mode(True, placement=P, sbp=B):\n            # Get the global mode info.\n            cur_global_mode = global_view.current_global_mode()\n            test_case.assertTrue(cur_global_mode.is_enabled)\n            test_case.assertEqual(cur_global_mode.placement, P)\n            test_case.assertEqual(cur_global_mode.sbp[0], B)\n    \"\"\"\n\n    def __init__(self) -> None:\n        super().__init__()\n\n    @property\n    def is_enabled(self):\n        return super().is_enabled\n\n    @property\n    def sbp(self):\n        return super().sbp\n\n    @property\n    def placement(self):\n        return super().placement\n"
  },
  {
    "path": "python/oneflow/utils/global_view/global_utils.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport warnings\nimport pickle\n\nimport oneflow as flow\nfrom oneflow.framework.tensor import Tensor\nfrom oneflow.framework.args_tree import ArgsTree\n\n\ndef to_global_tensor(input_tensor, placement=None, sbp=None, **kwargs):\n    # specific operation for None\n    if input_tensor is None:\n        return flow.local_to_global(\n            input=input_tensor, placement=placement, sbp=sbp, **kwargs\n        )\n\n    if input_tensor.is_global:\n        return flow.global_to_global(\n            input=input_tensor, placement=placement, sbp=sbp, **kwargs\n        )\n    else:\n        if \"grad_sbp\" in kwargs:\n            del kwargs[\"grad_sbp\"]\n        return flow.local_to_global(\n            input=input_tensor, placement=placement, sbp=sbp, **kwargs\n        )\n\n\ndef to_local_tensor(input_tensor, copy):\n    if not input_tensor.is_global:\n        warnings.warn(\"The tensor should be global, local tensor will remain the same.\")\n        return input_tensor\n    return flow._C.to_local(input_tensor, copy)\n\n\ndef check_input_global(input):\n    is_input_global = False\n    if input is not None:\n        if isinstance(input, Tensor):\n            is_input_global = input.is_global\n        elif isinstance(input, (dict, tuple, list)):\n            is_first_tensor_in_input = True\n            input_tree_for_is_global = ArgsTree(input)\n            for arg in input_tree_for_is_global.iter_nodes():\n                if isinstance(arg, Tensor):\n                    if is_first_tensor_in_input:\n                        is_input_global = arg.is_global\n                        is_first_tensor_in_input = False\n                    else:\n                        assert (\n                            arg.is_global == is_input_global\n                        ), \"Tensor(s) in the input must be all local or all global.\"\n\n    return is_input_global\n\n\ndef check_placement_on_all_ranks(placement):\n    # Determine whether the ranks of placement are same as all ranks\n    is_placement_on_all_ranks = False\n    all_ranks = flow.placement.all(\"cpu\").ranks\n    if (\n        all_ranks.shape == placement.ranks.shape\n        and (all_ranks == placement.ranks).all()\n    ):\n        is_placement_on_all_ranks = True\n\n    return is_placement_on_all_ranks\n\n\ndef src_sbp_broadcast(obj, src: int = 0):\n    rank = flow.env.get_rank()\n    if src == rank:\n        obj_bytes = pickle.dumps(obj)\n        obj_bytes = flow._oneflow_internal.cpu_broadcast(obj_bytes, src)\n    else:\n        obj_bytes = flow._oneflow_internal.cpu_broadcast(None, src)\n    return pickle.loads(obj_bytes)\n"
  },
  {
    "path": "python/oneflow/utils/global_view/to_global.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport warnings\nimport pickle\nimport types\n\nimport oneflow as flow\nfrom oneflow.framework.tensor import Tensor\nfrom oneflow.framework.args_tree import ArgsTree\nfrom oneflow.utils.global_view.global_utils import (\n    to_global_tensor,\n    check_input_global,\n    check_placement_on_all_ranks,\n    src_sbp_broadcast,\n)\n\n\ndef to_global(input, placement=None, sbp=None, warn_on_non_tensor_leaf=True, **kwargs):\n    r\"\"\"Converts the input tensor or input tensor(s) in list/tuple/dict to global tensor(s).\n    \n    Note:\n        Both placement and sbp are required if the input is local, otherwise at least one of placement and sbp is required.\n\n    Args:\n        input (oneflow.Tensor/None/list/tuple/dict): the input that needs to be converted.\n        placement (oneflow.placement, optional): the desired placement of the input. Default: None\n        sbp (oneflow.sbp.sbp, list/tuple of oneflow.sbp.sbp or Callable[[Tensor], oneflow.sbp.sbp], optional): the desired sbp of the input or self-defined functions in order to specify SBP. Default: None\n        warn_on_non_tensor_leaf (bool, optional): whether to warn when the leaf is not a tensor. Default: True\n    \n    Returns:\n        The converted input.\n\n    For a tensor input: please refer to the examples in :func:`oneflow.Tensor.to_global`.\n\n    For an input of other type (take a state dict as an example):\n\n    .. code-block:: python\n\n        >>> # Run on 2 ranks respectively\n        >>> import oneflow as flow\n        >>> from oneflow import nn\n        >>> placement = flow.placement(\"cpu\", ranks=[0, 1]) # doctest: +SKIP\n        >>> sbp = (flow.sbp.broadcast,) # doctest: +SKIP\n        >>> model = nn.Sequential(nn.Linear(8, 4), nn.ReLU(), nn.Linear(4, 2)) # doctest: +SKIP\n        >>> global_state_dict = flow.utils.global_view.to_global(model.state_dict(), placement, sbp) # doctest: +SKIP\n        >>> for val in state_dict.values(): # doctest: +SKIP\n        >>>     print(val.is_global) # doctest: +SKIP\n\n    .. code-block:: python\n\n        >>> # results on rank 0\n        True\n        True\n        True\n        True\n\n    .. code-block:: python\n\n        >>> # results on rank 1\n        True\n        True\n        True\n        True\n\n    Note:\n        For the input of dict type, such as the state dict of the model, the unified sbp cannot be used when calling the to_global method, and the sbp needs to be specialized. \n        Usually used for making graph models's state dict global.\n\n    If you want to do the `split(0)` operation, but there are tensors that cannot be split by dim 0, then these tensors can specify sbp. \n    It is worth noting that, for a tensor of shape `(1, n)`, you can specify SBP is `oneflow.sbp.split(1)`.\n    For example:\n\n    .. code-block:: python\n\n        flow.utils.global_view.to_global(state_dict, placement=placement, sbp=get_sbp)\n        # Defines a function to return the specified SBP.\n        def get_sbp(state_dict, tensor):\n            if tensor is state_dict[\"System-Train-TrainStep\"]:\n                return oneflow.sbp.broadcast\n            if tensor is state_dict[\"module_pipeline\"][\"m_stage3.linear.weight\"]:\n                return oneflow.sbp.split(1)\n            if tensor is state_dict[\"module_pipeline\"][\"m_stage3.linear.bias\"]:\n                return oneflow.sbp.broadcast\n            return oneflow.sbp.split(0)\n\n    \"\"\"\n    is_input_not_tensor_or_none = False\n    if (input is not None) and (not isinstance(input, (Tensor, dict, tuple, list))):\n        is_input_not_tensor_or_none = True\n\n    if (\n        (not is_input_not_tensor_or_none)\n        and (placement is not None)\n        and (not check_input_global(input))\n        and (not check_placement_on_all_ranks(placement))\n    ):\n        src_rank = placement.ranks.flat[0]\n        cur_rank = flow.env.get_rank()\n\n        if cur_rank == src_rank:\n            # Replace tensor(s) in the input with None, in order to reduce communication cost\n            if isinstance(input, Tensor) or input is None:\n                mapped_input_none = None\n            else:\n                input_tree_none = ArgsTree(input)\n\n                def leaf_fn_to_none(node):\n                    if isinstance(node, Tensor):\n                        # Ensure that each rank has a tensor instance, which can avoid the situation of none is none in the user-defined get_sbp function.\n                        return flow.empty(0, 1)\n                    else:\n                        if warn_on_non_tensor_leaf:\n                            warnings.warn(\n                                \"Non-Tensor type: {} encountered, it will remain the same.\".format(\n                                    type(node)\n                                )\n                            )\n                        return node\n\n                mapped_input_none = input_tree_none.map_leaf(leaf_fn_to_none)\n\n            obj_input = pickle.dumps(mapped_input_none)\n            flow._oneflow_internal.cpu_broadcast(obj_input, src_rank)\n        else:\n            if cur_rank in placement.ranks:\n                # Participating in the broadcast process but retaining original value\n                flow._oneflow_internal.cpu_broadcast(None, src_rank)\n            else:\n                # The input of other ranks will be always overwritten no matter what is passed in\n                input = pickle.loads(\n                    flow._oneflow_internal.cpu_broadcast(None, src_rank)\n                )\n\n    if isinstance(input, (Tensor, dict, tuple, list)):\n        input_tree = ArgsTree(input)\n\n        def leaf_fn(node):\n            if isinstance(node, Tensor) or node is None:\n                if isinstance(sbp, types.FunctionType):\n\n                    return to_global_tensor(node, placement, sbp(input, node), **kwargs)\n\n                else:\n                    return to_global_tensor(node, placement, sbp, **kwargs)\n\n            else:\n                if warn_on_non_tensor_leaf:\n                    warnings.warn(\n                        \"Non-Tensor type: {} encountered, it will remain the same.\".format(\n                            type(node)\n                        )\n                    )\n                return node\n\n        mapped_input = input_tree.map_leaf(leaf_fn)\n        return mapped_input\n\n    else:\n        if warn_on_non_tensor_leaf:\n            warnings.warn(\n                \"Non-Tensor type: {} encountered, it will remain the same.\".format(\n                    type(input)\n                )\n            )\n        return input\n"
  },
  {
    "path": "python/oneflow/utils/global_view/to_local.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport warnings\n\nimport oneflow as flow\nfrom oneflow.framework.tensor import Tensor\nfrom oneflow.framework.args_tree import ArgsTree\nfrom oneflow.utils.global_view.global_utils import to_local_tensor\n\n\ndef to_local(input, *, copy=False):\n    r\"\"\"Returns the local part of the input.\n    \n    Returns:\n        The converted input.\n\n    For a tensor input: please refer to the examples in :func:`oneflow.Tensor.to_local`.\n\n    For an input of other type (take a state dict as an example):\n\n    .. code-block:: python\n\n        >>> # Run on 2 ranks respectively\n        >>> import oneflow as flow\n        >>> from oneflow import nn\n        >>> placement = flow.placement(\"cpu\", ranks=[0, 1]) # doctest: +SKIP\n        >>> sbp = (flow.sbp.broadcast,) # doctest: +SKIP\n        >>> model = nn.Sequential(nn.Linear(8, 4), nn.ReLU(), nn.Linear(4, 2)) # doctest: +SKIP\n        >>> model = model.to_global(placement=placement, sbp=sbp) # doctest: +SKIP\n        >>> local_state_dict = flow.utils.global_view.to_local(model.state_dict()) # doctest: +SKIP\n        >>> for val in local_state_dict.values(): # doctest: +SKIP\n        >>>     print(val.is_global) # doctest: +SKIP\n\n    .. code-block:: python\n\n        >>> # results on rank 0\n        False\n        False\n        False\n        False\n\n    .. code-block:: python\n\n        >>> # results on rank 1\n        False\n        False\n        False\n        False\n    \"\"\"\n    if isinstance(input, Tensor):\n        return to_local_tensor(input, copy)\n    elif isinstance(input, (dict, tuple, list)):\n        input_tree = ArgsTree(input)\n\n        def leaf_fn(node):\n            if isinstance(node, Tensor):\n                return to_local_tensor(node, copy)\n            else:\n                warnings.warn(\n                    \"Non-Tensor type: {} encountered, it will remain the same.\".format(\n                        type(node)\n                    )\n                )\n                return node\n\n        mapped_input = input_tree.map_leaf(leaf_fn)\n        return mapped_input\n    else:\n        warnings.warn(\n            \"Non-Tensor type: {} encountered, it will remain the same.\".format(\n                type(input)\n            )\n        )\n        return input\n"
  },
  {
    "path": "python/oneflow/utils/hooks.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n# This file is mostly copied from PyTorch's torch/utils/hooks.py\nimport oneflow as flow\nimport oneflow.nn.modules._functions\nfrom oneflow.framework.tensor_tuple_util import convert_to_tensor_tuple\nfrom collections import OrderedDict\nimport weakref\nimport warnings\nfrom typing import Any\n\n__all__ = [\"BackwardHook\", \"RemovableHandle\"]\n\n\nclass RemovableHandle(object):\n    \"\"\"A handle which provides the capability to remove a hook.\"\"\"\n\n    id: int\n    next_id: int = 0\n\n    def __init__(self, hooks_dict: Any) -> None:\n        self.hooks_dict_ref = weakref.ref(hooks_dict)\n        self.id = RemovableHandle.next_id\n        RemovableHandle.next_id += 1\n\n    def remove(self) -> None:\n        hooks_dict = self.hooks_dict_ref()\n        if hooks_dict is not None and self.id in hooks_dict:\n            del hooks_dict[self.id]\n\n    def __getstate__(self):\n        return (self.hooks_dict_ref(), self.id)\n\n    def __setstate__(self, state) -> None:\n        if state[0] is None:\n            # create a dead reference\n            self.hooks_dict_ref = weakref.ref(OrderedDict())\n        else:\n            self.hooks_dict_ref = weakref.ref(state[0])\n        self.id = state[1]\n        RemovableHandle.next_id = max(RemovableHandle.next_id, self.id + 1)\n\n    def __enter__(self) -> \"RemovableHandle\":\n        return self\n\n    def __exit__(self, type: Any, value: Any, tb: Any) -> None:\n        self.remove()\n\n\nclass BackwardHook(object):\n    \"\"\"\n    A wrapper class to implement nn.Module backward hooks.\n    It handles:\n      - Ignoring non-Tensor inputs and replacing them by None before calling the user hook\n      - Generating the proper Node to capture a set of Tensor's gradients\n      - Linking the gradients captures for the outputs with the gradients captured for the input\n      - Calling the user hook once both output and input gradients are available\n    \"\"\"\n\n    def __init__(self, module, user_hooks, user_pre_hooks):\n        self.user_hooks = user_hooks\n        self.user_pre_hooks = user_pre_hooks\n        self.module = module\n\n        self.grad_outputs = None\n        self.n_outputs = -1\n        self.output_tensors_index = None\n        self.n_inputs = -1\n        self.input_tensors_index = None\n\n    def _pack_with_none(self, indices, values, size):\n        res = [None] * size\n        for idx, val in zip(indices, values):\n            res[idx] = val\n\n        return convert_to_tensor_tuple(res)\n\n    def _unpack_none(self, indices, values):\n        res = []\n        for idx in indices:\n            res.append(values[idx])\n\n        return convert_to_tensor_tuple(res)\n\n    def _set_user_hook(self, grad_fn):\n        def fn(grad_input, _):\n            # TODO(hujiakui): in pytorch, it should raise Error.\n            if self.grad_outputs is None:\n                warnings.warn(\n                    \"Module backward hook for grad_input is called before \"\n                    \"the grad_output one. This happens because the gradient \"\n                    \"in your nn.Module flows to the Module's input without \"\n                    \"passing through the Module's output. Make sure that the \"\n                    \"output depends on the input and that the loss is computed \"\n                    \"based on the output.\"\n                )\n                return\n\n            res = self._pack_with_none(\n                self.input_tensors_index, grad_input, self.n_inputs\n            )\n\n            for hook in self.user_hooks:\n                out = hook(self.module, res, self.grad_outputs)\n\n                if out is None:\n                    continue\n\n                if len(out) != len(res):\n                    raise RuntimeError(\n                        \"Backward hook returned an invalid number of grad_input, \"\n                        \"got {}, but expected {}\".format(len(out), len(res))\n                    )\n\n                res = out\n\n            if res is None:\n                return res\n\n            if len(res) != len(grad_input):\n                raise RuntimeError(\n                    \"Backward hook returned an invalid number of grad_input, \"\n                    \"got {}, but expected {}\".format(len(res), len(grad_input))\n                )\n            self.grad_outputs = None\n            return self._unpack_none(self.input_tensors_index, res)\n\n        grad_fn.register_hook(fn)\n\n    def _apply_on_tensors(self, fn, args):\n        # Can be used to apply the given function to the tensors contained in the\n        # args. Will return updated args and the tensors indices\n        tensors_idx = []\n        tensors = []\n\n        requires_grad = False\n        for i, arg in enumerate(args):\n            if isinstance(arg, flow.Tensor):\n                tensors_idx.append(i)\n                tensors.append(arg)\n                requires_grad |= arg.requires_grad\n\n        if not (requires_grad and flow.is_grad_enabled()):\n            return args, None\n\n        # FIXME: BackwardFunction should not return a single Tensor when the return type is tuple\n        new_tensors = flow.nn.modules._functions.BackwardHookFunction.apply(*tensors)\n        if not isinstance(new_tensors, tuple):\n            new_tensors = (new_tensors,)\n\n        if len(new_tensors) == 0:\n            raise RuntimeError(\n                \"Cannot set Module backward hook for a Module with no input Tensors.\"\n            )\n\n        grad_fns = [\n            t.grad_fn\n            for t in new_tensors\n            if t.grad_fn is not None\n            and t.grad_fn.name() == \"BackwardHookFunctionBackward\"\n        ]\n        if len(grad_fns) == 0:\n            raise RuntimeError(\n                \"Error while setting up backward hooks. Please open \"\n                \"an issue with a code sample to reproduce this.\"\n            )\n\n        fn(grad_fns[0])\n\n        arg_list = list(args)\n        for idx, val in zip(tensors_idx, new_tensors):\n            arg_list[idx] = val\n\n        return tuple(arg_list), tensors_idx\n\n    def setup_input_hook(self, args):\n        def fn(grad_fn):\n            self._set_user_hook(grad_fn)\n\n        res, input_idx = self._apply_on_tensors(fn, args)\n        self.n_inputs = len(args)\n        self.input_tensors_index = input_idx\n        return res\n\n    def setup_output_hook(self, args):\n        def fn(grad_fn):\n            def hook(_, grad_output):\n                self.grad_outputs = self._pack_with_none(\n                    self.output_tensors_index, grad_output, self.n_outputs\n                )\n\n                if self.user_pre_hooks:\n                    expected_len = len(self.grad_outputs)\n                    for user_pre_hook in self.user_pre_hooks:\n                        hook_grad_outputs = user_pre_hook(\n                            self.module, self.grad_outputs\n                        )\n                        if hook_grad_outputs is None:\n                            continue\n\n                        actual_len = len(hook_grad_outputs)\n                        if actual_len != expected_len:\n                            raise RuntimeError(\n                                \"Backward pre hook returned an invalid number of grad_output, \"\n                                \"got {}, but expected {}\".format(\n                                    actual_len, expected_len\n                                )\n                            )\n                        self.grad_outputs = hook_grad_outputs\n\n                # Special case if no input required gradients, this hook should call the user\n                # hook directly\n                if self.input_tensors_index is None:\n                    grad_inputs = self._pack_with_none([], [], self.n_inputs)\n                    for user_hook in self.user_hooks:\n                        res = user_hook(self.module, grad_inputs, self.grad_outputs)\n                        if res is not None and not (\n                            isinstance(res, tuple) and all(el is None for el in res)\n                        ):\n                            raise RuntimeError(\n                                \"Backward hook for Modules where no input requires \"\n                                \"gradient should always return None or None for all gradients.\"\n                            )\n                    self.grad_outputs = None\n\n            grad_fn.register_hook(hook)\n\n        is_tuple = True\n        if not isinstance(args, tuple):\n            args = (args,)\n            is_tuple = False\n\n        res, output_idx = self._apply_on_tensors(fn, args)\n\n        self.n_outputs = len(args)\n        self.output_tensors_index = output_idx\n\n        if not is_tuple:\n            res = res[0]\n        return res\n"
  },
  {
    "path": "python/oneflow/utils/insight/README.md",
    "content": "# OneFlow Insight\n\n## Overview\n\nOneFlow Insight is a module designed for profiling CUDA kernel execution time and bottleneck analysis. Typically, this is done using the nsys command provided by Nvidia, which generates corresponding profile files (formerly .qdrep and now .nsys-rep). These files can be visualized and analyzed using Nvidia's GUI software, Nsight Systems.\n\nIn addition to generating profile files, nsys also produces platform-independent data information recorded in a .sqlite file. The OneFlow Insight module can parse this .sqlite file to generate a JSON file formatted according to the Google Chrome Trace Event standard. This allows for direct visualization and analysis through Chrome or Edge browsers using chrome://tracing/ or edge://tracing/ (supported by trace-event-profiling-tool, see:https://www.chromium.org/developers/how-tos/trace-event-profiling-tool/).\n\n\n## Usage\n\n1. Generate profile files using the following nsys command:\n\n    ```bash\n    nsys profile --export=sqlite -o profile_data\n    ```\n\n    This will produce .nsys-rep files along with a .sqlite file.\n\n2. Use OneFlow Insight to parse the .sqlite file and generate a JSON file:\n\n    ```bash\n    python3 sqlite_to_google_trace_event.py --input 'profile_data.sqlite' -o trace.json\n    ```\n\n3. Open Chrome or Edge browser and navigate to chrome://tracing/ or edge://tracing/.\n\n4. Load the generated trace.json file for visualizing and analyzing the profiling data.\n\n## Visualization Example\n\n![OneFlow Insight Visualization](trace.json.png)\n\nThe above image demonstrates the visualization capabilities using Chrome or Edge browser with the generated JSON file.\n\nFeel free to explore and gain insights into your CUDA kernel execution performance!\n"
  },
  {
    "path": "python/oneflow/utils/insight/requirements.txt",
    "content": "sqlite3\nargparse\ntraceback"
  },
  {
    "path": "python/oneflow/utils/insight/sqlite_to_google_trace_event.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport json\nimport sqlite3\nimport argparse\nimport traceback\n\n\nclass DatabaseManager:\n    def __init__(self, db_file):\n        self.db_file = db_file\n        self.connection = None\n        self.cursor = None\n\n    def open_connection(self):\n        self.connection = sqlite3.connect(self.db_file)\n        self.cursor = self.connection.cursor()\n\n    def close_connection(self):\n        if self.cursor:\n            self.cursor.close()\n        if self.connection:\n            self.connection.close()\n\n    def execute_sql(self, sql):\n        try:\n            self.cursor.execute(sql)\n            self.connection.commit()\n        except sqlite3.Error as e:\n            print(f\"Execute sql '{sql}' error: {e}\")\n            traceback.print_exc()\n\n\ndef are_tables_exist(db_manager, table_names):\n    try:\n        # Query for the existence of sqlite database tables with specific names\n        results = {}\n        for table_name in table_names:\n            db_manager.execute_sql(\n                f\"SELECT name FROM sqlite_master WHERE type='table' AND name='{table_name}'\"\n            )\n            result = db_manager.cursor.fetchone()\n            results[table_name] = result is not None\n        return results\n\n    except sqlite3.Error as e:\n        print(f\"are_tables_exist() SQLite error: {e}\")\n        return {}\n\n\ndef print_db_info(db_manager):\n    # execute sql\n    db_manager.execute_sql(\"SELECT name, sql FROM sqlite_master WHERE type='table';\")\n    # get results\n    tables = db_manager.cursor.fetchall()\n    # print infomation\n    for table in tables:\n        print(f\"Table Name: {table[0]}\\nCreate Table SQL: {table[1]}\\n\")\n\n\ndef get_start_time(db_manager):\n    \"\"\"\n    get session start time(timestamp) from table TARGET_INFO_SESSION_START_TIME\n    \"\"\"\n    sql = \"SELECT utcEpochNs FROM TARGET_INFO_SESSION_START_TIME LIMIT 1;\"\n    db_manager.execute_sql(sql)\n    result = db_manager.cursor.fetchone()\n    timestamp = result[0]\n    return timestamp\n\n\ndef get_process_id(db_manager):\n    \"\"\"\n    get process id from table TARGET_INFO_CUDA_NULL_STREAM\n    \"\"\"\n    sql = \"SELECT processId FROM TARGET_INFO_CUDA_NULL_STREAM LIMIT 1;\"\n    db_manager.execute_sql(sql)\n    result = db_manager.cursor.fetchone()\n    process_id = result[0]\n    return process_id\n\n\ndef get_device_property(db_manager):\n    \"\"\"\n    get device properties from TARGET_INFO_GPU\n    \"\"\"\n    sql = (\n        \"SELECT name,totalMemory,computeMajor,computeMinor,\"\n        \"maxThreadsPerBlock,maxBlocksPerSm,maxRegistersPerBlock,\"\n        \"maxRegistersPerSm,threadsPerWarp,maxShmemPerBlock,\"\n        \"maxRegistersPerSm,smCount,maxShmemPerBlockOptin \"\n        \"FROM TARGET_INFO_GPU WHERE id is 0;\"\n    )\n    db_manager.execute_sql(sql)\n    (\n        name,\n        totalGlobalMem,\n        computeMajor,\n        computeMinor,\n        maxThreadsPerBlock,\n        maxBlocksPerSm,\n        regsPerBlock,\n        regsPerMultiprocessor,\n        warpSize,\n        sharedMemPerBlock,\n        sharedMemPerMultiprocessor,\n        numSms,\n        sharedMemPerBlockOptin,\n    ) = db_manager.cursor.fetchone()\n    maxThreadsPerMultiprocessor = maxThreadsPerBlock * maxBlocksPerSm\n\n    property = {\n        \"id\": 0,\n        \"name\": name,\n        \"totalGlobalMem\": totalGlobalMem,\n        \"computeMajor\": computeMajor,\n        \"computeMinor\": computeMinor,\n        \"maxThreadsPerBlock\": maxThreadsPerBlock,\n        \"maxThreadsPerMultiprocessor\": maxThreadsPerMultiprocessor,\n        \"regsPerBlock\": regsPerBlock,\n        \"regsPerMultiprocessor\": regsPerMultiprocessor,\n        \"warpSize\": warpSize,\n        \"sharedMemPerBlock\": sharedMemPerBlock,\n        \"sharedMemPerMultiprocessor\": sharedMemPerMultiprocessor,\n        \"numSms\": numSms,\n        \"sharedMemPerBlockOptin\": sharedMemPerBlockOptin,\n    }\n    return property\n\n\ndef sqlite_to_google_trace_event(args, tables):\n    try:\n        database_path = args.input\n        print(\"Opening sqlite database :\", database_path)\n        db_manager = DatabaseManager(database_path)\n        db_manager.open_connection()\n\n        # print basic database information\n        if args.info:\n            print_db_info(db_manager)\n\n        print(\"Checking if the following table exists:\")\n        results = are_tables_exist(db_manager, tables)\n        for table_name, exists in results.items():\n            if not exists:\n                print(f\"'{table_name}' not exists.\")\n                raise ValueError(\n                    f\"Table '{table_name}' does not exist in the database.\"\n                )\n            else:\n                print(f\"'{table_name}' exists.\")\n\n        # get some necessary information\n        session_start_time = get_start_time(db_manager)  # session start time\n        process_id = get_process_id(db_manager)  # process id\n        device_property = get_device_property(db_manager)  # properties of cuda device\n\n        deviceProperties = [device_property]\n        db_manager.execute_sql(\n            \"SELECT name,busLocation FROM TARGET_INFO_GPU WHERE id is 0;\"\n        )\n        name, bus_location = db_manager.cursor.fetchone()\n        db_manager.execute_sql(\n            \"SELECT duration, startTime, stopTime FROM ANALYSIS_DETAILS LIMIT 1;\"\n        )\n        trace_duration, trace_start_time, trace_stop_time = db_manager.cursor.fetchone()\n\n        raw_start_time = session_start_time + trace_start_time\n        start_time = round(raw_start_time / 1000)  # μs to ms\n        end_time = round((session_start_time + trace_stop_time) / 1000)\n        duration = round(trace_duration / 1000)  # μs to ms\n        traceEvents_data = []\n        # construct process meta infomations\n        traceEvents_meta = [\n            {\n                \"name\": \"process_name\",\n                \"ph\": \"M\",\n                \"ts\": start_time,\n                \"pid\": process_id,\n                \"tid\": 0,\n                \"args\": {\"name\": \"python3\"},\n            },\n            {\n                \"name\": \"process_labels\",\n                \"ph\": \"M\",\n                \"ts\": start_time,\n                \"pid\": process_id,\n                \"tid\": 0,\n                \"args\": {\"labels\": \"CPU\"},\n            },\n            {\n                \"name\": \"process_sort_index\",\n                \"ph\": \"M\",\n                \"ts\": start_time,\n                \"pid\": process_id,\n                \"tid\": 0,\n                \"args\": {\"sort_index\": process_id},\n            },\n            {\n                \"name\": \"process_name\",\n                \"ph\": \"M\",\n                \"ts\": start_time,\n                \"pid\": 0,\n                \"tid\": 0,\n                \"args\": {\"name\": \"python3\"},\n            },\n            {\n                \"name\": \"process_labels\",\n                \"ph\": \"M\",\n                \"ts\": start_time,\n                \"pid\": 0,\n                \"tid\": 0,\n                \"args\": {\"labels\": f\"GPU 0(CUDA HW {bus_location} - {name})\"},\n            },\n            {\n                \"name\": \"process_sort_index\",\n                \"ph\": \"M\",\n                \"ts\": start_time,\n                \"pid\": 0,\n                \"tid\": 0,\n                \"args\": {\"sort_index\": process_id},\n            },\n            {\n                \"ph\": \"X\",\n                \"cat\": \"Trace\",\n                \"ts\": start_time,\n                \"dur\": duration,\n                \"pid\": \"Spans\",\n                \"tid\": \"OneFlow Insight\",\n                \"name\": \"OneFlow Insight (0)\",\n                \"args\": {\"Op count\": 0},\n            },\n            {\n                \"name\": \"process_sort_index\",\n                \"ph\": \"M\",\n                \"ts\": start_time,\n                \"pid\": \"Spans\",\n                \"tid\": 0,\n                \"args\": {\"sort_index\": \"Spans\"},\n            },\n            {\n                \"name\": \"Iteration Start: OneFlow Insight\",\n                \"ph\": \"i\",\n                \"s\": \"g\",\n                \"pid\": \"Traces\",\n                \"tid\": \"Trace OneFlow Insight\",\n                \"ts\": start_time,\n            },\n            {\n                \"name\": \"Record Window End\",\n                \"ph\": \"i\",\n                \"s\": \"g\",\n                \"pid\": \"\",\n                \"tid\": \"\",\n                \"ts\": end_time,\n            },\n        ]\n\n        # construct vm threads meta infomations\n        db_manager.execute_sql(\"SELECT text,globalTid FROM NVTX_EVENTS;\")\n        globalTids = []\n        for row in db_manager.cursor.fetchall():\n            text, globalTid = row\n            globalTids.append(globalTid)\n            osrt_name = {\n                \"name\": \"thread_name\",\n                \"ph\": \"M\",\n                \"ts\": start_time,\n                \"pid\": process_id,\n                \"tid\": f\"[OSRT API]{globalTid}\",\n                \"args\": {\"name\": f\"[OSRT API]{text}\"},\n            }\n            osrt_sort_index = {\n                \"name\": \"thread_sort_index\",\n                \"ph\": \"M\",\n                \"ts\": start_time,\n                \"pid\": process_id,\n                \"tid\": f\"[OSRT API]{globalTid}\",\n                \"args\": {\"sort_index\": globalTid - 1},\n            }\n            cu_api_name = {\n                \"name\": \"thread_name\",\n                \"ph\": \"M\",\n                \"ts\": start_time,\n                \"pid\": process_id,\n                \"tid\": globalTid,\n                \"args\": {\"name\": f\"[CUDA API]{text}\"},\n            }\n            cu_api_name_index = {\n                \"name\": \"thread_sort_index\",\n                \"ph\": \"M\",\n                \"ts\": start_time,\n                \"pid\": process_id,\n                \"tid\": globalTid,\n                \"args\": {\"sort_index\": globalTid},\n            }\n            traceEvents_meta.append(osrt_name)\n            traceEvents_meta.append(osrt_sort_index)\n            traceEvents_meta.append(cu_api_name)\n            traceEvents_meta.append(cu_api_name_index)\n\n        # construct cuda stream meta infomations\n        db_manager.execute_sql(\n            \"SELECT streamId,processId FROM TARGET_INFO_CUDA_STREAM;\"\n        )\n        temp_time = start_time\n        for row in db_manager.cursor.fetchall():\n            temp_time += 187000\n            streamId, processId = row\n            thread_name = {\n                \"name\": \"thread_name\",\n                \"ph\": \"M\",\n                \"ts\": start_time,\n                \"pid\": 0,\n                \"tid\": streamId,\n                \"args\": {\"name\": f\"cuda stream {streamId}\", \"stream\": streamId,},\n            }\n            thread_sort_index = {\n                \"name\": \"thread_sort_index\",\n                \"ph\": \"M\",\n                \"ts\": start_time,\n                \"pid\": 0,\n                \"tid\": streamId,\n                \"args\": {\"sort_index\": streamId},\n            }\n            traceEvents_meta.append(thread_name)\n            traceEvents_meta.append(thread_sort_index)\n\n        # insert os runtime events\n        global_tids = \", \".join(map(str, globalTids))\n        db_manager.execute_sql(\n            f\"SELECT start,end,globalTid,nameId  FROM OSRT_API WHERE globalTid IN ({global_tids});\"\n        )\n        for row in db_manager.cursor.fetchall():\n            start, end, globalTid, nameId = row\n            db_manager.execute_sql(f\"SELECT value FROM StringIds WHERE id = {nameId};\")\n            name = db_manager.cursor.fetchone()[0]\n            ts = (raw_start_time + start) / 1000\n            dur = (end - start) / 1000\n            row_data = {\n                \"ph\": \"X\",\n                \"cat\": \"OS RUNTIME API\",\n                \"name\": name,\n                \"pid\": process_id,\n                \"tid\": f\"[OSRT API]{globalTid}\",\n                \"ts\": ts,\n                \"dur\": dur,\n                \"args\": {\"global tid\": f\"{globalTid}(serialized)\",},\n            }\n            traceEvents_data.append(row_data)\n\n        # insert cuda runtime api events\n        db_manager.execute_sql(\n            \"SELECT start,end,globalTid,correlationId,nameId  FROM CUPTI_ACTIVITY_KIND_RUNTIME;\"\n        )\n        for row in db_manager.cursor.fetchall():\n            start, end, globalTid, correlationId, nameId = row\n            db_manager.execute_sql(f\"SELECT value FROM StringIds WHERE id is {nameId};\")\n            name = db_manager.cursor.fetchone()[0]\n            short_name = name.split(\"_\", 1)[0]\n            ts = (raw_start_time + start) / 1000\n            dur = (end - start) / 1000\n            row_data = {\n                \"ph\": \"X\",\n                \"cat\": \"CUDA API\",\n                \"name\": short_name,\n                \"pid\": process_id,\n                \"tid\": globalTid,\n                \"ts\": ts,\n                \"dur\": dur,\n                \"args\": {\n                    \"name\": f\"Call to {name}\",\n                    \"begins\": f\"{start/(10**9)}s\",\n                    \"ends\": f\"{end/(10**9)}s(+{dur}ms)\",\n                    \"global tid\": f\"{globalTid}(serialized)\",\n                    \"correlation id\": correlationId,\n                },\n            }\n            traceEvents_data.append(row_data)\n\n        # insert cuda kernel events\n        db_manager.execute_sql(\n            (\n                \"SELECT start,end,deviceId,contextId,streamId,\"\n                \"correlationId,globalPid,demangledName,shortName,\"\n                \"gridX,gridY,gridZ,blockX,blockY,blockZ,\"\n                \"staticSharedMemory,dynamicSharedMemory,localMemoryTotal \"\n                \"FROM CUPTI_ACTIVITY_KIND_KERNEL;\"\n            )\n        )\n        for row in db_manager.cursor.fetchall():\n            (\n                start,\n                end,\n                deviceId,\n                contextId,\n                streamId,\n                correlationId,\n                globalPid,\n                demangledName,\n                shortName,\n                gridX,\n                gridY,\n                gridZ,\n                blockX,\n                blockY,\n                blockZ,\n                staticSharedMemory,\n                dynamicSharedMemory,\n                localMemoryTotal,\n            ) = row\n            db_manager.execute_sql(\n                f\"SELECT value FROM StringIds WHERE id is {shortName}\"\n            )\n            short_name = db_manager.cursor.fetchone()[0]\n            db_manager.execute_sql(\n                f\"SELECT value FROM StringIds WHERE id is {demangledName}\"\n            )\n            name = db_manager.cursor.fetchone()[0]\n            ts = (raw_start_time + start) / 1000\n            dur = (end - start) / 1000\n            row_data = {\n                \"ph\": \"X\",\n                \"cat\": \"CUDA Kernel\",\n                \"name\": short_name,\n                \"pid\": 0,\n                \"tid\": streamId,\n                \"ts\": ts,\n                \"dur\": dur,\n                \"args\": {\n                    \"name\": name,\n                    \"begins\": f\"{start/(10**9)}s\",\n                    \"ends\": f\"{end/(10**9)}s(+{dur}ms)\",\n                    \"grid\": f\"<<<{gridX},{gridY},{gridZ}>>>\",\n                    \"block\": f\"<<<{blockX},{blockY},{blockZ}>>>\",\n                    \"static shared memory\": f\"{staticSharedMemory}bytes\",\n                    \"dynamic shared memory\": f\"{dynamicSharedMemory}bytes\",\n                    \"local memory total\": f\"{localMemoryTotal}bytes\",\n                    \"global pid\": f\"{globalPid}(serialized)\",\n                    \"device id\": deviceId,\n                    \"context id\": contextId,\n                    \"stream id\": streamId,\n                    \"correlation id\": correlationId,\n                },\n            }\n\n            traceEvents_data.append(row_data)\n\n        # construct trace event dict\n        traceEvents = traceEvents_data + traceEvents_meta\n        data = {\"deviceProperties\": deviceProperties, \"traceEvents\": traceEvents}\n\n        # the path to the JSON file to be written\n        json_fpath = args.output\n\n        # write dict content into a JSON file using json.dump\n        with open(json_fpath, \"w\") as json_file:\n            json.dump(data, json_file, indent=2)\n        print(f\"Successfully converted content to file: {json_fpath}\")\n\n    except BaseException as e:\n        print(f\"An exception occurred: {type(e).__name__}: {e}\")\n        traceback.print_exc()\n\n    finally:\n        # close db connection\n        db_manager.close_connection()\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"Description of your program\")\n\n    parser.add_argument(\"--input\", help=\"Input nvidia nsight system .sqlite file path\")\n    parser.add_argument(\n        \"--output\",\n        \"-o\",\n        help=\"Output json file path(google trace format)\",\n        default=\"sqlite_to_google_trace_event.json\",\n    )\n    parser.add_argument(\n        \"--info\",\n        \"-v\",\n        action=\"store_true\",\n        help=\"Enable print infomation of sqlite database\",\n        default=False,\n    )\n\n    args = parser.parse_args()\n    # check if necessary tables exist\n    tables_to_check = [\n        \"TARGET_INFO_GPU\",\n        \"TARGET_INFO_SESSION_START_TIME\",\n        \"TARGET_INFO_CUDA_NULL_STREAM\",\n        \"ANALYSIS_DETAILS\",\n        \"NVTX_EVENTS\",\n        \"TARGET_INFO_CUDA_STREAM\",\n        \"OSRT_API\",\n        \"StringIds\",\n        \"CUPTI_ACTIVITY_KIND_RUNTIME\",\n        \"CUPTI_ACTIVITY_KIND_KERNEL\",\n    ]\n\n    # Usage:\n    # python3 sqlite_to_google_trace_event.py --input 'your_file.sqlite'\n    sqlite_to_google_trace_event(args, tables_to_check)\n"
  },
  {
    "path": "python/oneflow/utils/model_zoo.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n# torchvision/flowvision imports tqdm from here.\nfrom oneflow.hub import tqdm, load_state_dict_from_url as load_url  # noqa: F401\n"
  },
  {
    "path": "python/oneflow/utils/tensor/__init__.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nfrom oneflow.utils.tensor.from_or_to_torch_tensor import from_torch, to_torch\n\n__all__ = [\n    \"from_torch\",\n    \"to_torch\",\n]\n"
  },
  {
    "path": "python/oneflow/utils/tensor/from_or_to_torch_tensor.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport sys\nimport oneflow as flow\nfrom oneflow._C import from_numpy as flow_from_numpy\n\n\ndef print_error_msg():\n    msg = \"\"\n    exc_info = sys.exc_info()\n    if len(exc_info) > 0:\n        msg += str(exc_info[0])\n    if len(exc_info) > 1:\n        msg += \" \" + str(exc_info[1])\n    print(msg)\n\n\ndef from_torch(torch_tensor):\n    r\"\"\"\n    from_torch(torch_tensor) -> Tensor\n\n    Create a oneflow tensor from torch tensor.\n\n    The returned tensor and torch tensor share the same memory. \n    \n    .. note::\n        This function can be used in special data processing stages, torch's some cpu ops can be used. \n\n    Args:\n        input (torch.Tensor): Input Tensor\n\n    Returns:\n        oneflow.Tensor\n\n    For example:\n\n    .. code-block:: python\n\n        import oneflow as flow\n        import torch\n\n        torch_t = torch.tensor([[1, 2, 3], [4, 5, 6]])\n        flow_t = flow.utils.tensor.from_torch(torch_t)\n    \n    This feature ``from_torch`` is at Alpha Stage.\n    \"\"\"\n    try:\n        import torch\n    except:\n        print_error_msg()\n    assert isinstance(torch_tensor, torch.Tensor)\n    return flow.from_dlpack(torch.to_dlpack(torch_tensor))\n\n\ndef to_torch(flow_tensor):\n    r\"\"\"\n    to_torch(flow_tensor) -> Tensor\n\n    Create a torch tensor from oneflow tensor.\n\n    The returned tensor and oneflow tensor share the same memory. \n    \n    .. note::\n        Currently only local tensor is supported.\n\n    Args:\n        input (oneflow.Tensor): Input Tensor\n\n    Returns:\n        torch.Tensor\n\n    For example:\n\n    .. code-block:: python\n\n        import oneflow as flow\n        import torch\n\n        flow_t = flow.tensor([[1, 2, 3], [4, 5, 6]])\n        torch_t = flow.utils.tensor.to_torch(flow_t)\n\n    This feature ``to_torch`` is at Alpha Stage.\n    \"\"\"\n    try:\n        import torch\n    except:\n        print_error_msg()\n    assert isinstance(flow_tensor, flow.Tensor)\n    if flow_tensor.is_global:\n        print(\n            \"WARNING: `to_torch` received a global tensor. A PyTorch CPU tensor which is a copy of its data will be returned.\"\n        )\n        return torch.from_numpy(flow_tensor.numpy())\n    return torch.from_dlpack(flow.to_dlpack(flow_tensor))\n"
  },
  {
    "path": "python/setup.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nfrom __future__ import absolute_import\n\nimport argparse\nimport glob\nimport os\nimport sys\nimport numpy as np\n\nfrom setuptools import find_packages, setup\nfrom setuptools.command.install import install\nfrom setuptools.dist import Distribution\n\n\n# https://github.com/google/or-tools/issues/616\nclass InstallPlatlib(install):\n    def finalize_options(self):\n        install.finalize_options(self)\n        if self.distribution.has_ext_modules():\n            self.install_lib = self.install_platlib\n\n\nparser = argparse.ArgumentParser()\nparser.register(\"type\", \"bool\", lambda v: v.lower() == \"true\")\nparser.add_argument(\"--package_name\", type=str, default=\"oneflow\")\nargs, remain_args = parser.parse_known_args()\nsys.argv = [\"setup.py\"] + remain_args\n\n\ndef get_version():\n    import importlib.util\n\n    spec = importlib.util.spec_from_file_location(\n        \"version\", os.path.join(\"oneflow\", \"version.py\")\n    )\n    m = importlib.util.module_from_spec(spec)\n    spec.loader.exec_module(m)\n    return m.__version__\n\n\nREQUIRED_PACKAGES = [\n    f\"numpy>={np.__version__}, <2.0\",\n    \"protobuf>=3.9.2, <4.0\",\n    \"typing-extensions>=4.0.0, <5.0\",\n    \"tqdm\",\n    \"requests\",\n    \"pillow\",\n    \"rich\",\n]\n\nONEFLOW_VERSION = get_version()\nif \"cu11\" in ONEFLOW_VERSION and \"cu112\" not in ONEFLOW_VERSION:\n    REQUIRED_PACKAGES.append(\"nvidia-cudnn-cu11>=8.9,<9.0\")\n    REQUIRED_PACKAGES.append(\"nvidia-cublas-cu11\")\n    REQUIRED_PACKAGES.append(\"nvidia-nccl-cu11\")\n    REQUIRED_PACKAGES.append(\"nvidia-cusparse-cu11\")\n    REQUIRED_PACKAGES.append(\"nvidia-cufft-cu11\")\n\nif \"cu12\" in ONEFLOW_VERSION:\n    REQUIRED_PACKAGES.append(\"nvidia-cudnn-cu12>=8.9,<9.0\")\n    REQUIRED_PACKAGES.append(\"nvidia-cublas-cu12\")\n    REQUIRED_PACKAGES.append(\"nvidia-nccl-cu12\")\n    REQUIRED_PACKAGES.append(\"nvidia-cusparse-cu12\")\n    REQUIRED_PACKAGES.append(\"nvidia-cufft-cu12\")\n\n# if python version < 3.7.x, than need pip install dataclasses\nif sys.version_info.minor < 7:\n    REQUIRED_PACKAGES.append(\"dataclasses\")\n\n\nclass BinaryDistribution(Distribution):\n    def is_pure(self):\n        return False\n\n    def has_ext_modules(self):\n        return True\n\n\ninclude_files = glob.glob(\"oneflow/include/**/*\", recursive=True)\ninclude_files = [os.path.relpath(p, \"oneflow\") for p in include_files]\nassert len(include_files) > 0, os.path.abspath(\"oneflow/include\")\n\n\ndef get_oneflow_internal_so_path():\n    import importlib\n\n    suffixes = importlib.machinery.EXTENSION_SUFFIXES\n    loader = importlib.machinery.ExtensionFileLoader\n    lazy_loader = importlib.util.LazyLoader.factory(loader)\n    finder = importlib.machinery.FileFinder(\"oneflow\", (lazy_loader, suffixes))\n    spec = finder.find_spec(\"_oneflow_internal\")\n    pathname = spec.origin\n    assert os.path.isfile(pathname)\n    return os.path.basename(pathname)\n\n\npackage_data = {\"oneflow\": [get_oneflow_internal_so_path()] + include_files}\n\n\nsetup(\n    name=args.package_name,\n    version=get_version(),\n    url=\"https://www.oneflow.org/\",\n    install_requires=REQUIRED_PACKAGES,\n    packages=find_packages(),\n    package_dir={\"oneflow\": \"oneflow\"},\n    package_data=package_data,\n    zip_safe=False,\n    distclass=BinaryDistribution,\n    cmdclass={\"install\": InstallPlatlib},\n    entry_points={\n        \"console_scripts\": [\"oneflow-mock-torch=oneflow.mock_torch.__main__:main\"]\n    },\n)\n"
  },
  {
    "path": "tools/check_src.py",
    "content": "import os\nfrom pathlib import Path\n\n\nthis_file = os.path.dirname(os.path.abspath(__file__))\nsrc_root = os.path.join(this_file, \"..\")\nsrc_root = Path(os.path.abspath(src_root))\n\n\ndef check_unwanted_test_scripts(python_test_dir=None, allowed=None):\n    python_test_dir = os.path.abspath(python_test_dir)\n\n    allowed_full = [\n        os.path.relpath(os.path.join(python_test_dir, a), src_root) for a in allowed\n    ]\n    for (dirpath, dirnames, filenames) in os.walk(src_root):\n        if (\n            dirpath.startswith(os.path.abspath(python_test_dir) + os.sep)\n            and \"__pycache__\" not in dirpath\n        ):\n            rel_to_python_test = os.path.relpath(dirpath, python_test_dir)\n            rel_to_src_root = os.path.relpath(dirpath, src_root)\n            print(f\"checking: {rel_to_src_root}\")\n            if (\n                rel_to_python_test not in allowed\n                and rel_to_python_test != \".\"\n                and \"custom_ops\" not in rel_to_python_test\n            ):\n                if filenames == []:\n                    raise ValueError(f\"delete this directory: {rel_to_src_root}\")\n                else:\n                    filenames_full = [\n                        os.path.relpath(os.path.join(dirpath, a), src_root)\n                        for a in filenames\n                    ]\n                    raise ValueError(\n                        f\"\"\"move these files:\n    {filenames_full}\n    inside one of these directories:\n    {allowed_full},\n    and delete this directory: {rel_to_src_root}\"\"\"\n                    )\n\n\ndef check_dir_empty(path):\n    if os.path.exists(path):\n        for dirpath, dirnames, files in os.walk(path):\n            if files:\n                raise ValueError(dirpath, \"must be empty\")\n\n\noneflow_test_dir = src_root / \"python\" / \"oneflow\" / \"test\"\nsave_load_test_data_dirs = [\n    os.path.relpath(x[0], oneflow_test_dir)\n    for x in os.walk(oneflow_test_dir / \"modules\" / \"save_load_test_data\")\n]\n\nprint(save_load_test_data_dirs)\n\ncheck_unwanted_test_scripts(\n    python_test_dir=oneflow_test_dir,\n    allowed=[\n        \"custom_ops\",\n        \"dataloader\",\n        \"graph\",\n        \"models\",\n        \"modules\",\n        *save_load_test_data_dirs,\n        \"tensor\",\n        \"exceptions\",\n        \"expensive\",\n        \"ddp\",\n        \"misc\",\n        \"profiler\",\n    ],\n)\n"
  },
  {
    "path": "tools/clean_generated_api.py",
    "content": "import argparse\nimport glob\nimport os\nimport shutil\n\nparser = argparse.ArgumentParser()\nparser.add_argument(\"-root\", \"--root_path\", type=str, required=True)\nargs = parser.parse_args()\n\n\ndef main():\n    for p in glob.glob(os.path.join(args.root_path, \"oneflow/*/\")):\n        if p.endswith(\"python/\") or p.endswith(\"include/\"):\n            pass\n        else:\n            shutil.rmtree(p)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "tools/create_pip_index.py",
    "content": "# python3 -m pip install oss2 beautifulsoup4 --user\nfrom bs4 import BeautifulSoup\nimport os\nimport oss2\nimport urllib\nimport urllib.parse\n\nos.environ[\"no_proxy\"] = \"*\"\npage_template = \"\"\"\n<!DOCTYPE HTML PUBLIC \"-//W3C//DTD HTML 4.01//EN\" \"http://www.w3.org/TR/html4/strict.dtd\">\n<html>\n\n<head>\n    <meta http-equiv=\"Content-Type\" content=\"text/html; charset=utf-8\">\n    <title>Directory listing for /oneflow/</title>\n</head>\n\n<body>\n    <h1>Directory listing for /oneflow/</h1>\n    <hr>\n    <ul>\n    </ul>\n    <hr>\n</body>\n\n</html>\n\"\"\"\nsoup = BeautifulSoup(page_template, \"html.parser\")\n\n\ndef url4key(endpoint, bucket, key):\n    return \"https://{}.{}/{}\".format(bucket, endpoint, urllib.parse.quote(key))\n\n\ndef append_link(soup, link):\n    li_tag = soup.new_tag(\"li\")\n    soup.body.ul.append(li_tag)\n\n    a_tag = soup.new_tag(\"a\", href=link)\n    a_tag.append(os.path.basename(link))\n    li_tag.append(a_tag)\n\n\ndef generate_index_file(endpoint, bucket, dir_key, file_path, index_keys=None):\n    ki = os.getenv(\"OSS_ACCESS_KEY_ID\")\n    ks = os.getenv(\"OSS_ACCESS_KEY_SECRET\")\n    auth = oss2.Auth(ki, ks)\n    bucket_obj = oss2.Bucket(auth, endpoint, bucket)\n    should_continue = True\n    count = 0\n    next_marker = \"\"\n    while should_continue:\n        files = bucket_obj.list_objects(dir_key + \"/\", marker=next_marker)\n        for f in files.object_list:\n            key = f.key\n            if key.endswith(\".whl\"):\n                link = url4key(endpoint, bucket, key)\n                append_link(soup, link)\n                count += 1\n        next_marker = files.next_marker\n        should_continue = next_marker != \"\"\n    print(\"count\", count)\n    assert count\n    html = soup.prettify()\n    with open(file_path, \"w+\") as f:\n        f.write(html)\n    if index_keys == None:\n        index_keys = [dir_key + \".index.html\"]\n    for index_key in index_keys:\n        bucket_obj.put_object_from_file(index_key, file_path)\n\n\nif __name__ == \"__main__\":\n    import argparse\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\n        \"-o\", \"--output_path\", type=str, required=False, default=\"pip_index.html\"\n    )\n    parser.add_argument(\n        \"-e\",\n        \"--endpoint\",\n        type=str,\n        required=False,\n        default=\"oss-cn-beijing.aliyuncs.com\",\n    )\n    parser.add_argument(\n        \"-b\", \"--bucket\", type=str, required=False, default=\"oneflow-public\",\n    )\n    parser.add_argument(\n        \"-d\", \"--dir_key\", type=str, required=False, default=\"nightly\",\n    )\n    parser.add_argument(\"--index_key\", action=\"append\", nargs=\"+\")\n    args = parser.parse_args()\n    assert args.dir_key[-1] != \"/\"\n    index_keys = sum(args.index_key, [])\n    generate_index_file(\n        args.endpoint,\n        args.bucket,\n        args.dir_key,\n        args.output_path,\n        index_keys=index_keys,\n    )\n"
  },
  {
    "path": "tools/flags_from_git_diff.py",
    "content": "import subprocess\n\n\ndef get_changed_files(base=None, head=None):\n    changed = subprocess.check_output(\n        f\"git diff --name-only --diff-filter=ACMRT {base} {head}\",\n        shell=True,\n        text=True,\n    )\n    changed = str(changed).splitlines()\n    return changed\n\n\ndef should_run_single_client_tests(changed=None):\n    not_single_client_files = [\n        f\n        for f in changed\n        if (\n            f.endswith(\".py\")\n            and not f.startswith(\"python/oneflow/compatible/single_client\")\n        )\n        or f.endswith(\".yml\")\n        or f.endswith(\".rst\")\n        or f.endswith(\".md\")\n        or f.endswith(\".cmake\")\n        or f.endswith(\"CMakeLists.txt\")\n    ]\n    print(\"[changed]\", changed)\n    print(\"[not_single_client_files]\", not_single_client_files)\n    return len(not_single_client_files) < len(changed)\n\n\ndef print_github_action_output(name=None, value=None):\n    print(f\"::set-output name={name}::{value}\")\n\n\nif __name__ == \"__main__\":\n    import argparse\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--base\", type=str, required=True)\n    parser.add_argument(\"--head\", type=str, required=True)\n    parser.add_argument(\"--need_single_client_tests\", action=\"store_true\")\n    args = parser.parse_args()\n    files = get_changed_files(base=args.base, head=args.head)\n    if should_run_single_client_tests(changed=files) or args.need_single_client_tests:\n        print_github_action_output(name=\"should_run_single_client_tests\", value=\"1\")\n"
  },
  {
    "path": "tools/functional/generate_dispatch_stateful_ops.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport os\nimport re\nimport argparse\nimport yaml\n\nfrom generator import Generator\n\nparser = argparse.ArgumentParser()\nparser.add_argument(\n    \"--project_source_dir\", type=str, help=\"The project source code directory.\",\n)\nargs = parser.parse_args()\n\nlicense = \"\"\"/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \\\"License\\\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \\\"AS IS\\\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n// Generated from oneflow/api/python/functional/dispatch_stateful_ops.yaml. DO NOT EDIT!\"\"\"\n\nheader_fmt = (\n    license\n    + \"\"\"\n\n#ifndef ONEFLOW_API_PYTHON_FUNCTIONAL_GENERATED_DISPATCH_OP_API_H_\n#define ONEFLOW_API_PYTHON_FUNCTIONAL_GENERATED_DISPATCH_OP_API_H_\n\n#include <Python.h>\n#undef _PyGC_FINALIZED\n\n#include \"oneflow/core/common/optional.h\"\n#include \"oneflow/core/common/scalar.h\"\n#include \"oneflow/core/framework/dtype.h\"\n#include \"oneflow/core/framework/tensor.h\"\n#include \"oneflow/core/framework/op_expr.h\"\n#include \"oneflow/core/framework/tensor_tuple.h\"\n#include \"oneflow/core/framework/random_generator.h\"\n#include \"oneflow/core/functional/tensor_index.h\"\n\nnamespace oneflow {{\nnamespace one {{\nnamespace functional {{\n{0}\n}}  // namespace functional\n}}  // namespace one\n}}  // namespace oneflow\n\n#endif  // ONEFLOW_API_PYTHON_FUNCTIONAL_GENERATED_DISPATCH_OP_API_H_\"\"\"\n)\n\nsource_fmt = (\n    license\n    + \"\"\"\n\n#include \"oneflow/api/python/functional/dispatch_stateful_ops.yaml.h\"\n#include \"oneflow/core/functional/function_library.h\"\n\nnamespace oneflow {{\nnamespace one {{\nnamespace functional {{\n{0}\n}}  // namespace functional\n}}  // namespace one\n}}  // namespace oneflow\n\"\"\"\n)\n\npybind_header_fmt = (\n    license\n    + \"\"\"\n\nnamespace oneflow {{\nnamespace one {{\nnamespace functional {{\n{0}\n}}  // namespace functional\n}}  // namespace one\n}}  // namespace oneflow\n\"\"\"\n)\n\npybind_source_fmt = (\n    license\n    + \"\"\"\n\n#include <Python.h>\n#undef _PyGC_FINALIZED\n\n#include \"oneflow/api/python/of_api_registry.h\"\n#include \"oneflow/api/python/functional/common.h\"\n#include \"oneflow/api/python/exception/exception.h\"\n#include \"oneflow/api/python/functional/function_def.h\"\n#include \"oneflow/api/python/functional/python_arg.h\"\n#include \"oneflow/api/python/functional/python_arg_parser.h\"\n#include \"oneflow/api/python/functional/dispatch_stateful_ops.yaml.h\"\n#include \"oneflow/api/python/functional/dispatch_stateful_ops.yaml.pybind.h\"\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/common/optional.h\"\n#include \"oneflow/extension/stack/python/stack_getter.h\"\n\nnamespace oneflow {{\nnamespace one {{\nnamespace functional {{\n{0}\n}}  // namespace functional\n}}  // namespace one\n\nnamespace functional = one::functional;\n\nONEFLOW_API_PYBIND11_MODULE(\"_C\", m) {{\n  static PyMethodDef functions[] = {{\n{1}\n    {{NULL, NULL, 0, NULL}}\n  }};\n\n  PyObject* module = m.ptr();\n  if (module) {{\n    PyModule_AddFunctions(module, functions);\n  }}\n}}\n\n}}  // namespace oneflow\n\"\"\"\n)\n\nyaml_file_path = os.path.join(\n    args.project_source_dir, \"oneflow/api/python/functional/dispatch_stateful_ops.yaml\"\n)\ngenerated_api_dir = \"oneflow/api/python/functional\"\ngenerated_pybind_dir = \"oneflow/api/python/functional\"\n\nif __name__ == \"__main__\":\n    assert os.path.isfile(yaml_file_path), (\n        \"It is not a regular file for the yaml file which is \" + yaml_file_path\n    )\n    g = Generator(yaml_file_path)\n\n    assert os.path.isdir(generated_api_dir), (\n        \"Could not locate the api generate directory which is \" + generated_api_dir\n    )\n    target_header_file = os.path.join(generated_api_dir, \"dispatch_stateful_ops.yaml.h\")\n    g.generate_cpp_header_file(header_fmt, target_header_file)\n    target_source_file = os.path.join(\n        generated_api_dir, \"dispatch_stateful_ops.yaml.cpp\"\n    )\n    g.generate_cpp_source_file(source_fmt, target_source_file)\n\n    assert os.path.isdir(generated_pybind_dir), (\n        \"Could not locate the pybind generate directory which is \"\n        + generated_pybind_dir\n    )\n    target_pybind_header_file = os.path.join(\n        generated_pybind_dir, \"dispatch_stateful_ops.yaml.pybind.h\"\n    )\n    target_pybind_source_file = os.path.join(\n        generated_pybind_dir, \"dispatch_stateful_ops.yaml.pybind.cpp\"\n    )\n    g.generate_pybind_for_python(\n        pybind_header_fmt,\n        pybind_source_fmt,\n        target_pybind_header_file,\n        target_pybind_source_file,\n    )\n"
  },
  {
    "path": "tools/functional/generate_functional_api.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport os\nimport re\nimport argparse\nimport yaml\n\nfrom generator import Generator\n\nparser = argparse.ArgumentParser()\nparser.add_argument(\n    \"--project_source_dir\", type=str, help=\"The project source code directory.\",\n)\nparser.add_argument(\n    \"--export_pybind\",\n    action=\"store_true\",\n    default=False,\n    help=\"Whether to export pybind related files.\",\n)\nargs = parser.parse_args()\n\nlicense = \"\"\"/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \\\"License\\\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \\\"AS IS\\\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n// Generated from oneflow/core/functional/functional_api.yaml. DO NOT EDIT!\"\"\"\n\nheader_fmt = (\n    license\n    + \"\"\"\n\n#ifndef ONEFLOW_CORE_FUNCTIONAL_GENERATED_FUNCTIONAL_API_H_\n#define ONEFLOW_CORE_FUNCTIONAL_GENERATED_FUNCTIONAL_API_H_\n\n#include \"oneflow/core/common/memory_format.pb.h\"\n#include \"oneflow/core/common/optional.h\"\n#include \"oneflow/core/common/scalar.h\"\n#include \"oneflow/core/framework/dtype.h\"\n#include \"oneflow/core/framework/layout.h\"\n#include \"oneflow/core/framework/nd_sbp.h\"\n#include \"oneflow/core/framework/tensor.h\"\n#include \"oneflow/core/framework/tensor_tuple.h\"\n#include \"oneflow/core/framework/random_generator.h\"\n#include \"oneflow/core/functional/tensor_index.h\"\n\nnamespace oneflow {{\nnamespace one {{\nnamespace functional {{\n{0}\n}}  // namespace functional\n}}  // namespace one\n}}  // namespace oneflow\n\n#endif  // ONEFLOW_CORE_FUNCTIONAL_GENERATED_FUNCTIONAL_API_H_\"\"\"\n)\n\nsource_fmt = (\n    license\n    + \"\"\"\n\n#include \"oneflow/core/functional/functional_api.yaml.h\"\n#include \"oneflow/core/functional/function_library.h\"\n\nnamespace oneflow {{\nnamespace one {{\nnamespace functional {{\n{0}\n}}  // namespace functional\n}}  // namespace one\n}}  // namespace oneflow\n\"\"\"\n)\n\npybind_header_fmt = (\n    license\n    + \"\"\"\n\n#include <Python.h>\n#undef _PyGC_FINALIZED\n\nnamespace oneflow {{\nnamespace one {{\nnamespace functional {{\n{0}\n}}  // namespace functional\n}}  // namespace one\n}}  // namespace oneflow\n\"\"\"\n)\n\npybind_source_fmt = (\n    license\n    + \"\"\"\n\n#include <Python.h>\n#undef _PyGC_FINALIZED\n\n#include \"oneflow/api/python/of_api_registry.h\"\n#include \"oneflow/api/python/exception/exception.h\"\n#include \"oneflow/api/python/functional/common.h\"\n#include \"oneflow/api/python/functional/function_def.h\"\n#include \"oneflow/api/python/functional/python_arg.h\"\n#include \"oneflow/api/python/functional/python_arg_parser.h\"\n#include \"oneflow/api/python/functional/python_return_types.h\"\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/common/optional.h\"\n#include \"oneflow/core/functional/functional.h\"\n#include \"oneflow/extension/stack/python/stack_getter.h\"\n\nnamespace {{\n// This return type template code is referenced from:\n// https://github.com/pytorch/pytorch/blob/master/tools/autograd/gen_python_functions.py\nusing oneflow::one::functional::returned_structseq_repr;\n{2}\n\nstd::unordered_map<std::string, PyTypeObject*>& get_namedtuple_types_map() {{\n  static std::unordered_map<std::string, PyTypeObject*> namedtuple_types_map = {{\n{3}\n  }};\n  return namedtuple_types_map;\n}}\n\nPyTypeObject* get_namedtuple(const std::string& name) {{\n  static auto& namedtuple_types_map = get_namedtuple_types_map();\n  return namedtuple_types_map[name];\n}}\n\n}} // namespace\n\n\nnamespace oneflow {{\nnamespace one {{\nnamespace functional {{\n\nPyObject* WrapTensorTuple(const TensorTuple& tensortuple,\n                           const std::string& name) {{\n  PyObjectPtr r(PyStructSequence_New(get_namedtuple(name)));\n  if (!r) {{ throw py::error_already_set(); }}\n  for (int i = 0; i < tensortuple.size(); ++i) {{\n    PyTuple_SET_ITEM(r.get(), i, CastToPyObject(tensortuple[i]));\n  }}\n  return r.release();\n}}\n{0}\n}}  // namespace functional\n}}  // namespace one\n\nnamespace functional = one::functional;\n\nONEFLOW_API_PYBIND11_MODULE(\"_C\", m) {{\n  static PyMethodDef functions[] = {{\n{1}\n    {{NULL, NULL, 0, NULL}}\n  }};\n\n  PyObject* module = m.ptr();\n  if (module) {{\n    PyModule_AddFunctions(module, functions);\n  }}\n}}\n\n}}  // namespace oneflow\n\"\"\"\n)\n\nyaml_file_path = os.path.join(\n    args.project_source_dir, \"oneflow/core/functional/functional_api.yaml\"\n)\ngenerated_api_dir = \"oneflow/core/functional\"\ngenerated_pybind_dir = \"oneflow/api/python/functional\"\n\nif __name__ == \"__main__\":\n    assert os.path.isfile(yaml_file_path), (\n        \"It is not a regular file for the yaml file which is \" + yaml_file_path\n    )\n    g = Generator(yaml_file_path)\n\n    assert os.path.isdir(generated_api_dir), (\n        \"Could not locate the api generate directory which is \" + generated_api_dir\n    )\n    target_header_file = os.path.join(generated_api_dir, \"functional_api.yaml.h\")\n    g.generate_cpp_header_file(header_fmt, target_header_file)\n    target_source_file = os.path.join(generated_api_dir, \"functional_api.yaml.cpp\")\n    g.generate_cpp_source_file(source_fmt, target_source_file)\n    if args.export_pybind:\n        assert os.path.isdir(generated_pybind_dir), (\n            \"Could not locate the pybind generate directory which is \"\n            + generated_pybind_dir\n        )\n        target_pybind_header_file = os.path.join(\n            generated_pybind_dir, \"functional_api.yaml.pybind.h\"\n        )\n        target_pybind_source_file = os.path.join(\n            generated_pybind_dir, \"functional_api.yaml.pybind.cpp\"\n        )\n        g.generate_pybind_for_python(\n            pybind_header_fmt,\n            pybind_source_fmt,\n            target_pybind_header_file,\n            target_pybind_source_file,\n        )\n"
  },
  {
    "path": "tools/functional/generate_tensor_api.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\nimport os\nimport re\nimport argparse\nimport yaml\n\nfrom generator import Generator\n\nparser = argparse.ArgumentParser()\nparser.add_argument(\n    \"--project_source_dir\", type=str, help=\"The project source code directory.\",\n)\nargs = parser.parse_args()\n\nlicense = \"\"\"/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \\\"License\\\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \\\"AS IS\\\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n// Generated from oneflow/api/python/functional/tensor_api.yaml. DO NOT EDIT!\"\"\"\n\nheader_fmt = (\n    license\n    + \"\"\"\n\n#ifndef ONEFLOW_API_PYTHON_FUNCTIONAL_GENERATED_TENSOR_API_H_\n#define ONEFLOW_API_PYTHON_FUNCTIONAL_GENERATED_TENSOR_API_H_\n\n#include <Python.h>\n#undef _PyGC_FINALIZED\n\n#include \"oneflow/core/common/optional.h\"\n#include \"oneflow/core/common/scalar.h\"\n#include \"oneflow/core/framework/dtype.h\"\n#include \"oneflow/core/framework/nd_sbp.h\"\n#include \"oneflow/core/framework/tensor.h\"\n#include \"oneflow/core/framework/tensor_tuple.h\"\n#include \"oneflow/core/framework/random_generator.h\"\n#include \"oneflow/core/functional/tensor_index.h\"\n\nnamespace oneflow {{\nnamespace one {{\nnamespace functional {{\n{0}\n}}  // namespace functional\n}}  // namespace one\n}}  // namespace oneflow\n\n#endif  // ONEFLOW_API_PYTHON_FUNCTIONAL_GENERATED_TENSOR_API_H_\"\"\"\n)\n\nsource_fmt = (\n    license\n    + \"\"\"\n\n#include \"oneflow/api/python/functional/tensor_api.yaml.h\"\n#include \"oneflow/core/functional/function_library.h\"\n\nnamespace oneflow {{\nnamespace one {{\nnamespace functional {{\n{0}\n}}  // namespace functional\n}}  // namespace one\n}}  // namespace oneflow\n\"\"\"\n)\n\npybind_header_fmt = (\n    license\n    + \"\"\"\n\n#include <Python.h>\n#undef _PyGC_FINALIZED\n\nnamespace oneflow {{\nnamespace one {{\nnamespace functional {{\n{0}\n}}  // namespace functional\n}}  // namespace one\n}}  // namespace oneflow\n\"\"\"\n)\n\npybind_source_fmt = (\n    license\n    + \"\"\"\n\n#include <Python.h>\n#undef _PyGC_FINALIZED\n\n#include \"oneflow/api/python/of_api_registry.h\"\n#include \"oneflow/api/python/exception/exception.h\"\n#include \"oneflow/api/python/functional/common.h\"\n#include \"oneflow/api/python/functional/function_def.h\"\n#include \"oneflow/api/python/functional/python_arg.h\"\n#include \"oneflow/api/python/functional/python_arg_parser.h\"\n#include \"oneflow/api/python/functional/tensor_api.yaml.h\"\n#include \"oneflow/api/python/functional/tensor_api.yaml.pybind.h\"\n#include \"oneflow/core/common/maybe.h\"\n#include \"oneflow/core/common/optional.h\"\n#include \"oneflow/extension/stack/python/stack_getter.h\"\n\nnamespace oneflow {{\nnamespace one {{\nnamespace functional {{\n{0}\n}}  // namespace functional\n}}  // namespace one\n\nnamespace functional = one::functional;\n\nONEFLOW_API_PYBIND11_MODULE(\"_C\", m) {{\n  static PyMethodDef functions[] = {{\n{1}\n    {{NULL, NULL, 0, NULL}}\n  }};\n\n  PyObject* module = m.ptr();\n  if (module) {{\n    PyModule_AddFunctions(module, functions);\n  }}\n}}\n\n}}  // namespace oneflow\n\"\"\"\n)\n\nyaml_file_path = os.path.join(\n    args.project_source_dir, \"oneflow/api/python/functional/tensor_api.yaml\"\n)\ngenerated_api_dir = \"oneflow/api/python/functional\"\ngenerated_pybind_dir = \"oneflow/api/python/functional\"\n\nif __name__ == \"__main__\":\n    assert os.path.isfile(yaml_file_path), (\n        \"It is not a regular file for the yaml file which is \" + yaml_file_path\n    )\n    g = Generator(yaml_file_path)\n\n    assert os.path.isdir(generated_api_dir), (\n        \"Could not locate the api generate directory which is \" + generated_api_dir\n    )\n    target_header_file = os.path.join(generated_api_dir, \"tensor_api.yaml.h\")\n    g.generate_cpp_header_file(header_fmt, target_header_file)\n    target_source_file = os.path.join(generated_api_dir, \"tensor_api.yaml.cpp\")\n    g.generate_cpp_source_file(source_fmt, target_source_file)\n\n    assert os.path.isdir(generated_pybind_dir), (\n        \"Could not locate the pybind generate directory which is \"\n        + generated_pybind_dir\n    )\n    target_pybind_header_file = os.path.join(\n        generated_pybind_dir, \"tensor_api.yaml.pybind.h\"\n    )\n    target_pybind_source_file = os.path.join(\n        generated_pybind_dir, \"tensor_api.yaml.pybind.cpp\"\n    )\n    g.generate_pybind_for_python(\n        pybind_header_fmt,\n        pybind_source_fmt,\n        target_pybind_header_file,\n        target_pybind_source_file,\n    )\n"
  },
  {
    "path": "tools/functional/generator.py",
    "content": "\"\"\"\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport os\nimport re\nimport argparse\nimport yaml\n\ntypes_allowed = {\n    \"Void\",\n    \"Tensor\",\n    \"TensorTuple\",\n    \"Scalar\",\n    \"Int\",\n    \"Int32\",\n    \"Int64\",\n    \"Float\",\n    \"Double\",\n    \"String\",\n    \"Bool\",\n    \"ScalarList\",\n    \"IntList\",\n    \"Int32List\",\n    \"Int64List\",\n    \"FloatList\",\n    \"DoubleList\",\n    \"StringList\",\n    \"BoolList\",\n    \"DataType\",\n    \"Shape\",\n    \"Generator\",\n    \"TensorIndex\",\n    \"Device\",\n    \"Placement\",\n    \"Sbp\",\n    \"SbpList\",\n    \"OpExpr\",\n    \"PyObject*\",\n    \"ShapeList\",\n    \"DataTypeList\",\n    \"Layout\",\n    \"MemoryFormat\",\n}\n\nmangled_name = {\n    \"Void\": \"V\",\n    \"Tensor\": \"T\",\n    \"TensorTuple\": \"Tt\",\n    \"Scalar\": \"Sc\",\n    \"Int\": \"I\",\n    \"Int32\": \"I32\",\n    \"Int64\": \"I64\",\n    \"Float\": \"F\",\n    \"Double\": \"D\",\n    \"String\": \"S\",\n    \"Bool\": \"B\",\n    \"ScalarList\": \"Scl\",\n    \"IntList\": \"Il\",\n    \"Int32List\": \"I32l\",\n    \"Int64List\": \"I64l\",\n    \"FloatList\": \"Fl\",\n    \"DoubleList\": \"Dl\",\n    \"StringList\": \"Sl\",\n    \"BoolList\": \"Bl\",\n    \"DataType\": \"Dt\",\n    \"Shape\": \"Sh\",\n    \"Generator\": \"G\",\n    \"TensorIndex\": \"Ti\",\n    \"Device\": \"De\",\n    \"Placement\": \"P\",\n    \"Sbp\": \"Sbp\",\n    \"SbpList\": \"Sbpl\",\n    \"OpExpr\": \"Op\",\n    \"PyObject*\": \"Pyo\",\n    \"ShapeList\": \"Shl\",\n    \"DataTypeList\": \"Dtl\",\n    \"Layout\": \"Lo\",\n    \"MemoryFormat\": \"Memf\",\n}\n\ngeneric_type_aliases = {\n    \"Int\": \"int32_t\",\n    \"Int32\": \"int32_t\",\n    \"Int64\": \"int64_t\",\n    \"Float\": \"float\",\n    \"Double\": \"double\",\n    \"Bool\": \"bool\",\n}\n\nargument_type_aliases = {\n    \"Tensor\": \"const std::shared_ptr<one::Tensor>&\",\n    \"TensorTuple\": \"const TensorTuple&\",\n    \"Scalar\": \"const Scalar&\",\n    \"ScalarList\": \"const std::vector<Scalar>&\",\n    \"IntList\": \"const std::vector<int32_t>&\",\n    \"Int32List\": \"const std::vector<int32_t>&\",\n    \"Int64List\": \"const std::vector<int64_t>&\",\n    \"FloatList\": \"const std::vector<float>&\",\n    \"DoubleList\": \"const std::vector<double>&\",\n    \"String\": \"const std::string&\",\n    \"StringList\": \"const std::vector<std::string>&\",\n    \"BoolList\": \"const std::vector<bool>&\",\n    \"DataType\": \"const Symbol<DType>&\",\n    \"Shape\": \"const Shape&\",\n    \"Generator\": \"const std::shared_ptr<one::Generator>&\",\n    \"TensorIndex\": \"const TensorIndex&\",\n    \"Device\": \"const Symbol<Device>&\",\n    \"Placement\": \"const Symbol<ParallelDesc>&\",\n    \"Sbp\": \"const Symbol<SbpParallel>&\",\n    \"SbpList\": \"const std::vector<Symbol<SbpParallel>>&\",\n    \"OpExpr\": \"const std::shared_ptr<one::OpExpr>&\",\n    \"PyObject*\": \"PyObject*\",\n    \"ShapeList\": \"const std::vector<Shape>&\",\n    \"DataTypeList\": \"const std::vector<Symbol<DType>>&\",\n    \"Layout\": \"const Symbol<Layout>&\",\n    \"MemoryFormat\": \"MemoryFormat\",\n    **generic_type_aliases,\n}\n\noptional_argument_type_aliases = {\n    \"Tensor\": \"const Optional<one::Tensor>&\",\n    \"TensorTuple\": \"const Optional<TensorTuple>&\",\n    \"Scalar\": \"const Optional<Scalar>&\",\n    \"ScalarList\": \"const Optional<std::vector<Scalar>>&\",\n    \"IntList\": \"const Optional<std::vector<int32_t>>&\",\n    \"Int32List\": \"const Optional<std::vector<int32_t>>&\",\n    \"Int64List\": \"const Optional<std::vector<int64_t>>&\",\n    \"FloatList\": \"const Optional<std::vector<float>>&\",\n    \"DoubleList\": \"const Optional<std::vector<double>>&\",\n    \"String\": \"const Optional<std::string>&\",\n    \"StringList\": \"const Optional<std::vector<std::string>>&\",\n    \"BoolList\": \"const Optional<std::vector<bool>>&\",\n    \"DataType\": \"const Optional<Symbol<DType>>&\",\n    \"Shape\": \"const Optional<Shape>&\",\n    \"Generator\": \"const Optional<one::Generator>&\",\n    \"TensorIndex\": \"const Optional<TensorIndex>&\",\n    \"Device\": \"const Optional<Symbol<Device>>&\",\n    \"Placement\": \"const Optional<Symbol<ParallelDesc>>&\",\n    \"Sbp\": \"const Optional<Symbol<SbpParallel>>&\",\n    \"SbpList\": \"const Optional<std::vector<Symbol<SbpParallel>>>&\",\n    \"OpExpr\": \"const Optional<one::OpExpr>&\",\n    \"PyObject*\": \"const Optional<PyObject*>&\",\n    \"ShapeList\": \"const Optional<std::vector<Shape>>&\",\n    \"DataTypeList\": \"const Optional<std::vector<Symbol<DType>>>&\",\n    \"Layout\": \"const Optional<Symbol<Layout>>&\",\n    \"MemoryFormat\": \"const Optional<MemoryFormat>&\",\n    **{k: \"const Optional<{0}>&\".format(v) for k, v in generic_type_aliases.items()},\n}\n\nreturn_type_aliases = {\n    \"Void\": \"Maybe<void>\",\n    \"Tensor\": \"Maybe<one::Tensor>\",\n    \"TensorTuple\": \"Maybe<one::TensorTuple>\",\n    \"String\": \"Maybe<std::string>\",\n    \"Shape\": \"Maybe<Shape>\",\n    **{k: \"Maybe<{0}>\".format(v) for k, v in generic_type_aliases.items()},\n}\n\nvalue_aliases = {\n    \"True\": \"true\",\n    \"False\": \"false\",\n    \"kInt\": \"DType::Int32()\",\n    \"kInt8\": \"DType::Int8()\",\n    \"kUInt8\": \"DType::UInt8()\",\n    \"kInt32\": \"DType::Int32()\",\n    \"kInt64\": \"DType::Int64()\",\n    \"kFloat\": \"DType::Float()\",\n    \"kDouble\": \"DType::Double()\",\n    \"kBool\": \"DType::Bool()\",\n    \"kStrided\": \"Layout::Strided()\",\n}\n\n\ndef _escape_quote(fmt):\n    return re.sub(r\"\\\"|\\'\", '\\\\\"', fmt)\n\n\ndef _normalize(fmt):\n    fmt = fmt.strip()\n    return re.sub(r\"\\s+\", \" \", fmt)\n\n\ndef _remove_square_brackets_and_content_inside(fmt):\n    # \"TensorTuple[values], TensorTuple[indices]\" -> \"TensorTuple, TensorTuple\"\n    return re.sub(r\"\\[[^()]*?\\]\", \"\", fmt)\n\n\ndef _std_decay(fmt):\n    fmt = fmt.strip()\n    fmt = re.sub(r\"(const|&)\", \"\", fmt)\n    return _normalize(fmt)\n\n\ndef parse_function_params(fmt):\n    params = []\n    fmt = _normalize(fmt)\n    open_paren = fmt.find(\"(\")\n    if open_paren == -1:\n        raise ValueError('Missing \"(\" in function def: ' + fmt)\n\n    header = _normalize(fmt[0:open_paren])\n    items = _normalize(_remove_square_brackets_and_content_inside(header)).split(\" \")\n    if (len(items)) != 1:\n        raise ValueError(\n            \"Missing return type or more than 1 return type in function def: \" + fmt\n        )\n\n    params.append(header)\n\n    close_paren = fmt.rfind(\")\")\n    if close_paren == -1:\n        raise ValueError('Missing \")\" in Missingfunction def: ' + fmt)\n\n    tail = fmt[open_paren + 1 : close_paren]\n    # TODO(): Parse the parameter list more comprehensively.\n    items = tail.split(\",\")\n    for param in items:\n        params.append(_normalize(param))\n\n    pos = fmt.rfind(\"=>\")\n    if pos == -1:\n        raise ValueError('Missing \"=>\" in Missingfunction def: ' + fmt)\n    function_name = _normalize(fmt[pos + 2 :])\n    return function_name, params\n\n\ndef render_file_if_different(target_file, content):\n    if not os.path.isfile(target_file):\n        with open(target_file, \"w\") as f:\n            f.write(content)\n    else:\n        old_content = None\n        with open(target_file, \"r\") as f:\n            old_content = f.read()\n        if old_content is None or old_content != content:\n            with open(target_file, \"w\") as f:\n                f.write(content)\n\n\ndef generate_return_types_named_tuple(return_names, func_name, block_name):\n    param_names = \", \".join(\n        [\n            '{{const_cast<char*>(\"{}\"), const_cast<char*>(\"\")}}'.format(x)\n            for x in return_names\n        ]\n    )\n    code = f\"\"\"PyTypeObject* Get{func_name}NamedTuple() {{\n  static PyStructSequence_Field NamedTuple_fields[] = {{ {param_names},  {{nullptr}} }};\n  static PyTypeObject {func_name}NamedTuple;\n  static bool is_initialized = false;\n  static PyStructSequence_Desc desc = {{ const_cast<char*>(\"oneflow.return_types.{block_name}\"), nullptr, NamedTuple_fields, {len(return_names)} }};\n  if (!is_initialized) {{\n      PyStructSequence_InitType(&{func_name}NamedTuple, &desc);\n      {func_name}NamedTuple.tp_repr = (reprfunc)returned_structseq_repr;\n      is_initialized = true;\n  }}\n  return &{func_name}NamedTuple;\n}}\n\"\"\"\n    return code\n\n\nclass Argument:\n    def __init__(self, fmt, keyword_only=False):\n        self._keyword_only = keyword_only\n        self._type = None\n        self._name = None\n        self._default_value = None\n        self._size = 0\n\n        fmt = _normalize(fmt)\n        sp = fmt.rfind(\" \")\n        if sp == -1:\n            raise ValueError(\"Missing argument type or name for argument def: \" + fmt)\n        type_name = fmt[0:sp]\n        arg_name = fmt[sp + 1 :]\n        sp = type_name.find(\"[\")\n        if sp != -1:\n            self._type = _normalize(type_name[0:sp])\n            size = type_name[sp + 1 :]\n            sp = size.find(\"]\")\n            assert sp != -1, \"Missing ']' for argument def: \" + fmt\n            size = _normalize(size[0:sp])\n            assert size.isnumeric(), (\n                \"list size is not an integer for argument def: \" + fmt\n            )\n            self._size = int(size)\n        else:\n            self._type = _normalize(type_name)\n        assert self._type in types_allowed, \"Unknow type: \" + self._type\n\n        self._optional = False\n        self._name = _normalize(arg_name)\n        sp = self._name.find(\"=\")\n        if sp != -1:\n            self._default_value = _normalize(self._name[sp + 1 :])\n            if self._default_value == \"None\":\n                self._optional = True\n                self._default_cpp_value = \"\"\n            elif self._type.endswith(\"List\"):\n                if self._default_value != \"None\":\n                    _value_list = [\n                        self._default_value for i in range(self._size)\n                    ]  # For int32List[2] = 2, _value_list will be [\"2\", \"2\"]\n                    self._default_cpp_value = (\n                        \"{\" + \", \".join(_value_list) + \"}\"\n                    )  # [\"2\", \"2\"] -> \"{2, 2}\"\n            elif self._default_value in value_aliases:\n                self._default_cpp_value = value_aliases[self._default_value]\n            else:\n                self._default_cpp_value = self._default_value\n            self._name = _normalize(self._name[0:sp])\n\n        if not self._optional and self._type in argument_type_aliases:\n            self._cpp_type = argument_type_aliases[self._type]\n        elif self._optional and self._type in optional_argument_type_aliases:\n            self._cpp_type = optional_argument_type_aliases[self._type]\n        else:\n            self._cpp_type = self._type\n\n    @property\n    def has_default_value(self):\n        return self._default_value is not None\n\n    def to_string(self, to_cpp=False):\n        fmt = \"{0} {1}\".format(self._cpp_type if to_cpp else self._type, self._name)\n        if not to_cpp and self.has_default_value:\n            fmt += \"={0}\".format(self._default_value)\n        return fmt\n\n\nclass Return:\n    def __init__(self, fmt):\n        self._type, self._return_names = self.check_named_tuple(_normalize(fmt))\n        assert self._type in types_allowed, \"Unknow type: \" + self._type\n\n        if self._type in return_type_aliases:\n            self._cpp_type = return_type_aliases[self._type]\n        else:\n            self._cpp_type = self._type\n\n    @property\n    def type(self):\n        return self._type\n\n    def to_string(self, to_cpp=False):\n        return self._cpp_type if to_cpp else self._type\n\n    def check_named_tuple(self, fmt):\n        matches = re.match(r\"(.*?)\\s*\\[(.*?)\\]\", fmt)\n        if matches is None:\n            type, return_names = _normalize(fmt), None\n        else:\n            type = matches.group(1)\n            return_names = [_normalize(x) for x in matches.group(2).split(\",\")]\n        return type, return_names\n\n\nclass FunctionSignature:\n    def __init__(self, fmt):\n        self._fmt = fmt\n        self._name, self._params = parse_function_params(fmt)\n        self._ret = Return(self._params[0])\n        keyword_only = False\n        self._args = []\n        self._max_positional_args_count = 0\n        for arg in self._params[1:]:\n            if arg == \"*\":\n                keyword_only = True\n                continue\n            self._args.append(Argument(arg, keyword_only=keyword_only))\n            if not keyword_only:\n                self._max_positional_args_count += 1\n\n        self._max_args_count = len(self._args)\n        count = 0\n        for arg in self._args:\n            if arg._keyword_only:\n                count += 1\n        self._max_keyword_args_count = count\n\n    @property\n    def num_of_args(self):\n        return len(self._args)\n\n    def to_string(self, to_cpp=False, drop_name=False):\n        if drop_name:\n            fmt = \"{0} (\".format(self._ret.to_string(to_cpp=to_cpp))\n        else:\n            fmt = \"{0} {1}(\".format(self._ret.to_string(to_cpp=to_cpp), self._name)\n        keyword_start = False\n        for i, arg in enumerate(self._args):\n            if i > 0 and i < len(self._args):\n                fmt += \", \"\n            if not keyword_start and arg._keyword_only:\n                keyword_start = True\n                if not to_cpp:\n                    fmt += \"*, \"\n            fmt += arg.to_string(to_cpp=to_cpp)\n        fmt += \")\"\n        return fmt\n\n    def get_mangled_type(self):\n        fmt = mangled_name[self._ret._type]\n        for _, arg in enumerate(self._args):\n            fmt += mangled_name[arg._type]\n        return fmt\n\n    def get_schema_name(self):\n        return \"{0}Schema_{1}\".format(self._name, self.get_mangled_type())\n\n\nclass Block:\n    def __init__(self, name, signature, bind_python):\n        self._name = name\n        self._signature = signature\n        self._bind_python = bind_python\n\n\nclass Generator:\n    def __init__(self, input_file):\n        self._blocks = {}\n        with open(input_file) as f:\n            doc = yaml.load(f, Loader=yaml.FullLoader)\n            for block in doc:\n                assert \"name\" in block\n                assert \"signature\" in block\n                name = block[\"name\"]\n                signature = block[\"signature\"]\n                bind_python = False\n                if \"bind_python\" in block:\n                    bind_python = block[\"bind_python\"]\n                self._blocks[name] = list()\n                if isinstance(signature, list):\n                    for s in signature:\n                        self._blocks[name].append(\n                            Block(name, FunctionSignature(s), bind_python)\n                        )\n                else:\n                    self._blocks[name].append(\n                        Block(name, FunctionSignature(signature), bind_python)\n                    )\n\n    def generate_cpp_header_file(self, header_fmt, target_header_file):\n        fmt = \"\"\n        for name, blocks in self._blocks.items():\n            for block in blocks:\n                fmt += \"\\n\"\n                fmt += block._signature.to_string(to_cpp=True)\n                fmt += \";\\n\"\n\n        render_file_if_different(target_header_file, header_fmt.format(fmt))\n\n    def generate_cpp_source_file(self, source_fmt, target_source_file):\n        fmt = \"\"\n        for name, blocks in self._blocks.items():\n            for block in blocks:\n                signature = block._signature\n                fmt += \"\\n\"\n                fmt += signature.to_string(to_cpp=True)\n                fmt += \" {\\n\"\n                fmt += '  static thread_local const auto& __op = CHECK_JUST(FunctionLibrary::Global()->find<{0}, {1}>(\"{2}\"));\\n'.format(\n                    signature._ret._cpp_type,\n                    \", \".join([arg._cpp_type for arg in signature._args]),\n                    signature._name,\n                )\n                fmt += \"  return __op->call({0});\\n\".format(\n                    \", \".join([arg._name for arg in signature._args]),\n                )\n                fmt += \"}\\n\"\n\n        render_file_if_different(target_source_file, source_fmt.format(fmt))\n\n    def generate_pybind_for_python(\n        self,\n        pybind_header_fmt,\n        pybind_source_fmt,\n        target_pybind_header_file,\n        target_pybind_source_file,\n    ):\n        schema_fmt = \"\"\n        module_fmt = \"\"\n        header_fmt = \"\"\n\n        return_type_fmt = \"\"\n        map_pairs = []\n        for name, blocks in self._blocks.items():\n            schema_types = []\n            max_args_count = 0\n            for block in blocks:\n                if not block._bind_python:\n                    continue\n                signature = block._signature\n                max_args_count = max(max_args_count, signature._max_args_count)\n                schema_types.append(\n                    \"functional::{0}\".format(signature.get_schema_name())\n                )\n                return_type = signature._ret._cpp_type\n                schema_fmt += \"\\n\"\n                schema_fmt += \"struct {0} {{\\n\".format(signature.get_schema_name())\n                schema_fmt += \"  using FType = {0};\\n\".format(\n                    signature.to_string(to_cpp=True, drop_name=True)\n                )\n                schema_fmt += \"  using R = {0};\\n\".format(return_type)\n                schema_fmt += \"\\n\"\n                schema_fmt += \"  static constexpr FType* func = &functional::{0};\\n\".format(\n                    signature._name\n                )\n                schema_fmt += \"  static constexpr size_t max_args = {0};\\n\".format(\n                    signature._max_args_count\n                )\n                schema_fmt += \"  static constexpr size_t max_pos_args = {0};\\n\".format(\n                    signature._max_positional_args_count\n                )\n                schema_fmt += '  static constexpr char const* signature = \"{0}\";\\n'.format(\n                    _escape_quote(signature.to_string(drop_name=True))\n                )\n                schema_fmt += \"  static FunctionDef function_def;\\n\"\n                schema_fmt += \"};\\n\"\n                schema_fmt += \"\\n\"\n                schema_fmt += \"constexpr size_t {0}::max_args;\\n\".format(\n                    signature.get_schema_name()\n                )\n                schema_fmt += \"constexpr size_t {0}::max_pos_args;\\n\".format(\n                    signature.get_schema_name()\n                )\n                schema_fmt += \"constexpr char const* {0}::signature;\\n\".format(\n                    signature.get_schema_name()\n                )\n                return_def = \"ReturnDef(ValueTypeOf<{0}>())\".format(return_type)\n                argument_def = []\n                for arg in signature._args:\n                    keyword_only = \"true\" if arg._keyword_only else \"false\"\n                    optional = \"true\" if arg._optional else \"false\"\n                    if arg.has_default_value:\n                        argument_def.append(\n                            '  ArgumentDef(/*name*/\"{0}\", /*default_value*/{1}({2}), /*size*/{3}, /*keyword_only*/{4}, /*optional*/{5})'.format(\n                                arg._name,\n                                _std_decay(arg._cpp_type),\n                                arg._default_cpp_value,\n                                arg._size,\n                                keyword_only,\n                                optional,\n                            )\n                        )\n                    else:\n                        argument_def.append(\n                            '  ArgumentDef(/*name*/\"{0}\", /*value_type*/ValueTypeOf<{1}>(), /*size*/{2}, /*keyword_only*/{3}, /*optional*/{4})'.format(\n                                arg._name,\n                                _std_decay(arg._cpp_type),\n                                arg._size,\n                                keyword_only,\n                                optional,\n                            )\n                        )\n                schema_fmt += 'FunctionDef {0}::function_def = {{\\n/*name*/\"{1}\",\\n/*return_def*/{2},\\n/*argument_def*/{{\\n{3}\\n}}\\n}};\\n'.format(\n                    signature.get_schema_name(),\n                    name,\n                    return_def,\n                    \",\\n\".join(argument_def),\n                )\n\n            if len(schema_types) > 0:\n                module_fmt += '    {{\"{0}\", (PyCFunction)functional::{1}, METH_VARARGS | METH_KEYWORDS, NULL}},\\n'.format(\n                    name, name\n                )\n\n                header_fmt += \"\\n\"\n                header_fmt += \"PyObject* {0}(PyObject* self, PyObject* args, PyObject* kwargs);\\n\".format(\n                    name\n                )\n                schema_fmt += \"\\n\"\n                schema_fmt += \"PyObject* {0}(PyObject* self, PyObject* args, PyObject* kwargs) {{\\n\".format(\n                    name\n                )\n                schema_fmt += \"  HANDLE_ERRORS\\n\"\n                schema_fmt += '  OF_PROFILER_RANGE_GUARD(\"{0}\");\\n'.format(name)\n                schema_fmt += \"  PythonFrameGuard pf;\\n\"\n                schema_fmt += '  static PythonArgParser<{0}> parser(\"{1}\");\\n'.format(\n                    \", \".join(schema_types), name\n                )\n                schema_fmt += \"  ParsedArgs<{0}> r;\\n\".format(max_args_count)\n                schema_fmt += \"  int idx = parser.Parse(args, kwargs, &r);\\n\"\n                i = 0\n                for block in blocks:\n                    signature = block._signature\n                    schema_fmt += \"  if (idx == {0}) {{\\n\".format(i)\n                    params = []\n                    for j in range(len(signature._args)):\n                        cpp_type = _std_decay(signature._args[j]._cpp_type)\n                        params.append(\"r[{0}].As<{1}>()\".format(j, cpp_type))\n                    if signature._ret._return_names is None:\n                        schema_fmt += \"    return CastToPyObject(functional::{0}({1}));\\n\".format(\n                            signature._name, \", \".join(params)\n                        )\n                    else:\n                        schema_fmt += '    return WrapTensorTuple(functional::{0}({1}).GetOrThrow(), \"{2}\");\\n'.format(\n                            signature._name, \", \".join(params), signature._name,\n                        )\n                        return_type_fmt += generate_return_types_named_tuple(\n                            signature._ret._return_names, signature._name, block._name,\n                        )\n                        map_pairs.append(\n                            f'    {{\"{signature._name}\", Get{signature._name}NamedTuple()}},'\n                        )\n                    schema_fmt += \"  }\\n\"\n                    i += 1\n                schema_fmt += \"  Py_RETURN_NONE;\\n\"\n                schema_fmt += \"  END_HANDLE_ERRORS\\n\"\n                schema_fmt += \"}\\n\"\n\n        render_file_if_different(\n            target_pybind_header_file, pybind_header_fmt.format(header_fmt)\n        )\n        render_file_if_different(\n            target_pybind_source_file,\n            pybind_source_fmt.format(\n                schema_fmt, module_fmt, return_type_fmt, \"\\n\".join(map_pairs)\n            ),\n        )\n"
  },
  {
    "path": "tools/generate_header_list.py",
    "content": "import glob\nimport argparse\nimport os\n\nparser = argparse.ArgumentParser()\nparser.add_argument(\"-i\", \"--src_path\", type=str, required=True)\nparser.add_argument(\"-o\", \"--dst_file\", type=str, required=True)\nargs = parser.parse_args()\n\n\ndef glob_by_pattern(pattern):\n    result = []\n    for x in glob.glob(os.path.join(args.src_path, pattern), recursive=True):\n        result.append(os.path.relpath(x, args.src_path))\n    return result\n\n\nheaders = (\n    glob_by_pattern(\"**/*.h\")\n    + glob_by_pattern(\"**/*.hpp\")\n    + glob_by_pattern(\"**/*.cuh\")\n    + glob_by_pattern(\"**/*.proto\")\n    + glob_by_pattern(\"**/*.inc\")\n)\nwith open(args.dst_file, \"w\") as f:\n    for item in headers:\n        f.write(\"{}\\n\".format(item))\n"
  },
  {
    "path": "tools/generate_pip_version.py",
    "content": "import os\nimport subprocess\nimport argparse\nfrom datetime import date\n\nparser = argparse.ArgumentParser()\nparser.add_argument(\"--cuda\", type=str, required=False)\nparser.add_argument(\"--cmake_project_binary_dir\", type=str, required=False)\nparser.add_argument(\"--src\", type=str, required=False)\nparser.add_argument(\"--out\", type=str, required=False)\nargs = parser.parse_args()\n\nlocal_label = \"\"\nversion = f\"1.0.0\"\n\n# set version if release of nightly\nassert (\n    os.getenv(\"ONEFLOW_RELEASE_VERSION\") != \"\"\n), \"ONEFLOW_RELEASE_VERSION should be either None or a valid string\"\nis_release = False\nis_nightly = False\n\ndate_str = os.getenv(\"ONEFLOW_NIGHTLY_DATE\")\n\nif os.getenv(\"ONEFLOW_RELEASE_VERSION\"):\n    release_version = os.getenv(\"ONEFLOW_RELEASE_VERSION\")\n    version = f\"{release_version}\"\n    is_release = True\nelif date_str:\n    version += f\".dev{date_str}\"\n    is_nightly = True\n\n# append compute_platform\ncompute_platform = \"\"\nif args.cuda:\n    # TODO: use a proper semver lib to handle versions\n    splits = args.cuda.split(\".\")[0:2]\n    assert len(splits) == 2\n    compute_platform = \"\".join(splits)\n    compute_platform = \"cu\" + compute_platform\nelse:\n    compute_platform = \"cpu\"\nassert compute_platform\nversion += f\"+{compute_platform}\"\n\ntry:\n    git_hash = (\n        subprocess.check_output(\"git rev-parse --short HEAD\", shell=True, cwd=args.src)\n        .decode()\n        .strip()\n    )\nexcept:\n    git_hash = \"unknown\"\n\n# append git if not release\nif not os.getenv(\"ONEFLOW_RELEASE_VERSION\") and not os.getenv(\"ONEFLOW_NIGHTLY_DATE\"):\n    version += f\".git.{git_hash}\"\n\n\nprint(f\"-- Generating pip version: {version}, writing to: {args.out}\")\nassert args.out\nwith open(args.out, \"w+\") as f:\n    f.write(f'__version__ = \"{version}\"\\n')\n    f.write(f'__git_commit__ = \"{git_hash}\"\\n')\n    if not (is_nightly or is_release):\n        f.write(f'__cmake_project_binary_dir__ = \"{args.cmake_project_binary_dir}\"\\n')\n"
  },
  {
    "path": "tools/oneflow-tblgen/CMakeLists.txt",
    "content": "set(LLVM_LINK_COMPONENTS Support)\ninclude(FetchContent)\n\nset(JSON_Install ON CACHE STRING \"\" FORCE)\nFetchContent_Declare(json URL ${JSON_URL} URL_HASH MD5=${JSON_URL_HASH})\n\nset(INJA_USE_EMBEDDED_JSON OFF CACHE STRING \"\" FORCE)\nset(INJA_BUILD_TESTS OFF CACHE STRING \"\" FORCE)\nset(BUILD_BENCHMARK OFF CACHE STRING \"\" FORCE)\nFetchContent_Declare(inja URL ${INJA_URL} URL_HASH MD5=${INJA_URL_HASH})\n\nFetchContent_MakeAvailable(json inja)\n\nadd_tablegen(oneflow_tblgen llvm tablegen.cpp op_schema_emitter.cpp)\n\nif(LLVM_ENABLE_OBJLIB)\n  set(OF_TBLGEN_TARGET obj.oneflow_tblgen)\nelse()\n  set(OF_TBLGEN_TARGET oneflow_tblgen)\nendif()\n\ntarget_link_libraries(${OF_TBLGEN_TARGET} PRIVATE nlohmann_json::nlohmann_json pantor::inja)\n\ninstall(TARGETS oneflow_tblgen LLVMTableGen LLVMDemangle LLVMSupport COMPONENT OneFlowTableGen\n        LIBRARY DESTINATION lib)\nadd_custom_target(\n  install-oneflow-tblgen DEPENDS oneflow_tblgen\n  COMMAND \"${CMAKE_COMMAND}\" -DCMAKE_INSTALL_COMPONENT=OneFlowTableGen -P\n          \"${CMAKE_BINARY_DIR}/cmake_install.cmake\")\n"
  },
  {
    "path": "tools/oneflow-tblgen/backends.h",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#ifndef ONEFLOW_TBLGEN_BACKENDS_H\n#define ONEFLOW_TBLGEN_BACKENDS_H\n\nnamespace llvm {\nclass raw_ostream;\nclass RecordKeeper;\n}  // namespace llvm\n\nnamespace oneflow {\n\nnamespace tblgen {\n\nusing llvm::raw_ostream;\nusing llvm::RecordKeeper;\n\nvoid EmitOpSchemaHeader(RecordKeeper& RK, raw_ostream& OS);\nvoid EmitOpSchemaSource(RecordKeeper& RK, raw_ostream& OS);\n\n}  // namespace tblgen\n\n}  // namespace oneflow\n\n#endif  // ONEFLOW_TBLGEN_BACKENDS_H\n"
  },
  {
    "path": "tools/oneflow-tblgen/example/constant.td",
    "content": "include \"mlir/Interfaces/SideEffectInterfaces.td\"\ninclude \"OneFlowEnums.td\"\ninclude \"OneFlowBase.td\"\n\ndef OneFlow_ConstantOp : OneFlow_BaseOp<\"constant\", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {\n  let output = (outs\n    AnyType:$out\n  );\n  let attrs = (ins\n    DefaultValuedAttr<F64Attr, \"0.\">:$floating_value,\n    DefaultValuedAttr<SI64Attr, \"0\">:$integer_value,\n    DefaultValuedAttr<BoolAttr, \"false\">:$is_floating_value,\n    StrAttr:$dtype,\n    AnyI64ElementsAttr:$shape,\n    StrArrayAttr:$nd_sbp\n  );\n}\n"
  },
  {
    "path": "tools/oneflow-tblgen/op_schema_emitter.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n#include \"llvm/ADT/ArrayRef.h\"\n#include \"llvm/ADT/DenseMap.h\"\n#include \"llvm/ADT/StringExtras.h\"\n#include \"llvm/Support/CommandLine.h\"\n#include \"llvm/Support/Format.h\"\n#include \"llvm/Support/MemoryBuffer.h\"\n#include \"llvm/Support/SourceMgr.h\"\n#include \"llvm/Support/raw_ostream.h\"\n#include \"llvm/TableGen/Error.h\"\n#include \"llvm/TableGen/Record.h\"\n#include \"llvm/TableGen/TableGenBackend.h\"\n#include \"inja/inja.hpp\"\n\n#include <iomanip>\n#include <string>\n\nusing namespace llvm;\nusing inja::json;\n\nnamespace oneflow {\nnamespace tblgen {\n\ncl::OptionCategory opSchemaCat(\"Options for -gen-op-schema\");\n\ncl::opt<std::string> sourceIncludeFilename{\n    \"op-include\", cl::desc(\"header filename to include in source file\"),\n    cl::value_desc(\"include filename\"), cl::init(\"\"), cl::cat(opSchemaCat)};\n\ncl::opt<std::string> dumpJson{\"op-dump-json\",\n                              cl::desc(\"dump tablegen code to json in provided file\"),\n                              cl::value_desc(\"filename\"), cl::init(\"\"), cl::cat(opSchemaCat)};\n\nenum class FileTarget {\n  kHeader = 1,\n  kSource,\n};\n\ntemplate<FileTarget Target>\nclass OpSchemaEmitter {\n public:\n  explicit OpSchemaEmitter(RecordKeeper& RK);\n\n  void run(raw_ostream& os);\n\n  void emitInputAndOutput(const Record* def, json* op) const;\n\n  void emitAttrs(const Record* def, json* op) const;\n\n  void emitInt(const Record* def, StringRef fieldname, json* op) const;\n  void emitBit(const Record* def, StringRef fieldname, json* op) const;\n  void emitTrait(const Record* def, StringRef fieldname, StringRef traitname, json* op) const;\n\n private:\n  static std::string emitType(const std::string& ods_type) {\n#define OP_SCHEMA(ods, cpp) \\\n  if (ods_type == #ods) return #cpp;\n#include \"op_schema_types.inc\"\n#undef OP_SCHEMA\n    PrintFatalError(\"undefined attribute type: \" + ods_type);\n  }\n\n private:\n  RecordKeeper& records;\n\n  StringRef op_type_name;\n  StringRef op_name;\n\n  inja::Environment env;\n  inja::Template temp;\n  static const std::string code;\n};\n\ntemplate<FileTarget Target>\nOpSchemaEmitter<Target>::OpSchemaEmitter(RecordKeeper& RK) : records(RK) {\n  env.add_callback(\"quoted\", 1, [](inja::Arguments& args) {\n    auto str = args.at(0)->get<std::string>();\n    std::ostringstream os;\n    os << std::quoted(str);\n    return os.str();\n  });\n  env.add_callback(\"to_header\", 1, [](inja::Arguments& args) {\n    auto str = args.at(0)->get<std::string>();\n    auto dot_pos = str.find_last_of('.');\n    if (dot_pos != std::string::npos) { str.replace(dot_pos, str.size() - dot_pos, \".h\"); }\n\n    // assume that the source and header file is in the same directory\n    auto slash_pos = str.find_last_of('/');\n    if (slash_pos != std::string::npos) { str.replace(0, slash_pos + 1, \"\"); }\n    return str;\n  });\n  temp = env.parse(code);\n}\n\ntemplate<FileTarget Target>\nvoid OpSchemaEmitter<Target>::run(raw_ostream& os) {\n  emitSourceFileHeader(\"oneflow op schema\", os);\n  json ops = json::object();\n\n  for (const auto& def : records.getAllDerivedDefinitions(\"OneFlow_BaseOp\")) {\n    op_type_name = def->getValueAsString(\"opName\");\n    if (op_type_name.empty()) {\n      PrintFatalError(def, \"`opName` of op definitions cannot be omitted\");\n    }\n    op_name = def->getName();\n    if (!op_name.consume_front(\"OneFlow_\")) {\n      PrintFatalError(def, \"op name is not start with `OneFlow_`: \" + op_name.str());\n    }\n    json op{{\"name\", op_type_name},\n            {\"input\", json::array()},\n            {\"output\", json::array()},\n            {\"attrs\", json::array()}};\n\n    emitInputAndOutput(def, &op);\n    emitAttrs(def, &op);\n    emitInt(def, \"same_output_regst_num\", &op);\n    emitTrait(def, \"no_grad\", \"NoGrad\", &op);\n    emitTrait(def, \"support_non_contiguous\", \"SupportNonContiguous\", &op);\n    emitTrait(def, \"cpu_only\", \"CpuOnly\", &op);\n    emitBit(def, \"has_nd_sbp_infer_fn\", &op);\n    emitBit(def, \"has_get_sbp_fn\", &op);\n    emitBit(def, \"has_logical_tensor_desc_infer_fn\", &op);\n    emitBit(def, \"has_physical_tensor_desc_infer_fn\", &op);\n    emitBit(def, \"has_data_type_infer_fn\", &op);\n    emitBit(def, \"has_device_and_stream_infer_fn\", &op);\n    emitBit(def, \"has_input_arg_modify_fn\", &op);\n    emitBit(def, \"has_output_arg_modify_fn\", &op);\n    emitBit(def, \"has_output_blob_time_shape_infer_fn\", &op);\n    emitBit(def, \"has_sbp_signature_infer_fn\", &op);\n    emitBit(def, \"has_get_nd_sbp_fn\", &op);\n    emitBit(def, \"has_enumerate_nd_sbp_signatures_fn\", &op);\n    emitBit(def, \"has_dump_nd_sbp_signature_for_op_conf_fn\", &op);\n    emitBit(def, \"has_compute_complexity_fn\", &op);\n    emitBit(def, \"has_check_fn\", &op);\n    ops[op_name.str()] = op;\n  }\n\n  auto* option = static_cast<cl::opt<std::string>*>(cl::getRegisteredOptions().lookup(\"o\"));\n  auto filename = option->getValue();\n  filename = filename != \"-\" ? filename : \"\";\n  json data{{\"filename\", filename}, {\"ops\", ops}};\n\n  if (Target == FileTarget::kSource) { data[\"include\"] = sourceIncludeFilename.getValue(); }\n  if (!dumpJson.empty()) {\n    std::ofstream file(dumpJson);\n    file << data.dump();\n  }\n  os << env.render(temp, data);\n}\n\ntemplate<FileTarget Target>\nvoid OpSchemaEmitter<Target>::emitInputAndOutput(const Record* def, json* op) const {\n  const auto* input = def->getValueAsDag(\"input\");\n  for (size_t i = 0; i < input->getNumArgs(); ++i) {\n    const auto* A = dyn_cast<DefInit>(input->getArg(i))->getDef();\n    bool is_optional = A->isSubClassOf(\"Optional\");\n    auto NS = input->getArgName(i)->getAsUnquotedString();\n    (*op)[\"input\"].push_back({{\"name\", NS}, {\"is_optional\", is_optional}, {\"size\", 1}});\n  }\n  const auto* output = def->getValueAsDag(\"output\");\n  for (size_t i = 0; i < output->getNumArgs(); ++i) {\n    const auto* A = dyn_cast<DefInit>(output->getArg(i))->getDef();\n    bool is_optional = A->isSubClassOf(\"Optional\");\n    auto NS = output->getArgName(i)->getAsUnquotedString();\n    (*op)[\"output\"].push_back({{\"name\", NS}, {\"is_optional\", is_optional}, {\"size\", 1}});\n  }\n}\n\ntemplate<FileTarget Target>\nvoid OpSchemaEmitter<Target>::emitAttrs(const Record* def, json* op) const {\n  const auto* attrs = def->getValueAsDag(\"attrs\");\n  for (size_t i = 0; i < attrs->getNumArgs(); ++i) {\n    const auto* A = dyn_cast<DefInit>(attrs->getArg(i))->getDef();\n    std::string AS;\n    if (!A->isAnonymous()) {\n      AS = A->getNameInitAsString();\n    } else {\n      AS = A->getValueAsDef(\"baseAttr\")->getNameInitAsString();\n    }\n    auto NS = attrs->getArgName(i)->getAsUnquotedString();\n    // FlatSymbolRefAttr:$callee,\n    if (\"callee\" == NS && \"FlatSymbolRefAttr\" == AS) { continue; }\n    json attr{{\"name\", NS}, {\"type\", emitType(AS)}};\n\n    if (auto DV = A->getValueAsOptionalString(\"defaultValue\")) { attr[\"default\"] = DV.value(); }\n\n    (*op)[\"attrs\"].push_back(attr);\n  }\n}\n\ntemplate<FileTarget Target>\nvoid OpSchemaEmitter<Target>::emitBit(const Record* def, StringRef fieldname, json* op) const {\n  (*op)[fieldname.str()] = def->getValueAsBit(fieldname);\n}\n\ntemplate<FileTarget Target>\nvoid OpSchemaEmitter<Target>::emitTrait(const Record* def, StringRef fieldname, StringRef traitname,\n                                        json* op) const {\n  bool hasTrait = false;\n\n  for (auto elem : *def->getValueAsListInit(\"traits\")) {\n    if (elem->getAsString() == traitname) {\n      hasTrait = true;\n      break;\n    }\n  }\n\n  (*op)[fieldname.str()] = hasTrait;\n}\n\ntemplate<FileTarget Target>\nvoid OpSchemaEmitter<Target>::emitInt(const Record* def, StringRef fieldname, json* op) const {\n  (*op)[fieldname.str()] = def->getValueAsInt(fieldname);\n}\n\ntemplate<>\nconst std::string OpSchemaEmitter<FileTarget::kHeader>::code{\n#include \"op_schema_header.inc\"\n};\n\ntemplate<>\nconst std::string OpSchemaEmitter<FileTarget::kSource>::code{\n#include \"op_schema_source.inc\"\n};\n\nvoid EmitOpSchemaHeader(RecordKeeper& RK, raw_ostream& os) {\n  OpSchemaEmitter<FileTarget::kHeader>(RK).run(os);\n}\n\nvoid EmitOpSchemaSource(RecordKeeper& RK, raw_ostream& os) {\n  OpSchemaEmitter<FileTarget::kSource>(RK).run(os);\n}\n\n}  // namespace tblgen\n}  // namespace oneflow\n"
  },
  {
    "path": "tools/oneflow-tblgen/op_schema_header.inc",
    "content": "R\"OP_SCHEMA_INC(\n#include \"oneflow/core/common/data_type.h\"\n#include \"oneflow/core/common/shape.h\"\n#include \"oneflow/core/common/symbol.h\"\n#include \"oneflow/core/operator/op_conf.pb.h\"\n#include \"oneflow/core/job/sbp_parallel.pb.h\"\n#include \"oneflow/core/framework/op_definition.h\"\n\n#include <string>\n#include <vector>\n#include <functional>\n#include <complex>\n\nclass OperatorConf;\nclass NdSbpSignature;\n\nnamespace oneflow {\n\nclass Device;\nclass Stream;\nclass InputBlobModifier;\nclass OutputBlobModifier;\n\nnamespace user_op {\nclass UserOpDefWrapper;\nclass UserOpConfWrapper;\nclass InferContext;\nclass SbpContext;\nclass InferSbpSignatureFnContext;\nclass InferOutputBlobTimeShapeFnContext;\nclass InferNdSbpFnContext;\nclass DeviceAndStreamInferContext;\nclass ComputeComplexityFnContext;\nclass GetNdSbpSignatureListContext;\n}  // namespace user_op\n\nusing GetInputArgModifier =\n    std::function<InputBlobModifier*(const std::string& in_arg_name, int32_t in_arg_index)>;\nusing GetOutputArgModifier =\n    std::function<OutputBlobModifier*(const std::string& out_arg_name, int32_t out_arg_index)>;\n\n{% for opname, op in ops %}\nclass {{opname}} : public OpDefinition<{{opname}}> {\n public:\n  virtual ~{{opname}}() = default;\n  {% if op.has_nd_sbp_infer_fn -%}\n  static Maybe<void> InferNdSbp(user_op::InferNdSbpFnContext* ctx);\n  {% endif -%}\n  {% if op.has_get_sbp_fn -%}\n  static Maybe<void> GetSbp(user_op::SbpContext* ctx);\n  {% endif -%}\n  {% if op.has_get_nd_sbp_fn -%}\n  static Maybe<void> GetNdSbpSignatureList(user_op::GetNdSbpSignatureListContext* ctx);\n  {% endif -%}\n  {% if op.has_enumerate_nd_sbp_signatures_fn -%}\n  static Maybe<void> EnumerateNdSbpSignatures(user_op::GetNdSbpSignatureListContext* ctx);\n  {% endif -%}\n  {% if op.has_dump_nd_sbp_signature_for_op_conf_fn -%}\n  static Maybe<void> DumpNdSbpSignatureForOpConfFn(const NdSbpSignature& nd_sbp_sig, OperatorConf* op_conf);\n  {% endif -%}\n  {% if op.has_logical_tensor_desc_infer_fn -%}\n  static Maybe<void> InferLogicalTensorDesc(user_op::InferContext* ctx);\n  {% endif -%}\n  {% if op.has_physical_tensor_desc_infer_fn -%}\n  static Maybe<void> InferPhysicalTensorDesc(user_op::InferContext* ctx);\n  {% endif -%}\n  {% if op.has_data_type_infer_fn -%}\n  static Maybe<void> InferDataType(user_op::InferContext* ctx);\n  {% endif -%}\n  {% if op.has_device_and_stream_infer_fn -%}\n  static Maybe<Symbol<Stream>> InferDeviceAndStream(user_op::DeviceAndStreamInferContext* ctx);\n  {% endif -%}\n  {% if op.has_sbp_signature_infer_fn -%}\n  static Maybe<void> InferSbpSignature(user_op::InferSbpSignatureFnContext* ctx);\n  {% endif -%}\n  {% if op.has_compute_complexity_fn -%}\n  static Maybe<double> GetComputeComplexity(user_op::ComputeComplexityFnContext* ctx);\n  {% endif -%}\n  {% if op.has_input_arg_modify_fn -%}\n  static Maybe<void> ModifyInputArg(const GetInputArgModifier&, const user_op::UserOpConfWrapper&);\n  {% endif -%}\n  {% if op.has_output_arg_modify_fn -%}\n  static Maybe<void> ModifyOutputArg(const GetOutputArgModifier&, const user_op::UserOpConfWrapper&);\n  {% endif -%}\n  {% if op.has_output_blob_time_shape_infer_fn -%}\n  static Maybe<void> InferOutputBlobTimeShape(user_op::InferOutputBlobTimeShapeFnContext* ctx);\n  {% endif -%}\n  {% if op.has_check_fn -%}\n  static Maybe<void> CheckAttr(const user_op::UserOpDefWrapper&, const user_op::UserOpConfWrapper&);\n  {% endif -%}\n\n  {% for attr in op.attrs -%}\n  virtual const {{attr.type}}& {{attr.name}}() const = 0;\n  virtual {{attr.type}}* mutable_{{attr.name}}() = 0;\n  virtual void set_{{attr.name}}(const {{attr.type}}& {{attr.name}}) = 0;\n\n  {% endfor -%}\n  static const HashSet<std::string>& AttrNames();\n};\n\nnamespace schema {\nclass {{opname}} : public oneflow::{{opname}} {\n public:\n  {% for attr in op.attrs -%}\n  const {{attr.type}}& {{attr.name}}() const override { return {{attr.name}}_; }\n  {{attr.type}}* mutable_{{attr.name}}() override { return &{{attr.name}}_; }\n  void set_{{attr.name}}(const {{attr.type}}& {{attr.name}}) override { {{attr.name}}_ = {{attr.name}}; }\n\n  {% endfor -%}\n\n  Maybe<AttrVal> Attr(const std::string& attr_name) const override;\n\n private:\n  {% for attr in op.attrs -%}\n  {{attr.type}} {{attr.name}}_{% if existsIn(attr, \"default\") %} = {{attr.default}}{% endif %};\n  {% endfor %}\n};\n}  // namespace schema\n{% endfor %}\n} // namespace oneflow\n)OP_SCHEMA_INC\"\n"
  },
  {
    "path": "tools/oneflow-tblgen/op_schema_source.inc",
    "content": "R\"OP_SCHEMA_INC(\n{% if include != \"\" %}#include \"{{ include }}\"\n{% else if filename != \"\" %}#include \"{{ to_header(filename) }}\"\n{% endif %}\n#include \"oneflow/core/common/auto_registration_factory.h\"\n#include \"oneflow/core/framework/attr_value.h\"\n#include \"oneflow/core/framework/nd_sbp.h\"\n#include \"oneflow/core/framework/infer_nd_sbp_fn_context.h\"\n#include \"oneflow/core/framework/user_op_registry_manager.h\"\n#include <complex>\n\nnamespace oneflow {\n\n#define REGISTER_OP_SCHEMA(op_type, schema) \\\n  REGISTER_CLASS_CREATOR(std::string, op_type, OpDefinitionBase, ([]() { return new schema; }))\n\n{% for opname, op in ops %}\n/*static*/ const HashSet<std::string>& {{opname}}::AttrNames() {\n  static const HashSet<std::string> attr_names = { {%- for attr in op.attrs -%}\"{{attr.name}}\", {%- endfor -%} };\n  return attr_names;\n}\n\nnamespace schema {\nMaybe<AttrVal> {{opname}}::Attr(const std::string& attr_name) const {\n  {% for attr in op.attrs %}if(attr_name == \"{{attr.name}}\") {\n    return CastAttrValue(&{{attr.name}}_);\n  }\n  {% endfor -%}\n  return Error::RuntimeError() << \"{{op.name}} op has no attribute named \" << attr_name;\n}\n}  // namespace schema\n\nREGISTER_OP_SCHEMA(\"user.{{op.name}}\", schema::{{opname}});\n\nREGISTER_USER_OP(\"{{op.name}}\")\n{%- if op.input -%}\n{%- for input in op.input -%}\n{%- if input.is_optional -%}\n    .OptionalInput(\"{{input.name}}\")\n{%- else -%}\n    .Input(\"{{input.name}}\")\n{%- endif -%}\n{%- endfor -%}\n{%- endif -%}\n{%- if op.output -%}\n{%- for output in op.output -%}\n{%- if output.is_optional -%}\n    .OptionalOutput(\"{{output.name}}\")\n{%- else -%}\n    .Output(\"{{output.name}}\")\n{%- endif -%}\n{%- endfor -%}\n{%- endif -%}\n\n{%- for attr in op.attrs -%}\n{%- if existsIn(attr, \"default\") -%}\n    .Attr<{{attr.type}}>(\"{{attr.name}}\", {{attr.default}})\n{%- else -%}\n    .Attr<{{attr.type}}>(\"{{attr.name}}\")\n{%- endif -%}\n{%- endfor -%}\n{%- if op.cpu_only -%}\n    .SupportCpuOnly()\n{%- endif -%}\n{%- if op.no_grad -%}\n    .NoGrad()\n{%- endif -%}\n{%- if op.support_non_contiguous -%}\n    .SupportNonContiguous()\n{%- endif -%}\n{%- if op.same_output_regst_num != -1 -%}\n    .SetOutputBufferNum({{op.same_output_regst_num}})\n{%- endif -%}\n{%- if op.has_nd_sbp_infer_fn -%}\n    .SetNdSbpInferFn(&{{opname}}::InferNdSbp)\n{%- endif -%}\n{%- if op.has_get_sbp_fn -%}\n    .SetGetSbpFn(&{{opname}}::GetSbp)\n{%- endif -%}\n{%- if op.has_get_nd_sbp_fn -%}\n    .SetGetNdSbpSignatureListFn(&{{opname}}::GetNdSbpSignatureList)\n{%- endif -%}\n{%- if op.has_enumerate_nd_sbp_signatures_fn -%}\n    .SetEnumerateNdSbpSignaturesFn(&{{opname}}::EnumerateNdSbpSignatures)\n{%- endif -%}\n{%- if op.has_dump_nd_sbp_signature_for_op_conf_fn -%}\n    .SetDumpNdSbpSignatureForOpConfFn(&{{opname}}::DumpNdSbpSignatureForOpConfFn)\n{%- endif -%}\n{%- if op.has_compute_complexity_fn -%}\n    .SetComputeComplexityFn(&{{opname}}::GetComputeComplexity)\n{%- endif -%}\n{%- if op.has_logical_tensor_desc_infer_fn -%}\n    .SetLogicalTensorDescInferFn(&{{opname}}::InferLogicalTensorDesc)\n{%- endif -%}\n{%- if op.has_physical_tensor_desc_infer_fn -%}\n    .SetPhysicalTensorDescInferFn(&{{opname}}::InferPhysicalTensorDesc)\n{%- endif -%}\n{%- if op.has_data_type_infer_fn -%}\n    .SetDataTypeInferFn(&{{opname}}::InferDataType)\n{%- endif -%}\n{%- if op.has_device_and_stream_infer_fn -%}\n    .SetDeviceAndStreamInferFn(&{{opname}}::InferDeviceAndStream)\n{%- endif -%}\n{%- if op.has_sbp_signature_infer_fn -%}\n    .SetSbpSignatureInferFn(&{{opname}}::InferSbpSignature)\n{% endif -%}\n{%- if op.has_input_arg_modify_fn -%}\n    .SetInputArgModifyFn(&{{opname}}::ModifyInputArg)\n{%- endif -%}\n{%- if op.has_output_arg_modify_fn -%}\n    .SetOutputArgModifyFn(&{{opname}}::ModifyOutputArg)\n{%- endif -%}\n{%- if op.has_output_blob_time_shape_infer_fn -%}\n    .SetOutputBlobTimeShapeInferFn(&{{opname}}::InferOutputBlobTimeShape)\n{%- endif -%}\n{%- if op.has_check_fn -%}\n    .SetCheckAttrFn(&{{opname}}::CheckAttr)\n{%- endif -%}\n;\n{%- endfor %}\n} // namespace oneflow\n)OP_SCHEMA_INC\"\n"
  },
  {
    "path": "tools/oneflow-tblgen/op_schema_types.inc",
    "content": "OP_SCHEMA(SI32Attr, int32_t)\nOP_SCHEMA(SI64Attr, int64_t)\nOP_SCHEMA(BoolAttr, bool)\nOP_SCHEMA(F32Attr, float)\nOP_SCHEMA(F64Attr, double)\nOP_SCHEMA(StrAttr, std::string)\nOP_SCHEMA(ShapeAttr, Shape)\nOP_SCHEMA(OneFlow_DataType, DataType)\nOP_SCHEMA(OneFlow_MemoryFormat, MemoryFormat)\nOP_SCHEMA(SI32ArrayAttr, std::vector<std::int32_t>)\nOP_SCHEMA(SI64ArrayAttr, std::vector<std::int64_t>)\nOP_SCHEMA(F32ArrayAttr, std::vector<float>)\nOP_SCHEMA(DTArrayAttr, std::vector<DataType>)\nOP_SCHEMA(ShapeArrayAttr, std::vector<Shape>)\nOP_SCHEMA(StrArrayAttr, std::vector<std::string>)\nOP_SCHEMA(ComplexDoubleAttr, std::complex<double>)\nOP_SCHEMA(BytesAttr, std::vector<char>)\n"
  },
  {
    "path": "tools/oneflow-tblgen/tablegen.cpp",
    "content": "/*\nCopyright 2020 The OneFlow Authors. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"llvm/Support/CommandLine.h\"\n#include \"llvm/Support/InitLLVM.h\"\n#include \"llvm/TableGen/Main.h\"\n#include \"llvm/TableGen/Record.h\"\n#include \"llvm/TableGen/SetTheory.h\"\n\n#include \"backends.h\"\n\nusing namespace llvm;\nusing namespace oneflow::tblgen;\n\nenum ActionType {\n  PrintRecords,\n  PrintDetailedRecords,\n  NullBackend,\n  DumpJSON,\n  PrintEnums,\n  PrintSets,\n  GenOpSchemaHeader,\n  GenOpSchemaSource,\n};\n\nnamespace llvm {\ncl::opt<bool> EmitLongStrLiterals(\n    \"long-string-literals\",\n    cl::desc(\"when emitting large string tables, prefer string literals over \"\n             \"comma-separated char literals. This can be a readability and \"\n             \"compile-time performance win, but upsets some compilers\"),\n    cl::Hidden, cl::init(true));\n}  // end namespace llvm\n\nnamespace {\ncl::opt<ActionType> Action(\n    cl::desc(\"Action to perform:\"),\n    cl::values(clEnumValN(PrintRecords, \"print-records\", \"Print all records to stdout (default)\"),\n               clEnumValN(PrintDetailedRecords, \"print-detailed-records\",\n                          \"Print full details of all records to stdout\"),\n               clEnumValN(NullBackend, \"null-backend\",\n                          \"Do nothing after parsing (useful for timing)\"),\n               clEnumValN(DumpJSON, \"dump-json\", \"Dump all records as machine-readable JSON\"),\n               clEnumValN(PrintEnums, \"print-enums\", \"Print enum values for a class\"),\n               clEnumValN(PrintSets, \"print-sets\", \"Print expanded sets for testing DAG exprs\"),\n               clEnumValN(GenOpSchemaHeader, \"gen-op-schema-h\",\n                          \"Generate oneflow op schema header code (.h)\"),\n               clEnumValN(GenOpSchemaSource, \"gen-op-schema-cpp\",\n                          \"Generate oneflow op schema source code (.cpp)\")));\n\ncl::OptionCategory PrintEnumsCat(\"Options for -print-enums\");\ncl::opt<std::string> Class(\"class\", cl::desc(\"Print Enum list for this class\"),\n                           cl::value_desc(\"class name\"), cl::cat(PrintEnumsCat));\n\nbool LLVMTableGenMain(raw_ostream& OS, RecordKeeper& Records) {\n  switch (Action) {\n    case PrintRecords: OS << Records; break;\n    case PrintDetailedRecords: EmitDetailedRecords(Records, OS); break;\n    case NullBackend: break;\n    case DumpJSON: EmitJSON(Records, OS); break;\n    case PrintEnums: {\n      for (Record* Rec : Records.getAllDerivedDefinitions(Class)) OS << Rec->getName() << \", \";\n      OS << \"\\n\";\n      break;\n    }\n    case PrintSets: {\n      SetTheory Sets;\n      Sets.addFieldExpander(\"Set\", \"Elements\");\n      for (Record* Rec : Records.getAllDerivedDefinitions(\"Set\")) {\n        OS << Rec->getName() << \" = [\";\n        const std::vector<Record*>* Elts = Sets.expand(Rec);\n        assert(Elts && \"Couldn't expand Set instance\");\n        for (Record* Elt : *Elts) OS << ' ' << Elt->getName();\n        OS << \" ]\\n\";\n      }\n      break;\n    }\n    case GenOpSchemaHeader: EmitOpSchemaHeader(Records, OS); break;\n    case GenOpSchemaSource: EmitOpSchemaSource(Records, OS); break;\n  }\n\n  return false;\n}\n}  // namespace\n\nint main(int argc, char** argv) {\n  InitLLVM X(argc, argv);\n  cl::ParseCommandLineOptions(argc, argv);\n\n  return TableGenMain(argv[0], &LLVMTableGenMain);\n}\n"
  },
  {
    "path": "tools/oss_file_exist.py",
    "content": "import os\nimport oss2\n\n\ndef check_existence(endpoint, bucket, path):\n    ki = os.getenv(\"OSS_ACCESS_KEY_ID\")\n    ks = os.getenv(\"OSS_ACCESS_KEY_SECRET\")\n    auth = oss2.Auth(ki, ks)\n    bucket_obj = oss2.Bucket(auth, endpoint, bucket)\n    files = bucket_obj.list_objects(path)\n    file_cnt = 0\n    for f in files.object_list:\n        file_cnt += 1\n    is_existed = bucket_obj.object_exists(path) or file_cnt > 0\n    if is_existed:\n        print(\"export OSS_FILE_EXISTED=1\")\n\n\nif __name__ == \"__main__\":\n    import argparse\n\n    parser = argparse.ArgumentParser()\n\n    parser.add_argument(\n        \"-e\",\n        \"--endpoint\",\n        type=str,\n        required=False,\n        default=\"oss-cn-beijing.aliyuncs.com\",\n    )\n    parser.add_argument(\"--bucket\", type=str, required=True)\n    parser.add_argument(\"--path\", type=str, required=True)\n\n    args = parser.parse_args()\n\n    check_existence(args.endpoint, args.bucket, args.path)\n"
  },
  {
    "path": "tools/package_mirror.py",
    "content": "import glob\nimport argparse\nimport os\nimport re\nfrom urllib.parse import urlparse\nimport hashlib\nimport base64\nimport tempfile\nimport subprocess\n\nparser = argparse.ArgumentParser()\nparser.add_argument(\"-i\", \"--src_path\", type=str, required=False)\nparser.add_argument(\"-u\", \"--url\", type=str, required=False)\nargs = parser.parse_args()\n\n\ndef glob_by_pattern(dir_path, pattern):\n    result = []\n    for x in glob.glob(os.path.join(dir_path, pattern), recursive=True):\n        result.append(x)\n    return result\n\n\ndef scan_urls(dir_path):\n    cmakes = glob_by_pattern(dir_path, \"**/*.cmake\")\n    cmakes += glob_by_pattern(dir_path, \"**/*.bzl\")\n    cmakes += glob_by_pattern(dir_path, \"**/CMakeLists.txt\")\n    urls = []\n    for cmake_path in cmakes:\n        with open(cmake_path) as f:\n            content = f.read()\n            urls += re.findall(r'https?://[^\\s<>\"\\)]+|www\\.[^\\s<>\"]+', content)\n    return urls\n\n\ndef convert_url_to_oss_key(url):\n    parsed = urlparse(url)\n    assert parsed.scheme == \"https\", url\n    assert not parsed.params\n    assert not parsed.query\n    assert not parsed.port\n    assert not parsed.fragment\n    assert parsed.path.startswith(\"/\")\n    path = parsed.path[1::]\n    ret = os.path.join(\"third_party_mirror\", parsed.scheme, parsed.netloc, path)\n    assert convert_url_to_oss_key1(url) == ret\n    return ret\n\n\ndef convert_url_to_oss_key1(url):\n    path = url[len(\"https://\") : :]\n    return \"/\".join([\"third_party_mirror\", \"https\", path])\n\n\ndef convert_url_to_oss_https_url(url):\n    if should_be_mirrored(url):\n        key = convert_url_to_oss_key(url)\n        return \"https://oneflow-static.oss-cn-beijing.aliyuncs.com/\" + key\n    else:\n        return url\n\n\ndef should_be_mirrored(url: str):\n    parsed = urlparse(url)\n    return (\n        not parsed.port\n        and not parsed.query\n        and not parsed.params\n        and url.endswith((\"gz\", \"tar\", \"zip\", \"xz\"))\n        and not \"mirror.tensorflow.org\" in url\n        and not \"mirror.bazel.build\" in url\n        and not \"aliyuncs.com\" in url\n        and not \"file:\" in url\n    )\n\n\ndef calculate_data_md5(data):\n    md5 = hashlib.md5()\n    md5.update(data)\n    digest = md5.digest()\n    return base64.b64encode(digest)\n\n\ndef upload_one_to_aliyun(url: str):\n    ki = os.getenv(\"OSS_ACCESS_KEY_ID\")\n    ks = os.getenv(\"OSS_ACCESS_KEY_SECRET\")\n    import oss2\n\n    auth = oss2.Auth(ki, ks)\n    endpoint = \"oss-cn-beijing.aliyuncs.com\"\n    bucket = oss2.Bucket(auth, endpoint, \"oneflow-static\")\n    key = convert_url_to_oss_key(url)\n\n    if bucket.object_exists(key):\n        print(\"exists: \", key)\n    else:\n        d = tempfile.gettempdir()\n        dst = os.path.join(d, os.path.basename(key))\n        if os.path.isdir(dst):\n            raise ValueError(\"must not be a dir\", dst)\n        else:\n            if os.path.isfile(dst):\n                print(\"[removing]\", dst)\n                os.remove(dst)\n        subprocess.check_call(f\"wget {url} -O {dst}\", shell=True)\n        bucket.put_object_from_file(key, dst)\n\n\ndef upload_to_aliyun(dir_path):\n    urls = scan_urls(dir_path)\n    for url in urls:\n        if should_be_mirrored(url):\n            print(\"mirroring: \", url)\n            upload_one_to_aliyun(url)\n        else:\n            print(\"skipped: \", url)\n            continue\n\n\nif __name__ == \"__main__\":\n    if args.src_path != None:\n        upload_to_aliyun(args.src_path)\n    if args.url != None:\n        oss_url = convert_url_to_oss_https_url(args.url)\n        print(oss_url, end=\"\")\n"
  }
]